diff --git a/.bazelrc b/.bazelrc index ce8406b58aaab..f8ff2215f2d6b 100644 --- a/.bazelrc +++ b/.bazelrc @@ -1,4 +1,4 @@ -build --cxxopt=--std=c++14 +build --cxxopt=--std=c++17 build --copt=-I. # Bazel does not support including its cc_library targets as system # headers. We work around this for generated code diff --git a/.circleci/README.md b/.circleci/README.md new file mode 100644 index 0000000000000..e2429b4d1f037 --- /dev/null +++ b/.circleci/README.md @@ -0,0 +1,468 @@ +Warning +======= + +Contents may be out of date. Our CircleCI workflows are gradually being migrated to Github actions. + +Structure of CI +=============== + +setup job: +1. Does a git checkout +2. Persists CircleCI scripts (everything in `.circleci`) into a workspace. Why? + We don't always do a Git checkout on all subjobs, but we usually + still want to be able to call scripts one way or another in a subjob. + Persisting files this way lets us have access to them without doing a + checkout. This workspace is conventionally mounted on `~/workspace` + (this is distinguished from `~/project`, which is the conventional + working directory that CircleCI will default to starting your jobs + in.) +3. Write out the commit message to `.circleci/COMMIT_MSG`. This is so + we can determine in subjobs if we should actually run the jobs or + not, even if there isn't a Git checkout. + + +CircleCI configuration generator +================================ + +One may no longer make changes to the `.circleci/config.yml` file directly. +Instead, one must edit these Python scripts or files in the `verbatim-sources/` directory. + + +Usage +---------- + +1. Make changes to these scripts. +2. Run the `regenerate.sh` script in this directory and commit the script changes and the resulting change to `config.yml`. + +You'll see a build failure on GitHub if the scripts don't agree with the checked-in version. + + +Motivation +---------- + +These scripts establish a single, authoritative source of documentation for the CircleCI configuration matrix. +The documentation, in the form of diagrams, is automatically generated and cannot drift out of sync with the YAML content. + +Furthermore, consistency is enforced within the YAML config itself, by using a single source of data to generate +multiple parts of the file. + +* Facilitates one-off culling/enabling of CI configs for testing PRs on special targets + +Also see https://github.com/pytorch/pytorch/issues/17038 + + +Future direction +---------------- + +### Declaring sparse config subsets +See comment [here](https://github.com/pytorch/pytorch/pull/17323#pullrequestreview-206945747): + +In contrast with a full recursive tree traversal of configuration dimensions, +> in the future I think we actually want to decrease our matrix somewhat and have only a few mostly-orthogonal builds that taste as many different features as possible on PRs, plus a more complete suite on every PR and maybe an almost full suite nightly/weekly (we don't have this yet). Specifying PR jobs in the future might be easier to read with an explicit list when we come to this. +---------------- +---------------- + +# How do the binaries / nightlies / releases work? + +### What is a binary? + +A binary or package (used interchangeably) is a pre-built collection of c++ libraries, header files, python bits, and other files. We build these and distribute them so that users do not need to install from source. + +A **binary configuration** is a collection of + +* release or nightly + * releases are stable, nightlies are beta and built every night +* python version + * linux: 3.7m (mu is wide unicode or something like that. It usually doesn't matter but you should know that it exists) + * macos: 3.7, 3.8 + * windows: 3.7, 3.8 +* cpu version + * cpu, cuda 9.0, cuda 10.0 + * The supported cuda versions occasionally change +* operating system + * Linux - these are all built on CentOS. There haven't been any problems in the past building on CentOS and using on Ubuntu + * MacOS + * Windows - these are built on Azure pipelines +* devtoolset version (gcc compiler version) + * This only matters on Linux cause only Linux uses gcc. tldr is gcc made a backwards incompatible change from gcc 4.8 to gcc 5, because it had to change how it implemented std::vector and std::string + +### Where are the binaries? + +The binaries are built in CircleCI. There are nightly binaries built every night at 9pm PST (midnight EST) and release binaries corresponding to Pytorch releases, usually every few months. + +We have 3 types of binary packages + +* pip packages - nightlies are stored on s3 (pip install -f \). releases are stored in a pip repo (pip install torch) (ask Soumith about this) +* conda packages - nightlies and releases are both stored in a conda repo. Nighty packages have a '_nightly' suffix +* libtorch packages - these are zips of all the c++ libraries, header files, and sometimes dependencies. These are c++ only + * shared with dependencies (the only supported option for Windows) + * static with dependencies + * shared without dependencies + * static without dependencies + +All binaries are built in CircleCI workflows except Windows. There are checked-in workflows (committed into the .circleci/config.yml) to build the nightlies every night. Releases are built by manually pushing a PR that builds the suite of release binaries (overwrite the config.yml to build the release) + +# CircleCI structure of the binaries + +Some quick vocab: + +* A \**workflow** is a CircleCI concept; it is a DAG of '**jobs**'. ctrl-f 'workflows' on https://github.com/pytorch/pytorch/blob/master/.circleci/config.yml to see the workflows. +* **jobs** are a sequence of '**steps**' +* **steps** are usually just a bash script or a builtin CircleCI command. *All steps run in new environments, environment variables declared in one script DO NOT persist to following steps* +* CircleCI has a **workspace**, which is essentially a cache between steps of the *same job* in which you can store artifacts between steps. + +## How are the workflows structured? + +The nightly binaries have 3 workflows. We have one job (actually 3 jobs: build, test, and upload) per binary configuration + +1. binary_builds + 1. every day midnight EST + 2. linux: https://github.com/pytorch/pytorch/blob/master/.circleci/verbatim-sources/linux-binary-build-defaults.yml + 3. macos: https://github.com/pytorch/pytorch/blob/master/.circleci/verbatim-sources/macos-binary-build-defaults.yml + 4. For each binary configuration, e.g. linux_conda_3.7_cpu there is a + 1. binary_linux_conda_3.7_cpu_build + 1. Builds the build. On linux jobs this uses the 'docker executor'. + 2. Persists the package to the workspace + 2. binary_linux_conda_3.7_cpu_test + 1. Loads the package to the workspace + 2. Spins up a docker image (on Linux), mapping the package and code repos into the docker + 3. Runs some smoke tests in the docker + 4. (Actually, for macos this is a step rather than a separate job) + 3. binary_linux_conda_3.7_cpu_upload + 1. Logs in to aws/conda + 2. Uploads the package +2. update_s3_htmls + 1. every day 5am EST + 2. https://github.com/pytorch/pytorch/blob/master/.circleci/verbatim-sources/binary_update_htmls.yml + 3. See below for what these are for and why they're needed + 4. Three jobs that each examine the current contents of aws and the conda repo and update some html files in s3 +3. binarysmoketests + 1. every day + 2. https://github.com/pytorch/pytorch/blob/master/.circleci/verbatim-sources/nightly-build-smoke-tests-defaults.yml + 3. For each binary configuration, e.g. linux_conda_3.7_cpu there is a + 1. smoke_linux_conda_3.7_cpu + 1. Downloads the package from the cloud, e.g. using the official pip or conda instructions + 2. Runs the smoke tests + +## How are the jobs structured? + +The jobs are in https://github.com/pytorch/pytorch/tree/master/.circleci/verbatim-sources. Jobs are made of multiple steps. There are some shared steps used by all the binaries/smokes. Steps of these jobs are all delegated to scripts in https://github.com/pytorch/pytorch/tree/master/.circleci/scripts . + +* Linux jobs: https://github.com/pytorch/pytorch/blob/master/.circleci/verbatim-sources/linux-binary-build-defaults.yml + * binary_linux_build.sh + * binary_linux_test.sh + * binary_linux_upload.sh +* MacOS jobs: https://github.com/pytorch/pytorch/blob/master/.circleci/verbatim-sources/macos-binary-build-defaults.yml + * binary_macos_build.sh + * binary_macos_test.sh + * binary_macos_upload.sh +* Update html jobs: https://github.com/pytorch/pytorch/blob/master/.circleci/verbatim-sources/binary_update_htmls.yml + * These delegate from the pytorch/builder repo + * https://github.com/pytorch/builder/blob/master/cron/update_s3_htmls.sh + * https://github.com/pytorch/builder/blob/master/cron/upload_binary_sizes.sh +* Smoke jobs (both linux and macos): https://github.com/pytorch/pytorch/blob/master/.circleci/verbatim-sources/nightly-build-smoke-tests-defaults.yml + * These delegate from the pytorch/builder repo + * https://github.com/pytorch/builder/blob/master/run_tests.sh + * https://github.com/pytorch/builder/blob/master/smoke_test.sh + * https://github.com/pytorch/builder/blob/master/check_binary.sh +* Common shared code (shared across linux and macos): https://github.com/pytorch/pytorch/blob/master/.circleci/verbatim-sources/nightly-binary-build-defaults.yml + * binary_checkout.sh - checks out pytorch/builder repo. Right now this also checks out pytorch/pytorch, but it shouldn't. pytorch/pytorch should just be shared through the workspace. This can handle being run before binary_populate_env.sh + * binary_populate_env.sh - parses BUILD_ENVIRONMENT into the separate env variables that make up a binary configuration. Also sets lots of default values, the date, the version strings, the location of folders in s3, all sorts of things. This generally has to be run before other steps. + * binary_install_miniconda.sh - Installs miniconda, cross platform. Also hacks this for the update_binary_sizes job that doesn't have the right env variables + * binary_run_in_docker.sh - Takes a bash script file (the actual test code) from a hardcoded location, spins up a docker image, and runs the script inside the docker image + +### **Why do the steps all refer to scripts?** + +CircleCI creates a final yaml file by inlining every <<* segment, so if we were to keep all the code in the config.yml itself then the config size would go over 4 MB and cause infra problems. + +### **What is binary_run_in_docker for?** + +So, CircleCI has several executor types: macos, machine, and docker are the ones we use. The 'machine' executor gives you two cores on some linux vm. The 'docker' executor gives you considerably more cores (nproc was 32 instead of 2 back when I tried in February). Since the dockers are faster, we try to run everything that we can in dockers. Thus + +* linux build jobs use the docker executor. Running them on the docker executor was at least 2x faster than running them on the machine executor +* linux test jobs use the machine executor in order for them to properly interface with GPUs since docker executors cannot execute with attached GPUs +* linux upload jobs use the machine executor. The upload jobs are so short that it doesn't really matter what they use +* linux smoke test jobs use the machine executor for the same reason as the linux test jobs + +binary_run_in_docker.sh is a way to share the docker start-up code between the binary test jobs and the binary smoke test jobs + +### **Why does binary_checkout also checkout pytorch? Why shouldn't it?** + +We want all the nightly binary jobs to run on the exact same git commit, so we wrote our own checkout logic to ensure that the same commit was always picked. Later circleci changed that to use a single pytorch checkout and persist it through the workspace (they did this because our config file was too big, so they wanted to take a lot of the setup code into scripts, but the scripts needed the code repo to exist to be called, so they added a prereq step called 'setup' to checkout the code and persist the needed scripts to the workspace). The changes to the binary jobs were not properly tested, so they all broke from missing pytorch code no longer existing. We hotfixed the problem by adding the pytorch checkout back to binary_checkout, so now there's two checkouts of pytorch on the binary jobs. This problem still needs to be fixed, but it takes careful tracing of which code is being called where. + +# Code structure of the binaries (circleci agnostic) + +## Overview + +The code that runs the binaries lives in two places, in the normal [github.com/pytorch/pytorch](http://github.com/pytorch/pytorch), but also in [github.com/pytorch/builder](http://github.com/pytorch/builder), which is a repo that defines how all the binaries are built. The relevant code is + + +``` +# All code needed to set-up environments for build code to run in, +# but only code that is specific to the current CI system +pytorch/pytorch +- .circleci/ # Folder that holds all circleci related stuff + - config.yml # GENERATED file that actually controls all circleci behavior + - verbatim-sources # Used to generate job/workflow sections in ^ + - scripts/ # Code needed to prepare circleci environments for binary build scripts +- setup.py # Builds pytorch. This is wrapped in pytorch/builder +- cmake files # used in normal building of pytorch +# All code needed to prepare a binary build, given an environment +# with all the right variables/packages/paths. +pytorch/builder +# Given an installed binary and a proper python env, runs some checks +# to make sure the binary was built the proper way. Checks things like +# the library dependencies, symbols present, etc. +- check_binary.sh +# Given an installed binary, runs python tests to make sure everything +# is in order. These should be de-duped. Right now they both run smoke +# tests, but are called from different places. Usually just call some +# import statements, but also has overlap with check_binary.sh above +- run_tests.sh +- smoke_test.sh +# Folders that govern how packages are built. See paragraphs below +- conda/ + - build_pytorch.sh # Entrypoint. Delegates to proper conda build folder + - switch_cuda_version.sh # Switches activate CUDA installation in Docker + - pytorch-nightly/ # Build-folder +- manywheel/ + - build_cpu.sh # Entrypoint for cpu builds + - build.sh # Entrypoint for CUDA builds + - build_common.sh # Actual build script that ^^ call into +- wheel/ + - build_wheel.sh # Entrypoint for wheel builds +- windows/ + - build_pytorch.bat # Entrypoint for wheel builds on Windows +``` + +Every type of package has an entrypoint build script that handles the all the important logic. + +## Conda + +Linux, MacOS and Windows use the same code flow for the conda builds. + +Conda packages are built with conda-build, see https://conda.io/projects/conda-build/en/latest/resources/commands/conda-build.html + +Basically, you pass `conda build` a build folder (pytorch-nightly/ above) that contains a build script and a meta.yaml. The meta.yaml specifies in what python environment to build the package in, and what dependencies the resulting package should have, and the build script gets called in the env to build the thing. +tl;dr on conda-build is + +1. Creates a brand new conda environment, based off of deps in the meta.yaml + 1. Note that environment variables do not get passed into this build env unless they are specified in the meta.yaml + 2. If the build fails this environment will stick around. You can activate it for much easier debugging. The “General Python” section below explains what exactly a python “environment” is. +2. Calls build.sh in the environment +3. Copies the finished package to a new conda env, also specified by the meta.yaml +4. Runs some simple import tests (if specified in the meta.yaml) +5. Saves the finished package as a tarball + +The build.sh we use is essentially a wrapper around `python setup.py build`, but it also manually copies in some of our dependent libraries into the resulting tarball and messes with some rpaths. + +The entrypoint file `builder/conda/build_conda.sh` is complicated because + +* It works for Linux, MacOS and Windows + * The mac builds used to create their own environments, since they all used to be on the same machine. There’s now a lot of extra logic to handle conda envs. This extra machinery could be removed +* It used to handle testing too, which adds more logic messing with python environments too. This extra machinery could be removed. + +## Manywheels (linux pip and libtorch packages) + +Manywheels are pip packages for linux distros. Note that these manywheels are not actually manylinux compliant. + +`builder/manywheel/build_cpu.sh` and `builder/manywheel/build.sh` (for CUDA builds) just set different env vars and then call into `builder/manywheel/build_common.sh` + +The entrypoint file `builder/manywheel/build_common.sh` is really really complicated because + +* This used to handle building for several different python versions at the same time. The loops have been removed, but there's still unnecessary folders and movements here and there. + * The script is never used this way anymore. This extra machinery could be removed. +* This used to handle testing the pip packages too. This is why there’s testing code at the end that messes with python installations and stuff + * The script is never used this way anymore. This extra machinery could be removed. +* This also builds libtorch packages + * This should really be separate. libtorch packages are c++ only and have no python. They should not share infra with all the python specific stuff in this file. +* There is a lot of messing with rpaths. This is necessary, but could be made much much simpler if the above issues were fixed. + +## Wheels (MacOS pip and libtorch packages) + +The entrypoint file `builder/wheel/build_wheel.sh` is complicated because + +* The mac builds used to all run on one machine (we didn’t have autoscaling mac machines till circleci). So this script handled siloing itself by setting-up and tearing-down its build env and siloing itself into its own build directory. + * The script is never used this way anymore. This extra machinery could be removed. +* This also builds libtorch packages + * Ditto the comment above. This should definitely be separated out. + +Note that the MacOS Python wheels are still built in conda environments. Some of the dependencies present during build also come from conda. + +## Windows Wheels (Windows pip and libtorch packages) + +The entrypoint file `builder/windows/build_pytorch.bat` is complicated because + +* This used to handle building for several different python versions at the same time. This is why there are loops everywhere + * The script is never used this way anymore. This extra machinery could be removed. +* This used to handle testing the pip packages too. This is why there’s testing code at the end that messes with python installations and stuff + * The script is never used this way anymore. This extra machinery could be removed. +* This also builds libtorch packages + * This should really be separate. libtorch packages are c++ only and have no python. They should not share infra with all the python specific stuff in this file. + +Note that the Windows Python wheels are still built in conda environments. Some of the dependencies present during build also come from conda. + +## General notes + +### Note on run_tests.sh, smoke_test.sh, and check_binary.sh + +* These should all be consolidated +* These must run on all OS types: MacOS, Linux, and Windows +* These all run smoke tests at the moment. They inspect the packages some, maybe run a few import statements. They DO NOT run the python tests nor the cpp tests. The idea is that python tests on master and PR merges will catch all breakages. All these tests have to do is make sure the special binary machinery didn’t mess anything up. +* There are separate run_tests.sh and smoke_test.sh because one used to be called by the smoke jobs and one used to be called by the binary test jobs (see circleci structure section above). This is still true actually, but these could be united into a single script that runs these checks, given an installed pytorch package. + +### Note on libtorch + +Libtorch packages are built in the wheel build scripts: manywheel/build_*.sh for linux and build_wheel.sh for mac. There are several things wrong with this + +* It’s confusing. Most of those scripts deal with python specifics. +* The extra conditionals everywhere severely complicate the wheel build scripts +* The process for building libtorch is different from the official instructions (a plain call to cmake, or a call to a script) + +### Note on docker images / Dockerfiles + +All linux builds occur in docker images. The docker images are + +* pytorch/conda-cuda + * Has ALL CUDA versions installed. The script pytorch/builder/conda/switch_cuda_version.sh sets /usr/local/cuda to a symlink to e.g. /usr/local/cuda-10.0 to enable different CUDA builds + * Also used for cpu builds +* pytorch/manylinux-cuda90 +* pytorch/manylinux-cuda100 + * Also used for cpu builds + +The Dockerfiles are available in pytorch/builder, but there is no circleci job or script to build these docker images, and they cannot be run locally (unless you have the correct local packages/paths). Only Soumith can build them right now. + +### General Python + +* This is still a good explanation of python installations https://caffe2.ai/docs/faq.html#why-do-i-get-import-errors-in-python-when-i-try-to-use-caffe2 + +# How to manually rebuild the binaries + +tl;dr make a PR that looks like https://github.com/pytorch/pytorch/pull/21159 + +Sometimes we want to push a change to master and then rebuild all of today's binaries after that change. As of May 30, 2019 there isn't a way to manually run a workflow in the UI. You can manually re-run a workflow, but it will use the exact same git commits as the first run and will not include any changes. So we have to make a PR and then force circleci to run the binary workflow instead of the normal tests. The above PR is an example of how to do this; essentially you copy-paste the binarybuilds workflow steps into the default workflow steps. If you need to point the builder repo to a different commit then you'd need to change https://github.com/pytorch/pytorch/blob/master/.circleci/scripts/binary_checkout.sh#L42-L45 to checkout what you want. + +## How to test changes to the binaries via .circleci + +Writing PRs that test the binaries is annoying, since the default circleci jobs that run on PRs are not the jobs that you want to run. Likely, changes to the binaries will touch something under .circleci/ and require that .circleci/config.yml be regenerated (.circleci/config.yml controls all .circleci behavior, and is generated using `.circleci/regenerate.sh` in python 3.7). But you also need to manually hardcode the binary jobs that you want to test into the .circleci/config.yml workflow, so you should actually make at least two commits, one for your changes and one to temporarily hardcode jobs. See https://github.com/pytorch/pytorch/pull/22928 as an example of how to do this. + +```sh +# Make your changes +touch .circleci/verbatim-sources/nightly-binary-build-defaults.yml +# Regenerate the yaml, has to be in python 3.7 +.circleci/regenerate.sh +# Make a commit +git add .circleci * +git commit -m "My real changes" +git push origin my_branch +# Now hardcode the jobs that you want in the .circleci/config.yml workflows section +# Also eliminate ensure-consistency and should_run_job checks +# e.g. https://github.com/pytorch/pytorch/commit/2b3344bfed8772fe86e5210cc4ee915dee42b32d +# Make a commit you won't keep +git add .circleci +git commit -m "[DO NOT LAND] testing binaries for above changes" +git push origin my_branch +# Now you need to make some changes to the first commit. +git rebase -i HEAD~2 # mark the first commit as 'edit' +# Make the changes +touch .circleci/verbatim-sources/nightly-binary-build-defaults.yml +.circleci/regenerate.sh +# Ammend the commit and recontinue +git add .circleci +git commit --amend +git rebase --continue +# Update the PR, need to force since the commits are different now +git push origin my_branch --force +``` + +The advantage of this flow is that you can make new changes to the base commit and regenerate the .circleci without having to re-write which binary jobs you want to test on. The downside is that all updates will be force pushes. + +## How to build a binary locally + +### Linux + +You can build Linux binaries locally easily using docker. + +```sh +# Run the docker +# Use the correct docker image, pytorch/conda-cuda used here as an example +# +# -v path/to/foo:path/to/bar makes path/to/foo on your local machine (the +# machine that you're running the command on) accessible to the docker +# container at path/to/bar. So if you then run `touch path/to/bar/baz` +# in the docker container then you will see path/to/foo/baz on your local +# machine. You could also clone the pytorch and builder repos in the docker. +# +# If you know how, add ccache as a volume too and speed up everything +docker run \ + -v your/pytorch/repo:/pytorch \ + -v your/builder/repo:/builder \ + -v where/you/want/packages/to/appear:/final_pkgs \ + -it pytorch/conda-cuda /bin/bash +# Export whatever variables are important to you. All variables that you'd +# possibly need are in .circleci/scripts/binary_populate_env.sh +# You should probably always export at least these 3 variables +export PACKAGE_TYPE=conda +export DESIRED_PYTHON=3.7 +export DESIRED_CUDA=cpu +# Call the entrypoint +# `|& tee foo.log` just copies all stdout and stderr output to foo.log +# The builds generate lots of output so you probably need this when +# building locally. +/builder/conda/build_pytorch.sh |& tee build_output.log +``` + +**Building CUDA binaries on docker** + +You can build CUDA binaries on CPU only machines, but you can only run CUDA binaries on CUDA machines. This means that you can build a CUDA binary on a docker on your laptop if you so choose (though it’s gonna take a long time). + +For Facebook employees, ask about beefy machines that have docker support and use those instead of your laptop; it will be 5x as fast. + +### MacOS + +There’s no easy way to generate reproducible hermetic MacOS environments. If you have a Mac laptop then you can try emulating the .circleci environments as much as possible, but you probably have packages in /usr/local/, possibly installed by brew, that will probably interfere with the build. If you’re trying to repro an error on a Mac build in .circleci and you can’t seem to repro locally, then my best advice is actually to iterate on .circleci :/ + +But if you want to try, then I’d recommend + +```sh +# Create a new terminal +# Clear your LD_LIBRARY_PATH and trim as much out of your PATH as you +# know how to do +# Install a new miniconda +# First remove any other python or conda installation from your PATH +# Always install miniconda 3, even if building for Python <3 +new_conda="~/my_new_conda" +conda_sh="$new_conda/install_miniconda.sh" +curl -o "$conda_sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh +chmod +x "$conda_sh" +"$conda_sh" -b -p "$MINICONDA_ROOT" +rm -f "$conda_sh" +export PATH="~/my_new_conda/bin:$PATH" +# Create a clean python env +# All MacOS builds use conda to manage the python env and dependencies +# that are built with, even the pip packages +conda create -yn binary python=2.7 +conda activate binary +# Export whatever variables are important to you. All variables that you'd +# possibly need are in .circleci/scripts/binary_populate_env.sh +# You should probably always export at least these 3 variables +export PACKAGE_TYPE=conda +export DESIRED_PYTHON=3.7 +export DESIRED_CUDA=cpu +# Call the entrypoint you want +path/to/builder/wheel/build_wheel.sh +``` + +N.B. installing a brand new miniconda is important. This has to do with how conda installations work. See the “General Python” section above, but tldr; is that + +1. You make the ‘conda’ command accessible by prepending `path/to/conda_root/bin` to your PATH. +2. You make a new env and activate it, which then also gets prepended to your PATH. Now you have `path/to/conda_root/envs/new_env/bin:path/to/conda_root/bin:$PATH` +3. Now say you (or some code that you ran) call python executable `foo` + 1. if you installed `foo` in `new_env`, then `path/to/conda_root/envs/new_env/bin/foo` will get called, as expected. + 2. But if you forgot to installed `foo` in `new_env` but happened to previously install it in your root conda env (called ‘base’), then unix/linux will still find `path/to/conda_root/bin/foo` . This is dangerous, since `foo` can be a different version than you want; `foo` can even be for an incompatible python version! + +Newer conda versions and proper python hygiene can prevent this, but just install a new miniconda to be safe. + +### Windows + +TODO: fill in diff --git a/.circleci/docker/build.sh b/.circleci/docker/build.sh index 7633f1eacac09..ebea9eda85a6a 100755 --- a/.circleci/docker/build.sh +++ b/.circleci/docker/build.sh @@ -33,7 +33,7 @@ function extract_all_from_image_name() { if [ "x${name}" = xpy ]; then vername=ANACONDA_PYTHON_VERSION fi - # skip non-conforming fields such as "pytorch", "linux" or "xenial" without version string + # skip non-conforming fields such as "pytorch", "linux" or "bionic" without version string if [ -n "${name}" ]; then extract_version_from_image_name "${name}" "${vername}" fi @@ -46,11 +46,7 @@ if [[ "$image" == *xla* ]]; then exit 0 fi -if [[ "$image" == *-xenial* ]]; then - UBUNTU_VERSION=16.04 -elif [[ "$image" == *-artful* ]]; then - UBUNTU_VERSION=17.10 -elif [[ "$image" == *-bionic* ]]; then +if [[ "$image" == *-bionic* ]]; then UBUNTU_VERSION=18.04 elif [[ "$image" == *-focal* ]]; then UBUNTU_VERSION=20.04 @@ -79,56 +75,17 @@ elif [[ "$image" == *rocm* ]]; then DOCKERFILE="${OS}-rocm/Dockerfile" fi -if [[ "$image" == *xenial* ]] || [[ "$image" == *bionic* ]]; then - CMAKE_VERSION=3.13.5 -fi +# CMake 3.18 is needed to support CUDA17 language variant +CMAKE_VERSION=3.18.5 TRAVIS_DL_URL_PREFIX="https://s3.amazonaws.com/travis-python-archives/binaries/ubuntu/14.04/x86_64" _UCX_COMMIT=31e74cac7bee0ef66bef2af72e7d86d9c282e5ab -_UCC_COMMIT=12944da33f911daf505d9bbc51411233d0ed85e1 +_UCC_COMMIT=1c7a7127186e7836f73aafbd7697bbc274a77eee # It's annoying to rename jobs every time you want to rewrite a # configuration, so we hardcode everything here rather than do it # from scratch case "$image" in - pytorch-linux-xenial-py3.8) - ANACONDA_PYTHON_VERSION=3.8 - GCC_VERSION=7 - # Do not install PROTOBUF, DB, and VISION as a test - ;; - pytorch-linux-xenial-py3.7-gcc7.2) - ANACONDA_PYTHON_VERSION=3.7 - GCC_VERSION=7 - # Do not install PROTOBUF, DB, and VISION as a test - ;; - pytorch-linux-xenial-py3.7-gcc7) - ANACONDA_PYTHON_VERSION=3.7 - GCC_VERSION=7 - PROTOBUF=yes - DB=yes - VISION=yes - ;; - pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7) - CUDA_VERSION=10.2 - CUDNN_VERSION=7 - ANACONDA_PYTHON_VERSION=3.7 - GCC_VERSION=7 - PROTOBUF=yes - DB=yes - VISION=yes - KATEX=yes - ;; - pytorch-linux-xenial-cuda11.3-cudnn8-py3-gcc7) - CUDA_VERSION=11.3.0 # Deviating from major.minor to conform to nvidia's Docker image names - CUDNN_VERSION=8 - TENSORRT_VERSION=8.0.1.6 - ANACONDA_PYTHON_VERSION=3.7 - GCC_VERSION=7 - PROTOBUF=yes - DB=yes - VISION=yes - KATEX=yes - ;; pytorch-linux-bionic-cuda11.3-cudnn8-py3-clang9) CUDA_VERSION=11.3.0 # Deviating from major.minor to conform to nvidia's Docker image names CUDNN_VERSION=8 @@ -167,20 +124,6 @@ case "$image" in UCC_COMMIT=${_UCC_COMMIT} CONDA_CMAKE=yes ;; - pytorch-linux-xenial-py3-clang5-asan) - ANACONDA_PYTHON_VERSION=3.7 - CLANG_VERSION=5.0 - PROTOBUF=yes - DB=yes - VISION=yes - ;; - pytorch-linux-xenial-py3-clang7-asan) - ANACONDA_PYTHON_VERSION=3.7 - CLANG_VERSION=7 - PROTOBUF=yes - DB=yes - VISION=yes - ;; pytorch-linux-focal-py3-clang7-asan) ANACONDA_PYTHON_VERSION=3.7 CLANG_VERSION=7 @@ -189,13 +132,6 @@ case "$image" in VISION=yes CONDA_CMAKE=yes ;; - pytorch-linux-xenial-py3-clang7-onnx) - ANACONDA_PYTHON_VERSION=3.7 - CLANG_VERSION=7 - PROTOBUF=yes - DB=yes - VISION=yes - ;; pytorch-linux-focal-py3-clang10-onnx) ANACONDA_PYTHON_VERSION=3.7 CLANG_VERSION=10 @@ -204,9 +140,9 @@ case "$image" in VISION=yes CONDA_CMAKE=yes ;; - pytorch-linux-xenial-py3-clang5-android-ndk-r19c) + pytorch-linux-focal-py3-clang7-android-ndk-r19c) ANACONDA_PYTHON_VERSION=3.7 - CLANG_VERSION=5.0 + CLANG_VERSION=7 LLVMDEV=yes PROTOBUF=yes ANDROID=yes @@ -214,13 +150,6 @@ case "$image" in GRADLE_VERSION=6.8.3 NINJA_VERSION=1.9.0 ;; - pytorch-linux-xenial-py3.7-clang7) - ANACONDA_PYTHON_VERSION=3.7 - CLANG_VERSION=7 - PROTOBUF=yes - DB=yes - VISION=yes - ;; pytorch-linux-bionic-py3.7-clang9) ANACONDA_PYTHON_VERSION=3.7 CLANG_VERSION=9 @@ -259,8 +188,8 @@ case "$image" in VISION=yes CONDA_CMAKE=yes ;; - pytorch-linux-focal-rocm5.1-py3.7) - ANACONDA_PYTHON_VERSION=3.7 + pytorch-linux-focal-rocm5.1-py3.8) + ANACONDA_PYTHON_VERSION=3.8 GCC_VERSION=9 PROTOBUF=yes DB=yes @@ -268,8 +197,8 @@ case "$image" in ROCM_VERSION=5.1.1 CONDA_CMAKE=yes ;; - pytorch-linux-focal-rocm5.2-py3.7) - ANACONDA_PYTHON_VERSION=3.7 + pytorch-linux-focal-rocm5.2-py3.8) + ANACONDA_PYTHON_VERSION=3.8 GCC_VERSION=9 PROTOBUF=yes DB=yes @@ -279,7 +208,6 @@ case "$image" in ;; pytorch-linux-focal-py3.7-gcc7) ANACONDA_PYTHON_VERSION=3.7 - CMAKE_VERSION=3.16.9 # Required for precompiled header support GCC_VERSION=7 PROTOBUF=yes DB=yes @@ -320,6 +248,10 @@ case "$image" in fi if [[ "$image" == *rocm* ]]; then extract_version_from_image_name rocm ROCM_VERSION + NINJA_VERSION=1.9.0 + fi + if [[ "$image" == *centos7* ]]; then + NINJA_VERSION=1.10.2 fi if [[ "$image" == *gcc* ]]; then extract_version_from_image_name gcc GCC_VERSION diff --git a/.circleci/docker/common/install_base.sh b/.circleci/docker/common/install_base.sh index 6724031c0a447..84835d6de50d7 100755 --- a/.circleci/docker/common/install_base.sh +++ b/.circleci/docker/common/install_base.sh @@ -68,7 +68,10 @@ install_ubuntu() { sudo \ vim \ jq \ - libtool + libtool \ + vim \ + unzip \ + gdb # Should resolve issues related to various apt package repository cert issues # see: https://github.com/pytorch/pytorch/issues/65931 @@ -126,7 +129,9 @@ install_centos() { opencv-devel \ sudo \ wget \ - vim + vim \ + unzip \ + gdb # Cleanup yum clean all diff --git a/.circleci/docker/common/install_conda.sh b/.circleci/docker/common/install_conda.sh index 713aad4729110..84f9538ce1248 100755 --- a/.circleci/docker/common/install_conda.sh +++ b/.circleci/docker/common/install_conda.sh @@ -104,9 +104,6 @@ if [ -n "$ANACONDA_PYTHON_VERSION" ]; then conda_install magma-cuda$(TMP=${CUDA_VERSION/./};echo ${TMP%.*[0-9]}) -c pytorch fi - # TODO: This isn't working atm - conda_install nnpack -c killeent - # Install some other packages, including those needed for Python test reporting pip_install -r /opt/conda/requirements-ci.txt diff --git a/.circleci/docker/common/install_cudnn.sh b/.circleci/docker/common/install_cudnn.sh index 4a8829b1cba11..f68fc6946c2eb 100644 --- a/.circleci/docker/common/install_cudnn.sh +++ b/.circleci/docker/common/install_cudnn.sh @@ -6,9 +6,9 @@ if [[ ${CUDNN_VERSION} == 8 ]]; then CUDNN_NAME="cudnn-linux-x86_64-8.3.2.44_cuda11.5-archive" if [[ ${CUDA_VERSION:0:4} == "11.7" ]]; then CUDNN_NAME="cudnn-linux-x86_64-8.5.0.96_cuda11-archive" - curl -OLs https://ossci-linux.s3.amazonaws.com/${CUDNN_NAME}.tar.xz + curl --retry 3 -OLs https://ossci-linux.s3.amazonaws.com/${CUDNN_NAME}.tar.xz else - curl -OLs https://developer.download.nvidia.com/compute/redist/cudnn/v8.3.2/local_installers/11.5/${CUDNN_NAME}.tar.xz + curl --retry 3 -OLs https://developer.download.nvidia.com/compute/redist/cudnn/v8.3.2/local_installers/11.5/${CUDNN_NAME}.tar.xz fi tar xf ${CUDNN_NAME}.tar.xz diff --git a/.circleci/docker/common/install_docs_reqs.sh b/.circleci/docker/common/install_docs_reqs.sh index 1adc9e8009a02..e60171208ae1a 100644 --- a/.circleci/docker/common/install_docs_reqs.sh +++ b/.circleci/docker/common/install_docs_reqs.sh @@ -7,10 +7,10 @@ if [ -n "$KATEX" ]; then # Ignore error if gpg-agent doesn't exist (for Ubuntu 16.04) apt-get install -y gpg-agent || : - curl -sL https://deb.nodesource.com/setup_12.x | sudo -E bash - + curl --retry 3 -sL https://deb.nodesource.com/setup_12.x | sudo -E bash - sudo apt-get install -y nodejs - curl -sS https://dl.yarnpkg.com/debian/pubkey.gpg | sudo apt-key add - + curl --retry 3 -sS https://dl.yarnpkg.com/debian/pubkey.gpg | sudo apt-key add - echo "deb https://dl.yarnpkg.com/debian/ stable main" | sudo tee /etc/apt/sources.list.d/yarn.list apt-get update diff --git a/.circleci/docker/common/install_protobuf.sh b/.circleci/docker/common/install_protobuf.sh index 9d9f6c40ba0cf..4b7a7a6ac23f7 100755 --- a/.circleci/docker/common/install_protobuf.sh +++ b/.circleci/docker/common/install_protobuf.sh @@ -12,7 +12,7 @@ install_protobuf_317() { # g++: error: ./../lib64/crti.o: No such file or directory ln -s /usr/lib64 "$pb_dir/lib64" - curl -LO "https://github.com/protocolbuffers/protobuf/releases/download/v3.17.3/protobuf-all-3.17.3.tar.gz" + curl -LO "https://github.com/protocolbuffers/protobuf/releases/download/v3.17.3/protobuf-all-3.17.3.tar.gz" --retry 3 tar -xvz -C "$pb_dir" --strip-components 1 -f protobuf-all-3.17.3.tar.gz # -j6 to balance memory usage and speed. # naked `-j` seems to use too much memory. diff --git a/.circleci/docker/common/install_rocm.sh b/.circleci/docker/common/install_rocm.sh index 51c8402aa3787..7ad0c4f123e1c 100644 --- a/.circleci/docker/common/install_rocm.sh +++ b/.circleci/docker/common/install_rocm.sh @@ -29,7 +29,12 @@ install_ubuntu() { if [[ $(ver $ROCM_VERSION) -ge $(ver 4.5) ]]; then # Add amdgpu repository UBUNTU_VERSION_NAME=`cat /etc/os-release | grep UBUNTU_CODENAME | awk -F= '{print $2}'` - local amdgpu_baseurl="https://repo.radeon.com/amdgpu/${AMDGPU_VERSIONS[$ROCM_VERSION]}/ubuntu" + local amdgpu_baseurl + if [[ $(ver $ROCM_VERSION) -ge $(ver 5.3) ]]; then + amdgpu_baseurl="https://repo.radeon.com/amdgpu/${ROCM_VERSION}/ubuntu" + else + amdgpu_baseurl="https://repo.radeon.com/amdgpu/${AMDGPU_VERSIONS[$ROCM_VERSION]}/ubuntu" + fi echo "deb [arch=amd64] ${amdgpu_baseurl} ${UBUNTU_VERSION_NAME} main" > /etc/apt/sources.list.d/amdgpu.list fi @@ -38,6 +43,10 @@ install_ubuntu() { ROCM_REPO="xenial" fi + if [[ $(ver $ROCM_VERSION) -ge $(ver 5.3) ]]; then + ROCM_REPO="${UBUNTU_VERSION_NAME}" + fi + # Add rocm repository wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - local rocm_baseurl="http://repo.radeon.com/rocm/apt/${ROCM_VERSION}" @@ -78,7 +87,16 @@ install_centos() { if [[ $(ver $ROCM_VERSION) -ge $(ver 4.5) ]]; then # Add amdgpu repository - local amdgpu_baseurl="https://repo.radeon.com/amdgpu/${AMDGPU_VERSIONS[$ROCM_VERSION]}/rhel/7.9/main/x86_64" + local amdgpu_baseurl + if [[ $OS_VERSION == 9 ]]; then + amdgpu_baseurl="https://repo.radeon.com/amdgpu/${AMDGPU_VERSIONS[$ROCM_VERSION]}/rhel/9.0/main/x86_64" + else + if [[ $(ver $ROCM_VERSION) -ge $(ver 5.3) ]]; then + amdgpu_baseurl="https://repo.radeon.com/amdgpu/${ROCM_VERSION}/rhel/7.9/main/x86_64" + else + amdgpu_baseurl="https://repo.radeon.com/amdgpu/${AMDGPU_VERSIONS[$ROCM_VERSION]}/rhel/7.9/main/x86_64" + fi + fi echo "[AMDGPU]" > /etc/yum.repos.d/amdgpu.repo echo "name=AMDGPU" >> /etc/yum.repos.d/amdgpu.repo echo "baseurl=${amdgpu_baseurl}" >> /etc/yum.repos.d/amdgpu.repo diff --git a/.circleci/docker/requirements-ci.txt b/.circleci/docker/requirements-ci.txt index 018a7f6544fda..e527d29d4989b 100644 --- a/.circleci/docker/requirements-ci.txt +++ b/.circleci/docker/requirements-ci.txt @@ -159,8 +159,13 @@ pytest-shard #Pinned versions: #test that import: +pytest-flakefinder==1.1.0 +#Description: plugin for rerunning tests a fixed number of times in pytest +#Pinned versions: 1.1.0 +#test that import: + pytest-rerunfailures -#Description: plugin for rerunning tests in pytest +#Description: plugin for rerunning failure tests in pytest #Pinned versions: #test that import: diff --git a/.circleci/scripts/binary_install_miniconda.sh b/.circleci/scripts/binary_install_miniconda.sh index 43eb006742aed..3541a32ac6bf9 100755 --- a/.circleci/scripts/binary_install_miniconda.sh +++ b/.circleci/scripts/binary_install_miniconda.sh @@ -31,9 +31,9 @@ fi conda_sh="$workdir/install_miniconda.sh" if [[ "$(uname)" == Darwin ]]; then - curl --retry 3 -o "$conda_sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh + curl --retry 3 --retry-all-errors -o "$conda_sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh else - curl --retry 3 -o "$conda_sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh + curl --retry 3 --retry-all-errors -o "$conda_sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh fi chmod +x "$conda_sh" "$conda_sh" -b -p "$MINICONDA_ROOT" diff --git a/.circleci/scripts/binary_ios_upload.sh b/.circleci/scripts/binary_ios_upload.sh index da38065847eff..48518c4707c6f 100644 --- a/.circleci/scripts/binary_ios_upload.sh +++ b/.circleci/scripts/binary_ios_upload.sh @@ -33,7 +33,7 @@ fi cp ${PROJ_ROOT}/LICENSE ${ZIP_DIR}/ # zip the library export DATE="$(date -u +%Y%m%d)" -export IOS_NIGHTLY_BUILD_VERSION="1.14.0.${DATE}" +export IOS_NIGHTLY_BUILD_VERSION="2.0.0.${DATE}" if [ "${BUILD_LITE_INTERPRETER}" == "1" ]; then # libtorch_lite_ios_nightly_1.11.0.20210810.zip ZIPFILE="libtorch_lite_ios_nightly_${IOS_NIGHTLY_BUILD_VERSION}.zip" diff --git a/.circleci/scripts/binary_linux_test.sh b/.circleci/scripts/binary_linux_test.sh index 6e34b3e1e5f41..854ad883143b8 100755 --- a/.circleci/scripts/binary_linux_test.sh +++ b/.circleci/scripts/binary_linux_test.sh @@ -98,7 +98,7 @@ if [[ "$PACKAGE_TYPE" == conda ]]; then conda install \${EXTRA_CONDA_FLAGS} -y "\$pkg" --offline ) elif [[ "$PACKAGE_TYPE" != libtorch ]]; then - pip install "\$pkg" + pip install "\$pkg" --extra-index-url "https://download.pytorch.org/whl/nightly/${DESIRED_CUDA}" retry pip install -q future numpy protobuf typing-extensions six fi if [[ "$PACKAGE_TYPE" == libtorch ]]; then diff --git a/.circleci/scripts/binary_populate_env.sh b/.circleci/scripts/binary_populate_env.sh index 3294c72024aa3..7714371e26429 100755 --- a/.circleci/scripts/binary_populate_env.sh +++ b/.circleci/scripts/binary_populate_env.sh @@ -59,7 +59,7 @@ PIP_UPLOAD_FOLDER='nightly/' # We put this here so that OVERRIDE_PACKAGE_VERSION below can read from it export DATE="$(date -u +%Y%m%d)" #TODO: We should be pulling semver version from the base version.txt -BASE_BUILD_VERSION="1.14.0.dev$DATE" +BASE_BUILD_VERSION="2.0.0.dev$DATE" # Change BASE_BUILD_VERSION to git tag when on a git tag # Use 'git -C' to make doubly sure we're in the correct directory for checking # the git tag diff --git a/.circleci/scripts/build_android_gradle.sh b/.circleci/scripts/build_android_gradle.sh index 598e9cd0a6bd2..8312a18eb0aad 100755 --- a/.circleci/scripts/build_android_gradle.sh +++ b/.circleci/scripts/build_android_gradle.sh @@ -20,6 +20,11 @@ do touch "$file" || true done < <(find /var/lib/jenkins/.gradle -type f -print0) +# Patch pocketfft (as Android does not have aligned_alloc even if compiled with c++17 +if [ -f ~/workspace/third_party/pocketfft/pocketfft_hdronly.h ]; then + sed -i -e "s/#if __cplusplus >= 201703L/#if 0/" ~/workspace/third_party/pocketfft/pocketfft_hdronly.h +fi + export GRADLE_LOCAL_PROPERTIES=~/workspace/android/local.properties rm -f $GRADLE_LOCAL_PROPERTIES echo "sdk.dir=/opt/android/sdk" >> $GRADLE_LOCAL_PROPERTIES diff --git a/.circleci/scripts/driver_update.bat b/.circleci/scripts/driver_update.bat index 46c05475cdba8..fb87743666213 100644 --- a/.circleci/scripts/driver_update.bat +++ b/.circleci/scripts/driver_update.bat @@ -1,5 +1,5 @@ set "DRIVER_DOWNLOAD_LINK=https://s3.amazonaws.com/ossci-windows/452.39-data-center-tesla-desktop-win10-64bit-international.exe" -curl --retry 3 -kL %DRIVER_DOWNLOAD_LINK% --output 452.39-data-center-tesla-desktop-win10-64bit-international.exe +curl --retry 3 --retry-all-errors -kL %DRIVER_DOWNLOAD_LINK% --output 452.39-data-center-tesla-desktop-win10-64bit-international.exe if errorlevel 1 exit /b 1 start /wait 452.39-data-center-tesla-desktop-win10-64bit-international.exe -s -noreboot diff --git a/.circleci/scripts/python_doc_push_script.sh b/.circleci/scripts/python_doc_push_script.sh index f9b019ec069b3..d255f77c82e8e 100755 --- a/.circleci/scripts/python_doc_push_script.sh +++ b/.circleci/scripts/python_doc_push_script.sh @@ -135,6 +135,9 @@ git commit -m "Generate Python docs from pytorch/pytorch@${GITHUB_SHA}" || true git status if [[ "${WITH_PUSH:-}" == true ]]; then + # push to a temp branch first to trigger CLA check and satisfy branch protections + git push -u origin HEAD:pytorchbot/temp-branch-py -f + sleep 30 git push -u origin "${branch}" fi diff --git a/.circleci/scripts/setup_ci_environment.sh b/.circleci/scripts/setup_ci_environment.sh index 8ac4f5b43a9a2..42a605cd44451 100755 --- a/.circleci/scripts/setup_ci_environment.sh +++ b/.circleci/scripts/setup_ci_environment.sh @@ -32,7 +32,7 @@ if ! command -v aws >/dev/null; then fi if [ -n "${USE_CUDA_DOCKER_RUNTIME:-}" ]; then - DRIVER_FN="NVIDIA-Linux-x86_64-515.57.run" + DRIVER_FN="NVIDIA-Linux-x86_64-515.76.run" wget "https://s3.amazonaws.com/ossci-linux/nvidia_driver/$DRIVER_FN" sudo /bin/bash "$DRIVER_FN" -s --no-drm || (sudo cat /var/log/nvidia-installer.log && false) nvidia-smi @@ -40,8 +40,8 @@ if [ -n "${USE_CUDA_DOCKER_RUNTIME:-}" ]; then # Taken directly from https://github.com/NVIDIA/nvidia-docker # Add the package repositories distribution=$(. /etc/os-release;echo "$ID$VERSION_ID") - curl -s -L https://nvidia.github.io/nvidia-docker/gpgkey | sudo apt-key add - - curl -s -L "https://nvidia.github.io/nvidia-docker/${distribution}/nvidia-docker.list" | sudo tee /etc/apt/sources.list.d/nvidia-docker.list + curl -s -L --retry 3 --retry-all-errors https://nvidia.github.io/nvidia-docker/gpgkey | sudo apt-key add - + curl -s -L --retry 3 --retry-all-errors "https://nvidia.github.io/nvidia-docker/${distribution}/nvidia-docker.list" | sudo tee /etc/apt/sources.list.d/nvidia-docker.list retry sudo apt-get update -qq # Necessary to get the `--gpus` flag to function within docker diff --git a/.circleci/scripts/setup_linux_system_environment.sh b/.circleci/scripts/setup_linux_system_environment.sh index ce64076e2d64b..780f7c1bd3790 100755 --- a/.circleci/scripts/setup_linux_system_environment.sh +++ b/.circleci/scripts/setup_linux_system_environment.sh @@ -2,7 +2,7 @@ set -eux -o pipefail # Set up CircleCI GPG keys for apt, if needed -curl --retry 3 -s -L https://packagecloud.io/circleci/trusty/gpgkey | sudo apt-key add - +curl --retry 3 --retry-all-errors -s -L https://packagecloud.io/circleci/trusty/gpgkey | sudo apt-key add - # Stop background apt updates. Hypothetically, the kill should not # be necessary, because stop is supposed to send a kill signal to diff --git a/.circleci/scripts/vs_install.ps1 b/.circleci/scripts/vs_install.ps1 index a2e373078adb6..4bbbc24bb0437 100644 --- a/.circleci/scripts/vs_install.ps1 +++ b/.circleci/scripts/vs_install.ps1 @@ -29,7 +29,7 @@ if (Test-Path "${env:ProgramFiles(x86)}\Microsoft Visual Studio\Installer\vswher } echo "Downloading VS installer from S3." -curl.exe --retry 3 -kL $VS_DOWNLOAD_LINK --output vs_installer.exe +curl.exe --retry 3 --retry-all-errors -kL $VS_DOWNLOAD_LINK --output vs_installer.exe if ($LASTEXITCODE -ne 0) { echo "Download of the VS 2019 Version ${env:VS_VERSION} installer failed" exit 1 diff --git a/.circleci/scripts/vs_install_cmath.ps1 b/.circleci/scripts/vs_install_cmath.ps1 index c2998eba25217..62b637ec21b82 100644 --- a/.circleci/scripts/vs_install_cmath.ps1 +++ b/.circleci/scripts/vs_install_cmath.ps1 @@ -1,5 +1,5 @@ $CMATH_DOWNLOAD_LINK = "https://raw.githubusercontent.com/microsoft/STL/12c684bba78f9b032050526abdebf14f58ca26a3/stl/inc/cmath" $VC14_28_INSTALL_PATH="C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\VC\Tools\MSVC\14.28.29910\include" -curl.exe --retry 3 -kL $CMATH_DOWNLOAD_LINK --output "$home\cmath" +curl.exe --retry 3 --retry-all-errors -kL $CMATH_DOWNLOAD_LINK --output "$home\cmath" Move-Item -Path "$home\cmath" -Destination "$VC14_28_INSTALL_PATH" -Force diff --git a/.circleci/scripts/windows_cudnn_install.sh b/.circleci/scripts/windows_cudnn_install.sh index c279259e83416..bbf45a3290b37 100644 --- a/.circleci/scripts/windows_cudnn_install.sh +++ b/.circleci/scripts/windows_cudnn_install.sh @@ -36,7 +36,7 @@ else tmp_dir=$(mktemp -d) ( pushd "${tmp_dir}" - curl --retry 3 -o "${cudnn_installer_name}" "$cudnn_installer_link" + curl --retry 3 --retry-all-errors -o "${cudnn_installer_name}" "$cudnn_installer_link" 7z x "${cudnn_installer_name}" -ocudnn # Use '${var:?}/*' to avoid potentially expanding to '/*' # Remove all of the directories before attempting to copy files diff --git a/.github/actionlint.yaml b/.github/actionlint.yaml index ff640de7bde5a..9c5a52153ca0e 100644 --- a/.github/actionlint.yaml +++ b/.github/actionlint.yaml @@ -16,6 +16,7 @@ self-hosted-runner: - bm-runner - linux.rocm.gpu - macos-m1-12 + - macos-m1-13 - macos-12-xl - macos-12 - macos12.3-m1 diff --git a/.github/actions/filter-test-configs/action.yml b/.github/actions/filter-test-configs/action.yml index 6ec9e48c2df8e..0253577134c8a 100644 --- a/.github/actions/filter-test-configs/action.yml +++ b/.github/actions/filter-test-configs/action.yml @@ -52,7 +52,9 @@ runs: .github/scripts/filter_test_configs.py \ --test-matrix "${{ inputs.test-matrix }}" \ --pr-number "${{ github.event.pull_request.number }}" \ - --tag "${{ steps.parse-ref.outputs.tag }}" + --tag "${{ steps.parse-ref.outputs.tag }}" \ + --event-name "${{ github.event_name }}" \ + --schedule "${{ github.event.schedule }}" - name: Print the filtered test matrix shell: bash diff --git a/.github/actions/setup-rocm/action.yml b/.github/actions/setup-rocm/action.yml index 97dfd22c76ac0..d91762eb9a861 100644 --- a/.github/actions/setup-rocm/action.yml +++ b/.github/actions/setup-rocm/action.yml @@ -36,7 +36,12 @@ runs: run: | ngpu=$(rocminfo | grep -c -E 'Name:.*\sgfx') if [[ "x$ngpu" != "x2" && "x$ngpu" != "x4" ]]; then - echo "Failed to detect GPUs on the runner" + if [[ $ngpu -eq 0 ]]; then + echo "Error: Failed to detect any GPUs on the runner" + else + echo "Error: Detected $ngpu GPUs on the runner, when only 2 or 4 were expected" + fi + echo "Please file an issue on pytorch/pytorch reporting the faulty runner. Include a link to the runner logs so the runner can be identified" exit 1 fi diff --git a/.github/actions/setup-win/action.yml b/.github/actions/setup-win/action.yml index c5f1cac550f68..6dc1a1b6c6fe2 100644 --- a/.github/actions/setup-win/action.yml +++ b/.github/actions/setup-win/action.yml @@ -55,9 +55,10 @@ runs: .circleci/scripts/windows_cudnn_install.sh - name: Setup Python3 - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: - python-version: "3.x" + python-version: '3.x' + check-latest: false cache: pip cache-dependency-path: | **/requirements.txt diff --git a/.github/actions/test-pytorch-binary/action.yml b/.github/actions/test-pytorch-binary/action.yml index bc2c546f57b28..be2090db533db 100644 --- a/.github/actions/test-pytorch-binary/action.yml +++ b/.github/actions/test-pytorch-binary/action.yml @@ -15,7 +15,6 @@ runs: -e BINARY_ENV_FILE \ -e BUILDER_ROOT \ -e BUILD_ENVIRONMENT \ - -e BUILD_SPLIT_CUDA \ -e DESIRED_CUDA \ -e DESIRED_DEVTOOLSET \ -e DESIRED_PYTHON \ diff --git a/.github/actions/upload-test-artifacts/action.yml b/.github/actions/upload-test-artifacts/action.yml index 67083a103e06a..9fd2342601f11 100644 --- a/.github/actions/upload-test-artifacts/action.yml +++ b/.github/actions/upload-test-artifacts/action.yml @@ -34,7 +34,7 @@ runs: run: | # Remove any previous test reports if they exist rm -f test-reports-*.zip - zip -r "test-reports-${FILE_SUFFIX}.zip" test -i '*.xml' + zip -r "test-reports-${FILE_SUFFIX}.zip" test -i '*.xml' -i '*.csv' - name: Zip usage log for upload if: runner.os != 'Windows' && !inputs.use-gha @@ -67,7 +67,7 @@ runs: FILE_SUFFIX: ${{ inputs.file-suffix }} run: | # -ir => recursive include all files in pattern - 7z a "test-reports-$Env:FILE_SUFFIX.zip" -ir'!test\*.xml' + 7z a "test-reports-$Env:FILE_SUFFIX.zip" -ir'!test\*.xml' -ir'!test\*.csv' - name: Zip usage log for upload if: runner.os == 'Windows' && !inputs.use-gha @@ -127,5 +127,19 @@ runs: # Add the run attempt, see [Artifact run attempt] name: test-reports-runattempt${{ github.run_attempt }}-${{ inputs.file-suffix }}.zip retention-days: 14 - if-no-files-found: error - path: test/**/*.xml + # Don't want to fail the workflow here because not all workflows have csv files + if-no-files-found: ignore + path: | + test/**/*.xml + test/**/*.csv + + - name: Store Usage Logs on Github + uses: actions/upload-artifact@v3 + if: inputs.use-gha + with: + # Add the run attempt, see [Artifact run attempt] + name: usage-log-runattempt${{ github.run_attempt }}-${{ inputs.file-suffix }}.zip + retention-days: 14 + if-no-files-found: ignore + path: usage_log.txt + continue-on-error: true diff --git a/.github/auto_request_review.yml b/.github/auto_request_review.yml index 94ceafcb3d133..339f085d939af 100644 --- a/.github/auto_request_review.yml +++ b/.github/auto_request_review.yml @@ -4,17 +4,18 @@ reviewers: symbolic-shapes: - ezyang - Chillee - - wconstab - anjali411 - albanD - - Krovatkin - miladm - bdhirsh + - voznesenskym + - SherlockNoMad per_author: symbolic-shapes: - symbolic-shapes - antoniojkim + - wconstab files: # none yet, TODO: migrate CODEOWNERS here diff --git a/.github/ci_commit_pins/text.txt b/.github/ci_commit_pins/text.txt new file mode 100644 index 0000000000000..c0e01da17fd08 --- /dev/null +++ b/.github/ci_commit_pins/text.txt @@ -0,0 +1 @@ +5b78d074bd303eb230d30567646fcf0358ee2dd4 diff --git a/.github/ci_commit_pins/timm.txt b/.github/ci_commit_pins/timm.txt index 4b199567e9a7b..cdda1d14775c6 100644 --- a/.github/ci_commit_pins/timm.txt +++ b/.github/ci_commit_pins/timm.txt @@ -1 +1 @@ -ebee0a27940adfbb30444d83387b9ea0f1173f40 +6635bc3f7d06c6a0d0481803b24d6ad0004b61ac diff --git a/.github/ci_commit_pins/triton.txt b/.github/ci_commit_pins/triton.txt index 58d82813d6e13..7c5e80098f7b7 100644 --- a/.github/ci_commit_pins/triton.txt +++ b/.github/ci_commit_pins/triton.txt @@ -1 +1 @@ -db3aa1d1fb2bb536752a71d9e0f03cf6a86ddf65 +0d7e7532279e45672555e344646f5c19c3972331 diff --git a/.github/ci_commit_pins/vision.txt b/.github/ci_commit_pins/vision.txt index db0aa4e7d73c4..80fb2b961071d 100644 --- a/.github/ci_commit_pins/vision.txt +++ b/.github/ci_commit_pins/vision.txt @@ -1 +1 @@ -7a62a545ce76f43ccc5cfe0009131f7db14ae7b5 +029cb3fe4526084172c30be14278d46ecd5bf17c diff --git a/.github/ci_commit_pins/xla.txt b/.github/ci_commit_pins/xla.txt index 2ca663bacdea0..204ebba1034a0 100644 --- a/.github/ci_commit_pins/xla.txt +++ b/.github/ci_commit_pins/xla.txt @@ -1 +1 @@ -e1f5a49664b904e3ec1ddb9095ca75b6bbb5c10d +b55aec841b9cf680b04abefaf3c0197a51de8b08 diff --git a/.github/labeler.yml b/.github/labeler.yml index 9581c9fe706cd..14f1765462569 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -12,3 +12,44 @@ - torch/_dynamo/** - torch/_inductor/** - benchmarks/dynamo/** +- torch/_subclasses/fake_tensor.py +- torch/_subclasses/fake_utils.py +- torch/_subclasses/meta_utils.py +- test/distributed/test_dynamo_distributed.py +- functorch/_src/partitioners.py +- functorch/_src/aot_autograd.py + +"module: cpu": +- aten/src/ATen/cpu/** +- aten/src/ATen/native/cpu/** +- aten/src/ATen/native/quantized/cpu/** +- aten/src/ATen/native/Convolution*.cpp +- aten/src/ATen/native/mkldnn/** +- torch/cpu/** +- torch/utils/mkldnn.py +- test/test_mkldnn.py + +"module: mkldnn": +- third_party/ideep +- caffe2/ideep/** +- caffe2/python/ideep/** +- cmake/Modules/FindMKLDNN.cmake +- third_party/mkl-dnn.BUILD +- torch/csrc/jit/codegen/onednn/** +- test/test_jit_llga_fuser.py + +"module: amp (automated mixed precision)": +- torch/amp/** +- aten/src/ATen/autocast_mode.* +- torch/csrc/jit/passes/autocast.cpp +- test/test_autocast.py + +"NNC": +- torch/csrc/jit/tensorexpr/** + +"release notes: quantization": +- torch/ao/quantization/** +- torch/quantization/** +- aten/src/ATen/quantized/** +- aten/src/ATen/native/quantized/cpu/** +- test/quantization/** diff --git a/.github/merge_rules.yaml b/.github/merge_rules.yaml index 26b3eb437251a..c5cf415be984f 100644 --- a/.github/merge_rules.yaml +++ b/.github/merge_rules.yaml @@ -6,7 +6,6 @@ - docs/source/onnx* - docs/source/scripts/onnx/** - scripts/onnx/** - - test/jit/test_export_modes.py - test/onnx/** - tools/onnx/** - torch/_C/__init__.pyi.in @@ -21,6 +20,11 @@ approved_by: - BowenBao - abock + - justinchuby + - shubhambhokare1 + - thiagocrepaldi + - titaiwangms + - wschin mandatory_checks_name: - EasyCLA - Lint @@ -36,6 +40,7 @@ - csarofeen - ngimel - jjsjann123 + - kevinstephano - ptrblck mandatory_checks_name: - EasyCLA @@ -228,9 +233,14 @@ - wanchaol - fduwjj - H-Huang - - d4l3k - aazzolini - kwen2501 + - XilunWu + - wz337 + - awgu + - fegin + - kumpera + - yhcharles mandatory_checks_name: - EasyCLA - Lint @@ -241,9 +251,12 @@ - third_party/ideep - caffe2/ideep/** - caffe2/python/ideep/** + - cmake/Modules/FindMKLDNN.cmake + - third_party/mkl-dnn.BUILD approved_by: - XiaobingSuper - - yanbing-j + - jgong5 + - mingfeima mandatory_checks_name: - EasyCLA - Lint @@ -256,6 +269,7 @@ approved_by: - sanchitintel - chunyuan-w + - jgong5 mandatory_checks_name: - EasyCLA - Lint @@ -268,9 +282,11 @@ - aten/src/ATen/native/quantized/cpu/** - aten/src/ATen/native/Convolution*.cpp - aten/src/ATen/native/mkldnn/** + - test/test_mkldnn.py approved_by: - mingfeima - XiaobingSuper + - jgong5 mandatory_checks_name: - EasyCLA - Lint @@ -283,7 +299,7 @@ - test/test_mkldnn.py approved_by: - leslie-fang-intel - - CaoE + - jgong5 mandatory_checks_name: - EasyCLA - Lint @@ -297,7 +313,18 @@ - test/test_autocast.py approved_by: - leslie-fang-intel - - CaoE + - jgong5 + mandatory_checks_name: + - EasyCLA + - Lint + - pull + +- name: NNC + patterns: + - torch/csrc/jit/tensorexpr/** + approved_by: + - EikanWang + - jgong5 mandatory_checks_name: - EasyCLA - Lint @@ -308,11 +335,12 @@ - torch/csrc/lazy/** - test/cpp/lazy/** - test/lazy/** - - codegen/api/lazy.py - - codegen/dest/lazy_ir.py - - codegen/dest/lazy_ts_lowering.py - - codegen/gen_lazy_tensor.py + - torchgen/api/lazy.py + - torchgen/dest/lazy_ir.py + - torchgen/dest/lazy_ts_lowering.py + - torchgen/gen_lazy_tensor.py - aten/src/ATen/native/ts_native_functions.yaml + - .github/ci_commit_pins/xla.txt approved_by: - alanwaketan - JackCaoG diff --git a/.github/requirements-gha-cache.txt b/.github/requirements-gha-cache.txt index f331d98351ae8..6badbe2cc65c8 100644 --- a/.github/requirements-gha-cache.txt +++ b/.github/requirements-gha-cache.txt @@ -5,12 +5,14 @@ # docs/cpp/requirements.txt # functorch/docs/requirements.txt # .circleci/docker/requirements-ci.txt +boto3==1.19.12 cffi==1.15.0 dataclasses==0.6 jinja2==3.0.1 lintrunner==0.9.2 ninja==1.10.0.post1 pynvml==11.4.1 +pyyaml==6.0 requests==2.26 rich==10.9.0 rockset==0.8.10 diff --git a/.github/requirements/README.md b/.github/requirements/README.md new file mode 100644 index 0000000000000..7300eee145629 --- /dev/null +++ b/.github/requirements/README.md @@ -0,0 +1,24 @@ +### Cached requirements and consolidation of conda and pip installation + +At the moment, the installation of conda and pip dependencies happens at +different places in the CI depending at the whim of different +developers, which makes it very challenging to handle issues like +network flakiness or upstream dependency failures gracefully. So, this +center directory is created to gradually include all the conda environment +and pip requirement files that are used to setup CI jobs. Not only it +gives a clear picture of all the dependencies required by different CI +jobs, but it also allows them to be cached properly to improve CI +reliability. + +The list of support files are as follows: + +* Conda: + * conda-env-macOS-ARM64. This is used by MacOS (m1, arm64) build and + test jobs to setup the conda environment + * conda-env-macOS-X64. This is use by MacOS (x86-64) build and test + jobs to setup the conda environment + * conda-env-Linux-X64. This is used by Linux buck build and test jobs + to setup the conda environment +* Pip: + * pip-requirements-macOS.txt. This is used by MacOS build and test jobs to + setup the pip environment diff --git a/.github/requirements/conda-env-Linux-X64 b/.github/requirements/conda-env-Linux-X64 new file mode 100644 index 0000000000000..c4e8aa7ae548d --- /dev/null +++ b/.github/requirements/conda-env-Linux-X64 @@ -0,0 +1,10 @@ +cffi=1.15.1 +cmake=3.22.* +mkl=2022.1.0 +mkl-include=2022.1.0 +ninja=1.10.2 +numpy=1.23.3 +pyyaml=6.0 +requests=2.28.1 +setuptools=65.5.0 +typing_extensions=4.3.0 diff --git a/.github/requirements/conda-env-macOS-ARM64 b/.github/requirements/conda-env-macOS-ARM64 new file mode 100644 index 0000000000000..77f37cf463ea8 --- /dev/null +++ b/.github/requirements/conda-env-macOS-ARM64 @@ -0,0 +1,20 @@ +numpy=1.22.3 +pyyaml=6.0 +setuptools=61.2.0 +cmake=3.22.* +cffi=1.15.1 +typing_extensions=4.3.0 +dataclasses=0.8 +pip=22.2.2 +six=1.16.0 +pillow=9.2.0 +pkg-config=0.29.2 +wheel=0.37.1 +expecttest=0.1.3 + +# Not pinning certifi so that we can always get the latest certificates +certifi + +# Cross-compiling arm64 from x86-64 picks up 1.40.0 while testing on arm64 +# itself only has up to 1.39.0 from upstream conda. Both work though +libuv>=1.39.0,<=1.40.0 diff --git a/.github/requirements/conda-env-macOS-X64 b/.github/requirements/conda-env-macOS-X64 new file mode 100644 index 0000000000000..897850f0e36b9 --- /dev/null +++ b/.github/requirements/conda-env-macOS-X64 @@ -0,0 +1,18 @@ +mkl=2021.2.0 +mkl-include=2021.2.0 +numpy=1.18.5 +pyyaml=5.3 +setuptools=46.0.0 +cmake=3.22.* +cffi=1.15.1 +typing_extensions=4.3.0 +dataclasses=0.8 +pip=22.2.2 +six=1.16.0 +pillow=9.2.0 +libuv=1.40.0 +pkg-config=0.29.2 +wheel=0.37.1 + +# Not pinning certifi so that we can always get the latest certificates +certifi diff --git a/.github/requirements/pip-requirements-macOS.txt b/.github/requirements/pip-requirements-macOS.txt new file mode 100644 index 0000000000000..dfbaea260116e --- /dev/null +++ b/.github/requirements/pip-requirements-macOS.txt @@ -0,0 +1,22 @@ +boto3==1.19.12 +hypothesis==6.56.4 +expecttest==0.1.3 +librosa>=0.6.2 +mpmath==1.2.1 +networkx==2.8.7 +# Use numba-0.49.1 or older on Intel Macs, but 0.56.0 on M1 machines, as older numba is not available +numba==0.56.0; platform_machine == "arm64" +numba<=0.49.1; platform_machine != "arm64" +opt-einsum>=3.3 +psutil==5.9.1 +pynvml==11.4.1 +pygments==2.12.0 +pytest==7.2.0 +pytest-xdist==3.0.2 +pytest-rerunfailures==10.2 +pytest-flakefinder==1.1.0 +pytest-shard==0.1.2 +scipy==1.9.0 +sympy==1.11.1 +unittest-xml-reporting<=3.2.0,>=2.0.0 +xdoctest==1.0.2 diff --git a/.github/scripts/README.md b/.github/scripts/README.md index 22099c3732ea5..cc9e1617b11a7 100644 --- a/.github/scripts/README.md +++ b/.github/scripts/README.md @@ -3,7 +3,7 @@ > NOTE: This README contains information for the `.github` directory but cannot be located there because it will overwrite the repo README. -This directory contains workflows and scripts to support our CI infrastructure that runs on Github Actions. +This directory contains workflows and scripts to support our CI infrastructure that runs on GitHub Actions. ## Workflows @@ -36,7 +36,7 @@ New generated binary workflows can be added in the `.github/scripts/generate_ci_ examples from that script in order to add the workflow to the stream that is relevant to what you particularly care about. -Different parameters can be used to acheive different goals, i.e. running jobs on a cron, running only on trunk, etc. +Different parameters can be used to achieve different goals, i.e. running jobs on a cron, running only on trunk, etc. #### ciflow (trunk) diff --git a/.github/scripts/build_publish_nightly_docker.sh b/.github/scripts/build_publish_nightly_docker.sh deleted file mode 100644 index db84704aa3e4c..0000000000000 --- a/.github/scripts/build_publish_nightly_docker.sh +++ /dev/null @@ -1,44 +0,0 @@ -#!/usr/bin/env bash - -set -xeuo pipefail - -PYTORCH_DOCKER_TAG=$(git describe --tags --always)-devel -CUDA_VERSION=11.3.1 - -# Build PyTorch nightly docker -make -f docker.Makefile \ - DOCKER_REGISTRY=ghcr.io \ - DOCKER_ORG=pytorch \ - CUDA_VERSION=${CUDA_VERSION} \ - DOCKER_IMAGE=pytorch-nightly \ - DOCKER_TAG=${PYTORCH_DOCKER_TAG} \ - INSTALL_CHANNEL=pytorch-nightly BUILD_TYPE=official devel-image - -# Get the PYTORCH_NIGHTLY_COMMIT from the docker image -PYTORCH_NIGHTLY_COMMIT=$(docker run \ - ghcr.io/pytorch/pytorch-nightly:${PYTORCH_DOCKER_TAG} \ - python -c 'import torch; print(torch.version.git_version)' | head -c 7) - -docker tag ghcr.io/pytorch/pytorch-nightly:${PYTORCH_DOCKER_TAG} \ - ghcr.io/pytorch/pytorch-nightly:${PYTORCH_NIGHTLY_COMMIT}-cu${CUDA_VERSION} - -docker tag ghcr.io/pytorch/pytorch-nightly:${PYTORCH_NIGHTLY_COMMIT}-cu${CUDA_VERSION} \ - ghcr.io/pytorch/pytorch-nightly:latest - -if [[ ${WITH_PUSH:-} == "true" ]]; then - # Push the nightly docker to GitHub Container Registry - echo $GHCR_PAT | docker login ghcr.io -u pytorch --password-stdin - make -f docker.Makefile \ - DOCKER_REGISTRY=ghcr.io \ - DOCKER_ORG=pytorch \ - DOCKER_IMAGE=pytorch-nightly \ - DOCKER_TAG=${PYTORCH_NIGHTLY_COMMIT}-cu${CUDA_VERSION} \ - devel-push - - make -f docker.Makefile \ - DOCKER_REGISTRY=ghcr.io \ - DOCKER_ORG=pytorch \ - DOCKER_IMAGE=pytorch-nightly \ - DOCKER_TAG=latest \ - devel-push -fi diff --git a/.github/scripts/build_triton_wheel.py b/.github/scripts/build_triton_wheel.py index d9d2a2e98bd35..b0c7e3f8b3bd9 100644 --- a/.github/scripts/build_triton_wheel.py +++ b/.github/scripts/build_triton_wheel.py @@ -2,6 +2,7 @@ from subprocess import check_call from pathlib import Path from tempfile import TemporaryDirectory +from typing import Optional import sys import shutil SCRIPT_DIR = Path(__file__).parent @@ -29,12 +30,30 @@ def patch_setup_py(path: Path, *, version: str = "2.0.0", name: str = "triton") f.write(orig) -def build_triton(commit_hash: str) -> Path: +def build_triton(commit_hash: str, build_conda: bool = False, py_version : Optional[str] = None) -> Path: with TemporaryDirectory() as tmpdir: triton_basedir = Path(tmpdir) / "triton" triton_pythondir = triton_basedir / "python" check_call(["git", "clone", "https://github.com/openai/triton"], cwd=tmpdir) check_call(["git", "checkout", commit_hash], cwd=triton_basedir) + if build_conda: + with open(triton_basedir / "meta.yaml", "w") as meta: + print(f"package:\n name: torchtriton\n version: 2.0.0+{commit_hash[:10]}\n", file=meta) + print("source:\n path: .\n", file=meta) + print("build:\n string: py{{py}}\n number: 1\n script: cd python; " + "python setup.py install --single-version-externally-managed --record=record.txt\n", file=meta) + print("requirements:\n host:\n - python\n - setuptools\n run:\n - python\n" + " - filelock\n - pytorch\n", file=meta) + print("about:\n home: https://github.com/openai/triton\n license: MIT\n summary:" + " 'A language and compiler for custom Deep Learning operation'", file=meta) + + if py_version is None: + py_version = f"{sys.version_info.major}.{sys.version_info.minor}" + check_call(["conda", "build", "--python", py_version, "--output-folder", tmpdir, "."], cwd=triton_basedir) + conda_path = list(Path(tmpdir).glob("linux-64/torchtriton*.bz2"))[0] + shutil.copy(conda_path, Path.cwd()) + return Path.cwd() / conda_path.name + patch_setup_py(triton_pythondir / "setup.py", name="torchtriton", version=f"2.0.0+{commit_hash[:10]}") check_call([sys.executable, "setup.py", "bdist_wheel"], cwd=triton_pythondir) whl_path = list((triton_pythondir / "dist").glob("*.whl"))[0] @@ -43,8 +62,13 @@ def build_triton(commit_hash: str) -> Path: def main() -> None: + from argparse import ArgumentParser + parser = ArgumentParser("Build Triton binaries") + parser.add_argument("--build-conda", action="store_true") + parser.add_argument("--py-version", type=str) + args = parser.parse_args() pin = read_triton_pin() - build_triton(pin) + build_triton(pin, build_conda=args.build_conda, py_version=args.py_version) if __name__ == "__main__": diff --git a/.github/scripts/check_labels.py b/.github/scripts/check_labels.py new file mode 100755 index 0000000000000..2d4a216daf942 --- /dev/null +++ b/.github/scripts/check_labels.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python3 +"""check_labels.py""" + +from typing import Any, List + +from export_pytorch_labels import get_pytorch_labels +from gitutils import ( + get_git_remote_name, + get_git_repo_dir, + GitRepo, +) +from trymerge import ( + _fetch_url, + gh_post_pr_comment, + GitHubPR, +) + + +BOT_AUTHORS = ["github-actions", "pytorchmergebot", "pytorch-bot"] + +ERR_MSG_TITLE = "This PR needs a label" +ERR_MSG = ( + f"# {ERR_MSG_TITLE}\n" + "If your changes are user facing and intended to be a part of release notes, please use a label starting with `release notes:`.\n\n" # noqa: E501 pylint: disable=line-too-long + "If not, please add the `topic: not user facing` label.\n\n" + "For more information, see https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work." # noqa: E501 pylint: disable=line-too-long +) + + +def get_release_notes_labels() -> List[str]: + return [label for label in get_pytorch_labels() if label.lstrip().startswith("release notes:")] + + +def delete_comment(comment_id: int) -> None: + url = f"https://api.github.com/repos/pytorch/pytorch/issues/comments/{comment_id}" + _fetch_url(url, method="DELETE") + + +def has_required_labels(pr: GitHubPR) -> bool: + pr_labels = pr.get_labels() + # Check if PR is not user facing + is_not_user_facing_pr = any(label.strip() == "topic: not user facing" for label in pr_labels) + return is_not_user_facing_pr or any(label.strip() in get_release_notes_labels() for label in pr_labels) + + +def delete_comments(pr: GitHubPR) -> None: + # Delete all previous comments + for comment in pr.get_comments(): + if comment.body_text.lstrip(" #").startswith(ERR_MSG_TITLE) and comment.author_login in BOT_AUTHORS: + delete_comment(comment.database_id) + + +def add_comment(pr: GitHubPR) -> None: + # Only make a comment if one doesn't exist already + for comment in pr.get_comments(): + if comment.body_text.lstrip(" #").startswith(ERR_MSG_TITLE) and comment.author_login in BOT_AUTHORS: + return + gh_post_pr_comment(pr.org, pr.project, pr.pr_num, ERR_MSG) + + +def parse_args() -> Any: + from argparse import ArgumentParser + parser = ArgumentParser("Check PR labels") + parser.add_argument("pr_num", type=int) + + return parser.parse_args() + + +def main() -> None: + args = parse_args() + repo = GitRepo(get_git_repo_dir(), get_git_remote_name()) + org, project = repo.gh_owner_and_name() + pr = GitHubPR(org, project, args.pr_num) + + try: + if not has_required_labels(pr): + print(ERR_MSG) + add_comment(pr) + exit(1) + else: + delete_comments(pr) + except Exception as e: + pass + + +if __name__ == "__main__": + main() diff --git a/.github/scripts/ensure_actions_will_cancel.py b/.github/scripts/ensure_actions_will_cancel.py index c479aefb9fc43..729d02f560fa1 100755 --- a/.github/scripts/ensure_actions_will_cancel.py +++ b/.github/scripts/ensure_actions_will_cancel.py @@ -42,26 +42,26 @@ def should_check(filename: Path) -> bool: print("ERROR: duplicate workflow name:", name, file=sys.stderr) errors_found = True names.add(name) - - expected = { - "group": EXPECTED_GROUP, - "cancel-in-progress": True, - } - actual = data.get("concurrency", None) - if actual != expected: + actual = data.get("concurrency", {}) + if not actual.get("group", "").startswith(EXPECTED_GROUP): print( f"'concurrency' incorrect or not found in '{filename.relative_to(REPO_ROOT)}'", file=sys.stderr, ) print( - f"expected: {expected}", + f"concurrency group should start with {EXPECTED_GROUP} but found {actual.get('group', None)}", file=sys.stderr, ) + errors_found = True + if not actual.get("cancel-in-progress", False): print( - f"actual: {actual}", + f"'concurrency' incorrect or not found in '{filename.relative_to(REPO_ROOT)}'", + file=sys.stderr, + ) + print( + f"concurrency cancel-in-progress should be True but found {actual.get('cancel-in-progress', None)}", file=sys.stderr, ) - errors_found = True if errors_found: sys.exit(1) diff --git a/.github/scripts/filter_test_configs.py b/.github/scripts/filter_test_configs.py index 10170161554c7..eab32401ad97f 100755 --- a/.github/scripts/filter_test_configs.py +++ b/.github/scripts/filter_test_configs.py @@ -23,6 +23,10 @@ "force_on_cpu", "functorch", "inductor", + "inductor_distributed", + "inductor_huggingface", + "inductor_timm", + "inductor_torchbench", "jit_legacy", "multigpu", "nogpu_AVX512", @@ -32,12 +36,21 @@ "xla", }} +# Supported modes when running periodically +SUPPORTED_PERIODICAL_MODES = { + "mem_leak_check", + "rerun_disabled_tests", +} + + def parse_args() -> Any: from argparse import ArgumentParser parser = ArgumentParser("Filter all test configurations and keep only requested ones") parser.add_argument("--test-matrix", type=str, required=True, help="the original test matrix") parser.add_argument("--pr-number", type=str, help="the pull request number") parser.add_argument("--tag", type=str, help="the associated tag if it exists") + parser.add_argument("--event-name", type=str, help="name of the event that triggered the job (pull, schedule, etc)") + parser.add_argument("--schedule", type=str, help="cron schedule that triggered the job") return parser.parse_args() @@ -106,6 +119,23 @@ def filter(test_matrix: Dict[str, List[Any]], labels: Set[str]) -> Dict[str, Lis return filtered_test_matrix +def set_periodic_modes(test_matrix: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + """ + Apply all periodic modes when running under a schedule + """ + scheduled_test_matrix: Dict[str, List[Any]] = { + "include": [], + } + + for config in test_matrix.get("include", []): + for mode in SUPPORTED_PERIODICAL_MODES: + cfg = config.copy() + cfg[mode] = mode + scheduled_test_matrix["include"].append(cfg) + + return scheduled_test_matrix + + def set_output(name: str, val: Any) -> None: if os.getenv("GITHUB_OUTPUT"): with open(str(os.getenv("GITHUB_OUTPUT")), "a") as env: @@ -159,6 +189,11 @@ def main() -> None: # No PR number, no tag, we can just return the test matrix as it is filtered_test_matrix = test_matrix + if args.event_name == "schedule" and args.schedule == '29 8 * * *': + # we don't want to run the mem leack check or disabled tests on normal + # periodically scheduled jobs, only the ones at this time + filtered_test_matrix = set_periodic_modes(filtered_test_matrix) + # Set the filtered test matrix as the output set_output("test-matrix", json.dumps(filtered_test_matrix)) diff --git a/.github/scripts/generate_binary_build_matrix.py b/.github/scripts/generate_binary_build_matrix.py index 54949ff27bb1b..deb225287b3f5 100644 --- a/.github/scripts/generate_binary_build_matrix.py +++ b/.github/scripts/generate_binary_build_matrix.py @@ -16,7 +16,7 @@ CUDA_ARCHES = ["11.6", "11.7"] -ROCM_ARCHES = ["5.1.1", "5.2"] +ROCM_ARCHES = ["5.2", "5.3"] def arch_type(arch_version: str) -> str: @@ -219,9 +219,9 @@ def generate_wheels_matrix(os: str, "container_image": WHEEL_CONTAINER_IMAGES[arch_version], "package_type": package_type, "pytorch_extra_install_requirements": - "nvidia-cuda-runtime-cu11;" - "nvidia-cudnn-cu11==8.5.0.96;" - "nvidia-cublas-cu11==11.10.3.66", + "nvidia-cuda-runtime-cu11; platform_system == 'Linux' | " + "nvidia-cudnn-cu11==8.5.0.96; platform_system == 'Linux' | " + "nvidia-cublas-cu11==11.10.3.66; platform_system == 'Linux'", "build_name": f"{package_type}-py{python_version}-{gpu_arch_type}{gpu_arch_version}-with-pypi-cudnn" .replace( diff --git a/.github/scripts/generate_ci_workflows.py b/.github/scripts/generate_ci_workflows.py index 1ef3142286bf3..35680e30ee6a7 100755 --- a/.github/scripts/generate_ci_workflows.py +++ b/.github/scripts/generate_ci_workflows.py @@ -154,7 +154,7 @@ class OperatingSystem: package_type="libtorch", abi_version=generate_binary_build_matrix.PRE_CXX11_ABI, build_configs=generate_binary_build_matrix.generate_libtorch_matrix( - OperatingSystem.LINUX, generate_binary_build_matrix.CXX11_ABI, + OperatingSystem.LINUX, generate_binary_build_matrix.PRE_CXX11_ABI, arches=["cpu"], libtorch_variants=["shared-with-deps"], ), @@ -277,7 +277,7 @@ class OperatingSystem: BinaryBuildWorkflow( os=OperatingSystem.MACOS_ARM64, package_type="wheel", - build_configs=generate_binary_build_matrix.generate_wheels_matrix(OperatingSystem.MACOS), + build_configs=generate_binary_build_matrix.generate_wheels_matrix(OperatingSystem.MACOS_ARM64), cross_compile_arm64=True, ciflow_config=CIFlowConfig( labels={LABEL_CIFLOW_BINARIES, LABEL_CIFLOW_BINARIES_WHEEL}, diff --git a/.github/scripts/generate_pytorch_version.py b/.github/scripts/generate_pytorch_version.py index 0655df137e07c..02c19844cd09f 100755 --- a/.github/scripts/generate_pytorch_version.py +++ b/.github/scripts/generate_pytorch_version.py @@ -23,27 +23,22 @@ def get_pytorch_root() -> Path: def get_tag() -> str: root = get_pytorch_root() - # We're on a tag - am_on_tag = ( - subprocess.run( - ['git', 'describe', '--tags', '--exact'], - cwd=root, - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL - ).returncode == 0 - ) - tag = "" - if am_on_tag: + try: dirty_tag = subprocess.check_output( - ['git', 'describe'], + ['git', 'describe', '--tags', '--exact'], cwd=root ).decode('ascii').strip() - # Strip leading v that we typically do when we tag branches - # ie: v1.7.1 -> 1.7.1 - tag = re.sub(LEADING_V_PATTERN, "", dirty_tag) - # Strip trailing rc pattern - # ie: 1.7.1-rc1 -> 1.7.1 - tag = re.sub(TRAILING_RC_PATTERN, "", tag) + except subprocess.CalledProcessError: + return "" + # Strip leading v that we typically do when we tag branches + # ie: v1.7.1 -> 1.7.1 + tag = re.sub(LEADING_V_PATTERN, "", dirty_tag) + # Strip trailing rc pattern + # ie: 1.7.1-rc1 -> 1.7.1 + tag = re.sub(TRAILING_RC_PATTERN, "", tag) + # Ignore ciflow tags + if tag.startswith("ciflow/"): + return "" return tag def get_base_version() -> str: diff --git a/.github/scripts/gql_mocks.json b/.github/scripts/gql_mocks.json index 4a6ea6a6402c7..073658b0d6bc8 100644 --- a/.github/scripts/gql_mocks.json +++ b/.github/scripts/gql_mocks.json @@ -18634,7 +18634,7 @@ "path": "torch/ao/quantization/fx/fuse.py" }, { - "path": "torch/ao/quantization/fx/fusion_patterns.py" + "path": "torch/ao/quantization/fx/fuse_handler.py" }, { "path": "torch/ao/quantization/fx/match_utils.py" @@ -18646,7 +18646,7 @@ "path": "torch/ao/quantization/fx/prepare.py" }, { - "path": "torch/ao/quantization/fx/quantization_patterns.py" + "path": "torch/ao/quantization/fx/quantize_handler.py" }, { "path": "torch/ao/quantization/qconfig.py" @@ -20855,5 +20855,15054 @@ "team": null } } + }, + "query_sha=81fd873151c3cded18314e9e53bf54a93ffb0afa9c52fa2cbafb2ceab7df5e45 name=pytorch number=82169 owner=pytorch": { + "data": { + "repository": { + "pullRequest": { + "closed": true, + "isCrossRepository": false, + "author": { + "login": "ezyang" + }, + "title": "Move test_dtypes so it runs later", + "body": "Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom):\n* __->__ #82169\n\nThe error messages it gives are very unhelpful (because a failure\ngets translated into \"dtype was not supported\" rather than the\nactual backtrace), so I'd rather get error messages about this after\nI've tested basic functionality.\n\nSigned-off-by: Edward Z. Yang ", + "headRefName": "gh/ezyang/1279/head", + "headRepository": { + "nameWithOwner": "pytorch/pytorch" + }, + "baseRefName": "gh/ezyang/1279/base", + "baseRepository": { + "nameWithOwner": "pytorch/pytorch", + "isPrivate": false, + "defaultBranchRef": { + "name": "master" + } + }, + "mergeCommit": null, + "commits_with_authors": { + "nodes": [ + { + "commit": { + "author": { + "user": { + "login": "ezyang" + }, + "email": "ezyang@fb.com", + "name": "Edward Z. Yang" + }, + "oid": "cef34da55a59da5a32494bff218ccd4978b659d3" + } + }, + { + "commit": { + "author": { + "user": { + "login": "ezyang" + }, + "email": "ezyang@fb.com", + "name": "Edward Z. Yang" + }, + "oid": "83ad7e73a07111ac1d85e931d14360cc22c01edd" + } + }, + { + "commit": { + "author": { + "user": { + "login": "ezyang" + }, + "email": "ezyang@fb.com", + "name": "Edward Z. Yang" + }, + "oid": "28140e4008289251b695385acfb48ac7a47cd49c" + } + } + ], + "pageInfo": { + "endCursor": "Mw", + "hasNextPage": false + }, + "totalCount": 3 + }, + "commits": { + "nodes": [ + { + "commit": { + "checkSuites": { + "edges": [ + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "Lint" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "lintrunner", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747823981/jobs/4310707890" + }, + { + "name": "Test collect_env (with_torch)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747823981/jobs/4310708140" + }, + { + "name": "Test collect_env (without_torch)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747823981/jobs/4310708223" + }, + { + "name": "Test collect_env (older_python_version)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747823981/jobs/4310708332" + }, + { + "name": "quick-checks", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747823981/jobs/4310708496" + }, + { + "name": "toc", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747823981/jobs/4310708710" + }, + { + "name": "Test tools", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747823981/jobs/4310708937" + }, + { + "name": "workflow-checks", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747823981/jobs/4310709169" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAcGj1lc=", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/28140e4008289251b695385acfb48ac7a47cd49c/checks?check_suite_id=7546696649" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAcHRc8k=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "TorchBench CI (pytorch-linux-py3.7-cu102)" + } + }, + "checkRuns": { + "nodes": [], + "pageInfo": { + "endCursor": null, + "hasNextPage": false + } + }, + "conclusion": "CANCELLED", + "url": "https://github.com/pytorch/pytorch/commit/28140e4008289251b695385acfb48ac7a47cd49c/checks?check_suite_id=7546696651" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAcHRc8s=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "TorchBench CI (pytorch-linux-py3.7-cu102)" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "run-torchbench", + "conclusion": "NEUTRAL", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747823982/jobs/4310707884" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAcGjz0w=", + "hasNextPage": false + } + }, + "conclusion": "SKIPPED", + "url": "https://github.com/pytorch/pytorch/commit/28140e4008289251b695385acfb48ac7a47cd49c/checks?check_suite_id=7546696656" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAcHRc9A=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "Lint" + } + }, + "checkRuns": { + "nodes": [], + "pageInfo": { + "endCursor": null, + "hasNextPage": false + } + }, + "conclusion": "CANCELLED", + "url": "https://github.com/pytorch/pytorch/commit/28140e4008289251b695385acfb48ac7a47cd49c/checks?check_suite_id=7546696660" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAcHRc9Q=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "pull" + } + }, + "checkRuns": { + "nodes": [], + "pageInfo": { + "endCursor": null, + "hasNextPage": false + } + }, + "conclusion": "CANCELLED", + "url": "https://github.com/pytorch/pytorch/commit/28140e4008289251b695385acfb48ac7a47cd49c/checks?check_suite_id=7546696715" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAcHRdAs=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "pull" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "linux-bionic-cuda11.3-py3.7-clang9 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747824048/jobs/4310708487" + }, + { + "name": "linux-bionic-cuda11.6-py3.7-gcc7 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747824048/jobs/4310708713" + }, + { + "name": "linux-bionic-py3.7-clang9 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747824048/jobs/4310708942" + }, + { + "name": "linux-focal-py3.7-clang7-asan / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747824048/jobs/4310709174" + }, + { + "name": "linux-bionic-py3_7-clang8-xla / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747824048/jobs/4310709340" + }, + { + "name": "linux-focal-py3.7-gcc7-no-ops / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747824048/jobs/4310709579" + }, + { + "name": "linux-xenial-cuda11.3-py3.7-gcc7-bazel-test / build-and-test", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747824048/jobs/4310709844" + }, + { + "name": "linux-xenial-py3-clang5-mobile-custom-build-static / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747824048/jobs/4310710003" + }, + { + "name": "linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single-full-jit / build-and-test", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747824048/jobs/4310710175" + }, + { + "name": "win-vs2019-cuda11.6-py3 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747824048/jobs/4310710516" + }, + { + "name": "linux-focal-py3.7-gcc7 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747824048/jobs/4310710716" + }, + { + "name": "win-vs2019-cpu-py3 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747824048/jobs/4310710890" + }, + { + "name": "linux-focal-py3.7-gcc7-mobile-lightweight-dispatch-build / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747824048/jobs/4310711097" + }, + { + "name": "linux-focal-py3.7-clang10-onnx / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747824048/jobs/4310711234" + }, + { + "name": "linux-xenial-py3-clang5-mobile-build / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747824048/jobs/4310711429" + }, + { + "name": "linux-focal-rocm5.2-py3.7 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747824048/jobs/4310711603" + }, + { + "name": "linux-jammy-cuda11.6-cudnn8-py3.8-clang12 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747824048/jobs/4310711765" + }, + { + "name": "linux-vulkan-bionic-py3.7-clang9 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747824048/jobs/4310711946" + }, + { + "name": "linux-xenial-cuda11_3-py3_7-gcc7-deploy / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747824048/jobs/4310712129" + }, + { + "name": "linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single / build-and-test", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747824048/jobs/4310712276" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (default, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747824048/jobs/4311194495" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (default, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747824048/jobs/4311194591" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (crossref, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747824048/jobs/4311194659" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (crossref, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747824048/jobs/4311194749" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (dynamo, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747824048/jobs/4311194858" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (dynamo, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747824048/jobs/4311194934" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (functorch, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747824048/jobs/4311195003" + }, + { + "name": "linux-focal-py3.7-clang10-onnx / test (default, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747824048/jobs/4311220458" + }, + { + "name": "linux-focal-py3.7-clang10-onnx / test (default, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747824048/jobs/4311220540" + }, + { + "name": "linux-docs / build-docs (cpp)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747824048/jobs/4311222725" + }, + { + "name": "linux-docs / build-docs (python)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747824048/jobs/4311222869" + }, + { + "name": "linux-focal-py3.7-gcc7 / test (default, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747824048/jobs/4311223128" + }, + { + "name": "linux-focal-py3.7-gcc7 / test (default, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747824048/jobs/4311223225" + }, + { + "name": "linux-focal-py3.7-gcc7 / test (distributed, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747824048/jobs/4311223324" + }, + { + "name": "linux-focal-py3.7-gcc7 / test (functorch, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747824048/jobs/4311223396" + }, + { + "name": "linux-focal-py3.7-gcc7 / test (docs_test, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747824048/jobs/4311223496" + }, + { + "name": "linux-focal-py3.7-gcc7 / test (jit_legacy, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747824048/jobs/4311223569" + }, + { + "name": "linux-focal-py3.7-gcc7 / test (backwards_compat, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747824048/jobs/4311223690" + }, + { + "name": "linux-bionic-py3_7-clang8-xla / test (xla, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747824048/jobs/4311224360" + }, + { + "name": "linux-vulkan-bionic-py3.7-clang9 / test (default, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747824048/jobs/4311230050" + }, + { + "name": "linux-focal-py3.7-clang7-asan / test (default, 1, 5, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747824048/jobs/4311301930" + }, + { + "name": "linux-focal-py3.7-clang7-asan / test (default, 2, 5, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747824048/jobs/4311302152" + }, + { + "name": "linux-focal-py3.7-clang7-asan / test (default, 3, 5, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747824048/jobs/4311302303" + }, + { + "name": "linux-focal-py3.7-clang7-asan / test (default, 4, 5, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747824048/jobs/4311302433" + }, + { + "name": "linux-focal-py3.7-clang7-asan / test (default, 5, 5, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747824048/jobs/4311302531" + }, + { + "name": "linux-bionic-cuda11.6-py3.7-gcc7 / test (default, 1, 4, linux.4xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747824048/jobs/4311491082" + }, + { + "name": "linux-bionic-cuda11.6-py3.7-gcc7 / test (default, 2, 4, linux.4xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747824048/jobs/4311491172" + }, + { + "name": "linux-bionic-cuda11.6-py3.7-gcc7 / test (default, 3, 4, linux.4xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747824048/jobs/4311491232" + }, + { + "name": "linux-bionic-cuda11.6-py3.7-gcc7 / test (default, 4, 4, linux.4xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747824048/jobs/4311491289" + }, + { + "name": "linux-bionic-cuda11.6-py3.7-gcc7 / test (distributed, 1, 2, linux.8xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2747824048/jobs/4311491348" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAcG0YME=", + "hasNextPage": true + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/28140e4008289251b695385acfb48ac7a47cd49c/checks?check_suite_id=7546696836" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAcHRdIQ=" + }, + { + "node": { + "app": { + "name": "Facebook GitHub Tools", + "databaseId": 12274 + }, + "workflowRun": null, + "checkRuns": { + "nodes": [ + { + "name": "Facebook CLA Check", + "conclusion": "SUCCESS", + "detailsUrl": "https://code.intern.facebook.com/cla/" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAcGjyQg=", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/28140e4008289251b695385acfb48ac7a47cd49c/checks?check_suite_id=7546696896" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAcHRdMA=" + }, + { + "node": { + "app": { + "name": "Netlify", + "databaseId": 13473 + }, + "workflowRun": null, + "checkRuns": { + "nodes": [], + "pageInfo": { + "endCursor": null, + "hasNextPage": false + } + }, + "conclusion": null, + "url": "https://github.com/pytorch/pytorch/commit/28140e4008289251b695385acfb48ac7a47cd49c/checks?check_suite_id=7546697185" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAcHRdeE=" + }, + { + "node": { + "app": { + "name": "Azure Pipelines", + "databaseId": 9426 + }, + "workflowRun": null, + "checkRuns": { + "nodes": [], + "pageInfo": { + "endCursor": null, + "hasNextPage": false + } + }, + "conclusion": null, + "url": "https://github.com/pytorch/pytorch/commit/28140e4008289251b695385acfb48ac7a47cd49c/checks?check_suite_id=7546697205" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAcHRdfU=" + }, + { + "node": { + "app": { + "name": "Dependabot", + "databaseId": 29110 + }, + "workflowRun": null, + "checkRuns": { + "nodes": [], + "pageInfo": { + "endCursor": null, + "hasNextPage": false + } + }, + "conclusion": null, + "url": "https://github.com/pytorch/pytorch/commit/28140e4008289251b695385acfb48ac7a47cd49c/checks?check_suite_id=7546697224" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAcHRdgg=" + } + ], + "pageInfo": { + "hasNextPage": true + } + }, + "status": null, + "pushedDate": "2022-07-27T15:34:17Z", + "oid": "28140e4008289251b695385acfb48ac7a47cd49c" + } + } + ] + }, + "changedFiles": 1, + "files": { + "nodes": [ + { + "path": "test/test_ops.py" + } + ], + "pageInfo": { + "endCursor": "MQ", + "hasNextPage": false + } + }, + "reviews": { + "nodes": [ + { + "author": { + "login": "zou3519" + }, + "state": "APPROVED" + }, + { + "author": { + "login": "Chillee" + }, + "state": "APPROVED" + } + ], + "pageInfo": { + "startCursor": "Y3Vyc29yOnYyOpO5MjAyMi0wNy0yNVQxNDo0NTozNS0wNzowMLkyMDIyLTA3LTI1VDE0OjQ1OjM1LTA3OjAwzj6XYmg=", + "hasPreviousPage": false + } + }, + "comments": { + "nodes": [ + { + "bodyText": "@pytorchbot merge -f FORCE", + "createdAt": "2022-07-27T17:56:43Z", + "author": { + "login": "malfet" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1197107402 + }, + { + "bodyText": "You need to provide a reason for using force merge, in the format @pytorchbot merge -f '[CATEGORY] Explanation'. With [CATEGORY] being one the following:\nEMERGENCY - an emergency fix to quickly address an issue\nMINOR - a minor fix such as cleaning locally unused variables, which shouldn't break anything\nPRE_TESTED - a previous CI run tested everything and you've only added minor changes like fixing lint\nOTHER - something not covered above", + "createdAt": "2022-07-27T17:56:45Z", + "author": { + "login": "pytorch-bot" + }, + "authorAssociation": "NONE", + "editor": null, + "databaseId": 1197107439 + }, + { + "bodyText": "@pytorchbot merge -f \"[OTHER] normal land failed twice already\"", + "createdAt": "2022-07-27T17:57:28Z", + "author": { + "login": "malfet" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1197108130 + }, + { + "bodyText": "@pytorchbot successfully started a merge job. Check the current status here", + "createdAt": "2022-07-27T18:08:13Z", + "author": { + "login": "pytorchmergebot" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1197119348 + }, + { + "bodyText": "Hey @ezyang.\nYou've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.\nFor changes that are 'topic: not user facing' there is no need for a release notes label.", + "createdAt": "2022-07-27T18:08:58Z", + "author": { + "login": "github-actions" + }, + "authorAssociation": "NONE", + "editor": null, + "databaseId": 1197120095 + } + ], + "pageInfo": { + "startCursor": "Y3Vyc29yOnYyOpHOR1poyg==", + "hasPreviousPage": true + } + }, + "labels": { + "edges": [ + { + "node": { + "name": "Merged" + } + }, + { + "node": { + "name": "cla signed" + } + } + ] + } + } + } + } + }, + "query_sha=81fd873151c3cded18314e9e53bf54a93ffb0afa9c52fa2cbafb2ceab7df5e45 name=pytorch number=73811 owner=pytorch": { + "data": { + "repository": { + "pullRequest": { + "closed": true, + "isCrossRepository": false, + "author": { + "login": "seemethere" + }, + "title": "ci: Migrate metrics credentials to managed IAM", + "body": "Stack from [ghstack](https://github.com/ezyang/ghstack):\n* __->__ #73811\n\r\nMigrates our credentials to upload metrics statistics to managed IAM\r\ncredentials in order to make it easier to know where the credentials are\r\ncoming from and to make it easier to add more permissions / less\r\npermissions later on.\r\n\r\nRelates to work done in [D34535827](https://www.internalfb.com/diff/D34535827)\r\n\r\nSigned-off-by: Eli Uriegas ", + "headRefName": "gh/seemethere/215/head", + "headRepository": { + "nameWithOwner": "pytorch/pytorch" + }, + "baseRefName": "gh/seemethere/215/base", + "baseRepository": { + "nameWithOwner": "pytorch/pytorch", + "isPrivate": false, + "defaultBranchRef": { + "name": "master" + } + }, + "mergeCommit": null, + "commits_with_authors": { + "nodes": [ + { + "commit": { + "author": { + "user": { + "login": "seemethere" + }, + "email": "eliuriegas@fb.com", + "name": "Eli Uriegas" + }, + "oid": "13c44d16a876a56bca479b4cf30715d21fa16e99" + } + }, + { + "commit": { + "author": { + "user": { + "login": "seemethere" + }, + "email": "eliuriegas@fb.com", + "name": "Eli Uriegas" + }, + "oid": "9d26f4e6d8c8df275ea546180fef42548257d2d7" + } + } + ], + "pageInfo": { + "endCursor": "Mg", + "hasNextPage": false + }, + "totalCount": 2 + }, + "commits": { + "nodes": [ + { + "commit": { + "checkSuites": { + "edges": [ + { + "node": { + "app": { + "name": "Facebook GitHub Tools", + "databaseId": 12274 + }, + "workflowRun": null, + "checkRuns": { + "nodes": [ + { + "name": "Facebook CLA Check", + "conclusion": "SUCCESS", + "detailsUrl": "https://code.intern.facebook.com/cla/" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAUqOaHA=", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/9d26f4e6d8c8df275ea546180fef42548257d2d7/checks?check_suite_id=5658275867" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAVFCcBs=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "linux-xenial-py3.7-gcc5.4-mobile-lightweight-dispatch-build" + } + }, + "checkRuns": { + "nodes": [], + "pageInfo": { + "endCursor": null, + "hasNextPage": false + } + }, + "conclusion": "CANCELLED", + "url": "https://github.com/pytorch/pytorch/commit/9d26f4e6d8c8df275ea546180fef42548257d2d7/checks?check_suite_id=5658276090" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAVFCcPo=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "win-vs2019-cpu-py3" + } + }, + "checkRuns": { + "nodes": [], + "pageInfo": { + "endCursor": null, + "hasNextPage": false + } + }, + "conclusion": "CANCELLED", + "url": "https://github.com/pytorch/pytorch/commit/9d26f4e6d8c8df275ea546180fef42548257d2d7/checks?check_suite_id=5658276092" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAVFCcPw=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "linux-xenial-py3-clang5-mobile-build" + } + }, + "checkRuns": { + "nodes": [], + "pageInfo": { + "endCursor": null, + "hasNextPage": false + } + }, + "conclusion": "CANCELLED", + "url": "https://github.com/pytorch/pytorch/commit/9d26f4e6d8c8df275ea546180fef42548257d2d7/checks?check_suite_id=5658276094" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAVFCcP4=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single" + } + }, + "checkRuns": { + "nodes": [], + "pageInfo": { + "endCursor": null, + "hasNextPage": false + } + }, + "conclusion": "CANCELLED", + "url": "https://github.com/pytorch/pytorch/commit/9d26f4e6d8c8df275ea546180fef42548257d2d7/checks?check_suite_id=5658276095" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAVFCcP8=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "Lint" + } + }, + "checkRuns": { + "nodes": [], + "pageInfo": { + "endCursor": null, + "hasNextPage": false + } + }, + "conclusion": "CANCELLED", + "url": "https://github.com/pytorch/pytorch/commit/9d26f4e6d8c8df275ea546180fef42548257d2d7/checks?check_suite_id=5658276097" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAVFCcQE=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "TorchBench CI (pytorch-linux-py3.7-cu102)" + } + }, + "checkRuns": { + "nodes": [], + "pageInfo": { + "endCursor": null, + "hasNextPage": false + } + }, + "conclusion": "CANCELLED", + "url": "https://github.com/pytorch/pytorch/commit/9d26f4e6d8c8df275ea546180fef42548257d2d7/checks?check_suite_id=5658276098" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAVFCcQI=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "linux-xenial-py3.7-gcc7-no-ops" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1983602966/jobs/2839950629" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAUqObRM=", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/9d26f4e6d8c8df275ea546180fef42548257d2d7/checks?check_suite_id=5658276099" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAVFCcQM=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "Test tools" + } + }, + "checkRuns": { + "nodes": [], + "pageInfo": { + "endCursor": null, + "hasNextPage": false + } + }, + "conclusion": "CANCELLED", + "url": "https://github.com/pytorch/pytorch/commit/9d26f4e6d8c8df275ea546180fef42548257d2d7/checks?check_suite_id=5658276100" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAVFCcQQ=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "linux-xenial-py3.7-clang7-asan" + } + }, + "checkRuns": { + "nodes": [], + "pageInfo": { + "endCursor": null, + "hasNextPage": false + } + }, + "conclusion": "CANCELLED", + "url": "https://github.com/pytorch/pytorch/commit/9d26f4e6d8c8df275ea546180fef42548257d2d7/checks?check_suite_id=5658276101" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAVFCcQU=" + } + ], + "pageInfo": { + "hasNextPage": true + } + }, + "status": { + "contexts": [ + { + "context": "ci/circleci: docker-pytorch-linux-xenial-py3-clang5-android-ndk-r19c", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/17044969?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-build-x86_32", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/17045014?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: pytorch_linux_xenial_py3_clang5_android_ndk_r19c_x86_32_build", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/17044975?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + } + ] + }, + "pushedDate": "2022-03-14T23:01:55Z", + "oid": "9d26f4e6d8c8df275ea546180fef42548257d2d7" + } + } + ] + }, + "changedFiles": 3, + "files": { + "nodes": [ + { + "path": ".github/templates/common.yml.j2" + }, + { + "path": ".github/workflows/generated-macos-11-py3-x86-64.yml" + }, + { + "path": ".github/workflows/update_pytorch_labels.yml" + } + ], + "pageInfo": { + "endCursor": "Mw", + "hasNextPage": false + } + }, + "reviews": { + "nodes": [ + { + "author": { + "login": "kit1980" + }, + "state": "APPROVED" + }, + { + "author": { + "login": "janeyx99" + }, + "state": "APPROVED" + } + ], + "pageInfo": { + "startCursor": "Y3Vyc29yOnYyOpO5MjAyMi0wMy0wNFQxNDoyNDo0OC0wODowMLkyMDIyLTAzLTA0VDE0OjI0OjQ4LTA4OjAwzjWwwqA=", + "hasPreviousPage": false + } + }, + "comments": { + "nodes": [ + { + "bodyText": "Merge failed due to Too many checksuites for commit\nRaised by https://github.com/pytorch/pytorch/actions/runs/1988337976", + "createdAt": "2022-03-15T17:43:28Z", + "author": { + "login": "pytorchmergebot" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1068270969 + }, + { + "bodyText": "@pytorchbot force merge this", + "createdAt": "2022-03-15T20:26:36Z", + "author": { + "login": "seemethere" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1068436128 + }, + { + "bodyText": "Merge failed due to Too many checksuites for commit\nRaised by https://github.com/pytorch/pytorch/actions/runs/1989076952", + "createdAt": "2022-03-15T20:27:47Z", + "author": { + "login": "pytorchmergebot" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1068437098 + }, + { + "bodyText": "@pytorchbot merge this", + "createdAt": "2022-03-15T21:18:55Z", + "author": { + "login": "seemethere" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1068482921 + }, + { + "bodyText": "Hey @seemethere.\nYou've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.\nFor changes that are 'topic: not user facing' there is no need for a release notes label.", + "createdAt": "2022-03-15T21:20:40Z", + "author": { + "login": "github-actions" + }, + "authorAssociation": "NONE", + "editor": null, + "databaseId": 1068484404 + } + ], + "pageInfo": { + "startCursor": "Y3Vyc29yOnYyOpHOP6yFeQ==", + "hasPreviousPage": true + } + }, + "labels": { + "edges": [ + { + "node": { + "name": "cla signed" + } + } + ] + } + } + } + } + }, + "query_sha=81fd873151c3cded18314e9e53bf54a93ffb0afa9c52fa2cbafb2ceab7df5e45 name=pytorch number=31093 owner=pytorch": { + "data": { + "repository": { + "pullRequest": { + "closed": true, + "isCrossRepository": true, + "author": { + "login": "mingxiaoh" + }, + "title": "improve mkldnn convolution test coverage", + "body": "This pr will improve the test coverage of mkldnn convolution.\r\n1.test input: specific sensitive numbers\r\n2.pass criteria: output of mkldnn convolution matches output of thnn convolution\r\n3.coverage: by using coverage tool, we found out the following sensitive parameters. Overall the case will test 4352 patterns, takes 8.8s on my machine.\r\n\r\nto run the test case:\r\n\r\npython test_mkldnn_conv2d_ext.py\r\nor\r\npython run_test.py -i mkldnn_conv2d_ext\r\n\r\nIn case of failure, the pattern will be printed in the log for further debugging.\r\n\r\nactually, this PR is created to replace and improve that PR we created before(https://github.com/pytorch/pytorch/pull/25085) ", + "headRefName": "master", + "headRepository": { + "nameWithOwner": "mingxiaoh/pytorch" + }, + "baseRefName": "master", + "baseRepository": { + "nameWithOwner": "pytorch/pytorch", + "isPrivate": false, + "defaultBranchRef": { + "name": "master" + } + }, + "mergeCommit": null, + "commits_with_authors": { + "nodes": [ + { + "commit": { + "author": { + "user": { + "login": "11pikachu" + }, + "email": "junx.du@intel.com", + "name": "dujun" + }, + "oid": "29f6aa6ecc2ece3fa58170ff4561f9d8d5c129f9" + } + } + ], + "pageInfo": { + "endCursor": "MQ", + "hasNextPage": false + }, + "totalCount": 1 + }, + "commits": { + "nodes": [ + { + "commit": { + "checkSuites": { + "edges": [ + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "clang-format" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "clang-format", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/1099676797?check_suite_focus=true" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHOQYu8fQ==", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/29f6aa6ecc2ece3fa58170ff4561f9d8d5c129f9/checks?check_suite_id=1175281097" + }, + "cursor": "Y3Vyc29yOnYyOpHORg1dyQ==" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "Lint" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "flake8-py3", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/1099676800?check_suite_focus=true" + }, + { + "name": "quick-checks", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/1099676817?check_suite_focus=true" + }, + { + "name": "clang-tidy", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/1099676829?check_suite_focus=true" + }, + { + "name": "cmakelint", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/runs/1099676840?check_suite_focus=true" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHOQYu8qA==", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/29f6aa6ecc2ece3fa58170ff4561f9d8d5c129f9/checks?check_suite_id=1175281099" + }, + "cursor": "Y3Vyc29yOnYyOpHORg1dyw==" + }, + { + "node": { + "app": { + "name": "Codecov", + "databaseId": 254 + }, + "workflowRun": null, + "checkRuns": { + "nodes": [ + { + "name": "codecov/project", + "conclusion": "SUCCESS", + "detailsUrl": "https://codecov.io" + }, + { + "name": "codecov/patch", + "conclusion": "SUCCESS", + "detailsUrl": "https://codecov.io" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHOQZhcFQ==", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/29f6aa6ecc2ece3fa58170ff4561f9d8d5c129f9/checks?check_suite_id=1176100822" + }, + "cursor": "Y3Vyc29yOnYyOpHORhnf1g==" + }, + { + "node": { + "app": { + "name": "Codecov", + "databaseId": 254 + }, + "workflowRun": null, + "checkRuns": { + "nodes": [ + { + "name": "codecov/patch", + "conclusion": "SUCCESS", + "detailsUrl": "https://codecov.io" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHOQZZsEQ==", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/29f6aa6ecc2ece3fa58170ff4561f9d8d5c129f9/checks?check_suite_id=1176100824" + }, + "cursor": "Y3Vyc29yOnYyOpHORhnf2A==" + }, + { + "node": { + "app": { + "name": "Facebook GitHub Tools", + "databaseId": 12274 + }, + "workflowRun": null, + "checkRuns": { + "nodes": [ + { + "name": "Facebook CLA Check", + "conclusion": "SUCCESS", + "detailsUrl": "https://code.facebook.com/cla/" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHOUquzJg==", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/29f6aa6ecc2ece3fa58170ff4561f9d8d5c129f9/checks?check_suite_id=1487517306" + }, + "cursor": "Y3Vyc29yOnYyOpHOWKm2eg==" + } + ], + "pageInfo": { + "hasNextPage": false + } + }, + "status": { + "contexts": [ + { + "context": "ci/circleci: binary_linux_libtorch_3_7m_cpu_devtoolset7_shared-with-deps_build", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7406538?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: binary_linux_libtorch_3_7m_cpu_devtoolset7_shared-with-deps_test", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7406947?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: binary_linux_libtorch_3_7m_cpu_gcc5_4_cxx11-abi_shared-with-deps_build", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7406544?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: binary_linux_libtorch_3_7m_cpu_gcc5_4_cxx11-abi_shared-with-deps_test", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7406931?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: binary_windows_libtorch_3_7_cpu_debug_build", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7406550?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: binary_windows_libtorch_3_7_cpu_debug_test", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7406887?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: binary_windows_libtorch_3_7_cpu_release_build", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7406526?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: binary_windows_libtorch_3_7_cpu_release_test", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7406707?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: caffe2_onnx_main_py3_6_clang7_ubuntu16_04_build", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7406533?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: caffe2_onnx_main_py3_6_clang7_ubuntu16_04_test", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7407256?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: caffe2_onnx_ort1_py3_6_clang7_ubuntu16_04_test", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7407254?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: caffe2_onnx_ort2_py3_6_clang7_ubuntu16_04_test", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7407255?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: docker-pytorch-linux-bionic-cuda10.2-cudnn7-py3.6-clang9", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7406556?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: docker-pytorch-linux-bionic-cuda10.2-cudnn7-py3.8-gcc9", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7406532?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: docker-pytorch-linux-bionic-cuda11.0-cudnn8-py3.6-gcc9", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7406527?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: docker-pytorch-linux-bionic-cuda11.0-cudnn8-py3.8-gcc9", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7406553?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: docker-pytorch-linux-bionic-py3.6-clang9", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7406537?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: docker-pytorch-linux-bionic-py3.8-gcc9", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7406529?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: docker-pytorch-linux-bionic-rocm3.5.1-py3.6", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7406554?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: docker-pytorch-linux-bionic-rocm3.7-py3.6", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7406545?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: docker-pytorch-linux-xenial-cuda10-cudnn7-py3-gcc7", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7406543?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: docker-pytorch-linux-xenial-cuda10.1-cudnn7-py3-gcc7", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7406536?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: docker-pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7406552?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: docker-pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7406535?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: docker-pytorch-linux-xenial-cuda9.2-cudnn7-py3-gcc5.4", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7406540?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: docker-pytorch-linux-xenial-cuda9.2-cudnn7-py3-gcc7", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7406528?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: docker-pytorch-linux-xenial-py3-clang5-android-ndk-r19c", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7406541?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: docker-pytorch-linux-xenial-py3-clang5-asan", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7406549?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: docker-pytorch-linux-xenial-py3.6-clang7", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7406555?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: docker-pytorch-linux-xenial-py3.6-gcc4.8", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7406546?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: docker-pytorch-linux-xenial-py3.6-gcc5.4", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7406531?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: docker-pytorch-linux-xenial-py3.6-gcc7", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7406534?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: docker-pytorch-linux-xenial-py3.6-gcc7.2", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7406523?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: docker-pytorch-linux-xenial-py3.8", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7406539?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: docker-pytorch-linux-xenial-rocm3.3-py3.6", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7406547?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: docker-pytorch-linux-xenial-rocm3.5.1-py3.6", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7406551?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-build-x86_32", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7407209?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7406611?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: pytorch_bazel_build", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7406607?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: pytorch_bazel_test", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7406984?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: pytorch_cpp_doc_build", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7407013?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: pytorch_doc_test", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7407011?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: pytorch_ios_11_2_1_x86_64_build", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7406548?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: pytorch_libtorch_linux_xenial_cuda11_0_cudnn8_py3_gcc7_build", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7406563?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: pytorch_libtorch_linux_xenial_cuda11_0_cudnn8_py3_gcc7_test", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7408680?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: pytorch_linux_backward_compatibility_check_test", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7407014?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: pytorch_linux_bionic_py3_6_clang9_build", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7406567?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: pytorch_linux_bionic_py3_6_clang9_test", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7406945?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: pytorch_linux_bionic_py3_8_gcc9_build", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7406561?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: pytorch_linux_bionic_py3_8_gcc9_coverage_test", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7407422?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: pytorch_linux_bionic_rocm3_7_py3_6_build", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7406562?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: pytorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_build", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7406612?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: pytorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_test", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7408107?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: pytorch_linux_xenial_cuda10_2_cudnn7_py3_ge_config_legacy_test", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7408111?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: pytorch_linux_xenial_cuda10_2_cudnn7_py3_ge_config_profiling_test", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7408101?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: pytorch_linux_xenial_cuda9_2_cudnn7_py3_gcc5_4_build", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7406613?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: pytorch_linux_xenial_py3_6_gcc5_4_build", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7406565?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: pytorch_linux_xenial_py3_6_gcc5_4_ge_config_legacy_test", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7407017?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: pytorch_linux_xenial_py3_6_gcc5_4_ge_config_profiling_test", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7407019?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: pytorch_linux_xenial_py3_6_gcc5_4_ge_config_simple_test", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7407012?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: pytorch_linux_xenial_py3_6_gcc5_4_test", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7407016?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: pytorch_linux_xenial_py3_clang5_android_ndk_r19c_vulkan_x86_32_build", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7406608?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: pytorch_linux_xenial_py3_clang5_android_ndk_r19c_x86_32_build", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7406609?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: pytorch_linux_xenial_py3_clang5_asan_build", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7406606?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: pytorch_linux_xenial_py3_clang5_asan_test1", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7407435?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: pytorch_linux_xenial_py3_clang5_asan_test2", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7407436?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: pytorch_linux_xenial_py3_clang5_mobile_build", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7406605?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: pytorch_linux_xenial_py3_clang5_mobile_custom_build_dynamic", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7406610?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: pytorch_macos_10_13_py3_build", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7406525?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: pytorch_macos_10_13_py3_test", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7407415?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: pytorch_python_doc_build", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7407018?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: pytorch_vulkan_linux_bionic_py3_6_clang9_build", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7406566?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: pytorch_vulkan_linux_bionic_py3_6_clang9_test", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7406946?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: pytorch_windows_vs2019_py36_cpu_build", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7406542?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: pytorch_windows_vs2019_py36_cuda10.1_build", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7406530?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: pytorch_windows_vs2019_py36_cuda10.1_test1", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7407028?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: pytorch_windows_vs2019_py36_cuda10.1_test2", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7407027?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: pytorch_windows_vs2019_py36_cuda11.0_build", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7406524?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: pytorch_xla_linux_bionic_py3_6_clang9_build", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7406572?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: pytorch_xla_linux_bionic_py3_6_clang9_test", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/7407253?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "codecov/patch", + "state": "SUCCESS", + "targetUrl": "https://codecov.io/gh/pytorch/pytorch/compare/69f6d94caa3559d4f50745c26af5df041b83fee8...29f6aa6ecc2ece3fa58170ff4561f9d8d5c129f9" + }, + { + "context": "codecov/project", + "state": "SUCCESS", + "targetUrl": "https://codecov.io/gh/pytorch/pytorch/compare/69f6d94caa3559d4f50745c26af5df041b83fee8...29f6aa6ecc2ece3fa58170ff4561f9d8d5c129f9" + }, + { + "context": "pr/caffe2-pytorch-linux-bionic-rocm3.7-py3.6-test", + "state": "SUCCESS", + "targetUrl": "https://ci.pytorch.org/jenkins/job/caffe2-builds/job/pytorch-linux-bionic-rocm3.7-py3.6-trigger-test/2319/" + }, + { + "context": "pr/pytorch-linux-bionic-rocm3.7-py3.6", + "state": "SUCCESS", + "targetUrl": "https://ci.pytorch.org/jenkins/job/pytorch-builds/job/pytorch-linux-bionic-rocm3.7-py3.6-trigger/2325/" + } + ] + }, + "pushedDate": "2020-09-11T01:58:24Z", + "oid": "29f6aa6ecc2ece3fa58170ff4561f9d8d5c129f9" + } + } + ] + }, + "changedFiles": 5, + "files": { + "nodes": [ + { + "path": "test/math_libraries/convolutions.py" + }, + { + "path": "test/math_libraries/convolutions_cases/shapes_googlenet_v3.json" + }, + { + "path": "test/math_libraries/convolutions_cases/shapes_maskrcnn_p1.json" + }, + { + "path": "test/math_libraries/convolutions_cases/shapes_mobilenet.json" + }, + { + "path": "test/math_libraries/convolutions_cases/shapes_resnet_50.json" + } + ], + "pageInfo": { + "endCursor": "NQ", + "hasNextPage": false + } + }, + "reviews": { + "nodes": [ + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "CHANGES_REQUESTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "CHANGES_REQUESTED" + }, + { + "author": { + "login": "ailzhang" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "ngimel" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "VitalyFedyunin" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "ngimel" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mingxiaoh" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mingxiaoh" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "VitalyFedyunin" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "VitalyFedyunin" + }, + "state": "APPROVED" + } + ], + "pageInfo": { + "startCursor": "Y3Vyc29yOnYyOpO5MjAxOS0xMi0zMFQxMDoxOToxMS0wODowMLkyMDE5LTEyLTMwVDEwOjE5OjExLTA4OjAwzhQZLuY=", + "hasPreviousPage": false + } + }, + "comments": { + "nodes": [ + { + "bodyText": "I cloned your repo and ran the tests:\n~/pytorch/test/math_libraries$ python convolutions.py\nFFFF\n======================================================================\nFAIL: test_conv2d_ext_cpu_float32 (__main__.TestConvExtCPU)\n----------------------------------------------------------------------\nTraceback (most recent call last):\n File \"/private/home/mruberry/git/pytorch/torch/testing/_internal/common_device_type.py\", line 244, in instantiated_test\n result = test(self, *args)\n File \"/private/home/mruberry/git/pytorch/torch/testing/_internal/common_device_type.py\", line 615, in only_fn\n return fn(self, device, *args, **kwargs)\n File \"/private/home/mruberry/git/pytorch/torch/testing/_internal/common_device_type.py\", line 472, in dep_fn\n return fn(slf, device, *args, **kwargs)\n File \"convolutions.py\", line 114, in test_conv2d_ext\n \"invalid cases:\" + \",\".join(invalid_cases)\nAssertionError: invalid cases:masknet_p1:conv33,masknet_p1:conv8,masknet_p1:conv2*4,masknet_p1:conv12,masknet_p1:conv4*3,masknet_p1:conv19,masknet_p1:conv4,masknet_p1:conv4,masknet_p1:conv27,masknet_p1:conv39,masknet_p1:conv23,masknet_p1:conv20,masknet_p1:conv25,masknet_p1:conv17,masknet_p1:conv9*4,masknet_p1:conv36,masknet_p1:conv18,masknet_p1:conv5,masknet_p1:conv38,masknet_p1:conv31,masknet_p1:conv14,masknet_p1:conv26,masknet_p1:conv2,masknet_p1:conv5*2,masknet_p1:conv28,masknet_p1:conv16,masknet_p1:conv20*3,masknet_p1:conv9,masknet_p1:conv14*23,masknet_p1:conv32,masknet_p1:conv30,masknet_p1:conv35,masknet_p1:conv37,masknet_p1:conv3,masknet_p1:conv24,masknet_p1:conv13,masknet_p1:conv21*3,masknet_p1:conv10,masknet_p1:conv7,masknet_p1:conv34,masknet_p1:conv13*24,masknet_p1:conv10*4,masknet_p1:conv22*2,masknet_p1:conv6,masknet_p1:conv22,masknet_p1:conv11,masknet_p1:conv40,masknet_p1:conv15,masknet_p1:conv17*23,masknet_p1:conv29,masknet_p1:conv21,masknet_p1:conv1,masknet_p1:conv11*3,mobilenet:conv3,mobilenet:conv2*4,mobilenet:conv6,mobilenet:conv7,mobilenet:conv5*4,mobilenet:conv4*4,mobilenet:conv7*4,mobilenet:conv1*3,mobilenet:conv10,mobilenet:conv2,mobilenet:conv5,mobilenet:conv4,mobilenet:conv9*4,mobilenet:conv8,mobilenet:conv9,mobilenet:conv6*4,mobilenet:conv10*4,mobilenet:conv11,mobilenet:conv8*20,mobilenet:conv1,mobilenet:conv11*4,mobilenet:conv3*4\n\n======================================================================\nFAIL: test_conv2d_ext_cuda_float16 (__main__.TestConvExtCUDA)\n----------------------------------------------------------------------\nTraceback (most recent call last):\n File \"/private/home/mruberry/git/pytorch/torch/testing/_internal/common_utils.py\", line 815, in wrapper\n method(*args, **kwargs)\n File \"/private/home/mruberry/git/pytorch/torch/testing/_internal/common_utils.py\", line 815, in wrapper\n method(*args, **kwargs)\n File \"/private/home/mruberry/git/pytorch/torch/testing/_internal/common_device_type.py\", line 244, in instantiated_test\n result = test(self, *args)\n File \"/private/home/mruberry/git/pytorch/torch/testing/_internal/common_device_type.py\", line 615, in only_fn\n return fn(self, device, *args, **kwargs)\n File \"/private/home/mruberry/git/pytorch/torch/testing/_internal/common_device_type.py\", line 472, in dep_fn\n return fn(slf, device, *args, **kwargs)\n File \"convolutions.py\", line 114, in test_conv2d_ext\n \"invalid cases:\" + \",\".join(invalid_cases)\nAssertionError: invalid cases:masknet_p1:conv33,masknet_p1:conv8,masknet_p1:conv2*4,masknet_p1:conv12,masknet_p1:conv4*3,masknet_p1:conv19,masknet_p1:conv4,masknet_p1:conv4,masknet_p1:conv27,masknet_p1:conv39,masknet_p1:conv23,masknet_p1:conv20,masknet_p1:conv25,masknet_p1:conv17,masknet_p1:conv9*4,masknet_p1:conv36,masknet_p1:conv18,masknet_p1:conv5,masknet_p1:conv38,masknet_p1:conv31,masknet_p1:conv14,masknet_p1:conv26,masknet_p1:conv2,masknet_p1:conv5*2,masknet_p1:conv28,masknet_p1:conv16,masknet_p1:conv20*3,masknet_p1:conv9,masknet_p1:conv14*23,masknet_p1:conv32,masknet_p1:conv30,masknet_p1:conv35,masknet_p1:conv37,masknet_p1:conv3,masknet_p1:conv24,masknet_p1:conv13,masknet_p1:conv21*3,masknet_p1:conv10,masknet_p1:conv7,masknet_p1:conv34,masknet_p1:conv13*24,masknet_p1:conv10*4,masknet_p1:conv22*2,masknet_p1:conv6,masknet_p1:conv22,masknet_p1:conv11,masknet_p1:conv40,masknet_p1:conv15,masknet_p1:conv17*23,masknet_p1:conv29,masknet_p1:conv21,masknet_p1:conv1,masknet_p1:conv11*3,mobilenet:conv3,mobilenet:conv2*4,mobilenet:conv6,mobilenet:conv7,mobilenet:conv5*4,mobilenet:conv4*4,mobilenet:conv7*4,mobilenet:conv1*3,mobilenet:conv10,mobilenet:conv2,mobilenet:conv5,mobilenet:conv4,mobilenet:conv9*4,mobilenet:conv8,mobilenet:conv9,mobilenet:conv6*4,mobilenet:conv10*4,mobilenet:conv11,mobilenet:conv8*20,mobilenet:conv1,mobilenet:conv11*4,mobilenet:conv3*4\n\n======================================================================\nFAIL: test_conv2d_ext_cuda_float32 (__main__.TestConvExtCUDA)\n----------------------------------------------------------------------\nTraceback (most recent call last):\n File \"/private/home/mruberry/git/pytorch/torch/testing/_internal/common_utils.py\", line 815, in wrapper\n method(*args, **kwargs)\n File \"/private/home/mruberry/git/pytorch/torch/testing/_internal/common_utils.py\", line 815, in wrapper\n method(*args, **kwargs)\n File \"/private/home/mruberry/git/pytorch/torch/testing/_internal/common_device_type.py\", line 244, in instantiated_test\n result = test(self, *args)\n File \"/private/home/mruberry/git/pytorch/torch/testing/_internal/common_device_type.py\", line 615, in only_fn\n return fn(self, device, *args, **kwargs)\n File \"/private/home/mruberry/git/pytorch/torch/testing/_internal/common_device_type.py\", line 472, in dep_fn\n return fn(slf, device, *args, **kwargs)\n File \"convolutions.py\", line 114, in test_conv2d_ext\n \"invalid cases:\" + \",\".join(invalid_cases)\nAssertionError: invalid cases:masknet_p1:conv33,masknet_p1:conv8,masknet_p1:conv2*4,masknet_p1:conv12,masknet_p1:conv4*3,masknet_p1:conv19,masknet_p1:conv4,masknet_p1:conv4,masknet_p1:conv27,masknet_p1:conv39,masknet_p1:conv23,masknet_p1:conv20,masknet_p1:conv25,masknet_p1:conv17,masknet_p1:conv9*4,masknet_p1:conv36,masknet_p1:conv18,masknet_p1:conv5,masknet_p1:conv38,masknet_p1:conv31,masknet_p1:conv14,masknet_p1:conv26,masknet_p1:conv2,masknet_p1:conv5*2,masknet_p1:conv28,masknet_p1:conv16,masknet_p1:conv20*3,masknet_p1:conv9,masknet_p1:conv14*23,masknet_p1:conv32,masknet_p1:conv30,masknet_p1:conv35,masknet_p1:conv37,masknet_p1:conv3,masknet_p1:conv24,masknet_p1:conv13,masknet_p1:conv21*3,masknet_p1:conv10,masknet_p1:conv7,masknet_p1:conv34,masknet_p1:conv13*24,masknet_p1:conv10*4,masknet_p1:conv22*2,masknet_p1:conv6,masknet_p1:conv22,masknet_p1:conv11,masknet_p1:conv40,masknet_p1:conv15,masknet_p1:conv17*23,masknet_p1:conv29,masknet_p1:conv21,masknet_p1:conv1,masknet_p1:conv11*3,mobilenet:conv3,mobilenet:conv2*4,mobilenet:conv6,mobilenet:conv7,mobilenet:conv5*4,mobilenet:conv4*4,mobilenet:conv7*4,mobilenet:conv1*3,mobilenet:conv10,mobilenet:conv2,mobilenet:conv5,mobilenet:conv4,mobilenet:conv9*4,mobilenet:conv8,mobilenet:conv9,mobilenet:conv6*4,mobilenet:conv10*4,mobilenet:conv11,mobilenet:conv8*20,mobilenet:conv1,mobilenet:conv11*4,mobilenet:conv3*4\n\n======================================================================\nFAIL: test_conv2d_ext_cuda_float64 (__main__.TestConvExtCUDA)\n----------------------------------------------------------------------\nTraceback (most recent call last):\n File \"/private/home/mruberry/git/pytorch/torch/testing/_internal/common_utils.py\", line 815, in wrapper\n method(*args, **kwargs)\n File \"/private/home/mruberry/git/pytorch/torch/testing/_internal/common_utils.py\", line 815, in wrapper\n method(*args, **kwargs)\n File \"/private/home/mruberry/git/pytorch/torch/testing/_internal/common_device_type.py\", line 244, in instantiated_test\n result = test(self, *args)\n File \"/private/home/mruberry/git/pytorch/torch/testing/_internal/common_device_type.py\", line 615, in only_fn\n return fn(self, device, *args, **kwargs)\n File \"/private/home/mruberry/git/pytorch/torch/testing/_internal/common_device_type.py\", line 472, in dep_fn\n return fn(slf, device, *args, **kwargs)\n File \"convolutions.py\", line 114, in test_conv2d_ext\n \"invalid cases:\" + \",\".join(invalid_cases)\nAssertionError: invalid cases:masknet_p1:conv33,masknet_p1:conv8,masknet_p1:conv2*4,masknet_p1:conv12,masknet_p1:conv4*3,masknet_p1:conv19,masknet_p1:conv4,masknet_p1:conv4,masknet_p1:conv27,masknet_p1:conv39,masknet_p1:conv23,masknet_p1:conv20,masknet_p1:conv25,masknet_p1:conv17,masknet_p1:conv9*4,masknet_p1:conv36,masknet_p1:conv18,masknet_p1:conv5,masknet_p1:conv38,masknet_p1:conv31,masknet_p1:conv14,masknet_p1:conv26,masknet_p1:conv2,masknet_p1:conv5*2,masknet_p1:conv28,masknet_p1:conv16,masknet_p1:conv20*3,masknet_p1:conv9,masknet_p1:conv14*23,masknet_p1:conv32,masknet_p1:conv30,masknet_p1:conv35,masknet_p1:conv37,masknet_p1:conv3,masknet_p1:conv24,masknet_p1:conv13,masknet_p1:conv21*3,masknet_p1:conv10,masknet_p1:conv7,masknet_p1:conv34,masknet_p1:conv13*24,masknet_p1:conv10*4,masknet_p1:conv22*2,masknet_p1:conv6,masknet_p1:conv22,masknet_p1:conv11,masknet_p1:conv40,masknet_p1:conv15,masknet_p1:conv17*23,masknet_p1:conv29,masknet_p1:conv21,masknet_p1:conv1,masknet_p1:conv11*3,mobilenet:conv3,mobilenet:conv2*4,mobilenet:conv6,mobilenet:conv7,mobilenet:conv5*4,mobilenet:conv4*4,mobilenet:conv7*4,mobilenet:conv1*3,mobilenet:conv10,mobilenet:conv2,mobilenet:conv5,mobilenet:conv4,mobilenet:conv9*4,mobilenet:conv8,mobilenet:conv9,mobilenet:conv6*4,mobilenet:conv10*4,mobilenet:conv11,mobilenet:conv8*20,mobilenet:conv1,mobilenet:conv11*4,mobilenet:conv3*4\n\n----------------------------------------------------------------------\nRan 4 tests in 33.838s\n\nFAILED (failures=4)\n\nStill fails.\n\n@mruberry It is suggested by @VitalyFedyunin that, we need to display fail test to avoid invalid inputs, I guess we should set it as expected failures under the pytest test framework, right? we will change it as expected failure cases under pytest test framework. The result will looks like be low, is it ok?\n2500 passed, 136 skipped, 0 failed, 0 errors, 2 expected failures, 0 unexpected passes", + "createdAt": "2020-08-14T01:36:20Z", + "author": { + "login": "mingxiaoh" + }, + "authorAssociation": "NONE", + "editor": { + "login": "mingxiaoh" + }, + "databaseId": 673816925 + }, + { + "bodyText": "Displaying tests that fail is fine, but I don't think @VitalyFedyunin meant that it was OK if the tests didn't pass. If these are expected failures then yes, you can use with self.assertRaises(RuntimeError):... when testing them. If you also want to report that the test has test cases with these properties you can print or warn, which will appear in the test output.", + "createdAt": "2020-08-14T03:09:37Z", + "author": { + "login": "mruberry" + }, + "authorAssociation": "COLLABORATOR", + "editor": null, + "databaseId": 673858224 + }, + { + "bodyText": "Codecov Report\n\nMerging #31093 into master will not change coverage.\nThe diff coverage is n/a.\n\n\n@@ Coverage Diff @@\n## master #31093 +/- ##\n=======================================\n Coverage 68.00% 68.00% \n=======================================\n Files 382 382 \n Lines 49527 49527 \n=======================================\n Hits 33679 33679 \n Misses 15848 15848 \n\nContinue to review full report at Codecov.\n\nLegend - Click here to learn more\n\u0394 = absolute (impact), \u00f8 = not affected, ? = missing data\nPowered by Codecov. Last update 69f6d94...29f6aa6. Read the comment docs.", + "createdAt": "2020-09-04T05:41:01Z", + "author": { + "login": "codecov" + }, + "authorAssociation": "NONE", + "editor": { + "login": "codecov" + }, + "databaseId": 686921371 + }, + { + "bodyText": "Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale. Feel free to remove the Stale label if you feel this was a mistake. If you are unable to remove the Stale label please contact a maintainer in order to do so. Stale pull requests will automatically be closed 30 days after being marked Stale", + "createdAt": "2022-04-12T02:35:37Z", + "author": { + "login": "pytorchbot" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1095860944 + }, + { + "bodyText": "Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale. Feel free to remove the Stale label if you feel this was a mistake. If you are unable to remove the Stale label please contact a maintainer in order to do so. If you want the bot to never mark this PR stale again, add the no-stale label.Stale pull requests will automatically be closed after 30 days of inactivity.", + "createdAt": "2022-06-11T04:40:16Z", + "author": { + "login": "github-actions" + }, + "authorAssociation": "NONE", + "editor": null, + "databaseId": 1152854802 + } + ], + "pageInfo": { + "startCursor": "Y3Vyc29yOnYyOpHOKCmhXQ==", + "hasPreviousPage": true + } + }, + "labels": { + "edges": [ + { + "node": { + "name": "triaged" + } + }, + { + "node": { + "name": "open source" + } + }, + { + "node": { + "name": "cla signed" + } + }, + { + "node": { + "name": "Stale" + } + } + ] + } + } + } + } + }, + "query_sha=2e2877d2452c4f233f042b7ccd50ab9c2a6e9a73d8819a0c876203c12364e8a3 cursor=Y3Vyc29yOnYyOpHOKCmhXQ== name=pytorch number=31093 owner=pytorch": { + "data": { + "repository": { + "pullRequest": { + "comments": { + "nodes": [ + { + "bodyText": "Hi, @mingfeima @soumith @Jianhui-Li\nthis will improve the test coverage of mkldnn convolution, would you please review it?\nThe current code is forward only, do we need to cover backward, if yes, we can add backward.", + "createdAt": "2019-12-12T01:19:02Z", + "author": { + "login": "mingxiaoh" + }, + "authorAssociation": "NONE", + "editor": null, + "databaseId": 564806270 + }, + { + "bodyText": "@mingxiaoh, what is the value in testing DNNL as part of Pytorch validation for the Pytorch developers? Shouldn't having these tests run in DNNL validation be enough?", + "createdAt": "2019-12-12T01:28:32Z", + "author": { + "login": "vpirogov" + }, + "authorAssociation": "NONE", + "editor": null, + "databaseId": 564808528 + }, + { + "bodyText": "@vpirogov The main value is to serve as a blind test to DNNL. If DNNL adds these test to DNNL test sets, it lost the value as a blind test. The spirit of validation is to cross check.\n@gottbrath @gchanan The test was developed per the request of Pytorch team. Mingxiao made an effort to reduce the execution time to a few second but still with good coverage. Although the test today is focused on DNNL, it could be easily extended to be blind test for any conv implementation used in Pytorch.", + "createdAt": "2019-12-20T07:44:30Z", + "author": { + "login": "Jianhui-Li" + }, + "authorAssociation": "NONE", + "editor": null, + "databaseId": 567826907 + }, + { + "bodyText": "@mruberry thanks for the comment. As for the chainer dependency, we import it is because we would like to use its testing function for pytest test cases combinations, other wise we need to write much more code to achieve same effect. So, can we use it?", + "createdAt": "2020-01-15T09:04:34Z", + "author": { + "login": "mingxiaoh" + }, + "authorAssociation": "NONE", + "editor": null, + "databaseId": 574563012 + }, + { + "bodyText": "@mingxiaoh You cannot import chainer. Looking at the code you should be able to achieve the same effect without it.", + "createdAt": "2020-01-16T17:59:46Z", + "author": { + "login": "mruberry" + }, + "authorAssociation": "COLLABORATOR", + "editor": null, + "databaseId": 575272358 + }, + { + "bodyText": "@mruberry ok, we will change it according to your requirement. Thanks", + "createdAt": "2020-02-10T00:59:34Z", + "author": { + "login": "mingxiaoh" + }, + "authorAssociation": "NONE", + "editor": null, + "databaseId": 583917522 + }, + { + "bodyText": "\ud83d\udd17 Helpful links\n\n\ud83e\uddea \u00a0See artifacts and rendered test results at hud.pytorch.org/pr/31093\n\ud83d\udd27 \u00a0Opt-in to CIFlow to control what jobs run on your PRs\n\n\ud83d\udc8a CI failures summary and remediations\nAs of commit 29f6aa6 (more details on the Dr. CI page):\n\nCommit 29f6aa6 was recently pushed. Waiting for builds...\n\nThis comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.\nPlease report bugs/suggestions to the (internal) Dr. CI Users group.\nClick here to manually regenerate this comment.", + "createdAt": "2020-05-14T08:04:30Z", + "author": { + "login": "dr-ci" + }, + "authorAssociation": "NONE", + "editor": { + "login": "facebook-github-bot" + }, + "databaseId": 628466876 + }, + { + "bodyText": "@mruberry how about those cudnn UT error? we add check for it but it should be NV to fix cudnn bugs.", + "createdAt": "2020-05-18T05:34:11Z", + "author": { + "login": "mingxiaoh" + }, + "authorAssociation": "NONE", + "editor": null, + "databaseId": 629955767 + }, + { + "bodyText": "Hey @mingxiaoh! You're right, of course, that you shouldn't have to fix cuDNN bugs. Would you please:\n\nAssert that the test case fails, so we know it's failing and if someone fixes it they'll know what test to update.\nFile a new issue explaining the behavior and providing a short PyTorch program to reproduce the issue.\n\nThen we can ping NVIDIA on that issue.", + "createdAt": "2020-05-18T07:27:08Z", + "author": { + "login": "mruberry" + }, + "authorAssociation": "COLLABORATOR", + "editor": null, + "databaseId": 629997129 + }, + { + "bodyText": "about the suggestion 'Assert that the test case fails, so we know it's failing and if someone fixes it they'll know what test to update. ', if we only assert it and continue the following test, I guess users might always ignore them in later test. Anyway, any similar example case for reference?", + "createdAt": "2020-05-18T07:55:08Z", + "author": { + "login": "mingxiaoh" + }, + "authorAssociation": "NONE", + "editor": null, + "databaseId": 630010734 + }, + { + "bodyText": "In this recent PR https://github.com/pytorch/pytorch/pull/38505/files, for example, you can see that the construction of bool tensors wasn't working properly, so the test author cited the relevant issue and asserted that the incorrect behavior happened, as expected. You can also see how these lines are being removed by https://github.com/pytorch/pytorch/pull/38392/files, which fixes the issue.\nAnother common pattern is to use with self.assertRaises(RuntimeError/AssertionError/etc.):.", + "createdAt": "2020-05-18T08:02:13Z", + "author": { + "login": "mruberry" + }, + "authorAssociation": "COLLABORATOR", + "editor": null, + "databaseId": 630014823 + }, + { + "bodyText": "@mruberry the failed UT case is not introduced by our modification, how to handle this issue?", + "createdAt": "2020-05-20T01:59:13Z", + "author": { + "login": "mingxiaoh" + }, + "authorAssociation": "NONE", + "editor": null, + "databaseId": 631187735 + }, + { + "bodyText": "@mingxiaoh You mean the failures on ROCm? You may ignore them. Be sure to re-request review when you're ready.", + "createdAt": "2020-05-20T02:12:58Z", + "author": { + "login": "mruberry" + }, + "authorAssociation": "COLLABORATOR", + "editor": null, + "databaseId": 631191425 + }, + { + "bodyText": "@mruberry we already skipped those ROCm errors, but there are stil somel error caused by the original code, they are not introduced by our modification.", + "createdAt": "2020-05-21T05:18:07Z", + "author": { + "login": "mingxiaoh" + }, + "authorAssociation": "NONE", + "editor": null, + "databaseId": 631886529 + }, + { + "bodyText": "I understand. Let me know when you're ready for me to review.", + "createdAt": "2020-05-21T06:24:15Z", + "author": { + "login": "mruberry" + }, + "authorAssociation": "COLLABORATOR", + "editor": null, + "databaseId": 631908011 + }, + { + "bodyText": "@mruberry thanks, we are ready for review now.", + "createdAt": "2020-05-21T06:28:11Z", + "author": { + "login": "mingxiaoh" + }, + "authorAssociation": "NONE", + "editor": null, + "databaseId": 631909442 + }, + { + "bodyText": "@mingxiaoh Great! I'll take a look ASAP.", + "createdAt": "2020-05-21T06:31:10Z", + "author": { + "login": "mruberry" + }, + "authorAssociation": "COLLABORATOR", + "editor": null, + "databaseId": 631910556 + }, + { + "bodyText": "@mruberry we just pull the latest code and updated the patch according to your comment, may you please help double check it? BTW, the new failed case in preci is not introduced by our modification.", + "createdAt": "2020-05-25T07:44:58Z", + "author": { + "login": "mingxiaoh" + }, + "authorAssociation": "NONE", + "editor": null, + "databaseId": 633430458 + }, + { + "bodyText": "@ailzhang would you please check the comment below? Thanks.\nIs there a reason why this TestConv2dExt is a new class instead a test inside TestNN?\n//comment: it is actually suggested by Tongzhou Wang in another thread before.\nAlthough this test sits in generic testing framework, it's actually comparing thnn/mkldnn/cudnn results specially. I feel it's better to make it truly generic so that it compares any device result with CPU result. Alternatively you can mark this test only run when torch.backends.mkldnn.is_available()=True\n//comment: but our goal is to compare the result with that of thnn. Anyway, if you insist, we can start to compare it with cpu.", + "createdAt": "2020-05-27T05:11:08Z", + "author": { + "login": "mingxiaoh" + }, + "authorAssociation": "NONE", + "editor": { + "login": "mingxiaoh" + }, + "databaseId": 634432326 + }, + { + "bodyText": "Pruning reviewers. @ngimel, @VitalyFedyunin, this PR is looking pretty good from a test framework perspective. Would one of you like to review?", + "createdAt": "2020-05-27T09:58:42Z", + "author": { + "login": "mruberry" + }, + "authorAssociation": "COLLABORATOR", + "editor": null, + "databaseId": 634557563 + }, + { + "bodyText": "@mruberry Thanks, would you please help review it again. BTW: failed case is not introduced by our modification.", + "createdAt": "2020-05-28T10:26:32Z", + "author": { + "login": "mingxiaoh" + }, + "authorAssociation": "NONE", + "editor": null, + "databaseId": 635256214 + }, + { + "bodyText": "@mruberry we moved our case to TestNNDeviceType class, would you please help review it again? BTW, those failed cases are not introduced by our code", + "createdAt": "2020-06-02T08:00:01Z", + "author": { + "login": "1pikachu" + }, + "authorAssociation": "NONE", + "editor": null, + "databaseId": 637364148 + }, + { + "bodyText": "@mruberry we moved our case to TestNNDeviceType class, would you please help review it again? BTW, those failed cases are not introduced by our code\n\n@ngimel will follow-up on the test itself sometime this week or early next week.", + "createdAt": "2020-06-02T10:23:47Z", + "author": { + "login": "mruberry" + }, + "authorAssociation": "COLLABORATOR", + "editor": null, + "databaseId": 637444457 + }, + { + "bodyText": "@mruberry we moved our case to TestNNDeviceType class, would you please help review it again? BTW, those failed cases are not introduced by our code\n\n@ngimel will follow-up on the test itself sometime this week or early next week.\n\n@mruberry thank you", + "createdAt": "2020-06-02T11:32:06Z", + "author": { + "login": "1pikachu" + }, + "authorAssociation": "NONE", + "editor": null, + "databaseId": 637479226 + }, + { + "bodyText": "Improving test coverage of math libraries is certainly a good goal and this PR is moving towards it. I have some doubts about implementation decisions made, and about running this PR as part of regular pytorch CI.\nIf the primary goal of this PR is to test correctness of the convolution implementations in the vendor library, then it does not serve this purpose. The absolute majority of the 4000+ test cases come from group 1, where different kernel sizes/strides/dilations are used to produce the output of size 1x1. This can test whether pytorch correctly passes convolution parameters to the backends (although there are cheaper ways to do that), but as actual library correctness check it is almost useless - libraries use very different kernels depending in the input/output sizes, and tests with toy sizes like this don't invoke the real bread-and-butter kernels.\nAlso, if this test suite is meant as primary a means of testing vendor libraries (which is a good goal!) it does not have a place as a part of pytorch regular CI, and should be run when the corresponding vendor libraries are updated. I'd suggest moving this test out into a separate file (maybe even outside of torch/test directory) and have it as a part of library update/qualification process rather than regular CI.\nAlso, if the primary goal is to enable easier testing of vendor libraries correctness, perhaps we should rethink the mechanism of the generation of test cases. It should be easy to add a test case with a particular set of parameters that was found to be buggy. Also, running a cross-product of cases in a multi-dimensional space (as this PR does) is rarely an efficient way of getting a signal, some forms of random sampling usually provide a way to get better correctness signal why using less resources.\nAlso, when testing libraries it is important to test both forward and backward functions, whereas this PR does forward only. I'm openminded on whether convTransposed should be tested or not - if we are testing vendor libraries, then it's not necessary, convTransposed calls the same underlying functions, if we are testing pytorch, then it makes sense to test it separately because it takes different codepaths.", + "createdAt": "2020-06-02T21:56:33Z", + "author": { + "login": "ngimel" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 637827507 + }, + { + "bodyText": "@mruberry ngimel is quite responsible, but it seems that she is not familiar with the background of this pull-request, since this pull-request is pending for so such a long time, each time we are almost done, then reviewer changes, each reviewer has different idea, it is good, but, would it be better if you help review it or ask the same reviewer to review it considering that you are more familiar with the background/change history? Thanks in advance.", + "createdAt": "2020-06-03T02:16:07Z", + "author": { + "login": "mingxiaoh" + }, + "authorAssociation": "NONE", + "editor": null, + "databaseId": 637912105 + }, + { + "bodyText": "@mruberry ngimel is quite responsible, but it seems that she is not familiar with the background of this pull-request, since this pull-request is pending for so such a long time, each time we are almost done, then reviewer changes, each reviewer has different idea, it is good, but, would it be better if you help review it or ask the same reviewer to review it considering that you are more familiar with the background/change history? Thanks in advance.\n\nWe know this PR has been open for awhile and we respect that your time is valuable, but we want to make sure we're making the right change here, and I think @ngimel's comments reflect that and should not be too difficult to address. As I understand, her points are:\n\nThis is a good PR with an exciting idea. To let it run longer and test more cases maybe it should run outside the regular PyTorch CI.\nTo remedy this, let's create a test/math_libraries folder and put this test there: test/math_libaries/convolutions.py. Yes, this is different from our requests in the past, which is our mistake, but it should be an easy change.\nTo make the test more interesting it'd be good for the test cases to resemble convolutions used in practice. The current test cases seem like similar \"toy\" examples. Without time pressure we should be able to run larger, more computationally intensive convolutions.\nLet's change the test cases to include some practical convolutions, make it easy to add test cases, and think about how we might generate other interesting cases. (We should also test backwards once we have more time!)\n\nAnd I think these are good points. Maybe the PR doesn't create a new way to generate interesting convolutions to start and instead only runs a few representative convolutions, but @ngimel is positioning the work for success so that it's useful and we can continue to improve on it in the future.\nDoes that make sense?", + "createdAt": "2020-06-03T03:04:55Z", + "author": { + "login": "mruberry" + }, + "authorAssociation": "COLLABORATOR", + "editor": null, + "databaseId": 637924703 + }, + { + "bodyText": "@mruberry we were required to finish the test in limited time long long before, at that time, jianhui discussed this issue with you, and you are all agreed with the current test scope and test case number and test time, so you meant you change your mind now? you are not care about the test time currently? Sorry, this issue is pending so long, we are struggling with it now and would like to finish it asap. Given this, it would be be better if you raise all the requirement at a time, considering that we have many tasks at hand, we are hoping so eagerly that we can finish this PR and use it for further test for bugs finding.", + "createdAt": "2020-06-03T05:22:43Z", + "author": { + "login": "mingxiaoh" + }, + "authorAssociation": "NONE", + "editor": { + "login": "mingxiaoh" + }, + "databaseId": 637960626 + }, + { + "bodyText": "@mruberry we were required to finish the test in limited time long long before, at that time, jianhui discussed this issue with you, and you are all agreed with the current test scope and test case number and test time, so you meant you change your mind now? you are not care about the test time currently? Sorry, this issue is pending so long, we are struggling with it now and would like to finish it asap. Given this, it would be be better if you raise all the requirement at a time, considering that we have many tasks at hand, we are hoping so eagerly that we can finish this PR and use it for further test for bugs finding.\n\nI'm sorry, I don't think I've talked to @Jianhui-Li before. It's true that the team we expressed a concern about timing if the test was to be run in the CI initially, but I think now that we understand what the test is trying to do better we're not sure the CI is the best place for it. The PR was also closed after a lengthy period of inactivity, and we assumed it had simply been abandoned.\nDo you know who @Jianhui-Li spoke with about this issue originally? Maybe I can follow-up with them for more context.", + "createdAt": "2020-06-03T05:42:28Z", + "author": { + "login": "mruberry" + }, + "authorAssociation": "COLLABORATOR", + "editor": null, + "databaseId": 637967153 + }, + { + "bodyText": "@mruberry it is reviewed and discussed with @soumith before. Anyway, since current reviewer is you, so, it should be decided by you. So, what we should do next?", + "createdAt": "2020-06-03T06:13:14Z", + "author": { + "login": "mingxiaoh" + }, + "authorAssociation": "NONE", + "editor": null, + "databaseId": 637978356 + }, + { + "bodyText": "@mruberry it is reviewed and discussed with @soumith before. Anyway, since current reviewer is you, so, it should be decided by you. So, what we should do next?\n\nI think this will be easier to discuss at the regular Intel-FB meeting.", + "createdAt": "2020-06-03T20:34:05Z", + "author": { + "login": "mruberry" + }, + "authorAssociation": "COLLABORATOR", + "editor": null, + "databaseId": 638446723 + }, + { + "bodyText": "@mruberry it is reviewed and discussed with @soumith before. Anyway, since current reviewer is you, so, it should be decided by you. So, what we should do next?\n\nI think this will be easier to discuss at the regular Intel-FB meeting.\n\nLet me sync with Mingxiao and follow up with this. Thanks.", + "createdAt": "2020-06-03T20:44:44Z", + "author": { + "login": "Jianhui-Li" + }, + "authorAssociation": "NONE", + "editor": null, + "databaseId": 638451670 + }, + { + "bodyText": "@mruberry would you please help review it again?", + "createdAt": "2020-07-02T14:09:23Z", + "author": { + "login": "mingxiaoh" + }, + "authorAssociation": "NONE", + "editor": null, + "databaseId": 653028208 + }, + { + "bodyText": "@mruberry would you please help review it again?\n\nHappy to help out, but as last discussed this needs some follow-up at the Intel-FB meeting. Did you get a chance to discuss it there, yet? If so, what did you decide?", + "createdAt": "2020-07-06T20:15:04Z", + "author": { + "login": "mruberry" + }, + "authorAssociation": "COLLABORATOR", + "editor": null, + "databaseId": 654443242 + }, + { + "bodyText": "@mruberry would you please help review it again?\n\nHappy to help out, but as last discussed this needs some follow-up at the Intel-FB meeting. Did you get a chance to discuss it there, yet? If so, what did you decide?\n\nyes, we talked it with jianhui, and we decided to follow your ideas. Anyway, we would like to do so modification later, will contact you for review tomorrow. Thanks", + "createdAt": "2020-07-09T11:04:06Z", + "author": { + "login": "mingxiaoh" + }, + "authorAssociation": "NONE", + "editor": null, + "databaseId": 656062287 + }, + { + "bodyText": "@mruberry would you please help review it again?\n\nHappy to help out, but as last discussed this needs some follow-up at the Intel-FB meeting. Did you get a chance to discuss it there, yet? If so, what did you decide?\n\nyes, we talked it with jianhui, and we decided to follow your ideas. Anyway, we would like to do so modification later, will contact you for review tomorrow. Thanks\n\n@mruberry the code is ready for review now, would you please take time for it? Thanks.", + "createdAt": "2020-07-14T09:16:48Z", + "author": { + "login": "mingxiaoh" + }, + "authorAssociation": "NONE", + "editor": null, + "databaseId": 658071151 + }, + { + "bodyText": "super nit: renaming files to .json will make it more IDE friendly.", + "createdAt": "2020-07-14T23:38:37Z", + "author": { + "login": "VitalyFedyunin" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 658464685 + }, + { + "bodyText": "@mruberry would you please help review it again?\n\nHappy to help out, but as last discussed this needs some follow-up at the Intel-FB meeting. Did you get a chance to discuss it there, yet? If so, what did you decide?\n\nyes, we talked it with jianhui, and we decided to follow your ideas. Anyway, we would like to do so modification later, will contact you for review tomorrow. Thanks\n\n@mruberry the code is ready for review now, would you please take time for it? Thanks.\n\nCool! I took a look with @ngimel, once these issues are addressed I think we're good to go!", + "createdAt": "2020-07-16T05:17:29Z", + "author": { + "login": "mruberry" + }, + "authorAssociation": "COLLABORATOR", + "editor": null, + "databaseId": 659164401 + }, + { + "bodyText": "@ngimel & @VitalyFedyunin We have changed the code according to your suggestions, would you please review it again? Thanks.", + "createdAt": "2020-07-20T08:30:01Z", + "author": { + "login": "mingxiaoh" + }, + "authorAssociation": "NONE", + "editor": null, + "databaseId": 660884305 + }, + { + "bodyText": "@ngimel & @VitalyFedyunin We have changed the code according to your suggestions, would you please review it again? Thanks.\n\nUpdated: one more question about tolerances, one code cleanup recommendation, and one task leftover from the last review.", + "createdAt": "2020-07-22T20:26:42Z", + "author": { + "login": "mruberry" + }, + "authorAssociation": "COLLABORATOR", + "editor": null, + "databaseId": 662678464 + }, + { + "bodyText": "Updated: one more question about tolerances, one code cleanup recommendation, and one task leftover from the last review.\n@mruberry we have finished the modification according to your comment, would you please review it again? Thanks.", + "createdAt": "2020-07-23T10:24:26Z", + "author": { + "login": "mingxiaoh" + }, + "authorAssociation": "NONE", + "editor": null, + "databaseId": 662930687 + }, + { + "bodyText": "The code looks good, but I tried running the test suite and hit the following failures:\n======================================================================\nFAIL: test_conv2d_ext_cuda_float16 (__main__.TestConvExtCUDA)\n----------------------------------------------------------------------\nTraceback (most recent call last):\n File \"/private/home/mruberry/git/pytorch/torch/testing/_internal/common_utils.py\", line 777, in wrapper\n method(*args, **kwargs)\n File \"/private/home/mruberry/git/pytorch/torch/testing/_internal/common_utils.py\", line 777, in wrapper\n method(*args, **kwargs)\n File \"/private/home/mruberry/git/pytorch/torch/testing/_internal/common_device_type.py\", line 241, in instantiated_test\n result = test(self, device_arg, dtype)\n File \"/private/home/mruberry/git/pytorch/torch/testing/_internal/common_device_type.py\", line 542, in only_fn\n return fn(self, device, *args, **kwargs)\n File \"/private/home/mruberry/git/pytorch/torch/testing/_internal/common_device_type.py\", line 411, in dep_fn\n return fn(slf, device, *args, **kwargs)\n File \"convolutions.py\", line 102, in test_conv2d_ext\n msg=msg\n File \"/private/home/mruberry/git/pytorch/torch/testing/_internal/common_utils.py\", line 1085, in assertEqual\n self.assertTrue(result, msg=msg)\nAssertionError: False is not true : device:cuda:0, dtype:torch.float16, group:1, batchsize:22input channel:448, output channel:384, bias:False, padding:[1, 1], dilation:[1, 1], stride:[1, 1], kernel:[3, 3]\n\n======================================================================\nFAIL: test_conv2d_ext_cuda_float32 (__main__.TestConvExtCUDA)\n----------------------------------------------------------------------\nTraceback (most recent call last):\n File \"/private/home/mruberry/git/pytorch/torch/testing/_internal/common_utils.py\", line 777, in wrapper\n method(*args, **kwargs)\n File \"/private/home/mruberry/git/pytorch/torch/testing/_internal/common_utils.py\", line 777, in wrapper\n method(*args, **kwargs)\n File \"/private/home/mruberry/git/pytorch/torch/testing/_internal/common_device_type.py\", line 241, in instantiated_test\n result = test(self, device_arg, dtype)\n File \"/private/home/mruberry/git/pytorch/torch/testing/_internal/common_device_type.py\", line 542, in only_fn\n return fn(self, device, *args, **kwargs)\n File \"/private/home/mruberry/git/pytorch/torch/testing/_internal/common_device_type.py\", line 411, in dep_fn\n return fn(slf, device, *args, **kwargs)\n File \"convolutions.py\", line 102, in test_conv2d_ext\n msg=msg\n File \"/private/home/mruberry/git/pytorch/torch/testing/_internal/common_utils.py\", line 1085, in assertEqual\n self.assertTrue(result, msg=msg)\nAssertionError: False is not true : device:cuda:0, dtype:torch.float32, group:1, batchsize:22input channel:80, output channel:192, bias:False, padding:[0, 0], dilation:[1, 1], stride:[1, 1], kernel:[3, 3]\n\n======================================================================\nFAIL: test_conv2d_ext_cuda_float64 (__main__.TestConvExtCUDA)\n----------------------------------------------------------------------\nTraceback (most recent call last):\n File \"/private/home/mruberry/git/pytorch/torch/testing/_internal/common_utils.py\", line 777, in wrapper\n method(*args, **kwargs)\n File \"/private/home/mruberry/git/pytorch/torch/testing/_internal/common_utils.py\", line 777, in wrapper\n method(*args, **kwargs)\n File \"/private/home/mruberry/git/pytorch/torch/testing/_internal/common_device_type.py\", line 241, in instantiated_test\n result = test(self, device_arg, dtype)\n File \"/private/home/mruberry/git/pytorch/torch/testing/_internal/common_device_type.py\", line 542, in only_fn\n return fn(self, device, *args, **kwargs)\n File \"/private/home/mruberry/git/pytorch/torch/testing/_internal/common_device_type.py\", line 411, in dep_fn\n return fn(slf, device, *args, **kwargs)\n File \"convolutions.py\", line 106, in test_conv2d_ext\n \"invalid cases:\" + \",\".join(invalid_cases)\nAssertionError: invalid cases:masknet_p1:conv33,masknet_p1:conv8,masknet_p1:conv2*4,masknet_p1:conv12,masknet_p1:conv4*3,masknet_p1:conv19,masknet_p1:conv4,masknet_p1:conv4,masknet_p1:conv27,masknet_p1:conv39,masknet_p1:conv23,masknet_p1:conv20,masknet_p1:conv25,masknet_p1:conv17,masknet_p1:conv9*4,masknet_p1:conv36,masknet_p1:conv18,masknet_p1:conv5,masknet_p1:conv38,masknet_p1:conv31,masknet_p1:conv14,masknet_p1:conv26,masknet_p1:conv2,masknet_p1:conv5*2,masknet_p1:conv28,masknet_p1:conv16,masknet_p1:conv20*3,masknet_p1:conv9,masknet_p1:conv14*23,masknet_p1:conv32,masknet_p1:conv30,masknet_p1:conv35,masknet_p1:conv37,masknet_p1:conv3,masknet_p1:conv24,masknet_p1:conv13,masknet_p1:conv21*3,masknet_p1:conv10,masknet_p1:conv7,masknet_p1:conv34,masknet_p1:conv13*24,masknet_p1:conv10*4,masknet_p1:conv22*2,masknet_p1:conv6,masknet_p1:conv22,masknet_p1:conv11,masknet_p1:conv40,masknet_p1:conv15,masknet_p1:conv17*23,masknet_p1:conv29,masknet_p1:conv21,masknet_p1:conv1,masknet_p1:conv11*3,mobilenet:conv3,mobilenet:conv2*4,mobilenet:conv6,mobilenet:conv7,mobilenet:conv5*4,mobilenet:conv4*4,mobilenet:conv7*4,mobilenet:conv1*3,mobilenet:conv10,mobilenet:conv2,mobilenet:conv5,mobilenet:conv4,mobilenet:conv9*4,mobilenet:conv8,mobilenet:conv9,mobilenet:conv6*4,mobilenet:conv10*4,mobilenet:conv11,mobilenet:conv8*20,mobilenet:conv1,mobilenet:conv11*4,mobilenet:conv3*4\n\nLooking at the first invalid convolution, for example, it's:\n {\n \"case_name\":\"masknet_p1:conv33\",\n \"mb\":1,\n \"g\":1,\n \"ic\":512,\n \"ih\":64,\n \"iw\":64,\n \"oc\":12,\n \"kh\":1,\n \"kw\":1,\n \"sh\":1,\n \"sw\":1,\n \"ph\":0,\n \"pw\":0,\n \"dh\":0,\n \"dw\":0,\n \"bias\":\"False\"\n },\n\nwhich has a dh and dw of zero, causing it to be added to invalid cases here:\ndh, dw = case['dh'], case['dw']\n has_bias = case['bias']\n if dh == 0 or dw == 0:\n invalid_cases.append(case_name)", + "createdAt": "2020-07-23T21:25:19Z", + "author": { + "login": "mruberry" + }, + "authorAssociation": "COLLABORATOR", + "editor": { + "login": "mruberry" + }, + "databaseId": 663240268 + }, + { + "bodyText": "@mruberry the failure was not detected is because we did not export the cudnn path. Yes, you are right, we need to a large atol of 1e-2 . Would you please help review it again? Thanks.", + "createdAt": "2020-07-27T12:43:44Z", + "author": { + "login": "mingxiaoh" + }, + "authorAssociation": "NONE", + "editor": null, + "databaseId": 664373079 + }, + { + "bodyText": "@mruberry the failure was not detected is because we did not export the cudnn path. Yes, you are right, we need to a large atol of 1e-2 . Would you please help review it again? Thanks.\n\nBefore I run these tests again, is an atol of 1e-2 needed for all types or just half? Also, how does 1e-2 compare to the values that are being compared?", + "createdAt": "2020-07-27T18:39:27Z", + "author": { + "login": "mruberry" + }, + "authorAssociation": "COLLABORATOR", + "editor": null, + "databaseId": 664569507 + }, + { + "bodyText": "@mruberry 1e-2 is experimental result, details see below, random means it might be failed sometimes.\n\n\n\natol,rtol\n1e-2,1e-2\n1e-2,1e-3\n1e-3,1e-2\n1e-3,1e-3\n1e-4,1e-3\n1e-3,1e-4\n1e-4,1e-4\n1e-4,1e-5\n1e-5,1e-4\n\n\n\n\nCuda float16\npass\npass\npass\npass\npass\nfail\nFail\nFail\nfail\n\n\nCuda float32\npass\nrandom\nrandom\nrandom\nrandom\nrandom\nrandom\nrandom\nfail", + "createdAt": "2020-07-31T03:33:27Z", + "author": { + "login": "mingxiaoh" + }, + "authorAssociation": "NONE", + "editor": null, + "databaseId": 666894774 + }, + { + "bodyText": "@mruberry would you please find time to review it again? Thanks.", + "createdAt": "2020-08-04T05:01:20Z", + "author": { + "login": "mingxiaoh" + }, + "authorAssociation": "NONE", + "editor": null, + "databaseId": 668380451 + }, + { + "bodyText": "@mruberry would you please find time to review it again? Thanks.\n\nI was just about to try and run this again locally but it looks like the files describing the convolutions are missing?", + "createdAt": "2020-08-07T03:49:44Z", + "author": { + "login": "mruberry" + }, + "authorAssociation": "COLLABORATOR", + "editor": null, + "databaseId": 670306210 + }, + { + "bodyText": "@mruberry sorry but what is missing actually?", + "createdAt": "2020-08-07T05:00:20Z", + "author": { + "login": "mingxiaoh" + }, + "authorAssociation": "NONE", + "editor": null, + "databaseId": 670322557 + }, + { + "bodyText": "@mruberry sorry but what is missing actually?\n\nThe JSON files.", + "createdAt": "2020-08-07T16:06:41Z", + "author": { + "login": "mruberry" + }, + "authorAssociation": "COLLABORATOR", + "editor": null, + "databaseId": 670591170 + }, + { + "bodyText": "@mruberry sorry but what is missing actually?\n\nThe JSON files.\n\n@mruberry sorry, we add them now, would you please check it again? Thanks.", + "createdAt": "2020-08-13T10:40:11Z", + "author": { + "login": "mingxiaoh" + }, + "authorAssociation": "NONE", + "editor": null, + "databaseId": 673402901 + }, + { + "bodyText": "I cloned your repo and ran the tests:\n~/pytorch/test/math_libraries$ python convolutions.py\nFFFF\n======================================================================\nFAIL: test_conv2d_ext_cpu_float32 (__main__.TestConvExtCPU)\n----------------------------------------------------------------------\nTraceback (most recent call last):\n File \"/private/home/mruberry/git/pytorch/torch/testing/_internal/common_device_type.py\", line 244, in instantiated_test\n result = test(self, *args)\n File \"/private/home/mruberry/git/pytorch/torch/testing/_internal/common_device_type.py\", line 615, in only_fn\n return fn(self, device, *args, **kwargs)\n File \"/private/home/mruberry/git/pytorch/torch/testing/_internal/common_device_type.py\", line 472, in dep_fn\n return fn(slf, device, *args, **kwargs)\n File \"convolutions.py\", line 114, in test_conv2d_ext\n \"invalid cases:\" + \",\".join(invalid_cases)\nAssertionError: invalid cases:masknet_p1:conv33,masknet_p1:conv8,masknet_p1:conv2*4,masknet_p1:conv12,masknet_p1:conv4*3,masknet_p1:conv19,masknet_p1:conv4,masknet_p1:conv4,masknet_p1:conv27,masknet_p1:conv39,masknet_p1:conv23,masknet_p1:conv20,masknet_p1:conv25,masknet_p1:conv17,masknet_p1:conv9*4,masknet_p1:conv36,masknet_p1:conv18,masknet_p1:conv5,masknet_p1:conv38,masknet_p1:conv31,masknet_p1:conv14,masknet_p1:conv26,masknet_p1:conv2,masknet_p1:conv5*2,masknet_p1:conv28,masknet_p1:conv16,masknet_p1:conv20*3,masknet_p1:conv9,masknet_p1:conv14*23,masknet_p1:conv32,masknet_p1:conv30,masknet_p1:conv35,masknet_p1:conv37,masknet_p1:conv3,masknet_p1:conv24,masknet_p1:conv13,masknet_p1:conv21*3,masknet_p1:conv10,masknet_p1:conv7,masknet_p1:conv34,masknet_p1:conv13*24,masknet_p1:conv10*4,masknet_p1:conv22*2,masknet_p1:conv6,masknet_p1:conv22,masknet_p1:conv11,masknet_p1:conv40,masknet_p1:conv15,masknet_p1:conv17*23,masknet_p1:conv29,masknet_p1:conv21,masknet_p1:conv1,masknet_p1:conv11*3,mobilenet:conv3,mobilenet:conv2*4,mobilenet:conv6,mobilenet:conv7,mobilenet:conv5*4,mobilenet:conv4*4,mobilenet:conv7*4,mobilenet:conv1*3,mobilenet:conv10,mobilenet:conv2,mobilenet:conv5,mobilenet:conv4,mobilenet:conv9*4,mobilenet:conv8,mobilenet:conv9,mobilenet:conv6*4,mobilenet:conv10*4,mobilenet:conv11,mobilenet:conv8*20,mobilenet:conv1,mobilenet:conv11*4,mobilenet:conv3*4\n\n======================================================================\nFAIL: test_conv2d_ext_cuda_float16 (__main__.TestConvExtCUDA)\n----------------------------------------------------------------------\nTraceback (most recent call last):\n File \"/private/home/mruberry/git/pytorch/torch/testing/_internal/common_utils.py\", line 815, in wrapper\n method(*args, **kwargs)\n File \"/private/home/mruberry/git/pytorch/torch/testing/_internal/common_utils.py\", line 815, in wrapper\n method(*args, **kwargs)\n File \"/private/home/mruberry/git/pytorch/torch/testing/_internal/common_device_type.py\", line 244, in instantiated_test\n result = test(self, *args)\n File \"/private/home/mruberry/git/pytorch/torch/testing/_internal/common_device_type.py\", line 615, in only_fn\n return fn(self, device, *args, **kwargs)\n File \"/private/home/mruberry/git/pytorch/torch/testing/_internal/common_device_type.py\", line 472, in dep_fn\n return fn(slf, device, *args, **kwargs)\n File \"convolutions.py\", line 114, in test_conv2d_ext\n \"invalid cases:\" + \",\".join(invalid_cases)\nAssertionError: invalid cases:masknet_p1:conv33,masknet_p1:conv8,masknet_p1:conv2*4,masknet_p1:conv12,masknet_p1:conv4*3,masknet_p1:conv19,masknet_p1:conv4,masknet_p1:conv4,masknet_p1:conv27,masknet_p1:conv39,masknet_p1:conv23,masknet_p1:conv20,masknet_p1:conv25,masknet_p1:conv17,masknet_p1:conv9*4,masknet_p1:conv36,masknet_p1:conv18,masknet_p1:conv5,masknet_p1:conv38,masknet_p1:conv31,masknet_p1:conv14,masknet_p1:conv26,masknet_p1:conv2,masknet_p1:conv5*2,masknet_p1:conv28,masknet_p1:conv16,masknet_p1:conv20*3,masknet_p1:conv9,masknet_p1:conv14*23,masknet_p1:conv32,masknet_p1:conv30,masknet_p1:conv35,masknet_p1:conv37,masknet_p1:conv3,masknet_p1:conv24,masknet_p1:conv13,masknet_p1:conv21*3,masknet_p1:conv10,masknet_p1:conv7,masknet_p1:conv34,masknet_p1:conv13*24,masknet_p1:conv10*4,masknet_p1:conv22*2,masknet_p1:conv6,masknet_p1:conv22,masknet_p1:conv11,masknet_p1:conv40,masknet_p1:conv15,masknet_p1:conv17*23,masknet_p1:conv29,masknet_p1:conv21,masknet_p1:conv1,masknet_p1:conv11*3,mobilenet:conv3,mobilenet:conv2*4,mobilenet:conv6,mobilenet:conv7,mobilenet:conv5*4,mobilenet:conv4*4,mobilenet:conv7*4,mobilenet:conv1*3,mobilenet:conv10,mobilenet:conv2,mobilenet:conv5,mobilenet:conv4,mobilenet:conv9*4,mobilenet:conv8,mobilenet:conv9,mobilenet:conv6*4,mobilenet:conv10*4,mobilenet:conv11,mobilenet:conv8*20,mobilenet:conv1,mobilenet:conv11*4,mobilenet:conv3*4\n\n======================================================================\nFAIL: test_conv2d_ext_cuda_float32 (__main__.TestConvExtCUDA)\n----------------------------------------------------------------------\nTraceback (most recent call last):\n File \"/private/home/mruberry/git/pytorch/torch/testing/_internal/common_utils.py\", line 815, in wrapper\n method(*args, **kwargs)\n File \"/private/home/mruberry/git/pytorch/torch/testing/_internal/common_utils.py\", line 815, in wrapper\n method(*args, **kwargs)\n File \"/private/home/mruberry/git/pytorch/torch/testing/_internal/common_device_type.py\", line 244, in instantiated_test\n result = test(self, *args)\n File \"/private/home/mruberry/git/pytorch/torch/testing/_internal/common_device_type.py\", line 615, in only_fn\n return fn(self, device, *args, **kwargs)\n File \"/private/home/mruberry/git/pytorch/torch/testing/_internal/common_device_type.py\", line 472, in dep_fn\n return fn(slf, device, *args, **kwargs)\n File \"convolutions.py\", line 114, in test_conv2d_ext\n \"invalid cases:\" + \",\".join(invalid_cases)\nAssertionError: invalid cases:masknet_p1:conv33,masknet_p1:conv8,masknet_p1:conv2*4,masknet_p1:conv12,masknet_p1:conv4*3,masknet_p1:conv19,masknet_p1:conv4,masknet_p1:conv4,masknet_p1:conv27,masknet_p1:conv39,masknet_p1:conv23,masknet_p1:conv20,masknet_p1:conv25,masknet_p1:conv17,masknet_p1:conv9*4,masknet_p1:conv36,masknet_p1:conv18,masknet_p1:conv5,masknet_p1:conv38,masknet_p1:conv31,masknet_p1:conv14,masknet_p1:conv26,masknet_p1:conv2,masknet_p1:conv5*2,masknet_p1:conv28,masknet_p1:conv16,masknet_p1:conv20*3,masknet_p1:conv9,masknet_p1:conv14*23,masknet_p1:conv32,masknet_p1:conv30,masknet_p1:conv35,masknet_p1:conv37,masknet_p1:conv3,masknet_p1:conv24,masknet_p1:conv13,masknet_p1:conv21*3,masknet_p1:conv10,masknet_p1:conv7,masknet_p1:conv34,masknet_p1:conv13*24,masknet_p1:conv10*4,masknet_p1:conv22*2,masknet_p1:conv6,masknet_p1:conv22,masknet_p1:conv11,masknet_p1:conv40,masknet_p1:conv15,masknet_p1:conv17*23,masknet_p1:conv29,masknet_p1:conv21,masknet_p1:conv1,masknet_p1:conv11*3,mobilenet:conv3,mobilenet:conv2*4,mobilenet:conv6,mobilenet:conv7,mobilenet:conv5*4,mobilenet:conv4*4,mobilenet:conv7*4,mobilenet:conv1*3,mobilenet:conv10,mobilenet:conv2,mobilenet:conv5,mobilenet:conv4,mobilenet:conv9*4,mobilenet:conv8,mobilenet:conv9,mobilenet:conv6*4,mobilenet:conv10*4,mobilenet:conv11,mobilenet:conv8*20,mobilenet:conv1,mobilenet:conv11*4,mobilenet:conv3*4\n\n======================================================================\nFAIL: test_conv2d_ext_cuda_float64 (__main__.TestConvExtCUDA)\n----------------------------------------------------------------------\nTraceback (most recent call last):\n File \"/private/home/mruberry/git/pytorch/torch/testing/_internal/common_utils.py\", line 815, in wrapper\n method(*args, **kwargs)\n File \"/private/home/mruberry/git/pytorch/torch/testing/_internal/common_utils.py\", line 815, in wrapper\n method(*args, **kwargs)\n File \"/private/home/mruberry/git/pytorch/torch/testing/_internal/common_device_type.py\", line 244, in instantiated_test\n result = test(self, *args)\n File \"/private/home/mruberry/git/pytorch/torch/testing/_internal/common_device_type.py\", line 615, in only_fn\n return fn(self, device, *args, **kwargs)\n File \"/private/home/mruberry/git/pytorch/torch/testing/_internal/common_device_type.py\", line 472, in dep_fn\n return fn(slf, device, *args, **kwargs)\n File \"convolutions.py\", line 114, in test_conv2d_ext\n \"invalid cases:\" + \",\".join(invalid_cases)\nAssertionError: invalid cases:masknet_p1:conv33,masknet_p1:conv8,masknet_p1:conv2*4,masknet_p1:conv12,masknet_p1:conv4*3,masknet_p1:conv19,masknet_p1:conv4,masknet_p1:conv4,masknet_p1:conv27,masknet_p1:conv39,masknet_p1:conv23,masknet_p1:conv20,masknet_p1:conv25,masknet_p1:conv17,masknet_p1:conv9*4,masknet_p1:conv36,masknet_p1:conv18,masknet_p1:conv5,masknet_p1:conv38,masknet_p1:conv31,masknet_p1:conv14,masknet_p1:conv26,masknet_p1:conv2,masknet_p1:conv5*2,masknet_p1:conv28,masknet_p1:conv16,masknet_p1:conv20*3,masknet_p1:conv9,masknet_p1:conv14*23,masknet_p1:conv32,masknet_p1:conv30,masknet_p1:conv35,masknet_p1:conv37,masknet_p1:conv3,masknet_p1:conv24,masknet_p1:conv13,masknet_p1:conv21*3,masknet_p1:conv10,masknet_p1:conv7,masknet_p1:conv34,masknet_p1:conv13*24,masknet_p1:conv10*4,masknet_p1:conv22*2,masknet_p1:conv6,masknet_p1:conv22,masknet_p1:conv11,masknet_p1:conv40,masknet_p1:conv15,masknet_p1:conv17*23,masknet_p1:conv29,masknet_p1:conv21,masknet_p1:conv1,masknet_p1:conv11*3,mobilenet:conv3,mobilenet:conv2*4,mobilenet:conv6,mobilenet:conv7,mobilenet:conv5*4,mobilenet:conv4*4,mobilenet:conv7*4,mobilenet:conv1*3,mobilenet:conv10,mobilenet:conv2,mobilenet:conv5,mobilenet:conv4,mobilenet:conv9*4,mobilenet:conv8,mobilenet:conv9,mobilenet:conv6*4,mobilenet:conv10*4,mobilenet:conv11,mobilenet:conv8*20,mobilenet:conv1,mobilenet:conv11*4,mobilenet:conv3*4\n\n----------------------------------------------------------------------\nRan 4 tests in 33.838s\n\nFAILED (failures=4)\n\nStill fails.", + "createdAt": "2020-08-13T23:35:00Z", + "author": { + "login": "mruberry" + }, + "authorAssociation": "COLLABORATOR", + "editor": null, + "databaseId": 673760580 + } + ], + "pageInfo": { + "startCursor": "Y3Vyc29yOnYyOpHOIapCfg==", + "hasPreviousPage": false + } + } + } + } + } + }, + "query_sha=81fd873151c3cded18314e9e53bf54a93ffb0afa9c52fa2cbafb2ceab7df5e45 name=pytorch number=76118 owner=pytorch": { + "data": { + "repository": { + "pullRequest": { + "closed": true, + "isCrossRepository": false, + "author": { + "login": "malfet" + }, + "title": "Dummy change with lots of commits", + "body": "Draft PR with 100+ commits, to test mergebot ", + "headRefName": "malfet/pr-with-lots-of-commits", + "headRepository": { + "nameWithOwner": "pytorch/pytorch" + }, + "baseRefName": "master", + "baseRepository": { + "nameWithOwner": "pytorch/pytorch", + "isPrivate": false, + "defaultBranchRef": { + "name": "master" + } + }, + "mergeCommit": null, + "commits_with_authors": { + "nodes": [ + { + "commit": { + "author": { + "user": { + "login": "malfet" + }, + "email": "nshulga@fb.com", + "name": "Nikita Shulga" + }, + "oid": "3067f2240afc7a29dc348000aa19eccbd9772303" + } + }, + { + "commit": { + "author": { + "user": { + "login": "andrewor14" + }, + "email": "andrewor@fb.com", + "name": "Andrew Or" + }, + "oid": "2f655b71f70c496c4e645f6cdb27d7bb7e825701" + } + }, + { + "commit": { + "author": { + "user": null, + "email": "mruberry@devfair044.h1.fair", + "name": "Mike Ruberry" + }, + "oid": "0c6dcaa7f58a19c42a530f4ee14bb6f0f03ca9fb" + } + }, + { + "commit": { + "author": { + "user": { + "login": "dzdang" + }, + "email": "dzdang@umich.edu", + "name": "dzdang" + }, + "oid": "cad11c563d41ebcffb1683fe1f1288b8157413b3" + } + }, + { + "commit": { + "author": { + "user": null, + "email": "jwtan@fb.com", + "name": "Jiewen Tan" + }, + "oid": "4dfd0875a68d87fccb5ad0d81692db480043b86e" + } + }, + { + "commit": { + "author": { + "user": null, + "email": "mruberry@devfair044.h1.fair", + "name": "Mike Ruberry" + }, + "oid": "2d37e74690582a4a26890e4c8b98f1f80e589c82" + } + }, + { + "commit": { + "author": { + "user": null, + "email": "jwtan@fb.com", + "name": "Jiewen Tan" + }, + "oid": "d4aee60947e1a3ef23c7c42990621e0746fdd0a8" + } + }, + { + "commit": { + "author": { + "user": { + "login": "peterbell10" + }, + "email": "peterbell10@live.co.uk", + "name": "Peter Bell" + }, + "oid": "aac6204bf710beb5e50a383d426ae6222396335a" + } + }, + { + "commit": { + "author": { + "user": { + "login": "dzdang" + }, + "email": "dzdang@umich.edu", + "name": "dzdang" + }, + "oid": "4b0362cab884584c24f5834b3874f5f357f56b5d" + } + }, + { + "commit": { + "author": { + "user": null, + "email": "mruberry@devfair044.h1.fair", + "name": "Mike Ruberry" + }, + "oid": "7536df613cbc645a9e68e6a3b0a8450753260fd1" + } + }, + { + "commit": { + "author": { + "user": null, + "email": "mruberry@devfair044.h1.fair", + "name": "Mike Ruberry" + }, + "oid": "20a50cb966d28d7bf82924adf781cf72a01ef90e" + } + }, + { + "commit": { + "author": { + "user": null, + "email": "mruberry@devfair044.h1.fair", + "name": "Mike Ruberry" + }, + "oid": "486387e8644afb46edff5aa5925b55c8119f67f0" + } + }, + { + "commit": { + "author": { + "user": { + "login": "dzdang" + }, + "email": "dzdang@umich.edu", + "name": "dzdang" + }, + "oid": "acb9d78b9b732d3667b881727e6ed9f92a8c549f" + } + }, + { + "commit": { + "author": { + "user": null, + "email": "mruberry@devfair044.h1.fair", + "name": "Mike Ruberry" + }, + "oid": "683bb7959a5b973f8470c081ad02e8fc508e784a" + } + }, + { + "commit": { + "author": { + "user": { + "login": "qihqi" + }, + "email": "qihan@fb.com", + "name": "Han Qi" + }, + "oid": "a870cb40af65adf0b77d55f6b554d7093d284d7a" + } + }, + { + "commit": { + "author": { + "user": { + "login": "Krovatkin" + }, + "email": "korovaikon@gmail.com", + "name": "Nikolay Korovaiko" + }, + "oid": "70793b9f328ddf52cc86336104c3a064c8582ef4" + } + }, + { + "commit": { + "author": { + "user": { + "login": "suo" + }, + "email": "suo@fb.com", + "name": "Michael Suo" + }, + "oid": "f70b31f62b1c5159eef2725484b175983517c88c" + } + }, + { + "commit": { + "author": { + "user": { + "login": "dagitses" + }, + "email": "mikeyd@fb.com", + "name": "Michael Andreas Dagitses" + }, + "oid": "04d3ec1db60defe1c6904bf77e9f8dfa87dc0b63" + } + }, + { + "commit": { + "author": { + "user": null, + "email": "mruberry@devfair044.h1.fair", + "name": "Mike Ruberry" + }, + "oid": "46b754a55b63e3168ad5854ad412c124934b675d" + } + }, + { + "commit": { + "author": { + "user": { + "login": "robieta" + }, + "email": "taylorrobie@fb.com", + "name": "Taylor Robie" + }, + "oid": "13df69e13ee571fdd716139419a00aec47ade7d6" + } + }, + { + "commit": { + "author": { + "user": { + "login": "malfet" + }, + "email": "nshulga@fb.com", + "name": "Nikita Shulga" + }, + "oid": "70642e911ec80a47cdbf4a50aac475c11aa129b6" + } + }, + { + "commit": { + "author": { + "user": { + "login": "pytorchmergebot" + }, + "email": "pytorchmergebot@users.noreply.github.com", + "name": "PyTorch MergeBot" + }, + "oid": "59bb7c39384bf3e0b284a037adef8b3caa53c1c4" + } + }, + { + "commit": { + "author": { + "user": { + "login": "malfet" + }, + "email": "nshulga@fb.com", + "name": "Nikita Shulga" + }, + "oid": "007cfb97b55d70ff63e1ed71d1a674638f847376" + } + }, + { + "commit": { + "author": { + "user": { + "login": "pytorchmergebot" + }, + "email": "pytorchmergebot@users.noreply.github.com", + "name": "PyTorch MergeBot" + }, + "oid": "0a7b858a5af1393fa3cf2853f92eca0e1d408dde" + } + }, + { + "commit": { + "author": { + "user": { + "login": "qihqi" + }, + "email": "qihan@fb.com", + "name": "Han Qi" + }, + "oid": "7917d789f0a523715041ade5177d271082628236" + } + }, + { + "commit": { + "author": { + "user": { + "login": "kit1980" + }, + "email": "sdym@fb.com", + "name": "Sergii Dymchenko (Meta Employee)" + }, + "oid": "91eb6017f0fb8a1b29e8cb48fac93bc9709f73b3" + } + }, + { + "commit": { + "author": { + "user": { + "login": "dagitses" + }, + "email": "mikeyd@fb.com", + "name": "Michael Andreas Dagitses" + }, + "oid": "bd04dca5fabb0c2a51ac87063a515f256ef274fa" + } + }, + { + "commit": { + "author": { + "user": { + "login": "dagitses" + }, + "email": "mikeyd@fb.com", + "name": "Michael Andreas Dagitses" + }, + "oid": "1f805a5defda7dabc49d0059edb9ccb06bc29352" + } + }, + { + "commit": { + "author": { + "user": null, + "email": "mruberry@fb.com", + "name": "Mike Ruberry" + }, + "oid": "4982c0a8db8f23d15ec4bfcbca4ce939afc04954" + } + }, + { + "commit": { + "author": { + "user": { + "login": "pearu" + }, + "email": "pearu.peterson@gmail.com", + "name": "Pearu Peterson" + }, + "oid": "28502265cb5925cb7db8dcb2dd2334963092714a" + } + }, + { + "commit": { + "author": { + "user": null, + "email": "mruberry@devfair044.h1.fair", + "name": "Mike Ruberry" + }, + "oid": "e03fcaedb1342e6d65c7f7f20243000938ba60b2" + } + }, + { + "commit": { + "author": { + "user": { + "login": "pritamdamania" + }, + "email": "pritam.damania@fb.com", + "name": "pritam" + }, + "oid": "efb28f5a1a5d18aa96bd668ab2ab5c651be359f3" + } + }, + { + "commit": { + "author": { + "user": { + "login": "MagiaSN" + }, + "email": "magialiao@tencent.com", + "name": "magialiao" + }, + "oid": "52cc1b9994f861ebdd3908759ed1ab11cba1f8de" + } + }, + { + "commit": { + "author": { + "user": { + "login": "pytorchmergebot" + }, + "email": "pytorchmergebot@users.noreply.github.com", + "name": "PyTorch MergeBot" + }, + "oid": "3cd99f23d1acd6a5bedf6f3b02be79d64350a5b6" + } + }, + { + "commit": { + "author": { + "user": { + "login": "awgu" + }, + "email": "andgu@fb.com", + "name": "Andrew Gu" + }, + "oid": "b00502c634a5146f4d996bd90e84d317f049e7b0" + } + }, + { + "commit": { + "author": { + "user": { + "login": "davidberard98" + }, + "email": "dberard@fb.com", + "name": "David Berard" + }, + "oid": "976eb7cee799dddfbe6a4122b249aaee1b6c8854" + } + }, + { + "commit": { + "author": { + "user": { + "login": "ngimel" + }, + "email": "ngimel@fb.com", + "name": "Natalia Gimelshein" + }, + "oid": "9608ab28744d5cae32f371490557b248c9549c66" + } + }, + { + "commit": { + "author": { + "user": { + "login": "malfet" + }, + "email": "nshulga@fb.com", + "name": "Nikita Shulga" + }, + "oid": "4e119f0c39eb5ff0777f0e71561e6b633d85fb34" + } + }, + { + "commit": { + "author": { + "user": { + "login": "rohan-varma" + }, + "email": "rvarm1@fb.com", + "name": "Rohan Varma" + }, + "oid": "447580dc565f3660eddb2c996c6ed25b88338684" + } + }, + { + "commit": { + "author": { + "user": { + "login": "malfet" + }, + "email": "nshulga@fb.com", + "name": "Nikita Shulga" + }, + "oid": "2bc8f43e9233008ea23053fab87b83ab36fca5e3" + } + }, + { + "commit": { + "author": { + "user": { + "login": "dzdang" + }, + "email": "dzdang@umich.edu", + "name": "dzdang" + }, + "oid": "c13a8e891c3e3e714f60649ca1e3b082e090e9fe" + } + }, + { + "commit": { + "author": { + "user": { + "login": "dzdang" + }, + "email": "dzdang@umich.edu", + "name": "dzdang" + }, + "oid": "fddc861b7ee473f57d3c2161e4618a2663a237e8" + } + }, + { + "commit": { + "author": { + "user": { + "login": "jiyuanzFB" + }, + "email": "jiyuanz@fb.com", + "name": "Jiyuan Zhang" + }, + "oid": "e2336dbc539d6c021720cbe43c92c9e4c8463299" + } + }, + { + "commit": { + "author": { + "user": { + "login": "bdhirsh" + }, + "email": "hirsheybar@fb.com", + "name": "Brian Hirsh" + }, + "oid": "26e2759d1ad59aac12168b74d1ca55e42ba9455c" + } + }, + { + "commit": { + "author": { + "user": { + "login": "bdhirsh" + }, + "email": "hirsheybar@fb.com", + "name": "Brian Hirsh" + }, + "oid": "ad7aa914ee3b3d1252e31514f010ba96c40aae87" + } + }, + { + "commit": { + "author": { + "user": { + "login": "bdhirsh" + }, + "email": "hirsheybar@fb.com", + "name": "Brian Hirsh" + }, + "oid": "f113c5d78065aafbe7b1c0e611945bfe9f67b3c0" + } + }, + { + "commit": { + "author": { + "user": { + "login": "bdhirsh" + }, + "email": "hirsheybar@fb.com", + "name": "Brian Hirsh" + }, + "oid": "a366fd01136292544b7862968ae92feba4b6d8fe" + } + }, + { + "commit": { + "author": { + "user": { + "login": "seemethere" + }, + "email": "eliuriegas@fb.com", + "name": "Eli Uriegas" + }, + "oid": "afeba0773749da5883c378a2e6ac066e1ce62ca0" + } + }, + { + "commit": { + "author": { + "user": { + "login": "bdhirsh" + }, + "email": "hirsheybar@fb.com", + "name": "Brian Hirsh" + }, + "oid": "d306c99addc543908f64666baeecacbd0749f4a7" + } + }, + { + "commit": { + "author": { + "user": { + "login": "awgu" + }, + "email": "andgu@fb.com", + "name": "Andrew Gu" + }, + "oid": "c2456ea658f41f64ea054a422edf22a9c977399f" + } + }, + { + "commit": { + "author": { + "user": { + "login": "awgu" + }, + "email": "andgu@fb.com", + "name": "Andrew Gu" + }, + "oid": "a8b0a1b681c9fe41e0d553c962a5c93e81d92503" + } + }, + { + "commit": { + "author": { + "user": { + "login": "anjali411" + }, + "email": "chourdiaanjali123@gmail.com", + "name": "anjali411" + }, + "oid": "af761d9a5d058c9188f16589bae4f307d35185be" + } + }, + { + "commit": { + "author": { + "user": { + "login": "clee2000" + }, + "email": "csl@fb.com", + "name": "Catherine Lee" + }, + "oid": "beceb417baef35b15c2716e23178fb49f7fd6f9d" + } + }, + { + "commit": { + "author": { + "user": null, + "email": "mruberry@devfair044.h1.fair", + "name": "Mike Ruberry" + }, + "oid": "1516554e22136db89d0aeba43a1a1a987e995d68" + } + }, + { + "commit": { + "author": { + "user": { + "login": "qihqi" + }, + "email": "qihan@fb.com", + "name": "Han Qi" + }, + "oid": "68eb1fa8374eff6cbdcf0be5e37ed6775d22e722" + } + }, + { + "commit": { + "author": { + "user": { + "login": "janeyx99" + }, + "email": "janeyx@fb.com", + "name": "Jane Xu" + }, + "oid": "3c7bcb99b5c0c879c2610f427880b03881f82f38" + } + }, + { + "commit": { + "author": { + "user": { + "login": "janeyx99" + }, + "email": "janeyx@fb.com", + "name": "Jane Xu" + }, + "oid": "38c1a2028090353e40a019c673c9ab16b39e4825" + } + }, + { + "commit": { + "author": { + "user": { + "login": "albanD" + }, + "email": "albandes@fb.com", + "name": "Alban Desmaison" + }, + "oid": "8091cbea2c95ed2c4c406b3c61547a27c6319bae" + } + }, + { + "commit": { + "author": { + "user": { + "login": "ezyang" + }, + "email": "ezyang@fb.com", + "name": "Edward Z. Yang" + }, + "oid": "d81f59121969a47c8b2213a88e02cf9be0219be9" + } + }, + { + "commit": { + "author": { + "user": { + "login": "ezyang" + }, + "email": "ezyang@fb.com", + "name": "Edward Z. Yang" + }, + "oid": "20d798b319cd107a767fe220f7a3027c18a1c844" + } + }, + { + "commit": { + "author": { + "user": { + "login": "dzdang" + }, + "email": "dzdang@umich.edu", + "name": "dzdang" + }, + "oid": "eb35381a770b58c1cd41e935910cb4df2f3d8f14" + } + }, + { + "commit": { + "author": { + "user": { + "login": "pytorchmergebot" + }, + "email": "pytorchmergebot@users.noreply.github.com", + "name": "PyTorch MergeBot" + }, + "oid": "e6498a657b9aa47546dcd92d1b4ffb2e1a50ebdb" + } + }, + { + "commit": { + "author": { + "user": { + "login": "dzdang" + }, + "email": "dzdang@umich.edu", + "name": "dzdang" + }, + "oid": "7f821382db5ad08efe5b09a145c606852b8a9272" + } + }, + { + "commit": { + "author": { + "user": { + "login": "albanD" + }, + "email": "albandes@fb.com", + "name": "Alban Desmaison" + }, + "oid": "995c0e11a97d854ff969962bd81d7341e46ecb07" + } + }, + { + "commit": { + "author": { + "user": { + "login": "davidberard98" + }, + "email": "dberard@fb.com", + "name": "David Berard" + }, + "oid": "28d6258e62c9fc361a18689877c962c69889dc23" + } + }, + { + "commit": { + "author": { + "user": { + "login": "HarborYuan" + }, + "email": "yuanhaobo@whu.edu.cn", + "name": "Haobo Yuan" + }, + "oid": "2350fad8391367ebf81c7236a2c883644b4ff622" + } + }, + { + "commit": { + "author": { + "user": { + "login": "zou3519" + }, + "email": "zou3519@gmail.com", + "name": "Richard Zou" + }, + "oid": "3f789c9ccecdd7e2e52269453646e992a68c6b92" + } + }, + { + "commit": { + "author": { + "user": { + "login": "jeffdaily" + }, + "email": "jeff.daily@amd.com", + "name": "Jeff Daily" + }, + "oid": "20f79f610c1a3314da96d49515bbfbee9442e4f8" + } + }, + { + "commit": { + "author": { + "user": { + "login": "peterbell10" + }, + "email": "peterbell10@live.co.uk", + "name": "Peter Bell" + }, + "oid": "5823958f047f3b71a5dc8c52a20eb8ae3291bd3e" + } + }, + { + "commit": { + "author": { + "user": { + "login": "peterbell10" + }, + "email": "peterbell10@live.co.uk", + "name": "Peter Bell" + }, + "oid": "a0b15c49ecf3844daf2c0dcaef44f0214259db20" + } + }, + { + "commit": { + "author": { + "user": { + "login": "ezyang" + }, + "email": "ezyang@fb.com", + "name": "Edward Z. Yang" + }, + "oid": "4afc38c25ca2ca126ba4987a419a58a5c572223b" + } + }, + { + "commit": { + "author": { + "user": { + "login": "ezyang" + }, + "email": "ezyang@fb.com", + "name": "Edward Z. Yang" + }, + "oid": "b606f58d4a36683fbe0a7d02adfdde7d5cc694c2" + } + }, + { + "commit": { + "author": { + "user": { + "login": "albanD" + }, + "email": "albandes@fb.com", + "name": "Alban Desmaison" + }, + "oid": "2d61b4d630f6482a6c3cc7437091fad6d27c347e" + } + }, + { + "commit": { + "author": { + "user": { + "login": "george-qi" + }, + "email": "georgeqi94@gmail.com", + "name": "George Qi" + }, + "oid": "bc5384c47036a6cda94129f3e2f9e43c43393698" + } + }, + { + "commit": { + "author": { + "user": { + "login": "malfet" + }, + "email": "nshulga@fb.com", + "name": "Nikita Shulga" + }, + "oid": "60fc3277634365b64465712b13db2acb76d6c890" + } + }, + { + "commit": { + "author": { + "user": { + "login": "pytorchmergebot" + }, + "email": "pytorchmergebot@users.noreply.github.com", + "name": "PyTorch MergeBot" + }, + "oid": "1b8762e95bc38d1847fe99ed3230546c8b800bfd" + } + }, + { + "commit": { + "author": { + "user": { + "login": "jerryzh168" + }, + "email": "jerryzh168@gmail.com", + "name": "Jerry Zhang" + }, + "oid": "6acf60f95f59ecbc6e8ce830dea0abba7d3ec763" + } + }, + { + "commit": { + "author": { + "user": { + "login": "ysiraichi" + }, + "email": "yukio.siraichi@gmail.com", + "name": "Yukio Siraichi" + }, + "oid": "8fb0276561fdd530c5a06ea195e930e0584f8705" + } + }, + { + "commit": { + "author": { + "user": { + "login": "albanD" + }, + "email": "albandes@fb.com", + "name": "Alban Desmaison" + }, + "oid": "1da7aed95a8700406671425eac1e4bbc2c7a24b5" + } + }, + { + "commit": { + "author": { + "user": { + "login": "thiagocrepaldi" + }, + "email": "thiago.crepaldi@microsoft.com", + "name": "Thiago Crepaldi" + }, + "oid": "83208e7dee4503c1bee1df9f6632794694dffa01" + } + }, + { + "commit": { + "author": { + "user": { + "login": "kshitij12345" + }, + "email": "kshitijkalambarkar@gmail.com", + "name": "kshitij12345" + }, + "oid": "1a46cf08dcd3d3564604c17b2c02d7e4eb45a7ff" + } + }, + { + "commit": { + "author": { + "user": { + "login": "malfet" + }, + "email": "nshulga@fb.com", + "name": "Nikita Shulga" + }, + "oid": "b7f9b6689445f826c83694652fea5f7cfc7070d7" + } + }, + { + "commit": { + "author": { + "user": { + "login": "fatcat-z" + }, + "email": "jiz@microsoft.com", + "name": "Jay Zhang" + }, + "oid": "f273961c1696b156e35f8c76f7ad37934031050d" + } + }, + { + "commit": { + "author": { + "user": { + "login": "pavithranrao" + }, + "email": "pavithran@fb.com", + "name": "Pavithran Ramachandran" + }, + "oid": "eb410a51fcbc716873fd80a970eb932d4aaaea61" + } + }, + { + "commit": { + "author": { + "user": { + "login": "ngimel" + }, + "email": "ngimel@fb.com", + "name": "Natalia Gimelshein" + }, + "oid": "7dbb12cdc02332fa64264ed0df576511a5070d7e" + } + }, + { + "commit": { + "author": { + "user": { + "login": "pytorchmergebot" + }, + "email": "pytorchmergebot@users.noreply.github.com", + "name": "PyTorch MergeBot" + }, + "oid": "43675665fa6b5154de8b25125dd03d7be35c884f" + } + }, + { + "commit": { + "author": { + "user": { + "login": "albanD" + }, + "email": "albandes@fb.com", + "name": "Alban Desmaison" + }, + "oid": "6c4d23c402c413667463770d9a2fa801f493d3c5" + } + }, + { + "commit": { + "author": { + "user": { + "login": "pytorchmergebot" + }, + "email": "pytorchmergebot@users.noreply.github.com", + "name": "PyTorch MergeBot" + }, + "oid": "cf3778a35129a40dee14366515201b7ed2c0f346" + } + }, + { + "commit": { + "author": { + "user": { + "login": "dzdang" + }, + "email": "dzdang@umich.edu", + "name": "dzdang" + }, + "oid": "9d00a051373cb81f79cb6375942cf3ec9fff2fe6" + } + }, + { + "commit": { + "author": { + "user": { + "login": "pytorchmergebot" + }, + "email": "pytorchmergebot@users.noreply.github.com", + "name": "PyTorch MergeBot" + }, + "oid": "1eae67cf404aa8dffb80b8e85180f943878d52a6" + } + }, + { + "commit": { + "author": { + "user": { + "login": "janeyx99" + }, + "email": "janeyx@fb.com", + "name": "Jane Xu" + }, + "oid": "ce0e69dcda0fe41a6e964d6ac70ce8016979c71a" + } + }, + { + "commit": { + "author": { + "user": { + "login": "swolchok" + }, + "email": "swolchok@fb.com", + "name": "Scott Wolchok" + }, + "oid": "6faba554f6e49777f24911928edb3061b6ed0e3d" + } + }, + { + "commit": { + "author": { + "user": { + "login": "IvanYashchuk" + }, + "email": "ivan.yashchuk@aalto.fi", + "name": "Ivan Yashchuk" + }, + "oid": "d1d0e03f57a359f8f95331f9a34b8bed3e7cc845" + } + }, + { + "commit": { + "author": { + "user": { + "login": "Chillee" + }, + "email": "chilli@fb.com", + "name": "Horace He" + }, + "oid": "bb46bd9233a9fc631802a902cb48a4c13c2722ca" + } + }, + { + "commit": { + "author": { + "user": { + "login": "mehtanirav" + }, + "email": "niravmehta@fb.com", + "name": "Nirav Mehta" + }, + "oid": "3b1007fe4be12e483f2620fbac67cae42e703efc" + } + }, + { + "commit": { + "author": { + "user": { + "login": "mehtanirav" + }, + "email": "niravmehta@fb.com", + "name": "Nirav Mehta" + }, + "oid": "b4b65228dd0c109f5fdf17c7d9e56f60a98e398b" + } + }, + { + "commit": { + "author": { + "user": { + "login": "albanD" + }, + "email": "albandes@fb.com", + "name": "Alban Desmaison" + }, + "oid": "d629e300705196d3ae0bac5ed983b197101fa2ee" + } + }, + { + "commit": { + "author": { + "user": { + "login": "bigfootjon" + }, + "email": "jonjanzen@fb.com", + "name": "Jon Janzen" + }, + "oid": "52754b9e515f378f8476ad44d75b0a692bad8cde" + } + }, + { + "commit": { + "author": { + "user": { + "login": "samdow" + }, + "email": "samdow@fb.com", + "name": "samdow" + }, + "oid": "128c3ad747093f4970329a82c7c4720420faeff2" + } + }, + { + "commit": { + "author": { + "user": { + "login": "arindamroy-eng" + }, + "email": "61168652+arindamroy-eng@users.noreply.github.com", + "name": "arindamroy-eng" + }, + "oid": "2a0bda7d32a5bcc9827f7254a7b77cceb16ba973" + } + } + ], + "pageInfo": { + "endCursor": "MTAw", + "hasNextPage": true + }, + "totalCount": 131 + }, + "commits": { + "nodes": [ + { + "commit": { + "checkSuites": { + "edges": [ + { + "node": { + "app": { + "name": "Facebook GitHub Tools", + "databaseId": 12274 + }, + "workflowRun": null, + "checkRuns": { + "nodes": [ + { + "name": "Facebook CLA Check", + "conclusion": "SUCCESS", + "detailsUrl": "https://code.intern.facebook.com/cla/" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAWuNRg4=", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/5696e8357cf38f852ef3d680381513e26f202371/checks?check_suite_id=6193693698" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAXEsRAI=" + }, + { + "node": { + "app": { + "name": "Netlify", + "databaseId": 13473 + }, + "workflowRun": null, + "checkRuns": { + "nodes": [], + "pageInfo": { + "endCursor": null, + "hasNextPage": false + } + }, + "conclusion": null, + "url": "https://github.com/pytorch/pytorch/commit/5696e8357cf38f852ef3d680381513e26f202371/checks?check_suite_id=6193693712" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAXEsRBA=" + }, + { + "node": { + "app": { + "name": "Azure Pipelines", + "databaseId": 9426 + }, + "workflowRun": null, + "checkRuns": { + "nodes": [], + "pageInfo": { + "endCursor": null, + "hasNextPage": false + } + }, + "conclusion": null, + "url": "https://github.com/pytorch/pytorch/commit/5696e8357cf38f852ef3d680381513e26f202371/checks?check_suite_id=6193693725" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAXEsRB0=" + }, + { + "node": { + "app": { + "name": "Dependabot", + "databaseId": 29110 + }, + "workflowRun": null, + "checkRuns": { + "nodes": [], + "pageInfo": { + "endCursor": null, + "hasNextPage": false + } + }, + "conclusion": null, + "url": "https://github.com/pytorch/pytorch/commit/5696e8357cf38f852ef3d680381513e26f202371/checks?check_suite_id=6193693741" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAXEsRC0=" + }, + { + "node": { + "app": { + "name": "Codecov", + "databaseId": 254 + }, + "workflowRun": null, + "checkRuns": { + "nodes": [], + "pageInfo": { + "endCursor": null, + "hasNextPage": false + } + }, + "conclusion": null, + "url": "https://github.com/pytorch/pytorch/commit/5696e8357cf38f852ef3d680381513e26f202371/checks?check_suite_id=6193693761" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAXEsREE=" + }, + { + "node": { + "app": { + "name": "PyTorch Bot", + "databaseId": 40112 + }, + "workflowRun": null, + "checkRuns": { + "nodes": [], + "pageInfo": { + "endCursor": null, + "hasNextPage": false + } + }, + "conclusion": null, + "url": "https://github.com/pytorch/pytorch/commit/5696e8357cf38f852ef3d680381513e26f202371/checks?check_suite_id=6193693774" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAXEsRE4=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "TorchBench CI (pytorch-linux-py3.7-cu102)" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "run-torchbench", + "conclusion": "NEUTRAL", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2197192463/jobs/3232430975" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAWuNR-Y=", + "hasNextPage": false + } + }, + "conclusion": "SKIPPED", + "url": "https://github.com/pytorch/pytorch/commit/5696e8357cf38f852ef3d680381513e26f202371/checks?check_suite_id=6193694412" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAXEsRsw=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "Lint" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "Test collect_env (with_torch)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2197192461/jobs/3232461134" + }, + { + "name": "Test collect_env (without_torch)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2197192461/jobs/3232461211" + }, + { + "name": "toc", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2197192461/jobs/3232461301" + }, + { + "name": "Test tools", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2197192461/jobs/3232461386" + }, + { + "name": "quick-checks", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2197192461/jobs/3232461521" + }, + { + "name": "lintrunner", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2197192461/jobs/3232461634" + }, + { + "name": "workflow-checks", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2197192461/jobs/3232461717" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAWuN84s=", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/5696e8357cf38f852ef3d680381513e26f202371/checks?check_suite_id=6193694417" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAXEsRtE=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "pull" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "linux-xenial-py3.7-gcc7-no-ops / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2197192471/jobs/3232460797" + }, + { + "name": "linux-bionic-py3.7-clang9 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2197192471/jobs/3232460951" + }, + { + "name": "linux-xenial-py3.7-clang7-onnx / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2197192471/jobs/3232461088" + }, + { + "name": "deploy-linux-xenial-cuda11.3-py3.7-gcc7 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2197192471/jobs/3232461294" + }, + { + "name": "linux-xenial-py3.7-gcc5.4-mobile-lightweight-dispatch-build / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2197192471/jobs/3232461410" + }, + { + "name": "linux-xenial-py3.7-clang7-asan / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2197192471/jobs/3232461543" + }, + { + "name": "linux-xenial-py3-clang5-mobile-custom-build-static / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2197192471/jobs/3232461628" + }, + { + "name": "linux-bionic-rocm5.0-py3.7 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2197192471/jobs/3232461719" + }, + { + "name": "linux-vulkan-bionic-py3.7-clang9 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2197192471/jobs/3232461789" + }, + { + "name": "linux-bionic-cuda11.3-py3.7-clang9 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2197192471/jobs/3232461869" + }, + { + "name": "pytorch-xla-linux-bionic-py3.7-clang8 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2197192471/jobs/3232461946" + }, + { + "name": "linux-xenial-cuda11.3-py3.7-gcc7 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2197192471/jobs/3232462044" + }, + { + "name": "linux-xenial-py3.7-gcc7 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2197192471/jobs/3232462112" + }, + { + "name": "pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single / build-and-test", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2197192471/jobs/3232462244" + }, + { + "name": "win-vs2019-cuda11.3-py3 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2197192471/jobs/3232462360" + }, + { + "name": "linux-xenial-py3-clang5-mobile-build / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2197192471/jobs/3232462432" + }, + { + "name": "win-vs2019-cpu-py3 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2197192471/jobs/3232462521" + }, + { + "name": "pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single-full-jit / build-and-test", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2197192471/jobs/3232462621" + }, + { + "name": "linux-xenial-py3.7-gcc5.4 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2197192471/jobs/3232462683" + }, + { + "name": "linux-xenial-cuda11.3-py3.7-gcc7-bazel-test / build-and-test", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2197192471/jobs/3232462738" + }, + { + "name": "linux-xenial-py3.7-gcc7 / test (default, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2197192471/jobs/3232545510" + }, + { + "name": "linux-xenial-py3.7-gcc7 / test (default, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2197192471/jobs/3232545571" + }, + { + "name": "linux-docs / build-docs (cpp)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2197192471/jobs/3232547522" + }, + { + "name": "linux-docs / build-docs (python)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2197192471/jobs/3232547612" + }, + { + "name": "linux-xenial-py3.7-gcc5.4 / test (default, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2197192471/jobs/3232547714" + }, + { + "name": "linux-xenial-py3.7-gcc5.4 / test (default, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2197192471/jobs/3232547764" + }, + { + "name": "linux-xenial-py3.7-gcc5.4 / test (distributed, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2197192471/jobs/3232547824" + }, + { + "name": "linux-xenial-py3.7-gcc5.4 / test (docs_test, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2197192471/jobs/3232547869" + }, + { + "name": "linux-xenial-py3.7-gcc5.4 / test (backwards_compat, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2197192471/jobs/3232547909" + }, + { + "name": "linux-xenial-py3.7-gcc5.4 / test (jit_legacy, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2197192471/jobs/3232547973" + }, + { + "name": "linux-vulkan-bionic-py3.7-clang9 / test (default, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2197192471/jobs/3232553452" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (default, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2197192471/jobs/3232553558" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (default, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2197192471/jobs/3232553605" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (crossref, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2197192471/jobs/3232553650" + }, + { + "name": "linux-xenial-py3.7-clang7-onnx / test (default, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2197192471/jobs/3232563716" + }, + { + "name": "linux-xenial-py3.7-clang7-onnx / test (default, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2197192471/jobs/3232563763" + }, + { + "name": "linux-xenial-py3.7-clang7-asan / test (default, 1, 3, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2197192471/jobs/3232582650" + }, + { + "name": "linux-xenial-py3.7-clang7-asan / test (default, 2, 3, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2197192471/jobs/3232582703" + }, + { + "name": "linux-xenial-py3.7-clang7-asan / test (default, 3, 3, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2197192471/jobs/3232582741" + }, + { + "name": "pytorch-xla-linux-bionic-py3.7-clang8 / test (xla, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2197192471/jobs/3232590204" + }, + { + "name": "linux-bionic-rocm5.0-py3.7 / test (default, 1, 2, linux.rocm.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2197192471/jobs/3232608872" + }, + { + "name": "linux-bionic-rocm5.0-py3.7 / test (default, 2, 2, linux.rocm.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2197192471/jobs/3232608976" + }, + { + "name": "linux-xenial-cuda11.3-py3.7-gcc7 / test (default, 1, 2, linux.4xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2197192471/jobs/3232637097" + }, + { + "name": "linux-xenial-cuda11.3-py3.7-gcc7 / test (default, 2, 2, linux.4xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2197192471/jobs/3232637199" + }, + { + "name": "linux-xenial-cuda11.3-py3.7-gcc7 / test (distributed, 1, 1, linux.8xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2197192471/jobs/3232637259" + }, + { + "name": "linux-xenial-cuda11.3-py3.7-gcc7 / test (deploy, 1, 1, linux.4xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2197192471/jobs/3232639932" + }, + { + "name": "win-vs2019-cpu-py3 / test (default, 1, 2, windows.4xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2197192471/jobs/3232687012" + }, + { + "name": "win-vs2019-cpu-py3 / test (default, 2, 2, windows.4xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2197192471/jobs/3232687074" + }, + { + "name": "win-vs2019-cuda11.3-py3 / test (default, 1, 2, windows.8xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2197192471/jobs/3232785088" + }, + { + "name": "win-vs2019-cuda11.3-py3 / test (default, 2, 2, windows.8xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2197192471/jobs/3232785153" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAWuVD9M=", + "hasNextPage": true + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/5696e8357cf38f852ef3d680381513e26f202371/checks?check_suite_id=6193694439" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAXEsRuc=" + } + ], + "pageInfo": { + "hasNextPage": false + } + }, + "status": null, + "pushedDate": "2022-04-20T17:10:41Z", + "oid": "5696e8357cf38f852ef3d680381513e26f202371" + } + } + ] + }, + "changedFiles": 348, + "files": { + "nodes": [ + { + "path": ".circleci/cimodel/data/pytorch_build_data.py" + }, + { + "path": ".circleci/cimodel/data/pytorch_build_definitions.py" + }, + { + "path": ".circleci/scripts/cpp_doc_push_script.sh" + }, + { + "path": ".circleci/scripts/python_doc_push_script.sh" + }, + { + "path": ".github/actions/checkout-pytorch/action.yml" + }, + { + "path": ".github/merge_rules.json" + }, + { + "path": ".github/scripts/gitutils.py" + }, + { + "path": ".github/scripts/gql_mocks.json" + }, + { + "path": ".github/scripts/trymerge.py" + }, + { + "path": ".github/workflows/_bazel-build-test.yml" + }, + { + "path": ".github/workflows/_linux-build.yml" + }, + { + "path": ".github/workflows/_linux-test.yml" + }, + { + "path": ".github/workflows/_mac-test.yml" + }, + { + "path": ".github/workflows/_rocm-test.yml" + }, + { + "path": ".github/workflows/_win-test.yml" + }, + { + "path": ".github/workflows/buck_build_test.yml" + }, + { + "path": ".github/workflows/lint.yml" + }, + { + "path": ".github/workflows/periodic.yml" + }, + { + "path": ".github/workflows/pull.yml" + }, + { + "path": ".github/workflows/trunk.yml" + }, + { + "path": ".jenkins/pytorch/macos-test.sh" + }, + { + "path": ".jenkins/pytorch/test.sh" + }, + { + "path": ".jenkins/pytorch/win-test.sh" + }, + { + "path": ".lintrunner.toml" + }, + { + "path": "BUILD.bazel" + }, + { + "path": "CODEOWNERS" + }, + { + "path": "README.md" + }, + { + "path": "aten/src/ATen/BatchingRegistrations.cpp" + }, + { + "path": "aten/src/ATen/Dispatch.h" + }, + { + "path": "aten/src/ATen/ExpandUtils.h" + }, + { + "path": "aten/src/ATen/FunctionalInverses.cpp" + }, + { + "path": "aten/src/ATen/FunctionalStorageImpl.cpp" + }, + { + "path": "aten/src/ATen/FunctionalStorageImpl.h" + }, + { + "path": "aten/src/ATen/FunctionalTensorWrapper.cpp" + }, + { + "path": "aten/src/ATen/FunctionalTensorWrapper.h" + }, + { + "path": "aten/src/ATen/FunctionalizeFallbackKernel.cpp" + }, + { + "path": "aten/src/ATen/NestedTensorImpl.cpp" + }, + { + "path": "aten/src/ATen/OpMathType.h" + }, + { + "path": "aten/src/ATen/SparseCsrTensorUtils.h" + }, + { + "path": "aten/src/ATen/ThreadLocalState.cpp" + }, + { + "path": "aten/src/ATen/ThreadLocalState.h" + }, + { + "path": "aten/src/ATen/autocast_mode.cpp" + }, + { + "path": "aten/src/ATen/autocast_mode.h" + }, + { + "path": "aten/src/ATen/core/SymIntArrayRef.cpp" + }, + { + "path": "aten/src/ATen/core/SymIntArrayRef.h" + }, + { + "path": "aten/src/ATen/core/TensorBase.h" + }, + { + "path": "aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h" + }, + { + "path": "aten/src/ATen/core/dispatch/Dispatcher.h" + }, + { + "path": "aten/src/ATen/core/interned_strings.h" + }, + { + "path": "aten/src/ATen/core/ivalue.cpp" + }, + { + "path": "aten/src/ATen/core/ivalue.h" + }, + { + "path": "aten/src/ATen/core/ivalue_inl.h" + }, + { + "path": "aten/src/ATen/core/jit_type.h" + }, + { + "path": "aten/src/ATen/core/jit_type_base.h" + }, + { + "path": "aten/src/ATen/core/type.cpp" + }, + { + "path": "aten/src/ATen/cuda/CUDASparse.h" + }, + { + "path": "aten/src/ATen/cuda/llvm_complex.cpp" + }, + { + "path": "aten/src/ATen/cuda/llvm_jit_strings.h" + }, + { + "path": "aten/src/ATen/native/Blas.cpp" + }, + { + "path": "aten/src/ATen/native/Itertools.cpp" + }, + { + "path": "aten/src/ATen/native/LinearAlgebra.cpp" + }, + { + "path": "aten/src/ATen/native/SoftMax.cpp" + }, + { + "path": "aten/src/ATen/native/TensorConversions.cpp" + }, + { + "path": "aten/src/ATen/native/TensorShape.cpp" + }, + { + "path": "aten/src/ATen/native/TensorShape.h" + }, + { + "path": "aten/src/ATen/native/Unique.cpp" + }, + { + "path": "aten/src/ATen/native/cuda/BinaryMiscBackwardOpsKernels.cu" + }, + { + "path": "aten/src/ATen/native/cuda/CUDAJitLoops.cuh" + }, + { + "path": "aten/src/ATen/native/cuda/JitLoops.cuh" + }, + { + "path": "aten/src/ATen/native/cuda/Lerp.cu" + }, + { + "path": "aten/src/ATen/native/cuda/PersistentSoftmax.cuh" + }, + { + "path": "aten/src/ATen/native/cuda/SoftMax.cu" + }, + { + "path": "aten/src/ATen/native/cuda/UnarySpecialOpsKernel.cu" + }, + { + "path": "aten/src/ATen/native/cuda/Unique.cu" + }, + { + "path": "aten/src/ATen/native/cuda/jit_utils.cpp" + }, + { + "path": "aten/src/ATen/native/cuda/jit_utils.h" + }, + { + "path": "aten/src/ATen/native/native_functions.yaml" + }, + { + "path": "aten/src/ATen/native/nested/NestedTensorMath.cpp" + }, + { + "path": "aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp" + }, + { + "path": "aten/src/ATen/native/quantized/cpu/qsoftmax.cpp" + }, + { + "path": "aten/src/ATen/native/quantized/cudnn/BinaryOps.cpp" + }, + { + "path": "aten/src/ATen/native/quantized/cudnn/Linear.cpp" + }, + { + "path": "aten/src/ATen/native/quantized/cudnn/utils.h" + }, + { + "path": "aten/src/ATen/native/sparse/SparseCsrTensor.cpp" + }, + { + "path": "aten/src/ATen/native/ts_native_functions.yaml" + }, + { + "path": "aten/src/ATen/record_function.cpp" + }, + { + "path": "aten/src/ATen/record_function.h" + }, + { + "path": "aten/src/ATen/templates/Operators.h" + }, + { + "path": "aten/src/ATen/templates/RegisterFunctionalization.cpp" + }, + { + "path": "aten/src/ATen/test/basic.cpp" + }, + { + "path": "aten/src/ATen/test/vmap_test.cpp" + }, + { + "path": "binaries/record_function_benchmark.cc" + }, + { + "path": "c10/core/DispatchKey.cpp" + }, + { + "path": "c10/core/DispatchKey.h" + }, + { + "path": "c10/core/DispatchKeySet.h" + }, + { + "path": "c10/test/core/DispatchKeySet_test.cpp" + }, + { + "path": "c10/util/ArrayRef.h" + }, + { + "path": "caffe2/core/tensor.h" + }, + { + "path": "docs/source/conf.py" + }, + { + "path": "docs/source/fx.rst" + } + ], + "pageInfo": { + "endCursor": "MTAw", + "hasNextPage": true + } + }, + "reviews": { + "nodes": [], + "pageInfo": { + "startCursor": null, + "hasPreviousPage": false + } + }, + "comments": { + "nodes": [ + { + "bodyText": "Merge failed due to Matched rule superuser, but it was not reviewed yet by any of:zou3519,abhikrish,mehtanirav,wconstab,lc0, ...", + "createdAt": "2022-04-20T17:26:18Z", + "author": { + "login": "pytorchmergebot" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1104215370 + }, + { + "bodyText": "Merge failed due to Matched rule superuser, but PR has not been reviewed yet", + "createdAt": "2022-04-20T17:31:26Z", + "author": { + "login": "pytorchmergebot" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1104220908 + }, + { + "bodyText": "@pytorchbot merge this", + "createdAt": "2022-04-20T19:30:50Z", + "author": { + "login": "malfet" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1104378397 + }, + { + "bodyText": "Merge failed due to Matched rule superuser, but PR has not been reviewed yet\nRaised by https://github.com/pytorch/pytorch/actions/runs/2197877090", + "createdAt": "2022-04-20T19:32:10Z", + "author": { + "login": "pytorchmergebot" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1104379712 + }, + { + "bodyText": "Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale. Feel free to remove the Stale label if you feel this was a mistake. If you are unable to remove the Stale label please contact a maintainer in order to do so. If you want the bot to never mark this PR stale again, add the no-stale label.Stale pull requests will automatically be closed after 30 days of inactivity.", + "createdAt": "2022-06-20T16:44:05Z", + "author": { + "login": "github-actions" + }, + "authorAssociation": "NONE", + "editor": null, + "databaseId": 1160658699 + } + ], + "pageInfo": { + "startCursor": "Y3Vyc29yOnYyOpHOQdD9Sg==", + "hasPreviousPage": true + } + }, + "labels": { + "edges": [ + { + "node": { + "name": "cla signed" + } + }, + { + "node": { + "name": "Stale" + } + } + ] + } + } + } + } + }, + "query_sha=81fd873151c3cded18314e9e53bf54a93ffb0afa9c52fa2cbafb2ceab7df5e45 name=pytorch number=76123 owner=pytorch": { + "data": { + "repository": { + "pullRequest": { + "closed": true, + "isCrossRepository": true, + "author": { + "login": "kumpera" + }, + "title": "Introduce distributed checkpoint with ShardedTensor.", + "body": "Co-authored-by: Wen Zhang \r\nCo-authored-by: Yifu Wang \r\n\r\n", + "headRefName": "st_checkpoint", + "headRepository": { + "nameWithOwner": "kumpera/pytorch" + }, + "baseRefName": "master", + "baseRepository": { + "nameWithOwner": "pytorch/pytorch", + "isPrivate": false, + "defaultBranchRef": { + "name": "master" + } + }, + "mergeCommit": null, + "commits_with_authors": { + "nodes": [ + { + "commit": { + "author": { + "user": { + "login": "kumpera" + }, + "email": "kumpera@fb.com", + "name": "Rodrigo Kumpera" + }, + "oid": "6bf248bc20a71f248064b795f38276326fe43aae" + } + }, + { + "commit": { + "author": { + "user": { + "login": "kumpera" + }, + "email": "kumpera@fb.com", + "name": "Rodrigo Kumpera" + }, + "oid": "10f84fb90bf02d7062e565ebf2c1da6352b64db7" + } + }, + { + "commit": { + "author": { + "user": { + "login": "kumpera" + }, + "email": "kumpera@fb.com", + "name": "Rodrigo Kumpera" + }, + "oid": "96c5299740ec791f3cf0975c03a40a7b219b6747" + } + } + ], + "pageInfo": { + "endCursor": "Mw", + "hasNextPage": false + }, + "totalCount": 3 + }, + "commits": { + "nodes": [ + { + "commit": { + "checkSuites": { + "edges": [ + { + "node": { + "app": { + "name": "Facebook GitHub Tools", + "databaseId": 12274 + }, + "workflowRun": null, + "checkRuns": { + "nodes": [ + { + "name": "Facebook CLA Check", + "conclusion": "SUCCESS", + "detailsUrl": "https://code.intern.facebook.com/cla/" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAXgS2l4=", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/96c5299740ec791f3cf0975c03a40a7b219b6747/checks?check_suite_id=6380755666" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAXxSmtI=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "TorchBench CI (pytorch-linux-py3.7-cu102)" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "run-torchbench", + "conclusion": "NEUTRAL", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2273063614/jobs/3379894109" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAXd2r3Q=", + "hasNextPage": false + } + }, + "conclusion": "SKIPPED", + "url": "https://github.com/pytorch/pytorch/commit/96c5299740ec791f3cf0975c03a40a7b219b6747/checks?check_suite_id=6380755785" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAXxSm0k=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "Lint" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "quick-checks", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2273063615/jobs/3379894107" + }, + { + "name": "toc", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2273063615/jobs/3379894332" + }, + { + "name": "lintrunner", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2273063615/jobs/3379894444" + }, + { + "name": "Test collect_env (with_torch)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2273063615/jobs/3379894520" + }, + { + "name": "Test collect_env (without_torch)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2273063615/jobs/3379894567" + }, + { + "name": "Test tools", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2273063615/jobs/3379894616" + }, + { + "name": "workflow-checks", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2273063615/jobs/3379894672" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAXd2shU=", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/96c5299740ec791f3cf0975c03a40a7b219b6747/checks?check_suite_id=6380755786" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAXxSm0o=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "pull" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "linux-xenial-cuda11.3-py3.7-gcc7-bazel-test / build-and-test", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2273063632/jobs/3379902301" + }, + { + "name": "linux-bionic-cuda11.3-py3.7-clang9 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2273063632/jobs/3379902363" + }, + { + "name": "linux-xenial-py3.7-gcc5.4-mobile-lightweight-dispatch-build / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2273063632/jobs/3379902507" + }, + { + "name": "linux-bionic-py3.7-clang9 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2273063632/jobs/3379902560" + }, + { + "name": "win-vs2019-cpu-py3 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2273063632/jobs/3379902579" + }, + { + "name": "linux-xenial-py3.7-clang7-onnx / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2273063632/jobs/3379902603" + }, + { + "name": "linux-vulkan-bionic-py3.7-clang9 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2273063632/jobs/3379902637" + }, + { + "name": "pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single / build-and-test", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2273063632/jobs/3379902685" + }, + { + "name": "pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single-full-jit / build-and-test", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2273063632/jobs/3379902740" + }, + { + "name": "linux-xenial-py3.7-clang7-asan / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2273063632/jobs/3379902761" + }, + { + "name": "linux-xenial-py3-clang5-mobile-custom-build-static / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2273063632/jobs/3379902794" + }, + { + "name": "linux-xenial-cuda11.3-py3.7-gcc7 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2273063632/jobs/3379902874" + }, + { + "name": "linux-xenial-py3.7-gcc5.4 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2273063632/jobs/3379903006" + }, + { + "name": "linux-xenial-py3.7-gcc7-no-ops / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2273063632/jobs/3379903111" + }, + { + "name": "linux-xenial-py3-clang5-mobile-build / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2273063632/jobs/3379903193" + }, + { + "name": "linux-xenial-py3.7-gcc7 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2273063632/jobs/3379903284" + }, + { + "name": "win-vs2019-cuda11.3-py3 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2273063632/jobs/3379903357" + }, + { + "name": "deploy-linux-xenial-cuda11.3-py3.7-gcc7 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2273063632/jobs/3379903446" + }, + { + "name": "pytorch-xla-linux-bionic-py3.7-clang8 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2273063632/jobs/3379903512" + }, + { + "name": "linux-bionic-rocm5.1-py3.7 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2273063632/jobs/3379903546" + }, + { + "name": "linux-xenial-py3.7-gcc7 / test (default, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2273063632/jobs/3379944655" + }, + { + "name": "linux-xenial-py3.7-gcc7 / test (default, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2273063632/jobs/3379944695" + }, + { + "name": "linux-docs / build-docs (cpp)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2273063632/jobs/3379946308" + }, + { + "name": "linux-docs / build-docs (python)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2273063632/jobs/3379946337" + }, + { + "name": "linux-xenial-py3.7-gcc5.4 / test (default, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2273063632/jobs/3379946359" + }, + { + "name": "linux-xenial-py3.7-gcc5.4 / test (default, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2273063632/jobs/3379946391" + }, + { + "name": "linux-xenial-py3.7-gcc5.4 / test (distributed, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2273063632/jobs/3379946423" + }, + { + "name": "linux-xenial-py3.7-gcc5.4 / test (docs_test, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2273063632/jobs/3379946453" + }, + { + "name": "linux-xenial-py3.7-gcc5.4 / test (backwards_compat, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2273063632/jobs/3379946496" + }, + { + "name": "linux-xenial-py3.7-gcc5.4 / test (jit_legacy, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2273063632/jobs/3379946529" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (default, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2273063632/jobs/3379950041" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (default, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2273063632/jobs/3379950137" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (crossref, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2273063632/jobs/3379950165" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (crossref, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2273063632/jobs/3379950192" + }, + { + "name": "linux-vulkan-bionic-py3.7-clang9 / test (default, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2273063632/jobs/3379950646" + }, + { + "name": "linux-xenial-py3.7-clang7-onnx / test (default, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2273063632/jobs/3379951202" + }, + { + "name": "linux-xenial-py3.7-clang7-onnx / test (default, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2273063632/jobs/3379951230" + }, + { + "name": "linux-xenial-py3.7-clang7-asan / test (default, 1, 4, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2273063632/jobs/3379963877" + }, + { + "name": "linux-xenial-py3.7-clang7-asan / test (default, 2, 4, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2273063632/jobs/3379963928" + }, + { + "name": "linux-xenial-py3.7-clang7-asan / test (default, 3, 4, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2273063632/jobs/3379963976" + }, + { + "name": "linux-xenial-py3.7-clang7-asan / test (default, 4, 4, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2273063632/jobs/3379964018" + }, + { + "name": "pytorch-xla-linux-bionic-py3.7-clang8 / test (xla, 1, 1, linux.2xlarge)", + "conclusion": "FAILURE", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2273063632/jobs/3379966372" + }, + { + "name": "linux-bionic-rocm5.1-py3.7 / test (default, 1, 2, linux.rocm.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2273063632/jobs/3379996173" + }, + { + "name": "linux-bionic-rocm5.1-py3.7 / test (default, 2, 2, linux.rocm.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2273063632/jobs/3379996218" + }, + { + "name": "linux-xenial-cuda11.3-py3.7-gcc7 / test (deploy, 1, 1, linux.4xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2273063632/jobs/3379997861" + }, + { + "name": "linux-xenial-cuda11.3-py3.7-gcc7 / test (default, 1, 2, linux.4xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2273063632/jobs/3379998374" + }, + { + "name": "linux-xenial-cuda11.3-py3.7-gcc7 / test (default, 2, 2, linux.4xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2273063632/jobs/3379998397" + }, + { + "name": "linux-xenial-cuda11.3-py3.7-gcc7 / test (distributed, 1, 2, linux.8xlarge.nvidia.gpu)", + "conclusion": "FAILURE", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2273063632/jobs/3379998422" + }, + { + "name": "linux-xenial-cuda11.3-py3.7-gcc7 / test (distributed, 2, 2, linux.8xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2273063632/jobs/3379998441" + }, + { + "name": "win-vs2019-cpu-py3 / test (default, 1, 2, windows.4xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2273063632/jobs/3380042106" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAXd5yuY=", + "hasNextPage": true + } + }, + "conclusion": "FAILURE", + "url": "https://github.com/pytorch/pytorch/commit/96c5299740ec791f3cf0975c03a40a7b219b6747/checks?check_suite_id=6380755806" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAXxSm14=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "Lint" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "lintrunner", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2276796859/jobs/3387419477" + }, + { + "name": "quick-checks", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2276796859/jobs/3387419699" + }, + { + "name": "Test collect_env (with_torch)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2276796859/jobs/3387419923" + }, + { + "name": "Test collect_env (without_torch)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2276796859/jobs/3387419992" + }, + { + "name": "Test tools", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2276796859/jobs/3387420129" + }, + { + "name": "workflow-checks", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2276796859/jobs/3387420208" + }, + { + "name": "toc", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2276796859/jobs/3387420309" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAXgS3SE=", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/96c5299740ec791f3cf0975c03a40a7b219b6747/checks?check_suite_id=6390363240" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAXzlNGg=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "TorchBench CI (pytorch-linux-py3.7-cu102)" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "run-torchbench", + "conclusion": "NEUTRAL", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2276796862/jobs/3387419465" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAXgS1-o=", + "hasNextPage": false + } + }, + "conclusion": "SKIPPED", + "url": "https://github.com/pytorch/pytorch/commit/96c5299740ec791f3cf0975c03a40a7b219b6747/checks?check_suite_id=6390363271" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAXzlNIc=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "pull" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "linux-bionic-rocm5.1-py3.7 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2276796865/jobs/3387419999" + }, + { + "name": "linux-xenial-cuda11.3-py3.7-gcc7 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2276796865/jobs/3387420164" + }, + { + "name": "linux-xenial-py3.7-clang7-asan / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2276796865/jobs/3387420316" + }, + { + "name": "linux-xenial-py3.7-gcc7-no-ops / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2276796865/jobs/3387420477" + }, + { + "name": "pytorch-xla-linux-bionic-py3.7-clang8 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2276796865/jobs/3387420675" + }, + { + "name": "linux-bionic-py3.7-clang9 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2276796865/jobs/3387420934" + }, + { + "name": "linux-vulkan-bionic-py3.7-clang9 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2276796865/jobs/3387421278" + }, + { + "name": "linux-xenial-cuda11.3-py3.7-gcc7-bazel-test / build-and-test", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2276796865/jobs/3387421672" + }, + { + "name": "linux-xenial-py3-clang5-mobile-build / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2276796865/jobs/3387421888" + }, + { + "name": "linux-xenial-py3.7-clang7-onnx / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2276796865/jobs/3387421982" + }, + { + "name": "deploy-linux-xenial-cuda11.3-py3.7-gcc7 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2276796865/jobs/3387422191" + }, + { + "name": "linux-xenial-py3.7-gcc5.4-mobile-lightweight-dispatch-build / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2276796865/jobs/3387422303" + }, + { + "name": "linux-xenial-py3-clang5-mobile-custom-build-static / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2276796865/jobs/3387422476" + }, + { + "name": "linux-bionic-cuda11.3-py3.7-clang9 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2276796865/jobs/3387422715" + }, + { + "name": "linux-xenial-py3.7-gcc5.4 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2276796865/jobs/3387422963" + }, + { + "name": "pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single / build-and-test", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2276796865/jobs/3387423092" + }, + { + "name": "linux-xenial-py3.7-gcc7 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2276796865/jobs/3387423234" + }, + { + "name": "win-vs2019-cpu-py3 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2276796865/jobs/3387423421" + }, + { + "name": "pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single-full-jit / build-and-test", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2276796865/jobs/3387423622" + }, + { + "name": "win-vs2019-cuda11.3-py3 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2276796865/jobs/3387423739" + }, + { + "name": "linux-vulkan-bionic-py3.7-clang9 / test (default, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2276796865/jobs/3387545789" + }, + { + "name": "linux-xenial-py3.7-gcc7 / test (default, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2276796865/jobs/3387546032" + }, + { + "name": "linux-xenial-py3.7-gcc7 / test (default, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2276796865/jobs/3387546119" + }, + { + "name": "linux-xenial-py3.7-gcc5.4 / test (default, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2276796865/jobs/3387553028" + }, + { + "name": "linux-xenial-py3.7-gcc5.4 / test (default, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2276796865/jobs/3387553144" + }, + { + "name": "linux-xenial-py3.7-gcc5.4 / test (distributed, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2276796865/jobs/3387553251" + }, + { + "name": "linux-xenial-py3.7-gcc5.4 / test (docs_test, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2276796865/jobs/3387553438" + }, + { + "name": "linux-xenial-py3.7-gcc5.4 / test (backwards_compat, 1, 1, linux.2xlarge)", + "conclusion": "FAILURE", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2276796865/jobs/3387553556" + }, + { + "name": "linux-xenial-py3.7-gcc5.4 / test (jit_legacy, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2276796865/jobs/3387553668" + }, + { + "name": "linux-docs / build-docs (cpp)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2276796865/jobs/3387554002" + }, + { + "name": "linux-docs / build-docs (python)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2276796865/jobs/3387554098" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (default, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2276796865/jobs/3387558927" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (default, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2276796865/jobs/3387559016" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (crossref, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2276796865/jobs/3387559071" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (crossref, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2276796865/jobs/3387559139" + }, + { + "name": "linux-xenial-py3.7-clang7-onnx / test (default, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2276796865/jobs/3387563803" + }, + { + "name": "linux-xenial-py3.7-clang7-onnx / test (default, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2276796865/jobs/3387563894" + }, + { + "name": "linux-xenial-py3.7-clang7-asan / test (default, 1, 4, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2276796865/jobs/3387580868" + }, + { + "name": "linux-xenial-py3.7-clang7-asan / test (default, 2, 4, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2276796865/jobs/3387580936" + }, + { + "name": "linux-xenial-py3.7-clang7-asan / test (default, 3, 4, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2276796865/jobs/3387580993" + }, + { + "name": "linux-xenial-py3.7-clang7-asan / test (default, 4, 4, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2276796865/jobs/3387581053" + }, + { + "name": "pytorch-xla-linux-bionic-py3.7-clang8 / test (xla, 1, 1, linux.2xlarge)", + "conclusion": "FAILURE", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2276796865/jobs/3387592286" + }, + { + "name": "linux-bionic-rocm5.1-py3.7 / test (default, 1, 2, linux.rocm.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2276796865/jobs/3387631950" + }, + { + "name": "linux-bionic-rocm5.1-py3.7 / test (default, 2, 2, linux.rocm.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2276796865/jobs/3387632035" + }, + { + "name": "linux-xenial-cuda11.3-py3.7-gcc7 / test (default, 1, 2, linux.4xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2276796865/jobs/3387649916" + }, + { + "name": "linux-xenial-cuda11.3-py3.7-gcc7 / test (default, 2, 2, linux.4xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2276796865/jobs/3387649974" + }, + { + "name": "linux-xenial-cuda11.3-py3.7-gcc7 / test (distributed, 1, 2, linux.8xlarge.nvidia.gpu)", + "conclusion": "FAILURE", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2276796865/jobs/3387650084" + }, + { + "name": "linux-xenial-cuda11.3-py3.7-gcc7 / test (distributed, 2, 2, linux.8xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2276796865/jobs/3387650151" + }, + { + "name": "linux-xenial-cuda11.3-py3.7-gcc7 / test (deploy, 1, 1, linux.4xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2276796865/jobs/3387650373" + }, + { + "name": "win-vs2019-cpu-py3 / test (default, 1, 2, windows.4xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2276796865/jobs/3387753429" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAXgaCXo=", + "hasNextPage": true + } + }, + "conclusion": "FAILURE", + "url": "https://github.com/pytorch/pytorch/commit/96c5299740ec791f3cf0975c03a40a7b219b6747/checks?check_suite_id=6390363300" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAXzlNKQ=" + } + ], + "pageInfo": { + "hasNextPage": false + } + }, + "status": null, + "pushedDate": "2022-05-05T00:34:26Z", + "oid": "96c5299740ec791f3cf0975c03a40a7b219b6747" + } + } + ] + }, + "changedFiles": 11, + "files": { + "nodes": [ + { + "path": "test/distributed/_shard/checkpoint/test_checkpoint.py" + }, + { + "path": "test/distributed/_shard/checkpoint/test_file_system_checkpoint.py" + }, + { + "path": "test/distributed/_shard/sharded_tensor/test_sharded_tensor.py" + }, + { + "path": "torch/distributed/_shard/checkpoint/__init__.py" + }, + { + "path": "torch/distributed/_shard/checkpoint/filesystem.py" + }, + { + "path": "torch/distributed/_shard/checkpoint/metadata.py" + }, + { + "path": "torch/distributed/_shard/checkpoint/resharding.py" + }, + { + "path": "torch/distributed/_shard/checkpoint/state_dict_loader.py" + }, + { + "path": "torch/distributed/_shard/checkpoint/state_dict_saver.py" + }, + { + "path": "torch/distributed/_shard/checkpoint/storage.py" + }, + { + "path": "torch/testing/_internal/distributed/_shard/sharded_tensor/_test_st_common.py" + } + ], + "pageInfo": { + "endCursor": "MTE", + "hasNextPage": false + } + }, + "reviews": { + "nodes": [ + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "zzzwen" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "zzzwen" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "wanchaol" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "zzzwen" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "zzzwen" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "simpkins" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "zzzwen" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "zzzwen" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "simpkins" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "simpkins" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "pritamdamania87" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "pritamdamania87" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "pritamdamania87" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "wilson100hong" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "wilson100hong" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "wilson100hong" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "xunnanxu" + }, + "state": "DISMISSED" + }, + { + "author": { + "login": "xunnanxu" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "xunnanxu" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "xunnanxu" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "xunnanxu" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "xunnanxu" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "pritamdamania87" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "pritamdamania87" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "pritamdamania87" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "pritamdamania87" + }, + "state": "APPROVED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kumpera" + }, + "state": "COMMENTED" + } + ], + "pageInfo": { + "startCursor": "Y3Vyc29yOnYyOpO5MjAyMi0wNC0yNVQxMTozNTowMS0wNzowMLkyMDIyLTA0LTI1VDExOjM1OjAwLTA3OjAwzjjC2d0=", + "hasPreviousPage": true + } + }, + "comments": { + "nodes": [ + { + "bodyText": "Merge failed due to Can't fetch all PR reviews\nRaised by https://github.com/pytorch/pytorch/actions/runs/2275691136", + "createdAt": "2022-05-05T12:35:49Z", + "author": { + "login": "pytorchmergebot" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1118495479 + }, + { + "bodyText": "Merge failed due to Can't fetch all PR reviews\nRaised by https://github.com/pytorch/pytorch/actions/runs/2275691136", + "createdAt": "2022-05-05T12:53:15Z", + "author": { + "login": "pytorchmergebot" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1118511287 + }, + { + "bodyText": "Merge failed due to Can't fetch all PR reviews\nRaised by https://github.com/pytorch/pytorch/actions/runs/2275691136", + "createdAt": "2022-05-05T15:00:08Z", + "author": { + "login": "pytorchmergebot" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1118662274 + }, + { + "bodyText": "Merge failed due to Can't fetch all PR reviews Raised by https://github.com/pytorch/pytorch/actions/runs/2275691136\n\n@osalpekar @malfet This is failing because there are 109 review comments on this PR but we only fetch the first 100. This could be solved with a similar concept as how we fetch more comments/check_runs.", + "createdAt": "2022-05-05T15:20:46Z", + "author": { + "login": "janeyx99" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1118689010 + }, + { + "bodyText": "On a side note, has the test_fsdp_clip_grad_norm_norm_type_2_0_nested_fsdp_False_cpu_offload_CPUOffload failure on the distributed test first shard of this PR been addressed?", + "createdAt": "2022-05-05T15:24:08Z", + "author": { + "login": "janeyx99" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1118693497 + } + ], + "pageInfo": { + "startCursor": "Y3Vyc29yOnYyOpHOQqri9w==", + "hasPreviousPage": true + } + }, + "labels": { + "edges": [ + { + "node": { + "name": "oncall: distributed" + } + }, + { + "node": { + "name": "cla signed" + } + } + ] + } + } + } + } + }, + "query_sha=81fd873151c3cded18314e9e53bf54a93ffb0afa9c52fa2cbafb2ceab7df5e45 name=pytorch number=71759 owner=pytorch": { + "data": { + "repository": { + "pullRequest": { + "closed": true, + "isCrossRepository": true, + "author": { + "login": "coolteemf" + }, + "title": "Optimize grid sample 3d", + "body": "Fixes #71415\r\nI have implemented the changes that replicate what @to-mi did in this [PR](https://github.com/pytorch/pytorch/pull/65986#issue-1012959443) for the 3D case :\r\n\r\n> Fixes #64977\r\n> \r\n> Avoids creating a tensor for and calculating `input` gradient if it's not needed in the backward pass of `grid_sample` (2d case, native CPU & CUDA kernels). Especially the tensor creation seemed time consuming (see #64977).\r\n> \r\n> Brief description of the changes:\r\n> \r\n> * I have tried to go with rather minimal changes. It would probably be possible to make a more elegant version with a bit larger refactoring (or possibly with better understanding of PyTorch internals and C++ functionalities).\r\n> \r\n> * Changed the `native_functions.yaml` and `derivatives.yaml` so that the gradient input mask is passed to the functions.\r\n> \r\n> * Changed the CPU kernels:\r\n> (1) added `bool input_requires_grad` template parameter to the `backward` function,\r\n> (2) added if branches based on it to remove `input` gradient computations if it's not requested,\r\n> (3) feed in `TensorAccessor* gInp_slice_ptr` instead of `TensorAccessor& gInp_slice` so that I can pass a `nullptr` in case gradient for `input` is not requested. (A bit inelegant perhaps, but allows to keep one signature for `backward` function and not require breaking it to smaller pieces. Perhaps there's a more elegant way to achieve this?)\r\n> \r\n> * Changed CUDA kernel:\r\n> (1) added ~`bool input_requires_grad` template parameter~ `const bool input_requires_grad` argument to the `backward` function,\r\n> (2) added if branches based on it to remove `input` gradient computations if it's not requested,\r\n> (3) feed in `TensorInfo()` instead of `getTensorInfo(grad_input)` in case gradient for `input` is not requested.\r\n> \r\n> * Modified tests in `test/test_nn.py` so that they run also cases with no `input` gradient needed.\r\n> \r\n> * Have not touched the CPU fallback kernel.\r\n\r\nNote: the changes number (3) are N/A in this case.\r\n\r\n", + "headRefName": "optimize_grid_sample_3d", + "headRepository": { + "nameWithOwner": "coolteemf/pytorch" + }, + "baseRefName": "master", + "baseRepository": { + "nameWithOwner": "pytorch/pytorch", + "isPrivate": false, + "defaultBranchRef": { + "name": "master" + } + }, + "mergeCommit": null, + "commits_with_authors": { + "nodes": [ + { + "commit": { + "author": { + "user": null, + "email": "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", + "name": "coolteemf" + }, + "oid": "e0b0d1e695aeddceaf265da602c4704592053e9e" + } + }, + { + "commit": { + "author": { + "user": null, + "email": "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", + "name": "coolteemf" + }, + "oid": "563ec73747ad53b63b36736c47c4342f962c2a09" + } + }, + { + "commit": { + "author": { + "user": null, + "email": "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", + "name": "coolteemf" + }, + "oid": "51abe41a132d9dd5b1c0551bdca902aacc028ff8" + } + }, + { + "commit": { + "author": { + "user": null, + "email": "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", + "name": "coolteemf" + }, + "oid": "be9898205992034a00e8ace8a55c2ecdcee2c2f8" + } + }, + { + "commit": { + "author": { + "user": null, + "email": "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", + "name": "coolteemf" + }, + "oid": "2929c60b64384c2deae0f7dea8bab94ad4bc9ec8" + } + }, + { + "commit": { + "author": { + "user": null, + "email": "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", + "name": "coolteemf" + }, + "oid": "9241b737e7e2b257905cc74ad9c50b737d7f9d0a" + } + }, + { + "commit": { + "author": { + "user": null, + "email": "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", + "name": "coolteemf" + }, + "oid": "64d6b795d0636928a8aa2fd3da01302fb5f5f7af" + } + }, + { + "commit": { + "author": { + "user": null, + "email": "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", + "name": "coolteemf" + }, + "oid": "4503577e53760a0006f1e80ca6bfe04d2be90470" + } + }, + { + "commit": { + "author": { + "user": null, + "email": "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", + "name": "coolteemf" + }, + "oid": "b16f4b11ffbbbf2ca2098f9702af4ef6b6fc5e1f" + } + }, + { + "commit": { + "author": { + "user": null, + "email": "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", + "name": "coolteemf" + }, + "oid": "7ffc23368a604afdc92d2818747f730ce31a2bb5" + } + }, + { + "commit": { + "author": { + "user": null, + "email": "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", + "name": "coolteemf" + }, + "oid": "b85292604b9ad6c31706b76b5a5498c4f6d94309" + } + }, + { + "commit": { + "author": { + "user": null, + "email": "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", + "name": "coolteemf" + }, + "oid": "9d81d7bae8ad91aaa24b3ceab83e3138894dbc69" + } + }, + { + "commit": { + "author": { + "user": null, + "email": "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", + "name": "coolteemf" + }, + "oid": "e79f6a2202512b294c55bf4bfb2e0524fafd4c48" + } + }, + { + "commit": { + "author": { + "user": null, + "email": "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", + "name": "coolteemf" + }, + "oid": "f683e8aec7aea76097a264eec01511e704c31154" + } + }, + { + "commit": { + "author": { + "user": { + "login": "coolteemf" + }, + "email": "67541941+coolteemf@users.noreply.github.com", + "name": "Fran\u00e7ois Lecomte" + }, + "oid": "b932e9e286c22aaf352375186df851ef060b295a" + } + }, + { + "commit": { + "author": { + "user": null, + "email": "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", + "name": "coolteemf" + }, + "oid": "346e0c547953d98eb84d23c1391a95badb9c4a22" + } + } + ], + "pageInfo": { + "endCursor": "MTY", + "hasNextPage": false + }, + "totalCount": 16 + }, + "commits": { + "nodes": [ + { + "commit": { + "checkSuites": { + "edges": [ + { + "node": { + "app": { + "name": "Facebook GitHub Tools", + "databaseId": 12274 + }, + "workflowRun": null, + "checkRuns": { + "nodes": [ + { + "name": "Facebook CLA Check", + "conclusion": "SUCCESS", + "detailsUrl": "https://code.intern.facebook.com/cla/" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAATwGYqY=", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/346e0c547953d98eb84d23c1391a95badb9c4a22/checks?check_suite_id=5414801320" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAUK_T6g=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "linux-xenial-py3.7-clang7-onnx" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1886754066/jobs/2663109808" + }, + { + "name": "test (default, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1886754066/jobs/2663214802" + }, + { + "name": "test (default, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1886754066/jobs/2663214856" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAATwIob0=", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/346e0c547953d98eb84d23c1391a95badb9c4a22/checks?check_suite_id=5414801849" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAUK_Ubk=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "linux-xenial-py3-clang5-mobile-build" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1886754064/jobs/2663109676" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAATwGZ1E=", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/346e0c547953d98eb84d23c1391a95badb9c4a22/checks?check_suite_id=5414801852" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAUK_Ubw=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "linux-bionic-rocm4.5-py3.7" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1886754065/jobs/2663109684" + }, + { + "name": "test (default, 2, 2, linux.rocm.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1886754065/jobs/2663401083" + }, + { + "name": "test (default, 1, 2, linux.rocm.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1886754065/jobs/2663401143" + }, + { + "name": "test (distributed, 1, 1, linux.rocm.gpu)", + "conclusion": "FAILURE", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1886754065/jobs/2663401186" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAATwMsZY=", + "hasNextPage": false + } + }, + "conclusion": "FAILURE", + "url": "https://github.com/pytorch/pytorch/commit/346e0c547953d98eb84d23c1391a95badb9c4a22/checks?check_suite_id=5414801853" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAUK_Ub0=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "win-vs2019-cuda11.3-py3" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1886754068/jobs/2663109680" + }, + { + "name": "test (default, 1, 2, windows.8xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1886754068/jobs/2663995756" + }, + { + "name": "test (force_on_cpu, 1, 1, windows.4xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1886754068/jobs/2663995819" + }, + { + "name": "test (default, 2, 2, windows.8xlarge.nvidia.gpu)", + "conclusion": "FAILURE", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1886754068/jobs/2663995900" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAATwZbzg=", + "hasNextPage": false + } + }, + "conclusion": "FAILURE", + "url": "https://github.com/pytorch/pytorch/commit/346e0c547953d98eb84d23c1391a95badb9c4a22/checks?check_suite_id=5414801855" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAUK_Ub8=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "Lint" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "mypy", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1886754069/jobs/2663109683" + }, + { + "name": "shellcheck", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1886754069/jobs/2663109827" + }, + { + "name": "py2-setup-validate-errormsg", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1886754069/jobs/2663109962" + }, + { + "name": "clang-format", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1886754069/jobs/2663110044" + }, + { + "name": "cmakelint", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1886754069/jobs/2663110132" + }, + { + "name": "toc", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1886754069/jobs/2663110233" + }, + { + "name": "quick-checks", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1886754069/jobs/2663110320" + }, + { + "name": "clang-tidy", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1886754069/jobs/2663110461" + }, + { + "name": "flake8-py3", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1886754069/jobs/2663110575" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAATwGbAQ=", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/346e0c547953d98eb84d23c1391a95badb9c4a22/checks?check_suite_id=5414801856" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAUK_UcA=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "linux-xenial-py3.7-clang7-asan" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1886754070/jobs/2663109804" + }, + { + "name": "test (default, 3, 3, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1886754070/jobs/2663233675" + }, + { + "name": "test (default, 1, 3, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1886754070/jobs/2663233731" + }, + { + "name": "test (default, 2, 3, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1886754070/jobs/2663233805" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAATwJC4U=", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/346e0c547953d98eb84d23c1391a95badb9c4a22/checks?check_suite_id=5414801857" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAUK_UcE=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single-full-jit" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "build-and-test", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1886754076/jobs/2663109810" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAATwGZ_w=", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/346e0c547953d98eb84d23c1391a95badb9c4a22/checks?check_suite_id=5414801862" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAUK_UcY=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "linux-xenial-py3.7-gcc5.4" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1886754078/jobs/2663109777" + }, + { + "name": "test (backwards_compat, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1886754078/jobs/2663201383" + }, + { + "name": "test (default, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1886754078/jobs/2663201458" + }, + { + "name": "test (default, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1886754078/jobs/2663201512" + }, + { + "name": "test (distributed, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1886754078/jobs/2663201580" + }, + { + "name": "test (jit_legacy, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1886754078/jobs/2663201672" + }, + { + "name": "test (docs_test, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1886754078/jobs/2663201839" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAATwIWu4=", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/346e0c547953d98eb84d23c1391a95badb9c4a22/checks?check_suite_id=5414801866" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAUK_Uco=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "linux-xenial-cuda11.3-py3.7-gcc7-bazel-test" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "build-and-test", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1886754079/jobs/2663109681" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAATwGZ1k=", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/346e0c547953d98eb84d23c1391a95badb9c4a22/checks?check_suite_id=5414801869" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAUK_Uc0=" + } + ], + "pageInfo": { + "hasNextPage": true + } + }, + "status": { + "contexts": [ + { + "context": "ci/circleci: binary_linux_libtorch_3_7m_cpu_gcc5_4_cxx11-abi_shared-with-deps_build", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/17017798?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: docker-pytorch-linux-xenial-py3-clang5-android-ndk-r19c", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/17017799?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-build-x86_32", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/17017816?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: pytorch_linux_xenial_py3_clang5_android_ndk_r19c_x86_32_build", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/17017800?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + } + ] + }, + "pushedDate": "2022-02-23T10:39:30Z", + "oid": "346e0c547953d98eb84d23c1391a95badb9c4a22" + } + } + ] + }, + "changedFiles": 9, + "files": { + "nodes": [ + { + "path": "aten/src/ATen/native/GridSampler.cpp" + }, + { + "path": "aten/src/ATen/native/cpu/GridSamplerKernel.cpp" + }, + { + "path": "aten/src/ATen/native/cuda/GridSampler.cpp" + }, + { + "path": "aten/src/ATen/native/cuda/GridSampler.cu" + }, + { + "path": "aten/src/ATen/native/cuda/GridSampler.h" + }, + { + "path": "aten/src/ATen/native/native_functions.yaml" + }, + { + "path": "test/forward_backward_compatibility/check_forward_backward_compatibility.py" + }, + { + "path": "test/test_nn.py" + }, + { + "path": "tools/autograd/derivatives.yaml" + } + ], + "pageInfo": { + "endCursor": "OQ", + "hasNextPage": false + } + }, + "reviews": { + "nodes": [ + { + "author": { + "login": "albanD" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "coolteemf" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "albanD" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "coolteemf" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "albanD" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "coolteemf" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "coolteemf" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "albanD" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "coolteemf" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "albanD" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "albanD" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "coolteemf" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "albanD" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "coolteemf" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "albanD" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "albanD" + }, + "state": "APPROVED" + }, + { + "author": { + "login": "albanD" + }, + "state": "APPROVED" + } + ], + "pageInfo": { + "startCursor": "Y3Vyc29yOnYyOpO5MjAyMi0wMS0yNVQwODoyODoxMC0wODowMLkyMDIyLTAxLTI1VDA3OjU0OjA1LTA4OjAwzjNooqI=", + "hasPreviousPage": false + } + }, + "comments": { + "nodes": [ + { + "bodyText": "Merge failed due to 'NoneType' object is not subscriptable\nRaised by https://github.com/pytorch/pytorch/actions/runs/1887945630", + "createdAt": "2022-02-23T14:55:36Z", + "author": { + "login": "pytorchmergebot" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1048868910 + }, + { + "bodyText": "Thanks for the update! The windows failure is not your fault, you can ignore it!\n\nThank you very much for all of your feedback and sorry for the delay !", + "createdAt": "2022-02-23T16:44:36Z", + "author": { + "login": "coolteemf" + }, + "authorAssociation": "CONTRIBUTOR", + "editor": null, + "databaseId": 1048983572 + }, + { + "bodyText": "@coolteemf can you please send either me or @albanD an email? (or I can send you and invite to collab on private repo)", + "createdAt": "2022-02-23T17:49:55Z", + "author": { + "login": "malfet" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1049048119 + }, + { + "bodyText": "@pytorchbot merge this please", + "createdAt": "2022-02-23T19:23:55Z", + "author": { + "login": "albanD" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1049131992 + }, + { + "bodyText": "Hey @coolteemf.\nYou've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.\nFor changes that are 'topic: not user facing' there is no need for a release notes label.", + "createdAt": "2022-02-23T19:26:51Z", + "author": { + "login": "github-actions" + }, + "authorAssociation": "NONE", + "editor": null, + "databaseId": 1049134520 + } + ], + "pageInfo": { + "startCursor": "Y3Vyc29yOnYyOpHOPoR4Lg==", + "hasPreviousPage": true + } + }, + "labels": { + "edges": [ + { + "node": { + "name": "triaged" + } + }, + { + "node": { + "name": "open source" + } + }, + { + "node": { + "name": "cla signed" + } + }, + { + "node": { + "name": "release notes: nn" + } + }, + { + "node": { + "name": "topic: performance" + } + } + ] + } + } + } + } + }, + "query_sha=81fd873151c3cded18314e9e53bf54a93ffb0afa9c52fa2cbafb2ceab7df5e45 name=pytorch number=75095 owner=pytorch": { + "data": { + "repository": { + "pullRequest": { + "closed": true, + "isCrossRepository": false, + "author": { + "login": "mruberry" + }, + "title": "Initial prims, references, and test architecture for them", + "body": "This PR adds an initial set of experimental primitive operations and Python references that reimplement existing PyTorch operations using them. See https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-0/577 for additional context.\r\n\r\nThe following experimental primitives are added:\r\n\r\n- Elementwise unary prims -- abs, acos, acosh, asin, atan, cos, cosh, bessel_i0e, bessel_i1e, cbrt, ceil, digamma, erf, erf_inv, erfc, exp, expm1, floor, igamma, igammac, is_finite, lgamma, log, log1p, neg, reciprocal, round, sign, sinh, sqrt, square, tan. \r\n- Elementwise binary prims -- add, atan2, bitwise_and, bitwise_not, bitwise_or, bitwise_xor, div, eq, ge, gt, le, lt, max, min, mul, ne, nextafter, pow, rsqrt, shift_left, shift_right_arithmetic\r\n- View prims -- brodcast_in_dim, collapse_view, split_dim, squeeze\r\n- Shape prims -- collapse, concatenate, reshape\r\n- Conditional prims -- select\r\n- Data conversion & movement prims -- convert_element_type, device_put\r\n- Inplace prims -- copy_to, resize\r\n\r\nThese primitives do not add any new functionality to PyTorch, but are intended to be the semantic building blocks for reference operators. We have tried to make them consistent with the operations in [jax.lax](https://jax.readthedocs.io/en/latest/jax.lax.html) where possible (because PyTorch prefers being consistent with other frameworks), although there are key differences between these prims and operations in jax.lax. Most notably is that these prims model view semantics and inplace operations.\r\n\r\nIn addition to these primitives the following elementwise binary Python references are added:\r\n\r\n- Elementwise binary Python references -- add, atan2, bitwise_and, bitwise_left_shift, bitwise_or, bitwise_right_shift, bitwise_xor, eq, float_power, ge, gt, le, lt, maximum, minimum, mul, ne, nextafter, pow, sub, true_divide\r\n- Conditional Python references - where\r\n- Data conversion & movement references - copy_to\r\n\r\nA Python reference implements the same behavior as its corresponding PyTorch operator (excepting slight numerical differences, bug fixes, and in some cases additional features). \r\n\r\nThe start of an OpInfo-based test architecture for these references is also included in this PR. A new list, `python_ref_db`, is added to `common_methods_invocations.py`. This list introduces the new `ElementwiseBinaryPythonRefInfo`, which inherits input arguments from the original operators' OpInfo, allows them to be overridden, and then constructs the OpInfo for the Python reference using the (potentially modified) arguments. OpInfo-based tests can opt-into testing references by including this new list in the Sequence passed to the `@ops` decorator. \r\n\r\ncc @ngimel @csarofeen @kevinstephano @Lezcano ", + "headRefName": "prims_and_references", + "headRepository": { + "nameWithOwner": "pytorch/pytorch" + }, + "baseRefName": "master", + "baseRepository": { + "nameWithOwner": "pytorch/pytorch", + "isPrivate": false, + "defaultBranchRef": { + "name": "master" + } + }, + "mergeCommit": null, + "commits_with_authors": { + "nodes": [ + { + "commit": { + "author": { + "user": null, + "email": "mruberry@devfair044.h1.fair", + "name": "Mike Ruberry" + }, + "oid": "a790467c650be92775103cde5e866c90b56f5376" + } + }, + { + "commit": { + "author": { + "user": null, + "email": "mruberry@devfair044.h1.fair", + "name": "Mike Ruberry" + }, + "oid": "bd6fcf50692e208ebecdc2eaa517a2bfcdcd35cf" + } + }, + { + "commit": { + "author": { + "user": null, + "email": "mruberry@devfair044.h1.fair", + "name": "Mike Ruberry" + }, + "oid": "4a119c8f21529fe1375e7e8789b91f41a3df80c5" + } + }, + { + "commit": { + "author": { + "user": null, + "email": "mruberry@devfair044.h1.fair", + "name": "Mike Ruberry" + }, + "oid": "ea6750dc34d66be759fdfe84b09fb0e23ee59c79" + } + }, + { + "commit": { + "author": { + "user": null, + "email": "mruberry@devfair044.h1.fair", + "name": "Mike Ruberry" + }, + "oid": "2eef8a55fe0227e1921b51bf1f56f9d0a29b49ac" + } + }, + { + "commit": { + "author": { + "user": null, + "email": "mruberry@devfair044.h1.fair", + "name": "Mike Ruberry" + }, + "oid": "b886ed6c20dd1785fd31ed6fa6a8c5b6d0d0b16c" + } + }, + { + "commit": { + "author": { + "user": null, + "email": "mruberry@devfair044.h1.fair", + "name": "Mike Ruberry" + }, + "oid": "9ad9b63d09aa4f7a8549bcf1d88ea4ff0674299c" + } + }, + { + "commit": { + "author": { + "user": null, + "email": "mruberry@devfair044.h1.fair", + "name": "Mike Ruberry" + }, + "oid": "63fdd580118477416ae160e0670ae722ea248090" + } + }, + { + "commit": { + "author": { + "user": null, + "email": "mruberry@devfair044.h1.fair", + "name": "Mike Ruberry" + }, + "oid": "0ccf7dc292af1d40d0a094eb2b2fb0c7ab4ccc70" + } + }, + { + "commit": { + "author": { + "user": null, + "email": "mruberry@devfair044.h1.fair", + "name": "Mike Ruberry" + }, + "oid": "e8a8a4d1fbe35f20eb88e1a43cf5a653883638e5" + } + }, + { + "commit": { + "author": { + "user": null, + "email": "mruberry@devfair044.h1.fair", + "name": "Mike Ruberry" + }, + "oid": "186634dfdd25645c05b58a212f9e8d77c4125fc0" + } + }, + { + "commit": { + "author": { + "user": null, + "email": "mruberry@devfair044.h1.fair", + "name": "Mike Ruberry" + }, + "oid": "f5b4741312b5c42a79f6c8a1d3930b79db38ed8f" + } + }, + { + "commit": { + "author": { + "user": { + "login": "ezyang" + }, + "email": "ezyang@fb.com", + "name": "Edward Z. Yang" + }, + "oid": "23d50391bb0fd12111fd3171591c4235ffb2fc1a" + } + }, + { + "commit": { + "author": { + "user": { + "login": "ezyang" + }, + "email": "ezyang@fb.com", + "name": "Edward Z. Yang" + }, + "oid": "bac9d45422d58f513b60b4b854441cfdc253d4c5" + } + }, + { + "commit": { + "author": { + "user": { + "login": "ezyang" + }, + "email": "ezyang@fb.com", + "name": "Edward Z. Yang" + }, + "oid": "13240ae0b4a0332c3167b65ac026a3172da90cb7" + } + }, + { + "commit": { + "author": { + "user": { + "login": "ezyang" + }, + "email": "ezyang@fb.com", + "name": "Edward Z. Yang" + }, + "oid": "1ee34468cb1db3dc6cbae204669f4fec20e2a466" + } + }, + { + "commit": { + "author": { + "user": { + "login": "ezyang" + }, + "email": "ezyang@fb.com", + "name": "Edward Z. Yang" + }, + "oid": "561d132bc686d00e8911f7feb3da5901b2bdc574" + } + }, + { + "commit": { + "author": { + "user": { + "login": "ngimel" + }, + "email": "ngimel@fb.com", + "name": "Natalia Gimelshein" + }, + "oid": "ac42bedc84b7c96256376ad09917263bb020b2c3" + } + }, + { + "commit": { + "author": { + "user": { + "login": "ngimel" + }, + "email": "ngimel@fb.com", + "name": "Natalia Gimelshein" + }, + "oid": "7f7d5ba40a0b5e10526d90b018b30b54673d12d8" + } + }, + { + "commit": { + "author": { + "user": null, + "email": "mruberry@devfair044.h1.fair", + "name": "Mike Ruberry" + }, + "oid": "37a6b4a8b1adb712d5777c7c3479866c27fb3c4e" + } + }, + { + "commit": { + "author": { + "user": { + "login": "ngimel" + }, + "email": "ngimel@fb.com", + "name": "Natalia Gimelshein" + }, + "oid": "65b613868c44e519c1777af79b9fd3498c5a7e58" + } + }, + { + "commit": { + "author": { + "user": { + "login": "ngimel" + }, + "email": "ngimel@fb.com", + "name": "Natalia Gimelshein" + }, + "oid": "442c405e9da0d66744ef03e379224c41eedf5b57" + } + }, + { + "commit": { + "author": { + "user": null, + "email": "mruberry@devfair044.h1.fair", + "name": "Mike Ruberry" + }, + "oid": "031ac49ae9c192989385986b6707fa781e3229e0" + } + }, + { + "commit": { + "author": { + "user": null, + "email": "mruberry@devfair044.h1.fair", + "name": "Mike Ruberry" + }, + "oid": "9a6c3b00039c0c985c1c9cb59490012d1c0b38ba" + } + }, + { + "commit": { + "author": { + "user": null, + "email": "mruberry@devfair044.h1.fair", + "name": "Mike Ruberry" + }, + "oid": "d5c30e408af1889b90012d2e09f6ec3cda333bcb" + } + }, + { + "commit": { + "author": { + "user": null, + "email": "mruberry@devfair044.h1.fair", + "name": "Mike Ruberry" + }, + "oid": "db355d55655bb252a699cd532441bb98e52b98d5" + } + } + ], + "pageInfo": { + "endCursor": "MjY", + "hasNextPage": false + }, + "totalCount": 26 + }, + "commits": { + "nodes": [ + { + "commit": { + "checkSuites": { + "edges": [ + { + "node": { + "app": { + "name": "Facebook GitHub Tools", + "databaseId": 12274 + }, + "workflowRun": null, + "checkRuns": { + "nodes": [ + { + "name": "Facebook CLA Check", + "conclusion": "SUCCESS", + "detailsUrl": "https://code.intern.facebook.com/cla/" + }, + { + "name": "Meta Internal-Only Changes Check", + "conclusion": "SUCCESS", + "detailsUrl": "https://opensource.facebook.com/" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAW6ux14=", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/db355d55655bb252a699cd532441bb98e52b98d5/checks?check_suite_id=6241454954" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAXQFC2o=" + }, + { + "node": { + "app": { + "name": "Netlify", + "databaseId": 13473 + }, + "workflowRun": null, + "checkRuns": { + "nodes": [], + "pageInfo": { + "endCursor": null, + "hasNextPage": false + } + }, + "conclusion": null, + "url": "https://github.com/pytorch/pytorch/commit/db355d55655bb252a699cd532441bb98e52b98d5/checks?check_suite_id=6241454956" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAXQFC2w=" + }, + { + "node": { + "app": { + "name": "Azure Pipelines", + "databaseId": 9426 + }, + "workflowRun": null, + "checkRuns": { + "nodes": [], + "pageInfo": { + "endCursor": null, + "hasNextPage": false + } + }, + "conclusion": null, + "url": "https://github.com/pytorch/pytorch/commit/db355d55655bb252a699cd532441bb98e52b98d5/checks?check_suite_id=6241454965" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAXQFC3U=" + }, + { + "node": { + "app": { + "name": "Dependabot", + "databaseId": 29110 + }, + "workflowRun": null, + "checkRuns": { + "nodes": [], + "pageInfo": { + "endCursor": null, + "hasNextPage": false + } + }, + "conclusion": null, + "url": "https://github.com/pytorch/pytorch/commit/db355d55655bb252a699cd532441bb98e52b98d5/checks?check_suite_id=6241454970" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAXQFC3o=" + }, + { + "node": { + "app": { + "name": "Codecov", + "databaseId": 254 + }, + "workflowRun": null, + "checkRuns": { + "nodes": [], + "pageInfo": { + "endCursor": null, + "hasNextPage": false + } + }, + "conclusion": null, + "url": "https://github.com/pytorch/pytorch/commit/db355d55655bb252a699cd532441bb98e52b98d5/checks?check_suite_id=6241454974" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAXQFC34=" + }, + { + "node": { + "app": { + "name": "PyTorch Bot", + "databaseId": 40112 + }, + "workflowRun": null, + "checkRuns": { + "nodes": [], + "pageInfo": { + "endCursor": null, + "hasNextPage": false + } + }, + "conclusion": null, + "url": "https://github.com/pytorch/pytorch/commit/db355d55655bb252a699cd532441bb98e52b98d5/checks?check_suite_id=6241454977" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAXQFC4E=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "TorchBench CI (pytorch-linux-py3.7-cu102)" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "run-torchbench", + "conclusion": "NEUTRAL", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2217622865/jobs/3270915028" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAW6e-c8=", + "hasNextPage": false + } + }, + "conclusion": "SKIPPED", + "url": "https://github.com/pytorch/pytorch/commit/db355d55655bb252a699cd532441bb98e52b98d5/checks?check_suite_id=6241455322" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAXQFDNo=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "Lint" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "quick-checks", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2217622869/jobs/3270915027" + }, + { + "name": "lintrunner", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2217622869/jobs/3270915071" + }, + { + "name": "Test tools", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2217622869/jobs/3270915141" + }, + { + "name": "Test collect_env (with_torch)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2217622869/jobs/3270915194" + }, + { + "name": "Test collect_env (without_torch)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2217622869/jobs/3270915229" + }, + { + "name": "toc", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2217622869/jobs/3270915283" + }, + { + "name": "workflow-checks", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2217622869/jobs/3270915321" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAW6e-zM=", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/db355d55655bb252a699cd532441bb98e52b98d5/checks?check_suite_id=6241455334" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAXQFDOY=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "pull" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "linux-vulkan-bionic-py3.7-clang9 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2217622878/jobs/3270927344" + }, + { + "name": "linux-bionic-rocm5.0-py3.7 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2217622878/jobs/3270927442" + }, + { + "name": "linux-xenial-py3.7-clang7-onnx / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2217622878/jobs/3270927507" + }, + { + "name": "linux-xenial-py3.7-gcc5.4-mobile-lightweight-dispatch-build / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2217622878/jobs/3270927567" + }, + { + "name": "pytorch-xla-linux-bionic-py3.7-clang8 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2217622878/jobs/3270927674" + }, + { + "name": "win-vs2019-cuda11.3-py3 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2217622878/jobs/3270927727" + }, + { + "name": "linux-bionic-py3.7-clang9 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2217622878/jobs/3270927802" + }, + { + "name": "linux-xenial-cuda11.3-py3.7-gcc7 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2217622878/jobs/3270927853" + }, + { + "name": "linux-xenial-py3.7-gcc7 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2217622878/jobs/3270927948" + }, + { + "name": "linux-xenial-py3-clang5-mobile-build / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2217622878/jobs/3270927996" + }, + { + "name": "linux-xenial-py3.7-clang7-asan / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2217622878/jobs/3270928061" + }, + { + "name": "linux-xenial-py3-clang5-mobile-custom-build-static / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2217622878/jobs/3270928116" + }, + { + "name": "pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single-full-jit / build-and-test", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2217622878/jobs/3270928198" + }, + { + "name": "linux-xenial-py3.7-gcc5.4 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2217622878/jobs/3270928256" + }, + { + "name": "pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single / build-and-test", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2217622878/jobs/3270928291" + }, + { + "name": "win-vs2019-cpu-py3 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2217622878/jobs/3270928317" + }, + { + "name": "deploy-linux-xenial-cuda11.3-py3.7-gcc7 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2217622878/jobs/3270928338" + }, + { + "name": "linux-xenial-py3.7-gcc7-no-ops / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2217622878/jobs/3270928367" + }, + { + "name": "linux-xenial-cuda11.3-py3.7-gcc7-bazel-test / build-and-test", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2217622878/jobs/3270928410" + }, + { + "name": "linux-bionic-cuda11.3-py3.7-clang9 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2217622878/jobs/3270928445" + }, + { + "name": "linux-xenial-py3.7-gcc5.4 / test (default, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2217622878/jobs/3270991071" + }, + { + "name": "linux-xenial-py3.7-gcc5.4 / test (default, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2217622878/jobs/3270991125" + }, + { + "name": "linux-xenial-py3.7-gcc5.4 / test (distributed, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2217622878/jobs/3270991162" + }, + { + "name": "linux-xenial-py3.7-gcc5.4 / test (docs_test, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2217622878/jobs/3270991195" + }, + { + "name": "linux-xenial-py3.7-gcc5.4 / test (backwards_compat, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2217622878/jobs/3270991233" + }, + { + "name": "linux-xenial-py3.7-gcc5.4 / test (jit_legacy, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2217622878/jobs/3270991261" + }, + { + "name": "linux-docs / build-docs (cpp)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2217622878/jobs/3270991305" + }, + { + "name": "linux-docs / build-docs (python)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2217622878/jobs/3270991349" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (default, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2217622878/jobs/3270996024" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (default, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2217622878/jobs/3270996068" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (crossref, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2217622878/jobs/3270996092" + }, + { + "name": "linux-vulkan-bionic-py3.7-clang9 / test (default, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2217622878/jobs/3270996505" + }, + { + "name": "linux-xenial-py3.7-gcc7 / test (default, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2217622878/jobs/3270998987" + }, + { + "name": "linux-xenial-py3.7-gcc7 / test (default, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2217622878/jobs/3270999027" + }, + { + "name": "linux-xenial-py3.7-clang7-onnx / test (default, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2217622878/jobs/3271006886" + }, + { + "name": "linux-xenial-py3.7-clang7-onnx / test (default, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2217622878/jobs/3271006941" + }, + { + "name": "linux-xenial-py3.7-clang7-asan / test (default, 1, 3, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2217622878/jobs/3271018097" + }, + { + "name": "linux-xenial-py3.7-clang7-asan / test (default, 2, 3, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2217622878/jobs/3271018135" + }, + { + "name": "linux-xenial-py3.7-clang7-asan / test (default, 3, 3, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2217622878/jobs/3271018162" + }, + { + "name": "pytorch-xla-linux-bionic-py3.7-clang8", + "conclusion": "NEUTRAL", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2217622878/jobs/3271021143" + }, + { + "name": "linux-bionic-rocm5.0-py3.7 / test (default, 1, 2, linux.rocm.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2217622878/jobs/3271034041" + }, + { + "name": "linux-bionic-rocm5.0-py3.7 / test (default, 2, 2, linux.rocm.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2217622878/jobs/3271034072" + }, + { + "name": "linux-xenial-cuda11.3-py3.7-gcc7 / test (deploy, 1, 1, linux.4xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2217622878/jobs/3271048218" + }, + { + "name": "linux-xenial-cuda11.3-py3.7-gcc7 / test (default, 1, 2, linux.4xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2217622878/jobs/3271049553" + }, + { + "name": "linux-xenial-cuda11.3-py3.7-gcc7 / test (default, 2, 2, linux.4xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2217622878/jobs/3271049587" + }, + { + "name": "linux-xenial-cuda11.3-py3.7-gcc7 / test (distributed, 1, 1, linux.8xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2217622878/jobs/3271049616" + }, + { + "name": "win-vs2019-cpu-py3 / test (default, 1, 2, windows.4xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2217622878/jobs/3271068293" + }, + { + "name": "win-vs2019-cpu-py3 / test (default, 2, 2, windows.4xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2217622878/jobs/3271068336" + }, + { + "name": "win-vs2019-cuda11.3-py3 / test (default, 1, 2, windows.8xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2217622878/jobs/3271149276" + }, + { + "name": "win-vs2019-cuda11.3-py3 / test (default, 2, 2, windows.8xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2217622878/jobs/3271149321" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAW6jVK8=", + "hasNextPage": true + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/db355d55655bb252a699cd532441bb98e52b98d5/checks?check_suite_id=6241455360" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAXQFDQA=" + } + ], + "pageInfo": { + "hasNextPage": false + } + }, + "status": null, + "pushedDate": "2022-04-25T02:30:31Z", + "oid": "db355d55655bb252a699cd532441bb98e52b98d5" + } + } + ] + }, + "changedFiles": 5, + "files": { + "nodes": [ + { + "path": "test/test_ops.py" + }, + { + "path": "torch/_prims/__init__.py" + }, + { + "path": "torch/_prims/utils.py" + }, + { + "path": "torch/_refs/__init__.py" + }, + { + "path": "torch/testing/_internal/common_methods_invocations.py" + } + ], + "pageInfo": { + "endCursor": "NQ", + "hasNextPage": false + } + }, + "reviews": { + "nodes": [ + { + "author": { + "login": "lezcano" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "lezcano" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "lezcano" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "lezcano" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "lezcano" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "lezcano" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "lezcano" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "ngimel" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "ngimel" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "lezcano" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "ezyang" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "zou3519" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "peterbell10" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "lezcano" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "lezcano" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "ngimel" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "ezyang" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "ezyang" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "ezyang" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "ezyang" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "ezyang" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "ezyang" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "ezyang" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "ezyang" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "ezyang" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "ezyang" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "ezyang" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "ezyang" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "ezyang" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "ezyang" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "ezyang" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "ezyang" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "ezyang" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "ezyang" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "ezyang" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "ezyang" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "ezyang" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "ngimel" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "ezyang" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "lezcano" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "lezcano" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "lezcano" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "ezyang" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "ezyang" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "ezyang" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "ezyang" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "ngimel" + }, + "state": "APPROVED" + }, + { + "author": { + "login": "ezyang" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "mruberry" + }, + "state": "COMMENTED" + } + ], + "pageInfo": { + "startCursor": "Y3Vyc29yOnYyOpO5MjAyMi0wNC0wNlQxMjo1NjoyNC0wNzowMLkyMDIyLTA0LTA2VDA4OjQwOjM4LTA3OjAwzjenO6Y=", + "hasPreviousPage": false + } + }, + "comments": { + "nodes": [ + { + "bodyText": "Ref implementations by themselves can handle any shapes (and broadcast ops by themselves don't bake in any shapes). The question is can we decide if a particular trace is applicable for a different input, but that depends on the tracing technology and what we are caching on, so out of scope for initial PR.", + "createdAt": "2022-04-21T19:00:28Z", + "author": { + "login": "ngimel" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1105643418 + }, + { + "bodyText": "@pytorchbot merge this please", + "createdAt": "2022-04-25T04:42:29Z", + "author": { + "login": "mruberry" + }, + "authorAssociation": "COLLABORATOR", + "editor": null, + "databaseId": 1108072887 + }, + { + "bodyText": "Merge failed due to 'mruberry'\nRaised by https://github.com/pytorch/pytorch/actions/runs/2218044244", + "createdAt": "2022-04-25T04:43:54Z", + "author": { + "login": "pytorchmergebot" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1108073536 + }, + { + "bodyText": "@mruberry has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.", + "createdAt": "2022-04-25T04:51:11Z", + "author": { + "login": "facebook-github-bot" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1108075965 + }, + { + "bodyText": "Hey @mruberry.\nYou've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.\nFor changes that are 'topic: not user facing' there is no need for a release notes label.", + "createdAt": "2022-04-25T09:57:56Z", + "author": { + "login": "github-actions" + }, + "authorAssociation": "NONE", + "editor": null, + "databaseId": 1108351107 + } + ], + "pageInfo": { + "startCursor": "Y3Vyc29yOnYyOpHOQebHmg==", + "hasPreviousPage": true + } + }, + "labels": { + "edges": [ + { + "node": { + "name": "cla signed" + } + }, + { + "node": { + "name": "topic: not user facing" + } + }, + { + "node": { + "name": "module: primTorch" + } + } + ] + } + } + } + } + }, + "query_sha=81fd873151c3cded18314e9e53bf54a93ffb0afa9c52fa2cbafb2ceab7df5e45 name=pytorch number=77700 owner=pytorch": { + "data": { + "repository": { + "pullRequest": { + "closed": true, + "isCrossRepository": false, + "author": { + "login": "kit1980" + }, + "title": "Move pull linux-docs job to Ubuntu 20.04", + "body": "", + "headRefName": "sdym/pull-xenial-focal-linux-docs", + "headRepository": { + "nameWithOwner": "pytorch/pytorch" + }, + "baseRefName": "master", + "baseRepository": { + "nameWithOwner": "pytorch/pytorch", + "isPrivate": false, + "defaultBranchRef": { + "name": "master" + } + }, + "mergeCommit": null, + "commits_with_authors": { + "nodes": [ + { + "commit": { + "author": { + "user": { + "login": "kit1980" + }, + "email": "sdym@fb.com", + "name": "Sergii Dymchenko" + }, + "oid": "81261599614423baa17df72300b8e109677b6799" + } + } + ], + "pageInfo": { + "endCursor": "MQ", + "hasNextPage": false + }, + "totalCount": 1 + }, + "commits": { + "nodes": [ + { + "commit": { + "checkSuites": { + "edges": [ + { + "node": { + "app": { + "name": "Facebook GitHub Tools", + "databaseId": 12274 + }, + "workflowRun": null, + "checkRuns": { + "nodes": [ + { + "name": "Facebook CLA Check", + "conclusion": "SUCCESS", + "detailsUrl": "https://code.facebook.com/cla/" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAYNmNqE=", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/81261599614423baa17df72300b8e109677b6799/checks?check_suite_id=6567147714" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAYduuMI=" + }, + { + "node": { + "app": { + "name": "Netlify", + "databaseId": 13473 + }, + "workflowRun": null, + "checkRuns": { + "nodes": [], + "pageInfo": { + "endCursor": null, + "hasNextPage": false + } + }, + "conclusion": null, + "url": "https://github.com/pytorch/pytorch/commit/81261599614423baa17df72300b8e109677b6799/checks?check_suite_id=6567147726" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAYduuM4=" + }, + { + "node": { + "app": { + "name": "Azure Pipelines", + "databaseId": 9426 + }, + "workflowRun": null, + "checkRuns": { + "nodes": [], + "pageInfo": { + "endCursor": null, + "hasNextPage": false + } + }, + "conclusion": null, + "url": "https://github.com/pytorch/pytorch/commit/81261599614423baa17df72300b8e109677b6799/checks?check_suite_id=6567147733" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAYduuNU=" + }, + { + "node": { + "app": { + "name": "Dependabot", + "databaseId": 29110 + }, + "workflowRun": null, + "checkRuns": { + "nodes": [], + "pageInfo": { + "endCursor": null, + "hasNextPage": false + } + }, + "conclusion": null, + "url": "https://github.com/pytorch/pytorch/commit/81261599614423baa17df72300b8e109677b6799/checks?check_suite_id=6567147746" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAYduuOI=" + }, + { + "node": { + "app": { + "name": "Codecov", + "databaseId": 254 + }, + "workflowRun": null, + "checkRuns": { + "nodes": [], + "pageInfo": { + "endCursor": null, + "hasNextPage": false + } + }, + "conclusion": null, + "url": "https://github.com/pytorch/pytorch/commit/81261599614423baa17df72300b8e109677b6799/checks?check_suite_id=6567147762" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAYduuPI=" + }, + { + "node": { + "app": { + "name": "PyTorch Bot", + "databaseId": 40112 + }, + "workflowRun": null, + "checkRuns": { + "nodes": [], + "pageInfo": { + "endCursor": null, + "hasNextPage": false + } + }, + "conclusion": null, + "url": "https://github.com/pytorch/pytorch/commit/81261599614423baa17df72300b8e109677b6799/checks?check_suite_id=6567147780" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAYduuQQ=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "Lint" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "lintrunner", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867841/jobs/3528127876" + }, + { + "name": "workflow-checks", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867841/jobs/3528128023" + }, + { + "name": "quick-checks", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867841/jobs/3528128196" + }, + { + "name": "Test collect_env (with_torch)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867841/jobs/3528128519" + }, + { + "name": "Test collect_env (without_torch)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867841/jobs/3528128575" + }, + { + "name": "toc", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867841/jobs/3528128663" + }, + { + "name": "Test tools", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867841/jobs/3528128857" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAYNdYVY=", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/81261599614423baa17df72300b8e109677b6799/checks?check_suite_id=6567148336" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAYduuzA=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "TorchBench CI (pytorch-linux-py3.7-cu102)" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "run-torchbench", + "conclusion": "NEUTRAL", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867843/jobs/3528127882" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAYNdXEg=", + "hasNextPage": false + } + }, + "conclusion": "SKIPPED", + "url": "https://github.com/pytorch/pytorch/commit/81261599614423baa17df72300b8e109677b6799/checks?check_suite_id=6567148344" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAYduuzg=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "docker-builds" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "docker-build (pytorch-linux-bionic-cuda10.2-cudnn7-py3.9-gcc7)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867844/jobs/3528127883" + }, + { + "name": "docker-build (pytorch-linux-bionic-cuda11.3-cudnn8-py3-clang9)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867844/jobs/3528127945" + }, + { + "name": "docker-build (pytorch-linux-bionic-cuda11.6-cudnn8-py3-gcc7)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867844/jobs/3528128001" + }, + { + "name": "docker-build (pytorch-linux-bionic-py3.7-clang9)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867844/jobs/3528128067" + }, + { + "name": "docker-build (pytorch-linux-bionic-rocm5.0-py3.7)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867844/jobs/3528128124" + }, + { + "name": "docker-build (pytorch-linux-bionic-rocm5.1-py3.7)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867844/jobs/3528128191" + }, + { + "name": "docker-build (pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867844/jobs/3528128259" + }, + { + "name": "docker-build (pytorch-linux-xenial-cuda11.3-cudnn8-py3-gcc7)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867844/jobs/3528128321" + }, + { + "name": "docker-build (pytorch-linux-xenial-py3-clang5-android-ndk-r19c)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867844/jobs/3528128365" + }, + { + "name": "docker-build (pytorch-linux-xenial-py3-clang5-asan)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867844/jobs/3528128446" + }, + { + "name": "docker-build (pytorch-linux-xenial-py3-clang7-asan)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867844/jobs/3528128507" + }, + { + "name": "docker-build (pytorch-linux-xenial-py3-clang7-onnx)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867844/jobs/3528128563" + }, + { + "name": "docker-build (pytorch-linux-xenial-py3.7-gcc5.4)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867844/jobs/3528128639" + }, + { + "name": "docker-build (pytorch-linux-xenial-py3.7-gcc7)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867844/jobs/3528128687" + }, + { + "name": "docker-build (pytorch-linux-focal-py3.7-gcc7)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867844/jobs/3528128741" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAYNdYLI=", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/81261599614423baa17df72300b8e109677b6799/checks?check_suite_id=6567148352" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAYduu0A=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "pull" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "linux-bionic-py3.7-clang9 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867849/jobs/3528150762" + }, + { + "name": "linux-focal-py3.7-gcc7 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867849/jobs/3528150903" + }, + { + "name": "linux-xenial-py3.7-gcc7-no-ops / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867849/jobs/3528151086" + }, + { + "name": "linux-xenial-py3.7-gcc7 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867849/jobs/3528151258" + }, + { + "name": "linux-vulkan-bionic-py3.7-clang9 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867849/jobs/3528151511" + }, + { + "name": "linux-bionic-rocm5.1-py3.7 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867849/jobs/3528151776" + }, + { + "name": "linux-bionic-cuda11.3-py3.7-clang9 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867849/jobs/3528151896" + }, + { + "name": "linux-xenial-py3.7-gcc5.4 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867849/jobs/3528152014" + }, + { + "name": "linux-xenial-py3.7-clang7-onnx / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867849/jobs/3528152139" + }, + { + "name": "deploy-linux-xenial-cuda11.3-py3.7-gcc7 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867849/jobs/3528152216" + }, + { + "name": "win-vs2019-cuda11.3-py3 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867849/jobs/3528152378" + }, + { + "name": "linux-xenial-py3-clang5-mobile-custom-build-static / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867849/jobs/3528152516" + }, + { + "name": "linux-xenial-py3-clang5-mobile-build / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867849/jobs/3528152599" + }, + { + "name": "linux-xenial-py3.7-clang7-asan / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867849/jobs/3528152723" + }, + { + "name": "linux-xenial-cuda11.3-py3.7-gcc7 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867849/jobs/3528152802" + }, + { + "name": "linux-xenial-cuda11.3-py3.7-gcc7-bazel-test / build-and-test", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867849/jobs/3528152913" + }, + { + "name": "pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single-full-jit / build-and-test", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867849/jobs/3528152969" + }, + { + "name": "pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single / build-and-test", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867849/jobs/3528153005" + }, + { + "name": "linux-xenial-py3.7-gcc5.4-mobile-lightweight-dispatch-build / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867849/jobs/3528153062" + }, + { + "name": "pytorch-xla-linux-bionic-py3.7-clang8 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867849/jobs/3528153125" + }, + { + "name": "win-vs2019-cpu-py3 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867849/jobs/3528153207" + }, + { + "name": "linux-xenial-py3.7-gcc7 / test (default, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867849/jobs/3528242483" + }, + { + "name": "linux-xenial-py3.7-gcc7 / test (default, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867849/jobs/3528242528" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (default, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867849/jobs/3528245875" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (default, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867849/jobs/3528245914" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (crossref, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867849/jobs/3528245964" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (crossref, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867849/jobs/3528246008" + }, + { + "name": "linux-vulkan-bionic-py3.7-clang9 / test (default, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867849/jobs/3528248520" + }, + { + "name": "linux-xenial-py3.7-clang7-onnx / test (default, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867849/jobs/3528255086" + }, + { + "name": "linux-xenial-py3.7-clang7-onnx / test (default, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867849/jobs/3528255128" + }, + { + "name": "linux-xenial-py3.7-clang7-asan / test (default, 1, 5, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867849/jobs/3528274064" + }, + { + "name": "linux-xenial-py3.7-clang7-asan / test (default, 2, 5, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867849/jobs/3528274097" + }, + { + "name": "linux-xenial-py3.7-clang7-asan / test (default, 3, 5, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867849/jobs/3528274133" + }, + { + "name": "linux-xenial-py3.7-clang7-asan / test (default, 4, 5, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867849/jobs/3528274173" + }, + { + "name": "linux-xenial-py3.7-clang7-asan / test (default, 5, 5, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867849/jobs/3528274209" + }, + { + "name": "pytorch-xla-linux-bionic-py3.7-clang8 / test (xla, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867849/jobs/3528277014" + }, + { + "name": "linux-xenial-cuda11.3-py3.7-gcc7 / test (deploy, 1, 1, linux.4xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867849/jobs/3528308958" + }, + { + "name": "linux-xenial-cuda11.3-py3.7-gcc7 / test (default, 1, 4, linux.4xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867849/jobs/3528309747" + }, + { + "name": "linux-xenial-cuda11.3-py3.7-gcc7 / test (default, 2, 4, linux.4xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867849/jobs/3528309810" + }, + { + "name": "linux-xenial-cuda11.3-py3.7-gcc7 / test (default, 3, 4, linux.4xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867849/jobs/3528309837" + }, + { + "name": "linux-xenial-cuda11.3-py3.7-gcc7 / test (default, 4, 4, linux.4xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867849/jobs/3528309864" + }, + { + "name": "linux-xenial-cuda11.3-py3.7-gcc7 / test (distributed, 1, 2, linux.8xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867849/jobs/3528309895" + }, + { + "name": "linux-xenial-cuda11.3-py3.7-gcc7 / test (distributed, 2, 2, linux.8xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867849/jobs/3528309925" + }, + { + "name": "linux-bionic-rocm5.1-py3.7 / test (default, 1, 2, linux.rocm.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867849/jobs/3528310044" + }, + { + "name": "linux-bionic-rocm5.1-py3.7 / test (default, 2, 2, linux.rocm.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867849/jobs/3528310101" + }, + { + "name": "linux-xenial-py3.7-gcc5.4 / test (default, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867849/jobs/3528384337" + }, + { + "name": "linux-xenial-py3.7-gcc5.4 / test (default, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867849/jobs/3528384379" + }, + { + "name": "linux-xenial-py3.7-gcc5.4 / test (distributed, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867849/jobs/3528384408" + }, + { + "name": "linux-xenial-py3.7-gcc5.4 / test (docs_test, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867849/jobs/3528384441" + }, + { + "name": "linux-xenial-py3.7-gcc5.4 / test (backwards_compat, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2348867849/jobs/3528384471" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAYNi1Nc=", + "hasNextPage": true + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/81261599614423baa17df72300b8e109677b6799/checks?check_suite_id=6567148369" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAYduu1E=" + } + ], + "pageInfo": { + "hasNextPage": false + } + }, + "status": null, + "pushedDate": "2022-05-19T00:02:11Z", + "oid": "81261599614423baa17df72300b8e109677b6799" + } + } + ] + }, + "changedFiles": 3, + "files": { + "nodes": [ + { + "path": ".circleci/docker/build.sh" + }, + { + "path": ".circleci/docker/common/install_katex.sh" + }, + { + "path": ".github/workflows/pull.yml" + } + ], + "pageInfo": { + "endCursor": "Mw", + "hasNextPage": false + } + }, + "reviews": { + "nodes": [ + { + "author": { + "login": "suo" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "kit1980" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "janeyx99" + }, + "state": "APPROVED" + } + ], + "pageInfo": { + "startCursor": "Y3Vyc29yOnYyOpO5MjAyMi0wNS0xOFQxMjo0MTowNS0wNzowMLkyMDIyLTA1LTE4VDEyOjQxOjA0LTA3OjAwzjpD7es=", + "hasPreviousPage": false + } + }, + "comments": { + "nodes": [ + { + "bodyText": "\ud83d\udd17 Helpful links\n\n\ud83e\uddea \u00a0See artifacts and rendered test results at hud.pytorch.org/pr/77700\n\ud83d\udcc4 \u00a0Preview Python docs built from this PR\n\ud83d\udcc4 \u00a0Preview C++ docs built from this PR\n\u2753Need help or want to give feedback on the CI? Visit our office hours\n\n\u2705 No Failures (0 Pending)\nAs of commit 8126159 (more details on the Dr. CI page):\nExpand to see more\n\n\ud83d\udc9a \ud83d\udc9a Looks good so far! There are no failures yet. \ud83d\udc9a \ud83d\udc9a\n\nThis comment was automatically generated by Dr. CI (expand for details).\nPlease report bugs/suggestions to the (internal) Dr. CI Users group.\nClick here to manually regenerate this comment.", + "createdAt": "2022-05-17T23:01:48Z", + "author": { + "login": "facebook-github-bot" + }, + "authorAssociation": "MEMBER", + "editor": { + "login": "facebook-github-bot" + }, + "databaseId": 1129400934 + }, + { + "bodyText": "@pytorchbot merge", + "createdAt": "2022-05-19T15:39:05Z", + "author": { + "login": "kit1980" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1131884232 + }, + { + "bodyText": "Merge failed due to Refusing to merge as mandatory check(s) linux-docs / build-docs (cpp), linux-docs / build-docs (python) are pending/not yet run for rule OSS CI\nRaised by https://github.com/pytorch/pytorch/actions/runs/2353067846", + "createdAt": "2022-05-19T15:40:59Z", + "author": { + "login": "pytorchmergebot" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1131886153 + }, + { + "bodyText": "@pytorchbot merge -f", + "createdAt": "2022-05-19T16:41:29Z", + "author": { + "login": "kit1980" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1131945610 + }, + { + "bodyText": "Hey @kit1980.\nYou've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.\nFor changes that are 'topic: not user facing' there is no need for a release notes label.", + "createdAt": "2022-05-19T16:43:37Z", + "author": { + "login": "github-actions" + }, + "authorAssociation": "NONE", + "editor": null, + "databaseId": 1131947473 + } + ], + "pageInfo": { + "startCursor": "Y3Vyc29yOnYyOpHOQ1FKZg==", + "hasPreviousPage": false + } + }, + "labels": { + "edges": [ + { + "node": { + "name": "Merged" + } + }, + { + "node": { + "name": "cla signed" + } + } + ] + } + } + } + } + }, + "query_sha=81fd873151c3cded18314e9e53bf54a93ffb0afa9c52fa2cbafb2ceab7df5e45 name=pytorch number=68111 owner=pytorch": { + "data": { + "repository": { + "pullRequest": { + "closed": true, + "isCrossRepository": true, + "author": { + "login": "chunyuan-w" + }, + "title": "Add JIT graph fuser for oneDNN Graph API (Preview4)", + "body": "## Description\r\nPreview4 PR of this [RFC](https://github.com/pytorch/pytorch/issues/49444).\r\n\r\nOn the basis of https://github.com/pytorch/pytorch/pull/50256, the below improvements are included:\r\n\r\n- The [preview4 release branch](https://github.com/oneapi-src/oneDNN/releases/tag/graph-v0.4.1) of the oneDNN Graph API is used\r\n- The fuser now works with the profiling graph executor. We have inserted type check nodes to guard the profiled tensor properties.\r\n\r\n### User API:\r\nThe optimization pass is disabled by default. Users could enable it by:\r\n```\r\ntorch.jit.enable_onednn_fusion(True)\r\n```\r\n\r\n### Performance:\r\n[pytorch/benchmark](https://github.com/pytorch/benchmark) tool is used to compare the performance:\r\n- SkyLake 8180 (1 socket of 28 cores):\r\n\r\n ![image](https://user-images.githubusercontent.com/65992142/151162305-05e44425-a24e-4d5e-94e1-743b40b87a8c.png)\r\n\r\n- SkyLake 8180 (single thread):\r\n\r\n ![image](https://user-images.githubusercontent.com/65992142/151162528-69f90b79-d08d-46b8-8775-d80a6ccbce8a.png)\r\n \\* By mapping hardswish to oneDNN Graph, it\u2019s 8% faster than PyTorch JIT (NNC + OFI)\r\n \\** We expect performance gain after mapping transpose, contiguous & view to oneDNN graph ops\r\n\r\n\r\n### Directory structure of the integration code\r\nFuser-related code are placed under:\r\n```\r\ntorch/csrc/jit/codegen/onednn/\r\n```\r\n\r\nOptimization pass registration is done in:\r\n```\r\ntorch/csrc/jit/passes/onednn_graph_fuser.h\r\n```\r\n\r\nCMake for the integration code is:\r\n```\r\ncaffe2/CMakeLists.txt\r\n```\r\n\r\n## Limitations\r\n\r\n- In this PR, we have only supported the optimization on Linux platform. The support on Windows and MacOS will be enabled as the next step.\r\n- We have only optimized the inference use case.", + "headRefName": "chunyuan/llga_preview2", + "headRepository": { + "nameWithOwner": "chunyuan-w/pytorch" + }, + "baseRefName": "master", + "baseRepository": { + "nameWithOwner": "pytorch/pytorch", + "isPrivate": false, + "defaultBranchRef": { + "name": "master" + } + }, + "mergeCommit": null, + "commits_with_authors": { + "nodes": [ + { + "commit": { + "author": { + "user": { + "login": "chunyuan-w" + }, + "email": "chunyuan.wu@intel.com", + "name": "chunyuan" + }, + "oid": "0096fcc49f277fd8e006fcb42e0cb28a1422ec98" + } + }, + { + "commit": { + "author": { + "user": { + "login": "chunyuan-w" + }, + "email": "chunyuan.wu@intel.com", + "name": "chunyuan" + }, + "oid": "7bcc4de26a5472f1d252735dd425b46794b0844f" + } + }, + { + "commit": { + "author": { + "user": { + "login": "chunyuan-w" + }, + "email": "chunyuan.wu@intel.com", + "name": "chunyuan" + }, + "oid": "3a2a588bfe6bbf9bf74d88d441cd22affda207da" + } + }, + { + "commit": { + "author": { + "user": { + "login": "chunyuan-w" + }, + "email": "chunyuan.wu@intel.com", + "name": "chunyuan" + }, + "oid": "ca7df12fbfaa3ddbabeca39b76300d17f4a33f2f" + } + }, + { + "commit": { + "author": { + "user": { + "login": "chunyuan-w" + }, + "email": "chunyuan.wu@intel.com", + "name": "chunyuan" + }, + "oid": "81d44f35b8bc043c38837d0694e5bc072203b832" + } + }, + { + "commit": { + "author": { + "user": { + "login": "chunyuan-w" + }, + "email": "chunyuan.wu@intel.com", + "name": "chunyuan" + }, + "oid": "14fd5d1bfc2c58a71379f778871e3fca0a8e79b2" + } + }, + { + "commit": { + "author": { + "user": { + "login": "sanchitintel" + }, + "email": "sanchit.jain@intel.com", + "name": "sanchit.jain" + }, + "oid": "954dc23663125897f4b199eb2a8607dc5fca3274" + } + }, + { + "commit": { + "author": { + "user": { + "login": "sanchitintel" + }, + "email": "sanchit.jain@intel.com", + "name": "sanchit.jain" + }, + "oid": "9f77a0b476accc678b6f0569e4ff33fa6bbe97fc" + } + }, + { + "commit": { + "author": { + "user": { + "login": "sanchitintel" + }, + "email": "sanchit.jain@intel.com", + "name": "sanchitintel" + }, + "oid": "fbf3b23bc1288697e1aec539a7c4ee3dc0bcb84c" + } + }, + { + "commit": { + "author": { + "user": { + "login": "chunyuan-w" + }, + "email": "chunyuan.wu@intel.com", + "name": "chunyuan" + }, + "oid": "f8b8e78f786586c3cdf3966fd83ffa124d3eda70" + } + }, + { + "commit": { + "author": { + "user": { + "login": "chunyuan-w" + }, + "email": "chunyuan.wu@intel.com", + "name": "chunyuan" + }, + "oid": "6fffa2f7453ee7e0f8d8e2f73ea8a65230539589" + } + }, + { + "commit": { + "author": { + "user": { + "login": "chunyuan-w" + }, + "email": "chunyuan.wu@intel.com", + "name": "chunyuan" + }, + "oid": "849385404e6f3cd1cf7cef19f931ecf4fa28afdb" + } + }, + { + "commit": { + "author": { + "user": { + "login": "chunyuan-w" + }, + "email": "chunyuan.wu@intel.com", + "name": "chunyuan" + }, + "oid": "adbae7b77f8c0dbc59fccf15207d97ba86cfade2" + } + }, + { + "commit": { + "author": { + "user": { + "login": "chunyuan-w" + }, + "email": "chunyuan.wu@intel.com", + "name": "chunyuan" + }, + "oid": "6dcf2a4981aff24fa16fc7461ae4ec29690f956f" + } + }, + { + "commit": { + "author": { + "user": { + "login": "chunyuan-w" + }, + "email": "chunyuan.wu@intel.com", + "name": "chunyuan" + }, + "oid": "54f3e05ad524cffd0911ee93be3c50f589b51f58" + } + }, + { + "commit": { + "author": { + "user": { + "login": "chunyuan-w" + }, + "email": "chunyuan.wu@intel.com", + "name": "chunyuan" + }, + "oid": "edbfc640ea79a0af85757d9e73796dcc90231519" + } + }, + { + "commit": { + "author": { + "user": { + "login": "chunyuan-w" + }, + "email": "chunyuan.wu@intel.com", + "name": "chunyuan" + }, + "oid": "67654db7cba562809d1b4a44cdda58af5cc9daaf" + } + }, + { + "commit": { + "author": { + "user": { + "login": "sanchitintel" + }, + "email": "sanchit.jain@intel.com", + "name": "sanchit.jain" + }, + "oid": "9c9d99b930b11af9ff03f52d45bf49c652df758d" + } + }, + { + "commit": { + "author": { + "user": { + "login": "sanchitintel" + }, + "email": "sanchit.jain@intel.com", + "name": "sanchit.jain" + }, + "oid": "ffb25119cd9ce815cc4d9d14a2317fcbbfa9ea86" + } + }, + { + "commit": { + "author": { + "user": { + "login": "sanchitintel" + }, + "email": "sanchit.jain@intel.com", + "name": "sanchit.jain" + }, + "oid": "ab9eee84512ca1bdfbc81e25c6eb67b29d0f302a" + } + }, + { + "commit": { + "author": { + "user": { + "login": "sanchitintel" + }, + "email": "sanchit.jain@intel.com", + "name": "sanchit.jain" + }, + "oid": "62a4642cf3330524990a69ac29e002c97812320a" + } + }, + { + "commit": { + "author": { + "user": { + "login": "sanchitintel" + }, + "email": "sanchit.jain@intel.com", + "name": "sanchit.jain" + }, + "oid": "ca9b1223be4af2c8b4929303d498eafd71793128" + } + }, + { + "commit": { + "author": { + "user": { + "login": "sanchitintel" + }, + "email": "sanchit.jain@intel.com", + "name": "sanchit.jain" + }, + "oid": "6f4a23d24514a02954d2ec792830085f612223c9" + } + }, + { + "commit": { + "author": { + "user": { + "login": "sanchitintel" + }, + "email": "sanchit.jain@intel.com", + "name": "sanchitintel" + }, + "oid": "b2a9a9c0926b02d0b2e87722ed61450f224a61d0" + } + }, + { + "commit": { + "author": { + "user": { + "login": "sanchitintel" + }, + "email": "sanchit.jain@intel.com", + "name": "sanchit.jain" + }, + "oid": "e88b492be733f24b6aa395829c76add67d0901e7" + } + }, + { + "commit": { + "author": { + "user": { + "login": "sanchitintel" + }, + "email": "sanchit.jain@intel.com", + "name": "sanchit.jain" + }, + "oid": "c44336d7a914952bfb78e012e08d9a6d6dde5937" + } + }, + { + "commit": { + "author": { + "user": { + "login": "sanchitintel" + }, + "email": "sanchit.jain@intel.com", + "name": "sanchit.jain" + }, + "oid": "5157930f7b3921d41a586260582b574c915f6ca1" + } + }, + { + "commit": { + "author": { + "user": { + "login": "sanchitintel" + }, + "email": "sanchit.jain@intel.com", + "name": "sanchit.jain" + }, + "oid": "04cb8353813f6bbd0d913a994923cc7e1e291406" + } + }, + { + "commit": { + "author": { + "user": { + "login": "sanchitintel" + }, + "email": "sanchit.jain@intel.com", + "name": "sanchitintel" + }, + "oid": "62991eaad0e638bb0bced327e03f932f66f68732" + } + }, + { + "commit": { + "author": { + "user": { + "login": "sanchitintel" + }, + "email": "sanchit.jain@intel.com", + "name": "sanchitintel" + }, + "oid": "7496bf1588050191595d833d23b8972b2f22655e" + } + }, + { + "commit": { + "author": { + "user": { + "login": "sanchitintel" + }, + "email": "sanchit.jain@intel.com", + "name": "sanchitintel" + }, + "oid": "d9d35f23cca0cd29c78a845731b24826152dcf1c" + } + }, + { + "commit": { + "author": { + "user": { + "login": "sanchitintel" + }, + "email": "sanchit.jain@intel.com", + "name": "sanchit.jain" + }, + "oid": "f74ec134f18a65a7c72455bdf44f72e3ebb27105" + } + }, + { + "commit": { + "author": { + "user": { + "login": "sanchitintel" + }, + "email": "sanchit.jain@intel.com", + "name": "sanchit.jain" + }, + "oid": "eb32cc65a975361160948bfc3d6a577991ea262e" + } + }, + { + "commit": { + "author": { + "user": { + "login": "sanchitintel" + }, + "email": "sanchit.jain@intel.com", + "name": "sanchit.jain" + }, + "oid": "c7665f8d695b680c54db0bad2b7b7df46d886b50" + } + }, + { + "commit": { + "author": { + "user": { + "login": "sanchitintel" + }, + "email": "sanchit.jain@intel.com", + "name": "sanchit.jain" + }, + "oid": "e6321ad8f59ea01130568c202d186448bb9cb9d0" + } + }, + { + "commit": { + "author": { + "user": { + "login": "sanchitintel" + }, + "email": "sanchit.jain@intel.com", + "name": "sanchit.jain" + }, + "oid": "a72cd0d02693f45e5354a70654581ad514581ec7" + } + }, + { + "commit": { + "author": { + "user": { + "login": "sanchitintel" + }, + "email": "sanchit.jain@intel.com", + "name": "sanchit.jain" + }, + "oid": "b3cd3028b4ed31805e82f7eaf02217ab74ca59b9" + } + }, + { + "commit": { + "author": { + "user": { + "login": "sanchitintel" + }, + "email": "sanchit.jain@intel.com", + "name": "sanchit.jain" + }, + "oid": "49a592d9788d08e6cd0593882f867e129057c1cc" + } + }, + { + "commit": { + "author": { + "user": { + "login": "sanchitintel" + }, + "email": "sanchit.jain@intel.com", + "name": "sanchit.jain" + }, + "oid": "0575766b2144b13f6a38227c4e2b8d22ec8db80f" + } + }, + { + "commit": { + "author": { + "user": { + "login": "sanchitintel" + }, + "email": "sanchit.jain@intel.com", + "name": "sanchit.jain" + }, + "oid": "b5c9b10ff87d622350e8ca64fae3a476eb70d5aa" + } + }, + { + "commit": { + "author": { + "user": { + "login": "sanchitintel" + }, + "email": "sanchit.jain@intel.com", + "name": "sanchit.jain" + }, + "oid": "66bc652a30ccc329adb929870a4ac726bb98b38c" + } + }, + { + "commit": { + "author": { + "user": { + "login": "sanchitintel" + }, + "email": "sanchit.jain@intel.com", + "name": "sanchit.jain" + }, + "oid": "72b9ca9c8e2dac98cbb7199b3dfac7c7305b80c5" + } + }, + { + "commit": { + "author": { + "user": { + "login": "sanchitintel" + }, + "email": "sanchit.jain@intel.com", + "name": "sanchit.jain" + }, + "oid": "a7892ed7373207d96406c8b5734a089643c5cdbd" + } + }, + { + "commit": { + "author": { + "user": { + "login": "sanchitintel" + }, + "email": "sanchit.jain@intel.com", + "name": "sanchitintel" + }, + "oid": "d54cb084e1daad8a08c3f8de0ad3f7afb5b05ac1" + } + }, + { + "commit": { + "author": { + "user": { + "login": "sanchitintel" + }, + "email": "sanchit.jain@intel.com", + "name": "sanchitintel" + }, + "oid": "aef71d692a8a159e0ca56be363e2cc1225ce7647" + } + }, + { + "commit": { + "author": { + "user": { + "login": "sanchitintel" + }, + "email": "sanchit.jain@intel.com", + "name": "sanchit.jain" + }, + "oid": "bf618e205ec31cff962dcc8ab478e0a699a9572d" + } + }, + { + "commit": { + "author": { + "user": { + "login": "sanchitintel" + }, + "email": "sanchit.jain@intel.com", + "name": "sanchit.jain" + }, + "oid": "e4a331f1088448f7d7d86256ce71e0e71da006b0" + } + }, + { + "commit": { + "author": { + "user": { + "login": "sanchitintel" + }, + "email": "sanchit.jain@intel.com", + "name": "sanchit.jain" + }, + "oid": "0b743523d1430fec759d5fefbb687f17c89335a5" + } + }, + { + "commit": { + "author": { + "user": { + "login": "sanchitintel" + }, + "email": "sanchit.jain@intel.com", + "name": "sanchit.jain" + }, + "oid": "e80a351a62d98b810ec8985c4b25257af1d6c5bb" + } + }, + { + "commit": { + "author": { + "user": { + "login": "sanchitintel" + }, + "email": "sanchit.jain@intel.com", + "name": "sanchit.jain" + }, + "oid": "c189eca154b6691919d0e21489d1c322c7435c0b" + } + }, + { + "commit": { + "author": { + "user": { + "login": "sanchitintel" + }, + "email": "sanchit.jain@intel.com", + "name": "sanchitintel" + }, + "oid": "e080a067c75d7b888a8a362682a2d5ba70e0c3a8" + } + }, + { + "commit": { + "author": { + "user": { + "login": "sanchitintel" + }, + "email": "sanchit.jain@intel.com", + "name": "sanchitintel" + }, + "oid": "028561fbf8f3ed90e074e6e0e3a4ca4dd7ffa2a8" + } + }, + { + "commit": { + "author": { + "user": { + "login": "sanchitintel" + }, + "email": "sanchit.jain@intel.com", + "name": "sanchit.jain" + }, + "oid": "d550cf14037badd4caa2f52202e2f20bc4db8432" + } + }, + { + "commit": { + "author": { + "user": { + "login": "sanchitintel" + }, + "email": "sanchit.jain@intel.com", + "name": "sanchit.jain" + }, + "oid": "574159ebadd1dec24daaf883879ffeca8d9e71b7" + } + }, + { + "commit": { + "author": { + "user": { + "login": "sanchitintel" + }, + "email": "sanchit.jain@intel.com", + "name": "sanchit.jain" + }, + "oid": "9eb3ee98ea756067ed1c8f52f309f6d3e211a904" + } + }, + { + "commit": { + "author": { + "user": { + "login": "sanchitintel" + }, + "email": "sanchit.jain@intel.com", + "name": "sanchit.jain" + }, + "oid": "29929f48be03dcdd1bbfade572de7feafa825547" + } + }, + { + "commit": { + "author": { + "user": { + "login": "sanchitintel" + }, + "email": "sanchit.jain@intel.com", + "name": "sanchit.jain" + }, + "oid": "8a7358ca8da547b40ea1a99ddc57ebed19959684" + } + }, + { + "commit": { + "author": { + "user": { + "login": "sanchitintel" + }, + "email": "sanchit.jain@intel.com", + "name": "sanchit.jain" + }, + "oid": "6606637d2c5525b43e294a8b366a85052e1be0c6" + } + }, + { + "commit": { + "author": { + "user": { + "login": "sanchitintel" + }, + "email": "sanchit.jain@intel.com", + "name": "sanchit.jain" + }, + "oid": "5ecfd1f28b87045deb8bc8ffe33b3d8b906f3264" + } + }, + { + "commit": { + "author": { + "user": { + "login": "sanchitintel" + }, + "email": "sanchit.jain@intel.com", + "name": "sanchit.jain" + }, + "oid": "be2d4345c65442c4cfbe8afdfb2ae0893945da42" + } + }, + { + "commit": { + "author": { + "user": { + "login": "sanchitintel" + }, + "email": "sanchit.jain@intel.com", + "name": "sanchitintel" + }, + "oid": "b5b89d3644a43e2dbda841cafb71b32edbe07c8a" + } + }, + { + "commit": { + "author": { + "user": { + "login": "malfet" + }, + "email": "nikita.shulga@gmail.com", + "name": "Nikita Shulga" + }, + "oid": "73881411e2bfb3aaa2e89926a82390b4c587ad75" + } + } + ], + "pageInfo": { + "endCursor": "NjI", + "hasNextPage": false + }, + "totalCount": 62 + }, + "commits": { + "nodes": [ + { + "commit": { + "checkSuites": { + "edges": [ + { + "node": { + "app": { + "name": "Facebook GitHub Tools", + "databaseId": 12274 + }, + "workflowRun": null, + "checkRuns": { + "nodes": [ + { + "name": "Facebook CLA Check", + "conclusion": "SUCCESS", + "detailsUrl": "https://code.intern.facebook.com/cla/" + }, + { + "name": "Meta Internal-Only Changes Check", + "conclusion": "SUCCESS", + "detailsUrl": "https://opensource.facebook.com/" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAU_NXnc=", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/73881411e2bfb3aaa2e89926a82390b4c587ad75/checks?check_suite_id=5743625010" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAVZYwzI=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "Lint" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "clang-format", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440028/jobs/2903895825" + }, + { + "name": "py2-setup-validate-errormsg", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440028/jobs/2903895911" + }, + { + "name": "quick-checks", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440028/jobs/2903895963" + }, + { + "name": "shellcheck", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440028/jobs/2903896134" + }, + { + "name": "toc", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440028/jobs/2903896253" + }, + { + "name": "clang-tidy", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440028/jobs/2903896371" + }, + { + "name": "cmakelint", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440028/jobs/2903896525" + }, + { + "name": "flake8-py3", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440028/jobs/2903896658" + }, + { + "name": "Test collect_env (with_torch)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440028/jobs/2903896771" + }, + { + "name": "Test collect_env (without_torch)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440028/jobs/2903896795" + }, + { + "name": "Test tools", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440028/jobs/2903896838" + }, + { + "name": "mypy", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440028/jobs/2903896897" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAU_NZqw=", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/73881411e2bfb3aaa2e89926a82390b4c587ad75/checks?check_suite_id=5743625458" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAVZYxPI=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "TorchBench CI (pytorch-linux-py3.7-cu102)" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "run-torchbench", + "conclusion": "NEUTRAL", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440031/jobs/2903895828" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAU_NYIw=", + "hasNextPage": false + } + }, + "conclusion": "SKIPPED", + "url": "https://github.com/pytorch/pytorch/commit/73881411e2bfb3aaa2e89926a82390b4c587ad75/checks?check_suite_id=5743625463" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAVZYxPc=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "pull" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "pytorch-xla-linux-bionic-py3.7-clang8", + "conclusion": "NEUTRAL", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440039/jobs/2903896014" + }, + { + "name": "deploy-linux-xenial-cuda11.3-py3.7-gcc7 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440039/jobs/2903896165" + }, + { + "name": "linux-xenial-cuda11.3-py3.7-gcc7-bazel-test / build-and-test", + "conclusion": "FAILURE", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440039/jobs/2903896394" + }, + { + "name": "linux-bionic-rocm4.5-py3.7 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440039/jobs/2903896572" + }, + { + "name": "linux-xenial-py3.7-clang7-asan / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440039/jobs/2903896666" + }, + { + "name": "linux-xenial-py3.7-clang7-onnx / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440039/jobs/2903896778" + }, + { + "name": "linux-bionic-py3.7-clang9 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440039/jobs/2903896837" + }, + { + "name": "linux-xenial-cuda11.3-py3.7-gcc7 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440039/jobs/2903896896" + }, + { + "name": "linux-xenial-py3.7-gcc5.4 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440039/jobs/2903896936" + }, + { + "name": "linux-xenial-py3-clang5-mobile-build / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440039/jobs/2903897025" + }, + { + "name": "linux-xenial-py3.7-gcc7-no-ops / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440039/jobs/2903897161" + }, + { + "name": "linux-xenial-py3.7-gcc7 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440039/jobs/2903897213" + }, + { + "name": "pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single-full-jit / build-and-test", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440039/jobs/2903897280" + }, + { + "name": "win-vs2019-cpu-py3 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440039/jobs/2903897368" + }, + { + "name": "win-vs2019-cuda11.3-py3 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440039/jobs/2903897431" + }, + { + "name": "linux-vulkan-bionic-py3.7-clang9 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440039/jobs/2903897476" + }, + { + "name": "pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single / build-and-test", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440039/jobs/2903897578" + }, + { + "name": "linux-xenial-py3.7-gcc5.4-mobile-lightweight-dispatch-build / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440039/jobs/2903897630" + }, + { + "name": "linux-xenial-py3-clang5-mobile-custom-build-static / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440039/jobs/2903897699" + }, + { + "name": "pytorch-xla-linux-bionic-py3.7-clang8", + "conclusion": "NEUTRAL", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440039/jobs/2903897733" + }, + { + "name": "linux-docs / build-docs (cpp)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440039/jobs/2904327787" + }, + { + "name": "linux-docs / build-docs (python)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440039/jobs/2904327838" + }, + { + "name": "linux-xenial-py3.7-gcc5.4 / test (default, 1, 2, linux.2xlarge)", + "conclusion": "FAILURE", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440039/jobs/2904327956" + }, + { + "name": "linux-xenial-py3.7-gcc5.4 / test (default, 2, 2, linux.2xlarge)", + "conclusion": "FAILURE", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440039/jobs/2904327997" + }, + { + "name": "linux-xenial-py3.7-gcc5.4 / test (distributed, 1, 1, linux.2xlarge)", + "conclusion": "FAILURE", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440039/jobs/2904328035" + }, + { + "name": "linux-xenial-py3.7-gcc5.4 / test (docs_test, 1, 1, linux.2xlarge)", + "conclusion": "FAILURE", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440039/jobs/2904328093" + }, + { + "name": "linux-xenial-py3.7-gcc5.4 / test (backwards_compat, 1, 1, linux.2xlarge)", + "conclusion": "FAILURE", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440039/jobs/2904328131" + }, + { + "name": "linux-xenial-py3.7-gcc5.4 / test (jit_legacy, 1, 1, linux.2xlarge)", + "conclusion": "FAILURE", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440039/jobs/2904328177" + }, + { + "name": "linux-xenial-py3.7-gcc7 / test (default, 1, 2, linux.2xlarge)", + "conclusion": "FAILURE", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440039/jobs/2904333962" + }, + { + "name": "linux-xenial-py3.7-gcc7 / test (default, 2, 2, linux.2xlarge)", + "conclusion": "FAILURE", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440039/jobs/2904334006" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (default, 1, 2, linux.2xlarge)", + "conclusion": "FAILURE", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440039/jobs/2904430419" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (default, 2, 2, linux.2xlarge)", + "conclusion": "FAILURE", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440039/jobs/2904430459" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (noarch, 1, 1, linux.2xlarge)", + "conclusion": "FAILURE", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440039/jobs/2904430508" + }, + { + "name": "linux-vulkan-bionic-py3.7-clang9 / test (default, 1, 1, linux.2xlarge)", + "conclusion": "FAILURE", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440039/jobs/2904430573" + }, + { + "name": "linux-xenial-py3.7-clang7-asan / test (default, 1, 3, linux.2xlarge)", + "conclusion": "FAILURE", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440039/jobs/2904443663" + }, + { + "name": "linux-xenial-py3.7-clang7-asan / test (default, 2, 3, linux.2xlarge)", + "conclusion": "FAILURE", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440039/jobs/2904443723" + }, + { + "name": "linux-xenial-py3.7-clang7-asan / test (default, 3, 3, linux.2xlarge)", + "conclusion": "FAILURE", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440039/jobs/2904443787" + }, + { + "name": "win-vs2019-cpu-py3 / test (default, 1, 2, windows.4xlarge)", + "conclusion": "FAILURE", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440039/jobs/2904454239" + }, + { + "name": "win-vs2019-cpu-py3 / test (default, 2, 2, windows.4xlarge)", + "conclusion": "FAILURE", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440039/jobs/2904454303" + }, + { + "name": "linux-xenial-py3.7-clang7-onnx / test (default, 1, 2, linux.2xlarge)", + "conclusion": "FAILURE", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440039/jobs/2904554602" + }, + { + "name": "linux-xenial-py3.7-clang7-onnx / test (default, 2, 2, linux.2xlarge)", + "conclusion": "FAILURE", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440039/jobs/2904554698" + }, + { + "name": "win-vs2019-cuda11.3-py3 / test (default, 1, 2, windows.8xlarge.nvidia.gpu)", + "conclusion": "FAILURE", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440039/jobs/2904588855" + }, + { + "name": "win-vs2019-cuda11.3-py3 / test (default, 2, 2, windows.8xlarge.nvidia.gpu)", + "conclusion": "FAILURE", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440039/jobs/2904588886" + }, + { + "name": "win-vs2019-cuda11.3-py3 / test (force_on_cpu, 1, 1, windows.4xlarge)", + "conclusion": "FAILURE", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440039/jobs/2904588924" + }, + { + "name": "linux-xenial-cuda11.3-py3.7-gcc7 / test (deploy, 1, 1, linux.4xlarge.nvidia.gpu)", + "conclusion": "FAILURE", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440039/jobs/2904655702" + }, + { + "name": "linux-xenial-cuda11.3-py3.7-gcc7 / test (default, 1, 2, linux.4xlarge.nvidia.gpu)", + "conclusion": "FAILURE", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440039/jobs/2904656104" + }, + { + "name": "linux-xenial-cuda11.3-py3.7-gcc7 / test (default, 2, 2, linux.4xlarge.nvidia.gpu)", + "conclusion": "FAILURE", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440039/jobs/2904656150" + }, + { + "name": "linux-xenial-cuda11.3-py3.7-gcc7 / test (distributed, 1, 1, linux.8xlarge.nvidia.gpu)", + "conclusion": "FAILURE", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440039/jobs/2904656192" + }, + { + "name": "linux-bionic-rocm4.5-py3.7 / test (default, 1, 2, linux.rocm.gpu)", + "conclusion": "FAILURE", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440039/jobs/2904706520" + }, + { + "name": "linux-bionic-rocm4.5-py3.7 / test (default, 2, 2, linux.rocm.gpu)", + "conclusion": "FAILURE", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2018440039/jobs/2904706565" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAU_fN1g=", + "hasNextPage": false + } + }, + "conclusion": "FAILURE", + "url": "https://github.com/pytorch/pytorch/commit/73881411e2bfb3aaa2e89926a82390b4c587ad75/checks?check_suite_id=5743625483" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAVZYxQs=" + } + ], + "pageInfo": { + "hasNextPage": false + } + }, + "status": { + "contexts": [ + { + "context": "ci/circleci: binary_linux_libtorch_3_7m_cpu_gcc5_4_cxx11-abi_shared-with-deps_build", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/17048428?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: docker-pytorch-linux-xenial-py3-clang5-android-ndk-r19c", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/17048429?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-build-x86_32", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/17048431?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: pytorch_linux_xenial_py3_clang5_android_ndk_r19c_x86_32_build", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/17048430?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + } + ] + }, + "pushedDate": "2022-03-21T19:58:52Z", + "oid": "73881411e2bfb3aaa2e89926a82390b4c587ad75" + } + } + ] + }, + "changedFiles": 37, + "files": { + "nodes": [ + { + "path": "aten/src/ATen/core/interned_strings.h" + }, + { + "path": "caffe2/CMakeLists.txt" + }, + { + "path": "cmake/Dependencies.cmake" + }, + { + "path": "cmake/Modules/FindMKLDNN.cmake" + }, + { + "path": "cmake/public/mkldnn.cmake" + }, + { + "path": "docs/source/jit.rst" + }, + { + "path": "test/test_jit_llga_fuser.py" + }, + { + "path": "torch/_C/__init__.pyi.in" + }, + { + "path": "torch/csrc/jit/codegen/onednn/LlgaTensorImpl.cpp" + }, + { + "path": "torch/csrc/jit/codegen/onednn/LlgaTensorImpl.h" + }, + { + "path": "torch/csrc/jit/codegen/onednn/README.md" + }, + { + "path": "torch/csrc/jit/codegen/onednn/defer_size_check.cpp" + }, + { + "path": "torch/csrc/jit/codegen/onednn/defer_size_check.h" + }, + { + "path": "torch/csrc/jit/codegen/onednn/graph_fuser.cpp" + }, + { + "path": "torch/csrc/jit/codegen/onednn/graph_fuser.h" + }, + { + "path": "torch/csrc/jit/codegen/onednn/graph_helper.cpp" + }, + { + "path": "torch/csrc/jit/codegen/onednn/graph_helper.h" + }, + { + "path": "torch/csrc/jit/codegen/onednn/graph_rewriter.cpp" + }, + { + "path": "torch/csrc/jit/codegen/onednn/guard_shape.cpp" + }, + { + "path": "torch/csrc/jit/codegen/onednn/guard_shape.h" + }, + { + "path": "torch/csrc/jit/codegen/onednn/interface.cpp" + }, + { + "path": "torch/csrc/jit/codegen/onednn/interface.h" + }, + { + "path": "torch/csrc/jit/codegen/onednn/kernel.cpp" + }, + { + "path": "torch/csrc/jit/codegen/onednn/kernel.h" + }, + { + "path": "torch/csrc/jit/codegen/onednn/layout_propagation.cpp" + }, + { + "path": "torch/csrc/jit/codegen/onednn/layout_propagation.h" + }, + { + "path": "torch/csrc/jit/codegen/onednn/operator.h" + }, + { + "path": "torch/csrc/jit/codegen/onednn/prepare_binary.cpp" + }, + { + "path": "torch/csrc/jit/codegen/onednn/prepare_binary.h" + }, + { + "path": "torch/csrc/jit/codegen/onednn/register_interface.cpp" + }, + { + "path": "torch/csrc/jit/ir/alias_analysis.cpp" + }, + { + "path": "torch/csrc/jit/ir/ir.cpp" + }, + { + "path": "torch/csrc/jit/passes/inline_autodiff_subgraphs.cpp" + }, + { + "path": "torch/csrc/jit/passes/onednn_graph_fuser.h" + }, + { + "path": "torch/csrc/jit/python/init.cpp" + }, + { + "path": "torch/csrc/jit/runtime/operator.cpp" + }, + { + "path": "torch/jit/__init__.py" + } + ], + "pageInfo": { + "endCursor": "Mzc", + "hasNextPage": false + } + }, + "reviews": { + "nodes": [ + { + "author": { + "login": "pinzhenx" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "pinzhenx" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "sanchitintel" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "sanchitintel" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "pinzhenx" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "sanchitintel" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "chunyuan-w" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "eellison" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "sanchitintel" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "sanchitintel" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "sanchitintel" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "sanchitintel" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "sanchitintel" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "sanchitintel" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "sanchitintel" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "sanchitintel" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "sanchitintel" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "sanchitintel" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "sanchitintel" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "sanchitintel" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "sanchitintel" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "sanchitintel" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "sanchitintel" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "sanchitintel" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "sanchitintel" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "sanchitintel" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "sanchitintel" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "wukong1992" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "eellison" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "eellison" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "sanchitintel" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "sanchitintel" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "eellison" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "sanchitintel" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "sanchitintel" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "sanchitintel" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "sanchitintel" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "sanchitintel" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "eellison" + }, + "state": "APPROVED" + }, + { + "author": { + "login": "sanchitintel" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "eellison" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "malfet" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "sanchitintel" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "malfet" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "malfet" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "sanchitintel" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "sanchitintel" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "sanchitintel" + }, + "state": "COMMENTED" + }, + { + "author": { + "login": "sanchitintel" + }, + "state": "COMMENTED" + } + ], + "pageInfo": { + "startCursor": "Y3Vyc29yOnYyOpO5MjAyMS0xMi0xMFQwOToyNDoxOS0wODowMLkyMDIxLTEyLTEwVDA5OjI0OjE5LTA4OjAwzjFryLE=", + "hasPreviousPage": false + } + }, + "comments": { + "nodes": [ + { + "bodyText": "Looks like this broke master https://hud.pytorch.org/pytorch/pytorch/commit/7dd08230117f4fa8bb82b3524e90fb00340198c7. I am reverting.", + "createdAt": "2022-03-21T22:51:38Z", + "author": { + "login": "suo" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1074498483 + }, + { + "bodyText": "@pytorchbot revert this", + "createdAt": "2022-03-21T22:51:44Z", + "author": { + "login": "suo" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1074498550 + }, + { + "bodyText": "Looks like this broke master https://hud.pytorch.org/pytorch/pytorch/commit/7dd08230117f4fa8bb82b3524e90fb00340198c7. I am reverting.\n\nOops! Will fix it ASAP.", + "createdAt": "2022-03-21T22:53:34Z", + "author": { + "login": "sanchitintel" + }, + "authorAssociation": "COLLABORATOR", + "editor": null, + "databaseId": 1074499668 + }, + { + "bodyText": "This pull request has been reverted by e5bf879. To re-land this change, please open another pull request, assignthe same reviewers, fix the CI failures that caused the revert and make sure that the failing CI runs on the PR by applying the proper ciflow label (e.g., ciflow/trunk).", + "createdAt": "2022-03-21T23:07:23Z", + "author": { + "login": "facebook-github-bot" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1074508608 + }, + { + "bodyText": "This pull request has been reverted by e5bf879. To re-land this change, please open another pull request, assignthe same reviewers, fix the CI failures that caused the revert and make sure that the failing CI runs on the PR by applying the proper ciflow label (e.g., ciflow/trunk).", + "createdAt": "2022-03-30T00:53:50Z", + "author": { + "login": "facebook-github-bot" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1082508130 + } + ], + "pageInfo": { + "startCursor": "Y3Vyc29yOnYyOpHOQAuLsw==", + "hasPreviousPage": true + } + }, + "labels": { + "edges": [ + { + "node": { + "name": "oncall: jit" + } + }, + { + "node": { + "name": "triaged" + } + }, + { + "node": { + "name": "open source" + } + }, + { + "node": { + "name": "cla signed" + } + }, + { + "node": { + "name": "Reverted" + } + }, + { + "node": { + "name": "intel priority" + } + } + ] + } + } + } + } + }, + "query_sha=2e2877d2452c4f233f042b7ccd50ab9c2a6e9a73d8819a0c876203c12364e8a3 cursor=Y3Vyc29yOnYyOpHOQAuLsw== name=pytorch number=68111 owner=pytorch": { + "data": { + "repository": { + "pullRequest": { + "comments": { + "nodes": [ + { + "bodyText": "CI Flow Status\n\u269b\ufe0f CI Flow\nRuleset - Version: v1\nRuleset - File: https://github.com/chunyuan-w/pytorch/blob/7496bf1588050191595d833d23b8972b2f22655e/.github/generated-ciflow-ruleset.json\nPR ciflow labels: ciflow/default\n\n\n\nWorkflows\nLabels (bold enabled)\nStatus\n\n\n\n\nTriggered Workflows\n\n\n\n\nlinux-bionic-py3.7-clang9\nciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/noarch, ciflow/trunk\n\u2705 triggered\n\n\nlinux-docs\nciflow/all, ciflow/cpu, ciflow/default, ciflow/docs, ciflow/linux, ciflow/trunk\n\u2705 triggered\n\n\nlinux-vulkan-bionic-py3.7-clang9\nciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk, ciflow/vulkan\n\u2705 triggered\n\n\nlinux-xenial-cuda11.3-py3.7-gcc7\nciflow/all, ciflow/cuda, ciflow/default, ciflow/linux, ciflow/trunk\n\u2705 triggered\n\n\nlinux-xenial-cuda11.3-py3.7-gcc7-bazel-test\nciflow/all, ciflow/bazel, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk\n\u2705 triggered\n\n\nlinux-xenial-py3-clang5-mobile-build\nciflow/all, ciflow/default, ciflow/linux, ciflow/mobile, ciflow/trunk\n\u2705 triggered\n\n\nlinux-xenial-py3-clang5-mobile-custom-build-static\nciflow/all, ciflow/default, ciflow/linux, ciflow/mobile, ciflow/trunk\n\u2705 triggered\n\n\nlinux-xenial-py3.7-clang7-asan\nciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/sanitizers, ciflow/trunk\n\u2705 triggered\n\n\nlinux-xenial-py3.7-clang7-onnx\nciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/onnx, ciflow/trunk\n\u2705 triggered\n\n\nlinux-xenial-py3.7-gcc5.4\nciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk\n\u2705 triggered\n\n\nlinux-xenial-py3.7-gcc7\nciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk\n\u2705 triggered\n\n\nlinux-xenial-py3.7-gcc7-no-ops\nciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk\n\u2705 triggered\n\n\npytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single\nciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk\n\u2705 triggered\n\n\npytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single-full-jit\nciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk\n\u2705 triggered\n\n\nwin-vs2019-cpu-py3\nciflow/all, ciflow/cpu, ciflow/default, ciflow/trunk, ciflow/win\n\u2705 triggered\n\n\nwin-vs2019-cuda11.3-py3\nciflow/all, ciflow/cuda, ciflow/default, ciflow/trunk, ciflow/win\n\u2705 triggered\n\n\nSkipped Workflows\n\n\n\n\ncaffe2-linux-xenial-py3.7-gcc5.4\nciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk\n\ud83d\udeab skipped\n\n\ndocker-builds\nciflow/all, ciflow/trunk\n\ud83d\udeab skipped\n\n\nios-12-5-1-arm64\nciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk\n\ud83d\udeab skipped\n\n\nios-12-5-1-arm64-coreml\nciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk\n\ud83d\udeab skipped\n\n\nios-12-5-1-arm64-custom-ops\nciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk\n\ud83d\udeab skipped\n\n\nios-12-5-1-arm64-full-jit\nciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk\n\ud83d\udeab skipped\n\n\nios-12-5-1-arm64-metal\nciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk\n\ud83d\udeab skipped\n\n\nios-12-5-1-x86-64\nciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk\n\ud83d\udeab skipped\n\n\nios-12-5-1-x86-64-coreml\nciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk\n\ud83d\udeab skipped\n\n\nios-12-5-1-x86-64-full-jit\nciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk\n\ud83d\udeab skipped\n\n\nlibtorch-linux-xenial-cuda10.2-py3.7-gcc7\nciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/trunk\n\ud83d\udeab skipped\n\n\nlibtorch-linux-xenial-cuda11.3-py3.7-gcc7\nciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/trunk\n\ud83d\udeab skipped\n\n\nlinux-binary-conda\nciflow/binaries, ciflow/binaries/conda\n\ud83d\udeab skipped\n\n\nlinux-binary-libtorch-cxx11-abi\nciflow/binaries, ciflow/binaries/libtorch\n\ud83d\udeab skipped\n\n\nlinux-binary-libtorch-pre-cxx11\nciflow/binaries, ciflow/binaries/libtorch\n\ud83d\udeab skipped\n\n\nlinux-binary-manywheel\nciflow/binaries, ciflow/binaries/wheel\n\ud83d\udeab skipped\n\n\nlinux-bionic-cuda10.2-py3.9-gcc7\nciflow/all, ciflow/cuda, ciflow/linux, ciflow/slow, ciflow/trunk\n\ud83d\udeab skipped\n\n\nlinux-docs-push\nciflow/all, ciflow/cpu, ciflow/linux, ciflow/scheduled\n\ud83d\udeab skipped\n\n\nlinux-xenial-cuda11.3-py3.7-gcc7-no-ops\nciflow/all, ciflow/cuda, ciflow/linux, ciflow/trunk\n\ud83d\udeab skipped\n\n\nmacos-10-15-py3-arm64\nciflow/all, ciflow/macos, ciflow/trunk\n\ud83d\udeab skipped\n\n\nmacos-10-15-py3-lite-interpreter-x86-64\nciflow/all, ciflow/macos, ciflow/trunk\n\ud83d\udeab skipped\n\n\nmacos-11-py3-x86-64\nciflow/all, ciflow/macos, ciflow/trunk\n\ud83d\udeab skipped\n\n\nparallelnative-linux-xenial-py3.7-gcc5.4\nciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk\n\ud83d\udeab skipped\n\n\nperiodic-libtorch-linux-bionic-cuda11.5-py3.7-gcc7\nciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/scheduled\n\ud83d\udeab skipped\n\n\nperiodic-libtorch-linux-xenial-cuda11.1-py3.7-gcc7\nciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/scheduled\n\ud83d\udeab skipped\n\n\nperiodic-linux-bionic-cuda11.5-py3.7-gcc7\nciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled\n\ud83d\udeab skipped\n\n\nperiodic-linux-xenial-cuda10.2-py3-gcc7-slow-gradcheck\nciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled, ciflow/slow, ciflow/slow-gradcheck\n\ud83d\udeab skipped\n\n\nperiodic-linux-xenial-cuda11.1-py3.7-gcc7-debug\nciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled\n\ud83d\udeab skipped\n\n\nperiodic-win-vs2019-cuda11.1-py3\nciflow/all, ciflow/cuda, ciflow/scheduled, ciflow/win\n\ud83d\udeab skipped\n\n\nperiodic-win-vs2019-cuda11.5-py3\nciflow/all, ciflow/cuda, ciflow/scheduled, ciflow/win\n\ud83d\udeab skipped\n\n\npytorch-linux-xenial-py3-clang5-android-ndk-r19c-build\nciflow/all, ciflow/android, ciflow/cpu, ciflow/linux, ciflow/trunk\n\ud83d\udeab skipped\n\n\n\n\nYou can add a comment to the PR and tag @pytorchbot with the following commands:\n\n# ciflow rerun, \"ciflow/default\" will always be added automatically\n@pytorchbot ciflow rerun\n\n# ciflow rerun with additional labels \"-l \", which is equivalent to adding these labels manually and trigger the rerun\n@pytorchbot ciflow rerun -l ciflow/scheduled -l ciflow/slow\n\nFor more information, please take a look at the CI Flow Wiki.", + "createdAt": "2021-11-10T08:42:49Z", + "author": { + "login": "pytorch-probot" + }, + "authorAssociation": "NONE", + "editor": { + "login": "pytorch-probot" + }, + "databaseId": 964902865 + }, + { + "bodyText": "\ud83d\udd17 Helpful links\n\n\ud83e\uddea \u00a0See artifacts and rendered test results at hud.pytorch.org/pr/68111\nNeed help or want to give feedback on the CI? Visit our office hours\n\n\ud83d\udc8a CI failures summary and remediations\nAs of commit 7388141 (more details on the Dr. CI page):\n\n\n29/29 failures introduced in this PR\n\n\n\ud83d\udd75\ufe0f 29 new failures recognized by patterns\nThe following CI failures do not appear to be due to upstream breakages:\n pull / linux-bionic-py3.7-clang9 / test (default, 2, 2, linux.2xlarge) (1/29)\nStep: \"Upload test statistics\" (full log | diagnosis details | \ud83d\udd01 rerun)\n\n\n2022-03-21T21:31:38.6978776Z python3: can't ope...ow_job_id.py': [Errno 2] No such file or directory\n\n2022-03-21T21:31:38.3001628Z + python3 -m pip install boto3==1.19.12\n2022-03-21T21:31:38.5169168Z Defaulting to user installation because normal site-packages is not writeable\n2022-03-21T21:31:38.5362923Z Requirement already satisfied: boto3==1.19.12 in /home/ec2-user/.local/lib/python3.7/site-packages (1.19.12)\n2022-03-21T21:31:38.5413452Z Requirement already satisfied: botocore<1.23.0,>=1.22.12 in /home/ec2-user/.local/lib/python3.7/site-packages (from boto3==1.19.12) (1.22.12)\n2022-03-21T21:31:38.5458747Z Requirement already satisfied: s3transfer<0.6.0,>=0.5.0 in /home/ec2-user/.local/lib/python3.7/site-packages (from boto3==1.19.12) (0.5.2)\n2022-03-21T21:31:38.5484014Z Requirement already satisfied: jmespath<1.0.0,>=0.7.1 in /home/ec2-user/.local/lib/python3.7/site-packages (from boto3==1.19.12) (0.10.0)\n2022-03-21T21:31:38.5497924Z Requirement already satisfied: urllib3<1.27,>=1.25.4 in /home/ec2-user/.local/lib/python3.7/site-packages (from botocore<1.23.0,>=1.22.12->boto3==1.19.12) (1.26.9)\n2022-03-21T21:31:38.5656491Z Requirement already satisfied: python-dateutil<3.0.0,>=2.1 in /home/ec2-user/.local/lib/python3.7/site-packages (from botocore<1.23.0,>=1.22.12->boto3==1.19.12) (2.8.2)\n2022-03-21T21:31:38.5678893Z Requirement already satisfied: six>=1.5 in /home/ec2-user/.local/lib/python3.7/site-packages (from python-dateutil<3.0.0,>=2.1->botocore<1.23.0,>=1.22.12->boto3==1.19.12) (1.16.0)\n2022-03-21T21:31:38.6888479Z ++ python3 .github/scripts/get_workflow_job_id.py 2018440039 i-0f6488c20adb4dca4\n2022-03-21T21:31:38.6978776Z python3: can't open file '.github/scripts/get_workflow_job_id.py': [Errno 2] No such file or directory\n2022-03-21T21:31:38.6992648Z + GHA_WORKFLOW_JOB_ID=\n2022-03-21T21:31:38.7003010Z ##[error]Process completed with exit code 2.\n2022-03-21T21:31:38.7044027Z ##[group]Run pytorch/pytorch/.github/actions/teardown-linux@master\n2022-03-21T21:31:38.7044261Z with:\n2022-03-21T21:31:38.7044413Z env:\n2022-03-21T21:31:38.7044565Z IN_CI: 1\n2022-03-21T21:31:38.7044709Z IS_GHA: 1\n2022-03-21T21:31:38.7044885Z GIT_DEFAULT_BRANCH: master\n2022-03-21T21:31:38.7045067Z ##[endgroup]\n2022-03-21T21:31:38.7060958Z ##[group]Run # ignore expansion of \"docker ps -q\" since it could be empty\n\n\n pull / linux-xenial-py3.7-gcc5.4 / test (default, 1, 2, linux.2xlarge) (2/29)\nStep: \"Upload test statistics\" (full log | diagnosis details | \ud83d\udd01 rerun)\n\n\n2022-03-21T21:35:19.2635222Z python3: can't ope...ow_job_id.py': [Errno 2] No such file or directory\n\n2022-03-21T21:35:18.9028722Z + python3 -m pip install boto3==1.19.12\n2022-03-21T21:35:19.1132721Z Defaulting to user installation because normal site-packages is not writeable\n2022-03-21T21:35:19.1310590Z Requirement already satisfied: boto3==1.19.12 in /home/ec2-user/.local/lib/python3.7/site-packages (1.19.12)\n2022-03-21T21:35:19.1360251Z Requirement already satisfied: s3transfer<0.6.0,>=0.5.0 in /home/ec2-user/.local/lib/python3.7/site-packages (from boto3==1.19.12) (0.5.2)\n2022-03-21T21:35:19.1386865Z Requirement already satisfied: botocore<1.23.0,>=1.22.12 in /home/ec2-user/.local/lib/python3.7/site-packages (from boto3==1.19.12) (1.22.12)\n2022-03-21T21:35:19.1429182Z Requirement already satisfied: jmespath<1.0.0,>=0.7.1 in /home/ec2-user/.local/lib/python3.7/site-packages (from boto3==1.19.12) (0.10.0)\n2022-03-21T21:35:19.1441925Z Requirement already satisfied: python-dateutil<3.0.0,>=2.1 in /home/ec2-user/.local/lib/python3.7/site-packages (from botocore<1.23.0,>=1.22.12->boto3==1.19.12) (2.8.2)\n2022-03-21T21:35:19.1468280Z Requirement already satisfied: urllib3<1.27,>=1.25.4 in /home/ec2-user/.local/lib/python3.7/site-packages (from botocore<1.23.0,>=1.22.12->boto3==1.19.12) (1.26.9)\n2022-03-21T21:35:19.1617667Z Requirement already satisfied: six>=1.5 in /home/ec2-user/.local/lib/python3.7/site-packages (from python-dateutil<3.0.0,>=2.1->botocore<1.23.0,>=1.22.12->boto3==1.19.12) (1.16.0)\n2022-03-21T21:35:19.2545368Z ++ python3 .github/scripts/get_workflow_job_id.py 2018440039 i-098be2985e0392130\n2022-03-21T21:35:19.2635222Z python3: can't open file '.github/scripts/get_workflow_job_id.py': [Errno 2] No such file or directory\n2022-03-21T21:35:19.2648463Z + GHA_WORKFLOW_JOB_ID=\n2022-03-21T21:35:19.2658727Z ##[error]Process completed with exit code 2.\n2022-03-21T21:35:19.2706355Z ##[group]Run pytorch/pytorch/.github/actions/teardown-linux@master\n2022-03-21T21:35:19.2706591Z with:\n2022-03-21T21:35:19.2706748Z env:\n2022-03-21T21:35:19.2706908Z IN_CI: 1\n2022-03-21T21:35:19.2707061Z IS_GHA: 1\n2022-03-21T21:35:19.2707246Z GIT_DEFAULT_BRANCH: master\n2022-03-21T21:35:19.2707438Z ##[endgroup]\n2022-03-21T21:35:19.2724554Z ##[group]Run # ignore expansion of \"docker ps -q\" since it could be empty\n\n\n pull / win-vs2019-cuda11.3-py3 / test (force_on_cpu, 1, 1, windows.4xlarge) (3/29)\nStep: \"Upload test statistics\" (full log | diagnosis details | \ud83d\udd01 rerun)\n\n\n2022-03-21T23:11:57.5531419Z C:\\actions-runner\\...ow_job_id.py': [Errno 2] No such file or directory\n\n2022-03-21T23:11:52.7662022Z Downloading botocore-1.22.12-py3-none-any.whl (8.1 MB)\n2022-03-21T23:11:53.1213298Z ---------------------------------------- 8.1/8.1 MB 23.6 MB/s eta 0:00:00\n2022-03-21T23:11:53.1644665Z Requirement already satisfied: urllib3<1.27,>=1.25.4 in c:\\actions-runner\\_work\\_tool\\python\\3.10.3\\x64\\lib\\site-packages (from botocore<1.23.0,>=1.22.12->boto3==1.19.12) (1.26.9)\n2022-03-21T23:11:53.2218699Z Collecting python-dateutil<3.0.0,>=2.1\n2022-03-21T23:11:53.2389674Z Downloading python_dateutil-2.8.2-py2.py3-none-any.whl (247 kB)\n2022-03-21T23:11:53.2787295Z -------------------------------------- 247.7/247.7 KB 7.4 MB/s eta 0:00:00\n2022-03-21T23:11:53.3761842Z Requirement already satisfied: six>=1.5 in c:\\actions-runner\\_work\\_tool\\python\\3.10.3\\x64\\lib\\site-packages (from python-dateutil<3.0.0,>=2.1->botocore<1.23.0,>=1.22.12->boto3==1.19.12) (1.16.0)\n2022-03-21T23:11:53.5457622Z Installing collected packages: python-dateutil, jmespath, botocore, s3transfer, boto3\n2022-03-21T23:11:57.4175080Z Successfully installed boto3-1.19.12 botocore-1.22.12 jmespath-0.10.0 python-dateutil-2.8.2 s3transfer-0.5.2\n2022-03-21T23:11:57.5296815Z ++ python3 .github/scripts/get_workflow_job_id.py 2018440039 i-0105d4db093574f40\n2022-03-21T23:11:57.5531419Z C:\\actions-runner\\_work\\_tool\\Python\\3.10.3\\x64\\python3.exe: can't open file 'C:\\\\actions-runner\\\\_work\\\\pytorch\\\\pytorch\\\\.github\\\\scripts\\\\get_workflow_job_id.py': [Errno 2] No such file or directory\n2022-03-21T23:11:57.5564814Z + GHA_WORKFLOW_JOB_ID=\n2022-03-21T23:11:57.5587712Z ##[error]Process completed with exit code 2.\n2022-03-21T23:11:57.5790311Z ##[group]Run pytorch/pytorch/.github/actions/teardown-win@master\n2022-03-21T23:11:57.5790832Z with:\n2022-03-21T23:11:57.5791104Z env:\n2022-03-21T23:11:57.5791358Z IN_CI: 1\n2022-03-21T23:11:57.5791620Z IS_GHA: 1\n2022-03-21T23:11:57.5791939Z GIT_DEFAULT_BRANCH: master\n2022-03-21T23:11:57.5792425Z pythonLocation: C:\\actions-runner\\_work\\_tool\\Python\\3.10.3\\x64\n2022-03-21T23:11:57.5792884Z ##[endgroup]\n\n\n pull / linux-bionic-rocm4.5-py3.7 / test (default, 1, 2, linux.rocm.gpu) (4/29)\nStep: \"Upload test statistics\" (full log | diagnosis details | \ud83d\udd01 rerun)\n\n\n2022-03-22T02:17:12.6257577Z python3: can't ope...ow_job_id.py': [Errno 2] No such file or directory\n\n2022-03-22T02:17:11.9280556Z Using cached https://files.pythonhosted.org/packages/7b/9c/f51775ebe7df5a7aa4e7c79ed671bde94e154bd968aca8d65bb24aba0c8c/s3transfer-0.5.2-py3-none-any.whl\n2022-03-22T02:17:11.9335199Z Collecting urllib3<1.27,>=1.25.4 (from botocore<1.23.0,>=1.22.12->boto3==1.19.12)\n2022-03-22T02:17:11.9682045Z Using cached https://files.pythonhosted.org/packages/ec/03/062e6444ce4baf1eac17a6a0ebfe36bb1ad05e1df0e20b110de59c278498/urllib3-1.26.9-py2.py3-none-any.whl\n2022-03-22T02:17:11.9850357Z Collecting python-dateutil<3.0.0,>=2.1 (from botocore<1.23.0,>=1.22.12->boto3==1.19.12)\n2022-03-22T02:17:12.0403171Z Using cached https://files.pythonhosted.org/packages/36/7a/87837f39d0296e723bb9b62bbb257d0355c7f6128853c78955f57342a56d/python_dateutil-2.8.2-py2.py3-none-any.whl\n2022-03-22T02:17:12.0468875Z Collecting six>=1.5 (from python-dateutil<3.0.0,>=2.1->botocore<1.23.0,>=1.22.12->boto3==1.19.12)\n2022-03-22T02:17:12.0590000Z Using cached https://files.pythonhosted.org/packages/d9/5a/e7c31adbe875f2abbb91bd84cf2dc52d792b5a01506781dbcf25c91daf11/six-1.16.0-py2.py3-none-any.whl\n2022-03-22T02:17:12.0607093Z Installing collected packages: jmespath, urllib3, six, python-dateutil, botocore, s3transfer, boto3\n2022-03-22T02:17:12.5273459Z Successfully installed boto3-1.19.12 botocore-1.22.12 jmespath-0.10.0 python-dateutil-2.8.2 s3transfer-0.5.2 six-1.16.0 urllib3-1.26.9\n2022-03-22T02:17:12.6032812Z ++ python3 .github/scripts/get_workflow_job_id.py 2018440039 worker-rocm-amd-114\n2022-03-22T02:17:12.6257577Z python3: can't open file '.github/scripts/get_workflow_job_id.py': [Errno 2] No such file or directory\n2022-03-22T02:17:12.6259543Z + GHA_WORKFLOW_JOB_ID=\n2022-03-22T02:17:12.6291924Z ##[error]Process completed with exit code 2.\n2022-03-22T02:17:12.6387977Z ##[group]Run pytorch/pytorch/.github/actions/teardown-linux@master\n2022-03-22T02:17:12.6388298Z with:\n2022-03-22T02:17:12.6388521Z wait-ssh: false\n2022-03-22T02:17:12.6388727Z env:\n2022-03-22T02:17:12.6388932Z IN_CI: 1\n2022-03-22T02:17:12.6389143Z IS_GHA: 1\n2022-03-22T02:17:12.6389368Z GIT_DEFAULT_BRANCH: master\n2022-03-22T02:17:12.6389669Z DOCKER_HOST: unix:///run/user/1121/docker.sock\n\n\n pull / linux-xenial-py3.7-clang7-onnx / test (default, 2, 2, linux.2xlarge) (5/29)\nStep: \"Upload test statistics\" (full log | diagnosis details | \ud83d\udd01 rerun)\n\n\n2022-03-21T22:19:24.4890693Z python3: can't ope...ow_job_id.py': [Errno 2] No such file or directory\n\n2022-03-21T22:19:24.0962005Z + python3 -m pip install boto3==1.19.12\n2022-03-21T22:19:24.3152253Z Defaulting to user installation because normal site-packages is not writeable\n2022-03-21T22:19:24.3341183Z Requirement already satisfied: boto3==1.19.12 in /home/ec2-user/.local/lib/python3.7/site-packages (1.19.12)\n2022-03-21T22:19:24.3391374Z Requirement already satisfied: botocore<1.23.0,>=1.22.12 in /home/ec2-user/.local/lib/python3.7/site-packages (from boto3==1.19.12) (1.22.12)\n2022-03-21T22:19:24.3436392Z Requirement already satisfied: jmespath<1.0.0,>=0.7.1 in /home/ec2-user/.local/lib/python3.7/site-packages (from boto3==1.19.12) (0.10.0)\n2022-03-21T22:19:24.3448982Z Requirement already satisfied: s3transfer<0.6.0,>=0.5.0 in /home/ec2-user/.local/lib/python3.7/site-packages (from boto3==1.19.12) (0.5.2)\n2022-03-21T22:19:24.3474092Z Requirement already satisfied: python-dateutil<3.0.0,>=2.1 in /home/ec2-user/.local/lib/python3.7/site-packages (from botocore<1.23.0,>=1.22.12->boto3==1.19.12) (2.8.2)\n2022-03-21T22:19:24.3502003Z Requirement already satisfied: urllib3<1.27,>=1.25.4 in /home/ec2-user/.local/lib/python3.7/site-packages (from botocore<1.23.0,>=1.22.12->boto3==1.19.12) (1.26.9)\n2022-03-21T22:19:24.3655072Z Requirement already satisfied: six>=1.5 in /home/ec2-user/.local/lib/python3.7/site-packages (from python-dateutil<3.0.0,>=2.1->botocore<1.23.0,>=1.22.12->boto3==1.19.12) (1.16.0)\n2022-03-21T22:19:24.4799309Z ++ python3 .github/scripts/get_workflow_job_id.py 2018440039 i-0bc9250521f338cae\n2022-03-21T22:19:24.4890693Z python3: can't open file '.github/scripts/get_workflow_job_id.py': [Errno 2] No such file or directory\n2022-03-21T22:19:24.4903625Z + GHA_WORKFLOW_JOB_ID=\n2022-03-21T22:19:24.4913841Z ##[error]Process completed with exit code 2.\n2022-03-21T22:19:24.4957338Z ##[group]Run pytorch/pytorch/.github/actions/teardown-linux@master\n2022-03-21T22:19:24.4957575Z with:\n2022-03-21T22:19:24.4957735Z env:\n2022-03-21T22:19:24.4957900Z IN_CI: 1\n2022-03-21T22:19:24.4958055Z IS_GHA: 1\n2022-03-21T22:19:24.4958246Z GIT_DEFAULT_BRANCH: master\n2022-03-21T22:19:24.4958437Z ##[endgroup]\n2022-03-21T22:19:24.4989649Z ##[group]Run # ignore expansion of \"docker ps -q\" since it could be empty\n\n\n pull / linux-bionic-rocm4.5-py3.7 / test (default, 2, 2, linux.rocm.gpu) (6/29)\nStep: \"Upload test statistics\" (full log | diagnosis details | \ud83d\udd01 rerun)\n\n\n2022-03-22T01:05:07.6983899Z python3: can't ope...ow_job_id.py': [Errno 2] No such file or directory\n\n2022-03-22T01:05:06.8364546Z Using cached https://files.pythonhosted.org/packages/7b/9c/f51775ebe7df5a7aa4e7c79ed671bde94e154bd968aca8d65bb24aba0c8c/s3transfer-0.5.2-py3-none-any.whl\n2022-03-22T01:05:06.8431763Z Collecting urllib3<1.27,>=1.25.4 (from botocore<1.23.0,>=1.22.12->boto3==1.19.12)\n2022-03-22T01:05:06.8949391Z Using cached https://files.pythonhosted.org/packages/ec/03/062e6444ce4baf1eac17a6a0ebfe36bb1ad05e1df0e20b110de59c278498/urllib3-1.26.9-py2.py3-none-any.whl\n2022-03-22T01:05:06.9180079Z Collecting python-dateutil<3.0.0,>=2.1 (from botocore<1.23.0,>=1.22.12->boto3==1.19.12)\n2022-03-22T01:05:06.9803351Z Using cached https://files.pythonhosted.org/packages/36/7a/87837f39d0296e723bb9b62bbb257d0355c7f6128853c78955f57342a56d/python_dateutil-2.8.2-py2.py3-none-any.whl\n2022-03-22T01:05:06.9882133Z Collecting six>=1.5 (from python-dateutil<3.0.0,>=2.1->botocore<1.23.0,>=1.22.12->boto3==1.19.12)\n2022-03-22T01:05:07.0067062Z Using cached https://files.pythonhosted.org/packages/d9/5a/e7c31adbe875f2abbb91bd84cf2dc52d792b5a01506781dbcf25c91daf11/six-1.16.0-py2.py3-none-any.whl\n2022-03-22T01:05:07.0088676Z Installing collected packages: urllib3, jmespath, six, python-dateutil, botocore, s3transfer, boto3\n2022-03-22T01:05:07.5819667Z Successfully installed boto3-1.19.12 botocore-1.22.12 jmespath-0.10.0 python-dateutil-2.8.2 s3transfer-0.5.2 six-1.16.0 urllib3-1.26.9\n2022-03-22T01:05:07.6774717Z ++ python3 .github/scripts/get_workflow_job_id.py 2018440039 worker-rocm-amd-60\n2022-03-22T01:05:07.6983899Z python3: can't open file '.github/scripts/get_workflow_job_id.py': [Errno 2] No such file or directory\n2022-03-22T01:05:07.6988652Z + GHA_WORKFLOW_JOB_ID=\n2022-03-22T01:05:07.7023073Z ##[error]Process completed with exit code 2.\n2022-03-22T01:05:07.7102087Z ##[group]Run pytorch/pytorch/.github/actions/teardown-linux@master\n2022-03-22T01:05:07.7102389Z with:\n2022-03-22T01:05:07.7102603Z wait-ssh: false\n2022-03-22T01:05:07.7102820Z env:\n2022-03-22T01:05:07.7103015Z IN_CI: 1\n2022-03-22T01:05:07.7103224Z IS_GHA: 1\n2022-03-22T01:05:07.7103458Z GIT_DEFAULT_BRANCH: master\n2022-03-22T01:05:07.7103737Z DOCKER_HOST: unix:///run/user/1502/docker.sock\n\n\n pull / linux-xenial-py3.7-gcc5.4 / test (jit_legacy, 1, 1, linux.2xlarge) (7/29)\nStep: \"Upload test statistics\" (full log | diagnosis details | \ud83d\udd01 rerun)\n\n\n2022-03-21T20:51:39.3637996Z python3: can't ope...ow_job_id.py': [Errno 2] No such file or directory\n\n2022-03-21T20:51:39.2041249Z Attempting uninstall: s3transfer\n2022-03-21T20:51:39.2043010Z Found existing installation: s3transfer 0.3.7\n2022-03-21T20:51:39.2083799Z Uninstalling s3transfer-0.3.7:\n2022-03-21T20:51:39.2089675Z Successfully uninstalled s3transfer-0.3.7\n2022-03-21T20:51:39.2480546Z Attempting uninstall: boto3\n2022-03-21T20:51:39.2482953Z Found existing installation: boto3 1.16.34\n2022-03-21T20:51:39.2584292Z Uninstalling boto3-1.16.34:\n2022-03-21T20:51:39.2599474Z Successfully uninstalled boto3-1.16.34\n2022-03-21T20:51:39.3130921Z Successfully installed boto3-1.19.12 botocore-1.22.12 s3transfer-0.5.2\n2022-03-21T20:51:39.3550598Z ++ python3 .github/scripts/get_workflow_job_id.py 2018440039 i-03ef7efc3078e3da5\n2022-03-21T20:51:39.3637996Z python3: can't open file '.github/scripts/get_workflow_job_id.py': [Errno 2] No such file or directory\n2022-03-21T20:51:39.3650651Z + GHA_WORKFLOW_JOB_ID=\n2022-03-21T20:51:39.3660484Z ##[error]Process completed with exit code 2.\n2022-03-21T20:51:39.3696465Z ##[group]Run pytorch/pytorch/.github/actions/teardown-linux@master\n2022-03-21T20:51:39.3696693Z with:\n2022-03-21T20:51:39.3696850Z env:\n2022-03-21T20:51:39.3697012Z IN_CI: 1\n2022-03-21T20:51:39.3697161Z IS_GHA: 1\n2022-03-21T20:51:39.3697342Z GIT_DEFAULT_BRANCH: master\n2022-03-21T20:51:39.3697528Z ##[endgroup]\n2022-03-21T20:51:39.3730420Z ##[group]Run # ignore expansion of \"docker ps -q\" since it could be empty\n\n\n pull / linux-vulkan-bionic-py3.7-clang9 / test (default, 1, 1, linux.2xlarge) (8/29)\nStep: \"Upload test statistics\" (full log | diagnosis details | \ud83d\udd01 rerun)\n\n\n2022-03-21T21:03:36.3916860Z python3: can't ope...ow_job_id.py': [Errno 2] No such file or directory\n\n2022-03-21T21:03:36.0096309Z + python3 -m pip install boto3==1.19.12\n2022-03-21T21:03:36.2278560Z Defaulting to user installation because normal site-packages is not writeable\n2022-03-21T21:03:36.2461618Z Requirement already satisfied: boto3==1.19.12 in /home/ec2-user/.local/lib/python3.7/site-packages (1.19.12)\n2022-03-21T21:03:36.2513260Z Requirement already satisfied: s3transfer<0.6.0,>=0.5.0 in /home/ec2-user/.local/lib/python3.7/site-packages (from boto3==1.19.12) (0.5.2)\n2022-03-21T21:03:36.2541524Z Requirement already satisfied: jmespath<1.0.0,>=0.7.1 in /home/ec2-user/.local/lib/python3.7/site-packages (from boto3==1.19.12) (0.10.0)\n2022-03-21T21:03:36.2554899Z Requirement already satisfied: botocore<1.23.0,>=1.22.12 in /home/ec2-user/.local/lib/python3.7/site-packages (from boto3==1.19.12) (1.22.12)\n2022-03-21T21:03:36.2598277Z Requirement already satisfied: urllib3<1.27,>=1.25.4 in /home/ec2-user/.local/lib/python3.7/site-packages (from botocore<1.23.0,>=1.22.12->boto3==1.19.12) (1.26.9)\n2022-03-21T21:03:36.2758299Z Requirement already satisfied: python-dateutil<3.0.0,>=2.1 in /home/ec2-user/.local/lib/python3.7/site-packages (from botocore<1.23.0,>=1.22.12->boto3==1.19.12) (2.8.2)\n2022-03-21T21:03:36.2780690Z Requirement already satisfied: six>=1.5 in /home/ec2-user/.local/lib/python3.7/site-packages (from python-dateutil<3.0.0,>=2.1->botocore<1.23.0,>=1.22.12->boto3==1.19.12) (1.16.0)\n2022-03-21T21:03:36.3825021Z ++ python3 .github/scripts/get_workflow_job_id.py 2018440039 i-0a4a552890e6ef7d3\n2022-03-21T21:03:36.3916860Z python3: can't open file '.github/scripts/get_workflow_job_id.py': [Errno 2] No such file or directory\n2022-03-21T21:03:36.3930343Z + GHA_WORKFLOW_JOB_ID=\n2022-03-21T21:03:36.3941263Z ##[error]Process completed with exit code 2.\n2022-03-21T21:03:36.3979258Z ##[group]Run pytorch/pytorch/.github/actions/teardown-linux@master\n2022-03-21T21:03:36.3979496Z with:\n2022-03-21T21:03:36.3979654Z env:\n2022-03-21T21:03:36.3979814Z IN_CI: 1\n2022-03-21T21:03:36.3979968Z IS_GHA: 1\n2022-03-21T21:03:36.3980157Z GIT_DEFAULT_BRANCH: master\n2022-03-21T21:03:36.3980360Z ##[endgroup]\n2022-03-21T21:03:36.3996257Z ##[group]Run # ignore expansion of \"docker ps -q\" since it could be empty\n\n\n pull / win-vs2019-cuda11.3-py3 / test (default, 1, 2, windows.8xlarge.nvidia.gpu) (9/29)\nStep: \"Upload test statistics\" (full log | diagnosis details | \ud83d\udd01 rerun)\n\n\n2022-03-22T00:41:15.5325784Z C:\\actions-runner\\...ow_job_id.py': [Errno 2] No such file or directory\n\n2022-03-22T00:41:10.3015614Z Downloading s3transfer-0.5.2-py3-none-any.whl (79 kB)\n2022-03-22T00:41:10.3625659Z ---------------------------------------- 79.5/79.5 KB 1.1 MB/s eta 0:00:00\n2022-03-22T00:41:10.4120236Z Collecting python-dateutil<3.0.0,>=2.1\n2022-03-22T00:41:10.4170155Z Downloading python_dateutil-2.8.2-py2.py3-none-any.whl (247 kB)\n2022-03-22T00:41:10.4722115Z -------------------------------------- 247.7/247.7 KB 5.2 MB/s eta 0:00:00\n2022-03-22T00:41:10.4843512Z Requirement already satisfied: urllib3<1.27,>=1.25.4 in c:\\actions-runner\\_work\\_tool\\python\\3.10.3\\x64\\lib\\site-packages (from botocore<1.23.0,>=1.22.12->boto3==1.19.12) (1.26.9)\n2022-03-22T00:41:10.6596108Z Requirement already satisfied: six>=1.5 in c:\\actions-runner\\_work\\_tool\\python\\3.10.3\\x64\\lib\\site-packages (from python-dateutil<3.0.0,>=2.1->botocore<1.23.0,>=1.22.12->boto3==1.19.12) (1.16.0)\n2022-03-22T00:41:10.8733354Z Installing collected packages: python-dateutil, jmespath, botocore, s3transfer, boto3\n2022-03-22T00:41:15.3745408Z Successfully installed boto3-1.19.12 botocore-1.22.12 jmespath-0.10.0 python-dateutil-2.8.2 s3transfer-0.5.2\n2022-03-22T00:41:15.4987162Z ++ python3 .github/scripts/get_workflow_job_id.py 2018440039 i-09cacc848abc3dd32\n2022-03-22T00:41:15.5325784Z C:\\actions-runner\\_work\\_tool\\Python\\3.10.3\\x64\\python3.exe: can't open file 'C:\\\\actions-runner\\\\_work\\\\pytorch\\\\pytorch\\\\.github\\\\scripts\\\\get_workflow_job_id.py': [Errno 2] No such file or directory\n2022-03-22T00:41:15.5373630Z + GHA_WORKFLOW_JOB_ID=\n2022-03-22T00:41:15.5404353Z ##[error]Process completed with exit code 2.\n2022-03-22T00:41:15.5790508Z ##[group]Run pytorch/pytorch/.github/actions/teardown-win@master\n2022-03-22T00:41:15.5791192Z with:\n2022-03-22T00:41:15.5791530Z env:\n2022-03-22T00:41:15.5791849Z IN_CI: 1\n2022-03-22T00:41:15.5792186Z IS_GHA: 1\n2022-03-22T00:41:15.5792599Z GIT_DEFAULT_BRANCH: master\n2022-03-22T00:41:15.5793237Z pythonLocation: C:\\actions-runner\\_work\\_tool\\Python\\3.10.3\\x64\n2022-03-22T00:41:15.5793831Z ##[endgroup]\n\n\n pull / linux-xenial-py3.7-gcc5.4 / test (docs_test, 1, 1, linux.2xlarge) (10/29)\nStep: \"Upload test statistics\" (full log | diagnosis details | \ud83d\udd01 rerun)\n\n\n2022-03-21T20:50:32.9799307Z python3: can't ope...ow_job_id.py': [Errno 2] No such file or directory\n\n2022-03-21T20:50:32.8167560Z Attempting uninstall: s3transfer\n2022-03-21T20:50:32.8169351Z Found existing installation: s3transfer 0.3.7\n2022-03-21T20:50:32.8213295Z Uninstalling s3transfer-0.3.7:\n2022-03-21T20:50:32.8219209Z Successfully uninstalled s3transfer-0.3.7\n2022-03-21T20:50:32.8602320Z Attempting uninstall: boto3\n2022-03-21T20:50:32.8603289Z Found existing installation: boto3 1.16.34\n2022-03-21T20:50:32.8704535Z Uninstalling boto3-1.16.34:\n2022-03-21T20:50:32.8719403Z Successfully uninstalled boto3-1.16.34\n2022-03-21T20:50:32.9244278Z Successfully installed boto3-1.19.12 botocore-1.22.12 s3transfer-0.5.2\n2022-03-21T20:50:32.9710449Z ++ python3 .github/scripts/get_workflow_job_id.py 2018440039 i-0c568461a276d4a71\n2022-03-21T20:50:32.9799307Z python3: can't open file '.github/scripts/get_workflow_job_id.py': [Errno 2] No such file or directory\n2022-03-21T20:50:32.9812238Z + GHA_WORKFLOW_JOB_ID=\n2022-03-21T20:50:32.9823052Z ##[error]Process completed with exit code 2.\n2022-03-21T20:50:32.9859290Z ##[group]Run pytorch/pytorch/.github/actions/teardown-linux@master\n2022-03-21T20:50:32.9859527Z with:\n2022-03-21T20:50:32.9859664Z env:\n2022-03-21T20:50:32.9859817Z IN_CI: 1\n2022-03-21T20:50:32.9859977Z IS_GHA: 1\n2022-03-21T20:50:32.9860144Z GIT_DEFAULT_BRANCH: master\n2022-03-21T20:50:32.9860327Z ##[endgroup]\n2022-03-21T20:50:32.9893642Z ##[group]Run # ignore expansion of \"docker ps -q\" since it could be empty\n\n\n pull / linux-xenial-py3.7-clang7-asan / test (default, 1, 3, linux.2xlarge) (11/29)\nStep: \"Upload test statistics\" (full log | diagnosis details | \ud83d\udd01 rerun)\n\n\n2022-03-21T21:05:00.7163042Z SUMMARY: Undefined.../jenkins/workspace/aten/src/ATen/Utils.cpp:20:3 in\n\n2022-03-21T21:05:00.6660824Z #10 0x55fc8a3ea801 in run_mod /tmp/build/80754af9/python_1627392990942/work/Python/pythonrun.c:1037\n2022-03-21T21:05:00.6661768Z #11 0x55fc8a3f57a9 in PyRun_StringFlags /tmp/build/80754af9/python_1627392990942/work/Python/pythonrun.c:961\n2022-03-21T21:05:00.6662455Z #12 0x55fc8a3f580b in PyRun_SimpleStringFlags /tmp/build/80754af9/python_1627392990942/work/Python/pythonrun.c:455\n2022-03-21T21:05:00.6663570Z #13 0x55fc8a3f5908 in pymain_run_command /tmp/build/80754af9/python_1627392990942/work/Modules/main.c:420\n2022-03-21T21:05:00.6663952Z #14 0x55fc8a3f5908 in pymain_run_python /tmp/build/80754af9/python_1627392990942/work/Modules/main.c:2907\n2022-03-21T21:05:00.6664431Z #15 0x55fc8a3f5908 in pymain_main /tmp/build/80754af9/python_1627392990942/work/Modules/main.c:3460\n2022-03-21T21:05:00.6665304Z #16 0x55fc8a3f5ccb in _Py_UnixMain /tmp/build/80754af9/python_1627392990942/work/Modules/main.c:3495\n2022-03-21T21:05:00.7162113Z #17 0x7f940d00f83f in __libc_start_main /build/glibc-S7Ft5T/glibc-2.23/csu/../csu/libc-start.c:291\n2022-03-21T21:05:00.7162534Z #18 0x55fc8a39a554 in _start (/opt/conda/bin/python3.7+0x1d7554)\n2022-03-21T21:05:00.7162711Z \n2022-03-21T21:05:00.7163042Z SUMMARY: UndefinedBehaviorSanitizer: undefined-behavior /var/lib/jenkins/workspace/aten/src/ATen/Utils.cpp:20:3 in \n2022-03-21T21:05:00.7334595Z + retcode=1\n2022-03-21T21:05:00.7334954Z + set -e\n2022-03-21T21:05:00.7335215Z + return 1\n2022-03-21T21:05:00.7338688Z + [[ linux-xenial-py3.7-clang7-asan-default == *-NO_AVX-* ]]\n2022-03-21T21:05:00.7339232Z + [[ default == \\n\\o\\g\\p\\u\\_\\N\\O\\_\\A\\V\\X ]]\n2022-03-21T21:05:00.7340113Z + [[ linux-xenial-py3.7-clang7-asan-default == *-NO_AVX2-* ]]\n2022-03-21T21:05:00.7340612Z + [[ default == \\n\\o\\g\\p\\u\\_\\N\\O\\_\\A\\V\\X\\2 ]]\n2022-03-21T21:05:00.7341187Z + [[ linux-xenial-py3.7-clang7-asan-default == *-NO_AVX512-* ]]\n2022-03-21T21:05:00.7341668Z + [[ default == \\n\\o\\g\\p\\u\\_\\N\\O\\_\\A\\V\\X\\5\\1\\2 ]]\n2022-03-21T21:05:00.7344466Z + [[ linux-xenial-py3.7-clang7-asan-default == *tbb* ]]\n\n\n pull / linux-xenial-py3.7-clang7-onnx / test (default, 1, 2, linux.2xlarge) (12/29)\nStep: \"Upload test statistics\" (full log | diagnosis details | \ud83d\udd01 rerun)\n\n\n2022-03-21T22:06:03.4437430Z python3: can't ope...ow_job_id.py': [Errno 2] No such file or directory\n\n2022-03-21T22:06:03.0752199Z + python3 -m pip install boto3==1.19.12\n2022-03-21T22:06:03.2853252Z Defaulting to user installation because normal site-packages is not writeable\n2022-03-21T22:06:03.3032326Z Requirement already satisfied: boto3==1.19.12 in /home/ec2-user/.local/lib/python3.7/site-packages (1.19.12)\n2022-03-21T22:06:03.3081589Z Requirement already satisfied: jmespath<1.0.0,>=0.7.1 in /home/ec2-user/.local/lib/python3.7/site-packages (from boto3==1.19.12) (0.10.0)\n2022-03-21T22:06:03.3093911Z Requirement already satisfied: s3transfer<0.6.0,>=0.5.0 in /home/ec2-user/.local/lib/python3.7/site-packages (from boto3==1.19.12) (0.5.2)\n2022-03-21T22:06:03.3120244Z Requirement already satisfied: botocore<1.23.0,>=1.22.12 in /home/ec2-user/.local/lib/python3.7/site-packages (from boto3==1.19.12) (1.22.12)\n2022-03-21T22:06:03.3162406Z Requirement already satisfied: python-dateutil<3.0.0,>=2.1 in /home/ec2-user/.local/lib/python3.7/site-packages (from botocore<1.23.0,>=1.22.12->boto3==1.19.12) (2.8.2)\n2022-03-21T22:06:03.3188431Z Requirement already satisfied: urllib3<1.27,>=1.25.4 in /home/ec2-user/.local/lib/python3.7/site-packages (from botocore<1.23.0,>=1.22.12->boto3==1.19.12) (1.26.9)\n2022-03-21T22:06:03.3337181Z Requirement already satisfied: six>=1.5 in /home/ec2-user/.local/lib/python3.7/site-packages (from python-dateutil<3.0.0,>=2.1->botocore<1.23.0,>=1.22.12->boto3==1.19.12) (1.16.0)\n2022-03-21T22:06:03.4348072Z ++ python3 .github/scripts/get_workflow_job_id.py 2018440039 i-0ee48c8811fafc444\n2022-03-21T22:06:03.4437430Z python3: can't open file '.github/scripts/get_workflow_job_id.py': [Errno 2] No such file or directory\n2022-03-21T22:06:03.4450920Z + GHA_WORKFLOW_JOB_ID=\n2022-03-21T22:06:03.4461263Z ##[error]Process completed with exit code 2.\n2022-03-21T22:06:03.4502346Z ##[group]Run pytorch/pytorch/.github/actions/teardown-linux@master\n2022-03-21T22:06:03.4502576Z with:\n2022-03-21T22:06:03.4502730Z env:\n2022-03-21T22:06:03.4502888Z IN_CI: 1\n2022-03-21T22:06:03.4503038Z IS_GHA: 1\n2022-03-21T22:06:03.4503302Z GIT_DEFAULT_BRANCH: master\n2022-03-21T22:06:03.4503492Z ##[endgroup]\n2022-03-21T22:06:03.4519156Z ##[group]Run # ignore expansion of \"docker ps -q\" since it could be empty\n\n\n pull / linux-xenial-py3.7-gcc5.4 / test (backwards_compat, 1, 1, linux.2xlarge) (13/29)\nStep: \"Test\" (full log | diagnosis details | \ud83d\udd01 rerun)\n\n\n2022-03-21T20:50:13.2205634Z python3: can't ope...ow_job_id.py': [Errno 2] No such file or directory\n\n2022-03-21T20:50:12.8679322Z + python3 -m pip install boto3==1.19.12\n2022-03-21T20:50:13.0744228Z Defaulting to user installation because normal site-packages is not writeable\n2022-03-21T20:50:13.0916284Z Requirement already satisfied: boto3==1.19.12 in /home/ec2-user/.local/lib/python3.7/site-packages (1.19.12)\n2022-03-21T20:50:13.0964264Z Requirement already satisfied: botocore<1.23.0,>=1.22.12 in /home/ec2-user/.local/lib/python3.7/site-packages (from boto3==1.19.12) (1.22.12)\n2022-03-21T20:50:13.1005656Z Requirement already satisfied: jmespath<1.0.0,>=0.7.1 in /home/ec2-user/.local/lib/python3.7/site-packages (from boto3==1.19.12) (0.10.0)\n2022-03-21T20:50:13.1017299Z Requirement already satisfied: s3transfer<0.6.0,>=0.5.0 in /home/ec2-user/.local/lib/python3.7/site-packages (from boto3==1.19.12) (0.5.2)\n2022-03-21T20:50:13.1041042Z Requirement already satisfied: urllib3<1.27,>=1.25.4 in /home/ec2-user/.local/lib/python3.7/site-packages (from botocore<1.23.0,>=1.22.12->boto3==1.19.12) (1.26.9)\n2022-03-21T20:50:13.1189450Z Requirement already satisfied: python-dateutil<3.0.0,>=2.1 in /home/ec2-user/.local/lib/python3.7/site-packages (from botocore<1.23.0,>=1.22.12->boto3==1.19.12) (2.8.2)\n2022-03-21T20:50:13.1208751Z Requirement already satisfied: six>=1.5 in /home/ec2-user/.local/lib/python3.7/site-packages (from python-dateutil<3.0.0,>=2.1->botocore<1.23.0,>=1.22.12->boto3==1.19.12) (1.16.0)\n2022-03-21T20:50:13.2119445Z ++ python3 .github/scripts/get_workflow_job_id.py 2018440039 i-0d02da60fd18c22f5\n2022-03-21T20:50:13.2205634Z python3: can't open file '.github/scripts/get_workflow_job_id.py': [Errno 2] No such file or directory\n2022-03-21T20:50:13.2217939Z + GHA_WORKFLOW_JOB_ID=\n2022-03-21T20:50:13.2220259Z ##[error]Process completed with exit code 2.\n2022-03-21T20:50:13.2248664Z ##[group]Run pytorch/pytorch/.github/actions/teardown-linux@master\n2022-03-21T20:50:13.2249012Z with:\n2022-03-21T20:50:13.2249260Z env:\n2022-03-21T20:50:13.2249500Z IN_CI: 1\n2022-03-21T20:50:13.2249738Z IS_GHA: 1\n2022-03-21T20:50:13.2250025Z GIT_DEFAULT_BRANCH: master\n2022-03-21T20:50:13.2250329Z ##[endgroup]\n2022-03-21T20:50:13.2272735Z ##[group]Run # ignore expansion of \"docker ps -q\" since it could be empty\n\n\n pull / linux-xenial-cuda11.3-py3.7-gcc7 / test (distributed, 1, 1, linux.8xlarge.nvidia.gpu) (14/29)\nStep: \"Upload test statistics\" (full log | diagnosis details | \ud83d\udd01 rerun)\n\n\n2022-03-21T23:47:38.0451999Z python3: can't ope...ow_job_id.py': [Errno 2] No such file or directory\n\n2022-03-21T23:47:37.5554508Z + python3 -m pip install boto3==1.19.12\n2022-03-21T23:47:37.8411473Z Defaulting to user installation because normal site-packages is not writeable\n2022-03-21T23:47:37.8631484Z Requirement already satisfied: boto3==1.19.12 in /home/ec2-user/.local/lib/python3.7/site-packages (1.19.12)\n2022-03-21T23:47:37.8699561Z Requirement already satisfied: s3transfer<0.6.0,>=0.5.0 in /home/ec2-user/.local/lib/python3.7/site-packages (from boto3==1.19.12) (0.5.2)\n2022-03-21T23:47:37.8737037Z Requirement already satisfied: jmespath<1.0.0,>=0.7.1 in /home/ec2-user/.local/lib/python3.7/site-packages (from boto3==1.19.12) (0.10.0)\n2022-03-21T23:47:37.8754443Z Requirement already satisfied: botocore<1.23.0,>=1.22.12 in /home/ec2-user/.local/lib/python3.7/site-packages (from boto3==1.19.12) (1.22.12)\n2022-03-21T23:47:37.8814393Z Requirement already satisfied: python-dateutil<3.0.0,>=2.1 in /home/ec2-user/.local/lib/python3.7/site-packages (from botocore<1.23.0,>=1.22.12->boto3==1.19.12) (2.8.2)\n2022-03-21T23:47:37.8849540Z Requirement already satisfied: urllib3<1.27,>=1.25.4 in /home/ec2-user/.local/lib/python3.7/site-packages (from botocore<1.23.0,>=1.22.12->boto3==1.19.12) (1.26.9)\n2022-03-21T23:47:37.9059579Z Requirement already satisfied: six>=1.5 in /home/ec2-user/.local/lib/python3.7/site-packages (from python-dateutil<3.0.0,>=2.1->botocore<1.23.0,>=1.22.12->boto3==1.19.12) (1.16.0)\n2022-03-21T23:47:38.0336298Z ++ python3 .github/scripts/get_workflow_job_id.py 2018440039 i-0b44f47f4292089a2\n2022-03-21T23:47:38.0451999Z python3: can't open file '.github/scripts/get_workflow_job_id.py': [Errno 2] No such file or directory\n2022-03-21T23:47:38.0469471Z + GHA_WORKFLOW_JOB_ID=\n2022-03-21T23:47:38.0484106Z ##[error]Process completed with exit code 2.\n2022-03-21T23:47:38.0532678Z ##[group]Run pytorch/pytorch/.github/actions/teardown-linux@master\n2022-03-21T23:47:38.0533007Z with:\n2022-03-21T23:47:38.0533223Z env:\n2022-03-21T23:47:38.0533440Z IN_CI: 1\n2022-03-21T23:47:38.0533649Z IS_GHA: 1\n2022-03-21T23:47:38.0533902Z GIT_DEFAULT_BRANCH: master\n2022-03-21T23:47:38.0534170Z GPU_FLAG: --gpus all\n2022-03-21T23:47:38.0534401Z ##[endgroup]\n\n\n pull / linux-xenial-py3.7-clang7-asan / test (default, 2, 3, linux.2xlarge) (15/29)\nStep: \"Upload test statistics\" (full log | diagnosis details | \ud83d\udd01 rerun)\n\n\n2022-03-21T21:04:59.3115800Z SUMMARY: Undefined.../jenkins/workspace/aten/src/ATen/Utils.cpp:20:3 in\n\n2022-03-21T21:04:59.2595213Z #10 0x55a7f39a4801 in run_mod /tmp/build/80754af9/python_1627392990942/work/Python/pythonrun.c:1037\n2022-03-21T21:04:59.2595707Z #11 0x55a7f39af7a9 in PyRun_StringFlags /tmp/build/80754af9/python_1627392990942/work/Python/pythonrun.c:961\n2022-03-21T21:04:59.2597203Z #12 0x55a7f39af80b in PyRun_SimpleStringFlags /tmp/build/80754af9/python_1627392990942/work/Python/pythonrun.c:455\n2022-03-21T21:04:59.2598205Z #13 0x55a7f39af908 in pymain_run_command /tmp/build/80754af9/python_1627392990942/work/Modules/main.c:420\n2022-03-21T21:04:59.2598697Z #14 0x55a7f39af908 in pymain_run_python /tmp/build/80754af9/python_1627392990942/work/Modules/main.c:2907\n2022-03-21T21:04:59.2599178Z #15 0x55a7f39af908 in pymain_main /tmp/build/80754af9/python_1627392990942/work/Modules/main.c:3460\n2022-03-21T21:04:59.2599747Z #16 0x55a7f39afccb in _Py_UnixMain /tmp/build/80754af9/python_1627392990942/work/Modules/main.c:3495\n2022-03-21T21:04:59.3114751Z #17 0x7f3b3822383f in __libc_start_main /build/glibc-S7Ft5T/glibc-2.23/csu/../csu/libc-start.c:291\n2022-03-21T21:04:59.3115277Z #18 0x55a7f3954554 in _start (/opt/conda/bin/python3.7+0x1d7554)\n2022-03-21T21:04:59.3115468Z \n2022-03-21T21:04:59.3115800Z SUMMARY: UndefinedBehaviorSanitizer: undefined-behavior /var/lib/jenkins/workspace/aten/src/ATen/Utils.cpp:20:3 in \n2022-03-21T21:04:59.3292385Z + retcode=1\n2022-03-21T21:04:59.3292781Z + set -e\n2022-03-21T21:04:59.3293062Z + return 1\n2022-03-21T21:04:59.3295462Z + [[ linux-xenial-py3.7-clang7-asan-default == *-NO_AVX-* ]]\n2022-03-21T21:04:59.3295802Z + [[ default == \\n\\o\\g\\p\\u\\_\\N\\O\\_\\A\\V\\X ]]\n2022-03-21T21:04:59.3296394Z + [[ linux-xenial-py3.7-clang7-asan-default == *-NO_AVX2-* ]]\n2022-03-21T21:04:59.3296700Z + [[ default == \\n\\o\\g\\p\\u\\_\\N\\O\\_\\A\\V\\X\\2 ]]\n2022-03-21T21:04:59.3297055Z + [[ linux-xenial-py3.7-clang7-asan-default == *-NO_AVX512-* ]]\n2022-03-21T21:04:59.3297416Z + [[ default == \\n\\o\\g\\p\\u\\_\\N\\O\\_\\A\\V\\X\\5\\1\\2 ]]\n2022-03-21T21:04:59.3299623Z + [[ linux-xenial-py3.7-clang7-asan-default == *tbb* ]]\n\n\n pull / win-vs2019-cpu-py3 / test (default, 2, 2, windows.4xlarge) (16/29)\nStep: \"Upload test statistics\" (full log | diagnosis details | \ud83d\udd01 rerun)\n\n\n2022-03-21T22:14:31.7846086Z C:\\actions-runner\\...ow_job_id.py': [Errno 2] No such file or directory\n\n2022-03-21T22:14:25.5525714Z Collecting jmespath<1.0.0,>=0.7.1\n2022-03-21T22:14:25.5568155Z Downloading jmespath-0.10.0-py2.py3-none-any.whl (24 kB)\n2022-03-21T22:14:25.5952617Z Collecting python-dateutil<3.0.0,>=2.1\n2022-03-21T22:14:25.6169392Z Downloading python_dateutil-2.8.2-py2.py3-none-any.whl (247 kB)\n2022-03-21T22:14:25.6629996Z -------------------------------------- 247.7/247.7 KB 5.1 MB/s eta 0:00:00\n2022-03-21T22:14:25.6710247Z Requirement already satisfied: urllib3<1.27,>=1.25.4 in c:\\actions-runner\\_work\\_tool\\python\\3.10.3\\x64\\lib\\site-packages (from botocore<1.23.0,>=1.22.12->boto3==1.19.12) (1.26.9)\n2022-03-21T22:14:25.8284354Z Requirement already satisfied: six>=1.5 in c:\\actions-runner\\_work\\_tool\\python\\3.10.3\\x64\\lib\\site-packages (from python-dateutil<3.0.0,>=2.1->botocore<1.23.0,>=1.22.12->boto3==1.19.12) (1.16.0)\n2022-03-21T22:14:25.9816751Z Installing collected packages: python-dateutil, jmespath, botocore, s3transfer, boto3\n2022-03-21T22:14:31.6672236Z Successfully installed boto3-1.19.12 botocore-1.22.12 jmespath-0.10.0 python-dateutil-2.8.2 s3transfer-0.5.2\n2022-03-21T22:14:31.7630473Z ++ python3 .github/scripts/get_workflow_job_id.py 2018440039 i-0ed0915ecee5d2424\n2022-03-21T22:14:31.7846086Z C:\\actions-runner\\_work\\_tool\\Python\\3.10.3\\x64\\python3.exe: can't open file 'C:\\\\actions-runner\\\\_work\\\\pytorch\\\\pytorch\\\\.github\\\\scripts\\\\get_workflow_job_id.py': [Errno 2] No such file or directory\n2022-03-21T22:14:31.7876742Z + GHA_WORKFLOW_JOB_ID=\n2022-03-21T22:14:31.7897140Z ##[error]Process completed with exit code 2.\n2022-03-21T22:14:31.8195621Z ##[group]Run pytorch/pytorch/.github/actions/teardown-win@master\n2022-03-21T22:14:31.8196110Z with:\n2022-03-21T22:14:31.8196356Z env:\n2022-03-21T22:14:31.8196614Z IN_CI: 1\n2022-03-21T22:14:31.8196876Z IS_GHA: 1\n2022-03-21T22:14:31.8197169Z GIT_DEFAULT_BRANCH: master\n2022-03-21T22:14:31.8197652Z pythonLocation: C:\\actions-runner\\_work\\_tool\\Python\\3.10.3\\x64\n2022-03-21T22:14:31.8198093Z ##[endgroup]\n\n\n pull / linux-xenial-py3.7-gcc5.4 / test (default, 2, 2, linux.2xlarge) (17/29)\nStep: \"Upload test statistics\" (full log | diagnosis details | \ud83d\udd01 rerun)\n\n\n2022-03-21T21:19:15.8845728Z python3: can't ope...ow_job_id.py': [Errno 2] No such file or directory\n\n2022-03-21T21:19:15.5116060Z + python3 -m pip install boto3==1.19.12\n2022-03-21T21:19:15.7231476Z Defaulting to user installation because normal site-packages is not writeable\n2022-03-21T21:19:15.7409711Z Requirement already satisfied: boto3==1.19.12 in /home/ec2-user/.local/lib/python3.7/site-packages (1.19.12)\n2022-03-21T21:19:15.7458478Z Requirement already satisfied: jmespath<1.0.0,>=0.7.1 in /home/ec2-user/.local/lib/python3.7/site-packages (from boto3==1.19.12) (0.10.0)\n2022-03-21T21:19:15.7470508Z Requirement already satisfied: s3transfer<0.6.0,>=0.5.0 in /home/ec2-user/.local/lib/python3.7/site-packages (from boto3==1.19.12) (0.5.2)\n2022-03-21T21:19:15.7496799Z Requirement already satisfied: botocore<1.23.0,>=1.22.12 in /home/ec2-user/.local/lib/python3.7/site-packages (from boto3==1.19.12) (1.22.12)\n2022-03-21T21:19:15.7538362Z Requirement already satisfied: python-dateutil<3.0.0,>=2.1 in /home/ec2-user/.local/lib/python3.7/site-packages (from botocore<1.23.0,>=1.22.12->boto3==1.19.12) (2.8.2)\n2022-03-21T21:19:15.7566161Z Requirement already satisfied: urllib3<1.27,>=1.25.4 in /home/ec2-user/.local/lib/python3.7/site-packages (from botocore<1.23.0,>=1.22.12->boto3==1.19.12) (1.26.9)\n2022-03-21T21:19:15.7711630Z Requirement already satisfied: six>=1.5 in /home/ec2-user/.local/lib/python3.7/site-packages (from python-dateutil<3.0.0,>=2.1->botocore<1.23.0,>=1.22.12->boto3==1.19.12) (1.16.0)\n2022-03-21T21:19:15.8753543Z ++ python3 .github/scripts/get_workflow_job_id.py 2018440039 i-0e2b3b4ddb246ff2a\n2022-03-21T21:19:15.8845728Z python3: can't open file '.github/scripts/get_workflow_job_id.py': [Errno 2] No such file or directory\n2022-03-21T21:19:15.8859814Z + GHA_WORKFLOW_JOB_ID=\n2022-03-21T21:19:15.8870165Z ##[error]Process completed with exit code 2.\n2022-03-21T21:19:15.8917039Z ##[group]Run pytorch/pytorch/.github/actions/teardown-linux@master\n2022-03-21T21:19:15.8917279Z with:\n2022-03-21T21:19:15.8917433Z env:\n2022-03-21T21:19:15.8917586Z IN_CI: 1\n2022-03-21T21:19:15.8917734Z IS_GHA: 1\n2022-03-21T21:19:15.8917917Z GIT_DEFAULT_BRANCH: master\n2022-03-21T21:19:15.8918102Z ##[endgroup]\n2022-03-21T21:19:15.8934572Z ##[group]Run # ignore expansion of \"docker ps -q\" since it could be empty\n\n\n pull / linux-xenial-cuda11.3-py3.7-gcc7 / test (default, 1, 2, linux.4xlarge.nvidia.gpu) (18/29)\nStep: \"Upload test statistics\" (full log | diagnosis details | \ud83d\udd01 rerun)\n\n\n2022-03-21T23:19:48.5900162Z python3: can't ope...ow_job_id.py': [Errno 2] No such file or directory\n\n2022-03-21T23:19:48.0742254Z + python3 -m pip install boto3==1.19.12\n2022-03-21T23:19:48.3742563Z Defaulting to user installation because normal site-packages is not writeable\n2022-03-21T23:19:48.3976536Z Requirement already satisfied: boto3==1.19.12 in /home/ec2-user/.local/lib/python3.7/site-packages (1.19.12)\n2022-03-21T23:19:48.4048700Z Requirement already satisfied: jmespath<1.0.0,>=0.7.1 in /home/ec2-user/.local/lib/python3.7/site-packages (from boto3==1.19.12) (0.10.0)\n2022-03-21T23:19:48.4065374Z Requirement already satisfied: botocore<1.23.0,>=1.22.12 in /home/ec2-user/.local/lib/python3.7/site-packages (from boto3==1.19.12) (1.22.12)\n2022-03-21T23:19:48.4128076Z Requirement already satisfied: s3transfer<0.6.0,>=0.5.0 in /home/ec2-user/.local/lib/python3.7/site-packages (from boto3==1.19.12) (0.5.2)\n2022-03-21T23:19:48.4164273Z Requirement already satisfied: python-dateutil<3.0.0,>=2.1 in /home/ec2-user/.local/lib/python3.7/site-packages (from botocore<1.23.0,>=1.22.12->boto3==1.19.12) (2.8.2)\n2022-03-21T23:19:48.4202610Z Requirement already satisfied: urllib3<1.27,>=1.25.4 in /home/ec2-user/.local/lib/python3.7/site-packages (from botocore<1.23.0,>=1.22.12->boto3==1.19.12) (1.26.9)\n2022-03-21T23:19:48.4416723Z Requirement already satisfied: six>=1.5 in /home/ec2-user/.local/lib/python3.7/site-packages (from python-dateutil<3.0.0,>=2.1->botocore<1.23.0,>=1.22.12->boto3==1.19.12) (1.16.0)\n2022-03-21T23:19:48.5773033Z ++ python3 .github/scripts/get_workflow_job_id.py 2018440039 i-07ab7a3c4a5402af2\n2022-03-21T23:19:48.5900162Z python3: can't open file '.github/scripts/get_workflow_job_id.py': [Errno 2] No such file or directory\n2022-03-21T23:19:48.5919822Z + GHA_WORKFLOW_JOB_ID=\n2022-03-21T23:19:48.5936087Z ##[error]Process completed with exit code 2.\n2022-03-21T23:19:48.6007930Z ##[group]Run pytorch/pytorch/.github/actions/teardown-linux@master\n2022-03-21T23:19:48.6008268Z with:\n2022-03-21T23:19:48.6008483Z env:\n2022-03-21T23:19:48.6008701Z IN_CI: 1\n2022-03-21T23:19:48.6008920Z IS_GHA: 1\n2022-03-21T23:19:48.6009170Z GIT_DEFAULT_BRANCH: master\n2022-03-21T23:19:48.6009440Z GPU_FLAG: --gpus all\n2022-03-21T23:19:48.6009671Z ##[endgroup]\n\n\n pull / win-vs2019-cuda11.3-py3 / test (default, 2, 2, windows.8xlarge.nvidia.gpu) (19/29)\nStep: \"Upload test statistics\" (full log | diagnosis details | \ud83d\udd01 rerun)\n\n\n2022-03-21T22:54:04.2844259Z C:\\actions-runner\\...ow_job_id.py': [Errno 2] No such file or directory\n\n2022-03-21T22:53:59.0889659Z Downloading botocore-1.22.12-py3-none-any.whl (8.1 MB)\n2022-03-21T22:53:59.6881416Z ---------------------------------------- 8.1/8.1 MB 14.0 MB/s eta 0:00:00\n2022-03-21T22:53:59.7427779Z Requirement already satisfied: urllib3<1.27,>=1.25.4 in c:\\actions-runner\\_work\\_tool\\python\\3.10.3\\x64\\lib\\site-packages (from botocore<1.23.0,>=1.22.12->boto3==1.19.12) (1.26.9)\n2022-03-21T22:53:59.7691882Z Collecting python-dateutil<3.0.0,>=2.1\n2022-03-21T22:53:59.7779847Z Downloading python_dateutil-2.8.2-py2.py3-none-any.whl (247 kB)\n2022-03-21T22:53:59.8281663Z -------------------------------------- 247.7/247.7 KB 5.1 MB/s eta 0:00:00\n2022-03-21T22:54:00.0185115Z Requirement already satisfied: six>=1.5 in c:\\actions-runner\\_work\\_tool\\python\\3.10.3\\x64\\lib\\site-packages (from python-dateutil<3.0.0,>=2.1->botocore<1.23.0,>=1.22.12->boto3==1.19.12) (1.16.0)\n2022-03-21T22:54:00.2359770Z Installing collected packages: python-dateutil, jmespath, botocore, s3transfer, boto3\n2022-03-21T22:54:04.1208891Z Successfully installed boto3-1.19.12 botocore-1.22.12 jmespath-0.10.0 python-dateutil-2.8.2 s3transfer-0.5.2\n2022-03-21T22:54:04.2505862Z ++ python3 .github/scripts/get_workflow_job_id.py 2018440039 i-03b4fbe63be8ef4b0\n2022-03-21T22:54:04.2844259Z C:\\actions-runner\\_work\\_tool\\Python\\3.10.3\\x64\\python3.exe: can't open file 'C:\\\\actions-runner\\\\_work\\\\pytorch\\\\pytorch\\\\.github\\\\scripts\\\\get_workflow_job_id.py': [Errno 2] No such file or directory\n2022-03-21T22:54:04.2891082Z + GHA_WORKFLOW_JOB_ID=\n2022-03-21T22:54:04.2919900Z ##[error]Process completed with exit code 2.\n2022-03-21T22:54:04.3377901Z ##[group]Run pytorch/pytorch/.github/actions/teardown-win@master\n2022-03-21T22:54:04.3378575Z with:\n2022-03-21T22:54:04.3378930Z env:\n2022-03-21T22:54:04.3379275Z IN_CI: 1\n2022-03-21T22:54:04.3379600Z IS_GHA: 1\n2022-03-21T22:54:04.3380023Z GIT_DEFAULT_BRANCH: master\n2022-03-21T22:54:04.3380691Z pythonLocation: C:\\actions-runner\\_work\\_tool\\Python\\3.10.3\\x64\n2022-03-21T22:54:04.3381278Z ##[endgroup]\n\n\n pull / linux-bionic-py3.7-clang9 / test (noarch, 1, 1, linux.2xlarge) (20/29)\nStep: \"Upload test statistics\" (full log | diagnosis details | \ud83d\udd01 rerun)\n\n\n2022-03-21T22:09:34.0074610Z python3: can't ope...ow_job_id.py': [Errno 2] No such file or directory\n\n2022-03-21T22:09:33.6365531Z + python3 -m pip install boto3==1.19.12\n2022-03-21T22:09:33.8475619Z Defaulting to user installation because normal site-packages is not writeable\n2022-03-21T22:09:33.8655152Z Requirement already satisfied: boto3==1.19.12 in /home/ec2-user/.local/lib/python3.7/site-packages (1.19.12)\n2022-03-21T22:09:33.8704395Z Requirement already satisfied: jmespath<1.0.0,>=0.7.1 in /home/ec2-user/.local/lib/python3.7/site-packages (from boto3==1.19.12) (0.10.0)\n2022-03-21T22:09:33.8716774Z Requirement already satisfied: botocore<1.23.0,>=1.22.12 in /home/ec2-user/.local/lib/python3.7/site-packages (from boto3==1.19.12) (1.22.12)\n2022-03-21T22:09:33.8760145Z Requirement already satisfied: s3transfer<0.6.0,>=0.5.0 in /home/ec2-user/.local/lib/python3.7/site-packages (from boto3==1.19.12) (0.5.2)\n2022-03-21T22:09:33.8785000Z Requirement already satisfied: python-dateutil<3.0.0,>=2.1 in /home/ec2-user/.local/lib/python3.7/site-packages (from botocore<1.23.0,>=1.22.12->boto3==1.19.12) (2.8.2)\n2022-03-21T22:09:33.8811316Z Requirement already satisfied: urllib3<1.27,>=1.25.4 in /home/ec2-user/.local/lib/python3.7/site-packages (from botocore<1.23.0,>=1.22.12->boto3==1.19.12) (1.26.9)\n2022-03-21T22:09:33.8960134Z Requirement already satisfied: six>=1.5 in /home/ec2-user/.local/lib/python3.7/site-packages (from python-dateutil<3.0.0,>=2.1->botocore<1.23.0,>=1.22.12->boto3==1.19.12) (1.16.0)\n2022-03-21T22:09:33.9984866Z ++ python3 .github/scripts/get_workflow_job_id.py 2018440039 i-0d325eb9fd156146f\n2022-03-21T22:09:34.0074610Z python3: can't open file '.github/scripts/get_workflow_job_id.py': [Errno 2] No such file or directory\n2022-03-21T22:09:34.0087465Z + GHA_WORKFLOW_JOB_ID=\n2022-03-21T22:09:34.0101743Z ##[error]Process completed with exit code 2.\n2022-03-21T22:09:34.0154014Z ##[group]Run pytorch/pytorch/.github/actions/teardown-linux@master\n2022-03-21T22:09:34.0154246Z with:\n2022-03-21T22:09:34.0154412Z env:\n2022-03-21T22:09:34.0154574Z IN_CI: 1\n2022-03-21T22:09:34.0154728Z IS_GHA: 1\n2022-03-21T22:09:34.0154917Z GIT_DEFAULT_BRANCH: master\n2022-03-21T22:09:34.0155112Z ##[endgroup]\n2022-03-21T22:09:34.0191047Z ##[group]Run # ignore expansion of \"docker ps -q\" since it could be empty\n\n\n pull / linux-xenial-py3.7-gcc5.4 / test (distributed, 1, 1, linux.2xlarge) (21/29)\nStep: \"Upload test statistics\" (full log | diagnosis details | \ud83d\udd01 rerun)\n\n\n2022-03-21T21:03:17.8502655Z [E request_callbac...yUniqueId(created_on=0, local_id=0) to be created.\n\n2022-03-21T21:03:14.4669960Z INFO:torch.distributed.nn.jit.instantiator:Created a temporary directory at /tmp/tmpxgdsmeer\n2022-03-21T21:03:14.4671407Z INFO:torch.distributed.nn.jit.instantiator:Writing /tmp/tmpxgdsmeer/_remote_module_non_sriptable.py\n2022-03-21T21:03:14.4973023Z INFO:torch.distributed.nn.jit.instantiator:Created a temporary directory at /tmp/tmp1i2hfmpc\n2022-03-21T21:03:14.4973800Z INFO:torch.distributed.nn.jit.instantiator:Writing /tmp/tmp1i2hfmpc/_remote_module_non_sriptable.py\n2022-03-21T21:03:14.5532339Z INFO:torch.distributed.nn.jit.instantiator:Created a temporary directory at /tmp/tmpgx4da7b0\n2022-03-21T21:03:14.5533064Z INFO:torch.distributed.nn.jit.instantiator:Writing /tmp/tmpgx4da7b0/_remote_module_non_sriptable.py\n2022-03-21T21:03:14.7050673Z INFO:torch.testing._internal.common_distributed:Starting event listener thread for rank 0\n2022-03-21T21:03:14.7097127Z INFO:torch.testing._internal.common_distributed:Starting event listener thread for rank 3\n2022-03-21T21:03:14.7398339Z INFO:torch.testing._internal.common_distributed:Starting event listener thread for rank 2\n2022-03-21T21:03:14.7922283Z INFO:torch.testing._internal.common_distributed:Starting event listener thread for rank 1\n2022-03-21T21:03:17.8502655Z [E request_callback_no_python.cpp:559] Received error while processing request type 261: false INTERNAL ASSERT FAILED at \"/var/lib/jenkins/workspace/torch/csrc/distributed/rpc/rref_context.cpp\":387, please report a bug to PyTorch. Expected OwnerRRef with id GloballyUniqueId(created_on=0, local_id=0) to be created.\n2022-03-21T21:03:17.8503603Z Exception raised from getOwnerRRef at /var/lib/jenkins/workspace/torch/csrc/distributed/rpc/rref_context.cpp:387 (most recent call first):\n2022-03-21T21:03:17.8504385Z frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string, std::allocator >) + 0x69 (0x7f180df19e19 in /opt/conda/lib/python3.7/site-packages/torch/lib/libc10.so)\n2022-03-21T21:03:17.8505131Z frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string, std::allocator > const&) + 0xd2 (0x7f180df160e2 in /opt/conda/lib/python3.7/site-packages/torch/lib/libc10.so)\n2022-03-21T21:03:17.8505927Z frame #2: c10::detail::torchInternalAssertFail(char const*, char const*, unsigned int, char const*, std::__cxx11::basic_string, std::allocator > const&) + 0x4e (0x7f180df17a7e in /opt/conda/lib/python3.7/site-packages/torch/lib/libc10.so)\n2022-03-21T21:03:17.8506674Z frame #3: torch::distributed::rpc::RRefContext::getOwnerRRef(torch::distributed::rpc::GloballyUniqueId const&, bool) + 0x4b4 (0x7f18118b7b64 in /opt/conda/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)\n2022-03-21T21:03:17.8507642Z frame #4: torch::distributed::rpc::RequestCallbackNoPython::assignOwnerRRef(torch::distributed::rpc::GloballyUniqueId const&, torch::distributed::rpc::GloballyUniqueId const&, c10::intrusive_ptr >) const + 0x70 (0x7f18118a7bf0 in /opt/conda/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)\n2022-03-21T21:03:17.8508613Z frame #5: torch::distributed::rpc::RequestCallbackImpl::processPythonRemoteCall(torch::distributed::rpc::RpcCommandBase&, std::vector >) const + 0xc8 (0x7f1819736208 in /opt/conda/lib/python3.7/site-packages/torch/lib/libtorch_python.so)\n2022-03-21T21:03:17.8509749Z frame #6: torch::distributed::rpc::RequestCallbackNoPython::processRpc(torch::distributed::rpc::RpcCommandBase&, torch::distributed::rpc::MessageType const&, std::vector >) const + 0x194 (0x7f18118ac914 in /opt/conda/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)\n2022-03-21T21:03:17.8510708Z frame #7: torch::distributed::rpc::RequestCallbackImpl::processRpcWithErrors(torch::distributed::rpc::RpcCommandBase&, torch::distributed::rpc::MessageType const&, std::vector >) const + 0x65 (0x7f1819735865 in /opt/conda/lib/python3.7/site-packages/torch/lib/libtorch_python.so)\n2022-03-21T21:03:17.8511369Z frame #8: + 0x375249a (0x7f18118a949a in /opt/conda/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)\n\n\n pull / linux-xenial-cuda11.3-py3.7-gcc7-bazel-test / build-and-test (22/29)\nStep: \"Upload test statistics\" (full log | diagnosis details | \ud83d\udd01 rerun)\n\n\n2022-03-21T20:01:07.7015580Z \ufffd[36;1m echo \"ERR...t available for the merge-base of your branch\"\ufffd[0m\n\n2022-03-21T20:01:07.7012399Z \ufffd[36;1mfi\ufffd[0m\n2022-03-21T20:01:07.7012634Z \ufffd[36;1m# Covers the case where a previous tag doesn't exist for the tree\ufffd[0m\n2022-03-21T20:01:07.7012992Z \ufffd[36;1m# this is only really applicable on trees that don't have `.circleci/docker` at its merge base, i.e. nightly\ufffd[0m\n2022-03-21T20:01:07.7013373Z \ufffd[36;1mif ! git rev-parse \"$MERGE_BASE:.circleci/docker\"; then\ufffd[0m\n2022-03-21T20:01:07.7013784Z \ufffd[36;1m echo \"Directory '.circleci/docker' not found in commit $MERGE_BASE, you should probably rebase onto a more recent commit\"\ufffd[0m\n2022-03-21T20:01:07.7014149Z \ufffd[36;1m exit 1\ufffd[0m\n2022-03-21T20:01:07.7014325Z \ufffd[36;1mfi\ufffd[0m\n2022-03-21T20:01:07.7014573Z \ufffd[36;1mPREVIOUS_DOCKER_TAG=$(git rev-parse \"$MERGE_BASE:.circleci/docker\")\ufffd[0m\n2022-03-21T20:01:07.7014907Z \ufffd[36;1m# If no image exists but the hash is the same as the previous hash then we should error out here\ufffd[0m\n2022-03-21T20:01:07.7015231Z \ufffd[36;1mif [[ \"${PREVIOUS_DOCKER_TAG}\" = \"${DOCKER_TAG}\" ]]; then\ufffd[0m\n2022-03-21T20:01:07.7015580Z \ufffd[36;1m echo \"ERROR: Something has gone wrong and the previous image isn't available for the merge-base of your branch\"\ufffd[0m\n2022-03-21T20:01:07.7015931Z \ufffd[36;1m echo \" contact the PyTorch team to restore the original images\"\ufffd[0m\n2022-03-21T20:01:07.7016225Z \ufffd[36;1m exit 1\ufffd[0m\n2022-03-21T20:01:07.7016400Z \ufffd[36;1mfi\ufffd[0m\n2022-03-21T20:01:07.7016608Z \ufffd[36;1mecho ::set-output name=rebuild::yes\ufffd[0m\n2022-03-21T20:01:07.7027605Z shell: /usr/bin/bash --noprofile --norc -e -o pipefail {0}\n2022-03-21T20:01:07.7027837Z env:\n2022-03-21T20:01:07.7028006Z IN_CI: 1\n2022-03-21T20:01:07.7028159Z IS_GHA: 1\n2022-03-21T20:01:07.7028346Z GIT_DEFAULT_BRANCH: master\n2022-03-21T20:01:07.7028589Z BASE_REVISION: 6643522db9ff595f564b8081de58b3a33c546178\n\n\n pull / linux-xenial-cuda11.3-py3.7-gcc7 / test (deploy, 1, 1, linux.4xlarge.nvidia.gpu) (23/29)\nStep: \"Upload test statistics\" (full log | diagnosis details | \ud83d\udd01 rerun)\n\n\n2022-03-22T00:49:54.2949572Z python3: can't ope...ow_job_id.py': [Errno 2] No such file or directory\n\n2022-03-22T00:49:53.8049151Z + python3 -m pip install boto3==1.19.12\n2022-03-22T00:49:54.0981629Z Defaulting to user installation because normal site-packages is not writeable\n2022-03-22T00:49:54.1207562Z Requirement already satisfied: boto3==1.19.12 in /home/ec2-user/.local/lib/python3.7/site-packages (1.19.12)\n2022-03-22T00:49:54.1277146Z Requirement already satisfied: s3transfer<0.6.0,>=0.5.0 in /home/ec2-user/.local/lib/python3.7/site-packages (from boto3==1.19.12) (0.5.2)\n2022-03-22T00:49:54.1315027Z Requirement already satisfied: jmespath<1.0.0,>=0.7.1 in /home/ec2-user/.local/lib/python3.7/site-packages (from boto3==1.19.12) (0.10.0)\n2022-03-22T00:49:54.1331813Z Requirement already satisfied: botocore<1.23.0,>=1.22.12 in /home/ec2-user/.local/lib/python3.7/site-packages (from boto3==1.19.12) (1.22.12)\n2022-03-22T00:49:54.1391622Z Requirement already satisfied: urllib3<1.27,>=1.25.4 in /home/ec2-user/.local/lib/python3.7/site-packages (from botocore<1.23.0,>=1.22.12->boto3==1.19.12) (1.26.9)\n2022-03-22T00:49:54.1609217Z Requirement already satisfied: python-dateutil<3.0.0,>=2.1 in /home/ec2-user/.local/lib/python3.7/site-packages (from botocore<1.23.0,>=1.22.12->boto3==1.19.12) (2.8.2)\n2022-03-22T00:49:54.1637417Z Requirement already satisfied: six>=1.5 in /home/ec2-user/.local/lib/python3.7/site-packages (from python-dateutil<3.0.0,>=2.1->botocore<1.23.0,>=1.22.12->boto3==1.19.12) (1.16.0)\n2022-03-22T00:49:54.2830197Z ++ python3 .github/scripts/get_workflow_job_id.py 2018440039 i-0f7c32fe13be12fea\n2022-03-22T00:49:54.2949572Z python3: can't open file '.github/scripts/get_workflow_job_id.py': [Errno 2] No such file or directory\n2022-03-22T00:49:54.2966933Z + GHA_WORKFLOW_JOB_ID=\n2022-03-22T00:49:54.2982588Z ##[error]Process completed with exit code 2.\n2022-03-22T00:49:54.3031464Z ##[group]Run pytorch/pytorch/.github/actions/teardown-linux@master\n2022-03-22T00:49:54.3031794Z with:\n2022-03-22T00:49:54.3032012Z env:\n2022-03-22T00:49:54.3032227Z IN_CI: 1\n2022-03-22T00:49:54.3032434Z IS_GHA: 1\n2022-03-22T00:49:54.3032681Z GIT_DEFAULT_BRANCH: master\n2022-03-22T00:49:54.3033084Z GPU_FLAG: --gpus all\n2022-03-22T00:49:54.3033312Z ##[endgroup]\n\n\n pull / win-vs2019-cpu-py3 / test (default, 1, 2, windows.4xlarge) (24/29)\nStep: \"Upload test statistics\" (full log | diagnosis details | \ud83d\udd01 rerun)\n\n\n2022-03-21T21:56:12.5872636Z C:\\actions-runner\\...ow_job_id.py': [Errno 2] No such file or directory\n\n2022-03-21T21:56:07.3365589Z Downloading botocore-1.22.12-py3-none-any.whl (8.1 MB)\n2022-03-21T21:56:07.7926584Z ---------------------------------------- 8.1/8.1 MB 17.3 MB/s eta 0:00:00\n2022-03-21T21:56:07.9319362Z Collecting python-dateutil<3.0.0,>=2.1\n2022-03-21T21:56:07.9366132Z Downloading python_dateutil-2.8.2-py2.py3-none-any.whl (247 kB)\n2022-03-21T21:56:08.0077590Z -------------------------------------- 247.7/247.7 KB 3.0 MB/s eta 0:00:00\n2022-03-21T21:56:08.0164070Z Requirement already satisfied: urllib3<1.27,>=1.25.4 in c:\\actions-runner\\_work\\_tool\\python\\3.10.3\\x64\\lib\\site-packages (from botocore<1.23.0,>=1.22.12->boto3==1.19.12) (1.26.9)\n2022-03-21T21:56:08.1775537Z Requirement already satisfied: six>=1.5 in c:\\actions-runner\\_work\\_tool\\python\\3.10.3\\x64\\lib\\site-packages (from python-dateutil<3.0.0,>=2.1->botocore<1.23.0,>=1.22.12->boto3==1.19.12) (1.16.0)\n2022-03-21T21:56:08.3393469Z Installing collected packages: python-dateutil, jmespath, botocore, s3transfer, boto3\n2022-03-21T21:56:12.4576766Z Successfully installed boto3-1.19.12 botocore-1.22.12 jmespath-0.10.0 python-dateutil-2.8.2 s3transfer-0.5.2\n2022-03-21T21:56:12.5641959Z ++ python3 .github/scripts/get_workflow_job_id.py 2018440039 i-0afad69838118af0e\n2022-03-21T21:56:12.5872636Z C:\\actions-runner\\_work\\_tool\\Python\\3.10.3\\x64\\python3.exe: can't open file 'C:\\\\actions-runner\\\\_work\\\\pytorch\\\\pytorch\\\\.github\\\\scripts\\\\get_workflow_job_id.py': [Errno 2] No such file or directory\n2022-03-21T21:56:12.5905611Z + GHA_WORKFLOW_JOB_ID=\n2022-03-21T21:56:12.5927729Z ##[error]Process completed with exit code 2.\n2022-03-21T21:56:12.6239531Z ##[group]Run pytorch/pytorch/.github/actions/teardown-win@master\n2022-03-21T21:56:12.6240039Z with:\n2022-03-21T21:56:12.6240299Z env:\n2022-03-21T21:56:12.6240557Z IN_CI: 1\n2022-03-21T21:56:12.6240805Z IS_GHA: 1\n2022-03-21T21:56:12.6241118Z GIT_DEFAULT_BRANCH: master\n2022-03-21T21:56:12.6241613Z pythonLocation: C:\\actions-runner\\_work\\_tool\\Python\\3.10.3\\x64\n2022-03-21T21:56:12.6242052Z ##[endgroup]\n\n\n pull / linux-bionic-py3.7-clang9 / test (default, 1, 2, linux.2xlarge) (25/29)\nStep: \"Upload test statistics\" (full log | diagnosis details | \ud83d\udd01 rerun)\n\n\n2022-03-21T21:46:39.5474616Z python3: can't ope...ow_job_id.py': [Errno 2] No such file or directory\n\n2022-03-21T21:46:39.1884210Z + python3 -m pip install boto3==1.19.12\n2022-03-21T21:46:39.3928976Z Defaulting to user installation because normal site-packages is not writeable\n2022-03-21T21:46:39.4105069Z Requirement already satisfied: boto3==1.19.12 in /home/ec2-user/.local/lib/python3.7/site-packages (1.19.12)\n2022-03-21T21:46:39.4152571Z Requirement already satisfied: botocore<1.23.0,>=1.22.12 in /home/ec2-user/.local/lib/python3.7/site-packages (from boto3==1.19.12) (1.22.12)\n2022-03-21T21:46:39.4194931Z Requirement already satisfied: s3transfer<0.6.0,>=0.5.0 in /home/ec2-user/.local/lib/python3.7/site-packages (from boto3==1.19.12) (0.5.2)\n2022-03-21T21:46:39.4218947Z Requirement already satisfied: jmespath<1.0.0,>=0.7.1 in /home/ec2-user/.local/lib/python3.7/site-packages (from boto3==1.19.12) (0.10.0)\n2022-03-21T21:46:39.4230812Z Requirement already satisfied: urllib3<1.27,>=1.25.4 in /home/ec2-user/.local/lib/python3.7/site-packages (from botocore<1.23.0,>=1.22.12->boto3==1.19.12) (1.26.9)\n2022-03-21T21:46:39.4380089Z Requirement already satisfied: python-dateutil<3.0.0,>=2.1 in /home/ec2-user/.local/lib/python3.7/site-packages (from botocore<1.23.0,>=1.22.12->boto3==1.19.12) (2.8.2)\n2022-03-21T21:46:39.4399461Z Requirement already satisfied: six>=1.5 in /home/ec2-user/.local/lib/python3.7/site-packages (from python-dateutil<3.0.0,>=2.1->botocore<1.23.0,>=1.22.12->boto3==1.19.12) (1.16.0)\n2022-03-21T21:46:39.5387703Z ++ python3 .github/scripts/get_workflow_job_id.py 2018440039 i-0888bed1149cca415\n2022-03-21T21:46:39.5474616Z python3: can't open file '.github/scripts/get_workflow_job_id.py': [Errno 2] No such file or directory\n2022-03-21T21:46:39.5487145Z + GHA_WORKFLOW_JOB_ID=\n2022-03-21T21:46:39.5497480Z ##[error]Process completed with exit code 2.\n2022-03-21T21:46:39.5541319Z ##[group]Run pytorch/pytorch/.github/actions/teardown-linux@master\n2022-03-21T21:46:39.5541544Z with:\n2022-03-21T21:46:39.5541698Z env:\n2022-03-21T21:46:39.5541851Z IN_CI: 1\n2022-03-21T21:46:39.5541997Z IS_GHA: 1\n2022-03-21T21:46:39.5542176Z GIT_DEFAULT_BRANCH: master\n2022-03-21T21:46:39.5542361Z ##[endgroup]\n2022-03-21T21:46:39.5557878Z ##[group]Run # ignore expansion of \"docker ps -q\" since it could be empty\n\n\n pull / linux-xenial-py3.7-gcc7 / test (default, 2, 2, linux.2xlarge) (26/29)\nStep: \"Upload test statistics\" (full log | diagnosis details | \ud83d\udd01 rerun)\n\n\n2022-03-21T21:34:57.0623859Z python3: can't ope...ow_job_id.py': [Errno 2] No such file or directory\n\n2022-03-21T21:34:56.9039884Z Attempting uninstall: s3transfer\n2022-03-21T21:34:56.9041446Z Found existing installation: s3transfer 0.3.7\n2022-03-21T21:34:56.9090783Z Uninstalling s3transfer-0.3.7:\n2022-03-21T21:34:56.9095968Z Successfully uninstalled s3transfer-0.3.7\n2022-03-21T21:34:56.9453014Z Attempting uninstall: boto3\n2022-03-21T21:34:56.9454356Z Found existing installation: boto3 1.16.34\n2022-03-21T21:34:56.9564320Z Uninstalling boto3-1.16.34:\n2022-03-21T21:34:56.9578035Z Successfully uninstalled boto3-1.16.34\n2022-03-21T21:34:57.0091363Z Successfully installed boto3-1.19.12 botocore-1.22.12 s3transfer-0.5.2\n2022-03-21T21:34:57.0536230Z ++ python3 .github/scripts/get_workflow_job_id.py 2018440039 i-034a3afd5d80b91fd\n2022-03-21T21:34:57.0623859Z python3: can't open file '.github/scripts/get_workflow_job_id.py': [Errno 2] No such file or directory\n2022-03-21T21:34:57.0637167Z + GHA_WORKFLOW_JOB_ID=\n2022-03-21T21:34:57.0647396Z ##[error]Process completed with exit code 2.\n2022-03-21T21:34:57.0688237Z ##[group]Run pytorch/pytorch/.github/actions/teardown-linux@master\n2022-03-21T21:34:57.0688481Z with:\n2022-03-21T21:34:57.0688631Z env:\n2022-03-21T21:34:57.0688769Z IN_CI: 1\n2022-03-21T21:34:57.0688930Z IS_GHA: 1\n2022-03-21T21:34:57.0689109Z GIT_DEFAULT_BRANCH: master\n2022-03-21T21:34:57.0689462Z ##[endgroup]\n2022-03-21T21:34:57.0704768Z ##[group]Run # ignore expansion of \"docker ps -q\" since it could be empty\n\n\n pull / linux-xenial-py3.7-clang7-asan / test (default, 3, 3, linux.2xlarge) (27/29)\nStep: \"Upload test statistics\" (full log | diagnosis details | \ud83d\udd01 rerun)\n\n\n2022-03-21T21:05:00.7896545Z SUMMARY: Undefined.../jenkins/workspace/aten/src/ATen/Utils.cpp:20:3 in\n\n2022-03-21T21:05:00.7395504Z #10 0x5597fd5a9801 in run_mod /tmp/build/80754af9/python_1627392990942/work/Python/pythonrun.c:1037\n2022-03-21T21:05:00.7396330Z #11 0x5597fd5b47a9 in PyRun_StringFlags /tmp/build/80754af9/python_1627392990942/work/Python/pythonrun.c:961\n2022-03-21T21:05:00.7396688Z #12 0x5597fd5b480b in PyRun_SimpleStringFlags /tmp/build/80754af9/python_1627392990942/work/Python/pythonrun.c:455\n2022-03-21T21:05:00.7398664Z #13 0x5597fd5b4908 in pymain_run_command /tmp/build/80754af9/python_1627392990942/work/Modules/main.c:420\n2022-03-21T21:05:00.7399177Z #14 0x5597fd5b4908 in pymain_run_python /tmp/build/80754af9/python_1627392990942/work/Modules/main.c:2907\n2022-03-21T21:05:00.7399663Z #15 0x5597fd5b4908 in pymain_main /tmp/build/80754af9/python_1627392990942/work/Modules/main.c:3460\n2022-03-21T21:05:00.7399986Z #16 0x5597fd5b4ccb in _Py_UnixMain /tmp/build/80754af9/python_1627392990942/work/Modules/main.c:3495\n2022-03-21T21:05:00.7895241Z #17 0x7f0a5905983f in __libc_start_main /build/glibc-S7Ft5T/glibc-2.23/csu/../csu/libc-start.c:291\n2022-03-21T21:05:00.7895772Z #18 0x5597fd559554 in _start (/opt/conda/bin/python3.7+0x1d7554)\n2022-03-21T21:05:00.7896033Z \n2022-03-21T21:05:00.7896545Z SUMMARY: UndefinedBehaviorSanitizer: undefined-behavior /var/lib/jenkins/workspace/aten/src/ATen/Utils.cpp:20:3 in \n2022-03-21T21:05:00.8063448Z + retcode=1\n2022-03-21T21:05:00.8063787Z + set -e\n2022-03-21T21:05:00.8064058Z + return 1\n2022-03-21T21:05:00.8067638Z + [[ linux-xenial-py3.7-clang7-asan-default == *-NO_AVX-* ]]\n2022-03-21T21:05:00.8068127Z + [[ default == \\n\\o\\g\\p\\u\\_\\N\\O\\_\\A\\V\\X ]]\n2022-03-21T21:05:00.8069018Z + [[ linux-xenial-py3.7-clang7-asan-default == *-NO_AVX2-* ]]\n2022-03-21T21:05:00.8069500Z + [[ default == \\n\\o\\g\\p\\u\\_\\N\\O\\_\\A\\V\\X\\2 ]]\n2022-03-21T21:05:00.8070105Z + [[ linux-xenial-py3.7-clang7-asan-default == *-NO_AVX512-* ]]\n2022-03-21T21:05:00.8070580Z + [[ default == \\n\\o\\g\\p\\u\\_\\N\\O\\_\\A\\V\\X\\5\\1\\2 ]]\n2022-03-21T21:05:00.8072640Z + [[ linux-xenial-py3.7-clang7-asan-default == *tbb* ]]\n\n\n pull / linux-xenial-cuda11.3-py3.7-gcc7 / test (default, 2, 2, linux.4xlarge.nvidia.gpu) (28/29)\nStep: \"Upload test statistics\" (full log | diagnosis details | \ud83d\udd01 rerun)\n\n\n2022-03-21T22:48:17.3384813Z python3: can't ope...ow_job_id.py': [Errno 2] No such file or directory\n\n2022-03-21T22:48:16.8599645Z + python3 -m pip install boto3==1.19.12\n2022-03-21T22:48:17.1464241Z Defaulting to user installation because normal site-packages is not writeable\n2022-03-21T22:48:17.1685222Z Requirement already satisfied: boto3==1.19.12 in /home/ec2-user/.local/lib/python3.7/site-packages (1.19.12)\n2022-03-21T22:48:17.1754164Z Requirement already satisfied: jmespath<1.0.0,>=0.7.1 in /home/ec2-user/.local/lib/python3.7/site-packages (from boto3==1.19.12) (0.10.0)\n2022-03-21T22:48:17.1771662Z Requirement already satisfied: s3transfer<0.6.0,>=0.5.0 in /home/ec2-user/.local/lib/python3.7/site-packages (from boto3==1.19.12) (0.5.2)\n2022-03-21T22:48:17.1808722Z Requirement already satisfied: botocore<1.23.0,>=1.22.12 in /home/ec2-user/.local/lib/python3.7/site-packages (from boto3==1.19.12) (1.22.12)\n2022-03-21T22:48:17.1868636Z Requirement already satisfied: python-dateutil<3.0.0,>=2.1 in /home/ec2-user/.local/lib/python3.7/site-packages (from botocore<1.23.0,>=1.22.12->boto3==1.19.12) (2.8.2)\n2022-03-21T22:48:17.1903889Z Requirement already satisfied: urllib3<1.27,>=1.25.4 in /home/ec2-user/.local/lib/python3.7/site-packages (from botocore<1.23.0,>=1.22.12->boto3==1.19.12) (1.26.9)\n2022-03-21T22:48:17.2113746Z Requirement already satisfied: six>=1.5 in /home/ec2-user/.local/lib/python3.7/site-packages (from python-dateutil<3.0.0,>=2.1->botocore<1.23.0,>=1.22.12->boto3==1.19.12) (1.16.0)\n2022-03-21T22:48:17.3267404Z ++ python3 .github/scripts/get_workflow_job_id.py 2018440039 i-01fe178c405417375\n2022-03-21T22:48:17.3384813Z python3: can't open file '.github/scripts/get_workflow_job_id.py': [Errno 2] No such file or directory\n2022-03-21T22:48:17.3402286Z + GHA_WORKFLOW_JOB_ID=\n2022-03-21T22:48:17.3418376Z ##[error]Process completed with exit code 2.\n2022-03-21T22:48:17.3470528Z ##[group]Run pytorch/pytorch/.github/actions/teardown-linux@master\n2022-03-21T22:48:17.3470874Z with:\n2022-03-21T22:48:17.3471096Z env:\n2022-03-21T22:48:17.3471327Z IN_CI: 1\n2022-03-21T22:48:17.3471538Z IS_GHA: 1\n2022-03-21T22:48:17.3471802Z GIT_DEFAULT_BRANCH: master\n2022-03-21T22:48:17.3472083Z GPU_FLAG: --gpus all\n2022-03-21T22:48:17.3472322Z ##[endgroup]\n\n\n pull / linux-xenial-py3.7-gcc7 / test (default, 1, 2, linux.2xlarge) (29/29)\nStep: \"Upload test statistics\" (full log | diagnosis details | \ud83d\udd01 rerun)\n\n\n2022-03-21T21:16:38.9646300Z python3: can't ope...ow_job_id.py': [Errno 2] No such file or directory\n\n2022-03-21T21:16:38.7995969Z Attempting uninstall: s3transfer\n2022-03-21T21:16:38.7998039Z Found existing installation: s3transfer 0.3.7\n2022-03-21T21:16:38.8066994Z Uninstalling s3transfer-0.3.7:\n2022-03-21T21:16:38.8072844Z Successfully uninstalled s3transfer-0.3.7\n2022-03-21T21:16:38.8449275Z Attempting uninstall: boto3\n2022-03-21T21:16:38.8451430Z Found existing installation: boto3 1.16.34\n2022-03-21T21:16:38.8559828Z Uninstalling boto3-1.16.34:\n2022-03-21T21:16:38.8574290Z Successfully uninstalled boto3-1.16.34\n2022-03-21T21:16:38.9100438Z Successfully installed boto3-1.19.12 botocore-1.22.12 s3transfer-0.5.2\n2022-03-21T21:16:38.9558098Z ++ python3 .github/scripts/get_workflow_job_id.py 2018440039 i-0d779c59d277d32ee\n2022-03-21T21:16:38.9646300Z python3: can't open file '.github/scripts/get_workflow_job_id.py': [Errno 2] No such file or directory\n2022-03-21T21:16:38.9658894Z + GHA_WORKFLOW_JOB_ID=\n2022-03-21T21:16:38.9673240Z ##[error]Process completed with exit code 2.\n2022-03-21T21:16:38.9720106Z ##[group]Run pytorch/pytorch/.github/actions/teardown-linux@master\n2022-03-21T21:16:38.9720333Z with:\n2022-03-21T21:16:38.9720485Z env:\n2022-03-21T21:16:38.9720645Z IN_CI: 1\n2022-03-21T21:16:38.9720793Z IS_GHA: 1\n2022-03-21T21:16:38.9720970Z GIT_DEFAULT_BRANCH: master\n2022-03-21T21:16:38.9721151Z ##[endgroup]\n2022-03-21T21:16:38.9736762Z ##[group]Run # ignore expansion of \"docker ps -q\" since it could be empty\n\n\n\nThis comment was automatically generated by Dr. CI (expand for details).\nPlease report bugs/suggestions to the (internal) Dr. CI Users group.\nClick here to manually regenerate this comment.", + "createdAt": "2021-11-10T08:42:52Z", + "author": { + "login": "facebook-github-bot" + }, + "authorAssociation": "MEMBER", + "editor": { + "login": "facebook-github-bot" + }, + "databaseId": 964902894 + }, + { + "bodyText": "@vitaly-fedyunin @gottbrath FYI that this is the oneDNN Graph API integration. It depends on the #63748.", + "createdAt": "2021-11-16T16:36:52Z", + "author": { + "login": "Jianhui-Li" + }, + "authorAssociation": "NONE", + "editor": null, + "databaseId": 970451860 + }, + { + "bodyText": "CI failures are currently being caused by some issues in the CI infra, and are also occurring with other PRs.", + "createdAt": "2021-12-10T05:59:17Z", + "author": { + "login": "sanchitintel" + }, + "authorAssociation": "COLLABORATOR", + "editor": null, + "databaseId": 990641309 + }, + { + "bodyText": "CI failures are unrelated.", + "createdAt": "2021-12-10T20:44:09Z", + "author": { + "login": "sanchitintel" + }, + "authorAssociation": "COLLABORATOR", + "editor": null, + "databaseId": 991281407 + }, + { + "bodyText": "The CI failure is unrelated.", + "createdAt": "2021-12-16T02:45:59Z", + "author": { + "login": "sanchitintel" + }, + "authorAssociation": "COLLABORATOR", + "editor": null, + "databaseId": 995389295 + }, + { + "bodyText": "Hi, thank you for the PR!\nDo you mind running a larger amount of torchbench and reporting numbers ? You can look at Jason's post here for what models are supported in script. Initially just the vision models would be useful. @Krovatkin also did some benchmarking of a traced Bert model and found on average a ~16% speedup with this PR.", + "createdAt": "2022-01-18T18:22:34Z", + "author": { + "login": "eellison" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1015689390 + }, + { + "bodyText": "Thanks a lot for reviewing, @eellison & @Krovatkin!\nWe just wanted to let you know that we're working on the benchmarking & will get back to you in a day, or two.\nUPDATE (Jan 21): While running some TorchBench models, we discovered some composability issues, and are working to ensure that oneDNN Graph would complement PyTorch's existing fusion capabilities, not hinder them.\nUPDATE (Jan 24): We've resolved the issues & will update this PR later today. Thanks!", + "createdAt": "2022-01-20T00:31:01Z", + "author": { + "login": "sanchitintel" + }, + "authorAssociation": "COLLABORATOR", + "editor": { + "login": "sanchitintel" + }, + "databaseId": 1016996190 + }, + { + "bodyText": "Hello @eellison,\nWe used this TorchBench branch for comparison. compare_llga.sh can be run for comparison.\nFor benchmarking mobilenet_v3_large with hardswish support in oneDNN Graph, this oneDNN Graph branch can be used in third_party/ideep/mkl-dnn. It delivers a speedup over PyTorch JIT (NNC + OFI) because 21 additional reorders are prevented (the major factor here), and fusion with conv also helps further.\nThe next release of oneDNN Graph would have hardswish support.\nWe're also exploring adding a hardsigmoid op in oneDNN Graph.\nThank you!", + "createdAt": "2022-01-26T23:51:38Z", + "author": { + "login": "sanchitintel" + }, + "authorAssociation": "COLLABORATOR", + "editor": { + "login": "sanchitintel" + }, + "databaseId": 1022709513 + }, + { + "bodyText": "Please note that this PR should be merged after #71546, as #71546 changes the third_party/ideep commit (this PR also uses that ideep commit, but it'd probably be better to merge #71546 first, so that oneDNN v2.5.2 upgrade would be in a separate PR). Thank you!", + "createdAt": "2022-01-31T23:57:21Z", + "author": { + "login": "sanchitintel" + }, + "authorAssociation": "COLLABORATOR", + "editor": null, + "databaseId": 1026330085 + }, + { + "bodyText": "@sanchitintel mind rebasing and i'll land ?", + "createdAt": "2022-03-01T20:07:57Z", + "author": { + "login": "eellison" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1055813984 + }, + { + "bodyText": "@eellison has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.", + "createdAt": "2022-03-02T17:44:47Z", + "author": { + "login": "facebook-github-bot" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1057203495 + }, + { + "bodyText": "Thanks a lot for taking a look, @eellison! To fix this error, we would enable Bazel build for oneDNN Graph.", + "createdAt": "2022-03-07T23:03:45Z", + "author": { + "login": "sanchitintel" + }, + "authorAssociation": "COLLABORATOR", + "editor": { + "login": "sanchitintel" + }, + "databaseId": 1061230087 + }, + { + "bodyText": "@eellison has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.", + "createdAt": "2022-03-09T19:24:13Z", + "author": { + "login": "facebook-github-bot" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1063276600 + }, + { + "bodyText": "@malfet has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.", + "createdAt": "2022-03-21T19:59:41Z", + "author": { + "login": "facebook-github-bot" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1074355779 + }, + { + "bodyText": "And graph_rewriter.cpp is full of DOS newlines...", + "createdAt": "2022-03-21T20:53:40Z", + "author": { + "login": "malfet" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1074407452 + }, + { + "bodyText": "Hey @chunyuan-w.\nYou've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.\nFor changes that are 'topic: not user facing' there is no need for a release notes label.", + "createdAt": "2022-03-21T22:12:51Z", + "author": { + "login": "github-actions" + }, + "authorAssociation": "NONE", + "editor": null, + "databaseId": 1074471758 + }, + { + "bodyText": "Thanks a ton for your help, @malfet & @eellison! :)\nWe'll incorporate your suggestions in subsequent PR(s).", + "createdAt": "2022-03-21T22:41:25Z", + "author": { + "login": "sanchitintel" + }, + "authorAssociation": "COLLABORATOR", + "editor": { + "login": "sanchitintel" + }, + "databaseId": 1074492365 + } + ], + "pageInfo": { + "startCursor": "Y3Vyc29yOnYyOpHOOYM_0Q==", + "hasPreviousPage": false + } + } + } + } + } + }, + "query_sha=81fd873151c3cded18314e9e53bf54a93ffb0afa9c52fa2cbafb2ceab7df5e45 name=pytorch number=73969 owner=pytorch": { + "data": { + "repository": { + "pullRequest": { + "closed": true, + "isCrossRepository": true, + "author": { + "login": "malfet" + }, + "title": "Dummy change", + "body": "Test Plan: None at all\n\nDifferential Revision: D34753911\n\n", + "headRefName": "export-D34753911", + "headRepository": { + "nameWithOwner": "malfet/pytorch" + }, + "baseRefName": "master", + "baseRepository": { + "nameWithOwner": "pytorch/pytorch", + "isPrivate": false, + "defaultBranchRef": { + "name": "master" + } + }, + "mergeCommit": null, + "commits_with_authors": { + "nodes": [ + { + "commit": { + "author": { + "user": { + "login": "malfet" + }, + "email": "nshulga@fb.com", + "name": "Nikita Shulga" + }, + "oid": "4746da707a9912356f5179625da89616b228dc21" + } + } + ], + "pageInfo": { + "endCursor": "MQ", + "hasNextPage": false + }, + "totalCount": 1 + }, + "commits": { + "nodes": [ + { + "commit": { + "checkSuites": { + "edges": [ + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "linux-vulkan-bionic-py3.7-clang9" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1958280134/jobs/2794078044" + }, + { + "name": "test (default, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1958280134/jobs/2794189060" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAUbRQMQ=", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/4746da707a9912356f5179625da89616b228dc21/checks?check_suite_id=5595592963" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAU2F-QM=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "linux-xenial-py3.7-gcc5.4-mobile-lightweight-dispatch-build" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1958280135/jobs/2794078023" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAUbO2aM=", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/4746da707a9912356f5179625da89616b228dc21/checks?check_suite_id=5595592965" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAU2F-QU=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "linux-bionic-rocm4.5-py3.7" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1958280132/jobs/2794078060" + }, + { + "name": "test (default, 1, 2, linux.rocm.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1958280132/jobs/2794292071" + }, + { + "name": "test (default, 2, 2, linux.rocm.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1958280132/jobs/2794292205" + }, + { + "name": "test (distributed, 1, 1, linux.rocm.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1958280132/jobs/2794292306" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAUbTiXw=", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/4746da707a9912356f5179625da89616b228dc21/checks?check_suite_id=5595592966" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAU2F-QY=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "win-vs2019-cuda11.3-py3" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1958280139/jobs/2794078053" + }, + { + "name": "test (force_on_cpu, 1, 1, windows.4xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1958280139/jobs/2794536907" + }, + { + "name": "test (default, 2, 2, windows.8xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1958280139/jobs/2794536998" + }, + { + "name": "test (default, 1, 2, windows.8xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1958280139/jobs/2794537089" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAUbY_vU=", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/4746da707a9912356f5179625da89616b228dc21/checks?check_suite_id=5595592967" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAU2F-Qc=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "linux-xenial-cuda11.3-py3.7-gcc7-bazel-test" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "build-and-test", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1958280136/jobs/2794078031" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAUbO2ao=", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/4746da707a9912356f5179625da89616b228dc21/checks?check_suite_id=5595592969" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAU2F-Qk=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "linux-docs" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1958280138/jobs/2794078055" + }, + { + "name": "build-docs (cpp)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1958280138/jobs/2794183768" + }, + { + "name": "build-docs (python)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1958280138/jobs/2794183828" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAUbRIt0=", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/4746da707a9912356f5179625da89616b228dc21/checks?check_suite_id=5595592970" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAU2F-Qo=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "linux-xenial-py3.7-gcc7" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1958280140/jobs/2794078017" + }, + { + "name": "test (default, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1958280140/jobs/2794181109" + }, + { + "name": "test (default, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1958280140/jobs/2794181305" + }, + { + "name": "test (distributed, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1958280140/jobs/2794181488" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAUbRFm4=", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/4746da707a9912356f5179625da89616b228dc21/checks?check_suite_id=5595592971" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAU2F-Qs=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "linux-xenial-py3-clang5-mobile-build" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1958280143/jobs/2794078025" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAUbO2aw=", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/4746da707a9912356f5179625da89616b228dc21/checks?check_suite_id=5595592974" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAU2F-Q4=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "Lint" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "shellcheck", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1958280145/jobs/2794078028" + }, + { + "name": "quick-checks", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1958280145/jobs/2794078196" + }, + { + "name": "clang-tidy", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1958280145/jobs/2794078407" + }, + { + "name": "clang-format", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1958280145/jobs/2794078610" + }, + { + "name": "cmakelint", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1958280145/jobs/2794078760" + }, + { + "name": "toc", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1958280145/jobs/2794078898" + }, + { + "name": "py2-setup-validate-errormsg", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1958280145/jobs/2794078999" + }, + { + "name": "flake8-py3", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1958280145/jobs/2794079087" + }, + { + "name": "mypy", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1958280145/jobs/2794079199" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAUbO4Es=", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/4746da707a9912356f5179625da89616b228dc21/checks?check_suite_id=5595592975" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAU2F-Q8=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single-full-jit" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "build-and-test", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1958280146/jobs/2794078040" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAUbO2b0=", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/4746da707a9912356f5179625da89616b228dc21/checks?check_suite_id=5595592976" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAU2F-RA=" + } + ], + "pageInfo": { + "hasNextPage": true + } + }, + "status": { + "contexts": [ + { + "context": "ci/circleci: docker-pytorch-linux-xenial-py3-clang5-android-ndk-r19c", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/17040614?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-build-x86_32", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/17040643?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: pytorch_linux_xenial_py3_clang5_android_ndk_r19c_x86_32_build", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/17040615?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + } + ] + }, + "pushedDate": "2022-03-09T15:57:16Z", + "oid": "4746da707a9912356f5179625da89616b228dc21" + } + } + ] + }, + "changedFiles": 1, + "files": { + "nodes": [ + { + "path": "tools/build_variables.bzl" + } + ], + "pageInfo": { + "endCursor": "MQ", + "hasNextPage": false + } + }, + "reviews": { + "nodes": [], + "pageInfo": { + "startCursor": null, + "hasPreviousPage": false + } + }, + "comments": { + "nodes": [ + { + "bodyText": "CI Flow Status\n\u269b\ufe0f CI Flow\nRuleset - Version: v1\nRuleset - File: https://github.com/malfet/pytorch/blob/4746da707a9912356f5179625da89616b228dc21/.github/generated-ciflow-ruleset.json\nPR ciflow labels: ciflow/default\nAdd ciflow labels to this PR to trigger more builds:\n\n\n\nWorkflows\nLabels (bold enabled)\nStatus\n\n\n\n\nTriggered Workflows\n\n\n\n\nlinux-binary-conda\nciflow/binaries, ciflow/binaries_conda, ciflow/default\n\u2705 triggered\n\n\nlinux-binary-libtorch-cxx11-abi\nciflow/all, ciflow/binaries, ciflow/binaries_libtorch, ciflow/default, ciflow/trunk\n\u2705 triggered\n\n\nlinux-binary-libtorch-pre-cxx11\nciflow/all, ciflow/binaries, ciflow/binaries_libtorch, ciflow/default, ciflow/trunk\n\u2705 triggered\n\n\nlinux-binary-manywheel\nciflow/all, ciflow/binaries, ciflow/binaries_wheel, ciflow/default, ciflow/trunk\n\u2705 triggered\n\n\nlinux-bionic-py3.7-clang9\nciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/noarch, ciflow/trunk\n\u2705 triggered\n\n\nlinux-bionic-rocm4.5-py3.7\nciflow/all, ciflow/default, ciflow/linux, ciflow/rocm, ciflow/trunk\n\u2705 triggered\n\n\nlinux-docs\nciflow/all, ciflow/cpu, ciflow/default, ciflow/docs, ciflow/linux, ciflow/trunk\n\u2705 triggered\n\n\nlinux-vulkan-bionic-py3.7-clang9\nciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk, ciflow/vulkan\n\u2705 triggered\n\n\nlinux-xenial-cuda11.3-py3.7-gcc7\nciflow/all, ciflow/cuda, ciflow/default, ciflow/linux, ciflow/trunk\n\u2705 triggered\n\n\nlinux-xenial-cuda11.3-py3.7-gcc7-bazel-test\nciflow/all, ciflow/bazel, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk\n\u2705 triggered\n\n\nlinux-xenial-py3-clang5-mobile-build\nciflow/all, ciflow/default, ciflow/linux, ciflow/mobile, ciflow/trunk\n\u2705 triggered\n\n\nlinux-xenial-py3-clang5-mobile-custom-build-static\nciflow/all, ciflow/default, ciflow/linux, ciflow/mobile, ciflow/trunk\n\u2705 triggered\n\n\nlinux-xenial-py3.7-clang7-asan\nciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/sanitizers, ciflow/trunk\n\u2705 triggered\n\n\nlinux-xenial-py3.7-clang7-onnx\nciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/onnx, ciflow/trunk\n\u2705 triggered\n\n\nlinux-xenial-py3.7-gcc5.4\nciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk\n\u2705 triggered\n\n\nlinux-xenial-py3.7-gcc5.4-mobile-lightweight-dispatch-build\nciflow/all, ciflow/cpu, ciflow/default, ciflow/libtorch, ciflow/linux, ciflow/mobile, ciflow/trunk\n\u2705 triggered\n\n\nlinux-xenial-py3.7-gcc7\nciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk\n\u2705 triggered\n\n\nlinux-xenial-py3.7-gcc7-no-ops\nciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk\n\u2705 triggered\n\n\nmacos-arm64-binary-conda\nciflow/binaries, ciflow/binaries_conda, ciflow/default\n\u2705 triggered\n\n\nmacos-arm64-binary-wheel\nciflow/binaries, ciflow/binaries_wheel, ciflow/default\n\u2705 triggered\n\n\nmacos-binary-conda\nciflow/binaries, ciflow/binaries_conda, ciflow/default\n\u2705 triggered\n\n\nmacos-binary-libtorch-cxx11-abi\nciflow/binaries, ciflow/binaries_libtorch, ciflow/default\n\u2705 triggered\n\n\nmacos-binary-libtorch-pre-cxx11\nciflow/binaries, ciflow/binaries_libtorch, ciflow/default\n\u2705 triggered\n\n\nmacos-binary-wheel\nciflow/binaries, ciflow/binaries_wheel, ciflow/default\n\u2705 triggered\n\n\npytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single\nciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk\n\u2705 triggered\n\n\npytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single-full-jit\nciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk\n\u2705 triggered\n\n\nwin-vs2019-cpu-py3\nciflow/all, ciflow/cpu, ciflow/default, ciflow/trunk, ciflow/win\n\u2705 triggered\n\n\nwin-vs2019-cuda11.3-py3\nciflow/all, ciflow/cuda, ciflow/default, ciflow/trunk, ciflow/win\n\u2705 triggered\n\n\nwindows-binary-conda\nciflow/binaries, ciflow/binaries_conda, ciflow/default\n\u2705 triggered\n\n\nwindows-binary-libtorch-debug\nciflow/all, ciflow/binaries, ciflow/binaries_libtorch, ciflow/default, ciflow/trunk\n\u2705 triggered\n\n\nwindows-binary-libtorch-release\nciflow/all, ciflow/binaries, ciflow/binaries_libtorch, ciflow/default, ciflow/trunk\n\u2705 triggered\n\n\nwindows-binary-wheel\nciflow/all, ciflow/binaries, ciflow/binaries_wheel, ciflow/default, ciflow/trunk\n\u2705 triggered\n\n\nSkipped Workflows\n\n\n\n\ncaffe2-linux-xenial-py3.7-gcc5.4\nciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk\n\ud83d\udeab skipped\n\n\ndocker-builds\nciflow/all, ciflow/trunk\n\ud83d\udeab skipped\n\n\nios-12-5-1-arm64\nciflow/all, ciflow/ios, ciflow/macos, ciflow/scheduled\n\ud83d\udeab skipped\n\n\nios-12-5-1-arm64-coreml\nciflow/all, ciflow/ios, ciflow/macos, ciflow/scheduled\n\ud83d\udeab skipped\n\n\nios-12-5-1-arm64-custom-ops\nciflow/all, ciflow/ios, ciflow/macos, ciflow/scheduled\n\ud83d\udeab skipped\n\n\nios-12-5-1-arm64-metal\nciflow/all, ciflow/ios, ciflow/macos, ciflow/scheduled\n\ud83d\udeab skipped\n\n\nios-12-5-1-x86-64\nciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk\n\ud83d\udeab skipped\n\n\nios-12-5-1-x86-64-coreml\nciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk\n\ud83d\udeab skipped\n\n\nlibtorch-linux-xenial-cuda10.2-py3.7-gcc7\nciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/trunk\n\ud83d\udeab skipped\n\n\nlibtorch-linux-xenial-cuda11.3-py3.7-gcc7\nciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/trunk\n\ud83d\udeab skipped\n\n\nlinux-bionic-cuda10.2-py3.9-gcc7\nciflow/all, ciflow/cuda, ciflow/linux, ciflow/slow, ciflow/trunk\n\ud83d\udeab skipped\n\n\nlinux-docs-push\nciflow/all, ciflow/cpu, ciflow/linux, ciflow/scheduled\n\ud83d\udeab skipped\n\n\nlinux-xenial-cuda11.3-py3.7-gcc7-no-ops\nciflow/all, ciflow/cuda, ciflow/linux, ciflow/trunk\n\ud83d\udeab skipped\n\n\nmacos-10-15-py3-arm64\nciflow/all, ciflow/macos, ciflow/trunk\n\ud83d\udeab skipped\n\n\nmacos-10-15-py3-lite-interpreter-x86-64\nciflow/all, ciflow/macos, ciflow/trunk\n\ud83d\udeab skipped\n\n\nmacos-11-py3-x86-64\nciflow/all, ciflow/macos, ciflow/trunk\n\ud83d\udeab skipped\n\n\nparallelnative-linux-xenial-py3.7-gcc5.4\nciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk\n\ud83d\udeab skipped\n\n\nperiodic-libtorch-linux-bionic-cuda11.5-py3.7-gcc7\nciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/scheduled\n\ud83d\udeab skipped\n\n\nperiodic-linux-bionic-cuda11.5-py3.7-gcc7\nciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled\n\ud83d\udeab skipped\n\n\nperiodic-linux-xenial-cuda10.2-py3-gcc7-slow-gradcheck\nciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled, ciflow/slow, ciflow/slow-gradcheck\n\ud83d\udeab skipped\n\n\nperiodic-linux-xenial-cuda11.3-py3.7-gcc7-debug\nciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled\n\ud83d\udeab skipped\n\n\nperiodic-win-vs2019-cuda11.5-py3\nciflow/all, ciflow/cuda, ciflow/scheduled, ciflow/win\n\ud83d\udeab skipped\n\n\npytorch-linux-xenial-py3-clang5-android-ndk-r19c-build\nciflow/all, ciflow/android, ciflow/cpu, ciflow/linux, ciflow/trunk\n\ud83d\udeab skipped\n\n\npytorch-xla-linux-bionic-py3.7-clang8\nciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk, ciflow/xla\n\ud83d\udeab skipped", + "createdAt": "2022-03-09T15:57:11Z", + "author": { + "login": "pytorch-bot" + }, + "authorAssociation": "NONE", + "editor": null, + "databaseId": 1063079053 + }, + { + "bodyText": "\ud83d\udd17 Helpful links\n\n\ud83e\uddea \u00a0See artifacts and rendered test results at hud.pytorch.org/pr/73969\n\ud83d\udcc4 \u00a0Preview docs built from this PR\n\ud83d\udcc4 \u00a0Preview C++ docs built from this PR\n\ud83d\udd27 \u00a0Opt-in to CIFlow to control what jobs run on your PRs\n\n\ud83d\udc8a CI failures summary and remediations\nAs of commit 4746da7 (more details on the Dr. CI page):\n\n\ud83d\udc9a \ud83d\udc9a Looks good so far! There are no failures yet. \ud83d\udc9a \ud83d\udc9a\n\nThis comment was automatically generated by Dr. CI (expand for details).\nPlease report bugs/suggestions to the (internal) Dr. CI Users group.\nClick here to manually regenerate this comment.", + "createdAt": "2022-03-09T15:57:12Z", + "author": { + "login": "facebook-github-bot" + }, + "authorAssociation": "MEMBER", + "editor": { + "login": "facebook-github-bot" + }, + "databaseId": 1063079113 + }, + { + "bodyText": "This pull request was exported from Phabricator. Differential Revision: D34753911", + "createdAt": "2022-03-09T15:57:34Z", + "author": { + "login": "facebook-github-bot" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1063079731 + } + ], + "pageInfo": { + "startCursor": "Y3Vyc29yOnYyOpHOP11MjQ==", + "hasPreviousPage": false + } + }, + "labels": { + "edges": [ + { + "node": { + "name": "fb-exported" + } + }, + { + "node": { + "name": "cla signed" + } + } + ] + } + } + } + } + }, + "query_sha=81fd873151c3cded18314e9e53bf54a93ffb0afa9c52fa2cbafb2ceab7df5e45 name=pytorch number=73099 owner=pytorch": { + "data": { + "repository": { + "pullRequest": { + "closed": true, + "isCrossRepository": false, + "author": { + "login": "BowenBao" + }, + "title": "[ONNX] Make graph name spec-compliant (#71961)", + "body": "Stack from [ghstack](https://github.com/ezyang/ghstack):\n* #73104\n* #73103\n* #73102\n* #73101\n* #73100\n* __->__ #73099\n\n[According to the ONNX spec](https://github.com/onnx/onnx/blob/main/docs/IR.md#names-within-a-graph),\nall names must adhere to C90 identifier syntax rules, which means no\ndashes.\n\nFixes: #30952", + "headRefName": "gh/BowenBao/138/head", + "headRepository": { + "nameWithOwner": "pytorch/pytorch" + }, + "baseRefName": "gh/BowenBao/138/base", + "baseRepository": { + "nameWithOwner": "pytorch/pytorch", + "isPrivate": false, + "defaultBranchRef": { + "name": "master" + } + }, + "mergeCommit": null, + "commits_with_authors": { + "nodes": [ + { + "commit": { + "author": { + "user": { + "login": "BowenBao" + }, + "email": "bowbao@microsoft.com", + "name": "BowenBao" + }, + "oid": "3038b939eb2069653305c419326a0f47d2598e39" + } + } + ], + "pageInfo": { + "endCursor": "MQ", + "hasNextPage": false + }, + "totalCount": 1 + }, + "commits": { + "nodes": [ + { + "commit": { + "checkSuites": { + "edges": [ + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "TorchBench CI (pytorch-linux-py3.7-cu102)" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "run-torchbench", + "conclusion": "NEUTRAL", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1866041786/jobs/2626264278" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAATkNn9o=", + "hasNextPage": false + } + }, + "conclusion": "SKIPPED", + "url": "https://github.com/pytorch/pytorch/commit/3038b939eb2069653305c419326a0f47d2598e39/checks?check_suite_id=5365189561" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAT_KS7k=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "linux-xenial-cuda11.3-py3.7-gcc7" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1866041785/jobs/2626264385" + }, + { + "name": "test (default, 1, 2, linux.4xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1866041785/jobs/2626417658" + }, + { + "name": "test (default, 2, 2, linux.4xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1866041785/jobs/2626417743" + }, + { + "name": "test (distributed, 1, 1, linux.8xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1866041785/jobs/2626417885" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAATkRE_E=", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/3038b939eb2069653305c419326a0f47d2598e39/checks?check_suite_id=5365189562" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAT_KS7o=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "linux-xenial-py3.7-gcc7-no-ops" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1866041789/jobs/2626264416" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAATkNoJE=", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/3038b939eb2069653305c419326a0f47d2598e39/checks?check_suite_id=5365189563" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAT_KS7s=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "linux-xenial-py3-clang5-mobile-build" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1866041787/jobs/2626264407" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAATkNoIY=", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/3038b939eb2069653305c419326a0f47d2598e39/checks?check_suite_id=5365189564" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAT_KS7w=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "linux-xenial-cuda11.3-py3.7-gcc7-bazel-test" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "build-and-test", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1866041788/jobs/2626264422" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAATkNoJs=", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/3038b939eb2069653305c419326a0f47d2598e39/checks?check_suite_id=5365189566" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAT_KS74=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "linux-bionic-py3.7-clang9" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1866041790/jobs/2626264414" + }, + { + "name": "test (default, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1866041790/jobs/2626349405" + }, + { + "name": "test (noarch, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1866041790/jobs/2626349522" + }, + { + "name": "test (default, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1866041790/jobs/2626349618" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAATkPiwA=", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/3038b939eb2069653305c419326a0f47d2598e39/checks?check_suite_id=5365189567" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAT_KS78=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "linux-vulkan-bionic-py3.7-clang9" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1866041793/jobs/2626264431" + }, + { + "name": "test (default, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1866041793/jobs/2626359364" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAATkPxgQ=", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/3038b939eb2069653305c419326a0f47d2598e39/checks?check_suite_id=5365189568" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAT_KS8A=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "linux-xenial-py3-clang5-mobile-custom-build-static" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1866041792/jobs/2626264427" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAATkNoKA=", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/3038b939eb2069653305c419326a0f47d2598e39/checks?check_suite_id=5365189570" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAT_KS8I=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "win-vs2019-cpu-py3" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1866041791/jobs/2626264386" + }, + { + "name": "test (default, 1, 2, windows.4xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1866041791/jobs/2626722677" + }, + { + "name": "test (default, 2, 2, windows.4xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1866041791/jobs/2626722710" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAATkX070=", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/3038b939eb2069653305c419326a0f47d2598e39/checks?check_suite_id=5365189571" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAT_KS8M=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "linux-xenial-py3.7-gcc7" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1866041803/jobs/2626264401" + }, + { + "name": "test (distributed, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1866041803/jobs/2626349045" + }, + { + "name": "test (default, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1866041803/jobs/2626349141" + }, + { + "name": "test (default, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/1866041803/jobs/2626349272" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAATkPiQA=", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/3038b939eb2069653305c419326a0f47d2598e39/checks?check_suite_id=5365189572" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAT_KS8Q=" + } + ], + "pageInfo": { + "hasNextPage": true + } + }, + "status": { + "contexts": [ + { + "context": "ci/circleci: binary_linux_libtorch_3_7m_cpu_gcc5_4_cxx11-abi_shared-with-deps_build", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/17010288?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: docker-pytorch-linux-xenial-py3-clang5-android-ndk-r19c", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/17010289?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-build-x86_32", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/17010488?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + }, + { + "context": "ci/circleci: pytorch_linux_xenial_py3_clang5_android_ndk_r19c_x86_32_build", + "state": "SUCCESS", + "targetUrl": "https://circleci.com/gh/pytorch/pytorch/17010326?utm_campaign=vcs-integration-link&utm_medium=referral&utm_source=github-build-link" + } + ] + }, + "pushedDate": "2022-02-18T18:46:28Z", + "oid": "3038b939eb2069653305c419326a0f47d2598e39" + } + } + ] + }, + "changedFiles": 162, + "files": { + "nodes": [ + { + "path": "test/onnx/expect/TestOperators.test_acos.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_add_broadcast.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_add_left_broadcast.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_add_size1_broadcast.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_add_size1_right_broadcast.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_add_size1_singleton_broadcast.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_addconstant.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_addmm.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_arange_dynamic.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_argmax.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_asin.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_at_op.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_atan.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_aten_embedding_1.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_aten_embedding_2.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_avg_pool2d.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_baddbmm.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_basic.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_batchnorm.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_batchnorm_1d.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_batchnorm_noaffine.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_batchnorm_onnx_irv4.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_batchnorm_training.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_bitshift.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_c2_op.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_chunk.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_clip.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_clip_max.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_clip_min.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_concat2.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_conv.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_conv_onnx_irv4.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_conv_onnx_irv4_opset8.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_convtranspose.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_cos.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_cumsum.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_det.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_dict.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_dict_str.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_dim.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_dropout.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_dropout_default.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_dropout_opset12.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_dropout_training.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_dropout_training_opset12.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_dynamic_axes_add.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_dynamic_axes_add_inputs_same_symbolic_shape.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_dynamic_axes_matmul.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_dynamic_axes_reduce_mean.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_dynamic_axes_unchange.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_elu.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_embedding_bags.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_empty_like.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_empty_like_opset7.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_equal.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_erf.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_exp.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_expand.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_flatten.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_flatten2D.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_fmod.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_frobenius_norm.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_full.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_full_like.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_gather.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_gather_opset11.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_ge.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_gelu.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_gt.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_hardtanh.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_implicit_expand.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_index.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_isnan.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_layer_norm_aten.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_le.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_linear.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_log_sigmoid.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_logsoftmax.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_lstm_none_sequence_lens.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_lt.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_master_opset.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_max.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_maxpool.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_maxpool_dilations.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_maxpool_indices.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_mean.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_mean_dtype.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_meshgrid.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_min.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_mm.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_narrow.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_ne.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_nonzero.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_norm_p1.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_norm_p2.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_ones_like.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_pad.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_params.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_params_onnx_irv4.expect" + }, + { + "path": "test/onnx/expect/TestOperators.test_permute2.expect" + } + ], + "pageInfo": { + "endCursor": "MTAw", + "hasNextPage": true + } + }, + "reviews": { + "nodes": [ + { + "author": { + "login": "garymm" + }, + "state": "APPROVED" + } + ], + "pageInfo": { + "startCursor": "Y3Vyc29yOnYyOpO5MjAyMi0wMi0xOFQxNzoxODo0NC0wODowMLkyMDIyLTAyLTE4VDE3OjE4OjQ0LTA4OjAwzjTr0H0=", + "hasPreviousPage": false + } + }, + "comments": { + "nodes": [ + { + "bodyText": "This PR cannot be merged by bot due to changing > 100 files. @malfet \n \n \n pytorch/.github/scripts/trymerge.py\n \n \n Line 63\n in\n 932adf2\n \n \n \n \n\n \n \n files(last: 100) { \n \n \n \n\n Can this be relaxed? If not please import.", + "createdAt": "2022-02-22T18:22:40Z", + "author": { + "login": "BowenBao" + }, + "authorAssociation": "COLLABORATOR", + "editor": null, + "databaseId": 1048084569 + }, + { + "bodyText": "This PR cannot be merged by bot due to changing > 100 files. @malfet\nCan this be relaxed? If not please import.\n\nWow, you've hit a really interesting problem. 100 is a limitation enforced by GitHub, see https://docs.github.com/en/graphql/overview/resource-limitations, but I can implement a pagination. Do you mind keeping it like that for a bit, want to land a fix soonish.", + "createdAt": "2022-02-22T18:27:29Z", + "author": { + "login": "malfet" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1048088691 + }, + { + "bodyText": "@malfet Thank you for info. Sure, I have separated the rest of stack from this one, we'll wait for the fix to try again.", + "createdAt": "2022-02-22T18:29:48Z", + "author": { + "login": "BowenBao" + }, + "authorAssociation": "COLLABORATOR", + "editor": null, + "databaseId": 1048090640 + }, + { + "bodyText": "@pytorchbot merge this", + "createdAt": "2022-02-24T21:42:36Z", + "author": { + "login": "BowenBao" + }, + "authorAssociation": "COLLABORATOR", + "editor": null, + "databaseId": 1050293881 + }, + { + "bodyText": "Hey @BowenBao.\nYou've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.\nFor changes that are 'topic: not user facing' there is no need for a release notes label.", + "createdAt": "2022-02-24T21:44:39Z", + "author": { + "login": "github-actions" + }, + "authorAssociation": "NONE", + "editor": null, + "databaseId": 1050295451 + } + ], + "pageInfo": { + "startCursor": "Y3Vyc29yOnYyOpHOPniAWQ==", + "hasPreviousPage": true + } + }, + "labels": { + "edges": [ + { + "node": { + "name": "oncall: jit" + } + }, + { + "node": { + "name": "open source" + } + }, + { + "node": { + "name": "cla signed" + } + }, + { + "node": { + "name": "release notes: onnx" + } + }, + { + "node": { + "name": "topic: bug fixes" + } + } + ] + } + } + } + } + }, + "query_sha=81fd873151c3cded18314e9e53bf54a93ffb0afa9c52fa2cbafb2ceab7df5e45 name=pytorch number=74649 owner=pytorch": { + "data": { + "repository": { + "pullRequest": { + "closed": true, + "isCrossRepository": false, + "author": { + "login": "malfet" + }, + "title": "This should fail flake8", + "body": "Test issue for GHF mandatory checks", + "headRefName": "malfet-patch-8", + "headRepository": { + "nameWithOwner": "pytorch/pytorch" + }, + "baseRefName": "master", + "baseRepository": { + "nameWithOwner": "pytorch/pytorch", + "isPrivate": false, + "defaultBranchRef": { + "name": "master" + } + }, + "mergeCommit": null, + "commits_with_authors": { + "nodes": [ + { + "commit": { + "author": { + "user": { + "login": "malfet" + }, + "email": "nshulga@fb.com", + "name": "Nikita Shulga" + }, + "oid": "57c86ff1c5ab948888fd329986c9d55796680e33" + } + }, + { + "commit": { + "author": { + "user": { + "login": "malfet" + }, + "email": "nshulga@fb.com", + "name": "Nikita Shulga" + }, + "oid": "6c3c3de6a5c1183d9a08f3c54148bc0b5de11bb4" + } + } + ], + "pageInfo": { + "endCursor": "Mg", + "hasNextPage": false + }, + "totalCount": 2 + }, + "commits": { + "nodes": [ + { + "commit": { + "checkSuites": { + "edges": [ + { + "node": { + "app": { + "name": "Facebook GitHub Tools", + "databaseId": 12274 + }, + "workflowRun": null, + "checkRuns": { + "nodes": [ + { + "name": "Facebook CLA Check", + "conclusion": "SUCCESS", + "detailsUrl": "https://code.intern.facebook.com/cla/" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAVHsK3w=", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/6c3c3de6a5c1183d9a08f3c54148bc0b5de11bb4/checks?check_suite_id=5778018129" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAVhlj1E=" + }, + { + "node": { + "app": { + "name": "Netlify", + "databaseId": 13473 + }, + "workflowRun": null, + "checkRuns": { + "nodes": [], + "pageInfo": { + "endCursor": null, + "hasNextPage": false + } + }, + "conclusion": null, + "url": "https://github.com/pytorch/pytorch/commit/6c3c3de6a5c1183d9a08f3c54148bc0b5de11bb4/checks?check_suite_id=5778018131" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAVhlj1M=" + }, + { + "node": { + "app": { + "name": "Azure Pipelines", + "databaseId": 9426 + }, + "workflowRun": null, + "checkRuns": { + "nodes": [], + "pageInfo": { + "endCursor": null, + "hasNextPage": false + } + }, + "conclusion": null, + "url": "https://github.com/pytorch/pytorch/commit/6c3c3de6a5c1183d9a08f3c54148bc0b5de11bb4/checks?check_suite_id=5778018132" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAVhlj1Q=" + }, + { + "node": { + "app": { + "name": "Dependabot", + "databaseId": 29110 + }, + "workflowRun": null, + "checkRuns": { + "nodes": [], + "pageInfo": { + "endCursor": null, + "hasNextPage": false + } + }, + "conclusion": null, + "url": "https://github.com/pytorch/pytorch/commit/6c3c3de6a5c1183d9a08f3c54148bc0b5de11bb4/checks?check_suite_id=5778018134" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAVhlj1Y=" + }, + { + "node": { + "app": { + "name": "Codecov", + "databaseId": 254 + }, + "workflowRun": null, + "checkRuns": { + "nodes": [], + "pageInfo": { + "endCursor": null, + "hasNextPage": false + } + }, + "conclusion": null, + "url": "https://github.com/pytorch/pytorch/commit/6c3c3de6a5c1183d9a08f3c54148bc0b5de11bb4/checks?check_suite_id=5778018139" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAVhlj1s=" + }, + { + "node": { + "app": { + "name": "PyTorch Bot", + "databaseId": 40112 + }, + "workflowRun": null, + "checkRuns": { + "nodes": [], + "pageInfo": { + "endCursor": null, + "hasNextPage": false + } + }, + "conclusion": null, + "url": "https://github.com/pytorch/pytorch/commit/6c3c3de6a5c1183d9a08f3c54148bc0b5de11bb4/checks?check_suite_id=5778018142" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAVhlj14=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "Lint" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "clang-format", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576283/jobs/2928925132" + }, + { + "name": "clang-tidy", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576283/jobs/2928925189" + }, + { + "name": "cmakelint", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576283/jobs/2928925230" + }, + { + "name": "flake8-py3", + "conclusion": "FAILURE", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576283/jobs/2928925307" + }, + { + "name": "mypy", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576283/jobs/2928925365" + }, + { + "name": "Test collect_env (with_torch)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576283/jobs/2928925427" + }, + { + "name": "Test collect_env (without_torch)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576283/jobs/2928925449" + }, + { + "name": "Test tools", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576283/jobs/2928925537" + }, + { + "name": "py2-setup-validate-errormsg", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576283/jobs/2928925644" + }, + { + "name": "quick-checks", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576283/jobs/2928925688" + }, + { + "name": "toc", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576283/jobs/2928925809" + }, + { + "name": "shellcheck", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576283/jobs/2928925945" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAVHsMiY=", + "hasNextPage": false + } + }, + "conclusion": "FAILURE", + "url": "https://github.com/pytorch/pytorch/commit/6c3c3de6a5c1183d9a08f3c54148bc0b5de11bb4/checks?check_suite_id=5778018384" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAVhlkFA=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "TorchBench CI (pytorch-linux-py3.7-cu102)" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "run-torchbench", + "conclusion": "NEUTRAL", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576288/jobs/2928925134" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAVHsLW0=", + "hasNextPage": false + } + }, + "conclusion": "SKIPPED", + "url": "https://github.com/pytorch/pytorch/commit/6c3c3de6a5c1183d9a08f3c54148bc0b5de11bb4/checks?check_suite_id=5778018395" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAVhlkFs=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "pull" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "pytorch-xla-linux-bionic-py3.7-clang8", + "conclusion": "NEUTRAL", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576300/jobs/2928935743" + }, + { + "name": "linux-vulkan-bionic-py3.7-clang9 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576300/jobs/2928935775" + }, + { + "name": "linux-bionic-py3.7-clang9 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576300/jobs/2928935850" + }, + { + "name": "linux-bionic-rocm4.5-py3.7 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576300/jobs/2928935994" + }, + { + "name": "linux-xenial-cuda11.3-py3.7-gcc7 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576300/jobs/2928936064" + }, + { + "name": "linux-xenial-py3.7-gcc5.4 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576300/jobs/2928936179" + }, + { + "name": "linux-xenial-py3-clang5-mobile-build / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576300/jobs/2928936265" + }, + { + "name": "linux-xenial-py3.7-gcc5.4-mobile-lightweight-dispatch-build / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576300/jobs/2928936309" + }, + { + "name": "linux-xenial-cuda11.3-py3.7-gcc7-bazel-test / build-and-test", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576300/jobs/2928936353" + }, + { + "name": "linux-xenial-py3-clang5-mobile-custom-build-static / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576300/jobs/2928936395" + }, + { + "name": "pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single-full-jit / build-and-test", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576300/jobs/2928936426" + }, + { + "name": "pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single / build-and-test", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576300/jobs/2928936483" + }, + { + "name": "win-vs2019-cuda11.3-py3 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576300/jobs/2928936516" + }, + { + "name": "win-vs2019-cpu-py3 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576300/jobs/2928936558" + }, + { + "name": "linux-xenial-py3.7-gcc7-no-ops / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576300/jobs/2928936633" + }, + { + "name": "linux-xenial-py3.7-gcc7 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576300/jobs/2928936705" + }, + { + "name": "deploy-linux-xenial-cuda11.3-py3.7-gcc7 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576300/jobs/2928936736" + }, + { + "name": "linux-xenial-py3.7-clang7-onnx / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576300/jobs/2928936756" + }, + { + "name": "pytorch-xla-linux-bionic-py3.7-clang8", + "conclusion": "NEUTRAL", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576300/jobs/2928936796" + }, + { + "name": "linux-xenial-py3.7-clang7-asan / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576300/jobs/2928936823" + }, + { + "name": "linux-xenial-py3.7-gcc7 / test (default, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576300/jobs/2928990551" + }, + { + "name": "linux-xenial-py3.7-gcc7 / test (default, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576300/jobs/2928990588" + }, + { + "name": "linux-docs / build-docs (cpp)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576300/jobs/2928992832" + }, + { + "name": "linux-docs / build-docs (python)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576300/jobs/2928992868" + }, + { + "name": "linux-xenial-py3.7-gcc5.4 / test (default, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576300/jobs/2928992932" + }, + { + "name": "linux-xenial-py3.7-gcc5.4 / test (default, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576300/jobs/2928992965" + }, + { + "name": "linux-xenial-py3.7-gcc5.4 / test (distributed, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576300/jobs/2928993011" + }, + { + "name": "linux-xenial-py3.7-gcc5.4 / test (docs_test, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576300/jobs/2928993042" + }, + { + "name": "linux-xenial-py3.7-gcc5.4 / test (backwards_compat, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576300/jobs/2928993086" + }, + { + "name": "linux-xenial-py3.7-gcc5.4 / test (jit_legacy, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576300/jobs/2928993128" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (default, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576300/jobs/2928995802" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (default, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576300/jobs/2928995853" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (noarch, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576300/jobs/2928995889" + }, + { + "name": "linux-vulkan-bionic-py3.7-clang9 / test (default, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576300/jobs/2928997626" + }, + { + "name": "linux-xenial-py3.7-clang7-onnx / test (default, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576300/jobs/2928999058" + }, + { + "name": "linux-xenial-py3.7-clang7-onnx / test (default, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576300/jobs/2928999075" + }, + { + "name": "linux-xenial-py3.7-clang7-asan / test (default, 1, 3, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576300/jobs/2929012407" + }, + { + "name": "linux-xenial-py3.7-clang7-asan / test (default, 2, 3, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576300/jobs/2929012438" + }, + { + "name": "linux-xenial-py3.7-clang7-asan / test (default, 3, 3, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576300/jobs/2929012469" + }, + { + "name": "linux-bionic-rocm4.5-py3.7 / test (default, 1, 2, linux.rocm.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576300/jobs/2929034328" + }, + { + "name": "linux-bionic-rocm4.5-py3.7 / test (default, 2, 2, linux.rocm.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576300/jobs/2929034340" + }, + { + "name": "linux-xenial-cuda11.3-py3.7-gcc7 / test (deploy, 1, 1, linux.4xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576300/jobs/2929040801" + }, + { + "name": "linux-xenial-cuda11.3-py3.7-gcc7 / test (default, 1, 2, linux.4xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576300/jobs/2929045939" + }, + { + "name": "linux-xenial-cuda11.3-py3.7-gcc7 / test (default, 2, 2, linux.4xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576300/jobs/2929046016" + }, + { + "name": "linux-xenial-cuda11.3-py3.7-gcc7 / test (distributed, 1, 1, linux.8xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576300/jobs/2929046063" + }, + { + "name": "win-vs2019-cpu-py3 / test (default, 1, 2, windows.4xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576300/jobs/2929082254" + }, + { + "name": "win-vs2019-cpu-py3 / test (default, 2, 2, windows.4xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576300/jobs/2929082275" + }, + { + "name": "win-vs2019-cuda11.3-py3 / test (default, 1, 2, windows.8xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576300/jobs/2929157614" + }, + { + "name": "win-vs2019-cuda11.3-py3 / test (default, 2, 2, windows.8xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576300/jobs/2929157635" + }, + { + "name": "win-vs2019-cuda11.3-py3 / test (force_on_cpu, 1, 1, windows.4xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2031576300/jobs/2929157656" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAVHxIT4=", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/6c3c3de6a5c1183d9a08f3c54148bc0b5de11bb4/checks?check_suite_id=5778018405" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAVhlkGU=" + } + ], + "pageInfo": { + "hasNextPage": false + } + }, + "status": null, + "pushedDate": "2022-03-24T00:42:33Z", + "oid": "6c3c3de6a5c1183d9a08f3c54148bc0b5de11bb4" + } + } + ] + }, + "changedFiles": 1, + "files": { + "nodes": [ + { + "path": "torch/nn/cpp.py" + } + ], + "pageInfo": { + "endCursor": "MQ", + "hasNextPage": false + } + }, + "reviews": { + "nodes": [ + { + "author": { + "login": "seemethere" + }, + "state": "APPROVED" + } + ], + "pageInfo": { + "startCursor": "Y3Vyc29yOnYyOpO5MjAyMi0wMy0yM1QxNTo1MDo0NS0wNzowMLkyMDIyLTAzLTIzVDE1OjUwOjQ1LTA3OjAwzjbPEDg=", + "hasPreviousPage": false + } + }, + "comments": { + "nodes": [ + { + "bodyText": "\ud83d\udd17 Helpful links\n\n\ud83e\uddea \u00a0See artifacts and rendered test results at hud.pytorch.org/pr/74649\n\u21a9\ufe0f \u00a0[fb-only] Re-run with SSH instructions\nNeed help or want to give feedback on the CI? Visit our office hours\n\n\ud83d\udc8a CI failures summary and remediations\nAs of commit 6c3c3de (more details on the Dr. CI page):\n\n\n1/1 failures introduced in this PR\n\n\n1 failure not recognized by patterns:\n\n\n\nJob\nStep\nAction\n\n\n\n\n Lint / flake8-py3\nFail if there were any warnings\n\ud83d\udd01 rerun\n\n\n\n\nThis comment was automatically generated by Dr. CI (expand for details).\nPlease report bugs/suggestions to the (internal) Dr. CI Users group.\nClick here to manually regenerate this comment.", + "createdAt": "2022-03-23T22:40:51Z", + "author": { + "login": "facebook-github-bot" + }, + "authorAssociation": "MEMBER", + "editor": { + "login": "facebook-github-bot" + }, + "databaseId": 1076891218 + } + ], + "pageInfo": { + "startCursor": "Y3Vyc29yOnYyOpHOQDAOUg==", + "hasPreviousPage": false + } + }, + "labels": { + "edges": [ + { + "node": { + "name": "cla signed" + } + } + ] + } + } + } + } + }, + "query_sha=81fd873151c3cded18314e9e53bf54a93ffb0afa9c52fa2cbafb2ceab7df5e45 name=pytorch number=79694 owner=pytorch": { + "data": { + "repository": { + "pullRequest": { + "closed": true, + "isCrossRepository": true, + "author": { + "login": "kshitij12345" + }, + "title": "[complex] conv_transpose1d", + "body": "Reference: https://github.com/pytorch/pytorch/issues/71108", + "headRefName": "develop/complex/conv_transpose1d", + "headRepository": { + "nameWithOwner": "kshitij12345/pytorch" + }, + "baseRefName": "master", + "baseRepository": { + "nameWithOwner": "pytorch/pytorch", + "isPrivate": false, + "defaultBranchRef": { + "name": "master" + } + }, + "mergeCommit": null, + "commits_with_authors": { + "nodes": [ + { + "commit": { + "author": { + "user": { + "login": "kshitij12345" + }, + "email": "kshitijkalambarkar@gmail.com", + "name": "kshitij12345" + }, + "oid": "d1ea948e65ac6d31ad056287ab65d38ecc68b30d" + } + }, + { + "commit": { + "author": { + "user": { + "login": "kshitij12345" + }, + "email": "kshitijkalambarkar@gmail.com", + "name": "kshitij12345" + }, + "oid": "b4ba1db9a3a71bd8c03158dcd1b68711360633d8" + } + }, + { + "commit": { + "author": { + "user": { + "login": "kshitij12345" + }, + "email": "kshitijkalambarkar@gmail.com", + "name": "kshitij12345" + }, + "oid": "655a4220beae163bfe578f0318a130df01ec05d6" + } + }, + { + "commit": { + "author": { + "user": { + "login": "kshitij12345" + }, + "email": "kshitijkalambarkar@gmail.com", + "name": "Kshiteej K" + }, + "oid": "8181716be7a8005eb13ad5c3f2e1279ed1c60aff" + } + }, + { + "commit": { + "author": { + "user": { + "login": "kshitij12345" + }, + "email": "kshitijkalambarkar@gmail.com", + "name": "kshitij12345" + }, + "oid": "9e5ca3663e7471786eeebebfdf84aea5d761712f" + } + }, + { + "commit": { + "author": { + "user": { + "login": "kshitij12345" + }, + "email": "kshitijkalambarkar@gmail.com", + "name": "kshitij12345" + }, + "oid": "9c110f39bcdc4e56386b6f9c4e2c082c8940ade6" + } + }, + { + "commit": { + "author": { + "user": { + "login": "kshitij12345" + }, + "email": "kshitijkalambarkar@gmail.com", + "name": "kshitij12345" + }, + "oid": "49315e79d0eee8008e2a74575c6fc0f6a9531ee4" + } + }, + { + "commit": { + "author": { + "user": { + "login": "kshitij12345" + }, + "email": "kshitijkalambarkar@gmail.com", + "name": "kshitij12345" + }, + "oid": "728752480760226270c374a0acc08e28b9b133f3" + } + }, + { + "commit": { + "author": { + "user": { + "login": "kshitij12345" + }, + "email": "kshitijkalambarkar@gmail.com", + "name": "kshitij12345" + }, + "oid": "ffe43399d6f60ef7844523a5f465c11d9a67062f" + } + }, + { + "commit": { + "author": { + "user": { + "login": "kshitij12345" + }, + "email": "kshitijkalambarkar@gmail.com", + "name": "kshitij12345" + }, + "oid": "9672a2198472567bae4ac6f55d004f7e1fa8a9fa" + } + }, + { + "commit": { + "author": { + "user": { + "login": "kshitij12345" + }, + "email": "kshitijkalambarkar@gmail.com", + "name": "kshitij12345" + }, + "oid": "48a0ebf32b895286f036b36c871f671dc867e400" + } + }, + { + "commit": { + "author": { + "user": { + "login": "kshitij12345" + }, + "email": "kshitijkalambarkar@gmail.com", + "name": "kshitij12345" + }, + "oid": "52fbe80d5c8a94e03d816c0bd21fd82019dcd5ac" + } + }, + { + "commit": { + "author": { + "user": { + "login": "kshitij12345" + }, + "email": "kshitijkalambarkar@gmail.com", + "name": "kshitij12345" + }, + "oid": "2fd08f1c669bbb0f2e14ae40e76f9e0d3195f4ce" + } + } + ], + "pageInfo": { + "endCursor": "MTM", + "hasNextPage": false + }, + "totalCount": 13 + }, + "commits": { + "nodes": [ + { + "commit": { + "checkSuites": { + "edges": [ + { + "node": { + "app": { + "name": "Facebook GitHub Tools", + "databaseId": 12274 + }, + "workflowRun": null, + "checkRuns": { + "nodes": [ + { + "name": "Facebook CLA Check", + "conclusion": "SUCCESS", + "detailsUrl": "https://code.facebook.com/cla/" + }, + { + "name": "Meta Internal-Only Changes Check", + "conclusion": "SUCCESS", + "detailsUrl": "https://opensource.facebook.com/" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAdtq8Hc=", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/2fd08f1c669bbb0f2e14ae40e76f9e0d3195f4ce/checks?check_suite_id=7929899098" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAdioqFo=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "TorchBench CI (pytorch-linux-py3.7-cu102)" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "run-torchbench", + "conclusion": "NEUTRAL", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393316/jobs/4628529923" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAdqTEwk=", + "hasNextPage": false + } + }, + "conclusion": "SKIPPED", + "url": "https://github.com/pytorch/pytorch/commit/2fd08f1c669bbb0f2e14ae40e76f9e0d3195f4ce/checks?check_suite_id=7929899387" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAdioqXs=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "Lint" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "lintrunner", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393315/jobs/4628529910" + }, + { + "name": "quick-checks", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393315/jobs/4628530162" + }, + { + "name": "Test collect_env (with_torch)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393315/jobs/4628530698" + }, + { + "name": "Test collect_env (without_torch)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393315/jobs/4628530867" + }, + { + "name": "Test collect_env (older_python_version)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393315/jobs/4628530989" + }, + { + "name": "pr-sanity-checks", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393315/jobs/4628531151" + }, + { + "name": "workflow-checks", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393315/jobs/4628531475" + }, + { + "name": "Test tools", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393315/jobs/4628531753" + }, + { + "name": "toc", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393315/jobs/4628531853" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAdqTHFY=", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/2fd08f1c669bbb0f2e14ae40e76f9e0d3195f4ce/checks?check_suite_id=7929899388" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAdioqXw=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "pull" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "linux-focal-py3.7-clang7-asan / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393329/jobs/4628531149" + }, + { + "name": "linux-bionic-cuda11.6-py3.10-gcc7 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393329/jobs/4628531473" + }, + { + "name": "linux-bionic-py3.7-clang9 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393329/jobs/4628531754" + }, + { + "name": "linux-jammy-cuda11.6-cudnn8-py3.8-clang12 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393329/jobs/4628531857" + }, + { + "name": "linux-focal-py3.7-gcc7-pch / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393329/jobs/4628532179" + }, + { + "name": "linux-focal-py3.7-clang10-onnx / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393329/jobs/4628532543" + }, + { + "name": "linux-bionic-cuda11.3-py3.7-clang9 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393329/jobs/4628532694" + }, + { + "name": "linux-focal-py3.7-gcc7 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393329/jobs/4628532918" + }, + { + "name": "linux-vulkan-bionic-py3.7-clang9 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393329/jobs/4628533033" + }, + { + "name": "linux-focal-py3.7-gcc7-no-ops / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393329/jobs/4628533181" + }, + { + "name": "linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single / build-and-test", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393329/jobs/4628533420" + }, + { + "name": "linux-xenial-cuda11.3-py3.7-gcc7-bazel-test / build-and-test", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393329/jobs/4628533630" + }, + { + "name": "linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single-full-jit / build-and-test", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393329/jobs/4628533825" + }, + { + "name": "linux-xenial-py3-clang5-mobile-custom-build-static / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393329/jobs/4628533959" + }, + { + "name": "linux-xenial-py3-clang5-mobile-build / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393329/jobs/4628534129" + }, + { + "name": "linux-bionic-py3_7-clang8-xla / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393329/jobs/4628534256" + }, + { + "name": "linux-focal-rocm5.2-py3.7 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393329/jobs/4628534388" + }, + { + "name": "linux-focal-py3.7-gcc7-mobile-lightweight-dispatch-build / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393329/jobs/4628534571" + }, + { + "name": "linux-bionic-cuda11_6-py3_10-gcc7-deploy / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393329/jobs/4628534714" + }, + { + "name": "win-vs2019-cuda11.6-py3 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393329/jobs/4628534989" + }, + { + "name": "win-vs2019-cpu-py3 / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393329/jobs/4628535311" + }, + { + "name": "linux-focal-py3.7-gcc7 / test (default, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393329/jobs/4628639115" + }, + { + "name": "linux-focal-py3.7-gcc7 / test (default, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393329/jobs/4628639198" + }, + { + "name": "linux-focal-py3.7-gcc7 / test (distributed, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393329/jobs/4628639265" + }, + { + "name": "linux-focal-py3.7-gcc7 / test (functorch, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393329/jobs/4628639339" + }, + { + "name": "linux-focal-py3.7-gcc7 / test (docs_test, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393329/jobs/4628639395" + }, + { + "name": "linux-focal-py3.7-gcc7 / test (jit_legacy, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393329/jobs/4628639450" + }, + { + "name": "linux-focal-py3.7-gcc7 / test (backwards_compat, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393329/jobs/4628639509" + }, + { + "name": "linux-docs / build-docs (cpp)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393329/jobs/4628639572" + }, + { + "name": "linux-docs / build-docs (python)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393329/jobs/4628639635" + }, + { + "name": "linux-focal-py3.7-clang10-onnx / test (default, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393329/jobs/4628647047" + }, + { + "name": "linux-focal-py3.7-clang10-onnx / test (default, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393329/jobs/4628647119" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (default, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393329/jobs/4628647215" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (default, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393329/jobs/4628647277" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (crossref, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393329/jobs/4628647348" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (crossref, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393329/jobs/4628647432" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (dynamo, 1, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393329/jobs/4628647522" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (dynamo, 2, 2, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393329/jobs/4628647641" + }, + { + "name": "linux-bionic-py3.7-clang9 / test (functorch, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393329/jobs/4628647762" + }, + { + "name": "linux-vulkan-bionic-py3.7-clang9 / test (default, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393329/jobs/4628653797" + }, + { + "name": "linux-focal-py3.7-clang7-asan / test (default, 1, 5, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393329/jobs/4628679376" + }, + { + "name": "linux-focal-py3.7-clang7-asan / test (default, 2, 5, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393329/jobs/4628679431" + }, + { + "name": "linux-focal-py3.7-clang7-asan / test (default, 3, 5, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393329/jobs/4628679469" + }, + { + "name": "linux-focal-py3.7-clang7-asan / test (default, 4, 5, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393329/jobs/4628679519" + }, + { + "name": "linux-focal-py3.7-clang7-asan / test (default, 5, 5, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393329/jobs/4628679594" + }, + { + "name": "linux-bionic-py3_7-clang8-xla / test (xla, 1, 1, linux.2xlarge)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393329/jobs/4628681226" + }, + { + "name": "linux-bionic-cuda11_6-py3_10-gcc7-deploy / test (deploy, 1, 1, linux.4xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393329/jobs/4628854932" + }, + { + "name": "linux-bionic-cuda11.6-py3.10-gcc7 / test (default, 1, 4, linux.4xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393329/jobs/4628856434" + }, + { + "name": "linux-bionic-cuda11.6-py3.10-gcc7 / test (default, 2, 4, linux.4xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393329/jobs/4628856501" + }, + { + "name": "linux-bionic-cuda11.6-py3.10-gcc7 / test (default, 3, 4, linux.4xlarge.nvidia.gpu)", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2907393329/jobs/4628856575" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAdqZ2fA=", + "hasNextPage": true + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/2fd08f1c669bbb0f2e14ae40e76f9e0d3195f4ce/checks?check_suite_id=7929899419" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAdioqZs=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "windows-binary-libtorch-debug" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "libtorch-cpu-shared-with-deps-debug-build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2910351637/jobs/4634503587" + }, + { + "name": "libtorch-cpu-shared-with-deps-debug-test", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2910351637/jobs/4635312938" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAdsbsmM=", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/2fd08f1c669bbb0f2e14ae40e76f9e0d3195f4ce/checks?check_suite_id=7936953056" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAdkUSuA=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "windows-binary-wheel" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "wheel-py3_7-cuda11_3-build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2910351640/jobs/4634503571" + }, + { + "name": "wheel-py3_7-cuda11_3-test", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2910351640/jobs/4636146265" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAdsskcw=", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/2fd08f1c669bbb0f2e14ae40e76f9e0d3195f4ce/checks?check_suite_id=7936953059" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAdkUSuM=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "windows-binary-libtorch-release" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "libtorch-cpu-shared-with-deps-release-build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2910351643/jobs/4634503570" + }, + { + "name": "libtorch-cpu-shared-with-deps-release-test", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2910351643/jobs/4635003925" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAdsVbD8=", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/2fd08f1c669bbb0f2e14ae40e76f9e0d3195f4ce/checks?check_suite_id=7936953061" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAdkUSuU=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "linux-binary-libtorch-cxx11-abi" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "libtorch-cpu-shared-with-deps-cxx11-abi-build / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2910351698/jobs/4634504079" + }, + { + "name": "libtorch-cpu-shared-with-deps-cxx11-abi-test / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2910351698/jobs/4635072931" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAdsW5Aw=", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/2fd08f1c669bbb0f2e14ae40e76f9e0d3195f4ce/checks?check_suite_id=7936953185" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAdkUS2E=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "linux-binary-libtorch-pre-cxx11" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "libtorch-cpu-shared-with-deps-cxx11-abi-build / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2910351700/jobs/4634503897" + }, + { + "name": "libtorch-cpu-shared-with-deps-cxx11-abi-test / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2910351700/jobs/4635077148" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAdsW-jo=", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/2fd08f1c669bbb0f2e14ae40e76f9e0d3195f4ce/checks?check_suite_id=7936953186" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAdkUS2I=" + }, + { + "node": { + "app": { + "name": "GitHub Actions", + "databaseId": 15368 + }, + "workflowRun": { + "workflow": { + "name": "linux-binary-manywheel" + } + }, + "checkRuns": { + "nodes": [ + { + "name": "manywheel-py3_7-cuda10_2-build / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2910351699/jobs/4634503896" + }, + { + "name": "manywheel-py3_7-cuda10_2-test / build", + "conclusion": "SUCCESS", + "detailsUrl": "https://github.com/pytorch/pytorch/actions/runs/2910351699/jobs/4635934290" + } + ], + "pageInfo": { + "endCursor": "Y3Vyc29yOnYyOpHPAAAAAdsoMEA=", + "hasNextPage": false + } + }, + "conclusion": "SUCCESS", + "url": "https://github.com/pytorch/pytorch/commit/2fd08f1c669bbb0f2e14ae40e76f9e0d3195f4ce/checks?check_suite_id=7936953187" + }, + "cursor": "Y3Vyc29yOnYyOpHPAAAAAdkUS2M=" + } + ], + "pageInfo": { + "hasNextPage": true + } + }, + "status": null, + "pushedDate": "2022-08-22T22:04:19Z", + "oid": "2fd08f1c669bbb0f2e14ae40e76f9e0d3195f4ce" + } + } + ] + }, + "changedFiles": 3, + "files": { + "nodes": [ + { + "path": "aten/src/ATen/native/Convolution.cpp" + }, + { + "path": "torch/testing/_internal/common_methods_invocations.py" + }, + { + "path": "torch/testing/_internal/common_modules.py" + } + ], + "pageInfo": { + "endCursor": "Mw", + "hasNextPage": false + } + }, + "reviews": { + "nodes": [ + { + "author": { + "login": "ngimel" + }, + "state": "APPROVED" + } + ], + "pageInfo": { + "startCursor": "Y3Vyc29yOnYyOpO5MjAyMi0wNy0xOVQxMDowNzo1NC0wNzowMLkyMDIyLTA3LTE5VDEwOjA3OjU0LTA3OjAwzj43QcY=", + "hasPreviousPage": false + } + }, + "comments": { + "nodes": [ + { + "bodyText": "@pytorchbot merge -g\nAll is green internally!", + "createdAt": "2022-08-23T19:29:55Z", + "author": { + "login": "albanD" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1224702749 + }, + { + "bodyText": "@pytorchbot successfully started a merge job. Check the current status here.\nThe merge job was triggered with the green (-g) flag. This means that your change will be merged once all checks on your PR have passed (ETA: 0-4 Hours). If this is not the intended behavior, feel free to use some of the other merge options in the wiki.\nPlease reach out to the PyTorch DevX Team with feedback or questions!", + "createdAt": "2022-08-23T19:31:18Z", + "author": { + "login": "pytorchmergebot" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1224705564 + }, + { + "bodyText": "Thanks for looking into it \ud83d\ude42 @albanD @jeanschmidt", + "createdAt": "2022-08-23T19:34:36Z", + "author": { + "login": "kshitij12345" + }, + "authorAssociation": "COLLABORATOR", + "editor": null, + "databaseId": 1224712351 + }, + { + "bodyText": "Hey @kshitij12345.\nYou've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.\nFor changes that are 'topic: not user facing' there is no need for a release notes label.", + "createdAt": "2022-08-23T22:31:58Z", + "author": { + "login": "github-actions" + }, + "authorAssociation": "NONE", + "editor": null, + "databaseId": 1224956051 + }, + { + "bodyText": "Yeah, discussed with my manager and I got the required permissions to do so. Sorry for not responding promptly yesterday. But I am available from now on to provide assistance :)", + "createdAt": "2022-08-24T09:24:04Z", + "author": { + "login": "jeanschmidt" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1225462612 + } + ], + "pageInfo": { + "startCursor": "Y3Vyc29yOnYyOpHOSP97HQ==", + "hasPreviousPage": true + } + }, + "labels": { + "edges": [ + { + "node": { + "name": "open source" + } + }, + { + "node": { + "name": "Merged" + } + }, + { + "node": { + "name": "cla signed" + } + }, + { + "node": { + "name": "Reverted" + } + }, + { + "node": { + "name": "ciflow/trunk" + } + }, + { + "node": { + "name": "ciflow/periodic" + } + } + ] + } + } + } + } + }, + "query_sha=2e2877d2452c4f233f042b7ccd50ab9c2a6e9a73d8819a0c876203c12364e8a3 cursor=Y3Vyc29yOnYyOpHOSP97HQ== name=pytorch number=79694 owner=pytorch": { + "data": { + "repository": { + "pullRequest": { + "comments": { + "nodes": [ + { + "bodyText": "\ud83d\udd17 Helpful links\n\n\ud83e\uddea \u00a0See artifacts and rendered test results at hud.pytorch.org/pr/79694\n\ud83d\udcc4 \u00a0Preview Python docs built from this PR\n\ud83d\udcc4 \u00a0Preview C++ docs built from this PR\n\u2753Need help or want to give feedback on the CI? Visit our office hours\n\n\u2705 No Failures (0 Pending)\nAs of commit 2fd08f1 (more details on the Dr. CI page):\nExpand to see more\n\n\ud83d\udc9a \ud83d\udc9a Looks good so far! There are no failures yet. \ud83d\udc9a \ud83d\udc9a\n\nThis comment was automatically generated by Dr. CI (expand for details).\nPlease report bugs/suggestions to the (internal) Dr. CI Users group.\nClick here to manually regenerate this comment.", + "createdAt": "2022-06-16T09:43:16Z", + "author": { + "login": "facebook-github-bot" + }, + "authorAssociation": "MEMBER", + "editor": { + "login": "facebook-github-bot" + }, + "databaseId": 1157454523 + }, + { + "bodyText": "Unable to reproduce jit failure locally (will skip the test)\nCI Failure : https://github.com/pytorch/pytorch/runs/6926187074?check_suite_focus=true#step:9:20230\npytest test/test_ops_jit.py -k test_variant_consistency_jit_nn_functional_conv_transpose1d_cpu_complex64 -v\n=============================================================== test session starts ===============================================================\nplatform linux -- Python 3.10.0, pytest-6.2.5, py-1.10.0, pluggy-1.0.0 -- /home/kshiteej/.conda/envs/pytorch-cuda-dev/bin/python\ncachedir: .pytest_cache\nhypothesis profile 'default' -> database=DirectoryBasedExampleDatabase('/home/kshiteej/Pytorch/pytorch_complex_convolution.py/.hypothesis/examples')\nrootdir: /home/kshiteej/Pytorch/pytorch_complex_convolution.py, configfile: pytest.ini\nplugins: hypothesis-6.23.2, repeat-0.9.1\ncollected 1976 items / 1975 deselected / 1 selected \n\ntest/test_ops_jit.py::TestJitCPU::test_variant_consistency_jit_nn_functional_conv_transpose1d_cpu_complex64 PASSED [100%]\n\n================================================================ warnings summary =================================================================\n../../.conda/envs/pytorch-cuda-dev/lib/python3.10/site-packages/torch/testing/_internal/common_cuda.py:9\n /home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.10/site-packages/torch/testing/_internal/common_cuda.py:9: DeprecationWarning: The distutils package is deprecated and slated for removal in Python 3.12. Use setuptools or check PEP 632 for potential alternatives\n from distutils.version import LooseVersion\n\n../../.conda/envs/pytorch-cuda-dev/lib/python3.10/site-packages/torch/backends/cudnn/__init__.py:91\n /home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.10/site-packages/torch/backends/cudnn/__init__.py:91: UserWarning: PyTorch was compiled without cuDNN/MIOpen support. To use cuDNN/MIOpen, rebuild PyTorch making sure the library is visible to the build system.\n warnings.warn(\n\n-- Docs: https://docs.pytest.org/en/stable/warnings.html\n================================================= 1 passed, 1975 deselected, 2 warnings in 4.90s =================================================", + "createdAt": "2022-07-18T09:05:35Z", + "author": { + "login": "kshitij12345" + }, + "authorAssociation": "COLLABORATOR", + "editor": { + "login": "kshitij12345" + }, + "databaseId": 1186949486 + }, + { + "bodyText": "@pytorchbot merge", + "createdAt": "2022-07-19T17:12:23Z", + "author": { + "login": "ngimel" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1189347786 + }, + { + "bodyText": "@pytorchbot successfully started a merge job. Check the current status here", + "createdAt": "2022-07-19T17:13:42Z", + "author": { + "login": "pytorchmergebot" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1189350009 + }, + { + "bodyText": "Hey @kshitij12345.\nYou've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.\nFor changes that are 'topic: not user facing' there is no need for a release notes label.", + "createdAt": "2022-07-19T17:14:25Z", + "author": { + "login": "github-actions" + }, + "authorAssociation": "NONE", + "editor": null, + "databaseId": 1189350932 + }, + { + "bodyText": "@pytorchbot revert -m \"broke slow test https://github.com/pytorch/pytorch/runs/7414560957?check_suite_focus=true#step:9:31516\" -c \"nosignal\"", + "createdAt": "2022-07-19T19:15:41Z", + "author": { + "login": "kshitij12345" + }, + "authorAssociation": "COLLABORATOR", + "editor": null, + "databaseId": 1189459845 + }, + { + "bodyText": "@pytorchbot successfully started a revert job. Check the current status here", + "createdAt": "2022-07-19T19:16:59Z", + "author": { + "login": "pytorchmergebot" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1189460926 + }, + { + "bodyText": "Will not revert as @kshitij12345 is not a MEMBER, but COLLABORATOR", + "createdAt": "2022-07-19T19:17:00Z", + "author": { + "login": "pytorchmergebot" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1189460942 + }, + { + "bodyText": "@pytorchbot revert -m \"broke slow test https://github.com/pytorch/pytorch/runs/7414560957?check_suite_focus=true#step:9:31516\" -c \"nosignal\"", + "createdAt": "2022-07-19T20:40:04Z", + "author": { + "login": "anjali411" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1189529734 + }, + { + "bodyText": "@pytorchbot successfully started a revert job. Check the current status here", + "createdAt": "2022-07-19T20:41:20Z", + "author": { + "login": "pytorchmergebot" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1189530756 + }, + { + "bodyText": "@kshitij12345 your PR has been successfully reverted.", + "createdAt": "2022-07-19T20:41:25Z", + "author": { + "login": "pytorchmergebot" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1189530831 + }, + { + "bodyText": "@pytorchbot merge -g", + "createdAt": "2022-07-20T09:53:08Z", + "author": { + "login": "kshitij12345" + }, + "authorAssociation": "COLLABORATOR", + "editor": null, + "databaseId": 1190070141 + }, + { + "bodyText": "@pytorchbot successfully started a merge job. Check the current status here", + "createdAt": "2022-07-20T09:54:24Z", + "author": { + "login": "pytorchmergebot" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1190071424 + }, + { + "bodyText": "Hey @kshitij12345.\nYou've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.\nFor changes that are 'topic: not user facing' there is no need for a release notes label.", + "createdAt": "2022-07-20T13:00:51Z", + "author": { + "login": "github-actions" + }, + "authorAssociation": "NONE", + "editor": null, + "databaseId": 1190258272 + }, + { + "bodyText": "commit is breaking internal builds/tests https://pastebin.com/HX4RUusH (pytorch/functorch/test:test_eager_transforms)", + "createdAt": "2022-07-21T10:39:01Z", + "author": { + "login": "jeanschmidt" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1191327616 + }, + { + "bodyText": "@pytorchbot revert -m \"breaking internal builds\" -c \"ghfirst\"", + "createdAt": "2022-07-21T10:39:27Z", + "author": { + "login": "jeanschmidt" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1191328013 + }, + { + "bodyText": "@pytorchbot revert -m \"breaking internal builds\" -c \"ghfirst\"", + "createdAt": "2022-07-21T10:41:23Z", + "author": { + "login": "jeanschmidt" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1191329792 + }, + { + "bodyText": "@pytorchbot successfully started a revert job. Check the current status here", + "createdAt": "2022-07-21T10:42:16Z", + "author": { + "login": "pytorchmergebot" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1191330586 + }, + { + "bodyText": "@kshitij12345 your PR has been successfully reverted.", + "createdAt": "2022-07-21T10:42:23Z", + "author": { + "login": "pytorchmergebot" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1191330690 + }, + { + "bodyText": "@jeanschmidt which test is it failing on? I tried running the test_eager_transforms in functorch but couldn't reproduce it.", + "createdAt": "2022-07-25T07:11:19Z", + "author": { + "login": "kshitij12345" + }, + "authorAssociation": "COLLABORATOR", + "editor": null, + "databaseId": 1193667568 + }, + { + "bodyText": "@jbschlosser have added a ref as discussed offline. Can you please take a look? And if it looks good, can you import the PR to check if it is breaking anything internally.\nThanks", + "createdAt": "2022-08-03T18:30:17Z", + "author": { + "login": "kshitij12345" + }, + "authorAssociation": "COLLABORATOR", + "editor": null, + "databaseId": 1204329491 + }, + { + "bodyText": "@jbschlosser @jeanschmidt @albanD anything we can do to unblock this on our side?", + "createdAt": "2022-08-20T09:27:17Z", + "author": { + "login": "lezcano" + }, + "authorAssociation": "COLLABORATOR", + "editor": null, + "databaseId": 1221266218 + }, + { + "bodyText": "Functorch tests should be running here now so can you rebase on top of master please?", + "createdAt": "2022-08-22T21:42:37Z", + "author": { + "login": "albanD" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1223129944 + }, + { + "bodyText": "@albanD have rebased on latest master.", + "createdAt": "2022-08-23T08:49:10Z", + "author": { + "login": "kshitij12345" + }, + "authorAssociation": "COLLABORATOR", + "editor": null, + "databaseId": 1223758571 + }, + { + "bodyText": "I triggered all the tests not to have any issues with slow tests again", + "createdAt": "2022-08-23T09:20:18Z", + "author": { + "login": "lezcano" + }, + "authorAssociation": "COLLABORATOR", + "editor": null, + "databaseId": 1223796413 + }, + { + "bodyText": "Thanks @lezcano! However, last time it was reverted for internal failures. So it would be great if someone can import and verify that.\ncc: @albanD @jeanschmidt", + "createdAt": "2022-08-23T10:17:50Z", + "author": { + "login": "kshitij12345" + }, + "authorAssociation": "COLLABORATOR", + "editor": null, + "databaseId": 1223863075 + }, + { + "bodyText": "@albanD has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.", + "createdAt": "2022-08-23T14:43:02Z", + "author": { + "login": "facebook-github-bot" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1224175731 + }, + { + "bodyText": "I am not the right person to provide assistence, as currently I am not based in a Tier 1 location, so my permissions to access are so restricted that I am not able to import this commit, run the tests and provide meaningful responses.", + "createdAt": "2022-08-23T15:57:48Z", + "author": { + "login": "jeanschmidt" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1224272324 + }, + { + "bodyText": "@jeanschmidt has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.", + "createdAt": "2022-08-23T17:00:53Z", + "author": { + "login": "facebook-github-bot" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1224351135 + } + ], + "pageInfo": { + "startCursor": "Y3Vyc29yOnYyOpHORP1auw==", + "hasPreviousPage": false + } + } + } + } + } + }, + "query_sha=2e2877d2452c4f233f042b7ccd50ab9c2a6e9a73d8819a0c876203c12364e8a3 cursor=Y3Vyc29yOnYyOpHOR1poyg== name=pytorch number=82169 owner=pytorch": { + "data": { + "repository": { + "pullRequest": { + "comments": { + "nodes": [ + { + "bodyText": "\ud83d\udd17 Helpful links\n\n\ud83e\uddea \u00a0See artifacts and rendered test results at hud.pytorch.org/pr/82169\n\ud83d\udcc4 \u00a0Preview Python docs built from this PR\n\ud83d\udcc4 \u00a0Preview C++ docs built from this PR\n\u2753Need help or want to give feedback on the CI? Visit our office hours\n\n\u2705 No Failures (0 Pending)\nAs of commit 28140e4 (more details on the Dr. CI page):\nExpand to see more\n\n\ud83d\udc9a \ud83d\udc9a Looks good so far! There are no failures yet. \ud83d\udc9a \ud83d\udc9a\n\nThis comment was automatically generated by Dr. CI (expand for details).\nPlease report bugs/suggestions to the (internal) Dr. CI Users group.\nClick here to manually regenerate this comment.", + "createdAt": "2022-07-25T21:41:41Z", + "author": { + "login": "facebook-github-bot" + }, + "authorAssociation": "MEMBER", + "editor": { + "login": "facebook-github-bot" + }, + "databaseId": 1194667199 + }, + { + "bodyText": "@pytorchbot merge -g", + "createdAt": "2022-07-25T21:46:04Z", + "author": { + "login": "ezyang" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1194671445 + }, + { + "bodyText": "@pytorchbot successfully started a merge job. Check the current status here", + "createdAt": "2022-07-25T21:47:25Z", + "author": { + "login": "pytorchmergebot" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1194672744 + }, + { + "bodyText": "Merge failed due to Refusing to merge as mandatory check(s) pull failed for rule superuser\nRaised by https://github.com/pytorch/pytorch/actions/runs/2735501647", + "createdAt": "2022-07-25T23:22:45Z", + "author": { + "login": "pytorchmergebot" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1194761219 + }, + { + "bodyText": "@pytorchbot rebase", + "createdAt": "2022-07-26T00:54:17Z", + "author": { + "login": "ezyang" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1194839920 + }, + { + "bodyText": "@pytorchbot successfully started a rebase job. Check the current status here", + "createdAt": "2022-07-26T01:01:32Z", + "author": { + "login": "pytorchmergebot" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1194846575 + }, + { + "bodyText": "Successfully rebased gh/ezyang/1279/orig onto refs/remotes/origin/master, please pull locally before adding more changes (for example, via ghstack checkout https://github.com/pytorch/pytorch/pull/82169)", + "createdAt": "2022-07-26T01:01:53Z", + "author": { + "login": "pytorchmergebot" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1194846838 + }, + { + "bodyText": "@pytorchbot rebase", + "createdAt": "2022-07-27T15:32:13Z", + "author": { + "login": "ezyang" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1196915484 + }, + { + "bodyText": "@pytorchbot successfully started a rebase job. Check the current status here", + "createdAt": "2022-07-27T15:33:49Z", + "author": { + "login": "pytorchmergebot" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1196917359 + }, + { + "bodyText": "Successfully rebased gh/ezyang/1279/orig onto refs/remotes/origin/master, please pull locally before adding more changes (for example, via ghstack checkout https://github.com/pytorch/pytorch/pull/82169)", + "createdAt": "2022-07-27T15:34:03Z", + "author": { + "login": "pytorchmergebot" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1196917609 + }, + { + "bodyText": "@pytorchbot merge -g", + "createdAt": "2022-07-27T15:41:52Z", + "author": { + "login": "ezyang" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1196927174 + }, + { + "bodyText": "@pytorchbot successfully started a merge job. Check the current status here", + "createdAt": "2022-07-27T15:43:11Z", + "author": { + "login": "pytorchmergebot" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1196928771 + }, + { + "bodyText": "Merge failed due to Refusing to merge as mandatory check(s) Lint failed for rule superuser\nRaised by https://github.com/pytorch/pytorch/actions/runs/2747872935", + "createdAt": "2022-07-27T15:43:14Z", + "author": { + "login": "pytorchmergebot" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1196928849 + }, + { + "bodyText": "@pytorchbot merge -g", + "createdAt": "2022-07-27T16:59:37Z", + "author": { + "login": "ezyang" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1197046487 + }, + { + "bodyText": "@pytorchbot successfully started a merge job. Check the current status here", + "createdAt": "2022-07-27T17:07:32Z", + "author": { + "login": "pytorchmergebot" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1197055101 + }, + { + "bodyText": "Merge failed due to Refusing to merge as mandatory check(s) Lint failed for rule superuser\nRaised by https://github.com/pytorch/pytorch/actions/runs/2748317347", + "createdAt": "2022-07-27T17:07:36Z", + "author": { + "login": "pytorchmergebot" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1197055259 + }, + { + "bodyText": "@pytorchbot merge -f", + "createdAt": "2022-07-27T17:56:26Z", + "author": { + "login": "malfet" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1197107106 + }, + { + "bodyText": "\u274c \ud83e\udd16 pytorchbot command failed:\n@pytorchbot merge: error: argument -f/--force: expected one argument\n\nusage: @pytorchbot merge [-g | -f FORCE | -l]\n\nTry @pytorchbot --help for more info.", + "createdAt": "2022-07-27T17:56:27Z", + "author": { + "login": "pytorch-bot" + }, + "authorAssociation": "NONE", + "editor": null, + "databaseId": 1197107129 + } + ], + "pageInfo": { + "startCursor": "Y3Vyc29yOnYyOpHORzUsvw==", + "hasPreviousPage": false + } + } + } + } + } + }, + "query_sha=2e2877d2452c4f233f042b7ccd50ab9c2a6e9a73d8819a0c876203c12364e8a3 cursor=Y3Vyc29yOnYyOpHOPoR4Lg== name=pytorch number=71759 owner=pytorch": { + "data": { + "repository": { + "pullRequest": { + "comments": { + "nodes": [ + { + "bodyText": "CI Flow Status\n\u269b\ufe0f CI Flow\nRuleset - Version: v1\nRuleset - File: https://github.com/coolteemf/pytorch/blob/7647f7953a68e4f1c3feaa19c77d925abfe8e377/.github/generated-ciflow-ruleset.json\nPR ciflow labels: ciflow/default\nAdd ciflow labels to this PR to trigger more builds:\n\n\n\nWorkflows\nLabels (bold enabled)\nStatus\n\n\n\n\nTriggered Workflows\n\n\n\n\nlinux-bionic-py3.6-clang9\nciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/noarch, ciflow/xla\n\u2705 triggered\n\n\nlinux-xenial-cuda11.3-py3.6-gcc7\nciflow/all, ciflow/cuda, ciflow/default, ciflow/linux\n\u2705 triggered\n\n\nlinux-xenial-py3.6-clang7-asan\nciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/sanitizers\n\u2705 triggered\n\n\nlinux-xenial-py3.6-gcc5.4\nciflow/all, ciflow/cpu, ciflow/default, ciflow/linux\n\u2705 triggered\n\n\nlinux-xenial-py3.6-gcc7-bazel-test\nciflow/all, ciflow/bazel, ciflow/cpu, ciflow/default, ciflow/linux\n\u2705 triggered\n\n\nwin-vs2019-cpu-py3\nciflow/all, ciflow/cpu, ciflow/default, ciflow/win\n\u2705 triggered\n\n\nwin-vs2019-cuda11.3-py3\nciflow/all, ciflow/cuda, ciflow/default, ciflow/win\n\u2705 triggered\n\n\nSkipped Workflows\n\n\n\n\nlibtorch-linux-xenial-cuda10.2-py3.6-gcc7\nciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux\n\ud83d\udeab skipped\n\n\nlibtorch-linux-xenial-cuda11.3-py3.6-gcc7\nciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux\n\ud83d\udeab skipped\n\n\nlinux-bionic-cuda10.2-py3.9-gcc7\nciflow/all, ciflow/cuda, ciflow/linux, ciflow/slow\n\ud83d\udeab skipped\n\n\nlinux-xenial-cuda10.2-py3.6-gcc7\nciflow/all, ciflow/cuda, ciflow/linux, ciflow/slow\n\ud83d\udeab skipped\n\n\nparallelnative-linux-xenial-py3.6-gcc5.4\nciflow/all, ciflow/cpu, ciflow/linux\n\ud83d\udeab skipped\n\n\nperiodic-libtorch-linux-xenial-cuda11.1-py3.6-gcc7\nciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/scheduled\n\ud83d\udeab skipped\n\n\nperiodic-linux-xenial-cuda11.1-py3.6-gcc7\nciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled\n\ud83d\udeab skipped\n\n\nperiodic-win-vs2019-cuda11.1-py3\nciflow/all, ciflow/cuda, ciflow/scheduled, ciflow/win\n\ud83d\udeab skipped\n\n\npuretorch-linux-xenial-py3.6-gcc5.4\nciflow/all, ciflow/cpu, ciflow/linux\n\ud83d\udeab skipped", + "createdAt": "2022-01-25T09:31:05Z", + "author": { + "login": "pytorch-bot" + }, + "authorAssociation": "NONE", + "editor": null, + "databaseId": 1020983378 + }, + { + "bodyText": "Hi @coolteemf!\nThank you for your pull request and welcome to our community.\nAction Required\nIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.\nProcess\nIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.\nOnce the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.\nIf you have received this in error or have any questions, please contact us at cla@fb.com. Thanks!", + "createdAt": "2022-01-25T09:31:06Z", + "author": { + "login": "facebook-github-bot" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1020983383 + }, + { + "bodyText": "\ud83d\udd17 Helpful links\n\n\ud83e\uddea \u00a0See artifacts and rendered test results at hud.pytorch.org/pr/71759\n\ud83d\udcc4 \u00a0Preview docs built from this PR\n\ud83d\udcc4 \u00a0Preview C++ docs built from this PR\n\ud83d\udd27 \u00a0Opt-in to CIFlow to control what jobs run on your PRs\n\n\ud83d\udc8a CI failures summary and remediations\nAs of commit 346e0c5 (more details on the Dr. CI page):\n\n\n2/3 failures introduced in this PR\n1/3 tentatively recognized as flaky \u2744\ufe0f\n\nClick here to rerun these jobs\n\n\n\n\n\ud83d\udd75\ufe0f 2 new failures recognized by patterns\nThe following CI failures do not appear to be due to upstream breakages:\n win-vs2019-cpu-py3 / test (default, 2, 2, windows.4xlarge) (1/2)\nStep: \"Test\" (full log | diagnosis details | \ud83d\udd01 rerun)\n\n\n2022-02-23T14:12:58.9371445Z FAIL [0.010s]: test_sparse_addmm_cpu_bfloat16 (__main__.TestSparseCPU)\n\n2022-02-23T14:12:58.9258506Z test_sparse_zeros_tanh_cpu_float64 (__main__.TestSparseUnaryUfuncsCPU) ... ok (0.002s)\n2022-02-23T14:12:58.9274771Z test_sparse_zeros_tanh_cpu_int16 (__main__.TestSparseUnaryUfuncsCPU) ... ok (0.001s)\n2022-02-23T14:12:58.9290805Z test_sparse_zeros_tanh_cpu_int32 (__main__.TestSparseUnaryUfuncsCPU) ... ok (0.001s)\n2022-02-23T14:12:58.9306695Z test_sparse_zeros_tanh_cpu_int64 (__main__.TestSparseUnaryUfuncsCPU) ... ok (0.000s)\n2022-02-23T14:12:58.9322595Z test_sparse_zeros_tanh_cpu_int8 (__main__.TestSparseUnaryUfuncsCPU) ... ok (0.000s)\n2022-02-23T14:12:58.9338535Z test_sparse_zeros_tanh_cpu_uint8 (__main__.TestSparseUnaryUfuncsCPU) ... ok (0.000s)\n2022-02-23T14:12:58.9354468Z test_sparse_zeros_trunc_cpu_float32 (__main__.TestSparseUnaryUfuncsCPU) ... ok (0.000s)\n2022-02-23T14:12:58.9370208Z test_sparse_zeros_trunc_cpu_float64 (__main__.TestSparseUnaryUfuncsCPU) ... ok (0.000s)\n2022-02-23T14:12:58.9370712Z \n2022-02-23T14:12:58.9370976Z ======================================================================\n2022-02-23T14:12:58.9371445Z FAIL [0.010s]: test_sparse_addmm_cpu_bfloat16 (__main__.TestSparseCPU)\n2022-02-23T14:12:58.9372134Z ----------------------------------------------------------------------\n2022-02-23T14:12:58.9372597Z Traceback (most recent call last):\n2022-02-23T14:12:58.9374021Z File \"C:\\actions-runner\\_work\\pytorch\\pytorch\\build\\win_tmp\\build\\torch\\testing\\_internal\\common_device_type.py\", line 376, in instantiated_test\n2022-02-23T14:12:58.9374740Z result = test(self, **param_kwargs)\n2022-02-23T14:12:58.9375570Z File \"C:\\actions-runner\\_work\\pytorch\\pytorch\\build\\win_tmp\\build\\torch\\testing\\_internal\\common_utils.py\", line 2951, in wrapped\n2022-02-23T14:12:58.9376266Z f(self, *args, **kwargs, coalesced=False)\n2022-02-23T14:12:58.9376972Z File \"test_sparse.py\", line 1272, in test_sparse_addmm\n2022-02-23T14:12:58.9377402Z test_shape(7, 8, 9, 20, True, None)\n2022-02-23T14:12:58.9377939Z File \"test_sparse.py\", line 1264, in test_shape\n2022-02-23T14:12:58.9378373Z self.assertEqual(Y, Y_dense)\n\n\n win-vs2019-cuda11.3-py3 / test (default, 2, 2, windows.8xlarge.nvidia.gpu) (2/2)\nStep: \"Test\" (full log | diagnosis details | \ud83d\udd01 rerun)\n\n\n2022-02-23T15:20:20.5710678Z FAIL [0.031s]: test_sparse_addmm_cpu_bfloat16 (__main__.TestSparseCPU)\n\n2022-02-23T15:20:20.5569146Z test_sparse_zeros_tanh_cuda_float64 (__main__.TestSparseUnaryUfuncsCUDA) ... ok (0.000s)\n2022-02-23T15:20:20.5589083Z test_sparse_zeros_tanh_cuda_int16 (__main__.TestSparseUnaryUfuncsCUDA) ... ok (0.000s)\n2022-02-23T15:20:20.5609025Z test_sparse_zeros_tanh_cuda_int32 (__main__.TestSparseUnaryUfuncsCUDA) ... ok (0.000s)\n2022-02-23T15:20:20.5629080Z test_sparse_zeros_tanh_cuda_int64 (__main__.TestSparseUnaryUfuncsCUDA) ... ok (0.016s)\n2022-02-23T15:20:20.5649102Z test_sparse_zeros_tanh_cuda_int8 (__main__.TestSparseUnaryUfuncsCUDA) ... ok (0.000s)\n2022-02-23T15:20:20.5668867Z test_sparse_zeros_tanh_cuda_uint8 (__main__.TestSparseUnaryUfuncsCUDA) ... ok (0.000s)\n2022-02-23T15:20:20.5688700Z test_sparse_zeros_trunc_cuda_float32 (__main__.TestSparseUnaryUfuncsCUDA) ... ok (0.000s)\n2022-02-23T15:20:20.5708285Z test_sparse_zeros_trunc_cuda_float64 (__main__.TestSparseUnaryUfuncsCUDA) ... ok (0.000s)\n2022-02-23T15:20:20.5709405Z \n2022-02-23T15:20:20.5709879Z ======================================================================\n2022-02-23T15:20:20.5710678Z FAIL [0.031s]: test_sparse_addmm_cpu_bfloat16 (__main__.TestSparseCPU)\n2022-02-23T15:20:20.5711399Z ----------------------------------------------------------------------\n2022-02-23T15:20:20.5712013Z Traceback (most recent call last):\n2022-02-23T15:20:20.5713280Z File \"C:\\actions-runner\\_work\\pytorch\\pytorch\\build\\win_tmp\\build\\torch\\testing\\_internal\\common_device_type.py\", line 376, in instantiated_test\n2022-02-23T15:20:20.5714267Z result = test(self, **param_kwargs)\n2022-02-23T15:20:20.5715299Z File \"C:\\actions-runner\\_work\\pytorch\\pytorch\\build\\win_tmp\\build\\torch\\testing\\_internal\\common_utils.py\", line 2951, in wrapped\n2022-02-23T15:20:20.5716240Z f(self, *args, **kwargs, coalesced=False)\n2022-02-23T15:20:20.5716943Z File \"test_sparse.py\", line 1275, in test_sparse_addmm\n2022-02-23T15:20:20.5717516Z test_shape(7, 8, 9, 20, False, (1, 1))\n2022-02-23T15:20:20.5718323Z File \"test_sparse.py\", line 1264, in test_shape\n2022-02-23T15:20:20.5718915Z self.assertEqual(Y, Y_dense)\n\n\n\n\u2744\ufe0f 1 failure tentatively classified as flaky\nbut reruns have not yet been triggered to confirm:\n linux-bionic-rocm4.5-py3.7 / test (distributed, 1, 1, linux.rocm.gpu) (1/1)\nStep: \"Test\" (full log | diagnosis details | \ud83d\udd01 rerun) \u2744\ufe0f\n\n\n2022-02-23T16:16:26.7221984Z RuntimeError: Proc...ated or timed out after 100.06913685798645 seconds\n\n2022-02-23T16:16:26.7207909Z ERROR [100.093s]: test_collect_shards (__main__.TestZeroRedundancyOptimizerDistributed)\n2022-02-23T16:16:26.7209206Z Check the state consolidation mechanism, and the state dict exposed by ZeroRedundancyOptimizer\n2022-02-23T16:16:26.7213073Z ----------------------------------------------------------------------\n2022-02-23T16:16:26.7213996Z Traceback (most recent call last):\n2022-02-23T16:16:26.7215434Z File \"/opt/conda/lib/python3.7/site-packages/torch/testing/_internal/common_distributed.py\", line 483, in wrapper\n2022-02-23T16:16:26.7216409Z self._join_processes(fn)\n2022-02-23T16:16:26.7217801Z File \"/opt/conda/lib/python3.7/site-packages/torch/testing/_internal/common_distributed.py\", line 702, in _join_processes\n2022-02-23T16:16:26.7218822Z self._check_return_codes(elapsed_time)\n2022-02-23T16:16:26.7220266Z File \"/opt/conda/lib/python3.7/site-packages/torch/testing/_internal/common_distributed.py\", line 754, in _check_return_codes\n2022-02-23T16:16:26.7221201Z i, elapsed_time\n2022-02-23T16:16:26.7221984Z RuntimeError: Process 0 terminated or timed out after 100.06913685798645 seconds\n2022-02-23T16:16:26.7222551Z \n2022-02-23T16:16:26.7223245Z ----------------------------------------------------------------------\n2022-02-23T16:16:26.7224032Z Ran 26 tests in 303.663s\n2022-02-23T16:16:26.7224400Z \n2022-02-23T16:16:26.7224780Z FAILED (errors=1, skipped=8, unexpected successes=3)\n2022-02-23T16:16:26.7225718Z \n2022-02-23T16:16:26.7225992Z Generating XML reports...\n2022-02-23T16:16:26.7336797Z Generated XML report: test-reports/python-unittest/distributed.optim.test_zero_redundancy_optimizer/TEST-TestZeroRedundancyOptimizerDistributed-20220223161123.xml\n2022-02-23T16:16:26.7349296Z Generated XML report: test-reports/python-unittest/distributed.optim.test_zero_redundancy_optimizer/TEST-TestZeroRedundancyOptimizerSingleRank-20220223161123.xml\n2022-02-23T16:16:27.6823633Z Traceback (most recent call last):\n\n\n\nThis comment was automatically generated by Dr. CI (expand for details).\nPlease report bugs/suggestions to the (internal) Dr. CI Users group.\nClick here to manually regenerate this comment.", + "createdAt": "2022-01-25T09:31:08Z", + "author": { + "login": "facebook-github-bot" + }, + "authorAssociation": "MEMBER", + "editor": { + "login": "facebook-github-bot" + }, + "databaseId": 1020983433 + }, + { + "bodyText": "Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!", + "createdAt": "2022-01-25T18:07:45Z", + "author": { + "login": "facebook-github-bot" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1021467314 + }, + { + "bodyText": "@albanD Is there something that needs to be done to correct the failed check ?", + "createdAt": "2022-02-04T13:18:05Z", + "author": { + "login": "coolteemf" + }, + "authorAssociation": "CONTRIBUTOR", + "editor": null, + "databaseId": 1029978104 + }, + { + "bodyText": "Hi,\nI think you didn't do the merge properly as there are now a lot more commits than it should be in this PR.\nYou can either clean up the branch locally and force push here or open a new clean PR.\nNote that in general, it is better to rebase on top of master than merge master into your branch!", + "createdAt": "2022-02-04T14:28:28Z", + "author": { + "login": "albanD" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1030038719 + }, + { + "bodyText": "Okay thank you for the heads up", + "createdAt": "2022-02-04T16:44:46Z", + "author": { + "login": "coolteemf" + }, + "authorAssociation": "CONTRIBUTOR", + "editor": null, + "databaseId": 1030159616 + }, + { + "bodyText": "@albanD I just rebased and updated the branch to take into account changes from 28388b4. Is it all clear for merging ?", + "createdAt": "2022-02-16T15:34:59Z", + "author": { + "login": "coolteemf" + }, + "authorAssociation": "CONTRIBUTOR", + "editor": null, + "databaseId": 1041720345 + }, + { + "bodyText": "Thanks! The CI needs fixing for bc-compat and lint though\n\nThe lint should be fixed, however I didn't find clear instructions on how to fix the bc compat.\nI guess output_mask could be made optional, however in the case of native_group_norm_backward the same argument is not optional.", + "createdAt": "2022-02-17T08:04:30Z", + "author": { + "login": "coolteemf" + }, + "authorAssociation": "CONTRIBUTOR", + "editor": null, + "databaseId": 1042672732 + }, + { + "bodyText": "Since we are changing the signature on purpose here, you can add it to the list at https://github.com/pytorch/pytorch/blob/master/test/forward_backward_compatibility/check_forward_backward_compatibility.py#L29 to silence the test.", + "createdAt": "2022-02-17T14:41:16Z", + "author": { + "login": "albanD" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1043020903 + }, + { + "bodyText": "@pytorchbot merge this please", + "createdAt": "2022-02-23T14:48:05Z", + "author": { + "login": "albanD" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1048861185 + }, + { + "bodyText": "Merge failed due to 'NoneType' object is not subscriptable\nRaised by https://github.com/pytorch/pytorch/actions/runs/1887914411", + "createdAt": "2022-02-23T14:49:16Z", + "author": { + "login": "pytorchmergebot" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1048862374 + }, + { + "bodyText": "@coolteemf you can ignore me playing with the bot. Nothing is needed on your end anymore, I'll take it from here.", + "createdAt": "2022-02-23T14:52:10Z", + "author": { + "login": "albanD" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1048865236 + }, + { + "bodyText": "@pytorchbot merge this", + "createdAt": "2022-02-23T14:54:23Z", + "author": { + "login": "malfet" + }, + "authorAssociation": "MEMBER", + "editor": null, + "databaseId": 1048867615 + } + ], + "pageInfo": { + "startCursor": "Y3Vyc29yOnYyOpHOPNr4Ug==", + "hasPreviousPage": false + } + } + } + } + } + }, + "query_sha=2e2877d2452c4f233f042b7ccd50ab9c2a6e9a73d8819a0c876203c12364e8a3 cursor=Y3Vyc29yOnYyOpHOQebHmg== name=pytorch number=75095 owner=pytorch": { + "data": { + "repository": { + "pullRequest": { + "comments": { + "nodes": [ + { + "bodyText": "\ud83d\udd17 Helpful links\n\n\ud83e\uddea \u00a0See artifacts and rendered test results at hud.pytorch.org/pr/75095\n\ud83d\udcc4 \u00a0Preview Python docs built from this PR\n\ud83d\udcc4 \u00a0Preview C++ docs built from this PR\n\u2753Need help or want to give feedback on the CI? Visit our office hours\n\n\ud83d\udc8a CI failures summary and remediations\nAs of commit db355d5 (more details on the Dr. CI page):\nExpand to see more\n\n\ud83d\udc9a \ud83d\udc9a Looks good so far! There are no failures yet. \ud83d\udc9a \ud83d\udc9a\n\nThis comment was automatically generated by Dr. CI (expand for details).\nPlease report bugs/suggestions to the (internal) Dr. CI Users group.\nClick here to manually regenerate this comment.", + "createdAt": "2022-04-01T08:49:06Z", + "author": { + "login": "facebook-github-bot" + }, + "authorAssociation": "MEMBER", + "editor": { + "login": "facebook-github-bot" + }, + "databaseId": 1085625658 + }, + { + "bodyText": "High level question: how do we plan to validate that our ref implementations are compatible with somewhat-symbolic shapes? There are multiple ways to write the shape processing logic to be compatible vs not, it'd be good to catch such instances early. Does it make sense to throw in some proxy objects (that have state of 0,1,N) in tests early on? (maybe in a follow up PR). Otherwise it's not clear to me that squeeze/broadcast/etc are the right set of primitives for symbolic shapes", + "createdAt": "2022-04-21T18:51:24Z", + "author": { + "login": "dzhulgakov" + }, + "authorAssociation": "COLLABORATOR", + "editor": null, + "databaseId": 1105634766 + } + ], + "pageInfo": { + "startCursor": "Y3Vyc29yOnYyOpHOQLVVOg==", + "hasPreviousPage": false + } + } + } + } + } } } diff --git a/.github/scripts/install_nvidia_utils_linux.sh b/.github/scripts/install_nvidia_utils_linux.sh deleted file mode 100755 index 855d15dde83b4..0000000000000 --- a/.github/scripts/install_nvidia_utils_linux.sh +++ /dev/null @@ -1,79 +0,0 @@ -#!/usr/bin/env bash - -set -eou pipefail - - -DISTRIBUTION=$(. /etc/os-release;echo $ID$VERSION_ID) -DRIVER_VERSION="515.57" -DRIVER_FN="NVIDIA-Linux-x86_64-${DRIVER_VERSION}.run" -YUM_REPO_URL="https://nvidia.github.io/nvidia-docker/${DISTRIBUTION}/nvidia-docker.repo" - -install_nvidia_docker2_amzn2() { - ( - set -x - # Needed for yum-config-manager - sudo yum install -y yum-utils - sudo yum-config-manager --add-repo "${YUM_REPO_URL}" - sudo yum install -y nvidia-docker2 - sudo systemctl restart docker - ) -} - -install_nvidia_driver_amzn2() { - ( - set -x - - # Purge any nvidia driver installed from RHEL repo - sudo yum remove -y nvidia-driver-latest-dkms - - HAS_NVIDIA_DRIVER=0 - # Check if NVIDIA driver has already been installed - if [ -x "$(command -v nvidia-smi)" ]; then - # The driver exists, check its version next - INSTALLED_DRIVER_VERSION=$(nvidia-smi --query-gpu=driver_version --format=csv,noheader) - - if [ "$INSTALLED_DRIVER_VERSION" != "$DRIVER_VERSION" ]; then - echo "NVIDIA driver ($INSTALLED_DRIVER_VERSION) has been installed, but we expect to have $DRIVER_VERSION instead. Continuing" - else - HAS_NVIDIA_DRIVER=1 - echo "NVIDIA driver ($INSTALLED_DRIVER_VERSION) has already been installed. Skipping NVIDIA driver installation" - fi - fi - - if [ "$HAS_NVIDIA_DRIVER" -eq 0 ]; then - sudo yum groupinstall -y "Development Tools" - # ensure our kernel install is the same as our underlying kernel, - # groupinstall "Development Tools" has a habit of mismatching kernel headers - sudo yum install -y "kernel-devel-uname-r == $(uname -r)" - sudo modprobe backlight - sudo curl -fsL -o /tmp/nvidia_driver "https://s3.amazonaws.com/ossci-linux/nvidia_driver/$DRIVER_FN" - sudo /bin/bash /tmp/nvidia_driver -s --no-drm || (sudo cat /var/log/nvidia-installer.log && false) - sudo rm -fv /tmp/nvidia_driver - fi - - nvidia-smi - ) -} - -echo "== Installing nvidia driver ${DRIVER_FN} ==" -case "${DISTRIBUTION}" in - amzn*) - install_nvidia_driver_amzn2 - ;; - *) - echo "ERROR: Unknown distribution ${DISTRIBUTION}" - exit 1 - ;; -esac - -# Install container toolkit based on distribution -echo "== Installing nvidia container toolkit for ${DISTRIBUTION} ==" -case "${DISTRIBUTION}" in - amzn*) - install_nvidia_docker2_amzn2 - ;; - *) - echo "ERROR: Unknown distribution ${DISTRIBUTION}" - exit 1 - ;; -esac diff --git a/.github/scripts/process_commit.py b/.github/scripts/process_commit.py deleted file mode 100644 index 358f9012c92fd..0000000000000 --- a/.github/scripts/process_commit.py +++ /dev/null @@ -1,106 +0,0 @@ -#!/usr/bin/env python3 -""" -This script finds the user/pr creator responsible for labeling a PR by a commit SHA. It is used by the workflow in -'.github/workflows/pr-labels.yml'. If there exists no PR associated with the commit or the PR is properly labeled, -this script is a no-op. - -Note: we ping the user only, not the reviewers, as the reviewers can sometimes be external to pytorch -with no labeling responsibility, so we don't want to bother them. -This script is based on: https://github.com/pytorch/vision/blob/main/.github/process_commit.py -""" - -import sys -from typing import Any, Set, Tuple, List -import re -import os -import json -import requests - -# For a PR to be properly labeled it should have release notes label and one topic label -PULL_REQUEST_EXP = "Pull Request resolved:.*pull/(.*)" -PRIMARY_LABEL_FILTER = "release notes:" -SECONDARY_LABELS = { - "topic: bc_breaking", - "topic: deprecation", - "topic: new feature", - "topic: improvements", - "topic: bug fixes", - "topic: performance", - "topic: documentation", - "topic: developer feature", - "topic: not user facing", -} -# This secondary does not require a primary -ALLOWED_ONLY_SECONDARY = {"topic: not user facing"} -PYTORCH_REPO = "https://api.github.com/repos/pytorch/pytorch" -GITHUB_TOKEN = os.environ.get('GITHUB_TOKEN') -REQUEST_HEADERS = {'Accept': 'application/vnd.github.v3+json', 'Authorization': f'token {GITHUB_TOKEN}'} - - -def query_pytorch(cmd: str) -> Any: - response = requests.get(f"{PYTORCH_REPO}/{cmd}", headers=REQUEST_HEADERS) - return response.json() - - -def get_pr_number(commit_hash: str) -> Any: - data = query_pytorch(f"commits/{commit_hash}") - if not data or (not data["commit"]["message"]): - return None - message = data["commit"]["message"] - p = re.compile(PULL_REQUEST_EXP) - result = p.search(message) - if not result: - return None - return result.group(1) - - -def get_pr_author_and_labels(pr_number: int) -> Tuple[str, Set[str]]: - # See https://docs.github.com/en/rest/reference/pulls#get-a-pull-request - data = query_pytorch(f"pulls/{pr_number}") - user = data["user"]["login"] - labels = {label["name"] for label in data["labels"]} - return user, labels - -def get_repo_labels() -> List[str]: - collected_labels: List[str] = list() - for page in range(0, 10): - response = query_pytorch(f"labels?per_page=100&page={page}") - page_labels = list(map(lambda x: str(x["name"]), response)) - if not page_labels: - break - collected_labels += page_labels - return collected_labels - -def post_pytorch_comment(pr_number: int, merger: str) -> Any: - message = {'body' : f"Hey @{merger}." + """ -You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. \ -Please add one of each to the PR. The 'release notes: ...' label should represent the part of \ -PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should \ -represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). \ -The list of valid labels can be found [here](https://github.com/pytorch/pytorch/labels?q=release+notes) \ -for the 'release notes: ...' and [here](https://github.com/pytorch/pytorch/labels?q=topic) for the \ -'topics: ...'. -For changes that are 'topic: not user facing' there is no need for a release notes label."""} - - response = requests.post( - f"{PYTORCH_REPO}/issues/{pr_number}/comments", - json.dumps(message), - headers=REQUEST_HEADERS) - return response.json() - -if __name__ == "__main__": - commit_hash = sys.argv[1] - pr_number = get_pr_number(commit_hash) - - if not pr_number: - sys.exit(0) - - user, labels = get_pr_author_and_labels(pr_number) - repo_labels = get_repo_labels() - - primary_labels = set(filter(lambda x: x.startswith(PRIMARY_LABEL_FILTER), repo_labels)) - has_both_labels = bool(primary_labels.intersection(labels)) and bool(SECONDARY_LABELS.intersection(labels)) - is_properly_labeled = has_both_labels or bool(ALLOWED_ONLY_SECONDARY.intersection(labels)) - - if not is_properly_labeled: - post_pytorch_comment(pr_number, user) diff --git a/.github/scripts/test_check_labels.py b/.github/scripts/test_check_labels.py new file mode 100644 index 0000000000000..64e91dcd8ecbe --- /dev/null +++ b/.github/scripts/test_check_labels.py @@ -0,0 +1,77 @@ +"""test_check_labels.py""" + +from typing import Any +from unittest import TestCase, mock, main + +from trymerge import GitHubPR +from test_trymerge import mocked_gh_graphql +from check_labels import has_required_labels + +release_notes_labels = [ + "release notes: AO frontend", + "release notes: autograd", + "release notes: benchmark", + "release notes: build", + "release notes: complex", + "release notes: composability", + "release notes: cpp", + "release notes: cuda", + "release notes: cudnn", + "release notes: dataloader", + "release notes: distributed (c10d)", + "release notes: distributed (ddp)", + "release notes: distributed (fsdp)", + "release notes: distributed (pipeline)", + "release notes: distributed (rpc)", + "release notes: distributed (sharded)", + "release notes: foreach_frontend", + "release notes: functorch", + "release notes: fx", + "release notes: hub", + "release notes: jit", + "release notes: lazy", + "release notes: linalg_frontend", + "release notes: memory format", + "release notes: Meta API", + "release notes: mobile", + "release notes: mps", + "release notes: nested tensor", + "release notes: nn", + "release notes: onnx", + "release notes: package/deploy", + "release notes: performance_as_product", + "release notes: profiler", + "release notes: python_frontend", + "release notes: quantization", + "release notes: releng", + "release notes: rocm", + "release notes: sparse", + "release notes: visualization", + "release notes: vulkan", +] + + +class TestCheckLabels(TestCase): + @mock.patch('trymerge.gh_graphql', side_effect=mocked_gh_graphql) + @mock.patch('check_labels.get_release_notes_labels', return_value=release_notes_labels) + def test_pr_with_missing_labels(self, mocked_rn_labels: Any, mocked_gql: Any) -> None: + "Test PR with no 'release notes:' label or 'topic: not user facing' label" + pr = GitHubPR("pytorch", "pytorch", 82169) + self.assertFalse(has_required_labels(pr)) + + @mock.patch('trymerge.gh_graphql', side_effect=mocked_gh_graphql) + @mock.patch('check_labels.get_release_notes_labels', return_value=release_notes_labels) + def test_pr_with_release_notes_label(self, mocked_rn_labels: Any, mocked_gql: Any) -> None: + "Test PR with 'release notes: nn' label" + pr = GitHubPR("pytorch", "pytorch", 71759) + self.assertTrue(has_required_labels(pr)) + + @mock.patch('trymerge.gh_graphql', side_effect=mocked_gh_graphql) + @mock.patch('check_labels.get_release_notes_labels', return_value=release_notes_labels) + def test_pr_with_not_user_facing_label(self, mocked_rn_labels: Any, mocked_gql: Any) -> None: + "Test PR with 'topic: not user facing' label" + pr = GitHubPR("pytorch", "pytorch", 75095) + self.assertTrue(has_required_labels(pr)) + +if __name__ == "__main__": + main() diff --git a/.github/scripts/test_filter_test_configs.py b/.github/scripts/test_filter_test_configs.py index a043a35355431..55410e846c972 100755 --- a/.github/scripts/test_filter_test_configs.py +++ b/.github/scripts/test_filter_test_configs.py @@ -4,7 +4,14 @@ import yaml import json from unittest import TestCase, main, mock -from filter_test_configs import get_labels, filter, PREFIX, VALID_TEST_CONFIG_LABELS +from filter_test_configs import ( + get_labels, + filter, + set_periodic_modes, + PREFIX, + VALID_TEST_CONFIG_LABELS, + SUPPORTED_PERIODICAL_MODES +) import requests from requests.models import Response from typing import Any, Dict @@ -86,5 +93,26 @@ def test_filter_with_valid_label(self) -> None: self.assertEqual(case["expected"], json.dumps(filtered_test_matrix)) + def test_set_periodic_modes(self) -> None: + testcases = [ + { + "test_matrix": "{include: []}", + "description": "Empty test matrix", + }, + { + "test_matrix": '{include: [{config: "default", runner: "linux"}, {config: "cfg", runner: "macos"}]}', + "descripion": "Replicate each periodic mode in a different config", + }, + ] + + for case in testcases: + test_matrix = yaml.safe_load(case["test_matrix"]) + scheduled_test_matrix = set_periodic_modes(test_matrix) + self.assertEqual( + len(test_matrix["include"]) * len(SUPPORTED_PERIODICAL_MODES), + len(scheduled_test_matrix["include"]) + ) + + if __name__ == '__main__': main() diff --git a/.github/scripts/trymerge.py b/.github/scripts/trymerge.py index 87ca3ac06579c..697b4b94faac4 100755 --- a/.github/scripts/trymerge.py +++ b/.github/scripts/trymerge.py @@ -179,6 +179,7 @@ class WorkflowCheckState(NamedTuple): comments(last: 5) { nodes { bodyText + createdAt author { login } @@ -336,6 +337,7 @@ class WorkflowCheckState(NamedTuple): comments(last: 100, before: $cursor) { nodes { bodyText + createdAt author { login } @@ -405,6 +407,7 @@ class WorkflowCheckState(NamedTuple): r'https://github.com/(?P[^/]+)/(?P[^/]+)/pull/(?P[0-9]+)', re.MULTILINE ) +RE_PR_CC_LINE = re.compile(r'^cc:? @\w+.*\r?\n?$', re.MULTILINE) RE_DIFF_REV = re.compile(r'^Differential Revision:.+?(D[0-9]+)', re.MULTILINE) CIFLOW_LABEL = re.compile(r"^ciflow/.+") CIFLOW_TRUNK_LABEL = re.compile(r"^ciflow/trunk") @@ -583,6 +586,7 @@ def can_skip_internal_checks(pr: "GitHubPR", comment_id: Optional[int] = None) - @dataclass class GitHubComment: body_text: str + created_at: str author_login: str author_association: str editor_login: Optional[str] @@ -807,6 +811,7 @@ def get_pr_url(self) -> str: def _comment_from_node(node: Any) -> GitHubComment: editor = node["editor"] return GitHubComment(body_text=node["bodyText"], + created_at=node["createdAt"] if "createdAt" in node else "", author_login=node["author"]["login"], author_association=node["authorAssociation"], editor_login=editor["login"] if editor else None, @@ -903,8 +908,12 @@ def gen_commit_message(self, filter_ghstack: bool = False) -> str: filters out ghstack info """ # Adding the url here makes it clickable within the Github UI approved_by_urls = ', '.join(prefix_with_github_url(login) for login in self.get_approved_by()) + # Remove "cc: " line from the message body + msg_body = re.sub(RE_PR_CC_LINE, "", self.get_body()) + if filter_ghstack: + msg_body = re.sub(RE_GHSTACK_DESC, "", msg_body) msg = self.get_title() + f" (#{self.pr_num})\n\n" - msg += self.get_body() if not filter_ghstack else re.sub(RE_GHSTACK_DESC, "", self.get_body()) + msg += msg_body msg += f"\nPull Request resolved: {self.get_pr_url()}\n" msg += f"Approved by: {approved_by_urls}\n" return msg diff --git a/.github/scripts/tryrebase.py b/.github/scripts/tryrebase.py index 1b69f653e525a..2e8987e9faaa1 100755 --- a/.github/scripts/tryrebase.py +++ b/.github/scripts/tryrebase.py @@ -69,6 +69,7 @@ def rebase_ghstack_onto(pr: GitHubPR, repo: GitRepo, onto_branch: str, dry_run: push_result = ghstack_result.stdout.decode("utf-8") print(push_result) if ghstack_result.returncode != 0: + print(ghstack_result.stderr.decode("utf-8")) raise Exception(f"\n```{push_result}```") # The contents of a successful push result should look like: # Summary of changes (ghstack 0.6.0) diff --git a/.github/templates/common.yml.j2 b/.github/templates/common.yml.j2 index a2941546abe1c..edb652ff16ce5 100644 --- a/.github/templates/common.yml.j2 +++ b/.github/templates/common.yml.j2 @@ -78,7 +78,12 @@ concurrency: run: | ngpu=$(rocminfo | grep -c -E 'Name:.*\sgfx') if [[ "x$ngpu" != "x2" && "x$ngpu" != "x4" ]]; then - echo "Failed to detect GPUs on the runner" + if [[ $ngpu -eq 0 ]]; then + echo "Error: Failed to detect any GPUs on the runner" + else + echo "Error: Detected $ngpu GPUs on the runner, when only 2 or 4 were expected" + fi + echo "Please file an issue on pytorch/pytorch reporting the faulty runner. Include a link to the runner logs so the runner can be identified" exit 1 fi - name: Runner health check disconnect on failure diff --git a/.github/templates/macos_binary_build_workflow.yml.j2 b/.github/templates/macos_binary_build_workflow.yml.j2 index 149c007daef9e..eb0c2ff4b3734 100644 --- a/.github/templates/macos_binary_build_workflow.yml.j2 +++ b/.github/templates/macos_binary_build_workflow.yml.j2 @@ -58,17 +58,8 @@ jobs: {%- for config in build_configs %} !{{ config["build_name"] }}-build: if: ${{ github.repository_owner == 'pytorch' }} - {%- if config["package_type"] == "libtorch" %} - runs-on: macos-10.15 - {%- else %} runs-on: macos-12-xl - {%- endif %} -{%- if config["package_type"] == "libtorch" %} - # libtorch builds take a long time on github hosted runners - timeout-minutes: 720 -{%- else %} timeout-minutes: !{{ common.timeout_minutes }} -{%- endif %} !{{ upload.binary_env(config, true) }} # For sccache access (only on non-forked PRs) AWS_ACCESS_KEY_ID: ${{ secrets.MACOS_SCCACHE_S3_ACCESS_KEY_ID }} @@ -78,10 +69,11 @@ jobs: - name: Install conda and dependencies run: | # Install conda, setup-miniconda messes with the path that messes with the ruby stuff we do later on - curl --retry 3 -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh + curl --retry 3 --retry-all-errors -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh chmod +x "${RUNNER_TEMP}/conda.sh" /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" + echo "DEVELOPER_DIR=/Applications/Xcode_13.3.1.app/Contents/Developer" >> "${GITHUB_ENV}" !{{ common.checkout(deep_clone=False, directory="pytorch") }} !{{ common.checkout(deep_clone=False, directory="builder", repository="pytorch/builder", branch=common.builder_branch) }} - name: Install sccache (only for non-forked PRs, and pushes to trunk) @@ -92,7 +84,7 @@ jobs: max_attempts: 3 retry_wait_seconds: 90 command: | - sudo curl --retry 3 https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache + sudo curl --retry 3 --retry-all-errors https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache sudo chmod +x /usr/local/bin/sccache echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" - name: Populate binary env @@ -105,7 +97,7 @@ jobs: # shellcheck disable=SC1091 source "${RUNNER_TEMP}/anaconda/bin/activate" "${PYTORCH_ROOT}/.circleci/scripts/binary_macos_build.sh" - - uses: actions/upload-artifact@v2 + - uses: actions/upload-artifact@v3 if: always() with: name: !{{ config["build_name"] }} diff --git a/.github/workflows/_android-build-test.yml b/.github/workflows/_android-build-test.yml index 5538bc58cf425..dfa48daa84acd 100644 --- a/.github/workflows/_android-build-test.yml +++ b/.github/workflows/_android-build-test.yml @@ -28,6 +28,11 @@ jobs: if: github.repository_owner == 'pytorch' runs-on: [self-hosted, linux.2xlarge] steps: + - name: Setup SSH (Click me for login details) + uses: pytorch/test-infra/.github/actions/setup-ssh@main + with: + github-secret: ${{ secrets.GITHUB_TOKEN }} + # [see note: pytorch repo ref] - name: Checkout PyTorch uses: pytorch/pytorch/.github/actions/checkout-pytorch@master @@ -35,11 +40,6 @@ jobs: - name: Setup Linux uses: ./.github/actions/setup-linux - - name: Setup SSH (Click me for login details) - uses: pytorch/test-infra/.github/actions/setup-ssh@main - with: - github-secret: ${{ secrets.GITHUB_TOKEN }} - - name: Calculate docker image id: calculate-docker-image uses: ./.github/actions/calculate-docker-image diff --git a/.github/workflows/_android-full-build-test.yml b/.github/workflows/_android-full-build-test.yml index 1680461be78ef..ea07fda814b1d 100644 --- a/.github/workflows/_android-full-build-test.yml +++ b/.github/workflows/_android-full-build-test.yml @@ -28,6 +28,11 @@ jobs: if: github.repository_owner == 'pytorch' runs-on: [self-hosted, linux.2xlarge] steps: + - name: Setup SSH (Click me for login details) + uses: pytorch/test-infra/.github/actions/setup-ssh@main + with: + github-secret: ${{ secrets.GITHUB_TOKEN }} + # [see note: pytorch repo ref] - name: Checkout PyTorch uses: pytorch/pytorch/.github/actions/checkout-pytorch@master @@ -35,11 +40,6 @@ jobs: - name: Setup Linux uses: ./.github/actions/setup-linux - - name: Setup SSH (Click me for login details) - uses: pytorch/test-infra/.github/actions/setup-ssh@main - with: - github-secret: ${{ secrets.GITHUB_TOKEN }} - - name: Calculate docker image id: calculate-docker-image uses: ./.github/actions/calculate-docker-image @@ -128,7 +128,7 @@ jobs: # run gradle buildRelease (echo "./.circleci/scripts/build_android_gradle.sh" | docker exec \ - -e BUILD_ENVIRONMENT="pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-build" \ + -e BUILD_ENVIRONMENT="pytorch-linux-focal-py3-clang7-android-ndk-r19c-gradle-build" \ -e MAX_JOBS="$(nproc --ignore=2)" \ -e AWS_DEFAULT_REGION \ -e PR_NUMBER \ diff --git a/.github/workflows/_bazel-build-test.yml b/.github/workflows/_bazel-build-test.yml index a64758c2b1182..79445e1dad6c1 100644 --- a/.github/workflows/_bazel-build-test.yml +++ b/.github/workflows/_bazel-build-test.yml @@ -28,6 +28,11 @@ jobs: if: github.repository_owner == 'pytorch' runs-on: [self-hosted, linux.2xlarge] steps: + - name: Setup SSH (Click me for login details) + uses: pytorch/test-infra/.github/actions/setup-ssh@main + with: + github-secret: ${{ secrets.GITHUB_TOKEN }} + # [see note: pytorch repo ref] - name: Checkout PyTorch uses: pytorch/pytorch/.github/actions/checkout-pytorch@master @@ -35,11 +40,6 @@ jobs: - name: Setup Linux uses: ./.github/actions/setup-linux - - name: Setup SSH (Click me for login details) - uses: pytorch/test-infra/.github/actions/setup-ssh@main - with: - github-secret: ${{ secrets.GITHUB_TOKEN }} - - name: Calculate docker image id: calculate-docker-image uses: ./.github/actions/calculate-docker-image diff --git a/.github/workflows/_binary-build-linux.yml b/.github/workflows/_binary-build-linux.yml index a665f53bab5e0..192ca251b79ff 100644 --- a/.github/workflows/_binary-build-linux.yml +++ b/.github/workflows/_binary-build-linux.yml @@ -67,8 +67,8 @@ on: jobs: build: - runs-on: linux.4xlarge - timeout-minutes: 270 + runs-on: linux.12xlarge + timeout-minutes: 150 env: PYTORCH_ROOT: ${{ inputs.PYTORCH_ROOT }} BUILDER_ROOT: ${{ inputs.BUILDER_ROOT }} @@ -126,16 +126,16 @@ jobs: - name: List the env shell: bash run: env + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: pytorch/test-infra/.github/actions/setup-ssh@main + with: + github-secret: ${{ secrets.github-token }} - name: Checkout PyTorch uses: pytorch/pytorch/.github/actions/checkout-pytorch@master - name: Setup Linux uses: ./.github/actions/setup-linux - name: Chown workspace uses: ./.github/actions/chown-workspace - - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" - uses: pytorch/test-infra/.github/actions/setup-ssh@main - with: - github-secret: ${{ secrets.github-token }} - name: Clean workspace shell: bash run: | @@ -167,11 +167,6 @@ jobs: git clean -fxd working-directory: builder - - name: Set BUILD_SPLIT_CUDA - if: ${{ inputs.GPU_ARCH_TYPE == 'cuda' && startsWith(inputs.GPU_ARCH_VERSION, '11') }} - shell: bash - run: | - echo "BUILD_SPLIT_CUDA='ON'" >> "$GITHUB_ENV" - name: Pull Docker image uses: pytorch/test-infra/.github/actions/pull-docker-image@main with: @@ -184,7 +179,6 @@ jobs: -e BINARY_ENV_FILE \ -e BUILDER_ROOT \ -e BUILD_ENVIRONMENT \ - -e BUILD_SPLIT_CUDA \ -e DESIRED_CUDA \ -e DESIRED_DEVTOOLSET \ -e DESIRED_PYTHON \ diff --git a/.github/workflows/_binary-test-linux.yml b/.github/workflows/_binary-test-linux.yml index c18afe1b5b6ce..471a2af88b8f5 100644 --- a/.github/workflows/_binary-test-linux.yml +++ b/.github/workflows/_binary-test-linux.yml @@ -122,6 +122,10 @@ jobs: echo "SHA1=${{ env.SHA1 }}" } >> "${GITHUB_ENV} }}" + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: pytorch/test-infra/.github/actions/setup-ssh@main + with: + github-secret: ${{ secrets.github-token }} # Setup the environment - name: Checkout PyTorch uses: pytorch/pytorch/.github/actions/checkout-pytorch@master @@ -129,10 +133,6 @@ jobs: uses: ./.github/actions/setup-linux - name: Chown workspace uses: ./.github/actions/chown-workspace - - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" - uses: pytorch/test-infra/.github/actions/setup-ssh@main - with: - github-secret: ${{ secrets.github-token }} - name: Clean workspace shell: bash run: | @@ -171,17 +171,8 @@ jobs: path: "${{ runner.temp }}/artifacts/" - name: Install nvidia driver, nvidia-docker runtime, set GPU_FLAG - uses: nick-fields/retry@3e91a01664abd3c5cd539100d10d33b9c5b68482 + uses: pytorch/test-infra/.github/actions/setup-nvidia@main if: ${{ inputs.GPU_ARCH_TYPE == 'cuda' }} - with: - timeout_minutes: 10 - max_attempts: 3 - command: | - set -ex - pushd pytorch - bash .github/scripts/install_nvidia_utils_linux.sh - echo "GPU_FLAG=--gpus all" >> "${GITHUB_ENV}" - popd - name: Pull Docker image uses: pytorch/test-infra/.github/actions/pull-docker-image@main diff --git a/.github/workflows/_buck-build-test.yml b/.github/workflows/_buck-build-test.yml index 52b5d4b3c6f45..07f41299c711b 100644 --- a/.github/workflows/_buck-build-test.yml +++ b/.github/workflows/_buck-build-test.yml @@ -21,29 +21,10 @@ jobs: distribution: 'temurin' - name: Setup miniconda - uses: conda-incubator/setup-miniconda@v2 + uses: pytorch/test-infra/.github/actions/setup-miniconda@main with: - auto-update-conda: true python-version: 3.8 - activate-environment: build - - - name: Install dependencies - uses: nick-fields/retry@3e91a01664abd3c5cd539100d10d33b9c5b68482 - with: - timeout_minutes: 10 - max_attempts: 5 - command: | - conda install -y \ - cffi \ - cmake \ - mkl \ - mkl-include \ - ninja \ - numpy \ - pyyaml \ - requests \ - setuptools \ - typing_extensions + environment-file: .github/requirements/conda-env-${{ runner.os }}-${{ runner.arch }} - name: Install Buck uses: nick-fields/retry@3e91a01664abd3c5cd539100d10d33b9c5b68482 diff --git a/.github/workflows/_docs.yml b/.github/workflows/_docs.yml index d46e28f844f2d..318471e7c7860 100644 --- a/.github/workflows/_docs.yml +++ b/.github/workflows/_docs.yml @@ -48,8 +48,9 @@ jobs: # to the next available tier of 12xlarge. So much memory just to generate cpp # doc runner: linux.12xlarge - # Nightly cpp docs take about 150m to finish, and the number is stable - timeout-minutes: 180 + # TODO: Nightly cpp docs take longer and longer to finish (more than 3h now) + # Let's try to figure out how this can be improved + timeout-minutes: 240 - docs_type: python runner: linux.2xlarge # It takes less than 30m to finish python docs unless there are issues @@ -58,7 +59,20 @@ jobs: runner: linux.2xlarge # It takes less than 15m to finish functorch docs unless there are issues timeout-minutes: 15 + # Set a fixed name for this job instead of using the current matrix-generated name, i.e. build-docs (cpp, linux.12xlarge, 180) + # The current name requires updating the Rockset last docs push query from test-infra every time the matrix is updated + name: build-docs-${{ matrix.docs_type }}-${{ inputs.push }} steps: + - name: Setup SSH (Click me for login details) + uses: pytorch/test-infra/.github/actions/setup-ssh@main + with: + github-secret: ${{ secrets.GITHUB_TOKEN }} + instructions: | + All builds are done inside the container, to start an interactive session run: + docker exec -it $(docker container ps --format '{{.ID}}') bash + To start Python docs build type: + cd docs && make html && make coverage + # [see note: pytorch repo ref] - name: Checkout PyTorch uses: pytorch/pytorch/.github/actions/checkout-pytorch@master @@ -66,11 +80,6 @@ jobs: - name: Setup Linux uses: ./.github/actions/setup-linux - - name: Setup SSH (Click me for login details) - uses: pytorch/test-infra/.github/actions/setup-ssh@main - with: - github-secret: ${{ secrets.GITHUB_TOKEN }} - - name: Pull docker image uses: pytorch/test-infra/.github/actions/pull-docker-image@main with: @@ -95,7 +104,10 @@ jobs: timeout-minutes: ${{ matrix.timeout-minutes }} id: build-docs env: - WITH_PUSH: ${{ github.event_name == 'schedule' || startsWith(github.event.ref, 'refs/tags/v') }} + # After https://github.com/pytorch/pytorch/pull/88373, pull workflow can now be run periodically, + # so using a schedule event to determine if the docs should be pushed or not doesn't hold true + # anymore + WITH_PUSH: ${{ inputs.push }} DOCKER_IMAGE: ${{ inputs.docker-image }} DOCS_TYPE: ${{ matrix.docs_type }} RUN_DOXYGEN: ${{ inputs.run-doxygen }} @@ -163,3 +175,6 @@ jobs: if-no-files-found: error path: functorch_ghpages/nightly/ s3-prefix: pytorch/${{ github.event.pull_request.number }}/functorchdocs + - name: Teardown Linux + uses: pytorch/test-infra/.github/actions/teardown-linux@main + if: always() diff --git a/.github/workflows/_ios-build-test.yml b/.github/workflows/_ios-build-test.yml index 665ff1b9ce16f..269ad3f153ca4 100644 --- a/.github/workflows/_ios-build-test.yml +++ b/.github/workflows/_ios-build-test.yml @@ -68,7 +68,7 @@ jobs: - name: Install conda and dependencies run: | # Install conda, setup-miniconda messes with the path that messes with the ruby stuff we do later on - curl --retry 3 -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh + curl --retry 3 --retry-all-errors -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh chmod +x "${RUNNER_TEMP}/conda.sh" /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" @@ -106,7 +106,7 @@ jobs: - name: Build TestApp if: inputs.ios-platform == 'SIMULATOR' - timeout-minutes: 5 + timeout-minutes: 15 run: | # run the ruby build script if ! [ -x "$(command -v xcodebuild)" ]; then diff --git a/.github/workflows/_linux-build.yml b/.github/workflows/_linux-build.yml index cc7945c98760b..be3d2ce98c030 100644 --- a/.github/workflows/_linux-build.yml +++ b/.github/workflows/_linux-build.yml @@ -34,6 +34,12 @@ on: default: "5.2" description: | List of CUDA architectures CI build should target. + runner: + required: false + type: string + default: "linux.2xlarge" + description: | + List of CUDA architectures CI build should target. test-matrix: required: false @@ -55,11 +61,16 @@ jobs: build: # Don't run on forked repos if: github.repository_owner == 'pytorch' - runs-on: [self-hosted, linux.2xlarge] + runs-on: ${{ inputs.runner }} timeout-minutes: 240 outputs: docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} steps: + - name: Setup SSH (Click me for login details) + uses: pytorch/test-infra/.github/actions/setup-ssh@main + with: + github-secret: ${{ secrets.GITHUB_TOKEN }} + # [pytorch repo ref] # Use a pytorch/pytorch reference instead of a reference to the local # checkout because when we run this action we don't *have* a local @@ -70,11 +81,6 @@ jobs: - name: Setup Linux uses: ./.github/actions/setup-linux - - name: Setup SSH (Click me for login details) - uses: pytorch/test-infra/.github/actions/setup-ssh@main - with: - github-secret: ${{ secrets.GITHUB_TOKEN }} - - name: Calculate docker image id: calculate-docker-image uses: ./.github/actions/calculate-docker-image diff --git a/.github/workflows/_linux-test.yml b/.github/workflows/_linux-test.yml index e1e95ee5e7892..a444a5fc530a8 100644 --- a/.github/workflows/_linux-test.yml +++ b/.github/workflows/_linux-test.yml @@ -22,6 +22,12 @@ on: description: | If this is set, our linter will use this to make sure that every other job with the same `sync-tag` is identical. + timeout-minutes: + required: false + type: number + default: 240 + description: | + Set the maximum (in minutes) how long the workflow should take to finish env: GIT_DEFAULT_BRANCH: ${{ github.event.repository.default_branch }} @@ -56,33 +62,30 @@ jobs: matrix: ${{ fromJSON(needs.filter.outputs.test-matrix) }} fail-fast: false runs-on: ${{ matrix.runner }} + timeout-minutes: ${{ inputs.timeout-minutes }} steps: + - name: Setup SSH (Click me for login details) + uses: pytorch/test-infra/.github/actions/setup-ssh@main + with: + github-secret: ${{ secrets.GITHUB_TOKEN }} + instructions: | + All testing is done inside the container, to start an interactive session run: + docker exec -it $(docker container ps --format '{{.ID}}') bash + - name: Checkout PyTorch uses: pytorch/pytorch/.github/actions/checkout-pytorch@master - name: Setup Linux uses: ./.github/actions/setup-linux - - name: Setup SSH (Click me for login details) - uses: pytorch/test-infra/.github/actions/setup-ssh@main - with: - github-secret: ${{ secrets.GITHUB_TOKEN }} - - name: Pull docker image uses: pytorch/test-infra/.github/actions/pull-docker-image@main with: docker-image: ${{ inputs.docker-image }} - name: Install nvidia driver, nvidia-docker runtime, set GPU_FLAG - uses: nick-fields/retry@3e91a01664abd3c5cd539100d10d33b9c5b68482 + uses: pytorch/test-infra/.github/actions/setup-nvidia@main if: contains(inputs.build-environment, 'cuda') && !contains(matrix.config, 'nogpu') - with: - timeout_minutes: 10 - max_attempts: 3 - command: | - set -ex - bash .github/scripts/install_nvidia_utils_linux.sh - echo "GPU_FLAG=--gpus all" >> "${GITHUB_ENV}" - name: Start monitoring script id: monitor-script @@ -122,7 +125,8 @@ jobs: DOCKER_IMAGE: ${{ inputs.docker-image }} XLA_CUDA: ${{ contains(inputs.build-environment, 'xla') && '0' || '' }} XLA_CLANG_CACHE_S3_BUCKET_NAME: ossci-compiler-clang-cache-circleci-xla - timeout-minutes: 240 + PYTORCH_TEST_CUDA_MEM_LEAK_CHECK: ${{ matrix.mem_leak_check && '1' || '0' }} + PYTORCH_TEST_RERUN_DISABLED_TESTS: ${{ matrix.rerun_disabled_tests && '1' || '0' }} run: | set -x @@ -175,6 +179,8 @@ jobs: -e SCCACHE_S3_KEY_PREFIX \ -e XLA_CUDA \ -e XLA_CLANG_CACHE_S3_BUCKET_NAME \ + -e PYTORCH_TEST_CUDA_MEM_LEAK_CHECK \ + -e PYTORCH_TEST_RERUN_DISABLED_TESTS \ --env-file="/tmp/github_env_${GITHUB_RUN_ID}" \ --ulimit stack=10485760:83886080 \ --security-opt seccomp=unconfined \ @@ -189,6 +195,7 @@ jobs: -w /var/lib/jenkins/workspace \ "${DOCKER_IMAGE}" ) + echo "DOCKER_CONTAINER_ID=${container_name}" >> "${GITHUB_ENV}" docker exec -t "${container_name}" sh -c "pip install $(echo dist/*.whl)[opt-einsum] && ${TEST_COMMAND}" - name: Get workflow job id @@ -213,6 +220,12 @@ jobs: with: file-suffix: ${{ github.job }}-${{ matrix.config }}-${{ matrix.shard }}-${{ matrix.num_shards }}-${{ matrix.runner }}_${{ steps.get-job-id.outputs.job-id }} + - name: Collect backtraces from coredumps (if any) + if: always() + run: | + # shellcheck disable=SC2156 + find . -iname "core.[1-9]*" -exec docker exec "${DOCKER_CONTAINER_ID}" sh -c "gdb python {} -ex 'bt' -ex 'q'" \; + - name: Store Core dumps on S3 uses: seemethere/upload-artifact-s3@v5 if: failure() diff --git a/.github/workflows/_mac-build.yml b/.github/workflows/_mac-build.yml index 895b07164213e..ac018b66b9ee7 100644 --- a/.github/workflows/_mac-build.yml +++ b/.github/workflows/_mac-build.yml @@ -33,6 +33,10 @@ on: default: "3.8" description: | The python version to be used. Will be 3.8 by default + environment-file: + required: false + type: string + description: Set the conda environment file used to setup macOS build. test-matrix: required: false type: string @@ -59,8 +63,8 @@ on: jobs: build: - # Don't run on forked repos. - if: github.repository_owner == 'pytorch' + # # Don't run on forked repos. + # if: github.repository_owner == 'pytorch' runs-on: ${{ inputs.runner-type }} env: # For sccache access (only on non-forked PRs) @@ -83,9 +87,20 @@ jobs: fi - name: Setup miniconda + if: inputs.environment-file == '' + uses: pytorch/test-infra/.github/actions/setup-miniconda@main + with: + python-version: ${{ inputs.python_version }} + environment-file: .github/requirements/conda-env-${{ runner.os }}-${{ runner.arch }} + + # This option is used when cross-compiling arm64 from x86-64. Specifically, we need arm64 conda + # environment even though the arch is x86-64 + - name: Setup miniconda using the provided environment file + if: inputs.environment-file != '' uses: pytorch/test-infra/.github/actions/setup-miniconda@main with: python-version: ${{ inputs.python_version }} + environment-file: ${{ inputs.environment-file }} - name: Install macOS homebrew dependencies run: | @@ -94,12 +109,17 @@ jobs: brew link --force libomp - name: Install sccache (only for non-forked PRs, and pushes to trunk) + uses: nick-fields/retry@v2.8.2 if: ${{ github.event_name == 'push' || github.event.pull_request.head.repo.full_name == github.repository }} - run: | - sudo curl --retry 3 https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache - sudo chmod +x /usr/local/bin/sccache - echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" - echo "SCCACHE_S3_KEY_PREFIX=${GITHUB_WORKFLOW}" >> "${GITHUB_ENV}" + with: + timeout_minutes: 5 + max_attempts: 3 + retry_wait_seconds: 90 + command: | + sudo curl --retry 3 --retry-all-errors https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache + sudo chmod +x /usr/local/bin/sccache + echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" + echo "SCCACHE_S3_KEY_PREFIX=${GITHUB_WORKFLOW}" >> "${GITHUB_ENV}" - name: Get workflow job id id: get-job-id @@ -131,7 +151,7 @@ jobs: zip -1 -r artifacts.zip dist/ build/.ninja_log build/compile_commands.json .pytorch-test-times.json - name: Store PyTorch Build Artifacts on GHA - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 if: inputs.build-generates-artifacts && steps.build.outcome != 'skipped' with: name: ${{ env.BUILD_ENVIRONMENT }} @@ -140,7 +160,7 @@ jobs: path: artifacts.zip - name: Upload sccache stats to GHA - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 # Only if sccache is installed, see above if: ${{ (github.event_name == 'push' || github.event.pull_request.head.repo.full_name == github.repository) && steps.build.outcome != 'skipped' }} with: diff --git a/.github/workflows/_mac-test-mps.yml b/.github/workflows/_mac-test-mps.yml index e2c6ec74d3f44..5fac3126e20d5 100644 --- a/.github/workflows/_mac-test-mps.yml +++ b/.github/workflows/_mac-test-mps.yml @@ -14,11 +14,16 @@ on: description: | If this is set, our linter will use this to make sure that every other job with the same `sync-tag` is identical. + runs-on: + required: false + type: string + default: "macos-m1-12" + description: Hardware to run tests on jobs: run_mps_test: name: "Run MPS tests" - runs-on: macos-m1-12 + runs-on: ${{ inputs.runs-on }} steps: - name: Checkout PyTorch uses: malfet/checkout@silent-checkout @@ -37,10 +42,21 @@ jobs: name: ${{ inputs.build-environment }} use-gha: true + # This is copied from the main macos test workflow. It was missed in the earlier fix because macos M1 + # runners are shared and not ephemeral, so the issue wasn't manifested if the runners with the fix were + # used + - name: Install macOS homebrew dependencies + run: | + # Install dependencies + brew install libomp + brew link --force libomp + - name: Setup miniconda uses: pytorch/test-infra/.github/actions/setup-miniconda@main with: python-version: 3.9 + environment-file: .github/requirements/conda-env-${{ runner.os }}-${{ runner.arch }} + pip-requirements-file: .github/requirements/pip-requirements-${{ runner.os }}.txt - name: Install PyTorch env: @@ -50,13 +66,12 @@ jobs: run: | # shellcheck disable=SC1090 set -ex - ${CONDA_INSTALL} expecttest numpy=1.22.3 pyyaml=6.0 - ${CONDA_RUN} python3 -mpip install "unittest-xml-reporting<=3.2.0,>=2.0.0" # As wheels are cross-compiled they are reported as x86_64 ones ORIG_WHLNAME=$(ls -1 dist/*.whl); ARM_WHLNAME=${ORIG_WHLNAME/x86_64/arm64}; mv ${ORIG_WHLNAME} ${ARM_WHLNAME} - ${CONDA_RUN} python3 -mpip install dist/*.whl + ${CONDA_RUN} python3 -mpip install --no-index --no-deps dist/*.whl - name: Run MPS tests + id: test env: ENV_NAME: conda-test-env-${{ github.run_id }} shell: arch -arch arm64 bash {0} @@ -65,5 +80,24 @@ jobs: set -ex # TODO(https://github.com/pytorch/pytorch/issues/79293) - ${CONDA_RUN} --cwd test python3 test_mps.py -v - ${CONDA_RUN} --cwd test python3 test_metal.py -v + ${CONDA_RUN} python3 test/run_test.py --mps --verbose + + - name: Print remaining test logs + shell: bash + if: always() + run: | + cat test/**/*.log || true + + - name: Get workflow job id + id: get-job-id + uses: ./.github/actions/get-workflow-job-id + if: always() + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + + - name: Upload test artifacts + uses: ./.github/actions/upload-test-artifacts + if: always() && steps.test.conclusion && steps.test.conclusion != 'skipped' + with: + use-gha: true + file-suffix: ${{ github.job }}-mps-1-1-macos-m1-12_${{ steps.get-job-id.outputs.job-id }} diff --git a/.github/workflows/_mac-test.yml b/.github/workflows/_mac-test.yml index 72ee311498503..39236a0dd0828 100644 --- a/.github/workflows/_mac-test.yml +++ b/.github/workflows/_mac-test.yml @@ -82,7 +82,6 @@ jobs: - name: Checkout PyTorch uses: pytorch/pytorch/.github/actions/checkout-pytorch@master - - name: Download build artifacts uses: ./.github/actions/download-build-artifacts with: @@ -94,18 +93,21 @@ jobs: uses: pytorch/test-infra/.github/actions/setup-miniconda@main with: python-version: 3.8 + environment-file: .github/requirements/conda-env-${{ runner.os }}-${{ runner.arch }} + pip-requirements-file: .github/requirements/pip-requirements-${{ runner.os }}.txt - name: Setup miniconda (arm64, py3.9) if: ${{ runner.arch == 'ARM64' }} uses: pytorch/test-infra/.github/actions/setup-miniconda@main with: python-version: 3.9 + environment-file: .github/requirements/conda-env-${{ runner.os }}-${{ runner.arch }} + pip-requirements-file: .github/requirements/pip-requirements-${{ runner.os }}.txt - name: Start monitoring script id: monitor-script + continue-on-error: true run: | - ${CONDA_RUN} python3 -m pip install psutil==5.9.1 - ${CONDA_RUN} python3 -m pip install pynvml==11.4.1 ${CONDA_RUN} python3 -m tools.stats.monitor > usage_log.txt 2>&1 & echo "monitor-script-pid=${!}" >> "${GITHUB_OUTPUT}" @@ -127,6 +129,9 @@ jobs: - name: Test id: test + env: + PYTORCH_TEST_CUDA_MEM_LEAK_CHECK: ${{ matrix.mem_leak_check && '1' || '0' }} + PYTORCH_TEST_RERUN_DISABLED_TESTS: ${{ matrix.rerun_disabled_tests && '1' || '0' }} run: | COMMIT_MESSAGES=$(git cherry -v "origin/${GIT_DEFAULT_BRANCH:-master}") @@ -143,9 +148,15 @@ jobs: export PR_BODY="${PR_BODY//[\'\"]}" arch - ${CONDA_RUN} python3 -mpip install $(echo dist/*.whl)[opt-einsum] + ${CONDA_RUN} python3 -mpip install --no-index --no-deps $(echo dist/*.whl) ${CONDA_RUN} .jenkins/pytorch/macos-test.sh + - name: Print remaining test logs + shell: bash + if: always() + run: | + cat test/**/*.log || true + - name: Get workflow job id id: get-job-id uses: ./.github/actions/get-workflow-job-id @@ -163,7 +174,7 @@ jobs: - name: Upload test artifacts uses: ./.github/actions/upload-test-artifacts - if: always() && (steps.test.conclusion == 'success' || steps.test.conclusion == 'failure') + if: always() && steps.test.conclusion && steps.test.conclusion != 'skipped' with: use-gha: true file-suffix: ${{ github.job }}-${{ matrix.config }}-${{ matrix.shard }}-${{ matrix.num_shards }}-${{ matrix.runner }}_${{ steps.get-job-id.outputs.job-id }} @@ -189,6 +200,4 @@ jobs: GHA_WORKFLOW_JOB_ID: ${{ steps.get-job-id.outputs.job-id }} run: | set -x - ${CONDA_RUN} python3 -m pip install -r requirements.txt - ${CONDA_RUN} python3 -m pip install boto3==1.19.12 ${CONDA_RUN} python3 -m tools.stats.print_test_stats --upload-to-s3 --compare-with-s3 test diff --git a/.github/workflows/_rocm-test.yml b/.github/workflows/_rocm-test.yml index 62d487ffe3441..be4a5c9dcc6cd 100644 --- a/.github/workflows/_rocm-test.yml +++ b/.github/workflows/_rocm-test.yml @@ -39,12 +39,34 @@ env: GIT_DEFAULT_BRANCH: ${{ github.event.repository.default_branch }} jobs: + # This needs to be run right before the test starts so that it can gather the + # latest labels from the PR + filter: + runs-on: [self-hosted, linux.large] + outputs: + test-matrix: ${{ steps.filter.outputs.test-matrix }} + is-test-matrix-empty: ${{ steps.filter.outputs.is-test-matrix-empty }} + steps: + - name: Checkout PyTorch + uses: pytorch/pytorch/.github/actions/checkout-pytorch@master + with: + fetch-depth: 1 + submodules: false + + - name: Select all requested test configurations + id: filter + uses: ./.github/actions/filter-test-configs + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + test-matrix: ${{ inputs.test-matrix }} + test: - # Don't run on forked repos. - if: github.repository_owner == 'pytorch' + needs: filter + # Don't run on forked repos or empty test matrix + if: github.repository_owner == 'pytorch' && needs.filter.outputs.is-test-matrix-empty == 'False' timeout-minutes: 300 strategy: - matrix: ${{ fromJSON(inputs.test-matrix) }} + matrix: ${{ fromJSON(needs.filter.outputs.test-matrix) }} fail-fast: false runs-on: ${{ matrix.runner }} steps: @@ -97,6 +119,8 @@ jobs: DOCKER_IMAGE: ${{ inputs.docker-image }} XLA_CLANG_CACHE_S3_BUCKET_NAME: ossci-compiler-clang-cache-circleci-xla PYTORCH_JIT_ENABLE_NVFUSER: 1 + PYTORCH_TEST_CUDA_MEM_LEAK_CHECK: ${{ matrix.mem_leak_check && '1' || '0' }} + PYTORCH_TEST_RERUN_DISABLED_TESTS: ${{ matrix.rerun_disabled_tests && '1' || '0' }} timeout-minutes: 270 run: | set -x @@ -146,6 +170,8 @@ jobs: -e MAX_JOBS="$(nproc --ignore=2)" \ -e SCCACHE_BUCKET \ -e XLA_CLANG_CACHE_S3_BUCKET_NAME \ + -e PYTORCH_TEST_CUDA_MEM_LEAK_CHECK \ + -e PYTORCH_TEST_RERUN_DISABLED_TESTS \ --env-file="/tmp/github_env_${GITHUB_RUN_ID}" \ --ulimit stack=10485760:83886080 \ --security-opt seccomp=unconfined \ diff --git a/.github/workflows/_run_android_tests.yml b/.github/workflows/_run_android_tests.yml index 273ec2db81aed..d949e193b76b3 100644 --- a/.github/workflows/_run_android_tests.yml +++ b/.github/workflows/_run_android_tests.yml @@ -11,31 +11,16 @@ jobs: build-and-test: runs-on: ubuntu-latest steps: - - name: Setup miniconda - uses: conda-incubator/setup-miniconda@v2 - with: - auto-update-conda: true - python-version: 3.8 - activate-environment: build - - - name: Install dependencies - run: | - conda install -y \ - cffi \ - cmake \ - mkl \ - mkl-include \ - ninja \ - numpy \ - pyyaml \ - requests \ - setuptools \ - typing_extensions - # [see note: pytorch repo ref] - name: Checkout PyTorch uses: pytorch/pytorch/.github/actions/checkout-pytorch@master + - name: Setup miniconda + uses: pytorch/test-infra/.github/actions/setup-miniconda@main + with: + python-version: 3.8 + environment-file: .github/requirements/conda-env-${{ runner.os }}-${{ runner.arch }} + - name: Build PyTorch Android run: | # Install NDK 21 after GitHub update @@ -49,7 +34,7 @@ jobs: ln -sfn ${ANDROID_SDK_ROOT}/ndk/21.4.7075529 ${ANDROID_NDK} echo "CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname "$(which conda)")/../"}" >> "${GITHUB_ENV}" - ./scripts/build_pytorch_android.sh x86 + ${CONDA_RUN} ./scripts/build_pytorch_android.sh x86 - name: Run tests uses: reactivecircus/android-emulator-runner@v2 diff --git a/.github/workflows/_win-build.yml b/.github/workflows/_win-build.yml index faa37060d321d..b04dc7f6626cb 100644 --- a/.github/workflows/_win-build.yml +++ b/.github/workflows/_win-build.yml @@ -46,6 +46,20 @@ jobs: runs-on: [self-hosted, windows.4xlarge] timeout-minutes: 240 steps: + - name: Setup SSH (Click me for login details) + uses: pytorch/test-infra/.github/actions/setup-ssh@main + with: + github-secret: ${{ secrets.GITHUB_TOKEN }} + instructions: | + To forward remote desktop on your local machine ssh as follows: + ssh -L 3389:localhost:3389 %%username%%@%%hostname%% + And then change password using `passwd` command. + + To start build locally, change working folder to \actions-runner\_work\pytorch\pytorch, + Activate miniconda and Visual Studio environment, by running: + call C:\Jenkins\Miniconda3\Scripts\activate.bat C:\Jenkins\Miniconda3 + call "C:\Program Files (x86)\Microsoft Visual Studio\2019\BuildTools\VC\Auxiliary\Build\vcvarsall.bat" x64 + # [see note: pytorch repo ref] - name: Checkout PyTorch uses: pytorch/pytorch/.github/actions/checkout-pytorch@master @@ -57,11 +71,6 @@ jobs: with: cuda-version: ${{ inputs.cuda-version }} - - name: Setup SSH (Click me for login details) - uses: pytorch/test-infra/.github/actions/setup-ssh@main - with: - github-secret: ${{ secrets.GITHUB_TOKEN }} - - name: Parse ref id: parse-ref run: .github/scripts/parse_ref.py diff --git a/.github/workflows/_win-test.yml b/.github/workflows/_win-test.yml index 4a099b742cce7..234b2c7faad78 100644 --- a/.github/workflows/_win-test.yml +++ b/.github/workflows/_win-test.yml @@ -78,6 +78,16 @@ jobs: uses: pytorch/test-infra/.github/actions/setup-ssh@main with: github-secret: ${{ secrets.GITHUB_TOKEN }} + instructions: | + To forward remote desktop on your local machine ssh as follows: + ssh -L 3389:localhost:3389 %%username%%@%%hostname%% + And then change password using `passwd` command. + + To start tests locally, change working folder to \actions-runner\_work\pytorch\pytorch\test, + Activate miniconda and Visual Studio environment and set PYTHON_PATH, by running: + call C:\Jenkins\Miniconda3\Scripts\activate.bat C:\Jenkins\Miniconda3 + call "C:\Program Files (x86)\Microsoft Visual Studio\2019\BuildTools\VC\Auxiliary\Build\vcvarsall.bat" x64 + set PYTHONPATH=C:\actions-runner\_work\pytorch\pytorch\build\win_tmp\build - name: Start monitoring script id: monitor-script @@ -124,6 +134,8 @@ jobs: TEST_CONFIG: ${{ matrix.config }} PR_BODY: ${{ github.event.pull_request.body }} TORCH_CUDA_ARCH_LIST: "7.0" + PYTORCH_TEST_CUDA_MEM_LEAK_CHECK: ${{ matrix.mem_leak_check && '1' || '0' }} + PYTORCH_TEST_RERUN_DISABLED_TESTS: ${{ matrix.rerun_disabled_tests && '1' || '0' }} run: | COMMIT_MESSAGES=$(git cherry -v "origin/${GIT_DEFAULT_BRANCH:-master}") diff --git a/.github/workflows/auto_request_review.yml b/.github/workflows/auto_request_review.yml index 01df7a054005f..7c98c2990fba7 100644 --- a/.github/workflows/auto_request_review.yml +++ b/.github/workflows/auto_request_review.yml @@ -6,6 +6,8 @@ on: jobs: auto-request-review: + # Don't run on forked repos + if: ${{ !github.event.pull_request.head.repo.fork }} name: Auto Request Review runs-on: ubuntu-latest steps: diff --git a/.github/workflows/build-triton-wheel.yml b/.github/workflows/build-triton-wheel.yml index 074d53498faa6..171495c0322d2 100644 --- a/.github/workflows/build-triton-wheel.yml +++ b/.github/workflows/build-triton-wheel.yml @@ -3,7 +3,8 @@ name: Build Triton wheels on: push: branches: - main + - main + - master paths: - .github/workflows/build-triton-wheel.yml - .github/scripts/build_triton_wheel.py @@ -30,6 +31,11 @@ jobs: DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.6 PY_VERS: ${{ matrix.py_vers }} steps: + - name: Setup SSH (Click me for login details) + uses: pytorch/test-infra/.github/actions/setup-ssh@main + with: + github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Checkout PyTorch uses: pytorch/pytorch/.github/actions/checkout-pytorch@master with: @@ -38,11 +44,6 @@ jobs: - name: Setup Linux uses: ./.github/actions/setup-linux - - name: Setup SSH (Click me for login details) - uses: pytorch/test-infra/.github/actions/setup-ssh@main - with: - github-secret: ${{ secrets.GITHUB_TOKEN }} - - name: Pull Docker image uses: pytorch/test-infra/.github/actions/pull-docker-image@main with: @@ -84,13 +85,129 @@ jobs: ;; esac - docker exec -t "${container_name}" yum install -y zlib-devel + docker exec -t "${container_name}" yum install -y llvm11 llvm11-devel llvm11-static llvm11-libs zlib-devel docker exec -t "${container_name}" "${PYTHON_EXECUTABLE}" /pytorch/.github/scripts/build_triton_wheel.py docker exec -t "${container_name}" chown -R 1000.1000 /artifacts - uses: actions/upload-artifact@v3 with: - name: "pytorch-triton-${{ matrix.py_vers }}" + name: "pytorch-triton-wheel-${{ matrix.py_vers }}" + if-no-files-found: error + path: + ${{ runner.temp }}/artifacts/* + + - name: Teardown Linux + uses: pytorch/test-infra/.github/actions/teardown-linux@main + if: always() + upload-wheel: + runs-on: linux.20_04.4x + needs: build-wheel + container: + image: continuumio/miniconda3:4.12.0 + env: + GITHUB_TOKEN: ${{ secrets.github-token }} + steps: + - name: Download Build Artifacts (3.7) + uses: actions/download-artifact@v3 + with: + name: "pytorch-triton-wheel-3.7" + path: "${{ runner.temp }}/artifacts/" + - name: Download Build Artifacts (3.8) + uses: actions/download-artifact@v3 + with: + name: "pytorch-triton-wheel-3.8" + path: "${{ runner.temp }}/artifacts/" + - name: Download Build Artifacts (3.9) + uses: actions/download-artifact@v3 + with: + name: "pytorch-triton-wheel-3.9" + path: "${{ runner.temp }}/artifacts/" + - name: Download Build Artifacts (3.10) + uses: actions/download-artifact@v3 + with: + name: "pytorch-triton-wheel-3.10" + path: "${{ runner.temp }}/artifacts/" + - name: Download Build Artifacts (3.11) + uses: actions/download-artifact@v3 + with: + name: "pytorch-triton-wheel-3.11" + path: "${{ runner.temp }}/artifacts/" + - name: Upload binaries + if: ${{ github.event_name == 'push' && (github.event.ref == 'refs/heads/master' || github.event.ref == 'refs/heads/main') }} + env: + PKG_DIR: "${{ runner.temp }}/artifacts" + # When running these on pull_request events these should be blank + AWS_ACCESS_KEY_ID: ${{ secrets.AWS_S3_UPDATE_ACCESS_KEY_ID }} + AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_S3_UPDATE_SECRET_ACCESS_KEY }} + UPLOAD_BUCKET: "s3://pytorch" + run: | + set -ex + pip install -q awscli + s3_dir="${UPLOAD_BUCKET}/whl/nightly/" + for pkg in "${PKG_DIR}/"*.whl; do + aws s3 cp --no-progress --acl public-read "${pkg}" "${s3_dir}" + done + build-conda: + runs-on: [self-hosted, linux.2xlarge] + strategy: + fail-fast: false + matrix: + py_vers: [ "3.7", "3.8", "3.9", "3.10" ] + timeout-minutes: 40 + env: + DOCKER_IMAGE: pytorch/conda-builder:cuda11.6 + PY_VERS: ${{ matrix.py_vers }} + ANACONDA_API_TOKEN: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} + steps: + - name: Setup SSH (Click me for login details) + uses: pytorch/test-infra/.github/actions/setup-ssh@main + with: + github-secret: ${{ secrets.GITHUB_TOKEN }} + + - name: Checkout PyTorch + uses: pytorch/pytorch/.github/actions/checkout-pytorch@master + with: + submodules: false + + - name: Setup Linux + uses: ./.github/actions/setup-linux + + - name: Pull Docker image + uses: pytorch/test-infra/.github/actions/pull-docker-image@main + with: + docker-image: ${{ env.DOCKER_IMAGE }} + + - name: Build Triton conda package + run: | + set -x + mkdir -p "${RUNNER_TEMP}/artifacts/" + container_name=$(docker run \ + --tty \ + --detach \ + -v "${GITHUB_WORKSPACE}:/pytorch" \ + -v "${RUNNER_TEMP}/artifacts:/artifacts" \ + -w /artifacts/ \ + -e ANACONDA_API_TOKEN \ + "${DOCKER_IMAGE}" \ + ) + + docker exec -t "${container_name}" yum install -y llvm11 llvm11-devel llvm11-static llvm11-libs zlib-devel + docker exec -t "${container_name}" python /pytorch/.github/scripts/build_triton_wheel.py --build-conda --py-version="${PY_VERS}" + + - name: Upload artifacts to Anaconda + if: ${{ github.event_name == 'push' && (github.event.ref == 'refs/heads/master' || github.event.ref == 'refs/heads/main') }} + run: | + container_name=$(docker container ps --format '{{.ID}}') + docker exec -t "${container_name}" sh -c "anaconda upload /artifacts/torch*.tar.bz2 -u pytorch-nightly --label main --no-progress --force" + + - name: Chown artifacts + run: | + container_name=$(docker container ps --format '{{.ID}}') + docker exec -t "${container_name}" chown -R 1000.1000 /artifacts + + - uses: actions/upload-artifact@v3 + with: + name: "pytorch-triton-conda-${{ matrix.py_vers }}" if-no-files-found: error path: ${{ runner.temp }}/artifacts/* diff --git a/.github/workflows/docker-builds.yml b/.github/workflows/docker-builds.yml index 62699dde2243d..3108f4b926a89 100644 --- a/.github/workflows/docker-builds.yml +++ b/.github/workflows/docker-builds.yml @@ -33,20 +33,15 @@ jobs: strategy: matrix: include: - - docker-image-name: pytorch-linux-bionic-cuda10.2-cudnn7-py3.9-gcc7 - docker-image-name: pytorch-linux-bionic-cuda11.3-cudnn8-py3-clang9 - docker-image-name: pytorch-linux-bionic-cuda11.6-cudnn8-py3-gcc7 - docker-image-name: pytorch-linux-bionic-cuda11.7-cudnn8-py3-gcc7 - docker-image-name: pytorch-linux-bionic-py3.7-clang9 - - docker-image-name: pytorch-linux-focal-rocm5.1-py3.7 - - docker-image-name: pytorch-linux-focal-rocm5.2-py3.7 + - docker-image-name: pytorch-linux-focal-rocm5.1-py3.8 + - docker-image-name: pytorch-linux-focal-rocm5.2-py3.8 - docker-image-name: pytorch-linux-jammy-cuda11.6-cudnn8-py3.8-clang12 - docker-image-name: pytorch-linux-jammy-cuda11.7-cudnn8-py3.8-clang12 - - docker-image-name: pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7 - - docker-image-name: pytorch-linux-xenial-cuda11.3-cudnn8-py3-gcc7 - - docker-image-name: pytorch-linux-xenial-py3-clang5-android-ndk-r19c - - docker-image-name: pytorch-linux-xenial-py3-clang5-asan - - docker-image-name: pytorch-linux-xenial-py3-clang7-onnx + - docker-image-name: pytorch-linux-focal-py3-clang7-android-ndk-r19c - docker-image-name: pytorch-linux-focal-py3.7-gcc7 - docker-image-name: pytorch-linux-focal-py3-clang7-asan - docker-image-name: pytorch-linux-focal-py3-clang10-onnx diff --git a/.github/workflows/docker-release.yml b/.github/workflows/docker-release.yml index d61b1b2c1242b..0f9638e210ade 100644 --- a/.github/workflows/docker-release.yml +++ b/.github/workflows/docker-release.yml @@ -45,6 +45,10 @@ jobs: BUILD_IMAGE_TYPE: ${{ matrix.image_type }} BUILD_PLATFORMS: ${{ matrix.platform }} steps: + - name: Setup SSH (Click me for login details) + uses: pytorch/test-infra/.github/actions/setup-ssh@main + with: + github-secret: ${{ secrets.GITHUB_TOKEN }} # [see note: pytorch repo ref] # deep clone (fetch-depth 0) required for git merge-base - name: Checkout PyTorch @@ -54,10 +58,6 @@ jobs: submodules: 'recursive' - name: Setup Linux uses: ./.github/actions/setup-linux - - name: Setup SSH (Click me for login details) - uses: pytorch/test-infra/.github/actions/setup-ssh@main - with: - github-secret: ${{ secrets.GITHUB_TOKEN }} - name: Login to GitHub Container Registry if: ${{ env.WITH_PUSH == 'true' }} uses: docker/login-action@v2 @@ -80,14 +80,31 @@ jobs: # Generate PyTorch version to use echo "PYTORCH_VERSION=$(python3 .github/scripts/generate_pytorch_version.py)" >> "${GITHUB_ENV}" - name: Setup nightly specific variables - if: ${{ github.event.ref == 'refs/heads/nightly' }} + if: ${{ github.event.ref == 'refs/heads/nightly' || startsWith(github.event.ref, 'refs/tags/ciflow/nightly/') }} run: | - # Use nightly image if building for nightly - echo "DOCKER_IMAGE=pytorch-nightly" >> "${GITHUB_ENV}" + { + echo "DOCKER_IMAGE=pytorch-nightly"; + echo "INSTALL_CHANNEL=pytorch-nightly"; + echo "TRITON_VERSION=2.0.0+$(cut -c -10 .github/ci_commit_pins/triton.txt)"; + } >> "${GITHUB_ENV}" - name: Run docker build / push # WITH_PUSH is used here to determine whether or not to add the --push flag run: | make -f docker.Makefile "${BUILD_IMAGE_TYPE}-image" + - name: Push nightly tags + if: ${{ github.event.ref == 'refs/heads/nightly' && matrix.image_type == 'runtime' }} + run: | + PYTORCH_DOCKER_TAG="${PYTORCH_VERSION}-runtime" + CUDA_VERSION=$(python3 -c "import re;print(re.search('CUDA_VERSION\s+=\s+([0-9\.]+)',open('docker.Makefile').read())[1],end='')") + PYTORCH_NIGHTLY_COMMIT=$(docker run ghcr.io/pytorch/pytorch-nightly:"${PYTORCH_DOCKER_TAG}" \ + python -c 'import torch; print(torch.version.git_version[:7],end="")') + docker tag ghcr.io/pytorch/pytorch-nightly:"${PYTORCH_DOCKER_TAG}" \ + ghcr.io/pytorch/pytorch-nightly:"${PYTORCH_NIGHTLY_COMMIT}-cu${CUDA_VERSION}" + docker push ghcr.io/pytorch/pytorch-nightly:"${PYTORCH_NIGHTLY_COMMIT}-cu${CUDA_VERSION}" + + docker tag ghcr.io/pytorch/pytorch-nightly:"${PYTORCH_NIGHTLY_COMMIT}-cu${CUDA_VERSION}" \ + ghcr.io/pytorch/pytorch-nightly:latest + docker push ghcr.io/pytorch/pytorch-nightly:latest - name: Teardown Linux uses: pytorch/test-infra/.github/actions/teardown-linux@main if: always() diff --git a/.github/workflows/generated-linux-binary-libtorch-cxx11-abi-nightly.yml b/.github/workflows/generated-linux-binary-libtorch-cxx11-abi-nightly.yml index 6a23b85f433a0..6b1765b9a405d 100644 --- a/.github/workflows/generated-linux-binary-libtorch-cxx11-abi-nightly.yml +++ b/.github/workflows/generated-linux-binary-libtorch-cxx11-abi-nightly.yml @@ -780,7 +780,7 @@ jobs: aws-pytorch-uploader-secret-access-key: ${{ secrets.AWS_PYTORCH_UPLOADER_SECRET_ACCESS_KEY }} conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - libtorch-rocm5_1_1-shared-with-deps-cxx11-abi-build: + libtorch-rocm5_2-shared-with-deps-cxx11-abi-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml with: @@ -789,20 +789,20 @@ jobs: PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm5.1.1 - GPU_ARCH_VERSION: 5.1.1 + DESIRED_CUDA: rocm5.2 + GPU_ARCH_VERSION: 5.2 GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:rocm5.1.1 + DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:rocm5.2 LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: cxx11-abi - build_name: libtorch-rocm5_1_1-shared-with-deps-cxx11-abi + build_name: libtorch-rocm5_2-shared-with-deps-cxx11-abi build_environment: linux-binary-libtorch-cxx11-abi secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - libtorch-rocm5_1_1-shared-with-deps-cxx11-abi-test: # Testing + libtorch-rocm5_2-shared-with-deps-cxx11-abi-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: libtorch-rocm5_1_1-shared-with-deps-cxx11-abi-build + needs: libtorch-rocm5_2-shared-with-deps-cxx11-abi-build runs-on: linux.rocm.gpu timeout-minutes: 240 env: @@ -811,11 +811,11 @@ jobs: PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm5.1.1 - GPU_ARCH_VERSION: 5.1.1 + DESIRED_CUDA: rocm5.2 + GPU_ARCH_VERSION: 5.2 GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 - DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:rocm5.1.1 + DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:rocm5.2 LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: cxx11-abi steps: @@ -845,7 +845,12 @@ jobs: run: | ngpu=$(rocminfo | grep -c -E 'Name:.*\sgfx') if [[ "x$ngpu" != "x2" && "x$ngpu" != "x4" ]]; then - echo "Failed to detect GPUs on the runner" + if [[ $ngpu -eq 0 ]]; then + echo "Error: Failed to detect any GPUs on the runner" + else + echo "Error: Detected $ngpu GPUs on the runner, when only 2 or 4 were expected" + fi + echo "Please file an issue on pytorch/pytorch reporting the faulty runner. Include a link to the runner logs so the runner can be identified" exit 1 fi - name: Runner health check disconnect on failure @@ -859,7 +864,7 @@ jobs: - uses: actions/download-artifact@v3 name: Download Build Artifacts with: - name: libtorch-rocm5_1_1-shared-with-deps-cxx11-abi + name: libtorch-rocm5_2-shared-with-deps-cxx11-abi path: "${{ runner.temp }}/artifacts/" - name: Checkout PyTorch uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 @@ -890,7 +895,7 @@ jobs: - name: Pull Docker image uses: pytorch/test-infra/.github/actions/pull-docker-image@main with: - docker-image: pytorch/libtorch-cxx11-builder:rocm5.1.1 + docker-image: pytorch/libtorch-cxx11-builder:rocm5.2 - name: Test Pytorch binary uses: ./pytorch/.github/actions/test-pytorch-binary - name: Kill containers, clean up images @@ -901,29 +906,29 @@ jobs: docker stop $(docker ps -q) || true # Prune all of the docker images docker system prune -af - libtorch-rocm5_1_1-shared-with-deps-cxx11-abi-upload: # Uploading + libtorch-rocm5_2-shared-with-deps-cxx11-abi-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} - needs: libtorch-rocm5_1_1-shared-with-deps-cxx11-abi-test + needs: libtorch-rocm5_2-shared-with-deps-cxx11-abi-test with: PYTORCH_ROOT: /pytorch BUILDER_ROOT: /builder PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm5.1.1 - GPU_ARCH_VERSION: 5.1.1 + DESIRED_CUDA: rocm5.2 + GPU_ARCH_VERSION: 5.2 GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:rocm5.1.1 + DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:rocm5.2 LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: cxx11-abi - build_name: libtorch-rocm5_1_1-shared-with-deps-cxx11-abi + build_name: libtorch-rocm5_2-shared-with-deps-cxx11-abi secrets: github-token: ${{ secrets.GITHUB_TOKEN }} aws-access-key-id: ${{ secrets.AWS_PYTORCH_UPLOADER_ACCESS_KEY_ID }} aws-pytorch-uploader-secret-access-key: ${{ secrets.AWS_PYTORCH_UPLOADER_SECRET_ACCESS_KEY }} conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - libtorch-rocm5_1_1-static-with-deps-cxx11-abi-build: + libtorch-rocm5_2-static-with-deps-cxx11-abi-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml with: @@ -932,20 +937,20 @@ jobs: PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm5.1.1 - GPU_ARCH_VERSION: 5.1.1 + DESIRED_CUDA: rocm5.2 + GPU_ARCH_VERSION: 5.2 GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:rocm5.1.1 + DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:rocm5.2 LIBTORCH_VARIANT: static-with-deps DESIRED_DEVTOOLSET: cxx11-abi - build_name: libtorch-rocm5_1_1-static-with-deps-cxx11-abi + build_name: libtorch-rocm5_2-static-with-deps-cxx11-abi build_environment: linux-binary-libtorch-cxx11-abi secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - libtorch-rocm5_1_1-static-with-deps-cxx11-abi-test: # Testing + libtorch-rocm5_2-static-with-deps-cxx11-abi-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: libtorch-rocm5_1_1-static-with-deps-cxx11-abi-build + needs: libtorch-rocm5_2-static-with-deps-cxx11-abi-build runs-on: linux.rocm.gpu timeout-minutes: 240 env: @@ -954,11 +959,11 @@ jobs: PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm5.1.1 - GPU_ARCH_VERSION: 5.1.1 + DESIRED_CUDA: rocm5.2 + GPU_ARCH_VERSION: 5.2 GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 - DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:rocm5.1.1 + DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:rocm5.2 LIBTORCH_VARIANT: static-with-deps DESIRED_DEVTOOLSET: cxx11-abi steps: @@ -988,7 +993,12 @@ jobs: run: | ngpu=$(rocminfo | grep -c -E 'Name:.*\sgfx') if [[ "x$ngpu" != "x2" && "x$ngpu" != "x4" ]]; then - echo "Failed to detect GPUs on the runner" + if [[ $ngpu -eq 0 ]]; then + echo "Error: Failed to detect any GPUs on the runner" + else + echo "Error: Detected $ngpu GPUs on the runner, when only 2 or 4 were expected" + fi + echo "Please file an issue on pytorch/pytorch reporting the faulty runner. Include a link to the runner logs so the runner can be identified" exit 1 fi - name: Runner health check disconnect on failure @@ -1002,7 +1012,7 @@ jobs: - uses: actions/download-artifact@v3 name: Download Build Artifacts with: - name: libtorch-rocm5_1_1-static-with-deps-cxx11-abi + name: libtorch-rocm5_2-static-with-deps-cxx11-abi path: "${{ runner.temp }}/artifacts/" - name: Checkout PyTorch uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 @@ -1033,7 +1043,7 @@ jobs: - name: Pull Docker image uses: pytorch/test-infra/.github/actions/pull-docker-image@main with: - docker-image: pytorch/libtorch-cxx11-builder:rocm5.1.1 + docker-image: pytorch/libtorch-cxx11-builder:rocm5.2 - name: Test Pytorch binary uses: ./pytorch/.github/actions/test-pytorch-binary - name: Kill containers, clean up images @@ -1044,29 +1054,29 @@ jobs: docker stop $(docker ps -q) || true # Prune all of the docker images docker system prune -af - libtorch-rocm5_1_1-static-with-deps-cxx11-abi-upload: # Uploading + libtorch-rocm5_2-static-with-deps-cxx11-abi-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} - needs: libtorch-rocm5_1_1-static-with-deps-cxx11-abi-test + needs: libtorch-rocm5_2-static-with-deps-cxx11-abi-test with: PYTORCH_ROOT: /pytorch BUILDER_ROOT: /builder PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm5.1.1 - GPU_ARCH_VERSION: 5.1.1 + DESIRED_CUDA: rocm5.2 + GPU_ARCH_VERSION: 5.2 GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:rocm5.1.1 + DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:rocm5.2 LIBTORCH_VARIANT: static-with-deps DESIRED_DEVTOOLSET: cxx11-abi - build_name: libtorch-rocm5_1_1-static-with-deps-cxx11-abi + build_name: libtorch-rocm5_2-static-with-deps-cxx11-abi secrets: github-token: ${{ secrets.GITHUB_TOKEN }} aws-access-key-id: ${{ secrets.AWS_PYTORCH_UPLOADER_ACCESS_KEY_ID }} aws-pytorch-uploader-secret-access-key: ${{ secrets.AWS_PYTORCH_UPLOADER_SECRET_ACCESS_KEY }} conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - libtorch-rocm5_2-shared-with-deps-cxx11-abi-build: + libtorch-rocm5_3-shared-with-deps-cxx11-abi-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml with: @@ -1075,20 +1085,20 @@ jobs: PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm5.2 - GPU_ARCH_VERSION: 5.2 + DESIRED_CUDA: rocm5.3 + GPU_ARCH_VERSION: 5.3 GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:rocm5.2 + DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:rocm5.3 LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: cxx11-abi - build_name: libtorch-rocm5_2-shared-with-deps-cxx11-abi + build_name: libtorch-rocm5_3-shared-with-deps-cxx11-abi build_environment: linux-binary-libtorch-cxx11-abi secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - libtorch-rocm5_2-shared-with-deps-cxx11-abi-test: # Testing + libtorch-rocm5_3-shared-with-deps-cxx11-abi-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: libtorch-rocm5_2-shared-with-deps-cxx11-abi-build + needs: libtorch-rocm5_3-shared-with-deps-cxx11-abi-build runs-on: linux.rocm.gpu timeout-minutes: 240 env: @@ -1097,11 +1107,11 @@ jobs: PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm5.2 - GPU_ARCH_VERSION: 5.2 + DESIRED_CUDA: rocm5.3 + GPU_ARCH_VERSION: 5.3 GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 - DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:rocm5.2 + DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:rocm5.3 LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: cxx11-abi steps: @@ -1131,7 +1141,12 @@ jobs: run: | ngpu=$(rocminfo | grep -c -E 'Name:.*\sgfx') if [[ "x$ngpu" != "x2" && "x$ngpu" != "x4" ]]; then - echo "Failed to detect GPUs on the runner" + if [[ $ngpu -eq 0 ]]; then + echo "Error: Failed to detect any GPUs on the runner" + else + echo "Error: Detected $ngpu GPUs on the runner, when only 2 or 4 were expected" + fi + echo "Please file an issue on pytorch/pytorch reporting the faulty runner. Include a link to the runner logs so the runner can be identified" exit 1 fi - name: Runner health check disconnect on failure @@ -1145,7 +1160,7 @@ jobs: - uses: actions/download-artifact@v3 name: Download Build Artifacts with: - name: libtorch-rocm5_2-shared-with-deps-cxx11-abi + name: libtorch-rocm5_3-shared-with-deps-cxx11-abi path: "${{ runner.temp }}/artifacts/" - name: Checkout PyTorch uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 @@ -1176,7 +1191,7 @@ jobs: - name: Pull Docker image uses: pytorch/test-infra/.github/actions/pull-docker-image@main with: - docker-image: pytorch/libtorch-cxx11-builder:rocm5.2 + docker-image: pytorch/libtorch-cxx11-builder:rocm5.3 - name: Test Pytorch binary uses: ./pytorch/.github/actions/test-pytorch-binary - name: Kill containers, clean up images @@ -1187,29 +1202,29 @@ jobs: docker stop $(docker ps -q) || true # Prune all of the docker images docker system prune -af - libtorch-rocm5_2-shared-with-deps-cxx11-abi-upload: # Uploading + libtorch-rocm5_3-shared-with-deps-cxx11-abi-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} - needs: libtorch-rocm5_2-shared-with-deps-cxx11-abi-test + needs: libtorch-rocm5_3-shared-with-deps-cxx11-abi-test with: PYTORCH_ROOT: /pytorch BUILDER_ROOT: /builder PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm5.2 - GPU_ARCH_VERSION: 5.2 + DESIRED_CUDA: rocm5.3 + GPU_ARCH_VERSION: 5.3 GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:rocm5.2 + DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:rocm5.3 LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: cxx11-abi - build_name: libtorch-rocm5_2-shared-with-deps-cxx11-abi + build_name: libtorch-rocm5_3-shared-with-deps-cxx11-abi secrets: github-token: ${{ secrets.GITHUB_TOKEN }} aws-access-key-id: ${{ secrets.AWS_PYTORCH_UPLOADER_ACCESS_KEY_ID }} aws-pytorch-uploader-secret-access-key: ${{ secrets.AWS_PYTORCH_UPLOADER_SECRET_ACCESS_KEY }} conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - libtorch-rocm5_2-static-with-deps-cxx11-abi-build: + libtorch-rocm5_3-static-with-deps-cxx11-abi-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml with: @@ -1218,20 +1233,20 @@ jobs: PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm5.2 - GPU_ARCH_VERSION: 5.2 + DESIRED_CUDA: rocm5.3 + GPU_ARCH_VERSION: 5.3 GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:rocm5.2 + DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:rocm5.3 LIBTORCH_VARIANT: static-with-deps DESIRED_DEVTOOLSET: cxx11-abi - build_name: libtorch-rocm5_2-static-with-deps-cxx11-abi + build_name: libtorch-rocm5_3-static-with-deps-cxx11-abi build_environment: linux-binary-libtorch-cxx11-abi secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - libtorch-rocm5_2-static-with-deps-cxx11-abi-test: # Testing + libtorch-rocm5_3-static-with-deps-cxx11-abi-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: libtorch-rocm5_2-static-with-deps-cxx11-abi-build + needs: libtorch-rocm5_3-static-with-deps-cxx11-abi-build runs-on: linux.rocm.gpu timeout-minutes: 240 env: @@ -1240,11 +1255,11 @@ jobs: PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm5.2 - GPU_ARCH_VERSION: 5.2 + DESIRED_CUDA: rocm5.3 + GPU_ARCH_VERSION: 5.3 GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 - DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:rocm5.2 + DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:rocm5.3 LIBTORCH_VARIANT: static-with-deps DESIRED_DEVTOOLSET: cxx11-abi steps: @@ -1274,7 +1289,12 @@ jobs: run: | ngpu=$(rocminfo | grep -c -E 'Name:.*\sgfx') if [[ "x$ngpu" != "x2" && "x$ngpu" != "x4" ]]; then - echo "Failed to detect GPUs on the runner" + if [[ $ngpu -eq 0 ]]; then + echo "Error: Failed to detect any GPUs on the runner" + else + echo "Error: Detected $ngpu GPUs on the runner, when only 2 or 4 were expected" + fi + echo "Please file an issue on pytorch/pytorch reporting the faulty runner. Include a link to the runner logs so the runner can be identified" exit 1 fi - name: Runner health check disconnect on failure @@ -1288,7 +1308,7 @@ jobs: - uses: actions/download-artifact@v3 name: Download Build Artifacts with: - name: libtorch-rocm5_2-static-with-deps-cxx11-abi + name: libtorch-rocm5_3-static-with-deps-cxx11-abi path: "${{ runner.temp }}/artifacts/" - name: Checkout PyTorch uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 @@ -1319,7 +1339,7 @@ jobs: - name: Pull Docker image uses: pytorch/test-infra/.github/actions/pull-docker-image@main with: - docker-image: pytorch/libtorch-cxx11-builder:rocm5.2 + docker-image: pytorch/libtorch-cxx11-builder:rocm5.3 - name: Test Pytorch binary uses: ./pytorch/.github/actions/test-pytorch-binary - name: Kill containers, clean up images @@ -1330,22 +1350,22 @@ jobs: docker stop $(docker ps -q) || true # Prune all of the docker images docker system prune -af - libtorch-rocm5_2-static-with-deps-cxx11-abi-upload: # Uploading + libtorch-rocm5_3-static-with-deps-cxx11-abi-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} - needs: libtorch-rocm5_2-static-with-deps-cxx11-abi-test + needs: libtorch-rocm5_3-static-with-deps-cxx11-abi-test with: PYTORCH_ROOT: /pytorch BUILDER_ROOT: /builder PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm5.2 - GPU_ARCH_VERSION: 5.2 + DESIRED_CUDA: rocm5.3 + GPU_ARCH_VERSION: 5.3 GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:rocm5.2 + DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:rocm5.3 LIBTORCH_VARIANT: static-with-deps DESIRED_DEVTOOLSET: cxx11-abi - build_name: libtorch-rocm5_2-static-with-deps-cxx11-abi + build_name: libtorch-rocm5_3-static-with-deps-cxx11-abi secrets: github-token: ${{ secrets.GITHUB_TOKEN }} aws-access-key-id: ${{ secrets.AWS_PYTORCH_UPLOADER_ACCESS_KEY_ID }} diff --git a/.github/workflows/generated-linux-binary-libtorch-pre-cxx11-master.yml b/.github/workflows/generated-linux-binary-libtorch-pre-cxx11-master.yml index edacb2e949b00..39e41e67853ac 100644 --- a/.github/workflows/generated-linux-binary-libtorch-pre-cxx11-master.yml +++ b/.github/workflows/generated-linux-binary-libtorch-pre-cxx11-master.yml @@ -31,7 +31,7 @@ concurrency: cancel-in-progress: true jobs: - libtorch-cpu-shared-with-deps-cxx11-abi-build: + libtorch-cpu-shared-with-deps-pre-cxx11-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml with: @@ -42,17 +42,17 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:cpu + DOCKER_IMAGE: pytorch/manylinux-builder:cpu LIBTORCH_VARIANT: shared-with-deps - DESIRED_DEVTOOLSET: cxx11-abi - build_name: libtorch-cpu-shared-with-deps-cxx11-abi + DESIRED_DEVTOOLSET: pre-cxx11 + build_name: libtorch-cpu-shared-with-deps-pre-cxx11 build_environment: linux-binary-libtorch-pre-cxx11 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - libtorch-cpu-shared-with-deps-cxx11-abi-test: # Testing + libtorch-cpu-shared-with-deps-pre-cxx11-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: libtorch-cpu-shared-with-deps-cxx11-abi-build + needs: libtorch-cpu-shared-with-deps-pre-cxx11-build uses: ./.github/workflows/_binary-test-linux.yml with: PYTORCH_ROOT: /pytorch @@ -62,10 +62,10 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:cpu + DOCKER_IMAGE: pytorch/manylinux-builder:cpu LIBTORCH_VARIANT: shared-with-deps - DESIRED_DEVTOOLSET: cxx11-abi - build_name: libtorch-cpu-shared-with-deps-cxx11-abi + DESIRED_DEVTOOLSET: pre-cxx11 + build_name: libtorch-cpu-shared-with-deps-pre-cxx11 build_environment: linux-binary-libtorch-pre-cxx11 runs_on: linux.4xlarge secrets: diff --git a/.github/workflows/generated-linux-binary-libtorch-pre-cxx11-nightly.yml b/.github/workflows/generated-linux-binary-libtorch-pre-cxx11-nightly.yml index 27358089ba2df..eaa928f3e09a9 100644 --- a/.github/workflows/generated-linux-binary-libtorch-pre-cxx11-nightly.yml +++ b/.github/workflows/generated-linux-binary-libtorch-pre-cxx11-nightly.yml @@ -780,7 +780,7 @@ jobs: aws-pytorch-uploader-secret-access-key: ${{ secrets.AWS_PYTORCH_UPLOADER_SECRET_ACCESS_KEY }} conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - libtorch-rocm5_1_1-shared-with-deps-pre-cxx11-build: + libtorch-rocm5_2-shared-with-deps-pre-cxx11-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml with: @@ -789,20 +789,20 @@ jobs: PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm5.1.1 - GPU_ARCH_VERSION: 5.1.1 + DESIRED_CUDA: rocm5.2 + GPU_ARCH_VERSION: 5.2 GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.1.1 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.2 LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: pre-cxx11 - build_name: libtorch-rocm5_1_1-shared-with-deps-pre-cxx11 + build_name: libtorch-rocm5_2-shared-with-deps-pre-cxx11 build_environment: linux-binary-libtorch-pre-cxx11 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - libtorch-rocm5_1_1-shared-with-deps-pre-cxx11-test: # Testing + libtorch-rocm5_2-shared-with-deps-pre-cxx11-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: libtorch-rocm5_1_1-shared-with-deps-pre-cxx11-build + needs: libtorch-rocm5_2-shared-with-deps-pre-cxx11-build runs-on: linux.rocm.gpu timeout-minutes: 240 env: @@ -811,11 +811,11 @@ jobs: PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm5.1.1 - GPU_ARCH_VERSION: 5.1.1 + DESIRED_CUDA: rocm5.2 + GPU_ARCH_VERSION: 5.2 GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 - DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.1.1 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.2 LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: pre-cxx11 steps: @@ -845,7 +845,12 @@ jobs: run: | ngpu=$(rocminfo | grep -c -E 'Name:.*\sgfx') if [[ "x$ngpu" != "x2" && "x$ngpu" != "x4" ]]; then - echo "Failed to detect GPUs on the runner" + if [[ $ngpu -eq 0 ]]; then + echo "Error: Failed to detect any GPUs on the runner" + else + echo "Error: Detected $ngpu GPUs on the runner, when only 2 or 4 were expected" + fi + echo "Please file an issue on pytorch/pytorch reporting the faulty runner. Include a link to the runner logs so the runner can be identified" exit 1 fi - name: Runner health check disconnect on failure @@ -859,7 +864,7 @@ jobs: - uses: actions/download-artifact@v3 name: Download Build Artifacts with: - name: libtorch-rocm5_1_1-shared-with-deps-pre-cxx11 + name: libtorch-rocm5_2-shared-with-deps-pre-cxx11 path: "${{ runner.temp }}/artifacts/" - name: Checkout PyTorch uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 @@ -890,7 +895,7 @@ jobs: - name: Pull Docker image uses: pytorch/test-infra/.github/actions/pull-docker-image@main with: - docker-image: pytorch/manylinux-builder:rocm5.1.1 + docker-image: pytorch/manylinux-builder:rocm5.2 - name: Test Pytorch binary uses: ./pytorch/.github/actions/test-pytorch-binary - name: Kill containers, clean up images @@ -901,29 +906,29 @@ jobs: docker stop $(docker ps -q) || true # Prune all of the docker images docker system prune -af - libtorch-rocm5_1_1-shared-with-deps-pre-cxx11-upload: # Uploading + libtorch-rocm5_2-shared-with-deps-pre-cxx11-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} - needs: libtorch-rocm5_1_1-shared-with-deps-pre-cxx11-test + needs: libtorch-rocm5_2-shared-with-deps-pre-cxx11-test with: PYTORCH_ROOT: /pytorch BUILDER_ROOT: /builder PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm5.1.1 - GPU_ARCH_VERSION: 5.1.1 + DESIRED_CUDA: rocm5.2 + GPU_ARCH_VERSION: 5.2 GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.1.1 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.2 LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: pre-cxx11 - build_name: libtorch-rocm5_1_1-shared-with-deps-pre-cxx11 + build_name: libtorch-rocm5_2-shared-with-deps-pre-cxx11 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} aws-access-key-id: ${{ secrets.AWS_PYTORCH_UPLOADER_ACCESS_KEY_ID }} aws-pytorch-uploader-secret-access-key: ${{ secrets.AWS_PYTORCH_UPLOADER_SECRET_ACCESS_KEY }} conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - libtorch-rocm5_1_1-static-with-deps-pre-cxx11-build: + libtorch-rocm5_2-static-with-deps-pre-cxx11-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml with: @@ -932,20 +937,20 @@ jobs: PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm5.1.1 - GPU_ARCH_VERSION: 5.1.1 + DESIRED_CUDA: rocm5.2 + GPU_ARCH_VERSION: 5.2 GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.1.1 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.2 LIBTORCH_VARIANT: static-with-deps DESIRED_DEVTOOLSET: pre-cxx11 - build_name: libtorch-rocm5_1_1-static-with-deps-pre-cxx11 + build_name: libtorch-rocm5_2-static-with-deps-pre-cxx11 build_environment: linux-binary-libtorch-pre-cxx11 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - libtorch-rocm5_1_1-static-with-deps-pre-cxx11-test: # Testing + libtorch-rocm5_2-static-with-deps-pre-cxx11-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: libtorch-rocm5_1_1-static-with-deps-pre-cxx11-build + needs: libtorch-rocm5_2-static-with-deps-pre-cxx11-build runs-on: linux.rocm.gpu timeout-minutes: 240 env: @@ -954,11 +959,11 @@ jobs: PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm5.1.1 - GPU_ARCH_VERSION: 5.1.1 + DESIRED_CUDA: rocm5.2 + GPU_ARCH_VERSION: 5.2 GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 - DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.1.1 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.2 LIBTORCH_VARIANT: static-with-deps DESIRED_DEVTOOLSET: pre-cxx11 steps: @@ -988,7 +993,12 @@ jobs: run: | ngpu=$(rocminfo | grep -c -E 'Name:.*\sgfx') if [[ "x$ngpu" != "x2" && "x$ngpu" != "x4" ]]; then - echo "Failed to detect GPUs on the runner" + if [[ $ngpu -eq 0 ]]; then + echo "Error: Failed to detect any GPUs on the runner" + else + echo "Error: Detected $ngpu GPUs on the runner, when only 2 or 4 were expected" + fi + echo "Please file an issue on pytorch/pytorch reporting the faulty runner. Include a link to the runner logs so the runner can be identified" exit 1 fi - name: Runner health check disconnect on failure @@ -1002,7 +1012,7 @@ jobs: - uses: actions/download-artifact@v3 name: Download Build Artifacts with: - name: libtorch-rocm5_1_1-static-with-deps-pre-cxx11 + name: libtorch-rocm5_2-static-with-deps-pre-cxx11 path: "${{ runner.temp }}/artifacts/" - name: Checkout PyTorch uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 @@ -1033,7 +1043,7 @@ jobs: - name: Pull Docker image uses: pytorch/test-infra/.github/actions/pull-docker-image@main with: - docker-image: pytorch/manylinux-builder:rocm5.1.1 + docker-image: pytorch/manylinux-builder:rocm5.2 - name: Test Pytorch binary uses: ./pytorch/.github/actions/test-pytorch-binary - name: Kill containers, clean up images @@ -1044,29 +1054,29 @@ jobs: docker stop $(docker ps -q) || true # Prune all of the docker images docker system prune -af - libtorch-rocm5_1_1-static-with-deps-pre-cxx11-upload: # Uploading + libtorch-rocm5_2-static-with-deps-pre-cxx11-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} - needs: libtorch-rocm5_1_1-static-with-deps-pre-cxx11-test + needs: libtorch-rocm5_2-static-with-deps-pre-cxx11-test with: PYTORCH_ROOT: /pytorch BUILDER_ROOT: /builder PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm5.1.1 - GPU_ARCH_VERSION: 5.1.1 + DESIRED_CUDA: rocm5.2 + GPU_ARCH_VERSION: 5.2 GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.1.1 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.2 LIBTORCH_VARIANT: static-with-deps DESIRED_DEVTOOLSET: pre-cxx11 - build_name: libtorch-rocm5_1_1-static-with-deps-pre-cxx11 + build_name: libtorch-rocm5_2-static-with-deps-pre-cxx11 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} aws-access-key-id: ${{ secrets.AWS_PYTORCH_UPLOADER_ACCESS_KEY_ID }} aws-pytorch-uploader-secret-access-key: ${{ secrets.AWS_PYTORCH_UPLOADER_SECRET_ACCESS_KEY }} conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - libtorch-rocm5_2-shared-with-deps-pre-cxx11-build: + libtorch-rocm5_3-shared-with-deps-pre-cxx11-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml with: @@ -1075,20 +1085,20 @@ jobs: PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm5.2 - GPU_ARCH_VERSION: 5.2 + DESIRED_CUDA: rocm5.3 + GPU_ARCH_VERSION: 5.3 GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.2 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.3 LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: pre-cxx11 - build_name: libtorch-rocm5_2-shared-with-deps-pre-cxx11 + build_name: libtorch-rocm5_3-shared-with-deps-pre-cxx11 build_environment: linux-binary-libtorch-pre-cxx11 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - libtorch-rocm5_2-shared-with-deps-pre-cxx11-test: # Testing + libtorch-rocm5_3-shared-with-deps-pre-cxx11-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: libtorch-rocm5_2-shared-with-deps-pre-cxx11-build + needs: libtorch-rocm5_3-shared-with-deps-pre-cxx11-build runs-on: linux.rocm.gpu timeout-minutes: 240 env: @@ -1097,11 +1107,11 @@ jobs: PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm5.2 - GPU_ARCH_VERSION: 5.2 + DESIRED_CUDA: rocm5.3 + GPU_ARCH_VERSION: 5.3 GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 - DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.2 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.3 LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: pre-cxx11 steps: @@ -1131,7 +1141,12 @@ jobs: run: | ngpu=$(rocminfo | grep -c -E 'Name:.*\sgfx') if [[ "x$ngpu" != "x2" && "x$ngpu" != "x4" ]]; then - echo "Failed to detect GPUs on the runner" + if [[ $ngpu -eq 0 ]]; then + echo "Error: Failed to detect any GPUs on the runner" + else + echo "Error: Detected $ngpu GPUs on the runner, when only 2 or 4 were expected" + fi + echo "Please file an issue on pytorch/pytorch reporting the faulty runner. Include a link to the runner logs so the runner can be identified" exit 1 fi - name: Runner health check disconnect on failure @@ -1145,7 +1160,7 @@ jobs: - uses: actions/download-artifact@v3 name: Download Build Artifacts with: - name: libtorch-rocm5_2-shared-with-deps-pre-cxx11 + name: libtorch-rocm5_3-shared-with-deps-pre-cxx11 path: "${{ runner.temp }}/artifacts/" - name: Checkout PyTorch uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 @@ -1176,7 +1191,7 @@ jobs: - name: Pull Docker image uses: pytorch/test-infra/.github/actions/pull-docker-image@main with: - docker-image: pytorch/manylinux-builder:rocm5.2 + docker-image: pytorch/manylinux-builder:rocm5.3 - name: Test Pytorch binary uses: ./pytorch/.github/actions/test-pytorch-binary - name: Kill containers, clean up images @@ -1187,29 +1202,29 @@ jobs: docker stop $(docker ps -q) || true # Prune all of the docker images docker system prune -af - libtorch-rocm5_2-shared-with-deps-pre-cxx11-upload: # Uploading + libtorch-rocm5_3-shared-with-deps-pre-cxx11-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} - needs: libtorch-rocm5_2-shared-with-deps-pre-cxx11-test + needs: libtorch-rocm5_3-shared-with-deps-pre-cxx11-test with: PYTORCH_ROOT: /pytorch BUILDER_ROOT: /builder PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm5.2 - GPU_ARCH_VERSION: 5.2 + DESIRED_CUDA: rocm5.3 + GPU_ARCH_VERSION: 5.3 GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.2 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.3 LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: pre-cxx11 - build_name: libtorch-rocm5_2-shared-with-deps-pre-cxx11 + build_name: libtorch-rocm5_3-shared-with-deps-pre-cxx11 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} aws-access-key-id: ${{ secrets.AWS_PYTORCH_UPLOADER_ACCESS_KEY_ID }} aws-pytorch-uploader-secret-access-key: ${{ secrets.AWS_PYTORCH_UPLOADER_SECRET_ACCESS_KEY }} conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - libtorch-rocm5_2-static-with-deps-pre-cxx11-build: + libtorch-rocm5_3-static-with-deps-pre-cxx11-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml with: @@ -1218,20 +1233,20 @@ jobs: PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm5.2 - GPU_ARCH_VERSION: 5.2 + DESIRED_CUDA: rocm5.3 + GPU_ARCH_VERSION: 5.3 GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.2 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.3 LIBTORCH_VARIANT: static-with-deps DESIRED_DEVTOOLSET: pre-cxx11 - build_name: libtorch-rocm5_2-static-with-deps-pre-cxx11 + build_name: libtorch-rocm5_3-static-with-deps-pre-cxx11 build_environment: linux-binary-libtorch-pre-cxx11 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - libtorch-rocm5_2-static-with-deps-pre-cxx11-test: # Testing + libtorch-rocm5_3-static-with-deps-pre-cxx11-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: libtorch-rocm5_2-static-with-deps-pre-cxx11-build + needs: libtorch-rocm5_3-static-with-deps-pre-cxx11-build runs-on: linux.rocm.gpu timeout-minutes: 240 env: @@ -1240,11 +1255,11 @@ jobs: PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm5.2 - GPU_ARCH_VERSION: 5.2 + DESIRED_CUDA: rocm5.3 + GPU_ARCH_VERSION: 5.3 GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 - DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.2 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.3 LIBTORCH_VARIANT: static-with-deps DESIRED_DEVTOOLSET: pre-cxx11 steps: @@ -1274,7 +1289,12 @@ jobs: run: | ngpu=$(rocminfo | grep -c -E 'Name:.*\sgfx') if [[ "x$ngpu" != "x2" && "x$ngpu" != "x4" ]]; then - echo "Failed to detect GPUs on the runner" + if [[ $ngpu -eq 0 ]]; then + echo "Error: Failed to detect any GPUs on the runner" + else + echo "Error: Detected $ngpu GPUs on the runner, when only 2 or 4 were expected" + fi + echo "Please file an issue on pytorch/pytorch reporting the faulty runner. Include a link to the runner logs so the runner can be identified" exit 1 fi - name: Runner health check disconnect on failure @@ -1288,7 +1308,7 @@ jobs: - uses: actions/download-artifact@v3 name: Download Build Artifacts with: - name: libtorch-rocm5_2-static-with-deps-pre-cxx11 + name: libtorch-rocm5_3-static-with-deps-pre-cxx11 path: "${{ runner.temp }}/artifacts/" - name: Checkout PyTorch uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 @@ -1319,7 +1339,7 @@ jobs: - name: Pull Docker image uses: pytorch/test-infra/.github/actions/pull-docker-image@main with: - docker-image: pytorch/manylinux-builder:rocm5.2 + docker-image: pytorch/manylinux-builder:rocm5.3 - name: Test Pytorch binary uses: ./pytorch/.github/actions/test-pytorch-binary - name: Kill containers, clean up images @@ -1330,22 +1350,22 @@ jobs: docker stop $(docker ps -q) || true # Prune all of the docker images docker system prune -af - libtorch-rocm5_2-static-with-deps-pre-cxx11-upload: # Uploading + libtorch-rocm5_3-static-with-deps-pre-cxx11-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} - needs: libtorch-rocm5_2-static-with-deps-pre-cxx11-test + needs: libtorch-rocm5_3-static-with-deps-pre-cxx11-test with: PYTORCH_ROOT: /pytorch BUILDER_ROOT: /builder PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm5.2 - GPU_ARCH_VERSION: 5.2 + DESIRED_CUDA: rocm5.3 + GPU_ARCH_VERSION: 5.3 GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.2 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.3 LIBTORCH_VARIANT: static-with-deps DESIRED_DEVTOOLSET: pre-cxx11 - build_name: libtorch-rocm5_2-static-with-deps-pre-cxx11 + build_name: libtorch-rocm5_3-static-with-deps-pre-cxx11 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} aws-access-key-id: ${{ secrets.AWS_PYTORCH_UPLOADER_ACCESS_KEY_ID }} diff --git a/.github/workflows/generated-linux-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-binary-manywheel-nightly.yml index ac9edc252c28e..b93f797d7e01c 100644 --- a/.github/workflows/generated-linux-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-binary-manywheel-nightly.yml @@ -169,7 +169,7 @@ jobs: DESIRED_PYTHON: "3.7" build_name: manywheel-py3_7-cuda11_7-with-pypi-cudnn build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-runtime-cu11;nvidia-cudnn-cu11==8.5.0.96;nvidia-cublas-cu11==11.10.3.66 + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-runtime-cu11; platform_system == 'Linux' | nvidia-cudnn-cu11==8.5.0.96; platform_system == 'Linux' | nvidia-cublas-cu11==11.10.3.66; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -274,7 +274,7 @@ jobs: aws-pytorch-uploader-secret-access-key: ${{ secrets.AWS_PYTORCH_UPLOADER_SECRET_ACCESS_KEY }} conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_7-rocm5_1_1-build: + manywheel-py3_7-rocm5_2-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml with: @@ -283,19 +283,19 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm5.1.1 - GPU_ARCH_VERSION: 5.1.1 + DESIRED_CUDA: rocm5.2 + GPU_ARCH_VERSION: 5.2 GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.1.1 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.2 DESIRED_PYTHON: "3.7" - build_name: manywheel-py3_7-rocm5_1_1 + build_name: manywheel-py3_7-rocm5_2 build_environment: linux-binary-manywheel secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_7-rocm5_1_1-test: # Testing + manywheel-py3_7-rocm5_2-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: manywheel-py3_7-rocm5_1_1-build + needs: manywheel-py3_7-rocm5_2-build runs-on: linux.rocm.gpu timeout-minutes: 240 env: @@ -304,11 +304,11 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm5.1.1 - GPU_ARCH_VERSION: 5.1.1 + DESIRED_CUDA: rocm5.2 + GPU_ARCH_VERSION: 5.2 GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 - DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.1.1 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.2 DESIRED_PYTHON: "3.7" steps: - name: Clean workspace @@ -337,7 +337,12 @@ jobs: run: | ngpu=$(rocminfo | grep -c -E 'Name:.*\sgfx') if [[ "x$ngpu" != "x2" && "x$ngpu" != "x4" ]]; then - echo "Failed to detect GPUs on the runner" + if [[ $ngpu -eq 0 ]]; then + echo "Error: Failed to detect any GPUs on the runner" + else + echo "Error: Detected $ngpu GPUs on the runner, when only 2 or 4 were expected" + fi + echo "Please file an issue on pytorch/pytorch reporting the faulty runner. Include a link to the runner logs so the runner can be identified" exit 1 fi - name: Runner health check disconnect on failure @@ -351,7 +356,7 @@ jobs: - uses: actions/download-artifact@v3 name: Download Build Artifacts with: - name: manywheel-py3_7-rocm5_1_1 + name: manywheel-py3_7-rocm5_2 path: "${{ runner.temp }}/artifacts/" - name: Checkout PyTorch uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 @@ -382,7 +387,7 @@ jobs: - name: Pull Docker image uses: pytorch/test-infra/.github/actions/pull-docker-image@main with: - docker-image: pytorch/manylinux-builder:rocm5.1.1 + docker-image: pytorch/manylinux-builder:rocm5.2 - name: Test Pytorch binary uses: ./pytorch/.github/actions/test-pytorch-binary - name: Kill containers, clean up images @@ -393,28 +398,28 @@ jobs: docker stop $(docker ps -q) || true # Prune all of the docker images docker system prune -af - manywheel-py3_7-rocm5_1_1-upload: # Uploading + manywheel-py3_7-rocm5_2-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} - needs: manywheel-py3_7-rocm5_1_1-test + needs: manywheel-py3_7-rocm5_2-test with: PYTORCH_ROOT: /pytorch BUILDER_ROOT: /builder PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm5.1.1 - GPU_ARCH_VERSION: 5.1.1 + DESIRED_CUDA: rocm5.2 + GPU_ARCH_VERSION: 5.2 GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.1.1 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.2 DESIRED_PYTHON: "3.7" - build_name: manywheel-py3_7-rocm5_1_1 + build_name: manywheel-py3_7-rocm5_2 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} aws-access-key-id: ${{ secrets.AWS_PYTORCH_UPLOADER_ACCESS_KEY_ID }} aws-pytorch-uploader-secret-access-key: ${{ secrets.AWS_PYTORCH_UPLOADER_SECRET_ACCESS_KEY }} conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_7-rocm5_2-build: + manywheel-py3_7-rocm5_3-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml with: @@ -423,19 +428,19 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm5.2 - GPU_ARCH_VERSION: 5.2 + DESIRED_CUDA: rocm5.3 + GPU_ARCH_VERSION: 5.3 GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.2 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.3 DESIRED_PYTHON: "3.7" - build_name: manywheel-py3_7-rocm5_2 + build_name: manywheel-py3_7-rocm5_3 build_environment: linux-binary-manywheel secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_7-rocm5_2-test: # Testing + manywheel-py3_7-rocm5_3-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: manywheel-py3_7-rocm5_2-build + needs: manywheel-py3_7-rocm5_3-build runs-on: linux.rocm.gpu timeout-minutes: 240 env: @@ -444,11 +449,11 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm5.2 - GPU_ARCH_VERSION: 5.2 + DESIRED_CUDA: rocm5.3 + GPU_ARCH_VERSION: 5.3 GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 - DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.2 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.3 DESIRED_PYTHON: "3.7" steps: - name: Clean workspace @@ -477,7 +482,12 @@ jobs: run: | ngpu=$(rocminfo | grep -c -E 'Name:.*\sgfx') if [[ "x$ngpu" != "x2" && "x$ngpu" != "x4" ]]; then - echo "Failed to detect GPUs on the runner" + if [[ $ngpu -eq 0 ]]; then + echo "Error: Failed to detect any GPUs on the runner" + else + echo "Error: Detected $ngpu GPUs on the runner, when only 2 or 4 were expected" + fi + echo "Please file an issue on pytorch/pytorch reporting the faulty runner. Include a link to the runner logs so the runner can be identified" exit 1 fi - name: Runner health check disconnect on failure @@ -491,7 +501,7 @@ jobs: - uses: actions/download-artifact@v3 name: Download Build Artifacts with: - name: manywheel-py3_7-rocm5_2 + name: manywheel-py3_7-rocm5_3 path: "${{ runner.temp }}/artifacts/" - name: Checkout PyTorch uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 @@ -522,7 +532,7 @@ jobs: - name: Pull Docker image uses: pytorch/test-infra/.github/actions/pull-docker-image@main with: - docker-image: pytorch/manylinux-builder:rocm5.2 + docker-image: pytorch/manylinux-builder:rocm5.3 - name: Test Pytorch binary uses: ./pytorch/.github/actions/test-pytorch-binary - name: Kill containers, clean up images @@ -533,21 +543,21 @@ jobs: docker stop $(docker ps -q) || true # Prune all of the docker images docker system prune -af - manywheel-py3_7-rocm5_2-upload: # Uploading + manywheel-py3_7-rocm5_3-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} - needs: manywheel-py3_7-rocm5_2-test + needs: manywheel-py3_7-rocm5_3-test with: PYTORCH_ROOT: /pytorch BUILDER_ROOT: /builder PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm5.2 - GPU_ARCH_VERSION: 5.2 + DESIRED_CUDA: rocm5.3 + GPU_ARCH_VERSION: 5.3 GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.2 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.3 DESIRED_PYTHON: "3.7" - build_name: manywheel-py3_7-rocm5_2 + build_name: manywheel-py3_7-rocm5_3 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} aws-access-key-id: ${{ secrets.AWS_PYTORCH_UPLOADER_ACCESS_KEY_ID }} @@ -687,7 +697,7 @@ jobs: DESIRED_PYTHON: "3.8" build_name: manywheel-py3_8-cuda11_7-with-pypi-cudnn build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-runtime-cu11;nvidia-cudnn-cu11==8.5.0.96;nvidia-cublas-cu11==11.10.3.66 + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-runtime-cu11; platform_system == 'Linux' | nvidia-cudnn-cu11==8.5.0.96; platform_system == 'Linux' | nvidia-cublas-cu11==11.10.3.66; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -792,7 +802,7 @@ jobs: aws-pytorch-uploader-secret-access-key: ${{ secrets.AWS_PYTORCH_UPLOADER_SECRET_ACCESS_KEY }} conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_8-rocm5_1_1-build: + manywheel-py3_8-rocm5_2-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml with: @@ -801,19 +811,19 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm5.1.1 - GPU_ARCH_VERSION: 5.1.1 + DESIRED_CUDA: rocm5.2 + GPU_ARCH_VERSION: 5.2 GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.1.1 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.2 DESIRED_PYTHON: "3.8" - build_name: manywheel-py3_8-rocm5_1_1 + build_name: manywheel-py3_8-rocm5_2 build_environment: linux-binary-manywheel secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_8-rocm5_1_1-test: # Testing + manywheel-py3_8-rocm5_2-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: manywheel-py3_8-rocm5_1_1-build + needs: manywheel-py3_8-rocm5_2-build runs-on: linux.rocm.gpu timeout-minutes: 240 env: @@ -822,11 +832,11 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm5.1.1 - GPU_ARCH_VERSION: 5.1.1 + DESIRED_CUDA: rocm5.2 + GPU_ARCH_VERSION: 5.2 GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 - DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.1.1 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.2 DESIRED_PYTHON: "3.8" steps: - name: Clean workspace @@ -855,7 +865,12 @@ jobs: run: | ngpu=$(rocminfo | grep -c -E 'Name:.*\sgfx') if [[ "x$ngpu" != "x2" && "x$ngpu" != "x4" ]]; then - echo "Failed to detect GPUs on the runner" + if [[ $ngpu -eq 0 ]]; then + echo "Error: Failed to detect any GPUs on the runner" + else + echo "Error: Detected $ngpu GPUs on the runner, when only 2 or 4 were expected" + fi + echo "Please file an issue on pytorch/pytorch reporting the faulty runner. Include a link to the runner logs so the runner can be identified" exit 1 fi - name: Runner health check disconnect on failure @@ -869,7 +884,7 @@ jobs: - uses: actions/download-artifact@v3 name: Download Build Artifacts with: - name: manywheel-py3_8-rocm5_1_1 + name: manywheel-py3_8-rocm5_2 path: "${{ runner.temp }}/artifacts/" - name: Checkout PyTorch uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 @@ -900,7 +915,7 @@ jobs: - name: Pull Docker image uses: pytorch/test-infra/.github/actions/pull-docker-image@main with: - docker-image: pytorch/manylinux-builder:rocm5.1.1 + docker-image: pytorch/manylinux-builder:rocm5.2 - name: Test Pytorch binary uses: ./pytorch/.github/actions/test-pytorch-binary - name: Kill containers, clean up images @@ -911,28 +926,28 @@ jobs: docker stop $(docker ps -q) || true # Prune all of the docker images docker system prune -af - manywheel-py3_8-rocm5_1_1-upload: # Uploading + manywheel-py3_8-rocm5_2-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} - needs: manywheel-py3_8-rocm5_1_1-test + needs: manywheel-py3_8-rocm5_2-test with: PYTORCH_ROOT: /pytorch BUILDER_ROOT: /builder PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm5.1.1 - GPU_ARCH_VERSION: 5.1.1 + DESIRED_CUDA: rocm5.2 + GPU_ARCH_VERSION: 5.2 GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.1.1 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.2 DESIRED_PYTHON: "3.8" - build_name: manywheel-py3_8-rocm5_1_1 + build_name: manywheel-py3_8-rocm5_2 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} aws-access-key-id: ${{ secrets.AWS_PYTORCH_UPLOADER_ACCESS_KEY_ID }} aws-pytorch-uploader-secret-access-key: ${{ secrets.AWS_PYTORCH_UPLOADER_SECRET_ACCESS_KEY }} conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_8-rocm5_2-build: + manywheel-py3_8-rocm5_3-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml with: @@ -941,19 +956,19 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm5.2 - GPU_ARCH_VERSION: 5.2 + DESIRED_CUDA: rocm5.3 + GPU_ARCH_VERSION: 5.3 GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.2 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.3 DESIRED_PYTHON: "3.8" - build_name: manywheel-py3_8-rocm5_2 + build_name: manywheel-py3_8-rocm5_3 build_environment: linux-binary-manywheel secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_8-rocm5_2-test: # Testing + manywheel-py3_8-rocm5_3-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: manywheel-py3_8-rocm5_2-build + needs: manywheel-py3_8-rocm5_3-build runs-on: linux.rocm.gpu timeout-minutes: 240 env: @@ -962,11 +977,11 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm5.2 - GPU_ARCH_VERSION: 5.2 + DESIRED_CUDA: rocm5.3 + GPU_ARCH_VERSION: 5.3 GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 - DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.2 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.3 DESIRED_PYTHON: "3.8" steps: - name: Clean workspace @@ -995,7 +1010,12 @@ jobs: run: | ngpu=$(rocminfo | grep -c -E 'Name:.*\sgfx') if [[ "x$ngpu" != "x2" && "x$ngpu" != "x4" ]]; then - echo "Failed to detect GPUs on the runner" + if [[ $ngpu -eq 0 ]]; then + echo "Error: Failed to detect any GPUs on the runner" + else + echo "Error: Detected $ngpu GPUs on the runner, when only 2 or 4 were expected" + fi + echo "Please file an issue on pytorch/pytorch reporting the faulty runner. Include a link to the runner logs so the runner can be identified" exit 1 fi - name: Runner health check disconnect on failure @@ -1009,7 +1029,7 @@ jobs: - uses: actions/download-artifact@v3 name: Download Build Artifacts with: - name: manywheel-py3_8-rocm5_2 + name: manywheel-py3_8-rocm5_3 path: "${{ runner.temp }}/artifacts/" - name: Checkout PyTorch uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 @@ -1040,7 +1060,7 @@ jobs: - name: Pull Docker image uses: pytorch/test-infra/.github/actions/pull-docker-image@main with: - docker-image: pytorch/manylinux-builder:rocm5.2 + docker-image: pytorch/manylinux-builder:rocm5.3 - name: Test Pytorch binary uses: ./pytorch/.github/actions/test-pytorch-binary - name: Kill containers, clean up images @@ -1051,21 +1071,21 @@ jobs: docker stop $(docker ps -q) || true # Prune all of the docker images docker system prune -af - manywheel-py3_8-rocm5_2-upload: # Uploading + manywheel-py3_8-rocm5_3-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} - needs: manywheel-py3_8-rocm5_2-test + needs: manywheel-py3_8-rocm5_3-test with: PYTORCH_ROOT: /pytorch BUILDER_ROOT: /builder PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm5.2 - GPU_ARCH_VERSION: 5.2 + DESIRED_CUDA: rocm5.3 + GPU_ARCH_VERSION: 5.3 GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.2 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.3 DESIRED_PYTHON: "3.8" - build_name: manywheel-py3_8-rocm5_2 + build_name: manywheel-py3_8-rocm5_3 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} aws-access-key-id: ${{ secrets.AWS_PYTORCH_UPLOADER_ACCESS_KEY_ID }} @@ -1205,7 +1225,7 @@ jobs: DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda11_7-with-pypi-cudnn build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-runtime-cu11;nvidia-cudnn-cu11==8.5.0.96;nvidia-cublas-cu11==11.10.3.66 + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-runtime-cu11; platform_system == 'Linux' | nvidia-cudnn-cu11==8.5.0.96; platform_system == 'Linux' | nvidia-cublas-cu11==11.10.3.66; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1310,7 +1330,7 @@ jobs: aws-pytorch-uploader-secret-access-key: ${{ secrets.AWS_PYTORCH_UPLOADER_SECRET_ACCESS_KEY }} conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_9-rocm5_1_1-build: + manywheel-py3_9-rocm5_2-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml with: @@ -1319,19 +1339,19 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm5.1.1 - GPU_ARCH_VERSION: 5.1.1 + DESIRED_CUDA: rocm5.2 + GPU_ARCH_VERSION: 5.2 GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.1.1 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.2 DESIRED_PYTHON: "3.9" - build_name: manywheel-py3_9-rocm5_1_1 + build_name: manywheel-py3_9-rocm5_2 build_environment: linux-binary-manywheel secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_9-rocm5_1_1-test: # Testing + manywheel-py3_9-rocm5_2-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: manywheel-py3_9-rocm5_1_1-build + needs: manywheel-py3_9-rocm5_2-build runs-on: linux.rocm.gpu timeout-minutes: 240 env: @@ -1340,11 +1360,11 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm5.1.1 - GPU_ARCH_VERSION: 5.1.1 + DESIRED_CUDA: rocm5.2 + GPU_ARCH_VERSION: 5.2 GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 - DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.1.1 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.2 DESIRED_PYTHON: "3.9" steps: - name: Clean workspace @@ -1373,7 +1393,12 @@ jobs: run: | ngpu=$(rocminfo | grep -c -E 'Name:.*\sgfx') if [[ "x$ngpu" != "x2" && "x$ngpu" != "x4" ]]; then - echo "Failed to detect GPUs on the runner" + if [[ $ngpu -eq 0 ]]; then + echo "Error: Failed to detect any GPUs on the runner" + else + echo "Error: Detected $ngpu GPUs on the runner, when only 2 or 4 were expected" + fi + echo "Please file an issue on pytorch/pytorch reporting the faulty runner. Include a link to the runner logs so the runner can be identified" exit 1 fi - name: Runner health check disconnect on failure @@ -1387,7 +1412,7 @@ jobs: - uses: actions/download-artifact@v3 name: Download Build Artifacts with: - name: manywheel-py3_9-rocm5_1_1 + name: manywheel-py3_9-rocm5_2 path: "${{ runner.temp }}/artifacts/" - name: Checkout PyTorch uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 @@ -1418,7 +1443,7 @@ jobs: - name: Pull Docker image uses: pytorch/test-infra/.github/actions/pull-docker-image@main with: - docker-image: pytorch/manylinux-builder:rocm5.1.1 + docker-image: pytorch/manylinux-builder:rocm5.2 - name: Test Pytorch binary uses: ./pytorch/.github/actions/test-pytorch-binary - name: Kill containers, clean up images @@ -1429,28 +1454,28 @@ jobs: docker stop $(docker ps -q) || true # Prune all of the docker images docker system prune -af - manywheel-py3_9-rocm5_1_1-upload: # Uploading + manywheel-py3_9-rocm5_2-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} - needs: manywheel-py3_9-rocm5_1_1-test + needs: manywheel-py3_9-rocm5_2-test with: PYTORCH_ROOT: /pytorch BUILDER_ROOT: /builder PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm5.1.1 - GPU_ARCH_VERSION: 5.1.1 + DESIRED_CUDA: rocm5.2 + GPU_ARCH_VERSION: 5.2 GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.1.1 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.2 DESIRED_PYTHON: "3.9" - build_name: manywheel-py3_9-rocm5_1_1 + build_name: manywheel-py3_9-rocm5_2 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} aws-access-key-id: ${{ secrets.AWS_PYTORCH_UPLOADER_ACCESS_KEY_ID }} aws-pytorch-uploader-secret-access-key: ${{ secrets.AWS_PYTORCH_UPLOADER_SECRET_ACCESS_KEY }} conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_9-rocm5_2-build: + manywheel-py3_9-rocm5_3-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml with: @@ -1459,19 +1484,19 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm5.2 - GPU_ARCH_VERSION: 5.2 + DESIRED_CUDA: rocm5.3 + GPU_ARCH_VERSION: 5.3 GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.2 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.3 DESIRED_PYTHON: "3.9" - build_name: manywheel-py3_9-rocm5_2 + build_name: manywheel-py3_9-rocm5_3 build_environment: linux-binary-manywheel secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_9-rocm5_2-test: # Testing + manywheel-py3_9-rocm5_3-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: manywheel-py3_9-rocm5_2-build + needs: manywheel-py3_9-rocm5_3-build runs-on: linux.rocm.gpu timeout-minutes: 240 env: @@ -1480,11 +1505,11 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm5.2 - GPU_ARCH_VERSION: 5.2 + DESIRED_CUDA: rocm5.3 + GPU_ARCH_VERSION: 5.3 GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 - DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.2 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.3 DESIRED_PYTHON: "3.9" steps: - name: Clean workspace @@ -1513,7 +1538,12 @@ jobs: run: | ngpu=$(rocminfo | grep -c -E 'Name:.*\sgfx') if [[ "x$ngpu" != "x2" && "x$ngpu" != "x4" ]]; then - echo "Failed to detect GPUs on the runner" + if [[ $ngpu -eq 0 ]]; then + echo "Error: Failed to detect any GPUs on the runner" + else + echo "Error: Detected $ngpu GPUs on the runner, when only 2 or 4 were expected" + fi + echo "Please file an issue on pytorch/pytorch reporting the faulty runner. Include a link to the runner logs so the runner can be identified" exit 1 fi - name: Runner health check disconnect on failure @@ -1527,7 +1557,7 @@ jobs: - uses: actions/download-artifact@v3 name: Download Build Artifacts with: - name: manywheel-py3_9-rocm5_2 + name: manywheel-py3_9-rocm5_3 path: "${{ runner.temp }}/artifacts/" - name: Checkout PyTorch uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 @@ -1558,7 +1588,7 @@ jobs: - name: Pull Docker image uses: pytorch/test-infra/.github/actions/pull-docker-image@main with: - docker-image: pytorch/manylinux-builder:rocm5.2 + docker-image: pytorch/manylinux-builder:rocm5.3 - name: Test Pytorch binary uses: ./pytorch/.github/actions/test-pytorch-binary - name: Kill containers, clean up images @@ -1569,21 +1599,21 @@ jobs: docker stop $(docker ps -q) || true # Prune all of the docker images docker system prune -af - manywheel-py3_9-rocm5_2-upload: # Uploading + manywheel-py3_9-rocm5_3-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} - needs: manywheel-py3_9-rocm5_2-test + needs: manywheel-py3_9-rocm5_3-test with: PYTORCH_ROOT: /pytorch BUILDER_ROOT: /builder PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm5.2 - GPU_ARCH_VERSION: 5.2 + DESIRED_CUDA: rocm5.3 + GPU_ARCH_VERSION: 5.3 GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.2 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.3 DESIRED_PYTHON: "3.9" - build_name: manywheel-py3_9-rocm5_2 + build_name: manywheel-py3_9-rocm5_3 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} aws-access-key-id: ${{ secrets.AWS_PYTORCH_UPLOADER_ACCESS_KEY_ID }} @@ -1723,7 +1753,7 @@ jobs: DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda11_7-with-pypi-cudnn build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-runtime-cu11;nvidia-cudnn-cu11==8.5.0.96;nvidia-cublas-cu11==11.10.3.66 + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-runtime-cu11; platform_system == 'Linux' | nvidia-cudnn-cu11==8.5.0.96; platform_system == 'Linux' | nvidia-cublas-cu11==11.10.3.66; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1828,7 +1858,7 @@ jobs: aws-pytorch-uploader-secret-access-key: ${{ secrets.AWS_PYTORCH_UPLOADER_SECRET_ACCESS_KEY }} conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_10-rocm5_1_1-build: + manywheel-py3_10-rocm5_2-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml with: @@ -1837,19 +1867,19 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm5.1.1 - GPU_ARCH_VERSION: 5.1.1 + DESIRED_CUDA: rocm5.2 + GPU_ARCH_VERSION: 5.2 GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.1.1 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.2 DESIRED_PYTHON: "3.10" - build_name: manywheel-py3_10-rocm5_1_1 + build_name: manywheel-py3_10-rocm5_2 build_environment: linux-binary-manywheel secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_10-rocm5_1_1-test: # Testing + manywheel-py3_10-rocm5_2-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: manywheel-py3_10-rocm5_1_1-build + needs: manywheel-py3_10-rocm5_2-build runs-on: linux.rocm.gpu timeout-minutes: 240 env: @@ -1858,11 +1888,11 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm5.1.1 - GPU_ARCH_VERSION: 5.1.1 + DESIRED_CUDA: rocm5.2 + GPU_ARCH_VERSION: 5.2 GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 - DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.1.1 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.2 DESIRED_PYTHON: "3.10" steps: - name: Clean workspace @@ -1891,7 +1921,12 @@ jobs: run: | ngpu=$(rocminfo | grep -c -E 'Name:.*\sgfx') if [[ "x$ngpu" != "x2" && "x$ngpu" != "x4" ]]; then - echo "Failed to detect GPUs on the runner" + if [[ $ngpu -eq 0 ]]; then + echo "Error: Failed to detect any GPUs on the runner" + else + echo "Error: Detected $ngpu GPUs on the runner, when only 2 or 4 were expected" + fi + echo "Please file an issue on pytorch/pytorch reporting the faulty runner. Include a link to the runner logs so the runner can be identified" exit 1 fi - name: Runner health check disconnect on failure @@ -1905,7 +1940,7 @@ jobs: - uses: actions/download-artifact@v3 name: Download Build Artifacts with: - name: manywheel-py3_10-rocm5_1_1 + name: manywheel-py3_10-rocm5_2 path: "${{ runner.temp }}/artifacts/" - name: Checkout PyTorch uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 @@ -1936,7 +1971,7 @@ jobs: - name: Pull Docker image uses: pytorch/test-infra/.github/actions/pull-docker-image@main with: - docker-image: pytorch/manylinux-builder:rocm5.1.1 + docker-image: pytorch/manylinux-builder:rocm5.2 - name: Test Pytorch binary uses: ./pytorch/.github/actions/test-pytorch-binary - name: Kill containers, clean up images @@ -1947,28 +1982,28 @@ jobs: docker stop $(docker ps -q) || true # Prune all of the docker images docker system prune -af - manywheel-py3_10-rocm5_1_1-upload: # Uploading + manywheel-py3_10-rocm5_2-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} - needs: manywheel-py3_10-rocm5_1_1-test + needs: manywheel-py3_10-rocm5_2-test with: PYTORCH_ROOT: /pytorch BUILDER_ROOT: /builder PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm5.1.1 - GPU_ARCH_VERSION: 5.1.1 + DESIRED_CUDA: rocm5.2 + GPU_ARCH_VERSION: 5.2 GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.1.1 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.2 DESIRED_PYTHON: "3.10" - build_name: manywheel-py3_10-rocm5_1_1 + build_name: manywheel-py3_10-rocm5_2 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} aws-access-key-id: ${{ secrets.AWS_PYTORCH_UPLOADER_ACCESS_KEY_ID }} aws-pytorch-uploader-secret-access-key: ${{ secrets.AWS_PYTORCH_UPLOADER_SECRET_ACCESS_KEY }} conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_10-rocm5_2-build: + manywheel-py3_10-rocm5_3-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml with: @@ -1977,19 +2012,19 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm5.2 - GPU_ARCH_VERSION: 5.2 + DESIRED_CUDA: rocm5.3 + GPU_ARCH_VERSION: 5.3 GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.2 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.3 DESIRED_PYTHON: "3.10" - build_name: manywheel-py3_10-rocm5_2 + build_name: manywheel-py3_10-rocm5_3 build_environment: linux-binary-manywheel secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_10-rocm5_2-test: # Testing + manywheel-py3_10-rocm5_3-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: manywheel-py3_10-rocm5_2-build + needs: manywheel-py3_10-rocm5_3-build runs-on: linux.rocm.gpu timeout-minutes: 240 env: @@ -1998,11 +2033,11 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm5.2 - GPU_ARCH_VERSION: 5.2 + DESIRED_CUDA: rocm5.3 + GPU_ARCH_VERSION: 5.3 GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 - DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.2 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.3 DESIRED_PYTHON: "3.10" steps: - name: Clean workspace @@ -2031,7 +2066,12 @@ jobs: run: | ngpu=$(rocminfo | grep -c -E 'Name:.*\sgfx') if [[ "x$ngpu" != "x2" && "x$ngpu" != "x4" ]]; then - echo "Failed to detect GPUs on the runner" + if [[ $ngpu -eq 0 ]]; then + echo "Error: Failed to detect any GPUs on the runner" + else + echo "Error: Detected $ngpu GPUs on the runner, when only 2 or 4 were expected" + fi + echo "Please file an issue on pytorch/pytorch reporting the faulty runner. Include a link to the runner logs so the runner can be identified" exit 1 fi - name: Runner health check disconnect on failure @@ -2045,7 +2085,7 @@ jobs: - uses: actions/download-artifact@v3 name: Download Build Artifacts with: - name: manywheel-py3_10-rocm5_2 + name: manywheel-py3_10-rocm5_3 path: "${{ runner.temp }}/artifacts/" - name: Checkout PyTorch uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 @@ -2076,7 +2116,7 @@ jobs: - name: Pull Docker image uses: pytorch/test-infra/.github/actions/pull-docker-image@main with: - docker-image: pytorch/manylinux-builder:rocm5.2 + docker-image: pytorch/manylinux-builder:rocm5.3 - name: Test Pytorch binary uses: ./pytorch/.github/actions/test-pytorch-binary - name: Kill containers, clean up images @@ -2087,21 +2127,21 @@ jobs: docker stop $(docker ps -q) || true # Prune all of the docker images docker system prune -af - manywheel-py3_10-rocm5_2-upload: # Uploading + manywheel-py3_10-rocm5_3-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} - needs: manywheel-py3_10-rocm5_2-test + needs: manywheel-py3_10-rocm5_3-test with: PYTORCH_ROOT: /pytorch BUILDER_ROOT: /builder PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm5.2 - GPU_ARCH_VERSION: 5.2 + DESIRED_CUDA: rocm5.3 + GPU_ARCH_VERSION: 5.3 GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.2 + DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.3 DESIRED_PYTHON: "3.10" - build_name: manywheel-py3_10-rocm5_2 + build_name: manywheel-py3_10-rocm5_3 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} aws-access-key-id: ${{ secrets.AWS_PYTORCH_UPLOADER_ACCESS_KEY_ID }} @@ -2241,7 +2281,7 @@ jobs: DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda11_7-with-pypi-cudnn build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-runtime-cu11;nvidia-cudnn-cu11==8.5.0.96;nvidia-cublas-cu11==11.10.3.66 + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-runtime-cu11; platform_system == 'Linux' | nvidia-cudnn-cu11==8.5.0.96; platform_system == 'Linux' | nvidia-cublas-cu11==11.10.3.66; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/generated-macos-arm64-binary-conda-nightly.yml b/.github/workflows/generated-macos-arm64-binary-conda-nightly.yml index 52fe582aa59ee..c88b107a90a94 100644 --- a/.github/workflows/generated-macos-arm64-binary-conda-nightly.yml +++ b/.github/workflows/generated-macos-arm64-binary-conda-nightly.yml @@ -67,10 +67,11 @@ jobs: - name: Install conda and dependencies run: | # Install conda, setup-miniconda messes with the path that messes with the ruby stuff we do later on - curl --retry 3 -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh + curl --retry 3 --retry-all-errors -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh chmod +x "${RUNNER_TEMP}/conda.sh" /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" + echo "DEVELOPER_DIR=/Applications/Xcode_13.3.1.app/Contents/Developer" >> "${GITHUB_ENV}" - name: Checkout PyTorch uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: @@ -102,7 +103,7 @@ jobs: max_attempts: 3 retry_wait_seconds: 90 command: | - sudo curl --retry 3 https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache + sudo curl --retry 3 --retry-all-errors https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache sudo chmod +x /usr/local/bin/sccache echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" - name: Populate binary env @@ -115,7 +116,7 @@ jobs: # shellcheck disable=SC1091 source "${RUNNER_TEMP}/anaconda/bin/activate" "${PYTORCH_ROOT}/.circleci/scripts/binary_macos_build.sh" - - uses: actions/upload-artifact@v2 + - uses: actions/upload-artifact@v3 if: always() with: name: conda-py3_8-cpu @@ -176,10 +177,11 @@ jobs: - name: Install conda and dependencies run: | # Install conda, setup-miniconda messes with the path that messes with the ruby stuff we do later on - curl --retry 3 -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh + curl --retry 3 --retry-all-errors -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh chmod +x "${RUNNER_TEMP}/conda.sh" /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" + echo "DEVELOPER_DIR=/Applications/Xcode_13.3.1.app/Contents/Developer" >> "${GITHUB_ENV}" - name: Checkout PyTorch uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: @@ -211,7 +213,7 @@ jobs: max_attempts: 3 retry_wait_seconds: 90 command: | - sudo curl --retry 3 https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache + sudo curl --retry 3 --retry-all-errors https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache sudo chmod +x /usr/local/bin/sccache echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" - name: Populate binary env @@ -224,7 +226,7 @@ jobs: # shellcheck disable=SC1091 source "${RUNNER_TEMP}/anaconda/bin/activate" "${PYTORCH_ROOT}/.circleci/scripts/binary_macos_build.sh" - - uses: actions/upload-artifact@v2 + - uses: actions/upload-artifact@v3 if: always() with: name: conda-py3_9-cpu @@ -285,10 +287,11 @@ jobs: - name: Install conda and dependencies run: | # Install conda, setup-miniconda messes with the path that messes with the ruby stuff we do later on - curl --retry 3 -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh + curl --retry 3 --retry-all-errors -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh chmod +x "${RUNNER_TEMP}/conda.sh" /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" + echo "DEVELOPER_DIR=/Applications/Xcode_13.3.1.app/Contents/Developer" >> "${GITHUB_ENV}" - name: Checkout PyTorch uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: @@ -320,7 +323,7 @@ jobs: max_attempts: 3 retry_wait_seconds: 90 command: | - sudo curl --retry 3 https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache + sudo curl --retry 3 --retry-all-errors https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache sudo chmod +x /usr/local/bin/sccache echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" - name: Populate binary env @@ -333,7 +336,7 @@ jobs: # shellcheck disable=SC1091 source "${RUNNER_TEMP}/anaconda/bin/activate" "${PYTORCH_ROOT}/.circleci/scripts/binary_macos_build.sh" - - uses: actions/upload-artifact@v2 + - uses: actions/upload-artifact@v3 if: always() with: name: conda-py3_10-cpu diff --git a/.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml b/.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml index fc5f84d9484ea..c8858fd0501bd 100644 --- a/.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml +++ b/.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml @@ -34,115 +34,6 @@ concurrency: cancel-in-progress: true jobs: - wheel-py3_7-cpu-build: - if: ${{ github.repository_owner == 'pytorch' }} - runs-on: macos-12-xl - timeout-minutes: 240 - env: - PYTORCH_ROOT: ${{ github.workspace }}/pytorch - BUILDER_ROOT: ${{ github.workspace }}/builder - PACKAGE_TYPE: wheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cpu - GPU_ARCH_TYPE: cpu - SKIP_ALL_TESTS: 1 - DESIRED_PYTHON: "3.7" - # For sccache access (only on non-forked PRs) - AWS_ACCESS_KEY_ID: ${{ secrets.MACOS_SCCACHE_S3_ACCESS_KEY_ID }} - AWS_SECRET_ACCESS_KEY: ${{ secrets.MACOS_SCCACHE_S3_SECRET_ACCESS_KEY }} - steps: - # NOTE: These environment variables are put here so that they can be applied on every job equally - # They are also here because setting them at a workflow level doesn't give us access to the - # runner.temp variable, which we need. - - name: Populate binary env - shell: bash - run: | - # shellcheck disable=SC2129 - echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" - # shellcheck disable=SC2129 - echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" - # shellcheck disable=SC2129 - echo "MAC_PACKAGE_WORK_DIR=${RUNNER_TEMP}" >> "${GITHUB_ENV}" - - name: Install conda and dependencies - run: | - # Install conda, setup-miniconda messes with the path that messes with the ruby stuff we do later on - curl --retry 3 -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh - chmod +x "${RUNNER_TEMP}/conda.sh" - /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" - echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" - - name: Checkout PyTorch - uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 - with: - ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} - submodules: recursive - path: pytorch - - name: Clean PyTorch checkout - run: | - # Remove any artifacts from the previous checkouts - git clean -fxd - working-directory: pytorch - - name: Checkout pytorch/builder - uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 - with: - ref: main - submodules: recursive - repository: pytorch/builder - path: builder - - name: Clean pytorch/builder checkout - run: | - # Remove any artifacts from the previous checkouts - git clean -fxd - working-directory: builder - - name: Install sccache (only for non-forked PRs, and pushes to trunk) - uses: nick-fields/retry@v2.8.2 - if: ${{ github.event_name == 'push' || github.event.pull_request.head.repo.full_name == github.repository }} - with: - timeout_minutes: 5 - max_attempts: 3 - retry_wait_seconds: 90 - command: | - sudo curl --retry 3 https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache - sudo chmod +x /usr/local/bin/sccache - echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" - - name: Populate binary env - run: | - # shellcheck disable=SC1091 - source "${RUNNER_TEMP}/anaconda/bin/activate" - "${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh" - - name: Build PyTorch binary - run: | - # shellcheck disable=SC1091 - source "${RUNNER_TEMP}/anaconda/bin/activate" - "${PYTORCH_ROOT}/.circleci/scripts/binary_macos_build.sh" - - uses: actions/upload-artifact@v2 - if: always() - with: - name: wheel-py3_7-cpu - retention-days: 14 - if-no-files-found: error - path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - wheel-py3_7-cpu-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - needs: wheel-py3_7-cpu-build - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: wheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cpu - GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/manylinux-builder:cpu - DESIRED_PYTHON: "3.7" - build_name: wheel-py3_7-cpu - use_s3: False - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - aws-access-key-id: ${{ secrets.AWS_PYTORCH_UPLOADER_ACCESS_KEY_ID }} - aws-pytorch-uploader-secret-access-key: ${{ secrets.AWS_PYTORCH_UPLOADER_SECRET_ACCESS_KEY }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - uses: ./.github/workflows/_binary-upload.yml wheel-py3_8-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} runs-on: macos-12-xl @@ -176,10 +67,11 @@ jobs: - name: Install conda and dependencies run: | # Install conda, setup-miniconda messes with the path that messes with the ruby stuff we do later on - curl --retry 3 -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh + curl --retry 3 --retry-all-errors -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh chmod +x "${RUNNER_TEMP}/conda.sh" /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" + echo "DEVELOPER_DIR=/Applications/Xcode_13.3.1.app/Contents/Developer" >> "${GITHUB_ENV}" - name: Checkout PyTorch uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: @@ -211,7 +103,7 @@ jobs: max_attempts: 3 retry_wait_seconds: 90 command: | - sudo curl --retry 3 https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache + sudo curl --retry 3 --retry-all-errors https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache sudo chmod +x /usr/local/bin/sccache echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" - name: Populate binary env @@ -224,7 +116,7 @@ jobs: # shellcheck disable=SC1091 source "${RUNNER_TEMP}/anaconda/bin/activate" "${PYTORCH_ROOT}/.circleci/scripts/binary_macos_build.sh" - - uses: actions/upload-artifact@v2 + - uses: actions/upload-artifact@v3 if: always() with: name: wheel-py3_8-cpu @@ -285,10 +177,11 @@ jobs: - name: Install conda and dependencies run: | # Install conda, setup-miniconda messes with the path that messes with the ruby stuff we do later on - curl --retry 3 -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh + curl --retry 3 --retry-all-errors -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh chmod +x "${RUNNER_TEMP}/conda.sh" /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" + echo "DEVELOPER_DIR=/Applications/Xcode_13.3.1.app/Contents/Developer" >> "${GITHUB_ENV}" - name: Checkout PyTorch uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: @@ -320,7 +213,7 @@ jobs: max_attempts: 3 retry_wait_seconds: 90 command: | - sudo curl --retry 3 https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache + sudo curl --retry 3 --retry-all-errors https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache sudo chmod +x /usr/local/bin/sccache echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" - name: Populate binary env @@ -333,7 +226,7 @@ jobs: # shellcheck disable=SC1091 source "${RUNNER_TEMP}/anaconda/bin/activate" "${PYTORCH_ROOT}/.circleci/scripts/binary_macos_build.sh" - - uses: actions/upload-artifact@v2 + - uses: actions/upload-artifact@v3 if: always() with: name: wheel-py3_9-cpu @@ -394,10 +287,11 @@ jobs: - name: Install conda and dependencies run: | # Install conda, setup-miniconda messes with the path that messes with the ruby stuff we do later on - curl --retry 3 -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh + curl --retry 3 --retry-all-errors -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh chmod +x "${RUNNER_TEMP}/conda.sh" /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" + echo "DEVELOPER_DIR=/Applications/Xcode_13.3.1.app/Contents/Developer" >> "${GITHUB_ENV}" - name: Checkout PyTorch uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: @@ -429,7 +323,7 @@ jobs: max_attempts: 3 retry_wait_seconds: 90 command: | - sudo curl --retry 3 https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache + sudo curl --retry 3 --retry-all-errors https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache sudo chmod +x /usr/local/bin/sccache echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" - name: Populate binary env @@ -442,7 +336,7 @@ jobs: # shellcheck disable=SC1091 source "${RUNNER_TEMP}/anaconda/bin/activate" "${PYTORCH_ROOT}/.circleci/scripts/binary_macos_build.sh" - - uses: actions/upload-artifact@v2 + - uses: actions/upload-artifact@v3 if: always() with: name: wheel-py3_10-cpu diff --git a/.github/workflows/generated-macos-binary-conda-nightly.yml b/.github/workflows/generated-macos-binary-conda-nightly.yml index 8fab29ddaed9f..52cfb3d98f764 100644 --- a/.github/workflows/generated-macos-binary-conda-nightly.yml +++ b/.github/workflows/generated-macos-binary-conda-nightly.yml @@ -65,10 +65,11 @@ jobs: - name: Install conda and dependencies run: | # Install conda, setup-miniconda messes with the path that messes with the ruby stuff we do later on - curl --retry 3 -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh + curl --retry 3 --retry-all-errors -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh chmod +x "${RUNNER_TEMP}/conda.sh" /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" + echo "DEVELOPER_DIR=/Applications/Xcode_13.3.1.app/Contents/Developer" >> "${GITHUB_ENV}" - name: Checkout PyTorch uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: @@ -100,7 +101,7 @@ jobs: max_attempts: 3 retry_wait_seconds: 90 command: | - sudo curl --retry 3 https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache + sudo curl --retry 3 --retry-all-errors https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache sudo chmod +x /usr/local/bin/sccache echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" - name: Populate binary env @@ -113,7 +114,7 @@ jobs: # shellcheck disable=SC1091 source "${RUNNER_TEMP}/anaconda/bin/activate" "${PYTORCH_ROOT}/.circleci/scripts/binary_macos_build.sh" - - uses: actions/upload-artifact@v2 + - uses: actions/upload-artifact@v3 if: always() with: name: conda-py3_7-cpu @@ -174,10 +175,11 @@ jobs: - name: Install conda and dependencies run: | # Install conda, setup-miniconda messes with the path that messes with the ruby stuff we do later on - curl --retry 3 -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh + curl --retry 3 --retry-all-errors -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh chmod +x "${RUNNER_TEMP}/conda.sh" /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" + echo "DEVELOPER_DIR=/Applications/Xcode_13.3.1.app/Contents/Developer" >> "${GITHUB_ENV}" - name: Checkout PyTorch uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: @@ -209,7 +211,7 @@ jobs: max_attempts: 3 retry_wait_seconds: 90 command: | - sudo curl --retry 3 https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache + sudo curl --retry 3 --retry-all-errors https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache sudo chmod +x /usr/local/bin/sccache echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" - name: Populate binary env @@ -222,7 +224,7 @@ jobs: # shellcheck disable=SC1091 source "${RUNNER_TEMP}/anaconda/bin/activate" "${PYTORCH_ROOT}/.circleci/scripts/binary_macos_build.sh" - - uses: actions/upload-artifact@v2 + - uses: actions/upload-artifact@v3 if: always() with: name: conda-py3_8-cpu @@ -283,10 +285,11 @@ jobs: - name: Install conda and dependencies run: | # Install conda, setup-miniconda messes with the path that messes with the ruby stuff we do later on - curl --retry 3 -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh + curl --retry 3 --retry-all-errors -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh chmod +x "${RUNNER_TEMP}/conda.sh" /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" + echo "DEVELOPER_DIR=/Applications/Xcode_13.3.1.app/Contents/Developer" >> "${GITHUB_ENV}" - name: Checkout PyTorch uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: @@ -318,7 +321,7 @@ jobs: max_attempts: 3 retry_wait_seconds: 90 command: | - sudo curl --retry 3 https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache + sudo curl --retry 3 --retry-all-errors https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache sudo chmod +x /usr/local/bin/sccache echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" - name: Populate binary env @@ -331,7 +334,7 @@ jobs: # shellcheck disable=SC1091 source "${RUNNER_TEMP}/anaconda/bin/activate" "${PYTORCH_ROOT}/.circleci/scripts/binary_macos_build.sh" - - uses: actions/upload-artifact@v2 + - uses: actions/upload-artifact@v3 if: always() with: name: conda-py3_9-cpu @@ -392,10 +395,11 @@ jobs: - name: Install conda and dependencies run: | # Install conda, setup-miniconda messes with the path that messes with the ruby stuff we do later on - curl --retry 3 -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh + curl --retry 3 --retry-all-errors -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh chmod +x "${RUNNER_TEMP}/conda.sh" /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" + echo "DEVELOPER_DIR=/Applications/Xcode_13.3.1.app/Contents/Developer" >> "${GITHUB_ENV}" - name: Checkout PyTorch uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: @@ -427,7 +431,7 @@ jobs: max_attempts: 3 retry_wait_seconds: 90 command: | - sudo curl --retry 3 https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache + sudo curl --retry 3 --retry-all-errors https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache sudo chmod +x /usr/local/bin/sccache echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" - name: Populate binary env @@ -440,7 +444,7 @@ jobs: # shellcheck disable=SC1091 source "${RUNNER_TEMP}/anaconda/bin/activate" "${PYTORCH_ROOT}/.circleci/scripts/binary_macos_build.sh" - - uses: actions/upload-artifact@v2 + - uses: actions/upload-artifact@v3 if: always() with: name: conda-py3_10-cpu diff --git a/.github/workflows/generated-macos-binary-libtorch-cxx11-abi-nightly.yml b/.github/workflows/generated-macos-binary-libtorch-cxx11-abi-nightly.yml index ae63f95bc3189..cd9ad45ba5610 100644 --- a/.github/workflows/generated-macos-binary-libtorch-cxx11-abi-nightly.yml +++ b/.github/workflows/generated-macos-binary-libtorch-cxx11-abi-nightly.yml @@ -34,9 +34,8 @@ concurrency: jobs: libtorch-cpu-shared-with-deps-cxx11-abi-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: macos-10.15 - # libtorch builds take a long time on github hosted runners - timeout-minutes: 720 + runs-on: macos-12-xl + timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch BUILDER_ROOT: ${{ github.workspace }}/builder @@ -70,10 +69,11 @@ jobs: - name: Install conda and dependencies run: | # Install conda, setup-miniconda messes with the path that messes with the ruby stuff we do later on - curl --retry 3 -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh + curl --retry 3 --retry-all-errors -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh chmod +x "${RUNNER_TEMP}/conda.sh" /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" + echo "DEVELOPER_DIR=/Applications/Xcode_13.3.1.app/Contents/Developer" >> "${GITHUB_ENV}" - name: Checkout PyTorch uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: @@ -105,7 +105,7 @@ jobs: max_attempts: 3 retry_wait_seconds: 90 command: | - sudo curl --retry 3 https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache + sudo curl --retry 3 --retry-all-errors https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache sudo chmod +x /usr/local/bin/sccache echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" - name: Populate binary env @@ -118,7 +118,7 @@ jobs: # shellcheck disable=SC1091 source "${RUNNER_TEMP}/anaconda/bin/activate" "${PYTORCH_ROOT}/.circleci/scripts/binary_macos_build.sh" - - uses: actions/upload-artifact@v2 + - uses: actions/upload-artifact@v3 if: always() with: name: libtorch-cpu-shared-with-deps-cxx11-abi @@ -149,9 +149,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml libtorch-cpu-shared-without-deps-cxx11-abi-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: macos-10.15 - # libtorch builds take a long time on github hosted runners - timeout-minutes: 720 + runs-on: macos-12-xl + timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch BUILDER_ROOT: ${{ github.workspace }}/builder @@ -185,10 +184,11 @@ jobs: - name: Install conda and dependencies run: | # Install conda, setup-miniconda messes with the path that messes with the ruby stuff we do later on - curl --retry 3 -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh + curl --retry 3 --retry-all-errors -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh chmod +x "${RUNNER_TEMP}/conda.sh" /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" + echo "DEVELOPER_DIR=/Applications/Xcode_13.3.1.app/Contents/Developer" >> "${GITHUB_ENV}" - name: Checkout PyTorch uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: @@ -220,7 +220,7 @@ jobs: max_attempts: 3 retry_wait_seconds: 90 command: | - sudo curl --retry 3 https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache + sudo curl --retry 3 --retry-all-errors https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache sudo chmod +x /usr/local/bin/sccache echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" - name: Populate binary env @@ -233,7 +233,7 @@ jobs: # shellcheck disable=SC1091 source "${RUNNER_TEMP}/anaconda/bin/activate" "${PYTORCH_ROOT}/.circleci/scripts/binary_macos_build.sh" - - uses: actions/upload-artifact@v2 + - uses: actions/upload-artifact@v3 if: always() with: name: libtorch-cpu-shared-without-deps-cxx11-abi @@ -264,9 +264,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml libtorch-cpu-static-with-deps-cxx11-abi-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: macos-10.15 - # libtorch builds take a long time on github hosted runners - timeout-minutes: 720 + runs-on: macos-12-xl + timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch BUILDER_ROOT: ${{ github.workspace }}/builder @@ -300,10 +299,11 @@ jobs: - name: Install conda and dependencies run: | # Install conda, setup-miniconda messes with the path that messes with the ruby stuff we do later on - curl --retry 3 -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh + curl --retry 3 --retry-all-errors -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh chmod +x "${RUNNER_TEMP}/conda.sh" /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" + echo "DEVELOPER_DIR=/Applications/Xcode_13.3.1.app/Contents/Developer" >> "${GITHUB_ENV}" - name: Checkout PyTorch uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: @@ -335,7 +335,7 @@ jobs: max_attempts: 3 retry_wait_seconds: 90 command: | - sudo curl --retry 3 https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache + sudo curl --retry 3 --retry-all-errors https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache sudo chmod +x /usr/local/bin/sccache echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" - name: Populate binary env @@ -348,7 +348,7 @@ jobs: # shellcheck disable=SC1091 source "${RUNNER_TEMP}/anaconda/bin/activate" "${PYTORCH_ROOT}/.circleci/scripts/binary_macos_build.sh" - - uses: actions/upload-artifact@v2 + - uses: actions/upload-artifact@v3 if: always() with: name: libtorch-cpu-static-with-deps-cxx11-abi @@ -379,9 +379,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml libtorch-cpu-static-without-deps-cxx11-abi-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: macos-10.15 - # libtorch builds take a long time on github hosted runners - timeout-minutes: 720 + runs-on: macos-12-xl + timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch BUILDER_ROOT: ${{ github.workspace }}/builder @@ -415,10 +414,11 @@ jobs: - name: Install conda and dependencies run: | # Install conda, setup-miniconda messes with the path that messes with the ruby stuff we do later on - curl --retry 3 -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh + curl --retry 3 --retry-all-errors -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh chmod +x "${RUNNER_TEMP}/conda.sh" /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" + echo "DEVELOPER_DIR=/Applications/Xcode_13.3.1.app/Contents/Developer" >> "${GITHUB_ENV}" - name: Checkout PyTorch uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: @@ -450,7 +450,7 @@ jobs: max_attempts: 3 retry_wait_seconds: 90 command: | - sudo curl --retry 3 https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache + sudo curl --retry 3 --retry-all-errors https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache sudo chmod +x /usr/local/bin/sccache echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" - name: Populate binary env @@ -463,7 +463,7 @@ jobs: # shellcheck disable=SC1091 source "${RUNNER_TEMP}/anaconda/bin/activate" "${PYTORCH_ROOT}/.circleci/scripts/binary_macos_build.sh" - - uses: actions/upload-artifact@v2 + - uses: actions/upload-artifact@v3 if: always() with: name: libtorch-cpu-static-without-deps-cxx11-abi diff --git a/.github/workflows/generated-macos-binary-libtorch-pre-cxx11-nightly.yml b/.github/workflows/generated-macos-binary-libtorch-pre-cxx11-nightly.yml index 39ad514a56702..4ce5c6f32c36d 100644 --- a/.github/workflows/generated-macos-binary-libtorch-pre-cxx11-nightly.yml +++ b/.github/workflows/generated-macos-binary-libtorch-pre-cxx11-nightly.yml @@ -34,9 +34,8 @@ concurrency: jobs: libtorch-cpu-shared-with-deps-pre-cxx11-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: macos-10.15 - # libtorch builds take a long time on github hosted runners - timeout-minutes: 720 + runs-on: macos-12-xl + timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch BUILDER_ROOT: ${{ github.workspace }}/builder @@ -70,10 +69,11 @@ jobs: - name: Install conda and dependencies run: | # Install conda, setup-miniconda messes with the path that messes with the ruby stuff we do later on - curl --retry 3 -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh + curl --retry 3 --retry-all-errors -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh chmod +x "${RUNNER_TEMP}/conda.sh" /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" + echo "DEVELOPER_DIR=/Applications/Xcode_13.3.1.app/Contents/Developer" >> "${GITHUB_ENV}" - name: Checkout PyTorch uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: @@ -105,7 +105,7 @@ jobs: max_attempts: 3 retry_wait_seconds: 90 command: | - sudo curl --retry 3 https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache + sudo curl --retry 3 --retry-all-errors https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache sudo chmod +x /usr/local/bin/sccache echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" - name: Populate binary env @@ -118,7 +118,7 @@ jobs: # shellcheck disable=SC1091 source "${RUNNER_TEMP}/anaconda/bin/activate" "${PYTORCH_ROOT}/.circleci/scripts/binary_macos_build.sh" - - uses: actions/upload-artifact@v2 + - uses: actions/upload-artifact@v3 if: always() with: name: libtorch-cpu-shared-with-deps-pre-cxx11 @@ -149,9 +149,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml libtorch-cpu-shared-without-deps-pre-cxx11-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: macos-10.15 - # libtorch builds take a long time on github hosted runners - timeout-minutes: 720 + runs-on: macos-12-xl + timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch BUILDER_ROOT: ${{ github.workspace }}/builder @@ -185,10 +184,11 @@ jobs: - name: Install conda and dependencies run: | # Install conda, setup-miniconda messes with the path that messes with the ruby stuff we do later on - curl --retry 3 -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh + curl --retry 3 --retry-all-errors -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh chmod +x "${RUNNER_TEMP}/conda.sh" /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" + echo "DEVELOPER_DIR=/Applications/Xcode_13.3.1.app/Contents/Developer" >> "${GITHUB_ENV}" - name: Checkout PyTorch uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: @@ -220,7 +220,7 @@ jobs: max_attempts: 3 retry_wait_seconds: 90 command: | - sudo curl --retry 3 https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache + sudo curl --retry 3 --retry-all-errors https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache sudo chmod +x /usr/local/bin/sccache echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" - name: Populate binary env @@ -233,7 +233,7 @@ jobs: # shellcheck disable=SC1091 source "${RUNNER_TEMP}/anaconda/bin/activate" "${PYTORCH_ROOT}/.circleci/scripts/binary_macos_build.sh" - - uses: actions/upload-artifact@v2 + - uses: actions/upload-artifact@v3 if: always() with: name: libtorch-cpu-shared-without-deps-pre-cxx11 @@ -264,9 +264,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml libtorch-cpu-static-with-deps-pre-cxx11-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: macos-10.15 - # libtorch builds take a long time on github hosted runners - timeout-minutes: 720 + runs-on: macos-12-xl + timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch BUILDER_ROOT: ${{ github.workspace }}/builder @@ -300,10 +299,11 @@ jobs: - name: Install conda and dependencies run: | # Install conda, setup-miniconda messes with the path that messes with the ruby stuff we do later on - curl --retry 3 -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh + curl --retry 3 --retry-all-errors -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh chmod +x "${RUNNER_TEMP}/conda.sh" /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" + echo "DEVELOPER_DIR=/Applications/Xcode_13.3.1.app/Contents/Developer" >> "${GITHUB_ENV}" - name: Checkout PyTorch uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: @@ -335,7 +335,7 @@ jobs: max_attempts: 3 retry_wait_seconds: 90 command: | - sudo curl --retry 3 https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache + sudo curl --retry 3 --retry-all-errors https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache sudo chmod +x /usr/local/bin/sccache echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" - name: Populate binary env @@ -348,7 +348,7 @@ jobs: # shellcheck disable=SC1091 source "${RUNNER_TEMP}/anaconda/bin/activate" "${PYTORCH_ROOT}/.circleci/scripts/binary_macos_build.sh" - - uses: actions/upload-artifact@v2 + - uses: actions/upload-artifact@v3 if: always() with: name: libtorch-cpu-static-with-deps-pre-cxx11 @@ -379,9 +379,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml libtorch-cpu-static-without-deps-pre-cxx11-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: macos-10.15 - # libtorch builds take a long time on github hosted runners - timeout-minutes: 720 + runs-on: macos-12-xl + timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch BUILDER_ROOT: ${{ github.workspace }}/builder @@ -415,10 +414,11 @@ jobs: - name: Install conda and dependencies run: | # Install conda, setup-miniconda messes with the path that messes with the ruby stuff we do later on - curl --retry 3 -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh + curl --retry 3 --retry-all-errors -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh chmod +x "${RUNNER_TEMP}/conda.sh" /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" + echo "DEVELOPER_DIR=/Applications/Xcode_13.3.1.app/Contents/Developer" >> "${GITHUB_ENV}" - name: Checkout PyTorch uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: @@ -450,7 +450,7 @@ jobs: max_attempts: 3 retry_wait_seconds: 90 command: | - sudo curl --retry 3 https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache + sudo curl --retry 3 --retry-all-errors https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache sudo chmod +x /usr/local/bin/sccache echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" - name: Populate binary env @@ -463,7 +463,7 @@ jobs: # shellcheck disable=SC1091 source "${RUNNER_TEMP}/anaconda/bin/activate" "${PYTORCH_ROOT}/.circleci/scripts/binary_macos_build.sh" - - uses: actions/upload-artifact@v2 + - uses: actions/upload-artifact@v3 if: always() with: name: libtorch-cpu-static-without-deps-pre-cxx11 diff --git a/.github/workflows/generated-macos-binary-wheel-nightly.yml b/.github/workflows/generated-macos-binary-wheel-nightly.yml index 70d6783dbe881..a3839d6e8a142 100644 --- a/.github/workflows/generated-macos-binary-wheel-nightly.yml +++ b/.github/workflows/generated-macos-binary-wheel-nightly.yml @@ -65,10 +65,11 @@ jobs: - name: Install conda and dependencies run: | # Install conda, setup-miniconda messes with the path that messes with the ruby stuff we do later on - curl --retry 3 -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh + curl --retry 3 --retry-all-errors -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh chmod +x "${RUNNER_TEMP}/conda.sh" /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" + echo "DEVELOPER_DIR=/Applications/Xcode_13.3.1.app/Contents/Developer" >> "${GITHUB_ENV}" - name: Checkout PyTorch uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: @@ -100,7 +101,7 @@ jobs: max_attempts: 3 retry_wait_seconds: 90 command: | - sudo curl --retry 3 https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache + sudo curl --retry 3 --retry-all-errors https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache sudo chmod +x /usr/local/bin/sccache echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" - name: Populate binary env @@ -113,7 +114,7 @@ jobs: # shellcheck disable=SC1091 source "${RUNNER_TEMP}/anaconda/bin/activate" "${PYTORCH_ROOT}/.circleci/scripts/binary_macos_build.sh" - - uses: actions/upload-artifact@v2 + - uses: actions/upload-artifact@v3 if: always() with: name: wheel-py3_7-cpu @@ -174,10 +175,11 @@ jobs: - name: Install conda and dependencies run: | # Install conda, setup-miniconda messes with the path that messes with the ruby stuff we do later on - curl --retry 3 -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh + curl --retry 3 --retry-all-errors -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh chmod +x "${RUNNER_TEMP}/conda.sh" /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" + echo "DEVELOPER_DIR=/Applications/Xcode_13.3.1.app/Contents/Developer" >> "${GITHUB_ENV}" - name: Checkout PyTorch uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: @@ -209,7 +211,7 @@ jobs: max_attempts: 3 retry_wait_seconds: 90 command: | - sudo curl --retry 3 https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache + sudo curl --retry 3 --retry-all-errors https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache sudo chmod +x /usr/local/bin/sccache echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" - name: Populate binary env @@ -222,7 +224,7 @@ jobs: # shellcheck disable=SC1091 source "${RUNNER_TEMP}/anaconda/bin/activate" "${PYTORCH_ROOT}/.circleci/scripts/binary_macos_build.sh" - - uses: actions/upload-artifact@v2 + - uses: actions/upload-artifact@v3 if: always() with: name: wheel-py3_8-cpu @@ -283,10 +285,11 @@ jobs: - name: Install conda and dependencies run: | # Install conda, setup-miniconda messes with the path that messes with the ruby stuff we do later on - curl --retry 3 -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh + curl --retry 3 --retry-all-errors -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh chmod +x "${RUNNER_TEMP}/conda.sh" /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" + echo "DEVELOPER_DIR=/Applications/Xcode_13.3.1.app/Contents/Developer" >> "${GITHUB_ENV}" - name: Checkout PyTorch uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: @@ -318,7 +321,7 @@ jobs: max_attempts: 3 retry_wait_seconds: 90 command: | - sudo curl --retry 3 https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache + sudo curl --retry 3 --retry-all-errors https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache sudo chmod +x /usr/local/bin/sccache echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" - name: Populate binary env @@ -331,7 +334,7 @@ jobs: # shellcheck disable=SC1091 source "${RUNNER_TEMP}/anaconda/bin/activate" "${PYTORCH_ROOT}/.circleci/scripts/binary_macos_build.sh" - - uses: actions/upload-artifact@v2 + - uses: actions/upload-artifact@v3 if: always() with: name: wheel-py3_9-cpu @@ -392,10 +395,11 @@ jobs: - name: Install conda and dependencies run: | # Install conda, setup-miniconda messes with the path that messes with the ruby stuff we do later on - curl --retry 3 -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh + curl --retry 3 --retry-all-errors -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh chmod +x "${RUNNER_TEMP}/conda.sh" /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" + echo "DEVELOPER_DIR=/Applications/Xcode_13.3.1.app/Contents/Developer" >> "${GITHUB_ENV}" - name: Checkout PyTorch uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 with: @@ -427,7 +431,7 @@ jobs: max_attempts: 3 retry_wait_seconds: 90 command: | - sudo curl --retry 3 https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache + sudo curl --retry 3 --retry-all-errors https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache sudo chmod +x /usr/local/bin/sccache echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" - name: Populate binary env @@ -440,7 +444,7 @@ jobs: # shellcheck disable=SC1091 source "${RUNNER_TEMP}/anaconda/bin/activate" "${PYTORCH_ROOT}/.circleci/scripts/binary_macos_build.sh" - - uses: actions/upload-artifact@v2 + - uses: actions/upload-artifact@v3 if: always() with: name: wheel-py3_10-cpu diff --git a/.github/workflows/inductor.yml b/.github/workflows/inductor.yml index a5aa7acaec0b9..9179b186e9182 100644 --- a/.github/workflows/inductor.yml +++ b/.github/workflows/inductor.yml @@ -1,11 +1,12 @@ name: inductor on: + schedule: + - cron: 45 1,5,9,13,17,21 * * * push: - branches: - - master tags: - ciflow/inductor/* + - ciflow/periodic/* workflow_dispatch: concurrency: @@ -19,11 +20,15 @@ jobs: with: build-environment: linux-bionic-cuda11.6-py3.10-gcc7-sm86 docker-image-name: pytorch-linux-bionic-cuda11.6-cudnn8-py3-gcc7 - cuda-arch-list: 8.6 + cuda-arch-list: '8.6' test-matrix: | { include: [ - { config: "inductor", shard: 1, num_shards: 6, runner: "linux.g5.4xlarge.nvidia.gpu" }, - { config: "inductor", shard: 2, num_shards: 6, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor_huggingface", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor_timm", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor_timm", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor_torchbench", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor_distributed", shard: 1, num_shards: 1, runner: "linux.g5.12xlarge.nvidia.gpu" }, ]} linux-bionic-cuda11_6-py3_10-gcc7-inductor-test: diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml index 1d48268beccd1..bdef7a1367bfc 100644 --- a/.github/workflows/labeler.yml +++ b/.github/workflows/labeler.yml @@ -13,7 +13,7 @@ jobs: - uses: actions/labeler@v4 with: repo-token: "${{ secrets.GITHUB_TOKEN }}" - sync-labels: true + sync-labels: '' concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 1803395f81d97..0b846bc5a90fa 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -5,14 +5,13 @@ on: push: branches: - master - - main - - release/* - - landchecks/* workflow_dispatch: +# The names of steps that actually test the code should be suffixed with `(nonretryable)`. +# When any other step fails, it's job will be retried once by retryBot. jobs: lintrunner: - runs-on: linux.20_04.16x + runs-on: macos-m1-12 steps: - name: Checkout PyTorch uses: pytorch/pytorch/.github/actions/checkout-pytorch@master @@ -20,295 +19,64 @@ jobs: submodules: false fetch-depth: 1 - - name: Setup Python - uses: actions/setup-python@v4 + - name: Setup miniconda + uses: pytorch/test-infra/.github/actions/setup-miniconda@main with: - python-version: 3.8 - architecture: x64 - cache: pip - cache-dependency-path: | - **/.github/requirements-gha-cache.txt + python-version: 3.9 + environment-file: .github/requirements/conda-env-${{ runner.os }}-${{ runner.arch }} + # pip-requirements-file: .github/requirements/pip-requirements-${{ runner.os }}.txt - - name: Install lintrunner - uses: nick-fields/retry@3e91a01664abd3c5cd539100d10d33b9c5b68482 - with: - timeout_minutes: 5 - max_attempts: 3 - command: pip install lintrunner==0.9.2 + - name: Install requirements + env: + ENV_NAME: conda-test-env-${{ github.run_id }} + PY_VERS: 3.9 + shell: arch -arch arm64 bash {0} + run: | + # shellcheck disable=SC1090 + set -ex + ${CONDA_RUN} python3 -m pip install --force-reinstall -r .github/requirements-gha-cache.txt - name: Initialize lint dependencies - run: lintrunner init + env: + ENV_NAME: conda-test-env-${{ github.run_id }} + PY_VERS: 3.9 + shell: arch -arch arm64 bash {0} + run: | + # shellcheck disable=SC1090 + set -ex + ${CONDA_RUN} lintrunner init - name: Do build steps necessary for linters - run: | - python3 -m tools.linter.clang_tidy.generate_build_files - python3 -m tools.generate_torch_version --is_debug=false - python3 -m tools.pyi.gen_pyi \ + env: + ENV_NAME: conda-test-env-${{ github.run_id }} + PY_VERS: 3.9 + shell: arch -arch arm64 bash {0} + run: | + # shellcheck disable=SC1090 + set -ex + ${CONDA_RUN} python3 -m tools.linter.clang_tidy.generate_build_files + ${CONDA_RUN} python3 -m tools.generate_torch_version --is_debug=false + ${CONDA_RUN} python3 -m tools.pyi.gen_pyi \ --native-functions-path aten/src/ATen/native/native_functions.yaml \ --tags-path aten/src/ATen/native/tags.yaml \ --deprecated-functions-path "tools/autograd/deprecated.yaml" - - name: Run lintrunner on all files + - name: Run lintrunner on all MPS files (nonretryable) + env: + ENV_NAME: conda-test-env-${{ github.run_id }} + PY_VERS: 3.9 + shell: arch -arch arm64 bash {0} run: | + # shellcheck disable=SC1090 + set -ex set +e - if ! lintrunner --force-color --all-files --tee-json=lint.json; then + if ! ${CONDA_RUN} lintrunner --force-color aten/src/ATen/native/mps/operations/* test/test_mps.py; then echo "" echo -e "\e[1m\e[36mYou can reproduce these results locally by using \`lintrunner\`.\e[0m" echo -e "\e[1m\e[36mSee https://github.com/pytorch/pytorch/wiki/lintrunner for setup instructions.\e[0m" exit 1 fi - - name: Store annotations - if: always() && github.event_name == 'pull_request' - # Don't show this as an error; the above step will have already failed. - continue-on-error: true - run: | - # Use jq to massage the JSON lint output into GitHub Actions workflow commands. - jq --raw-output \ - '"::\(if .severity == "advice" or .severity == "disabled" then "warning" else .severity end) file=\(.path),line=\(.line),col=\(.char),title=\(.code) \(.name)::" + (.description | gsub("\\n"; "%0A"))' \ - lint.json - - quick-checks: - name: quick-checks - runs-on: linux.20_04.4x - steps: - # [see note: pytorch repo ref] - - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@master - with: - submodules: false - fetch-depth: 1 - - name: Clean PyTorch checkout - run: | - # Remove any artifacts from the previous checkouts - git clean -fxd - - name: Setup Python - uses: actions/setup-python@v4 - with: - python-version: 3.x - architecture: x64 - cache: pip - cache-dependency-path: | - **/requirements.txt - - name: Install requirements - id: requirements - run: pip install -r requirements.txt --user - - name: Ensure no non-breaking spaces - if: always() - run: | - # NB: We use 'printf' below rather than '\u000a' since bash pre-4.2 - # does not support the '\u000a' syntax (which is relevant for local linters) - (! git --no-pager grep -In "$(printf '\xC2\xA0')" -- . || (echo "The above lines have non-breaking spaces (U+00A0); please convert them to spaces (U+0020)"; false)) - - name: Ensure cross-OS compatible file names - if: always() - run: | - (! git ls-files | grep -E '([<>:"|?*]|[ .]$)' || (echo "The above file names are not valid across all operating systems. Please ensure they don't contain the characters '<>:""|?*' and don't end with a white space or a '.' "; false)) - - name: Ensure no versionless Python shebangs - if: always() - run: | - (! git --no-pager grep -In '#!.*python$' -- . || (echo "The above lines have versionless Python shebangs; please specify either python2 or python3"; false)) - - name: C++ docs check - if: ${{ always() && steps.requirements.outcome == 'success' }} - run: | - sudo apt-get install -y doxygen - cd docs/cpp/source && ./check-doxygen.sh - - name: CUDA kernel launch check - if: ${{ always() && steps.requirements.outcome == 'success' }} - run: | - set -eux - python torch/testing/_internal/check_kernel_launches.py |& tee "${GITHUB_WORKSPACE}"/cuda_kernel_launch_checks.txt - - pr-sanity-checks: - name: pr-sanity-checks - runs-on: linux.20_04.4x - # Only run this on pull requests - if: github.event_name == 'pull_request' && !contains(github.event.pull_request.labels.*.name, 'skip-pr-sanity-checks') - steps: - - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@master - with: - submodules: false - fetch-depth: -1 - - name: PR size check - env: - BASE: ${{ github.event.pull_request.base.sha }} - HEAD: ${{ github.event.pull_request.head.sha }} - run: | - bash .github/scripts/pr-sanity-check.sh - - workflow-checks: - name: workflow-checks - runs-on: linux.20_04.4x - steps: - # [see note: pytorch repo ref] - - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@master - with: - submodules: false - fetch-depth: 1 - - name: Setup Python - uses: actions/setup-python@v4 - with: - python-version: 3.x - architecture: x64 - cache: pip - cache-dependency-path: | - **/requirements.txt - **/.github/requirements-gha-cache.txt - - name: Install requirements - id: requirements - run: | - pip install -r requirements.txt --user - - name: Install Jinja2 - run: | - pip install Jinja2==3.0.1 --user - - name: Regenerate workflows - id: generate_workflows - run: .github/scripts/generate_ci_workflows.py - - name: Assert that regenerating the workflows didn't change them - run: | - if ! .github/scripts/report_git_status.sh .github/workflows; then - echo - echo 'As shown by the above diff, the committed .github/workflows' - echo 'are not up to date according to .github/templates.' - echo 'Please run this command, commit, and push again to your PR:' - echo - echo ' .github/scripts/generate_ci_workflows.py' - echo - echo 'If running that command does nothing, you may need to rebase' - echo 'onto a more recent commit from the PyTorch master branch.' - false - fi - - name: Check that jobs will be cancelled - if: ${{ always() && steps.generate_workflows.outcome == 'success' }} - run: | - .github/scripts/ensure_actions_will_cancel.py - - toc: - name: toc - runs-on: linux.20_04.4x - # https://github.com/actions/virtual-environments/issues/599#issuecomment-602754687 - env: - NPM_CONFIG_PREFIX: ~/.npm-global - steps: - # [see note: pytorch repo ref] - - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@master - with: - submodules: false - fetch-depth: 1 - # This is not a node project so there is no package-lock.json to cache - - name: Setup Node - uses: actions/setup-node@v3 - - name: Install markdown-toc - run: npm install -g markdown-toc - - name: Regenerate ToCs and check that they didn't change - run: | - set -eu - export PATH=~/.npm-global/bin:"$PATH" - for FILE in $(git grep -Il '' -- '**.md'); do - markdown-toc --bullets='-' -i "$FILE" - done - - if ! .github/scripts/report_git_status.sh .; then - echo - echo 'As shown by the above diff, the table of contents in one or' - echo 'more Markdown files is not up to date with the file contents.' - echo 'You can either apply that Git diff directly to correct the' - echo 'table of contents, or if you have npm installed, you can' - echo 'install the npm package markdown-toc and run the following' - # shellcheck disable=SC2016 - echo 'command (replacing $FILE with the filename for which you want' - echo 'to regenerate the table of contents):' - echo - # shellcheck disable=SC2016 - echo " markdown-toc --bullets='-' -i \"\$FILE\"" - false - fi - - test-tools: - name: Test tools - if: ${{ github.repository == 'pytorch/pytorch' }} - runs-on: linux.20_04.4x - steps: - # [see note: pytorch repo ref] - # deep clone (fetch-depth 0) required, to allow us to use git log - - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@master - with: - submodules: false - - name: Setup Python - uses: actions/setup-python@v4 - with: - python-version: 3.8 - architecture: x64 - cache: pip - cache-dependency-path: | - **/requirements.txt - **/requirements-flake8.txt - **/.circleci/docker/requirements-ci.txt - **/.github/requirements-gha-cache.txt - - name: Install dependencies - # mypy and boto3 versions copied from - # .circleci/docker/common/install_conda.sh - run: | - set -eux - pip install -r requirements.txt - pip install boto3==1.19.12 - pip install typing-extensions==3.10 --user - pip install -r requirements-flake8.txt --user - pip install rockset==0.8.10 --user - pip install -r requirements.txt --user - pip install mypy==0.960 --user - make setup_lint - - name: Test tools - run: | - python3 -m unittest discover -vs tools/test -p 'test_*.py' - python3 -m unittest discover -vs .github/scripts -p 'test_*.py' - - test_collect_env: - if: ${{ github.repository == 'pytorch/pytorch' }} - name: Test collect_env - runs-on: linux.20_04.4x - strategy: - matrix: - test_type: [with_torch, without_torch, older_python_version] - steps: - # [see note: pytorch repo ref] - # deep clone (fetch-depth 0) required, to allow us to use git log - - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@master - with: - submodules: false - fetch-depth: 1 - - name: Setup Python 3.5 - if: matrix.test_type == 'older_python_version' - uses: actions/setup-python@v4 - with: - python-version: 3.5 - architecture: x64 - cache: pip - cache-dependency-path: | - **/.github/requirements-gha-cache.txt - - name: Setup Python 3.8 - if: matrix.test_type != 'older_python_version' - uses: actions/setup-python@v4 - with: - python-version: 3.8 - architecture: x64 - cache: pip - cache-dependency-path: | - **/.github/requirements-gha-cache.txt - - name: Install torch - if: matrix.test_type == 'with_torch' - run: | - # Doesn't really matter what torch version, we just need ANY torch installed - pip install 'torch==1.*' - - name: Run collect_env.py - run: | - # All we need to see is that it passes - python3 torch/utils/collect_env.py - concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} cancel-in-progress: true diff --git a/.github/workflows/mac-mps.yml b/.github/workflows/mac-mps.yml index 8fc2dd8336bff..a2ca4867fd76b 100644 --- a/.github/workflows/mac-mps.yml +++ b/.github/workflows/mac-mps.yml @@ -1,10 +1,11 @@ name: Mac MPS on: - push: - tags: - - ciflow/mps/* - workflow_dispatch: + # push: + # tags: + # - ciflow/mps/* + # workflow_dispatch: + pull_request: concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }} @@ -18,10 +19,14 @@ jobs: sync-tag: macos-12-py3-arm64-build build-environment: macos-12-py3-arm64 xcode-version: "13.3.1" - runner-type: macos-12-xl + runner-type: macos-m1-13 build-generates-artifacts: true # To match the one pre-installed in the m1 runners python_version: 3.9.12 + # We need to set the environment file here instead of trying to detect it automatically because + # MacOS arm64 is cross-compiled from x86-64. Specifically, it means that arm64 conda environment + # is needed when building PyTorch MacOS arm64 from x86-64 + environment-file: .github/requirements/conda-env-macOS-ARM64 secrets: MACOS_SCCACHE_S3_ACCESS_KEY_ID: ${{ secrets.MACOS_SCCACHE_S3_ACCESS_KEY_ID }} MACOS_SCCACHE_S3_SECRET_ACCESS_KEY: ${{ secrets.MACOS_SCCACHE_S3_SECRET_ACCESS_KEY }} @@ -33,3 +38,11 @@ jobs: with: sync-tag: macos-12-py3-arm64-mps-test build-environment: macos-12-py3-arm64 + + macos-13-py3-arm64-mps-test: + name: macos-13-py3-arm64-mps + uses: ./.github/workflows/_mac-test-mps.yml + needs: macos-12-py3-arm64-build + with: + build-environment: macos-12-py3-arm64 + runs-on: macos-m1-13 diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index a8de37ca85be2..5c1de3dac5479 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -38,6 +38,7 @@ jobs: update-vision-commit-hash: uses: ./.github/workflows/_update-commit-hash.yml + if: ${{ github.event_name == 'schedule' }} with: repo-name: vision branch: main diff --git a/.github/workflows/periodic.yml b/.github/workflows/periodic.yml index e0b69e6b6d91e..4c47cdfe57a0a 100644 --- a/.github/workflows/periodic.yml +++ b/.github/workflows/periodic.yml @@ -3,13 +3,14 @@ name: periodic on: schedule: - cron: 45 0,4,8,12,16,20 * * * + - cron: 29 8 * * * # about 1:29am PDT, for mem leak check and rerun disabled tests push: tags: - ciflow/periodic/* workflow_dispatch: concurrency: - group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }} + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}-${{ github.event.schedule }} cancel-in-progress: true jobs: @@ -33,48 +34,51 @@ jobs: build-environment: linux-bionic-cuda11.6-py3-gcc7-slow-gradcheck docker-image: ${{ needs.linux-bionic-cuda11_6-py3-gcc7-slow-gradcheck-build.outputs.docker-image }} test-matrix: ${{ needs.linux-bionic-cuda11_6-py3-gcc7-slow-gradcheck-build.outputs.test-matrix }} + timeout-minutes: 300 - linux-focal-rocm5_2-py3_7-slow-build: - name: linux-focal-rocm5.2-py3.7-slow + linux-focal-rocm5_2-py3_8-slow-build: + name: linux-focal-rocm5.2-py3.8-slow uses: ./.github/workflows/_linux-build.yml with: - build-environment: linux-focal-rocm5.2-py3.7 - docker-image-name: pytorch-linux-focal-rocm5.2-py3.7 - - linux-focal-rocm5_2-py3_7-slow-test: - name: linux-focal-rocm5.2-py3.7-slow - uses: ./.github/workflows/_rocm-test.yml - needs: linux-focal-rocm5_2-py3_7-slow-build - with: - build-environment: linux-focal-rocm5.2-py3.7 - docker-image: ${{ needs.linux-focal-rocm5_2-py3_7-slow-build.outputs.docker-image }} + build-environment: linux-focal-rocm5.2-py3.8 + docker-image-name: pytorch-linux-focal-rocm5.2-py3.8 test-matrix: | { include: [ { config: "slow", shard: 1, num_shards: 1, runner: "linux.rocm.gpu" }, ]} + + linux-focal-rocm5_2-py3_8-slow-test: + name: linux-focal-rocm5.2-py3.8-slow + uses: ./.github/workflows/_rocm-test.yml + needs: linux-focal-rocm5_2-py3_8-slow-build + with: + build-environment: linux-focal-rocm5.2-py3.8 + docker-image: ${{ needs.linux-focal-rocm5_2-py3_8-slow-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-rocm5_2-py3_8-slow-build.outputs.test-matrix }} secrets: AWS_OSSCI_METRICS_V2_ACCESS_KEY_ID: ${{ secrets.AWS_OSSCI_METRICS_V2_ACCESS_KEY_ID }} AWS_OSSCI_METRICS_V2_SECRET_ACCESS_KEY: ${{ secrets.AWS_OSSCI_METRICS_V2_SECRET_ACCESS_KEY }} - linux-focal-rocm5_2-py3_7-distributed-build: - name: linux-focal-rocm5.2-py3.7-distributed + linux-focal-rocm5_2-py3_8-distributed-build: + name: linux-focal-rocm5.2-py3.8-distributed uses: ./.github/workflows/_linux-build.yml with: - build-environment: linux-focal-rocm5.2-py3.7 - docker-image-name: pytorch-linux-focal-rocm5.2-py3.7 - - linux-focal-rocm5_2-py3_7-distributed-test: - name: linux-focal-rocm5.2-py3.7-distributed - uses: ./.github/workflows/_rocm-test.yml - needs: linux-focal-rocm5_2-py3_7-distributed-build - with: - build-environment: linux-focal-rocm5.2-py3.7 - docker-image: ${{ needs.linux-focal-rocm5_2-py3_7-distributed-build.outputs.docker-image }} + build-environment: linux-focal-rocm5.2-py3.8 + docker-image-name: pytorch-linux-focal-rocm5.2-py3.8 test-matrix: | { include: [ { config: "distributed", shard: 1, num_shards: 2, runner: "linux.rocm.gpu" }, { config: "distributed", shard: 2, num_shards: 2, runner: "linux.rocm.gpu" }, ]} + + linux-focal-rocm5_2-py3_8-distributed-test: + name: linux-focal-rocm5.2-py3.8-distributed + uses: ./.github/workflows/_rocm-test.yml + needs: linux-focal-rocm5_2-py3_8-distributed-build + with: + build-environment: linux-focal-rocm5.2-py3.8 + docker-image: ${{ needs.linux-focal-rocm5_2-py3_8-distributed-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-rocm5_2-py3_8-distributed-build.outputs.test-matrix }} secrets: AWS_OSSCI_METRICS_V2_ACCESS_KEY_ID: ${{ secrets.AWS_OSSCI_METRICS_V2_ACCESS_KEY_ID }} AWS_OSSCI_METRICS_V2_SECRET_ACCESS_KEY: ${{ secrets.AWS_OSSCI_METRICS_V2_SECRET_ACCESS_KEY }} @@ -164,8 +168,9 @@ jobs: cuda-version: "11.7" test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 2, runner: "windows.8xlarge.nvidia.gpu" }, - { config: "default", shard: 2, num_shards: 2, runner: "windows.8xlarge.nvidia.gpu" }, + { config: "default", shard: 1, num_shards: 3, runner: "windows.8xlarge.nvidia.gpu" }, + { config: "default", shard: 2, num_shards: 3, runner: "windows.8xlarge.nvidia.gpu" }, + { config: "default", shard: 3, num_shards: 3, runner: "windows.8xlarge.nvidia.gpu" }, { config: "force_on_cpu", shard: 1, num_shards: 1, runner: "windows.4xlarge" }, ]} diff --git a/.github/workflows/pr-labels.yml b/.github/workflows/pr-labels.yml deleted file mode 100644 index aa8cf4472b784..0000000000000 --- a/.github/workflows/pr-labels.yml +++ /dev/null @@ -1,37 +0,0 @@ -name: pr-labels - -on: - push: - branches: - - master - - main - -jobs: - is-properly-labeled: - runs-on: ubuntu-latest - - steps: - - name: Checkout repository - uses: actions/checkout@v3 - - - name: Set up python - uses: actions/setup-python@v4 - with: - python-version: '3.10' - cache: pip - cache-dependency-path: | - **/.github/requirements-gha-cache.txt - - - name: Install requests - run: pip install requests==2.26 - - - name: Process commit and find merger responsible for labeling - id: commit - env: - SHA1: ${{ github.sha }} - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - run: python .github/scripts/process_commit.py "${SHA1}" - -concurrency: - group: pr-labels-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} - cancel-in-progress: true diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml deleted file mode 100644 index 4192537795557..0000000000000 --- a/.github/workflows/pull.yml +++ /dev/null @@ -1,308 +0,0 @@ -name: pull - -on: - pull_request: - push: - branches: - - master - - main - - release/* - - landchecks/* - workflow_dispatch: - -concurrency: - group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} - cancel-in-progress: true - -jobs: - linux-focal-py3_7-gcc7-build: - name: linux-focal-py3.7-gcc7 - uses: ./.github/workflows/_linux-build.yml - with: - build-environment: linux-focal-py3.7-gcc7 - docker-image-name: pytorch-linux-focal-py3.7-gcc7 - test-matrix: | - { include: [ - { config: "default", shard: 1, num_shards: 2, runner: "linux.2xlarge" }, - { config: "default", shard: 2, num_shards: 2, runner: "linux.2xlarge" }, - { config: "distributed", shard: 1, num_shards: 2, runner: "linux.2xlarge" }, - { config: "distributed", shard: 2, num_shards: 2, runner: "linux.2xlarge" }, - { config: "functorch", shard: 1, num_shards: 1, runner: "linux.2xlarge" }, - { config: "docs_test", shard: 1, num_shards: 1, runner: "linux.2xlarge" }, - { config: "jit_legacy", shard: 1, num_shards: 1, runner: "linux.2xlarge" }, - { config: "backwards_compat", shard: 1, num_shards: 1, runner: "linux.2xlarge" }, - ]} - - linux-focal-py3_7-gcc7-test: - name: linux-focal-py3.7-gcc7 - uses: ./.github/workflows/_linux-test.yml - needs: linux-focal-py3_7-gcc7-build - with: - build-environment: linux-focal-py3.7-gcc7 - docker-image: ${{ needs.linux-focal-py3_7-gcc7-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-py3_7-gcc7-build.outputs.test-matrix }} - - linux-docs: - name: linux-docs - uses: ./.github/workflows/_docs.yml - needs: linux-focal-py3_7-gcc7-build - with: - build-environment: linux-focal-py3.7-gcc7 - docker-image: ${{ needs.linux-focal-py3_7-gcc7-build.outputs.docker-image }} - - linux-focal-py3_7-gcc7-no-ops: - name: linux-focal-py3.7-gcc7-no-ops - uses: ./.github/workflows/_linux-build.yml - with: - build-environment: linux-focal-py3.7-gcc7-no-ops - docker-image-name: pytorch-linux-focal-py3.7-gcc7 - - linux-focal-py3_7-gcc7-pch: - name: linux-focal-py3.7-gcc7-pch - uses: ./.github/workflows/_linux-build.yml - with: - build-environment: linux-focal-py3.7-gcc7-pch - docker-image-name: pytorch-linux-focal-py3.7-gcc7 - - linux-focal-py3_7-clang7-asan-build: - name: linux-focal-py3.7-clang7-asan - uses: ./.github/workflows/_linux-build.yml - with: - build-environment: linux-focal-py3.7-clang7-asan - docker-image-name: pytorch-linux-focal-py3-clang7-asan - test-matrix: | - { include: [ - { config: "default", shard: 1, num_shards: 5, runner: "linux.2xlarge" }, - { config: "default", shard: 2, num_shards: 5, runner: "linux.2xlarge" }, - { config: "default", shard: 3, num_shards: 5, runner: "linux.2xlarge" }, - { config: "default", shard: 4, num_shards: 5, runner: "linux.2xlarge" }, - { config: "default", shard: 5, num_shards: 5, runner: "linux.2xlarge" }, - { config: "functorch", shard: 1, num_shards: 1, runner: "linux.2xlarge" }, - ]} - - linux-focal-py3_7-clang7-asan-test: - name: linux-focal-py3.7-clang7-asan - uses: ./.github/workflows/_linux-test.yml - needs: linux-focal-py3_7-clang7-asan-build - with: - build-environment: linux-focal-py3.7-clang7-asan - docker-image: ${{ needs.linux-focal-py3_7-clang7-asan-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-py3_7-clang7-asan-build.outputs.test-matrix }} - - linux-focal-py3_7-clang10-onnx-build: - name: linux-focal-py3.7-clang10-onnx - uses: ./.github/workflows/_linux-build.yml - with: - build-environment: linux-focal-py3.7-clang10-onnx - docker-image-name: pytorch-linux-focal-py3-clang10-onnx - test-matrix: | - { include: [ - { config: "default", shard: 1, num_shards: 2, runner: "linux.2xlarge" }, - { config: "default", shard: 2, num_shards: 2, runner: "linux.2xlarge" }, - ]} - - linux-focal-py3_7-clang10-onnx-test: - name: linux-focal-py3.7-clang10-onnx - uses: ./.github/workflows/_linux-test.yml - needs: linux-focal-py3_7-clang10-onnx-build - with: - build-environment: linux-focal-py3.7-clang10-onnx - docker-image: ${{ needs.linux-focal-py3_7-clang10-onnx-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-py3_7-clang10-onnx-build.outputs.test-matrix }} - - linux-bionic-py3_7-clang9-build: - name: linux-bionic-py3.7-clang9 - uses: ./.github/workflows/_linux-build.yml - with: - build-environment: linux-bionic-py3.7-clang9 - docker-image-name: pytorch-linux-bionic-py3.7-clang9 - test-matrix: | - { include: [ - { config: "default", shard: 1, num_shards: 2, runner: "linux.2xlarge" }, - { config: "default", shard: 2, num_shards: 2, runner: "linux.2xlarge" }, - { config: "crossref", shard: 1, num_shards: 2, runner: "linux.2xlarge" }, - { config: "crossref", shard: 2, num_shards: 2, runner: "linux.2xlarge" }, - { config: "dynamo", shard: 1, num_shards: 2, runner: "linux.2xlarge" }, - { config: "dynamo", shard: 2, num_shards: 2, runner: "linux.2xlarge" }, - { config: "functorch", shard: 1, num_shards: 1, runner: "linux.2xlarge" }, - ]} - - linux-bionic-py3_7-clang9-test: - name: linux-bionic-py3.7-clang9 - uses: ./.github/workflows/_linux-test.yml - needs: linux-bionic-py3_7-clang9-build - with: - build-environment: linux-bionic-py3.7-clang9 - docker-image: ${{ needs.linux-bionic-py3_7-clang9-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-bionic-py3_7-clang9-build.outputs.test-matrix }} - - linux-vulkan-bionic-py3_7-clang9-build: - name: linux-vulkan-bionic-py3.7-clang9 - uses: ./.github/workflows/_linux-build.yml - with: - build-environment: linux-vulkan-bionic-py3.7-clang9 - docker-image-name: pytorch-linux-bionic-py3.7-clang9 - test-matrix: | - { include: [ - { config: "default", shard: 1, num_shards: 1, runner: "linux.2xlarge" }, - ]} - - linux-vulkan-bionic-py3_7-clang9-test: - name: linux-vulkan-bionic-py3.7-clang9 - uses: ./.github/workflows/_linux-test.yml - needs: linux-vulkan-bionic-py3_7-clang9-build - with: - build-environment: linux-vulkan-bionic-py3.7-clang9 - docker-image: ${{ needs.linux-vulkan-bionic-py3_7-clang9-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-vulkan-bionic-py3_7-clang9-build.outputs.test-matrix }} - - linux-bionic-cuda11_6-py3_10-gcc7-build: - name: linux-bionic-cuda11.6-py3.10-gcc7 - uses: ./.github/workflows/_linux-build.yml - with: - build-environment: linux-bionic-cuda11.6-py3.10-gcc7 - docker-image-name: pytorch-linux-bionic-cuda11.6-cudnn8-py3-gcc7 - test-matrix: | - { include: [ - { config: "default", shard: 1, num_shards: 4, runner: "linux.4xlarge.nvidia.gpu" }, - { config: "default", shard: 2, num_shards: 4, runner: "linux.4xlarge.nvidia.gpu" }, - { config: "default", shard: 3, num_shards: 4, runner: "linux.4xlarge.nvidia.gpu" }, - { config: "default", shard: 4, num_shards: 4, runner: "linux.4xlarge.nvidia.gpu" }, - { config: "distributed", shard: 1, num_shards: 3, runner: "linux.8xlarge.nvidia.gpu" }, - { config: "distributed", shard: 2, num_shards: 3, runner: "linux.8xlarge.nvidia.gpu" }, - { config: "distributed", shard: 3, num_shards: 3, runner: "linux.8xlarge.nvidia.gpu" }, - { config: "functorch", shard: 1, num_shards: 1, runner: "linux.4xlarge.nvidia.gpu" }, - { config: "deploy", shard: 1, num_shards: 1, runner: "linux.4xlarge.nvidia.gpu" }, - ]} - - linux-bionic-cuda11_6-py3_10-gcc7-test: - name: linux-bionic-cuda11.6-py3.10-gcc7 - uses: ./.github/workflows/_linux-test.yml - needs: linux-bionic-cuda11_6-py3_10-gcc7-build - with: - build-environment: linux-bionic-cuda11.6-py3.10-gcc7 - docker-image: ${{ needs.linux-bionic-cuda11_6-py3_10-gcc7-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-bionic-cuda11_6-py3_10-gcc7-build.outputs.test-matrix }} - - linux-xenial-py3-clang5-mobile-build: - name: linux-xenial-py3-clang5-mobile-build - uses: ./.github/workflows/_linux-build.yml - with: - build-environment: linux-xenial-py3-clang5-mobile-build - docker-image-name: pytorch-linux-xenial-py3-clang5-asan - build-generates-artifacts: false - - linux-jammy-cuda-11_6-cudnn8-py3_8-clang12-build: - name: linux-jammy-cuda11.6-cudnn8-py3.8-clang12 - uses: ./.github/workflows/_linux-build.yml - with: - build-environment: linux-jammy-cuda11.6-cudnn8-py3.8-clang12 - docker-image-name: pytorch-linux-jammy-cuda11.6-cudnn8-py3.8-clang12 - - linux-xenial-py3-clang5-mobile-custom-build-static: - name: linux-xenial-py3-clang5-mobile-custom-build-static - uses: ./.github/workflows/_linux-build.yml - with: - build-environment: linux-xenial-py3-clang5-mobile-custom-build-static - docker-image-name: pytorch-linux-xenial-py3-clang5-android-ndk-r19c - build-generates-artifacts: false - - linux-bionic-py3_7-clang8-xla-build: - name: linux-bionic-py3_7-clang8-xla - uses: ./.github/workflows/_linux-build.yml - with: - build-environment: linux-bionic-py3_7-clang8-xla - docker-image-name: xla_base - test-matrix: | - { include: [ - { config: "xla", shard: 1, num_shards: 1, runner: "linux.2xlarge" }, - ]} - - linux-bionic-py3_7-clang8-xla-test: - name: linux-bionic-py3_7-clang8-xla - uses: ./.github/workflows/_linux-test.yml - needs: linux-bionic-py3_7-clang8-xla-build - with: - build-environment: linux-bionic-py3_7-clang8-xla - docker-image: ${{ needs.linux-bionic-py3_7-clang8-xla-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-bionic-py3_7-clang8-xla-build.outputs.test-matrix }} - - win-vs2019-cpu-py3-build: - name: win-vs2019-cpu-py3 - uses: ./.github/workflows/_win-build.yml - with: - build-environment: win-vs2019-cpu-py3 - cuda-version: cpu - test-matrix: | - { include: [ - { config: "default", shard: 1, num_shards: 2, runner: "windows.4xlarge" }, - { config: "default", shard: 2, num_shards: 2, runner: "windows.4xlarge" }, - { config: "functorch", shard: 1, num_shards: 1, runner: "windows.4xlarge" }, - ]} - - win-vs2019-cpu-py3-test: - name: win-vs2019-cpu-py3 - uses: ./.github/workflows/_win-test.yml - needs: win-vs2019-cpu-py3-build - with: - build-environment: win-vs2019-cpu-py3 - cuda-version: cpu - test-matrix: ${{ needs.win-vs2019-cpu-py3-build.outputs.test-matrix }} - - win-vs2019-cuda11_6-py3-build: - if: github.event_name == 'pull_request' - name: win-vs2019-cuda11.6-py3 - uses: ./.github/workflows/_win-build.yml - with: - build-environment: win-vs2019-cuda11.6-py3 - cuda-version: "11.6" - sync-tag: win-cuda-build - test-matrix: | - { include: [ - { config: "default", shard: 1, num_shards: 5, runner: "windows.8xlarge.nvidia.gpu" }, - { config: "default", shard: 2, num_shards: 5, runner: "windows.8xlarge.nvidia.gpu" }, - { config: "default", shard: 3, num_shards: 5, runner: "windows.8xlarge.nvidia.gpu" }, - { config: "default", shard: 4, num_shards: 5, runner: "windows.8xlarge.nvidia.gpu" }, - { config: "default", shard: 5, num_shards: 5, runner: "windows.8xlarge.nvidia.gpu" }, - { config: "functorch", shard: 1, num_shards: 1, runner: "windows.8xlarge.nvidia.gpu" }, - { config: "force_on_cpu", shard: 1, num_shards: 1, runner: "windows.4xlarge" }, - ]} - - linux-bionic-cuda11_6-py3_10-gcc7-bazel-test: - name: linux-bionic-cuda11.6-py3.10-gcc7-bazel-test - uses: ./.github/workflows/_bazel-build-test.yml - with: - build-environment: linux-bionic-cuda11.6-py3.10-gcc7-bazel-test - docker-image-name: pytorch-linux-bionic-cuda11.6-cudnn8-py3-gcc7 - - linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single: - name: linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single - uses: ./.github/workflows/_android-build-test.yml - with: - build-environment: linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single - docker-image-name: pytorch-linux-xenial-py3-clang5-android-ndk-r19c - - linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single-full-jit: - name: linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single-full-jit - uses: ./.github/workflows/_android-build-test.yml - with: - build-environment: linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single-full-jit - docker-image-name: pytorch-linux-xenial-py3-clang5-android-ndk-r19c - - linux-focal-py3_7-gcc7-mobile-lightweight-dispatch-build: - name: linux-focal-py3.7-gcc7-mobile-lightweight-dispatch-build - uses: ./.github/workflows/_linux-build.yml - with: - build-environment: linux-focal-py3.7-gcc7-mobile-lightweight-dispatch-build - docker-image-name: pytorch-linux-focal-py3.7-gcc7 - build-generates-artifacts: false - - linux-focal-rocm5_2-py3_7-build: - # don't run build twice on master - if: github.event_name == 'pull_request' - name: linux-focal-rocm5.2-py3.7 - uses: ./.github/workflows/_linux-build.yml - with: - build-environment: linux-focal-rocm5.2-py3.7 - docker-image-name: pytorch-linux-focal-rocm5.2-py3.7 - sync-tag: rocm-build diff --git a/.github/workflows/push_nightly_docker_ghcr.yml b/.github/workflows/push_nightly_docker_ghcr.yml deleted file mode 100644 index ac443a4d558c1..0000000000000 --- a/.github/workflows/push_nightly_docker_ghcr.yml +++ /dev/null @@ -1,39 +0,0 @@ -name: docker-release-builds -on: - schedule: - # Push the nightly docker daily at 1 PM UTC - - cron: '0 13 * * *' - # Trigger when we modify something related to these images - pull_request: - paths: - - .github/scripts/build_publish_nightly_docker.sh - - .github/workflows/push_nightly_docker_ghcr.yml - - Dockerfile - - docker.Makefile - # Have the ability to trigger this job manually using the API as well - workflow_dispatch: - -jobs: - docker-release-build: - if: ${{ github.repository == 'pytorch/pytorch' }} - runs-on: linux.2xlarge - env: - GHCR_PAT: ${{ secrets.GHCR_PAT }} - WITH_PUSH: ${{ github.event_name == 'schedule' }} - steps: - - name: Checkout PyTorch - uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 - with: - ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} - - uses: nick-fields/retry@3e91a01664abd3c5cd539100d10d33b9c5b68482 - name: Build and upload nightly docker - with: - timeout_minutes: 30 - max_attempts: 3 - command: | - set -ex - bash .github/scripts/build_publish_nightly_docker.sh - -concurrency: - group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} - cancel-in-progress: true diff --git a/.github/workflows/revert.yml b/.github/workflows/revert.yml index d207840f383b4..2a2fff27044ea 100644 --- a/.github/workflows/revert.yml +++ b/.github/workflows/revert.yml @@ -21,9 +21,10 @@ jobs: - name: Setup Python uses: actions/setup-python@v4 with: - python-version: 3.8 + python-version: '3.8' architecture: x64 - cache: 'pip' + check-latest: false + cache: pip - run: pip install pyyaml==6.0 - name: Setup committer id diff --git a/.github/workflows/run_torchbench.yml b/.github/workflows/run_torchbench.yml deleted file mode 100644 index 9a46a23af5bfc..0000000000000 --- a/.github/workflows/run_torchbench.yml +++ /dev/null @@ -1,102 +0,0 @@ -name: TorchBench CI (pytorch-linux-py3.7-cu102) -on: - pull_request: - -env: - PYTHON_VERSION: "3.8" - # must be consistent with https://github.com/pytorch/benchmark/blob/main/requirements.txt#L19 - NUMPY_VERSION: "1.21.2" - PR_NUM: ${{ github.event.number }} - PR_BODY: ${{ github.event.pull_request.body }} - PR_BASE_SHA: ${{ github.event.pull_request.base.sha }} - PR_HEAD_SHA: ${{ github.event.pull_request.head.sha }} - AWS_ACCESS_KEY_ID: ${{ secrets.AWS_OSSCI_METRICS_V2_ACCESS_KEY_ID }} - AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_OSSCI_METRICS_V2_SECRET_ACCESS_KEY }} - -jobs: - run-torchbench: - # We don't accept running on non-pytorch repos because of security concerns - # Only run the job when the body contains magic word "RUN_TORCHBENCH:" - if: ${{ github.repository_owner == 'pytorch' && contains(github.event.pull_request.body, 'RUN_TORCHBENCH:') }} - runs-on: [self-hosted, bm-runner] - # Set to 12 hours - timeout-minutes: 720 - steps: - - name: Checkout PyTorch - uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 - with: - path: pytorch - - name: Update self-hosted PyTorch - run: | - pushd "${HOME}"/pytorch - git remote prune origin - git fetch - popd - - name: Create conda environment and install deps - run: | - conda create -y -n pr-ci python="${PYTHON_VERSION}" - # shellcheck disable=SC1091 - . "${HOME}"/anaconda3/etc/profile.d/conda.sh - conda activate pr-ci - # pin cmake version to 3.22 since 3.23 breaks pytorch build - # see details at: https://github.com/pytorch/pytorch/issues/74985 - conda install -y numpy="${NUMPY_VERSION}" requests ninja pyyaml mkl mkl-include \ - setuptools cmake=3.22 cffi typing_extensions boto3 \ - future six dataclasses pillow pytest tabulate gitpython git-lfs tqdm psutil - - name: Setup TorchBench branch - run: | - # shellcheck disable=SC1091 - . "${HOME}"/anaconda3/etc/profile.d/conda.sh - conda activate pr-ci - PR_BODY_FILE=/tmp/pr-body.txt - echo "$PR_BODY" > ${PR_BODY_FILE} - python pytorch/.github/scripts/run_torchbench.py --pr-body "${PR_BODY_FILE}" set-torchbench-branch - - name: Checkout TorchBench - uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 - with: - repository: pytorch/benchmark - path: benchmark - lfs: false - ref: ${{ env.TORCHBENCH_BRANCH }} - - name: GPU Info - run: | - nvidia-smi - - name: Run TorchBench - run: | - set -x - pushd "${HOME}"/pytorch - PR_MERGE_BASE=$(git merge-base "$PR_BASE_SHA" "$PR_HEAD_SHA") - popd - PR_BODY_FILE=/tmp/pr-body.txt - echo "$PR_BODY" > ${PR_BODY_FILE} - # shellcheck disable=SC1091 - . "${HOME}"/anaconda3/etc/profile.d/conda.sh - conda activate pr-ci - python3 pytorch/.github/scripts/run_torchbench.py \ - --pr-body "$PR_BODY_FILE" \ - run \ - --pytorch-path "${HOME}"/pytorch \ - --torchbench-path "${PWD}"/benchmark \ - --pr-num "$PR_NUM" \ - --pr-base-sha "$PR_MERGE_BASE" \ - --pr-head-sha "$PR_HEAD_SHA" - - name: Upload result to S3 - run: | - . "${HOME}"/anaconda3/etc/profile.d/conda.sh - conda activate pr-ci - python3 pytorch/.github/scripts/run_torchbench.py \ - upload-s3 \ - --result-dir "${HOME}/.torchbench/bisection/pr${{ github.event.number }}" - - name: Remove conda environment and cleanup - run: | - conda env remove --name pr-ci - rm /tmp/pr-body.txt - - name: Upload artifact - uses: actions/upload-artifact@v2 - with: - name: TorchBench result - path: ~/.torchbench/bisection/pr${{ github.event.number }} - -concurrency: - group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} - cancel-in-progress: true diff --git a/.github/workflows/scorecards.yml b/.github/workflows/scorecards.yml index 516998bfa95be..8abee79cf400f 100644 --- a/.github/workflows/scorecards.yml +++ b/.github/workflows/scorecards.yml @@ -21,11 +21,11 @@ jobs: # Used to receive a badge. id-token: write - if: github.repository == 'pytorch/pytorch' # don't run on forks + if: false && github.repository == 'pytorch/pytorch' # don't run on forks steps: - name: "Checkout code" - uses: actions/checkout@2541b1294d2704b0964813337f33b291d3f8596b # tag=v3.0.2 + uses: actions/checkout@v3 with: persist-credentials: false @@ -42,7 +42,7 @@ jobs: # Upload the results as artifacts (optional). Commenting out will disable uploads of run results in SARIF # format to the repository Actions tab. - name: "Upload artifact" - uses: actions/upload-artifact@3cea5372237819ed00197afe530f5a7ea3e805c8 # tag=v3.1.0 + uses: actions/upload-artifact@v3 with: name: SARIF file path: results.sarif diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index af348a84556c9..85526deccbf90 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -10,9 +10,11 @@ on: tags: - ciflow/trunk/* workflow_dispatch: + schedule: + - cron: 29 8 * * * # about 1:29am PDT concurrency: - group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }} + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} cancel-in-progress: true jobs: @@ -58,8 +60,6 @@ jobs: { config: "default", shard: 3, num_shards: 4, runner: "linux.4xlarge.nvidia.gpu" }, { config: "default", shard: 4, num_shards: 4, runner: "linux.4xlarge.nvidia.gpu" }, { config: "functorch", shard: 1, num_shards: 1, runner: "linux.4xlarge.nvidia.gpu" }, - { config: "slow", shard: 1, num_shards: 2, runner: "linux.4xlarge.nvidia.gpu" }, - { config: "slow", shard: 2, num_shards: 2, runner: "linux.4xlarge.nvidia.gpu" }, { config: "nogpu_AVX512", shard: 1, num_shards: 1, runner: "linux.2xlarge" }, { config: "nogpu_NO_AVX2", shard: 1, num_shards: 1, runner: "linux.2xlarge" }, { config: "jit_legacy", shard: 1, num_shards: 1, runner: "linux.4xlarge.nvidia.gpu" }, @@ -90,6 +90,8 @@ jobs: { config: "default", shard: 2, num_shards: 4, runner: "linux.g5.4xlarge.nvidia.gpu" }, { config: "default", shard: 3, num_shards: 4, runner: "linux.g5.4xlarge.nvidia.gpu" }, { config: "default", shard: 4, num_shards: 4, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "slow", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "slow", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, { config: "functorch", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, ]} @@ -109,6 +111,7 @@ jobs: build-environment: libtorch-linux-bionic-cuda11.6-py3.7-gcc7 docker-image-name: pytorch-linux-bionic-cuda11.6-cudnn8-py3-gcc7 build-generates-artifacts: false + runner: linux.4xlarge # no-ops builds test USE_PER_OPERATOR_HEADERS=0 where ATen/ops is not generated linux-bionic-cuda11_7-py3_10-gcc7-no-ops-build: @@ -118,12 +121,12 @@ jobs: build-environment: linux-bionic-cuda11.7-py3.10-gcc7-no-ops docker-image-name: pytorch-linux-bionic-cuda11.7-cudnn8-py3-gcc7 - pytorch-linux-xenial-py3-clang5-android-ndk-r19c-build: - name: pytorch-linux-xenial-py3-clang5-android-ndk-r19c-build + pytorch-linux-focal-py3-clang7-android-ndk-r19c-build: + name: pytorch-linux-focal-py3-clang7-android-ndk-r19c-build uses: ./.github/workflows/_android-full-build-test.yml with: - build-environment: pytorch-linux-xenial-py3-clang5-android-ndk-r19c-build - docker-image-name: pytorch-linux-xenial-py3-clang5-android-ndk-r19c + build-environment: pytorch-linux-focal-py3-clang7-android-ndk-r19c-build + docker-image-name: pytorch-linux-focal-py3-clang7-android-ndk-r19c linux-bionic-py3_7-clang9-slow-build: name: linux-bionic-py3.7-clang9-slow @@ -226,6 +229,10 @@ jobs: build-generates-artifacts: true # To match the one pre-installed in the m1 runners python_version: 3.9.12 + # We need to set the environment file here instead of trying to detect it automatically because + # MacOS arm64 is cross-compiled from x86-64. Specifically, it means that arm64 conda environment + # is needed when building PyTorch MacOS arm64 from x86-64 + environment-file: .github/requirements/conda-env-macOS-ARM64 test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 2, runner: "macos-m1-12" }, @@ -284,26 +291,27 @@ jobs: cuda-version: "11.6" test-matrix: ${{ needs.win-vs2019-cuda11_6-py3-build.outputs.test-matrix }} - linux-focal-rocm5_2-py3_7-build: - name: linux-focal-rocm5.2-py3.7 + linux-focal-rocm5_2-py3_8-build: + name: linux-focal-rocm5.2-py3.8 uses: ./.github/workflows/_linux-build.yml with: - build-environment: linux-focal-rocm5.2-py3.7 - docker-image-name: pytorch-linux-focal-rocm5.2-py3.7 + build-environment: linux-focal-rocm5.2-py3.8 + docker-image-name: pytorch-linux-focal-rocm5.2-py3.8 sync-tag: rocm-build - - linux-focal-rocm5_2-py3_7-test: - name: linux-focal-rocm5.2-py3.7 - uses: ./.github/workflows/_rocm-test.yml - needs: linux-focal-rocm5_2-py3_7-build - with: - build-environment: linux-focal-rocm5.2-py3.7 - docker-image: ${{ needs.linux-focal-rocm5_2-py3_7-build.outputs.docker-image }} test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 2, runner: "linux.rocm.gpu" }, { config: "default", shard: 2, num_shards: 2, runner: "linux.rocm.gpu" }, ]} + + linux-focal-rocm5_2-py3_8-test: + name: linux-focal-rocm5.2-py3.8 + uses: ./.github/workflows/_rocm-test.yml + needs: linux-focal-rocm5_2-py3_8-build + with: + build-environment: linux-focal-rocm5.2-py3.8 + docker-image: ${{ needs.linux-focal-rocm5_2-py3_8-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-rocm5_2-py3_8-build.outputs.test-matrix }} secrets: AWS_OSSCI_METRICS_V2_ACCESS_KEY_ID: ${{ secrets.AWS_OSSCI_METRICS_V2_ACCESS_KEY_ID }} AWS_OSSCI_METRICS_V2_SECRET_ACCESS_KEY: ${{ secrets.AWS_OSSCI_METRICS_V2_SECRET_ACCESS_KEY }} diff --git a/.github/workflows/trymerge.yml b/.github/workflows/trymerge.yml index dff92303f5056..3d1d92967d885 100644 --- a/.github/workflows/trymerge.yml +++ b/.github/workflows/trymerge.yml @@ -21,8 +21,9 @@ jobs: - name: Setup Python uses: actions/setup-python@v4 with: - python-version: 3.8 - cache: 'pip' + python-version: '3.8' + check-latest: false + cache: pip architecture: x64 - run: pip install pyyaml==6.0 diff --git a/.github/workflows/tryrebase.yml b/.github/workflows/tryrebase.yml index fed9000c420e9..53434310c3d00 100644 --- a/.github/workflows/tryrebase.yml +++ b/.github/workflows/tryrebase.yml @@ -20,9 +20,10 @@ jobs: - name: Setup Python uses: actions/setup-python@v4 with: - python-version: 3.8 + python-version: '3.8' architecture: x64 - cache: 'pip' + check-latest: false + cache: pip - run: pip install pyyaml==6.0 - name: Setup committer id diff --git a/.github/workflows/update-viablestrict.yml b/.github/workflows/update-viablestrict.yml index 5901b1f4cda1b..12bf4e271f927 100644 --- a/.github/workflows/update-viablestrict.yml +++ b/.github/workflows/update-viablestrict.yml @@ -22,8 +22,9 @@ jobs: - name: Setup Python uses: actions/setup-python@v4 with: - python-version: 3.8 + python-version: '3.8' architecture: x64 + check-latest: false cache: pip cache-dependency-path: | **/.circleci/docker/requirements-ci.txt diff --git a/.github/workflows/upload-test-stats.yml b/.github/workflows/upload-test-stats.yml index 688a55b6eabca..3f3db80670d8c 100644 --- a/.github/workflows/upload-test-stats.yml +++ b/.github/workflows/upload-test-stats.yml @@ -2,7 +2,7 @@ name: Upload test stats on: workflow_run: - workflows: [pull, trunk, periodic] + workflows: [pull, trunk, periodic, inductor] types: - completed @@ -58,6 +58,31 @@ jobs: python3 -m tools.stats.upload_test_stats --workflow-run-id "${WORKFLOW_RUN_ID}" --workflow-run-attempt "${WORKFLOW_RUN_ATTEMPT}" --head-branch "${HEAD_BRANCH}" python3 -m tools.stats.upload_sccache_stats --workflow-run-id "${WORKFLOW_RUN_ID}" --workflow-run-attempt "${WORKFLOW_RUN_ATTEMPT}" + - name: Upload test artifacts + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + WORKFLOW_ARTIFACTS_URL: ${{ github.event.workflow_run.artifacts_url }} + WORKFLOW_RUN_ID: ${{ github.event.workflow_run.id }} + WORKFLOW_RUN_ATTEMPT: ${{ github.event.workflow_run.run_attempt }} + REPO_FULLNAME: ${{ github.event.workflow_run.repository.full_name }} + run: | + echo "${WORKFLOW_ARTIFACTS_URL}" + + # Note that in the case of Linux and Windows, their artifacts have already been uploaded to S3, so there simply won't be + # anything on GitHub to upload. The command should return right away + python3 -m tools.stats.upload_artifacts --workflow-run-id "${WORKFLOW_RUN_ID}" --workflow-run-attempt "${WORKFLOW_RUN_ATTEMPT}" --repo "${REPO_FULLNAME}" + + - name: Analyze disabled tests rerun + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + WORKFLOW_ARTIFACTS_URL: ${{ github.event.workflow_run.artifacts_url }} + WORKFLOW_RUN_ID: ${{ github.event.workflow_run.id }} + WORKFLOW_RUN_ATTEMPT: ${{ github.event.workflow_run.run_attempt }} + REPO_FULLNAME: ${{ github.event.workflow_run.repository.full_name }} + run: | + # Analyze the results from disable tests rerun and upload them to S3 + python3 -m tools.stats.check_disabled_tests --workflow-run-id "${WORKFLOW_RUN_ID}" --workflow-run-attempt "${WORKFLOW_RUN_ATTEMPT}" --repo "${REPO_FULLNAME}" + check-api-rate: if: ${{ always() }} runs-on: [self-hosted, linux.2xlarge] diff --git a/.gitignore b/.gitignore index d4f07f2cf10fb..597ae390abe9c 100644 --- a/.gitignore +++ b/.gitignore @@ -46,6 +46,7 @@ docs/source/generated/ log usage_log.txt test-reports/ +test/*.bak test/.coverage test/.hypothesis/ test/cpp/api/mnist @@ -338,3 +339,6 @@ third_party/glog/ # Virtualenv venv/ + +# Log files +*.log diff --git a/.jenkins/caffe2/test.sh b/.jenkins/caffe2/test.sh index 0204907ee865d..d245dabda4daa 100755 --- a/.jenkins/caffe2/test.sh +++ b/.jenkins/caffe2/test.sh @@ -149,6 +149,9 @@ export DNNL_MAX_CPU_ISA=AVX2 # Should still run even in the absence of SHARD_NUMBER if [[ "${SHARD_NUMBER:-1}" == "1" ]]; then + # TODO(sdym@meta.com) remove this when the linked issue resolved. + # py is temporary until https://github.com/Teemu/pytest-sugar/issues/241 is fixed + pip install --user py==1.11.0 pip install --user pytest-sugar # NB: Warnings are disabled because they make it harder to see what # the actual erroring test is @@ -173,7 +176,9 @@ fi ############## if [[ "$BUILD_ENVIRONMENT" == *onnx* ]]; then pip install -q --user --no-use-pep517 "git+https://github.com/pytorch/vision.git@$(cat .github/ci_commit_pins/vision.txt)" - pip install -q --user ninja flatbuffers==2.0 numpy==1.21.5 onnxruntime==1.12.1 beartype==0.10.4 + pip install -q --user ninja flatbuffers==2.0 numpy==1.21.5 onnxruntime==1.12.1 beartype==0.10.4 onnx==1.12.0 + # TODO: change this when onnx-script is on testPypi + pip install 'onnx-script @ git+https://github.com/microsoft/onnx-script' # numba requires numpy <= 1.20, onnxruntime requires numpy >= 1.21. # We don't actually need it for our tests, but it's imported if it's present, so uninstall. pip uninstall -q --yes numba diff --git a/.jenkins/pytorch/build-asan.sh b/.jenkins/pytorch/build-asan.sh index d2cafa323fc56..91953c322f223 100755 --- a/.jenkins/pytorch/build-asan.sh +++ b/.jenkins/pytorch/build-asan.sh @@ -26,7 +26,7 @@ CC="clang" CXX="clang++" LDSHARED="clang --shared" \ CFLAGS="-fsanitize=address -fsanitize=undefined -fno-sanitize-recover=all -fsanitize-address-use-after-scope -shared-libasan" \ USE_ASAN=1 USE_CUDA=0 USE_MKLDNN=0 \ python setup.py bdist_wheel - python -mpip install "$(echo dist/*.whl)[opt-einsum]" + pip_install_whl "$(echo dist/*.whl)" # Test building via the sdist source tarball python setup.py sdist diff --git a/.jenkins/pytorch/build-tsan.sh b/.jenkins/pytorch/build-tsan.sh index 41ebdd5cb1eed..e10edb310d813 100755 --- a/.jenkins/pytorch/build-tsan.sh +++ b/.jenkins/pytorch/build-tsan.sh @@ -22,7 +22,7 @@ CC="clang" CXX="clang++" LDSHARED="clang --shared" \ CFLAGS="-fsanitize=thread" \ USE_TSAN=1 USE_CUDA=0 USE_MKLDNN=0 \ python setup.py bdist_wheel - python -mpip install dist/*.whl + pip_install_whl "$(echo dist/*.whl)" print_sccache_stats diff --git a/.jenkins/pytorch/build.sh b/.jenkins/pytorch/build.sh index 24567449424a6..bb7b2c5d03c88 100755 --- a/.jenkins/pytorch/build.sh +++ b/.jenkins/pytorch/build.sh @@ -41,8 +41,6 @@ if [[ "$BUILD_ENVIRONMENT" == *cuda* ]]; then fi if [[ "$BUILD_ENVIRONMENT" == *cuda11* ]]; then - # enable split torch_cuda build option in CMake - export BUILD_SPLIT_CUDA=ON if [[ "$BUILD_ENVIRONMENT" != *cuda11.3* && "$BUILD_ENVIRONMENT" != *clang* ]]; then # TODO: there is a linking issue when building with UCC using clang, # disable it for now and to be fix later. @@ -62,9 +60,6 @@ elif [[ ${BUILD_ENVIRONMENT} == *"parallelnative"* ]]; then export ATEN_THREADING=NATIVE fi -# TODO: Don't run this... -pip_install -r requirements.txt || true - # Enable LLVM dependency for TensorExpr testing if [[ "$BUILD_ENVIRONMENT" == *rocm* ]]; then export USE_LLVM=/opt/rocm/llvm @@ -74,13 +69,11 @@ else export LLVM_DIR=/opt/llvm/lib/cmake/llvm fi -# TODO: Don't install this here if ! which conda; then # In ROCm CIs, we are doing cross compilation on build machines with # intel cpu and later run tests on machines with amd cpu. # Also leave out two builds to make sure non-mkldnn builds still work. if [[ "$BUILD_ENVIRONMENT" != *rocm* ]]; then - pip_install mkl mkl-devel export USE_MKLDNN=1 else export USE_MKLDNN=0 @@ -189,17 +182,8 @@ if [[ "${BUILD_ENVIRONMENT}" == *linux-focal-py3.7-gcc7-build* ]]; then export USE_GLOO_WITH_OPENSSL=ON fi -# TODO: Remove after xenial->focal migration -if [[ "${BUILD_ENVIRONMENT}" == pytorch-linux-xenial-py3* ]]; then - if [[ "${BUILD_ENVIRONMENT}" != *android* && "${BUILD_ENVIRONMENT}" != *cuda* ]]; then - export BUILD_STATIC_RUNTIME_BENCHMARK=ON - fi -fi - -if [[ "${BUILD_ENVIRONMENT}" == pytorch-linux-focal-py3* ]]; then - if [[ "${BUILD_ENVIRONMENT}" != *android* && "${BUILD_ENVIRONMENT}" != *cuda* ]]; then - export BUILD_STATIC_RUNTIME_BENCHMARK=ON - fi +if [[ "${BUILD_ENVIRONMENT}" != *android* && "${BUILD_ENVIRONMENT}" != *cuda* ]]; then + export BUILD_STATIC_RUNTIME_BENCHMARK=ON fi if [[ "$BUILD_ENVIRONMENT" == *-bazel-* ]]; then @@ -230,7 +214,7 @@ else else python setup.py bdist_wheel fi - python -mpip install "$(echo dist/*.whl)[opt-einsum]" + pip_install_whl "$(echo dist/*.whl)" # TODO: I'm not sure why, but somehow we lose verbose commands set -x diff --git a/.jenkins/pytorch/common_utils.sh b/.jenkins/pytorch/common_utils.sh index c0e51bc80aa8c..6060a7179e0f9 100644 --- a/.jenkins/pytorch/common_utils.sh +++ b/.jenkins/pytorch/common_utils.sh @@ -9,6 +9,10 @@ log() { printf '%s\n' "$*"; } error() { log "ERROR: $*" >&2; } fatal() { error "$@"; exit 1; } +retry () { + "$@" || (sleep 10 && "$@") || (sleep 20 && "$@") || (sleep 40 && "$@") +} + # compositional trap taken from https://stackoverflow.com/a/7287873/23845 # appends a command to a trap # @@ -49,6 +53,12 @@ function assert_git_not_dirty() { fi } +function pip_install_whl() { + # This is used to install PyTorch and other build artifacts wheel locally + # without using any network connection + python3 -mpip install --no-index --no-deps "$@" +} + function pip_install() { # retry 3 times # old versions of pip don't have the "--progress-bar" flag @@ -72,12 +82,12 @@ function get_exit_code() { function get_bazel() { if [[ $(uname) == "Darwin" ]]; then # download bazel version - curl https://github.com/bazelbuild/bazel/releases/download/4.2.1/bazel-4.2.1-darwin-x86_64 -Lo tools/bazel + retry curl https://github.com/bazelbuild/bazel/releases/download/4.2.1/bazel-4.2.1-darwin-x86_64 -Lo tools/bazel # verify content echo '74d93848f0c9d592e341e48341c53c87e3cb304a54a2a1ee9cff3df422f0b23c tools/bazel' | shasum -a 256 -c >/dev/null else # download bazel version - curl https://ossci-linux.s3.amazonaws.com/bazel-4.2.1-linux-x86_64 -o tools/bazel + retry curl https://ossci-linux.s3.amazonaws.com/bazel-4.2.1-linux-x86_64 -o tools/bazel # verify content echo '1a4f3a3ce292307bceeb44f459883859c793436d564b95319aacb8af1f20557c tools/bazel' | shasum -a 256 -c >/dev/null fi @@ -95,20 +105,16 @@ function get_pinned_commit() { cat .github/ci_commit_pins/"${1}".txt } -function install_torchvision() { +function install_torchtext() { local commit - commit=$(get_pinned_commit vision) - pip_install --no-use-pep517 --user "git+https://github.com/pytorch/vision.git@${commit}" + commit=$(get_pinned_commit text) + pip_install --no-use-pep517 --user "git+https://github.com/pytorch/text.git@${commit}" } -function checkout_install_torchvision() { +function install_torchvision() { local commit commit=$(get_pinned_commit vision) - git clone https://github.com/pytorch/vision - pushd vision - git checkout "${commit}" - time python setup.py install - popd + pip_install --no-use-pep517 --user "git+https://github.com/pytorch/vision.git@${commit}" } function clone_pytorch_xla() { @@ -134,12 +140,12 @@ function install_triton() { else commit=$(get_pinned_commit triton) pip_install --user "git+https://github.com/openai/triton@${commit}#subdirectory=python" + pip_install --user jinja2 fi } function setup_torchdeploy_deps(){ - conda install -y cmake - conda install -y -c conda-forge libpython-static=3.10 + conda install -y -n "py_${ANACONDA_PYTHON_VERSION}" "libpython-static=${ANACONDA_PYTHON_VERSION}" local CC local CXX CC="$(which gcc)" @@ -151,13 +157,14 @@ function setup_torchdeploy_deps(){ function checkout_install_torchdeploy() { local commit + commit=$(get_pinned_commit multipy) setup_torchdeploy_deps pushd .. git clone --recurse-submodules https://github.com/pytorch/multipy.git pushd multipy - # with ABI flag change + git checkout "${commit}" python multipy/runtime/example/generate_examples.py - pip install -e . --install-option="--abicxx" + pip install -e . --install-option="--cudatests" popd popd } @@ -166,6 +173,7 @@ function test_torch_deploy(){ pushd .. pushd multipy ./multipy/runtime/build/test_deploy + ./multipy/runtime/build/test_deploy_gpu popd popd } @@ -187,13 +195,12 @@ function install_timm() { } function checkout_install_torchbench() { - local commit - commit=$(get_pinned_commit torchbench) git clone https://github.com/pytorch/benchmark torchbench pushd torchbench - git checkout "${commit}" - python install.py - pip_install gym==0.25.2 # workaround issue in 0.26.0 + git checkout no_torchaudio + # Occasionally the installation may fail on one model but it is ok to continue + # to install and test other models + python install.py --continue_on_fail popd } diff --git a/.jenkins/pytorch/macos-common.sh b/.jenkins/pytorch/macos-common.sh index 319e88e40aa8d..d1b31ec941889 100755 --- a/.jenkins/pytorch/macos-common.sh +++ b/.jenkins/pytorch/macos-common.sh @@ -7,52 +7,6 @@ source "$(dirname "${BASH_SOURCE[0]}")/common.sh" sysctl -a | grep machdep.cpu -if [[ ${BUILD_ENVIRONMENT} = *arm64* ]]; then - # We use different versions here as the arm build/tests runs on python 3.9 - # while the x86 one runs on python 3.8 - retry conda install -y \ - numpy=1.22.3 \ - pyyaml=6.0 \ - setuptools=61.2.0 \ - cmake=3.22.1 \ - cffi \ - ninja \ - typing_extensions \ - dataclasses \ - pip -else - # NOTE: mkl 2021.3.0+ cmake requires sub-command PREPEND, may break the build - retry conda install -y \ - mkl=2021.2.0 \ - mkl-include=2021.2.0 \ - numpy=1.18.5 \ - pyyaml=5.3 \ - setuptools=46.0.0 \ - cmake=3.22.1 \ - cffi \ - ninja \ - typing_extensions \ - dataclasses \ - pip -fi - -# The torch.hub tests make requests to GitHub. -# -# The certifi package from conda-forge is new enough to make the -# following error disappear (included for future reference): -# -# > ssl.SSLCertVerificationError: [SSL: CERTIFICATE_VERIFY_FAILED] -# > certificate verify failed: unable to get local issuer certificate -# > (_ssl.c:1056) -# -retry conda install -y -c conda-forge certifi wheel=0.36.2 - -# Needed by torchvision, which is imported from TestHub in test_utils.py. -retry conda install -y pillow - -# Building with USE_DISTRIBUTED=1 requires libuv (for Gloo). -retry conda install -y libuv pkg-config - # These are required for both the build job and the test job. # In the latter to test cpp extensions. export MACOSX_DEPLOYMENT_TARGET=10.9 diff --git a/.jenkins/pytorch/macos-test.sh b/.jenkins/pytorch/macos-test.sh index 7103f1a5dbee3..4beab880ddbb3 100755 --- a/.jenkins/pytorch/macos-test.sh +++ b/.jenkins/pytorch/macos-test.sh @@ -4,24 +4,6 @@ # shellcheck source=./macos-common.sh source "$(dirname "${BASH_SOURCE[0]}")/macos-common.sh" -conda install -y six -if [[ ${BUILD_ENVIRONMENT} = *arm64* ]]; then - pip install hypothesis "expecttest==0.1.3" "librosa>=0.6.2" "numba==0.56.0" psutil "scipy==1.9.0" -else - pip install hypothesis "expecttest==0.1.3" "librosa>=0.6.2" "numba<=0.49.1" psutil "scipy==1.6.3" -fi - -# TODO move this to docker -# Pin unittest-xml-reporting to freeze printing test summary logic, related: https://github.com/pytorch/pytorch/issues/69014 -pip install "unittest-xml-reporting<=3.2.0,>=2.0.0" \ - pytest \ - pytest-xdist \ - pytest-shard \ - pytest-rerunfailures \ - "xdoctest==1.0.2" \ - "pygments==2.12.0" \ - "opt-einsum>=3.3" - if [ -z "${CI}" ]; then rm -rf "${WORKSPACE_DIR}"/miniconda3/lib/python3.6/site-packages/torch* fi diff --git a/.jenkins/pytorch/multigpu-test.sh b/.jenkins/pytorch/multigpu-test.sh index bbd1c370a638e..9d7efc969823c 100755 --- a/.jenkins/pytorch/multigpu-test.sh +++ b/.jenkins/pytorch/multigpu-test.sh @@ -8,11 +8,6 @@ source "$(dirname "${BASH_SOURCE[0]}")/common.sh" echo "Testing pytorch" -if [ -n "${CI}" ]; then - # TODO move this to docker - # Pin unittest-xml-reporting to freeze printing test summary logic, related: https://github.com/pytorch/pytorch/issues/69014 - pip_install "unittest-xml-reporting<=3.2.0,>=2.0.0" -fi # Disabling tests to see if they solve timeout issues; see https://github.com/pytorch/pytorch/issues/70015 # python tools/download_mnist.py --quiet -d test/cpp/api/mnist @@ -28,8 +23,8 @@ time python test/run_test.py --verbose -i distributed/rpc/cuda/test_tensorpipe_a # FSDP tests for f in test/distributed/fsdp/*.py ; do time python test/run_test.py --verbose -i "${f#*/}" ; done # ShardedTensor tests -time python test/run_test.py --verbose -i distributed/_shard/checkpoint/test_checkpoint -time python test/run_test.py --verbose -i distributed/_shard/checkpoint/test_file_system_checkpoint +time python test/run_test.py --verbose -i distributed/checkpoint/test_checkpoint +time python test/run_test.py --verbose -i distributed/checkpoint/test_file_system_checkpoint time python test/run_test.py --verbose -i distributed/_shard/sharding_spec/test_sharding_spec time python test/run_test.py --verbose -i distributed/_shard/sharding_plan/test_sharding_plan time python test/run_test.py --verbose -i distributed/_shard/sharded_tensor/test_megatron_prototype diff --git a/.jenkins/pytorch/test.sh b/.jenkins/pytorch/test.sh index 7e9d4f37edec1..4e52f31a74c7f 100755 --- a/.jenkins/pytorch/test.sh +++ b/.jenkins/pytorch/test.sh @@ -16,6 +16,7 @@ BUILD_RENAMED_DIR="build_renamed" BUILD_BIN_DIR="$BUILD_DIR"/bin export VALGRIND=ON +export TORCH_INDUCTOR_INSTALL_GXX=ON if [[ "$BUILD_ENVIRONMENT" == *clang9* ]]; then # clang9 appears to miscompile code involving c10::optional, # such that valgrind complains along these lines: @@ -97,10 +98,6 @@ if [[ "$BUILD_ENVIRONMENT" == *cuda* || "$BUILD_ENVIRONMENT" == *rocm* ]]; then export PYTORCH_TESTING_DEVICE_ONLY_FOR="cuda" fi -if [[ "$BUILD_ENVIRONMENT" == *cuda11* ]]; then - export BUILD_SPLIT_CUDA=ON -fi - if [[ "$TEST_CONFIG" == *crossref* ]]; then export PYTORCH_TEST_WITH_CROSSREF=1 fi @@ -113,14 +110,6 @@ if [[ "$TEST_CONFIG" == *inductor* ]]; then export PYTORCH_TEST_WITH_INDUCTOR=1 fi -# TODO: this condition is never true, need to fix this. -if [[ -n "$PR_NUMBER" ]] && [[ -z "$CI_MASTER" || "$CI_MASTER" == "false" ]]; then - # skip expensive checks when on PR and CI_MASTER flag is not set - export PYTORCH_TEST_SKIP_CUDA_MEM_LEAK_CHECK=1 -else - export PYTORCH_TEST_SKIP_CUDA_MEM_LEAK_CHECK=0 -fi - if [[ "$BUILD_ENVIRONMENT" == *rocm* ]]; then # Print GPU info rocminfo @@ -129,7 +118,7 @@ fi if [[ "$BUILD_ENVIRONMENT" != *-bazel-* ]] ; then # JIT C++ extensions require ninja. - pip_install --user ninja + pip_install --user "ninja==1.10.2" # ninja is installed in $HOME/.local/bin, e.g., /var/lib/jenkins/.local/bin for CI user jenkins # but this script should be runnable by any user, including root export PATH="$HOME/.local/bin:$PATH" @@ -139,9 +128,8 @@ fi # if you're not careful. Check this if you made some changes and the # ASAN test is not working if [[ "$BUILD_ENVIRONMENT" == *asan* ]]; then - # Suppress vptr violations arising from multiple copies of pybind11 export ASAN_OPTIONS=detect_leaks=0:symbolize=1:detect_stack_use_after_return=1:strict_init_order=true:detect_odr_violation=0 - export UBSAN_OPTIONS=print_stacktrace=1:suppressions=$PWD/ubsan.supp + export UBSAN_OPTIONS=print_stacktrace=1 export PYTORCH_TEST_WITH_ASAN=1 export PYTORCH_TEST_WITH_UBSAN=1 # TODO: Figure out how to avoid hard-coding these paths @@ -184,9 +172,10 @@ if [[ "$BUILD_ENVIRONMENT" == *asan* ]]; then ulimit -s 81920 (cd test && python -c "import torch; print(torch.__version__, torch.version.git_version)") - echo "The next three invocations are expected to crash; if they don't that means ASAN/UBSAN is misconfigured" + echo "The next four invocations are expected to crash; if they don't that means ASAN/UBSAN is misconfigured" (cd test && ! get_exit_code python -c "import torch; torch._C._crash_if_csrc_asan(3)") (cd test && ! get_exit_code python -c "import torch; torch._C._crash_if_csrc_ubsan(0)") + (cd test && ! get_exit_code python -c "import torch; torch._C._crash_if_vptr_ubsan()") (cd test && ! get_exit_code python -c "import torch; torch._C._crash_if_aten_asan(3)") fi @@ -227,6 +216,7 @@ test_dynamo_shard() { echo "NUM_TEST_SHARDS must be defined to run a Python test shard" exit 1 fi + python tools/dynamo/verify_dynamo.py # Temporarily disable test_fx for dynamo pending the investigation on TTS # regression in https://github.com/pytorch/torchdynamo/issues/784 time python test/run_test.py \ @@ -247,34 +237,63 @@ test_dynamo_shard() { test_python_dispatch \ test_fx \ test_package \ - test_vmap \ + test_legacy_vmap \ --shard "$1" "$NUM_TEST_SHARDS" \ --verbose assert_git_not_dirty } +test_inductor_distributed() { + # this runs on both single-gpu and multi-gpu instance. It should be smart about skipping tests that aren't supported + # with if required # gpus aren't available + PYTORCH_TEST_WITH_INDUCTOR=0 PYTORCH_TEST_WITH_INDUCTOR=0 python test/run_test.py --include distributed/test_dynamo_distributed --verbose + assert_git_not_dirty +} test_inductor() { - echo "TODO: enable inductor unit tests" - # time python test/run_test.py --core --exclude test_autograd --continue-through-error --verbose + python tools/dynamo/verify_dynamo.py + python test/run_test.py --include test_modules test_ops test_ops_gradients --verbose + PYTORCH_TEST_WITH_INDUCTOR=0 python test/run_test.py --include inductor/test_torchinductor --include inductor/test_torchinductor_opinfo --verbose +} - # PYTORCH_TEST_WITH_DYNAMO and PYTORCH_TEST_WITH_INDUCTOR are only needed for PyTorch tests not written with - # using dynamo/inductor. For dynamo/inductor unit tests, specifiying them will trigger an error like - # "Detected two calls to `torchdynamo.optimize(...)` with a different backend compiler arguments." - # PYTORCH_TEST_WITH_DYNAMO=0 PYTORCH_TEST_WITH_INDUCTOR=0 pytest test/inductor +test_inductor_benchmark() { + # Use test-reports directory under test folder will allow the CI to automatically pick up + # the test reports and upload them to S3. Need to use full path here otherwise the script + # will bark about file not found later on + TEST_REPORTS_DIR=$(pwd)/test/test-reports + PARTITION_FLAGS="" + if [[ -n "$NUM_TEST_SHARDS" && -n "$2" ]]; then + PARTITION_FLAGS="--total-partitions 2 --partition-id $2" + fi + mkdir -p "$TEST_REPORTS_DIR" + # Check inference with --float32 + # shellcheck disable=SC2086 + python benchmarks/dynamo/$1.py --ci --accuracy \ + --device cuda --inductor --float32 $PARTITION_FLAGS --output "$TEST_REPORTS_DIR"/inductor_inference_$1.csv + # shellcheck disable=SC2086 + python benchmarks/dynamo/check_csv.py -f "$TEST_REPORTS_DIR"/inductor_inference_$1.csv + # Check training with --amp + # shellcheck disable=SC2086 + python benchmarks/dynamo/$1.py --ci --training --accuracy \ + --device cuda --inductor --amp $PARTITION_FLAGS --output "$TEST_REPORTS_DIR"/inductor_training_$1.csv + # shellcheck disable=SC2086 + python benchmarks/dynamo/check_csv.py -f "$TEST_REPORTS_DIR"/inductor_training_$1.csv } -test_inductor_huggingface_shard() { +test_inductor_huggingface() { + test_inductor_benchmark huggingface +} + +test_inductor_timm_shard() { if [[ -z "$NUM_TEST_SHARDS" ]]; then echo "NUM_TEST_SHARDS must be defined to run a Python test shard" exit 1 fi - TEST_REPORTS_DIR=/tmp/test-reports - mkdir -p "$TEST_REPORTS_DIR" - python benchmarks/dynamo/huggingface.py --ci --training --accuracy \ - --device cuda --inductor --float32 --total-partitions 1 --partition-id "$1" \ - --output "$TEST_REPORTS_DIR"/inductor_huggingface_"$1".csv - python benchmarks/dynamo/check_csv.py -f "$TEST_REPORTS_DIR"/inductor_huggingface_"$1".csv + test_inductor_benchmark timm_models "$1" +} + +test_inductor_torchbench() { + PYTHONPATH=$(pwd)/torchbench test_inductor_benchmark torchbench } test_python_gloo_with_tls() { @@ -393,12 +412,9 @@ test_libtorch() { OMP_NUM_THREADS=2 TORCH_CPP_TEST_MNIST_PATH="test/cpp/api/mnist" "$TORCH_BIN_DIR"/test_api --gtest_filter='-IMethodTest.*' --gtest_output=xml:$TEST_REPORTS_DIR/test_api.xml "$TORCH_BIN_DIR"/test_tensorexpr --gtest_output=xml:$TEST_REPORTS_DIR/test_tensorexpr.xml - # TODO: this condition is never (BUILD_ENVIRONMENT doesn't start with pytorch-), need to fix this. - if [[ "${BUILD_ENVIRONMENT}" == pytorch-linux-xenial-py3* ]]; then - if [[ "${BUILD_ENVIRONMENT}" != *android* && "${BUILD_ENVIRONMENT}" != *cuda* && "${BUILD_ENVIRONMENT}" != *asan* ]]; then - # TODO: Consider to run static_runtime_test from $TORCH_BIN_DIR (may need modify build script) - "$BUILD_BIN_DIR"/static_runtime_test --gtest_output=xml:$TEST_REPORTS_DIR/static_runtime_test.xml - fi + if [[ "${BUILD_ENVIRONMENT}" != *android* && "${BUILD_ENVIRONMENT}" != *cuda* && "${BUILD_ENVIRONMENT}" != *asan* ]]; then + # TODO: Consider to run static_runtime_test from $TORCH_BIN_DIR (may need modify build script) + "$BUILD_BIN_DIR"/static_runtime_test --gtest_output=xml:$TEST_REPORTS_DIR/static_runtime_test.xml fi assert_git_not_dirty fi @@ -708,6 +724,8 @@ elif [[ "${BUILD_ENVIRONMENT}" == *libtorch* ]]; then # TODO: run some C++ tests echo "no-op at the moment" elif [[ "$TEST_CONFIG" == distributed ]]; then + install_filelock + install_triton test_distributed # Only run RPC C++ tests on the first shard if [[ "${SHARD_NUMBER}" == 1 ]]; then @@ -716,6 +734,11 @@ elif [[ "$TEST_CONFIG" == distributed ]]; then elif [[ "$TEST_CONFIG" == deploy ]]; then checkout_install_torchdeploy test_torch_deploy +elif [[ "${TEST_CONFIG}" == *inductor_distributed* ]]; then + install_filelock + install_triton + install_huggingface + test_inductor_distributed elif [[ "${TEST_CONFIG}" == *dynamo* && "${SHARD_NUMBER}" == 1 && $NUM_TEST_SHARDS -gt 1 ]]; then test_without_numpy install_torchvision @@ -727,30 +750,41 @@ elif [[ "${TEST_CONFIG}" == *dynamo* && "${SHARD_NUMBER}" == 2 && $NUM_TEST_SHAR install_filelock install_triton test_dynamo_shard 2 -elif [[ "${TEST_CONFIG}" == *inductor* && "${SHARD_NUMBER}" == 1 && $NUM_TEST_SHARDS -gt 1 ]]; then +elif [[ "${TEST_CONFIG}" == *inductor_huggingface* ]]; then install_torchvision install_filelock install_triton - test_inductor -elif [[ "${TEST_CONFIG}" == *inductor* && "${SHARD_NUMBER}" == 2 && $NUM_TEST_SHARDS -gt 1 ]]; then + install_huggingface + test_inductor_huggingface +elif [[ "${TEST_CONFIG}" == *inductor_timm* && $NUM_TEST_SHARDS -gt 1 ]]; then install_torchvision install_filelock install_triton - install_huggingface - test_inductor_huggingface_shard 0 + install_timm + id=$((SHARD_NUMBER-1)) + test_inductor_timm_shard $id +elif [[ "${TEST_CONFIG}" == *inductor_torchbench* ]]; then + install_torchtext + install_torchvision + install_filelock + install_triton + checkout_install_torchbench + test_inductor_torchbench +elif [[ "${TEST_CONFIG}" == *inductor* && "${SHARD_NUMBER}" == 1 ]]; then + install_torchvision + install_filelock + install_triton + test_inductor + test_inductor_distributed elif [[ "${SHARD_NUMBER}" == 1 && $NUM_TEST_SHARDS -gt 1 ]]; then test_without_numpy install_torchvision - if ! [[ "${BUILD_ENVIRONMENT}" == *sm86 ]]; then - install_triton - fi + install_triton test_python_shard 1 test_aten elif [[ "${SHARD_NUMBER}" == 2 && $NUM_TEST_SHARDS -gt 1 ]]; then install_torchvision - if ! [[ "${BUILD_ENVIRONMENT}" == *sm86 ]]; then - install_triton - fi + install_triton test_python_shard 2 test_libtorch test_aot_compilation @@ -759,9 +793,7 @@ elif [[ "${SHARD_NUMBER}" == 2 && $NUM_TEST_SHARDS -gt 1 ]]; then test_torch_function_benchmark elif [[ "${SHARD_NUMBER}" -gt 2 ]]; then # Handle arbitrary number of shards - if ! [[ "${BUILD_ENVIRONMENT}" == *sm86 ]]; then - install_triton - fi + install_triton test_python_shard "$SHARD_NUMBER" elif [[ "${BUILD_ENVIRONMENT}" == *vulkan* ]]; then test_vulkan @@ -779,9 +811,7 @@ elif [[ "${TEST_CONFIG}" == *functorch* ]]; then test_functorch else install_torchvision - if ! [[ "${BUILD_ENVIRONMENT}" == *sm86 ]]; then - install_triton - fi + install_triton install_monkeytype test_python test_aten diff --git a/.jenkins/pytorch/win-test-helpers/build_pytorch.bat b/.jenkins/pytorch/win-test-helpers/build_pytorch.bat index b85dad0616cd7..da28956cae971 100644 --- a/.jenkins/pytorch/win-test-helpers/build_pytorch.bat +++ b/.jenkins/pytorch/win-test-helpers/build_pytorch.bat @@ -135,16 +135,17 @@ if "%REBUILD%" == "" ( if not errorlevel 0 exit /b ) ) -:: tests if BUILD_ENVIRONMENT contains cuda11 as a substring -if not x%BUILD_ENVIRONMENT:cuda11=%==x%BUILD_ENVIRONMENT% ( - set BUILD_SPLIT_CUDA=ON -) -python setup.py bdist_wheel && sccache --show-stats && python -c "import os, glob; os.system('python -mpip install ' + glob.glob('dist/*.whl')[0] + '[opt-einsum]')" ( +python setup.py bdist_wheel +if errorlevel 1 exit /b +if not errorlevel 0 exit /b +sccache --show-stats +python -c "import os, glob; os.system('python -mpip install ' + glob.glob('dist/*.whl')[0] + '[opt-einsum]')" +( if "%BUILD_ENVIRONMENT%"=="" ( echo NOTE: To run `import torch`, please make sure to activate the conda environment by running `call %CONDA_PARENT_DIR%\Miniconda3\Scripts\activate.bat %CONDA_PARENT_DIR%\Miniconda3` in Command Prompt before running Git Bash. ) else ( - 7z a %TMP_DIR_WIN%\%IMAGE_COMMIT_TAG%.7z %CONDA_PARENT_DIR%\Miniconda3\Lib\site-packages\torch %CONDA_PARENT_DIR%\Miniconda3\Lib\site-packages\torchgen %CONDA_PARENT_DIR%\Miniconda3\Lib\site-packages\caffe2 %CONDA_PARENT_DIR%\Miniconda3\Lib\site-packages\functorch && copy /Y "%TMP_DIR_WIN%\%IMAGE_COMMIT_TAG%.7z" "%PYTORCH_FINAL_PACKAGE_DIR%\" + 7z a %TMP_DIR_WIN%\%IMAGE_COMMIT_TAG%.7z %CONDA_PARENT_DIR%\Miniconda3\Lib\site-packages\torch %CONDA_PARENT_DIR%\Miniconda3\Lib\site-packages\torchgen %CONDA_PARENT_DIR%\Miniconda3\Lib\site-packages\functorch && copy /Y "%TMP_DIR_WIN%\%IMAGE_COMMIT_TAG%.7z" "%PYTORCH_FINAL_PACKAGE_DIR%\" if errorlevel 1 exit /b if not errorlevel 0 exit /b diff --git a/.jenkins/pytorch/win-test-helpers/installation-helpers/activate_miniconda3.bat b/.jenkins/pytorch/win-test-helpers/installation-helpers/activate_miniconda3.bat index e6660a17b3890..0552d85a407a5 100644 --- a/.jenkins/pytorch/win-test-helpers/installation-helpers/activate_miniconda3.bat +++ b/.jenkins/pytorch/win-test-helpers/installation-helpers/activate_miniconda3.bat @@ -13,7 +13,7 @@ if not exist %CONDA_PARENT_DIR%\Miniconda3 ( ) if "%INSTALL_FRESH_CONDA%"=="1" ( - curl --retry 3 -k https://repo.anaconda.com/miniconda/Miniconda3-latest-Windows-x86_64.exe --output %TMP_DIR_WIN%\Miniconda3-latest-Windows-x86_64.exe + curl --retry 3 --retry-all-errors -k https://repo.anaconda.com/miniconda/Miniconda3-latest-Windows-x86_64.exe --output %TMP_DIR_WIN%\Miniconda3-latest-Windows-x86_64.exe if errorlevel 1 exit /b if not errorlevel 0 exit /b diff --git a/.jenkins/pytorch/win-test-helpers/installation-helpers/install_magma.bat b/.jenkins/pytorch/win-test-helpers/installation-helpers/install_magma.bat index d9f3ab1cf8211..d0fbf5b20d888 100644 --- a/.jenkins/pytorch/win-test-helpers/installation-helpers/install_magma.bat +++ b/.jenkins/pytorch/win-test-helpers/installation-helpers/install_magma.bat @@ -24,7 +24,7 @@ if "%CUDA_SUFFIX%" == "" ( if "%REBUILD%"=="" ( if "%BUILD_ENVIRONMENT%"=="" ( - curl --retry 3 -k https://s3.amazonaws.com/ossci-windows/magma_2.5.4_%CUDA_SUFFIX%_%BUILD_TYPE%.7z --output %TMP_DIR_WIN%\magma_2.5.4_%CUDA_SUFFIX%_%BUILD_TYPE%.7z + curl --retry 3 --retry-all-errors -k https://s3.amazonaws.com/ossci-windows/magma_2.5.4_%CUDA_SUFFIX%_%BUILD_TYPE%.7z --output %TMP_DIR_WIN%\magma_2.5.4_%CUDA_SUFFIX%_%BUILD_TYPE%.7z ) else ( aws s3 cp s3://ossci-windows/magma_2.5.4_%CUDA_SUFFIX%_%BUILD_TYPE%.7z %TMP_DIR_WIN%\magma_2.5.4_%CUDA_SUFFIX%_%BUILD_TYPE%.7z --quiet ) diff --git a/.jenkins/pytorch/win-test-helpers/installation-helpers/install_mkl.bat b/.jenkins/pytorch/win-test-helpers/installation-helpers/install_mkl.bat index c700a04a1e4af..6c676d1baeded 100644 --- a/.jenkins/pytorch/win-test-helpers/installation-helpers/install_mkl.bat +++ b/.jenkins/pytorch/win-test-helpers/installation-helpers/install_mkl.bat @@ -1,6 +1,6 @@ if "%REBUILD%"=="" ( if "%BUILD_ENVIRONMENT%"=="" ( - curl --retry 3 -k https://s3.amazonaws.com/ossci-windows/mkl_2020.2.254.7z --output %TMP_DIR_WIN%\mkl.7z + curl --retry 3 --retry-all-errors -k https://s3.amazonaws.com/ossci-windows/mkl_2020.2.254.7z --output %TMP_DIR_WIN%\mkl.7z ) else ( aws s3 cp s3://ossci-windows/mkl_2020.2.254.7z %TMP_DIR_WIN%\mkl.7z --quiet ) diff --git a/.jenkins/pytorch/win-test-helpers/installation-helpers/install_sccache.bat b/.jenkins/pytorch/win-test-helpers/installation-helpers/install_sccache.bat index 0165604400ddc..6f8cc15ba8684 100644 --- a/.jenkins/pytorch/win-test-helpers/installation-helpers/install_sccache.bat +++ b/.jenkins/pytorch/win-test-helpers/installation-helpers/install_sccache.bat @@ -7,8 +7,8 @@ if "%REBUILD%"=="" ( del %TMP_DIR_WIN%\bin\sccache.exe || ver > nul del %TMP_DIR_WIN%\bin\sccache-cl.exe || ver > nul if "%BUILD_ENVIRONMENT%"=="" ( - curl --retry 3 -k https://s3.amazonaws.com/ossci-windows/sccache.exe --output %TMP_DIR_WIN%\bin\sccache.exe - curl --retry 3 -k https://s3.amazonaws.com/ossci-windows/sccache-cl.exe --output %TMP_DIR_WIN%\bin\sccache-cl.exe + curl --retry 3 --retry-all-errors -k https://s3.amazonaws.com/ossci-windows/sccache.exe --output %TMP_DIR_WIN%\bin\sccache.exe + curl --retry 3 --retry-all-errors -k https://s3.amazonaws.com/ossci-windows/sccache-cl.exe --output %TMP_DIR_WIN%\bin\sccache-cl.exe ) else ( aws s3 cp s3://ossci-windows/sccache.exe %TMP_DIR_WIN%\bin\sccache.exe aws s3 cp s3://ossci-windows/sccache-cl.exe %TMP_DIR_WIN%\bin\sccache-cl.exe diff --git a/.jenkins/pytorch/win-test.sh b/.jenkins/pytorch/win-test.sh index dc28521204878..560b039dbf679 100755 --- a/.jenkins/pytorch/win-test.sh +++ b/.jenkins/pytorch/win-test.sh @@ -39,10 +39,6 @@ fi export SCRIPT_HELPERS_DIR=$SCRIPT_PARENT_DIR/win-test-helpers -if [[ "${BUILD_ENVIRONMENT}" == *cuda11* ]]; then - export BUILD_SPLIT_CUDA=ON -fi - if [[ "$TEST_CONFIG" = "force_on_cpu" ]]; then # run the full test suite for force_on_cpu test export USE_CUDA=0 diff --git a/.lintrunner.toml b/.lintrunner.toml index 56ecfc7295f4c..10756f22a8e75 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -101,13 +101,22 @@ exclude_patterns = [ 'torch/csrc/**', 'torch/_dynamo/**/*.py', 'torch/_inductor/**/*.py', + 'torch/_functorch/aot_autograd.py', + 'torch/_functorch/benchmark_utils.py', + 'torch/_functorch/compile_utils.py', + 'torch/_functorch/compilers.py', + 'torch/_functorch/eager_transforms.py', + 'torch/_functorch/fx_minifier.py', + 'torch/_functorch/partitioners.py', + 'torch/_functorch/make_functional.py', + 'torch/_functorch/top_operators_github_usage.py', + 'torch/_functorch/vmap.py', 'torch/distributed/elastic/agent/server/api.py', 'torch/testing/_internal/**', 'torch/distributed/fsdp/fully_sharded_data_parallel.py', 'torch/distributed/distributed_c10d.py', # TODO(suo): these exclusions were added just to get lint clean on master. # Follow up to do more target suppressions and remove them. - 'torch/distributed/fsdp/flatten_params_wrapper.py', 'torch/ao/quantization/fx/convert.py', 'torch/ao/quantization/_dbr/function_fusion.py', 'test/test_datapipe.py', @@ -141,6 +150,32 @@ init_command = [ 'pyyaml==6.0', ] +[[linter]] +code = 'MYPYNOFOLLOW' +include_patterns = [ + 'torch/_dynamo/eval_frame.py', + 'torch/_dynamo/convert_frame.py', + 'torch/_dynamo/symbolic_convert.py', + 'torch/_dynamo/types.py', + 'torch/_dynamo/output_graph.py', + 'torch/_dynamo/guards.py', + 'torch/_dynamo/side_effects.py', + 'torch/_dynamo/optimizations/__init__.py', + 'torch/_dynamo/optimizations/backends.py', + 'torch/_dynamo/optimizations/training.py', + 'torch/_C/_dynamo/**/*.py', +] +exclude_patterns = [ +] +command = [ + 'python3', + 'tools/linter/adapters/mypy_linter.py', + '--config=mypy-nofollow.ini', + '--code=MYPYNOFOLLOW', + '--', + '@{{PATHSFILE}}' +] + [[linter]] code = 'MYPYSTRICT' include_patterns = [ @@ -156,6 +191,7 @@ include_patterns = [ exclude_patterns = [ # (linbinyu) copied from internal repo 'tools/code_analyzer/gen_operators_yaml.py', + 'tools/dynamo/verify_dynamo.py', 'tools/gen_vulkan_spv.py', 'tools/test/gen_operators_yaml_test.py', 'tools/test/gen_oplist_test.py', @@ -165,6 +201,7 @@ command = [ 'python3', 'tools/linter/adapters/mypy_linter.py', '--config=mypy-strict.ini', + '--code=MYPYSTRICT', '--', '@{{PATHSFILE}}' ] @@ -371,6 +408,7 @@ include_patterns = [ exclude_patterns = [ 'aten/src/ATen/native/quantized/cpu/qnnpack/**', 'aten/src/ATen/native/vulkan/api/vk_mem_alloc.h', + 'aten/src/ATen/native/vulkan/glsl/**', 'torch/csrc/jit/serialization/mobile_bytecode_generated.h', ] command = [ @@ -420,6 +458,35 @@ command = [ '@{{PATHSFILE}}' ] +[[linter]] +code = 'ERROR_PRONE_ISINSTANCE' +include_patterns = [ + 'torch/_refs/**/*.py', + 'torch/_prims/**/*.py', + 'torch/_prims_common/**/*.py', + 'torch/_decomp/**/*.py', + 'torch/_meta_registrations.py', +] +command = [ + 'python3', + 'tools/linter/adapters/grep_linter.py', + '--pattern=isinstance\([^)]+(int|float)\)', + '--linter-name=ERROR_PRONE_ISINSTANCE', + '--error-name=error prone isinstance', + """--error-description=\ + This line has an isinstance call that directly refers to \ + int or float. This is error-prone because you may also \ + have wanted to allow SymInt or SymFloat in your test. \ + To suppress this lint, use an appropriate type alias defined \ + in torch._prims_common; use IntLike/FloatLike when you would accept \ + both regular and symbolic numbers, Dim for ints representing \ + dimensions, or IntWithoutSymInt/FloatWithoutSymFloat if you really \ + meant to exclude symbolic numbers. + """, + '--', + '@{{PATHSFILE}}' +] + [[linter]] code = 'PYBIND11_SPECIALIZATION' include_patterns = [ @@ -726,6 +793,7 @@ include_patterns = [ 'torchgen/**/*.py', 'functorch/functorch/_src/aot_autograd.py', 'functorch/functorch/_src/compilers.py', + 'torch/testing/*.py', ] command = [ 'python3', diff --git a/BUILD.bazel b/BUILD.bazel index 172a31723a0bf..938630e2e2bd2 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -1648,6 +1648,17 @@ cu_library( deps = [":torch_headers"], ) +torch_sources = ({ + k: "" for k in ( + libtorch_core_sources + + libtorch_distributed_sources + + torch_cpp_srcs + + libtorch_extra_sources + + jit_core_sources + + lazy_tensor_ts_sources + + GENERATED_AUTOGRAD_CPP) +}).keys() + cc_library( name = "torch", srcs = if_cuda(glob( @@ -1657,11 +1668,7 @@ cc_library( "torch/csrc/cuda/nccl.cpp", "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu", ], - )) + libtorch_core_sources + libtorch_distributed_sources + torch_cpp_srcs + libtorch_extra_sources + jit_core_sources + lazy_tensor_ts_sources + GENERATED_AUTOGRAD_CPP + [ - "torch/csrc/jit/serialization/flatbuffer_serializer.cpp", - "torch/csrc/jit/mobile/flatbuffer_loader.cpp", - "torch/csrc/jit/serialization/flatbuffer_serializer_jit.cpp", - ], + )) + torch_sources, copts = TORCH_COPTS, defines = [ "CAFFE2_NIGHTLY_VERSION=20200115", diff --git a/CMakeLists.txt b/CMakeLists.txt index dae1dd4bc14fb..003fe7fa3d1b5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.13 FATAL_ERROR) +cmake_minimum_required(VERSION 3.18 FATAL_ERROR) #cmake_policy(SET CMP0022 NEW) #cmake_policy(SET CMP0023 NEW) @@ -11,13 +11,9 @@ cmake_policy(SET CMP0025 NEW) # Suppress warning flags in default MSVC configuration. It's not # mandatory that we do this (and we don't if cmake is old), but it's # nice when it's possible, and it's possible on our Windows configs. -if(NOT CMAKE_VERSION VERSION_LESS 3.15.0) - cmake_policy(SET CMP0092 NEW) -endif() +cmake_policy(SET CMP0092 NEW) -if(NOT CMAKE_VERSION VERSION_LESS 3.10) - set(FIND_CUDA_MODULE_DEPRECATED ON) -endif() +set(FIND_CUDA_MODULE_DEPRECATED ON) # ---[ Project and semantic versioning. project(Torch CXX C) @@ -35,9 +31,9 @@ string(FIND "${CMAKE_CXX_FLAGS}" "-std=c++" env_cxx_standard) if(env_cxx_standard GREATER -1) message( WARNING "C++ standard version definition detected in environment variable." - "PyTorch requires -std=c++14. Please remove -std=c++ settings in your environment.") + "PyTorch requires -std=c++17. Please remove -std=c++ settings in your environment.") endif() -set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ standard whose features are requested to build this target.") +set(CMAKE_CXX_STANDARD 17 CACHE STRING "The C++ standard whose features are requested to build this target.") set(CMAKE_C_STANDARD 11 CACHE STRING "The C standard whose features are requested to build this target.") if(DEFINED GLIBCXX_USE_CXX11_ABI) @@ -184,20 +180,13 @@ cmake_dependent_option( "BUILD_TEST" OFF) option(USE_CPP_CODE_COVERAGE "Compile C/C++ with code coverage flags" OFF) option(USE_COLORIZE_OUTPUT "Colorize output during compilation" ON) -option(USE_ASAN "Use Address Sanitizer" OFF) +option(USE_ASAN "Use Address+Undefined Sanitizers" OFF) option(USE_TSAN "Use Thread Sanitizer" OFF) option(USE_CUDA "Use CUDA" ON) -# BUILD_SPLIT_CUDA must also be exported as an environment variable before building, with -# `export BUILD_SPLIT_CUDA=1` because cpp_extension.py can only work properly if this variable -# also exists in the environment. -# This option is incompatible with CUDA_SEPARABLE_COMPILATION. -cmake_dependent_option( - BUILD_SPLIT_CUDA "Split torch_cuda library into torch_cuda_cu and torch_cuda_cpp" OFF - "USE_CUDA AND NOT CUDA_SEPARABLE_COMPILATION" OFF) cmake_dependent_option( BUILD_LAZY_CUDA_LINALG "Build cuda linalg ops as separate library" ON "USE_CUDA AND LINUX AND BUILD_PYTHON" OFF) option(USE_FAST_NVCC "Use parallel NVCC build" OFF) -option(USE_ROCM "Use ROCm" ON) +cmake_dependent_option(USE_ROCM "Use ROCm" ON "LINUX" OFF) option(CAFFE2_STATIC_LINK_CUDA "Statically link CUDA libraries" OFF) cmake_dependent_option( USE_CUDNN "Use cuDNN" ON @@ -292,6 +281,7 @@ if(NOT USE_XNNPACK AND CMAKE_VERSION VERSION_LESS ${XNNPACK_MIN_CMAKE_VER}) endif() option(USE_ZMQ "Use ZMQ" OFF) option(USE_ZSTD "Use ZSTD" OFF) +option(TORCH_DISABLE_GPU_ASSERTS "Disable GPU asserts by default" OFF) # Ensure that an ITT build is the default for x86 CPUs cmake_dependent_option( USE_ITT "Use Intel(R) VTune Profiler ITT functionality" ON @@ -552,6 +542,9 @@ if(MSVC) # Try harder string(APPEND CMAKE_CUDA_FLAGS " -Xcompiler /w -w") + + string(APPEND CMAKE_CXX_FLAGS " /FS") + string(APPEND CMAKE_CUDA_FLAGS " -Xcompiler /FS") endif(MSVC) string(APPEND CMAKE_CUDA_FLAGS " -Xfatbin -compress-all") @@ -818,7 +811,6 @@ endif() # ---[ Build flags if(NOT MSVC) string(APPEND CMAKE_CXX_FLAGS " -O2 -fPIC") - string(APPEND CMAKE_CXX_FLAGS " -Wno-narrowing") # Eigen fails to build with some versions, so convert this to a warning # Details at http://eigen.tuxfamily.org/bz/show_bug.cgi?id=1459 string(APPEND CMAKE_CXX_FLAGS " -Wall") @@ -827,6 +819,7 @@ if(NOT MSVC) append_cxx_flag_if_supported("-Werror=non-virtual-dtor" CMAKE_CXX_FLAGS) append_cxx_flag_if_supported("-Werror=braced-scalar-init" CMAKE_CXX_FLAGS) append_cxx_flag_if_supported("-Werror=range-loop-construct" CMAKE_CXX_FLAGS) + append_cxx_flag_if_supported("-Wnarrowing" CMAKE_CXX_FLAGS) append_cxx_flag_if_supported("-Wno-missing-field-initializers" CMAKE_CXX_FLAGS) append_cxx_flag_if_supported("-Wno-type-limits" CMAKE_CXX_FLAGS) append_cxx_flag_if_supported("-Wno-array-bounds" CMAKE_CXX_FLAGS) @@ -891,7 +884,6 @@ if(NOT MSVC) append_cxx_flag_if_supported("-Wno-unused-private-field" CMAKE_CXX_FLAGS) append_cxx_flag_if_supported("-Wno-inconsistent-missing-override" CMAKE_CXX_FLAGS) append_cxx_flag_if_supported("-Wno-aligned-allocation-unavailable" CMAKE_CXX_FLAGS) - append_cxx_flag_if_supported("-Wno-c++14-extensions" CMAKE_CXX_FLAGS) append_cxx_flag_if_supported("-Wno-constexpr-not-const" CMAKE_CXX_FLAGS) append_cxx_flag_if_supported("-Wno-missing-braces" CMAKE_CXX_FLAGS) append_cxx_flag_if_supported("-Wunused-lambda-capture" CMAKE_CXX_FLAGS) @@ -932,8 +924,8 @@ if(NOT MSVC) endif() if(USE_ASAN) - string(APPEND CMAKE_CXX_FLAGS_DEBUG " -fsanitize=address") - string(APPEND CMAKE_LINKER_FLAGS_DEBUG " -fsanitize=address") + string(APPEND CMAKE_CXX_FLAGS_DEBUG " -fsanitize=address -fsanitize=undefined") + string(APPEND CMAKE_LINKER_FLAGS_DEBUG " -fsanitize=address -fsanitize=undefined") endif() if(USE_TSAN) @@ -996,7 +988,6 @@ if(APPLE) endif() append_cxx_flag_if_supported("-Wno-unused-private-field" CMAKE_CXX_FLAGS) append_cxx_flag_if_supported("-Wno-missing-braces" CMAKE_CXX_FLAGS) - append_cxx_flag_if_supported("-Wno-c++14-extensions" CMAKE_CXX_FLAGS) append_cxx_flag_if_supported("-Wno-constexpr-not-const" CMAKE_CXX_FLAGS) endif() diff --git a/CODEOWNERS b/CODEOWNERS index 3bddc2f0373e4..c5699b137a5ed 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -15,7 +15,7 @@ /torch/autograd/ @albanD @soulitzer /tools/autograd/ @albanD @soulitzer /torch/nn/ @albanD @jbschlosser -/torch/optim/ @albanD +/torch/optim/ @albanD @janeyx99 /test/test_public_bindings.py @albanD /test/allowlist_for_publicAPI.json @albanD @anjali411 /docs/source/conf.py @albanD @@ -25,8 +25,8 @@ /aten/src/ATen/native/ao_sparse @z-a-f @salilsdesai @kimishpatel @digantdesai @jianyuh /aten/src/ATen/native/quantized @jerryzh168 @z-a-f @salilsdesai @kimishpatel @digantdesai @jianyuh /aten/src/ATen/native/quantized/cpu @jerryzh168 @z-a-f @salilsdesai @kimishpatel @digantdesai @jianyuh -/aten/src/ATen/native/quantized/cuda @jerryzh168 @dzdang -/aten/src/ATen/native/quantized/cudnn @jerryzh168 @dzdang +/aten/src/ATen/native/quantized/cuda @jerryzh168 +/aten/src/ATen/native/quantized/cudnn @jerryzh168 /test/test_quantization.py @jerryzh168 /test/ao/ @jerryzh168 @z-a-f @hdcharles /test/quantization/ @jerryzh168 @z-a-f @@ -39,21 +39,22 @@ nn/quantizable/ @jerryzh168 @z-a-f nn/qat/ @jerryzh168 # Tensorpipe RPC Agent. -/torch/csrc/distributed/rpc/tensorpipe_agent.cpp @jiayisuse @osalpekar @lw @beauby -/torch/csrc/distributed/rpc/tensorpipe_agent.h @jiayisuse @osalpekar @lw @beauby +/torch/csrc/distributed/rpc/tensorpipe_agent.cpp @jiayisuse @osalpekar @lw +/torch/csrc/distributed/rpc/tensorpipe_agent.h @jiayisuse @osalpekar @lw # Distributed package # This list is mostly if you'd like to be tagged as reviewer, feel free to add # or remove yourself from it. -/torch/csrc/distributed/ @mrshenli @zhaojuanmao @pritamdamania87 @rohan-varma @mingzhe09088 @H-Huang @awgu @kwen2501 -/torch/distributed/ @mrshenli @zhaojuanmao @pritamdamania87 @rohan-varma @mingzhe09088 @H-Huang @awgu @kwen2501 -/torch/nn/parallel/ @mrshenli @zhaojuanmao @pritamdamania87 @rohan-varma @mingzhe09088 @H-Huang @awgu @kwen2501 +/torch/csrc/distributed/ @mrshenli @zhaojuanmao @pritamdamania87 @rohan-varma @H-Huang @awgu @kwen2501 @wanchaol +/torch/distributed/ @mrshenli @zhaojuanmao @pritamdamania87 @rohan-varma @H-Huang @awgu @kwen2501 @wanchaol +/torch/distributed/_composable @mrshenli @zhaojuanmao @pritamdamania87 @rohan-varma @H-Huang @awgu @kwen2501 @yhcharles +/torch/nn/parallel/ @mrshenli @zhaojuanmao @pritamdamania87 @rohan-varma @H-Huang @awgu @kwen2501 @wanchaol # Distributed tests # This list is mostly if you'd like to be tagged as reviewer, feel free to add # or remove yourself from it. -/test/distributed @mrshenli @pritamdamania87 @zhaojuanmao @rohan-varma @H-Huang @awgu @kwen2501 -/torch/testing/_internal/distributed @mrshenli @pritamdamania87 @zhaojuanmao @rohan-varma @H-Huang @awgu @kwen2501 +/test/distributed @mrshenli @pritamdamania87 @zhaojuanmao @rohan-varma @H-Huang @awgu @kwen2501 @wanchaol +/torch/testing/_internal/distributed @mrshenli @pritamdamania87 @zhaojuanmao @rohan-varma @H-Huang @awgu @kwen2501 @wanchaol # ONNX Export /torch/csrc/jit/passes/onnx.h @bowenbao @abock @@ -90,12 +91,19 @@ nn/qat/ @jerryzh168 /torch/testing/_internal/common_methods_invocations.py @mruberry @ngimel /torch/testing/_internal/common_device_type.py @mruberry @ngimel test/test_ops.py @mruberry @ngimel -test/test_ops_gradients.py @mruberry @ngimel +test/test_ops_gradients.py @mruberry @ngimel @soulitzer +test/test_ops_fwd_gradients.py @mruberry @ngimel @soulitzer test/test_unary_ufuncs.py @mruberry @ngimel test/test_binary_ufuncs.py @mruberry @ngimel test/test_reductions.py @mruberry @ngimel test/test_type_promotion.py @mruberry @ngimel +# functorch-related things +# This list is for people wanting to be notified every time there's a change +# Useful for e.g. auditing xfails that other folks add to tests +test/functorch/test_ops.py @zou3519 +test/functorch/test_vmap.py @zou3519 + # torch MPS test/test_mps.py @kulinseth aten/src/ATen/mps/ @kulinseth @@ -106,3 +114,6 @@ torch/csrc/autograd/profiler* @robieta torch/autograd/profiler* @robieta torch/csrc/profiler/ @robieta torch/profiler/ @robieta + +# AOTDispatch tests +test/functorch/test_aotdispatch.py @ezyang @Chillee diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 05e98c3b9a673..eaf81b19eefaf 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -118,21 +118,9 @@ git submodule sync --recursive git submodule update --init --recursive --jobs 0 ``` -If you want to have no-op incremental rebuilds (which are fast), see the section below titled "Make no-op build fast." +If you want to have no-op incremental rebuilds (which are fast), see [Make no-op build fast](#make-no-op-build-fast) below. -3. Follow the instructions for [installing PyTorch from source](https://github.com/pytorch/pytorch#from-source), except when it's time to install PyTorch instead of invoking `setup.py install` you'll want to call `setup.py develop` instead: - -Specifically, the change you have to make is to replace - -```bash -python setup.py install -``` - -with - -```bash -python setup.py develop -``` +3. Follow the instructions for [installing PyTorch from source](https://github.com/pytorch/pytorch#from-source), but instead of installing PyTorch via `python setup.py install`, use `python setup.py develop`. This mode will symlink the Python files from the current local source tree into the Python install. This way when you modify a Python file, you @@ -1290,8 +1278,9 @@ our [CI wiki](https://github.com/pytorch/pytorch/wiki/Debugging-using-with-ssh-f ### Which commit is used in CI? For CI run on `master`, this repository is checked out for a given `master` -commit, and CI is run on that commit (there isn't really any other choice). For -PRs, however, it's a bit more complicated. Consider this commit graph, where +commit, and CI is run on that commit (there isn't really any other choice). + +For PRs, however, it's a bit more complicated. Consider this commit graph, where `master` is at commit `A`, and the branch for PR #42 (just a placeholder) is at commit `B`: @@ -1300,7 +1289,7 @@ commit `B`: / \ / C (refs/pull/42/merge) / / ----o---o---o---A (refs/heads/master) +---o---o---o---A (merge-destination) - usually master ``` There are two possible choices for which commit to use: @@ -1308,37 +1297,18 @@ There are two possible choices for which commit to use: 1. Checkout commit `B`, the head of the PR (manually committed by the PR author). 2. Checkout commit `C`, the hypothetical result of what would happen if the PR - were merged into `master` (automatically generated by GitHub). - -This choice depends on several factors; here is the decision tree as of -2021-03-30: - -- For CI jobs on CircleCI: - - If the name of the job (or one of its ancestors in the workflow DAG) - contains "xla" or "gcc5", choice **2** is used. This includes the following - jobs: - - pytorch_linux_xenial_py3_6_gcc5_4_build - - pytorch_cpp_doc_build - - pytorch_doc_test - - pytorch_linux_forward_backward_compatibility_check_test - - pytorch_linux_xenial_py3_6_gcc5_4_jit_legacy_test - - pytorch_linux_xenial_py3_6_gcc5_4_test - - pytorch_python_doc_build - - pytorch_xla_linux_bionic_py3_6_clang9_build - - pytorch_xla_linux_bionic_py3_6_clang9_test - - Otherwise, choice **1** is used. -- For CI jobs on GitHub Actions: - - If the PR was created using [`ghstack`](https://github.com/ezyang/ghstack), - choice **1** is used. - - Otherwise, choice **2** is used. - -This is important to be aware of, because if you see a CI failure on your PR and -choice **2** is being used for that CI job, it is possible that the failure is -nondeterministically caused by a commit that does not exist in the ancestry of -your PR branch. If you happen to have write access to this repo, you can choose -to use `ghstack` to eliminate this nondeterminism for GitHub Actions jobs on -your PRs, but it will still be present for the select CircleCI jobs listed -above. + were merged into it's destination (usually `master`). + +For all practical purposes, most people can think of the commit being used as +commit `B` (choice **1**). + +However, if workflow files (which govern CI behavior) were modified (either by your PR or since dev branch were created ) there's +a nuance to know about: +The workflow files themselves get taken from checkpoint `C`, the merger of your +PR and the `master` branch. But only the workflow files get taken from that merged +checkpoint. Everything else (tests, code, etc) all get taken directly from your +PR's commit (commit `B`). Please note, this scenario would never affect PRs authored by `ghstack` as they would not automatically ingest the updates from default branch. + ## Dev Infra Office Hours [Dev Infra Office Hours](https://github.com/pytorch/pytorch/wiki/Dev-Infra-Office-Hours) are hosted every Friday to answer any questions regarding developer experience, Green HUD, and CI. diff --git a/Dockerfile b/Dockerfile index 815a9108ce946..e125271607c93 100644 --- a/Dockerfile +++ b/Dockerfile @@ -59,20 +59,23 @@ RUN --mount=type=cache,target=/opt/ccache \ FROM conda as conda-installs ARG PYTHON_VERSION=3.8 -ARG CUDA_VERSION=11.3 +ARG CUDA_VERSION=11.6 ARG CUDA_CHANNEL=nvidia ARG INSTALL_CHANNEL=pytorch-nightly -ENV CONDA_OVERRIDE_CUDA=${CUDA_VERSION} # Automatically set by buildx +RUN /opt/conda/bin/conda update -y conda RUN /opt/conda/bin/conda install -c "${INSTALL_CHANNEL}" -y python=${PYTHON_VERSION} ARG TARGETPLATFORM +ARG TRITON_VERSION + # On arm64 we can only install wheel packages RUN case ${TARGETPLATFORM} in \ "linux/arm64") pip install --extra-index-url https://download.pytorch.org/whl/cpu/ torch torchvision torchtext ;; \ - *) /opt/conda/bin/conda install -c "${INSTALL_CHANNEL}" -c "${CUDA_CHANNEL}" -y "python=${PYTHON_VERSION}" pytorch torchvision torchtext "cudatoolkit=${CUDA_VERSION}" ;; \ + *) /opt/conda/bin/conda install -c "${INSTALL_CHANNEL}" -c "${CUDA_CHANNEL}" -y "python=${PYTHON_VERSION}" pytorch torchvision torchtext "pytorch-cuda=$(echo $CUDA_VERSION | cut -d'.' -f 1-2)" ;; \ esac && \ /opt/conda/bin/conda clean -ya RUN /opt/conda/bin/pip install torchelastic +RUN if test -n "${TRITON_VERSION}" -a "${TARGETPLATFORM}" != "linux/arm64"; then /opt/conda/bin/pip install "torchtriton==${TRITON_VERSION}" --extra-index-url https://download.pytorch.org/whl/nightly/cpu ; fi FROM ${BASE_IMAGE} as official ARG PYTORCH_VERSION diff --git a/MANIFEST.in b/MANIFEST.in index 403b90b702df2..f6ffb4e02a8af 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,6 +1,6 @@ include MANIFEST.in include CMakeLists.txt -include CITATION +include CITATION.cff include LICENSE include NOTICE include .gitmodules diff --git a/Makefile b/Makefile index 21745f42a8873..45dfeb8cda267 100644 --- a/Makefile +++ b/Makefile @@ -31,3 +31,7 @@ lint: quicklint: lintrunner + +triton: + $(PIP) uninstall -y triton + $(PIP) install -U "git+https://github.com/openai/triton@$(shell cat .github/ci_commit_pins/triton.txt)#subdirectory=python" diff --git a/README.md b/README.md index 3a80c8083a499..bcce2997b25b6 100644 --- a/README.md +++ b/README.md @@ -234,7 +234,7 @@ python tools/amd_build/build_amd.py Install PyTorch ```bash export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"} -python setup.py install +python setup.py develop ``` Note that if you are using [Anaconda](https://www.anaconda.com/distribution/#download-section), you may experience an error caused by the linker: @@ -251,7 +251,7 @@ This is caused by `ld` from the Conda environment shadowing the system `ld`. You ```bash export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"} -MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++ python setup.py install +MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++ python setup.py develop ``` **On Windows** @@ -274,7 +274,7 @@ In this mode PyTorch computations will run on your CPU, not your GPU ```cmd conda activate -python setup.py install +python setup.py develop ``` Note on OpenMP: The desired OpenMP implementation is Intel OpenMP (iomp). In order to link against iomp, you'll need to manually download the library and set up the building environment by tweaking `CMAKE_INCLUDE_PATH` and `LIB`. The instruction [here](https://github.com/pytorch/pytorch/blob/master/docs/source/notes/windows.rst#building-from-source) is an example for setting up both MKL and Intel OpenMP. Without these configurations for CMake, Microsoft Visual C OpenMP runtime (vcomp) will be used. @@ -315,7 +315,7 @@ for /f "usebackq tokens=*" %i in (`"%ProgramFiles(x86)%\Microsoft Visual Studio\ :: [Optional] If you want to override the CUDA host compiler set CUDAHOSTCXX=C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\VC\Tools\MSVC\14.27.29110\bin\HostX64\x64\cl.exe -python setup.py install +python setup.py develop ``` diff --git a/RELEASE.md b/RELEASE.md index e2b69b5bf82ee..22d279bceec3a 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -21,6 +21,7 @@ - [Patch Release Criteria](#patch-release-criteria) - [Patch Release Process](#patch-release-process) - [Triage](#triage) + - [Issue Tracker for Patch releases](#issue-tracker-for-patch-releases) - [Building a release schedule / cherry picking](#building-a-release-schedule--cherry-picking) - [Building Binaries / Promotion to Stable](#building-binaries--promotion-to-stable) - [Hardware / Software Support in Binary Build Matrix](#hardware--software-support-in-binary-build-matrix) @@ -234,6 +235,18 @@ Patch releases should be considered if a regression meets the following criteria 3. Triage reviewers will then add the issue / pull request to the related milestone (i.e. `1.9.1`) if the regressions if found to be within the [Patch Release Criteria](#patch-release-criteria) * ![adding to milestone](https://user-images.githubusercontent.com/1700823/131175980-148ff38d-44c3-4611-8a1f-cd2fd1f4c49d.png) +### Issue Tracker for Patch releases + +For patch releases issue tracker needs to be created. For patch release, we require all cherry-pick changes to have links to either a high-priority Github issue or a CI failure from previous RC. An example of this would look like: +* https://github.com/pytorch/pytorch/issues/51886 + +Only following issues are accepted: +1. Fixes to regressions against previous major version (e.g. regressions introduced in 1.13.0 from 1.12.0 are pickable for 1.13.1) +2. Critical fixes for: silent correctness, backwards compatibility, crashes, deadlocks, (large) memory leaks +3. Fixes to new features being introduced in this release +4. Documentation improvements +5. Release branch specific changes (e.g. blocking ci fixes, change version identifiers) + ### Building a release schedule / cherry picking > Main POC: Patch Release Managers @@ -281,7 +294,7 @@ need to support these particular versions of software. In the event a submodule cannot be fast forwarded and a patch must be applied we can take two different approaches: -* (preferred) Fork the said repository under the pytorch Github organization, apply the patches we need there, and then switch our submodule to accept our fork. +* (preferred) Fork the said repository under the pytorch GitHub organization, apply the patches we need there, and then switch our submodule to accept our fork. * Get the dependencies maintainers to support a release branch for us Editing submodule remotes can be easily done with: (running from the root of the git repository) diff --git a/android/gradle.properties b/android/gradle.properties index ecefc09a587ba..25695a1762f63 100644 --- a/android/gradle.properties +++ b/android/gradle.properties @@ -1,6 +1,6 @@ ABI_FILTERS=armeabi-v7a,arm64-v8a,x86,x86_64 -VERSION_NAME=1.14.0-SNAPSHOT +VERSION_NAME=2.0.0-SNAPSHOT GROUP=org.pytorch MAVEN_GROUP=org.pytorch SONATYPE_STAGING_PROFILE=orgpytorch diff --git a/android/pytorch_android/CMakeLists.txt b/android/pytorch_android/CMakeLists.txt index ad2647c2f4df6..9691d5694c441 100644 --- a/android/pytorch_android/CMakeLists.txt +++ b/android/pytorch_android/CMakeLists.txt @@ -14,7 +14,7 @@ endif() include(GNUInstallDirs) -set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ standard whose features are requested to build this target.") +set(CMAKE_CXX_STANDARD 17 CACHE STRING "The C++ standard whose features are requested to build this target.") set(CMAKE_VERBOSE_MAKEFILE ON) message(STATUS "ANDROID_STL:${ANDROID_STL}") diff --git a/android/pytorch_android/src/main/cpp/pytorch_jni_jit.cpp b/android/pytorch_android/src/main/cpp/pytorch_jni_jit.cpp index 1b0d54784d76f..6ef4f462df169 100644 --- a/android/pytorch_android/src/main/cpp/pytorch_jni_jit.cpp +++ b/android/pytorch_android/src/main/cpp/pytorch_jni_jit.cpp @@ -195,14 +195,16 @@ class PytorchJni : public facebook::jni::HybridClass { std::vector inputs{}; size_t n = jinputs->size(); inputs.reserve(n); + const bool requires_backend_transfers = + module_.attr("requires_backend_transfers", at::IValue(true)).toBool(); for (size_t i = 0; i < n; i++) { at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i)); - if (at::kVulkan == deviceType_) { + if (at::kVulkan == deviceType_ && requires_backend_transfers) { inputs.push_back( atIValue.isTensor() ? at::IValue{atIValue.toTensor().vulkan()} : std::move(atIValue)); } else { - TORCH_CHECK(at::kCPU == deviceType_); + TORCH_CHECK(at::kCPU == deviceType_ || !requires_backend_transfers); inputs.push_back(std::move(atIValue)); } } @@ -223,14 +225,16 @@ class PytorchJni : public facebook::jni::HybridClass { std::vector inputs{}; size_t n = jinputs->size(); inputs.reserve(n); + const bool requires_backend_transfers = + module_.attr("requires_backend_transfers", at::IValue(true)).toBool(); for (size_t i = 0; i < n; i++) { at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i)); - if (at::kVulkan == deviceType_) { + if (at::kVulkan == deviceType_ && requires_backend_transfers) { inputs.push_back( atIValue.isTensor() ? at::IValue{atIValue.toTensor().vulkan()} : std::move(atIValue)); } else { - TORCH_CHECK(at::kCPU == deviceType_); + TORCH_CHECK(at::kCPU == deviceType_ || !requires_backend_transfers); inputs.push_back(std::move(atIValue)); } } diff --git a/android/pytorch_android/src/main/cpp/pytorch_jni_lite.cpp b/android/pytorch_android/src/main/cpp/pytorch_jni_lite.cpp index 86fd1e2260f9c..802bb801a1f9c 100644 --- a/android/pytorch_android/src/main/cpp/pytorch_jni_lite.cpp +++ b/android/pytorch_android/src/main/cpp/pytorch_jni_lite.cpp @@ -158,14 +158,16 @@ class PytorchJni : public facebook::jni::HybridClass { std::vector inputs{}; size_t n = jinputs->size(); inputs.reserve(n); + const bool requires_backend_transfers = + module_.attr("requires_backend_transfers", at::IValue(true)).toBool(); for (const auto i : c10::irange(n)) { at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i)); - if (at::kVulkan == deviceType_) { + if (at::kVulkan == deviceType_ && requires_backend_transfers) { inputs.push_back( atIValue.isTensor() ? at::IValue{atIValue.toTensor().vulkan()} : std::move(atIValue)); } else { - TORCH_CHECK(at::kCPU == deviceType_); + TORCH_CHECK(at::kCPU == deviceType_ || !requires_backend_transfers); inputs.push_back(std::move(atIValue)); } } @@ -187,14 +189,16 @@ class PytorchJni : public facebook::jni::HybridClass { std::vector inputs{}; size_t n = jinputs->size(); inputs.reserve(n); + const bool requires_backend_transfers = + module_.attr("requires_backend_transfers", at::IValue(true)).toBool(); for (const auto i : c10::irange(n)) { at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i)); - if (at::kVulkan == deviceType_) { + if (at::kVulkan == deviceType_ && requires_backend_transfers) { inputs.push_back( atIValue.isTensor() ? at::IValue{atIValue.toTensor().vulkan()} : std::move(atIValue)); } else { - TORCH_CHECK(at::kCPU == deviceType_); + TORCH_CHECK(at::kCPU == deviceType_ || !requires_backend_transfers); inputs.push_back(std::move(atIValue)); } } diff --git a/android/pytorch_android_torchvision/CMakeLists.txt b/android/pytorch_android_torchvision/CMakeLists.txt index 08de7cebde491..849e4d07cc1d5 100644 --- a/android/pytorch_android_torchvision/CMakeLists.txt +++ b/android/pytorch_android_torchvision/CMakeLists.txt @@ -1,6 +1,6 @@ cmake_minimum_required(VERSION 3.4.1) project(pytorch_vision_jni CXX) -set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ standard whose features are requested to build this target.") +set(CMAKE_CXX_STANDARD 17 CACHE STRING "The C++ standard whose features are requested to build this target.") set(CMAKE_VERBOSE_MAKEFILE ON) set(pytorch_vision_cpp_DIR ${CMAKE_CURRENT_LIST_DIR}/src/main/cpp) diff --git a/android/test_app/app/CMakeLists.txt b/android/test_app/app/CMakeLists.txt index 457ccbe189bd7..cfdc4976ef48d 100644 --- a/android/test_app/app/CMakeLists.txt +++ b/android/test_app/app/CMakeLists.txt @@ -1,7 +1,7 @@ cmake_minimum_required(VERSION 3.4.1) set(PROJECT_NAME pytorch_testapp_jni) project(${PROJECT_NAME} CXX) -set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ standard whose features are requested to build this target.") +set(CMAKE_CXX_STANDARD 17 CACHE STRING "The C++ standard whose features are requested to build this target.") set(CMAKE_VERBOSE_MAKEFILE ON) set(build_DIR ${CMAKE_SOURCE_DIR}/build) diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 629db87dc15d3..613c6a6834e33 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -56,7 +56,7 @@ if(NOT BUILD_CAFFE2 AND NOT BUILD_LITE_INTERPRETER) EXCLUDE(ATen_CORE_TEST_SRCS "${ATen_CORE_TEST_SRCS}" ${ATen_CORE_EXCLUDED_TEST_SRCS}) endif() -file(GLOB base_h "*.h" "detail/*.h" "cpu/*.h" "cpu/vec/vec512/*.h" "cpu/vec/vec256/*.h" "cpu/vec/*.h" "quantized/*.h" "functorch/*.h") +file(GLOB base_h "*.h" "detail/*.h" "cpu/*.h" "cpu/vec/vec512/*.h" "cpu/vec/vec256/*.h" "cpu/vec/vec256/vsx/*.h" "cpu/vec/*.h" "quantized/*.h" "functorch/*.h") file(GLOB base_cpp "*.cpp" "detail/*.cpp" "cpu/*.cpp" "functorch/*.cpp") file(GLOB cuda_h "cuda/*.h" "cuda/detail/*.h" "cuda/*.cuh" "cuda/detail/*.cuh") file(GLOB cuda_cpp "cuda/*.cpp" "cuda/detail/*.cpp") diff --git a/aten/src/ATen/CPUGeneratorImpl.cpp b/aten/src/ATen/CPUGeneratorImpl.cpp index d7dce2561d4f9..5fd06c442750d 100644 --- a/aten/src/ATen/CPUGeneratorImpl.cpp +++ b/aten/src/ATen/CPUGeneratorImpl.cpp @@ -127,8 +127,8 @@ void CPUGeneratorImpl::set_state(const c10::TensorImpl& new_state) { using detail::CPUGeneratorImplState; using detail::CPUGeneratorImplStateLegacy; - static_assert(std::is_pod::value, "CPUGeneratorImplStateLegacy is not a PODType"); - static_assert(std::is_pod::value, "CPUGeneratorImplState is not a PODType"); + static_assert(std::is_standard_layout::value, "CPUGeneratorImplStateLegacy is not a PODType"); + static_assert(std::is_standard_layout::value, "CPUGeneratorImplState is not a PODType"); static const size_t size_legacy = sizeof(CPUGeneratorImplStateLegacy); static const size_t size_current = sizeof(CPUGeneratorImplState); @@ -207,7 +207,7 @@ c10::intrusive_ptr CPUGeneratorImpl::get_state() const { using detail::CPUGeneratorImplState; static const size_t size = sizeof(CPUGeneratorImplState); - static_assert(std::is_pod::value, "CPUGeneratorImplState is not a PODType"); + static_assert(std::is_standard_layout::value, "CPUGeneratorImplState is not a PODType"); auto state_tensor = at::detail::empty_cpu({(int64_t)size}, ScalarType::Byte, c10::nullopt, c10::nullopt, c10::nullopt, c10::nullopt); auto rng_state = state_tensor.data_ptr(); diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index c96b36975214e..256a4bd9e5fdf 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -21,10 +21,6 @@ #include #endif // USE_FBGEMM -#ifdef USE_MPS -#include -#endif - namespace at { Context::Context() = default; @@ -112,6 +108,14 @@ void Context::setSDPUseFlash(bool e) { enabled_flashSDP = e; } +bool Context::userEnabledMemEfficientSDP() const { + return enabled_mem_efficientSDP; +} + +void Context::setSDPUseMemEfficient(bool e) { + enabled_mem_efficientSDP = e; +} + bool Context::userEnabledMathSDP() const { return enabled_mathSDP; } @@ -262,14 +266,6 @@ bool Context::hasMKLDNN() { #endif } -bool Context::hasMPS() { -#if USE_MPS - return at::mps::is_available(); -#else - return false; -#endif -} - bool Context::hasOpenMP() { #ifdef _OPENMP return true; @@ -287,8 +283,24 @@ bool Context::hasLAPACK() { } at::QEngine Context::qEngine() const { - // If wasn't explicitly set - take the last one available - return quantized_engine.value_or(supportedQEngines().back()); + static auto _quantized_engine = []() { + at::QEngine qengine = at::kNoQEngine; +#if defined(C10_MOBILE) && defined(USE_PYTORCH_QNNPACK) + qengine = at::kQNNPACK; +#endif + +#if AT_MKLDNN_ENABLED() + qengine = at::kONEDNN; +#endif + +#ifdef USE_FBGEMM + if (fbgemm::fbgemmSupportedCPU()) { + qengine = at::kFBGEMM; + } +#endif + return qengine; + }(); + return quantized_engine.value_or(_quantized_engine); } void Context::setQEngine(at::QEngine e) { @@ -324,8 +336,8 @@ const std::vector& Context::supportedQEngines() { #ifdef USE_FBGEMM if (fbgemm::fbgemmSupportedCPU()) { - // The X86 qengine is available if and only if FBGEMM is available engines.push_back(at::kX86); + // The X86 qengine is available if and only if FBGEMM is available engines.push_back(at::kFBGEMM); } #endif diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index 43f4433b7ce99..9f1c571b66968 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -37,6 +38,8 @@ class TORCH_API Context { return at::detail::getDefaultCPUGenerator(); } else if (device_type == at::kCUDA) { return at::detail::getCUDAHooks().getDefaultCUDAGenerator(device.index()); + } else if (device_type == at::kMPS) { + return at::detail::getMPSHooks().getDefaultMPSGenerator(); } else { AT_ERROR(DeviceTypeName(device_type), " device type not enabled."); } @@ -83,6 +86,9 @@ class TORCH_API Context { static bool hasHIP() { return detail::getHIPHooks().hasHIP(); } + static bool hasMPS() { + return detail::getMPSHooks().hasMPS(); + } static bool hasIPU() { return c10::impl::hasDeviceGuardImpl(at::DeviceType::IPU); } @@ -92,8 +98,6 @@ class TORCH_API Context { static bool hasLazy() { return c10::impl::hasDeviceGuardImpl(at::DeviceType::Lazy); } - static bool hasMPS(); - static bool hasORT() { return c10::impl::hasDeviceGuardImpl(at::DeviceType::ORT); } @@ -128,8 +132,9 @@ class TORCH_API Context { // Note [Disabling Fused SDP Kernels] // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - // Flash SDP kernels are enabled by default. However, they can be disabled - // by setting at::globalContext().setUserEnabledFlashSDP(false) flag. + // Flash and Memory Efficient SDP kernels are enabled by default. + // However, they can be disabled by setting + // at::globalContext().setUserEnabledFlashSDP(false) flag. // This is useful for debugging purposes. For example, if you want to // compare the performance of the flash SDP kernels with the unfused // kernel, you can disable the flash SDP kernels. By disabling @@ -139,6 +144,9 @@ class TORCH_API Context { void setSDPUseFlash(bool); bool userEnabledFlashSDP() const; + void setSDPUseMemEfficient(bool); + bool userEnabledMemEfficientSDP() const; + void setSDPUseMath(bool); bool userEnabledMathSDP() const; @@ -270,6 +278,7 @@ class TORCH_API Context { bool _deterministic_algorithms = false; bool _deterministic_algorithms_warn_only = false; bool enabled_flashSDP = true; + bool enabled_mem_efficientSDP = true; bool enabled_mathSDP = true; #ifdef USE_ROCM bool benchmark_cudnn = true; @@ -414,6 +423,13 @@ static inline void manual_seed(uint64_t seed) { } } } + + if (hasMPS()) { + auto mps_gen = globalContext().defaultGenerator(DeviceType::MPS); + // See Note [Acquire lock when using random generators] + std::lock_guard lock(mps_gen.mutex()); + mps_gen.set_current_seed(seed); + } } // When the global flag `allow_tf32` is set to true, cuBLAS handles are diff --git a/aten/src/ATen/ExpandUtils.cpp b/aten/src/ATen/ExpandUtils.cpp index a44005a2ef815..ee846c9b82e34 100644 --- a/aten/src/ATen/ExpandUtils.cpp +++ b/aten/src/ATen/ExpandUtils.cpp @@ -13,8 +13,8 @@ TensorBase expand_slow_path(const TensorBase &self, IntArrayRef size) { namespace { // NOTE: are_expandable did a similar check, please keep them sync if change is needed -template -Container infer_size_impl(IntArrayRef a, IntArrayRef b) { +template +Container infer_size_impl(ArrayType a, ArrayType b) { size_t dimsA = a.size(); size_t dimsB = b.size(); size_t ndim = dimsA > dimsB ? dimsA : dimsB; @@ -25,8 +25,8 @@ Container infer_size_impl(IntArrayRef a, IntArrayRef b) { ptrdiff_t offset = ndim - 1 - i; ptrdiff_t dimA = dimsA - 1 - offset; ptrdiff_t dimB = dimsB - 1 - offset; - int64_t sizeA = (dimA >= 0) ? a[dimA] : 1; - int64_t sizeB = (dimB >= 0) ? b[dimB] : 1; + auto sizeA = (dimA >= 0) ? a[dimA] : 1; + auto sizeB = (dimB >= 0) ? b[dimB] : 1; TORCH_CHECK( sizeA == sizeB || sizeA == 1 || sizeB == 1, @@ -35,7 +35,7 @@ Container infer_size_impl(IntArrayRef a, IntArrayRef b) { ") at non-singleton dimension ", i); // 1s map to the other size (even 0). - expandedSizes[i] = sizeA == 1 ? sizeB : sizeA; + expandedSizes[i] = sizeA == 1 ? std::move(sizeB) : std::move(sizeA); } return expandedSizes; diff --git a/aten/src/ATen/ExpandUtils.h b/aten/src/ATen/ExpandUtils.h index 779894645b8ec..9e48421e540fe 100644 --- a/aten/src/ATen/ExpandUtils.h +++ b/aten/src/ATen/ExpandUtils.h @@ -21,6 +21,8 @@ namespace at { TORCH_API std::vector infer_size(IntArrayRef a, IntArrayRef b); TORCH_API DimVector infer_size_dimvector(IntArrayRef a, IntArrayRef b); +TORCH_API SymDimVector +infer_size_symdimvector(SymIntArrayRef a, SymIntArrayRef b); // Named type instead of a pair/tuple so that we can be sure to // construct the vectors in place and get NRVO. @@ -94,10 +96,11 @@ inline void check_defined( inline c10::MaybeOwned expand_inplace( const Tensor& tensor, const Tensor& to_expand) { - if (tensor.sizes().equals(to_expand.sizes())) { + if (tensor.sym_sizes().equals(to_expand.sym_sizes())) { return c10::MaybeOwned::borrowed(to_expand); } - return c10::MaybeOwned::owned(to_expand.expand(tensor.sizes())); + return c10::MaybeOwned::owned( + to_expand.expand_symint(tensor.sym_sizes())); } inline c10::MaybeOwned expand_inplace( diff --git a/aten/src/ATen/FunctionalInverses.cpp b/aten/src/ATen/FunctionalInverses.cpp index ed1026152e32c..2bdc76c7764af 100644 --- a/aten/src/ATen/FunctionalInverses.cpp +++ b/aten/src/ATen/FunctionalInverses.cpp @@ -159,10 +159,11 @@ Tensor FunctionalInverses::_reshape_alias_copy_inverse(const Tensor& base, const } } -Tensor FunctionalInverses::select_copy_int_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, int64_t dim, int64_t index) { +Tensor FunctionalInverses::select_copy_int_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, int64_t dim, c10::SymInt index) { // Pessimism: we can't reapply views for slice_scatter. - return base.select_scatter(mutated_view, dim, index); + return base.select_scatter_symint(mutated_view, dim, index); } + Tensor FunctionalInverses::detach_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views) { // the functionalization pass doesn't care about autograd metadata - as a view, I think detach() is just an identity function return mutated_view; diff --git a/aten/src/ATen/FunctionalStorageImpl.cpp b/aten/src/ATen/FunctionalStorageImpl.cpp index e50ffbdcf5112..8e80ce0ca7ddc 100644 --- a/aten/src/ATen/FunctionalStorageImpl.cpp +++ b/aten/src/ATen/FunctionalStorageImpl.cpp @@ -15,23 +15,9 @@ ViewMeta ViewMeta::to_out_idx(int64_t out_idx) { return ViewMeta(forward_fn, reverse_fn, out_idx); } -Alias::Alias(const at::Tensor& base) { - TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(base)); - base_ = base; -} - -const at::Tensor& Alias::base() const { - return base_; -} - -void Alias::add_update(const at::Tensor& updated_val, const std::vector& metas) { - updates_.push_back({updated_val, metas}); - generation_++; -} - // Note [Functionalization: Alias Removal Part 2] // See Note [Functionalization: Alias Removal] for more details. -// This function applies a single update from one of the views to the Alias object. +// This function applies a single update from one of the views to the StorageImpl. // We start out with and , and our goal is to end up with . // Consider this program: // @@ -46,15 +32,15 @@ void Alias::add_update(const at::Tensor& updated_val, const std::vector 0; - for (auto& update_data: updates_) { - base_ = apply_update(update_data, base_); - } - updates_.clear(); - return any_updates; -} c10::SymInt get_nbytes(const Tensor& value) { if (value.unsafeGetTensorImpl()->has_symbolic_sizes_strides()) { // Today, the two implementations of SymInt are in Python (proxy tensor), // and lazy tensor (LTC/XLA). - // LTC hasn't implemented SymInt support yet though (torch::lazy::SymIntNodeImpl). + // LTC hasn't implemented SymInt support yet though // Once it does, we should remove this check. if (value.key_set().has(c10::DispatchKey::Python)) { return value.storage().sym_nbytes(); @@ -105,31 +78,37 @@ c10::SymInt get_nbytes(const Tensor& value) { return at::detail::computeStorageNbytes(value.sizes(), value.strides(), value.dtype().itemsize(), value.storage_offset()); } -FunctionalStorageImpl::FunctionalStorageImpl(const Tensor& value) +FunctionalStorageImpl::FunctionalStorageImpl(const Tensor& base) : c10::StorageImpl( c10::StorageImpl::use_byte_size_t(), - get_nbytes(value), - DataPtr{nullptr, value.device()}, + get_nbytes(base), + DataPtr{nullptr, base.device()}, GetAllocator(kMeta), /*resizeable=*/true ), - alias_(Alias(value)) - {} - -void FunctionalStorageImpl::add_update(const Tensor& updated_val, const std::vector& view_metas) { - alias_.add_update(updated_val, view_metas); -} - -bool FunctionalStorageImpl::apply_updates() { - return alias_.apply_updates(); + base_(base) + { + TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(base_)); } -const Tensor& FunctionalStorageImpl::base() { - return alias_.base(); +void FunctionalStorageImpl::add_update(const Tensor& updated_val, const std::vector& metas) { + TORCH_CHECK(!frozen_, "cannot mutate tensors with frozen storage"); + updates_.push_back({updated_val, metas}); + generation_++; } -size_t FunctionalStorageImpl::generation() const { - return alias_.generation(); +bool FunctionalStorageImpl::apply_updates() { + // N.B:none of the tensors used in this function should be FunctionalTensorWrappers at this point. + // The only reason we currently need the TLS exclude guard here is because of functorch's DynamicLayer stack. + // It adds the Functionalize key into TLS before redispatching to the functionalization kernels, + // which means that we need to explicitly exclude it here before doing any other work underneath the pass. + at::AutoDispatchSkipFunctionalize guard; + bool any_updates = updates_.size() > 0; + for (auto& update_data: updates_) { + base_ = apply_update(update_data, base_); + } + updates_.clear(); + return any_updates; } } // namespace functionalization diff --git a/aten/src/ATen/FunctionalStorageImpl.h b/aten/src/ATen/FunctionalStorageImpl.h index 6caeac2737fd0..dbaf30c9963d9 100644 --- a/aten/src/ATen/FunctionalStorageImpl.h +++ b/aten/src/ATen/FunctionalStorageImpl.h @@ -46,13 +46,18 @@ struct ViewMeta { ViewMeta to_out_idx(int64_t out_idx); }; -// Alias represents the state shared by (potentially multiple) views of the same -// tensor. For example, in the following code: +// FunctionalStorageImpl is a subclass of StorageImpl used by the +// functionalization pass. It has no underlying data (similar to meta storage). +// It also knows how to reflect mutations to tensors in the absence of a valid +// data pointer. +// +// A storage represents the state shared by (potentially multiple) views of the +// same tensor. For example, in the following code: // // b = a.view1(...) // c = b.view2(...) // b.add_(1) -// --> alias.add_update(b, {view1_meta}) +// --> storage.add_update(b, {view1_meta}) // // The call to add_(1) will result in a call to alias.add_update(b, // {view1_meta}), queueing up the mutation from b onto the alias. Later, suppose @@ -65,58 +70,49 @@ struct ViewMeta { // --> c.sync_() // --> alias.apply_updates() // after this, the alias will be updated to // reflect the mutation to b -class Alias { +struct TORCH_API FunctionalStorageImpl : public c10::StorageImpl { public: struct Update { const at::Tensor new_val; const std::vector view_metas; }; - explicit Alias(const at::Tensor& base); - const at::Tensor& base() const; + + explicit FunctionalStorageImpl(const Tensor& value); + + void add_update( + const Tensor& updated_val, + const std::vector& view_metas); + bool apply_updates(); + const Tensor& base() { + return base_; + } size_t generation() const { return generation_; } - void add_update( - const at::Tensor& updated_val, - const std::vector& metas); - bool apply_updates(); + void freeze() { + frozen_ = true; + } + + ~FunctionalStorageImpl() override = default; private: // NB: base_ should always point to a tensor BELOW the current // functionalization layer. This is mainly to avoid reference cycles. e.g. // given `b = a.view(...)` Both a.storage_ and b.storage_ are a - // FunctionStorageImpl containing an Alias, with contains a Tensor `base_`. In - // this case (where a and b are FunctionalTensorWrapper's), base_ should point - // not to a, but to a's unwrapped value, a.value_` See Note - // [Functionalization: Alias Removal] for a diagram that shows this visually. + // FunctionStorageImpl containing an Walualias, with contains a Tensor + // `base_`. In this case (where a and b are FunctionalTensorWrapper's), base_ + // should point not to a, but to a's unwrapped value, a.value_` See Note + // [Functionalization: Walualias Removal] for a diagram that shows this + // visually. at::Tensor base_; std::vector updates_; // generation_ gets incremented every time a mutation is queued onto the // alias. It is used to determine if a given tensor is "up to date", or if it // needs to be regenerated from the alias. size_t generation_ = 0; -}; - -// FunctionalStorageImpl is a subclass of StorageImpl used by the -// functionalization pass. It has no underlying data (similar to meta storage). -// It also knows how to reflect mutations to tensors in the absence of a valid -// data pointer. It does this by separately storing an Alias object, which knows -// how to reflect mutations that may have happened to views of the original -// tensor. -struct TORCH_API FunctionalStorageImpl : public c10::StorageImpl { - explicit FunctionalStorageImpl(const Tensor& value); - - void add_update( - const Tensor& updated_val, - const std::vector& view_metas); - bool apply_updates(); - const Tensor& base(); - size_t generation() const; - - ~FunctionalStorageImpl() override = default; - - private: - at::functionalization::Alias alias_; + // If frozen, no more mutations are allowed on this storage. Once frozen, a + // storage cannot be unfrozen. + bool frozen_ = false; }; } // namespace functionalization diff --git a/aten/src/ATen/FunctionalTensorWrapper.cpp b/aten/src/ATen/FunctionalTensorWrapper.cpp index 91136f921b1ad..2c3a12020eb68 100644 --- a/aten/src/ATen/FunctionalTensorWrapper.cpp +++ b/aten/src/ATen/FunctionalTensorWrapper.cpp @@ -37,22 +37,10 @@ void FunctionalTensorWrapper::set_constructor_metadata() { // Functorch transforms all have their own wrapper tensors (e.g. BatchedTensorImpl) which expect // to participate in the functorch transforms. key_set_ = key_set_ - c10::functorch_transforms_ks - c10::python_ks; - // For better error handling, - // we also don't want our wrapper tensor to be able to dispatch directly - // to a backend kernel. - // Dispatching directly to e.g. a CPU kernel would always segfault, - // because wrapper tensors don't have any real data. - // (This should never happen because we should always hit a functionalization kernel, - // but can help make bugs less nasty). - // Here, we defensively remove any backend keys from the wrapper's keyset. - // We don't want to remove actual backend bits though (say we're redispatching to autograd; - // we need to know if we're dispatching to AutogradCPU or AutogradXLA). - // Instead, it's sufficient to remove the `Dense` dispatch key, - // which prevents us from accidentally trying to directly run a CPU/CUDA kernel. - key_set_ = key_set_.remove(c10::DispatchKey::Dense); // We override a bunch of _custom(), so make sure they get called // TODO: metadata copying may not actually be necessary then set_custom_sizes_strides(SizesStridesPolicy::CustomSizes); + set_custom_device(true); } FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& value) @@ -66,6 +54,10 @@ FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& value) set_constructor_metadata(); } +void FunctionalTensorWrapper::freeze_storage() const { + functional_storage_impl()->freeze(); +} + // Note [Functionalization: Alias Removal] // When someone calls a view() op during the functionalization pass, e.g. 'b = a.view(...)', // we link `b` and `a` to a shared Alias object to preserve the aliasing relationship. @@ -302,12 +294,16 @@ c10::intrusive_ptr FunctionalTensorWrapper::shallow_copy_and_detach_ return r; } } + auto impl = c10::make_intrusive(value_); copy_tensor_metadata( /*src_impl=*/this, /*dest_impl=*/impl.get(), /*version_counter=*/std::forward(version_counter), /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); + impl->level_ = level_; + impl->generation_ = generation_; + impl->view_metas_ = view_metas_; impl->refresh_numel(); impl->refresh_contiguous(); return impl; @@ -327,6 +323,9 @@ c10::intrusive_ptr FunctionalTensorWrapper::shallow_copy_and_detach( std::move(version_counter), allow_tensor_metadata_change); } +c10::Device FunctionalTensorWrapper::device_custom() const { + return value_.unsafeGetTensorImpl()->device(); +} at::IntArrayRef FunctionalTensorWrapper::sizes_custom() const { return value_.unsafeGetTensorImpl()->sizes(); } @@ -533,6 +532,12 @@ bool isFunctionalTensor(ITensorListRef list) { return isFunctionalTensorIListRef(list); } +void freeze_functional_tensor(const Tensor& tensor) { + TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(tensor)); + auto functional_base_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(tensor); + functional_base_impl->freeze_storage(); +} + Tensor create_functional_tensor_with_view_meta(const at::Tensor& view_to_wrap, const at::Tensor& base, functionalization::ViewMeta meta, int64_t out_idx) { TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(view_to_wrap)); TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(base)); diff --git a/aten/src/ATen/FunctionalTensorWrapper.h b/aten/src/ATen/FunctionalTensorWrapper.h index 9f98353dad868..0762fb1f7f9b0 100644 --- a/aten/src/ATen/FunctionalTensorWrapper.h +++ b/aten/src/ATen/FunctionalTensorWrapper.h @@ -100,6 +100,8 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl { // used to determine if it's up-to-date with its alias. The act of syncing a // tensor will set a tensor's generation equal to its alias's generation. bool is_up_to_date() const; + // Freezes the storage of this tensor, preventing subsequent mutations + void freeze_storage() const; // Every FunctionalTensorWrapper contains a vector objects // describing the series of view ops that ran to generate the current tensor // from the base tensor. This method is used by inplace-view ops like @@ -146,6 +148,7 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl { c10::SymInt sym_size_custom(int64_t d) const override; c10::SymIntArrayRef sym_strides_custom() const override; c10::SymInt sym_storage_offset_custom() const override; + c10::Device device_custom() const override; private: const char* tensorimpl_type_name() const override; @@ -197,6 +200,8 @@ TORCH_API c10::List> to_functional_tensor( const c10::List>& t_list); TORCH_API std::vector to_functional_tensor(ITensorListRef t_list); +TORCH_API void freeze_functional_tensor(const Tensor& tensor); + TORCH_API Tensor from_functional_tensor(const Tensor& tensor, bool assert_functional = true); TORCH_API c10::optional from_functional_tensor( diff --git a/aten/src/ATen/InferSize.h b/aten/src/ATen/InferSize.h index 594b87373a209..111c7eb8f5fc7 100644 --- a/aten/src/ATen/InferSize.h +++ b/aten/src/ATen/InferSize.h @@ -80,7 +80,7 @@ inline at::SymDimVector infer_size_dv( c10::SymInt numel) { auto res = at::SymDimVector(shape); infer_size_impl( - shape, numel, res); + shape, std::move(numel), res); return res; } diff --git a/aten/src/ATen/BatchedFallback.cpp b/aten/src/ATen/LegacyBatchedFallback.cpp similarity index 99% rename from aten/src/ATen/BatchedFallback.cpp rename to aten/src/ATen/LegacyBatchedFallback.cpp index 7ca516182cc4c..72794ece1c5a8 100644 --- a/aten/src/ATen/BatchedFallback.cpp +++ b/aten/src/ATen/LegacyBatchedFallback.cpp @@ -1,7 +1,7 @@ #include -#include +#include #include -#include +#include #include #include #include diff --git a/aten/src/ATen/BatchedFallback.h b/aten/src/ATen/LegacyBatchedFallback.h similarity index 100% rename from aten/src/ATen/BatchedFallback.h rename to aten/src/ATen/LegacyBatchedFallback.h diff --git a/aten/src/ATen/BatchedTensorImpl.cpp b/aten/src/ATen/LegacyBatchedTensorImpl.cpp similarity index 99% rename from aten/src/ATen/BatchedTensorImpl.cpp rename to aten/src/ATen/LegacyBatchedTensorImpl.cpp index fdedfa7c6316e..eea6d7859930c 100644 --- a/aten/src/ATen/BatchedTensorImpl.cpp +++ b/aten/src/ATen/LegacyBatchedTensorImpl.cpp @@ -1,4 +1,4 @@ -#include +#include #include #include diff --git a/aten/src/ATen/BatchedTensorImpl.h b/aten/src/ATen/LegacyBatchedTensorImpl.h similarity index 100% rename from aten/src/ATen/BatchedTensorImpl.h rename to aten/src/ATen/LegacyBatchedTensorImpl.h diff --git a/aten/src/ATen/BatchingRegistrations.cpp b/aten/src/ATen/LegacyBatchingRegistrations.cpp similarity index 99% rename from aten/src/ATen/BatchingRegistrations.cpp rename to aten/src/ATen/LegacyBatchingRegistrations.cpp index 5a01f949745f6..c235da67d5a71 100644 --- a/aten/src/ATen/BatchingRegistrations.cpp +++ b/aten/src/ATen/LegacyBatchingRegistrations.cpp @@ -1,7 +1,7 @@ #include #include -#include -#include +#include +#include #include #include #include diff --git a/aten/src/ATen/VmapMode.cpp b/aten/src/ATen/LegacyVmapMode.cpp similarity index 95% rename from aten/src/ATen/VmapMode.cpp rename to aten/src/ATen/LegacyVmapMode.cpp index 4f0a2413f4513..f10e1005debcd 100644 --- a/aten/src/ATen/VmapMode.cpp +++ b/aten/src/ATen/LegacyVmapMode.cpp @@ -1,4 +1,4 @@ -#include +#include namespace at { namespace impl { diff --git a/aten/src/ATen/VmapMode.h b/aten/src/ATen/LegacyVmapMode.h similarity index 100% rename from aten/src/ATen/VmapMode.h rename to aten/src/ATen/LegacyVmapMode.h diff --git a/aten/src/ATen/VmapTransforms.cpp b/aten/src/ATen/LegacyVmapTransforms.cpp similarity index 99% rename from aten/src/ATen/VmapTransforms.cpp rename to aten/src/ATen/LegacyVmapTransforms.cpp index 71ef7a169026d..1457e572812a4 100644 --- a/aten/src/ATen/VmapTransforms.cpp +++ b/aten/src/ATen/LegacyVmapTransforms.cpp @@ -1,4 +1,4 @@ -#include +#include #include #include #include diff --git a/aten/src/ATen/VmapTransforms.h b/aten/src/ATen/LegacyVmapTransforms.h similarity index 99% rename from aten/src/ATen/VmapTransforms.h rename to aten/src/ATen/LegacyVmapTransforms.h index cece52dcbc410..0afb3247ac86e 100644 --- a/aten/src/ATen/VmapTransforms.h +++ b/aten/src/ATen/LegacyVmapTransforms.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include namespace at { diff --git a/aten/src/ATen/NestedTensorImpl.cpp b/aten/src/ATen/NestedTensorImpl.cpp index 94c9c8d073a94..4ed527cfd4865 100644 --- a/aten/src/ATen/NestedTensorImpl.cpp +++ b/aten/src/ATen/NestedTensorImpl.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include @@ -172,6 +173,7 @@ NestedTensorImpl::NestedTensorImpl( nested_stride_tensor_(std::move(nested_stride_tensor)), storage_offsets_(std::move(offsets)), opt_sizes_(construct_opt_sizes(nested_size_tensor_)) { + C10_LOG_API_USAGE_ONCE("torch.NestedTensor"); TORCH_WARN_ONCE( "The PyTorch API of nested tensors is in prototype stage and will change " "in the near future."); diff --git a/aten/src/ATen/PadNd.h b/aten/src/ATen/PadNd.h index 2c0d67e9d5d3f..573d1a7b88ab7 100644 --- a/aten/src/ATen/PadNd.h +++ b/aten/src/ATen/PadNd.h @@ -1,4 +1,6 @@ #pragma once +#include +#include namespace at { diff --git a/aten/src/ATen/Parallel.h b/aten/src/ATen/Parallel.h index 4693997624e98..ff14f568d22a6 100644 --- a/aten/src/ATen/Parallel.h +++ b/aten/src/ATen/Parallel.h @@ -29,7 +29,7 @@ TORCH_API bool in_parallel_region(); namespace internal { // Initialise num_threads lazily at first parallel call -inline TORCH_API void lazy_init_num_threads() { +inline void lazy_init_num_threads() { thread_local bool init = false; if (C10_UNLIKELY(!init)) { at::init_num_threads(); diff --git a/aten/src/ATen/PythonTorchFunctionTLS.cpp b/aten/src/ATen/PythonTorchFunctionTLS.cpp index c4e1241805a88..c9487c6958cbf 100644 --- a/aten/src/ATen/PythonTorchFunctionTLS.cpp +++ b/aten/src/ATen/PythonTorchFunctionTLS.cpp @@ -6,18 +6,6 @@ namespace impl { static thread_local PythonTorchFunctionTLS pythonTorchFunctionState; -void PythonTorchFunctionTLS::set_mode(std::shared_ptr mode) { - pythonTorchFunctionState.mode_ = std::move(mode); -} - -const std::shared_ptr& PythonTorchFunctionTLS::get_mode() { - return pythonTorchFunctionState.mode_; -} - -void PythonTorchFunctionTLS::swap_mode(std::shared_ptr& mode) { - pythonTorchFunctionState.mode_.swap(mode); -} - void PythonTorchFunctionTLS::push_onto_stack(std::shared_ptr mode) { pythonTorchFunctionState.stack_.push_back(std::move(mode)); } @@ -54,8 +42,8 @@ const PythonTorchFunctionTLS& PythonTorchFunctionTLS::get_state() { return pythonTorchFunctionState; } -bool function_mode_enabled() { - return static_cast(PythonTorchFunctionTLS::get_mode()); +bool torch_function_mode_enabled() { + return PythonTorchFunctionTLS::stack_len() > 0; } } // namespace impl diff --git a/aten/src/ATen/PythonTorchFunctionTLS.h b/aten/src/ATen/PythonTorchFunctionTLS.h index ef283164246d3..5940fb6f2dee2 100644 --- a/aten/src/ATen/PythonTorchFunctionTLS.h +++ b/aten/src/ATen/PythonTorchFunctionTLS.h @@ -10,10 +10,6 @@ struct TORCH_API PythonTorchFunctionTLS { static void set_disabled(bool); static bool is_disabled(); - static void set_mode(std::shared_ptr); - static const std::shared_ptr& get_mode(); - static void swap_mode(std::shared_ptr&); - static void push_onto_stack(std::shared_ptr mode); static const std::shared_ptr pop_stack(); static const std::shared_ptr& get_stack_at(int64_t idx); @@ -26,16 +22,13 @@ struct TORCH_API PythonTorchFunctionTLS { // The mode TLS is split into // - disabled_, which says whether or not to disable all torch function // modes - // - mode_, which is the C++ mode, that can only be the mode handling mode - // or null // - stack_, which is a vector of modes representing the stack of user // defined modes bool disabled_; - std::shared_ptr mode_ = nullptr; std::vector> stack_; }; -TORCH_API bool function_mode_enabled(); +TORCH_API bool torch_function_mode_enabled(); } // namespace impl } // namespace at diff --git a/aten/src/ATen/SparseCsrTensorUtils.h b/aten/src/ATen/SparseCsrTensorUtils.h index e76d2707c6f49..13ed74c7e8a55 100644 --- a/aten/src/ATen/SparseCsrTensorUtils.h +++ b/aten/src/ATen/SparseCsrTensorUtils.h @@ -122,6 +122,13 @@ } \ }() +#define AT_DISPATCH_SPARSE_VALUE_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND4( \ + kComplexHalf, kHalf, kBool, kBFloat16, __VA_ARGS__)) + namespace at { namespace sparse_csr { @@ -287,5 +294,13 @@ inline Layout flip_compressed_layout(Layout layout) { } } +inline DimVector getBlockSize(Tensor const& self) { + int64_t n_batch = numBatchDimensions(self); + Tensor values = self.values(); + return { + std::max(1, values.size(n_batch + 1)), + std::max(1, values.size(n_batch + 2))}; +} + } // namespace sparse_csr } // namespace at diff --git a/aten/src/ATen/SparseTensorImpl.cpp b/aten/src/ATen/SparseTensorImpl.cpp index 197ae21438967..36c93b706db86 100644 --- a/aten/src/ATen/SparseTensorImpl.cpp +++ b/aten/src/ATen/SparseTensorImpl.cpp @@ -89,16 +89,16 @@ void SparseTensorImpl::set_indices_and_values_unsafe(const Tensor& indices, cons TORCH_CHECK(indices.options().backend() == values.options().backend(), "backend of indices (", indices.options().backend(), ") must match backend of values (", values.options().backend(), ")"); TORCH_CHECK(!indices.is_cuda() || indices.get_device() == values.get_device(), "device of indices (", indices.get_device(), ") must match device of values (", values.get_device(), ")"); - TORCH_CHECK(indices.dim() == 2, "indices must be sparse_dim x nnz, but got: ", indices.sizes()); - TORCH_CHECK(indices.size(1) == values.size(0), "indices and values must have same nnz, but got nnz from indices: ", indices.size(1), ", nnz from values: ", values.size(0)); - TORCH_CHECK(indices.size(0) == sparse_dim_, "indices has incorrect first dimension, expected ", sparse_dim_, ", got ", indices.size(0)); + TORCH_CHECK(indices.dim() == 2, "indices must be sparse_dim x nnz, but got: ", indices.sym_sizes()); + TORCH_CHECK(indices.sym_size(1) == values.sym_size(0), "indices and values must have same nnz, but got nnz from indices: ", indices.sym_size(1), ", nnz from values: ", values.sym_size(0)); + TORCH_CHECK(indices.sym_size(0) == sparse_dim_, "indices has incorrect first dimension, expected ", sparse_dim_, ", got ", indices.sym_size(0)); TORCH_CHECK(values.dim() == dense_dim_ + 1, "values has incorrect number of dimensions, expected ", dense_dim_ + 1, ", got ", values.dim()); - auto dense_size_original = sizes().slice(sparse_dim_); - std::vector expected_values_size_vec = {values.size(0)}; + auto dense_size_original = sym_sizes().slice(sparse_dim_); + std::vector expected_values_size_vec = {values.sym_size(0)}; expected_values_size_vec.insert(expected_values_size_vec.end(), dense_size_original.begin(), dense_size_original.end()); - IntArrayRef expected_values_size(expected_values_size_vec); - auto new_values_size = values.sizes(); + SymIntArrayRef expected_values_size(expected_values_size_vec); + auto new_values_size = values.sym_sizes(); TORCH_CHECK( std::equal(expected_values_size.begin(), expected_values_size.end(), new_values_size.begin()), "values has incorrect size, expected ", expected_values_size, ", got ", new_values_size @@ -109,7 +109,7 @@ void SparseTensorImpl::set_indices_and_values_unsafe(const Tensor& indices, cons AT_ASSERT(device() == values_.device()); AT_ASSERT(values_.device() == indices_.device()); - coalesced_ = nnz() < 2; + coalesced_ = sym_nnz() < 2; } diff --git a/aten/src/ATen/SparseTensorImpl.h b/aten/src/ATen/SparseTensorImpl.h index c36d89be5b610..d90734100ca6c 100644 --- a/aten/src/ATen/SparseTensorImpl.h +++ b/aten/src/ATen/SparseTensorImpl.h @@ -9,6 +9,7 @@ #include #else #include +#include #endif namespace at { @@ -51,6 +52,10 @@ struct TORCH_API SparseTensorImpl : public TensorImpl { int64_t nnz() const { return values_.size(0); } + + c10::SymInt sym_nnz() const { + return values_.sym_size(0); + } int64_t sparse_dim() const { return sparse_dim_; } @@ -85,7 +90,7 @@ struct TORCH_API SparseTensorImpl : public TensorImpl { TORCH_CHECK( !has_symbolic_sizes_strides_, "raw_resize_ called on tensor with symbolic shape") - sizes_and_strides_.set_sizes(size); + set_sizes_and_strides(size, std::vector(size.size())); sparse_dim_ = sparse_dim; dense_dim_ = dense_dim; refresh_numel(); @@ -116,7 +121,8 @@ struct TORCH_API SparseTensorImpl : public TensorImpl { // 4. When we attempt to shrink the size of any of the sparse dimensions on a // non-empty sparse tensor (this could make some of the stored indices // out-of-bound and thus unsafe). - void resize_(int64_t sparse_dim, int64_t dense_dim, IntArrayRef size) { + template + void _resize_(int64_t sparse_dim, int64_t dense_dim, ArrayRef size) { TORCH_CHECK( allow_tensor_metadata_change(), "resize_ ", @@ -160,7 +166,7 @@ struct TORCH_API SparseTensorImpl : public TensorImpl { bool shrinking_sparse_dims = false; bool shrinking_dense_dim = false; - auto sparse_size_original = sizes().slice(0, sparse_dim); + auto sparse_size_original = generic_sizes().slice(0, sparse_dim); auto sparse_size_new = size.slice(0, sparse_dim); for (const auto i : c10::irange(sparse_dim)) { if (sparse_size_new[i] < sparse_size_original[i]) { @@ -168,7 +174,7 @@ struct TORCH_API SparseTensorImpl : public TensorImpl { break; } } - auto dense_size_original = sizes().slice(sparse_dim); + auto dense_size_original = generic_sizes().slice(sparse_dim); auto dense_size_new = size.slice(sparse_dim); for (const auto i : c10::irange(dense_dim)) { if (dense_size_new[i] < dense_size_original[i]) { @@ -196,7 +202,7 @@ struct TORCH_API SparseTensorImpl : public TensorImpl { alt_options_msg); } - IntArrayRef sizes_and_strides = sizes_and_strides_.sizes_arrayref(); + auto sizes_and_strides = generic_sizes(); const bool size_equals_sizes = std::equal( size.begin(), size.end(), @@ -204,23 +210,34 @@ struct TORCH_API SparseTensorImpl : public TensorImpl { sizes_and_strides.end()); if ((!size_equals_sizes) || (sparse_dim != sparse_dim_) || (dense_dim != dense_dim_)) { - auto nnz = values().size(0); - std::vector values_size = {nnz}; + auto nnz = at::symint::sizes(values())[0]; + std::vector values_size = {nnz}; auto dense_size = size.slice(sparse_dim); values_size.insert( values_size.end(), dense_size.begin(), dense_size.end()); - values_.resize_(values_size); - indices_.resize_({sparse_dim, nnz}); + at::symint::resize_(values_, values_size); + at::symint::resize_(indices_, {T(sparse_dim), nnz}); } if (!size_equals_sizes) { - sizes_and_strides_.set_sizes(size); + set_sizes_and_strides(size, std::vector(size.size())); } sparse_dim_ = sparse_dim; dense_dim_ = dense_dim; refresh_numel(); } + void resize_(int64_t sparse_dim, int64_t dense_dim, ArrayRef size) { + return _resize_(sparse_dim, dense_dim, size); + } + + void resize_( + int64_t sparse_dim, + int64_t dense_dim, + ArrayRef size) { + return _resize_(sparse_dim, dense_dim, size); + } + // NOTE: this function will resize the sparse tensor and also set `indices` // and `values` to empty. void resize_and_clear_( @@ -243,7 +260,7 @@ struct TORCH_API SparseTensorImpl : public TensorImpl { "), but got ", size.size()); - sizes_and_strides_.set_sizes(size); + set_sizes_and_strides(size, std::vector(size.size())); sparse_dim_ = sparse_dim; dense_dim_ = dense_dim; diff --git a/aten/src/ATen/TensorIterator.cpp b/aten/src/ATen/TensorIterator.cpp index 7b1442db75ad4..7e86163f1ca4c 100644 --- a/aten/src/ATen/TensorIterator.cpp +++ b/aten/src/ATen/TensorIterator.cpp @@ -431,7 +431,7 @@ void TensorIteratorBase::compute_types(const TensorIteratorConfig& config) { } // Computes a common dtype, if needed - if (has_different_input_dtypes && config.promote_inputs_to_common_dtype_) { + if ((has_different_input_dtypes || all_ops_are_scalars_) && config.promote_inputs_to_common_dtype_) { common_dtype_ = compute_common_dtype(); } @@ -1237,6 +1237,7 @@ void TensorIteratorBase::compute_shape(const TensorIteratorConfig& config) { shape_ = infer_size_dimvector(shape_, shape); } } + all_ops_are_scalars_ = !has_tensors; } void TensorIteratorBase::compute_strides(const TensorIteratorConfig& config) { diff --git a/aten/src/ATen/TensorIterator.h b/aten/src/ATen/TensorIterator.h index 59f52d9dbd2ed..31ae65466870a 100644 --- a/aten/src/ATen/TensorIterator.h +++ b/aten/src/ATen/TensorIterator.h @@ -659,9 +659,12 @@ struct TORCH_API TensorIteratorBase : public impl::MetaBase { /// in operands_). int num_outputs_ = 0; - /// Whether or not all operands have the same shape. Having all the same - /// shape affects whether or not the iterator is eligible for fast setup. + /// Whether or not all operands have the same shape and are 1d+. Having all + /// the same shape affects whether or not the iterator is eligible for fast + /// setup. bool all_ops_same_shape_ = false; + /// Whether or not all operands are 0d, this affects type promotion + bool all_ops_are_scalars_ = false; /// The "computation" dtype of TensorIterator, specifying what the dtype /// we will do the internal computation in TensorIterator. Typically, diff --git a/aten/src/ATen/ThreadLocalState.cpp b/aten/src/ATen/ThreadLocalState.cpp index 16c0aa42232f8..5c8214b7d8829 100644 --- a/aten/src/ATen/ThreadLocalState.cpp +++ b/aten/src/ATen/ThreadLocalState.cpp @@ -6,6 +6,7 @@ #include #include +#include namespace at { @@ -15,7 +16,8 @@ ThreadLocalState::ThreadLocalState() functorch_tls_(functorch::getCopyOfFuncTorchTLS()), autograd_tls_(c10::AutogradState::get_tls_state()), python_dispatcher_state_(c10::impl::PythonDispatcherTLS::get_state()), - python_torch_function_state_(at::impl::PythonTorchFunctionTLS::get_state()) { + python_torch_function_state_(at::impl::PythonTorchFunctionTLS::get_state()), + functionalization_reapply_views_state_(at::functionalization::impl::getFunctionalizationReapplyViewsTLS()) { rf_tls_ = at::get_record_function_tls_(); saved_tensors_default_hooks_state_ = at::SavedTensorDefaultHooks::get_tls_state(); @@ -53,6 +55,8 @@ void ThreadLocalState::setThreadLocalState( c10::impl::_force_tls_local_dispatch_key_set(state.dispatch_key_); functorch::setFuncTorchTLS(state.functorch_tls_); + + at::functionalization::impl::setFunctionalizationReapplyViewsTLS(state.functionalization_reapply_views_state_); } } // namespace at diff --git a/aten/src/ATen/ThreadLocalState.h b/aten/src/ATen/ThreadLocalState.h index 9e5f70a4224f3..0184cc9b82c47 100644 --- a/aten/src/ATen/ThreadLocalState.h +++ b/aten/src/ATen/ThreadLocalState.h @@ -74,6 +74,8 @@ class TORCH_API ThreadLocalState { // TLS for saved tensors default hooks at::impl::SavedTensorDefaultHooksTLS saved_tensors_default_hooks_state_; + bool functionalization_reapply_views_state_; + friend class ThreadLocalStateGuard; }; diff --git a/aten/src/ATen/WrapDimUtils.h b/aten/src/ATen/WrapDimUtils.h index e942245703287..b0bc583b90c2e 100644 --- a/aten/src/ATen/WrapDimUtils.h +++ b/aten/src/ATen/WrapDimUtils.h @@ -38,14 +38,29 @@ inline int64_t maybe_wrap_dim( return maybe_wrap_dim(dim, tensor_sizes[0].size()); } -// wrap each dim in the dims array, taking dim_post_expr as the true number of -// dimensions +// Given an array of dimensions `dims` of length `ndims`, this function "Wraps" +// each dim in-place for a tensor of rank `dim_post_expr`, allowing dims to be +// specified using negative indices. +// +// Additionally, if `wrap_scalar` is true then scalar tensors with rank 0, will +// allow dimensions in the range [-1, 0]. Otherwise, an IndexError is raised for +// dimensions not in the range [-dim_post_expr, dim_post_expr). inline void maybe_wrap_dims_n( int64_t* dims, int64_t ndims, - int64_t dim_post_expr) { + int64_t dim_post_expr, + bool wrap_scalars = true) { if (dim_post_expr <= 0) { - dim_post_expr = 1; // this will make range [-1, 0] + if (wrap_scalars) { + dim_post_expr = 1; // this will make range [-1, 0] + } else { + TORCH_CHECK_INDEX( + ndims == 0, + "Dimension specified as ", + dims[0], + " but tensor has no dimensions"); + return; + } } int64_t min = -dim_post_expr; int64_t max = dim_post_expr - 1; @@ -67,11 +82,20 @@ inline void maybe_wrap_dims_n( } } -// Wrap each dim in a contiguous container, taking dim_post_expr as the true -// number of dimensions E.g. could also be std::array or c10::SmallVector +// Given a contiguous container of dimensions `dims`, this function "Wraps" +// each dim in-place for a tensor of rank `dim_post_expr`, allowing dims to be +// specified using negative indices. +// +// Additionally, if `wrap_scalar` is true then scalar tensors with rank 0, will +// allow dimensions in the range [-1, 0]. Otherwise, an IndexError is raised for +// dimensions not in the range [-dim_post_expr, dim_post_expr). template -inline void maybe_wrap_dims(Container& dims, int64_t dim_post_expr) { - return maybe_wrap_dims_n(dims.data(), dims.size(), dim_post_expr); +inline void maybe_wrap_dims( + Container& dims, + int64_t dim_post_expr, + bool wrap_scalars = true) { + return maybe_wrap_dims_n( + dims.data(), dims.size(), dim_post_expr, wrap_scalars); } // previously, size [0] tensors were the only possible empty tensors; thus, it diff --git a/aten/src/ATen/autocast_mode.cpp b/aten/src/ATen/autocast_mode.cpp index 437eadf873a15..ee8b4b30b1520 100644 --- a/aten/src/ATen/autocast_mode.cpp +++ b/aten/src/ATen/autocast_mode.cpp @@ -9,6 +9,7 @@ #include #include +#include namespace at { namespace autocast { @@ -64,7 +65,8 @@ namespace { // directly against incoming TensorImpl*s. using weakref_type = c10::weak_intrusive_ptr; using val_type = std::tuple; -thread_local std::unordered_map cached_casts; +std::unordered_map cached_casts; +std::mutex cached_casts_mutex; // nesting tracks the nesting depth of the Python-side context manager. // When the autocast context manager exits to a nesting level that's outside @@ -89,6 +91,7 @@ thread_local at::ScalarType autocast_gpu_dtype = at::kHalf; } void clear_cache() { + const std::lock_guard lock(cached_casts_mutex); cached_casts.clear(); } @@ -155,6 +158,7 @@ Tensor cached_cast(at::ScalarType to_type, const Tensor& arg, DeviceType device_ arg.scalar_type() == at::kFloat && arg.requires_grad() && arg.is_leaf() && !arg.is_view() && cache_enabled); if (can_try_cache) { + const std::lock_guard lock(cached_casts_mutex); auto it = cached_casts.find(arg.unsafeGetTensorImpl()); if (it != cached_casts.end()) { return std::get<1>(it->second); @@ -446,6 +450,9 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) { KERNEL2(cumprod, dimname, fp32_set_opt_dtype) KERNEL(cumsum, fp32_set_opt_dtype) KERNEL2(cumsum, dimname, fp32_set_opt_dtype) + KERNEL(linalg_vector_norm, fp32_set_opt_dtype) + KERNEL(linalg_matrix_norm, fp32_set_opt_dtype) + KERNEL2(linalg_matrix_norm, str_ord, fp32_set_opt_dtype) // commenting these out because they accept an explicit (not-optional) dtype, and we shouldn't try to flip that even // when autocasting. // KERNEL2(norm, ScalarOpt_dtype, fp32_set_opt_dtype) @@ -572,8 +579,6 @@ TORCH_LIBRARY_IMPL(aten, AutocastCPU, m) { KERNEL_CPU(fft_irfftn, fp32) KERNEL_CPU(fft_hfft, fp32) KERNEL_CPU(fft_ihfft, fp32) - KERNEL_CPU(linalg_matrix_norm, fp32) - KERNEL_CPU2(linalg_matrix_norm, str_ord, fp32) KERNEL_CPU(linalg_cond, fp32) KERNEL_CPU2(linalg_cond, p_str, fp32) KERNEL_CPU(linalg_matrix_rank, fp32) diff --git a/aten/src/ATen/autocast_mode.h b/aten/src/ATen/autocast_mode.h index 155a52f669f99..3d57ac9231164 100644 --- a/aten/src/ATen/autocast_mode.h +++ b/aten/src/ATen/autocast_mode.h @@ -126,6 +126,16 @@ inline at::ScalarType prioritize( return current; } +inline at::ScalarType prioritize( + at::ScalarType current, + const ITensorListRef& list, + DeviceType device_type = DeviceType::CUDA) { + for (const auto& tensor : list) { + current = prioritize(current, tensor, device_type); + } + return current; +} + // Template to catch non-Tensor args (no-op that returns current best guess) template inline at::ScalarType prioritize( @@ -196,6 +206,18 @@ inline std::vector cached_cast( return vec; } +inline std::vector cached_cast( + at::ScalarType to_type, + const ITensorListRef& arg, + DeviceType device_type = DeviceType::CUDA) { + std::vector vec; + vec.reserve(arg.size()); + for (const auto& t : arg) { + vec.push_back(cached_cast(to_type, t, device_type)); + } + return vec; +} + // Template to catch non-Tensor args. template inline T cached_cast( diff --git a/aten/src/ATen/core/Dict.h b/aten/src/ATen/core/Dict.h index 7ae106b6618cf..c4fb44ce0c636 100644 --- a/aten/src/ATen/core/Dict.h +++ b/aten/src/ATen/core/Dict.h @@ -101,8 +101,15 @@ class DictEntryRef final { // this wraps map_type::iterator to make sure user code can't rely // on it being the type of the underlying map. template -class DictIterator final : public std::iterator> { +class DictIterator final { public: + // C++17 friendly std::iterator implementation + using iterator_category = std::forward_iterator_tag; + using value_type = DictEntryRef; + using difference_type = std::ptrdiff_t; + using pointer = value_type*; + using reference = value_type&; + explicit DictIterator() = default; ~DictIterator() = default; @@ -136,7 +143,7 @@ class DictIterator final : public std::iterator>::difference_type operator-(const DictIterator& lhs, const DictIterator& rhs) { + friend difference_type operator-(const DictIterator& lhs, const DictIterator& rhs) { return lhs.entryRef_.iterator_ - rhs.entryRef_.iterator_; } diff --git a/aten/src/ATen/core/Formatting.cpp b/aten/src/ATen/core/Formatting.cpp index 875b9ef3d0427..4537adff5aa4b 100644 --- a/aten/src/ATen/core/Formatting.cpp +++ b/aten/src/ATen/core/Formatting.cpp @@ -13,7 +13,7 @@ std::ostream& operator<<(std::ostream & out, Backend b) { return out << toString(b); } -std::ostream& operator<<(std::ostream & out, Scalar s) { +std::ostream& operator<<(std::ostream & out, const Scalar& s) { if (s.isFloatingPoint()) { return out << s.toDouble(); } @@ -35,7 +35,7 @@ std::ostream& operator<<(std::ostream & out, Scalar s) { throw std::logic_error("Unknown type in Scalar"); } -std::string toString(Scalar s) { +std::string toString(const Scalar& s) { std::stringstream out; out << s; return out.str(); diff --git a/aten/src/ATen/core/Formatting.h b/aten/src/ATen/core/Formatting.h index 6dcfc6c7b3cd1..9dcd14e1902ee 100644 --- a/aten/src/ATen/core/Formatting.h +++ b/aten/src/ATen/core/Formatting.h @@ -8,8 +8,8 @@ namespace c10 { TORCH_API std::ostream& operator<<(std::ostream& out, Backend b); -TORCH_API std::ostream& operator<<(std::ostream & out, Scalar s); -TORCH_API std::string toString(Scalar s); +TORCH_API std::ostream& operator<<(std::ostream & out, const Scalar& s); +TORCH_API std::string toString(const Scalar& s); } namespace at { diff --git a/aten/src/ATen/core/IListRef.h b/aten/src/ATen/core/IListRef.h index 0b0ff67b02e2d..340e519f43dbc 100644 --- a/aten/src/ATen/core/IListRef.h +++ b/aten/src/ATen/core/IListRef.h @@ -359,7 +359,7 @@ using MaterializedIListRef = std::vector>; * than 0. */ template -class IListRefIterator : public std::iterator { +class IListRefIterator { private: #define DEFINE_FRIEND_CLASS(TAG, ...) \ friend class detail::IListRefTagImpl; \ @@ -371,6 +371,13 @@ class IListRefIterator : public std::iterator::list_type::const_iterator; using boxed_iterator_type = typename detail:: diff --git a/aten/src/ATen/core/List.h b/aten/src/ATen/core/List.h index fe75bf37cb7fa..610417774774c 100644 --- a/aten/src/ATen/core/List.h +++ b/aten/src/ATen/core/List.h @@ -111,13 +111,15 @@ class ListElementReference final { // this wraps vector::iterator to make sure user code can't rely // on it being the type of the underlying vector. template -class ListIterator final : public std::iterator< - std::random_access_iterator_tag, - T, - std::ptrdiff_t, - T*, - ListElementReference> { +class ListIterator final { public: + // C++17 friendly std::iterator implementation + using iterator_category = std::random_access_iterator_tag; + using value_type = T; + using difference_type = std::ptrdiff_t; + using pointer = T*; + using reference = ListElementReference; + explicit ListIterator() = default; ~ListIterator() = default; @@ -166,7 +168,7 @@ class ListIterator final : public std::iterator< return ListIterator{iterator_ - offset}; } - friend typename std::iterator::difference_type operator-(const ListIterator& lhs, const ListIterator& rhs) { + friend difference_type operator-(const ListIterator& lhs, const ListIterator& rhs) { return lhs.iterator_ - rhs.iterator_; } diff --git a/aten/src/ATen/core/List_test.cpp b/aten/src/ATen/core/List_test.cpp index e16e26b6042e3..f37f3c0084932 100644 --- a/aten/src/ATen/core/List_test.cpp +++ b/aten/src/ATen/core/List_test.cpp @@ -1118,7 +1118,7 @@ TEST(ListTest, canAccessStringByReference) { List list({"one", "two"}); const auto& listRef = list; static_assert(std::is_same::value, - "const List acccess should be by const reference"); + "const List access should be by const reference"); std::string str = list[1]; const std::string& strRef = listRef[1]; EXPECT_EQ("two", str); @@ -1130,7 +1130,7 @@ TEST(ListTest, canAccessOptionalStringByReference) { const auto& listRef = list; static_assert( std::is_same>>::value, - "List> acccess should be by const reference"); + "List> access should be by const reference"); c10::optional str1 = list[1]; c10::optional str2 = list[2]; decltype(auto) strRef1 = listRef[1]; diff --git a/aten/src/ATen/core/PythonFallbackKernel.cpp b/aten/src/ATen/core/PythonFallbackKernel.cpp index fcdb018b6ff7b..2d8834afe59ef 100644 --- a/aten/src/ATen/core/PythonFallbackKernel.cpp +++ b/aten/src/ATen/core/PythonFallbackKernel.cpp @@ -52,9 +52,10 @@ void pythonFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) { // If Torch Dispatch Mode is active, use its PyInterpreter for dispatch - const auto& maybe_torch_dispatch_mode_state = c10::impl::TorchDispatchModeTLS::get_mode(); - if (maybe_torch_dispatch_mode_state) { - maybe_torch_dispatch_mode_state->pyinterpreter()->dispatch(op, stack); + const auto mode_stack_len = c10::impl::TorchDispatchModeTLS::stack_len(); + if (mode_stack_len > 0) { + const auto& cur_torch_dispatch_mode_state = c10::impl::TorchDispatchModeTLS::get_stack_at(mode_stack_len - 1); + cur_torch_dispatch_mode_state->pyinterpreter()->dispatch(op, stack); return; } @@ -73,10 +74,13 @@ void pythonFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) { (*interpreter)->dispatch(op, stack); return; } - } else if (ivalue.isTensorList() || (ivalue.isOptionalTensorList() && !ivalue.isNone())) { + } else if (ivalue.isTensorList() || ivalue.isOptionalTensorList()) { // NB: use toListRef as it doesn't induce refcount bumps (toTensorListRef // is not a thing) for (const auto& nv : ivalue.toListRef()) { + if (nv.isNone()) { + continue; + } auto* interpreter = nv.unsafeToTensorImpl()->pyobj_interpreter(); if (interpreter) { (*interpreter)->dispatch(op, stack); diff --git a/aten/src/ATen/core/PythonOpRegistrationTrampoline.cpp b/aten/src/ATen/core/PythonOpRegistrationTrampoline.cpp new file mode 100644 index 0000000000000..2d9b15a6b03cb --- /dev/null +++ b/aten/src/ATen/core/PythonOpRegistrationTrampoline.cpp @@ -0,0 +1,28 @@ +#include + +namespace at { +namespace impl { + +// The strategy is that all python interpreters attempt to register themselves +// as the main interpreter, but only one wins. Only that interpreter is +// allowed to interact with the C++ dispatcher. Furthermore, when we execute +// logic on that interpreter, we do so hermetically, never setting pyobj field +// on Tensor. + +std::atomic PythonOpRegistrationTrampoline::interpreter_{nullptr}; + +bool PythonOpRegistrationTrampoline::registerInterpreter(c10::impl::PyInterpreter* interp) { + c10::impl::PyInterpreter* expected = nullptr; + interpreter_.compare_exchange_strong(expected, interp); + if (expected != nullptr) { + // This is the second (or later) Python interpreter, which means we need + // non-trivial hermetic PyObject TLS + c10::impl::HermeticPyObjectTLS::init_state(); + return false; + } else { + return true; + } +} + +} // namespace impl +} // namespace at diff --git a/aten/src/ATen/core/PythonOpRegistrationTrampoline.h b/aten/src/ATen/core/PythonOpRegistrationTrampoline.h new file mode 100644 index 0000000000000..00d3c635859a3 --- /dev/null +++ b/aten/src/ATen/core/PythonOpRegistrationTrampoline.h @@ -0,0 +1,18 @@ +#include + +// TODO: this can probably live in c10 + +namespace at { +namespace impl { + +class TORCH_API PythonOpRegistrationTrampoline final { + static std::atomic interpreter_; + +public: + // Returns true if you successfully registered yourself (that means + // you are in the hot seat for doing the operator registrations!) + static bool registerInterpreter(c10::impl::PyInterpreter*); +}; + +} // namespace impl +} // namespace at diff --git a/aten/src/ATen/core/TensorBase.h b/aten/src/ATen/core/TensorBase.h index 08a14f2e09580..0ecd4456033b0 100644 --- a/aten/src/ATen/core/TensorBase.h +++ b/aten/src/ATen/core/TensorBase.h @@ -955,11 +955,21 @@ c10::SymIntArrayRef sizes(const TensorBase& t) { return t.sym_sizes(); } template > IntArrayRef sizes(const TensorBase& t) { return t.sizes(); } +template > +c10::SymInt size(const TensorBase& t, int64_t dim) { return t.sym_size(dim); } +template > +int64_t size(const TensorBase& t, int64_t dim) { return t.size(dim); } + template > c10::SymIntArrayRef strides(const TensorBase& t) { return t.sym_strides(); } template > IntArrayRef strides(const TensorBase& t) { return t.strides(); } +template > +c10::SymInt numel(const TensorBase& t) { return t.sym_numel(); } +template > +int64_t numel(const TensorBase& t) { return t.numel(); } + } // namespace symint } // namespace at diff --git a/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h b/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h index a99f45040788d..bf2a8819f989b 100644 --- a/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h +++ b/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h @@ -577,14 +577,16 @@ namespace impl { // Decay ReturnType to ReturnType_ so that if a reference gets returned, we actually store it by value // and don't get a dangling reference. This is only required because some kernels still return `Tensor&`. #ifdef __cpp_if_constexpr - using ReturnType_ = std::decay_t; + // [Note: VC++ and 'std': ambiguous symbol] + using ReturnType_ = ::std::decay_t; ReturnType_ output = call_functor_with_args_from_stack(functor, dispatchKeySet, stack); #else using ReturnType_ = std::decay_t>; ReturnType_ output = call_functor_with_args_from_stack(functor, dispatchKeySet, delay_check(stack)); #endif torch::jit::drop(*stack, num_inputs); - push_outputs::call(std::move(output), stack); + // See note [ VC++ and 'std': ambiguous symbol] + push_outputs::call(::std::move(output), stack); #ifdef __cpp_if_constexpr } else { #else diff --git a/aten/src/ATen/core/class_type.cpp b/aten/src/ATen/core/class_type.cpp index 9d7b38d4d67b6..2478bde034bc7 100644 --- a/aten/src/ATen/core/class_type.cpp +++ b/aten/src/ATen/core/class_type.cpp @@ -86,7 +86,7 @@ std::string ClassType::getForwardPreHookErrorMessage(int pre_hook_idx) const { std::string pre_hook_schema = pre_hook_name + "(self, input: Tuple[" + input_types + "])"; std::string return_string = - "This error occured while scripting the forward pre-hook '" + + "This error occurred while scripting the forward pre-hook '" + pre_hook_name + "' on module '" + name()->name() + "'. If you did not want to script this pre-hook remove it from the " "original NN module before scripting. Pre-hooks for module '" + @@ -111,7 +111,7 @@ std::string ClassType::getForwardHookErrorMessage(int hook_idx) const { std::string hook_schema = hook_name + "(self, input: Tuple[" + input_types + "], output: " + output_types + ")"; std::string return_string = - "This error occured while scripting the forward hook '" + "This error occurred while scripting the forward hook '" + hook_name + "' on module " + name()->name() + ". If you did not want to script this hook remove it from" + " the original NN module before scripting. This hook was" + diff --git a/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h b/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h index 27c6e26721a2e..7401297c66a69 100644 --- a/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h +++ b/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h @@ -80,7 +80,7 @@ namespace detail { ts = ts | x.key_set(); } } - void operator()(at::ArrayRef>) { + [[noreturn]] void operator()(at::ArrayRef>) { // Just checking that the handling of Tensor?[] didn't change. TORCH_INTERNAL_ASSERT(false); } diff --git a/aten/src/ATen/core/dispatch/Dispatcher.cpp b/aten/src/ATen/core/dispatch/Dispatcher.cpp index 667eefdcc5ab8..8b2257605161e 100644 --- a/aten/src/ATen/core/dispatch/Dispatcher.cpp +++ b/aten/src/ATen/core/dispatch/Dispatcher.cpp @@ -1,6 +1,7 @@ #include #include #include +#include namespace c10 { @@ -9,6 +10,12 @@ bool show_dispatch_trace() { return temp != nullptr; } +static thread_local int64_t dispatch_trace_nesting_value_; + +void dispatch_trace_nesting_incr() { ++dispatch_trace_nesting_value_; } +void dispatch_trace_nesting_decr() { --dispatch_trace_nesting_value_; } +int64_t dispatch_trace_nesting_value() { return dispatch_trace_nesting_value_; } + namespace detail { class RegistrationListenerList final { @@ -44,7 +51,9 @@ Dispatcher::Dispatcher() , operatorLookupTable_() , backendFallbackKernels_() , listeners_(std::make_unique()) -, mutex_() {} +, mutex_() +, cond_var_() +{} Dispatcher::~Dispatcher() = default; @@ -63,6 +72,41 @@ c10::optional Dispatcher::findOp(const OperatorName& overload_na }); } +// NB: If you add more waitFor* implementations, you also have to add +// appropriate notify_all() calls to the relevant register calls + +void Dispatcher::waitForDef(const FunctionSchema& schema) { + using namespace std::chrono_literals; + std::unique_lock lock(mutex_); + bool r = cond_var_.wait_for(lock, 2s, [&]{ + return findOp(schema.operator_name()) != c10::nullopt; + }); + TORCH_INTERNAL_ASSERT(r, + "Expected main interpreter to define ", schema.operator_name(), + ", but this didn't happen within timeout. Are you trying to load " + "different models in the same torchdeploy/multipy instance? You " + "must warmup each interpreter identically, e.g., import all " + "the same dependencies."); +} + +void Dispatcher::waitForImpl(const OperatorName& op_name, c10::optional maybe_dk) { + using namespace std::chrono_literals; + std::unique_lock lock(mutex_); + auto dk = maybe_dk.value_or(DispatchKey::CompositeImplicitAutograd); + auto op = findOrRegisterName_(op_name); + bool r = cond_var_.wait_for(lock, 2s, [&]{ + // NB: this is slightly unsound for overrides, but overrides are + // funny business anyway + return op.hasKernelForDispatchKey(dk); + }); + TORCH_INTERNAL_ASSERT(r, + "Expected main interpreter to implement ", dk, " for ", op_name, + ", but this didn't happen within timeout. Are you trying to load " + "different models in the same torchdeploy/multipy instance? You " + "must warmup each interpreter identically, e.g., import all " + "the same dependencies."); +} + c10::optional Dispatcher::findSchema(const OperatorName& overload_name) { auto it = findOp(overload_name); if (it.has_value()) { @@ -169,6 +213,8 @@ RegistrationHandleRAII Dispatcher::registerDef(FunctionSchema schema, std::strin ++op.operatorDef_->def_count; ++op.operatorDef_->def_and_impl_count; + cond_var_.notify_all(); + return RegistrationHandleRAII([this, op, op_name] { deregisterDef_(op, op_name); }); @@ -221,6 +267,8 @@ RegistrationHandleRAII Dispatcher::registerImpl( ++op.operatorDef_->def_and_impl_count; + cond_var_.notify_all(); + return RegistrationHandleRAII([this, op, op_name, dispatch_key, handle] { deregisterImpl_(op, op_name, dispatch_key, handle); }); @@ -243,6 +291,7 @@ RegistrationHandleRAII Dispatcher::registerName(OperatorName op_name) { std::lock_guard lock(mutex_); auto op = findOrRegisterName_(op_name); ++op.operatorDef_->def_and_impl_count; + return RegistrationHandleRAII( [this, op, op_name] { deregisterName_(op, op_name); }); } diff --git a/aten/src/ATen/core/dispatch/Dispatcher.h b/aten/src/ATen/core/dispatch/Dispatcher.h index 1ea677b54ef5a..5af8ef1e52ded 100644 --- a/aten/src/ATen/core/dispatch/Dispatcher.h +++ b/aten/src/ATen/core/dispatch/Dispatcher.h @@ -11,6 +11,7 @@ #include #include #include +#include #include #include @@ -19,6 +20,14 @@ namespace c10 { TORCH_API bool show_dispatch_trace(); +TORCH_API void dispatch_trace_nesting_incr(); +TORCH_API void dispatch_trace_nesting_decr(); +TORCH_API int64_t dispatch_trace_nesting_value(); + +struct DispatchTraceNestingGuard { + DispatchTraceNestingGuard() { dispatch_trace_nesting_incr(); } + ~DispatchTraceNestingGuard() { dispatch_trace_nesting_decr(); } +}; class TORCH_API OperatorHandle; template class TypedOperatorHandle; @@ -174,6 +183,9 @@ class TORCH_API Dispatcher final { return backendFallbackKernels_[dispatch_ix].kernel.isValid(); } + // Used by torchdeploy/multipy for multiple interpreters racing. + void waitForDef(const FunctionSchema& schema); + void waitForImpl(const OperatorName& op_name, c10::optional dispatch_key); // ------------------------------------------------------------------------ // @@ -299,7 +311,23 @@ class TORCH_API Dispatcher final { std::array backendFallbackKernels_; std::unique_ptr listeners_; + + // This mutex protects concurrent access to the dispatcher std::mutex mutex_; + + // This condition variable gets notified whenever we add a new def/impl to the + // dispatch table. This is primarily used by multipy/torchdeploy, when + // we have multiple interpreters trying to register to the dispatch table. + // In this situation, whenever the non-primary interpreter would have tried + // to register to the dispatch table, instead it will check to see if the + // expected registration has already been made, and if it hasn't, wait on + // this condition variable to see if it was just racing with the primary + // interpreter. + // + // We expect it to be rare for there to be any waiters on this condition + // variable. This is mostly just to help give better diagnostics if + // something goes horribly wrong + std::condition_variable cond_var_; }; /** @@ -308,6 +336,8 @@ class TORCH_API Dispatcher final { * to lookup a kernel for a certain set of arguments. */ class TORCH_API OperatorHandle { + template friend class std::hash; + public: OperatorHandle(OperatorHandle&&) noexcept = default; OperatorHandle& operator=(OperatorHandle&&) noexcept = default; @@ -403,6 +433,14 @@ class TORCH_API OperatorHandle { return operatorDef_->op.getPythonOp(self_interpreter, slow_accessor); } + bool operator==(const OperatorHandle& other) const { + return operatorDef_ == other.operatorDef_; + } + + bool operator!=(const OperatorHandle& other) const { + return operatorDef_ != other.operatorDef_; + } + private: explicit OperatorHandle(std::list::iterator operatorIterator) : operatorDef_(&*operatorIterator), operatorIterator_(operatorIterator) {} @@ -583,7 +621,10 @@ C10_ALWAYS_INLINE_UNLESS_MOBILE Return Dispatcher::call(const TypedOperatorHandl auto dispatchKeySet = op.operatorDef_->op.dispatchKeyExtractor() .template getDispatchKeySetUnboxed(args...); #ifndef NDEBUG + DispatchTraceNestingGuard debug_guard; if (show_dispatch_trace()) { + auto nesting_value = dispatch_trace_nesting_value(); + for (int64_t i = 0; i < nesting_value; ++i) std::cerr << " "; std::cerr << "[call] op=[" << op.operator_name() << "], key=[" << toString(dispatchKeySet.highestPriorityTypeId()) << "]" << std::endl; } #endif @@ -603,7 +644,10 @@ inline Return Dispatcher::redispatch(const TypedOperatorHandle detail::unused_arg_(args...); // workaround for a false-positive warning about unused parameters in gcc 5 // do not use RecordFunction on redispatch #ifndef NDEBUG + DispatchTraceNestingGuard debug_guard; if (show_dispatch_trace()) { + auto nesting_value = dispatch_trace_nesting_value(); + for (int64_t i = 0; i < nesting_value; ++i) std::cerr << " "; std::cerr << "[redispatch] op=[" << op.operator_name() << "], key=[" << toString(currentDispatchKeySet.highestPriorityTypeId()) << "]" << std::endl; } #endif @@ -616,7 +660,10 @@ inline void Dispatcher::callBoxed(const OperatorHandle& op, Stack* stack) const const auto& entry = op.operatorDef_->op; auto dispatchKeySet = entry.dispatchKeyExtractor().getDispatchKeySetBoxed(stack); #ifndef NDEBUG + DispatchTraceNestingGuard debug_guard; if (show_dispatch_trace()) { + auto nesting_value = dispatch_trace_nesting_value(); + for (int64_t i = 0; i < nesting_value; ++i) std::cerr << " "; std::cerr << "[callBoxed] op=[" << op.operator_name() << "], key=[" << toString(dispatchKeySet.highestPriorityTypeId()) << "]" << std::endl; } #endif @@ -666,7 +713,10 @@ inline void Dispatcher::redispatchBoxed(const OperatorHandle& op, DispatchKeySet // note: this doesn't need the mutex because write operations on the list keep iterators intact. const auto& entry = op.operatorDef_->op; #ifndef NDEBUG + DispatchTraceNestingGuard debug_guard; if (show_dispatch_trace()) { + auto nesting_value = dispatch_trace_nesting_value(); + for (int64_t i = 0; i < nesting_value; ++i) std::cerr << " "; std::cerr << "[redispatchBoxed] op=[" << op.operator_name() << "], key=[" << toString(dispatchKeySet.highestPriorityTypeId()) << "]" << std::endl; } #endif @@ -675,3 +725,14 @@ inline void Dispatcher::redispatchBoxed(const OperatorHandle& op, DispatchKeySet } } // namespace c10 + +namespace std { + +template <> +struct hash { + size_t operator()(c10::OperatorHandle op) const noexcept { + return std::hash{}(static_cast(op.operatorDef_)); + } +}; + +} // namespace std diff --git a/aten/src/ATen/core/dispatch/OperatorEntry.cpp b/aten/src/ATen/core/dispatch/OperatorEntry.cpp index 5d53500e7dfe0..cbc7ff8bf309b 100644 --- a/aten/src/ATen/core/dispatch/OperatorEntry.cpp +++ b/aten/src/ATen/core/dispatch/OperatorEntry.cpp @@ -147,13 +147,17 @@ OperatorEntry::AnnotatedKernelContainerIterator OperatorEntry::registerKernel( #else if (k.size() > 0) { #endif - TORCH_WARN("Overriding a previously registered kernel for the same operator and the same dispatch key\n", - " operator: ", (schema_.has_value() ? toString(schema_->schema) : toString(name_)), "\n", - " ", (this->schema_.has_value() ? this->schema_->debug : "no debug info"), "\n", - " dispatch key: ", toString(dispatch_key), "\n", - " previous kernel: ", (cpp_signature_.has_value() ? cpp_signature_->debug : (sym_cpp_signature_.has_value() ? sym_cpp_signature_->debug : "no debug info")), "\n", - " new kernel: ", debug - ); + // Suppress the warning for Meta key as we are overriding C++ meta functions with python meta functions + // for some ops + if (dispatch_key != DispatchKey::Meta) { + TORCH_WARN("Overriding a previously registered kernel for the same operator and the same dispatch key\n", + " operator: ", (schema_.has_value() ? toString(schema_->schema) : toString(name_)), "\n", + " ", (this->schema_.has_value() ? this->schema_->debug : "no debug info"), "\n", + " dispatch key: ", toString(dispatch_key), "\n", + " previous kernel: ", (cpp_signature_.has_value() ? cpp_signature_->debug : (sym_cpp_signature_.has_value() ? sym_cpp_signature_->debug : "no debug info")), "\n", + " new kernel: ", debug + ); + } } #ifdef C10_DISPATCHER_ONE_KERNEL_PER_DISPATCH_KEY @@ -275,6 +279,7 @@ std::pair OperatorEntry::computeDispatchTab // cause confusion for AutogradOther. It's pretty straightforward to use Autograd (if available) // in this case. // (2.4) Use kernel from DispatchKey::Autograd if available + // (2.5) Use kernel from DispatchKey::FuncTorchBatchedDecomposition if available // The implementation of (2.2) relies on the invariant that for a given backend, // `computeDispatchTableEntryWithDebug()` will be called for that backend's autograd key after the // backend key. See Note [Refresh Runtime Autograd entries in dispatchTable_] @@ -327,6 +332,7 @@ std::pair OperatorEntry::computeDispatchTab // We have no intention to change the behavior of Undefined, // so this nested-tensor branch requires `dispatch_key != DispatchKey::Undefined` // to let the original CompositeImplicitAutograd handle Undefined + // See Note: [Disjoint AliasKeyset] The order for this alias key doesn't matter if (dispatch_key != DispatchKey::Undefined && isIncludedInAlias(dispatch_key, DispatchKey::CompositeImplicitAutogradNestedTensor)) { if (auto nested_registration = getKernelForDispatchKey(DispatchKey::CompositeImplicitAutogradNestedTensor)) { return {*nested_registration, "nested kernel"}; @@ -351,6 +357,14 @@ std::pair OperatorEntry::computeDispatchTab } } + // 2.5. For batched backend keys, use kernel from DispatchKey::FuncTorchBatchedDecomposition if available + // See Note: [Disjoint AliasKeyset] The order for this alias key doesn't matter + if (isIncludedInAlias(dispatch_key, DispatchKey::FuncTorchBatchedDecomposition)) { + if (auto batched_registration = getKernelForDispatchKey(DispatchKey::FuncTorchBatchedDecomposition)) { + return {*batched_registration, "batched kernel"}; + } + } + // 3. Backend fallback auto dispatch_ix = getDispatchTableIndexForDispatchKey(dispatch_key); if (dispatch_ix < 0) { @@ -495,6 +509,22 @@ void OperatorEntry::reportSignatureError(const CppSignature& call_signature, con ); }; +std::string post_process_dispatch_key_str(std::string dispatch_key) { + const std::string substr = "PrivateUse1"; + if (substr.size() <= dispatch_key.size() && std::equal(substr.rbegin(), substr.rend(), dispatch_key.rbegin())) { + auto privateuse1_backend = get_privateuse1_backend(); + if (privateuse1_backend != "privateuseone") { + // remove trailing "*PrivateUse1" + dispatch_key.erase(dispatch_key.length() - substr.length()); + // append the registered backend's name. + // AutogradPrivateUse1 -> AutogradFoo + auto backend_name = c10::get_privateuse1_backend(); + dispatch_key = dispatch_key + backend_name; + } + } + return dispatch_key; +} + void OperatorEntry::reportError(DispatchKey dispatchKey) const { // If there is an invariant problem, report it now. checkInvariants(); @@ -509,7 +539,7 @@ void OperatorEntry::reportError(DispatchKey dispatchKey) const { } TORCH_CHECK_NOT_IMPLEMENTED(false, "Could not run '", name_, "' with arguments", - " from the '", toString(dispatchKey), "' backend. This could be because " + " from the '", post_process_dispatch_key_str(toString(dispatchKey)), "' backend. This could be because " "the operator doesn't exist for this backend, or was omitted during ", "the selective/custom build process (if using custom build). If you are a ", "Facebook employee using PyTorch on mobile, please visit ", diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h index dc5860ebf2c4e..2abc6217516de 100644 --- a/aten/src/ATen/core/interned_strings.h +++ b/aten/src/ATen/core/interned_strings.h @@ -50,8 +50,11 @@ namespace c10 { _(prim, FunctionalGraph) \ _(prim, add_optional) \ _(prim, view_copy) \ + _(prim, permute_copy) \ _(prim, reshape_copy) \ _(prim, squeeze_copy) \ + _(prim, t_copy) \ + _(prim, transpose_copy) \ _(prim, unsqueeze_copy) \ _(prim, flatten_copy) \ _(prim, expand_copy) \ @@ -236,6 +239,7 @@ namespace c10 { _(onnx, LSTM) \ _(onnx, MatMul) \ _(onnx, Min) \ + _(onnx, Max) \ _(onnx, Mul) \ _(onnx, Pow) \ _(onnx, RNN) \ diff --git a/aten/src/ATen/core/ivalue.h b/aten/src/ATen/core/ivalue.h index 122afcba4d843..3461fe2300e45 100644 --- a/aten/src/ATen/core/ivalue.h +++ b/aten/src/ATen/core/ivalue.h @@ -425,6 +425,7 @@ struct TORCH_API IValue final { at::Tensor& toTensor() &; const at::Tensor& toTensor() const&; at::TensorImpl* unsafeToTensorImpl() const { + TORCH_INTERNAL_ASSERT(isTensor()); return payload.as_tensor.unsafeGetTensorImpl(); } @@ -562,7 +563,7 @@ struct TORCH_API IValue final { IValue(c10::SymInt i) { if (i.is_symbolic()) { tag = Tag::SymInt; - payload.u.as_intrusive_ptr = i.toSymIntNodeImpl().release(); + payload.u.as_intrusive_ptr = i.toSymNodeImpl().release(); } else { tag = Tag::Int; payload.u.as_int = i.as_int_unchecked(); @@ -578,7 +579,7 @@ struct TORCH_API IValue final { IValue(c10::SymFloat i) { if (i.is_symbolic()) { tag = Tag::SymFloat; - payload.u.as_intrusive_ptr = i.toSymFloatNodeImpl().release(); + payload.u.as_intrusive_ptr = i.toSymNodeImpl().release(); } else { tag = Tag::Double; payload.u.as_double = i.as_float_unchecked(); @@ -812,10 +813,10 @@ struct TORCH_API IValue final { // for both SymFloat and double if (s.isSymInt()) { tag = Tag::SymInt; - payload.u.as_intrusive_ptr = s.toSymInt().toSymIntNodeImpl().release(); + payload.u.as_intrusive_ptr = s.toSymInt().toSymNodeImpl().release(); } else if (s.isSymFloat()) { tag = Tag::SymFloat; - payload.u.as_intrusive_ptr = s.toSymFloat().toSymFloatNodeImpl().release(); + payload.u.as_intrusive_ptr = s.toSymFloat().toSymNodeImpl().release(); } else if (s.isFloatingPoint()) { tag = Tag::Double; payload.u.as_double = s.toDouble(); diff --git a/aten/src/ATen/core/ivalue_inl.h b/aten/src/ATen/core/ivalue_inl.h index 1c3453abb4c88..bea795c8d81e8 100644 --- a/aten/src/ATen/core/ivalue_inl.h +++ b/aten/src/ATen/core/ivalue_inl.h @@ -219,7 +219,7 @@ inline at::Generator IValue::toGenerator() const& { inline c10::SymInt IValue::toSymInt() const { AT_ASSERT(isSymInt() || isInt(), "Expected SymInt or int but got ", tagKind()); if (isSymInt()) { - return c10::SymInt::toSymInt(toIntrusivePtr()); + return c10::SymInt(toIntrusivePtr()); } else { return c10::SymInt(payload.u.as_int); } @@ -228,7 +228,7 @@ inline c10::SymInt IValue::toSymInt() const { inline c10::SymFloat IValue::toSymFloat() const { AT_ASSERT(isSymFloat() || isDouble(), "Expected SymFloat or double but got ", tagKind()); if (isSymFloat()) { - return c10::SymFloat::toSymFloat(toIntrusivePtr()); + return c10::SymFloat(toIntrusivePtr()); } else { return c10::SymFloat(payload.u.as_double); } diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h index e554bd586272f..0a8f5e14d9a5d 100644 --- a/aten/src/ATen/core/jit_type.h +++ b/aten/src/ATen/core/jit_type.h @@ -1310,7 +1310,6 @@ struct TORCH_API SymIntType : public Type { return "SymInt"; } std::string annotation_str_impl(TypePrinter printer = nullptr) const override { - // TODO: will become a Union[SymIntNodeImpl|int] in the near future return "int"; } static const TypeKind Kind = TypeKind::SymIntType; diff --git a/aten/src/ATen/core/library.cpp b/aten/src/ATen/core/library.cpp index 5c9cea05ea76b..965d3f243d01c 100644 --- a/aten/src/ATen/core/library.cpp +++ b/aten/src/ATen/core/library.cpp @@ -89,7 +89,7 @@ Library::Library(Kind kind, std::string ns, c10::optional k, c // merge everything #define DEF_PRELUDE "def(\"", schema.operator_name(), "\"): " -Library& Library::_def(c10::FunctionSchema&& schema, c10::OperatorName* out_name, const std::vector& tags) & { +Library& Library::_def(c10::FunctionSchema&& schema, c10::OperatorName* out_name, const std::vector& tags, _RegisterOrVerify rv) & { TORCH_CHECK(kind_ == DEF || kind_ == FRAGMENT, DEF_PRELUDE, "Cannot define an operator inside of a ", toString(kind_), " block. " @@ -125,13 +125,20 @@ Library& Library::_def(c10::FunctionSchema&& schema, c10::OperatorName* out_name if (out_name) { *out_name = schema.operator_name(); // copy! } - registrars_.emplace_back( - c10::Dispatcher::singleton().registerDef( - std::move(schema), - debugString(file_, line_), - tags - ) - ); + switch (rv) { + case _RegisterOrVerify::REGISTER: + registrars_.emplace_back( + c10::Dispatcher::singleton().registerDef( + std::move(schema), + debugString(file_, line_), + tags + ) + ); + break; + case _RegisterOrVerify::VERIFY: + c10::Dispatcher::singleton().waitForDef(schema); + break; + } return *this; } #undef DEF_PRELUDE @@ -174,11 +181,10 @@ Library& Library::_def(c10::either&& nam } #define IMPL_PRELUDE "impl(\"", name_str, "\", ...): " -Library& Library::_impl(const char* name_str, CppFunction&& f) & { +at::OperatorName Library::_parseNameForLib(const char* name_str) const { auto name = torch::jit::parseName(name_str); auto ns_opt = name.getNamespace(); - // This is kind of similar to the checking in def(), but the error - // messages are a little different for this call site + // This is a copy paste of Library::_impl if (ns_opt.has_value()) { // See Note [Redundancy in registration code is OK] TORCH_CHECK(*ns_opt == *ns_, @@ -193,6 +199,11 @@ Library& Library::_impl(const char* name_str, CppFunction&& f) & { bool b = name.setNamespaceIfNotSet(ns_->c_str()); TORCH_INTERNAL_ASSERT(b, ERROR_CONTEXT); } + return name; +} + +Library& Library::_impl(const char* name_str, CppFunction&& f, _RegisterOrVerify rv) & { + at::OperatorName name = _parseNameForLib(name_str); // See Note [Redundancy in registration code is OK] TORCH_CHECK(!(f.dispatch_key_.has_value() && dispatch_key_.has_value() && @@ -205,19 +216,30 @@ Library& Library::_impl(const char* name_str, CppFunction&& f) & { ERROR_CONTEXT ); auto dispatch_key = f.dispatch_key_.has_value() ? f.dispatch_key_ : dispatch_key_; - registrars_.emplace_back( - c10::Dispatcher::singleton().registerImpl( - std::move(name), - dispatch_key, - std::move(f.func_), - // NOLINTNEXTLINE(performance-move-const-arg) - std::move(f.cpp_signature_), - std::move(f.schema_), - debugString(std::move(f.debug_), file_, line_) - ) - ); + switch (rv) { + case _RegisterOrVerify::REGISTER: + registrars_.emplace_back( + c10::Dispatcher::singleton().registerImpl( + std::move(name), + dispatch_key, + std::move(f.func_), + // NOLINTNEXTLINE(performance-move-const-arg) + std::move(f.cpp_signature_), + std::move(f.schema_), + debugString(std::move(f.debug_), file_, line_) + ) + ); + break; + case _RegisterOrVerify::VERIFY: + c10::Dispatcher::singleton().waitForImpl(name, dispatch_key); + break; + } return *this; } + +c10::OperatorName Library::_resolve(const char* name_str) const { + return _parseNameForLib(name_str); +} #undef IMPL_PRELUDE Library& Library::_fallback(CppFunction&& f) & { diff --git a/aten/src/ATen/cpu/vec/vec256/vec256.h b/aten/src/ATen/cpu/vec/vec256/vec256.h index 98ec588137ce3..f9c8794560be7 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256.h @@ -222,6 +222,59 @@ inline deinterleave2(const Vectorized& a, const Vectorized& _mm256_permute2f128_ps(a_grouped, b_grouped, 0b0110001)); // 1, 3. 4 bits apart } +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FLIP ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template<> +inline Vectorized flip(const Vectorized & v) { + const __m256i mask_float = _mm256_set_epi32(0, 1, 2, 3, 4, 5, 6, 7); + return _mm256_permutevar8x32_ps(v, mask_float); +} + +template<> +inline Vectorized flip(const Vectorized & v) { + return _mm256_permute4x64_pd(v, 27); // 27 == _MM_SHUFFLE(0, 1, 2, 3) +} + +template<> +inline Vectorized flip(const Vectorized & v) { + return _mm256_permute4x64_epi64(v, 27); // 27 == _MM_SHUFFLE(0, 1, 2, 3) +} + +template<> +inline Vectorized flip(const Vectorized & v) { + const __m256i mask_int32 = _mm256_set_epi32(0, 1, 2, 3, 4, 5, 6, 7); + return _mm256_permutevar8x32_epi32(v, mask_int32); +} + +template<> +inline Vectorized flip(const Vectorized & v) { + const __m256i mask = _mm256_set_epi8( + 1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 15, 14, + 1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 15, 14 + ); + auto reversed = _mm256_shuffle_epi8(v, mask); + return _mm256_permute2x128_si256(reversed, reversed, 1); +} + +inline __m256i flip8(const __m256i & v) { + const __m256i mask_int8 = _mm256_set_epi8( + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 + ); + auto reversed = _mm256_shuffle_epi8(v, mask_int8); + return _mm256_permute2x128_si256(reversed, reversed, 1); +} + +template<> +inline Vectorized flip(const Vectorized & v) { + return flip8(v); +} + +template<> +inline Vectorized flip(const Vectorized & v) { + return flip8(v); +} + #endif // (defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) }}} diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_complex_double.h b/aten/src/ATen/cpu/vec/vec256/vec256_complex_double.h index 487233bc3c407..2614e5f85e24d 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_complex_double.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_complex_double.h @@ -185,7 +185,7 @@ template <> class Vectorized> { return _mm256_div_pd(log(), log10_); } Vectorized> log1p() const { - AT_ERROR("not supported for complex numbers"); + return map(std::log1p); } Vectorized> asin() const { // asin(x) diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_complex_float.h b/aten/src/ATen/cpu/vec/vec256/vec256_complex_float.h index 4093022a7e349..4a8f30f0c6ccc 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_complex_float.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_complex_float.h @@ -221,7 +221,7 @@ template <> class Vectorized> { return _mm256_div_ps(log(), log10_); } Vectorized> log1p() const { - AT_ERROR("not supported for complex numbers"); + return map(std::log1p); } Vectorized> asin() const { // asin(x) diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_int.h b/aten/src/ATen/cpu/vec/vec256/vec256_int.h index 0cc36d590019d..81e9d687d10a7 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_int.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_int.h @@ -133,7 +133,6 @@ class Vectorized : public Vectorizedi { Vectorized conj() const { return *this; } - Vectorized frac() const; Vectorized neg() const; Vectorized operator==(const Vectorized& other) const { return _mm256_cmpeq_epi64(values, other.values); @@ -253,7 +252,6 @@ class Vectorized : public Vectorizedi { Vectorized conj() const { return *this; } - Vectorized frac() const; Vectorized neg() const; Vectorized operator==(const Vectorized& other) const { return _mm256_cmpeq_epi32(values, other.values); @@ -467,7 +465,6 @@ class Vectorized : public Vectorizedi { Vectorized conj() const { return *this; } - Vectorized frac() const; Vectorized neg() const; Vectorized operator==(const Vectorized& other) const { return _mm256_cmpeq_epi16(values, other.values); @@ -496,34 +493,37 @@ class Vectorized : public Vectorizedi { Vectorized le(const Vectorized& other) const; }; -template <> -class Vectorized : public Vectorizedi { -private: - static const Vectorized ones; +template +class Vectorized8 : public Vectorizedi { + static_assert( + std::is_same::value || std::is_same::value, + "Only int8_t/uint8_t are supported"); +protected: + static const Vectorized ones; public: - using value_type = int8_t; + using value_type = T; static constexpr int size() { return 32; } using Vectorizedi::Vectorizedi; - Vectorized() {} - Vectorized(int8_t v) { values = _mm256_set1_epi8(v); } - Vectorized(int8_t val1, int8_t val2, int8_t val3, int8_t val4, - int8_t val5, int8_t val6, int8_t val7, int8_t val8, - int8_t val9, int8_t val10, int8_t val11, int8_t val12, - int8_t val13, int8_t val14, int8_t val15, int8_t val16, - int8_t val17, int8_t val18, int8_t val19, int8_t val20, - int8_t val21, int8_t val22, int8_t val23, int8_t val24, - int8_t val25, int8_t val26, int8_t val27, int8_t val28, - int8_t val29, int8_t val30, int8_t val31, int8_t val32) { + Vectorized8() {} + Vectorized8(T v) { values = _mm256_set1_epi8(v); } + Vectorized8(T val1, T val2, T val3, T val4, + T val5, T val6, T val7, T val8, + T val9, T val10, T val11, T val12, + T val13, T val14, T val15, T val16, + T val17, T val18, T val19, T val20, + T val21, T val22, T val23, T val24, + T val25, T val26, T val27, T val28, + T val29, T val30, T val31, T val32) { values = _mm256_setr_epi8(val1, val2, val3, val4, val5, val6, val7, val8, val9, val10, val11, val12, val13, val14, val15, val16, val17, val18, val19, val20, val21, val22, val23, val24, val25, val26, val27, val28, val29, val30, val31, val32); } template - static Vectorized blend(Vectorized a, Vectorized b) { - __at_align__ int8_t tmp_values[size()]; + static Vectorized blend(Vectorized a, Vectorized b) { + __at_align__ T tmp_values[size()]; a.store(tmp_values); if (mask & 0x01) tmp_values[0] = _mm256_extract_epi8(b.values, 0); @@ -591,13 +591,13 @@ class Vectorized : public Vectorizedi { tmp_values[31] = _mm256_extract_epi8(b.values, 31); return loadu(tmp_values); } - static Vectorized blendv(const Vectorized& a, const Vectorized& b, - const Vectorized& mask) { + static Vectorized blendv(const Vectorized& a, const Vectorized& b, + const Vectorized& mask) { return _mm256_blendv_epi8(a.values, b.values, mask.values); } template - static Vectorized arange(int8_t base = 0, step_t step = static_cast(1)) { - return Vectorized( + static Vectorized arange(T base = 0, step_t step = static_cast(1)) { + return Vectorized( base, base + step, base + 2 * step, base + 3 * step, base + 4 * step, base + 5 * step, base + 6 * step, base + 7 * step, base + 8 * step, base + 9 * step, base + 10 * step, base + 11 * step, @@ -607,8 +607,8 @@ class Vectorized : public Vectorizedi { base + 24 * step, base + 25 * step, base + 26 * step, base + 27 * step, base + 28 * step, base + 29 * step, base + 30 * step, base + 31 * step); } - static Vectorized - set(Vectorized a, Vectorized b, int8_t count = size()) { + static Vectorized + set(Vectorized a, Vectorized b, T count = size()) { switch (count) { case 0: return a; @@ -677,18 +677,18 @@ class Vectorized : public Vectorizedi { } return b; } - static Vectorized loadu(const void* ptr) { + static Vectorized loadu(const void* ptr) { return _mm256_loadu_si256(reinterpret_cast(ptr)); } - static Vectorized loadu(const void* ptr, int8_t count) { - __at_align__ int8_t tmp_values[size()]; + static Vectorized loadu(const void* ptr, T count) { + __at_align__ T tmp_values[size()]; // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two // instructions while a loop would be compiled to one instruction. for (const auto i : c10::irange(size())) { tmp_values[i] = 0; } - std::memcpy(tmp_values, ptr, count * sizeof(int8_t)); + std::memcpy(tmp_values, ptr, count * sizeof(T)); return loadu(tmp_values); } void store(void* ptr, int count = size()) const { @@ -697,27 +697,35 @@ class Vectorized : public Vectorizedi { // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values); } else if (count > 0) { - __at_align__ int8_t tmp_values[size()]; + __at_align__ T tmp_values[size()]; _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values); - std::memcpy(ptr, tmp_values, count * sizeof(int8_t)); + std::memcpy(ptr, tmp_values, count * sizeof(T)); } } - const int8_t& operator[](int idx) const = delete; - int8_t& operator[](int idx) = delete; - Vectorized abs() const { - return _mm256_abs_epi8(values); - } - Vectorized real() const { + const T& operator[](int idx) const = delete; + T& operator[](int idx) = delete; + Vectorized real() const { return *this; } - Vectorized imag() const { + Vectorized imag() const { return _mm256_set1_epi8(0); } - Vectorized conj() const { + Vectorized conj() const { return *this; } - Vectorized frac() const; +}; + +template<> +class Vectorized: public Vectorized8 { +public: + using Vectorized8::Vectorized8; + Vectorized neg() const; + + Vectorized abs() const { + return _mm256_abs_epi8(values); + } + Vectorized operator==(const Vectorized& other) const { return _mm256_cmpeq_epi8(values, other.values); } @@ -731,10 +739,10 @@ class Vectorized : public Vectorizedi { return invert(_mm256_cmpgt_epi8(values, other.values)); } Vectorized operator>(const Vectorized& other) const { - return _mm256_cmpgt_epi8(values, other.values); + return other < *this; } Vectorized operator>=(const Vectorized& other) const { - return invert(_mm256_cmpgt_epi8(other.values, values)); + return other <= *this; } Vectorized eq(const Vectorized& other) const; @@ -745,6 +753,46 @@ class Vectorized : public Vectorizedi { Vectorized le(const Vectorized& other) const; }; +template<> +class Vectorized: public Vectorized8 { +public: + using Vectorized8::Vectorized8; + + Vectorized neg() const; + + Vectorized abs() const { + return *this; + } + + Vectorized operator==(const Vectorized& other) const { + return _mm256_cmpeq_epi8(values, other.values); + } + Vectorized operator!=(const Vectorized& other) const { + return invert(_mm256_cmpeq_epi8(values, other.values)); + } + Vectorized operator<(const Vectorized& other) const { + __m256i max = _mm256_max_epu8(values, other.values); + return invert(_mm256_cmpeq_epi8(max, values)); + } + Vectorized operator<=(const Vectorized& other) const { + __m256i max = _mm256_max_epu8(values, other.values); + return _mm256_cmpeq_epi8(max, other.values); + } + Vectorized operator>(const Vectorized& other) const { + return other < *this; + } + Vectorized operator>=(const Vectorized& other) const { + return other <= *this; + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + template <> Vectorized inline operator+(const Vectorized& a, const Vectorized& b) { return _mm256_add_epi64(a, b); @@ -765,6 +813,11 @@ Vectorized inline operator+(const Vectorized& a, const Vectorize return _mm256_add_epi8(a, b); } +template <> +Vectorized inline operator+(const Vectorized& a, const Vectorized& b) { + return _mm256_add_epi8(a, b); +} + template <> Vectorized inline operator-(const Vectorized& a, const Vectorized& b) { return _mm256_sub_epi64(a, b); @@ -785,6 +838,11 @@ Vectorized inline operator-(const Vectorized& a, const Vectorize return _mm256_sub_epi8(a, b); } +template <> +Vectorized inline operator-(const Vectorized& a, const Vectorized& b) { + return _mm256_sub_epi8(a, b); +} + // Negation. Defined here so we can utilize operator- inline Vectorized Vectorized::neg() const { return Vectorized(0) - *this; @@ -802,6 +860,10 @@ inline Vectorized Vectorized::neg() const { return Vectorized(0) - *this; } +inline Vectorized Vectorized::neg() const { + return Vectorized(0) - *this; +} + // Emulate operations with no native 64-bit support in avx, // by extracting each element, performing the operation pointwise, // then combining the results into a vector. @@ -888,6 +950,12 @@ Vectorized inline operator*(const Vectorized& a, const Vectorize return int_elementwise_binary_256(a, b, std::multiplies()); } +template <> +Vectorized inline operator*(const Vectorized& a, const Vectorized& b) { + // We don't have an instruction for multiplying uint8_t + return int_elementwise_binary_256(a, b, std::multiplies()); +} + template <> Vectorized inline minimum(const Vectorized& a, const Vectorized& b) { return emulate(a, b, [](int64_t a_point, int64_t b_point) {return std::min(a_point, b_point);}); @@ -908,6 +976,11 @@ Vectorized inline minimum(const Vectorized& a, const Vectorized< return _mm256_min_epi8(a, b); } +template <> +Vectorized inline minimum(const Vectorized& a, const Vectorized& b) { + return _mm256_min_epu8(a, b); +} + template <> Vectorized inline maximum(const Vectorized& a, const Vectorized& b) { return emulate(a, b, [](int64_t a_point, int64_t b_point) {return std::max(a_point, b_point);}); @@ -928,6 +1001,11 @@ Vectorized inline maximum(const Vectorized& a, const Vectorized< return _mm256_max_epi8(a, b); } +template <> +Vectorized inline maximum(const Vectorized& a, const Vectorized& b) { + return _mm256_max_epu8(a, b); +} + template <> Vectorized inline clamp(const Vectorized& a, const Vectorized& min_val, const Vectorized& max_val) { return emulate(a, min_val, max_val, [](int64_t a_point, int64_t min_point, int64_t max_point) {return std::min(max_point, std::max(a_point, min_point));}); @@ -948,6 +1026,11 @@ Vectorized inline clamp(const Vectorized& a, const Vectorized +Vectorized inline clamp(const Vectorized& a, const Vectorized& min_val, const Vectorized& max_val) { + return _mm256_min_epu8(max_val, _mm256_max_epu8(a, min_val)); +} + template <> Vectorized inline clamp_max(const Vectorized& a, const Vectorized& max_val) { return emulate(a, max_val, [](int64_t a_point, int64_t max_point) {return std::min(max_point, a_point);}); @@ -968,6 +1051,11 @@ Vectorized inline clamp_max(const Vectorized& a, const Vectorize return _mm256_min_epi8(max_val, a); } +template <> +Vectorized inline clamp_max(const Vectorized& a, const Vectorized& max_val) { + return _mm256_min_epu8(max_val, a); +} + template <> Vectorized inline clamp_min(const Vectorized& a, const Vectorized& min_val) { return emulate(a, min_val, [](int64_t a_point, int64_t min_point) {return std::max(min_point, a_point);}); @@ -988,6 +1076,11 @@ Vectorized inline clamp_min(const Vectorized& a, const Vectorize return _mm256_max_epi8(min_val, a); } +template <> +Vectorized inline clamp_min(const Vectorized& a, const Vectorized& min_val) { + return _mm256_max_epu8(min_val, a); +} + template Vectorized inline convert_to_int32(const T* ptr) { return Vectorized::loadu(ptr); @@ -1019,6 +1112,10 @@ template <> Vectorized inline operator/(const Vectorized& a, const Vectorized& b) { return int_elementwise_binary_256(a, b, std::divides()); } +template <> +Vectorized inline operator/(const Vectorized& a, const Vectorized& b) { + return int_elementwise_binary_256(a, b, std::divides()); +} template>::value, int> = 0> inline Vectorized operator&(const Vectorized& a, const Vectorized& b) { @@ -1133,6 +1230,292 @@ inline Vectorized Vectorized::le(const Vectorized& other return (*this <= other) & Vectorized(1); } +inline Vectorized Vectorized::eq(const Vectorized& other) const { + return (*this == other) & Vectorized(1); +} + +inline Vectorized Vectorized::ne(const Vectorized& other) const { + return (*this != other) & Vectorized(1); +} + +inline Vectorized Vectorized::gt(const Vectorized& other) const { + return (*this > other) & Vectorized(1); +} + +inline Vectorized Vectorized::ge(const Vectorized& other) const { + return (*this >= other) & Vectorized(1); +} + +inline Vectorized Vectorized::lt(const Vectorized& other) const { + return (*this < other) & Vectorized(1); +} + +inline Vectorized Vectorized::le(const Vectorized& other) const { + return (*this <= other) & Vectorized(1); +} + +template +Vectorized inline shift_256_16(const Vectorized& a, const Vectorized& b) { + // No vector instruction for shifting int16_t, so emulating it instead. + + // Control masks for shuffle operation, treating 256 bits as an + // array of 16-bit elements, and considering pairs of neighboring + // elements. Specifially, a mask named "ctl_M_N" (M,N in [0,1], and + // M!=N) is set so that shuffle will move element with index M from + // input pair into element with index N in output pair, and element + // with index M in output pair will be set to all 0s. + __m256i ctl_0_1 = _mm256_set_epi8(29, 28, 0x80, 0x80, 25, 24, 0x80, 0x80, + 21, 20, 0x80, 0x80, 17, 16, 0x80, 0x80, + 13, 12, 0x80, 0x80, 9, 8, 0x80, 0x80, + 5, 4, 0x80, 0x80, 1, 0, 0x80, 0x80); + __m256i ctl_1_0 = _mm256_set_epi8(0x80, 0x80, 31, 30, 0x80, 0x80, 27, 26, + 0x80, 0x80, 23, 22, 0x80, 0x80, 19, 18, + 0x80, 0x80, 15, 14, 0x80, 0x80, 11, 10, + 0x80, 0x80, 7, 6, 0x80, 0x80, 3, 2); + + // Masks for bitwise and operation, treating 256 bits as an array of + // 16-bit elements, and considering them in pairs of neighboring + // elements. A mask named "keep_M" (M in [0,1]) is set so that + // bitwise and will copy element with index M from input pair into + // element with the same index in output pair, while the other + // element in output pair will be set to all 0s. + __m256i keep_0 = _mm256_set1_epi32(0xFFFF); + __m256i keep_1 = _mm256_set1_epi32(0xFFFF0000); + + // Take each 16-bit element with idx%2==0 from input array to be + // shifted and extend it to 32 bits so that 0s are added to the + // right. Then, perform shifting on this 32-bit number. Upper 16 + // bits will be proper result of shifting original 16-bit number, so + // write them to result array, into the same position from which + // corresponding input element is taken. Also, make sure that + // result array elements with idx%2!=0 are set to all 0s. + // + // Note that number of bits to shift for is extended to 32 bits by + // adding 0s to the left. That means this number is not properly + // sign-extended for negative values. However, number of bits to + // shift is treated as an unsigned integer by respective shift + // intrinsics anyway so if negative then either with or without + // proper sign extension, it will be interpreted as a number greater + // than 32, and the shifting result will be the same. + __m256i a0 = _mm256_shuffle_epi8(a, ctl_0_1); + __m256i b0 = _mm256_and_si256(b, keep_0); + __m256i c0; + if (left_shift) + c0 = _mm256_sllv_epi32(a0, b0); + else + c0 = _mm256_srav_epi32(a0, b0); + c0 = _mm256_shuffle_epi8(c0, ctl_1_0); + + // Peform shifting the same way for input array elements with + // idx%2==1. + __m256i a1 = _mm256_and_si256(a, keep_1); + __m256i b1 = _mm256_shuffle_epi8(b, ctl_1_0); + __m256i c1; + if (left_shift) + c1 = _mm256_sllv_epi32(a1, b1); + else + c1 = _mm256_srav_epi32(a1, b1); + c1 = _mm256_and_si256(c1, keep_1); + + // Merge partial results into the final result. + __m256i c = _mm256_or_si256(c0, c1); + + return c; +} + +template ::value || std::is_same::value, int> = 0> +Vectorized inline shift_256_8(const Vectorized& a, const Vectorized& b) { + // No vector instruction for shifting int8_t/uint8_t, so emulating + // it instead. + + // Control masks for shuffle operation, treating 256 bits as an + // array of 8-bit elements, and considering quadruples of + // neighboring elements. Specifially, a mask named "ctl_M_N" (M,N + // in [0,1,2,3], and M!=N) is set so that shuffle will move element + // with index M from input quadruple into element with index N in + // output quadruple, and other elements in output quadruple will be + // set to all 0s. + __m256i ctl_0_3 = _mm256_set_epi8(28, 0x80, 0x80, 0x80, 24, 0x80, 0x80, 0x80, + 20, 0x80, 0x80, 0x80, 16, 0x80, 0x80, 0x80, + 12, 0x80, 0x80, 0x80, 8, 0x80, 0x80, 0x80, + 4, 0x80, 0x80, 0x80, 0, 0x80, 0x80, 0x80); + __m256i ctl_1_0 = _mm256_set_epi8(0x80, 0x80, 0x80, 29, 0x80, 0x80, 0x80, 25, + 0x80, 0x80, 0x80, 21, 0x80, 0x80, 0x80, 17, + 0x80, 0x80, 0x80, 13, 0x80, 0x80, 0x80, 9, + 0x80, 0x80, 0x80, 5, 0x80, 0x80, 0x80, 1); + __m256i ctl_1_3 = _mm256_set_epi8(29, 0x80, 0x80, 0x80, 25, 0x80, 0x80, 0x80, + 21, 0x80, 0x80, 0x80, 17, 0x80, 0x80, 0x80, + 13, 0x80, 0x80, 0x80, 9, 0x80, 0x80, 0x80, + 5, 0x80, 0x80, 0x80, 1, 0x80, 0x80, 0x80); + __m256i ctl_2_0 = _mm256_set_epi8(0x80, 0x80, 0x80, 30, 0x80, 0x80, 0x80, 26, + 0x80, 0x80, 0x80, 22, 0x80, 0x80, 0x80, 18, + 0x80, 0x80, 0x80, 14, 0x80, 0x80, 0x80, 10, + 0x80, 0x80, 0x80, 6, 0x80, 0x80, 0x80, 2); + __m256i ctl_2_3 = _mm256_set_epi8(30, 0x80, 0x80, 0x80, 26, 0x80, 0x80, 0x80, + 22, 0x80, 0x80, 0x80, 18, 0x80, 0x80, 0x80, + 14, 0x80, 0x80, 0x80, 10, 0x80, 0x80, 0x80, + 6, 0x80, 0x80, 0x80, 2, 0x80, 0x80, 0x80); + __m256i ctl_3_0 = _mm256_set_epi8(0x80, 0x80, 0x80, 31, 0x80, 0x80, 0x80, 27, + 0x80, 0x80, 0x80, 23, 0x80, 0x80, 0x80, 19, + 0x80, 0x80, 0x80, 15, 0x80, 0x80, 0x80, 11, + 0x80, 0x80, 0x80, 7, 0x80, 0x80, 0x80, 3); + __m256i ctl_3_1 = _mm256_set_epi8(0x80, 0x80, 31, 0x80, 0x80, 0x80, 27, 0x80, + 0x80, 0x80, 23, 0x80, 0x80, 0x80, 19, 0x80, + 0x80, 0x80, 15, 0x80, 0x80, 0x80, 11, 0x80, + 0x80, 0x80, 7, 0x80, 0x80, 0x80, 3, 0x80); + __m256i ctl_3_2 = _mm256_set_epi8(0x80, 31, 0x80, 0x80, 0x80, 27, 0x80, 0x80, + 0x80, 23, 0x80, 0x80, 0x80, 19, 0x80, 0x80, + 0x80, 15, 0x80, 0x80, 0x80, 11, 0x80, 0x80, + 0x80, 7, 0x80, 0x80, 0x80, 3, 0x80, 0x80); + + // Masks for bitwise and operation, treating 256 bits as an array of + // 8-bit elements, and considering them in quadruples of neighboring + // elements. A mask named "keep_M" (M in [0,1,2,3]) is set so that + // bitwise and will copy element with index M from input quadruple + // into element with the same index in output quadruple, while the + // other elements in output quadruple will be set to all 0s. + __m256i keep_0 = _mm256_set1_epi32(0xFF); + __m256i keep_3 = _mm256_set1_epi32(0xFF000000); + + // Take each 8-bit element with idx%4==0 from input array to be + // shifted and extend it to 32 bits so that 0s are added to the + // right. Then, perform shifting on this 32-bit number. Upper 8 + // bits will be proper result of shifting original 8-bit number, so + // write them to result array, into the same position from which + // corresponding input element is taken. Also, make sure that + // result array elements with idx%4!=0 are set to all 0s. + // + // Note that number of bits to shift for is extended to 32 bits by + // adding 0s to the left. That means this number is not properly + // sign-extended for negative values. However, number of bits to + // shift is treated as an unsigned integer by respective shift + // intrinsics anyway so if negative then either with or without + // proper sign extension, it will be interpreted as a number greater + // than 32, and the shifting result will be the same. + __m256i a0 = _mm256_shuffle_epi8(a, ctl_0_3); + __m256i b0 = _mm256_and_si256(b, keep_0); + __m256i c0; + if (left_shift) + c0 = _mm256_sllv_epi32(a0, b0); + else + if (std::is_same::value) + c0 = _mm256_srav_epi32(a0, b0); + else + c0 = _mm256_srlv_epi32(a0, b0); + c0 = _mm256_shuffle_epi8(c0, ctl_3_0); + + // Peform shifting the same way for input array elements with + // idx%4==1. + __m256i a1 = _mm256_shuffle_epi8(a, ctl_1_3); + __m256i b1 = _mm256_shuffle_epi8(b, ctl_1_0); + __m256i c1; + if (left_shift) + c1 = _mm256_sllv_epi32(a1, b1); + else + if (std::is_same::value) + c1 = _mm256_srav_epi32(a1, b1); + else + c1 = _mm256_srlv_epi32(a1, b1); + c1 = _mm256_shuffle_epi8(c1, ctl_3_1); + + // Peform shifting the same way for input array elements with + // idx%4==2. + __m256i a2 = _mm256_shuffle_epi8(a, ctl_2_3); + __m256i b2 = _mm256_shuffle_epi8(b, ctl_2_0); + __m256i c2; + if (left_shift) + c2 = _mm256_sllv_epi32(a2, b2); + else + if (std::is_same::value) + c2 = _mm256_srav_epi32(a2, b2); + else + c2 = _mm256_srlv_epi32(a2, b2); + c2 = _mm256_shuffle_epi8(c2, ctl_3_2); + + // Peform shifting the same way for input array elements with + // idx%4==3. + __m256i a3 = _mm256_and_si256(a, keep_3); + __m256i b3 = _mm256_shuffle_epi8(b, ctl_3_0); + __m256i c3; + if (left_shift) + c3 = _mm256_sllv_epi32(a3, b3); + else + if (std::is_same::value) + c3 = _mm256_srav_epi32(a3, b3); + else + c3 = _mm256_srlv_epi32(a3, b3); + c3 = _mm256_and_si256(c3, keep_3); + + // Merge partial results into the final result. + __m256i c01 = _mm256_or_si256(c0, c1); + __m256i c23 = _mm256_or_si256(c2, c3); + __m256i c = _mm256_or_si256(c01, c23); + + return c; +} + +template <> +Vectorized inline operator<<(const Vectorized& a, const Vectorized& b) { + return _mm256_sllv_epi64(a, b); +} + +template <> +Vectorized inline operator<<(const Vectorized& a, const Vectorized& b) { + return _mm256_sllv_epi32(a, b); +} + +template <> +Vectorized inline operator<<(const Vectorized& a, const Vectorized& b) { + return shift_256_16(a, b); +} + +template <> +Vectorized inline operator<<(const Vectorized& a, const Vectorized& b) { + return shift_256_8(a, b); +} + +template <> +Vectorized inline operator<<(const Vectorized& a, const Vectorized& b) { + return shift_256_8(a, b); +} + +template <> +Vectorized inline operator>>(const Vectorized& a, const Vectorized& b) { + // No vector instruction for right shifting int64_t, so emulating it + // instead. + + // Shift the number logically to the right, thus filling the most + // significant bits with 0s. Then, replace these bits with the sign + // bit. + __m256i sign_bits = _mm256_cmpgt_epi64(_mm256_set1_epi64x(0), a); + __m256i b_inv_mod_64 = _mm256_sub_epi64(_mm256_set1_epi64x(64), b); + __m256i sign_ext = _mm256_sllv_epi64(sign_bits, b_inv_mod_64); + __m256i c = _mm256_srlv_epi64(a, b); + c = _mm256_or_si256(c, sign_ext); + + return c; +} + +template <> +Vectorized inline operator>>(const Vectorized& a, const Vectorized& b) { + return _mm256_srav_epi32(a, b); +} + +template <> +Vectorized inline operator>>(const Vectorized& a, const Vectorized& b) { + return shift_256_16(a, b); +} + +template <> +Vectorized inline operator>>(const Vectorized& a, const Vectorized& b) { + return shift_256_8(a, b); +} + +template <> +Vectorized inline operator>>(const Vectorized& a, const Vectorized& b) { + return shift_256_8(a, b); +} + #endif }}} diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_double_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_double_vsx.h index cb8bb78597854..dfa4a852f4d84 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_double_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_double_vsx.h @@ -142,7 +142,7 @@ class Vectorized { vec_vsx_ld(offset16, reinterpret_cast(ptr))}; } - __at_align__ value_type tmp_values[size()]; + __at_align__ value_type tmp_values[size()] = {}; std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); return { @@ -284,6 +284,10 @@ class Vectorized { return ret.elwise_mult(vd_log10e_inv); } + Vectorized log1p() const { + return map(std::log1p); + } + Vectorized asin() const { // asin(x) // = -i*ln(iz + sqrt(1 -z^2)) @@ -481,10 +485,6 @@ class Vectorized { TORCH_CHECK(false, "not supported for complex numbers"); } - Vectorized log1p() const { - TORCH_CHECK(false, "not supported for complex numbers"); - } - Vectorized atan2(const Vectorized& b) const { TORCH_CHECK(false, "not supported for complex numbers"); } diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_float_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_float_vsx.h index 8445a31fb3d60..56a6f4e6e39a6 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_float_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_float_vsx.h @@ -196,7 +196,7 @@ class Vectorized { vec_vsx_ld(offset16, reinterpret_cast(ptr))}; } - __at_align__ value_type tmp_values[size()]; + __at_align__ value_type tmp_values[size()] = {}; std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); return { @@ -321,6 +321,10 @@ class Vectorized { return ret.elwise_mult(log10e_inv); } + Vectorized log1p() const { + return map(std::log1p); + } + Vectorized el_swapped() const { vfloat32 v0 = vec_perm(_vec0, _vec0, swap_mask); vfloat32 v1 = vec_perm(_vec1, _vec1, swap_mask); @@ -568,10 +572,6 @@ class Vectorized { TORCH_CHECK(false,"not supported for complex numbers"); } - Vectorized log1p() const { - TORCH_CHECK(false,"not supported for complex numbers"); - } - Vectorized expm1() const { TORCH_CHECK(false,"not supported for complex numbers"); } diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_double_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_double_vsx.h index c53b7c792e471..810e79ebfe83d 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_double_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_double_vsx.h @@ -171,7 +171,7 @@ class Vectorized { vec_vsx_ld(offset16, reinterpret_cast(ptr))}; } - __at_align__ value_type tmp_values[size()]; + __at_align__ value_type tmp_values[size()] = {}; std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h index 77cf3695ab912..ac09531c4d2fa 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h @@ -180,7 +180,7 @@ class Vectorized { vec_vsx_ld(offset16, reinterpret_cast(ptr))}; } - __at_align__ value_type tmp_values[size()]; + __at_align__ value_type tmp_values[size()] = {}; std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; @@ -256,29 +256,29 @@ class Vectorized { } Vectorized C10_ALWAYS_INLINE acos() const { - return {Sleef_acosf4_u10vsx(_vec0), Sleef_acosf4_u10vsx(_vec1)}; + return {Sleef_acosf4_u10vsx(_vec0), Sleef_acosf4_u10vsx(_vec1)}; } Vectorized C10_ALWAYS_INLINE asin() const { - return {Sleef_asinf4_u10vsx(_vec0), Sleef_asinf4_u10vsx(_vec1)}; + return {Sleef_asinf4_u10vsx(_vec0), Sleef_asinf4_u10vsx(_vec1)}; } Vectorized atan() const { - return {Sleef_atanf4_u10vsx(_vec0), Sleef_atanf4_u10vsx(_vec1)}; + return {Sleef_atanf4_u10vsx(_vec0), Sleef_atanf4_u10vsx(_vec1)}; } Vectorized atan2(const Vectorized& b) const { - return {Sleef_atan2f4_u10vsx(_vec0, b._vec0), Sleef_atan2f4_u10vsx(_vec1, b._vec1)}; + return {Sleef_atan2f4_u10vsx(_vec0, b._vec0), Sleef_atan2f4_u10vsx(_vec1, b._vec1)}; } Vectorized copysign(const Vectorized &sign) const { return {Sleef_copysignf4_vsx(_vec0, sign._vec0), Sleef_copysignf4_vsx(_vec1, sign._vec1)}; } Vectorized lgamma() const { - return {Sleef_lgammaf4_u10vsx(_vec0), Sleef_lgammaf4_u10vsx(_vec1)}; + return {Sleef_lgammaf4_u10vsx(_vec0), Sleef_lgammaf4_u10vsx(_vec1)}; } Vectorized erf() const { - return {Sleef_erff4_u10vsx(_vec0), Sleef_erff4_u10vsx(_vec1)}; + return {Sleef_erff4_u10vsx(_vec0), Sleef_erff4_u10vsx(_vec1)}; } Vectorized erfc() const { - return {Sleef_erfcf4_u15vsx(_vec0), Sleef_erfcf4_u15vsx(_vec1)}; + return {Sleef_erfcf4_u15vsx(_vec0), Sleef_erfcf4_u15vsx(_vec1)}; } Vectorized erfinv() const { @@ -301,133 +301,32 @@ class Vectorized { } Vectorized C10_ALWAYS_INLINE exp() const { - // implementation logic from avx_mathfun with some modifications from sleef - // Express e**x = e**g 2**n - /// = e**g e**( n loge(2) ) - /// = e**( g + n loge(2) ) - // - auto tmp_x = *this; - auto fx = (tmp_x * log2e_inv).round(); - - auto x = fx.madd(negln2f_hi, tmp_x); - x = fx.madd(negln2f_lo, x); - auto z = x * x; - auto y = x.madd(exp_p0, exp_p1); - y = y.madd(x, exp_p2); - y = y.madd(x, exp_p3); - y = y.madd(x, exp_p4); - y = y.madd(x, exp_p5); - y = y.madd(z, x) + one; - - // vm_pow2n 2^n - vint32 imm0 = vec_signed(fx._vec0); - vint32 imm1 = vec_signed(fx._vec1); - // this pow2n logic is from Sleef code - vint32 imm00 = imm0 >> 1; //>>1 - vint32 imm01 = imm1 >> 1; - vint32 imm10 = imm0 - imm00; - vint32 imm11 = imm1 - imm01; - imm00 = (imm00 + v0x7f) << vu_23; - imm01 = (imm01 + v0x7f) << vu_23; - imm10 = (imm10 + v0x7f) << vu_23; - imm11 = (imm11 + v0x7f) << vu_23; - // treat imm as float vector without conversion - - y._vec0 = (y._vec0 * (vfloat32)imm00) * (vfloat32)imm10; - y._vec1 = (y._vec1 * (vfloat32)imm01) * (vfloat32)imm11; - // boundary check - auto tmp = blendv(y, v_inf, (Vectorized(exp_hi) <= tmp_x)); - y = blendv(tmp, zero, (tmp_x < Vectorized(exp_lo))); - - return y; + return {Sleef_expf4_u10vsx(_vec0), Sleef_expf4_u10vsx(_vec1)}; } Vectorized expm1() const { - return exp() - one; + return {Sleef_expm1f4_u10vsx(_vec0), Sleef_expm1f4_u10vsx(_vec1)}; } Vectorized C10_ALWAYS_INLINE log() const { return {Sleef_logf4_u10vsx(_vec0), Sleef_logf4_u10vsx(_vec1)}; } Vectorized C10_ALWAYS_INLINE log10() const { - return {Sleef_log10f4_u10vsx(_vec0), Sleef_log10f4_u10vsx(_vec1)}; + return {Sleef_log10f4_u10vsx(_vec0), Sleef_log10f4_u10vsx(_vec1)}; } Vectorized C10_ALWAYS_INLINE log1p() const { - return {Sleef_log1pf4_u10vsx(_vec0), Sleef_log1pf4_u10vsx(_vec1)}; + return {Sleef_log1pf4_u10vsx(_vec0), Sleef_log1pf4_u10vsx(_vec1)}; } Vectorized C10_ALWAYS_INLINE log2() const { - return {Sleef_log2f4_u10vsx(_vec0), Sleef_log2f4_u10vsx(_vec1)}; + return {Sleef_log2f4_u10vsx(_vec0), Sleef_log2f4_u10vsx(_vec1)}; } Vectorized C10_ALWAYS_INLINE ceil() const { return {vec_ceil(_vec0), vec_ceil(_vec1)}; } Vectorized C10_ALWAYS_INLINE cos() const { - // take the absolute value - auto x = abs(); - // extract the sign bit (upper one) - auto sign_bit = (*this) & sign_mask; - // scale by 4/Pi - auto y = x * _4div_pi; - // store the integer part of y in mm0 - // j=(j+1) & (~1) (see the cephes sources) - vint32 imm0 = (vec_signed(y._vec0) + vi_1) & vi_inv1; - vint32 imm1 = (vec_signed(y._vec1) + vi_1) & vi_inv1; - y._vec0 = vec_float(imm0); - y._vec1 = vec_float(imm1); - - imm0 = imm0 - vi_2; - imm1 = imm1 - vi_2; - Vectorized poly_mask; - // get the swap sign flag - vint32 tmp0 = vec_and(vec_nand(imm0, imm0), vi_4); - vint32 tmp1 = vec_and(vec_nand(imm1, imm1), vi_4); - sign_bit._vecb0 = (vbool32)vec_sl(tmp0, vu_29); - sign_bit._vecb1 = (vbool32)vec_sl(tmp1, vu_29); - // get the polynom selection mask - // there is one polynom for 0 <= x <= Pi / 4 - // and another one for Pi / 4 < x <= Pi / 2 - // Both branches will be computed. - - poly_mask._vecb0 = (vbool32)vec_cmpeq((imm0 & vi_2), vi_0); - poly_mask._vecb1 = (vbool32)vec_cmpeq((imm1 & vi_2), vi_0); - - // The magic pass: "Extended precision modular arithmetic" - // x = ((x - y * DP1) - y * DP2) - y * DP3; - x = y.madd(minus_cephes_dp1, x); - x = y.madd(minus_cephes_dp2, x); - x = y.madd(minus_cephes_dp3, x); - - // Evaluate the first polynom (0 <= x <= Pi/4) - auto z = x * x; - y = z.madd(coscof_p0, coscof_p1); - y = y.madd(z, coscof_p2); - y = y * z * z; - y = y - z * half + one; - - // Evaluate the second polynom (Pi/4 <= x <= 0) - auto y_2 = z.madd(sincof_p0, sincof_p1); - y_2 = y_2.madd(z, sincof_p2); - y_2 = y_2 * z; - y_2 = y_2.madd(x, x); - - // select the correct result from the two polynoms - y = blendv(y, y_2, poly_mask); - // update the sign - y = y ^ sign_bit; - - return y; + return {Sleef_cosf4_u10vsx(_vec0), Sleef_cosf4_u10vsx(_vec1)}; } Vectorized C10_ALWAYS_INLINE cosh() const { - // cosh = 1/2 * (e^x + e^-x) - auto x = abs(); - auto e_x = x.exp(); - auto ret = (e_x + Vectorized(one) / e_x) * half; - // inf and nan checks -#if 0 - ret = blendv(ret, v_inf, x >= vf_89); - ret = blendv(ret, v_inf, ret.isnan()); - ret = blendv(ret, v_nan, this->isnan()); -#endif - return ret; + return {Sleef_coshf4_u10vsx(_vec0), Sleef_coshf4_u10vsx(_vec1)}; } Vectorized C10_ALWAYS_INLINE floor() const { return {vec_floor(_vec0), vec_floor(_vec1)}; @@ -440,97 +339,16 @@ class Vectorized { return {vec_round(_vec0), vec_round(_vec1)}; } Vectorized C10_ALWAYS_INLINE sin() const { - // take the absolute value and xtract sign - auto x = abs(); - auto sign_bit = (*this) & sign_mask; - - // scale by 4/Pi - auto y = x * _4div_pi; - // store the integer part of y in mm0 - - // j=(j+1) & (~1) (see the cephes sources) - vint32 imm0 = (vec_signed(y._vec0) + vi_1) & vi_inv1; - vint32 imm1 = (vec_signed(y._vec1) + vi_1) & vi_inv1; - y._vec0 = vec_float(imm0); - y._vec1 = vec_float(imm1); - // get the swap sign flag - Vectorized swap_sign_bit, poly_mask; - swap_sign_bit._vecb0 = (vbool32)vec_sl(imm0 & vi_4, vu_29); - swap_sign_bit._vecb1 = (vbool32)vec_sl(imm1 & vi_4, vu_29); - // get the polynom selection mask - // there is one polynom for 0 <= x <= Pi/4 - // and another one for Pi/4 C10_ALWAYS_INLINE sinh() const { - auto temp_abs = abs(); - // get exponent - auto ret = temp_abs.exp(); - auto recp = Vectorized(half) / ret; - auto v = ret * half - recp; - // extract the sign bit (upper one) - auto sign_bit = (*this) & sign_mask; - auto z = temp_abs * temp_abs; - auto y = z.madd(p0, p1); - y = y.madd(z, p2); - y = (y * z).madd(temp_abs, temp_abs); - // check and select - auto result = blendv(y, v, temp_abs > one); - return result | sign_bit; + return {Sleef_sinhf4_u10vsx(_vec0), Sleef_sinhf4_u10vsx(_vec1)}; } Vectorized C10_ALWAYS_INLINE tan() const { - return {Sleef_tanf4_u10vsx(_vec0), Sleef_tanf4_u10vsx(_vec1)}; + return {Sleef_tanf4_u10vsx(_vec0), Sleef_tanf4_u10vsx(_vec1)}; } Vectorized C10_ALWAYS_INLINE tanh() const { - auto x = *this; - auto vabs = abs(); - // get exponent - auto exp2x = (vabs + vabs).exp(); - auto vv = Vectorized(one) - Vectorized(two) / (exp2x + one); - // extract the sign bit (upper one) - auto sign_bit = (*this) & sign_mask; - auto z = vabs * vabs; - auto y = z.madd(tanh_p0, tanh_p1); - auto tmp = y.madd(z, tanh_p2); - y = z.madd(tmp, tanh_p3); - tmp = y.madd(z, tanh_p4); - y = tmp * z; - tmp = y.madd(x, x); - // add sign - vv = vv | sign_bit; - // check and select - auto sel_mask = vabs >= tanh_0p625; - auto max_mask = vabs > tanh_half_max; - auto max_ret = sign_bit ^ one; - return blendv(blendv(tmp, vv, sel_mask), max_ret, max_mask); + return {Sleef_tanhf4_u10vsx(_vec0), Sleef_tanhf4_u10vsx(_vec1)}; } Vectorized C10_ALWAYS_INLINE trunc() const { return {vec_trunc(_vec0), vec_trunc(_vec1)}; @@ -555,15 +373,15 @@ class Vectorized { } Vectorized fmod(const Vectorized& b) const { - return {Sleef_fmodf4_vsx(_vec0, b._vec0),Sleef_fmodf4_vsx(_vec1, b._vec1)}; + return {Sleef_fmodf4_vsx(_vec0, b._vec0),Sleef_fmodf4_vsx(_vec1, b._vec1)}; } Vectorized hypot(const Vectorized& b) const { - return {Sleef_hypotf4_u05vsx(_vec0, b._vec0), Sleef_hypotf4_u05vsx(_vec1, b._vec1)}; + return {Sleef_hypotf4_u05vsx(_vec0, b._vec0), Sleef_hypotf4_u05vsx(_vec1, b._vec1)}; } Vectorized nextafter(const Vectorized& b) const { - return {Sleef_nextafterf4_vsx(_vec0, b._vec0), Sleef_nextafterf4_vsx(_vec1, b._vec1)}; + return {Sleef_nextafterf4_vsx(_vec0, b._vec0), Sleef_nextafterf4_vsx(_vec1, b._vec1)}; } Vectorized igamma(const Vectorized& x) const { diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int16_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int16_vsx.h index 464a13c9f5f77..7c300c8087cff 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int16_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int16_vsx.h @@ -269,7 +269,7 @@ class Vectorized { vec_vsx_ld(offset16, reinterpret_cast(ptr))}; } - __at_align__ value_type tmp_values[size()]; + __at_align__ value_type tmp_values[size()] = {}; std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int32_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int32_vsx.h index 6ef6147447d54..c98ab6215e620 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int32_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int32_vsx.h @@ -199,7 +199,7 @@ class Vectorized { vec_vsx_ld(offset16, reinterpret_cast(ptr))}; } - __at_align__ value_type tmp_values[size()]; + __at_align__ value_type tmp_values[size()] = {}; std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int64_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int64_vsx.h index c0f1146d9d357..a4171026a2b99 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int64_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int64_vsx.h @@ -148,7 +148,7 @@ class Vectorized { (vint64)vec_vsx_ld(offset16, dptr)}; } - __at_align__ double tmp_values[size()]; + __at_align__ double tmp_values[size()] = {}; std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); return { diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_qint32_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_qint32_vsx.h index c3cec14a5b13e..a85730c9a6df8 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_qint32_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_qint32_vsx.h @@ -81,7 +81,7 @@ struct Vectorized { vec_vsx_ld(offset16, reinterpret_cast(ptr))}; } - __at_align__ value_type tmp_values[size()]; + __at_align__ value_type tmp_values[size()] = {}; std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; diff --git a/aten/src/ATen/cpu/vec/vec512/vec512.h b/aten/src/ATen/cpu/vec/vec512/vec512.h index 0c6f33fa08a06..8656756aaed56 100644 --- a/aten/src/ATen/cpu/vec/vec512/vec512.h +++ b/aten/src/ATen/cpu/vec/vec512/vec512.h @@ -190,6 +190,65 @@ inline deinterleave2(const Vectorized& a, const Vectorized& _mm512_mask_permutex2var_ps(a, 0xffff, idx2, b)); } +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FLIP ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template<> +inline Vectorized flip(const Vectorized & v) { + const __m512i mask = _mm512_set_epi32(0, 1, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15); + return _mm512_permutexvar_ps(mask, v); +} + +template<> +inline Vectorized flip(const Vectorized & v) { + const __m512i mask = _mm512_set_epi64(0, 1, 2, 3, 4, 5, 6, 7); + return _mm512_permutexvar_pd(mask, v); +} + +template<> +inline Vectorized flip(const Vectorized & v) { + const __m512i mask = _mm512_set_epi64(0, 1, 2, 3, 4, 5, 6, 7); + return _mm512_permutexvar_epi64(mask, v); +} + +template<> +inline Vectorized flip(const Vectorized & v) { + const __m512i mask = _mm512_set_epi32(0, 1, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15); + return _mm512_permutexvar_epi32(mask, v); +} + +template<> +inline Vectorized flip(const Vectorized & v) { + const __m512i mask = _mm512_set_epi16( + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31 + ); + return _mm512_permutexvar_epi16(mask, v); +} + +inline __m512i flip8(const __m512i & v) { + const __m512i mask1 = _mm512_set_epi8( + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 + ); + const __m512i mask2 = _mm512_set_epi64(1, 0, 3, 2, 5, 4, 7, 6); + auto reversed_vec = _mm512_shuffle_epi8(v, mask1); + return _mm512_permutexvar_epi64(mask2, reversed_vec); +} + +template<> +inline Vectorized flip(const Vectorized & v) { + return flip8(v); +} + +template<> +inline Vectorized flip(const Vectorized & v) { + return flip8(v); +} + #endif // defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER) }}} diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_complex_double.h b/aten/src/ATen/cpu/vec/vec512/vec512_complex_double.h index 9d862534a9d67..cb73beaaedd60 100644 --- a/aten/src/ATen/cpu/vec/vec512/vec512_complex_double.h +++ b/aten/src/ATen/cpu/vec/vec512/vec512_complex_double.h @@ -248,7 +248,7 @@ template <> class Vectorized> { return _mm512_div_pd(log(), log10_); } Vectorized> log1p() const { - AT_ERROR("not supported for complex numbers"); + return map(std::log1p); } Vectorized> asin() const { // asin(x) diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_complex_float.h b/aten/src/ATen/cpu/vec/vec512/vec512_complex_float.h index 966f42a253484..03b75ed035131 100644 --- a/aten/src/ATen/cpu/vec/vec512/vec512_complex_float.h +++ b/aten/src/ATen/cpu/vec/vec512/vec512_complex_float.h @@ -753,7 +753,7 @@ template <> class Vectorized> { return _mm512_div_ps(log(), log10_); } Vectorized> log1p() const { - AT_ERROR("not supported for complex numbers"); + return map(std::log1p); } Vectorized> asin() const { // asin(x) diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_int.h b/aten/src/ATen/cpu/vec/vec512/vec512_int.h index c2cbc0b1d7f94..73aae89d51be3 100644 --- a/aten/src/ATen/cpu/vec/vec512/vec512_int.h +++ b/aten/src/ATen/cpu/vec/vec512/vec512_int.h @@ -135,7 +135,6 @@ class Vectorized : public Vectorizedi { Vectorized conj() const { return *this; } - Vectorized frac() const; Vectorized neg() const; Vectorized operator==(const Vectorized& other) const { auto mask = _mm512_cmpeq_epi64_mask(values, other.values); @@ -285,7 +284,6 @@ class Vectorized : public Vectorizedi { Vectorized conj() const { return *this; } - Vectorized frac() const; Vectorized neg() const; Vectorized operator==(const Vectorized& other) const { auto mask = _mm512_cmpeq_epi32_mask(values, other.values); @@ -517,7 +515,6 @@ class Vectorized : public Vectorizedi { Vectorized conj() const { return *this; } - Vectorized frac() const; Vectorized neg() const; Vectorized operator==(const Vectorized& other) const { auto mask = _mm512_cmpeq_epi16_mask(values, other.values); @@ -552,35 +549,38 @@ class Vectorized : public Vectorizedi { Vectorized le(const Vectorized& other) const; }; -template <> -class Vectorized : public Vectorizedi { -private: +template +class Vectorized8 : public Vectorizedi { + static_assert( + std::is_same::value || std::is_same::value, + "Only int8_t/uint8_t are supported"); +protected: static constexpr __m512i zero_vector {0, 0, 0, 0, 0, 0, 0, 0}; - static const Vectorized ones; + static const Vectorized ones; public: - using value_type = int8_t; + using value_type = T; static constexpr int size() { return 64; } using Vectorizedi::Vectorizedi; - Vectorized() {} - Vectorized(int8_t v) { values = _mm512_set1_epi8(v); } - Vectorized(int8_t val1, int8_t val2, int8_t val3, int8_t val4, - int8_t val5, int8_t val6, int8_t val7, int8_t val8, - int8_t val9, int8_t val10, int8_t val11, int8_t val12, - int8_t val13, int8_t val14, int8_t val15, int8_t val16, - int8_t val17, int8_t val18, int8_t val19, int8_t val20, - int8_t val21, int8_t val22, int8_t val23, int8_t val24, - int8_t val25, int8_t val26, int8_t val27, int8_t val28, - int8_t val29, int8_t val30, int8_t val31, int8_t val32, - int8_t val33, int8_t val34, int8_t val35, int8_t val36, - int8_t val37, int8_t val38, int8_t val39, int8_t val40, - int8_t val41, int8_t val42, int8_t val43, int8_t val44, - int8_t val45, int8_t val46, int8_t val47, int8_t val48, - int8_t val49, int8_t val50, int8_t val51, int8_t val52, - int8_t val53, int8_t val54, int8_t val55, int8_t val56, - int8_t val57, int8_t val58, int8_t val59, int8_t val60, - int8_t val61, int8_t val62, int8_t val63, int8_t val64){ + Vectorized8() {} + Vectorized8(T v) { values = _mm512_set1_epi8(v); } + Vectorized8(T val1, T val2, T val3, T val4, + T val5, T val6, T val7, T val8, + T val9, T val10, T val11, T val12, + T val13, T val14, T val15, T val16, + T val17, T val18, T val19, T val20, + T val21, T val22, T val23, T val24, + T val25, T val26, T val27, T val28, + T val29, T val30, T val31, T val32, + T val33, T val34, T val35, T val36, + T val37, T val38, T val39, T val40, + T val41, T val42, T val43, T val44, + T val45, T val46, T val47, T val48, + T val49, T val50, T val51, T val52, + T val53, T val54, T val55, T val56, + T val57, T val58, T val59, T val60, + T val61, T val62, T val63, T val64){ values = _mm512_set_epi8(val64, val63, val62, val61, val60, val59, val58, val57, val56, val55, val54, val53,val52, val51, val50, val49, val48, val47, val46, val45, val44, val43, val42, val41, @@ -591,18 +591,12 @@ class Vectorized : public Vectorizedi { val8, val7, val6, val5, val4, val3, val2, val1); } template - static Vectorized blend(Vectorized a, Vectorized b) { + static Vectorized blend(Vectorized a, Vectorized b) { return _mm512_mask_blend_epi8(mask, a.values, b.values); } - static Vectorized blendv(const Vectorized& a, const Vectorized& b, - const Vectorized& mask) { - auto msb_one = _mm512_set1_epi8(0xFF); - auto mask_ = _mm512_cmp_epi8_mask(mask, msb_one, _MM_CMPINT_EQ); - return _mm512_mask_blend_epi8(mask_, a.values, b.values); - } template - static Vectorized arange(int8_t base = 0, step_t step = static_cast(1)) { - return Vectorized( + static Vectorized arange(T base = 0, step_t step = static_cast(1)) { + return Vectorized( base, base + step, base + 2 * step, base + 3 * step, base + 4 * step, base + 5 * step, base + 6 * step, base + 7 * step, base + 8 * step, base + 9 * step, base + 10 * step, base + 11 * step, @@ -620,8 +614,8 @@ class Vectorized : public Vectorizedi { base + 56 * step, base + 57 * step, base + 58 * step, base + 59 * step, base + 60 * step, base + 61 * step, base + 62 * step, base + 63 * step); } - static Vectorized - set(Vectorized a, Vectorized b, int8_t count = size()) { + static Vectorized + set(Vectorized a, Vectorized b, T count = size()) { switch (count) { case 0: return a; @@ -754,18 +748,18 @@ class Vectorized : public Vectorizedi { } return b; } - static Vectorized loadu(const void* ptr) { + static Vectorized loadu(const void* ptr) { return _mm512_loadu_si512(reinterpret_cast(ptr)); } - static Vectorized loadu(const void* ptr, int8_t count) { - __at_align__ int8_t tmp_values[size()]; + static Vectorized loadu(const void* ptr, T count) { + __at_align__ T tmp_values[size()]; // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two // instructions while a loop would be compiled to one instruction. for (const auto i : c10::irange(size())) { tmp_values[i] = 0; } - std::memcpy(tmp_values, ptr, count * sizeof(int8_t)); + std::memcpy(tmp_values, ptr, count * sizeof(T)); return loadu(tmp_values); } void store(void* ptr, int count = size()) const { @@ -774,27 +768,42 @@ class Vectorized : public Vectorizedi { // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm512-storeu-si512.html _mm512_storeu_si512(reinterpret_cast<__m512i*>(ptr), values); } else if (count > 0) { - __at_align__ int8_t tmp_values[size()]; + __at_align__ T tmp_values[size()]; _mm512_storeu_si512(reinterpret_cast<__m512i*>(tmp_values), values); - std::memcpy(ptr, tmp_values, count * sizeof(int8_t)); + std::memcpy(ptr, tmp_values, count * sizeof(T)); } } - const int8_t& operator[](int idx) const = delete; - int8_t& operator[](int idx) = delete; - Vectorized abs() const { - return _mm512_abs_epi8(values); - } - Vectorized real() const { + const T& operator[](int idx) const = delete; + T& operator[](int idx) = delete; + Vectorized real() const { return *this; } - Vectorized imag() const { + Vectorized imag() const { return _mm512_set1_epi8(0); } - Vectorized conj() const { + Vectorized conj() const { return *this; } - Vectorized frac() const; +}; + +template<> +class Vectorized: public Vectorized8 { +public: + using Vectorized8::Vectorized8; + + static Vectorized blendv(const Vectorized& a, const Vectorized& b, + const Vectorized& mask) { + auto msb_one = _mm512_set1_epi8(0xFF); + auto mask_ = _mm512_cmp_epi8_mask(mask, msb_one, _MM_CMPINT_EQ); + return _mm512_mask_blend_epi8(mask_, a.values, b.values); + } + Vectorized neg() const; + + Vectorized abs() const { + return _mm512_abs_epi8(values); + } + Vectorized operator==(const Vectorized& other) const { auto mask = _mm512_cmpeq_epi8_mask(values, other.values); return _mm512_mask_set1_epi8(zero_vector, mask, 0xFF); @@ -812,12 +821,10 @@ class Vectorized : public Vectorizedi { return _mm512_mask_set1_epi8(zero_vector, mask, 0xFF); } Vectorized operator>(const Vectorized& other) const { - auto mask = _mm512_cmpgt_epi8_mask(values, other.values); - return _mm512_mask_set1_epi8(zero_vector, mask, 0xFF); + return other < *this; } Vectorized operator>=(const Vectorized& other) const { - auto mask = _mm512_cmpge_epi8_mask(values, other.values); - return _mm512_mask_set1_epi8(zero_vector, mask, 0xFF); + return other <= *this; } Vectorized eq(const Vectorized& other) const; @@ -828,6 +835,55 @@ class Vectorized : public Vectorizedi { Vectorized le(const Vectorized& other) const; }; +template<> +class Vectorized: public Vectorized8 { +public: + using Vectorized8::Vectorized8; + + static Vectorized blendv(const Vectorized& a, const Vectorized& b, + const Vectorized& mask) { + auto msb_one = _mm512_set1_epi8(0xFF); + auto mask_ = _mm512_cmp_epu8_mask(mask, msb_one, _MM_CMPINT_EQ); + return _mm512_mask_blend_epi8(mask_, a.values, b.values); + } + + Vectorized neg() const; + + Vectorized abs() const { + return *this; + } + + Vectorized operator==(const Vectorized& other) const { + auto mask = _mm512_cmpeq_epu8_mask(values, other.values); + return _mm512_mask_set1_epi8(zero_vector, mask, 0xFF); + } + Vectorized operator!=(const Vectorized& other) const { + auto mask = _mm512_cmpneq_epu8_mask(values, other.values); + return _mm512_mask_set1_epi8(zero_vector, mask, 0xFF); + } + Vectorized operator<(const Vectorized& other) const { + auto mask = _mm512_cmplt_epu8_mask(values, other.values); + return _mm512_mask_set1_epi8(zero_vector, mask, 0xFF); + } + Vectorized operator<=(const Vectorized& other) const { + auto mask = _mm512_cmple_epu8_mask(values, other.values); + return _mm512_mask_set1_epi8(zero_vector, mask, 0xFF); + } + Vectorized operator>(const Vectorized& other) const { + return other < *this; + } + Vectorized operator>=(const Vectorized& other) const { + return other <= *this; + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + template <> Vectorized inline operator+(const Vectorized& a, const Vectorized& b) { return _mm512_add_epi64(a, b); @@ -848,6 +904,11 @@ Vectorized inline operator+(const Vectorized& a, const Vectorize return _mm512_add_epi8(a, b); } +template <> +Vectorized inline operator+(const Vectorized& a, const Vectorized& b) { + return _mm512_add_epi8(a, b); +} + template <> Vectorized inline operator-(const Vectorized& a, const Vectorized& b) { return _mm512_sub_epi64(a, b); @@ -868,6 +929,11 @@ Vectorized inline operator-(const Vectorized& a, const Vectorize return _mm512_sub_epi8(a, b); } +template <> +Vectorized inline operator-(const Vectorized& a, const Vectorized& b) { + return _mm512_sub_epi8(a, b); +} + // Negation. Defined here so we can utilize operator- inline Vectorized Vectorized::neg() const { return Vectorized(0) - *this; @@ -885,6 +951,10 @@ inline Vectorized Vectorized::neg() const { return Vectorized(0) - *this; } +inline Vectorized Vectorized::neg() const { + return Vectorized(0) - *this; +} + template <> Vectorized inline operator*(const Vectorized& a, const Vectorized& b) { return _mm512_mullo_epi64(a, b); @@ -918,6 +988,12 @@ Vectorized inline operator*(const Vectorized& a, const Vectorize return int_elementwise_binary_512(a, b, std::multiplies()); } +template <> +Vectorized inline operator*(const Vectorized& a, const Vectorized& b) { + // We don't have an instruction for multiplying uint8_t + return int_elementwise_binary_512(a, b, std::multiplies()); +} + template <> Vectorized inline minimum(const Vectorized& a, const Vectorized& b) { return _mm512_min_epi64(a, b); @@ -938,6 +1014,11 @@ Vectorized inline minimum(const Vectorized& a, const Vectorized< return _mm512_min_epi8(a, b); } +template <> +Vectorized inline minimum(const Vectorized& a, const Vectorized& b) { + return _mm512_min_epu8(a, b); +} + template <> Vectorized inline maximum(const Vectorized& a, const Vectorized& b) { return _mm512_max_epi64(a, b); @@ -958,6 +1039,11 @@ Vectorized inline maximum(const Vectorized& a, const Vectorized< return _mm512_max_epi8(a, b); } +template <> +Vectorized inline maximum(const Vectorized& a, const Vectorized& b) { + return _mm512_max_epi8(a, b); +} + template <> Vectorized inline clamp(const Vectorized& a, const Vectorized& min_val, const Vectorized& max_val) { return _mm512_min_epi64(max_val, _mm512_max_epi64(a, min_val)); @@ -978,6 +1064,11 @@ Vectorized inline clamp(const Vectorized& a, const Vectorized +Vectorized inline clamp(const Vectorized& a, const Vectorized& min_val, const Vectorized& max_val) { + return _mm512_min_epu8(max_val, _mm512_max_epu8(a, min_val)); +} + template <> Vectorized inline clamp_max(const Vectorized& a, const Vectorized& max_val) { return _mm512_min_epi64(max_val, a); @@ -998,6 +1089,11 @@ Vectorized inline clamp_max(const Vectorized& a, const Vectorize return _mm512_min_epi8(max_val, a); } +template <> +Vectorized inline clamp_max(const Vectorized& a, const Vectorized& max_val) { + return _mm512_min_epu8(max_val, a); +} + template <> Vectorized inline clamp_min(const Vectorized& a, const Vectorized& min_val) { return _mm512_max_epi64(min_val, a); @@ -1018,6 +1114,11 @@ Vectorized inline clamp_min(const Vectorized& a, const Vectorize return _mm512_max_epi8(min_val, a); } +template <> +Vectorized inline clamp_min(const Vectorized& a, const Vectorized& min_val) { + return _mm512_max_epu8(min_val, a); +} + template Vectorized inline convert_to_int32(const T* ptr) { return Vectorized::loadu(ptr); @@ -1049,6 +1150,10 @@ template <> Vectorized inline operator/(const Vectorized& a, const Vectorized& b) { return int_elementwise_binary_512(a, b, std::divides()); } +template <> +Vectorized inline operator/(const Vectorized& a, const Vectorized& b) { + return int_elementwise_binary_512(a, b, std::divides()); +} template>::value, int> = 0> inline Vectorized operator&(const Vectorized& a, const Vectorized& b) { @@ -1163,6 +1268,164 @@ inline Vectorized Vectorized::le(const Vectorized& other return (*this <= other) & Vectorized(1); } +inline Vectorized Vectorized::eq(const Vectorized& other) const { + return (*this == other) & Vectorized(1); +} + +inline Vectorized Vectorized::ne(const Vectorized& other) const { + return (*this != other) & Vectorized(1); +} + +inline Vectorized Vectorized::gt(const Vectorized& other) const { + return (*this > other) & Vectorized(1); +} + +inline Vectorized Vectorized::ge(const Vectorized& other) const { + return (*this >= other) & Vectorized(1); +} + +inline Vectorized Vectorized::lt(const Vectorized& other) const { + return (*this < other) & Vectorized(1); +} + +inline Vectorized Vectorized::le(const Vectorized& other) const { + return (*this <= other) & Vectorized(1); +} + +template ::value || std::is_same::value, int> = 0> +Vectorized inline shift_512_8(const Vectorized& a, const Vectorized& b) { + // No vector instruction for shifting int8_t/uint8_t, so emulating + // it instead. + + // Control masks for shuffle operation, treating 512 bits as an + // array of 8-bit elements, and considering pairs of neighboring + // elements. Specifially, a mask named "ctl_M_N" (M,N in [0,1], and + // M!=N) is set so that shuffle will move element with index M from + // input pair into element with index N in output pair, and element + // with index M in output pair will be set to all 0s. + __m512i ctl_0_1 = _mm512_set_epi8(62, 0x80, 60, 0x80, 58, 0x80, 56, 0x80, + 54, 0x80, 52, 0x80, 50, 0x80, 48, 0x80, + 46, 0x80, 44, 0x80, 42, 0x80, 40, 0x80, + 38, 0x80, 36, 0x80, 34, 0x80, 32, 0x80, + 30, 0x80, 28, 0x80, 26, 0x80, 24, 0x80, + 22, 0x80, 20, 0x80, 18, 0x80, 16, 0x80, + 14, 0x80, 12, 0x80, 10, 0x80, 8, 0x80, + 6, 0x80, 4, 0x80, 2, 0x80, 0, 0x80); + __m512i ctl_1_0 = _mm512_set_epi8(0x80, 63, 0x80, 61, 0x80, 59, 0x80, 57, + 0x80, 55, 0x80, 53, 0x80, 51, 0x80, 49, + 0x80, 47, 0x80, 45, 0x80, 43, 0x80, 41, + 0x80, 39, 0x80, 37, 0x80, 35, 0x80, 33, + 0x80, 31, 0x80, 29, 0x80, 27, 0x80, 25, + 0x80, 23, 0x80, 21, 0x80, 19, 0x80, 17, + 0x80, 15, 0x80, 13, 0x80, 11, 0x80, 9, + 0x80, 7, 0x80, 5, 0x80, 3, 0x80, 1); + + // Masks for bitwise and operation, treating 512 bits as an array of + // 8-bit elements, and considering them in pairs of neighboring + // elements. A mask named "keep_M" (M in [0,1]) is set so that + // bitwise and will copy element with index M from input pair into + // element with the same index in output pair, while the other + // element in output pair will be set to all 0s. + __m512i keep_0 = _mm512_set1_epi16(0xFF); + __m512i keep_1 = _mm512_set1_epi16(0xFF00); + + // Take each 8-bit element with idx%2==0 from input array to be + // shifted and extend it to 16 bits so that 0s are added to the + // right. Then, perform shifting on this 16-bit number. Upper 8 + // bits will be proper result of shifting original 8-bit number, so + // write them to result array, into the same position from which + // corresponding input element is taken. Also, make sure that + // result array elements with idx%2!=0 are set to all 0s. + // + // Note that number of bits to shift for is extended to 16 bits by + // adding 0s to the left. That means this number is not properly + // sign-extended for negative values. However, number of bits to + // shift is treated as an unsigned integer by respective shift + // intrinsics anyway so if negative then either with or without + // proper sign extension, it will be interpreted as a number greater + // than 32, and the shifting result will be the same. + __m512i a0 = _mm512_shuffle_epi8(a, ctl_0_1); + __m512i b0 = _mm512_and_si512(b, keep_0); + __m512i c0; + if (left_shift) + c0 = _mm512_sllv_epi16(a0, b0); + else + if (std::is_same::value) + c0 = _mm512_srav_epi16(a0, b0); + else + c0 = _mm512_srlv_epi16(a0, b0); + c0 = _mm512_shuffle_epi8(c0, ctl_1_0); + + // Peform shifting the same way for input array elements with + // idx%2==1. + __m512i a1 = _mm512_and_si512(a, keep_1); + __m512i b1 = _mm512_shuffle_epi8(b, ctl_1_0); + __m512i c1; + if (left_shift) + c1 = _mm512_sllv_epi16(a1, b1); + else + if (std::is_same::value) + c1 = _mm512_srav_epi16(a1, b1); + else + c1 = _mm512_srlv_epi16(a1, b1); + c1 = _mm512_and_si512(c1, keep_1); + + // Merge partial results into the final result. + __m512i c = _mm512_or_si512(c0, c1); + + return c; +} + +template <> +Vectorized inline operator<<(const Vectorized& a, const Vectorized& b) { + return _mm512_sllv_epi64(a, b); +} + +template <> +Vectorized inline operator<<(const Vectorized& a, const Vectorized& b) { + return _mm512_sllv_epi32(a, b); +} + +template <> +Vectorized inline operator<<(const Vectorized& a, const Vectorized& b) { + return _mm512_sllv_epi16(a, b); +} + +template <> +Vectorized inline operator<<(const Vectorized& a, const Vectorized& b) { + return shift_512_8(a, b); +} + +template <> +Vectorized inline operator<<(const Vectorized& a, const Vectorized& b) { + return shift_512_8(a, b); +} + +template <> +Vectorized inline operator>>(const Vectorized& a, const Vectorized& b) { + return _mm512_srav_epi64(a, b); +} + +template <> +Vectorized inline operator>>(const Vectorized& a, const Vectorized& b) { + return _mm512_srav_epi32(a, b); +} + +template <> +Vectorized inline operator>>(const Vectorized& a, const Vectorized& b) { + return _mm512_srav_epi16(a, b); +} + +template <> +Vectorized inline operator>>(const Vectorized& a, const Vectorized& b) { + return shift_512_8(a, b); +} + +template <> +Vectorized inline operator>>(const Vectorized& a, const Vectorized& b) { + return shift_512_8(a, b); +} + #endif }}} diff --git a/aten/src/ATen/cpu/vec/vec_base.h b/aten/src/ATen/cpu/vec/vec_base.h index b9b3745e99d5f..abf106e8d5b36 100644 --- a/aten/src/ATen/cpu/vec/vec_base.h +++ b/aten/src/ATen/cpu/vec/vec_base.h @@ -799,6 +799,21 @@ inline Vectorized operator~(const Vectorized& a) { return a ^ ones; } +template Vectorized inline operator<<(const Vectorized &a, const Vectorized &b) { + Vectorized c; + for (int i = 0; i != Vectorized::size(); i++) { + c[i] = a[i] << b[i]; + } + return c; +} + +template Vectorized inline operator>>(const Vectorized &a, const Vectorized &b) { + Vectorized c; + for (int i = 0; i != Vectorized::size(); i++) { + c[i] = a[i] >> b[i]; + } + return c; +} template inline Vectorized& operator += (Vectorized& a, const Vectorized& b) { @@ -826,6 +841,18 @@ inline Vectorized& operator *= (Vectorized& a, const Vectorized& b) { return a; } +template +inline Vectorized& operator <<= (Vectorized& a, const Vectorized& b) { + a = a << b; + return a; +} + +template +inline Vectorized& operator >>= (Vectorized& a, const Vectorized& b) { + a = a >> b; + return a; +} + template inline Vectorized fmadd(const Vectorized& a, const Vectorized& b, const Vectorized& c) { return a * b + c; @@ -988,4 +1015,16 @@ inline void convert(const src_T *src, dst_T *dst, int64_t n) { } } +template +inline Vectorized flip(const Vectorized & data) { + static constexpr int size = Vectorized::size(); + T output[size]; + T buffer[size]; + data.store(static_cast(buffer)); + for (const auto i : c10::irange(size)) { + output[i] = buffer[size - i - 1]; + } + return Vectorized::loadu(static_cast(output)); +} + }}} diff --git a/aten/src/ATen/cuda/Atomic.cuh b/aten/src/ATen/cuda/Atomic.cuh index 42975411e841e..3d60b672e9725 100644 --- a/aten/src/ATen/cuda/Atomic.cuh +++ b/aten/src/ATen/cuda/Atomic.cuh @@ -6,6 +6,10 @@ #include +#if !(defined(USE_ROCM) || ((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))) +#include +#endif + template struct AtomicFPOp; @@ -219,10 +223,15 @@ static inline __device__ at::Half gpuAtomicAdd(at::Half *address, at::Half val) } static inline __device__ at::BFloat16 gpuAtomicAdd(at::BFloat16 *address, at::BFloat16 val) { - return AtomicFPOp()(address, val, - [](at::BFloat16 bsum, at::BFloat16 val) { - return bsum + val; - }); +#if defined(USE_ROCM) || ((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))) +return AtomicFPOp()(address, val, + [](at::BFloat16 bsum, at::BFloat16 val) { + return bsum + val; + }); +#else + __nv_bfloat16 r = atomicAdd(reinterpret_cast<__nv_bfloat16*>(address), *reinterpret_cast<__nv_bfloat16*>(&val)); + return *reinterpret_cast(&r); +#endif } #if defined(CUDA_VERSION) && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || CUDA_VERSION < 8000) diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index 866f53ee7f87f..648b55774f194 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -741,7 +741,7 @@ void gemm_and_bias( TORCH_CUDABLAS_CHECK(CUBLAS_STATUS_NOT_SUPPORTED); } - TORCH_CUDABLAS_CHECK(cublasLtMatmul( + cublasStatus_t cublasStatus = cublasLtMatmul( ltHandle, computeDesc.descriptor(), &alpha_val, @@ -757,7 +757,33 @@ void gemm_and_bias( &heuristicResult.algo, workspace.data_ptr(), workspaceSize, - at::cuda::getCurrentCUDAStream())); + at::cuda::getCurrentCUDAStream()); + TORCH_CHECK( + cublasStatus == CUBLAS_STATUS_SUCCESS, + "CUDA error: ", + at::cuda::blas::_cublasGetErrorEnum(cublasStatus), + " when calling cublasLtMatmul with transpose_mat1 ", + transpose_mat1, + " transpose_mat2 ", + transpose_mat2, + " m ", + m, + " n ", + n, + " k ", + k, + " mat1_ld ", + mat1_ld, + " mat2_ld ", + mat2_ld, + " result_ld ", + result_ld, + " abcType ", + abcType, + " computeType ", + computeType, + " scaleType ", + scaleType); } template void gemm_and_bias( diff --git a/aten/src/ATen/cuda/CUDAGraph.cpp b/aten/src/ATen/cuda/CUDAGraph.cpp index 24ee0b19ab90c..92eddeb4b755c 100644 --- a/aten/src/ATen/cuda/CUDAGraph.cpp +++ b/aten/src/ATen/cuda/CUDAGraph.cpp @@ -8,6 +8,8 @@ namespace at { namespace cuda { +static bool _cuda_graphs_debug = false; + MempoolId_t graph_pool_handle() { #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 // uuid count starts at 1. 0 is reserved to mean "wasn't set by graph_pool_handle". @@ -16,7 +18,7 @@ MempoolId_t graph_pool_handle() { // cudaStreamGetCaptureInfo id_s in capture_begin. return {0, uuid++}; #else - TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and not yet supported on ROCM"); + TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and is not yet supported on ROCM"); return {0, 0}; #endif } @@ -46,7 +48,7 @@ CUDAGraph::CUDAGraph() // CUDAStreams may not be default-constructed. : capture_stream_(at::cuda::getCurrentCUDAStream()) { #if (defined(CUDA_VERSION) && CUDA_VERSION < 11000) || defined(USE_ROCM) - TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and not yet supported on ROCM"); + TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and is not yet supported on ROCM"); #endif } @@ -122,7 +124,7 @@ void CUDAGraph::capture_begin(MempoolId_t pool/*=0*/) { // kernel will end up as part of the capture or not. c10::cuda::CUDACachingAllocator::notifyCaptureBegin(capture_dev_, id_, mempool_id_); #else - TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and not yet supported on ROCM"); + TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and is not yet supported on ROCM"); #endif } @@ -179,12 +181,24 @@ void CUDAGraph::capture_end() { "when capture began"); wholegraph_increment_ = gen->capture_epilogue(); - // Now that we've instantiated graph_ into graph_exec_, - // we don't need graph_ anymore. - AT_CUDA_CHECK(cudaGraphDestroy(graph_)); - has_graph_ = false; + size_t numCUDAGraphNodes = 0; + AT_CUDA_CHECK(cudaGraphGetNodes(graph_, NULL, &numCUDAGraphNodes)); + if (numCUDAGraphNodes == 0) { + TORCH_WARN("The CUDA Graph is empty. This ususally means that the graph was ", + "attempted to be captured on wrong device or stream."); + } + + // check if debug path is set + if (!_cuda_graphs_debug) { + // Now that we've instantiated graph_ into graph_exec_, + // we don't need graph_ anymore. + AT_CUDA_CHECK(cudaGraphDestroy(graph_)); + has_graph_ = false; + } else { + TORCH_WARN("DEBUG: TORCH_CUDAGRAPHS_DEBUG_PATH detected. graph_ will not be freed until debug_dump is called."); + } #else - TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and not yet supported on ROCM"); + TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and is not yet supported on ROCM"); #endif } @@ -219,7 +233,33 @@ void CUDAGraph::replay() { AT_CUDA_CHECK(cudaDeviceSynchronize()); } #else - TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and not yet supported on ROCM"); + TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and is not yet supported on ROCM"); +#endif +} + +void CUDAGraph::enable_debug_mode() { +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 + _cuda_graphs_debug = true; +#else + TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and is not yet supported on ROCM"); +#endif + +} + +void CUDAGraph::debug_dump(const std::string& debug_path) { +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 + if (_cuda_graphs_debug) { + TORCH_WARN("DEBUG: calling debug_dump()"); + if (has_graph_) { + TORCH_WARN("DEBUG: calling cudaGraphDebugDotPrint() with ", debug_path); + C10_CUDA_CHECK_WARN(cudaGraphDebugDotPrint(graph_, debug_path.c_str(), 1<<10)); // most verbose output + AT_CUDA_CHECK(cudaGraphDestroy(graph_)); + } + } else { + TORCH_WARN("CUDA Graphs debug not enabled, set with torch._C._cuda_enable_graphs_debug_mode"); + } +#else + TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and is not yet supported on ROCM"); #endif } @@ -255,7 +295,7 @@ void CUDAGraph::reset() { C10_CUDA_CHECK_WARN(cudaGraphExecDestroy(graph_exec_)); } #else - TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and not yet supported on ROCM"); + TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and is not yet supported on ROCM"); #endif } @@ -265,7 +305,7 @@ MempoolId_t CUDAGraph::pool() { TORCH_CHECK(has_graph_exec_, "Called CUDAGraph::pool() without a preceding successful capture."); #else - TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and not yet supported on ROCM"); + TORCH_CHECK(false, "CUDA graphs may only be used in Pytorch built with CUDA >= 11.0 and is not yet supported on ROCM"); #endif return mempool_id_; } diff --git a/aten/src/ATen/cuda/CUDAGraph.h b/aten/src/ATen/cuda/CUDAGraph.h index bacad79102a3e..fa5a73b65e05e 100644 --- a/aten/src/ATen/cuda/CUDAGraph.h +++ b/aten/src/ATen/cuda/CUDAGraph.h @@ -24,6 +24,8 @@ struct TORCH_CUDA_CPP_API CUDAGraph { void replay(); void reset(); MempoolId_t pool(); + void enable_debug_mode(); + void debug_dump(const std::string& debug_path); protected: #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 diff --git a/aten/src/ATen/cuda/detail/CUDAHooks.cpp b/aten/src/ATen/cuda/detail/CUDAHooks.cpp index b5e685dac65f1..25e4c2b44fa99 100644 --- a/aten/src/ATen/cuda/detail/CUDAHooks.cpp +++ b/aten/src/ATen/cuda/detail/CUDAHooks.cpp @@ -82,7 +82,7 @@ void CUDAHooks::initCUDA() const { at::cuda::detail::init_p2p_access_cache(num_devices); #if AT_MAGMA_ENABLED() - TORCH_INTERNAL_ASSERT(magma_init_fn != nullptr, "Cannot initilaize magma, init routine not set"); + TORCH_INTERNAL_ASSERT(magma_init_fn != nullptr, "Cannot initialize magma, init routine not set"); magma_init_fn(); #endif } diff --git a/aten/src/ATen/cudnn/Descriptors.cpp b/aten/src/ATen/cudnn/Descriptors.cpp index f954bbf5623ad..0e739a49bb33c 100644 --- a/aten/src/ATen/cudnn/Descriptors.cpp +++ b/aten/src/ATen/cudnn/Descriptors.cpp @@ -164,7 +164,7 @@ void FilterDescriptor::set(const at::Tensor &t, const at::MemoryFormat memory_fo filter_format = CUDNN_TENSOR_NHWC; break; default: - TORCH_INTERNAL_ASSERT(false, "unsurpported memory_format for cuDNN filters"); + TORCH_INTERNAL_ASSERT(false, "unsupported memory_format for cuDNN filters"); } set(getDataType(t), (int) dim, size, filter_format); } diff --git a/aten/src/ATen/cudnn/Descriptors.h b/aten/src/ATen/cudnn/Descriptors.h index a393548bd4d3f..e111987785cc5 100644 --- a/aten/src/ATen/cudnn/Descriptors.h +++ b/aten/src/ATen/cudnn/Descriptors.h @@ -46,7 +46,8 @@ inline int dataSize(cudnnDataType_t dataType) // that the stride for dim i is the product of the sizes of dims // i+1 to the end. This stride is indeed uniquely determined. This // function modifies 'stride' in place so this invariant holds. -static inline void fixSizeOneDimStride(int dim, const int *size, int *stride, bool nhwc) { +template +static inline void fixSizeOneDimStride(int dim, const T *size, T *stride, bool nhwc) { int64_t z = 1; int index = 0; std::vector permutation(dim); @@ -150,7 +151,7 @@ class TORCH_CUDA_CPP_API TensorDescriptor : public Descriptor< void set(cudnnDataType_t dataType, IntArrayRef sizes, IntArrayRef strides, size_t pad, bool nhwc); void set(cudnnDataType_t dataType, int dim, int* size, int* stride, bool nhwc) { - fixSizeOneDimStride(dim, size, stride, nhwc); + fixSizeOneDimStride(dim, size, stride, nhwc); AT_CUDNN_CHECK(cudnnSetTensorNdDescriptor(mut_desc(), dataType, dim, size, stride)); } }; diff --git a/aten/src/ATen/detail/FunctionTraits.h b/aten/src/ATen/detail/FunctionTraits.h index aab7300b585fe..f49a55e1326d5 100644 --- a/aten/src/ATen/detail/FunctionTraits.h +++ b/aten/src/ATen/detail/FunctionTraits.h @@ -76,3 +76,27 @@ struct binary_function_traits { using arg1_t = typename traits::template arg<0>::type; using arg2_t = typename traits::template arg<1>::type; }; + + +// Traits for calling with c10::guts::invoke, where member_functions have a first argument of ClassType +template +struct invoke_traits : public function_traits{ +}; + +template +struct invoke_traits : public invoke_traits{ +}; + +template +struct invoke_traits : public invoke_traits{ +}; + +template +struct invoke_traits : + public function_traits { +}; + +template +struct invoke_traits : + public function_traits { +}; diff --git a/aten/src/ATen/detail/MPSHooksInterface.cpp b/aten/src/ATen/detail/MPSHooksInterface.cpp new file mode 100644 index 0000000000000..a73e456caff58 --- /dev/null +++ b/aten/src/ATen/detail/MPSHooksInterface.cpp @@ -0,0 +1,31 @@ +// Copyright © 2022 Apple Inc. + +#include +#include +#include + +namespace at { +namespace detail { + +const MPSHooksInterface& getMPSHooks() { + static std::unique_ptr mps_hooks; +#if !defined C10_MOBILE + static c10::once_flag once; + c10::call_once(once, [] { + mps_hooks = MPSHooksRegistry()->Create("MPSHooks", MPSHooksArgs{}); + if (!mps_hooks) { + mps_hooks = std::make_unique(); + } + }); +#else + if (mps_hooks == nullptr) { + mps_hooks = std::make_unique(); + } +#endif + return *mps_hooks; +} +} // namespace detail + +C10_DEFINE_REGISTRY(MPSHooksRegistry, MPSHooksInterface, MPSHooksArgs) + +} // namespace at diff --git a/aten/src/ATen/detail/MPSHooksInterface.h b/aten/src/ATen/detail/MPSHooksInterface.h new file mode 100644 index 0000000000000..fd1f2f5a75c67 --- /dev/null +++ b/aten/src/ATen/detail/MPSHooksInterface.h @@ -0,0 +1,54 @@ +// Copyright © 2022 Apple Inc. + +#pragma once + +#include +#include +#include +#include + +#include +#include + +namespace at { +class Context; +} + +namespace at { + +struct TORCH_API MPSHooksInterface { + virtual ~MPSHooksInterface() = default; + + // Initialize the MPS library state + virtual void initMPS() const { + AT_ERROR("Cannot initialize MPS without MPS backend."); + } + + virtual bool hasMPS() const { + return false; + } + + virtual const Generator& getDefaultMPSGenerator() const { + AT_ERROR("Cannot get default MPS generator without MPS backend."); + } + + virtual Allocator* getMPSDeviceAllocator() const { + AT_ERROR("MPSDeviceAllocator requires MPS."); + } + + virtual void deviceSynchronize() const { + TORCH_CHECK(false, "Cannot synchronize MPS device without MPS backend. "); + } +}; + +struct TORCH_API MPSHooksArgs {}; + +C10_DECLARE_REGISTRY(MPSHooksRegistry, MPSHooksInterface, MPSHooksArgs); +#define REGISTER_MPS_HOOKS(clsname) \ + C10_REGISTER_CLASS(MPSHooksRegistry, clsname, clsname) + +namespace detail { +TORCH_API const MPSHooksInterface& getMPSHooks(); + +} // namespace detail +} // namespace at diff --git a/aten/src/ATen/functorch/ADInterpreters.cpp b/aten/src/ATen/functorch/ADInterpreters.cpp index 46c134f59d61b..fb97114bec504 100644 --- a/aten/src/ATen/functorch/ADInterpreters.cpp +++ b/aten/src/ATen/functorch/ADInterpreters.cpp @@ -28,7 +28,7 @@ static void checkForInvalidMutationOnCaptures( "as inputs."); } -static Tensor materializeGradWrappers(const Tensor& tensor, int64_t current_level) { +Tensor materializeGradWrappers(const Tensor& tensor, int64_t current_level) { if (!tensor.defined()) { return tensor; } @@ -44,6 +44,19 @@ static Tensor materializeGradWrappers(const Tensor& tensor, int64_t current_leve return makeTensorWrapper(tensor, current_level, /*is_immutable=*/true); } +static Tensor base_lift(const Tensor& tensor, int64_t level) { + auto tensor_ = unwrapIfDead(tensor); + return materializeGradWrappers(tensor_, level); +} + +Tensor GradInterpreterPtr::lift(const Tensor& tensor) const { + return base_lift(tensor, level()); +} + +Tensor JvpInterpreterPtr::lift(const Tensor& tensor) const { + return base_lift(tensor, level()); +} + static void autogradBasedTransformProcess( const c10::OperatorHandle& op, torch::jit::Stack* stack, diff --git a/aten/src/ATen/functorch/ADInterpreters.h b/aten/src/ATen/functorch/ADInterpreters.h index b8ad638c5aee4..6ec1cca065d61 100644 --- a/aten/src/ATen/functorch/ADInterpreters.h +++ b/aten/src/ATen/functorch/ADInterpreters.h @@ -7,7 +7,7 @@ namespace at { namespace functorch { // (grad, vjp and jvp). // See NOTE: [functorch interpreter stack] for more details. -struct GradInterpreterPtr { +struct TORCH_API GradInterpreterPtr { explicit GradInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Grad); } TransformType key() const { return base_->key(); } int64_t level() const { return base_->level(); } @@ -16,11 +16,12 @@ struct GradInterpreterPtr { bool prevGradMode() const { return c10::get(base_->meta()).prevGradMode_; } + Tensor lift(const Tensor& tensor) const; private: const Interpreter* base_; }; -struct JvpInterpreterPtr { +struct TORCH_API JvpInterpreterPtr { explicit JvpInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Jvp); } TransformType key() const { return base_->key(); } int64_t level() const { return base_->level(); } @@ -29,6 +30,7 @@ struct JvpInterpreterPtr { bool prevFwdGradMode() const { return c10::get(base_->meta()).prevFwdGradMode_; } + Tensor lift(const Tensor& tensor) const; private: const Interpreter* base_; }; diff --git a/aten/src/ATen/functorch/BatchRulesBinaryOps.cpp b/aten/src/ATen/functorch/BatchRulesBinaryOps.cpp index 4e228afdfc614..db601d3b0b8f1 100644 --- a/aten/src/ATen/functorch/BatchRulesBinaryOps.cpp +++ b/aten/src/ATen/functorch/BatchRulesBinaryOps.cpp @@ -385,7 +385,6 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { BINARY_SCALAR_2(div, Tensor_mode, Scalar_mode); BINARY_POINTWISE(floor_divide); - UNARY_POINTWISE2(floor_divide, Scalar); BINARY_POINTWISE(fmax); BINARY_POINTWISE(fmin); diff --git a/aten/src/ATen/functorch/BatchRulesConvolution.cpp b/aten/src/ATen/functorch/BatchRulesConvolution.cpp index 0640af3a1b533..90cd68b2e0da1 100644 --- a/aten/src/ATen/functorch/BatchRulesConvolution.cpp +++ b/aten/src/ATen/functorch/BatchRulesConvolution.cpp @@ -17,7 +17,7 @@ namespace at { namespace functorch { // we do not support batch_group_count (which is needed for convolution backwards). // Instead, there's a convolution_backward op that needs a batching rule. std::tuple> -convolution_batch_rule(const Tensor& lhs, optional lhs_bdim, const Tensor& rhs, optional rhs_bdim, const optional& bias, optional bias_bdim, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool transposed, IntArrayRef output_padding, int64_t groups) { +convolution_batch_rule(const Tensor& lhs, optional lhs_bdim, const Tensor& rhs, optional rhs_bdim, const optional& bias, optional bias_bdim, IntArrayRef stride, c10::SymIntArrayRef padding, IntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, int64_t groups) { DimVector lhs_spec(stride.size() + 2); std::iota(lhs_spec.begin(), lhs_spec.end(), 0); DimVector rhs_spec = lhs_spec; @@ -42,13 +42,13 @@ convolution_batch_rule(const Tensor& lhs, optional lhs_bdim, const Tens std::tuple> result; if (lhs_bdim && !rhs_bdim) { auto new_x = reshape_dim_into(*lhs_bdim, lhs_spec[0], lhs); - auto out = at::convolution(new_x, rhs, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups); + auto out = at::convolution_symint(new_x, rhs, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups); out = reshape_dim_outof(out_spec[0], lhs.sizes()[*lhs_bdim], out); result = std::make_tuple(out, out_spec[0]); } else if (!lhs_bdim && rhs_bdim) { if (groups == 1) { auto new_w = reshape_dim_into(*rhs_bdim, rhs_spec[0], rhs); - auto out = at::convolution(lhs, new_w, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups); + auto out = at::convolution_symint(lhs, new_w, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups); out = reshape_dim_outof(out_spec[1], rhs.size(*rhs_bdim), out); result = std::make_tuple(out, out_spec[1]); } else { @@ -62,7 +62,7 @@ convolution_batch_rule(const Tensor& lhs, optional lhs_bdim, const Tens // BIOHW -> I(BO)HW auto new_w = reshape_dim_into(*rhs_bdim, 1, rhs); // NIHW, I(BO)HW -> N(GBO)HW - auto out = at::convolution(lhs, new_w, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups); + auto out = at::convolution_symint(lhs, new_w, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups); // N(GBO)HW -> NG(BO)HW out = reshape_dim_outof(1, groups, out); // NG(BO)HW -> NGBOHW @@ -84,7 +84,7 @@ convolution_batch_rule(const Tensor& lhs, optional lhs_bdim, const Tens // G(BO)IHW -> (GBO)IHW new_w = reshape_dim_into(0, 0, new_w); // N(GI)HW, (GBO)IHW -> N(GBO)HW - auto out = at::convolution(lhs, new_w, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups); + auto out = at::convolution_symint(lhs, new_w, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups); // N(GBO)HW -> NG(BO)HW out = reshape_dim_outof(1, groups, out); // NG(BO)HW -> NGBOHW @@ -99,11 +99,11 @@ convolution_batch_rule(const Tensor& lhs, optional lhs_bdim, const Tens groups *= lhs.sizes()[*lhs_bdim]; auto dim_with_groups = transposed ? 1 : 0; auto new_w = reshape_dim_into(*rhs_bdim, rhs_spec[dim_with_groups], rhs); - auto out = at::convolution(new_x, new_w, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups); + auto out = at::convolution_symint(new_x, new_w, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups); out = reshape_dim_outof(out_spec[1], lhs.sizes()[*lhs_bdim], out); result = std::make_tuple(out, out_spec[1]); } else { - result = std::make_tuple(at::convolution(lhs, rhs, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups), nullopt); + result = std::make_tuple(at::convolution_symint(lhs, rhs, unbatched_bias, stride, padding, dilation, transposed, output_padding, groups), nullopt); } if (separate_bias) { auto A = std::get<0>(result); @@ -244,8 +244,8 @@ convolution_backward_input_batch_rule( const Tensor& grad_output, optional grad_output_bdim, const Tensor& input, optional input_bdim, const Tensor& weight, optional weight_bdim, - IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool transposed, - IntArrayRef output_padding, int64_t groups) { + IntArrayRef stride, c10::SymIntArrayRef padding, IntArrayRef dilation, bool transposed, + c10::SymIntArrayRef output_padding, int64_t groups) { const std::array mask = {true, false, false}; if (grad_output_bdim && weight_bdim) { // regular: BNO, BOI -> N(BO), (BO)I -> N(BI) @@ -254,7 +254,7 @@ convolution_backward_input_batch_rule( const auto grad_output_ = reshape_dim_into(*grad_output_bdim, 1, grad_output); const auto weight_ = reshape_dim_into(*weight_bdim, 0, weight); auto dummy_input = make_dummy(input, input_bdim, 1, batch_size); - const auto result = at::convolution_backward( + const auto result = at::convolution_backward_symint( grad_output_, dummy_input, weight_, nullopt, stride, padding, dilation, transposed, output_padding, groups * batch_size, mask); const auto grad_input = reshape_dim_outof(1, batch_size, std::get<0>(result)); @@ -265,7 +265,7 @@ convolution_backward_input_batch_rule( const auto batch_size = grad_output.size(*grad_output_bdim); const auto grad_output_ = reshape_dim_into(*grad_output_bdim, 0, grad_output); auto dummy_input = make_dummy(input, input_bdim, 0, batch_size); - const auto result = at::convolution_backward( + const auto result = at::convolution_backward_symint( grad_output_, dummy_input, weight, nullopt, stride, padding, dilation, transposed, output_padding, groups, mask); const auto grad_input = reshape_dim_outof(0, batch_size, std::get<0>(result)); @@ -278,7 +278,7 @@ convolution_backward_input_batch_rule( const auto in_ch_dim = transposed ? 0 : 1; const auto weight_ = reshape_dim_into(*weight_bdim, in_ch_dim, weight); auto dummy_input = make_dummy(input, input_bdim, 1, batch_size); - const auto result = at::convolution_backward( + const auto result = at::convolution_backward_symint( grad_output, dummy_input, weight_, nullopt, stride, padding, dilation, transposed, output_padding, groups, mask); const auto grad_input = reshape_dim_outof(1, batch_size, std::get<0>(result)); @@ -289,7 +289,7 @@ convolution_backward_input_batch_rule( // N(GO), B(GO)I -> N(GO), (GO)(BI) -> N(GBI) const auto weight_ = reshape_dim_into(*weight_bdim, 1, weight); auto dummy_input = make_dummy(input, input_bdim, 1, batch_size); - const auto result = at::convolution_backward( + const auto result = at::convolution_backward_symint( grad_output, dummy_input, weight_, nullopt, stride, padding, dilation, transposed, output_padding, groups, mask); grad_input = std::get<0>(result); // N(GBI) @@ -300,7 +300,7 @@ convolution_backward_input_batch_rule( weight_ = weight_.transpose(0, 1); // GBIO weight_ = weight_.flatten(0, 2); // (GBI)O const auto dummy_input = make_dummy(input, input_bdim, 1, batch_size); - const auto result = at::convolution_backward( + const auto result = at::convolution_backward_symint( grad_output, dummy_input, weight_, nullopt, stride, padding, dilation, transposed, output_padding, groups, mask); grad_input = std::get<0>(result); // N(GBI) @@ -314,7 +314,7 @@ convolution_backward_input_batch_rule( } else { TORCH_INTERNAL_ASSERT(input_bdim); const auto dummy_input = make_dummy(input, input_bdim, 0, 1); - const auto result = at::convolution_backward( + const auto result = at::convolution_backward_symint( grad_output, dummy_input, weight, nullopt, stride, padding, dilation, transposed, output_padding, groups, mask); return std::make_tuple(std::get<0>(result), nullopt); @@ -325,8 +325,8 @@ convolution_backward_weight_batch_rule( const Tensor& grad_output, optional grad_output_bdim, const Tensor& input, optional input_bdim, const Tensor& weight, optional weight_bdim, - IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool transposed, - IntArrayRef output_padding, int64_t groups) { + IntArrayRef stride, c10::SymIntArrayRef padding, IntArrayRef dilation, bool transposed, + c10::SymIntArrayRef output_padding, int64_t groups) { const std::array mask = {false, true, false}; if (grad_output_bdim && input_bdim) { // BNO, BNI -> N(BO), N(BI) -> (BO)I (regular) (BI)O (transposed) @@ -334,7 +334,7 @@ convolution_backward_weight_batch_rule( const auto grad_output_ = reshape_dim_into(*grad_output_bdim, 1, grad_output); const auto input_ = reshape_dim_into(*input_bdim, 1, input); const auto dummy_weight = make_dummy(weight, weight_bdim, 0, batch_size); - const auto result = at::convolution_backward( + const auto result = at::convolution_backward_symint( grad_output_, input_, dummy_weight, nullopt, stride, padding, dilation, transposed, output_padding, groups * batch_size, mask); auto grad_weight = std::get<1>(result); @@ -348,7 +348,7 @@ convolution_backward_weight_batch_rule( const auto grad_output_ = reshape_dim_into(*grad_output_bdim, 1, grad_output); const auto out_ch_dim = transposed ? 1 : 0; const auto dummy_weight = make_dummy(weight, weight_bdim, out_ch_dim, batch_size); - const auto result = at::convolution_backward( + const auto result = at::convolution_backward_symint( grad_output_, input, dummy_weight, nullopt, stride, padding, dilation, transposed, output_padding, groups, mask); auto grad_weight = std::get<1>(result); @@ -362,7 +362,7 @@ convolution_backward_weight_batch_rule( if (!transposed) { // BN(GO), N(GI) -> N(GBO), N(GI) -> (GBO)I const auto dummy_weight = make_dummy(weight, weight_bdim, 0, batch_size); - const auto result = at::convolution_backward( + const auto result = at::convolution_backward_symint( grad_output_, input, dummy_weight, nullopt, stride, padding, dilation, transposed, output_padding, groups, mask); auto grad_weight = std::get<1>(result); @@ -373,7 +373,7 @@ convolution_backward_weight_batch_rule( } else { // BN(GO), N(GI) -> N(GBO), N(GI) -> (GI)(BO) const auto dummy_weight = make_dummy(weight, weight_bdim, 1, batch_size); - const auto result = at::convolution_backward( + const auto result = at::convolution_backward_symint( grad_output_, input, dummy_weight, nullopt, stride, padding, dilation, transposed, output_padding, groups, mask); auto grad_weight = std::get<1>(result); @@ -389,7 +389,7 @@ convolution_backward_weight_batch_rule( const auto input_ = reshape_dim_into(*input_bdim, 1, input); const auto in_ch_dim = transposed ? 0 : 1; const auto dummy_weight = make_dummy(weight, weight_bdim, in_ch_dim, batch_size); - const auto result = at::convolution_backward( + const auto result = at::convolution_backward_symint( grad_output, input_, dummy_weight, nullopt, stride, padding, dilation, transposed, output_padding, groups, mask); auto grad_weight = std::get<1>(result); @@ -403,7 +403,7 @@ convolution_backward_weight_batch_rule( if (!transposed) { // regular: N(GO), BN(GI) -> N(GO), N(GBI) -> (GO)(BI) const auto dummy_weight = make_dummy(weight, weight_bdim, 1, batch_size); - const auto result = at::convolution_backward( + const auto result = at::convolution_backward_symint( grad_output, input_, dummy_weight, nullopt, stride, padding, dilation, transposed, output_padding, groups, mask); auto grad_weight = std::get<1>(result); @@ -412,7 +412,7 @@ convolution_backward_weight_batch_rule( } else { // transposed: N(GO), BN(GI) -> N(GO), N(GBI) -> (GBI)O const auto dummy_weight = make_dummy(weight, weight_bdim, 0, batch_size); - const auto result = at::convolution_backward( + const auto result = at::convolution_backward_symint( grad_output, input_, dummy_weight, nullopt, stride, padding, dilation, transposed, output_padding, groups, mask); auto grad_weight = std::get<1>(result); @@ -425,7 +425,7 @@ convolution_backward_weight_batch_rule( } else { TORCH_INTERNAL_ASSERT(weight_bdim); const auto dummy_weight = make_dummy(weight, weight_bdim, 0, 1); - const auto result = at::convolution_backward( + const auto result = at::convolution_backward_symint( grad_output, input, dummy_weight, nullopt, stride, padding, dilation, transposed, output_padding, groups, mask); return std::make_tuple(std::get<1>(result), nullopt); @@ -436,10 +436,10 @@ convolution_backward_weight_batch_rule( std::tuple convolution_backward_plumbing( const Tensor& grad_output_, const Tensor& input_, const Tensor& weight_, const c10::OptionalArrayRef bias_sizes_opt, - IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool transposed, - IntArrayRef output_padding, int64_t groups, std::array output_mask) { + IntArrayRef stride, c10::SymIntArrayRef padding, IntArrayRef dilation, bool transposed, + c10::SymIntArrayRef output_padding, int64_t groups, std::array output_mask) { const auto maybe_layer = maybeCurrentDynamicLayer(); - TORCH_INTERNAL_ASSERT(maybe_layer.has_value()); + vmap_check_escaped(maybe_layer, "convolution_backward_plumbing"); int64_t cur_level = maybe_layer->layerId(); if (!areAnyBatchedAtLevel({grad_output_, input_, weight_}, cur_level)){ @@ -487,7 +487,7 @@ std::tuple convolution_backward_plumbing( const auto batch_size = weight.size(*weight_bdim); input = reshape_dim_into(*input_bdim, 1, input); weight = reshape_dim_into(*weight_bdim, 0, weight); - const auto result = at::convolution_backward( + const auto result = at::convolution_backward_symint( grad_output, input, weight, nullopt, stride, padding, dilation, transposed, output_padding, batch_size * groups, output_mask); // N(BI), (BO)I -> NBI, BOI diff --git a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp index 157fbf23bf6fd..eebb0ab6349dd 100644 --- a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp +++ b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp @@ -25,7 +25,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchVmapMode, m) { OP_DECOMPOSE(feature_dropout_); } -TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { +TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) { OP_DECOMPOSE2(__and__, Scalar); OP_DECOMPOSE2(__and__, Tensor); OP_DECOMPOSE2(__iand__, Tensor); @@ -41,11 +41,12 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { OP_DECOMPOSE(_batch_norm_impl_index); OP_DECOMPOSE(absolute); OP_DECOMPOSE(arctan2); + OP_DECOMPOSE(argsort); OP_DECOMPOSE(avg_pool1d); OP_DECOMPOSE(adaptive_max_pool1d); OP_DECOMPOSE(adaptive_avg_pool1d); m.impl("adaptive_avg_pool2d", native::adaptive_avg_pool2d_symint); - OP_DECOMPOSE(adaptive_avg_pool3d); + m.impl("adaptive_avg_pool3d", native::adaptive_avg_pool3d_symint); OP_DECOMPOSE(adjoint); OP_DECOMPOSE(arccos); OP_DECOMPOSE(arccosh); @@ -63,26 +64,29 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { OP_DECOMPOSE2(bitwise_or, Scalar); OP_DECOMPOSE2(bitwise_xor, Scalar); OP_DECOMPOSE(broadcast_tensors); - OP_DECOMPOSE(broadcast_to); + m.impl("broadcast_to", native::broadcast_to_symint); OP_DECOMPOSE(cartesian_prod); OP_DECOMPOSE(cdist); + OP_DECOMPOSE(chunk); OP_DECOMPOSE(clip); OP_DECOMPOSE2(clip, Tensor ); OP_DECOMPOSE(concat); OP_DECOMPOSE(conj_physical); + OP_DECOMPOSE(contiguous); OP_DECOMPOSE(combinations); OP_DECOMPOSE(corrcoef); OP_DECOMPOSE(cosine_embedding_loss); OP_DECOMPOSE(cosine_similarity); OP_DECOMPOSE(cov); + OP_DECOMPOSE(cross); m.impl("cross_entropy_loss", native::cross_entropy_loss_symint); OP_DECOMPOSE2(cumulative_trapezoid, x); OP_DECOMPOSE2(cumulative_trapezoid, dx); OP_DECOMPOSE2(dsplit, int); OP_DECOMPOSE2(dsplit, array); OP_DECOMPOSE(det); - m.impl("diag_backward", native::diag_backward_symint); OP_DECOMPOSE(diff); + OP_DECOMPOSE(diag); OP_DECOMPOSE(dstack); OP_DECOMPOSE(einsum); m.impl("embedding_backward", native::embedding_backward_symint); @@ -110,6 +114,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { OP_DECOMPOSE(flipud); OP_DECOMPOSE2(float_power, Tensor_Tensor); OP_DECOMPOSE2(float_power, Tensor_Scalar); + OP_DECOMPOSE2(floor_divide, Scalar); OP_DECOMPOSE(ger); OP_DECOMPOSE2(gradient, scalarint); OP_DECOMPOSE2(gradient, scalararray); @@ -133,7 +138,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { OP_DECOMPOSE(instance_norm); OP_DECOMPOSE(kron); OP_DECOMPOSE(l1_loss); - OP_DECOMPOSE(layer_norm); + m.impl("layer_norm", native::layer_norm_symint); OP_DECOMPOSE2(ldexp, Tensor); OP_DECOMPOSE2(less_equal, Tensor ); OP_DECOMPOSE2(less, Tensor ); @@ -185,7 +190,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { OP_DECOMPOSE(positive); OP_DECOMPOSE(qr); OP_DECOMPOSE(ravel); - OP_DECOMPOSE2(repeat_interleave, self_int); + m.impl("repeat_interleave.self_int", native::repeat_interleave_symint); OP_DECOMPOSE2(repeat_interleave, self_Tensor); m.impl("reshape", native::reshape_symint); OP_DECOMPOSE(resolve_conj); @@ -201,6 +206,22 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { OP_DECOMPOSE(special_multigammaln); OP_DECOMPOSE(special_polygamma); OP_DECOMPOSE(special_softmax); + OP_DECOMPOSE(special_digamma); + OP_DECOMPOSE(special_erf); + OP_DECOMPOSE(special_erfc); + OP_DECOMPOSE(special_erfinv); + OP_DECOMPOSE(special_exp2); + OP_DECOMPOSE(special_expm1); + OP_DECOMPOSE(special_expit); + OP_DECOMPOSE(special_gammaln); + OP_DECOMPOSE(special_i0); + OP_DECOMPOSE(special_log1p); + OP_DECOMPOSE(special_ndtr); + OP_DECOMPOSE(special_psi); + OP_DECOMPOSE(special_round); + OP_DECOMPOSE(special_sinc); + + m.impl("split.sizes", native::split_symint); OP_DECOMPOSE(square); OP_DECOMPOSE(numpy_T); @@ -243,7 +264,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { OP_DECOMPOSE2(where, ScalarSelf); OP_DECOMPOSE(orgqr); OP_DECOMPOSE2(unflatten, int); - OP_DECOMPOSE(_convolution_double_backward); + m.impl("_convolution_double_backward", native::_convolution_double_backward); OP_DECOMPOSE(conv_transpose1d); OP_DECOMPOSE2(conv_transpose2d, input); OP_DECOMPOSE2(conv_transpose3d, input); @@ -254,9 +275,9 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { OP_DECOMPOSE2(conv2d, padding); OP_DECOMPOSE2(conv3d, padding); OP_DECOMPOSE(_convolution_mode); - OP_DECOMPOSE(frobenius_norm); OP_DECOMPOSE(type_as); OP_DECOMPOSE(linalg_diagonal); + OP_DECOMPOSE(diagonal_copy); m.impl("pad", native::pad_symint); m.impl("_pad_circular", native::_pad_circular_symint); OP_DECOMPOSE(t_); diff --git a/aten/src/ATen/functorch/BatchRulesHelper.h b/aten/src/ATen/functorch/BatchRulesHelper.h index 219c01c89c56e..8e78ba71029b1 100644 --- a/aten/src/ATen/functorch/BatchRulesHelper.h +++ b/aten/src/ATen/functorch/BatchRulesHelper.h @@ -3,6 +3,9 @@ // // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. +#pragma once + +#include #include #include @@ -65,7 +68,7 @@ template struct BasicUnaryBatchRuleHelper; template -struct BasicUnaryBatchRuleHelper> { +struct BasicUnaryBatchRuleHelper> { static std::tuple> apply( const Tensor& tensor, optional batch_dim, @@ -90,7 +93,7 @@ template struct VariadicBdimsBatchRuleHelper; template -struct VariadicBdimsBatchRuleHelper> { +struct VariadicBdimsBatchRuleHelper> { static std::tuple> apply( const Tensor& tensor, optional batch_dim, @@ -123,7 +126,8 @@ void boxed_tensor_inputs_batch_rule(const c10::OperatorHandle& op, torch::jit::S c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); auto maybe_layer = maybeCurrentDynamicLayer(); - TORCH_INTERNAL_ASSERT(maybe_layer.has_value()); + vmap_check_escaped(maybe_layer, "boxed_tensor_inputs_batch_rule"); + int64_t cur_level = maybe_layer->layerId(); auto orig_arguments = torch::jit::last(*stack, num_arguments); @@ -241,7 +245,7 @@ inline void boxed_existing_bdim_all_batch_rule( c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); auto maybe_layer = maybeCurrentDynamicLayer(); - TORCH_INTERNAL_ASSERT(maybe_layer.has_value()); + vmap_check_escaped(maybe_layer, "boxed_existing_bdim_all_batch_rule"); int64_t cur_level = maybe_layer->layerId(); const auto arguments = torch::jit::last(stack, num_arguments); @@ -297,7 +301,7 @@ inline void boxed_all_tensors_have_optional_bdim( c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); auto maybe_layer = maybeCurrentDynamicLayer(); - TORCH_INTERNAL_ASSERT(maybe_layer.has_value()); + vmap_check_escaped(maybe_layer, "boxed_all_tensors_have_optional_bdim"); int64_t cur_level = maybe_layer->layerId(); const auto arguments = torch::jit::last(stack, num_arguments); @@ -379,7 +383,7 @@ template struct ExistingBdimBatchRuleHelper; template -struct ExistingBdimBatchRuleHelper> { +struct ExistingBdimBatchRuleHelper> { static std::tuple> apply( const Tensor& self, optional self_bdim, diff --git a/aten/src/ATen/functorch/BatchRulesLoss.cpp b/aten/src/ATen/functorch/BatchRulesLoss.cpp index 66c2b7fb3194d..6429856572878 100644 --- a/aten/src/ATen/functorch/BatchRulesLoss.cpp +++ b/aten/src/ATen/functorch/BatchRulesLoss.cpp @@ -59,7 +59,7 @@ Tensor binary_cross_entropy_plumbing( const Tensor& self, const Tensor& target, const optional& weight, int64_t reduction) { auto maybe_layer = maybeCurrentDynamicLayer(); - TORCH_INTERNAL_ASSERT(maybe_layer.has_value()); + vmap_check_escaped(maybe_layer, "binary_cross_entropy_plumbing"); int64_t cur_level = maybe_layer->layerId(); if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level) @@ -99,7 +99,7 @@ Tensor binary_cross_entropy_backward_plumbing( const Tensor& grad, const Tensor& input, const Tensor& target, const c10::optional& weight_opt, int64_t reduction) { auto maybe_layer = maybeCurrentDynamicLayer(); - TORCH_INTERNAL_ASSERT(maybe_layer.has_value()); + vmap_check_escaped(maybe_layer, "binary_cross_entropy_backward_plumbing"); int64_t cur_level = maybe_layer->layerId(); if (!areAnyBatchedAtLevel({grad, input, target, weight_opt}, cur_level)) { diff --git a/aten/src/ATen/functorch/BatchRulesModules.cpp b/aten/src/ATen/functorch/BatchRulesModules.cpp index 3968e2400397d..506ed3ae44052 100644 --- a/aten/src/ATen/functorch/BatchRulesModules.cpp +++ b/aten/src/ATen/functorch/BatchRulesModules.cpp @@ -21,16 +21,16 @@ static Tensor getStepTensor(const Tensor& indices, c10::SymInt bdim_size, c10::S std::tuple> embedding_batch_rule( const Tensor& weight, optional weight_bdim, const Tensor& indices, optional indices_bdim, - int64_t padding_idx, bool scale_grad_by_freq, bool sparse) { + c10::SymInt padding_idx, bool scale_grad_by_freq, bool sparse) { if (!weight_bdim && indices_bdim) { // B*, ED -> B*D - const auto result = at::embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse); + const auto result = at::embedding_symint(weight, indices, padding_idx, scale_grad_by_freq, sparse); return std::make_tuple(result, indices_bdim); } else if (weight_bdim && !indices_bdim) { // *, BED -> *, E(BD) -> *(BD) -> *BD const auto batch_size = weight.size(*weight_bdim); const auto weight_ = reshape_dim_into(*weight_bdim, /*embedding_dim*/1, weight); - auto result = at::embedding(weight_, indices, padding_idx, scale_grad_by_freq, sparse); + auto result = at::embedding_symint(weight_, indices, padding_idx, scale_grad_by_freq, sparse); result = reshape_dim_outof(-1, batch_size, result); return std::make_tuple(result, result.dim() - 2); } @@ -44,7 +44,7 @@ std::tuple> embedding_batch_rule( const auto range = getStepTensor(indices, batch_size, num_embeddings); indices_ = indices_ + range; - const auto result = at::embedding(weight_, indices_, padding_idx, scale_grad_by_freq, sparse); + const auto result = at::embedding_symint(weight_, indices_, padding_idx, scale_grad_by_freq, sparse); return std::make_tuple(result, 0); } @@ -52,7 +52,7 @@ std::tuple> embedding_dense_backward_batch_rule( const Tensor& grad_, optional grad_bdim, const Tensor& indices_, optional indices_bdim, - c10::SymInt num_weights, int64_t padding_idx, bool scale_grad_by_freq) { + c10::SymInt num_weights, c10::SymInt padding_idx, bool scale_grad_by_freq) { Tensor grad = grad_; Tensor indices = indices_; if (!indices_bdim && grad_bdim) { @@ -74,7 +74,7 @@ embedding_dense_backward_batch_rule( // Fill in the padding. We can't do it in the embedding_dense_backward call // because we need to fill in multiple rows! if (padding_idx >= 0) { - result.select(1, padding_idx).fill_(0); + result.select_symint(1, padding_idx).fill_(0); } return std::make_tuple(result, 0); } @@ -295,7 +295,7 @@ template struct UpsampleBackwardBatchRuleHelper> { static std::tuple> apply( const Tensor& grad_output, optional grad_output_bdim, - OptionalSymIntArrayRef output_size, c10::SymIntArrayRef input_size, + c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, T... extra_args) { auto grad_output_ = reshape_dim_into(*grad_output_bdim, 0, grad_output); TORCH_INTERNAL_ASSERT(input_size.size() > 0); @@ -375,11 +375,11 @@ struct CudnnGridSampleBackwardBatchRuleHelper { #define CUDNN_GRID_SAMPLE_BW_BATCH_RULE(fn)\ CudnnGridSampleBackwardBatchRuleHelper::apply -#define UPSAMPLE_BACKWARD(op, overload) VMAP_SUPPORT2(op, overload, SINGLE_ARG(\ +#define UPSAMPLE_BACKWARD(op) VMAP_SUPPORT(op, SINGLE_ARG(\ UpsampleBackwardBatchRuleHelper<\ - decltype(&ATEN_FN2(op, overload)),\ - &ATEN_FN2(op, overload),\ - c10::guts::function_traits::parameter_types>::apply)) + decltype(&ATEN_FN(op)),\ + &ATEN_FN(op),\ + c10::guts::function_traits::parameter_types>::apply)) #define UPSAMPLE_BATCH(op) \ EXISTING_BDIM2(op, vec); \ @@ -401,7 +401,6 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { VMAP_SUPPORT(cudnn_grid_sampler_backward, CUDNN_GRID_SAMPLE_BW_BATCH_RULE(cudnn_grid_sampler_backward)); VMAP_SUPPORT(cudnn_grid_sampler, GRID_SAMPLE_BATCH_RULE(cudnn_grid_sampler)); - VMAP_SUPPORT(cross, cross_batch_rule); EXISTING_BDIM(pixel_shuffle); EXISTING_BDIM(pixel_unshuffle); @@ -430,13 +429,13 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { UPSAMPLE_BATCH(upsample_nearest3d); UPSAMPLE_BATCH(upsample_trilinear3d); - UPSAMPLE_BACKWARD(upsample_bicubic2d_backward, vec); - UPSAMPLE_BACKWARD(upsample_bilinear2d_backward, vec); - UPSAMPLE_BACKWARD(upsample_linear1d_backward, vec); - UPSAMPLE_BACKWARD(upsample_nearest1d_backward, vec); - UPSAMPLE_BACKWARD(upsample_nearest2d_backward, vec); - UPSAMPLE_BACKWARD(upsample_nearest3d_backward, vec); - UPSAMPLE_BACKWARD(upsample_trilinear3d_backward, vec); + UPSAMPLE_BACKWARD(upsample_bicubic2d_backward); + UPSAMPLE_BACKWARD(upsample_bilinear2d_backward); + UPSAMPLE_BACKWARD(upsample_linear1d_backward); + UPSAMPLE_BACKWARD(upsample_nearest1d_backward); + UPSAMPLE_BACKWARD(upsample_nearest2d_backward); + UPSAMPLE_BACKWARD(upsample_nearest3d_backward); + UPSAMPLE_BACKWARD(upsample_trilinear3d_backward); m.impl("one_hot", one_hot_decomposition_hack); } }} diff --git a/aten/src/ATen/functorch/BatchRulesNorm.cpp b/aten/src/ATen/functorch/BatchRulesNorm.cpp index 5e6f85510163d..bdd80540e649c 100644 --- a/aten/src/ATen/functorch/BatchRulesNorm.cpp +++ b/aten/src/ATen/functorch/BatchRulesNorm.cpp @@ -222,7 +222,7 @@ std::tuple batch_norm_backward_plumbing( // plumbing auto maybe_layer = maybeCurrentDynamicLayer(); - TORCH_INTERNAL_ASSERT(maybe_layer.has_value()); + vmap_check_escaped(maybe_layer, "batch_norm_backward_plumbing"); int64_t cur_level = maybe_layer->layerId(); Tensor grad_out_value; @@ -304,7 +304,7 @@ std::tuple native_group_norm_plumbing( const Tensor& bias = *bias_maybe_owned; auto maybe_layer = maybeCurrentDynamicLayer(); - TORCH_INTERNAL_ASSERT(maybe_layer.has_value()); + vmap_check_escaped(maybe_layer, "native_group_norm_plumbing"); int64_t cur_level = maybe_layer->layerId(); if (!areAnyBatchedAtLevel({input, weight_opt, bias_opt}, cur_level)) { @@ -393,7 +393,7 @@ std::tuple native_group_norm_backward_plumbing( // plumbing auto maybe_layer = maybeCurrentDynamicLayer(); - TORCH_INTERNAL_ASSERT(maybe_layer.has_value()); + vmap_check_escaped(maybe_layer, "native_group_norm_backward_plumbing"); int64_t cur_level = maybe_layer->layerId(); if (!areAnyBatchedAtLevel({grad_out, input, mean, rstd, weight_opt}, cur_level)) { @@ -604,7 +604,7 @@ std::tuple native_layer_norm_backward_plumbing // plumbing auto maybe_layer = maybeCurrentDynamicLayer(); - TORCH_INTERNAL_ASSERT(maybe_layer.has_value()); + vmap_check_escaped(maybe_layer, "native_layer_norm_backward_plumbing"); int64_t cur_level = maybe_layer->layerId(); if (!areAnyBatchedAtLevel({grad_out, input, mean, rstd, weight_opt, bias_opt}, cur_level)) { c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); @@ -756,7 +756,7 @@ struct NativeBatchNormBackwardBatchRuleHelper { std::array output_mask) { auto maybe_layer = maybeCurrentDynamicLayer(); - TORCH_INTERNAL_ASSERT(maybe_layer.has_value()); + vmap_check_escaped(maybe_layer, "NativeBatchNormBackwardBatchRuleHelper.apply"); int64_t cur_level = maybe_layer->layerId(); if (!areAnyBatchedAtLevel({grad_out, input, weight_opt, running_mean_opt, @@ -786,7 +786,7 @@ struct CudnnBatchNormBackwardBatchRuleHelper { const at::Tensor & reserve) { auto maybe_layer = maybeCurrentDynamicLayer(); - TORCH_INTERNAL_ASSERT(maybe_layer.has_value()); + vmap_check_escaped(maybe_layer, "CudnnBatchNormBackwardBatchRuleHelper.apply"); int64_t cur_level = maybe_layer->layerId(); if (!areAnyBatchedAtLevel({input, grad_out, weight, running_mean_opt, @@ -814,7 +814,7 @@ struct MiopenBatchNormBackwardBatchRuleHelper { double eps) { auto maybe_layer = maybeCurrentDynamicLayer(); - TORCH_INTERNAL_ASSERT(maybe_layer.has_value()); + vmap_check_escaped(maybe_layer, "MiopenBatchNormBackwardBatchRuleHelper.apply"); int64_t cur_level = maybe_layer->layerId(); if (!areAnyBatchedAtLevel({input, grad_out, weight, running_mean_opt, @@ -875,10 +875,28 @@ std::tuple cudnn_batch_norm_backward_wrapper( return at::miopen_batch_norm_backward(input, grad_out, weight_opt, running_mean_opt, running_var_opt, save_mean_opt, save_rstd_opt, eps); } +// NB: This is NOT good. In the ideal world, we do NOT want to convert the new legit op back into native_batch_norm +// as native_batch_norm has a problematic schema--it promises it is functional when it is not. However, vmap doesn't +// work with dynamo anyway so we gain some buffer room to do wrong things here. The (reasonable) hope is that we will +// make native_batch_norm composite implicit within a few weeks and we can fix this before vmap works with dynamo. +std::tuple _native_batch_norm_legit_batch( + const Tensor& self, const c10::optional& weight_opt, const c10::optional& bias_opt, + Tensor& running_mean, Tensor& running_var, bool train, double momentum, double eps) { + return at::native_batch_norm(self, weight_opt, bias_opt, running_mean, running_var, train, momentum, eps); +} + +std::tuple _native_batch_norm_legit_no_stats_batch( + const Tensor& self, const c10::optional& weight_opt, const c10::optional& bias_opt, + bool train, double momentum, double eps) { + return at::native_batch_norm(self, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, eps); +} + TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { VMAP_SUPPORT(native_batch_norm, NATIVE_BATCH_NORM_BATCH_RULE(native_batch_norm)); VMAP_SUPPORT(cudnn_batch_norm, CUDNN_BATCH_NORM_BATCH_RULE(cudnn_batch_norm)); VMAP_SUPPORT(miopen_batch_norm, MIOPEN_BATCH_NORM_BATCH_RULE(miopen_batch_norm)); + m.impl("_native_batch_norm_legit", _native_batch_norm_legit_batch); + m.impl("_native_batch_norm_legit.no_stats", _native_batch_norm_legit_no_stats_batch); m.impl("native_batch_norm_backward", NATIVE_BATCH_NORM_BACKWARD_BATCH_RULE(native_batch_norm_backward)); m.impl("cudnn_batch_norm_backward", CUDNN_BATCH_NORM_BACKWARD_BATCH_RULE(at::functorch::cudnn_batch_norm_backward_wrapper)); m.impl("miopen_batch_norm_backward", MIOPEN_BATCH_NORM_BACKWARD_BATCH_RULE(at::functorch::miopen_batch_norm_backward_wrapper)); diff --git a/aten/src/ATen/functorch/BatchRulesReduceOps.cpp b/aten/src/ATen/functorch/BatchRulesReduceOps.cpp index 8654b78455014..ec849c9794b4d 100644 --- a/aten/src/ATen/functorch/BatchRulesReduceOps.cpp +++ b/aten/src/ATen/functorch/BatchRulesReduceOps.cpp @@ -72,7 +72,7 @@ void boxed_reduction_batch_rule(const c10::OperatorHandle& op, torch::jit::Stack c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); auto maybe_layer = maybeCurrentDynamicLayer(); - TORCH_INTERNAL_ASSERT(maybe_layer.has_value()); + vmap_check_escaped(maybe_layer, "boxed_reduction_batch_rule"); int64_t cur_level = maybe_layer->layerId(); auto orig_arguments = torch::jit::last(*stack, num_arguments); @@ -168,7 +168,7 @@ void boxed_reduction_batch_rule(const c10::OperatorHandle& op, torch::jit::Stack #define REDUCTION_BOXED_ARGS(op, dim_pos) \ m.impl(#op, torch::CppFunction::makeFromBoxedFunction>()); -// Skipping frobenius/nuclear/all/any since they don't have opinfo tests right now :P +// Skipping all/any since they don't have opinfo tests right now :P Tensor dist_decomp(const Tensor& self, const Tensor& other, const Scalar& p) { return at::norm((self - other), p); @@ -412,7 +412,6 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { REDUCTION_BOXED(_softmax); REDUCTION_BOXED(sort); REDUCTION_BOXED_ARGS(sort.stable, 2); - REDUCTION_BOXED(argsort); REDUCTION_BOXED(std_mean.correction); m.impl("sum", sum_decomp); REDUCTION_BOXED(sum.dim_IntList); diff --git a/aten/src/ATen/functorch/BatchRulesScatterOps.cpp b/aten/src/ATen/functorch/BatchRulesScatterOps.cpp index 5eecbedd93e7b..c1d66369fb1fe 100644 --- a/aten/src/ATen/functorch/BatchRulesScatterOps.cpp +++ b/aten/src/ATen/functorch/BatchRulesScatterOps.cpp @@ -319,7 +319,7 @@ Tensor index_plumbing(const Tensor & self, const List> & indice ) { c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); auto maybe_layer = maybeCurrentDynamicLayer(); - TORCH_INTERNAL_ASSERT(maybe_layer.has_value()); + vmap_check_escaped(maybe_layer, "index_plumbing"); int64_t cur_level = maybe_layer->layerId(); if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level)) { return at::index(self, indices); @@ -506,7 +506,7 @@ Tensor& index_put__plumbing(Tensor & self, const List> & indice , const Tensor & values, bool accumulate) { c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); auto maybe_layer = maybeCurrentDynamicLayer(); - TORCH_INTERNAL_ASSERT(maybe_layer.has_value()); + vmap_check_escaped(maybe_layer, "index_put__plumbing"); int64_t cur_level = maybe_layer->layerId(); if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { return self.index_put_(indices, values, accumulate); @@ -545,7 +545,7 @@ Tensor &_index_put_impl__plumbing(Tensor &self, const List> &in const Tensor &values, bool accumulate, bool unsafe) { c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); auto maybe_layer = maybeCurrentDynamicLayer(); - TORCH_INTERNAL_ASSERT(maybe_layer.has_value()); + vmap_check_escaped(maybe_layer, "_index_put_impl__plumbing"); int64_t cur_level = maybe_layer->layerId(); if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { return at::_index_put_impl_(self, indices, values, accumulate, unsafe); @@ -666,7 +666,7 @@ Tensor index_put_plumbing(const Tensor & self, const List> & in const Tensor & values, bool accumulate) { c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); auto maybe_layer = maybeCurrentDynamicLayer(); - TORCH_INTERNAL_ASSERT(maybe_layer.has_value()); + vmap_check_escaped(maybe_layer, "index_put_plumbing"); int64_t cur_level = maybe_layer->layerId(); if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { return self.index_put(indices, values, accumulate); @@ -928,6 +928,11 @@ Tensor index_copy_decomp( return at::scatter(self, dim, index_, source); ; } +// Note [Fix vmap slice_scatter] +// registers a decomposition for `slice_scatter` that calls into `slice.src` +// *_scatter operators have some special semantics though, that we can't easily +// through a decomposition: slice_scatter's output needs to have the same +// size, size, strides and storage_offset as the input. Tensor slice_scatter_decomp(const Tensor &self, const Tensor &src, int64_t dim, c10::optional start, c10::optional end, int64_t step) diff --git a/aten/src/ATen/functorch/BatchRulesUnaryOps.cpp b/aten/src/ATen/functorch/BatchRulesUnaryOps.cpp index ee6391c6e2844..8cd4385fea863 100644 --- a/aten/src/ATen/functorch/BatchRulesUnaryOps.cpp +++ b/aten/src/ATen/functorch/BatchRulesUnaryOps.cpp @@ -93,7 +93,6 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { UNARY_POINTWISE(view_as_real); VMAP_SUPPORT(view_as_complex, view_as_complex_batch_rule); VMAP_SUPPORT(clone, clone_batch_rule); - VMAP_SUPPORT(contiguous, contiguous_batch_rule); VMAP_SUPPORT2(to, device, BASIC_UNARY_BATCH_RULE(ATEN_FN2(to, device))); VMAP_SUPPORT2(to, dtype, BASIC_UNARY_BATCH_RULE(ATEN_FN2(to, dtype))); VMAP_SUPPORT2(to, dtype_layout, BASIC_UNARY_BATCH_RULE(ATEN_FN2(to, dtype_layout))); @@ -163,25 +162,11 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { // torch.special.* functions UNARY_POINTWISE(special_entr); - UNARY_POINTWISE(special_erf); - UNARY_POINTWISE(special_erfc); UNARY_POINTWISE(special_erfcx); - UNARY_POINTWISE(special_erfinv); - UNARY_POINTWISE(special_expit); - UNARY_POINTWISE(special_expm1); - UNARY_POINTWISE(special_digamma); - UNARY_POINTWISE(special_psi); - UNARY_POINTWISE(special_exp2); - UNARY_POINTWISE(special_gammaln); - UNARY_POINTWISE(special_i0); UNARY_POINTWISE(special_i0e); UNARY_POINTWISE(special_i1); UNARY_POINTWISE(special_i1e); - UNARY_POINTWISE(special_log1p); - UNARY_POINTWISE(special_ndtr); UNARY_POINTWISE(special_ndtri); - UNARY_POINTWISE(special_round); - UNARY_POINTWISE(special_sinc); // Activation functions (from https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity) UNARY_POINTWISE_ALL(elu); diff --git a/aten/src/ATen/functorch/BatchRulesViews.cpp b/aten/src/ATen/functorch/BatchRulesViews.cpp index 9dd014a4307f9..e083d9d1c4ea5 100644 --- a/aten/src/ATen/functorch/BatchRulesViews.cpp +++ b/aten/src/ATen/functorch/BatchRulesViews.cpp @@ -172,7 +172,7 @@ const Tensor& resize__plumbing( optional_memory_format == c10::MemoryFormat::Contiguous, "resize_: batching rule only supports None or Contiguous MemoryFormat"); auto maybe_layer = maybeCurrentDynamicLayer(); - TORCH_INTERNAL_ASSERT(maybe_layer.has_value()); + vmap_check_escaped(maybe_layer, "resize__plumbing"); int64_t cur_level = maybe_layer->layerId(); if (!isBatchedAtLevel(self, cur_level)) { c10::impl::ExcludeDispatchKeyGuard guard2(DispatchKey::FuncTorchBatched); @@ -275,14 +275,14 @@ std::tuple, optional> chunk_batching_rule(const Ten return std::make_tuple(at::chunk(self_, chunks, new_dim), 0); } -std::tuple> select_batching_rule(const Tensor& self, optional bdim, int64_t dim, int64_t index) { +std::tuple> select_batching_rule(const Tensor& self, optional bdim, int64_t dim, c10::SymInt index) { if (!bdim) { - return std::make_tuple(self.select(dim, index), nullopt); + return std::make_tuple(self.select_symint(dim, index), nullopt); } auto _self = moveBatchDimToFront(self, bdim); auto dim_physical = getPhysicalDim(_self, true, dim); - auto result = _self.select(dim_physical, index); + auto result = _self.select_symint(dim_physical, index); return std::make_tuple(result, 0); } @@ -402,7 +402,7 @@ std::tuple> permute_batching_rule( std::tuple> select_backward_batch_rule( const Tensor& grad_input, optional grad_input_bdim, - SymIntArrayRef input_sizes, int64_t dim, int64_t index) { + c10::SymIntArrayRef input_sizes, int64_t dim, c10::SymInt index) { auto logical_rank = rankWithoutBatchDim(grad_input, grad_input_bdim); auto grad_input_ = moveBatchDimToFront(grad_input, grad_input_bdim); dim = maybe_wrap_dim(dim, logical_rank + 1) + 1; @@ -438,6 +438,19 @@ std::tuple> view_batching_rule( return std::make_tuple(self_.view_symint(size_), 0); } +std::tuple> view_copy_batch_rule( + const Tensor& self, + optional self_bdim, + c10::SymIntArrayRef size) { + auto self_ = moveBatchDimToFront(self, self_bdim); + SymDimVector view_size(size.size() + 1); + view_size[0] = self_.size(0); + std::copy(size.cbegin(), size.cend(), view_size.begin() + 1); + + return std::make_tuple(at::view_copy_symint(self_, view_size), 0); +} + + template std::tuple> expand_batch_rule( const Tensor &self, optional self_bdim, SymIntArrayRef size, bool implicit) @@ -490,6 +503,18 @@ std::tuple> unfold_batch_rule( return std::make_tuple(result, 0); } +std::tuple> narrow_copy_batch_rule( + const Tensor &self, optional self_bdim, int64_t dim, c10::SymInt start, c10::SymInt length) +{ + TORCH_INTERNAL_ASSERT(self_bdim.has_value()); + auto self_ = moveBatchDimToFront(self, self_bdim); + auto logical_rank = rankWithoutBatchDim(self, self_bdim); + dim = maybe_wrap_dim(dim, logical_rank) + 1; + auto result = self_.narrow_copy_symint(dim, start, length); + + return std::make_tuple(result, 0); +} + std::tuple> movedim_batch_rule(const Tensor& self, optional self_bdim, IntArrayRef source, IntArrayRef destination) { auto self_ = moveBatchDimToFront(self, self_bdim); auto source_ = getPhysicalDims(self_, self_bdim.has_value(), source); @@ -506,12 +531,10 @@ std::tuple> diag_embed_batch_rule(const Tensor& self, } Tensor trace_decomp(const Tensor& tensor) { - return tensor.diag().sum(); + return tensor.diagonal().sum(); } TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { - VMAP_SUPPORT(diag, diag_batch_rule); - VMAP_SUPPORT(chunk, chunk_batching_rule); m.impl("flatten.using_ints", static_cast(native::flatten)); VMAP_SUPPORT(flip, flip_batch_rule); m.impl("trace", trace_decomp); @@ -532,6 +555,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { VMAP_SUPPORT(select_backward, select_backward_batch_rule); VMAP_SUPPORT(slice_backward, slice_backward_batch_rule); VMAP_SUPPORT(view, view_batching_rule); + VMAP_SUPPORT(view_copy, view_copy_batch_rule); VMAP_SUPPORT(expand, SINGLE_ARG(expand_batch_rule)); VMAP_SUPPORT(expand_copy, SINGLE_ARG(expand_batch_rule)); VMAP_SUPPORT(unfold, unfold_batch_rule); @@ -539,6 +563,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { VMAP_SUPPORT2(slice, Tensor, slice_batch_rule); VMAP_SUPPORT2(transpose, int, transpose_int_batch_rule); VMAP_SUPPORT(diag_embed, diag_embed_batch_rule); + VMAP_SUPPORT(narrow_copy, narrow_copy_batch_rule); } }} diff --git a/aten/src/ATen/functorch/DynamicLayer.cpp b/aten/src/ATen/functorch/DynamicLayer.cpp index 8a2668fe748b1..30fcc9e70bb25 100644 --- a/aten/src/ATen/functorch/DynamicLayer.cpp +++ b/aten/src/ATen/functorch/DynamicLayer.cpp @@ -101,7 +101,7 @@ class FuncTorchTLS : public FuncTorchTLSBase { } int64_t checkSupportsAutogradFunction() const override { - TORCH_CHECK(dynamicLayerStack.size() == 0, + TORCH_CHECK(dynamicLayerStack.size() == 0 || getAutogradFunctionAllowed(), "functorch functions (vmap, grad, vjp, etc.) currently do not support the use of autograd.Function. ", "Please rewrite your function to not use autograd.Function while we work on fixing this"); return 0; @@ -128,6 +128,7 @@ class FuncTorchTLS : public FuncTorchTLSBase { std::vector dynamicLayerStack; bool allow_inplace_requires_grad_ = false; + bool allow_autograd_function_ = false; }; static FuncTorchTLS* getRawFunctorchTLS() { @@ -151,6 +152,16 @@ bool getInplaceRequiresGradAllowed() { return functorch_tls->allow_inplace_requires_grad_; } +void setAutogradFunctionAllowed(bool allowed) { + auto* functorch_tls = getRawFunctorchTLS(); + functorch_tls->allow_autograd_function_ = allowed; +} + +bool getAutogradFunctionAllowed() { + auto* functorch_tls = getRawFunctorchTLS(); + return functorch_tls->allow_autograd_function_; +} + static std::vector& dynamicLayerStackAccessor() { return getRawFunctorchTLS()->dynamicLayerStack; } @@ -203,7 +214,7 @@ bool areTransformsActive() { return !data.empty(); } -static DynamicLayer popDynamicLayer() { +DynamicLayer popDynamicLayer() { auto& dynamicLayerStack = dynamicLayerStackAccessor(); TORCH_INTERNAL_ASSERT(dynamicLayerStack.size() > 0); auto result = dynamicLayerStack.back(); @@ -221,7 +232,7 @@ static DynamicLayer popDynamicLayer() { return result; } -static int64_t pushDynamicLayer(DynamicLayer&& dynamic_layer) { +int64_t pushDynamicLayer(DynamicLayer&& dynamic_layer) { auto& dynamicLayerStack = dynamicLayerStackAccessor(); int64_t layerId = 1 + dynamicLayerStack.size(); TORCH_INTERNAL_ASSERT(layerId == dynamic_layer.layerId()); @@ -280,6 +291,14 @@ DynamicLayer popDynamicLayerAndDeleteMetadata() { return result; } +bool isDeadTensorWrapper(const Tensor& tensor) { + auto* wrapped = maybeGetTensorWrapper(tensor); + if (!wrapped) { + return false; + } + return !wrapped->is_alive(); +} + Tensor unwrapIfDead(const Tensor& tensor) { auto* wrapped = maybeGetTensorWrapper(tensor); if (!wrapped) { diff --git a/aten/src/ATen/functorch/DynamicLayer.h b/aten/src/ATen/functorch/DynamicLayer.h index 576a9621651a4..90e9ae514f5be 100644 --- a/aten/src/ATen/functorch/DynamicLayer.h +++ b/aten/src/ATen/functorch/DynamicLayer.h @@ -108,15 +108,25 @@ TORCH_API bool isInplaceOp(const c10::FunctionSchema& schema); TORCH_API c10::optional findAliasedOutput(const FunctionSchema& schema, const int64_t immutable_input); TORCH_API Tensor unwrapIfDead(const Tensor& tensor); +TORCH_API bool isDeadTensorWrapper(const Tensor& tensor); // Pretty printers TORCH_API std::ostream& operator<<(std::ostream& os, const DynamicLayer& layer); TORCH_API std::ostream& operator<<(std::ostream& os, const std::vector& dynamicLayerStack); +// While a functorch transform is active, autograd.Function is disabled +// by default. The following two APIs are APIs for enabling +// autograd.Function. These are not user-facing APIs. +TORCH_API void setAutogradFunctionAllowed(bool allowed); +TORCH_API bool getAutogradFunctionAllowed(); + // While a functorch grad transform is active, Tensor.requires_grad_() gets // disabled. These two functions are the mechanism to controlling that. TORCH_API void setInplaceRequiresGradAllowed(bool allowed); TORCH_API bool getInplaceRequiresGradAllowed(); +TORCH_API DynamicLayer popDynamicLayer(); +TORCH_API int64_t pushDynamicLayer(DynamicLayer&& layer); + } } // namespace at diff --git a/aten/src/ATen/functorch/PlumbingHelper.cpp b/aten/src/ATen/functorch/PlumbingHelper.cpp index 5dd01d0abbcbe..5877d2380d247 100644 --- a/aten/src/ATen/functorch/PlumbingHelper.cpp +++ b/aten/src/ATen/functorch/PlumbingHelper.cpp @@ -10,6 +10,17 @@ namespace at { namespace functorch { +void vmap_check_escaped(const optional &layer, const char* what) { + TORCH_CHECK( + layer.has_value(), + "Either your tensor may have escaped from inside a function being vmapped and this is a user error ", + "(see https://pytorch.org/functorch/stable/ux_limitations.html), " + "or there is an internal functorch error in `", + what, + "` Please file an issue if it looks like the latter" + ) +} + Tensor makeBatched(const Tensor& tensor, optional bdim, int64_t level) { if (bdim.has_value()) { TORCH_INTERNAL_ASSERT(*bdim >= 0); diff --git a/aten/src/ATen/functorch/PlumbingHelper.h b/aten/src/ATen/functorch/PlumbingHelper.h index 9eb486a6eefa0..dfb7da5227d5b 100644 --- a/aten/src/ATen/functorch/PlumbingHelper.h +++ b/aten/src/ATen/functorch/PlumbingHelper.h @@ -26,6 +26,8 @@ namespace at { namespace functorch { +void vmap_check_escaped(const optional &layer, const char* what); + // Create a BatchedTensor given a tensor, bdim, and level TORCH_API Tensor makeBatched(const Tensor& tensor, optional bdim, int64_t level); diff --git a/aten/src/ATen/mps/MPSAllocator.h b/aten/src/ATen/mps/MPSAllocator.h index d739e8956d814..beb5723ea1c94 100644 --- a/aten/src/ATen/mps/MPSAllocator.h +++ b/aten/src/ATen/mps/MPSAllocator.h @@ -1,5 +1,6 @@ // Copyright © 2022 Apple Inc. +#include #include #include #include @@ -9,27 +10,10 @@ // this implementation is based on CUDACachingAllocator. // It utilizes Metal Heaps to improve the performance with buffer allocation. +// Do not include this header. Use MPSAllocatorInterface.h instead. // TODO: Unify the logic with CUDACachingAllocator and remove redundant code. namespace at { namespace mps { - -class IMpsAllocatorCallback { - public: - enum class EventType { - ALLOCATED, // buffer got allocated to be used immediately - RECYCLED, // buffer pulled from free list to be reused - FREED, // buffer put to free list for future recycling - RELEASED, // buffer memory released - }; - virtual ~IMpsAllocatorCallback() = default; - virtual void executeMPSAllocatorCallback(void* ptr, EventType event) = 0; -}; - -// MPS allocator will execute every registered callback when a block of memory is freed. -C10_DECLARE_REGISTRY(MPSAllocatorCallbacksRegistry, IMpsAllocatorCallback); -#define REGISTER_MPS_ALLOCATOR_CALLBACK(name, ...) \ - C10_REGISTER_CLASS(MPSAllocatorCallbacksRegistry, name, __VA_ARGS__); - namespace HeapAllocator { #define MB(x) round_page(x * 1048576UL) @@ -263,30 +247,47 @@ class MPSHeapAllocatorImpl // interface exposed to at::Allocator id malloc(size_t size, uint32_t usage); + // frees a buffer and returns it into buffer pool void free(void* ptr); + // releases all the cached buffers and their associated heaps void emptyCache(); - // interface exposed to internal MPS operations + // returns true if buffer was allocated from the shared pool bool isSharedBuffer(void* ptr); - ssize_t getRequestedBufferSize(void* ptr); + // get the requested unaligned size of an MTLBuffer + ssize_t getUnalignedBufferSize(void* ptr); + // set the shape of a base tensor from a view tensor void setBufferShape(void* ptr, const IntArrayRef& shape); + // retrieve the shape of a base tensor from a view tensor IntArrayRef getBufferShape(void* ptr); + // allocate a buffer from a specialized pool to import CPU scalars into GPU id allocScalarBufferWithValue(void* value, size_t size); // this indicates how far (in Megabytes) the current total allocations are from the // low watermark limit which is used to detect if we're under memory pressure // This returns zero if we've reached the low watermark limit ssize_t getLowWatermarkValue(); - - bool getDebugVerbosity() const { return m_debug_verbosity; } - size_t getMaxTotalAllowedSize() const { return m_max_total_allowed_size; } + // (see m_low_watermark_ratio for description) + void setLowWatermarkRatio(double ratio); + // (see m_high_watermark_ratio for description) + void setHighWatermarkRatio(double ratio); + // (see m_low_watermark_limit for description) size_t getLowWatermarkLimit() const { return m_low_watermark_limit; } + // (see m_max_total_allowed_size for description) + size_t getHighWatermarkLimit() const { return m_max_total_allowed_size; } + // (see m_total_allocated_memory for description) + size_t getTotalAllocatedMemory() const {return m_total_allocated_memory; } + // (see enum DebugVerbosity for description) + uint32_t getDebugVerbosity() const { return m_debug_verbosity; } + // returns the device that we allocate from inline id Device() const { return m_device; } private: // (see m_high_watermark_ratio for description) - constexpr static double default_high_watermark_ratio = 0.0; + constexpr static double default_high_watermark_ratio = 1.7; + // we set the allowed upper bound to twice the size of recommendedMaxWorkingSetSize. + constexpr static double default_high_watermark_upper_bound = 2.0; // (see m_low_watermark_ratio for description) // on unified memory, we could allocate beyond the recommendedMaxWorkingSetSize - constexpr static double default_low_watermark_ratio_unified = 1.5; + constexpr static double default_low_watermark_ratio_unified = 1.4; constexpr static double default_low_watermark_ratio_discrete = 1.0; const id m_device; @@ -375,17 +376,5 @@ class MPSHeapAllocatorImpl }; } // namespace HeapAllocator - -// interface exposed to internal MPS operations - -// get the requested non-aligned size of an MTL buffer -ssize_t get_requested_buffer_size(void* ptr); -// retrieve the shape of a base tensor from a view tensor -IntArrayRef get_buffer_shape(void* ptr); -// set the shape of a base tensor from a view tensor -void set_buffer_shape(void* ptr, const IntArrayRef& shape); -// allocate a buffer from a specialized pool to import CPU scalars into GPU -DataPtr allocate_scalar_buffer(void* value, size_t size); - } // namespace mps } // namespace at diff --git a/aten/src/ATen/mps/MPSAllocator.mm b/aten/src/ATen/mps/MPSAllocator.mm index a40ddd7992a29..ba3a63b5595a0 100644 --- a/aten/src/ATen/mps/MPSAllocator.mm +++ b/aten/src/ATen/mps/MPSAllocator.mm @@ -22,27 +22,35 @@ static const char *verbosity_str = getenv("PYTORCH_DEBUG_MPS_ALLOCATOR"); m_debug_verbosity = verbosity_str ? strtol(verbosity_str, nullptr, 0) : DebugVerbosity::SILENT; - // on unified memory, we set the allowed upper bound to twice the size of recommendedMaxWorkingSetSize. - const double high_watermark_upper_bound = m_device.hasUnifiedMemory ? 2.0 : 1.0; - static const char *high_watermark_ratio_str = getenv("PYTORCH_MPS_HIGH_WATERMARK_RATIO"); - m_high_watermark_ratio = high_watermark_ratio_str ? strtod(high_watermark_ratio_str, nullptr) : default_high_watermark_ratio; - TORCH_CHECK(m_high_watermark_ratio >= 0.0 && m_high_watermark_ratio <= high_watermark_upper_bound, - "invalid high watermark ratio ", m_high_watermark_ratio); + const double high_watermark_ratio = high_watermark_ratio_str ? strtod(high_watermark_ratio_str, nullptr) : + default_high_watermark_ratio; + setHighWatermarkRatio(high_watermark_ratio); - m_max_total_allowed_size = (m_high_watermark_ratio == 0.0) ? std::numeric_limits::max() : - static_cast(m_high_watermark_ratio * (double)max_device_size()); - // used for comparison with lower_watermark_ratio - const double high_watermark_limit = m_high_watermark_ratio == 0.0 ? high_watermark_upper_bound : m_high_watermark_ratio; const double default_low_watermark_ratio = m_device.hasUnifiedMemory ? default_low_watermark_ratio_unified : default_low_watermark_ratio_discrete; static const char *low_watermark_ratio_str = getenv("PYTORCH_MPS_LOW_WATERMARK_RATIO"); - m_low_watermark_ratio = low_watermark_ratio_str ? strtod(low_watermark_ratio_str, nullptr) : default_low_watermark_ratio; - TORCH_CHECK(m_low_watermark_ratio >= 0.0 && m_low_watermark_ratio <= high_watermark_limit, - "invalid low watermark ratio ", m_low_watermark_ratio); + const double low_watermark_ratio = low_watermark_ratio_str ? strtod(low_watermark_ratio_str, nullptr) : default_low_watermark_ratio; + setLowWatermarkRatio(low_watermark_ratio); +} + +void MPSHeapAllocatorImpl::setHighWatermarkRatio(double ratio) +{ + TORCH_CHECK(ratio >= 0.0 && ratio <= default_high_watermark_upper_bound, "invalid high watermark ratio ", ratio); + m_max_total_allowed_size = (ratio == 0.0) ? std::numeric_limits::max() : + static_cast(ratio * (double)max_device_size()); + m_high_watermark_ratio = ratio; +} + +void MPSHeapAllocatorImpl::setLowWatermarkRatio(double ratio) +{ + // used for comparison with lower_watermark_ratio + const double high_watermark_limit = m_high_watermark_ratio == 0.0 ? default_high_watermark_upper_bound : m_high_watermark_ratio; + TORCH_CHECK(ratio >= 0.0 && ratio <= high_watermark_limit, "invalid low watermark ratio ", ratio); // we use this to detect if there's memory pressure - m_low_watermark_limit = (m_low_watermark_ratio == 0.0) ? std::numeric_limits::max() : - static_cast(m_low_watermark_ratio * (double)max_device_size()); + m_low_watermark_limit = (ratio == 0.0) ? std::numeric_limits::max() : + static_cast(ratio * (double)max_device_size()); + m_low_watermark_ratio = ratio; } HeapBlock* MPSHeapAllocatorImpl::get_free_heap(AllocParams& params) @@ -386,7 +394,9 @@ void MPSHeapAllocatorImpl::garbage_collect_cached_buffers(AllocParams& params) { - TORCH_INTERNAL_ASSERT(current_allocated_size() >= m_low_watermark_limit); + // skip garbage collection if memory pressure has already relieved + if (current_allocated_size() < m_low_watermark_limit) + return; // attempt to collect garbage until we reach below low watermark limit const auto target_size = current_allocated_size() - m_low_watermark_limit; const BufferPool& pool = *params.pool; @@ -468,7 +478,7 @@ return buffer_block->buffer; } -ssize_t MPSHeapAllocatorImpl::getRequestedBufferSize(void* ptr) +ssize_t MPSHeapAllocatorImpl::getUnalignedBufferSize(void* ptr) { std::lock_guard lock(m_mutex); @@ -550,15 +560,15 @@ } // MPS allocator struct to be registered with Pytorch -struct TORCH_API MPSAllocator final : public at::Allocator { +struct TORCH_API MPSAllocator final : public IMPSAllocator { public: explicit MPSAllocator(uint32_t Usage) : m_has_unified_memory(_getAllocImpl().Device().hasUnifiedMemory), m_usage(Usage) { if (_getAllocImpl().getDebugVerbosity()) { if (!(m_usage & HeapAllocator::UsageFlags::SHARED) || m_has_unified_memory) { - const size_t max_total_allowed_size = _getAllocImpl().getMaxTotalAllowedSize(); - const size_t low_watermark_limit = _getAllocImpl().getLowWatermarkLimit(); + const size_t high_watermark_limit = _getAllocImpl().getHighWatermarkLimit(); + const size_t low_watermark_limit = _getAllocImpl().getLowWatermarkLimit(); std::cerr << "Initializing " << ((m_usage & HeapAllocator::UsageFlags::SHARED) ? "shared" : "private") << " heap allocator on " @@ -566,8 +576,8 @@ explicit MPSAllocator(uint32_t Usage) : << " device memory of size " << _getAllocImpl().Device().recommendedMaxWorkingSetSize / 1048576UL << " MB" << " (max allowed: " - << (max_total_allowed_size == std::numeric_limits::max() ? "unlimited" : - (to_string(max_total_allowed_size / 1048576UL) + " MB")) + << (high_watermark_limit == std::numeric_limits::max() ? "unlimited" : + (to_string(high_watermark_limit / 1048576UL) + " MB")) << ", low watermark: " << (low_watermark_limit == std::numeric_limits::max() ? "unlimited" : (to_string(low_watermark_limit / 1048576UL) + " MB")) << ")\n"; @@ -578,20 +588,28 @@ explicit MPSAllocator(uint32_t Usage) : ~MPSAllocator() override { _getAllocImpl().emptyCache(); } + DeleterFnPtr raw_deleter() const override { return &Delete; } DataPtr allocate(const size_t nbytes) const override { __block id buf = nbytes > 0 ? _getAllocImpl().malloc(nbytes, m_usage) : nullptr; return { buf, buf, &Delete, at::Device(at::DeviceType::MPS, 0)}; } - - DataPtr allocate_scalar_buffer(void *value, size_t size) const { + DataPtr allocScalarBufferWithValue(void *value, size_t size) const override { id buf = _getAllocImpl().allocScalarBufferWithValue(value, size); return { buf, buf, &Delete, at::Device(at::DeviceType::MPS, 0)}; } - - DeleterFnPtr raw_deleter() const override { return &Delete; } - bool is_shared(void* ptr) const { return _getAllocImpl().isSharedBuffer(ptr); } - bool is_shared_storage_supported() const { return m_has_unified_memory; } + bool isSharedBuffer(void* ptr) const override { return _getAllocImpl().isSharedBuffer(ptr); } + bool isSharedStorageSupported() const override { return m_has_unified_memory; } + void emptyCache() const override { _getAllocImpl().emptyCache(); } + ssize_t getUnalignedBufferSize(void* ptr) const override { return _getAllocImpl().getUnalignedBufferSize(ptr); } + IntArrayRef getBufferShape(void* ptr) const override { return _getAllocImpl().getBufferShape(ptr); } + void setBufferShape(void* ptr, const IntArrayRef& shape) override { _getAllocImpl().setBufferShape(ptr, shape); } + size_t getTotalAllocatedMemory() const override { return _getAllocImpl().getTotalAllocatedMemory(); } + ssize_t getLowWatermarkValue() const override { return _getAllocImpl().getLowWatermarkValue(); } + size_t getLowWatermarkLimit() const override { return _getAllocImpl().getLowWatermarkLimit(); } + size_t getHighWatermarkLimit() const override { return _getAllocImpl().getHighWatermarkLimit(); } + void setLowWatermarkRatio(double ratio) const override { _getAllocImpl().setLowWatermarkRatio(ratio); } + void setHighWatermarkRatio(double ratio) const override { _getAllocImpl().setHighWatermarkRatio(ratio); } private: bool m_has_unified_memory; @@ -616,41 +634,17 @@ static void Delete(void* ptr) { } } // anonymous namespace -at::Allocator* getMPSSharedAllocator() -{ +IMPSAllocator* getIMPSAllocator(bool sharedAllocator) { + if (!sharedAllocator) { + return &_getPrivateAllocator(); + } auto& sa = _getSharedAllocator(); - if (sa.is_shared_storage_supported()) { + if (sa.isSharedStorageSupported()) { return &sa; } - return nullptr; } -at::Allocator* getMPSPrivateAllocator() { - return &_getPrivateAllocator(); -} - -// TODO: create MPSHooks interface and move these there. -ssize_t get_requested_buffer_size(void* ptr) { - return _getAllocImpl().getRequestedBufferSize(ptr); -} - -void set_buffer_shape(void* ptr, const IntArrayRef& shape) { - _getAllocImpl().setBufferShape(ptr, shape); -} - -IntArrayRef get_buffer_shape(void* ptr) { - return _getAllocImpl().getBufferShape(ptr); -} - -DataPtr allocate_scalar_buffer(void *value, size_t size) { - return _getPrivateAllocator().allocate_scalar_buffer(value, size); -} - -uint32_t get_adaptive_commit_threshold() { - return _getAllocImpl().getLowWatermarkValue(); -} - } // namespace mps namespace native { @@ -662,14 +656,14 @@ uint32_t get_adaptive_commit_threshold() { bool is_pinned_mps(const Tensor& self, c10::optional device) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!device.has_value() || device->is_mps()); - return at::mps::_getSharedAllocator().is_shared(self.storage().data()); + return at::mps::_getSharedAllocator().isSharedBuffer(self.storage().data()); } // torch.pin_memory() implementation Tensor _pin_memory_mps(const Tensor& self, c10::optional device) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!device.has_value() || device->is_mps()); - auto* shared_allocator = at::mps::getMPSSharedAllocator(); + auto* shared_allocator = at::mps::getIMPSAllocator(true); TORCH_CHECK(shared_allocator, "unable to pin memory on a non-unified memory device"); const size_t storage_size = detail::computeStorageNbytes(self.sizes(), self.strides(), self.dtype().itemsize()); diff --git a/aten/src/ATen/mps/MPSAllocatorInterface.h b/aten/src/ATen/mps/MPSAllocatorInterface.h new file mode 100644 index 0000000000000..3278c599d34d3 --- /dev/null +++ b/aten/src/ATen/mps/MPSAllocatorInterface.h @@ -0,0 +1,50 @@ +// Copyright © 2023 Apple Inc. + +#include +#include +#include + +namespace at { +namespace mps { + +// this is a public interface to access MPSAllocator. +// Do not declare methods that would depend on MPS or Metal frameworks. +class IMPSAllocator : public c10::Allocator { +public: + // see the comments in MPSAllocator.h for the description of these methods. + virtual void emptyCache() const = 0; + virtual ssize_t getUnalignedBufferSize(void* ptr) const = 0; + virtual IntArrayRef getBufferShape(void* ptr) const = 0; + virtual void setBufferShape(void* ptr, const IntArrayRef& shape) = 0; + virtual bool isSharedBuffer(void* ptr) const = 0; + virtual bool isSharedStorageSupported() const = 0; + virtual c10::DataPtr allocScalarBufferWithValue(void* value, size_t size) const = 0; + virtual void setLowWatermarkRatio(double ratio) const = 0; + virtual void setHighWatermarkRatio(double ratio) const = 0; + virtual ssize_t getLowWatermarkValue() const = 0; + virtual size_t getLowWatermarkLimit() const = 0; + virtual size_t getHighWatermarkLimit() const = 0; + virtual size_t getTotalAllocatedMemory() const = 0; +}; + +class IMpsAllocatorCallback { + public: + enum class EventType { + ALLOCATED, // buffer got allocated to be used immediately + RECYCLED, // buffer pulled from free list to be reused + FREED, // buffer put to free list for future recycling + RELEASED, // buffer memory released + }; + virtual ~IMpsAllocatorCallback() = default; + virtual void executeMPSAllocatorCallback(void* ptr, EventType event) = 0; +}; + +// MPS allocator will execute every registered callback when a block of memory is freed. +C10_DECLARE_REGISTRY(MPSAllocatorCallbacksRegistry, IMpsAllocatorCallback); +#define REGISTER_MPS_ALLOCATOR_CALLBACK(name, ...) \ + C10_REGISTER_CLASS(MPSAllocatorCallbacksRegistry, name, __VA_ARGS__); + +IMPSAllocator* getIMPSAllocator(bool sharedAllocator = false); + +} // namespace mps +} // namespace at diff --git a/aten/src/ATen/mps/MPSDevice.h b/aten/src/ATen/mps/MPSDevice.h index 77e93ea1234a4..7bd1774f482fd 100644 --- a/aten/src/ATen/mps/MPSDevice.h +++ b/aten/src/ATen/mps/MPSDevice.h @@ -53,6 +53,10 @@ class TORCH_API MPSDevice { MTLDevice_t device() { return _mtl_device; } + /** + * Returns whether running on Ventura or newer + */ + bool isMacOS13Plus(int32_t subVersion) const; MTLFunction_t metalIndexingFunction(const std::string &kernel, MTLFunctionConstantValues_t constantValues); @@ -66,7 +70,8 @@ class TORCH_API MPSDevice { }; TORCH_API bool is_available(); - +TORCH_API bool is_macos_13_or_newer(int32_t subVersion = 0); +TORCH_API void device_synchronize(); TORCH_API at::Allocator* GetMPSAllocator(bool useSharedAllocator = false); } // namespace mps diff --git a/aten/src/ATen/mps/MPSDevice.mm b/aten/src/ATen/mps/MPSDevice.mm index 6569d57420c87..2a976bf117d3c 100644 --- a/aten/src/ATen/mps/MPSDevice.mm +++ b/aten/src/ATen/mps/MPSDevice.mm @@ -3,6 +3,8 @@ #include #include +#include +#include #include namespace at { @@ -66,6 +68,7 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id& de // Create the MPSGraph and check method introduced in 12.3+ // which is used by MPS backend. id mpsCD = NSClassFromString(@"MPSGraph"); + if ([mpsCD instancesRespondToSelector:@selector(LSTMWithSourceTensor: recurrentWeight: inputWeight: @@ -76,6 +79,7 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id& de name:)] == NO) { return; } + NSArray* devices = [MTLCopyAllDevices() autorelease]; for (unsigned long i = 0 ; i < [devices count] ; i++) { id device = devices[i]; @@ -85,17 +89,39 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id& de } } TORCH_INTERNAL_ASSERT_DEBUG_ONLY(_mtl_device); + +} + +bool MPSDevice::isMacOS13Plus(int32_t subVersion) const { + id mpsCD = NSClassFromString(@"MPSGraph"); + static bool _macos_13_0_plus = [mpsCD instancesRespondToSelector:@selector(cumulativeSumWithTensor:axis:name:)] == YES; + static bool _macos_13_1_plus = [mpsCD instancesRespondToSelector:@selector( + sampleGridWithSourceTensor:coordinateTensor:layout:normalizeCoordinates:relativeCoordinates:alignCorners:paddingMode:samplingMode:constantValue:name:)] == YES; + static bool _macos_13_2_plus = [mpsCD instancesRespondToSelector:@selector(convolution3DWithSourceTensor:weightsTensor:descriptor:name:)] == YES; + + switch (subVersion) { + case 0: return _macos_13_0_plus; + case 1: return _macos_13_1_plus; + case 2: return _macos_13_2_plus; + default: return false; + } } -at::Allocator* getMPSSharedAllocator(); -at::Allocator* getMPSPrivateAllocator(); at::Allocator* GetMPSAllocator(bool useSharedAllocator) { - return useSharedAllocator ? getMPSSharedAllocator() : getMPSPrivateAllocator(); + return getIMPSAllocator(useSharedAllocator); } bool is_available() { return MPSDevice::getInstance()->device() != nil; } +bool is_macos_13_or_newer(int32_t subVersion) { + return MPSDevice::getInstance()->isMacOS13Plus(subVersion); +} + +void device_synchronize() { + getDefaultMPSStream()->synchronize(SyncType::COMMIT_AND_WAIT); +} + } // namespace mps } // namespace at diff --git a/aten/src/ATen/mps/MPSFallback.mm b/aten/src/ATen/mps/MPSFallback.mm index f1c0dbbacdca3..bb2ea6e693793 100644 --- a/aten/src/ATen/mps/MPSFallback.mm +++ b/aten/src/ATen/mps/MPSFallback.mm @@ -59,10 +59,11 @@ Tensor slow_conv2d_forward_mps( m.impl("repeat_interleave.self_int", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>()); m.impl("_fft_c2c", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>()); m.impl("_fft_r2c", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>()); + m.impl("im2col", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>()); // Used in preprocessing by nn.Unfold + m.impl("col2im", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>()); m.impl("linalg_vector_norm", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>()); - m.impl("sgn.out", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>()); - m.impl("nonzero", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>()); m.impl("_slow_conv2d_forward", slow_conv2d_forward_mps); + m.impl("upsample_nearest3d.vec", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>()); } } // namespace at diff --git a/aten/src/ATen/mps/MPSGeneratorImpl.h b/aten/src/ATen/mps/MPSGeneratorImpl.h new file mode 100644 index 0000000000000..9695eb719274c --- /dev/null +++ b/aten/src/ATen/mps/MPSGeneratorImpl.h @@ -0,0 +1,52 @@ +// Copyright © 2022 Apple Inc. + +#pragma once + +#include +#include +#include +#include + +namespace at { +namespace mps { +namespace detail { + +static const uint32_t PHILOX_STATE_N = 7; +struct rng_data_pod { + std::array state{1}; + uint64_t seed = default_rng_seed_val; +}; + +TORCH_API const Generator& getDefaultMPSGenerator(); +TORCH_API Generator createMPSGenerator(uint64_t seed_val = default_rng_seed_val); + +} // namespace detail +} // namespace mps + +struct TORCH_API MPSGeneratorImpl : public c10::GeneratorImpl { + // Constructors + MPSGeneratorImpl(uint64_t seed_in = default_rng_seed_val); + ~MPSGeneratorImpl() override = default; + + // MPSGeneratorImpl methods + std::shared_ptr clone() const; + void set_current_seed(uint64_t seed) override; + uint64_t current_seed() const override; + uint64_t seed() override; + void set_state(const c10::TensorImpl& new_state) override; + c10::intrusive_ptr get_state() const override; + void update_philox_counters(); + + void set_engine(at::Philox4_32 engine) { engine_ = engine; }; + at::Philox4_32 engine() { return engine_; }; + uint32_t* state_data() { return data_.state.data(); } + static DeviceType device_type() { return DeviceType::MPS; }; + +private: + mps::detail::rng_data_pod data_; + at::Philox4_32 engine_; + + MPSGeneratorImpl* clone_impl() const override; +}; + +} // namespace at diff --git a/aten/src/ATen/mps/MPSGeneratorImpl.mm b/aten/src/ATen/mps/MPSGeneratorImpl.mm new file mode 100644 index 0000000000000..7eb6b7d987826 --- /dev/null +++ b/aten/src/ATen/mps/MPSGeneratorImpl.mm @@ -0,0 +1,100 @@ +// Copyright © 2022 Apple Inc. + +#include +#include +#include + +namespace at { +namespace mps { +namespace detail { + +const Generator& getDefaultMPSGenerator() { + static auto default_gen_mps = createMPSGenerator(c10::detail::getNonDeterministicRandom()); + return default_gen_mps; +} + +Generator createMPSGenerator(uint64_t seed_val) { + auto gen = make_generator(seed_val); + gen.set_current_seed(seed_val); + return gen; +} + +} // namespace detail +} // namespace mps + +MPSGeneratorImpl::MPSGeneratorImpl(uint64_t seed_in) + : c10::GeneratorImpl{Device(DeviceType::MPS), DispatchKeySet(c10::DispatchKey::MPS)}, + data_({.seed = seed_in}), engine_(seed_in, 0, 0) { } + +void MPSGeneratorImpl::set_current_seed(uint64_t seed) { + data_.seed = seed; + data_.state.fill(1); + // the two last state values are the Philox keys + // TODO: make "key" in PhiloxRNGEngine.h public so we don't duplicate code here + data_.state[5] = static_cast(seed); + data_.state[6] = static_cast(seed >> 32); + engine_.reset_state(seed); +} + +uint64_t MPSGeneratorImpl::current_seed() const { + return data_.seed; +} + +uint64_t MPSGeneratorImpl::seed() { + auto random = c10::detail::getNonDeterministicRandom(); + this->set_current_seed(random); + return random; +} + +// See Note [Acquire lock when using random generators] +void MPSGeneratorImpl::update_philox_counters() { + // calling engine_() would call operator() of philox_engine class to + // get each of the four newly generated counter values (see PhiloxRNGEngine.h). + for (int i = 1; i <= 4; i++) { + data_.state[i] = engine_(); + } +} + +c10::intrusive_ptr MPSGeneratorImpl::get_state() const { + static const size_t states_size = mps::detail::PHILOX_STATE_N * sizeof(uint32_t); + static const size_t seed_size = sizeof(uint64_t); + static const size_t total_size = states_size + seed_size; + + auto state_tensor = at::detail::empty_cpu({(int64_t)total_size}, ScalarType::Byte, c10::nullopt, c10::nullopt, c10::nullopt, c10::nullopt); + auto rng_state = state_tensor.data_ptr(); + auto current_seed = this->current_seed(); + memcpy(rng_state, this->data_.state.data(), states_size); + memcpy(rng_state + states_size, ¤t_seed, seed_size); + + return state_tensor.getIntrusivePtr(); +} + +void MPSGeneratorImpl::set_state(const c10::TensorImpl& new_state) { + static const size_t states_size = mps::detail::PHILOX_STATE_N * sizeof(uint32_t); + static const size_t seed_size = sizeof(uint64_t); + static const size_t total_size = states_size + seed_size; + + detail::check_rng_state(new_state); + + auto new_state_size = new_state.numel(); + TORCH_CHECK(new_state_size == total_size, "RNG state is wrong size"); + + uint64_t input_seed = default_rng_seed_val; + auto new_rng_state = new_state.data(); + memcpy(&input_seed, new_rng_state + states_size, seed_size); + this->set_current_seed(input_seed); + // state.data must be copied after input_seed to not reset the state in set_current_seed() + memcpy(this->state_data(), new_rng_state, states_size); +} + +std::shared_ptr MPSGeneratorImpl::clone() const { + return std::shared_ptr(this->clone_impl()); +} + +MPSGeneratorImpl* MPSGeneratorImpl::clone_impl() const { + auto gen = new MPSGeneratorImpl(); + gen->set_current_seed(this->data_.seed); + return gen; +} + +} // namespace at diff --git a/aten/src/ATen/mps/MPSGuardImpl.h b/aten/src/ATen/mps/MPSGuardImpl.h index 27d32bf652e7a..b6002497d223d 100644 --- a/aten/src/ATen/mps/MPSGuardImpl.h +++ b/aten/src/ATen/mps/MPSGuardImpl.h @@ -109,12 +109,12 @@ struct TORCH_API MPSGuardImpl final : public c10::impl::DeviceGuardImplInterface struct OptionalMPSGuard { explicit OptionalMPSGuard() : guard_() {} - explicit OptionalMPSGuard(optional device_opt) + explicit OptionalMPSGuard(c10::optional device_opt) : guard_(device_opt) {} /// Set the current MPS device to the passed device index, if it is not /// nullopt - explicit OptionalMPSGuard(optional device_index_opt) + explicit OptionalMPSGuard(c10::optional device_index_opt) : guard_(device_index_opt) {} // Copy is not allowed @@ -144,14 +144,14 @@ struct OptionalMPSGuard { /// Returns the device that was set immediately prior to initialization of the /// guard, or nullopt if the guard is uninitialized. - optional original_device() const { + c10::optional original_device() const { return guard_.original_device(); } /// Returns the most recent device that was set using this device guard, /// either from construction, or via set_device, if the guard is initialized, /// or nullopt if the guard is uninitialized. - optional current_device() const { + c10::optional current_device() const { return guard_.current_device(); } diff --git a/aten/src/ATen/mps/MPSHooks.cpp b/aten/src/ATen/mps/MPSHooks.cpp new file mode 100644 index 0000000000000..4a549bfc72252 --- /dev/null +++ b/aten/src/ATen/mps/MPSHooks.cpp @@ -0,0 +1,37 @@ +// Copyright © 2022 Apple Inc. + +#include +#include +#include + +namespace at { +namespace mps { + +void MPSHooks::initMPS() const { + C10_LOG_API_USAGE_ONCE("aten.init.mps"); + // TODO: initialize MPS devices and streams here +} + +bool MPSHooks::hasMPS() const { + return at::mps::is_available(); +} + +Allocator* MPSHooks::getMPSDeviceAllocator() const { + return at::mps::GetMPSAllocator(); +} + +const Generator& MPSHooks::getDefaultMPSGenerator() const { + return at::mps::detail::getDefaultMPSGenerator(); +} + +void MPSHooks::deviceSynchronize() const { + at::mps::device_synchronize(); +} + +using at::MPSHooksRegistry; +using at::RegistererMPSHooksRegistry; + +REGISTER_MPS_HOOKS(MPSHooks); + +} // namespace mps +} // namespace at diff --git a/aten/src/ATen/mps/MPSHooks.h b/aten/src/ATen/mps/MPSHooks.h new file mode 100644 index 0000000000000..d64781930bff5 --- /dev/null +++ b/aten/src/ATen/mps/MPSHooks.h @@ -0,0 +1,21 @@ +// Copyright © 2022 Apple Inc. + +#pragma once + +#include +#include +#include + +namespace at { namespace mps { + +// The real implementation of MPSHooksInterface +struct MPSHooks : public at::MPSHooksInterface { + MPSHooks(at::MPSHooksArgs) {} + void initMPS() const override; + bool hasMPS() const override; + Allocator* getMPSDeviceAllocator() const override; + const Generator& getDefaultMPSGenerator() const override; + void deviceSynchronize() const override; +}; + +}} // at::mps diff --git a/aten/src/ATen/mps/MPSStream.mm b/aten/src/ATen/mps/MPSStream.mm index 04115fc268c76..f1f2d47cf1e6a 100644 --- a/aten/src/ATen/mps/MPSStream.mm +++ b/aten/src/ATen/mps/MPSStream.mm @@ -1,15 +1,13 @@ // Copyright © 2022 Apple Inc. #include +#include namespace at { namespace mps { #define USE_COMMIT_AND_CONTINUE 1 -// the frequency that we commit the command buffer calculated based on low watermark ratio in MPSAllocator -uint32_t get_adaptive_commit_threshold(); - //----------------------------------------------------------------- // MPSStream //----------------------------------------------------------------- @@ -52,7 +50,7 @@ break; case SyncType::COMMIT_ADAPTIVE: // the adaptive commit only commits if we hit the low watermark memory threshold - if (get_adaptive_commit_threshold() <= 1) { + if (getIMPSAllocator()->getLowWatermarkValue() <= 1) { #if USE_COMMIT_AND_CONTINUE commitAndContinue(); #else diff --git a/aten/src/ATen/native/Activation.cpp b/aten/src/ATen/native/Activation.cpp index 37e832d1e457b..bef09e81a5ea5 100644 --- a/aten/src/ATen/native/Activation.cpp +++ b/aten/src/ATen/native/Activation.cpp @@ -699,7 +699,7 @@ Tensor prelu_cpu(const Tensor& self, const Tensor& weight_) { auto as_nd = [&](const Tensor& t) { TORCH_CHECK( t.dim() == 1 || t.dim() == 0, - "prelu: Expected `weight` to be a scalar or 1D tensor, but got ndim = ", t.dim()); + "prelu: Expected `weight` to be a scalar or 1D tensor, but got: ndim = ", t.dim()); if (ndim >= 2) { sizes[1] = t.dim() == 1 ? t.size(0) : 1; strides[1] = t.dim() == 1 ? t.stride(0) : 0; diff --git a/aten/src/ATen/native/AdaptiveAveragePooling.cpp b/aten/src/ATen/native/AdaptiveAveragePooling.cpp index 40b05d74053ca..b612ef009b651 100644 --- a/aten/src/ATen/native/AdaptiveAveragePooling.cpp +++ b/aten/src/ATen/native/AdaptiveAveragePooling.cpp @@ -130,9 +130,9 @@ namespace { Tensor out = input.mean({-1, -2}, /* keepdim = */ true); if (input.suggest_memory_format() == at::MemoryFormat::ChannelsLast) { // assert ndim == 4, since ndim = 3 doesn't give channels_last - const int n = input.size(0); - const int c = input.size(1); - out.as_strided_({n, c, 1, 1}, {c, 1, c, c}); + const auto n = input.sym_size(0); + const auto c = input.sym_size(1); + out.as_strided__symint({n, c, 1, 1}, {c, 1, c, c}); } return out; } else { diff --git a/aten/src/ATen/native/AdaptiveAveragePooling3d.cpp b/aten/src/ATen/native/AdaptiveAveragePooling3d.cpp index 427368e2c06ae..a0a02ca531600 100644 --- a/aten/src/ATen/native/AdaptiveAveragePooling3d.cpp +++ b/aten/src/ATen/native/AdaptiveAveragePooling3d.cpp @@ -313,7 +313,7 @@ Tensor adaptive_avg_pool3d_cpu(Tensor const& input, IntArrayRef output_size) { return output; } -Tensor adaptive_avg_pool3d(Tensor const& input, IntArrayRef output_size) { +Tensor adaptive_avg_pool3d_symint(Tensor const& input, SymIntArrayRef output_size) { TORCH_CHECK(output_size.size() == 3, "adaptive_avg_pool3d: output_size must be 3"); TORCH_CHECK( (output_size[0] >= 0 && output_size[1] >= 0 && output_size[2] >= 0), @@ -326,7 +326,7 @@ Tensor adaptive_avg_pool3d(Tensor const& input, IntArrayRef output_size) { Tensor out = input.mean({-1, -2, -3}, /* keepdim = */ true); return out; } else { - return _adaptive_avg_pool3d(input, output_size); + return _adaptive_avg_pool3d_symint(input, output_size); } } diff --git a/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp b/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp index ef53b266ab1e9..e53d8cd2d38fc 100644 --- a/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp @@ -451,15 +451,6 @@ Tensor& orgqr_kernel_impl(Tensor& result, const Tensor& tau) { return result; } -// we use `enum class LapackLstsqDriverType` as keys in an unordered_map. -// Clang5 and Gcc5 do not support std::hash for enum classes, hence -// we provide our own hash function. -struct LapackLstsqDriverTypeHash { - std::size_t operator()(const LapackLstsqDriverType& driver_type) const { - return static_cast(driver_type); - } -}; - /* Solves a least squares problem. That is minimizing ||B - A X||. @@ -490,7 +481,7 @@ void apply_lstsq(const Tensor& A, Tensor& B, Tensor& rank, Tensor& singular_valu auto lapack_func = lapackLstsq; static auto driver_type_to_func - = std::unordered_map({ + = std::unordered_map({ {driver_t::Gels, lapackLstsq}, {driver_t::Gelsy, lapackLstsq}, {driver_t::Gelsd, lapackLstsq}, diff --git a/aten/src/ATen/native/Col2Im.cpp b/aten/src/ATen/native/Col2Im.cpp index 090a3a8a71db2..5ce747e9c7a7e 100644 --- a/aten/src/ATen/native/Col2Im.cpp +++ b/aten/src/ATen/native/Col2Im.cpp @@ -144,7 +144,6 @@ static void col2im_out_cpu_template( int64_t n_output_plane = n_input_plane / (kernel_width * kernel_height); output.resize_({batch_size, n_output_plane, output_height, output_width}); - output.zero_(); AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, input.scalar_type(), "col2im_out_cpu", [&] { diff --git a/aten/src/ATen/native/ComplexHelper.h b/aten/src/ATen/native/ComplexHelper.h index 88668d13145c5..9533115a7066c 100644 --- a/aten/src/ATen/native/ComplexHelper.h +++ b/aten/src/ATen/native/ComplexHelper.h @@ -1,8 +1,15 @@ #pragma once -#include +#include #include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#include +#endif + // WARNING: this header contains non-inline functions and should be only // included from ONE cpp file @@ -11,19 +18,18 @@ namespace at { namespace native { // View tensor with new dtype, storage offset, sizes and strides inline Tensor view_tensor( const Tensor &tensor, ScalarType dtype, - int64_t offset, IntArrayRef sizes, IntArrayRef strides) { + c10::SymInt offset, SymIntArrayRef sizes, SymIntArrayRef strides) { Storage storage = tensor.storage(); auto key_set = tensor.key_set().remove(DispatchKey::Conjugate); auto new_tensor = detail::make_tensor( c10::TensorImpl::VIEW, std::move(storage), key_set, scalarTypeToTypeMeta(dtype)); auto * impl = new_tensor.unsafeGetTensorImpl(); - impl->set_storage_offset(offset); - impl->set_sizes_and_strides(sizes, strides); + impl->set_sizes_and_strides(sizes, strides, offset); return new_tensor; } -inline DimVector computeStrideForViewAsReal(IntArrayRef oldstride) { - DimVector res(oldstride.size() + 1); +inline SymDimVector computeStrideForViewAsReal(SymIntArrayRef oldstride) { + SymDimVector res(oldstride.size() + 1); for (const auto i : c10::irange(oldstride.size())) { res[i] = oldstride[i] * 2; } @@ -33,13 +39,13 @@ inline DimVector computeStrideForViewAsReal(IntArrayRef oldstride) { Tensor _view_as_real_physical(const Tensor& self) { TORCH_CHECK(self.is_complex(), "view_as_real is only supported for complex tensors"); - auto old_sizes = self.sizes(); - DimVector new_sizes(old_sizes.size() + 1); + auto old_sizes = self.sym_sizes(); + SymDimVector new_sizes(old_sizes.size() + 1); std::copy(old_sizes.begin(), old_sizes.end(), new_sizes.begin()); // last dimension will always have two elements containing the real and imag vals new_sizes.back() = 2; - auto new_strides = computeStrideForViewAsReal(self.strides()); - auto new_storage_offset = 2 * self.storage_offset(); + auto new_strides = computeStrideForViewAsReal(self.sym_strides()); + auto new_storage_offset = self.sym_storage_offset() * 2; const auto float_type = c10::toRealValueType(self.scalar_type()); auto real_tensor = view_tensor(self, float_type, new_storage_offset, new_sizes, new_strides); return real_tensor; @@ -53,11 +59,11 @@ Tensor view_as_real(const Tensor& self) { return _view_as_real_physical(self); } -inline DimVector computeStrideForViewAsComplex(IntArrayRef oldstride) { +inline SymDimVector computeStrideForViewAsComplex(SymIntArrayRef oldstride) { const int64_t dim = oldstride.size(); TORCH_CHECK(oldstride[dim-1] == 1, "Tensor must have a last dimension with stride 1"); - DimVector res(dim - 1); + SymDimVector res(dim - 1); for (const auto i : c10::irange(res.size())) { TORCH_CHECK(oldstride[i] % 2 == 0, "Tensor must have a stride divisible by 2 for all but last dimension"); res[i] = oldstride[i] / 2; @@ -72,16 +78,16 @@ Tensor view_as_complex(const Tensor& self) { self.scalar_type() == kFloat || self.scalar_type() == kDouble || self.scalar_type() == kHalf, "view_as_complex is only supported for half, float and double tensors, but got a tensor of scalar type: ", self.scalar_type()); - auto old_sizes = self.sizes(); + auto old_sizes = self.sym_sizes(); TORCH_CHECK(old_sizes.size() != 0, "Input tensor must have one or more dimensions"); TORCH_CHECK(old_sizes[old_sizes.size()-1] == 2, "Tensor must have a last dimension of size 2"); - DimVector new_sizes(old_sizes.begin(), old_sizes.end() - 1); + SymDimVector new_sizes(old_sizes.begin(), old_sizes.end() - 1); - const auto new_strides = computeStrideForViewAsComplex(self.strides()); + const auto new_strides = computeStrideForViewAsComplex(self.sym_strides()); const auto complex_type = c10::toComplexType(self.scalar_type()); - TORCH_CHECK(self.storage_offset() % 2 == 0, "Tensor must have a storage_offset divisible by 2"); - const auto new_storage_offset = self.storage_offset() / 2; + TORCH_CHECK(self.sym_storage_offset() % 2 == 0, "Tensor must have a storage_offset divisible by 2"); + const auto new_storage_offset = self.sym_storage_offset() / 2; return view_tensor(self, complex_type, new_storage_offset, new_sizes, new_strides); } diff --git a/aten/src/ATen/native/ConvUtils.h b/aten/src/ATen/native/ConvUtils.h index a31dbee2bd759..880ce0c2af54a 100644 --- a/aten/src/ATen/native/ConvUtils.h +++ b/aten/src/ATen/native/ConvUtils.h @@ -80,40 +80,7 @@ static inline bool cudnnv8_use_heur_mode_b() { return cudnnv8_heuristic_mode_b; } -// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) -struct ConvParams { - std::vector stride; - std::vector padding; - std::vector dilation; - bool transposed; - std::vector output_padding; - int groups; - bool benchmark; - bool deterministic; - bool cudnn_enabled; - bool allow_tf32; - - bool is_strided() const; - bool is_dilated() const; - bool is_padded() const; - bool is_output_padding_neg() const; - bool is_output_padding_big() const; - bool is_padding_neg() const; - bool is_stride_nonpos() const; - void view1d_as_2d(); - bool use_cpu_depthwise3x3_winograd(const at::Tensor& input, const at::Tensor& weight, const c10::optional& bias) const; - bool needs_64bit_indexing_no_split(const at::Tensor& input, const at::Tensor& weight) const; - bool use_cudnn(const at::Tensor& input, const at::Tensor& weight) const; - bool use_cudnn_depthwise(const at::Tensor& input, const at::Tensor& weight) const; - bool use_miopen(const at::Tensor& input, const at::Tensor& weight, bool bias_defined) const; - bool use_mkldnn(const at::Tensor& input, const at::Tensor& weight) const; - bool use_nnpack(const at::Tensor& input, const at::Tensor& weight) const; - bool use_xnnpack(const at::Tensor& input, const at::Tensor& weight, - const at::OptionalIntArrayRef bias_sizes_opt) const; - bool use_mps(const at::Tensor& input, const at::Tensor& weight) const; - bool is_depthwise(const at::Tensor& input, const at::Tensor& weight) const; -}; - +// Keep in sync with py::enum_ in Module.cpp enum class ConvBackend { CudaDepthwise2d, CudaDepthwise3d, @@ -139,33 +106,16 @@ enum class ConvBackend { MpsTranspose, }; -// Function to select the convolution backend based on the inputs and params. -// This overload is used within the convolution internals but not exposed to python. -// NB: The forward pass provides a bias tensor while the backward pass provides -// a bool indicating whether the bias is defined. This is done to save memory by -// avoiding saving the full bias tensor for backward. -TORCH_API ConvBackend _select_conv_backend( - const Tensor& input, - const Tensor& weight, - const c10::optional& bias_opt, - const at::OptionalIntArrayRef bias_sizes_opt, - const bool need_backward, - const ConvParams& params); - -// For BC reasons, have a copy that does not require bias_opt -TORCH_API ConvBackend select_conv_backend( - const Tensor& input, - const Tensor& weight, - const at::OptionalIntArrayRef bias_sizes_opt, - const bool need_backward, - const ConvParams& params); - // Overload for selecting the convolution backend from the full set of convolution inputs. // This overload is exposed to python for testing, etc. TORCH_API ConvBackend select_conv_backend( const Tensor& input, const Tensor& weight, const c10::optional& bias_opt, - IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, - bool transposed, IntArrayRef output_padding, int64_t groups); + IntArrayRef stride, SymIntArrayRef padding, IntArrayRef dilation, + bool transposed, SymIntArrayRef output_padding, int64_t groups, const at::OptionalSymIntArrayRef bias_sizes_opt); + +TORCH_API at::MemoryFormat _determine_backend_memory_format(const Tensor& input, + const Tensor& weight, + const ConvBackend backend); // --------------------------------------------------------------------- // @@ -250,15 +200,16 @@ static void convolution_shape_check( // as conv_output_size loses information; this is why conv_input_size // takes an extra output_padding argument to resolve the ambiguity. -static inline std::vector conv_output_size( - IntArrayRef input_size, IntArrayRef weight_size, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation = IntArrayRef() +template +static inline std::vector _conv_output_size( + ArrayRef input_size, ArrayRef weight_size, + ArrayRef padding, IntArrayRef stride, IntArrayRef dilation = IntArrayRef() ) { // ASSERT(input_size.size() > 2) // ASSERT(input_size.size() == weight_size.size()) bool has_dilation = dilation.size() > 0; auto dim = input_size.size(); - std::vector output_size(dim); + std::vector output_size(dim); output_size[0] = input_size[input_batch_size_dim]; output_size[1] = weight_size[weight_output_channels_dim]; for (const auto d : c10::irange(2, dim)) { @@ -269,40 +220,84 @@ static inline std::vector conv_output_size( return output_size; } -static inline std::vector conv_input_size( - IntArrayRef output_size, IntArrayRef weight_size, - IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups +static inline std::vector conv_output_size( + IntArrayRef input_size, IntArrayRef weight_size, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation = IntArrayRef() +) { + return _conv_output_size(input_size, weight_size, padding, stride, dilation); +} + +static inline std::vector conv_output_size( + SymIntArrayRef input_size, SymIntArrayRef weight_size, + SymIntArrayRef padding, IntArrayRef stride, IntArrayRef dilation = IntArrayRef() +) { + return _conv_output_size(input_size, weight_size, padding, stride, dilation); +} + +template +std::vector _conv_input_size( + ArrayRef output_size, ArrayRef weight_size, + ArrayRef padding, ArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups ) { // ASSERT(output_size.size() > 2) // ASSERT(output_size.size() == weight_size.size()) auto dim = output_size.size(); - std::vector input_size(dim); + std::vector input_size(dim); input_size[0] = output_size[output_batch_size_dim]; input_size[1] = weight_size[weight_input_channels_dim] * groups; for (const auto d : c10::irange(2, dim)) { - int kernel = dilation[d - 2] * (weight_size[d] - 1) + 1; - input_size[d] = (output_size[d] - 1) * stride[d - 2] - (2 * padding[d - 2]) + + auto kernel = (weight_size[d] - 1) * dilation[d - 2] + 1; + input_size[d] = (output_size[d] - 1) * stride[d - 2] - (padding[d - 2] * 2) + kernel + output_padding[d - 2]; } return input_size; } -static inline std::vector conv_weight_size( - IntArrayRef input_size, IntArrayRef output_size, +static inline std::vector conv_input_size( + SymIntArrayRef output_size, SymIntArrayRef weight_size, + SymIntArrayRef padding, SymIntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups +) { + return _conv_input_size(output_size, weight_size, padding, output_padding, stride, dilation, groups); +} + +static inline std::vector conv_input_size( + IntArrayRef output_size, IntArrayRef weight_size, IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups +) { + return _conv_input_size(output_size, weight_size, padding, output_padding, stride, dilation, groups); +} + +template +std::vector _conv_weight_size( + ArrayRef input_size, ArrayRef output_size, + ArrayRef padding, ArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups ) { auto dim = input_size.size(); - std::vector weight_size(dim); + std::vector weight_size(dim); weight_size[0] = output_size[1]; weight_size[1] = input_size[1] / groups; for (const auto d : c10::irange(2, dim)) { - int kernel = input_size[d] - (output_size[d] - 1) * stride[d - 2] - + 2 * padding[d - 2] - output_padding[d - 2]; + auto kernel = input_size[d] - (output_size[d] - 1) * stride[d - 2] + + padding[d - 2] * 2 - output_padding[d - 2]; weight_size[d] = (kernel - 1) / dilation[d - 2] + 1; } return weight_size; } +static inline std::vector conv_weight_size( + SymIntArrayRef input_size, SymIntArrayRef output_size, + SymIntArrayRef padding, SymIntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups +) { + return _conv_weight_size(input_size, output_size, padding, output_padding, stride, dilation, groups); +} + +static inline std::vector conv_weight_size( + IntArrayRef input_size, IntArrayRef output_size, + IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups +) { + return _conv_weight_size(input_size, output_size, padding, output_padding, stride, dilation, groups); +} + static inline Tensor reshape_bias(int64_t dim, const Tensor& bias) { std::vector shape(dim, 1); shape[1] = -1; diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp index 4d68f23c0734f..edb51a5c837d8 100644 --- a/aten/src/ATen/native/Convolution.cpp +++ b/aten/src/ATen/native/Convolution.cpp @@ -11,11 +11,15 @@ #include #include #include - #include - #include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + #if AT_NNPACK_ENABLED() #include #endif @@ -82,309 +86,12 @@ constexpr int MIOPEN_DIM_MAX = 5; namespace at { namespace native { -DEFINE_DISPATCH(conv_depthwise2d_backward_stub); -DEFINE_DISPATCH(conv_depthwise3d_backward_stub); -DEFINE_DISPATCH(cudnn_convolution_backward_stub); -DEFINE_DISPATCH(cudnn_convolution_transpose_backward_stub); -DEFINE_DISPATCH(slow_conv_transpose3d_backward_stub); -DEFINE_DISPATCH(convolution_depthwise3x3_winograd_stub); -DEFINE_DISPATCH(miopen_convolution_backward_stub); -DEFINE_DISPATCH(miopen_convolution_transpose_backward_stub); -DEFINE_DISPATCH(miopen_depthwise_convolution_backward_stub); -DEFINE_DISPATCH(mkldnn_convolution_backward_stub); -DEFINE_DISPATCH(slow_conv_dilated2d_backward_stub); -DEFINE_DISPATCH(slow_conv_dilated3d_backward_stub); -DEFINE_DISPATCH(slow_conv_transpose2d_backward_stub); -REGISTER_NO_CPU_DISPATCH(conv_depthwise2d_backward_stub); -REGISTER_NO_CPU_DISPATCH(conv_depthwise3d_backward_stub); -REGISTER_NO_CPU_DISPATCH(cudnn_convolution_backward_stub); -REGISTER_NO_CPU_DISPATCH(cudnn_convolution_transpose_backward_stub); -REGISTER_NO_CPU_DISPATCH(miopen_convolution_backward_stub); -REGISTER_NO_CPU_DISPATCH(miopen_convolution_transpose_backward_stub); -REGISTER_NO_CPU_DISPATCH(miopen_depthwise_convolution_backward_stub); - -std::ostream& operator<<(std::ostream & out, const ConvParams& params) { - out << "ConvParams {" - << " stride = " << IntArrayRef{params.stride} - << " padding = " << IntArrayRef{params.padding} - << " dilation = " << IntArrayRef{params.dilation} - << " transposed = " << params.transposed - << " output_padding = " << IntArrayRef{params.output_padding} - << " groups = " << params.groups - << " benchmark = " << params.benchmark - << " deterministic = " << params.deterministic - << " cudnn_enabled = " << params.cudnn_enabled - << " allow_tf32 = " << params.allow_tf32 - << "}"; - return out; -} - -auto ConvParams::is_strided() const -> bool { - bool is_strided = false; - for (auto s : stride) { - is_strided |= (s != 1); - } - return is_strided; -} - -auto ConvParams::is_dilated() const -> bool { - bool is_dilated = false; - for (auto d : dilation) { - is_dilated |= (d != 1); - } - return is_dilated; -} - -auto ConvParams::is_padded() const -> bool { - bool is_padded = false; - for (auto p : padding) { - is_padded |= (p != 0); - } - return is_padded; -} - -auto ConvParams::is_output_padding_neg() const -> bool { - bool is_non_neg = false; - for (auto p : output_padding) { - is_non_neg |= (p < 0); - } - return is_non_neg; -} - -auto ConvParams::is_output_padding_big() const -> bool { - bool is_big = false; - for (auto i: c10::irange(output_padding.size())) { - is_big |= (output_padding[i] >= stride[i]); - } - return is_big; -} - -auto ConvParams::is_padding_neg() const -> bool { - bool is_non_neg = false; - for (auto p : padding) { - is_non_neg |= (p < 0); - } - return is_non_neg; -} - -auto ConvParams::is_stride_nonpos() const -> bool { - bool is_nonpos = false; - for (auto s : stride) { - is_nonpos |= (s <= 0); - } - return is_nonpos; -} - -auto ConvParams::view1d_as_2d() -> void { - if (stride.size() == 1) { - stride.insert(stride.begin(), 1); - padding.insert(padding.begin(), 0); - dilation.insert(dilation.begin(), 1); - output_padding.insert(output_padding.begin(), 0); - } -} - -auto ConvParams::use_cpu_depthwise3x3_winograd( - const at::Tensor& input, - const at::Tensor& weight, - const c10::optional& bias) const -> bool { -#if defined(__ARM_NEON__) - // Currently only 3x3 depthwise convolutions on tensors of float are supported. - return (input.ndimension() == 4) && - (input.size(1) == groups) && - (weight.ndimension() == 4 ) && - (weight.size(0) % input.size(1) == 0) && - (weight.size(1) == 1) && - (weight.size(2) == 3) && - (weight.size(3) == 3) && - (input.device().is_cpu()) && - (input.scalar_type() == at::kFloat) && - input.is_contiguous() && - (weight.device().is_cpu()) && - (weight.scalar_type() == at::kFloat) && - weight.is_contiguous() && - (!bias.has_value() || bias->is_contiguous()) && - !is_strided() && - !is_dilated() && - !transposed; -#else - return false; -#endif -} - -auto ConvParams::needs_64bit_indexing_no_split(const at::Tensor& input, const at::Tensor& weight) const -> bool { - constexpr int64_t int_max = std::numeric_limits::max(); - int64_t numel_input = input.numel(); - // empty input - if (numel_input == 0) { - return false; - } - // input size can not be reduced to the range of int by splitting the batch dim - int64_t n = input.size(0); - if (numel_input / n > int_max) { - return true; - } - // output size can not be reduced to the range of int by splitting the batch dim - int64_t outsize = 1; - if (transposed) { - std::vector o = conv_input_size(input.sizes(), weight.sizes(), padding, output_padding, stride, dilation, groups); - outsize = c10::multiply_integers(o.begin() + 1, o.end()); - } else { - std::vector o = conv_output_size(input.sizes(), weight.sizes(), padding, stride, dilation); - outsize = c10::multiply_integers(o.begin() + 1, o.end()); - } - return outsize > int_max; -} - -auto ConvParams::use_cudnn(const at::Tensor& input, const at::Tensor& weight) const -> bool { - -// Note [Mobile check segfaults] -// cudnn and miopen are guaranteed not to be on mobile, and T102591915 / T110194934 suggest -// that maybe the compiledWithCuDNN() check sometimes segfaults (though I can't imagine how) -#if !defined(C10_MOBILE) - if (needs_64bit_indexing_no_split(input, weight)) { - return false; - } - if (!detail::getCUDAHooks().compiledWithCuDNN()) { - return false; - } - if (!input.is_cuda() || !cudnn_enabled) { - return false; - } - if (input.scalar_type() == at::kBFloat16 || weight.scalar_type() == at::kBFloat16) { - if (!(detail::getCUDAHooks().supportsBFloat16ConvolutionWithCuDNNv8() && at::native::cudnnv8_enabled_check_debug())) { - return false; - } - } - if (cudnn_conv_suggest_memory_format(input, weight) == at::MemoryFormat::Contiguous) { - // bypass dilation checks for channels_last convolution - if (deterministic && is_dilated()) { - // cudnn doesn't support deterministic dilated convolution fully yet - return false; - } - if (is_dilated()) { - return detail::getCUDAHooks().supportsDilatedConvolutionWithCuDNN() && !is_output_padding_big(); - } - } - return !is_output_padding_big(); -#else - return false; -#endif -} - -auto ConvParams::use_mps( const at::Tensor& input, const at::Tensor& weight) const -> bool { - // These checks need to be expanded. Currently we have very limited set of - // checks for MPS. -#ifdef USE_MPS - if (needs_64bit_indexing_no_split(input, weight)) { - return false; - } - if (!input.is_mps()) { - return false; - } - return true; -#else - return false; -#endif -} - -auto ConvParams::use_miopen(const at::Tensor& input, const at::Tensor& weight, bool bias_defined) const -> bool { - if (needs_64bit_indexing_no_split(input, weight)) { - return false; - } - return ((input.scalar_type() == at::kFloat) || (input.scalar_type() == at::kHalf) || (input.scalar_type() == at::kBFloat16)) - && detail::getCUDAHooks().compiledWithMIOpen() - && input.is_cuda() - && input.dim() <= MIOPEN_DIM_MAX - && !(groups > 1 && is_dilated()) // MIOpen currently does not support dilation with groups of size > 1 - && !(input.scalar_type() == at::kBFloat16 && bias_defined) // MIOpen currently doesn't support bias with bfloat16 - && cudnn_enabled - ; -} - -auto ConvParams::use_mkldnn(const at::Tensor& input, const at::Tensor& weight) const -> bool { -#if AT_MKLDNN_ENABLED() - if (!at::globalContext().userEnabledMkldnn()) { - return false; - } - if (input.device().is_cpu() && input.scalar_type() == kBFloat16 && mkldnn_bf16_device_check()) { - return true; - } - return (input.is_mkldnn()) || // input is mkldnn Tensor - (input.device().is_cpu() && - input.scalar_type() == kFloat && // only on CPU Float Tensors - !transposed && // or transposed tensors - // For 1x1 filters, MKLDNN is faster than THNN when multi-threaded, - // but THNN is faster when single-threaded. - (is_strided() || is_dilated() || input.size(0) >= 16 || - weight.size(-1) != 1 || weight.size(-2) != 1 || at::get_num_threads() > 1) && - (groups > 1 - || (weight.size(-1) > 3 && weight.size(-2) > 3) - || input.size(0) > 1 - || input.size(0)*input.size(1)*input.size(2)*input.size(3) > 20480) // for some case, native is faster - ); - -#endif - return false; -} - -auto ConvParams::use_nnpack(const at::Tensor& input, const at::Tensor& weight) const -> bool { -#if AT_NNPACK_ENABLED() - return at::_nnpack_available() && - input.device().is_cpu() && - input.scalar_type() == kFloat && // only on CPU Float Tensors - !is_dilated() && // or dilation - !transposed && // or transposed tensors - input.ndimension() == 4 && // must be in NCHW format - weight.ndimension() == 4 && - (weight.size(2) < 17) && (weight.size(3) < 17) // NNPACK only supports kernels up to 16x16 -#if !defined(C10_MOBILE) - && input.size(0) >= 16 // ensure large enough batch size to ensure perf, tuneable -#endif - ; -#endif - return false; -} - -auto ConvParams::use_xnnpack( - const at::Tensor& input, - const at::Tensor& weight, - const at::OptionalIntArrayRef bias_sizes_opt) const -> bool { -#if defined(C10_MOBILE) - if (!transposed) { - return (input.size(1) == groups) && - xnnpack::use_convolution2d( - input, - weight, - bias_sizes_opt, - padding, - stride, - dilation, - groups, - transposed); - } -#endif - return false; -} - -// We currently only have depthwise support for the case where groups == -// nInputPlane and nInputPlane == nOutputPlane (the latter due to the lack of -// a depthwise multiplier) -auto ConvParams::is_depthwise( - const at::Tensor& input, const at::Tensor& weight) const -> bool { - return input.is_cuda() && - !transposed && - (input.ndimension() == 4 || input.ndimension() == 5) && - input.size(1) == groups && - groups > 1 && // no point if there is only a single group - weight.size(0) % input.size(1) == 0; // output channels must be a multiple of input channels -} - // Check workload to activate fast depthwise FP16 cudnn conv kernels +template bool check_cudnn_depthwise_workload(const at::Tensor& input, int stride) { - int w = input.size(3); // same as h - int ch = input.size(1); - int bs = input.size(0); + auto w = at::symint::size(input, 3); // same as h + auto ch = at::symint::size(input, 1); + auto bs = at::symint::size(input, 0); if (stride==1) { if (w >= 7) { // All batch sizes and nb_channels @@ -503,27 +210,28 @@ bool check_cudnn_depthwise_workload(const at::Tensor& input, int stride) { } // simplified version for cudnn 8.2 and above +template bool check_cudnn_depthwise_workload_with_filter(const at::Tensor& input, int stride, const at::Tensor& weight) { // 1D conv - if(input.size(2) == 1 && stride == 1){ + if(at::symint::size(input, 2) == 1 && stride == 1){ return true; } // 2d conv // only square filters - if (weight.size(2) != weight.size(3)) return false; - int filter = weight.size(3); + if (at::symint::size(weight, 2) != at::symint::size(weight, 3)) return false; + auto filter = at::symint::size(weight, 3); // only 1/3/5 filter if (filter != 1 && filter != 3 && filter != 5) return false; // we don't enforce square input but only check width to reduce heuristic space - if (input.size(3) < 7) return false; // min width 7 - int w = input.size(3); + if (at::symint::size(input, 3) < 7) return false; // min width 7 + auto w = at::symint::size(input, 3); // only 1/2 stride, use cudnn for all stride 1 if (stride == 1) return true; if (stride != 2) return false; - int ch = input.size(1); - int bs = input.size(0); + auto ch = at::symint::size(input, 1); + auto bs = at::symint::size(input, 0); // special case since bs1 show good perf in lots of cases if (bs == 1) { if (filter == 1 && w <= 28) return true; @@ -537,54 +245,390 @@ bool check_cudnn_depthwise_workload_with_filter(const at::Tensor& input, int str return false; } -// Use cudnn for FP16 depthwise convolutions -auto ConvParams::use_cudnn_depthwise( - const at::Tensor& input, const at::Tensor& weight) const -> bool { - if (cudnn_conv_suggest_memory_format(input, weight) != at::MemoryFormat::Contiguous && use_cudnn(input, weight)) { - // always use cudnn_depthwise for channels_last format - return true; + +bool xnnpack_use_convolution2d( + const Tensor& input, + const Tensor& weight, + const at::OptionalIntArrayRef bias_sizes_opt, + const IntArrayRef padding, + const IntArrayRef stride, + const IntArrayRef dilation, + const int64_t groups, + const bool transposed) { + return xnnpack::use_convolution2d(input, weight, bias_sizes_opt, padding, stride, dilation, groups, transposed); +} + +bool xnnpack_use_convolution2d( + const Tensor& input, + const Tensor& weight, + const at::OptionalSymIntArrayRef bias_sizes_opt, + const SymIntArrayRef padding, + const IntArrayRef stride, + const IntArrayRef dilation, + const int64_t groups, + const bool transposed) { + // Never use xnnpack for symbolic tracing + return false; +} + +// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) +// This struct is templated so that we can run backend selection in a dynamic +// shapes context; all of the real kernel selection in eager mode runs with +// int64_t +template +struct ConvParams { + std::vector stride; + std::vector padding; + std::vector dilation; + bool transposed; + std::vector output_padding; + int groups; + bool benchmark; + bool deterministic; + bool cudnn_enabled; + bool allow_tf32; + + bool is_strided() const { + bool is_strided = false; + for (auto s : stride) { + is_strided |= (s != 1); + } + return is_strided; + } + + bool is_dilated() const { + bool is_dilated = false; + for (auto d : dilation) { + is_dilated |= (d != 1); + } + return is_dilated; + } + + bool is_padded() const { + bool is_padded = false; + for (auto p : padding) { + is_padded |= (p != 0); + } + return is_padded; + } + + bool is_output_padding_neg() const { + bool is_non_neg = false; + for (auto p : output_padding) { + is_non_neg |= (p < 0); + } + return is_non_neg; + } + + bool is_output_padding_big() const { + bool is_big = false; + for (auto i: c10::irange(output_padding.size())) { + is_big |= (output_padding[i] >= stride[i]); + } + return is_big; + } + + bool is_padding_neg() const { + bool is_non_neg = false; + for (auto p : padding) { + is_non_neg |= (p < 0); + } + return is_non_neg; + } + + bool is_stride_nonpos() const { + bool is_nonpos = false; + for (auto s : stride) { + is_nonpos |= (s <= 0); + } + return is_nonpos; } - if (detail::getCUDAHooks().supportsDepthwiseConvolutionWithCuDNN()) { - long cudnn_version = detail::getCUDAHooks().versionCuDNN(); - if (cudnn_version >= 8200) { - bool kernel_cond = (use_cudnn(input, weight) && + + void view1d_as_2d() { + if (stride.size() == 1) { + stride.insert(stride.begin(), 1); + padding.insert(padding.begin(), 0); + dilation.insert(dilation.begin(), 1); + output_padding.insert(output_padding.begin(), 0); + } + } + + bool use_cpu_depthwise3x3_winograd(const at::Tensor& input, const at::Tensor& weight, const c10::optional& bias) const { +#if defined(__ARM_NEON__) + // Currently only 3x3 depthwise convolutions on tensors of float are supported. + return (input.ndimension() == 4) && + (at::symint::size(input, 1) == groups) && + (weight.ndimension() == 4 ) && + (at::symint::size(weight, 0) % at::symint::size(input, 1) == 0) && + (at::symint::size(weight, 1) == 1) && + (at::symint::size(weight, 2) == 3) && + (at::symint::size(weight, 3) == 3) && + (input.device().is_cpu()) && + (input.scalar_type() == at::kFloat) && + input.is_contiguous() && + (weight.device().is_cpu()) && + (weight.scalar_type() == at::kFloat) && + weight.is_contiguous() && + (!bias.has_value() || bias->is_contiguous()) && + !is_strided() && + !is_dilated() && + !transposed; +#else + return false; +#endif + } + + bool needs_64bit_indexing_no_split(const at::Tensor& input, const at::Tensor& weight) const { + constexpr int64_t int_max = std::numeric_limits::max(); + auto numel_input = at::symint::numel(input); + // empty input + if (numel_input == 0) { + return false; + } + // input size can not be reduced to the range of int by splitting the batch dim + auto n = at::symint::size(input, 0); + if (numel_input / n > int_max) { + return true; + } + // output size can not be reduced to the range of int by splitting the batch dim + T outsize = 1; + if (transposed) { + auto o = conv_input_size(at::symint::sizes(input), at::symint::sizes(weight), padding, output_padding, stride, dilation, groups); + outsize = c10::multiply_integers(o.begin() + 1, o.end()); + } else { + auto o = conv_output_size(at::symint::sizes(input), at::symint::sizes(weight), padding, stride, dilation); + outsize = c10::multiply_integers(o.begin() + 1, o.end()); + } + return outsize > int_max; + } + + bool use_cudnn(const at::Tensor& input, const at::Tensor& weight) const { + // Note [Mobile check segfaults] + // cudnn and miopen are guaranteed not to be on mobile, and T102591915 / T110194934 suggest + // that maybe the compiledWithCuDNN() check sometimes segfaults (though I can't imagine how) +#if !defined(C10_MOBILE) + if (needs_64bit_indexing_no_split(input, weight)) { + return false; + } + if (!detail::getCUDAHooks().compiledWithCuDNN()) { + return false; + } + if (!input.is_cuda() || !cudnn_enabled) { + return false; + } + if (input.scalar_type() == at::kBFloat16 || weight.scalar_type() == at::kBFloat16) { + if (!(detail::getCUDAHooks().supportsBFloat16ConvolutionWithCuDNNv8() && at::native::cudnnv8_enabled_check_debug())) { + return false; + } + } + if (cudnn_conv_suggest_memory_format(input, weight) == at::MemoryFormat::Contiguous) { + // bypass dilation checks for channels_last convolution + if (deterministic && is_dilated()) { + // cudnn doesn't support deterministic dilated convolution fully yet + return false; + } + if (is_dilated()) { + return detail::getCUDAHooks().supportsDilatedConvolutionWithCuDNN() && !is_output_padding_big(); + } + } + return !is_output_padding_big(); +#else + return false; +#endif + } + + // Use cudnn for FP16 depthwise convolutions + bool use_cudnn_depthwise(const at::Tensor& input, const at::Tensor& weight) const { + if (cudnn_conv_suggest_memory_format(input, weight) != at::MemoryFormat::Contiguous && use_cudnn(input, weight)) { + // always use cudnn_depthwise for channels_last format + return true; + } + if (detail::getCUDAHooks().supportsDepthwiseConvolutionWithCuDNN()) { + long cudnn_version = detail::getCUDAHooks().versionCuDNN(); + if (cudnn_version >= 8200) { + bool kernel_cond = (use_cudnn(input, weight) && + input.scalar_type() == kHalf && // only for FP16 + weight.scalar_type() == kHalf && + is_depthwise(input, weight) && + input.ndimension() == 4 && // TODO: 5-D contiguous depthwise is not supported yet, need benchmarks + !is_dilated() && // no dilation supported + (stride[0] == stride[1] || at::symint::size(input, 2) == 1) && // square or 1d + at::symint::size(input, 1) >= 32); // min 32 channels supported) + if (kernel_cond) { + return check_cudnn_depthwise_workload_with_filter(input, stride[1], weight); + } + } + // keep (7600 <= cudnn < 8200) code unchanged + bool kernel_cond = (cudnn_version >= 7600 && + use_cudnn(input, weight) && input.scalar_type() == kHalf && // only for FP16 weight.scalar_type() == kHalf && is_depthwise(input, weight) && input.ndimension() == 4 && // TODO: 5-D contiguous depthwise is not supported yet, need benchmarks + at::symint::size(weight, 2) == at::symint::size(weight, 3) && // only square kernels + at::symint::size(input, 2) >= 7 && // min width/height 7 !is_dilated() && // no dilation supported - (stride[0] == stride[1] || input.size(2) == 1) && // square or 1d - input.size(1) >= 32); // min 32 channels supported) + stride[0] == stride[1] && // equal strides + ((at::symint::size(weight, 3) == 3) || (at::symint::size(weight, 3) == 1)) && + at::symint::size(input, 1) >= 32); // min 32 channels supported) if (kernel_cond) { - return check_cudnn_depthwise_workload_with_filter(input, stride[1], weight); + return check_cudnn_depthwise_workload(input, stride[0]); + } else { + return false; } - } - // keep (7600 <= cudnn < 8200) code unchanged - bool kernel_cond = (cudnn_version >= 7600 && - use_cudnn(input, weight) && - input.scalar_type() == kHalf && // only for FP16 - weight.scalar_type() == kHalf && - is_depthwise(input, weight) && - input.ndimension() == 4 && // TODO: 5-D contiguous depthwise is not supported yet, need benchmarks - weight.size(2) == weight.size(3) && // only square kernels - input.size(2) >= 7 && // min width/height 7 - !is_dilated() && // no dilation supported - stride[0] == stride[1] && // equal strides - ((weight.size(3) == 3) || (weight.size(3) == 1)) && - input.size(1) >= 32); // min 32 channels supported) - if (kernel_cond) { - return check_cudnn_depthwise_workload(input, stride[0]); } else { return false; } - } else { + } + + bool use_miopen(const at::Tensor& input, const at::Tensor& weight, bool bias_defined) const { + if (needs_64bit_indexing_no_split(input, weight)) { + return false; + } + return ((input.scalar_type() == at::kFloat) || (input.scalar_type() == at::kHalf) || (input.scalar_type() == at::kBFloat16)) + && detail::getCUDAHooks().compiledWithMIOpen() + && input.is_cuda() + && input.dim() <= MIOPEN_DIM_MAX + && !(groups > 1 && is_dilated()) // MIOpen currently does not support dilation with groups of size > 1 + && !(input.scalar_type() == at::kBFloat16 && bias_defined) // MIOpen currently doesn't support bias with bfloat16 + && cudnn_enabled + ; + } + bool use_mkldnn(const at::Tensor& input, const at::Tensor& weight) const { +#if AT_MKLDNN_ENABLED() + if (!at::globalContext().userEnabledMkldnn()) { + return false; + } + if (input.device().is_cpu() && input.scalar_type() == kBFloat16 && mkldnn_bf16_device_check()) { + return true; + } + return (input.is_mkldnn()) || // input is mkldnn Tensor + (input.device().is_cpu() && + input.scalar_type() == kFloat && // only on CPU Float Tensors + !transposed && // or transposed tensors + // For 1x1 filters, MKLDNN is faster than THNN when multi-threaded, + // but THNN is faster when single-threaded. + (is_strided() || is_dilated() || at::symint::size(input, 0) >= 16 || + at::symint::size(weight, -1) != 1 || at::symint::size(weight, -2) != 1 || at::get_num_threads() > 1) && + (groups > 1 + || (at::symint::size(weight, -1) > 3 && at::symint::size(weight, -2) > 3) + || at::symint::size(input, 0) > 1 + || at::symint::size(input, 0)*at::symint::size(input, 1)*at::symint::size(input, 2)*at::symint::size(input, 3) > 20480) // for some case, native is faster + ); + +#endif + return false; + } + bool use_nnpack(const at::Tensor& input, const at::Tensor& weight) const { +#if AT_NNPACK_ENABLED() + return at::_nnpack_available() && + input.device().is_cpu() && + input.scalar_type() == kFloat && // only on CPU Float Tensors + !is_dilated() && // or dilation + !transposed && // or transposed tensors + input.ndimension() == 4 && // must be in NCHW format + weight.ndimension() == 4 && + (at::symint::size(weight, 2) < 17) && (at::symint::size(weight, 3) < 17) // NNPACK only supports kernels up to 16x16 +#if !defined(C10_MOBILE) + && at::symint::size(input, 0) >= 16 // ensure large enough batch size to ensure perf, tuneable +#endif + ; +#endif return false; } + bool use_xnnpack(const at::Tensor& input, const at::Tensor& weight, + const at::OptionalArrayRef bias_sizes_opt) const { +#if defined(C10_MOBILE) + if (!transposed) { + // NB: for the call here, it MATTERS that we are templated. If you + // untemplate this to always use SymInt, the function + // xnnpack_use_convolution2d will always return false + return (at::symint::size(input, 1) == groups) && + xnnpack_use_convolution2d( + input, + weight, + bias_sizes_opt, + padding, + stride, + dilation, + groups, + transposed); + } +#endif + return false; + } + + bool use_mps(const at::Tensor& input, const at::Tensor& weight) const { + // These checks need to be expanded. Currently we have very limited set of + // checks for MPS. +#ifdef USE_MPS + if (needs_64bit_indexing_no_split(input, weight)) { + return false; + } + if (!input.is_mps()) { + return false; + } + return true; +#else + return false; +#endif + } + + // We currently only have depthwise support for the case where groups == + // nInputPlane and nInputPlane == nOutputPlane (the latter due to the lack of + // a depthwise multiplier) + bool is_depthwise(const at::Tensor& input, const at::Tensor& weight) const { + return input.is_cuda() && + !transposed && + (input.ndimension() == 4 || input.ndimension() == 5) && + at::symint::size(input, 1) == groups && + groups > 1 && // no point if there is only a single group + at::symint::size(weight, 0) % at::symint::size(input, 1) == 0; // output channels must be a multiple of input channels + } +}; + +DEFINE_DISPATCH(conv_depthwise2d_backward_stub); +DEFINE_DISPATCH(conv_depthwise3d_backward_stub); +DEFINE_DISPATCH(cudnn_convolution_backward_stub); +DEFINE_DISPATCH(cudnn_convolution_transpose_backward_stub); +DEFINE_DISPATCH(slow_conv_transpose3d_backward_stub); +DEFINE_DISPATCH(convolution_depthwise3x3_winograd_stub); +DEFINE_DISPATCH(miopen_convolution_backward_stub); +DEFINE_DISPATCH(miopen_convolution_transpose_backward_stub); +DEFINE_DISPATCH(miopen_depthwise_convolution_backward_stub); +DEFINE_DISPATCH(mkldnn_convolution_backward_stub); +DEFINE_DISPATCH(slow_conv_dilated2d_backward_stub); +DEFINE_DISPATCH(slow_conv_dilated3d_backward_stub); +DEFINE_DISPATCH(slow_conv_transpose2d_backward_stub); +REGISTER_NO_CPU_DISPATCH(conv_depthwise2d_backward_stub); +REGISTER_NO_CPU_DISPATCH(conv_depthwise3d_backward_stub); +REGISTER_NO_CPU_DISPATCH(cudnn_convolution_backward_stub); +REGISTER_NO_CPU_DISPATCH(cudnn_convolution_transpose_backward_stub); +REGISTER_NO_CPU_DISPATCH(miopen_convolution_backward_stub); +REGISTER_NO_CPU_DISPATCH(miopen_convolution_transpose_backward_stub); +REGISTER_NO_CPU_DISPATCH(miopen_depthwise_convolution_backward_stub); + +template +std::ostream& operator<<(std::ostream & out, const ConvParams& params) { + out << "ConvParams {" + << " stride = " << IntArrayRef{params.stride} + << " padding = " << ArrayRef{params.padding} + << " dilation = " << IntArrayRef{params.dilation} + << " transposed = " << params.transposed + << " output_padding = " << ArrayRef{params.output_padding} + << " groups = " << params.groups + << " benchmark = " << params.benchmark + << " deterministic = " << params.deterministic + << " cudnn_enabled = " << params.cudnn_enabled + << " allow_tf32 = " << params.allow_tf32 + << "}"; + return out; } +template static void check_shape_forward(const at::Tensor& input, - const c10::IntArrayRef& weight_sizes, const at::Tensor& bias, - const ConvParams& params) { + const c10::ArrayRef& weight_sizes, const at::Tensor& bias, + const ConvParams& params) { int64_t k = input.ndimension(); int64_t weight_dim = weight_sizes.size(); int64_t groups = params.groups; @@ -599,7 +643,7 @@ static void check_shape_forward(const at::Tensor& input, TORCH_CHECK(weight_dim == k, "Expected ", weight_dim, "-dimensional input for ", weight_dim, "-dimensional weight ", weight_sizes, ", but got ", k, "-dimensional input of size ", - input.sizes(), " instead"); + at::symint::sizes(input), " instead"); TORCH_CHECK(weight_sizes[0] >= groups, "Given groups=", groups, ", expected weight to be at least ", groups, " at dimension 0, but got weight of size ", weight_sizes, " instead"); @@ -609,23 +653,23 @@ static void check_shape_forward(const at::Tensor& input, "] instead"); if (!transposed) { - std::vector input_shape; - std::vector kernel_shape; + std::vector input_shape; + std::vector kernel_shape; bool kernel_size_correct = true; - TORCH_CHECK(input.size(1) == (weight_sizes[1] * groups), + TORCH_CHECK(at::symint::size(input, 1) == (weight_sizes[1] * groups), "Given groups=", groups, ", weight of size ", weight_sizes, - ", expected input", input.sizes(), " to have ", - (weight_sizes[1] * groups), " channels, but got ", input.size(1), + ", expected input", at::symint::sizes(input), " to have ", + (weight_sizes[1] * groups), " channels, but got ", at::symint::size(input, 1), " channels instead"); - TORCH_CHECK(!bias.defined() || (bias.ndimension() == 1 && bias.size(0) == weight_sizes[0]), + TORCH_CHECK(!bias.defined() || (bias.ndimension() == 1 && at::symint::size(bias, 0) == weight_sizes[0]), "Given weight of size ", weight_sizes, ", expected bias to be 1-dimensional with ", weight_sizes[0], " elements", - ", but got bias of size ", bias.sizes(), " instead"); + ", but got bias of size ", at::symint::sizes(bias), " instead"); for (const auto i : c10::irange(2, k)) { - input_shape.push_back(input.size(i) + 2 * padding[i-2]); + input_shape.push_back(at::symint::size(input, i) + 2 * padding[i-2]); // log new kernel size considering dilation kernel_shape.push_back(dilation[i-2] * (weight_sizes[i]-1) + 1); if (input_shape.back() < kernel_shape.back()) { @@ -651,22 +695,23 @@ static void check_shape_forward(const at::Tensor& input, "Kernel size: (", kernel_ss.str(), "). Kernel size can't be greater than actual input size"); } } else { // transposed - TORCH_CHECK(input.size(1) == weight_sizes[0], + TORCH_CHECK(at::symint::size(input, 1) == weight_sizes[0], "Given transposed=", transposed, ", weight of size ", weight_sizes, - ", expected input", input.sizes(), " to have ", weight_sizes[0], - " channels, but got ", input.size(1), " channels instead"); - TORCH_CHECK(!bias.defined() || (bias.ndimension() == 1 && bias.size(0) == weight_sizes[1] * groups), + ", expected input", at::symint::sizes(input), " to have ", weight_sizes[0], + " channels, but got ", at::symint::size(input, 1), " channels instead"); + TORCH_CHECK(!bias.defined() || (bias.ndimension() == 1 && at::symint::size(bias, 0) == weight_sizes[1] * groups), "Given transposed=", transposed, ", weight of size ", weight_sizes, ", expected bias to be 1-dimensional with ", weight_sizes[1] * groups, " elements", - ", but got bias of size ", bias.sizes(), " instead"); + ", but got bias of size ", at::symint::sizes(bias), " instead"); } } +template static void check_shape_backward( const at::Tensor& input, - const c10::IntArrayRef& weight_sizes, - const ConvParams& params) { - check_shape_forward(input, weight_sizes, /*bias=*/ Tensor(), params); + const c10::ArrayRef& weight_sizes, + const ConvParams& params) { + check_shape_forward(input, weight_sizes, /*bias=*/ Tensor(), params); } // Given an input tensor and an expected number of spatial dimensions, checks that the @@ -910,8 +955,8 @@ static Tensor convolution_same( auto k = weight.dim(); TORCH_CHECK(k > 2, "weight should have at least three dimensions"); auto dim = static_cast(k - 2); - auto weight_sizes = weight.sizes(); - auto input_sizes = input.sizes(); + auto weight_sizes = weight.sym_sizes(); + auto input_sizes = input.sym_sizes(); TORCH_CHECK(k == input.dim(), "Expected ", k, "-dimensional input for ", k, "-dimensional weight", weight_sizes, ", but got ", @@ -926,7 +971,7 @@ static Tensor convolution_same( } // Calculate the correct padding - DimVector padding_l, padding_r; + SymDimVector padding_l, padding_r; bool symmetric_padding = true; for (auto i: c10::irange(dim)) { auto s = stride.size() == 1 ? stride[0] : stride[i]; @@ -942,14 +987,14 @@ static Tensor convolution_same( if (symmetric_padding) { // All backends handle symmetric padding natively - DimVector output_padding(static_cast(dim)); - return at::convolution(input, weight, bias, stride, padding_l, dilation, + SymDimVector output_padding(static_cast(dim)); + return at::convolution_symint(input, weight, bias, stride, padding_l, dilation, false, output_padding, groups); } TORCH_WARN_ONCE("Using padding='same' with even kernel lengths and odd dilation may" " require a zero-padded copy of the input be created"); - SmallVector pad_nd(static_cast(2 * dim)); + SmallVector pad_nd(static_cast(2 * dim)); for (auto i: c10::irange(dim)) { // Apply padding by the difference, leaving only a symmetric padding auto delta_pad = padding_r[i] - padding_l[i]; @@ -961,10 +1006,10 @@ static Tensor convolution_same( padding_l[i] = padding_r[i]; } } - auto padded_input = at::constant_pad_nd(input, pad_nd, 0); - DimVector output_padding(static_cast(dim)); - return at::convolution(padded_input, weight, bias, stride, padding_l, - dilation, false, output_padding, groups); + auto padded_input = at::constant_pad_nd_symint(input, pad_nd, 0); + SymDimVector output_padding(static_cast(dim)); + return at::convolution_symint(padded_input, weight, bias, stride, padding_l, + dilation, false, output_padding, groups); } Tensor _convolution_mode( @@ -1066,8 +1111,14 @@ at::Tensor conv_transpose2d( Tensor input; bool is_batched; std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 2, "conv_transpose2d"); - auto output = at::convolution( + Tensor output; + if (at::isComplexType(input_.scalar_type())) { + output = complex_convolution( + input, weight, bias, stride, padding, dilation, true, output_padding, groups); + } else { + output = at::convolution( input, weight, bias, stride, padding, dilation, true, output_padding, groups); + } return is_batched ? output : output.squeeze(0); } @@ -1081,8 +1132,14 @@ at::Tensor conv_transpose3d( Tensor input; bool is_batched; std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 3, "conv_transpose3d"); - auto output = at::convolution( + Tensor output; + if (at::isComplexType(input_.scalar_type())) { + output = complex_convolution( input, weight, bias, stride, padding, dilation, true, output_padding, groups); + } else { + output = at::convolution( + input, weight, bias, stride, padding, dilation, true, output_padding, groups); + } return is_batched ? output : output.squeeze(0); } @@ -1112,71 +1169,25 @@ at::Tensor convolution_overrideable( TORCH_CHECK_NOT_IMPLEMENTED(false, "convolution_overrideable not implemented. You are likely triggering this with tensor backend other than CPU/CUDA/MKLDNN, if this is intended, please use TORCH_LIBRARY_IMPL to override this function "); } -// Selects a backend for convolution based on the inputs and params. -ConvBackend select_conv_backend( - const Tensor& input_r, const Tensor& weight_r, const c10::optional& bias_opt, - IntArrayRef stride_, IntArrayRef padding_, IntArrayRef dilation_, - bool transposed_, IntArrayRef output_padding_, int64_t groups_) { - c10::MaybeOwned bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt); - const Tensor& bias = *bias_maybe_owned; - - auto& ctx = at::globalContext(); - auto k = weight_r.ndimension(); - int64_t dim = k - 2; - ConvParams params; - params.stride = expand_param_if_needed(stride_, "stride", dim); - params.padding = expand_param_if_needed(padding_, "padding", dim); - params.dilation = expand_param_if_needed(dilation_, "dilation", dim); - params.transposed = transposed_; - params.output_padding = expand_param_if_needed(output_padding_, "output_padding", dim); - params.groups = groups_; - params.benchmark = ctx.benchmarkCuDNN(); - params.deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms(); - params.cudnn_enabled = ctx.userEnabledCuDNN(); - params.allow_tf32 = ctx.allowTF32CuDNN(); - - auto input = input_r; - auto weight = weight_r; - check_shape_forward(input, weight.sizes(), bias, params); - - // Expand 1d -> 2d. - // This is only done for backends that don't natively support 1d spatial input. - if (k == 3 && !input.is_mkldnn() && !input.is_xpu()) { - // avoid accidentally going through NHWC for permuted 3d input. - input = input.contiguous(); - params.view1d_as_2d(); - input = view4d(input); - weight = view4d(weight); - } - - auto bias_sizes_opt = bias.defined() ? c10::optional(bias.sizes()) : c10::nullopt; - bool need_backward = GradMode::is_enabled() && - (input.requires_grad() || weight.requires_grad() || (bias.defined() && bias.requires_grad())); - return _select_conv_backend(input, weight, bias, bias_sizes_opt, need_backward, params); -} - -ConvBackend select_conv_backend( - const Tensor& input, - const Tensor& weight, - const at::OptionalIntArrayRef bias_sizes_opt, - const bool need_backward, - const ConvParams& params) { - return _select_conv_backend(input, weight, {}, bias_sizes_opt, need_backward, params); -} - +// Function to select the convolution backend based on the inputs and params. +// This overload is used within the convolution internals but not exposed to python. +// NB: The forward pass provides a bias tensor while the backward pass provides +// a bool indicating whether the bias is defined. This is done to save memory by +// avoiding saving the full bias tensor for backward. +template ConvBackend _select_conv_backend( const Tensor& input, const Tensor& weight, const c10::optional& bias, - const at::OptionalIntArrayRef bias_sizes_opt, + const at::OptionalArrayRef bias_sizes_opt, const bool need_backward, - const ConvParams& params) { + const ConvParams& params) { // don't send empty inputs through backends - if (input.size(0) == 0 || input.size(1) == 0) { + if (at::symint::size(input, 0) == 0 || at::symint::size(input, 1) == 0) { return input.is_mkldnn() ? ConvBackend::MkldnnEmpty : ConvBackend::Empty; - } else if (input.numel() == 0) { - TORCH_CHECK(false, "Only zero batch or zero channel inputs are supported, but got input shape: ", input.sizes()); + } else if (at::symint::numel(input) == 0) { + TORCH_CHECK(false, "Only zero batch or zero channel inputs are supported, but got input shape: ", at::symint::sizes(input)); } if (params.is_depthwise(input, weight)) { @@ -1268,12 +1279,65 @@ ConvBackend _select_conv_backend( AT_ERROR("unsupported ConvNd parameters"); } +// Selects a backend for convolution based on the inputs and params. +ConvBackend select_conv_backend( + const Tensor& input_r, const Tensor& weight_r, const c10::optional& bias_opt, + IntArrayRef stride_, SymIntArrayRef padding_, IntArrayRef dilation_, + bool transposed_, SymIntArrayRef output_padding_, int64_t groups_, const at::OptionalSymIntArrayRef bias_sizes_opt) { + c10::MaybeOwned bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt); + const Tensor& bias = *bias_maybe_owned; + + auto& ctx = at::globalContext(); + auto k = weight_r.ndimension(); + int64_t dim = k - 2; + ConvParams params; + params.stride = expand_param_if_needed(stride_, "stride", dim); + params.padding = expand_param_if_needed(padding_, "padding", dim); + params.dilation = expand_param_if_needed(dilation_, "dilation", dim); + params.transposed = transposed_; + params.output_padding = expand_param_if_needed(output_padding_, "output_padding", dim); + params.groups = groups_; + params.benchmark = ctx.benchmarkCuDNN(); + params.deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms(); + params.cudnn_enabled = ctx.userEnabledCuDNN(); + params.allow_tf32 = ctx.allowTF32CuDNN(); + + auto input = input_r; + auto weight = weight_r; + check_shape_forward(input, weight.sym_sizes(), bias, params); + + // Expand 1d -> 2d. + // This is only done for backends that don't natively support 1d spatial input. + if (k == 3 && !input.is_mkldnn() && !input.is_xpu()) { + // avoid accidentally going through NHWC for permuted 3d input. + input = input.contiguous(); + params.view1d_as_2d(); + input = view4d(input); + weight = view4d(weight); + } + + auto bias_sizes = bias.defined() ? c10::optional(bias.sym_sizes()) : bias_sizes_opt; + bool need_backward = GradMode::is_enabled() && + (input.requires_grad() || weight.requires_grad() || (bias.defined() && bias.requires_grad())); + return _select_conv_backend(input, weight, bias, bias_sizes, need_backward, params); +} + +// For BC reasons, have a copy that does not require bias_opt +ConvBackend select_conv_backend( + const Tensor& input, + const Tensor& weight, + const at::OptionalIntArrayRef bias_sizes_opt, + const bool need_backward, + const ConvParams& params) { + return _select_conv_backend(input, weight, {}, bias_sizes_opt, need_backward, params); +} + at::Tensor _convolution_nogroup_backend( const Tensor& input, const Tensor& weight, const Tensor& bias, const ConvBackend backend, - const ConvParams& params) { + const ConvParams& params) { auto kernel_size = weight.sizes().slice(2); switch(backend) { case ConvBackend::NnpackSpatial: @@ -1304,7 +1368,7 @@ at::Tensor _convolution_nogroup_backend( static inline std::vector calc_output_size( const Tensor& input, const Tensor& weight, - const ConvParams& params) { + const ConvParams& params) { std::vector output_size = params.transposed ? conv_input_size(input.sizes(), weight.sizes(), params.padding, params.output_padding, params.stride, params.dilation, params.groups) : @@ -1359,6 +1423,13 @@ static inline at::MemoryFormat determine_backend_memory_format( return backend_memory_format; } +at::MemoryFormat _determine_backend_memory_format( + const Tensor& input, + const Tensor& weight, + const ConvBackend backend) { + return determine_backend_memory_format(input, weight, backend); +} + at::Tensor _convolution( const Tensor& input_r, const Tensor& weight_r, const c10::optional& bias_r_opt, IntArrayRef stride_, IntArrayRef padding_, IntArrayRef dilation_, @@ -1378,7 +1449,7 @@ at::Tensor _convolution( TORCH_CHECK(dim > 0, "weight should have at least three dimensions"); TORCH_CHECK(groups_ > 0, "non-positive groups is not supported"); - ConvParams params; + ConvParams params; params.stride = expand_param_if_needed(stride_, "stride", dim); params.padding = expand_param_if_needed(padding_, "padding", dim); params.dilation = expand_param_if_needed(dilation_, "dilation", dim); @@ -1406,7 +1477,7 @@ at::Tensor _convolution( auto bias_sizes_opt = bias.defined() ? c10::optional(bias.sizes()) : c10::nullopt; bool need_backward = GradMode::is_enabled() && (input.requires_grad() || weight.requires_grad() || (bias.defined() && bias.requires_grad())); - ConvBackend backend = _select_conv_backend(input, weight, bias, bias_sizes_opt, need_backward, params); + ConvBackend backend = _select_conv_backend(input, weight, bias, c10::OptionalIntArrayRef(bias_sizes_opt), need_backward, params); at::MemoryFormat backend_memory_format = determine_backend_memory_format(input, weight, backend); // Call the backend. @@ -1441,7 +1512,19 @@ at::Tensor _convolution( break; case ConvBackend::Empty: { - auto weight_view = at::_unsafe_view(weight, -1); + Tensor weight_view; + // Use permute and clone to avoid at::_unsafe_view(weight, -1) failure for non-contiguous cases where + // view size is not compatible with input tensor's size and stride. + if(weight.is_contiguous()) { + weight_view = at::_unsafe_view(weight, -1); + } else if (weight.is_contiguous(at::MemoryFormat::ChannelsLast)) { + weight_view = at::_unsafe_view(at::permute(weight, {0, 2, 3, 1}), -1); + } else if (weight.is_contiguous(at::MemoryFormat::ChannelsLast3d)) { + weight_view = at::_unsafe_view(at::permute(weight, {0, 2, 3, 4, 1}), -1); + } else { + weight_view = at::_unsafe_view(weight.clone(at::MemoryFormat::Contiguous), -1); + } + output = (input.size(1) == 0) ? (input.view(-1) * weight_view) : (input * weight_view[0]); if (bias.defined()) { output.add_(bias[0]); @@ -1619,7 +1702,7 @@ std::tuple _convolution_double_backward( const c10::option auto weight = weight_r; int64_t dim = weight.ndimension() - 2; - ConvParams params; + ConvParams params; params.stride = expand_param_if_needed(stride_, "stride", dim); params.padding = expand_param_if_needed(padding_, "padding", dim); params.dilation = expand_param_if_needed(dilation_, "dilation", dim); @@ -1682,7 +1765,7 @@ std::tuple _convolution_double_backward( const c10::option if (ggI.defined()) { // Modified params with correct padding - ConvParams gw_conv_params(params); + ConvParams gw_conv_params(params); // Disable groups as they are handled separately auto groups = gw_conv_params.groups; @@ -1751,7 +1834,7 @@ std::tuple _convolution_double_backward( const c10::option Tensor gI; if (input.numel() != 0) { if (ggW.defined()) { - ConvParams gi_conv_params(params); + ConvParams gi_conv_params(params); gi_conv_params.transposed = !params.transposed; if (params.transposed) { @@ -1807,7 +1890,7 @@ std::tuple _convolution_backward_nogroup_bac const Tensor& weight, const std::array output_mask, const ConvBackend backend, - const ConvParams& params) { + const ConvParams& params) { auto kernel_size = weight.sizes().slice(2); switch(backend) { case ConvBackend::Slow2d: @@ -1872,7 +1955,7 @@ std::tuple convolution_backward( TORCH_CHECK(dim > 0, "weight should have at least three dimensions"); auto& ctx = at::globalContext(); - ConvParams params; + ConvParams params; params.stride = expand_param_if_needed(stride, "stride", dim); params.padding = expand_param_if_needed(padding, "padding", dim); params.dilation = expand_param_if_needed(dilation, "dilation", dim); diff --git a/aten/src/ATen/native/Copy.cpp b/aten/src/ATen/native/Copy.cpp index a44f39c5bb2eb..0c99943eb0cb0 100644 --- a/aten/src/ATen/native/Copy.cpp +++ b/aten/src/ATen/native/Copy.cpp @@ -124,12 +124,17 @@ static Tensor & copy_impl(Tensor & self, const Tensor & src, bool non_blocking) // 1. Memory Format for source and destination tensors is contiguous. // 2. Device for both the source and destination tensor is CPU. // 3. dtype conversion between FP32->FP16 and FP16->FP32. + // This checks that self.sizes() == src.sizes() because this code path doesn't + // support broadcasting. This also guards against out of bounds memory access + // when copying, see fbgemm::Float16ToFloat_ref. + // https://github.com/pytorch/pytorch/issues/88543 #ifdef USE_FBGEMM if (((self.dtype() == at::kFloat && src.dtype() == at::kHalf) || (self.dtype() == at::kHalf && src.dtype() == at::kFloat)) && (self.device().is_cpu() && src.device().is_cpu()) && ((self.is_contiguous() && src.is_contiguous()) || - (self.is_non_overlapping_and_dense() && self.strides() == src.strides()))) { + (self.is_non_overlapping_and_dense() && self.strides() == src.strides())) && + (self.sizes() == src.sizes())) { if (src.dtype() == at::kFloat && self.dtype() == at::kHalf) { auto* output_ptr = reinterpret_cast(self.data_ptr()); @@ -220,6 +225,18 @@ static Tensor & copy_impl(Tensor & self, const Tensor & src, bool non_blocking) return at::metal::metal_copy_(self, src); } + // Exit early if self and src are views of the same data + const bool is_same_data = ( + self.is_alias_of(src) && + self.storage_offset() == src.storage_offset() && + self.strides().equals(src.strides()) && + self.sizes().equals(src.sizes()) && + self.scalar_type() == src.scalar_type() + ); + if (is_same_data) { + return self; + } + auto iter = TensorIteratorConfig() .add_output(self) @@ -261,27 +278,39 @@ static Tensor & copy_impl(Tensor & self, const Tensor & src, bool non_blocking) return self; } +// NB: cribbed from https://github.com/pytorch/pytorch/pull/88198 +at::Tensor clone_preserve_strides(const at::Tensor& self) { + TORCH_INTERNAL_ASSERT(self.has_storage()); + // In cases where the input tensor has internal memory overlap, we cannot actually + // preserve the strides/storage_offset of the input tensor, because + // *_scatter ops will try to copy_() into the cloned tensor. + // However, this should **never** show up in functionalized user code; + // most aten ops that try to mutate a tensor with internal memory overlap would error anyway. + // + // The one place that this does come up is in autograd - if there's a select_scatter + // in the forward, then autograd will generate one for the backward. + // If the input to the select_scatter is grad_output, then this could be an expanded tensor + // with internal overlap. + //if (at::has_internal_overlap(self) == at::MemOverlap::Yes) { + // return self.clone(); + //} + auto dtype_size = self.dtype().itemsize(); + auto nbytes = self.storage().sym_nbytes(); + TORCH_INTERNAL_ASSERT(nbytes % dtype_size == 0); + auto numel = nbytes / dtype_size; + auto self_full_size = self.as_strided_symint({numel}, {1}, 0); + auto clone = self_full_size.clone(); + auto out = clone.as_strided_symint(self.sym_sizes(), self.sym_strides(), self.sym_storage_offset()); + return out; +} + Tensor copy(const Tensor& self, const Tensor& src, bool non_blocking) { // copy() is the "functional" form of copy_(). It exists so we can properly functionalize copy_(), but: // (1) It isn't exposed to the frontend (no python bindings) // (2) It isn't exposed to the backend (it's a composite, that decomposes into to() and expand_as() calls. - // Note: This implementation doesn't currently preserve the strides of `self`. - // That might be fine for functorch (which already doesn't preserve strides in vmap), - // but it's worth looking into whether or not this implementation will be problematic for LazyTensor/XLA. - auto intermediate = src.to(self, non_blocking); - // We can't use expand() here. Why? - // The contract for copy_() is that the output tensor has the same amount of storage as the original tensor. - // e.g. This should work: - // a = torch.ones(4, 4) - // b = torch.ones(1, 4) - // c = torch.ones(4, 4) - // torch.ops.aten.copy(a, b).add_(c) - // We don't want to emit an extra copy every time though, so we only do it if the shapes are different. - if (self.sym_sizes() != intermediate.sym_sizes()) { - return at::expand_copy_symint(intermediate, self.sym_sizes()); - } else { - return intermediate; - } + auto r = clone_preserve_strides(self); + r.copy_(src, non_blocking); + return r; } Tensor& copy_(Tensor& self, const Tensor& src, bool non_blocking) { diff --git a/aten/src/ATen/native/Correlation.cpp b/aten/src/ATen/native/Correlation.cpp index 204e4f2cb5688..9aca753c78ca5 100644 --- a/aten/src/ATen/native/Correlation.cpp +++ b/aten/src/ATen/native/Correlation.cpp @@ -139,7 +139,7 @@ Tensor corrcoef(const Tensor& self) { } // normalize covariance - const auto d = c.diag(); + const auto d = c.diagonal(); const auto stddev = at::sqrt(d.is_complex() ? at::real(d) : d); c = c / stddev.view({-1, 1}); c = c / stddev.view({1, -1}); diff --git a/aten/src/ATen/native/Embedding.cpp b/aten/src/ATen/native/Embedding.cpp index f23594022991e..4c37325c48171 100644 --- a/aten/src/ATen/native/Embedding.cpp +++ b/aten/src/ATen/native/Embedding.cpp @@ -33,8 +33,8 @@ namespace at { namespace native { -Tensor embedding(const Tensor & weight, const Tensor & indices, - int64_t padding_idx, bool scale_grad_by_freq, bool sparse) { +Tensor embedding_symint(const Tensor & weight, const Tensor & indices, + c10::SymInt padding_idx, bool scale_grad_by_freq, bool sparse) { TORCH_CHECK(weight.dim() == 2, "'weight' must be 2-D"); auto indices_arg = TensorArg(indices, "indices", 1); checkScalarTypes("embedding", indices_arg, {kLong, kInt}); @@ -53,18 +53,21 @@ Tensor embedding(const Tensor & weight, const Tensor & indices, } Tensor embedding_backward_symint( - const Tensor & grad, const Tensor & indices, SymInt num_weights, - int64_t padding_idx, bool scale_grad_by_freq, bool sparse) { + const Tensor & grad, const Tensor & indices, c10::SymInt num_weights, + c10::SymInt padding_idx, bool scale_grad_by_freq, bool sparse) { if (sparse) { // TODO: if we teach sparse tensor how to propagate symints, the guard // here is not strictly necessary. However, we think it is fine as is // because num weights is derived from a parameter and therefore // typically not varying. return at::embedding_sparse_backward( - grad, indices, num_weights.guard_int(__FILE__, __LINE__), padding_idx, scale_grad_by_freq); + grad, indices, + num_weights.guard_int(__FILE__, __LINE__), + padding_idx.guard_int(__FILE__, __LINE__), + scale_grad_by_freq); } else { return at::embedding_dense_backward_symint( - grad, indices, num_weights, padding_idx, scale_grad_by_freq); + grad, indices, num_weights, padding_idx, scale_grad_by_freq); } } @@ -89,20 +92,20 @@ Tensor embedding_sparse_backward( grad = grad.index(c); } - int64_t num_features = grad_.size(-1); - auto weight_size = std::array{{ num_weights, num_features }}; + auto num_features = grad_.sym_size(-1); + auto weight_size = std::array{{ num_weights, num_features }}; auto dense_options = grad.options(); // check if all our grad come from padding_idx - if (grad.numel() == 0) { - return at::_sparse_coo_tensor_unsafe(at::empty({1, 0}, indices_.options().dtype(kLong)), - at::empty({0, num_features}, dense_options), + if (grad.sym_numel() == 0) { + return at::_sparse_coo_tensor_unsafe_symint(at::empty({1, 0}, indices_.options().dtype(kLong)), + at::empty_symint({c10::SymInt(0), num_features}, dense_options), weight_size); } auto index = indices.reshape({1, -1}); - auto values = grad.reshape({-1, num_features}); - return at::_sparse_coo_tensor_unsafe(index.to(kLong), values, weight_size); + auto values = grad.reshape_symint({c10::SymInt(-1), num_features}); + return at::_sparse_coo_tensor_unsafe_symint(index.to(kLong), values, weight_size); } Tensor embedding_dense_backward_cpu( diff --git a/aten/src/ATen/native/EmbeddingBag.cpp b/aten/src/ATen/native/EmbeddingBag.cpp index 7d4a89d6b40f7..21404947b3dbb 100644 --- a/aten/src/ATen/native/EmbeddingBag.cpp +++ b/aten/src/ATen/native/EmbeddingBag.cpp @@ -1307,7 +1307,7 @@ Tensor _embedding_bag_backward_symint(const Tensor &grad, const Tensor &indices_ checkContiguous("embedding_bag", offsets_arg); Tensor offset2bag_; - if (indices.numel() != 0 && offset2bag.numel() == 0) { + if (indices.sym_numel() != 0 && offset2bag.sym_numel() == 0) { offset2bag_ = offsets.new_zeros( {indices.size(0) + 1}, offsets.options()); // offset2bag = [0 0 0 0 0] diff --git a/aten/src/ATen/native/ForeachOpsKernels.cpp b/aten/src/ATen/native/ForeachOpsKernels.cpp index bbe12b73592b1..4b6ef9196f990 100644 --- a/aten/src/ATen/native/ForeachOpsKernels.cpp +++ b/aten/src/ATen/native/ForeachOpsKernels.cpp @@ -196,7 +196,30 @@ void foreach_tensor_##OP##_scalarlist_slow_(TensorList input, TensorList tensors for(const auto i : c10::irange(input.size())) { \ input[i].OP##_(tensors1[i], tensors2[i], scalars[i]); \ } \ -} \ +} + +#define FOREACH_POINTWISE_OP_TENSOR(OP) \ + std::vector foreach_tensor_##OP##_tensor_slow( \ + TensorList input, \ + TensorList tensors1, \ + TensorList tensors2, \ + const Tensor& scalars_) { \ + auto scalars = convert_tensor_to_scalar_list(scalars_, input.size()); \ + check_foreach_api_restrictions(input, tensors1, tensors2, scalars); \ + return foreach_tensor_##OP##_scalarlist_slow( \ + input, tensors1, tensors2, scalars); \ + } \ + \ + void foreach_tensor_##OP##_tensor_slow_( \ + TensorList input, \ + TensorList tensors1, \ + TensorList tensors2, \ + const Tensor& scalars_) { \ + auto scalars = convert_tensor_to_scalar_list(scalars_, input.size()); \ + check_foreach_api_restrictions(input, tensors1, tensors2, scalars); \ + foreach_tensor_##OP##_scalarlist_slow_( \ + input, tensors1, tensors2, scalars); \ + } FOREACH_BINARY_OP_LIST_ALPHA(add); FOREACH_BINARY_OP_LIST_ALPHA(sub); @@ -249,6 +272,9 @@ FOREACH_POINTWISE_OP_SCALAR(addcmul); FOREACH_POINTWISE_OP_SCALARLIST(addcdiv); FOREACH_POINTWISE_OP_SCALARLIST(addcmul); +FOREACH_POINTWISE_OP_TENSOR(addcdiv); +FOREACH_POINTWISE_OP_TENSOR(addcmul); + // NOTE(crcrpar): It didn't seem feasible to use `self[i]` as both the first and the last // arguments of `maximum_out` and `minimum_out` so I tentatively embarrassingly get and copy // the result to `self[i]`. diff --git a/aten/src/ATen/native/ForeachUtils.h b/aten/src/ATen/native/ForeachUtils.h index 033052f401f6b..0166d040863c5 100644 --- a/aten/src/ATen/native/ForeachUtils.h +++ b/aten/src/ATen/native/ForeachUtils.h @@ -2,6 +2,7 @@ #include #include +#include #ifndef AT_PER_OPERATOR_HEADERS #include @@ -123,6 +124,45 @@ bool check_fast_path_restrictions( return true; } +std::vector convert_tensor_to_scalar_list( + const Tensor& scalarList_, + int64_t expect_length) { + std::vector scalarList; + TORCH_CHECK( + scalarList_.device() == c10::kCPU, + "Expected scalars to be on CPU, got ", + scalarList_.device(), + " instead."); + TORCH_CHECK( + scalarList_.is_contiguous(), "Expected scalars to be contiguous."); + TORCH_CHECK( + scalarList_.dim() == 1, + "Expected packed scalar Tensor to be of dimension 1. Got ", + scalarList_.dim(), + " instead."); + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( + kComplexHalf, + kHalf, + kBool, + kBFloat16, + scalarList_.scalar_type(), + "convert_tensor_to_scalar_list", + [&]() { + const scalar_t* scalar_data = scalarList_.data_ptr(); + TORCH_CHECK( + (expect_length == scalarList_.size(0)), + "Expected length of scalars to match input of length ", + expect_length, + " but got ", + scalarList_.size(0), + " instead."); + for (int64_t i = 0; i < scalarList_.size(0); i++) { + scalarList.push_back(c10::Scalar(scalar_data[i])); + } + }); + return scalarList; +} + bool can_use_fast_route(ArrayRef tensorLists, ArrayRef scalarList = {}, bool does_op_promote_integer_inputs_to_float = false) { diff --git a/aten/src/ATen/native/FractionalMaxPool2d.cpp b/aten/src/ATen/native/FractionalMaxPool2d.cpp index 82512c83f4337..1e9bf9c3902fd 100644 --- a/aten/src/ATen/native/FractionalMaxPool2d.cpp +++ b/aten/src/ATen/native/FractionalMaxPool2d.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #ifndef AT_PER_OPERATOR_HEADERS @@ -128,28 +129,6 @@ TORCH_META_FUNC(fractional_max_pool2d_backward)( namespace native { namespace { -template -static std::vector fractional_max_pool2d_generate_intervals( - scalar_t sample, - int inputSize, - int outputSize, - int poolSize) { - std::vector sequence(outputSize); - if (outputSize > 1) { - scalar_t alpha = static_cast(inputSize - poolSize) / - static_cast(outputSize - 1); - - for (int i = 0; i < outputSize - 1; ++i) { - sequence[i] = - static_cast((i + sample) * alpha) - static_cast(sample * alpha); - } - } - if (outputSize > 0) { - sequence[outputSize - 1] = inputSize - poolSize; - } - return sequence; -} - template static void fractional_max_pool2d_out_single_batch_frame( scalar_t* input, @@ -166,9 +145,9 @@ static void fractional_max_pool2d_out_single_batch_frame( scalar_t* randomSamplesForPlane = randomSamples + plane * 2; /* Generate interval sequence */ - auto sequenceW = fractional_max_pool2d_generate_intervals( + auto sequenceW = generate_intervals( randomSamplesForPlane[0], inputW, outputW, poolSizeW); - auto sequenceH = fractional_max_pool2d_generate_intervals( + auto sequenceH = generate_intervals( randomSamplesForPlane[1], inputH, outputH, poolSizeH); /* loop over output */ @@ -305,10 +284,16 @@ TORCH_IMPL_FUNC(fractional_max_pool2d_out_cpu) ( const at::Tensor& input_, IntArrayRef pool_size, IntArrayRef output_size, - const at::Tensor& randomSamples, + const at::Tensor& randomSamples_, const at::Tensor& output, const at::Tensor& indices) { + fractional_max_pool_check_shape(input_, randomSamples_); + + if (output.numel() == 0) { + return; + } + int64_t numBatch = 1; int64_t planeDim = 0; int64_t heightDim = 1; @@ -318,8 +303,9 @@ TORCH_IMPL_FUNC(fractional_max_pool2d_out_cpu) ( int64_t poolSizeH = pool_size[0]; int64_t poolSizeW = pool_size[1]; - /* get contiguous input */ + /* get contiguous input and samples */ auto input = input_.contiguous(); + auto randomSamples = randomSamples_.contiguous(); int64_t ndims = input.ndimension(); diff --git a/aten/src/ATen/native/FractionalMaxPool3d.cpp b/aten/src/ATen/native/FractionalMaxPool3d.cpp index 5890026872a85..c524f0545473c 100644 --- a/aten/src/ATen/native/FractionalMaxPool3d.cpp +++ b/aten/src/ATen/native/FractionalMaxPool3d.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include @@ -100,28 +101,6 @@ TORCH_PRECOMPUTE_META_FUNC(fractional_max_pool3d)( namespace native { namespace { -template -static std::vector generate_intervals( - scalar_t sample, - int64_t inputSize, - int64_t outputSize, - int64_t poolSize) { - std::vector sequence(outputSize); - if (outputSize > 1) { - scalar_t alpha = static_cast(inputSize - poolSize) / - static_cast(outputSize - 1); - - for (const auto i : c10::irange(outputSize - 1)) { - sequence[i] = - static_cast((i + sample) * alpha) - static_cast(sample * alpha); - } - } - if (outputSize > 0) { - sequence[outputSize - 1] = inputSize - poolSize; - } - return sequence; -} - template static void fractional_max_pool3d_out_single_batch_frame( scalar_t* input, @@ -241,7 +220,7 @@ TORCH_IMPL_FUNC(fractional_max_pool3d_out_cpu)( int64_t outputT, int64_t outputH, int64_t outputW, - const at::Tensor& randomSamples, + const at::Tensor& randomSamples_, int64_t numBatch, int64_t numPlanes, int64_t inputT, @@ -249,8 +228,16 @@ TORCH_IMPL_FUNC(fractional_max_pool3d_out_cpu)( int64_t inputW, const at::Tensor& output, const at::Tensor& indices) { - /* get contiguous input */ + + fractional_max_pool_check_shape(input_, randomSamples_); + + if (output.numel() == 0) { + return; + } + + /* get contiguous input and samples */ auto input = input_.contiguous(); + auto randomSamples = randomSamples_.contiguous(); AT_DISPATCH_FLOATING_TYPES( input.scalar_type(), diff --git a/aten/src/ATen/native/FractionalMaxPooling.h b/aten/src/ATen/native/FractionalMaxPooling.h new file mode 100644 index 0000000000000..6631450faaa88 --- /dev/null +++ b/aten/src/ATen/native/FractionalMaxPooling.h @@ -0,0 +1,80 @@ +#pragma once +#include +#include +#include + +namespace at { namespace native { + +template +static inline std::vector generate_intervals( + scalar_t sample, + int64_t inputSize, + int64_t outputSize, + int64_t poolSize) { + std::vector sequence(outputSize); + if (outputSize > 1) { + scalar_t alpha = static_cast(inputSize - poolSize) / + static_cast(outputSize - 1); + + for (const auto i : c10::irange(outputSize - 1)) { + sequence[i] = + static_cast((i + sample) * alpha) - static_cast(sample * alpha); + } + } + if (outputSize > 0) { + sequence[outputSize - 1] = inputSize - poolSize; + } + return sequence; +} + +template +static inline void fractional_max_pool_check_shape( + const Tensor& input, + const Tensor& randomSamples) { + + TORCH_CHECK( + input.scalar_type() == randomSamples.scalar_type(), + "Expect _random_samples to have the same dtype as input"); + + int64_t ndimension = randomSamples.ndimension(); + TORCH_CHECK( + ndimension == 3, + "Expect _random_samples to have 3 dimensions, got ", ndimension); + + int64_t N = randomSamples.size(0); + int64_t C = randomSamples.size(1); + int64_t D = randomSamples.size(2); + + int64_t input_batch, input_channel; + if (ndim == 2) { + // fractional_max_pool2d + if (input.ndimension() == 3) { + input_batch = 1; + input_channel = input.size(0); + } else { + input_batch = input.size(0); + input_channel = input.size(1); + } + } else { + // factional_max_pool3d + if (input.ndimension() == 4) { + input_batch = 1; + input_channel = input.size(0); + } else { + input_batch = input.size(0); + input_channel = input.size(1); + } + } + + TORCH_CHECK( + N >= input_batch, + "Expect _random_samples.size(0) no less then input batch size."); + TORCH_CHECK( + C == input_channel, + "Expect _random_samples.size(1) equals to input channel size."); + TORCH_CHECK( + D == ndim, + "Expect _random_samples.size(2) equals to ", ndim, "; got ", D, "."); +} + +}} // at::native diff --git a/aten/src/ATen/native/GridSamplerUtils.h b/aten/src/ATen/native/GridSamplerUtils.h index 0b6f29de8c427..7c22fedfe94e2 100644 --- a/aten/src/ATen/native/GridSamplerUtils.h +++ b/aten/src/ATen/native/GridSamplerUtils.h @@ -101,7 +101,7 @@ bool cond_cudnn_grid_sampler( at::native::canUse32BitIndexMath(input) && at::native::canUse32BitIndexMath(grid) && input.dim() == 4 && - input.size(1) <= 1024); + input.sym_size(1) <= 1024); } } // anonymous namespace diff --git a/aten/src/ATen/native/Histogram.cpp b/aten/src/ATen/native/Histogram.cpp index c3a007f2c2dcb..89ede6bea35c1 100644 --- a/aten/src/ATen/native/Histogram.cpp +++ b/aten/src/ATen/native/Histogram.cpp @@ -1,10 +1,28 @@ -#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include #include -#include #include #include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#endif + #include #include #include diff --git a/aten/src/ATen/native/Histogram.h b/aten/src/ATen/native/Histogram.h index 9df0aafafc18d..3305cc5e315fb 100644 --- a/aten/src/ATen/native/Histogram.h +++ b/aten/src/ATen/native/Histogram.h @@ -3,8 +3,6 @@ #include #include -#include - namespace at { namespace native { using histogramdd_fn = void(*)(const Tensor&, const c10::optional&, bool, Tensor&, const TensorList&); diff --git a/aten/src/ATen/native/Im2Col.cpp b/aten/src/ATen/native/Im2Col.cpp index dd6c8b303a5fe..416e77e9ff199 100644 --- a/aten/src/ATen/native/Im2Col.cpp +++ b/aten/src/ATen/native/Im2Col.cpp @@ -1,12 +1,21 @@ -#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include #include -#include -#include #include #include #include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#endif + namespace at { namespace native { namespace { @@ -85,7 +94,6 @@ static void im2col_out_cpu_template( int64_t output_length = output_height * output_width; output.resize_({batch_size, n_output_plane, output_length}); - output.zero_(); AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, input.scalar_type(), "im2col_out_cpu", [&] { diff --git a/aten/src/ATen/native/IndexingUtils.cpp b/aten/src/ATen/native/IndexingUtils.cpp index e91eff03ab856..2dba1972ce574 100644 --- a/aten/src/ATen/native/IndexingUtils.cpp +++ b/aten/src/ATen/native/IndexingUtils.cpp @@ -1,9 +1,10 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include namespace at { namespace native { bool canUse32BitIndexMath(const TensorBase& t, int64_t max_elem) { - int64_t elements = t.numel(); + auto elements = t.sym_numel(); if (elements >= max_elem) { return false; } @@ -11,16 +12,16 @@ bool canUse32BitIndexMath(const TensorBase& t, int64_t max_elem) { return max_elem > 0; } - int64_t offset = 0; - int64_t linearId = elements - 1; + c10::SymInt offset = 0; + auto linearId = elements - 1; // NOTE: Assumes all strides are positive, which is true for now // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) for (int i = t.dim() - 1; i >= 0; --i) { - int64_t curDimIndex = linearId % t.size(i); - int64_t curDimOffset = curDimIndex * t.stride(i); + auto curDimIndex = linearId % t.sym_size(i); + auto curDimOffset = curDimIndex * t.sym_stride(i); offset += curDimOffset; - linearId /= t.size(i); + linearId /= t.sym_size(i); } if (offset >= max_elem) { diff --git a/aten/src/ATen/native/Integration.cpp b/aten/src/ATen/native/Integration.cpp index 7ca01bae18a57..09e444476d1fd 100644 --- a/aten/src/ATen/native/Integration.cpp +++ b/aten/src/ATen/native/Integration.cpp @@ -1,12 +1,23 @@ -#include -#include -#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include #include +#include +#include #include #include #include #include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#endif + namespace at { namespace native { namespace { diff --git a/aten/src/ATen/native/Itertools.cpp b/aten/src/ATen/native/Itertools.cpp index 265b05054b0a3..8d6ff506a43f8 100644 --- a/aten/src/ATen/native/Itertools.cpp +++ b/aten/src/ATen/native/Itertools.cpp @@ -1,5 +1,20 @@ -#include -#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#include +#include +#include +#endif #include diff --git a/aten/src/ATen/native/Batching.cpp b/aten/src/ATen/native/LegacyBatching.cpp similarity index 98% rename from aten/src/ATen/native/Batching.cpp rename to aten/src/ATen/native/LegacyBatching.cpp index b50b6201b7a2d..6dcacbd1f23f5 100644 --- a/aten/src/ATen/native/Batching.cpp +++ b/aten/src/ATen/native/LegacyBatching.cpp @@ -1,7 +1,7 @@ #include -#include +#include #include -#include +#include namespace at { namespace native { diff --git a/aten/src/ATen/native/Lerp.cpp b/aten/src/ATen/native/Lerp.cpp index bfac91a881ae0..2e67dec35033f 100644 --- a/aten/src/ATen/native/Lerp.cpp +++ b/aten/src/ATen/native/Lerp.cpp @@ -1,5 +1,14 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS #include +#else +#include +#endif namespace at { namespace meta { diff --git a/aten/src/ATen/native/Linear.cpp b/aten/src/ATen/native/Linear.cpp index 7192cc6e1138c..591289a726ac8 100644 --- a/aten/src/ATen/native/Linear.cpp +++ b/aten/src/ATen/native/Linear.cpp @@ -1,17 +1,36 @@ -#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include #include -#include #include -#include #include -#include +#include +#include #include #include #include -#include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#endif + #include -#include #include #include #include @@ -545,6 +564,9 @@ Tensor einsum(c10::string_view equation, TensorList operands, at::OptionalIntArr // Sum out contraction dims if (perm_index - out_num_dim > 0) { + // if there were ops to contract, we would have already done so + // in the previous loop and all the dims to sum are now 1 + // NB: use view instead of squeeze (or sum) for faster (mps) performance if (num_ops > 1) { auto sizes = ops[0].sym_sizes().vec(); for (auto dim = perm_index - 1; dim >= out_num_dim; --dim) { diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index c658d4427c97d..7e47170cd72ee 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -1,27 +1,132 @@ -#include -#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include #include #include #include #include #include #include -#include #include #include #include #include #include -#include -#include #include +#include +#include +#include #include -#include #include #include #include -#include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#endif + #include #include #include @@ -772,7 +877,7 @@ std::vector> matrix_chain_order(TensorList tensors) { /** * @brief Recursively multiplies the tensors i...j using the given order * - * @param tensors matrices to multiply togther + * @param tensors matrices to multiply together * @param order optimal chain multiplication order from #matrix_chain_order * @param i index of first tensor to be multiplied * @param j index of last tensor to be multiplied @@ -2665,7 +2770,7 @@ Tensor& linalg_norm_out(const Tensor& X, c10::string_view ord, OptionalIntArrayR //////////////////////////////////////////////////////////////////////////////// // Frobenius Norm // -// Just used in linalg.norm. It should not be removed. // +// Just used in torch..norm. It should not be removed. // //////////////////////////////////////////////////////////////////////////////// Tensor frobenius_norm(const Tensor& self) { @@ -2711,7 +2816,7 @@ Tensor &frobenius_norm_out(const Tensor& self, //////////////////////////////////////////////////////////////////////////////// // Nuclear Norm // -// Just used in linalg.norm. It should not be removed. // +// Just used in torch.norm. It should not be removed. // //////////////////////////////////////////////////////////////////////////////// Tensor nuclear_norm(const Tensor& self, bool keepdim) { diff --git a/aten/src/ATen/native/Loss.cpp b/aten/src/ATen/native/Loss.cpp index 52569ba6b4995..78b7d70236207 100644 --- a/aten/src/ATen/native/Loss.cpp +++ b/aten/src/ATen/native/Loss.cpp @@ -1,15 +1,62 @@ -#include -#include -#include -#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include #include +#include +#include +#include +#include #include #include -#include #include #include #include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#endif + constexpr float EPSILON = 1e-12; namespace { diff --git a/aten/src/ATen/native/LossCTC.cpp b/aten/src/ATen/native/LossCTC.cpp index 1ddb8f2285640..dcfad968cad79 100644 --- a/aten/src/ATen/native/LossCTC.cpp +++ b/aten/src/ATen/native/LossCTC.cpp @@ -5,16 +5,36 @@ // 1. Graves et al: http://www.cs.toronto.edu/~graves/icml_2006.pdf // We use the equations from above link, but note that [1] has 1-based indexing and we (of course) use 0-based. // Graves et al call the probabilities y, we use log_probs (also calling them inputs) +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS -#include +#include #include #include -#include +#include +#include #include #include #include -#include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#endif + #include namespace at { diff --git a/aten/src/ATen/native/LossMulti.h b/aten/src/ATen/native/LossMulti.h index 54736bcc123b2..148615e7e14f1 100644 --- a/aten/src/ATen/native/LossMulti.h +++ b/aten/src/ATen/native/LossMulti.h @@ -1,8 +1,8 @@ -#include -#include -#include - #pragma once +#include +#include +#include +#include namespace at { namespace native { namespace { diff --git a/aten/src/ATen/native/LossMultiLabelMargin.cpp b/aten/src/ATen/native/LossMultiLabelMargin.cpp index f59de5c8817a4..26d7a748df8d4 100644 --- a/aten/src/ATen/native/LossMultiLabelMargin.cpp +++ b/aten/src/ATen/native/LossMultiLabelMargin.cpp @@ -1,10 +1,23 @@ -#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include #include #include #include #include #include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#include +#endif + namespace at { namespace native { diff --git a/aten/src/ATen/native/LossMultiMargin.cpp b/aten/src/ATen/native/LossMultiMargin.cpp index c7ab53f1d211b..110520cf8f950 100644 --- a/aten/src/ATen/native/LossMultiMargin.cpp +++ b/aten/src/ATen/native/LossMultiMargin.cpp @@ -1,9 +1,19 @@ -#include -#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include #include +#include #include #include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#endif + namespace at { namespace native { diff --git a/aten/src/ATen/native/LossNLL.cpp b/aten/src/ATen/native/LossNLL.cpp index 79e98c877548a..28fc60508ab10 100644 --- a/aten/src/ATen/native/LossNLL.cpp +++ b/aten/src/ATen/native/LossNLL.cpp @@ -1,13 +1,32 @@ -#include -#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include #include #include +#include #include +#include #include #include #include #include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#endif + #include #include @@ -637,7 +656,7 @@ Tensor nll_loss(const Tensor & self, const Tensor & target, const c10::optional< c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); const Tensor& weight = *weight_maybe_owned; - return std::get<0>(at::nll_loss_forward(self, target, weight, reduction, ignore_index)); + return std::get<0>(at::nll_loss_forward_symint(self, target, weight, reduction, ignore_index)); } Tensor nll_loss_nd_symint( diff --git a/aten/src/ATen/native/LossNLL2d.cpp b/aten/src/ATen/native/LossNLL2d.cpp index 6950cb2805e9e..aee22ce3edeb5 100644 --- a/aten/src/ATen/native/LossNLL2d.cpp +++ b/aten/src/ATen/native/LossNLL2d.cpp @@ -1,12 +1,23 @@ -#include -#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include #include #include -#include #include #include #include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#include +#endif + namespace at { namespace native { @@ -487,7 +498,7 @@ Tensor nll_loss2d(const Tensor & self, const Tensor & target, const c10::optiona c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); const Tensor& weight = *weight_maybe_owned; - return std::get<0>(at::nll_loss2d_forward(self, target, weight, reduction, ignore_index)); + return std::get<0>(at::nll_loss2d_forward_symint(self, target, weight, reduction, ignore_index)); } } // namespace native diff --git a/aten/src/ATen/native/MathBitsFallback.h b/aten/src/ATen/native/MathBitsFallback.h index 4e9c2d9e98b18..84e72aa724d0e 100644 --- a/aten/src/ATen/native/MathBitsFallback.h +++ b/aten/src/ATen/native/MathBitsFallback.h @@ -1,12 +1,17 @@ -#include +#include #include #include #include -#include #include #include #include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + namespace at { namespace native { // This fallback should only be used for operations that are self inverse and have a corresponding tensor diff --git a/aten/src/ATen/native/MaxPooling.cpp b/aten/src/ATen/native/MaxPooling.cpp index 0f05eeac7d3e9..e809c75ba21d6 100644 --- a/aten/src/ATen/native/MaxPooling.cpp +++ b/aten/src/ATen/native/MaxPooling.cpp @@ -1,4 +1,5 @@ -#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include #include #include #include @@ -6,6 +7,16 @@ #include #include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#endif + namespace at { namespace native { diff --git a/aten/src/ATen/native/MaxUnpooling.cpp b/aten/src/ATen/native/MaxUnpooling.cpp index 33cc4dc7a61ce..adab802d65cd5 100644 --- a/aten/src/ATen/native/MaxUnpooling.cpp +++ b/aten/src/ATen/native/MaxUnpooling.cpp @@ -1,8 +1,17 @@ -#include -#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include #include #include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#endif + namespace at { namespace native { diff --git a/aten/src/ATen/native/Memory.cpp b/aten/src/ATen/native/Memory.cpp index df6949b2d7d95..2b66f08933934 100644 --- a/aten/src/ATen/native/Memory.cpp +++ b/aten/src/ATen/native/Memory.cpp @@ -1,6 +1,17 @@ -#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include #include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#endif + namespace at { namespace native { diff --git a/aten/src/ATen/native/NNPACK.cpp b/aten/src/ATen/native/NNPACK.cpp index 3df0a0623e437..4fb40a17d0267 100644 --- a/aten/src/ATen/native/NNPACK.cpp +++ b/aten/src/ATen/native/NNPACK.cpp @@ -1,10 +1,21 @@ -#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include #include #include #include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#endif + #if !AT_NNPACK_ENABLED() namespace at { @@ -198,8 +209,8 @@ Tensor _nnpack_spatial_convolution( .height = (size_t)output.size(2), }; const nnp_size output_subsample = { - .width = stride[1], - .height = stride[0], + .width = static_cast(stride[1]), + .height = static_cast(stride[0]), }; const auto input_ = input.contiguous(); diff --git a/aten/src/ATen/native/NaiveConvolutionTranspose2d.cpp b/aten/src/ATen/native/NaiveConvolutionTranspose2d.cpp index ea604c426c3b4..a9cf36a004f4c 100644 --- a/aten/src/ATen/native/NaiveConvolutionTranspose2d.cpp +++ b/aten/src/ATen/native/NaiveConvolutionTranspose2d.cpp @@ -1,5 +1,5 @@ -#include -#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include #include #include @@ -8,6 +8,17 @@ #include #include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#endif + #include #include diff --git a/aten/src/ATen/native/NaiveConvolutionTranspose3d.cpp b/aten/src/ATen/native/NaiveConvolutionTranspose3d.cpp index 3d34091fd036a..cf60f56f9df44 100644 --- a/aten/src/ATen/native/NaiveConvolutionTranspose3d.cpp +++ b/aten/src/ATen/native/NaiveConvolutionTranspose3d.cpp @@ -1,11 +1,23 @@ -#include -#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include #include #include #include #include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#endif + namespace at { namespace native { diff --git a/aten/src/ATen/native/NaiveDilatedConvolution.cpp b/aten/src/ATen/native/NaiveDilatedConvolution.cpp index fa7b30f5977ef..827bf204b093f 100644 --- a/aten/src/ATen/native/NaiveDilatedConvolution.cpp +++ b/aten/src/ATen/native/NaiveDilatedConvolution.cpp @@ -1,14 +1,25 @@ -#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include #include #include #include #include #include -#include #include #include #include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#endif + namespace at { namespace native { namespace { diff --git a/aten/src/ATen/native/NamedTensor.cpp b/aten/src/ATen/native/NamedTensor.cpp index d725c26a14631..6ee2f095b6d09 100644 --- a/aten/src/ATen/native/NamedTensor.cpp +++ b/aten/src/ATen/native/NamedTensor.cpp @@ -1,8 +1,30 @@ -#include -#include - +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include #include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#endif + #include #include diff --git a/aten/src/ATen/native/NegateFallback.cpp b/aten/src/ATen/native/NegateFallback.cpp index a2b134a91e40e..0a34b4f4331d6 100644 --- a/aten/src/ATen/native/NegateFallback.cpp +++ b/aten/src/ATen/native/NegateFallback.cpp @@ -1,3 +1,4 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include diff --git a/aten/src/ATen/native/NonSymbolicBC.h b/aten/src/ATen/native/NonSymbolicBC.h index e7d31ae3fa020..0b942efb52c3b 100644 --- a/aten/src/ATen/native/NonSymbolicBC.h +++ b/aten/src/ATen/native/NonSymbolicBC.h @@ -22,4 +22,6 @@ TORCH_API at::Tensor _embedding_bag_sparse_backward(const at::Tensor & grad, con TORCH_API at::Tensor value_selecting_reduction_backward(const at::Tensor & grad, int64_t dim, const at::Tensor & indices, at::IntArrayRef sizes, bool keepdim); TORCH_API at::Tensor trace_backward(const at::Tensor & grad, at::IntArrayRef sizes); TORCH_API at::Tensor index_select_backward(const at::Tensor & grad, at::IntArrayRef self_sizes, int64_t dim, const at::Tensor & index); +TORCH_API at::Tensor select(const at::Tensor& self, int64_t dim, int64_t index); +TORCH_API std::vector tensor_split(const Tensor& self, IntArrayRef indices, int64_t dim); }} diff --git a/aten/src/ATen/native/Normalization.cpp b/aten/src/ATen/native/Normalization.cpp index 6911d780c1d0e..ab9094d9b5981 100644 --- a/aten/src/ATen/native/Normalization.cpp +++ b/aten/src/ATen/native/Normalization.cpp @@ -1,18 +1,53 @@ -#include -#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include #include -#include -#include #include +#include +#include +#include +#include +#include +#include +#include #include -#include #include #include #include +#include #include #include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#endif + #include #include @@ -89,17 +124,17 @@ std::tuple batch_norm_cpu_transform_input_template( const Tensor& input, const Tensor& weight, const Tensor& bias, const Tensor& save_mean /* optional */, const Tensor& save_invstd /* optional */, const Tensor& running_mean /* optional */, const Tensor& running_var /* optional */, - bool train, double eps) { + bool train, double eps, Tensor& output) { bool all_contiguous = is_contiguous(input) - && (!weight.defined() || weight.is_contiguous()) - && (!bias.defined() || bias.is_contiguous()) - && running_mean.is_contiguous() - && running_var.is_contiguous(); + && is_contiguous(output) + && (!weight.defined() || weight.is_contiguous()) + && (!bias.defined() || bias.is_contiguous()) + && running_mean.is_contiguous() + && running_var.is_contiguous(); // inference contiguous path if (all_contiguous) { - Tensor output = at::empty_like(input, suggest_memory_format_contig(input)); batch_norm_cpu_stub(kCPU, output, input, weight, bias, save_mean, save_invstd, running_mean, running_var, train, eps); return std::make_tuple(output, save_mean, save_invstd); @@ -131,7 +166,6 @@ std::tuple batch_norm_cpu_transform_input_template( auto b = bias.defined() ? as_nd(bias) : at::detail::scalar_tensor_static(0, dtype, kCPU); - Tensor output = at::empty_like(input, input.suggest_memory_format()); auto iter = TensorIteratorConfig() .add_output(output) .add_input(input) @@ -151,30 +185,17 @@ std::tuple batch_norm_cpu_transform_input_template( template class VarTransform> std::tuple batch_norm_cpu_update_stats_template( const Tensor& input, const Tensor& running_mean, const Tensor& running_var, - double momentum, double eps) { + double momentum, double eps, Tensor& save_mean, Tensor& save_var_transform) { using accscalar_t = at::acc_type; int64_t n_input = input.size(1); int64_t n = input.numel() / n_input; - const int64_t ndim = input.dim(); - - // Reduce all dimensions except dim=1 - DimVector reduce_dims(ndim - 1); - reduce_dims[0] = 0; - for (const auto i : c10::irange(2, ndim)) { - reduce_dims[i - 1] = i; - } bool all_contiguous = is_contiguous(input); const bool mixed_type = !std::is_same::value; const auto dtype = mixed_type ? kFloat : input.scalar_type(); - // For contiguous case, leave 'mean' computation to kernel - Tensor save_mean = all_contiguous - ? at::empty({n_input}, input.options().dtype(dtype)) - : at::mean(input, /*dim=*/reduce_dims, /*keepdim=*/false, dtype); - Tensor save_var_transform = at::empty({n_input}, input.options().dtype(dtype)); auto save_mean_a = save_mean.accessor(); auto save_var_transform_a = save_var_transform.accessor(); @@ -244,6 +265,25 @@ std::tuple batch_norm_cpu_update_stats_template( return std::make_tuple(save_mean, save_var_transform); } +template class VarTransform> +std::tuple batch_norm_cpu_update_stats_template( + const Tensor& input, const Tensor& running_mean, const Tensor& running_var, + double momentum, double eps) { + int64_t n_input = input.size(1); + const int64_t ndim = input.dim(); + DimVector reduce_dims(ndim - 1); + reduce_dims[0] = 0; + for (const auto i : c10::irange(2, ndim)) { + reduce_dims[i - 1] = i; + } + + const bool mixed_type = !std::is_same::value; + const auto dtype = mixed_type ? kFloat : input.scalar_type(); + Tensor save_mean = is_contiguous(input) ? at::empty({n_input}, input.options().dtype(dtype)) : at::mean(input, /*dim=*/reduce_dims, /*keepdim=*/false, dtype); + Tensor save_var_transform = at::empty({n_input}, input.options().dtype(dtype)); + return batch_norm_cpu_update_stats_template(input, running_mean, running_var, momentum, eps, save_mean, save_var_transform); +} + template std::tuple batch_norm_backward_cpu_template( const Tensor& grad_out_, const Tensor& input, const Tensor& weight, @@ -656,8 +696,8 @@ std::tuple batch_norm_update_stats_cpu( }); } -std::tuple batch_norm_cpu(const Tensor& self, const c10::optional& weight_opt, const c10::optional& bias_opt, const c10::optional& running_mean_opt, const c10::optional& running_var_opt, - bool train, double momentum, double eps) { +std::tuple batch_norm_cpu_out(const Tensor& self, const c10::optional& weight_opt, const c10::optional& bias_opt, const c10::optional& running_mean_opt, const c10::optional& running_var_opt, + bool train, double momentum, double eps, Tensor& out, Tensor& save_mean, Tensor& save_var) { // See [Note: hacky wrapper removal for optional tensor] c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); const Tensor& weight = *weight_maybe_owned; @@ -665,33 +705,112 @@ std::tuple batch_norm_cpu(const Tensor& self, const c10: const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();}); const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();}); - checkBackend("batch_norm_cpu", {self, weight, bias, running_mean, running_var}, Backend::CPU); + checkBackend("batch_norm_cpu_out", {self, weight, bias, running_mean, running_var}, Backend::CPU); + // Resize out + at::native::resize_output(out, self.sizes()); const bool mixed_type = is_mixed_type(self, weight, bias, running_mean, running_var); - return AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, self.scalar_type(), "batch_norm", [&] { + AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, self.scalar_type(), "batch_norm", [&] { if (mixed_type) { check_mixed_data_type(self, weight, bias, running_mean, running_var); if (!train) { - auto save_mean = at::empty({0}, self.options().dtype(kFloat)); - auto save_var = at::empty({0}, self.options().dtype(kFloat)); - return batch_norm_cpu_transform_input_template(self, weight, bias, save_mean, save_var, running_mean, running_var, train, eps); + return batch_norm_cpu_transform_input_template(self, weight, bias, save_mean, save_var, running_mean, running_var, train, eps, out); } else { - auto save_stats = batch_norm_cpu_update_stats_template(self, running_mean, running_var, momentum, eps); - return batch_norm_cpu_transform_input_template(self, weight, bias, std::get<0>(save_stats), std::get<1>(save_stats), running_mean, running_var, train, eps); + // Resize save_mean and save_var + at::native::resize_output(save_mean, {self.size(1)}); + at::native::resize_output(save_var, {self.size(1)}); + auto save_stats = batch_norm_cpu_update_stats_template(self, running_mean, running_var, momentum, eps, save_mean, save_var); + return batch_norm_cpu_transform_input_template(self, weight, bias, std::get<0>(save_stats), std::get<1>(save_stats), running_mean, running_var, train, eps, out); } } else { if (!train) { - auto save_mean = at::empty({0}, self.options()); - auto save_var = at::empty({0}, self.options()); - return batch_norm_cpu_transform_input_template(self, weight, bias, save_mean, save_var, running_mean, running_var, train, eps); + return batch_norm_cpu_transform_input_template(self, weight, bias, save_mean, save_var, running_mean, running_var, train, eps, out); } else { - auto save_stats = batch_norm_cpu_update_stats_template(self, running_mean, running_var, momentum, eps); - return batch_norm_cpu_transform_input_template(self, weight, bias, std::get<0>(save_stats), std::get<1>(save_stats), running_mean, running_var, train, eps); + // Resize save_mean and save_var + at::native::resize_output(save_mean, {self.size(1)}); + at::native::resize_output(save_var, {self.size(1)}); + auto save_stats = batch_norm_cpu_update_stats_template(self, running_mean, running_var, momentum, eps, save_mean, save_var); + return batch_norm_cpu_transform_input_template(self, weight, bias, std::get<0>(save_stats), std::get<1>(save_stats), running_mean, running_var, train, eps, out); } } }); + + return std::tuple(out, save_mean, save_var); } +std::tuple batch_norm_cpu(const Tensor& self, const c10::optional& weight_opt, const c10::optional& bias_opt, const c10::optional& running_mean_opt, const c10::optional& running_var_opt, + bool train, double momentum, double eps) { + // See [Note: hacky wrapper removal for optional tensor] + c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); + const Tensor& weight = *weight_maybe_owned; + const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();}); + const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();}); + const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();}); + + checkBackend("batch_norm_cpu", {self, weight, bias, running_mean, running_var}, Backend::CPU); + + // Prepare output tensor + const bool all_contiguous = is_contiguous(self) + && (!weight.defined() || weight.is_contiguous()) + && (!bias.defined() || bias.is_contiguous()) + && running_mean.is_contiguous() + && running_var.is_contiguous(); + Tensor output = at::empty_like(self, all_contiguous ? suggest_memory_format_contig(self) : self.suggest_memory_format()); + + // Prepare save_mean and save_var + Tensor save_var; + Tensor save_mean; + const bool mixed_type = is_mixed_type(self, weight, bias, running_mean, running_var); + const int64_t ndim = self.dim(); + DimVector reduce_dims(ndim - 1); + reduce_dims[0] = 0; + for (const auto i : c10::irange(2, ndim)) { + reduce_dims[i - 1] = i; + } + if (mixed_type) { + if (!train) { + save_mean = at::empty({0}, self.options().dtype(kFloat)); + save_var = at::empty({0}, self.options().dtype(kFloat)); + } else { + save_mean = is_contiguous(self) ? at::empty({self.size(1)}, self.options().dtype(kFloat)) : at::mean(self, /*dim=*/reduce_dims, /*keepdim=*/false, kFloat); + save_var = at::empty({self.size(1)}, self.options().dtype(kFloat)); + } + } else { + if (!train) { + save_mean = at::empty({0}, self.options()); + save_var = at::empty({0}, self.options()); + } else { + save_mean = is_contiguous(self) ? at::empty({self.size(1)}, self.options()) : at::mean(self, /*dim=*/reduce_dims, /*keepdim=*/false); + save_var = at::empty({self.size(1)}, self.options()); + } + } + return batch_norm_cpu_out(self, weight_opt, bias_opt, running_mean_opt, running_var_opt, train, momentum, eps, output, save_mean, save_var); +} + + +std::tuple _batch_norm_legit_cpu( + const Tensor& self, const c10::optional& weight_opt, const c10::optional& bias_opt, + Tensor& running_mean, Tensor& running_var, bool train, double momentum, double eps) { + return batch_norm_cpu(self, weight_opt, bias_opt, running_mean, running_var, train, momentum, eps); +} + +std::tuple _batch_norm_legit_no_stats_cpu( + const Tensor& self, const c10::optional& weight_opt, const c10::optional& bias_opt, + bool train, double momentum, double eps) { + return batch_norm_cpu(self, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, eps); +} + + +std::tuple _batch_norm_legit_cpu_out(const Tensor& self, const c10::optional& weight_opt, const c10::optional& bias_opt, Tensor& running_mean, Tensor& running_var, bool train, double momentum, double eps, Tensor& out, Tensor& save_mean, Tensor& save_var) { + return batch_norm_cpu_out(self, weight_opt, bias_opt, running_mean, running_var, train, momentum, eps, out, save_mean, save_var); +} + + +std::tuple _batch_norm_legit_no_stats_cpu_out(const Tensor& self, const c10::optional& weight_opt, const c10::optional& bias_opt, bool train, double momentum, double eps, Tensor& out, Tensor& save_mean, Tensor& save_var) { + return batch_norm_cpu_out(self, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, eps, out, save_mean, save_var); +} + + std::tuple batch_norm_backward_cpu(const Tensor& grad_out, const Tensor& self, const c10::optional& weight_opt, const c10::optional& running_mean_opt, const c10::optional& running_var_opt, const c10::optional& save_mean_opt, const c10::optional& save_invstd_opt, bool train, double eps, std::array grad_input_mask) { // See [Note: hacky wrapper removal for optional tensor] diff --git a/aten/src/ATen/native/Onehot.cpp b/aten/src/ATen/native/Onehot.cpp index a0c061062174b..41b7a69618636 100644 --- a/aten/src/ATen/native/Onehot.cpp +++ b/aten/src/ATen/native/Onehot.cpp @@ -1,4 +1,14 @@ -#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#endif namespace at { namespace native { diff --git a/aten/src/ATen/native/PackedSequence.cpp b/aten/src/ATen/native/PackedSequence.cpp index 736829eb6d118..19b12b0819607 100644 --- a/aten/src/ATen/native/PackedSequence.cpp +++ b/aten/src/ATen/native/PackedSequence.cpp @@ -1,5 +1,20 @@ -#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include #include +#else +#include +#include +#include +#include +#include +#include +#include +#include +#include +#endif #include diff --git a/aten/src/ATen/native/PadNd.cpp b/aten/src/ATen/native/PadNd.cpp index c6b18c1257b51..9421d537717c8 100644 --- a/aten/src/ATen/native/PadNd.cpp +++ b/aten/src/ATen/native/PadNd.cpp @@ -1,8 +1,29 @@ -#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include +#include #include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#endif + namespace at { namespace native { Tensor constant_pad_nd(const Tensor& self, IntArrayRef pad, const Scalar& value) { diff --git a/aten/src/ATen/native/PixelShuffle.cpp b/aten/src/ATen/native/PixelShuffle.cpp index 2a100321a6400..e535909a73429 100644 --- a/aten/src/ATen/native/PixelShuffle.cpp +++ b/aten/src/ATen/native/PixelShuffle.cpp @@ -1,10 +1,21 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include +#include -#include -#include #include -#include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#endif + +#include +#include +#include namespace at { namespace native { diff --git a/aten/src/ATen/native/PointwiseOps.cpp b/aten/src/ATen/native/PointwiseOps.cpp index a99bc959eb958..8259135ce14a3 100644 --- a/aten/src/ATen/native/PointwiseOps.cpp +++ b/aten/src/ATen/native/PointwiseOps.cpp @@ -1,12 +1,17 @@ // Ternary and higher-order pointwise operations +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include -#include -#include -#include -#include +#include +#include +#include -#include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#include +#endif namespace at { namespace meta { diff --git a/aten/src/ATen/native/Pool.h b/aten/src/ATen/native/Pool.h index cf5b45b365d05..0ff4490086b7e 100644 --- a/aten/src/ATen/native/Pool.h +++ b/aten/src/ATen/native/Pool.h @@ -67,17 +67,18 @@ static inline T pooling_output_shape( inputSize, kernelSize, pad, pad, stride, dilation, ceil_mode); } -inline std::pair pooling_same_mode_padding_lr( - int64_t inputSize, int64_t kernelSize, int64_t stride, int64_t dilation) { +template +std::pair _pooling_same_mode_padding_lr( + T inputSize, T kernelSize, int64_t stride, int64_t dilation) { // NOTE: with strides, the output shape is ceil(inputSize/stride) - auto total_padding = dilation * (kernelSize - 1); + auto total_padding = T(dilation) * (kernelSize - 1); // Prefer symmetric padding if possible if (stride > 2 && (total_padding % 2 == 1)) { // The floor in the output size calculation gives us a little wiggle room auto wiggle_room = inputSize % stride - 1; if (wiggle_room > 0) { - --total_padding; + total_padding = total_padding - 1; } } @@ -85,6 +86,15 @@ inline std::pair pooling_same_mode_padding_lr( return {left, total_padding - left}; } +inline std::pair pooling_same_mode_padding_lr( + int64_t inputSize, int64_t kernelSize, int64_t stride, int64_t dilation) { + return _pooling_same_mode_padding_lr(inputSize, kernelSize, stride, dilation); +} + +inline std::pair pooling_same_mode_padding_lr( + c10::SymInt inputSize, c10::SymInt kernelSize, int64_t stride, int64_t dilation) { + return _pooling_same_mode_padding_lr(inputSize, kernelSize, stride, dilation); +} // AveragePool2d/DilatedMaxPool2d (forward) static inline void diff --git a/aten/src/ATen/native/Pooling.cpp b/aten/src/ATen/native/Pooling.cpp index 724c53fdd0c00..fcbe741ab0ea0 100644 --- a/aten/src/ATen/native/Pooling.cpp +++ b/aten/src/ATen/native/Pooling.cpp @@ -1,12 +1,31 @@ -#include - -#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include #include #include #include -#include #include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#endif + #include namespace at { namespace native { diff --git a/aten/src/ATen/native/Pow.cpp b/aten/src/ATen/native/Pow.cpp index 4326853a8165a..7050524acebf2 100644 --- a/aten/src/ATen/native/Pow.cpp +++ b/aten/src/ATen/native/Pow.cpp @@ -1,11 +1,20 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include -#include -#include -#include +#include #include #include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#endif + namespace at { namespace meta { diff --git a/aten/src/ATen/native/QuantizedLinear.cpp b/aten/src/ATen/native/QuantizedLinear.cpp index af7643ec18b6c..002bb1adc4386 100644 --- a/aten/src/ATen/native/QuantizedLinear.cpp +++ b/aten/src/ATen/native/QuantizedLinear.cpp @@ -1,20 +1,28 @@ -#include -#include -#include -#include -#include -#include -#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include -#include -#include +#include #include #include #include #include #include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#include +#include +#include +#include +#endif + #include #ifdef USE_FBGEMM diff --git a/aten/src/ATen/native/README.md b/aten/src/ATen/native/README.md index 01a25e3a978cc..651b21ae01863 100644 --- a/aten/src/ATen/native/README.md +++ b/aten/src/ATen/native/README.md @@ -47,10 +47,9 @@ signature. if one argument is a `FloatTensor`, all other arguments are checked to be `FloatTensor`s). `Tensor` or `Tensor?` must sometimes be annotated to indicate aliasing and mutability. - In general annotations can be defined via the following four situations: - - `Tensor(a)` - `a` is a set of Tensors that may alias to the same data. + In general annotations can be defined via the following situations: + - `Tensor(a)` - `a` is a set of Tensors that may alias to the same data. The set could have a size of one. - `Tensor(a!)` - members of `a` may be written to thus mutating the underlying data. - - `Tensor!` - shorthand for Tensor(fresh\_identifier!) - `Tensor(a! -> a|b)` - Tensor is in set `a`, written to, and after the write is in set `a` AND `b`. For more details on when and why this needs to happen, please see the section on annotations. - `Tensor[]`. A `Tensor[]` argument translates into a C++ argument of type `ArrayRef` @@ -445,7 +444,7 @@ By default, ATen code generation will generate device check, which will ensure all the tensor parameters passed to kernel are on the same device. -However, in some cases, checking the device is unncessary, because, +However, in some cases, checking the device is unnecessary, because, e.g., you call a function allows to work on multiple devices. In that case, code generation of the device check can be disabled by adding `device_check: NoCheck` to your function definition. @@ -556,7 +555,7 @@ Here're steps to follow to decide the right dispatch keyword: Note: to support training, you're required to write a formula in derivatives.yaml since your backend implementations don't support autograd. - - Yes: you're likely calling other `at::` ops in the implemetation. Go to step 2. + - Yes: you're likely calling other `at::` ops in the implementation. Go to step 2. 2. Think about training: does your kernel support autograd? [check autograd support](#will-your-function-be-automatically-differentiable) - Yes: in other words, you're providing a `CompositeImplicitAutograd` kernel which supports both inference and autograd. @@ -610,7 +609,7 @@ It shows for a certain operator, what the computed dispatch table looks like aft 4. TODO: AutogradCPUOrCUDA Note that in native_functions.yaml you can mix using backend keywords and alias keywords above for one op: - - direct registration to backend always has higher precendence than alias + - direct registration to backend always has higher precedence than alias - DO NOT provide multiple alias keywords to the same op: alias keywords have precedence `CompositeExplicitAutograd > CompositeImplicitAutograd`, e.g. adding both `CompositeImplicitAutograd` and `CompositeExplicitAutograd` kernels for one op will completely ignore `CompositeImplicitAutograd` kernel for both inference and training. Thus this will trigger an error when native_functions.yaml is parsed. diff --git a/aten/src/ATen/native/RNN.cpp b/aten/src/ATen/native/RNN.cpp index 670395893d8ef..52efc6929f54e 100644 --- a/aten/src/ATen/native/RNN.cpp +++ b/aten/src/ATen/native/RNN.cpp @@ -1,8 +1,10 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include -#include -#include -#include +#include +#include +#include +#include #include #include #include @@ -10,6 +12,46 @@ #include #include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#endif + int register_linear_params(); namespace at { namespace native { diff --git a/aten/src/ATen/native/RangeFactories.cpp b/aten/src/ATen/native/RangeFactories.cpp index 038da93456edb..408bf0a27e6fe 100644 --- a/aten/src/ATen/native/RangeFactories.cpp +++ b/aten/src/ATen/native/RangeFactories.cpp @@ -1,13 +1,23 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include -#include #include -#include #include -#include +#include +#include #include #include #include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#endif + namespace at { namespace native { diff --git a/aten/src/ATen/native/ReduceAllOps.cpp b/aten/src/ATen/native/ReduceAllOps.cpp index 1ef5e9b93733c..e1d51a1666af2 100644 --- a/aten/src/ATen/native/ReduceAllOps.cpp +++ b/aten/src/ATen/native/ReduceAllOps.cpp @@ -1,8 +1,21 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include -#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include #include +#else +#include +#include +#include +#include +#include +#include +#include +#endif namespace at { namespace native { diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp index 2bb01abd51b5f..a10f6c7255760 100644 --- a/aten/src/ATen/native/ReduceOps.cpp +++ b/aten/src/ATen/native/ReduceOps.cpp @@ -1,21 +1,114 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include -#include +#include #include -#include -#include +#include #include #include #include +#include +#include +#include #include #include -#include -#include #include -#include #include #include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#endif + #include #include @@ -24,9 +117,7 @@ #include #include #include -#include #include -#include #include namespace at { @@ -1653,6 +1744,9 @@ static Tensor& std_var_out( const auto correction = correction_opt.value_or(1); ScalarType dtype = get_dtype_from_result(result, {}); auto iter = make_reduction(fname, result, self, dim, keepdim, dtype); + TORCH_CHECK(at::canCast(self.scalar_type(), result.scalar_type()), + "result type ", self.scalar_type(), " can't be cast to the " + "desired output type ", result.scalar_type()); if (iter.numel() == 0) { // Trivial reduction diff --git a/aten/src/ATen/native/ReduceOpsUtils.h b/aten/src/ATen/native/ReduceOpsUtils.h index 9db9802ea788b..2b46eb683f1c9 100644 --- a/aten/src/ATen/native/ReduceOpsUtils.h +++ b/aten/src/ATen/native/ReduceOpsUtils.h @@ -102,7 +102,7 @@ static inline void check_scalar_type_device_layout_equal(const Tensor& out, cons OPTION_TYPE_EQUALITY_CHECK(layout, out.options(), self.options()); } -static inline Tensor integer_upcast(const Tensor& self, optional dtype) { +static inline Tensor integer_upcast(const Tensor& self, c10::optional dtype) { ScalarType scalarType = self.scalar_type(); ScalarType upcast_scalarType = dtype.value_or(at::isIntegralType(scalarType, /*includeBool=*/true) ? ScalarType::Long : scalarType); return self.toType(upcast_scalarType); diff --git a/aten/src/ATen/native/ReflectionPad.cpp b/aten/src/ATen/native/ReflectionPad.cpp index 7824de63805f3..3a6ad683d0457 100644 --- a/aten/src/ATen/native/ReflectionPad.cpp +++ b/aten/src/ATen/native/ReflectionPad.cpp @@ -1,9 +1,26 @@ -#include -#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include #include +#include #include #include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#include +#include +#include +#include +#endif + namespace at { namespace meta { diff --git a/aten/src/ATen/native/Repeat.cpp b/aten/src/ATen/native/Repeat.cpp index b6e5c04f77026..c8c4e134929f9 100644 --- a/aten/src/ATen/native/Repeat.cpp +++ b/aten/src/ATen/native/Repeat.cpp @@ -1,8 +1,19 @@ -#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include #include #include #include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#endif + template static void compute_cpu( index_t* repeat_ptr, @@ -64,11 +75,11 @@ Tensor repeat_interleave( } Tensor repeats_ = repeats; - if (repeats.dim() == 0 || (repeats.dim() == 1 && repeats.size(0) == 1)) { - repeats_ = repeats.reshape({1}).expand({input.size(dim.value())}); + if (repeats.dim() == 0 || (repeats.dim() == 1 && repeats.sym_size(0) == 1)) { + repeats_ = repeats.reshape({1}).expand_symint({input.sym_size(dim.value())}); } else if (repeats.dim() == 1) { TORCH_CHECK( - repeats.size(0) == input.size(dim.value()), + repeats.sym_size(0) == input.sym_size(dim.value()), "repeats must have the same size as input along dim") } else { AT_ERROR("repeats must be 0-dim or 1-dim tensor"); @@ -91,10 +102,17 @@ Tensor repeat_interleave( int64_t repeats, c10::optional dim, c10::optional output_size) { - at::Tensor repeats_ = - at::empty(1, self.options().dtype(at::kLong)).fill_(repeats); + at::Tensor repeats_ = at::empty(1, self.options().dtype(at::kLong)).fill_(repeats); return at::native::repeat_interleave(self, repeats_, dim, output_size); } +Tensor repeat_interleave_symint( + const Tensor& self, + c10::SymInt repeats, + c10::optional dim, + c10::optional output_size) { + return at::native::repeat_interleave(self, repeats.guard_int(__FILE__, __LINE__), dim, output_size); + } + } // namespace native } // namespace at diff --git a/aten/src/ATen/native/ReplicationPadding.cpp b/aten/src/ATen/native/ReplicationPadding.cpp index 40fdb788a4ffa..d0a4ea919acbf 100644 --- a/aten/src/ATen/native/ReplicationPadding.cpp +++ b/aten/src/ATen/native/ReplicationPadding.cpp @@ -1,9 +1,24 @@ -#include -#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include #include +#include #include #include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#include +#include +#endif + namespace at { namespace meta { diff --git a/aten/src/ATen/native/Resize.cpp b/aten/src/ATen/native/Resize.cpp index 08286f3983cc9..bd47a25e69601 100644 --- a/aten/src/ATen/native/Resize.cpp +++ b/aten/src/ATen/native/Resize.cpp @@ -1,9 +1,16 @@ -#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include #include #include +#include #include -#include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#include +#endif namespace at { namespace native { diff --git a/aten/src/ATen/native/RowwisePrune.cpp b/aten/src/ATen/native/RowwisePrune.cpp index 40ae2215cbccc..c27707c4d3075 100644 --- a/aten/src/ATen/native/RowwisePrune.cpp +++ b/aten/src/ATen/native/RowwisePrune.cpp @@ -1,8 +1,17 @@ // Copyright 2004-present Facebook. All Rights Reserved. +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS -#include +#include +#include #include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#endif namespace at { namespace native { diff --git a/aten/src/ATen/native/Scalar.cpp b/aten/src/ATen/native/Scalar.cpp index 7342c4806d44c..f8932ea03bb2e 100644 --- a/aten/src/ATen/native/Scalar.cpp +++ b/aten/src/ATen/native/Scalar.cpp @@ -1,5 +1,15 @@ -#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include #include +#else +#include +#include +#include +#endif namespace at { namespace native { diff --git a/aten/src/ATen/native/SegmentReduce.cpp b/aten/src/ATen/native/SegmentReduce.cpp index 3e562b7cf859f..1e5e28dab86b2 100644 --- a/aten/src/ATen/native/SegmentReduce.cpp +++ b/aten/src/ATen/native/SegmentReduce.cpp @@ -1,10 +1,23 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include -#include +#include #include #include +#include #include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#endif + namespace at { namespace native { diff --git a/aten/src/ATen/native/SobolEngineOps.cpp b/aten/src/ATen/native/SobolEngineOps.cpp index 48366976a2e70..187faeba16a7b 100644 --- a/aten/src/ATen/native/SobolEngineOps.cpp +++ b/aten/src/ATen/native/SobolEngineOps.cpp @@ -1,11 +1,21 @@ -#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include #include -#include #include #include -#include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#include +#endif namespace at { namespace native { diff --git a/aten/src/ATen/native/SobolEngineOpsUtils.cpp b/aten/src/ATen/native/SobolEngineOpsUtils.cpp index ef7cbb1faae92..709d5c06d3c97 100644 --- a/aten/src/ATen/native/SobolEngineOpsUtils.cpp +++ b/aten/src/ATen/native/SobolEngineOpsUtils.cpp @@ -1,4 +1,5 @@ /// This file contains tensor-agnostic SoboleEngine constants +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include /* diff --git a/aten/src/ATen/native/SobolEngineOpsUtils.h b/aten/src/ATen/native/SobolEngineOpsUtils.h index d3d7a362f2e87..495a43ed8a7cf 100644 --- a/aten/src/ATen/native/SobolEngineOpsUtils.h +++ b/aten/src/ATen/native/SobolEngineOpsUtils.h @@ -1,6 +1,14 @@ /// This file contains some tensor-agnostic operations to be used in the /// core functions of the `SobolEngine` -#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#include +#include +#endif namespace at { namespace native { diff --git a/aten/src/ATen/native/SoftMax.cpp b/aten/src/ATen/native/SoftMax.cpp index d9d1b90534d73..0332f57e9e23e 100644 --- a/aten/src/ATen/native/SoftMax.cpp +++ b/aten/src/ATen/native/SoftMax.cpp @@ -1,13 +1,36 @@ -#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include #include -#include +#include #include #include #include +#include #include #include #include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#endif + #include #include #include @@ -139,9 +162,6 @@ void host_softmax( int64_t mask_type = mask_type_.value(); // If mask_type == 2, then mask_.sizes() must equal input_.sizes() TORCH_CHECK((mask_type == 0) || (mask_type == 1) || (mask_type == 2), "Mask Type should be 0 (src_mask) or 1 (src_key_padding_mask), or 2 (default_mask)"); - - // TODO: Add support for TxT src_mask - TORCH_CHECK(mask_type != 0, "src_mask not currently supported on CPU"); } int64_t outer_size = 1; @@ -171,8 +191,22 @@ void host_softmax( output_data_base + outer_idx * outer_stride + inner_idx; bool* mask_data = nullptr; if (MaskedSoftMax) { - mask_data = mask_data_base + outer_idx * outer_stride + inner_idx; - } + // Process mask differently depending on the type: + // For a generic mask of mask_type == 2, mask shape is the same as the input shape, + // so indexing is the same. + auto mask_outer_idx = outer_idx; + if (mask_type_ == 0) { + // Optimized case: attention mask of shape LxL + // outer_idx goes over BxHxL, mask_outer_idx goes over L. + mask_outer_idx = outer_idx % input.size(2); + } else if (mask_type_ == 1) { + // Optimized case: padding mask of shape BxL + // outer_idx goes over BxHxL, mask_outer_idx goes over B. + mask_outer_idx = outer_idx / (input.size(1) * input.size(2)); + } + + mask_data = mask_data_base + mask_outer_idx * outer_stride + inner_idx; + }; // Calc max in softmax dim bool is_meaningful_max = false; @@ -554,15 +588,48 @@ Tensor log_softmax(const Tensor& self, Dimname dim, optional dtype) } Tensor masked_softmax_cpu(const Tensor& input_, const Tensor& mask_, const c10::optional dim_, const c10::optional mask_type_) { - TORCH_CHECK( - input_.sizes() == mask_.sizes(), "Mask shape should match input shape"); + + auto mask = mask_.contiguous(); + auto mask_type = mask_type_; // Mask type might get transformed below + TORCH_CHECK( mask_.scalar_type() == ScalarType::Bool, "Mask should be a boolean tensor"); + if ((mask.dim() != 2) || (input_.dim() != 4)) { + // Mask types 0 and 1 are only allowed for 2D masks and 4D inputs + mask_type = 2; + } + + if (mask_type == 2) { + TORCH_CHECK(input_.sizes() == mask.sizes(), + "For mask_type == 2 mask shape should match input shape") + } else if (mask_type == 1) { + // Padding mask of shape (B, L) + TORCH_CHECK((input_.sizes()[0] == mask.sizes()[0]) && (input_.sizes()[2] == mask.sizes()[1]), + "For mask_type == 1 mask shape should be (B, L)"); + if (dim_ != input_.dim() - 1) { + // We only process padding mask in the optimized way if softmax is applied along the last dimesion, + // otherwise we need to expand the mask into a generic 4D one + mask = mask_.view({input_.sizes()[0], 1, 1, input_.sizes()[2]}); + mask = mask.expand(input_.sizes()).contiguous(); + mask_type = 2; + } + } else if (mask_type == 0) { + // Attention mask of shape (L, L) + TORCH_CHECK((mask.dim() == 2) && (input_.sizes()[2] == mask.sizes()[0]) && (input_.sizes()[2] == mask.sizes()[1]), + "For mask_type == 0 mask shape should be (L, L)"); + if (dim_ != input_.dim() - 1) { + // We only process attention mask in a optimized way if softmax is applied along the last dimesion, + // otherwise we need to expand the mask into a generic 4D one + mask = mask.view({1, 1, input_.sizes()[2], input_.sizes()[2]}); + mask = mask.expand(input_.sizes()).contiguous(); + mask_type = 2; + } + } + Tensor output = at::empty_like(input_, input_.options()); auto input = input_.contiguous(); - auto mask = mask_.contiguous(); int64_t dim = dim_.has_value() ? dim_.value() : input.dim() - 1; dim = maybe_wrap_dim(dim, input_.dim()); @@ -576,7 +643,7 @@ Tensor masked_softmax_cpu(const Tensor& input_, const Tensor& mask_, const c10:: scalar_t, false /* LogSoftMax */, true /* MaskedSoftMax */>( - output, input, dim, mask.data_ptr(), mask_type_); + output, input, dim, mask.data_ptr(), mask_type); }); return output; } diff --git a/aten/src/ATen/native/Sorting.cpp b/aten/src/ATen/native/Sorting.cpp index 66b9daf7fad8c..3b50d7744aa28 100644 --- a/aten/src/ATen/native/Sorting.cpp +++ b/aten/src/ATen/native/Sorting.cpp @@ -1,8 +1,16 @@ -#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include #include #include #include #include +#include +#include +#include +#include +#include #include #include #include @@ -11,6 +19,32 @@ #include #include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#endif + #include namespace at { diff --git a/aten/src/ATen/native/SpectralOps.cpp b/aten/src/ATen/native/SpectralOps.cpp index f39eeaccf9d4f..124c2d06d9e83 100644 --- a/aten/src/ATen/native/SpectralOps.cpp +++ b/aten/src/ATen/native/SpectralOps.cpp @@ -1,16 +1,67 @@ -#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include #include #include #include #include -#include -#include +#include +#include #include #include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#endif + #include -#include -#include namespace at { namespace native { @@ -148,7 +199,7 @@ Tensor fft_c2r(c10::string_view function_name, " expects a floating point output tensor, but got ", out.scalar_type()); input = promote_tensor_fft(input, /*require_complex=*/true); const auto input_dim = input.dim(); - const auto dim = maybe_wrap_dim(unwrapped_dim, input_dim); + const auto dim = maybe_wrap_dim(unwrapped_dim, input_dim, /*wrap_scalar=*/false); const auto n = n_opt.value_or(2*(input.sizes()[dim] - 1)); TORCH_CHECK(n >= 1, "Invalid number of data points (", n, ") specified"); if (n_opt) { @@ -157,7 +208,7 @@ Tensor fft_c2r(c10::string_view function_name, const auto norm = norm_from_string(norm_str, forward); if (forward) { // FIXME: _fft does not support complex_output=false with inverse=false - input = at::conj(input); + input = input.conj(); } return fft_c2r_maybe_out( function_name, out, input, dim, static_cast(norm), n); @@ -174,7 +225,7 @@ Tensor fft_r2c(c10::string_view function_name, " expects a complex output tensor, but got ", out.scalar_type()); input = promote_tensor_fft(input); const auto input_dim = input.dim(); - const auto dim = maybe_wrap_dim(unwrapped_dim, input_dim); + const auto dim = maybe_wrap_dim(unwrapped_dim, input_dim, /*wrap_scalar=*/false); const auto n = n_opt.value_or(input.sizes()[dim]); TORCH_CHECK(n >= 1, "Invalid number of data points (", n, ") specified"); if (n_opt) { @@ -192,7 +243,7 @@ Tensor fft_r2c(c10::string_view function_name, if (!forward) { // FIXME: _fft_r2c doesn't support native r2c IFFT - return out.defined() ? at::conj_physical_out(out, ret) : at::conj(ret); + return out.defined() ? at::conj_physical_out(out, ret) : ret.conj(); } else { return ret; } @@ -206,7 +257,7 @@ Tensor fft_c2c(c10::string_view function_name, TORCH_CHECK(input.is_complex(), function_name, " expects a complex input tensor, but got ", input.scalar_type()); const auto input_dim = input.dim(); - const auto dim = maybe_wrap_dim(unwrapped_dim, input_dim); + const auto dim = maybe_wrap_dim(unwrapped_dim, input_dim, /*wrap_scalar=*/false); const auto n = n_opt.value_or(input.sizes()[dim]); TORCH_CHECK(n >= 1, "Invalid number of data points (", n, ") specified"); if (n_opt) { @@ -233,7 +284,7 @@ ShapeAndDims canonicalize_fft_shape_and_dim_args( if (dim) { ret.dim.resize(dim->size()); std::copy(dim->begin(), dim->end(), ret.dim.begin()); - maybe_wrap_dims(ret.dim, input_dim); + maybe_wrap_dims(ret.dim, input_dim, /*wrap_scalars=*/false); // Check dims are unique DimVector copy = ret.dim; @@ -521,7 +572,7 @@ static Tensor fft_hfftn_impl( } const auto last_dim = desc.dim.back(); - tmp = at::conj(tmp); + tmp = tmp.conj(); return fft_c2r_maybe_out(fname, out, tmp, last_dim, norm, last_dim_size); } @@ -559,7 +610,7 @@ static Tensor fft_ihfftn_impl( const auto last_dim = desc.dim.back(); auto tmp = at::_fft_r2c(x, last_dim, norm, /*onesided=*/true); if (desc.dim.size() == 1) { - return out.defined() ? at::conj_physical_out(tmp, out) : at::conj(tmp); + return out.defined() ? at::conj_physical_out(tmp, out) : tmp.conj(); } tmp = at::conj_physical(tmp); @@ -699,7 +750,7 @@ DimVector default_alldims(const Tensor& self, at::OptionalIntArrayRef dim_opt) { IntArrayRef dim_unwrapped = *dim_opt; dim.resize(dim_unwrapped.size()); for (const auto i : c10::irange(dim.size())) { - dim[i] = maybe_wrap_dim(dim_unwrapped[i], self.dim()); + dim[i] = maybe_wrap_dim(dim_unwrapped[i], self.dim(), /*wrap_scalars=*/false); } } else { dim.resize(self.dim()); @@ -1002,13 +1053,13 @@ Tensor istft(const Tensor& self, const int64_t n_fft, const optional ho if (onesided) { if (n_fft / 2 + 1 != fft_size) { std::ostringstream ss; - REPR(ss) << ": expected the frequency dimension (3rd to the last) of the input tensor to match n_fft / 2 + 1 when onsided=True, but got " << fft_size; + REPR(ss) << ": expected the frequency dimension (3rd to the last) of the input tensor to match n_fft / 2 + 1 when onesided=True, but got " << fft_size; AT_ERROR(ss.str()); } } else { if (n_fft != fft_size) { std::ostringstream ss; - REPR(ss) << ": expected the frequency dimension (3rd to the last) of the input tensor to match n_fft when onsided=False, but got " << fft_size; + REPR(ss) << ": expected the frequency dimension (3rd to the last) of the input tensor to match n_fft when onesided=False, but got " << fft_size; AT_ERROR(ss.str()); } } @@ -1044,7 +1095,7 @@ Tensor istft(const Tensor& self, const int64_t n_fft, const optional ho input = input.unsqueeze(0); } - input = as_complex(input.transpose(1, 2)); // size: (channel, n_frames, fft_size, 2) + input = as_complex(input.transpose(1, 2)); // size: (channel, n_frames, fft_size) const fft_norm_mode norm = normalized ? fft_norm_mode::by_root_n : fft_norm_mode::by_n; if (return_complex) { @@ -1061,26 +1112,23 @@ Tensor istft(const Tensor& self, const int64_t n_fft, const optional ho TORCH_INTERNAL_ASSERT(input.size(2) == n_fft); Tensor y_tmp = input * window_tmp.view({1, 1, n_fft}); // size: (channel, n_frames, n_fft) - y_tmp = y_tmp.transpose(1, 2); // size: (channel, n_fft, frame) - - Tensor y = at::col2im(y_tmp, - /*output_size*/ {1, (n_frames - 1) * hop_length + n_fft}, - /*kernel_size*/ {1, n_fft}, - /*dilation*/ {1, 1}, - /*padding*/ {0, 0}, - /*stride*/ {1, hop_length} - ).squeeze(2); - window_tmp = window_tmp.pow(2).view({n_fft, 1}).repeat({1, n_frames}).unsqueeze(0); // size: (1, n_fft, n_frames) - Tensor window_envelop = at::col2im(window_tmp, - /*output_size*/ {1, (n_frames - 1) * hop_length + n_fft}, - /*kernel_size*/ {1, n_fft}, - /*dilation*/ {1, 1}, - /*padding*/ {0, 0}, - /*stride*/ {1, hop_length} - ).squeeze(2); // size: (1, 1, expected_output_signal_len) - - TORCH_INTERNAL_ASSERT(expected_output_signal_len == y.size(2)); - TORCH_INTERNAL_ASSERT(expected_output_signal_len == window_envelop.size(2)); + + Tensor y = at::unfold_backward( + y_tmp, + /*input_sizes=*/{y_tmp.size(0), expected_output_signal_len}, + /*dim=*/1, + /*size=*/n_fft, + /*step=*/hop_length); + window_tmp = window_tmp.pow(2).expand({1, n_frames, n_fft}); // size: (1, n_frames, n_fft) + Tensor window_envelop = at::unfold_backward( + window_tmp, + /*input_sizes=*/{1, expected_output_signal_len}, + /*dim=*/1, + /*size=*/n_fft, + /*step=*/hop_length); // size: (1, expected_output_signal_len) + + TORCH_INTERNAL_ASSERT(expected_output_signal_len == y.size(1)); + TORCH_INTERNAL_ASSERT(expected_output_signal_len == window_envelop.size(1)); // We need to trim the front padding away if centered const auto start = center ? n_fft / 2 : 0; @@ -1094,8 +1142,8 @@ Tensor istft(const Tensor& self, const int64_t n_fft, const optional ho return expected_output_signal_len; }(); - y = y.slice(2, start, end, 1); - window_envelop = window_envelop.slice(2, start, end, 1); + y = y.slice(1, start, end, 1); + window_envelop = window_envelop.slice(1, start, end, 1); const auto window_envelop_lowest = window_envelop.abs().min().lt(1e-11); if (at::is_scalar_tensor_true(window_envelop_lowest)) { std::ostringstream ss; @@ -1103,7 +1151,7 @@ Tensor istft(const Tensor& self, const int64_t n_fft, const optional ho AT_ERROR(ss.str()); } - y = (y / window_envelop).squeeze(1); // size: (channel, expected_output_signal_len) + y = (y / window_envelop); // size: (channel, expected_output_signal_len) if (input_dim == 3) { y = y.squeeze(0); } @@ -1134,7 +1182,7 @@ void _fft_fill_with_conjugate_symmetry_(const Tensor& input, IntArrayRef dim_) { const auto input_strides = input.strides(); TORCH_CHECK(dim_.size() > 0); DimVector dim(dim_.begin(), dim_.end()); - at::maybe_wrap_dims(dim, input_strides.size()); + at::maybe_wrap_dims(dim, input_strides.size(), /*wrap_scalars=*/false); if (input.numel() == 0 || input_sizes[dim.back()] <= 2) { return; // No elements need writing diff --git a/aten/src/ATen/native/SummaryOps.cpp b/aten/src/ATen/native/SummaryOps.cpp index e7dbe72576721..ae0b38c96efa7 100644 --- a/aten/src/ATen/native/SummaryOps.cpp +++ b/aten/src/ATen/native/SummaryOps.cpp @@ -1,10 +1,17 @@ // Returns the frequency of elements of input non-negative integer tensor. +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS -#include +#include #include #include -#include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#endif namespace at { namespace native { diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.cpp b/aten/src/ATen/native/TensorAdvancedIndexing.cpp index 2f7dbaf45252f..3004dc1b31c79 100644 --- a/aten/src/ATen/native/TensorAdvancedIndexing.cpp +++ b/aten/src/ATen/native/TensorAdvancedIndexing.cpp @@ -47,31 +47,93 @@ // ...) // // where & and * represent the C-style address-of and indirection operations. +// #define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include #include #include #include -#include -#include +#include +#include +#include +#include #include #include -#include -#include -#include +#include +#include +#include +#include +#include +#include +#include #include #include #include #include +#include #include #include #include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#endif + #include #include #include -#include #include #include diff --git a/aten/src/ATen/native/TensorAdvancedIndexingUtils.h b/aten/src/ATen/native/TensorAdvancedIndexingUtils.h index 3e786bf7db4fc..0c0db4b83f351 100644 --- a/aten/src/ATen/native/TensorAdvancedIndexingUtils.h +++ b/aten/src/ATen/native/TensorAdvancedIndexingUtils.h @@ -1,5 +1,5 @@ #pragma once -#include +#include #include #include diff --git a/aten/src/ATen/native/TensorCompare.cpp b/aten/src/ATen/native/TensorCompare.cpp index 856d684c52e85..5d3ee7d98d803 100644 --- a/aten/src/ATen/native/TensorCompare.cpp +++ b/aten/src/ATen/native/TensorCompare.cpp @@ -1,19 +1,73 @@ -#include -#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include #include -#include -#include -#include -#include +#include +#include +#include +#include +#include +#include #include +#include #include #include -#include -#include -#include #include -#include #include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#endif namespace at { namespace meta { diff --git a/aten/src/ATen/native/TensorConversions.cpp b/aten/src/ATen/native/TensorConversions.cpp index 2af35c66a0b9e..e6c7bd3875d2a 100644 --- a/aten/src/ATen/native/TensorConversions.cpp +++ b/aten/src/ATen/native/TensorConversions.cpp @@ -1,8 +1,50 @@ +// #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include -#include +#include #include #include +#include #include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#endif #include #include @@ -202,48 +244,52 @@ Tensor _to_copy( // memory_format is handled separately due to MemoryFormat::Preserve logic options = self.options().merge_in(options).memory_format(c10::nullopt); auto memory_format = optional_memory_format.value_or(MemoryFormat::Preserve); + // TODO: Use the dispatcher for this. // Currently there are unenumerated extensibility issues preventing this. - if (self.is_sparse_csr()) { - TORCH_CHECK( - memory_format == MemoryFormat::Preserve, - "sparse_csr only supports memory format Preserve, but got ", - memory_format, - " instead."); - - auto new_values = at::native::to( - self.values(), - dtype, - c10::kStrided, // values are strided - device, - pin_memory, - non_blocking, - true, // force copy since we're in _to_copy - memory_format); - - auto new_crow_indices = at::native::to( - self.crow_indices(), - self.crow_indices().scalar_type(), // indices are integral - c10::kStrided, // indices are strided - device, - pin_memory, - non_blocking, - true, // force copy since we're in _to_copy - memory_format); - - auto new_col_indices = at::native::to( - self.col_indices(), - self.col_indices().scalar_type(), // indices are integral - c10::kStrided, // indices are strided - device, - pin_memory, - non_blocking, - true, // force copy since we're in _to_copy - memory_format); - - return at::native::_sparse_csr_tensor_unsafe( - new_crow_indices, - new_col_indices, + if (at::sparse_csr::is_sparse_compressed(self)) { + TORCH_CHECK( + memory_format == MemoryFormat::Preserve, + "to(options): ", at::sparse_csr::layoutToString(self.layout()), + " only supports memory format Preserve, but got ", memory_format, + " instead."); + + Tensor compressed_indices, plain_indices; + std::tie(compressed_indices, plain_indices) = at::sparse_csr::getCompressedPlainIndices(self); + + const auto new_values = at::native::to( + self.values(), + dtype, + c10::kStrided, + device, + pin_memory, + non_blocking, + true, // force copy since we are in _to_copy + memory_format); + + const auto new_compressed_indices = at::native::to( + compressed_indices, + compressed_indices.scalar_type(), + c10::kStrided, + device, + pin_memory, + non_blocking, + true, // force copy since we are in _to_copy + memory_format); + + const auto new_plain_indices = at::native::to( + plain_indices, + plain_indices.scalar_type(), + c10::kStrided, + device, + pin_memory, + non_blocking, + true, // force copy since we are in _to_copy + memory_format); + + return at::native::_sparse_compressed_tensor_unsafe( + new_compressed_indices, + new_plain_indices, new_values, self.sizes(), new_values.scalar_type(), @@ -896,7 +942,9 @@ Tensor dense_to_sparse_bsr(const Tensor& self, IntArrayRef blocksize) { self.size(-1), " needs to be divisible by blocksize[1] ", blocksize[1]); - + // TODO: specify the number of dense dimensions, or equivalently, + // the number of batch dimensions. Until then, below we'll assume + // that the number of dense dimensions is 0. auto n_batch_dim = self.dim() - 2; auto values = _batch_tile_tensor(self, blocksize); @@ -1008,8 +1056,7 @@ void _check_blocksize_matches( const std::string& name) { if (blocksize_opt.has_value()) { const auto blocksize = *blocksize_opt; - const auto self_values = self.values(); - const auto self_blocksize = at::DimVector({self_values.size(-2), self_values.size(-1)}); + const auto self_blocksize = at::sparse_csr::getBlockSize(self); TORCH_CHECK(self_blocksize == blocksize, name, "(): the provided blocksize does not match the blocksize of the to be converted tensor, ", "got (", blocksize[0], ", ", blocksize[1], ") ", @@ -1098,7 +1145,7 @@ Tensor sparse_compressed_to_flipped( const auto sparse_dims = [&]() -> at::DimVector { auto sparse_dims = at::DimVector(self.sizes().slice(n_batches, 2)); if (layout == at::kSparseBsr || layout == at::kSparseBsc) { - std::array blocksize = {values.size(-2), values.size(-1)}; + auto blocksize = at::sparse_csr::getBlockSize(self); sparse_dims[0] /= blocksize[0]; sparse_dims[1] /= blocksize[1]; } @@ -1260,9 +1307,9 @@ Tensor sparse_compressed_to_sparse_csr(const Tensor& self) { Tensor coo_to_sparse_csr(const Tensor& self) { TORCH_CHECK( - self.dim() == 2, - "Only 2D tensors can be converted to the SparseCsr layout but got shape: ", - self.sizes()); + self.sparse_dim() == 2, + "Only tensors with two sparse dimensions can be converted to the SparseCsr layout, got self with ", + self.sparse_dim(), " sparse dimensions."); auto coalesced_self = self.coalesce(); auto row_indices = coalesced_self.indices()[0]; bool out_int32 = (row_indices.scalar_type() == at::kInt); @@ -1280,9 +1327,9 @@ Tensor coo_to_sparse_csr(const Tensor& self) { Tensor coo_to_sparse_csc(const Tensor& self) { TORCH_CHECK( - self.dim() == 2, - "Only 2D tensors can be converted to the SparseCsc layout but got shape: ", - self.sizes()); + self.sparse_dim() == 2, + "Only tensors with two sparse dimensions can be converted to the SparseCsc layout, got self with ", + self.sparse_dim(), " sparse dimensions."); auto coalesced_self = self.transpose(0, 1).coalesce().to_sparse_csr(); return at::native::_sparse_csc_tensor_unsafe( coalesced_self.crow_indices(), @@ -1295,15 +1342,11 @@ Tensor coo_to_sparse_csc(const Tensor& self) { } Tensor coo_to_sparse_bsr(const Tensor& self, IntArrayRef blocksize) { - AT_ERROR( - "Conversion from ", self.layout(), " to SparseBsr is currently not supported."); - return self; + return self.to_sparse_csr().to_sparse_bsr(blocksize); } Tensor coo_to_sparse_bsc(const Tensor& self, IntArrayRef blocksize) { - AT_ERROR( - "Conversion from ", self.layout(), " to SparseBsc is currently not supported."); - return self; + return self.to_sparse_bsr(blocksize).to_sparse_bsc(blocksize); } namespace { @@ -1552,9 +1595,20 @@ Tensor _csr_to_block_csr_cpu(const Tensor& self, IntArrayRef blocksize) { input_crow_indices.data_ptr(), input_col_indices.data_ptr()); }); + DimVector values_size{num_blocks, blocksize[0], blocksize[1]}; + + // While we don't support conversion of hybrid csr-to-bsr yet, we'll + // compute hybrid compatible values sizes to meet the invariants of + // the BSR tensor when the support will be implemented. + int64_t numel_dense = 1; + for (int i=0; i( n_row, @@ -1589,9 +1643,17 @@ Tensor _csr_to_block_csr_cpu(const Tensor& self, IntArrayRef blocksize) { Tensor sparse_compressed_to_sparse_bsr(const Tensor& self, IntArrayRef blocksize) { if (self.layout() == kSparseBsc) { + DimVector self_blocksize = at::sparse_csr::getBlockSize(self); + TORCH_CHECK(self_blocksize == blocksize, "to_sparse_bsr:", + "conversion from ", self.layout(), "[blocksize=", self_blocksize, "] to ", kSparseBsr, + "[blocksize=", DimVector(blocksize),"] is not implemented."); return sparse_compressed_to_flipped(self, blocksize, "to_sparse_bsr"); } if (self.layout() == kSparseBsr) { + DimVector self_blocksize = at::sparse_csr::getBlockSize(self); + TORCH_CHECK(self_blocksize == blocksize, "to_sparse_bsr:", + "conversion from ", self.layout(), "[blocksize=", self_blocksize, "] to ", kSparseBsr, + "[blocksize=", blocksize,"] is not implemented."); return sparse_compressed_clone(self, blocksize, "to_sparse_bsr"); } if (self.layout() == kSparseCsr) { @@ -1633,10 +1695,18 @@ Tensor sparse_compressed_to_sparse_bsr(const Tensor& self, IntArrayRef blocksize Tensor sparse_compressed_to_sparse_bsc(const Tensor& self, IntArrayRef blocksize) { if (self.layout() == kSparseBsr) { - return sparse_compressed_to_flipped(self, blocksize, "to_sparse_bsr"); + DimVector self_blocksize = at::sparse_csr::getBlockSize(self); + TORCH_CHECK(self_blocksize == blocksize, "to_sparse_bsc:", + "conversion from ", self.layout(), "[blocksize=", self_blocksize, "] to ", kSparseBsc, + "[blocksize=", blocksize,"] is not implemented."); + return sparse_compressed_to_flipped(self, blocksize, "to_sparse_bsc"); } if (self.layout() == kSparseBsc) { - return sparse_compressed_clone(self, blocksize, "to_sparse_bsr"); + DimVector self_blocksize = at::sparse_csr::getBlockSize(self); + TORCH_CHECK(self_blocksize == blocksize, "to_sparse_bsc:", + "conversion from ", self.layout(), "[blocksize=", self_blocksize, "] to ", kSparseBsc, + "[blocksize=", blocksize,"] is not implemented."); + return sparse_compressed_clone(self, blocksize, "to_sparse_bsc"); } AT_ERROR( "sparse_compressed_to_sparse_bsc expected SparseBsr or SparseBsc layout but got ", @@ -1684,10 +1754,81 @@ Tensor sparse_compressed_to_sparse(const Tensor& self, int64_t sparse_dim) { self.layout()); } -Tensor sparse_compressed_to_sparse(const Tensor& self) { - return sparse_compressed_to_sparse(self, 2); +Tensor sparse_compressed_to_sparse(const Tensor& self, c10::optional layout, OptionalIntArrayRef blocksize) { + Layout layout_ = layout.value_or(kSparse); + TORCH_CHECK(!blocksize.has_value() || layout_ == kSparseBsr || layout_ == kSparseBsc, + "to_sparse: ", self.layout(), " to ", layout_, + " conversion does not use the specified blocksize ", blocksize.value(), "."); + if (self.layout() == layout_ && (!blocksize.has_value() || at::sparse_csr::getBlockSize(self) == *blocksize)) { + return self; + } + switch (layout_) { + case kStrided: + return sparse_compressed_to_dense(self); + case kSparse: + return sparse_compressed_to_sparse(self, 2); + case kSparseCsr: + return sparse_compressed_to_sparse_csr(self); + case kSparseCsc: + return sparse_compressed_to_sparse_csc(self); + case kSparseBsr: + if (blocksize.has_value()) { + return sparse_compressed_to_sparse_bsr(self, *blocksize); + } else { + DimVector blocksize_ = at::sparse_csr::getBlockSize(self); + TORCH_CHECK(blocksize_.size() == 2, "to_sparse: ", self.layout(), " to ", layout_, + " conversion requires blocksize specified."); + return sparse_compressed_to_sparse_bsr(self, blocksize_); + } + case kSparseBsc: + if (blocksize.has_value()) { + return sparse_compressed_to_sparse_bsc(self, *blocksize); + } else { + DimVector blocksize_ = at::sparse_csr::getBlockSize(self); + TORCH_CHECK(blocksize_.size() == 2, "to_sparse: ", self.layout(), " to ", layout_, + " conversion requires blocksize specified."); + return sparse_compressed_to_sparse_bsc(self, blocksize_); + } + default: + break; + } + AT_ERROR("to_sparse: ", self.layout(), " to ", layout_, " conversion not implemented."); + return Tensor(); } +Tensor sparse_coo_to_sparse(const Tensor& self, c10::optional layout, OptionalIntArrayRef blocksize) { + Layout layout_ = layout.value_or(kSparse); + TORCH_CHECK(!blocksize.has_value() || layout_ == kSparseBsr || layout_ == kSparseBsc, + "to_sparse: ", self.layout(), " to ", layout_, + " conversion does not use the specified blocksize ", blocksize.value(), "."); + if (self.layout() == layout_) { + return self; + } + switch (layout_) { + case kStrided: + return self.to_dense(); + case kSparse: + return self; + case kSparseCsr: + return self.to_sparse_csr(); + case kSparseCsc: + return self.to_sparse_csc(); + case kSparseBsr: + TORCH_CHECK(blocksize.has_value(), "to_sparse: ", self.layout(), " to ", layout_, + " conversion requires blocksize specified."); + return self.to_sparse_bsr(*blocksize); + case kSparseBsc: + TORCH_CHECK(blocksize.has_value(), "to_sparse: ", self.layout(), " to ", layout_, + " conversion requires blocksize specified."); + return self.to_sparse_bsc(*blocksize); + default: + break; + } + AT_ERROR("to_sparse not implemented for ", self.layout(), " to ", *layout, " conversion"); + return Tensor(); +} + + // Sparse layout conversions End Tensor to_meta(const Tensor& tensor) { diff --git a/aten/src/ATen/native/TensorDimApply.h b/aten/src/ATen/native/TensorDimApply.h index ad9ca857eeab8..e75cd40caf48b 100644 --- a/aten/src/ATen/native/TensorDimApply.h +++ b/aten/src/ATen/native/TensorDimApply.h @@ -1,4 +1,5 @@ -#include +#pragma once +#include #include namespace at { diff --git a/aten/src/ATen/native/TensorFactories.cpp b/aten/src/ATen/native/TensorFactories.cpp index 2e01f7e8699ad..037ba84181de0 100644 --- a/aten/src/ATen/native/TensorFactories.cpp +++ b/aten/src/ATen/native/TensorFactories.cpp @@ -1,31 +1,99 @@ -#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include + +#include #include #include #include +#include #include #include -#include #include +#include +#include +#include #include -#include -#include -#include -#include #include -#include #include #include -#include -#include +#include + #ifndef AT_PER_OPERATOR_HEADERS #include +#include #else +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include #include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include #endif #include -#include -#include #include #include #include @@ -257,12 +325,6 @@ Tensor empty_like( // See [Note: hacky wrapper removal for TensorOptions] TensorOptions options_ = TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory); - - TORCH_CHECK( - !(options_.has_memory_format() && optional_memory_format.has_value()), - "Cannot set memory_format both in TensorOptions and explicit argument; please delete " - "the redundant setter."); - TensorOptions options = self.options() .merge_in(options_) @@ -1099,6 +1161,19 @@ Tensor _efficientzerotensor(IntArrayRef size, return out; } +Tensor _efficientzerotensor_meta(IntArrayRef size, + c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory) { + auto device_ = device_or_default(device); + auto allocator = at::native::ZeroTensorAllocator(device_); + auto dtype_ = dtype_or_default(dtype); + auto zero_ks = at::DispatchKeySet(c10::DispatchKey::Meta) | at::DispatchKeySet(c10::DispatchKey::ZeroTensor); + auto out = at::detail::empty_generic(size, &allocator, zero_ks, dtype_, c10::nullopt); + return out; +} + Tensor& zeros_sparse_out(IntArrayRef size, Tensor& result) { result.sparse_resize_and_clear_(size, size.size(), 0.); return result; diff --git a/aten/src/ATen/native/TensorFactories.h b/aten/src/ATen/native/TensorFactories.h index 35e058df4b3ab..2c0665518a9e3 100644 --- a/aten/src/ATen/native/TensorFactories.h +++ b/aten/src/ATen/native/TensorFactories.h @@ -1,10 +1,9 @@ #pragma once #include -#include +#include +#include #include -#include -#include #ifndef AT_PER_OPERATOR_HEADERS #include diff --git a/aten/src/ATen/native/TensorIteratorReduce.cpp b/aten/src/ATen/native/TensorIteratorReduce.cpp index ea772bfe7e641..606a442226876 100644 --- a/aten/src/ATen/native/TensorIteratorReduce.cpp +++ b/aten/src/ATen/native/TensorIteratorReduce.cpp @@ -1,11 +1,14 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include -#include -#include -#include -#include #include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + #include /// Contains the implementation of parallel reductions in TensorIterator. diff --git a/aten/src/ATen/native/TensorProperties.cpp b/aten/src/ATen/native/TensorProperties.cpp index 7941f2e3b758c..e37dbf56cc81a 100644 --- a/aten/src/ATen/native/TensorProperties.cpp +++ b/aten/src/ATen/native/TensorProperties.cpp @@ -1,12 +1,27 @@ -#include -#include -#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include #include -#include -#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#include +#include +#include +#include +#endif -#include #include + namespace at { namespace native { @@ -54,7 +69,7 @@ bool cudnn_is_acceptable(const TensorBase& self) { // tensors. Maybe some cuDNN functions actually support empty tensors, but // native/THNN kernels shouldn't be much slower because the output is also // likely empty. - if (self.numel() == 0) return false; + if (self.sym_numel() == 0) return false; // NB: In the old Python code, there was also a test to see if the // cuDNN library was actually dynamically linked or not. I'm not // sure if we can actually test this. diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index a72fba7ac12e0..f2ee31fe0bcdb 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -1,12 +1,18 @@ -#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include +#include #include +#include #include #include #include #include -#include #include #include +#include #include #include #include @@ -26,10 +32,184 @@ #include #include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#endif + +#include #include #include +#include #include -#include namespace at { namespace meta { @@ -238,7 +418,7 @@ Tensor& set_storage_meta__symint(Tensor& result, Storage storage, c10::SymInt st const auto itemsize = result.dtype().itemsize(); c10::SymInt size_bytes = at::detail::computeStorageNbytes( size, stride, itemsize, storage_offset); - storage.set_nbytes(size_bytes); + storage.set_nbytes(std::move(size_bytes)); } return result; } @@ -359,8 +539,8 @@ Tensor sparse_broadcast_to(const Tensor& self, IntArrayRef size) { return at::sparse_coo_tensor(new_indices, new_values, size)._coalesced_(is_coalesced); } -Tensor broadcast_to(const Tensor& self, IntArrayRef size) { - return self.expand(size); +Tensor broadcast_to_symint(const Tensor& self, SymIntArrayRef size) { + return self.expand_symint(size); } std::vector broadcast_tensors(TensorList tensors) { @@ -739,9 +919,12 @@ std::vector chunk(const Tensor& self, int64_t chunks, int64_t dim) { } } -std::vector tensor_split(const Tensor& self, int64_t sections, int64_t dim) { +std::vector tensor_split_sections_symint(const Tensor& self, c10::SymInt sym_sections, int64_t dim) { TORCH_CHECK(self.dim() > 0, "tensor_split expected at least a 1-dimensional tensor, but got a tensor with ", self.dim()," dims"); int64_t dim_ = maybe_wrap_dim(dim, self.dim()); + // NB: intentional, sections specifies number of output tensors, which + // cannot be polymorphic + int64_t sections = sym_sections.guard_int(__FILE__, __LINE__); TORCH_CHECK(sections > 0, "number of sections must be larger than 0, got ", sections); const auto dim_size = self.sym_size(dim_); std::vector splits(sections); @@ -756,21 +939,30 @@ std::vector tensor_split(const Tensor& self, int64_t sections, int64_t d return splits; } -std::vector tensor_split(const Tensor& self, IntArrayRef indices, int64_t dim) { +template +std::vector _tensor_split_indices(const Tensor& self, ArrayRef indices, int64_t dim) { TORCH_CHECK(self.dim() > 0, "tensor_split expected at least a 1-dimensional tensor, but got a tensor with ", self.dim()," dims"); int64_t dim_ = maybe_wrap_dim(dim, self.dim()); int64_t num_indices = indices.size(); std::vector splits(num_indices + 1); - int64_t start_idx = 0; + T start_idx(0); for (const auto split_idx : c10::irange(num_indices)) { - int64_t end_idx = indices[split_idx]; - splits[split_idx] = at::slice(self, dim_, start_idx, end_idx); + auto end_idx = indices[split_idx]; + splits[split_idx] = at::symint::slice(self, dim_, start_idx, end_idx); start_idx = end_idx; } - splits[num_indices] = at::slice(self, dim_, start_idx, self.size(dim_)); + splits[num_indices] = at::symint::slice(self, dim_, start_idx, at::symint::size(self, dim_)); return splits; } +std::vector tensor_split(const Tensor& self, IntArrayRef indices, int64_t dim) { + return _tensor_split_indices(self, indices, dim); +} + +std::vector tensor_split_indices_symint(const Tensor& self, SymIntArrayRef indices, int64_t dim) { + return _tensor_split_indices(self, indices, dim); +} + std::vector tensor_split(const Tensor& self, const Tensor& tensor_indices_or_sections, int64_t dim) { TORCH_CHECK(self.dim() > 0, "tensor_split expected at least a 1-dimensional tensor, but got a tensor with ", self.dim()," dims"); auto split_device = tensor_indices_or_sections.device(); @@ -996,8 +1188,8 @@ Tensor as_strided_qtensorimpl(const Tensor& self, IntArrayRef size, IntArrayRef return result; } -const Tensor &as_strided_(const Tensor& self, IntArrayRef size, IntArrayRef stride, optional storage_offset_) { - auto storage_offset = storage_offset_.value_or(self.storage_offset()); +const Tensor &as_strided__symint(const Tensor& self, SymIntArrayRef size, SymIntArrayRef stride, optional storage_offset_) { + auto storage_offset = storage_offset_.value_or(self.sym_storage_offset()); setStrided(self, size, stride, storage_offset); return self; } @@ -1006,6 +1198,8 @@ Tensor narrow_copy_dense(const Tensor& self, int64_t dim, int64_t start, int64_t return self.narrow(dim, start, length).clone(at::MemoryFormat::Contiguous); } +// Should just use narrow_copy_out, but this API is used internally at Meta: +// https://github.com/pytorch/pytorch/pull/87045#issuecomment-1309353561 Tensor narrow_copy_dense_cpu(const Tensor& self, int64_t dim, int64_t start, int64_t length){ auto output = at::empty_like(self); return narrow_copy_dense_cpu_out(self, dim, start, length, output); @@ -1015,9 +1209,10 @@ Tensor narrow_copy_sparse(const Tensor& self, int64_t dim, int64_t start, int64_ int64_t allDim = self.dim(); int64_t end = start+length; TORCH_CHECK(allDim > 0, "narrow() cannot be applied to a 0-dim tensor."); + TORCH_CHECK(length >= 0, "narrow(): length must be non-negative."); TORCH_CHECK(dim >= 0 && dim < allDim, "Dimension ", dim, " out of range. Expecting 0 <= dim < ", allDim, "."); - TORCH_CHECK(start >= 0 && length >= 0 && end <= self.size(dim), + TORCH_CHECK(start >= 0 && end <= self.size(dim), "Invalid range to narrow. range(start, start+length) must be a subset of range(0, ", self.size(dim), ").") Tensor indices = self._indices(); int64_t sparse_dim = self.sparse_dim(); @@ -1045,6 +1240,8 @@ Tensor narrow_copy_sparse(const Tensor& self, int64_t dim, int64_t start, int64_ return newTensor._coalesced_(self.is_coalesced()); } +// Should just use narrow_copy_out, but this API is used internally at Meta: +// https://github.com/pytorch/pytorch/pull/87045#issuecomment-1309353561 Tensor& narrow_copy_dense_cpu_out( const Tensor& self, int64_t dim, int64_t start, int64_t length, Tensor& output ) { @@ -1128,22 +1325,24 @@ Tensor& narrow_copy_dense_cpu_out( Tensor narrow(const Tensor& self, int64_t dim, int64_t start, int64_t length) { TORCH_CHECK(self.dim() > 0, "narrow() cannot be applied to a 0-dim tensor."); + TORCH_CHECK(length >= 0, "narrow(): length must be non-negative."); auto cur_size = self.size(dim); if (start != cur_size) { // start being the end is valid, but not a valid dim specification. start = maybe_wrap_dim(start, cur_size); } - TORCH_CHECK(length >= 0 && start <= cur_size - length, + TORCH_CHECK(start <= cur_size - length, "start (", start, ") + length (", length, ") exceeds dimension size (", cur_size, ")."); return at::slice(self, dim, start, start + length, 1); } Tensor narrow_symint(const Tensor& self, int64_t dim, SymInt start, SymInt length) { TORCH_CHECK(self.dim() > 0, "narrow() cannot be applied to a 0-dim tensor."); + TORCH_CHECK(length >= 0, "narrow(): length must be non-negative."); auto cur_size = self.sym_size(dim); if (start != cur_size) { // start being the end is valid, but not a valid dim specification. start = maybe_wrap_dim(start, cur_size); } - TORCH_CHECK(length >= 0 && start <= cur_size - length, + TORCH_CHECK(start <= cur_size - length, "start (", start, ") + length (", length, ") exceeds dimension size (", cur_size, ")."); return at::slice_symint(self, dim, start, start + length, 1); } @@ -1375,7 +1574,7 @@ Tensor reshape_symint(const Tensor& self, c10::SymIntArrayRef proposed_shape) { // // We need to do the checks here instead of in `native_functions.yaml` // to preserve backwards compatibility. - if (!self.is_xla() && !self.is_lazy() && !self.is_ipu()) { + if (!self.is_xla() && !self.is_lazy() && !self.is_ipu() && !at::isTensorSubclassLike(self)) { return self._reshape_alias_symint(shape, stride.value()); } else { return self.view_symint(shape); @@ -1384,6 +1583,23 @@ Tensor reshape_symint(const Tensor& self, c10::SymIntArrayRef proposed_shape) { return at::_unsafe_view_symint(self.clone(at::MemoryFormat::Contiguous), shape); } +Tensor _reshape_copy_symint(const Tensor& self, c10::SymIntArrayRef proposed_shape) { + if (self.is_sparse()) { + TORCH_CHECK(0, "_reshape_copy is not implemented for sparse tensors"); + } + c10::SymDimVector shape = infer_size_dv(proposed_shape, self.sym_numel()); + + if (self.is_mkldnn()) { + TORCH_CHECK(0, "_reshape_copy not implemented for mkldnn tensors"); + } + + if (self.is_contiguous()) { + return self.view_symint(shape).clone(at::MemoryFormat::Contiguous); + } else { + return at::_unsafe_view_symint(self.clone(at::MemoryFormat::Contiguous), shape); + } +} + // Duplicate of above code for non-symbolic ints. Kept for BC purposes and to // minimize breakages. Tensor reshape(const Tensor& self, IntArrayRef proposed_shape) { @@ -1512,22 +1728,29 @@ QuantizerPtr create_subtensor_quantizer(const Tensor& self, bool is_select, int6 return quantizer; } -Tensor select(const Tensor& self, int64_t dim, int64_t index_) { +Tensor select(const Tensor& self, int64_t dim, int64_t index) { + return at::select_symint(self, dim, c10::SymInt{index}); +} + +Tensor select(const Tensor& self, Dimname dim, int64_t index) { + return at::select_symint(self, dimname_to_position(self, dim), c10::SymInt{index}); +} + +Tensor select_symint(const Tensor& self, int64_t dim, c10::SymInt index) { int64_t ndim = self.dim(); if (ndim == 0) { TORCH_CHECK_INDEX(false, "select() cannot be applied to a 0-dim tensor."); } dim = maybe_wrap_dim(dim, ndim); auto size = self.sym_sizes()[dim]; - if (size < -index_ || size <= index_) { + if (size < -index || size <= index) { if (self.has_names() && self.names()[dim] != Dimname::wildcard()) { - TORCH_CHECK_INDEX(false, "select(): index ", index_, " out of range for tensor of size ", + TORCH_CHECK_INDEX(false, "select(): index ", index, " out of range for tensor of size ", self.sizes(), " at dimension ", self.names()[dim]); } - TORCH_CHECK_INDEX(false, "select(): index ", index_, " out of range for tensor of size ", + TORCH_CHECK_INDEX(false, "select(): index ", index, " out of range for tensor of size ", self.sizes(), " at dimension ", dim); } - SymInt index = index_; if (index < 0) { index += size; } @@ -1560,13 +1783,9 @@ Tensor select(const Tensor& self, int64_t dim, int64_t index_) { return result; } -Tensor select(const Tensor& self, Dimname dim, int64_t index) { - return at::select(self, dimname_to_position(self, dim), index); -} - -Tensor select_backward(const Tensor& grad, IntArrayRef input_sizes, int64_t dim, int64_t index) { - auto grad_input = at::zeros(input_sizes, grad.options()); - grad_input.select(dim, index).copy_(grad); +Tensor select_backward_symint(const Tensor& grad, c10::SymIntArrayRef input_sizes, int64_t dim, c10::SymInt index) { + auto grad_input = at::zeros_symint(input_sizes, grad.options()); + grad_input.select_symint(dim, index).copy_(grad); return grad_input; } @@ -2889,7 +3108,7 @@ Tensor squeeze_qtensor(const Tensor& self, c10::optional dim) { const auto* per_channel_quantizer = static_cast(quantizer.get()); auto axis = per_channel_quantizer->axis(); int64_t shift = 0; - integer_range dims = dim.has_value() ? integer_range{dim.value(), dim.value() + 1} : c10::irange(self.dim()); + integer_range dims = dim.has_value() ? integer_range{dim.value(), dim.value() + 1} : c10::irange(0, self.dim()); for (const auto d : dims) { if (self.sizes()[d] == 1) { TORCH_CHECK(axis != d, "Squeeze is only possible on non-axis dimension for Per-Channel Quantized Tensors."); @@ -3390,72 +3609,29 @@ Tensor unfold(const Tensor& self, int64_t d, int64_t size, int64_t step) { return self.as_strided(sizes, strides); } -template -void apply_diag(Tensor& result, const Tensor& self, int64_t dimension) { - TORCH_CHECK(self.dim() == 1 || self.dim() == 2, "matrix or a vector expected"); - - auto self_data = self.data_ptr(); - if (self.dim() == 1) { - auto self_size = self.size(0); - auto self_stride = self.stride(0); - int64_t sz = self_size + std::abs(dimension); - - at::native::resize_output(result, {sz, sz}); - result.zero_(); - auto r_data = result.data_ptr(); - auto r_stride_0 = result.stride(0); - auto r_stride_1 = result.stride(1); - r_data += (dimension >= 0 ? dimension*r_stride_1 : -dimension*r_stride_0); - - for (const auto i : c10::irange(self_size)) { - r_data[i * (r_stride_0 + r_stride_1)] = self_data[i * self_stride]; - } +Tensor diag(const Tensor& self, int64_t offset) { + auto ndim = self.dim(); + TORCH_CHECK(ndim == 1 || ndim == 2, "diag(): Supports 1D or 2D tensors. Got ", self.dim(), "D"); + if (ndim == 1) { + return at::diag_embed(self, offset); } else { - auto self_stride_0 = self.stride(0); - auto self_stride_1 = self.stride(1); - - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int64_t sz; - if (dimension >= 0) { - sz = std::min(self.size(0), self.size(1) - dimension); - } else { - sz = std::min(self.size(0) + dimension, self.size(1)); - } - - at::native::resize_output(result, {sz}); - result.zero_(); - auto r_data = result.data_ptr(); - auto r_stride_0 = result.stride(0); - self_data += (dimension >= 0 ? dimension * self_stride_1 : -dimension * self_stride_0); - for (const auto i : c10::irange(sz)) { - r_data[i * r_stride_0] = self_data[i * (self_stride_0 + self_stride_1)]; - } + // We return a copy of the diagonal + return at::diagonal_copy(self, offset); } } -Tensor diag(const Tensor& self, int64_t dimension) { - Tensor result = at::empty({0}, self.options()); - at::diag_out(result, self, dimension); - return result; -} - -Tensor& diag_cpu_out(const Tensor& self, int64_t dimension, Tensor &result) { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kBool, self.scalar_type(), "diag", [&] { - apply_diag(result, self, dimension); - }); - return result; -} - -Tensor diag_backward_symint(const Tensor& grad, SymIntArrayRef input_sizes, int64_t diagonal) { - auto ndimension = input_sizes.size(); - AT_ASSERT(ndimension == 1 || ndimension == 2); - - if (ndimension == 1 || input_sizes[0] == input_sizes[1]) { - return grad.diag(diagonal); +Tensor& diag_out(const Tensor& self, int64_t offset, Tensor& out) { + auto ndim = self.dim(); + TORCH_CHECK(ndim == 1 || ndim == 2, "Supports 1D or 2D tensors. Got ", self.dim(), "D"); + if (ndim == 1) { + TORCH_CHECK( + canCast(self.scalar_type(), out.scalar_type()), + "diag: result type ", self.scalar_type(), " can't be cast to the desired out= type ", + out.scalar_type()); + return at::diag_embed_out(out, self, offset); + } else { + return at::diagonal_copy_out(out, self, offset); } - - // Input was a matrix but was not square - return at::diagonal_backward_symint(grad, input_sizes, diagonal, 0, 1); } Tensor diagonal_backward_symint(const Tensor & grad, SymIntArrayRef input_sizes, int64_t offset, int64_t dim1, int64_t dim2) { @@ -3616,9 +3792,9 @@ at::Tensor slice_scatter(const at::Tensor& self, const at::Tensor& src, int64_t slice.copy_(src); return output; } -at::Tensor select_scatter(const at::Tensor& self, const at::Tensor& src, int64_t dim, int64_t index) { +at::Tensor select_scatter_symint(const at::Tensor& self, const at::Tensor& src, int64_t dim, c10::SymInt index) { auto output = self.clone(); - auto slice = output.select(dim, index); + auto slice = output.select_symint(dim, index); TORCH_CHECK(slice.sizes() == src.sizes(), "expected src to have a size equal to the slice of self. src size = ", src.sizes(), ", slice size = ", slice.sizes()); slice.copy_(src); return output; @@ -3709,8 +3885,16 @@ at::Tensor& _sparse_broadcast_to_copy_out(const at::Tensor & self, at::IntArrayR at::Tensor& diagonal_copy_out(const at::Tensor & self, int64_t offset, int64_t dim1, int64_t dim2, at::Tensor & out) { - auto tmp = self.diagonal(offset, dim1, dim2); - out.copy_(tmp); + TORCH_CHECK( + out.device() == self.device(), + "diagonal_copy: Expected out and self tensors to be on the same device, but got ", + "out on ", out.device(), " and self on ", self.device()); + auto result = self.diagonal(offset, dim1, dim2); + at::native::resize_output(out, result.sizes()); + TORCH_CHECK( + canCast(result.scalar_type(), out.scalar_type()), + "diagonal_copy: result type ", result.scalar_type(), " can't be cast to the desired out= type ", out.scalar_type()); + out.copy_(result); return out; } @@ -3750,8 +3934,8 @@ at::Tensor& _reshape_alias_copy_out(const at::Tensor & self, at::IntArrayRef siz } -at::Tensor& select_copy_int_out(const at::Tensor & self, int64_t dim, int64_t index, at::Tensor & out) { - auto tmp = self.select(dim, index); +at::Tensor& select_copy_symint_out(const at::Tensor & self, int64_t dim, c10::SymInt index, at::Tensor & out) { + auto tmp = self.select_symint(dim, index); out.copy_(tmp); return out; } diff --git a/aten/src/ATen/native/TensorShape.h b/aten/src/ATen/native/TensorShape.h index 21d0ba78261ec..60e2533e9b538 100644 --- a/aten/src/ATen/native/TensorShape.h +++ b/aten/src/ATen/native/TensorShape.h @@ -53,11 +53,4 @@ inline int64_t get_num_splits(const Tensor& self, int64_t split_size, int64_t di return num_splits; } -/// -/// For more information, see -/// https://pytorch.org/docs/master/generated/torch.Tensor.unfold.html#torch.Tensor.unfold -/// - -Tensor unfold(const Tensor& self, int64_t dimension, int64_t size, int64_t step); - }} // namespace at::native diff --git a/aten/src/ATen/native/TensorTransformations.cpp b/aten/src/ATen/native/TensorTransformations.cpp index f0e2c0f02caa7..028b05e66930e 100644 --- a/aten/src/ATen/native/TensorTransformations.cpp +++ b/aten/src/ATen/native/TensorTransformations.cpp @@ -1,14 +1,31 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include // for flip_stub -#include -#include #include +#include #include #include #include #include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#endif + #include #include diff --git a/aten/src/ATen/native/TestOps.cpp b/aten/src/ATen/native/TestOps.cpp index a8c30f5c3ba61..3f62aa58d2593 100644 --- a/aten/src/ATen/native/TestOps.cpp +++ b/aten/src/ATen/native/TestOps.cpp @@ -1,10 +1,26 @@ // Copyright 2004-present Facebook. All Rights Reserved. +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS -#include +#include #include -#include #include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#endif + #include namespace at { @@ -91,6 +107,11 @@ Tensor _test_autograd_multiple_dispatch_view(const Tensor &self) { return self.view(-1); } +// Helper for inductor tests +Tensor _test_inductor_realize(const Tensor &self) { + return self.clone(); +} + } // namespace native namespace functionalization { diff --git a/aten/src/ATen/native/TriangularOps.cpp b/aten/src/ATen/native/TriangularOps.cpp index f98018d7fe5a5..59d2b8a0d224b 100644 --- a/aten/src/ATen/native/TriangularOps.cpp +++ b/aten/src/ATen/native/TriangularOps.cpp @@ -1,22 +1,34 @@ -#include -#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include #include -#include #include #include -#include #include #include #include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#include +#endif + namespace at { namespace meta { TORCH_META_FUNC(tril)(const Tensor& self, int64_t k) { + TORCH_CHECK(self.dim() >= 2, "tril: input tensor must have at least 2 dimensions") set_output_raw_strided(0, self.sizes(), {}, self.options()); } TORCH_META_FUNC(triu)(const Tensor& self, int64_t k) { + TORCH_CHECK(self.dim() >= 2, "triu: input tensor must have at least 2 dimensions") set_output_raw_strided(0, self.sizes(), {}, self.options()); } diff --git a/aten/src/ATen/native/TriangularOpsUtils.h b/aten/src/ATen/native/TriangularOpsUtils.h index c5bce42ed3fd7..e380a510bddeb 100644 --- a/aten/src/ATen/native/TriangularOpsUtils.h +++ b/aten/src/ATen/native/TriangularOpsUtils.h @@ -1,4 +1,4 @@ -#include +#include #include namespace at { diff --git a/aten/src/ATen/native/TypeProperties.cpp b/aten/src/ATen/native/TypeProperties.cpp index feceb75631cec..36354c133a98e 100644 --- a/aten/src/ATen/native/TypeProperties.cpp +++ b/aten/src/ATen/native/TypeProperties.cpp @@ -1,8 +1,26 @@ -#include -#include -#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include #include -#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#endif namespace at { namespace native { diff --git a/aten/src/ATen/native/UnaryOps.cpp b/aten/src/ATen/native/UnaryOps.cpp index c301d8ecc26a2..845610ce373e7 100644 --- a/aten/src/ATen/native/UnaryOps.cpp +++ b/aten/src/ATen/native/UnaryOps.cpp @@ -1,26 +1,174 @@ -#include -#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include #include -#include #include +#include +#include +#include +#include +#include #include -#include -#include -#include #include #include -#include -#include #include -#include -#include -#include -#include -#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#endif -#include +#include namespace at { diff --git a/aten/src/ATen/native/Unfold2d.cpp b/aten/src/ATen/native/Unfold2d.cpp index 0a3b760a33fda..60bbc8a777121 100644 --- a/aten/src/ATen/native/Unfold2d.cpp +++ b/aten/src/ATen/native/Unfold2d.cpp @@ -1,3 +1,4 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include namespace at { namespace native { diff --git a/aten/src/ATen/native/Unfold3d.cpp b/aten/src/ATen/native/Unfold3d.cpp index 3495f92dc3ce6..1a2d0ea2ae1f9 100644 --- a/aten/src/ATen/native/Unfold3d.cpp +++ b/aten/src/ATen/native/Unfold3d.cpp @@ -1,5 +1,7 @@ -#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include #include +#include #include #include diff --git a/aten/src/ATen/native/UnfoldBackward.cpp b/aten/src/ATen/native/UnfoldBackward.cpp index 10bee80cea23c..4941432321169 100644 --- a/aten/src/ATen/native/UnfoldBackward.cpp +++ b/aten/src/ATen/native/UnfoldBackward.cpp @@ -5,6 +5,7 @@ #include #include #else +#include #include #include #endif @@ -21,6 +22,11 @@ Tensor unfold_backward( int64_t step ) { auto grad_input = at::zeros(input_sizes, grad.options()); + if (step >= size) { + auto gI_unfolded = grad_input.unfold(dim, size, step); + gI_unfolded.copy_(grad); + return grad_input; + } unfold_backward_stub( grad.device().type(), diff --git a/aten/src/ATen/native/UnfoldBackward.h b/aten/src/ATen/native/UnfoldBackward.h index 1f6c8fa1b289c..f8099167361c2 100644 --- a/aten/src/ATen/native/UnfoldBackward.h +++ b/aten/src/ATen/native/UnfoldBackward.h @@ -1,10 +1,9 @@ #pragma once #include -#include +#include #include -#include -#include +#include #ifndef AT_PER_OPERATOR_HEADERS #include @@ -108,79 +107,6 @@ static C10_UNUSED TensorIterator _make_unfold_backward_iter_over_grad_out( return iter; } -static C10_UNUSED TensorIterator _make_unfold_backward_iter_over_grad_in( - Tensor& grad_out, - const Tensor& grad_in, - int64_t dim, - int64_t /*size*/, - int64_t /*step*/ -) { - dim = maybe_wrap_dim(dim, grad_out.dim()); - // last dim stores the folds - auto last_dim = maybe_wrap_dim(-1, grad_in.dim()); - - auto grad_in_dim = ensure_nonempty_dim(grad_in.dim()); - auto grad_in_dim_size = ensure_nonempty_size(grad_in, dim); - auto grad_in_last_dim_size = ensure_nonempty_size(grad_in, last_dim); - - /* prepare grad_out for TensorIterator { */ - auto grad_out_restrided = grad_out.unsqueeze(-1); - - auto grad_out_strides = ensure_nonempty_vec(grad_out_restrided.strides().vec()); - auto grad_out_sizes = ensure_nonempty_vec(grad_out_restrided.sizes().vec()); - - grad_out_strides[dim] = 0; - grad_out_strides[last_dim] = 0; - - grad_out_sizes[dim] = grad_in_dim_size; - grad_out_sizes[last_dim] = grad_in_last_dim_size; - - grad_out_restrided = grad_out_restrided.as_strided(grad_out_sizes, grad_out_strides); - /* } */ - - // for each element grad_out[i_1,...,i_dim,...,i_last_dim] - // we have to know i_dim and i_last_dim. - // This information is stored in Tensors - // idx_dim and idx_last_dim - /* prepare idx_dim and idx_last_dim for TensorIterator { */ - auto idx_dim = at::arange( - 0, grad_in_dim_size, grad_in.options().dtype(at::kLong) - ); - - auto idx_dim_strides = std::vector(grad_in_dim, 0); - auto idx_dim_sizes = std::vector(grad_in_dim, 1); - - idx_dim_strides[dim] = 1; - idx_dim_sizes[dim] = grad_in_dim_size; - - auto idx_dim_restrided = idx_dim.as_strided(idx_dim_sizes, idx_dim_strides); - - auto idx_last_dim = at::arange( - 0, grad_in_last_dim_size, grad_in.options().dtype(at::kLong) - ); - - auto idx_last_dim_strides = std::vector(grad_in_dim, 0); - auto idx_last_dim_sizes = std::vector(grad_in_dim, 1); - - idx_last_dim_strides[last_dim] = 1; - idx_last_dim_sizes[last_dim] = grad_in_last_dim_size; - - auto idx_last_dim_restrided = idx_last_dim.as_strided(idx_last_dim_sizes, idx_last_dim_strides); - /* } */ - - auto iter = TensorIteratorConfig() - .set_check_mem_overlap(false) - .check_all_same_dtype(false) - .resize_outputs(false) - .add_owned_output(grad_out_restrided) - .add_owned_input(grad_in) - .add_owned_input(idx_dim_restrided) - .add_owned_input(idx_last_dim_restrided) - .build(); - - return iter; -} - } }} // namespace at::native diff --git a/aten/src/ATen/native/Unique.cpp b/aten/src/ATen/native/Unique.cpp index f418611e08644..92b48c9f388ca 100644 --- a/aten/src/ATen/native/Unique.cpp +++ b/aten/src/ATen/native/Unique.cpp @@ -1,8 +1,27 @@ // Returns unique elements of input tensor. +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS -#include +#include #include #include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#endif #include #include diff --git a/aten/src/ATen/native/UpSample.cpp b/aten/src/ATen/native/UpSample.cpp index db75b7e99fdb1..1a6af75260300 100644 --- a/aten/src/ATen/native/UpSample.cpp +++ b/aten/src/ATen/native/UpSample.cpp @@ -1,4 +1,5 @@ // Copyright 2004-present Facebook. All Rights Reserved. +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include diff --git a/aten/src/ATen/native/UpSample.h b/aten/src/ATen/native/UpSample.h index f3dd836444d13..92ee7252d1bd6 100644 --- a/aten/src/ATen/native/UpSample.h +++ b/aten/src/ATen/native/UpSample.h @@ -2,7 +2,7 @@ #include -#include +#include #include #include #include @@ -56,7 +56,7 @@ TORCH_API c10::SmallVector compute_output_size( inline c10::optional get_scale_value(c10::optional> scales, int idx) { if (!scales) { - return nullopt; + return c10::nullopt; } return scales->at(idx); } @@ -266,15 +266,13 @@ static inline scalar_t area_pixel_compute_scale( bool align_corners, const c10::optional scale) { // see Note [area_pixel_compute_scale] - if(align_corners){ + if(align_corners) { if(output_size > 1) { return static_cast(input_size - 1) / (output_size - 1); - } - else { + } else { return static_cast(0); } - } - else{ + } else { return compute_scales_value(scale, input_size, output_size); } } @@ -447,14 +445,20 @@ static inline void compute_source_index_and_lambda( lambda0 = static_cast(1); lambda1 = static_cast(0); } else { - using accscalar_t = at::acc_type; - const accscalar_t real_input_index = - area_pixel_compute_source_index( + using opmath_t = at::opmath_type; + const auto real_input_index = + area_pixel_compute_source_index( ratio, output_index, align_corners, /*cubic=*/false); - input_index0 = static_cast(real_input_index); + // when `real_input_index` becomes larger than the range the floating point + // type can accurately represent, the type casting to `int64_t` might exceed + // `input_size - 1`, causing overflow. So we guard it with `std::min` below. + input_index0 = std::min(static_cast(real_input_index), input_size - 1); int64_t offset = (input_index0 < input_size - 1) ? 1 : 0; input_index1 = input_index0 + offset; - lambda1 = real_input_index - input_index0; + lambda1 = std::min( + std::max(real_input_index - input_index0, static_cast(0)), + static_cast(1) + ); lambda0 = static_cast(1.) - lambda1; } } diff --git a/aten/src/ATen/native/UpSampleBicubic2d.cpp b/aten/src/ATen/native/UpSampleBicubic2d.cpp index 5bf7ba6a53666..035bea5629547 100644 --- a/aten/src/ATen/native/UpSampleBicubic2d.cpp +++ b/aten/src/ATen/native/UpSampleBicubic2d.cpp @@ -1,8 +1,24 @@ -#include -#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include #include #include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#include +#include +#include +#endif + namespace at { namespace meta { @@ -271,18 +287,6 @@ Tensor upsample_bicubic2d( return at::upsample_bicubic2d(input, osize, align_corners, scale_h, scale_w); } -Tensor upsample_bicubic2d_backward( - const Tensor& grad_output, - at::OptionalIntArrayRef output_size, - IntArrayRef input_size, - bool align_corners, - c10::optional> scale_factors) { - auto osize = compute_output_size(input_size, output_size, scale_factors); - auto scale_h = get_scale_value(scale_factors, 0); - auto scale_w = get_scale_value(scale_factors, 1); - return at::upsample_bicubic2d_backward(grad_output, osize, input_size, align_corners, scale_h, scale_w); -} - Tensor _upsample_bicubic2d_aa( const Tensor& input, at::OptionalIntArrayRef output_size, @@ -294,18 +298,6 @@ Tensor _upsample_bicubic2d_aa( return at::_upsample_bicubic2d_aa(input, osize, align_corners, scale_h, scale_w); } -Tensor _upsample_bicubic2d_aa_backward( - const Tensor& grad_output, - at::OptionalIntArrayRef output_size, - IntArrayRef input_size, - bool align_corners, - c10::optional> scale_factors) { - auto osize = compute_output_size(input_size, output_size, scale_factors); - auto scale_h = get_scale_value(scale_factors, 0); - auto scale_w = get_scale_value(scale_factors, 1); - return at::_upsample_bicubic2d_aa_backward(grad_output, osize, input_size, align_corners, scale_h, scale_w); -} - DEFINE_DISPATCH(upsample_bicubic2d_kernel); DEFINE_DISPATCH(_upsample_bicubic2d_aa_kernel); DEFINE_DISPATCH(_upsample_bicubic2d_aa_backward_kernel); diff --git a/aten/src/ATen/native/UpSampleBilinear2d.cpp b/aten/src/ATen/native/UpSampleBilinear2d.cpp index 527555a066abb..5d91e93e016df 100644 --- a/aten/src/ATen/native/UpSampleBilinear2d.cpp +++ b/aten/src/ATen/native/UpSampleBilinear2d.cpp @@ -1,11 +1,26 @@ // Adapted from interp.cpp from Caffe util by Pauline Luc // Originally developed by George Papandreou +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS -#include -#include +#include +#include #include #include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#include +#include +#include +#endif + namespace at { namespace meta { @@ -154,18 +169,6 @@ Tensor upsample_bilinear2d( return at::upsample_bilinear2d(input, osize, align_corners, scale_h, scale_w); } -Tensor upsample_bilinear2d_backward( - const Tensor& grad_output, - at::OptionalIntArrayRef output_size, - IntArrayRef input_size, - bool align_corners, - c10::optional> scale_factors) { - auto osize = compute_output_size(input_size, output_size, scale_factors); - auto scale_h = get_scale_value(scale_factors, 0); - auto scale_w = get_scale_value(scale_factors, 1); - return at::upsample_bilinear2d_backward(grad_output, osize, input_size, align_corners, scale_h, scale_w); -} - Tensor _upsample_bilinear2d_aa( const Tensor& input, at::OptionalIntArrayRef output_size, @@ -177,18 +180,6 @@ Tensor _upsample_bilinear2d_aa( return at::_upsample_bilinear2d_aa(input, osize, align_corners, scale_h, scale_w); } -Tensor _upsample_bilinear2d_aa_backward( - const Tensor& grad_output, - at::OptionalIntArrayRef output_size, - IntArrayRef input_size, - bool align_corners, - c10::optional> scale_factors) { - auto osize = compute_output_size(input_size, output_size, scale_factors); - auto scale_h = get_scale_value(scale_factors, 0); - auto scale_w = get_scale_value(scale_factors, 1); - return at::_upsample_bilinear2d_aa_backward(grad_output, osize, input_size, align_corners, scale_h, scale_w); -} - DEFINE_DISPATCH(upsample_bilinear2d_kernel); DEFINE_DISPATCH(upsample_bilinear2d_backward_kernel); DEFINE_DISPATCH(_upsample_bilinear2d_aa_kernel); diff --git a/aten/src/ATen/native/UpSampleLinear1d.cpp b/aten/src/ATen/native/UpSampleLinear1d.cpp index b100450c2b6a7..aed082b685638 100644 --- a/aten/src/ATen/native/UpSampleLinear1d.cpp +++ b/aten/src/ATen/native/UpSampleLinear1d.cpp @@ -1,10 +1,22 @@ // Adapted from interp.cpp from Caffe util by Pauline Luc // Originally developed by George Papandreou +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS -#include -#include +#include +#include +#include #include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#endif + namespace at { namespace meta { @@ -87,17 +99,6 @@ Tensor upsample_linear1d( return at::upsample_linear1d(input, osize, align_corners, scale_w); } -Tensor upsample_linear1d_backward( - const Tensor& grad_output, - at::OptionalIntArrayRef output_size, - IntArrayRef input_size, - bool align_corners, - c10::optional> scale_factors) { - auto osize = compute_output_size(input_size, output_size, scale_factors); - auto scale_w = get_scale_value(scale_factors, 0); - return at::upsample_linear1d_backward(grad_output, osize, input_size, align_corners, scale_w); -} - DEFINE_DISPATCH(upsample_linear1d_kernel); DEFINE_DISPATCH(upsample_linear1d_backward_kernel); diff --git a/aten/src/ATen/native/UpSampleNearest1d.cpp b/aten/src/ATen/native/UpSampleNearest1d.cpp index 83121ed3be45b..1bdbda8f66c41 100644 --- a/aten/src/ATen/native/UpSampleNearest1d.cpp +++ b/aten/src/ATen/native/UpSampleNearest1d.cpp @@ -1,7 +1,23 @@ -#include -#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include #include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#include +#include +#include +#endif + namespace at { namespace meta { @@ -125,26 +141,6 @@ Tensor _upsample_nearest_exact1d( return at::_upsample_nearest_exact1d(input, osize, scale_w); } -Tensor upsample_nearest1d_backward( - const Tensor& grad_output, - at::OptionalIntArrayRef output_size, - IntArrayRef input_size, - c10::optional> scale_factors) { - auto osize = compute_output_size(input_size, output_size, scale_factors); - auto scale_w = get_scale_value(scale_factors, 0); - return at::upsample_nearest1d_backward(grad_output, osize, input_size, scale_w); -} - -Tensor _upsample_nearest_exact1d_backward( - const Tensor& grad_output, - at::OptionalIntArrayRef output_size, - IntArrayRef input_size, - c10::optional> scale_factors) { - auto osize = compute_output_size(input_size, output_size, scale_factors); - auto scale_w = get_scale_value(scale_factors, 0); - return at::_upsample_nearest_exact1d_backward(grad_output, osize, input_size, scale_w); -} - DEFINE_DISPATCH(upsample_nearest1d_kernel); DEFINE_DISPATCH(_upsample_nearest_exact1d_kernel); DEFINE_DISPATCH(upsample_nearest1d_backward_kernel); diff --git a/aten/src/ATen/native/UpSampleNearest2d.cpp b/aten/src/ATen/native/UpSampleNearest2d.cpp index ee5dce4a02eff..65e20b78f868e 100644 --- a/aten/src/ATen/native/UpSampleNearest2d.cpp +++ b/aten/src/ATen/native/UpSampleNearest2d.cpp @@ -1,9 +1,24 @@ -#include -#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include #include #include #include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#include +#include +#include +#endif + namespace at { namespace meta { @@ -152,28 +167,6 @@ Tensor _upsample_nearest_exact2d( return at::_upsample_nearest_exact2d(input, osize, scale_h, scale_w); } -Tensor upsample_nearest2d_backward( - const Tensor& grad_output, - at::OptionalIntArrayRef output_size, - IntArrayRef input_size, - c10::optional> scale_factors) { - auto osize = compute_output_size(input_size, output_size, scale_factors); - auto scale_h = get_scale_value(scale_factors, 0); - auto scale_w = get_scale_value(scale_factors, 1); - return at::upsample_nearest2d_backward(grad_output, osize, input_size, scale_h, scale_w); -} - -Tensor _upsample_nearest_exact2d_backward( - const Tensor& grad_output, - at::OptionalIntArrayRef output_size, - IntArrayRef input_size, - c10::optional> scale_factors) { - auto osize = compute_output_size(input_size, output_size, scale_factors); - auto scale_h = get_scale_value(scale_factors, 0); - auto scale_w = get_scale_value(scale_factors, 1); - return at::_upsample_nearest_exact2d_backward(grad_output, osize, input_size, scale_h, scale_w); -} - DEFINE_DISPATCH(upsample_nearest2d_kernel); DEFINE_DISPATCH(_upsample_nearest_exact2d_kernel); DEFINE_DISPATCH(upsample_nearest2d_backward_kernel); diff --git a/aten/src/ATen/native/UpSampleNearest3d.cpp b/aten/src/ATen/native/UpSampleNearest3d.cpp index 0e4040980ae26..27ca6745655c9 100644 --- a/aten/src/ATen/native/UpSampleNearest3d.cpp +++ b/aten/src/ATen/native/UpSampleNearest3d.cpp @@ -1,8 +1,23 @@ -#include -#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include #include #include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#include +#include +#include +#endif + namespace at { namespace meta { @@ -147,7 +162,7 @@ TORCH_IMPL_FUNC(_upsample_nearest_exact3d_backward_out_cpu) ( using at::native::upsample::compute_output_size; using at::native::upsample::get_scale_value; -Tensor upsample_nearest3d_cpu( +Tensor upsample_nearest3d( const Tensor& input, at::OptionalIntArrayRef output_size, c10::optional> scale_factors) { @@ -158,7 +173,7 @@ Tensor upsample_nearest3d_cpu( return at::upsample_nearest3d(input, osize, scale_d, scale_h, scale_w); } -Tensor _upsample_nearest_exact3d_cpu( +Tensor _upsample_nearest_exact3d( const Tensor& input, at::OptionalIntArrayRef output_size, c10::optional> scale_factors) { @@ -169,31 +184,6 @@ Tensor _upsample_nearest_exact3d_cpu( return at::_upsample_nearest_exact3d(input, osize, scale_d, scale_h, scale_w); } -// when structured kernels can handle QuantizedCPU, update these overloads to be CompositeExplicitAutograd -Tensor upsample_nearest3d_backward_cpu( - const Tensor& grad_output, - at::OptionalIntArrayRef output_size, - IntArrayRef input_size, - c10::optional> scale_factors) { - auto osize = compute_output_size(input_size, output_size, scale_factors); - auto scale_d = get_scale_value(scale_factors, 0); - auto scale_h = get_scale_value(scale_factors, 1); - auto scale_w = get_scale_value(scale_factors, 2); - return at::upsample_nearest3d_backward(grad_output, osize, input_size, scale_d, scale_h, scale_w); -} - -Tensor _upsample_nearest_exact3d_backward_cpu( - const Tensor& grad_output, - at::OptionalIntArrayRef output_size, - IntArrayRef input_size, - c10::optional> scale_factors) { - auto osize = compute_output_size(input_size, output_size, scale_factors); - auto scale_d = get_scale_value(scale_factors, 0); - auto scale_h = get_scale_value(scale_factors, 1); - auto scale_w = get_scale_value(scale_factors, 2); - return at::_upsample_nearest_exact3d_backward(grad_output, osize, input_size, scale_d, scale_h, scale_w); -} - DEFINE_DISPATCH(upsample_nearest3d_kernel); DEFINE_DISPATCH(_upsample_nearest_exact3d_kernel); DEFINE_DISPATCH(upsample_nearest3d_backward_kernel); diff --git a/aten/src/ATen/native/UpSampleTrilinear3d.cpp b/aten/src/ATen/native/UpSampleTrilinear3d.cpp index 73fffbe5afe79..1bf9c8f6cb4ee 100644 --- a/aten/src/ATen/native/UpSampleTrilinear3d.cpp +++ b/aten/src/ATen/native/UpSampleTrilinear3d.cpp @@ -1,11 +1,22 @@ // Adapted from interp.cpp from Caffe util by Pauline Luc // Originally developed by George Papandreou +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS -#include -#include +#include +#include #include #include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#endif + namespace at { namespace meta { @@ -100,19 +111,6 @@ Tensor upsample_trilinear3d( return at::upsample_trilinear3d(input, osize, align_corners, scale_d, scale_h, scale_w); } -Tensor upsample_trilinear3d_backward( - const Tensor& grad_output, - at::OptionalIntArrayRef output_size, - IntArrayRef input_size, - bool align_corners, - c10::optional> scale_factors) { - auto osize = compute_output_size(input_size, output_size, scale_factors); - auto scale_d = get_scale_value(scale_factors, 0); - auto scale_h = get_scale_value(scale_factors, 1); - auto scale_w = get_scale_value(scale_factors, 2); - return at::upsample_trilinear3d_backward(grad_output, osize, input_size, align_corners, scale_d, scale_h, scale_w); -} - DEFINE_DISPATCH(upsample_trilinear3d_kernel); DEFINE_DISPATCH(upsample_trilinear3d_backward_kernel); diff --git a/aten/src/ATen/native/VariableMethodStubs.cpp b/aten/src/ATen/native/VariableMethodStubs.cpp index ce5432e677af2..6191717930aec 100644 --- a/aten/src/ATen/native/VariableMethodStubs.cpp +++ b/aten/src/ATen/native/VariableMethodStubs.cpp @@ -1,5 +1,23 @@ -#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include #include +#else +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#endif // The stubs in here are used by dynamic dispatch. It just redirects everything // to the Tensor method we manually bind in TensorBody.h. diff --git a/aten/src/ATen/native/WeightNorm.cpp b/aten/src/ATen/native/WeightNorm.cpp index bf258d80a0fb3..8291120f19603 100644 --- a/aten/src/ATen/native/WeightNorm.cpp +++ b/aten/src/ATen/native/WeightNorm.cpp @@ -1,11 +1,21 @@ -#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include #include -#include +#include #include -#include -#include -#include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#include +#endif + #include namespace at { diff --git a/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear.cpp b/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear.cpp index de053b353758a..144cdb292ba16 100644 --- a/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear.cpp +++ b/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear.cpp @@ -247,6 +247,7 @@ class QLinearInt8 final { }; TORCH_LIBRARY_IMPL(sparse, QuantizedCPU, m) { + register_linear_params(); m.impl( TORCH_SELECTIVE_NAME("sparse::qlinear"), TORCH_FN(QLinearInt8::run)); diff --git a/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_deserialize.cpp b/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_deserialize.cpp index c5fa0210cd581..d367dbe011031 100644 --- a/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_deserialize.cpp +++ b/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_deserialize.cpp @@ -209,7 +209,7 @@ PackedLinearWeightQnnp::PackedLinearWeightQnnp( std::get(serialized); TORCH_CHECK( serialization_version <= SPARSE_LINEAR_PACKED_PARAM_SERIALIZATION_VERSION, - "Attemped to deserialize sparse qlinear packed params with an ", + "Attempted to deserialize sparse qlinear packed params with an ", "incompatible serialization version (", serialization_version, " > ", diff --git a/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_dynamic.cpp b/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_dynamic.cpp index a430e81854519..64cab80790a99 100644 --- a/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_dynamic.cpp +++ b/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_dynamic.cpp @@ -45,7 +45,7 @@ at::Tensor PackedLinearWeightQnnp::apply_dynamic_impl( const auto cols_input = static_cast(input.size(input.dim() - 1)); TORCH_CHECK( cols_input == input_channels_, - "quantized_sparse_lienar: Input tensor's last and weight tensor's" + "quantized_sparse_linear: Input tensor's last and weight tensor's" " second dimension must match."); // On empty input, no output data will be generated, diff --git a/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_prepack.cpp b/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_prepack.cpp index 83aaf810edd72..bedf2f4461f3a 100644 --- a/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_prepack.cpp +++ b/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_prepack.cpp @@ -240,6 +240,7 @@ class QLinearPackWeightInt8 final { }; TORCH_LIBRARY_IMPL(sparse, QuantizedCPU, m) { + register_linear_params(); m.impl( TORCH_SELECTIVE_NAME("sparse::qlinear_prepack"), TORCH_FN(QLinearPackWeightInt8::run)); diff --git a/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_unpack.cpp b/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_unpack.cpp index 14cf9521a4cdb..d66abc9d2a8a5 100644 --- a/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_unpack.cpp +++ b/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_unpack.cpp @@ -133,6 +133,7 @@ class QLinearUnpackWeightInt8 final { }; TORCH_LIBRARY_IMPL(sparse, CatchAll, m) { + register_linear_params(); m.impl( TORCH_SELECTIVE_NAME("sparse::qlinear_unpack"), TORCH_FN(QLinearUnpackWeightInt8::run)); diff --git a/aten/src/ATen/native/cpu/Activation.cpp b/aten/src/ATen/native/cpu/Activation.cpp index 6f3eac783ccda..728ea62f1898f 100644 --- a/aten/src/ATen/native/cpu/Activation.cpp +++ b/aten/src/ATen/native/cpu/Activation.cpp @@ -623,7 +623,25 @@ void shrink_backward_kernel(TensorIteratorBase& iter, const Scalar& lambd) { } void hardtanh_backward_kernel(TensorIterator& iter, const Scalar& min, const Scalar& max) { - AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "hardshrink_backward_cpu", [&] { + if (iter.dtype() == kBFloat16) { + auto min_val = min.to(); + auto max_val = max.to(); + cpu_kernel_vec( + iter, + [=](BFloat16 grad_val, BFloat16 self_val) -> BFloat16 { + return (float(self_val) <= min_val || float(self_val) >= max_val) ? BFloat16(0) : grad_val; + }, + [=](Vectorized grad_val, Vectorized self_val) -> Vectorized { + Vectorized grad_val0, grad_val1, self_val0, self_val1; + std::tie(grad_val0, grad_val1) = convert_bfloat16_float(grad_val); + std::tie(self_val0, self_val1) = convert_bfloat16_float(self_val); + return convert_float_bfloat16( + ((self_val0 > min_val) & (self_val0 < max_val)) & grad_val0, + ((self_val1 > min_val) & (self_val1 < max_val)) & grad_val1 + ); + }); + } else { + AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "hardshrink_backward_cpu", [&] { auto min_val = min.to(); auto max_val = max.to(); cpu_kernel_vec( @@ -635,6 +653,7 @@ void hardtanh_backward_kernel(TensorIterator& iter, const Scalar& min, const Sca return ((self_val > min_val) & (self_val < max_val)) & grad_val; }); }); + } } void hardswish_kernel(TensorIterator& iter) { @@ -1035,8 +1054,23 @@ void glu_backward_kernel(TensorIterator& iter) { } void silu_kernel(TensorIteratorBase& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1( - kBFloat16, iter.dtype(), "silu_cpu", [&]() { + if (iter.dtype() == kBFloat16) { + const Vectorized kOneVec(1.0f); + cpu_kernel_vec( + iter, + [](BFloat16 x) -> BFloat16 { + return float(x) / (1.0f + std::exp(-float(x))); + }, + [kOneVec](Vectorized x_vec) -> Vectorized { + Vectorized x_vec0, x_vec1; + std::tie(x_vec0, x_vec1) = convert_bfloat16_float(x_vec); + return convert_float_bfloat16( + x_vec0 / (kOneVec + x_vec0.neg().exp()), + x_vec1 / (kOneVec + x_vec1.neg().exp())); + }); + } else { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES( + iter.dtype(), "silu_cpu", [&]() { const Vectorized kOneVec(scalar_t(1)); cpu_kernel_vec( iter, @@ -1047,11 +1081,34 @@ void silu_kernel(TensorIteratorBase& iter) { return x_vec / (kOneVec + x_vec.neg().exp()); }); }); + } } void silu_backward_kernel(TensorIteratorBase& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1( - kBFloat16, iter.dtype(), "silu_backward_cpu", [&]() { + if (iter.dtype() == kBFloat16) { + const Vectorized kOneVec(1.0f); + cpu_kernel_vec( + iter, + [](BFloat16 dy, BFloat16 x) -> BFloat16 { + const float sigmoid = + 1.0f / (1.0f + std::exp(-float(x))); + return dy * sigmoid * (1.0f + x * (1.0f - sigmoid)); + }, + [kOneVec](Vectorized dy_vec, Vectorized x_vec) -> Vectorized { + Vectorized x_vec0, x_vec1, dy_vec0, dy_vec1; + std::tie(x_vec0, x_vec1) = convert_bfloat16_float(x_vec); + std::tie(dy_vec0, dy_vec1) = convert_bfloat16_float(dy_vec); + const Vectorized sigmoid0 = + kOneVec / (kOneVec + x_vec0.neg().exp()); + const Vectorized sigmoid1 = + kOneVec / (kOneVec + x_vec1.neg().exp()); + return convert_float_bfloat16( + dy_vec0 * sigmoid0 * (kOneVec + x_vec0 * (kOneVec - sigmoid0)), + dy_vec1 * sigmoid1 * (kOneVec + x_vec1 * (kOneVec - sigmoid1))); + }); + } else { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES( + iter.dtype(), "silu_backward_cpu", [&]() { const Vectorized kOneVec(scalar_t(1)); cpu_kernel_vec( iter, @@ -1066,10 +1123,26 @@ void silu_backward_kernel(TensorIteratorBase& iter) { return dy_vec * sigmoid * (kOneVec + x_vec * (kOneVec - sigmoid)); }); }); + } } void mish_kernel(TensorIteratorBase& iter) { - AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "mish_cpu", [&]() { + if (iter.dtype() == kBFloat16) { + cpu_kernel_vec( + iter, + [](BFloat16 x) -> BFloat16{ + return static_cast(float(x) * std::tanh(std::log1p(std::exp(float(x))))); + }, + [](Vectorized x_vec) -> Vectorized { + Vectorized x_vec0, x_vec1; + std::tie(x_vec0, x_vec1) = convert_bfloat16_float(x_vec); + return convert_float_bfloat16( + x_vec0 * x_vec0.exp().log1p().tanh(), + x_vec1 * x_vec1.exp().log1p().tanh() + ); + }); + } else { + AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "mish_cpu", [&]() { using Vec = Vectorized; cpu_kernel_vec( iter, @@ -1080,10 +1153,36 @@ void mish_kernel(TensorIteratorBase& iter) { return x_vec * x_vec.exp().log1p().tanh(); }); }); + } } void mish_backward_kernel(TensorIterator& iter) { - AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "mish_backward_cpu", [&]() { + if (iter.dtype() == kBFloat16) { + using Vec = Vectorized; + const Vec kOneVec(1.0f); + cpu_kernel_vec( + iter, + [](BFloat16 dy, BFloat16 x) -> BFloat16 { + const float sigmoid = + 1.0f / (1.0f + std::exp(-float(x))); + const float tanh_softplus = std::tanh(std::log1p(std::exp(float(x)))); + return dy * (tanh_softplus + x * sigmoid * (1.0f - tanh_softplus * tanh_softplus)); + }, + [kOneVec](Vectorized dy_vec, Vectorized x_vec) -> Vectorized { + Vectorized x_vec0, x_vec1, dy_vec0, dy_vec1; + std::tie(x_vec0, x_vec1) = convert_bfloat16_float(x_vec); + std::tie(dy_vec0, dy_vec1) = convert_bfloat16_float(dy_vec); + const Vec sigmoid0 = kOneVec / (kOneVec + x_vec0.neg().exp()); + const Vec sigmoid1 = kOneVec / (kOneVec + x_vec1.neg().exp()); + const Vec tanh_softplus0 = x_vec0.exp().log1p().tanh(); + const Vec tanh_softplus1 = x_vec1.exp().log1p().tanh(); + return convert_float_bfloat16( + dy_vec0 * (tanh_softplus0 + x_vec0 * sigmoid0 * (kOneVec - tanh_softplus0 * tanh_softplus0)), + dy_vec1 * (tanh_softplus1 + x_vec1 * sigmoid1 * (kOneVec - tanh_softplus1 * tanh_softplus1)) + ); + }); + } else { + AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "mish_backward_cpu", [&]() { using Vec = Vectorized; const Vec kOneVec(scalar_t(1)); cpu_kernel_vec( @@ -1100,6 +1199,7 @@ void mish_backward_kernel(TensorIterator& iter) { return dy_vec * (tanh_softplus + x_vec * sigmoid * (kOneVec - tanh_softplus * tanh_softplus)); }); }); + } } void prelu_cpu_kernel(TensorIterator& iter) { diff --git a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp index a5dde16024ab6..9b5f442ef02cc 100644 --- a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp @@ -314,10 +314,13 @@ void bitwise_xor_kernel(TensorIteratorBase& iter) { void lshift_kernel(TensorIteratorBase& iter) { AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "lshift_cpu", [&]() { - cpu_kernel(iter, - [](scalar_t a, scalar_t b) -> scalar_t { - return static_cast>(a) << b; - }); + cpu_kernel_vec(iter, + [](scalar_t a, scalar_t b) -> scalar_t { + return static_cast>(a) << b; + }, + [](Vectorized a, Vectorized b) { + return a << b; + }); }); } @@ -380,10 +383,13 @@ void logical_xor_kernel(TensorIterator& iter) { void rshift_kernel(TensorIteratorBase& iter) { AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "rshift_cpu", [&]() { - cpu_kernel(iter, - [](scalar_t a, scalar_t b) -> scalar_t { - return a >> b; - }); + cpu_kernel_vec(iter, + [](scalar_t a, scalar_t b) -> scalar_t { + return a >> b; + }, + [](Vectorized a, Vectorized b) { + return a >> b; + }); }); } diff --git a/aten/src/ATen/native/cpu/GridSamplerKernel.cpp b/aten/src/ATen/native/cpu/GridSamplerKernel.cpp index 47b20b2ca4c18..c80c5d2f000d6 100644 --- a/aten/src/ATen/native/cpu/GridSamplerKernel.cpp +++ b/aten/src/ATen/native/cpu/GridSamplerKernel.cpp @@ -511,8 +511,8 @@ struct ApplyGridSample(x_w); + auto i_y_n = convert_to_int_of_same_size(y_n); auto i_x_e = i_x_w + iVec(1); auto i_y_s = i_y_n + iVec(1); diff --git a/aten/src/ATen/native/cpu/HistogramKernel.cpp b/aten/src/ATen/native/cpu/HistogramKernel.cpp index 932bf9beb4993..83011aa2e9a79 100644 --- a/aten/src/ATen/native/cpu/HistogramKernel.cpp +++ b/aten/src/ATen/native/cpu/HistogramKernel.cpp @@ -166,8 +166,8 @@ void histogramdd_cpu_contiguous(Tensor& hist, const TensorList& bin_edges, * the appropriate bin via simple division. */ pos = static_cast((elt - leftmost_edge[dim]) - / (rightmost_edge[dim] - leftmost_edge[dim]) - * (num_bin_edges[dim] - 1)); + * (num_bin_edges[dim] - 1) + / (rightmost_edge[dim] - leftmost_edge[dim])); /* Ensures consistency with bin_edges by checking the bins to the left and right * of the selected position. Necessary for cases in which an element very close diff --git a/aten/src/ATen/native/cpu/IndexKernel.cpp b/aten/src/ATen/native/cpu/IndexKernel.cpp index be0dc3301a006..81e135d1e7498 100644 --- a/aten/src/ATen/native/cpu/IndexKernel.cpp +++ b/aten/src/ATen/native/cpu/IndexKernel.cpp @@ -457,6 +457,75 @@ void masked_select_kernel(TensorIterator& iter, int64_t result_stride) { }); } + +template +void cpu_hflip_vec(at::TensorIterator& iter) { + + auto loop2d = [&](char** base, const int64_t *strides, int64_t size0, int64_t size1) { + + static constexpr int ntensors = 3; + std::array data_arr; + std::copy_n(base, ntensors, data_arr.data()); + const int64_t *outer_strides = &strides[ntensors]; + + using Vec = Vectorized; + + constexpr auto stride = sizeof(scalar_t); + TORCH_INTERNAL_ASSERT(stride == -strides[0] && stride == strides[1]); + + for (const auto j C10_UNUSED : c10::irange(size1)) { + + // vectorized loop with negative stride for output + char** C10_RESTRICT data_ = data_arr.data(); + int64_t n = size0; + + char* C10_RESTRICT data[ntensors]; + for (const auto arg : c10::irange(ntensors)) { + data[arg] = data_[arg]; + } + + int64_t i = 0; + + // data[0] unaligned pre-pass + int64_t offset = (j * n + (n - i - Vec::size())) % 32; + offset = (offset >= n) ? n : offset; + for (; i < offset; i++) { + scalar_t* out_ptr = (scalar_t*)(data[0] - i * stride); + *out_ptr = *(scalar_t *)(data[1] + i * stride); + } + // Empirically found that it is faster to process 3 data items together vs 2 or 4 + for (; i <= n - 3 * Vec::size(); i += 3 * Vec::size()) { + auto out1 = Vec::loadu(data[1] + i * stride); + auto out2 = Vec::loadu(data[1] + (i + Vec::size()) * stride); + auto out3 = Vec::loadu(data[1] + (i + 2 * Vec::size()) * stride); + // flip the vector: 1234 -> 4321 + out1 = flip(out1); + out2 = flip(out2); + out3 = flip(out3); + out1.store(data[0] - (i + Vec::size() - 1) * stride); + out2.store(data[0] - (i + 2 * Vec::size() - 1) * stride); + out3.store(data[0] - (i + 3 * Vec::size() - 1) * stride); + } + if (i < n) { + for (; i < n; i++) { + scalar_t* out_ptr = (scalar_t*)(data[0] - i * stride); + *out_ptr = *(scalar_t *)(data[1] + i * stride); + } + } + + // advance: + for (const auto arg : c10::irange(data_arr.size())) { + data_arr[arg] += outer_strides[arg]; + } + } + }; + + int64_t grain_size = at::internal::GRAIN_SIZE; + iter.for_each(loop2d, grain_size); + iter.cast_outputs(); +} + + void flip_kernel(TensorIterator& iter, const bool quantized) { if (quantized) { AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(iter.dtype(), "flip_quantized_cpu", @@ -466,6 +535,29 @@ void flip_kernel(TensorIterator& iter, const bool quantized) { }); }); } else { + // Special case: horizontal flip with vectorization and input is contiguous + // Context: horizontal flip leads to strides[0] < 0 and + // thus is_contiguous condition is not satisfied and non-vectorized code path is taken. + auto output_strides = iter.strides(0); + auto input_strides = iter.strides(1); + if (iter.ndim() > 0 && output_strides[0] < 0 && input_strides[0] == iter.element_size(1)) { + auto iter_dtype = iter.dtype(); + if (iter_dtype == kByte) { + return cpu_hflip_vec(iter); + } else if (iter_dtype == kFloat) { + return cpu_hflip_vec(iter); + } else if (iter_dtype == kInt) { + return cpu_hflip_vec(iter); + } else if (iter_dtype == kShort) { + return cpu_hflip_vec(iter); + } else if (iter_dtype == kLong) { + return cpu_hflip_vec(iter); + } else if (iter_dtype == kDouble) { + return cpu_hflip_vec(iter); + } + // other dtypes are handled below with cpu_kernel_vec + } + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kHalf, kBFloat16, iter.dtype(), "flip_cpu", [&iter] { cpu_kernel_vec(iter, [](scalar_t a, scalar_t /*dummy input*/) -> scalar_t { diff --git a/aten/src/ATen/native/cpu/README.md b/aten/src/ATen/native/cpu/README.md index ab2f9d3d02609..2cf6fa0a13320 100644 --- a/aten/src/ATen/native/cpu/README.md +++ b/aten/src/ATen/native/cpu/README.md @@ -64,7 +64,7 @@ within 256bit & 512bits registers. vec defines various operators such as As an example `ReduceOpsKernel.cpp` implements a generic `kernel_` that reduces an entire array using a given associative binary operation such as +. -More explicity, calling `kernel_` with template argument `std::plus` will cause +More explicitly, calling `kernel_` with template argument `std::plus` will cause it to sum up the entire array into a single value. `ReduceOpsKernel.cpp` uses the `CPU_CAPABILITY_*` macros to "know" under which @@ -73,7 +73,7 @@ generic code, which will be compiled under multipled compilation settings. `../ReduceOps.cpp` now includes the header `ReduceOpsKernel.h`, which contains a generic definition of `sumImplAll`. This function allows the user to reduce -over a dimension or all dimensions. The appropiate capability is chosen at +over a dimension or all dimensions. The appropriate capability is chosen at runtime using cpuinfo. If the current platform has AVX2, `sumImpl` will be set to `sumImplAll`. diff --git a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp index a4345c3fd5d86..bbf45ba2ecd0b 100644 --- a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp @@ -184,7 +184,7 @@ static void prod_kernel_impl(TensorIterator& iter) { // NOLINTNEXTLINE(bugprone-argument-comment) /*identity=*/1); } else { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX(iter.dtype(), "prod_cpu", [&] { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, iter.dtype(), "prod_out_cpu", [&] { binary_kernel_reduce_vec( iter, [=](scalar_t a, scalar_t b) @@ -333,20 +333,9 @@ static void and_kernel_impl(TensorIterator& iter) { binary_kernel_reduce_vec( iter, [=](uint8_t a, uint8_t b) -> uint8_t { return (a && b) ? 1 : 0; }, -#if defined(CPU_CAPABILITY_ZVECTOR) [=](Vectorized a, Vectorized b) { return a & b; }, -#else - [=](Vectorized a, Vectorized b) { - Vectorized c = Vectorized(); - - for (decltype(c.size()) i = 0; i != Vectorized::size(); i++) { - c[i] = (a[i] && b[i]) ? 1 : 0; - } - return c; - }, -#endif /*ident=*/true); } else { binary_kernel_reduce_vec( @@ -380,20 +369,9 @@ static void or_kernel_impl(TensorIterator& iter) { binary_kernel_reduce_vec( iter, [=](uint8_t a, uint8_t b) -> uint8_t { return (a || b) ? 1 : 0; }, -#if defined(CPU_CAPABILITY_ZVECTOR) [=](Vectorized a, Vectorized b) { return a | b; }, -#else - [=](Vectorized a, Vectorized b) { - Vectorized c = Vectorized(); - - for (decltype(c.size()) i = 0; i != Vectorized::size(); i++) { - c[i] = (a[i] || b[i]) ? 1 : 0; - } - return c; - }, -#endif /*ident=*/false); } else { binary_kernel_reduce_vec( diff --git a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp index 8a0534fd3da5f..898f736fabe86 100644 --- a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp @@ -780,7 +780,7 @@ IMPLEMENT_COMPLEX_KERNEL(log) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) IMPLEMENT_COMPLEX_KERNEL(log10) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) -IMPLEMENT_FLOAT_KERNEL(log1p) +IMPLEMENT_COMPLEX_KERNEL(log1p) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) IMPLEMENT_COMPLEX_KERNEL(log2) // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) diff --git a/aten/src/ATen/native/cpu/UnfoldBackwardKernel.cpp b/aten/src/ATen/native/cpu/UnfoldBackwardKernel.cpp index 129ab3a973e3a..aa5dfb0143801 100644 --- a/aten/src/ATen/native/cpu/UnfoldBackwardKernel.cpp +++ b/aten/src/ATen/native/cpu/UnfoldBackwardKernel.cpp @@ -1,5 +1,6 @@ #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include +#include #include #include #include @@ -65,8 +66,7 @@ void _unfold_backward_internal_kernel( int64_t grad_in_dim_stride, int64_t grad_in_last_dim_stride, int64_t grad_in_dim_size, - int64_t grad_out_dim_stride, - bool is_step_ge_size + int64_t grad_out_dim_stride ) { if (iter.numel() == 0) { return; @@ -77,53 +77,32 @@ void _unfold_backward_internal_kernel( auto* RESTRICT grad_in_ptr = data[1]; auto* RESTRICT idx_dim_ptr = data[2]; - if (is_step_ge_size) { - auto* RESTRICT idx_last_dim_ptr = data[3]; + for (const auto elem C10_UNUSED : c10::irange(nelems)) { + auto* RESTRICT grad_out_data = reinterpret_cast(grad_out_ptr); + auto* RESTRICT grad_in_data = reinterpret_cast(grad_in_ptr); - for (const auto elem C10_UNUSED : c10::irange(nelems)) { - auto* RESTRICT grad_out_data = reinterpret_cast(grad_out_ptr); - auto* RESTRICT grad_in_data = reinterpret_cast(grad_in_ptr); + auto idx_dim = *reinterpret_cast(idx_dim_ptr); - auto idx_dim = *reinterpret_cast(idx_dim_ptr); - auto idx_last_dim = *reinterpret_cast(idx_last_dim_ptr); + // left_fold potentially intersecting with idx_dim + // is either (idx_dim - size) / step or the next integer. + int64_t left_fold_idx = (idx_dim > size) ? (idx_dim - size) / step : 0; + if (!(left_fold_idx * step <= idx_dim && idx_dim < left_fold_idx * step + size)) { + ++left_fold_idx; + } - auto grad_out_idx_dim = idx_dim * step + idx_last_dim; - grad_out_data[grad_out_idx_dim * grad_out_dim_stride] = *grad_in_data; + auto right_fold_idx = idx_dim / step; + right_fold_idx = (right_fold_idx >= grad_in_dim_size) + ? (grad_in_dim_size - 1) : right_fold_idx; - grad_out_ptr += strides[0]; - grad_in_ptr += strides[1]; - idx_dim_ptr += strides[2]; - idx_last_dim_ptr += strides[3]; - } - } - else { - for (const auto elem C10_UNUSED : c10::irange(nelems)) { - auto* RESTRICT grad_out_data = reinterpret_cast(grad_out_ptr); - auto* RESTRICT grad_in_data = reinterpret_cast(grad_in_ptr); - - auto idx_dim = *reinterpret_cast(idx_dim_ptr); - - // left_fold potentially intersecting with idx_dim - // is either (idx_dim - size) / step or the next integer. - int64_t left_fold_idx = (idx_dim > size) ? (idx_dim - size) / step : 0; - if (!(left_fold_idx * step <= idx_dim && idx_dim < left_fold_idx * step + size)) { - ++left_fold_idx; - } - - auto right_fold_idx = idx_dim / step; - right_fold_idx = (right_fold_idx >= grad_in_dim_size) - ? (grad_in_dim_size - 1) : right_fold_idx; - - for (auto fold_idx = left_fold_idx; fold_idx <= right_fold_idx; ++fold_idx) { - auto idx_last_dim = idx_dim - fold_idx * step; - *grad_out_data += grad_in_data[fold_idx * grad_in_dim_stride - + idx_last_dim * grad_in_last_dim_stride]; - } - - grad_out_ptr += strides[0]; - grad_in_ptr += strides[1]; - idx_dim_ptr += strides[2]; + for (auto fold_idx = left_fold_idx; fold_idx <= right_fold_idx; ++fold_idx) { + auto idx_last_dim = idx_dim - fold_idx * step; + *grad_out_data += grad_in_data[fold_idx * grad_in_dim_stride + + idx_last_dim * grad_in_last_dim_stride]; } + + grad_out_ptr += strides[0]; + grad_in_ptr += strides[1]; + idx_dim_ptr += strides[2]; } }; @@ -147,16 +126,8 @@ void unfold_backward_cpu_kernel( auto grad_out_dim_stride = ensure_nonempty_stride(grad_out, dim); - auto is_step_ge_size = (step >= size); - - TensorIterator iter = - is_step_ge_size ? - _make_unfold_backward_iter_over_grad_in( - grad_out, grad_in, dim, size, step - ) : - _make_unfold_backward_iter_over_grad_out( - grad_out, grad_in, dim, size, step - ); + TensorIterator iter = _make_unfold_backward_iter_over_grad_out( + grad_out, grad_in, dim, size, step); AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, @@ -169,8 +140,7 @@ void unfold_backward_cpu_kernel( grad_in_dim_stride, grad_in_last_dim_stride, grad_in_dim_size, - grad_out_dim_stride, - is_step_ge_size + grad_out_dim_stride ); } ); diff --git a/aten/src/ATen/native/cpu/UpSampleKernel.cpp b/aten/src/ATen/native/cpu/UpSampleKernel.cpp index 75cefe425ebbc..8d418c2645040 100644 --- a/aten/src/ATen/native/cpu/UpSampleKernel.cpp +++ b/aten/src/ATen/native/cpu/UpSampleKernel.cpp @@ -471,12 +471,12 @@ void cpu_upsample_linear_channels_last( TORCH_CHECK(channels > 0, "expected input and output channels greater than 0 but got ", channels); int64_t output_slice_size = output_depth * output_height * output_width * channels; - using accscalar_t = at::acc_type; + using opmath_t = at::opmath_type; using Vec = vec::Vectorized; auto loop2d = [&](int64_t begin, int64_t end) { - const scalar_t height_scale = area_pixel_compute_scale( + const auto height_scale = area_pixel_compute_scale( input_height, output_height, align_corners, scales[0]); - const scalar_t width_scale = area_pixel_compute_scale( + const auto width_scale = area_pixel_compute_scale( input_width, output_width, align_corners, scales[1]); auto input_indexr = [=](int64_t n, int64_t h, int64_t w) { @@ -486,7 +486,7 @@ void cpu_upsample_linear_channels_last( // NOLINTNEXTLINE(cppcoreguidelines-init-variables) int64_t ih0, ih1, iw0, iw1; - scalar_t h0lambda, h1lambda, w0lambda, w1lambda; + opmath_t h0lambda, h1lambda, w0lambda, w1lambda; for (const auto n : c10::irange(begin, end)) { for (const auto oh : c10::irange(output_height)) { compute_source_index_and_lambda( @@ -501,10 +501,10 @@ void cpu_upsample_linear_channels_last( scalar_t* i01 = input_indexr(n, ih0, iw1); scalar_t* i10 = input_indexr(n, ih1, iw0); scalar_t* i11 = input_indexr(n, ih1, iw1); - accscalar_t w00 = h0lambda * w0lambda; - accscalar_t w01 = h0lambda * w1lambda; - accscalar_t w10 = h1lambda * w0lambda; - accscalar_t w11 = h1lambda * w1lambda; + opmath_t w00 = h0lambda * w0lambda; + opmath_t w01 = h0lambda * w1lambda; + opmath_t w10 = h1lambda * w0lambda; + opmath_t w11 = h1lambda * w1lambda; int64_t size = channels; int64_t d = 0; @@ -521,11 +521,11 @@ void cpu_upsample_linear_channels_last( }; auto loop3d = [&](int64_t begin, int64_t end) { - const scalar_t depth_scale = area_pixel_compute_scale( + const auto depth_scale = area_pixel_compute_scale( input_depth, output_depth, align_corners, scales[0]); - const scalar_t height_scale = area_pixel_compute_scale( + const auto height_scale = area_pixel_compute_scale( input_height, output_height, align_corners, scales[1]); - const scalar_t width_scale = area_pixel_compute_scale( + const auto width_scale = area_pixel_compute_scale( input_width, output_width, align_corners, scales[2]); auto input_indexr = [=](int64_t n, int64_t d, int64_t h, int64_t w) { @@ -536,7 +536,7 @@ void cpu_upsample_linear_channels_last( // NOLINTNEXTLINE(cppcoreguidelines-init-variables) int64_t id0, id1, ih0, ih1, iw0, iw1; - scalar_t d0lambda, d1lambda, h0lambda, h1lambda, w0lambda, w1lambda; + opmath_t d0lambda, d1lambda, h0lambda, h1lambda, w0lambda, w1lambda; for (const auto n : c10::irange(begin, end)) { for (const auto od : c10::irange(output_depth)) { compute_source_index_and_lambda( @@ -559,14 +559,14 @@ void cpu_upsample_linear_channels_last( scalar_t* i101 = input_indexr(n, id1, ih0, iw1); scalar_t* i110 = input_indexr(n, id1, ih1, iw0); scalar_t* i111 = input_indexr(n, id1, ih1, iw1); - accscalar_t w000 = d0lambda * h0lambda * w0lambda; - accscalar_t w001 = d0lambda * h0lambda * w1lambda; - accscalar_t w010 = d0lambda * h1lambda * w0lambda; - accscalar_t w011 = d0lambda * h1lambda * w1lambda; - accscalar_t w100 = d1lambda * h0lambda * w0lambda; - accscalar_t w101 = d1lambda * h0lambda * w1lambda; - accscalar_t w110 = d1lambda * h1lambda * w0lambda; - accscalar_t w111 = d1lambda * h1lambda * w1lambda; + opmath_t w000 = d0lambda * h0lambda * w0lambda; + opmath_t w001 = d0lambda * h0lambda * w1lambda; + opmath_t w010 = d0lambda * h1lambda * w0lambda; + opmath_t w011 = d0lambda * h1lambda * w1lambda; + opmath_t w100 = d1lambda * h0lambda * w0lambda; + opmath_t w101 = d1lambda * h0lambda * w1lambda; + opmath_t w110 = d1lambda * h1lambda * w0lambda; + opmath_t w111 = d1lambda * h1lambda * w1lambda; int64_t size = channels; int64_t d = 0; @@ -775,10 +775,10 @@ struct HelperInterpNearest : public HelperInterpBase { // index_f32 = (output_index) * scale // input_index = floor(index_f32) // Same as OpenCV INTER_NEAREST - using accscalar_t = at::acc_type; + using opmath_t = at::opmath_type; for (const auto i : c10::irange(output_size)) { - const accscalar_t real_input_index = - area_pixel_compute_source_index( + const auto real_input_index = + area_pixel_compute_source_index( scale, i, /*align_corners=*/true, /*cubic=*/false); input_index = static_cast(floorf(real_input_index)); input_index_ptr[i] = static_cast(std::min(input_index, input_size - 1)) * stride; @@ -826,10 +826,10 @@ struct HelperInterpNearestExact : public HelperInterpNearest { // index_f32 = (output_index + 0.5) * scale - 0.5 // input_index = round(index_f32) // Same as Pillow and Scikit-Image/Scipy ndi.zoom - using accscalar_t = at::acc_type; + using opmath_t = at::opmath_type; for (const auto i : c10::irange(output_size)) { - const accscalar_t real_input_index = - area_pixel_compute_source_index( + const auto real_input_index = + area_pixel_compute_source_index( scale, i, /*align_corners=*/align_corners, /*cubic=*/false); input_index = static_cast(floorf(real_input_index + 0.5)); input_index_ptr[i] = static_cast(std::min(input_index, input_size - 1)) * stride; @@ -975,10 +975,10 @@ struct HelperInterpCubic : public HelperInterpBase { int64_t * idx_ptr; scalar_t * wt_ptr; - using accscalar_t = at::acc_type; + using opmath_t = at::opmath_type; for (const auto i : c10::irange(output_size)) { - const accscalar_t real_input_index = - area_pixel_compute_source_index( + const auto real_input_index = + area_pixel_compute_source_index( scale, i, align_corners, /*cubic=*/true); input_index = static_cast(floorf(real_input_index)); get_cubic_upsample_coefficients(coeffs, real_input_index - input_index); diff --git a/aten/src/ATen/native/cpu/UpSampleMoreKernel.cpp b/aten/src/ATen/native/cpu/UpSampleMoreKernel.cpp index a26cef72bb10c..c73e0249dee82 100644 --- a/aten/src/ATen/native/cpu/UpSampleMoreKernel.cpp +++ b/aten/src/ATen/native/cpu/UpSampleMoreKernel.cpp @@ -441,9 +441,9 @@ void cpu_upsample_linear_backward_channels_last( int64_t input_width = input_sizes[ndim - 1]; int64_t output_width = output_sizes[ndim - 1]; - using accscalar_t = at::acc_type; + using opmath_t = at::opmath_type; using Vec = vec::Vectorized; - auto acc = [](scalar_t* gin, scalar_t* gout, accscalar_t w, int64_t size) { + auto acc = [](scalar_t* gin, scalar_t* gout, opmath_t w, int64_t size) { int64_t d = 0; for (; d < size - (size % Vec::size()); d += Vec::size()) { Vec gin_vec = Vec::loadu(gin + d) + Vec(w) * Vec::loadu(gout + d); diff --git a/aten/src/ATen/native/cpu/batch_norm_kernel.cpp b/aten/src/ATen/native/cpu/batch_norm_kernel.cpp index c00b764f08055..7c8b22210e238 100644 --- a/aten/src/ATen/native/cpu/batch_norm_kernel.cpp +++ b/aten/src/ATen/native/cpu/batch_norm_kernel.cpp @@ -789,15 +789,6 @@ void batch_norm_cpu_collect_stats_contiguous_impl( } } -static inline std::tuple, Vectorized> load2f(const BFloat16* ptr) { - return convert_bfloat16_float(Vectorized::loadu(ptr)); -} - -static inline std::tuple, Vectorized> load2f(const float* ptr) { - using Vec = Vectorized; - return std::make_tuple(Vec::loadu(ptr), Vec::loadu(ptr + Vec::size())); -} - template inline void batch_norm_cpu_collect_stats_channels_last_internal( Tensor& mean, Tensor& var_sum, const Tensor& input) { diff --git a/aten/src/ATen/native/cpu/group_norm_kernel.cpp b/aten/src/ATen/native/cpu/group_norm_kernel.cpp index ff84f9b60784e..6f40e13f3256f 100644 --- a/aten/src/ATen/native/cpu/group_norm_kernel.cpp +++ b/aten/src/ATen/native/cpu/group_norm_kernel.cpp @@ -52,13 +52,15 @@ void GroupNormKernelImplInternal( const bool beta_null = beta_data == nullptr; const int64_t inner_size = D * HxW; + using T_ACC = vec::vec_scalar_t; + at::parallel_for(0, N * G, 1, [&](int64_t start, int64_t end) { for (const auto i : c10::irange(start, end)) { const T* X_ptr = X_data + i * inner_size; - T mean_val; - T rstd_val; - std::tie(mean_val, rstd_val) = utils::RowwiseMoments(X_ptr, inner_size); - rstd_val = T(1) / std::sqrt(std::max(rstd_val, T(0)) + eps); + T_ACC mean_val; + T_ACC rstd_val; + std::tie(mean_val, rstd_val) = RowwiseMoments(X_ptr, inner_size); + rstd_val = T_ACC(1) / std::sqrt(std::max(rstd_val, T_ACC(0)) + eps); if (gamma_null && beta_null) { T* Y_ptr = Y_data + i * inner_size; for (const auto j : c10::irange(inner_size)) { @@ -68,8 +70,8 @@ void GroupNormKernelImplInternal( const int64_t g = i % G; for (const auto j : c10::irange(D)) { const int64_t c = g * D + j; - const T scale = rstd_val * (gamma_null ? T(1) : gamma_data[c]); - const T bias = -scale * mean_val + (beta_null ? T(0) : beta_data[c]); + const T_ACC scale = rstd_val * (gamma_null ? T(1) : gamma_data[c]); + const T_ACC bias = -scale * mean_val + (beta_null ? T(0) : beta_data[c]); X_ptr = X_data + (i * D + j) * HxW; T* Y_ptr = Y_data + (i * D + j) * HxW; for (const auto k : c10::irange(HxW)) { diff --git a/aten/src/ATen/native/cpu/layer_norm_kernel.cpp b/aten/src/ATen/native/cpu/layer_norm_kernel.cpp index f7104875b8247..5fbbf2597529c 100644 --- a/aten/src/ATen/native/cpu/layer_norm_kernel.cpp +++ b/aten/src/ATen/native/cpu/layer_norm_kernel.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #ifndef AT_PER_OPERATOR_HEADERS @@ -22,22 +23,18 @@ namespace native { namespace { -template +template void LayerNormKernelImplInternal( const Tensor& X, const Tensor& gamma, const Tensor& beta, int64_t M, int64_t N, - T eps, + T_ACC eps, Tensor* Y, Tensor* mean, Tensor* rstd) { - using T_ACC = vec::vec_scalar_t; - using Vec = vec::Vectorized; - TORCH_DCHECK_EQ(X.numel(), M * N); - DCHECK(!gamma.defined() || gamma.numel() == N); - DCHECK(!beta.defined() || beta.numel() == N); + using Vec = vec::Vectorized; const T* X_data = X.data_ptr(); const T* gamma_data = gamma.defined() ? gamma.data_ptr() : nullptr; const T* beta_data = beta.defined() ? beta.data_ptr() : nullptr; @@ -55,10 +52,10 @@ void LayerNormKernelImplInternal( T* Y_ptr = Y_data + i * N; T mean_val; T rstd_val; - std::tie(mean_val, rstd_val) = utils::RowwiseMoments(X_ptr, N); + std::tie(mean_val, rstd_val) = RowwiseMoments(X_ptr, N); rstd_val = T(1) / std::sqrt(rstd_val + eps); - const T_ACC scale = rstd_val; - const T_ACC bias = -rstd_val * mean_val; + const T scale = rstd_val; + const T bias = -rstd_val * mean_val; if (gamma_null || beta_null) { for (const auto j : c10::irange(N)) { const T gamma_v = gamma_null ? T(1) : gamma_data[j]; @@ -86,6 +83,94 @@ void LayerNormKernelImplInternal( }); } +template +void layer_norm_kernel_mixed_type( + const Tensor& X, + const Tensor& gamma, + const Tensor& beta, + int64_t M, + int64_t N, + float eps, + Tensor* Y, + Tensor* mean, + Tensor* rstd) { + using bVec = Vectorized; + using fVec = Vectorized; + const BFloat16* X_data = X.data_ptr(); + const param_t* gamma_data = gamma.defined() ? gamma.data_ptr() : nullptr; + const param_t* beta_data = beta.defined() ? beta.data_ptr() : nullptr; + BFloat16* Y_data = Y->data_ptr(); + param_t* mean_data = mean ? mean->data_ptr() : nullptr; + param_t* rstd_data = rstd ? rstd->data_ptr() : nullptr; + + const bool gamma_null = gamma_data == nullptr; + const bool beta_null = beta_data == nullptr; + const bool mean_null = mean_data == nullptr; + const bool rstd_null = rstd_data == nullptr; + at::parallel_for(0, M, 1, [&](int64_t start, int64_t end) { + for (const auto i : c10::irange(start, end)) { + const BFloat16* X_ptr = X_data + i * N; + BFloat16* Y_ptr = Y_data + i * N; + float mean_val; + float rstd_val; + std::tie(mean_val, rstd_val) = RowwiseMoments(X_ptr, N); + rstd_val = float(1) / std::sqrt(rstd_val + eps); + const float scale = rstd_val; + const float bias = -rstd_val * mean_val; + if (gamma_null || beta_null) { + for (const auto j : c10::irange(N)) { + const param_t gamma_v = gamma_null ? param_t(1) : gamma_data[j]; + const param_t beta_v = beta_null ? param_t(0) : beta_data[j]; + Y_ptr[j] = (X_ptr[j] * scale + bias) * gamma_v + beta_v; + } + } else { + int64_t d = 0; + for (; d < N - (N % bVec::size()); d += bVec::size()) { + bVec x_bvec = bVec::loadu(X_ptr + d); + fVec x_fvec0, x_fvec1; + std::tie(x_fvec0, x_fvec1) = convert_bfloat16_float(x_bvec); + fVec gamma_fvec0, gamma_fvec1; + std::tie(gamma_fvec0, gamma_fvec1) = load2f(gamma_data + d); + fVec beta_fvec0, beta_fvec1; + std::tie(beta_fvec0, beta_fvec1) = load2f(beta_data + d); + fVec y_fvec0 = (x_fvec0 * fVec(scale) + fVec(bias)) * gamma_fvec0 + beta_fvec0; + fVec y_fvec1 = (x_fvec1 * fVec(scale) + fVec(bias)) * gamma_fvec1 + beta_fvec1; + bVec y_bvec = convert_float_bfloat16(y_fvec0, y_fvec1); + y_bvec.store(Y_ptr + d); + } + for (; d < N; d++) { + Y_ptr[d] = (X_ptr[d] * scale + bias) * gamma_data[d] + beta_data[d]; + } + } + if (!mean_null) { + mean_data[i] = mean_val; + } + if (!rstd_null) { + rstd_data[i] = rstd_val; + } + } + }); +} + +template <> +void LayerNormKernelImplInternal( + const Tensor& X, + const Tensor& gamma, + const Tensor& beta, + int64_t M, + int64_t N, + float eps, + Tensor* Y, + Tensor* mean, + Tensor* rstd) { + const bool mixed_type = is_mixed_type(X, gamma, beta); + if (mixed_type) { + layer_norm_kernel_mixed_type(X, gamma, beta, M, N, eps, Y, mean, rstd); + } else { + layer_norm_kernel_mixed_type(X, gamma, beta, M, N, eps, Y, mean, rstd); + } +} + void LayerNormKernelImpl( const Tensor& X, const Tensor& gamma, @@ -96,10 +181,14 @@ void LayerNormKernelImpl( Tensor* Y, Tensor* mean, Tensor* rstd) { + TORCH_DCHECK_EQ(X.numel(), M * N); + DCHECK(!gamma.defined() || gamma.numel() == N); + DCHECK(!beta.defined() || beta.numel() == N); AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::BFloat16, X.scalar_type(), "LayerNormKernelImpl", [&]() { - LayerNormKernelImplInternal( - X, gamma, beta, M, N, static_cast(eps), Y, mean, rstd); + using acc_t = vec::vec_scalar_t; + LayerNormKernelImplInternal( + X, gamma, beta, M, N, static_cast(eps), Y, mean, rstd); }); } diff --git a/aten/src/ATen/native/cpu/moments_utils.h b/aten/src/ATen/native/cpu/moments_utils.h index 18e6899619046..8afd3612abb64 100644 --- a/aten/src/ATen/native/cpu/moments_utils.h +++ b/aten/src/ATen/native/cpu/moments_utils.h @@ -14,7 +14,9 @@ namespace at { namespace native { -namespace utils { +inline namespace CPU_CAPABILITY { + +template using acc_t = vec::vec_scalar_t; constexpr int64_t kChunkSize = 16; @@ -52,20 +54,71 @@ C10_ALWAYS_INLINE void AddMomentsVec( m0 = n; } +template +inline void UpdateMomentsVec( + int64_t m0, + const T* X_ptr, + const std::array>, kChunkSize>& c_vecs, + int64_t& m0_stk0, + vec::Vectorized>& m1_stk0, + vec::Vectorized>& m2_stk0) { + using Vec = vec::Vectorized>; + Vec m1_vec(0); + Vec m2_vec(0); + for (const auto j : c10::irange(m0)) { + const Vec x_vec = Vec::loadu(X_ptr + j * Vec::size()); + const Vec delta_vec = x_vec - m1_vec; + m1_vec += delta_vec * c_vecs[j]; + m2_vec += delta_vec * (x_vec - m1_vec); + } + AddMomentsVec(m0, m1_vec, m2_vec, m0_stk0, m1_stk0, m2_stk0); +} + +// each bfloat16 vector will be converted to two float vectors, +// and accumulated successively on m1_stk0/m2_stk0. +template <> +inline void UpdateMomentsVec( + int64_t m0, + const BFloat16* X_ptr, + const std::array, kChunkSize>& c_vecs, + int64_t& m0_stk0, + vec::Vectorized& m1_stk0, + vec::Vectorized& m2_stk0) { + using bVec = vec::Vectorized; + using fVec = vec::Vectorized; + fVec m1_fvec0(0), m1_fvec1(0); + fVec m2_fvec0(0), m2_fvec1(0); + for (const auto j : c10::irange(m0)) { + const bVec x_bvec = bVec::loadu(X_ptr + j * bVec::size()); + fVec x_fvec0, x_fvec1; + std::tie(x_fvec0, x_fvec1) = convert_bfloat16_float(x_bvec); + const fVec delta_fvec0 = x_fvec0 - m1_fvec0; + const fVec delta_fvec1 = x_fvec1 - m1_fvec1; + m1_fvec0 += delta_fvec0 * c_vecs[j]; + m1_fvec1 += delta_fvec1 * c_vecs[j]; + m2_fvec0 += delta_fvec0 * (x_fvec0 - m1_fvec0); + m2_fvec1 += delta_fvec1 * (x_fvec1 - m1_fvec1); + } + AddMomentsVec(m0, m1_fvec0, m2_fvec0, m0_stk0, m1_stk0, m2_stk0); + AddMomentsVec(m0, m1_fvec1, m2_fvec1, m0_stk0, m1_stk0, m2_stk0); +} + // Compute rowwise moments by Welford algorithm and cascade sum to improve // numerical stability. // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance // https://en.wikipedia.org/wiki/Pairwise_summation template -std::pair RowwiseMomentsImpl(const T* X, int64_t N, int64_t ddof = 0) { - using Vec = vec::Vectorized; +std::pair, acc_t> RowwiseMomentsImpl(const T* X, int64_t N, int64_t ddof = 0) { + using T_ACC = acc_t; - constexpr int64_t kVecSize = Vec::size(); + constexpr int64_t kVecSize = vec::Vectorized::size(); + constexpr int64_t kAccVecSize = vec::Vectorized::size(); const int64_t n = N / kVecSize; const int64_t m = divup(n, kChunkSize); - const int64_t depth = CeilLog2(m); + const int64_t depth = utils::CeilLog2(m); - const Vec kZeroVec(T(0)); + using Vec = vec::Vectorized; + const Vec kZeroVec(T_ACC(0)); c10::SmallVector m0_stk(depth, 0); c10::SmallVector m1_stk(depth, kZeroVec); c10::SmallVector m2_stk(depth, kZeroVec); @@ -76,19 +129,12 @@ std::pair RowwiseMomentsImpl(const T* X, int64_t N, int64_t ddof = 0) { static std::array c_vecs = ([]() { std::array result; for (const auto i : c10::irange(kChunkSize)) { - result[i] = Vec(T(1) / static_cast(i + 1)); + result[i] = Vec(T_ACC(1) / static_cast(i + 1)); } return result; })(); - Vec m1_vec(0); - Vec m2_vec(0); - for (const auto j : c10::irange(m0)) { - const Vec x_vec = Vec::loadu(X_ptr + j * kVecSize); - const Vec delta_vec = x_vec - m1_vec; - m1_vec += delta_vec * c_vecs[j]; - m2_vec += delta_vec * (x_vec - m1_vec); - } - AddMomentsVec(m0, m1_vec, m2_vec, m0_stk[0], m1_stk[0], m2_stk[0]); + UpdateMomentsVec(m0, X_ptr, c_vecs, m0_stk[0], m1_stk[0], m2_stk[0]); + int64_t mask = i + 1; for (int64_t j = 1; j < depth && (mask & 1) == 0; ++j) { AddMomentsVec( @@ -109,34 +155,37 @@ std::pair RowwiseMomentsImpl(const T* X, int64_t N, int64_t ddof = 0) { m0_stk[i], m1_stk[i], m2_stk[i], m0_stk[0], m1_stk[0], m2_stk[0]); } - std::array m1_arr{}; - std::array m2_arr{}; + std::array m1_arr{}; + std::array m2_arr{}; m1_stk[0].store(m1_arr.data()); m2_stk[0].store(m2_arr.data()); int64_t m0 = 0; - T m1 = 0; - T m2 = 0; + T_ACC m1 = 0; + T_ACC m2 = 0; for (int64_t i = n * kVecSize; i < N; ++i) { - const T delta = X[i] - m1; + T_ACC x = static_cast(X[i]); + const T_ACC delta = x - m1; ++m0; - m1 += delta / static_cast(m0); - m2 += delta * (X[i] - m1); + m1 += delta / static_cast(m0); + m2 += delta * (x - m1); } - for (const auto i : c10::irange(kVecSize)) { - AddMoments(n, m1_arr[i], m2_arr[i], m0, m1, m2); + // for BFloat16, each vector in m1_arr/m2_arr holds 2*n accumulated result + int64_t m0_add = n * kVecSize / kAccVecSize; + for (const auto i : c10::irange(kAccVecSize)) { + AddMoments(m0_add, m1_arr[i], m2_arr[i], m0, m1, m2); } - return std::make_pair(m1, m2 / static_cast(N - ddof)); + return std::make_pair(m1, m2 / static_cast(N - ddof)); } template -std::pair RowwiseMoments(const T* X, int64_t N, int64_t ddof = 0) { +std::pair, acc_t> RowwiseMoments(const T* X, int64_t N, int64_t ddof = 0) { using Vec = vec::Vectorized; constexpr int64_t kVecSize = Vec::size(); const int64_t n = N / kVecSize; const int64_t m = divup(n, kChunkSize); - const int64_t depth = CeilLog2(m); + const int64_t depth = utils::CeilLog2(m); if (depth <= 4) { return RowwiseMomentsImpl(X, N, ddof); } else if (depth <= 8) { @@ -150,6 +199,6 @@ std::pair RowwiseMoments(const T* X, int64_t N, int64_t ddof = 0) { } } -} // namespace utils +} // namespace CPU_CAPABILITY } // namespace native } // namespace at diff --git a/aten/src/ATen/native/cpu/utils.h b/aten/src/ATen/native/cpu/utils.h index 5c607f06b3a5a..1fd30475e9ff4 100644 --- a/aten/src/ATen/native/cpu/utils.h +++ b/aten/src/ATen/native/cpu/utils.h @@ -61,6 +61,16 @@ template struct VectorizedType { using type = Vectorized struct VectorizedType { using type = Vec2; }; template using VecType = typename VectorizedType::type; +// Helper for mixed data type parameter Vec::load +inline std::tuple, Vectorized> load2f(const BFloat16* ptr) { + return convert_bfloat16_float(Vectorized::loadu(ptr)); +} + +inline std::tuple, Vectorized> load2f(const float* ptr) { + using Vec = Vectorized; + return std::make_tuple(Vec::loadu(ptr), Vec::loadu(ptr + Vec::size())); +} + } // namespace namespace utils { diff --git a/aten/src/ATen/native/cuda/Activation.cpp b/aten/src/ATen/native/cuda/Activation.cpp index 4360f8b5c3efc..31926b353b4a3 100644 --- a/aten/src/ATen/native/cuda/Activation.cpp +++ b/aten/src/ATen/native/cuda/Activation.cpp @@ -114,7 +114,7 @@ Tensor prelu_cuda(const Tensor& self, const Tensor& weight_) { Tensor result = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); TORCH_CHECK(weight_dim == 0 || weight_dim == 1, - "prelu: Expected `weight` to be a scalar or 1D tensor, but got ndim = ", + "prelu: Expected `weight` to be a scalar or 1D tensor, but got: ndim = ", weight_dim); // case1: shared weight for all channels diff --git a/aten/src/ATen/native/cuda/Bucketization.cu b/aten/src/ATen/native/cuda/Bucketization.cu index 2a3d5730d7860..21c582216628e 100644 --- a/aten/src/ATen/native/cuda/Bucketization.cu +++ b/aten/src/ATen/native/cuda/Bucketization.cu @@ -10,7 +10,6 @@ #include #include #else -#include #include #include #include @@ -191,11 +190,6 @@ Tensor searchsorted_cuda( return result; } -// See [Note about _torch_cuda_cu_linker_symbol_op and torch_cuda_cu] in native_functions.yaml -Tensor _torch_cuda_cu_linker_symbol_op_cuda(const Tensor& self) { - return self; -} - Tensor searchsorted_cuda( const Tensor& sorted_sequence, const Scalar& self, diff --git a/aten/src/ATen/native/cuda/Col2Im.cu b/aten/src/ATen/native/cuda/Col2Im.cu index 98d1950004ef2..53eb2df3013eb 100644 --- a/aten/src/ATen/native/cuda/Col2Im.cu +++ b/aten/src/ATen/native/cuda/Col2Im.cu @@ -101,7 +101,6 @@ void col2im_out_cuda_template( int64_t input_batch_stride = input.stride(0); output.resize_({batch_size, n_output_plane, output_height, output_width}); - output.zero_(); int64_t output_batch_stride = output.stride(0); AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, diff --git a/aten/src/ATen/native/cuda/DepthwiseConv2d.cu b/aten/src/ATen/native/cuda/DepthwiseConv2d.cu index 8f0f9b99903a7..20748837bbaf7 100644 --- a/aten/src/ATen/native/cuda/DepthwiseConv2d.cu +++ b/aten/src/ATen/native/cuda/DepthwiseConv2d.cu @@ -236,7 +236,6 @@ __global__ void conv_depthwise2d_grad_weight_kernel( } } } - __syncthreads(); // At this point each thread in the block has a local gradient, which we need to // accumulate prior to writing the global value diff --git a/aten/src/ATen/native/cuda/ForeachPointwiseOp.cu b/aten/src/ATen/native/cuda/ForeachPointwiseOp.cu index 3b04b68b0f391..27b3d77ad4d6c 100644 --- a/aten/src/ATen/native/cuda/ForeachPointwiseOp.cu +++ b/aten/src/ATen/native/cuda/ForeachPointwiseOp.cu @@ -160,10 +160,45 @@ void foreach_tensor_##NAME##_scalarlist_cuda_(TensorList input, TensorList tenso foreach_pointwise_op_(input, tensors1, tensors2, scalars); \ } +#define FOREACH_POINTWISE_OP_TENSOR(NAME, OP) \ + std::vector foreach_tensor_##NAME##_tensor_cuda( \ + TensorList input, \ + TensorList tensors1, \ + TensorList tensors2, \ + const Tensor& scalars_) { \ + auto scalars = convert_tensor_to_scalar_list(scalars_, input.size()); \ + check_foreach_api_restrictions(input, tensors1, tensors2, scalars); \ + if (!can_use_fast_route({input, tensors1, tensors2}) || \ + has_integral_tensor(input, /* includeBool */ true)) { \ + return at::native::foreach_tensor_##NAME##_scalarlist_slow( \ + input, tensors1, tensors2, scalars); \ + } \ + \ + return foreach_pointwise_op(input, tensors1, tensors2, scalars); \ + } \ + \ + void foreach_tensor_##NAME##_tensor_cuda_( \ + TensorList input, \ + TensorList tensors1, \ + TensorList tensors2, \ + const Tensor& scalars_) { \ + auto scalars = convert_tensor_to_scalar_list(scalars_, input.size()); \ + check_foreach_api_restrictions(input, tensors1, tensors2, scalars); \ + if (!can_use_fast_route({input, tensors1, tensors2}, scalars) || \ + has_integral_tensor(input, /* includeBool */ true)) { \ + return at::native::foreach_tensor_##NAME##_scalarlist_slow_( \ + input, tensors1, tensors2, scalars); \ + } \ + \ + foreach_pointwise_op_(input, tensors1, tensors2, scalars); \ + } + FOREACH_POINTWISE_OP_SCALAR(addcmul, std::multiplies); FOREACH_POINTWISE_OP_SCALAR(addcdiv, std::divides); FOREACH_POINTWISE_OP_SCALARLIST(addcmul, std::multiplies); FOREACH_POINTWISE_OP_SCALARLIST(addcdiv, std::divides); +FOREACH_POINTWISE_OP_TENSOR(addcdiv, std::divides); +FOREACH_POINTWISE_OP_TENSOR(addcmul, std::multiplies); // Why bool tensors are pushed to slowpath? diff --git a/aten/src/ATen/native/cuda/ForeachUnaryOp.cu b/aten/src/ATen/native/cuda/ForeachUnaryOp.cu index 09a29e0c62db3..29b2a07a82441 100644 --- a/aten/src/ATen/native/cuda/ForeachUnaryOp.cu +++ b/aten/src/ATen/native/cuda/ForeachUnaryOp.cu @@ -317,7 +317,7 @@ void foreach_tensor_zero_cuda_(TensorList tensors) { std::vector> tensor_lists; tensor_lists.emplace_back(tensors.vec()); - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(ScalarType::Half, tensors[0].scalar_type(), "foreach_zero_cuda_", [&]() { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, tensors[0].scalar_type(), "foreach_zero_cuda_", [&]() { multi_tensor_apply<1>(tensor_lists, ZeroFunctor #include #include +#include #include #ifndef AT_PER_OPERATOR_HEADERS @@ -139,6 +140,8 @@ TORCH_IMPL_FUNC(fractional_max_pool2d_out_cuda) ( const Tensor& output, const Tensor& indices ) { + fractional_max_pool_check_shape(input, randomSamples); + int planeDim = 0; int dimh = 1; int dimw = 2; diff --git a/aten/src/ATen/native/cuda/FractionalMaxPool3d.cu b/aten/src/ATen/native/cuda/FractionalMaxPool3d.cu index 92a77dc00af53..971905d291065 100644 --- a/aten/src/ATen/native/cuda/FractionalMaxPool3d.cu +++ b/aten/src/ATen/native/cuda/FractionalMaxPool3d.cu @@ -11,6 +11,7 @@ #include #include #include +#include #include #ifndef AT_PER_OPERATOR_HEADERS @@ -258,6 +259,7 @@ TORCH_IMPL_FUNC(fractional_max_pool3d_out_cuda) ( int64_t inputW, const Tensor& output, const Tensor& indices) { + fractional_max_pool_check_shape(input, randomSamples); auto output_ = output; auto indices_ = indices; diff --git a/aten/src/ATen/native/cuda/Im2Col.cu b/aten/src/ATen/native/cuda/Im2Col.cu index a18d4d822c659..a209aa2764639 100644 --- a/aten/src/ATen/native/cuda/Im2Col.cu +++ b/aten/src/ATen/native/cuda/Im2Col.cu @@ -102,7 +102,6 @@ static void im2col_out_cuda_template( int64_t output_length = output_height * output_width; output.resize_({batch_size, n_output_plane, output_length}); - output.zero_(); // Launch kernel AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, diff --git a/aten/src/ATen/native/cuda/KernelUtils.cuh b/aten/src/ATen/native/cuda/KernelUtils.cuh index 1e36e2db74d54..d2e956d1a3e44 100644 --- a/aten/src/ATen/native/cuda/KernelUtils.cuh +++ b/aten/src/ATen/native/cuda/KernelUtils.cuh @@ -1,6 +1,10 @@ #pragma once #include +#if !(defined(USE_ROCM) || ((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))) +#include +#endif + namespace at { namespace native { @@ -66,7 +70,49 @@ __device__ __forceinline__ void fastSpecializedAtomicAdd( template < typename scalar_t, typename index_t, - typename std::enable_if::value>::type* = + typename std::enable_if::value>::type* = + nullptr> +__device__ __forceinline__ void fastSpecializedAtomicAdd( + scalar_t* tensor, + index_t index, + const index_t numel, + scalar_t value) { +#if ( \ + (defined(USE_ROCM)) || \ + (defined(CUDA_VERSION) && (CUDA_VERSION < 11000)) || \ + (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))) + gpuAtomicAddNoReturn( + reinterpret_cast(tensor) + index, + static_cast(value)); +#else + // Accounts for the chance tensor falls on an odd 16 bit alignment (ie, not 32 bit aligned) + __nv_bfloat16* target_addr = reinterpret_cast<__nv_bfloat16*>(tensor + index); + bool low_byte = (reinterpret_cast(target_addr) % sizeof(__nv_bfloat162) == 0); + + if (low_byte && index < (numel - 1)) { + __nv_bfloat162 value2; + value2.x = *reinterpret_cast<__nv_bfloat16*>(&value); + value2.y = __int2bfloat16_rz(0); + atomicAdd(reinterpret_cast<__nv_bfloat162*>(target_addr), value2); + + } else if (!low_byte && index > 0) { + __nv_bfloat162 value2; + value2.x = __int2bfloat16_rz(0); + value2.y = *reinterpret_cast<__nv_bfloat16*>(&value); + atomicAdd(reinterpret_cast<__nv_bfloat162*>(target_addr - 1), value2); + + } else { + atomicAdd( + reinterpret_cast<__nv_bfloat16*>(tensor) + index, *reinterpret_cast<__nv_bfloat16*>(&value)); + } +#endif +} + + +template < + typename scalar_t, + typename index_t, + typename std::enable_if::value && !std::is_same::value >::type* = nullptr> __device__ __forceinline__ void fastSpecializedAtomicAdd( scalar_t* tensor, diff --git a/aten/src/ATen/native/cuda/Lerp.cu b/aten/src/ATen/native/cuda/Lerp.cu index c1adb5b6fc030..697b61aa7866c 100644 --- a/aten/src/ATen/native/cuda/Lerp.cu +++ b/aten/src/ATen/native/cuda/Lerp.cu @@ -3,16 +3,54 @@ #include #include #include +#include #include namespace at { namespace native { namespace { +const char lerp_tensor_name[] = "lerp_tensor"; void lerp_tensor_kernel(at::TensorIteratorBase& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( + auto dtype = iter.common_dtype(); + if(at::isComplexType(dtype)) { +#if AT_USE_JITERATOR() + static const auto lerp_tensor_string = jiterator_stringify( + template + T lerp_tensor(T self_val, T end_val, T weight_val) { + return (std::abs(weight_val) < 0.5) + ? self_val + weight_val * (end_val - self_val) + : end_val - + (end_val - self_val) * (static_cast(1) - weight_val); + } + ); // lerp_tensor_string + AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "lerp_cuda", [&] { + jitted_gpu_kernel< + /*name=*/ lerp_tensor_name, + /*return_dtype=*/ scalar_t, + /*common_dtype=*/ scalar_t, + /*arity=*/ 3>(iter, lerp_tensor_string); + }); +#else + AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "lerp_cuda", [&] { + using opmath_t = at::opmath_type; + at::native::gpu_kernel( + iter, + [] GPU_LAMBDA( + scalar_t self_val, + scalar_t end_val, + scalar_t weight_val) -> scalar_t { + opmath_t self_val_f = self_val; + opmath_t end_val_f = end_val; + opmath_t weight_val_f = weight_val; + return lerp(self_val, end_val, weight_val); + }); + }); +#endif + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, - iter.common_dtype(), "lerp_cuda", + dtype, "lerp_cuda", [&] { at::native::gpu_kernel( iter, @@ -23,12 +61,54 @@ void lerp_tensor_kernel(at::TensorIteratorBase& iter) { return lerp(self_val, end_val, weight_val); }); }); + } } +const char lerp_scalar_name[] = "lerp_scalar"; void lerp_scalar_kernel(at::TensorIteratorBase& iter, const c10::Scalar& weight) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( + auto dtype = iter.common_dtype(); + if (at::isComplexType(dtype)) { +#if AT_USE_JITERATOR() + static const auto lerp_scalar_string = jiterator_stringify( + template + T lerp_scalar(T self_val, T end_val, T weight_val) { + return (std::abs(weight_val) < 0.5) + ? self_val + weight_val * (end_val - self_val) + : end_val - + (end_val - self_val) * (static_cast(1) - weight_val); + } + ); // lerp_scalar_string + AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "lerp_cuda", [&] { + using opmath_t = at::opmath_type; + auto weight_val = weight.to(); + jitted_gpu_kernel< + /*name=*/ lerp_scalar_name, + /*return_dtype=*/ scalar_t, + /*common_dtype=*/ scalar_t, + /*arity=*/ 2>( + iter, + lerp_scalar_string, + /*scalar_pos=*/ at::cuda::jit::BinaryFuncVariant::NoScalar, + /*scalar_val=*/ 0, + /*extra_args=*/ std::make_tuple(weight_val)); + }); +#else + AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "lerp_cuda", [&] { + using opmath_t = at::opmath_type; + auto weight_val = weight.to(); + at::native::gpu_kernel( + iter, + [=] GPU_LAMBDA(scalar_t self_val, scalar_t end_val) { + opmath_t self_val_f = self_val; + opmath_t end_val_f = end_val; + return lerp(self_val, end_val, weight_val); + }); + }); +#endif + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, - iter.common_dtype(), "lerp_cuda", + dtype, "lerp_cuda", [&]{ using opmath_t = at::opmath_type; auto weight_val = weight.to(); @@ -38,6 +118,8 @@ void lerp_scalar_kernel(at::TensorIteratorBase& iter, const c10::Scalar& weight) }); }); } +} + } // anonymous namespace REGISTER_DISPATCH(lerp_kernel_tensor_weight, &lerp_tensor_kernel); diff --git a/aten/src/ATen/native/cuda/MultiMarginLoss.cu b/aten/src/ATen/native/cuda/MultiMarginLoss.cu index 15e6d1e9dc0c3..26f21cfa59a22 100644 --- a/aten/src/ATen/native/cuda/MultiMarginLoss.cu +++ b/aten/src/ATen/native/cuda/MultiMarginLoss.cu @@ -31,6 +31,7 @@ __global__ void MultiMarginLoss_forward_kernel( scalar_t *input_k = input + k*dim; scalar_t *output_k = output + k; int target_k = static_cast(target[k]); + CUDA_KERNEL_ASSERT(target_k >= 0 && target_k < dim && "target index is out of bounds"); scalar_t input_target_k = input_k[target_k]; int i_start = threadIdx.x; diff --git a/aten/src/ATen/native/cuda/MultinomialKernel.cu b/aten/src/ATen/native/cuda/MultinomialKernel.cu index de8e8404ac2dd..c8473245604c0 100644 --- a/aten/src/ATen/native/cuda/MultinomialKernel.cu +++ b/aten/src/ATen/native/cuda/MultinomialKernel.cu @@ -80,7 +80,7 @@ void renormRows(Tensor& t) { int64_t cols = t.size(1); auto props = at::cuda::getCurrentDeviceProperties(); - CUDA_KERNEL_ASSERT(props != NULL); + TORCH_CHECK(props != nullptr); int numSM = props->multiProcessorCount; const int64_t maxThreads = std::min( props->maxThreadsPerBlock, cuda_utils::kCUDABlockReduceMaxThreads); @@ -342,7 +342,7 @@ void multinomial_with_replacement_kernel_impl( AT_DISPATCH_FLOATING_TYPES_AND_HALF(self_v.scalar_type(), "multinomial_kernel_cuda", [&] { using accscalar_t = at::acc_type; auto props = at::cuda::getCurrentDeviceProperties(); - CUDA_KERNEL_ASSERT(props != NULL); + TORCH_CHECK(props != nullptr); int numSM = props->multiProcessorCount; int maxThreads = props->maxThreadsPerBlock; int maxShared = props->sharedMemPerBlock; diff --git a/aten/src/ATen/native/cuda/Normalization.cu b/aten/src/ATen/native/cuda/Normalization.cu index 3b27ebfc7d922..a8eff154c3505 100644 --- a/aten/src/ATen/native/cuda/Normalization.cu +++ b/aten/src/ATen/native/cuda/Normalization.cu @@ -48,8 +48,11 @@ bool is_mixed_type(const Tensor& input, const Args&... parameters) { } inline bool batch_norm_use_channels_last_kernels(const at::Tensor& self) { - return (self.is_contiguous(at::MemoryFormat::ChannelsLast) || - (self.is_contiguous() && self.strides()[1] == 1)); + return ( + self.is_contiguous(at::MemoryFormat::ChannelsLast) || + self.is_contiguous(at::MemoryFormat::ChannelsLast3d) || + (self.is_contiguous() && self.strides()[1] == 1) + ); } enum class Impl { @@ -470,6 +473,22 @@ std::tuple batch_norm_cuda(const Tensor& self, const c10 return std::make_tuple(output, save_mean, save_invstd); } +std::tuple _batch_norm_legit_cuda(const Tensor& self, const c10::optional& weight_opt, const c10::optional& bias_opt, Tensor& running_mean, Tensor& running_var, bool train, double momentum, double epsilon) { + return batch_norm_cuda(self, weight_opt, bias_opt, running_mean, running_var, train, momentum, epsilon); +} + +std::tuple _batch_norm_legit_no_stats_cuda(const Tensor& self, const c10::optional& weight_opt, const c10::optional& bias_opt, bool train, double momentum, double epsilon) { + return batch_norm_cuda(self, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, epsilon); +} + +std::tuple _batch_norm_legit_cuda_out(const Tensor& self, const c10::optional& weight_opt, const c10::optional& bias_opt, Tensor& running_mean, Tensor& running_var, bool train, double momentum, double epsilon, Tensor& output, Tensor& save_mean, Tensor& save_invstd) { + return batch_norm_cuda_out(self, weight_opt, bias_opt, running_mean, running_var, train, momentum, epsilon, output, save_mean, save_invstd); +} + +std::tuple _batch_norm_legit_no_stats_cuda_out(const Tensor& self, const c10::optional& weight_opt, const c10::optional& bias_opt, bool train, double momentum, double epsilon, Tensor& output, Tensor& save_mean, Tensor& save_invstd) { + return batch_norm_cuda_out(self, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, epsilon, output, save_mean, save_invstd); +} + std::tuple batch_norm_backward_cuda(const Tensor& grad_out, const Tensor& input, const c10::optional& weight_opt, const c10::optional& running_mean_opt, const c10::optional& running_var_opt, const c10::optional& save_mean_opt, const c10::optional& save_invstd_opt, bool train, double epsilon, std::array grad_input_mask) { // See [Note: hacky wrapper removal for optional tensor] c10::MaybeOwned weight = at::borrow_from_optional_tensor(weight_opt); diff --git a/aten/src/ATen/native/cuda/TriangularOps.cu b/aten/src/ATen/native/cuda/TriangularOps.cu index f87d821f396ce..a079ec6849888 100644 --- a/aten/src/ATen/native/cuda/TriangularOps.cu +++ b/aten/src/ATen/native/cuda/TriangularOps.cu @@ -102,137 +102,9 @@ TORCH_IMPL_FUNC(triu_cuda)(const Tensor& self, int64_t k, const Tensor &result) } } -// Copy the kth diagonal of a matrix B to a vector A. -template -C10_LAUNCH_BOUNDS_1(1024) -__global__ void copy_from_diagonal_kernel( - scalar_t* a, - scalar_t* b, - std::ptrdiff_t start, - std::ptrdiff_t size, - std::ptrdiff_t strideSum, - std::ptrdiff_t strideA) { - for (std::ptrdiff_t linearIndex = blockIdx.x * blockDim.x + threadIdx.x; - linearIndex < size; - linearIndex += gridDim.x * blockDim.x) { - const std::ptrdiff_t bOffset = start + strideSum * linearIndex; - a[strideA * linearIndex] = b[bOffset]; - } -} - -// Copy vector B to the kth diagonal of a matrix A -template -C10_LAUNCH_BOUNDS_1(1024) -__global__ void copy_to_diagonal_kernel( - scalar_t* a, - scalar_t* b, - std::ptrdiff_t start, - std::ptrdiff_t size, - std::ptrdiff_t strideSum, - std::ptrdiff_t strideB) { - for (std::ptrdiff_t linearIndex = blockIdx.x * blockDim.x + threadIdx.x; - linearIndex < size; - linearIndex += gridDim.x * blockDim.x) { - const std::ptrdiff_t aOffset = start + strideSum * linearIndex; - a[aOffset] = b[strideB * linearIndex]; - } -} - -template -Tensor& apply_diag(Tensor& result, const Tensor& self, int64_t dimension) { - TORCH_CHECK( - self.dim() == 1 || self.dim() == 2, "matrix or a vector expected"); - - TensorArg result_arg{result, "result", 1}; - TensorArg self_arg{self, "self", 2}; - checkAllSameGPU(__func__, {result_arg, self_arg}); - checkSameType(__func__, result_arg, self_arg); - - int nDimension = self.dim(); - if (nDimension == 2) { - auto self_stride_0 = self.stride(0); - auto self_stride_1 = self.stride(1); - - int sz; - if (dimension > 0) { - sz = std::min(self.size(0), self.size(1) - dimension); - } else { - sz = std::min(self.size(0) + dimension, self.size(1)); - } - - at::native::resize_output(result, {sz}); - if (sz > 0) { - at::assert_no_internal_overlap(result); - auto result_stride = result.stride(0); - const dim3 threads(std::min( - int(sz), - int(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock))); - const dim3 grid( - std::min(int(1024), ceil_div(int(sz), int(threads.x)))); - auto start = - (dimension >= 0 ? dimension * self_stride_1 - : -dimension * self_stride_0); - - // Kernel Launch - copy_from_diagonal_kernel - <<>>( - result.data_ptr(), - self.data_ptr(), - start, - sz, - self_stride_0 + self_stride_1, - result_stride); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - } - } else { - auto n_elems = self.numel(); - auto sz = (dimension > 0) ? n_elems + dimension : n_elems - dimension; - auto self_stride = self.stride(0); - at::native::resize_output(result, {sz, sz}); - result.zero_(); - if (sz > 0) { - at::assert_no_internal_overlap(result); - auto result_stride_0 = result.stride(0); - auto result_stride_1 = result.stride(1); - const dim3 threads(std::min( - int(sz), at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock)); - const dim3 grid( - std::min(int(1024), ceil_div(int(sz), int(threads.x)))); - auto start = - (dimension >= 0 ? dimension * result_stride_1 - : -dimension * result_stride_0); - - // Kernel Launch - copy_to_diagonal_kernel - <<>>( - result.data_ptr(), - self.data_ptr(), - start, - n_elems, - result_stride_0 + result_stride_1, - self_stride); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - } - } - - return result; -} - -Tensor& diag_cuda_out(const Tensor& self, int64_t dimension, Tensor& result) { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( - kComplexHalf, ScalarType::Half, ScalarType::BFloat16, ScalarType::Bool, - self.scalar_type(), "diag_cuda", - [&] { - apply_diag(result, self, dimension); - }); - return result; -} - Tensor trace_cuda(const Tensor& self) { TORCH_CHECK(self.dim() == 2, "expected a matrix"); - int dimension = 0; - auto result = at::diag(self, dimension); - return result.sum(); + return self.diagonal().sum(); } } // namespace native diff --git a/aten/src/ATen/native/cuda/UnaryFractionKernels.cu b/aten/src/ATen/native/cuda/UnaryFractionKernels.cu index 87aa784b7d5d3..ae4d4a01aa00d 100644 --- a/aten/src/ATen/native/cuda/UnaryFractionKernels.cu +++ b/aten/src/ATen/native/cuda/UnaryFractionKernels.cu @@ -122,7 +122,7 @@ __host__ __device__ static inline c10::complex nearbyint_wrapper(c10::com } #pragma push -#pragma diag_suppress 177 // Function was declared but never referenced +#pragma nv_diag_suppress 177 // Function was declared but never referenced __host__ __device__ static inline c10::complex nearbyint_wrapper(c10::complex a) { return c10::complex(::nearbyint(static_cast(a.real())), ::nearbyint(static_cast(a.imag()))); } diff --git a/aten/src/ATen/native/cuda/UnfoldBackwardKernel.cu b/aten/src/ATen/native/cuda/UnfoldBackwardKernel.cu index 90f5238d0180d..d75de2a6e90fb 100644 --- a/aten/src/ATen/native/cuda/UnfoldBackwardKernel.cu +++ b/aten/src/ATen/native/cuda/UnfoldBackwardKernel.cu @@ -1,6 +1,7 @@ #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include +#include #include #include #include @@ -57,8 +58,7 @@ void _unfold_backward_internal_kernel( int64_t grad_in_dim_stride, int64_t grad_in_last_dim_stride, int64_t grad_in_dim_size, - int64_t grad_out_dim_stride, - bool is_step_ge_size + int64_t grad_out_dim_stride ) { if (iter.numel() == 0) { return; @@ -73,8 +73,7 @@ void _unfold_backward_internal_kernel( grad_in_dim_stride, grad_in_last_dim_stride, grad_in_dim_size, - grad_out_dim_stride, - is_step_ge_size + grad_out_dim_stride ); } return; @@ -84,63 +83,39 @@ void _unfold_backward_internal_kernel( char* __restrict__ grad_in_ptr = reinterpret_cast(iter.data_ptr(1)); char* __restrict__ idx_dim_ptr = reinterpret_cast(iter.data_ptr(2)); - if (is_step_ge_size) { - char* __restrict__ idx_last_dim_ptr = reinterpret_cast(iter.data_ptr(3)); + auto offset_calc = make_offset_calculator<3>(iter); - auto offset_calc = make_offset_calculator<4>(iter); + // The algorithm is: for each index in grad_out find + // the elements contributing to it and sum them up. + // Note: the algorithm does not require any synchronization. + auto loop = [=]C10_DEVICE(int i) { + auto offsets = offset_calc.get(i); - // this loop simply copies the data - // from proper places in grad_out to grad_in - auto loop = [=]C10_DEVICE(int i) { - auto offsets = offset_calc.get(i); + auto* __restrict__ grad_out_data = reinterpret_cast(grad_out_ptr + offsets[0]); + auto* __restrict__ grad_in_data = reinterpret_cast(grad_in_ptr + offsets[1]); - auto* __restrict__ grad_out_data = reinterpret_cast(grad_out_ptr + offsets[0]); - auto* __restrict__ grad_in_data = reinterpret_cast(grad_in_ptr + offsets[1]); + auto idx_dim = *reinterpret_cast(idx_dim_ptr + offsets[2]); - auto idx_dim = *reinterpret_cast(idx_dim_ptr + offsets[2]); - auto idx_last_dim = *reinterpret_cast(idx_last_dim_ptr + offsets[3]); - - auto grad_out_idx_dim = idx_dim * step + idx_last_dim; - grad_out_data[grad_out_idx_dim * grad_out_dim_stride] = *grad_in_data; - }; - - _launch_unfold_backward_kernel(iter.numel(), loop); - } - else { - auto offset_calc = make_offset_calculator<3>(iter); - - // The algorithm is: for each index in grad_out find - // the elements contributing to it and sum them up. - // Note: the algorithm does not require any synchronization. - auto loop = [=]C10_DEVICE(int i) { - auto offsets = offset_calc.get(i); - - auto* __restrict__ grad_out_data = reinterpret_cast(grad_out_ptr + offsets[0]); - auto* __restrict__ grad_in_data = reinterpret_cast(grad_in_ptr + offsets[1]); - - auto idx_dim = *reinterpret_cast(idx_dim_ptr + offsets[2]); - - // left_fold potentially intersecting with idx_dim - // is either (idx_dim - size) / step or the next integer. - int64_t left_fold_idx = (idx_dim > size) ? (idx_dim - size) / step : 0; - if (!(left_fold_idx * step <= idx_dim && idx_dim < left_fold_idx * step + size)) { - ++left_fold_idx; - } + // left_fold potentially intersecting with idx_dim + // is either (idx_dim - size) / step or the next integer. + int64_t left_fold_idx = (idx_dim > size) ? (idx_dim - size) / step : 0; + if (!(left_fold_idx * step <= idx_dim && idx_dim < left_fold_idx * step + size)) { + ++left_fold_idx; + } - auto right_fold_idx = idx_dim / step; - right_fold_idx = (right_fold_idx >= grad_in_dim_size) ? - (grad_in_dim_size - 1) : right_fold_idx; + auto right_fold_idx = idx_dim / step; + right_fold_idx = (right_fold_idx >= grad_in_dim_size) ? + (grad_in_dim_size - 1) : right_fold_idx; - for (auto fold_idx = left_fold_idx; fold_idx <= right_fold_idx; ++fold_idx) { - auto idx_last_dim = idx_dim - fold_idx * step; - *grad_out_data += grad_in_data[fold_idx * grad_in_dim_stride - + idx_last_dim * grad_in_last_dim_stride]; - } + for (auto fold_idx = left_fold_idx; fold_idx <= right_fold_idx; ++fold_idx) { + auto idx_last_dim = idx_dim - fold_idx * step; + *grad_out_data += grad_in_data[fold_idx * grad_in_dim_stride + + idx_last_dim * grad_in_last_dim_stride]; + } - }; + }; - _launch_unfold_backward_kernel(iter.numel(), loop); - } + _launch_unfold_backward_kernel(iter.numel(), loop); } void unfold_backward_cuda_kernel( @@ -160,16 +135,8 @@ void unfold_backward_cuda_kernel( auto grad_out_dim_stride = ensure_nonempty_stride(grad_out, dim); - auto is_step_ge_size = (step >= size); - - TensorIterator iter = - is_step_ge_size ? - _make_unfold_backward_iter_over_grad_in( - grad_out, grad_in, dim, size, step - ) : - _make_unfold_backward_iter_over_grad_out( - grad_out, grad_in, dim, size, step - ); + TensorIterator iter = _make_unfold_backward_iter_over_grad_out( + grad_out, grad_in, dim, size, step); AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, @@ -182,8 +149,7 @@ void unfold_backward_cuda_kernel( grad_in_dim_stride, grad_in_last_dim_stride, grad_in_dim_size, - grad_out_dim_stride, - is_step_ge_size + grad_out_dim_stride ); } ); diff --git a/aten/src/ATen/native/cuda/UpSampleNearest2d.cu b/aten/src/ATen/native/cuda/UpSampleNearest2d.cu index 8aa4f68aeda64..f223655daca15 100644 --- a/aten/src/ATen/native/cuda/UpSampleNearest2d.cu +++ b/aten/src/ATen/native/cuda/UpSampleNearest2d.cu @@ -94,13 +94,13 @@ __global__ void upsample_nearest2d_nhwc_out_frame( float width_scale, const size_t out_numel) { - const int index = blockIdx.x * blockDim.x + threadIdx.x; + const int64_t index = blockIdx.x * blockDim.x + threadIdx.x; if (index < out_numel) { - const int c = index % channels; - const int w2 = (index / channels) % width2; - const int h2 = (index / channels / width2) % height2; - const int n = index / channels / width2 / height2; + const auto c = index % channels; + const auto w2 = (index / channels) % width2; + const auto h2 = (index / channels / width2) % height2; + const auto n = index / channels / width2 / height2; const size_t h1 = height1 == height2 ? h2 : nn_compute_source_index_fn(height_scale, h2, height1); const size_t w1 = width1 == width2 ? w2 : nn_compute_source_index_fn(width_scale, w2, width1); @@ -240,13 +240,13 @@ static void upsample_nearest2d_out_cuda_template( output.is_contiguous(memory_format)) { at::Tensor input = input_.contiguous(at::MemoryFormat::ChannelsLast); - TORCH_CHECK(input.numel() < std::numeric_limits::max(), - "upsample_nearest_nhwc only supports input tensors with less than INT_MAX elements"); - TORCH_CHECK(output.numel() < std::numeric_limits::max(), - "upsample_nearest_nhwc only supports output tensors with less than INT_MAX elements"); + TORCH_CHECK(input.numel() < std::numeric_limits::max(), + "upsample_nearest_nhwc only supports input tensors with less than 2^63 - 1 elements"); + TORCH_CHECK(output.numel() < std::numeric_limits::max(), + "upsample_nearest_nhwc only supports output tensors with less than 2^63 - 1 elements"); - const int num_kernels = output.numel(); - const int num_threads = std::min(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024); + const int64_t num_kernels = output.numel(); + const int64_t num_threads = std::min(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024); AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::Byte, input.scalar_type(), "upsample_nearest2d_nhwc_out_frame", [&] { const scalar_t* idata = input.data_ptr(); diff --git a/aten/src/ATen/native/cuda/UpSampleNearest3d.cu b/aten/src/ATen/native/cuda/UpSampleNearest3d.cu index 1a4afa012d780..58f14ad491a69 100644 --- a/aten/src/ATen/native/cuda/UpSampleNearest3d.cu +++ b/aten/src/ATen/native/cuda/UpSampleNearest3d.cu @@ -337,52 +337,5 @@ TORCH_IMPL_FUNC(_upsample_nearest_exact3d_backward_out_cuda) ( using at::native::upsample::compute_output_size; using at::native::upsample_cuda::get_scale_value; -Tensor upsample_nearest3d_cuda( - const Tensor& input, - at::OptionalIntArrayRef output_size, - c10::optional> scale_factors) { - auto osize = compute_output_size(input.sizes(), output_size, scale_factors); - auto scale_d = get_scale_value(scale_factors, 0); - auto scale_h = get_scale_value(scale_factors, 1); - auto scale_w = get_scale_value(scale_factors, 2); - return at::upsample_nearest3d(input, osize, scale_d, scale_h, scale_w); -} - -Tensor _upsample_nearest_exact3d_cuda( - const Tensor& input, - at::OptionalIntArrayRef output_size, - c10::optional> scale_factors) { - auto osize = compute_output_size(input.sizes(), output_size, scale_factors); - auto scale_d = get_scale_value(scale_factors, 0); - auto scale_h = get_scale_value(scale_factors, 1); - auto scale_w = get_scale_value(scale_factors, 2); - return at::_upsample_nearest_exact3d(input, osize, scale_d, scale_h, scale_w); -} - -// when structured kernels can handle QuantizedCPU, update these overloads to be CompositeExplicitAutograd -Tensor upsample_nearest3d_backward_cuda( - const Tensor& grad_output, - at::OptionalIntArrayRef output_size, - IntArrayRef input_size, - c10::optional> scale_factors) { - auto osize = compute_output_size(input_size, output_size, scale_factors); - auto scale_d = get_scale_value(scale_factors, 0); - auto scale_h = get_scale_value(scale_factors, 1); - auto scale_w = get_scale_value(scale_factors, 2); - return at::upsample_nearest3d_backward(grad_output, osize, input_size, scale_d, scale_h, scale_w); -} - -Tensor _upsample_nearest_exact3d_backward_cuda( - const Tensor& grad_output, - at::OptionalIntArrayRef output_size, - IntArrayRef input_size, - c10::optional> scale_factors) { - auto osize = compute_output_size(input_size, output_size, scale_factors); - auto scale_d = get_scale_value(scale_factors, 0); - auto scale_h = get_scale_value(scale_factors, 1); - auto scale_w = get_scale_value(scale_factors, 2); - return at::_upsample_nearest_exact3d_backward(grad_output, osize, input_size, scale_d, scale_h, scale_w); -} - } // namespace native } // namespace at diff --git a/aten/src/ATen/native/cuda/jit_utils.cpp b/aten/src/ATen/native/cuda/jit_utils.cpp index b292d488708bf..870e980bb69ee 100644 --- a/aten/src/ATen/native/cuda/jit_utils.cpp +++ b/aten/src/ATen/native/cuda/jit_utils.cpp @@ -893,6 +893,8 @@ void codegenOutputQuery( max_dev_version = CUDAVersion(7, 5); } else if (nvrtc_version == CUDAVersion(11, 0)) { // 11.0 supports 3-8.0 max_dev_version = CUDAVersion(8, 0); + } else if (nvrtc_major == 11 && nvrtc_minor < 8) { + max_dev_version = CUDAVersion(8, 6); } else { // If the driver version is unknown (i.e. newer than this code) // assume the driver supports this device @@ -1530,7 +1532,7 @@ NvrtcFunction jit_pwise_function( &program, code.c_str(), nullptr, 0, nullptr, nullptr)); #ifdef USE_ROCM - std::vector args = {"--std=c++14"}; + std::vector args = {"--std=c++17"}; #else // Constructs nvrtc build arguments // CUDA 11.1 allows going directly to SASS (sm_) instead of PTX (compute_) @@ -1545,7 +1547,7 @@ NvrtcFunction jit_pwise_function( std::to_string(cuda_minor); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) std::vector args = { - "--std=c++14", compute.c_str(), "-default-device"}; + "--std=c++17", compute.c_str(), "-default-device"}; #endif #ifndef NDEBUG diff --git a/aten/src/ATen/native/cuda/layer_norm_kernel.cu b/aten/src/ATen/native/cuda/layer_norm_kernel.cu index ae09f0aaad8f8..3fb041c61d454 100644 --- a/aten/src/ATen/native/cuda/layer_norm_kernel.cu +++ b/aten/src/ATen/native/cuda/layer_norm_kernel.cu @@ -25,6 +25,8 @@ #endif #include +#include + namespace at { namespace native { @@ -33,6 +35,7 @@ namespace { constexpr int kCUDANumThreads = 256; constexpr int kColwiseReduceTileSize = 32; +constexpr unsigned int kWarpSize = 32; constexpr int vec_size = 4; //we could make it dependent on dtype, but that would lead to different results between float and low-p types // aligned vector generates vectorized load/store on CUDA (copy-pasted from MemoryAccess.cuh) @@ -555,8 +558,108 @@ __global__ void GammaBetaBackwardCUDAKernel1( } } +template +__global__ void GammaBetaBackwardCUDAKernel_32x32( + int64_t M, + int64_t N, + const T* dY, + const T* X, + const T_ACC* mean, + const T_ACC* rstd, + T* dg, + T* db) { + alignas(sizeof(double)) extern __shared__ char s_data1[]; + T_ACC* s_data_typed = reinterpret_cast(&s_data1); + T_ACC* s_dg; + T_ACC* s_db; + T_ACC dg_sum = 0; + T_ACC db_sum = 0; + + const int64_t j = blockIdx.x * blockDim.x + threadIdx.x; + + if (j < N) { + constexpr int unroll_factor = 8; + int laneId = threadIdx.x & 0x1f; + + T_ACC mean_reg, mean_reg_tmp; + T_ACC rstd_reg, rstd_reg_tmp; + T dY_reg; + T X_reg; + // Main loop + int bcounter; + for (bcounter = 0; bcounter < M / (blockDim.y * unroll_factor); + bcounter++) { + int offset = (bcounter * blockDim.y + threadIdx.y) * unroll_factor; + + if (laneId < unroll_factor) { + mean_reg_tmp = mean[offset + laneId]; + rstd_reg_tmp = rstd[offset + laneId]; + } +#if !defined(USE_ROCM) + // Volta and newer architectures allow lane divergence within a warp. + __syncwarp(); +#endif + + #pragma unroll + for (int ii = 0; ii < unroll_factor; ++ii) { + dY_reg = dY[(offset + ii) * N + j]; + X_reg = X[(offset + ii) * N + j]; + mean_reg = WARP_SHFL(mean_reg_tmp, ii, kWarpSize); + rstd_reg = WARP_SHFL(rstd_reg_tmp, ii, kWarpSize); + dg_sum += dY_reg * (X_reg - mean_reg) * rstd_reg; + db_sum += dY_reg; + } + } + + // Remainder loop + int offset = (bcounter * blockDim.y + threadIdx.y) * unroll_factor; + for (int ii = 0; ii < unroll_factor; ii++) { + if ((offset + ii) < M) { + mean_reg = mean[offset + ii]; + rstd_reg = rstd[offset + ii]; + dY_reg = dY[(offset + ii) * N + j]; + X_reg = X[(offset + ii) * N + j]; + dg_sum += dY_reg * (X_reg - mean_reg) * rstd_reg; + db_sum += dY_reg; + } + } + + // This kernel uses a block of (32 x 32) and gets called when M; N + // divide by 32. We can use warp shuffles for the final reduction + // step. This removes 4 shmem loads and stores with their + // corresponding __syncthreads() + + // This greatly reduces bank conflicts at the expense of a little + // extra shared memory. It does not impact occupancy + int padded_bx = (1 + blockDim.x); + + s_dg = s_data_typed; + s_db = s_data_typed + (padded_bx * blockDim.y); + s_dg[threadIdx.y * padded_bx + threadIdx.x] = dg_sum; + s_db[threadIdx.y * padded_bx + threadIdx.x] = db_sum; + __syncthreads(); + + // Load transposed so that a warp holds an entire column + T_ACC reg_dg = s_dg[threadIdx.x * padded_bx + threadIdx.y]; + T_ACC reg_db = s_db[threadIdx.x * padded_bx + threadIdx.y]; + for (int delta = 16; delta >= 1; delta /= 2) { + reg_dg += WARP_SHFL_XOR(reg_dg, delta, kWarpSize); + reg_db += WARP_SHFL_XOR(reg_db, delta, kWarpSize); + } + + if (threadIdx.x == 0) { + const int64_t j = blockIdx.x * blockDim.x + threadIdx.y; + if (dg) { + dg[j] = reg_dg; + } + if (db) { + db[j] = reg_db; + } + } + } +} template __global__ void GammaBetaBackwardCUDAKernel( @@ -569,66 +672,75 @@ __global__ void GammaBetaBackwardCUDAKernel( T* dg, T* db) { alignas(sizeof(double)) extern __shared__ char s_data1[]; - T_ACC * s_data_typed = reinterpret_cast(&s_data1); + T_ACC* s_data_typed = reinterpret_cast(&s_data1); + T_ACC* s_dg; + T_ACC* s_db; + const int64_t j = blockIdx.x * blockDim.x + threadIdx.x; - constexpr int unroll = 8; - T dYs[unroll]; - T Xs[unroll]; - T_ACC * means = s_data_typed; - T_ACC * rstds = s_data_typed + unroll * blockDim.y; + T_ACC dg_sum = 0; T_ACC db_sum = 0; + if (j < N) { + constexpr int unroll_factor = 8; + + T_ACC mean_reg; + T_ACC rstd_reg; + T dY_reg; + T X_reg; + + // Main Loop int bcounter; - for (bcounter = 0; bcounter < M/(blockDim.y * unroll); bcounter++){ - int offset = (bcounter * blockDim.y + threadIdx.y) * unroll; - #pragma unroll - for (int ii=0; ii=1; offset /= 2){ + + for (int offset = blockDim.y / 2; offset >= 1; offset /= 2) { if (threadIdx.y < offset) { - s_data_typed[threadIdx.y * blockDim.x + threadIdx.x] += s_data_typed[(threadIdx.y + offset) * blockDim.x + threadIdx.x]; - s_data_typed[blockDim.x * blockDim.y + threadIdx.y * blockDim.x + threadIdx.x] += - s_data_typed[blockDim.x * blockDim.y + (threadIdx.y + offset) * blockDim.x + threadIdx.x]; - } + s_dg[threadIdx.y * blockDim.x + threadIdx.x] += + s_dg[(threadIdx.y + offset) * blockDim.x + threadIdx.x]; + s_db[threadIdx.y * blockDim.x + threadIdx.x] += + s_db[(threadIdx.y + offset) * blockDim.x + threadIdx.x]; + } __syncthreads(); } + if (threadIdx.y == 0) { if (dg) { - dg[j] = s_data_typed[threadIdx.x]; + dg[j] = s_dg[threadIdx.x]; } if (db) { - db[j] = s_data_typed[threadIdx.x + blockDim.x * blockDim.y]; + db[j] = s_db[threadIdx.x]; } } } @@ -722,6 +834,305 @@ void LayerNormKernelImpl( }); } +template __device__ +void cuLoadWriteStridedInputs( + const int i1_block, + const int thr_load_row_off, + const int thr_load_col_off, + const int i2_off, + const int row_stride, + T_ACC* warp_buf1, + T_ACC* warp_buf2, + const T* input, + const T* dout, + const int i1_end, + const int64_t N, + const T_ACC* __restrict__ mean, + const T_ACC* __restrict__ rstd) +{ + int i1 = i1_block+thr_load_row_off; + if (i1 < i1_end) { + T curr_mean = mean[i1]; + T curr_rstd = rstd[i1]; + for (int k = 0; k < blockDim.y; ++k) { + int i2 = i2_off + k; + int load_idx = i1*N+i2; + int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k; + if (i2(input[load_idx]); + T curr_dout = static_cast(dout[load_idx]); + warp_buf1[write_idx] = curr_dout; + warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_rstd; + } else { + warp_buf1[write_idx] = T(0); + warp_buf2[write_idx] = T(0); + } + } + } else { + for (int k = 0; k < blockDim.y; ++k) { + int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k; + warp_buf1[write_idx] = T(0); + warp_buf2[write_idx] = T(0); + } + } +} + +template __device__ +void cuLoadAddStridedInputs( + const int i1_block, + const int thr_load_row_off, + const int thr_load_col_off, + const int i2_off, + const int row_stride, + T_ACC* warp_buf1, + T_ACC* warp_buf2, + const T* input, + const T* dout, + const int i1_end, + const int64_t N, + const T_ACC* __restrict__ mean, + const T_ACC* __restrict__ rstd) +{ + int i1 = i1_block+thr_load_row_off; + if (i1 < i1_end) { + T_ACC curr_mean = mean[i1]; + T_ACC curr_rstd = rstd[i1]; + for (int k = 0; k < blockDim.y; ++k) { + int i2 = i2_off + k; + int load_idx = i1*N+i2; + int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k; + if (i2(input[load_idx]); + T_ACC curr_dout = static_cast(dout[load_idx]); + warp_buf1[write_idx] += curr_dout; + warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_rstd; + } + } + } +} + +template __global__ +void cuComputePartGradGammaBeta( + const T* __restrict__ dout, + const T* __restrict__ input, + const int64_t M, + const int64_t N, + const T_ACC* __restrict__ mean, + const T_ACC* __restrict__ rstd, + T_ACC* part_grad_gamma, + T_ACC* part_grad_beta) +{ + const int numsegs_M = (M+blockDim.y*blockDim.y-1) / (blockDim.y*blockDim.y); + const int segs_per_block = (numsegs_M + gridDim.y - 1) / gridDim.y; + const int i1_beg = blockIdx.y * segs_per_block * blockDim.y*blockDim.y; + const int i1_beg_plus_one = (blockIdx.y+1) * segs_per_block * blockDim.y*blockDim.y; + const int i1_end = i1_beg_plus_one < M ? i1_beg_plus_one : M; + const int row_stride = blockDim.x+1; + const int thr_load_col_off = (threadIdx.x*blockDim.y)&(blockDim.x-1); + const int thr_load_row_off = (threadIdx.x*blockDim.y)/blockDim.x + threadIdx.y*blockDim.y; + const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off; + alignas(sizeof(double)) extern __shared__ char shared[]; + T_ACC * buf = reinterpret_cast(&shared); // buf has at least blockDim.x * blockDim.y * blockDim.y + (blockDim.y - 1)*(blockDim.x/blockDim.y) elements + T_ACC* warp_buf1 = (T_ACC*)buf; + T_ACC* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride; + // compute partial sums from strided inputs + // do this to increase number of loads in flight + cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,N,mean,rstd); + for (int i1_block = i1_beg+blockDim.y*blockDim.y; i1_block < i1_end; i1_block+=blockDim.y*blockDim.y) { + cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,N,mean,rstd); + } + __syncthreads(); + // inter-warp reductions + // sum within each warp + T_ACC acc1 = T_ACC(0); + T_ACC acc2 = T_ACC(0); + for (int k = 0; k < blockDim.y; ++k) { + int row1 = threadIdx.y + k*blockDim.y; + int idx1 = row1*row_stride + threadIdx.x; + acc1 += warp_buf1[idx1]; + acc2 += warp_buf2[idx1]; + } + warp_buf1[threadIdx.y*row_stride+threadIdx.x] = acc1; + warp_buf2[threadIdx.y*row_stride+threadIdx.x] = acc2; + __syncthreads(); + // sum all warps + for (int offset = blockDim.y/2; offset > 1; offset /= 2) { + if (threadIdx.y < offset) { + int row1 = threadIdx.y; + int row2 = threadIdx.y + offset; + int idx1 = row1*row_stride + threadIdx.x; + int idx2 = row2*row_stride + threadIdx.x; + warp_buf1[idx1] += warp_buf1[idx2]; + warp_buf2[idx1] += warp_buf2[idx2]; + } + __syncthreads(); + } + int i2 = blockIdx.x * blockDim.x + threadIdx.x; + if (threadIdx.y == 0 && i2 < N) { + int row1 = threadIdx.y; + int row2 = threadIdx.y + 1; + int idx1 = row1*row_stride + threadIdx.x; + int idx2 = row2*row_stride + threadIdx.x; + part_grad_beta[blockIdx.y*N+i2] = warp_buf1[idx1] + warp_buf1[idx2]; + part_grad_gamma[blockIdx.y*N+i2] = warp_buf2[idx1] + warp_buf2[idx2]; + } +} + +template __global__ +void cuComputeGradGammaBeta( + const T_ACC* part_grad_gamma, + const T_ACC* part_grad_beta, + const int part_size, + const int64_t M, + const int64_t N, + T* grad_gamma, + T* grad_beta) +{ + // sum partial gradients for gamma and beta + alignas(sizeof(double)) extern __shared__ char shared[]; + T_ACC * buf = reinterpret_cast(&shared); + int i2 = blockIdx.x * blockDim.x + threadIdx.x; + if (i2 < N) { + // each warp does sequential reductions until reduced part_size is num_warps + int num_warp_reductions = part_size / blockDim.y; + T_ACC sum_gamma = T_ACC(0); + T_ACC sum_beta = T_ACC(0); + const T_ACC* part_grad_gamma_ptr = part_grad_gamma + threadIdx.y * num_warp_reductions * N + i2; + const T_ACC* part_grad_beta_ptr = part_grad_beta + threadIdx.y * num_warp_reductions * N + i2; + for (int warp_offset = 0; warp_offset < num_warp_reductions; ++warp_offset) { + sum_gamma += part_grad_gamma_ptr[warp_offset*N]; + sum_beta += part_grad_beta_ptr[warp_offset*N]; + } + // inter-warp reductions + const int nbsize3 = blockDim.x * blockDim.y / 2; + for (int offset = blockDim.y/2; offset >= 1; offset /= 2) { + // top half write to shared memory + if (threadIdx.y >= offset && threadIdx.y < 2*offset) { + const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x; + buf[write_idx] = sum_gamma; + buf[write_idx+nbsize3] = sum_beta; + } + __syncthreads(); + // bottom half sums + if (threadIdx.y < offset) { + const int read_idx = threadIdx.y * blockDim.x + threadIdx.x; + sum_gamma += buf[read_idx]; + sum_beta += buf[read_idx+nbsize3]; + } + __syncthreads(); + } + // write out fully summed gradients + if (threadIdx.y == 0) { + grad_gamma[i2] = sum_gamma; + grad_beta[i2] = sum_beta; + } + } +} + +template __global__ +void cuComputeGradInput( + const T* __restrict__ dout, + const T* __restrict__ input, + const int64_t M, + const int64_t N, + const T_ACC* __restrict__ mean, + const T_ACC* __restrict__ rstd, + const T* gamma, + T* grad_input) +{ + for (int i1=blockIdx.y; i1 < M; i1 += gridDim.y) { + T_ACC sum_loss1 = T_ACC(0); + T_ACC sum_loss2 = T_ACC(0); + T_ACC c_mean = mean[i1]; + const T_ACC c_rstd = rstd[i1]; + const T* k_input = input + i1*N; + const T* k_dout = dout + i1*N; + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + if (gamma != NULL) { + // Optimization for ROCm MI100 + for( int l = 0; l < N ; l += numx) { + int idx = l + thrx; + const T_ACC gamma_idx = static_cast((idx((idx((idx((idx((idx 0; mask /= 2) { + sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask); + sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask); + } + // inter-warp reductions + if (blockDim.y > 1) { + alignas(sizeof(double)) extern __shared__ char shared[]; + T_ACC * buf = reinterpret_cast(&shared); + for (int offset = blockDim.y/2; offset > 0; offset /= 2) { + // upper half of warps write to shared + if (threadIdx.y >= offset && threadIdx.y < 2*offset) { + const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x; + buf[2*wrt_i] = sum_loss1; + buf[2*wrt_i+1] = sum_loss2; + } + __syncthreads(); + // lower half merges + if (threadIdx.y < offset) { + const int read_i = threadIdx.y * blockDim.x + threadIdx.x; + sum_loss1 += buf[2*read_i]; + sum_loss2 += buf[2*read_i+1]; + } + __syncthreads(); + } + if (threadIdx.y == 0) { + buf[2*threadIdx.x] = sum_loss1; + buf[2*threadIdx.x+1] = sum_loss2; + } + __syncthreads(); + if (threadIdx.y !=0) { + sum_loss1 = buf[2*threadIdx.x]; + sum_loss2 = buf[2*threadIdx.x+1]; + } + } + // all threads now have the two sums over l + T_ACC fH = (T_ACC)N; + T_ACC term1 = (T_ACC(1) / fH) * c_rstd; + T* k_grad_input = grad_input + i1*N; + if (gamma != NULL) { + for (int l = thrx; l < N; l+=numx) { + const T_ACC c_h = static_cast(k_input[l]); + const T_ACC c_loss = static_cast(k_dout[l]); + T_ACC f_grad_input = fH * c_loss * gamma[l]; + f_grad_input -= sum_loss1; + f_grad_input -= (c_h - c_mean) * c_rstd * sum_loss2; + f_grad_input *= term1; + k_grad_input[l] = static_cast(f_grad_input); + } + } else { + for (int l = thrx; l < N; l+=numx) { + const T_ACC c_h = static_cast(k_input[l]); + const T_ACC c_loss = static_cast(k_dout[l]); + T_ACC f_grad_input = fH * c_loss; + f_grad_input -= sum_loss1; + f_grad_input -= (c_h - c_mean) * c_rstd * sum_loss2; + f_grad_input *= term1; + k_grad_input[l] = static_cast(f_grad_input); + } + } + // prevent race where buf is written again before reads are done + __syncthreads(); + } +} + template void LayerNormBackwardKernelImplInternal( const Tensor& dY, @@ -750,20 +1161,49 @@ void LayerNormBackwardKernelImplInternal( gamma.defined() ? gamma.template data_ptr() : nullptr; T* dX_data = dX->defined() ? dX->template data_ptr() : nullptr; cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream(); + const int warp_size = at::cuda::warp_size(); if (dX_data != nullptr) { - const int warp_size = at::cuda::warp_size(); +#if defined __HIP_PLATFORM_HCC__ + if (M >= 32768) { + const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; + const dim3 blocks1(1, std::min((uint64_t)M, maxGridY), 1); + dim3 threads1(warp_size, 4, 1); + threads1.y = 2; // Optimization for ROCm + int nshared = + threads1.y > 1 ? + threads1.y*threads1.x*sizeof(T_ACC) : + 0; + cuComputeGradInput<<>>( + dY_data, + X_data, + M, N, + mean_data, + rstd_data, + gamma_data, + dX_data); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } else { + const dim3 blocks(M); + int nshared = (num_threads()/warp_size) * sizeof(T_ACC); + layer_norm_grad_input_kernel<<>>(dY_data, + X_data, mean_data, rstd_data, gamma_data, dX_data, N); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } +#else const dim3 blocks(M); int nshared = (num_threads()/warp_size) * sizeof(T_ACC); layer_norm_grad_input_kernel<<>>(dY_data, X_data, mean_data, rstd_data, gamma_data, dX_data, N); C10_CUDA_KERNEL_LAUNCH_CHECK(); +#endif } if (dgamma->defined() || dbeta->defined()) { T* dgamma_data = dgamma->defined() ? dgamma->template data_ptr() : nullptr; T* dbeta_data = dbeta->defined() ? dbeta->template data_ptr() : nullptr; - if (M < 512) { + + if (M < 128) { // For small batch size, do colwise reduce directly. const int64_t B = (N + kCUDANumThreads - 1) / kCUDANumThreads; GammaBetaBackwardSimpleCUDAKernel @@ -778,19 +1218,77 @@ void LayerNormBackwardKernelImplInternal( dbeta_data); C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { - dim3 threads{16, 32}; - int blocks = (N + threads.x-1)/threads.x; - GammaBetaBackwardCUDAKernel - <<>>( - M, - N, - dY_data, - X_data, - mean_data, - rstd_data, - dgamma_data, - dbeta_data); +#if defined(USE_ROCM) + // For small batch size, do colwise reduce directly. + const int part_size = warp_size; + const dim3 threads2(warp_size, 4, 1); + const dim3 blocks2((N + threads2.x - 1) / threads2.x, part_size, 1); + const int nshared2_a = 2 * sizeof(T_ACC) * threads2.y * threads2.y * (threads2.x + 1); + const int nshared2_b = threads2.x * threads2.y * sizeof(T_ACC); + const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b; + + const auto part_grad_dtype = at::toAccumulateType(X.scalar_type(), true); + Tensor part_grad_gamma = at::empty({part_size,N}, gamma.options().dtype(part_grad_dtype)); + Tensor part_grad_beta = at::native::empty_like(part_grad_gamma); + cuComputePartGradGammaBeta<<>>( + dY_data, + X_data, + M,N, + mean_data, + rstd_data, + part_grad_gamma.template data_ptr(), + part_grad_beta.template data_ptr()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + const dim3 threads3(warp_size, 8, 1); // Optimization for ROCm + const dim3 blocks3((N + threads2.x - 1) / threads2.x, 1, 1); + const int nshared3 = threads3.x * threads3.y * sizeof(T); + cuComputeGradGammaBeta<<>>( + part_grad_gamma.template data_ptr(), + part_grad_beta.template data_ptr(), + part_size, + M,N, + dgamma_data, + dbeta_data); C10_CUDA_KERNEL_LAUNCH_CHECK(); +#else + if ((M % kWarpSize == 0) && (N % kWarpSize == 0)) { + // This implementation relies on warp primitives and requires that M and N divide + // exactly to warp size. + dim3 threads{kWarpSize, kWarpSize}; + int blocks = (N + threads.x - 1) / threads.x; + + // If M and N divide by 32, we can use warp shuffles for the final reduction. That requires + // transposing values in shared memory, so we apply a padding to reduce bank conflicts. + size_t shmem_sz = 2 * sizeof(T_ACC) * (threads.x + 1) * threads.y; + GammaBetaBackwardCUDAKernel_32x32 + <<>>( + M, + N, + dY_data, + X_data, + mean_data, + rstd_data, + dgamma_data, + dbeta_data); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } else { + dim3 threads{16, 32}; + int blocks = (N + threads.x - 1) / threads.x; + size_t shmem_sz = 2 * sizeof(T_ACC) * threads.x * threads.y; + GammaBetaBackwardCUDAKernel + <<>>( + M, + N, + dY_data, + X_data, + mean_data, + rstd_data, + dgamma_data, + dbeta_data); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } +#endif } } } diff --git a/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.cpp b/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.cpp index 01788e0bdffee..89c1246a32d14 100644 --- a/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.cpp +++ b/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.cpp @@ -656,23 +656,21 @@ inline static void apply_svd_cusolver_gesvdjBatched(const Tensor& A, const Tenso using value_t = typename c10::scalar_value_type::type; int m = cuda_int_cast(A.size(-2), "m"); int n = cuda_int_cast(A.size(-1), "n"); - int k = std::min(m, n); int batchsize = cuda_int_cast(batchCount(A), "batch size"); + int lda = A.stride(-1); + int ldu = compute_uv ? U.stride(-1) : m; + int ldv = compute_uv ? V.stride(-1) : n; // Need to pass allocated memory to the function, otherwise it fails auto& allocator = *::c10::cuda::CUDACachingAllocator::get(); - auto dataPtr_U = !compute_uv ? allocator.allocate(sizeof(scalar_t) * batchsize * m * k) : c10::DataPtr{}; - auto dataPtr_V = !compute_uv ? allocator.allocate(sizeof(scalar_t) * batchsize * n * k) : c10::DataPtr{}; + auto dataPtr_U = !compute_uv ? allocator.allocate(sizeof(scalar_t) * batchsize * m * ldu) : c10::DataPtr{}; + auto dataPtr_V = !compute_uv ? allocator.allocate(sizeof(scalar_t) * batchsize * n * ldv) : c10::DataPtr{}; auto A_data = A.data_ptr(); auto U_data = compute_uv ? U.data_ptr() : reinterpret_cast(dataPtr_U.get()); auto S_data = S.data_ptr(); auto V_data = compute_uv ? V.data_ptr() : reinterpret_cast(dataPtr_V.get()); - int lda = A.stride(-1); - int ldu = compute_uv ? U.stride(-1) : m; - int ldv = compute_uv ? V.stride(-1) : n; - TORCH_INTERNAL_ASSERT(m <= 32 && n <= 32, "gesvdjBatched requires both matrix dimensions not greater than 32, but got " "m = ", m, " n = ", n); @@ -695,10 +693,42 @@ inline static void apply_svd_cusolver_gesvdjBatched(const Tensor& A, const Tenso TORCH_CUSOLVER_CHECK(cusolverDnDestroyGesvdjInfo(gesvdj_params)); } -inline static void svd_cusolver_gesvdjBatched(const Tensor& A, const Tensor& U, const Tensor& S, const Tensor& V, const Tensor& infos, bool compute_uv) { +inline static void svd_cusolver_gesvdjBatched(const Tensor& A, const Tensor& U, const Tensor& S, const Tensor& V, const Tensor& infos, bool full_matrices, bool compute_uv) { + auto m = A.size(-2); + auto n = A.size(-1); + auto k = std::min(m, n); + // The kernel assumes full_matrices == true + // If full_matrices == false and m != n, we create auxiliary tensors of the right size and copy the results back + auto U_ = U; + auto V_ = V; + if (compute_uv && !full_matrices) { + auto sizes = A.sizes().vec(); + if (m > n) { + // Size of U with full_matrices == True + sizes.end()[-1] = m; + // U, V should be a batch of Fortran contiguous arrays + U_ = U.new_empty(sizes).mT(); + } else if (m < n) { + // Size of V with full_matrices == True + sizes.end()[-2] = n; + V_ = V.new_empty(sizes).mT(); + } + } + // Here U_ and V_ are batches of F-contig square matrices + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(A.scalar_type(), "svd_cuda_gesvdjBatched", [&] { - apply_svd_cusolver_gesvdjBatched(A, U, S, V, infos, compute_uv); + apply_svd_cusolver_gesvdjBatched(A, U_, S, V_, infos, compute_uv); }); + + // Copy the result back if we created any new matrix + if (compute_uv && !full_matrices) { + if (!U_.is_alias_of(U)) { + U.copy_(U_.narrow(-1, 0, k)); + } + if (!V_.is_alias_of(V)) { + V.copy_(V_.narrow(-1, 0, k)); + } + } } template @@ -832,21 +862,23 @@ void svd_cusolver(const Tensor& A, const Tensor& V, const Tensor& info) { // Here U and V are F-contig whenever they are defined (i.e. whenever compute_uv=true) - const auto batch_size = batchCount(A); const auto m = A.size(-2); const auto n = A.size(-1); const auto k = std::min(m, n); static const char* check_svd_doc = "Check doc at https://pytorch.org/docs/stable/generated/torch.linalg.svd.html"; - // The default heuristic is to use gesvdj driver + // The default heuristic is to use the gesvdj driver const auto driver_v = driver.value_or("gesvdj"); if (driver_v == "gesvd") { svd_cusolver_gesvd(A, U, S, V, info, full_matrices, compute_uv); } else if (driver_v == "gesvdj") { - if (m <= 32 && n <= 32 && batch_size > 1 && (full_matrices || m == n)) { - svd_cusolver_gesvdjBatched(cloneBatchedColumnMajor(A), U, S, V, info, compute_uv); + // See the benchmarks in + // https://github.com/pytorch/pytorch/pull/88502#issuecomment-1303860789 + // The m <= 32 && n <= 32 restrictions come from the limitations of the cusolver backend. See the cusolver docs + if (m <= 32 && n <= 32) { + svd_cusolver_gesvdjBatched(cloneBatchedColumnMajor(A), U, S, V, info, full_matrices, compute_uv); } else { // gesvdj driver may be numerically unstable for large sized matrix svd_cusolver_gesvdj(cloneBatchedColumnMajor(A), U, S, V, info, full_matrices, compute_uv); diff --git a/aten/src/ATen/native/cuda/vol2col.cuh b/aten/src/ATen/native/cuda/vol2col.cuh index 7ab719bc819eb..51dbe1c744053 100644 --- a/aten/src/ATen/native/cuda/vol2col.cuh +++ b/aten/src/ATen/native/cuda/vol2col.cuh @@ -15,7 +15,7 @@ using namespace at::cuda::detail; // Kernel for fast unfold+copy on volumes template __global__ void vol2col_kernel( - const int n, + const int64_t n, const T* data_vol, const int depth, const int height, @@ -37,16 +37,16 @@ __global__ void vol2col_kernel( const int width_col, T* data_col) { CUDA_KERNEL_LOOP(index, n) { - int w_out = index % width_col; + auto w_out = index % width_col; index /= width_col; - int h_out = index % height_col; + auto h_out = index % height_col; index /= height_col; - int t_out = index % depth_col; - int channel_in = index / depth_col; - int channel_out = channel_in * ksize_t * ksize_h * ksize_w; - int t_in = t_out * stride_t - pad_t; - int h_in = h_out * stride_h - pad_h; - int w_in = w_out * stride_w - pad_w; + auto t_out = index % depth_col; + auto channel_in = index / depth_col; + auto channel_out = channel_in * ksize_t * ksize_h * ksize_w; + auto t_in = t_out * stride_t - pad_t; + auto h_in = h_out * stride_h - pad_h; + auto w_in = w_out * stride_w - pad_w; data_col += ((channel_out * depth_col + t_out) * height_col + h_out) * width_col + w_out; @@ -54,9 +54,9 @@ __global__ void vol2col_kernel( for (int i = 0; i < ksize_t; ++i) { for (int j = 0; j < ksize_h; ++j) { for (int k = 0; k < ksize_w; ++k) { - int t = t_in + i * dilation_t; - int h = h_in + j * dilation_h; - int w = w_in + k * dilation_w; + auto t = t_in + i * dilation_t; + auto h = h_in + j * dilation_h; + auto w = w_in + k * dilation_w; *data_col = (t >= 0 && h >= 0 && w >= 0 && t < depth && h < height && w < width) ? data_vol @@ -126,7 +126,7 @@ void vol2col( template __global__ void vol2im_kernel( - const unsigned n, + const int64_t n, const T* data_col, const unsigned depth, const unsigned height, @@ -150,30 +150,30 @@ __global__ void vol2im_kernel( T* data_vol) { CUDA_KERNEL_LOOP(index, n) { accT val = static_cast(0); - const unsigned w_im = index % width + pad_w; - const unsigned h_im = (index / width) % height + pad_h; - const unsigned t_im = (index / width / height) % depth + pad_t; - const unsigned c_im = index / (width * height * depth); - unsigned kernel_extent_w = (kernel_w - 1) * dilation_w + 1; - unsigned kernel_extent_h = (kernel_h - 1) * dilation_h + 1; - unsigned kernel_extent_t = (kernel_t - 1) * dilation_t + 1; + const auto w_im = index % width + pad_w; + const auto h_im = (index / width) % height + pad_h; + const auto t_im = (index / width / height) % depth + pad_t; + const auto c_im = index / (width * height * depth); + auto kernel_extent_w = (kernel_w - 1) * dilation_w + 1; + auto kernel_extent_h = (kernel_h - 1) * dilation_h + 1; + auto kernel_extent_t = (kernel_t - 1) * dilation_t + 1; // compute the start and end of the output - const unsigned w_col_start = + const auto w_col_start = (w_im < kernel_extent_w) ? 0 : (w_im - kernel_extent_w) / stride_w + 1; - const unsigned w_col_end = std::min(w_im / stride_w + 1, width_col); - const unsigned h_col_start = + const auto w_col_end = std::min(w_im / stride_w + 1, width_col); + const auto h_col_start = (h_im < kernel_extent_h) ? 0 : (h_im - kernel_extent_h) / stride_h + 1; - const unsigned h_col_end = std::min(h_im / stride_h + 1, height_col); - const unsigned t_col_start = + const auto h_col_end = std::min(h_im / stride_h + 1, height_col); + const auto t_col_start = (t_im < kernel_extent_t) ? 0 : (t_im - kernel_extent_t) / stride_t + 1; - const unsigned t_col_end = std::min(t_im / stride_t + 1, depth_col); + const auto t_col_end = std::min(t_im / stride_t + 1, depth_col); // TODO: use LCM of stride and dilation to avoid unnecessary loops for (unsigned t_col = t_col_start; t_col < t_col_end; t_col += 1) { for (unsigned h_col = h_col_start; h_col < h_col_end; h_col += 1) { for (unsigned w_col = w_col_start; w_col < w_col_end; w_col += 1) { - unsigned t_k = (t_im - t_col * stride_t); - unsigned h_k = (h_im - h_col * stride_h); - unsigned w_k = (w_im - w_col * stride_w); + uint64_t t_k = (t_im - t_col * stride_t); + uint64_t h_k = (h_im - h_col * stride_h); + uint64_t w_k = (w_im - w_col * stride_w); if (t_k % dilation_t == 0 && h_k % dilation_h == 0 && w_k % dilation_w == 0) { t_k /= dilation_t; diff --git a/aten/src/ATen/native/cudnn/Conv_v8.cpp b/aten/src/ATen/native/cudnn/Conv_v8.cpp index ded4d2385c2ce..11fe5be8298e1 100644 --- a/aten/src/ATen/native/cudnn/Conv_v8.cpp +++ b/aten/src/ATen/native/cudnn/Conv_v8.cpp @@ -54,9 +54,12 @@ uint8_t getAlignment(const Tensor &t) { return alignment; } -cudnn_frontend::Tensor getTensorDescriptorWithTypeVirtual(const Tensor &t, const int64_t id, const uint8_t alignment, const cudnnDataType_t dataType, const bool _virtual) { +cudnn_frontend::Tensor getTensorDescriptorWithTypeVirtual(const Tensor &t, const int64_t id, const uint8_t alignment, const cudnnDataType_t dataType, const at::MemoryFormat memory_format, const bool _virtual) { auto sizes = t.sizes(); auto strides = t.strides(); + bool channels_last = memory_format == at::MemoryFormat::ChannelsLast || + memory_format == at::MemoryFormat::ChannelsLast3d; + fixSizeOneDimStride(sizes.size(), &sizes[0], (int64_t *) &strides[0], channels_last); auto r = cudnn_frontend::TensorBuilder() .setDim(sizes.size(), sizes.data()) .setStrides(strides.size(), strides.data()) @@ -68,8 +71,8 @@ cudnn_frontend::Tensor getTensorDescriptorWithTypeVirtual(const Tensor &t, const return r; } -cudnn_frontend::Tensor getTensorDescriptor(const Tensor &t, const int64_t id, const uint8_t alignment) { - return getTensorDescriptorWithTypeVirtual(t, id, alignment, getCudnnDataType(t), false); +cudnn_frontend::Tensor getTensorDescriptor(const Tensor &t, const int64_t id, const uint8_t alignment, const at::MemoryFormat memory_format) { + return getTensorDescriptorWithTypeVirtual(t, id, alignment, getCudnnDataType(t), memory_format, false); } cudnn_frontend::ConvDesc_v8 getConvDescriptor(cudnnDataType_t dataType, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, const at::ScalarType scalar_type) { @@ -159,7 +162,8 @@ BenchmarkCache benchmark_cache_fus // would not be a POD anymore. void setCacheKey(CacheKey& key, const cudnnBackendDescriptorType_t operation, const Tensor& y, const Tensor& x, const Tensor& w, const IntArrayRef padding, const IntArrayRef stride, const IntArrayRef dilation, int64_t groups, bool deterministic, bool allow_tf32) { memset(&key, 0, sizeof(key)); - setConvolutionParams(&key.params, x, w, padding, stride, dilation, groups, deterministic, allow_tf32, x.suggest_memory_format()); + at::MemoryFormat memory_format = cudnn_conv_suggest_memory_format(x, w); + setConvolutionParams(&key.params, x, w, padding, stride, dilation, groups, deterministic, allow_tf32, memory_format); key.operation = operation; key.x_alignment = getAlignment(x); key.y_alignment = getAlignment(y); @@ -168,7 +172,8 @@ void setCacheKey(CacheKey& key, const cudnnBackendDescriptorType_t operation, co void setCacheKeyFused(CacheKeyFused& key, const Tensor& y, const Tensor& x, const Tensor& w, const Tensor& z, const Tensor& b, const float alpha, const IntArrayRef padding, const IntArrayRef stride, const IntArrayRef dilation, int64_t groups, bool deterministic, bool allow_tf32) { memset(&key, 0, sizeof(key)); - setConvolutionParams(&key.params, x, w, padding, stride, dilation, groups, deterministic, allow_tf32, x.suggest_memory_format()); + at::MemoryFormat memory_format = cudnn_conv_suggest_memory_format(x, w); + setConvolutionParams(&key.params, x, w, padding, stride, dilation, groups, deterministic, allow_tf32, memory_format); key.x_alignment = getAlignment(x); key.y_alignment = getAlignment(y); key.w_alignment = getAlignment(w); @@ -207,9 +212,9 @@ void run_conv_plan_fused(cudnnHandle_t handle, const Tensor& x, const Tensor& y, auto build_opgraph(const cudnnHandle_t handle, const cudnnBackendDescriptorType_t desc, const Tensor& x, const Tensor& y, const Tensor& w, const CacheKey& key, const IntArrayRef padding, const IntArrayRef stride, const IntArrayRef dilation) { auto op = cudnn_frontend::OperationBuilder(desc) - .setxDesc(getTensorDescriptor(x, 'x', key.x_alignment)) - .setyDesc(getTensorDescriptor(y, 'y', key.y_alignment)) - .setwDesc(getTensorDescriptor(w, 'w', key.w_alignment)) + .setxDesc(getTensorDescriptor(x, 'x', key.x_alignment, key.params.memory_format)) + .setyDesc(getTensorDescriptor(y, 'y', key.y_alignment, key.params.memory_format)) + .setwDesc(getTensorDescriptor(w, 'w', key.w_alignment, key.params.memory_format)) .setcDesc(getConvDescriptor(key.params.dataType, padding, stride, dilation, x.scalar_type())) .build(); std::array ops = {&op}; @@ -239,33 +244,33 @@ auto build_opgraph_fused(const cudnnHandle_t handle, const Tensor & x, const Ten const float alpha1 = 1.0; const float alpha2 = alpha; auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR) - .setxDesc(getTensorDescriptor(x, 'x', key.x_alignment)) + .setxDesc(getTensorDescriptor(x, 'x', key.x_alignment, key.params.memory_format)) // virtual output of conv - .setyDesc(getTensorDescriptorWithTypeVirtual(y, 'C', key.y_alignment, precision, true)) - .setwDesc(getTensorDescriptor(w, 'w', key.w_alignment)) + .setyDesc(getTensorDescriptorWithTypeVirtual(y, 'C', key.y_alignment, precision, key.params.memory_format, true)) + .setwDesc(getTensorDescriptor(w, 'w', key.w_alignment, key.params.memory_format)) .setAlpha(alpha1) .setcDesc(convDesc) .build(); auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) .setxDesc(conv_op.getOutputTensor()) - .setbDesc(getTensorDescriptor(z, 'z', key.z_alignment)) + .setbDesc(getTensorDescriptor(z, 'z', key.z_alignment, key.params.memory_format)) // another virtual output (of add) - .setyDesc(getTensorDescriptorWithTypeVirtual(y, 'A', key.y_alignment, precision, true)) + .setyDesc(getTensorDescriptorWithTypeVirtual(y, 'A', key.y_alignment, precision, key.params.memory_format, true)) .setpwDesc(addDesc) .setAlpha(alpha1) .setAlpha2(alpha2) .build(); auto add_bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) .setxDesc(add_op.getOutputTensor()) - .setbDesc(getTensorDescriptor(b, 'b', key.b_alignment)) + .setbDesc(getTensorDescriptor(b, 'b', key.b_alignment, key.params.memory_format)) // another virtual output (of add bias) - .setyDesc(getTensorDescriptorWithTypeVirtual(y, 'B', key.y_alignment, precision, true)) + .setyDesc(getTensorDescriptorWithTypeVirtual(y, 'B', key.y_alignment, precision, key.params.memory_format, true)) .setpwDesc(addBiasDesc) .build(); auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) .setxDesc(add_bias_op.getOutputTensor()) // final output is in original datatype - .setyDesc(getTensorDescriptor(y, 'y', key.y_alignment)) + .setyDesc(getTensorDescriptor(y, 'y', key.y_alignment, key.params.memory_format)) .setpwDesc(actDesc) .build(); std::array ops = {&conv_op, &add_op, &add_bias_op, &act_op}; diff --git a/aten/src/ATen/native/cudnn/LossCTC.cpp b/aten/src/ATen/native/cudnn/LossCTC.cpp index a741816424a7f..7737e91d44177 100644 --- a/aten/src/ATen/native/cudnn/LossCTC.cpp +++ b/aten/src/ATen/native/cudnn/LossCTC.cpp @@ -88,13 +88,13 @@ bool _use_cudnn_ctc_loss( // (they should, but we didn't check yet) int64_t max_input_length = log_probs.size(0); for (const auto input_length : input_lengths) { - use_cudnn &= ((input_length == max_input_length) ? 1 : 0); + use_cudnn = use_cudnn && ((input_length == max_input_length) ? 1 : 0); } for (const auto b : c10::irange(target_lengths.size())) { // target length < 256 is documented, but we see illegal memory accesses // when target lengths > input lengths for CuDNN - use_cudnn &= - (target_lengths[b] < 256) & (target_lengths[b] <= input_lengths[b]); + use_cudnn = + use_cudnn && (target_lengths[b] < 256) && (target_lengths[b] <= input_lengths[b]); } } return use_cudnn; diff --git a/aten/src/ATen/native/cudnn/RNN.cpp b/aten/src/ATen/native/cudnn/RNN.cpp index c08c5d26b63c7..426243392b6fc 100644 --- a/aten/src/ATen/native/cudnn/RNN.cpp +++ b/aten/src/ATen/native/cudnn/RNN.cpp @@ -70,7 +70,7 @@ Tensor _cudnn_init_dropout_state(double dropout, bool train, int64_t dropout_see c10::optional device, c10::optional pin_memory) { // See [Note: hacky wrapper removal for TensorOptions] - TensorOptions options = TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory); + TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory); AT_ERROR("_cudnn_init_dropout_state: ATen not compiled with cuDNN support"); } diff --git a/aten/src/ATen/native/group_norm.cpp b/aten/src/ATen/native/group_norm.cpp index 24a23577e490e..22ff9ea5f0e86 100644 --- a/aten/src/ATen/native/group_norm.cpp +++ b/aten/src/ATen/native/group_norm.cpp @@ -1,26 +1,37 @@ -#include -#include -#include -#include -#include -#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include +#include +#include #include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#include +#include +#endif + #include #include -#include #include #include namespace at { + namespace native { +template void check_group_norm_inputs( const Tensor& input, const Tensor& weight, const Tensor& bias, - int64_t C, + T C, int64_t num_groups) { TORCH_CHECK( num_groups > 0, @@ -34,14 +45,14 @@ void check_group_norm_inputs( "num_groups=", num_groups); TORCH_CHECK( - !weight.defined() || (weight.dim() == 1 && weight.numel() == C), + !weight.defined() || (weight.dim() == 1 && at::symint::numel(weight) == C), "Expected weight to be a vector of size equal to the number of ", "channels in input, but got weight of shape ", weight.sizes(), " and input of shape ", input.sizes()); TORCH_CHECK( - !bias.defined() || (bias.dim() == 1 && bias.numel() == C), + !bias.defined() || (bias.dim() == 1 && at::symint::numel(bias) == C), "Expected bias to be a vector of size equal to the number of ", "channels in input, but got bias of shape ", weight.sizes(), @@ -162,24 +173,24 @@ Tensor group_norm( const Tensor& weight = *weight_maybe_owned; const Tensor& bias = c10::value_or_else(bias_opt, [] { return Tensor(); }); - const int64_t N = input.size(0); - const int64_t C = input.size(1); + const auto N = input.sym_size(0); + const auto C = input.sym_size(1); check_group_norm_inputs(input, weight, bias, C, num_groups); - const auto input_shape = input.sizes(); - const int64_t HxW = - c10::multiply_integers(input_shape.cbegin() + 2, input_shape.cend()); + const auto input_shape = input.sym_sizes(); + const auto HxW = + c10::multiply_integers(input_shape.slice(2)); const Tensor kEmpty; auto memory_format = input.suggest_memory_format(); - const auto& X = input.device().is_cpu() ? + const auto& X = input.device().is_cpu() || input.device().is_xpu() ? input.contiguous(memory_format) : input.contiguous(); const auto& gamma = weight.defined() ? weight.contiguous() : kEmpty; const auto& beta = bias.defined() ? bias.contiguous() : kEmpty; - TORCH_CHECK(!gamma.defined() || gamma.numel() == C); - TORCH_CHECK(!beta.defined() || beta.numel() == C); + TORCH_CHECK(!gamma.defined() || gamma.sym_numel() == C); + TORCH_CHECK(!beta.defined() || beta.sym_numel() == C); return std::get<0>( - at::native_group_norm(X, gamma, beta, N, C, HxW, num_groups, eps)); + at::native_group_norm_symint(X, gamma, beta, N, C, HxW, num_groups, eps)); } DEFINE_DISPATCH(GroupNormKernel); diff --git a/aten/src/ATen/native/layer_norm.cpp b/aten/src/ATen/native/layer_norm.cpp index 80a7bb6111f23..37a3f1a750ab2 100644 --- a/aten/src/ATen/native/layer_norm.cpp +++ b/aten/src/ATen/native/layer_norm.cpp @@ -1,17 +1,27 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include -#include -#include -#include -#include -#include +#include #include +#include #include -#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#include +#include +#include +#include +#endif #include -#include -#include #include #include @@ -69,6 +79,10 @@ std::tuple layer_norm_cpu( c10::MaybeOwned bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt); const Tensor& bias = *bias_maybe_owned; + bool mixed_type = is_mixed_type(input, weight, bias); + if (mixed_type) { + check_mixed_data_type(input, weight, bias); + } auto M_N = _check_layer_norm_inputs(input, normalized_shape, weight, bias); auto M = M_N.first; @@ -84,8 +98,9 @@ std::tuple layer_norm_cpu( c10::nullopt /* device */, c10::nullopt /* pin_memory */, at::MemoryFormat::Contiguous); - Tensor mean = at::empty({M}, X->options()); - Tensor rstd = at::empty({M}, X->options()); + const auto dtype = param_scalar_type(input, mixed_type); + Tensor mean = at::empty({M}, X->options().dtype(dtype)); + Tensor rstd = at::empty({M}, X->options().dtype(dtype)); layer_norm_with_mean_rstd_out(Y, mean, rstd, *X, normalized_shape, *gamma, *beta, eps, M, N); return std::make_tuple(std::move(Y), std::move(mean), std::move(rstd)); @@ -166,9 +181,9 @@ std::tuple layer_norm_backward_cpu( return std::make_tuple(std::move(dX), std::move(dgamma), std::move(dbeta)); } -Tensor layer_norm( +Tensor layer_norm_symint( const Tensor& input, - IntArrayRef normalized_shape, const c10::optional& weight_opt /* optional */, const c10::optional& bias_opt /* optional */, + c10::SymIntArrayRef normalized_shape, const c10::optional& weight_opt /* optional */, const c10::optional& bias_opt /* optional */, double eps, bool /* cudnn_enable, deprecated */) { // See [Note: hacky wrapper removal for optional tensor] @@ -177,8 +192,7 @@ Tensor layer_norm( c10::MaybeOwned bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt); const Tensor& bias = *bias_maybe_owned; - - return std::get<0>(at::native_layer_norm(input, normalized_shape, weight, bias, eps)); + return std::get<0>(at::native_layer_norm_symint(input, normalized_shape, weight, bias, eps)); } DEFINE_DISPATCH(LayerNormKernel); diff --git a/aten/src/ATen/native/miopen/Conv_miopen.cpp b/aten/src/ATen/native/miopen/Conv_miopen.cpp index 677a711ce7a6b..060a97d6fc1c1 100644 --- a/aten/src/ATen/native/miopen/Conv_miopen.cpp +++ b/aten/src/ATen/native/miopen/Conv_miopen.cpp @@ -187,7 +187,7 @@ struct ConvolutionParams }; // ConvolutionParams must be a POD because we read out its memory // contenst as char* when hashing -static_assert(std::is_pod::value, "ConvolutionParams not POD"); +static_assert(std::is_standard_layout::value, "ConvolutionParams not POD"); void setConvolutionParams( ConvolutionParams* params, miopenHandle_t handle, diff --git a/aten/src/ATen/native/mkldnn/Conv.cpp b/aten/src/ATen/native/mkldnn/Conv.cpp index 508aefe787ad7..3d8188c003e1d 100644 --- a/aten/src/ATen/native/mkldnn/Conv.cpp +++ b/aten/src/ATen/native/mkldnn/Conv.cpp @@ -8,9 +8,9 @@ #include #include #else +#include #include -#include -#include +#include #include #include #include @@ -175,51 +175,23 @@ static inline at::MemoryFormat mkldnn_convolution_memory_format(int64_t dims, bo return memory_format; } -Tensor _mkldnn_convolution( +void _mkldnn_convolution_out ( const Tensor& input_t, const Tensor& weight_t, - const c10::optional& bias_opt, - IntArrayRef padding, + const Tensor& bias, + std::vector& output_sizes, + ideep::tensor& y, IntArrayRef stride, IntArrayRef dilation, + IntArrayRef padding, int64_t groups, - c10::string_view attr = "none", - torch::List> scalars = - torch::List>(), - c10::optional algorithm = c10::nullopt) { - ideep::attr_t op_attr = ideep::attr_t(); - if (attr != "none") { - auto it = fx_fusion_attr_map().find(attr); - TORCH_CHECK(it != fx_fusion_attr_map().end(), "Fusion behavior undefined."); - op_attr = it->second(scalars, algorithm); - } - // See [Note: hacky wrapper removal for optional tensor] - c10::MaybeOwned bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt); - const Tensor& bias = *bias_maybe_owned; - - if (input_t.scalar_type() == ScalarType::BFloat16) { - TORCH_CHECK(mkldnn_bf16_device_check(), - "mkldnn_convolution: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq"); - } - - check_shape_forward(input_t, weight_t, bias, padding, stride, dilation, groups); - - bool is_channels_last = mkldnn_conv_use_channels_last(input_t, weight_t); + bool is_channels_last, + const ideep::attr_t& op_attr) { auto memory_format = mkldnn_convolution_memory_format(input_t.ndimension(), is_channels_last); - auto input = input_t.is_mkldnn() ? input_t : input_t.contiguous(memory_format); auto weight = weight_t.is_mkldnn() ? weight_t : weight_t.contiguous(memory_format); - auto output_sizes = conv_output_size(input.sizes(), weight.sizes(), padding, stride, dilation); - auto output = at::empty({0}, input.options()); - const ideep::tensor x = itensor_from_tensor(input); const ideep::tensor w = itensor_from_tensor(weight); - - ideep::tensor y; - if (is_channels_last) { - output.resize_(output_sizes, memory_format); - y = itensor_from_tensor(output); - } if (bias.defined()) { const ideep::tensor b = itensor_from_tensor(bias); ideep::convolution_forward::compute_v3( @@ -249,11 +221,66 @@ Tensor _mkldnn_convolution( is_channels_last, op_attr); } +} + +Tensor _mkldnn_convolution( + const Tensor& input_t, + const Tensor& weight_t, + const c10::optional& bias_opt, + IntArrayRef padding, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups, + bool use_channels_last, + c10::string_view attr = "none", + torch::List> scalars = + torch::List>(), + c10::optional algorithm = c10::nullopt) { + ideep::attr_t op_attr = ideep::attr_t(); + if (attr != "none") { + auto it = fusion_unary_attr_map().find(attr); + TORCH_CHECK( + it != fusion_unary_attr_map().end(), "Fusion behavior undefined."); + op_attr = it->second(scalars, algorithm); + } + // See [Note: hacky wrapper removal for optional tensor] + c10::MaybeOwned bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt); + const Tensor& bias = *bias_maybe_owned; + + if (input_t.scalar_type() == ScalarType::BFloat16) { + TORCH_CHECK(mkldnn_bf16_device_check(), + "mkldnn_convolution: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq"); + } + + check_shape_forward(input_t, weight_t, bias, padding, stride, dilation, groups); + + auto memory_format = + mkldnn_convolution_memory_format(input_t.ndimension(), use_channels_last); + + auto output_sizes = conv_output_size(input_t.sizes(), weight_t.sizes(), padding, stride, dilation); + auto output = at::empty({0}, input_t.options()); + ideep::tensor y; + if (use_channels_last) { + output.resize_(output_sizes, memory_format); + y = itensor_from_tensor(output); + } + _mkldnn_convolution_out( + input_t, + weight_t, + bias, + output_sizes, + y, + stride, + dilation, + padding, + groups, + use_channels_last, + op_attr); - if (input.is_mkldnn()) { - return MKLDNNTensor(y, input.options()); - } else if (!is_channels_last) { - return mkldnn_to_dense(MKLDNNTensor(y, input.options())); + if (input_t.is_mkldnn()) { + return MKLDNNTensor(y, input_t.options()); + } else if (!use_channels_last) { + return mkldnn_to_dense(MKLDNNTensor(y, input_t.options())); } else { TORCH_INTERNAL_ASSERT(y.get_desc().is_nhwc()); return output; @@ -268,8 +295,16 @@ Tensor mkldnn_convolution( IntArrayRef stride, IntArrayRef dilation, int64_t groups) { + bool use_channels_last = mkldnn_conv_use_channels_last(input_t, weight_t); return _mkldnn_convolution( - input_t, weight_t, bias_opt, padding, stride, dilation, groups); + input_t, + weight_t, + bias_opt, + padding, + stride, + dilation, + groups, + use_channels_last); } Tensor mkldnn_convolution_pointwise( @@ -284,6 +319,8 @@ Tensor mkldnn_convolution_pointwise( torch::List> scalars, c10::optional algorithm) { c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset); + bool use_channels_last = + weight_t.is_mkldnn() || mkldnn_conv_use_channels_last(input_t, weight_t); return _mkldnn_convolution( input_t, weight_t, @@ -292,11 +329,20 @@ Tensor mkldnn_convolution_pointwise( stride, dilation, groups, + use_channels_last, attr, scalars, algorithm); } +// Fuse convolution+binary_op+unary_op for good performance, which doing such +// operation: output=unary_op(binary_op(conv(input_t, ...), other_t, alpha)). +// The binary_attr means which binary_op is, it can be "add", or +// other binary operation. the unary_attr means which unary_op is, +// it can be "relu" or other unary operation, if it is none, meaning that +// there doesn't have a unary post op. unary_scalars and unary_algorithm +// are the parameters of the unary op, such as "hardtanh" has scalar parameters, +// "gelu" has algorithm parameters. Tensor mkldnn_convolution_pointwise_binary( const Tensor& input_t, const Tensor& other_t, @@ -306,10 +352,17 @@ Tensor mkldnn_convolution_pointwise_binary( IntArrayRef stride, IntArrayRef dilation, int64_t groups, - c10::string_view attr) { + c10::string_view binary_attr, + c10::optional alpha, + c10::optional unary_attr, + torch::List> unary_scalars, + c10::optional unary_algorithm) { TORCH_CHECK( input_t.ndimension() == 4 || input_t.ndimension() == 5, "mkldnn_convolution_pointwise_binary: currently only support 2d and 3d") + TORCH_CHECK( + !alpha.has_value() || alpha.value().to() == 1.0, + "mkldnn_convolution_pointwise_binary: the alpha value should be none or 1.0"); c10::MaybeOwned bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt); @@ -331,18 +384,32 @@ Tensor mkldnn_convolution_pointwise_binary( // Only calling fusion path for channels_last path. // TODO: OneDNN doesn't optimize well for groups > 1 case, it will be enabled // at next OneDNN release. - bool can_be_fused = - groups == 1 && mkldnn_conv_use_channels_last(input_t, weight_t); - - auto it_binary = fusion_binary_alg_map().find(attr); + bool use_channels_last = + weight_t.is_mkldnn() || mkldnn_conv_use_channels_last(input_t, weight_t); + bool can_be_fused = groups == 1 && use_channels_last; + + c10::string_view unary_attr_value = "none"; + ideep::algorithm unary_alg; + if (unary_attr.has_value()) { + auto it_unary = fusion_unary_alg_map().find(unary_attr.value()); + // Now, we only support conv+binary+relu. + TORCH_CHECK( + it_unary != fusion_unary_alg_map().end(), + "Unary Fusion behavior undefined."); + unary_attr_value = unary_attr.value(); + unary_alg = it_unary->second; + } + auto it_binary = fusion_binary_alg_map().find(binary_attr); TORCH_CHECK( - it_binary != fusion_binary_alg_map().end(), "Fusion behavior undefined."); + it_binary != fusion_binary_alg_map().end(), + "Binary Fusion behavior undefined."); + c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset); if (can_be_fused) { - c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset); auto memory_format = mkldnn_convolution_memory_format(input_t.ndimension(), true); auto input = input_t.contiguous(memory_format); - auto weight = weight_t.contiguous(memory_format); + auto weight = + weight_t.is_mkldnn() ? weight_t : weight_t.contiguous(memory_format); auto other = other_t.contiguous(memory_format); auto output = at::empty_like(other); const ideep::tensor x = itensor_from_tensor(input); @@ -356,7 +423,15 @@ Tensor mkldnn_convolution_pointwise_binary( } auto other_desc = ideep::tensor::desc( output_size, get_mkldnn_dtype(weight.scalar_type()), format_tag); - auto op_attr = ideep::attr_t::fuse_binary(it_binary->second, other_desc); + + ideep::attr_t op_attr; + ideep::post_ops po; + po.append_binary(it_binary->second, other_desc); + if (unary_attr_value != "none") { + po.append_eltwise(1.0, unary_alg, 0.f, 0.f); + } + op_attr.set_post_ops(po); + if (bias.defined()) { const ideep::tensor b = itensor_from_tensor(bias); ideep::convolution_forward::compute_binary( @@ -393,26 +468,131 @@ Tensor mkldnn_convolution_pointwise_binary( // Fallback case, if inputs are not channels last or have different dtype, // OneDNN fusion may have performance regression. Tensor output; - if (input_t.ndimension() == 4) { - output = at::conv2d( - input_t, weight_t, bias_opt, stride, padding, dilation, groups); + if (weight_t.is_mkldnn()) { + output = _mkldnn_convolution( + input_t, weight_t, bias, padding, stride, dilation, groups, true); } else { - output = at::conv3d( - input_t, weight_t, bias_opt, stride, padding, dilation, groups); + output = at::convolution( + input_t, weight_t, bias, stride, padding, dilation, false, 0, groups); } - if (attr == "add") { + if (binary_attr == "add" && unary_attr_value != "none") { + output = at::native::add_relu_(output, other_t); + return output; + } + if (binary_attr == "add") { output.add_(other_t); - } else if (attr == "sub") { + } else if (binary_attr == "sub") { output.sub_(other_t); - } else if (attr == "mul") { + } else if (binary_attr == "mul") { output.mul_(other_t); } else { output.div_(other_t); } + if (unary_attr_value != "none") { + output.relu_(); + } return output; } } +// Fuse convolution+binary_op+unary_op for good performance, which doing +// such operation: other_t=unary_op(binary_op(conv(input_t, ...), other_t, +// alpha)). The binary_attr means which binary_op is, it can be "add", or other +// binary operation. the unary_attr means which unary_op is, it can be "relu" or +// other unary operation, if it is none, meaning that there doesn't have a unary +// post op. unary_scalars and unary_algorithm are the parameters of the unary +// op, such as "hardtanh" has scalar parameters "gelu" has algorithm parameters. + +Tensor& mkldnn_convolution_pointwise_binary_( + const Tensor& input_t, + Tensor& other_t, + const Tensor& weight_t, + const c10::optional& bias_opt, + IntArrayRef padding, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups, + c10::string_view binary_attr, + c10::optional alpha, + c10::optional unary_attr, + torch::List> unary_scalars, + c10::optional unary_algorithm) { + // other_t += convolution(...), other_t = unary(other_t) + TORCH_CHECK( + input_t.ndimension() == 4 || input_t.ndimension() == 5, + "mkldnn_convolution_add_: currently only support 2d and 3d") + TORCH_CHECK( + binary_attr == "add", + "mkldnn_convolution_pointwise_binary_: only support binary op fusion") + TORCH_CHECK( + !alpha.has_value() || alpha.value().to() == 1.0, + "mkldnn_convolution_pointwise_binary: the alpha value for the binary op should be none(meaning 1.0) or 1.0"); + TORCH_CHECK( + !unary_attr.has_value() || unary_attr.value() == "relu", + "mkldnn_convolution_pointwise_binary: only support none or relu unary op fusion after binary op"); + + c10::MaybeOwned bias_maybe_owned = + at::borrow_from_optional_tensor(bias_opt); + const Tensor& bias = *bias_maybe_owned; + + // Make sure inputs have same type(device, layout, dtype), device is cpu and + // dtype is float or bfloat16. + check_mkldnn_binary_fusion_inputs(input_t, other_t, weight_t, bias); + + check_shape_forward( + input_t, weight_t, bias, padding, stride, dilation, groups); + + auto output_sizes = conv_output_size( + input_t.sizes(), weight_t.sizes(), padding, stride, dilation); + TORCH_CHECK( + output_sizes == other_t.sizes(), + "Add Fusion's inputs should have same shape"); + // Only calling fusion path for channels_last path and the output is contiguous tensor(channels_last). + bool can_be_fused = (weight_t.is_mkldnn() || + mkldnn_conv_use_channels_last(input_t, weight_t)) && + (other_t.is_contiguous(at::MemoryFormat::ChannelsLast) || + other_t.is_contiguous(at::MemoryFormat::ChannelsLast3d)); + c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset); + if (can_be_fused) { + ideep::tensor y = itensor_from_tensor(other_t); + ideep::attr_t op_attr; + if (unary_attr.has_value()) { + op_attr = ideep::attr_t::residual(); + } else { + op_attr = ideep::attr_t::fuse_sum(); + } + _mkldnn_convolution_out( + input_t, + weight_t, + bias, + output_sizes, + y, + stride, + dilation, + padding, + groups, + true, + op_attr); + } else { + // Fallback case, if inputs are not channels last or have different dtype, + // OneDNN fusion may have performance regression. + Tensor output; + if (weight_t.is_mkldnn()) { + output = _mkldnn_convolution( + input_t, weight_t, bias, padding, stride, dilation, groups, true); + } else { + output = at::convolution( + input_t, weight_t, bias, stride, padding, dilation, false, 0, groups); + } + if (unary_attr.has_value()) { + other_t = at::native::add_relu_(other_t, output); + } else { + other_t.add_(output); + } + } + return other_t; +} + Tensor mkldnn_convolution_backward_input( IntArrayRef input_size, const Tensor& grad_output, @@ -540,8 +720,22 @@ TORCH_LIBRARY_IMPL(mkldnn, CPU, m) { m.impl( TORCH_SELECTIVE_NAME("mkldnn::_convolution_pointwise.binary"), TORCH_FN(mkldnn_convolution_pointwise_binary)); + m.impl( + TORCH_SELECTIVE_NAME("mkldnn::_convolution_pointwise_.binary"), + TORCH_FN(mkldnn_convolution_pointwise_binary_)); } +TORCH_LIBRARY_IMPL(mkldnn, MkldnnCPU, m) { + m.impl( + TORCH_SELECTIVE_NAME("mkldnn::_convolution_pointwise"), + TORCH_FN(mkldnn_convolution_pointwise)); + m.impl( + TORCH_SELECTIVE_NAME("mkldnn::_convolution_pointwise.binary"), + TORCH_FN(mkldnn_convolution_pointwise_binary)); + m.impl( + TORCH_SELECTIVE_NAME("mkldnn::_convolution_pointwise_.binary"), + TORCH_FN(mkldnn_convolution_pointwise_binary_)); +} }} // namespace at::native #endif diff --git a/aten/src/ATen/native/mkldnn/Linear.cpp b/aten/src/ATen/native/mkldnn/Linear.cpp index b57d8e56a16d1..894e54eefb1c1 100644 --- a/aten/src/ATen/native/mkldnn/Linear.cpp +++ b/aten/src/ATen/native/mkldnn/Linear.cpp @@ -10,6 +10,7 @@ #else #include #include +#include #include #include #include @@ -215,8 +216,9 @@ Tensor mkldnn_linear_pointwise( } const ideep::tensor w = itensor_from_tensor(weight_t); - auto it = fx_fusion_attr_map().find(attr); - TORCH_CHECK(it != fx_fusion_attr_map().end(), "Fusion behavior undefined."); + auto it = fusion_unary_attr_map().find(attr); + TORCH_CHECK( + it != fusion_unary_attr_map().end(), "Fusion behavior undefined."); ideep::attr_t op_attr = it->second(scalars, algorithm); if (mkldnn_bias.has_value()) { @@ -335,3 +337,100 @@ TORCH_LIBRARY_IMPL(mkldnn, CPU, m) { } // namespace at #endif // AT_MKLDNN_ENABLED + +#if AT_MKL_ENABLED() && AT_MKLDNN_ENABLED() +#include + +namespace at { +namespace native { + +Tensor mkl_linear( + const Tensor& self, + const Tensor& mkl_weight_t, + const Tensor& origin_weight_t, + const c10::optional& bias_opt, + const int64_t prepack_batch_size) { + c10::MaybeOwned bias_maybe_owned = + at::borrow_from_optional_tensor(bias_opt); + const Tensor& bias = *bias_maybe_owned; + TORCH_CHECK( + self.options().type_equal(origin_weight_t.options()), + "Input type (", + self.toString(), + ") and weight type (", + origin_weight_t.toString(), + ") should be the same"); + TORCH_CHECK( + !bias.defined() || (self.options().type_equal(bias.options())), + "Input type (", + self.toString(), + ") and bias type (", + bias.toString(), + ") should be the same"); + TORCH_CHECK( + mkl_weight_t.scalar_type() == origin_weight_t.scalar_type() && + origin_weight_t.scalar_type() == kFloat, + "mkl_linear: weight dtype should be float"); + + c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset); + auto input_size = self.sizes(); + std::vector output_size(input_size.begin(), input_size.end() - 1); + output_size.push_back(origin_weight_t.size(0)); + auto output = at::empty(output_size, self.options()); + int64_t M = self.numel() / self.size(self.dim() - 1); + if (M == prepack_batch_size && mkl_weight_t.is_mkldnn()) { + auto self_ = self.is_contiguous() ? self : self.contiguous(); + auto K = origin_weight_t.size(1); + auto N = origin_weight_t.size(0); + const ideep::tensor& w = itensor_from_mkldnn(mkl_weight_t); + auto in_ptr = self_.data_ptr(); + auto weight_ptr = (float*)(w.get_data_handle()); + auto out_ptr = output.data_ptr(); + if (bias.defined()) { + auto bias_ = bias.is_contiguous() ? bias : bias.contiguous(); + auto bias_ptr = bias_.data_ptr(); +#ifdef _OPENMP +#if (_OPENMP >= 201307) +#pragma omp parallel for simd schedule( \ + static) if (omp_get_max_threads() > 1 && !omp_in_parallel()) +#else +#pragma omp parallel for schedule( \ + static) if (omp_get_max_threads() > 1 && !omp_in_parallel()) +#endif +#endif + for (int64_t i = 0; i < M; ++i) { + memcpy(out_ptr + i * N, bias_ptr, sizeof(float) * N); + } + } + cblas_sgemm_compute( + CblasRowMajor, + CblasNoTrans, + CblasPacked, + M, + N, + K, + in_ptr, + K, + weight_ptr, + K, + bias.defined() ? 1.f : 0.f, + out_ptr, + N); + } else { + output = at::linear_out(output, self, origin_weight_t, bias_opt); + } + return output; +} + +TORCH_LIBRARY_IMPL(mkl, CPU, m) { + m.impl(TORCH_SELECTIVE_NAME("mkl::_mkl_linear"), TORCH_FN(mkl_linear)); +} + +TORCH_LIBRARY_IMPL(mkl, MkldnnCPU, m) { + m.impl(TORCH_SELECTIVE_NAME("mkl::_mkl_linear"), TORCH_FN(mkl_linear)); +} + +} // namespace native +} // namespace at + +#endif // AT_MKL_ENABLED && AT_MKLDNN_ENABLED diff --git a/aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp b/aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp index f1ac8f9d53830..d643fae22ca26 100644 --- a/aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp +++ b/aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp @@ -1,9 +1,10 @@ #define TORCH_ASSERT_ONLY_METHOD_OPERATORS -#include #include +#include #include #include #include +#include #ifndef AT_PER_OPERATOR_HEADERS #include @@ -88,7 +89,8 @@ Tensor mkldnn_reorder_conv2d_weight( IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, - int64_t groups) { + int64_t groups, + c10::OptionalArrayRef input_size) { if (self.scalar_type() == ScalarType::BFloat16) { TORCH_CHECK(mkldnn_bf16_device_check(), "mkldnn_reorder_conv2d_weight: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq"); @@ -106,16 +108,28 @@ Tensor mkldnn_reorder_conv2d_weight( w.reshape({wdims[0] * wdims[1], wdims[2], wdims[3], wdims[4]}); } - auto desc = - ideep::convolution_forward::expected_weights_desc( - w.get_dims(), - w.get_data_type(), - {stride.begin(), stride.end()}, - {padding.begin(), padding.end()}, - {padding.begin(), padding.end()}, - {dilation.begin(), dilation.end()}, - groups, - ideep::algorithm::convolution_direct); + ideep::dims src_dims = ideep::dims(); + bool is_channels_last = false; + if (input_size.has_value()) { + src_dims = input_size.value().vec(); + // if has input size, we always use channels last. + is_channels_last = true; + } + + auto desc = ideep::convolution_forward::expected_weights_desc( + w.get_dims(), + w.get_data_type(), + {stride.begin(), stride.end()}, + {padding.begin(), padding.end()}, + {padding.begin(), padding.end()}, + {dilation.begin(), dilation.end()}, + groups, + ideep::algorithm::convolution_direct, + ideep::prop_kind::forward, + w.get_data_type(), + src_dims, + ideep::attr_t(), + is_channels_last); ideep::tensor result; result.init(desc); result.feed_from(w); @@ -169,7 +183,8 @@ Tensor mkldnn_reorder_conv2d_weight( IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, - int64_t groups) { + int64_t groups, + c10::OptionalArrayRef input_size) { TORCH_CHECK(false, "mkldnn_reorder_conv2d_weight: MKL-DNN build is disabled"); } @@ -184,4 +199,48 @@ Tensor mkldnn_reorder_conv3d_weight( #endif // AT_MKLDNN_ENABLED() +#if AT_MKL_ENABLED() && AT_MKLDNN_ENABLED() +#include + +Tensor mkl_reorder_linear_weight( + const Tensor& weight, + const int64_t batch_size) { + TORCH_CHECK( + weight.scalar_type() == ScalarType::Float, + "reorder_linear_weight: weight's dtype should be float"); + c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset); + auto M = batch_size; + auto N = weight.size(0); + auto K = weight.size(1); + int64_t pack_size = + (int64_t)(cblas_sgemm_pack_get_size(CblasBMatrix, M, N, K) / sizeof(float) + 1); + auto packed_weight = empty_mkldnn( + {pack_size, 1}, + weight.scalar_type(), + weight.options().layout_opt(), + weight.options().device_opt(), + weight.options().pinned_memory_opt()); + ideep::tensor& mkl_weight = itensor_from_mkldnn(packed_weight); + ideep::tensor& orig_w = itensor_from_mkldnn(weight); + cblas_sgemm_pack( + CblasRowMajor, + CblasBMatrix, + CblasTrans, + M, + N, + K, + 1.0f, + (float*)(orig_w.get_data_handle()), + K, + (float*)(mkl_weight.get_data_handle())); + return packed_weight; +} + +TORCH_LIBRARY_IMPL(mkl, MkldnnCPU, m) { + m.impl( + TORCH_SELECTIVE_NAME("mkl::_mkl_reorder_linear_weight"), + TORCH_FN(mkl_reorder_linear_weight)); +} + +#endif // AT_MKL_ENABLED && AT_MKLDNN_ENABLED }} diff --git a/aten/src/ATen/native/mkldnn/Normalization.cpp b/aten/src/ATen/native/mkldnn/Normalization.cpp index 1be6224a23c42..d0171865fac61 100644 --- a/aten/src/ATen/native/mkldnn/Normalization.cpp +++ b/aten/src/ATen/native/mkldnn/Normalization.cpp @@ -41,6 +41,23 @@ std::tuple mkldnn_layer_norm_last_index_weight_bias_f32( TORCH_CHECK(false, "mkldnn_layer_norm_last_index_weight_bias_f32: ATen not compiled with MKLDNN support"); } +std::tuple _mkldnn_batch_norm_legit( + const Tensor& input, const c10::optional& weight_opt, const c10::optional& bias_opt, Tensor& running_mean, Tensor& running_var, + bool train, + double momentum, + double eps) { + TORCH_CHECK(false, "_mkldnn_batch_norm_legit: ATen not compiled with MKLDNN support"); +} + + +std::tuple _mkldnn_batch_norm_legit_no_stats( + const Tensor& input, const c10::optional& weight_opt, const c10::optional& bias_opt, + bool train, + double momentum, + double eps) { + TORCH_CHECK(false, "_mkldnn_batch_norm_legit_no_stats: ATen not compiled with MKLDNN support"); +} + } // namespace native } // namespace at @@ -173,6 +190,25 @@ std::tuple mkldnn_batch_norm( } } + +std::tuple _mkldnn_batch_norm_legit( + const Tensor& input, const c10::optional& weight_opt, const c10::optional& bias_opt, Tensor& running_mean, Tensor& running_var, + bool train, + double momentum, + double eps) { + return mkldnn_batch_norm(input, weight_opt, bias_opt, running_mean, running_var, train, momentum, eps); +} + + +std::tuple _mkldnn_batch_norm_legit_no_stats( + const Tensor& input, const c10::optional& weight_opt, const c10::optional& bias_opt, + bool train, + double momentum, + double eps) { + return mkldnn_batch_norm(input, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, eps); +} + + std::tuple mkldnn_batch_norm_backward(const Tensor& grad_output, const Tensor& input, const c10::optional& weight_opt, const c10::optional& running_mean_opt, const c10::optional& running_var_opt, const c10::optional& save_mean_opt, const c10::optional& save_invstd_opt, bool train, diff --git a/aten/src/ATen/native/mkldnn/Pooling.cpp b/aten/src/ATen/native/mkldnn/Pooling.cpp index a0f9207e2faed..30ff49f49dd3b 100644 --- a/aten/src/ATen/native/mkldnn/Pooling.cpp +++ b/aten/src/ATen/native/mkldnn/Pooling.cpp @@ -518,7 +518,7 @@ Tensor mkldnn_adaptive_avg_pool2d( /*padding*/ {0, 0}, /*dilation*/ {1, 1}, /*ceil_mode*/ false, - /*algo*/ ideep::algorithm::pooling_avg); + /*algo*/ ideep::algorithm::pooling_avg_exclude_padding); } Tensor& mkldnn_adaptive_avg_pool2d_out_stub(const Tensor& input, diff --git a/aten/src/ATen/native/mkldnn/Prelu.cpp b/aten/src/ATen/native/mkldnn/Prelu.cpp index acc78211d83cc..dc7d239da7b68 100644 --- a/aten/src/ATen/native/mkldnn/Prelu.cpp +++ b/aten/src/ATen/native/mkldnn/Prelu.cpp @@ -17,7 +17,7 @@ std::tuple mkldnn_prelu_backward(const Tensor& grad_output, cons }} -#else // AT_MKLDNN_EBABLED +#else // AT_MKLDNN_ENABLED #include #include @@ -76,4 +76,4 @@ std::tuple mkldnn_prelu_backward(const Tensor& grad_output, cons } }} -#endif // AT_MKLDNN_EBABLED +#endif // AT_MKLDNN_ENABLED diff --git a/aten/src/ATen/native/mkldnn/RegisterMkldnnOpContextClass.cpp b/aten/src/ATen/native/mkldnn/RegisterMkldnnOpContextClass.cpp index 0be8d8a100cd6..8841d65a2e782 100644 --- a/aten/src/ATen/native/mkldnn/RegisterMkldnnOpContextClass.cpp +++ b/aten/src/ATen/native/mkldnn/RegisterMkldnnOpContextClass.cpp @@ -42,7 +42,9 @@ TORCH_LIBRARY(mkldnn, m) { m.def(TORCH_SELECTIVE_SCHEMA( "mkldnn::_convolution_pointwise(Tensor X, Tensor W, Tensor? B, int[] padding, int[] stride, int[] dilation, int groups, str attr, Scalar?[] scalars, str? algorithm) -> Tensor Y")); m.def(TORCH_SELECTIVE_SCHEMA( - "mkldnn::_convolution_pointwise.binary(Tensor X, Tensor other, Tensor W, Tensor? B, int[] padding, int[] stride, int[] dilation, int groups, str attr) -> Tensor Y")); + "mkldnn::_convolution_pointwise.binary(Tensor X, Tensor other, Tensor W, Tensor? B, int[] padding, int[] stride, int[] dilation, int groups, str binary_attr, Scalar? alpha, str? unary_attr, Scalar?[] unary_scalars, str? unary_algorithm) -> Tensor Y")); + m.def(TORCH_SELECTIVE_SCHEMA( + "mkldnn::_convolution_pointwise_.binary(Tensor X, Tensor(a!) other, Tensor W, Tensor? B, int[] padding, int[] stride, int[] dilation, int groups, str binary_attr, Scalar? alpha, str? unary_attr, Scalar?[] unary_scalars, str? unary_algorithm) -> Tensor(a!) Y")); } TORCH_LIBRARY(mkldnn_prepacked, m) { @@ -67,3 +69,22 @@ TORCH_LIBRARY_IMPL(mkldnn_prepacked, CPU, m) { } // namespace at #endif // AT_MKLDNN_ENABLED() + +#if AT_MKL_ENABLED() && AT_MKLDNN_ENABLED() + +namespace at { +namespace native { +namespace mkl { + +TORCH_LIBRARY(mkl, m) { + m.def(TORCH_SELECTIVE_SCHEMA( + "mkl::_mkl_reorder_linear_weight(Tensor X, int batch_size) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA( + "mkl::_mkl_linear(Tensor X, Tensor MKL_W, Tensor ORI_W, Tensor? B, int batch_size) -> Tensor")); +} + +} // namespace mkl +} // namespace native +} // namespace at + +#endif // AT_MKL_ENABLED && AT_MKLDNN_ENABLED diff --git a/aten/src/ATen/native/mkldnn/TensorShape.cpp b/aten/src/ATen/native/mkldnn/TensorShape.cpp index fbf1e96bf14da..1e54aae9d6601 100644 --- a/aten/src/ATen/native/mkldnn/TensorShape.cpp +++ b/aten/src/ATen/native/mkldnn/TensorShape.cpp @@ -1,7 +1,8 @@ #define TORCH_ASSERT_ONLY_METHOD_OPERATORS -#include #include #include +#include +#include #include #ifndef AT_PER_OPERATOR_HEADERS @@ -78,6 +79,9 @@ Tensor mkldnn_clone(const Tensor& self, c10::optional optiona } Tensor mkldnn_transpose(const Tensor& self, int64_t dim0, int64_t dim1) { + auto ndims = self.dim(); + dim0 = maybe_wrap_dim(dim0, ndims); + dim1 = maybe_wrap_dim(dim1, ndims); const ideep::tensor& x = itensor_from_mkldnn(self); ideep::tensor y; std::vector axes(x.ndims()); diff --git a/aten/src/ATen/native/mkldnn/Utils.cpp b/aten/src/ATen/native/mkldnn/Utils.cpp index 42f855d75cbe8..2c9bcc016e47d 100644 --- a/aten/src/ATen/native/mkldnn/Utils.cpp +++ b/aten/src/ATen/native/mkldnn/Utils.cpp @@ -38,13 +38,19 @@ void check_mkldnn_binary_fusion_inputs( const Tensor& other, const Tensor& weight, const Tensor& bias) { - TORCH_CHECK( - input.options().type_equal(weight.options()), - "Input type (", - input.toString(), - ") and weight type (", - weight.toString(), - ") should be the same"); + if (!weight.is_mkldnn()) { + TORCH_CHECK( + input.options().type_equal(weight.options()), + "Input type (", + input.toString(), + ") and weight type (", + weight.toString(), + ") should be the same"); + } else { + TORCH_CHECK( + input.scalar_type() == input.scalar_type(), + "mkldnn pointwise binary: input dtype and weight dtype should be the same"); + } TORCH_CHECK( input.options().type_equal(other.options()), "Input type (", @@ -61,11 +67,11 @@ void check_mkldnn_binary_fusion_inputs( ") should be the same"); TORCH_CHECK( input.device().is_cpu(), - "mkldnn pointwise binary fusion: inputs' device should be CPU") + "mkldnn pointwise binary fusion: input's device should be CPU"); TORCH_CHECK( input.scalar_type() == ScalarType::Float || input.scalar_type() == ScalarType::BFloat16, - "mkldnn pointwise binary: inputs' dtypoe should be float or bfloat16") + "mkldnn pointwise binary: input's dtype should be float or bfloat16"); if (input.scalar_type() == ScalarType::BFloat16) { TORCH_CHECK( mkldnn_bf16_device_check(), @@ -127,11 +133,12 @@ AttrFunction attr_func_gelu = [](torch::List> scalars, return ideep::attr_t::fuse_gelu(1.0, 0.f, 0.f, gelu_type); }; -const std::map& fx_fusion_attr_map() { +const std::map& fusion_unary_attr_map() { static const std::map fusion_attr_map{ {"relu", ATTR_FUNC(relu)}, {"sigmoid", ATTR_FUNC(sigmoid)}, {"tanh", ATTR_FUNC(tanh)}, + {"swish", ATTR_FUNC(swish)}, {"hardswish", ATTR_FUNC(hardswish)}, {"leaky_relu", attr_func_leaky_relu}, {"hardtanh", attr_func_hardtanh}, @@ -140,6 +147,13 @@ const std::map& fx_fusion_attr_map() { return fusion_attr_map; }; +const std::map& fusion_unary_alg_map() { + static const std::map fusion_attr_map{ + {"relu", {ideep::algorithm::eltwise_relu}}, + }; + return fusion_attr_map; +}; + const std::map& fusion_binary_alg_map() { static const std::map fusion_attr_map{ {"add", {ideep::algorithm::binary_add}}, diff --git a/aten/src/ATen/native/mkldnn/Utils.h b/aten/src/ATen/native/mkldnn/Utils.h index 314a7efc950ef..a25be13c46dab 100644 --- a/aten/src/ATen/native/mkldnn/Utils.h +++ b/aten/src/ATen/native/mkldnn/Utils.h @@ -39,7 +39,9 @@ using AttrFunction = std::function>, c10::optional)>; -const std::map& fx_fusion_attr_map(); +const std::map& fusion_unary_attr_map(); + +const std::map& fusion_unary_alg_map(); const std::map& fusion_binary_alg_map(); diff --git a/aten/src/ATen/native/mps/MPSGraphVenturaOps.h b/aten/src/ATen/native/mps/MPSGraphVenturaOps.h new file mode 100644 index 0000000000000..5f581dbbb78d6 --- /dev/null +++ b/aten/src/ATen/native/mps/MPSGraphVenturaOps.h @@ -0,0 +1,114 @@ +#pragma once +#include + +// TODO: Remove me when moved to MacOS 13 +@interface MPSGraph (VenturaOps) + +#if !defined(__MAC_13_0) && \ + (!defined(MAC_OS_X_VERSION_13_0) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_13_0)) + +API_AVAILABLE(macos(13.0)) +typedef NS_ENUM(NSUInteger, MPSGraphResizeNearestRoundingMode) +{ + MPSGraphResizeNearestRoundingModeRoundPreferCeil = 0L, + MPSGraphResizeNearestRoundingModeRoundPreferFloor = 1L, + MPSGraphResizeNearestRoundingModeCeil = 2L, + MPSGraphResizeNearestRoundingModeFloor = 3L, + MPSGraphResizeNearestRoundingModeRoundToEven = 4L, + MPSGraphResizeNearestRoundingModeRoundToOdd = 5L, +}; +#endif + +- (MPSGraphTensor * _Nonnull)cumulativeSumWithTensor:(MPSGraphTensor * _Nonnull)tensor + axis:(NSInteger)axis + name:(NSString * _Nullable)name; + +- (MPSGraphTensor * _Nonnull)sortWithTensor:(MPSGraphTensor * _Nonnull)tensor + axis:(NSInteger)axis + name:(NSString * _Nullable)name; + +- (MPSGraphTensor * _Nonnull)argSortWithTensor:(MPSGraphTensor * _Nonnull)tensor + axis:(NSInteger)axis + name:(NSString * _Nullable)name; + +- (MPSGraphTensor * _Nonnull)inverseOfTensor:(MPSGraphTensor * _Nonnull) inputTensor + name:(NSString * _Nullable)name; + +- (MPSGraphTensor * _Nonnull) resizeNearestWithTensor:(MPSGraphTensor * _Nonnull) imagesTensor + sizeTensor:(MPSGraphTensor * _Nonnull) size + nearestRoundingMode:(MPSGraphResizeNearestRoundingMode) nearestRoundingMode + centerResult:(BOOL) centerResult + alignCorners:(BOOL) alignCorners + layout:(MPSGraphTensorNamedDataLayout) layout + name:(NSString * _Nullable) name; + +- (MPSGraphTensor * _Nonnull) resizeNearestWithTensor:(MPSGraphTensor * _Nonnull) imagesTensor + sizeTensor:(MPSGraphTensor * _Nonnull) size + scaleOffsetTensor:(MPSGraphTensor * _Nonnull) scaleOffset + nearestRoundingMode:(MPSGraphResizeNearestRoundingMode) nearestRoundingMode + layout:(MPSGraphTensorNamedDataLayout) layout + name:(NSString * _Nullable) name; + +- (MPSGraphTensor * _Nonnull) resizeBilinearWithTensor:(MPSGraphTensor * _Nonnull) imagesTensor + sizeTensor:(MPSGraphTensor * _Nonnull) size + centerResult:(BOOL) centerResult + alignCorners:(BOOL) alignCorners + layout:(MPSGraphTensorNamedDataLayout) layout + name:(NSString * _Nullable) name; + +- (MPSGraphTensor * _Nonnull) resizeBilinearWithTensor:(MPSGraphTensor * _Nonnull) imagesTensor + sizeTensor:(MPSGraphTensor * _Nonnull) size + scaleOffsetTensor:(MPSGraphTensor * _Nonnull) scaleOffset + layout:(MPSGraphTensorNamedDataLayout) layout + name:(NSString * _Nullable) name; + +- (MPSGraphTensor * _Nonnull) resizeNearestWithGradientTensor:(MPSGraphTensor * _Nonnull) gradient + input:(MPSGraphTensor * _Nonnull) input + nearestRoundingMode:(MPSGraphResizeNearestRoundingMode) nearestRoundingMode + centerResult:(BOOL) centerResult + alignCorners:(BOOL) alignCorners + layout:(MPSGraphTensorNamedDataLayout) layout + name:(NSString * _Nullable) name; + +- (MPSGraphTensor * _Nonnull) resizeNearestWithGradientTensor:(MPSGraphTensor * _Nonnull) gradient + input:(MPSGraphTensor * _Nonnull) input + scaleOffsetTensor:(MPSGraphTensor * _Nonnull) scaleOffset + nearestRoundingMode:(MPSGraphResizeNearestRoundingMode) nearestRoundingMode + layout:(MPSGraphTensorNamedDataLayout) layout + name:(NSString * _Nullable) name; + +- (MPSGraphTensor * _Nonnull) resizeBilinearWithGradientTensor:(MPSGraphTensor * _Nonnull) gradient + input:(MPSGraphTensor * _Nonnull) input + centerResult:(BOOL) centerResult + alignCorners:(BOOL) alignCorners + layout:(MPSGraphTensorNamedDataLayout) layout + name:(NSString * _Nullable) name; + +- (MPSGraphTensor * _Nonnull) resizeBilinearWithGradientTensor:(MPSGraphTensor * _Nonnull) gradient + input:(MPSGraphTensor * _Nonnull) input + scaleOffsetTensor:(MPSGraphTensor * _Nonnull) scaleOffset + layout:(MPSGraphTensorNamedDataLayout) layout + name:(NSString * _Nullable) name; + +- (MPSGraphTensor * _Nonnull) sampleGridWithSourceTensor:(MPSGraphTensor * _Nonnull) source + coordinateTensor:(MPSGraphTensor * _Nonnull) coordinates + layout:(MPSGraphTensorNamedDataLayout) layout + normalizeCoordinates:(BOOL) normalizeCoordinates + relativeCoordinates:(BOOL) relativeCoordinates + alignCorners:(BOOL) alignCorners + paddingMode:(MPSGraphPaddingMode) paddingMode + samplingMode:(MPSGraphResizeMode) samplingMode + constantValue:(double) constantValue + name:(NSString * _Nullable) name; + +- (MPSGraphTensor * _Nonnull) sampleGridWithSourceTensor:(MPSGraphTensor * _Nonnull) source + coordinateTensor:(MPSGraphTensor * _Nonnull) coordinates + layout:(MPSGraphTensorNamedDataLayout) layout + normalizeCoordinates:(BOOL) normalizeCoordinates + relativeCoordinates:(BOOL) relativeCoordinates + alignCorners:(BOOL) alignCorners + paddingMode:(MPSGraphPaddingMode) paddingMode + nearestRoundingMode:(MPSGraphResizeNearestRoundingMode) nearestRoundingMode + constantValue:(double) constantValue + name:(NSString * _Nullable) name; +@end diff --git a/aten/src/ATen/native/mps/OperationUtils.h b/aten/src/ATen/native/mps/OperationUtils.h index 13c817b4c45bf..b24f0c633eb42 100644 --- a/aten/src/ATen/native/mps/OperationUtils.h +++ b/aten/src/ATen/native/mps/OperationUtils.h @@ -19,24 +19,6 @@ namespace at { namespace native { namespace mps { -struct TORCH_CUDA_CPP_API MPSGeneratorImpl : public c10::GeneratorImpl { - MPSGeneratorImpl(DeviceIndex device_index = -1); - ~MPSGeneratorImpl() = default; - - void set_current_seed(uint64_t seed) override; - uint64_t current_seed() const override; - uint64_t seed() override; - void set_state(const c10::TensorImpl& new_state) override; - c10::intrusive_ptr get_state() const override; - static DeviceType device_type(); - -private: - MPSGeneratorImpl* clone_impl() const override; - uint64_t seed_ = default_rng_seed_val; -}; - -const Generator& getDefaultMPSGenerator(); - struct MPSScalar { id getMTLBuffer() const { return __builtin_bit_cast(id, buffer.get()); } @@ -60,17 +42,22 @@ void runMPSGraph( MPSDataType getMPSDataType(ScalarType scalar_type); MPSDataType getMPSScalarType(ScalarType scalar_type); MPSScalar getMPSScalar(const Scalar& scalar, ScalarType type); -std::string getMPSTypeString(ScalarType scalar_type); +std::string getMPSTypeString(ScalarType scalar_type, bool short_name = false); +std::string scalarToMetalTypeString(const c10::ScalarType& scalar_type); +NSArray* getTensorAxes(const Tensor& t); +NSArray* getTensorAxes(const Tensor& t, at::OptionalIntArrayRef dim); std::string getMPSShapeString(MPSShape* shape); -std::string getTensorsStringKey(const TensorList& tensors, bool use_scalar_value = false); +std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype = false); std::string getArrayRefString(const IntArrayRef s); // use has_storage() on the returned tensor to determine if src actually is a view Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst); -Tensor& scatterViewTensor(const at::Tensor& src, at::Tensor& output); +Tensor& scatterViewTensor(const at::Tensor& src, at::Tensor& output, id updatesBuffer = nil); +bool canSliceViewTensor(const Tensor& src, MPSShape *mpsShape); +MPSGraphTensorData* getMPSGraphTensorDataForView(const Tensor& src, MPSShape *mpsShape, const MPSDataType mpsDataType); -MPSShape* getMPSShape(const Tensor& t); -MPSShape* getMPSShape(IntArrayRef sizes); -MPSShape* getMPSShape(c10::MaybeOwned t); +// The MPSShape could vary based on memory format +MPSShape* getMPSShape(const Tensor& t, c10::MemoryFormat memory_format = MemoryFormat::Contiguous); +MPSShape* getMPSShape(IntArrayRef sizes, c10::MemoryFormat memory_format = MemoryFormat::Contiguous); static inline id getMTLBufferStorage(const at::Tensor& tensor) { return __builtin_bit_cast(id, tensor.storage().data()); @@ -80,7 +67,8 @@ class Placeholder { public: Placeholder() : _placeholder(nullptr), _value(nullptr), _tensor(Tensor()) {} Placeholder(MPSGraphTensor* mpsGraphTensor) : _placeholder(mpsGraphTensor), _value(nullptr), _tensor(Tensor()) {} - Placeholder(MPSGraphTensor* mpsGraphTensor, const Tensor& self, MPSShape *mpsShape = nullptr); + Placeholder(MPSGraphTensor* mpsGraphTensor, const Tensor& self, MPSShape *mpsShape = nullptr, + bool gatherTensorData = true, MPSDataType dataType = MPSDataTypeInvalid); MPSGraphTensor* getMPSGraphTensor() { return _placeholder; } @@ -99,16 +87,19 @@ class Placeholder { void resize_tensor(Tensor* output); MPSGraphTensor* trunc_tensor(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor); +MPSGraphTensor* convertNHWCtoNCHW(MPSGraph *mpsGraph, MPSGraphTensor* tensor); MPSGraphTensor* castMPSTensor(MPSGraph *mpsGraph, MPSGraphTensor* tensor, ScalarType toType); MPSGraphTensorData *getMPSGraphTensorData(MPSGraph* mpsGraph, MPSStream* mpsStream, const Tensor& tensor); MPSGraphTensorData* getMPSGraphTensorFromScalar(MPSStream* mpsStream, MPSScalar& scalar); MPSGraph* make_mps_graph(); void printTensorNDArray(const Tensor& t); +MPSNDArray* ndArrayFromTensor(const Tensor& tensor, MPSShape *shape, MPSDataType mpsType); MPSGraphTensor* mpsGraphUnrankedPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType); MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType, MPSShape* mpsShape); MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph *mpsGraph, const Tensor& tensor); +MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType); MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph *mpsGraph, const Scalar& scalar); string get_mem_format_string(c10::MemoryFormat memory_format); @@ -143,6 +134,13 @@ struct MPSUnaryCachedGraph : public MPSCachedGraph MPSGraphTensor *outputTensor_ = nil; }; +struct MPSBinaryCachedGraph : public MPSCachedGraph +{ + MPSBinaryCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + MPSGraphTensor *inputTensor_ = nil; + MPSGraphTensor *otherTensor_ = nil; + MPSGraphTensor *outputTensor_ = nil; +}; // TODO: Improve the overall design of MPSGraphCache. // https://github.com/pytorch/pytorch/issues/77176 @@ -178,7 +176,7 @@ struct MPSGraphCache MPSGraphCache(const MPSGraphCache&) = delete; void operator=(const MPSGraphCache&) = delete; - MPSCachedGraph* CreateCachedGraph(const std::string& key, CreateCachedGraphBlock createCacheBlock, void* view_ptr = nullptr) { + MPSCachedGraph* CreateCachedGraph(const std::string& key, CreateCachedGraphBlock createCacheBlock) { __block MPSCachedGraph * result = nil; @@ -196,17 +194,14 @@ struct MPSGraphCache result = createCacheBlock(); CacheEntry entry(key, result); cache_.emplace(hash, entry); - if (view_ptr) { - views_list.insert(std::make_pair(view_ptr, hash)); - } } }); return result; } template - inline T* CreateCachedGraphAs(const std::string& key, CreateCachedGraphBlock createCacheBlock, void* view_ptr = nullptr) { - return static_cast(CreateCachedGraph(key, createCacheBlock, view_ptr)); + inline T* CreateCachedGraphAs(const std::string& key, CreateCachedGraphBlock createCacheBlock) { + return static_cast(CreateCachedGraph(key, createCacheBlock)); } MPSCachedGraph* LookUp(const std::string& key) const { @@ -231,24 +226,6 @@ struct MPSGraphCache return static_cast(LookUp(key)); } - void FindAndRemoveViewEntry(void* ptr) { - // this may find multiple view entries with the same buffer pointers - auto views_range = views_list.equal_range(ptr); - if (views_range.first == views_range.second) - return; - for (auto view_it = views_range.first; view_it != views_range.second; ++view_it) { - MPSCacheKey hash = view_it->second; - // find the cache entry associated with the hash - auto cache_it = cache_.find(hash); - if (cache_it != cache_.end()) { - cache_.erase(cache_it); - delete cache_it->second.cachedGraph_; - } - } - // this erase-by-key will remove all pairs in the list with the same key - views_list.erase(ptr); - } - private: MPSGraphCache() { serialQueue_ = dispatch_queue_create("cache queue", DISPATCH_QUEUE_SERIAL); @@ -256,13 +233,11 @@ struct MPSGraphCache static MPSGraphCache* _instance_cache; std::unordered_map cache_; - // list of buffers associated with view entries in the cache - // note that multiple view cache entries could use the same buffer pointer - std::unordered_multimap views_list; dispatch_queue_t serialQueue_ = nullptr; }; + } // namespace mps } // namespace native } // namespace at diff --git a/aten/src/ATen/native/mps/OperationUtils.mm b/aten/src/ATen/native/mps/OperationUtils.mm index fd18c1f4a95e4..3ba4146ae3cf3 100644 --- a/aten/src/ATen/native/mps/OperationUtils.mm +++ b/aten/src/ATen/native/mps/OperationUtils.mm @@ -1,65 +1,12 @@ // Copyright © 2022 Apple Inc. #include -#include +#include namespace at { namespace native { namespace mps { -uint64_t MPSGeneratorImpl::seed() { - auto random = c10::detail::getNonDeterministicRandom(true); - this->set_current_seed(random); - return random; -} -uint64_t MPSGeneratorImpl::current_seed() const { - return seed_; -} - -void MPSGeneratorImpl::set_current_seed(uint64_t seed) { - seed_ = seed; -} - -MPSGeneratorImpl::MPSGeneratorImpl(DeviceIndex device_index) - : c10::GeneratorImpl{Device(DeviceType::MPS, device_index), - DispatchKeySet(c10::DispatchKey::MPS)} { -} - -const Generator& getDefaultMPSGenerator() { - static auto gen = make_generator(0); - gen.seed(); - return gen; -} -DeviceType MPSGeneratorImpl::device_type() { - return DeviceType::MPS; -} -c10::intrusive_ptr MPSGeneratorImpl::get_state() const { - static const size_t seed_size = sizeof(uint64_t); - static const size_t offset_size = sizeof(int64_t); - static const size_t total_size = seed_size + offset_size; - - auto state_tensor = at::detail::empty_cpu({(int64_t)total_size}, ScalarType::Byte, c10::nullopt, c10::nullopt, c10::nullopt, c10::nullopt); - - return state_tensor.getIntrusivePtr(); -} - -void MPSGeneratorImpl::set_state(const c10::TensorImpl& new_state) { - static const size_t seed_size = sizeof(uint64_t); - - detail::check_rng_state(new_state); - - uint64_t input_seed; - auto new_rng_state = new_state.data(); - memcpy(&input_seed, new_rng_state, seed_size); - this->set_current_seed(input_seed); -} - -MPSGeneratorImpl* MPSGeneratorImpl::clone_impl() const { - auto gen = new MPSGeneratorImpl(0); - gen->set_current_seed(this->seed_); - return gen; -} - void runMPSGraph(MPSStream* mpsStream, MPSGraph* mpsGraph, NSDictionary* feeds, NSDictionary* results) { mpsStream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT_ADAPTIVE); } @@ -116,30 +63,80 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) { } } -std::string getMPSTypeString(ScalarType scalar_type) { +// use short_name to avoid getting extra long cached graph keys with ops such as cat_out(), etc. +std::string getMPSTypeString(ScalarType scalar_type, bool short_name) { switch (scalar_type) { case ScalarType::Double: case ScalarType::Float: - return "Float32"; + return short_name ? "f32" : "Float32"; + case ScalarType::Half: + return short_name ? "f16" : "Float16"; + case ScalarType::Int: + return short_name ? "i32" : "Int32"; + case ScalarType::Long: + return short_name ? "i64" : "Int64"; + case ScalarType::Short: + return short_name ? "i16" : "Int16"; + case ScalarType::Char: + return short_name ? "i8" : "Int8"; + case ScalarType::Byte: + return short_name ? "u8" : "UInt8"; + case ScalarType::Bool: + return short_name ? "b8" : "Bool"; + default: + return "Undefined"; + } +} + +std::string scalarToMetalTypeString(const c10::ScalarType& scalar_type) { + switch (scalar_type) { + case ScalarType::Float: + return "float"; case ScalarType::Half: - return "Float16"; + return "half"; case ScalarType::Int: - return "Int32"; + return "int"; case ScalarType::Long: - return "Int64"; + return "long"; case ScalarType::Short: - return "Int16"; + return "short"; case ScalarType::Char: - return "Int8"; + return "char"; case ScalarType::Byte: - return "UInt8"; + return "uchar"; case ScalarType::Bool: - return "Bool"; + return "bool"; default: + TORCH_CHECK(false, "Undefined type ", scalar_type); return "Undefined"; } } + +NSArray* getTensorAxes(const Tensor& t) { + int64_t ndim = t.dim(); + auto axes = [NSMutableArray arrayWithCapacity:ndim]; + for (const auto i: c10::irange(ndim)) { + axes[i] = [NSNumber numberWithInteger:i]; + } + return axes; +} + +NSArray* getTensorAxes(const Tensor& t, at::OptionalIntArrayRef dim) { + if (dim.has_value() && dim.value().size() != 0) { + IntArrayRef dimValues = dim.value(); + int ndim = dimValues.size(); + auto axes = [NSMutableArray arrayWithCapacity:ndim]; + for (const auto i: c10::irange(ndim)) { + axes[i] = [NSNumber numberWithInteger:dimValues[i]]; + } + + return axes; + } + + return getTensorAxes(t); +} + std::string getMPSShapeString(MPSShape* shape) { std::string str; for(NSNumber *elem in shape) { @@ -154,16 +151,16 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) { return ss.str(); } -std::string getTensorsStringKey(const TensorList& tensors, bool use_scalar_value) { +std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype) { std::string str; - // The key format per tensor would look like ":MPSDataTypeFloat32[1,1,1,10]:" + // The key format per tensor would look like ":Float32[1,1,1,10]:" for (const Tensor& tensor: tensors) { str += ":"; if (tensor.defined()) { - str += getMPSTypeString(tensor.scalar_type()) + "["; + str += getMPSTypeString(tensor.scalar_type(), short_dtype) + "["; // if tensor is a scalar if (tensor.dim() == 0) { - str += (use_scalar_value ? std::to_string(tensor.item().to()) : "Scalar"); + str += "Scalar"; } else { const NSString* ns_shape_key = [[getMPSShape(tensor) valueForKey:@"description"] componentsJoinedByString:@","]; str += std::string(ns_shape_key.UTF8String); @@ -176,16 +173,15 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) { return str; } -MPSShape* getMPSShape(const Tensor& t) { - return getMPSShape(t.sizes()); +MPSShape* getMPSShape(const Tensor& t, c10::MemoryFormat memory_format) { + return getMPSShape(t.sizes(), memory_format); } -MPSShape* getMPSShape(c10::MaybeOwned t) { - const Tensor& t_ = *t; - return getMPSShape(t_); -} - -MPSShape* getMPSShape(IntArrayRef sizes) { +MPSShape* getMPSShape(IntArrayRef sizes, c10::MemoryFormat memory_format) { + if (memory_format == MemoryFormat::ChannelsLast) { + TORCH_INTERNAL_ASSERT(sizes.size() == 4, "ChannelsLast memory format must have 4 dimensions!"); + return @[@(sizes[0]), @(sizes[2]), @(sizes[3]), @(sizes[1])]; + } const int sz = sizes.size(); const int sz_ = (sz > 0) ? sz : 1; @@ -219,13 +215,25 @@ void printTensorNDArray(const Tensor& t) { C10_CLANG_DIAGNOSTIC_POP() } -Placeholder::Placeholder(MPSGraphTensor* mpsGraphTensor, const Tensor& src, MPSShape *mpsShape) : _tensor(src) +MPSNDArray* ndArrayFromTensor(const Tensor& tensor, MPSShape *shape, MPSDataType mpsType) +{ + id buffer = getMTLBufferStorage(tensor); + MPSGraphTensorData* tmpGraphTensorData = [[[MPSGraphTensorData alloc] initWithMTLBuffer:buffer + shape:shape + dataType:mpsType] autorelease]; + + return [tmpGraphTensorData mpsndarray]; +} + +Placeholder::Placeholder(MPSGraphTensor* mpsGraphTensor, const Tensor& src, MPSShape *mpsShape, + bool gatherTensorData, MPSDataType dataType) : _tensor(src) { TORCH_CHECK(src.is_mps(), "Placeholder storage has not been allocated on MPS device!"); // extract the pointer to MTLBuffer from the Tensor's storage id srcBuf = getMTLBufferStorage(src); + bool sliceViewTensor = canSliceViewTensor(src, mpsShape); // a view tensor could be contiguous (e.g., slice ops) or non-contiguous (e.g., transpose()) - if (src.is_view() || !src.is_contiguous()) { + if ((!src.is_contiguous() || (src.is_view() && src.storage_offset() && !sliceViewTensor)) && gatherTensorData) { Tensor emptyShell = Tensor(); // use "_tensor" from Placeholder to retain view's output during its usage in other ops _tensor = gatherViewTensor(src, emptyShell); @@ -236,18 +244,26 @@ void printTensorNDArray(const Tensor& t) { } srcBuf = getMTLBufferStorage(_tensor); } + // tensor.numel() could be zero, but tensor is valid as long as the buffer size is non-zero. // if buffer size is zero in here, it's not a user error. It could be a missing check for // tensor.numel() == 0 in our internal implementations of ops. TORCH_INTERNAL_ASSERT([srcBuf length] > 0, "Placeholder tensor is empty!"); + const MPSDataType mpsDataType = dataType != MPSDataTypeInvalid ? dataType : + _tensor.dim() == 0 ? getMPSScalarType(_tensor.scalar_type()) : getMPSDataType(_tensor.scalar_type()); - const MPSDataType mpsDataType = _tensor.dim() == 0 ? getMPSScalarType(_tensor.scalar_type()) : getMPSDataType(_tensor.scalar_type()); - if (!mpsShape) - mpsShape = getMPSShape(_tensor); + if (src.is_contiguous() && src.storage_offset() && sliceViewTensor) { + _value = getMPSGraphTensorDataForView(src, mpsShape, mpsDataType); + } else { + if (!mpsShape) { + mpsShape = getMPSShape(_tensor); + } + + _value = [[[MPSGraphTensorData alloc] initWithMTLBuffer:srcBuf + shape:mpsShape + dataType:mpsDataType] autorelease]; + } - _value = [[[MPSGraphTensorData alloc] initWithMTLBuffer:srcBuf - shape:mpsShape - dataType:mpsDataType] autorelease]; TORCH_INTERNAL_ASSERT(_value); _placeholder = mpsGraphTensor; } @@ -297,7 +313,7 @@ MPSScalar getMPSScalar(const Scalar& scalar, ScalarType type) { MPSGraphTensorData *result = nullptr; // Scalar pools are only supported on devices with unified memory if (mpsStream->device().hasUnifiedMemory) { - scalar.buffer = at::mps::allocate_scalar_buffer(&scalar.value, scalar.size); + scalar.buffer = getIMPSAllocator()->allocScalarBufferWithValue(&scalar.value, scalar.size); result = [[[MPSGraphTensorData alloc] initWithMTLBuffer: scalar.getMTLBuffer() shape: @[@1] dataType: getMPSScalarType(scalar.type)] autorelease]; @@ -316,7 +332,6 @@ void resize_tensor(Tensor* output) { MPSGraph* make_mps_graph() { MPSGraph* mpsGraph = [[MPSGraph new] autorelease]; - mpsGraph.options = MPSGraphOptionsNone; return mpsGraph; } @@ -338,6 +353,12 @@ void resize_tensor(Tensor* output) { name:nil]; } +MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType) { + return [mpsGraph placeholderWithShape:@[@1] + dataType:dataType + name:nil]; +} + MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph *mpsGraph, const Scalar& scalar) { return [mpsGraph placeholderWithShape:@[@1] dataType:getMPSScalarType(scalar.type()) @@ -350,6 +371,12 @@ void resize_tensor(Tensor* output) { return [mpsGraph castTensor:tensor toType:getMPSScalarType(toType) name:@"castTensor"]; } +MPSGraphTensor* convertNHWCtoNCHW(MPSGraph *mpsGraph, MPSGraphTensor* tensor) { + TORCH_INTERNAL_ASSERT(tensor.shape.count == 4, "Tensor must have 4 dimensions!"); + return [mpsGraph transposeTensor:[mpsGraph transposeTensor:tensor dimension:3 withDimension:2 name:nil] + dimension:2 withDimension:1 name: nil]; +} + string get_mem_format_string(c10::MemoryFormat memory_format) { string mem_format_key; switch(memory_format) { diff --git a/aten/src/ATen/native/mps/operations/Activation.mm b/aten/src/ATen/native/mps/operations/Activation.mm index fca3f3f81b33b..69be087ee2aaf 100644 --- a/aten/src/ATen/native/mps/operations/Activation.mm +++ b/aten/src/ATen/native/mps/operations/Activation.mm @@ -410,65 +410,6 @@ Tensor relu_mps(const Tensor& self) { } -TORCH_IMPL_FUNC(sigmoid_out_mps)( - const Tensor& self, - const Tensor& output) { - using namespace mps; - using CachedGraph = MPSUnaryCachedGraph; - TORCH_CHECK(output.is_mps()); - - if(output.numel() == 0) { - return; - } - - MPSGraphCache* cache_ = MPSGraphCache::getInstance(); - - MPSStream* stream = getCurrentMPSStream(); - - @autoreleasepool { - string key = "sigmoid_out_mps" + getTensorsStringKey({self}); - CachedGraph* cachedGraph = cache_->LookUpAs(key); - if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { - - CachedGraph *newCachedGraph = nil; - - @autoreleasepool { - // Initialize graph - MPSGraph* mpsGraph = make_mps_graph(); - newCachedGraph = new CachedGraph(mpsGraph); - - MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); - - MPSGraphTensor* outputTensor = [mpsGraph sigmoidWithTensor:inputTensor - name:nil]; - - newCachedGraph->inputTensor_ = inputTensor; - newCachedGraph->outputTensor_ = outputTensor; - } - return newCachedGraph; - }); - cachedGraph = static_cast(tmpCachedGraph); - } - - Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); - Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output); - - // Create dictionary of inputs and outputs - NSDictionary* feeds = @{ - selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData() - }; - - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; - - runMPSGraph(stream, cachedGraph->graph(), feeds, results); - - } - -} - TORCH_IMPL_FUNC(sigmoid_backward_out_mps)( const Tensor& grad_output, const Tensor& output, @@ -803,6 +744,51 @@ Tensor relu_mps(const Tensor& self) { return erfTensor; } +MPSGraphTensor* tanh (MPSGraph* mpsGraph, MPSGraphTensor *inputTensor) { + // 0.5 * x * (1 + text{Tanh}(sqrt(2 / pi) * (x + 0.044715 * x^3))) + auto dataType = [inputTensor dataType]; + const float SQRT2_PI = 0.797884523868560791015625f; + const float VAL = 0.044715f; + MPSGraphTensor *onef = [mpsGraph constantWithScalar: 1.0f + shape: @[@1] + dataType: dataType]; + MPSGraphTensor *halff = [mpsGraph constantWithScalar: 0.5f + shape: @[@1] + dataType: dataType]; + MPSGraphTensor *sqrt2_pi = [mpsGraph constantWithScalar: SQRT2_PI + shape: @[@1] + dataType: dataType]; + MPSGraphTensor *valf = [mpsGraph constantWithScalar: VAL + shape: @[@1] + dataType: dataType]; + + MPSGraphTensor *erfTensor = [mpsGraph multiplicationWithPrimaryTensor: inputTensor + secondaryTensor: inputTensor + name : nil]; + erfTensor = [mpsGraph multiplicationWithPrimaryTensor: erfTensor + secondaryTensor: inputTensor + name : nil]; + erfTensor = [mpsGraph multiplicationWithPrimaryTensor: erfTensor + secondaryTensor: valf + name : nil]; + erfTensor = [mpsGraph additionWithPrimaryTensor: erfTensor + secondaryTensor: inputTensor + name : nil]; + erfTensor = [mpsGraph multiplicationWithPrimaryTensor: erfTensor + secondaryTensor: sqrt2_pi + name : nil]; + erfTensor = [mpsGraph tanhWithTensor: erfTensor + name : nil]; + erfTensor = [mpsGraph additionWithPrimaryTensor: erfTensor + secondaryTensor: onef + name : nil]; + erfTensor = [mpsGraph multiplicationWithPrimaryTensor: erfTensor + secondaryTensor: halff + name : nil]; + + return erfTensor; +} + TORCH_IMPL_FUNC(gelu_out_mps) ( const Tensor& self, c10::string_view approximate, const Tensor& output ) { @@ -826,7 +812,7 @@ Tensor relu_mps(const Tensor& self) { MPSStream* stream = getCurrentMPSStream(); @autoreleasepool { - string key = "gelu_out_mps" + getTensorsStringKey({self}); + string key = "gelu_out_mps" + getTensorsStringKey({self}) + ":" + c10::str(approximate); CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); if(!cachedGraph) { MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { @@ -841,7 +827,12 @@ Tensor relu_mps(const Tensor& self) { getMPSDataType(self.scalar_type()), getMPSShape(self)); - MPSGraphTensor* outputTensor = normcdf(mpsGraph, inputTensor); + MPSGraphTensor* outputTensor = nil; + if(approximate == "tanh") { + outputTensor = tanh(mpsGraph, inputTensor); + } else { + outputTensor = normcdf(mpsGraph, inputTensor); + } outputTensor = [mpsGraph multiplicationWithPrimaryTensor:outputTensor secondaryTensor:inputTensor name:nil]; @@ -1464,16 +1455,18 @@ Tensor glu_backward_mps (const Tensor& grad_output, CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} MPSGraphTensor *inputTensor_ = nil; MPSGraphTensor *betaTensor_ = nil; + MPSGraphTensor *thresholdTensor_ = nil; MPSGraphTensor *outputTensor_ = nil; }; MPSGraphCache* cache_ = MPSGraphCache::getInstance(); MPSStream* stream = getCurrentMPSStream(); - MPSScalar beta_scalar = getMPSScalar(beta, ScalarType::Float);; + MPSScalar beta_scalar = getMPSScalar(beta, ScalarType::Float); + MPSScalar threshold_scalar = getMPSScalar(threshold, ScalarType::Float); @autoreleasepool { - string key = "softplus_out_mps:" + getTensorsStringKey({self}); + string key = "softplus_out_mps:" + getTensorsStringKey({self}) + ":" + std::to_string(beta.to()) + ":" + std::to_string(threshold.to()); CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); if(!cachedGraph) { @@ -1486,7 +1479,9 @@ Tensor glu_backward_mps (const Tensor& grad_output, newCachedGraph = new CachedGraph(mpsGraph); MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); - MPSGraphTensor* betaTensor = mpsGraphScalarPlaceHolder(mpsGraph, beta); + MPSGraphTensor* betaTensor = mpsGraphScalarPlaceHolder(mpsGraph, getMPSDataType(ScalarType::Float)); + + MPSGraphTensor* thresholdTensor = mpsGraphScalarPlaceHolder(mpsGraph, getMPSDataType(ScalarType::Float)); MPSGraphTensor* reluTensor = [mpsGraph reLUWithTensor:inputTensor name:nil]; @@ -1499,9 +1494,6 @@ Tensor glu_backward_mps (const Tensor& grad_output, MPSGraphTensor* bxTensor = [mpsGraph multiplicationWithPrimaryTensor:inputTensor secondaryTensor:betaTensor name:nil]; - MPSGraphTensor* thresholdTensor = [mpsGraph constantWithScalar:threshold.to() - shape:@[@1] - dataType:getMPSDataType(self.scalar_type())]; MPSGraphTensor* predicateTensor = [mpsGraph greaterThanWithPrimaryTensor:bxTensor secondaryTensor:thresholdTensor name:nil]; @@ -1524,6 +1516,7 @@ Tensor glu_backward_mps (const Tensor& grad_output, newCachedGraph->inputTensor_ = inputTensor; newCachedGraph->betaTensor_ = betaTensor; + newCachedGraph->thresholdTensor_ = thresholdTensor; newCachedGraph->outputTensor_ = outputTensor; } return newCachedGraph; @@ -1536,7 +1529,8 @@ Tensor glu_backward_mps (const Tensor& grad_output, // Create dictionary of inputs and outputs NSDictionary* feeds = @{ selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(), - cachedGraph->betaTensor_ : getMPSGraphTensorFromScalar(stream, beta_scalar) + cachedGraph->betaTensor_ : getMPSGraphTensorFromScalar(stream, beta_scalar), + cachedGraph->thresholdTensor_ : getMPSGraphTensorFromScalar(stream, threshold_scalar), }; NSDictionary* results = @{ outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() @@ -1559,7 +1553,8 @@ Tensor glu_backward_mps (const Tensor& grad_output, if(grad_input.numel() == 0) return; - MPSScalar beta_scalar = getMPSScalar(beta, ScalarType::Float);; + MPSScalar beta_scalar = getMPSScalar(beta, ScalarType::Float); + MPSScalar threshold_scalar = getMPSScalar(threshold, ScalarType::Float); struct CachedGraph : public MPSCachedGraph { @@ -1567,6 +1562,7 @@ Tensor glu_backward_mps (const Tensor& grad_output, MPSGraphTensor *gradOutputTensor_ = nil; MPSGraphTensor *inputTensor_ = nil; MPSGraphTensor *betaTensor_ = nil; + MPSGraphTensor *thresholdTensor_ = nil; MPSGraphTensor *outputTensor_ = nil; }; @@ -1575,7 +1571,7 @@ Tensor glu_backward_mps (const Tensor& grad_output, MPSStream* stream = getCurrentMPSStream(); @autoreleasepool { - string key = "softplus_backward_out_mps:" + getTensorsStringKey({grad_output, self}); + string key = "softplus_backward_out_mps:" + getTensorsStringKey({grad_output, self}) + ":" + std::to_string(beta.to()) + ":" + std::to_string(threshold.to()); CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); if(!cachedGraph) { @@ -1590,7 +1586,9 @@ Tensor glu_backward_mps (const Tensor& grad_output, MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); - MPSGraphTensor* betaTensor = mpsGraphScalarPlaceHolder(mpsGraph, beta); + MPSGraphTensor* betaTensor = mpsGraphScalarPlaceHolder(mpsGraph, getMPSScalarType(ScalarType::Float)); + + MPSGraphTensor* thresholdTensor = mpsGraphScalarPlaceHolder(mpsGraph, getMPSScalarType(ScalarType::Float)); MPSGraphTensor* unitTensor = [mpsGraph constantWithScalar:1.0 shape:@[@1] @@ -1609,9 +1607,6 @@ Tensor glu_backward_mps (const Tensor& grad_output, rTensor = [mpsGraph divisionWithPrimaryTensor:rTensor secondaryTensor:unitExpBxTensor name:nil]; - MPSGraphTensor* thresholdTensor = [mpsGraph constantWithScalar:threshold.to() - shape:@[@1] - dataType:getMPSDataType(self.scalar_type())]; MPSGraphTensor* predicateTensor = [mpsGraph greaterThanWithPrimaryTensor:bxTensor secondaryTensor:thresholdTensor name:nil]; @@ -1623,6 +1618,7 @@ Tensor glu_backward_mps (const Tensor& grad_output, newCachedGraph->gradOutputTensor_ = gradOutputTensor; newCachedGraph->inputTensor_ = inputTensor; newCachedGraph->betaTensor_ = betaTensor; + newCachedGraph->thresholdTensor_ = thresholdTensor; newCachedGraph->outputTensor_ = outputTensor; } return newCachedGraph; @@ -1637,7 +1633,8 @@ Tensor glu_backward_mps (const Tensor& grad_output, NSDictionary* feeds = @{ gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData(), selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(), - cachedGraph->betaTensor_ : getMPSGraphTensorFromScalar(stream, beta_scalar) + cachedGraph->betaTensor_ : getMPSGraphTensorFromScalar(stream, beta_scalar), + cachedGraph->thresholdTensor_ : getMPSGraphTensorFromScalar(stream, threshold_scalar), }; NSDictionary* results = @{ gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData() @@ -2196,5 +2193,245 @@ Tensor prelu_mps(const Tensor& self, const Tensor& weight_) { return grad_input; } +Tensor& hardswish_out_mps(const Tensor& self, Tensor& output) { + using namespace mps; + using CachedGraph = MPSUnaryCachedGraph; + + TORCH_CHECK(self.is_mps()); + + if (output.numel() == 0) { + return output; + } + + MPSGraphCache* cache_ = MPSGraphCache::getInstance(); + + MPSStream* stream = at::mps::getCurrentMPSStream(); + + @autoreleasepool { + string key = "hardswish_out_mps" + getTensorsStringKey({self}); + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + if (!cachedGraph) { + MPSCachedGraph* tmpCachedGraph = + cache_->CreateCachedGraph(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; + @autoreleasepool { + MPSGraph* mpsGraph = make_mps_graph(); + newCachedGraph = new CachedGraph(mpsGraph); + MPSGraphTensor* inputTensor = + mpsGraphRankedPlaceHolder(mpsGraph, self); + + MPSGraphTensor* zeroTensor = [mpsGraph + constantWithScalar:0.0f + shape:@[ @1 ] + dataType:getMPSDataType(self.scalar_type())]; + + MPSGraphTensor* threeTensor = [mpsGraph + constantWithScalar:3.0f + shape:@[ @1 ] + dataType:getMPSDataType(self.scalar_type())]; + + MPSGraphTensor* negativeThreeTensor = [mpsGraph + constantWithScalar:-3.0f + shape:@[ @1 ] + dataType:getMPSDataType(self.scalar_type())]; + + MPSGraphTensor* sixTensor = [mpsGraph + constantWithScalar:6.0f + shape:@[ @1 ] + dataType:getMPSDataType(self.scalar_type())]; + + MPSGraphTensor* lessThanMinPredicateTensor = [mpsGraph + lessThanOrEqualToWithPrimaryTensor:inputTensor + secondaryTensor:negativeThreeTensor + name:nil]; + + MPSGraphTensor* lessThanMaxPredicateTensor = + [mpsGraph lessThanWithPrimaryTensor:inputTensor + secondaryTensor:threeTensor + name:nil]; + + MPSGraphTensor* inputPlusThreeTensor = + [mpsGraph additionWithPrimaryTensor:inputTensor + secondaryTensor:threeTensor + name:nil]; + + MPSGraphTensor* inputDivSixTensor = + [mpsGraph divisionWithPrimaryTensor:inputPlusThreeTensor + secondaryTensor:sixTensor + name:nil]; + + MPSGraphTensor* weightedTensor = + [mpsGraph multiplicationWithPrimaryTensor:inputTensor + secondaryTensor:inputDivSixTensor + name:nil]; + + MPSGraphTensor* tempTensor = + [mpsGraph selectWithPredicateTensor:lessThanMaxPredicateTensor + truePredicateTensor:weightedTensor + falsePredicateTensor:inputTensor + name:nil]; + + MPSGraphTensor* outputTensor = + [mpsGraph selectWithPredicateTensor:lessThanMinPredicateTensor + truePredicateTensor:zeroTensor + falsePredicateTensor:tempTensor + name:nil]; + newCachedGraph->inputTensor_ = inputTensor; + newCachedGraph->outputTensor_ = outputTensor; + } + return newCachedGraph; + }); + cachedGraph = static_cast(tmpCachedGraph); + } + Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); + Placeholder outputPlaceholder = + Placeholder(cachedGraph->outputTensor_, output); + + // Create dictionary of inputs and outputs + NSDictionary* feeds = @{ + selfPlaceholder.getMPSGraphTensor() : + selfPlaceholder.getMPSGraphTensorData() + }; + + NSDictionary* results = @{ + outputPlaceholder.getMPSGraphTensor() : + outputPlaceholder.getMPSGraphTensorData() + }; + + runMPSGraph(stream, cachedGraph->graph(), feeds, results); + } + return output; +} + +Tensor hardswish_mps(const Tensor& self) { + using namespace mps; + Tensor output = at::empty_like(self, self.suggest_memory_format()); + + return hardswish_out_mps(self, output); +} + +Tensor& hardswish_mps_(Tensor& self) { + using namespace mps; + Tensor& output = self; + + return hardswish_out_mps(self, output); +} + +Tensor hardswish_backward_mps(const Tensor& grad_output, const Tensor& self) { + using namespace mps; + + Tensor grad_input = at::empty_like(self, self.suggest_memory_format()); + if (grad_input.numel() == 0) { + return grad_input; + } + + struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* gradOutputTensor_ = nil; + MPSGraphTensor* inputTensor_ = nil; + MPSGraphTensor* gradInputTensor_ = nil; + }; + + MPSGraphCache* cache_ = MPSGraphCache::getInstance(); + + @autoreleasepool { + string key = "hardswish_backward_mps" + getTensorsStringKey({self}); + CachedGraph* cachedGraph = cache_->LookUpAs(key); + if (!cachedGraph) { + cachedGraph = cache_->CreateCachedGraphAs(key, ^MPSCachedGraph*() { + CachedGraph* newCachedGraph = nil; + @autoreleasepool { + MPSGraph* mpsGraph = make_mps_graph(); + newCachedGraph = new CachedGraph(mpsGraph); + MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output); + MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); + + MPSGraphTensor* zeroTensor = [mpsGraph + constantWithScalar:0.0f + shape:@[ @1 ] + dataType:getMPSDataType(grad_output.scalar_type())]; + + MPSGraphTensor* unitTensor = [mpsGraph + constantWithScalar:1.0f + shape:@[ @1 ] + dataType:getMPSDataType(grad_output.scalar_type())]; + + MPSGraphTensor* threeTensor = [mpsGraph + constantWithScalar:3.0f + shape:@[ @1 ] + dataType:getMPSDataType(grad_output.scalar_type())]; + + MPSGraphTensor* negativeThreeTensor = [mpsGraph + constantWithScalar:-3.0f + shape:@[ @1 ] + dataType:getMPSDataType(grad_output.scalar_type())]; + + MPSGraphTensor* halfTensor = [mpsGraph + constantWithScalar:0.5f + shape:@[ @1 ] + dataType:getMPSDataType(grad_output.scalar_type())]; + + MPSGraphTensor* tempTensor = + [mpsGraph divisionWithPrimaryTensor:inputTensor + secondaryTensor:threeTensor + name:nil]; + + MPSGraphTensor* weightedTensor = + [mpsGraph additionWithPrimaryTensor:tempTensor + secondaryTensor:halfTensor + name:nil]; + + MPSGraphTensor* lessThanMinPredicateTensor = [mpsGraph + lessThanOrEqualToWithPrimaryTensor:inputTensor + secondaryTensor:negativeThreeTensor + name:nil]; + + MPSGraphTensor* lessThanMaxPredicateTensor = + [mpsGraph lessThanWithPrimaryTensor:inputTensor + secondaryTensor:threeTensor + name:nil]; + + MPSGraphTensor* lessThanMaxGradTensor = + [mpsGraph selectWithPredicateTensor:lessThanMaxPredicateTensor + truePredicateTensor:weightedTensor + falsePredicateTensor:unitTensor + name:nil]; + + MPSGraphTensor* gradTensor = + [mpsGraph selectWithPredicateTensor:lessThanMinPredicateTensor + truePredicateTensor:zeroTensor + falsePredicateTensor:lessThanMaxGradTensor + name:nil]; + MPSGraphTensor* gradInputTensor = + [mpsGraph multiplicationWithPrimaryTensor:gradTensor + secondaryTensor:gradOutputTensor + name:nil]; + + newCachedGraph->gradOutputTensor_ = gradOutputTensor; + newCachedGraph->inputTensor_ = inputTensor; + newCachedGraph->gradInputTensor_ = gradInputTensor; + } + return newCachedGraph; + }); + } + + Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output); + Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); + Placeholder gradInputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, grad_input); + + // Create dictionary of inputs and outputs + NSDictionary* feeds = @{ + gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData(), + selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData() + }; + + NSDictionary* results = @{ + gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData() + }; + + runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results); + } + return grad_input; +} } // namespace native } // namespace at diff --git a/aten/src/ATen/native/mps/operations/BinaryOps.mm b/aten/src/ATen/native/mps/operations/BinaryOps.mm index a246bb0c50f07..83f2535d0188e 100644 --- a/aten/src/ATen/native/mps/operations/BinaryOps.mm +++ b/aten/src/ATen/native/mps/operations/BinaryOps.mm @@ -27,6 +27,10 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha, const Tensor& output_, std::string op_name, BinaryOpBlock binaryBlock) { + TORCH_CHECK(!(op_name == "power" && !is_macos_13_or_newer(2) && + (self.scalar_type() == ScalarType::Long || + (other.scalar_type() == ScalarType::Long && (self.scalar_type() != ScalarType::Half && self.scalar_type() != ScalarType::Float)))), + "MPS: ", op_name, " op with int64 input is supported natively starting from macOS 13.2"); MPSStream* mpsStream = getCurrentMPSStream(); const bool is_self_scalar = self.dim() == 0; @@ -54,9 +58,26 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha needsCopyToOutput = true; } + auto inputDataType = self.scalar_type(); + auto otherDataType = other.scalar_type(); + auto outputDataType = output_.scalar_type(); + if (!is_macos_13_or_newer()) { + // workaround for signed vs. unsigned comparison issue in MacOS 12 + if (outputDataType == kBool && (inputDataType == kByte || otherDataType == kByte)) { + inputDataType = otherDataType = kByte; + } else { + if (inputDataType == kBool || inputDataType == kByte) { + inputDataType = kChar; + } + if (otherDataType == kBool || otherDataType == kByte) { + otherDataType = kChar; + } + } + } + MPSGraphCache* cache_ = MPSGraphCache::getInstance(); @autoreleasepool { - string key = op_name + getTensorsStringKey({self, other, output_}, /*use_scalar_value*/ false); + string key = op_name + getTensorsStringKey({self, other, output_}); BinaryOpCachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); if(!cachedGraph) { @@ -65,46 +86,37 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); newCachedGraph = new BinaryOpCachedGraph(mpsGraph); - newCachedGraph->primaryTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); - newCachedGraph->secondaryTensor = mpsGraphRankedPlaceHolder(mpsGraph, other); + newCachedGraph->primaryTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(inputDataType), getMPSShape(self)); + newCachedGraph->secondaryTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(otherDataType), getMPSShape(other)); MPSGraphTensor* primaryCastTensor = newCachedGraph->primaryTensor; MPSGraphTensor* secondaryCastTensor = newCachedGraph->secondaryTensor; // this type inference is only required at the time of graph creation - const ScalarType common_dtype = c10::promoteTypes(self.scalar_type(), other.scalar_type()); - - // Condition - - // 1. Division operation - // 2. Inputs are not float - bool div_condition = op_name.rfind("div", 0) == 0 - && (!(common_dtype == ScalarType::Float || common_dtype == ScalarType::Half)); - - auto compute_type = ScalarType::Float; - - if(div_condition) { - - if(output_.scalar_type() == ScalarType::Float || output_.scalar_type() == ScalarType::Half) - compute_type = output_.scalar_type(); - - primaryCastTensor = castMPSTensor(mpsGraph, newCachedGraph->primaryTensor, compute_type); - secondaryCastTensor = castMPSTensor(mpsGraph, newCachedGraph->secondaryTensor, compute_type); - } - else { - if (self.scalar_type() != common_dtype) { - primaryCastTensor = castMPSTensor(mpsGraph, newCachedGraph->primaryTensor, common_dtype); - } - if (other.scalar_type() != common_dtype) { - secondaryCastTensor = castMPSTensor(mpsGraph, newCachedGraph->secondaryTensor, common_dtype); + ScalarType common_dtype = c10::promoteTypes(inputDataType, otherDataType); + if (isIntegralType(common_dtype, true)) { + // integer inputs must be cast to float, if output is float + if (isFloatingType(outputDataType)) { + common_dtype = outputDataType; + // in boolean comparison ops with signed vs. unsigned integers, we always cast to the unsigned type + } else if (outputDataType == ScalarType::Bool && + (inputDataType == ScalarType::Byte || + otherDataType == ScalarType::Byte)) { + common_dtype = ScalarType::Byte; } } + if (inputDataType != common_dtype) { + primaryCastTensor = castMPSTensor(mpsGraph, newCachedGraph->primaryTensor, common_dtype); + } + if (otherDataType != common_dtype) { + secondaryCastTensor = castMPSTensor(mpsGraph, newCachedGraph->secondaryTensor, common_dtype); + } newCachedGraph->outputTensor = binaryBlock(newCachedGraph, primaryCastTensor, secondaryCastTensor); // Cast output tensor to an expected type if needed, which addresses discrepancy when int64 scalar is added to int32 tensor // Output tensor should have been promoted but it remains an int32 tensor - - if ((div_condition && compute_type != output_.scalar_type()) || - output_.scalar_type() != common_dtype) { - newCachedGraph->outputTensor = castMPSTensor(mpsGraph, newCachedGraph->outputTensor, output_.scalar_type()); + if (outputDataType != common_dtype || + [newCachedGraph->outputTensor dataType] != getMPSDataType(outputDataType)) { + newCachedGraph->outputTensor = castMPSTensor(mpsGraph, newCachedGraph->outputTensor, outputDataType); } } return newCachedGraph; @@ -120,17 +132,19 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha MPSScalar alpha_scalar; if (is_self_scalar && !self.is_mps()) { - self_scalar = getMPSScalar(self.item(), self.scalar_type()); + self_scalar = getMPSScalar(self.item(), inputDataType); feeds[cachedGraph->primaryTensor] = getMPSGraphTensorFromScalar(mpsStream, self_scalar); } else { - selfPlaceholder = Placeholder(cachedGraph->primaryTensor, self); + selfPlaceholder = Placeholder(cachedGraph->primaryTensor, self, /*mpsShape*/nil, + /*gatherTensorData=*/true, getMPSScalarType(inputDataType)); feeds[selfPlaceholder.getMPSGraphTensor()] = selfPlaceholder.getMPSGraphTensorData(); } if (is_other_scalar && !other.is_mps()) { - other_scalar = getMPSScalar(other.item(), other.scalar_type()); + other_scalar = getMPSScalar(other.item(), otherDataType); feeds[cachedGraph->secondaryTensor] = getMPSGraphTensorFromScalar(mpsStream, other_scalar); } else { - otherPlaceholder = Placeholder(cachedGraph->secondaryTensor, other); + otherPlaceholder = Placeholder(cachedGraph->secondaryTensor, other, /*mpsShape*/nil, + /*gatherTensorData=*/true, getMPSScalarType(otherDataType)); feeds[otherPlaceholder.getMPSGraphTensor()] = otherPlaceholder.getMPSGraphTensorData(); } @@ -162,8 +176,21 @@ void div_mode_template(const Tensor& self, const Tensor& other, c10::optional rounding_mode, const Tensor& output, const string op_name) { + if(rounding_mode.has_value() && *rounding_mode == "floor"){ + TORCH_CHECK(self.scalar_type() != ScalarType::Long, + "MPS: does not support floor_divide op with int64 input"); + } BinaryOpBlock div_mode_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) { MPSGraph* mpsGraph = cachedGraph->graph(); + bool isFloatInput = ([primaryCastTensor dataType] & MPSDataTypeFloatBit) != 0; + if(!isFloatInput && rounding_mode.has_value() && *rounding_mode == "floor") { + primaryCastTensor = [mpsGraph castTensor:primaryCastTensor + toType:MPSDataTypeFloat32 + name:@"primaryCastTensor"]; + secondaryCastTensor = [mpsGraph castTensor:secondaryCastTensor + toType:MPSDataTypeFloat32 + name:@"secondaryCastTensor"]; + } MPSGraphTensor* divTensor = [mpsGraph divisionWithPrimaryTensor:primaryCastTensor secondaryTensor:secondaryCastTensor name:nil]; @@ -181,7 +208,7 @@ void div_mode_template(const Tensor& self, const Tensor& other, assert(0 && "Invalid rounding mode\n"); return nullptr; }; - binaryOpTensor(self, other, Scalar(1.0), output, op_name + "_out_mps:" + (rounding_mode.has_value() ? c10::str(*rounding_mode) : ""), div_mode_op_block); + binaryOpTensor(self, other, Scalar(1.0), output, op_name + "_mps:" + (rounding_mode.has_value() ? c10::str(*rounding_mode) : ""), div_mode_op_block); } void add_sub_template(const Tensor& self, const Tensor& other, const Scalar& alpha, const Tensor& output, std::string op_name) @@ -237,7 +264,7 @@ void add_sub_template(const Tensor& self, const Tensor& other, const Scalar& alp #define CREATE_MPS_STRUCTURED_BINARY_OP_FUNC(func_out, func_stub, other_type) \ TORCH_IMPL_FUNC(func_out) (const Tensor& self, const other_type& other, const Tensor& output) { \ TORCH_CHECK(!(self.scalar_type() == ScalarType::Long && \ - (std::string(#func_stub) == "power" || std::string(#func_stub) == "atan2")), \ + std::string(#func_stub) == "atan2"), \ "MPS does not support ", #func_stub, " op with int64 input") \ mps::binaryOp##other_type(self, other, Scalar(1.0), output, #func_stub, \ ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) { \ @@ -247,16 +274,15 @@ void add_sub_template(const Tensor& self, const Tensor& other, const Scalar& alp name:nil]; }); \ } -// Boolean Ops require casting output to "MPSDataTypeBool" +// output of Boolean Ops will be cast to "MPSDataTypeBool" at the end of binaryOpTensor() #define CREATE_MPS_STRUCTURED_BOOLEAN_OP_FUNC(func_out, func_stub, other_type) \ TORCH_IMPL_FUNC(func_out) (const Tensor& self, const other_type& other, const Tensor& output) { \ mps::binaryOp##other_type(self, other, Scalar(1.0), output, #func_stub, \ ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) { \ MPSGraph* mpsGraph = cachedGraph->graph(); \ - MPSGraphTensor* outputTensor = [mpsGraph func_stub##WithPrimaryTensor:primaryCastTensor \ - secondaryTensor:secondaryCastTensor \ - name:nil]; \ - return mps::castMPSTensor(mpsGraph, outputTensor, ScalarType::Bool); }); \ + return [mpsGraph func_stub##WithPrimaryTensor:primaryCastTensor \ + secondaryTensor:secondaryCastTensor \ + name:nil]; }); \ } // Boolean Binary Ops @@ -287,11 +313,11 @@ void add_sub_template(const Tensor& self, const Tensor& other, const Scalar& alp TORCH_IMPL_FUNC(div_out_mode_mps) (const Tensor& self, const Tensor& other, c10::optional rounding_mode, const Tensor& output) { - mps::div_mode_template(self, other, rounding_mode, output, "div_mode"); + mps::div_mode_template(self, other, rounding_mode, output, "div_mode_out"); } TORCH_IMPL_FUNC(div_out_mps) (const Tensor& self, const Tensor& other, const Tensor& output) { - mps::div_mode_template(self, other, c10::nullopt, output, "div"); + mps::div_mode_template(self, other, c10::nullopt, output, "div_out"); } TORCH_IMPL_FUNC(add_out_mps) (const Tensor& self, const Tensor& other, const Scalar& alpha, const Tensor& output) { @@ -302,142 +328,69 @@ void add_sub_template(const Tensor& self, const Tensor& other, const Scalar& alp mps::add_sub_template(self, other, alpha, output, "sub"); } +Tensor& floor_divide_out_mps(const Tensor& self, const Tensor& other, Tensor& result) { + mps::div_mode_template(self, other, "floor", result, "floor_divide_out"); + return result; +} -TORCH_IMPL_FUNC(logaddexp_out_mps) (const Tensor& self, const Tensor& other, const Tensor& output) -{ - using namespace mps; - MPSStream* stream = getCurrentMPSStream(); - - if (&output != &self) { - output.resize_(self.sizes());; - } +Tensor floor_divide_mps(const Tensor& self, const Tensor& other) { + Tensor output = at::empty_like(self); + mps::div_mode_template(self, other, "floor", output, "floor_divide"); + return output; +} - // Derive from MPSCachedGraph - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *inputTensor_ = nil; - MPSGraphTensor *otherTensor_ = nil; - MPSGraphTensor *outputTensor_ = nil; - }; - - MPSGraphCache* cache_ = MPSGraphCache::getInstance(); - - @autoreleasepool { - string key = "log_base_e_out_mps:" + getTensorsStringKey({self, other}); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - - if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { - CachedGraph *newCachedGraph = nil; - - @autoreleasepool { - MPSGraph* mpsGraph = make_mps_graph(); - newCachedGraph = new CachedGraph(mpsGraph); - MPSGraphTensor* xTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); - MPSGraphTensor* yTensor = mpsGraphRankedPlaceHolder(mpsGraph, other); - MPSGraphTensor* ePowXTensor = [mpsGraph exponentWithTensor:xTensor - name:nil]; - MPSGraphTensor* ePowYTensor = [mpsGraph exponentWithTensor:yTensor - name:nil]; - MPSGraphTensor* sumTensor = [mpsGraph additionWithPrimaryTensor:ePowXTensor - secondaryTensor:ePowYTensor - name:nil]; - MPSGraphTensor* outputTensor = [mpsGraph logarithmWithTensor:sumTensor - name:nil]; - - newCachedGraph->inputTensor_ = xTensor; - newCachedGraph->otherTensor_ = yTensor; - newCachedGraph->outputTensor_ = outputTensor; - } - return newCachedGraph; - }); - cachedGraph = static_cast(tmpCachedGraph); - } +Tensor& floor_divide_mps_(Tensor& self, const Tensor& other) { + return floor_divide_out_mps(self, other, self); +} - Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); - Placeholder otherPlaceholder = Placeholder(cachedGraph->otherTensor_, other); - Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output); +TORCH_IMPL_FUNC(remainder_out_mps) (const Tensor& self, const Tensor& other, const Tensor& output) { + // torch.remainder(a, b) == a - a.div(b, rounding_mode="floor") * b + mps::BinaryOpBlock remainder_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) { + MPSGraph* mpsGraph = cachedGraph->graph(); + // Rounding is a no-op for integral types, and also a reasonable workaround + // For MPSGraph bug on Apple Silicon, that throws `Function floorOp_i64 was not found in the library` + // See https://github.com/pytorch/pytorch/issues/84995 - NSDictionary* feeds = @{ - selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(), - otherPlaceholder.getMPSGraphTensor() : otherPlaceholder.getMPSGraphTensorData() - }; - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + auto divTensor = [mpsGraph divisionWithPrimaryTensor:primaryCastTensor + secondaryTensor:secondaryCastTensor + name:nil]; + bool isFloatOutput = ([divTensor dataType] & MPSDataTypeFloatBit) != 0; + if (isFloatOutput) { + divTensor = [mpsGraph floorWithTensor:divTensor name:nil]; + } - runMPSGraph(stream, cachedGraph->graph(), feeds, results); - } + auto mulTensor = [mpsGraph multiplicationWithPrimaryTensor:divTensor + secondaryTensor:secondaryCastTensor + name:nil]; + return [mpsGraph subtractionWithPrimaryTensor:primaryCastTensor + secondaryTensor:mulTensor + name: nil]; + }; + mps::binaryOpTensor(self, other, Scalar(1.0), output, "remainder_out_mps", remainder_op_block); +} - } +TORCH_IMPL_FUNC(logaddexp_out_mps) (const Tensor& self, const Tensor& other, const Tensor& output) +{ + mps::BinaryOpBlock logaddexp_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) { + MPSGraph* mpsGraph = cachedGraph->graph(); + MPSGraphTensor* sumTensor = [mpsGraph additionWithPrimaryTensor:[mpsGraph exponentWithTensor:primaryCastTensor name:nil] + secondaryTensor:[mpsGraph exponentWithTensor:secondaryCastTensor name:nil] + name:nil]; + return [mpsGraph logarithmWithTensor:sumTensor name:nil]; + }; + mps::binaryOpTensor(self, other, Scalar(1.0), output, "logaddexp_out_mps", logaddexp_op_block); +} TORCH_IMPL_FUNC(logaddexp2_out_mps) (const Tensor& self, const Tensor& other, const Tensor& output) { - using namespace mps; - MPSStream* stream = getCurrentMPSStream(); - - if (&output != &self) { - output.resize_(self.sizes());; - } - - // Derive from MPSCachedGraph - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *inputTensor_ = nil; - MPSGraphTensor *otherTensor_ = nil; - MPSGraphTensor *outputTensor_ = nil; - }; - - MPSGraphCache* cache_ = MPSGraphCache::getInstance(); - - @autoreleasepool { - string key = "log_base_two_out_mps:" + getTensorsStringKey({self, other}); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - - if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { - CachedGraph *newCachedGraph = nil; - - @autoreleasepool { - MPSGraph* mpsGraph = make_mps_graph(); - newCachedGraph = new CachedGraph(mpsGraph); - MPSGraphTensor* xTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); - MPSGraphTensor* yTensor = mpsGraphRankedPlaceHolder(mpsGraph, other); - MPSGraphTensor* twoPowXTensor = [mpsGraph exponentBase2WithTensor:xTensor - name:nil]; - MPSGraphTensor* twoPowYTensor = [mpsGraph exponentBase2WithTensor:yTensor - name:nil]; - MPSGraphTensor* sumTensor = [mpsGraph additionWithPrimaryTensor:twoPowXTensor - secondaryTensor:twoPowYTensor - name:nil]; - MPSGraphTensor* outputTensor = [mpsGraph logarithmBase2WithTensor:sumTensor - name:nil]; - - newCachedGraph->inputTensor_ = xTensor; - newCachedGraph->otherTensor_ = yTensor; - newCachedGraph->outputTensor_ = outputTensor; - } - return newCachedGraph; - }); - cachedGraph = static_cast(tmpCachedGraph); - } - - Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); - Placeholder otherPlaceholder = Placeholder(cachedGraph->otherTensor_, other); - Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output); - - NSDictionary* feeds = @{ - selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(), - otherPlaceholder.getMPSGraphTensor() : otherPlaceholder.getMPSGraphTensorData() - }; - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; - - runMPSGraph(stream, cachedGraph->graph(), feeds, results); - } + mps::BinaryOpBlock logaddexp2_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) { + MPSGraph* mpsGraph = cachedGraph->graph(); + MPSGraphTensor* sumTensor = [mpsGraph additionWithPrimaryTensor:[mpsGraph exponentBase2WithTensor:primaryCastTensor name:nil] + secondaryTensor:[mpsGraph exponentBase2WithTensor:secondaryCastTensor name:nil] + name:nil]; + return [mpsGraph logarithmBase2WithTensor:sumTensor name:nil]; + }; + mps::binaryOpTensor(self, other, Scalar(1.0), output, "logaddexp2_out_mps", logaddexp2_op_block); } } // namespace native diff --git a/aten/src/ATen/native/mps/operations/Blas.mm b/aten/src/ATen/native/mps/operations/Blas.mm index 20a3ec5eb6db4..31b0592620018 100644 --- a/aten/src/ATen/native/mps/operations/Blas.mm +++ b/aten/src/ATen/native/mps/operations/Blas.mm @@ -21,6 +21,9 @@ Tensor dot_mps( const Tensor &self, const Tensor &other) { + + TORCH_CHECK(self.scalar_type() != ScalarType::Long, "MPS: dot op doesn't support int64 input") + using namespace mps; auto output = at::native::empty_mps({}, self.scalar_type(), c10::nullopt, kMPS, c10::nullopt, c10::nullopt); diff --git a/aten/src/ATen/native/mps/operations/Convolution.mm b/aten/src/ATen/native/mps/operations/Convolution.mm index 88bad9a5872a4..20432d2933e65 100644 --- a/aten/src/ATen/native/mps/operations/Convolution.mm +++ b/aten/src/ATen/native/mps/operations/Convolution.mm @@ -41,7 +41,7 @@ void fill_conv_desc(MPSGraphConvolution2DOpDescriptor* descriptor_, static MPSShape* get_mps_conv_shape(const Tensor& tensor, bool is_channels_last) { - if (is_channels_last) { + if (is_channels_last && tensor.is_contiguous() && !tensor.is_view()) { const auto tensorSizes = tensor.sizes(); const NSUInteger N = tensorSizes[0]; const NSUInteger C = tensorSizes[1]; @@ -226,6 +226,8 @@ Tensor mps_convolution_backward_input( checkAllSameType(c, {grad_output, weight}); checkAllSameGPU(c, {grad_output, weight}); auto memory_format = grad_output_t.suggest_memory_format(); + bool is_channels_last = (memory_format == at::MemoryFormat::ChannelsLast); + auto grad_input_t = at::empty( input_size, grad_output->scalar_type(), @@ -266,8 +268,9 @@ Tensor mps_convolution_backward_input( assert(0 && "Check should have been done earlier\n"); } + MPSShape* gradOutputShape = get_mps_conv_shape(grad_output_t, is_channels_last); MPSShape* mps_input_shape = getMPSShape(input_size); - NSString* ns_shape_key = [[mps_input_shape valueForKey:@"description"] componentsJoinedByString:@","]; + NSString* ns_shape_key = [[gradOutputShape valueForKey:@"description"] componentsJoinedByString:@","]; string key = "mps_convolution_backward_input:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":" + to_string(dilation[0]) + ":" + to_string(dilation[1]) + ":" + to_string(padding[0]) + ":" + to_string(padding[1]) + ":" @@ -289,12 +292,21 @@ Tensor mps_convolution_backward_input( fill_conv_desc(descriptor_, stride[1], stride[0], dilation[1], dilation[0], padding[1], padding[0], - memory_format, groups); + at::MemoryFormat::Contiguous, groups); - MPSGraphTensor* gradOutputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, grad_output_t); + MPSGraphTensor* gradOutputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSScalarType(grad_output_t.scalar_type()), gradOutputShape); MPSGraphTensor* weightTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, weight_t); - MPSGraphTensor* gradInputTensor = [mpsGraph convolution2DDataGradientWithIncomingGradientTensor:gradOutputTensor + MPSGraphTensor *gradOutputTensorTranspose = gradOutputTensor; + if (is_channels_last && grad_output_t.is_contiguous() && !grad_output_t.is_view()) { + // NHWC -> NCHW + gradOutputTensorTranspose = [mpsGraph transposeTensor: [mpsGraph transposeTensor:gradOutputTensor dimension:-1 withDimension:-2 name:nil] + dimension: -2 + withDimension: -3 + name: nil]; + } + + MPSGraphTensor* gradInputTensor = [mpsGraph convolution2DDataGradientWithIncomingGradientTensor:gradOutputTensorTranspose weightsTensor:weightTensor outputShape:mps_input_shape forwardConvolutionDescriptor:descriptor_ @@ -309,7 +321,7 @@ Tensor mps_convolution_backward_input( cachedGraph = static_cast(tmpCachedGraph); } - auto gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output_t); + auto gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output_t, gradOutputShape); auto weightsPlaceholder = Placeholder(cachedGraph->weightTensor_, weight_t); auto outputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, *grad_input); @@ -328,14 +340,17 @@ Tensor mps_convolution_backward_input( } Tensor mps_convolution_backward_weights( - IntArrayRef weight_size, const Tensor& grad_output_t, const Tensor& input_t, + IntArrayRef weight_size, const Tensor& grad_output_, const Tensor& input_, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined) { namespace native_mps = at::native::mps; using namespace mps; CheckedFrom c = "mps_convolution_backward_weights"; - auto memory_format = input_t.suggest_memory_format(); + auto memory_format = input_.suggest_memory_format(); bool is_channels_last = (memory_format == at::MemoryFormat::ChannelsLast); - MPSShape* inputShape = get_mps_conv_shape(input_t, is_channels_last); + + auto grad_output_t = grad_output_.to(memory_format); + auto input_t = input_.to(memory_format); + MPSShape* gradOutputShape = get_mps_conv_shape(grad_output_t, is_channels_last); // For uniformity with everything else, although it seems grad_weight @@ -346,7 +361,13 @@ Tensor mps_convolution_backward_weights( checkAllSameType(c, {grad_output, input}); checkAllSameGPU(c, {grad_output, input}); - auto grad_weight_t = at::empty(weight_size, grad_output_t.options(), c10::nullopt); + auto grad_weight_t = at::empty( + weight_size, + grad_output_t.scalar_type(), + c10::nullopt, + kMPS, + c10::nullopt, + memory_format); TensorArg grad_weight{ grad_weight_t, "result", 0 }; convolution_shape_check(c, input, grad_weight, grad_output, padding, stride, dilation, groups); @@ -377,9 +398,8 @@ Tensor mps_convolution_backward_weights( default: assert(0 && "Check should have been done earlier\n"); } - MPSShape* mps_weight_shape = getMPSShape(weight_size); - NSString* ns_shape_key = [[mps_weight_shape valueForKey:@"description"] componentsJoinedByString:@","]; + NSString* ns_shape_key = [[gradOutputShape valueForKey:@"description"] componentsJoinedByString:@","]; string key = "mps_convolution_backward_weights:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":" + to_string(dilation[0]) + ":" + to_string(dilation[1]) + ":" + to_string(padding[0]) + ":" + to_string(padding[1]) + ":" @@ -401,12 +421,20 @@ Tensor mps_convolution_backward_weights( fill_conv_desc(descriptor_, stride[1], stride[0], dilation[1], dilation[0], padding[1], padding[0], - memory_format, groups); + at::MemoryFormat::Contiguous, groups); MPSGraphTensor* gradOutputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSScalarType(grad_output_t.scalar_type()), gradOutputShape); - MPSGraphTensor* inputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSScalarType(input_t.scalar_type()), inputShape); + MPSGraphTensor* inputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, input_t); - MPSGraphTensor* gradWeightTensor = [mpsGraph convolution2DWeightsGradientWithIncomingGradientTensor:gradOutputTensor + MPSGraphTensor *gradOutputTensorTranspose = gradOutputTensor; + if (is_channels_last && grad_output_t.is_contiguous() && !grad_output_t.is_view()) { + // NHWC -> NCHW + gradOutputTensorTranspose = [mpsGraph transposeTensor: [mpsGraph transposeTensor:gradOutputTensor dimension:-1 withDimension:-2 name:nil] + dimension: -2 + withDimension: -3 + name: nil]; + } + MPSGraphTensor* gradWeightTensor = [mpsGraph convolution2DWeightsGradientWithIncomingGradientTensor:gradOutputTensorTranspose sourceTensor:inputTensor outputShape:mps_weight_shape forwardConvolutionDescriptor:descriptor_ @@ -422,7 +450,7 @@ Tensor mps_convolution_backward_weights( } auto gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output_t, gradOutputShape); - auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input_t, inputShape); + auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input_t); auto outputPlaceholder = Placeholder(cachedGraph->gradWeightTensor_, grad_weight_t); NSDictionary *feeds = @{ @@ -481,6 +509,7 @@ Tensor _mps_convolution_transpose( const Tensor& input_t, const Tensor& weight_t, IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups) { + TORCH_CHECK(input_t.dim() < 5, "ConvTranspose 3D is not supported on MPS"); auto output_t = mps_convolution_transpose_forward( input_t, weight_t, padding, output_padding, stride, dilation, groups); diff --git a/aten/src/ATen/native/mps/operations/Copy.mm b/aten/src/ATen/native/mps/operations/Copy.mm index 2bfee3f9a393e..9eae9d409e41e 100644 --- a/aten/src/ATen/native/mps/operations/Copy.mm +++ b/aten/src/ATen/native/mps/operations/Copy.mm @@ -1,17 +1,7 @@ // Copyright © 2022 Apple Inc. -#include #include #include -#include -#include -#include -#include -#include -#include -#include -#include - namespace at { namespace native { @@ -66,7 +56,11 @@ void copy_cast_mps(at::Tensor& dst, const at::Tensor& src, newCachedGraph = new CachedGraph(mpsGraph); MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, src); - MPSGraphTensor* outputTensor = [mpsGraph castTensor:inputTensor toType:dstDType name:@"cast"]; + MPSGraphTensor* inputCastTensor = inputTensor; + if (isFloatingType(src.scalar_type()) && dstDType == MPSDataTypeUInt8) { + inputCastTensor = [mpsGraph castTensor:inputTensor toType:MPSDataTypeInt32 name:@"cast"]; + } + MPSGraphTensor* outputTensor = [mpsGraph castTensor:inputCastTensor toType:dstDType name:@"cast"]; newCachedGraph->inputTensor_ = inputTensor; newCachedGraph->outputTensor_ = outputTensor; @@ -115,8 +109,8 @@ void copy_cast_mps(at::Tensor& dst, const at::Tensor& src, src = src_; } id sourceBuffer = getMTLBufferStorage(src); - size_t dst_tensor_nbytes = dst.nbytes(); - + size_t dst_tensor_nbytes = dst.is_view() ? at::detail::computeStorageNbytesContiguous(dst.sizes(), dst.element_size(), dst.storage_offset()) : + dst.nbytes(); @autoreleasepool { MTLResourceOptions options = MTLResourceOptionCPUCacheModeDefault | MTLResourceStorageModeShared; NSUInteger alignedLength = 0; @@ -186,7 +180,6 @@ void copy_cast_mps(at::Tensor& dst, const at::Tensor& src, // For View tensors, the storage offset can be bigger than what's being reported by nbytes src_total_size = at::detail::computeStorageNbytesContiguous(src.sizes(), src.element_size(), src.storage_offset()); } else { - TORCH_INTERNAL_ASSERT(src_.strides() == dst_.strides()); src = src_; if (src.dtype() != dst_.dtype()) { // In case of dtype change, perform conversion on source device @@ -205,14 +198,30 @@ void copy_cast_mps(at::Tensor& dst, const at::Tensor& src, NSUInteger alignedLength = 0; void* alignedPtr = pageAlignedBlockPtr(host_src, (NSUInteger)src_total_size, &alignedLength); - id sourceBuffer = [device newBufferWithBytesNoCopy:alignedPtr - length:alignedLength - options:options - deallocator:nil]; sourceOffset = uintptr_t(host_src) - uintptr_t(alignedPtr); sourceOffset += src_.storage_offset() * src_.itemsize(); - stream->copy_and_sync(sourceBuffer, destBuffer, size_to_copy, sourceOffset, dst_byte_offset, non_blocking); + id sourceBuffer = nil; + // If the destination is a strided MPS tensor, we cannot perform a blit directly to copy the + // memory from the CPU tensor into the MPS tensor. We need to scatter the data into the right indices + bool doScatter = (!dst_.is_contiguous() && src.is_contiguous()); + if (doScatter) { + sourceBuffer = [device newBufferWithBytes:(void*)((uint8_t*)host_src + (src_.storage_offset() * src_.itemsize())) + length:size_to_copy + options:options]; + } + else { + sourceBuffer = [device newBufferWithBytesNoCopy:alignedPtr + length:alignedLength + options:options + deallocator:nil]; + } + + if (doScatter) { + scatterViewTensor(src, dst_, sourceBuffer); + } else { + stream->copy_and_sync(sourceBuffer, destBuffer, size_to_copy, sourceOffset, dst_byte_offset, non_blocking); + } [sourceBuffer release]; } @@ -231,10 +240,10 @@ void copy_blit_mps(void* dst, const void* src, size_t size) { // If dst is contiguous and there is no byte offset, we can save directly the result of // gather into dst. This reduces the overhead of doing an additional blit for most cases - bool returnGatherOutput = (dst_.is_contiguous() && !dst_byte_offset); + bool returnGatherOutput = (dst_.is_contiguous() && !dst_byte_offset && src_.dtype() == dst_.dtype()); Tensor src; - if (!src_.is_contiguous()) { + if (src_.is_view() || !src_.is_contiguous()) { Tensor emptyShell = Tensor(); src = gatherViewTensor(src_, returnGatherOutput ? dst_ : emptyShell); @@ -257,18 +266,7 @@ void copy_blit_mps(void* dst, const void* src, size_t size) { // If the memory is not contiguous, it means that the tensor has strides and we would not be // able to do the copy using a single blit if (!dst_.is_contiguous()) { - Tensor tmp; - if (src.dtype() != dst_.dtype()) { - id tmpBuffer = sourceBuffer; - if (src.element_size() < dst_.element_size()) { - tmp = at::native::empty_mps(dst_.sizes(), dst_.scalar_type(), c10::nullopt, kMPS); - tmpBuffer = getMTLBufferStorage(tmp); - } - - copy_cast_mps(dst_, src, tmpBuffer, sourceBuffer); - } - - return scatterViewTensor((src.dtype() != dst_.dtype() && tmp.has_storage()) ? tmp : src, dst_); + return scatterViewTensor(src, dst_); } src._set_conj(src_.is_conj()); src._set_neg(src_.is_neg()); diff --git a/aten/src/ATen/native/mps/operations/CrossKernel.mm b/aten/src/ATen/native/mps/operations/CrossKernel.mm new file mode 100644 index 0000000000000..22b715f5f9153 --- /dev/null +++ b/aten/src/ATen/native/mps/operations/CrossKernel.mm @@ -0,0 +1,207 @@ +// Copyright © 2022 Apple Inc. + +#include +#include + +namespace at { +namespace native { + +static const char* METAL_CROSS = R"CROSS_METAL( + +#include +using namespace metal; + +#define REGISTER_CROSS_FUNC(DTYPE) \ +static inline DTYPE ## 3 cross(DTYPE ## 3 x, DTYPE ## 3 y) { \ + DTYPE ## 3 out; \ + out.x = x.y * y.z - x.z * y.y; \ + out.y = x.z * y.x - x.x * y.z; \ + out.z = x.x * y.y - x.y * y.x; \ + return out; \ +} + +// Metal only supports half and float for native cross implementation. +// For all the the other data types, implement cross manually. +REGISTER_CROSS_FUNC(int); +REGISTER_CROSS_FUNC(long); +REGISTER_CROSS_FUNC(short); +REGISTER_CROSS_FUNC(char); +REGISTER_CROSS_FUNC(uchar); +REGISTER_CROSS_FUNC(bool); + +template +kernel void cross(constant void * input_ [[buffer(0)]], + constant void * other_ [[buffer(1)]], + device void * out_ [[buffer(2)]], + constant uint3 * offsets [[buffer(3)]], + constant int64_t & outStride [[buffer(4)]], + constant int64_t & inputStride [[buffer(5)]], + constant int64_t & otherStride [[buffer(6)]], + uint tid [[thread_position_in_grid]]) { + device T* out = (device T*)((device uint8_t*)out_ + offsets[tid].x); + constant T* input = (constant T*)((constant uint8_t*)input_ + offsets[tid].y); + constant T* other = (constant T*)((constant uint8_t*)other_ + offsets[tid].z); + + const U x = {input[0 * inputStride], input[1 * inputStride], input[2 * inputStride]}; + const U y = {other[0 * otherStride], other[1 * otherStride], other[2 * otherStride]}; + const U res = cross(x, y); + + out[0 * outStride] = res.x; + out[1 * outStride] = res.y; + out[2 * outStride] = res.z; +} + +#define REGISTER_CROSS_OP(DTYPE) \ +template \ +[[host_name("cross_" #DTYPE)]] \ +kernel void cross( \ + constant void * input_ [[buffer(0)]], \ + constant void * other_ [[buffer(1)]], \ + device void * out_ [[buffer(2)]], \ + constant uint3 * offsets [[buffer(3)]], \ + constant int64_t & outStride [[buffer(4)]], \ + constant int64_t & inputStride [[buffer(5)]], \ + constant int64_t & otherStride [[buffer(6)]], \ + uint tid [[thread_position_in_grid]]); + +REGISTER_CROSS_OP(float); +REGISTER_CROSS_OP(half); +REGISTER_CROSS_OP(int); +REGISTER_CROSS_OP(long); +REGISTER_CROSS_OP(short); +REGISTER_CROSS_OP(char); +REGISTER_CROSS_OP(uchar); +REGISTER_CROSS_OP(bool); + +)CROSS_METAL"; + +using namespace mps; + +static id compileCrossOpLibrary(id device) { + static id crossLibrary = nil; + if (crossLibrary) { + return crossLibrary; + } + + NSError *error = nil; + MTLCompileOptions *options = [[MTLCompileOptions new] autorelease]; + [options setLanguageVersion: MTLLanguageVersion2_3]; + crossLibrary = [device newLibraryWithSource:[NSString stringWithCString: METAL_CROSS encoding:NSASCIIStringEncoding] + options:options + error:&error]; + TORCH_CHECK(crossLibrary, "Failed to create metal cross library, error: ", [[error description] UTF8String]); + return crossLibrary; +} + +static id crossPipelineState(id device, ScalarType scalar_type) { + std::string kernel = "cross_" + scalarToMetalTypeString(scalar_type); + static std::unordered_map> psoCache; + id pso = psoCache[kernel]; + if (pso) { + return pso; + } + + NSError* error = nil; + id crossLib = compileCrossOpLibrary(device); + id crossFunc = [crossLib newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]]; + TORCH_CHECK(crossFunc, "Failed to create function state object for: ", kernel); + pso = [device newComputePipelineStateWithFunction:crossFunc error:&error]; + TORCH_CHECK(pso, "Failed to created pipeline state object, error: ", [[error description] UTF8String]); + + psoCache[kernel] = pso; + return pso; +} + +void cross_mps_impl(const Tensor& out, const Tensor& input, const Tensor& other, int64_t dim) { + TORCH_CHECK(input.dtype() != at::kDouble, "float64 is not supported on MPS"); + + auto iter = TensorIteratorConfig() + .add_output(out) + .add_input(input) + .add_input(other) + .resize_outputs(false) + .declare_static_shape(out.sizes(), /*squash_dims=*/dim) + .build(); + + id inputBuffer = getMTLBufferStorage(input); + id otherBuffer = getMTLBufferStorage(other); + id outputBuffer = getMTLBufferStorage(out); + id device = MPSDevice::getInstance()->device(); + MPSStream* mpsStream = getCurrentMPSStream(); + const int64_t out_dim_stride = out.stride(dim); + const int64_t input_dim_stride = input.stride(dim); + const int64_t other_dim_stride = other.stride(dim); + const uint32_t nDim = iter.ndim(); + constexpr uint32_t nOffsets = 3; + const uint32_t numThreads = iter.numel(); + dispatch_sync(mpsStream->queue(), ^(){ + @autoreleasepool { + NSError* error = nil; + id commandBuffer = mpsStream->commandBuffer(); + id computeEncoder = [commandBuffer computeCommandEncoder]; + MTLSize gridSize = MTLSizeMake(numThreads, 1, 1); + const IntArrayRef& iterShape = iter.shape(); + std::vector iterShapeData(iterShape.size()); + std::vector> strides(nDim); + + for (const auto i: c10::irange(iterShape.size())) { + TORCH_CHECK(i <= UINT32_MAX); + iterShapeData[i] = (uint32_t)(iterShape[i]); + } + + for (const auto i: c10::irange(nDim)) { + for (const auto offset: c10::irange(nOffsets)) { + strides[i][offset] = iter.strides(offset)[i]; + } + } + + id kernelDataOffsetsFunction = MPSDevice::getInstance()->metalIndexingFunction("kernel_index_offsets", nil); + id kernelDataOffsetsPSO = [[device newComputePipelineStateWithFunction: kernelDataOffsetsFunction + error: &error] autorelease]; + id kernelDataOffsets = [[device newBufferWithLength: numThreads * sizeof(simd_uint3) + options: 0] autorelease]; + TORCH_CHECK(kernelDataOffsetsPSO, "Failed to created pipeline state object, error: ", [[error description] UTF8String]); + [computeEncoder setComputePipelineState:kernelDataOffsetsPSO]; + [computeEncoder setBytes:strides.data() length:sizeof(uint32_t) * nDim * nOffsets atIndex:0]; + [computeEncoder setBuffer:kernelDataOffsets offset:0 atIndex:1]; + [computeEncoder setBytes:iterShapeData.data() length:sizeof(uint32_t) * iterShape.size() atIndex:2]; + [computeEncoder setBytes:&nDim length:sizeof(uint32_t) atIndex:3]; + [computeEncoder setBytes:&nOffsets length:sizeof(uint32_t) atIndex:4]; + + NSUInteger kernelOffsetsTGSize = kernelDataOffsetsPSO.maxTotalThreadsPerThreadgroup; + if (kernelOffsetsTGSize > numThreads) + kernelOffsetsTGSize = numThreads; + + MTLSize kernelOffsetsThreadGroupSize = MTLSizeMake(kernelOffsetsTGSize, 1, 1); + [computeEncoder dispatchThreads: gridSize + threadsPerThreadgroup: kernelOffsetsThreadGroupSize]; + + id crossPSO = crossPipelineState(device, out.scalar_type()); + [computeEncoder setComputePipelineState:crossPSO]; + [computeEncoder setBuffer:inputBuffer offset:input.storage_offset() * input.element_size() atIndex:0]; + [computeEncoder setBuffer:otherBuffer offset:other.storage_offset() * other.element_size() atIndex:1]; + [computeEncoder setBuffer:outputBuffer offset:out.storage_offset() * out.element_size() atIndex:2]; + [computeEncoder setBuffer:kernelDataOffsets offset:0 atIndex:3]; + [computeEncoder setBytes:&out_dim_stride length:sizeof(int64_t) atIndex:4]; + [computeEncoder setBytes:&input_dim_stride length:sizeof(int64_t) atIndex:5]; + [computeEncoder setBytes:&other_dim_stride length:sizeof(int64_t) atIndex:6]; + + NSUInteger tgSize = crossPSO.maxTotalThreadsPerThreadgroup; + if (tgSize > numThreads) { + tgSize = numThreads; + } + + MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1); + [computeEncoder dispatchThreads: gridSize + threadsPerThreadgroup: threadGroupSize]; + + [computeEncoder endEncoding]; + mpsStream->commit(true); + } + }); +} + +REGISTER_DISPATCH(cross_stub, &cross_mps_impl); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/mps/operations/Distributions.mm b/aten/src/ATen/native/mps/operations/Distributions.mm index d26b25e8c352d..1b395a3b9071d 100644 --- a/aten/src/ATen/native/mps/operations/Distributions.mm +++ b/aten/src/ATen/native/mps/operations/Distributions.mm @@ -3,7 +3,9 @@ #include #include #include -#include +#include +#include +#include namespace at { namespace native { @@ -11,32 +13,13 @@ struct RandomCachedGraph : public MPSCachedGraph { - RandomCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) { - // initialize Philox state values (only required once when graph is created) - const auto seed = c10::detail::getNonDeterministicRandom(); - const auto subsequence = c10::detail::getNonDeterministicRandom(); - philoxState = at::Philox4_32(seed, subsequence); - // the two last state values are the Philox keys which are initialized once only - stateValues[5] = static_cast(seed); - stateValues[6] = static_cast(seed >> 32); - } + RandomCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) { } // Only relevant for multinomial MPSGraphTensor *probTensor = nil; MPSGraphTensor *resultTensor = nil; MPSGraphTensor *stateTensor = nil; // used for Normal distributions only MPSGraphTensor *meanTensor = nil, *stdTensor = nil; - // we initialize and keep the philox's state in the graph. This would - // guarantee producing new random values each time the same graph is reused. - at::Philox4_32 philoxState; - std::array stateValues = {1}; - - void updatePhiloxCounters() { - // calling philoxState() would call operator() of philox_engine class to - // get each of the four newly generated counter values (see PhiloxRNGEngine.h). - for (int i = 1; i <= 4; i++) - stateValues[i] = philoxState(); - } }; typedef MPSGraphTensor* (^RandomOpBlock)(RandomCachedGraph*, MPSGraphTensor*); @@ -49,11 +32,13 @@ void updatePhiloxCounters() { const c10::optional& mean_opt, const c10::optional& std_opt, MPSGraphRandomDistribution distribution, + c10::optional gen, std::string op_name, RandomOpBlock randomBlock) { if (self.numel() == 0) { return self; } + auto mps_gen = get_generator_or_default(gen, at::mps::detail::getDefaultMPSGenerator()); MPSGraphCache* cache_ = MPSGraphCache::getInstance(); MPSStream* stream = getCurrentMPSStream(); @@ -68,7 +53,7 @@ void updatePhiloxCounters() { @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); newCachedGraph = new RandomCachedGraph(mpsGraph); - newCachedGraph->stateTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[@7]); + newCachedGraph->stateTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[@(at::mps::detail::PHILOX_STATE_N)]); // FP16, FP32 and Int32 are the only data types supported for distributions on MPS backend. const MPSDataType inputDataType = [&] { @@ -95,7 +80,7 @@ void updatePhiloxCounters() { desc.standardDeviation = static_cast(val2); } // we don't use the output state tensor from the MPSGraph API as it requires reading back from GPU to CPU. - // Instead, we keep the Philox state in the cached graph and use the PyTorch's philox_engine to maintain + // Instead, we keep the Philox state in the MPSGenerator and use the PyTorch's philox_engine to maintain // the counters, and feed them to the graph manually NSArray *resultTensors = [mpsGraph randomTensorWithShape: getMPSShape(self) descriptor: desc @@ -109,12 +94,16 @@ void updatePhiloxCounters() { return newCachedGraph; }); } - // update the Philox state values on each run of the same graph - cachedGraph->updatePhiloxCounters(); // feed the updated state values to the graph - MPSNDArrayDescriptor *stateDesc = [MPSNDArrayDescriptor descriptorWithDataType: MPSDataTypeInt32 shape: @[@7]]; + MPSNDArrayDescriptor *stateDesc = [MPSNDArrayDescriptor descriptorWithDataType: MPSDataTypeInt32 shape: @[@(at::mps::detail::PHILOX_STATE_N)]]; MPSNDArray *stateNDArray = [[[MPSNDArray alloc] initWithDevice: stream->device() descriptor: stateDesc] autorelease]; - [stateNDArray writeBytes: &cachedGraph->stateValues[0] strideBytes: nil]; + { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(mps_gen->mutex_); + // update the Philox state values on each run + mps_gen->update_philox_counters(); + [stateNDArray writeBytes: mps_gen->state_data() strideBytes: nil]; + } MPSGraphTensorData* stateTensorData = [[[MPSGraphTensorData alloc] initWithMPSNDArray: stateNDArray] autorelease]; Placeholder meanPlaceholder, stdPlaceholder; @@ -146,6 +135,7 @@ void updatePhiloxCounters() { Tensor& normal_mps_impl(Tensor& self, double mean_s, double std_s, const c10::optional& mean_opt, const c10::optional& std_opt, + c10::optional gen, std::string op_name) { const Tensor& std_t = *(at::borrow_from_optional_tensor(std_opt)); @@ -177,12 +167,12 @@ void updatePhiloxCounters() { return resultTensor; }; return random_mps_impl(self, mean_s, std_s, mean_opt, std_opt, - MPSGraphRandomDistributionNormal, + MPSGraphRandomDistributionNormal, gen, op_name + getTensorsStringKey({mean_t, std_t}), random_op_block); } -Tensor& bernoulli_mps_impl(Tensor& self, const Tensor& prob_t, std::string op_name) +Tensor& bernoulli_mps_impl(Tensor& self, const Tensor& prob_t, c10::optional gen, std::string op_name) { TORCH_CHECK(prob_t.is_same_size(self), op_name, ": probability and self tensor should be of the same shape") @@ -195,7 +185,7 @@ void updatePhiloxCounters() { }; // Bernoulli generates binary output so we use bool type return mps::random_mps_impl(self, 0.0, 1.0, c10::nullopt, prob_t, - MPSGraphRandomDistributionUniform, + MPSGraphRandomDistributionUniform, gen, op_name + getTensorsStringKey({prob_t}), random_op_block); } @@ -215,16 +205,16 @@ void updatePhiloxCounters() { }); return mps::random_mps_impl(self, from, to, c10::nullopt, c10::nullopt, - MPSGraphRandomDistributionUniform, __func__, nullptr); + MPSGraphRandomDistributionUniform, gen, __func__, nullptr); } Tensor& normal_mps_(Tensor& self, double mean, double std, c10::optional gen) { - return mps::normal_mps_impl(self, mean, std, c10::nullopt, c10::nullopt, __func__); + return mps::normal_mps_impl(self, mean, std, c10::nullopt, c10::nullopt, gen, __func__); } Tensor normal_mps(const Tensor& mean, double std, c10::optional gen) { Tensor self = empty_mps(mean.sizes(), mean.scalar_type(), c10::nullopt, kMPS); - return mps::normal_mps_impl(self, 0.0, std, mean, c10::nullopt, __func__); + return mps::normal_mps_impl(self, 0.0, std, mean, c10::nullopt, gen, __func__); } Tensor normal_mps(double mean, const Tensor& std, c10::optional gen) { @@ -232,48 +222,48 @@ Tensor normal_mps(double mean, const Tensor& std, c10::optional gen) // when there's no tensor-type mean, we cannot pass scalar mean value due to the order of // multiply/add ops in random computation. So we create a mean tensor instead. Tensor mean_t = at::full_like(self, Scalar(mean)); - return mps::normal_mps_impl(self, 0.0, 1.0, mean_t, std, __func__); + return mps::normal_mps_impl(self, 0.0, 1.0, mean_t, std, gen, __func__); } Tensor normal_mps(const Tensor& mean, const Tensor& std, c10::optional gen) { auto shape = at::infer_size(mean.sizes(), std.sizes()); Tensor self = empty_mps(shape, mean.scalar_type(), c10::nullopt, kMPS); - return mps::normal_mps_impl(self, 0.0, 1.0, mean, std, __func__); + return mps::normal_mps_impl(self, 0.0, 1.0, mean, std, gen, __func__); } Tensor& normal_mps_out(const Tensor& mean, double std, c10::optional gen, Tensor& self) { - return mps::normal_mps_impl(self, 0.0, std, mean, c10::nullopt, __func__); + return mps::normal_mps_impl(self, 0.0, std, mean, c10::nullopt, gen, __func__); } Tensor& normal_mps_out(double mean, const Tensor& std, c10::optional gen, Tensor& self) { // when there's no tensor-type mean, we cannot pass scalar mean value due to the order of // multiply/add ops in random computation. So we create a mean tensor instead. Tensor mean_t = at::full_like(self, Scalar(mean)); - return mps::normal_mps_impl(self, 0.0, 1.0, mean_t, std, __func__); + return mps::normal_mps_impl(self, 0.0, 1.0, mean_t, std, gen, __func__); } Tensor& normal_mps_out(const Tensor& mean, const Tensor& std, c10::optional gen, Tensor& self) { TORCH_CHECK(mean.numel() == std.numel(), "normal_mps_out: mean and std must have same number of elements") - return mps::normal_mps_impl(self, 0.0, 1.0, mean, std, __func__); + return mps::normal_mps_impl(self, 0.0, 1.0, mean, std, gen, __func__); } Tensor& bernoulli_out_mps(const Tensor& p_, c10::optional gen, Tensor& result) { result.resize_(p_.sizes()); - return mps::bernoulli_mps_impl(result, p_, __func__); + return mps::bernoulli_mps_impl(result, p_, gen, __func__); } Tensor& bernoulli_mps_(Tensor& self, double p, c10::optional gen) { TORCH_CHECK(0.0 <= p && p <= 1.0, "bernoulli_mps_ expects p to be in [0, 1], but got p=", p); Tensor prob_t = at::full_like(self, Scalar(p)); - return mps::bernoulli_mps_impl(self, prob_t, __func__); + return mps::bernoulli_mps_impl(self, prob_t, gen, __func__); } Tensor& bernoulli_mps_(Tensor& self, const Tensor& p_, c10::optional gen) { - return mps::bernoulli_mps_impl(self, p_, __func__); + return mps::bernoulli_mps_impl(self, p_, gen, __func__); } // random_.from -Tensor& random_mps_(Tensor& self, int64_t from, optional to_opt, c10::optional gen) { +Tensor& random_mps_(Tensor& self, int64_t from, c10::optional to_opt, c10::optional gen) { auto input_dtype = self.scalar_type(); int64_t to = 0; @@ -321,7 +311,7 @@ Tensor normal_mps(const Tensor& mean, const Tensor& std, c10::optional(self, from, to - 1, c10::nullopt, c10::nullopt, - MPSGraphRandomDistributionUniform, __func__, nullptr); + MPSGraphRandomDistributionUniform, gen, __func__, nullptr); } Tensor& random_mps_(Tensor& self, int64_t to, c10::optional gen) { @@ -348,10 +338,51 @@ Tensor normal_mps(const Tensor& mean, const Tensor& std, c10::optional(self, 0.0, 1.0, c10::nullopt, c10::nullopt, - MPSGraphRandomDistributionUniform, + MPSGraphRandomDistributionUniform, gen, "exponential_mps_:" + std::to_string(lambda), random_op_block); } +Tensor& randperm_out_mps(int64_t n, c10::optional generator, Tensor& result) { + if (!is_macos_13_or_newer()) { + TORCH_WARN_ONCE("MPS: randperm op is supported natively starting from macOS 13.0. ", + "Falling back on CPU. This may have performance implications."); + + auto result_cpu = result.to("cpu"); + at::randperm_out(result_cpu, n); + result.resize_as_(result_cpu); + result.copy_(result_cpu); + return result; + } + + TORCH_CHECK(n >= 0, "n must be non-negative, got", n); + TORCH_CHECK(!generator.has_value() || + (generator.has_value() && result.device() == generator->device()), + "Expected a '", result.device(), "' generator device but found '", generator->device(), "'"); + check_supported_max_int_with_precision(n, result); + + result.resize_({n}); + if (n == 0) { + return result; + } + + mps::RandomOpBlock random_op_block = ^RandomOpFn(cachedGraph, randomTensor) { + MPSGraph* mpsGraph = cachedGraph->graph(); + MPSGraphTensor* argsortTensor = [mpsGraph argSortWithTensor:randomTensor + axis:0 + name:nil]; + if (result.scalar_type() != kInt) { + argsortTensor = [mpsGraph castTensor:argsortTensor + toType:mps::getMPSDataType(result.scalar_type()) + name:@"castOutput"]; + } + return argsortTensor; + }; + + return mps::random_mps_impl(result, 0.0, 1.0, c10::nullopt, c10::nullopt, + MPSGraphRandomDistributionUniform, generator, + "ranperm_out_mps:" + mps::getTensorsStringKey({result}), random_op_block); +} + Tensor& multinomial_with_replacement_mps_kernel( const Tensor& self, const int64_t n_sample, @@ -360,6 +391,7 @@ Tensor normal_mps(const Tensor& mean, const Tensor& std, c10::optional(generator, at::mps::detail::getDefaultMPSGenerator()); int inputSize = self.dim(); int numDist = inputSize == 1 ? 1 : self.size(0); @@ -405,9 +437,14 @@ Tensor normal_mps(const Tensor& mean, const Tensor& std, c10::optional(n_sample), numCategories}; MPSGraphTensor *broadcastShapeTensor = [mpsGraph constantWithData:[NSData dataWithBytes:broadcastShapeVals length:sizeof(int) * broadcastShape.count] shape:@[[NSNumber numberWithUnsignedInteger:broadcastShape.count]] dataType:MPSDataTypeUInt32]; @@ -473,11 +510,15 @@ Tensor normal_mps(const Tensor& mean, const Tensor& std, c10::optionalupdatePhiloxCounters(); - // feed the updated state values to the graph - MPSNDArrayDescriptor *stateDesc = [MPSNDArrayDescriptor descriptorWithDataType: MPSDataTypeInt32 shape: @[@7]]; + MPSNDArrayDescriptor *stateDesc = [MPSNDArrayDescriptor descriptorWithDataType: MPSDataTypeInt32 shape: @[@(at::mps::detail::PHILOX_STATE_N)]]; MPSNDArray *stateNDArray = [[[MPSNDArray alloc] initWithDevice: stream->device() descriptor: stateDesc] autorelease]; - [stateNDArray writeBytes: &cachedGraph->stateValues[0] strideBytes: nil]; + { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(mps_gen->mutex_); + // update the Philox state values on each run + mps_gen->update_philox_counters(); + [stateNDArray writeBytes: mps_gen->state_data() strideBytes: nil]; + } MPSGraphTensorData* stateTensorData = [[[MPSGraphTensorData alloc] initWithMPSNDArray: stateNDArray] autorelease]; auto probPlaceholder = Placeholder(cachedGraph->probTensor, self_v); diff --git a/aten/src/ATen/native/mps/operations/Eye.mm b/aten/src/ATen/native/mps/operations/Eye.mm index 45b3fdf68b07f..6b72c0686caa4 100644 --- a/aten/src/ATen/native/mps/operations/Eye.mm +++ b/aten/src/ATen/native/mps/operations/Eye.mm @@ -70,9 +70,9 @@ @autoreleasepool { // A key is used to identify the MPSGraph which was created once, and can be reused if the parameters, data types etc match the earlier created MPSGraph string key = "eye_out_mps:" + getTensorsStringKey({result}); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + CachedGraph* cachedGraph = cache_->LookUpAs(key); if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { + cachedGraph = cache_->CreateCachedGraphAs(key, ^ MPSCachedGraph * () { CachedGraph *newCachedGraph = nil; @@ -94,7 +94,6 @@ } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); } // Create placeholders which use the keys of the CachedGraph to create inputs and outputs of the operation diff --git a/aten/src/ATen/native/mps/operations/GridSampler.mm b/aten/src/ATen/native/mps/operations/GridSampler.mm new file mode 100644 index 0000000000000..b37b956fc9020 --- /dev/null +++ b/aten/src/ATen/native/mps/operations/GridSampler.mm @@ -0,0 +1,150 @@ +#include +#include +#include + +namespace at { +namespace native { + +void grid_sampler_2d_mps_impl(Tensor &output, const Tensor& input, const Tensor& grid, + int64_t interpolation_mode, int64_t padding_mode, + bool align_corners) { +// Grid Sampler support has been added in macOS 13.1 +#if !defined(__MAC_13_1) && !defined(MAC_OS_X_VERSION_13_1) + using namespace mps; + check_grid_sampler_common(input, grid); + check_grid_sampler_2d(input, grid); + + MPSGraphResizeMode samplingMode; + MPSGraphPaddingMode paddingMode; + + auto memory_format = input.suggest_memory_format(); + MPSGraphTensorNamedDataLayout inputTensorLayout = + (memory_format == at::MemoryFormat::Contiguous) ? MPSGraphTensorNamedDataLayoutNCHW : MPSGraphTensorNamedDataLayoutNHWC; + + switch (static_cast(padding_mode)) { + case GridSamplerPadding::Zeros: + paddingMode = MPSGraphPaddingModeZero; break; + case GridSamplerPadding::Border: + TORCH_CHECK(false, "MPS: Unsupported Border padding mode"); break; + case GridSamplerPadding::Reflection: + paddingMode = align_corners == true ? MPSGraphPaddingModeReflect : MPSGraphPaddingModeSymmetric; break; + default: + TORCH_CHECK(false, "MPS: Unrecognised Padding Mode: ", padding_mode); + } + + switch (static_cast(interpolation_mode)) { + case GridSamplerInterpolation::Bilinear: + samplingMode = MPSGraphResizeBilinear; break; + case GridSamplerInterpolation::Nearest: + samplingMode = MPSGraphResizeNearest; break; + case GridSamplerInterpolation::Bicubic: + TORCH_CHECK(false, "MPS: Unsupported Bicubic interpolation"); break; + default: + TORCH_CHECK(false, "MPS: Unrecognised interpolation mode: ", interpolation_mode); break; + } + + MPSStream *stream = getCurrentMPSStream(); + + struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* inputTensor_ = nil; + MPSGraphTensor* gridTensor_ = nil; + MPSGraphTensor* outputTensor_ = nil; + }; + + MPSGraphCache* cache_ = MPSGraphCache::getInstance(); + + @autoreleasepool { + string key = "grid_sampler_2d_mps" + + getTensorsStringKey({input, grid}) + + ":" + std::to_string(interpolation_mode) + + ":" + std::to_string(padding_mode) + + ":" + std::to_string(align_corners); + + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + if(!cachedGraph) { + MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { + + CachedGraph *newCachedGraph = nil; + @autoreleasepool { + MPSGraph* mpsGraph = make_mps_graph(); + newCachedGraph = new CachedGraph(mpsGraph); + + MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input); + MPSGraphTensor* gridTensor = mpsGraphRankedPlaceHolder(mpsGraph, grid); + + MPSGraphTensor* outputTensor = nil; + if (static_cast(interpolation_mode) == GridSamplerInterpolation::Nearest) { + outputTensor = [mpsGraph sampleGridWithSourceTensor: inputTensor + coordinateTensor: gridTensor + layout: inputTensorLayout + normalizeCoordinates: TRUE + relativeCoordinates: FALSE + alignCorners: align_corners + paddingMode: paddingMode + nearestRoundingMode: MPSGraphResizeNearestRoundingModeRoundToEven + constantValue: 0.0f + name: nil]; + } else { + outputTensor = [mpsGraph sampleGridWithSourceTensor: inputTensor + coordinateTensor: gridTensor + layout: inputTensorLayout + normalizeCoordinates: TRUE + relativeCoordinates: FALSE + alignCorners: align_corners + paddingMode: paddingMode + samplingMode: samplingMode + constantValue: 0.0f + name: nil]; + } + + newCachedGraph->inputTensor_ = inputTensor; + newCachedGraph->gridTensor_ = gridTensor; + newCachedGraph->outputTensor_ = outputTensor; + } + return newCachedGraph; + }); + cachedGraph = static_cast(tmpCachedGraph); + } + + Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input); + Placeholder gridPlaceholder = Placeholder(cachedGraph->gridTensor_, grid); + Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output); + + + NSDictionary* feeds = @{ + inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), + gridPlaceholder.getMPSGraphTensor() : gridPlaceholder.getMPSGraphTensorData() + }; + NSDictionary* results = @{ + outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() + }; + + runMPSGraph(stream, cachedGraph->graph(), feeds, results); + } +#endif // !defined(__MAC_13_1) && !defined(MAC_OS_X_VERSION_13_1) +} + +Tensor grid_sampler_2d_mps(const Tensor& input, const Tensor& grid, + int64_t interpolation_mode, int64_t padding_mode, + bool align_corners) { + if (!is_macos_13_or_newer(/*subVersion=*/1)) { + TORCH_WARN_ONCE("MPS: grid_sampler_2d op is supported natively starting from macOS 13.1. ", + "Falling back on CPU. This may have performance implications."); + + return at::grid_sampler_2d( + input.to("cpu"), grid.to("cpu"), interpolation_mode, padding_mode, align_corners).clone().to("mps"); + } + + auto in_size = input.sizes(); + auto grid_size = grid.sizes(); + auto output = at::empty( + {in_size[0], in_size[1], grid_size[1], grid_size[2]}, input.options()); + + grid_sampler_2d_mps_impl( + output, input, grid, interpolation_mode, padding_mode, align_corners); + return output; +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/mps/operations/Indexing.mm b/aten/src/ATen/native/mps/operations/Indexing.mm index 65d27ba757935..28bb6e8c84f98 100644 --- a/aten/src/ATen/native/mps/operations/Indexing.mm +++ b/aten/src/ATen/native/mps/operations/Indexing.mm @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -49,7 +50,7 @@ bool dispatchIndexKernel(TensorIteratorBase& iter, dispatch_sync(mpsStream->queue(), ^(){ @autoreleasepool { - NSError* error = nil; + NSError* error = nil; constexpr uint32_t nOffsets = 3; const int64_t num_indices = index_size.size(); const uint32_t numThreads = iter.numel(); @@ -139,7 +140,7 @@ bool dispatchIndexKernel(TensorIteratorBase& iter, threadsPerThreadgroup: threadGroupSize]; [computeEncoder endEncoding]; - mpsStream->commit(true); + mpsStream->synchronize(SyncType::COMMIT); } }); @@ -211,6 +212,189 @@ void index_put_kernel_mps(TensorIterator& iter, IntArrayRef index_size, IntArray return result; } +static +Tensor nonzero_fallback(const Tensor& self) { + TORCH_WARN_ONCE("MPS: nonzero op is supported natively starting from macOS 13.0. ", + "Falling back on CPU. This may have performance implications."); + + return at::nonzero(self.to("cpu")).clone().to("mps"); +} + +Tensor& nonzero_out_mps(const Tensor& self, Tensor& out_){ + if (!is_macos_13_or_newer()) { + Tensor out_fallback = nonzero_fallback(self); + at::native::resize_output(out_, out_fallback.sizes()); + out_.copy_(out_fallback.to("mps")); + return out_; + } + + int64_t nDim = self.dim(); + if (self.numel() == 0) { + at::native::resize_output(out_, {0, nDim}); + return out_; + } + + using namespace mps; + const uint32_t maxDimensions = 16; + + TORCH_CHECK(self.numel() < std::numeric_limits::max(), "nonzero is not supported for tensors with more than INT_MAX elements, \ + file a support request"); + TORCH_CHECK(out_.dtype() == at::kLong, "Expected object of scalar type ", at::kLong, " as out, but got ", out_.dtype()); + TORCH_CHECK(self.device() == out_.device(), "expected self and out to be on the same device, but got out on ", + out_.device(), " and self on ", self.device()); + TORCH_CHECK(self.dim() <= maxDimensions, "nonzero is not supported for tensor with more than ", 16, " dimensions"); + TORCH_CHECK(out_.is_mps()); + + MPSStream *stream = getCurrentMPSStream(); + struct CachedGraph : public MPSCachedGraph + { + CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* inputTensor_ = nil; + MPSGraphTensor* outputTensor_ = nil; + MPSGraphTensor* scatterDataTensor_ = nil; + MPSGraphTensor* countNonzeroTensor_ = nil; + }; + + stream->synchronize(SyncType::COMMIT_AND_WAIT); + Tensor count_nonzero = at::empty({1}, self.options().dtype(kInt)); + Tensor out = at::native::empty_mps( + {self.numel(), nDim == 0 ? 1 : nDim}, + out_.scalar_type(), + c10::nullopt, + kMPS, + c10::nullopt, + c10::nullopt); + + int64_t _apparentInputShape = 1; + for (auto dim : self.sizes()) { + _apparentInputShape *= dim; + } + MPSShape *apparentOutputShape = @[@(self.numel() * nDim)]; + MPSShape *apparentInputShape = @[@(_apparentInputShape)]; + + // Pseudocode: + // + // inputTensor = [1, 0, 0, 3] + // inputNonZero = [1, 0, 0, 1] + // indices = [1, 1, 1, 2] + // maskedIndices = [0, -1, -1, 1] + // coordinates = [0, 1, 2, 3] + // scatterResult = [0, 3] + + MPSGraphCache* cache_ = MPSGraphCache::getInstance(); + @autoreleasepool { + string key = "nonzero_out_mps" + getTensorsStringKey(self); + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + + if(!cachedGraph) { + MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { + CachedGraph *newCachedGraph = nil; + @autoreleasepool { + MPSDataType inputDataType = getMPSDataType(self.scalar_type()); + MPSShape* inputShape = getMPSShape(self); + MPSGraph* mpsGraph = make_mps_graph(); + newCachedGraph = new CachedGraph(mpsGraph); + + MPSGraphTensor *inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(self.scalar_type()), apparentInputShape); + MPSGraphTensor *scatterDataTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSScalarType(out.scalar_type())); + MPSGraphTensor *zeroTensor = [mpsGraph constantWithScalar:0.0 dataType:inputDataType]; + MPSGraphTensor *oneTensor = [mpsGraph constantWithScalar:1.0 dataType:MPSDataTypeInt32]; + MPSGraphTensor *minusMaxDimTensor = [mpsGraph constantWithScalar:-maxDimensions dataType:MPSDataTypeInt32]; + MPSGraphTensor *inputNotEqualToZeroTensor = [mpsGraph notEqualWithPrimaryTensor:inputTensor + secondaryTensor:zeroTensor + name:nil]; + MPSGraphTensor *countNonzero = [mpsGraph reductionSumWithTensor:inputNotEqualToZeroTensor + axis:0 + name:nil]; + MPSGraphTensor *maskTensor = [mpsGraph castTensor:inputNotEqualToZeroTensor + toType:MPSDataTypeInt32 + name:@"castToInt32"]; + MPSGraphTensor *indicesTensor = [mpsGraph cumulativeSumWithTensor:maskTensor + axis:0 + name:nil]; + MPSGraphTensor *indicesMinusOneTensor = [mpsGraph subtractionWithPrimaryTensor:indicesTensor + secondaryTensor:oneTensor + name:nil]; + MPSGraphTensor *maskedIndicesTensor = [mpsGraph selectWithPredicateTensor:inputNotEqualToZeroTensor + truePredicateTensor:indicesMinusOneTensor + falsePredicateTensor:minusMaxDimTensor + name:nil]; + MPSGraphTensor *coordinatesTensor = [mpsGraph reshapeTensor:[mpsGraph coordinateAlongAxis:0 withShape:inputShape name:nil] + withShape:@[@-1] + name:nil]; + if (nDim > 1) { + NSMutableArray *maskedIndicesTensorArray = [NSMutableArray arrayWithCapacity:nDim]; + NSMutableArray *coordinatesTensorArray = [NSMutableArray arrayWithCapacity:nDim]; + + MPSGraphTensor *constantRankTensor = [mpsGraph constantWithScalar:nDim + dataType:MPSDataTypeInt32]; + maskedIndicesTensorArray[0] = [mpsGraph multiplicationWithPrimaryTensor:maskedIndicesTensor + secondaryTensor:constantRankTensor + name:nil]; + coordinatesTensorArray[0] = coordinatesTensor; + for (int i = 1; i < nDim; i++){ + maskedIndicesTensorArray[i] = [mpsGraph additionWithPrimaryTensor:maskedIndicesTensorArray[i - 1] + secondaryTensor:oneTensor + name:nil]; + coordinatesTensorArray[i] = [mpsGraph reshapeTensor:[mpsGraph coordinateAlongAxis:i withShape:inputShape name:nil] + withShape:@[@-1] + name:nil]; + } + maskedIndicesTensor = [mpsGraph concatTensors:maskedIndicesTensorArray dimension:0 interleave:YES name:nil]; + coordinatesTensor = [mpsGraph concatTensors:coordinatesTensorArray dimension:0 interleave:YES name:nil]; + } + + MPSGraphTensor *outputTensor = [mpsGraph scatterWithDataTensor:scatterDataTensor + updatesTensor:coordinatesTensor + indicesTensor:maskedIndicesTensor + axis:0 + mode:MPSGraphScatterModeSet + name:nil]; + + newCachedGraph->inputTensor_ = inputTensor; + newCachedGraph->scatterDataTensor_ = scatterDataTensor; + newCachedGraph->outputTensor_ = outputTensor; + newCachedGraph->countNonzeroTensor_ = countNonzero; + } + return newCachedGraph; + }); + cachedGraph = static_cast(tmpCachedGraph); + } + + Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, apparentInputShape); + Placeholder countNonzeroPlaceholder = Placeholder(cachedGraph->countNonzeroTensor_, count_nonzero); + Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, out, apparentOutputShape); + Placeholder scatterPlaceholder = Placeholder(cachedGraph->scatterDataTensor_, out, apparentOutputShape); + + // Create dictionary of inputs and outputs + NSDictionary* feeds = @{ + selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(), + scatterPlaceholder.getMPSGraphTensor() : scatterPlaceholder.getMPSGraphTensorData() + }; + + NSDictionary* results = @{ + outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData(), + countNonzeroPlaceholder.getMPSGraphTensor() : countNonzeroPlaceholder.getMPSGraphTensorData() + }; + + runMPSGraph(stream, cachedGraph->graph(), feeds, results); + } + + int32_t total_nonzero = count_nonzero.item(); + at::native::resize_output(out_, {total_nonzero, nDim}); + out_.copy_(out.resize_({total_nonzero, nDim})); + return out_; +} + +Tensor nonzero_mps(const Tensor& self){ + if (!is_macos_13_or_newer()) { + return nonzero_fallback(self); + } + + Tensor out = at::empty({0}, self.options().dtype(kLong)); + return nonzero_out_mps(self, out); +} + Tensor masked_select_mps(const Tensor & self, const Tensor & mask) { namedinference::compute_broadcast_outnames(self, mask); Tensor result = at::empty({0}, self.options()); @@ -252,22 +436,26 @@ Tensor flip_mps(const Tensor& self, IntArrayRef dims) { MPSStream* stream = getCurrentMPSStream(); - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor* inputTensor_ = nil; - MPSGraphTensor* outputTensor_ = nil; - }; + using CachedGraph = mps::MPSUnaryCachedGraph; MPSGraphCache* cache_ = MPSGraphCache::getInstance(); - + MPSDataType inputDataType = getMPSScalarType(self.scalar_type()); + MPSDataType outputDataType = getMPSScalarType(self.scalar_type()); + if (!is_macos_13_or_newer()) { + if (self.scalar_type() == kBool) { + inputDataType = MPSDataTypeInt8; + } + if (result.scalar_type() == kBool) { + outputDataType = MPSDataTypeInt8; + } + } @autoreleasepool { NSString* ns_dims_key = [[ns_dims valueForKey:@"description"] componentsJoinedByString:@","]; // A key is used to identify the MPSGraph which was created once, and can be reused if the parameters, data types etc match the earlier created MPSGraph string key = "flip_mps:" + getTensorsStringKey({self}) + ":" + string([ns_dims_key UTF8String]); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + auto cachedGraph = cache_->LookUpAs(key); if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { + cachedGraph = cache_->CreateCachedGraphAs(key, ^ MPSCachedGraph * () { CachedGraph *newCachedGraph = nil; @@ -275,7 +463,7 @@ Tensor flip_mps(const Tensor& self, IntArrayRef dims) { MPSGraph* mpsGraph = make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); - MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); + MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, inputDataType, getMPSShape(self)); MPSGraphTensor* outputTensor = [mpsGraph reverseTensor:inputTensor axes:ns_dims name:nil]; @@ -284,12 +472,13 @@ Tensor flip_mps(const Tensor& self, IntArrayRef dims) { } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); } // Create placeholders which use the keys of the CachedGraph to create inputs and outputs of the operation - Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, self); - Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, result); + Placeholder inputPlaceholder = Placeholder( + cachedGraph->inputTensor_, self, /*mpsShape*/nil, /*gatherTensorData=*/true, inputDataType); + Placeholder outputPlaceholder = Placeholder( + cachedGraph->outputTensor_, result, /*mpsShape*/nil, /*gatherTensorData=*/false, outputDataType); NSDictionary* feeds = @{ @@ -320,12 +509,14 @@ Tensor flip_mps(const Tensor& self, IntArrayRef dims) { MPSStream* stream = getCurrentMPSStream(); dim = maybe_wrap_dim(dim, self.dim()); auto numel = index.numel(); - auto alpha_f = alpha.to(); if (numel == 0) { return; } + TORCH_CHECK(self.scalar_type() != ScalarType::Long, + "MPS: does not support index_add op with int64 input"); + struct CachedGraph : public MPSCachedGraph { CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} @@ -341,10 +532,10 @@ Tensor flip_mps(const Tensor& self, IntArrayRef dims) { @autoreleasepool { string key = "index_add_mps_out" + getTensorsStringKey({self, index, source}) + ":" + std::to_string(dim); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + CachedGraph* cachedGraph = cache_->LookUpAs(key); if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { + cachedGraph = cache_->CreateCachedGraphAs(key, ^ MPSCachedGraph * () { CachedGraph *newCachedGraph = nil; @autoreleasepool { @@ -354,16 +545,46 @@ Tensor flip_mps(const Tensor& self, IntArrayRef dims) { MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); MPSGraphTensor* indexTensor = mpsGraphRankedPlaceHolder(mpsGraph, index); MPSGraphTensor* sourceTensor = mpsGraphRankedPlaceHolder(mpsGraph, source); - MPSGraphTensor* alphaTensor = mpsGraphScalarPlaceHolder(mpsGraph, alpha_f); - MPSGraphTensor* alphaSourceSlice = [mpsGraph multiplicationWithPrimaryTensor:sourceTensor - secondaryTensor:alphaTensor + MPSGraphTensor* alphaTensor = mpsGraphScalarPlaceHolder(mpsGraph, getMPSScalarType(self.scalar_type())); + + MPSGraphTensor* castInputTensor = inputTensor; + MPSGraphTensor* castSourceTensor = sourceTensor; + MPSGraphTensor* castAlphaTensor = alphaTensor; + + MPSDataType dataType = [inputTensor dataType]; + + // failure due to issue #104289647: Wrong results from scatterWithDataTensor + if (dataType != MPSDataTypeInt32 && + dataType != MPSDataTypeFloat32) { + dataType = (dataType & MPSDataTypeFloatBit) ? MPSDataTypeFloat32 : MPSDataTypeInt32; + castInputTensor = [mpsGraph castTensor:inputTensor + toType:dataType + name:@"castInputTensor"]; + castSourceTensor = [mpsGraph castTensor:sourceTensor + toType:dataType + name:@"castSourceTensor"]; + castAlphaTensor = [mpsGraph castTensor:alphaTensor + toType:dataType + name:@"castAlphaTensor"]; + } + + MPSGraphTensor* alphaSourceSlice = [mpsGraph multiplicationWithPrimaryTensor:castSourceTensor + secondaryTensor:castAlphaTensor name:nil]; - MPSGraphTensor* outputTensor = [mpsGraph scatterWithDataTensor:inputTensor + MPSGraphTensor* outputTensor = [mpsGraph scatterWithDataTensor:castInputTensor updatesTensor:alphaSourceSlice indicesTensor:indexTensor axis:dim mode:MPSGraphScatterModeAdd name:nil]; + dataType = [inputTensor dataType]; + if (dataType != MPSDataTypeInt32 && + dataType != MPSDataTypeFloat32) { + outputTensor = [mpsGraph castTensor:outputTensor + toType:[inputTensor dataType] + name:@"castOutputTensor"]; + } + newCachedGraph->inputTensor_ = inputTensor; newCachedGraph->indexTensor_ = indexTensor; newCachedGraph->sourceTensor_ = sourceTensor; @@ -372,14 +593,13 @@ Tensor flip_mps(const Tensor& self, IntArrayRef dims) { } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); } Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); Placeholder indexPlaceholder = Placeholder(cachedGraph->indexTensor_, index); Placeholder sourcePlaceholder = Placeholder(cachedGraph->sourceTensor_, source); Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, result); - MPSScalar alpha_scalar = getMPSScalar(alpha_f, source.scalar_type()); + MPSScalar alpha_scalar = getMPSScalar(alpha, self.scalar_type()); NSDictionary* feeds = @{ selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(), @@ -456,21 +676,31 @@ Tensor index_select_mps(const Tensor & self, }; MPSGraphCache* cache_ = MPSGraphCache::getInstance(); + auto inputType = getMPSDataType(self.scalar_type()); + auto outputType = getMPSDataType(output.scalar_type()); + if (inputType == MPSDataTypeUInt8 || + (!is_macos_13_or_newer() && inputType == MPSDataTypeBool)) { + inputType = MPSDataTypeInt8; + } + if (outputType == MPSDataTypeUInt8 || + (!is_macos_13_or_newer() && outputType == MPSDataTypeBool)) { + outputType = MPSDataTypeInt8; + } @autoreleasepool { string key = "index_select_out_mps" + getTensorsStringKey({self, index}) + ":" + std::to_string(dim); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + CachedGraph* cachedGraph = cache_->LookUpAs(key); if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { + cachedGraph = cache_->CreateCachedGraphAs(key, ^ MPSCachedGraph * () { CachedGraph *newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); - MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); + MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, inputType, getMPSShape(self)); MPSGraphTensor* indexTensor = mpsGraphRankedPlaceHolder(mpsGraph, index); MPSGraphTensor* outputTensor = [mpsGraph gatherWithUpdatesTensor:inputTensor @@ -485,12 +715,13 @@ Tensor index_select_mps(const Tensor & self, } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); } - Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); + Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, + /*mpsShape=*/nullptr, /*gatherTensorData=*/true, /*dataType=*/inputType); Placeholder indexPlaceholder = Placeholder(cachedGraph->indexTensor_, index); - Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output); + Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output, + /*mpsShape=*/nullptr, /*gatherTensorData=*/false, /*dataType=*/outputType); NSDictionary* feeds = @{ selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(), @@ -522,17 +753,32 @@ Tensor index_select_mps(const Tensor & self, CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} MPSGraphTensor *inputTensor_ = nil; MPSGraphTensor *maskTensor_ = nil; + MPSGraphTensor *valueTensor_ = nil; MPSGraphTensor *outputTensor_ = nil; }; MPSGraphCache* cache_ = MPSGraphCache::getInstance(); + MPSDataType inputDataType = getMPSScalarType(self.scalar_type()); + MPSDataType maskDataType = getMPSScalarType(b_mask->scalar_type()); + // Workaround for `selectWithPredicateTensor` on macOS Monterey where bool data type may cause a hang + // The issue is fixed in macOS Ventura (13.0) + if (!is_macos_13_or_newer()) { + if (self.scalar_type() == kBool) { + inputDataType = MPSDataTypeInt8; + } + if (mask.scalar_type() == kBool) { + maskDataType = MPSDataTypeInt8; + } + } + MPSStream* stream = getCurrentMPSStream(); + MPSScalar valueScalar = getMPSScalar(value, value.type()); @autoreleasepool { - string key = "masked_fill" + getTensorsStringKey({self, mask}) + ":" + std::to_string(value.toDouble()); + string key = "masked_fill" + getTensorsStringKey({self, *b_mask}) + ":" + getMPSTypeString(value.type()); CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { + cachedGraph = cache_->CreateCachedGraphAs(key, ^ MPSCachedGraph * () { CachedGraph *newCachedGraph = nil; @@ -540,43 +786,44 @@ Tensor index_select_mps(const Tensor & self, MPSGraph* mpsGraph = make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); - MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); - MPSGraphTensor* maskTensor = mpsGraphRankedPlaceHolder(mpsGraph, mask); - MPSDataType valueType = getMPSScalarType(value.type()); + MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, inputDataType, getMPSShape(self)); + MPSGraphTensor* maskTensor = mpsGraphRankedPlaceHolder(mpsGraph, maskDataType, getMPSShape(*b_mask)); + MPSGraphTensor* valueTensor = mpsGraphScalarPlaceHolder(mpsGraph, value); - // constantWithScalar doesn't like Bool constants getting created so - // mapping them to int8 - if (valueType == MPSDataTypeBool) { - valueType = MPSDataTypeInt8; + MPSDataType valueType = getMPSScalarType(value.type()); + MPSGraphTensor* castValueTensor = valueTensor; + if (valueType != inputDataType) { + castValueTensor = [mpsGraph castTensor:valueTensor + toType:inputDataType + name:@"castValueTensor"]; } - MPSGraphTensor* valueTensor = [mpsGraph constantWithScalar:value.to() - dataType:valueType]; - valueTensor = [mpsGraph castTensor:valueTensor - toType:getMPSDataType(self.scalar_type()) - name : @"castTensorEq"]; MPSGraphTensor* outputTensor = [mpsGraph selectWithPredicateTensor:maskTensor - truePredicateTensor:valueTensor + truePredicateTensor:castValueTensor falsePredicateTensor:inputTensor name:nil]; newCachedGraph->inputTensor_ = inputTensor; newCachedGraph->maskTensor_ = maskTensor; + newCachedGraph->valueTensor_ = valueTensor; newCachedGraph->outputTensor_ = outputTensor; } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); } - Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); - Placeholder maskPlaceholder = Placeholder(cachedGraph->maskTensor_, mask); - Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, self); + Placeholder selfPlaceholder = Placeholder( + cachedGraph->inputTensor_, self, /*mpsShape*/nil, /*gatherTensorData=*/true, inputDataType); + Placeholder maskPlaceholder = Placeholder( + cachedGraph->maskTensor_, *b_mask, /*mpsShape*/nil, /*gatherTensorData=*/true, maskDataType); + Placeholder outputPlaceholder = Placeholder( + cachedGraph->outputTensor_, self, /*mpsShape*/nil, /*gatherTensorData=*/false, inputDataType); // Create dictionary of inputs and outputs NSDictionary* feeds = @{ selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(), - maskPlaceholder.getMPSGraphTensor() : maskPlaceholder.getMPSGraphTensorData() + maskPlaceholder.getMPSGraphTensor() : maskPlaceholder.getMPSGraphTensorData(), + cachedGraph->valueTensor_ : getMPSGraphTensorFromScalar(stream, valueScalar) }; NSDictionary* results = @{ @@ -584,7 +831,6 @@ Tensor index_select_mps(const Tensor & self, }; runMPSGraph(stream, cachedGraph->graph(), feeds, results); - } namedinference::propagate_names_if_nonempty(self, maybe_outnames); return self; @@ -615,7 +861,7 @@ Tensor embedding_dense_backward_mps( int64_t D = incoming_gradient_shape[num_incoming_gradient_dims - 1]; c10::SmallVector outgoing_gradient_shape{num_weights, D}; Tensor outgoing_gradient = at::native::empty_mps( - IntArrayRef(outgoing_gradient_shape.data(), outgoing_gradient_shape.size()), + IntArrayRef(outgoing_gradient_shape), grad_.scalar_type(), c10::nullopt, kMPS, @@ -630,10 +876,10 @@ Tensor embedding_dense_backward_mps( @autoreleasepool { string key = "edb_mps:" + native_mps::getMPSTypeString(grad_.scalar_type()) + ":indices" + std::to_string(num_indices_dims) + ":num_weights" + std::to_string(num_weights) + ":padding_idx" + std::to_string(padding_idx) + ":scaled" + std::to_string(scale_grad_by_freq); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + CachedGraph* cachedGraph = cache_->LookUpAs(key); // Initialize once if configuration not found in cache if(!cachedGraph) { - native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () { + cachedGraph = cache_->CreateCachedGraphAs(key, ^ native_mps::MPSCachedGraph * () { CachedGraph *newCachedGraph = nil; @@ -645,17 +891,20 @@ Tensor embedding_dense_backward_mps( MPSGraphTensor* indicesTensor = native_mps::mpsGraphUnrankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(indices.scalar_type())); - MPSGraphTensor *reshapedIndicesTensor = [mpsGraph expandDimsOfTensor:indicesTensor - axes:@[@-1] - name:nil]; + MPSGraphTensor* reshapedIndicesTensor = indicesTensor; + + if (num_indices_dims != 0) { + reshapedIndicesTensor = [mpsGraph expandDimsOfTensor: indicesTensor + axes: @[@-1] + name: nil]; + } - MPSGraphTensor *outgoingGradTensor; - outgoingGradTensor = [mpsGraph scatterNDWithUpdatesTensor:incomingGradTensor - indicesTensor:reshapedIndicesTensor - shape:native_mps::getMPSShape(IntArrayRef(outgoing_gradient_shape.data(), outgoing_gradient_shape.size())) - batchDimensions:0 - mode:MPSGraphScatterModeAdd - name:@"edb"]; + auto outgoingGradTensor = [mpsGraph scatterNDWithUpdatesTensor: incomingGradTensor + indicesTensor: reshapedIndicesTensor + shape: native_mps::getMPSShape(IntArrayRef(outgoing_gradient_shape)) + batchDimensions: 0 + mode: MPSGraphScatterModeAdd + name: @"edb"]; newCachedGraph->incomingGradTensor_ = incomingGradTensor; newCachedGraph->indicesTensor_ = indicesTensor; @@ -664,7 +913,6 @@ Tensor embedding_dense_backward_mps( } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); } auto incomingGradPlaceholder = native_mps::Placeholder(cachedGraph->incomingGradTensor_, grad_); auto indicesPlaceholder = native_mps::Placeholder(cachedGraph->indicesTensor_, indices); diff --git a/aten/src/ATen/native/mps/operations/Inverse.mm b/aten/src/ATen/native/mps/operations/Inverse.mm new file mode 100644 index 0000000000000..2975fd9875949 --- /dev/null +++ b/aten/src/ATen/native/mps/operations/Inverse.mm @@ -0,0 +1,86 @@ +#include +#include +#include +#include +#include + + +namespace at { +namespace native { + +TORCH_IMPL_FUNC(linalg_inv_ex_out_mps)(const Tensor& A, bool check_errors, const Tensor& result, const Tensor& info) { + TORCH_CHECK(result.is_mps(), "Output tensor is not MPS"); + if (!is_macos_13_or_newer()) { + TORCH_WARN_ONCE("torch.linalg_inv_ex.inverse is supported by MPS on MacOS 13+, please upgrade. Falling back to CPU."); + auto cpu_info = at::empty({0}, kInt, c10::nullopt, kCPU, c10::nullopt, c10::nullopt); + auto cpu_result = result.clone().to("cpu"); + at::linalg_inv_ex_out(cpu_result, cpu_info, A.to("cpu")); + info.copy_(cpu_info); + result.copy_(cpu_result); + return; + } + + using namespace mps; + MPSStream* stream = getCurrentMPSStream(); + info.zero_(); + + struct CachedGraph : public MPSCachedGraph + { + CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* inputTensor_ = nil; + MPSGraphTensor* outputTensor_ = nil; + }; + + Tensor output = result; + bool isContiguous = true; + if (!result.is_contiguous()) { + output = result.contiguous(); + isContiguous = false; + } + + MPSGraphCache* cache_ = MPSGraphCache::getInstance(); + + @autoreleasepool { + string key = "inv_out_mps" + getTensorsStringKey({A}); + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + if(!cachedGraph) + { + MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { + + CachedGraph *newCachedGraph = nil; + @autoreleasepool { + MPSGraph* mpsGraph = make_mps_graph(); + newCachedGraph = new CachedGraph(mpsGraph); + MPSGraphTensor* inputTensor= mpsGraphRankedPlaceHolder(mpsGraph, A); + MPSGraphTensor* outputTensor = [mpsGraph inverseOfTensor: inputTensor + name: nil]; + + newCachedGraph->inputTensor_ = inputTensor; + newCachedGraph->outputTensor_ = outputTensor; + } + + return newCachedGraph; + + }); + cachedGraph = static_cast(tmpCachedGraph); + } + + Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, A); + Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, isContiguous ? result : output); + + NSDictionary* feeds = @{ + inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData() + }; + + NSDictionary* results = @{ + outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() + }; + + runMPSGraph(stream, cachedGraph->graph(), feeds, results); + if (!isContiguous) { + result.copy_(output); + } + } +} +} +} \ No newline at end of file diff --git a/aten/src/ATen/native/mps/operations/Linear.mm b/aten/src/ATen/native/mps/operations/Linear.mm index ddaa6ce979638..b524a04aad58d 100644 --- a/aten/src/ATen/native/mps/operations/Linear.mm +++ b/aten/src/ATen/native/mps/operations/Linear.mm @@ -1,21 +1,12 @@ // Copyright © 2022 Apple Inc. -#include -#include -#include -#include #include -#include - -#ifdef __OBJC__ -#include -#endif - -using namespace at::mps; namespace at { namespace native { +using namespace mps; + Tensor _mps_linear( const Tensor& input, const Tensor& weight_arg, @@ -23,18 +14,13 @@ Tensor _mps_linear( // wT = transpose(weight); // y=x*wT+b - using namespace mps; - auto weight = (weight_arg.dim() == 1) ? weight_arg.view({1, weight_arg.size(0)}) : weight_arg; - TORCH_CHECK(input.scalar_type() == ScalarType::Double - || input.scalar_type() == ScalarType::Float - || input.scalar_type() == ScalarType::Half, "MPS device does not support linear for non-float inputs"); + TORCH_CHECK(input.scalar_type() == ScalarType::Float || + input.scalar_type() == ScalarType::Half, "MPS device does not support linear for non-float inputs"); - // See [Note: hacky wrapper removal for optional tensor] - auto bias = bias_opt.has_value() - ? c10::MaybeOwned::borrowed(*bias_opt) - : c10::MaybeOwned::owned(c10::in_place); + const Tensor& bias = *(at::borrow_from_optional_tensor(bias_opt)); + bool is_bias_defined = bias.defined(); auto input_size = input.sizes(); std::vector output_size(input_size.begin(), input_size.end() - 1); @@ -65,24 +51,11 @@ Tensor _mps_linear( MPSGraphCache* cache_ = MPSGraphCache::getInstance(); - bool is_bias_defined = bias->defined(); - @autoreleasepool { - - MPSShape* wt_shape = getMPSShape(weight); - string wt_key = string([[[wt_shape valueForKey:@"description"] componentsJoinedByString:@","] UTF8String]); - string bias_key = "nobias"; - if(is_bias_defined) { - bias_key = "bias"; - } - - string key = "mps_linear" + getTensorsStringKey({input, weight}) + ":" + bias_key; - - + string key = "mps_linear" + getTensorsStringKey({input, weight, bias}) ; CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { CachedGraph *newCachedGraph = nil; @@ -93,17 +66,11 @@ Tensor _mps_linear( MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input); MPSGraphTensor* weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight); - MPSGraphTensor* biasTensor = nil; - - if(is_bias_defined) { - biasTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType((*bias).scalar_type())); - } MPSGraphTensor* weightTransposeTensor = [mpsGraph transposeTensor:weightTensor dimension:-1 withDimension:-2 name:nil]; - MPSGraphTensor* outputTensor = nil; if (!is_bias_defined) @@ -114,17 +81,26 @@ Tensor _mps_linear( } else { - MPSGraphTensor* xMulWTTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:inputTensor + MPSGraphTensor* inputFlattened = inputTensor; + bool doReshape = false; + // workaround to improve the performance with 3D+ inputs + if (input_size.size() > 2 && input_size[0] > 1 && input_size[1] >= 1 && input_size[1] <= 32) { + doReshape = true; + inputFlattened = [mpsGraph flatten2DTensor:inputTensor axis:-1 name:nil]; + } + + newCachedGraph->biasTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, bias); + MPSGraphTensor* xMulWTTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:inputFlattened secondaryTensor:weightTransposeTensor name:nil]; - outputTensor = [mpsGraph additionWithPrimaryTensor:xMulWTTensor - secondaryTensor:biasTensor + MPSGraphTensor* biasedTensor = [mpsGraph additionWithPrimaryTensor:xMulWTTensor + secondaryTensor:newCachedGraph->biasTensor_ name:nil]; + outputTensor = doReshape ? [mpsGraph reshapeTensor:biasedTensor withShape:getMPSShape(output_size) name:nil] : biasedTensor; } newCachedGraph->inputTensor_ = inputTensor; newCachedGraph->weightTensor_ = weightTensor; - newCachedGraph->biasTensor_ = biasTensor; newCachedGraph->outputTensor_ = outputTensor; } return newCachedGraph; @@ -135,21 +111,20 @@ Tensor _mps_linear( Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input); Placeholder weightPlaceholder = Placeholder(cachedGraph->weightTensor_, weight); Placeholder biasPlaceholder = Placeholder(); - if(is_bias_defined) - biasPlaceholder = Placeholder(cachedGraph->biasTensor_, *bias); Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output); NSMutableDictionary* feeds =[NSMutableDictionary dictionary]; feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData(); feeds[weightPlaceholder.getMPSGraphTensor()] = weightPlaceholder.getMPSGraphTensorData(); - if (is_bias_defined) - feeds[biasPlaceholder.getMPSGraphTensor()] = biasPlaceholder.getMPSGraphTensorData(); - + if (is_bias_defined) { + biasPlaceholder = Placeholder(cachedGraph->biasTensor_, bias); + feeds[biasPlaceholder.getMPSGraphTensor()] = biasPlaceholder.getMPSGraphTensorData(); + } NSDictionary* results = @{ outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() }; - mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results); + runMPSGraph(stream, cachedGraph->graph(), feeds, results); } // Shave off '1' present at the end of the shape @@ -159,8 +134,7 @@ Tensor _mps_linear( std::vector out_shape(output_sizes.begin(), output_sizes.end()-1); return output.view(IntArrayRef(out_shape)); } - else - return output; + return output; } Tensor _mps_linear_backward_input( @@ -170,8 +144,9 @@ Tensor _mps_linear_backward_input( { TORCH_CHECK(grad_output.is_mps(), "mps_linear_backward: grad_output needs to be mps layout"); - TORCH_CHECK(weight.device().is_mps() && weight.scalar_type() == kFloat, - "mps_linear_backward: weight needs to be a dense tensor"); + TORCH_CHECK(weight.device().is_mps() && + (weight.scalar_type() == kFloat || (weight.scalar_type() == kHalf)), + "mps_linear_backward: unsupported weights data type: ", weight.scalar_type()); TORCH_CHECK(grad_output.scalar_type() == ScalarType::Double || grad_output.scalar_type() == ScalarType::Float @@ -179,7 +154,7 @@ Tensor _mps_linear_backward_input( const Tensor weight_reshaped = weight.is_contiguous() ? weight : weight.contiguous(); - struct CachedGraph : public mps::MPSCachedGraph + struct CachedGraph : public MPSCachedGraph { CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} MPSGraphTensor *weightTensor_ = nil; @@ -195,24 +170,24 @@ Tensor _mps_linear_backward_input( grad_output.suggest_memory_format()); TORCH_CHECK(output.is_mps()); - mps::MPSGraphCache *cache_ = mps::MPSGraphCache::getInstance(); + MPSGraphCache *cache_ = MPSGraphCache::getInstance(); MPSStream *stream= getCurrentMPSStream(); @autoreleasepool { - string key = "mps_linear_backward_input" + mps::getTensorsStringKey({grad_output, weight_reshaped}); + string key = "mps_linear_backward_input" + getTensorsStringKey({grad_output, weight_reshaped}); CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); if(!cachedGraph) { - mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ mps::MPSCachedGraph * () { + MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { CachedGraph *newCachedGraph = nil; @autoreleasepool { - MPSGraph *mpsGraph = mps::make_mps_graph(); + MPSGraph *mpsGraph = make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); - MPSGraphTensor *weightTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, weight_reshaped); - MPSGraphTensor *gradOutputTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, grad_output); + MPSGraphTensor *weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight_reshaped); + MPSGraphTensor *gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output); MPSGraphTensor *outputTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor: gradOutputTensor @@ -228,9 +203,9 @@ Tensor _mps_linear_backward_input( cachedGraph = static_cast(tmpCachedGraph); } - mps::Placeholder weightPlaceholder = mps::Placeholder(cachedGraph->weightTensor_, weight_reshaped); - mps::Placeholder gradOutputPlaceholder = mps::Placeholder(cachedGraph->gradOutputTensor_, grad_output); - mps::Placeholder outputPlaceholder = mps::Placeholder(cachedGraph->outputTensor_, output); + Placeholder weightPlaceholder = Placeholder(cachedGraph->weightTensor_, weight_reshaped); + Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output); + Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output); NSDictionary* feeds = @{ weightPlaceholder.getMPSGraphTensor() : weightPlaceholder.getMPSGraphTensorData(), @@ -241,7 +216,7 @@ Tensor _mps_linear_backward_input( outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() }; - mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results); + runMPSGraph(stream, cachedGraph->graph(), feeds, results); return output; } @@ -253,11 +228,10 @@ Tensor _mps_linear_backward_input( TORCH_CHECK(grad_output.is_mps() && input.is_mps(), "_mps_linear_backward: grad_output and input needs to be mps layout"); - TORCH_CHECK(grad_output.scalar_type() == ScalarType::Double - || grad_output.scalar_type() == ScalarType::Float - || grad_output.scalar_type() == ScalarType::Half, "MPS device does not support linear backward for non-float inputs"); + TORCH_CHECK(grad_output.scalar_type() == ScalarType::Float || + grad_output.scalar_type() == ScalarType::Half, "MPS device does not support linear backward for non-float inputs"); - struct CachedGraph : public mps::MPSCachedGraph + struct CachedGraph : public MPSCachedGraph { CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} MPSGraphTensor *inputTensor_ = nil; @@ -289,26 +263,26 @@ Tensor _mps_linear_backward_input( TORCH_CHECK(output.is_mps()); TORCH_CHECK(bias.is_mps()); - mps::MPSGraphCache *cache_ = mps::MPSGraphCache::getInstance(); + MPSGraphCache *cache_ = MPSGraphCache::getInstance(); MPSStream *stream= getCurrentMPSStream(); @autoreleasepool { string key = "mps_linear_backward_weights:" + to_string(bias_defined) + ":" + - mps::getTensorsStringKey({input_reshaped, weight, grad_output_reshaped}); + getTensorsStringKey({input_reshaped, weight, grad_output_reshaped}); CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); if(!cachedGraph) { - mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ mps::MPSCachedGraph * () { + MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { CachedGraph *newCachedGraph = nil; @autoreleasepool { - MPSGraph *mpsGraph = mps::make_mps_graph(); + MPSGraph *mpsGraph = make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); - MPSGraphTensor *inputTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, input_reshaped); - MPSGraphTensor *weightTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, weight); - MPSGraphTensor *gradOutputTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, grad_output_reshaped); + MPSGraphTensor *inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_reshaped); + MPSGraphTensor *weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight); + MPSGraphTensor *gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output_reshaped); MPSGraphTensor *gradOutputTransposeTensor = [mpsGraph transposeTensor: gradOutputTensor @@ -342,11 +316,11 @@ Tensor _mps_linear_backward_input( cachedGraph = static_cast(tmpCachedGraph); } - mps::Placeholder inputPlaceholder = mps::Placeholder(cachedGraph->inputTensor_, input_reshaped); - mps::Placeholder weightPlaceholder = mps::Placeholder(cachedGraph->weightTensor_, weight); - mps::Placeholder gradOutputPlaceholder = mps::Placeholder(cachedGraph->gradOutputTensor_, grad_output_reshaped); - mps::Placeholder outputPlaceholder = mps::Placeholder(cachedGraph->outputTensor_, output); - mps::Placeholder biasPlaceholder = mps::Placeholder(cachedGraph->biasTensor_, bias); + Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input_reshaped); + Placeholder weightPlaceholder = Placeholder(cachedGraph->weightTensor_, weight); + Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output_reshaped); + Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output); + Placeholder biasPlaceholder = Placeholder(cachedGraph->biasTensor_, bias); NSDictionary* feeds = @{ gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData(), @@ -359,7 +333,7 @@ Tensor _mps_linear_backward_input( if (bias_defined) results[biasPlaceholder.getMPSGraphTensor()] = biasPlaceholder.getMPSGraphTensorData(); - mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results); + runMPSGraph(stream, cachedGraph->graph(), feeds, results); return std::tuple{ output, bias }; } diff --git a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm index 31c8c88248d6a..c5f79a07d1982 100644 --- a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm +++ b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm @@ -7,6 +7,7 @@ #include #include #include +#include #ifdef __OBJC__ #include @@ -597,5 +598,109 @@ Tensor addbmm_mps(const Tensor& self, const Tensor& batch1, const Tensor& batch2 return addbmm_out_mps(self, batch1, batch2, beta, alpha, self); } +Tensor& linalg_solve_triangular_mps_impl( const Tensor& A, const Tensor& B, bool upper, bool transpose, bool left, bool unitriangular, Tensor& out) { + using namespace mps; + + if (!is_macos_13_or_newer()) { + TORCH_WARN_ONCE("MPS: linalg_solve_triangular_out op is supported natively starting from macOS 13.0. ", + "Falling back on CPU. This may have performance implications."); + + Tensor cpu_out = out.cpu(); + Tensor A_cpu = A.cpu(); + Tensor B_cpu = B.cpu(); + at::linalg_solve_triangular_out( + cpu_out, A_cpu, B_cpu, upper, left, unitriangular); + out.resize_(cpu_out.sizes(), cpu_out.suggest_memory_format()); + out.copy_(cpu_out); + + return out; + } + + checkInputsSolver(A, B, left, "linalg.solve_triangular"); + Tensor A_, B_; + std::tie(B_, A_) = _linalg_broadcast_batch_dims(B, A, /*don't check errors*/nullptr); + at::native::resize_output(out, B_.sizes()); + + if (A.numel() == 0 || B.numel() == 0 || out.numel() == 0) { + return out; + } + + id aBuffer = getMTLBufferStorage(A_); + id bBuffer = getMTLBufferStorage(B_); + id outBuffer = getMTLBufferStorage(out); + MPSStream* mpsStream = getCurrentMPSStream(); + id device = MPSDevice::getInstance()->device(); + + dispatch_sync(mpsStream->queue(), ^(){ + @autoreleasepool { + id commandBuffer = mpsStream->commandBuffer(); + MPSMatrixSolveTriangular *filter = [[[MPSMatrixSolveTriangular alloc] initWithDevice:device + right:!left + upper:upper + transpose:transpose + unit:unitriangular + order:left ? B_.size(-2) : B_.size(-1) + numberOfRightHandSides:left ? B_.size(-1) : B_.size(-2) + alpha:1.0f] autorelease]; + uint64_t batchSize = A_.sizes().size() > 2 ? A_.size(0) : 1; + uint64_t aRows = A_.size(-2); + uint64_t bRows = B_.size(-2); + uint64_t aCols = A_.size(-1); + uint64_t bCols = B_.size(-1); + uint64_t aElemSize = A_.element_size(); + uint64_t bElemSize = B_.element_size(); + + MPSMatrixDescriptor* sourceMatrixDesc = [MPSMatrixDescriptor matrixDescriptorWithRows:aRows + columns:aCols + matrices:batchSize + rowBytes:aCols * aElemSize + matrixBytes:aRows * aCols * aElemSize + dataType:getMPSDataType(A_.scalar_type())]; + MPSMatrixDescriptor* rightHandSideMatrixDesc = [MPSMatrixDescriptor matrixDescriptorWithRows:bRows + columns:bCols + matrices:batchSize + rowBytes:bCols * bElemSize + matrixBytes:bRows * bCols * bElemSize + dataType:getMPSDataType(B_.scalar_type())]; + for (const auto i: c10::irange(batchSize)) { + MPSMatrix* sourceMatrix = [[[MPSMatrix alloc] initWithBuffer:aBuffer + offset:i * aRows * aCols * aElemSize + descriptor:sourceMatrixDesc] autorelease]; + MPSMatrix* rightHandSideMatrix = [[[MPSMatrix alloc] initWithBuffer:bBuffer + offset:i * bRows * bCols * bElemSize + descriptor:rightHandSideMatrixDesc] autorelease]; + MPSMatrix *solutionMatrix = [[[MPSMatrix alloc] initWithBuffer:outBuffer + offset:i * bRows * bCols * bElemSize + descriptor:rightHandSideMatrixDesc] autorelease]; + + [filter encodeToCommandBuffer:commandBuffer + sourceMatrix:sourceMatrix + rightHandSideMatrix:rightHandSideMatrix + solutionMatrix:solutionMatrix]; + } + mpsStream->commit(true); + } + }); + return out; +} + +Tensor& linalg_solve_triangular_mps_out( const Tensor& A, const Tensor& B, bool upper, bool left, bool unitriangular, Tensor& out) { + return linalg_solve_triangular_mps_impl(A, B, upper, /*transpose=*/false, left, unitriangular, out); +} + +Tensor linalg_solve_triangular_mps(const Tensor& A, const Tensor& B, bool upper, bool left, bool unitriangular) { + Tensor out = at::empty({0}, A.options()); + linalg_solve_triangular_mps_impl(A, B, upper, /*transpose=*/false, left, unitriangular, out); + return out; +} + +TORCH_IMPL_FUNC(triangular_solve_mps_out)(const Tensor& self, const Tensor& A, bool upper, bool transpose, bool unitriangular, const Tensor& result, const Tensor& clone_A) { + clone_A.copy_(A); + Tensor out = at::empty({0}, A.options()); + linalg_solve_triangular_mps_impl(A, self, upper, transpose, /*left=*/true, unitriangular, out); + result.resize_(out.sizes()); + result.copy_(out); +} + } // namespace native } // namespace at diff --git a/aten/src/ATen/native/mps/operations/LossOps.mm b/aten/src/ATen/native/mps/operations/LossOps.mm index 3430af0434dec..8af47f86ef542 100644 --- a/aten/src/ATen/native/mps/operations/LossOps.mm +++ b/aten/src/ATen/native/mps/operations/LossOps.mm @@ -356,19 +356,12 @@ void nllnd_loss_backward_impl( MPSShape* weight_shape = getMPSShape(weight); MPSShape* total_weight_shape = getMPSShape(total_weight); - NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","]; - string key = "nllnd_loss_backward_impl:" + to_string(numClasses) + ":" + to_string(ignore_index) + ":" + to_string(isWeightsArrayValid) + ":" + reductionToString(reduction) + ":" + - [ns_shape_key UTF8String] + ":" + - getMPSTypeString(input.scalar_type()) + ":" + - getMPSTypeString(target.scalar_type()) + ":" + - getMPSTypeString(weight.scalar_type()) + ":" + - getMPSTypeString(total_weight.scalar_type()); + getTensorsStringKey({input, target, weight, total_weight}); CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if(!cachedGraph) { MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { @@ -409,9 +402,7 @@ void nllnd_loss_backward_impl( float onValue = -1.0f; - MPSGraphTensor *oneHotTensor; - - oneHotTensor = [mpsGraph oneHotWithIndicesTensor:udpatedTargetTensor + MPSGraphTensor *oneHotTensor = [mpsGraph oneHotWithIndicesTensor:udpatedTargetTensor depth:numClasses axis:1 dataType:inputTensor.dataType @@ -419,10 +410,22 @@ void nllnd_loss_backward_impl( offValue:0.0f name:nil]; - if(isWeightsArrayValid) - { + if(isWeightsArrayValid) { + int64_t nDim = input.sizes().size(); + IntArrayRef sizes = input.sizes(); + std::vector numbers(nDim); + for (const auto i: c10::irange(nDim)) { + NSInteger sz_i = (i == 1) ? sizes[i] : 1; + NSNumber* number = [NSNumber numberWithInteger:sz_i]; + numbers[i] = number; + } + + MPSGraphTensor *weightTensorReshaped = [mpsGraph reshapeTensor:weightTensor + withShape:[NSArray arrayWithObjects:numbers.data() count:numbers.size()] + name:nil]; + oneHotTensor = [mpsGraph multiplicationWithPrimaryTensor:oneHotTensor - secondaryTensor:weightTensor + secondaryTensor:weightTensorReshaped name:@"scaleByWeightTensor"]; } @@ -1077,12 +1080,14 @@ void smooth_l1_loss_backward_template( MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input); MPSGraphTensor* targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, target); + + MPSDataType input_type = getMPSScalarType(input.scalar_type()); MPSGraphTensor* deltaTensor = [mpsGraph constantWithScalar:delta shape:@[@1] - dataType:MPSDataTypeFloat32]; + dataType:input_type]; MPSGraphTensor* halfTensor = [mpsGraph constantWithScalar:.5f shape:@[@1] - dataType:MPSDataTypeFloat32]; + dataType:input_type]; MPSGraphTensor* diffTensor = [mpsGraph subtractionWithPrimaryTensor: inputTensor secondaryTensor: targetTensor @@ -1211,7 +1216,7 @@ Tensor huber_loss_mps(const Tensor& input, const Tensor& target, int64_t reducti name:nil]; MPSGraphTensor* deltaTensor = [mpsGraph constantWithScalar:delta shape:getMPSShape(target) - dataType:MPSDataTypeFloat32]; + dataType:getMPSDataType(target.scalar_type())]; MPSGraphTensor* diffTensor = [mpsGraph subtractionWithPrimaryTensor:inputTensor secondaryTensor:targetTensor name:nil]; diff --git a/aten/src/ATen/native/mps/operations/Normalization.mm b/aten/src/ATen/native/mps/operations/Normalization.mm index 5384ee666fead..fd0a8471c7545 100644 --- a/aten/src/ATen/native/mps/operations/Normalization.mm +++ b/aten/src/ATen/native/mps/operations/Normalization.mm @@ -135,7 +135,9 @@ void get_shapes(MPSShape* input_shape_readonly, + std::to_string(momentum) + ":" + std::to_string(train) + ":" + std::to_string(has_running_mean) + ":" + std::to_string(has_weight) + ":" + std::to_string(has_bias) + ":" - + [ns_shape_key UTF8String] + ":" + native_mps::getMPSTypeString(self.scalar_type()); + + [ns_shape_key UTF8String] + ":" + + native_mps::getTensorsStringKey({ + self, weight_opt.value_or(Tensor()), bias_opt.value_or(Tensor()), running_mean_opt.value_or(Tensor()), running_var_opt.value_or(Tensor())}); auto input_mps_dtype = native_mps::getMPSDataType(self.scalar_type()); CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); @@ -180,6 +182,7 @@ void get_shapes(MPSShape* input_shape_readonly, MPSGraphTensor* updatedRunningMeanTensor = nil; MPSGraphTensor* updatedRunningVarTensor = nil; + MPSGraphTensor *scaledInverseSqrtVariance = nil; /* If train: @@ -195,6 +198,7 @@ Check if running mean exists (maybe do this check before making graph) Compute the batch norm output and stats to be saved */ + MPSGraphTensor *varTensor = nil; if(train) { // Compute mean and variance of the current batch @@ -204,6 +208,7 @@ Check if running mean exists (maybe do this check before making graph) MPSGraphTensor* batchVarianceTensor = [mpsGraph varianceOfTensor:inputTensor axes:axes name:nil]; + varTensor = batchVarianceTensor; if(has_running_mean) { // TODO: This is not the formula used in PyTorch, is this OK? Seems more robust // float besselCorrectionTerm = float(N) / std::max(N - 1.0f, 1.0f); @@ -240,14 +245,27 @@ Check if running mean exists (maybe do this check before making graph) updatedRunningVarTensor = [mpsGraph additionWithPrimaryTensor:scaledCorrectedBatchVar secondaryTensor:scaledRunningVar name:nil]; - // Update saved mean and inverse std tensor - saveMeanTensor = batchMeanTensor; - saveVarTensor = batchVarianceTensor; - } - else { - saveMeanTensor = batchMeanTensor; - saveVarTensor = batchVarianceTensor; } + // Update saved mean and inverse std tensor + MPSGraphTensor *epsilonTensor = [mpsGraph constantWithScalar:(double)epsilon + shape:@[@1] + dataType:MPSDataTypeFloat32]; + + MPSGraphTensor *varianceEps = [mpsGraph additionWithPrimaryTensor:batchVarianceTensor + secondaryTensor:epsilonTensor + name:@"varianceEps"]; + + MPSGraphTensor *sqrtVariance = [mpsGraph squareRootWithTensor:varianceEps + name:@"sqrtVariance"]; + float primary = 1.0f; + MPSGraphTensor *primaryTensor = [mpsGraph constantWithScalar:primary dataType:MPSDataTypeFloat32]; + + scaledInverseSqrtVariance = [mpsGraph divisionWithPrimaryTensor:primaryTensor + secondaryTensor:sqrtVariance + name:nil]; + // Update saved mean and inverse std tensor + saveMeanTensor = batchMeanTensor; + saveVarTensor = scaledInverseSqrtVariance; } else { // Test TORCH_CHECK(has_running_mean); @@ -255,12 +273,13 @@ Check if running mean exists (maybe do this check before making graph) name:nil]; saveVarTensor = [mpsGraph identityWithTensor:runningVarTensor name:nil]; + varTensor = saveVarTensor; } // Compute output of batch norm MPSGraphTensor* outputTensor = [mpsGraph normalizationWithTensor:inputTensor meanTensor:saveMeanTensor - varianceTensor:saveVarTensor + varianceTensor:varTensor gammaTensor:weightTensor betaTensor:biasTensor epsilon:(float)epsilon @@ -352,6 +371,10 @@ Check if running mean exists (maybe do this check before making graph) } + if(!train) { + save_mean.resize_({0}); + save_var.resize_({0}); + } return std::tuple(output, save_mean, save_var); } @@ -411,6 +434,54 @@ Check if running mean exists (maybe do this check before making graph) return std::make_tuple(output, save_mean, save_var); } +std::tuple _batch_norm_legit_mps + (const Tensor& self, + const c10::optional& weight_opt, + const c10::optional& bias_opt, + Tensor& running_mean, + Tensor& running_var, + bool train, + double momentum, + double epsilon) { + + return batch_norm_mps(self, weight_opt, bias_opt, running_mean, running_var, train, momentum, epsilon); +} + +std::tuple _batch_norm_legit_no_stats_mps + (const Tensor& self, + const c10::optional& weight_opt, + const c10::optional& bias_opt, + bool train, + double momentum, + double epsilon) { + + return batch_norm_mps(self, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, epsilon); +} + +std::tuple _batch_norm_legit_mps_out + (const Tensor& self, + const c10::optional& weight_opt, + const c10::optional& bias_opt, + Tensor& running_mean, + Tensor& running_var, + bool train, double momentum, double epsilon, + Tensor& output, + Tensor& save_mean, + Tensor& save_var) { + return batch_norm_mps_out(self, weight_opt, bias_opt, running_mean, running_var, train, momentum, epsilon, output, save_mean, save_var); +} + +std::tuple _batch_norm_legit_no_stats_mps_out + (const Tensor& self, + const c10::optional& weight_opt, + const c10::optional& bias_opt, + bool train, double momentum, double epsilon, + Tensor& output, + Tensor& save_mean, + Tensor& save_var) { + return batch_norm_mps_out(self, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, epsilon, output, save_mean, save_var); +} + string get_mem_string(c10::MemoryFormat memory_format) { string mem_format_key; switch(memory_format) { @@ -602,11 +673,24 @@ string get_mem_string(c10::MemoryFormat memory_format) { if(train) { // Use save_mean and save_var + float primary = 1.0f; + MPSGraphTensor *primaryTensor = [mpsGraph constantWithScalar:primary dataType:MPSDataTypeFloat32]; + MPSGraphTensor *epsilonTensor = [mpsGraph constantWithScalar:(float)epsilon dataType:MPSDataTypeFloat32]; + MPSGraphTensor *revertSaveVarTensor = saveVarTensor; + revertSaveVarTensor = [mpsGraph divisionWithPrimaryTensor: primaryTensor + secondaryTensor: revertSaveVarTensor + name: nil]; + revertSaveVarTensor = [mpsGraph multiplicationWithPrimaryTensor: revertSaveVarTensor + secondaryTensor: revertSaveVarTensor + name: nil]; + revertSaveVarTensor = [mpsGraph subtractionWithPrimaryTensor: revertSaveVarTensor + secondaryTensor: epsilonTensor + name: nil]; if(grad_input_mask[1]) { gradWeightTensor = [mpsGraph normalizationGammaGradientWithIncomingGradientTensor:gradOutputTensor sourceTensor:inputTensor meanTensor:saveMeanTensor - varianceTensor:saveVarTensor + varianceTensor:revertSaveVarTensor reductionAxes:axes epsilon:(float)epsilon name:nil]; @@ -621,7 +705,7 @@ string get_mem_string(c10::MemoryFormat memory_format) { gradInputTensor = [mpsGraph normalizationGradientWithIncomingGradientTensor:gradOutputTensor sourceTensor:inputTensor meanTensor:saveMeanTensor - varianceTensor:saveVarTensor + varianceTensor:revertSaveVarTensor gammaTensor:weightTensor gammaGradientTensor:gradWeightTensor betaGradientTensor:gradBiasTensor @@ -823,7 +907,7 @@ string get_mem_string(c10::MemoryFormat memory_format) { const int normalized_ndim = normalized_shape.size(); // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) const int axis = input_ndim - normalized_ndim; - at::Tensor input_reshaped = input.reshape({1, M, -1}); + at::Tensor input_reshaped = input.numel() == 0 ? input.reshape({1, M, 0}) : input.reshape({1, M, -1}); // Unlike Batch Normalization, which applies scalar scale and bias for each // entire channel/plane with the affine option, Layer Normalization applies // per-element scale and bias. E.g. For input {N, C, H, W}, weight for @@ -843,8 +927,6 @@ string get_mem_string(c10::MemoryFormat memory_format) { at::Tensor mean = std::get<1>(outputs); at::Tensor variance = std::get<2>(outputs); - at::Tensor rstd = at::rsqrt(at::add(variance, eps)); - std::vector stat_shape; for (const auto idx : c10::irange(axis)) { stat_shape.push_back(input_shape[idx]); @@ -854,8 +936,8 @@ string get_mem_string(c10::MemoryFormat memory_format) { stat_shape.push_back(1); } mean = mean.view(stat_shape); - rstd = rstd.view(stat_shape); - return std::make_tuple(out, mean, rstd); + variance = variance.view(stat_shape); + return std::make_tuple(out, mean, variance); } std::tuple layer_norm_backward_mps( diff --git a/aten/src/ATen/native/mps/operations/Pad.mm b/aten/src/ATen/native/mps/operations/Pad.mm index 63a26e66288be..c6029e3d7b2ae 100644 --- a/aten/src/ATen/native/mps/operations/Pad.mm +++ b/aten/src/ATen/native/mps/operations/Pad.mm @@ -13,7 +13,7 @@ MPSGraphPaddingMode mode, double constantValue, const string op_name) { const int padding_size = (int) padding.size(); - const int padding_dim = padding_size / 2; // either 1D, 2D, or 3D + int padding_dim = padding_size / 2; // either 1D, 2D, or 3D TORCH_CHECK(padding_size == 2 || padding_size == 4 || padding_size == 6, "invalid padding argument of size ", padding_size); @@ -23,33 +23,44 @@ int64_t nbatch = 1; int64_t ndims = input_.ndimension(); + + TORCH_CHECK(ndims >= (int64_t)padding_dim, "Length of pad should be no more than twice the number of " + "dimensions of the input. Pad length is ", padding_size, "while the input has ", ndims, "dimensions."); + // number of input dims with ConstantPad could be less than 2 - int dim_w = ndims > 1 ? padding_dim : 0; + int dim_w = padding_dim; int dim_h = padding_dim - 1; int dim_d = padding_dim - 2; int dim_slices = 0; - if (!is_backward_pass && ndims > 1) { + if (!is_backward_pass && mode != MPSGraphPaddingModeConstant && ndims > padding_dim) { bool valid_dims = input_.size(1) != 0 && input_.size(padding_dim) != 0; TORCH_CHECK((ndims == 1 + padding_dim && valid_dims) || (ndims == 2 + padding_dim && valid_dims && input_.size(1 + padding_dim) != 0), "3D or 4D (batch mode) tensor expected for input, but got: ", input_); } - if (ndims == 2 + padding_dim) { - nbatch = input_.size(0); - dim_w++; - dim_h++; - dim_d++; + if (ndims == padding_dim) { + dim_w--; + dim_h--; + dim_d--; + } else if (ndims > padding_dim + 1) { + const int dim_diff = (int)ndims - padding_dim - 1; + // this virtually inflates the padding with zeros if ndims > padding_dim + 2 + padding_dim += dim_diff - 1; + dim_w += dim_diff; + dim_h += dim_diff; + dim_d += dim_diff; dim_slices++; + nbatch = input_.size(0); } int64_t pad_l = padding[0]; int64_t pad_r = padding[1]; - int64_t pad_t = padding_dim > 1 ? padding[2] : 0; - int64_t pad_b = padding_dim > 1 ? padding[3] : 0; - int64_t pad_front = padding_dim > 2 ? padding[4] : 0; - int64_t pad_back = padding_dim > 2 ? padding[5] : 0; + int64_t pad_t = padding_size > 2 ? padding[2] : 0; + int64_t pad_b = padding_size > 2 ? padding[3] : 0; + int64_t pad_front = padding_size > 4 ? padding[4] : 0; + int64_t pad_back = padding_size > 4 ? padding[5] : 0; int64_t nplane = input_.size(dim_slices); int64_t input_w = input_.size(dim_w); @@ -62,40 +73,64 @@ Tensor grad_output, input = input_; if (!is_backward_pass) { - TORCH_CHECK(pad_l < input_w && pad_r < input_w, - "Argument #4: Padding size should be less than the corresponding " - "input dimension, but got: padding (", pad_l, ", ", pad_r, - ") at dimension ", dim_w, " of input ", ndims); - - if (padding_dim > 1) { - TORCH_CHECK(pad_t < input_h && pad_b < input_h, - "Argument #6: Padding size should be less than the corresponding " - "input dimension, but got: padding (", pad_t, ", ", pad_b, - ") at dimension ", dim_h, " of input ", ndims); - } TORCH_CHECK(output_w >= 1 || output_h >= padding_dim - 1, "input (H: ", input_h, ", W: ", input_w, ") is too small. Calculated " "output H: ", output_h, " W: ", output_w); - if (ndims == 1 + padding_dim) { - if (padding_dim == 3) - output.resize_({nplane, output_d, output_h, output_w}); - else if (padding_dim == 2) - output.resize_({nplane, output_h, output_w}); - else - output.resize_({nplane, output_w}); + std::vector outputSizes; + if (mode == MPSGraphPaddingModeConstant) { + // support arbitrary input dimensions for constant pad. + auto input_sizes = input_.sizes(); + auto ori_padding_dim = padding_size / 2; + auto l_diff = ndims - ori_padding_dim; + + for (size_t i = 0; i < (size_t)l_diff; i ++) { + outputSizes.emplace_back(input_sizes[i]); + } + for (const auto i : c10::irange((size_t)ori_padding_dim)) { + auto pad_idx = padding.size() - ((i + 1) * 2); + auto new_dim = input_sizes[l_diff + i] + padding[pad_idx] + padding[pad_idx + 1]; + outputSizes.emplace_back(new_dim); + } } else { - if (padding_dim == 3) - output.resize_({nbatch, nplane, output_d, output_h, output_w}); - else if (padding_dim == 2) - output.resize_({nbatch, nplane, output_h, output_w}); - else if (ndims > 1) - output.resize_({nbatch, nplane, output_w}); - else - output.resize_({output_w}); + // these checks aren't relevant for constant pad + TORCH_CHECK(pad_l < input_w && pad_r < input_w, + "Argument #4: Padding size should be less than the corresponding " + "input dimension, but got: padding (", pad_l, ", ", pad_r, + ") at dimension ", dim_w, " of input ", ndims); + + if (padding_dim > 1) { + TORCH_CHECK(pad_t < input_h && pad_b < input_h, + "Argument #6: Padding size should be less than the corresponding " + "input dimension, but got: padding (", pad_t, ", ", pad_b, + ") at dimension ", dim_h, " of input ", ndims); + } + if (padding_dim > 2) { + TORCH_CHECK(pad_front < input_d && pad_back < input_d, + "Argument #8: Padding size should be less than the corresponding " + "input dimension, but got: padding (", pad_front, ", ", pad_back, + ") at dimension ", dim_d, " of input ", ndims); + } + outputSizes.insert(outputSizes.begin(), output_w); + if (padding_dim >= 2) + outputSizes.insert(outputSizes.begin(), output_h); + if (padding_dim >= 3) + outputSizes.insert(outputSizes.begin(), output_d); + if (ndims >= 1 + padding_dim) + outputSizes.insert(outputSizes.begin(), nplane); + if (ndims >= 2 + padding_dim) + outputSizes.insert(outputSizes.begin(), nbatch); + } + + output.resize_(outputSizes); + + if (output.numel() == 0) { + return output; } - if (output.numel() == 0 || input_.numel() == 0) + if (input_.numel() == 0) { + output.fill_(constantValue); return output; + } input = input_.contiguous(); } else { TORCH_CHECK(output_w == grad_output_.size(dim_w), @@ -104,24 +139,57 @@ TORCH_CHECK(output_h == grad_output_.size(dim_h), "gradOutput height unexpected. Expected: ", output_h, ", Got: ", grad_output_.size(dim_h)); } + output.resize_as_(input); + if (output.numel() == 0 || grad_output_.numel() == 0) + return output; grad_output = grad_output_.contiguous(); } + const uint32_t dims_mask = (1U << ndims) - 1; + uint32_t startMask = dims_mask, endMask = dims_mask; std::vector leftPadVec(ndims, @(0)); std::vector rightPadVec(ndims, @(0)); - leftPadVec [ndims - 1] = @(pad_l); - rightPadVec[ndims - 1] = @(pad_r); - if (padding_dim >= 2) { - leftPadVec [ndims - 2] = @(pad_t); - rightPadVec[ndims - 2] = @(pad_b); - } - if (padding_dim >= 3) { - leftPadVec [ndims - 3] = @(pad_front); - rightPadVec[ndims - 3] = @(pad_back); + std::vector startsVec(ndims, @(0)); + std::vector endsVec(ndims, @(0)); + std::vector stridesVec(ndims, @(1)); + + for (int64_t pdim = 0; pdim < padding_size / 2; pdim++) { + const int64_t leftIdx = pdim * 2; + const int64_t rightIdx = pdim * 2 + 1; + const int64_t padIdx = ndims - pdim - 1; + + leftPadVec [padIdx] = @(padding[leftIdx]); + rightPadVec[padIdx] = @(padding[rightIdx]); + // workaround for negative padding issue in backward pass + if (is_backward_pass) { + if (padding[leftIdx] < 0) { + leftPadVec[padIdx] = @(0); + startsVec[padIdx] = @(-padding[leftIdx]); + startMask &= ~(1U << padIdx); + } + if (padding[rightIdx] < 0) { + rightPadVec[padIdx] = @(0); + endsVec[padIdx] = @(input.size(padIdx) + padding[rightIdx]); + endMask &= ~(1U << padIdx); + } + // workaround for the right padding bug in Monterey + } else if (!is_macos_13_or_newer()) { + if (padding[rightIdx] == 1 && padding[leftIdx] == 0) { + rightPadVec[padIdx] = @(2); + endsVec[padIdx] = @(input.size(padIdx) + 2); + endMask &= ~(1U << padIdx); + } + } } MPSShape *leftPadding = [NSArray arrayWithObjects:leftPadVec.data() count:ndims]; MPSShape *rightPadding = [NSArray arrayWithObjects:rightPadVec.data() count:ndims]; + MPSDataType dataType = getMPSScalarType(input.scalar_type()); + // workaround for Bool type assert with Constant padding + if (input.scalar_type() == kBool) { + dataType = MPSDataTypeInt8; + } + struct CachedGraph : public MPSCachedGraph { CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) { } MPSGraphTensor *inputTensor = nil, *outputTensor = nil; @@ -130,50 +198,78 @@ MPSGraphCache* cache_ = MPSGraphCache::getInstance(); @autoreleasepool { - string key = op_name + getTensorsStringKey({input, grad_output}) + - ":L" + to_string(pad_l) + ":R" + to_string(pad_r) + - ":T" + to_string(pad_t) + ":B" + to_string(pad_b) + - ":F" + to_string(pad_front) + ":K" + to_string(pad_back); + string key = op_name + getTensorsStringKey({input, grad_output, output}) + ":[" + + getArrayRefString(padding) + "]:" + std::to_string(constantValue); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + CachedGraph* cachedGraph = cache_->LookUpAs(key); if(!cachedGraph) { - cachedGraph = static_cast(cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { + cachedGraph = cache_->CreateCachedGraphAs(key, ^ MPSCachedGraph * () { CachedGraph *newCachedGraph = nil; @autoreleasepool { - MPSGraph* mpsGraph = make_mps_graph(); - newCachedGraph = new CachedGraph(mpsGraph); - newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input); - if (!is_backward_pass) { - newCachedGraph->outputTensor = [mpsGraph padTensor:newCachedGraph->inputTensor - withPaddingMode:mode - leftPadding:leftPadding - rightPadding:rightPadding - constantValue:constantValue - name:nil]; + MPSGraph* mpsGraph = make_mps_graph(); + newCachedGraph = new CachedGraph(mpsGraph); + newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, dataType, getMPSShape(input)); + const bool needsSlice = startMask != dims_mask || endMask != dims_mask; + + if (!is_backward_pass) { + MPSGraphTensor *padTensor = [mpsGraph padTensor:newCachedGraph->inputTensor + withPaddingMode:mode + leftPadding:leftPadding + rightPadding:rightPadding + constantValue:constantValue + name:nil]; + // workaround for the right padding bug in Monterey + if (needsSlice) { + newCachedGraph->outputTensor = [mpsGraph sliceTensor:padTensor + starts:[NSArray arrayWithObjects:startsVec.data() count:ndims] + ends:[NSArray arrayWithObjects:endsVec.data() count:ndims] + strides:[NSArray arrayWithObjects:stridesVec.data() count:ndims] + startMask:startMask + endMask:endMask + squeezeMask:0 + name:nil]; + } else { + newCachedGraph->outputTensor = padTensor; + } + } else { + newCachedGraph->gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, dataType, getMPSShape(grad_output)); + MPSGraphTensor *padGradTensor = [mpsGraph padGradientWithIncomingGradientTensor:newCachedGraph->gradOutputTensor + sourceTensor:newCachedGraph->inputTensor + paddingMode:mode + leftPadding:leftPadding + rightPadding:rightPadding + name:nil]; + // workaround for negative padding issue with padGradientWithIncomingGradientTensor() + if (needsSlice) { + newCachedGraph->outputTensor = [mpsGraph sliceGradientTensor:padGradTensor + fwdInShapeTensor:[mpsGraph shapeOfTensor:newCachedGraph->inputTensor name:nil] + starts:[NSArray arrayWithObjects:startsVec.data() count:ndims] + ends:[NSArray arrayWithObjects:endsVec.data() count:ndims] + strides:[NSArray arrayWithObjects:stridesVec.data() count:ndims] + startMask:startMask + endMask:endMask + squeezeMask:0 + name:nil]; } else { - newCachedGraph->gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output); - newCachedGraph->outputTensor = [mpsGraph padGradientWithIncomingGradientTensor:newCachedGraph->gradOutputTensor - sourceTensor:newCachedGraph->inputTensor - paddingMode:mode - leftPadding:leftPadding - rightPadding:rightPadding - name:nil]; + newCachedGraph->outputTensor = padGradTensor; } + } } return newCachedGraph; - })); + }); } - Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor, input); - Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, output); + Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor, input, nullptr, true, dataType); + Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, output, nullptr, true, dataType); + Placeholder gradOutputPlaceholder = !is_backward_pass ? Placeholder() : + Placeholder(cachedGraph->gradOutputTensor, grad_output, nullptr, true, dataType); NSMutableDictionary *feeds = [[NSMutableDictionary new] autorelease]; feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData(); if (is_backward_pass) { - Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor, grad_output); - feeds[gradOutputPlaceholder.getMPSGraphTensor()] = gradOutputPlaceholder.getMPSGraphTensorData(); + feeds[gradOutputPlaceholder.getMPSGraphTensor()] = gradOutputPlaceholder.getMPSGraphTensorData(); } NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() + outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() }; runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results); } diff --git a/aten/src/ATen/native/mps/operations/PointwiseOps.mm b/aten/src/ATen/native/mps/operations/PointwiseOps.mm index eb68239ecedd6..9ed6298368716 100644 --- a/aten/src/ATen/native/mps/operations/PointwiseOps.mm +++ b/aten/src/ATen/native/mps/operations/PointwiseOps.mm @@ -15,8 +15,9 @@ const bool is_div, const string op_name) { - if (&output != &self) { - output.resize_(output.sizes()); + if (value_opt.toDouble() == 0.0) { + output.copy_(self); + return output; } if(output.numel() == 0) { @@ -34,12 +35,12 @@ MPSGraphCache* cache_ = MPSGraphCache::getInstance(); @autoreleasepool { - string key = op_name + getTensorsStringKey({self, tensor1, tensor2}, false); + string key = op_name + getTensorsStringKey({self, tensor1, tensor2}); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + CachedGraph* cachedGraph = cache_->LookUpAs(key); if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { + cachedGraph = cache_->CreateCachedGraphAs(key, ^ MPSCachedGraph * () { CachedGraph* newCachedGraph = nil; @autoreleasepool { @@ -49,7 +50,7 @@ newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); newCachedGraph->firstTensor = mpsGraphRankedPlaceHolder(mpsGraph, tensor1); newCachedGraph->secondTensor = mpsGraphRankedPlaceHolder(mpsGraph, tensor2); - newCachedGraph->valueTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSScalarType(self.scalar_type())); + newCachedGraph->valueTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(self.scalar_type()), @[@1]); // the tensor to be optionally multiplied by value_scalar MPSGraphTensor *multiplicandTensor = nil; @@ -72,7 +73,6 @@ } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); } // Inputs as placeholders diff --git a/aten/src/ATen/native/mps/operations/Pooling.mm b/aten/src/ATen/native/mps/operations/Pooling.mm index 1df24e073239e..89f8c22ee96a4 100644 --- a/aten/src/ATen/native/mps/operations/Pooling.mm +++ b/aten/src/ATen/native/mps/operations/Pooling.mm @@ -1,867 +1,349 @@ // Copyright © 2022 Apple Inc. -#include -#include -#include -#include -#include #include #include -#include namespace at { namespace native { +namespace mps { + +struct PoolingCachedGraph : public MPSCachedGraph +{ + PoolingCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* inputTensor = nil; + MPSGraphTensor* outputTensor = nil; + MPSGraphTensor* indicesTensor = nil; + MPSGraphTensor* gradOutputTensor = nil; + MPSGraphTensor* divisorTensor = nil; +}; + +typedef MPSGraphTensor* (^PoolingOpBlock)(PoolingCachedGraph&, MPSGraphPooling2DOpDescriptor*); +#define PoolingOpFn(graph, desc) MPSGraphTensor* (mps::PoolingCachedGraph& graph, MPSGraphPooling2DOpDescriptor* desc) + +// Pooling ops (1D/2D forward and backward Max and Average pooling) +static void pool2d_template(const Tensor& input, const Tensor& output, + const c10::optional& indices_opt, + const c10::optional& grad_output_opt, + IntArrayRef kernel_size, IntArrayRef stride, + IntArrayRef padding, IntArrayRef dilation, + bool ceil_mode, const c10::optional divisor, + PoolingOpBlock poolingBlock, const c10::string& op_name) +{ + if (input.numel() == 0) + return; -// Create pooling descriptor -void fill_pool_desc(MPSGraphPooling2DOpDescriptor* desc, - NSUInteger kW, NSUInteger kH, - NSUInteger dW, NSUInteger dH, - NSUInteger dilationW, NSUInteger dilationH, - NSUInteger padW, NSUInteger padH, - bool ceil_mode, c10::MemoryFormat memory_format) { - desc.kernelWidth = kW; - desc.kernelHeight = kH; - desc.strideInX = dW; - desc.strideInY = dH; - desc.dilationRateInX = dilationW; - desc.dilationRateInY = dilationH; - desc.paddingLeft = padW; - desc.paddingRight = padW; - desc.paddingTop = padH; - desc.paddingBottom = padH; - desc.ceilMode = ceil_mode; - desc.paddingStyle = MPSGraphPaddingStyleExplicit; - switch(memory_format) { - case at::MemoryFormat::Contiguous: - desc.dataLayout = MPSGraphTensorNamedDataLayoutNCHW; - break; - case at::MemoryFormat::ChannelsLast: - desc.dataLayout = MPSGraphTensorNamedDataLayoutNHWC; - break; - default: - assert(0 && "Check should have been done earlier\n"); + if (!is_macos_13_or_newer()) { + TORCH_CHECK(input.scalar_type() != ScalarType::Long, + "MPS: ", op_name, " op with int64 input is supported natively starting from macOS 13.0."); + } + const int64_t ndims = input.ndimension(); + const Tensor& grad_output = *(at::borrow_from_optional_tensor(grad_output_opt)); + const Tensor& indices = *(at::borrow_from_optional_tensor(indices_opt)); + const bool is_backward_pass = grad_output.defined(); + const bool has_indices = indices.defined(); + const bool has_divisor = divisor.has_value(); + const auto suggested_memory_format = input.suggest_memory_format(); + // for max_pool2d_with_indices() we cannot pass ChannelsLast (i.e., NHWC) to 'desc.dataLayout' in MPSGraph. + // Because the returned indices will be selected based on NHWC memory layout which will + // be incompatible with the PyTorch's global NCHW layout. + const auto memory_format = has_indices ? MemoryFormat::Contiguous : suggested_memory_format; + + TORCH_CHECK(kernel_size.size() == 1 || kernel_size.size() == 2, op_name, + ": kernel_size must either be a single int, or a tuple of two ints") + TORCH_CHECK(stride.size() == 0 || stride.size() == 1 || stride.size() == 2, op_name, + ": stride must either be omitted, a single int, or a tuple of two ints") + TORCH_CHECK(padding.size() == 1 || padding.size() == 2, op_name, + ": padding must be either be a single int, or a tuple of two ints"); + TORCH_CHECK(dilation.size() == 1 || dilation.size() == 2, op_name, + ": dilation must be either a single int, or a tuple of two ints"); + + if (suggested_memory_format == at::MemoryFormat::ChannelsLast) { + TORCH_CHECK(ndims == 4, "non-empty 4D (batch mode) tensor expected for input with channels_last layout"); + } else if (suggested_memory_format == at::MemoryFormat::Contiguous) { + TORCH_CHECK((ndims == 3 || ndims == 4), "non-empty 3D or 4D (batch mode) tensor expected for input"); + } else { + TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous"); } -} - -Tensor _mps_max_pool2d( - const Tensor& input_t, - IntArrayRef kernel_size, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - bool ceil_mode) { - // #20866, #22032: Guarantee this for the official C++ API? - TORCH_CHECK(kernel_size.size() == 1 || kernel_size.size() == 2, - "max_pool2d: kernel_size must either be a single int, or a tuple of two ints") const int kH = safe_downcast(kernel_size[0]); const int kW = kernel_size.size() == 1 ? kH : safe_downcast(kernel_size[1]); - - // NB: stride default is not expressible as an integer constant, so we accept - // empty stride for this case - TORCH_CHECK(stride.size() == 0 || stride.size() == 1 || stride.size() == 2, - "max_pool2d: stride must either be omitted, a single int, or a tuple of two ints") const int dH = stride.empty() ? kH : safe_downcast(stride[0]); - const int dW = stride.empty() ? kW : - stride.size() == 1 ? dH : safe_downcast(stride[1]); - - TORCH_CHECK(padding.size() == 1 || padding.size() == 2, - "max_pool2d: padding must be either be a single int, or a tuple of two ints"); + const int dW = stride.empty() ? kW : stride.size() == 1 ? dH : safe_downcast(stride[1]); const int padH = safe_downcast(padding[0]); const int padW = padding.size() == 1 ? padH : safe_downcast(padding[1]); - - TORCH_CHECK(dilation.size() == 1 || dilation.size() == 2, - "max_pool2d: dilation must be either a single int, or a tuple of two ints"); const int dilationH = safe_downcast(dilation[0]); const int dilationW = dilation.size() == 1 ? dilationH : safe_downcast(dilation[1]); - - const auto memory_format = input_t.suggest_memory_format(); - if (memory_format == at::MemoryFormat::ChannelsLast) { - TORCH_CHECK(input_t.ndimension() == 4, - "non-empty 4D (batch mode) tensor expected for input with channels_last layout"); - } else if (memory_format == at::MemoryFormat::Contiguous) { - TORCH_CHECK((input_t.ndimension() == 3 || input_t.ndimension() == 4), - "non-empty 3D or 4D (batch mode) tensor expected for input"); - } else { - TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous"); - } - - /* sizes */ - const int64_t nbatch = input_t.ndimension() == 4 ? input_t.size(-4) : 1; - const int64_t nInputPlane = input_t.size(-3); - const int64_t inputHeight = input_t.size(-2); - const int64_t inputWidth = input_t.size(-1); - + const int64_t nbatch = ndims == 4 ? input.size(-4) : 1; + const int64_t nInputPlane = input.size(-3); + const int64_t inputHeight = input.size(-2); + const int64_t inputWidth = input.size(-1); const int64_t outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, dilationH, ceil_mode); const int64_t outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, dilationW, ceil_mode); - pool2d_shape_check( - input_t, - kH, kW, dH, dW, padH, padW, dilationH, dilationW, - nInputPlane, - inputHeight, inputWidth, - outputHeight, outputWidth, memory_format); - - namespace native_mps = at::native::mps; - using CachedGraph = native_mps::MPSUnaryCachedGraph; - - native_mps::MPSGraphCache* cache_ = native_mps::MPSGraphCache::getInstance(); + pool2d_shape_check(input, kH, kW, dH, dW, padH, padW, dilationH, dilationW, + nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, memory_format); - Tensor output_t; - - if (input_t.ndimension() == 3) { - output_t = at::native::empty_mps( - {nInputPlane, outputHeight, outputWidth}, - input_t.scalar_type(), - c10::nullopt, - kMPS, - c10::nullopt, - memory_format); - } else { - output_t = at::native::empty_mps( - {nbatch, nInputPlane, outputHeight, outputWidth}, - input_t.scalar_type(), - c10::nullopt, - kMPS, - c10::nullopt, - memory_format); + // the output and indices are 'empty', so we could avoid unnecessary gatherView on empty tensors + // by simply restriding them (instead of calling the costly Contiguous()). + if (indices.suggest_memory_format() == MemoryFormat::ChannelsLast) { + indices.unsafeGetTensorImpl()->empty_tensor_restride(MemoryFormat::Contiguous); } - - if (output_t.numel() == 0) { - return output_t; + if (output.numel() == 0) { + std::vector outputSizes {nInputPlane, outputHeight, outputWidth}; + if (ndims == 4) { + outputSizes.insert(outputSizes.begin(), nbatch); + } + output.resize_(outputSizes); + } else if (output.suggest_memory_format() == MemoryFormat::ChannelsLast) { + output.unsafeGetTensorImpl()->empty_tensor_restride(MemoryFormat::Contiguous); } - auto stream = at::mps::getCurrentMPSStream(); + if (output.numel() == 0 || (is_backward_pass && grad_output.numel() == 0)) { + return; + } + MPSGraphCache* cache_ = MPSGraphCache::getInstance(); @autoreleasepool { + string key = op_name + getTensorsStringKey({input, indices, grad_output}) + ":K[" + + getArrayRefString(kernel_size) + "]:S[" + getArrayRefString(stride) + "]:P[" + + getArrayRefString(padding) + "]:D[" + getArrayRefString(dilation) + "]" + + (ceil_mode ? ":ceil" : "") + ":" + (suggested_memory_format == MemoryFormat::ChannelsLast ? "NHWC" : "NCHW"); - string mem_format_key; - switch(memory_format) { - case at::MemoryFormat::Contiguous: - mem_format_key = "Contiguous"; - break; - case at::MemoryFormat::ChannelsLast: - mem_format_key = "ChannelsLast"; - break; - default: - assert(0 && "Check should have been done earlier\n"); - } - - string key = "mps_max_pool2d:" + to_string(kW) + ":" + to_string(kH) + ":" + - to_string(dW) + ":" + to_string(dH) + ":" + - to_string(dilationW) + ":" + to_string(dilationH) + ":" + - to_string(padW) + ":" + to_string(padH) + ":" + - to_string(ceil_mode) + ":" + mem_format_key + - mps::getTensorsStringKey({input_t}); - CachedGraph* cachedGraph = cache_->LookUpAs(key); + MPSShape* inputShape = getMPSShape(input, memory_format); + MPSShape* gradOutputShape = is_backward_pass ? getMPSShape(grad_output, memory_format) : nullptr; + PoolingCachedGraph* cachedGraph = cache_->LookUpAs(key); if(!cachedGraph) { - native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () { - CachedGraph *newCachedGraph = nil; + cachedGraph = cache_->CreateCachedGraphAs(key, ^ MPSCachedGraph * () { + PoolingCachedGraph *newCachedGraph = nil; @autoreleasepool { - MPSGraph* mpsGraph = native_mps::make_mps_graph(); - newCachedGraph = new CachedGraph(mpsGraph); - - MPSGraphPooling2DOpDescriptor* desc = [[MPSGraphPooling2DOpDescriptor new] autorelease]; - fill_pool_desc(desc, kW, kH, dW, dH, dilationW, dilationH, padW, padH, ceil_mode, memory_format); - - MPSGraphTensor* inputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, input_t); - MPSGraphTensor* outputTensor = [mpsGraph maxPooling2DWithSourceTensor:inputTensor - descriptor:desc - name:nil]; - newCachedGraph->inputTensor_ = inputTensor; - newCachedGraph->outputTensor_ = outputTensor; + MPSGraph* mpsGraph = make_mps_graph(); + newCachedGraph = new PoolingCachedGraph(mpsGraph); + + MPSGraphPooling2DOpDescriptor* desc = [MPSGraphPooling2DOpDescriptor + descriptorWithKernelWidth: kW + kernelHeight: kH + strideInX: dW + strideInY: dH + dilationRateInX: dilationW + dilationRateInY: dilationH + paddingLeft: padW + paddingRight: ceil_mode ? padW * dW : padW + paddingTop: padH + paddingBottom: ceil_mode ? padH * dH : padH + paddingStyle: MPSGraphPaddingStyleExplicit + dataLayout: memory_format == MemoryFormat::ChannelsLast ? + MPSGraphTensorNamedDataLayoutNHWC : + MPSGraphTensorNamedDataLayoutNCHW]; + desc.ceilMode = (padW == 0 && padH == 0) ? ceil_mode : false; + if (has_indices) { + desc.returnIndicesMode = MPSGraphPoolingReturnIndicesGlobalFlatten2D; + desc.returnIndicesDataType = MPSDataTypeInt32; + } + newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(input.scalar_type()), inputShape); + if (is_backward_pass) { + newCachedGraph->gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(grad_output.scalar_type()), gradOutputShape); + } + if (has_divisor) { + newCachedGraph->divisorTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(output.scalar_type()), @[@1]); + } + MPSGraphTensor* outputTensor = poolingBlock(*newCachedGraph, desc); + // with desc.dataLayout = NHWC (i.e., ChannelsLast), the results need to be converted back to NCHW + newCachedGraph->outputTensor = memory_format == MemoryFormat::ChannelsLast ? + convertNHWCtoNCHW(mpsGraph, outputTensor) : outputTensor; } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); } - auto inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t); - auto outputPlaceholder = native_mps::Placeholder(cachedGraph->outputTensor_, output_t); + MPSStream* mpsStream = getCurrentMPSStream(); + // in case of ChannelsLast we don't perform gather() in placeholder to avoid implicit conversion to NCHW + Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor, input, inputShape, memory_format != MemoryFormat::ChannelsLast); + Placeholder gradOutputPlaceholder = !is_backward_pass ? Placeholder() : + Placeholder(cachedGraph->gradOutputTensor, grad_output, + gradOutputShape, memory_format != MemoryFormat::ChannelsLast); + Placeholder indicesPlaceholder = has_indices ? Placeholder(cachedGraph->indicesTensor, indices) : Placeholder(); + Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, output); + NSMutableDictionary *feeds = [[NSMutableDictionary new] autorelease]; + NSMutableDictionary *results = [[NSMutableDictionary new] autorelease]; + + feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData(); + results[outputPlaceholder.getMPSGraphTensor()] = outputPlaceholder.getMPSGraphTensorData(); + + if (cachedGraph->gradOutputTensor) { + feeds[gradOutputPlaceholder.getMPSGraphTensor()] = gradOutputPlaceholder.getMPSGraphTensorData(); + } + if (cachedGraph->indicesTensor) { + if (is_backward_pass) { + feeds[indicesPlaceholder.getMPSGraphTensor()] = indicesPlaceholder.getMPSGraphTensorData(); + } else { + results[indicesPlaceholder.getMPSGraphTensor()] = indicesPlaceholder.getMPSGraphTensorData(); + } + } + MPSScalar divisor_scalar; + if (cachedGraph->divisorTensor) { + divisor_scalar = getMPSScalar(divisor.value(), output.scalar_type()); + feeds[cachedGraph->divisorTensor] = getMPSGraphTensorFromScalar(mpsStream, divisor_scalar); + } - NSDictionary *feeds = @{ - inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), - }; + runMPSGraph(mpsStream, cachedGraph->graph(), feeds, results); + } +} - NSDictionary *results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; +} // namespace mps - native_mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results); - } +Tensor _mps_max_pool2d( + const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + bool ceil_mode) +{ + Tensor output = at::empty({0}, input.options(), MemoryFormat::Contiguous); + mps::PoolingOpBlock pooling_op_block = ^PoolingOpFn(cachedGraph, desc) { + MPSGraph* mpsGraph = cachedGraph.graph(); + return [mpsGraph maxPooling2DWithSourceTensor: cachedGraph.inputTensor + descriptor: desc + name: nil]; + }; + mps::pool2d_template(input, output, c10::nullopt, c10::nullopt, kernel_size, stride, + padding, dilation, ceil_mode, c10::nullopt, pooling_op_block, "max_pool2d"); - return output_t; + return output; } Tensor mps_max_pool2d_backward( const Tensor& grad_output, - const Tensor& input_t, + const Tensor& input, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, - bool ceil_mode) { - - // #20866, #22032: Guarantee this for the official C++ API? - TORCH_CHECK(kernel_size.size() == 1 || kernel_size.size() == 2, - "max_pool2d: kernel_size must either be a single int, or a tuple of two ints") - const int kH = safe_downcast(kernel_size[0]); - const int kW = kernel_size.size() == 1 ? kH : safe_downcast(kernel_size[1]); - - // NB: stride default is not expressible as an integer constant, so we accept - // empty stride for this case - TORCH_CHECK(stride.size() == 0 || stride.size() == 1 || stride.size() == 2, - "max_pool2d: stride must either be omitted, a single int, or a tuple of two ints") - const int dH = stride.empty() ? kH : safe_downcast(stride[0]); - const int dW = stride.empty() ? kW : - stride.size() == 1 ? dH : safe_downcast(stride[1]); - - TORCH_CHECK(padding.size() == 1 || padding.size() == 2, - "max_pool2d: padding must be either be a single int, or a tuple of two ints"); - const int padH = safe_downcast(padding[0]); - const int padW = padding.size() == 1 ? padH : safe_downcast(padding[1]); - - TORCH_CHECK(dilation.size() == 1 || dilation.size() == 2, - "max_pool2d: dilation must be either a single int, or a tuple of two ints"); - const int dilationH = safe_downcast(dilation[0]); - const int dilationW = dilation.size() == 1 ? dilationH : safe_downcast(dilation[1]); - - const auto memory_format = input_t.suggest_memory_format(); - if (memory_format == at::MemoryFormat::ChannelsLast) { - TORCH_CHECK(input_t.ndimension() == 4, - "non-empty 4D (batch mode) tensor expected for input with channels_last layout"); - } else if (memory_format == at::MemoryFormat::Contiguous) { - TORCH_CHECK((input_t.ndimension() == 3 || input_t.ndimension() == 4), - "non-empty 3D or 4D (batch mode) tensor expected for input"); - } else { - TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous"); - } - - namespace native_mps = at::native::mps; - - // Derive from MPSCachedGraph - struct CachedGraph : public native_mps::MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *inputTensor_ = nil; - MPSGraphTensor *gradOutputTensor_ = nil; - MPSGraphTensor *gradInputTensor_ = nil; + bool ceil_mode) +{ + Tensor grad_input = at::empty(input.sizes(), input.options(), MemoryFormat::Contiguous); + mps::PoolingOpBlock pooling_op_block = ^PoolingOpFn(cachedGraph, desc) { + MPSGraph* mpsGraph = cachedGraph.graph(); + return [mpsGraph maxPooling2DGradientWithGradientTensor: cachedGraph.gradOutputTensor + sourceTensor: cachedGraph.inputTensor + descriptor: desc + name: nil]; }; - - native_mps::MPSGraphCache* cache_ = native_mps::MPSGraphCache::getInstance(); - - Tensor grad_input; - grad_input = at::native::empty_mps( - input_t.sizes(), - input_t.scalar_type(), - c10::nullopt, - kMPS, - c10::nullopt, - memory_format); - - if (grad_input.numel() == 0) { - return grad_input; - } - - auto stream = at::mps::getCurrentMPSStream(); - - @autoreleasepool { - - string mem_format_key; - switch(memory_format) { - case at::MemoryFormat::Contiguous: - mem_format_key = "Contiguous"; - break; - case at::MemoryFormat::ChannelsLast: - mem_format_key = "ChannelsLast"; - break; - default: - assert(0 && "Check should have been done earlier\n"); - } - - string key = "mps_max_pool2d_backward:" + to_string(kW) + ":" + to_string(kH) + ":" + - to_string(dW) + ":" + to_string(dH) + ":" + - to_string(dilationW) + ":" + to_string(dilationH) + ":" + - to_string(padW) + ":" + to_string(padH) + ":" + - to_string(ceil_mode) + ":" + mem_format_key + - mps::getTensorsStringKey({input_t, grad_output}); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - - if(!cachedGraph) { - native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () { - - CachedGraph *newCachedGraph = nil; - @autoreleasepool { - MPSGraph* mpsGraph = native_mps::make_mps_graph(); - newCachedGraph = new CachedGraph(mpsGraph); - - MPSGraphPooling2DOpDescriptor* desc = [[MPSGraphPooling2DOpDescriptor new] autorelease]; - fill_pool_desc(desc, kW, kH, dW, dH, dilationW, dilationH, padW, padH, ceil_mode, memory_format); - - MPSGraphTensor* inputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, input_t); - MPSGraphTensor* gradOutputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, grad_output); - MPSGraphTensor* gradInputTensor = [mpsGraph maxPooling2DGradientWithGradientTensor:gradOutputTensor - sourceTensor:inputTensor - descriptor:desc - name:nil]; - - newCachedGraph->inputTensor_ = inputTensor; - newCachedGraph->gradOutputTensor_ = gradOutputTensor; - newCachedGraph->gradInputTensor_ = gradInputTensor; - } - return newCachedGraph; - }); - cachedGraph = static_cast(tmpCachedGraph); - } - - auto inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t); - auto gradOutputPlaceholder = native_mps::Placeholder(cachedGraph->gradOutputTensor_, grad_output); - auto gradInputPlaceholder = native_mps::Placeholder(cachedGraph->gradInputTensor_, grad_input); - - NSDictionary *feeds = @{ - inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), - gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData() - }; - - NSDictionary *results = @{ - gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData() - }; - - native_mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results); - } + mps::pool2d_template(input, grad_input, c10::nullopt, grad_output, kernel_size, stride, + padding, dilation, ceil_mode, c10::nullopt, pooling_op_block, "max_pool2d_backward"); return grad_input; } TORCH_IMPL_FUNC(max_pool2d_with_indices_out_mps)( - const Tensor& input_t, + const Tensor& input, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool ceil_mode, - const Tensor& output_t, - const Tensor& indices) { - - // #20866, #22032: Guarantee this for the official C++ API? - TORCH_CHECK(kernel_size.size() == 1 || kernel_size.size() == 2, - "max_pool2d: kernel_size must either be a single int, or a tuple of two ints") - const int kH = safe_downcast(kernel_size[0]); - const int kW = kernel_size.size() == 1 ? kH : safe_downcast(kernel_size[1]); - - // NB: stride default is not expressible as an integer constant, so we accept - // empty stride for this case - TORCH_CHECK(stride.size() == 0 || stride.size() == 1 || stride.size() == 2, - "max_pool2d: stride must either be omitted, a single int, or a tuple of two ints") - const int dH = stride.empty() ? kH : safe_downcast(stride[0]); - const int dW = stride.empty() ? kW : - stride.size() == 1 ? dH : safe_downcast(stride[1]); - - TORCH_CHECK(padding.size() == 1 || padding.size() == 2, - "max_pool2d: padding must be either be a single int, or a tuple of two ints"); - const int padH = safe_downcast(padding[0]); - const int padW = padding.size() == 1 ? padH : safe_downcast(padding[1]); - - TORCH_CHECK(dilation.size() == 1 || dilation.size() == 2, - "max_pool2d: dilation must be either a single int, or a tuple of two ints"); - const int dilationH = safe_downcast(dilation[0]); - const int dilationW = dilation.size() == 1 ? dilationH : safe_downcast(dilation[1]); - - const auto memory_format = input_t.suggest_memory_format(); - if (memory_format == at::MemoryFormat::ChannelsLast) { - TORCH_CHECK(input_t.ndimension() == 4, - "non-empty 4D (batch mode) tensor expected for input with channels_last layout"); - } else if (memory_format == at::MemoryFormat::Contiguous) { - TORCH_CHECK((input_t.ndimension() == 3 || input_t.ndimension() == 4), - "non-empty 3D or 4D (batch mode) tensor expected for input"); - } else { - TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous"); - } - - /* sizes */ - const int64_t nInputPlane = input_t.size(-3); - const int64_t inputHeight = input_t.size(-2); - const int64_t inputWidth = input_t.size(-1); - - const int64_t outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, dilationH, ceil_mode); - const int64_t outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, dilationW, ceil_mode); - - pool2d_shape_check( - input_t, - kH, kW, dH, dW, padH, padW, dilationH, dilationW, - nInputPlane, - inputHeight, inputWidth, - outputHeight, outputWidth, memory_format); - - namespace native_mps = at::native::mps; - - // Derive from MPSCachedGraph - struct CachedGraph : public native_mps::MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor* inputTensor_ = nil; - MPSGraphTensor* outputTensor_ = nil; - MPSGraphTensor* indicesTensor_ = nil; + const Tensor& output, + const Tensor& indices) +{ + mps::PoolingOpBlock pooling_op_block = ^PoolingOpFn(cachedGraph, desc) { + MPSGraph* mpsGraph = cachedGraph.graph(); + NSArray* poolOutputs = [mpsGraph maxPooling2DReturnIndicesWithSourceTensor: cachedGraph.inputTensor + descriptor: desc + name: nil]; + cachedGraph.indicesTensor = mps::castMPSTensor(mpsGraph, poolOutputs[1], ScalarType::Long); + return poolOutputs[0]; }; - - native_mps::MPSGraphCache* cache_ = native_mps::MPSGraphCache::getInstance(); - - if (output_t.numel() == 0) { - return; - } - - auto stream = at::mps::getCurrentMPSStream(); - - @autoreleasepool { - - string mem_format_key; - switch(memory_format) { - case at::MemoryFormat::Contiguous: - mem_format_key = "Contiguous"; - break; - case at::MemoryFormat::ChannelsLast: - mem_format_key = "ChannelsLast"; - break; - default: - assert(0 && "Check should have been done earlier\n"); - } - - string key = "max_pool2d_with_indices_out_mps:" + to_string(kW) + ":" + to_string(kH) + ":" + - to_string(dW) + ":" + to_string(dH) + ":" + - to_string(dilationW) + ":" + to_string(dilationH) + ":" + - to_string(padW) + ":" + to_string(padH) + ":" + - to_string(ceil_mode) + ":" + mem_format_key + - mps::getTensorsStringKey({input_t}) + ":" + - native_mps::getMPSTypeString(indices.scalar_type()); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - - if(!cachedGraph) { - native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () { - - CachedGraph *newCachedGraph = nil; - - @autoreleasepool { - MPSGraph* mpsGraph = native_mps::make_mps_graph(); - newCachedGraph = new CachedGraph(mpsGraph); - - MPSGraphPooling2DOpDescriptor* desc = [[MPSGraphPooling2DOpDescriptor new] autorelease]; - fill_pool_desc(desc, kW, kH, dW, dH, dilationW, dilationH, padW, padH, ceil_mode, memory_format); - desc.returnIndicesMode = MPSGraphPoolingReturnIndicesGlobalFlatten2D; - desc.returnIndicesDataType = MPSDataTypeInt32; - - MPSGraphTensor* inputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, input_t); - NSArray* poolOutputs = [mpsGraph maxPooling2DReturnIndicesWithSourceTensor:inputTensor - descriptor:desc - name:nil]; - - MPSGraphTensor* indicesTensor = poolOutputs[1]; - if(mps::getMPSDataType(indices.scalar_type()) == MPSDataTypeInt64) { - indicesTensor = [mpsGraph castTensor:indicesTensor - toType:MPSDataTypeInt64 - name:@"castToI64"]; - } - - newCachedGraph->inputTensor_ = inputTensor; - newCachedGraph->outputTensor_ = poolOutputs[0]; - newCachedGraph->indicesTensor_ = indicesTensor; - } - return newCachedGraph; - }); - cachedGraph = static_cast(tmpCachedGraph); - } - - auto inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t); - auto outputPlaceholder = native_mps::Placeholder(cachedGraph->outputTensor_, output_t); - auto indicesPlaceholder = native_mps::Placeholder(cachedGraph->indicesTensor_, indices); - - NSDictionary *feeds = @{ - inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), - }; - - NSDictionary *results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData(), - indicesPlaceholder.getMPSGraphTensor() : indicesPlaceholder.getMPSGraphTensorData() - }; - - native_mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results); - } - + mps::pool2d_template(input, output, indices, c10::nullopt, kernel_size, stride, + padding, dilation, ceil_mode, c10::nullopt, pooling_op_block, "max_pool2d_indices"); } -TORCH_IMPL_FUNC(max_pool2d_with_indices_backward_out_mps) -(const Tensor& grad_output, -const Tensor& input_t, -IntArrayRef kernel_size, -IntArrayRef stride, -IntArrayRef padding, -IntArrayRef dilation, -bool ceil_mode, -const Tensor& indices, -const Tensor& grad_input) { - - // #20866, #22032: Guarantee this for the official C++ API? - TORCH_CHECK(kernel_size.size() == 1 || kernel_size.size() == 2, - "max_pool2d: kernel_size must either be a single int, or a tuple of two ints") - const int kH = safe_downcast(kernel_size[0]); - const int kW = kernel_size.size() == 1 ? kH : safe_downcast(kernel_size[1]); - - // NB: stride default is not expressible as an integer constant, so we accept - // empty stride for this case - TORCH_CHECK(stride.size() == 0 || stride.size() == 1 || stride.size() == 2, - "max_pool2d: stride must either be omitted, a single int, or a tuple of two ints") - const int dH = stride.empty() ? kH : safe_downcast(stride[0]); - const int dW = stride.empty() ? kW : - stride.size() == 1 ? dH : safe_downcast(stride[1]); - - TORCH_CHECK(padding.size() == 1 || padding.size() == 2, - "max_pool2d: padding must be either be a single int, or a tuple of two ints"); - const int padH = safe_downcast(padding[0]); - const int padW = padding.size() == 1 ? padH : safe_downcast(padding[1]); - - TORCH_CHECK(dilation.size() == 1 || dilation.size() == 2, - "max_pool2d: dilation must be either a single int, or a tuple of two ints"); - const int dilationH = safe_downcast(dilation[0]); - const int dilationW = dilation.size() == 1 ? dilationH : safe_downcast(dilation[1]); - - const auto memory_format = input_t.suggest_memory_format(); - if (memory_format == at::MemoryFormat::ChannelsLast) { - TORCH_CHECK(input_t.ndimension() == 4, - "non-empty 4D (batch mode) tensor expected for input with channels_last layout"); - } else if (memory_format == at::MemoryFormat::Contiguous) { - TORCH_CHECK((input_t.ndimension() == 3 || input_t.ndimension() == 4), - "non-empty 3D or 4D (batch mode) tensor expected for input"); - } else { - TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous"); - } - - namespace native_mps = at::native::mps; - - // Derive from MPSCachedGraph - struct CachedGraph : public native_mps::MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *inputTensor_ = nil; - MPSGraphTensor *gradOutputTensor_ = nil; - MPSGraphTensor *gradInputTensor_ = nil; +TORCH_IMPL_FUNC(max_pool2d_with_indices_backward_out_mps)( + const Tensor& grad_output, + const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + bool ceil_mode, + const Tensor& indices, + const Tensor& grad_input) +{ + mps::PoolingOpBlock pooling_op_block = ^PoolingOpFn(cachedGraph, desc) { + MPSGraph* mpsGraph = cachedGraph.graph(); + return [mpsGraph maxPooling2DGradientWithGradientTensor: cachedGraph.gradOutputTensor + sourceTensor: cachedGraph.inputTensor + descriptor: desc + name: nil]; }; - - native_mps::MPSGraphCache* cache_ = native_mps::MPSGraphCache::getInstance(); - - if (grad_input.numel() == 0) { - return; - } - - auto stream = at::mps::getCurrentMPSStream(); - - @autoreleasepool { - - string mem_format_key; - switch(memory_format) { - case at::MemoryFormat::Contiguous: - mem_format_key = "Contiguous"; - break; - case at::MemoryFormat::ChannelsLast: - mem_format_key = "ChannelsLast"; - break; - default: - assert(0 && "Check should have been done earlier\n"); - } - - string key = "max_pool2d_with_indices_backward_out_mps:" + to_string(kW) + ":" + to_string(kH) + ":" + - to_string(dW) + ":" + to_string(dH) + ":" + - to_string(dilationW) + ":" + to_string(dilationH) + ":" + - to_string(padW) + ":" + to_string(padH) + ":" + - to_string(ceil_mode) + ":" + mem_format_key + - mps::getTensorsStringKey({input_t, grad_output}); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - - if(!cachedGraph) { - native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () { - - CachedGraph *newCachedGraph = nil; - @autoreleasepool { - MPSGraph* mpsGraph = native_mps::make_mps_graph(); - newCachedGraph = new CachedGraph(mpsGraph); - - MPSGraphPooling2DOpDescriptor* desc = [[MPSGraphPooling2DOpDescriptor new] autorelease]; - fill_pool_desc(desc, kW, kH, dW, dH, dilationW, dilationH, padW, padH, ceil_mode, memory_format); - - MPSGraphTensor* inputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, input_t); - MPSGraphTensor* gradOutputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, grad_output); - MPSGraphTensor* gradInputTensor = [mpsGraph maxPooling2DGradientWithGradientTensor:gradOutputTensor - sourceTensor:inputTensor - descriptor:desc - name:nil]; - - newCachedGraph->inputTensor_ = inputTensor; - newCachedGraph->gradOutputTensor_ = gradOutputTensor; - newCachedGraph->gradInputTensor_ = gradInputTensor; - } - return newCachedGraph; - }); - cachedGraph = static_cast(tmpCachedGraph); - } - - auto inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t); - auto gradOutputPlaceholder = native_mps::Placeholder(cachedGraph->gradOutputTensor_, grad_output); - auto gradInputPlaceholder = native_mps::Placeholder(cachedGraph->gradInputTensor_, grad_input); - - NSDictionary *feeds = @{ - inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), - gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData() - }; - - NSDictionary *results = @{ - gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData() - }; - - native_mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results); - } + mps::pool2d_template(input, grad_input, indices, grad_output, kernel_size, stride, + padding, dilation, ceil_mode, c10::nullopt, pooling_op_block, "max_pool2d_indices_backward"); } TORCH_IMPL_FUNC(avg_pool2d_out_mps) ( - const Tensor& input_, - int64_t kH_, - int64_t kW_, - int64_t dH_, - int64_t dW_, - int64_t padH_, - int64_t padW_, + const Tensor& input, + int64_t kH, + int64_t kW, + int64_t dH, + int64_t dW, + int64_t padH, + int64_t padW, bool ceil_mode, bool count_include_pad, c10::optional divisor_override, - const Tensor& output) { - namespace native_mps = at::native::mps; - - TensorArg output_arg{ output, "output", 1 }; - TensorArg input_arg{ input_, "input_", 2 }; - - checkAllSameGPU("avg_pool2d_out_cuda", {output_arg, input_arg}); - - const int kH = safe_downcast(kH_); - const int kW = safe_downcast(kW_); - - const int dH = safe_downcast(dH_); - const int dW = safe_downcast(dW_); - - const int padH = safe_downcast(padH_); - const int padW = safe_downcast(padW_); - - /* sizes */ - - const auto memory_format = input_.suggest_memory_format(); - - Tensor input = input_.contiguous(memory_format); - - const int32_t count = safe_downcast(output.numel()); - - bool use_divisor = divisor_override.has_value(); - const auto divisor_override_value = use_divisor ? divisor_override.value() : 0; - - if (count != 0) { - // Derive from MPSCachedGraph - struct CachedGraph : public native_mps::MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor* inputTensor_ = nil; - MPSGraphTensor* outputTensor_ = nil; - MPSGraphTensor* indicesTensor_ = nil; - }; - - native_mps::MPSGraphCache* cache_ = native_mps::MPSGraphCache::getInstance(); - - auto stream = at::mps::getCurrentMPSStream(); - - @autoreleasepool { - string mem_format_key; - switch(memory_format) { - case at::MemoryFormat::Contiguous: - mem_format_key = "Contiguous"; - break; - case at::MemoryFormat::ChannelsLast: - mem_format_key = "ChannelsLast"; - break; - default: - assert(0 && "Check should have been done earlier\n"); - } - - string key = "mps_avg_pool2d:" + to_string(kW) + ":" + to_string(kH) + ":" + - to_string(dW) + ":" + to_string(dH) + ":" + - to_string(padW) + ":" + to_string(padH) + ":" + - to_string(ceil_mode) + ":" + mem_format_key + ":" + - to_string(divisor_override_value) + - mps::getTensorsStringKey({input}); - CachedGraph* cachedGraph = cache_->LookUpAs(key); - - if(!cachedGraph) { - native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () { - CachedGraph *newCachedGraph = nil; - - @autoreleasepool { - MPSGraph* mpsGraph = native_mps::make_mps_graph(); - newCachedGraph = new CachedGraph(mpsGraph); - - MPSGraphPooling2DOpDescriptor* desc = [[MPSGraphPooling2DOpDescriptor new] autorelease]; - fill_pool_desc(desc, kW, kH, dW, dH, 1, 1, padW, padH, ceil_mode, memory_format); - - MPSGraphTensor* inputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, input); - MPSGraphTensor* outputTensor = [mpsGraph avgPooling2DWithSourceTensor:inputTensor - descriptor:desc - name:nil]; - newCachedGraph->inputTensor_ = inputTensor; - newCachedGraph->outputTensor_ = outputTensor; - } - return newCachedGraph; - }); - cachedGraph = static_cast(tmpCachedGraph); - } - - auto inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input); - auto outputPlaceholder = native_mps::Placeholder(cachedGraph->outputTensor_, output); - - NSDictionary *feeds = @{ - inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), - }; - - NSDictionary *results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; - - native_mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results); - } - } + const Tensor& output) +{ + const bool use_divisor = divisor_override.has_value() && divisor_override.value() != 0; + float divisor = use_divisor ? float(kH * kW) / (float) divisor_override.value() : 1.0f; + count_include_pad = use_divisor ? use_divisor : count_include_pad; + + mps::PoolingOpBlock pooling_op_block = ^PoolingOpFn(cachedGraph, desc) { + MPSGraph* mpsGraph = cachedGraph.graph(); + desc.includeZeroPadToAverage = count_include_pad; + MPSGraphTensor* avgPoolTensor = [mpsGraph avgPooling2DWithSourceTensor: cachedGraph.inputTensor + descriptor: desc + name: nil]; + // workaround: custom divisor isn't supported by MPS backend, so we scale manually + return [mpsGraph multiplicationWithPrimaryTensor: avgPoolTensor + secondaryTensor: cachedGraph.divisorTensor + name: nil]; + }; + mps::pool2d_template(input, output, c10::nullopt, c10::nullopt, {kH, kW}, {dH, dW}, + {padH, padW}, {1, 1}, ceil_mode, divisor, pooling_op_block, + std::string("avg_pool2d") + (count_include_pad ? "_include_pad" : "")); } TORCH_IMPL_FUNC(avg_pool2d_backward_out_mps) ( - const Tensor& gradOutput_, - const Tensor& input_, - IntArrayRef kernel_size, - IntArrayRef stride, - IntArrayRef padding, - bool ceil_mode, - bool count_include_pad, - c10::optional divisor_override, - const Tensor& gradInput -) { - TensorArg gradInput_arg{ gradInput, "gradInput", 1 }; - TensorArg gradOutput_arg{ gradOutput_, "gradOutput_", 2 }; - TensorArg input_arg{ input_, "input_", 3 }; - - checkAllSameGPU("avg_pool2d_backward_out_cuda", - {gradInput_arg, gradOutput_arg, input_arg}); - namespace native_mps = at::native::mps; - - const int kH = safe_downcast(kernel_size[0]); - const int kW = kernel_size.size() == 1 ? kH : safe_downcast(kernel_size[1]); - - const int dH = stride.empty() ? kH : safe_downcast(stride[0]); - const int dW = stride.empty() ? kW : - stride.size() == 1 ? dH : safe_downcast(stride[1]); - - const int padH = safe_downcast(padding[0]); - const int padW = padding.size() == 1 ? padH : safe_downcast(padding[1]); - - const auto memory_format = input_.suggest_memory_format(); - const Tensor input = input_.contiguous(memory_format); - const Tensor gradOutput = gradOutput_.contiguous(memory_format); - - const int64_t inputHeight = input.size(-2); - const int64_t inputWidth = input.size(-1); - - const int64_t outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, 1, ceil_mode); - const int64_t outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, 1, ceil_mode); - - - const int32_t count = safe_downcast(input.numel()); - if (count == 0) { - return; - } - - namespace native_mps = at::native::mps; - - // Derive from MPSCachedGraph - struct CachedGraph : public native_mps::MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *inputTensor_ = nil; - MPSGraphTensor *gradOutputTensor_ = nil; - MPSGraphTensor *gradInputTensor_ = nil; + const Tensor& gradOutput, + const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + bool ceil_mode, + bool count_include_pad, + c10::optional divisor_override, + const Tensor& gradInput) +{ + const bool use_divisor = divisor_override.has_value() && divisor_override.value() != 0; + float divisor = use_divisor ? float(kernel_size[0] * kernel_size[1]) / (float) divisor_override.value() : 1.0f; + count_include_pad = use_divisor ? use_divisor : count_include_pad; + + mps::PoolingOpBlock pooling_op_block = ^PoolingOpFn(cachedGraph, desc) { + MPSGraph* mpsGraph = cachedGraph.graph(); + desc.includeZeroPadToAverage = count_include_pad; + // workaround: custom divisor isn't supported by MPS backend, so we scale manually + MPSGraphTensor* scaledGradTensor = [mpsGraph multiplicationWithPrimaryTensor: cachedGraph.gradOutputTensor + secondaryTensor: cachedGraph.divisorTensor + name: nil]; + return [mpsGraph avgPooling2DGradientWithGradientTensor: scaledGradTensor + sourceTensor: cachedGraph.inputTensor + descriptor: desc + name: nil]; }; - - native_mps::MPSGraphCache* cache_ = native_mps::MPSGraphCache::getInstance(); - - if (gradInput.numel() == 0) { - return; - } - - auto stream = at::mps::getCurrentMPSStream(); - - @autoreleasepool { - - string mem_format_key; - switch(memory_format) { - case at::MemoryFormat::Contiguous: - mem_format_key = "Contiguous"; - break; - case at::MemoryFormat::ChannelsLast: - mem_format_key = "ChannelsLast"; - break; - default: - assert(0 && "Check should have been done earlier\n"); - } - - string key = "avg_pool2d_backward_out_mps:" + to_string(kW) + ":" + to_string(kH) + ":" + - to_string(dW) + ":" + to_string(dH) + ":" + - to_string(outputWidth) + ":" + to_string(outputHeight) + ":" + - to_string(padW) + ":" + to_string(padH) + ":" + - to_string(ceil_mode) + ":" + mem_format_key + - mps::getTensorsStringKey({input, gradOutput}); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - - if(!cachedGraph) { - native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () { - - CachedGraph *newCachedGraph = nil; - @autoreleasepool { - MPSGraph* mpsGraph = native_mps::make_mps_graph(); - newCachedGraph = new CachedGraph(mpsGraph); - - MPSGraphPooling2DOpDescriptor* desc = [[MPSGraphPooling2DOpDescriptor new] autorelease]; - fill_pool_desc(desc, kW, kH, dW, dH, 1, 1, padW, padH, ceil_mode, memory_format); - - MPSGraphTensor* inputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, input); - MPSGraphTensor* gradOutputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, gradOutput); - MPSGraphTensor *gradInputTensor = [mpsGraph avgPooling2DGradientWithGradientTensor:gradOutputTensor - sourceTensor:inputTensor - descriptor : desc - name:nil]; - - newCachedGraph->inputTensor_ = inputTensor; - newCachedGraph->gradOutputTensor_ = gradOutputTensor; - newCachedGraph->gradInputTensor_ = gradInputTensor; - } - return newCachedGraph; - }); - cachedGraph = static_cast(tmpCachedGraph); - } - - auto inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input); - auto gradOutputPlaceholder = native_mps::Placeholder(cachedGraph->gradOutputTensor_, gradOutput); - auto gradInputPlaceholder = native_mps::Placeholder(cachedGraph->gradInputTensor_, gradInput); - - NSDictionary *feeds = @{ - inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), - gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData() - }; - - NSDictionary *results = @{ - gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData() - }; - - native_mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results); - } + mps::pool2d_template(input, gradInput, c10::nullopt, gradOutput, kernel_size, stride, + padding, {1, 1}, ceil_mode, divisor, pooling_op_block, + std::string("avg_pool2d_backward") + (count_include_pad ? "_include_pad" : "")); } } // namespace native diff --git a/aten/src/ATen/native/mps/operations/RangeFactories.mm b/aten/src/ATen/native/mps/operations/RangeFactories.mm index 403ae4748f0ff..4533ad1578556 100644 --- a/aten/src/ATen/native/mps/operations/RangeFactories.mm +++ b/aten/src/ATen/native/mps/operations/RangeFactories.mm @@ -88,6 +88,11 @@ } result.resize_({size}); } + + if (result.numel() == 0) { + return; + } + bool is_contiguous = result.is_contiguous(); Tensor r = !is_contiguous ? at::empty_like(result, LEGACY_CONTIGUOUS_MEMORY_FORMAT) : result; using namespace mps; diff --git a/aten/src/ATen/native/mps/operations/ReduceOps.mm b/aten/src/ATen/native/mps/operations/ReduceOps.mm index 36a68fc5331c0..e2bd1ab6da9cf 100644 --- a/aten/src/ATen/native/mps/operations/ReduceOps.mm +++ b/aten/src/ATen/native/mps/operations/ReduceOps.mm @@ -9,10 +9,14 @@ #include #include #include +#include namespace at { namespace native { +typedef MPSGraphTensor* (^NormOpBlock)(mps::MPSBinaryCachedGraph*, MPSGraphTensor*, MPSGraphTensor*); +#define NormOpFn(graph, primary, secondary) MPSGraphTensor* (mps::MPSBinaryCachedGraph* graph, MPSGraphTensor* primary, MPSGraphTensor* secondary) + enum StdVarType { STANDARD_VARIANCE, STANDARD_DEVIATION @@ -26,10 +30,10 @@ SUM, PROD, MEAN, - COUNT_NONZERO + COUNT_NONZERO, + TRACE }; - void set_apparent_shapes(NSMutableArray * &apparent_out_shape, NSMutableArray * &apparent_in_shape, int64_t num_reduce_dims, @@ -78,7 +82,6 @@ void set_apparent_shapes(NSMutableArray * &apparent_out_shape, } } } - } // Helper function to set the axes of reduction @@ -137,16 +140,14 @@ void set_axes_and_shapes(const Tensor& input_t, } } -void reduction_out_mps - (const Tensor& input_tensor, - OptionalIntArrayRef opt_dim, - bool keepdim, - c10::optional dtype, - const Tensor& output_t, - MPSReductionType reduction_type, - const std::string& func_name) { - - auto input_t = (input_tensor.sizes().size() == 0) ? input_tensor.view({1}) : input_tensor; +void reduction_out_mps( + const Tensor& input_t, + OptionalIntArrayRef opt_dim, + bool keepdim, + c10::optional dtype, + const Tensor& output_t, + MPSReductionType reduction_type, + const std::string& func_name) { IntArrayRef input_shape = input_t.sizes(); @@ -154,7 +155,7 @@ void set_axes_and_shapes(const Tensor& input_t, IntArrayRef dim = opt_dim.value(); for(int i = 0; i < dim.size(); i++) { auto wrap_dim = maybe_wrap_dim(dim[i], input_shape.size()); - TORCH_CHECK(wrap_dim < input_shape.size(), + TORCH_CHECK(wrap_dim < (input_shape.size() == 0 ? input_t.numel() : input_shape.size()), func_name+": reduction dim must be in the range of input shape") } } @@ -167,56 +168,71 @@ void set_axes_and_shapes(const Tensor& input_t, NSMutableArray *output_shape = nil; set_axes_and_shapes(input_t, opt_dim, axes, apparent_input_shape, apparent_output_shape, output_shape); - - auto cache_ = native_mps::MPSGraphCache::getInstance(); + NSArray* wrappedAxes = mps::getTensorAxes(input_t, opt_dim); + auto cache_ = native_mps::MPSGraphCache::getInstance(); if (output_t.numel() == 0 || input_t.numel() == 0) { + if (reduction_type == MPSReductionType::PROD) { + output_t.fill_(1); + } return; } auto stream = at::mps::getCurrentMPSStream(); - @autoreleasepool { - - // TODO: Make this key proper - NSString* ns_key = [[axes valueForKey:@"description"] componentsJoinedByString:@","]; - string key = func_name+":" + string([ns_key UTF8String]) + ":" + native_mps::getMPSTypeString(input_t.scalar_type()) + ":" + native_mps::getMPSTypeString(output_t.scalar_type()); + std::string dtype_str = dtype.has_value() ? mps::getMPSTypeString(dtype.value()) : ""; + NSString* ns_key = [[wrappedAxes valueForKey:@"description"] componentsJoinedByString:@","]; + string key = func_name + ":" + + string([ns_key UTF8String]) + ":" + + native_mps::getTensorsStringKey(input_t) + ":" + + std::to_string(keepdim) + ":" + + std::to_string(reduction_type) + ":" + + native_mps::getTensorsStringKey(output_t) + ":" + + dtype_str; using CachedGraph = native_mps::MPSUnaryCachedGraph; auto cachedGraph = cache_->LookUpAs(key); if(!cachedGraph) { - native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () { + cachedGraph = cache_->CreateCachedGraphAs(key, ^ native_mps::MPSCachedGraph * () { CachedGraph *newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = native_mps::make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); + MPSDataType input_type = native_mps::getMPSDataType(input_t.scalar_type()); + + MPSGraphTensor* inputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, input_t); + MPSGraphTensor* castInputTensor = inputTensor; + MPSDataType inputCastDtype = MPSDataTypeInvalid; + if (dtype.has_value() && + (dtype.value() == kFloat || dtype.value() == kHalf || dtype.value() == kInt)) { + inputCastDtype = native_mps::getMPSDataType(dtype.value()); + } else if (input_type != MPSDataTypeInt32 && + input_type != MPSDataTypeFloat32 && + input_type != MPSDataTypeFloat16) { + inputCastDtype = MPSDataTypeFloat32; + } - MPSGraphTensor* inputTensor = native_mps::mpsGraphUnrankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(input_t.scalar_type())); - - MPSGraphTensor* castInputTensor = nil; - - if(input_t.scalar_type() != ScalarType::Float && input_t.scalar_type() != ScalarType::Int) - castInputTensor = [mpsGraph castTensor:inputTensor - toType:MPSDataTypeFloat32 - name:@"castInputTensor"]; - else - castInputTensor = inputTensor; + if (inputCastDtype != MPSDataTypeInvalid) { + castInputTensor = [mpsGraph castTensor:inputTensor + toType:inputCastDtype + name:@"castInputTensor"]; + } MPSGraphTensor* castOutputTensor = nil; if(reduction_type == MPSReductionType::SUM) { castOutputTensor = [mpsGraph reductionSumWithTensor:castInputTensor - axes:axes + axes:wrappedAxes name:nil]; } else if(reduction_type == MPSReductionType::PROD) { castOutputTensor = [mpsGraph reductionProductWithTensor:castInputTensor - axes:axes + axes:wrappedAxes name:nil]; } else if(reduction_type == MPSReductionType::MEAN) { - castOutputTensor = [mpsGraph meanOfTensor:inputTensor - axes:axes + castOutputTensor = [mpsGraph meanOfTensor:castInputTensor + axes:wrappedAxes name:nil]; } else if(reduction_type == MPSReductionType::COUNT_NONZERO) { MPSGraphTensor* zeros = [mpsGraph constantWithScalar:0 @@ -227,17 +243,25 @@ void set_axes_and_shapes(const Tensor& input_t, name:nil]; castOutputTensor = [mpsGraph reductionSumWithTensor:nonZeros - axes:axes + axes:wrappedAxes name:nil]; } else if(reduction_type == MPSReductionType::AMAX) { - castOutputTensor = [mpsGraph reductionMaximumWithTensor:inputTensor - axes:axes + castOutputTensor = [mpsGraph reductionMaximumWithTensor:castInputTensor + axes:wrappedAxes name:nil]; } else if(reduction_type == MPSReductionType::AMIN) { - castOutputTensor = [mpsGraph reductionMinimumWithTensor:inputTensor - axes:axes + castOutputTensor = [mpsGraph reductionMinimumWithTensor:castInputTensor + axes:wrappedAxes name:nil]; + } else if(reduction_type == MPSReductionType::TRACE) { + MPSGraphTensor *bandPartWithTensor = [mpsGraph bandPartWithTensor:inputTensor + numLower:0 + numUpper:0 + name:nil]; + castOutputTensor = [mpsGraph reductionSumWithTensor:bandPartWithTensor + axes:@[@0, @1] + name:nil]; } MPSGraphTensor* outputTensor = nil; @@ -254,15 +278,9 @@ void set_axes_and_shapes(const Tensor& input_t, } return newCachedGraph; }); - cachedGraph = tmpCachedGraph->as(); } - auto inputPlaceholder = native_mps::Placeholder(); - - if(apparent_input_shape) - inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t, apparent_input_shape); - else - inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t); + auto inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t); auto outputPlaceholder = native_mps::Placeholder(cachedGraph->outputTensor_, output_t, apparent_output_shape); NSDictionary *feeds = @{ inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), @@ -273,17 +291,36 @@ void set_axes_and_shapes(const Tensor& input_t, }; native_mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results); } +} + +TORCH_IMPL_FUNC(sum_out_mps)( + const Tensor& input_t, + OptionalIntArrayRef opt_dim, + bool keepdim, + c10::optional dtype, + const Tensor& output_t) { + reduction_out_mps(input_t, opt_dim, keepdim, dtype, output_t, MPSReductionType::SUM, "sum_out_mps"); } -TORCH_IMPL_FUNC(sum_out_mps) - (const Tensor& input_t, - OptionalIntArrayRef opt_dim, - bool keepdim, - c10::optional dtype, - const Tensor& output_t) { +Tensor trace_mps_out(const Tensor& self) { + + Tensor output_t = at::native::empty_mps( + {}, + self.scalar_type(), + c10::nullopt, + kMPS, + c10::nullopt, + c10::nullopt); + + std::vector dims(self.dim()); + std::iota(dims.begin(), dims.end(), 0); + + reduction_out_mps(self, IntArrayRef(dims), false, c10::nullopt, const_cast(output_t), MPSReductionType::TRACE, "trace_mps_out"); + + return output_t; + - reduction_out_mps(input_t, opt_dim, keepdim, dtype, output_t, MPSReductionType::SUM, "sum_out_mps"); } TORCH_IMPL_FUNC(prod_out_mps) @@ -292,20 +329,19 @@ void set_axes_and_shapes(const Tensor& input_t, bool keepdim, c10::optional dtype, const Tensor& output_t) { - - int64_t dims[1] = {dim}; - - reduction_out_mps(input_t, IntArrayRef(dims, 1), keepdim, dtype, output_t, MPSReductionType::PROD, "prod_out_mps"); + int64_t dims[1] = {dim}; + reduction_out_mps(input_t, IntArrayRef(dims, 1), keepdim, dtype, output_t, MPSReductionType::PROD, "prod_out_mps"); } // Taken from ReduceOps.cpp inline ScalarType get_dtype_from_self( const Tensor& self, - const optional& dtype, + const c10::optional& dtype, bool promote_integers) { if (dtype.has_value()) { return dtype.value(); } + ScalarType src_type = self.scalar_type(); if (promote_integers && at::isIntegralType(src_type, /*includeBool=*/true)) { return kLong; @@ -313,26 +349,25 @@ inline ScalarType get_dtype_from_self( return src_type; } -TORCH_IMPL_FUNC(amax_out_mps) - (const Tensor& input_t, - IntArrayRef dim, - bool keepdim, - const Tensor& output_t) { +TORCH_IMPL_FUNC(amax_out_mps)( + const Tensor& input_t, + IntArrayRef dim, + bool keepdim, + const Tensor& output_t) { - reduction_out_mps(input_t, dim, keepdim, c10::nullopt, output_t, MPSReductionType::AMAX, "amax_out_mps"); + reduction_out_mps(input_t, dim, keepdim, c10::nullopt, output_t, MPSReductionType::AMAX, "amax_out_mps"); } -TORCH_IMPL_FUNC(amin_out_mps) - (const Tensor& input_t, - IntArrayRef dim, - bool keepdim, - const Tensor& output_t) { +TORCH_IMPL_FUNC(amin_out_mps)( + const Tensor& input_t, + IntArrayRef dim, + bool keepdim, + const Tensor& output_t) { - reduction_out_mps(input_t, dim, keepdim, c10::nullopt, output_t, MPSReductionType::AMIN, "amin_out_mps"); + reduction_out_mps(input_t, dim, keepdim, c10::nullopt, output_t, MPSReductionType::AMIN, "amin_out_mps"); } Tensor prod_mps(const Tensor &self, c10::optional opt_dtype) { - std::vector dims(self.dim()); std::iota(dims.begin(), dims.end(), 0); @@ -351,65 +386,74 @@ Tensor prod_mps(const Tensor &self, c10::optional opt_dtype) { Tensor count_nonzero_mps(const Tensor& self, IntArrayRef dims){ - NSMutableArray *axes = nil; - NSMutableArray *apparent_input_shape = nil; - NSMutableArray *apparent_output_shape = nil; - NSMutableArray *output_shape = nil; - - set_axes_and_shapes(self, dims, axes, apparent_input_shape, apparent_output_shape, output_shape); - - int64_t* raw_output_shape = (int64_t *)malloc([output_shape count] * sizeof(int64_t)); - for(int i=0; i < [output_shape count]; i++) { - raw_output_shape[i] = [output_shape[i] longValue]; + int64_t shape_size = dims.size() == 0 ? 0 : self.sizes().size() - dims.size(); + int64_t out_shape = std::max(shape_size, 0LL); + std::vector output_shape(out_shape); + std::vector dims_vec = dims.vec(); + std::for_each(dims_vec.begin(), dims_vec.end(), [&](int64_t &n){ n = maybe_wrap_dim(n, self); }); + + if (out_shape != 0) { + int out_dim = 0; + for (const auto self_dim: c10::irange((self.sizes().size()))) { + if (std::find(dims_vec.begin(), dims_vec.end(), self_dim) == dims_vec.end()) { + output_shape[out_dim++] = (self.sizes()[self_dim]); + } + } } Tensor output_t = at::native::empty_mps( - IntArrayRef(raw_output_shape, [output_shape count]), + IntArrayRef(output_shape), ScalarType::Long, c10::nullopt, kMPS, c10::nullopt, c10::nullopt); - reduction_out_mps(self, dims, false, self.scalar_type(), const_cast(output_t), MPSReductionType::COUNT_NONZERO, "count_nonzero_mps"); - free(raw_output_shape); - return output_t; } -TORCH_IMPL_FUNC(mean_out_mps) - (const Tensor& input_t, - OptionalIntArrayRef opt_dim, - bool keepdim, - c10::optional dtype, - const Tensor& output_t) { +TORCH_IMPL_FUNC(mean_out_mps)( + const Tensor& input_t, + OptionalIntArrayRef opt_dim, + bool keepdim, + c10::optional dtype, + const Tensor& output_t) { - reduction_out_mps(input_t, opt_dim, keepdim, dtype, output_t, MPSReductionType::MEAN, "mean_out_mps"); + reduction_out_mps(input_t, opt_dim, keepdim, dtype, output_t, MPSReductionType::MEAN, "mean_out_mps"); } -TORCH_IMPL_FUNC(norm_out_mps) -(const Tensor& input_tensor, - const OptionalScalarRef opt_p, - IntArrayRef dim, - bool keepdim, - const Tensor& output_t) -{ - if (input_tensor.numel() == 0) +void impl_func_norm_mps( + const Tensor& input_tensor, + const Tensor& other_tensor, + const OptionalScalarRef& opt_p, + IntArrayRef dim, + bool keepdim, + c10::optional opt_dtype, + const Tensor& output_t, + bool cdist = false, + c10::optional input_broadcasted_shape = c10::nullopt, + NormOpBlock normOpBlock = nullptr + ) { + + namespace native_mps = at::native::mps; + if (input_tensor.numel() == 0) { return; + } auto input_t = (input_tensor.sizes().size() == 0) ? input_tensor.view({1}) : input_tensor; + auto in_dtype = opt_dtype.value_or(input_tensor.scalar_type()); + auto mps_input_dtype = native_mps::getMPSDataType(in_dtype); - IntArrayRef input_shape = input_t.sizes(); + IntArrayRef input_shape = cdist ? input_broadcasted_shape.value() : input_t.sizes(); - for(int i = 0; i < dim.size(); i++) { + for (const auto i : c10::irange(dim.size())) { auto wrap_dim = maybe_wrap_dim(dim[i], input_shape.size()); TORCH_CHECK(wrap_dim < input_shape.size(), "norm_out_mps: reduction dim must be in the range of input shape") } - namespace native_mps = at::native::mps; - using CachedGraph = native_mps::MPSUnaryCachedGraph; + using CachedGraph = native_mps::MPSBinaryCachedGraph; native_mps::MPSGraphCache* cache_ = native_mps::MPSGraphCache::getInstance(); @@ -439,78 +483,91 @@ Tensor count_nonzero_mps(const Tensor& self, IntArrayRef dims){ num_output_dims, input_shape, axes); + + NSArray* wrappedAxes = mps::getTensorAxes(input_t, dim); + if (cdist) { + apparent_input_shape = [mps::getMPSShape(input_tensor.sizes()) mutableCopy]; + apparent_output_shape = [mps::getMPSShape(output_t.sizes()) mutableCopy]; + } + if (output_t.numel() == 0) { return; } auto stream = at::mps::getCurrentMPSStream(); - @autoreleasepool { NSString* ns_key = [[axes valueForKey:@"description"] componentsJoinedByString:@","]; string keepdim_info = (keepdim) ? "keepdim=1" : "keepdim=0"; - string key = string("norm_out_mps:") + [ns_key UTF8String] + ":" + native_mps::getMPSTypeString(input_t.scalar_type()) + ":p" + to_string(p) + ":" + keepdim_info; + string tensor_key = cdist ? native_mps::getTensorsStringKey({input_tensor, other_tensor}) : mps::getTensorsStringKey({input_t}); + string key = string("norm_out_mps:") + [ns_key UTF8String] + ":" + tensor_key + ":p" + to_string(p) + ":" + keepdim_info; auto cachedGraph = cache_->LookUpAs(key); if(!cachedGraph) { - native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () { + cachedGraph = cache_->CreateCachedGraphAs(key, ^ native_mps::MPSCachedGraph * () { CachedGraph *newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = native_mps::make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); + newCachedGraph->inputTensor_ = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, input_tensor); + + if (cdist) { + newCachedGraph->otherTensor_ = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, other_tensor); + } - MPSGraphTensor* inputTensor = native_mps::mpsGraphUnrankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(input_t.scalar_type())); + MPSGraphTensor* inputTensor = cdist ? normOpBlock(newCachedGraph, newCachedGraph->inputTensor_, newCachedGraph->otherTensor_) : + newCachedGraph->inputTensor_; + if (opt_dtype.has_value()) { + inputTensor = [mpsGraph castTensor:inputTensor + toType:mps_input_dtype + name:@"any_all"]; + } MPSGraphTensor *outputTensor; - if (pIsZero) - { + if (pIsZero) { MPSGraphTensor *absoluteTensor = [mpsGraph absoluteWithTensor:inputTensor name:nil]; MPSGraphTensor *powerValTensor = [mpsGraph constantWithScalar:p - dataType:native_mps::getMPSDataType(input_t.scalar_type())]; + dataType:mps_input_dtype]; MPSGraphTensor *powerTensor = [mpsGraph powerWithPrimaryTensor:absoluteTensor secondaryTensor:powerValTensor name:nil]; outputTensor = [mpsGraph reductionSumWithTensor:powerTensor - axes:axes + axes:wrappedAxes name:nil]; } - else if (pIsPosInf) - { + else if (pIsPosInf) { MPSGraphTensor *absoluteTensor = [mpsGraph absoluteWithTensor:inputTensor name:nil]; outputTensor = [mpsGraph reductionMaximumWithTensor:absoluteTensor - axes:axes + axes:wrappedAxes name:nil]; } - else if (pIsNegInf) - { + else if (pIsNegInf) { MPSGraphTensor *absoluteTensor = [mpsGraph absoluteWithTensor:inputTensor name:nil]; outputTensor = [mpsGraph reductionMinimumWithTensor:absoluteTensor - axes:axes + axes:wrappedAxes name:nil]; - } - else - { + } else { MPSGraphTensor *absoluteTensor = [mpsGraph absoluteWithTensor:inputTensor name:nil]; MPSGraphTensor *powerValTensor = [mpsGraph constantWithScalar:p - dataType:native_mps::getMPSDataType(input_t.scalar_type())]; + dataType:mps_input_dtype]; MPSGraphTensor *reciprocalPowerValTensor = [mpsGraph constantWithScalar:reciprocal_p - dataType:native_mps::getMPSDataType(input_t.scalar_type())]; + dataType:mps_input_dtype]; MPSGraphTensor *powerTensor = [mpsGraph powerWithPrimaryTensor:absoluteTensor secondaryTensor:powerValTensor name:nil]; MPSGraphTensor *reductionSumTensor = [mpsGraph reductionSumWithTensor:powerTensor - axes:axes + axes:wrappedAxes name:nil]; outputTensor = [mpsGraph powerWithPrimaryTensor:reductionSumTensor @@ -518,37 +575,132 @@ Tensor count_nonzero_mps(const Tensor& self, IntArrayRef dims){ name:nil]; } - newCachedGraph->inputTensor_ = inputTensor; + if (cdist) { + outputTensor= [mpsGraph reshapeTensor:outputTensor withShape:mps::getMPSShape(output_t) name: nil]; + } + newCachedGraph->outputTensor_ = outputTensor; } return newCachedGraph; }); - cachedGraph = tmpCachedGraph->as(); } - auto inputPlaceholder = native_mps::Placeholder(); - - if(apparent_input_shape) - inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t, apparent_input_shape); - else - inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t); - + auto otherPlaceholder = native_mps::Placeholder(); + auto inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t); auto outputPlaceholder = native_mps::Placeholder(cachedGraph->outputTensor_, output_t, apparent_output_shape); + NSMutableDictionary* feeds =[NSMutableDictionary dictionary]; + feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData(); - NSDictionary *feeds = @{ - inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), - }; + if (cdist) { + otherPlaceholder = native_mps::Placeholder(cachedGraph->otherTensor_, other_tensor); + feeds[otherPlaceholder.getMPSGraphTensor()] = otherPlaceholder.getMPSGraphTensorData(); + } NSDictionary *results = @{ outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() }; native_mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results); - } } +TORCH_IMPL_FUNC(norm_out_mps) +(const Tensor& self, + const OptionalScalarRef opt_p, + IntArrayRef dim, + bool keepdim, + const Tensor& result) { + impl_func_norm_mps(self, self, opt_p, dim, keepdim, c10::nullopt, result, /*cdist=*/false); +} + +TORCH_IMPL_FUNC(norm_dtype_out_mps) +(const Tensor& self, + const OptionalScalarRef opt_p, + IntArrayRef dim, + bool keepdim, + ScalarType dtype, + const Tensor& result) { + impl_func_norm_mps(self, self, opt_p, dim, keepdim, dtype, result, /*cdist=*/false); +} + +Tensor _cdist_forward_mps(const Tensor& x1, const Tensor& x2, const double p, c10::optional compute_mode) { + using namespace mps; + TORCH_CHECK(x1.dim() >= 2, "cdist only supports at least 2D tensors, X1 got: ", x1.dim(), "D"); + TORCH_CHECK(x2.dim() >= 2, "cdist only supports at least 2D tensors, X2 got: ", x2.dim(), "D"); + TORCH_CHECK(x1.size(-1) == x2.size(-1), "X1 and X2 must have the same number of columns. X1: ", x1.size(-1), " X2: ", x2.size(-1)); + TORCH_CHECK(at::isFloatingType(x1.scalar_type()), "cdist only supports floating-point dtypes, X1 got: ", x1.scalar_type()); + auto device1 = x1.device().type(); + TORCH_CHECK(at::isFloatingType(x2.scalar_type()), "cdist only supports floating-point dtypes, X2 got: ", x2.scalar_type()); + auto device2 = x2.device().type(); + TORCH_CHECK(p >= 0, "cdist only supports non-negative p values"); + TORCH_CHECK(device1 == device2, "X1 and X2 must have the same device type. X1: ", device1, " X2: ", device2); + TORCH_CHECK(x1.is_mps() && (x1.get_device() == x2.get_device()), "device of X1 (", x1.get_device(), ") must match device of X2 (", x2.get_device(), ")"); + + int64_t c1 = x1.size(-1); + int64_t c2 = x2.size(-1); + + auto dim1 = x1.dim(); + auto dim2 = x2.dim(); + int64_t mode = compute_mode.value_or(0); + TORCH_CHECK(mode >= 0 && mode <= 2, "possible modes: 0, 1, 2, but was: ", mode); + + int64_t r1 = x1.size(-2); + int64_t r2 = x2.size(-2); + + //For batch calculation we expand all dimensions(except the last two) to one, with size that equals to product of them. + //The last two dimensions will stay the same + IntArrayRef batch_tensor1(x1.sizes().data(), dim1 - 2); + IntArrayRef batch_tensor2(x2.sizes().data(), dim2 - 2); + std::vector expand_batch_portion = infer_size(batch_tensor1, batch_tensor2); + std::vector tensor1_expand_size(expand_batch_portion); + tensor1_expand_size.insert(tensor1_expand_size.end(), {r1, c1}); + std::vector tensor2_expand_size(expand_batch_portion); + tensor2_expand_size.insert(tensor2_expand_size.end(), {r2, c2}); + + const int64_t expand_batch_product = c10::multiply_integers(expand_batch_portion); + std::vector tensor1_view{expand_batch_product, r1, c1}; + std::vector tensor2_view{expand_batch_product, r2, c2}; + + std::vector output_shape(expand_batch_portion); + output_shape.insert(output_shape.end(), {r1, r2}); + Tensor result = at::empty(output_shape, x1.options()); + + NormOpBlock norm_op_block = ^NormOpFn(cachedGraph, x1Tensor, x2Tensor) { + MPSGraph* mpsGraph = cachedGraph->graph(); + + MPSGraphTensor* inputBroadcast = [mpsGraph broadcastTensor:x1Tensor toShape:getMPSShape(tensor1_expand_size) name:nil]; + MPSGraphTensor* inputBroadcastReshape = [mpsGraph reshapeTensor:inputBroadcast withShape:getMPSShape(tensor1_view) name:nil]; + + MPSGraphTensor* otherBroadcast = [mpsGraph broadcastTensor:x2Tensor toShape:getMPSShape(tensor2_expand_size) name:nil]; + MPSGraphTensor* otherBroadcastReshape = [mpsGraph reshapeTensor:otherBroadcast withShape:getMPSShape(tensor2_view) name:nil]; + + NSMutableArray *inputArray = [NSMutableArray arrayWithCapacity:tensor1_view[1]]; + NSMutableArray *otherArray = [NSMutableArray arrayWithCapacity:tensor2_view[1]]; + + for (const auto i : c10::irange(tensor2_view[1])) { + inputArray[i] = inputBroadcastReshape; + } + + for (const auto i : c10::irange(tensor1_view[1])) { + otherArray[i] = otherBroadcastReshape; + } + + MPSGraphTensor *inputTensorReshaped = [mpsGraph concatTensors:inputArray dimension:1 interleave:YES name:nil]; + MPSGraphTensor *otherTensorReshaped = [mpsGraph concatTensors:otherArray dimension:1 interleave:NO name:nil]; + + + MPSGraphTensor *inputTensorPNorm = [mpsGraph subtractionWithPrimaryTensor: inputTensorReshaped + secondaryTensor: otherTensorReshaped + name: nil]; + return inputTensorPNorm; + }; + + c10::optional inputBroadcastSize = c10::make_optional(makeArrayRef(tensor1_view.data(), tensor1_view.size())); + impl_func_norm_mps(x1, x2, OptionalScalarRef(p), makeArrayRef(2), false, c10::nullopt, result, /*cdist=*/true, inputBroadcastSize, norm_op_block); + return result; +} + Tensor std_var_common_impl_mps( const Tensor & input_t, at::OptionalIntArrayRef dim, @@ -565,22 +717,23 @@ Tensor std_var_common_impl_mps( bool use_dim = dim.has_value(); IntArrayRef dim_value = use_dim ? dim.value() : NULL; - if (use_dim) - { - string errMessage = (stdVarType == STANDARD_DEVIATION) ? "std_mps" : "var_mps"; - errMessage += ": reduction dim must be in the range of input shape"; - for(int i = 0; i < dim_value.size(); i++) { - auto wrap_dim = maybe_wrap_dim(dim_value[i], input_shape.size()); - TORCH_CHECK(wrap_dim < input_shape.size(), errMessage.c_str()) + if (use_dim){ + string errMessage = (stdVarType == STANDARD_DEVIATION) ? "std_mps" : "var_mps"; + errMessage += ": reduction dim must be in the range of input shape"; + for (const int i : c10::irange(dim_value.size())) { + auto wrap_dim = maybe_wrap_dim(dim_value[i], input_shape.size()); + TORCH_CHECK(wrap_dim < input_shape.size(), errMessage.c_str()) } } - bool use_correction = correction.has_value(); - const auto correction_value = use_correction ? correction.value() : false; + bool use_correction = !(correction.has_value() && correction.value() == 0); + const auto correction_value = correction.value_or(1); int64_t correction_n = 1; native_mps::MPSGraphCache* cache_ = native_mps::MPSGraphCache::getInstance(); + NSArray* wrappedAxes = mps::getTensorAxes(input_t, dim); + int64_t num_output_dims = 0; NSMutableArray *axes = nil; NSMutableArray *apparent_output_shape = nil; @@ -610,97 +763,90 @@ Tensor std_var_common_impl_mps( axes[0] = @0; } - else if (!keepdim && use_dim && dim_value.size() > 0) - { - int64_t num_reduce_dims = dim_value.size(); - num_output_dims = num_input_dims; - - set_axes(axes, num_reduce_dims, dim_value, num_input_dims); - set_apparent_shapes(apparent_output_shape, - apparent_input_shape, - num_reduce_dims, - num_input_dims, - num_output_dims, - input_shape, - axes); - - num_output_dims = (num_input_dims >= num_reduce_dims) ? (num_input_dims - num_reduce_dims) : 0; //num_input_dims; - - unsigned int curr_i = 0; - for (int i = 0; i < num_input_dims; i++) - { - bool found = false; - for (int j = 0; j < num_reduce_dims; j++) - { - if (i == dim_value[j]) - { - found = true; - break; - } - } - if (found) continue; - output_shape.push_back(input_shape[i]); - curr_i += 1; - // End loop when output shape is filled - if (curr_i == num_output_dims) - break; - } + else if (!keepdim && use_dim && dim_value.size() > 0) { + int64_t num_reduce_dims = dim_value.size(); + num_output_dims = num_input_dims; - for(int i = 0; i < num_reduce_dims; i++) - { - auto wrap_dim = maybe_wrap_dim(dim_value[i], input_shape.size()); - correction_n *= input_shape[wrap_dim]; - } - // (3, 4, 5) --> (3, 5) - } - else if ((keepdim && !use_dim) || (keepdim && use_dim && dim_value.size() <= 0)) - { - num_output_dims = 0; - int64_t num_reduce_dims = 0; - set_axes(axes, num_reduce_dims, dim_value, input_shape.size()); - set_apparent_shapes(apparent_output_shape, + set_axes(axes, num_reduce_dims, dim_value, num_input_dims); + set_apparent_shapes(apparent_output_shape, apparent_input_shape, - num_reduce_dims, - num_input_dims, - num_output_dims, - input_shape, - axes); - num_output_dims = num_input_dims; - for (int i = 0; i < num_input_dims; i++) - { - output_shape.push_back((int64_t) 1); - correction_n *= input_shape[i]; + num_reduce_dims, + num_input_dims, + num_output_dims, + input_shape, + axes); + + num_output_dims = (num_input_dims >= num_reduce_dims) ? (num_input_dims - num_reduce_dims) : 0; //num_input_dims; + + unsigned int curr_i = 0; + for (const int i : c10::irange(num_input_dims)) { + bool found = false; + for (const int j : c10::irange(num_reduce_dims)) { + if (i == dim_value[j]) { + found = true; + break; + } } - // scalar --> vector case [[1.0034567]] - } - else if (keepdim && use_dim && dim_value.size() > 0) - { - int64_t num_reduce_dims = dim_value.size(); - num_output_dims = num_input_dims; - - set_axes(axes, num_reduce_dims, dim_value, num_input_dims); - set_apparent_shapes(apparent_output_shape, - apparent_input_shape, - num_reduce_dims, - num_input_dims, - num_output_dims, - input_shape, - axes); - - num_output_dims = num_input_dims;//(num_input_dims >= num_reduce_dims) ? (num_input_dims - num_reduce_dims) : 0; - - for(int i = 0; i < num_reduce_dims; i++) - { - auto wrap_dim = maybe_wrap_dim(dim_value[i], input_shape.size()); - correction_n *= input_shape[wrap_dim]; + if (found) { + continue; } - for (int i = 0; i < num_input_dims; i++) - { - output_shape.push_back([apparent_output_shape[i] longValue]); - } + output_shape.push_back(input_shape[i]); + curr_i += 1; + // End loop when output shape is filled + if (curr_i == num_output_dims) + break; + } + + for(int i = 0; i < num_reduce_dims; i++) { + auto wrap_dim = maybe_wrap_dim(dim_value[i], input_shape.size()); + correction_n *= input_shape[wrap_dim]; + } + // (3, 4, 5) --> (3, 5) } + else if ((keepdim && !use_dim) || (keepdim && use_dim && dim_value.size() <= 0)) { + num_output_dims = 0; + int64_t num_reduce_dims = 0; + set_axes(axes, num_reduce_dims, dim_value, input_shape.size()); + set_apparent_shapes(apparent_output_shape, + apparent_input_shape, + num_reduce_dims, + num_input_dims, + num_output_dims, + input_shape, + axes); + num_output_dims = num_input_dims; + for (const int i : c10::irange(num_input_dims)) + { + output_shape.push_back((int64_t) 1); + correction_n *= input_shape[i]; + } + // scalar --> vector case [[1.0034567]] + } + else if (keepdim && use_dim && dim_value.size() > 0) { + int64_t num_reduce_dims = dim_value.size(); + num_output_dims = num_input_dims; + + set_axes(axes, num_reduce_dims, dim_value, num_input_dims); + set_apparent_shapes(apparent_output_shape, + apparent_input_shape, + num_reduce_dims, + num_input_dims, + num_output_dims, + input_shape, + axes); + + num_output_dims = num_input_dims;//(num_input_dims >= num_reduce_dims) ? (num_input_dims - num_reduce_dims) : 0; + + for(const int i : c10::irange(num_reduce_dims)) { + auto wrap_dim = maybe_wrap_dim(dim_value[i], input_shape.size()); + correction_n *= input_shape[wrap_dim]; + } + for (const int i : c10::irange(num_input_dims)) { + output_shape.push_back([apparent_output_shape[i] longValue]); + } + } Tensor output_t = at::native::empty_mps( IntArrayRef(output_shape.data(), num_output_dims), @@ -710,26 +856,30 @@ Tensor std_var_common_impl_mps( c10::nullopt, c10::nullopt); - if (output_t.numel() == 0 || input_t.numel() == 0) - { - return output_t; + if (output_t.numel() == 0 || input_t.numel() == 0) { + return output_t; } - double bessel_correction = ((double) correction_n) / ((double) (correction_n-1)); - + double bessel_correction = ((double) correction_n) / ((double) (correction_n-correction_value)); auto stream = at::mps::getCurrentMPSStream(); @autoreleasepool { string op_key = (stdVarType == STANDARD_DEVIATION) ? "std_mps" : "var_mps"; - NSString* ns_key = [[axes valueForKey:@"description"] componentsJoinedByString:@","]; + NSString* ns_key = [[wrappedAxes valueForKey:@"description"] componentsJoinedByString:@","]; string bessel_corrected = (use_correction && correction_value) ? "unbiased " : "biased "; string use_dim_info = (use_dim) ? "use_dim=1:" + to_string(dim_value.size()) : "use_dim=0"; string keepdim_info = (keepdim) ? "keepdim=1" : "keepdim=0"; - string key = op_key + use_dim_info + ":" + keepdim_info + ":" + string([ns_key UTF8String]) + ":" + native_mps::getTensorsStringKey(input_t) + ":" + bessel_corrected; + string key = op_key + ":" + + native_mps::getTensorsStringKey(input_t) + ":" + + use_dim_info + ":" + + keepdim_info + ":" + + string([ns_key UTF8String]) + ":" + + bessel_corrected + ":" + + std::to_string(correction_value); auto cachedGraph = cache_->LookUpAs(key); // Initialize once if configuration not found in cache - if(!cachedGraph) { + if(!cachedGraph) { native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () { CachedGraph *newCachedGraph = nil; @@ -738,24 +888,22 @@ Tensor std_var_common_impl_mps( MPSGraph* mpsGraph = native_mps::make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); - MPSGraphTensor *inputTensor = native_mps::mpsGraphUnrankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(input_t.scalar_type())); + MPSGraphTensor *inputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, input_t); MPSGraphTensor *outputVarTensor = [mpsGraph varianceOfTensor:inputTensor - axes:axes - name:nil]; - MPSGraphTensor *outputTensor; + axes:wrappedAxes + name:nil]; + MPSGraphTensor *outputTensor = nil; if (use_correction && correction_value) { MPSGraphTensor *besselTensor= [mpsGraph constantWithScalar:bessel_correction - dataType:MPSDataTypeFloat32]; - MPSGraphTensor *correctedTensor = [mpsGraph multiplicationWithPrimaryTensor: outputVarTensor - secondaryTensor: besselTensor - name: nil]; + dataType:native_mps::getMPSDataType(input_t.scalar_type())]; + MPSGraphTensor *correctedTensor = [mpsGraph multiplicationWithPrimaryTensor:outputVarTensor + secondaryTensor:besselTensor + name:nil]; outputTensor = (stdVarType == STANDARD_DEVIATION) ? [mpsGraph squareRootWithTensor:correctedTensor name:nil] : correctedTensor; - } - else - { + } else { outputTensor = (stdVarType == STANDARD_DEVIATION) ? [mpsGraph squareRootWithTensor:outputVarTensor name:nil] : outputVarTensor; } @@ -765,18 +913,10 @@ Tensor std_var_common_impl_mps( } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); + cachedGraph = static_cast(tmpCachedGraph); } - auto inputPlaceholder = native_mps::Placeholder(); - if(apparent_input_shape) - { - inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t, apparent_input_shape); - } - else - { - inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t); - } + auto inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t); auto outputPlaceholder = native_mps::Placeholder(cachedGraph->outputTensor_, output_t, apparent_output_shape); NSDictionary *feeds = @{ @@ -848,7 +988,7 @@ Tensor std_mps( CachedGraph* cachedGraph = cache_->LookUpAs(key); if(!cachedGraph) { - native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () { + cachedGraph = cache_->CreateCachedGraphAs(key, ^ native_mps::MPSCachedGraph * () { CachedGraph *newCachedGraph = nil; @autoreleasepool { @@ -861,8 +1001,7 @@ Tensor std_mps( if (input_type != MPSDataTypeInt32 && input_type != MPSDataTypeFloat32 && - input_type != MPSDataTypeFloat16 ) - { + input_type != MPSDataTypeFloat16) { MPSGraphTensor* inputCastedTensor = [mpsGraph castTensor:inputTensor toType:MPSDataTypeInt32 name:@"any_all"]; @@ -888,7 +1027,6 @@ Tensor std_mps( } return newCachedGraph; }); - cachedGraph = tmpCachedGraph->as(); } auto inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t); @@ -923,7 +1061,7 @@ Tensor std_mps( CachedGraph* cachedGraph = cache_->LookUpAs(key); if(!cachedGraph) { - native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () { + cachedGraph = cache_->CreateCachedGraphAs(key, ^ native_mps::MPSCachedGraph * () { CachedGraph *newCachedGraph = nil; @@ -964,7 +1102,6 @@ Tensor std_mps( } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); } auto inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t); @@ -1019,7 +1156,7 @@ Tensor std_mps( CachedGraph* cachedGraph = cache_->LookUpAs(key); if(!cachedGraph) { - native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () { + cachedGraph = cache_->CreateCachedGraphAs(key, ^ native_mps::MPSCachedGraph * () { CachedGraph *newCachedGraph = nil; @autoreleasepool { @@ -1059,7 +1196,6 @@ Tensor std_mps( } return newCachedGraph; }); - cachedGraph = tmpCachedGraph->as(); } auto inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t); @@ -1094,7 +1230,7 @@ Tensor std_mps( CachedGraph* cachedGraph = cache_->LookUpAs(key); if(!cachedGraph) { - native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () { + cachedGraph = cache_->CreateCachedGraphAs(key, ^ native_mps::MPSCachedGraph * () { CachedGraph *newCachedGraph = nil; @@ -1135,7 +1271,6 @@ Tensor std_mps( } return newCachedGraph; }); - cachedGraph = tmpCachedGraph->as(); } auto inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t); @@ -1159,26 +1294,15 @@ Tensor std_mps( (const Tensor& input_t, MPSReductionType reduction_type, const std::string& func_name) { + TORCH_CHECK(input_t.scalar_type() != ScalarType::Long, "MPS does not support min/max ops with int64 input"); namespace native_mps = at::native::mps; using CachedGraph = native_mps::MPSUnaryCachedGraph; native_mps::MPSGraphCache* cache_ = native_mps::MPSGraphCache::getInstance(); - IntArrayRef input_shape = input_t.sizes(); - int64_t num_input_dims = input_shape.size(); - - // Flatten the input tensor to reduce it to one value - NSMutableArray *apparent_input_shape = [NSMutableArray arrayWithCapacity:1]; - int64_t num_in_elements = 1; - for(int i = 0; i < num_input_dims; i++) { - num_in_elements *= input_shape[i]; - } - apparent_input_shape[0] = [NSNumber numberWithInt:num_in_elements]; - Tensor output_t = at::native::empty_mps({}, input_t.scalar_type(), c10::nullopt, kMPS, c10::nullopt, c10::nullopt); - - if (output_t.numel() == 0 || num_in_elements == 0) { + if (output_t.numel() == 0 || input_t.numel() == 0) { return output_t; } @@ -1187,7 +1311,7 @@ Tensor std_mps( CachedGraph* cachedGraph = cache_->LookUpAs(key); // Initialize once if configuration not found in cache if(!cachedGraph) { - native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () { + cachedGraph = cache_->CreateCachedGraphAs(key, ^ native_mps::MPSCachedGraph * () { CachedGraph *newCachedGraph = nil; @@ -1195,17 +1319,29 @@ Tensor std_mps( MPSGraph* mpsGraph = native_mps::make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); - MPSGraphTensor* inputTensor = native_mps::mpsGraphUnrankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(input_t.scalar_type())); + MPSGraphTensor* inputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, input_t); MPSGraphTensor* outputTensor = nil; + MPSGraphTensor* castInputTensor = nil; + + if(input_t.scalar_type() != ScalarType::Float && + input_t.scalar_type() != ScalarType::Int && + input_t.scalar_type() != ScalarType::Half) { + castInputTensor = [mpsGraph castTensor:inputTensor + toType:MPSDataTypeInt32 + name:@"castInputTensor"]; + } else { + castInputTensor = inputTensor; + } + NSArray* axes = mps::getTensorAxes(input_t); if(reduction_type == MPSReductionType::MAX) - outputTensor = [mpsGraph reductionMaximumWithTensor:inputTensor - axes:@[@0] + outputTensor = [mpsGraph reductionMaximumWithTensor:castInputTensor + axes:axes name:nil]; else if(reduction_type == MPSReductionType::MIN) - outputTensor = [mpsGraph reductionMinimumWithTensor:inputTensor - axes:@[@0] + outputTensor = [mpsGraph reductionMinimumWithTensor:castInputTensor + axes:axes name:nil]; newCachedGraph->inputTensor_ = inputTensor; @@ -1214,10 +1350,9 @@ Tensor std_mps( } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); } - auto inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t, apparent_input_shape); + auto inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t); auto outputPlaceholder = native_mps::Placeholder(cachedGraph->outputTensor_, output_t, @[@1]); NSDictionary *feeds = @{ @@ -1254,6 +1389,7 @@ Tensor min_mps(const Tensor& input_t) { const Tensor& indices_t, MPSReductionType reduction_type, const std::string& func_name) { + TORCH_CHECK(input_t.scalar_type() != ScalarType::Long, "MPS does not support min/max ops with int64 input"); namespace native_mps = at::native::mps; @@ -1297,11 +1433,11 @@ Tensor min_mps(const Tensor& input_t) { auto stream = at::mps::getCurrentMPSStream(); @autoreleasepool { - string key = func_name + ":" + to_string(dim_) + ":" + native_mps::getMPSTypeString(input_t.scalar_type()); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + string key = func_name + native_mps::getTensorsStringKey({input_t, indices_t}) + ":" + to_string(dim_); + CachedGraph* cachedGraph = cache_->LookUpAs(key); if(!cachedGraph) { - native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () { + cachedGraph = cache_->CreateCachedGraphAs(key, ^ native_mps::MPSCachedGraph * () { CachedGraph *newCachedGraph = nil; @@ -1309,28 +1445,29 @@ Tensor min_mps(const Tensor& input_t) { MPSGraph* mpsGraph = native_mps::make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); - MPSGraphTensor* inputTensor = native_mps::mpsGraphUnrankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(input_t.scalar_type())); + MPSGraphTensor* inputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, input_t); MPSGraphTensor* outputTensor = nil; + + MPSGraphTensor* castInputTensor = inputTensor; + bool castOutput = false; + if(input_t.scalar_type() != ScalarType::Float && + input_t.scalar_type() != ScalarType::Int && + input_t.scalar_type() != ScalarType::Half) { + castInputTensor = [mpsGraph castTensor:inputTensor + toType:MPSDataTypeInt32 + name:@"castInputTensor"]; + castOutput = true; + } + if(reduction_type == MPSReductionType::MAX) - outputTensor = [mpsGraph reductionMaximumWithTensor:inputTensor + outputTensor = [mpsGraph reductionMaximumWithTensor:castInputTensor axis:(NSInteger)dim_ name:nil]; else if(reduction_type == MPSReductionType::MIN) - outputTensor = [mpsGraph reductionMinimumWithTensor:inputTensor + outputTensor = [mpsGraph reductionMinimumWithTensor:castInputTensor axis:(NSInteger)dim_ name:nil]; - MPSGraphTensor* castInputTensor = nil; - - if(input_t.scalar_type() != ScalarType::Float && - input_t.scalar_type() != ScalarType::Int && - input_t.scalar_type() != ScalarType::Half) - castInputTensor = [mpsGraph castTensor:inputTensor - toType:MPSDataTypeFloat32 - name:@"castInputTensor"]; - else - castInputTensor = inputTensor; - MPSGraphTensor* argreduceOutTensor = nil; if(reduction_type == MPSReductionType::MAX) argreduceOutTensor = [mpsGraph reductionArgMaximumWithTensor:castInputTensor @@ -1345,13 +1482,17 @@ Tensor min_mps(const Tensor& input_t) { toType:MPSDataTypeInt64 name:@"cast_out"]; + if (castOutput) { + outputTensor = [mpsGraph castTensor:outputTensor + toType:native_mps::getMPSDataType(output_t.scalar_type()) + name:@"cast_out"]; + } newCachedGraph->inputTensor_ = inputTensor; newCachedGraph->outputTensor_ = outputTensor; newCachedGraph->indicesTensor_ = indicesTensor; } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); } auto inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t); @@ -1411,122 +1552,126 @@ Tensor min_mps(const Tensor& input_t) { namespace native_mps = at::native::mps; using CachedGraph = native_mps::MPSUnaryCachedGraph; - native_mps::MPSGraphCache* cache_ = native_mps::MPSGraphCache::getInstance(); - - int64_t dim_; + native_mps::MPSGraphCache* cache_ = native_mps::MPSGraphCache::getInstance(); - if (dim.has_value()) { - dim_ = maybe_wrap_dim(dim.value(), input_t.dim()); - zero_numel_check_dims(input_t, dim_, reduction_type == MPSReductionType::MAX ? "argmax()" : "argmin()"); - } else { - TORCH_CHECK_INDEX( - input_t.numel() != 0, - reduction_type == MPSReductionType::MAX ? "argmax()" : "argmin()" , ": Expected reduction dim to be specified for input.numel() == 0."); - // Since input will be flattened, take argmax or argmin along 0'th dimension - dim_ = 0; - } + int64_t dim_; + + if (dim.has_value()) { + dim_ = maybe_wrap_dim(dim.value(), input_t.dim()); + zero_numel_check_dims(input_t, dim_, reduction_type == MPSReductionType::MAX ? "argmax()" : "argmin()"); + } else { + TORCH_CHECK_INDEX( + input_t.numel() != 0, + reduction_type == MPSReductionType::MAX ? "argmax()" : "argmin()" , ": Expected reduction dim to be specified for input.numel() == 0."); + // Since input will be flattened, take argmax or argmin along 0'th dimension + dim_ = 0; + } - // Calculate the output shape according to keepdim=True - // If there is no dim argument, the input shape is flattened - IntArrayRef input_shape = input_t.sizes(); - int64_t num_input_dims = input_shape.size(); - NSMutableArray *apparent_in_shape = nil; - NSMutableArray *apparent_out_shape = nil; + // Calculate the output shape according to keepdim=True + // If there is no dim argument, the input shape is flattened + IntArrayRef input_shape = input_t.sizes(); + int64_t num_input_dims = input_shape.size(); + NSMutableArray *apparent_in_shape = nil; + NSMutableArray *apparent_out_shape = nil; - if(dim.has_value()) { - apparent_out_shape = [NSMutableArray arrayWithCapacity:num_input_dims]; - for(int i = 0; i < num_input_dims; i++) { - if(dim_ == i) - apparent_out_shape[i] = @1; - else - apparent_out_shape[i] = [NSNumber numberWithInt:input_shape[i]]; - } + if(dim.has_value()) { + apparent_out_shape = [NSMutableArray arrayWithCapacity:num_input_dims]; + for(int i = 0; i < num_input_dims; i++) { + if(dim_ == i) + apparent_out_shape[i] = @1; + else + apparent_out_shape[i] = [NSNumber numberWithInt:input_shape[i]]; } - else { - apparent_in_shape = [NSMutableArray arrayWithCapacity:1]; - int64_t num_in_elements = 1; - for(int i = 0; i < num_input_dims; i++) { - num_in_elements *= input_shape[i]; - } - apparent_in_shape[0] = [NSNumber numberWithInt:num_in_elements]; - - apparent_out_shape = [NSMutableArray arrayWithCapacity:1]; - apparent_out_shape[0] = @1; + } else { + apparent_in_shape = [NSMutableArray arrayWithCapacity:1]; + int64_t num_in_elements = 1; + for(int i = 0; i < num_input_dims; i++) { + num_in_elements *= input_shape[i]; } + apparent_in_shape[0] = [NSNumber numberWithInt:num_in_elements]; - if (output_t.numel() == 0) { - return; - } + apparent_out_shape = [NSMutableArray arrayWithCapacity:1]; + apparent_out_shape[0] = @1; + } - auto stream = at::mps::getCurrentMPSStream(); + if (output_t.numel() == 0) { + return; + } - @autoreleasepool { - string key = func_name + to_string(dim_) + ":" + native_mps::getTensorsStringKey(input_t); - CachedGraph* cachedGraph = cache_->LookUpAs(key); + if (!apparent_in_shape) { + apparent_in_shape = [native_mps::getMPSShape(input_t.sizes()) mutableCopy]; + } - if(!cachedGraph) { - native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () { + auto stream = at::mps::getCurrentMPSStream(); + @autoreleasepool { + NSString* ns_key = [[apparent_in_shape valueForKey:@"description"] componentsJoinedByString:@","]; + string key = func_name + ":" + + to_string(dim_) + ":" + + native_mps::getTensorsStringKey(input_t) + ":" + + string([ns_key UTF8String]); + CachedGraph* cachedGraph = cache_->LookUpAs(key); - CachedGraph *newCachedGraph = nil; + if(!cachedGraph) { + native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () { - @autoreleasepool { - MPSGraph* mpsGraph = native_mps::make_mps_graph(); - newCachedGraph = new CachedGraph(mpsGraph); + CachedGraph *newCachedGraph = nil; - MPSGraphTensor* inputTensor = native_mps::mpsGraphUnrankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(input_t.scalar_type())); + @autoreleasepool { + MPSGraph* mpsGraph = native_mps::make_mps_graph(); + newCachedGraph = new CachedGraph(mpsGraph); - MPSGraphTensor* castInputTensor = nil; - MPSGraphTensor* argreduceOutTensor = nil; + MPSGraphTensor* inputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(input_t.scalar_type()), apparent_in_shape); - if(input_t.scalar_type() != ScalarType::Float && - input_t.scalar_type() != ScalarType::Int && - input_t.scalar_type() != ScalarType::Half) - castInputTensor = [mpsGraph castTensor:inputTensor - toType:MPSDataTypeFloat32 - name:@"castInputTensor"]; - else - castInputTensor = inputTensor; + MPSGraphTensor* castInputTensor = inputTensor; + MPSGraphTensor* argreduceOutTensor = nil; - if (reduction_type == MPSReductionType::MAX) { - argreduceOutTensor = [mpsGraph reductionArgMaximumWithTensor:castInputTensor - axis:(NSInteger)dim_ - name:nil]; - } - else { - argreduceOutTensor = [mpsGraph reductionArgMinimumWithTensor:castInputTensor - axis:(NSInteger)dim_ - name:nil]; - } - MPSGraphTensor* outputTensor = [mpsGraph castTensor:argreduceOutTensor - toType:MPSDataTypeInt64 - name:@"castOutpuTensor"]; + if(input_t.scalar_type() != ScalarType::Float && + input_t.scalar_type() != ScalarType::Int && + input_t.scalar_type() != ScalarType::Half) { + castInputTensor = [mpsGraph castTensor:inputTensor + toType:MPSDataTypeFloat32 + name:@"castInputTensor"]; + } - newCachedGraph->inputTensor_ = inputTensor; - newCachedGraph->outputTensor_ = outputTensor; + if (reduction_type == MPSReductionType::MAX) { + argreduceOutTensor = [mpsGraph reductionArgMaximumWithTensor:castInputTensor + axis:(NSInteger)dim_ + name:nil]; + } else { + argreduceOutTensor = [mpsGraph reductionArgMinimumWithTensor:castInputTensor + axis:(NSInteger)dim_ + name:nil]; } - return newCachedGraph; - }); - cachedGraph = static_cast(tmpCachedGraph); - } + MPSGraphTensor* outputTensor = [mpsGraph castTensor:argreduceOutTensor + toType:MPSDataTypeInt64 + name:@"castOutputTensor"]; - native_mps::Placeholder inputPlaceholder = native_mps::Placeholder(); - if(apparent_in_shape) - inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t, apparent_in_shape); - else - inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t); + MPSGraphTensor* outputClampedTensor = [mpsGraph clampWithTensor:outputTensor + minValueTensor:[mpsGraph constantWithScalar:0 dataType:MPSDataTypeInt64] + maxValueTensor:[mpsGraph constantWithScalar:LLONG_MAX dataType:MPSDataTypeInt64] + name: nil]; - auto outputPlaceholder = native_mps::Placeholder(cachedGraph->outputTensor_, output_t, apparent_out_shape); + newCachedGraph->inputTensor_ = inputTensor; + newCachedGraph->outputTensor_ = outputClampedTensor; + } + return newCachedGraph; + }); + cachedGraph = static_cast(tmpCachedGraph); + } - NSDictionary *feeds = @{ - inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), - }; + auto inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t, apparent_in_shape); + auto outputPlaceholder = native_mps::Placeholder(cachedGraph->outputTensor_, output_t, apparent_out_shape); - NSDictionary *results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; + NSDictionary *feeds = @{ + inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), + }; - native_mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results); - } + NSDictionary *results = @{ + outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() + }; + + native_mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results); + } } TORCH_IMPL_FUNC(argmax_out_mps) @@ -1547,7 +1692,6 @@ Tensor min_mps(const Tensor& input_t) { argmax_argmin_out_mps(input_t, dim, keepdim, output_t, MPSReductionType::MIN, "argmin_out_mps"); } - // Min/Max with dim std::tuple min_max_mps (const Tensor& input_t, @@ -1569,8 +1713,8 @@ Tensor min_mps(const Tensor& input_t) { // Use this if keepdim is false int64_t num_output_dims = num_input_dims - 1; - int64_t* malloc_apparent_out_shape = (int64_t *)malloc(num_input_dims * sizeof(int64_t)); - int64_t* malloc_out_shape = (int64_t *)malloc(num_output_dims * sizeof(int64_t)); + std::vector vec_apparent_out_shape(num_input_dims); + std::vector vec_out_shape(num_output_dims); apparent_out_shape = [NSMutableArray arrayWithCapacity:num_input_dims]; // Counter for shape when keepdim is false @@ -1578,12 +1722,12 @@ Tensor min_mps(const Tensor& input_t) { for(int i = 0; i < num_input_dims; i++) { if(dim_ == i) { apparent_out_shape[i] = @1; - malloc_apparent_out_shape[i] = 1; + vec_apparent_out_shape[i] = 1; } else { apparent_out_shape[i] = [NSNumber numberWithInt:input_shape[i]]; - malloc_apparent_out_shape[i] = input_shape[i]; - malloc_out_shape[out_i] = input_shape[i]; + vec_apparent_out_shape[i] = input_shape[i]; + vec_out_shape[out_i] = input_shape[i]; out_i++; } } @@ -1592,30 +1736,29 @@ Tensor min_mps(const Tensor& input_t) { Tensor indices_t; if(!keepdim) { output_t = at::native::empty_mps( - IntArrayRef(malloc_out_shape, num_output_dims), + IntArrayRef(vec_out_shape), input_t.scalar_type(), c10::nullopt, kMPS, c10::nullopt, c10::nullopt); indices_t = at::native::empty_mps( - IntArrayRef(malloc_out_shape, num_output_dims), + IntArrayRef(vec_out_shape), ScalarType::Long, c10::nullopt, kMPS, c10::nullopt, c10::nullopt); - } - else { + } else { output_t = at::native::empty_mps( - IntArrayRef(malloc_apparent_out_shape, num_input_dims), + IntArrayRef(vec_apparent_out_shape), input_t.scalar_type(), c10::nullopt, kMPS, c10::nullopt, c10::nullopt); indices_t = at::native::empty_mps( - IntArrayRef(malloc_apparent_out_shape, num_input_dims), + IntArrayRef(vec_apparent_out_shape), ScalarType::Long, c10::nullopt, kMPS, @@ -1624,15 +1767,11 @@ Tensor min_mps(const Tensor& input_t) { } if (output_t.numel() == 0 || input_t.numel() == 0) { - free(malloc_out_shape); - free(malloc_apparent_out_shape); return std::tuple{output_t, indices_t}; } min_max_out_mps(input_t, dim, keepdim, output_t, indices_t, reduction_type, func_name); - free(malloc_out_shape); - free(malloc_apparent_out_shape); return std::tuple{output_t, indices_t}; } @@ -1654,5 +1793,341 @@ Tensor min_mps(const Tensor& input_t) { return min_max_mps(input_t, dim, keepdim, MPSReductionType::MIN, "min_mps"); } +// Median of entire tensor into scalar result +Tensor median_mps(const Tensor& input_t) { + + if(!is_macos_13_or_newer()){ + TORCH_WARN_ONCE("MPS: median op is supported natively starting from macOS 13.0. ", + "Falling back on CPU. This may have performace implications."); + return at::median(input_t.to("cpu")); + } + + TORCH_CHECK(input_t.scalar_type() != ScalarType::Long, "MPS does not support median op with int64 input"); + + namespace native_mps = at::native::mps; + using CachedGraph = native_mps::MPSUnaryCachedGraph; + + native_mps::MPSGraphCache* cache_ = native_mps::MPSGraphCache::getInstance(); + + IntArrayRef input_shape = input_t.sizes(); + int64_t num_input_dims = input_shape.size(); + + // calculate total no. of elements in the input tensor to reduce it to one dimension + NSMutableArray *apparent_input_shape = [NSMutableArray arrayWithCapacity:1]; + int64_t num_in_elements = 1; + for(int i = 0; i < num_input_dims; i++) { + num_in_elements *= input_shape[i]; + } + + apparent_input_shape[0] = [NSNumber numberWithInt:num_in_elements]; + + Tensor output_t = at::native::empty_mps({}, input_t.scalar_type(), c10::nullopt, kMPS, c10::nullopt, c10::nullopt); + + if (output_t.numel() == 0 || num_in_elements == 0) { + return output_t; + } + + @autoreleasepool { + string key = "median_mps:"+ mps::getMPSTypeString(input_t.scalar_type()) + mps::getTensorsStringKey(input_t); + CachedGraph* cachedGraph = cache_->LookUpAs(key); + // Initialize once if configuration not found in cache + if(!cachedGraph) { + native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () { + + CachedGraph *newCachedGraph = nil; + + @autoreleasepool { + MPSGraph* mpsGraph = native_mps::make_mps_graph(); + newCachedGraph = new CachedGraph(mpsGraph); + + MPSGraphTensor* inputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, input_t); + + MPSGraphTensor* outputTensor = nil; + + MPSGraphTensor * reshapedTensor = [mpsGraph reshapeTensor:inputTensor + withShape:@[@-1] + name:nil]; + MPSDataType dataType = [inputTensor dataType]; + // #issue 104398441 sortWithTensor only supports following types, cast if necessary + if (dataType != MPSDataTypeInt32 && + dataType != MPSDataTypeFloat32 && + dataType != MPSDataTypeFloat16) { + dataType = (dataType & MPSDataTypeFloatBit) ? MPSDataTypeFloat32 : MPSDataTypeInt32; + reshapedTensor = [mpsGraph castTensor:reshapedTensor + toType:dataType + name:@"castReshapedTensor"]; + } + + MPSGraphTensor * sortedTensor = [mpsGraph + sortWithTensor:reshapedTensor + axis:((NSUInteger) (int)0) + name:nil]; + + outputTensor = [mpsGraph sliceTensor:sortedTensor + dimension:0 + start:((NSUInteger) (int)((num_in_elements+1)/2 ) - 1) + length:1 + name:nil]; + + newCachedGraph->inputTensor_ = inputTensor; + newCachedGraph->outputTensor_ = outputTensor; + } + return newCachedGraph; + }); + cachedGraph = static_cast(tmpCachedGraph); + } + + auto inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t); + auto outputPlaceholder = native_mps::Placeholder(cachedGraph->outputTensor_, output_t, @[@1]); + + NSDictionary *feeds = @{ + inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), + }; + + NSDictionary *results = @{ + outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() + }; + + native_mps::runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results); + } + + return output_t; +} + + +void median_out_mps + (const Tensor& input_t, + int64_t dim, + bool keepdim, + const Tensor& output_t, + const Tensor& indices_t, + const std::string& func_name) { + + namespace native_mps = at::native::mps; + + if (output_t.numel() == 0) { + return; + } + if (input_t.numel() == 1 && input_t.dim() == 0) { + output_t.fill_(input_t); + indices_t.fill_(0); + return; + } + + // Derive from MPSCachedGraph + struct CachedGraph : public native_mps::MPSCachedGraph + { + CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + MPSGraphTensor *inputTensor_ = nil; + MPSGraphTensor *outputTensor_ = nil; + MPSGraphTensor *indicesTensor_ = nil; + }; + + native_mps::MPSGraphCache* cache_ = native_mps::MPSGraphCache::getInstance(); + + int64_t dim_ = maybe_wrap_dim(dim, input_t.dim()); + + // Calculate the output shape according to keepdim=True + // If there is no dim argument, the input shape is flattened + IntArrayRef input_shape = input_t.sizes(); + int64_t num_input_dims = input_shape.size(); + NSMutableArray *apparent_out_shape = nil; + + apparent_out_shape = [NSMutableArray arrayWithCapacity:num_input_dims]; + for(int i = 0; i < num_input_dims; i++) { + if(dim_ == i) + apparent_out_shape[i] = @1; + else + apparent_out_shape[i] = [NSNumber numberWithInt:input_shape[i]]; + } + int dim_total_elements = input_shape[dim_]; + + auto stream = at::mps::getCurrentMPSStream(); + + @autoreleasepool { + string key = func_name + ":" + to_string(dim_) + ":" + native_mps::getTensorsStringKey(input_t) + ":" + native_mps::getTensorsStringKey(indices_t); + CachedGraph* cachedGraph = cache_->LookUpAs(key); + + if(!cachedGraph) { + cachedGraph = cache_->CreateCachedGraphAs(key, ^ native_mps::MPSCachedGraph * () { + + CachedGraph *newCachedGraph = nil; + + @autoreleasepool { + MPSGraph* mpsGraph = native_mps::make_mps_graph(); + newCachedGraph = new CachedGraph(mpsGraph); + + MPSGraphTensor* inputTensor = native_mps::mpsGraphUnrankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(input_t.scalar_type())); + MPSGraphTensor* outputTensor = nil; + MPSGraphTensor* castInputTensor = inputTensor; + MPSDataType dataType = native_mps::getMPSDataType(input_t.scalar_type()); + // #issue 104398441 sortWithTensor only supports following types, cast if necessary + if (dataType != MPSDataTypeInt32 && + dataType != MPSDataTypeFloat32 && + dataType != MPSDataTypeFloat16) { + dataType = (dataType & MPSDataTypeFloatBit) ? MPSDataTypeFloat32 : MPSDataTypeInt32; + castInputTensor = [mpsGraph castTensor:inputTensor + toType:dataType + name:@"castInputTensor"]; + } + + MPSGraphTensor * sortedTensor = [mpsGraph + sortWithTensor:castInputTensor + axis:((NSUInteger) (int)dim_) + name:nil]; + + outputTensor = [mpsGraph sliceTensor:sortedTensor + dimension:dim_ + start:((NSUInteger) (int)((dim_total_elements+1)/2 ) - 1) + length:1 + name:nil]; + MPSGraphTensor* argreduceOutTensor = nil; + argreduceOutTensor = [mpsGraph argSortWithTensor:castInputTensor + axis:(NSInteger)dim_ + name:@"argmax_out"]; + MPSGraphTensor* argOutputTensor = [mpsGraph sliceTensor:argreduceOutTensor + dimension:dim_ + start:((NSUInteger) (int)((dim_total_elements+1)/2 ) - 1) + length:1 + name:nil]; + + newCachedGraph->inputTensor_ = inputTensor; + newCachedGraph->outputTensor_ = outputTensor; + newCachedGraph->indicesTensor_ = argOutputTensor; + } + return newCachedGraph; + }); + } + + auto inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t); + auto outputPlaceholder = native_mps::Placeholder(cachedGraph->outputTensor_, output_t, apparent_out_shape); + auto indicesPlaceholder = native_mps::Placeholder(cachedGraph->indicesTensor_, indices_t, apparent_out_shape); + + NSDictionary *feeds = @{ + inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), + }; + + NSDictionary *results = @{ + outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData(), + indicesPlaceholder.getMPSGraphTensor() : indicesPlaceholder.getMPSGraphTensorData() + }; + + native_mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results); + + } + +} + +// in case mps sortWithTensor do not supported on macOS +std::tuple median_from_cpu( + const Tensor& self, + int64_t dim, + bool keepdim, Tensor & valuesI, Tensor & indicesI, IntArrayRef vec_out_shape, IntArrayRef vec_apparent_out_shape) { + // Tensor a = at::median(self.to("cpu")); + Tensor values; + Tensor indices; + if (!keepdim){ + values = at::empty({vec_out_shape}, self.options()); + indices = at::empty({vec_out_shape}, self.options().dtype(kLong)); + + } + else{ + values = at::empty({vec_apparent_out_shape}, self.options()); + indices = at::empty({vec_apparent_out_shape}, self.options().dtype(kLong)); + } + at::median_out(values, indices, self, dim, keepdim); + + valuesI.copy_(values); + indicesI.copy_(indices); + return std::forward_as_tuple(valuesI, indicesI); +} + +TORCH_API ::std::tuple median_out_mps + (const at::Tensor & input_t, + int64_t dim, + bool keepdim, + at::Tensor & values, + at::Tensor & indices){ + + TORCH_CHECK(input_t.scalar_type() != ScalarType::Long, "MPS does not support median ops with int64 input"); + + namespace native_mps = at::native::mps; + int64_t dim_ = maybe_wrap_dim(dim, input_t.dim()); + native::zero_numel_check_dims(input_t, dim_, "max()"); + + // Calculate the output shape according to keepdim=True + // If there is no dim argument, the input shape is flattened + IntArrayRef input_shape = input_t.sizes(); + int64_t num_input_dims = input_shape.size(); + NSMutableArray *apparent_out_shape = nil; + // Use this if keepdim is false + int64_t num_output_dims = num_input_dims - 1 < 0 ? 0 : num_input_dims - 1; + + std::vector vec_apparent_out_shape(num_input_dims); + std::vector vec_out_shape(num_output_dims); + + apparent_out_shape = [NSMutableArray arrayWithCapacity:num_input_dims]; + // Counter for shape when keepdim is false + int out_i = 0; + for(int i = 0; i < num_input_dims; i++) { + if(dim_ == i) { + apparent_out_shape[i] = @1; + vec_apparent_out_shape[i] = 1; + } + else { + apparent_out_shape[i] = [NSNumber numberWithInt:input_shape[i]]; + vec_apparent_out_shape[i] = input_shape[i]; + vec_out_shape[out_i] = input_shape[i]; + out_i++; + } + } + + if(!keepdim) { + values = at::native::empty_mps( + IntArrayRef(vec_out_shape), + input_t.scalar_type(), + c10::nullopt, + kMPS, + c10::nullopt, + c10::nullopt); + indices = at::native::empty_mps( + IntArrayRef(vec_out_shape), + ScalarType::Long, + c10::nullopt, + kMPS, + c10::nullopt, + c10::nullopt); + } else { + values = at::native::empty_mps( + IntArrayRef(vec_apparent_out_shape), + input_t.scalar_type(), + c10::nullopt, + kMPS, + c10::nullopt, + c10::nullopt); + indices = at::native::empty_mps( + IntArrayRef(vec_apparent_out_shape), + ScalarType::Long, + c10::nullopt, + kMPS, + c10::nullopt, + c10::nullopt); + } + + if (values.numel() == 0 || input_t.numel() == 0) { + return std::tuple{values, indices}; + } + + if(!is_macos_13_or_newer()){ + TORCH_WARN_ONCE("MPS: median op is supported natively starting from macOS 13.0.", + "Falling back on CPU. This may have performace implications."); + return median_from_cpu(input_t.to("cpu"), dim, keepdim, values, indices, IntArrayRef(vec_out_shape),IntArrayRef(vec_apparent_out_shape) ); + } + + median_out_mps(input_t, dim, keepdim, values, indices, "median_out_mps"); + + return std::tuple{values, indices}; +} + } // native } // at diff --git a/aten/src/ATen/native/mps/operations/Repeat.mm b/aten/src/ATen/native/mps/operations/Repeat.mm index 8b6b709da6427..6dd041c542b68 100644 --- a/aten/src/ATen/native/mps/operations/Repeat.mm +++ b/aten/src/ATen/native/mps/operations/Repeat.mm @@ -36,48 +36,6 @@ Tensor permute_mps(const Tensor& self, IntArrayRef dims) { return self.as_strided(newSizes, newStrides); } -void set_apparent_shapes(NSArray * input_shape, - NSArray * &apparent_input_shape, - int64_t num_input_dims, - IntArrayRef repeats, - NSMutableArray * &repeats_shape, - int64_t num_repeat_dims) { - - - bool repeat_empty = false; - if(num_repeat_dims == 0) { - num_repeat_dims = num_input_dims; - repeat_empty = true; - } - - // Set repeats_shape - repeats_shape = [NSMutableArray arrayWithCapacity:num_repeat_dims]; - - for(int i = 0; i < num_repeat_dims; i++) { - if(repeat_empty) - repeats_shape[i] = [NSNumber numberWithInteger:1]; - else - repeats_shape[i] = [NSNumber numberWithInteger:repeats[i]]; - } - - // If no extension of the shape is needed - if(num_repeat_dims == num_input_dims) { - apparent_input_shape = input_shape; - } - // num_repeat_dims > num_input_dims - else { - auto rc = [NSMutableArray arrayWithCapacity:num_repeat_dims]; - - for(int i = 0; i < num_repeat_dims - num_input_dims; i++) - rc[i] = @1; - - for(int i = num_repeat_dims - num_input_dims; i < num_repeat_dims; i++) - rc[i] = input_shape[i + num_input_dims - num_repeat_dims]; - apparent_input_shape = rc; - } - -} - Tensor repeat_mps(const Tensor& self, IntArrayRef repeats) { using namespace mps; @@ -91,54 +49,42 @@ Tensor repeat_mps(const Tensor& self, IntArrayRef repeats) { MPSGraphTensor *outputTensor_ = nil; }; - MPSGraphCache* cache_ = MPSGraphCache::getInstance(); - - NSArray *apparent_input_shape = nil; - NSMutableArray *repeats_shape = nil; - - auto input_shape = getMPSShape(self); - auto num_input_dims = [input_shape count]; - auto num_repeat_dims = repeats.size(); - - set_apparent_shapes(input_shape, - apparent_input_shape, - num_input_dims, - repeats, - repeats_shape, - num_repeat_dims); - - // Set output shape - std::vector output_shape(num_repeat_dims); + // Add new leading dimensions to the tensor if the + // number of target dimensions is larger than the + // number of source dimensions. + int64_t num_new_dimensions = repeats.size() - self.dim(); + DimVector padded_size(num_new_dimensions, 1); + padded_size.insert(padded_size.end(), self.sizes().begin(), self.sizes().end()); + DimVector target_size(repeats.size()); bool zero_tensor = false; - for(auto i : c10::irange(num_repeat_dims)) { - output_shape[i] = repeats[i] * [apparent_input_shape[i] intValue]; - if(output_shape[i] == 0) { + for(const auto idx : c10::irange(repeats.size())) { + if (repeats[idx] == 0) { zero_tensor = true; } + target_size[idx] = padded_size[idx] * repeats[idx]; } - Tensor output = at::native::empty_mps( - IntArrayRef(output_shape), - self.scalar_type(), - c10::nullopt, - kMPS, - c10::nullopt, - c10::nullopt); - - // Empty output - if(zero_tensor || output.numel() == 0) - return output; + Tensor expanded_tensor = self.expand(padded_size); + Tensor result = at::empty(target_size, self.options()); + MPSGraphCache* cache_ = MPSGraphCache::getInstance(); + if(zero_tensor || result.numel() == 0) { + return result; + } auto stream = at::mps::getCurrentMPSStream(); + auto inputDataType = getMPSDataType(expanded_tensor.scalar_type()); + auto outputDataType = getMPSDataType(result.scalar_type()); + if (!is_macos_13_or_newer()) { + if (expanded_tensor.scalar_type() == kBool) { + inputDataType = MPSDataTypeInt8; + } + if (result.scalar_type() == kBool) { + outputDataType = MPSDataTypeInt8; + } + } @autoreleasepool { - - NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","]; - NSString* ns_repeats_key = [[repeats_shape valueForKey:@"description"] componentsJoinedByString:@","]; - - string key = "repeat_mps:" + getMPSTypeString(self.scalar_type()) - + ":" + string([ns_shape_key UTF8String]) - + ":" + string([ns_repeats_key UTF8String]); + string key = "repeat_mps:" + getTensorsStringKey(self) + ":" + getArrayRefString(repeats); CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); if(!cachedGraph) { @@ -149,9 +95,9 @@ Tensor repeat_mps(const Tensor& self, IntArrayRef repeats) { MPSGraph* mpsGraph = make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); - MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self.scalar_type()), apparent_input_shape); + MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, inputDataType, getMPSShape(expanded_tensor)); MPSGraphTensor* outputTensor = [mpsGraph tileTensor:inputTensor - withMultiplier:repeats_shape + withMultiplier:getMPSShape(repeats) name:nil]; newCachedGraph->inputTensor_ = inputTensor; @@ -162,8 +108,10 @@ Tensor repeat_mps(const Tensor& self, IntArrayRef repeats) { cachedGraph = static_cast(tmpCachedGraph); } - Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, apparent_input_shape); - Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output); + Placeholder selfPlaceholder = Placeholder( + cachedGraph->inputTensor_, expanded_tensor, /*mpsShape=*/nil, /*gatherTensorData=*/true, inputDataType); + Placeholder outputPlaceholder = Placeholder( + cachedGraph->outputTensor_, result, /*mpsShape=*/nil, /*gatherTensorData*/false, outputDataType); NSDictionary* feeds = @{ selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData() @@ -175,9 +123,8 @@ Tensor repeat_mps(const Tensor& self, IntArrayRef repeats) { runMPSGraph(stream, cachedGraph->graph(), feeds, results); } - return output; - + return result; } -} -} +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/mps/operations/ScatterGather.mm b/aten/src/ATen/native/mps/operations/ScatterGather.mm index cf8d8a1fef7e3..fca74314a4fd9 100644 --- a/aten/src/ATen/native/mps/operations/ScatterGather.mm +++ b/aten/src/ATen/native/mps/operations/ScatterGather.mm @@ -1,15 +1,6 @@ // Copyright © 2022 Apple Inc. -#include -#include -#include - -#include -#include #include -#include - -#include namespace at { namespace native { @@ -19,25 +10,22 @@ int64_t dim, const Tensor & index, bool sparse_grad, - const Tensor & output) { - + const Tensor & output) +{ using namespace mps; - MPSStream* stream = getCurrentMPSStream(); + if (self_arg.numel() == 0 || index.numel() == 0) { + return; + } auto self = self_arg.dim() == 0 ? self_arg.view({1}) : self_arg; - dim = at::maybe_wrap_dim(dim, self.dim()); TORCH_CHECK(!sparse_grad, "sparse_grad not supported in MPS yet") - - TORCH_CHECK(index.scalar_type() == ScalarType::Long || index.scalar_type() == ScalarType::Int, "index_select(): Expected dtype int32 or int64 for index"); TORCH_CHECK(self.scalar_type() == output.scalar_type(), "gather(): self and output must have the same scalar type"); TORCH_CHECK(dim >= 0 && dim < self.dim(), "gather(): Indexing dim ", dim, " is out of bounds of tensor"); - - // Derive from MPSCachedGraph struct CachedGraph : public MPSCachedGraph { CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} @@ -51,29 +39,28 @@ @autoreleasepool { MPSShape* input_shape = getMPSShape(self); - NSString* ns_input_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","]; MPSShape* index_shape = getMPSShape(index); - NSString* ns_index_shape_key = [[index_shape valueForKey:@"description"] componentsJoinedByString:@","]; - - int num_input_dims = [input_shape count]; - int num_index_dims = [index_shape count]; - + uint32_t num_input_dims = [input_shape count]; + uint32_t num_index_dims = [index_shape count]; TORCH_CHECK(num_input_dims == num_index_dims, "Input and index must have same rank") // Determine if we need to slice into the input tensor bool needSlice = false; - for(int i = 0; i < num_input_dims; i++) { + for(uint32_t i = 0; i < num_input_dims; i++) { TORCH_CHECK(i == dim || [index_shape[i] intValue] <= [input_shape[i] intValue], "Index dim must not exceed input dim except at gathering axis") if(i != dim && [index_shape[i] intValue] < [input_shape[i] intValue]) needSlice = true; } - - string key = "gather_out_mps:" + getMPSTypeString(self.scalar_type()) + ":" - + getMPSTypeString(index.scalar_type()) + ":" - + std::to_string(dim) + ":" - + [ns_input_shape_key UTF8String] + ":" - + [ns_index_shape_key UTF8String]; + auto input_type = getMPSDataType(self.scalar_type()); + auto output_type = getMPSDataType(output.scalar_type()); + if (input_type == MPSDataTypeUInt8 || ((input_type == MPSDataTypeBool && !is_macos_13_or_newer()))) { + input_type = MPSDataTypeInt8; + } + if (output_type == MPSDataTypeUInt8 || ((output_type == MPSDataTypeBool && !is_macos_13_or_newer()))) { + output_type = MPSDataTypeInt8; + } + string key = "gather_out_mps" + getTensorsStringKey({self, index, output}) + ":" + std::to_string(dim); CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); if(!cachedGraph) { @@ -84,10 +71,10 @@ MPSGraph* mpsGraph = make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); - MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self.scalar_type()), input_shape); - MPSGraphTensor* indexTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(index.scalar_type()), index_shape); + MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_type, getMPSShape(self)); + MPSGraphTensor* indexTensor = mpsGraphRankedPlaceHolder(mpsGraph, index); - MPSGraphTensor* getInput = nil; + MPSGraphTensor* getInput = inputTensor; // Slice into the input tensor IF NEEDED if(needSlice) { @@ -100,31 +87,24 @@ strides[i] = @1; // All starts are 0 starts[i] = @0; - if(i != dim) - ends[i] = index_shape[i]; - else - ends[i] = input_shape[i]; + ends[i] = (i != dim) ? index_shape[i] : input_shape[i]; } getInput = [mpsGraph sliceTensor:inputTensor - starts:starts - ends:ends - strides:strides - name:nil]; - + starts:starts + ends:ends + strides:strides + name:nil]; } - else - getInput = inputTensor; MPSGraphTensor* castIndexTensor = [mpsGraph castTensor:indexTensor - toType:getMPSDataType(ScalarType::Int) + toType:MPSDataTypeInt32 name:(NSString * _Nonnull)nil]; MPSGraphTensor* outputTensor = [mpsGraph gatherAlongAxis: (NSInteger) dim withUpdatesTensor: getInput indicesTensor: castIndexTensor name: nil]; - newCachedGraph->inputTensor_ = inputTensor; newCachedGraph->indexTensor_ = indexTensor; newCachedGraph->outputTensor_ = outputTensor; @@ -134,9 +114,9 @@ cachedGraph = static_cast(tmpCachedGraph); } - Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, input_shape); + Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, input_shape, true, input_type); Placeholder indexPlaceholder = Placeholder(cachedGraph->indexTensor_, index, index_shape); - Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output); + Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output, nullptr, false, output_type); NSDictionary* feeds = @{ selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(), @@ -146,9 +126,8 @@ outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() }; - runMPSGraph(stream, cachedGraph->graph(), feeds, results); + runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results); } - } void scatter_mps_general @@ -158,23 +137,21 @@ const Tensor& src, const Tensor& output, string func_name, - const c10::string_view reduce) { - + const c10::string_view reduce) +{ using namespace mps; - MPSStream* stream = getCurrentMPSStream(); + if (self_arg.numel() == 0 || index.numel() == 0 || src.numel() == 0) { + return; + } auto self = self_arg.dim() == 0 ? self_arg.view({1}) : self_arg; - dim = at::maybe_wrap_dim(dim, self.dim()); - TORCH_CHECK(index.scalar_type() == ScalarType::Long || index.scalar_type() == ScalarType::Int, "index_select(): Expected dtype int32 or int64 for index"); TORCH_CHECK(self.scalar_type() == output.scalar_type() && output.scalar_type() == src.scalar_type(), "scatter(): self, src and output must have the same scalar type"); TORCH_CHECK(dim >= 0 && dim < self.dim(), "scatter(): Indexing dim ", dim, " is out of bounds of tensor"); - - // Derive from MPSCachedGraph struct CachedGraph : public MPSCachedGraph { CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} @@ -189,23 +166,20 @@ @autoreleasepool { MPSShape* input_shape = getMPSShape(self); - NSString* ns_input_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","]; MPSShape* index_shape = getMPSShape(index); - NSString* ns_index_shape_key = [[index_shape valueForKey:@"description"] componentsJoinedByString:@","]; MPSShape* src_shape = getMPSShape(src); - NSString* ns_src_shape_key = [[src_shape valueForKey:@"description"] componentsJoinedByString:@","]; - - int num_input_dims = [input_shape count]; - int num_index_dims = [index_shape count]; - int num_src_dims = [src_shape count]; + uint32_t num_input_dims = [input_shape count]; + uint32_t num_index_dims = [index_shape count]; + uint32_t num_src_dims = [src_shape count]; TORCH_CHECK(num_input_dims == num_index_dims && num_index_dims == num_src_dims, "Input, index and src must have same rank") // Do we need to slice into the src tensor? bool needSlice = false; bool inputNeedSlice = false; + bool needsCast = false; - for(int i = 0; i < num_input_dims; i++) { + for(uint32_t i = 0; i < num_input_dims; i++) { TORCH_CHECK(i == dim || [index_shape[i] intValue] <= [input_shape[i] intValue], "Index dim must not exceed input dim except at gathering axis") TORCH_CHECK([index_shape[i] intValue] <= [src_shape[i] intValue], "Index dim must not exceed input dim except at gathering axis") if([index_shape[i] intValue] < [src_shape[i] intValue]) @@ -213,33 +187,15 @@ if(i != dim && [index_shape[i] intValue] < [input_shape[i] intValue]) inputNeedSlice = true; } - TORCH_CHECK(reduce != "mean", "Scatter reduce mean mode not yet supported in MPS") - string reduce_key; - - if(reduce == "set") - reduce_key = "set"; - else if(reduce == "sum") - reduce_key = "sum"; - else if(reduce == "add") - reduce_key = "add"; - else if(reduce == "prod") - reduce_key = "prod"; - else if(reduce == "multiply") - reduce_key = "multiply"; - else if(reduce == "amax") - reduce_key = "amax"; - else if(reduce == "amin") - reduce_key = "amin"; - - string key = func_name + ":" + getMPSTypeString(self.scalar_type()) + ":" - + getMPSTypeString(index.scalar_type()) + ":" - + std::to_string(dim) + ":" - + [ns_input_shape_key UTF8String] + ":" - + [ns_index_shape_key UTF8String] + ":" - + [ns_src_shape_key UTF8String] + ":" - + reduce_key; + MPSDataType src_type = getMPSDataType(src.scalar_type()); + if (reduce != "set" || self.scalar_type() == ScalarType::Byte) { + src_type = isFloatingType(src.scalar_type()) ? MPSDataTypeFloat32 : MPSDataTypeInt32; + needsCast = true; + } + + string key = func_name + getTensorsStringKey({self, index, src, output}) + ":" + std::to_string(dim) + ":" + std::string(reduce); CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); if(!cachedGraph) { MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { @@ -249,112 +205,72 @@ MPSGraph* mpsGraph = make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); - MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self.scalar_type()), input_shape); - MPSGraphTensor* indexTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(index.scalar_type()), index_shape); - MPSGraphTensor* srcTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(src.scalar_type()), src_shape); - - MPSGraphTensor* getSrc = nil; - MPSGraphTensor* getInput = nil; - - // Slice into the src tensor IF NEEDED - if(needSlice) { - NSMutableArray *starts = [NSMutableArray arrayWithCapacity:num_input_dims]; - NSMutableArray *ends = [NSMutableArray arrayWithCapacity:num_input_dims]; - NSMutableArray *strides = [NSMutableArray arrayWithCapacity:num_input_dims]; + MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); + MPSGraphTensor* indexTensor = mpsGraphRankedPlaceHolder(mpsGraph, index); + MPSGraphTensor* srcTensor = mpsGraphRankedPlaceHolder(mpsGraph, src); - for(int i = 0; i < num_input_dims; i++) { - // All strides are 1 - strides[i] = @1; - // All starts are 0 - starts[i] = @0; - ends[i] = index_shape[i]; - } - - getSrc = [mpsGraph sliceTensor:srcTensor - starts:starts - ends:ends - strides:strides - name:nil]; + MPSGraphTensor* outputTensor = nil; + MPSGraphTensor* castSrcTensor = srcTensor; + MPSGraphTensor* castInputTensor = inputTensor; + if (needsCast) { + castSrcTensor = [mpsGraph castTensor:srcTensor toType:src_type name:@"cast"]; + castInputTensor = [mpsGraph castTensor:inputTensor toType:src_type name:@"cast"]; } - else - getSrc = srcTensor; + MPSGraphTensor* castIndexTensor = [mpsGraph castTensor:indexTensor toType:MPSDataTypeInt32 name:@"cast"]; + + MPSGraphTensor* slicedSrc = castSrcTensor; + MPSGraphTensor* slicedInput = castInputTensor; // Use in case input needs to be smaller to get scatter - NSArray* scatterInputShape = nil; + NSMutableArray* scatterInputShape = [NSMutableArray arrayWithArray:input_shape]; - // Slice into the input tensor IF NEEDED - if(inputNeedSlice) { + // Slice into the src or input tensors IF NEEDED + if (needSlice || inputNeedSlice) { NSMutableArray *starts = [NSMutableArray arrayWithCapacity:num_input_dims]; - NSMutableArray *ends = [NSMutableArray arrayWithCapacity:num_input_dims]; NSMutableArray *strides = [NSMutableArray arrayWithCapacity:num_input_dims]; - - auto rc = [NSMutableArray arrayWithCapacity:num_input_dims]; + NSMutableArray *ends_src = [NSMutableArray arrayWithCapacity:num_input_dims]; for(int i = 0; i < num_input_dims; i++) { - // All strides are 1 strides[i] = @1; - // All starts are 0 starts[i] = @0; - if(i != dim) { - ends[i] = index_shape[i]; - rc[i] = index_shape[i]; - } - else { - ends[i] = input_shape[i]; - rc[i] = input_shape[i]; - } + ends_src[i] = index_shape[i]; + scatterInputShape[i] = (i != dim) ? index_shape[i] : input_shape[i]; } - scatterInputShape = rc; - - getInput = [mpsGraph sliceTensor:inputTensor + if (needSlice) { + slicedSrc = [mpsGraph sliceTensor:castSrcTensor starts:starts - ends:ends + ends:ends_src strides:strides name:nil]; - - } - else { - getInput = inputTensor; - scatterInputShape = input_shape; + } + if (inputNeedSlice) { + slicedInput = [mpsGraph sliceTensor:castInputTensor + starts:starts + ends:scatterInputShape + strides:strides + name:nil]; + } } + MPSGraphScatterMode scatter_mode = MPSGraphScatterModeSet; - MPSGraphTensor* outputTensor = nil; - - MPSGraphTensor* castIndexTensor = [mpsGraph castTensor:indexTensor - toType:getMPSDataType(ScalarType::Int) - name:(NSString * _Nonnull)nil]; - - MPSGraphScatterMode scatter_mode; - - if(reduce_key == "set") - scatter_mode = MPSGraphScatterModeSet; - else if(reduce_key == "sum" || reduce_key == "add") + if(reduce == "sum" || reduce == "add") scatter_mode = MPSGraphScatterModeAdd; - else if(reduce_key == "prod" || reduce_key == "multiply") + else if(reduce == "prod" || reduce == "multiply") scatter_mode = MPSGraphScatterModeMul; - else if(reduce_key == "amax") + else if(reduce == "amax") scatter_mode = MPSGraphScatterModeMax; - else if(reduce_key == "amin") + else if(reduce == "amin") scatter_mode = MPSGraphScatterModeMin; - if(!inputNeedSlice) { - outputTensor = [mpsGraph scatterAlongAxis: (NSInteger) dim - withDataTensor: getInput - updatesTensor: getSrc - indicesTensor: castIndexTensor - mode: scatter_mode - name: nil]; - } - else { - // Scatter this into the input with set mode - MPSGraphTensor* scatterTensor = [mpsGraph scatterAlongAxis: (NSInteger) dim - withDataTensor: getInput - updatesTensor: getSrc - indicesTensor: castIndexTensor - mode: scatter_mode - name: nil]; - + // Scatter this into the input with set mode + MPSGraphTensor* scatterTensor = [mpsGraph scatterAlongAxis: (NSInteger) dim + withDataTensor: slicedInput + updatesTensor: slicedSrc + indicesTensor: castIndexTensor + mode: scatter_mode + name: nil]; + if(inputNeedSlice) { // Make an array of scatter indices tensors NSMutableArray* indicesTensors = [NSMutableArray arrayWithCapacity:num_input_dims]; @@ -369,7 +285,7 @@ } MPSGraphTensor* scatterInputShapeTensor = [mpsGraph constantWithData:[NSData dataWithBytes:shape_data.data() length:num_input_dims * sizeof(int)] - shape:@[[NSNumber numberWithInt:num_input_dims]] + shape:@[[NSNumber numberWithUnsignedInt:num_input_dims]] dataType:MPSDataTypeInt32]; for(int i = 0; i < num_input_dims; i++) { @@ -392,18 +308,19 @@ withShape:@[@-1] name:nil]; - outputTensor = [mpsGraph scatterNDWithDataTensor:inputTensor + outputTensor = [mpsGraph scatterNDWithDataTensor:castInputTensor updatesTensor:flatValuesTensor indicesTensor:scatter_fullIndexTensor batchDimensions:0 mode:MPSGraphScatterModeSet name:nil]; + } else { + outputTensor = scatterTensor; } - newCachedGraph->inputTensor_ = inputTensor; newCachedGraph->srcTensor_ = srcTensor; newCachedGraph->indexTensor_ = indexTensor; - newCachedGraph->outputTensor_ = outputTensor; + newCachedGraph->outputTensor_ = needsCast ? castMPSTensor(mpsGraph, outputTensor, output.scalar_type()) : outputTensor; } return newCachedGraph; }); @@ -424,9 +341,8 @@ outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() }; - runMPSGraph(stream, cachedGraph->graph(), feeds, results); + runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results); } - } TORCH_IMPL_FUNC(scatter_src_out_mps) diff --git a/aten/src/ATen/native/mps/operations/Shape.mm b/aten/src/ATen/native/mps/operations/Shape.mm index f491f2ff823ad..de90af408d112 100644 --- a/aten/src/ATen/native/mps/operations/Shape.mm +++ b/aten/src/ATen/native/mps/operations/Shape.mm @@ -1,18 +1,10 @@ // Copyright © 2022 Apple Inc. -#include #include -#include -#include -#include #include -#include #include #include #include -#include -#include -#include namespace at { namespace native { @@ -190,32 +182,6 @@ void check_shape_except_dim(const Tensor &first, const Tensor &second, } } -inline c10::MemoryFormat compute_output_memory_format(const TensorList &inputs) { - c10::optional format = c10::nullopt; - for (auto &t : inputs) { - auto f = t.suggest_memory_format(); - if (!format.has_value()) { - format = f; - continue; - } - if (format.value() == f) { - continue; - } - bool contiguous = (format.value() == c10::MemoryFormat::Contiguous || f == c10::MemoryFormat::Contiguous || format.value() != f); - if (contiguous) { - return c10::MemoryFormat::Contiguous; - } - } - return format.value(); -} - -//Tensor cat_mps(TensorList inputs, int64_t dimension) { - //ScalarType high_type = result_type(inputs); - //Tensor out = at::empty({0}, inputs.front().options().dtype(high_type)); - //at::native::cat_out_mps(inputs, dimension, out); - //return out; -//} - TORCH_IMPL_FUNC(cat_out_mps) (const ITensorListRef& inputs, int64_t dimension, @@ -229,17 +195,25 @@ void check_shape_except_dim(const Tensor &first, const Tensor &second, if (out.numel() == 0) { return; } - auto materialized_inputs = inputs.materialize(); + auto out_dtype = at::native::result_type(inputs); int idx = 0; for(const Tensor& t : materialized_inputs) { - TORCH_CHECK(t.dim() > 0, - "zero-dimensional tensor (at position ", idx, ") cannot be concatenated"); + TORCH_CHECK(t.dim() > 0, "zero-dimensional tensor (at position ", idx, ") cannot be concatenated"); + auto lap = at::get_overlap_status(out, t); + TORCH_CHECK(lap != at::MemOverlapStatus::Partial && lap != at::MemOverlapStatus::Full, + "torch.cat(): unsupported operation: the input tensors cannot refer to any " + "of the output memory locations. Found overlap in input tensor ", idx); idx++; } + // Check for type promotion + TORCH_CHECK(canCast(out_dtype, out.scalar_type()), + "torch.cat(): input types can't be cast to the desired output type ", out.scalar_type()); + TORCH_CHECK(inputs.size() > 0,"torch.cat(): invalid number of inputs ", inputs.size()); dimension = legacy_cat_wrap_dim(dimension, materialized_inputs); + TORCH_CHECK(dimension >= 0, "torch.cat(): invalid dimension ", dimension); // previously, size [0] tensors were the only possible empty tensors; thus, it // wasn't possible to cat empty tensors unless all the other tensors were @@ -250,35 +224,13 @@ void check_shape_except_dim(const Tensor &first, const Tensor &second, auto should_skip = [](const Tensor& t) { return t.dim() == 1 && at::native::size(t, 0) == 0; }; - - const Tensor* notSkippedTensor = NULL; // non-owning reference - - // Check for type promotion - TORCH_CHECK( - canCast(result_type(inputs), out.scalar_type()), - "torch.cat(): input types ", - " can't be cast to the desired output type ", - out.scalar_type()); - - // Inputs cannot alias the output tensor - idx = 0; - for(const Tensor& t : materialized_inputs) { - auto lap = at::get_overlap_status(out, t); - TORCH_CHECK( - lap != at::MemOverlapStatus::Partial && - lap != at::MemOverlapStatus::Full, - "torch.cat(): unsupported operation: the input tensors cannot refer to any " - "of the output memory locations. Found overlap in input " - "tensor ", - idx); - idx++; - } at::assert_no_internal_overlap(out); + Tensor notSkippedTensor; // Indices of tensors to be skipped because they're empty std::vector skipped_tensor_indices; // Tensors to be read - std::vector input_tensors; + std::vector input_tensors; int tensor_idx = 0; for(const Tensor& t : materialized_inputs) { if(t.numel() == 0 || should_skip(t)) { @@ -286,44 +238,25 @@ void check_shape_except_dim(const Tensor &first, const Tensor &second, tensor_idx++; continue; } - input_tensors.push_back(&t); + input_tensors.push_back(t); // TODO: Is this OK? - notSkippedTensor = &t; + notSkippedTensor = t; tensor_idx++; } - // If all inputs are empty tensors, return an empty tensor - if (notSkippedTensor == NULL) { + if (!notSkippedTensor.defined()) { return; } - - TORCH_CHECK( - inputs.size() > 0, - "torch.cat(): invalid number of inputs ", - inputs.size()); - TORCH_CHECK(dimension >= 0, "torch.cat(): invalid dimension ", dimension); - for (const Tensor& t : inputs) { - TORCH_CHECK( - t.device() == notSkippedTensor->device(), - "torch.cat(): all input tensors must be on the same device. Received ", - t.device(), - " and ", - notSkippedTensor->device()); + TORCH_CHECK(t.device() == notSkippedTensor.device(), + "torch.cat(): all input tensors must be on the same device. Received ", + t.device(), " and ", notSkippedTensor.device()); } + TORCH_CHECK(out.device() == notSkippedTensor.device(), + "torch.cat(): all input tensors and out must be on the same device, but inputs are on ", + notSkippedTensor.device(), " and out is on ", out.device()); - TORCH_CHECK( - out.device() == notSkippedTensor->device(), - "torch.cat(): all input tensors and out must be on the same device, but inputs are on ", - notSkippedTensor->device(), - " and out is on ", - out.device()); - - // TODO: memory_format is now an argument? - // // TODO: Factor out `compute_output_memory_format` - // c10::MemoryFormat memory_format = compute_output_memory_format(inputs); - - std::vector size(notSkippedTensor->sizes().vec()); + std::vector size(notSkippedTensor.sizes().vec()); // Compute size of the result in the cat dimension int64_t cat_dim_size = 0; @@ -333,104 +266,87 @@ void check_shape_except_dim(const Tensor &first, const Tensor &second, continue; } // TODO: Factor out `check_shape_except_dim` - check_shape_except_dim(*notSkippedTensor, tensor, dimension, idx); + check_shape_except_dim(notSkippedTensor, tensor, dimension, idx); cat_dim_size += at::native::size(tensor, dimension); idx++; } - // Compute the size of the result size[dimension] = cat_dim_size; - // skip resizing if size of result is same as expected if (out.sizes() != size) { out.resize_(size, memory_format); } - if (out.numel() == 0) { return; } - // Get stream - MPSStream* stream = getCurrentMPSStream(); + if (memory_format != MemoryFormat::Contiguous) { + switch (dimension) { + case 0: + break; + case 1: + dimension = out.dim() - dimension; + break; + default: + dimension--; + break; + } + } - struct CachedGraph : public MPSCachedGraph - { + struct CachedGraph : public MPSCachedGraph { CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - // TODO: Free this when no longer needed globally - MPSGraphTensor** inputMPSGraphTensors_ = nil; + std::vector inputTensors_; MPSGraphTensor* outputTensor_ = nil; }; - MPSGraphCache *cache_ = MPSGraphCache::getInstance(); - // Make string out of skipped tensor indices - string skipped_indices_string = ""; - for(int idx : skipped_tensor_indices) - skipped_indices_string += (std::to_string(idx)+","); - string input_types = ""; - for(const Tensor& tensor : materialized_inputs) - input_types += (getMPSTypeString(tensor.scalar_type())+","); - @autoreleasepool { - string key = "cat_out_mps:" + getMPSTypeString(result_type(inputs)) - + ":" + to_string(inputs.size()) - + ":" + skipped_indices_string - + ":" + input_types - + ":" + to_string(dimension); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + string key = "cat_out_mps:" + to_string(dimension) + getTensorsStringKey(input_tensors, /*short_dtype*/true) + ":" + + (memory_format == MemoryFormat::ChannelsLast ? "NHWC" : "NCHW"); + + CachedGraph* cachedGraph = cache_->LookUpAs(key); if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { + cachedGraph = cache_->CreateCachedGraphAs(key, ^ MPSCachedGraph * () { CachedGraph *newCachedGraph = nil; @autoreleasepool { - // Initialize graph MPSGraph *mpsGraph = make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); - // Create placeholders auto len_tensor_array = inputs.size() - skipped_tensor_indices.size(); - std::vector inputMPSGraphTensors(len_tensor_array); - std::vector castInputMPSGraphTensors(len_tensor_array); - - int graph_tensor_idx = 0; - for(const Tensor* tensor : input_tensors) { - inputMPSGraphTensors[graph_tensor_idx] = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(tensor->scalar_type()) ); - if(getMPSDataType(result_type(inputs)) == MPSDataTypeBool) { - castInputMPSGraphTensors[graph_tensor_idx] = [mpsGraph castTensor:inputMPSGraphTensors[graph_tensor_idx] - toType:MPSDataTypeFloat32 - name:[NSString stringWithFormat:@"castInput%@", [NSNumber numberWithInt:graph_tensor_idx]]]; + std::vector castInputTensors(len_tensor_array); + newCachedGraph->inputTensors_.reserve(len_tensor_array); + + for (const auto idx : c10::irange(len_tensor_array)) { + const Tensor& tensor = input_tensors[idx]; + auto scalar_type = getMPSScalarType(tensor.scalar_type()); + if (tensor.scalar_type() == kBool) { + scalar_type = MPSDataTypeInt8; } - else { - if(tensor->scalar_type() != result_type(inputs)) - castInputMPSGraphTensors[graph_tensor_idx] = [mpsGraph castTensor:inputMPSGraphTensors[graph_tensor_idx] - toType:getMPSDataType(result_type(inputs)) - name:[NSString stringWithFormat:@"castInput%@", [NSNumber numberWithInt:graph_tensor_idx]]]; - else - castInputMPSGraphTensors[graph_tensor_idx] = inputMPSGraphTensors[graph_tensor_idx]; + newCachedGraph->inputTensors_[idx] = mpsGraphRankedPlaceHolder(mpsGraph, scalar_type, getMPSShape(tensor, memory_format)); + if (tensor.scalar_type() != out_dtype) { + castInputTensors[idx] = [mpsGraph castTensor:newCachedGraph->inputTensors_[idx] + toType:getMPSDataType(out_dtype) + name:@"castInput"]; + } else { + castInputTensors[idx] = newCachedGraph->inputTensors_[idx]; } - graph_tensor_idx++; } - auto inputTensorsArray = [NSArray arrayWithObjects:castInputMPSGraphTensors.data() + auto inputTensorsArray = [NSArray arrayWithObjects:castInputTensors.data() count:len_tensor_array]; - // Use concatTensors to concatenate MPSGraphTensor* outputTensor = [mpsGraph concatTensors:inputTensorsArray dimension:dimension // Maybe convert this from int64_t -> int32 name:nil]; - - newCachedGraph->inputMPSGraphTensors_ = (MPSGraphTensor**)malloc(len_tensor_array * sizeof(MPSGraphTensor*)); - - for(int i = 0; i < len_tensor_array; i++) - newCachedGraph->inputMPSGraphTensors_[i] = inputMPSGraphTensors[i]; - if(getMPSDataType(result_type(inputs)) == MPSDataTypeBool) + if(getMPSDataType(out_dtype) == MPSDataTypeBool) { outputTensor = [mpsGraph castTensor:outputTensor toType:MPSDataTypeBool name:@"outputTensor"]; + } newCachedGraph->outputTensor_ = outputTensor; } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); } std::vector inputPlaceholders; @@ -438,14 +354,24 @@ void check_shape_except_dim(const Tensor &first, const Tensor &second, int t_idx = 0; for(const Tensor& tensor : materialized_inputs) { if(std::find(skipped_tensor_indices.begin(), skipped_tensor_indices.end(), i) == skipped_tensor_indices.end()) { - Placeholder currentInputPlaceholder = Placeholder(cachedGraph->inputMPSGraphTensors_[t_idx], tensor); - inputPlaceholders.push_back(currentInputPlaceholder); + auto scalar_type = getMPSScalarType(tensor.scalar_type()); + if (tensor.scalar_type() == kBool) { + scalar_type = MPSDataTypeInt8; + } + inputPlaceholders.emplace_back(cachedGraph->inputTensors_[t_idx], tensor, + getMPSShape(tensor, memory_format), + memory_format != MemoryFormat::ChannelsLast, scalar_type); t_idx++; } i++; } - Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, out); + auto outputDataType = getMPSScalarType(out.scalar_type()); + if (!is_macos_13_or_newer() && out.scalar_type() == kBool) { + outputDataType = MPSDataTypeInt8; + } + Placeholder outputPlaceholder = Placeholder( + cachedGraph->outputTensor_, out, /*mpsShape=*/getMPSShape(out, memory_format), /*gatherTensorData=*/false, outputDataType); NSMutableDictionary *feeds = [[NSMutableDictionary new] autorelease]; for (int i = 0; i < inputPlaceholders.size(); i++) { @@ -455,312 +381,9 @@ void check_shape_except_dim(const Tensor &first, const Tensor &second, outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() }; - mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results); + runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results); } - } -void upsample_backward_out_mps(const Tensor& grad_output, - IntArrayRef output_size, - IntArrayRef input_size, - c10::optional scales_h, - c10::optional scales_w, - const Tensor& grad_input, - MPSGraphResizeMode requested_mode, - bool requested_align_corners - ) -{ - using namespace mps; - int64_t input_dims = input_size.size(); - - TORCH_CHECK((input_dims == 4), - "NCHW tensor expected for input"); - - struct CachedGraph : public MPSCachedGraph { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *gradInputTensor = nil, *gradOutputTensor = nil; - }; - MPSGraphCache* cache_ = MPSGraphCache::getInstance(); - /* sizes */ - int64_t output_height = output_size[0]; - int64_t output_width = output_size[1]; - - int64_t input_n = input_size[0]; - int64_t input_c = input_size[1]; - int64_t input_height = input_size[2]; - int64_t input_width = input_size[3]; - - @autoreleasepool { - MPSShape* output_shape = getMPSShape(grad_output); - string key = string("upsample_backward:") + mps::getMPSShapeString(output_shape) + ":" + - getMPSTypeString(grad_output.scalar_type()) + - ":oh" + to_string(output_height) + ":ow" + to_string(output_width) + - ":ih" + to_string(input_height) + ":iw" + to_string(input_width) + - ":mode" + to_string(requested_mode); - - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if(!cachedGraph) { - cachedGraph = static_cast(cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { - - CachedGraph *newCachedGraph = nil; - @autoreleasepool { - MPSGraph* mpsGraph = make_mps_graph(); - newCachedGraph = new CachedGraph(mpsGraph); - - newCachedGraph->gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(grad_input.scalar_type()), output_shape); - MPSGraphTensor * shapeTensor = [mpsGraph constantWithScalar:0 - shape:@[[NSNumber numberWithLong: input_n], - [NSNumber numberWithLong: input_c], - [NSNumber numberWithLong:input_height], - [NSNumber numberWithLong:input_width]] - dataType:getMPSDataType(grad_output.scalar_type())]; - - newCachedGraph->gradInputTensor = [mpsGraph resizeWithGradientTensor: newCachedGraph->gradOutputTensor - input: shapeTensor - mode: requested_mode - centerResult: true - alignCorners: requested_align_corners - layout: MPSGraphTensorNamedDataLayoutNCHW - name: nil]; - - } - return newCachedGraph; - })); - } - Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor, grad_output); - Placeholder gradInputPlaceholder = Placeholder(cachedGraph->gradInputTensor, grad_input); - - NSDictionary* feeds = @{ - gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData(), - }; - NSDictionary* results = @{ - gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData() - }; - runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results); - } -} - -TORCH_IMPL_FUNC(_upsample_nearest_exact2d_backward_out_mps) ( - const Tensor& grad_output, - IntArrayRef output_size, - IntArrayRef input_size, - c10::optional scales_h, - c10::optional scales_w, - const Tensor& grad_input) -{ - upsample_backward_out_mps(grad_output, output_size, input_size, scales_h, scales_w, grad_input, MPSGraphResizeNearest, false); -} - -TORCH_IMPL_FUNC(upsample_nearest2d_backward_out_mps) ( - const Tensor& grad_output, - IntArrayRef output_size, - IntArrayRef input_size, - c10::optional scales_h, - c10::optional scales_w, - const Tensor& grad_input) -{ - upsample_backward_out_mps(grad_output, output_size, input_size, scales_h, scales_w, grad_input, MPSGraphResizeNearest, false); -} - -TORCH_IMPL_FUNC(upsample_bilinear2d_backward_out_mps) ( - const Tensor& grad_output, - IntArrayRef output_size, - IntArrayRef input_size, - bool align_corners, - c10::optional scales_h, - c10::optional scales_w, - const Tensor& grad_input) -{ - upsample_backward_out_mps(grad_output, output_size, input_size, scales_h, scales_w, grad_input, MPSGraphResizeBilinear, align_corners); -} - -void upsample_out_mps(const Tensor& input, - IntArrayRef output_size, - c10::optional scales_h, - c10::optional scales_w, - const Tensor& output, - MPSGraphResizeMode requested_mode, - bool requested_align_corners) -{ - // Get stream - using namespace mps; - struct CachedGraph : public MPSCachedGraph { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *inputTensor = nil, *outputTensor = nil; - }; - MPSGraphCache* cache_ = MPSGraphCache::getInstance(); - - /* sizes */ - int64_t output_height = output_size[0]; - int64_t output_width = output_size[1]; - @autoreleasepool { - MPSShape* input_shape = getMPSShape(input); - string key = string("upsample_2d:") + mps::getMPSShapeString(input_shape) + ":" + - getMPSTypeString(input.scalar_type()) + - ":h" + to_string(output_height) + ":w" + to_string(output_width) + - ":mode" + to_string(requested_mode); - - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - if(!cachedGraph) { - cachedGraph = static_cast(cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { - - CachedGraph *newCachedGraph = nil; - - @autoreleasepool { - MPSGraph* mpsGraph = make_mps_graph(); - newCachedGraph = new CachedGraph(mpsGraph); - - newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()), input_shape); - newCachedGraph->outputTensor = [mpsGraph resizeTensor:newCachedGraph->inputTensor - size:@[ @(output_height), @(output_width)] - mode:requested_mode - centerResult: true - alignCorners: requested_align_corners - layout: MPSGraphTensorNamedDataLayoutNCHW - name:nil]; - } - return newCachedGraph; - })); - } - Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor, input); - Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, output); - - NSDictionary* feeds = @{ - inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), - }; - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; - runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results); - } -} - -TORCH_IMPL_FUNC(_upsample_nearest_exact2d_out_mps) ( - const Tensor& input, - IntArrayRef output_size, - c10::optional scales_h, - c10::optional scales_w, - const Tensor& output) -{ - // Note: this differs from the CPU implementation in the way - // ties are resolved wrt to nearest mostly in cases where the scale - // is not an integer. - // Example: - // For upsampling from (2, 5) to (2, 16) - // MPS: - // tensor([[[[0., 0., 0., 0., 1., 1., 1., 2., 2., 2., 3., 3., 3., 4., 4., 4.], - // [5., 5., 5., 5., 6., 6., 6., 7., 7., 7., 8., 8., 8., 9., 9., 9.]]]]) - // CPU: - // tensor([[[[0., 0., 0., 1., 1., 1., 2., 2., 2., 2., 3., 3., 3., 4., 4., 4.], - // [5., 5., 5., 6., 6., 6., 7., 7., 7., 7., 8., 8., 8., 9., 9., 9.]]]]) - using namespace mps; - upsample_out_mps(input, output_size, scales_h, scales_w, output, MPSGraphResizeNearest, false); -} - - -TORCH_IMPL_FUNC(upsample_nearest2d_out_mps) ( - const Tensor& input, - IntArrayRef output_size, - c10::optional scales_h, - c10::optional scales_w, - const Tensor& output) -{ - // Note: this differs from the CPU implementation in the way - // ties are resolved wrt to nearest mostly in cases where the scale - // is not an integer. - // Example: - // For upsampling from (2, 5) to (2, 16) - // MPS: - // tensor([[[[0., 0., 0., 0., 1., 1., 1., 2., 2., 2., 3., 3., 3., 4., 4., 4.], - // [5., 5., 5., 5., 6., 6., 6., 7., 7., 7., 8., 8., 8., 9., 9., 9.]]]]) - // CPU: - // tensor([[[[0., 0., 0., 1., 1., 1., 2., 2., 2., 2., 3., 3., 3., 4., 4., 4.], - // [5., 5., 5., 6., 6., 6., 7., 7., 7., 7., 8., 8., 8., 9., 9., 9.]]]]) - using namespace mps; - upsample_out_mps(input, output_size, scales_h, scales_w, output, MPSGraphResizeNearest, false); -} - -TORCH_IMPL_FUNC(upsample_bilinear2d_out_mps) ( - const Tensor& input, - IntArrayRef output_size, - bool align_corners, - c10::optional scales_h, - c10::optional scales_w, - const Tensor& output) -{ - using namespace mps; - upsample_out_mps(input, output_size, scales_h, scales_w, output, MPSGraphResizeBilinear, align_corners); -} - -void upsample1d_out_mps(const Tensor& input, - IntArrayRef output_size, - c10::optional scales, - const Tensor& output, - MPSGraphResizeMode requested_mode) -{ - // Get stream - using namespace mps; - using CachedGraph = MPSUnaryCachedGraph; - MPSGraphCache* cache_ = MPSGraphCache::getInstance(); - - /* sizes */ - int64_t out_size = output_size[0]; - @autoreleasepool { - MPSShape* input_shape = getMPSShape(input); - string key = string("upsample_1d:") + mps::getMPSShapeString(input_shape) + ":" + - getMPSTypeString(input.scalar_type()) + - ":size" + to_string(out_size) + - ":mode" + to_string(requested_mode); - - CachedGraph* cachedGraph = cache_->LookUpAs(key); - if(!cachedGraph) { - cachedGraph = static_cast(cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { - - CachedGraph *newCachedGraph = nil; - - @autoreleasepool { - MPSGraph* mpsGraph = make_mps_graph(); - newCachedGraph = new CachedGraph(mpsGraph); - - newCachedGraph->inputTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()), input_shape); - newCachedGraph->outputTensor_ = [mpsGraph resizeTensor:newCachedGraph->inputTensor_ - size:@[ @(out_size), @(1)] - mode:requested_mode - centerResult: true - alignCorners: true - layout: MPSGraphTensorNamedDataLayoutCHW - name:nil]; - } - return newCachedGraph; - })); - } - Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input); - Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output); - - NSDictionary* feeds = @{ - inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), - }; - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; - runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results); - } -} - - -TORCH_IMPL_FUNC(upsample_nearest1d_out_mps) ( - const Tensor& input, - IntArrayRef output_size, - c10::optional scales, - const Tensor& output) -{ - using namespace mps; - upsample1d_out_mps(input, output_size, scales, output, MPSGraphResizeNearest); -} - - - - - } // namespace native } // namespace at diff --git a/aten/src/ATen/native/mps/operations/SummaryOps.mm b/aten/src/ATen/native/mps/operations/SummaryOps.mm new file mode 100644 index 0000000000000..41d33cadcb3bd --- /dev/null +++ b/aten/src/ATen/native/mps/operations/SummaryOps.mm @@ -0,0 +1,155 @@ +// Copyright © 2022 Apple Inc. + +#include + +namespace at { +namespace native { + +Tensor& bincount_mps_impl(const Tensor& self, + const Tensor& weights, + Tensor& output) { + using namespace mps; + + struct CachedGraph : public MPSCachedGraph + { + CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* inputTensor_ = nil; + MPSGraphTensor* weightsTensor_ = nil; + MPSGraphTensor* scatterDataTensor_ = nil; + MPSGraphTensor* outputTensor_ = nil; + }; + + MPSStream* stream = getCurrentMPSStream(); + MPSGraphCache* cache_ = MPSGraphCache::getInstance(); + bool has_weights = weights.defined(); + + @autoreleasepool { + string key = "bincount_mps_impl" + getTensorsStringKey({self, weights}); + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + if(!cachedGraph) { + MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { + + CachedGraph *newCachedGraph = nil; + + @autoreleasepool { + // Initialize graph + MPSGraph* mpsGraph = make_mps_graph(); + newCachedGraph = new CachedGraph(mpsGraph); + MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); + MPSGraphTensor *scatterDataTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSScalarType(output.scalar_type())); + + MPSGraphTensor *updatesTensor = nil; + if (has_weights) { + updatesTensor = mpsGraphRankedPlaceHolder(mpsGraph, weights); + } + else { + updatesTensor = [mpsGraph constantWithScalar:1.0f + shape:getMPSShape(self) + dataType:getMPSDataType(output.scalar_type())]; + } + + MPSGraphTensor *castedInputTensor = inputTensor; + if (self.scalar_type() == kByte) { + castedInputTensor = [mpsGraph castTensor:inputTensor + toType:MPSDataTypeInt32 + name:@"castInputTensor"]; + } + + MPSGraphTensor *outputTensor = [mpsGraph scatterWithDataTensor:scatterDataTensor + updatesTensor:updatesTensor + indicesTensor:castedInputTensor + axis:0 + mode:MPSGraphScatterModeAdd + name:nil]; + + newCachedGraph->inputTensor_ = inputTensor; + newCachedGraph->outputTensor_ = outputTensor; + newCachedGraph->scatterDataTensor_ = scatterDataTensor; + if (has_weights) { + newCachedGraph->weightsTensor_ = updatesTensor; + } + } + return newCachedGraph; + }); + cachedGraph = static_cast(tmpCachedGraph); + } + + // Create placeholders which use the keys of the CachedGraph to create inputs and outputs of the operation + Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, self); + Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output); + Placeholder scatterPlaceholder = Placeholder(cachedGraph->scatterDataTensor_, output); + Placeholder weightsPlaceholder = Placeholder(); + + // Create dictionary of inputs/feeds and outputs/results + NSMutableDictionary* feeds =[NSMutableDictionary dictionary]; + feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData(); + feeds[scatterPlaceholder.getMPSGraphTensor()] = scatterPlaceholder.getMPSGraphTensorData(); + if(has_weights) { + weightsPlaceholder = Placeholder(cachedGraph->weightsTensor_, weights); + feeds[weightsPlaceholder.getMPSGraphTensor()] = weightsPlaceholder.getMPSGraphTensorData(); + } + + NSDictionary* results = @{ + outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() + }; + + // Run the graph + runMPSGraph(stream, cachedGraph->graph(), feeds, results); + } + + return output; +} + +Tensor _bincount_mps(const Tensor& self, const c10::optional& weights_opt, int64_t minlength) { + // See [Note: hacky wrapper removal for optional tensor] + c10::MaybeOwned weights_maybe_owned = at::borrow_from_optional_tensor(weights_opt); + const Tensor& weights = *weights_maybe_owned; + + TORCH_CHECK(c10::isIntegralType(self.scalar_type(), /*includesBool=*/true)); + TORCH_CHECK(minlength >= 0, "minlength should be >= 0"); + + if (self.dim() == 1 && self.numel() == 0) { + return at::zeros( + {minlength}, + kLong, + c10::nullopt /* layout */, + kMPS, + c10::nullopt /* pin_memory */); + } + TORCH_CHECK(self.dim() == 1 && self.min().item() >= 0, "bincount only supports 1-d non-negative integral inputs."); + + bool has_weights = weights.defined(); + TORCH_CHECK(!(has_weights && (weights.dim() != 1 || weights.size(0) != self.size(0))), "weights should be 1-d and have the same length as input"); + + const int64_t nbins = std::max(self.max().item() + 1L, minlength); + Tensor output; + + Tensor weights_ = weights; + if (has_weights) { + if(weights.scalar_type() != ScalarType::Float && + weights.scalar_type() != ScalarType::Int && + weights.scalar_type() != ScalarType::Half) { + // Scatter doesn't work for int8/int16 dtypes + weights_ = weights.to(kInt); + } + output = at::zeros( + {nbins}, + optTypeMetaToScalarType(weights_.options().dtype_opt()), + weights_.options().layout_opt(), + weights_.options().device_opt(), + weights_.options().pinned_memory_opt()); + } + else { + output = at::zeros( + {nbins}, + kLong, + c10::nullopt /* layout */, + kMPS, + c10::nullopt /* pin_memory */); + } + + return bincount_mps_impl(self, weights_, output); +} + +} +} diff --git a/aten/src/ATen/native/mps/operations/TensorCompare.mm b/aten/src/ATen/native/mps/operations/TensorCompare.mm index 44d19e99c2f62..419f2572ea926 100644 --- a/aten/src/ATen/native/mps/operations/TensorCompare.mm +++ b/aten/src/ATen/native/mps/operations/TensorCompare.mm @@ -2,7 +2,7 @@ #include #include -#include +#include namespace at { namespace native { @@ -321,6 +321,23 @@ void clamp_scalar_out_mps(const Tensor& input_t, MPSGraphCache* cache_ = MPSGraphCache::getInstance(); + MPSDataType conditionDataType = getMPSScalarType(condition.scalar_type()); + MPSDataType selfDataType = getMPSScalarType(self.scalar_type()); + MPSDataType otherDataType = getMPSScalarType(other.scalar_type()); + // Workaround for `selectWithPredicateTensor` on macOS Monterey where bool data type may cause a hang + // The issue is fixed in macOS Ventura (13.0) + if (!is_macos_13_or_newer()) { + if (condition.scalar_type() == kBool) { + conditionDataType = MPSDataTypeInt8; + } + if (self.scalar_type() == kBool) { + selfDataType = MPSDataTypeInt8; + } + if (other.scalar_type() == kBool) { + otherDataType = MPSDataTypeInt8; + } + } + @autoreleasepool { string key = "where_self_out_mps:" + getTensorsStringKey({cond_bool, self, other}); @@ -336,9 +353,9 @@ void clamp_scalar_out_mps(const Tensor& input_t, MPSGraph* mpsGraph = make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); - MPSGraphTensor* conditionTensor = mpsGraphRankedPlaceHolder(mpsGraph, cond_bool); - MPSGraphTensor* selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); - MPSGraphTensor* otherTensor = mpsGraphRankedPlaceHolder(mpsGraph, other); + MPSGraphTensor* conditionTensor = mpsGraphRankedPlaceHolder(mpsGraph, conditionDataType, getMPSShape(cond_bool)); + MPSGraphTensor* selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, selfDataType, getMPSShape(self)); + MPSGraphTensor* otherTensor = mpsGraphRankedPlaceHolder(mpsGraph, otherDataType, getMPSShape(other)); MPSGraphTensor* outputTensor = [mpsGraph selectWithPredicateTensor:conditionTensor truePredicateTensor:selfTensor @@ -355,9 +372,12 @@ void clamp_scalar_out_mps(const Tensor& input_t, cachedGraph = static_cast(tmpCachedGraph); } - Placeholder conditionPlaceholder = Placeholder(cachedGraph->conditionTensor_, cond_bool); - Placeholder selfPlaceholder = Placeholder(cachedGraph->selfTensor_, self); - Placeholder otherPlaceholder = Placeholder(cachedGraph->otherTensor_, other); + Placeholder conditionPlaceholder = Placeholder( + cachedGraph->conditionTensor_, cond_bool, /*mpsShape=*/nullptr, /*gatherTensorData=*/true, conditionDataType); + Placeholder selfPlaceholder = Placeholder( + cachedGraph->selfTensor_, self, /*mpsShape=*/nullptr, /*gatherTensorData=*/true, selfDataType); + Placeholder otherPlaceholder = Placeholder( + cachedGraph->otherTensor_, other, /*mpsShape=*/nullptr, /*gatherTensorData=*/true, otherDataType); Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, out); NSDictionary* feeds = @{ @@ -416,5 +436,112 @@ Tensor where_mps(const Tensor& condition, } +Tensor& nan_to_num_out_mps(const Tensor& self, + c10::optional nan, + c10::optional pos_inf, + c10::optional neg_inf, + Tensor& result) +{ + TORCH_CHECK(self.scalar_type() == result.scalar_type(), "nan_to_num: dtype of out: ", + result.scalar_type(), " should be same as input: ", self.scalar_type()); + if(result.numel() == 0) { + return result; + } + if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/true)) { + at::native::resize_output(result, self.sizes()); + result.copy_(self); + return result; + } + using namespace mps; + struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* selfTensor = nil; + MPSGraphTensor* outputTensor = nil; + MPSGraphTensor* nanReplacementTensor = nil; + MPSGraphTensor* posInfReplacementTensor = nil; + MPSGraphTensor* negInfReplacementTensor = nil; + }; + MPSGraphCache* cache_ = MPSGraphCache::getInstance(); + + @autoreleasepool { + string key = "nan_to_num" + getTensorsStringKey({self}); + MPSDataType self_dtype = getMPSScalarType(self.scalar_type()); + + CachedGraph* cachedGraph = cache_->LookUpAs(key); + if (!cachedGraph) { + cachedGraph = cache_->CreateCachedGraphAs(key, ^ MPSCachedGraph * () { + CachedGraph *newCachedGraph = nil; + @autoreleasepool { + MPSGraph* mpsGraph = make_mps_graph(); + newCachedGraph = new CachedGraph(mpsGraph); + + newCachedGraph->selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); + newCachedGraph->nanReplacementTensor = mpsGraphRankedPlaceHolder(mpsGraph, self_dtype, @[@1]); + newCachedGraph->posInfReplacementTensor = mpsGraphRankedPlaceHolder(mpsGraph, self_dtype, @[@1]); + newCachedGraph->negInfReplacementTensor = mpsGraphRankedPlaceHolder(mpsGraph, self_dtype, @[@1]); + + MPSGraphTensor* nanFreeTensor = [mpsGraph selectWithPredicateTensor: [mpsGraph isNaNWithTensor: newCachedGraph->selfTensor name:nil] + truePredicateTensor: newCachedGraph->nanReplacementTensor + falsePredicateTensor: newCachedGraph->selfTensor + name: nil]; + MPSGraphTensor* subZeroTensor = [mpsGraph lessThanWithPrimaryTensor: nanFreeTensor + secondaryTensor: [mpsGraph constantWithScalar: 0.0 dataType: self_dtype] + name: nil]; + MPSGraphTensor* isInfTensor = [mpsGraph isInfiniteWithTensor: nanFreeTensor name:nil]; + // workaround for Monterey; On Ventura the output of lessThan() is always Boolean + if (subZeroTensor.dataType != MPSDataTypeBool) { + subZeroTensor = castMPSTensor(mpsGraph, subZeroTensor, kBool); + } + if (isInfTensor.dataType != MPSDataTypeBool) { + isInfTensor = castMPSTensor(mpsGraph, isInfTensor, kBool); + } + MPSGraphTensor* isNegInfTensor = [mpsGraph logicalANDWithPrimaryTensor: subZeroTensor + secondaryTensor: isInfTensor + name: nil]; + MPSGraphTensor* negInfFreeTensor = [mpsGraph selectWithPredicateTensor: isNegInfTensor + truePredicateTensor: newCachedGraph->negInfReplacementTensor + falsePredicateTensor: nanFreeTensor + name: nil]; + newCachedGraph->outputTensor = [mpsGraph selectWithPredicateTensor: [mpsGraph isInfiniteWithTensor: negInfFreeTensor name:nil] + truePredicateTensor: newCachedGraph->posInfReplacementTensor + falsePredicateTensor: negInfFreeTensor + name: nil]; + } + return newCachedGraph; + }); + } + MPSScalar nanReplacementScalar, posInfReplacementScalar, negInfReplacementScalar; + AT_DISPATCH_FLOATING_TYPES_AND(kHalf, self.scalar_type(), "nan_to_num_mps", [&]() { + scalar_t nan_replacement = static_cast(nan.value_or(0.)); + scalar_t pos_inf_replacement = pos_inf.has_value() ? + static_cast(pos_inf.value()) : + std::numeric_limits::max(); + scalar_t neg_inf_replacement = neg_inf.has_value() ? + static_cast(neg_inf.value()) : + std::numeric_limits::lowest(); + + nanReplacementScalar = getMPSScalar(nan_replacement, self.scalar_type()); + posInfReplacementScalar = getMPSScalar(pos_inf_replacement, self.scalar_type()); + negInfReplacementScalar = getMPSScalar(neg_inf_replacement, self.scalar_type()); + }); + + MPSStream* stream = getCurrentMPSStream(); + Placeholder selfPlaceholder = Placeholder(cachedGraph->selfTensor, self); + Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, result); + + NSDictionary* feeds = @{ + selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(), + cachedGraph->nanReplacementTensor : getMPSGraphTensorFromScalar(stream, nanReplacementScalar), + cachedGraph->posInfReplacementTensor : getMPSGraphTensorFromScalar(stream, posInfReplacementScalar), + cachedGraph->negInfReplacementTensor : getMPSGraphTensorFromScalar(stream, negInfReplacementScalar), + }; + NSDictionary* results = @{ + outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() + }; + runMPSGraph(stream, cachedGraph->graph(), feeds, results); + } + return result; +} + } // namespace native } // namespace at diff --git a/aten/src/ATen/native/mps/operations/TriangularOps.mm b/aten/src/ATen/native/mps/operations/TriangularOps.mm index fb6e1c52ba49e..c276707964997 100644 --- a/aten/src/ATen/native/mps/operations/TriangularOps.mm +++ b/aten/src/ATen/native/mps/operations/TriangularOps.mm @@ -172,197 +172,5 @@ } -Tensor& diag_mps_out(const Tensor& self, - int64_t diagonal, - Tensor &output) { - - // Do checks, resize output - IntArrayRef input_size = self.sizes(); - auto num_input_dims = input_size.size(); - // Input can only be 1D or 2D - TORCH_CHECK(num_input_dims == 1 || num_input_dims == 2, - "diag_mps_out: Input tensor must be 1D or 2D") - - if(num_input_dims == 1) { - auto n = input_size[0]; - if(diagonal > 0) - n += diagonal; - else if(diagonal < 0) - n -= diagonal; - - output.resize_({n, n}); - } - else if(num_input_dims == 2) { - auto num_diag_elements = std::min(input_size[0], input_size[1]); - if(diagonal > 0) { - TORCH_CHECK(input_size[1] - diagonal > 0, "Matrix not big enough for requested diagonal") - num_diag_elements = std::min(input_size[0], input_size[1] - diagonal); - } - else if(diagonal < 0) { - TORCH_CHECK(input_size[0] + diagonal > 0, "Matrix not big enough for requested diagonal") - num_diag_elements = std::min(input_size[0] + diagonal, input_size[1]); - } - - output.resize_({num_diag_elements}); - } - - using namespace mps; - MPSStream* stream = getCurrentMPSStream(); - - // Derive from MPSCachedGraph - struct CachedGraph : public MPSCachedGraph - { - CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} - MPSGraphTensor *inputTensor_ = nil; - MPSGraphTensor *outputTensor_ = nil; - }; - - MPSGraphCache* cache_ = MPSGraphCache::getInstance(); - - @autoreleasepool { - - MPSShape* input_shape = getMPSShape(self); - MPSShape* output_shape = getMPSShape(output); - NSNumber* num_input_cols = nil; - NSNumber* num_output_cols = nil; - NSMutableArray* flat_input_shape = nil; - NSMutableArray* flat_output_shape = nil; - if(num_input_dims == 1) { - num_output_cols = output_shape[1]; - flat_output_shape = [NSMutableArray arrayWithCapacity:1]; - flat_output_shape[0] = [NSNumber numberWithInt:[output_shape[0] intValue] * [output_shape[1] intValue]]; - } - else if(num_input_dims == 2) { - num_input_cols = input_shape[1]; - flat_input_shape = [NSMutableArray arrayWithCapacity:1]; - flat_input_shape[0] = [NSNumber numberWithInt:[input_shape[0] intValue] * [input_shape[1] intValue]]; - } - NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","]; - string key = "diag_mps_out:" + getMPSTypeString(self.scalar_type()) + ":" + std::to_string(diagonal) - + ":" + string([ns_shape_key UTF8String]); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); - - if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { - CachedGraph *newCachedGraph = nil; - - @autoreleasepool { - MPSGraph* mpsGraph = make_mps_graph(); - newCachedGraph = new CachedGraph(mpsGraph); - - // TODO: Accept this as the flat version in 2D case - MPSGraphTensor* inputTensor = nil; - if(num_input_dims == 1) - inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(self.scalar_type())); - else - inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self.scalar_type()), flat_input_shape); - - MPSGraphTensor* outputTensor = nil; - - MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0 - dataType:MPSDataTypeInt32]; - MPSGraphTensor* numDiagElementsRange = nil; - MPSGraphTensor* diagOffset = nil; - MPSGraphTensor* rowMultiplier = nil; - MPSGraphTensor* rowIndices = nil; - MPSGraphTensor* colIndices = nil; - MPSGraphTensor* indicesTensor = nil; - - if(num_input_dims == 1) { - int shape_data[1] = {[input_shape[0] intValue]}; - MPSGraphTensor* inputShapeTensor = [mpsGraph constantWithData:[NSData dataWithBytes:shape_data length:sizeof(int)] - shape:@[@1] - dataType:MPSDataTypeInt32]; - numDiagElementsRange = [mpsGraph coordinateAlongAxisTensor: zeroTensor - withShapeTensor: inputShapeTensor - name: nil]; - diagOffset = [mpsGraph constantWithScalar:diagonal - dataType:MPSDataTypeInt32]; - rowMultiplier = [mpsGraph constantWithScalar:[num_output_cols intValue] - dataType:MPSDataTypeInt32]; - } - else { - int shape_data[1] = {[output_shape[0] intValue]}; - MPSGraphTensor* outputShapeTensor = [mpsGraph constantWithData:[NSData dataWithBytes:shape_data length:sizeof(int)] - shape:@[@1] - dataType:MPSDataTypeInt32]; - numDiagElementsRange = [mpsGraph coordinateAlongAxisTensor: zeroTensor - withShapeTensor: outputShapeTensor - name: nil]; - diagOffset = [mpsGraph constantWithScalar:diagonal - dataType:MPSDataTypeInt32]; - rowMultiplier = [mpsGraph constantWithScalar:[num_input_cols intValue] - dataType:MPSDataTypeInt32]; - } - - if(diagonal >= 0) { - rowIndices = numDiagElementsRange; - colIndices = [mpsGraph additionWithPrimaryTensor:numDiagElementsRange - secondaryTensor:diagOffset - name:nil]; - } - else { - rowIndices = [mpsGraph subtractionWithPrimaryTensor:numDiagElementsRange - secondaryTensor:diagOffset - name:nil];; - colIndices = numDiagElementsRange; - } - - indicesTensor = [mpsGraph multiplicationWithPrimaryTensor:rowIndices - secondaryTensor:rowMultiplier - name:nil]; - indicesTensor = [mpsGraph additionWithPrimaryTensor:indicesTensor - secondaryTensor:colIndices - name:nil]; - - if(num_input_dims == 1) { - // TODO: Scatter mode doesn't matter, so what should I set it to be? - outputTensor = [mpsGraph scatterWithUpdatesTensor:inputTensor - indicesTensor:indicesTensor - shape:flat_output_shape - axis:0 - mode:MPSGraphScatterModeAdd - name:nil]; - outputTensor = [mpsGraph reshapeTensor:outputTensor - withShape:output_shape - name:nil]; - } - else if(num_input_dims == 2) { - outputTensor = [mpsGraph gatherWithUpdatesTensor:inputTensor - indicesTensor:indicesTensor - axis:0 - batchDimensions:0 - name:nil]; - } - - newCachedGraph->inputTensor_ = inputTensor; - newCachedGraph->outputTensor_ = outputTensor; - } - return newCachedGraph; - }); - cachedGraph = static_cast(tmpCachedGraph); - } - - Placeholder selfPlaceholder = Placeholder(); - if(num_input_dims == 1) - selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); - else - selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, flat_input_shape); - - Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output); - - NSDictionary* feeds = @{ - selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData() - }; - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; - - runMPSGraph(stream, cachedGraph->graph(), feeds, results); - } - - return output; -} - } // namespace native } // namespace at diff --git a/aten/src/ATen/native/mps/operations/UnaryOps.mm b/aten/src/ATen/native/mps/operations/UnaryOps.mm index dd9c8176d0b7c..bbbb81cf47432 100644 --- a/aten/src/ATen/native/mps/operations/UnaryOps.mm +++ b/aten/src/ATen/native/mps/operations/UnaryOps.mm @@ -1,10 +1,8 @@ // Copyright © 2022 Apple Inc. -#include -#include -#include -#include +//#include #include +#include #include namespace at { @@ -14,6 +12,7 @@ typedef MPSGraphTensor* (^UnaryOpBlock)(MPSGraph*, MPSGraphTensor*); using is_noop_p = std::function; +#define ConditionalOpFn(void) NSArray * (void) bool is_empty_tensor(const Tensor& self) { return self.numel() == 0; @@ -30,11 +29,11 @@ void unary_op(const Tensor& self, const Tensor& output, std::string op_name, Una } MPSGraphCache* cache_ = MPSGraphCache::getInstance(); @autoreleasepool { - string key = op_name + getTensorsStringKey({self}, /*use_scalar_value*/ false); + string key = op_name + getTensorsStringKey({self, output}); auto cachedGraph = cache_->LookUpAs(key); if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph* () { + cachedGraph = cache_->CreateCachedGraphAs(key, ^ MPSCachedGraph* () { MPSUnaryCachedGraph *newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); @@ -42,18 +41,22 @@ void unary_op(const Tensor& self, const Tensor& output, std::string op_name, Una newCachedGraph->inputTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, self); MPSGraphTensor* castTensor = newCachedGraph->inputTensor_; // Integer input must be cast to float if output is float - if (isIntegralType(self.scalar_type()) && isFloatingType(output.scalar_type())) { + if (isIntegralType(self.scalar_type(), true) && isFloatingType(output.scalar_type())) { castTensor = castMPSTensor(mpsGraph, newCachedGraph->inputTensor_, output.scalar_type()); } newCachedGraph->outputTensor_ = unaryBlock(mpsGraph, castTensor); } return newCachedGraph; }); - cachedGraph = tmpCachedGraph->as(); } - Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); - Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output); + bool gatherTensorData = true; + if (!output.is_contiguous() || output.is_view()) { + gatherTensorData = false; + } + + Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, /*mpsShape=*/nullptr, gatherTensorData); + Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output, /*mpsShape=*/nullptr, false); NSDictionary* feeds = @{ selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData() }; @@ -93,6 +96,24 @@ void unary_op(const Tensor& self, const Tensor& output, std::string op_name, Una { return mps::trunc_tensor(mpsGraph, inputTensor); }); } +TORCH_IMPL_FUNC(signbit_out_mps) (const Tensor& self, const Tensor& output) { + mps::unary_op(self, output, "signbit_out_mps", + ^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { + MPSGraphTensor* output; + // signbit is not implemented for int64 type. + // workaround for `Function signbitOp_i64 was not found in the library` + if ([inputTensor dataType] == MPSDataTypeInt64) { + MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 dataType:inputTensor.dataType]; + output = [mpsGraph lessThanWithPrimaryTensor:inputTensor + secondaryTensor:zeroTensor + name:nil]; + } else { + output = [mpsGraph signbitWithTensor: inputTensor name: nil]; + } + return mps::castMPSTensor(mpsGraph, output, ScalarType::Bool); + }); +} + TORCH_IMPL_FUNC(sign_out_mps) (const Tensor& self, const Tensor& output) { mps::unary_op(self, output, "sign_out_mps", ^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { @@ -113,7 +134,7 @@ void unary_op(const Tensor& self, const Tensor& output, std::string op_name, Una ^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) \ { return [mpsGraph func_stub##WithTensor:inputTensor name:nil]; }, \ [](const Tensor& t) -> bool { \ - return t.numel() == 0 || isIntegralType(t.scalar_type()); \ + return t.numel() == 0 || isIntegralType(t.scalar_type(), true); \ }); \ } CREATE_MPS_STRUCTURED_UNARY_ROUNDING_TORCH_IMPL_FUNC(ceil_out_mps, ceil) @@ -168,51 +189,32 @@ void unary_op(const Tensor& self, const Tensor& output, std::string op_name, Una return output; } +TORCH_IMPL_FUNC(sigmoid_out_mps) (const Tensor& self, const Tensor& output) +{ + TORCH_CHECK(self.scalar_type() != ScalarType::Long, "MPS does not support sigmoid op with int64 input"); + mps::unary_op(self, output, "sigmoid_out_mps", + ^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { + return [mpsGraph sigmoidWithTensor:inputTensor name:nil]; + }); +} + TORCH_IMPL_FUNC(log1p_out_mps) (const Tensor& self, const Tensor& output) { - using namespace mps; - if (!output.is_same_size(self)) { - output.resize_(self.sizes()); - } - MPSGraphCache* cache_ = MPSGraphCache::getInstance(); - @autoreleasepool { - string key = string("log1p_out_mps") + getTensorsStringKey({self}); - auto cachedGraph = cache_->LookUpAs(key); - - if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph* () { - MPSUnaryCachedGraph *newCachedGraph = nil; - @autoreleasepool { - MPSGraph* mpsGraph = make_mps_graph(); - newCachedGraph = new MPSUnaryCachedGraph(mpsGraph); - newCachedGraph->inputTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, self); - MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1.0 - shape:getMPSShape(self) - dataType:mps::getMPSDataType(self.scalar_type())]; - MPSGraphTensor* addedTensor = [mpsGraph additionWithPrimaryTensor:newCachedGraph->inputTensor_ - secondaryTensor:oneTensor - name:nil]; - newCachedGraph->outputTensor_ = [mpsGraph logarithmWithTensor:addedTensor - name:nil]; - } - return newCachedGraph; - }); - cachedGraph = tmpCachedGraph->as(); - } - - Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); - Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output); - NSDictionary* feeds = @{ - selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData() - }; - NSDictionary* results = @{ - outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() - }; - runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results); - } + TORCH_CHECK(self.scalar_type() != ScalarType::Long, "MPS does not support log1p op with int64 input"); + mps::unary_op(self, output, "log1p_out_mps", + ^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { + MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1.0 + dataType:inputTensor.dataType]; + MPSGraphTensor* addedTensor = [mpsGraph additionWithPrimaryTensor:inputTensor + secondaryTensor:oneTensor + name:nil]; + return [mpsGraph logarithmWithTensor:addedTensor + name:nil]; + }); } -TORCH_IMPL_FUNC(frac_out_mps) (const Tensor& self, const Tensor& output) { +TORCH_IMPL_FUNC(frac_out_mps) (const Tensor& self, const Tensor& output) +{ TORCH_CHECK(isFloatingType(self.scalar_type()), "frac_out_mps is only implemented for floating types"); mps::unary_op(self, output, "frac_out_mps", ^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { @@ -231,5 +233,224 @@ void unary_op(const Tensor& self, const Tensor& output, std::string op_name, Una }); } +TORCH_IMPL_FUNC(expm1_out_mps) (const Tensor& self, const Tensor& output) { + mps::unary_op(self, output, "expm1_out_mps", + ^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { + MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1.0 + shape:@[@1] + dataType:inputTensor.dataType]; + MPSGraphTensor* ePowTensor = [mpsGraph exponentWithTensor:inputTensor + name:nil]; + return [mpsGraph subtractionWithPrimaryTensor:ePowTensor + secondaryTensor:oneTensor + name: nil]; + }); +} + + + +TORCH_IMPL_FUNC(cumsum_out_mps) +(const Tensor& self, + int64_t dim, + c10::optional dtype, + const Tensor& result) { + + auto nDims = self.dim(); + auto wrapped_dim = maybe_wrap_dim(dim, nDims); + TORCH_CHECK(wrapped_dim >=0 && wrapped_dim < std::max(1LL, self.ndimension()), "Expected wrapped dim to be between 0 and ", self.ndimension(), " but got ", wrapped_dim , "(original dim is ", dim, ")"); + if (!is_macos_13_or_newer()) { + TORCH_WARN_ONCE("torch.cumsum supported by MPS on MacOS 13+, please upgrade"); + auto cpu_result = self.to(at::Device(kCPU)).cumsum(dim, dtype); + at::_copy_from_and_resize(cpu_result, result); + return; + } + auto input = dtype.has_value() ? self.to(dtype.value()) : self; + TORCH_CHECK(input.scalar_type() != ScalarType::Long, "MPS does not support cumsum op with int64 input"); + mps::unary_op(input, result, "cumsum_out_mp" + std::to_string(dim), + ^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { + // cumsum is horribly broken for int8, int16 and as chances for overflow is pretty high, cast to int32 + if (isIntegralType(input.scalar_type()) && input.scalar_type() !=ScalarType::Int) { + inputTensor = mps::castMPSTensor(mpsGraph, inputTensor, ScalarType::Int); + } + auto rc = [mpsGraph cumulativeSumWithTensor: inputTensor + axis: dim + name: nil]; + if (result.scalar_type()!= input.scalar_type() || + (isIntegralType(input.scalar_type()) && input.scalar_type() !=ScalarType::Int)) { + return mps::castMPSTensor(mpsGraph, rc, result.scalar_type()); + } + return rc; + }); +} + +TORCH_IMPL_FUNC(sgn_out_mps) (const Tensor& self, const Tensor& output) +{ + using namespace mps; + + if (self.numel() == 0) { + return; + } + + if (!output.is_same_size(self)) { + output.resize_(self.sizes()); + } + + string graphSuffix = "_real"; + Tensor realInput; + Tensor realOutput; + Tensor flatInput = self.flatten(); + Tensor flatOutput = output.flatten(); + if (self.is_complex()) { + realInput = at::view_as_real(flatInput); + realOutput = at::view_as_real(flatOutput); + graphSuffix = "_complex"; + } else { + realInput = flatInput; + realOutput = flatOutput; + } + + MPSDataType selfDataType = getMPSScalarType(self.scalar_type()); + // Workaround for `constantWithScalar` crashes due to unsupported bool data type + // The issue is fixed in macOS Ventura (13.0) + if (!is_macos_13_or_newer()) { + if (self.scalar_type() == kBool) { + selfDataType = MPSDataTypeInt8; + } + } + + MPSGraphCache* cache_ = MPSGraphCache::getInstance(); + @autoreleasepool { + string key = string("sgn_out_mps") + getTensorsStringKey({realInput}) + graphSuffix; + auto cachedGraph = cache_->LookUpAs(key); + + if(!cachedGraph) { + MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph* () { + MPSUnaryCachedGraph *newCachedGraph = nil; + @autoreleasepool { + MPSGraph* mpsGraph = make_mps_graph(); + newCachedGraph = new MPSUnaryCachedGraph(mpsGraph); + newCachedGraph->inputTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, selfDataType, getMPSShape(realInput)); + MPSGraphTensor* sgnTensor; + if (self.is_complex()) { + NSArray* complexNumberComponents = [mpsGraph splitTensor:newCachedGraph->inputTensor_ + numSplits: 2 + axis: 1 + name: nil]; + + MPSGraphTensor* realPartTensor = complexNumberComponents[0]; + MPSGraphTensor* imaginaryPartTensor = complexNumberComponents[1]; + + MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 + shape:realPartTensor.shape + dataType:realPartTensor.dataType]; + + MPSGraphTensor* complexZeroTensor = [mpsGraph constantWithScalar:0.0 + shape: newCachedGraph->inputTensor_.shape + dataType:realPartTensor.dataType]; + + MPSGraphTensor* isRealZero = [mpsGraph equalWithPrimaryTensor:realPartTensor + secondaryTensor:zeroTensor + name: nil]; + + MPSGraphTensor* isImaginaryZero = [mpsGraph equalWithPrimaryTensor:imaginaryPartTensor + secondaryTensor:zeroTensor + name: nil]; + + MPSGraphTensor* isComplexZero = [mpsGraph logicalANDWithPrimaryTensor:isRealZero + secondaryTensor:isImaginaryZero + name: nil]; + + MPSGraphTensor* sgnDenomReal = [mpsGraph squareWithTensor:realPartTensor + name: nil]; + + MPSGraphTensor* sgnDenomImaginary = [mpsGraph squareWithTensor:imaginaryPartTensor + name: nil]; + + MPSGraphTensor* sgnDenomSum = [mpsGraph additionWithPrimaryTensor:sgnDenomReal + secondaryTensor:sgnDenomImaginary + name: nil]; + + MPSGraphTensor* sgnDenom = [mpsGraph squareRootWithTensor:sgnDenomSum + name: nil]; + + MPSGraphTensor* sgnRealTensor = [mpsGraph divisionWithPrimaryTensor:realPartTensor + secondaryTensor:sgnDenom + name: nil]; + + MPSGraphTensor* sgnImaginaryTensor = [mpsGraph divisionWithPrimaryTensor:imaginaryPartTensor + secondaryTensor:sgnDenom + name: nil]; + + MPSGraphTensor* sgnComplexTensor = [mpsGraph concatTensors:@[sgnRealTensor, sgnImaginaryTensor] + dimension: 1 + name: nil]; + + sgnTensor = [mpsGraph selectWithPredicateTensor:isComplexZero + truePredicateTensor:complexZeroTensor + falsePredicateTensor:sgnComplexTensor + name:nil]; + } else { + MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0 + shape:newCachedGraph->inputTensor_.shape + dataType:selfDataType]; + + MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1 + shape:newCachedGraph->inputTensor_.shape + dataType:selfDataType]; + + MPSGraphTensor* negativeOneTensor = [mpsGraph constantWithScalar:-1 + shape:newCachedGraph->inputTensor_.shape + dataType:selfDataType]; + + MPSGraphTensor* isPositive = [mpsGraph greaterThanWithPrimaryTensor:newCachedGraph->inputTensor_ + secondaryTensor:zeroTensor + name: nil]; + + MPSGraphTensor* isNegative = [mpsGraph lessThanWithPrimaryTensor:newCachedGraph->inputTensor_ + secondaryTensor:zeroTensor + name: nil]; + + MPSGraphTensor* notPositiveTensor = [mpsGraph selectWithPredicateTensor:isNegative + truePredicateTensor:negativeOneTensor + falsePredicateTensor:zeroTensor + name:nil]; + + sgnTensor = [mpsGraph selectWithPredicateTensor:isPositive + truePredicateTensor:oneTensor + falsePredicateTensor:notPositiveTensor + name:nil]; + } + newCachedGraph->outputTensor_ = sgnTensor; + } + return newCachedGraph; + }); + cachedGraph = tmpCachedGraph->as(); + } + + Placeholder selfPlaceholder = Placeholder( + cachedGraph->inputTensor_, realInput, /*mpsShape*/nullptr, /*gatherTensorData=*/true, selfDataType); + Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, realOutput); + NSDictionary* feeds = @{ + selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData() + }; + NSDictionary* results = @{ + outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() + }; + runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results); + } + + if (self.is_complex()) { + std::vector realSize = self.sizes().vec(); + realSize.push_back(2); + + Tensor originalShape = realOutput.reshape(realSize); + Tensor complexOutput = at::view_as_complex(originalShape); + output.copy_(complexOutput); + } else { + Tensor originalShape = at::reshape(realOutput, self.sizes()); + output.copy_(originalShape); + } +} + } // namespace native } // namespace at diff --git a/aten/src/ATen/native/mps/operations/Unique.mm b/aten/src/ATen/native/mps/operations/Unique.mm new file mode 100644 index 0000000000000..4319c4aad0f5e --- /dev/null +++ b/aten/src/ATen/native/mps/operations/Unique.mm @@ -0,0 +1,345 @@ +// Copyright © 2022 Apple Inc. + +#include +#include +#include + +namespace at { +namespace native { +namespace mps { + +struct UniqueCachedGraph : public MPSCachedGraph +{ + UniqueCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + MPSGraphTensor* inputTensor_ = nil; + MPSGraphTensor* outputTensor_ = nil; + MPSGraphTensor* inverseIndicesTensor_ = nil; + MPSGraphTensor* countsTensor_ = nil; + MPSGraphTensor* lengthTensor_ = nil; +}; + +static std::string getUniqueKey(const ScalarType& dtype, const IntArrayRef& base_shape, + const bool return_inverse, const bool return_counts, + const bool consecutive, c10::optional dimOpt) +{ + return "_unique2_mps:" + getMPSTypeString(dtype) + "[" + getArrayRefString(base_shape) + + "]:[" + (dimOpt.has_value() ? to_string(dimOpt.value()) : "None") + "]:[" + to_string(return_inverse) + + "]:[" + to_string(return_counts) + "]:[" + to_string(consecutive) + "]"; +} + +// dim arg not supported when non consecutive, ie sorted +NSArray *buildUniqueGraph(const Tensor& self, UniqueCachedGraph *uniqueGraph, const bool return_inverse, const bool return_counts, const bool consecutive, c10::optional dimOpt) { + int64_t dim = dimOpt.has_value() ? maybe_wrap_dim(dimOpt.value(), self.dim()) : 0; + + MPSGraph *graph = uniqueGraph->graph(); + MPSGraphTensor *inputTensor = uniqueGraph->inputTensor_; + MPSShape *shape = [inputTensor shape]; + MPSShape *destShape = shape; + NSUInteger length = [shape[dim] integerValue]; + MPSDataType dataType = [inputTensor dataType]; + + MPSGraphTensor *resultTensor = (MPSGraphTensor *)[NSNull null]; + MPSGraphTensor *inverseIndicesTensor = (MPSGraphTensor *)[NSNull null]; + MPSGraphTensor *countTensor = (MPSGraphTensor *)[NSNull null]; + MPSGraphTensor *lengthTensor = (MPSGraphTensor *)[NSNull null]; + if (length <= 1) { + // Trivial case, only 1 element everything is unique + resultTensor = inputTensor; + lengthTensor = [graph constantWithScalar:0.0f + dataType:MPSDataTypeInt32]; + if (return_inverse) + inverseIndicesTensor = [graph constantWithScalar:0.0f + dataType:MPSDataTypeInt32]; + if (return_counts) + countTensor = [graph constantWithScalar:1.0f + dataType:MPSDataTypeInt32]; + return @[resultTensor, inverseIndicesTensor, countTensor, lengthTensor]; + } + + // #issue 104398441 sortWithTensor only supports following types, cast if necessary + if (dataType != MPSDataTypeInt32 && + dataType != MPSDataTypeFloat32 && + dataType != MPSDataTypeFloat16) { + dataType = (dataType & MPSDataTypeFloatBit) ? MPSDataTypeFloat32 : MPSDataTypeInt32; + inputTensor = [graph castTensor:inputTensor + toType:dataType + name:@"castInputTensor"]; + } + + bool needsFlatten = !(dimOpt.has_value() || [shape count] == 1); + if (needsFlatten) { + inputTensor = [graph reshapeTensor:inputTensor + withShape:@[@-1] + name:nil]; + length = 1; + for(NSUInteger i = 0; i < [shape count]; i++) + length *= [shape[i] integerValue]; + destShape = @[[NSNumber numberWithUnsignedInteger:length]]; + } + + MPSGraphTensor *sortedInput = nil; + if (consecutive) + sortedInput = inputTensor; + else + sortedInput = [graph sortWithTensor:inputTensor + axis:0 + name:nil]; + + MPSGraphTensor *frontNMinusOne = [graph sliceTensor:sortedInput + dimension:dim + start:0 + length:length-1 + name:nil]; + MPSGraphTensor *backNMinusOne = [graph sliceTensor:sortedInput + dimension:dim + start:1 + length:length-1 + name:nil]; + MPSGraphTensor *notEqualToPreviousElement = [graph notEqualWithPrimaryTensor:backNMinusOne + secondaryTensor:frontNMinusOne + name:nil]; + MPSGraphTensor *mask = [graph castTensor:notEqualToPreviousElement + toType:MPSDataTypeInt32 + name:@"castMaskTensor"]; + + // If comparing tensors, not scalars, check if entire tensor matches previos element using reductionOr over tensor + if (dimOpt.has_value() && [shape count] != 1) { + NSMutableArray *axes = [[NSMutableArray alloc] initWithCapacity:[shape count]-1]; + for (NSUInteger axis = 0; axis < [shape count]; axis++){ + if (axis != dim) + [axes addObject:[NSNumber numberWithUnsignedInteger:axis]]; + } + mask = [graph reductionOrWithTensor:mask + axes:axes + name:nil]; + mask = [graph squeezeTensor:mask + axes:axes + name:nil]; + [axes release]; + } + + MPSGraphTensor *scannedIndices = [graph cumulativeSumWithTensor:mask + axis:0 + name:nil]; + lengthTensor = [graph sliceTensor:scannedIndices + dimension:0 + start:length-2 + length:1 + name:nil]; + + MPSGraphTensor *minusOneTensor = [graph constantWithScalar:-1.0f + dataType:MPSDataTypeInt32]; + MPSGraphTensor *maskedIndices = [graph selectWithPredicateTensor:mask + truePredicateTensor:scannedIndices + falsePredicateTensor:minusOneTensor + name:nil]; + + MPSGraphTensor *zeroTensor = [graph constantWithScalar:0.0f + shape:@[@1] + dataType:MPSDataTypeInt32]; + MPSGraphTensor *maskedIndicesWithHead = [graph concatTensors:@[zeroTensor, maskedIndices] + dimension:0 + name:nil]; + MPSGraphTensor *scannedIndicesWithHead = [graph concatTensors:@[zeroTensor, scannedIndices] + dimension:0 + name:nil]; + + resultTensor = [graph scatterWithUpdatesTensor:sortedInput + indicesTensor:maskedIndicesWithHead + shape:destShape + axis:dim + mode:MPSGraphScatterModeSet + name:nil]; + // Cast back if necessary + if ([uniqueGraph->inputTensor_ dataType] != dataType) + resultTensor = [graph castTensor:resultTensor + toType:[uniqueGraph->inputTensor_ dataType] + name:@"castResultTensor"]; + + // Compute optional returned tensors if requested + if(return_inverse) { + MPSGraphTensor *argSortedInput = nil; + if (consecutive) + argSortedInput = [graph coordinateAlongAxis:0 + withShape:@[[NSNumber numberWithUnsignedInteger:length]] + name:nil]; + else + argSortedInput = [graph argSortWithTensor:inputTensor + axis:0 + name:nil]; + inverseIndicesTensor = [graph scatterWithUpdatesTensor:scannedIndicesWithHead + indicesTensor:argSortedInput + shape:@[[NSNumber numberWithUnsignedInteger:length]] + axis:0 + mode:MPSGraphScatterModeAdd + name:nil]; + if (needsFlatten) + inverseIndicesTensor = [graph reshapeTensor:inverseIndicesTensor + withShape:shape + name:nil]; + } + + if (return_counts) { + MPSGraphTensor *unitTensor = [graph constantWithScalar:1.0f + shape:@[[NSNumber numberWithUnsignedInteger:length]] + dataType:MPSDataTypeInt32]; + countTensor = [graph scatterWithUpdatesTensor:unitTensor + indicesTensor:scannedIndicesWithHead + shape:@[[NSNumber numberWithUnsignedInteger:length]] + axis:0 + mode:MPSGraphScatterModeAdd + name:nil]; + } + + return @[resultTensor, inverseIndicesTensor, countTensor, lengthTensor]; +} + +static UniqueCachedGraph* getUniqueGraph(const Tensor& self, const bool return_inverse, const bool return_counts, const bool consecutive, c10::optional dim) { + MPSGraphCache* cache_ = MPSGraphCache::getInstance(); + + @autoreleasepool { + string key = getUniqueKey(self.scalar_type(), self.sizes(), return_inverse, return_counts, consecutive, dim); + UniqueCachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + if(!cachedGraph) { + MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { + + UniqueCachedGraph *newCachedGraph = nil; + + @autoreleasepool { + // Initialize graph + MPSGraph* mpsGraph = make_mps_graph(); + newCachedGraph = new UniqueCachedGraph(mpsGraph); + + // Workaround for MPSShaderLibrary bug + // TODO: Remove once https://github.com/pytorch/pytorch/issues/82305 is resolved + auto inputType = getMPSScalarType(self.scalar_type()); + newCachedGraph->inputTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, inputType, getMPSShape(self.sizes())); + + NSArray *outputTensors = buildUniqueGraph(self, newCachedGraph, return_inverse, return_counts, consecutive, dim); + + newCachedGraph->outputTensor_ = outputTensors[0]; + newCachedGraph->inverseIndicesTensor_ = outputTensors[1]; + newCachedGraph->countsTensor_ = outputTensors[2]; + newCachedGraph->lengthTensor_ = outputTensors[3]; + } + return newCachedGraph; + }); + cachedGraph = static_cast(tmpCachedGraph); + } + return cachedGraph; + } +} + +void runUniqueGraph(UniqueCachedGraph *uniqueGraph, const Tensor& input, Tensor& output, + Tensor& inverse_indices, Tensor& counts, Tensor& length, + bool return_inverse, bool return_counts){ + Placeholder inputPlaceholder = Placeholder(uniqueGraph->inputTensor_, input); + NSDictionary* feeds = @{ + inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), + }; + + NSMutableDictionary* results = [NSMutableDictionary dictionary]; + Placeholder outputPlaceholder = Placeholder(uniqueGraph->outputTensor_, output); + Placeholder lengthPlaceholder = Placeholder(uniqueGraph->lengthTensor_, length); + [results setObject:outputPlaceholder.getMPSGraphTensorData() + forKey:outputPlaceholder.getMPSGraphTensor()]; + [results setObject:lengthPlaceholder.getMPSGraphTensorData() + forKey:lengthPlaceholder.getMPSGraphTensor()]; + if (return_inverse) { + Placeholder inverseIndicesPlaceholder = Placeholder(uniqueGraph->inverseIndicesTensor_, inverse_indices); + [results setObject:inverseIndicesPlaceholder.getMPSGraphTensorData() + forKey:inverseIndicesPlaceholder.getMPSGraphTensor()]; + } + if (return_counts) { + Placeholder countsPlaceholder = Placeholder(uniqueGraph->countsTensor_, counts); + [results setObject:countsPlaceholder.getMPSGraphTensorData() + forKey:countsPlaceholder.getMPSGraphTensor()]; + } + + // Run the graph + MPSStream* stream = getCurrentMPSStream(); + runMPSGraph(stream, uniqueGraph->graph(), feeds, results); +} + +} // namespace mps + +std::tuple +_unique_impl_mps(const Tensor& self, const bool return_inverse, const bool return_counts, const bool consecutive, c10::optional dimOpt) { + + const Tensor& input = self.contiguous(); + + // get flat output size + int64_t totalElems = c10::multiply_integers(input.sizes()); + + IntArrayRef outputShape = IntArrayRef(totalElems); + IntArrayRef inverseIndicesShape = input.sizes(); + IntArrayRef countsShape = IntArrayRef(totalElems); + int64_t dim = dimOpt.has_value() ? maybe_wrap_dim(dimOpt.value(), self.dim()) : 0; + + if (dimOpt.has_value()) { + outputShape = input.sizes(); + inverseIndicesShape = IntArrayRef(input.sizes()[dim]); + countsShape = IntArrayRef(input.sizes()[dim]); + } + if (!return_inverse) + inverseIndicesShape = {}; + if (!return_counts) + countsShape = {}; + + Tensor output = at::native::empty_mps(outputShape, input.scalar_type(), c10::nullopt, kMPS); + Tensor inverse_indices = at::native::empty_mps(inverseIndicesShape, ScalarType::Long, c10::nullopt, kMPS); + Tensor counts = at::native::empty_mps(countsShape, ScalarType::Long, c10::nullopt, kMPS); + Tensor length = at::native::empty_mps({1}, ScalarType::Int, c10::nullopt, kMPS); + + if (input.numel() == 0) { + return std::make_tuple(output, inverse_indices, counts); + } + + mps::UniqueCachedGraph *uniqueGraph = mps::getUniqueGraph(input, return_inverse, return_counts, consecutive, dimOpt); + mps::runUniqueGraph(uniqueGraph, input, output, inverse_indices, counts, length, return_inverse, return_counts); + + int64_t lengthScalar = length.item() + 1; // length actually holds max index, add 1 + if (output.sizes().size() != 0) { + output = at::slice(output, dim, 0, lengthScalar); + } + if (return_counts) + counts = at::slice(counts, 0, 0, lengthScalar); + + return std::make_tuple(output, inverse_indices, counts); +} + +std::tuple +unique_consecutive_mps(const Tensor& self, const bool return_inverse, const bool return_counts, c10::optional dim) { + if (!is_macos_13_or_newer()) { + TORCH_WARN_ONCE("MPS: unique_consecutive op is supported natively starting from macOS 13.0. ", + "Falling back on CPU. This may have performace implications."); + return at::unique_consecutive(self.to("cpu"), return_inverse, return_counts, dim); + } + + return _unique_impl_mps(self, return_inverse, return_counts, true, dim); +} + +std::tuple +unique_dim_consecutive_mps(const Tensor& self, int64_t dim, const bool return_inverse, const bool return_counts) { + if (!is_macos_13_or_newer()) { + TORCH_WARN_ONCE("MPS: unique_dim_consecutive op is supported natively starting from macOS 13.0. ", + "Falling back on CPU. This may have performace implications."); + return at::unique_dim_consecutive(self.to("cpu"), dim, return_inverse, return_counts); + } + + return _unique_impl_mps(self, return_inverse, return_counts, true, c10::make_optional((int64_t)dim)); +} + +std::tuple +_unique2_mps(const Tensor& self, const bool sorted, const bool return_inverse, const bool return_counts) { + if (!is_macos_13_or_newer()) { + TORCH_WARN_ONCE("MPS: _unique2 op is supported natively starting from macOS 13.0. ", + "Falling back on CPU. This may have performace implications."); + return at::_unique2(self.to("cpu"), sorted, return_inverse, return_counts); + } + + return _unique_impl_mps(self, return_inverse, return_counts, false, c10::nullopt); +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/mps/operations/UpSample.mm b/aten/src/ATen/native/mps/operations/UpSample.mm new file mode 100644 index 0000000000000..2ed353283e03f --- /dev/null +++ b/aten/src/ATen/native/mps/operations/UpSample.mm @@ -0,0 +1,385 @@ +// Copyright © 2022 Apple Inc. + +#include +#include +#include + +namespace at { +namespace native { +namespace mps { + +// Upsampling operations (1D/2D forward and backward) +// supported resize_mode: 'nearest' | 'bilinear' | 'nearest-exact' +void upsample_out_template(const Tensor& input, + IntArrayRef output_size, + c10::optional input_size_opt, // only used for backward pass + c10::optional scale_h_opt, + c10::optional scale_w_opt, + const Tensor& output, + bool align_corners, + const c10::string_view resize_mode_str) +{ + if (input.numel() == 0) + return; + + const auto input_dim = input.sizes(); + if (input_dim.size() <= 3) + native::upsample_1d_common_check(input.sizes(), output_size); + else + native::upsample_2d_common_check(input.sizes(), output_size); + + bool centerResults = false; + MPSGraphResizeMode resizeMode = MPSGraphResizeNearest; + MPSGraphResizeNearestRoundingMode nearestRoundingMode = MPSGraphResizeNearestRoundingModeFloor; + MPSGraphTensorNamedDataLayout dataLayout = input_dim.size() > 3 ? + MPSGraphTensorNamedDataLayoutNCHW : + MPSGraphTensorNamedDataLayoutCHW; + if (resize_mode_str == "nearest") { + resizeMode = MPSGraphResizeNearest; + } else if (resize_mode_str == "bilinear") { + resizeMode = MPSGraphResizeBilinear; + centerResults = true; + } else if (resize_mode_str == "nearest-exact") { + centerResults = true; + nearestRoundingMode = MPSGraphResizeNearestRoundingModeRoundPreferCeil; + } else { + AT_ERROR("Unsupported resize mode ", resize_mode_str); + } + + const bool is_macOS_13_0_or_newer = is_macos_13_or_newer(); + const int64_t output_width = output_size.size() > 1 ? output_size[1] : output_size[0]; + const int64_t output_height = output_size.size() > 1 ? output_size[0] : 1; + const float scale_w = (scale_w_opt.has_value() && scale_w_opt.value() > 0.) ? static_cast(scale_w_opt.value()) : 0.; + const float scale_h = (scale_h_opt.has_value() && scale_h_opt.value() > 0.) ? static_cast(scale_h_opt.value()) : 1.; + const float offset_y = centerResults ? (scale_h - 1.0f) / 2.0f : 0.0f; + const float offset_x = centerResults ? (scale_w - 1.0f) / 2.0f : 0.0f; + + IntArrayRef input_size; + const bool is_backward_pass = input_size_opt.has_value(); + if (is_backward_pass) + input_size = input_size_opt.value(); + + struct CachedGraph : public MPSCachedGraph { + CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + MPSGraphTensor *inputTensor = nil, *outputTensor = nil; + MPSGraphTensor *outputSizeTensor = nil; + }; + MPSStream* stream = getCurrentMPSStream(); + + @autoreleasepool { + string key = "upsample_" + std::string(resize_mode_str) + (align_corners ? "_aligned_corners" : "") + + getTensorsStringKey({input}) + ":[" + to_string(scale_h) + "," + to_string(scale_w) + "]:[" + + (is_backward_pass ? getArrayRefString(input_size) : "Undefined") + "]"; + + MPSGraphCache* cache_ = MPSGraphCache::getInstance(); + CachedGraph* cachedGraph = cache_->LookUpAs(key); + if(!cachedGraph) { + cachedGraph = cache_->CreateCachedGraphAs(key, ^ MPSCachedGraph * () { + CachedGraph *newCachedGraph = nil; + @autoreleasepool { + MPSGraph* mpsGraph = make_mps_graph(); + newCachedGraph = new CachedGraph(mpsGraph); + + newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input); + newCachedGraph->outputSizeTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[@(2)]); + + MPSGraphTensor* scaleOffsetTensor = nullptr; + MPSGraphTensor* inputSizeTensor = nullptr; + + if (scale_w > 0.0) { + const float outScales[4] = {scale_h, scale_w, offset_y, offset_x}; + scaleOffsetTensor = [mpsGraph constantWithData: [NSData dataWithBytes: outScales length: sizeof(outScales)] + shape: @[@4] + dataType: MPSDataTypeFloat32]; + } + if (is_backward_pass) { + std::vector inputSizeVec(4); + inputSizeVec[0] = @(input_size[0]); + inputSizeVec[1] = @(input_size[1]); + inputSizeVec[2] = @(input_size[2]); + inputSizeVec[3] = @(input_dim.size() > 3 ? input_size[3] : 1); + inputSizeTensor = [mpsGraph constantWithScalar: 0 + shape: [NSArray arrayWithObjects:inputSizeVec.data() count:input_dim.size()] + dataType: getMPSDataType(input.scalar_type())]; + } + if (is_macOS_13_0_or_newer) { + if (!is_backward_pass) { + if (scaleOffsetTensor && !align_corners) { + if (resizeMode == MPSGraphResizeNearest) { + newCachedGraph->outputTensor = [mpsGraph resizeNearestWithTensor: newCachedGraph->inputTensor + sizeTensor: newCachedGraph->outputSizeTensor + scaleOffsetTensor: scaleOffsetTensor + nearestRoundingMode: nearestRoundingMode + layout: dataLayout + name: nil]; + } else { // bilinear forward + newCachedGraph->outputTensor = [mpsGraph resizeBilinearWithTensor: newCachedGraph->inputTensor + sizeTensor: newCachedGraph->outputSizeTensor + scaleOffsetTensor: scaleOffsetTensor + layout: dataLayout + name: nil]; + } + } else { // scaleOffsetTensor == nil || align_corners + if (resizeMode == MPSGraphResizeNearest) { + newCachedGraph->outputTensor = [mpsGraph resizeNearestWithTensor: newCachedGraph->inputTensor + sizeTensor: newCachedGraph->outputSizeTensor + nearestRoundingMode: nearestRoundingMode + centerResult: centerResults + alignCorners: align_corners + layout: dataLayout + name: nil]; + } else { // bilinear forward + newCachedGraph->outputTensor = [mpsGraph resizeBilinearWithTensor: newCachedGraph->inputTensor + sizeTensor: newCachedGraph->outputSizeTensor + centerResult: centerResults + alignCorners: align_corners + layout: dataLayout + name: nil]; + } + } + } else { // is_backward_pass == true + if (scaleOffsetTensor && !align_corners) { + if (resizeMode == MPSGraphResizeNearest) { + newCachedGraph->outputTensor = [mpsGraph resizeNearestWithGradientTensor: newCachedGraph->inputTensor + input: inputSizeTensor + scaleOffsetTensor: scaleOffsetTensor + nearestRoundingMode: nearestRoundingMode + layout: dataLayout + name: nil]; + } else { // bilinear backward + newCachedGraph->outputTensor = [mpsGraph resizeBilinearWithGradientTensor: newCachedGraph->inputTensor + input: inputSizeTensor + scaleOffsetTensor: scaleOffsetTensor + layout: dataLayout + name: nil]; + } + } else { // scaleOffsetTensor == nil || align_corners + if (resizeMode == MPSGraphResizeNearest) { + newCachedGraph->outputTensor = [mpsGraph resizeNearestWithGradientTensor: newCachedGraph->inputTensor + input: inputSizeTensor + nearestRoundingMode: nearestRoundingMode + centerResult: centerResults + alignCorners: align_corners + layout: dataLayout + name: nil]; + } else { // bilinear backward + newCachedGraph->outputTensor = [mpsGraph resizeBilinearWithGradientTensor: newCachedGraph->inputTensor + input: inputSizeTensor + centerResult: centerResults + alignCorners: align_corners + layout: dataLayout + name: nil]; + } + } + } + } else { // if macOS version < 13.0 (for backwards compatibility) + if (!is_backward_pass) { + newCachedGraph->outputTensor = [mpsGraph resizeTensor: newCachedGraph->inputTensor + sizeTensor: newCachedGraph->outputSizeTensor + mode: resizeMode + centerResult: YES + alignCorners: align_corners + layout: dataLayout + name: nil]; + } else { + newCachedGraph->outputTensor = [mpsGraph resizeWithGradientTensor: newCachedGraph->inputTensor + input: inputSizeTensor + mode: resizeMode + centerResult: YES + alignCorners: align_corners + layout: dataLayout + name: nil]; + } + } + } + return newCachedGraph; + }); + } + MPSNDArrayDescriptor *sizeDesc = [MPSNDArrayDescriptor descriptorWithDataType: MPSDataTypeInt32 shape: @[@(2)]]; + MPSNDArray *sizeNDArray = [[[MPSNDArray alloc] initWithDevice: stream->device() descriptor: sizeDesc] autorelease]; + [sizeNDArray writeBytes: (int32_t[]) {(int32_t)output_height, (int32_t)output_width} strideBytes: nil]; + MPSGraphTensorData* sizeTensorData = [[[MPSGraphTensorData alloc] initWithMPSNDArray: sizeNDArray] autorelease]; + + Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor, input); + Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, output); + + NSDictionary* feeds = @{ + inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), + cachedGraph->outputSizeTensor : sizeTensorData, + }; + NSDictionary* results = @{ + outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() + }; + runMPSGraph(stream, cachedGraph->graph(), feeds, results); + } +} + +} // namespace mps + +static bool check_mps_compatibility(c10::optional scale) +{ + static const bool is_macOS_13_0_or_newer = is_macos_13_or_newer(); + // passing scale factors to MPS's resize APIs is not supported on macOS < 13 + if (!is_macOS_13_0_or_newer && scale.has_value() && scale.value() > 0.) { + TORCH_WARN_ONCE("MPS: passing scale factor to upsample ops is supported natively starting from macOS 13.0. ", + "Falling back on CPU. This may have performance implications."); + return false; + } + return true; +} + +TORCH_IMPL_FUNC(upsample_nearest1d_out_mps) ( + const Tensor& input, + IntArrayRef output_size, + c10::optional scale, + const Tensor& output) +{ + if (check_mps_compatibility(scale)) { + mps::upsample_out_template(input, output_size, c10::nullopt, c10::nullopt, scale, output, false, "nearest"); + } else { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) + const_cast(output) = at::upsample_nearest1d(input.to("cpu"), output_size, scale).clone().to("mps"); + } +} + +TORCH_IMPL_FUNC(upsample_nearest1d_backward_out_mps) ( + const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size, + c10::optional scale, + const Tensor& grad_input) +{ + if (check_mps_compatibility(scale)) { + mps::upsample_out_template(grad_output, output_size, input_size, c10::nullopt, scale, grad_input, false, "nearest"); + } else { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) + const_cast(grad_input) = at::upsample_nearest1d_backward(grad_output.to("cpu"), output_size, input_size, scale).clone().to("mps"); + } +} + +TORCH_IMPL_FUNC(_upsample_nearest_exact1d_out_mps) ( + const Tensor& input, + IntArrayRef output_size, + c10::optional scale, + const Tensor& output) +{ + if (check_mps_compatibility(scale)) { + mps::upsample_out_template(input, output_size, c10::nullopt, c10::nullopt, scale, output, false, "nearest-exact"); + } else { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) + const_cast(output) = at::_upsample_nearest_exact1d(input.to("cpu"), output_size, scale).clone().to("mps"); + } +} + +TORCH_IMPL_FUNC(_upsample_nearest_exact1d_backward_out_mps) ( + const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size, + c10::optional scale, + const Tensor& grad_input) +{ + if (check_mps_compatibility(scale)) { + mps::upsample_out_template(grad_output, output_size, input_size, c10::nullopt, scale, grad_input, false, "nearest-exact"); + } else { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) + const_cast(grad_input) = at::_upsample_nearest_exact1d_backward(grad_output.to("cpu"), output_size, input_size, scale).clone().to("mps"); + } +} + +TORCH_IMPL_FUNC(upsample_nearest2d_out_mps) ( + const Tensor& input, + IntArrayRef output_size, + c10::optional scales_h, + c10::optional scales_w, + const Tensor& output) +{ + if (check_mps_compatibility(scales_w)) { + mps::upsample_out_template(input, output_size, c10::nullopt, scales_h, scales_w, output, false, "nearest"); + } else { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) + const_cast(output) = at::upsample_nearest2d(input.to("cpu"), output_size, scales_h, scales_w).clone().to("mps"); + } +} + +TORCH_IMPL_FUNC(upsample_nearest2d_backward_out_mps) ( + const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size, + c10::optional scales_h, + c10::optional scales_w, + const Tensor& grad_input) +{ + if (check_mps_compatibility(scales_w)) { + mps::upsample_out_template(grad_output, output_size, input_size, scales_h, scales_w, grad_input, false, "nearest"); + } else { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) + const_cast(grad_input) = at::upsample_nearest2d_backward(grad_output.to("cpu"), output_size, input_size, scales_h, scales_w).clone().to("mps"); + } +} + +TORCH_IMPL_FUNC(_upsample_nearest_exact2d_out_mps) ( + const Tensor& input, + IntArrayRef output_size, + c10::optional scales_h, + c10::optional scales_w, + const Tensor& output) +{ + if (check_mps_compatibility(scales_w)) { + mps::upsample_out_template(input, output_size, c10::nullopt, scales_h, scales_w, output, false, "nearest-exact"); + } else { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) + const_cast(output) = at::_upsample_nearest_exact2d(input.to("cpu"), output_size, scales_h, scales_w).clone().to("mps"); + } +} + +TORCH_IMPL_FUNC(_upsample_nearest_exact2d_backward_out_mps) ( + const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size, + c10::optional scales_h, + c10::optional scales_w, + const Tensor& grad_input) +{ + if (check_mps_compatibility(scales_w)) { + mps::upsample_out_template(grad_output, output_size, input_size, scales_h, scales_w, grad_input, false, "nearest-exact"); + } else { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) + const_cast(grad_input) = at::_upsample_nearest_exact2d_backward(grad_output.to("cpu"), output_size, input_size, scales_h, scales_w).clone().to("mps"); + } +} + +TORCH_IMPL_FUNC(upsample_bilinear2d_out_mps) ( + const Tensor& input, + IntArrayRef output_size, + bool align_corners, + c10::optional scales_h, + c10::optional scales_w, + const Tensor& output) +{ + if (check_mps_compatibility(scales_w)) { + mps::upsample_out_template(input, output_size, c10::nullopt, scales_h, scales_w, output, align_corners, "bilinear"); + } else { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) + const_cast(output) = at::upsample_bilinear2d(input.to("cpu"), output_size, align_corners, scales_h, scales_w).clone().to("mps"); + } +} + +TORCH_IMPL_FUNC(upsample_bilinear2d_backward_out_mps) ( + const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size, + bool align_corners, + c10::optional scales_h, + c10::optional scales_w, + const Tensor& grad_input) +{ + if (check_mps_compatibility(scales_w)) { + mps::upsample_out_template(grad_output, output_size, input_size, scales_h, scales_w, grad_input, align_corners, "bilinear"); + } else { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) + const_cast(grad_input) = at::upsample_bilinear2d_backward(grad_output.to("cpu"), output_size, input_size, align_corners, scales_h, scales_w).clone().to("mps"); + } +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/mps/operations/View.mm b/aten/src/ATen/native/mps/operations/View.mm index 70c7d50b730f3..cd28295eab9e2 100644 --- a/aten/src/ATen/native/mps/operations/View.mm +++ b/aten/src/ATen/native/mps/operations/View.mm @@ -2,7 +2,7 @@ #include #include -#include +#include namespace at { namespace native { @@ -18,16 +18,27 @@ std::vector strideTensors; }; -static std::string getStridedKey(const ScalarType& dtype, const IntArrayRef& base_shape, - const IntArrayRef& new_shape, bool is_scatter) +static std::string getStridedKey(const ScalarType& self_dtype, const ScalarType& updates_dtype, const IntArrayRef& base_shape, + const IntArrayRef& new_shape, const IntArrayRef& stride, + int64_t storage_offset, bool is_scatter) { - return (is_scatter ? "scatter:" : "gather:") + getMPSTypeString(dtype) + "[" + - getArrayRefString(base_shape) + "]:[" + getArrayRefString(new_shape) + "]"; + std::string dtype_key = getMPSTypeString(self_dtype); + if (is_scatter) { + dtype_key += ":" + getMPSTypeString(updates_dtype); + } + + return (is_scatter ? "scatter:" : "gather:") + dtype_key + "[" + + getArrayRefString(base_shape) + "]:[" + getArrayRefString(new_shape) + "]:[" + + getArrayRefString(stride) + "]:[" + to_string(storage_offset) + "]"; } // initializes the MTLBuffers for tensor data and runs the MPSGraph for the view op -static Tensor& runViewGraph(ViewCachedGraph* cachedGraph, const at::Tensor& src, Tensor& output, - bool needsScatter, bool requires_sync = false) +static Tensor& runViewGraph( + ViewCachedGraph* cachedGraph, + const at::Tensor& src, + Tensor& output, + bool needsScatter, + id updatesBuffer = nil) { const id sourceBuffer = getMTLBufferStorage(src); const id outputBuffer = getMTLBufferStorage(output); @@ -48,9 +59,14 @@ shape: inputShape dataType: inputType] autorelease]; if (needsScatter) { - feeds[cachedGraph->updatesTensor] = [[[MPSGraphTensorData alloc] initWithMTLBuffer: sourceBuffer + auto updatesType = getMPSScalarType(src.scalar_type()); + if (updatesType == MPSDataTypeUInt8 || (updatesType == MPSDataTypeBool && !is_macos_13_or_newer())) { + updatesType = MPSDataTypeInt8; + } + + feeds[cachedGraph->updatesTensor] = [[[MPSGraphTensorData alloc] initWithMTLBuffer: (updatesBuffer != nil) ? updatesBuffer : sourceBuffer shape: getMPSShape(src.numel()) - dataType: inputType] autorelease]; + dataType: updatesType] autorelease]; } MPSScalar storageOffsetScalar = getMPSScalar(storage_offset, ScalarType::Int); feeds[cachedGraph->storageOffsetTensor] = getMPSGraphTensorFromScalar(stream, storageOffsetScalar); @@ -60,28 +76,441 @@ strideScalars[i] = getMPSScalar(strides[i], ScalarType::Int); feeds[cachedGraph->strideTensors[i]] = getMPSGraphTensorFromScalar(stream, strideScalars[i]); } - // Workaround for MPSShaderLibrary bug - // TODO: Remove once https://github.com/pytorch/pytorch/issues/82305 is resolved - auto outputType = getMPSDataType(output.scalar_type()); - if (outputType == MPSDataTypeUInt8) { + // Workaround for MPSShaderLibrary bug in macOS Monterey + // This is fixed in macOS Ventura + auto outputType = getMPSScalarType(output.scalar_type()); + if (outputType == MPSDataTypeUInt8 || (outputType == MPSDataTypeBool && !is_macos_13_or_newer())) { outputType = MPSDataTypeInt8; } + MPSGraphTensorData* outputTensorData = [[[MPSGraphTensorData alloc] initWithMTLBuffer: outputBuffer shape: outputShape dataType: outputType] autorelease]; NSDictionary* results = @{ cachedGraph->outputTensor : outputTensorData }; - stream->executeMPSGraph(cachedGraph->graph(), feeds, results, - requires_sync ? SyncType::COMMIT : SyncType::COMMIT_ADAPTIVE); + runMPSGraph(stream, cachedGraph->graph(), feeds, results); } return output; } +MPSGraphTensor *permuteTensor(MPSGraph *graph, MPSGraphTensor *inputTensor, NSArray *permuteOrder) { + NSUInteger srcRank = [[inputTensor shape] count]; + if (srcRank != [permuteOrder count]) + return nil; + + MPSGraphTensor *outputTensor = inputTensor; + std::vector dimensionOrder(srcRank); + std::iota (std::begin(dimensionOrder), std::end(dimensionOrder), 0); + + for (NSUInteger i = 0; i < srcRank; i++) { + NSUInteger axis = [permuteOrder[i] integerValue]; + auto axisIter = std::find(dimensionOrder.begin(), dimensionOrder.end(), axis); + NSUInteger axis1 = i; + NSUInteger axis2 = axisIter - dimensionOrder.begin(); + iter_swap(dimensionOrder.begin() + i, axisIter); + + outputTensor = [graph transposeTensor:outputTensor + dimension:axis1 + withDimension:axis2 + name:nil]; + } + + return outputTensor; +} + +NSDictionary *getStrideToDimLengthOffsetDict(MPSGraphTensor *tensor, NSUInteger rank, NSUInteger offset) { + // Assuming input tensor has default strides + NSInteger stride = 1; + NSMutableDictionary *strideToDimLengthOffset = [[NSMutableDictionary alloc] init]; + for (NSInteger srcDim = rank - 1; srcDim >= 0; srcDim--) { + NSUInteger size = [[tensor shape][srcDim] integerValue]; + NSDictionary *entry = + @{ + @"dim": [NSNumber numberWithInteger:srcDim], + @"length": [tensor shape][srcDim], + @"offset": [NSNumber numberWithInteger:offset % size] // offset is determined traversing backwards through stride + }; + [strideToDimLengthOffset setValue:entry forKey:[NSString stringWithFormat:@"%ld",stride]]; + offset /= size; + stride *= size; + } + return strideToDimLengthOffset; +} + +// Detect only expand dims, allows for duplicate strides +MPSGraphTensor* asStridedLayer_expandDimsPattern(MPSGraph *graph, MPSGraphTensor *inputTensor, int dstRank, const IntArrayRef& dstSizes, const IntArrayRef& dstStrides, int offset) { + + NSUInteger srcRank = [[inputTensor shape] count]; + // Not an expand dims + if (srcRank >= dstRank) + return nil; + + NSMutableArray *expandAxes = [[NSMutableArray alloc] init]; + + BOOL isValidExpand = YES; + NSInteger currSrcDim = (NSInteger)srcRank - 1; + NSUInteger currSrcStride = 1; + for (NSInteger dstDim = dstRank - 1; dstDim >= 0 && isValidExpand; dstDim--) { + NSUInteger currDimLength = dstSizes[dstDim]; + NSUInteger currStride = dstStrides[dstDim]; + NSUInteger currSrcDimLength = currSrcDim >= 0 ? [[inputTensor shape][currSrcDim] integerValue] : 1; + + NSUInteger targetDimLength = currSrcDimLength; + if (currDimLength != targetDimLength) + targetDimLength = 1; + if (currDimLength != targetDimLength || currStride != currSrcStride) + isValidExpand = NO; + if (currSrcDim >= 0 && currSrcDimLength == targetDimLength) { + currSrcStride *= currSrcDimLength; + currSrcDim--; + } else { + [expandAxes addObject:[NSNumber numberWithInt:dstDim]]; + } + } + + // Did not use every dimension of source + if (!isValidExpand || currSrcDim >= 0) { + [expandAxes release]; + return nil; + } + + MPSGraphTensor *expandTensor = inputTensor; + if ([expandAxes count]) { + expandTensor = [graph expandDimsOfTensor:expandTensor + axes:expandAxes + name:nil]; + } + [expandAxes release]; + + return expandTensor; +} + +// Detect contiguous reshapes, no slicing +MPSGraphTensor* asStridedLayer_reshapePattern(MPSGraph *graph, MPSGraphTensor *inputTensor, int dstRank, const IntArrayRef& dstSizes, const IntArrayRef& dstStrides, int offset) { + NSUInteger srcRank = [[inputTensor shape] count]; + // Not a reshape + if (srcRank <= dstRank) + return nil; + + NSMutableArray *dstShape = [[NSMutableArray alloc] init]; + + BOOL isValidReshape = YES; + NSInteger srcDim = srcRank - 1; + NSUInteger srcStride = 1; + for (NSInteger dstDim = dstRank - 1; dstDim >= 0 && isValidReshape; dstDim--) { + NSUInteger currDimLength = dstSizes[dstDim]; + NSUInteger currStride = dstStrides[dstDim]; + [dstShape insertObject:[NSNumber numberWithInteger:currDimLength] atIndex: 0]; + + NSUInteger targetDimLength = currDimLength; + NSUInteger currReshapeSize = 1; + NSUInteger innerStride = srcStride; + + while (currReshapeSize != targetDimLength && srcDim >= 0) { + NSUInteger srcDimLength = [[inputTensor shape][srcDim] integerValue]; + currReshapeSize *= srcDimLength; + srcStride *= srcDimLength; + srcDim--; + }; + + isValidReshape &= (currReshapeSize == targetDimLength && currStride == innerStride); + } + isValidReshape &= (srcDim < 0); + + MPSGraphTensor *outputTensor = nil; + if (isValidReshape) + outputTensor = [graph reshapeTensor: inputTensor + withShape: dstShape + name: nil]; + [dstShape release]; + return outputTensor; +} + +MPSGraphTensor* asStridedLayer_genericPattern(MPSGraph *graph, MPSGraphTensor *inputTensor, int dstRank, const IntArrayRef& dstSizes, const IntArrayRef& dstStrides, int offset) { + + // Duplicate strides cannot be done + { + BOOL allUnique = YES; + NSMutableSet *uniqueStrides = [[NSMutableSet alloc] init]; + for (NSInteger dstDim = 0; (dstDim < dstRank) && allUnique; dstDim++) { + int stride = dstStrides[dstDim]; + NSNumber *strideObj = [NSNumber numberWithInt:stride]; + allUnique &= (stride == 0 || ![uniqueStrides containsObject:strideObj]); + [uniqueStrides addObject: strideObj]; + } + [uniqueStrides release]; + if (!allUnique) + return nil; + + // Skip for zero in dst shape + for (NSInteger dstDim = 0; dstDim < dstRank; dstDim++) + if (dstSizes[dstDim] == 0) { return nil; } + } + + // 1. Flatten the inputTensor if necessary + MPSGraphTensor *flatInputTensor = inputTensor; + { + // Flatten inputs to remove duplicate strides. + NSMutableArray *squeezeAxes = [[NSMutableArray alloc] init]; + for(NSUInteger srcDim = 1; srcDim < [[flatInputTensor shape] count]; srcDim++) { + if ([[flatInputTensor shape][srcDim] intValue] == 1) + [squeezeAxes addObject:[NSNumber numberWithInteger:srcDim]]; + } + // We have to leave at least 1 dimension, if all input dims are 1 + if ([squeezeAxes count]) + flatInputTensor = [graph squeezeTensor:flatInputTensor + axes:squeezeAxes + name:nil]; + [squeezeAxes release]; + } + + int srcRank = (int)[[flatInputTensor shape] count]; + NSDictionary *srcStrideToDimLengthOffset = getStrideToDimLengthOffsetDict(flatInputTensor, srcRank, offset); + + // Populate the dimension order, slice info, and broadcast info + NSMutableArray *dstDimOrder = [[NSMutableArray alloc] init]; + std::vector dstDimToSliceLength(dstRank); + std::vector dstDimToSliceOffset(dstRank); + bool needsBroadcast = false; + { + for (NSInteger dstDim = dstRank - 1; dstDim >= 0; dstDim--) { + if (dstStrides[dstDim] == 0) { + // This dimension should be a broadcast + needsBroadcast = true; + dstDimToSliceLength[dstDim] = dstSizes[dstDim]; + dstDimToSliceOffset[dstDim] = 0; + } else { + // Find what dimension and native length was for the specified stride + NSDictionary *srcDimLengthOffset = srcStrideToDimLengthOffset[[NSString stringWithFormat:@"%lld",dstStrides[dstDim]]]; + + dstDimToSliceLength[dstDim] = dstSizes[dstDim]; + dstDimToSliceOffset[dstDim] = [srcDimLengthOffset[@"offset"] intValue]; + + // Stride does not exist in source tensor, or the specified size is too long. Not possible + // TODO: Longer length with same stride + removal of dim(s) above this is a flatten/reshape. Consider adding support + if (!srcDimLengthOffset || + // the offset + length of destination should not be larger than source's length when slicing + dstDimToSliceOffset[dstDim] + dstDimToSliceLength[dstDim] > [srcDimLengthOffset[@"length"] intValue]) { + return nil; + } + // Get the src dimension corresponding to the requested stride + NSNumber *srcDim = srcDimLengthOffset[@"dim"]; + [dstDimOrder insertObject:srcDim atIndex:0]; + } + } + } + + // 2. Slice out any unused dimensions + NSMutableArray *missingSrcDims = [[NSMutableArray alloc] init]; + MPSGraphTensor *slicedUnusedTensor = flatInputTensor; + { + // Find any src strides/dims that are not present in the dst + NSMutableArray *missingSrcStrides = [[NSMutableArray alloc] init]; + { + NSUInteger stride = 1; + for (NSInteger srcDim = [[flatInputTensor shape] count] - 1; srcDim >= 0; srcDim--) { + [missingSrcStrides addObject:[NSNumber numberWithInteger:stride]]; + stride *= [[flatInputTensor shape][srcDim] integerValue]; + } + for (NSInteger dstDim = 0; dstDim < dstRank; dstDim++) { + [missingSrcStrides removeObject:[NSNumber numberWithInteger:dstStrides[dstDim]]]; + } + } + for (NSUInteger i = 0; i < [missingSrcStrides count]; i++) { + NSUInteger stride = [missingSrcStrides[i] integerValue]; + NSDictionary *srcDimLengthOffset = srcStrideToDimLengthOffset[[NSString stringWithFormat:@"%ld",stride]]; + NSNumber *missingSrcDim = srcDimLengthOffset[@"dim"]; + [missingSrcDims addObject:missingSrcDim]; + [dstDimOrder insertObject:missingSrcDim atIndex:0]; + + slicedUnusedTensor = [graph sliceTensor:slicedUnusedTensor + dimension:[missingSrcDim intValue] + start:[srcDimLengthOffset[@"offset"] intValue] + length:1 + name:nil]; + } + [missingSrcStrides release]; + } + + // 3. Transpose if necessary + MPSGraphTensor *transposedTensor = slicedUnusedTensor; + { + // TODO: Use Transpose API + BOOL needsTranspose = NO; + for(NSUInteger dstDim = 0; dstDim < [dstDimOrder count] && !needsTranspose; dstDim++ ) + needsTranspose |= ([dstDimOrder[dstDim] intValue] != dstDim); + if (needsTranspose) + transposedTensor = permuteTensor(graph, transposedTensor, dstDimOrder); + } + + // 4. Squeeze any unused dimensions following transpose + MPSGraphTensor *squeezedTensor = transposedTensor; + { + // Transpose the missing dims back + NSMutableArray *transposedMissingSrcDims = [[NSMutableArray alloc] init]; + for (NSUInteger dstDim = 0; dstDim < [dstDimOrder count]; dstDim++) { + NSNumber *srcDim = dstDimOrder[dstDim]; + if ([missingSrcDims containsObject:srcDim]) + [transposedMissingSrcDims addObject:[NSNumber numberWithInt:dstDim]]; + } + if ([transposedMissingSrcDims count]) + squeezedTensor = [graph squeezeTensor:squeezedTensor + axes:transposedMissingSrcDims + name:nil]; + [transposedMissingSrcDims release]; + } + + // 5. Slice + MPSGraphTensor *slicedTensor = squeezedTensor; + { + NSUInteger currDstDim = 0; + for (NSUInteger dstDim = 0; dstDim < dstRank; dstDim++) { + // Only dstDims with nonzero stride are in the current tensor, skip broadcasts + if (dstStrides[dstDim] != 0) { + int start = dstDimToSliceOffset[dstDim]; + int length = dstDimToSliceLength[dstDim]; + if (length != [[slicedTensor shape][currDstDim] intValue]) + slicedTensor = [graph sliceTensor:slicedTensor + dimension:currDstDim + start:start + length:length + name:nil]; + currDstDim++; + } + } + } + + // 6. Expand then broadcast the source tensor + MPSGraphTensor *broadcastTensor = slicedTensor; + if (needsBroadcast) { + NSMutableArray *broadcastShape = [[NSMutableArray alloc] init]; + NSMutableArray *expandAxes = [[NSMutableArray alloc] init]; + for(NSInteger dstDim = 0; dstDim < dstRank; dstDim++) { + [broadcastShape addObject:[NSNumber numberWithInt:dstSizes[dstDim]]]; + if (dstStrides[dstDim] == 0) + [expandAxes addObject:[NSNumber numberWithInt:dstDim]]; + } + + if ([expandAxes count]) { + MPSGraphTensor *expandTensor = [graph expandDimsOfTensor:broadcastTensor + axes:expandAxes + name:nil]; + broadcastTensor = [graph broadcastTensor:expandTensor + toShape:broadcastShape + name:nil]; + } + [broadcastShape release]; + [expandAxes release]; + } + + [srcStrideToDimLengthOffset release]; + [dstDimOrder release]; + [missingSrcDims release]; + + return broadcastTensor; +} + +MPSGraphTensor* asStridedLayer_pattern(MPSGraph *graph, MPSGraphTensor *inputTensor, int dstRank, const IntArrayRef& dstSizes, const IntArrayRef& dstStrides, int offset) { + if (!dstRank) + return nil; + + MPSGraphTensor *outputTensor = nil; + outputTensor = asStridedLayer_expandDimsPattern(graph, inputTensor, dstRank, dstSizes, dstStrides, offset); + if (!outputTensor) + outputTensor = asStridedLayer_reshapePattern(graph, inputTensor, dstRank, dstSizes, dstStrides, offset); + if (!outputTensor) + outputTensor = asStridedLayer_genericPattern(graph, inputTensor, dstRank, dstSizes, dstStrides, offset); + + return outputTensor; +} + +static +std::vector getViewShape(const Tensor& src, MPSShape *mpsShape) { + bool hasMPSShape = (mpsShape != nil); + std::vector src_view_shape; + if (hasMPSShape) { + int src_ndim_view = [mpsShape count]; + src_view_shape.resize(src_ndim_view); + for (const auto i : c10::irange(src_ndim_view)) { + src_view_shape[i] = [mpsShape[i] intValue]; + } + } else { + src_view_shape = src.sizes().vec(); + } + + return src_view_shape; +} + +bool canSliceViewTensor(const Tensor& src, MPSShape *mpsShape) { + if (!src.is_contiguous()) { + return false; + } + + IntArrayRef src_base_shape = getIMPSAllocator()->getBufferShape(src.storage().data()); + size_t src_ndim_base = src_base_shape.size(); + std::vector src_view_shape = getViewShape(src, mpsShape); + size_t src_ndim_view = src_view_shape.size(); + if (src_ndim_base != src_ndim_view) { + return false; + } + + for (const auto i: c10::irange(src_ndim_base)) { + if (src_view_shape[i] > src_base_shape[i]) { + return false; + } + } + + return true; +} + +MPSGraphTensorData* getMPSGraphTensorDataForView(const Tensor& src, MPSShape *mpsShape, const MPSDataType mpsDataType) { + IntArrayRef src_base_shape = getIMPSAllocator()->getBufferShape(src.storage().data()); + int src_ndim_base = src_base_shape.size(); + std::vector src_view_shape = getViewShape(src, mpsShape); + int src_ndim_view = src_view_shape.size(); + + TORCH_CHECK(src_ndim_base == src_ndim_view); + + MPSNDArray *srcTensorNDArrayView = nil; + MPSNDArrayDescriptor *srcTensorNDArrayDesc = nil; + MPSNDArray *srcTensorNDArray = nil; + id commandBuffer = getCurrentMPSStream()->commandBuffer(); + + srcTensorNDArray = ndArrayFromTensor(src, getMPSShape(src_base_shape), mpsDataType); + srcTensorNDArrayDesc = srcTensorNDArray.descriptor; + + int firstDimToSlice = 0; + while (src_base_shape[firstDimToSlice] == src_view_shape[firstDimToSlice]) { + firstDimToSlice++; + } + + int view_numel = 1; + for (const auto i : c10::irange(firstDimToSlice + 1, src_base_shape.size())) { + view_numel *= src_base_shape[i]; + } + + int sliceOffset = src.storage_offset() / view_numel; + // There are cases where both dimensions of a view can shrink + // E.g: x = torch.randn((3,6))[1, 1:3] + int nextSliceOffset = src.storage_offset() % view_numel; + + [srcTensorNDArrayDesc sliceDimension:src_ndim_base - 1 - firstDimToSlice withSubrange:{static_cast(sliceOffset), static_cast(src.sizes()[firstDimToSlice])}]; + if (nextSliceOffset) { + [srcTensorNDArrayDesc sliceDimension:src_ndim_base - 2 - firstDimToSlice withSubrange:{static_cast(nextSliceOffset), static_cast(src.sizes()[firstDimToSlice+1])}]; + } + + srcTensorNDArrayView = [srcTensorNDArray arrayViewWithCommandBuffer:commandBuffer + descriptor:srcTensorNDArrayDesc + aliasing:MPSAliasingStrategyShallAlias]; + + return [[[MPSGraphTensorData alloc] initWithMPSNDArray:srcTensorNDArrayView] autorelease]; +} + static MPSGraphTensor* chainViewOperation(ViewCachedGraph* cachedGraph, const IntArrayRef& size, const IntArrayRef& stride, int64_t offset, const IntArrayRef& base_shape, bool needsScatter, - const bool needsBoolCast) + MPSGraphTensor* updatesTensor) { MPSGraph* mpsGraph = cachedGraph->graph(); MPSGraphTensor *outputTensor = nil; @@ -123,12 +552,11 @@ name: nil]; MPSGraphTensor *inputTensor = cachedGraph->inputTensor; - // Workaround for bool scatter/gather deficiency - // See https://github.com/pytorch/pytorch/issues/82663 - if (needsBoolCast) { - inputTensor = [mpsGraph castTensor:inputTensor - toType:MPSDataTypeInt8 - name:@"Cast away from bool"]; + if (!needsScatter) { + MPSGraphTensor *outputTensor = asStridedLayer_pattern(mpsGraph, inputTensor, shape_size, size, stride, offset); + if (outputTensor) { + return outputTensor; + } } MPSGraphTensor *reshapedInputTensor = [mpsGraph reshapeTensor: inputTensor @@ -140,7 +568,7 @@ if (needsScatter) { MPSGraphTensor* scatteredTensor = [mpsGraph scatterAlongAxis: (NSInteger) 0 withDataTensor: reshapedInputTensor - updatesTensor: cachedGraph->updatesTensor + updatesTensor: updatesTensor indicesTensor: reshapedIndicesTensor mode: MPSGraphScatterModeSet name: nil]; @@ -159,18 +587,28 @@ withShapeTensor: shapeTensor name: nil]; } - - // Workaround for bool scatter/gather deficiency - // See https://github.com/pytorch/pytorch/issues/82663 - if (needsBoolCast) { - outputTensor = [mpsGraph castTensor:outputTensor - toType:MPSDataTypeBool - name:@"Cast back to bool"]; - } } return outputTensor; } +static IntArrayRef updateTensorBaseShape(const Tensor& self) +{ + IntArrayRef base_shape = getIMPSAllocator()->getBufferShape(self.storage().data()); + // if there's no base_shape stored in MPSAllocator, then infer it from tensor's size and store it + if (base_shape.size() == 0) { + // IntArrayRef wouldn't own the data, so we use a static storage + static const int64_t shape_1d = 1; + // self.sizes().size() could be zero + base_shape = self.sizes().size() ? self.sizes() : + ((self.is_view() && self._base().sizes().size()) ? self._base().sizes() : IntArrayRef(&shape_1d, 1)); + + // base_shape will be retained in MPSAllocator until buffer gets recycled + if (self.storage().data()) + getIMPSAllocator()->setBufferShape(self.storage().data(), base_shape); + } + return base_shape; +} + // There are few cases we need to consider: // Here nodes are the Tensors and the edges are the operations performed on the // Tensor. As a result of the operation performed we can have result as View @@ -188,24 +626,13 @@ // | / \ | // | / \ | // NonView T NonView T -static ViewCachedGraph* createViewGraph(const Tensor& self, IntArrayRef size, IntArrayRef stride, int64_t storage_offset, bool needsScatter) +static ViewCachedGraph* createViewGraph(const Tensor& self, const Tensor &updates, IntArrayRef size, IntArrayRef stride, int64_t storage_offset, bool needsScatter) { - IntArrayRef base_shape = get_buffer_shape(self.storage().data()); - if (base_shape.size() == 0) { - // IntArrayRef wouldn't own the data, so we use a static storage - static const int64_t shape_1d = 1; - // self.sizes().size() could be zero - base_shape = self.sizes().size() ? self.sizes() : - self.is_view() ? self._base().sizes() : IntArrayRef(&shape_1d, 1); - - // base_shape will be retained in MPSAllocator until buffer gets recycled - if (self.storage().data()) - set_buffer_shape(self.storage().data(), base_shape); - } - MPSGraphCache* cache_ = MPSGraphCache::getInstance(); + IntArrayRef base_shape = updateTensorBaseShape(self); @autoreleasepool { - string key = getStridedKey(self.scalar_type(), base_shape, size, needsScatter); + string key = getStridedKey(self.scalar_type(), updates.scalar_type(), base_shape, size, stride, storage_offset, needsScatter); + MPSGraphCache* cache_ = MPSGraphCache::getInstance(); ViewCachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); if (!cachedGraph) { @@ -213,14 +640,15 @@ ViewCachedGraph *newCachedGraph = nil; @autoreleasepool { MPSGraph* mpsGraph = make_mps_graph(); + MPSGraphTensor* updatesTensor = nil; newCachedGraph = new ViewCachedGraph(mpsGraph); - // Workaround for MPSShaderLibrary bug - // TODO: Remove once https://github.com/pytorch/pytorch/issues/82305 is resolved + // Workaround for MPSShaderLibrary bug in macOS Monterey + // This is fixed in macOS Ventura auto inputType = getMPSScalarType(self.scalar_type()); - if (inputType == MPSDataTypeUInt8) { - inputType = MPSDataTypeInt8; + if (inputType == MPSDataTypeUInt8 || (inputType == MPSDataTypeBool && !is_macos_13_or_newer())) { + inputType = MPSDataTypeInt8; } - auto needsBoolCast = inputType == MPSDataTypeBool; + // Self is the input tensor we are creating view of newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, inputType, getMPSShape(base_shape)); newCachedGraph->storageOffsetTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[@1]); @@ -228,9 +656,19 @@ newCachedGraph->strideTensors.push_back(mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[@1])); } if (needsScatter) { - newCachedGraph->updatesTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, inputType); + auto updatesType = getMPSScalarType(updates.scalar_type()); + if (updatesType == MPSDataTypeUInt8 || (updatesType == MPSDataTypeBool && !is_macos_13_or_newer())) { + updatesType = MPSDataTypeInt8; + } + newCachedGraph->updatesTensor = mpsGraphRankedPlaceHolder(mpsGraph, updatesType, getMPSShape(self.numel())); + updatesTensor = newCachedGraph->updatesTensor; + if (inputType != updatesType) { + updatesTensor = [mpsGraph castTensor:updatesTensor + toType:inputType + name:@"castUpdatesTensor"]; + } } - newCachedGraph->outputTensor = chainViewOperation(newCachedGraph, size, stride, storage_offset, base_shape, needsScatter, needsBoolCast); + newCachedGraph->outputTensor = chainViewOperation(newCachedGraph, size, stride, storage_offset, base_shape, needsScatter, updatesTensor); } return newCachedGraph; })); @@ -241,48 +679,41 @@ Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst) { - ViewCachedGraph* cachedGraph = nullptr; - - const IntArrayRef& base_shape = get_buffer_shape(src.storage().data()); - if (base_shape.size() > 0) { - string key = getStridedKey(src.scalar_type(), base_shape, src.sizes(), /*is_scatter*/ false); - cachedGraph = static_cast(MPSGraphCache::getInstance()->LookUp(key)); - } - // there are cases where gatherViewTensor() is called without having as_strided() called beforehand. - // this typically may come from copy_mps variants. In such cases, when the base_shape isn't found the - // callers would resort to make the tensor contiguous in an alternative code path. - if (!cachedGraph) { + if (src.sizes().size() == 0) { return Tensor(); } - - bool requires_sync = false; Tensor output; if (!dst.has_storage()) { output = at::native::empty_mps(src.sizes(), src.scalar_type(), c10::nullopt, kMPS); - requires_sync = true; } - return runViewGraph(cachedGraph, src, dst.has_storage() ? dst : output, /*needsScatter*/ false, requires_sync); + ViewCachedGraph* cachedGraph = createViewGraph(src.is_complex() ? at::view_as_real(src) : src, + dst, src.sizes(), src.strides(), + src.storage_offset(), /*needsScatter*/ false); + return runViewGraph(cachedGraph, src, dst.has_storage() ? dst : output, /*needsScatter*/ false); } -Tensor& scatterViewTensor(const at::Tensor& src, at::Tensor& output) +Tensor& scatterViewTensor(const at::Tensor& src, at::Tensor& output, id updatesBuffer) { - ViewCachedGraph* cachedGraph = createViewGraph(output, output.sizes(), output.strides(), + ViewCachedGraph* cachedGraph = createViewGraph(output.is_complex() ? at::view_as_real(output) : output, + src, output.sizes(), output.strides(), output.storage_offset(), /*needsScatter*/ true); - return runViewGraph(cachedGraph, src, output, /*needsScatter*/ true, /*requires_sync*/ true); + return runViewGraph(cachedGraph, src, output, /*needsScatter*/ true, updatesBuffer); } } // namespace mps // implementation of as_strided() op -Tensor as_strided_tensorimpl_mps(const Tensor& self, IntArrayRef size, IntArrayRef stride, optional storage_offset_) +Tensor as_strided_tensorimpl_mps(const Tensor& self, IntArrayRef size, IntArrayRef stride, c10::optional storage_offset_) { auto storage_offset = storage_offset_.value_or(self.storage_offset()); auto result = detail::make_tensor(c10::TensorImpl::VIEW, Storage(self.storage()), self.key_set(), self.dtype()); setStrided(result, size, stride, storage_offset); - // 0 sizes won't result in any change in the shape of the Tensor so we can skip it. - if (size.size() > 0) - mps::createViewGraph(self, size, stride, storage_offset, /*needsScatter*/ false); + // creating the view graph will be deferred until gatherViewTensor() or scatterViewTensor() are called. + // In as_strided, we just update the base shape of the buffer in order to retrieve it later + // when we create/run the view graph. + IntArrayRef base_shape = mps::updateTensorBaseShape(self); + TORCH_INTERNAL_ASSERT(base_shape.size() > 0, "Failed to update the base shape of tensor's buffer at ", self.storage().data()); return result; } diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index ba1d38aa350b5..494c0eb2e8ba5 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -244,7 +244,7 @@ CPU: native_dropout_cpu CUDA: native_dropout_cuda NestedTensorCPU, NestedTensorCUDA: native_dropout_nested - tags: nondeterministic_seeded, canonical + tags: [nondeterministic_seeded, canonical] autogen: native_dropout.out - func: native_dropout_backward(Tensor grad_output, Tensor mask, float scale) -> Tensor @@ -252,6 +252,7 @@ CPU, NestedTensorCPU, NestedTensorCUDA: native_dropout_backward CUDA: native_dropout_backward_cuda autogen: native_dropout_backward.out + tags: pointwise - func: _sobol_engine_draw(Tensor quasi, int n, Tensor sobolstate, int dimension, int num_generated, ScalarType? dtype) -> (Tensor, Tensor) @@ -296,7 +297,7 @@ CompositeExplicitAutograd: abs SparseCPU, SparseCUDA: abs_sparse SparseCsrCPU, SparseCsrCUDA: abs_sparse_csr - tags: canonical + tags: [canonical, pointwise] - func: abs_(Tensor(a!) self) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -313,6 +314,7 @@ MPS: abs_out_mps SparseCPU, SparseCUDA: abs_sparse_out SparseCsrCPU, SparseCsrCUDA: abs_sparse_csr_out + tags: pointwise # Note [Adding an alias] # To add an alias do the following: @@ -336,8 +338,8 @@ # in op_db list in torch/testing/_internal/common_methods_invocations.py # # See torch.absolute, an alias for torch.abs, as an example. - # Absolute, alias for abs + - func: absolute(Tensor self) -> Tensor device_check: NoCheck # TensorIterator variants: function, method @@ -355,12 +357,14 @@ dispatch: CPU, CUDA: angle SparseCsrCPU, SparseCsrCUDA: angle_sparse_csr + tags: pointwise - func: angle.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator dispatch: CPU, CUDA: angle_out SparseCsrCPU, SparseCsrCUDA: angle_sparse_csr_out + tags: pointwise - func: view_as_real(Tensor(a) self) -> Tensor(a) variants: function @@ -378,6 +382,7 @@ dispatch: SparseCPU, SparseCUDA: sgn_sparse SparseCsrCPU, SparseCsrCUDA: sgn_sparse_csr + tags: pointwise - func: sgn_(Tensor(a!) self) -> Tensor(a!) variants: method @@ -385,14 +390,17 @@ dispatch: SparseCPU, SparseCUDA: sgn_sparse_ SparseCsrCPU, SparseCsrCUDA: sgn_sparse_csr_ + tags: pointwise - func: sgn.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) structured: True structured_inherits: TensorIteratorBase dispatch: CPU, CUDA: sgn_out + MPS: sgn_out_mps SparseCPU, SparseCUDA: sgn_sparse_out SparseCsrCPU, SparseCsrCUDA: sgn_sparse_csr_out + tags: pointwise - func: chalf(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor variants: method @@ -423,18 +431,21 @@ - func: conj_physical(Tensor self) -> Tensor variants: function, method + tags: pointwise - func: conj_physical.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) dispatch: CPU, CUDA: conj_physical_out SparseCPU, SparseCUDA: conj_physical_out_sparse SparseCsrCPU, SparseCsrCUDA: conj_physical_sparse_csr_out + tags: pointwise - func: conj_physical_(Tensor(a!) self) -> Tensor(a!) variants: function, method dispatch: CompositeExplicitAutograd: conj_physical_ SparseCsrCPU, SparseCsrCUDA: conj_physical_sparse_csr_ + tags: pointwise - func: resolve_conj(Tensor(a) self) -> Tensor(a) variants: function, method @@ -451,11 +462,13 @@ device_check: NoCheck # TensorIterator variants: function, method structured_delegate: acos.out + tags: pointwise - func: acos_(Tensor(a!) self) -> Tensor(a!) device_check: NoCheck # TensorIterator variants: function, method structured_delegate: acos.out + tags: pointwise - func: acos.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -464,6 +477,7 @@ dispatch: CPU, CUDA: acos_out MPS: acos_out_mps + tags: pointwise # arccos, alias of acos - func: arccos(Tensor self) -> Tensor @@ -491,7 +505,7 @@ MkldnnCPU: mkldnn_add ZeroTensor: add_zerotensor NestedTensorCPU, NestedTensorCUDA: NestedTensor_add_Tensor - tags: canonical + tags: [canonical, pointwise] - func: add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -502,6 +516,7 @@ SparseCsrCPU, SparseCsrCUDA: add_sparse_csr_ MkldnnCPU: mkldnn_add_ NestedTensorCPU, NestedTensorCUDA: NestedTensor_add__Tensor + tags: pointwise - func: add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -517,6 +532,7 @@ SparseCsrCUDA: add_out_sparse_csr_cuda MkldnnCPU: mkldnn_add_out MPS: add_out_mps + tags: pointwise - func: _add_relu.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor variants: function @@ -550,7 +566,7 @@ variants: function, method dispatch: CompositeExplicitAutograd: add - tags: canonical + tags: [canonical, pointwise] - func: add_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -558,6 +574,7 @@ dispatch: CompositeExplicitAutograd: add_ autogen: add.Scalar_out + tags: pointwise - func: addmv(Tensor self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor structured_delegate: addmv.out @@ -713,10 +730,12 @@ - func: acosh(Tensor self) -> Tensor variants: function, method structured_delegate: acosh.out + tags: pointwise - func: acosh_(Tensor(a!) self) -> Tensor(a!) variants: function, method structured_delegate: acosh.out + tags: pointwise - func: acosh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) structured: True @@ -724,8 +743,9 @@ dispatch: CPU, CUDA: acosh_out MPS: acosh_out_mps - + tags: pointwise # arccosh, alias for acosh + - func: arccosh(Tensor self) -> Tensor variants: function, method @@ -740,6 +760,7 @@ dispatch: SparseCPU, SparseCUDA: asinh_sparse SparseCsrCPU, SparseCsrCUDA: asinh_sparse_csr + tags: pointwise - func: asinh_(Tensor(a!) self) -> Tensor(a!) variants: function, method @@ -747,6 +768,7 @@ dispatch: SparseCPU, SparseCUDA: asinh_sparse_ SparseCsrCPU, SparseCsrCUDA: asinh_sparse_csr_ + tags: pointwise - func: asinh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) structured: True @@ -756,6 +778,7 @@ MPS: asinh_out_mps SparseCPU, SparseCUDA: asinh_sparse_out SparseCsrCPU, SparseCsrCUDA: asinh_sparse_csr_out + tags: pointwise # arcsinh, alias for asinh - func: arcsinh(Tensor self) -> Tensor @@ -772,6 +795,7 @@ dispatch: SparseCPU, SparseCUDA: atanh_sparse SparseCsrCPU, SparseCsrCUDA: atanh_sparse_csr + tags: pointwise - func: atanh_(Tensor(a!) self) -> Tensor(a!) structured_delegate: atanh.out @@ -779,6 +803,7 @@ dispatch: SparseCPU, SparseCUDA: atanh_sparse_ SparseCsrCPU, SparseCsrCUDA: atanh_sparse_csr_ + tags: pointwise - func: atanh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) structured: True @@ -788,8 +813,9 @@ MPS: atanh_out_mps SparseCPU, SparseCUDA: atanh_sparse_out SparseCsrCPU, SparseCsrCUDA: atanh_sparse_csr_out - + tags: pointwise # arctanh, alias for atanh + - func: arctanh(Tensor self) -> Tensor variants: function, method @@ -815,7 +841,7 @@ device_guard: False tags: inplace_view dispatch: - CompositeExplicitAutogradNonFunctional: as_strided_ + CompositeExplicitAutogradNonFunctional: as_strided__symint - func: asin(Tensor self) -> Tensor device_check: NoCheck # TensorIterator @@ -824,6 +850,7 @@ dispatch: SparseCPU, SparseCUDA: asin_sparse SparseCsrCPU, SparseCsrCUDA: asin_sparse_csr + tags: pointwise - func: asin_(Tensor(a!) self) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -832,6 +859,7 @@ dispatch: SparseCPU, SparseCUDA: asin_sparse_ SparseCsrCPU, SparseCsrCUDA: asin_sparse_csr_ + tags: pointwise - func: asin.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -842,6 +870,7 @@ MPS: asin_out_mps SparseCPU, SparseCUDA: asin_sparse_out SparseCsrCPU, SparseCsrCUDA: asin_sparse_csr_out + tags: pointwise # arcsin, alias of asin - func: arcsin(Tensor self) -> Tensor @@ -859,6 +888,7 @@ dispatch: SparseCPU, SparseCUDA: atan_sparse SparseCsrCPU, SparseCsrCUDA: atan_sparse_csr + tags: pointwise - func: atan_(Tensor(a!) self) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -867,6 +897,7 @@ dispatch: SparseCPU, SparseCUDA: atan_sparse_ SparseCsrCPU, SparseCsrCUDA: atan_sparse_csr_ + tags: pointwise - func: atan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -877,6 +908,7 @@ MPS: atan_out_mps SparseCPU, SparseCUDA: atan_sparse_out SparseCsrCPU, SparseCsrCUDA: atan_sparse_csr_out + tags: pointwise # arctan, alias of atan - func: arctan(Tensor self) -> Tensor @@ -985,6 +1017,8 @@ device_check: NoCheck # TensorIterator variants: function, method tags: nondeterministic_seeded + dispatch: + CompositeExplicitAutogradNonFunctional: bernoulli - func: bilinear(Tensor input1, Tensor input2, Tensor weight, Tensor? bias=None) -> Tensor @@ -1034,6 +1068,7 @@ dispatch: CPU: _bincount_cpu CUDA: _bincount_cuda + MPS: _bincount_mps tags: dynamic_output_shape autogen: bincount.out @@ -1041,12 +1076,13 @@ device_check: NoCheck # TensorIterator structured_delegate: bitwise_not.out variants: function, method - tags: canonical + tags: [canonical, pointwise] - func: bitwise_not_(Tensor(a!) self) -> Tensor(a!) device_check: NoCheck # TensorIterator structured_delegate: bitwise_not.out variants: method + tags: pointwise - func: bitwise_not.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -1054,6 +1090,7 @@ structured_inherits: TensorIteratorBase dispatch: CPU, CUDA: bitwise_not_out + tags: pointwise - func: copysign.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -1061,11 +1098,13 @@ structured_inherits: TensorIteratorBase dispatch: CPU, CUDA: copysign_out + tags: pointwise - func: copysign.Tensor(Tensor self, Tensor other) -> Tensor device_check: NoCheck # TensorIterator variants: function, method structured_delegate: copysign.out + tags: pointwise - func: copysign_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -1076,6 +1115,7 @@ variants: function, method dispatch: CompositeExplicitAutograd: copysign + tags: pointwise - func: copysign_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) variants: method @@ -1085,78 +1125,91 @@ - func: copysign.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) dispatch: CompositeExplicitAutograd: copysign_out + tags: pointwise - func: logical_not(Tensor self) -> Tensor device_check: NoCheck # TensorIterator variants: function, method dispatch: CompositeExplicitAutograd: logical_not + tags: pointwise - func: logical_not_(Tensor(a!) self) -> Tensor(a!) device_check: NoCheck # TensorIterator variants: method dispatch: CompositeExplicitAutograd: logical_not_ + tags: pointwise - func: logical_not.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator dispatch: CPU, CUDA: logical_not_out MPS: logical_not_out_mps + tags: pointwise - func: logical_xor(Tensor self, Tensor other) -> Tensor device_check: NoCheck # TensorIterator variants: function, method dispatch: CompositeExplicitAutograd: logical_xor + tags: pointwise - func: logical_xor_(Tensor(a!) self, Tensor other) -> Tensor(a!) device_check: NoCheck # TensorIterator variants: method dispatch: CompositeExplicitAutograd: logical_xor_ + tags: pointwise - func: logical_xor.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator dispatch: CPU, CUDA: logical_xor_out MPS: logical_xor_out_mps + tags: pointwise - func: logical_and(Tensor self, Tensor other) -> Tensor device_check: NoCheck # TensorIterator variants: function, method dispatch: CompositeExplicitAutograd: logical_and + tags: pointwise - func: logical_and_(Tensor(a!) self, Tensor other) -> Tensor(a!) device_check: NoCheck # TensorIterator variants: method dispatch: CompositeExplicitAutograd: logical_and_ + tags: pointwise - func: logical_and.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator dispatch: CPU, CUDA: logical_and_out MPS: logical_and_out_mps + tags: pointwise - func: logical_or(Tensor self, Tensor other) -> Tensor device_check: NoCheck # TensorIterator variants: function, method dispatch: CompositeExplicitAutograd: logical_or + tags: pointwise - func: logical_or_(Tensor(a!) self, Tensor other) -> Tensor(a!) device_check: NoCheck # TensorIterator variants: method dispatch: CompositeExplicitAutograd: logical_or_ + tags: pointwise - func: logical_or.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator dispatch: CPU, CUDA: logical_or_out MPS: logical_or_out_mps + tags: pointwise - func: blackman_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor dispatch: @@ -1193,8 +1246,10 @@ device_check: NoCheck device_guard: False -- func: broadcast_to(Tensor(a) self, int[] size) -> Tensor(a) +- func: broadcast_to(Tensor(a) self, SymInt[] size) -> Tensor(a) variants: function, method + dispatch: + CompositeImplicitAutograd: broadcast_to_symint - func: _sparse_broadcast_to(Tensor(a) self, int[] size) -> Tensor(a) variants: function @@ -1253,6 +1308,7 @@ dispatch: SparseCPU, SparseCUDA: ceil_sparse SparseCsrCPU, SparseCsrCUDA: ceil_sparse_csr + tags: pointwise - func: ceil_(Tensor(a!) self) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -1261,6 +1317,7 @@ dispatch: SparseCPU, SparseCUDA: ceil_sparse_ SparseCsrCPU, SparseCsrCUDA: ceil_sparse_csr_ + tags: pointwise - func: ceil.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -1271,6 +1328,7 @@ MPS: ceil_out_mps SparseCPU, SparseCUDA: ceil_sparse_out SparseCsrCPU, SparseCsrCUDA: ceil_sparse_csr_out + tags: pointwise # alias for torch.linalg.multi_dot - func: chain_matmul(Tensor[] matrices) -> Tensor @@ -1292,11 +1350,15 @@ CompositeImplicitAutograd: chunk NestedTensorCPU, NestedTensorCUDA: chunk_nested_tensor -- func: tensor_split.sections(Tensor(a -> *) self, int sections, int dim=0) -> Tensor(a)[] +- func: tensor_split.sections(Tensor(a -> *) self, SymInt sections, int dim=0) -> Tensor(a)[] variants: function, method + dispatch: + CompositeImplicitAutograd: tensor_split_sections_symint -- func: tensor_split.indices(Tensor(a -> *) self, int[] indices, int dim=0) -> Tensor(a)[] +- func: tensor_split.indices(Tensor(a -> *) self, SymInt[] indices, int dim=0) -> Tensor(a)[] variants: function, method + dispatch: + CompositeImplicitAutograd: tensor_split_indices_symint - func: tensor_split.tensor_indices_or_sections(Tensor(a -> *) self, Tensor tensor_indices_or_sections, int dim=0) -> Tensor(a)[] variants: function, method @@ -1308,21 +1370,24 @@ structured_delegate: clamp.out dispatch: QuantizedCPU: clamp_quantized_cpu - tags: canonical + tags: [canonical, pointwise] - func: clamp.Tensor(Tensor self, Tensor? min=None, Tensor? max=None) -> Tensor variants: function, method structured_delegate: clamp.Tensor_out + tags: pointwise - func: clamp_(Tensor(a!) self, Scalar? min=None, Scalar? max=None) -> Tensor(a!) device_check: NoCheck # TensorIterator variants: function, method cpp_no_default_args: ['min'] structured_delegate: clamp.out + tags: pointwise - func: clamp_.Tensor(Tensor(a!) self, Tensor? min=None, Tensor? max=None) -> Tensor(a!) variants: function, method structured_delegate: clamp.Tensor_out + tags: pointwise - func: clamp.out(Tensor self, Scalar? min=None, Scalar? max=None, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -1332,6 +1397,7 @@ dispatch: CPU, CUDA: clamp_out MPS: clamp_out_mps + tags: pointwise - func: clamp.Tensor_out(Tensor self, Tensor? min=None, Tensor? max=None, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -1340,24 +1406,29 @@ dispatch: CPU, CUDA: clamp_Tensor_out MPS: clamp_Tensor_out_mps + tags: pointwise - func: clamp_max(Tensor self, Scalar max) -> Tensor device_check: NoCheck # TensorIterator variants: function, method structured_delegate: clamp_max.out + tags: pointwise - func: clamp_max.Tensor(Tensor self, Tensor max) -> Tensor variants: function, method structured_delegate: clamp_max.Tensor_out + tags: pointwise - func: clamp_max_(Tensor(a!) self, Scalar max) -> Tensor(a!) device_check: NoCheck # TensorIterator variants: function, method structured_delegate: clamp_max.out + tags: pointwise - func: clamp_max_.Tensor(Tensor(a!) self, Tensor max) -> Tensor(a!) variants: function, method structured_delegate: clamp_max.Tensor_out + tags: pointwise - func: clamp_max.out(Tensor self, Scalar max, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -1366,6 +1437,7 @@ dispatch: CPU, CUDA: clamp_max_out MPS: clamp_max_out_mps + tags: pointwise - func: clamp_max.Tensor_out(Tensor self, Tensor max, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -1374,24 +1446,29 @@ dispatch: CPU, CUDA: clamp_max_Tensor_out MPS: clamp_max_Tensor_out_mps + tags: pointwise - func: clamp_min(Tensor self, Scalar min) -> Tensor device_check: NoCheck # TensorIterator variants: function, method structured_delegate: clamp_min.out + tags: pointwise - func: clamp_min.Tensor(Tensor self, Tensor min) -> Tensor variants: function, method structured_delegate: clamp_min.Tensor_out + tags: pointwise - func: clamp_min_(Tensor(a!) self, Scalar min) -> Tensor(a!) device_check: NoCheck # TensorIterator variants: function, method structured_delegate: clamp_min.out + tags: pointwise - func: clamp_min_.Tensor(Tensor(a!) self, Tensor min) -> Tensor(a!) variants: function, method structured_delegate: clamp_min.Tensor_out + tags: pointwise - func: clamp_min.out(Tensor self, Scalar min, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -1400,6 +1477,7 @@ dispatch: CPU, CUDA: clamp_min_out MPS: clamp_min_out_mps + tags: pointwise - func: clamp_min.Tensor_out(Tensor self, Tensor min, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -1408,24 +1486,30 @@ dispatch: CPU, CUDA: clamp_min_Tensor_out MPS: clamp_min_Tensor_out_mps + tags: pointwise # clip is an alias for clamp - func: clip(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor cpp_no_default_args: ['min'] variants: function, method + tags: pointwise - func: clip.Tensor(Tensor self, Tensor? min=None, Tensor? max=None) -> Tensor variants: function, method + tags: pointwise - func: clip_(Tensor(a!) self, Scalar? min=None, Scalar? max=None) -> Tensor(a!) cpp_no_default_args: ['min'] variants: function, method + tags: pointwise - func: clip_.Tensor(Tensor(a!) self, Tensor? min=None, Tensor? max=None) -> Tensor(a!) variants: function, method + tags: pointwise - func: clip.out(Tensor self, Scalar? min=None, Scalar? max=None, *, Tensor(a!) out) -> Tensor(a!) cpp_no_default_args: ['min'] + tags: pointwise - func: clip.Tensor_out(Tensor self, Tensor? min=None, Tensor? max=None, *, Tensor(a!) out) -> Tensor(a!) @@ -1463,13 +1547,13 @@ variants: method manual_cpp_binding: True -- func: convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups) -> Tensor +- func: convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, SymInt[] padding, int[] dilation, bool transposed, SymInt[] output_padding, int groups) -> Tensor dispatch: CompositeExplicitAutograd: convolution autogen: convolution.out tags: canonical -- func: convolution_backward(Tensor grad_output, Tensor input, Tensor weight, SymInt[]? bias_sizes, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor) +- func: convolution_backward(Tensor grad_output, Tensor input, Tensor weight, SymInt[]? bias_sizes, int[] stride, SymInt[] padding, int[] dilation, bool transposed, SymInt[] output_padding, int groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor) dispatch: CompositeExplicitAutograd, CUDA: convolution_backward autogen: convolution_backward.out @@ -1485,7 +1569,7 @@ CompositeExplicitAutograd: convolution_backward_overrideable autogen: convolution_backward_overrideable.out -- func: _convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor +- func: _convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, SymInt[] padding, int[] dilation, bool transposed, SymInt[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor dispatch: CompositeExplicitAutograd: _convolution autogen: _convolution.out @@ -1494,7 +1578,7 @@ - func: _convolution_mode(Tensor input, Tensor weight, Tensor? bias, int[] stride, str padding, int[] dilation, int groups) -> Tensor -- func: _convolution_double_backward(Tensor? ggI, Tensor? ggW, Tensor? ggb, Tensor gO, Tensor weight, Tensor self, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor) +- func: _convolution_double_backward(Tensor? ggI, Tensor? ggW, Tensor? ggb, Tensor gO, Tensor weight, Tensor self, int[] stride, SymInt[] padding, int[] dilation, bool transposed, SymInt[] output_padding, int groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor) - func: conv1d(Tensor input, Tensor weight, Tensor? bias=None, int[1] stride=1, int[1] padding=0, int[1] dilation=1, int groups=1) -> Tensor @@ -1527,6 +1611,8 @@ - func: copy(Tensor self, Tensor src, bool non_blocking=False) -> Tensor variants: function + dispatch: + CompositeExplicitAutogradNonFunctional: copy - func: copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!) variants: method @@ -1537,6 +1623,7 @@ SparseCPU, SparseCUDA: copy_sparse_wrapper_ CompositeExplicitAutograd: copy_ SparseCsrCPU, SparseCsrCUDA: copy_sparse_compressed_ + NestedTensorCPU, NestedTensorCUDA: copy_nested_ autogen: copy.out - func: _copy_from(Tensor self, Tensor dst, bool non_blocking=False) -> Tensor @@ -1555,11 +1642,13 @@ device_check: NoCheck # TensorIterator variants: function, method structured_delegate: cos.out + tags: pointwise - func: cos_(Tensor(a!) self) -> Tensor(a!) device_check: NoCheck # TensorIterator variants: function, method structured_delegate: cos.out + tags: pointwise - func: cos.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -1568,16 +1657,19 @@ dispatch: CPU, CUDA: cos_out MPS: cos_out_mps + tags: pointwise - func: cosh(Tensor self) -> Tensor device_check: NoCheck # TensorIterator variants: function, method structured_delegate: cosh.out + tags: pointwise - func: cosh_(Tensor(a!) self) -> Tensor(a!) device_check: NoCheck # TensorIterator variants: function, method structured_delegate: cosh.out + tags: pointwise - func: cosh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -1586,6 +1678,7 @@ dispatch: CPU, CUDA: cosh_out MPS: cosh_out_mps + tags: pointwise - func: cosine_embedding_loss(Tensor input1, Tensor input2, Tensor target, float margin=0.0, int reduction=Mean) -> Tensor @@ -1769,6 +1862,7 @@ device_check: NoCheck # TensorIterator dispatch: CPU, CUDA: cumsum_out + MPS: cumsum_out_mps - func: cumsum.dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor device_check: NoCheck # TensorIterator @@ -1815,7 +1909,7 @@ - func: diag_embed(Tensor self, int offset=0, int dim1=-2, int dim2=-1) -> Tensor variants: function, method dispatch: - CompositeExplicitAutograd: diag_embed + CompositeExplicitAutogradNonFunctional: diag_embed autogen: diag_embed.out - func: diagflat(Tensor self, int offset=0) -> Tensor @@ -1878,7 +1972,8 @@ dispatch: SparseCPU, SparseCUDA: div_sparse ZeroTensor: div_zerotensor - tags: canonical + NestedTensorCPU, NestedTensorCUDA: NestedTensor_div_Tensor + tags: [canonical, pointwise] - func: div_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -1886,6 +1981,7 @@ structured_delegate: div.out dispatch: SparseCPU, SparseCUDA: div_sparse_ + tags: pointwise - func: div.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -1895,6 +1991,7 @@ CPU, CUDA: div_out MPS: div_out_mps SparseCPU, SparseCUDA: div_out_sparse_zerodim + tags: pointwise - func: div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor device_check: NoCheck # TensorIterator @@ -1902,6 +1999,7 @@ structured_delegate: div.out_mode dispatch: SparseCPU, SparseCUDA: div_sparse + tags: pointwise - func: div_.Tensor_mode(Tensor(a!) self, Tensor other, *, str? rounding_mode) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -1909,6 +2007,7 @@ structured_delegate: div.out_mode dispatch: SparseCPU, SparseCUDA: div_sparse_ + tags: pointwise - func: div.out_mode(Tensor self, Tensor other, *, str? rounding_mode, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -1918,6 +2017,7 @@ CPU, CUDA: div_out_mode MPS: div_out_mode_mps SparseCPU, SparseCUDA: div_out_sparse_zerodim + tags: pointwise # For C++ only, until we have conversion from C++ numbers to Tensor - func: div.Scalar(Tensor self, Scalar other) -> Tensor @@ -1925,7 +2025,8 @@ variants: function, method dispatch: CompositeExplicitAutograd: div - tags: canonical + NestedTensorCPU, NestedTensorCUDA: NestedTensor_div_Scalar + tags: [canonical, pointwise] - func: div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -1933,17 +2034,20 @@ dispatch: CompositeExplicitAutograd: div_ autogen: div.Scalar_out + tags: pointwise - func: div.Scalar_mode(Tensor self, Scalar other, *, str? rounding_mode) -> Tensor variants: function, method dispatch: CompositeExplicitAutograd: div + tags: pointwise - func: div_.Scalar_mode(Tensor(a!) self, Scalar other, *, str? rounding_mode) -> Tensor(a!) variants: method dispatch: CompositeExplicitAutograd: div_ autogen: div.Scalar_mode_out + tags: pointwise # divide, alias for div - func: divide.Tensor(Tensor self, Tensor other) -> Tensor @@ -1978,6 +2082,7 @@ - func: true_divide.Tensor(Tensor self, Tensor other) -> Tensor device_check: NoCheck # TensorIterator variants: function, method + tags: pointwise - func: true_divide_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -2017,17 +2122,17 @@ - func: einsum(str equation, Tensor[] tensors, *, int[]? path=None) -> Tensor -- func: embedding(Tensor weight, Tensor indices, int padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor +- func: embedding(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor dispatch: - CompositeExplicitAutograd: embedding + CompositeExplicitAutograd: embedding_symint NestedTensorCPU, NestedTensorCUDA: NestedTensor_embedding autogen: embedding.out -- func: embedding_backward(Tensor grad, Tensor indices, SymInt num_weights, int padding_idx, bool scale_grad_by_freq, bool sparse) -> Tensor +- func: embedding_backward(Tensor grad, Tensor indices, SymInt num_weights, SymInt padding_idx, bool scale_grad_by_freq, bool sparse) -> Tensor dispatch: CompositeImplicitAutograd: embedding_backward_symint -- func: embedding_dense_backward(Tensor grad_output, Tensor indices, SymInt num_weights, int padding_idx, bool scale_grad_by_freq) -> Tensor +- func: embedding_dense_backward(Tensor grad_output, Tensor indices, SymInt num_weights, SymInt padding_idx, bool scale_grad_by_freq) -> Tensor dispatch: CPU: embedding_dense_backward_cpu CUDA: embedding_dense_backward_cuda @@ -2234,7 +2339,7 @@ dispatch: SparseCPU, SparseCUDA: erf_sparse SparseCsrCPU, SparseCsrCUDA: erf_sparse_csr - tags: canonical + tags: [canonical, pointwise] - func: erf_(Tensor(a!) self) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -2243,6 +2348,7 @@ dispatch: SparseCPU, SparseCUDA: erf_sparse_ SparseCsrCPU, SparseCsrCUDA: erf_sparse_csr_ + tags: pointwise - func: erf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -2253,16 +2359,19 @@ MPS: erf_out_mps SparseCPU, SparseCUDA: erf_sparse_out SparseCsrCPU, SparseCsrCUDA: erf_sparse_csr_out + tags: pointwise - func: erfc(Tensor self) -> Tensor device_check: NoCheck # TensorIterator structured_delegate: erfc.out variants: function, method + tags: pointwise - func: erfc_(Tensor(a!) self) -> Tensor(a!) device_check: NoCheck # TensorIterator structured_delegate: erfc.out variants: function, method + tags: pointwise - func: erfc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -2270,17 +2379,19 @@ structured_inherits: TensorIteratorBase dispatch: CPU, CUDA: erfc_out + tags: pointwise - func: exp(Tensor self) -> Tensor device_check: NoCheck # TensorIterator structured_delegate: exp.out variants: function, method - tags: canonical + tags: [canonical, pointwise] - func: exp_(Tensor(a!) self) -> Tensor(a!) device_check: NoCheck # TensorIterator structured_delegate: exp.out variants: function, method + tags: pointwise - func: exp.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -2289,14 +2400,17 @@ dispatch: CPU, CUDA: exp_out MPS: exp_out_mps + tags: pointwise - func: exp2(Tensor self) -> Tensor structured_delegate: exp2.out variants: function, method + tags: pointwise - func: exp2_(Tensor(a!) self) -> Tensor(a!) structured_delegate: exp2.out variants: function, method + tags: pointwise - func: exp2.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) structured: True @@ -2304,6 +2418,7 @@ dispatch: CPU, CUDA: exp2_out MPS: exp2_out_mps + tags: pointwise - func: expm1(Tensor self) -> Tensor device_check: NoCheck # TensorIterator @@ -2312,6 +2427,7 @@ dispatch: SparseCPU, SparseCUDA: expm1_sparse SparseCsrCPU, SparseCsrCUDA: expm1_sparse_csr + tags: pointwise - func: expm1_(Tensor(a!) self) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -2320,6 +2436,7 @@ dispatch: SparseCPU, SparseCUDA: expm1_sparse_ SparseCsrCPU, SparseCsrCUDA: expm1_sparse_csr_ + tags: pointwise - func: expm1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -2327,8 +2444,10 @@ structured_inherits: TensorIteratorBase dispatch: CPU, CUDA: expm1_out + MPS: expm1_out_mps SparseCPU, SparseCUDA: expm1_sparse_out SparseCsrCPU, SparseCsrCUDA: expm1_sparse_csr_out + tags: pointwise - func: expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a) variants: method # This is method-only to match the previous tensor API. In the future we could make this a function too. @@ -2402,6 +2521,7 @@ QuantizedCPU, QuantizedCUDA: fill_quantized_ Meta: fill_meta_ SparseCsrCPU, SparseCsrCUDA: fill_sparse_csr_ + NestedTensorCPU, NestedTensorCUDA: fill_nested_ autogen: fill.Scalar_out - func: fill_.Tensor(Tensor(a!) self, Tensor value) -> Tensor(a!) @@ -2412,6 +2532,7 @@ MPS: fill_tensor_mps_ QuantizedCPU, QuantizedCUDA: fill_quantized_ Meta: fill_meta_ + NestedTensorCPU, NestedTensorCUDA: fill_nested_ autogen: fill.Tensor_out - func: floor(Tensor self) -> Tensor @@ -2421,6 +2542,7 @@ dispatch: SparseCPU, SparseCUDA: floor_sparse SparseCsrCPU, SparseCsrCUDA: floor_sparse_csr + tags: pointwise - func: floor_(Tensor(a!) self) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -2429,6 +2551,7 @@ dispatch: SparseCPU, SparseCUDA: floor_sparse_ SparseCsrCPU, SparseCsrCUDA: floor_sparse_csr_ + tags: pointwise - func: floor.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -2439,12 +2562,14 @@ MPS: floor_out_mps SparseCPU, SparseCUDA: floor_sparse_out SparseCsrCPU, SparseCsrCUDA: floor_sparse_csr_out + tags: pointwise - func: floor_divide(Tensor self, Tensor other) -> Tensor device_check: NoCheck # TensorIterator variants: function, method dispatch: CPU, CUDA: floor_divide + MPS: floor_divide_mps SparseCPU, SparseCUDA: floor_divide_sparse - func: floor_divide_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) @@ -2452,12 +2577,14 @@ variants: method dispatch: CPU, CUDA: floor_divide_ + MPS: floor_divide_mps_ SparseCPU, SparseCUDA: floor_divide_sparse_ - func: floor_divide.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator dispatch: CPU, CUDA: floor_divide_out + MPS: floor_divide_out_mps SparseCPU, SparseCUDA: floor_divide_out_sparse_zerodim - func: floor_divide.Scalar(Tensor self, Scalar other) -> Tensor @@ -2472,11 +2599,19 @@ device_check: NoCheck # TensorIterator structured_delegate: frac.out variants: function, method + dispatch: + SparseCPU, SparseCUDA: frac_sparse + SparseCsrCPU, SparseCsrCUDA: frac_sparse_csr + tags: pointwise - func: frac_(Tensor(a!) self) -> Tensor(a!) device_check: NoCheck # TensorIterator structured_delegate: frac.out variants: function, method + dispatch: + SparseCPU, SparseCUDA: frac_sparse_ + SparseCsrCPU, SparseCsrCUDA: frac_sparse_csr_ + tags: pointwise - func: frac.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -2485,6 +2620,9 @@ dispatch: CPU, CUDA: frac_out MPS: frac_out_mps + SparseCPU, SparseCUDA: frac_sparse_out + SparseCsrCPU, SparseCsrCUDA: frac_sparse_csr_out + tags: pointwise - func: full.names(int[] size, Scalar fill_value, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor device_check: NoCheck @@ -2518,10 +2656,12 @@ structured_inherits: TensorIteratorBase dispatch: CPU, CUDA: gcd_out + tags: pointwise - func: gcd(Tensor self, Tensor other) -> Tensor structured_delegate: gcd.out variants: function, method + tags: pointwise - func: gcd_(Tensor(a!) self, Tensor other) -> Tensor(a!) structured_delegate: gcd.out @@ -2532,10 +2672,12 @@ structured_inherits: TensorIteratorBase dispatch: CPU, CUDA: lcm_out + tags: pointwise - func: lcm(Tensor self, Tensor other) -> Tensor structured_delegate: lcm.out variants: function, method + tags: pointwise - func: lcm_(Tensor(a!) self, Tensor other) -> Tensor(a!) structured_delegate: lcm.out @@ -2564,6 +2706,7 @@ dispatch: CPU, QuantizedCPU: grid_sampler_2d_cpu CUDA: grid_sampler_2d_cuda + MPS: grid_sampler_2d_mps autogen: grid_sampler_2d.out tags: canonical @@ -2830,6 +2973,7 @@ SparseCPU, SparseCUDA: isnan_sparse SparseCsrCPU, SparseCsrCUDA: isnan_sparse_csr autogen: isnan.out + tags: pointwise - func: is_distributed(Tensor self) -> bool variants: function, method @@ -2913,7 +3057,9 @@ - func: kthvalue.dimname_out(Tensor self, int k, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) -- func: layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor +- func: layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor + dispatch: + CompositeImplicitAutograd: layer_norm_symint - func: native_layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor) dispatch: @@ -2938,17 +3084,21 @@ dispatch: CompositeExplicitAutograd: nan_to_num SparseCPU, SparseCUDA: nan_to_num_sparse + tags: pointwise - func: nan_to_num_(Tensor(a!) self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor(a!) variants: function, method dispatch: CompositeExplicitAutograd: nan_to_num_ SparseCPU, SparseCUDA: nan_to_num_sparse_ + tags: pointwise - func: nan_to_num.out(Tensor self, float? nan=None, float? posinf=None, float? neginf=None, *, Tensor(a!) out) -> Tensor(a!) dispatch: CPU, CUDA: nan_to_num_out + MPS: nan_to_num_out_mps SparseCPU, SparseCUDA: nan_to_num_sparse_out + tags: pointwise - func: linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor python_module: nn @@ -3010,8 +3160,10 @@ - func: ldexp_(Tensor(a!) self, Tensor other) -> Tensor(a!) variants: function, method + tags: pointwise - func: ldexp.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + tags: pointwise - func: linspace(Scalar start, Scalar end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor dispatch: @@ -3027,12 +3179,13 @@ device_check: NoCheck # TensorIterator structured_delegate: log.out variants: function, method - tags: canonical + tags: [canonical, pointwise] - func: log_(Tensor(a!) self) -> Tensor(a!) device_check: NoCheck # TensorIterator structured_delegate: log.out variants: function, method + tags: pointwise - func: log.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -3041,16 +3194,19 @@ dispatch: CPU, CUDA: log_out MPS: log_out_mps + tags: pointwise - func: log10(Tensor self) -> Tensor device_check: NoCheck # TensorIterator structured_delegate: log10.out variants: function, method + tags: pointwise - func: log10_(Tensor(a!) self) -> Tensor(a!) device_check: NoCheck # TensorIterator structured_delegate: log10.out variants: function, method + tags: pointwise - func: log10.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -3059,6 +3215,7 @@ dispatch: CPU, CUDA: log10_out MPS: log10_out_mps + tags: pointwise - func: log1p(Tensor self) -> Tensor device_check: NoCheck # TensorIterator @@ -3067,6 +3224,7 @@ dispatch: SparseCPU, SparseCUDA: log1p_sparse SparseCsrCPU, SparseCsrCUDA: log1p_sparse_csr + tags: pointwise - func: log1p_(Tensor(a!) self) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -3075,6 +3233,7 @@ dispatch: SparseCPU, SparseCUDA: log1p_sparse_ SparseCsrCPU, SparseCsrCUDA: log1p_sparse_csr_ + tags: pointwise - func: log1p.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -3085,16 +3244,19 @@ MPS: log1p_out_mps SparseCPU, SparseCUDA: log1p_sparse_out SparseCsrCPU, SparseCsrCUDA: log1p_sparse_csr_out + tags: pointwise - func: log2(Tensor self) -> Tensor device_check: NoCheck # TensorIterator structured_delegate: log2.out variants: function, method + tags: pointwise - func: log2_(Tensor(a!) self) -> Tensor(a!) device_check: NoCheck # TensorIterator structured_delegate: log2.out variants: function, method + tags: pointwise - func: log2.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -3103,6 +3265,7 @@ dispatch: CPU, CUDA: log2_out MPS: log2_out_mps + tags: pointwise - func: logaddexp.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) structured: True @@ -3110,10 +3273,12 @@ dispatch: CPU, CUDA: logaddexp_out MPS: logaddexp_out_mps + tags: pointwise - func: logaddexp(Tensor self, Tensor other) -> Tensor variants: method, function structured_delegate: logaddexp.out + tags: pointwise - func: logaddexp2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) structured: True @@ -3121,33 +3286,39 @@ dispatch: CPU, CUDA: logaddexp2_out MPS: logaddexp2_out_mps + tags: pointwise - func: logaddexp2(Tensor self, Tensor other) -> Tensor variants: method, function structured_delegate: logaddexp2.out + tags: pointwise - func: xlogy.Tensor(Tensor self, Tensor other) -> Tensor device_check: NoCheck # TensorIterator structured_delegate: xlogy.OutTensor variants: function, method + tags: pointwise - func: xlogy.Scalar_Self(Scalar self, Tensor other) -> Tensor device_check: NoCheck # TensorIterator variants: function dispatch: CompositeExplicitAutograd: xlogy + tags: pointwise - func: xlogy.Scalar_Other(Tensor self, Scalar other) -> Tensor device_check: NoCheck # TensorIterator variants: function, method dispatch: CompositeExplicitAutograd: xlogy + tags: pointwise # xlogy: inplace variant - func: xlogy_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) device_check: NoCheck # TensorIterator variants: function, method structured_delegate: xlogy.OutTensor + tags: pointwise - func: xlogy_.Scalar_Other(Tensor(a!) self, Scalar other) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -3163,18 +3334,21 @@ variants: function dispatch: CPU, CUDA: xlogy_out + tags: pointwise - func: xlogy.OutScalar_Self(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator variants: function dispatch: CompositeExplicitAutograd: xlogy_out + tags: pointwise - func: xlogy.OutScalar_Other(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator variants: function dispatch: CompositeExplicitAutograd: xlogy_out + tags: pointwise - func: logspace(Scalar start, Scalar end, int steps, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor dispatch: @@ -3469,6 +3643,7 @@ dispatch: CPU: median_cpu CUDA: median_cuda + MPS: median_mps autogen: median.out - func: median.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) @@ -3480,6 +3655,7 @@ dispatch: CPU: median_out_cpu CUDA: median_out_cuda + MPS: median_out_mps - func: median.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) variants: function, method @@ -3556,7 +3732,7 @@ MPS: mps_convolution_backward autogen: mps_convolution_backward.out -- func: mkldnn_convolution(Tensor self, Tensor weight, Tensor? bias, int[] padding, int[] stride, int[] dilation, int groups) -> Tensor +- func: mkldnn_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, int[] stride, int[] dilation, int groups) -> Tensor dispatch: CompositeExplicitAutograd: mkldnn_convolution autogen: mkldnn_convolution.out @@ -3571,17 +3747,17 @@ CUDA: miopen_batch_norm_backward autogen: miopen_batch_norm_backward.out -- func: miopen_convolution(Tensor self, Tensor weight, Tensor? bias, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor +- func: miopen_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor dispatch: CUDA: miopen_convolution autogen: miopen_convolution.out -- func: miopen_convolution_transpose(Tensor self, Tensor weight, Tensor? bias, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor +- func: miopen_convolution_transpose(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] output_padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor dispatch: CUDA: miopen_convolution_transpose autogen: miopen_convolution_transpose.out -- func: miopen_depthwise_convolution(Tensor self, Tensor weight, Tensor? bias, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor +- func: miopen_depthwise_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor dispatch: CUDA: miopen_depthwise_convolution autogen: miopen_depthwise_convolution.out @@ -3660,7 +3836,7 @@ MkldnnCPU: mkldnn_mul ZeroTensor: mul_zerotensor NestedTensorCPU, NestedTensorCUDA: NestedTensor_mul_Tensor - tags: canonical + tags: [canonical, pointwise] - func: mul_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -3671,6 +3847,7 @@ SparseCsrCPU, SparseCsrCUDA: mul_sparse_csr_ MkldnnCPU: mkldnn_mul_ NestedTensorCPU, NestedTensorCUDA: NestedTensor_mul__Tensor + tags: pointwise - func: mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -3683,8 +3860,9 @@ SparseCUDA: mul_out_sparse_cuda SparseCsrCPU, SparseCsrCUDA: mul_out_sparse_csr MkldnnCPU: mkldnn_mul_out - + tags: pointwise # For C++ only, until we have conversion from C++ numbers to Tensor + - func: mul.Scalar(Tensor self, Scalar other) -> Tensor device_check: NoCheck # TensorIterator variants: function, method @@ -3692,7 +3870,7 @@ CompositeExplicitAutograd: mul SparseCsrCPU, SparseCsrCUDA: mul_scalar_sparse_csr NestedTensorCPU, NestedTensorCUDA: NestedTensor_mul_Scalar - tags: canonical + tags: [canonical, pointwise] - func: mul_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -3702,8 +3880,9 @@ SparseCsrCPU, SparseCsrCUDA: mul__scalar_sparse_csr NestedTensorCPU, NestedTensorCUDA: NestedTensor_mul__Scalar autogen: mul.Scalar_out - + tags: pointwise # multiply, alias for mul + - func: multiply.Tensor(Tensor self, Tensor other) -> Tensor variants: function, method @@ -3731,25 +3910,28 @@ - func: mvlgamma.out(Tensor self, int p, *, Tensor(a!) out) -> Tensor(a!) dispatch: CPU, CUDA: mvlgamma_out + tags: pointwise - func: mvlgamma(Tensor self, int p) -> Tensor device_check: NoCheck # TensorIterator variants: function, method dispatch: CompositeExplicitAutograd: mvlgamma + tags: pointwise - func: mvlgamma_(Tensor(a!) self, int p) -> Tensor(a!) device_check: NoCheck # TensorIterator variants: method dispatch: CompositeExplicitAutograd: mvlgamma_ + tags: pointwise - func: narrow_copy(Tensor self, int dim, SymInt start, SymInt length) -> Tensor variants: function, method dispatch: CPU: narrow_copy_dense_cpu SparseCPU, SparseCUDA: narrow_copy_sparse - CompositeExplicitAutogradNonFunctional: narrow_copy_dense + CompositeExplicitAutogradNonFunctional: narrow_copy_dense_symint tags: view_copy - func: narrow_copy.out(Tensor self, int dim, SymInt start, SymInt length, *, Tensor(a!) out) -> Tensor(a!) @@ -3782,6 +3964,36 @@ dispatch: CUDA: batch_norm_cuda_out MPS: batch_norm_mps_out + CPU: batch_norm_cpu_out + +# TODO: In 2 weeks, we should make native_batch_norm composite implicit so that this correct schema percolates correctly through our dispatching +- func: _native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor) + dispatch: + CPU: _batch_norm_legit_cpu + CUDA: _batch_norm_legit_cuda + MPS: _batch_norm_legit_mps + MkldnnCPU: _mkldnn_batch_norm_legit + autogen: _native_batch_norm_legit_functional + +- func: _native_batch_norm_legit.out(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, bool training, float momentum, float eps, *, Tensor(d!) out, Tensor(e!) save_mean, Tensor(f!) save_invstd) -> (Tensor(d!), Tensor(e!), Tensor(f!)) + dispatch: + CPU: _batch_norm_legit_cpu_out + CUDA: _batch_norm_legit_cuda_out + MPS: _batch_norm_legit_mps_out + +- func: _native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor) + dispatch: + CPU: _batch_norm_legit_no_stats_cpu + CUDA: _batch_norm_legit_no_stats_cuda + MPS: _batch_norm_legit_no_stats_mps + MkldnnCPU: _mkldnn_batch_norm_legit_no_stats + tags: canonical + +- func: _native_batch_norm_legit.no_stats_out(Tensor input, Tensor? weight, Tensor? bias, bool training, float momentum, float eps, *, Tensor(a!) out, Tensor(b!) save_mean, Tensor(c!) save_invstd) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + dispatch: + CPU: _batch_norm_legit_no_stats_cpu_out + CUDA: _batch_norm_legit_no_stats_cuda_out + MPS: _batch_norm_legit_no_stats_mps_out - func: batch_norm_stats(Tensor input, float eps) -> (Tensor, Tensor) dispatch: @@ -3835,7 +4047,7 @@ - func: _nnpack_available() -> bool -- func: _nnpack_spatial_convolution(Tensor input, Tensor weight, Tensor? bias, int[2] padding, int[2] stride=1) -> Tensor +- func: _nnpack_spatial_convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[2] padding, int[2] stride=1) -> Tensor variants: function dispatch: CompositeExplicitAutograd: _nnpack_spatial_convolution @@ -3861,6 +4073,7 @@ # NB: Although this composite mutates on the inside, it is # non-differentiable so NonFunctional doesn't apply CompositeExplicitAutograd: ones_like + NestedTensorCPU, NestedTensorCUDA: ones_like autogen: ones_like.out - func: pairwise_distance(Tensor x1, Tensor x2, float p=2, float eps=1e-06, bool keepdim=False) -> Tensor @@ -3875,6 +4088,7 @@ - func: _cdist_forward(Tensor x1, Tensor x2, float p, int? compute_mode) -> Tensor dispatch: CPU, CUDA: _cdist_forward + MPS: _cdist_forward_mps autogen: _cdist_forward.out - func: _cdist_backward(Tensor grad, Tensor x1, Tensor x2, float p, Tensor cdist) -> Tensor @@ -3995,32 +4209,44 @@ variants: function, method dispatch: CompositeExplicitAutograd: rad2deg + SparseCPU, SparseCUDA: rad2deg_sparse SparseCsrCPU, SparseCsrCUDA: rad2deg_sparse_csr - func: rad2deg_(Tensor(a!) self) -> Tensor(a!) variants: function, method dispatch: CompositeExplicitAutograd: rad2deg_ + SparseCPU, SparseCUDA: rad2deg_sparse_ SparseCsrCPU, SparseCsrCUDA: rad2deg_sparse_csr_ - func: rad2deg.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) dispatch: CompositeExplicitAutograd: rad2deg_out + SparseCPU, SparseCUDA: rad2deg_sparse_out SparseCsrCPU, SparseCsrCUDA: rad2deg_sparse_csr_out - func: deg2rad(Tensor self) -> Tensor variants: function, method dispatch: CompositeExplicitAutograd: deg2rad + SparseCPU, SparseCUDA: deg2rad_sparse + SparseCsrCPU, SparseCsrCUDA: deg2rad_sparse_csr + tags: pointwise - func: deg2rad_(Tensor(a!) self) -> Tensor(a!) variants: function, method dispatch: CompositeExplicitAutograd: deg2rad_ + SparseCPU, SparseCUDA: deg2rad_sparse_ + SparseCsrCPU, SparseCsrCUDA: deg2rad_sparse_csr_ + tags: pointwise - func: deg2rad.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) dispatch: CompositeExplicitAutograd: deg2rad_out + SparseCPU, SparseCUDA: deg2rad_sparse_out + SparseCsrCPU, SparseCsrCUDA: deg2rad_sparse_csr_out + tags: pointwise - func: scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor dispatch: @@ -4186,6 +4412,7 @@ dispatch: CPU: randperm_out_cpu CUDA: randperm_out_cuda + MPS: randperm_out_mps - func: range.step(Scalar start, Scalar end, Scalar step=1, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor dispatch: @@ -4212,12 +4439,13 @@ device_check: NoCheck # TensorIterator structured_delegate: reciprocal.out variants: function, method - tags: canonical + tags: [canonical, pointwise] - func: reciprocal_(Tensor(a!) self) -> Tensor(a!) device_check: NoCheck # TensorIterator structured_delegate: reciprocal.out variants: function, method + tags: pointwise - func: reciprocal.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -4226,6 +4454,7 @@ dispatch: CPU, CUDA: reciprocal_out MPS: reciprocal_out_mps + tags: pointwise - func: neg(Tensor self) -> Tensor device_check: NoCheck # TensorIterator @@ -4234,7 +4463,8 @@ dispatch: SparseCPU, SparseCUDA: neg_sparse SparseCsrCPU, SparseCsrCUDA: neg_sparse_csr - tags: canonical + NestedTensorCPU, NestedTensorCUDA: NestedTensor_neg + tags: [canonical, pointwise] - func: neg_(Tensor(a!) self) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -4243,6 +4473,8 @@ dispatch: SparseCPU, SparseCUDA: neg_sparse_ SparseCsrCPU, SparseCsrCUDA: neg_sparse_csr_ + NestedTensorCPU, NestedTensorCUDA: NestedTensor_neg_ + tags: pointwise - func: neg.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -4253,8 +4485,9 @@ MPS: neg_out_mps SparseCPU, SparseCUDA: neg_out_sparse SparseCsrCPU, SparseCsrCUDA: neg_sparse_csr_out - + tags: pointwise # Alias for neg + - func: negative(Tensor self) -> Tensor variants: function, method @@ -4282,8 +4515,10 @@ - func: repeat_interleave.self_Tensor(Tensor self, Tensor repeats, int? dim=None, *, int? output_size=None) -> Tensor variants: function, method -- func: repeat_interleave.self_int(Tensor self, int repeats, int? dim=None, *, int? output_size=None) -> Tensor +- func: repeat_interleave.self_int(Tensor self, SymInt repeats, int? dim=None, *, int? output_size=None) -> Tensor variants: function, method + dispatch: + CompositeImplicitAutograd: repeat_interleave_symint - func: reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a) variants: function, method @@ -4293,6 +4528,11 @@ CompositeImplicitAutograd: reshape_symint CompositeImplicitAutogradNestedTensor: reshape_nested +- func: _reshape_copy(Tensor self, SymInt[] size) -> Tensor + variants: function + dispatch: + CompositeExplicitAutograd: _reshape_copy_symint + # NOTE [ _reshape_alias ] is meant to be used in the implementation of reshape. # They are not user-facing, hence the leading underscore. Please don't use it # anywhere else. @@ -4326,6 +4566,7 @@ dispatch: SparseCPU, SparseCUDA: round_sparse SparseCsrCPU, SparseCsrCUDA: round_sparse_csr + tags: pointwise - func: round_(Tensor(a!) self) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -4334,6 +4575,7 @@ dispatch: SparseCPU, SparseCUDA: round_sparse_ SparseCsrCPU, SparseCsrCUDA: round_sparse_csr_ + tags: pointwise - func: round.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -4345,16 +4587,19 @@ MPS: round_out_mps SparseCPU, SparseCUDA: round_sparse_out SparseCsrCPU, SparseCsrCUDA: round_sparse_csr_out + tags: pointwise - func: round.decimals(Tensor self, *, int decimals) -> Tensor device_check: NoCheck # TensorIterator structured_delegate: round.decimals_out variants: function, method + tags: pointwise - func: round_.decimals(Tensor(a!) self, *, int decimals) -> Tensor(a!) device_check: NoCheck # TensorIterator structured_delegate: round.decimals_out variants: function, method + tags: pointwise - func: round.decimals_out(Tensor self, *, int decimals, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -4363,6 +4608,7 @@ dispatch: CPU: round_decimals_out CUDA: round_decimals_out + tags: pointwise - func: rrelu(Tensor self, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor device_check: NoCheck # TensorIterator @@ -4384,7 +4630,7 @@ NestedTensorCPU, NestedTensorCUDA: NestedTensor_relu SparseCPU, SparseCUDA: relu_sparse SparseCsrCPU, SparseCsrCUDA: relu_sparse_csr - tags: canonical + tags: [canonical, pointwise] - func: relu_(Tensor(a!) self) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -4399,6 +4645,7 @@ SparseCPU, SparseCUDA: relu_sparse_ SparseCsrCPU, SparseCsrCUDA: relu_sparse_csr_ autogen: relu.out + tags: pointwise - func: relu6(Tensor self) -> Tensor python_module: nn @@ -4451,7 +4698,7 @@ QuantizedCPU: gelu_quantized_cpu QuantizedCUDA: gelu_quantized_cuda NestedTensorCPU, NestedTensorCUDA: NestedTensor_gelu - tags: canonical + tags: [canonical, pointwise] - func: gelu_backward.grad_input(Tensor grad_output, Tensor self, *, str approximate='none', Tensor(a!) grad_input) -> Tensor(a!) structured: True @@ -4467,6 +4714,7 @@ python_module: nn dispatch: MkldnnCPU: mkldnn_gelu_backward + tags: pointwise - func: infinitely_differentiable_gelu_backward(Tensor grad, Tensor self) -> Tensor variants: function @@ -4500,12 +4748,13 @@ device_check: NoCheck # TensorIterator structured_delegate: rsqrt.out variants: function, method - tags: canonical + tags: [canonical, pointwise] - func: rsqrt_(Tensor(a!) self) -> Tensor(a!) device_check: NoCheck # TensorIterator structured_delegate: rsqrt.out variants: function, method + tags: pointwise - func: rsqrt.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -4514,35 +4763,36 @@ dispatch: CPU, CUDA: rsqrt_out MPS: rsqrt_out_mps + tags: pointwise - func: select.Dimname(Tensor(a) self, Dimname dim, int index) -> Tensor(a) variants: function, method device_check: NoCheck device_guard: False -- func: select.int(Tensor(a) self, int dim, int index) -> Tensor(a) +- func: select.int(Tensor(a) self, int dim, SymInt index) -> Tensor(a) variants: function, method device_check: NoCheck device_guard: False dispatch: - CompositeExplicitAutograd: select + CompositeExplicitAutograd: select_symint SparseCsrCPU, SparseCsrCUDA: select_sparse_csr NestedTensorCPU, NestedTensorCUDA: select_nested -- func: select_backward(Tensor grad_output, SymInt[] input_sizes, int dim, int index) -> Tensor +- func: select_backward(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt index) -> Tensor variants: function device_check: NoCheck device_guard: False dispatch: - CompositeExplicitAutogradNonFunctional: select_backward + CompositeExplicitAutogradNonFunctional: select_backward_symint autogen: select_backward.out -- func: _nested_select_backward(Tensor grad_output, Tensor self, int dim, int index) -> Tensor +- func: _nested_select_backward(Tensor grad_output, Tensor self, int dim, SymInt index) -> Tensor variants: function device_check: NoCheck device_guard: False dispatch: - NestedTensorCPU, NestedTensorCUDA: _nested_select_backward + NestedTensorCPU, NestedTensorCUDA: _nested_select_backward_symint - func: selu(Tensor self) -> Tensor device_check: NoCheck # TensorIterator @@ -4619,7 +4869,7 @@ dispatch: QuantizedCPU: sigmoid_quantized_cpu MkldnnCPU: mkldnn_sigmoid - tags: canonical + tags: [canonical, pointwise] - func: sigmoid_(Tensor(a!) self) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -4627,6 +4877,7 @@ variants: function, method dispatch: MkldnnCPU: mkldnn_sigmoid_ + tags: pointwise - func: sigmoid.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -4635,20 +4886,24 @@ dispatch: CPU, CUDA: sigmoid_out MPS: sigmoid_out_mps + tags: pointwise - func: logit(Tensor self, float? eps=None) -> Tensor variants: function, method dispatch: CPU, CUDA: logit + tags: pointwise - func: logit_(Tensor(a!) self, float? eps=None) -> Tensor(a!) variants: function, method dispatch: CPU, CUDA: logit_ + tags: pointwise - func: logit.out(Tensor self, float? eps=None, *, Tensor(a!) out) -> Tensor(a!) dispatch: CPU, CUDA: logit_out + tags: pointwise - func: sin(Tensor self) -> Tensor device_check: NoCheck # TensorIterator @@ -4657,6 +4912,7 @@ dispatch: SparseCsrCPU, SparseCsrCUDA: sin_sparse_csr SparseCPU, SparseCUDA: sin_sparse + tags: pointwise - func: sin_(Tensor(a!) self) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -4665,6 +4921,7 @@ dispatch: SparseCsrCPU, SparseCsrCUDA: sin_sparse_csr_ SparseCPU, SparseCUDA: sin_sparse_ + tags: pointwise - func: sin.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -4675,20 +4932,24 @@ MPS: sin_out_mps SparseCsrCPU, SparseCsrCUDA: sin_sparse_csr_out SparseCPU, SparseCUDA: sin_sparse_out + tags: pointwise - func: sinc(Tensor self) -> Tensor structured_delegate: sinc.out variants: function, method + tags: pointwise - func: sinc_(Tensor(a!) self) -> Tensor(a!) structured_delegate: sinc.out variants: function, method + tags: pointwise - func: sinc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) structured: True structured_inherits: TensorIteratorBase dispatch: CPU, CUDA: sinc_out + tags: pointwise - func: sinh(Tensor self) -> Tensor device_check: NoCheck # TensorIterator @@ -4697,6 +4958,7 @@ dispatch: SparseCPU, SparseCUDA: sinh_sparse SparseCsrCPU, SparseCsrCUDA: sinh_sparse_csr + tags: pointwise - func: sinh_(Tensor(a!) self) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -4705,6 +4967,7 @@ dispatch: SparseCPU, SparseCUDA: sinh_sparse_ SparseCsrCPU, SparseCsrCUDA: sinh_sparse_csr_ + tags: pointwise - func: sinh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -4727,6 +4990,7 @@ # to false to make such changes explicitly illegal, in order to prevent users from # changing metadata of the detached tensor and expecting the original tensor to also # be updated. + tags: pointwise - func: detach(Tensor(a) self) -> Tensor(a) variants: function, method dispatch: @@ -4781,12 +5045,12 @@ autogen: slice_scatter.out tags: canonical -- func: select_scatter(Tensor self, Tensor src, int dim, int index) -> Tensor +- func: select_scatter(Tensor self, Tensor src, int dim, SymInt index) -> Tensor variants: function, method device_check: NoCheck device_guard: False dispatch: - CompositeExplicitAutograd: select_scatter + CompositeExplicitAutograd: select_scatter_symint autogen: select_scatter.out - func: diagonal_scatter(Tensor self, Tensor src, int offset=0, int dim1=0, int dim2=1) -> Tensor @@ -5065,7 +5329,7 @@ dispatch: SparseCPU, SparseCUDA: sqrt_sparse SparseCsrCPU, SparseCsrCUDA: sqrt_sparse_csr - tags: canonical + tags: [canonical, pointwise] - func: sqrt_(Tensor(a!) self) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -5074,6 +5338,7 @@ dispatch: SparseCPU, SparseCUDA: sqrt_sparse_ SparseCsrCPU, SparseCsrCUDA: sqrt_sparse_csr_ + tags: pointwise - func: sqrt.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -5084,26 +5349,32 @@ MPS: sqrt_out_mps SparseCPU, SparseCUDA: sqrt_sparse_out SparseCsrCPU, SparseCsrCUDA: sqrt_sparse_csr_out + tags: pointwise - func: square(Tensor self) -> Tensor device_check: NoCheck # TensorIterator variants: function, method + tags: pointwise - func: square_(Tensor(a!) self) -> Tensor(a!) device_check: NoCheck # TensorIterator variants: function, method + tags: pointwise - func: square.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + tags: pointwise - func: std(Tensor self, bool unbiased=True) -> Tensor device_check: NoCheck # TensorIterator variants: function, method + cpp_no_default_args: ["unbiased"] - func: std.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor device_check: NoCheck # TensorIterator variants: function, method + cpp_no_default_args: ["unbiased"] -- func: std.correction(Tensor self, int[1]? dim, *, int? correction, bool keepdim=False) -> Tensor +- func: std.correction(Tensor self, int[1]? dim=None, *, int? correction=None, bool keepdim=False) -> Tensor device_check: NoCheck # TensorIterator variants: function, method dispatch: @@ -5114,12 +5385,14 @@ - func: std_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor) device_check: NoCheck # TensorIterator variants: function + cpp_no_default_args: ["unbiased"] - func: std_mean.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor) device_check: NoCheck # TensorIterator variants: function + cpp_no_default_args: ["unbiased"] -- func: std_mean.correction(Tensor self, int[1]? dim, *, int? correction, bool keepdim=False) -> (Tensor, Tensor) +- func: std_mean.correction(Tensor self, int[1]? dim=None, *, int? correction=None, bool keepdim=False) -> (Tensor, Tensor) device_check: NoCheck # TensorIterator variants: function dispatch: @@ -5129,15 +5402,17 @@ - func: std_mean.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor) device_check: NoCheck # TensorIterator variants: function + cpp_no_default_args: ["unbiased"] -- func: std_mean.correction_names(Tensor self, Dimname[1] dim, *, int? correction, bool keepdim=False) -> (Tensor, Tensor) +- func: std_mean.correction_names(Tensor self, Dimname[1] dim, *, int? correction=None, bool keepdim=False) -> (Tensor, Tensor) device_check: NoCheck # TensorIterator variants: function - func: std.out(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator + cpp_no_default_args: ["unbiased"] -- func: std.correction_out(Tensor self, int[1]? dim, *, int? correction, bool keepdim=False, Tensor(a!) out) -> Tensor(a!) +- func: std.correction_out(Tensor self, int[1]? dim=None, *, int? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator dispatch: CPU, CUDA: std_out @@ -5146,15 +5421,17 @@ - func: std.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor device_check: NoCheck # TensorIterator variants: function, method + cpp_no_default_args: ["unbiased"] - func: std.names_out(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator + cpp_no_default_args: ["unbiased"] -- func: std.correction_names(Tensor self, Dimname[1] dim, *, int? correction, bool keepdim=False) -> Tensor +- func: std.correction_names(Tensor self, Dimname[1] dim, *, int? correction=None, bool keepdim=False) -> Tensor device_check: NoCheck # TensorIterator variants: function, method -- func: std.correction_names_out(Tensor self, Dimname[1] dim, *, int? correction, bool keepdim=False, Tensor(a!) out) -> Tensor(a!) +- func: std.correction_names_out(Tensor self, Dimname[1] dim, *, int? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator variants: function @@ -5207,6 +5484,7 @@ dispatch: SparseCPU, SparseCUDA: tan_sparse SparseCsrCPU, SparseCsrCUDA: tan_sparse_csr + tags: pointwise - func: tan_(Tensor(a!) self) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -5215,6 +5493,7 @@ dispatch: SparseCPU, SparseCUDA: tan_sparse_ SparseCsrCPU, SparseCsrCUDA: tan_sparse_csr_ + tags: pointwise - func: tan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -5225,6 +5504,7 @@ MPS: tan_out_mps SparseCPU, SparseCUDA: tan_sparse_out SparseCsrCPU, SparseCsrCUDA: tan_sparse_csr_out + tags: pointwise - func: tanh(Tensor self) -> Tensor device_check: NoCheck # TensorIterator @@ -5236,7 +5516,7 @@ SparseCPU, SparseCUDA: tanh_sparse SparseCsrCPU, SparseCsrCUDA: tanh_sparse_csr NestedTensorCPU, NestedTensorCUDA: NestedTensor_tanh - tags: canonical + tags: [canonical, pointwise] - func: tanh_(Tensor(a!) self) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -5247,6 +5527,7 @@ SparseCPU, SparseCUDA: tanh_sparse_ SparseCsrCPU, SparseCsrCUDA: tanh_sparse_csr_ NestedTensorCPU, NestedTensorCUDA: NestedTensor_tanh_ + tags: pointwise - func: tanh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -5257,6 +5538,7 @@ MPS: tanh_out_mps SparseCPU, SparseCUDA: tanh_sparse_out SparseCsrCPU, SparseCsrCUDA: tanh_sparse_csr_out + tags: pointwise - func: tensordot(Tensor self, Tensor other, int[] dims_self, int[] dims_other) -> Tensor variants: function @@ -5303,6 +5585,7 @@ MkldnnCPU: mkldnn_relu_backward SparseCPU, SparseCUDA: threshold_backward_sparse SparseCsrCPU, SparseCsrCUDA: threshold_backward_sparse_compressed + tags: pointwise - func: tile(Tensor self, int[] dims) -> Tensor variants: function, method @@ -5461,6 +5744,7 @@ dispatch: SparseCPU, SparseCUDA: trunc_sparse SparseCsrCPU, SparseCsrCUDA: trunc_sparse_csr + tags: pointwise - func: trunc_(Tensor(a!) self) -> Tensor(a!) structured_delegate: trunc.out @@ -5469,6 +5753,7 @@ dispatch: SparseCPU, SparseCUDA: trunc_sparse_ SparseCsrCPU, SparseCsrCUDA: trunc_sparse_csr_ + tags: pointwise - func: trunc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) structured: True @@ -5479,8 +5764,9 @@ MPS: trunc_out_mps SparseCPU, SparseCUDA: trunc_sparse_out SparseCsrCPU, SparseCsrCUDA: trunc_sparse_csr_out - + tags: pointwise # Alias for trunc + - func: fix(Tensor self) -> Tensor variants: function, method @@ -5515,6 +5801,7 @@ dispatch: CPU: unique_consecutive_cpu CUDA: unique_consecutive_cuda + MPS: unique_consecutive_mps tags: dynamic_output_shape autogen: unique_consecutive.out @@ -5523,6 +5810,7 @@ dispatch: CPU: unique_dim_consecutive_cpu CUDA: unique_dim_consecutive_cuda + MPS: unique_dim_consecutive_mps tags: dynamic_output_shape autogen: unique_dim_consecutive.out @@ -5535,6 +5823,7 @@ dispatch: CPU: _unique2_cpu CUDA: _unique2_cuda + MPS: _unique2_mps tags: dynamic_output_shape autogen: _unique2.out @@ -5567,13 +5856,15 @@ - func: var(Tensor self, bool unbiased=True) -> Tensor device_check: NoCheck # TensorIterator variants: function, method + cpp_no_default_args: ["unbiased"] - func: var.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor device_check: NoCheck # TensorIterator variants: function, method tags: canonical + cpp_no_default_args: ["unbiased"] -- func: var.correction(Tensor self, int[1]? dim, *, int? correction, bool keepdim=False) -> Tensor +- func: var.correction(Tensor self, int[1]? dim=None, *, int? correction=None, bool keepdim=False) -> Tensor device_check: NoCheck # TensorIterator variants: function, method dispatch: @@ -5582,8 +5873,9 @@ - func: var.out(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator + cpp_no_default_args: ["unbiased"] -- func: var.correction_out(Tensor self, int[1]? dim, *, int? correction, bool keepdim=False, Tensor(a!) out) -> Tensor(a!) +- func: var.correction_out(Tensor self, int[1]? dim=None, *, int? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator dispatch: CPU, CUDA: var_out @@ -5591,27 +5883,31 @@ - func: var.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor device_check: NoCheck # TensorIterator variants: function, method + cpp_no_default_args: ["unbiased"] - func: var.names_out(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator + cpp_no_default_args: ["unbiased"] -- func: var.correction_names(Tensor self, Dimname[1] dim, *, int? correction, bool keepdim=False) -> Tensor +- func: var.correction_names(Tensor self, Dimname[1] dim, *, int? correction=None, bool keepdim=False) -> Tensor device_check: NoCheck # TensorIterator variants: function, method -- func: var.correction_names_out(Tensor self, Dimname[1] dim, *, int? correction, bool keepdim=False, Tensor(a!) out) -> Tensor(a!) +- func: var.correction_names_out(Tensor self, Dimname[1] dim, *, int? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator variants: function - func: var_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor) device_check: NoCheck # TensorIterator variants: function + cpp_no_default_args: ["unbiased"] - func: var_mean.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor) device_check: NoCheck # TensorIterator variants: function + cpp_no_default_args: ["unbiased"] -- func: var_mean.correction(Tensor self, int[1]? dim, *, int? correction, bool keepdim=False) -> (Tensor, Tensor) +- func: var_mean.correction(Tensor self, int[1]? dim=None, *, int? correction=None, bool keepdim=False) -> (Tensor, Tensor) device_check: NoCheck # TensorIterator variants: function dispatch: @@ -5621,8 +5917,9 @@ - func: var_mean.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor) device_check: NoCheck # TensorIterator variants: function + cpp_no_default_args: ["unbiased"] -- func: var_mean.correction_names(Tensor self, Dimname[1] dim, *, int? correction, bool keepdim=False) -> (Tensor, Tensor) +- func: var_mean.correction_names(Tensor self, Dimname[1] dim, *, int? correction=None, bool keepdim=False) -> (Tensor, Tensor) device_check: NoCheck # TensorIterator variants: function @@ -5637,7 +5934,7 @@ dispatch: CPU, CUDA: where MPS: where_mps - tags: canonical + tags: [canonical, pointwise] - func: where.self_out(Tensor condition, Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -5694,6 +5991,7 @@ dispatch: CPU: _efficientzerotensor CUDA: _efficientzerotensor_cuda + Meta: _efficientzerotensor_meta autogen: _efficientzerotensor.out - func: zeros(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor @@ -5881,6 +6179,7 @@ device_check: NoCheck # TensorIterator dispatch: CPU, CUDA: norm_dtype_out + MPS: norm_dtype_out_mps - func: norm.out(Tensor self, Scalar? p, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) structured: True @@ -5908,12 +6207,14 @@ variants: method, function dispatch: CompositeExplicitAutograd: frexp + tags: pointwise - func: frexp.Tensor_out(Tensor self, *, Tensor(a!) mantissa, Tensor(b!) exponent) -> (Tensor(a!) mantissa, Tensor(b!) exponent) dispatch: CPU, CUDA: frexp_out - + tags: pointwise # Deprecated (v.1.12) + - func: frobenius_norm(Tensor self) -> Tensor variants: function @@ -5955,6 +6256,7 @@ - func: positive(Tensor(a) self) -> Tensor(a) variants: function, method + tags: pointwise - func: resize_as_(Tensor(a!) self, Tensor the_template, *, MemoryFormat? memory_format=None) -> Tensor(a!) use_const_ref_for_mutable_tensors: True @@ -5991,6 +6293,7 @@ CPU, CUDA: sub_out MPS: sub_out_mps SparseCPU, SparseCUDA: sub_out_sparse + tags: pointwise - func: sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor device_check: NoCheck # TensorIterator @@ -5999,7 +6302,7 @@ dispatch: SparseCPU, SparseCUDA: sub_sparse ZeroTensor: sub_zerotensor - tags: canonical + tags: [canonical, pointwise] - func: sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -6007,14 +6310,15 @@ structured_delegate: sub.out dispatch: SparseCPU, SparseCUDA: sub_sparse_ - + tags: pointwise # For C++ only, until we have conversion from C++ numbers to Tensor + - func: sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor device_check: NoCheck # TensorIterator variants: function, method dispatch: CompositeExplicitAutograd: sub - tags: canonical + tags: [canonical, pointwise] - func: sub_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -6022,8 +6326,9 @@ dispatch: CompositeExplicitAutograd: sub_ autogen: sub.Scalar_out - + tags: pointwise # subtract, alias for sub + - func: subtract.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) - func: subtract.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor @@ -6052,11 +6357,13 @@ device_check: NoCheck # TensorIterator dispatch: CPU, CUDA: heaviside_out + tags: pointwise - func: heaviside(Tensor self, Tensor values) -> Tensor device_check: NoCheck # TensorIterator variants: function, method structured_delegate: heaviside.out + tags: pointwise - func: heaviside_(Tensor(a!) self, Tensor values) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -6073,6 +6380,7 @@ # Functionally the same as addmm, but we give it a different derivative formula # that doesn't propagate gradients to non-present entries on sparse. + tags: pointwise - func: _sparse_addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor python_module: sparse dispatch: @@ -6285,9 +6593,9 @@ SparseCPU, SparseCUDA, SparseMeta, Meta: new_with_dims_sparse autogen: _sparse_coo_tensor_with_dims.out -- func: _sparse_coo_tensor_with_dims_and_tensors(SymInt sparse_dim, SymInt dense_dim, SymInt[] size, Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor +- func: _sparse_coo_tensor_with_dims_and_tensors(int sparse_dim, int dense_dim, SymInt[] size, Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor dispatch: - SparseCPU, SparseCUDA, SparseMeta, Meta: new_with_dims_and_tensor_sparse + SparseCPU, SparseCUDA, SparseMeta, Meta: new_with_dims_and_tensor_sparse_symint autogen: _sparse_coo_tensor_with_dims_and_tensors.out - func: sparse_resize_(Tensor(a!) self, int[] size, int sparse_dim, int dense_dim) -> Tensor(a!) @@ -6495,10 +6803,11 @@ SparseCsrCPU, SparseCsrCUDA: sparse_compressed_to_sparse autogen: to_sparse.sparse_dim_out -- func: to_sparse(Tensor self) -> Tensor +- func: to_sparse(Tensor self, *, Layout? layout=None, int[2]? blocksize=None) -> Tensor variants: method dispatch: CPU, CUDA: dense_to_sparse + SparseCPU, SparseCUDA: sparse_coo_to_sparse SparseCsrCPU, SparseCsrCUDA: sparse_compressed_to_sparse autogen: to_sparse.out @@ -6540,7 +6849,7 @@ CPU: dense_to_mkldnn autogen: to_mkldnn.out -- func: mkldnn_reorder_conv2d_weight(Tensor self, int[2] padding=0, int[2] stride=1, int[2] dilation=1, int groups=1) -> Tensor +- func: mkldnn_reorder_conv2d_weight(Tensor self, int[2] padding=0, int[2] stride=1, int[2] dilation=1, int groups=1, int[]? input_size=None) -> Tensor variants: function python_module: nn dispatch: @@ -6985,7 +7294,7 @@ - func: lift_fresh_copy(Tensor self) -> Tensor tags: view_copy dispatch: - CompositeExplicitAutograd: lift_fresh_copy + CompositeExplicitAutogradNonFunctional: lift_fresh_copy autogen: lift_fresh_copy.out - func: is_set_to(Tensor self, Tensor tensor) -> bool @@ -7011,6 +7320,7 @@ variants: function, method dispatch: CompositeExplicitAutograd: masked_fill + tags: pointwise - func: masked_fill_.Tensor(Tensor(a!) self, Tensor mask, Tensor value) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -7079,7 +7389,7 @@ - func: put_(Tensor(a!) self, Tensor index, Tensor source, bool accumulate=False) -> Tensor(a!) variants: method dispatch: - CPU, CUDA, MPS: put_ + CPU, CUDA: put_ autogen: put.out - func: put(Tensor self, Tensor index, Tensor source, bool accumulate=False) -> Tensor @@ -7284,18 +7594,21 @@ variants: function dispatch: CPU, CUDA: bitwise_and_out + tags: pointwise - func: bitwise_and.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator variants: function dispatch: CompositeExplicitAutograd: bitwise_and_out + tags: pointwise - func: bitwise_and.Scalar(Tensor self, Scalar other) -> Tensor device_check: NoCheck # TensorIterator variants: method, function dispatch: CompositeExplicitAutograd: bitwise_and + tags: pointwise - func: bitwise_and.Scalar_Tensor(Scalar self, Tensor other) -> Tensor device_check: NoCheck # TensorIterator @@ -7303,21 +7616,24 @@ dispatch: CompositeExplicitAutograd: bitwise_and autogen: bitwise_and.Scalar_Tensor_out + tags: pointwise - func: bitwise_and.Tensor(Tensor self, Tensor other) -> Tensor device_check: NoCheck # TensorIterator variants: method, function structured_delegate: bitwise_and.Tensor_out - tags: canonical + tags: [canonical, pointwise] - func: bitwise_and_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) device_check: NoCheck # TensorIterator variants: method + tags: pointwise - func: bitwise_and_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) device_check: NoCheck # TensorIterator variants: method structured_delegate: bitwise_and.Tensor_out + tags: pointwise - func: __and__.Scalar(Tensor self, Scalar other) -> Tensor device_check: NoCheck # TensorIterator @@ -7342,16 +7658,19 @@ variants: function dispatch: CPU, CUDA: bitwise_or_out + tags: pointwise - func: bitwise_or.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator variants: function dispatch: CompositeExplicitAutograd: bitwise_or_out + tags: pointwise - func: bitwise_or.Scalar(Tensor self, Scalar other) -> Tensor device_check: NoCheck # TensorIterator variants: method, function + tags: pointwise - func: bitwise_or.Scalar_Tensor(Scalar self, Tensor other) -> Tensor device_check: NoCheck # TensorIterator @@ -7359,21 +7678,24 @@ dispatch: CompositeExplicitAutograd: bitwise_or autogen: bitwise_or.Scalar_Tensor_out + tags: pointwise - func: bitwise_or.Tensor(Tensor self, Tensor other) -> Tensor device_check: NoCheck # TensorIterator variants: method, function structured_delegate: bitwise_or.Tensor_out - tags: canonical + tags: [canonical, pointwise] - func: bitwise_or_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) device_check: NoCheck # TensorIterator variants: method + tags: pointwise - func: bitwise_or_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) device_check: NoCheck # TensorIterator variants: method structured_delegate: bitwise_or.Tensor_out + tags: pointwise - func: __or__.Scalar(Tensor self, Scalar other) -> Tensor device_check: NoCheck # TensorIterator @@ -7398,16 +7720,19 @@ variants: function dispatch: CPU, CUDA: bitwise_xor_out + tags: pointwise - func: bitwise_xor.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator variants: function dispatch: CompositeExplicitAutograd: bitwise_xor_out + tags: pointwise - func: bitwise_xor.Scalar(Tensor self, Scalar other) -> Tensor device_check: NoCheck # TensorIterator variants: method, function + tags: pointwise - func: bitwise_xor.Scalar_Tensor(Scalar self, Tensor other) -> Tensor device_check: NoCheck # TensorIterator @@ -7415,48 +7740,58 @@ dispatch: CompositeExplicitAutograd: bitwise_xor autogen: bitwise_xor.Scalar_Tensor_out + tags: pointwise - func: bitwise_xor.Tensor(Tensor self, Tensor other) -> Tensor device_check: NoCheck # TensorIterator variants: method, function structured_delegate: bitwise_xor.Tensor_out + tags: pointwise - func: bitwise_xor_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) device_check: NoCheck # TensorIterator variants: method + tags: pointwise - func: bitwise_xor_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) device_check: NoCheck # TensorIterator variants: method structured_delegate: bitwise_xor.Tensor_out + tags: pointwise - func: __xor__.Scalar(Tensor self, Scalar other) -> Tensor device_check: NoCheck # TensorIterator variants: method, function + tags: pointwise - func: __xor__.Tensor(Tensor self, Tensor other) -> Tensor device_check: NoCheck # TensorIterator variants: method, function + tags: pointwise - func: __ixor__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) device_check: NoCheck # TensorIterator variants: method + tags: pointwise - func: __ixor__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) device_check: NoCheck # TensorIterator variants: method + tags: pointwise - func: __lshift__.Scalar(Tensor self, Scalar other) -> Tensor device_check: NoCheck # TensorIterator variants: method, function dispatch: CPU, CUDA: __lshift__ + tags: pointwise - func: __lshift__.Tensor(Tensor self, Tensor other) -> Tensor device_check: NoCheck # TensorIterator variants: method, function dispatch: CPU, CUDA: __lshift__ + tags: pointwise - func: __ilshift__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -7464,6 +7799,7 @@ dispatch: CPU, CUDA: __ilshift__ autogen: __lshift__.Scalar_out + tags: pointwise - func: __ilshift__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -7471,16 +7807,19 @@ dispatch: CPU, CUDA: __ilshift__ autogen: __lshift__.Tensor_out + tags: pointwise - func: bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor device_check: NoCheck # TensorIterator variants: function, method structured_delegate: bitwise_left_shift.Tensor_out + tags: pointwise - func: bitwise_left_shift_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) device_check: NoCheck # TensorIterator variants: method structured_delegate: bitwise_left_shift.Tensor_out + tags: pointwise - func: bitwise_left_shift.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -7488,24 +7827,28 @@ structured_inherits: TensorIteratorBase dispatch: CPU, CUDA: bitwise_left_shift_out + tags: pointwise - func: bitwise_left_shift.Tensor_Scalar(Tensor self, Scalar other) -> Tensor device_check: NoCheck # TensorIterator variants: method, function dispatch: CompositeExplicitAutograd: bitwise_left_shift + tags: pointwise - func: bitwise_left_shift_.Tensor_Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) device_check: NoCheck # TensorIterator variants: method dispatch: CompositeExplicitAutograd: bitwise_left_shift_ + tags: pointwise - func: bitwise_left_shift.Tensor_Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator variants: function dispatch: CompositeExplicitAutograd: bitwise_left_shift_out + tags: pointwise - func: bitwise_left_shift.Scalar_Tensor(Scalar self, Tensor other) -> Tensor device_check: NoCheck # TensorIterator @@ -7513,18 +7856,21 @@ dispatch: CompositeExplicitAutograd: bitwise_left_shift autogen: bitwise_left_shift.Scalar_Tensor_out + tags: pointwise - func: __rshift__.Scalar(Tensor self, Scalar other) -> Tensor device_check: NoCheck # TensorIterator variants: method, function dispatch: CPU, CUDA: __rshift__ + tags: pointwise - func: __rshift__.Tensor(Tensor self, Tensor other) -> Tensor device_check: NoCheck # TensorIterator variants: method, function dispatch: CPU, CUDA: __rshift__ + tags: pointwise - func: __irshift__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -7544,11 +7890,13 @@ device_check: NoCheck # TensorIterator variants: function, method structured_delegate: bitwise_right_shift.Tensor_out + tags: pointwise - func: bitwise_right_shift_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) device_check: NoCheck # TensorIterator variants: method structured_delegate: bitwise_right_shift.Tensor_out + tags: pointwise - func: bitwise_right_shift.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -7556,24 +7904,28 @@ structured_inherits: TensorIteratorBase dispatch: CPU, CUDA: bitwise_right_shift_out + tags: pointwise - func: bitwise_right_shift.Tensor_Scalar(Tensor self, Scalar other) -> Tensor device_check: NoCheck # TensorIterator variants: method, function dispatch: CompositeExplicitAutograd: bitwise_right_shift + tags: pointwise - func: bitwise_right_shift_.Tensor_Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) device_check: NoCheck # TensorIterator variants: method dispatch: CompositeExplicitAutograd: bitwise_right_shift_ + tags: pointwise - func: bitwise_right_shift.Tensor_Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator variants: function dispatch: CompositeExplicitAutograd: bitwise_right_shift_out + tags: pointwise - func: bitwise_right_shift.Scalar_Tensor(Scalar self, Tensor other) -> Tensor device_check: NoCheck # TensorIterator @@ -7581,6 +7933,7 @@ dispatch: CompositeExplicitAutograd: bitwise_right_shift autogen: bitwise_right_shift.Scalar_Tensor_out + tags: pointwise - func: tril_(Tensor(a!) self, int diagonal=0) -> Tensor(a!) structured_delegate: tril.out @@ -7594,16 +7947,19 @@ device_check: NoCheck # TensorIterator structured_delegate: digamma.out variants: method + tags: pointwise - func: lerp_.Scalar(Tensor(a!) self, Tensor end, Scalar weight) -> Tensor(a!) device_check: NoCheck # TensorIterator variants: method structured_delegate: lerp.Scalar_out + tags: pointwise - func: lerp_.Tensor(Tensor(a!) self, Tensor end, Tensor weight) -> Tensor(a!) device_check: NoCheck # TensorIterator variants: method structured_delegate: lerp.Tensor_out + tags: pointwise - func: addbmm_(Tensor(a!) self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) variants: method @@ -7697,22 +8053,9 @@ autogen: geometric, geometric.out - func: diag.out(Tensor self, int diagonal=0, *, Tensor(a!) out) -> Tensor(a!) - dispatch: - CPU: diag_cpu_out - CUDA: diag_cuda_out - MPS: diag_mps_out - func: diag(Tensor self, int diagonal=0) -> Tensor variants: method, function - dispatch: - CompositeExplicitAutograd: diag - -- func: diag_backward(Tensor grad, SymInt[] input_sizes, int diagonal) -> Tensor - variants: function - device_check: NoCheck - device_guard: False - dispatch: - CompositeImplicitAutograd: diag_backward_symint - func: cross.out(Tensor self, Tensor other, int? dim=None, *, Tensor(a!) out) -> Tensor(a!) @@ -7758,6 +8101,7 @@ dispatch: CPU: trace_cpu CUDA: trace_cuda + MPS: trace_mps_out autogen: trace.out - func: trace_backward(Tensor grad, SymInt[] sizes) -> Tensor @@ -7775,6 +8119,7 @@ CPU, CUDA: ne_Scalar_out MPS: ne_scalar_out_mps QuantizedCPU: ne_out_quantized_cpu + tags: pointwise - func: ne.Scalar(Tensor self, Scalar other) -> Tensor structured_delegate: ne.Scalar_out @@ -7782,6 +8127,7 @@ variants: method, function dispatch: QuantizedCPU: ne_quantized_cpu + tags: pointwise - func: ne.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) structured: True @@ -7791,6 +8137,7 @@ CPU, CUDA: ne_Tensor_out MPS: ne_tensor_out_mps QuantizedCPU: ne_out_quantized_cpu + tags: pointwise - func: ne.Tensor(Tensor self, Tensor other) -> Tensor structured_delegate: ne.Tensor_out @@ -7798,6 +8145,7 @@ variants: method, function dispatch: QuantizedCPU: ne_quantized_cpu + tags: pointwise - func: ne_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) structured_delegate: ne.Scalar_out @@ -7834,6 +8182,7 @@ CPU, CUDA: eq_Scalar_out MPS: eq_scalar_out_mps QuantizedCPU: eq_out_quantized_cpu + tags: pointwise - func: eq.Scalar(Tensor self, Scalar other) -> Tensor structured_delegate: eq.Scalar_out @@ -7841,7 +8190,7 @@ variants: method, function dispatch: QuantizedCPU: eq_quantized_cpu - tags: canonical + tags: [canonical, pointwise] - func: eq.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) structured: True @@ -7851,6 +8200,7 @@ CPU, CUDA: eq_Tensor_out MPS: eq_tensor_out_mps QuantizedCPU: eq_out_quantized_cpu + tags: pointwise - func: eq.Tensor(Tensor self, Tensor other) -> Tensor structured_delegate: eq.Tensor_out @@ -7858,6 +8208,7 @@ variants: method, function dispatch: QuantizedCPU: eq_quantized_cpu + tags: pointwise - func: ge.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) structured: True @@ -7867,6 +8218,7 @@ CPU, CUDA: ge_Scalar_out MPS: ge_scalar_out_mps QuantizedCPU: ge_out_quantized_cpu + tags: pointwise - func: ge.Scalar(Tensor self, Scalar other) -> Tensor structured_delegate: ge.Scalar_out @@ -7874,7 +8226,7 @@ variants: method, function dispatch: QuantizedCPU: ge_quantized_cpu - tags: canonical + tags: [canonical, pointwise] - func: ge.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) structured: True @@ -7884,6 +8236,7 @@ CPU, CUDA: ge_Tensor_out MPS: ge_tensor_out_mps QuantizedCPU: ge_out_quantized_cpu + tags: pointwise - func: ge.Tensor(Tensor self, Tensor other) -> Tensor structured_delegate: ge.Tensor_out @@ -7891,6 +8244,7 @@ variants: method, function dispatch: QuantizedCPU: ge_quantized_cpu + tags: pointwise - func: ge_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) structured_delegate: ge.Scalar_out @@ -7927,6 +8281,7 @@ CPU, CUDA: le_Scalar_out MPS: le_scalar_out_mps QuantizedCPU: le_out_quantized_cpu + tags: pointwise - func: le.Scalar(Tensor self, Scalar other) -> Tensor structured_delegate: le.Scalar_out @@ -7934,7 +8289,7 @@ variants: method, function dispatch: QuantizedCPU: le_quantized_cpu - tags: canonical + tags: [canonical, pointwise] - func: le.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) structured: True @@ -7944,6 +8299,7 @@ CPU, CUDA: le_Tensor_out MPS: le_tensor_out_mps QuantizedCPU: le_out_quantized_cpu + tags: pointwise - func: le.Tensor(Tensor self, Tensor other) -> Tensor structured_delegate: le.Tensor_out @@ -7951,6 +8307,7 @@ variants: method, function dispatch: QuantizedCPU: le_quantized_cpu + tags: pointwise - func: le_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) structured_delegate: le.Scalar_out @@ -7987,6 +8344,7 @@ CPU, CUDA: gt_Scalar_out MPS: gt_scalar_out_mps QuantizedCPU: gt_out_quantized_cpu + tags: pointwise - func: gt.Scalar(Tensor self, Scalar other) -> Tensor structured_delegate: gt.Scalar_out @@ -7994,7 +8352,7 @@ variants: method, function dispatch: QuantizedCPU: gt_quantized_cpu - tags: canonical + tags: [canonical, pointwise] - func: gt.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) structured: True @@ -8004,6 +8362,7 @@ CPU, CUDA: gt_Tensor_out MPS: gt_tensor_out_mps QuantizedCPU: gt_out_quantized_cpu + tags: pointwise - func: gt.Tensor(Tensor self, Tensor other) -> Tensor structured_delegate: gt.Tensor_out @@ -8011,6 +8370,7 @@ variants: method, function dispatch: QuantizedCPU: gt_quantized_cpu + tags: pointwise - func: gt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) structured_delegate: gt.Scalar_out @@ -8047,6 +8407,7 @@ CPU, CUDA: lt_Scalar_out MPS: lt_scalar_out_mps QuantizedCPU: lt_out_quantized_cpu + tags: pointwise - func: lt.Scalar(Tensor self, Scalar other) -> Tensor structured_delegate: lt.Scalar_out @@ -8054,7 +8415,7 @@ variants: method, function dispatch: QuantizedCPU: lt_quantized_cpu - tags: canonical + tags: [canonical, pointwise] - func: lt.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) structured: True @@ -8064,6 +8425,7 @@ CPU, CUDA: lt_Tensor_out MPS: lt_tensor_out_mps QuantizedCPU: lt_out_quantized_cpu + tags: pointwise - func: lt.Tensor(Tensor self, Tensor other) -> Tensor structured_delegate: lt.Tensor_out @@ -8071,6 +8433,7 @@ variants: method, function dispatch: QuantizedCPU: lt_quantized_cpu + tags: pointwise - func: lt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) structured_delegate: lt.Scalar_out @@ -8167,6 +8530,7 @@ dispatch: CPU: nonzero_out_cpu CUDA: nonzero_out_cuda + MPS: nonzero_out_mps tags: dynamic_output_shape - func: nonzero(Tensor self) -> Tensor @@ -8174,7 +8538,8 @@ dispatch: CPU: nonzero_cpu CUDA: nonzero_cuda - tags: dynamic_output_shape, canonical + MPS: nonzero_mps + tags: [dynamic_output_shape, canonical] - func: nonzero_numpy(Tensor self) -> Tensor[] variants: method, function @@ -8213,16 +8578,19 @@ dispatch: CPU, CUDA: addcmul_out MPS: addcmul_out_mps + tags: pointwise - func: addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor structured_delegate: addcmul.out device_check: NoCheck # TensorIterator variants: method, function + tags: pointwise - func: addcmul_(Tensor(a!) self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor(a!) structured_delegate: addcmul.out device_check: NoCheck # TensorIterator variants: method + tags: pointwise - func: addcdiv.out(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1, Tensor(a!) out) -> Tensor(a!) structured: True @@ -8231,16 +8599,19 @@ dispatch: CPU, CUDA: addcdiv_out MPS: addcdiv_out_mps + tags: pointwise - func: addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor structured_delegate: addcdiv.out device_check: NoCheck # TensorIterator variants: method, function + tags: pointwise - func: addcdiv_(Tensor(a!) self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor(a!) structured_delegate: addcdiv.out device_check: NoCheck # TensorIterator variants: method + tags: pointwise - func: cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor python_module: nn @@ -8251,6 +8622,7 @@ structured: True dispatch: CPU, CUDA: triangular_solve_out + MPS: triangular_solve_mps_out SparseCsrCPU: triangular_solve_out_sparse_csr_cpu SparseCsrCUDA: triangular_solve_out_sparse_csr_cuda @@ -8266,12 +8638,14 @@ python_module: linalg dispatch: CPU, CUDA: linalg_solve_triangular_out + MPS: linalg_solve_triangular_mps_out - func: linalg_solve_triangular(Tensor self, Tensor B, *, bool upper, bool left=True, bool unitriangular=False) -> Tensor python_module: linalg variants: function dispatch: CPU, CUDA: linalg_solve_triangular + MPS: linalg_solve_triangular_mps - func: linalg_vander(Tensor x, *, int? N=None) -> Tensor python_module: linalg @@ -8423,16 +8797,19 @@ structured_inherits: TensorIteratorBase dispatch: CPU, CUDA: lgamma_out + tags: pointwise - func: lgamma_(Tensor(a!) self) -> Tensor(a!) device_check: NoCheck # TensorIterator structured_delegate: lgamma.out variants: method + tags: pointwise - func: lgamma(Tensor self) -> Tensor device_check: NoCheck # TensorIterator structured_delegate: lgamma.out variants: method, function + tags: pointwise - func: digamma.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -8440,11 +8817,13 @@ structured_inherits: TensorIteratorBase dispatch: CPU, CUDA: digamma_out + tags: pointwise - func: digamma(Tensor self) -> Tensor device_check: NoCheck # TensorIterator structured_delegate: digamma.out variants: method, function + tags: pointwise - func: polygamma.out(int n, Tensor self, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -8452,17 +8831,20 @@ structured_inherits: TensorIteratorBase dispatch: CPU, CUDA: polygamma_out + tags: pointwise - func: polygamma(int n, Tensor self) -> Tensor device_check: NoCheck # TensorIterator structured_delegate: polygamma.out variants: method, function + tags: pointwise - func: polygamma_(Tensor(a!) self, int n) -> Tensor(a!) device_check: NoCheck # TensorIterator variants: method dispatch: CompositeExplicitAutograd: polygamma_ + tags: pointwise - func: erfinv(Tensor self) -> Tensor device_check: NoCheck # TensorIterator @@ -8471,6 +8853,7 @@ dispatch: SparseCPU, SparseCUDA: erfinv_sparse SparseCsrCPU, SparseCsrCUDA: erfinv_sparse_csr + tags: pointwise - func: erfinv_(Tensor(a!) self) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -8479,6 +8862,7 @@ dispatch: SparseCPU, SparseCUDA: erfinv_sparse_ SparseCsrCPU, SparseCsrCUDA: erfinv_sparse_csr_ + tags: pointwise - func: erfinv.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -8488,20 +8872,24 @@ CPU, CUDA: erfinv_out SparseCPU, SparseCUDA: erfinv_sparse_out SparseCsrCPU, SparseCsrCUDA: erfinv_sparse_csr_out + tags: pointwise - func: i0(Tensor self) -> Tensor structured_delegate: i0.out variants: function, method + tags: pointwise - func: i0_(Tensor(a!) self) -> Tensor(a!) structured_delegate: i0.out variants: function, method + tags: pointwise - func: i0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) structured: True structured_inherits: TensorIteratorBase dispatch: CPU, CUDA: i0_out + tags: pointwise - func: sign(Tensor self) -> Tensor device_check: NoCheck # TensorIterator @@ -8510,6 +8898,7 @@ dispatch: SparseCPU, SparseCUDA: sign_sparse SparseCsrCPU, SparseCsrCUDA: sign_sparse_csr + tags: pointwise - func: sign_(Tensor(a!) self) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -8518,6 +8907,7 @@ dispatch: SparseCPU, SparseCUDA: sign_sparse_ SparseCsrCPU, SparseCsrCUDA: sign_sparse_csr_ + tags: pointwise - func: sign.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -8528,6 +8918,7 @@ MPS: sign_out_mps SparseCPU, SparseCUDA: sign_sparse_out SparseCsrCPU, SparseCsrCUDA: sign_sparse_csr_out + tags: pointwise - func: signbit(Tensor self) -> Tensor variants: function, method @@ -8535,6 +8926,7 @@ dispatch: SparseCPU, SparseCUDA: signbit_sparse SparseCsrCPU, SparseCsrCUDA: signbit_sparse_csr + tags: pointwise - func: signbit.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) structured: True @@ -8542,8 +8934,10 @@ dispatch: CPU: signbit_out CUDA: signbit_out + MPS: signbit_out_mps SparseCPU, SparseCUDA: signbit_sparse_out SparseCsrCPU, SparseCsrCUDA: signbit_sparse_csr_out + tags: pointwise - func: dist(Tensor self, Tensor other, Scalar p=2) -> Tensor device_check: NoCheck # TensorIterator @@ -8559,18 +8953,21 @@ dispatch: CPU, CUDA: atan2_out MPS: atan2_mps_out + tags: pointwise - func: atan2_(Tensor(a!) self, Tensor other) -> Tensor(a!) device_check: NoCheck # TensorIterator structured_delegate: atan2.out variants: method + tags: pointwise - func: atan2(Tensor self, Tensor other) -> Tensor device_check: NoCheck # TensorIterator structured_delegate: atan2.out variants: method, function - + tags: pointwise # arctan2, alias of atan2 + - func: arctan2(Tensor self, Tensor other) -> Tensor variants: method, function @@ -8586,6 +8983,7 @@ structured_inherits: TensorIteratorBase dispatch: CPU, CUDA: lerp_Scalar + tags: pointwise - func: lerp.Tensor_out(Tensor self, Tensor end, Tensor weight, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -8593,16 +8991,19 @@ structured_inherits: TensorIteratorBase dispatch: CPU, CUDA: lerp_Tensor + tags: pointwise - func: lerp.Scalar(Tensor self, Tensor end, Scalar weight) -> Tensor device_check: NoCheck # TensorIterator variants: method, function structured_delegate: lerp.Scalar_out + tags: pointwise - func: lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor device_check: NoCheck # TensorIterator variants: method, function structured_delegate: lerp.Tensor_out + tags: pointwise - func: histc.out(Tensor self, int bins=100, Scalar min=0, Scalar max=0, *, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -8658,18 +9059,21 @@ device_check: NoCheck # TensorIterator dispatch: CompositeExplicitAutograd: fmod_out + tags: pointwise - func: fmod.Scalar(Tensor self, Scalar other) -> Tensor device_check: NoCheck # TensorIterator variants: method, function dispatch: CompositeExplicitAutograd: fmod + tags: pointwise - func: fmod_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) device_check: NoCheck # TensorIterator variants: method dispatch: CompositeExplicitAutograd: fmod_ + tags: pointwise - func: fmod.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -8677,87 +9081,104 @@ structured_inherits: TensorIteratorBase dispatch: CPU, CUDA: fmod_out + tags: pointwise - func: fmod.Tensor(Tensor self, Tensor other) -> Tensor device_check: NoCheck # TensorIterator structured_delegate: fmod.Tensor_out variants: method, function - + tags: pointwise - func: fmod_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) device_check: NoCheck # TensorIterator variants: method structured_delegate: fmod.Tensor_out + tags: pointwise - func: hypot.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) structured: True structured_inherits: TensorIteratorBase dispatch: CPU, CUDA: hypot_out + tags: pointwise - func: hypot(Tensor self, Tensor other) -> Tensor structured_delegate: hypot.out variants: method, function + tags: pointwise - func: hypot_(Tensor(a!) self, Tensor other) -> Tensor(a!) structured_delegate: hypot.out variants: method + tags: pointwise - func: igamma.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) structured: True structured_inherits: TensorIteratorBase dispatch: CPU, CUDA: igamma_out + tags: pointwise - func: igamma(Tensor self, Tensor other) -> Tensor structured_delegate: igamma.out variants: method, function + tags: pointwise - func: igamma_(Tensor(a!) self, Tensor other) -> Tensor(a!) structured_delegate: igamma.out variants: method + tags: pointwise - func: igammac.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) structured: True structured_inherits: TensorIteratorBase dispatch: CPU, CUDA: igammac_out + tags: pointwise - func: igammac(Tensor self, Tensor other) -> Tensor structured_delegate: igammac.out variants: method, function + tags: pointwise - func: igammac_(Tensor(a!) self, Tensor other) -> Tensor(a!) structured_delegate: igammac.out variants: method + tags: pointwise - func: nextafter.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) structured: True structured_inherits: TensorIteratorBase dispatch: CPU, CUDA: nextafter_out + tags: pointwise - func: nextafter(Tensor self, Tensor other) -> Tensor structured_delegate: nextafter.out variants: method, function + tags: pointwise - func: nextafter_(Tensor(a!) self, Tensor other) -> Tensor(a!) structured_delegate: nextafter.out variants: method + tags: pointwise - func: remainder.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) dispatch: CompositeExplicitAutograd: remainder_out + tags: pointwise - func: remainder.Scalar(Tensor self, Scalar other) -> Tensor variants: method, function dispatch: CompositeExplicitAutograd: remainder + tags: pointwise - func: remainder_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) variants: method dispatch: CompositeExplicitAutograd: remainder_ + tags: pointwise - func: remainder.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -8765,16 +9186,20 @@ structured_inherits: TensorIteratorBase dispatch: CPU, CUDA: remainder_out + MPS: remainder_out_mps + tags: pointwise - func: remainder.Tensor(Tensor self, Tensor other) -> Tensor device_check: NoCheck # TensorIterator structured_delegate: remainder.Tensor_out variants: method, function + tags: pointwise - func: remainder_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) device_check: NoCheck # TensorIterator structured_delegate: remainder.Tensor_out variants: method + tags: pointwise - func: remainder.Scalar_Tensor(Scalar self, Tensor other) -> Tensor device_check: NoCheck # TensorIterator @@ -8782,6 +9207,7 @@ dispatch: CPU, CUDA: remainder autogen: remainder.Scalar_Tensor_out + tags: pointwise - func: min(Tensor self) -> Tensor device_check: NoCheck # TensorIterator @@ -8802,6 +9228,7 @@ structured_delegate: fmin.out device_check: NoCheck # TensorIterator variants: method, function + tags: pointwise - func: fmin.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) structured: True @@ -8809,6 +9236,7 @@ device_check: NoCheck # TensorIterator dispatch: CPU, CUDA: fmin_out + tags: pointwise - func: max(Tensor self) -> Tensor device_check: NoCheck # TensorIterator @@ -8822,6 +9250,7 @@ structured_delegate: fmax.out device_check: NoCheck # TensorIterator variants: method, function + tags: pointwise - func: fmax.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) structured: True @@ -8829,12 +9258,13 @@ device_check: NoCheck # TensorIterator dispatch: CPU, CUDA: fmax_out + tags: pointwise - func: maximum(Tensor self, Tensor other) -> Tensor structured_delegate: maximum.out device_check: NoCheck # TensorIterator variants: method, function - tags: canonical + tags: [canonical, pointwise] - func: maximum.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) structured: True @@ -8843,15 +9273,18 @@ dispatch: CPU, CUDA: maximum_out MPS: maximum_out_mps + tags: pointwise # binary max, alias of maximum # NOTE: max is not an alias for maximum, since there is also unary max - func: max.other(Tensor self, Tensor other) -> Tensor device_check: NoCheck # TensorIterator variants: method, function + tags: pointwise - func: max.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator + tags: pointwise - func: max.unary_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -8863,7 +9296,7 @@ structured_delegate: minimum.out device_check: NoCheck # TensorIterator variants: method, function - tags: canonical + tags: [canonical, pointwise] - func: minimum.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) structured: True @@ -8872,15 +9305,18 @@ dispatch: CPU, CUDA: minimum_out MPS: minimum_out_mps + tags: pointwise # binary min, alias for minimum # NOTE: min is not an alias for minimum, since there is also unary min - func: min.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator + tags: pointwise - func: min.other(Tensor self, Tensor other) -> Tensor device_check: NoCheck # TensorIterator variants: method, function + tags: pointwise - func: quantile(Tensor self, Tensor q, int? dim=None, bool keepdim=False, *, str interpolation='linear') -> Tensor variants: method, function @@ -9013,7 +9449,7 @@ device_check: NoCheck device_guard: False dispatch: - CPU, CUDA, Meta: unfold + CPU, CUDA, Meta, MPS: unfold QuantizedCPU, QuantizedCUDA: unfold - func: unfold_backward(Tensor grad_in, SymInt[] input_sizes, int dim, int size, int step) -> Tensor @@ -9023,7 +9459,7 @@ autogen: unfold_backward.out - func: equal(Tensor self, Tensor other) -> bool - tags: data_dependent_output + tags: [data_dependent_output, pointwise] variants: method, function dispatch: CPU: cpu_equal @@ -9038,21 +9474,25 @@ dispatch: CPU, CUDA: pow_Tensor_Tensor_out MPS: pow_tensor_tensor_out_mps + tags: pointwise - func: pow.Tensor_Tensor(Tensor self, Tensor exponent) -> Tensor device_check: NoCheck # TensorIterator structured_delegate: pow.Tensor_Tensor_out variants: method, function + tags: pointwise - func: pow.Scalar_out(Scalar self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator structured: True dispatch: CPU, CUDA: pow_Scalar_out + tags: pointwise - func: pow.Scalar(Scalar self, Tensor exponent) -> Tensor device_check: NoCheck # TensorIterator structured_delegate: pow.Scalar_out + tags: pointwise - func: pow.Tensor_Scalar_out(Tensor self, Scalar exponent, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -9062,6 +9502,7 @@ CPU, CUDA: pow_Tensor_Scalar_out SparseCPU, SparseCUDA: pow_out_sparse_scalar MPS: pow_tensor_scalar_out_mps + tags: pointwise - func: pow.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor device_check: NoCheck # TensorIterator @@ -9069,37 +9510,47 @@ variants: function, method dispatch: SparseCPU, SparseCUDA: pow_sparse_scalar - tags: canonical + tags: [canonical, pointwise] - func: pow_.Scalar(Tensor(a!) self, Scalar exponent) -> Tensor(a!) device_check: NoCheck # TensorIterator structured_delegate: pow.Tensor_Scalar_out variants: method + tags: pointwise - func: pow_.Tensor(Tensor(a!) self, Tensor exponent) -> Tensor(a!) device_check: NoCheck # TensorIterator structured_delegate: pow.Tensor_Tensor_out variants: method + tags: pointwise - func: float_power.Tensor_Tensor_out(Tensor self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!) + tags: pointwise - func: float_power.Tensor_Tensor(Tensor self, Tensor exponent) -> Tensor variants: function, method + tags: pointwise - func: float_power.Scalar_out(Scalar self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!) + tags: pointwise - func: float_power.Scalar(Scalar self, Tensor exponent) -> Tensor + tags: pointwise - func: float_power.Tensor_Scalar_out(Tensor self, Scalar exponent, *, Tensor(a!) out) -> Tensor(a!) + tags: pointwise - func: float_power.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor variants: function, method + tags: pointwise - func: float_power_.Scalar(Tensor(a!) self, Scalar exponent) -> Tensor(a!) variants: method + tags: pointwise - func: float_power_.Tensor(Tensor(a!) self, Tensor exponent) -> Tensor(a!) variants: method + tags: pointwise - func: normal_(Tensor(a!) self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -9835,6 +10286,14 @@ CUDA: foreach_tensor_addcdiv_scalarlist_cuda_ autogen: _foreach_addcdiv.ScalarList_out +- func: _foreach_addcdiv_.Tensor(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CPU: foreach_tensor_addcdiv_tensor_slow_ + CUDA: foreach_tensor_addcdiv_tensor_cuda_ + autogen: _foreach_addcdiv.Tensor_out + - func: _foreach_addcmul_.ScalarList(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function @@ -9843,6 +10302,14 @@ CUDA: foreach_tensor_addcmul_scalarlist_cuda_ autogen: _foreach_addcmul.ScalarList_out +- func: _foreach_addcmul_.Tensor(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars) -> () + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CPU: foreach_tensor_addcmul_tensor_slow_ + CUDA: foreach_tensor_addcmul_tensor_cuda_ + autogen: _foreach_addcmul.Tensor_out + - func: _foreach_addcdiv.Scalar(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function @@ -9864,6 +10331,13 @@ CPU: foreach_tensor_addcdiv_scalarlist_slow CUDA: foreach_tensor_addcdiv_scalarlist_cuda +- func: _foreach_addcdiv.Tensor(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CPU: foreach_tensor_addcdiv_tensor_slow + CUDA: foreach_tensor_addcdiv_tensor_cuda + - func: _foreach_addcmul.ScalarList(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function @@ -9871,6 +10345,13 @@ CPU: foreach_tensor_addcmul_scalarlist_slow CUDA: foreach_tensor_addcmul_scalarlist_cuda +- func: _foreach_addcmul.Tensor(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CPU: foreach_tensor_addcmul_tensor_slow + CUDA: foreach_tensor_addcmul_tensor_cuda + - func: _foreach_maximum.List(Tensor[] self, Tensor[] other) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function @@ -9930,17 +10411,6 @@ CPU: searchsorted_cpu CUDA: searchsorted_cuda -# [Note about _torch_cuda_cu_linker_symbol_op and torch_cuda_cu] -# This is a DUMMY function to force the linking against torch_cuda_cu on Windows. -# Otherwise, the Windows linker will optimize and not include torch_cuda_cu even when we -# want it to be included. This is similar to what we do with warp_size for torch_cuda_cpp, -# described as the solution to this issue: https://github.com/pytorch/pytorch/issues/31611 -# This op should NOT be used or exposed or edited or else Windows builds (with BUILD_SPLIT_CUDA) will break. -- func: _torch_cuda_cu_linker_symbol_op(Tensor self) -> Tensor - dispatch: - CUDA: _torch_cuda_cu_linker_symbol_op_cuda - autogen: _torch_cuda_cu_linker_symbol_op.out - - func: searchsorted.Tensor_out(Tensor sorted_sequence, Tensor self, *, bool out_int32=False, bool right=False, str? side=None, Tensor? sorter=None, Tensor(a!) out) -> Tensor(a!) dispatch: CPU: searchsorted_out_cpu @@ -10338,23 +10808,27 @@ python_module: nn dispatch: CPU, CUDA: hardswish_out + MPS: hardswish_out_mps - func: hardswish(Tensor self) -> Tensor device_check: NoCheck # TensorIterator python_module: nn dispatch: CPU, CUDA: hardswish + MPS: hardswish_mps - func: hardswish_(Tensor(a!) self) -> Tensor(a!) device_check: NoCheck # TensorIterator python_module: nn dispatch: CPU, CUDA: hardswish_ + MPS: hardswish_mps_ - func: hardswish_backward(Tensor grad_output, Tensor self) -> Tensor python_module: nn dispatch: CPU, CUDA: hardswish_backward + MPS: hardswish_backward_mps autogen: hardswish_backward.out - func: leaky_relu.out(Tensor self, Scalar negative_slope=0.01, *, Tensor(a!) out) -> Tensor(a!) @@ -10550,15 +11024,17 @@ autogen: _adaptive_avg_pool2d_backward.out tags: canonical -- func: adaptive_avg_pool3d.out(Tensor self, int[3] output_size, *, Tensor(a!) out) -> Tensor(a!) +- func: adaptive_avg_pool3d.out(Tensor self, SymInt[3] output_size, *, Tensor(a!) out) -> Tensor(a!) python_module: nn dispatch: CPU: adaptive_avg_pool3d_out_cpu CUDA: adaptive_avg_pool3d_out_cuda QuantizedCPU: adaptive_avg_pool3d_out_quantized_cpu -- func: adaptive_avg_pool3d(Tensor self, int[3] output_size) -> Tensor +- func: adaptive_avg_pool3d(Tensor self, SymInt[3] output_size) -> Tensor python_module: nn + dispatch: + CompositeImplicitAutograd: adaptive_avg_pool3d_symint - func: _adaptive_avg_pool3d(Tensor self, SymInt[3] output_size) -> Tensor dispatch: @@ -10998,158 +11474,54 @@ - func: upsample_linear1d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor python_module: nn - dispatch: - CompositeExplicitAutograd: upsample_linear1d autogen: upsample_linear1d.vec_out -- func: upsample_linear1d_backward.vec(Tensor grad_output, SymInt[]? output_size, SymInt[] input_size, bool align_corners, float[]? scale_factors) -> Tensor - python_module: nn - dispatch: - CompositeExplicitAutograd: upsample_linear1d_backward - autogen: upsample_linear1d_backward.vec_out - - func: upsample_bilinear2d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor python_module: nn - dispatch: - CompositeExplicitAutograd: upsample_bilinear2d autogen: upsample_bilinear2d.vec_out tags: canonical -- func: upsample_bilinear2d_backward.vec(Tensor grad_output, SymInt[]? output_size, SymInt[] input_size, bool align_corners, float[]? scale_factors) -> Tensor - python_module: nn - dispatch: - CompositeExplicitAutograd: upsample_bilinear2d_backward - autogen: upsample_bilinear2d_backward.vec_out - tags: canonical - - func: _upsample_bilinear2d_aa.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor python_module: nn - dispatch: - CompositeExplicitAutograd: _upsample_bilinear2d_aa autogen: _upsample_bilinear2d_aa.vec_out -- func: _upsample_bilinear2d_aa_backward.vec(Tensor grad_output, SymInt[]? output_size, SymInt[] input_size, bool align_corners, float[]? scale_factors) -> Tensor - python_module: nn - dispatch: - CompositeExplicitAutograd: _upsample_bilinear2d_aa_backward - autogen: _upsample_bilinear2d_aa_backward.vec_out - - func: upsample_trilinear3d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor python_module: nn - dispatch: - CompositeExplicitAutograd: upsample_trilinear3d autogen: upsample_trilinear3d.vec_out -- func: upsample_trilinear3d_backward.vec(Tensor grad_output, SymInt[]? output_size, SymInt[] input_size, bool align_corners, float[]? scale_factors) -> Tensor - python_module: nn - dispatch: - CompositeExplicitAutograd: upsample_trilinear3d_backward - autogen: upsample_trilinear3d_backward.vec_out - - func: upsample_bicubic2d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor python_module: nn - dispatch: - CompositeExplicitAutograd: upsample_bicubic2d autogen: upsample_bicubic2d.vec_out -- func: upsample_bicubic2d_backward.vec(Tensor grad_output, SymInt[]? output_size, SymInt[] input_size, bool align_corners, float[]? scale_factors) -> Tensor - python_module: nn - dispatch: - CompositeExplicitAutograd: upsample_bicubic2d_backward - autogen: upsample_bicubic2d_backward.vec_out - - func: _upsample_bicubic2d_aa.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor python_module: nn - dispatch: - CompositeExplicitAutograd: _upsample_bicubic2d_aa autogen: _upsample_bicubic2d_aa.vec_out -- func: _upsample_bicubic2d_aa_backward.vec(Tensor grad_output, SymInt[]? output_size, SymInt[] input_size, bool align_corners, float[]? scale_factors) -> Tensor - python_module: nn - dispatch: - CompositeExplicitAutograd: _upsample_bicubic2d_aa_backward - autogen: _upsample_bicubic2d_aa_backward.vec_out - - func: upsample_nearest1d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor python_module: nn - dispatch: - CompositeExplicitAutograd: upsample_nearest1d autogen: upsample_nearest1d.vec_out - func: _upsample_nearest_exact1d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor python_module: nn - dispatch: - CompositeExplicitAutograd: _upsample_nearest_exact1d autogen: _upsample_nearest_exact1d.vec_out -- func: upsample_nearest1d_backward.vec(Tensor grad_output, SymInt[]? output_size, SymInt[] input_size, float[]? scale_factors) -> Tensor - python_module: nn - dispatch: - CompositeExplicitAutograd: upsample_nearest1d_backward - autogen: upsample_nearest1d_backward.vec_out - -- func: _upsample_nearest_exact1d_backward.vec(Tensor grad_output, SymInt[]? output_size, SymInt[] input_size, float[]? scale_factors) -> Tensor - python_module: nn - dispatch: - CompositeExplicitAutograd: _upsample_nearest_exact1d_backward - autogen: _upsample_nearest_exact1d_backward.vec_out - - func: upsample_nearest2d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor python_module: nn - dispatch: - CompositeExplicitAutograd: upsample_nearest2d autogen: upsample_nearest2d.vec_out tags: canonical - func: _upsample_nearest_exact2d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor python_module: nn - dispatch: - CompositeExplicitAutograd: _upsample_nearest_exact2d autogen: _upsample_nearest_exact2d.vec_out -- func: upsample_nearest2d_backward.vec(Tensor grad_output, SymInt[]? output_size, SymInt[] input_size, float[]? scale_factors) -> Tensor - python_module: nn - dispatch: - CompositeExplicitAutograd: upsample_nearest2d_backward - autogen: upsample_nearest2d_backward.vec_out - tags: canonical - -- func: _upsample_nearest_exact2d_backward.vec(Tensor grad_output, SymInt[]? output_size, SymInt[] input_size, float[]? scale_factors) -> Tensor - python_module: nn - dispatch: - CompositeExplicitAutograd: _upsample_nearest_exact2d_backward - autogen: _upsample_nearest_exact2d_backward.vec_out - - func: upsample_nearest3d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor python_module: nn - dispatch: - CPU: upsample_nearest3d_cpu - CUDA: upsample_nearest3d_cuda - QuantizedCPU: upsample_nearest3d_quantized_cpu autogen: upsample_nearest3d.vec_out - func: _upsample_nearest_exact3d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor python_module: nn - dispatch: - CPU: _upsample_nearest_exact3d_cpu - CUDA: _upsample_nearest_exact3d_cuda - QuantizedCPU: _upsample_nearest_exact3d_quantized_cpu autogen: _upsample_nearest_exact3d.vec_out -- func: upsample_nearest3d_backward.vec(Tensor grad_output, SymInt[]? output_size, SymInt[] input_size, float[]? scale_factors) -> Tensor - python_module: nn - dispatch: - CPU: upsample_nearest3d_backward_cpu - CUDA: upsample_nearest3d_backward_cuda - autogen: upsample_nearest3d_backward.vec_out - -- func: _upsample_nearest_exact3d_backward.vec(Tensor grad_output, SymInt[]? output_size, SymInt[] input_size, float[]? scale_factors) -> Tensor - python_module: nn - dispatch: - CPU: _upsample_nearest_exact3d_backward_cpu - CUDA: _upsample_nearest_exact3d_backward_cuda - autogen: _upsample_nearest_exact3d_backward.vec_out - # NOTE: all of the non-"vec" upsample overloads are only kept for backward compatibility. - func: upsample_linear1d.out(Tensor self, SymInt[1] output_size, bool align_corners, float? scales=None, *, Tensor(a!) out) -> Tensor(a!) python_module: nn @@ -11301,6 +11673,7 @@ dispatch: CPU: _upsample_nearest_exact1d_out_cpu CUDA: _upsample_nearest_exact1d_out_cuda + MPS: _upsample_nearest_exact1d_out_mps - func: upsample_nearest1d(Tensor self, SymInt[1] output_size, float? scales=None) -> Tensor python_module: nn @@ -11316,6 +11689,7 @@ dispatch: CPU: upsample_nearest1d_backward_out_cpu CUDA: upsample_nearest1d_backward_out_cuda + MPS: upsample_nearest1d_backward_out_mps - func: _upsample_nearest_exact1d_backward.grad_input(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!) python_module: nn @@ -11323,6 +11697,7 @@ dispatch: CPU: _upsample_nearest_exact1d_backward_out_cpu CUDA: _upsample_nearest_exact1d_backward_out_cuda + MPS: _upsample_nearest_exact1d_backward_out_mps - func: upsample_nearest1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None) -> Tensor python_module: nn @@ -11439,10 +11814,12 @@ dispatch: CPU, CUDA: sigmoid_backward_out MPS: sigmoid_backward_out_mps + tags: pointwise - func: sigmoid_backward(Tensor grad_output, Tensor output) -> Tensor python_module: nn structured_delegate: sigmoid_backward.grad_input + tags: pointwise - func: logit_backward.grad_input(Tensor grad_output, Tensor self, float? eps=None, *, Tensor(a!) grad_input) -> Tensor(a!) python_module: nn @@ -11450,10 +11827,12 @@ structured_inherits: TensorIteratorBase dispatch: CPU, CUDA: logit_backward_out + tags: pointwise - func: logit_backward(Tensor grad_output, Tensor self, float? eps=None) -> Tensor python_module: nn structured_delegate: logit_backward.grad_input + tags: pointwise - func: tanh_backward.grad_input(Tensor grad_output, Tensor output, *, Tensor(a!) grad_input) -> Tensor(a!) python_module: nn @@ -11462,6 +11841,7 @@ dispatch: CPU, CUDA: tanh_backward_out MPS: tanh_backward_out_mps + tags: pointwise - func: tanh_backward(Tensor grad_output, Tensor output) -> Tensor python_module: nn @@ -11484,25 +11864,26 @@ # one that is written in the native style: modern C++. Algorithmically, # these are the same thing, but we give them different prefixes to # make the operational distinction clear. + tags: pointwise -- func: slow_conv_transpose2d.out(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] output_padding=0, int[2] dilation=1, *, Tensor(a!) out) -> Tensor(a!) +- func: slow_conv_transpose2d.out(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias=None, int[2] stride=1, SymInt[2] padding=0, SymInt[2] output_padding=0, int[2] dilation=1, *, Tensor(a!) out) -> Tensor(a!) python_module: nn structured: True dispatch: CPU: slow_conv_transpose2d_structured_cpu CUDA: slow_conv_transpose2d_structured_cuda -- func: slow_conv_transpose2d(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] output_padding=0, int[2] dilation=1) -> Tensor +- func: slow_conv_transpose2d(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias=None, int[2] stride=1, SymInt[2] padding=0, SymInt[2] output_padding=0, int[2] dilation=1) -> Tensor python_module: nn structured_delegate: slow_conv_transpose2d.out -- func: slow_conv_transpose3d.out(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] output_padding=0, int[3] dilation=1, *, Tensor(a!) out) -> Tensor(a!) +- func: slow_conv_transpose3d.out(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias=None, int[3] stride=1, SymInt[3] padding=0, SymInt[3] output_padding=0, int[3] dilation=1, *, Tensor(a!) out) -> Tensor(a!) python_module: nn dispatch: CPU: slow_conv_transpose3d_out_cpu CUDA: slow_conv_transpose3d_out_cuda -- func: slow_conv_transpose3d(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] output_padding=0, int[3] dilation=1) -> Tensor +- func: slow_conv_transpose3d(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias=None, int[3] stride=1, SymInt[3] padding=0, SymInt[3] output_padding=0, int[3] dilation=1) -> Tensor python_module: nn dispatch: CPU: slow_conv_transpose3d_cpu @@ -11539,47 +11920,47 @@ CUDA: slow_conv2d_backward_cuda autogen: _slow_conv2d_backward.output_mask_out -- func: _conv_depthwise2d.out(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias, int[2] stride, int[2] padding, int[2] dilation, *, Tensor(a!) out) -> Tensor(a!) +- func: _conv_depthwise2d.out(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias, int[2] stride, SymInt[2] padding, int[2] dilation, *, Tensor(a!) out) -> Tensor(a!) use_const_ref_for_mutable_tensors: True python_module: nn dispatch: CUDA: conv_depthwise2d_cuda_out -- func: _conv_depthwise2d(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias, int[2] stride, int[2] padding, int[2] dilation) -> Tensor +- func: _conv_depthwise2d(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias, int[2] stride, SymInt[2] padding, int[2] dilation) -> Tensor python_module: nn dispatch: CUDA: conv_depthwise2d_cuda -- func: conv_depthwise3d(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias, int[3] stride, int[3] padding, int[3] dilation) -> Tensor +- func: conv_depthwise3d(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias, int[3] stride, SymInt[3] padding, int[3] dilation) -> Tensor python_module: nn dispatch: CUDA: conv_depthwise3d_cuda autogen: conv_depthwise3d.out -- func: slow_conv3d.out(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias=None, int[3] stride=1, int[3] padding=0, *, Tensor(a!) out) -> Tensor(a!) +- func: slow_conv3d.out(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias=None, int[3] stride=1, SymInt[3] padding=0, *, Tensor(a!) out) -> Tensor(a!) python_module: nn -- func: slow_conv3d(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias=None, int[3] stride=1, int[3] padding=0) -> Tensor +- func: slow_conv3d(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias=None, int[3] stride=1, SymInt[3] padding=0) -> Tensor python_module: nn -- func: slow_conv3d_forward.output(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias, int[3] stride, int[3] padding, *, Tensor(a!) output) -> Tensor(a!) +- func: slow_conv3d_forward.output(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias, int[3] stride, SymInt[3] padding, *, Tensor(a!) output) -> Tensor(a!) python_module: nn dispatch: CPU: slow_conv3d_forward_out_cpu -- func: slow_conv3d_forward(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias, int[3] stride, int[3] padding) -> Tensor +- func: slow_conv3d_forward(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias, int[3] stride, SymInt[3] padding) -> Tensor python_module: nn dispatch: CPU: slow_conv3d_forward_cpu -- func: slow_conv_dilated2d(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1) -> Tensor +- func: slow_conv_dilated2d(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias=None, int[2] stride=1, SymInt[2] padding=0, int[2] dilation=1) -> Tensor python_module: nn dispatch: CPU: slow_conv_dilated2d_cpu CUDA: slow_conv_dilated2d_cuda autogen: slow_conv_dilated2d.out -- func: slow_conv_dilated3d(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] dilation=1) -> Tensor +- func: slow_conv_dilated3d(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias=None, int[3] stride=1, SymInt[3] padding=0, int[3] dilation=1) -> Tensor python_module: nn dispatch: CPU: slow_conv_dilated3d_cpu @@ -11642,6 +12023,7 @@ dispatch: SparseCPU, SparseCUDA: isposinf_sparse SparseCsrCPU, SparseCsrCUDA: isposinf_sparse_csr + tags: pointwise - func: isposinf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) structured: True @@ -11650,6 +12032,7 @@ CPU, CUDA: isposinf_out SparseCPU, SparseCUDA: isposinf_sparse_out SparseCsrCPU, SparseCsrCUDA: isposinf_sparse_csr_out + tags: pointwise - func: isneginf(Tensor self) -> Tensor variants: function, method @@ -11657,6 +12040,7 @@ dispatch: SparseCPU, SparseCUDA: isneginf_sparse SparseCsrCPU, SparseCsrCUDA: isneginf_sparse_csr + tags: pointwise - func: isneginf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) structured: True @@ -11665,6 +12049,7 @@ CPU, CUDA: isneginf_out SparseCPU, SparseCUDA: isneginf_sparse_out SparseCsrCPU, SparseCsrCUDA: isneginf_sparse_csr_out + tags: pointwise # NOTE [_add_batch_dim and _remove_batch_dim] # _add_batch_dim and _remove_batch_dim are meant to be used in the implementation @@ -11688,6 +12073,7 @@ structured_delegate: special_entr.out python_module: special variants: function + tags: pointwise - func: special_entr.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) structured: True @@ -11696,11 +12082,13 @@ variants: function dispatch: CPU, CUDA: special_entr_out + tags: pointwise - func: special_ndtri(Tensor self) -> Tensor structured_delegate: special_ndtri.out python_module: special variants: function + tags: pointwise - func: special_ndtri.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) structured: True @@ -11709,11 +12097,13 @@ variants: function dispatch: CPU, CUDA: special_ndtri_out + tags: pointwise - func: special_log_ndtr(Tensor self) -> Tensor structured_delegate: special_log_ndtr.out python_module: special variants: function + tags: pointwise - func: special_log_ndtr.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) structured: True @@ -11722,6 +12112,7 @@ variants: function dispatch: CPU, CUDA: special_log_ndtr_out + tags: pointwise - func: special_expm1(Tensor self) -> Tensor python_module: special @@ -11782,6 +12173,7 @@ python_module: special variants: function structured_delegate: special_erfcx.out + tags: pointwise - func: special_erfcx.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) python_module: special @@ -11789,6 +12181,7 @@ structured_inherits: TensorIteratorBase dispatch: CPU, CUDA: special_erfcx_out + tags: pointwise - func: special_erfinv(Tensor self) -> Tensor python_module: special @@ -11810,6 +12203,7 @@ python_module: special variants: function structured_delegate: special_xlog1py.out + tags: pointwise - func: special_xlog1py.self_scalar(Scalar self, Tensor other) -> Tensor device_check: NoCheck # TensorIterator @@ -11817,6 +12211,7 @@ variants: function dispatch: CompositeExplicitAutograd: special_xlog1py + tags: pointwise - func: special_xlog1py.other_scalar(Tensor self, Scalar other) -> Tensor device_check: NoCheck # TensorIterator @@ -11824,6 +12219,7 @@ variants: function dispatch: CompositeExplicitAutograd: special_xlog1py + tags: pointwise - func: special_xlog1py.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -11833,6 +12229,7 @@ variants: function dispatch: CPU, CUDA: special_xlog1py_out + tags: pointwise - func: special_xlog1py.self_scalar_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -11840,6 +12237,7 @@ variants: function dispatch: CompositeExplicitAutograd: special_xlog1py_out + tags: pointwise - func: special_xlog1py.other_scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -11847,6 +12245,7 @@ variants: function dispatch: CompositeExplicitAutograd: special_xlog1py_out + tags: pointwise - func: special_xlogy(Tensor self, Tensor other) -> Tensor device_check: NoCheck # TensorIterator @@ -11883,6 +12282,7 @@ python_module: special variants: function structured_delegate: special_zeta.out + tags: pointwise - func: special_zeta.self_scalar(Scalar self, Tensor other) -> Tensor device_check: NoCheck # TensorIterator @@ -11890,6 +12290,7 @@ variants: function dispatch: CompositeExplicitAutograd: special_zeta + tags: pointwise - func: special_zeta.other_scalar(Tensor self, Scalar other) -> Tensor device_check: NoCheck # TensorIterator @@ -11897,6 +12298,7 @@ variants: function dispatch: CompositeExplicitAutograd: special_zeta + tags: pointwise - func: special_zeta.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -11906,6 +12308,7 @@ variants: function dispatch: CPU, CUDA: special_zeta_out + tags: pointwise - func: special_zeta.self_scalar_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -11913,6 +12316,7 @@ variants: function dispatch: CompositeExplicitAutograd: special_zeta_out + tags: pointwise - func: special_zeta.other_scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -11920,6 +12324,7 @@ variants: function dispatch: CompositeExplicitAutograd: special_zeta_out + tags: pointwise - func: special_i0(Tensor self) -> Tensor python_module: special @@ -11933,6 +12338,7 @@ python_module: special variants: function structured_delegate: special_i0e.out + tags: pointwise - func: special_i0e.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) python_module: special @@ -11940,11 +12346,13 @@ structured_inherits: TensorIteratorBase dispatch: CPU, CUDA: special_i0e_out + tags: pointwise - func: special_i1(Tensor self) -> Tensor python_module: special variants: function structured_delegate: special_i1.out + tags: pointwise - func: special_i1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) python_module: special @@ -11952,11 +12360,13 @@ structured_inherits: TensorIteratorBase dispatch: CPU, CUDA: special_i1_out + tags: pointwise - func: special_i1e(Tensor self) -> Tensor python_module: special variants: function structured_delegate: special_i1e.out + tags: pointwise - func: special_i1e.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) python_module: special @@ -11964,6 +12374,7 @@ structured_inherits: TensorIteratorBase dispatch: CPU, CUDA: special_i1e_out + tags: pointwise - func: special_logit(Tensor self, float? eps=None) -> Tensor python_module: special @@ -12282,7 +12693,7 @@ python_module: linalg structured: True dispatch: - CPU, CUDA: linalg_cross_out + CPU, CUDA, MPS: linalg_cross_out # linalg.lu_factor - func: linalg_lu_factor(Tensor A, *, bool pivot=True) -> (Tensor LU, Tensor pivots) @@ -12501,6 +12912,7 @@ structured: True dispatch: CPU, CUDA: linalg_inv_ex_out + MPS: linalg_inv_ex_out_mps - func: linalg_inv(Tensor A) -> Tensor python_module: linalg @@ -12826,6 +13238,12 @@ tags: view_copy autogen: _test_autograd_multiple_dispatch_view_copy.out +# Note: this function is only for testing. +- func: _test_inductor_realize(Tensor self) -> Tensor + variants: function + dispatch: + CPU, CUDA, Meta: _test_inductor_realize + - func: segment_reduce(Tensor data, str reduce, *, Tensor? lengths=None, Tensor? indices=None, Tensor? offsets=None, int axis=0, bool unsafe=False, Scalar? initial=None) -> Tensor variants: function dispatch: @@ -12895,7 +13313,7 @@ - func: as_strided_copy(Tensor self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor variants: function dispatch: - CompositeExplicitAutogradNonFunctional: as_strided_copy + CompositeExplicitAutogradNonFunctional: as_strided_copy_symint tags: view_copy - func: _sparse_broadcast_to_copy(Tensor self, int[] size) -> Tensor @@ -12913,7 +13331,7 @@ - func: expand_copy(Tensor self, SymInt[] size, *, bool implicit=False) -> Tensor variants: function dispatch: - CompositeExplicitAutogradNonFunctional: expand_copy + CompositeExplicitAutogradNonFunctional: expand_copy_symint tags: view_copy - func: permute_copy(Tensor self, int[] dims) -> Tensor @@ -12925,13 +13343,14 @@ - func: _reshape_alias_copy(Tensor self, SymInt[] size, SymInt[] stride) -> Tensor variants: function dispatch: - CompositeExplicitAutogradNonFunctional: _reshape_alias_copy + CompositeExplicitAutogradNonFunctional: _reshape_alias_copy_symint tags: view_copy -- func: select_copy.int(Tensor self, int dim, int index) -> Tensor +- func: select_copy.int(Tensor self, int dim, SymInt index) -> Tensor variants: function dispatch: - CompositeExplicitAutogradNonFunctional: select_copy_int + CompositeExplicitAutogradNonFunctional: select_copy_symint + SparseCsrCPU, SparseCsrCUDA: select_copy_sparse_csr tags: view_copy - func: detach_copy(Tensor self) -> Tensor @@ -12943,19 +13362,19 @@ - func: slice_copy.Tensor(Tensor self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor variants: function dispatch: - CompositeExplicitAutogradNonFunctional: slice_copy_Tensor + CompositeExplicitAutogradNonFunctional: slice_copy_Tensor_symint tags: view_copy - func: split_copy.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[] variants: function dispatch: - CompositeExplicitAutogradNonFunctional: split_copy_Tensor + CompositeExplicitAutogradNonFunctional: split_copy_Tensor_symint tags: view_copy - func: split_with_sizes_copy(Tensor self, SymInt[] split_sizes, int dim=0) -> Tensor[] variants: function dispatch: - CompositeExplicitAutogradNonFunctional: split_with_sizes_copy + CompositeExplicitAutogradNonFunctional: split_with_sizes_copy_symint tags: view_copy - func: squeeze_copy(Tensor self) -> Tensor @@ -13027,14 +13446,14 @@ - func: ccol_indices_copy(Tensor self) -> Tensor variants: function dispatch: - CompositeExplicitAutograd: ccol_indices_copy + CompositeExplicitAutogradNonFunctional: ccol_indices_copy tags: view_copy autogen: ccol_indices_copy.out - func: row_indices_copy(Tensor self) -> Tensor variants: function dispatch: - CompositeExplicitAutograd: row_indices_copy + CompositeExplicitAutogradNonFunctional: row_indices_copy tags: view_copy autogen: row_indices_copy.out @@ -13140,10 +13559,10 @@ CompositeExplicitAutograd: _reshape_alias_copy_out -- func: select_copy.int_out(Tensor self, int dim, int index, *, Tensor(a!) out) -> Tensor(a!) +- func: select_copy.int_out(Tensor self, int dim, SymInt index, *, Tensor(a!) out) -> Tensor(a!) variants: function dispatch: - CompositeExplicitAutograd: select_copy_int_out + CompositeExplicitAutograd: select_copy_symint_out - func: detach_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) @@ -13287,7 +13706,8 @@ - func: _native_multi_head_attention(Tensor query, Tensor key, Tensor value, int embed_dim, int num_head, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, Tensor? mask=None, bool need_weights=True, bool average_attn_weights=True, int? mask_type=None) -> (Tensor, Tensor) variants: function dispatch: - CPU, CUDA, NestedTensorCPU, NestedTensorCUDA: native_multi_head_attention + CPU, NestedTensorCPU: native_multi_head_attention_cpu + CUDA, NestedTensorCUDA: native_multi_head_attention_cuda autogen: _native_multi_head_attention.out - func: _scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool need_attn_weights=False, bool is_causal=False) -> (Tensor, Tensor) @@ -13295,19 +13715,45 @@ variants: function autogen: _scaled_dot_product_attention.out -# Register the math kernel for cpu -- func: _scaled_dot_product_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool need_attn_weights=False, bool is_causal=False) -> (Tensor, Tensor) - variants: function +- func: _fused_sdp_choice(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool need_attn_weights=False, bool is_causal=False) -> int dispatch: - CUDA: _scaled_dot_product_attention_forward_cuda - CPU: _scaled_dot_product_attention_forward_math - NestedTensorCUDA: _scaled_dot_product_attention_forward_nested - NestedTensorCPU: _scaled_dot_product_attention_forward_math - Meta: _scaled_dot_product_attention_forward_math + CPU, NestedTensorCPU, Meta: _fused_sdp_choice_cpp + CUDA, NestedTensorCUDA: _fused_sdp_choice_cuda - func: _scaled_dot_product_attention_math(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool need_attn_weights=False, bool is_causal=False) -> (Tensor, Tensor) variants: function +- func: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool return_softmax=False, bool is_causal=False) -> (Tensor, Tensor, Tensor) + dispatch: + CUDA: _scaled_dot_product_flash_attention_cuda + NestedTensorCUDA: _scaled_dot_product_flash_attention_nestedtensor_cuda + +- func: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, bool compute_log_sumexp, bool is_causal=False) -> (Tensor, Tensor) + dispatch: + CUDA: _scaled_dot_product_efficient_attention_cuda + NestedTensorCUDA: _scaled_dot_product_efficient_attention_nestedtensor_cuda + +- func: _scaled_dot_product_efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, bool is_causal=False) -> (Tensor, Tensor, Tensor) + dispatch: + CUDA: _scaled_dot_product_efficient_attention_backward_cuda + +# Returns ouput, softmax_logsumexp, softmax +- func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, bool return_softmax, float dropout_p, bool is_causal) -> (Tensor, Tensor, Tensor) + variants: function + dispatch: + CUDA: _flash_attention_forward + +# Returns ouput, logsumexp if compute_logsumexp +- func: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seqlen_q, bool compute_log_sumexp=False, bool causal=False) -> (Tensor, Tensor) + variants: function + dispatch: + CUDA: _efficient_attention_forward + +- func: _efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, bool is_causal=False) -> (Tensor, Tensor, Tensor) + variants: function + dispatch: + CUDA: _efficient_attention_backward + - func: _triton_scaled_dot_attention(Tensor q, Tensor k, Tensor v, float dropout_p=0.0) -> Tensor variants: function dispatch: @@ -13324,6 +13770,7 @@ python_module: special structured_delegate: special_airy_ai.out variants: function + tags: pointwise - func: special_airy_ai.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -13332,16 +13779,7 @@ structured_inherits: TensorIteratorBase structured: True variants: function - -- func: _flash_scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, float dropout_p, bool is_causal) -> Tensor - variants: function - dispatch: - CUDA: flash_scaled_dot_product_attention - -- func: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seqlen_q, bool compute_log_sumexp=False, bool causal=False) -> (Tensor, Tensor) - variants: function - dispatch: - CUDA: _efficient_attention_forward + tags: pointwise - func: _transformer_decoder_only_layer_fwd(Tensor src, int embed_dim, int num_heads, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, bool use_gelu, bool norm_first, float eps, Tensor norm_weight_1, Tensor norm_bias_1, Tensor norm_weight_2, Tensor norm_bias_2, Tensor ffn_weight_1, Tensor ffn_bias_1, Tensor ffn_weight_2, Tensor ffn_bias_2, Tensor? mask=None, Tensor? incr_key=None, Tensor? incr_value=None) -> (Tensor, Tensor, Tensor) variants: function @@ -13359,6 +13797,7 @@ python_module: special structured_delegate: special_bessel_j0.out variants: function + tags: pointwise - func: special_bessel_j0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -13367,11 +13806,13 @@ structured_inherits: TensorIteratorBase structured: True variants: function + tags: pointwise - func: special_bessel_j1(Tensor self) -> Tensor python_module: special structured_delegate: special_bessel_j1.out variants: function + tags: pointwise - func: special_bessel_j1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -13380,11 +13821,13 @@ structured_inherits: TensorIteratorBase structured: True variants: function + tags: pointwise - func: special_bessel_y0(Tensor self) -> Tensor python_module: special structured_delegate: special_bessel_y0.out variants: function + tags: pointwise - func: special_bessel_y0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -13393,11 +13836,13 @@ structured_inherits: TensorIteratorBase structured: True variants: function + tags: pointwise - func: special_bessel_y1(Tensor self) -> Tensor python_module: special structured_delegate: special_bessel_y1.out variants: function + tags: pointwise - func: special_bessel_y1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -13406,22 +13851,26 @@ structured_inherits: TensorIteratorBase structured: True variants: function + tags: pointwise - func: special_chebyshev_polynomial_t(Tensor x, Tensor n) -> Tensor device_check: NoCheck python_module: special structured_delegate: special_chebyshev_polynomial_t.out variants: function + tags: pointwise - func: special_chebyshev_polynomial_t.x_scalar(Scalar x, Tensor n) -> Tensor device_check: NoCheck python_module: special variants: function + tags: pointwise - func: special_chebyshev_polynomial_t.n_scalar(Tensor x, Scalar n) -> Tensor device_check: NoCheck python_module: special variants: function + tags: pointwise - func: special_chebyshev_polynomial_t.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck @@ -13431,11 +13880,13 @@ structured_inherits: TensorIteratorBase structured: True variants: function + tags: pointwise - func: special_chebyshev_polynomial_t.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck python_module: special variants: function + tags: pointwise - func: special_chebyshev_polynomial_t.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -13443,22 +13894,26 @@ device_check: NoCheck python_module: special variants: function + tags: pointwise - func: special_chebyshev_polynomial_u(Tensor x, Tensor n) -> Tensor device_check: NoCheck python_module: special structured_delegate: special_chebyshev_polynomial_u.out variants: function + tags: pointwise - func: special_chebyshev_polynomial_u.x_scalar(Scalar x, Tensor n) -> Tensor device_check: NoCheck python_module: special variants: function + tags: pointwise - func: special_chebyshev_polynomial_u.n_scalar(Tensor x, Scalar n) -> Tensor device_check: NoCheck python_module: special variants: function + tags: pointwise - func: special_chebyshev_polynomial_u.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck @@ -13468,11 +13923,13 @@ structured_inherits: TensorIteratorBase structured: True variants: function + tags: pointwise - func: special_chebyshev_polynomial_u.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck python_module: special variants: function + tags: pointwise - func: special_chebyshev_polynomial_u.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -13480,22 +13937,26 @@ device_check: NoCheck python_module: special variants: function + tags: pointwise - func: special_chebyshev_polynomial_v(Tensor x, Tensor n) -> Tensor device_check: NoCheck python_module: special structured_delegate: special_chebyshev_polynomial_v.out variants: function + tags: pointwise - func: special_chebyshev_polynomial_v.x_scalar(Scalar x, Tensor n) -> Tensor device_check: NoCheck python_module: special variants: function + tags: pointwise - func: special_chebyshev_polynomial_v.n_scalar(Tensor x, Scalar n) -> Tensor device_check: NoCheck python_module: special variants: function + tags: pointwise - func: special_chebyshev_polynomial_v.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck @@ -13505,11 +13966,13 @@ structured_inherits: TensorIteratorBase structured: True variants: function + tags: pointwise - func: special_chebyshev_polynomial_v.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck python_module: special variants: function + tags: pointwise - func: special_chebyshev_polynomial_v.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -13517,22 +13980,26 @@ device_check: NoCheck python_module: special variants: function + tags: pointwise - func: special_chebyshev_polynomial_w(Tensor x, Tensor n) -> Tensor device_check: NoCheck python_module: special structured_delegate: special_chebyshev_polynomial_w.out variants: function + tags: pointwise - func: special_chebyshev_polynomial_w.x_scalar(Scalar x, Tensor n) -> Tensor device_check: NoCheck python_module: special variants: function + tags: pointwise - func: special_chebyshev_polynomial_w.n_scalar(Tensor x, Scalar n) -> Tensor device_check: NoCheck python_module: special variants: function + tags: pointwise - func: special_chebyshev_polynomial_w.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck @@ -13542,11 +14009,13 @@ structured_inherits: TensorIteratorBase structured: True variants: function + tags: pointwise - func: special_chebyshev_polynomial_w.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck python_module: special variants: function + tags: pointwise - func: special_chebyshev_polynomial_w.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -13554,22 +14023,26 @@ device_check: NoCheck python_module: special variants: function + tags: pointwise - func: special_hermite_polynomial_h(Tensor x, Tensor n) -> Tensor device_check: NoCheck python_module: special structured_delegate: special_hermite_polynomial_h.out variants: function + tags: pointwise - func: special_hermite_polynomial_h.x_scalar(Scalar x, Tensor n) -> Tensor device_check: NoCheck python_module: special variants: function + tags: pointwise - func: special_hermite_polynomial_h.n_scalar(Tensor x, Scalar n) -> Tensor device_check: NoCheck python_module: special variants: function + tags: pointwise - func: special_hermite_polynomial_h.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck @@ -13579,11 +14052,13 @@ structured_inherits: TensorIteratorBase structured: True variants: function + tags: pointwise - func: special_hermite_polynomial_h.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck python_module: special variants: function + tags: pointwise - func: special_hermite_polynomial_h.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -13591,22 +14066,26 @@ device_check: NoCheck python_module: special variants: function + tags: pointwise - func: special_hermite_polynomial_he(Tensor x, Tensor n) -> Tensor device_check: NoCheck python_module: special structured_delegate: special_hermite_polynomial_he.out variants: function + tags: pointwise - func: special_hermite_polynomial_he.x_scalar(Scalar x, Tensor n) -> Tensor device_check: NoCheck python_module: special variants: function + tags: pointwise - func: special_hermite_polynomial_he.n_scalar(Tensor x, Scalar n) -> Tensor device_check: NoCheck python_module: special variants: function + tags: pointwise - func: special_hermite_polynomial_he.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck @@ -13616,11 +14095,13 @@ structured_inherits: TensorIteratorBase structured: True variants: function + tags: pointwise - func: special_hermite_polynomial_he.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck python_module: special variants: function + tags: pointwise - func: special_hermite_polynomial_he.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -13628,22 +14109,26 @@ device_check: NoCheck python_module: special variants: function + tags: pointwise - func: special_laguerre_polynomial_l(Tensor x, Tensor n) -> Tensor device_check: NoCheck python_module: special structured_delegate: special_laguerre_polynomial_l.out variants: function + tags: pointwise - func: special_laguerre_polynomial_l.x_scalar(Scalar x, Tensor n) -> Tensor device_check: NoCheck python_module: special variants: function + tags: pointwise - func: special_laguerre_polynomial_l.n_scalar(Tensor x, Scalar n) -> Tensor device_check: NoCheck python_module: special variants: function + tags: pointwise - func: special_laguerre_polynomial_l.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck @@ -13653,11 +14138,13 @@ structured_inherits: TensorIteratorBase structured: True variants: function + tags: pointwise - func: special_laguerre_polynomial_l.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck python_module: special variants: function + tags: pointwise - func: special_laguerre_polynomial_l.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -13665,22 +14152,26 @@ device_check: NoCheck python_module: special variants: function + tags: pointwise - func: special_legendre_polynomial_p(Tensor x, Tensor n) -> Tensor device_check: NoCheck python_module: special structured_delegate: special_legendre_polynomial_p.out variants: function + tags: pointwise - func: special_legendre_polynomial_p.x_scalar(Scalar x, Tensor n) -> Tensor device_check: NoCheck python_module: special variants: function + tags: pointwise - func: special_legendre_polynomial_p.n_scalar(Tensor x, Scalar n) -> Tensor device_check: NoCheck python_module: special variants: function + tags: pointwise - func: special_legendre_polynomial_p.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck @@ -13690,11 +14181,13 @@ structured_inherits: TensorIteratorBase structured: True variants: function + tags: pointwise - func: special_legendre_polynomial_p.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck python_module: special variants: function + tags: pointwise - func: special_legendre_polynomial_p.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -13702,11 +14195,13 @@ device_check: NoCheck python_module: special variants: function + tags: pointwise - func: special_modified_bessel_i0(Tensor self) -> Tensor python_module: special structured_delegate: special_modified_bessel_i0.out variants: function + tags: pointwise - func: special_modified_bessel_i0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -13715,11 +14210,13 @@ structured_inherits: TensorIteratorBase structured: True variants: function + tags: pointwise - func: special_modified_bessel_i1(Tensor self) -> Tensor python_module: special structured_delegate: special_modified_bessel_i1.out variants: function + tags: pointwise - func: special_modified_bessel_i1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -13728,11 +14225,13 @@ structured_inherits: TensorIteratorBase structured: True variants: function + tags: pointwise - func: special_modified_bessel_k0(Tensor self) -> Tensor python_module: special structured_delegate: special_modified_bessel_k0.out variants: function + tags: pointwise - func: special_modified_bessel_k0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -13741,11 +14240,13 @@ structured_inherits: TensorIteratorBase structured: True variants: function + tags: pointwise - func: special_modified_bessel_k1(Tensor self) -> Tensor python_module: special structured_delegate: special_modified_bessel_k1.out variants: function + tags: pointwise - func: special_modified_bessel_k1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -13754,11 +14255,13 @@ structured_inherits: TensorIteratorBase structured: True variants: function + tags: pointwise - func: special_scaled_modified_bessel_k0(Tensor x) -> Tensor python_module: special structured_delegate: special_scaled_modified_bessel_k0.out variants: function + tags: pointwise - func: special_scaled_modified_bessel_k0.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -13767,11 +14270,13 @@ structured_inherits: TensorIteratorBase structured: True variants: function + tags: pointwise - func: special_scaled_modified_bessel_k1(Tensor x) -> Tensor python_module: special structured_delegate: special_scaled_modified_bessel_k1.out variants: function + tags: pointwise - func: special_scaled_modified_bessel_k1.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -13780,22 +14285,26 @@ structured_inherits: TensorIteratorBase structured: True variants: function + tags: pointwise - func: special_shifted_chebyshev_polynomial_t(Tensor x, Tensor n) -> Tensor device_check: NoCheck python_module: special structured_delegate: special_shifted_chebyshev_polynomial_t.out variants: function + tags: pointwise - func: special_shifted_chebyshev_polynomial_t.x_scalar(Scalar x, Tensor n) -> Tensor device_check: NoCheck python_module: special variants: function + tags: pointwise - func: special_shifted_chebyshev_polynomial_t.n_scalar(Tensor x, Scalar n) -> Tensor device_check: NoCheck python_module: special variants: function + tags: pointwise - func: special_shifted_chebyshev_polynomial_t.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck @@ -13805,11 +14314,13 @@ structured_inherits: TensorIteratorBase structured: True variants: function + tags: pointwise - func: special_shifted_chebyshev_polynomial_t.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck python_module: special variants: function + tags: pointwise - func: special_shifted_chebyshev_polynomial_t.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -13817,22 +14328,26 @@ device_check: NoCheck python_module: special variants: function + tags: pointwise - func: special_shifted_chebyshev_polynomial_u(Tensor x, Tensor n) -> Tensor device_check: NoCheck python_module: special structured_delegate: special_shifted_chebyshev_polynomial_u.out variants: function + tags: pointwise - func: special_shifted_chebyshev_polynomial_u.x_scalar(Scalar x, Tensor n) -> Tensor device_check: NoCheck python_module: special variants: function + tags: pointwise - func: special_shifted_chebyshev_polynomial_u.n_scalar(Tensor x, Scalar n) -> Tensor device_check: NoCheck python_module: special variants: function + tags: pointwise - func: special_shifted_chebyshev_polynomial_u.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck @@ -13842,11 +14357,13 @@ structured_inherits: TensorIteratorBase structured: True variants: function + tags: pointwise - func: special_shifted_chebyshev_polynomial_u.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck python_module: special variants: function + tags: pointwise - func: special_shifted_chebyshev_polynomial_u.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -13854,22 +14371,26 @@ device_check: NoCheck python_module: special variants: function + tags: pointwise - func: special_shifted_chebyshev_polynomial_v(Tensor x, Tensor n) -> Tensor device_check: NoCheck python_module: special structured_delegate: special_shifted_chebyshev_polynomial_v.out variants: function + tags: pointwise - func: special_shifted_chebyshev_polynomial_v.x_scalar(Scalar x, Tensor n) -> Tensor device_check: NoCheck python_module: special variants: function + tags: pointwise - func: special_shifted_chebyshev_polynomial_v.n_scalar(Tensor x, Scalar n) -> Tensor device_check: NoCheck python_module: special variants: function + tags: pointwise - func: special_shifted_chebyshev_polynomial_v.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck @@ -13879,11 +14400,13 @@ structured_inherits: TensorIteratorBase structured: True variants: function + tags: pointwise - func: special_shifted_chebyshev_polynomial_v.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck python_module: special variants: function + tags: pointwise - func: special_shifted_chebyshev_polynomial_v.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -13891,22 +14414,26 @@ device_check: NoCheck python_module: special variants: function + tags: pointwise - func: special_shifted_chebyshev_polynomial_w(Tensor x, Tensor n) -> Tensor device_check: NoCheck python_module: special structured_delegate: special_shifted_chebyshev_polynomial_w.out variants: function + tags: pointwise - func: special_shifted_chebyshev_polynomial_w.x_scalar(Scalar x, Tensor n) -> Tensor device_check: NoCheck python_module: special variants: function + tags: pointwise - func: special_shifted_chebyshev_polynomial_w.n_scalar(Tensor x, Scalar n) -> Tensor device_check: NoCheck python_module: special variants: function + tags: pointwise - func: special_shifted_chebyshev_polynomial_w.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck @@ -13916,11 +14443,13 @@ structured_inherits: TensorIteratorBase structured: True variants: function + tags: pointwise - func: special_shifted_chebyshev_polynomial_w.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck python_module: special variants: function + tags: pointwise - func: special_shifted_chebyshev_polynomial_w.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -13928,11 +14457,13 @@ device_check: NoCheck python_module: special variants: function + tags: pointwise - func: special_spherical_bessel_j0(Tensor x) -> Tensor python_module: special structured_delegate: special_spherical_bessel_j0.out variants: function + tags: pointwise - func: special_spherical_bessel_j0.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -13941,6 +14472,7 @@ structured_inherits: TensorIteratorBase structured: True variants: function + tags: pointwise # Aux function used in the test TestPythonDispatch.test_kwarg_only_and_positional_default # within test/test_python_dispatch.py diff --git a/aten/src/ATen/native/nested/NestedTensorBackward.cpp b/aten/src/ATen/native/nested/NestedTensorBackward.cpp index 0807e39e952d3..51a4210a56ae5 100644 --- a/aten/src/ATen/native/nested/NestedTensorBackward.cpp +++ b/aten/src/ATen/native/nested/NestedTensorBackward.cpp @@ -154,18 +154,18 @@ Tensor _nested_sum_backward_cpu( } -Tensor _nested_select_backward( +Tensor _nested_select_backward_symint( const Tensor& grad, const Tensor& nested_self, int64_t dim, - int64_t index) { + c10::SymInt index) { auto nt_self = get_nested_tensor_impl(nested_self); const Tensor& self_buffer = nt_self->get_buffer(); const auto self_sizes = nt_self->get_nested_size_tensor(); const Tensor& self_grad_buffer = self_buffer.new_zeros(self_buffer.sizes()); auto nt_grad = wrap_buffer(self_grad_buffer, self_sizes); - nt_grad.select(dim, index).copy_(grad); + nt_grad.select_symint(dim, index).copy_(grad); return nt_grad; } diff --git a/aten/src/ATen/native/nested/NestedTensorBinaryOps.cpp b/aten/src/ATen/native/nested/NestedTensorBinaryOps.cpp new file mode 100644 index 0000000000000..215252f91d6d2 --- /dev/null +++ b/aten/src/ATen/native/nested/NestedTensorBinaryOps.cpp @@ -0,0 +1,247 @@ +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace at { +namespace native { + +DEFINE_DISPATCH(nested_dense_elementwise_stub); +REGISTER_NO_CPU_DISPATCH(nested_dense_elementwise_stub); + +std::pair +get_elementwise_nested_tensor_impl( + const Tensor& self, + const Tensor& other, + const std::string& op_name) { + if (self.is_nested() && !(other.is_nested())) { + TORCH_CHECK( + false, + "Expected both self and other to be nested, but got a nested self and non-nested other"); + } else if (!(self.is_nested()) && other.is_nested()) { + TORCH_CHECK( + false, + "Expected both self and other to be nested, but got a non-nested self and nested other"); + } else if (!(self.is_nested()) || !(other.is_nested())) { + TORCH_CHECK( + false, + "Expected both self and other to be nested, but got a non-nested self and non-nested other"); + } + + auto self_ptr = get_nested_tensor_impl(self); + auto other_ptr = get_nested_tensor_impl(other); + + TORCH_CHECK( + self.dim() == other.dim(), + op_name, + " does not support broadcasting when given a NestedTensor"); + TORCH_CHECK( + at::equal( + self_ptr->get_nested_size_tensor(), + other_ptr->get_nested_size_tensor()), + op_name, + " does not support broadcasting when given a NestedTensor"); + TORCH_CHECK( + at::equal( + self_ptr->get_nested_stride_tensor(), + other_ptr->get_nested_stride_tensor()), + op_name, + " requires strides to match when given NestedTensors"); + auto self_offsets = self_ptr->get_storage_offsets(); + auto other_offsets = other_ptr->get_storage_offsets(); + bool offsets_match = true; + for (size_t i = 0; i < self_offsets.size(); i++) { + offsets_match = offsets_match && (self_offsets[i] == other_offsets[i]); + } + TORCH_CHECK( + offsets_match, + op_name, + " requires offsets to match when given NestedTensors"); + return std::make_pair(self_ptr, other_ptr); +} + +template +Tensor NestedTensor_elementwise_Tensor( + const Tensor& self, + const Tensor& other, + const std::string& op_name, + Func f) { + // self is a scalar + if (!self.is_nested() && self.dim() == 0 && self.numel() == 1) { + auto other_impl = get_nested_tensor_impl(other); + return wrap_buffer( + f(self, other_impl->get_unsafe_storage_as_tensor()), + other_impl->get_nested_size_tensor().clone(), + other_impl->get_nested_stride_tensor().clone(), + other_impl->get_storage_offsets() + ); + } + // other is a scalar + if (!other.is_nested() && other.dim() == 0 && other.numel() == 1) { + auto self_impl = get_nested_tensor_impl(self); + return wrap_buffer( + f(self_impl->get_unsafe_storage_as_tensor(), other), + self_impl->get_nested_size_tensor().clone(), + self_impl->get_nested_stride_tensor().clone(), + self_impl->get_storage_offsets() + ); + } + // special case when other is dense + if (self.is_nested() && !other.is_nested()) { + // check for the [B, *, D], [B, 1, D] esuhm case + // TODO: this if statement is ugly and hopefully we will remove this in the near future + auto self_ptr = get_nested_tensor_impl(self); + if (self_ptr->dim() == 3 && + other.dim() == 3 && + self_ptr->size(0) == other.size(0) && + other.size(1) == 1 && + self_ptr->opt_size(2).has_value() && + self_ptr->opt_size(2).value() == other.size(2) && + self.is_cuda() && + other.is_cuda()) { + if (!nested_tensor_impl_is_contiguous(self_ptr)) { + self_ptr = get_nested_tensor_impl(self.contiguous()); + } + const auto self_buffer = self_ptr->get_buffer(); + const auto self_sizes = self_ptr->get_nested_size_tensor(); + auto result_buffer = at::empty_like(self_buffer); + auto result = wrap_buffer(result_buffer, self_sizes); + if (op_name == "add") { + nested_dense_elementwise_stub(self.device().type(), result, self, other, NESTED_DENSE_OP::ADD); + } else if (op_name == "mul") { + nested_dense_elementwise_stub(self.device().type(), result, self, other, NESTED_DENSE_OP::MUL); + } else { + TORCH_CHECK(false, "Unsupported nested dense elementwise op"); + } + return result; + } + TORCH_CHECK(false, "Expected both self and other to be nested, but got a nested self and non-nested other."); + } + + NestedTensorImpl* self_impl = nullptr; + NestedTensorImpl* other_impl = nullptr; + std::tie(self_impl, other_impl) = + get_elementwise_nested_tensor_impl(self, other, op_name); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(self_impl); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(other_impl); + return wrap_buffer( + f(self_impl->get_unsafe_storage_as_tensor(), + other_impl->get_unsafe_storage_as_tensor()), + self_impl->get_nested_size_tensor(), + self_impl->get_nested_stride_tensor(), + self_impl->get_storage_offsets()); +} + +Tensor NestedTensor_add_Tensor( + const Tensor& self, + const Tensor& other, + const Scalar& alpha) { + return NestedTensor_elementwise_Tensor( + self, other, "add", [alpha](const Tensor& b1, const Tensor& b2) { + return at::add(b1, b2, alpha); + }); +} + +Tensor NestedTensor_mul_Tensor(const Tensor& self, const Tensor& other) { + return NestedTensor_elementwise_Tensor( + self, other, "mul", [](const Tensor& b1, const Tensor& b2) { + return at::mul(b1, b2); + }); +} + +// Only usable on the C++ side; scalars are converted to tensors coming from Python. +Tensor NestedTensor_mul_Scalar(const Tensor& self, const Scalar& other) { + return NestedTensor_mul_Tensor(self, wrapped_scalar_tensor(other)); +} + +Tensor NestedTensor_div_Tensor(const Tensor& self, const Tensor& other) { + return NestedTensor_elementwise_Tensor( + self, other, "div", [](const Tensor& b1, const Tensor& b2) { + return at::div(b1, b2); + }); +} + +// Only usable on the C++ side; scalars are converted to tensors coming from Python. +Tensor NestedTensor_div_Scalar(const Tensor& self, const Scalar& other) { + return NestedTensor_div_Tensor(self, wrapped_scalar_tensor(other)); +} + +template +Tensor& NestedTensor_elementwise__Tensor( + Tensor& self, + const Tensor& other, + const std::string& op_name, + Func f) { + // self is a scalar + if (!self.is_nested() && self.dim() == 0 && self.numel() == 1) { + auto other_impl = get_nested_tensor_impl(other); + f(self, other_impl->get_buffer()); + return self; + } + // other is a scalar + if (!other.is_nested() && other.dim() == 0 && other.numel() == 1) { + auto self_impl = get_nested_tensor_impl(self); + f(self_impl->get_buffer(), other); + return self; + } + NestedTensorImpl* self_impl = nullptr; + NestedTensorImpl* other_impl = nullptr; + std::tie(self_impl, other_impl) = + get_elementwise_nested_tensor_impl(self, other, op_name); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(self_impl); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(other_impl); + const auto& nt_self = *self_impl; + const auto& nt_other = *other_impl; + f(nt_self.get_buffer().view({-1}), nt_other.get_buffer().view({-1})); + return self; +} + +Tensor& NestedTensor_add__Tensor( + Tensor& self, + const Tensor& other, + const Scalar& alpha) { + return NestedTensor_elementwise__Tensor( + self, other, "add_", [alpha](const Tensor& b1, const Tensor& b2) { + return b1.add_(b2, alpha); + }); +} + +Tensor& NestedTensor_mul__Tensor(Tensor& self, const Tensor& other) { + return NestedTensor_elementwise__Tensor( + self, other, "mul_", [](const Tensor& b1, const Tensor& b2) { + return b1.mul_(b2); + }); +} + +// Only usable on the C++ side; scalars are converted to tensors coming from Python. +Tensor& NestedTensor_mul__Scalar(Tensor& self, const Scalar& other) { + return NestedTensor_mul__Tensor(self, wrapped_scalar_tensor(other)); +} + +Tensor& fill_nested_(Tensor& self, const Scalar& value) { + const auto& self_buf = get_nested_tensor_impl(self)->get_buffer(); + self_buf.fill_(value); + return self; +} + +Tensor& fill_nested_(Tensor& self, const Tensor& value) { + const auto& self_buf = get_nested_tensor_impl(self)->get_buffer(); + self_buf.fill_(value); + return self; +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/nested/NestedTensorBinaryOps.h b/aten/src/ATen/native/nested/NestedTensorBinaryOps.h new file mode 100644 index 0000000000000..51eeaf2919111 --- /dev/null +++ b/aten/src/ATen/native/nested/NestedTensorBinaryOps.h @@ -0,0 +1,16 @@ +#pragma once + +#include +#include + +namespace at { +namespace native { + +enum class NESTED_DENSE_OP: uint8_t {ADD, MUL}; + +using nested_dense_elementwise_fn = void (*)(Tensor& result, const Tensor & self, const Tensor & other, const NESTED_DENSE_OP& op); + +DECLARE_DISPATCH(nested_dense_elementwise_fn, nested_dense_elementwise_stub); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/nested/NestedTensorFactories.cpp b/aten/src/ATen/native/nested/NestedTensorFactories.cpp index 998a62eb136d1..b45fbb24880ce 100644 --- a/aten/src/ATen/native/nested/NestedTensorFactories.cpp +++ b/aten/src/ATen/native/nested/NestedTensorFactories.cpp @@ -106,9 +106,20 @@ Tensor _to_copy_nested( Tensor r; r = at::empty_like(self, dtype, layout, device, pin_out, memory_format); get_nested_tensor_impl(r)->get_buffer().copy_( - get_nested_tensor_impl(self)->get_buffer()); + get_nested_tensor_impl(self)->get_buffer(), non_blocking); return r; } +Tensor& copy_nested_(Tensor& self, const Tensor& src, bool non_blocking) { + const auto* nt_self = get_nested_tensor_impl(self); + const auto* nt_src = get_nested_tensor_impl(src); + TORCH_CHECK( + at::equal( + nt_self->get_nested_size_tensor(), nt_src->get_nested_size_tensor()), + "copy_ only supports tensors that are the same size for Nested implementations"); + nt_self->get_buffer().copy_(nt_src->get_buffer(), non_blocking); + return self; +} + } // namespace native } // namespace at diff --git a/aten/src/ATen/native/nested/NestedTensorMath.cpp b/aten/src/ATen/native/nested/NestedTensorMath.cpp index fc9e11ea44914..5842c3b8b2172 100644 --- a/aten/src/ATen/native/nested/NestedTensorMath.cpp +++ b/aten/src/ATen/native/nested/NestedTensorMath.cpp @@ -17,14 +17,7 @@ namespace at { namespace native { - namespace { -template -Tensor map_nt(const Tensor& nt, Func f) { - auto* nt_impl = get_nested_tensor_impl(nt); - const auto& sizes = nt_impl->get_nested_size_tensor(); - return at::detail::make_tensor(f(nt_impl->get_buffer()), sizes); -} int64_t num_bytes(IntArrayRef sizes) { // 0-dim Tensors have torch.Size of .size() 0, but carry 1 memory. @@ -90,46 +83,6 @@ std::vector NestedTensor_unbind( return result_tensors; } -Tensor& NestedTensor_relu_(Tensor& self) { - auto self_ptr = get_nested_tensor_impl(self); - check_numel_equals_buffer_size(self_ptr); - auto buffer = self_ptr->get_buffer(); - at::relu_(buffer); - return self; -} - -Tensor NestedTensor_relu(const Tensor& self) { - return map_nt(self, at::relu); -} - -Tensor& NestedTensor_gelu_(Tensor& self, c10::string_view approximate) { - auto self_ptr = get_nested_tensor_impl(self); - check_numel_equals_buffer_size(self_ptr); - auto buffer = self_ptr->get_buffer(); - at::gelu_(buffer, approximate); - return self; -} - -Tensor NestedTensor_gelu(const Tensor& self, c10::string_view approximate) { - return map_nt( - self, - [approximate](const Tensor& buffer) { - return at::gelu(buffer, approximate); - }); -} - -Tensor& NestedTensor_tanh_(Tensor& self) { - auto self_ptr = get_nested_tensor_impl(self); - check_numel_equals_buffer_size(self_ptr); - auto buffer = self_ptr->get_buffer(); - at::tanh_(buffer); - return self; -} - -Tensor NestedTensor_tanh(const Tensor& self) { - return map_nt(self, at::tanh); -} - Tensor NestedTensor_nested_tensor_from_mask(const Tensor& t, const Tensor& mask, bool mask_check) { TORCH_CHECK(mask.scalar_type() == at::ScalarType::Bool, "Expected mask to be of ScalarType Bool, but got ", mask.scalar_type(), " instead."); TORCH_CHECK(mask.dim() == 2, "Padding mask should be 2D"); @@ -467,157 +420,6 @@ Tensor NestedTensor_embedding( result_buffer.reshape({-1}), std::move(new_sizes)); } -std::pair -get_elementwise_nested_tensor_impl( - const Tensor& self, - const Tensor& other, - const std::string& op_name) { - if (self.is_nested() && !(other.is_nested())) { - TORCH_CHECK( - false, - "Expected both self and other to be nested, but got a nested self and non-nested other"); - } else if (!(self.is_nested()) && other.is_nested()) { - TORCH_CHECK( - false, - "Expected both self and other to be nested, but got a non-nested self and nested other"); - } else if (!(self.is_nested()) || !(other.is_nested())) { - TORCH_CHECK( - false, - "Expected both self and other to be nested, but got a non-nested self and non-nested other"); - } - - auto self_ptr = get_nested_tensor_impl(self); - auto other_ptr = get_nested_tensor_impl(other); - - TORCH_CHECK( - self.dim() == other.dim(), - op_name, - " does not support broadcasting when given a NestedTensor"); - TORCH_CHECK( - at::equal( - self_ptr->get_nested_size_tensor(), - other_ptr->get_nested_size_tensor()), - op_name, - " does not support broadcasting when given a NestedTensor"); - TORCH_CHECK( - nested_tensor_impl_is_contiguous(self_ptr) && - nested_tensor_impl_is_contiguous(other_ptr), - op_name, - " does not support non-contiguous NestedTensor inputs"); - return std::make_pair(self_ptr, other_ptr); -} - -template -Tensor NestedTensor_elementwise_Tensor( - const Tensor& self, - const Tensor& other, - const std::string& op_name, - Func f) { - // self is a scalar - if (!self.is_nested() && self.dim() == 0 && self.numel() == 1) { - auto other_impl = get_nested_tensor_impl(other); - return wrap_buffer( - f(self, other_impl->get_buffer()), - other_impl->get_nested_size_tensor().clone() - ); - } - // other is a scalar - if (!other.is_nested() && other.dim() == 0 && other.numel() == 1) { - auto self_impl = get_nested_tensor_impl(self); - return wrap_buffer( - f(self_impl->get_buffer(), other), - self_impl->get_nested_size_tensor().clone() - ); - } - NestedTensorImpl* self_impl = nullptr; - NestedTensorImpl* other_impl = nullptr; - std::tie(self_impl, other_impl) = - get_elementwise_nested_tensor_impl(self, other, op_name); - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(self_impl); - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(other_impl); - const auto& nt_self = *self_impl; - const auto& nt_other = *other_impl; - const auto& self_sizes = nt_self.get_nested_size_tensor(); - return wrap_buffer( - f(nt_self.get_buffer().reshape({-1}), - nt_other.get_buffer().reshape({-1})), - self_sizes); -} - -Tensor NestedTensor_add_Tensor( - const Tensor& self, - const Tensor& other, - const Scalar& alpha) { - return NestedTensor_elementwise_Tensor( - self, other, "add", [alpha](const Tensor& b1, const Tensor& b2) { - return at::add(b1, b2, alpha); - }); -} - -Tensor NestedTensor_mul_Tensor(const Tensor& self, const Tensor& other) { - return NestedTensor_elementwise_Tensor( - self, other, "mul", [](const Tensor& b1, const Tensor& b2) { - return at::mul(b1, b2); - }); -} - -// Only usable on the C++ side; scalars are converted to tensors coming from Python. -Tensor NestedTensor_mul_Scalar(const Tensor& self, const Scalar& other) { - return NestedTensor_mul_Tensor(self, wrapped_scalar_tensor(other)); -} - -template -Tensor& NestedTensor_elementwise__Tensor( - Tensor& self, - const Tensor& other, - const std::string& op_name, - Func f) { - // self is a scalar - if (!self.is_nested() && self.dim() == 0 && self.numel() == 1) { - auto other_impl = get_nested_tensor_impl(other); - f(self, other_impl->get_buffer()); - return self; - } - // other is a scalar - if (!other.is_nested() && other.dim() == 0 && other.numel() == 1) { - auto self_impl = get_nested_tensor_impl(self); - f(self_impl->get_buffer(), other); - return self; - } - NestedTensorImpl* self_impl = nullptr; - NestedTensorImpl* other_impl = nullptr; - std::tie(self_impl, other_impl) = - get_elementwise_nested_tensor_impl(self, other, op_name); - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(self_impl); - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(other_impl); - const auto& nt_self = *self_impl; - const auto& nt_other = *other_impl; - f(nt_self.get_buffer().view({-1}), nt_other.get_buffer().view({-1})); - return self; -} - -Tensor& NestedTensor_add__Tensor( - Tensor& self, - const Tensor& other, - const Scalar& alpha) { - return NestedTensor_elementwise__Tensor( - self, other, "add_", [alpha](const Tensor& b1, const Tensor& b2) { - return b1.add_(b2, alpha); - }); -} - -Tensor& NestedTensor_mul__Tensor(Tensor& self, const Tensor& other) { - return NestedTensor_elementwise__Tensor( - self, other, "mul_", [](const Tensor& b1, const Tensor& b2) { - return b1.mul_(b2); - }); -} - -// Only usable on the C++ side; scalars are converted to tensors coming from Python. -Tensor& NestedTensor_mul__Scalar(Tensor& self, const Scalar& other) { - return NestedTensor_mul__Tensor(self, wrapped_scalar_tensor(other)); -} - // Very rudimentary sum_dim for prototyping with torch_scatter.segment_reduce. Tensor NestedTensor_sum_dim_CPU( const Tensor& self, @@ -692,23 +494,59 @@ Tensor NestedTensor_sum_dim_CPU( Tensor select_nested(const Tensor& self, int64_t dim, int64_t index) { auto self_ptr = get_nested_tensor_impl(self); - int64_t positive_dim = at::maybe_wrap_dim(dim, self_ptr->dim()); - TORCH_CHECK( - positive_dim == 0, - "NestedTensor can only be selected along dimension 0 ", - "got dimension ", dim, " instead." - ); - int64_t ntensors = self_ptr->size(0); - TORCH_CHECK_INDEX( - index >= -ntensors && index < ntensors, - "index ", index, - " is out of bounds for dimension 0 with size ", ntensors); - int64_t positive_index = index < 0 ? index + ntensors : index; - const at::Tensor& buffer = self_ptr->get_unsafe_storage_as_tensor(); std::vector sizes = NestedTensor_get_sizes(self_ptr), - strides = NestedTensor_get_strides(self_ptr); + strides = NestedTensor_get_strides(self_ptr); const std::vector& offsets = self_ptr->get_storage_offsets(); - return buffer.as_strided(sizes[positive_index], strides[positive_index], offsets[positive_index]); + const at::Tensor& buffer = self_ptr->get_unsafe_storage_as_tensor(); + int64_t positive_dim = at::maybe_wrap_dim(dim, self_ptr->dim()); + int64_t ntensors = self_ptr->size(0); + TORCH_CHECK_INDEX(ntensors > 0, "You can only select when the NT is not empty."); + int64_t ndims = static_cast(sizes[0].size()); + if (positive_dim == 0) { + TORCH_CHECK_INDEX( + index >= -ntensors && index < ntensors, + "index ", + index, + " is out of bounds for dimension 0 with size ", + ntensors); + int64_t positive_index = index < 0 ? index + ntensors : index; + return buffer.as_strided( + sizes[positive_index], + strides[positive_index], + offsets[positive_index]); + } else { + auto new_sizes = at::empty({ntensors, ndims-1}, TensorOptions().dtype(kLong)); + auto new_strides = at::empty({ntensors, ndims-1}, TensorOptions().dtype(kLong)); + auto new_offsets = std::vector(offsets); + std::vector tensor_slices(ntensors); + for (int64_t i : c10::irange(ntensors)) { + int64_t *size_ptr = new_sizes[i].data_ptr(); + int64_t *stride_ptr = new_strides[i].data_ptr(); + + int64_t dim_idx = 0; + for (int64_t j : c10::irange(ndims)) { + if (j != dim - 1) { + size_ptr[dim_idx] = sizes[i][j]; + stride_ptr[dim_idx] = strides[i][j]; + ++dim_idx; + } else { + TORCH_CHECK_INDEX( + index >= 0 && index < sizes[i][j], + "index ", + index, + " is out of bounds for dimension ", + j, + " of the ", + i, + "th constituent tensor with size ", + sizes[i][j]); + new_offsets[i] = offsets[i] + index * strides[i][j]; + } + } + } + return create_nested_view_tensor(self, new_sizes, new_strides, std::move(new_offsets)); + } + } Tensor clone_nested( @@ -807,215 +645,6 @@ Tensor softmax_nested( return output; } -Tensor bmm_nested(const Tensor& self, const Tensor& mat2) { - if (self.is_nested() && !mat2.is_nested()) { - AT_ERROR("Expected both to be nested, but got a nested self and non-nested other"); - } - else if (!self.is_nested() && mat2.is_nested()) { - AT_ERROR("Expected both to be nested, but got a non-nested self and nested other"); - } - // dispatcher should have guaranteed that at least one is nested - auto self_ptr = get_nested_tensor_impl(self); - auto mat2_ptr = get_nested_tensor_impl(mat2); - TORCH_CHECK(self_ptr->dim() == 3, "batch1 must be a 3D tensor"); - TORCH_CHECK(mat2_ptr->dim() == 3, "batch2 must be a 3D tensor"); - int64_t ntensors = self_ptr->size(0), - ntensors2 = mat2_ptr->size(0); - TORCH_CHECK(ntensors == ntensors2, - "Expected size for the 1st dimension of batch2 tensor to be: ", ntensors, - " but got: ", ntensors2, "."); - const Tensor& self_buffer = self_ptr->get_unsafe_storage_as_tensor(), - & mat2_buffer = mat2_ptr->get_unsafe_storage_as_tensor(); - std::vector self_sizes = NestedTensor_get_sizes(self_ptr), - mat2_sizes = NestedTensor_get_sizes(mat2_ptr), - self_strides = NestedTensor_get_strides(self_ptr), - mat2_strides = NestedTensor_get_strides(mat2_ptr); - const std::vector& self_offsets = self_ptr->get_storage_offsets(), - & mat2_offsets = mat2_ptr->get_storage_offsets(); - // create a contiguous output - int64_t out_numel = 0; - const Tensor& self_sizemat = self_ptr->get_nested_size_tensor(); - Tensor out_sizemat = self_sizemat.new_empty(self_sizemat.sizes()); - int64_t* out_sizemat_ptr = out_sizemat.data_ptr(); - for (int64_t i = 0; i < ntensors; i++) { - const IntArrayRef& self_shape = self_sizes[i], - & mat2_shape = mat2_sizes[i]; - const int64_t& self_size0 = self_shape[0], & self_size1 = self_shape[1], - & mat2_size0 = mat2_shape[0], & mat2_size1 = mat2_shape[1]; - TORCH_CHECK(self_size1 == mat2_size0, - i, "-th nested matrices in batch cannot be multiplied (", - self_size0, "x", self_size1, " and ", - mat2_size0, "x", mat2_size1, ")"); - out_sizemat_ptr[0] = self_size0; - out_sizemat_ptr[1] = mat2_size1; - out_sizemat_ptr += 2; - out_numel += self_size0 * mat2_size1; - } - Tensor out_buffer = self_buffer.new_empty(out_numel); - Tensor output = wrap_buffer(out_buffer, out_sizemat); - // call tensor mm - // TODO: `padding nested tensor -> bmm -> remove padding` may be more efficient - // until we have specialized nested tensor bmm kernel - // useful resource: `aten/src/ATen/native/cpu/LinearAlgebra.cpp/bmm_out_or_baddbmm_` - // `aten/src/ATen/native/cuda/Blas.cpp/baddbmm_out_cuda_impl` - std::vector output_unbind = output.unbind(); - for (int64_t i = 0; i < ntensors; i++) { - at::mm_out(output_unbind[i], - self_buffer.as_strided(self_sizes[i], self_strides[i], self_offsets[i]), - mat2_buffer.as_strided(mat2_sizes[i], mat2_strides[i], mat2_offsets[i])); - } - return output; -} - -// utilities support `matmul_nested` -namespace { -// Args: -// self_sizes: the sizes of `self` in `matmul_nested` -// mat2_sizes: the sizes of `mat2` in `matmul_nested` -// buffer_op: the options for new buffer -// sizemat_op: the options for new size matrix -// Returns: -// the batch size of each input underlying tensor, i.e. the product of batch-dimension sizes -// the empty output nested tensor -inline std::tuple, Tensor> -matmul_nested_helper( - const std::vector& self_sizes, - const std::vector& mat2_sizes, - const c10::TensorOptions& buffer_op, - const c10::TensorOptions& sizemat_op) { - int64_t ntensors = self_sizes.size(), - ndims = self_sizes[0].size(); - std::vector batch_sizes(ntensors, 1); - Tensor sizemat = at::empty({ntensors, ndims}, sizemat_op); - int64_t* sizemat_ptr = sizemat.data_ptr(); - int64_t numel = 0; - for (int64_t i = 0; i < ntensors; i++) { - const IntArrayRef& self_size = self_sizes[i], - & mat2_size = mat2_sizes[i]; - int64_t& batch_size = batch_sizes[i]; - // batch dimensions - for (int64_t j = 0; j < ndims - 2; j++) { - const int64_t& self_sizej = self_size[j], - & mat2_sizej = mat2_size[j]; - TORCH_CHECK( - self_sizej == mat2_sizej, - "matmul: For nested tensors, no broadcasting is currently performed: ", - i, "-th nested matrices in batch at dimension ", j + 1, - " have mismatching sizes ", self_sizej, " and ", mat2_sizej); - sizemat_ptr[j] = self_sizej; - batch_size *= sizemat_ptr[j]; - } - // matrix multiplication dimensions - const int64_t& self_size0 = self_size[ndims - 2], & self_size1 = self_size[ndims - 1], - & mat2_size0 = mat2_size[ndims - 2], & mat2_size1 = mat2_size[ndims - 1]; - TORCH_CHECK( - self_size1 == mat2_size0, - "matmul: ", - i, "-th nested matrices in batch cannot be multiplied (", - self_size0, "x", self_size1, " and ", - mat2_size0, "x", mat2_size1, ")"); - sizemat_ptr[ndims - 2] = self_size0; - sizemat_ptr[ndims - 1] = mat2_size1; - sizemat_ptr += ndims; - numel += batch_size * self_size0 * mat2_size1; - } - Tensor buffer = at::empty(numel, buffer_op); - Tensor output = wrap_buffer(buffer, sizemat); - return std::make_tuple(batch_sizes, output); -} -} - -// Note [nested tensor matmul] -// This is really a generalized batched matmul dedicated to nested tensors, -// where `self` and `mat2` have same number (>= 3) of dimensions. -// The last 2 dimensions will be considered as matrix dimensions, -// so they should be matrix-multiplicable. -// The leading dimensions are considered as batch dimensions, -// and since nested tensor does not support broadcasting for now, -// for each batch dimension `self` and `mat2` must have same size. -// TODO: Should make full matmul semantics support some day -Tensor matmul_nested(const Tensor& self, const Tensor& mat2) { - if (self.is_nested() && !mat2.is_nested()) { - AT_ERROR("Expected both to be nested, but got a nested self and non-nested other"); - } - else if (!self.is_nested() && mat2.is_nested()) { - AT_ERROR("Expected both to be nested, but got a non-nested self and nested other"); - } - // to_padded_tensor only supports contiguous inputs - auto self_contig = self.contiguous(); - auto mat2_contig = mat2.contiguous(); - // dispatcher should have guaranteed that at least one is nested - const auto self_ptr = get_nested_tensor_impl(self_contig); - const auto mat2_ptr = get_nested_tensor_impl(mat2_contig); - int64_t self_dim = self_ptr->dim(), - mat2_dim = mat2_ptr->dim(); - TORCH_CHECK( - self_dim >= 3, - "matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: ", - self_dim); - TORCH_CHECK( - mat2_dim >= 3, - "matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 2nd input has rank: ", - mat2_dim); - TORCH_CHECK(self_dim == mat2_dim, "matmul: both inputs must have the same rank"); - int64_t ntensors = self_ptr->size(0), - ntensors2 = mat2_ptr->size(0); - TORCH_CHECK(ntensors == ntensors2, - "matmul: Expected size for the 1st dimension of 2nd input tensor to be: ", ntensors, - " but got: ", ntensors2, "."); - // Ensure batch dimensions have the same sizes (no broadcasting). - const auto& self_sizes = self_ptr->get_nested_size_tensor(); - const auto& mat2_sizes = mat2_ptr->get_nested_size_tensor(); - const auto& self_batch_sizes = self_sizes.narrow(1, 0, self_dim-3); - const auto& mat2_batch_sizes = mat2_sizes.narrow(1, 0, mat2_dim-3); - TORCH_CHECK(at::equal(self_batch_sizes, mat2_batch_sizes), - "matmul: For nested tensors, batch dimensions must have the same sizes, ", - "no broadcasting is currently performed. Got batch shapes for self ", - self_batch_sizes, - " and batch shapes for mat2 ", - mat2_batch_sizes); - // Ensure last dim of self and second last dim of mat2 have the same size - const auto& self_dim_size = self_sizes.select(1, -1); - const auto& mat2_dim_size = mat2_sizes.select(1, -2); - TORCH_CHECK(at::equal(self_dim_size, mat2_dim_size), - "matmul: Nested tensors cannot be matrix multiplied, last dimension of self has sizes", - self_dim_size, - "second last dimension of mat2 has sizes", - mat2_dim_size); - // Construct output size from input sizes - Tensor output_sizes = self_sizes.clone(); - // The last entry in every row of output_sizes should be last column of mat2_sizes - output_sizes.index_put_({at::indexing::Slice(), -1}, mat2_sizes.select(1, -1).clone()); - - auto self_padded = self_contig.to_padded_tensor(0.); - auto mat2_padded = mat2_contig.to_padded_tensor(0.); - auto output_padded = at::matmul(self_padded, mat2_padded); - auto output_nested = nested_from_padded_generic(output_padded, output_sizes); - return output_nested; -} - -Tensor& matmul_out_nested(const Tensor& tensor1, const Tensor& tensor2, Tensor& result) { - // TODO: this is a very quick and dirty implementation - // should improve it to avoid the intermediate memory usage - Tensor function_result = at::matmul(tensor1, tensor2); - auto function_result_ptr = get_nested_tensor_impl(function_result); - // TODO: this is to reproduce function_result_ptr->opt_sizes_ - // if an accessor is provided in the future, can replace this - std::vector sizes; - for (int64_t i = 0; i < function_result_ptr->dim(); i++) { - c10::optional opt_size = function_result_ptr->opt_size(i); - if (opt_size.has_value()) { - sizes.push_back(*opt_size); - } - else { - sizes.push_back(-1); - } - } - result.reshape(sizes); - result.copy_(function_result); - return result; -} - Tensor transpose_nested(const Tensor& self, int64_t dim0, int64_t dim1) { auto self_ptr = get_nested_tensor_impl(self); // check input dimensions @@ -1111,7 +740,6 @@ Tensor unsqueeze_nested(const Tensor& self, int64_t dim) { self, sizemat_unsqueezed, stridemat_unsqueezed, std::vector(self_ptr->get_storage_offsets())); } - // utilities supporting `view_nested` and `reshape_nested` namespace { // Args: diff --git a/aten/src/ATen/native/nested/NestedTensorMath.h b/aten/src/ATen/native/nested/NestedTensorMath.h index 69fe4ee3cd296..954fa807f1832 100644 --- a/aten/src/ATen/native/nested/NestedTensorMath.h +++ b/aten/src/ATen/native/nested/NestedTensorMath.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include namespace at { @@ -11,5 +12,12 @@ TORCH_API Tensor NestedTensor_to_padded_tensor_generic( double padding, OptionalIntArrayRef output_size); +template +Tensor map_nt(const Tensor& nt, Func f) { + auto* nt_impl = get_nested_tensor_impl(nt); + const auto& sizes = nt_impl->get_nested_size_tensor(); + return at::detail::make_tensor(f(nt_impl->get_buffer()), sizes); +} + } // namespace native } // namespace at diff --git a/aten/src/ATen/native/nested/NestedTensorMatmul.cpp b/aten/src/ATen/native/nested/NestedTensorMatmul.cpp new file mode 100644 index 0000000000000..c8cfa124330d6 --- /dev/null +++ b/aten/src/ATen/native/nested/NestedTensorMatmul.cpp @@ -0,0 +1,352 @@ +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace at { +namespace native { + +Tensor bmm_nested(const Tensor& self, const Tensor& mat2) { + if (self.is_nested() && !mat2.is_nested()) { + AT_ERROR("Expected both to be nested, but got a nested self and non-nested other"); + } + else if (!self.is_nested() && mat2.is_nested()) { + AT_ERROR("Expected both to be nested, but got a non-nested self and nested other"); + } + // dispatcher should have guaranteed that at least one is nested + auto self_ptr = get_nested_tensor_impl(self); + auto mat2_ptr = get_nested_tensor_impl(mat2); + TORCH_CHECK(self_ptr->dim() == 3, "batch1 must be a 3D tensor"); + TORCH_CHECK(mat2_ptr->dim() == 3, "batch2 must be a 3D tensor"); + int64_t ntensors = self_ptr->size(0), + ntensors2 = mat2_ptr->size(0); + TORCH_CHECK(ntensors == ntensors2, + "Expected size for the 1st dimension of batch2 tensor to be: ", ntensors, + " but got: ", ntensors2, "."); + const Tensor& self_buffer = self_ptr->get_unsafe_storage_as_tensor(), + & mat2_buffer = mat2_ptr->get_unsafe_storage_as_tensor(); + std::vector self_sizes = NestedTensor_get_sizes(self_ptr), + mat2_sizes = NestedTensor_get_sizes(mat2_ptr), + self_strides = NestedTensor_get_strides(self_ptr), + mat2_strides = NestedTensor_get_strides(mat2_ptr); + const std::vector& self_offsets = self_ptr->get_storage_offsets(), + & mat2_offsets = mat2_ptr->get_storage_offsets(); + // create a contiguous output + int64_t out_numel = 0; + const Tensor& self_sizemat = self_ptr->get_nested_size_tensor(); + Tensor out_sizemat = self_sizemat.new_empty(self_sizemat.sizes()); + int64_t* out_sizemat_ptr = out_sizemat.data_ptr(); + for (int64_t i = 0; i < ntensors; i++) { + const IntArrayRef& self_shape = self_sizes[i], + & mat2_shape = mat2_sizes[i]; + const int64_t& self_size0 = self_shape[0], & self_size1 = self_shape[1], + & mat2_size0 = mat2_shape[0], & mat2_size1 = mat2_shape[1]; + TORCH_CHECK(self_size1 == mat2_size0, + i, "-th nested matrices in batch cannot be multiplied (", + self_size0, "x", self_size1, " and ", + mat2_size0, "x", mat2_size1, ")"); + out_sizemat_ptr[0] = self_size0; + out_sizemat_ptr[1] = mat2_size1; + out_sizemat_ptr += 2; + out_numel += self_size0 * mat2_size1; + } + Tensor out_buffer = self_buffer.new_empty(out_numel); + Tensor output = wrap_buffer(out_buffer, out_sizemat); + // call tensor mm + // TODO: `padding nested tensor -> bmm -> remove padding` may be more efficient + // until we have specialized nested tensor bmm kernel + // useful resource: `aten/src/ATen/native/cpu/LinearAlgebra.cpp/bmm_out_or_baddbmm_` + // `aten/src/ATen/native/cuda/Blas.cpp/baddbmm_out_cuda_impl` + std::vector output_unbind = output.unbind(); + for (int64_t i = 0; i < ntensors; i++) { + at::mm_out(output_unbind[i], + self_buffer.as_strided(self_sizes[i], self_strides[i], self_offsets[i]), + mat2_buffer.as_strided(mat2_sizes[i], mat2_strides[i], mat2_offsets[i])); + } + return output; +} + +// utilities support `matmul_nested` +namespace { +// Args: +// self_sizes: the sizes of `self` in `matmul_nested` +// mat2_sizes: the sizes of `mat2` in `matmul_nested` +// buffer_op: the options for new buffer +// sizemat_op: the options for new size matrix +// Returns: +// the batch size of each input underlying tensor, i.e. the product of batch-dimension sizes +// the empty output nested tensor +inline std::tuple, Tensor> +matmul_nested_helper( + const std::vector& self_sizes, + const std::vector& mat2_sizes, + const c10::TensorOptions& buffer_op, + const c10::TensorOptions& sizemat_op) { + int64_t ntensors = self_sizes.size(), + ndims = self_sizes[0].size(); + std::vector batch_sizes(ntensors, 1); + Tensor sizemat = at::empty({ntensors, ndims}, sizemat_op); + int64_t* sizemat_ptr = sizemat.data_ptr(); + int64_t numel = 0; + for (int64_t i = 0; i < ntensors; i++) { + const IntArrayRef& self_size = self_sizes[i], + & mat2_size = mat2_sizes[i]; + int64_t& batch_size = batch_sizes[i]; + // batch dimensions + for (int64_t j = 0; j < ndims - 2; j++) { + const int64_t& self_sizej = self_size[j], + & mat2_sizej = mat2_size[j]; + TORCH_CHECK( + self_sizej == mat2_sizej, + "matmul: For nested tensors, no broadcasting is currently performed: ", + i, "-th nested matrices in batch at dimension ", j + 1, + " have mismatching sizes ", self_sizej, " and ", mat2_sizej); + sizemat_ptr[j] = self_sizej; + batch_size *= sizemat_ptr[j]; + } + // matrix multiplication dimensions + const int64_t& self_size0 = self_size[ndims - 2], & self_size1 = self_size[ndims - 1], + & mat2_size0 = mat2_size[ndims - 2], & mat2_size1 = mat2_size[ndims - 1]; + TORCH_CHECK( + self_size1 == mat2_size0, + "matmul: ", + i, "-th nested matrices in batch cannot be multiplied (", + self_size0, "x", self_size1, " and ", + mat2_size0, "x", mat2_size1, ")"); + sizemat_ptr[ndims - 2] = self_size0; + sizemat_ptr[ndims - 1] = mat2_size1; + sizemat_ptr += ndims; + numel += batch_size * self_size0 * mat2_size1; + } + Tensor buffer = at::empty(numel, buffer_op); + Tensor output = wrap_buffer(buffer, sizemat); + return std::make_tuple(batch_sizes, output); +} +} + +Tensor matmul_with_bmm_nested(const Tensor& self, const Tensor& mat2) { + // Tensor self = self_.contiguous(); + // Tensor mat2 = mat2_.contiguous(); + // self [N, n_heads, *, head_dim] + // mat2 [N, n_heads, head_dim, *] + const auto self_ptr = get_nested_tensor_impl(self); + const auto mat2_ptr = get_nested_tensor_impl(mat2); + // metadata for self + std::vector self_sizes = NestedTensor_get_sizes(self_ptr); + std::vector self_strides = NestedTensor_get_strides(self_ptr); + std::vector self_offsets = self_ptr->get_storage_offsets(); + auto opt = self_ptr->get_nested_size_tensor().options(); + + // metadata for mat2 + std::vector mat2_sizes = NestedTensor_get_sizes(mat2_ptr); + std::vector mat2_strides = NestedTensor_get_strides(mat2_ptr); + std::vector mat2_offsets = mat2_ptr->get_storage_offsets(); + auto opt2 = mat2_ptr->get_nested_size_tensor().options(); + + int64_t N = self_sizes.size(); + int64_t n_heads = self_sizes[0][0]; + + // viewed metadata for self + auto self_new_sizes = at::empty({N * n_heads, 2}, opt); + int64_t* self_new_sizes_ptr = self_new_sizes.data_ptr(); + + auto self_new_strides = at::empty({N * n_heads, 2}, opt); + int64_t* self_new_strides_ptr = self_new_strides.data_ptr(); + std::vector self_new_offsets; + + // viewed metadata for mat2 + auto mat2_new_sizes = at::empty({N * n_heads, 2}, opt2); + int64_t* mat2_new_sizes_ptr = mat2_new_sizes.data_ptr(); + + auto mat2_new_strides = at::empty({N * n_heads, 2}, opt2); + int64_t* mat2_new_strides_ptr = mat2_new_strides.data_ptr(); + std::vector mat2_new_offsets; + + for (int64_t i = 0; i < N; i++) { + const IntArrayRef& self_size_i = self_sizes[i]; + const IntArrayRef& self_stride_i = self_strides[i]; + int64_t self_offset = self_offsets[i]; + + const IntArrayRef& mat2_size_i = mat2_sizes[i]; + const IntArrayRef& mat2_stride_i = mat2_strides[i]; + int64_t mat2_offset = mat2_offsets[i]; + for (int64_t j = 0; j < n_heads; j++) { + auto idx = (i * n_heads + j) * 2; + self_new_sizes_ptr[idx] = self_size_i[1]; + self_new_sizes_ptr[idx + 1] = self_size_i[2]; + self_new_strides_ptr[idx] = self_stride_i[1]; + self_new_strides_ptr[idx + 1] = self_stride_i[2]; + self_new_offsets.push_back(self_offset); + self_offset += self_stride_i[0]; + + mat2_new_sizes_ptr[idx] = mat2_size_i[1]; + mat2_new_sizes_ptr[idx + 1] = mat2_size_i[2]; + mat2_new_strides_ptr[idx] = mat2_stride_i[1]; + mat2_new_strides_ptr[idx + 1] = mat2_stride_i[2]; + mat2_new_offsets.push_back(mat2_offset); + mat2_offset += mat2_stride_i[0]; + } + } + + + // view self as [N * n_heads, *, head_dim] (collapse first 2 dims) + auto viewed_self = create_nested_view_tensor( + self, self_new_sizes, self_new_strides, std::vector(self_new_offsets)); + + // view mat2 as [N * n_heads, head_dim, *] (collapse first 2_dims) + auto viewed_mat2 = create_nested_view_tensor( + mat2, mat2_new_sizes, mat2_new_strides, std::vector(mat2_new_offsets)); + + // output [N * n_heads, *, *] + auto bmm_output = at::bmm(viewed_self, viewed_mat2); + + // generate metadata for viewing output as [N, n_heads, *, *] + // output of bmm should be contiguous so stride calculations should hold + auto out_new_sizes = at::empty({N, 3}, opt); + auto out_new_strides = at::empty({N, 3}, opt); + std::vector out_new_offsets; + + int64_t* out_new_sizes_ptr = out_new_sizes.data_ptr(); + int64_t* out_new_strides_ptr = out_new_strides.data_ptr(); + + int64_t out_offset = 0; + for (int64_t i = 0; i < N; i++) { + out_new_offsets.push_back(out_offset); + const IntArrayRef& self_size_i = self_sizes[i]; + const IntArrayRef& mat2_size_i = mat2_sizes[i]; + auto idx = i * 3; + out_new_sizes_ptr[idx] = n_heads; + out_new_sizes_ptr[idx + 1] = self_size_i[1]; + out_new_sizes_ptr[idx + 2] = mat2_size_i[2]; + out_new_strides_ptr[idx] = self_size_i[1] * mat2_size_i[2]; + out_new_strides_ptr[idx + 1] = mat2_size_i[2]; + out_new_strides_ptr[idx + 2] = 1; + out_offset += n_heads * (self_size_i[1] * mat2_size_i[2]); + } + + auto viewed_out = create_nested_view_tensor( + bmm_output, out_new_sizes, out_new_strides, std::vector(out_new_offsets)); + + return viewed_out; + +} + +// Note [nested tensor matmul] +// This is really a generalized batched matmul dedicated to nested tensors, +// where `self` and `mat2` have same number (>= 3) of dimensions. +// The last 2 dimensions will be considered as matrix dimensions, +// so they should be matrix-multiplicable. +// The leading dimensions are considered as batch dimensions, +// and since nested tensor does not support broadcasting for now, +// for each batch dimension `self` and `mat2` must have same size. +// TODO: Should make full matmul semantics support some day +Tensor matmul_nested(const Tensor& self, const Tensor& mat2) { + if (self.is_nested() && !mat2.is_nested()) { + AT_ERROR("Expected both to be nested, but got a nested self and non-nested other"); + } + else if (!self.is_nested() && mat2.is_nested()) { + AT_ERROR("Expected both to be nested, but got a non-nested self and nested other"); + } + // to_padded_tensor only supports contiguous inputs + auto self_contig = self.contiguous(); + auto mat2_contig = mat2.contiguous(); + // dispatcher should have guaranteed that at least one is nested + const auto self_ptr = get_nested_tensor_impl(self_contig); + const auto mat2_ptr = get_nested_tensor_impl(mat2_contig); + int64_t self_dim = self_ptr->dim(), + mat2_dim = mat2_ptr->dim(); + TORCH_CHECK( + self_dim >= 3, + "matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: ", + self_dim); + TORCH_CHECK( + mat2_dim >= 3, + "matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 2nd input has rank: ", + mat2_dim); + TORCH_CHECK(self_dim == mat2_dim, "matmul: both inputs must have the same rank"); + int64_t ntensors = self_ptr->size(0), + ntensors2 = mat2_ptr->size(0); + TORCH_CHECK(ntensors == ntensors2, + "matmul: Expected size for the 1st dimension of 2nd input tensor to be: ", ntensors, + " but got: ", ntensors2, "."); + // Ensure batch dimensions have the same sizes (no broadcasting). + const auto& self_sizes = self_ptr->get_nested_size_tensor(); + const auto& mat2_sizes = mat2_ptr->get_nested_size_tensor(); + const auto& self_batch_sizes = self_sizes.narrow(1, 0, self_dim-3); + const auto& mat2_batch_sizes = mat2_sizes.narrow(1, 0, mat2_dim-3); + TORCH_CHECK(at::equal(self_batch_sizes, mat2_batch_sizes), + "matmul: For nested tensors, batch dimensions must have the same sizes, ", + "no broadcasting is currently performed. Got batch shapes for self ", + self_batch_sizes, + " and batch shapes for mat2 ", + mat2_batch_sizes); + // Ensure last dim of self and second last dim of mat2 have the same size + const auto& self_dim_size = self_sizes.select(1, -1); + const auto& mat2_dim_size = mat2_sizes.select(1, -2); + TORCH_CHECK(at::equal(self_dim_size, mat2_dim_size), + "matmul: Nested tensors cannot be matrix multiplied, last dimension of self has sizes", + self_dim_size, + "second last dimension of mat2 has sizes", + mat2_dim_size); + + // use bmm inference-only fast path for [N, n_heads, *, head_dim] [N, n_heads, head_dim, *] + if (self.is_cuda() && + self_dim == 4 && self.is_contiguous() && + mat2_dim == 4 && mat2.is_contiguous() && + !(GradMode::is_enabled() && (self.requires_grad() || mat2.requires_grad()))) { + auto n_heads = self_sizes.select(0, 1).select(0, 0).item(); + auto self_first_dim_n_heads = at::all(self_sizes.select(1, 0) == n_heads).item(); + auto mat2_first_dim_n_heads = at::all(mat2_sizes.select(1, 0) == n_heads).item(); + if (self_first_dim_n_heads && mat2_first_dim_n_heads) { + return matmul_with_bmm_nested(self, mat2); + } + } + + // Construct output size from input sizes + Tensor output_sizes = self_sizes.clone(); + // The last entry in every row of output_sizes should be last column of mat2_sizes + output_sizes.index_put_({at::indexing::Slice(), -1}, mat2_sizes.select(1, -1).clone()); + + auto self_padded = self_contig.to_padded_tensor(0.); + auto mat2_padded = mat2_contig.to_padded_tensor(0.); + auto output_padded = at::matmul(self_padded, mat2_padded); + auto output_nested = nested_from_padded_generic(output_padded, output_sizes); + return output_nested; +} + +Tensor& matmul_out_nested(const Tensor& tensor1, const Tensor& tensor2, Tensor& result) { + // TODO: this is a very quick and dirty implementation + // should improve it to avoid the intermediate memory usage + Tensor function_result = at::matmul(tensor1, tensor2); + auto function_result_ptr = get_nested_tensor_impl(function_result); + // TODO: this is to reproduce function_result_ptr->opt_sizes_ + // if an accessor is provided in the future, can replace this + std::vector sizes; + for (int64_t i = 0; i < function_result_ptr->dim(); i++) { + c10::optional opt_size = function_result_ptr->opt_size(i); + if (opt_size.has_value()) { + sizes.push_back(*opt_size); + } + else { + sizes.push_back(-1); + } + } + result.reshape(sizes); + result.copy_(function_result); + return result; +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/nested/NestedTensorUnaryOps.cpp b/aten/src/ATen/native/nested/NestedTensorUnaryOps.cpp new file mode 100644 index 0000000000000..6be7239775ea6 --- /dev/null +++ b/aten/src/ATen/native/nested/NestedTensorUnaryOps.cpp @@ -0,0 +1,74 @@ +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace at { +namespace native { + +Tensor& NestedTensor_relu_(Tensor& self) { + auto self_ptr = get_nested_tensor_impl(self); + check_numel_equals_buffer_size(self_ptr); + auto buffer = self_ptr->get_buffer(); + at::relu_(buffer); + return self; +} + +Tensor NestedTensor_relu(const Tensor& self) { + return map_nt(self, at::relu); +} + +Tensor& NestedTensor_gelu_(Tensor& self, c10::string_view approximate) { + auto self_ptr = get_nested_tensor_impl(self); + check_numel_equals_buffer_size(self_ptr); + auto buffer = self_ptr->get_buffer(); + at::gelu_(buffer, approximate); + return self; +} + +Tensor NestedTensor_gelu(const Tensor& self, c10::string_view approximate) { + return map_nt( + self, + [approximate](const Tensor& buffer) { + return at::gelu(buffer, approximate); + }); +} + +Tensor& NestedTensor_tanh_(Tensor& self) { + auto self_ptr = get_nested_tensor_impl(self); + check_numel_equals_buffer_size(self_ptr); + auto buffer = self_ptr->get_buffer(); + at::tanh_(buffer); + return self; +} + +Tensor NestedTensor_tanh(const Tensor& self) { + return map_nt(self, at::tanh); +} + +Tensor& NestedTensor_neg_(Tensor& self) { + auto self_ptr = get_nested_tensor_impl(self); + check_numel_equals_buffer_size(self_ptr); + auto buffer = self_ptr->get_buffer(); + at::neg_(buffer); + return self; +} + +Tensor NestedTensor_neg(const Tensor& self) { + return map_nt(self, at::neg); +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/nested/NestedTensorUtils.h b/aten/src/ATen/native/nested/NestedTensorUtils.h index 77d512c519b28..6590db9116e09 100644 --- a/aten/src/ATen/native/nested/NestedTensorUtils.h +++ b/aten/src/ATen/native/nested/NestedTensorUtils.h @@ -1,17 +1,21 @@ #pragma once -#include +#include #include +#include +#include #include #include #include #include #ifndef AT_PER_OPERATOR_HEADERS + #include #include #else #include +#include #include #include #include @@ -50,6 +54,19 @@ inline at::Tensor wrap_buffer( std::move(offsets)); } +inline at::Tensor wrap_buffer( + at::Tensor buffer, + at::Tensor nested_size_tensor, + at::Tensor nested_stride_tensor, + const std::vector& offsets) { + std::vector offsets_copy(offsets); + return wrap_buffer( + buffer, + nested_size_tensor, + nested_stride_tensor, + std::move(offsets_copy)); +} + inline at::Tensor get_buffer(const at::Tensor& tensor) { return get_nested_tensor_impl(tensor)->get_buffer(); } @@ -119,7 +136,6 @@ inline std::vector NestedTensor_get_sizes( return sizes; } - TORCH_API std::vector NestedTensor_get_max_size( const NestedTensorImpl& nt); @@ -161,17 +177,18 @@ inline std::vector NestedTensor_get_strides( inline void check_numel_equals_buffer_size(const at::Tensor& self) { auto self_impl = get_nested_tensor_impl(self); TORCH_CHECK( - self.numel() == self_impl -> get_buffer_size(), + self.numel() == self_impl->get_buffer_size(), "Number of elements in nested tensor must match number of elements in buffer."); } inline void check_numel_equals_buffer_size(const NestedTensorImpl* self_ptr) { TORCH_CHECK( - self_ptr-> numel() == self_ptr -> get_buffer_size(), + self_ptr->numel() == self_ptr->get_buffer_size(), "Number of elements in nested tensor must match number of elements in buffer."); } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -// Data structures and functions for generically applying a function on a nested tensor. +// Data structures and functions for generically applying a function on a nested +// tensor. namespace impl { template @@ -308,17 +325,84 @@ inline Tensor wrap_tensor_node( if (tensor_node.degree() == 0) { return wrap_buffer(ones({0}, dtype, layout, device), ones({})); } - std::vector sizes; - std::vector flat_tensors; + + // Fast path: if all tensors are on CPU, have contiguous memory, and the same + // dtype, copying can be done much faster. + bool all_tensors_cpu = true; + bool all_tensors_contiguous = true; + bool all_tensors_same_dtype = true; + auto first_dtype = tensor_node.children(0).dtype(); + std::vector start_offsets(tensor_node.degree()); + start_offsets[0] = 0; + long total_size = 0; for (const auto i : c10::irange(tensor_node.degree())) { - flat_tensors.push_back(tensor_node.children(i).reshape(-1).contiguous()); - sizes.push_back(tensor(c10::IntArrayRef(tensor_node.children(i).sizes()))); + all_tensors_cpu = all_tensors_cpu && tensor_node.children(i).is_cpu(); + all_tensors_contiguous = + all_tensors_contiguous && tensor_node.children(i).is_contiguous(); + all_tensors_same_dtype = all_tensors_same_dtype && + (first_dtype == tensor_node.children(i).dtype()); + if (!(all_tensors_cpu && all_tensors_contiguous && + all_tensors_same_dtype)) { + break; + } + if (i > 0) { + start_offsets[i] = + start_offsets[i - 1] + tensor_node.children(i - 1).numel(); + } + total_size += tensor_node.children(i).numel(); } - TensorOptions options = flat_tensors[0].options().merge_in(options_); + TensorOptions options; + Tensor nt_buffer, nt_sizes; + if (all_tensors_cpu && all_tensors_contiguous && all_tensors_same_dtype) { + nt_buffer = at::empty({total_size}, tensor_node.children(0).options()); + nt_sizes = at::empty( + {static_cast(tensor_node.degree()), + static_cast(tensor_node.children(0).sizes().size())}, + TensorOptions().dtype(kLong)); + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( + at::ScalarType::Half, + at::ScalarType::Bool, + at::ScalarType::BFloat16, + c10::typeMetaToScalarType(first_dtype), + "create_nt_buffer", + [&]() { + at::parallel_for( + 0, tensor_node.degree(), 1, [&](int64_t begin, int64_t end) { + for (int64_t i = begin; i < end; ++i) { + // Only try copying memory if there is more than 0 elements + // for a certain tensor + if (tensor_node.children(i).numel() > 0) { + memcpy( + nt_buffer.data_ptr() + start_offsets[i], + tensor_node.children(i).data_ptr(), + tensor_node.children(i).numel() * sizeof(scalar_t)); + } + } + }); + }); + long sizes_offset = 0; + for (size_t i = 0; i < tensor_node.degree(); ++i) { + auto tensor_sizes = tensor_node.children(i).sizes(); + for (size_t j = 0; j < tensor_sizes.size(); ++j) { + nt_sizes.data_ptr()[sizes_offset++] = tensor_sizes[j]; + } + } + options = nt_buffer.options().merge_in(options_); + } else { // Slow path + std::vector flat_tensors; + std::vector sizes; + for (const auto i : c10::irange(tensor_node.degree())) { + flat_tensors.push_back(tensor_node.children(i).reshape(-1).contiguous()); + sizes.push_back( + tensor(c10::IntArrayRef(tensor_node.children(i).sizes()))); + } + options = flat_tensors[0].options().merge_in(options_); + nt_buffer = at::cat(flat_tensors); + nt_sizes = at::native::stack(sizes); + } - return wrap_buffer( - at::cat(flat_tensors).to(options), at::native::stack(sizes)); + return wrap_buffer(nt_buffer.to(options), nt_sizes); } } // namespace impl diff --git a/aten/src/ATen/native/nested/cuda/NestedTensorBinaryOps.cu b/aten/src/ATen/native/nested/cuda/NestedTensorBinaryOps.cu new file mode 100644 index 0000000000000..678e62f5a81c6 --- /dev/null +++ b/aten/src/ATen/native/nested/cuda/NestedTensorBinaryOps.cu @@ -0,0 +1,120 @@ +#include + +#include + +#include +#include + +#include +#include +#include +#include +#include + +#include +#include + + +#include + +#define BLOCK_DIM 256 + +namespace at { +namespace native { + + +// only for nested [B, *, D], dense [B, 1, D] +template +__global__ void op_dense_esuhm( + const T* input, + const T* dense, + T* output, + int64_t embedding_dim, + const int64_t* offsets, + const func_t& f) +{ + // each batch is handled by a block + const int64_t batch_idx = blockIdx.x; + const int64_t grain_size = blockDim.x; + const int64_t tid = threadIdx.x; + const int64_t range = offsets[batch_idx + 1] - offsets[batch_idx]; + // each thread handles (embedding_dim // grain_size + (embedding_dim % grain_size <= tid)) elems + // of the dense embedding + for (int64_t idx = tid; idx < embedding_dim; idx += grain_size) { + const T dense_elem = dense[batch_idx * embedding_dim + idx]; + for (int64_t nested_idx = idx; nested_idx < range; nested_idx += embedding_dim) { + output[offsets[batch_idx] + nested_idx] = f(input[offsets[batch_idx] + nested_idx], dense_elem); + } + } +} + +template +void nested_op_dense_kernelLauncher( + const T* input, // [sum(*) x embedding_dim] + const T* dense, // [batch_size x embedding_dim] + T* output, // [sum(*) x embedding_dim] + int64_t batch_size, + int64_t embedding_dim, + const int64_t* input_offsets, // [batch_size] + func_t f) +{ + dim3 grid; + grid.x = batch_size; + const auto stream = at::cuda::getDefaultCUDAStream(); + + op_dense_esuhm<<>>( + input, + dense, + output, + embedding_dim, + input_offsets, + f); +} + +template +void _nested_op_dense_esuhm_kernel(Tensor& result, const Tensor& self, const Tensor& other, func_t f) { + auto self_ptr = get_nested_tensor_impl(self); + auto result_ptr = get_nested_tensor_impl(result); + + const auto self_buffer = self_ptr->get_buffer(); + const auto offsets = self_ptr->get_storage_offsets(); + const auto batch_size = other.size(0); + const auto embedding_size = other.size(2); + + auto result_buffer = result_ptr->get_buffer(); + auto result_offsets = at::cat({at::tensor(offsets), at::tensor(self_ptr->numel())}); + result_offsets = result_offsets.to(kCUDA); + + const scalar_t* self_data_ptr = self_buffer.data_ptr(); + const scalar_t* other_data_ptr = other.data_ptr(); + scalar_t* result_data_ptr = result_buffer.data_ptr(); + int64_t* result_offsets_ptr = result_offsets.data_ptr(); + + nested_op_dense_kernelLauncher( + self_data_ptr, + other_data_ptr, + result_data_ptr, + batch_size, + embedding_size, + result_offsets_ptr, + f); +} + +void _nested_op_dense_esuhm_cuda(Tensor& result, const Tensor& self, const Tensor& other, const NESTED_DENSE_OP& op) { + AT_DISPATCH_ALL_TYPES_AND2( + ScalarType::Half, ScalarType::BFloat16, self.scalar_type(), "_nested_op_dense_esuhm", [&]() { + switch (op) { + case NESTED_DENSE_OP::ADD : + _nested_op_dense_esuhm_kernel(result, self, other, [] __host__ __device__ (scalar_t a, scalar_t b) -> scalar_t { return a + b; }); + break; + case NESTED_DENSE_OP::MUL : + _nested_op_dense_esuhm_kernel(result, self, other, [] __host__ __device__ (scalar_t a, scalar_t b) -> scalar_t { return a * b; }); + break; + } + }); +} + +REGISTER_CUDA_DISPATCH(nested_dense_elementwise_stub, &_nested_op_dense_esuhm_cuda); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/nested/cuda/NestedTensorMatmul.cu b/aten/src/ATen/native/nested/cuda/NestedTensorMatmul.cu new file mode 100644 index 0000000000000..22cf38f850208 --- /dev/null +++ b/aten/src/ATen/native/nested/cuda/NestedTensorMatmul.cu @@ -0,0 +1,416 @@ +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include + +#ifndef USE_ROCM +#ifndef _WIN32 +#include +#include +#include +#endif +#endif + +#include + +#define BLOCK_DIM 256 +#define GRID_DIM_Y 16 + +namespace at { +namespace native { + +#ifndef USE_ROCM +#ifndef _WIN32 +namespace { + +template < + typename scalar_t, + unsigned int kPad, + typename LayoutA, + typename LayoutB, + typename OpClass, + typename Arch, + typename ThreadBlockShape, + typename WarpShape, + typename InstructionShape> +void gemm_grouped_cuda_internal( + const std::vector& lda, + const std::vector& ldb, + const std::vector& ldd, + const std::vector& aptr, + const std::vector& bptr, + const std::vector& dptr, + const std::vector& gemm_sizes, + const int problem_count, + at::Device& device) { + using Element = scalar_t; + using ElementAcc = float; + + using GemmConfiguration = + typename cutlass::gemm::device::DefaultGemmConfiguration< + OpClass, + Arch, + Element, + Element, + Element, + ElementAcc>; + + using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< + Element, + LayoutA, + cutlass::ComplexTransform::kNone, + kPad, + Element, + LayoutB, + cutlass::ComplexTransform::kNone, + kPad, + Element, + cutlass::layout::RowMajor, + ElementAcc, + OpClass, + Arch, + ThreadBlockShape, + WarpShape, + InstructionShape, + typename GemmConfiguration::EpilogueOutputOp, + cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, + GemmConfiguration::kStages>::GemmKernel; + + using GemmGrouped = typename cutlass::gemm::device::GemmGrouped; + using EpilogueOutputOp = typename GemmGrouped::GemmKernel::Epilogue::OutputOp; + typename EpilogueOutputOp::Params epilogue_op(/*alpha*/ 1, /*beta*/ 0); + + const int64_t gemm_coord_size = + problem_count * ((int64_t)sizeof(cutlass::gemm::GemmCoord)); + // Number of gmm args not including *problem_sizes + at::Tensor gmm_args = at::empty( + {problem_count * 6 + gemm_coord_size}, + at::TensorOptions().dtype(at::kLong).pinned_memory(true)); + + // Obtain pointers for each argument (on host) + int64_t* lda_data = gmm_args.data_ptr(); // Base pointer + int64_t* ldb_data = lda_data + problem_count; + int64_t* ldd_data = lda_data + 2 * problem_count; + int64_t* ptr_a_data = lda_data + 3 * problem_count; + int64_t* ptr_b_data = lda_data + 4 * problem_count; + int64_t* ptr_d_data = lda_data + 5 * problem_count; + cutlass::gemm::GemmCoord* problem_sizes_data = + reinterpret_cast(lda_data + 6 * problem_count); + + // Set arguments into gmm_args from input args + for (int i = 0; i < problem_count; ++i) { + problem_sizes_data[i] = gemm_sizes[i]; + lda_data[i] = lda[i]; + ldb_data[i] = ldb[i]; + ldd_data[i] = ldd[i]; + ptr_a_data[i] = reinterpret_cast(aptr[i]); + ptr_b_data[i] = reinterpret_cast(bptr[i]); + ptr_d_data[i] = reinterpret_cast(dptr[i]); + } + const int threadblock_count = + GemmGrouped::sufficient(problem_sizes_data, problem_count); + + // Transfer arguments to GPU + gmm_args = gmm_args.to(device, true); + + // Obtain pointers for each of arguments (on GPU) + lda_data = gmm_args.data_ptr(); // Base pointer + ldb_data = lda_data + problem_count; + ldd_data = lda_data + 2 * problem_count; + ptr_a_data = lda_data + 3 * problem_count; + ptr_b_data = lda_data + 4 * problem_count; + ptr_d_data = lda_data + 5 * problem_count; + problem_sizes_data = + reinterpret_cast(lda_data + 6 * problem_count); + + // Create GemmGrouped::Arguments using the arguments prepared above + typename GemmGrouped::Arguments args( + problem_sizes_data, + problem_count, + threadblock_count, + epilogue_op, + reinterpret_cast(ptr_a_data), + reinterpret_cast(ptr_b_data), + reinterpret_cast(ptr_d_data), + reinterpret_cast(ptr_d_data), + lda_data, + ldb_data, + ldd_data, + ldd_data); + + GemmGrouped gemm; + cutlass::Status status = + gemm.initialize(args, nullptr, at::cuda::getCurrentCUDAStream()); + TORCH_CHECK( + status != cutlass::Status::kErrorWorkspaceNull, + "Failed to initialize CUTLASS Grouped GEMM kernel due to workspace."); + TORCH_CHECK( + status != cutlass::Status::kErrorInternal, + "Failed to initialize CUTLASS Grouped GEMM kernel due to internal error."); + TORCH_CHECK( + status == cutlass::Status::kSuccess, + "Failed to initialize CUTLASS Grouped GEMM kernel."); + + // Run CUTLASS group GEMM + status = gemm.run(at::cuda::getCurrentCUDAStream()); + TORCH_CHECK( + status == cutlass::Status::kSuccess, + "Failed to run CUTLASS Grouped GEMM kernel."); + + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +bool group_gemm_dispatch( + at::Device device, + const std::vector& aptr, + const std::vector& bptr, + const std::vector& dptr, + const std::vector& lda, + const std::vector& ldb, + const std::vector& ldd, + std::vector gemm_sizes, + int64_t ntensors) { + return false; +} + +template <> +bool group_gemm_dispatch( + at::Device device, + const std::vector& aptr, + const std::vector& bptr, + const std::vector& dptr, + const std::vector& lda, + const std::vector& ldb, + const std::vector& ldd, + std::vector gemm_sizes, + int64_t ntensors) { + + gemm_grouped_cuda_internal< + float, + 1, + cutlass::layout::RowMajor, + cutlass::layout::RowMajor, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 8>, + cutlass::gemm::GemmShape<64, 32, 8>, + cutlass::gemm::GemmShape<1, 1, 1>>( + lda, ldb, ldd, aptr, bptr, dptr, gemm_sizes, ntensors, device); + return true; +} + +template <> +bool group_gemm_dispatch( + at::Device device, + const std::vector& aptr_, + const std::vector& bptr_, + const std::vector& dptr_, + const std::vector& lda, + const std::vector& ldb, + const std::vector& ldd, + std::vector gemm_sizes, + int64_t ntensors) { + + // Check alignment + bool all_pad_8 = true; + for (int i = 0; i < ntensors; i++) { + all_pad_8 = all_pad_8 && (gemm_sizes[i].n() % 8 == 0); + all_pad_8 = all_pad_8 && (gemm_sizes[i].k() % 8 == 0); + + // Not sure if this is a requirement, on the safe side + all_pad_8 = all_pad_8 && (lda[i] % 8 == 0); + all_pad_8 = all_pad_8 && (ldb[i] % 8 == 0); + all_pad_8 = all_pad_8 && (ldd[i] % 8 == 0); + } + + std::vector aptr; + std::vector bptr; + std::vector dptr; + for (int64_t i = 0; i < ntensors; i++) { + aptr.push_back(reinterpret_cast(aptr_[i])); + bptr.push_back(reinterpret_cast(bptr_[i])); + dptr.push_back(reinterpret_cast(dptr_[i])); + } + if (all_pad_8) { + gemm_grouped_cuda_internal< + cutlass::half_t, + 8, + cutlass::layout::RowMajor, + cutlass::layout::RowMajor, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>>( + lda, ldb, ldd, aptr, bptr, dptr, gemm_sizes, ntensors, device); + return true; + } else { + gemm_grouped_cuda_internal< + cutlass::half_t, + 1, + cutlass::layout::RowMajor, + cutlass::layout::RowMajor, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 8>, + cutlass::gemm::GemmShape<64, 32, 8>, + cutlass::gemm::GemmShape<1, 1, 1>>( + lda, ldb, ldd, aptr, bptr, dptr, gemm_sizes, ntensors, device); + return true; + } + // Did not perform GEMM + return false; +} + +} // namespace + +#endif +#endif + +Tensor bmm_nested_cuda(const Tensor& self, const Tensor& mat2) { + if (self.is_nested() && !mat2.is_nested()) { + AT_ERROR( + "Expected both to be nested, but got a nested self and non-nested other"); + } else if (!self.is_nested() && mat2.is_nested()) { + AT_ERROR( + "Expected both to be nested, but got a non-nested self and nested other"); + } + // dispatcher should have guaranteed that at least one is nested + auto self_ptr = get_nested_tensor_impl(self); + auto mat2_ptr = get_nested_tensor_impl(mat2); + TORCH_CHECK(self_ptr->dim() == 3, "batch1 must be a 3D tensor"); + TORCH_CHECK(mat2_ptr->dim() == 3, "batch2 must be a 3D tensor"); + int64_t ntensors = self_ptr->size(0), ntensors2 = mat2_ptr->size(0); + TORCH_CHECK( + ntensors == ntensors2, + "Expected size for the 1st dimension of batch2 tensor to be: ", + ntensors, + " but got: ", + ntensors2, + "."); + + // create a contiguous output + const Tensor& self_sizemat = self_ptr->get_nested_size_tensor(); + Tensor out_sizemat = self_sizemat.new_empty(self_sizemat.sizes()); + int64_t* out_sizemat_ptr = out_sizemat.data_ptr(); + + std::vector self_sizes = NestedTensor_get_sizes(self_ptr); + std::vector mat2_sizes = NestedTensor_get_sizes(mat2_ptr); + + int64_t out_numel = 0; + for (int64_t i = 0; i < ntensors; i++) { + const IntArrayRef &self_shape = self_sizes[i], &mat2_shape = mat2_sizes[i]; + const int64_t &self_size0 = self_shape[0], &self_size1 = self_shape[1], + &mat2_size0 = mat2_shape[0], &mat2_size1 = mat2_shape[1]; + TORCH_CHECK( + self_size1 == mat2_size0, + i, + "-th nested matrices in batch cannot be multiplied (", + self_size0, + "x", + self_size1, + " and ", + mat2_size0, + "x", + mat2_size1, + ")"); + out_sizemat_ptr[0] = self_size0; + out_sizemat_ptr[1] = mat2_size1; + out_sizemat_ptr += 2; + out_numel += self_size0 * mat2_size1; + } + const Tensor &self_buffer = self_ptr->get_unsafe_storage_as_tensor(); + const Tensor &mat2_buffer = mat2_ptr->get_unsafe_storage_as_tensor(); + Tensor out_buffer = self_buffer.new_empty(out_numel); + Tensor output = wrap_buffer(out_buffer, out_sizemat); + auto out_ptr = get_nested_tensor_impl(output); + + std::vector self_strides = NestedTensor_get_strides(self_ptr); + std::vector mat2_strides = NestedTensor_get_strides(mat2_ptr); + const std::vector& self_offsets = self_ptr->get_storage_offsets(); + const std::vector& mat2_offsets = mat2_ptr->get_storage_offsets(); + const std::vector& out_offsets = out_ptr->get_storage_offsets(); + +#ifndef USE_ROCM +#ifndef _WIN32 + bool success = false; + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + self.scalar_type(), "group_gemm_dispatch", [&] { + std::vector aptr(ntensors); + std::vector bptr(ntensors); + std::vector dptr(ntensors); + std::vector lda(ntensors); + std::vector ldb(ntensors); + std::vector ldd(ntensors); + std::vector gemm_sizes; + bool all_row_major = true; + for (int64_t i = 0; i < ntensors; i++) { + const IntArrayRef& self_shape = self_sizes[i]; + const IntArrayRef& mat2_shape = mat2_sizes[i]; + const int64_t &self_size0 = self_shape[0]; + const int64_t &self_size1 = self_shape[1]; + const int64_t &mat2_size0 = mat2_shape[0]; + const int64_t &mat2_size1 = mat2_shape[1]; + gemm_sizes.push_back( + cutlass::gemm::GemmCoord(self_size0, mat2_size1, self_size1)); + aptr[i] = self_buffer.data_ptr() + self_offsets[i]; + bptr[i] = mat2_buffer.data_ptr() + mat2_offsets[i]; + dptr[i] = out_buffer.data_ptr() + out_offsets[i]; + all_row_major = all_row_major && (self_strides[i][1] == 1); + all_row_major = all_row_major && (mat2_strides[i][1] == 1); + lda[i] = self_strides[i][0]; + ldb[i] = mat2_strides[i][0]; + ldd[i] = mat2_size1; + } + auto dprops = at::cuda::getCurrentDeviceProperties(); + bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; + if (all_row_major && + self.is_contiguous() && + mat2.is_contiguous() && + is_sm8x) { + success = group_gemm_dispatch( + output.device(), + aptr, + bptr, + dptr, + lda, + ldb, + ldd, + gemm_sizes, + ntensors); + } + }); + if (success) { + return output; + } +#endif +#endif + + std::vector output_unbind = output.unbind(); + for (int64_t i = 0; i < ntensors; i++) { + at::mm_out( + output_unbind[i], + self_buffer.as_strided(self_sizes[i], self_strides[i], self_offsets[i]), + mat2_buffer.as_strided( + mat2_sizes[i], mat2_strides[i], mat2_offsets[i])); + } + return output; +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp index a90af2fe0af32..9c72454560d38 100644 --- a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp +++ b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp @@ -44,16 +44,16 @@ Tensor nested_from_padded_cuda( const Tensor& sizes, bool do_transform_0213) { if (padded.dim() > 1 && padded.dim() < 5) { + // Instead of erroring call the generic version + if(!(padded.dim() == 4 && do_transform_0213) && !(padded.dim() == 3 && !do_transform_0213)){ + return at::native::nested_from_padded_generic(padded, sizes, do_transform_0213); + } if (padded.dtype() != kFloat && padded.dtype() != kHalf) { TORCH_WARN_ONCE( "nested_from_padded CUDA kernels only support fp32/fp16; falling " "back to slower generic kernel"); return at::native::nested_from_padded_generic(padded, sizes, do_transform_0213); } - TORCH_CHECK( - (padded.dim() == 4 && do_transform_0213) || - (padded.dim() == 3 && !do_transform_0213), - "padded tensor size error"); Tensor target_offsets = NestedTensor_batch_offsets_from_size_tensor(sizes, 0); Tensor padded_sizes_tensor = at::tensor(padded.sizes()); @@ -152,8 +152,8 @@ Tensor NestedTensor_to_padded_tensor_cuda( if (t_dim == 3 && nt_input->opt_size(2) && (*nt_input->opt_size(2) > 0) && !(output_size.has_value())) { Tensor nt_sizes = nt_input->get_nested_size_tensor(); - Tensor sizes_dim1 = at::native::narrow(nt_sizes, 1, 0, 1); - Tensor sizes_dim2 = at::native::narrow(nt_sizes, 1, 1, 1); + Tensor sizes_dim1 = at::native::narrow_symint(nt_sizes, 1, 0, 1); + Tensor sizes_dim2 = at::native::narrow_symint(nt_sizes, 1, 1, 1); Tensor result = at::detail::make_tensor( nt_input->get_buffer(), sizes_dim1 * sizes_dim2[0]); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(result.dim() == 2); @@ -214,27 +214,7 @@ Tensor NestedTensor_to_padded_tensor_cuda( return NestedTensor_to_padded_tensor_generic(t, padding, output_size); } -std::tuple _scaled_dot_product_attention_forward_nested( - const Tensor& query_, const Tensor& key, const Tensor& value, - const c10::optional& attn_mask_, double dropout_p, bool need_attn_weights, bool is_causal) { - - // Determine which efficient kernel to use - sdp::sdp_params kernel_params{query_, key, value, attn_mask_.has_value(), dropout_p, need_attn_weights, is_causal}; - auto backend = select_sdp_backend(kernel_params); - switch(backend){ - case sdp::SDPBackend::flash_attention: - // TODO: enable flash attention kernel - return mem_efficient_helper_nested_unpacked(query_, key, value, dropout_p, need_attn_weights, is_causal); - case sdp::SDPBackend::efficient_attention: - return mem_efficient_helper_nested_unpacked(query_, key, value, dropout_p, need_attn_weights, is_causal); - case sdp::SDPBackend::math: - return at::_scaled_dot_product_attention_math(query_, key, value, attn_mask_, dropout_p, need_attn_weights, is_causal); - default: - TORCH_CHECK(false, "Unsupported backend for scaled_dot_product_attention"); - return std::make_tuple(Tensor(), Tensor()); - } -} - +namespace{ /** * This function is used to calculate two pieces of metadata that are needed @@ -242,9 +222,10 @@ std::tuple _scaled_dot_product_attention_forward_nested( * cumulative sequence_length over a batch of sequences and the maximum sequence * length. * - * @return A tuple of cumulative sequence lengths and the maximum sequence length + * @return A tuple of cumulative sequence lengths and the maximum sequence length, + * and the last element in the cumulative_sequence_lengths */ -std::tuple cumulative_and_max_seq_len(Tensor qkv) { +std::tuple cumulative_and_max_seq_len(Tensor qkv) { TORCH_CHECK( qkv.is_nested(), "QKV must be nested for flash cumulative_seq_len calculation.") @@ -274,7 +255,7 @@ std::tuple cumulative_and_max_seq_len(Tensor qkv) { // Send to GPU, this is pretty light weight calc for normal batch size // but maybe this needs to be on gpu cumulative_seqlen = cumulative_seqlen.to(TensorOptions().device(at::kCUDA)); - return std::tuple{cumulative_seqlen, max_seqlen}; + return std::tuple{cumulative_seqlen, max_seqlen, sum}; } /** @@ -321,14 +302,15 @@ bool is_safe_to_get_storage_as_tensor(const NestedTensorImpl* tensor) { const int64_t* tensor_size_ptr = tensor_sizes.data_ptr(); const int64_t* tensor_stride_ptr = tensor_strides.data_ptr(); - int64_t offset_constant = (tensor_offsets[1] - tensor_offsets[0]) / - tensor_size_ptr[0] * tensor_stride_ptr[0]; + int64_t numel_0 = (tensor_size_ptr[0] * tensor_stride_ptr[0]); + TORCH_INTERNAL_ASSERT(numel_0 > 0, "numels must be positive!"); + int64_t offset_constant = (tensor_offsets[1] - tensor_offsets[0]) / numel_0; for (int64_t i = 2; i < n_tensors; i++) { - int64_t current_offset_constant = - (tensor_offsets[i] - tensor_offsets[i - 1]) / - tensor_size_ptr[(i - 1) * tensor_stride_0] * - tensor_stride_ptr[(i - 1) * tensor_stride_0]; + // TODO: When 0 seq_len nested tensors are allowed we need to guard against this + int64_t previous_numel = tensor_size_ptr[(i - 1) * tensor_stride_0] * tensor_stride_ptr[(i - 1) * tensor_stride_0]; + TORCH_INTERNAL_ASSERT(previous_numel > 0, "numels must be positive!"); + int64_t current_offset_constant = (tensor_offsets[i] - tensor_offsets[i - 1]) / previous_numel; if (current_offset_constant != offset_constant) { return false; } @@ -337,37 +319,99 @@ bool is_safe_to_get_storage_as_tensor(const NestedTensorImpl* tensor) { return true; } -std::tuple mem_efficient_helper_nested_unpacked( +} // namespace + +std::tuple _scaled_dot_product_flash_attention_nestedtensor_cuda( const Tensor& query, const Tensor& key, const Tensor& value, double dropout_p, - bool need_atten_weights, + bool return_softmax, bool is_causal) { + TORCH_CHECK(false, "There are currently cuda memory errors being returned from this path.") // Query (Batch x Num_heads x {Q_seq_len} x Dim_per_head) // Key (Batch x Num_heads x {KV_seq_len} x Dim_per_head) // Value (Batch x Num_heads x {KV_seq_len} x Dim_per_head) const int64_t num_heads = query.size(1); const int64_t head_dim = query.size(3); - Tensor q_t = query.transpose(1, 2); - Tensor k_t = key.transpose(1, 2); - Tensor v_t = value.transpose(1, 2); - - auto cumulative_and_max_q = cumulative_and_max_seq_len(q_t); - auto cumulative_and_max_k = cumulative_and_max_seq_len(k_t); + // Query -> Query (Batch x {Q_seq_len} x Num_heads x Dim_per_head) + // Key -> Key (Batch x {KV_seq_len} x Num_heads x Dim_per_head) + // Value -> Value (Batch x {KV_seq_len} x Num_heads x Dim_per_head) + Tensor q_t = query.transpose(1, 2).contiguous(); + Tensor k_t = key.transpose(1, 2).contiguous(); + Tensor v_t = value.transpose(1, 2).contiguous(); // K and V have to have the same Nnz, should probably torch_check // assume in order to not iterate over v + auto cumulative_and_max_q = cumulative_and_max_seq_len(q_t); + auto cumulative_and_max_k = cumulative_and_max_seq_len(k_t); + Tensor cumulative_sequence_length_q = std::get<0>(cumulative_and_max_q); Tensor cumulative_sequence_length_k = std::get<0>(cumulative_and_max_k); const int64_t max_seqlen_batch_q = std::get<1>(cumulative_and_max_q); + const int64_t max_seqlen_batch_k = std::get<1>(cumulative_and_max_k); - const int64_t Nnz_q = cumulative_sequence_length_q[-1].item(); + const int64_t Nnz_q = cumulative_sequence_length_q[-1].item(); const int64_t Nnz_kv = cumulative_sequence_length_k[-1].item(); + auto query_buffer_reshaped = + get_buffer(q_t).view({Nnz_q, num_heads, head_dim}); + auto key_buffer_reshaped = + get_buffer(k_t).view({Nnz_kv, num_heads, head_dim}); + auto value_buffer_reshaped = + get_buffer(v_t).view({Nnz_kv, num_heads, head_dim}); + + auto attention_and_lse_and_softmax = + at::_flash_attention_forward( + query_buffer_reshaped, + key_buffer_reshaped, + value_buffer_reshaped, + cumulative_sequence_length_q, + cumulative_sequence_length_k, + max_seqlen_batch_q, + max_seqlen_batch_k, + return_softmax, + dropout_p, + is_causal); + // Reshape output to convert nnz to batch_size and seq_len + Tensor attention = std::get<0>(attention_and_lse_and_softmax); + attention = wrap_buffer(attention.view(-1), get_nested_size_tensor(q_t).clone()).transpose(1,2); + return std::tie(attention, std::get<1>(attention_and_lse_and_softmax), std::get<2>(attention_and_lse_and_softmax)); +} + +std::tuple _scaled_dot_product_efficient_attention_nestedtensor_cuda( + const Tensor& query, + const Tensor& key, + const Tensor& value, + bool compute_log_sumexp, + bool is_causal) { + // Query (Batch x Num_heads x {Q_seq_len} x Dim_per_head) + // Key (Batch x Num_heads x {KV_seq_len} x Dim_per_head) + // Value (Batch x Num_heads x {KV_seq_len} x Dim_per_head) + const int64_t num_heads = query.size(1); + const int64_t head_dim = query.size(3); + + Tensor q_t = query.transpose(1, 2); + Tensor k_t = key.transpose(1, 2); + Tensor v_t = value.transpose(1, 2); + + auto cumulative_and_max_q_and_nnz_q = cumulative_and_max_seq_len(q_t); + auto cumulative_and_max_k_and_nnz_k = cumulative_and_max_seq_len(k_t); + + // K and V have to have the same Nnz, should probably torch_check + // assume in order to not iterate over v + + Tensor cumulative_sequence_length_q = std::get<0>(cumulative_and_max_q_and_nnz_q); + Tensor cumulative_sequence_length_k = std::get<0>(cumulative_and_max_k_and_nnz_k); + + const int64_t max_seqlen_batch_q = std::get<1>(cumulative_and_max_q_and_nnz_q); + + const int64_t Nnz_q = std::get<2>(cumulative_and_max_q_and_nnz_q); + const int64_t Nnz_kv = std::get<2>(cumulative_and_max_k_and_nnz_k); + Tensor query_buffer_reshaped; Tensor key_buffer_reshaped; Tensor value_buffer_reshaped; @@ -429,8 +473,7 @@ std::tuple mem_efficient_helper_nested_unpacked( {Nnz_kv, num_heads, head_dim}, {nnz_v_stride, head_v_stride, head_dim_stride}, value_impl->get_storage_offsets()[0]); - - std::tuple attention_and_weights = + std::tuple attention_and_logsumexp= at::_efficient_attention_forward( query_buffer_reshaped.unsqueeze(0), key_buffer_reshaped.unsqueeze(0), @@ -438,14 +481,14 @@ std::tuple mem_efficient_helper_nested_unpacked( cumulative_sequence_length_q, cumulative_sequence_length_k, max_seqlen_batch_q, - false, - false); + compute_log_sumexp, + is_causal); // Reshape output to convert nnz to batch_size and seq_len - Tensor attention = std::get<0>(attention_and_weights); + Tensor attention = std::get<0>(attention_and_logsumexp); attention = wrap_buffer(attention.view(-1), get_nested_size_tensor(q_t).clone()) .transpose(1, 2); - return std::tie(attention, std::get<1>(attention_and_weights)); + return std::tie(attention, std::get<1>(attention_and_logsumexp)); } Tensor flash_attention_helper( @@ -460,15 +503,15 @@ Tensor flash_attention_helper( int64_t head_dim{query.size(-1)}; int64_t num_heads{query.size(-2)}; - auto cumulative_and_max_q = cumulative_and_max_seq_len(query); - Tensor cumulative_sequence_length_q = std::get<0>(cumulative_and_max_q); - int64_t max_seqlen_batch_q = std::get<1>(cumulative_and_max_q); + auto cumulative_and_max_q_and_nnz_q = cumulative_and_max_seq_len(query); + Tensor cumulative_sequence_length_q = std::get<0>(cumulative_and_max_q_and_nnz_q); + int64_t max_seqlen_batch_q = std::get<1>(cumulative_and_max_q_and_nnz_q); TORCH_CHECK( key.is_same(key) && query.is_same(value), "Key and Value must be the same tensor"); - int64_t Nnz_q{cumulative_sequence_length_q[-1].item()}; + int64_t Nnz_q = std::get<2>(cumulative_and_max_q_and_nnz_q); // For the packed case we need to set the output size for dim 2 to 1 auto atten_size = get_nested_size_tensor(query).clone(); @@ -490,7 +533,7 @@ Tensor flash_attention_helper( // If we are passing in query, key, value all the same tensors then we have // packed them into one tensor and need to slice for flash attention Tensor attention = - at::_flash_scaled_dot_product_attention( + std::get<0>(at::_flash_attention_forward( q, k, v, @@ -498,8 +541,9 @@ Tensor flash_attention_helper( cumulative_sequence_length_q, max_seqlen_batch_q, max_seqlen_batch_q, + false /*return_softmax*/, dropout_p, - is_causal); + is_causal)); // Output of flash_attention is a regular tensor lets wrap it back up to // form a nested tensor diff --git a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cu b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cu index fc84d07ba6797..56cac2a898034 100644 --- a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cu +++ b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cu @@ -462,277 +462,5 @@ template void add_padding_kernelLauncher( const int batch_size, const int output_batch_size); -namespace { - -#ifndef USE_ROCM -#ifndef _WIN32 -template -void gemm_grouped_cuda_internal( - const std::vector& lda, - const std::vector& ldb, - const std::vector& ldd, - const std::vector& aptr, - const std::vector& bptr, - const std::vector& dptr, - const std::vector& gemm_sizes, - const int problem_count, - at::Device& device) { - using Element = scalar_t; - using ElementAcc = float; - using OpClass = cutlass::arch::OpClassSimt; - - using GemmConfiguration = - typename cutlass::gemm::device::DefaultGemmConfiguration< - OpClass, - cutlass::arch::Sm80, - Element, - Element, - Element, - ElementAcc>; - - using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< - Element, - cutlass::layout::RowMajor, - cutlass::ComplexTransform::kNone, - GemmConfiguration::kAlignmentA, - Element, - cutlass::layout::RowMajor, - cutlass::ComplexTransform::kNone, - GemmConfiguration::kAlignmentB, - Element, - cutlass::layout::RowMajor, - ElementAcc, - OpClass, - cutlass::arch::Sm80, - typename GemmConfiguration::ThreadblockShape, - typename GemmConfiguration::WarpShape, - typename GemmConfiguration::InstructionShape, - typename GemmConfiguration::EpilogueOutputOp, - cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, - GemmConfiguration::kStages>::GemmKernel; - - using GemmGrouped = typename cutlass::gemm::device::GemmGrouped; - using EpilogueOutputOp = typename GemmGrouped::GemmKernel::Epilogue::OutputOp; - typename EpilogueOutputOp::Params epilogue_op(/*alpha*/ 1, /*beta*/ 0); - - const int64_t gemm_coord_size = - problem_count * ((int64_t)sizeof(cutlass::gemm::GemmCoord)); - // Number of gmm args not including *problem_sizes - at::Tensor gmm_args = at::empty( - {problem_count * 6 + gemm_coord_size}, - at::TensorOptions().dtype(at::kLong).pinned_memory(true)); - - // Obtain pointers for each argument (on host) - int64_t* lda_data = gmm_args.data_ptr(); // Base pointer - int64_t* ldb_data = lda_data + problem_count; - int64_t* ldd_data = lda_data + 2 * problem_count; - int64_t* ptr_a_data = lda_data + 3 * problem_count; - int64_t* ptr_b_data = lda_data + 4 * problem_count; - int64_t* ptr_d_data = lda_data + 5 * problem_count; - cutlass::gemm::GemmCoord* problem_sizes_data = - reinterpret_cast(lda_data + 6 * problem_count); - - // Set arguments into gmm_args from input args - for (int i = 0; i < problem_count; ++i) { - problem_sizes_data[i] = gemm_sizes[i]; - lda_data[i] = lda[i]; - ldb_data[i] = ldb[i]; - ldd_data[i] = ldd[i]; - ptr_a_data[i] = reinterpret_cast(aptr[i]); - ptr_b_data[i] = reinterpret_cast(bptr[i]); - ptr_d_data[i] = reinterpret_cast(dptr[i]); - } - const int threadblock_count = - GemmGrouped::sufficient(problem_sizes_data, problem_count); - - // Transfer arguments to GPU - gmm_args = gmm_args.to(device, true); - - // Obtain pointers for each of arguments (on GPU) - lda_data = gmm_args.data_ptr(); // Base pointer - ldb_data = lda_data + problem_count; - ldd_data = lda_data + 2 * problem_count; - ptr_a_data = lda_data + 3 * problem_count; - ptr_b_data = lda_data + 4 * problem_count; - ptr_d_data = lda_data + 5 * problem_count; - problem_sizes_data = - reinterpret_cast(lda_data + 6 * problem_count); - - // Create GemmGrouped::Arguments using the arguments prepared above - typename GemmGrouped::Arguments args( - problem_sizes_data, - problem_count, - threadblock_count, - epilogue_op, - reinterpret_cast(ptr_a_data), - reinterpret_cast(ptr_b_data), - reinterpret_cast(ptr_d_data), - reinterpret_cast(ptr_d_data), - lda_data, - ldb_data, - ldd_data, - ldd_data); - - GemmGrouped gemm; - cutlass::Status status = - gemm.initialize(args, nullptr, at::cuda::getCurrentCUDAStream()); - TORCH_CHECK( - status != cutlass::Status::kErrorWorkspaceNull, - "Failed to initialize CUTLASS Grouped GEMM kernel due to workspace."); - TORCH_CHECK( - status != cutlass::Status::kErrorInternal, - "Failed to initialize CUTLASS Grouped GEMM kernel due to internal error."); - TORCH_CHECK( - status == cutlass::Status::kSuccess, - "Failed to initialize CUTLASS Grouped GEMM kernel."); - - // Run CUTLASS group GEMM - status = gemm.run(at::cuda::getCurrentCUDAStream()); - TORCH_CHECK( - status == cutlass::Status::kSuccess, - "Failed to run CUTLASS Grouped GEMM kernel."); - - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} -#endif -#endif - -} // namespace - -Tensor bmm_nested_cuda(const Tensor& self, const Tensor& mat2) { - if (self.is_nested() && !mat2.is_nested()) { - AT_ERROR( - "Expected both to be nested, but got a nested self and non-nested other"); - } else if (!self.is_nested() && mat2.is_nested()) { - AT_ERROR( - "Expected both to be nested, but got a non-nested self and nested other"); - } - // dispatcher should have guaranteed that at least one is nested - auto self_ptr = get_nested_tensor_impl(self); - auto mat2_ptr = get_nested_tensor_impl(mat2); - TORCH_CHECK(self_ptr->dim() == 3, "batch1 must be a 3D tensor"); - TORCH_CHECK(mat2_ptr->dim() == 3, "batch2 must be a 3D tensor"); - int64_t ntensors = self_ptr->size(0), ntensors2 = mat2_ptr->size(0); - TORCH_CHECK( - ntensors == ntensors2, - "Expected size for the 1st dimension of batch2 tensor to be: ", - ntensors, - " but got: ", - ntensors2, - "."); - const Tensor &self_buffer = self_ptr->get_buffer(), - &mat2_buffer = mat2_ptr->get_buffer(); - std::vector self_sizes = NestedTensor_get_sizes(self_ptr), - mat2_sizes = NestedTensor_get_sizes(mat2_ptr), - self_strides = NestedTensor_get_strides(self_ptr), - mat2_strides = NestedTensor_get_strides(mat2_ptr); - const std::vector& self_offsets = self_ptr->get_storage_offsets(); - const std::vector& mat2_offsets = mat2_ptr->get_storage_offsets(); - - // create a contiguous output - int64_t out_numel = 0; - int64_t a_numel = 0; - int64_t b_numel = 0; - const Tensor& self_sizemat = self_ptr->get_nested_size_tensor(); - Tensor out_sizemat = self_sizemat.new_empty(self_sizemat.sizes()); - int64_t* out_sizemat_ptr = out_sizemat.data_ptr(); - std::vector output_offsets; - std::vector a_offsets; - std::vector b_offsets; - std::vector lda; - std::vector ldb; - std::vector ldd; -#ifndef USE_ROCM -#ifndef _WIN32 - std::vector gemm_sizes; -#endif -#endif - bool all_row_major = true; - for (int64_t i = 0; i < ntensors; i++) { - const IntArrayRef &self_shape = self_sizes[i], &mat2_shape = mat2_sizes[i]; - const int64_t &self_size0 = self_shape[0], &self_size1 = self_shape[1], - &mat2_size0 = mat2_shape[0], &mat2_size1 = mat2_shape[1]; - TORCH_CHECK( - self_size1 == mat2_size0, - i, - "-th nested matrices in batch cannot be multiplied (", - self_size0, - "x", - self_size1, - " and ", - mat2_size0, - "x", - mat2_size1, - ")"); - out_sizemat_ptr[0] = self_size0; - out_sizemat_ptr[1] = mat2_size1; - out_sizemat_ptr += 2; - output_offsets.push_back(out_numel); - out_numel += self_size0 * mat2_size1; -#ifndef USE_ROCM -#ifndef _WIN32 - gemm_sizes.push_back( - cutlass::gemm::GemmCoord(self_size0, mat2_size1, self_size1)); -#endif -#endif - lda.push_back(self_strides[i][0]); - ldb.push_back(mat2_strides[i][0]); - ldd.push_back(mat2_size1); - a_offsets.push_back(a_numel); - b_offsets.push_back(b_numel); - a_numel += self_size0 * self_strides[i][0]; - b_numel += mat2_size0 * mat2_strides[i][0]; - all_row_major = all_row_major && (self_strides[i][1] == 1); - all_row_major = all_row_major && (mat2_strides[i][1] == 1); - } - Tensor out_buffer = self_buffer.new_empty(out_numel); - Tensor output = wrap_buffer(out_buffer, out_sizemat); - at::Device device = output.device(); - -#ifndef USE_ROCM -#ifndef _WIN32 - auto dprops = at::cuda::getCurrentDeviceProperties(); - bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; - if (is_sm8x && all_row_major) { - if (self.dtype() == at::kFloat) { - std::vector aptr; - std::vector bptr; - std::vector dptr; - for (int64_t i = 0; i < ntensors; i++) { - aptr.push_back(self_buffer.data_ptr() + a_offsets[i]); - bptr.push_back(mat2_buffer.data_ptr() + b_offsets[i]); - dptr.push_back(out_buffer.data_ptr() + output_offsets[i]); - } - gemm_grouped_cuda_internal( - lda, ldb, ldd, aptr, bptr, dptr, gemm_sizes, ntensors, device); - return output; - } - if (self.dtype() == at::kHalf) { - std::vector aptr; - std::vector bptr; - std::vector dptr; - for (int64_t i = 0; i < ntensors; i++) { - aptr.push_back(self_buffer.data_ptr() + a_offsets[i]); - bptr.push_back(mat2_buffer.data_ptr() + b_offsets[i]); - dptr.push_back(out_buffer.data_ptr() + output_offsets[i]); - } - gemm_grouped_cuda_internal( - lda, ldb, ldd, aptr, bptr, dptr, gemm_sizes, ntensors, device); - return output; - } - } -#endif -#endif - std::vector output_unbind = output.unbind(); - for (int64_t i = 0; i < ntensors; i++) { - at::mm_out( - output_unbind[i], - self_buffer.as_strided(self_sizes[i], self_strides[i], self_offsets[i]), - mat2_buffer.as_strided( - mat2_sizes[i], mat2_strides[i], mat2_offsets[i])); - } - return output; -} - } // namespace native } // namespace at diff --git a/aten/src/ATen/native/prim_native_functions.cpp b/aten/src/ATen/native/prim_native_functions.cpp index 8f82345c19058..4e79c112d7fc6 100644 --- a/aten/src/ATen/native/prim_native_functions.cpp +++ b/aten/src/ATen/native/prim_native_functions.cpp @@ -1,4 +1,11 @@ -#include +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif namespace at { namespace native { diff --git a/aten/src/ATen/native/quantized/AffineQuantizer.cpp b/aten/src/ATen/native/quantized/AffineQuantizer.cpp index e2fa8f65adc60..dbda6ebd5f902 100644 --- a/aten/src/ATen/native/quantized/AffineQuantizer.cpp +++ b/aten/src/ATen/native/quantized/AffineQuantizer.cpp @@ -97,6 +97,21 @@ void checkSameSize( " only works with Tensors with the same shape"); } +void checkPerChannelParamsSize( + const Tensor& rtensor, + int64_t axis, + const Tensor& scales, + const Tensor& zero_points +) { + int64_t channel = rtensor.size(axis); + TORCH_CHECK( + channel == int64_t(scales.numel()), + "length of scales must equal to channel, expected ", channel, " got, ", scales.numel()); + TORCH_CHECK( + channel == int64_t(zero_points.numel()), + "length of zero_points must equal to channel expected ", channel, " got, ", zero_points.numel()); +} + } // anonymous namespace Tensor& quantize_tensor_per_tensor_affine( @@ -156,13 +171,7 @@ Tensor& quantize_tensor_per_channel_affine( "Expected: [0, ", rtensor.dim(), ")"); - int64_t channel = rtensor.size(axis); - TORCH_CHECK( - channel == int64_t(scales.numel()), - "length of scales must equal to channel"); - TORCH_CHECK( - channel == int64_t(zero_points.numel()), - "length of zero_points must equal to channel"); + checkPerChannelParamsSize(rtensor, axis, scales, zero_points); quantize_tensor_per_channel_affine_stub( rtensor.device().type(), rtensor, qtensor, scales, zero_points, axis); @@ -195,13 +204,7 @@ Tensor& quantize_tensor_per_channel_float_qparams( "Expected: [0, ", rtensor.dim(), ")"); - int64_t channel = rtensor.size(axis); - TORCH_CHECK( - channel == int64_t(scales.numel()), - "length of scales must equal to channel"); - TORCH_CHECK( - channel == int64_t(zero_points.numel()), - "length of zero_points must equal to channel"); + checkPerChannelParamsSize(rtensor, axis, scales, zero_points); quantize_tensor_per_channel_float_qparams_stub( rtensor.device().type(), rtensor, qtensor, scales, zero_points, axis); @@ -260,13 +263,7 @@ Tensor& dequantize_tensor_per_channel_affine( " Expected: [0, ", qtensor.dim(), ")"); - int64_t channel = qtensor.size(axis); - TORCH_CHECK( - channel == int64_t(scales.numel()), - "length of scales must equal to channel"); - TORCH_CHECK( - channel == int64_t(zero_points.numel()), - "length of zero_points must equal to channel"); + checkPerChannelParamsSize(rtensor, axis, scales, zero_points); dequantize_tensor_per_channel_affine_stub( qtensor.device().type(), qtensor, rtensor, scales, zero_points, axis); @@ -297,13 +294,7 @@ Tensor& dequantize_tensor_per_channel_float_qparams( " Expected: [0, ", qtensor.dim(), ")"); - int64_t channel = qtensor.size(axis); - TORCH_CHECK( - channel == int64_t(scales.numel()), - "length of scales must equal to channel"); - TORCH_CHECK( - channel == int64_t(zero_points.numel()), - "length of zero_points must equal to channel"); + checkPerChannelParamsSize(rtensor, axis, scales, zero_points); dequantize_tensor_per_channel_float_qparams_stub( qtensor.device().type(), qtensor, rtensor, scales, zero_points, axis); diff --git a/aten/src/ATen/native/quantized/FakeQuantPerTensorAffine.cpp b/aten/src/ATen/native/quantized/FakeQuantPerTensorAffine.cpp index 700b3b14b180c..aac039f0e03ef 100644 --- a/aten/src/ATen/native/quantized/FakeQuantPerTensorAffine.cpp +++ b/aten/src/ATen/native/quantized/FakeQuantPerTensorAffine.cpp @@ -122,10 +122,10 @@ Tensor fake_quantize_per_tensor_affine_cachemask_backward( const Tensor& dY, const Tensor& mask) { TORCH_CHECK(mask.scalar_type() == ScalarType::Bool); - TORCH_CHECK(mask.numel() == dY.numel(), + TORCH_CHECK(mask.sym_numel() == dY.sym_numel(), "`mask` and `dY` are not the same size: ", - "`mask` is size ", mask.numel(), " and `dY` is size ", dY.numel()); - if (dY.numel() <= 0) { + "`mask` is size ", mask.sym_numel(), " and `dY` is size ", dY.sym_numel()); + if (dY.sym_numel() <= 0) { return dY; } // Note: no additional kernels needed, since mask is pre-computed diff --git a/aten/src/ATen/native/quantized/PackedParams.h b/aten/src/ATen/native/quantized/PackedParams.h index 179fcce23dfe5..a442628573fec 100644 --- a/aten/src/ATen/native/quantized/PackedParams.h +++ b/aten/src/ATen/native/quantized/PackedParams.h @@ -36,6 +36,55 @@ struct LinearPackedParamsBase : public torch::jit::CustomClassHolder { return output; } + // Corresponding pattern (the ops with `*` are part of the pattern that + // represents the computation of quantized::linear_with_input_q_dq_qweight_dq_output_fp32): + // input -> q* -> dq* -> linear* -> + // qweight -> dq* / + // + // After fusion: + // input -> quantized::linear_with_input_q_dq_qweight_dq_output_fp32* -> + // qweight / + // + // Additional Note: the weight is packed as well + // Params: + // X: float32 Tensor, will be quantized to quint8 in the op + // W_prepack: packed qint8 quantized weight and bias + // Returns: + // Y: float32 Tensor + virtual at::Tensor apply_with_input_q_dq_qweight_dq_output_fp32( + at::Tensor input, + double input_scale, + int64_t input_zero_point) { + throw std::runtime_error( + "apply_with_input_q_dq_qweight_dq_output_fp32 is not implemented for this packed " + "parameter type"); + return {}; + } + + // Corresponding pattern (the ops with `*` are part of the pattern that + // represents the computation of quantized::linear_with_input_q_dq_qweight_dq_relu_output_fp32): + // input -> q* -> dq* -> linear* -> relu* -> + // qweight -> dq* / + // + // After fusion: + // input -> quantized::linear_with_input_q_dq_qweight_dq_relu_output_fp32* -> + // qweight / + // + // Additional Note: the weight is packed as well + // Params: + // input: float32 Tensor, will be quantized to quint8 in the op + // Returns: + // float32 Tensor + virtual at::Tensor apply_with_input_q_dq_qweight_dq_relu_output_fp32( + at::Tensor input, + double input_scale, + int64_t input_zero_point) { + throw std::runtime_error( + "apply_with_input_q_dq_qweight_dq_relu_output_fp32 is not implemented for this packed " + "parameter type"); + return {}; + } + virtual at::Tensor apply_dynamic( at::Tensor input, bool reduce_range = false) = 0; diff --git a/aten/src/ATen/native/quantized/QTensor.cpp b/aten/src/ATen/native/quantized/QTensor.cpp index 5a9bbfb387e43..b3ff8bd8b3274 100644 --- a/aten/src/ATen/native/quantized/QTensor.cpp +++ b/aten/src/ATen/native/quantized/QTensor.cpp @@ -330,6 +330,10 @@ std::tuple choose_qparams_optimized( const double ratio, int64_t bit_width) { + if (numel < 0 || numel > input_tensor.numel()) { + TORCH_CHECK(false, "numel is out of the bound of input tensor"); + } + TORCH_CHECK(numel <= input_tensor.numel(), "numel ", numel, " greater than input_tensor.numel() ", input_tensor.numel()); const float* input_row = input_tensor.data_ptr(); diff --git a/aten/src/ATen/native/quantized/cpu/BinaryOps.cpp b/aten/src/ATen/native/quantized/cpu/BinaryOps.cpp index 8444f9ca615be..58a7036bdd7e2 100644 --- a/aten/src/ATen/native/quantized/cpu/BinaryOps.cpp +++ b/aten/src/ATen/native/quantized/cpu/BinaryOps.cpp @@ -36,10 +36,10 @@ namespace { inline void check_inputs(const Tensor& qa, const Tensor& qb) { TORCH_CHECK( qa.qscheme() == kPerTensorAffine, - "Only per tensor quantization is suported in Add."); + "Only per tensor quantization is supported in Add."); TORCH_CHECK( qa.qscheme() == qb.qscheme(), - "Both inputs to Add must have the same quantization shceme."); + "Both inputs to Add must have the same quantization scheme."); TORCH_CHECK( qa.scalar_type() == qb.scalar_type(), "Add operands should have same data type."); diff --git a/aten/src/ATen/native/quantized/cpu/OnednnUtils.h b/aten/src/ATen/native/quantized/cpu/OnednnUtils.h index 533d83361f05d..85eaf93ac4bc2 100644 --- a/aten/src/ATen/native/quantized/cpu/OnednnUtils.h +++ b/aten/src/ATen/native/quantized/cpu/OnednnUtils.h @@ -167,6 +167,12 @@ struct DeconvPrimitiveCache : PrimitiveCache { } }; +enum PostOps { + NoPostOp, + Relu, + LeakyRelu, +}; + struct PackedLinearWeightsOnednn : public LinearPackedParamsBase { PackedLinearWeightsOnednn( std::unique_ptr weight, @@ -196,6 +202,12 @@ struct PackedLinearWeightsOnednn : public LinearPackedParamsBase { at::Tensor apply_dynamic(at::Tensor input, bool reduce_range=false) override; at::Tensor apply_dynamic_relu(at::Tensor input, bool reduce_range=false) override; + at::Tensor apply_leaky_relu( + at::Tensor input, + double output_scale, + int64_t output_zero_point, + double negative_slope); + std::tuple> unpack() override; c10::optional bias() override { @@ -210,11 +222,12 @@ struct PackedLinearWeightsOnednn : public LinearPackedParamsBase { LinearPrimitiveCache prim_cache; std::unique_ptr cache_initialized_flag; - template + template at::Tensor apply_impl( at::Tensor input, double output_scale, - int64_t output_zero_point); + int64_t output_zero_point, + torch::List post_op_args = torch::List()); template at::Tensor apply_dynamic_impl(at::Tensor input, bool reduce_range=false); diff --git a/aten/src/ATen/native/quantized/cpu/QnnpackUtils.h b/aten/src/ATen/native/quantized/cpu/QnnpackUtils.h index 40dfad16e9c52..9c6c721657cb1 100644 --- a/aten/src/ATen/native/quantized/cpu/QnnpackUtils.h +++ b/aten/src/ATen/native/quantized/cpu/QnnpackUtils.h @@ -272,8 +272,9 @@ struct PackedConvWeightsQnnp : public ConvPackedParamsBase { void* zero_buffer = malloc(zero_size); if (zero_buffer == nullptr) { pytorch_qnnp_delete_operator(convolution); - pytorch_qnnp_log_error( - "failed to allocate %zu bytes for zero padding", zero_size); + TORCH_INTERNAL_ASSERT( + false, "failed to allocate %zu bytes for zero padding", + zero_size); } // Need to set to input zero point // memset(zero_buffer, input_zero_point, zero_size); diff --git a/aten/src/ATen/native/quantized/cpu/UpSampleNearest3d.cpp b/aten/src/ATen/native/quantized/cpu/UpSampleNearest3d.cpp index 871f700ef4fb1..4b4c63eb7c3d3 100644 --- a/aten/src/ATen/native/quantized/cpu/UpSampleNearest3d.cpp +++ b/aten/src/ATen/native/quantized/cpu/UpSampleNearest3d.cpp @@ -238,27 +238,5 @@ Tensor _upsample_nearest_exact3d_quantized_cpu( input, osize, scale_d, scale_h, scale_w); } -Tensor upsample_nearest3d_quantized_cpu( - const Tensor& input, - at::OptionalIntArrayRef output_size, - c10::optional> scale_factors) { - auto osize = compute_output_size(input.sizes(), output_size, scale_factors); - auto scale_d = get_scale_value(scale_factors, 0); - auto scale_h = get_scale_value(scale_factors, 1); - auto scale_w = get_scale_value(scale_factors, 2); - return upsample_nearest3d_quantized_cpu(input, osize, scale_d, scale_h, scale_w); -} - -Tensor _upsample_nearest_exact3d_quantized_cpu( - const Tensor& input, - at::OptionalIntArrayRef output_size, - c10::optional> scale_factors) { - auto osize = compute_output_size(input.sizes(), output_size, scale_factors); - auto scale_d = get_scale_value(scale_factors, 0); - auto scale_h = get_scale_value(scale_factors, 1); - auto scale_w = get_scale_value(scale_factors, 2); - return _upsample_nearest_exact3d_quantized_cpu(input, osize, scale_d, scale_h, scale_w); -} - } // namespace native } // namespace at diff --git a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp index 658e2c48481e6..8af21bbc7df8b 100644 --- a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp +++ b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp @@ -560,6 +560,7 @@ int register_embedding_params() { return PackedEmbeddingBagWeight::prepack(weight); }) .def("bit_rate", &EmbeddingPackedParamsBase::bit_rate) + .def("unpack", &EmbeddingPackedParamsBase::unpack) .def("version", &EmbeddingPackedParamsBase::version); return 0; diff --git a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h index d43409231ab69..bfaf5b93d667b 100644 --- a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h +++ b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h @@ -45,6 +45,7 @@ struct TORCH_API PackedLinearWeight : public LinearPackedParamsBase { at::Tensor input, double output_scale, int64_t output_zero_point) override; + at::Tensor apply_relu( at::Tensor input, double output_scale, @@ -62,8 +63,19 @@ struct TORCH_API PackedLinearWeight : public LinearPackedParamsBase { int64_t output_zero_point, at::Tensor& output) override; + at::Tensor apply_with_input_q_dq_qweight_dq_output_fp32( + at::Tensor input, + double input_scale, + int64_t input_zero_point) override; + + at::Tensor apply_with_input_q_dq_qweight_dq_relu_output_fp32( + at::Tensor input, + double input_scale, + int64_t input_zero_point) override; + at::Tensor apply_dynamic(at::Tensor input, bool reduce_range = false) override; + at::Tensor apply_dynamic_relu(at::Tensor input, bool reduce_range = false) override; @@ -85,6 +97,12 @@ struct TORCH_API PackedLinearWeight : public LinearPackedParamsBase { int64_t output_zero_point, at::Tensor& output); + template + at::Tensor apply_with_input_q_dq_qweight_dq_output_fp32_impl( + const at::Tensor& input, + double input_scale, + int64_t input_zero_point); + template at::Tensor apply_dynamic_impl(at::Tensor input, bool reduce_range = false); }; diff --git a/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp b/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp index a286e01e28625..a1f8f0d7c2457 100644 --- a/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp +++ b/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp @@ -119,7 +119,7 @@ Tensor qcat_nhwc_kernel( c10::nullopt); // N, H, and W are explicitly captured here because there's a bug in GCC5 - // which causes an internal compiler error if they're not + // and clang5 which causes an internal compiler error if they're not AT_DISPATCH_QINT_TYPES(output.scalar_type(), "qcat_nhwc", [&, N, H, W]() { using Vec = Vectorized; at::parallel_for(0, N * H * W, 0, [&](int64_t begin, int64_t end) { diff --git a/aten/src/ATen/native/quantized/cpu/qconv.cpp b/aten/src/ATen/native/quantized/cpu/qconv.cpp index 2cd7cd81b9034..31945234f2a9a 100644 --- a/aten/src/ATen/native/quantized/cpu/qconv.cpp +++ b/aten/src/ATen/native/quantized/cpu/qconv.cpp @@ -130,7 +130,7 @@ at::SmallVector MakeDeConvOutputShape( ", output padding: ", output_padding[idx], ", dilation: ", dilation[idx]) TORCH_CHECK(output_shape[idx + 2] < kReasonableMaxDim, - "Output dimension is beyound reasonable maximum for ", idx, + "Output dimension is beyond reasonable maximum for ", idx, " axis;" " kernel: ", kernel[idx], ", stride: ", stride[idx], @@ -1329,14 +1329,19 @@ at::Tensor PackedConvWeightsOnednn::apply_impl( ideep::convolution_forward::compute( pd, primitive, src, weights, expected_bias, dst, src_zp_tensor, groups()); } else { - ideep::convolution_forward::compute_v2( - src, weights, b, dst_dims, dst, + src.set_zero_point(src_zero_points); + dst.set_zero_point(dst_zero_points); + ConvParams params; + ideep::convolution_forward::prepare( + params, src, weights, b, dst_dims, dst, strides, dilates, padding_l, padding_r, groups(), src_scales, weights_scales, ideep::scale_t(scale_size, inv_output_scale), - src_zero_points, dst_zero_points, op_attr, - dnnl::algorithm::convolution_direct, + op_attr, dnnl::algorithm::convolution_direct, dnnl::prop_kind::forward_inference, ideep::u8s8, ideep::engine::cpu_engine()); + onednn_utils::try_reorder( + weights, (ideep::tensor::desc)params.pd.weights_desc(), weights_scales); + ideep::convolution_forward::compute(params, src, weights, b, dst); } } return output; diff --git a/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp b/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp index 2250e84ad7a6e..9d2f1a96c31ba 100644 --- a/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp @@ -1,4 +1,5 @@ #define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include #include #include @@ -444,7 +445,7 @@ c10::intrusive_ptr> PackedConvWeightsOnednn< exp_wgt.init(w_desc); exp_wgt.set_scale(wgt_scales); // Also for feed_from() exp_wgt.feed_from(wgt, transpose); // expect wgt to be in [OC IC KH KW] format - ideep::tensor * packed_weight_p = new ideep::tensor(exp_wgt); + ideep::tensor * packed_weight_p = new ideep::tensor(std::move(exp_wgt)); packed_weight_p->set_scale(wgt_scales); packed_weight_p->set_zero_point(wgt_zero_points); std::unique_ptr weight_ptr(packed_weight_p); diff --git a/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp b/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp index 00a3a9b10e96a..dab19e0908e35 100644 --- a/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp @@ -266,9 +266,10 @@ Tensor& qembeddingbag_byte_prepack_out(Tensor& output, const Tensor& weight) { } #else - const auto weight_data = weight_contig->scalar_type() == at::ScalarType::Half - ? weight_contig->to(at::ScalarType::Float).data_ptr() - : weight_contig->data_ptr(); + const Tensor& float_weight = weight_contig->scalar_type() == at::ScalarType::Half + ? weight_contig->to(at::ScalarType::Float) + : *weight_contig; + const auto weight_data = float_weight.data_ptr(); constexpr float kEpsilon = 1e-8f; for (auto row : c10::irange(embedding_rows)) { const float* input_row = weight_data + row * embedding_cols; diff --git a/aten/src/ATen/native/quantized/cpu/qlinear.cpp b/aten/src/ATen/native/quantized/cpu/qlinear.cpp index 111b5eb5f1394..93a0f82978716 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear.cpp @@ -283,6 +283,162 @@ at::Tensor& PackedLinearWeight::apply_relu_out( return apply_impl(input, output_scale, output_zero_point, output); } +at::Tensor PackedLinearWeight::apply_with_input_q_dq_qweight_dq_output_fp32( + at::Tensor input, + double input_scale, + int64_t input_zero_point) { + TORCH_CHECK(!input.is_quantized(), "Input tensor for apply_with_input_q_dq_qweight_dq_output_fp32 is quantized; " + "Expected input tensor in PackedLinearWeight::apply_with_input_q_dq_qweight_dq_output_fp32 to be full precision."); + + return apply_with_input_q_dq_qweight_dq_output_fp32_impl(input, input_scale, input_zero_point); +} + +at::Tensor PackedLinearWeight::apply_with_input_q_dq_qweight_dq_relu_output_fp32( + at::Tensor input, + double input_scale, + int64_t input_zero_point) { + TORCH_CHECK(!input.is_quantized(), "Input tensor for apply_with_input_q_dq_qweight_dq_output_fp32 is quantized; " + "Expected input tensor in PackedLinearWeight::apply_with_input_q_dq_qweight_dq_output_fp32 to be full precision."); + + return apply_with_input_q_dq_qweight_dq_output_fp32_impl(input, input_scale, input_zero_point); +} + + +template +at::Tensor PackedLinearWeight::apply_with_input_q_dq_qweight_dq_output_fp32_impl( + const at::Tensor& input, + double input_scale, + int64_t input_zero_point) { + TORCH_CHECK( + fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM."); + + auto input_contig = input.expect_contiguous(); + const auto* input_ptr = input_contig->data_ptr(); + + TORCH_CHECK( + input.dim() >= 2, + "The dimension of input tensor should be larger than or equal to 2"); + int64_t M = size_to_dim_(input.dim() - 1, input.sizes()); + + auto packB = w.get(); + + int64_t N = static_cast(packB->numCols()); + int64_t K = input.sizes()[input.dim() - 1]; + TORCH_CHECK( + K == static_cast(packB->numRows()), + "The number of rows in the packB should be equal to K: " + + std::to_string(K)); + + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + float input_scale_float = input_scale; + int32_t input_zero_point_int32 = input_zero_point; + + TORCH_CHECK( + w_scale.size() == w_zp.size(), + "Weight scales and zero points vectors should have the same size."); + + const float* bias_ptr = nullptr; + c10::MaybeOwned bias_contig; + if (this->bias_.has_value()) { + auto& bias = this->bias_.value(); + bias_contig = bias.expect_contiguous(); + TORCH_CHECK(bias_contig->dim() == 1, "bias should be a vector (1D Tensor)"); + TORCH_CHECK( + bias_contig->sizes()[0] == N, "bias should have N elements: " + std::to_string(N)); + bias_ptr = bias_contig->data_ptr(); + } + + std::vector out_sizes = input.sizes().vec(); + out_sizes.back() = N; + // Allocate output Tensor and a buffer for fbgemmPacked to use + auto output = at::empty(out_sizes, input.options().dtype(at::kFloat)); + auto buffer = at::empty_like( + output, + output.options().dtype(at::kInt), + LEGACY_CONTIGUOUS_MEMORY_FORMAT); + + int num_tasks = at::get_num_threads(); + at::parallel_for(0, num_tasks, 1, [&](int64_t begin, int64_t end) { + fbgemm::PackAWithQuantRowOffset packA( + /*trans=*/fbgemm::matrix_op_t::NoTranspose, + /*nRow=*/M, + /*nCol=*/K, + /*smat=*/input_ptr, + /*ld=*/K, + /*pmat=*/nullptr, + /*scale=*/input_scale_float, + /*zero_pt=*/input_zero_point_int32); + + fbgemm::DoNothing doNothingObj{}; + for (const auto task_id : c10::irange(begin, end)) { + if (q_scheme == c10::kPerTensorAffine) { + // Process the per tensor quantization. + // + // After the uint8 * int8 matrix multiplication is performed, this + // operation does: + // 1) Add in row and column offsets to the rows and columns, + // respectively. + // 2) Add in the bias term. + fbgemm::ReQuantizeForFloat + outputProcObj( + doNothingObj, + input_scale_float, + w_scale.data(), + input_zero_point_int32, + w_zp.data(), + packA.getRowOffsetBuffer(), + col_offsets.data(), + bias_ptr, + N /* nCol */); + + // Do the GEMM + fbgemm::fbgemmPacked( + /*packA=*/packA, + /*packB=*/*packB, + /*C=*/output.data_ptr(), + /*C_buffer=*/buffer.data_ptr(), + /*ldc=*/N, + /*outProcess=*/outputProcObj, + /*thread_id=*/task_id, + /*num_threads=*/num_tasks); + } else if (q_scheme == c10::kPerChannelAffine) { + // Process the per channel quantization. + // + // After the uint8 * int8 matrix multiplication is performed, this + // operation does: + // 1) Add in row and column offsets to the rows and columns, + // respectively. + // 2) Add in the bias term. + fbgemm::ReQuantizeForFloat< + ReluFused, + fbgemm::QuantizationGranularity::OUT_CHANNEL> + outputProcObj( + doNothingObj, + input_scale_float, + w_scale.data(), + input_zero_point_int32, + w_zp.data(), + packA.getRowOffsetBuffer(), + col_offsets.data(), + bias_ptr, + N /* nCol */); + + // Do the GEMM + fbgemm::fbgemmPacked( + /*packA=*/packA, + /*packB=*/*packB, + /*C=*/output.data_ptr(), + /*C_buffer=*/buffer.data_ptr(), + /*ldc=*/N, + /*outProcess=*/outputProcObj, + /*thread_id=*/task_id, + /*num_threads=*/num_tasks); + } + } + }); + return output; +} + #endif // USE_FBGEMM #ifdef USE_PYTORCH_QNNPACK @@ -621,11 +777,12 @@ at::Tensor PackedLinearWeightsQnnp::apply_relu( #endif // USE_PYTORCH_QNNPACK #if AT_MKLDNN_ENABLED() -template +template at::Tensor PackedLinearWeightsOnednn::apply_impl( at::Tensor input, double output_scale, - int64_t output_zero_point) { + int64_t output_zero_point, + torch::List post_op_args) { const int64_t dim = input.dim(); TORCH_CHECK( dim != 0, @@ -639,7 +796,12 @@ at::Tensor PackedLinearWeightsOnednn::apply_impl( auto input_dims = {M, K}; auto input_data_type = dnnl::memory::data_type::u8; auto input_desc = ideep::tensor::desc(input_dims, input_data_type); - ideep::attr_t op_attr = ReluFused ? ideep::attr_t::fuse_relu() : ideep::attr_t(); + ideep::attr_t op_attr = ideep::attr_t(); + if (post_op == Relu) { + op_attr = ideep::attr_t::fuse_relu(); + } else if (post_op == LeakyRelu) { + op_attr = ideep::attr_t::fuse_relu(/*scale=*/1.0f, /*alpha=*/post_op_args.get(0).to()); + } ideep::tensor x(input_desc, input_contig->data_ptr()); auto dst_dims = {M, N}; double input_scale = input.q_scale(); @@ -705,14 +867,27 @@ at::Tensor PackedLinearWeightsOnednn::apply( at::Tensor input, double output_scale, int64_t output_zero_point) { - return apply_impl(std::move(input), output_scale, output_zero_point); + return apply_impl( + std::move(input), output_scale, output_zero_point); } at::Tensor PackedLinearWeightsOnednn::apply_relu( at::Tensor input, double output_scale, int64_t output_zero_point) { - return apply_impl(std::move(input), output_scale, output_zero_point); + return apply_impl( + std::move(input), output_scale, output_zero_point); +} + +at::Tensor PackedLinearWeightsOnednn:: apply_leaky_relu( + at::Tensor input, + double output_scale, + int64_t output_zero_point, + double negative_slope) { + torch::List post_op_args = + {at::Scalar(negative_slope)}; + return apply_impl( + std::move(input), output_scale, output_zero_point, post_op_args); } #endif // #if AT_MKLDNN_ENABLED() @@ -739,15 +914,63 @@ class QLinearInt8 final { } }; +class QLinearLeakyReluInt8 final { + public: + static at::Tensor run( + at::Tensor input, + const c10::intrusive_ptr& packed_weight, + double output_scale, + int64_t output_zero_point, + double negative_slope) { + auto& ctx = at::globalContext(); +#if AT_MKLDNN_ENABLED() + if (ctx.qEngine() == at::QEngine::ONEDNN) { + return dynamic_cast(packed_weight.get())->apply_leaky_relu( + std::move(input), output_scale, output_zero_point, negative_slope); + } +#endif + TORCH_CHECK( + false, + "Didn't find engine for operation quantized::linear_leaky_relu ", + toString(ctx.qEngine())); + } +}; + +template +class QLinearInt8FusedQDQ final { + public: + static at::Tensor run( + at::Tensor input, + double input_scale, + int64_t input_zero_point, + const c10::intrusive_ptr& packed_weight) { + if (ReluFused) { + return packed_weight->apply_with_input_q_dq_qweight_dq_relu_output_fp32( + std::move(input), input_scale, input_zero_point); + } else { + return packed_weight->apply_with_input_q_dq_qweight_dq_output_fp32( + std::move(input), input_scale, input_zero_point); + } + } +}; + TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) { + register_linear_params(); m.impl(TORCH_SELECTIVE_NAME("quantized::linear"), TORCH_FN(QLinearInt8::run)); m.impl(TORCH_SELECTIVE_NAME("quantized::linear_relu"), TORCH_FN(QLinearInt8::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::linear_leaky_relu"), TORCH_FN(QLinearLeakyReluInt8::run)); } TORCH_LIBRARY_IMPL(_quantized, QuantizedCPU, m) { + register_linear_params(); m.impl(TORCH_SELECTIVE_NAME("_quantized::linear"), TORCH_FN(QLinearInt8::run)); } +TORCH_LIBRARY_IMPL(quantized, CPU, m) { + m.impl(TORCH_SELECTIVE_NAME("quantized::linear_with_input_q_dq_qweight_dq_output_fp32"), TORCH_FN(QLinearInt8FusedQDQ::run)); + m.impl(TORCH_SELECTIVE_NAME("quantized::linear_with_input_q_dq_qweight_dq_relu_output_fp32"), TORCH_FN(QLinearInt8FusedQDQ::run)); +} + } // namespace } // namespace native } // namespace at diff --git a/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp b/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp index 537d0f492f8f1..c7f350c60e87b 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp @@ -662,6 +662,7 @@ class QLinearDynamicFp16 final { }; TORCH_LIBRARY_IMPL(quantized, CPU, m) { + register_linear_params(); m.impl( TORCH_SELECTIVE_NAME("quantized::linear_dynamic"), TORCH_FN(QLinearDynamicInt8::run)); @@ -677,6 +678,7 @@ TORCH_LIBRARY_IMPL(quantized, CPU, m) { } TORCH_LIBRARY_IMPL(_quantized, CPU, m) { + register_linear_params(); m.impl( TORCH_SELECTIVE_NAME("_quantized::linear_dynamic"), TORCH_FN(QLinearDynamicInt8::run)); diff --git a/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp b/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp index dda600e9b41c0..9dcf21689d57d 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp @@ -23,6 +23,7 @@ #include #include +#include #include int register_linear_params(); @@ -249,7 +250,7 @@ c10::intrusive_ptr PackedLinearWeightsOnednn::prepack( dnnl::memory::data_type::u8); ideep::tensor exp_wgt(w_desc); exp_wgt.feed_from(wgt); - ideep::tensor * packed_weight_p = new ideep::tensor(exp_wgt); + ideep::tensor * packed_weight_p = new ideep::tensor(std::move(exp_wgt)); packed_weight_p->set_scale(wgt_scales); packed_weight_p->set_zero_point(wgt_zero_points); std::unique_ptr weight_ptr(packed_weight_p); @@ -380,20 +381,24 @@ class QLinearPackWeightFp16Legacy final { }; TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) { + register_linear_params(); m.impl(TORCH_SELECTIVE_NAME("quantized::linear_prepack"), TORCH_FN(QLinearPackWeightInt8::run)); m.impl(TORCH_SELECTIVE_NAME("quantized::linear_prepack_legacy"), TORCH_FN(QLinearPackWeightInt8Legacy::run)); } TORCH_LIBRARY_IMPL(quantized, CPU, m) { + register_linear_params(); m.impl(TORCH_SELECTIVE_NAME("quantized::linear_prepack_fp16"), TORCH_FN(QLinearPackWeightFp16::run)); m.impl(TORCH_SELECTIVE_NAME("quantized::linear_prepack_fp16_legacy"), TORCH_FN(QLinearPackWeightFp16Legacy::run)); } TORCH_LIBRARY_IMPL(_quantized, QuantizedCPU, m) { + register_linear_params(); m.impl(TORCH_SELECTIVE_NAME("_quantized::linear_prepack"), TORCH_FN(QLinearPackWeightInt8::run)); } TORCH_LIBRARY_IMPL(_quantized, CPU, m) { + register_linear_params(); m.impl(TORCH_SELECTIVE_NAME("_quantized::linear_prepack_fp16"), TORCH_FN(QLinearPackWeightFp16::run)); m.impl(TORCH_SELECTIVE_NAME("_quantized::linear_prepack_fp16_legacy"), TORCH_FN(QLinearPackWeightFp16Legacy::run)); } diff --git a/aten/src/ATen/native/quantized/cpu/qmatmul.cpp b/aten/src/ATen/native/quantized/cpu/qmatmul.cpp index c1e5041a5734c..4da714e0bcf0b 100644 --- a/aten/src/ATen/native/quantized/cpu/qmatmul.cpp +++ b/aten/src/ATen/native/quantized/cpu/qmatmul.cpp @@ -21,7 +21,7 @@ inline void check_inputs(const Tensor& qa, const Tensor& qb) { "MatMul operands should have same data type."); TORCH_CHECK( qa.qscheme() == kPerTensorAffine || qa.qscheme() == kPerTensorSymmetric, - "Only per-tensor quantization is suported in Matmul."); + "Only per-tensor quantization is supported in Matmul."); TORCH_CHECK( qa.qscheme() == qb.qscheme(), "Both inputs to Matmul must have the same quantization scheme."); @@ -45,7 +45,7 @@ Tensor qmatmul( " and ", b_num_dims, " provided)"); TORCH_CHECK( num_dims >= 2, - "Quantized Matmul currently only suports operands which are at least 2-dimensional. (", + "Quantized Matmul currently only supports operands which are at least 2-dimensional. (", num_dims, " provided)"); const int64_t m = qa.size(num_dims - 2); diff --git a/aten/src/ATen/native/quantized/cpu/qmul.cpp b/aten/src/ATen/native/quantized/cpu/qmul.cpp index 35d2139c6c142..aa6ad0e724f5b 100644 --- a/aten/src/ATen/native/quantized/cpu/qmul.cpp +++ b/aten/src/ATen/native/quantized/cpu/qmul.cpp @@ -40,7 +40,7 @@ inline void check_inputs(const Tensor& qa, const Tensor& qb) { TORCH_CHECK(qa.scalar_type() == qb.scalar_type(), "Mul operands should have same data type."); TORCH_CHECK(qa.qscheme() == qb.qscheme(), - "Both inputs to Mul must have the same quantization shceme."); + "Both inputs to Mul must have the same quantization scheme."); } // Note: out is assumed to be the same size as self and other. @@ -314,7 +314,7 @@ class QMulScalarTensor final { static Tensor run(Tensor qa, Tensor b) { TORCH_CHECK(qa.qscheme() == kPerTensorAffine || qa.qscheme() == kPerTensorSymmetric, - "Only per tensor quantization is suported in Mul."); + "Only per tensor quantization is supported in Mul."); auto qc = at::empty_like(qa, qa.suggest_memory_format()); return _mul_scalar_out(qc, qa, b.item()); } diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/4x4c2-sse2.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/4x4c2-sse2.c index 0b2da5a62bed5..398496e081156 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/4x4c2-sse2.c +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/4x4c2-sse2.c @@ -327,14 +327,15 @@ void pytorch_q8gemm_ukernel_4x4c2__sse2( (uint32_t)_mm_cvtsi128_si32(_mm_unpackhi_epi32(vout, vout)); *((uint32_t*)c3) = (uint32_t)_mm_cvtsi128_si32(_mm_srli_si128(vout, 12)); } else { + typedef PYTORCH_QNNP_UNALIGNED uint16_t unaligned_uint16_t; if (nr >= 2) { - *((uint16_t*)c0) = (uint16_t)_mm_extract_epi16(vout, 0); + *((unaligned_uint16_t*)c0) = (uint16_t)_mm_extract_epi16(vout, 0); c0 += 2; - *((uint16_t*)c1) = (uint16_t)_mm_extract_epi16(vout, 2); + *((unaligned_uint16_t*)c1) = (uint16_t)_mm_extract_epi16(vout, 2); c1 += 2; - *((uint16_t*)c2) = (uint16_t)_mm_extract_epi16(vout, 4); + *((unaligned_uint16_t*)c2) = (uint16_t)_mm_extract_epi16(vout, 4); c2 += 2; - *((uint16_t*)c3) = (uint16_t)_mm_extract_epi16(vout, 6); + *((unaligned_uint16_t*)c3) = (uint16_t)_mm_extract_epi16(vout, 6); c3 += 2; vout = _mm_srli_epi32(vout, 16); nr -= 2; diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/common.h b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/common.h index 14bcc01d21ed0..fbfaa85904c78 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/common.h +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/common.h @@ -80,3 +80,15 @@ #if defined(_MSC_VER) #define __builtin_prefetch #endif + +#if defined(__GNUC__) + #define PYTORCH_QNNP_UNALIGNED __attribute__((__aligned__(1))) +#elif defined(_MSC_VER) + #if defined(_M_IX86) + #define PYTORCH_QNNP_UNALIGNED + #else + #define PYTORCH_QNNP_UNALIGNED __unaligned + #endif +#else + #error "Platform-specific implementation of PYTORCH_QNNP_UNALIGNED required" +#endif diff --git a/aten/src/ATen/native/quantized/cpu/qsoftmax.cpp b/aten/src/ATen/native/quantized/cpu/qsoftmax.cpp index f29f548fc758c..921e1cffeb5b2 100644 --- a/aten/src/ATen/native/quantized/cpu/qsoftmax.cpp +++ b/aten/src/ATen/native/quantized/cpu/qsoftmax.cpp @@ -81,8 +81,6 @@ Tensor qsoftmax_qnnpack(const Tensor& qx, const int64_t dim) { initQNNPACK(); pytorch_qnnp_operator_t softargmax = nullptr; - std::unique_ptr softmax_op( - softargmax); pytorch_qnnp_status status = pytorch_qnnp_create_softargmax_nc_q8( channels, @@ -96,6 +94,9 @@ Tensor qsoftmax_qnnpack(const Tensor& qx, const int64_t dim) { "failed to create QNNPACK Softmax operator"); TORCH_CHECK_NOTNULL(softargmax); + std::unique_ptr softmax_op( + softargmax); + status = pytorch_qnnp_setup_softargmax_nc_q8( softargmax, batch_size, input, input_stride, output, output_stride); TORCH_CHECK( diff --git a/aten/src/ATen/native/quantized/cudnn/BinaryOps.cpp b/aten/src/ATen/native/quantized/cudnn/BinaryOps.cpp index d9abd8bcfc797..fbb46b4b0174c 100644 --- a/aten/src/ATen/native/quantized/cudnn/BinaryOps.cpp +++ b/aten/src/ATen/native/quantized/cudnn/BinaryOps.cpp @@ -71,10 +71,10 @@ std::unordered_map> PackedConvWeightCudnn< int64_t groups, bool transpose) { // TODO: need to check out to implement groups for conv operator in Conv.cpp - TORCH_CHECK(groups == 1, "Quantized cudnn conv2d is currenty limited to groups = 1; received groups =", groups); + TORCH_CHECK(groups == 1, "Quantized cudnn conv2d is currently limited to groups = 1; received groups =", groups); TORCH_CHECK(weight.qscheme() == c10::kPerTensorAffine, "Unsupported qscheme: ", toString(weight.qscheme())); TORCH_CHECK( kSpatialDim == 2, // 1D is packed as 2d, hence we don't need other checks diff --git a/aten/src/ATen/native/quantized/library.cpp b/aten/src/ATen/native/quantized/library.cpp index a6ac4b330b0f1..92990fc267270 100644 --- a/aten/src/ATen/native/quantized/library.cpp +++ b/aten/src/ATen/native/quantized/library.cpp @@ -152,6 +152,39 @@ TORCH_LIBRARY(quantized, m) { m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_relu_dynamic(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, bool reduce_range=False) -> Tensor Y")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_dynamic_fp16(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack) -> Tensor Y")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_relu_dynamic_fp16(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack) -> Tensor Y")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_leaky_relu(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, float Y_scale_i, int Y_zero_point_i, float negative_slope) -> Tensor Y")); + // Corresponding pattern (the ops with `*` are part of the pattern that + // represents the computation of quantized::linear_with_input_q_dq_qweight_dq_output_fp32): + // input -> q* -> dq* -> linear* -> + // qweight -> dq* / + // + // After fusion: + // input -> quantized::linear_with_input_q_dq_qweight_dq_output_fp32* -> + // qweight / + // + // Additional Note: the weight is packed as well + // Params: + // X: float32 Tensor, will be quantized to quint8 in the op + // W_prepack: packed qint8 quantized weight and bias + // Returns: + // Y: float32 Tensor + m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_with_input_q_dq_qweight_dq_output_fp32(Tensor X, float X_scale, int X_zero_point, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack) -> Tensor Y")); + // Corresponding pattern (the ops with `*` are part of the pattern that + // represents the computation of quantized::linear_with_input_q_dq_qweight_dq_relu_output_fp32): + // input -> q* -> dq* -> linear* -> relu* -> + // qweight -> dq* / + // + // After fusion: + // input -> quantized::linear_with_input_q_dq_qweight_dq_relu_output_fp32* -> + // qweight / + // + // Additional Note: the weight is packed as well + // Params: + // X: float32 Tensor, will be quantized to quint8 in the op + // W_prepack: packed qint8 quantized weight and bias + // Returns: + // Y: float32 Tensor + m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_with_input_q_dq_qweight_dq_relu_output_fp32(Tensor X, float X_scale, int X_zero_point, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack) -> Tensor Y")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_prepack(Tensor W, Tensor? B=None) -> __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_prepack_fp16(Tensor W, Tensor? B=None) -> __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_prepack_legacy(Tensor W, Tensor? B=None) -> Tensor W_prepack")); diff --git a/aten/src/ATen/native/quantized/qconv_unpack.cpp b/aten/src/ATen/native/quantized/qconv_unpack.cpp index 90e210ebe227d..cff99560b7eec 100644 --- a/aten/src/ATen/native/quantized/qconv_unpack.cpp +++ b/aten/src/ATen/native/quantized/qconv_unpack.cpp @@ -28,6 +28,12 @@ and /cudnn/ConvUnpackImpl.cpp, for cudnn. #include #endif +template +int register_conv_params(); + +extern template int register_conv_params<2>(); +extern template int register_conv_params<3>(); + namespace at { namespace native { @@ -192,6 +198,8 @@ unpack_quantized_prepacked_sizes_conv2d(const IValue& ivalue) { } TORCH_LIBRARY_IMPL(quantized, CatchAll, m) { + register_conv_params<2>(); + register_conv_params<3>(); // conv_unpack is deprecated, please use conv2d_unpack for 2D conv. m.impl(TORCH_SELECTIVE_NAME("quantized::conv_unpack"), TORCH_FN(QConvUnpackWeightsInt8<2>::run)); // We use conv2d_unpack to be consistent with conv3d_unpack diff --git a/aten/src/ATen/native/quantized/qlinear_unpack.cpp b/aten/src/ATen/native/quantized/qlinear_unpack.cpp index f293a7307e330..19c9890c82e38 100644 --- a/aten/src/ATen/native/quantized/qlinear_unpack.cpp +++ b/aten/src/ATen/native/quantized/qlinear_unpack.cpp @@ -13,6 +13,8 @@ and /cudnn/linear_unpack_impl.cpp, for cudnn. #include #include +int register_linear_params(); + namespace at { namespace native { namespace { @@ -68,6 +70,7 @@ TORCH_LIBRARY_IMPL(quantized, CPU, m) { } TORCH_LIBRARY_IMPL(quantized, CatchAll, m) { + register_linear_params(); m.impl(TORCH_SELECTIVE_NAME("quantized::linear_unpack"), TORCH_FN(QLinearUnpackWeightInt8::run)); m.impl(TORCH_SELECTIVE_NAME("quantized::linear_unpack_fp16"), TORCH_FN(QLinearUnpackWeightFp16::run)); } diff --git a/aten/src/ATen/native/sparse/SparseCsrTensor.cpp b/aten/src/ATen/native/sparse/SparseCsrTensor.cpp index 2bcbe00a87205..59db274a978b2 100644 --- a/aten/src/ATen/native/sparse/SparseCsrTensor.cpp +++ b/aten/src/ATen/native/sparse/SparseCsrTensor.cpp @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -42,6 +43,8 @@ #include #include #include +#include +#include #include #include #include @@ -50,6 +53,7 @@ #include #include #include +#include #endif namespace at { @@ -59,6 +63,50 @@ using namespace at::sparse_csr; namespace { +bool solve_arange(const Tensor& input, int64_t& start, int64_t& end, int64_t& step) { + /* + This function solves the equation + + input == arange(start, end, step) + + for integers start, end, and step, if possible. If the solution + exists, returns true. + */ + int64_t n = input.numel(); + if (n == 0) { + // a trivial solution + start = end = 0; + step = 1; + } else if (n == 1) { + // a simple solution + start = input[0].item(); + end = start + 1; + step = 1; + } else { + Tensor first_last = input.slice(0, 0, n, n - 1).cpu(); + int64_t start_candidate = first_last[0].item(); + int64_t end_candidate = first_last[1].item() + 1; + if (end_candidate - start_candidate == n) { + // a special solution + start = start_candidate; + end = end_candidate; + step = 1; + } else { + // detect if general solution exists + Tensor possible_steps = input.slice(0, 1).sub(input.slice(0, 0, n - 1)); + Tensor possible_step = possible_steps[0]; + if ((possible_steps.eq(possible_step)).all().item()) { + start = start_candidate; + end = end_candidate; + step = possible_step.item(); + } else { + // no solution + return false; + } + } + } + return true; +} } // end anonymous namespace @@ -129,7 +177,7 @@ void _validate_sparse_compressed_tensor_args_worker(const Tensor& compressed_ind // 3.1 TORCH_CHECK( static_cast(size.size()) == batch_ndim + base_ndim + dense_ndim, - "tensor dimensionality must be sum of batch, base, and dense dimensionalites (=", + "tensor dimensionality must be sum of batch, base, and dense dimensionalities (=", batch_ndim, " + ", base_ndim, " + ", dense_ndim, ") but got ", size.size()); // For CSR/CSC formats, we define blocksize=(1, 1) so that checking @@ -380,7 +428,7 @@ DimVector _estimate_sparse_compressed_tensor_size( } TORCH_CHECK( static_cast(size.size()) == batch_ndim + base_ndim + dense_ndim, - "tensor dimensionality must be sum of batch, base, and dense dimensionalites (=", + "tensor dimensionality must be sum of batch, base, and dense dimensionalities (=", batch_ndim, " + ", base_ndim, " + ", dense_ndim, ") but got ", size.size()); return size; } @@ -559,13 +607,13 @@ Tensor& copy_sparse_compressed_(Tensor& self, const Tensor& src, bool non_blocki "torch.copy_: expected shapes of self and src to match along dimension ", self_compressed_dim, " for ", self.layout(), " layout but the corresponding dimensions of self and src are ", - self_compressed_dims, " and ", src_compressed_dims, ", respecitvely."); + self_compressed_dims, " and ", src_compressed_dims, ", respectively."); } else { TORCH_CHECK(self_compressed_dims == src_compressed_dims, "torch.copy_: expected shapes of self and src to match along dimensions ", self_compressed_dim, " and ", src_compressed_dim, ", respectively, for ", self.layout(), " layout but the corresponding dimensions of self and src are ", - self_compressed_dims, " and ", src_compressed_dims, ", respecitvely."); + self_compressed_dims, " and ", src_compressed_dims, ", respectively."); } AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(self.layout(), "copy_sparse_compressed_", [&]{}, @@ -576,7 +624,7 @@ Tensor& copy_sparse_compressed_(Tensor& self, const Tensor& src, bool non_blocki auto src_blocksize = DimVector(src_values.sizes().slice(src_values.dim()-2, 2)); TORCH_CHECK(self_blocksize == src_blocksize, "torch.copy_: copy of sparse compressed tensors having different block sizes is not supported.", - " self and src block sizes are ", self_blocksize, " and ", src_blocksize, ", respectivly."); + " self and src block sizes are ", self_blocksize, " and ", src_blocksize, ", respectively."); }); AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(self.layout(), "copy_sparse_compressed_", [&]{ @@ -744,17 +792,19 @@ Tensor empty_like_sparse_csr( } } -Tensor select_sparse_csr(const Tensor& self, int64_t dim, int64_t index) { +template +Tensor select_sparse_csr_worker(const Tensor& self, int64_t dim, int64_t index) { + constexpr const char* select_name = (require_view ? "select()" : "select_copy()"); AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS( - self.layout(), "select()", []() { return; }); + self.layout(), "select", []() { return; }); TORCH_CHECK_INDEX( - self.dim() != 0, "select() cannot be applied to a 0-dim tensor."); + self.dim() != 0, select_name, " cannot be applied to a 0-dim tensor."); dim = maybe_wrap_dim(dim, self.dim()); auto size = self.size(dim); if (index < -size || index >= size) { TORCH_CHECK_INDEX( false, - "select(): index ", + select_name, ": index ", index, " out of range for tensor of size ", self.sizes(), @@ -765,6 +815,14 @@ Tensor select_sparse_csr(const Tensor& self, int64_t dim, int64_t index) { index += size; } + auto select_strided = [](const Tensor& self, int64_t dim, int64_t index) { + if (require_copy) { + return at::select_copy(self, dim, index); + } else { + return self.select(dim, index); + } + }; + TORCH_INTERNAL_ASSERT(dim >= 0 && dim < self.dim()); auto new_sizes = DimVector(self.sizes()); @@ -790,7 +848,7 @@ Tensor select_sparse_csr(const Tensor& self, int64_t dim, int64_t index) { return at::native::_sparse_compressed_tensor_unsafe( compressed_indices.select(dim, index), plain_indices.select(dim, index), - self.values().select(dim, index), + select_strided(self.values(), dim, index), new_sizes, optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), @@ -798,28 +856,237 @@ Tensor select_sparse_csr(const Tensor& self, int64_t dim, int64_t index) { options.pinned_memory_opt()); } else if (dim < n_batch + 2) { // Selecting sparse dimension - TORCH_CHECK( - self.layout() == kSparseCsr || self.layout() == kSparseCsc, - "select(): selecting non-batch dimensions is currently only supported for non-blocked sparse compressed layouts tensors."); TORCH_CHECK( n_batch == 0, - "select(): selecting rows or columns is not implemented for batched sparse compressed tensors.") - // Converting to COO and calling select is slightly slower than operating - // on the CSR indices directly for constructing a COO vector, however - // current version is more readable and easier to understand. - return self.to_sparse().select(dim, index); + select_name, ": selecting sparse dimensions is not implemented for batched sparse compressed tensors.") + TORCH_INTERNAL_ASSERT(dim == 0 || dim == 1); + + DimVector blocksize{1, 1}; + AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(self.layout(), "select", [&] {}, [&] { + blocksize[0] = std::max(1, self.values().size(n_batch + 1)); + blocksize[1] = std::max(1, self.values().size(n_batch + 2)); + }); + + auto indices_options = compressed_indices.options(); + int64_t fast_dim = AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(self.layout(), "select", [&]() { return 0; }, [&]() { return 1; }); + int64_t other_dim = (dim == 0 ? 1 : 0); + Tensor indices; + Tensor values; + bool is_view = dim == fast_dim; + if (is_view) { + // select is always a view operation + Tensor start_end = compressed_indices.narrow(0, index / blocksize[dim], 2).cpu(); + int64_t start = start_end[0].item(); + int64_t end = start_end[1].item(); + indices = plain_indices.slice(0, start, end); + values = self.values().slice(0, start, end); + } else { + Tensor decompressed_indices = at::_convert_indices_from_csr_to_coo(compressed_indices, plain_indices) + .select(0, 0); + + Tensor dim_indices = at::where(plain_indices.eq(index / blocksize[dim]))[0]; + // Notice that dim_indices is a sorted sequence of non-negative + // distinct integers. Below we'll try to solve `dim_indices == + // arange(start, stop, step)`. If the solution exists then the + // select will be a view operation also for the `dim != + // fast_dim` case. + int64_t start{}, end{}, step{}; + if (solve_arange(dim_indices, start, end, step)) { + indices = decompressed_indices.slice(0, start, end, step); + values = self.values().slice(0, start, end, step); + is_view = true; + } else { + // select will be a copy operation due to index_select! + indices = decompressed_indices.index_select(0, dim_indices); + values = self.values().index_select(0, dim_indices); + } + } + + AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(self.layout(), "select", [&]() {}, + [&]() { + /* + The formula for select indices and values below are best + explained by an example. Consider a BSR tensor with a + block size (2, 3) having four blocks (the other two blocks + contain all zeros and hence will not be specified): + + [ 1 2 3] | [ 7 8 9] + [ 4 5 6] | [10 11 12] + --------------------- + [13 14 15] | [ 0 0 0] + [16 17 18] | [ 0 0 0] + ----------------------- + [ 0 0 0] | [19 20 21] + [ 0 0 0] | [22 23 24] + + that represents a 6 x 6 tensor: + + [ 1 2 3 7 8 9 ] + [ 4 5 6 10 11 12 ] + [ 13 14 15 0 0 0 ] + [ 16 17 18 0 0 0 ] + [ 0 0 0 19 20 21 ] + [ 0 0 0 22 23 24 ] + + The corresponding data for the BSR representation is: + + crow_indices = [0 2 3 4] + col_indices = [0 1 0 1] + values = [ [[1 2 3], [4 5 6]], [[7 8 9], [10 11 12]], [[13 14 15], [16 17 18]], [[19 20 21], [22 23 24]] ] + shape = (6, 6) + + From crow_indices, we can find that + + row_indices = [0 0 1 2] + + In the following, we'll illustrate the details of + computing the result of torch.select_copy(input, dim, + index) where dim is 0 or 1, and index is in + range(shape[dim]). + + Select a row of a BSR tensor + ---------------------------- + + We will consider first the dim=0 case that corresponds to + selecting a index-th row of the tensor. For instance, for + dim=0 and index=1, the expected result would represent a + 1D tensor: + + [ 4 5 6 10 11 12 ] + + that is a concatenated tensor of certain slices from the + first and the second block that is computed as follows: + + values[dim_indices].select(1 + dim, index % blocksize[dim]).flatten(0, 1) + -> values[[0, 1]][:, 1 % 2].flatten(0, 1) + -> [ [[1 2 3], [4 5 6]], [[7 8 9], [10 11 12]] ][:, 1].flatten(0, 1) + -> [ [4 5 6], [10 11 12]].flatten(0, 1) + -> [ 4 5 6 10 11 12] + + where dim_indices is found as + + where(row_indices == index//blocksize[dim]) + -> where([0 0 1 2] == 1//2) + -> [0 1] + + The corresponding column indices are computed as + + (col_indices[dim_indices].mul(blocksize[other_dim]).unsqueeze(1) + arange(blocksize[other_dim]).unsqueeze(0)).flatten(0, 1) + + where other_dim is 1 if dim is 0, and 0 if dim is 1. Let's + expand the above expression with the data in the example: + + -> (col_indices[[0, 1]].mul(3).unsqueeze(1) + arange(3).unsqueeze(0)).flatten(0, 1) + -> ([[0 1].mul(3).unsqueeze(1) + [[0 1 2]]).flatten(0, 1) + -> ([[[0], [3]] + [[0 1 2]]).flatten(0, 1) <- here addition will use broadcasting rules! + -> ([[[0 1 2], [3 4 5]]).flatten(0, 1) + -> [0 1 2 3 4 5] + + Finally, the select(dim=0, index=1) op on the given sparse + compressed tensors will return a COO tensor: + + sparse_coo_tensor([0 1 2 3 4 5].unsqueeze(0), [4 5 6 10 11 12], (6,)) + + that represents the expected result: [ 4 5 6 10 11 12 ] + + Select a column of a BSR tensor + ------------------------------- + + Next, we'll consider the dim=1 case that corresponds to + selecting the index-th column of the tensor. For instance, + for dim=1 and index=4, the expected result would represent + a 1D tensor: + + [ 8 11 0 0 20 23] + + that is a concatenated tensor of certain slices from the + second and the last block: + + values[dim_indices].select(1 + dim, index % blocksize[dim]).flatten(0, 1) + -> values[[1, 3]][:, :, 4 % 3 ].flatten(0, 1) + -> [ [[7 8 9], [10 11 12]], [[19 20 21], [22 23 24]] ][:, 1, 1].flatten(0, 1) + -> [ [8 11], [20 23]].flatten(0, 1) + -> [ 8 11 20 23 ] + + The corresponding row indices are computed as + + (row_indices[dim_indices].mul(blocksize[other_dim]).unsqueeze(1) + arange(blocksize[other_dim]).unsqueeze(0)).flatten(0, 1) + + where dim_indices is + + where(col_indices == index//blocksize[dim]) + -> where([0 1 0 1] == 4//3) + -> [1 3] + + and we have + + (row_indices[dim_indices].mul(blocksize[other_dim]).unsqueeze(1) + arange(blocksize[other_dim]).unsqueeze(0)).flatten(0, 1) + -> (row_indices[[1 3]].mul(2).unsqueeze(1) + arange(2).unsqueeze(0)).flatten(0, 1) + -> ([0 4].unsqueeze(1) + [0 1].unsqueeze(0)).flatten(0, 1) + -> ([[0], [4]] + [[0 1]]).flatten(0, 1) <- here addition will use broadcasting rules! + -> ([[0 1], [4 5]]).flatten(0, 1) + -> [ 0 1 4 5 ] + + Finally, the select(dim=1, index=4) op on the given sparse + compressed tensors will return a COO tensor: + + sparse_coo_tensor([0 1 4 5].unsqueeze(0), [8 11 20 23], (6,)) + + that represents the expected result: [ 8 11 0 0 20 23 ] + + */ + Tensor subblock_indices = at::arange(0, blocksize[other_dim], indices_options); + indices = indices.mul(blocksize[other_dim]).unsqueeze(1).add(subblock_indices.unsqueeze(0)).flatten(0, 1); + values = values.select(dim + 1, index % blocksize[dim]).flatten(0, 1); + // flatten(0, 1) can be a view or a copy operation. If view + // is required, it will be checked below via is_alias_of, + // otherwise, we'll check if copy is made here to avoid + // unnecessary clone below: + if (require_copy) { + is_view = values.is_alias_of(self.values()); + } + }); + + if (require_view) { + TORCH_CHECK(values.is_alias_of(self.values()), select_name, + ": no view exists for the given input, consider using torch.select_copy."); + } + + indices = indices.unsqueeze(0).to(kLong); + if (require_copy && is_view) { + values = values.clone(); + } + return at::_sparse_coo_tensor_unsafe(indices, values, new_sizes)._coalesced_(true); } else { // Selecting dense dimension - return AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS( + Tensor new_values = AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS( self.layout(), "select", // Non blocked layout (2 sparse dims become 1 nnz dim in values, so dim // is found one position to the left) - [&]() { return self.values().select(dim - 1, index); }, + [&]() { return select_strided(self.values(), dim - 1, index); }, // Block layout (2 sparse dims become 1 nnz dim + 2 block-shape dims in // values, so dim is found 1 position to the right) - [&]() { return self.values().select(dim + 1, index); }); + [&]() { return select_strided(self.values(), dim + 1, index); }); + return at::native::_sparse_compressed_tensor_unsafe( + compressed_indices, + plain_indices, + new_values, + new_sizes, + optTypeMetaToScalarType(options.dtype_opt()), + options.layout_opt(), + options.device_opt(), + options.pinned_memory_opt()); } } + +Tensor select_sparse_csr(const Tensor& self, int64_t dim, int64_t index) { + return select_sparse_csr_worker(self, dim, index); +} + +Tensor select_copy_sparse_csr(const Tensor& self, int64_t dim, int64_t index) { + return select_sparse_csr_worker(self, dim, index); +} + } // namespace native } // namespace at diff --git a/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp b/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp index acc95564e6ddb..3818b558cc1d1 100644 --- a/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp +++ b/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp @@ -48,6 +48,8 @@ #include #include #include +#include +#include #include #include #include @@ -58,6 +60,8 @@ #include #include #include +#include +#include #include #include #include @@ -117,7 +121,8 @@ namespace meta { TORCH_META_FUNC(_convert_indices_from_coo_to_csr) (const Tensor& self, const int64_t size, const bool out_int32) { - TORCH_CHECK(self.dim() <= 1, "Input is supposed to be a vector"); + TORCH_CHECK(self.dim() <= 1, "Input is supposed to be a vector, but got ", + self.dim(), " dimensional tensor."); ScalarType scalar_type = out_int32 ? ScalarType::Int : ScalarType::Long; c10::TensorOptions options = TensorOptions().device(self.options().device()).dtype(scalar_type); @@ -130,8 +135,10 @@ TORCH_META_FUNC(_convert_indices_from_csr_to_coo) const bool out_int32, const bool transpose) { TORCH_CHECK( - crow_indices.dim() == 1, "crow_indices is supposed to be a vector"); - TORCH_CHECK(col_indices.dim() == 1, "col_indices is supposed to be a vector"); + crow_indices.dim() == 1, "crow_indices is supposed to be a vector, but got ", + crow_indices.dim(), " dimensional tensor."); + TORCH_CHECK(col_indices.dim() == 1, "col_indices is supposed to be a vector, but got ", + col_indices.dim(), " dimensional tensor."); ScalarType scalar_type = out_int32 ? ScalarType::Int : ScalarType::Long; c10::TensorOptions options = crow_indices.options().dtype(scalar_type); set_output_raw_strided(0, {2, col_indices.numel()}, {}, options, {}); @@ -355,10 +362,12 @@ CREATE_UNARY_UFUNC(asinh); CREATE_UNARY_UFUNC(atan); CREATE_UNARY_UFUNC(atanh); CREATE_UNARY_UFUNC(ceil); +CREATE_UNARY_UFUNC(deg2rad); CREATE_UNARY_UFUNC(erf); CREATE_UNARY_UFUNC(erfinv); CREATE_UNARY_UFUNC(expm1); CREATE_UNARY_UFUNC(floor); +CREATE_UNARY_UFUNC(frac); CREATE_UNARY_UFUNC(log1p); CREATE_UNARY_UFUNC(neg); CREATE_UNARY_UFUNC(rad2deg); diff --git a/aten/src/ATen/native/sparse/SparseTensor.cpp b/aten/src/ATen/native/sparse/SparseTensor.cpp index 625f5b1c0b080..859218b2f7042 100644 --- a/aten/src/ATen/native/sparse/SparseTensor.cpp +++ b/aten/src/ATen/native/sparse/SparseTensor.cpp @@ -163,10 +163,10 @@ SparseTensor new_with_dims_sparse( return self; } -SparseTensor new_with_dims_and_tensor_sparse( +SparseTensor new_with_dims_and_tensor_sparse_symint( int64_t sparse_dim, int64_t dense_dim, - ArrayRef size, + c10::SymIntArrayRef size, const Tensor& indices, const Tensor& values, c10::optional dtype, @@ -444,7 +444,9 @@ Tensor _sparse_coo_tensor_unsafe_symint(const Tensor& indices, const Tensor& val Tensor values = expand_values_if_needed(values_); - auto sparse_dim = indices.sym_size(0); + // This guard is intentional: we don't support dynamic shapes along the + // indices dimension because that implies variable dimensionality + auto sparse_dim = indices.sym_size(0).guard_int(__FILE__, __LINE__); auto dense_dim = values.dim() - 1; return at::_sparse_coo_tensor_with_dims_and_tensors_symint( @@ -516,7 +518,40 @@ const SparseTensor& resize_as_sparse_(const SparseTensor& self, const SparseTens return self; } -SparseTensor dense_to_sparse(const Tensor& self) { +SparseTensor dense_to_sparse(const Tensor& self, c10::optional layout, OptionalIntArrayRef blocksize) { + if (layout.has_value()) { + if (blocksize.has_value() && !(*layout == kSparseBsr || *layout == kSparseBsc)) { + AT_ERROR("to_sparse for ", self.layout(), " to ", *layout, " conversion does not use specified blocksize"); + } + if (self.layout() == *layout) { + return self; + } + switch (*layout) { + case kStrided: + return self; + case kSparse: + return dense_to_sparse(self, self.dim()); + case kSparseCsr: + return self.to_sparse_csr(); + case kSparseCsc: + return self.to_sparse_csc(); + case kSparseBsr: + if (blocksize.has_value()) { + return self.to_sparse_bsr(*blocksize); + } + AT_ERROR("to_sparse for ", self.layout(), " to ", *layout, " conversion requires blocksize"); + break; + case kSparseBsc: + if (blocksize.has_value()) { + return self.to_sparse_bsc(*blocksize); + } + break; + AT_ERROR("to_sparse for ", self.layout(), " to ", *layout, " conversion requires blocksize"); + default: + break; + } + AT_ERROR("to_sparse not implemented for ", self.layout(), " to ", *layout, " conversion"); + } return dense_to_sparse(self, self.dim()); } diff --git a/aten/src/ATen/native/sparse/SparseUnaryOps.cpp b/aten/src/ATen/native/sparse/SparseUnaryOps.cpp index ed6df15fe7795..9e0503337b5de 100644 --- a/aten/src/ATen/native/sparse/SparseUnaryOps.cpp +++ b/aten/src/ATen/native/sparse/SparseUnaryOps.cpp @@ -19,6 +19,8 @@ #include #include #include +#include +#include #include #include #include @@ -27,6 +29,8 @@ #include #include #include +#include +#include #include #include #include @@ -39,6 +43,8 @@ #include #include #include +#include +#include #include #include #include @@ -165,12 +171,15 @@ COALESCED_UNARY_UFUNC(asinh); COALESCED_UNARY_UFUNC(atan); COALESCED_UNARY_UFUNC(atanh); COALESCED_UNARY_UFUNC(ceil); +COALESCED_UNARY_UFUNC(deg2rad); COALESCED_UNARY_UFUNC(erf); COALESCED_UNARY_UFUNC(erfinv); COALESCED_UNARY_UFUNC(expm1); COALESCED_UNARY_UFUNC(floor); +COALESCED_UNARY_UFUNC(frac); COALESCED_UNARY_UFUNC(log1p); COALESCED_UNARY_UFUNC(round); +COALESCED_UNARY_UFUNC(rad2deg); COALESCED_UNARY_UFUNC(sign); COALESCED_UNARY_UFUNC(sgn); COALESCED_UNARY_UFUNC(sin); diff --git a/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp b/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp index 379640bad56b9..833fd41eb6a02 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp +++ b/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp @@ -1401,7 +1401,7 @@ void sampled_addmm_out_sparse_csr( const Scalar& beta, const Scalar& alpha, const at::sparse_csr::SparseCsrTensor& C) { -#if !AT_USE_CUSPARSE_GENERIC_SDDMM() +#if !(AT_USE_CUSPARSE_GENERIC_SDDMM() || AT_USE_HIPSPARSE_GENERIC_52_API()) TORCH_CHECK( false, "Calling sampled_addmm with sparse GPU tensors requires compiling ", diff --git a/aten/src/ATen/native/sparse/cuda/SparseMatMul.cu b/aten/src/ATen/native/sparse/cuda/SparseMatMul.cu index 8cc5fc3157c38..33123abccbe93 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseMatMul.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseMatMul.cu @@ -70,7 +70,7 @@ Tensor _to_csr_int(const Tensor& rowIndices, int64_t dim, int64_t nnz) { #pragma push // NVCC complains that confirm_mult_size is not used, // but it is used in specializations of CusparseMatrixMultiplyOp below -#pragma diag_suppress 177 // Function was declared but never referenced +#pragma nv_diag_suppress 177 // Function was declared but never referenced int confirm_mult_size(const std::vector& mat1_size, const std::vector& mat2_size) { TORCH_CHECK( mat1_size[1] == mat2_size[0], diff --git a/aten/src/ATen/native/tags.yaml b/aten/src/ATen/native/tags.yaml index 5d2a69db016fd..ce75b0ae10c63 100644 --- a/aten/src/ATen/native/tags.yaml +++ b/aten/src/ATen/native/tags.yaml @@ -40,3 +40,8 @@ type promotion and boardcasting ops. Canonical aten ops is also effectively the opset produced by torchdynamo.export(aten_graph=True), and thus can be used as an opset for export purpose. +- tag: pointwise + desc: | + Pointwise operators are operators where each element of the output is computed only by accessing + the corresponding element of all the broadcasted inputs. The output shape will be the broadcasted + shape of the inputs. diff --git a/aten/src/ATen/native/transformers/attention.cpp b/aten/src/ATen/native/transformers/attention.cpp index c03935ecfbf3d..06ea49bb516c4 100644 --- a/aten/src/ATen/native/transformers/attention.cpp +++ b/aten/src/ATen/native/transformers/attention.cpp @@ -6,6 +6,8 @@ #include #include #include +#include +#include #ifndef AT_PER_OPERATOR_HEADERS #include @@ -14,7 +16,6 @@ #endif #include - namespace at { namespace native { @@ -106,6 +107,17 @@ void transform_bias_rescale_qkv_inner_loop( } } +Tensor transform_0213(const Tensor& a) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(a.size(1)); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(a.size(3)); + return a.permute({0, 2, 1, 3}) + .contiguous() + .view({a.size(0), a.size(2), a.size(1) * a.size(3)}); +} + +} // namespace + + Tensor bmm_nt(const Tensor& a, const Tensor& b) { auto a_ = a.view({a.size(0) * a.size(1), a.size(2), a.size(3)}); auto b_ = b.view({b.size(0) * b.size(1), b.size(2), b.size(3)}); @@ -118,7 +130,7 @@ Tensor masked_softmax( Tensor& attn_scores, c10::optional attn_mask, const Tensor& query, - c10::optional mask_type = NULL) { + c10::optional mask_type) { if (query.is_nested() && !attn_mask) { return at::_nested_tensor_softmax_with_shape(attn_scores, query); } @@ -128,15 +140,6 @@ Tensor masked_softmax( "negatively affect performance. Prefer to use a boolean mask directly."); attn_mask = attn_mask->to(at::kBool); } - if (attn_scores.is_cpu() && attn_mask && attn_mask->dim() == 2) { - // TODO: CPU path does not support transformer mask yet. - const auto batch_size = attn_scores.sizes()[0]; - const auto seq_len = attn_scores.sizes()[3]; - TORCH_CHECK(attn_mask->sizes()[0] == batch_size); - TORCH_CHECK(attn_mask->sizes()[1] == seq_len); - attn_mask = attn_mask->view({batch_size, 1, 1, seq_len}); - attn_mask = at::expand_inplace(attn_scores, *attn_mask)->contiguous(); - } if (attn_mask) { return _masked_softmax(attn_scores, *attn_mask, attn_scores.dim() - 1, mask_type); } else { @@ -156,13 +159,6 @@ Tensor bmm_nn(Tensor& out, const Tensor& a, const Tensor& b) { return c_.view({a.size(0), a.size(1), a.size(2), b.size(3)}); } -Tensor transform_0213(const Tensor& a) { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(a.size(1)); - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(a.size(3)); - return a.permute({0, 2, 1, 3}) - .contiguous() - .view({a.size(0), a.size(2), a.size(1) * a.size(3)}); -} Tensor transform0213_gemm_nt_bias( const Tensor& a, @@ -254,8 +250,6 @@ Tensor qkv_projection( return qkv; } -} // namespace - // compute q = (q + q_bias) / sqrt(dim_per_head), k = k + k_bias, v = v + v_bias std::tuple transform_bias_rescale_qkv_cpu( const Tensor& qkv, @@ -312,7 +306,7 @@ std::tuple transform_bias_rescale_qkv_cpu( return std::make_tuple(q_k_v_s[0], q_k_v_s[1], q_k_v_s[2]); } -std::tuple native_multi_head_attention( +std::tuple native_multi_head_attention_cpu( const Tensor& query, const Tensor& key, const Tensor& value, @@ -683,27 +677,74 @@ std::tuple native_decoder_only_multi_head_attent // L: Target sequence length // E: Embedding dimension std::tuple _scaled_dot_product_attention( - const Tensor& query_, const Tensor& key, const Tensor& value, - const c10::optional& attn_mask_, double dropout_p, bool need_attn_weights, bool is_causal) { - if (query_.requires_grad() || key.requires_grad() || value.requires_grad()){ - return at::_scaled_dot_product_attention_math(query_, key, value, attn_mask_, dropout_p, need_attn_weights, is_causal); - } - return at::_scaled_dot_product_attention_forward(query_, key, value, attn_mask_, dropout_p, need_attn_weights, is_causal); + const Tensor& query_, + const Tensor& key, + const Tensor& value, + const c10::optional& attn_mask_, + double dropout_p, + bool need_attn_weights, + bool is_causal) { + // TODO: The second return is the attention weights if the math kernel is + // used. The fused kernels do not return this Tensor so for the fused kernels + // The second return SHOULD always be an empty Tensor, unless need_attn_weights + // is true (in which case the fused kernels would not be called). This blows up + // op_info tests. + int64_t choice_int = at::_fused_sdp_choice( + query_, key, value, attn_mask_, dropout_p, need_attn_weights, is_causal); + sdp::SDPBackend backend = static_cast(choice_int); + switch (backend) { + case sdp::SDPBackend::flash_attention: { + auto out_lse_softmax = at::_scaled_dot_product_flash_attention( + query_, key, value, dropout_p, need_attn_weights, is_causal); + return std::make_tuple( + std::move(std::get<0>(out_lse_softmax)), + std::move(std::get<2>(out_lse_softmax))); + } + case sdp::SDPBackend::efficient_attention: { + bool compute_logsumexp = + (query_.requires_grad() || key.requires_grad() || + value.requires_grad()); + return at::_scaled_dot_product_efficient_attention( + query_, key, value, compute_logsumexp, is_causal); + } + case sdp::SDPBackend::math: + return at::_scaled_dot_product_attention_math( + query_, + key, + value, + attn_mask_, + dropout_p, + need_attn_weights, + is_causal); + default: + TORCH_CHECK( + false, + "No viable backend for scaled_dot_product_attention was found."); + return std::make_tuple(Tensor(), Tensor()); + } } -std::tuple _scaled_dot_product_attention_forward_math( - const Tensor& query_, const Tensor& key, const Tensor& value, +int64_t _fused_sdp_choice_cpp(const Tensor& query_, const Tensor& key, const Tensor& value, const c10::optional& attn_mask_, double dropout_p, bool need_attn_weights, bool is_causal){ - return at::_scaled_dot_product_attention_math(query_, key, value, attn_mask_, dropout_p, need_attn_weights, is_causal); - } + return static_cast(sdp::SDPBackend::math); +} std::tuple _scaled_dot_product_attention_math( const Tensor& query_, const Tensor& key, const Tensor& value, const c10::optional& attn_mask_, double dropout_p, bool need_attn_weights, bool is_causal) { + if (query_.is_nested() || key.is_nested() || value.is_nested()) { + TORCH_CHECK( + query_.is_contiguous() && key.is_contiguous() && + value.is_contiguous(), + "scaled_dot_product_attention: If inputs are nested tensors they must be contiguous"); + } auto attn_mask = attn_mask_; // Naive, composite implementation defined here. - const auto embed_size = query_.size(-1); - const auto query = query_ * (1. / ::sqrt(static_cast(embed_size))); + + // Scale q,k before matmul for stability see https://tinyurl.com/sudb9s96 for math + const auto embed_size = SymFloat(query_.sym_size(-1)); + const auto scaling_factor = embed_size.sqrt().sqrt(); + const auto query = query_ / scaling_factor; if (is_causal) { TORCH_CHECK(!attn_mask.has_value(), "_scaled_dot_product_attention: Explicit attn_mask should not be set when is_causal=True"); @@ -711,8 +752,8 @@ std::tuple _scaled_dot_product_attention_math( "_scaled_dot_product_attention: Nested tensors for query / key are not supported when is_causal=True"); // Replace attn_mask with causal mask; lower triangular elements take part in attention. - const auto L = query.size(-2), S = key.size(-2); - attn_mask = at::ones({L, S}, query.options().dtype(at::kBool)).tril(); + const auto L = query.sym_size(-2), S = key.sym_size(-2); + attn_mask = at::ones_symint({L, S}, query.options().dtype(at::kBool)).tril(); } if (attn_mask.has_value()) { TORCH_CHECK(!query.is_nested() && !key.is_nested(), @@ -726,7 +767,7 @@ std::tuple _scaled_dot_product_attention_math( } // Otherwise, attn_mask represents an additive attention tensor } - auto attn = at::matmul(query, key.transpose(-2, -1)); + auto attn = at::matmul(query, key.transpose(-2, -1)/scaling_factor); if (attn_mask.has_value()) { attn.add_(*attn_mask); } diff --git a/aten/src/ATen/native/transformers/attention.h b/aten/src/ATen/native/transformers/attention.h new file mode 100644 index 0000000000000..783b22869137e --- /dev/null +++ b/aten/src/ATen/native/transformers/attention.h @@ -0,0 +1,33 @@ +#pragma once +#include +#include + +namespace at { +namespace native { + +TORCH_API Tensor bmm_nt(const Tensor& a, const Tensor& b); +TORCH_API Tensor masked_softmax( + Tensor& attn_scores, + c10::optional attn_mask, + const Tensor& query, + c10::optional mask_type = NULL); + +TORCH_API Tensor transform0213_gemm_nt_bias( + const Tensor& a, + const Tensor& b, + const Tensor& c, + const Tensor& query); + +TORCH_API Tensor bmm_nn(Tensor& out, const Tensor& a, const Tensor& b); + +TORCH_API void debug_assert_shape(int line, const Tensor& t, c10::IntArrayRef shape); + +TORCH_API Tensor qkv_projection( + const Tensor& query, + const Tensor& key, + const Tensor& value, + const int64_t embed_dim, + const Tensor& qkv_weight); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/transformers/cuda/attention.cu b/aten/src/ATen/native/transformers/cuda/attention.cu index acccb8821d833..8dcb99b3380d9 100644 --- a/aten/src/ATen/native/transformers/cuda/attention.cu +++ b/aten/src/ATen/native/transformers/cuda/attention.cu @@ -17,6 +17,7 @@ #include +#include #include #include #include @@ -479,12 +480,210 @@ __host__ std::tuple transform_bias_rescale_qkv_cuda( return std::make_tuple(q_k_v_s[0], q_k_v_s[1], q_k_v_s[2]); } -std::tuple flash_attention_helper_dense_unpacked( +std::tuple native_multi_head_attention_cuda( + const Tensor& query, + const Tensor& key, + const Tensor& value, + const int64_t embed_dim, + const int64_t num_head, + const Tensor& qkv_weight, + const Tensor& qkv_bias, + const Tensor& proj_weight, + const Tensor& proj_bias, + const c10::optional& mask, + bool need_weights, + bool average_attn_weights, + const c10::optional mask_type) { + // query shape: [B, T, D] + // qkv_weight shape: [3 * D, D] + + TORCH_CHECK( + !mask || !query.is_nested(), + "NestedTensor with mask is not supported yet"); + const auto D = embed_dim; + TORCH_CHECK( + query.dim() == 3, + "expected 3-D `query`, got ", + query.dim(), + "-D tensor"); + TORCH_CHECK( + query.is_nested() || query.sizes()[2] == embed_dim, + "passed-in embed_dim ", + embed_dim, + " didn't match last dim of query ", + query.sizes()[2]); + TORCH_CHECK( + key.dim() == 3, + "expected 3-D `key`, got ", + key.dim(), + "-D tensor"); + TORCH_CHECK( + value.dim() == 3, + "expected 3-D `value`, got ", + value.dim(), + "-D tensor"); + TORCH_CHECK( + query.is_nested() || key.is_nested() || value.is_nested() || + (query.sizes() == key.sizes() && key.sizes() == value.sizes()), + "expected `query`/`key`/`value` shapes to match"); + TORCH_CHECK( + qkv_weight.dim() == 2, + "expected 2-D `qkv_weight`, got ", + qkv_weight.dim(), + "-D tensor"); + TORCH_CHECK( + D * 3 == qkv_weight.sizes()[0], + "expected `qkv_weight` first dim to be 3x embed_dim"); + TORCH_CHECK( + D == qkv_weight.sizes()[1], + "expected `qkv_weight` second dim to be embed_Dim"); + TORCH_CHECK( + qkv_bias.dim() == 1, + "expected 2-D `qkv_bias`, got ", + qkv_bias.dim(), + "-D tensor"); + TORCH_CHECK( + qkv_bias.sizes()[0] == 3 * D, + "expected `qkv_bias` first dim and first dim of query to be equal"); + TORCH_CHECK(D % num_head == 0, "`embed_dim` must divide evenly by `num_heads`"); + +#ifndef NDEBUG + const auto B = query.is_nested() + ? get_nested_tensor_impl(query)->get_nested_size_tensor().size(0) + : query.sizes()[0]; + auto T = query.is_nested() ? 0 : query.sizes()[1]; + +#endif + const auto dim_per_head = D / num_head; + if ((query.is_same(key) && key.is_same(value)) && dim_per_head % 8 == 0 ) { + + // We have not done linear projection yet but the input for SDP + // Is expected to be 4 dimensional. We "cheaply" create view tensors + // That will then be used for checking hot path conditions with select_sd_backend + auto q = query.view({query.size(0), -1, num_head, dim_per_head}).transpose(1, 2); + auto k = key.view({key.size(0), -1, num_head, dim_per_head}).transpose(1, 2); + auto v = value.view({value.size(0), -1, num_head, dim_per_head}).transpose(1, 2); + + sdp::sdp_params kernel_params{q, k, v, mask.has_value(), 0.0, need_weights, false}; + auto backend = select_sdp_backend(kernel_params); + if (backend == sdp::SDPBackend::flash_attention || backend == sdp::SDPBackend::efficient_attention) { + auto x = at::linear(query, qkv_weight, qkv_bias); + auto chunks = x.chunk(3, -1); + auto x_size_0 = x.size(0); + + chunks[0] = (chunks[0].view({x_size_0, -1, num_head, dim_per_head})) + .transpose(1, 2); + chunks[1] = (chunks[1].view({x_size_0, -1, num_head, dim_per_head})) + .transpose(1, 2); + chunks[2] = (chunks[2].view({x_size_0, -1, num_head, dim_per_head})) + .transpose(1, 2); + + auto y = at::_scaled_dot_product_attention( + chunks[0], chunks[1], chunks[2], mask, 0.0, need_weights, false); + auto past_sdp = + std::get<0>(y).transpose(1, 2).reshape({x_size_0, -1, embed_dim}); + return std::make_tuple( + at::linear(past_sdp, proj_weight, proj_bias), Tensor()); + } + // Returned math or error lets not use it + } + + // shape: [B, T, 3 x D] + auto qkv = qkv_projection(query, key, value, embed_dim, qkv_weight); + + if (!qkv.is_nested() && qkv.numel() == 0) { + if (query.is_nested()) { + return std::make_tuple(Tensor(), Tensor()); + } + return std::make_tuple(at::empty_like(query), Tensor()); + } + +#ifndef NDEBUG + if (!query.is_nested() || !qkv.is_nested()) { + if (query.is_nested()) { + T = qkv.size(1); + } + debug_assert_shape(__LINE__, qkv, {B, T, 3 * D}); + } +#endif + +#ifdef DEBUG_PRINT_EACH_STEP + if (!qkv.is_nested()) { + std::cerr << "qkv: " << qkv << std::endl; + } +#endif + // shape: 3 x [B, num_head, T, dim_per_head] + auto q_k_v = _transform_bias_rescale_qkv(qkv, qkv_bias, num_head); + qkv = Tensor(); // Not used any more, allow free + auto& q = std::get<0>(q_k_v); + const auto& k = std::get<1>(q_k_v); + const auto& v = std::get<2>(q_k_v); +#ifndef NDEBUG + debug_assert_shape(__LINE__, q, {B, num_head, T, dim_per_head}); + debug_assert_shape(__LINE__, k, {B, num_head, T, dim_per_head}); + debug_assert_shape(__LINE__, v, {B, num_head, T, dim_per_head}); +#endif +#ifdef DEBUG_PRINT_EACH_STEP + std::cerr << "q: " << q << std::endl; + std::cerr << "k: " << k << std::endl; + std::cerr << "v: " << v << std::endl; +#endif + + // shape: [B, num_head, T, T] + auto qkt = bmm_nt(q, k); + // q & k are dead but cannot be freed because they were packed with v +#ifndef NDEBUG + debug_assert_shape(__LINE__, qkt, {B, num_head, T, T}); +#endif +#ifdef DEBUG_PRINT_EACH_STEP + std::cerr << "qkt: " << qkt << std::endl; +#endif + + // shape: [B, num_head, T, T] + // TODO: long-term, have a kernel that works with + // NestedTensor directly if there is no mask passed + qkt = masked_softmax(qkt, mask, query, mask_type); +#ifdef DEBUG_PRINT_EACH_STEP + std::cerr << "qkt after softmax: " << qkt << std::endl; +#endif + + // shape: [B, num_head, T, dim_per_head] + // reuse storage for q; we're done with it + auto attn_ctx = bmm_nn(q, qkt, v); + // qkv is not dead; we just reused storage for q! + if (!need_weights) { + qkt = Tensor(); + } +#ifndef NDEBUG + debug_assert_shape(__LINE__, attn_ctx, {B, num_head, T, dim_per_head}); +#endif +#ifdef DEBUG_PRINT_EACH_STEP + std::cerr << "attn_ctx: " << attn_ctx << std::endl; +#endif + + // shape: [B, T, D] + // Fuse transform_0213 inside + auto proj = transform0213_gemm_nt_bias( + attn_ctx, proj_weight, proj_bias, query); +#ifndef NDEBUG + debug_assert_shape(__LINE__, proj, {B, T, D}); +#endif + if (need_weights && average_attn_weights) { + // weights are not needed for full transformer, so don't worry too + // much about performance -- we implement this just to make use + // cases that don't disable need_weights still get some speedup. + qkt = qkt.sum(1); + qkt /= num_head; + } + return std::make_tuple(std::move(proj), std::move(qkt)); +} + +std::tuple _scaled_dot_product_flash_attention_cuda( const Tensor& query, const Tensor& key, const Tensor& value, double dropout_p, - bool need_atten_weights, + bool return_softmax, bool is_causal) { // Query (Batch x Num_heads x Q_seq_len x Dim_per_head) // Key (Batch x Num_heads x KV_seq_len x Dim_per_head) @@ -527,8 +726,9 @@ std::tuple flash_attention_helper_dense_unpacked( Tensor key_reshaped = k_t.reshape({Nnz_kv, num_heads, head_dim}); Tensor value_reshaped = v_t.reshape({Nnz_kv, num_heads, head_dim}); - Tensor attention = - at::_flash_scaled_dot_product_attention( + Tensor attention, log_sumexp, softmax; + std::tie(attention, log_sumexp, softmax) = + at::_flash_attention_forward( query_reshaped, key_reshaped, value_reshaped, @@ -536,18 +736,22 @@ std::tuple flash_attention_helper_dense_unpacked( cumulative_sequence_length_k, max_seqlen_batch_q, max_seqlen_batch_k, + return_softmax, dropout_p, is_causal); // Reshape output to convert nnz to batch_size and seq_len attention = attention.view({batch_size, max_seqlen_batch_q, num_heads, head_dim}).transpose(1,2); - return std::tuple(attention, Tensor()); + return std::make_tuple(attention, log_sumexp, softmax); } -std::tuple mem_eff_helper( + +std::tuple _scaled_dot_product_efficient_attention_cuda( const Tensor& query, const Tensor& key, - const Tensor& value){ + const Tensor& value, + bool compute_log_sumexp, + bool is_causal) { // Query -> Query(Batch x Q_seq_len x Num_heads x Dim_per_head) // Key -> Key(Batch x KV_seq_len x Num_heads x Dim_per_head) // Value -> Value(Batch x KV_seq_len x Num_heads x Dim_per_head) @@ -555,38 +759,34 @@ std::tuple mem_eff_helper( Tensor k_t = key.transpose(1, 2); Tensor v_t = value.transpose(1, 2); - Tensor attention = std::get<0>(at::_efficient_attention_forward( + Tensor attention, log_sumexp; + std::tie(attention, log_sumexp) = at::_efficient_attention_forward( q_t, k_t, v_t, c10::nullopt, c10::nullopt, c10::nullopt, - false, - false)).transpose(1,2); - return std::make_tuple(attention, Tensor()); + compute_log_sumexp, + is_causal); + attention = attention.transpose(1,2); + return std::make_tuple(std::move(attention), std::move(log_sumexp)); } -std::tuple _scaled_dot_product_attention_forward_cuda( - const Tensor& query_, const Tensor& key, const Tensor& value, - const c10::optional& attn_mask_, double dropout_p, bool need_attn_weights, bool is_causal) { - // Determine which efficient kernel to use - sdp::sdp_params kernel_params{query_, key, value, attn_mask_.has_value(), dropout_p, need_attn_weights, is_causal}; - auto backend = select_sdp_backend(kernel_params); - switch(backend){ - case sdp::SDPBackend::flash_attention: - return flash_attention_helper_dense_unpacked(query_, key, value, dropout_p, need_attn_weights, is_causal); - case sdp::SDPBackend::efficient_attention: - return mem_eff_helper(query_, key , value); - case sdp::SDPBackend::math: - return at::_scaled_dot_product_attention_math(query_, key, value, attn_mask_, dropout_p, need_attn_weights, is_causal); - default: - TORCH_CHECK(false, "No viable backend for scaled_dot_product_attention was found."); - return std::make_tuple(Tensor(), Tensor()); - } +int64_t _fused_sdp_choice_cuda(const Tensor& query_, const Tensor& key, const Tensor& value, + const c10::optional& attn_mask_, double dropout_p, bool need_attn_weights, bool is_causal){ + sdp::sdp_params kernel_params{query_, key, value, attn_mask_.has_value(), dropout_p, need_attn_weights, is_causal}; + auto backend = select_sdp_backend(kernel_params); + if (backend == sdp::SDPBackend::error) { + TORCH_CHECK( + false, + "No viable backend for scaled_dot_product_attention was found. ", + "This is likely due to turning off both the math kernel and the fused kernels."); + } + return static_cast(backend); } -Tensor flash_scaled_dot_product_attention( +std::tuple _flash_attention_forward( const Tensor& query, const Tensor& key, const Tensor& value, @@ -594,11 +794,12 @@ Tensor flash_scaled_dot_product_attention( const Tensor& cumulative_sequence_length_k, const int64_t max_seqlen_batch_q, const int64_t max_seqlen_batch_k, + bool return_softmax, double dropout_p, bool is_causal) { #if defined(USE_FLASH_ATTENTION) auto softmax_scale = std::pow(query.size(-1), -0.5); - std::vector output = fmha::mha_fwd( + return fmha::mha_fwd( query, key, value, @@ -610,12 +811,11 @@ Tensor flash_scaled_dot_product_attention( softmax_scale, false, is_causal, - false, + return_softmax, c10::nullopt); - return output[0]; #endif TORCH_CHECK(false, "USE_FLASH_ATTENTION was not enabled for build.") - return Tensor(); + return std::make_tuple(Tensor(), Tensor(), Tensor()); } std::tuple _efficient_attention_forward( @@ -636,7 +836,6 @@ std::tuple _efficient_attention_forward( // TODO In theory it is possible to compile with _CUDA_ARCH < 5.0 and run on a // machine that is >= 5.0. In practice, this is not a problem but since // this would avoid runtime architecture checks, we should look into it - TORCH_CHECK(query.dim() == 4); TORCH_CHECK(key.dim() == 4); TORCH_CHECK(value.dim() == 4); @@ -768,7 +967,7 @@ std::tuple _efficient_attention_forward( kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes)); } Kernel::check_supported(p); - kernel_fn<<>>(p); + kernel_fn<<>>(p); }; // Dispatch to the right kernel DISPATCH_KERNEL(query, key, value, ([&]() { diff --git a/aten/src/ATen/native/transformers/cuda/attention_backward.cu b/aten/src/ATen/native/transformers/cuda/attention_backward.cu new file mode 100644 index 0000000000000..a063aacb901ee --- /dev/null +++ b/aten/src/ATen/native/transformers/cuda/attention_backward.cu @@ -0,0 +1,289 @@ +#include + +#include + +#include +#include + +#include +#include +#include +#include + +#include +#ifdef USE_FLASH_ATTENTION +#include +#endif + +#define ASSIGN_CHECK_OVERFLOW(A, B) \ + { \ + A = B; \ + TORCH_CHECK(B < std::numeric_limits::max(), #B " overflows"); \ + } + +#define DISPATCH_MAXK(func) \ + { \ + const auto maxK = std::max(query.size(3), value.size(3)); \ + if (maxK <= 64) { \ + constexpr int kMaxK = 64; \ + func(); \ + } else if (maxK <= 128) { \ + constexpr int kMaxK = 128; \ + func(); \ + } else { \ + constexpr int kMaxK = std::numeric_limits::max(); \ + func(); \ + } \ + } + +#define DISPATCH_KERNEL(QUERY, KEY, VALUE, FUNC) \ + { \ + cudaDeviceProp* properties = \ + at::cuda::getDeviceProperties(QUERY.device().index()); \ + const int computeCapability = properties->major * 10 + properties->minor; \ + DISPATCH_MAXK(([&] { \ + DISPATCH_TYPES( \ + QUERY, ([&]() { \ + DISPATCH_ARCHTAG( \ + computeCapability, ([&]() { \ + using AlignedAK = \ + AttentionBackwardKernel; \ + bool isAligned = \ + (QUERY.stride(2) % AlignedAK::kOptimalAlignement == 0 && \ + KEY.stride(2) % AlignedAK::kOptimalAlignement == 0 && \ + VALUE.stride(2) % AlignedAK::kOptimalAlignement == 0); \ + DISPATCH_BOOL(isAligned, kIsAligned, ([&]() { \ + using Kernel = AttentionBackwardKernel< \ + ArchTag, \ + scalar_t, \ + kIsAligned, \ + kMaxK>; \ + FUNC(); \ + })) \ + })) \ + })) \ + })); \ + } + +namespace at { + +namespace native { + +std::tuple _efficient_attention_backward( + const at::Tensor& grad_out_, + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const at::Tensor& out, + const at::Tensor& logsumexp, + bool causal) { + #if defined(USE_FLASH_ATTENTION) + if (!grad_out_.defined()) { + return std::make_tuple(Tensor{}, Tensor{}, Tensor{}); + } + // ndim + TORCH_CHECK(query.dim() == grad_out_.dim()); + TORCH_CHECK(query.dim() == key.dim()); + TORCH_CHECK(query.dim() == value.dim()); + TORCH_CHECK(query.dim() == 4); + + // batch size + TORCH_CHECK(query.size(0) == grad_out_.size(0)); + TORCH_CHECK(query.size(0) == key.size(0)); + TORCH_CHECK(query.size(0) == value.size(0)); + + // seqlen + TORCH_CHECK(key.size(1) == value.size(1)); + TORCH_CHECK(query.size(1) == grad_out_.size(1)); + + // Num heads + TORCH_CHECK(query.size(2) == key.size(2)); + TORCH_CHECK(query.size(2) == value.size(2)); + TORCH_CHECK(query.size(2) == grad_out_.size(2)); + + // Embedding per head + TORCH_CHECK(query.size(3) == key.size(3)); + TORCH_CHECK(value.size(3) == grad_out_.size(3)); + + // handle potentially non-contiguous grad_out through a copy + auto grad_out = grad_out_.contiguous(); + CHECK_NOSPARSE_CONTIGUOUS_CUDA(grad_out); + + CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(query); + CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(key); + CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(value); + + at::cuda::CUDAGuard device_guard(query.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + int64_t B = query.size(0); + int64_t M = query.size(1); + int64_t N = key.size(1); + int64_t nH = query.size(2); + int64_t K = query.size(3); + + // It does not make sense to use that in practice, + // but let's still make sure we are correct + // As we iterate through keys first, we skip + // keys with no query associated, so they are not + // initialized + bool grad_kv_needs_init = causal && N > M; + at::Tensor grad_q, grad_k, grad_v; + int8_t gQKV_strideM_multiplier = 1; + if (!grad_kv_needs_init && query.size(1) == key.size(1) && + query.size(3) == value.size(3) && + query.storage().is_alias_of(key.storage()) && + query.storage().is_alias_of(value.storage())) { + // Create one big contiguous chunk + // This is because q, k and v usually come from a single + // output of a linear layer that is chunked. + // Creating the gradients with the right layout saves us + // a `torch.cat` call in the backward pass + at::Tensor chunk = at::empty({B, M, 3, nH, K}, query.options()); + grad_q = chunk.select(2, 0); + grad_k = chunk.select(2, 1); + grad_v = chunk.select(2, 2); + gQKV_strideM_multiplier=3; + } else { + grad_q = at::empty(query.sizes(), query.options()); + grad_k = grad_kv_needs_init ? at::zeros(key.sizes(), key.options()) + : at::empty(key.sizes(), key.options()); + grad_v = grad_kv_needs_init ? at::zeros(value.sizes(), value.options()) + : at::empty(value.sizes(), value.options()); + } + + auto launchKernel = [&](auto _k, int computeCapability) { + using Kernel = decltype(_k); + using scalar_t = typename Kernel::scalar_t; + (void)_k; + + size_t smem_bytes = sizeof(typename Kernel::SharedStorage); + + // TODO: Fuse this into a kernel? + // This is a bottleneck for smaller sequences (M <= 128) + auto delta = Kernel::kKernelComputesDelta + ? at::empty({B, nH, M}, query.options().dtype(at::ScalarType::Float)) + : (grad_out.to(at::kFloat) * out.to(at::kFloat)) + .sum(-1) + .transpose(-2, -1) + .contiguous(); + TORCH_INTERNAL_ASSERT(delta.size(0) == B); + TORCH_INTERNAL_ASSERT(delta.size(1) == nH); + TORCH_INTERNAL_ASSERT(delta.size(2) == M); + + typename Kernel::Params p; + p.query_ptr = (scalar_t*)query.data_ptr(); + p.key_ptr = (scalar_t*)key.data_ptr(); + p.value_ptr = (scalar_t*)value.data_ptr(); + p.logsumexp_ptr = (typename Kernel::lse_scalar_t*)logsumexp.data_ptr(); + p.output_ptr = (scalar_t*)out.data_ptr(); + p.grad_output_ptr = (scalar_t*)grad_out.data_ptr(); + p.grad_query_ptr = (scalar_t*)grad_q.data_ptr(); + p.grad_key_ptr = (scalar_t*)grad_k.data_ptr(); + p.grad_value_ptr = (scalar_t*)grad_v.data_ptr(); + p.delta_ptr = (float*)delta.data_ptr(); + p.head_dim = query.size(3); + p.head_dim_value = value.size(3); + p.num_queries = query.size(1); + p.num_keys = key.size(1); + p.num_batches = B; + p.num_heads = nH; + p.causal = causal; + + ASSIGN_CHECK_OVERFLOW(p.gO_strideB, grad_out.stride(0)); + ASSIGN_CHECK_OVERFLOW(p.gO_strideM, grad_out.stride(1)); + ASSIGN_CHECK_OVERFLOW(p.gO_strideH, grad_out.stride(2)); + + ASSIGN_CHECK_OVERFLOW(p.o_strideB, out.stride(0)); + ASSIGN_CHECK_OVERFLOW(p.o_strideH, out.stride(2)); + + ASSIGN_CHECK_OVERFLOW(p.gQ_strideB, grad_q.stride(0)); + ASSIGN_CHECK_OVERFLOW(p.gK_strideB, grad_k.stride(0)); + ASSIGN_CHECK_OVERFLOW(p.gV_strideB, grad_v.stride(0)); + ASSIGN_CHECK_OVERFLOW(p.gQ_strideH, grad_q.stride(2)); + ASSIGN_CHECK_OVERFLOW(p.gK_strideH, grad_k.stride(2)); + ASSIGN_CHECK_OVERFLOW(p.gV_strideH, grad_v.stride(2)); + p.gQKV_strideM_multiplier = gQKV_strideM_multiplier; + TORCH_INTERNAL_ASSERT(p.gQ_strideM() == grad_q.stride(1)); + TORCH_INTERNAL_ASSERT(p.gK_strideM() == grad_k.stride(1)); + TORCH_INTERNAL_ASSERT(p.gV_strideM() == grad_v.stride(1)); + + ASSIGN_CHECK_OVERFLOW(p.q_strideB, query.stride(0)); + ASSIGN_CHECK_OVERFLOW(p.k_strideB, key.stride(0)); + ASSIGN_CHECK_OVERFLOW(p.v_strideB, value.stride(0)); + ASSIGN_CHECK_OVERFLOW(p.q_strideM, query.stride(1)); + ASSIGN_CHECK_OVERFLOW(p.k_strideM, key.stride(1)); + ASSIGN_CHECK_OVERFLOW(p.v_strideM, value.stride(1)); + ASSIGN_CHECK_OVERFLOW(p.q_strideH, query.stride(2)); + ASSIGN_CHECK_OVERFLOW(p.k_strideH, key.stride(2)); + ASSIGN_CHECK_OVERFLOW(p.v_strideH, value.stride(2)); + + Kernel::check_supported(p); + + constexpr auto kernel_fn = attention_kernel_backward_batched; + + if (smem_bytes > 0xc000) { + TORCH_INTERNAL_ASSERT( + computeCapability >= 70, + "This kernel requires too much shared memory on this machine!"); + cudaFuncSetAttribute( + kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); + } + + // second syntax resulted in the error below on windows + // error C3495: 'kernel_fn': a simple capture must be a variable + // with automatic storage duration declared + // in the reaching scope of the lambda +#ifdef _WIN32 + cudaFuncAttributes attr; + AT_CUDA_CHECK(cudaFuncGetAttributes(&attr, kernel_fn)); + TORCH_INTERNAL_ASSERT( + attr.binaryVersion >= Kernel::ArchTag::kMinComputeCapability, + "Something went wrong in the build process"); +#else + auto checkBinaryArchMatches = [&]() { + cudaFuncAttributes attr; + AT_CUDA_CHECK(cudaFuncGetAttributes(&attr, kernel_fn)); + return attr.binaryVersion >= Kernel::ArchTag::kMinComputeCapability; + }; + TORCH_INTERNAL_ASSERT( + checkBinaryArchMatches(), "Something went wrong in the build process"); +#endif + + kernel_fn<<>>(p); + }; + + DISPATCH_KERNEL( + query, key, value, ([&] { launchKernel(Kernel{}, computeCapability); })); + AT_CUDA_CHECK(cudaGetLastError()); + return std::make_tuple(grad_q, grad_k, grad_v); + #endif + TORCH_CHECK(false, "USE_FLASH_ATTENTION was not enabled for build.") + return std::make_tuple(Tensor{}, Tensor{}, Tensor{}); +} + + +std::tuple _scaled_dot_product_efficient_attention_backward_cuda( + const at::Tensor& grad_out_, + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const at::Tensor& out, + const at::Tensor& logsumexp, + bool causal){ + if (!grad_out_.defined()) { + return std::make_tuple(Tensor{}, Tensor{}, Tensor{}); + } + auto grad_out = grad_out_.transpose(1, 2); + auto out_t = out.transpose(1, 2); + auto q_t = query.transpose(1, 2); + auto k_t = key.transpose(1, 2); + auto v_t = value.transpose(1, 2); + + Tensor grad_q, grad_k, grad_v; + std::tie(grad_q, grad_k, grad_v) = at::_efficient_attention_backward(grad_out, q_t, k_t, v_t, out_t, logsumexp, causal); + return std::make_tuple(grad_q.transpose(1, 2), grad_k.transpose(1, 2), grad_v.transpose(1, 2)); +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.cpp b/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.cpp index a8d6110e951d9..7d9807260db2f 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.cpp +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.cpp @@ -26,9 +26,11 @@ * ******************************************************************************/ +#include #ifdef USE_FLASH_ATTENTION #include #include +#include #include #include @@ -62,7 +64,7 @@ void set_params_fprop(FMHA_fprop_params ¶ms, bool is_causal) { // Reset the parameters - memset(¶ms, 0, sizeof(params)); + params = {}; params.is_bf16 = q.dtype() == at::kBFloat16; @@ -114,7 +116,7 @@ void set_params_fprop(FMHA_fprop_params ¶ms, params.is_causal = is_causal; } -std::vector +std::tuple mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i const at::Tensor &k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i @@ -185,6 +187,9 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q int max_seqlen_q = ((max_seqlen_q_ + 16 - 1) / 16) * 16; bool loop = max_seqlen_k > blocksize_c; + // Otherwise the kernel will be launched from cuda:0 device + at::cuda::CUDAGuard device_guard{q.device()}; + auto opts = q.options(); auto o = at::empty({ total_q, num_heads, head_size }, opts); @@ -237,9 +242,7 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q run_fmha_fprop(launch_params, /*configure=*/false); - std::vector result = {o, softmax_lse}; - if (return_softmax) {result.push_back(s);} - return result; + return std::make_tuple(o, softmax_lse, s); } } // namespace fmha #endif diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.h b/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.h index 226d4ddd2b551..b0555463be040 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.h +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.h @@ -7,7 +7,7 @@ namespace fmha { TORCH_API -std::vector +std::tuple mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i const at::Tensor &k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/attention_backward_generic.cu b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/attention_backward_generic.cu deleted file mode 100644 index 07c14ad8195dd..0000000000000 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/attention_backward_generic.cu +++ /dev/null @@ -1,166 +0,0 @@ -#include - -#define DISPATCH_MAXK(func) \ - { \ - const auto maxK = std::max(query.size(2), value.size(2)); \ - if (maxK <= 64) { \ - constexpr int kMaxK = 64; \ - func(); \ - } else if (maxK <= 128) { \ - constexpr int kMaxK = 128; \ - func(); \ - } else { \ - constexpr int kMaxK = std::numeric_limits::max(); \ - func(); \ - } \ - } - -#define DISPATCH_KERNEL(QUERY, KEY, VALUE, FUNC) \ - { \ - cudaDeviceProp* properties = \ - at::cuda::getDeviceProperties(QUERY.device().index()); \ - const int computeCapability = properties->major * 10 + properties->minor; \ - DISPATCH_MAXK(([&] { \ - DISPATCH_TYPES( \ - QUERY, ([&]() { \ - DISPATCH_ARCHTAG( \ - computeCapability, ([&]() { \ - using AlignedAK = \ - AttentionBackwardKernel; \ - bool isAligned = \ - (QUERY.stride(1) % AlignedAK::kOptimalAlignement == 0 && \ - KEY.stride(1) % AlignedAK::kOptimalAlignement == 0 && \ - VALUE.stride(1) % AlignedAK::kOptimalAlignement == 0); \ - DISPATCH_BOOL(isAligned, kIsAligned, ([&]() { \ - using Kernel = AttentionBackwardKernel< \ - ArchTag, \ - scalar_t, \ - kIsAligned, \ - kMaxK>; \ - FUNC(); \ - })) \ - })) \ - })) \ - })); \ - } - -namespace { -std::tuple -mem_efficient_attention_backward_cutlass( - const at::Tensor& grad_out_, - const at::Tensor& query, - const at::Tensor& key, - const at::Tensor& value, - const at::Tensor& logsumexp, - const at::Tensor& out, - bool causal) { - TORCH_CHECK(query.dim() == grad_out_.dim()); - TORCH_CHECK(query.dim() == key.dim()); - TORCH_CHECK(query.dim() == 3); - - TORCH_CHECK(query.size(0) == grad_out_.size(0)); - TORCH_CHECK(query.size(1) == grad_out_.size(1)); - TORCH_CHECK(value.size(2) == grad_out_.size(2)); - - TORCH_CHECK(query.size(2) == key.size(2)); - TORCH_CHECK(query.size(0) == key.size(0)); - - TORCH_CHECK(query.size(0) == value.size(0)); - TORCH_CHECK(key.size(1) == value.size(1)); - - // handle potentially non-contiguous grad_out through a copy - auto grad_out = grad_out_.contiguous(); - - CHECK_NOSPARSE_CONTIGUOUS_CUDA(query); - CHECK_NOSPARSE_CONTIGUOUS_CUDA(key); - CHECK_NOSPARSE_CONTIGUOUS_CUDA(value); - CHECK_NOSPARSE_CONTIGUOUS_CUDA(grad_out); - - at::cuda::CUDAGuard device_guard(query.device()); - - int64_t B = query.size(0); - int64_t M = query.size(1); - int64_t N = key.size(1); - int64_t K = query.size(2); - - // It does not make sense to use that in practice, - // but let's still make sure we are correct - // As we iterate through keys first, we skip - // keys with no query associated, so they are not - // initialized - bool grad_kv_needs_init = causal && N > M; - at::Tensor grad_q = at::empty_like(query); - at::Tensor grad_k = - grad_kv_needs_init ? at::zeros_like(key) : at::empty_like(key); - at::Tensor grad_v = - grad_kv_needs_init ? at::zeros_like(value) : at::empty_like(value); - - auto launchKernel = [&](auto _k, int computeCapability) { - using Kernel = decltype(_k); - using scalar_t = typename Kernel::scalar_t; - (void)_k; - - size_t smem_bytes = sizeof(typename Kernel::SharedStorage); - - // TODO: Fuse this into a kernel? - // This is a bottleneck for smaller sequences (M <= 128) - auto delta = Kernel::kKernelComputesDelta - ? at::empty({B, M}, query.options().dtype(at::ScalarType::Float)) - : (grad_out.to(at::kFloat) * out.to(at::kFloat)).sum(-1); - TORCH_INTERNAL_ASSERT(delta.size(0) == B); - TORCH_INTERNAL_ASSERT(delta.size(1) == M); - - typename Kernel::Params params; - params.query_ptr = (scalar_t*)query.data_ptr(); - params.key_ptr = (scalar_t*)key.data_ptr(); - params.value_ptr = (scalar_t*)value.data_ptr(); - params.logsumexp_ptr = (typename Kernel::lse_scalar_t*)logsumexp.data_ptr(); - params.output_ptr = (scalar_t*)out.data_ptr(); - params.grad_output_ptr = (scalar_t*)grad_out.data_ptr(); - params.grad_query_ptr = (scalar_t*)grad_q.data_ptr(); - params.grad_key_ptr = (scalar_t*)grad_k.data_ptr(); - params.grad_value_ptr = (scalar_t*)grad_v.data_ptr(); - params.delta_ptr = (float*)delta.data_ptr(); - params.head_dim = query.size(2); - params.head_dim_value = value.size(2); - params.num_queries = query.size(1); - params.num_keys = key.size(1); - params.num_batches = B; - params.causal = causal; - Kernel::check_supported(params); - - constexpr auto kernel_fn = attention_kernel_backward_batched; - - if (smem_bytes > 0xc000) { - TORCH_INTERNAL_ASSERT( - computeCapability >= 70, - "This kernel requires too much shared memory on this machine!"); - cudaFuncSetAttribute( - kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); - } - - auto checkBinaryArchMatches = [&]() { - cudaFuncAttributes attr; - AT_CUDA_CHECK(cudaFuncGetAttributes(&attr, kernel_fn)); - return attr.binaryVersion >= Kernel::ArchTag::kMinComputeCapability; - }; - TORCH_INTERNAL_ASSERT( - checkBinaryArchMatches(), "Something went wrong in the build process"); - - kernel_fn<<>>( - params); - }; - - DISPATCH_KERNEL( - query, key, value, ([&] { launchKernel(Kernel{}, computeCapability); })); - AT_CUDA_CHECK(cudaGetLastError()); - return std::make_tuple(grad_q, grad_k, grad_v); -} // namespace - -} // namespace - -// TORCH_LIBRARY_IMPL(xformers, CUDA, m) { -// m.impl( -// TORCH_SELECTIVE_NAME("xformers::efficient_attention_backward_cutlass"), -// TORCH_FN(mem_efficient_attention_backward_cutlass)); -// } diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/attention_forward_generic.cu b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/attention_forward_generic.cu deleted file mode 100644 index 59b3637c8a438..0000000000000 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/attention_forward_generic.cu +++ /dev/null @@ -1,232 +0,0 @@ -#include - - -#define DISPATCH_BLOCKSIZE(VALUE_HEAD_DIM, FN) \ - { \ - if (VALUE_HEAD_DIM <= 64) { \ - constexpr bool kIs64x64 = true; \ - constexpr bool kSingleValueIteration = true; \ - FN(); \ - } else { \ - constexpr bool kIs64x64 = false; \ - if (VALUE_HEAD_DIM <= 128) { \ - constexpr bool kSingleValueIteration = true; \ - FN(); \ - } else { \ - constexpr bool kSingleValueIteration = false; \ - FN(); \ - } \ - } \ - } - -#define DISPATCH_KERNEL(QUERY, KEY, VALUE, FUNC) \ - { \ - cudaDeviceProp* properties = \ - at::cuda::getDeviceProperties(QUERY.device().index()); \ - const int computeCapability = properties->major * 10 + properties->minor; \ - DISPATCH_BLOCKSIZE( \ - VALUE.size(-1), ([&]() { \ - static constexpr int64_t kQueriesPerBlock = kIs64x64 ? 64 : 32; \ - static constexpr int64_t kKeysPerBlock = kIs64x64 ? 64 : 128; \ - DISPATCH_TYPES( \ - QUERY, ([&]() { \ - DISPATCH_ARCHTAG( \ - computeCapability, ([&]() { \ - using AlignedAK = AttentionKernel< \ - scalar_t, \ - ArchTag, \ - true, \ - kQueriesPerBlock, \ - kKeysPerBlock, \ - kSingleValueIteration>; \ - /* Run a more efficient kernel (with `isAligned=True`) \ - if memory is correctly aligned*/ \ - bool isAligned = \ - (QUERY.stride(2) % AlignedAK::kAlignmentQ == 0 && \ - KEY.stride(2) % AlignedAK::kAlignmentK == 0 && \ - VALUE.stride(2) % AlignedAK::kAlignmentV == 0); \ - /* TODO: Should we warn or log somewhere when we use a \ - less efficient kernel due to wrong alignment? */ \ - DISPATCH_BOOL(isAligned, kIsAligned, ([&]() { \ - using Kernel = AttentionKernel< \ - scalar_t, \ - ArchTag, \ - kIsAligned, \ - kQueriesPerBlock, \ - kKeysPerBlock, \ - kSingleValueIteration>; \ - FUNC(); \ - })) \ - })) \ - })); \ - })); \ - } - -namespace { -/* - There are 2 modes for using this function. - (Mode BMHK) With all the heads having the same seqlen - (Mode 1MHK) `batch=1` with all tokens across batches concatenated -*/ -std::tuple efficient_attention_forward_cutlass( - const at::Tensor& query, // [b, seqlen, num_heads, K] - const at::Tensor& key, // [b, seqlen, num_heads, K] - const at::Tensor& value, // [b, seqlen, num_heads, Kv] - // (Mode 1MHK only) [b+1]: cu_seqlens_q[b] contains the - // position of the first query token for batch $b - const c10::optional& cu_seqlens_q, - // (Mode 1MHK only) [b+1]: cu_seqlens_k[b] contains the - // position of the first key token for batch $b - const c10::optional& cu_seqlens_k, - // (Mode 1MHK only) Maximum sequence length across batches - const c10::optional max_seqlen_q_, - bool compute_logsumexp, - bool causal) { - TORCH_CHECK(query.dim() == 4); - TORCH_CHECK(key.dim() == 4); - TORCH_CHECK(value.dim() == 4); - - // Batch sizes - TORCH_CHECK(query.size(0) == key.size(0)); - TORCH_CHECK(query.size(0) == value.size(0)); - - // Sequence length - TORCH_CHECK(key.size(1) == value.size(1)); - - // Num heads - TORCH_CHECK(query.size(2) == key.size(2)); - TORCH_CHECK(query.size(2) == value.size(2)); - - // Embedding per head - TORCH_CHECK(query.size(3) == key.size(3)); - - int64_t max_seqlen_q, max_seqlen_k; - TORCH_CHECK(cu_seqlens_q.has_value() == cu_seqlens_k.has_value()); - if (cu_seqlens_q.has_value()) { - TORCH_CHECK(cu_seqlens_q->scalar_type() == at::ScalarType::Int); - TORCH_CHECK(cu_seqlens_k->scalar_type() == at::ScalarType::Int); - TORCH_CHECK(cu_seqlens_q->dim() == 1 && cu_seqlens_k->dim() == 1); - CHECK_NOSPARSE_CONTIGUOUS_CUDA((*cu_seqlens_q)); - CHECK_NOSPARSE_CONTIGUOUS_CUDA((*cu_seqlens_k)); - TORCH_CHECK(cu_seqlens_q->size(0) == cu_seqlens_k->size(0)); - TORCH_CHECK(query.size(0) == 1, "cu_seqlen only supports batch_size=1"); - TORCH_CHECK(max_seqlen_q_.has_value()); - max_seqlen_q = *max_seqlen_q_; - max_seqlen_k = 0; // Will be set inside the kernel - } else { - max_seqlen_q = query.size(1); - max_seqlen_k = key.size(1); - } - - CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(query); - CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(key); - CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(value); - - at::cuda::CUDAGuard device_guard(query.device()); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - int64_t B = query.size(0); - int64_t M = query.size(1); - int64_t N = key.size(1); - int64_t num_heads = query.size(-2); - int64_t K = query.size(-1); - int64_t Kv = value.size(-1); - - at::Tensor res; - at::Tensor logsumexp; - - auto launchKernel = [&](auto _k, int computeCapability) { - using Kernel = decltype(_k); - using scalar_t = typename Kernel::scalar_t; - (void)_k; - - res = at::empty( - {B, M, num_heads, Kv}, - query.options().dtype( - TypeTraits::atScalarType())); - - // NOTE: Should be aligned (by padding) in case M is - // not a good number for loading during backward - constexpr decltype(M) kAlignLSE = Kernel::kAlignLSE; - logsumexp = at::empty( - {B, - num_heads, - compute_logsumexp ? ceil_div(max_seqlen_q, kAlignLSE) * kAlignLSE : 0}, - query.options().dtype(at::ScalarType::Float)); - - typename Kernel::Params p; - p.query_ptr = (scalar_t*)query.data_ptr(); - p.key_ptr = (scalar_t*)key.data_ptr(); - p.value_ptr = (scalar_t*)value.data_ptr(); - p.logsumexp_ptr = compute_logsumexp - ? (typename Kernel::lse_scalar_t*)logsumexp.data_ptr() - : nullptr; - at::Tensor output_accum; - if (Kernel::kNeedsOutputAccumulatorBuffer) { - output_accum = at::empty( - {B, M, num_heads, Kv}, - query.options().dtype( - TypeTraits::atScalarType())); - p.output_accum_ptr = - (typename Kernel::output_accum_t*)output_accum.data_ptr(); - } else { - p.output_accum_ptr = nullptr; - } - p.output_ptr = (typename Kernel::output_t*)res.data_ptr(); - - if (cu_seqlens_q.has_value()) { - p.cu_seqlens_q_ptr = (int32_t*)cu_seqlens_q->data_ptr(); - p.cu_seqlens_k_ptr = (int32_t*)cu_seqlens_k->data_ptr(); - } - -#define ASSIGN_CHECK_OVERFLOW(A, B) \ - { \ - A = B; \ - TORCH_CHECK(B < std::numeric_limits::max(), #B " overflows"); \ - } - - p.num_heads = num_heads; - p.head_dim = query.size(3); - p.head_dim_value = value.size(3); - p.num_queries = max_seqlen_q; - p.num_keys = max_seqlen_k; - p.num_batches = cu_seqlens_q.has_value() ? cu_seqlens_q->size(0) - 1 : B; - p.causal = causal; - - ASSIGN_CHECK_OVERFLOW(p.q_strideB, query.stride(0)); - ASSIGN_CHECK_OVERFLOW(p.k_strideB, key.stride(0)); - ASSIGN_CHECK_OVERFLOW(p.v_strideB, value.stride(0)); - ASSIGN_CHECK_OVERFLOW(p.q_strideM, query.stride(1)); - ASSIGN_CHECK_OVERFLOW(p.k_strideM, key.stride(1)); - ASSIGN_CHECK_OVERFLOW(p.v_strideM, value.stride(1)); - ASSIGN_CHECK_OVERFLOW(p.q_strideH, query.stride(2)); - ASSIGN_CHECK_OVERFLOW(p.k_strideH, key.stride(2)); - ASSIGN_CHECK_OVERFLOW(p.v_strideH, value.stride(2)); - - constexpr auto kernel_fn = attention_kernel_batched; - size_t smem_bytes = sizeof(typename Kernel::SharedStorage); - if (smem_bytes > 0xc000) { - TORCH_INTERNAL_ASSERT( - computeCapability >= 70, - "This kernel requires too much shared memory on this machine!"); - AT_CUDA_CHECK(cudaFuncSetAttribute( - kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes)); - } - Kernel::check_supported(p); - kernel_fn<<>>(p); - }; - // Dispatch to the right kernel - DISPATCH_KERNEL(query, key, value, ([&]() { - launchKernel(Kernel{}, computeCapability); - })); - - AT_CUDA_CHECK(cudaGetLastError()); - return std::make_tuple(res, logsumexp); -} -} // namespace - -// TORCH_LIBRARY_IMPL(xformers, CUDA, m) { -// m.impl( -// TORCH_SELECTIVE_NAME("xformers::efficient_attention_forward_cutlass"), -// TORCH_FN(efficient_attention_forward_cutlass)); -// } diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/find_default_mma.h b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/find_default_mma.h index 399593fd09573..b0e7106f3cfc8 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/find_default_mma.h +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/find_default_mma.h @@ -1,15 +1,16 @@ /*! \file \brief Cutlass provides helper template functions to figure out the right - datastructures to instanciate to run a GEMM with various parameters (see + datastructures to instantiate to run a GEMM with various parameters (see `cutlass/gemm/threadblock/default_mma.h`). However, due to template - instanciation priority rules, it will only create an MmaMultiStage with + instantiation priority rules, it will only create an MmaMultiStage with kStages=3 (otherwise creates an MmePipelined - which is not compatible with FastF32). kStages=3 uses too much shared memory and we want to use kStages=2, so we just copy-pasted some code from `default_mma.h` and - `default_mma_core.h` files and wrapped this template to allow our usecase. + `default_mma_core.h` files and wrapped this template to allow our use case. This is really only for the FastF32 case - aka using TensorCores with fp32. */ +#pragma once #include #include diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h index e25701a7588ac..e629aaaecab4b 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h @@ -1,7 +1,5 @@ #pragma once - #include -#include #include #include @@ -75,46 +73,113 @@ struct AttentionBackwardKernel { struct Params { // Input tensors - scalar_t* query_ptr; // [num_queries, head_dim] - scalar_t* key_ptr; // [num_keys, head_dim] - scalar_t* value_ptr; // [num_keys, head_dim_value] - lse_scalar_t* logsumexp_ptr; // [num_queries] - scalar_t* output_ptr; // [num_queries, head_dim_value] - scalar_t* grad_output_ptr; // [num_queries, head_dim_value] - accum_t* delta_ptr; // [num_queries] + scalar_t* query_ptr; // [Mq, nH, K] + scalar_t* key_ptr; // [Mk, nH, K] + scalar_t* value_ptr; // [Mk, nH, Kv] + lse_scalar_t* logsumexp_ptr; // [nH, Mq] + scalar_t* output_ptr; // [Mq, nH, Kv] + scalar_t* grad_output_ptr; // [Mq, nH, Kv] + accum_t* delta_ptr; // [Mq, nH] // Output tensors - scalar_t* grad_query_ptr; // [num_queries, head_dim] - scalar_t* grad_key_ptr; // [num_keys, head_dim] - scalar_t* grad_value_ptr; // [num_keys, head_dim_value] + output_t* grad_query_ptr; // [Mq, nH, K] + output_t* grad_key_ptr; // [Mk, nH, K] + output_t* grad_value_ptr; // [Mk, nH, Kv] // Dimensions/strides int32_t head_dim; int32_t head_dim_value; int32_t num_queries; int32_t num_keys; - int32_t num_batches; + int32_t num_heads; bool causal; - __device__ void advance_batches(int32_t batch_id) { + int32_t q_strideM; + int32_t k_strideM; + int32_t v_strideM; + int32_t gO_strideM; + int8_t gQKV_strideM_multiplier; // 3 for packed, 1 otherwise + + CUTLASS_HOST_DEVICE int32_t o_strideM() const { + return head_dim_value * num_heads; + } + CUTLASS_HOST_DEVICE int32_t gQ_strideM() const { + return gQKV_strideM_multiplier * num_heads * head_dim; + } + CUTLASS_HOST_DEVICE int32_t gK_strideM() const { + return gQKV_strideM_multiplier * num_heads * head_dim; + } + CUTLASS_HOST_DEVICE int32_t gV_strideM() const { + return gQKV_strideM_multiplier * num_heads * head_dim_value; + } + + // Everything below is only used in `advance_to_block` + // and shouldn't use registers + int64_t o_strideH; + int32_t q_strideH; + int32_t k_strideH; + int32_t v_strideH; + int64_t o_strideB; + int64_t q_strideB; + int64_t k_strideB; + int64_t v_strideB; + int32_t num_batches; + + int64_t gO_strideB; + int64_t gQ_strideB; + int64_t gK_strideB; + int64_t gV_strideB; + int64_t gO_strideH; + int64_t gQ_strideH; + int64_t gK_strideH; + int64_t gV_strideH; + + CUTLASS_DEVICE void advance_to_block() { constexpr int32_t kAlignLSE = 32; // block size of backward auto lse_dim = ceil_div((int32_t)num_queries, kAlignLSE) * kAlignLSE; - query_ptr += batch_id * head_dim * num_queries; - key_ptr += batch_id * head_dim * num_keys; - value_ptr += batch_id * head_dim_value * num_keys; - logsumexp_ptr += batch_id * lse_dim; - output_ptr += batch_id * head_dim_value * num_queries; - grad_output_ptr += batch_id * head_dim_value * num_queries; - delta_ptr += batch_id * num_queries; - - grad_query_ptr += batch_id * head_dim * num_queries; - grad_key_ptr += batch_id * head_dim * num_keys; - grad_value_ptr += batch_id * head_dim_value * num_keys; + int32_t batch_id = blockIdx.z; + int32_t head_id = blockIdx.y; + + query_ptr += batch_id * q_strideB + head_id * q_strideH; + key_ptr += batch_id * k_strideB + head_id * k_strideH; + value_ptr += batch_id * v_strideB + head_id * v_strideH; + logsumexp_ptr += (batch_id * num_heads + head_id) * lse_dim; + output_ptr += batch_id * o_strideB + head_id * o_strideH; + grad_output_ptr += batch_id * gO_strideB + head_id * gO_strideH; + delta_ptr += (batch_id * num_heads + head_id) * num_queries; + + grad_query_ptr += batch_id * gQ_strideB + head_id * gQ_strideH; + grad_key_ptr += batch_id * gK_strideB + head_id * gK_strideH; + grad_value_ptr += batch_id * gV_strideB + head_id * gV_strideH; + + head_dim = warp_uniform(head_dim); + head_dim_value = warp_uniform(head_dim_value); + num_queries = warp_uniform(num_queries); + num_keys = warp_uniform(num_keys); + num_heads = warp_uniform(num_heads); + + gO_strideM = warp_uniform(gO_strideM); + gQKV_strideM_multiplier = warp_uniform(gQKV_strideM_multiplier); + q_strideM = warp_uniform(q_strideM); + k_strideM = warp_uniform(k_strideM); + v_strideM = warp_uniform(v_strideM); + + query_ptr = warp_uniform(query_ptr); + key_ptr = warp_uniform(key_ptr); + value_ptr = warp_uniform(value_ptr); + logsumexp_ptr = warp_uniform(logsumexp_ptr); + output_ptr = warp_uniform(output_ptr); + grad_output_ptr = warp_uniform(grad_output_ptr); + delta_ptr = warp_uniform(delta_ptr); + + grad_query_ptr = warp_uniform(grad_query_ptr); + grad_key_ptr = warp_uniform(grad_key_ptr); + grad_value_ptr = warp_uniform(grad_value_ptr); } __host__ dim3 getBlocksGrid() const { - return dim3(1, 1, num_batches); + return dim3(1, num_heads, num_batches); } __host__ dim3 getThreadsGrid() const { return dim3(kWarpSize, kNumWarpsPerBlock, 1); @@ -179,7 +244,6 @@ struct AttentionBackwardKernel { attn_T = k_j @ q_i.transpose(-2, -1) # matmul attn_T = (attn_T - logsumexp[i_start:i_end].unsqueeze(1).transpose(-2, -1)).exp() # epilogue - with attn_T.shape = (kBlockSizeJ, kBlockSizeI) */ using ThreadblockShape = @@ -225,7 +289,6 @@ struct AttentionBackwardKernel { struct MatmulGradV { /* grad_v[j_start:j_end] += attn_T @ do_i # matmul - Dimensions: (kBlockSizeJ * kNumWarpsPerBlock, kBlockSizeI, K) (we might need to iterate multiple times on K) */ @@ -601,7 +664,7 @@ struct AttentionBackwardKernel { typename MatmulGradV::Mma::FragmentC gradV; typename MatmulGradK::Mma::FragmentC gradK; - __device__ __forceinline__ void clear() { + CUTLASS_DEVICE void clear() { gradV.clear(); gradK.clear(); } @@ -614,17 +677,21 @@ struct AttentionBackwardKernel { CHECK_ALIGNED_PTR(p.output_ptr, kMinimumAlignment); CHECK_ALIGNED_PTR(p.grad_output_ptr, kMinimumAlignment); TORCH_CHECK( - p.head_dim % kMinimumAlignment == 0, - "query/key is not correctly aligned"); + p.q_strideH % kMinimumAlignment == 0, "query is not correctly aligned"); + TORCH_CHECK( + p.k_strideH % kMinimumAlignment == 0, "key is not correctly aligned"); TORCH_CHECK( - p.head_dim_value % kMinimumAlignment == 0, - "value is not correctly aligned"); + p.v_strideH % kMinimumAlignment == 0, "value is not correctly aligned"); } - static __device__ void kernel(Params& p_) { + static CUTLASS_DEVICE void kernel(Params& p_) { // Hint to nvcc to store points & tensor shapes in registers // as we use them a lot +#if __cplusplus < 201703L register const Params p = p_; +#else + const Params p = p_; +#endif extern __shared__ char smem_buffer[]; SharedStorage& shared_storage = *((SharedStorage*)smem_buffer); @@ -658,7 +725,11 @@ struct AttentionBackwardKernel { __syncthreads(); } +#if __cplusplus < 201703L + OutputFragments register output_frags; +#else OutputFragments output_frags; +#endif int32_t key_start = 0; int32_t key_end = p.num_keys / kBlockSizeJ * kBlockSizeJ; for (; key_start < key_end; key_start += kBlockSizeJ) { @@ -695,7 +766,7 @@ struct AttentionBackwardKernel { } } - static __device__ __forceinline__ void loadDi( + static CUTLASS_DEVICE void loadDi( cutlass::Array& di, Params const& p, int32_t query_start) { @@ -710,7 +781,7 @@ struct AttentionBackwardKernel { } template - static __device__ __forceinline__ void processBlockIJ( + static CUTLASS_DEVICE void processBlockIJ( SharedStorage& shared_storage, OutputFragments& output_frags, Params const& p, @@ -718,9 +789,9 @@ struct AttentionBackwardKernel { int32_t key_start) { cutlass::MatrixCoord no_offset{0, 0}; accum_t scale = accum_t(1.0 / std::sqrt(float(p.head_dim))); - int32_t thread_id = threadIdx.x + threadIdx.y * blockDim.x; - int32_t warp_id = threadIdx.y; - int32_t lane_id = threadIdx.x; + int16_t thread_id = threadIdx.x + threadIdx.y * blockDim.x; + int8_t warp_id = warp_uniform(threadIdx.y); + int8_t lane_id = threadIdx.x; __syncthreads(); loadDi(shared_storage.di(), p, query_start); @@ -734,8 +805,8 @@ struct AttentionBackwardKernel { auto prologueGradV = [&](int col) { typename MatmulGradV::Mma::IteratorB iterator_dO( - {int32_t(p.head_dim_value)}, - p.grad_output_ptr + query_start * p.head_dim_value + col, + {int32_t(p.gO_strideM)}, + p.grad_output_ptr + query_start * p.gO_strideM + col, {num_queries_in_block, p.head_dim_value - col}, thread_id, no_offset); @@ -747,8 +818,8 @@ struct AttentionBackwardKernel { }; auto prologueGradQ = [&](int col) { typename MatmulGradQ::Mma::IteratorB iterator_K( - {int32_t(p.head_dim)}, - p.key_ptr + key_start * p.head_dim + col, + {int32_t(p.k_strideM)}, + p.key_ptr + key_start * p.k_strideM + col, {num_keys_in_block, p.head_dim - col}, thread_id, no_offset); @@ -757,8 +828,8 @@ struct AttentionBackwardKernel { }; auto prologueGradK = [&](int col) { typename MatmulGradK::Mma::IteratorB iterator_Q( - {int32_t(p.head_dim)}, - p.query_ptr + query_start * p.head_dim + col, + {int32_t(p.q_strideM)}, + p.query_ptr + query_start * p.q_strideM + col, {num_queries_in_block, p.head_dim - col}, thread_id, no_offset); @@ -770,14 +841,14 @@ struct AttentionBackwardKernel { }; auto prologueDOV = [&]() { typename MatmulDOIVJ::Mma::IteratorA iterator_A( - {int32_t(p.head_dim_value)}, - p.grad_output_ptr + query_start * p.head_dim_value, + {int32_t(p.gO_strideM)}, + p.grad_output_ptr + query_start * p.gO_strideM, {num_queries_in_block, p.head_dim_value}, thread_id, no_offset); typename MatmulDOIVJ::Mma::IteratorB iterator_B( - {int32_t(p.head_dim_value)}, - p.value_ptr + key_start * p.head_dim_value, + {int32_t(p.v_strideM)}, + p.value_ptr + key_start * p.v_strideM, {p.head_dim_value, num_keys_in_block}, thread_id, no_offset); @@ -803,16 +874,16 @@ struct AttentionBackwardKernel { // k_j typename Mma::IteratorA iterator_A( - {int32_t(p.head_dim)}, - p.key_ptr + key_start * p.head_dim, + {int32_t(p.k_strideM)}, + p.key_ptr + key_start * p.k_strideM, {problem_size.m(), problem_size.k()}, thread_id, no_offset); // q_i.transpose(-2, -1) typename Mma::IteratorB iterator_B( - {int32_t(p.head_dim)}, - p.query_ptr + query_start * p.head_dim, + {int32_t(p.q_strideM)}, + p.query_ptr + query_start * p.q_strideM, {problem_size.k(), problem_size.n()}, thread_id, no_offset); @@ -893,14 +964,14 @@ struct AttentionBackwardKernel { num_keys_in_block, p.head_dim_value - col, num_queries_in_block); auto createEpilogueIter = [&]() { return typename MatmulGradV::OutputTileIterator( - typename MatmulGradV::OutputTileIterator::Params{p.head_dim_value}, - p.grad_value_ptr + key_start * p.head_dim_value + col, + typename MatmulGradV::OutputTileIterator::Params{p.gV_strideM()}, + p.grad_value_ptr + key_start * p.gV_strideM() + col, {num_keys_in_block, p.head_dim_value - col}, thread_id); }; typename Mma::IteratorB iterator_B( - {int32_t(p.head_dim_value)}, - p.grad_output_ptr + query_start * p.head_dim_value + col, + {int32_t(p.gO_strideM)}, + p.grad_output_ptr + query_start * p.gO_strideM + col, {num_queries_in_block, p.head_dim_value - col}, thread_id, no_offset); @@ -951,16 +1022,16 @@ struct AttentionBackwardKernel { using Mma = typename MatmulDOIVJ::Mma; // do_i typename Mma::IteratorA iterator_A( - {int32_t(p.head_dim_value)}, - p.grad_output_ptr + query_start * p.head_dim_value, + {int32_t(p.gO_strideM)}, + p.grad_output_ptr + query_start * p.gO_strideM, {num_queries_in_block, p.head_dim_value}, thread_id, no_offset); // v_j.transpose(-2, -1) typename Mma::IteratorB iterator_B( - {int32_t(p.head_dim_value)}, - p.value_ptr + key_start * p.head_dim_value, + {int32_t(p.v_strideM)}, + p.value_ptr + key_start * p.v_strideM, {p.head_dim_value, num_keys_in_block}, thread_id, no_offset); @@ -1057,16 +1128,16 @@ struct AttentionBackwardKernel { num_keys_in_block); auto createEpilogueIter = [&]() { return typename MatmulGradQ::OutputTileIterator( - typename MatmulGradQ::OutputTileIterator::Params{p.head_dim}, - p.grad_query_ptr + query_start * p.head_dim + col, + typename MatmulGradQ::OutputTileIterator::Params{p.gQ_strideM()}, + p.grad_query_ptr + query_start * p.gQ_strideM() + col, {problem_size.m(), problem_size.n()}, thread_id); }; // k_j typename Mma::IteratorB iterator_B( - {int32_t(p.head_dim)}, - p.key_ptr + key_start * p.head_dim + col, + {int32_t(p.k_strideM)}, + p.key_ptr + key_start * p.k_strideM + col, {problem_size.k(), problem_size.n()}, thread_id, no_offset); @@ -1153,8 +1224,8 @@ struct AttentionBackwardKernel { num_queries_in_block); auto createEpilogueIter = [&]() { return typename MatmulGradK::OutputTileIterator( - typename MatmulGradK::OutputTileIterator::Params{p.head_dim}, - p.grad_key_ptr + key_start * p.head_dim + col, + typename MatmulGradK::OutputTileIterator::Params{p.gK_strideM()}, + p.grad_key_ptr + key_start * p.gK_strideM() + col, {num_keys_in_block, false ? MatmulGradK::ThreadblockShape::kN : p.head_dim - col}, thread_id); @@ -1162,8 +1233,8 @@ struct AttentionBackwardKernel { // q_i typename Mma::IteratorB iterator_B( - {int32_t(p.head_dim)}, - p.query_ptr + query_start * p.head_dim + col, + {int32_t(p.q_strideM)}, + p.query_ptr + query_start * p.q_strideM + col, {problem_size.k(), problem_size.n()}, thread_id, no_offset); @@ -1236,15 +1307,15 @@ struct AttentionBackwardKernel { kForceReloadK || !MatmulQK::Mma::kSmemContainsEntireMat; auto thread_id = get_thread_id(); typename MatmulQK::Mma::IteratorA iterator_A( - {int32_t(p.head_dim)}, - p.key_ptr + key_start * p.head_dim, + {int32_t(p.k_strideM)}, + p.key_ptr + key_start * p.k_strideM, {p.num_keys - key_start, p.head_dim}, thread_id, cutlass::MatrixCoord{0, 0}); typename MatmulQK::Mma::IteratorB iterator_B( - {int32_t(p.head_dim)}, - p.query_ptr + query_start * p.head_dim, + {int32_t(p.q_strideM)}, + p.query_ptr + query_start * p.q_strideM, {p.head_dim, p.num_queries - query_start}, thread_id, cutlass::MatrixCoord{0, 0}); @@ -1259,7 +1330,7 @@ struct AttentionBackwardKernel { } template - static __device__ __forceinline__ void writeFragsToGmem( + static CUTLASS_DEVICE void writeFragsToGmem( SharedStorage& shared_storage, OutputFragments& output_frags, Params const& p, @@ -1268,8 +1339,8 @@ struct AttentionBackwardKernel { ? MatmulQK::Mma::Shape::kM : std::min((int32_t)MatmulQK::Mma::Shape::kM, p.num_keys - key_start); typename MatmulGradV::OutputTileIterator outputV_it( - typename MatmulGradV::OutputTileIterator::Params{p.head_dim_value}, - p.grad_value_ptr + key_start * p.head_dim_value, + typename MatmulGradV::OutputTileIterator::Params{p.gV_strideM()}, + p.grad_value_ptr + key_start * p.gV_strideM(), {num_keys_in_block, p.head_dim_value}, get_thread_id()); accumulateInGmem( @@ -1279,8 +1350,8 @@ struct AttentionBackwardKernel { true); typename MatmulGradK::OutputTileIterator outputK_it( - typename MatmulGradK::OutputTileIterator::Params{p.head_dim}, - p.grad_key_ptr + key_start * p.head_dim, + typename MatmulGradK::OutputTileIterator::Params{p.gK_strideM()}, + p.grad_key_ptr + key_start * p.gK_strideM(), {num_keys_in_block, false ? MatmulGradK::ThreadblockShape::kN : p.head_dim}, get_thread_id()); @@ -1292,7 +1363,7 @@ struct AttentionBackwardKernel { } template - static __device__ __forceinline__ void accumulateInGmem( + static CUTLASS_DEVICE void accumulateInGmem( typename MatmulT::DefaultEpilogue::SharedStorage& epilogue_smem, typename MatmulT::Mma::FragmentC const& accum, typename MatmulT::OutputTileIterator output_it, @@ -1334,7 +1405,9 @@ struct AttentionBackwardKernel { } template - static __device__ void computeDelta(Params const& p, int32_t query_start) { + static CUTLASS_DEVICE void computeDelta( + Params const& p, + int32_t query_start) { // Each thread computes one value for Delta // Depending on warp configuration, we might have multiple // threads of the same warp working on the same row @@ -1349,13 +1422,15 @@ struct AttentionBackwardKernel { bool rowPred = (query_start + laneRow) < p.num_queries; bool pred = rowPred; - const __restrict__ AccessType* grad_output_ptr = - reinterpret_cast( - p.grad_output_ptr + (query_start + laneRow) * p.head_dim_value + + // on windows, previous syntax __restrict__ AccessType* + // resulted in error: "restrict" is not allowed + const AccessType* __restrict__ grad_output_ptr = + reinterpret_cast( + p.grad_output_ptr + (query_start + laneRow) * p.gO_strideM + laneFirstCol); - const __restrict__ AccessType* output_ptr = - reinterpret_cast( - p.output_ptr + (query_start + laneRow) * p.head_dim_value + + const AccessType* __restrict__ output_ptr = + reinterpret_cast( + p.output_ptr + (query_start + laneRow) * p.o_strideM() + laneFirstCol); static constexpr int64_t kMaxIters = @@ -1430,13 +1505,13 @@ struct AttentionBackwardKernel { } } - static __device__ __forceinline__ int8_t get_lane_id() { + static CUTLASS_DEVICE int8_t get_lane_id() { return threadIdx.x; } - static __device__ __forceinline__ int8_t get_warp_id() { + static CUTLASS_DEVICE int8_t get_warp_id() { return threadIdx.y; } - static __device__ __forceinline__ int16_t get_thread_id() { + static CUTLASS_DEVICE int16_t get_thread_id() { return threadIdx.x + threadIdx.y * blockDim.x; } }; @@ -1457,8 +1532,7 @@ __global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) #define INSTANTIATE_ATTENTION_KERNEL_BACKWARD(ARCH, ...) \ _ATTENTION_KERNEL_BACKWARD_BEGIN( \ AttentionBackwardKernel) \ - auto batch_id = blockIdx.z; \ - p.advance_batches(batch_id); \ + p.advance_to_block(); \ Kernel::kernel(p); \ _ATTENTION_KERNEL_BACKWARD_END(); diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.h b/aten/src/ATen/native/transformers/cuda/sdp_utils.h index 218322f995d67..2b57ef6dd6f6c 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.h +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.h @@ -9,6 +9,8 @@ #include #include #include +#include +#include #include #include @@ -26,8 +28,6 @@ struct sdp_params { bool is_causal; }; -enum class SDPBackend { flash_attention, efficient_attention, math, error }; - template inline bool check_tensor_dtype( sdp_params params, @@ -40,7 +40,9 @@ inline bool check_tensor_dtype( allowed_dtypes.end()))) { TORCH_CHECK( !debug, - "Expected query, key and value to be of dtype float16 or bfloat16 but got Query dtype: ", + "Expected query, key and value to all be of dtype: {", + c10::Join(", ", allowed_dtypes), "}. Got ", + "Query dtype: ", params.query.dtype(), ", Key dtype: ", params.key.dtype(), @@ -62,6 +64,60 @@ inline bool check_for_attn_weights(sdp_params params, bool debug) { return true; } +inline bool check_for_non_zero_dropout(sdp_params params, bool debug) { + if (params.dropout != 0.0) { + TORCH_CHECK(!debug, "Mem_efficient does not support non_zero dropout. Dropout_p: ", params.dropout); + return false; + } + return true; +} + +inline bool check_for_seq_len_1_nested_tensor(sdp_params params, bool debug) { + if (!params.query.is_nested()) { + return true; + } + const at::Tensor& sizes = at::native::get_nested_tensor_impl(params.query)->get_nested_size_tensor(); + auto* sizes_ptr = sizes.data_ptr(); + const int64_t n_tensors = params.query.size(0); + const int64_t size_tensor_stride = sizes.stride(0); + + // This is being called inside sdp with shape [batch, heads, {seq_len}, dim] + for (const auto i : c10::irange(n_tensors)) { + if (sizes_ptr[(i * size_tensor_stride) + 1] <= 1) { + TORCH_CHECK( + !debug, "Flash Attention does not support sequence_length <= 1"); + return false; + } + } + + return true; +} + +inline bool check_for_nested_inputs(sdp_params params, bool debug){ + if (params.query.is_nested() || params.key.is_nested() || params.value.is_nested()) { + TORCH_CHECK(!debug, "We are not enabling nested Tensors for Flash Attention because of cuda memory errors."); + return false; + } + return true; +} + +inline bool check_requires_grad(sdp_params params, bool debug) { + if (params.query.requires_grad() || params.key.requires_grad() || params.value.requires_grad()) { + TORCH_CHECK(!debug, "Flash Attention does not currently support training."); + return false; + } + return true; +} + +inline bool check_requires_grad_and_nested(sdp_params params, bool debug) { + // If we fail both checks then we return false + if (!check_for_nested_inputs(params, false) && !check_requires_grad(params,false)){ + TORCH_CHECK(!debug, "Memory efficient attention currently doesn't support training with NT inputs."); + return false; + } + return true; +} + inline bool check_for_attn_mask(sdp_params params, bool debug) { if (params.has_attn_mask) { TORCH_CHECK(!debug, "Flash Attention does not support attention mask."); @@ -73,7 +129,7 @@ inline bool check_for_attn_mask(sdp_params params, bool debug) { inline bool check_tensor_shapes(sdp_params params, bool debug) { auto query_dim = params.query.dim(); if (!(query_dim == params.key.dim() && query_dim == params.value.dim() && - query_dim == 4)) { + (query_dim == 4 ))) { TORCH_CHECK( !debug, "Flash attention requires query, key and value to be 4 dimensional, but got Query dim: ", @@ -108,7 +164,26 @@ inline bool check_head_dim_size(sdp_params params, bool debug) { return true; } -inline bool check_runtime_disabled(sdp_params params, bool debug) { +inline bool check_head_dim_size_mem_efficient(sdp_params params, bool debug) { + const int64_t query_size_last = params.query.size(-1); + if (!(query_size_last == params.key.size(-1) && + query_size_last == params.value.size(-1) && query_size_last >= 8)) { + TORCH_CHECK( + !debug, + "Mem efficient attention requires last dimension of inputs to be >= 8.", + "Got Query.size(-1): ", + query_size_last, + ", Key.size(-1): ", + params.key.size(-1), + ", Value.size(-1): ", + params.value.size(-1), + " instead."); + return false; + } + return true; +} + +inline bool check_runtime_disabled_flash(sdp_params params, bool debug) { // We check the global context to see if user has explicitly turned of flash // sdp kernels if (!at::globalContext().userEnabledFlashSDP()) { @@ -118,6 +193,16 @@ inline bool check_runtime_disabled(sdp_params params, bool debug) { return true; } +inline bool check_runtime_disabled_mem_efficient(sdp_params params, bool debug) { + // We check the global context to see if user has explicitly turned of mem_efficient + // sdp kernels + if (!at::globalContext().userEnabledMemEfficientSDP()) { + TORCH_CHECK(!debug, "Memory Efficient attention has been runtime disabled."); + return false; + } + return true; +} + inline bool check_gpu_sm75_or_greater(sdp_params params, bool debug) { // Check that the gpu is capable of running flash attention auto dprops = at::cuda::getCurrentDeviceProperties(); @@ -158,27 +243,31 @@ inline bool use_flash_attention(sdp_params params, bool debug) { TORCH_CHECK(!debug, "Torch was not compiled with flash attention."); return false; #endif - // Constraints specific to flash attention - static const std::vector flash_dtypes{ - at::kHalf, at::kBFloat16}; - // Define gate functions that determine if a flash kernel can be ran - std::vector> constraints{ - check_runtime_disabled, + constexpr std::array constraints {{ + check_runtime_disabled_flash, + check_requires_grad, check_tensor_shapes, check_for_attn_weights, check_for_attn_mask, check_head_dim_size, - check_gpu_sm75_or_greater}; + check_gpu_sm75_or_greater, + check_for_nested_inputs, + check_for_seq_len_1_nested_tensor}}; for (auto& constraint : constraints) { if (!constraint(params, debug)) { return false; } } - if (!check_tensor_dtype(params, flash_dtypes, debug)) { - return false; + + auto dprop = at::cuda::getCurrentDeviceProperties(); + if (dprop->major >= 8) { + static const std::array sm80_flash_dtypes{at::kHalf, at::kBFloat16}; + return check_tensor_dtype(params, sm80_flash_dtypes, debug); + } else { + static const std::array default_flash_dtypes{at::kHalf}; + return check_tensor_dtype(params, default_flash_dtypes, debug); } - return true; } inline bool use_mem_efficient_attention(sdp_params params, bool debug) { @@ -191,12 +280,16 @@ inline bool use_mem_efficient_attention(sdp_params params, bool debug) { at::kHalf, at::kFloat, at::kBFloat16}; // Define gate functions that determine if a flash kernel can be ran - std::vector> constraints{ + constexpr std::array constraints{{ check_gpu_sm50_or_greater, - check_runtime_disabled, + check_runtime_disabled_mem_efficient, + check_requires_grad_and_nested, check_for_attn_weights, check_tensor_shapes, - check_for_attn_mask}; + check_for_attn_mask, + check_head_dim_size_mem_efficient, + check_for_seq_len_1_nested_tensor, + check_for_non_zero_dropout}}; for (auto& constraint : constraints) { if (!constraint(params, debug)) { return false; @@ -214,7 +307,7 @@ inline SDPBackend select_sdp_backend(sdp_params kernel_params) { // 2. Mem Efficient Attention // 3. Math fallback auto& ctx = at::globalContext(); - if (!ctx.userEnabledMathSDP() && !ctx.userEnabledFlashSDP()) { + if (!ctx.userEnabledMathSDP() && !ctx.userEnabledFlashSDP() && !ctx.userEnabledMemEfficientSDP()) { return SDPBackend::error; } // Because TORCHCHECK checks if condition is true we negate debug so that diff --git a/aten/src/ATen/native/transformers/sdp_utils_cpp.h b/aten/src/ATen/native/transformers/sdp_utils_cpp.h new file mode 100644 index 0000000000000..9641a36b33b2c --- /dev/null +++ b/aten/src/ATen/native/transformers/sdp_utils_cpp.h @@ -0,0 +1,9 @@ +#pragma once +namespace sdp { +enum class SDPBackend { + error = -1, + math = 0, + flash_attention = 1, + efficient_attention = 2 +}; +} // namespace sdp \ No newline at end of file diff --git a/aten/src/ATen/native/transformers/transformer.cpp b/aten/src/ATen/native/transformers/transformer.cpp index afadb0d6bce8b..4a4c9946b35aa 100644 --- a/aten/src/ATen/native/transformers/transformer.cpp +++ b/aten/src/ATen/native/transformers/transformer.cpp @@ -95,44 +95,27 @@ Tensor transformer_encoder_layer_forward( if (norm_first) { x = norm(x, embed_dim, layer_norm_eps, layer_norm_weight_1, layer_norm_bias_1, use_nested_tensor); } + x = std::get<0>(at::_native_multi_head_attention( + x, + x, + x, + embed_dim, + num_heads, + qkv_weight, + qkv_bias, + proj_weight, + proj_bias, + mask, + false /* need_weights */, + true /* average_attn_weights */, + mask_type)); -#if BETTER_TRANSFORMER_USE_FLASH_ATTENTION - if (x.is_nested() && x.is_cuda() && x.dtype() == at::kHalf && !mask.has_value() && - (embed_dim / num_heads == 16 || - embed_dim / num_heads == 32 || - embed_dim / num_heads == 64 || - embed_dim / num_heads == 128)) { - TORCH_WARN_ONCE("transformer_encoder_layer_forward is using flash attention."); - x = at::linear(x, qkv_weight, qkv_bias); - auto x_size_0 = x.size(0); - x = x.view({x_size_0, -1, 3, num_heads, embed_dim / num_heads}); - x = flash_attention_helper(x, x, x, 0.0, false, false); - x = x.view({x_size_0, -1, embed_dim}); - x = at::linear(x, proj_weight, proj_bias); - } else { -#endif - x = std::get<0>(native_multi_head_attention( - x, - x, - x, - embed_dim, - num_heads, - qkv_weight, - qkv_bias, - proj_weight, - proj_bias, - mask, - false /* need_weights */, - true /* average_attn_weights */, - mask_type)); -#if BETTER_TRANSFORMER_USE_FLASH_ATTENTION - } -#endif x.add_(src); if (!norm_first) { x = norm(x, embed_dim, layer_norm_eps, layer_norm_weight_1, layer_norm_bias_1, use_nested_tensor); } + auto pre_ffn_res = x; if (norm_first) { diff --git a/aten/src/ATen/native/ts_native_functions.yaml b/aten/src/ATen/native/ts_native_functions.yaml index fc287045dc9dd..f4c3ee8498960 100644 --- a/aten/src/ATen/native/ts_native_functions.yaml +++ b/aten/src/ATen/native/ts_native_functions.yaml @@ -189,6 +189,7 @@ supported: # after functionalization, # but their implementations call view operators (which we need to functionalize away). - block_diag + - diag_embed - diagonal_backward - slice_backward - new_empty_strided @@ -226,33 +227,10 @@ non_native: - ShapeCompute - TreatScalarsAsConstants - CanBeReusedDeclOnly + # Even we have removed all the other view ops in favor of the *_copy version, expand + # is still kept because it's used in copy_. - func: expand(Tensor input, int[] size, bool is_scalar_expand) -> Tensor - - func: view(Tensor input, int[] output_size) -> Tensor - properties: - - ShapeCompute - func: cast(Tensor input, ScalarType dtype, ScalarType? stype) -> Tensor opkind: ltc_cast properties: - ShapeCompute - - # View ops only required until proper functionalization pass is introduced into LTC - - func: as_strided_view_update(Tensor target, Tensor input, int[] size, int[] stride, int storage_offset) -> Tensor - opkind: ltc_as_strided_view_update - - func: as_strided(Tensor input, int[] size, int[] stride, int storage_offset) -> Tensor - - func: diagonal_view_update(Tensor target, Tensor input, int offset, int dim1, int dim2) -> Tensor - opkind: ltc_diagonal_view_update - properties: - - ShapeCompute - - func: diagonal(Tensor input, int offset, int dim1, int dim2) -> Tensor - - func: narrow_view_update(Tensor input, Tensor source, int[] base_indices) -> Tensor - opkind: ltc_narrow_view_update - - func: narrow(Tensor input, int[] base_indices, int[] sizes) -> Tensor - - func: permute(Tensor input, int[] dims) -> Tensor - - func: resize(Tensor input, int[] size) -> Tensor - - func: select_view_update(Tensor target, Tensor source, int dim, int start, int end, int stride) -> Tensor - opkind: ltc_select_view_update - properties: - - ShapeCompute - - func: select(Tensor input, int dim, int start, int end, int stride) -> Tensor - - func: squeeze(Tensor input, int dim) -> Tensor - - func: unsqueeze(Tensor input, int dim) -> Tensor diff --git a/aten/src/ATen/native/utils/ParamUtils.h b/aten/src/ATen/native/utils/ParamUtils.h index 376467ff79cf5..adb5f1cfa49f9 100644 --- a/aten/src/ATen/native/utils/ParamUtils.h +++ b/aten/src/ATen/native/utils/ParamUtils.h @@ -6,12 +6,13 @@ namespace at { namespace native { -inline std::vector expand_param_if_needed( - IntArrayRef list_param, +template +inline std::vector _expand_param_if_needed( + ArrayRef list_param, const char* param_name, int64_t expected_dim) { if (list_param.size() == 1) { - return std::vector(expected_dim, list_param[0]); + return std::vector(expected_dim, list_param[0]); } else if ((int64_t)list_param.size() != expected_dim) { std::ostringstream ss; ss << "expected " << param_name << " to be a single integer value or a " @@ -23,5 +24,19 @@ inline std::vector expand_param_if_needed( } } +inline std::vector expand_param_if_needed( + IntArrayRef list_param, + const char* param_name, + int64_t expected_dim) { + return _expand_param_if_needed(list_param, param_name, expected_dim); +} + +inline std::vector expand_param_if_needed( + SymIntArrayRef list_param, + const char* param_name, + int64_t expected_dim) { + return _expand_param_if_needed(list_param, param_name, expected_dim); +} + } // namespace native } // namespace at diff --git a/aten/src/ATen/native/utils/ParamsHash.h b/aten/src/ATen/native/utils/ParamsHash.h index 76bb4de53d633..c4056ab1b3f1e 100644 --- a/aten/src/ATen/native/utils/ParamsHash.h +++ b/aten/src/ATen/native/utils/ParamsHash.h @@ -11,8 +11,8 @@ namespace at { namespace native { template struct ParamsHash { // Params must be a POD because we read out its memory - // contenst as char* when hashing - static_assert(std::is_pod::value, "Params is not POD"); + // contents as char* when hashing + static_assert(std::is_standard_layout::value, "Params is not POD"); size_t operator()(const Params& params) const { auto ptr = reinterpret_cast(¶ms); @@ -28,8 +28,8 @@ struct ParamsHash { template struct ParamsEqual { // Params must be a POD because we read out its memory - // contenst as char* when comparing - static_assert(std::is_pod::value, "Params is not POD"); + // contents as char* when comparing + static_assert(std::is_standard_layout::value, "Params is not POD"); bool operator()(const Params& a, const Params& b) const { auto ptr1 = reinterpret_cast(&a); diff --git a/aten/src/ATen/native/vol2col.h b/aten/src/ATen/native/vol2col.h index 12718a8f00afc..2b2ee3b57b0c4 100644 --- a/aten/src/ATen/native/vol2col.h +++ b/aten/src/ATen/native/vol2col.h @@ -1,8 +1,6 @@ #pragma once -#include -#include -#include +#include namespace at { namespace native { diff --git a/aten/src/ATen/native/vulkan/api/Adapter.cpp b/aten/src/ATen/native/vulkan/api/Adapter.cpp index 311648b6894ed..176236611c1d9 100644 --- a/aten/src/ATen/native/vulkan/api/Adapter.cpp +++ b/aten/src/ATen/native/vulkan/api/Adapter.cpp @@ -195,7 +195,7 @@ std::string get_device_type_str(const VkPhysicalDeviceType type) { case VK_PHYSICAL_DEVICE_TYPE_CPU: return "CPU"; default: - return "UNKOWN"; + return "UNKNOWN"; } } diff --git a/aten/src/ATen/native/vulkan/api/Common.h b/aten/src/ATen/native/vulkan/api/Common.h index 738592408f6f8..3cfee491d7eab 100644 --- a/aten/src/ATen/native/vulkan/api/Common.h +++ b/aten/src/ATen/native/vulkan/api/Common.h @@ -21,6 +21,12 @@ CONCAT_LITERALS(vulkan., name), name##_spv, name##_spv_len, \ name##_spv_layout \ } +#define VK_SHADER(name) \ + ::at::native::vulkan::api::ShaderInfo { \ + CONCAT_LITERALS(vulkan., name), name##_spv, name##_spv_len, \ + name##_spv_layout, name##_spv_tile_size, name##_spv_bias_storage_type, \ + name##_spv_weight_storage_type, \ + } #endif /* USE_VULKAN_SHADERC_RUNTIME */ /* diff --git a/aten/src/ATen/native/vulkan/api/Context.cpp b/aten/src/ATen/native/vulkan/api/Context.cpp index 73bbb4b21c4ad..06038b9e4ecfa 100644 --- a/aten/src/ATen/native/vulkan/api/Context.cpp +++ b/aten/src/ATen/native/vulkan/api/Context.cpp @@ -154,6 +154,63 @@ Context* context() { return context.get(); } +// +// UniformParamsBuffer +// + +namespace { + +void memcpy_to_buffer(const VulkanBuffer& src, VulkanBuffer& dst) { + MemoryMap dst_mapping(dst, MemoryAccessType::WRITE); + + MemoryMap src_mapping(src, api::MemoryAccessType::READ); + src_mapping.invalidate(); + + void* dst_ptr = dst_mapping.template data(); + void* src_ptr = src_mapping.template data(); + + memcpy(dst_ptr, src_ptr, src.mem_size()); +} + +} // namespace + +UniformParamsBuffer::UniformParamsBuffer(const UniformParamsBuffer& other) + : context_p_(other.context_p_), vulkan_buffer_{} { + if (other.vulkan_buffer_) { + vulkan_buffer_ = context_p_->adapter_ptr()->vma().create_uniform_buffer( + other.vulkan_buffer_.mem_size()); + + memcpy_to_buffer(other.vulkan_buffer_, vulkan_buffer_); + } +} + +UniformParamsBuffer& UniformParamsBuffer::operator=( + const UniformParamsBuffer& other) { + if (&other != this) { + context_p_ = other.context_p_; + + // Move vulkan_buffer_ to another VulkanBuffer for cleanup + if (vulkan_buffer_) { + VulkanBuffer temp_buffer(std::move(vulkan_buffer_)); + context_p_->register_buffer_cleanup(temp_buffer); + } + // vulkan_buffer_ should now be empty + + if (other.vulkan_buffer_) { + vulkan_buffer_ = context_p_->adapter_ptr()->vma().create_uniform_buffer( + other.vulkan_buffer_.mem_size()); + + memcpy_to_buffer(other.vulkan_buffer_, vulkan_buffer_); + } + } + + return *this; +} + +// +// VulkanImpl +// + struct VulkanImpl final : public at::vulkan::VulkanImplInterface { bool is_vulkan_available() const override { return available(); diff --git a/aten/src/ATen/native/vulkan/api/Context.h b/aten/src/ATen/native/vulkan/api/Context.h index 56db8fa6a173b..ce0525abda573 100644 --- a/aten/src/ATen/native/vulkan/api/Context.h +++ b/aten/src/ATen/native/vulkan/api/Context.h @@ -206,14 +206,16 @@ class UniformParamsBuffer final { VulkanBuffer vulkan_buffer_; public: + UniformParamsBuffer() : context_p_{nullptr}, vulkan_buffer_{} {} + template UniformParamsBuffer(Context* context_p, const Block& block) : context_p_(context_p), vulkan_buffer_( context_p_->adapter_ptr()->vma().create_params_buffer(block)) {} - UniformParamsBuffer(const UniformParamsBuffer&) = delete; - UniformParamsBuffer& operator=(const UniformParamsBuffer&) = delete; + UniformParamsBuffer(const UniformParamsBuffer&); + UniformParamsBuffer& operator=(const UniformParamsBuffer&); UniformParamsBuffer(UniformParamsBuffer&&) = default; UniformParamsBuffer& operator=(UniformParamsBuffer&&) = default; diff --git a/aten/src/ATen/native/vulkan/api/Resource.cpp b/aten/src/ATen/native/vulkan/api/Resource.cpp index 9cfdbcdb03f3e..517bd0a56232f 100644 --- a/aten/src/ATen/native/vulkan/api/Resource.cpp +++ b/aten/src/ATen/native/vulkan/api/Resource.cpp @@ -26,8 +26,8 @@ namespace api { * always created with the corresponding VkFormat. Consequently, kHalf tensors * are currently unsupported in favor of enforcing inputs to be of kFloat dtype. */ -VkFormat vk_format(const caffe2::TypeMeta dtype) { - switch (c10::typeMetaToScalarType(dtype)) { +VkFormat vk_format(const at::ScalarType dtype) { + switch (dtype) { case kFloat: #ifdef USE_VULKAN_FP16_INFERENCE return VK_FORMAT_R16G16B16A16_SFLOAT; @@ -36,6 +36,10 @@ VkFormat vk_format(const caffe2::TypeMeta dtype) { #endif /* USE_VULKAN_FP16_INFERENCE */ case c10::kQUInt8: return VK_FORMAT_R8G8B8A8_UINT; + case c10::kQInt8: + return VK_FORMAT_R8G8B8A8_SINT; + case c10::kQInt32: + return VK_FORMAT_R32G32B32A32_SINT; default: TORCH_CHECK( @@ -663,6 +667,21 @@ VulkanBuffer MemoryAllocator::create_staging_buffer(const VkDeviceSize size) { return VulkanBuffer(allocator_, size, mem_props); } +VulkanBuffer MemoryAllocator::create_uniform_buffer(const VkDeviceSize size) { + const VulkanBuffer::MemoryProperties mem_props{ + DEFAULT_ALLOCATION_STRATEGY | + VMA_ALLOCATION_CREATE_HOST_ACCESS_SEQUENTIAL_WRITE_BIT, + VMA_MEMORY_USAGE_AUTO, + 0u, + 0u, + VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT, + }; + + VulkanBuffer uniform_buffer(allocator_, size, mem_props); + + return uniform_buffer; +} + // // VulkanFence // diff --git a/aten/src/ATen/native/vulkan/api/Resource.h b/aten/src/ATen/native/vulkan/api/Resource.h index 52153ebc0e05f..9180b3422db13 100644 --- a/aten/src/ATen/native/vulkan/api/Resource.h +++ b/aten/src/ATen/native/vulkan/api/Resource.h @@ -18,7 +18,7 @@ namespace api { typedef uint8_t MemoryAccessFlags; -VkFormat vk_format(const caffe2::TypeMeta dtype); +VkFormat vk_format(const at::ScalarType dtype); c10::ScalarType c10_scalartype(const VkFormat image_format); @@ -401,6 +401,14 @@ class MemoryAllocator final { VulkanBuffer create_staging_buffer(const VkDeviceSize); + /* + * Create a uniform buffer with a specified size + */ + VulkanBuffer create_uniform_buffer(const VkDeviceSize); + + /* + * Create a uniform buffer containing the data in an arbitrary struct + */ template VulkanBuffer create_params_buffer(const Block& block); }; @@ -486,16 +494,7 @@ struct FencePool final { template inline VulkanBuffer MemoryAllocator::create_params_buffer(const Block& block) { - const VulkanBuffer::MemoryProperties mem_props{ - DEFAULT_ALLOCATION_STRATEGY | - VMA_ALLOCATION_CREATE_HOST_ACCESS_SEQUENTIAL_WRITE_BIT, - VMA_MEMORY_USAGE_AUTO, - 0u, - 0u, - VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT, - }; - - VulkanBuffer uniform_buffer(allocator_, sizeof(Block), mem_props); + VulkanBuffer uniform_buffer = create_uniform_buffer(sizeof(Block)); // Fill the uniform buffer with data in block { diff --git a/aten/src/ATen/native/vulkan/api/Shader.cpp b/aten/src/ATen/native/vulkan/api/Shader.cpp index 2ead82bc934e2..1ca37ba999998 100644 --- a/aten/src/ATen/native/vulkan/api/Shader.cpp +++ b/aten/src/ATen/native/vulkan/api/Shader.cpp @@ -50,6 +50,23 @@ ShaderSource::ShaderSource( kernel_name{std::move(name)}, kernel_layout{layout} {} +ShaderInfo::ShaderInfo( + std::string name, + const uint32_t* const spirv_bin, + const uint32_t size, + const std::vector& layout, + const std::vector& tile_size, + const StorageType bias_storage_type, + const StorageType weight_storage_type) + : shader_src(name, spirv_bin, size, layout), + tile_size(tile_size), + bias_storage_type(bias_storage_type), + weight_storage_type(weight_storage_type) { + for (uint64_t i = 0; i < tile_size.size(); ++i) { + shader_src.out_tile_size.data[i] = tile_size[i]; + } +} + bool operator==(const ShaderSource& _1, const ShaderSource& _2) { if (_1.type != _2.type) { return false; diff --git a/aten/src/ATen/native/vulkan/api/Shader.h b/aten/src/ATen/native/vulkan/api/Shader.h index 12cc5c193d123..c676d10b19379 100644 --- a/aten/src/ATen/native/vulkan/api/Shader.h +++ b/aten/src/ATen/native/vulkan/api/Shader.h @@ -3,6 +3,7 @@ #ifdef USE_VULKAN_API #include +#include #include #include #include @@ -74,6 +75,24 @@ struct ShaderSource final { bool operator==(const ShaderSource& _1, const ShaderSource& _2); +struct ShaderInfo final { + ShaderSource shader_src; + c10::SmallVector tile_size; + StorageType bias_storage_type{StorageType::UNKNOWN}; + StorageType weight_storage_type{StorageType::UNKNOWN}; + + explicit ShaderInfo() = default; + explicit ShaderInfo(std::string, const char*); + explicit ShaderInfo( + std::string, + const uint32_t*, + const uint32_t, + const std::vector&, + const std::vector& tile_size, + const StorageType bias_storage_type, + const StorageType weight_storage_type); +}; + class ShaderModule final { public: explicit ShaderModule(const VkDevice device, const ShaderSource& source); diff --git a/aten/src/ATen/native/vulkan/api/Types.h b/aten/src/ATen/native/vulkan/api/Types.h new file mode 100644 index 0000000000000..ff4ce3e7044d7 --- /dev/null +++ b/aten/src/ATen/native/vulkan/api/Types.h @@ -0,0 +1,21 @@ +#pragma once + +#ifdef USE_VULKAN_API +namespace at { +namespace native { +namespace vulkan { +namespace api { + +enum class StorageType { + BUFFER, + TEXTURE_3D, + TEXTURE_2D, + UNKNOWN, +}; + +} // namespace api +} // namespace vulkan +} // namespace native +} // namespace at + +#endif /* USE_VULKAN_API */ diff --git a/aten/src/ATen/native/vulkan/glsl/batchnorm.glsl b/aten/src/ATen/native/vulkan/glsl/batchnorm.glsl index 6ec93422b0d6b..0ec7dbdf4fcf5 100644 --- a/aten/src/ATen/native/vulkan/glsl/batchnorm.glsl +++ b/aten/src/ATen/native/vulkan/glsl/batchnorm.glsl @@ -1,37 +1,61 @@ #version 450 core #define PRECISION $precision -#define FORMAT $format +#define FORMAT $format layout(std430) buffer; -/* Qualifiers: layout - storage - precision - memory */ - -layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput; -layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; -layout(set = 0, binding = 2) uniform PRECISION sampler3D uGamma; -layout(set = 0, binding = 3) uniform PRECISION sampler3D uBeta; -layout(set = 0, binding = 4) uniform PRECISION sampler3D uMean; -layout(set = 0, binding = 5) uniform PRECISION sampler3D uVar; -layout(set = 0, binding = 6) uniform PRECISION restrict Block { - ivec3 isize; - int channels_ext; +/* + * Output Image + */ +layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput; + +/* + * Input Textures + */ +layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; +layout(set = 0, binding = 2) uniform PRECISION sampler3D uGamma; +layout(set = 0, binding = 3) uniform PRECISION sampler3D uBeta; +layout(set = 0, binding = 4) uniform PRECISION sampler3D uMean; +layout(set = 0, binding = 5) uniform PRECISION sampler3D uVar; + +/* + * Params Buffer + */ +layout(set = 0, binding = 6) uniform PRECISION restrict Block { + // xyz contains extents of the output texture, w contains the number of + // channels divided by 4, rounded up. + ivec4 out_extents; float eps; -} uBlock; +} +uBlock; +/* + * Local Work Group + */ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +/* + * Computes a Batch normalization. Each shader invocation calculates the output + * at a single output location. + */ void main() { const ivec3 pos = ivec3(gl_GlobalInvocationID); - if (all(lessThan(pos, uBlock.isize.xyz))) { - const ivec3 chn = ivec3(0, 0, pos.z % uBlock.channels_ext); - imageStore( - uOutput, - pos, - (texelFetch(uInput, pos, 0) - - texelFetch(uMean, chn, 0)) - / sqrt(texelFetch(uVar, chn, 0) + uBlock.eps) - * texelFetch(uGamma, chn, 0) - + texelFetch(uBeta, chn, 0)); + // Return if this global position is outside output texture bounds + if (any(greaterThanEqual(pos, uBlock.out_extents.xyz))) { + return; } + + const ivec3 ch_pos = ivec3(0, 0, pos.z % uBlock.out_extents.w); + + const vec4 in_tex = texelFetch(uInput, pos, 0); + const vec4 gamma_tex = texelFetch(uGamma, ch_pos, 0); + const vec4 beta_tex = texelFetch(uBeta, ch_pos, 0); + const vec4 mean_tex = texelFetch(uMean, ch_pos, 0); + const vec4 var_tex = texelFetch(uVar, ch_pos, 0); + + const vec4 out_tex = + (in_tex - mean_tex) / sqrt(var_tex + uBlock.eps) * gamma_tex + beta_tex; + + imageStore(uOutput, pos, out_tex); } diff --git a/aten/src/ATen/native/vulkan/glsl/buffer_to_buffer.glsl b/aten/src/ATen/native/vulkan/glsl/buffer_to_buffer.glsl new file mode 100644 index 0000000000000..7a67a8ca37372 --- /dev/null +++ b/aten/src/ATen/native/vulkan/glsl/buffer_to_buffer.glsl @@ -0,0 +1,78 @@ +#version 450 core + +#define PRECISION $precision +#define FORMAT $format + +#include "indexing.h" + +layout(std430) buffer; + +/* + * Output Buffer + */ +layout(set = 0, binding = 0) buffer PRECISION restrict writeonly OutBuffer { + float data[]; +} +uOutput; + +/* + * Output Buffer Metadata + */ +layout(set = 0, binding = 1) uniform PRECISION restrict OutMeta { + uvec4 sizes; + uvec4 strides; + uint ndim; + uint buf_length; +} +uOutMeta; + +/* + * Input Buffer + */ +layout(set = 0, binding = 2) buffer PRECISION restrict readonly InBuffer { + float data[]; +} +uInput; + +/* + * Input Buffer Metadata + */ +layout(set = 0, binding = 3) uniform PRECISION restrict InMeta { + uvec4 sizes; + uvec4 strides; + uint ndim; + uint buf_length; +} +uInMeta; + +/* + * Local Work Group Size + */ +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +/* + * Copies data from the tensor at uInput to the tensor at uOutput based on 4D + * coordinate. Each element at (x,y,c,n) in uInput will be copied to uOutput at + * (x,y,c,n). If (x,y,c,n) is outside the bounds of uInput then 0 will be + * written. + * + * Each shader invocation is responsible for one element of the output buffer. + */ +void main() { + const uint write_idx = ivec3(gl_GlobalInvocationID).x; + + if (write_idx >= uOutMeta.buf_length) { + return; + } + + uvec4 write_coord = + idx_to_coord(write_idx, uOutMeta.strides, uOutMeta.sizes); + + float outval = 0u; + if (all(lessThan(write_coord, uInMeta.sizes))) { + uint read_idx = coord_to_idx(write_coord, uInMeta.strides); + outval = uInput.data[read_idx]; + } + + uOutput.data[write_idx] = outval; +} diff --git a/aten/src/ATen/native/vulkan/glsl/conv2d.glsl b/aten/src/ATen/native/vulkan/glsl/conv2d.glsl index 4afae20127e80..9d73356c71e7e 100644 --- a/aten/src/ATen/native/vulkan/glsl/conv2d.glsl +++ b/aten/src/ATen/native/vulkan/glsl/conv2d.glsl @@ -2,6 +2,12 @@ #define PRECISION $precision #define FORMAT $format +/* + * TILE_SIZE = (1, 1, 1) + * WEIGHT_STORAGE = TEXTURE_2D + * BIAS_STORAGE = TEXTURE_2D + */ + layout(std430) buffer; /* diff --git a/aten/src/ATen/native/vulkan/glsl/conv2d_dw.glsl b/aten/src/ATen/native/vulkan/glsl/conv2d_dw.glsl index 671b6410e61df..ab2ce6459c67d 100644 --- a/aten/src/ATen/native/vulkan/glsl/conv2d_dw.glsl +++ b/aten/src/ATen/native/vulkan/glsl/conv2d_dw.glsl @@ -2,6 +2,13 @@ #define PRECISION $precision #define FORMAT $format +/* + * TILE_SIZE = (1, 1, 1) + * WEIGHT_STORAGE = TEXTURE_2D + * BIAS_STORAGE = TEXTURE_2D + * Note that for DW kernel IC = 1 so the weight layout is really OC4, H, W, 4oc + */ + layout(std430) buffer; /* @@ -60,24 +67,22 @@ void main() { // Compute the start and end of the input indices to load. Padding is assumed // to be constant 0 padding, so any reads from the padding region is skipped. - const ivec2 start = max(ivec2(0), ipos); - const ivec2 end = min(ipos + uBlock.overlay_region.xy, uBlock.in_extents.xy); - // Compute the start of the kernel based on how far we are skipping ahead when - // reading the input - const ivec2 kstart = (start - ipos) / uBlock.dilate; + const ivec2 start = ipos; + const ivec2 end = ipos + uBlock.overlay_region.xy; vec4 sum = texelFetch(uBias, ivec2(pos.z, 0), 0); const int dil_y = uBlock.dilate.y; const int dil_x = uBlock.dilate.x; - for (int y = start.y, ky = kstart.y; y < end.y; y += dil_y, ky++) { - for (int x = start.x, kx = kstart.x; x < end.x; x += dil_x, kx++) { + int k_ind = 0; + for (int y = start.y; y < end.y; y += dil_y) { + for (int x = start.x; x < end.x; x += dil_x) { // The weight kernel was rearranged so that every NxN filter was flattened // so that it fits on one row. Each filter was then stacked on top of each // other vertically. - const int k_ind = kx + ky * uBlock.kernel_size.x; const vec4 k_tex = texelFetch(uKernel, ivec2(k_ind, pos.z), 0); const vec4 i_tex = texelFetch(uInput, ivec3(x, y, pos.z), 0); sum = fma(i_tex, k_tex, sum); + k_ind++; } } diff --git a/aten/src/ATen/native/vulkan/glsl/conv_transpose2d.glsl b/aten/src/ATen/native/vulkan/glsl/conv_transpose2d.glsl index ba9fbbd8df363..b3c983fc52149 100644 --- a/aten/src/ATen/native/vulkan/glsl/conv_transpose2d.glsl +++ b/aten/src/ATen/native/vulkan/glsl/conv_transpose2d.glsl @@ -2,6 +2,12 @@ #define PRECISION $precision #define FORMAT $format +/* + * TILE_SIZE = (1, 1, 1) + * WEIGHT_STORAGE = TEXTURE_2D + * BIAS_STORAGE = TEXTURE_2D + */ + layout(std430) buffer; /* Qualifiers: layout - storage - precision - memory */ diff --git a/aten/src/ATen/native/vulkan/glsl/image_to_nchw_int32.glsl b/aten/src/ATen/native/vulkan/glsl/image_to_nchw_int32.glsl new file mode 100644 index 0000000000000..f6f1a48105a7e --- /dev/null +++ b/aten/src/ATen/native/vulkan/glsl/image_to_nchw_int32.glsl @@ -0,0 +1,52 @@ +#version 450 core +#define PRECISION $precision + +layout(std430) buffer; + +/* + * Input Sampler + */ +layout(set = 0, binding = 0) uniform PRECISION isampler3D uImage; + +/* + * Output Buffer + */ +layout(set = 0, binding = 1) buffer PRECISION restrict writeonly Buffer { + int data[]; +} +uBuffer; + +/* + * Params Buffer + */ +layout(set = 0, binding = 2) uniform PRECISION restrict Block { + // xyz contain the extents of the input texture, w contains HxW to help + // calculate buffer offsets + ivec4 in_extents; +} +uBlock; + +/* + * Local Work Group in_extents + */ +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (any(greaterThanEqual(pos, uBlock.in_extents.xyz))) { + return; + } + + const ivec4 intex = texelFetch(uImage, pos, 0); + + const int base_index = + pos.x + uBlock.in_extents.x * pos.y + (4 * uBlock.in_extents.w) * pos.z; + const ivec4 buf_indices = + base_index + ivec4(0, 1, 2, 3) * uBlock.in_extents.w; + + uBuffer.data[buf_indices.x] = intex.x; + uBuffer.data[buf_indices.y] = intex.y; + uBuffer.data[buf_indices.z] = intex.z; + uBuffer.data[buf_indices.w] = intex.w; +} diff --git a/aten/src/ATen/native/vulkan/glsl/image_to_nchw_quantized.glsl b/aten/src/ATen/native/vulkan/glsl/image_to_nchw_quantized.glsl index 2f5999b465e35..3fe0447a33a53 100644 --- a/aten/src/ATen/native/vulkan/glsl/image_to_nchw_quantized.glsl +++ b/aten/src/ATen/native/vulkan/glsl/image_to_nchw_quantized.glsl @@ -11,7 +11,7 @@ layout(set = 0, binding = 0) uniform PRECISION isampler3D uImage; /* * Output Buffer */ -layout(set = 0, binding = 1) buffer PRECISION Buffer { +layout(set = 0, binding = 1) buffer PRECISION restrict writeonly Buffer { uint data[]; } uBuffer; @@ -33,55 +33,47 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; void main() { const ivec3 pos = ivec3(gl_GlobalInvocationID); - if (pos.y == 0 && pos.z == 0) { - ivec4 texture_pos = ivec4(0, 1, 2, 3) + 4 * pos.x; + // each instance of the shader writes out a single element of the output + // the global size matches the size of the output, in other words: + // global size = {div_up(numel, 4), 1u, 1u} + // pos = {pos.x, 1, 1} where pos.x is the index of the output element - ivec4 last_eight; - last_eight.z = texture_pos.x / (uBlock.in_extents.x * uBlock.in_extents.y); - last_eight.w = texture_pos.x % (uBlock.in_extents.x * uBlock.in_extents.y); - last_eight.y = last_eight.w / uBlock.in_extents.x; - last_eight.x = last_eight.w % uBlock.in_extents.x; + ivec4 input_pos = ivec4(0, 1, 2, 3) + 4 * pos.x; + // each output element is a uint32 made up four consecutive uint8 from the + // input in nchw format. input_pos contains the positions of these four + // elements from the input in nchw format. - ivec4 sec_last_eight; - sec_last_eight.z = - texture_pos.y / (uBlock.in_extents.x * uBlock.in_extents.y); - sec_last_eight.w = - texture_pos.y % (uBlock.in_extents.x * uBlock.in_extents.y); - sec_last_eight.y = sec_last_eight.w / uBlock.in_extents.x; - sec_last_eight.x = sec_last_eight.w % uBlock.in_extents.x; + ivec4 nc_pos = input_pos / uBlock.in_extents.w; + // we divide by HxW (uBlock.in_extents.w), to find the position along the + // batch/channel axis of these four elements. - ivec4 thr_last_eight; - thr_last_eight.z = - texture_pos.z / (uBlock.in_extents.x * uBlock.in_extents.y); - thr_last_eight.w = - texture_pos.z % (uBlock.in_extents.x * uBlock.in_extents.y); - thr_last_eight.y = thr_last_eight.w / uBlock.in_extents.x; - thr_last_eight.x = thr_last_eight.w % uBlock.in_extents.x; + ivec4 w_pos = input_pos % uBlock.in_extents.w; + // we compute the reminder mod HxW, to find the positions in the flatten + // out HxW plane. - ivec4 four_last_eight; - four_last_eight.z = - texture_pos.w / (uBlock.in_extents.x * uBlock.in_extents.y); - four_last_eight.w = - texture_pos.w % (uBlock.in_extents.x * uBlock.in_extents.y); - four_last_eight.y = four_last_eight.w / uBlock.in_extents.x; - four_last_eight.x = four_last_eight.w % uBlock.in_extents.x; + ivec4 x_pos = w_pos % uBlock.in_extents.x; + ivec4 y_pos = w_pos / uBlock.in_extents.x; + // we divide this "flatten out position" by H, to find the positions along + // the y-axis (height) and we compute its reminder mod H, to find the + // position along the x-axis (width). - ivec3 last_eight_pos = ivec3(last_eight.x, last_eight.y, last_eight.z / 4); - ivec3 sec_last_eight_pos = - ivec3(sec_last_eight.x, sec_last_eight.y, sec_last_eight.z / 4); - ivec3 thr_last_eight_pos = - ivec3(thr_last_eight.x, thr_last_eight.y, thr_last_eight.z / 4); - ivec3 four_last_eight_pos = - ivec3(four_last_eight.x, four_last_eight.y, four_last_eight.z / 4); + ivec4 z_pos = nc_pos / 4; + ivec4 ix = nc_pos % 4; + // z_pos contains the texel positions along the z-axis, and ix the + // indices inside each texel. - int texel_1 = texelFetch(uImage, last_eight_pos, 0)[last_eight.z]; - int texel_2 = texelFetch(uImage, sec_last_eight_pos, 0)[sec_last_eight.z]; - int texel_3 = texelFetch(uImage, thr_last_eight_pos, 0)[thr_last_eight.z]; - int texel_4 = texelFetch(uImage, four_last_eight_pos, 0)[four_last_eight.z]; + // now we fetch each uint8 element from the input, and we write out a uint32 + // whose binary representation is equal to: tex3 tex2 tex1 tex0 - uint ui32 = (uint(texel_4 & 0xFF) << 24) | (uint(texel_3 & 0xFF) << 16) | - (uint(texel_2 & 0xFF) << 8) | (uint(texel_1 & 0xFF)); + int tex0 = texelFetch(uImage, ivec3(x_pos[0], y_pos[0], z_pos[0]), 0)[ix[0]]; + int tex1 = texelFetch(uImage, ivec3(x_pos[1], y_pos[1], z_pos[1]), 0)[ix[1]]; + int tex2 = texelFetch(uImage, ivec3(x_pos[2], y_pos[2], z_pos[2]), 0)[ix[2]]; + int tex3 = texelFetch(uImage, ivec3(x_pos[3], y_pos[3], z_pos[3]), 0)[ix[3]]; - uBuffer.data[texture_pos.x / 4] = ui32; - } + uint ui32 = (uint(tex3 & 0xFF) << 24) + | (uint(tex2 & 0xFF) << 16) + | (uint(tex1 & 0xFF) << 8) + | (uint(tex0 & 0xFF)); + + uBuffer.data[pos.x] = ui32; } diff --git a/aten/src/ATen/native/vulkan/glsl/image_to_nchw_quantized_mul4.glsl b/aten/src/ATen/native/vulkan/glsl/image_to_nchw_quantized_mul4.glsl new file mode 100644 index 0000000000000..210ed2b85ed66 --- /dev/null +++ b/aten/src/ATen/native/vulkan/glsl/image_to_nchw_quantized_mul4.glsl @@ -0,0 +1,75 @@ +#version 450 core +#define PRECISION $precision + +layout(std430) buffer; + +/* + * Input Sampler + */ +layout(set = 0, binding = 0) uniform PRECISION isampler3D uImage; + +/* + * Output Buffer + */ +layout(set = 0, binding = 1) buffer PRECISION restrict writeonly Buffer { + uint data[]; +} +uBuffer; + +/* + * Params Buffer + */ +layout(set = 0, binding = 2) uniform PRECISION restrict Block { + // xyz contain the extents of the input texture, w contains HxW to help + // calculate buffer offsets + ivec4 in_extents; +} +uBlock; + +/* + * Local Work Group in_extents + */ +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + // each instance of the shader writes out four elements of the output + // by processing 4 consecutive texels at the same depth. + // global size = {HxW / 4, 1u, z_extent}. + // this shader requires HxW to be a multiple of 4, so that multiple + // planes can be processed in parallel + + if (4 * pos.x >= uBlock.in_extents.w || + pos.y > 0 || + pos.z >= uBlock.in_extents.z) { + return; + } + + ivec4 xy_pos = ivec4(0, 1, 2, 3) + 4 * pos.x; + // each output element is a uint32 made up four consecutive uint8 from the + // input in nchw format. xy_pos contains the positions of these four + // elements from the input in the flatten out HxW plane. + + ivec4 x_pos = xy_pos % uBlock.in_extents.x; + ivec4 y_pos = xy_pos / uBlock.in_extents.x; + // we divide this "flatten out position" by H, to find the positions along + // the y-axis (height) and we compute its reminder mod H, to find the + // position along the x-axis (width). + + const ivec4 intex0 = texelFetch(uImage, ivec3(x_pos[0], y_pos[0], pos.z), 0); + const ivec4 intex1 = texelFetch(uImage, ivec3(x_pos[1], y_pos[1], pos.z), 0); + const ivec4 intex2 = texelFetch(uImage, ivec3(x_pos[2], y_pos[2], pos.z), 0); + const ivec4 intex3 = texelFetch(uImage, ivec3(x_pos[3], y_pos[3], pos.z), 0); + + const int base_index = 4 * pos.x + 4 * uBlock.in_extents.w * pos.z; + const ivec4 buf_indices = + base_index + ivec4(0, 1, 2, 3) * uBlock.in_extents.w; + + for (int i = 0; i < 4; i += 1) { + uint ui32 = (uint(intex3[i] & 0xFF) << 24) + | (uint(intex2[i] & 0xFF) << 16) + | (uint(intex1[i] & 0xFF) << 8) + | (uint(intex0[i] & 0xFF)); + uBuffer.data[buf_indices[i] / 4] = ui32; + } +} diff --git a/aten/src/ATen/native/vulkan/glsl/indexing.h b/aten/src/ATen/native/vulkan/glsl/indexing.h new file mode 100644 index 0000000000000..e7b6a29fc16ed --- /dev/null +++ b/aten/src/ATen/native/vulkan/glsl/indexing.h @@ -0,0 +1,13 @@ +/* + * Computes a 4D tensor co-ordinate from a linearized index + */ +uvec4 idx_to_coord(const uint idx, const uvec4 strides, const uvec4 sizes) { + return ivec4(mod(idx / strides, sizes)); +} + +/* + * Computes a linearized index from a 4D tensor co-ordinate + */ +uint coord_to_idx(const uvec4 coord, const uvec4 strides) { + return int(dot(coord * strides, ivec4(1))); +} diff --git a/aten/src/ATen/native/vulkan/glsl/nchw_to_image_int32.glsl b/aten/src/ATen/native/vulkan/glsl/nchw_to_image_int32.glsl new file mode 100644 index 0000000000000..1d0eb65e2604a --- /dev/null +++ b/aten/src/ATen/native/vulkan/glsl/nchw_to_image_int32.glsl @@ -0,0 +1,55 @@ +#version 450 core +#define PRECISION $precision +#define FORMAT $format + +layout(std430) buffer; + +/* Qualifiers: layout - storage - precision - memory */ + +/* + * Output Image + */ +layout(set = 0, binding = 0, rgba32i) uniform PRECISION restrict writeonly iimage3D uImage; + +/* + * Input Buffer + */ +layout(set = 0, binding = 1) buffer PRECISION restrict readonly Buffer { + int data[]; +} +uBuffer; + +/* + * Params Buffer + */ +layout(set = 0, binding = 2) uniform PRECISION restrict Block { + // xyz contain the extents of the input texture, w contains HxW to help + // calculate buffer offsets + ivec4 out_extents; +} +uBlock; + +/* + * Local Work Group Size + */ +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (any(greaterThanEqual(pos, uBlock.out_extents.xyz))) { + return; + } + + const int base_index = + pos.x + uBlock.out_extents.x * pos.y + (4 * uBlock.out_extents.w) * pos.z; + const ivec4 buf_indices = + base_index + ivec4(0, 1, 2, 3) * uBlock.out_extents.w; + + int val_x = uBuffer.data[buf_indices.x]; + int val_y = uBuffer.data[buf_indices.y]; + int val_z = uBuffer.data[buf_indices.z]; + int val_w = uBuffer.data[buf_indices.w]; + + imageStore(uImage, pos, ivec4(val_x, val_y, val_z, val_w)); +} diff --git a/aten/src/ATen/native/vulkan/glsl/nchw_to_image_int8.glsl b/aten/src/ATen/native/vulkan/glsl/nchw_to_image_int8.glsl new file mode 100644 index 0000000000000..4189ad219810c --- /dev/null +++ b/aten/src/ATen/native/vulkan/glsl/nchw_to_image_int8.glsl @@ -0,0 +1,85 @@ +#version 450 core +#define PRECISION $precision +#define FORMAT $format + +layout(std430) buffer; + +/* Qualifiers: layout - storage - precision - memory */ + +/* + * Output Image + */ +layout(set = 0, binding = 0, rgba8i) uniform PRECISION restrict writeonly iimage3D uImage; + +/* + * Input Buffer + */ +layout(set = 0, binding = 1) buffer PRECISION restrict readonly Buffer { + int data[]; +} +uBuffer; + +/* + * Params Buffer + */ +layout(set = 0, binding = 2) uniform PRECISION restrict Block { + // xyz contain the extents of the input texture, w contains HxW to help + // calculate buffer offsets + ivec4 out_extents; +} +uBlock; + +/* + * Local Work Group Size + */ +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +/* + * Extends sign of int8 + */ +int extend_sign(int x) { + if (x >> 7 == 1) { + return x | 0xFFFFFF00; + } + return x; +} + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (any(greaterThanEqual(pos, uBlock.out_extents.xyz))) { + return; + } + + const int base_index = + pos.x + uBlock.out_extents.x * pos.y + (4 * uBlock.out_extents.w) * pos.z; + const ivec4 buf_indices = + base_index + ivec4(0, 1, 2, 3) * uBlock.out_extents.w; + + int shift = (1 << 8) - 1; + ivec4 masks; + masks.x = shift << 8 * (buf_indices.x % 4); + masks.y = shift << 8 * (buf_indices.y % 4); + masks.z = shift << 8 * (buf_indices.z % 4); + masks.w = shift << 8 * (buf_indices.w % 4); + + int buf_in_1 = uBuffer.data[buf_indices.x / 4]; + int a_v = (buf_in_1 & masks.x) >> 8 * (buf_indices.x % 4); + a_v = extend_sign(a_v); + + int buf_in_2 = uBuffer.data[buf_indices.y / 4]; + int b_v = (buf_in_2 & masks.y) >> 8 * (buf_indices.y % 4); + b_v = extend_sign(b_v); + + int buf_in_3 = uBuffer.data[buf_indices.z / 4]; + int g_v = (buf_in_3 & masks.z) >> 8 * (buf_indices.z % 4); + g_v = extend_sign(g_v); + + int buf_in_4 = uBuffer.data[buf_indices.w / 4]; + int r_v = (buf_in_4 & masks.w) >> 8 * (buf_indices.w % 4); + r_v = extend_sign(r_v); + + ivec4 texel = ivec4(a_v, b_v, g_v, r_v); + + imageStore(uImage, pos, texel); +} diff --git a/aten/src/ATen/native/vulkan/glsl/nchw_to_image_quantized.glsl b/aten/src/ATen/native/vulkan/glsl/nchw_to_image_uint8.glsl similarity index 98% rename from aten/src/ATen/native/vulkan/glsl/nchw_to_image_quantized.glsl rename to aten/src/ATen/native/vulkan/glsl/nchw_to_image_uint8.glsl index cca8d88fcd7d5..68adb45fa37b1 100644 --- a/aten/src/ATen/native/vulkan/glsl/nchw_to_image_quantized.glsl +++ b/aten/src/ATen/native/vulkan/glsl/nchw_to_image_uint8.glsl @@ -7,12 +7,12 @@ layout(std430) buffer; /* Qualifiers: layout - storage - precision - memory */ /* - * Input Sampler + * Output Image */ layout(set = 0, binding = 0, rgba8ui) uniform PRECISION restrict writeonly uimage3D uImage; /* - * Output Buffer + * Input Buffer */ layout(set = 0, binding = 1) buffer PRECISION restrict readonly Buffer { uint data[]; diff --git a/aten/src/ATen/native/vulkan/glsl/quantize_per_tensor_qint32.glsl b/aten/src/ATen/native/vulkan/glsl/quantize_per_tensor_qint32.glsl new file mode 100644 index 0000000000000..75fca31ee23b2 --- /dev/null +++ b/aten/src/ATen/native/vulkan/glsl/quantize_per_tensor_qint32.glsl @@ -0,0 +1,31 @@ +#version 450 core +#define PRECISION $precision +#define FORMAT $format + +layout(std430) buffer; + +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, rgba32i) uniform PRECISION restrict writeonly iimage3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; //input +layout(set = 0, binding = 2) uniform PRECISION restrict Block { + ivec4 size; + vec2 scale; + ivec2 zero_point; +} uBlock; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + if (all(lessThan(pos, uBlock.size.xyz))) { + vec4 q_res = roundEven(texelFetch(uInput, pos, 0) / uBlock.scale.x) + uBlock.zero_point.x; + + ivec4 ret = ivec4(q_res); + + imageStore( + uOutput, + pos, + ret); + } +} diff --git a/aten/src/ATen/native/vulkan/glsl/quantize_per_tensor_qint8.glsl b/aten/src/ATen/native/vulkan/glsl/quantize_per_tensor_qint8.glsl new file mode 100644 index 0000000000000..2ba863d321312 --- /dev/null +++ b/aten/src/ATen/native/vulkan/glsl/quantize_per_tensor_qint8.glsl @@ -0,0 +1,31 @@ +#version 450 core +#define PRECISION $precision +#define FORMAT $format + +layout(std430) buffer; + +/* Qualifiers: layout - storage - precision - memory */ + +layout(set = 0, binding = 0, rgba8i) uniform PRECISION restrict writeonly iimage3D uOutput; +layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; //input +layout(set = 0, binding = 2) uniform PRECISION restrict Block { + ivec4 size; + vec2 scale; + ivec2 zero_point; +} uBlock; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + if (all(lessThan(pos, uBlock.size.xyz))) { + vec4 q_res = roundEven(texelFetch(uInput, pos, 0) / uBlock.scale.x) + uBlock.zero_point.x; + + ivec4 ret = ivec4(q_res); + + imageStore( + uOutput, + pos, + ret); + } +} diff --git a/aten/src/ATen/native/vulkan/glsl/quantize_per_tensor.glsl b/aten/src/ATen/native/vulkan/glsl/quantize_per_tensor_quint8.glsl similarity index 81% rename from aten/src/ATen/native/vulkan/glsl/quantize_per_tensor.glsl rename to aten/src/ATen/native/vulkan/glsl/quantize_per_tensor_quint8.glsl index 910603aa29f26..f67954ad48c14 100644 --- a/aten/src/ATen/native/vulkan/glsl/quantize_per_tensor.glsl +++ b/aten/src/ATen/native/vulkan/glsl/quantize_per_tensor_quint8.glsl @@ -19,11 +19,13 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; void main() { const ivec3 pos = ivec3(gl_GlobalInvocationID); if (all(lessThan(pos, uBlock.size.xyz))) { - vec4 ret = texelFetch(uInput, pos, 0) / uBlock.scale.x + uBlock.zero_point.x; - uvec4 texel = uvec4(int(ret.x), int(ret.y), int(ret.z), int(ret.w)); + vec4 q_res = roundEven(texelFetch(uInput, pos, 0) / uBlock.scale.x) + uBlock.zero_point.x; + + uvec4 ret = uvec4(q_res); + imageStore( uOutput, pos, - texel); + ret); } } diff --git a/aten/src/ATen/native/vulkan/glsl/quantized_add.glsl b/aten/src/ATen/native/vulkan/glsl/quantized_add.glsl index 8f6e51397d1c1..a526dc2121bf7 100644 --- a/aten/src/ATen/native/vulkan/glsl/quantized_add.glsl +++ b/aten/src/ATen/native/vulkan/glsl/quantized_add.glsl @@ -34,9 +34,9 @@ void main() { vec4 deq_in_1 = uBlock.in_scale.y * (texel1 - uBlock.in_zero_point.y); vec4 res = deq_in_0 + deq_in_1; - vec4 q_res = res / uBlock.out_scale.x + uBlock.out_zero_point.x; + vec4 q_res = roundEven(res / uBlock.out_scale.x) + uBlock.out_zero_point.x; - uvec4 ret = uvec4(int(q_res.x), int(q_res.y), int(q_res.z), int(q_res.w)); + uvec4 ret = uvec4(q_res); imageStore( uOutput, diff --git a/aten/src/ATen/native/vulkan/glsl/quantized_conv2d.glsl b/aten/src/ATen/native/vulkan/glsl/quantized_conv2d.glsl index a53078b8b269f..63bf055761cc9 100644 --- a/aten/src/ATen/native/vulkan/glsl/quantized_conv2d.glsl +++ b/aten/src/ATen/native/vulkan/glsl/quantized_conv2d.glsl @@ -2,6 +2,12 @@ #define PRECISION $precision #define FORMAT $format +/* + * TILE_SIZE = (1, 1, 1) + * WEIGHT_STORAGE = TEXTURE_3D + * BIAS_STORAGE = TEXTURE_3D + */ + layout(std430) buffer; /* Qualifiers: layout - storage - precision - memory */ @@ -58,7 +64,7 @@ vec4 dequantize(vec4 tex, float scale, int zero_point) { * Quantizes a float texel based on a scale and zero point. */ uvec4 quantize(vec4 tex, float scale, int zero_point) { - return uvec4(tex / scale + zero_point); + return uvec4(roundEven(tex / scale) + zero_point); } /* diff --git a/aten/src/ATen/native/vulkan/glsl/quantized_conv2d_dw.glsl b/aten/src/ATen/native/vulkan/glsl/quantized_conv2d_dw.glsl index d842ab97bcc8d..0d823620a517f 100644 --- a/aten/src/ATen/native/vulkan/glsl/quantized_conv2d_dw.glsl +++ b/aten/src/ATen/native/vulkan/glsl/quantized_conv2d_dw.glsl @@ -2,6 +2,13 @@ #define PRECISION $precision #define FORMAT $format +/* + * TILE_SIZE = (1, 1, 1) + * WEIGHT_STORAGE = TEXTURE_3D + * BIAS_STORAGE = TEXTURE_3D + * Note that for DW kernel IC = 1 so the weight layout is really OC4, H, W, 4oc + */ + layout(std430) buffer; /* Qualifiers: layout - storage - precision - memory */ @@ -58,7 +65,7 @@ vec4 dequantize(vec4 tex, float scale, int zero_point) { * Quantizes a float texel based on a scale and zero point. */ uvec4 quantize(vec4 tex, float scale, int zero_point) { - return uvec4(tex / scale + zero_point); + return uvec4(roundEven(tex / scale) + zero_point); } void main() { diff --git a/aten/src/ATen/native/vulkan/glsl/quantized_conv2d_pw_2x2.glsl b/aten/src/ATen/native/vulkan/glsl/quantized_conv2d_pw_2x2.glsl index 21e6d1a607f19..2ef6d3d60f324 100644 --- a/aten/src/ATen/native/vulkan/glsl/quantized_conv2d_pw_2x2.glsl +++ b/aten/src/ATen/native/vulkan/glsl/quantized_conv2d_pw_2x2.glsl @@ -2,6 +2,12 @@ #define PRECISION $precision #define FORMAT $format +/* + * TILE_SIZE = (2, 2, 1) + * WEIGHT_STORAGE = TEXTURE_3D + * BIAS_STORAGE = TEXTURE_3D + */ + /* * Output Image */ @@ -54,7 +60,7 @@ vec4 dequantize(vec4 tex, float scale, int zero_point) { * Quantizes a float texel based on a scale and zero point. */ uvec4 quantize(vec4 tex, float scale, int zero_point) { - return uvec4(tex / scale + zero_point); + return uvec4(roundEven(tex / scale) + zero_point); } /* diff --git a/aten/src/ATen/native/vulkan/glsl/quantized_div.glsl b/aten/src/ATen/native/vulkan/glsl/quantized_div.glsl index aa961eb349934..1998c5abbca38 100644 --- a/aten/src/ATen/native/vulkan/glsl/quantized_div.glsl +++ b/aten/src/ATen/native/vulkan/glsl/quantized_div.glsl @@ -34,9 +34,9 @@ void main() { vec4 deq_in_1 = uBlock.in_scale.y * (texel1 - uBlock.in_zero_point.y); vec4 res = deq_in_0 / deq_in_1; - vec4 q_res = res / uBlock.out_scale.x + uBlock.out_zero_point.x; + vec4 q_res = roundEven(res / uBlock.out_scale.x) + uBlock.out_zero_point.x; - uvec4 ret = uvec4(int(q_res.x), int(q_res.y), int(q_res.z), int(q_res.w)); + uvec4 ret = uvec4(q_res); imageStore( uOutput, diff --git a/aten/src/ATen/native/vulkan/glsl/quantized_mul.glsl b/aten/src/ATen/native/vulkan/glsl/quantized_mul.glsl index 459f56915d774..c1ce18dbb38c1 100644 --- a/aten/src/ATen/native/vulkan/glsl/quantized_mul.glsl +++ b/aten/src/ATen/native/vulkan/glsl/quantized_mul.glsl @@ -34,9 +34,9 @@ void main() { vec4 deq_in_1 = uBlock.in_scale.y * (texel1 - uBlock.in_zero_point.y); vec4 res = deq_in_0 * deq_in_1; - vec4 q_res = res / uBlock.out_scale.x + uBlock.out_zero_point.x; + vec4 q_res = roundEven(res / uBlock.out_scale.x) + uBlock.out_zero_point.x; - uvec4 ret = uvec4(int(q_res.x), int(q_res.y), int(q_res.z), int(q_res.w)); + uvec4 ret = uvec4(q_res); imageStore( uOutput, diff --git a/aten/src/ATen/native/vulkan/glsl/quantized_sub.glsl b/aten/src/ATen/native/vulkan/glsl/quantized_sub.glsl index 6bd00f33a89c0..767181f080fdd 100644 --- a/aten/src/ATen/native/vulkan/glsl/quantized_sub.glsl +++ b/aten/src/ATen/native/vulkan/glsl/quantized_sub.glsl @@ -34,9 +34,9 @@ void main() { vec4 deq_in_1 = uBlock.in_scale.y * (texel1 - uBlock.in_zero_point.y); vec4 res = deq_in_0 - deq_in_1; - vec4 q_res = res / uBlock.out_scale.x + uBlock.out_zero_point.x; + vec4 q_res = roundEven(res / uBlock.out_scale.x) + uBlock.out_zero_point.x; - uvec4 ret = uvec4(int(q_res.x), int(q_res.y), int(q_res.z), int(q_res.w)); + uvec4 ret = uvec4(q_res); imageStore( uOutput, diff --git a/aten/src/ATen/native/vulkan/glsl/quantized_upsample_nearest2d.glsl b/aten/src/ATen/native/vulkan/glsl/quantized_upsample_nearest2d.glsl index 28c167515405e..46abbb1a8d768 100644 --- a/aten/src/ATen/native/vulkan/glsl/quantized_upsample_nearest2d.glsl +++ b/aten/src/ATen/native/vulkan/glsl/quantized_upsample_nearest2d.glsl @@ -25,8 +25,7 @@ void main() { ivec2(0), uBlock.isize); - vec4 texel = texelFetch(uInput, ivec3(ipos, pos.z), 0); - uvec4 ret = uvec4(int(texel.r), int(texel.g), int(texel.b), int(texel.a)); + uvec4 ret = texelFetch(uInput, ivec3(ipos, pos.z), 0); imageStore( uOutput, diff --git a/aten/src/ATen/native/vulkan/glsl/templates/conv2d_dw.glslt b/aten/src/ATen/native/vulkan/glsl/templates/conv2d_dw.glslt new file mode 100644 index 0000000000000..3afbefa2be492 --- /dev/null +++ b/aten/src/ATen/native/vulkan/glsl/templates/conv2d_dw.glslt @@ -0,0 +1,88 @@ +/* + * KERNEL_SIZE = ($KERNEL_SIZE_X, $KERNEL_SIZE_Y) + * TILE_SIZE = (1, 1, 1) + * WEIGHT_STORAGE = TEXTURE_2D + * BIAS_STORAGE = TEXTURE_2D + * Note that for DW kernel IC = 1 so the weight layout is really OC4, H, W, 4oc + */ + +layout(std430) buffer; + +/* + * Output Image + */ +layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOut; + +/* + * Input Textures + */ +layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput; +layout(set = 0, binding = 2) uniform PRECISION sampler2D uKernel; +layout(set = 0, binding = 3) uniform PRECISION sampler2D uBias; + +/* + * Params Buffer + */ +layout(set = 0, binding = 4) uniform PRECISION restrict Block { + // extents of the output texture + ivec4 out_extents; + // extents of the input texture + ivec4 in_extents; + // size of the overlay region of the kernel + ivec4 overlay_region; + // width and height of the kernel + ivec2 kernel_size; + // convolution parameters + ivec2 stride; + ivec2 padding; + ivec2 dilate; + vec2 clamp_thresh; +} +uBlock; + +/* + * Local Work Group + */ +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +/* + * Computes depthwise convolution. Each shader invocation calculates the output + * of a single output location. + */ +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + // Return if this global position is outside output texture bounds + if (any(greaterThanEqual(pos, uBlock.out_extents.xyz))) { + return; + } + + // Compute the index of the top-left element of the overlay region. Note that + // negative indices can be produced indicating that the top-left element is in + // a region added by padding. + const ivec2 ipos = pos.xy * uBlock.stride - uBlock.padding; + + // Compute the start and end of the input indices to load. Padding is assumed + // to be constant 0 padding, so any reads from the padding region is skipped. + const ivec2 start = ipos; + const ivec2 end = ipos + uBlock.overlay_region.xy; + + vec4 sum = texelFetch(uBias, ivec2(pos.z, 0), 0); + const int dil_y = uBlock.dilate.y; + const int dil_x = uBlock.dilate.x; + int k_ind = 0; + for (int y = start.y, i = 0; i < $KERNEL_SIZE_Y; y += dil_y, i++) { + for (int x = start.x, j = 0; j < $KERNEL_SIZE_X; x += dil_x, j++) { + // The weight kernel was rearranged so that every NxN filter was flattened + // so that it fits on one row. Each filter was then stacked on top of each + // other vertically. + const vec4 kernel_vals = texelFetch(uKernel, ivec2(k_ind, pos.z), 0); + const vec4 i_tex = texelFetch(uInput, ivec3(x, y, pos.z), 0); + sum = fma(i_tex, kernel_vals, sum); + k_ind++; + } + } + + imageStore( + uOut, pos, clamp(sum, uBlock.clamp_thresh.x, uBlock.clamp_thresh.y)); +} diff --git a/aten/src/ATen/native/vulkan/glsl/templates/conv2d_dw_params.yaml b/aten/src/ATen/native/vulkan/glsl/templates/conv2d_dw_params.yaml new file mode 100644 index 0000000000000..baddf153cb7bc --- /dev/null +++ b/aten/src/ATen/native/vulkan/glsl/templates/conv2d_dw_params.yaml @@ -0,0 +1,7 @@ +conv2d_dw: + parameter_names_with_default_values: + KERNEL_SIZE_X: 3 + KERNEL_SIZE_Y: 3 + parameter_values: + - KERNEL_SIZE_X: 5 + KERNEL_SIZE_Y: 5 diff --git a/aten/src/ATen/native/vulkan/glsl/conv2d_pw_2x2.glsl b/aten/src/ATen/native/vulkan/glsl/templates/conv2d_pw.glslt similarity index 78% rename from aten/src/ATen/native/vulkan/glsl/conv2d_pw_2x2.glsl rename to aten/src/ATen/native/vulkan/glsl/templates/conv2d_pw.glslt index b497f41587ff5..8f3c5a38db870 100644 --- a/aten/src/ATen/native/vulkan/glsl/conv2d_pw_2x2.glsl +++ b/aten/src/ATen/native/vulkan/glsl/templates/conv2d_pw.glslt @@ -1,6 +1,9 @@ -#version 450 core -#define PRECISION $precision -#define FORMAT $format +/* + * TILE_SIZE = ($TILE_SIZE_X, $TILE_SIZE_Y, 1) + * WEIGHT_STORAGE = TEXTURE_2D + * WEIGHT_STORAGE_LAYOUT = OC4,IC4,4ic,4oc + * BIAS_STORAGE = TEXTURE_2D + */ layout(std430) buffer; @@ -49,17 +52,19 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; void main() { const ivec3 gpos = ivec3(gl_GlobalInvocationID); - // Determine the output positions that will be written to. + // Output position for TILE_SIZE_X, TILE_SIZE_Y = 2, 2 // +--------+--------+ // | pos[0] | pos[1] | // +--------+--------+ // | pos[2] | pos[3] | // +--------+--------+ - ivec3 pos[4]; - pos[0] = ivec3(gpos.x * 2, gpos.y * 2, gpos.z); - pos[1] = ivec3(gpos.x * 2 + 1, gpos.y * 2, gpos.z); - pos[2] = ivec3(gpos.x * 2, gpos.y * 2 + 1, gpos.z); - pos[3] = ivec3(gpos.x * 2 + 1, gpos.y * 2 + 1, gpos.z); + ivec3 pos[$TILE_SIZE_X * $TILE_SIZE_Y]; + for (int y = 0, i = 0; y < $TILE_SIZE_Y; ++y) { + for (int x = 0; x < $TILE_SIZE_X; ++x) { + pos[i] = ivec3(gpos.x * $TILE_SIZE_X + x, gpos.y * $TILE_SIZE_Y + y, gpos.z); + i++; + } + } // If the top left position is out of bounds, then this invocation will have // no work to do. @@ -70,14 +75,14 @@ void main() { // Compute the index of the input texture that needs to be loaded for each // output position. Note that negative indices can be produced indicating that // the top-left element is in a region added by padding. - ivec2 ipos[4]; - for (int i = 0; i < 4; ++i) { + ivec2 ipos[$TILE_SIZE_X * $TILE_SIZE_Y]; + for (int i = 0; i < $TILE_SIZE_X * $TILE_SIZE_Y; ++i) { ipos[i] = pos[i].xy * uBlock.stride - uBlock.padding; } - vec4 sum[4]; + vec4 sum[$TILE_SIZE_X * $TILE_SIZE_Y]; sum[0] = texelFetch(uBias, ivec2(gpos.z, 0), 0); - for (int i = 1; i < 4; ++i) { + for (int i = 1; i < $TILE_SIZE_X * $TILE_SIZE_Y; ++i) { sum[i] = sum[0]; } @@ -87,13 +92,18 @@ void main() { // During prepacking, the weight tensor has been permuted so that the // channel (IC) dim is along the x axis, and the batch (OC) dim is along // the z axis. + vec4 in_tex[$TILE_SIZE_X * $TILE_SIZE_Y]; const vec4 ktex_0 = texelFetch(uKernel, ivec2(z + 0, gpos.z), 0); const vec4 ktex_1 = texelFetch(uKernel, ivec2(z + 1, gpos.z), 0); const vec4 ktex_2 = texelFetch(uKernel, ivec2(z + 2, gpos.z), 0); const vec4 ktex_3 = texelFetch(uKernel, ivec2(z + 3, gpos.z), 0); - for (int i = 0; i < 4; ++i) { - const vec4 in_tex = texelFetch(uInput, ivec3(ipos[i], z4), 0); + for (int i = 0; i < $TILE_SIZE_Y * $TILE_SIZE_X; ++i) { + in_tex[i] = texelFetch(uInput, ivec3(ipos[i], z4), 0); + } + + for (int i = 0; i < $TILE_SIZE_Y * $TILE_SIZE_X; ++i) { + // For 2x2 tile size algorithm works as follows. // To explain the calculations below, the contents one in_tex and the // group of 4 texels loaded from uKernel are shown: // @@ -126,15 +136,14 @@ void main() { // // which is what is expressed in the following calculations. This is done // for each output position. - - sum[i] = fma(in_tex.xxxx, ktex_0, sum[i]); - sum[i] = fma(in_tex.yyyy, ktex_1, sum[i]); - sum[i] = fma(in_tex.zzzz, ktex_2, sum[i]); - sum[i] = fma(in_tex.wwww, ktex_3, sum[i]); + sum[i] = fma(in_tex[i].xxxx, ktex_0, sum[i]); + sum[i] = fma(in_tex[i].yyyy, ktex_1, sum[i]); + sum[i] = fma(in_tex[i].zzzz, ktex_2, sum[i]); + sum[i] = fma(in_tex[i].wwww, ktex_3, sum[i]); } } - for (int i = 0; i < 4; ++i) { + for (int i = 0; i < $TILE_SIZE_Y * $TILE_SIZE_X; ++i) { if (all(lessThan(pos[i], uBlock.out_extents.xyz))) { imageStore( uOutput, diff --git a/aten/src/ATen/native/vulkan/glsl/templates/conv2d_pw_params.yaml b/aten/src/ATen/native/vulkan/glsl/templates/conv2d_pw_params.yaml new file mode 100644 index 0000000000000..fef8f20f4e733 --- /dev/null +++ b/aten/src/ATen/native/vulkan/glsl/templates/conv2d_pw_params.yaml @@ -0,0 +1,7 @@ +conv2d_pw: + parameter_names_with_default_values: + TILE_SIZE_X: 2 + TILE_SIZE_Y: 2 + parameter_values: + - TILE_SIZE_X: 1 + TILE_SIZE_Y: 1 diff --git a/aten/src/ATen/native/vulkan/ops/Batchnorm.cpp b/aten/src/ATen/native/vulkan/ops/Batchnorm.cpp index 84828aa60468c..d1fecca2abeb0 100644 --- a/aten/src/ATen/native/vulkan/ops/Batchnorm.cpp +++ b/aten/src/ATen/native/vulkan/ops/Batchnorm.cpp @@ -1,109 +1,44 @@ -#include +#include +#include #include namespace at { namespace native { namespace vulkan { namespace ops { -namespace { - -using namespace api::utils; -Tensor batch_norm( - const at::Tensor& input_arg, - const c10::optional& weight_opt /* optional */, - const c10::optional& bias_opt /* optional */, - const c10::optional& running_mean_opt /* optional */, - const c10::optional& running_var_opt /* optional */, - bool training, - double /* momentum, not used in eval mode */, - double eps, - bool /* cudnn_enable, deprecated */) { - TORCH_CHECK(!training, "Vulkan batchnorm only supports evaluation mode."); - TORCH_CHECK( - weight_opt && weight_opt->defined() && bias_opt && bias_opt->defined(), - "Vulkan batchnorm expects weight and bias arguments to be defined"); - TORCH_CHECK( - running_mean_opt && running_mean_opt->defined(), - "running_mean must be defined in evaluation mode."); - TORCH_CHECK( - running_var_opt && running_var_opt->defined(), - "running_var must be defined in evaluation mode."); - TORCH_CHECK(input_arg.dim() == 4, "Vulkan batchnorm expects 4-dim input!"); - TORCH_CHECK( - get_dim(input_arg) % 4 == 0, - "Vulkan batchnorm expects channel dim to be multiple of 4!"); - - const Tensor input = input_arg.is_vulkan() ? input_arg : input_arg.vulkan(); - const vTensor& v_input = convert(input); - const IntArrayRef v_input_sizes = v_input.sizes(); - - auto num_features = v_input.sizes()[1]; - auto channels_ext = num_features / 4; - - const Tensor weight_opt_3d = weight_opt->reshape({num_features, 1, 1}); - const Tensor weight = - weight_opt_3d.is_vulkan() ? weight_opt_3d : weight_opt_3d.vulkan(); - const vTensor& v_weight = convert(weight); - TORCH_CHECK( - weight.numel() == num_features, - "weight tensor should contain ", - num_features, - " elements!"); - - const Tensor bias_opt_3d = bias_opt->reshape({num_features, 1, 1}); - const Tensor bias = - bias_opt_3d.is_vulkan() ? bias_opt_3d : bias_opt_3d.vulkan(); - const vTensor& v_bias = convert(bias); - TORCH_CHECK( - bias.numel() == num_features, - "bias tensor should contain ", - num_features, - " elements!"); +namespace batchnorm { + +struct Params final { + api::utils::ivec3 out_extents; + int32_t c4; + float eps; +}; + +void record_op( + api::Context* const context, + vTensor& v_output, + const vTensor& v_input, + const vTensor& v_weight, + const vTensor& v_bias, + const vTensor& v_running_mean, + const vTensor& v_running_var, + const float eps) { + api::PipelineBarrier pipeline_barrier{}; - const Tensor running_mean_opt_3d = - running_mean_opt->reshape({num_features, 1, 1}); - const Tensor running_mean = running_mean_opt_3d.is_vulkan() - ? running_mean_opt_3d - : running_mean_opt_3d.vulkan(); - const vTensor& v_running_mean = convert(running_mean); - TORCH_CHECK( - running_mean.numel() == num_features, - "running mean tensor should contain ", - num_features, - " elements!"); - - const Tensor running_var_opt_3d = - running_var_opt->reshape({num_features, 1, 1}); - const Tensor running_var = running_var_opt_3d.is_vulkan() - ? running_var_opt_3d - : running_var_opt_3d.vulkan(); - const vTensor& v_running_var = convert(running_var); - TORCH_CHECK( - running_var.numel() == num_features, - "running var tensor should contain ", - num_features, - " elements!"); + api::utils::uvec3 global_size = v_output.extents(); + api::utils::uvec3 local_size = adaptive_work_group_size(global_size); - api::Context* const context = api::context(); + uint32_t num_features = get_dim(v_input.sizes()); + uint32_t channels_ext = api::utils::div_up(num_features, 4u); - vTensor v_output{ - context, - v_input_sizes, - v_input.options(), + Params block{ + api::utils::make_ivec3(v_output.extents()), + api::utils::safe_downcast(channels_ext), + eps, }; - const struct Block final { - uvec3 iextents; - int32_t channels_ext; - float epsilon; - } block{ - v_output.extents(), - safe_downcast(channels_ext), - safe_downcast(eps)}; - api::UniformParamsBuffer params(context, block); - api::PipelineBarrier pipeline_barrier{}; context->submit_compute_job( // shader descriptor @@ -111,9 +46,9 @@ Tensor batch_norm( // pipeline barrier pipeline_barrier, // global work group size - v_output.extents(), + global_size, // local work group size - adaptive_work_group_size(v_output.extents()), + local_size, // fence handle VK_NULL_HANDLE, // shader arguments @@ -128,8 +63,34 @@ Tensor batch_norm( v_running_var.image(pipeline_barrier, api::PipelineStage::COMPUTE), // params buffer params.buffer()); +} - return convert(v_output); +} // namespace batchnorm + +namespace { + +using namespace api::utils; + +Tensor batch_norm( + const at::Tensor& input_arg, + const c10::optional& weight_opt /* optional */, + const c10::optional& bias_opt /* optional */, + const c10::optional& running_mean_opt /* optional */, + const c10::optional& running_var_opt /* optional */, + bool training, + double /* momentum, not used in eval mode */, + double eps, + bool /* cudnn_enable, deprecated */) { + TORCH_CHECK(!training, "Only evaluation mode is supported!"); + TORCH_CHECK(input_arg.dim() == 4, "Input must have dim == 4!"); + TORCH_CHECK( + get_dim(input_arg) % 4 == 0, + "Input must have channels divisible by 4!"); + + return run_batchnorm_context( + input_arg, + c10::make_intrusive(BatchNormPackedContext( + weight_opt, bias_opt, running_mean_opt, running_var_opt, eps))); } #ifdef USE_VULKAN_API @@ -141,6 +102,143 @@ TORCH_LIBRARY_IMPL(aten, Vulkan, m) { #endif /* USE_VULKAN_API */ } // namespace + +BatchNormPackedContext::BatchNormPackedContext( + const c10::optional& weight_opt, + const c10::optional& bias_opt, + const c10::optional& running_mean_opt, + const c10::optional& running_var_opt, + double eps) + : unpacked_{c10::AnyType::get()} { + packed_.reserve(ListArgs::kNumArgs); + + // Each optional tensor arg, if provided should be a 1 dimensional tensor. To + // achieve more efficient packing as a texture, they are first reshaped to {N, + // 1, 1}. Eventually this rearrangement should happen automatically in vTensor + // itself. + + // Weight + TORCH_CHECK(weight_opt, "Weight must be provided!"); + TORCH_CHECK(weight_opt->dim() == 1, "Weight must have ndim == 1!"); + + const int64_t num_features = + api::utils::safe_downcast(weight_opt->numel()); + const Tensor weight_3d = weight_opt->reshape({num_features, 1, 1}); + packed_.emplace_back(weight_3d.vulkan()); + + // Bias + TORCH_CHECK(bias_opt, "Bias must be provided!"); + TORCH_CHECK(bias_opt->dim() == 1, "Bias must have ndim == 1!"); + TORCH_CHECK( + bias_opt->numel() == num_features, + "Bias must have the same numel as weight!"); + + const Tensor bias_3d = bias_opt->reshape({num_features, 1, 1}); + packed_.emplace_back(bias_3d.vulkan()); + + // Running Mean + TORCH_CHECK(running_mean_opt, "Running mean must be provided!"); + TORCH_CHECK(running_mean_opt->dim() == 1, "Running mean must have ndim == 1"); + TORCH_CHECK( + running_mean_opt->numel() == num_features, + "Running mean must have the same numel as weight!"); + + const Tensor running_mean_3d = + running_mean_opt->reshape({num_features, 1, 1}); + packed_.emplace_back(running_mean_3d.vulkan()); + + // Running var + TORCH_CHECK(running_var_opt, "Running var must be provided!"); + TORCH_CHECK(running_var_opt->dim() == 1, "Running var must have ndim == 1"); + TORCH_CHECK( + running_var_opt->numel() == num_features, + "Running var must have the same numel as weight!"); + + const Tensor running_var_3d = running_var_opt->reshape({num_features, 1, 1}); + packed_.emplace_back(running_var_3d.vulkan()); + + // Epsilon + packed_.emplace_back(eps); + + if (!at::globalContext().releaseWeightsWhenPrepacking()) { + unpacked_.reserve(ListArgs::kNumArgs); + unpacked_.emplace_back(weight_opt); + unpacked_.emplace_back(bias_opt); + unpacked_.emplace_back(running_mean_opt); + unpacked_.emplace_back(running_var_opt); + unpacked_.emplace_back(eps); + } +} + +BatchNormPackedContext BatchNormPackedContext::pack( + c10::impl::GenericList unpacked) { + return BatchNormPackedContext( + get_optional_tensor(unpacked, ListArgs::kWeight), + get_optional_tensor(unpacked, ListArgs::kBias), + get_optional_tensor(unpacked, ListArgs::kRunningMean), + get_optional_tensor(unpacked, ListArgs::kRunningVar), + unpacked.get(ListArgs::kEps).toDouble()); +} + +c10::intrusive_ptr create_batchnorm_context( + c10::optional&& weight_opt, + c10::optional&& bias_opt, + c10::optional&& running_mean_opt, + c10::optional&& running_var_opt, + bool training, + double /* momentum */, + double eps, + bool /* cudnn_enable, deprecated */) { + return c10::make_intrusive(BatchNormPackedContext( + weight_opt, bias_opt, running_mean_opt, running_var_opt, eps)); +} + +Tensor run_batchnorm_context( + const Tensor& input_arg, + const c10::intrusive_ptr& batchnorm_context) { + api::Context* const context = api::context(); + + const vTensor& v_input = convert(input_arg); + + const vTensor& v_weight = convert( + batchnorm_context->get_val(BatchNormPackedContext::ListArgs::kWeight) + .toTensor()); + + const vTensor& v_bias = convert( + batchnorm_context->get_val(BatchNormPackedContext::ListArgs::kBias) + .toTensor()); + + const vTensor& v_running_mean = convert( + batchnorm_context->get_val(BatchNormPackedContext::ListArgs::kRunningMean) + .toTensor()); + + const vTensor& v_running_var = convert( + batchnorm_context->get_val(BatchNormPackedContext::ListArgs::kRunningVar) + .toTensor()); + + const float eps = api::utils::safe_downcast( + batchnorm_context->get_val(BatchNormPackedContext::ListArgs::kEps) + .toDouble()); + + vTensor v_output{ + context, + v_input.sizes(), + v_input.options(), + }; + + batchnorm::record_op( + context, + v_output, + v_input, + v_weight, + v_bias, + v_running_mean, + v_running_var, + eps); + + return convert(v_output); +} + } // namespace ops } // namespace vulkan } // namespace native diff --git a/aten/src/ATen/native/vulkan/ops/Batchnorm.h b/aten/src/ATen/native/vulkan/ops/Batchnorm.h new file mode 100644 index 0000000000000..6afaeb6f243b3 --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/Batchnorm.h @@ -0,0 +1,68 @@ +#pragma once + +#ifdef USE_VULKAN_API + +#include +#include +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { + +class BatchNormPackedContext final : virtual public VulkanPackedContext, + public torch::jit::CustomClassHolder { + private: + c10::impl::GenericList unpacked_; + + public: + BatchNormPackedContext( + const c10::optional& weight_opt, + const c10::optional& bias_opt, + const c10::optional& running_mean_opt, + const c10::optional& running_var_opt, + double eps); + + /* + * Assigns a name to each index in the packed/unpacked list. + */ + struct ListArgs final { + static constexpr uint32_t kWeight = 0u; + static constexpr uint32_t kBias = 1u; + static constexpr uint32_t kRunningMean = 2u; + static constexpr uint32_t kRunningVar = 3u; + static constexpr uint32_t kEps = 4u; + + static constexpr uint32_t kNumArgs = 5u; + }; + + static BatchNormPackedContext pack(c10::impl::GenericList); + + const c10::impl::GenericList unpack() const override { + TORCH_CHECK(unpacked_.size() > 0u, "unpacked_ does not have any elements!"); + + return unpacked_; + } +}; + +c10::intrusive_ptr create_batchnorm_context( + c10::optional&& weight_opt, + c10::optional&& bias_opt, + c10::optional&& running_mean_opt, + c10::optional&& running_var_opt, + bool training, + double /* momentum */, + double eps, + bool /* cudnn_enable, deprecated */); + +Tensor run_batchnorm_context( + const Tensor& input_arg, + const c10::intrusive_ptr& context); + +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at + +#endif /* USE_VULKAN_API */ diff --git a/aten/src/ATen/native/vulkan/ops/Clone.cpp b/aten/src/ATen/native/vulkan/ops/Clone.cpp index de353a10cb931..2601d785ddb52 100644 --- a/aten/src/ATen/native/vulkan/ops/Clone.cpp +++ b/aten/src/ATen/native/vulkan/ops/Clone.cpp @@ -21,7 +21,7 @@ Tensor clone( TORCH_CHECK( (c10::MemoryFormat::Preserve == memory_format) || (c10::MemoryFormat::Contiguous == memory_format), - "Vulkan supports Preserve and Contiguous memory foramts"); + "Vulkan supports Preserve and Contiguous memory formats"); Tensor self; if (memory_format == MemoryFormat::Preserve) { diff --git a/aten/src/ATen/native/vulkan/ops/Common.cpp b/aten/src/ATen/native/vulkan/ops/Common.cpp index 5a3daeb074288..4c645ba3b1423 100644 --- a/aten/src/ATen/native/vulkan/ops/Common.cpp +++ b/aten/src/ATen/native/vulkan/ops/Common.cpp @@ -5,6 +5,15 @@ namespace native { namespace vulkan { namespace ops { +api::utils::uvec4 make_nchw_uvec4(const IntArrayRef arr) { + uint32_t w = get_dim(arr); + uint32_t h = get_dim(arr); + uint32_t c = get_dim(arr); + uint32_t n = get_dim(arr); + + return {w, h, c, n}; +} + api::utils::uvec3 adaptive_work_group_size( const api::utils::uvec3& global_work_group) { api::utils::uvec3 local_group_size = {4, 4, 4}; diff --git a/aten/src/ATen/native/vulkan/ops/Common.h b/aten/src/ATen/native/vulkan/ops/Common.h index 9d4e50c800955..4248417b3c991 100644 --- a/aten/src/ATen/native/vulkan/ops/Common.h +++ b/aten/src/ATen/native/vulkan/ops/Common.h @@ -106,6 +106,12 @@ uint32_t get_dim(const vTensor& v_in) { return get_dim(v_in.sizes()); } +/* + * Given an IntArrayRef of up to 4 elements, constructs a uvec4 containing those + * elements in reverse order. + */ +api::utils::uvec4 make_nchw_uvec4(const IntArrayRef arr); + inline c10::optional get_optional_tensor( const c10::impl::GenericList& gen_list, const uint32_t idx) { diff --git a/aten/src/ATen/native/vulkan/ops/Concat.cpp b/aten/src/ATen/native/vulkan/ops/Concat.cpp index ac15b3924b080..412bda4fcde06 100644 --- a/aten/src/ATen/native/vulkan/ops/Concat.cpp +++ b/aten/src/ATen/native/vulkan/ops/Concat.cpp @@ -37,15 +37,15 @@ Tensor cat_feature( const struct Block final { uvec3 size; // output texture size - uint32_t fill_0; // dummy + uint32_t fill0; // dummy uvec3 isize; // input texture size - uint32_t fill_1; // dummy - uint32_t batch_size; // input tensor's batch size - uint32_t ch_size; // input tensor's channel size + uint32_t fill1; // dummy + uint32_t batchSize; // input tensor's batch size + uint32_t chSize; // input tensor's channel size uint32_t - ch_interval; // channel interval (total # of channels for all tensors) + chInterval; // channel interval (total # of channels for all tensors) uint32_t - ch_size_allprior; // # of channels for tensor 0 to i-1 at ith tensor + chSizeAllprior; // # of channels for tensor 0 to i-1 at ith tensor } block{ v_output.extents(), 0u, @@ -181,10 +181,12 @@ Tensor cat_height( return convert(v_output); } -Tensor cat(const at::ITensorListRef& tensors, const int64_t dim) { +Tensor cat(const at::ITensorListRef& tensors, const int64_t in_dim) { TORCH_CHECK(tensors.size() > 0, "Vulkan cat expects at least one tensor"); + const int64_t dim = normalize_dim(in_dim, 4); auto materialized = tensors.materialize(); + TORCH_INTERNAL_ASSERT(materialized.size() > 0, "Accessing empty array"); const at::Tensor& tensor = materialized[0]; int64_t cat_dim_size = 0; bool is_mult4ch = true; @@ -209,6 +211,7 @@ Tensor cat(const at::ITensorListRef& tensors, const int64_t dim) { } auto result_size = tensor.sizes().vec(); + TORCH_INTERNAL_ASSERT(result_size.size() > 0, "Accessing empty array"); result_size[dim] = cat_dim_size; vTensor v_output{api::context(), result_size, tensor.options()}; diff --git a/aten/src/ATen/native/vulkan/ops/Convolution.cpp b/aten/src/ATen/native/vulkan/ops/Convolution.cpp index d1fca607cc768..9ab19a6e9b0f3 100644 --- a/aten/src/ATen/native/vulkan/ops/Convolution.cpp +++ b/aten/src/ATen/native/vulkan/ops/Convolution.cpp @@ -125,7 +125,7 @@ at::Tensor rearrange_weights_dw(const Tensor& weight_in) { // reshape to stack the resulting batches vertically weight = weight.permute({1, 0, 2, 3}).reshape({4, N4 * C, H * W}); - return weight; + return weight.contiguous(); } /* @@ -228,7 +228,7 @@ at::Tensor rearrange_weights_2d(const Tensor& weight_in, bool tconv) { // Collapse the outermost dim so that each group of 4 is stacked vertically weight = weight.permute({1, 0, 2, 3}).reshape({4, N4 * H, C_aligned * W}); - return weight; + return weight.contiguous(); } /* @@ -272,7 +272,7 @@ at::Tensor rearrange_bias( bias = bias.reshape({L4, 4}).permute({1, 0}); bias = bias.reshape({4, 1, L4}); - return bias; + return bias.contiguous(); } // @@ -287,7 +287,7 @@ static api::ShaderSource get_shader( const Conv2dMethod method, const bool transposed, const bool quantized) { - api::ShaderSource shader; + api::ShaderInfo shader; if (quantized) { if (transposed) { @@ -296,39 +296,46 @@ static api::ShaderSource get_shader( switch (method) { case Conv2dSlidingWindow: - shader = VK_KERNEL(quantized_conv2d); - return shader; + shader = VK_SHADER(quantized_conv2d); + break; case Conv2dDepthwise: - shader = VK_KERNEL(quantized_conv2d_dw); - return shader; + shader = VK_SHADER(quantized_conv2d_dw); + break; case Conv2dPointwise: - shader = VK_KERNEL(quantized_conv2d_pw_2x2); - // Set explicitly for now. In the future, this will be set automatically - // by shader codegen. - shader.out_tile_size = {2u, 2u, 1u}; - return shader; + shader = VK_SHADER(quantized_conv2d_pw_2x2); + break; + // todo fail for quantized transposed conv } + return shader.shader_src; } if (transposed) { - shader = VK_KERNEL(conv_transpose2d); - return shader; + shader = VK_SHADER(conv_transpose2d); + return shader.shader_src; } switch (method) { case Conv2dSlidingWindow: - shader = VK_KERNEL(conv2d); - return shader; + shader = VK_SHADER(conv2d); + break; case Conv2dDepthwise: - shader = VK_KERNEL(conv2d_dw); - return shader; + shader = VK_SHADER(conv2d_dw); + if (kernel_size.size() == 4 && kernel_size[2] == 3 && + kernel_size[3] == 3) { + // 1x1 refers to the output tile size + shader = VK_SHADER(conv2d_dw_3x3); + } + if (kernel_size.size() == 4 && kernel_size[2] == 5 && + kernel_size[3] == 5) { + // 1x1 refers to the output tile size + shader = VK_SHADER(conv2d_dw_5x5); + } + break; case Conv2dPointwise: - shader = VK_KERNEL(conv2d_pw_2x2); - // Set explicitly for now. In the future, this will be set automatically - // by shader codegen. - shader.out_tile_size = {2u, 2u, 1u}; - return shader; + shader = VK_SHADER(conv2d_pw_2x2); + break; } + return shader.shader_src; } // @@ -520,8 +527,8 @@ vTensor pack_weights( vTensor v_weight{ api::context(), weight_rearranged.sizes(), - quantized ? StorageType::TEXTURE_3D : StorageType::TEXTURE_2D, weight_arg.options(), + quantized ? api::StorageType::TEXTURE_3D : api::StorageType::TEXTURE_2D, }; if (quantized) { @@ -545,14 +552,14 @@ vTensor pack_biases( vTensor v_bias{ api::context(), bias_rearranged.sizes(), - quantized ? StorageType::TEXTURE_3D : StorageType::TEXTURE_2D, - weight.options(), + bias_rearranged.options(), + quantized ? api::StorageType::TEXTURE_3D : api::StorageType::TEXTURE_2D, }; if (quantized) { v_bias.set_is_quantized(); - v_bias.set_scale(weight.q_scale()); - v_bias.set_zero_point(weight.q_zero_point()); + v_bias.set_scale(bias_rearranged.q_scale()); + v_bias.set_zero_point(bias_rearranged.q_zero_point()); } pack_cpu_to_vulkan(bias_rearranged, v_bias); @@ -985,7 +992,7 @@ c10::intrusive_ptr create_qconv2d_context( dilation, /* transposed = */ false, /* quantized = */ true, - /* output_padding_arg = */ {}, + /* output_padding_arg = */ {0}, groups, output_min, output_max)); diff --git a/aten/src/ATen/native/vulkan/ops/Copy.cpp b/aten/src/ATen/native/vulkan/ops/Copy.cpp index dbac25e0c7ee3..4e414e4366bcb 100644 --- a/aten/src/ATen/native/vulkan/ops/Copy.cpp +++ b/aten/src/ATen/native/vulkan/ops/Copy.cpp @@ -1,3 +1,4 @@ +#include #include #include @@ -17,10 +18,15 @@ void memcpy_to_mapping(const Tensor& src, api::MemoryMap& dst_mapping) { memcpy_to_mapping_impl(src, dst_mapping); } else if (src.dtype() == c10::kQUInt8) { memcpy_to_mapping_impl(src, dst_mapping); + } else if (src.dtype() == c10::kQInt8) { + memcpy_to_mapping_impl(src, dst_mapping); + } else if (src.dtype() == c10::kQInt32) { + memcpy_to_mapping_impl(src, dst_mapping); } else { TORCH_CHECK( false, - "Invalid Data Type: expected c10::QUint8, at::kHalf or at::Float but got ", + "Invalid Data Type: expected c10::kQInt32, c10::kQInt8, c10::kQUInt8,", + " at::kHalf or at::Float but got ", src.dtype()); } } @@ -32,10 +38,15 @@ void memcpy_from_mapping(api::MemoryMap& src_mapping, Tensor& dst) { memcpy_from_mapping_impl(src_mapping, dst); } else if (dst.dtype() == c10::kQUInt8) { memcpy_from_mapping_impl(src_mapping, dst); + } else if (dst.dtype() == c10::kQInt8) { + memcpy_from_mapping_impl(src_mapping, dst); + } else if (dst.dtype() == c10::kQInt32) { + memcpy_from_mapping_impl(src_mapping, dst); } else { TORCH_CHECK( false, - "Invalid Data Type: expected c10::QUint8, at::kHalf or Float but got ", + "Invalid Data Type: expected c10::kQInt32, c10::kQInt8, c10::kQUInt8,", + " at::kHalf or at::Float but got ", dst.dtype()); } } @@ -52,7 +63,7 @@ void transfer_cpu_to_vulkan(const Tensor& src, vTensor& v_dst) { // a 16 bit format will be used for at::kFloat. Tensor src_nc4hw = utils::nchw_to_nc4hw(src).to(v_dst.texture_dtype()); - api::StorageBuffer staging(context, v_dst.texture_dtype(), v_dst.numcells()); + api::StorageBuffer staging(context, v_dst.texture_dtype(), v_dst.gpu_numel()); // Copy data into the staging buffer { api::MemoryMap mapping(staging.buffer(), api::MemoryAccessType::WRITE); @@ -71,7 +82,7 @@ void transfer_vulkan_to_cpu(vTensor& v_src, Tensor& dst) { // Temporary tensor to receive copied NC4HW data at::Tensor dst_tmp = utils::create_staging_tensor(v_src); - api::StorageBuffer staging(context, v_src.texture_dtype(), v_src.numcells()); + api::StorageBuffer staging(context, v_src.texture_dtype(), v_src.gpu_numel()); api::VulkanFence fence = context->fences().get_fence(); @@ -135,13 +146,16 @@ void transfer_vulkan_to_vulkan(vTensor& src, vTensor& dst) { void pack_cpu_to_vulkan(const Tensor& src, vTensor& dst) { api::Context* const context = api::context(); + // Ensure that src is contiguous in its memory format + Tensor src_contig = src.contiguous(src.suggest_memory_format()); + // Note that the float data type has been enforced for the storage buffer // below. The reason for this is that the nchw_to_image and image_to_nchw // shaders which perform the transfer to/from an image texture expect a buffer // of floats as input. GLSL/Vulkan does not natively support 16 bit arithmetic // types, so for now storage buffers created for compute shaders must define // floats as their base data type. - api::StorageBuffer staging(context, at::kFloat, dst.numcells()); + api::StorageBuffer staging(context, at::kFloat, dst.gpu_numel()); { api::MemoryMap mapping(staging.buffer(), api::MemoryAccessType::WRITE); @@ -150,20 +164,23 @@ void pack_cpu_to_vulkan(const Tensor& src, vTensor& dst) { // buffer as input (note that at::kFloat is used to create the StorageBuffer // above). if (src.dtype() == at::kHalf) { - memcpy_to_mapping(src.to(at::kFloat), mapping); + memcpy_to_mapping(src_contig.to(at::kFloat), mapping); } else { - memcpy_to_mapping(src, mapping); + memcpy_to_mapping(src_contig, mapping); } } utils::pack_staging_to_vtensor(staging.buffer(), dst); } void pack_vulkan_to_cpu(vTensor& src, Tensor& dst) { + TORCH_CHECK( + !src.is_quantized(), + "Copy of vulkan quantized tensors to cpu is currently disabled!"); api::Context* const context = api::context(); // Refer to the comment in pack_cpu_to_vulkan for why at::kFloat is specified // for the storage buffer below. - api::StorageBuffer staging(context, at::kFloat, src.numcells()); + api::StorageBuffer staging(context, at::kFloat, src.gpu_numel()); api::VulkanFence fence = context->fences().get_fence(); @@ -245,6 +262,28 @@ Tensor& copy_(Tensor& dst, const Tensor& src) { return dst; } +ops::vTensor to_vulkan(at::Tensor& src, const api::StorageType storage_type) { + TORCH_CHECK( + src.device().type() == at::kCPU, + "Vulkan to_vulkan(): input tensor must be a CPU tensor!") + + ops::vTensor v_ret{ + api::context(), + src.sizes(), + src.options().memory_format(src.suggest_memory_format()), + storage_type}; + + ops::pack_cpu_to_vulkan(src, v_ret); + + return v_ret; +} + +at::Tensor from_vulkan(ops::vTensor& v_src) { + at::Tensor ret = at::empty(v_src.sizes(), v_src.options().device(at::kCPU)); + ops::pack_vulkan_to_cpu(v_src, ret); + return ret; +} + } // namespace ops } // namespace vulkan } // namespace native diff --git a/aten/src/ATen/native/vulkan/ops/Copy.h b/aten/src/ATen/native/vulkan/ops/Copy.h index bf72a96b219fb..a91d500a1a343 100644 --- a/aten/src/ATen/native/vulkan/ops/Copy.h +++ b/aten/src/ATen/native/vulkan/ops/Copy.h @@ -19,6 +19,12 @@ void pack_vulkan_to_cpu(vTensor& src, Tensor& dst); Tensor& copy_(Tensor& dst, const Tensor& src); +ops::vTensor to_vulkan( + at::Tensor& src, + const api::StorageType storage_type = api::StorageType::TEXTURE_3D); + +at::Tensor from_vulkan(ops::vTensor& v_src); + // // Utility functions for memcpy // @@ -28,7 +34,7 @@ void memcpy_to_mapping_impl(const Tensor& src, api::MemoryMap& dst_mapping) { T* data_ptr = dst_mapping.template data(); memcpy( data_ptr, - src.contiguous().data_ptr(), + src.data_ptr(), std::min(src.nbytes(), dst_mapping.nbytes())); } diff --git a/aten/src/ATen/native/vulkan/ops/Mm.cpp b/aten/src/ATen/native/vulkan/ops/Mm.cpp index b003a322804ad..c8225f08354ef 100644 --- a/aten/src/ATen/native/vulkan/ops/Mm.cpp +++ b/aten/src/ATen/native/vulkan/ops/Mm.cpp @@ -43,7 +43,7 @@ vTensor pack_weights(const Tensor& weight_arg) { weight.options(), }; - api::StorageBuffer staging(context, at::kFloat, v_weight.numcells()); + api::StorageBuffer staging(context, at::kFloat, v_weight.gpu_numel()); { api::MemoryMap mapping(staging.buffer(), api::MemoryAccessType::WRITE); @@ -106,7 +106,7 @@ vTensor pack_biases( bias_arg->options(), }; - api::StorageBuffer staging(context, at::kFloat, v_bias.numcells()); + api::StorageBuffer staging(context, at::kFloat, v_bias.gpu_numel()); { api::MemoryMap mapping(staging.buffer(), api::MemoryAccessType::WRITE); @@ -135,7 +135,7 @@ vTensor pack_biases( weight_arg.options(), }; - api::StorageBuffer staging(context, at::kFloat, v_bias.numcells()); + api::StorageBuffer staging(context, at::kFloat, v_bias.gpu_numel()); { api::MemoryMap mapping(staging.buffer(), api::MemoryAccessType::WRITE); diff --git a/aten/src/ATen/native/vulkan/ops/QuantizedTensor.cpp b/aten/src/ATen/native/vulkan/ops/QuantizedTensor.cpp index c4ba030b5bb4e..4bb0880383575 100644 --- a/aten/src/ATen/native/vulkan/ops/QuantizedTensor.cpp +++ b/aten/src/ATen/native/vulkan/ops/QuantizedTensor.cpp @@ -10,12 +10,29 @@ namespace ops { using namespace api::utils; +static api::ShaderSource get_quantize_per_tensor_shader( + const c10::ScalarType dtype) { + switch (dtype) { + case c10::ScalarType::QUInt8: + return VK_KERNEL(quantize_per_tensor_quint8); + case c10::ScalarType::QInt8: + return VK_KERNEL(quantize_per_tensor_qint8); + case c10::ScalarType::QInt32: + return VK_KERNEL(quantize_per_tensor_qint32); + default: + TORCH_CHECK( + false, + "Vulkan quantization currently not supported for dtype ", + dtype); + } +} + Tensor quantize_per_tensor( const at::Tensor& input_arg, const double scale, const int64_t zero_point, const c10::ScalarType dtype) { - TORCH_CHECK(dtype == c10::ScalarType::QUInt8, "Expected type c10::kQUint8"); + api::ShaderSource compute_shader = get_quantize_per_tensor_shader(dtype); api::Context* const context = api::context(); @@ -23,11 +40,7 @@ Tensor quantize_per_tensor( const vTensor& v_input = convert(input); vTensor v_output{ - context, - input.sizes(), - input.options().dtype(c10::kQUInt8), - scale, - zero_point}; + context, input.sizes(), input.options().dtype(dtype), scale, zero_point}; const struct Block final { uvec3 extents; @@ -50,7 +63,7 @@ Tensor quantize_per_tensor( context->submit_compute_job( // shader descriptor - VK_KERNEL(quantize_per_tensor), + compute_shader, // barrier pipeline_barrier, // global work group size diff --git a/aten/src/ATen/native/vulkan/ops/Register.cpp b/aten/src/ATen/native/vulkan/ops/Register.cpp index 18d5a6facfaed..f9f0c2ad6aff3 100644 --- a/aten/src/ATen/native/vulkan/ops/Register.cpp +++ b/aten/src/ATen/native/vulkan/ops/Register.cpp @@ -1,5 +1,6 @@ #ifdef USE_VULKAN_API +#include #include #include #include @@ -16,6 +17,19 @@ namespace ops { namespace { TORCH_LIBRARY(vulkan, m) { + m.class_("BatchNormPackedContext") + .def_pickle( + // __getstate__ + [](const c10::intrusive_ptr& context) { + // context is packed + return context->unpack(); + }, + // __setstate__ + [](c10::impl::GenericList state) { + // state is unpacked + return c10::make_intrusive( + BatchNormPackedContext::pack(state)); + }); m.class_("LinearPackedContext") .def_pickle( // __getstate__ @@ -114,6 +128,14 @@ TORCH_LIBRARY(vulkan_prepack, m) { m.def(TORCH_SELECTIVE_SCHEMA( "vulkan_prepack::run_tconv2d_context(Tensor X, " "__torch__.torch.classes.vulkan.Conv2dPackedContext W_prepack) -> Tensor Y")); + m.def(TORCH_SELECTIVE_SCHEMA( + "vulkan_prepack::create_qconv2d_context(Tensor W, Tensor? B, " + "int[2] stride, int[2] padding, int[2] dilation, int groups, " + "Scalar? output_min=None, Scalar? output_max=None) " + "-> __torch__.torch.classes.vulkan.Conv2dPackedContext")); + m.def(TORCH_SELECTIVE_SCHEMA( + "vulkan_prepack::run_qconv2d_context(Tensor X, float scale, int zero_point, " + "__torch__.torch.classes.vulkan.Conv2dPackedContext vk_context) -> Tensor Y")); m.def(TORCH_SELECTIVE_SCHEMA( "vulkan_prepack::create_linear_context(Tensor W, Tensor? B) " "-> __torch__.torch.classes.vulkan.LinearPackedContext")); @@ -147,6 +169,22 @@ TORCH_LIBRARY(vulkan_prepack, m) { "Tensor hx_vk, " "Tensor cx_vk, " "__torch__.torch.classes.vulkan.LstmPackedContext L_prepack) -> (Tensor next_input, Tensor hidden_state, Tensor cell_state)")); + m.def(TORCH_SELECTIVE_SCHEMA( + "vulkan_prepack::create_batchnorm_context(" + "Tensor? weight_opt, " + "Tensor? bias_opt, " + "Tensor? running_mean_opt, " + "Tensor? running_var_opt, " + "bool training, " + "float momentum, " + "float eps, " + "bool cudnn_enable) " + "-> __torch__.torch.classes.vulkan.BatchNormPackedContext")); + m.def(TORCH_SELECTIVE_SCHEMA( + "vulkan_prepack::run_batchnorm_context(" + "Tensor input_vk, " + "__torch__.torch.classes.vulkan.BatchNormPackedContext context) " + "-> Tensor out")); } TORCH_LIBRARY_IMPL(vulkan_prepack, CPU, m) { @@ -168,6 +206,15 @@ TORCH_LIBRARY_IMPL(vulkan_prepack, CPU, m) { m.impl( TORCH_SELECTIVE_NAME("vulkan_prepack::create_lstm_context"), TORCH_FN(create_lstm_context)); + m.impl( + TORCH_SELECTIVE_NAME("vulkan_prepack::create_batchnorm_context"), + TORCH_FN(create_batchnorm_context)); +} + +TORCH_LIBRARY_IMPL(vulkan_prepack, QuantizedCPU, m) { + m.impl( + TORCH_SELECTIVE_NAME("vulkan_prepack::create_qconv2d_context"), + TORCH_FN(create_qconv2d_context)); } TORCH_LIBRARY_IMPL(vulkan_prepack, Vulkan, m) { @@ -180,6 +227,9 @@ TORCH_LIBRARY_IMPL(vulkan_prepack, Vulkan, m) { m.impl( TORCH_SELECTIVE_NAME("vulkan_prepack::run_tconv2d_context"), TORCH_FN(run_tconv2d_context)); + m.impl( + TORCH_SELECTIVE_NAME("vulkan_prepack::run_qconv2d_context"), + TORCH_FN(run_qconv2d_context)); m.impl( TORCH_SELECTIVE_NAME("vulkan_prepack::run_linear_context"), TORCH_FN(run_linear_context)); @@ -189,6 +239,9 @@ TORCH_LIBRARY_IMPL(vulkan_prepack, Vulkan, m) { m.impl( TORCH_SELECTIVE_NAME("vulkan_prepack::run_lstm_context"), TORCH_FN(run_lstm_context)); + m.impl( + TORCH_SELECTIVE_NAME("vulkan_prepack::run_batchnorm_context"), + TORCH_FN(run_batchnorm_context)); } } // namespace diff --git a/aten/src/ATen/native/vulkan/ops/Shape.cpp b/aten/src/ATen/native/vulkan/ops/Shape.cpp index d8263e59668e6..4209a3781cd28 100644 --- a/aten/src/ATen/native/vulkan/ops/Shape.cpp +++ b/aten/src/ATen/native/vulkan/ops/Shape.cpp @@ -22,7 +22,7 @@ Tensor view_internal(const Tensor& self_arg, const IntArrayRef shape) { self.options(), }; - api::StorageBuffer buffer(context, at::kFloat, v_self.numcells(), true); + api::StorageBuffer buffer(context, at::kFloat, v_self.gpu_numel(), true); utils::pack_vtensor_to_staging(v_self, buffer.buffer()); diff --git a/aten/src/ATen/native/vulkan/ops/Tensor.cpp b/aten/src/ATen/native/vulkan/ops/Tensor.cpp index 8a829bda0708f..315462ac0d1df 100644 --- a/aten/src/ATen/native/vulkan/ops/Tensor.cpp +++ b/aten/src/ATen/native/vulkan/ops/Tensor.cpp @@ -1,3 +1,4 @@ +#include #include #include #include @@ -9,45 +10,207 @@ namespace ops { namespace { -api::utils::uvec3 image_extents(const IntArrayRef sizes) { - int64_t width = 1; - int64_t height = 1; - int64_t depth = 1; +/* + * Extracts the memory format member of a TensorOptions struct. If there is no + * empty format listed, then a contiguous format is assumed. + */ +at::MemoryFormat get_memory_format(const TensorOptions& options) { + return options.memory_format_opt() ? *(options.memory_format_opt()) + : at::MemoryFormat::Contiguous; +} - switch (sizes.size()) { - case 1: - width = sizes[0]; - break; +/* + * Calculates the strides of a contiguous tensor. empty_tensor_restride from + * TensorImpl.h was used as a reference. + */ +c10::SmallVector calc_contiguous_strides(const IntArrayRef sizes) { + int64_t ndim = sizes.size(); + c10::SmallVector strides(ndim); + + int64_t running_product = 1; + if (ndim >= 1) { + strides[ndim - 1] = running_product; + for (int i = sizes.size() - 2; i >= 0; --i) { + running_product *= sizes[i + 1]; + strides[i] = running_product; + } + } - case 2: - width = sizes[1]; - height = sizes[0]; - break; + return strides; +} - case 3: - width = sizes[2]; - height = sizes[1]; - depth = sizes[0]; - break; +c10::SmallVector calc_channels_last_strides( + const IntArrayRef sizes) { + c10::SmallVector strides(sizes.size()); + switch (sizes.size()) { case 4: - width = sizes[3]; - height = sizes[2]; - depth = sizes[0] * sizes[1]; - break; - + strides[1] = 1; + strides[3] = sizes[1]; + strides[2] = strides[3] * sizes[3]; + strides[0] = strides[2] * sizes[2]; + return strides; + case 3: + strides[0] = 1; + strides[2] = sizes[0]; + strides[1] = strides[2] * sizes[2]; + return strides; default: - TORCH_INTERNAL_ASSERT( - false, - "Only Tensors with 1 <= dim <= 4 can be represented as a Vulkan Image!"); + TORCH_CHECK( + false, "ChannelsLast format only available for 3 <= ndim <= 4!"); } - return { - api::utils::safe_downcast(width), - api::utils::safe_downcast(height), - api::utils::safe_downcast( - api::utils::div_up(depth, INT64_C(4))), + return strides; +} + +/* + * Calculates the strides of a tensor based on the sizes and memory format. Note + * that strides are only valid for vTensors that are backed by buffer storage; + * if texture storage is used then the strides are invalid and set to zeros. + */ +c10::SmallVector calc_strides( + const IntArrayRef sizes, + const at::MemoryFormat memory_format, + const api::StorageType storage_type) { + if (storage_type == api::StorageType::BUFFER) { + switch (memory_format) { + case MemoryFormat::Contiguous: + return calc_contiguous_strides(sizes); + break; + case MemoryFormat::ChannelsLast: + return calc_channels_last_strides(sizes); + break; + default: + TORCH_CHECK(false, "Invalid memory format used to create vTensor!"); + } + } else { + c10::SmallVector strides(sizes.size()); + return strides; + } +} + +/* + * When stored on the GPU, one dimension will be aligned to the next multiple of + * 4 in order to take advantage of vec4 data types. This function adjusts one of + * the dimensions based on the desired memory format and storage type. + */ +c10::SmallVector calc_gpu_sizes( + const IntArrayRef sizes, + const at::MemoryFormat memory_format, + const api::StorageType storage_type) { + size_t ndim = sizes.size(); + + // For buffer formats, the innermost dim (i.e. where the stride is 1) will be + // aligned up. + if (storage_type == api::StorageType::BUFFER) { + c10::SmallVector gpu_sizes{sizes}; + + switch (memory_format) { + case at::MemoryFormat::Contiguous: + gpu_sizes[ndim - 1] = api::utils::align_up(sizes[ndim - 1], INT64_C(4)); + break; + + case at::MemoryFormat::ChannelsLast: + switch (ndim) { + case 3: + gpu_sizes[0] = api::utils::align_up(sizes[0], INT64_C(4)); + break; + + case 4: + gpu_sizes[1] = api::utils::align_up(sizes[1], INT64_C(4)); + break; + } + break; + + default: + TORCH_CHECK(false, "Invalid memory format used to create vTensor!"); + break; + } + + return gpu_sizes; + } else { + TORCH_CHECK( + ndim >= 1 && ndim <= 4, + "Texture storage only valid for 1 <= ndim <= 4!"); + + c10::SmallVector gpu_sizes(3); + + // Channel dim will be always be aligned. For 4 dimensional tensors, batch + // and channel are combined, then aligned. + switch (ndim) { + case 1: + gpu_sizes[0] = 4; + gpu_sizes[1] = 1; + gpu_sizes[2] = sizes[0]; + break; + + case 2: + gpu_sizes[0] = 4; + gpu_sizes[1] = sizes[0]; + gpu_sizes[2] = sizes[1]; + break; + + case 3: + gpu_sizes[0] = api::utils::align_up(sizes[0], INT64_C(4)); + gpu_sizes[1] = sizes[1]; + gpu_sizes[2] = sizes[2]; + break; + + case 4: + int64_t combined_depth = sizes[0] * sizes[1]; + gpu_sizes[0] = api::utils::align_up(combined_depth, INT64_C(4)); + gpu_sizes[1] = sizes[2]; + gpu_sizes[2] = sizes[3]; + break; + } + return gpu_sizes; + } +} + +/* + * Creates a uvec3 denoting the extents of the image texture that will be + * created to store a tensor of a given size. + */ +api::utils::uvec3 create_image_extents( + const IntArrayRef gpu_sizes, + const api::StorageType storage_type) { + size_t ndim = gpu_sizes.size(); + + if (storage_type == api::StorageType::BUFFER) { + // image extents do not apply to buffer storage + return {0u, 0u, 0u}; + } else { + TORCH_CHECK( + ndim >= 1 && ndim <= 3, + "Texture storage only valid for 1 <= ndim <= 3!"); + + uint32_t width = get_dim(gpu_sizes); + uint32_t height = get_dim(gpu_sizes); + uint32_t depth = get_dim(gpu_sizes); + + TORCH_CHECK(depth % 4 == 0, "Channels must be divisible by 4!") + + return {width, height, depth / 4u}; + } +} + +api::UniformParamsBuffer make_metadata_uniform( + api::Context* const context, + const IntArrayRef sizes, + const IntArrayRef strides, + const api::StorageType storage_type) { + if (storage_type != api::StorageType::BUFFER) { + return api::UniformParamsBuffer(); + } + + vTensor::BufferMetadata metadata{ + ops::make_nchw_uvec4(sizes), + ops::make_nchw_uvec4(strides), + api::utils::safe_downcast(sizes.size()), + api::utils::safe_downcast(c10::multiply_integers(sizes)), }; + + return api::UniformParamsBuffer(context, metadata); } } // namespace @@ -59,58 +222,69 @@ api::utils::uvec3 image_extents(const IntArrayRef sizes) { vTensor::vTensor( api::Context* const context, const IntArrayRef sizes, - const TensorOptions& options) - : view_(std::make_shared( + const TensorOptions& options, + const api::StorageType storage_type) + : options_(options), + memory_format_(get_memory_format(options)), + // Calculate sizes and strides + sizes_{sizes}, + strides_{calc_strides(sizes, memory_format_, storage_type)}, + gpu_sizes_{calc_gpu_sizes(sizes, memory_format_, storage_type)}, + gpu_strides_{calc_strides(gpu_sizes_, memory_format_, storage_type)}, + // Vulkan uniform buffer containing sizes and stride info + metadata_uniform_{make_metadata_uniform( context, - sizes, - StorageType::TEXTURE_3D, - options)) {} - -vTensor::vTensor( - api::Context* const context, - const IntArrayRef sizes, - const StorageType storage_type, - const TensorOptions& options) - : view_(std::make_shared( + gpu_sizes_, + gpu_strides_, + storage_type)}, + // Construct Tensor storage + view_(std::make_shared( context, - sizes, storage_type, - options)) {} + gpu_sizes_, + dtype())) { + ops::verify(options); +} vTensor::vTensor( api::Context* const context, const IntArrayRef sizes, const TensorOptions& options, double q_scale, - int64_t q_zero_point) - : view_(std::make_shared( + int64_t q_zero_point, + const api::StorageType storage_type) + : options_(options), + memory_format_(get_memory_format(options)), + // Calculate sizes and strides + sizes_{sizes}, + strides_{calc_strides(sizes, memory_format_, storage_type)}, + gpu_sizes_{calc_gpu_sizes(sizes, memory_format_, storage_type)}, + gpu_strides_{calc_strides(gpu_sizes_, memory_format_, storage_type)}, + // Vulkan uniform buffer containing sizes and stride info + metadata_uniform_{make_metadata_uniform( context, - sizes, - StorageType::TEXTURE_3D, - options, - q_scale, - q_zero_point)) {} - -vTensor::vTensor( - api::Context* const context, - const IntArrayRef sizes, - const StorageType storage_type, - const TensorOptions& options, - double q_scale, - int64_t q_zero_point) - : view_(std::make_shared( + gpu_sizes_, + gpu_strides_, + storage_type)}, + // Quantization params + is_quantized_{true}, + q_scale_{q_scale}, + q_zero_point_{q_zero_point}, + // Construct Tensor storage + view_(std::make_shared( context, - sizes, storage_type, - options, - q_scale, - q_zero_point)) {} + gpu_sizes_, + dtype())) { + verify(options); +} api::VulkanImage& vTensor::image( api::PipelineBarrier& pipeline_barrier, const api::PipelineStageFlags stage) const& { - view_->transition(pipeline_barrier, stage, api::MemoryAccessType::READ); + TORCH_CHECK(view_->image_, "vTensor has empty image texture!"); + view_->transition(pipeline_barrier, stage, api::MemoryAccessType::READ); return view_->image_; } @@ -118,11 +292,40 @@ api::VulkanImage& vTensor::image( api::PipelineBarrier& pipeline_barrier, const api::PipelineStageFlags stage, const api::MemoryAccessFlags access) & { - view_->transition(pipeline_barrier, stage, access); + TORCH_CHECK(view_->image_, "vTensor has empty image texture!"); + view_->transition(pipeline_barrier, stage, access); return view_->image_; } +api::VulkanBuffer& vTensor::buffer( + api::PipelineBarrier& pipeline_barrier, + const api::PipelineStageFlags stage) const& { + TORCH_CHECK(view_->buffer_, "vTensor has empty buffer!"); + + view_->transition(pipeline_barrier, stage, api::MemoryAccessType::READ); + return view_->buffer_; +} + +api::VulkanBuffer& vTensor::buffer( + api::PipelineBarrier& pipeline_barrier, + const api::PipelineStageFlags stage, + const api::MemoryAccessFlags access) & { + TORCH_CHECK(view_->buffer_, "vTensor has empty buffer!"); + + view_->transition(pipeline_barrier, stage, access); + return view_->buffer_; +} + +vTensor::BufferMetadata vTensor::get_cpu_buffer_metadata() const { + return { + ops::make_nchw_uvec4(sizes_), + ops::make_nchw_uvec4(strides_), + api::utils::safe_downcast(sizes_.size()), + api::utils::safe_downcast(c10::multiply_integers(sizes_)), + }; +} + // // vTensorStorage // @@ -130,7 +333,7 @@ api::VulkanImage& vTensor::image( api::VulkanImage allocate_image( api::Context* const context_ptr, api::utils::uvec3& extents, - StorageType storage_type, + const api::StorageType storage_type, const VkFormat image_format) { api::Adapter* adapter_ptr = context_ptr->adapter_ptr(); @@ -145,14 +348,17 @@ api::VulkanImage allocate_image( VkImageViewType image_view_type = VK_IMAGE_VIEW_TYPE_3D; switch (storage_type) { - case StorageType::TEXTURE_3D: + case api::StorageType::TEXTURE_3D: image_type = VK_IMAGE_TYPE_3D; image_view_type = VK_IMAGE_VIEW_TYPE_3D; break; - case StorageType::TEXTURE_2D: + case api::StorageType::TEXTURE_2D: image_type = VK_IMAGE_TYPE_2D; image_view_type = VK_IMAGE_VIEW_TYPE_2D; break; + default: + // Return an empty VulkanImage by default + return api::VulkanImage(); } VkSampler sampler = adapter_ptr->sampler_cache().retrieve(sampler_props); @@ -167,53 +373,48 @@ api::VulkanImage allocate_image( true); } -vTensorStorage::vTensorStorage( - api::Context* const context, - const IntArrayRef sizes, - const StorageType storage_type, - const TensorOptions& options) - : context_(context), - extents_(image_extents(sizes)), - options_(options), - sizes_(sizes), - strides_(sizes.size()), - storage_type_{storage_type}, - image_(allocate_image( - context_, - extents_, - storage_type_, - api::vk_format(options_.dtype()))), - last_access_{} { - ops::verify(options); +api::VulkanBuffer allocate_buffer( + api::Context* const context_ptr, + const int64_t numel, + const api::StorageType storage_type, + const c10::ScalarType dtype) { + api::Adapter* adapter_ptr = context_ptr->adapter_ptr(); + + switch (storage_type) { + case api::StorageType::BUFFER: + break; + default: + // Return an empty VulkanBuffer if Buffer storage is not used + return api::VulkanBuffer(); + } + + return adapter_ptr->vma().create_storage_buffer( + c10::elementSize(dtype) * numel, true); } vTensorStorage::vTensorStorage( api::Context* const context, - const IntArrayRef sizes, - const StorageType storage_type, - const TensorOptions& options, - double q_scale_in, - int64_t q_zero_point_in) + const api::StorageType storage_type, + const IntArrayRef gpu_sizes, + const at::ScalarType dtype) : context_(context), - extents_(image_extents(sizes)), - options_(options), - sizes_(sizes), - strides_(sizes.size()), - is_quantized_{true}, - q_scale{q_scale_in}, - q_zero_point{q_zero_point_in}, storage_type_{storage_type}, + extents_(create_image_extents(gpu_sizes, storage_type)), + buffer_length_{c10::multiply_integers(gpu_sizes)}, image_(allocate_image( context_, extents_, storage_type_, - api::vk_format(options_.dtype()))), - last_access_{} { - ops::verify(options); -} + api::vk_format(dtype))), + buffer_(allocate_buffer(context_, buffer_length_, storage_type_, dtype)), + last_access_{} {} vTensorStorage::~vTensorStorage() { - context_->register_image_cleanup(image_); + if (image_) { + context_->register_image_cleanup(image_); + } else if (buffer_) { + context_->register_buffer_cleanup(buffer_); + } } void vTensorStorage::transition( @@ -224,12 +425,18 @@ void vTensorStorage::transition( api::PipelineStageFlags prev_stage = last_access_.stage; api::MemoryAccessFlags prev_access = last_access_.access; - const VkImageLayout cur_layout = image_.layout(); - const VkImageLayout new_layout = api::vk_layout(cur_stage, cur_access); - - const bool layout_changed = cur_layout != new_layout; const bool prev_written = (prev_access & api::MemoryAccessType::WRITE) != 0; + VkImageLayout cur_layout = VK_IMAGE_LAYOUT_UNDEFINED; + VkImageLayout new_layout = VK_IMAGE_LAYOUT_UNDEFINED; + bool layout_changed = false; + if (image_) { + cur_layout = image_.layout(); + new_layout = api::vk_layout(cur_stage, cur_access); + + layout_changed = cur_layout != new_layout; + } + if (prev_written || layout_changed) { VkPipelineStageFlags src_stage = api::vk_stage(prev_stage); if (0u == src_stage) { @@ -243,14 +450,21 @@ void vTensorStorage::transition( pipeline_barrier.stage.src |= src_stage; pipeline_barrier.stage.dst |= dst_stage; - pipeline_barrier.images.push_back(api::ImageMemoryBarrier( - api::vk_access(prev_stage, prev_access), - api::vk_access(cur_stage, cur_access), - cur_layout, - new_layout, - image_)); - - image_.set_layout(new_layout); + if (image_) { + pipeline_barrier.images.push_back(api::ImageMemoryBarrier( + api::vk_access(prev_stage, prev_access), + api::vk_access(cur_stage, cur_access), + cur_layout, + new_layout, + image_)); + + image_.set_layout(new_layout); + } else if (buffer_) { + pipeline_barrier.buffers.push_back(api::BufferMemoryBarrier( + api::vk_access(prev_stage, prev_access), + api::vk_access(cur_stage, cur_access), + buffer_)); + } } last_access_.stage = cur_stage; @@ -303,9 +517,10 @@ void verify(const TensorOptions& options) { !options.has_layout() || (c10::kStrided == options.layout()), "'layout' tensor option is not yet supported under Vulkan!"); + at::MemoryFormat memory_format = get_memory_format(options); TORCH_CHECK( - !options.has_memory_format() || - (c10::MemoryFormat::Contiguous == options.memory_format_opt()), + memory_format == at::MemoryFormat::ChannelsLast || + memory_format == at::MemoryFormat::Contiguous, "'memory_format' tensor option is not yet supported under Vulkan!"); } diff --git a/aten/src/ATen/native/vulkan/ops/Tensor.h b/aten/src/ATen/native/vulkan/ops/Tensor.h index 9e5651cb510f3..241d2c839b80a 100644 --- a/aten/src/ATen/native/vulkan/ops/Tensor.h +++ b/aten/src/ATen/native/vulkan/ops/Tensor.h @@ -26,11 +26,6 @@ struct LastAccess { : stage{stage_flags}, access{access_flags} {} }; -enum class StorageType { - TEXTURE_3D, - TEXTURE_2D, -}; - class vTensorStorage final { public: // Do not allow empty vTensorStorage construction @@ -38,16 +33,9 @@ class vTensorStorage final { vTensorStorage( api::Context* context, - IntArrayRef sizes, - const StorageType storage_type, - const TensorOptions& options); - vTensorStorage( - api::Context* context, - IntArrayRef sizes, - const StorageType storage_type, - const TensorOptions& options, - double q_scale, - int64_t q_zero_point); + const api::StorageType storage_type, + const IntArrayRef sizes, + const at::ScalarType dtype); vTensorStorage(const vTensorStorage&) = delete; vTensorStorage& operator=(const vTensorStorage&) = delete; @@ -63,18 +51,15 @@ class vTensorStorage final { // Context api::Context* context_; - // Metadata + api::StorageType storage_type_; + + // Resource sizings api::utils::uvec3 extents_; - TensorOptions options_; - c10::SmallVector sizes_; - c10::SmallVector strides_; - bool is_quantized_{false}; - double q_scale{1.0f}; - int64_t q_zero_point{0u}; + int64_t buffer_length_; // Image Texture - StorageType storage_type_; mutable api::VulkanImage image_; + mutable api::VulkanBuffer buffer_; // Last Access - used to insert memory barriers LastAccess last_access_; @@ -100,33 +85,53 @@ class vTensor final { // Do not allow empty vTensor construction vTensor() = default; + // Default constructor vTensor( api::Context* context, IntArrayRef sizes, - const TensorOptions& options); - - vTensor( - api::Context* context, - IntArrayRef sizes, - const StorageType storage_type, - const TensorOptions& options); - - vTensor( - api::Context* const context, - const IntArrayRef sizes, const TensorOptions& options, - double q_scale, - int64_t q_zero_point); + const api::StorageType storage_type = api::StorageType::TEXTURE_3D); + // Default constructor with quantization parameters vTensor( api::Context* const context, const IntArrayRef sizes, - const StorageType storage_type, const TensorOptions& options, double q_scale, - int64_t q_zero_point); + int64_t q_zero_point, + const api::StorageType storage_type = api::StorageType::TEXTURE_3D); + + // Used for passing buffer sizes and strides data to shaders + struct BufferMetadata { + api::utils::uvec4 sizes; + api::utils::uvec4 strides; + uint32_t ndim; + uint32_t buffer_length; + }; private: + // Tensor Options + TensorOptions options_; + at::MemoryFormat memory_format_; + + // Sizes and Strides + c10::SmallVector sizes_; + c10::SmallVector strides_; + + // Storage Dimensions. When stored on the GPU, one dimension will be aligned + // to the next multiple of 4 in order to take advantage of vec4 data types. + c10::SmallVector gpu_sizes_; + c10::SmallVector gpu_strides_; + + // A Vulkan uniform buffer containing sizes and strides of the GPU buffer that + // can be passed into a shader. + api::UniformParamsBuffer metadata_uniform_; + + // Quantization params + bool is_quantized_{false}; + double q_scale_{1.0f}; + int64_t q_zero_point_{0u}; + // Even at the cost of a heap allocation plus the resulting negative impact // on cache locality due to the subsequent pointer chasing, it is still // critcal to share the view across vTensor implementations to minimize @@ -151,7 +156,7 @@ class vTensor final { Texture Access */ - inline StorageType storage_type() const { + inline api::StorageType storage_type() const { return view_->storage_type_; } @@ -163,6 +168,15 @@ class vTensor final { const api::PipelineStageFlags, const api::MemoryAccessFlags) &; + api::VulkanBuffer& buffer( + api::PipelineBarrier&, + const api::PipelineStageFlags) const&; + + api::VulkanBuffer& buffer( + api::PipelineBarrier&, + const api::PipelineStageFlags, + const api::MemoryAccessFlags) &; + /* Metadata */ @@ -171,6 +185,13 @@ class vTensor final { return view_->extents_; } + /* + * Extract a ScalarType from the TensorOptions member + */ + inline c10::ScalarType dtype() const { + return c10::typeMetaToScalarType(options_.dtype()); + } + /* * Get a c10::ScalarType that corresponds to the image format of the texture */ @@ -178,77 +199,96 @@ class vTensor final { return api::c10_scalartype(view_->texture_format()); } + inline at::MemoryFormat memory_format() const { + return memory_format_; + } + inline const TensorOptions& options() const { - return view_->options_; + return options_; } inline IntArrayRef sizes() const { - return view_->sizes_; + return sizes_; } inline IntArrayRef strides() const { - return view_->strides_; + return strides_; } - inline void set_is_quantized() const { - view_->is_quantized_ = true; + inline IntArrayRef gpu_sizes() const { + return gpu_sizes_; + } + + inline IntArrayRef gpu_strides() const { + return gpu_strides_; + } + + /* + * Get a uniform buffer containing sizes and strides information of the GPU + * buffer + */ + inline api::VulkanBuffer& buffer_metadata() { + return metadata_uniform_.buffer(); + } + + /* + * Constructs a BufferMetdata struct based on the original sizes and strides + * to pass into a shader. + */ + BufferMetadata get_cpu_buffer_metadata() const; + + inline void set_is_quantized() { + is_quantized_ = true; } inline bool is_quantized() const { - return view_->is_quantized_; + return is_quantized_; } - inline void set_scale(const double q_scale) const { - view_->q_scale = q_scale; + inline void set_scale(const double q_scale) { + q_scale_ = q_scale; } inline double get_scale() const { - return view_->q_scale; + return q_scale_; } inline float get_scale_float() const { - return api::utils::safe_downcast(view_->q_scale); + return api::utils::safe_downcast(q_scale_); } - inline void set_zero_point(const int64_t q_zero_point) const { - view_->q_zero_point = q_zero_point; + inline void set_zero_point(const int64_t q_zero_point) { + q_zero_point_ = q_zero_point; } inline int64_t get_zero_point() const { - return view_->q_zero_point; + return q_zero_point_; } inline int32_t get_zero_point_int32() const { - return api::utils::safe_downcast(view_->q_zero_point); + return api::utils::safe_downcast(q_zero_point_); } - inline size_t nbytes() const { - return c10::elementSize(c10::typeMetaToScalarType(options().dtype())) * - c10::multiply_integers(sizes()); + inline size_t numel() const { + return c10::multiply_integers(sizes()); } /* - * Number of texels in the image texture. + * Returns numel but based on gpu_sizes_ instead of sizes_ */ - inline VkDeviceSize numtexels() { - return view_->extents_.data[0u] * view_->extents_.data[1u] * - view_->extents_.data[2u]; + inline size_t gpu_numel() const { + return view_->buffer_length_; } - /* - * Number of "cells" in the image texture. 4 cells make up a texel. - */ - inline VkDeviceSize numcells() { - return view_->extents_.data[0u] * view_->extents_.data[1u] * - (4u * view_->extents_.data[2u]); + inline size_t nbytes() const { + return c10::elementSize(dtype()) * numel(); } /* - * Number of bytes needed for a buffer to receive all data in the texture + * Return nbytes but bnased on gpu_sizes_ instead of sizes_ */ - inline VkDeviceSize buffer_bytes() { - return c10::elementSize(this->texture_dtype()) * view_->extents_.data[0u] * - view_->extents_.data[1u] * (4u * view_->extents_.data[2u]); + inline VkDeviceSize gpu_nbytes() const { + return c10::elementSize(dtype()) * gpu_numel(); } }; diff --git a/aten/src/ATen/native/vulkan/ops/Utils.cpp b/aten/src/ATen/native/vulkan/ops/Utils.cpp index 23ae1e9f57b46..fed1fd482fd58 100644 --- a/aten/src/ATen/native/vulkan/ops/Utils.cpp +++ b/aten/src/ATen/native/vulkan/ops/Utils.cpp @@ -19,40 +19,78 @@ namespace packing { static api::ShaderSource get_nchw_to_image_shader(const vTensor& v_dst) { if (v_dst.is_quantized()) { switch (v_dst.storage_type()) { - case StorageType::TEXTURE_3D: - return VK_KERNEL(nchw_to_image_quantized); - case StorageType::TEXTURE_2D: + case api::StorageType::TEXTURE_3D: + switch (v_dst.dtype()) { + case c10::ScalarType::QUInt8: + return VK_KERNEL(nchw_to_image_uint8); + case c10::ScalarType::QInt8: + return VK_KERNEL(nchw_to_image_int8); + case c10::ScalarType::QInt32: + return VK_KERNEL(nchw_to_image_int32); + default: + TORCH_CHECK( + false, + "Vulkan quantization currently not supported for dtype ", + v_dst.dtype()); + } + default: TORCH_CHECK(false, "No kernel available!"); + case api::StorageType::BUFFER: + case api::StorageType::UNKNOWN: + TORCH_CHECK(false, "Requested storage type must be a texture type."); } } switch (v_dst.storage_type()) { - case StorageType::TEXTURE_3D: + case api::StorageType::TEXTURE_3D: return VK_KERNEL(nchw_to_image); - case StorageType::TEXTURE_2D: + case api::StorageType::TEXTURE_2D: return VK_KERNEL(nchw_to_image2d); + default: + TORCH_CHECK(false, "No kernel available!"); } } static api::ShaderSource get_image_to_nchw_shader(const vTensor& v_src) { if (v_src.is_quantized()) { + auto plane_size = + get_dim(v_src) * get_dim(v_src); switch (v_src.storage_type()) { - case StorageType::TEXTURE_3D: - return VK_KERNEL(image_to_nchw_quantized); - case StorageType::TEXTURE_2D: + case api::StorageType::TEXTURE_3D: + switch (v_src.dtype()) { + case c10::ScalarType::QUInt8: + return plane_size % 4 == 0 ? VK_KERNEL(image_to_nchw_quantized_mul4) + : VK_KERNEL(image_to_nchw_quantized); + case c10::ScalarType::QInt8: + return plane_size % 4 == 0 ? VK_KERNEL(image_to_nchw_quantized_mul4) + : VK_KERNEL(image_to_nchw_quantized); + case c10::ScalarType::QInt32: + return VK_KERNEL(image_to_nchw_int32); + default: + TORCH_CHECK( + false, + "Vulkan quantization currently not supported for dtype ", + v_src.dtype()); + } + default: TORCH_CHECK(false, "No kernel available!"); + case api::StorageType::BUFFER: + case api::StorageType::UNKNOWN: + TORCH_CHECK(false, "Requested storage type must be a texture type."); } } switch (v_src.storage_type()) { - case StorageType::TEXTURE_3D: + case api::StorageType::TEXTURE_3D: return VK_KERNEL(image_to_nchw); - case StorageType::TEXTURE_2D: + case api::StorageType::TEXTURE_2D: return VK_KERNEL(image2d_to_nchw); + default: + TORCH_CHECK(false, "No kernel available!"); } } -struct Params final { +struct ToFromTextureParams final { api::utils::ivec3 extents; int32_t plane_size; }; @@ -73,7 +111,7 @@ void record_nchw_to_image_op( api::utils::safe_downcast(get_dim(v_dst)); int32_t plane_size = height * width; - Params block{ + ToFromTextureParams block{ api::utils::make_ivec3(v_dst.extents()), plane_size, }; @@ -116,11 +154,25 @@ void record_image_to_nchw_op( api::utils::safe_downcast(get_dim(v_src)); int32_t plane_size = height * width; - Params block{ + ToFromTextureParams block{ api::utils::make_ivec3(v_src.extents()), plane_size, }; + if (v_src.dtype() == c10::ScalarType::QUInt8 || + v_src.dtype() == c10::ScalarType::QInt8) { + if (plane_size % 4 == 0) { + global_size.data[0u] = plane_size / 4; + global_size.data[1u] = 1; + local_size.data[0u] *= local_size.data[1u]; + local_size.data[1u] = 1; + } else { + uint32_t numel = v_src.numel(); + global_size = {api::utils::div_up(numel, uint32_t(4)), 1u, 1u}; + local_size = {64u, 1u, 1u}; + } + } + api::UniformParamsBuffer params(context, block); context->submit_compute_job( // shader descriptor @@ -143,6 +195,76 @@ void record_image_to_nchw_op( params.buffer()); } +void record_nchw_to_buffer_op( + api::Context* const context, + api::VulkanBuffer& src_buffer, + vTensor& v_dst, + api::PipelineBarrier pipeline_barrier, + const VkFence fence_handle) { + uint32_t gpu_buf_len = api::utils::safe_downcast(v_dst.gpu_numel()); + + api::utils::uvec3 global_size = {gpu_buf_len, 1u, 1u}; + api::utils::uvec3 local_size = {32u, 1u, 1u}; + + api::UniformParamsBuffer cpu_buffer_metadata( + context, v_dst.get_cpu_buffer_metadata()); + + context->submit_compute_job( + // shader descriptor + VK_KERNEL(buffer_to_buffer), + // pipeline barrier + pipeline_barrier, + // global work group size + global_size, + // local work group size + local_size, + // fence handle + fence_handle, + // shader arguments + v_dst.buffer( + pipeline_barrier, + api::PipelineStage::COMPUTE, + api::MemoryAccessType::WRITE), + v_dst.buffer_metadata(), + src_buffer, + cpu_buffer_metadata.buffer()); +} + +void record_buffer_to_nchw_op( + api::Context* const context, + vTensor& v_src, + api::VulkanBuffer& dst_buffer, + api::PipelineBarrier pipeline_barrier, + const VkFence fence_handle) { + uint32_t buf_len = api::utils::safe_downcast(v_src.numel()); + + api::utils::uvec3 global_size = {buf_len, 1u, 1u}; + api::utils::uvec3 local_size = {4u, 1u, 1u}; + + api::UniformParamsBuffer cpu_buffer_metadata( + context, v_src.get_cpu_buffer_metadata()); + + context->submit_compute_job( + // shader descriptor + VK_KERNEL(buffer_to_buffer), + // pipeline barrier + pipeline_barrier, + // global work group size + global_size, + // local work group size + local_size, + // fence handle + fence_handle, + // shader arguments + dst_buffer, + cpu_buffer_metadata.buffer(), + v_src.buffer( + pipeline_barrier, + api::PipelineStage::COMPUTE, + api::MemoryAccessType::WRITE), + v_src.buffer_metadata()); +} + } // namespace packing namespace utils { @@ -247,7 +369,7 @@ void copy_buffer_to_vtensor( api::Context* const context = api::context(); TORCH_CHECK( - src_buffer.mem_size() == v_dst.buffer_bytes(), + src_buffer.mem_size() == v_dst.gpu_nbytes(), "Vulkan copy_buffer_to_vtensor: source buffer and destination texture " "do not have the same number of bytes"); @@ -297,7 +419,7 @@ void copy_vtensor_to_buffer( api::Context* const context = api::context(); TORCH_CHECK( - v_src.buffer_bytes() == dst_buffer.mem_size(), + v_src.gpu_nbytes() == dst_buffer.mem_size(), "Vulkan copy_vtensor_to_buffer: source texture and destination buffer " "do not have the same number of bytes"); @@ -324,14 +446,20 @@ void pack_buffer_to_vtensor( api::PipelineBarrier& pipeline_barrier) { api::Context* const context = api::context(); - api::ShaderSource compute_shader = packing::get_nchw_to_image_shader(v_self); - packing::record_nchw_to_image_op( - context, - compute_shader, - buffer, - v_self, - pipeline_barrier, - VK_NULL_HANDLE); + if (v_self.storage_type() == api::StorageType::BUFFER) { + packing::record_nchw_to_buffer_op( + context, buffer, v_self, pipeline_barrier, VK_NULL_HANDLE); + } else { + api::ShaderSource compute_shader = + packing::get_nchw_to_image_shader(v_self); + packing::record_nchw_to_image_op( + context, + compute_shader, + buffer, + v_self, + pipeline_barrier, + VK_NULL_HANDLE); + } } void pack_staging_to_vtensor(api::VulkanBuffer& staging, vTensor& v_self) { @@ -344,11 +472,22 @@ void pack_vtensor_to_staging( api::VulkanBuffer& staging, const VkFence fence_handle) { api::Context* const context = api::context(); - api::ShaderSource compute_shader = packing::get_image_to_nchw_shader(v_self); - api::PipelineBarrier pipeline_barrier{}; - packing::record_image_to_nchw_op( - context, compute_shader, v_self, staging, pipeline_barrier, fence_handle); + + if (v_self.storage_type() == api::StorageType::BUFFER) { + packing::record_buffer_to_nchw_op( + context, v_self, staging, pipeline_barrier, fence_handle); + } else { + api::ShaderSource compute_shader = + packing::get_image_to_nchw_shader(v_self); + packing::record_image_to_nchw_op( + context, + compute_shader, + v_self, + staging, + pipeline_barrier, + fence_handle); + } } } // namespace utils diff --git a/aten/src/ATen/record_function.h b/aten/src/ATen/record_function.h index d4c143211a21a..323dc5f888b87 100644 --- a/aten/src/ATen/record_function.h +++ b/aten/src/ATen/record_function.h @@ -611,6 +611,16 @@ void record_function_with_scope_and_debug_handle( RECORD_WITH_SCOPE_DEBUG_HANDLE_AND_INPUTS( \ at::RecordScope::LITE_INTERPRETER, fn, debug_handle, inputs) +// Bookend to the RECORD_FUNCTION macros. Use this after the kernel +// launch to let the profiler bind the outputs to the op that produced +// them. Note that guard is declared by RECORD_FUNCTION so this macro +// needs to be called from the same scope as RECORD_FUNCTION +#define RECORD_OUTPUTS(outputs) \ + if (guard.needsOutputs()) { \ + guard.setOutputs( \ + std::vector(outputs.begin(), outputs.end())); \ + } + /** * addThreadLocalCallback adds a thread local callback to run with * RecordFunction, returns handle to use with removeThreadLocalCallback diff --git a/aten/src/ATen/templates/CompositeViewCopyKernels.cpp b/aten/src/ATen/templates/CompositeViewCopyKernels.cpp index d6a7266952e9b..7548d7c1a3a8a 100644 --- a/aten/src/ATen/templates/CompositeViewCopyKernels.cpp +++ b/aten/src/ATen/templates/CompositeViewCopyKernels.cpp @@ -30,17 +30,25 @@ std::vector clone_arg(const at::TensorList& t_list) { return out; } +// duped with gen_resize_out_helper from structured kernels void copy_arg(const at::Tensor& dst, const at::Tensor& src) { + TORCH_CHECK(src.dtype() == dst.dtype(), + "Expected out tensor to have dtype ", src.dtype(), ", but got ", dst.dtype(), " instead"); + TORCH_CHECK(src.device() == dst.device(), + "Expected out tensor to have device ", src.device(), ", but got ", dst.device(), " instead"); dst.copy_(src); } void copy_arg(const at::TensorList& dst, const at::TensorList& src) { TORCH_INTERNAL_ASSERT(dst.size() == src.size()); for (const auto& i : c10::irange(dst.size())) { - dst[i].copy_(src[i]); + copy_arg(dst[i], src[i]); } } +// TODO: this doesn't handle restriding empty tensors correctly; see +// gen_resize_out_helper for the correct algorithm + void resize_out_helper(const at::Tensor& dst, const at::Tensor& src) { at::native::resize_output(dst, src.sizes()); } diff --git a/aten/src/ATen/test/CMakeLists.txt b/aten/src/ATen/test/CMakeLists.txt index 5c8fda81b3d9c..27b9e37596529 100644 --- a/aten/src/ATen/test/CMakeLists.txt +++ b/aten/src/ATen/test/CMakeLists.txt @@ -45,7 +45,7 @@ list(APPEND ATen_CPU_TEST_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/undefined_tensor_test.cpp ${CMAKE_CURRENT_SOURCE_DIR}/variant_test.cpp ${CMAKE_CURRENT_SOURCE_DIR}/verify_api_visibility.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/vmap_test.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/legacy_vmap_test.cpp ${CMAKE_CURRENT_SOURCE_DIR}/weakref_test.cpp ${CMAKE_CURRENT_SOURCE_DIR}/wrapdim_test.cpp ${CMAKE_CURRENT_SOURCE_DIR}/xla_tensor_test.cpp diff --git a/aten/src/ATen/test/vmap_test.cpp b/aten/src/ATen/test/legacy_vmap_test.cpp similarity index 99% rename from aten/src/ATen/test/vmap_test.cpp rename to aten/src/ATen/test/legacy_vmap_test.cpp index 1feafaa59f3a4..5ca827de2d98a 100644 --- a/aten/src/ATen/test/vmap_test.cpp +++ b/aten/src/ATen/test/legacy_vmap_test.cpp @@ -1,8 +1,8 @@ #include #include -#include -#include +#include +#include #include using namespace at; diff --git a/aten/src/ATen/test/math_kernel_test.cpp b/aten/src/ATen/test/math_kernel_test.cpp index 15ce0af4001d5..8875e72a6af9b 100644 --- a/aten/src/ATen/test/math_kernel_test.cpp +++ b/aten/src/ATen/test/math_kernel_test.cpp @@ -114,16 +114,6 @@ TEST(MathKernelTest, MishBackward) { ASSERT_ALLCLOSE_TOLERANCES(out, math_out, 1e-4, 1e-6); } -TEST(MathKernelTest, NarrowCopy) { - auto x = rand({5, 8, 7}); - for (const auto dim : c10::irange(3)) { - const int64_t start = 1, length = 4; - auto y_ref = x.narrow(dim, start, length); - auto y_test = at::native::narrow_copy_dense(x, dim, start, length); - ASSERT_ALLCLOSE_TOLERANCES(y_ref, y_test, 0, 0); - } -} - TEST(MathKernelTest, Bmm) { auto test_bmm = [](int64_t last_dim) { auto x = rand({1, 4, 4}, at::kFloat); diff --git a/aten/src/ATen/test/scalar_test.cpp b/aten/src/ATen/test/scalar_test.cpp index bd9e84bc23554..b6762e1739458 100644 --- a/aten/src/ATen/test/scalar_test.cpp +++ b/aten/src/ATen/test/scalar_test.cpp @@ -194,34 +194,3 @@ TEST(TestScalar, TestFormatting) { ASSERT_EQ("(2,3.1)", format(Scalar(c10::complex(2.0, 3.1)))); ASSERT_EQ("4", format(Scalar(Scalar(4).toSymInt()))); } - -TEST(TestSymInt, Basic) { - Scalar foo; - auto a_impl = c10::make_intrusive(); - foo = Scalar(a_impl->toSymInt()); - ASSERT_EQ(a_impl.use_count(), 2); - Scalar bar{foo}; - ASSERT_EQ(a_impl.use_count(), 3); - auto baz = bar; - ASSERT_EQ(a_impl.use_count(), 4); - auto foo2 = std::move(bar); - ASSERT_EQ(a_impl.use_count(), 4); - ASSERT_TRUE(foo2.isSymInt()); - // NOLINTNEXTLINE(bugprone-use-after-move,clang-analyzer-cplusplus.Move) - ASSERT_TRUE(bar.isIntegral(false)); - foo2 = SymInt(4); - ASSERT_FALSE(foo2.isSymInt()); - ASSERT_EQ(foo2.toSymInt().expect_int(), 4); - // NOLINTNEXTLINE(clang-diagnostic-self-assign-overloaded) - foo2 = foo2; - ASSERT_FALSE(foo2.isSymInt()); - ASSERT_EQ(foo2.toSymInt().expect_int(), 4); - - ASSERT_EQ(a_impl.use_count(), 3); - - ASSERT_THROW(foo.to(), c10::Error); - - Scalar int_s = 3; - TORCH_CHECK(int_s.toSymInt().expect_int(), 3); - -} diff --git a/aten/src/ATen/test/vulkan_api_test.cpp b/aten/src/ATen/test/vulkan_api_test.cpp index a0f00daed5742..39edf4ae3a8c2 100644 --- a/aten/src/ATen/test/vulkan_api_test.cpp +++ b/aten/src/ATen/test/vulkan_api_test.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include // TODO: These functions should move to a common place. @@ -251,6 +252,7 @@ class VulkanAPITest : public ::testing::Test { }; TEST_F(VulkanAPITest, copy_to_texture) { + using namespace at::native::vulkan; at::Tensor test_tensors[] = { // 4D at::rand({7, 17, 134, 213}, at::TensorOptions(at::kCPU).dtype(at::kFloat)), @@ -272,6 +274,8 @@ TEST_F(VulkanAPITest, copy_to_texture) { std::cout << "Copy failed on size " << in_cpu.sizes() << "with dtype" << in_cpu.dtype() << std::endl; } + + ASSERT_TRUE(check_copy); } } @@ -629,7 +633,7 @@ TEST_F(VulkanAPITest, batch_norm_invalid_inputs) { at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), - true, + false, 0.1, 1e-05, false); @@ -643,7 +647,7 @@ TEST_F(VulkanAPITest, batch_norm_invalid_inputs) { at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), - true, + false, 0.1, 1e-05, false); @@ -657,7 +661,7 @@ TEST_F(VulkanAPITest, batch_norm_invalid_inputs) { at::rand({7}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), at::rand({7}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), at::rand({7}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), - true, + false, 0.1, 1e-05, false); @@ -671,7 +675,7 @@ TEST_F(VulkanAPITest, batch_norm_invalid_inputs) { at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), - true, + false, 0.1, 1e-05, false); @@ -685,7 +689,7 @@ TEST_F(VulkanAPITest, batch_norm_invalid_inputs) { at::rand({12}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), - true, + false, 0.1, 1e-05, false); @@ -699,7 +703,7 @@ TEST_F(VulkanAPITest, batch_norm_invalid_inputs) { at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), at::rand({12}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), - true, + false, 0.1, 1e-05, false); @@ -713,7 +717,7 @@ TEST_F(VulkanAPITest, batch_norm_invalid_inputs) { at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), at::rand({8}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), at::rand({12}, at::device(at::kCPU).dtype(at::kFloat)).vulkan(), - true, + false, 0.1, 1e-05, false); @@ -844,6 +848,130 @@ TEST_F(VulkanAPITest, clamp_) { ASSERT_TRUE(check); } +void test_conv2d_context( + const at::IntArrayRef input_shape, + const at::IntArrayRef weight_shape, + const at::IntArrayRef bias_shape, + std::vector stride, + std::vector padding, + std::vector dilation, + int64_t groups) { + c10::InferenceMode mode; + + at::Tensor input = at::rand(input_shape, at::device(at::kCPU).dtype(at::kFloat)); + at::Tensor weight = at::rand(weight_shape, at::device(at::kCPU).dtype(at::kFloat)); + at::Tensor bias = at::rand(bias_shape, at::device(at::kCPU).dtype(at::kFloat)); + + // cpu + const auto out_cpu = at::conv2d( + input, weight, bias, stride, padding, dilation, groups); + + // vulkan + const auto prepack_vulkan = callOpByName( + "vulkan_prepack::create_conv2d_context", + "", + weight, bias, stride, padding, dilation, groups, c10::nullopt, c10::nullopt); + + const auto vulkan_output = callOpByName( + "vulkan_prepack::run_conv2d_context", + "", + input.vulkan(), prepack_vulkan[0]); + + const auto out_vulkan = vulkan_output[0].toTensor(); + const auto out_vk_cpu = out_vulkan.cpu(); + + // check + const bool check = almostEqual(out_cpu, out_vk_cpu); + if (!check) { + showRtol(out_cpu, out_vk_cpu); + } + + ASSERT_TRUE(check); +} + +void test_backwards_compatible_conv2d_context( + const at::IntArrayRef input_shape, + const at::IntArrayRef weight_shape, + const at::IntArrayRef bias_shape, + std::vector stride, + std::vector padding, + std::vector dilation, + int64_t groups) { + c10::InferenceMode mode; + + at::Tensor input = at::rand(input_shape, at::device(at::kCPU).dtype(at::kFloat)); + at::Tensor weight = at::rand(weight_shape, at::device(at::kCPU).dtype(at::kFloat)); + at::Tensor bias = at::rand(bias_shape, at::device(at::kCPU).dtype(at::kFloat)); + + // cpu + const auto out_cpu = at::conv2d( + input, weight, bias, stride, padding, dilation, groups); + + // vulkan + const auto prepack_vulkan = callOpByName( + "vulkan_prepack::conv2d_clamp_prepack", + "", + weight, bias, stride, padding, dilation, groups, c10::nullopt, c10::nullopt); + + const auto vulkan_output = callOpByName( + "vulkan_prepack::conv2d_clamp_run", + "", + input.vulkan(), prepack_vulkan[0]); + + const auto out_vulkan = vulkan_output[0].toTensor(); + const auto out_vk_cpu = out_vulkan.cpu(); + + // check + const bool check = almostEqual(out_cpu, out_vk_cpu); + if (!check) { + showRtol(out_cpu, out_vk_cpu); + } + + ASSERT_TRUE(check); +} + +void test_transposed_conv2d_context( + const at::IntArrayRef input_shape, + const at::IntArrayRef weight_shape, + const at::IntArrayRef bias_shape, + std::vector stride, + std::vector padding, + std::vector output_padding, + std::vector dilation, + int64_t groups) { + c10::InferenceMode mode; + + at::Tensor input = at::rand(input_shape, at::device(at::kCPU).dtype(at::kFloat)); + at::Tensor weight = at::rand(weight_shape, at::device(at::kCPU).dtype(at::kFloat)); + at::Tensor bias = at::rand(bias_shape, at::device(at::kCPU).dtype(at::kFloat)); + + // cpu + const auto out_cpu = at::conv_transpose2d( + input, weight, bias, stride, padding, output_padding, groups, dilation); + + // vulkan + const auto prepack_vulkan = callOpByName( + "vulkan_prepack::create_tconv2d_context", + "", + weight, bias, stride, padding, output_padding, dilation, groups, c10::nullopt, c10::nullopt); + + const auto vulkan_output = callOpByName( + "vulkan_prepack::run_tconv2d_context", + "", + input.vulkan(), prepack_vulkan[0]); + + const auto out_vulkan = vulkan_output[0].toTensor(); + const auto out_vk_cpu = out_vulkan.cpu(); + + // check + const bool check = almostEqual(out_cpu, out_vk_cpu); + if (!check) { + showRtol(out_cpu, out_vk_cpu); + } + + ASSERT_TRUE(check); +} + TEST_F(VulkanAPITest, conv2d) { constexpr int64_t groups = 1; constexpr std::array stride{2, 2}; @@ -913,6 +1041,158 @@ TEST_F(VulkanAPITest, conv2d) { ASSERT_TRUE(check); } +TEST_F(VulkanAPITest, conv2d_prepack) { + test_conv2d_context( + {1, 3, 8, 8}, // input_shape + {1, 3, 3, 3}, // weight_shape + {1}, // bias_shape + {2, 2}, // stride + {1, 1}, // padding + {1, 1}, // dilation + 1); // groups +} + +TEST_F(VulkanAPITest, conv2d_prepack_bc) { + test_backwards_compatible_conv2d_context( + {1, 3, 8, 8}, // input_shape + {1, 3, 3, 3}, // weight_shape + {1}, // bias_shape + {2, 2}, // stride + {1, 1}, // padding + {1, 1}, // dilation + 1); // groups +} + +TEST_F(VulkanAPITest, conv2d_dw_3x3) { + constexpr int64_t groups = 7; + constexpr std::array stride{2, 3}; + constexpr std::array padding{0, 4}; + constexpr std::array dilation{3, 1}; + + constexpr struct { + uint32_t batches; + uint32_t channels; + uint32_t width; + uint32_t height; + + std::array size() const { + return { + batches, + channels, + width, + height, + }; + } + } input{1, groups, 137, 199}; + + constexpr struct { + uint32_t output_channels; + uint32_t input_channels; + uint32_t width; + uint32_t height; + + std::array size() const { + return { + output_channels, + input_channels, + width, + height, + }; + } + } weights{groups, 1, 3, 3}; + + const auto input_cpu = + at::rand(input.size(), at::device(at::kCPU).dtype(at::kFloat)); + const auto weights_cpu = + at::rand(weights.size(), at::device(at::kCPU).dtype(at::kFloat)); + const auto bias_cpu = at::rand( + {weights.output_channels}, at::device(at::kCPU).dtype(at::kFloat)); + + const auto output_cpu = at::conv2d( + input_cpu, weights_cpu, bias_cpu, stride, padding, dilation, groups); + + const auto output_vulkan = at::conv2d( + input_cpu.vulkan(), + weights_cpu, + bias_cpu, + stride, + padding, + dilation, + groups); + + const bool check = almostEqual(output_cpu, output_vulkan.cpu()); + if (!check) { + showRtol(output_cpu, output_vulkan.cpu()); + } + + ASSERT_TRUE(check); +} + +TEST_F(VulkanAPITest, conv2d_dw_5x5) { + constexpr int64_t groups = 7; + constexpr std::array stride{2, 3}; + constexpr std::array padding{0, 4}; + constexpr std::array dilation{3, 1}; + + constexpr struct { + uint32_t batches; + uint32_t channels; + uint32_t width; + uint32_t height; + + std::array size() const { + return { + batches, + channels, + width, + height, + }; + } + } input{1, groups, 137, 199}; + + constexpr struct { + uint32_t output_channels; + uint32_t input_channels; + uint32_t width; + uint32_t height; + + std::array size() const { + return { + output_channels, + input_channels, + width, + height, + }; + } + } weights{groups, 1, 5, 5}; + + const auto input_cpu = + at::rand(input.size(), at::device(at::kCPU).dtype(at::kFloat)); + const auto weights_cpu = + at::rand(weights.size(), at::device(at::kCPU).dtype(at::kFloat)); + const auto bias_cpu = at::rand( + {weights.output_channels}, at::device(at::kCPU).dtype(at::kFloat)); + + const auto output_cpu = at::conv2d( + input_cpu, weights_cpu, bias_cpu, stride, padding, dilation, groups); + + const auto output_vulkan = at::conv2d( + input_cpu.vulkan(), + weights_cpu, + bias_cpu, + stride, + padding, + dilation, + groups); + + const bool check = almostEqual(output_cpu, output_vulkan.cpu()); + if (!check) { + showRtol(output_cpu, output_vulkan.cpu()); + } + + ASSERT_TRUE(check); +} + TEST_F(VulkanAPITest, conv2d_dw) { constexpr int64_t groups = 7; constexpr std::array stride{2, 3}; @@ -981,6 +1261,28 @@ TEST_F(VulkanAPITest, conv2d_dw) { ASSERT_TRUE(check); } +TEST_F(VulkanAPITest, conv2d_dw_prepack) { + test_conv2d_context( + {1, 7, 137, 199}, // input_shape + {7, 1, 17, 7}, // weight_shape + {7}, // bias_shape + {2, 3}, // stride + {0, 4}, // padding + {3, 1}, // dilation + 7); // groups +} + +TEST_F(VulkanAPITest, conv2d_dw_prepack_bc) { + test_backwards_compatible_conv2d_context( + {1, 7, 137, 199}, // input_shape + {7, 1, 17, 7}, // weight_shape + {7}, // bias_shape + {2, 3}, // stride + {0, 4}, // padding + {3, 1}, // dilation + 7); // groups +} + TEST_F(VulkanAPITest, conv2d_pw) { constexpr int64_t groups = 1; constexpr std::array stride{1, 1}; @@ -1049,6 +1351,115 @@ TEST_F(VulkanAPITest, conv2d_pw) { ASSERT_TRUE(check); } +TEST_F(VulkanAPITest, conv2d_pw_prepack) { + test_conv2d_context( + {1, 17, 127, 397}, // input_shape + {29, 17, 1, 1}, // weight_shape + {29}, // bias_shape + {1, 1}, // stride + {0, 0}, // padding + {1, 1}, // dilation + 1); // groups +} + +TEST_F(VulkanAPITest, conv2d_pw_prepack_bc) { + test_backwards_compatible_conv2d_context( + {1, 17, 127, 397}, // input_shape + {29, 17, 1, 1}, // weight_shape + {29}, // bias_shape + {1, 1}, // stride + {0, 0}, // padding + {1, 1}, // dilation + 1); // groups +} + +TEST_F(VulkanAPITest, conv2d_transposed) { + // Arrange + constexpr int64_t groups = 1; + constexpr std::array stride{1, 2}; + constexpr std::array padding{1, 0}; + constexpr std::array output_padding{0, 1}; + //TODO: Support conv_transpose2d with dilation != 1 + constexpr std::array dilation{1, 1}; + + constexpr struct { + uint32_t batches; + uint32_t channels; + uint32_t height; + uint32_t width; + + std::array size() const { + return { + batches, + channels, + height, + width, + }; + } + } input {1, 55, 7, 19}; + + constexpr struct { + uint32_t input_channels; + uint32_t output_channels; + uint32_t height; + uint32_t width; + + std::array size() const { + return { + input_channels, + output_channels, + height, + width, + }; + } + } weights {input.channels, 47, 2, 3}; + + const auto input_cpu = at::randn(input.size(), at::device(at::kCPU).dtype(at::kFloat)); + const auto weights_cpu = at::randn(weights.size(), at::device(at::kCPU).dtype(at::kFloat)); + const auto bias_cpu = at::zeros({weights.output_channels}, at::device(at::kCPU).dtype(at::kFloat)); + + // Act + const auto output_cpu = at::conv_transpose2d( + input_cpu, + weights_cpu, + bias_cpu, + stride, + padding, + output_padding, + groups, + dilation); + + const auto output_vk = at::conv_transpose2d( + input_cpu.vulkan(), + weights_cpu, + bias_cpu, + stride, + padding, + output_padding, + groups, + dilation).cpu(); + + // Assert + const bool check = almostEqual(output_cpu, output_vk); + if (!check) { + showRtol(output_cpu, output_vk); + } + + ASSERT_TRUE(check); +} + +TEST_F(VulkanAPITest, conv2d_transposed_prepack) { + test_transposed_conv2d_context( + {1, 55, 7, 19}, // input_shape + {55, 47, 2, 3}, // weight_shape + {47}, // bias_shape + {1, 2}, // stride + {1, 0}, // padding + {0, 1}, // output_padding + {1, 1}, // dilation + 1); // groups +} + TEST_F(VulkanAPITest, copy) { const auto cpu = at::rand({13, 17, 37, 19}, at::device(at::kCPU).dtype(at::kFloat)); const auto vulkan = cpu.vulkan(); @@ -1445,6 +1856,36 @@ TEST_F(VulkanAPITest, hardshrink_) { } } +TEST_F(VulkanAPITest, hardtanh) { + const auto in_cpu = at::rand({17, 197, 302, 5}, at::device(at::kCPU).dtype(at::kFloat)) * 10; + const auto in_vulkan = in_cpu.vulkan(); + + const auto out_cpu = at::hardtanh(in_cpu, 3, 7); + const auto out_vulkan = at::hardtanh(in_vulkan, 3, 7); + + const auto check = almostEqual(out_cpu, out_vulkan.cpu()); + if (!check) { + showRtol(out_cpu, out_vulkan.cpu()); + } + + ASSERT_TRUE(check); +} + +TEST_F(VulkanAPITest, hardtanh_) { + auto a_cpu = at::rand({17, 197, 302, 5}, at::device(at::kCPU).dtype(at::kFloat)) * 10; + auto a_vulkan = a_cpu.vulkan(); + + at::hardtanh_(a_cpu, 3, 7); + at::hardtanh_(a_vulkan, 3, 7); + + const auto check = almostEqual(a_cpu, a_vulkan.cpu()); + if (!check) { + showRtol(a_cpu, a_vulkan.cpu()); + } + + ASSERT_TRUE(check); +} + TEST_F(VulkanAPITest, layer_norm_invalid_inputs) { c10::InferenceMode mode; @@ -2229,6 +2670,38 @@ TEST_F(VulkanAPITest, mul_to_scalar_wrapped) { ASSERT_TRUE(check); } +TEST_F(VulkanAPITest, relu) { + const auto in_cpu = at::rand({17, 197, 302, 5}, at::device(at::kCPU).dtype(at::kFloat)); + const auto in_vulkan = in_cpu.vulkan(); + + const auto out_cpu = at::relu(in_cpu); + const auto out_vulkan = at::relu(in_vulkan); + + const auto check = almostEqual(out_cpu, out_vulkan.cpu()); + + if (!check) { + showRtol(out_cpu, out_vulkan.cpu()); + } + + ASSERT_TRUE(check); +} + +TEST_F(VulkanAPITest, relu_) { + auto a_cpu = at::rand({17, 197, 302, 5}, at::device(at::kCPU).dtype(at::kFloat)); + auto a_vulkan = a_cpu.vulkan(); + + at::relu_(a_cpu); + at::relu_(a_vulkan); + + const auto check = almostEqual(a_cpu, a_vulkan.cpu()); + + if (!check) { + showRtol(a_cpu, a_vulkan.cpu()); + } + + ASSERT_TRUE(check); +} + TEST_F(VulkanAPITest, reflection_pad2d) { const auto a_cpu = at::rand({2, 3, 47, 63}, at::device(at::kCPU).dtype(at::kFloat)); const auto a_vulkan = a_cpu.vulkan(); @@ -2702,81 +3175,6 @@ TEST_F(VulkanAPITest, sub_to_scalar_wrapped) { ASSERT_TRUE(check); } -TEST_F(VulkanAPITest, transposed_conv2d) { - // Arrange - constexpr int64_t groups = 1; - constexpr std::array stride{1, 2}; - constexpr std::array padding{1, 0}; - constexpr std::array output_padding{0, 1}; - //TODO: Support conv_transpose2d with dilation != 1 - constexpr std::array dilation{1, 1}; - - constexpr struct { - uint32_t batches; - uint32_t channels; - uint32_t height; - uint32_t width; - - std::array size() const { - return { - batches, - channels, - height, - width, - }; - } - } input {1, 55, 7, 19}; - - constexpr struct { - uint32_t input_channels; - uint32_t output_channels; - uint32_t height; - uint32_t width; - - std::array size() const { - return { - input_channels, - output_channels, - height, - width, - }; - } - } weights {input.channels, 47, 2, 3}; - - const auto input_cpu = at::randn(input.size(), at::device(at::kCPU).dtype(at::kFloat)); - const auto weights_cpu = at::randn(weights.size(), at::device(at::kCPU).dtype(at::kFloat)); - const auto bias_cpu = at::zeros({weights.output_channels}, at::device(at::kCPU).dtype(at::kFloat)); - - // Act - const auto output_cpu = at::conv_transpose2d( - input_cpu, - weights_cpu, - bias_cpu, - stride, - padding, - output_padding, - groups, - dilation); - - const auto output_vk = at::conv_transpose2d( - input_cpu.vulkan(), - weights_cpu, - bias_cpu, - stride, - padding, - output_padding, - groups, - dilation).cpu(); - - // Assert - const bool check = almostEqual(output_cpu, output_vk); - if (!check) { - showRtol(output_cpu, output_vk); - } - - ASSERT_TRUE(check); -} - TEST_F(VulkanAPITest, upsample_nearest2d) { const auto in_cpu = at::rand({1, 2, 2, 3}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); const auto out_cpu = at::upsample_nearest2d(in_cpu, {4, 6}); @@ -2884,6 +3282,44 @@ TEST_F(VulkanAPITest, view_invalid_inputs) { }, ::std::runtime_error); } +TEST_F(VulkanAPITest, cat_dim0_invalidinputs_exceptions) { + // Arrange: Vulkan cat inputs must have matching sizes except concatenated dimension + { + const auto in_cpu1 = at::rand({3, 5, 221, 193}, at::device(at::kCPU).dtype(at::kFloat)); + const auto in_cpu2 = at::rand({3, 9, 112, 193}, at::device(at::kCPU).dtype(at::kFloat)); + const auto in_cpu3 = at::rand({3, 9, 331, 193}, at::device(at::kCPU).dtype(at::kFloat)); + + // Act + EXPECT_THROW({ + const auto out_vulkan = at::cat({in_cpu1.vulkan(), in_cpu2.vulkan(), in_cpu3.vulkan()}, 0); + }, ::c10::Error); + } + + // Arrange: Vulkan cat expects 4 dimensional inputs + { + const auto in_cpu1 = at::rand({3, 9, 221, 193}, at::device(at::kCPU).dtype(at::kFloat)); + const auto in_cpu2 = at::rand({9, 112, 193}, at::device(at::kCPU).dtype(at::kFloat)); + const auto in_cpu3 = at::rand({3, 9, 331, 193}, at::device(at::kCPU).dtype(at::kFloat)); + + // Act + EXPECT_THROW({ + const auto out_vulkan = at::cat({in_cpu1.vulkan(), in_cpu2.vulkan(), in_cpu3.vulkan()}, 0); + }, ::c10::Error); + } + + // Arrange: Vulkan cat not implemented for batch dimension! + { + const auto in_cpu1 = at::rand({221, 3, 9, 193}, at::device(at::kCPU).dtype(at::kFloat)); + const auto in_cpu2 = at::rand({112, 3, 9, 193}, at::device(at::kCPU).dtype(at::kFloat)); + const auto in_cpu3 = at::rand({331, 3, 9, 193}, at::device(at::kCPU).dtype(at::kFloat)); + + // Act + EXPECT_THROW({ + const auto out_vulkan = at::cat({in_cpu1.vulkan(), in_cpu2.vulkan(), in_cpu3.vulkan()}, 0); + }, ::c10::Error); + } +} + #if !defined(__APPLE__) TEST_F(VulkanAPITest, DISABLED_cat_dim1_samefeature_success) { // Arrange @@ -3112,6 +3548,25 @@ TEST_F(VulkanAPITest, cat_dim2_diffheight_success) { ASSERT_TRUE(check); } +TEST_F(VulkanAPITest, cat_dim2_negdim_success) { + // Arrange + const auto in_cpu1 = at::rand({3, 9, 221, 193}, at::device(at::kCPU).dtype(at::kFloat)); + const auto in_cpu2 = at::rand({3, 9, 112, 193}, at::device(at::kCPU).dtype(at::kFloat)); + const auto in_cpu3 = at::rand({3, 9, 331, 193}, at::device(at::kCPU).dtype(at::kFloat)); + + // Act + const auto out_cpu = at::cat({in_cpu1, in_cpu2, in_cpu3}, -2); + const auto out_vulkan = at::cat({in_cpu1.vulkan(), in_cpu2.vulkan(), in_cpu3.vulkan()}, -2); + + // Assert + const auto check = almostEqual(out_cpu, out_vulkan.cpu()); + if (!check) { + showRtol(out_cpu, out_vulkan.cpu()); + } + + ASSERT_TRUE(check); +} + TEST_F(VulkanAPITest, cat_dim2_singledepth_success) { // Arrange: batch x channel (1x1) = single depth texture const auto in_cpu1 = at::rand({1, 1, 221, 193}, at::device(at::kCPU).dtype(at::kFloat)); @@ -3175,6 +3630,44 @@ TEST_F(VulkanAPITest, cat_dim2_invalidinputs_exceptions) { } } +TEST_F(VulkanAPITest, cat_dim3_invalidinputs_exceptions) { + // Arrange: Vulkan cat inputs must have matching sizes except concatenated dimension + { + const auto in_cpu1 = at::rand({3, 5, 221, 193}, at::device(at::kCPU).dtype(at::kFloat)); + const auto in_cpu2 = at::rand({3, 9, 112, 193}, at::device(at::kCPU).dtype(at::kFloat)); + const auto in_cpu3 = at::rand({3, 9, 331, 193}, at::device(at::kCPU).dtype(at::kFloat)); + + // Act + EXPECT_THROW({ + const auto out_vulkan = at::cat({in_cpu1.vulkan(), in_cpu2.vulkan(), in_cpu3.vulkan()}, 3); + }, ::c10::Error); + } + + // Arrange: Vulkan cat expects 4 dimensional inputs + { + const auto in_cpu1 = at::rand({3, 9, 221, 193}, at::device(at::kCPU).dtype(at::kFloat)); + const auto in_cpu2 = at::rand({9, 112, 193}, at::device(at::kCPU).dtype(at::kFloat)); + const auto in_cpu3 = at::rand({3, 9, 331, 193}, at::device(at::kCPU).dtype(at::kFloat)); + + // Act + EXPECT_THROW({ + const auto out_vulkan = at::cat({in_cpu1.vulkan(), in_cpu2.vulkan(), in_cpu3.vulkan()}, 3); + }, ::c10::Error); + } + + // Arrange: Vulkan cat not implemented for width dimension! + { + const auto in_cpu1 = at::rand({3, 9, 193, 221}, at::device(at::kCPU).dtype(at::kFloat)); + const auto in_cpu2 = at::rand({3, 9, 193, 112}, at::device(at::kCPU).dtype(at::kFloat)); + const auto in_cpu3 = at::rand({3, 9, 193, 331}, at::device(at::kCPU).dtype(at::kFloat)); + + // Act + EXPECT_THROW({ + const auto out_vulkan = at::cat({in_cpu1.vulkan(), in_cpu2.vulkan(), in_cpu3.vulkan()}, 3); + }, ::c10::Error); + } +} + TEST_F(VulkanAPITest, permute_2d_success) { // Arrange const auto in_cpu = at::rand({2, 3}, at::device(at::kCPU).dtype(at::kFloat)); diff --git a/aten/src/ATen/test/vulkan_perf_test.cpp b/aten/src/ATen/test/vulkan_perf_test.cpp index 0c1c6b9cfe378..51cce68d4c4db 100644 --- a/aten/src/ATen/test/vulkan_perf_test.cpp +++ b/aten/src/ATen/test/vulkan_perf_test.cpp @@ -1,3 +1,4 @@ +#include #ifdef USE_VULKAN_API #include @@ -10,8 +11,46 @@ #include #include +#include + namespace { +namespace vulkan_api = at::native::vulkan::api; +void report_pep(const std::string& name, const uint64_t duration) { + std::stringstream buffer; + buffer << "PyTorchObserver {\"type\": \""; + buffer << name << "\","; + buffer << "\"unit\": \"" + << "ns" + << "\"," + << "\"metric\": \"" + << "latency" + << "\","; + buffer << "\"value\": \"" << duration << "\""; + buffer << "}\n"; + std::cout << buffer.str(); +} + +void report_aibench_res(vulkan_api::QueryPool& qpool) { + std::unordered_map shader_runtimes; + uint64_t num_additions = 0; + auto result_aggregator = + [&shader_runtimes, &num_additions](const vulkan_api::ShaderDuration& s) { + if (shader_runtimes.count(s.kernel_name) == 0) { + shader_runtimes[s.kernel_name] = 0; + } + shader_runtimes[s.kernel_name] += s.execution_duration_ns; + num_additions += 1; + }; + qpool.shader_log_for_each(result_aggregator); + uint64_t num_iters = num_additions / shader_runtimes.size(); + for (const auto& i : shader_runtimes) { + const auto& name = i.first; + const auto& duration = i.second / num_iters; + report_pep(name, duration); + } +} + at::Tensor vulkan_to_cpu(at::Tensor vulkan, at::Tensor in_cpu) { auto q_options = in_cpu.options(); if (q_options.dtype().toScalarType() == c10::ScalarType::QUInt8) { @@ -536,10 +575,10 @@ static void conv2ddw_op_benchmark(benchmark::State& state) { const auto batches_in = safe_downcast(state.range(0)); const auto height_in = safe_downcast(state.range(2)); const auto width_in = safe_downcast(state.range(3)); - constexpr int64_t groups = 7; - constexpr std::array stride{2, 3}; - constexpr std::array padding{0, 4}; - constexpr std::array dilation{3, 1}; + constexpr int64_t groups = 32; + constexpr std::array stride{1, 1}; + constexpr std::array padding{0, 0}; + constexpr std::array dilation{1, 1}; struct { uint32_t batches; @@ -571,7 +610,7 @@ static void conv2ddw_op_benchmark(benchmark::State& state) { height, }; } - } weights{groups, 1, 17, 7}; + } weights{groups, 1, 3, 3}; const auto input_cpu = at::randn(input.size(), at::device(at::kCPU).dtype(at::kFloat)); @@ -606,6 +645,7 @@ static void conv2ddw_op_benchmark(benchmark::State& state) { #if defined(USE_VULKAN_GPU_DIAGNOSTICS) && defined(__ANDROID__) at::native::vulkan::api::context()->querypool().extract_results(); at::native::vulkan::api::context()->querypool().print_results(); + report_aibench_res(vulkan_api::context()->querypool()); state.SetIterationTime(at::native::vulkan::api::context()->querypool().get_total_op_ns("conv2d_dw") / 1000000.0); #endif } @@ -1053,7 +1093,7 @@ BENCHMARK(conv2ddw_op_benchmark) ->UseManualTime() ->Threads(1) ->Iterations(10) - ->Args({1, 7, 137, 199}); + ->Args({1, 32, 256, 256}); BENCHMARK(conv2ddw_op_q_benchmark) ->Apply(CommonBenchmarkSettings) ->UseManualTime() diff --git a/aten/src/ATen/test/vulkan_quantized_api_test.cpp b/aten/src/ATen/test/vulkan_quantized_api_test.cpp index 50cceafdb5ff2..89c205ef3d7d4 100644 --- a/aten/src/ATen/test/vulkan_quantized_api_test.cpp +++ b/aten/src/ATen/test/vulkan_quantized_api_test.cpp @@ -9,12 +9,24 @@ #include #include #include +#include +#include #include +/* + * TODO: rename this file to something like vulkan_experimental_test and move + * this under caffe2/fb/vulkan. This file should be used to test experimental + * features of the Vulkan backend. vulkan_api_test cannot serve this purpose + * because it cannot link against symbols in the ATen/native/vulkan folder. + */ + namespace { -bool checkRtol(const at::Tensor& diff, const std::vector& inputs) { +bool checkRtol( + const at::Tensor& diff, + const std::vector& inputs, + const float tolerated_error = 0) { float maxValue = 0.0f; for (const auto& tensor : inputs) { @@ -27,11 +39,11 @@ bool checkRtol(const at::Tensor& diff, const std::vector& inputs) { constexpr float tolerance = 1e-5; #endif - return diff.abs().max().item() <= (tolerance * maxValue); + return diff.abs().max().item() <= (tolerance * maxValue + tolerated_error); } -bool almostEqual(const at::Tensor& a, const at::Tensor& b) { - return checkRtol(a - b, {a, b}); +bool almostEqual(const at::Tensor& a, const at::Tensor& b, const float tolerated_error = 0) { + return checkRtol(a - b, {a, b}, tolerated_error); } /* Unused function @@ -99,6 +111,93 @@ inline std::vector callOpByName( namespace { +double rand01() { + return (double)rand() / (double)RAND_MAX; +} + +int64_t rand_pos_int(const int max_val) { + TORCH_CHECK(max_val > 0, "max value must be positive"); + return 1 + rand() % max_val; +} + +at::Tensor produce_random_tensor( + const at::IntArrayRef tensor_shape, + const float s_min = 1.0, + const float s_max = 100.0, + const float shift = 0.45) { + // tensor is randomly generated with values in the range + // [-shift * s, (1-shift) * s), where s is randomly generated in the range + // [s_min, s_max] + // with these default values, s is randomly generated in the range [1, 100] + // this means that the range of the tensor values could be as narrow as + // [-0.45, 0.55) or as wide as [-45.0, 55.0) + TORCH_CHECK(s_min > 0, "scalar lower bound must be positive"); + TORCH_CHECK(s_min <= s_max, "scalar lower bound must be <= upper bound"); + const auto scalar = s_min + (s_max - s_min) * (float)rand()/(float)RAND_MAX; + return scalar * + (at::rand(tensor_shape, at::device(at::kCPU).dtype(at::kFloat)) - shift); +} + +double produce_random_scale( + const double scale_min = 0.001, + const double scale_max = 2.0) { + TORCH_CHECK(scale_min <= scale_max, "scale min must be <= scale max"); + // scale is randomly generated in the range [scale_min, scale_max) + return rand01() * (scale_max - scale_min) + scale_min; +} + +int64_t produce_random_zero_point(const c10::ScalarType dtype) { + int64_t zero_point; + switch (dtype) { + case c10::ScalarType::QUInt8: + zero_point = rand() % 256; + break; + case c10::ScalarType::QInt8: + zero_point = rand() % 256 - 128; + break; + case c10::ScalarType::QInt32: + zero_point = rand() % 100000 - 200000; + break; + default: + TORCH_CHECK( + false, "Vulkan quantization currently not supported for dtype ", dtype + ); + } + return zero_point; +} + +std::tuple compute_quant_params( + const at::Tensor tensor, + const c10::ScalarType dtype = c10::ScalarType::QUInt8) { + int zero_point_min; + int zero_point_max; + if (dtype == c10::ScalarType::QUInt8) { + zero_point_min = 0; + zero_point_max = 255; + } else if (dtype == c10::ScalarType::QInt8) { + zero_point_min = -128; + zero_point_max = 127; + } else { + TORCH_CHECK(false, "Computation of quant params only available for dtypes", + "QUInt8 and QInt8"); + } + const auto tensor_max = tensor.max().item(); + const auto tensor_min = tensor.min().item(); + auto q_params = quant_utils::ChooseQuantizationParams( + /*min=*/tensor_min, + /*max=*/tensor_max, + /*qmin=*/zero_point_min, + /*qmax=*/zero_point_max, + /*preserve_sparsity=*/false, + /*force_scale_power_of_two=*/false, + /*reduce_range=*/false); + return std::tuple(q_params.scale, q_params.zero_point); +} + +} // namespace + +namespace { + class VulkanAPITest : public ::testing::Test { public: void SetUp() { @@ -125,10 +224,12 @@ class VulkanAPITest : public ::testing::Test { at::Tensor cpu_to_vulkan(at::Tensor in_cpu) { auto options = in_cpu.options(); - if (options.dtype().toScalarType() == c10::ScalarType::QUInt8) { + if (options.dtype().toScalarType() == c10::ScalarType::QUInt8 || + options.dtype().toScalarType() == c10::ScalarType::QInt8 || + options.dtype().toScalarType() == c10::ScalarType::QInt32) { auto ret = at::native::vulkan::ops::_empty_affine_quantized( in_cpu.sizes(), - c10::ScalarType::QUInt8, + options.dtype().toScalarType(), options.layout(), options.device(), options.pinned_memory(), @@ -146,7 +247,9 @@ at::Tensor cpu_to_vulkan(at::Tensor in_cpu) { at::Tensor vulkan_to_cpu(at::Tensor vulkan, at::Tensor in_cpu) { auto q_options = in_cpu.options(); - if (q_options.dtype().toScalarType() == c10::ScalarType::QUInt8) { + if (q_options.dtype().toScalarType() == c10::ScalarType::QUInt8 || + q_options.dtype().toScalarType() == c10::ScalarType::QInt8 || + q_options.dtype().toScalarType() == c10::ScalarType::QInt32) { auto output = at::native::empty_affine_quantized( in_cpu.sizes(), q_options.dtype().toScalarType(), @@ -164,7 +267,87 @@ at::Tensor vulkan_to_cpu(at::Tensor vulkan, at::Tensor in_cpu) { } } -TEST_F(VulkanAPITest, support_vulkan) { +TEST_F(VulkanAPITest, uniform_buffer_copy) { + using namespace at::native::vulkan; + + struct TestStruct{ + int a; + int b; + int c; + }; + + TestStruct test_struct{4, 9, 10}; + + api::UniformParamsBuffer params(api::context(), test_struct); + api::UniformParamsBuffer params_copy = params; + + api::MemoryMap copy_mapping( + params_copy.buffer(), api::MemoryAccessType::READ); + + TestStruct* test_copy_p = copy_mapping.template data(); + + ASSERT_TRUE(test_copy_p->a == test_struct.a); + ASSERT_TRUE(test_copy_p->b == test_struct.b); + ASSERT_TRUE(test_copy_p->c == test_struct.c); +} + +TEST_F(VulkanAPITest, copy_to_buffer) { + using namespace at::native::vulkan; + + at::Tensor test_tensors[] = { + // 4D + at::rand({7, 17, 134, 213}, at::TensorOptions(at::kCPU).dtype(at::kFloat)), + // 3D + at::rand({67, 134, 213}, at::TensorOptions(at::kCPU).dtype(at::kFloat)), + // 2D + at::rand({229, 213}, at::TensorOptions(at::kCPU).dtype(at::kFloat)), + // 1D + at::rand({1902}, at::TensorOptions(at::kCPU).dtype(at::kFloat)), + }; + + for (auto in_cpu : test_tensors) { + ops::vTensor in_vk_copied = ops::to_vulkan(in_cpu, api::StorageType::BUFFER); + at::Tensor out_copied = ops::from_vulkan(in_vk_copied); + + const auto check_copy = almostEqual(out_copied, in_cpu); + + if(!check_copy) { + std::cout << "Copy failed on size " << in_cpu.sizes() + << "with dtype" << in_cpu.dtype() << std::endl; + } + + ASSERT_TRUE(check_copy); + } +} + +TEST_F(VulkanAPITest, copy_to_buffer_channels_last) { + using namespace at::native::vulkan; + + at::TensorOptions options(at::kCPU); + options = options.dtype(at::kFloat); + + at::Tensor test_tensors[] = { + // 4D + at::rand({7, 17, 134, 213}, options).to(at::MemoryFormat::ChannelsLast), + }; + + for (auto in_cpu : test_tensors) { + ops::vTensor in_vk_copied = ops::to_vulkan(in_cpu, api::StorageType::BUFFER); + at::Tensor out_copied = ops::from_vulkan(in_vk_copied); + + const auto check_copy = almostEqual(out_copied, in_cpu); + + if(!check_copy) { + std::cout << "Copy failed on size " << in_cpu.sizes() + << "with dtype" << in_cpu.dtype() << std::endl; + } + + ASSERT_TRUE(check_copy); + } +} + +// TODO: Fix vulkan to cpu on Android +TEST_F(VulkanAPITest, DISABLED_support_vulkan) { const double scale = 0.1; const int64_t zero_point = 10; @@ -202,7 +385,256 @@ TEST_F(VulkanAPITest, support_vulkan) { ASSERT_TRUE(check); } -TEST_F(VulkanAPITest, quantize_per_tensor) { +void test_cpu_to_vulkan_and_vulkan_to_cpu( + const at::IntArrayRef input_shape, + const double scale, + const int zero_point, + const c10::ScalarType dtype = c10::ScalarType::QUInt8) { + + // produce random quantized cpu tensor + auto in_cpu = produce_random_tensor(input_shape); + auto in_q_cpu = at::quantize_per_tensor( + in_cpu, scale, zero_point, dtype); + + // copy quantized cpu tensor to vulkan + auto in_q_cpu_vk = cpu_to_vulkan(in_q_cpu); + + // copy quantized vulkan tensor to cpu + auto out_q_cpu = vulkan_to_cpu(in_q_cpu_vk, in_q_cpu); + + // check that the copy equals the original + const auto diff = at::native::int_repr_quantized_cpu(in_q_cpu) + - at::native::int_repr_quantized_cpu(out_q_cpu); + + const int error = diff.abs().max().item(); + + const auto check = (error == 0); + + if (!check) { + std::cout + << "Copy to vulkan and back to cpu failed with input shape: " + << input_shape << " scale: " << scale << " and zero point: " + << zero_point << std::endl; + std::cout << "Error: " << error << std::endl; + } + + ASSERT_TRUE(check); +} + +void test_cpu_to_vulkan_and_vulkan_to_cpu_random( + const c10::ScalarType dtype) { + const double scale = produce_random_scale(); + const int64_t zero_point = produce_random_zero_point(dtype); + const at::IntArrayRef tensor_shape = + {rand_pos_int(30), rand_pos_int(30), rand_pos_int(100), rand_pos_int(100)}; + test_cpu_to_vulkan_and_vulkan_to_cpu( + tensor_shape, scale, zero_point, dtype); +} + +// TODO: Fix vulkan to cpu on Android +TEST_F(VulkanAPITest, DISABLED_cpu_to_vulkan_and_vulkan_to_cpu_quint8) { + const c10::ScalarType dtype = c10::ScalarType::QUInt8; + test_cpu_to_vulkan_and_vulkan_to_cpu({1, 1, 1, 1}, 0.13, 21, dtype); + test_cpu_to_vulkan_and_vulkan_to_cpu({1, 1, 1, 4}, 0.3, 87, dtype); + test_cpu_to_vulkan_and_vulkan_to_cpu({1, 1, 4, 1}, 0.2, 120, dtype); + test_cpu_to_vulkan_and_vulkan_to_cpu({1, 1, 7, 7}, 0.3, 87, dtype); + test_cpu_to_vulkan_and_vulkan_to_cpu({1, 1, 8, 8}, 0.1, 10, dtype); + test_cpu_to_vulkan_and_vulkan_to_cpu({3, 5, 8, 8}, 0.04, 97, dtype); + test_cpu_to_vulkan_and_vulkan_to_cpu({1, 1, 11, 17}, 0.07, 15, dtype); + test_cpu_to_vulkan_and_vulkan_to_cpu({1, 1, 12, 17}, 0.1, 10, dtype); + test_cpu_to_vulkan_and_vulkan_to_cpu({3, 5, 12, 17}, 0.1, 10, dtype); + test_cpu_to_vulkan_and_vulkan_to_cpu({1, 1, 17, 12}, 0.1, 10, dtype); + test_cpu_to_vulkan_and_vulkan_to_cpu({2, 4, 17, 12}, 0.1, 10, dtype); + test_cpu_to_vulkan_and_vulkan_to_cpu({1, 1, 10, 14}, 0.0001, 101, dtype); + test_cpu_to_vulkan_and_vulkan_to_cpu({3, 5, 10, 14}, 0.009, 43, dtype); + test_cpu_to_vulkan_and_vulkan_to_cpu({3, 5, 10, 15}, 0.1, 19, dtype); + test_cpu_to_vulkan_and_vulkan_to_cpu({4, 4, 9, 17}, 0.1, 19, dtype); + test_cpu_to_vulkan_and_vulkan_to_cpu({3, 5, 25, 29}, 0.1, 19, dtype); + test_cpu_to_vulkan_and_vulkan_to_cpu({4, 4, 25, 29}, 0.1, 19, dtype); + test_cpu_to_vulkan_and_vulkan_to_cpu({11, 17, 25, 29}, 0.027, 89, dtype); + + for (int i = 0; i < 20; i += 1) { + test_cpu_to_vulkan_and_vulkan_to_cpu_random(dtype); + } +} + +// TODO: Fix vulkan to cpu on Android +TEST_F(VulkanAPITest, DISABLED_cpu_to_vulkan_and_vulkan_to_cpu_qint8) { + const c10::ScalarType dtype = c10::ScalarType::QInt8; + test_cpu_to_vulkan_and_vulkan_to_cpu({1, 1, 1, 1}, 0.13, -21, dtype); + test_cpu_to_vulkan_and_vulkan_to_cpu({1, 1, 1, 4}, 0.3, 87, dtype); + test_cpu_to_vulkan_and_vulkan_to_cpu({1, 1, 4, 1}, 0.2, -120, dtype); + test_cpu_to_vulkan_and_vulkan_to_cpu({1, 1, 7, 7}, 0.3, 87, dtype); + test_cpu_to_vulkan_and_vulkan_to_cpu({1, 1, 8, 8}, 0.1, -10, dtype); + test_cpu_to_vulkan_and_vulkan_to_cpu({3, 5, 8, 8}, 0.04, 97, dtype); + test_cpu_to_vulkan_and_vulkan_to_cpu({1, 1, 11, 17}, 0.07, -15, dtype); + test_cpu_to_vulkan_and_vulkan_to_cpu({1, 1, 12, 17}, 0.1, 10, dtype); + test_cpu_to_vulkan_and_vulkan_to_cpu({3, 5, 12, 17}, 0.1, -10, dtype); + test_cpu_to_vulkan_and_vulkan_to_cpu({1, 1, 17, 12}, 0.1, 10, dtype); + test_cpu_to_vulkan_and_vulkan_to_cpu({2, 4, 17, 12}, 0.1, -10, dtype); + test_cpu_to_vulkan_and_vulkan_to_cpu({1, 1, 10, 14}, 0.0001, 101, dtype); + test_cpu_to_vulkan_and_vulkan_to_cpu({3, 5, 10, 14}, 0.009, -43, dtype); + test_cpu_to_vulkan_and_vulkan_to_cpu({3, 5, 10, 15}, 0.1, 19, dtype); + test_cpu_to_vulkan_and_vulkan_to_cpu({4, 4, 9, 17}, 0.1, -19, dtype); + test_cpu_to_vulkan_and_vulkan_to_cpu({3, 5, 25, 29}, 0.1, 19, dtype); + test_cpu_to_vulkan_and_vulkan_to_cpu({4, 4, 25, 29}, 0.1, -19, dtype); + test_cpu_to_vulkan_and_vulkan_to_cpu({11, 17, 25, 29}, 0.027, 89, dtype); + + for (int i = 0; i < 20; i += 1) { + test_cpu_to_vulkan_and_vulkan_to_cpu_random(dtype); + } +} + +// TODO: Fix vulkan to cpu on Android +TEST_F(VulkanAPITest, DISABLED_cpu_to_vulkan_and_vulkan_to_cpu_qint32) { + const c10::ScalarType dtype = c10::ScalarType::QInt32; + test_cpu_to_vulkan_and_vulkan_to_cpu({1, 1, 1, 1}, 0.13, -21123, dtype); + test_cpu_to_vulkan_and_vulkan_to_cpu({1, 1, 1, 4}, 0.339, 8734, dtype); + test_cpu_to_vulkan_and_vulkan_to_cpu({1, 1, 4, 1}, 0.228, -12023, dtype); + test_cpu_to_vulkan_and_vulkan_to_cpu({1, 1, 7, 7}, 0.338, 8723, dtype); + test_cpu_to_vulkan_and_vulkan_to_cpu({1, 1, 8, 8}, 0.193, -1023, dtype); + test_cpu_to_vulkan_and_vulkan_to_cpu({3, 5, 8, 8}, 0.0449, 972, dtype); + test_cpu_to_vulkan_and_vulkan_to_cpu({1, 1, 11, 17}, 0.073, -15, dtype); + test_cpu_to_vulkan_and_vulkan_to_cpu({1, 1, 12, 17}, 0.1572, 102, dtype); + test_cpu_to_vulkan_and_vulkan_to_cpu({3, 5, 12, 17}, 0.147, -156, dtype); + test_cpu_to_vulkan_and_vulkan_to_cpu({1, 1, 17, 12}, 0.129, 10448, dtype); + test_cpu_to_vulkan_and_vulkan_to_cpu({2, 4, 17, 12}, 0.137, -10, dtype); + test_cpu_to_vulkan_and_vulkan_to_cpu({1, 1, 10, 14}, 0.0001, 101, dtype); + test_cpu_to_vulkan_and_vulkan_to_cpu({3, 5, 10, 14}, 0.009, -43267, dtype); + test_cpu_to_vulkan_and_vulkan_to_cpu({3, 5, 10, 15}, 0.1243, 19, dtype); + test_cpu_to_vulkan_and_vulkan_to_cpu({4, 4, 9, 17}, 0.1889, -19784, dtype); + test_cpu_to_vulkan_and_vulkan_to_cpu({3, 5, 25, 29}, 0.1345, 196, dtype); + test_cpu_to_vulkan_and_vulkan_to_cpu({4, 4, 25, 29}, 0.129, -19489, dtype); + test_cpu_to_vulkan_and_vulkan_to_cpu({11, 17, 25, 29}, 0.027, 89, dtype); + + for (int i = 0; i < 20; i += 1) { + test_cpu_to_vulkan_and_vulkan_to_cpu_random(dtype); + } +} + +void test_cpu_to_vulkan_and_dequantize( + const at::IntArrayRef input_shape, + const double scale, + const int zero_point, + const c10::ScalarType dtype = c10::ScalarType::QUInt8) { + + // produce random quantized cpu tensor + auto in_cpu = produce_random_tensor(input_shape); + auto in_q_cpu = at::quantize_per_tensor( + in_cpu, scale, zero_point, dtype); + + // copy quantized cpu tensor to vulkan + auto in_q_cpu_vk = cpu_to_vulkan(in_q_cpu); + + // dequantize tensors + const auto out_cpu_deq = at::dequantize(in_q_cpu); + const auto out_vk_deq = at::dequantize(in_q_cpu_vk); + const auto out_vk_deq_cpu = out_vk_deq.cpu(); + + // check dequantized tensors are equal + const auto check = almostEqual(out_cpu_deq, out_vk_deq_cpu); + + if (!check) { + const auto error = at::abs(out_vk_deq_cpu - out_cpu_deq).max().item(); + std::cout + << "Copy cpu to vulkan and dequantize failed with input shape: " + << input_shape << " scale: " << scale << " and zero point: " + << zero_point << std::endl; + std::cout << "Error: " << error << std::endl; + } + ASSERT_TRUE(check); +} + +void test_cpu_to_vulkan_and_dequantize_random( + const c10::ScalarType dtype) { + const double scale = produce_random_scale(); + const int64_t zero_point = produce_random_zero_point(dtype); + const at::IntArrayRef tensor_shape = + {rand_pos_int(30), rand_pos_int(30), rand_pos_int(100), rand_pos_int(100)}; + test_cpu_to_vulkan_and_dequantize( + tensor_shape, scale, zero_point, dtype); +} + +TEST_F(VulkanAPITest, cpu_to_vulkan_and_dequantize_quint8) { + const c10::ScalarType dtype = c10::ScalarType::QUInt8; + test_cpu_to_vulkan_and_dequantize({1, 1, 1, 1}, 0.13, 21, dtype); + test_cpu_to_vulkan_and_dequantize({1, 1, 1, 4}, 0.3, 87, dtype); + test_cpu_to_vulkan_and_dequantize({1, 1, 4, 1}, 0.2, 120, dtype); + test_cpu_to_vulkan_and_dequantize({1, 1, 7, 7}, 0.3, 87, dtype); + test_cpu_to_vulkan_and_dequantize({1, 1, 8, 8}, 0.1, 10, dtype); + test_cpu_to_vulkan_and_dequantize({3, 5, 8, 8}, 0.04, 97, dtype); + test_cpu_to_vulkan_and_dequantize({1, 1, 11, 17}, 0.07, 15, dtype); + test_cpu_to_vulkan_and_dequantize({1, 1, 12, 17}, 0.1, 10, dtype); + test_cpu_to_vulkan_and_dequantize({3, 5, 12, 17}, 0.1, 10, dtype); + test_cpu_to_vulkan_and_dequantize({1, 1, 17, 12}, 0.1, 10, dtype); + test_cpu_to_vulkan_and_dequantize({2, 4, 17, 12}, 0.1, 10, dtype); + test_cpu_to_vulkan_and_dequantize({1, 1, 10, 14}, 0.0001, 101, dtype); + test_cpu_to_vulkan_and_dequantize({3, 5, 10, 14}, 0.009, 43, dtype); + test_cpu_to_vulkan_and_dequantize({3, 5, 10, 15}, 0.1, 19, dtype); + test_cpu_to_vulkan_and_dequantize({4, 4, 9, 17}, 0.1, 19, dtype); + test_cpu_to_vulkan_and_dequantize({3, 5, 25, 29}, 0.1, 19, dtype); + test_cpu_to_vulkan_and_dequantize({4, 4, 25, 29}, 0.1, 19, dtype); + test_cpu_to_vulkan_and_dequantize({11, 17, 25, 29}, 0.027, 89, dtype); + + for (int i = 0; i < 20; i += 1) { + test_cpu_to_vulkan_and_dequantize_random(dtype); + } +} + +TEST_F(VulkanAPITest, cpu_to_vulkan_and_dequantize_qint8) { + const c10::ScalarType dtype = c10::ScalarType::QInt8; + test_cpu_to_vulkan_and_dequantize({1, 1, 1, 1}, 0.13, -21, dtype); + test_cpu_to_vulkan_and_dequantize({1, 1, 1, 4}, 0.3, 87, dtype); + test_cpu_to_vulkan_and_dequantize({1, 1, 4, 1}, 0.2, -120, dtype); + test_cpu_to_vulkan_and_dequantize({1, 1, 7, 7}, 0.3, 87, dtype); + test_cpu_to_vulkan_and_dequantize({1, 1, 8, 8}, 0.1, -10, dtype); + test_cpu_to_vulkan_and_dequantize({3, 5, 8, 8}, 0.04, 97, dtype); + test_cpu_to_vulkan_and_dequantize({1, 1, 11, 17}, 0.07, -15, dtype); + test_cpu_to_vulkan_and_dequantize({1, 1, 12, 17}, 0.1, 10, dtype); + test_cpu_to_vulkan_and_dequantize({3, 5, 12, 17}, 0.1, -10, dtype); + test_cpu_to_vulkan_and_dequantize({1, 1, 17, 12}, 0.1, 10, dtype); + test_cpu_to_vulkan_and_dequantize({2, 4, 17, 12}, 0.1, -10, dtype); + test_cpu_to_vulkan_and_dequantize({1, 1, 10, 14}, 0.0001, 101, dtype); + test_cpu_to_vulkan_and_dequantize({3, 5, 10, 14}, 0.009, -43, dtype); + test_cpu_to_vulkan_and_dequantize({3, 5, 10, 15}, 0.1, 19, dtype); + test_cpu_to_vulkan_and_dequantize({4, 4, 9, 17}, 0.1, -19, dtype); + test_cpu_to_vulkan_and_dequantize({3, 5, 25, 29}, 0.1, 19, dtype); + test_cpu_to_vulkan_and_dequantize({4, 4, 25, 29}, 0.1, -19, dtype); + test_cpu_to_vulkan_and_dequantize({11, 17, 25, 29}, 0.027, 89, dtype); + + for (int i = 0; i < 20; i += 1) { + test_cpu_to_vulkan_and_dequantize_random(dtype); + } +} + +TEST_F(VulkanAPITest, cpu_to_vulkan_and_dequantize_qint32) { + const c10::ScalarType dtype = c10::ScalarType::QInt32; + test_cpu_to_vulkan_and_dequantize({1, 1, 1, 1}, 0.13, -21123, dtype); + test_cpu_to_vulkan_and_dequantize({1, 1, 1, 4}, 0.339, 8734, dtype); + test_cpu_to_vulkan_and_dequantize({1, 1, 4, 1}, 0.228, -12023, dtype); + test_cpu_to_vulkan_and_dequantize({1, 1, 7, 7}, 0.338, 8723, dtype); + test_cpu_to_vulkan_and_dequantize({1, 1, 8, 8}, 0.193, -1023, dtype); + test_cpu_to_vulkan_and_dequantize({3, 5, 8, 8}, 0.0449, 972, dtype); + test_cpu_to_vulkan_and_dequantize({1, 1, 11, 17}, 0.073, -15, dtype); + test_cpu_to_vulkan_and_dequantize({1, 1, 12, 17}, 0.1572, 102, dtype); + test_cpu_to_vulkan_and_dequantize({3, 5, 12, 17}, 0.147, -156, dtype); + test_cpu_to_vulkan_and_dequantize({1, 1, 17, 12}, 0.129, 10448, dtype); + test_cpu_to_vulkan_and_dequantize({2, 4, 17, 12}, 0.137, -10, dtype); + test_cpu_to_vulkan_and_dequantize({1, 1, 10, 14}, 0.0001, 101, dtype); + test_cpu_to_vulkan_and_dequantize({3, 5, 10, 14}, 0.009, -43267, dtype); + test_cpu_to_vulkan_and_dequantize({3, 5, 10, 15}, 0.1243, 19, dtype); + test_cpu_to_vulkan_and_dequantize({4, 4, 9, 17}, 0.1889, -19784, dtype); + test_cpu_to_vulkan_and_dequantize({3, 5, 25, 29}, 0.1345, 196, dtype); + test_cpu_to_vulkan_and_dequantize({4, 4, 25, 29}, 0.129, -19489, dtype); + test_cpu_to_vulkan_and_dequantize({11, 17, 25, 29}, 0.027, 89, dtype); + + for (int i = 0; i < 20; i += 1) { + test_cpu_to_vulkan_and_dequantize_random(dtype); + } +} + +// TODO: Fix vulkan to cpu on Android +TEST_F(VulkanAPITest, DISABLED_quantize_per_tensor) { const auto in_cpu = at::rand({2, 13, 32, 27}, at::device(at::kCPU).dtype(at::kFloat)) * 6; const auto in_vulkan = in_cpu.vulkan(); @@ -230,6 +662,139 @@ TEST_F(VulkanAPITest, quantize_per_tensor) { ASSERT_TRUE(check); } +void test_quantize_per_tensor_and_vulkan_to_cpu( + const at::IntArrayRef input_shape, + const double input_scale, + const int input_zero_point, + const c10::ScalarType dtype = c10::ScalarType::QUInt8, + const int tolerance = 1) { + // tolerance = 1, to allow for precision differences after dividing by random + // scale which could result on a difference of 1 unit in the quantized result + + at::Tensor input = produce_random_tensor(input_shape); + + // quantize tensor + at::Tensor out_q_cpu = at::quantize_per_tensor( + input, input_scale, input_zero_point, dtype); + + at::Tensor out_q_vk = at::quantize_per_tensor( + input.vulkan(), input_scale, input_zero_point, dtype); + + // copy vulkan tensor to cpu + at::Tensor out_q_vk_cpu = vulkan_to_cpu(out_q_vk, out_q_cpu); + + const auto diff = at::native::int_repr_quantized_cpu(out_q_vk_cpu) + - at::native::int_repr_quantized_cpu(out_q_cpu); + + const int error = diff.abs().max().item(); + + const auto check = (error <= tolerance); + + if (!check) { + std::cout + << "Quantize and copy to cpu failed with input shape: " << input_shape + << " scale: " << input_scale << " and zero point: " << input_zero_point + << std::endl; + std::cout << "Error: " << error << std::endl; + } + + ASSERT_TRUE(check); +} + +void test_quantize_per_tensor_and_vulkan_to_cpu_random( + const c10::ScalarType dtype) { + const double scale = produce_random_scale(); + const int64_t zero_point = produce_random_zero_point(dtype); + const at::IntArrayRef tensor_shape = + {rand_pos_int(30), rand_pos_int(30), rand_pos_int(100), rand_pos_int(100)}; + test_quantize_per_tensor_and_vulkan_to_cpu( + tensor_shape, scale, zero_point, dtype); +} + +// TODO: Fix vulkan to cpu on Android +TEST_F(VulkanAPITest, DISABLED_quantize_per_tensor_and_vulkan_to_cpu_quint8) { + const c10::ScalarType dtype = c10::ScalarType::QUInt8; + test_quantize_per_tensor_and_vulkan_to_cpu({1, 1, 1, 1}, 0.13, 21, dtype); + test_quantize_per_tensor_and_vulkan_to_cpu({1, 1, 1, 4}, 0.3, 87, dtype); + test_quantize_per_tensor_and_vulkan_to_cpu({1, 1, 4, 1}, 0.2, 120, dtype); + test_quantize_per_tensor_and_vulkan_to_cpu({1, 1, 7, 7}, 0.3, 87, dtype); + test_quantize_per_tensor_and_vulkan_to_cpu({1, 1, 8, 8}, 0.1, 10, dtype); + test_quantize_per_tensor_and_vulkan_to_cpu({3, 5, 8, 8}, 0.04, 97, dtype); + test_quantize_per_tensor_and_vulkan_to_cpu({1, 1, 11, 17}, 0.07, 15, dtype); + test_quantize_per_tensor_and_vulkan_to_cpu({1, 1, 12, 17}, 0.1, 10, dtype); + test_quantize_per_tensor_and_vulkan_to_cpu({3, 5, 12, 17}, 0.1, 10, dtype); + test_quantize_per_tensor_and_vulkan_to_cpu({1, 1, 17, 12}, 0.1, 10, dtype); + test_quantize_per_tensor_and_vulkan_to_cpu({2, 4, 17, 12}, 0.1, 10, dtype); + test_quantize_per_tensor_and_vulkan_to_cpu({1, 1, 10, 14}, 0.0001, 101, dtype); + test_quantize_per_tensor_and_vulkan_to_cpu({3, 5, 10, 14}, 0.009, 43, dtype); + test_quantize_per_tensor_and_vulkan_to_cpu({3, 5, 10, 15}, 0.1, 19, dtype); + test_quantize_per_tensor_and_vulkan_to_cpu({4, 4, 9, 17}, 0.1, 19, dtype); + test_quantize_per_tensor_and_vulkan_to_cpu({3, 5, 25, 29}, 0.1, 19, dtype); + test_quantize_per_tensor_and_vulkan_to_cpu({4, 4, 25, 29}, 0.1, 19, dtype); + test_quantize_per_tensor_and_vulkan_to_cpu({11, 17, 25, 29}, 0.027, 89, dtype); + test_quantize_per_tensor_and_vulkan_to_cpu({3, 16, 77, 54}, 0.204173, 229, dtype); + + for (int i = 0; i < 20; i += 1) { + test_quantize_per_tensor_and_vulkan_to_cpu_random(dtype); + } +} + +// TODO: Fix vulkan to cpu on Android +TEST_F(VulkanAPITest, DISABLED_quantize_per_tensor_and_vulkan_to_cpu_qint8) { + const c10::ScalarType dtype = c10::ScalarType::QInt8; + test_quantize_per_tensor_and_vulkan_to_cpu({1, 1, 1, 1}, 0.13, -21, dtype); + test_quantize_per_tensor_and_vulkan_to_cpu({1, 1, 1, 4}, 0.3, 87, dtype); + test_quantize_per_tensor_and_vulkan_to_cpu({1, 1, 4, 1}, 0.2, -120, dtype); + test_quantize_per_tensor_and_vulkan_to_cpu({1, 1, 7, 7}, 0.3, 87, dtype); + test_quantize_per_tensor_and_vulkan_to_cpu({1, 1, 8, 8}, 0.1, -10, dtype); + test_quantize_per_tensor_and_vulkan_to_cpu({3, 5, 8, 8}, 0.04, 97, dtype); + test_quantize_per_tensor_and_vulkan_to_cpu({1, 1, 11, 17}, 0.07, -15, dtype); + test_quantize_per_tensor_and_vulkan_to_cpu({1, 1, 12, 17}, 0.1, 10, dtype); + test_quantize_per_tensor_and_vulkan_to_cpu({3, 5, 12, 17}, 0.1, -10, dtype); + test_quantize_per_tensor_and_vulkan_to_cpu({1, 1, 17, 12}, 0.1, 10, dtype); + test_quantize_per_tensor_and_vulkan_to_cpu({2, 4, 17, 12}, 0.1, -10, dtype); + test_quantize_per_tensor_and_vulkan_to_cpu({1, 1, 10, 14}, 0.0001, 101, dtype); + test_quantize_per_tensor_and_vulkan_to_cpu({3, 5, 10, 14}, 0.009, -43, dtype); + test_quantize_per_tensor_and_vulkan_to_cpu({3, 5, 10, 15}, 0.1, 19, dtype); + test_quantize_per_tensor_and_vulkan_to_cpu({4, 4, 9, 17}, 0.1, -19, dtype); + test_quantize_per_tensor_and_vulkan_to_cpu({3, 5, 25, 29}, 0.1, 19, dtype); + test_quantize_per_tensor_and_vulkan_to_cpu({4, 4, 25, 29}, 0.1, -19, dtype); + test_quantize_per_tensor_and_vulkan_to_cpu({11, 17, 25, 29}, 0.027, 89, dtype); + test_quantize_per_tensor_and_vulkan_to_cpu({3, 16, 77, 54}, 0.204173, 229, dtype); + + for (int i = 0; i < 20; i += 1) { + test_quantize_per_tensor_and_vulkan_to_cpu_random(dtype); + } +} + +// TODO: Fix vulkan to cpu on Android +TEST_F(VulkanAPITest, DISABLED_quantize_per_tensor_and_vulkan_to_cpu_qint32) { + const c10::ScalarType dtype = c10::ScalarType::QInt32; + test_quantize_per_tensor_and_vulkan_to_cpu({1, 1, 1, 1}, 0.13, -21123, dtype); + test_quantize_per_tensor_and_vulkan_to_cpu({1, 1, 1, 4}, 0.339, 8734, dtype); + test_quantize_per_tensor_and_vulkan_to_cpu({1, 1, 4, 1}, 0.228, -12023, dtype); + test_quantize_per_tensor_and_vulkan_to_cpu({1, 1, 7, 7}, 0.338, 8723, dtype); + test_quantize_per_tensor_and_vulkan_to_cpu({1, 1, 8, 8}, 0.193, -1023, dtype); + test_quantize_per_tensor_and_vulkan_to_cpu({3, 5, 8, 8}, 0.0449, 972, dtype); + test_quantize_per_tensor_and_vulkan_to_cpu({1, 1, 11, 17}, 0.073, -15, dtype); + test_quantize_per_tensor_and_vulkan_to_cpu({1, 1, 12, 17}, 0.1572, 102, dtype); + test_quantize_per_tensor_and_vulkan_to_cpu({3, 5, 12, 17}, 0.147, -156, dtype); + test_quantize_per_tensor_and_vulkan_to_cpu({1, 1, 17, 12}, 0.129, 10448, dtype); + test_quantize_per_tensor_and_vulkan_to_cpu({2, 4, 17, 12}, 0.137, -10, dtype); + test_quantize_per_tensor_and_vulkan_to_cpu({1, 1, 10, 14}, 0.0001, 101, dtype, 1); + test_quantize_per_tensor_and_vulkan_to_cpu({3, 5, 10, 14}, 0.009, -43267, dtype); + test_quantize_per_tensor_and_vulkan_to_cpu({3, 5, 10, 15}, 0.1243, 19, dtype); + test_quantize_per_tensor_and_vulkan_to_cpu({4, 4, 9, 17}, 0.1889, -19784, dtype); + test_quantize_per_tensor_and_vulkan_to_cpu({3, 5, 25, 29}, 0.1345, 196, dtype); + test_quantize_per_tensor_and_vulkan_to_cpu({4, 4, 25, 29}, 0.129, -19489, dtype); + test_quantize_per_tensor_and_vulkan_to_cpu({11, 17, 25, 29}, 0.027, 89, dtype); + test_quantize_per_tensor_and_vulkan_to_cpu({3, 16, 77, 54}, 0.204173, 229, dtype); + + for (int i = 0; i < 20; i += 1) { + test_quantize_per_tensor_and_vulkan_to_cpu_random(dtype); + } +} + TEST_F(VulkanAPITest, quantize_dequantize) { const auto in_cpu = at::rand({2, 13, 32, 27}, at::device(at::kCPU).dtype(at::kFloat)) * 6; @@ -268,6 +833,130 @@ TEST_F(VulkanAPITest, quantize_dequantize) { ASSERT_TRUE(check_two); } +void test_quantize_per_tensor_and_dequantize( + const at::IntArrayRef input_shape, + const double input_scale, + const int input_zero_point, + const c10::ScalarType dtype = c10::ScalarType::QUInt8) { + at::Tensor input = produce_random_tensor(input_shape); + + // quantize tensors + at::Tensor out_q_cpu = at::quantize_per_tensor( + input, input_scale, input_zero_point, dtype); + at::Tensor out_q_vk = at::quantize_per_tensor( + input.vulkan(), input_scale, input_zero_point, dtype); + + // dequantize tensors + const auto out_cpu_deq = at::dequantize(out_q_cpu); + const auto out_vk_deq = at::dequantize(out_q_vk); + const auto out_vk_deq_cpu = out_vk_deq.cpu(); + + // check dequantized tensor are equal + const float tolerance = input_scale; + // tolerated error = scale, to allow for precision differences after dividing + // by random scale, which could result on a difference of 1 unit in the + // quantized result. + const auto check = almostEqual(out_cpu_deq, out_vk_deq_cpu, tolerance); + + if (!check) { + const auto error = at::abs(out_vk_deq_cpu - out_cpu_deq).max().item(); + std::cout + << "Quantize and Dequantize failed with input shape: " << input_shape + << " scale: " << input_scale << " and zero point: " << input_zero_point + << std::endl; + std::cout << "Error: " << error << std::endl; + } + ASSERT_TRUE(check); +} + +void test_quantize_per_tensor_and_dequantize_random( + const c10::ScalarType dtype) { + const double scale = produce_random_scale(); + const int64_t zero_point = produce_random_zero_point(dtype); + const at::IntArrayRef tensor_shape = + {rand_pos_int(30), rand_pos_int(30), rand_pos_int(100), rand_pos_int(100)}; + test_quantize_per_tensor_and_dequantize( + tensor_shape, scale, zero_point, dtype); +} + +TEST_F(VulkanAPITest, quantize_per_tensor_and_dequantize_quint8) { + const c10::ScalarType dtype = c10::ScalarType::QUInt8; + test_quantize_per_tensor_and_dequantize({1, 1, 1, 1}, 0.13, 21, dtype); + test_quantize_per_tensor_and_dequantize({1, 1, 1, 4}, 0.3, 87, dtype); + test_quantize_per_tensor_and_dequantize({1, 1, 4, 1}, 0.2, 120, dtype); + test_quantize_per_tensor_and_dequantize({1, 1, 7, 7}, 0.3, 87, dtype); + test_quantize_per_tensor_and_dequantize({1, 1, 8, 8}, 0.1, 10, dtype); + test_quantize_per_tensor_and_dequantize({3, 5, 8, 8}, 0.04, 97, dtype); + test_quantize_per_tensor_and_dequantize({1, 1, 11, 17}, 0.07, 15, dtype); + test_quantize_per_tensor_and_dequantize({1, 1, 12, 17}, 0.1, 10, dtype); + test_quantize_per_tensor_and_dequantize({3, 5, 12, 17}, 0.1, 10, dtype); + test_quantize_per_tensor_and_dequantize({1, 1, 17, 12}, 0.1, 10, dtype); + test_quantize_per_tensor_and_dequantize({2, 4, 17, 12}, 0.1, 10, dtype); + test_quantize_per_tensor_and_dequantize({1, 1, 10, 14}, 0.001, 101, dtype); + test_quantize_per_tensor_and_dequantize({3, 5, 10, 14}, 0.009, 43, dtype); + test_quantize_per_tensor_and_dequantize({3, 5, 10, 15}, 0.1, 19, dtype); + test_quantize_per_tensor_and_dequantize({4, 4, 9, 17}, 0.1, 19, dtype); + test_quantize_per_tensor_and_dequantize({3, 5, 25, 29}, 0.1, 19, dtype); + test_quantize_per_tensor_and_dequantize({4, 4, 25, 29}, 0.1, 19, dtype); + test_quantize_per_tensor_and_dequantize({11, 17, 25, 29}, 0.027, 89, dtype); + + for (int i = 0; i < 20; i += 1) { + test_quantize_per_tensor_and_dequantize_random(dtype); + } +} + +TEST_F(VulkanAPITest, quantize_per_tensor_and_dequantize_qint8) { + const c10::ScalarType dtype = c10::ScalarType::QInt8; + test_quantize_per_tensor_and_dequantize({1, 1, 1, 1}, 0.13, -21, dtype); + test_quantize_per_tensor_and_dequantize({1, 1, 1, 4}, 0.3, 87, dtype); + test_quantize_per_tensor_and_dequantize({1, 1, 4, 1}, 0.2, -120, dtype); + test_quantize_per_tensor_and_dequantize({1, 1, 7, 7}, 0.3, 87, dtype); + test_quantize_per_tensor_and_dequantize({1, 1, 8, 8}, 0.1, -10, dtype); + test_quantize_per_tensor_and_dequantize({3, 5, 8, 8}, 0.04, 97, dtype); + test_quantize_per_tensor_and_dequantize({1, 1, 11, 17}, 0.07, -15, dtype); + test_quantize_per_tensor_and_dequantize({1, 1, 12, 17}, 0.1, 10, dtype); + test_quantize_per_tensor_and_dequantize({3, 5, 12, 17}, 0.1, -10, dtype); + test_quantize_per_tensor_and_dequantize({1, 1, 17, 12}, 0.1, 10, dtype); + test_quantize_per_tensor_and_dequantize({2, 4, 17, 12}, 0.1, -10, dtype); + test_quantize_per_tensor_and_dequantize({1, 1, 10, 14}, 0.001, 101, dtype); + test_quantize_per_tensor_and_dequantize({3, 5, 10, 14}, 0.009, -43, dtype); + test_quantize_per_tensor_and_dequantize({3, 5, 10, 15}, 0.1, 19, dtype); + test_quantize_per_tensor_and_dequantize({4, 4, 9, 17}, 0.1, -19, dtype); + test_quantize_per_tensor_and_dequantize({3, 5, 25, 29}, 0.1, 19, dtype); + test_quantize_per_tensor_and_dequantize({4, 4, 25, 29}, 0.1, -19, dtype); + test_quantize_per_tensor_and_dequantize({11, 17, 25, 29}, 0.027, 89, dtype); + + for (int i = 0; i < 20; i += 1) { + test_quantize_per_tensor_and_dequantize_random(dtype); + } +} + +TEST_F(VulkanAPITest, quantize_per_tensor_and_dequantize_qint32) { + const c10::ScalarType dtype = c10::ScalarType::QInt32; + test_quantize_per_tensor_and_dequantize({1, 1, 1, 1}, 0.13, -21123, dtype); + test_quantize_per_tensor_and_dequantize({1, 1, 1, 4}, 0.339, 8734, dtype); + test_quantize_per_tensor_and_dequantize({1, 1, 4, 1}, 0.228, -12023, dtype); + test_quantize_per_tensor_and_dequantize({1, 1, 7, 7}, 0.338, 8723, dtype); + test_quantize_per_tensor_and_dequantize({1, 1, 8, 8}, 0.193, -1023, dtype); + test_quantize_per_tensor_and_dequantize({3, 5, 8, 8}, 0.0449, 972, dtype); + test_quantize_per_tensor_and_dequantize({1, 1, 11, 17}, 0.073, -15, dtype); + test_quantize_per_tensor_and_dequantize({1, 1, 12, 17}, 0.1572, 102, dtype); + test_quantize_per_tensor_and_dequantize({3, 5, 12, 17}, 0.147, -156, dtype); + test_quantize_per_tensor_and_dequantize({1, 1, 17, 12}, 0.129, 10448, dtype); + test_quantize_per_tensor_and_dequantize({2, 4, 17, 12}, 0.137, -10, dtype); + test_quantize_per_tensor_and_dequantize({1, 1, 10, 14}, 0.001, 101, dtype); + test_quantize_per_tensor_and_dequantize({3, 5, 10, 14}, 0.009, -43267, dtype); + test_quantize_per_tensor_and_dequantize({3, 5, 10, 15}, 0.1243, 19, dtype); + test_quantize_per_tensor_and_dequantize({4, 4, 9, 17}, 0.1889, -19784, dtype); + test_quantize_per_tensor_and_dequantize({3, 5, 25, 29}, 0.1345, 196, dtype); + test_quantize_per_tensor_and_dequantize({4, 4, 25, 29}, 0.129, -19489, dtype); + test_quantize_per_tensor_and_dequantize({11, 17, 25, 29}, 0.027, 89, dtype); + + for (int i = 0; i < 20; i += 1) { + test_quantize_per_tensor_and_dequantize_random(dtype); + } +} + TEST_F(VulkanAPITest, quantized_add) { const auto in_cpu = at::rand({2, 13, 32, 27}, at::device(at::kCPU).dtype(at::kFloat)) * 6; @@ -462,7 +1151,6 @@ TEST_F(VulkanAPITest, quantized_add_broadcast2) { ASSERT_TRUE(check); } - TEST_F(VulkanAPITest, quantized_add_broadcast3) { if (!at::is_vulkan_available()) { return; @@ -1049,6 +1737,735 @@ TEST_F(VulkanAPITest, quantized_upsample_nearest2d) { ASSERT_TRUE(check); } +std::tuple produce_inputs_for_binary_op( + const bool compute_quantization_params, + const bool random_quantization_params, + const char* op_name, + const at::IntArrayRef input1_shape, + const at::IntArrayRef input2_shape, + double in1_scale, double in2_scale, + int in1_zero_point, int in2_zero_point, + at::Tensor& input1_cpu, at::Tensor& input1_cpu_q, + at::Tensor& input1_cpu_deq, + at::Tensor& input1_vk, at::Tensor& input1_vk_q, + at::Tensor& input1_vk_deq, at::Tensor& input1_vk_deq_cpu, + at::Tensor& input2_cpu, at::Tensor& input2_cpu_q, + at::Tensor& input2_cpu_deq, + at::Tensor& input2_vk, at::Tensor& input2_vk_q, + at::Tensor& input2_vk_deq, at::Tensor& input2_vk_deq_cpu) { + + int num_attempts = 5; + // in order to make sure we start with input tensors that are numerically + // the same (cpu vs vulkan), we allow multiple attempts when randomly + // generating the inputs. If the cpu quantized tensor and the vk quantized + // tensors are not the same (maybe off by 1 due to differences in rounding + // and precision), we try again. + for (int i = 0; i < num_attempts; i += 1) { + // produce random inputs + input1_cpu = produce_random_tensor(input1_shape); + input2_cpu = produce_random_tensor(input1_shape); + + if (compute_quantization_params) { + // compute appropiate scale and zero point for inputs + const auto in1_quant_params = compute_quant_params(input1_cpu); + in1_scale = std::get<0>(in1_quant_params); + in1_zero_point = std::get<1>(in1_quant_params); + + const auto in2_quant_params = compute_quant_params(input2_cpu); + in2_scale = std::get<0>(in2_quant_params); + in2_zero_point = std::get<1>(in2_quant_params); + } else if (random_quantization_params) { + // produce random scale and zero point for inputs + in1_scale = produce_random_scale(); + in1_zero_point = produce_random_zero_point(c10::ScalarType::QUInt8); + + in2_scale = produce_random_scale(); + in2_zero_point = produce_random_zero_point(c10::ScalarType::QUInt8); + } + + // we do this, to avoid dividing by zero + if (strcmp(op_name, "quantized::div") == 0) { + // we might end up dividing by 0, if we allow random scale and zero point + // of the divisor. + if (random_quantization_params) { + const auto in2_quant_params = compute_quant_params(input2_cpu); + in2_scale = std::get<0>(in2_quant_params); + in2_zero_point = std::get<1>(in2_quant_params); + } + + const auto non_zero_sign = input2_cpu.sign() - input2_cpu.sign().abs() + 1; + // non_zero_sign = 1 if the value is non negative, and -1 if it is negative + input2_cpu = input2_cpu + in2_scale * non_zero_sign; + // this will force abs(input2_cpu) >= in2_scale, which means that none of + // the quantized values of the second input will be equal to the zero point. + } + + // quantize cpu inputs + input1_cpu_q = at::quantize_per_tensor( + input1_cpu, in1_scale, in1_zero_point, c10::ScalarType::QUInt8); + input2_cpu_q = at::quantize_per_tensor( + input2_cpu, in2_scale, in2_zero_point, c10::ScalarType::QUInt8); + + // dequantize quantized cpu inputs + input1_cpu_deq = at::dequantize(input1_cpu_q); + input2_cpu_deq = at::dequantize(input2_cpu_q); + + // vulkan quantized inputs + input1_vk = input1_cpu.vulkan(); + input1_vk_q = at::quantize_per_tensor( + input1_vk, in1_scale, in1_zero_point, c10::ScalarType::QUInt8); + input2_vk = input2_cpu.vulkan(); + input2_vk_q = at::quantize_per_tensor( + input2_vk, in2_scale, in2_zero_point, c10::ScalarType::QUInt8); + + // dequantize quantized vulkan inputs + input1_vk_deq = at::dequantize(input1_vk_q); + input2_vk_deq = at::dequantize(input2_vk_q); + + input1_vk_deq_cpu = input1_vk_deq.cpu(); + input2_vk_deq_cpu = input2_vk_deq.cpu(); + + const float input1_dif = at::abs(input1_cpu_deq - input1_vk_deq_cpu).max().item(); + const float input2_dif = at::abs(input2_cpu_deq - input2_vk_deq_cpu).max().item(); + if (input1_dif < 1e-5 && input2_dif < 1e-5 && input1_dif < in1_scale/2 && input2_dif < in2_scale/2) { + break; + } + } + + return {in1_scale, in2_scale, in1_zero_point, in2_zero_point}; +} + +at::Tensor apply_cpu_quantized_binary_op( + const char* op_name, + at::Tensor input1_cpu_deq, + at::Tensor input2_cpu_deq) { + if (strcmp(op_name, "quantized::add") == 0) { + return at::add(input1_cpu_deq, input2_cpu_deq); + } else if (strcmp(op_name, "quantized::sub") == 0) { + return at::sub(input1_cpu_deq, input2_cpu_deq); + } else if (strcmp(op_name, "quantized::mul") == 0) { + return at::mul(input1_cpu_deq, input2_cpu_deq); + } else if (strcmp(op_name, "quantized::div") == 0) { + return at::div(input1_cpu_deq, input2_cpu_deq); + } else { + TORCH_CHECK(false, "Invalid op"); + } +} + +at::Tensor apply_vulkan_quantized_binary_op( + const char* op_name, + at::Tensor input1_vk_q, + at::Tensor input2_vk_q, + double out_scale, + int64_t out_zero_point) { + if (strcmp(op_name, "quantized::add") == 0) { + return at::native::vulkan::ops::quantized_add( + input1_vk_q, input2_vk_q, out_scale, out_zero_point); + } else if (strcmp(op_name, "quantized::sub") == 0) { + return at::native::vulkan::ops::quantized_sub( + input1_vk_q, input2_vk_q, out_scale, out_zero_point); + } else if (strcmp(op_name, "quantized::mul") == 0) { + return at::native::vulkan::ops::quantized_mul( + input1_vk_q, input2_vk_q, out_scale, out_zero_point); + } else if (strcmp(op_name, "quantized::div") == 0) { + return at::native::vulkan::ops::quantized_div( + input1_vk_q, input2_vk_q, out_scale, out_zero_point); + } else { + TORCH_CHECK(false, "Invalid op"); + } +} + +void test_quantized_binary_op( + const bool compute_quantization_params, + const bool random_quantization_params, + const char* op_name, + const at::IntArrayRef input1_shape, + const at::IntArrayRef input2_shape, + double in1_scale_default = 0.103, + double in2_scale_default = 0.171, + double out_scale_default = 0.139, + int64_t in1_zero_point_default = 11, + int64_t in2_zero_point_default = 9, + int64_t out_zero_point_default = 17) { + + // produce inputs + at::Tensor input1_cpu, input1_cpu_q, input1_cpu_deq; + at::Tensor input1_vk, input1_vk_q, input1_vk_deq, input1_vk_deq_cpu; + at::Tensor input2_cpu, input2_cpu_q, input2_cpu_deq; + at::Tensor input2_vk, input2_vk_q, input2_vk_deq, input2_vk_deq_cpu; + + auto input_params = produce_inputs_for_binary_op( + compute_quantization_params, random_quantization_params, op_name, + input1_shape, input2_shape, + in1_scale_default, in2_scale_default, + in1_zero_point_default, in2_zero_point_default, + input1_cpu, input1_cpu_q, input1_cpu_deq, + input1_vk, input1_vk_q, input1_vk_deq, input1_vk_deq_cpu, + input2_cpu, input2_cpu_q, input2_cpu_deq, + input2_vk, input2_vk_q, input2_vk_deq, input2_vk_deq_cpu); + + double in1_scale = std::get<0>(input_params); + double in2_scale = std::get<1>(input_params); + int64_t in1_zero_point = std::get<2>(input_params); + int64_t in2_zero_point = std::get<3>(input_params); + + double out_scale = out_scale_default; + int64_t out_zero_point = out_zero_point_default; + + // apply op on dequantized cpu tensors + at::Tensor output_cpu = apply_cpu_quantized_binary_op( + op_name, input1_cpu_deq, input2_cpu_deq); + + if (compute_quantization_params || random_quantization_params) { + // compute appropiate scale and zero point for output + const auto out_quant_params = compute_quant_params(output_cpu); + out_scale = std::get<0>(out_quant_params); + out_zero_point = std::get<1>(out_quant_params); + } + + // quantize and dequantize cpu output + const auto output_cpu_q = at::quantize_per_tensor( + output_cpu, out_scale, out_zero_point, c10::ScalarType::QUInt8); + const auto output_cpu_deq = at::dequantize(output_cpu_q); + + // vulkan quantized output + at::Tensor output_vk_q = apply_vulkan_quantized_binary_op( + op_name, input1_vk_q, input2_vk_q, out_scale, out_zero_point); + + const auto output_vk_deq = at::dequantize(output_vk_q); + const auto output_vk_deq_cpu = output_vk_deq.cpu(); + + // check + const float tolerance = + (compute_quantization_params || random_quantization_params) ? out_scale : 0; + const auto check = almostEqual(output_cpu_deq, output_vk_deq_cpu, tolerance); + + if (!check) { + const auto vk_q_error = at::abs(output_vk_deq_cpu - output_cpu_deq).max().item(); + std::cout << "Binary op " << op_name << " failed with inputs: " << std::endl; + std::cout << "input1: shape " << input1_shape << " scale " << in1_scale + << " and zero point " << in1_zero_point << std::endl; + std::cout << "input2: shape " << input2_shape << " scale " << in2_scale + << " and zero point " << in2_zero_point << std::endl; + std::cout << "output scale " << out_scale + << " and zero point " << out_zero_point << std::endl; + std::cout << "error: " << vk_q_error << std::endl; + } + ASSERT_TRUE(check); +} + +void quantized_binary_op_test_set( + const char* op_name) { + // fixed params + test_quantized_binary_op(false, false, op_name, {1, 1, 1, 1}, {1, 1, 1, 1}); + test_quantized_binary_op(false, false, op_name, {1, 1, 8, 8}, {1, 1, 8, 8}); + test_quantized_binary_op(false, false, op_name, {1, 1, 12, 17}, {1, 1, 12, 17}); + test_quantized_binary_op(false, false, op_name, {2, 13, 32, 27}, {2, 13, 32, 27}); + test_quantized_binary_op(false, false, op_name, {7, 15, 6, 17}, {7, 15, 1, 17}); // broadcasting + test_quantized_binary_op(false, false, op_name, {7, 1, 6, 17}, {7, 5, 6, 17}); // broadcasting + + // compute params + test_quantized_binary_op(true, false, op_name, {1, 1, 1, 1}, {1, 1, 1, 1}); + test_quantized_binary_op(true, false, op_name, {1, 1, 8, 8}, {1, 1, 8, 8}); + test_quantized_binary_op(true, false, op_name, {1, 1, 12, 17}, {1, 1, 12, 17}); + test_quantized_binary_op(true, false, op_name, {2, 13, 32, 27}, {2, 13, 32, 27}); + test_quantized_binary_op(true, false, op_name, {7, 15, 6, 17}, {7, 15, 1, 17}); // broadcasting + test_quantized_binary_op(true, false, op_name, {7, 1, 6, 17}, {7, 5, 6, 17}); // broadcasting + + // random params + test_quantized_binary_op(false, true, op_name, {1, 1, 1, 1}, {1, 1, 1, 1}); + test_quantized_binary_op(false, true, op_name, {1, 1, 8, 8}, {1, 1, 8, 8}); + test_quantized_binary_op(false, true, op_name, {1, 1, 12, 17}, {1, 1, 12, 17}); + test_quantized_binary_op(false, true, op_name, {2, 13, 32, 27}, {2, 13, 32, 27}); + test_quantized_binary_op(false, true, op_name, {7, 15, 6, 17}, {7, 15, 1, 17}); // broadcasting + test_quantized_binary_op(false, true, op_name, {7, 1, 6, 17}, {7, 5, 6, 17}); // broadcasting + + // random shape and params + for (int i = 0; i < 10; i += 1) { + const at::IntArrayRef tensor_shape = + { + rand_pos_int(30), + rand_pos_int(30), + rand_pos_int(100), + rand_pos_int(100) + }; + test_quantized_binary_op(false, true, op_name, tensor_shape, tensor_shape); + } +} + +TEST_F(VulkanAPITest, quantized_add_tests) { + quantized_binary_op_test_set("quantized::add"); +} + +TEST_F(VulkanAPITest, quantized_sub_tests) { + quantized_binary_op_test_set("quantized::sub"); +} + +TEST_F(VulkanAPITest, quantized_mul_tests) { + quantized_binary_op_test_set("quantized::mul"); +} + +TEST_F(VulkanAPITest, quantized_div_tests) { + quantized_binary_op_test_set("quantized::div"); +} + +void test_quantized_conv2d( + const bool prepacking, + const bool compute_quantization_params, + const bool random_quantization_params, + const at::IntArrayRef input_shape, + const at::IntArrayRef weight_shape, + const at::IntArrayRef bias_shape, + std::vector stride, + std::vector padding, + std::vector dilation, + int64_t groups, + double in_scale = 0.13, + double w_scale = 0.29, + double b_scale = 0.19, + double out_scale = 0.15, + int64_t in_zero_point = 11, + int64_t w_zero_point = 19, + int64_t b_zero_point = 27, + int64_t out_zero_point = 10) { + c10::InferenceMode mode; + + // input cpu + at::Tensor input_cpu; // input cpu tensor + at::Tensor input_cpu_q; // input cpu tensor -> quantized + at::Tensor input_cpu_deq; // input cpu tensor -> quantized -> dequantized + + // input vulkan + at::Tensor input_vk; // input cpu tensor -> to vulkan + at::Tensor input_vk_q; // input cpu tensor -> to vulkan -> quantized + at::Tensor input_vk_deq; // input cpu tensor -> to vulkan -> quantized -> dequantized + at::Tensor input_vk_deq_cpu; // input cpu tensor -> to vulkan -> quantized -> dequantized -> to cpu + + // weight cpu + at::Tensor weight_cpu; // weight cpu tensor + at::Tensor weight_cpu_q; // weight cpu tensor -> quantized + at::Tensor weight_cpu_deq; // weight cpu tensor -> quantized -> dequantized + + // bias cpu + at::Tensor bias_cpu; // bias cpu tensor + at::Tensor bias_cpu_q; // bias cpu tensor -> quantized + at::Tensor bias_cpu_deq; // bias cpu tensor -> quantized -> dequantized + + // When we randomly generate the input tensor, we might get unlucky + // and one of the entries might be generated such that when it is divided + // by the scale we get something like 2.50003 for example which could be + // rounded to 2 or 3 depending on the precision and rounding method. + // Because of that possibility, we generate the input and check the + // difference between input_cpu_deq and input_vk_deq_cpu + // If they are different we regenerated them again (up to 3 times) + // The goal is to start with input tensors that remain equal after quantization. + int num_attempts = 5; + for (int i = 0; i < num_attempts; i += 1) { + // produce random input, weight and bias + input_cpu = produce_random_tensor(input_shape, 1.26, 5.97, 0.59); + weight_cpu = produce_random_tensor(weight_shape, 1.26, 5.97, 0.59); + bias_cpu = produce_random_tensor(bias_shape, 1.26, 5.97, 0.59); + + if (compute_quantization_params) { + // compute appropiate scale and zero point for input, weight and bias + const auto in_quant_params = compute_quant_params(input_cpu); + in_scale = std::get<0>(in_quant_params); + in_zero_point = std::get<1>(in_quant_params); + + const auto w_quant_params = compute_quant_params(weight_cpu); + w_scale = std::get<0>(w_quant_params); + w_zero_point = std::get<1>(w_quant_params); + + const auto input_max = input_cpu.max().item(); + const auto input_min = input_cpu.min().item(); + const auto input_range = input_max - input_min; + + bias_cpu = input_range * at::rand(bias_shape, at::device(at::kCPU).dtype(at::kFloat)) + input_min; + b_scale = in_scale; + b_zero_point = in_zero_point; + } + else if (random_quantization_params) { + // produce random scale and zero point for inputs + in_scale = produce_random_scale(); + in_zero_point = produce_random_zero_point(c10::ScalarType::QUInt8); + + w_scale = produce_random_scale(); + w_zero_point = produce_random_zero_point(c10::ScalarType::QUInt8); + + b_scale = produce_random_scale(); + b_zero_point = produce_random_zero_point(c10::ScalarType::QUInt8); + } + + // quantize cpu input, weight and bias + input_cpu_q = at::quantize_per_tensor( + input_cpu, in_scale, in_zero_point, c10::ScalarType::QUInt8); + weight_cpu_q = at::quantize_per_tensor( + weight_cpu, w_scale, w_zero_point, c10::ScalarType::QUInt8); + bias_cpu_q = at::quantize_per_tensor( + bias_cpu, b_scale, b_zero_point, c10::ScalarType::QUInt8); + + // dequantize quantized cpu input, weight and bias + input_cpu_deq = at::dequantize(input_cpu_q); + weight_cpu_deq = at::dequantize(weight_cpu_q); + bias_cpu_deq = at::dequantize(bias_cpu_q); + + // vulkan quantized input + input_vk = input_cpu.vulkan(); + input_vk_q = at::quantize_per_tensor( + input_vk, in_scale, in_zero_point, c10::ScalarType::QUInt8); + + // dequantize quantized vulkan input + input_vk_deq = at::dequantize(input_vk_q); + input_vk_deq_cpu = input_vk_deq.cpu(); + + const float input_dif = at::abs(input_cpu_deq - input_vk_deq_cpu).max().item(); + + if (input_dif < 1e-5 && input_dif < in_scale/2) { + break; + } else { + std::cout << "input_dif too big: " << input_dif; + if (i + 1 < num_attempts) { + std::cout << ". generating input again ..." << std::endl; + } else { + std::cout << std::endl; + } + } + } + + // conv2d on dequantized cpu tensors + // Note: we apply the convolutio to the dequantized quantized tensors, that way + // we are performing the operations on the same numeric values. + const auto output_cpu = at::conv2d( + input_cpu_deq, weight_cpu_deq, bias_cpu_deq, stride, padding, dilation, groups); + + if (compute_quantization_params || random_quantization_params) { + // compute appropiate scale and zero point for output + const auto out_quant_params = compute_quant_params(output_cpu); + out_scale = std::get<0>(out_quant_params); + out_zero_point = std::get<1>(out_quant_params); + } + + // quantize and dequantize cpu output + at::Tensor output_cpu_q = at::quantize_per_tensor( + output_cpu, out_scale, out_zero_point, c10::ScalarType::QUInt8); + at::Tensor output_cpu_deq = at::dequantize(output_cpu_q); + + // vulkan quantized output + at::Tensor output_vk_q; + + if (!prepacking) { + // vulkan quantized conv2d + output_vk_q = at::native::vulkan::ops::quantized_conv2d( + input_vk_q, weight_cpu_q, bias_cpu_q, + stride, padding, dilation, groups, + out_scale, out_zero_point); + } else { + // vulkan quantized conv2d call by name + const auto prepack_vulkan_call_by_name = callOpByName( + "vulkan_prepack::create_qconv2d_context", + "", + weight_cpu_q, bias_cpu_q, stride, padding, dilation, groups, c10::nullopt, c10::nullopt); + const auto vulkan_output = callOpByName( + "vulkan_prepack::run_qconv2d_context", + "", + input_vk_q, out_scale, out_zero_point, prepack_vulkan_call_by_name[0]); + output_vk_q = vulkan_output[0].toTensor(); + } + + // dequantize vulkan output + const auto output_vk_deq = at::dequantize(output_vk_q); + const auto output_vk_deq_cpu = output_vk_deq.cpu(); + + // check + const float tolerance = out_scale; + const auto check = almostEqual(output_cpu_deq, output_vk_deq_cpu, tolerance); + + if (!check) { + const auto vk_q_error = at::abs(output_vk_deq_cpu - output_cpu_deq).max().item(); + std::cout << "Quantized Conv2d failed with: " << std::endl; + std::cout << "input: shape " << input_shape << " scale " << in_scale + << " and zero point " << in_zero_point << std::endl; + std::cout << "weight: shape " << weight_shape << " scale " << w_scale + << " and zero point " << w_zero_point << std::endl; + std::cout << "bias: shape " << bias_shape << " scale " << b_scale + << " and zero point " << b_zero_point << std::endl; + std::cout << "output scale " << out_scale + << " and zero point " << out_zero_point << std::endl; + std::cout << "error: " << vk_q_error << std::endl; + } + ASSERT_TRUE(check); +} + +TEST_F(VulkanAPITest, conv2d_quantized_fixed_params) { + test_quantized_conv2d( + /* prepacking? */ false, + /* compute params */false, + /* random params */ false, + /* input_shape */ {1, 3, 8, 8}, + /* weight_shape */ {1, 3, 3, 3}, + /* bias_shape */ {1}, + /* stride */ {2, 2}, + /* padding */ {1, 1}, + /* dilation */ {1, 1}, + /* groups */ 1 + ); +} + +TEST_F(VulkanAPITest, conv2d_quantized_computed_params) { + test_quantized_conv2d( + /* prepacking? */ false, + /* compute params */true, + /* random params */ false, + /* input_shape */ {1, 3, 8, 8}, + /* weight_shape */ {1, 3, 3, 3}, + /* bias_shape */ {1}, + /* stride */ {2, 2}, + /* padding */ {1, 1}, + /* dilation */ {1, 1}, + /* groups */ 1 + ); +} + +TEST_F(VulkanAPITest, conv2d_quantized_random_params) { + test_quantized_conv2d( + /* prepacking? */ false, + /* compute params */false, + /* random params */ true, + /* input_shape */ {1, 3, 8, 8}, + /* weight_shape */ {1, 3, 3, 3}, + /* bias_shape */ {1}, + /* stride */ {2, 2}, + /* padding */ {1, 1}, + /* dilation */ {1, 1}, + /* groups */ 1 + ); +} + +TEST_F(VulkanAPITest, conv2d_quantized_prepack_fixed_params) { + test_quantized_conv2d( + /* prepacking? */ true, + /* compute params */false, + /* random params */ false, + /* input_shape */ {1, 3, 8, 8}, + /* weight_shape */ {1, 3, 3, 3}, + /* bias_shape */ {1}, + /* stride */ {2, 2}, + /* padding */ {1, 1}, + /* dilation */ {1, 1}, + /* groups */ 1 + ); +} + +TEST_F(VulkanAPITest, conv2d_quantized_prepack_computed_params) { + test_quantized_conv2d( + /* prepacking? */ true, + /* compute params */true, + /* random params */ false, + /* input_shape */ {1, 3, 8, 8}, + /* weight_shape */ {1, 3, 3, 3}, + /* bias_shape */ {1}, + /* stride */ {2, 2}, + /* padding */ {1, 1}, + /* dilation */ {1, 1}, + /* groups */ 1 + ); +} + +TEST_F(VulkanAPITest, conv2d_quantized_prepack_random_params) { + test_quantized_conv2d( + /* prepacking? */ true, + /* compute params */false, + /* random params */ true, + /* input_shape */ {1, 3, 8, 8}, + /* weight_shape */ {1, 3, 3, 3}, + /* bias_shape */ {1}, + /* stride */ {2, 2}, + /* padding */ {1, 1}, + /* dilation */ {1, 1}, + /* groups */ 1 + ); +} + +TEST_F(VulkanAPITest, conv2d_dw_quantized_fixed_params) { + test_quantized_conv2d( + /* prepacking? */ false, + /* compute params */false, + /* random params */ false, + /* input_shape */ {1, 7, 137, 199}, + /* weight_shape */ {7, 1, 17, 7}, + /* bias_shape */ {7}, + /* stride */ {2, 3}, + /* padding */ {0, 4}, + /* dilation */ {3, 1}, + /* groups */ 7 + ); +} + +TEST_F(VulkanAPITest, conv2d_dw_quantized_computed_params) { + test_quantized_conv2d( + /* prepacking? */ false, + /* compute params */true, + /* random params */ false, + /* input_shape */ {1, 7, 137, 199}, + /* weight_shape */ {7, 1, 17, 7}, + /* bias_shape */ {7}, + /* stride */ {2, 3}, + /* padding */ {0, 4}, + /* dilation */ {3, 1}, + /* groups */ 7 + ); +} + +TEST_F(VulkanAPITest, conv2d_dw_quantized_random_params) { + test_quantized_conv2d( + /* prepacking? */ false, + /* compute params */false, + /* random params */ true, + /* input_shape */ {1, 7, 137, 199}, + /* weight_shape */ {7, 1, 17, 7}, + /* bias_shape */ {7}, + /* stride */ {2, 3}, + /* padding */ {0, 4}, + /* dilation */ {3, 1}, + /* groups */ 7 + ); +} + +TEST_F(VulkanAPITest, conv2d_dw_quantized_prepack_fixed_params) { + test_quantized_conv2d( + /* prepacking? */ true, + /* compute params */false, + /* random params */ false, + /* input_shape */ {1, 7, 137, 199}, + /* weight_shape */ {7, 1, 17, 7}, + /* bias_shape */ {7}, + /* stride */ {2, 3}, + /* padding */ {0, 4}, + /* dilation */ {3, 1}, + /* groups */ 7 + ); +} + +TEST_F(VulkanAPITest, conv2d_dw_quantized_prepack_computed_params) { + test_quantized_conv2d( + /* prepacking? */ true, + /* compute params */true, + /* random params */ false, + /* input_shape */ {1, 7, 137, 199}, + /* weight_shape */ {7, 1, 17, 7}, + /* bias_shape */ {7}, + /* stride */ {2, 3}, + /* padding */ {0, 4}, + /* dilation */ {3, 1}, + /* groups */ 7 + ); +} + +TEST_F(VulkanAPITest, conv2d_dw_quantized_prepack_random_params) { + test_quantized_conv2d( + /* prepacking? */ true, + /* compute params */false, + /* random params */ true, + /* input_shape */ {1, 7, 137, 199}, + /* weight_shape */ {7, 1, 17, 7}, + /* bias_shape */ {7}, + /* stride */ {2, 3}, + /* padding */ {0, 4}, + /* dilation */ {3, 1}, + /* groups */ 7 + ); +} + +TEST_F(VulkanAPITest, conv2d_pw_quantized_fixed_params) { + test_quantized_conv2d( + /* prepacking? */ false, + /* compute params */false, + /* random params */ false, + /* input_shape */ {1, 17, 127, 397}, + /* weight_shape */ {29, 17, 1, 1}, + /* bias_shape */ {29}, + /* stride */ {1, 1}, + /* padding */ {0, 0}, + /* dilation */ {1, 1}, + /* groups */ 1 + ); +} + +TEST_F(VulkanAPITest, conv2d_pw_quantized_computed_params) { + test_quantized_conv2d( + /* prepacking? */ false, + /* compute params */true, + /* random params */ false, + /* input_shape */ {1, 17, 127, 397}, + /* weight_shape */ {29, 17, 1, 1}, + /* bias_shape */ {29}, + /* stride */ {1, 1}, + /* padding */ {0, 0}, + /* dilation */ {1, 1}, + /* groups */ 1 + ); +} + +TEST_F(VulkanAPITest, conv2d_pw_quantized_random_params) { + test_quantized_conv2d( + /* prepacking? */ false, + /* compute params */false, + /* random params */ true, + /* input_shape */ {1, 17, 127, 397}, + /* weight_shape */ {29, 17, 1, 1}, + /* bias_shape */ {29}, + /* stride */ {1, 1}, + /* padding */ {0, 0}, + /* dilation */ {1, 1}, + /* groups */ 1 + ); +} + +TEST_F(VulkanAPITest, conv2d_pw_quantized_prepack_fixed_params) { + test_quantized_conv2d( + /* prepacking? */ true, + /* compute params */false, + /* random params */ false, + /* input_shape */ {1, 17, 127, 397}, + /* weight_shape */ {29, 17, 1, 1}, + /* bias_shape */ {29}, + /* stride */ {1, 1}, + /* padding */ {0, 0}, + /* dilation */ {1, 1}, + /* groups */ 1 + ); +} + +TEST_F(VulkanAPITest, conv2d_pw_quantized_prepack_computed_params) { + test_quantized_conv2d( + /* prepacking? */ true, + /* compute params */true, + /* random params */ false, + /* input_shape */ {1, 17, 127, 397}, + /* weight_shape */ {29, 17, 1, 1}, + /* bias_shape */ {29}, + /* stride */ {1, 1}, + /* padding */ {0, 0}, + /* dilation */ {1, 1}, + /* groups */ 1 + ); +} + +TEST_F(VulkanAPITest, conv2d_pw_quantized_prepack_random_params) { + test_quantized_conv2d( + /* prepacking? */ true, + /* compute params */false, + /* random params */ true, + /* input_shape */ {1, 17, 127, 397}, + /* weight_shape */ {29, 17, 1, 1}, + /* bias_shape */ {29}, + /* stride */ {1, 1}, + /* padding */ {0, 0}, + /* dilation */ {1, 1}, + /* groups */ 1 + ); +} + } // namespace #endif /* USE_VULKAN_API */ diff --git a/aten/src/README.md b/aten/src/README.md index add2816926331..3127ed5c8c399 100644 --- a/aten/src/README.md +++ b/aten/src/README.md @@ -69,8 +69,8 @@ will `retain` it itself. ``` Sometimes, you have a tensor in hand which you'd like to use directly, but -under some conditions you have to have to call, e.g., `newContiguous`, to get -it into the correct form: +under some conditions you have to call, e.g., `newContiguous`, to get it into +the correct form: ``` if (!(k_->stride(3) == 1) || !(k_->stride[2] == k_->size(3))) { diff --git a/aten/tools/run_tests.sh b/aten/tools/run_tests.sh index 5b0c02c2846a4..3ae0da113bca7 100755 --- a/aten/tools/run_tests.sh +++ b/aten/tools/run_tests.sh @@ -26,7 +26,7 @@ fi ./Dict_test ./NamedTensor_test ./cpu_generator_test -./vmap_test +./legacy_vmap_test ./operators_test if [[ -x ./cudnn_test ]]; then ./cudnn_test diff --git a/benchmarks/cpp/nvfuser/CMakeLists.txt b/benchmarks/cpp/nvfuser/CMakeLists.txt index bac36d19f3d16..ad9053bb3a3aa 100644 --- a/benchmarks/cpp/nvfuser/CMakeLists.txt +++ b/benchmarks/cpp/nvfuser/CMakeLists.txt @@ -20,6 +20,7 @@ if(USE_CUDA) softmax_backward.cpp scale_bias_relu.cpp transpose.cpp + matmul.cpp timm.cpp utils.cpp main.cpp) diff --git a/benchmarks/cpp/nvfuser/batch_norm_channels_first.cpp b/benchmarks/cpp/nvfuser/batch_norm_channels_first.cpp index 723d222516df4..2f839f0c8332a 100644 --- a/benchmarks/cpp/nvfuser/batch_norm_channels_first.cpp +++ b/benchmarks/cpp/nvfuser/batch_norm_channels_first.cpp @@ -73,10 +73,6 @@ static void NvFuserScheduler_BatchNorm( DataType dtype) { TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half); - const bool kTraining = true; - const float kMomentum = 0.1; - const float kEps = 1e-5; - std::vector input_shape{ benchmark_state.range(0), benchmark_state.range(1), diff --git a/benchmarks/cpp/nvfuser/batch_norm_channels_first_backward.cpp b/benchmarks/cpp/nvfuser/batch_norm_channels_first_backward.cpp index af2b4d145fc8f..62a4e99e21ef6 100644 --- a/benchmarks/cpp/nvfuser/batch_norm_channels_first_backward.cpp +++ b/benchmarks/cpp/nvfuser/batch_norm_channels_first_backward.cpp @@ -25,7 +25,6 @@ static void setupBatchNorm_BWD(Fusion* fusion, DataType dtype) { FusionGuard fg(fusion); const bool kTraining = true; - const float kMomentum = 0.1; const float kEps = 1e-5; // setup fusion @@ -85,9 +84,6 @@ static void NvFuserScheduler_BatchNorm_BWD( DataType dtype) { TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half); - const bool kTraining = true; - const float kEps = 1e-5; - std::vector input_shape{ benchmark_state.range(0), benchmark_state.range(1), diff --git a/benchmarks/cpp/nvfuser/batch_norm_channels_last.cpp b/benchmarks/cpp/nvfuser/batch_norm_channels_last.cpp index 14fde631aec0b..7b8972a0aad07 100644 --- a/benchmarks/cpp/nvfuser/batch_norm_channels_last.cpp +++ b/benchmarks/cpp/nvfuser/batch_norm_channels_last.cpp @@ -74,10 +74,6 @@ static void NvFuserScheduler_BatchNorm_nhwc( DataType dtype) { TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half); - const bool kTraining = true; - const float kMomentum = 0.1; - const float kEps = 1e-5; - std::vector input_shape{ benchmark_state.range(0), benchmark_state.range(2), diff --git a/benchmarks/cpp/nvfuser/batch_norm_channels_last_backward.cpp b/benchmarks/cpp/nvfuser/batch_norm_channels_last_backward.cpp index 0660b75e39426..29bcfb3e81be7 100644 --- a/benchmarks/cpp/nvfuser/batch_norm_channels_last_backward.cpp +++ b/benchmarks/cpp/nvfuser/batch_norm_channels_last_backward.cpp @@ -25,7 +25,6 @@ static void setupBatchNorm_nhwc_BWD(Fusion* fusion, DataType dtype) { FusionGuard fg(fusion); const bool kTraining = true; - const float kMomentum = 0.1; const float kEps = 1e-5; // setup fusion @@ -86,9 +85,6 @@ static void NvFuserScheduler_BatchNorm_nhwc_BWD( DataType dtype) { TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half); - const bool kTraining = true; - const float kEps = 1e-5; - std::vector input_shape{ benchmark_state.range(0), benchmark_state.range(2), diff --git a/benchmarks/cpp/nvfuser/gelu_backward.cpp b/benchmarks/cpp/nvfuser/gelu_backward.cpp index e6a24111e848f..732ad7f0ea0fd 100644 --- a/benchmarks/cpp/nvfuser/gelu_backward.cpp +++ b/benchmarks/cpp/nvfuser/gelu_backward.cpp @@ -113,9 +113,6 @@ BENCHMARK(GeluBackward_AutoSchedule)->Unit(benchmark::kMicrosecond); //------------------------------------------------------------------------------ static void GeluBackward_Lower(benchmark::State& benchmark_state) { - constexpr int kHiddenFeatures = 512; - constexpr int kBatchSize = 64; - Fusion fusion; // setup fusion diff --git a/benchmarks/cpp/nvfuser/layer_norm.cpp b/benchmarks/cpp/nvfuser/layer_norm.cpp index 316fe22c1ff4f..d2cff09e5d2ed 100644 --- a/benchmarks/cpp/nvfuser/layer_norm.cpp +++ b/benchmarks/cpp/nvfuser/layer_norm.cpp @@ -22,7 +22,6 @@ static void setupLayerNorm(Fusion* fusion, DataType dtype) { FusionGuard fg(fusion); - const int kReductionAxis = 1; const float kEps = 1e-5; Double* eps_ptr = IrBuilder::create(kEps); @@ -61,7 +60,6 @@ static void NvFuserScheduler_LayerNorm( std::vector input_shape{ benchmark_state.range(0), benchmark_state.range(1)}; - const float kEps = 1e-5; // inputs at::manual_seed(0); diff --git a/benchmarks/cpp/nvfuser/layer_norm_backward.cpp b/benchmarks/cpp/nvfuser/layer_norm_backward.cpp index cce8aa42ce933..c431622e7b9f4 100644 --- a/benchmarks/cpp/nvfuser/layer_norm_backward.cpp +++ b/benchmarks/cpp/nvfuser/layer_norm_backward.cpp @@ -22,9 +22,6 @@ static void setupLayerNorm_BWD(Fusion* fusion, DataType dtype) { TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half); - const int kReductionAxis = 1; - Double* eps_ptr = IrBuilder::create(1e-5); - // setup fusion auto grad_out = makeContigTensor(2, dtype); auto input = makeContigTensor(2, dtype); diff --git a/benchmarks/cpp/nvfuser/matmul.cpp b/benchmarks/cpp/nvfuser/matmul.cpp new file mode 100644 index 0000000000000..25fc6cfe23569 --- /dev/null +++ b/benchmarks/cpp/nvfuser/matmul.cpp @@ -0,0 +1,357 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include + +#include + +using namespace torch::jit::fuser::cuda; + +bool cudaArchGuardShouldSkip(int required_major, int required_minor) { + int capability_major = at::cuda::getCurrentDeviceProperties()->major; + int capability_minor = at::cuda::getCurrentDeviceProperties()->minor; + + if (capability_major < required_major || + (capability_major == required_major && + capability_minor < required_minor)) { + return true; + } + return false; +} + +bool hasRequiredSmemSize(size_t required_size) { + // Only checking device 0 + return at::cuda::getDeviceProperties(0)->sharedMemPerBlockOptin >= + required_size; +} + +#define NVFUSER_BENCHMARK_ARCH_SMEM_GUARD( \ + REQUIRED_MAJOR, REQUIRED_MINOR, SMEM_SIZE, STATE) \ + if (cudaArchGuardShouldSkip(REQUIRED_MAJOR, REQUIRED_MINOR) || \ + !hasRequiredSmemSize(SMEM_SIZE)) { \ + STATE.SkipWithError("Unsupported arch or not enough smem!"); \ + return; \ + } + +// util to track support matmul operand layout. +using MatmulLayout = MmaOptions::MmaInputLayout; + +static constexpr std::array kAllSupportedLayout = { + MatmulLayout::TT, + MatmulLayout::NT, + MatmulLayout::TN}; + +// Generic interface to get matmul op with the given layout. +TensorView* matmul(TensorView* a, TensorView* b, MatmulLayout layout) { + TORCH_CHECK( + a->nDims() == 2 && b->nDims() == 2, "only pure matmuls for these tests"); + TensorView *tv2 = nullptr, *tv0b = nullptr, *tv1b = nullptr; + switch (layout) { + case MatmulLayout::TT: + tv0b = broadcast(a, {false, false, true}); + tv1b = broadcast(b, {true, false, false}); + tv2 = fusedMultiplySum(tv0b, tv1b, {1}); + break; + case MatmulLayout::TN: + tv0b = broadcast(a, {false, true, false}); + tv1b = broadcast(b, {true, false, false}); + tv2 = fusedMultiplySum(tv0b, tv1b, {2}); + break; + case MatmulLayout::NT: + tv0b = broadcast(a, {false, false, true}); + tv1b = broadcast(b, {false, true, false}); + tv2 = fusedMultiplySum(tv0b, tv1b, {0}); + break; + default: + TORCH_CHECK(false, "unsupported data layout."); + } + return tv2; +} + +// Utility to generate matmul input tensors based on given layout +at::Tensor atMatmul(at::Tensor a, at::Tensor b, MatmulLayout layout) { + switch (layout) { + case MatmulLayout::TT: + return a.matmul(b); + case MatmulLayout::TN: + return a.matmul(b.t()); + case MatmulLayout::NT: + return a.t().matmul(b); + default: + TORCH_CHECK(false, "unsupported data layout."); + } + return at::Tensor(); +} + +// Utility to generate reference results based on given layout +std::pair fp16MatmulAtInput( + int M, + int N, + int K, + MatmulLayout layout) { + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + + switch (layout) { + case MatmulLayout::TT: + return std::make_pair( + at::randn({M, K}, options), at::randn({K, N}, options)); + case MatmulLayout::TN: + return std::make_pair( + at::randn({M, K}, options), at::randn({N, K}, options)); + case MatmulLayout::NT: + return std::make_pair( + at::randn({K, M}, options), at::randn({K, N}, options)); + default: + TORCH_CHECK(false, "unsupported data layout."); + } + return std::make_pair(at::Tensor(), at::Tensor()); +} + +// TODO: separate compute and schedule definition once the can schedule +// logic and pattern matching is ready. +void setupMatmul(Fusion* fusion, MatmulLayout layout, MatmulParam params) { + // Only hgemm on the initial setup + auto a = makeContigTensor(2, DataType::Half); + auto b = makeContigTensor(2, DataType::Half); + + auto c = matmul(a, b, layout); + + fusion->addInput(a); + fusion->addInput(b); + fusion->addOutput(c); + + scheduleMatmul(c, a, b, params); +} + +static void SingleMatmulBase( + benchmark::State& benchmark_state, + MatmulLayout layout, + MatmulParam params) { + std::vector input_mnk{ + benchmark_state.range(0), + benchmark_state.range(1), + benchmark_state.range(2)}; + + auto fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + // Define fusion graph + setupMatmul(fusion, layout, params); + + // inputs + at::manual_seed(0); + + // Tensor inputs + auto inputs = fp16MatmulAtInput( + input_mnk.at(0), input_mnk.at(1), input_mnk.at(2), layout); + + KernelArgumentHolder args = KernelArgumentHolder::createKernelArgumentHolder( + {inputs.first, inputs.second}); + + // Always use 32b indexing mode for now. + TORCH_INTERNAL_ASSERT(args.getIndexMode() == KernelIndexMode::INT32); + + // Compile kernel + FusionExecutor fe; + fe.compileFusion(fusion, args, LaunchParams()); + + // Warm up run + auto outputs = fe.runFusion({inputs.first, inputs.second}); + fe.setMeasureKernelTimeFlag(true); + + // Sync everything up before we start + for (auto _ : benchmark_state) { + clearL2Cache(); + auto outputs = fe.runFusion({inputs.first, inputs.second}); + benchmark_state.SetIterationTime(fe.kernelTimeMs() / 1000.0); + } + // Sync everything up before we're finished, don't want to run ahead on the + // cpu while benchmarking. + cudaDeviceSynchronize(); + + // TODO: FLOPS calculation +} + +static void EagerModeMatmul( + benchmark::State& benchmark_state, + MatmulLayout layout) { + std::vector input_mnk{ + benchmark_state.range(0), + benchmark_state.range(1), + benchmark_state.range(2)}; + + at::manual_seed(0); + + auto inputs = fp16MatmulAtInput( + input_mnk.at(0), input_mnk.at(1), input_mnk.at(2), layout); + + // warm up run + auto outputs = atMatmul(inputs.first, inputs.second, layout); + + for (auto _ : benchmark_state) { + clearL2Cache(); + CudaKernelTimer timer; + outputs = atMatmul(inputs.first, inputs.second, layout); + benchmark_state.SetIterationTime(timer.elapsed() / 1000.0); + } + // Sync everything up before we're finished, don't want to run ahead on the + // cpu while benchmarking. + cudaDeviceSynchronize(); +} + +// Actual benchmarking +// ----------------------------------------------------------------- + +size_t getSmemSize(GemmTile cta_tile, int stage_number) { + return ((cta_tile.m * cta_tile.k) + (cta_tile.n * cta_tile.k)) * + dataTypeSize(DataType::Half) * stage_number; +} + +// TODO: this part eventually will be automated by heuristics +MatmulParam getMatmulParams( + GemmTile cta_tile, + int stage_number, + MatmulLayout layout) { + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = cta_tile; + // TODO: pipe through split K + gemm_tile.warp_tile = GemmTile(64, 64, cta_tile.k); + gemm_tile.instruction_tile = GemmTile(16, 16, 16); + + // Collect mma swizzle info + auto mma_builder = + MmaBuilder(MmaOptions::MacroType::Ampere_16_16_16, gemm_tile) + .layout(layout); + + MatmulParam params(mma_builder); + params.tile_sizes = gemm_tile; + params.async_gmem_load_operands = true; + params.double_buffer_options.double_buffer_smem_write = true; + params.double_buffer_options.double_buffer_smem_read = true; + params.double_buffer_options.smem_double_buffer_stage = stage_number; + + return params; +} + +static void Nvfuser_Matmul_4warp3stage( + benchmark::State& benchmark_state, + MatmulLayout layout) { + auto cta_tile = GemmTile(128, 128, 32); + int number_of_stage = 3; + + auto params = getMatmulParams(cta_tile, number_of_stage, layout); + + NVFUSER_BENCHMARK_ARCH_SMEM_GUARD( + 8, 0, getSmemSize(cta_tile, number_of_stage), benchmark_state); + + // Run benchmark: + SingleMatmulBase(benchmark_state, layout, params); +} + +static void Nvfuser_Matmul_8warp3stage( + benchmark::State& benchmark_state, + MatmulLayout layout) { + auto cta_tile = GemmTile(256, 128, 32); + int number_of_stage = 3; + + auto params = getMatmulParams(cta_tile, number_of_stage, layout); + + NVFUSER_BENCHMARK_ARCH_SMEM_GUARD( + 8, 0, getSmemSize(cta_tile, number_of_stage), benchmark_state); + + // Run benchmark: + SingleMatmulBase(benchmark_state, layout, params); +} + +static void Nvfuser_Matmul_4warp4stage( + benchmark::State& benchmark_state, + MatmulLayout layout) { + auto cta_tile = GemmTile(128, 128, 32); + int number_of_stage = 4; + + auto params = getMatmulParams(cta_tile, number_of_stage, layout); + + NVFUSER_BENCHMARK_ARCH_SMEM_GUARD( + 8, 0, getSmemSize(cta_tile, number_of_stage), benchmark_state); + + // Run benchmark: + SingleMatmulBase(benchmark_state, layout, params); +} + +static void Nvfuser_Matmul_8warp4stage( + benchmark::State& benchmark_state, + MatmulLayout layout) { + auto cta_tile = GemmTile(256, 128, 32); + int number_of_stage = 4; + + auto params = getMatmulParams(cta_tile, number_of_stage, layout); + + NVFUSER_BENCHMARK_ARCH_SMEM_GUARD( + 8, 0, getSmemSize(cta_tile, number_of_stage), benchmark_state); + + // Run benchmark: + SingleMatmulBase(benchmark_state, layout, params); +} + +// ----------------------------- Benchmark Instantiation------- + +// Common utils: +#define NO_TILE_QUANTIZATION_ARGS \ + ArgsProduct( \ + {{2048}, {3456}, benchmark::CreateDenseRange(512, 4096, /*step=*/512)}) \ + ->Unit(benchmark::kMicrosecond) \ + ->UseManualTime(); + +#define ForAllLayouts(run) \ + run(TT, MatmulLayout::TT); \ + run(TN, MatmulLayout::TN); \ + run(NT, MatmulLayout::NT) + +// Instantiations: +#define Nvfuser_4warp3stage_test(layout_label, layout) \ + BENCHMARK_CAPTURE( \ + Nvfuser_Matmul_4warp3stage, \ + no_quant_nvfuser_4warp_##layout_label, \ + layout) \ + ->NO_TILE_QUANTIZATION_ARGS + +#define Nvfuser_8warp3stage_test(layout_label, layout) \ + BENCHMARK_CAPTURE( \ + Nvfuser_Matmul_8warp3stage, \ + no_quant_nvfuser_8warp_##layout_label, \ + layout) \ + ->NO_TILE_QUANTIZATION_ARGS + +#define Nvfuser_4warp4stage_test(layout_label, layout) \ + BENCHMARK_CAPTURE( \ + Nvfuser_Matmul_4warp4stage, \ + no_quant_nvfuser_4warp_##layout_label, \ + layout) \ + ->NO_TILE_QUANTIZATION_ARGS + +#define Nvfuser_8warp4stage_test(layout_label, layout) \ + BENCHMARK_CAPTURE( \ + Nvfuser_Matmul_8warp4stage, \ + no_quant_nvfuser_8warp_##layout_label, \ + layout) \ + ->NO_TILE_QUANTIZATION_ARGS + +#define Eagermode_test(layout_label, layout) \ + BENCHMARK_CAPTURE( \ + EagerModeMatmul, no_quant_eagermode_##layout_label, layout) \ + ->NO_TILE_QUANTIZATION_ARGS + +ForAllLayouts(Nvfuser_4warp3stage_test); +ForAllLayouts(Nvfuser_4warp4stage_test); +ForAllLayouts(Nvfuser_8warp3stage_test); +ForAllLayouts(Nvfuser_8warp4stage_test); +ForAllLayouts(Eagermode_test); diff --git a/benchmarks/cpp/nvfuser/rms_norm.cpp b/benchmarks/cpp/nvfuser/rms_norm.cpp index 81fdf46cf8189..37911ea6b1fd2 100644 --- a/benchmarks/cpp/nvfuser/rms_norm.cpp +++ b/benchmarks/cpp/nvfuser/rms_norm.cpp @@ -24,7 +24,6 @@ static void setupRMSNorm(Fusion* fusion, DataType dtype) { FusionGuard fg(fusion); - const int kReductionAxis = 2; const float kEps = 1e-6; Double* eps_ptr = IrBuilder::create(kEps); @@ -61,7 +60,6 @@ static void NvFuserScheduler_RMSNorm( dtype == DataType::BFloat16); std::vector input_shape{8, benchmark_state.range(0), 1024}; - const float kEps = 1e-6; // inputs at::manual_seed(0); diff --git a/benchmarks/cpp/nvfuser/rms_norm_backward.cpp b/benchmarks/cpp/nvfuser/rms_norm_backward.cpp index b4c6ac413c758..987c3bf234fa2 100644 --- a/benchmarks/cpp/nvfuser/rms_norm_backward.cpp +++ b/benchmarks/cpp/nvfuser/rms_norm_backward.cpp @@ -24,9 +24,6 @@ static void setupRMSNorm_BWD(Fusion* fusion, DataType dtype) { dtype == DataType::Float || dtype == DataType::Half || dtype == DataType::BFloat16); - const int kReductionAxis = 2; - Double* eps_ptr = IrBuilder::create(1e-6); - // setup fusion auto grad_out = makeContigTensor(3, dtype); auto input = makeContigTensor(3, dtype); diff --git a/benchmarks/cpp/nvfuser/timm.cpp b/benchmarks/cpp/nvfuser/timm.cpp index 013b609be6020..4669ff0ecabf6 100644 --- a/benchmarks/cpp/nvfuser/timm.cpp +++ b/benchmarks/cpp/nvfuser/timm.cpp @@ -115,7 +115,7 @@ static void setup_vit_base_patch16_224_bcast5(Fusion* fusion, void* null) { auto t6 = set(t5); auto t7 = broadcast(t6, bcast_pattern0); auto t8 = add(t4, t7); - auto t9 = randlike(t8); + auto t9 = rand_like(t8); auto d34 = sub(IrBuilder::create(1.0), IrBuilder::create(0.0)); auto t10 = lt(t9, d34); @@ -139,7 +139,6 @@ static void setup_vit_base_patch16_224_bcast5(Fusion* fusion, void* null) { auto t20 = sum(t37, {2}); auto t24 = broadcast(t20, bcast_pattern1); auto d95 = castOp(DataType::Double, t2->axis(2)->extent()); - auto d96 = mul(IrBuilder::create(1.0), d95); auto d105 = reciprocal(d95); auto t25 = mul(t24, d105); auto t26 = add(t25, IrBuilder::create(1e-6)); @@ -289,7 +288,7 @@ static void setup_vit_base_patch16_224_norm_inner3(Fusion* fusion, void* null) { auto t10 = broadcast(t9, {false, false, false, true}); auto t11 = reciprocal(t10); auto t12 = mul(t8, t11); - auto t13 = randlike(t12); + auto t13 = rand_like(t12); auto d79 = sub(IrBuilder::create(1), IrBuilder::create(0)); auto t14 = lt(t13, d79); auto t15 = castOp(DataType::Float, t14); @@ -320,8 +319,6 @@ static void NvFuserScheduler_TIMM_vit_base_patch16_224_norm_inner3( at::manual_seed(0); auto fp16_options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto fp32_options = - at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); auto t0 = at::randn(input_shape, fp16_options); @@ -367,7 +364,7 @@ static void setup_vit_base_patch16_224_bcast_outer6( auto t9 = add(IrBuilder::create(1), t8); auto t10 = mul(IrBuilder::create(0.5), t9); auto t11 = mul(t6, t10); - auto t12 = randlike(t11); + auto t12 = rand_like(t11); auto d66 = sub(IrBuilder::create(1), IrBuilder::create(0)); auto t13 = lt(t12, d66); auto t14 = castOp(DataType::Float, t13); @@ -456,7 +453,7 @@ static void setup_vit_base_patch16_224_bcast_inner6( auto t9 = add(IrBuilder::create(1), t8); auto t10 = mul(IrBuilder::create(0.5), t9); auto t11 = mul(t6, t10); - auto t12 = randlike(t11); + auto t12 = rand_like(t11); auto d66 = sub(IrBuilder::create(1), IrBuilder::create(0)); auto t13 = lt(t12, d66); auto t14 = castOp(DataType::Float, t13); diff --git a/benchmarks/distributed/ddp/README.md b/benchmarks/distributed/ddp/README.md index 0bf254ee4cce2..f89aaff9809eb 100644 --- a/benchmarks/distributed/ddp/README.md +++ b/benchmarks/distributed/ddp/README.md @@ -158,7 +158,7 @@ Benchmark: resnext101_32x8d with batch size 32 ``` This compares throughput between `bucket_cap_mb=25` (the default) and -`bucket_cap_mb=1` on 8 DGX machines with V100 GPUs. It confims that +`bucket_cap_mb=1` on 8 DGX machines with V100 GPUs. It confirms that even for a relatively small model on machines with a very fast interconnect (4x 100Gb InfiniBand per machine), it still pays off to batch allreduce calls. diff --git a/benchmarks/dynamo/Makefile_dashboard b/benchmarks/dynamo/Makefile_dashboard index 729178f538408..904b6726c494c 100644 --- a/benchmarks/dynamo/Makefile_dashboard +++ b/benchmarks/dynamo/Makefile_dashboard @@ -5,16 +5,20 @@ PIP ?= python -m pip clone-deps: (cd ../../.. \ && (test -e torchvision || git clone --recursive https://github.com/pytorch/vision torchvision) \ + && (test -e torchdata || git clone --recursive https://github.com/pytorch/data.git torchdata) \ && (test -e torchtext || git clone --recursive https://github.com/pytorch/text torchtext) \ + && (test -e torchaudio || git clone --recursive https://github.com/pytorch/audio torchaudio) \ && (test -e detectron2 || git clone --recursive https://github.com/facebookresearch/detectron2) \ && (test -e torchbenchmark || git clone --recursive https://github.com/pytorch/benchmark torchbenchmark) \ && (test -e triton || git clone --recursive https://github.com/openai/triton.git) \ ) -pull-deps: +pull-deps: clone-deps echo $(TRITON_VERSION) (cd ../../../torchvision && git pull && git submodule update --init --recursive) + (cd ../../../torchdata && git pull && git submodule update --init --recursive) (cd ../../../torchtext && git pull && git submodule update --init --recursive) + (cd ../../../torchaudio && git pull && git submodule update --init --recursive) (cd ../../../detectron2 && git pull && git submodule update --init --recursive) (cd ../../../torchbenchmark && git pull && git submodule update --init --recursive) (cd ../../../triton && git checkout master && git pull && git checkout $(TRITON_VERSION) && git submodule update --init --recursive) @@ -28,7 +32,9 @@ build-deps: clone-deps conda install -y -c pytorch magma-cuda116 conda install -y -c conda-forge librosa (cd ../../../torchvision && python setup.py clean && python setup.py develop) + (cd ../../../torchdata && python setup.py install) (cd ../../../torchtext && python setup.py clean && python setup.py develop) + (cd ../../../torchaudio && python setup.py clean && python setup.py develop) (cd ../../../detectron2 && python setup.py clean && python setup.py develop) (cd ../../../torchbenchmark && python install.py --continue_on_fail) (cd ../../../triton/python && python setup.py clean && python setup.py develop) diff --git a/benchmarks/dynamo/README.md b/benchmarks/dynamo/README.md index 5307e77b9b173..91556084cd0db 100644 --- a/benchmarks/dynamo/README.md +++ b/benchmarks/dynamo/README.md @@ -27,11 +27,13 @@ For HF and TIMM models, the scripts already install the transformers and timm pa There are a lot of flags in the benchmark runner, and it can be confusing to know which settings to use or what machine to run it on. In order to support apples-to-apples comparison, we have provided the following 'standard' settings in `runner.py`. This script is a wrapper over the common benchmarking infrastructure and simplifies the flags. We will continually update `runner.py` with the latest and most relevant compilers for training and inference. It also provides some graph utilities to visualize and compare results. Some of the example commands are **Inference Commands** -* Inference compilers on torchbench models - `python benchmarks/runner.py --suites=torchbench --inference --dtypes=float16` +* Inference compilers on torchbench models - `python benchmarks/dynamo/runner.py --suites=torchbench --inference --dtypes=float16` +* Inductor Inference compiler on torchbench models - `python benchmarks/dynamo/runner.py --suites=torchbench --inference --dtypes=float16 --compilers=inductor` **Training Commands** -* Training compilers on TIMM models - `python benchmarks/runner.py --suites=timm_models --training --dtypes=float32 --output-dir=timm_logs` -* AOTAutograd Training compiler on TIMM models - `python benchmarks/runner.py --suites=timm_models --training --dtypes=float32 --compilers=aot_nvfuser --output-dir=timm_logs` +* Training compilers on TIMM models - `python benchmarks/dynamo/runner.py --suites=timm_models --training --dtypes=float32 --output-dir=timm_logs` +* AOTAutograd Training compiler on TIMM models - `python benchmarks/dynamo/runner.py --suites=timm_models --training --dtypes=float32 --compilers=aot_nvfuser --output-dir=timm_logs` +* Inductor Training compiler on TIMM models - `python benchmarks/dynamo/runner.py --suites=timm_models --training --dtypes=float32 --compilers=inductor --output-dir=timm_logs` Running runner.py generates a file named `run.sh`. This file contains the actual commands that invoke the common benchmarking infrastructure with the appropriate flags. Which brings us to the advanced usage. @@ -40,11 +42,11 @@ Running runner.py generates a file named `run.sh`. This file contains the actual One could directly call `torchbench.py`, `huggingface.py` or `timm_models.py` with the necessary flags. There are a lot of flags in the benchmarks runner. Some of the examples are as follows. These are subject to change. **Inference Commands** -* TorchScript NVFuser Inference - `python benchmarks/torchbench.py -dcuda -n100 --speedup-ts` -* TorchInductor CUDA Graphs Inference - `python benchmarks/torchbench.py -dcuda --inductor-settings --float32 -n50 --inductor` +* TorchScript (with TorchDynamo capture) NVFuser Inference - `python benchmarks/dynamo/torchbench.py -dcuda -n100 --speedup-dynamo-ts --performance` +* TorchInductor CUDA Graphs Inference - `python benchmarks/dynamo/torchbench.py -dcuda --float32 -n50 --inductor --performance` **Training Commands** -* Torchscript (with TorchDynamo capture) NVFuser Training - `python benchmarks/torchbench.py --float32 -dcuda --training --nvfuser --speedup-dynamo-ts --use-eval-mode` -* AOTAutograd Torchscript NVFuser Training - `python benchmarks/torchbench.py --float32 -dcuda --training --nvfuser --accuracy-aot-ts-mincut --use-eval-mode` +* Torchscript (with TorchDynamo capture) NVFuser Training - `python benchmarks/dynamo/torchbench.py --float32 -dcuda --training --nvfuser --speedup-dynamo-ts --performance` +* TorchInductor CUDA Graphs Training - `python benchmarks/dynamo/torchbench.py --float32 -dcuda --training --inductor --performance` Above commands are for torchbench models. You can simply replace `torchbench.py` with `huggingface.py` for HF models, and `timm_model.py` for TIMM models. diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index 8ff1fb5c3ae93..ba4ca471e8f4b 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -4,6 +4,7 @@ import copy import csv import functools +import importlib import io import logging import os @@ -13,6 +14,7 @@ import sys import time import warnings +from contextlib import contextmanager import numpy as np import pandas as pd @@ -20,23 +22,25 @@ import torch._dynamo import torch._dynamo.utils -from microbenchmarks.operator_inp_utils import OperatorInputsMode +import torch.distributed from scipy.stats import gmean, ttest_ind from torch._dynamo.optimizations import backends from torch._dynamo.optimizations.log_args import conv_args_analysis from torch._dynamo.profiler import fx_insert_profiling, Profiler from torch._dynamo.testing import dummy_fx_compile, format_speedup, same from torch._dynamo.utils import clone_inputs -from torch._inductor.utils import fresh_triton_cache +from torch._functorch.aot_autograd import set_model_name +from torch._inductor import config as inductor_config +from torch._inductor.utils import fresh_inductor_cache from torch._subclasses.fake_tensor import FakeTensorMode +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils._pytree import tree_map try: - from functorch._src.aot_autograd import set_model_name + from .microbenchmarks.operator_inp_utils import OperatorInputsMode except ImportError: - - def set_model_name(name): - pass + from microbenchmarks.operator_inp_utils import OperatorInputsMode log = logging.getLogger(__name__) @@ -83,81 +87,93 @@ def set_model_name(name): CI_SKIP_INDCUTOR_INFERENCE = [ *CI_SKIP_AOT_EAGER_INFERENCE, # TorchBench + "DALLE2_pytorch", "detectron2", - "hf_Reformer", + "hf_T5", # accuracy + "hf_BigBird", # accuracy + "hf_GPT2_large", # OOM + "maml", # accuracy + "mobilenet_v2_quantized_qat", # The eval test only supports CPU "moco", # accuracy + "pytorch_struct", # Test eval is not implemented "pyhpc_equation_of_state", # Accuracy "pyhpc_turbulent_kinetic_energy", # Accuracy "tacotron2", "vision_maskrcnn", # accuracy - "yolov3", # Accuracy # Huggingface - "BigBird", - "YituTechConvBert", + "DebertaV2ForQuestionAnswering", # OOM # TIMM "cait_m36_384", # Accuracy "ghostnet_100", # Accuracy - "swin_base_patch4_window7_224", # Accuracy ] CI_SKIP_INDUCTOR_TRAINING = [ - # CI does not check accuracy for inductor training yet - # *CI_SKIP_AOT_EAGER_TRAINING, - # *CI_SKIP_INDCUTOR_INFERENCE, + *CI_SKIP_INDCUTOR_INFERENCE, # TorchBench - "attention_is_all_you_need_pytorch", - "drq", - "hf_Albert", - "hf_Bart", - "hf_GPT2", - "hf_Reformer", - "mobilenet_v3_large", - "moco", - "pytorch_struct", - "vgg16", - "speech_transformer", # from functionalization - "vision_maskrcnn", # from functionalization - "timm_efficientnet", # from functionalization (only fails for inductor) - "hf_Bert", - "soft_actor_critic", - "tacotron2", - "yolov3", - # OOM - "Background_Matting", - "fastNLP_Bert", - "hf_BigBird", - "mobilenet_v2", - "mobilenet_v2_quantized_qat", - "resnet50_quantized_qat", - "timm_regnet", + "Background_Matting", # fp64_OOM + "dlrm", # Fails on CI - unable to repro locally + "mobilenet_v3_large", # accuracy + "resnet50_quantized_qat", # Eager model failed to run # Huggingface - "AllenaiLongformerBase", - "AlbertForMaskedLM", # OOM - "BartForConditionalGeneration", # OOM + "BlenderbotForCausalLM", # OOM + "GoogleFnet", # Eager model failed to run "M2M100ForConditionalGeneration", # OOM - "MBartForConditionalGeneration", # OOM - "MT5ForConditionalGeneration", # OOM - "PegasusForConditionalGeneration", # OOM - "XGLMForCausalLM", # fp64_OOM - # OOM - "BigBird", - "TrOCRForCausalLM", - "AlbertForQuestionAnswering", + "XGLMForCausalLM", # OOM # TIMM - "cait_m36_384", # fp64_OOM - "coat_lite_mini", # time out "convit_base", # fp64_OOM - "gernet_l", # accuracy - "gluon_xception65", - "lcnet_0500", # accuracy - "levit_128", # levit_128 - "rexnet_100", # accuracy - "swin_base_patch4_window7_224", - "twins_pcpvt_base", # time out + "dm_nfnet_f0", # accuracy + "convmixer_768_32", # accuracy - Unable to repro on A100 + "hrnet_w18", # accuracy - Unable to repro on A100 + "sebotnet33ts_256", # accuracy - Unable to repro on A100 + "hrnet_w18", # accuracy - Unable to repro on A100 + "eca_botnext26ts_256", # accuracy - Fails on A100 + "eca_halonext26ts", # accuracy + "fbnetv3_b", # accuracy + "levit_128", # fp64_OOM + "res2net101_26w_4s", # accuracy + "spnasnet_100", # accuracy + "resnest101e", # accuracy + "swin_base_patch4_window7_224", # accuracy "xcit_large_24_p8_224", # fp64_OOM ] +def model_specified_by_path(path_and_class_str): + return ":" in path_and_class_str + + +def load_model_from_path(path_and_class_str): + configs = {} + for kvstr in path_and_class_str.split(","): + k, v = kvstr.split(":") + configs[k] = v + + for name in ["path", "class"]: + if name not in configs: + raise RuntimeError( + "Invalid --only arguments. Check help message for the correct format" + ) + + path = configs["path"] + class_name = configs["class"] + + if path[:1] != "/": + raise RuntimeError( + "Use absolute path since dynamo may change the current working directory which makes using relative path tricky" + ) + + spec = importlib.util.spec_from_file_location("module_name", path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + model_class = getattr(module, class_name) + assert issubclass(model_class, torch.nn.Module) + model = model_class() + assert hasattr(model, "get_example_inputs") + inputs = model.get_example_inputs() + return model, inputs + + def output_csv(filename, headers, row): assert filename existed = os.path.exists(filename) @@ -182,6 +198,10 @@ def __exit__(self, exc_type, exc_val, exc_tb): pass +def nothing(f): + return f + + @functools.lru_cache(None) def patch_torch_manual_seed(): """Make torch manual seed deterministic. Helps with accuracy testing.""" @@ -230,13 +250,36 @@ def print_summary(filename): pass +def tensor_is_on_xla(tensors): + if not isinstance(tensors, (tuple, list)): + tensors = [tensors] + tensors = [x for x in tensors if isinstance(x, torch.Tensor)] + return any(map(lambda x: x.device.type == "xla", tensors)) + + def timed(model, model_iter_fn, example_inputs, times=1, return_result=False): synchronize() + if tensor_is_on_xla(example_inputs): + import torch_xla.core.xla_model as xm + + xm.mark_step() + reset_rng_state() t0 = time.perf_counter() # Dont collect outputs to correctly measure timing for _ in range(times): result = model_iter_fn(model, example_inputs, collect_outputs=False) + if tensor_is_on_xla(result): + # If the model is on XLA device, it's possible that after running + # the model, the computation is accumulated but not performed yet. + # Flush all the accumulated computations to make the time measurement + # accurate. + import torch_xla + + result_list = result + if not isinstance(result, (tuple, list)): + result_list = [result] + torch_xla._XLAC._xla_sync_multi(result_list, []) synchronize() t1 = time.perf_counter() return (t1 - t0, result) if return_result else t1 - t0 @@ -348,82 +391,11 @@ def randomize_input(inputs): ) -def cold_start_experiment(args, model_iter_fn, model, example_inputs, optimize_ctx): - compile_iters = 2 - total_iters = compile_iters + 2 - timings = np.zeros((total_iters, 2), np.float64) - # if we randomize the input, we should also check the result is correct - should_check_result = should_randomize_input = args.randomize_input - is_correct = True +def maybe_mark_step(args): + if args.trace_on_xla: + import torch_xla.core.xla_model as xm - optimized_model_iter_fn = optimize_ctx(model_iter_fn) - for rep in range(total_iters): - inputs = ( - randomize_input(copy.deepcopy(example_inputs)) - if should_randomize_input - else example_inputs - ) - - # interleave the runs to handle frequency scaling and load changes - timings[rep, 0], expected_output = timed( - model, model_iter_fn, inputs, return_result=True - ) - timings[rep, 1], actual_output = timed( - model, optimized_model_iter_fn, inputs, return_result=True - ) - if should_check_result: - is_correct = is_correct and same(expected_output, actual_output) - pvalue = ttest_ind(timings[:, 0], timings[:, 1]).pvalue - worst = np.max(timings, axis=0) - - def breakeven(dynamo_times, eager_times): - """ - Solve for the number of iterations it takes dynamo to 'catch up' with eager, - taking into account the time it spent compiling. Assumes all compilation - happens up front and the model is static thereafter, which is definitely not - true in general but might be across torchbench. - - dc1, dc2 = dynamo compilation iterations (with Prof Exec) - d, e = dynamo, eager warmed up iteration - B = num iters to break even - dc1 + dc2 + (B-2)d = B*e - B = (dc1 + dc2 - 2d) / (e - d) - """ - dc1, dc2, d = dynamo_times[0], dynamo_times[1], np.median(dynamo_times[2:]) - e = np.median(eager_times) - if d < e: - return (dc1 + dc2 + 2 * d) / (e - d) - else: - # if optimized dynamo is not faster than eager we'll compute - # a nonsense negative number - return 0 - - speedup = worst[0] / worst[1] - eager_times, dynamo_times = timings[:, 0], timings[:, 1] - output_csv( - output_filename, - ("dev", "name", "batch_size", "cold-start speedup", "breakeven iters"), - [ - current_device, - current_name, - current_batch_size, - float(speedup), - breakeven(dynamo_times, eager_times), - ], - ) - - def format_speedup( - speedup, pvalue, breakeven_iters, is_correct=True, pvalue_threshold=0.1 - ): - if not is_correct: - return "ERROR" - if pvalue > pvalue_threshold: - return f"{speedup:.3f}x breakeven={breakeven_iters:.2f} iters SAME" - return f"{speedup:.3f}x breakeven={breakeven_iters:.2f} iters p={pvalue:.2f}" - - return format_speedup( - speedup, pvalue, breakeven(dynamo_times, eager_times), is_correct=is_correct - ) + xm.mark_step() def speedup_experiment(args, model_iter_fn, model, example_inputs, **kwargs): @@ -450,6 +422,16 @@ def maybe_profile(*args, **kwargs): else: yield + @contextlib.contextmanager + def maybe_mark_profile(*args, **kwargs): + prof: torch.profiler.profile = kwargs.pop("p", None) + mark = kwargs.pop("mark", None) + if prof: + with torch.profiler.record_function(mark): + yield + else: + yield + with maybe_profile(enabled=args.export_profiler_trace) as p: frozen_model_iter_fn = torch._dynamo.run(model_iter_fn) for rep in range(args.repeat): @@ -458,16 +440,28 @@ def maybe_profile(*args, **kwargs): if should_randomize_input else example_inputs ) + # need call mark_step to perform the computation + # on randomize_input. Otherwise the first call using the + # inputs will incur high penalty then the next one. + maybe_mark_step(args) # interleave the runs to handle frequency scaling and load changes - timings[rep, 0], expected_output = timed( - model, model_iter_fn, inputs, return_result=True - ) - timings[rep, 1], actual_output = timed( - model, frozen_model_iter_fn, inputs, return_result=True - ) + with maybe_mark_profile(p=p, mark="expected"): + timings[rep, 0], expected_output = timed( + model, model_iter_fn, inputs, return_result=True + ) + + # call mark_step between the 2 calls to make the comparison fair. + maybe_mark_step(args) + + with maybe_mark_profile(p=p, mark="actual"): + timings[rep, 1], actual_output = timed( + model, frozen_model_iter_fn, inputs, return_result=True + ) + if should_check_result: is_correct = is_correct and same(expected_output, actual_output) + if args.export_profiler_trace: name = args.profiler_trace_name + "_" + model.name + ".json" name = os.path.join(torch._dynamo.config.base_dir, name) @@ -481,8 +475,14 @@ def maybe_profile(*args, **kwargs): timings, ) - headers = ("dev", "name", "batch_size", "speedup") - row = [current_device, current_name, current_batch_size, float(speedup)] + headers = ("dev", "name", "batch_size", "speedup", "abs_latency") + row = [ + current_device, + current_name, + current_batch_size, + float(speedup), + median[1] * 1000, + ] if "compilation_latency" in kwargs: headers = headers + ("compilation_latency", "compression_ratio") row.append(kwargs["compilation_latency"]) @@ -665,66 +665,6 @@ def try_script(model, example_inputs): return None -def speedup_experiment_ts(args, model_iter_fn, model, example_inputs): - """ - Measure baseline performance (without using TorchDynamo) of TorchScript and optimize_for_inference. - - Writes to ./baseline_ts.csv - """ - if args.training: - return baselines( - [ - ("eager", model), - ("ts", try_script(model, example_inputs)), - ], - model_iter_fn, - example_inputs, - args, - ) - - return baselines( - [ - ("eager", model), - ("ts", try_script(model, example_inputs)), - ( - "ofi", - backends.ofi(try_script(model, example_inputs), example_inputs), - ), - # ("nnc", backends.nnc(try_script(model, example_inputs), example_inputs)), - # ("nvfuser", backends.nvfuser(try_script(model, example_inputs), example_inputs)), - ], - model_iter_fn, - example_inputs, - args, - ) - - -def speedup_experiment_sr(args, model_iter_fn, model, example_inputs): - """ - Measure baseline performance (without using TorchDynamo) of static runtime. - - Writes to ./baseline_sr.csv - """ - - if current_name not in ("opacus_cifar10", "timm_nfnet", "hf_T5"): - sr = backends.static_runtime(try_script(model, example_inputs), example_inputs) - else: - # segfaults on these models - sr = None - return baselines( - [ - ("eager", model), - ( - "sr", - sr, - ), - ], - model_iter_fn, - example_inputs, - args, - ) - - def speedup_experiment_onnx(args, model_iter_fn, model, example_inputs): """ Measure baseline performance (without using TorchDynamo) of ONNXRT and TensorFlow. @@ -891,19 +831,19 @@ def scale(self, loss): return loss -def maybe_fresh_cache(fn): - def inner(self, *args, **kwargs): +def maybe_fresh_cache(fn, is_cold_start): + def inner(*args, **kwargs): cache_minder = NullContext() - if self.args.cold_start_latency: + if is_cold_start: cache_entries = {} - cache_minder = fresh_triton_cache(cache_entries) + cache_minder = fresh_inductor_cache(cache_entries) try: with cache_minder: - return fn(self, *args, **kwargs) + return fn(*args, **kwargs) finally: dump_cache = False - if dump_cache and self.args.cold_start_latency: + if dump_cache and is_cold_start: output_csv( output_filename[:-4] + "_triton_cache.csv", ["dev", "name", "batch_size", "triton_cache"], @@ -918,6 +858,24 @@ def inner(self, *args, **kwargs): return inner +@contextmanager +def maybe_init_distributed(should_init_distributed, port="6789", rank=0, world_size=1): + # To avoid multiple inheritance from _dynamo.test_case.TestCase and MultiProcessTestCase, + # Just manually implement the most important part of the dynamo behavior to reset/clear. + try: + if should_init_distributed: + torch.cuda.set_device(rank) + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = port + torch.distributed.init_process_group( + "nccl", rank=rank, world_size=world_size + ) + yield + finally: + if should_init_distributed: + torch.distributed.destroy_process_group() + + class BenchmarkRunner: def __init__(self): self.model_iter_fn = None @@ -942,16 +900,26 @@ def setup_amp(self): # Since we are not running a long iteration, default value of # init_scale 65536 is going to turn all gradients to inf. Therefore, # we just use a init_scale of 2.0 for benchmarking purpose. - self.grad_scaler = torch.cuda.amp.GradScaler(init_scale=2.0) + + # Disabling Gradscaler because + # 1) Benchmark setup runs 2 iterations of fwd-bwd. So, not useful. + # 2) Current setup shares grad_scaler for eager and dynamo model, + # which is bad as Gradscaler has state and can adjust the scaling + # factor between eager and dynamo run, making accuracy check + # harder. + # self.grad_scaler = torch.cuda.amp.GradScaler(init_scale=2.0) self.autocast = torch.cuda.amp.autocast def init_optimizer(self, device, params): - param_list = list(params) - if device == "cuda" and len(param_list) != 0: - # capturable is only supported on cuda at the moment - self.optimizer = torch.optim.Adam(param_list, capturable=True) - else: - self.optimizer = None + self.optimizer = None + # TODO - Currently, optimizers are used incorrectly. Fix optimizers with + # https://github.com/pytorch/pytorch/pull/87492 + # param_list = list(params) + # if device == "cuda" and len(param_list) != 0: + # # capturable is only supported on cuda at the moment + # self.optimizer = torch.optim.Adam(param_list, capturable=True) + # else: + # self.optimizer = None @property def args(self): @@ -1033,8 +1001,8 @@ def validate_model(self, model, example_inputs): try: self.model_iter_fn(model, example_inputs) - except Exception: - raise NotImplementedError("Eager model failed to run") + except Exception as e: + raise NotImplementedError("Eager model failed to run") from e def maybe_cast(self, model, example_inputs): model = copy.deepcopy(model) @@ -1053,7 +1021,7 @@ def decay_batch_exp(self, batch_size, factor=0.5, divisor=2): out_batch_size = batch_size - 1 return max(0, int(out_batch_size)) - def batch_size_finder(self, device, model_name, initial_batch_size=128): + def batch_size_finder(self, device, model_name, initial_batch_size=1024): batch_size = initial_batch_size while batch_size >= 1: torch.cuda.empty_cache() @@ -1077,9 +1045,11 @@ def run_n_iterations(self, mod, inputs, n=2): self.model_iter_fn(mod, inputs, collect_outputs=False) return self.model_iter_fn(mod, inputs, collect_outputs=True) - def optimizer_zero_grad(self): + def optimizer_zero_grad(self, mod): if self.optimizer is not None: self.optimizer.zero_grad(True) + else: + mod.zero_grad(True) def optimizer_step(self): if self.optimizer is not None: @@ -1116,44 +1086,56 @@ def record_status(accuracy_status): ) return "PASS" if accuracy_status in ("pass", "pass_due_to_skip") else "FAIL" - tolerance, cos_similarity = self.get_tolerance_and_cosine_flag( - self.args.training, current_device, name - ) - if name in self.skip_accuracy_checks_large_models_dashboard: return record_status("pass_due_to_skip") + def deepcopy_and_maybe_ddp(model): + model = copy.deepcopy(model) + if self.args.ddp: + model = DDP(model, find_unused_parameters=True) + elif self.args.fsdp: + model = FSDP(model, use_orig_params=True) + torch._inductor.config.triton.cudagraphs = False + log.warn("Disabling cudagraphs for FSDP compatibility") + return model + # Collect the fp64 reference outputs to be used later for accuracy checking. fp64_outputs = None try: fp64_outputs = self.run_n_iterations( *cast_to_fp64( - copy.deepcopy(model), + deepcopy_and_maybe_ddp(model), clone_inputs(example_inputs), ) ) except Exception: - log.warning(f"fp64 golden ref were not generated for {name}") + log.warning( + f"fp64 golden ref were not generated for {name}. Setting accuracy check to cosine" + ) + self.args.cosine = True fp64_outputs = None if self.args.ci and self.args.training: return record_status("fp64_OOM") + tolerance, cos_similarity = self.get_tolerance_and_cosine_flag( + self.args.training, current_device, name + ) + # Cast the model to float16/float32 as necessary model, example_inputs = self.maybe_cast(model, example_inputs) - accuracy_status = "pass" with self.pick_grad(name, self.args.training): # Get results of native pytorch reset_rng_state() correct_result = self.run_n_iterations( - copy.deepcopy(model), clone_inputs(example_inputs) + deepcopy_and_maybe_ddp(model), clone_inputs(example_inputs) ) # Rerun native pytorch reset_rng_state() correct_rerun_result = self.run_n_iterations( - copy.deepcopy(model), clone_inputs(example_inputs) + deepcopy_and_maybe_ddp(model), clone_inputs(example_inputs) ) if not same( correct_result, @@ -1170,7 +1152,10 @@ def record_status(accuracy_status): torch._dynamo.reset() try: optimized_model_iter_fn = optimize_ctx(self.run_n_iterations) - new_result = optimized_model_iter_fn(model, example_inputs) + + new_result = optimized_model_iter_fn( + deepcopy_and_maybe_ddp(model), example_inputs + ) except Exception as e: accuracy_status = "fail_to_run" print( @@ -1232,7 +1217,9 @@ def warmup(fn, model, example_inputs, mode, niters=5): ) compilation_time = dynamo_latency - eager_latency - compression_ratio = eager_peak_mem / dynamo_peak_mem + compression_ratio = ( + eager_peak_mem / dynamo_peak_mem if dynamo_peak_mem else 0.0 + ) # print( # f"memory: eager: {eager_peak_mem:.2f} GB, " # f"dynamo: {dynamo_peak_mem:.2f} GB, " @@ -1321,7 +1308,6 @@ def compare_branches( "--diff_main called on main branch, what are you diffing?" ) - @maybe_fresh_cache def run_one_model( self, name, @@ -1331,6 +1317,7 @@ def run_one_model( experiment, diff=False, branch=None, + explain=False, ): if diff: self.compare_branches( @@ -1340,6 +1327,8 @@ def run_one_model( print("RUNNING ON BRANCH:", branch) mode = "train" if self.args.training else "eval" print(f"{current_device:4} {mode:5} {current_name:34} ", end="", flush=True) + start_calls_captured = torch._dynamo.utils.counters["stats"]["calls_captured"] + start_unique_graphs = torch._dynamo.utils.counters["stats"]["unique_graphs"] if self.args.accuracy: status = self.check_accuracy( name, model, example_inputs, optimize_ctx, experiment @@ -1350,14 +1339,20 @@ def run_one_model( name, model, example_inputs, optimize_ctx, experiment ) print(status) + end_calls_captured = torch._dynamo.utils.counters["stats"]["calls_captured"] + end_unique_graphs = torch._dynamo.utils.counters["stats"]["unique_graphs"] + if explain: + print( + f"Dynamo produced {end_unique_graphs-start_unique_graphs} graph(s) " + f"covering {end_calls_captured-start_calls_captured} ops" + ) def help(fn): return fn.__doc__ -def parse_args(): - +def parse_args(args=None): parser = argparse.ArgumentParser() parser.add_argument( "--filter", "-k", action="append", help="filter benchmarks with regexp" @@ -1378,7 +1373,10 @@ def parse_args(): default=0, help="ID of the benchmark suite partition to be run. Used to divide CI tasks", ) - parser.add_argument("--devices", "-d", action="append", help="cpu or cuda") + parser.add_argument( + "--devices", "--device", "-d", action="append", help="cpu or cuda" + ) + parser.add_argument("--device-index", help="CUDA device index") parser.add_argument( "--repeat", "-n", type=int, default=30, help="number of timing runs" ) @@ -1434,12 +1432,58 @@ def parse_args(): parser.add_argument( "--fast", "-f", action="store_true", help="skip slow benchmarks" ) - parser.add_argument("--only", help="Run just one model") + parser.add_argument( + "--only", + help="""Run just one model from torchbench. Or + specify the path and class name of the model in format like: + --only=path:,class: + + Due to the fact that dynamo changes current working directory, + the path should be an absolute path. + + The class should have a method get_example_inputs to return the inputs + for the model. An example looks like + ``` + class LinearModel(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(10, 10) + + def forward(self, x): + return self.linear(x) + + def get_example_inputs(self): + return (torch.randn(2, 10),) + ``` + """, + ) parser.add_argument( "--training", action="store_true", help="Performs training", ) + parser.add_argument( + "--ddp", + action="store_true", + help="Wraps model in DDP before running it, and uses dynamo DDPOptmizer (graph breaks) by default.", + ) + parser.add_argument( + "--fsdp", + action="store_true", + help="""Wraps model in FSDP before running it. Disables cudagraphs by default. + Doesn't recursively wrap, mainly useful for checking dynamo UnspecNNModule compatibility + """, + ) + parser.add_argument( + "--no-optimize-ddp", + action="store_true", + help="Disables dynamo DDPOptimizer (graph breaks). (Applies only when using --ddp benchmark mode).", + ) + parser.add_argument( + "--distributed-master-port", + default="6789", + help="Port to bind for for torch.distributed. Use the default unless it's conflicting with another user", + ) parser.add_argument( "--dynamic-shapes", action="store_true", @@ -1466,24 +1510,22 @@ def parse_args(): help="Use same settings as --inductor for baseline comparisons", ) parser.add_argument( - "--raise-on-assertion-error", + "--suppress-errors", action="store_true", - help="Fail a benchmark if torch._dynamo triggers an internal assertion", + help="Suppress errors instead of raising them", ) parser.add_argument( - "--raise-on-backend-error", - action="store_true", - help="Fail a benchmark if backend throws an exception", + "--output", + help="Overrides the output filename", ) parser.add_argument( - "--raise-on-any", - "--raise", - action="store_true", - help="Raise on assertion or backend errors", + "--output-directory", + help="Overrides the directory to place output files.", ) parser.add_argument( - "--output", - help="Overrides the output filename", + "--part", + default=None, + help="Specify the part of the model to run.", ) parser.add_argument( "--export-profiler-trace", @@ -1498,11 +1540,27 @@ def parse_args(): help="Delta this branch against main. In the future, we may add support for picking the branch.", ) + parser.add_argument( + "--explain", + action="store_true", + help="print some graph/op statistics during the run, similar to .explain()", + ) + parser.add_argument( "--cold_start_latency", action="store_true", help="Use a fresh triton cachedir when running each model, to force cold-start compile.", ) + parser.add_argument( + "--disable-cudagraphs", + action="store_true", + help="Disables cudagraphs for Inductor", + ) + parser.add_argument( + "--trace-on-xla", + action="store_true", + help="Whether to trace the model on XLA or on eager device", + ) group_fuser = parser.add_mutually_exclusive_group() # --nvfuser is now the default, keep the option to not break scripts @@ -1528,28 +1586,9 @@ def parse_args(): group.add_argument( "--coverage", action="store_true", help="(default) " + help(coverage_experiment) ) - group.add_argument( - "--speedup-ltc", - action="store_true", - help="speedup using the ltc backend", - ) - group.add_argument( - "--speedup-ltc-trivial", - action="store_true", - help="speedup using the ltc backend without reusing compiled graph", - ) - group.add_argument( - "--cold-start", action="store_true", help=help(cold_start_experiment) - ) group.add_argument( "--overhead", action="store_true", help=help(overhead_experiment) ) - group.add_argument( - "--speedup-ts", action="store_true", help=help(speedup_experiment_ts) - ) - group.add_argument( - "--speedup-sr", action="store_true", help=help(speedup_experiment_sr) - ) group.add_argument( "--speedup-onnx", action="store_true", help=help(speedup_experiment_onnx) ) @@ -1620,17 +1659,23 @@ def parse_args(): mode_group.add_argument( "--performance", action="store_true", help="Measures performance speedup" ) - args = parser.parse_args() - return args + return parser.parse_args(args) def main(runner, original_dir=None): args = parse_args() + with maybe_init_distributed( + (args.ddp or args.fsdp) and args.only, port=args.distributed_master_port + ): + return maybe_fresh_cache(run, args.cold_start_latency and args.only)( + runner, args, original_dir + ) + +def run(runner, args, original_dir=None): # Pass the parsed args object to benchmark runner object runner.args = args - # defaults args.filter = args.filter or [r"."] args.exclude = args.exclude or [r"^$"] @@ -1650,7 +1695,24 @@ def main(runner, original_dir=None): if args.training else CI_SKIP_INDCUTOR_INFERENCE ) - + if args.ddp: + # TODO: we could also hook DDP bench up to --speedup bench, _not_ for mgpu e2e perf, + # but just to measure impact on singlenode of performing graph-breaks. + # Left it as a follow up to keep this PR isolated. + assert ( + args.accuracy + ), "DDP benchmark is currently only hooked up to --accuracy bench" + assert args.training, "DDP benchmark requires --training mode" + if args.no_optimize_ddp: + torch._dynamo.config.optimize_ddp = False + else: + # TODO(whc) after enabling DDPOptimizer by default this could be removed or assert + torch._dynamo.config.optimize_ddp = True + if args.only == "dlrm": + log.error( + "DLRM+DDP is unsupported as it requires sharding the embedding layer separately from DDP" + ) + return sys.exit(-1) if args.accuracy: # Use small batch size. We use >1 batch size to ensure we test # batch_norm type of operators that work on batch dims. @@ -1658,21 +1720,31 @@ def main(runner, original_dir=None): if args.batch_size is None: if runner.suite_name == "huggingface": args.batch_size = 1 + elif runner.suite_name == "torchbench": + args.batch_size = 4 else: - args.batch_size = 2 + # Larger batch size of TIMM models to have stable batch_norm + assert runner.suite_name == "timm_models" + args.batch_size = 8 # Remove sources of randomness - args.use_eval_mode = True + if runner.suite_name != "timm_models": + # TODO - Using train mode for timm_models. Move to train mode for HF and Torchbench as well. + args.use_eval_mode = True + inductor_config.fallback_random = True # Remove randomeness when torch manual seed is called patch_torch_manual_seed() # Some models e.g. yolov3 assert batch size on n_gpus if "CUDA_VISIBLE_DEVICES" not in os.environ: - os.environ["CUDA_VISIBLE_DEVICES"] = "0" + args.device_index = "0" # Stricter check to disable fallbacks - args.raise_on_any = True + args.suppress_errors = False + + if args.device_index is not None: + os.environ["CUDA_VISIBLE_DEVICES"] = args.device_index elif args.performance: # Ensure that we test on real scenarios @@ -1736,12 +1808,7 @@ def main(runner, original_dir=None): if args.quiet: torch._dynamo.config.log_level = logging.ERROR - torch._dynamo.config.raise_on_assertion_error = ( - args.raise_on_assertion_error or args.raise_on_any - ) - torch._dynamo.config.raise_on_backend_error = ( - args.raise_on_backend_error or args.raise_on_any - ) + torch._dynamo.config.suppress_errors = args.suppress_errors if args.training: runner.model_iter_fn = runner.forward_and_backward_pass @@ -1785,17 +1852,7 @@ def main(runner, original_dir=None): optimize_ctx = torch._dynamo.optimize(dummy_fx_compile, nopython=args.nopython) experiment = speedup_experiment output_filename = "overheads.csv" - elif args.cold_start: - optimize_ctx = torch._dynamo.optimize("aot_nvfuser", nopython=args.nopython) - experiment = cold_start_experiment - assert args.nvfuser, "TODO - Add another aot string for mem fusion with NNC" - backend_str = "nvfuser" if args.nvfuser else "nnc" - output_filename = f"cold_start_{backend_str}.csv" - # TODO(whc) should we move this to a more general part of the script? - torch.backends.cuda.matmul.allow_tf32 = True elif args.inductor or args.inductor_dynamic: - from torch._inductor import config as inductor_config - inductor_config.debug = args.verbose if args.threads: inductor_config.cpp.threads = args.threads @@ -1812,24 +1869,6 @@ def main(runner, original_dir=None): optimize_ctx = torch._dynamo.optimize("inductor", nopython=args.nopython) experiment = speedup_experiment output_filename = "inductor.csv" - elif args.speedup_ltc: - optimize_ctx = torch._dynamo.optimize( - backends.ltc_reuse_graph, nopython=args.nopython - ) - experiment = speedup_experiment - output_filename = "speedups_ltc.csv" - elif args.speedup_ltc_trivial: - optimize_ctx = torch._dynamo.optimize( - backends.ltc_trivial, nopython=args.nopython - ) - experiment = speedup_experiment - output_filename = "speedups_ltc_trivial.csv" - elif args.speedup_ts: - experiment = speedup_experiment_ts - output_filename = "baseline_ts.csv" - elif args.speedup_sr: - experiment = speedup_experiment_sr - output_filename = "baseline_sr.csv" elif args.speedup_onnx: experiment = speedup_experiment_onnx output_filename = "baseline_onnx.csv" @@ -1875,7 +1914,8 @@ def main(runner, original_dir=None): nopython=args.nopython, ) elif args.nothing: - pass + optimize_ctx = nothing + output_filename = "nothing.csv" elif args.backend: optimize_ctx = torch._dynamo.optimize(args.backend, nopython=args.nopython) experiment = speedup_experiment @@ -1898,13 +1938,22 @@ def main(runner, original_dir=None): experiment = coverage_experiment output_filename = "coverage.csv" + if args.inductor or args.backend == "inductor": + if args.disable_cudagraphs: + inductor_config.triton.cudagraphs = False + runner.setup_amp() if args.output: output_filename = args.output if output_filename: - output_filename = os.path.join(torch._dynamo.config.base_dir, output_filename) + if args.output_directory: + output_filename = os.path.join(args.output_directory, output_filename) + else: + output_filename = os.path.join( + torch._dynamo.config.base_dir, output_filename + ) if args.find_batch_sizes and args.only: for device in args.devices: @@ -1934,19 +1983,47 @@ def main(runner, original_dir=None): batch_size = read_batch_size_from_file( args, args.batch_size_file, model_name ) - try: - device, name, model, example_inputs, batch_size = runner.load_model( - device, - model_name, - batch_size=batch_size, - ) - except NotImplementedError as e: - print(e) - import traceback + if model_specified_by_path(args.only): + model, example_inputs = load_model_from_path(args.only) + name = model.__class__.__name__ + model = model.to(device=device) + example_inputs = tree_map(lambda x: x.to(device=device), example_inputs) + else: + try: + if args.part: + ( + device, + name, + model, + example_inputs, + batch_size, + ) = runner.load_model( + device, model_name, batch_size=batch_size, part=args.part + ) + else: + ( + device, + name, + model, + example_inputs, + batch_size, + ) = runner.load_model(device, model_name, batch_size=batch_size) + except NotImplementedError as e: + print(e) + import traceback + + print(traceback.format_exc()) + logging.warn(f"{args.only} failed to load") + continue # bad benchmark implementation + + if args.trace_on_xla: + import torch_xla.core.xla_model as xm - print(traceback.format_exc()) - logging.warn(f"{args.only} failed to load") - continue # bad benchmark implementation + xla_dev = xm.xla_device() + model = model.to(device=xla_dev) + example_inputs = tree_map( + lambda x: x.to(device=xla_dev), example_inputs + ) current_name = name current_device = device @@ -1971,6 +2048,7 @@ def main(runner, original_dir=None): optimize_ctx, experiment, diff=args.diff_main, + explain=args.explain, ) if args.generate_aot_autograd_stats: stats_file = output_filename.split(".csv")[0] + "_stats.csv" diff --git a/benchmarks/dynamo/dist_util.py b/benchmarks/dynamo/dist_util.py new file mode 100644 index 0000000000000..24625c84e1a10 --- /dev/null +++ b/benchmarks/dynamo/dist_util.py @@ -0,0 +1,148 @@ +import argparse +import functools +import importlib +import os + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch._dynamo.testing import reduce_to_scalar_loss +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + apply_activation_checkpointing, + checkpoint_wrapper, + CheckpointImpl, +) +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp.wrap import ModuleWrapPolicy + +try: + from .torchbench import setup_torchbench_cwd +except ImportError: + from torchbench import setup_torchbench_cwd + +from transformers.models.bert.modeling_bert import BertLayer, BertLMPredictionHead +from transformers.models.t5.modeling_t5 import T5Block + + +def setup(rank, world_size): + os.environ["MASTER_ADDR"] = os.getenv("MASTER_ADDR", "localhost") + os.environ["MASTER_PORT"] = os.getenv("MASTER_PORT", "12355") + os.environ["RANK"] = os.getenv("RANK", "0") + os.environ["WORLD_SIZE"] = os.getenv("WORLD_SIZE", "1") + dist.init_process_group("nccl") + + +def cleanup(): + dist.destroy_process_group() + + +class CustomLinear(torch.nn.Module): + def __init__(self, a, b): + super(CustomLinear, self).__init__() + self.weight = nn.Parameter(torch.randn(a, b)) + + def forward(self, x): + return torch.mm(x, self.weight) + + +class MyModule(torch.nn.Module): + def __init__(self, a, b): + super(MyModule, self).__init__() + self.net = nn.Sequential( + nn.Linear(a, b), + nn.ReLU(), + ) + + def forward(self, x): + return self.net(x) + + +class ToyModel(nn.Module): + def __init__(self): + super(ToyModel, self).__init__() + self.net = nn.Sequential( + *[nn.Linear(10, 10000), nn.ReLU()] + + [nn.Linear(10000, 10000), nn.ReLU()] + + [MyModule(10000, 10000)] + + [MyModule(10000, 1000)] + + [MyModule(1000, 1000)] + + [MyModule(1000, 1000)] + + [MyModule(1000, 1000)] + + [MyModule(1000, 1000)] + + [MyModule(1000, 1000)] + + [MyModule(1000, 1000)] + + [MyModule(1000, 1000)] + + [nn.Linear(1000, 5)] + ) + + def forward(self, x): + return self.net(x) + + +def model_iter_fn(model, example_inputs, collect_outputs=False): + outputs = model(*example_inputs) + loss = reduce_to_scalar_loss(outputs) + loss.backward() + if collect_outputs: + return outputs + + +def get_model(args): + if args.torchbench_model: + old_cwd = setup_torchbench_cwd() + module = importlib.import_module( + f"torchbenchmark.models.{args.torchbench_model}" + ) + benchmark_cls = getattr(module, "Model", None) + bm = benchmark_cls( + test="train", device=args.device, jit=False, batch_size=args.batch_size + ) + model, inputs = bm.get_module() + elif args.toy_model: + model = ToyModel() + inputs = (torch.randn(20, 10),) + else: + raise argparse.ArgumentError( + args.torchbench_model, message="Must specify a model" + ) + + return model, inputs + + +def fsdp_checkpointing_base(model, blocks): + """apply activation checkpointing to model + returns None as model is updated directly + """ + non_reentrant_wrapper = functools.partial( + checkpoint_wrapper, + offload_to_cpu=False, + checkpoint_impl=CheckpointImpl.NO_REENTRANT, + ) + + def check_fn(submodule): + return isinstance(submodule, blocks) + + apply_activation_checkpointing( + model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn + ) + + +MODEL_FSDP_WRAP = { + "toy_model": (MyModule,), + "hf_Bert": (BertLayer, BertLMPredictionHead), + "hf_T5": (T5Block,), +} + + +def apply_fsdp(args, model, use_checkpointing=False, use_wrap_policy=True): + wrap_policy = None + blocks = MODEL_FSDP_WRAP[ + "toy_model" if model.__class__ is ToyModel else args.torchbench_model + ] + if use_wrap_policy: + wrap_policy = ModuleWrapPolicy(blocks) + + model = FSDP(model, auto_wrap_policy=wrap_policy, use_orig_params=True) + if use_checkpointing: + fsdp_checkpointing_base(model, blocks) + return model diff --git a/benchmarks/dynamo/distributed.py b/benchmarks/dynamo/distributed.py new file mode 100644 index 0000000000000..b490c48ade90e --- /dev/null +++ b/benchmarks/dynamo/distributed.py @@ -0,0 +1,169 @@ +import argparse +import logging +import os +from functools import partial + +import torch +import torch._dynamo as dynamo +import torch.utils._pytree as pytree +from torch._dynamo.testing import reduce_to_scalar_loss +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.profiler import profile, ProfilerActivity, record_function + +try: + from .common import timed + from .dist_util import apply_fsdp, cleanup, get_model, model_iter_fn, setup +except ImportError: + from common import timed + from dist_util import apply_fsdp, cleanup, get_model, model_iter_fn, setup + +log = logging.getLogger(__name__) + + +def torchviz_model(args, model, inputs, rank): + from torchviz import make_dot + + outputs = model(*inputs) + loss = reduce_to_scalar_loss(outputs) + parameter_names = dict(model.named_parameters()) + dot = make_dot(loss, params=parameter_names, show_attrs=True, show_saved=True) + if rank == 0: + dot.render("torchviz.dot") + + +def profile_model(args, model, inputs, rank): + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: + for i in range(args.repeat): + with record_function("Forward"): + outputs = model(*inputs) + loss = reduce_to_scalar_loss(outputs) + with record_function("Backward"): + loss.backward() + if rank == 0: + prof.export_chrome_trace(args.trace_file) + + +def run_model(args, model, inputs, key): + rank = int(os.getenv("RANK", 0)) + world_size = int(os.getenv("WORLD_SIZE", 1)) + # result_q = [] + + setup(rank, world_size) + if args.device == "cuda": + # needed for FSDP + torch.cuda.set_device(rank) + + dev_rank = f"{args.device}:{rank}" + model = model.to(dev_rank) + + def move_tensor(maybe_tensor): + if torch.is_tensor(maybe_tensor): + return maybe_tensor.to(dev_rank) + return maybe_tensor + + inputs = pytree.tree_map(move_tensor, inputs) + + if args.fsdp: + model = apply_fsdp( + args, + model, + use_checkpointing=args.fsdp_checkpoint, + use_wrap_policy=args.fsdp_wrap, + ) + elif args.ddp: + model = DDP(model) + + if args.verbose: + print(model) + + if args.dynamo: + dynamo.reset() + if args.verbose: + dynamo.config.verbose = True + dynamo.config.log_level = logging.DEBUG + if args.dynamo_no_optimize_ddp: + dynamo.config.optimize_ddp = False + if args.dynamo == "inductor" and args.fsdp: + torch._inductor.config.triton.cudagraphs = False + log.warn("disabling inductor cudagraphs for compatibility with FSDP") + + def print_compile(gm, ex): + print( + f"print_compile:\n{str(gm.graph)}\n-----------------------------------------" + ) + return gm + + dynamo_ctx = dynamo.optimize( + print_compile if args.dynamo == "print" else args.dynamo + ) + model = dynamo_ctx(model) + + # warmup + _ = timed(model, model_iter_fn, inputs, times=3, return_result=False) + t_total = timed( + model, model_iter_fn, inputs, times=args.repeat, return_result=False + ) + if args.torchviz: + torchviz_model(args, model, inputs, rank) + if args.profile: + profile_model(args, model, inputs, rank) + + cleanup() + return t_total + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--device", default="cuda") + parser.add_argument( + "--dynamo", + default=None, + help="if set to a str, uses dynamo[str] backend. else, eager", + ) + parser.add_argument("--verbose", action="store_true") + parser.add_argument("--batch_size", default=None) + parser.add_argument( + "--torchviz", action="store_true", help="Dump autograd graph with torchviz" + ) + parser.add_argument("--profile", action="store_true", help="Run the profiler") + parser.add_argument("--trace_file", default="profile.json", help="Run the profiler") + parser.add_argument("--repeat", default=10, help="Repeats for timing run") + parser.add_argument( + "--dynamo_no_optimize_ddp", + action="store_true", + help="Enable dynamo's ddp optimizer", + ) + parser.add_argument( + "--fsdp_checkpoint", + action="store_true", + help="whether to use gradient checkpointing via model-specific policy", + ) + parser.add_argument( + "--fsdp_wrap", + action="store_true", + help="whether to apply fsdp to submodules via model-specific policy", + ) + + dist_arg = parser.add_mutually_exclusive_group() + dist_arg.add_argument("--ddp", action="store_true") + dist_arg.add_argument("--fsdp", action="store_true") + + model_arg = parser.add_mutually_exclusive_group(required=True) + model_arg.add_argument( + "--torchbench_model", help="name of torchbench model, e.g. hf_Bert" + ) + model_arg.add_argument( + "--toy_model", action="store_true", help="use toy model instead" + ) + args = parser.parse_args() + + model_name = args.torchbench_model + if args.toy_model: + model_name = "ToyModel" + model, inputs = get_model(args) + + fn = partial(run_model, args, model, inputs) + + world_size = os.getenv("WORLD_SIZE", 1) + t_total = fn(f"{model_name}_{world_size}") + print(f"mean latency {t_total / args.repeat} across {args.repeat} runs") diff --git a/benchmarks/dynamo/huggingface.py b/benchmarks/dynamo/huggingface.py index b563c229529d3..84caea0d910ec 100755 --- a/benchmarks/dynamo/huggingface.py +++ b/benchmarks/dynamo/huggingface.py @@ -68,9 +68,6 @@ def pip_install(package): exec(f"from transformers import {cls}") -USE_HALF_BATCH_SIZE = True - - # These models contain the models present in huggingface_models_list. It is a # combination of models supported by HF Fx parser and some manually supplied # models. For these models, we already know the largest batch size that can fit @@ -92,46 +89,77 @@ def pip_install(package): SKIP = { - # Difficult to run and compare + # Difficult to setup accuracy test because .eval() not supported "Reformer", # Fails deepcopy - "BlenderbotForCausalLM", "BlenderbotForConditionalGeneration", - "GPTJForCausalLM", - "GPTJForQuestionAnswering", "GPTNeoForCausalLM", "GPTNeoForSequenceClassification", # Fails with even batch size = 1 - "DebertaV2ForMaskedLM", - "DebertaV2ForQuestionAnswering", + "GPTJForCausalLM", + "GPTJForQuestionAnswering", } # TODO - Fails even after fake tensors -USE_SMALL_BATCH_SIZE = { +BATCH_SIZE_DIVISORS = { "AlbertForMaskedLM": 2, - "AlbertForPreTraining": 4, "AlbertForQuestionAnswering": 2, + "AllenaiLongformerBase": 2, "BartForCausalLM": 2, - "BartForConditionalGeneration": 1, - "BlenderbotSmallForConditionalGeneration": 32, - "DebertaForMaskedLM": 4, + "BartForConditionalGeneration": 2, + "BertForMaskedLM": 2, + "BertForQuestionAnswering": 2, + "BlenderbotForCausalLM": 8, + # "BlenderbotForConditionalGeneration" : 16, + "BlenderbotSmallForCausalLM": 4, + "BlenderbotSmallForConditionalGeneration": 2, + "CamemBert": 2, + "DebertaForMaskedLM": 8, "DebertaForQuestionAnswering": 4, - "DebertaV2ForMaskedLM": 1, - "DebertaV2ForQuestionAnswering": 1, - "DistilBertForMaskedLM": 16, - "ElectraForCausalLM": 1, - "GPTNeoForCausalLM": 1, - "GPTNeoForSequenceClassification": 1, - "M2M100ForConditionalGeneration": 2, + "DebertaV2ForMaskedLM": 8, + "DebertaV2ForQuestionAnswering": 4, + "DistilBertForMaskedLM": 2, + "DistilBertForQuestionAnswering": 2, + "DistillGPT2": 2, + "ElectraForCausalLM": 2, + "ElectraForQuestionAnswering": 2, + "GPT2ForSequenceClassification": 2, + # "GPTJForCausalLM" : 2, + # "GPTJForQuestionAnswering" : 2, + # "GPTNeoForCausalLM" : 32, + # "GPTNeoForSequenceClassification" : 2, + "GoogleFnet": 2, + "LayoutLMForMaskedLM": 2, + "LayoutLMForSequenceClassification": 2, + "M2M100ForConditionalGeneration": 4, + "MBartForCausalLM": 2, + "MBartForConditionalGeneration": 2, "MT5ForConditionalGeneration": 2, - "MegatronBertForCausalLM": 2, - "OPTForCausalLM": 4, - "PegasusForCausalLM": 8, - "PegasusForConditionalGeneration": 4, - "RobertaForCausalLM": 4, - "TrOCRForCausalLM": 8, - "XGLMForCausalLM": 1, - "XLNetLMHeadModel": 4, + "MegatronBertForCausalLM": 4, + "MegatronBertForQuestionAnswering": 2, + "MobileBertForMaskedLM": 4, + "MobileBertForQuestionAnswering": 2, + "OPTForCausalLM": 2, + "PLBartForCausalLM": 2, + "PLBartForConditionalGeneration": 2, + "PegasusForCausalLM": 4, + "PegasusForConditionalGeneration": 2, + "RobertaForCausalLM": 2, + "RobertaForQuestionAnswering": 2, + "Speech2Text2ForCausalLM": 4, + "T5ForConditionalGeneration": 2, + "T5Small": 2, + "TrOCRForCausalLM": 2, + "XGLMForCausalLM": 4, + "XLNetLMHeadModel": 2, + "YituTechConvBert": 2, +} + +SKIP_ACCURACY_CHECK_MODELS = { + # Models too large to have eager, dynamo and fp64_numbers simultaneosuly + # even for 40 GB machine. + "DebertaV2ForMaskedLM", + "BlenderbotForCausalLM", } @@ -146,18 +174,33 @@ def get_module_cls_by_model_name(model_cls_name): def get_sequence_length(model_cls, model_name): - if model_name.startswith(("Bert", "Roberta", "Blenderbot")): + if model_name.startswith(("Blenderbot",)): seq_length = 128 - elif model_name.startswith(("GPT2", "Bart", "T5")): + elif model_name.startswith(("GPT2", "Bart", "T5", "PLBart", "MBart")): seq_length = 1024 elif model_name in ("AllenaiLongformerBase", "BigBird"): seq_length = 1024 + elif model_name.startswith("OPT"): + seq_length = 2048 elif "Reformer" in model_name: seq_length = 4096 elif model_name.startswith( - ("Albert", "Deberta", "Layout", "Electra", "XLNet") + ( + "Albert", + "Deberta", + "Layout", + "Electra", + "XLNet", + "MegatronBert", + "Bert", + "Roberta", + ) ) or model_name in ("DistillGPT2", "GoogleFnet", "YituTechConvBert", "CamemBert"): seq_length = 512 + elif model_name in ("TrOCRForCausalLM"): + seq_length = 256 + elif model_name.startswith("MobileBert"): + seq_length = 128 else: log.warning( f"Sequence Length not defined for {model_name}. Choosing 128 arbitrarily" @@ -294,10 +337,10 @@ def rand_int_tensor(device, low, high, shape): AutoConfig.from_pretrained("t5-small"), AutoModelForSeq2SeqLM, ), - "BigBird": ( - BigBirdConfig(attention_type="block_sparse"), - AutoModelForMaskedLM, - ), + # "BigBird": ( + # BigBirdConfig(attention_type="block_sparse"), + # AutoModelForMaskedLM, + # ), "DistillGPT2": ( AutoConfig.from_pretrained("distilgpt2"), AutoModelForCausalLM, @@ -369,13 +412,8 @@ def load_model( if batch_size is None: batch_size = batch_size_default - if model_name in USE_SMALL_BATCH_SIZE: - batch_size = USE_SMALL_BATCH_SIZE[model_name] - log.warning( - f"Running smaller batch size={batch_size} for {model_name}, orig batch_size={batch_size_default}" - ) - elif USE_HALF_BATCH_SIZE and batch_size >= 2: - batch_size = int(batch_size / 2) + if model_name in BATCH_SIZE_DIVISORS: + batch_size = max(int(batch_size / BATCH_SIZE_DIVISORS[model_name]), 1) log.warning( f"Running smaller batch size={batch_size} for {model_name}, orig batch_size={batch_size_default}" ) @@ -416,6 +454,12 @@ def iter_model_names(self, args): continue yield model_name + @property + def skip_accuracy_checks_large_models_dashboard(self): + if self.args.dashboard or self.args.accuracy: + return SKIP_ACCURACY_CHECK_MODELS + return set() + def pick_grad(self, name, is_training): if is_training: return torch.enable_grad() @@ -436,7 +480,7 @@ def forward_pass(self, mod, inputs, collect_outputs=True): def forward_and_backward_pass(self, mod, inputs, collect_outputs=True): cloned_inputs = clone_inputs(inputs) - self.optimizer_zero_grad() + self.optimizer_zero_grad(mod) with self.autocast(): pred = mod(**cloned_inputs) loss = self.compute_loss(pred) @@ -473,10 +517,10 @@ def refresh_model_names_and_batch_sizes(): if model_cls in [ CLIPModel, CLIPVisionModel, - SwinForImageClassification, - SwinForImageClassification, - SwinForMaskedImageModeling, - SwinModel, + # SwinForImageClassification, + # SwinForImageClassification, + # SwinForMaskedImageModeling, + # SwinModel, ViTForImageClassification, ViTForMaskedImageModeling, ViTModel, diff --git a/benchmarks/dynamo/huggingface_models_list.txt b/benchmarks/dynamo/huggingface_models_list.txt index 8272c79b12bda..6e3cf19a783d7 100644 --- a/benchmarks/dynamo/huggingface_models_list.txt +++ b/benchmarks/dynamo/huggingface_models_list.txt @@ -1,53 +1,51 @@ AlbertForMaskedLM,8 AlbertForQuestionAnswering,8 -AllenaiLongformerBase,1 -BartForCausalLM,16 +AllenaiLongformerBase,8 +BartForCausalLM,8 BartForConditionalGeneration,4 -BertForMaskedLM,128 -BertForQuestionAnswering,128 -BigBird,1 +BertForMaskedLM,32 +BertForQuestionAnswering,32 BlenderbotForCausalLM,32 -BlenderbotForConditionalGeneration,32 -BlenderbotSmallForCausalLM,128 +BlenderbotForConditionalGeneration,16 +BlenderbotSmallForCausalLM,256 BlenderbotSmallForConditionalGeneration,128 -CamemBert,1 +CamemBert,32 DebertaForMaskedLM,32 DebertaForQuestionAnswering,32 DebertaV2ForMaskedLM,8 DebertaV2ForQuestionAnswering,8 -DistilBertForMaskedLM,64 -DistilBertForQuestionAnswering,64 -DistillGPT2,1 +DistilBertForMaskedLM,256 +DistilBertForQuestionAnswering,512 +DistillGPT2,32 ElectraForCausalLM,64 ElectraForQuestionAnswering,128 GPT2ForSequenceClassification,8 GPTJForCausalLM,1 GPTJForQuestionAnswering,1 -GPTNeoForCausalLM,8 -GPTNeoForSequenceClassification,8 -GoogleFnet,1 +GPTNeoForCausalLM,32 +GPTNeoForSequenceClassification,32 +GoogleFnet,32 LayoutLMForMaskedLM,32 LayoutLMForSequenceClassification,32 -M2M100ForConditionalGeneration,8 -MBartForCausalLM,32 -MBartForConditionalGeneration,16 -MT5ForConditionalGeneration,8 +M2M100ForConditionalGeneration,64 +MBartForCausalLM,8 +MBartForConditionalGeneration,4 +MT5ForConditionalGeneration,32 MegatronBertForCausalLM,16 MegatronBertForQuestionAnswering,16 -MobileBertForMaskedLM,32 -MobileBertForQuestionAnswering,64 -OPTForCausalLM,32 -PLBartForCausalLM,32 -PLBartForConditionalGeneration,16 -PegasusForCausalLM,32 -PegasusForConditionalGeneration,16 -Reformer,1 -RobertaForCausalLM,128 -RobertaForQuestionAnswering,128 -Speech2Text2ForCausalLM,128 +MobileBertForMaskedLM,256 +MobileBertForQuestionAnswering,256 +OPTForCausalLM,4 +PLBartForCausalLM,16 +PLBartForConditionalGeneration,8 +PegasusForCausalLM,128 +PegasusForConditionalGeneration,64 +RobertaForCausalLM,32 +RobertaForQuestionAnswering,32 +Speech2Text2ForCausalLM,1024 T5ForConditionalGeneration,8 -T5Small,1 -TrOCRForCausalLM,32 -XGLMForCausalLM,8 -XLNetLMHeadModel,128 -YituTechConvBert,1 +T5Small,8 +TrOCRForCausalLM,64 +XGLMForCausalLM,32 +XLNetLMHeadModel,16 +YituTechConvBert,32 diff --git a/benchmarks/dynamo/runner.py b/benchmarks/dynamo/runner.py index ce952095bd352..3f45b55fd77ee 100755 --- a/benchmarks/dynamo/runner.py +++ b/benchmarks/dynamo/runner.py @@ -26,20 +26,26 @@ import argparse import dataclasses +import functools import glob import importlib import io import itertools import logging import os +import re import shutil import subprocess +import sys +import tempfile from collections import defaultdict -from datetime import datetime +from datetime import datetime, timedelta, timezone from os.path import abspath, exists from random import randint import matplotlib.pyplot as plt + +import numpy as np import pandas as pd import torch @@ -63,15 +69,17 @@ "eager": "--training --backend=eager ", "aot_eager": "--training --backend=aot_eager ", "aot_cudagraphs": "--training --backend=aot_cudagraphs ", - "aot_nvfuser": "--training --nvfuser --backend=aot_nvfuser ", + "aot_nvfuser": "--training --nvfuser --backend=aot_ts_nvfuser ", + "nvprims_nvfuser": "--training --backend=nvprims_nvfuser ", "inductor": "--training --inductor ", + "inductor_no_cudagraphs": "--training --inductor --disable-cudagraphs ", }, "inference": { "ts_nnc": "--speedup-ts", "ts_nvfuser": "-n100 --speedup-ts --nvfuser", "trt": "-n100 --speedup-trt", - "ts_nvfuser_cudagraphs": "--inductor-settings --float32 -n50 --backend=cudagraphs_ts", - "inductor": "--inductor-settings --float32 -n50 --inductor", + "ts_nvfuser_cudagraphs": "--backend=cudagraphs_ts", + "inductor": "-n50 --inductor", }, } @@ -82,11 +90,14 @@ "training": [ "eager", "aot_eager", - "aot_cudagraphs", - "aot_nvfuser", "inductor", + "inductor_no_cudagraphs", ], "inference": ["ts_nvfuser_cudagraphs", "inductor"], + "flag_compilers": { + "training": ["inductor", "inductor_no_cudagraphs"], + "inference": ["inductor"], + }, "dtypes": [ "float32", ], @@ -109,6 +120,30 @@ } +def flag_speedup(x): + return x < 0.95 + + +def flag_compilation_latency(x): + return x > 120 + + +def flag_compression_ratio(x): + return x < 0.9 + + +def flag_accuracy(x): + return "pass" not in x + + +FLAG_FNS = { + "speedup": flag_speedup, + "compilation_latency": flag_compilation_latency, + "compression_ratio": flag_compression_ratio, + "accuracy": flag_accuracy, +} + + def percentage(part, whole, decimals=2): if whole == 0: return 0 @@ -125,6 +160,12 @@ def parse_args(): action="append", help=f"For --inference, options are {INFERENCE_COMPILERS}. For --training, options are {TRAINING_COMPILERS}", ) + + parser.add_argument( + "--flag-compilers", + action="append", + help="List of compilers to flag issues. Same format as --compilers.", + ) parser.add_argument( "--quick", action="store_true", help="Just runs one model. Helps in debugging" ) @@ -179,6 +220,30 @@ def parse_args(): default=False, help="Updates to dashboard", ) + parser.add_argument( + "--no-graphs", + action="store_true", + default=False, + help="Do not genenerate and upload metric graphs", + ) + parser.add_argument( + "--no-update-archive", + action="store_true", + default=False, + help="Do not update lookup.csv or the log archive", + ) + parser.add_argument( + "--no-gh-comment", + action="store_true", + default=False, + help="Do not write a comment to github", + ) + parser.add_argument( + "--update-dashboard-test", + action="store_true", + default=False, + help="does all of --no-graphs, --no-update-lookup, and --no-gh-comment", + ) parser.add_argument( "--dashboard-image-uploader", default=DASHBOARD_DEFAULTS["dashboard_image_uploader"], @@ -189,6 +254,11 @@ def parse_args(): default=DASHBOARD_DEFAULTS["dashboard_archive_path"], help="Archived directory path", ) + parser.add_argument( + "--archive-name", + help="Directory name under dashboard-archive-path to copy output-dir to. " + "If not provided, a generated name is used.", + ) parser.add_argument( "--dashboard-gh-cli-path", default=DASHBOARD_DEFAULTS["dashboard_gh_cli_path"], @@ -223,6 +293,11 @@ def get_skip_tests(suite): return skip_str +def generate_csv_name(args, dtype, suite, device, compiler, testing): + mode = get_mode(args) + return f"{compiler}_{suite}_{dtype}_{mode}_{device}_{testing}.csv" + + def generate_commands(args, dtypes, suites, devices, compilers, output_dir): mode = get_mode(args) with open("run.sh", "w") as runfile: @@ -242,7 +317,7 @@ def generate_commands(args, dtypes, suites, devices, compilers, output_dir): info = TABLE[mode] for compiler in compilers: base_cmd = info[compiler] - output_filename = f"{output_dir}/{compiler}_{suite}_{dtype}_{mode}_{device}_{testing}.csv" + output_filename = f"{output_dir}/{generate_csv_name(args, dtype, suite, device, compiler, testing)}" cmd = f"python benchmarks/dynamo/{suite}.py --{testing} --{dtype} -d{device} --output={output_filename}" cmd = f"{cmd} {base_cmd} {args.extra_args} --no-skip --dashboard" @@ -256,7 +331,10 @@ def generate_commands(args, dtypes, suites, devices, compilers, output_dir): filters = DEFAULTS["quick"][suite] cmd = f"{cmd} {filters}" - if testing == "performance" and compiler == "inductor": + if testing == "performance" and compiler in ( + "inductor", + "inductor_no_cudagraphs", + ): cmd = f"{cmd} --cold_start_latency" lines.append(cmd) lines.append("") @@ -274,7 +352,7 @@ def generate_dropdown_comment(title, body): return str_io.getvalue() -def build_summary(): +def build_summary(args): import git out_io = io.StringIO() @@ -283,38 +361,45 @@ def print_commit_hash(path, name): if exists(path): repo = git.Repo(path, search_parent_directories=True) sha = repo.head.object.hexsha + date = repo.head.object.committed_datetime out_io.write(f"{name} commit: {sha}\n") + out_io.write(f"{name} commit date: {date}\n") else: out_io.write(f"{name} Absent\n") def env_var(name): out_io.write(f"{name} = {os.environ[name]}\n") - out_io.write("## Commit hashes ##\n") - print_commit_hash(".", "torch._dynamo") + out_io.write("\n") + out_io.write("### Run name ###\n") + out_io.write(get_archive_name(args, args.dtypes[0])) + out_io.write("\n") + + out_io.write("\n") + out_io.write("### Commit hashes ###\n") print_commit_hash("../pytorch", "pytorch") print_commit_hash("../functorch", "functorch") print_commit_hash("../torchbenchmark", "torchbench") out_io.write("\n") - out_io.write("## TorchDynamo config flags ##\n") + out_io.write("### TorchDynamo config flags ###\n") for key in dir(torch._dynamo.config): val = getattr(torch._dynamo.config, key) if not key.startswith("__") and isinstance(val, bool): out_io.write(f"torch._dynamo.config.{key} = {val}\n") out_io.write("\n") - out_io.write("## Torch version ##\n") + out_io.write("### Torch version ###\n") out_io.write(f"torch: {torch.__version__}\n") out_io.write("\n") - out_io.write("## Environment variables ##\n") + out_io.write("### Environment variables ###\n") env_var("TORCH_CUDA_ARCH_LIST") env_var("CUDA_HOME") env_var("USE_LLVM") out_io.write("\n") - out_io.write("## GPU details ##\n") + out_io.write("### GPU details ###\n") out_io.write(f"CUDNN VERSION: {torch.backends.cudnn.version()}\n") out_io.write(f"Number CUDA Devices: {torch.cuda.device_count()}\n") out_io.write(f"Device Name: {torch.cuda.get_device_name(0)}\n") @@ -328,12 +413,70 @@ def env_var(name): gh_fh.write(comment) +@functools.lru_cache(None) +def archive_data(archive_name): + if archive_name is not None: + prefix_match = re.search(r"\w+(?=_performance)", archive_name) + if prefix_match is not None: + prefix = prefix_match.group(0) + else: + prefix = "" + day_match = re.search(r"day_(\d+)_", archive_name) + if day_match is not None: + day = day_match.group(1) + else: + day = "000" + else: + now = datetime.now(tz=timezone(timedelta(hours=-8))) + day = now.strftime("%j") + prefix = now.strftime(f"day_{day}_%d_%m_%y") + return day, prefix + + +@functools.lru_cache(None) +def default_archive_name(dtype): + _, prefix = archive_data(None) + return f"{prefix}_performance_{dtype}_{randint(100, 999)}" + + +def get_archive_name(args, dtype): + return ( + default_archive_name(dtype) if args.archive_name is None else args.archive_name + ) + + +def archive(src_dir, dest_dir_prefix, archive_name, dtype): + if archive_name is None: + archive_name = default_archive_name(dtype) + # Copy the folder to archived location + dest = os.path.join(dest_dir_prefix, archive_name) + shutil.copytree(src_dir, dest, dirs_exist_ok=True) + print(f"copied contents of {src_dir} to {dest}") + + +def get_metric_title(metric): + if metric == "speedup": + return "Performance speedup" + elif metric == "accuracy": + return "Accuracy" + elif metric == "compilation_latency": + return "Compilation latency (sec)" + elif metric == "compression_ratio": + return "Peak Memory Compression Ratio" + elif metric == "abs_latency": + return "Absolute latency (ms)" + raise RuntimeError("unknown metric") + + class Parser: - def __init__(self, suites, devices, dtypes, compilers, mode, output_dir): + def __init__( + self, suites, devices, dtypes, compilers, flag_compilers, mode, output_dir + ): self.suites = suites self.devices = devices self.dtypes = dtypes self.compilers = compilers + self.flag_compilers = flag_compilers self.output_dir = output_dir self.mode = mode @@ -347,11 +490,20 @@ def has_header(self, output_filename): class ParsePerformanceLogs(Parser): - def __init__(self, suites, devices, dtypes, compilers, mode, output_dir): - super().__init__(suites, devices, dtypes, compilers, mode, output_dir) + def __init__( + self, suites, devices, dtypes, compilers, flag_compilers, mode, output_dir + ): + super().__init__( + suites, devices, dtypes, compilers, flag_compilers, mode, output_dir + ) self.parsed_frames = defaultdict(lambda: defaultdict(None)) self.untouched_parsed_frames = defaultdict(lambda: defaultdict(None)) - self.metrics = ["speedup", "compilation_latency", "compression_ratio"] + self.metrics = [ + "speedup", + "abs_latency", + "compilation_latency", + "compression_ratio", + ] self.bottom_k = 50 self.parse() @@ -384,6 +536,7 @@ def read_csv(self, output_filename): "name", "batch_size", "speedup", + "abs_latency", "compilation_latency", "compression_ratio", ], @@ -395,12 +548,6 @@ def parse(self): self.extract_df("accuracy", "accuracy") for metric in self.metrics: self.extract_df(metric, "performance") - self.generate_executive_summary() - for suite in self.suites: - self.plot_graph( - self.untouched_parsed_frames[suite]["speedup"], - f"{suite}_{self.dtypes[0]}", - ) def clean_batch_sizes(self, frames): # Clean up batch sizes when its 0 @@ -427,6 +574,8 @@ def extract_df(self, metric, testing): for compiler in self.compilers: output_filename = f"{self.output_dir}/{compiler}_{suite}_{dtype}_{self.mode}_{device}_{testing}.csv" df = self.read_csv(output_filename) + if metric not in df: + df.insert(len(df.columns), metric, np.nan) df = df[["dev", "name", "batch_size", metric]] df.rename(columns={metric: compiler}, inplace=True) df["batch_size"] = df["batch_size"].astype(int) @@ -446,26 +595,31 @@ def extract_df(self, metric, testing): df_copy = df_copy.sort_values( by=list(reversed(self.compilers)), ascending=False ) + if "inductor" in self.compilers: + df_copy = df_copy.sort_values(by="inductor", ascending=False) self.untouched_parsed_frames[suite][metric] = df_copy if testing == "performance": df_accuracy = self.parsed_frames[suite]["accuracy"] perf_rows = [] for model_name in df["name"]: - perf_row = df[df["name"] == model_name] + perf_row = df[df["name"] == model_name].copy() acc_row = df_accuracy[df_accuracy["name"] == model_name] for compiler in self.compilers: if not perf_row.empty: if acc_row.empty: - perf_row[compiler].iloc[0] = 0.0 + perf_row[compiler] = 0.0 elif acc_row[compiler].iloc[0] not in ( "pass", "pass_due_to_skip", ): - perf_row[compiler].iloc[0] = 0.0 + perf_row[compiler] = 0.0 perf_rows.append(perf_row) df = pd.concat(perf_rows) df = df.sort_values(by=list(reversed(self.compilers)), ascending=False) + + if "inductor" in self.compilers: + df = df.sort_values(by="inductor", ascending=False) self.parsed_frames[suite][metric] = df def get_passing_entries(self, compiler, df): @@ -581,6 +735,54 @@ def generate_executive_summary(self): str_io.write(peak_memory_summary) self.executive_summary = str_io.getvalue() + def flag_bad_entries(self, suite, metric, flag_fn): + df = self.untouched_parsed_frames[suite][metric] + df = df.drop("dev", axis=1) + df = df.rename(columns={"batch_size": "bs"}) + # apply flag_fn elementwise to flag_compilers columns, + # if one element fails, the entire row is flagged + flag = np.logical_or.reduce( + df[self.flag_compilers].applymap(flag_fn), + axis=1, + ) + df = df[flag] + df = df.assign(suite=suite) + return df.reindex(columns=["suite", "name"] + self.flag_compilers) + + def generate_warnings(self): + title = "## Warnings ##" + body = ( + "We flag models where:\n\n" + " - accuracy fails\n" + " - speedup < 0.95x (NOTE: 0.0 speedup typically signifies a failure in the performance test)\n" + " - compilation latency > 120 sec.\n" + " - compression ratio < 0.9\n" + "\n" + ) + for metric in [ + "accuracy", + "speedup", + "compilation_latency", + "compression_ratio", + ]: + dfs = [] + for suite in self.suites: + dfs.append(self.flag_bad_entries(suite, metric, FLAG_FNS[metric])) + df = pd.concat(dfs, axis=0) + if df.empty: + continue + tabform = tabulate(df, headers="keys", tablefmt="pretty", showindex="never") + str_io = io.StringIO() + str_io.write("\n") + str_io.write(get_metric_title(metric) + " warnings\n") + str_io.write("~~~\n") + str_io.write(f"{tabform}\n") + str_io.write("~~~\n") + body += str_io.getvalue() + + comment = generate_dropdown_comment(title, body) + return comment + def prepare_message(self, suite): title = f"## {suite} suite with {self.dtypes[0]} precision ##" body = "" @@ -589,6 +791,7 @@ def prepare_message(self, suite): "accuracy", "compilation_latency", "compression_ratio", + "abs_latency", ]: df = self.untouched_parsed_frames[suite][metric] df = df.drop("dev", axis=1) @@ -596,14 +799,7 @@ def prepare_message(self, suite): tabform = tabulate(df, headers="keys", tablefmt="pretty", showindex="never") str_io = io.StringIO() str_io.write("\n") - if metric == "speedup": - str_io.write("Performance speedup\n") - elif metric == "accuracy": - str_io.write("Accuracy\n") - elif metric == "compilation_latency": - str_io.write("Compilation latency (sec)\n") - elif metric == "compression_ratio": - str_io.write("Peak Memory Compression Ratio\n") + str_io.write(get_metric_title(metric) + "\n") str_io.write("~~~\n") str_io.write(f"{tabform}\n") str_io.write("~~~\n") @@ -613,6 +809,13 @@ def prepare_message(self, suite): return comment def gen_summary_files(self): + self.generate_executive_summary() + for suite in self.suites: + self.plot_graph( + self.untouched_parsed_frames[suite]["speedup"], + f"{suite}_{self.dtypes[0]}", + ) + with open(f"{self.output_dir}/gh_title.txt", "w") as gh_fh: str_io = io.StringIO() str_io.write("\n") @@ -622,23 +825,27 @@ def gen_summary_files(self): with open(f"{self.output_dir}/gh_executive_summary.txt", "w") as gh_fh: gh_fh.write(self.executive_summary) - print(self.executive_summary) + + with open(f"{self.output_dir}/gh_warnings.txt", "w") as gh_fh: + warnings_body = self.generate_warnings() + gh_fh.write(warnings_body) str_io = io.StringIO() for suite in self.suites: str_io.write(self.prepare_message(suite)) str_io.write("\n") - print(str_io.getvalue()) with open(f"{self.output_dir}/gh_{self.mode}.txt", "w") as gh_fh: gh_fh.write(str_io.getvalue()) -def parse_logs(args, dtypes, suites, devices, compilers, output_dir): +def parse_logs(args, dtypes, suites, devices, compilers, flag_compilers, output_dir): mode = get_mode(args) - build_summary() + build_summary(args) parser_class = ParsePerformanceLogs - parser = parser_class(suites, devices, dtypes, compilers, mode, output_dir) + parser = parser_class( + suites, devices, dtypes, compilers, flag_compilers, mode, output_dir + ) parser.gen_summary_files() return @@ -656,6 +863,204 @@ def get_date(log_info): return datetime.strptime(f"{log_info.day}", "%j").strftime("%m-%d") +def find_last_2_with_filenames(lookup_file, dashboard_archive_path, dtype, filenames): + df = pd.read_csv(lookup_file, names=("day", "mode", "prec", "path")) + df = df[df["mode"] == "performance"] + df = df[df["prec"] == dtype] + df = df[::-1] + last2 = [] + for path in df["path"]: + output_dir = os.path.join(dashboard_archive_path, path) + fullpaths = [ + os.path.join(dashboard_archive_path, path, name) for name in filenames + ] + if all([os.path.exists(fullpath) for fullpath in fullpaths]): + last2.append(output_dir) + if len(last2) >= 2: + return last2 + return None + + +class SummaryStatDiffer: + def __init__(self, args): + self.args = args + self.lookup_file = os.path.join(self.args.dashboard_archive_path, "lookup.csv") + assert os.path.exists(self.lookup_file) + + def generate_diff(self, last2, filename, caption): + df_cur, df_prev = [pd.read_csv(os.path.join(path, filename)) for path in last2] + df_merge = df_cur.merge(df_prev, on="Compiler", suffixes=("_cur", "_prev")) + data = {col: [] for col in ("compiler", "suite", "prev_value", "cur_value")} + for _, row in df_merge.iterrows(): + if row["Compiler"] in self.args.flag_compilers: + for suite in self.args.suites: + if suite + "_prev" not in row or suite + "_cur" not in row: + continue + data["compiler"].append(row["Compiler"]) + data["suite"].append(suite) + data["prev_value"].append(row[suite + "_prev"]) + data["cur_value"].append(row[suite + "_cur"]) + + df = pd.DataFrame(data) + tabform = tabulate(df, headers="keys", tablefmt="pretty", showindex="never") + str_io = io.StringIO() + str_io.write("\n") + str_io.write(f"{caption}\n") + str_io.write("~~~\n") + str_io.write(f"{tabform}\n") + str_io.write("~~~\n") + return str_io.getvalue() + + def generate_comment(self): + title = "## Summary Statistics Diff ##\n" + body = ( + "For each relevant compiler, we compare the summary statistics " + "for the most 2 recent reports that actually run the compiler.\n\n" + ) + dtype = self.args.dtypes[0] + last2 = find_last_2_with_filenames( + self.lookup_file, + self.args.dashboard_archive_path, + dtype, + ["geomean.csv", "passrate.csv"], + ) + + if last2 is None: + body += "Could not find most 2 recent reports.\n\n" + else: + for state, path in zip(("Current", "Previous"), last2): + body += f"{state} report name: {path}\n\n" + body += self.generate_diff(last2, "passrate.csv", "Passrate diff") + body += self.generate_diff( + last2, "geomean.csv", "Geometric mean speedup diff" + ) + + comment = generate_dropdown_comment(title, body) + + with open(f"{self.args.output_dir}/gh_summary_diff.txt", "w") as gh_fh: + gh_fh.write(comment) + + +class RegressionDetector: + """ + Compares the most recent 2 benchmarks to find previously unflagged models + that are now flagged. + """ + + def __init__(self, args): + self.args = args + self.lookup_file = os.path.join(self.args.dashboard_archive_path, "lookup.csv") + assert os.path.exists(self.lookup_file) + + def generate_comment(self): + title = "## Recent Regressions ##\n" + body = ( + "For each relevant compiler, we compare the most recent 2 reports " + "(that actually run the compiler) to find previously unflagged " + "models that are now flagged as problematic (according to the " + "'Warnings' section).\n\n" + ) + dtype = self.args.dtypes[0] + device = self.args.devices[0] + for suite in self.args.suites: + body += f"### Regressions for {suite} ###\n" + last2 = {} + + for compiler in self.args.flag_compilers: + filenames = [ + generate_csv_name( + self.args, dtype, suite, device, compiler, testing + ) + for testing in ["performance", "accuracy"] + ] + compiler_last2 = find_last_2_with_filenames( + self.lookup_file, self.args.dashboard_archive_path, dtype, filenames + ) + if compiler_last2 is not None: + last2[compiler] = [ + ParsePerformanceLogs( + [suite], + [device], + [dtype], + [compiler], + [compiler], + get_mode(self.args), + output_dir, + ) + for output_dir in compiler_last2 + ] + for state, path in zip(("Current", "Previous"), compiler_last2): + body += ( + f"{state} report name (compiler: {compiler}, " + f"suite: {suite}): {path}\n\n" + ) + + regressions_present = False + for metric in [ + "accuracy", + "speedup", + "compilation_latency", + "compression_ratio", + ]: + dfs = [] + for compiler in self.args.flag_compilers: + if last2[compiler] is None: + continue + + df_cur, df_prev = [ + last2[compiler][i].untouched_parsed_frames[suite][metric] + for i in (0, 1) + ] + df_merge = df_cur.merge( + df_prev, on="name", suffixes=("_cur", "_prev") + ) + flag_fn = FLAG_FNS[metric] + flag = np.logical_and( + df_merge[compiler + "_prev"].apply( + lambda x: not pd.isna(x) and not flag_fn(x) + ), + df_merge[compiler + "_cur"].apply( + lambda x: not pd.isna(x) and flag_fn(x) + ), + ) + df_bad = df_merge[flag] + dfs.append( + pd.DataFrame( + data={ + "compiler": compiler, + "name": df_bad["name"], + "prev_status": df_bad[compiler + "_prev"], + "cur_status": df_bad[compiler + "_cur"], + } + ) + ) + + if not dfs: + continue + df = pd.concat(dfs, axis=0) + if df.empty: + continue + regressions_present = True + tabform = tabulate( + df, headers="keys", tablefmt="pretty", showindex="never" + ) + str_io = io.StringIO() + str_io.write("\n") + str_io.write(f"{get_metric_title(metric)} regressions\n") + str_io.write("~~~\n") + str_io.write(f"{tabform}\n") + str_io.write("~~~\n") + body += str_io.getvalue() + + if not regressions_present: + body += "No regressions found.\n" + + comment = generate_dropdown_comment(title, body) + + with open(f"{self.args.output_dir}/gh_metric_regression.txt", "w") as gh_fh: + gh_fh.write(comment) + + class RegressionTracker: """ Plots progress of different metrics over time to detect regressions. @@ -687,13 +1092,14 @@ def find_last_k(self): def generate_comment(self): title = "## Metrics over time ##\n" str_io = io.StringIO() - for name in glob.glob(self.args.output_dir + "/*over_time.png"): - output = ( - subprocess.check_output([self.args.dashboard_image_uploader, name]) - .decode("ascii") - .rstrip() - ) - str_io.write(f"\n{name} : ![]({output})\n") + if not self.args.update_dashboard_test and not self.args.no_graphs: + for name in glob.glob(self.args.output_dir + "/*over_time.png"): + output = ( + subprocess.check_output([self.args.dashboard_image_uploader, name]) + .decode("ascii") + .rstrip() + ) + str_io.write(f"\n{name} : ![]({output})\n") comment = generate_dropdown_comment(title, str_io.getvalue()) with open(f"{self.args.output_dir}/gh_regression.txt", "w") as gh_fh: @@ -702,7 +1108,7 @@ def generate_comment(self): def diff(self): log_infos = self.find_last_k() - for metric in ["geomean", "passrate"]: + for metric in ["geomean", "passrate", "comp_time", "memory"]: fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(15, 5)) for idx, suite in enumerate(self.suites): dfs = [] @@ -715,7 +1121,9 @@ def diff(self): if not os.path.exists(gmean_filename): continue df = pd.read_csv(gmean_filename) - if metric == "geomean": + if suite not in df: + continue + if metric == "geomean" or metric == "memory": df[suite] = df[suite].str.replace("x", "").astype(float) elif metric == "passrate": df[suite] = df[suite].str.split("%").str[0].astype(float) @@ -730,6 +1138,7 @@ def diff(self): dfs.append(df) df = pd.concat(dfs) + df = df.interpolate(method="linear") ax = df.plot( ax=axes[idx], kind="line", @@ -760,35 +1169,46 @@ def __init__(self, args): self.output_dir = args.output_dir self.lookup_file = os.path.join(self.args.dashboard_archive_path, "lookup.csv") assert os.path.exists(self.lookup_file) - self.archive() - - def archive(self): - # Copy the folder to archived location - src = self.output_dir - day = datetime.today().strftime("%j") - prefix = datetime.today().strftime(f"day_{day}_%d_%m_%y") - target_dir = f"{prefix}_performance_{self.args.dtypes[0]}_{randint(100, 999)}" - target = os.path.join(self.args.dashboard_archive_path, target_dir) - shutil.copytree(src, target) + try: + if not self.args.update_dashboard_test and not self.args.no_update_archive: + self.update_lookup_file() + except subprocess.CalledProcessError: + sys.stderr.write("failed to update lookup file\n") - # Update lookup csv the folder to arhived logs + def update_lookup_file(self): dtype = self.args.dtypes[0] + day, _ = archive_data(self.args.archive_name) + target_dir = get_archive_name(self.args, dtype) + # Update lookup csv the folder to arhived logs subprocess.check_call( f'echo "{day},performance,{dtype},{target_dir}" >> {self.lookup_file}', shell=True, ) + def archive(self): + dtype = self.args.dtypes[0] + # Copy the folder to archived location + archive( + self.output_dir, + self.args.dashboard_archive_path, + self.args.archive_name, + dtype, + ) + def upload_graphs(self): title = "## Performance graphs ##\n" str_io = io.StringIO() - for name in glob.glob(self.output_dir + "/*png"): - if "over_time" not in name: - output = ( - subprocess.check_output([self.args.dashboard_image_uploader, name]) - .decode("ascii") - .rstrip() - ) - str_io.write(f"\n{name} : ![]({output})\n") + if not self.args.update_dashboard_test and not self.args.no_graphs: + for name in glob.glob(self.output_dir + "/*png"): + if "over_time" not in name: + output = ( + subprocess.check_output( + [self.args.dashboard_image_uploader, name] + ) + .decode("ascii") + .rstrip() + ) + str_io.write(f"\n{name} : ![]({output})\n") comment = generate_dropdown_comment(title, str_io.getvalue()) with open(f"{self.output_dir}/gh_graphs.txt", "w") as gh_fh: @@ -798,14 +1218,21 @@ def gen_comment(self): files = [ "gh_title.txt", "gh_executive_summary.txt", + "gh_summary_diff.txt", + "gh_warnings.txt", "gh_regression.txt", + "gh_metric_regression.txt", "gh_training.txt", "gh_graphs.txt", + "gh_build_summary.txt", ] all_lines = [] for f in files: - with open(os.path.join(self.output_dir, f), "r") as fh: - all_lines.extend(fh.readlines()) + try: + with open(os.path.join(self.output_dir, f), "r") as fh: + all_lines.extend(fh.readlines()) + except FileNotFoundError: + pass return "\n".join([x.rstrip() for x in all_lines]) @@ -813,6 +1240,10 @@ def comment_on_gh(self, comment): """ Send a commment to dashboard """ + with tempfile.NamedTemporaryFile(mode="w", delete=False) as f: + f.write(comment) + filename = f.name + subprocess.check_call( [ self.args.dashboard_gh_cli_path, @@ -820,21 +1251,32 @@ def comment_on_gh(self, comment): "comment", "--repo=https://github.com/pytorch/torchdynamo.git", "681", - "-b", - comment, + "-F", + filename, ] ) + os.remove(filename) + def update(self): self.upload_graphs() + SummaryStatDiffer(self.args).generate_comment() + RegressionDetector(self.args).generate_comment() try: RegressionTracker(self.args).diff() - except Exception: + except Exception as e: + logging.exception(e) with open(f"{self.args.output_dir}/gh_regression.txt", "w") as gh_fh: gh_fh.write("") comment = self.gen_comment() - self.comment_on_gh(comment) + print(comment) + + if not self.args.update_dashboard_test: + if not self.args.no_gh_comment: + self.comment_on_gh(comment) + if not self.args.no_update_archive: + self.archive() if __name__ == "__main__": @@ -849,20 +1291,37 @@ def extract(key): if args.inference: compilers = DEFAULTS["inference"] if args.compilers is None else args.compilers + flag_compilers = ( + DEFAULTS["flag_compilers"]["inference"] + if args.flag_compilers is None + else args.flag_compilers + ) else: assert args.training compilers = DEFAULTS["training"] if args.compilers is None else args.compilers + flag_compilers = ( + DEFAULTS["flag_compilers"]["training"] + if args.flag_compilers is None + else args.flag_compilers + ) output_dir = args.output_dir args.compilers = compilers + args.devices = devices + args.dtypes = dtypes + flag_compilers = list(set(flag_compilers) & set(compilers)) + args.flag_compilers = flag_compilers args.suites = suites if args.print_run_commands: generate_commands(args, dtypes, suites, devices, compilers, output_dir) elif args.visualize_logs: - parse_logs(args, dtypes, suites, devices, compilers, output_dir) + parse_logs(args, dtypes, suites, devices, compilers, flag_compilers, output_dir) elif args.run: generate_commands(args, dtypes, suites, devices, compilers, output_dir) + # generate memoized archive name now so that the date is reflective + # of when the run started + get_archive_name(args, dtypes[0]) # TODO - Do we need to worry about segfaults try: os.system("bash run.sh") @@ -872,7 +1331,23 @@ def extract(key): ) raise e if not args.log_operator_inputs: - parse_logs(args, dtypes, suites, devices, compilers, output_dir) + if not args.no_update_archive: + archive( + output_dir, + args.dashboard_archive_path, + args.archive_name, + dtypes[0], + ) + parse_logs( + args, dtypes, suites, devices, compilers, flag_compilers, output_dir + ) + if not args.no_update_archive: + archive( + output_dir, + args.dashboard_archive_path, + args.archive_name, + dtypes[0], + ) if args.update_dashboard: DashboardUpdater(args).update() diff --git a/benchmarks/dynamo/test.py b/benchmarks/dynamo/test.py new file mode 100644 index 0000000000000..438218462030f --- /dev/null +++ b/benchmarks/dynamo/test.py @@ -0,0 +1,44 @@ +import os +import unittest + +from .common import parse_args, run + +from .torchbench import setup_torchbench_cwd, TorchBenchmarkRunner + +try: + # fbcode only + from aiplatform.utils.sanitizer_status import is_asan_or_tsan +except ImportError: + + def is_asan_or_tsan(): + return False + + +class TestDynamoBenchmark(unittest.TestCase): + @unittest.skipIf(is_asan_or_tsan(), "ASAN/TSAN not supported") + def test_benchmark_infra_runs(self) -> None: + """ + Basic smoke test that TorchBench runs. + + This test is mainly meant to check that our setup in fbcode + doesn't break. + + If you see a failure here related to missing CPP headers, then + you likely need to update the resources list in: + //caffe2:inductor + """ + original_dir = setup_torchbench_cwd() + try: + args = parse_args( + [ + "-dcpu", + "--inductor", + "--performance", + "--only=BERT_pytorch", + "-n1", + "--batch_size=1", + ] + ) + run(TorchBenchmarkRunner(), args, original_dir) + finally: + os.chdir(original_dir) diff --git a/benchmarks/dynamo/timm_models.py b/benchmarks/dynamo/timm_models.py index f7ff2559cbb8a..98d67c501d633 100755 --- a/benchmarks/dynamo/timm_models.py +++ b/benchmarks/dynamo/timm_models.py @@ -40,48 +40,34 @@ def pip_install(package): # TODO - Figure out the reason of cold start memory spike + BATCH_SIZE_DIVISORS = { "beit_base_patch16_224": 2, - "cait_m36_384": 4, - "convit_base": 4, + "cait_m36_384": 2, + "convit_base": 2, "convmixer_768_32": 2, - "convnext_base": 4, - "crossvit_9_240": 2, + "convnext_base": 2, "cspdarknet53": 2, "deit_base_distilled_patch16_224": 2, - "dla102": 2, "dpn107": 2, - "eca_botnext26ts_256": 2, - "eca_halonext26ts": 2, - "gluon_senet154": 2, "gluon_xception65": 2, - "gmixer_24_224": 2, - "gmlp_s16_224": 2, - "hrnet_w18": 64, - "jx_nest_base": 4, - "mixer_b16_224": 2, - "mixnet_l": 2, - "mobilevit_s": 4, - "nfnet_l0": 2, + "mobilevit_s": 2, "pit_b_224": 2, "pnasnet5large": 2, "poolformer_m36": 2, "res2net101_26w_4s": 2, - "res2net50_14w_8s": 64, - "res2next50": 64, - "resnest101e": 4, + "resnest101e": 2, "sebotnet33ts_256": 2, "swin_base_patch4_window7_224": 2, "swsl_resnext101_32x16d": 2, - "tf_mixnet_l": 2, - "tnt_s_patch16_224": 2, - "twins_pcpvt_base": 4, + "twins_pcpvt_base": 2, "vit_base_patch16_224": 2, "volo_d1_224": 2, + "jx_nest_base": 4, "xcit_large_24_p8_224": 4, } -REQUIRE_HIGHER_TOLERANCE = set() +REQUIRE_HIGHER_TOLERANCE = set("botnet26t_256") SKIP = { # Unusual training setup @@ -89,6 +75,11 @@ def pip_install(package): } +MAX_BATCH_SIZE_FOR_ACCURACY_CHECK = { + "cait_m36_384": 4, +} + + def refresh_model_names(): import glob @@ -230,11 +221,17 @@ def load_model( ) input_size = data_config["input_size"] recorded_batch_size = TIMM_MODELS[model_name] - recorded_batch_size = max( - int(recorded_batch_size / BATCH_SIZE_DIVISORS.get(model_name, 1)), 1 - ) + + if model_name in BATCH_SIZE_DIVISORS: + recorded_batch_size = max( + int(recorded_batch_size / BATCH_SIZE_DIVISORS[model_name]), 1 + ) batch_size = batch_size or recorded_batch_size + # Control the memory footprint for few models + if self.args.accuracy and model_name in MAX_BATCH_SIZE_FOR_ACCURACY_CHECK: + batch_size = min(batch_size, MAX_BATCH_SIZE_FOR_ACCURACY_CHECK[model_name]) + # example_inputs = torch.randn( # (batch_size,) + input_size, device=device, dtype=data_dtype # ) @@ -315,7 +312,7 @@ def forward_pass(self, mod, inputs, collect_outputs=True): def forward_and_backward_pass(self, mod, inputs, collect_outputs=True): cloned_inputs = clone_inputs(inputs) - self.optimizer_zero_grad() + self.optimizer_zero_grad(mod) with self.autocast(): pred = mod(*cloned_inputs) if isinstance(pred, tuple): diff --git a/benchmarks/dynamo/torchbench.py b/benchmarks/dynamo/torchbench.py index c37422a19bfd9..d138e3e692462 100755 --- a/benchmarks/dynamo/torchbench.py +++ b/benchmarks/dynamo/torchbench.py @@ -20,31 +20,37 @@ # We are primarily interested in tf32 datatype torch.backends.cuda.matmul.allow_tf32 = True -original_dir = abspath(os.getcwd()) - -os.environ["KALDI_ROOT"] = "/tmp" # avoids some spam -for torchbench_dir in ( - "./torchbenchmark", - "../torchbenchmark", - "../torchbench", - "../benchmark", - "../../torchbenchmark", - "../../torchbench", - "../../benchmark", -): + + +def setup_torchbench_cwd(): + original_dir = abspath(os.getcwd()) + + os.environ["KALDI_ROOT"] = "/tmp" # avoids some spam + for torchbench_dir in ( + "./torchbenchmark", + "../torchbenchmark", + "../torchbench", + "../benchmark", + "../../torchbenchmark", + "../../torchbench", + "../../benchmark", + ): + if exists(torchbench_dir): + break + if exists(torchbench_dir): - break + torchbench_dir = abspath(torchbench_dir) + os.chdir(torchbench_dir) + sys.path.append(torchbench_dir) -if exists(torchbench_dir): - torchbench_dir = abspath(torchbench_dir) - os.chdir(torchbench_dir) - sys.path.append(torchbench_dir) + return original_dir # Some models have large dataset that doesn't fit in memory. Lower the batch # size to test the accuracy. USE_SMALL_BATCH_SIZE = { "demucs": 4, + "dlrm": 1024, "densenet121": 4, "hf_Reformer": 4, "timm_efficientdet": 1, @@ -113,8 +119,7 @@ } REQUIRE_COSINE_TOLERACE = { - # https://github.com/pytorch/torchdynamo/issues/556 - "resnet50_quantized_qat", + # Just keeping it here even though its empty, if we need this in future. } # non-deterministic output / cant check correctness @@ -177,6 +182,12 @@ } +MAX_BATCH_SIZE_FOR_ACCURACY_CHECK = { + "hf_GPT2": 2, + "pytorch_unet": 2, +} + + class TorchBenchmarkRunner(BenchmarkRunner): def __init__(self): super(TorchBenchmarkRunner, self).__init__() @@ -212,7 +223,7 @@ def failing_dynamic_shape_models(self): @property def skip_accuracy_checks_large_models_dashboard(self): - if self.args.dashboard: + if self.args.dashboard or self.args.accuracy: return SKIP_ACCURACY_CHECK_MODELS return set() @@ -221,12 +232,16 @@ def load_model( device, model_name, batch_size=None, + part=None, ): is_training = self.args.training use_eval_mode = self.args.use_eval_mode dynamic_shapes = self.args.dynamic_shapes - module = importlib.import_module(f"torchbenchmark.models.{model_name}") + try: + module = importlib.import_module(f"torchbenchmark.models.{model_name}") + except ModuleNotFoundError: + module = importlib.import_module(f"torchbenchmark.models.fb.{model_name}") benchmark_cls = getattr(module, "Model", None) if not hasattr(benchmark_cls, "name"): benchmark_cls.name = model_name @@ -240,15 +255,30 @@ def load_model( if batch_size is None and is_training and model_name in USE_SMALL_BATCH_SIZE: batch_size = USE_SMALL_BATCH_SIZE[model_name] + # Control the memory footprint for few models + if self.args.accuracy and model_name in MAX_BATCH_SIZE_FOR_ACCURACY_CHECK: + batch_size = min(batch_size, MAX_BATCH_SIZE_FOR_ACCURACY_CHECK[model_name]) + # workaround "RuntimeError: not allowed to set torch.backends.cudnn flags" torch.backends.__allow_nonbracketed_mutation_flag = True + extra_args = [] + if part: + extra_args = ["--part", part] if is_training: benchmark = benchmark_cls( - test="train", device=device, jit=False, batch_size=batch_size + test="train", + device=device, + jit=False, + batch_size=batch_size, + extra_args=extra_args, ) else: benchmark = benchmark_cls( - test="eval", device=device, jit=False, batch_size=batch_size + test="eval", + device=device, + jit=False, + batch_size=batch_size, + extra_args=extra_args, ) if dynamic_shapes: if not hasattr(benchmark, "get_dynamic_shapes_module"): @@ -306,9 +336,10 @@ def get_tolerance_and_cosine_flag(self, is_training, current_device, name): tolerance = 1e-4 cosine = self.args.cosine # Increase the tolerance for torch allclose - if self.args.float16: + if self.args.float16 or self.args.amp: return 1e-3, cosine if is_training and current_device == "cuda": + tolerance = 1e-3 if name in REQUIRE_COSINE_TOLERACE: cosine = True elif name in REQUIRE_HIGHER_TOLERANCE: @@ -325,7 +356,7 @@ def forward_pass(self, mod, inputs, collect_outputs=True): def forward_and_backward_pass(self, mod, inputs, collect_outputs=True): cloned_inputs = clone_inputs(inputs) - self.optimizer_zero_grad() + self.optimizer_zero_grad(mod) with self.autocast(): pred = mod(*cloned_inputs) loss = self.compute_loss(pred) @@ -338,6 +369,7 @@ def forward_and_backward_pass(self, mod, inputs, collect_outputs=True): if __name__ == "__main__": + original_dir = setup_torchbench_cwd() logging.basicConfig(level=logging.WARNING) warnings.filterwarnings("ignore") main(TorchBenchmarkRunner(), original_dir) diff --git a/benchmarks/functional_autograd_benchmark/torchaudio_models.py b/benchmarks/functional_autograd_benchmark/torchaudio_models.py index 8512028fbad0d..1e568d1d01f04 100644 --- a/benchmarks/functional_autograd_benchmark/torchaudio_models.py +++ b/benchmarks/functional_autograd_benchmark/torchaudio_models.py @@ -330,8 +330,9 @@ def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5): super(TransformerModel, self).__init__() try: from torch.nn import TransformerEncoder, TransformerEncoderLayer - except Exception: - raise ImportError('TransformerEncoder module does not exist in PyTorch 1.1 or lower.') + except Exception as e: + raise ImportError('TransformerEncoder module does not exist in PyTorch 1.1 or ' + 'lower.') from e self.model_type = 'Transformer' self.src_mask = None self.pos_encoder = PositionalEncoding(ninp, dropout) diff --git a/benchmarks/instruction_counts/README.md b/benchmarks/instruction_counts/README.md index ed2633caba151..32071e8aa80e0 100644 --- a/benchmarks/instruction_counts/README.md +++ b/benchmarks/instruction_counts/README.md @@ -73,7 +73,7 @@ Timer( ``` Moreover, because `signature` is provided we know that creation of `x` and `w` -is part of setup, and the overall comptation uses `x` and `w` to produce `y`. +is part of setup, and the overall computation uses `x` and `w` to produce `y`. As a result, we can derive TorchScript'd and AutoGrad variants as well. We can deduce that a TorchScript model will take the form: diff --git a/benchmarks/nested/nested_bmm_bench.py b/benchmarks/nested/nested_bmm_bench.py index 311b23395efdb..56e283effddf9 100644 --- a/benchmarks/nested/nested_bmm_bench.py +++ b/benchmarks/nested/nested_bmm_bench.py @@ -1,4 +1,5 @@ import argparse +import random import torch @@ -15,31 +16,38 @@ def bench(nt_a, nt_b, niter): nt_c = nt_a.bmm(nt_b) end_event.record() torch.cuda.synchronize() - runtime = (start_event.elapsed_time(end_event) * 1.0e-3) / niter + runtime = (start_event.elapsed_time(end_event)) / niter return runtime -def sweep_n(ntensor, niter, dtype): - print("n, dtype, ntensor, gflop, runtime, tflop/s") - for n in [16, 32, 64, 128, 256, 512, 1024, 2048, 4096]: - nt_a = torch.nested_tensor( - [torch.randn(n, n).to(dtype).cuda() for t in range(ntensor)] +def sweep_n(niter, dtype): + for ntensor in [4, 8, 16, 32, 64, 128, 256]: + tensors = [torch.randn(256, random.randint(100, 200)) for t in range(ntensor)] + nt_a = torch.nested.nested_tensor( + tensors, + dtype=dtype, + device="cuda", ) - nt_b = torch.nested_tensor( - [torch.randn(n, n).to(dtype).cuda() for t in range(ntensor)] + nt_b = torch.nested.nested_tensor( + [t.t() for t in tensors], + dtype=dtype, + device="cuda", ) runtime = bench(nt_a, nt_b, niter) - tflop = n * n * n * ntensor * 2 / 1e12 - print(n, dtype, ntensor, tflop, runtime, tflop / runtime) + nt_a_size = torch.ops.aten._nested_tensor_size(nt_a) + lengths = nt_a_size[:, 1] + print(",".join(map(str, [ntensor, dtype, lengths.min().item(), + lengths.float().mean().item(), lengths.max().item(), runtime]))) + if __name__ == "__main__": + random.seed(123) parser = argparse.ArgumentParser(description="Nested Tensor BMM Benchmark") parser.add_argument("--niter", default="10", type=int) - parser.add_argument("--ntensor", default="20", type=int) args = parser.parse_args() niter = args.niter - ntensor = args.ntensor - sweep_n(ntensor, niter, torch.float32) - sweep_n(ntensor, niter, torch.float16) + print("ntensor,dtype,min_length,mean_length,max_length,runtime") + sweep_n(niter, torch.float32) + sweep_n(niter, torch.float16) diff --git a/benchmarks/operator_benchmark/README.md b/benchmarks/operator_benchmark/README.md index 59918f6fab3ca..cff275d9a1f97 100644 --- a/benchmarks/operator_benchmark/README.md +++ b/benchmarks/operator_benchmark/README.md @@ -374,7 +374,7 @@ unary_ops_list = op_bench.op_list( ``` #### Part 2. Create Tensors and Add Computation -In this example, both operators share the same input so we only need to implement one TorchBenchmakrBase subclass. +In this example, both operators share the same input so we only need to implement one TorchBenchmarkBase subclass. Every new subclass is required to implement 3 methods: * `init` is used to create tensors and set the operator name and function. In this example, the parameters to `init` are `M`, `N`, and `op_func` which have been specified in the configurations. * `forward` includes the operator to be tested and the computation based on the created tensors in `init`. Apart from `self`, the order of the arguments must match the entries specified in `self.inputs`. diff --git a/benchmarks/static_runtime/test_generated_ops.cc b/benchmarks/static_runtime/test_generated_ops.cc index 13be31e29a38a..415bf464fbd13 100644 --- a/benchmarks/static_runtime/test_generated_ops.cc +++ b/benchmarks/static_runtime/test_generated_ops.cc @@ -5584,138 +5584,6 @@ TEST(StaticRuntime, autogen_multilabel_margin_loss) { /*check_resize=*/false); } -TEST(StaticRuntime, autogen_nll_loss) { - const std::string script = R"IR( - graph(%self: Tensor, %target: Tensor, %weight: Tensor?, %reduction: int, %ignore_index: int): - %bias: None = prim::Constant() - %ret = aten::nll_loss(%self, %target, %weight, %reduction, %ignore_index) - %cloned = aten::clone(%ret, %bias) - return (%cloned) - )IR"; - - auto self0 = at::rand({6, 6}); - auto target0 = at::randint(6, {6}, torch::kInt64); - auto weight0 = at::rand({6}); - auto reduction0 = 1; - auto ignore_index0 = 1; - std::vector args{self0, target0, weight0, reduction0, ignore_index0}; - testStaticRuntime( - script, - args, - {}, - /*use_allclose=*/false, - /*use_equalnan=*/false, - /*check_resize=*/false); - - auto self1 = at::rand({22, 22}); - auto target1 = at::randint(22, {22}, torch::kInt64); - auto weight1 = at::rand({22}); - auto reduction1 = 1; - auto ignore_index1 = 1; - std::vector args2{self1, target1, weight1, reduction1, ignore_index1}; - testStaticRuntime( - script, - args, - args2, - /*use_allclose=*/false, - /*use_equalnan=*/false, - /*check_resize=*/false); -} - -TEST(StaticRuntime, autogen_nll_loss_backward) { - const std::string script = R"IR( - graph(%grad_output: Tensor, %self: Tensor, %target: Tensor, %weight: Tensor?, %reduction: int, %ignore_index: int, %total_weight: Tensor): - %bias: None = prim::Constant() - %ret = aten::nll_loss_backward(%grad_output, %self, %target, %weight, %reduction, %ignore_index, %total_weight) - %cloned = aten::clone(%ret, %bias) - return (%cloned) - )IR"; - - auto grad_output0 = at::rand({}); - auto self0 = at::rand({6}); - auto target0 = at::randint(0, 5, {6}, torch::kInt64); - auto weight0 = at::rand({6}); - auto reduction0 = 1; - auto ignore_index0 = 1; - auto total_weight0 = at::rand({}); - std::vector args{ - grad_output0, - self0, - target0, - weight0, - reduction0, - ignore_index0, - total_weight0}; - testStaticRuntime( - script, - args, - {}, - /*use_allclose=*/false, - /*use_equalnan=*/false, - /*check_resize=*/true); - - auto grad_output1 = at::rand({}); - auto self1 = at::rand({36}); - auto target1 = at::randint(0, 11, {36}, torch::kInt64); - auto weight1 = at::rand({36}); - auto reduction1 = 1; - auto ignore_index1 = 1; - auto total_weight1 = at::rand({}); - std::vector args2{ - grad_output1, - self1, - target1, - weight1, - reduction1, - ignore_index1, - total_weight1}; - testStaticRuntime( - script, - args, - args2, - /*use_allclose=*/false, - /*use_equalnan=*/false, - /*check_resize=*/true); -} - -TEST(StaticRuntime, autogen_nll_loss2d) { - const std::string script = R"IR( - graph(%self: Tensor, %target: Tensor, %weight: Tensor?, %reduction: int, %ignore_index: int): - %bias: None = prim::Constant() - %ret = aten::nll_loss2d(%self, %target, %weight, %reduction, %ignore_index) - %cloned = aten::clone(%ret, %bias) - return (%cloned) - )IR"; - - auto self0 = at::rand({6, 6, 6, 6}); - auto target0 = at::randint(6, {6, 6, 6}, torch::kInt64); - auto weight0 = at::rand({6}); - auto reduction0 = 1; - auto ignore_index0 = 1; - std::vector args{self0, target0, weight0, reduction0, ignore_index0}; - testStaticRuntime( - script, - args, - {}, - /*use_allclose=*/false, - /*use_equalnan=*/false, - /*check_resize=*/false); - - auto self1 = at::rand({22, 22, 22, 22}); - auto target1 = at::randint(22, {22, 22, 22}, torch::kInt64); - auto weight1 = at::rand({22}); - auto reduction1 = 1; - auto ignore_index1 = 1; - std::vector args2{self1, target1, weight1, reduction1, ignore_index1}; - testStaticRuntime( - script, - args, - args2, - /*use_allclose=*/false, - /*use_equalnan=*/false, - /*check_resize=*/false); -} - TEST(StaticRuntime, autogen_soft_margin_loss) { const std::string script = R"IR( graph(%self: Tensor, %target: Tensor, %reduction: int): @@ -7973,7 +7841,6 @@ TEST(StaticRuntime, autogen_diagonal) { auto offset0 = 0; auto dim10 = 2; auto dim20 = 1; - auto dim00 = 1; std::vector args{self0, offset0, dim10, dim20}; testStaticRuntime(script, args); } @@ -7991,7 +7858,6 @@ TEST(StaticRuntime, autogen_linalg_diagonal) { auto offset0 = 0; auto dim10 = 2; auto dim20 = 1; - auto dim00 = 1; std::vector args{A0, offset0, dim10, dim20}; testStaticRuntime(script, args); } diff --git a/benchmarks/static_runtime/test_static_module.cc b/benchmarks/static_runtime/test_static_module.cc index 70d1d1d306939..3c927c9c41d9d 100644 --- a/benchmarks/static_runtime/test_static_module.cc +++ b/benchmarks/static_runtime/test_static_module.cc @@ -77,13 +77,6 @@ const auto sigmoid_inplace_script = R"JIT( return (a) )JIT"; -const auto sigmoid_out_script = R"JIT( - def forward(self, inp: Tensor): - a = inp + inp - b = torch.sigmoid(inp, out=a).clone() - return (b) -)JIT"; - } // namespace // Test that StaticModule::value_group groups values of the graph into @@ -354,6 +347,18 @@ TEST(StaticRuntime, CanEnableStaticRuntime) { EXPECT_TRUE(testCanEnableStaticRuntime(is_not_script_none)); } +TEST(StaticRuntime, CanEnableStaticRuntimeSubBlocks) { + const auto src = R"JIT( + def forward(self, a: Tensor, b: Tensor, cond: bool): + if cond: + # aten::__is__ on tensors is blocked + return a is b + return False + )JIT"; + + EXPECT_FALSE(testCanEnableStaticRuntime(src)); +} + TEST(StaticRuntime, NestedOutput) { // dict of tuple of list const auto nested_output_script_0 = R"JIT( diff --git a/benchmarks/static_runtime/test_static_runtime.cc b/benchmarks/static_runtime/test_static_runtime.cc index cf102224fc087..ef3bc75f921b2 100644 --- a/benchmarks/static_runtime/test_static_runtime.cc +++ b/benchmarks/static_runtime/test_static_runtime.cc @@ -2164,12 +2164,7 @@ TEST(StaticRuntime, Permute) { c10::List dims_b{0, 2, 1}; std::vector args_b{b, dims_b}; - auto c = at::randn({3, 3, 3}); - c10::List dims_c{0, -1, 1}; - std::vector args_c{c, dims_c}; - testStaticRuntime(permute_script, args_a); - testStaticRuntime(permute_script, args_c); testStaticRuntime(permute_script, args_a, args_b); permute_script = R"JIT( @@ -2560,8 +2555,8 @@ TEST(StaticRuntime, Tensor_Split) { std::vector args2{at::randn({8}), torch::tensor(3), 0}; const auto tensor_split_str3 = R"JIT( - def forward(self, a: Tensor, indicies: List[int], dim: int): - return torch.tensor_split(a, indicies, dim) + def forward(self, a: Tensor, indices: List[int], dim: int): + return torch.tensor_split(a, indices, dim) )JIT"; std::vector args3{at::randn({8}), c10::List({1, 6}), 0}; @@ -3194,9 +3189,14 @@ TEST(StaticRuntime, ReplaceWithMaybeCopy) { smodule.runtime().check_for_memory_leak(); EXPECT_TRUE(expected.equal(actual)); - EXPECT_FALSE(hasProcessedNodeWithName(smodule, "aten::to")); + + // Make a fresh graph to ensure the pass works in isolation + auto new_graph = std::make_shared(); + torch::jit::parseIR(to, new_graph.get()); + ReplaceWithMaybeCopy(new_graph); + EXPECT_FALSE(hasNodeWithKind(new_graph, "aten::to")); EXPECT_TRUE( - hasProcessedNodeWithName(smodule, "static_runtime::to_maybe_copy_out")); + hasNodeWithKind(new_graph, "static_runtime::to_maybe_copy_out")); } TEST(StaticRuntime, Int) { @@ -3676,41 +3676,6 @@ TEST(StaticRuntime, ClampNaNToNum) { testStaticRuntime(src1, {a.to(at::kDouble)}, {b.to(at::kDouble)}, /*use_allclose=*/true, /*use_equalnan=*/true); } -TEST(StaticRuntime, PrepackWeights) { - const std::string src = R"IR( - graph(%input: Tensor, %weight: Tensor, %bias: Tensor?, %scale: Tensor, %zero_point: Tensor): - %none: NoneType = prim::Constant() - %result: Tensor = fb::quantized_linear_unpacked_weight_v2(%input, %weight, %bias, %scale, %zero_point) - %dequantized: Tensor = aten::dequantize(%result) - return (%dequantized) - )IR"; - - auto graph = getGraphFromIR(src); - PrepackWeights(graph); - ASSERT_TRUE(graphHasOp(graph, "quantized::linear")); - ASSERT_TRUE(graphHasOp(graph, "quantized::linear_prepack")); - ASSERT_FALSE(graphHasOp(graph, "fb::quantized_linear_unpacked_weight_v2")); - - auto scale = at::tensor({2}, at::kFloat); - auto zero_point = at::tensor({3}, at::kLong); - - auto weight = - at::quantize_per_tensor(torch::randn({3, 2}), 2, 3, torch::kQInt8); - auto input = - at::quantize_per_tensor(torch::randn({3, 2}), 2, 3, torch::kQUInt8); - auto args1 = std::vector{input, weight, c10::nullopt, scale, zero_point}; - - auto weight_2 = - at::quantize_per_tensor(torch::randn({8, 3}), 2, 3, torch::kQInt8); - auto input_2 = - at::quantize_per_tensor(torch::randn({9, 3}), 2, 3, torch::kQUInt8); - auto bias_2 = torch::randn({3}, torch::kFloat); - auto args2 = std::vector{input, weight, bias_2, scale, zero_point}; - - testStaticRuntime(src, args1); - testStaticRuntime(src, args2); -} - TEST(StaticRuntime, IfReturningTuple) { const auto src = R"JIT( def forward(self, x, y, cond: bool, idx: int): diff --git a/benchmarks/static_runtime/test_utils.cc b/benchmarks/static_runtime/test_utils.cc index 7e0733fbc8af4..cc88801139334 100644 --- a/benchmarks/static_runtime/test_utils.cc +++ b/benchmarks/static_runtime/test_utils.cc @@ -124,7 +124,7 @@ void compareTensorLists( const bool use_allclose, const bool use_equalnan) { EXPECT_TRUE(l.size() == r.size()); - for (int i = 0; i < l.size(); ++i) { + for (auto i : c10::irange(l.size())) { ASSERT_TRUE(l[i].isTensor()); ASSERT_TRUE(r[i].isTensor()); VLOG(2) << "expect " << i << ": \n" << l[i] << std::endl; @@ -172,7 +172,7 @@ void compareResults( EXPECT_TRUE(actual.isTuple()); auto lhs = expect.toTupleRef().elements(); auto rhs = actual.toTupleRef().elements(); - EXPECT_TRUE(lhs.size() == rhs.size()); + ASSERT_TRUE(lhs.size() == rhs.size()); for (size_t i = 0; i < lhs.size(); i++) { compareResults(lhs[i], rhs[i]); } @@ -180,7 +180,7 @@ void compareResults( EXPECT_TRUE(actual.isList()); auto lhs = expect.toList(); auto rhs = actual.toList(); - EXPECT_TRUE(lhs.size() == rhs.size()); + ASSERT_TRUE(lhs.size() == rhs.size()); for (size_t i = 0; i < lhs.size(); i++) { compareResults(lhs[i], rhs[i]); } @@ -191,7 +191,7 @@ void compareResults( EXPECT_TRUE(lhs.size() == rhs.size()); for (auto& lh : lhs) { auto f = rhs.find(lh.key()); - EXPECT_FALSE(f == rhs.end()); + ASSERT_FALSE(f == rhs.end()); compareResults(lh.value(), f->value()); } } else { @@ -298,11 +298,12 @@ void testStaticRuntime( // 1st run: collect allocation profiles (args) // 2nd run: exercise memory planner and resizing with args2 // 3rd run: run with args again - StaticModuleOptions opts{ - .enable_out_variant = enable_out_variant, - .optimize_memory = enable_out_variant, - .manage_output_tensors = manage_output_tensors, - .enable_tensorexpr_fusion = enable_tensorexpr_fusion}; + StaticModuleOptions opts; + opts.enable_out_variant = enable_out_variant; + opts.optimize_memory = enable_out_variant; + opts.manage_output_tensors = manage_output_tensors; + opts.enable_tensorexpr_fusion = enable_tensorexpr_fusion; + auto smodule = test_context->makeStaticModule(opts); StaticRuntime runtime(smodule); auto actual = runtime(args, {}); diff --git a/benchmarks/transformer/better_transformer_vs_mha_functional.py b/benchmarks/transformer/better_transformer_vs_mha_functional.py new file mode 100644 index 0000000000000..b76077ba4c22e --- /dev/null +++ b/benchmarks/transformer/better_transformer_vs_mha_functional.py @@ -0,0 +1,195 @@ +""" +Tests the performance of torch.nn.MultiheadAttention's fast path (BetterTransformer) +vs the slow path (torch.nn.functional.multi_head_attention) + +To run this script install these dependencies: + +pip install tqdm +pip install prettytable +""" + +import torch +import random +import numpy as np +from pprint import pprint +import itertools +import json +import argparse +from pathlib import Path +from typing import Optional + +from prettytable import PrettyTable +from collections import defaultdict, OrderedDict +from tqdm import tqdm + + +import warnings + +warnings.filterwarnings("ignore") + +error_dict = defaultdict(int) + + +def benchmark_torch_function(iters, f, *args, **kwargs): + f(*args, **kwargs) + f(*args, **kwargs) + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for _ in range(iters): + f(*args, **kwargs) + end_event.record() + torch.cuda.synchronize() + # elapsed_time has a resolution of 0.5 microseconds: + # but returns milliseconds, so we need to multiply it to increase resolution + return start_event.elapsed_time(end_event) * 1000 / iters, *f(*args, **kwargs) + + +def run(a: int, b: int, iters: int, batch_size: int, sequence_length: int, + embed_dim: int, num_heads: int, device: str, dtype: str, block_size: int, seed): + random.seed(seed) + torch.manual_seed(seed) + np.random.seed(seed) + + from scipy.stats import beta + lengths = beta.rvs(a, b, size=batch_size) * (sequence_length + block_size - 1) // block_size + lengths = list(map(int, list(lengths))) + lengths = [l * block_size for l in lengths] + lengths = [max(l, block_size) for l in lengths] + + # Used to enforce no padding + # lengths = [sequence_length] * batch_size + + # Ensure one row in the batch of ele has the max_sequence_length + lengths[random.randint(0, batch_size - 1)] = sequence_length + + q = [torch.randn(l, embed_dim, device=device, dtype=dtype) + for l in lengths] + q = torch.nested.nested_tensor(q, device=device, dtype=dtype) + k, v = q, q + + qkv = torch.nn.Linear(embed_dim, 3 * embed_dim, device=device, dtype=dtype) + proj = torch.nn.Linear(embed_dim, embed_dim, device=device, dtype=dtype) + + native_mha = torch.nn.MultiheadAttention( + embed_dim, num_heads, batch_first=True, device=device, dtype=dtype + ).eval() + native_mha.in_proj_weight = qkv.weight + native_mha.in_proj_bias = qkv.bias + native_mha.out_proj.weight = proj.weight + native_mha.out_proj.bias = proj.bias + + # Create query mask + q_mask = torch.nested.to_padded_tensor( + torch.nested.nested_tensor([ + torch.tensor([True] * length, dtype=torch.bool) + for length in lengths + ]), 0) + q_mask = q_mask.cuda() + + if q_mask.size(1) == 0: + return None + + # Benchmark the native MHA in core + with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=True): + with torch.inference_mode(): + time_native_mha_fast, y_native_mha_fast, _ = benchmark_torch_function( + iters, native_mha, q, k, v, need_weights=False) + q = q.to_padded_tensor(0) + k = q + v = q + # Internal Flash Attention + time_native_mha_slow, y_native_mha_slow, _ = benchmark_torch_function( + iters, native_mha, q, k, v, key_padding_mask=~q_mask, need_weights=False) + + # Convert to padded for comparison + if y_native_mha_fast.is_nested: + y_native_mha_fast = torch.nested.to_padded_tensor(y_native_mha_fast, 0) + y_native_mha_fast = y_native_mha_fast * q_mask.unsqueeze(-1) + + if y_native_mha_slow.is_nested: + y_native_mha_slow = torch.nested.to_padded_tensor(y_native_mha_slow, 0) + y_native_mha_slow = y_native_mha_slow * q_mask.unsqueeze(-1) + + # Correctness check + entry_name = f"batch:{batch_size}_seq_len:{sequence_length}_n_heads:{num_heads}_embed_dim:{embed_dim}" + try: + torch.testing.assert_close(y_native_mha_fast, y_native_mha_slow, atol=1e-3, rtol=1e-3) + except AssertionError as e: + error_dict[entry_name] += 1 + pprint(error_dict) + + # Calculate amount of padding + padding = 1 - q_mask.float().mean().item() + + # Calculate the speedup for flash attention + speedup_fast_internal = time_native_mha_slow / time_native_mha_fast + + result_entry = OrderedDict() + result_entry['dtype'] = dtype + result_entry["batch_size"] = batch_size + result_entry["sequence_length"] = sequence_length + result_entry["n_heads"] = num_heads + result_entry["embed_dim"] = embed_dim + result_entry["time_native_mha_slow(μs)"] = f"{time_native_mha_slow:.3f}" + result_entry["time_native_mha_fast (μs)"] = f"{time_native_mha_fast:.3f}" + result_entry["speedup flash_mha v native_mha"] = f"{speedup_fast_internal:.3f}" + result_entry["padding"] = f"{padding:.3f}" + return result_entry + + +def main(save_path: Optional[Path], error_path: Optional[Path]): + table = PrettyTable() + entries = defaultdict(list) + + print("CUDA device: ", torch.cuda.get_device_name(0)) + iters = 100 + header = None + batch_sizes = [16, 32, 64, 128, 256] + sequence_lengths = [64, 128, 256, 512] + embed_dims = [512, 1024] + num_heads_list = [8, 16] + betas = range(1, 64, 4) + + for (batch_size, sequence_length, embed_dim, num_heads, block_size, b) in tqdm( + list(itertools.product(batch_sizes, sequence_lengths, embed_dims, num_heads_list, [2], betas))): + seed = 26214 # Magic number that works well for higher b values + entry = run(1, b * 0.05, iters, batch_size, sequence_length, + embed_dim, num_heads, "cuda", torch.float16, block_size, seed) + if entry is None: + continue + if header is None: + table.field_names = list(entry.keys()) + header = list(entry.keys()) + row = [] + for k, v in entry.items(): + row.append(v) + entries[k].append(v) + table.add_row(row) + + # Print the full table to console + print(table) + pprint(error_dict) + + csv_string = table.get_csv_string() + if save_path is not None: + with open(save_path, 'w') as csvfile: + csvfile.write(csv_string) + + print(f"Total errors: {sum(error_dict.values())}") + if error_path is not None: + with open(error_path, 'w') as file: + file.write(json.dumps(error_dict)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--save_path", type=str, help="Path to save the results") + parser.add_argument("--error_save_path", type=str, help="Path to save the errors") + + args = parser.parse_args() + save_path = Path(args.save_path) if args.save_path else None + error_path = Path(args.error_save_path) if args.error_save_path else None + + main(save_path, error_path) diff --git a/benchmarks/transformer/sdp.py b/benchmarks/transformer/sdp.py index 50db76e9f8c21..fbd123fc39b31 100644 --- a/benchmarks/transformer/sdp.py +++ b/benchmarks/transformer/sdp.py @@ -7,6 +7,8 @@ import warnings warnings.filterwarnings("ignore") + + class CompositeMHA(torch.nn.Module): def __init__(self, num_heads, in_proj_weight, in_proj_bias, out_proj): super().__init__() @@ -90,8 +92,8 @@ def benchmark_torch_function(iters, f, *args, **kwargs): return (start_event.elapsed_time(end_event) * 1.0e-3) / iters -def run_timing(iters, batch_size, embed_dimension, num_heads, max_sequence_len, pad_percentage, writer): - with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=True): +def run_timing(iters, batch_size, embed_dimension, num_heads, max_sequence_len, pad_percentage, enable_math, enable_flash, writer): + with torch.backends.cuda.sdp_kernel(enable_math=enable_math, enable_flash=enable_flash): with torch.inference_mode(): dropout_p = 0.0 mask = None @@ -122,6 +124,8 @@ def run_timing(iters, batch_size, embed_dimension, num_heads, max_sequence_len, results["cp_time"] = cp_time results["speedup"] = pt_time / cp_time results["dtype"] = str(x.dtype) + results["enable_math"] = str(enable_math) + results["enable_flash"] = str(enable_flash) writer.writerow(results) @@ -131,15 +135,22 @@ def main(): np.random.seed(seed) torch.manual_seed(seed) - headers = ["max_sequence_len", "num_heads", "embed_dimension", "pt_time", "cp_time", "speedup", "dtype"] + headers = ["max_sequence_len", "num_heads", "embed_dimension", "pt_time", + "cp_time", "speedup", "dtype", "enable_math", "enable_flash"] writer = csv.DictWriter(sys.stdout, headers) writer.writeheader() batch_size = 64 pad_percentage = 0.5 - for num_heads, max_seq_len in itertools.product([2, 4, 8, 16, 32], [64, 128, 256]): - run_timing(iters, batch_size, 1024, num_heads, max_seq_len, pad_percentage, writer) + for (enable_math, enable_flash) in [(False, True), (True, False), (True, True)]: + for num_heads, max_seq_len in itertools.product([2, 4, 8, 16, 32], [64, 128, 256]): + run_timing(iters, batch_size, 1024, num_heads, max_seq_len, + pad_percentage, enable_math, enable_flash, writer) + run_timing(iters, batch_size, 1024, num_heads, max_seq_len, + pad_percentage, enable_math, enable_flash, writer) + run_timing(iters, batch_size, 1024, num_heads, max_seq_len, + pad_percentage, enable_math, enable_flash, writer) if __name__ == "__main__": diff --git a/benchmarks/transformer/sdp_backwards.py b/benchmarks/transformer/sdp_backwards.py new file mode 100644 index 0000000000000..2f745e157b280 --- /dev/null +++ b/benchmarks/transformer/sdp_backwards.py @@ -0,0 +1,189 @@ +import torch +import numpy as np +import random +import torch.utils.benchmark as benchmark +from torch.profiler import profile, record_function, ProfilerActivity + + +class CompositeMHA(torch.nn.Module): + def __init__(self, num_heads, in_proj_weight, in_proj_bias, out_proj): + super().__init__() + self.in_proj_weight = in_proj_weight + self.in_proj_bias = in_proj_bias + self.out_proj = out_proj + self.num_heads = num_heads + + def forward(self, query, key, value, mask): + if not (query is key and key is value): + raise NotImplementedError( + "query, key and value must be the same Tensor for now." + ) + if mask is not None: + raise NotImplementedError("mask is currently not supported.") + + query_projected = torch.nn.functional.linear( + query, self.in_proj_weight, self.in_proj_bias + ) + + batch_size = query_projected.size(0) + embed_dim = query_projected.size(2) + head_dim = embed_dim // (self.num_heads * 3) + + query, key, value = query_projected.chunk(3, -1) + + query = query.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + attn, _ = torch.nn.functional._scaled_dot_product_attention( + query, + key, + value, + attn_mask=None, + dropout_p=0.0, + need_attn_weights=False, + is_causal=False, + ) + + attn = attn.transpose(1, 2).reshape(batch_size, -1, self.num_heads * head_dim) + # Match return signature of nn.MHA + return self.out_proj(attn) + + +def build_composite_mha_from_nn_mha(pt): + assert pt._qkv_same_embed_dim + in_proj_weight = pt.in_proj_weight + assert in_proj_weight is not None + assert pt.batch_first + return CompositeMHA(pt.num_heads, pt.in_proj_weight, pt.in_proj_bias, pt.out_proj) + + +def forw_back(model, input, upward): + output = model(*input) + output.backward(upward) + + +# Context manger not working in timer + + +def forw_back_fused(model, input, upward): + with torch.backends.cuda.sdp_kernel(enable_math=False, enable_mem_efficient=True): + output = model(*input) + output.backward(upward) + + +def forw_back_eager(model, input, upward): + with torch.backends.cuda.sdp_kernel(enable_math=True, enable_mem_efficient=False): + output = model(*input) + output.backward(upward) + + +def run_timing( + min_run_time, batch_size, embed_dimension, num_heads, max_sequence_len, dtype +): + dropout_p = 0.0 + mask = None + + pt = torch.nn.MultiheadAttention( + embed_dim=embed_dimension, + num_heads=num_heads, + batch_first=True, + dropout=dropout_p, + ) + npt = pt.cuda().to(dtype) + cpt = build_composite_mha_from_nn_mha(npt) + x = torch.randn( + batch_size, + max_sequence_len, + embed_dimension, + dtype=dtype, + device="cuda", + requires_grad=True, + ) + + with torch.backends.cuda.sdp_kernel(enable_math=False, enable_mem_efficient=True): + rand_fused_upward = cpt(x, x, x, mask).clone().detach() + + with torch.backends.cuda.sdp_kernel(enable_math=True, enable_mem_efficient=False): + rand_eager_upward = cpt(x, x, x, mask).clone().detach() + + t0 = benchmark.Timer( + stmt="forw_back_fused(cpt, (x,x,x,mask), rand_fused_upward)", + globals={ + "forw_back_fused": forw_back_fused, + "cpt": cpt, + "x": x, + "rand_fused_upward": rand_fused_upward, + "mask": mask, + }, + label=f"Fused SDP forward and backward batch_size={batch_size} max_sequence_len={max_sequence_len} " + f"num_heads={num_heads} embed_dimension={embed_dimension} dtype={dtype}", + num_threads=torch.get_num_threads(), + ) + + t1 = benchmark.Timer( + stmt="forw_back_eager(cpt, (x,x,x,mask), rand_eager_upward)", + globals={ + "forw_back_eager": forw_back_eager, + "cpt": cpt, + "x": x, + "rand_eager_upward": rand_eager_upward, + "mask": mask, + }, + label=f"Eager SDP forward and backward batch_size={batch_size} max_sequence_len={max_sequence_len} " + f"num_heads={num_heads} embed_dimension={embed_dimension} dtype={dtype}", + num_threads=torch.get_num_threads(), + ) + + m0 = t0.blocked_autorange(min_run_time=min_run_time) + m1 = t1.blocked_autorange(min_run_time=min_run_time) + + print(m0) + print(m1) + + activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA] + + print("Profile for Fused".center(200, "-")) + with torch.backends.cuda.sdp_kernel(enable_math=False, enable_mem_efficient=True): + with profile( + activities=activities, record_shapes=False, with_stack=True + ) as prof: + with record_function("Fused SDP forward and backward"): + for _ in range(20): + forw_back(cpt, (x, x, x, mask), rand_fused_upward) + print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=20)) + + print("Profile for eager".center(200, "-")) + with torch.backends.cuda.sdp_kernel(enable_math=True, enable_mem_efficient=False): + with profile( + activities=activities, record_shapes=False, with_stack=True + ) as prof: + with record_function("Fused SDP forward and backward"): + for _ in range(20): + forw_back(cpt, (x, x, x, mask), rand_eager_upward) + print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=20)) + + +def main(): + seed = 123 + np.random.seed(seed) + torch.manual_seed(seed) + random.seed(seed) + + min_run_time = 10 + batch_size = 64 + num_heads = 32 + max_seq_len = 256 + embed_dim = 1024 + dtype = torch.bfloat16 + + print( + f"Running timing for batch_size={batch_size} max_sequence_len={max_seq_len} " + f"num_heads={num_heads} embed_dimension={embed_dim} dtype={dtype}" + ) + run_timing(min_run_time, batch_size, embed_dim, num_heads, max_seq_len, dtype) + + +if __name__ == "__main__": + main() diff --git a/binaries/optimize_for_mobile.cc b/binaries/optimize_for_mobile.cc index 991bca7e55871..005b19ce888a4 100644 --- a/binaries/optimize_for_mobile.cc +++ b/binaries/optimize_for_mobile.cc @@ -16,13 +16,13 @@ #include #include -#include "torch/script.h" -#include "torch/csrc/jit/api/module.h" +#include +#include #include -#include "torch/csrc/jit/passes/vulkan_rewrite.h" -#include "torch/csrc/jit/passes/xnnpack_rewrite.h" -#include "torch/csrc/jit/serialization/import.h" -#include "torch/csrc/jit/serialization/export.h" +#include +#include +#include +#include C10_DEFINE_string(model, "", "The torch script model to optimize."); C10_DEFINE_string( @@ -86,7 +86,8 @@ int main(int argc, char** argv) { if (FLAGS_backend == "" || FLAGS_backend == "cpu") { optimized_module = torch::jit::optimizeForMobile(module); } else if (FLAGS_backend == "vulkan") { - optimized_module = torch::jit::vulkanOptimizeForMobile(module, preserved_methods); + optimized_module = torch::jit::vulkanOptimizeForMobile( + module, std::set(), preserved_methods); } else if (FLAGS_backend == "metal"){ optimized_module = torch::jit::metalOptimizeForMobile(module, preserved_methods); }else{ diff --git a/binaries/speed_benchmark_torch.cc b/binaries/speed_benchmark_torch.cc index ea523898b51e6..0fadfad5b9f28 100644 --- a/binaries/speed_benchmark_torch.cc +++ b/binaries/speed_benchmark_torch.cc @@ -180,6 +180,10 @@ class vkRunner final : public Runner { virtual c10::IValue run( T& module, const std::vector& inputs) override { + if (!module.attr("requires_backend_transfers", at::IValue(true)).toBool()) { + // No need to transfer input/output backends + return module.forward(inputs); + } if (inputs_.size() == 0) { // Upload the input tensor(s) to GPU memory. diff --git a/buckbuild.bzl b/buckbuild.bzl index d0185aa313a47..f84f21cd4d111 100644 --- a/buckbuild.bzl +++ b/buckbuild.bzl @@ -17,11 +17,12 @@ load( "aten_cpu_source_list", "aten_native_source_list", "core_sources_common", - "core_sources_full_mobile_no_backend_interface", + "core_sources_full_mobile_no_backend_interface_xplat", "core_trainer_sources", "jit_core_headers", "jit_core_sources", "libtorch_profiler_sources", + "torch_cpp_srcs", "torch_mobile_tracer_sources", ) load( @@ -97,6 +98,9 @@ def get_strip_error_messages(): return True # always strip in OSS CI to expose potential issues return read_bool("pt", "strip_error_messages", not _is_build_mode_dev()) +def get_disable_warn(): + return read_bool("pt", "disable_warn", False) + def get_enable_eager_symbolication(): return read_bool("pt", "enable_eager_symbolication", default = False, required = False) @@ -143,6 +147,7 @@ THIRD_PARTY_LIBS = { "rt": ["//xplat/third-party/linker_lib:rt", "//third_party:rt"], "ruy": ["//third-party/ruy:ruy_xplat_lib", "//third_party:ruy_lib"], "typing-extensions": ["//third-party/typing-extensions:typing-extensions", "//third_party:typing-extensions"], + "sleef_arm": ["//third-party/sleef:sleef_arm", "//third_party:sleef_arm"], } def third_party(name): @@ -199,6 +204,8 @@ _COMMON_PREPROCESSOR_FLAGS = [ ["-DC10_MOBILE_TRIM_DISPATCH_KEYS"] if get_enable_mobile_dispatch_keys_trimming() else [] ) + ( ["-DSTRIP_ERROR_MESSAGES"] if get_strip_error_messages() else [] +) + ( + ["-DDISABLE_WARN"] if get_disable_warn() else [] ) def get_aten_preprocessor_flags(): @@ -748,14 +755,13 @@ def get_pt_operator_registry_dict( "pt_operator_registry", ], deps = [ - # need absolute path here - ROOT + ":torch_mobile_core", - ROOT + ":aten_cpu", - ROOT + ":aten_metal_prepack_header", - third_party("glog"), - C10, - ] + ([ROOT + ":torch_mobile_train"] if train else []) + - ([ROOT + ":flatbuffers_mobile"] if enable_flatbuffer else []), + # need absolute path here + ROOT + ":torch_mobile_core", + ROOT + ":aten_cpu", + ROOT + ":aten_metal_prepack_header", + third_party("glog"), + C10, + ] + ([ROOT + ":torch_mobile_train"] if train else []), **kwargs ) @@ -1296,12 +1302,14 @@ def define_buck_targets( name = "torch_mobile_deserialize", srcs = [ "torch/csrc/jit/mobile/import.cpp", + "torch/csrc/jit/mobile/flatbuffer_loader.cpp", ], compiler_flags = get_pt_compiler_flags(), - exported_preprocessor_flags = get_pt_preprocessor_flags(), + exported_preprocessor_flags = get_pt_preprocessor_flags() + (["-DFB_XPLAT_BUILD"] if not IS_OSS else []), header_namespace = "", exported_headers = [ "torch/csrc/jit/mobile/import.h", + "torch/csrc/jit/mobile/flatbuffer_loader.h", ], # torch_mobile_deserialize brings in sources neccessary to read a module # which depends on mobile module definition @@ -1324,6 +1332,7 @@ def define_buck_targets( ":torch_mobile_module", ":torch_mobile_observer", ":torch_mobile_deserialize_common", + ":mobile_bytecode", C10, ], ) @@ -1369,12 +1378,21 @@ def define_buck_targets( ) pt_xplat_cxx_library( - name = "torch_core", - srcs = core_sources_full_mobile_no_backend_interface + [ - "torch/csrc/api/src/jit.cpp", - "torch/csrc/jit/serialization/export_bytecode.cpp", - "torch/csrc/jit/serialization/export_module.cpp", + name = "torch_cpp_cpu", + srcs = torch_cpp_srcs, + headers = native.glob(["torch/csrc/api/include/**/*.h"]) + ["torch/script.h"], + compiler_flags = get_pt_compiler_flags(), + exported_preprocessor_flags = get_pt_preprocessor_flags(), + visibility = ["PUBLIC"], + exported_deps = [ + ":torch", + ":torch_mobile_deserialize_common", # for torch/csrc/api/src/serialize/input-archive.cpp ], + ) + + pt_xplat_cxx_library( + name = "torch_core", + srcs = core_sources_full_mobile_no_backend_interface_xplat, compiler_flags = get_pt_compiler_flags(), exported_preprocessor_flags = get_pt_preprocessor_flags(), visibility = [ @@ -1423,6 +1441,7 @@ def define_buck_targets( ":torch_core", ":torch_mobile_deserialize", ":torch_mobile_train", + ":jit_module_saving", C10, ], ) @@ -1454,6 +1473,7 @@ def define_buck_targets( ":generated-autograd-headers", ":torch_headers", ":torch_mobile_deserialize", + ":flatbuffers_serializer_mobile", C10, ], ) @@ -1543,15 +1563,16 @@ def define_buck_targets( "torch/csrc/jit/serialization/export_module.cpp", ], compiler_flags = get_pt_compiler_flags(), - exported_preprocessor_flags = get_pt_preprocessor_flags(), + exported_preprocessor_flags = get_pt_preprocessor_flags() + + (["-DFB_XPLAT_BUILD"] if not IS_OSS else []), exported_headers = [ "torch/csrc/jit/serialization/export.h", - "torch/csrc/jit/serialization/flatbuffer_serializer_jit.h", ], visibility = ["PUBLIC"], deps = [ ":torch", ":torch_mobile_core", + ":flatbuffers_serializer_mobile", ], ) @@ -1598,6 +1619,7 @@ def define_buck_targets( ]), ) + #TODO(qihan) delete pt_xplat_cxx_library( name = "torch_mobile_core_flatbuffer", srcs = [], @@ -1619,9 +1641,7 @@ def define_buck_targets( exported_deps = [ ":aten_cpu", ":torch_common", - ] + ([] if IS_OSS else [ - "//xplat/caffe2/fb/runtime:torch_mobile_deserialize_flatbuffer", - ]), + ], ) fb_xplat_cxx_library( @@ -1713,13 +1733,13 @@ def define_buck_targets( name = "mobile_bytecode", header_namespace = "", exported_headers = { - "torch/csrc/jit/serialization/mobile_bytecode_generated_fbsource.h": ":mobile_bytecode_header[mobile_bytecode_generated_fbsource.h]", + ("torch/csrc/jit/serialization/mobile_bytecode_generated.h" if IS_OSS else "torch/csrc/jit/serialization/mobile_bytecode_generated_fbsource.h"): ":mobile_bytecode_header[mobile_bytecode_generated_fbsource.h]", }, # Avoid leaking implementation details by only exposing this header to # the internals of the loader/serializer layer. visibility = [ "{}:flatbuffer_loader".format(ROOT), - "{}:flatbuffer_serializer_mobile".format(ROOT), + "{}:flatbuffers_serializer_mobile".format(ROOT), ], exported_deps = [ third_party("flatbuffers-api"), @@ -1746,14 +1766,15 @@ def define_buck_targets( C10, ], exported_deps = [ - ":torch_mobile_train", + ":torch_mobile_deserialize", + ":mobile_bytecode", ], ) + # TODO (qihan) delete pt_xplat_cxx_library( name = "flatbuffer_loader", srcs = [ - "torch/csrc/jit/mobile/flatbuffer_loader.cpp", ], exported_headers = [ "torch/csrc/jit/mobile/flatbuffer_loader.h", @@ -1783,17 +1804,13 @@ def define_buck_targets( ":mobile_bytecode", ], exported_deps = [ - ":torch_mobile_deserialize", C10, ], ) + # TODO(qihan) delete fb_xplat_cxx_library( name = "flatbuffers_serializer_jit", - srcs = ["torch/csrc/jit/serialization/flatbuffer_serializer_jit.cpp"], - exported_headers = [ - "torch/csrc/jit/serialization/flatbuffer_serializer_jit.h", - ], compiler_flags = [ "-g0", "-O3", @@ -1801,6 +1818,12 @@ def define_buck_targets( "-frtti", "-Wno-deprecated-declarations", ], + headers = [ + "torch/csrc/jit/serialization/flatbuffer_serializer_jit.h", + ], + srcs = [ + "torch/csrc/jit/serialization/flatbuffer_serializer_jit.cpp", + ], linker_flags = [ "-Wl,--no-as-needed", ], @@ -1830,6 +1853,7 @@ def define_buck_targets( exported_deps = [ ":flatbuffer_loader", ":flatbuffers_serializer_mobile", + ":torch_mobile_train", ], ) @@ -1911,7 +1935,12 @@ def define_buck_targets( third_party("glog"), third_party("XNNPACK"), third_party("pocketfft"), - ], + ] + select({ + "DEFAULT": [], + "ovr_config//runtime:fbcode-arm64": [ + third_party("sleef_arm"), + ], + }), compiler_flags = get_aten_compiler_flags(), exported_preprocessor_flags = get_aten_preprocessor_flags(), exported_deps = [ diff --git a/build_variables.bzl b/build_variables.bzl index f1801b446ed8c..5af378f0a0a65 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -134,6 +134,7 @@ libtorch_profiler_sources = [ "torch/csrc/autograd/profiler_legacy.cpp", "torch/csrc/autograd/profiler_kineto.cpp", "torch/csrc/profiler/collection.cpp", + "torch/csrc/profiler/data_flow.cpp", "torch/csrc/profiler/kineto_shim.cpp", "torch/csrc/profiler/kineto_client_interface.cpp", "torch/csrc/profiler/orchestration/observer.cpp", @@ -142,6 +143,7 @@ libtorch_profiler_sources = [ "torch/csrc/profiler/standalone/itt_observer.cpp", "torch/csrc/profiler/standalone/nvtx_observer.cpp", "torch/csrc/profiler/stubs/base.cpp", + "torch/csrc/profiler/perf.cpp", "torch/csrc/monitor/counters.cpp", "torch/csrc/monitor/events.cpp", ] @@ -175,7 +177,28 @@ core_trainer_sources = [ "torch/csrc/jit/serialization/type_name_uniquer.cpp", ] -core_sources_full_mobile_no_backend_interface = [ +torch_mobile_core = [ + # backend_debug_info.cpp provides + # __torch__.torch.classes.backend.BackendDebugInfo class + # This should not be needed eventually. + # TODO: Remove this dependency + "torch/csrc/jit/backends/backend_debug_info.cpp", + "torch/csrc/jit/mobile/compatibility/model_compatibility.cpp", + "torch/csrc/jit/mobile/function.cpp", + "torch/csrc/jit/mobile/import.cpp", + "torch/csrc/jit/mobile/flatbuffer_loader.cpp", + "torch/csrc/jit/mobile/interpreter.cpp", + "torch/csrc/jit/mobile/module.cpp", + "torch/csrc/jit/mobile/observer.cpp", + "torch/csrc/jit/mobile/parse_bytecode.cpp", + "torch/csrc/jit/mobile/parse_operators.cpp", + "torch/csrc/jit/mobile/quantization.cpp", + "torch/csrc/jit/mobile/upgrader_mobile.cpp", + "torch/csrc/jit/runtime/register_prim_ops.cpp", + "torch/csrc/jit/runtime/register_special_ops.cpp", +] + +core_sources_full_mobile_no_backend_interface_xplat = [ "torch/csrc/jit/api/function_impl.cpp", "torch/csrc/jit/api/module.cpp", "torch/csrc/jit/api/object.cpp", @@ -383,6 +406,26 @@ core_sources_full_mobile_no_backend_interface = [ "torch/csrc/utils/variadic.cpp", ] +core_sources_full_mobile_no_backend_interface = core_sources_full_mobile_no_backend_interface_xplat + [ + # backend_debug_info.cpp provides + # __torch__.torch.classes.backend.BackendDebugInfo class + # This should not be needed eventually. + # TODO: Remove this dependency + "torch/csrc/jit/backends/backend_debug_info.cpp", + "torch/csrc/jit/mobile/compatibility/model_compatibility.cpp", + "torch/csrc/jit/mobile/function.cpp", + "torch/csrc/jit/mobile/import.cpp", + "torch/csrc/jit/mobile/flatbuffer_loader.cpp", + "torch/csrc/jit/mobile/interpreter.cpp", + "torch/csrc/jit/mobile/module.cpp", + "torch/csrc/jit/mobile/observer.cpp", + "torch/csrc/jit/mobile/parse_bytecode.cpp", + "torch/csrc/jit/mobile/parse_operators.cpp", + "torch/csrc/jit/mobile/quantization.cpp", + "torch/csrc/jit/mobile/upgrader_mobile.cpp", +] + + core_sources_full_mobile = core_sources_full_mobile_no_backend_interface + [ "torch/csrc/jit/backends/backend_debug_info.cpp", "torch/csrc/jit/backends/backend_interface.cpp", @@ -414,7 +457,6 @@ lazy_tensor_core_sources = [ "torch/csrc/lazy/core/ir_metadata.cpp", "torch/csrc/lazy/core/ir_util.cpp", "torch/csrc/lazy/core/lazy_graph_executor.cpp", - "torch/csrc/lazy/core/lazy_view.cpp", "torch/csrc/lazy/core/metrics.cpp", "torch/csrc/lazy/core/multi_wait.cpp", "torch/csrc/lazy/core/ops/arithmetic_ir_ops.cpp", @@ -562,28 +604,6 @@ torch_mobile_tracer_sources = [ "torch/csrc/jit/mobile/model_tracer/BuildFeatureTracer.cpp", ] -torch_mobile_core = [ - # backend_debug_info.cpp provides - # __torch__.torch.classes.backend.BackendDebugInfo class - # This should not be needed eventually. - # TODO: Remove this dependency - "torch/csrc/jit/backends/backend_debug_info.cpp", - "torch/csrc/jit/mobile/compatibility/model_compatibility.cpp", - # TODO: This line needs to be uncommented to build mobile in OSS with flatbuffers - # "torch/csrc/jit/mobile/flatbuffer_loader.cpp", - "torch/csrc/jit/mobile/function.cpp", - "torch/csrc/jit/mobile/import.cpp", - "torch/csrc/jit/mobile/interpreter.cpp", - "torch/csrc/jit/mobile/module.cpp", - "torch/csrc/jit/mobile/observer.cpp", - "torch/csrc/jit/mobile/parse_bytecode.cpp", - "torch/csrc/jit/mobile/parse_operators.cpp", - "torch/csrc/jit/mobile/quantization.cpp", - "torch/csrc/jit/mobile/upgrader_mobile.cpp", - "torch/csrc/jit/runtime/register_prim_ops.cpp", - "torch/csrc/jit/runtime/register_special_ops.cpp", -] - libtorch_lite_eager_symbolication = [ "torch/csrc/jit/frontend/source_range.cpp", "torch/csrc/jit/ir/scope.cpp", @@ -620,6 +640,7 @@ libtorch_extra_sources = libtorch_core_jit_sources + [ # when it is built in libtorch "torch/csrc/jit/mobile/debug_info.cpp", "torch/csrc/jit/mobile/function.cpp", + "torch/csrc/jit/mobile/flatbuffer_loader.cpp", "torch/csrc/jit/mobile/import.cpp", "torch/csrc/jit/mobile/import_data.cpp", "torch/csrc/jit/mobile/interpreter.cpp", @@ -637,24 +658,16 @@ libtorch_extra_sources = libtorch_core_jit_sources + [ "torch/csrc/jit/serialization/export.cpp", "torch/csrc/jit/serialization/export_bytecode.cpp", "torch/csrc/jit/serialization/export_module.cpp", + "torch/csrc/jit/serialization/flatbuffer_serializer.cpp", "torch/csrc/jit/serialization/import_legacy.cpp", "torch/csrc/utils/byte_order.cpp", "torch/csrc/utils/out_types.cpp", ] def libtorch_sources(gencode_pattern = ":generate-code[{}]"): - enable_flatbuffer = bool(native.read_config("fbcode", "caffe2_enable_flatbuffer", None)) - flatbuffer_serializer_sources = [ - "torch/csrc/jit/serialization/flatbuffer_serializer.cpp", - "torch/csrc/jit/serialization/flatbuffer_serializer_jit.cpp", - ] - if enable_flatbuffer: - return ( - libtorch_generated_sources(gencode_pattern) + libtorch_core_sources + libtorch_distributed_sources + libtorch_extra_sources + - flatbuffer_serializer_sources - ) - else: - return libtorch_generated_sources(gencode_pattern) + libtorch_core_sources + libtorch_distributed_sources + libtorch_extra_sources + return ( + libtorch_generated_sources(gencode_pattern) + libtorch_core_sources + libtorch_distributed_sources + libtorch_extra_sources + ) libtorch_cuda_core_sources = [ "torch/csrc/CudaIPCTypes.cpp", @@ -665,7 +678,7 @@ libtorch_cuda_core_sources = [ "torch/csrc/autograd/functions/comm.cpp", "torch/csrc/jit/codegen/cuda/arith.cpp", "torch/csrc/jit/codegen/cuda/compute_at.cpp", - "torch/csrc/jit/codegen/cuda/inline_propagator.cpp", + "torch/csrc/jit/codegen/cuda/inlining.cpp", "torch/csrc/jit/codegen/cuda/compute_at_map.cpp", "torch/csrc/jit/codegen/cuda/codegen.cpp", "torch/csrc/jit/codegen/cuda/contiguity.cpp", @@ -699,6 +712,7 @@ libtorch_cuda_core_sources = [ "torch/csrc/jit/codegen/cuda/lower_alias_memory.cpp", "torch/csrc/jit/codegen/cuda/lower_allocation.cpp", "torch/csrc/jit/codegen/cuda/lower_double_buffer.cpp", + "torch/csrc/jit/codegen/cuda/lower_divisible_split.cpp", "torch/csrc/jit/codegen/cuda/lower_expr_sort.cpp", "torch/csrc/jit/codegen/cuda/lower_fused_reduction.cpp", "torch/csrc/jit/codegen/cuda/lower_fusion_simplifier.cpp", @@ -722,6 +736,7 @@ libtorch_cuda_core_sources = [ "torch/csrc/jit/codegen/cuda/lower_validation.cpp", "torch/csrc/jit/codegen/cuda/lower_warp_reduce.cpp", "torch/csrc/jit/codegen/cuda/lower2device.cpp", + "torch/csrc/jit/codegen/cuda/lower_bank_conflict.cpp", "torch/csrc/jit/codegen/cuda/manager.cpp", "torch/csrc/jit/codegen/cuda/maxinfo_propagator.cpp", "torch/csrc/jit/codegen/cuda/mutator.cpp", @@ -749,6 +764,7 @@ libtorch_cuda_core_sources = [ "torch/csrc/jit/codegen/cuda/scheduler/reduction_utils.cpp", "torch/csrc/jit/codegen/cuda/scheduler/registry.cpp", "torch/csrc/jit/codegen/cuda/scheduler/utils.cpp", + "torch/csrc/jit/codegen/cuda/scheduler/vectorize_helper.cpp", "torch/csrc/jit/codegen/cuda/type_inference.cpp", "torch/csrc/jit/codegen/cuda/type_promotion.cpp", "torch/csrc/jit/codegen/cuda/fusion_segmenter.cpp", @@ -799,6 +815,7 @@ torch_cpp_srcs = [ "torch/csrc/api/src/imethod.cpp", "torch/csrc/api/src/jit.cpp", "torch/csrc/api/src/serialize.cpp", + "torch/csrc/api/src/mps.cpp", "torch/csrc/api/src/nn/init.cpp", "torch/csrc/api/src/nn/module.cpp", "torch/csrc/api/src/nn/modules/_functions.cpp", @@ -858,6 +875,7 @@ libtorch_python_cuda_core_sources = [ "torch/csrc/cuda/shared/cudart.cpp", "torch/csrc/cuda/shared/nvtx.cpp", "torch/csrc/cuda/utils.cpp", + "torch/csrc/cuda/CUDAPluggableAllocator.cpp", ] libtorch_python_cuda_sources = libtorch_python_cuda_core_sources + [ @@ -895,6 +913,7 @@ libtorch_python_core_sources = [ "torch/csrc/autograd/python_function.cpp", "torch/csrc/autograd/python_hook.cpp", "torch/csrc/autograd/python_legacy_variable.cpp", + "torch/csrc/autograd/python_nested_functions_manual.cpp", "torch/csrc/autograd/python_torch_functions_manual.cpp", "torch/csrc/autograd/python_variable.cpp", "torch/csrc/autograd/python_variable_indexing.cpp", @@ -902,6 +921,7 @@ libtorch_python_core_sources = [ "torch/csrc/dynamo/guards.cpp", "torch/csrc/dynamo/init.cpp", "torch/csrc/functorch/init.cpp", + "torch/csrc/mps/Module.cpp", "torch/csrc/jit/backends/backend_init.cpp", "torch/csrc/jit/codegen/cuda/python_frontend/python_bindings.cpp", "torch/csrc/jit/python/init.cpp", @@ -956,9 +976,12 @@ libtorch_python_core_sources = [ "torch/csrc/utils.cpp", "torch/csrc/utils/cuda_lazy_init.cpp", "torch/csrc/utils/invalid_arguments.cpp", + "torch/csrc/utils/nested.cpp", "torch/csrc/utils/object_ptr.cpp", "torch/csrc/utils/python_arg_parser.cpp", "torch/csrc/utils/python_dispatch.cpp", + "torch/csrc/utils/python_symnode.cpp", + "torch/csrc/utils/pybind.cpp", "torch/csrc/utils/structseq.cpp", "torch/csrc/utils/tensor_apply.cpp", "torch/csrc/utils/tensor_dtypes.cpp", @@ -1019,7 +1042,7 @@ def glob_libtorch_python_sources(gencode_pattern = ":generate-code[{}]"): aten_cpu_source_non_codegen_list = [ "aten/src/ATen/AccumulateType.cpp", - "aten/src/ATen/BatchedTensorImpl.cpp", + "aten/src/ATen/LegacyBatchedTensorImpl.cpp", "aten/src/ATen/CPUGeneratorImpl.cpp", "aten/src/ATen/Context.cpp", "aten/src/ATen/DLConvertor.cpp", @@ -1053,8 +1076,8 @@ aten_cpu_source_non_codegen_list = [ "aten/src/ATen/FuncTorchTLS.cpp", "aten/src/ATen/Utils.cpp", "aten/src/ATen/Version.cpp", - "aten/src/ATen/VmapMode.cpp", - "aten/src/ATen/VmapTransforms.cpp", + "aten/src/ATen/LegacyVmapMode.cpp", + "aten/src/ATen/LegacyVmapTransforms.cpp", "aten/src/ATen/core/BackendSelectFallbackKernel.cpp", "aten/src/ATen/core/DeprecatedTypeProperties.cpp", "aten/src/ATen/core/DeprecatedTypePropertiesRegistry.cpp", @@ -1092,6 +1115,7 @@ aten_cpu_source_non_codegen_list = [ "aten/src/ATen/cpu/FlushDenormal.cpp", "aten/src/ATen/detail/CPUGuardImpl.cpp", "aten/src/ATen/detail/CUDAHooksInterface.cpp", + "aten/src/ATen/detail/MPSHooksInterface.cpp", "aten/src/ATen/detail/HIPHooksInterface.cpp", "aten/src/ATen/detail/ORTHooksInterface.cpp", "aten/src/ATen/metal/Context.cpp", @@ -1283,7 +1307,7 @@ aten_native_source_non_codegen_list = [ "aten/src/ATen/native/AveragePool3d.cpp", "aten/src/ATen/native/BatchLinearAlgebra.cpp", "aten/src/ATen/native/BatchLinearAlgebraKernel.cpp", - "aten/src/ATen/native/Batching.cpp", + "aten/src/ATen/native/LegacyBatching.cpp", "aten/src/ATen/native/BinaryOps.cpp", "aten/src/ATen/native/Blas.cpp", "aten/src/ATen/native/BlasKernel.cpp", @@ -1399,11 +1423,14 @@ aten_native_source_non_codegen_list = [ "aten/src/ATen/native/mkl/SparseBlasImpl.cpp", "aten/src/ATen/native/mkl/SparseCsrLinearAlgebra.cpp", "aten/src/ATen/native/mkl/SpectralOps.cpp", + "aten/src/ATen/native/nested/NestedTensorAliases.cpp", + "aten/src/ATen/native/nested/NestedTensorBackward.cpp", + "aten/src/ATen/native/nested/NestedTensorBinaryOps.cpp", "aten/src/ATen/native/nested/NestedTensorFactories.cpp", "aten/src/ATen/native/nested/NestedTensorMath.cpp", - "aten/src/ATen/native/nested/NestedTensorAliases.cpp", + "aten/src/ATen/native/nested/NestedTensorMatmul.cpp", "aten/src/ATen/native/nested/NestedTensorTransformerFunctions.cpp", - "aten/src/ATen/native/nested/NestedTensorBackward.cpp", + "aten/src/ATen/native/nested/NestedTensorUnaryOps.cpp", "aten/src/ATen/native/nested/NestedTensorUtils.cpp", "aten/src/ATen/native/sparse/ParamUtils.cpp", "aten/src/ATen/native/sparse/SoftMax.cpp", @@ -1478,3 +1505,33 @@ aten_cuda_with_sort_by_key_source_list = [ aten_cuda_cu_with_sort_by_key_source_list = [ "aten/src/ATen/native/cuda/Unique.cu", ] + +# Followings are source code for xnnpack delegate + +xnnpack_delegate_serializer_header = [ + "torch/csrc/jit/backends/xnnpack/serialization/serializer.h", +] + +xnnpack_delegate_serializer_source_list = [ + "torch/csrc/jit/backends/xnnpack/serialization/serializer.cpp", +] + +xnnpack_delegate_core_source_list = [ + "torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.cpp", +] + +xnnpack_delegate_core_header = [ + "torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.h", + "torch/csrc/jit/backends/xnnpack/executor/xnn_executor.h", +] + +xnnpack_backend_header = [ + "torch/csrc/jit/backends/xnnpack/xnnpack_graph_builder.h", +] + xnnpack_delegate_core_header + +xnnpack_backend_source_list = [ + "torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.cpp", + "torch/csrc/jit/backends/xnnpack/xnnpack_backend_lib.cpp", + "torch/csrc/jit/backends/xnnpack/xnnpack_backend_preprocess.cpp", + "torch/csrc/jit/backends/xnnpack/xnnpack_graph_builder.cpp", +] + xnnpack_delegate_core_source_list diff --git a/c10/CMakeLists.txt b/c10/CMakeLists.txt index 0309d7a2d712e..9c80fa9051ab6 100644 --- a/c10/CMakeLists.txt +++ b/c10/CMakeLists.txt @@ -1,7 +1,7 @@ cmake_minimum_required(VERSION 3.10 FATAL_ERROR) project(c10 CXX) -set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ standard whose features are requested to build this target.") +set(CMAKE_CXX_STANDARD 17 CACHE STRING "The C++ standard whose features are requested to build this target.") set(CMAKE_EXPORT_COMPILE_COMMANDS ON) # Main build file for the C10 library. diff --git a/c10/core/Device.cpp b/c10/core/Device.cpp index 7b55d2dbe283b..96d2504ec7de5 100644 --- a/c10/core/Device.cpp +++ b/c10/core/Device.cpp @@ -47,6 +47,9 @@ DeviceType parse_type(const std::string& device_string) { if (device != types.end()) { return device->second; } + if (device_string == get_privateuse1_backend()) { + return DeviceType::PrivateUse1; + } std::vector device_names; for (const auto& it : types) { if (it.first) { diff --git a/c10/core/Device.h b/c10/core/Device.h index cea7cfec119e9..d53ab38ff9cb9 100644 --- a/c10/core/Device.h +++ b/c10/core/Device.h @@ -148,7 +148,8 @@ struct C10_API Device final { /// Return true if the device supports arbirtary strides. bool supports_as_strided() const noexcept { - return type_ != DeviceType::XLA && type_ != DeviceType::Lazy; + return type_ != DeviceType::IPU && type_ != DeviceType::XLA && + type_ != DeviceType::Lazy; } /// Same string as returned from operator<<. diff --git a/c10/core/DeviceType.cpp b/c10/core/DeviceType.cpp index ac4c1f653efbf..22f0029d747d4 100644 --- a/c10/core/DeviceType.cpp +++ b/c10/core/DeviceType.cpp @@ -1,5 +1,9 @@ #include #include +#include +#include +#include +#include namespace c10 { @@ -46,7 +50,7 @@ std::string DeviceTypeName(DeviceType d, bool lower_case) { case DeviceType::IPU: return lower_case ? "ipu" : "IPU"; case DeviceType::PrivateUse1: - return lower_case ? "privateuseone" : "PRIVATEUSEONE"; + return get_privateuse1_backend(/*lowercase=*/lower_case); default: TORCH_CHECK( false, @@ -101,4 +105,46 @@ std::ostream& operator<<(std::ostream& stream, DeviceType type) { return stream; } +// We use both a mutex and an atomic here because: +// (1) Mutex is needed during writing: +// We need to first check the value and potentially error, +// before setting the value (without any one else racing in the middle). +// It's also totally fine for this to be slow, since it happens exactly once +// at import time. +// (2) Atomic is needed during reading: +// Whenever a user prints a privatuse1 device name, they need to read this +// variable. Although unlikely, we'll data race if someone else is trying to +// set this variable at the same time that another thread is print the +// device name. We could re-use the same mutex, but reading the atomic will +// be much faster. +static std::atomic privateuse1_backend_name_set; +static std::string privateuse1_backend_name; +static std::mutex privateuse1_lock; + +std::string get_privateuse1_backend(bool lower_case) { + // Applying the same atomic read memory ordering logic as in Note [Memory + // ordering on Python interpreter tag]. + auto name_registered = + privateuse1_backend_name_set.load(std::memory_order_acquire); + // Guaranteed that if the flag is set, then privateuse1_backend_name has been + // set, and will never be written to. + auto backend_name = + name_registered ? privateuse1_backend_name : "privateuseone"; + return backend_name; +} + +void register_privateuse1_backend(std::string backend_name) { + std::lock_guard guard(privateuse1_lock); + TORCH_CHECK( + !privateuse1_backend_name_set.load() || + privateuse1_backend_name == backend_name, + "torch.register_privateuse1_backend() has already been set! Current backend: ", + privateuse1_backend_name); + + privateuse1_backend_name = backend_name; + // Invariant: once this flag is set, privateuse1_backend_name is NEVER written + // to. + privateuse1_backend_name_set.store(true, std::memory_order_relaxed); +} + } // namespace c10 diff --git a/c10/core/DeviceType.h b/c10/core/DeviceType.h index 000ad331828b0..065444827833d 100644 --- a/c10/core/DeviceType.h +++ b/c10/core/DeviceType.h @@ -95,6 +95,9 @@ C10_API bool isValidDeviceType(DeviceType d); C10_API std::ostream& operator<<(std::ostream& stream, DeviceType type); +C10_API void register_privateuse1_backend(std::string backend_name); +C10_API std::string get_privateuse1_backend(bool lower_case = true); + } // namespace c10 namespace std { diff --git a/c10/core/DispatchKey.cpp b/c10/core/DispatchKey.cpp index bffafc59168c6..0bbea6a4f078a 100644 --- a/c10/core/DispatchKey.cpp +++ b/c10/core/DispatchKey.cpp @@ -189,6 +189,8 @@ const char* toString(DispatchKey t) { return "CompositeExplicitAutograd"; case DispatchKey::CompositeExplicitAutogradNonFunctional: return "CompositeExplicitAutogradNonFunctional"; + case DispatchKey::FuncTorchBatchedDecomposition: + return "FuncTorchBatchedDecomposition"; // Per-backend dispatch keys @@ -317,6 +319,7 @@ c10::DispatchKey parseDispatchKey(const std::string& k) { {"SparseHIP", c10::DispatchKey::SparseHIP}, {"SparseXPU", c10::DispatchKey::SparseXPU}, {"SparseVE", c10::DispatchKey::SparseVE}, + {"SparseMeta", c10::DispatchKey::SparseMeta}, {"AutogradCPU", c10::DispatchKey::AutogradCPU}, {"AutogradCUDA", c10::DispatchKey::AutogradCUDA}, @@ -340,6 +343,8 @@ c10::DispatchKey parseDispatchKey(const std::string& k) { c10::DispatchKey::CompositeExplicitAutograd}, {"CompositeExplicitAutogradNonFunctional", c10::DispatchKey::CompositeExplicitAutogradNonFunctional}, + {"FuncTorchBatchedDecomposition", + c10::DispatchKey::FuncTorchBatchedDecomposition}, }; auto it = key_map.find(k); TORCH_CHECK(it != key_map.end(), "could not parse dispatch key: ", k); diff --git a/c10/core/DispatchKey.h b/c10/core/DispatchKey.h index 33f762e9da7d2..b28f770290e31 100644 --- a/c10/core/DispatchKey.h +++ b/c10/core/DispatchKey.h @@ -444,6 +444,15 @@ enum class DispatchKey : uint16_t { Autograd, CompositeImplicitAutograd, // registered at // build/aten/src/ATen/RegisterCompositeImplicitAutograd.cpp + + // Note: The alias keyset for FuncTorchBatchedDecomposition is disjoint from + // all + // other alias keysets + // and so precedence order doesn't matter + FuncTorchBatchedDecomposition, // registered at + // build/aten/src/ATen/RegisterFuncTorchBatchedDecomposition.cpp + // Note: The alias keyset for CompositeImplicitAutogradNestedTensor is + // disjoint from all other alias keysets CompositeImplicitAutogradNestedTensor, // registered at // build/aten/src/ATen/RegisterCompositeImplicitAutogradNestedTensor.cpp CompositeExplicitAutograd, // registered at diff --git a/c10/core/DispatchKeySet.cpp b/c10/core/DispatchKeySet.cpp index a8f60451be379..f180008a102c5 100644 --- a/c10/core/DispatchKeySet.cpp +++ b/c10/core/DispatchKeySet.cpp @@ -101,6 +101,8 @@ bool runtimeDispatchKeySetHas(DispatchKey t, DispatchKey k) { // See Note [NestedTensor Not Included in Backend Keys] return k != DispatchKey::NestedTensor && non_functional_backend_dispatch_keyset.has(k); + case DispatchKey::FuncTorchBatchedDecomposition: + return functorch_batched_ks.has(k); default: return t == k; } diff --git a/c10/core/DispatchKeySet.h b/c10/core/DispatchKeySet.h index 33916492a0ef5..a2f7b31fa9c5a 100644 --- a/c10/core/DispatchKeySet.h +++ b/c10/core/DispatchKeySet.h @@ -757,6 +757,9 @@ constexpr auto functorch_transforms_ks = DispatchKeySet( DispatchKey::VmapMode, DispatchKey::FuncTorchGradWrapper}); +constexpr auto functorch_batched_ks = + DispatchKeySet({DispatchKey::FuncTorchBatched}); + // This keyset has: // (1) the functionality bits corresponding to backends (dense, sparse, // quantized) (2) all of the backend bits set @@ -876,7 +879,10 @@ static inline DispatchKey legacyExtractDispatchKey(DispatchKeySet s) { // treatment; return (s - autograd_dispatch_keyset_with_ADInplaceOrView - autocast_dispatch_keyset - - DispatchKeySet({DispatchKey::PythonTLSSnapshot, DispatchKey::Python})) + DispatchKeySet( + {DispatchKey::Functionalize, + DispatchKey::PythonTLSSnapshot, + DispatchKey::Python})) .highestPriorityTypeId(); } diff --git a/c10/core/Scalar.h b/c10/core/Scalar.h index c0d89315b65db..0c124177e38f7 100644 --- a/c10/core/Scalar.h +++ b/c10/core/Scalar.h @@ -92,8 +92,8 @@ class C10_API Scalar { SymInt toSymInt() const { if (Tag::HAS_si == tag) { - return c10::SymInt::toSymInt(intrusive_ptr::reclaim_copy( - static_cast(v.p))); + return c10::SymInt(intrusive_ptr::reclaim_copy( + static_cast(v.p))); } else { return toLong(); } @@ -101,9 +101,8 @@ class C10_API Scalar { SymFloat toSymFloat() const { if (Tag::HAS_sd == tag) { - return c10::SymFloat::toSymFloat( - intrusive_ptr::reclaim_copy( - static_cast(v.p))); + return c10::SymFloat(intrusive_ptr::reclaim_copy( + static_cast(v.p))); } else { return toDouble(); } diff --git a/c10/core/Storage.h b/c10/core/Storage.h index a89a0039fdfe6..09c5920b56493 100644 --- a/c10/core/Storage.h +++ b/c10/core/Storage.h @@ -76,7 +76,7 @@ struct C10_API Storage { } void set_nbytes(c10::SymInt size_bytes) const { - storage_impl_.get()->set_nbytes(size_bytes); + storage_impl_.get()->set_nbytes(std::move(size_bytes)); } bool resizable() const { diff --git a/c10/core/StorageImpl.h b/c10/core/StorageImpl.h index bbf0803842537..1d80daed871a2 100644 --- a/c10/core/StorageImpl.h +++ b/c10/core/StorageImpl.h @@ -112,7 +112,7 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target { } void set_nbytes(c10::SymInt size_bytes) { - size_bytes_ = size_bytes; + size_bytes_ = std::move(size_bytes); } bool resizable() const { diff --git a/c10/core/SymFloat.cpp b/c10/core/SymFloat.cpp index 0ba980a9727ea..161313c777dda 100644 --- a/c10/core/SymFloat.cpp +++ b/c10/core/SymFloat.cpp @@ -1,79 +1,91 @@ #include -#include +#include #include +#include +#include namespace c10 { -SymFloatNode SymFloat::toSymFloatNodeImpl() const { +SymNode SymFloat::toSymNodeImpl() const { TORCH_CHECK(is_symbolic()); - return SymFloatNode::reclaim_copy(toSymFloatNodeImplUnowned()); + return SymNode::reclaim_copy(toSymNodeImplUnowned()); } -static std::array normalize_symfloats( - SymFloat a_, - SymFloat b_) { - SymFloatNode a, b; +static std::array normalize_symfloats( + const SymFloat& a_, + const SymFloat& b_) { + SymNode a, b; if (a_.is_symbolic()) - a = a_.toSymFloatNodeImpl(); + a = a_.toSymNodeImpl(); if (b_.is_symbolic()) - b = b_.toSymFloatNodeImpl(); + b = b_.toSymNodeImpl(); - SymFloatNodeImpl* common = a ? a.get() : b.get(); - // TODO: technically we need to check that the classes match + SymNodeImpl* common = a ? a.get() : b.get(); if (!a) { - a = common->wrap(a_.as_float_unchecked()); - a_.toSymFloat(a); // + a = common->wrap_float(a_.as_float_unchecked()); } if (!b) { - b = common->wrap(b_.as_float_unchecked()); - b_.toSymFloat(b); + b = common->wrap_float(b_.as_float_unchecked()); } - return {a, b}; + return {std::move(a), std::move(b)}; } -SymFloat SymFloat::operator+(SymFloat sci) const { +SymFloat SymFloat::operator+(const SymFloat& sci) const { if (!is_symbolic() && !sci.is_symbolic()) { return SymFloat(data_ + sci.data_); } auto res = normalize_symfloats(*this, sci); - return SymFloat::toSymFloat(res[0]->add(res[1])); + return SymFloat(res[0]->add(res[1])); } -SymFloat SymFloat::operator-(SymFloat sci) const { +SymFloat SymFloat::operator-(const SymFloat& sci) const { if (!is_symbolic() && !sci.is_symbolic()) { return SymFloat(data_ - sci.data_); } auto res = normalize_symfloats(*this, sci); - return SymFloat::toSymFloat(res[0]->sub(res[1])); + return SymFloat(res[0]->sub(res[1])); } -SymFloat SymFloat::operator*(SymFloat sci) const { +SymFloat SymFloat::operator*(const SymFloat& sci) const { if (!is_symbolic() && !sci.is_symbolic()) { return SymFloat(data_ * sci.data_); } auto res = normalize_symfloats(*this, sci); - return SymFloat::toSymFloat(res[0]->mul(res[1])); + return SymFloat(res[0]->mul(res[1])); } -SymFloat SymFloat::operator/(SymFloat sci) const { +SymFloat SymFloat::operator/(const SymFloat& sci) const { if (!is_symbolic() && !sci.is_symbolic()) { return SymFloat(data_ / sci.data_); } auto res = normalize_symfloats(*this, sci); - return SymFloat::toSymFloat(res[0]->truediv(res[1])); + return SymFloat(res[0]->truediv(res[1])); } -c10::SymFloat SymFloat::toSymFloat(SymFloatNode sin_sp) { - return c10::SymFloat(std::move(sin_sp)); -} - -std::ostream& operator<<(std::ostream& os, SymFloat s) { +std::ostream& operator<<(std::ostream& os, const SymFloat& s) { if (s.is_symbolic()) { - os << s.toSymFloatNodeImpl()->str(); + os << s.toSymNodeImpl()->str(); } else { os << s.as_float_unchecked(); } return os; } +SymFloat SymFloat::sqrt() const { + if (!is_symbolic()) { + return SymFloat(std::sqrt(data_)); + } + auto other = SymFloat(-0.5); + auto res = normalize_symfloats(*this, other); + return SymFloat(res[0]->pow(res[1])); +} + +double SymFloat::guard_float(const char* file, int64_t line) const { + if (!is_symbolic()) { + return data_; + } + SymNode a = toSymNodeImpl(); + return a->guard_float(file, line); +} + } // namespace c10 diff --git a/c10/core/SymFloat.h b/c10/core/SymFloat.h index 92abb81ea2a22..50512dc6fb206 100644 --- a/c10/core/SymFloat.h +++ b/c10/core/SymFloat.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include #include @@ -14,30 +14,44 @@ namespace c10 { class C10_API SymFloat { public: /*implicit*/ SymFloat(double d) : data_(d){}; - SymFloat(SymFloatNode ptr) - : data_(std::numeric_limits::quiet_NaN()), ptr_(std::move(ptr)){}; + SymFloat(SymNode ptr) + : data_(std::numeric_limits::quiet_NaN()), ptr_(std::move(ptr)) { + TORCH_CHECK(ptr_->is_float()); + }; SymFloat() : data_(0.0) {} - SymFloatNodeImpl* toSymFloatNodeImplUnowned() const { + SymNodeImpl* toSymNodeImplUnowned() const { return ptr_.get(); } - SymFloatNodeImpl* release() && { + SymNodeImpl* release() && { return std::move(ptr_).release(); } - SymFloatNode toSymFloatNodeImpl() const; - static c10::SymFloat toSymFloat(SymFloatNode sin); + SymNode toSymNodeImpl() const; double expect_float() const { TORCH_CHECK(!is_symbolic()); return data_; } - SymFloat operator+(SymFloat) const; - SymFloat operator-(SymFloat) const; - SymFloat operator*(SymFloat) const; - SymFloat operator/(SymFloat) const; + SymFloat operator+(const SymFloat&) const; + SymFloat operator-(const SymFloat&) const; + SymFloat operator*(const SymFloat&) const; + SymFloat operator/(const SymFloat&) const; + + // Need guidance on where to put this code + SymFloat sqrt() const; + + // Insert a guard for the float to be its concrete value, and then return + // that value. This operation always works, even if the float is symbolic, + // so long as we know what the underlying value is. Don't blindly put this + // everywhere; you can cause overspecialization of PyTorch programs with + // this method. + // + // It should be called as guard_float(__FILE__, __LINE__). The file and line + // number can be used to diagnose overspecialization. + double guard_float(const char* file, int64_t line) const; // N.B. It's important to keep this definition in the header // as we expect if checks to be folded for mobile builds @@ -53,8 +67,8 @@ class C10_API SymFloat { private: // TODO: optimize to union double data_; - SymFloatNode ptr_; + SymNode ptr_; }; -C10_API std::ostream& operator<<(std::ostream& os, SymFloat s); +C10_API std::ostream& operator<<(std::ostream& os, const SymFloat& s); } // namespace c10 diff --git a/c10/core/SymFloatNodeImpl.cpp b/c10/core/SymFloatNodeImpl.cpp deleted file mode 100644 index 714ee095d84e3..0000000000000 --- a/c10/core/SymFloatNodeImpl.cpp +++ /dev/null @@ -1,20 +0,0 @@ -#include -#include -#include - -namespace c10 { - -c10::SymFloat SymFloatNodeImpl::toSymFloat() { - auto sit_sp = SymFloatNode::reclaim_copy(this); - return SymFloat::toSymFloat(sit_sp); -} - -c10::SymIntNode SymFloatNodeImpl::ceil() { - TORCH_CHECK(false, "NYI"); -} - -c10::SymIntNode SymFloatNodeImpl::floor() { - TORCH_CHECK(false, "NYI"); -} - -} // namespace c10 diff --git a/c10/core/SymFloatNodeImpl.h b/c10/core/SymFloatNodeImpl.h deleted file mode 100644 index 0ab9d952b5bbc..0000000000000 --- a/c10/core/SymFloatNodeImpl.h +++ /dev/null @@ -1,76 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include - -namespace c10 { - -class SymIntNodeImpl; -using SymIntNode = c10::intrusive_ptr; - -class SymFloat; -class SymFloatNodeImpl; -using SymFloatNode = c10::intrusive_ptr; - -class C10_API SymFloatNodeImpl : public c10::intrusive_ptr_target { - public: - c10::SymFloat toSymFloat(); - virtual ~SymFloatNodeImpl(){}; - - template - c10::intrusive_ptr dyn_cast() const { - return c10::intrusive_ptr::reclaim_copy(dynamic_cast(this)); - } - - virtual SymFloatNode wrap(double num) { - TORCH_CHECK(false, "NYI"); - }; - virtual SymFloatNode add(const SymFloatNode& other) { - TORCH_CHECK(false, "NYI"); - } - virtual SymFloatNode sub(const SymFloatNode& other) { - TORCH_CHECK(false, "NYI"); - } - virtual SymFloatNode mul(const SymFloatNode& other) { - TORCH_CHECK(false, "NYI"); - } - virtual SymFloatNode truediv(const SymFloatNode& other) { - TORCH_CHECK(false, "NYI"); - } - virtual SymFloatNode pow(const SymFloatNode& other) { - TORCH_CHECK(false, "NYI"); - } - virtual SymFloatNode eq(const SymFloatNode& other) { - TORCH_CHECK(false, "NYI"); - }; - virtual SymFloatNode ne(const SymFloatNode& other) { - TORCH_CHECK(false, "NYI"); - }; - virtual SymFloatNode gt(const SymFloatNode& other) { - TORCH_CHECK(false, "NYI"); - }; - virtual SymFloatNode lt(const SymFloatNode& other) { - TORCH_CHECK(false, "NYI"); - }; - virtual SymFloatNode le(const SymFloatNode& other) { - TORCH_CHECK(false, "NYI"); - }; - virtual SymFloatNode ge(const SymFloatNode& other) { - TORCH_CHECK(false, "NYI"); - }; - virtual SymIntNode ceil(); - virtual SymIntNode floor(); - virtual std::string str() { - TORCH_CHECK(false, "NYI"); - }; - std::ostream& operator<<(std::ostream& os) { - os << str(); - return os; - }; -}; - -} // namespace c10 diff --git a/c10/core/SymInt.cpp b/c10/core/SymInt.cpp index 03f39078b406c..977397be1264b 100644 --- a/c10/core/SymInt.cpp +++ b/c10/core/SymInt.cpp @@ -1,47 +1,49 @@ #include #include -#include +#include #include +#include namespace c10 { -static std::array normalize_symints(SymInt a_, SymInt b_) { - SymIntNode a, b; +static std::array normalize_symints( + const SymInt& a_, + const SymInt& b_) { + SymNode a, b; if (a_.is_symbolic()) - a = a_.toSymIntNodeImpl(); + a = a_.toSymNodeImpl(); if (b_.is_symbolic()) - b = b_.toSymIntNodeImpl(); + b = b_.toSymNodeImpl(); - SymIntNodeImpl* common = a ? a.get() : b.get(); + SymNodeImpl* common = a ? a.get() : b.get(); // TODO: technically we need to check that the classes match if (!a) { - a = common->wrap(a_.as_int_unchecked()); - a_.toSymInt(a); // + a = common->wrap_int(a_.as_int_unchecked()); } if (!b) { - b = common->wrap(b_.as_int_unchecked()); - b_.toSymInt(b); + b = common->wrap_int(b_.as_int_unchecked()); } - return {a, b}; + return {std::move(a), std::move(b)}; } -SymIntNode SymInt::toSymIntNodeImpl() const { +SymNode SymInt::toSymNodeImpl() const { TORCH_CHECK(is_symbolic()); - return SymIntNode::reclaim_copy(toSymIntNodeImplUnowned()); + return SymNode::reclaim_copy(toSymNodeImplUnowned()); } -c10::SymInt SymInt::toSymInt(SymIntNode sin_sp) { +SymInt::SymInt(SymNode sin_sp) { + TORCH_CHECK(sin_sp->is_int()); auto ptr = static_cast( reinterpret_cast(static_cast(sin_sp.release()))); auto rep = (ptr & ~MASK) | IS_SYM; - return c10::SymInt(UNCHECKED, static_cast(rep)); + data_ = static_cast(rep); } int64_t SymInt::guard_int(const char* file, int64_t line) const { if (!is_symbolic()) { return data_; } - SymIntNode a = toSymIntNodeImpl(); + SymNode a = toSymNodeImpl(); return a->guard_int(file, line); } @@ -49,50 +51,50 @@ SymInt::operator SymFloat() const { if (!is_symbolic()) { return SymFloat(double(data_)); } - return SymFloat::toSymFloat(toSymIntNodeImpl()->sym_float()); + return SymFloat(toSymNodeImpl()->sym_float()); } -SymInt SymInt::operator+(SymInt sci) const { +SymInt SymInt::operator+(const SymInt& sci) const { if (!is_symbolic() && !sci.is_symbolic()) { return SymInt(data_ + sci.data_); } auto res = normalize_symints(*this, sci); - return SymInt::toSymInt(res[0]->add(res[1])); + return SymInt(res[0]->add(res[1])); } -SymInt SymInt::operator-(SymInt sci) const { +SymInt SymInt::operator-(const SymInt& sci) const { if (!is_symbolic() && !sci.is_symbolic()) { return SymInt(data_ - sci.data_); } auto res = normalize_symints(*this, sci); - return SymInt::toSymInt(res[0]->sub(res[1])); + return SymInt(res[0]->sub(res[1])); } -SymInt SymInt::operator*(SymInt sci) const { +SymInt SymInt::operator*(const SymInt& sci) const { if (!is_symbolic() && !sci.is_symbolic()) { return SymInt(data_ * sci.data_); } auto res = normalize_symints(*this, sci); - return SymInt::toSymInt(res[0]->mul(res[1])); + return SymInt(res[0]->mul(res[1])); } -SymInt SymInt::operator/(SymInt sci) const { +SymInt SymInt::operator/(const SymInt& sci) const { if (!is_symbolic() && !sci.is_symbolic()) { return SymInt(data_ / sci.data_); } auto res = normalize_symints(*this, sci); - return SymInt::toSymInt(res[0]->floordiv(res[1])); + return SymInt(res[0]->floordiv(res[1])); } -SymInt SymInt::operator%(SymInt sci) const { +SymInt SymInt::operator%(const SymInt& sci) const { if (!is_symbolic() && !sci.is_symbolic()) { return SymInt(data_ % sci.data_); } auto res = normalize_symints(*this, sci); - return SymInt::toSymInt(res[0]->mod(res[1])); + return SymInt(res[0]->mod(res[1])); } -bool SymInt::operator==(SymInt sci) const { +bool SymInt::operator==(const SymInt& sci) const { if (!is_symbolic() && !sci.is_symbolic()) { return data_ == sci.data_; } @@ -100,11 +102,11 @@ bool SymInt::operator==(SymInt sci) const { return res[0]->eq(res[1])->bool_(); } -bool SymInt::operator!=(SymInt sci) const { +bool SymInt::operator!=(const SymInt& sci) const { return !(*this == sci); } -bool SymInt::operator<(SymInt sci) const { +bool SymInt::operator<(const SymInt& sci) const { if (!is_symbolic() && !sci.is_symbolic()) { return data_ < sci.data_; } @@ -112,7 +114,7 @@ bool SymInt::operator<(SymInt sci) const { return res[0]->lt(res[1])->bool_(); } -bool SymInt::operator<=(SymInt sci) const { +bool SymInt::operator<=(const SymInt& sci) const { if (!is_symbolic() && !sci.is_symbolic()) { return data_ <= sci.data_; } @@ -120,7 +122,7 @@ bool SymInt::operator<=(SymInt sci) const { return res[0]->le(res[1])->bool_(); } -bool SymInt::operator>(SymInt sci) const { +bool SymInt::operator>(const SymInt& sci) const { if (!is_symbolic() && !sci.is_symbolic()) { return data_ > sci.data_; } @@ -128,7 +130,7 @@ bool SymInt::operator>(SymInt sci) const { return res[0]->gt(res[1])->bool_(); } -bool SymInt::operator>=(SymInt sci) const { +bool SymInt::operator>=(const SymInt& sci) const { if (!is_symbolic() && !sci.is_symbolic()) { return data_ >= sci.data_; } @@ -136,26 +138,30 @@ bool SymInt::operator>=(SymInt sci) const { return res[0]->ge(res[1])->bool_(); } -SymInt SymInt::min(SymInt sci) const { +SymInt SymInt::min(const SymInt& sci) const { if (!is_symbolic() && !sci.is_symbolic()) { return std::min(data_, sci.data_); } auto res = normalize_symints(*this, sci); - return SymInt::toSymInt(res[0]->min(res[1])); + return SymInt(res[0]->min(res[1])); } -SymInt SymInt::max(SymInt sci) const { +SymInt SymInt::max(const SymInt& sci) const { if (!is_symbolic() && !sci.is_symbolic()) { return std::max(data_, sci.data_); } auto res = normalize_symints(*this, sci); - return SymInt::toSymInt(res[0]->max(res[1])); + return SymInt(res[0]->max(res[1])); } -void SymInt::operator*=(SymInt sci) { +void SymInt::operator*=(const SymInt& sci) { *this = *this * sci; } -void SymInt::operator+=(SymInt sci) { +void SymInt::operator/=(const SymInt& sci) { + *this = *this / sci; +} + +void SymInt::operator+=(const SymInt& sci) { *this = *this + sci; } @@ -187,18 +193,18 @@ SymInt SymInt::operator*(int64_t sci) const { return *this * c10::SymInt(sci); } -std::ostream& operator<<(std::ostream& os, SymInt s) { +std::ostream& operator<<(std::ostream& os, const SymInt& s) { if (s.is_symbolic()) { - os << s.toSymIntNodeImpl()->str(); + os << s.toSymNodeImpl()->str(); } else { os << s.as_int_unchecked(); } return os; } -SymInt operator-(SymInt s) { +SymInt operator-(const SymInt& s) { if (s.is_symbolic()) { - return SymInt::toSymInt(s.toSymIntNodeImpl()->neg()); + return SymInt(s.toSymNodeImpl()->neg()); } else { return SymInt(-s.as_int_unchecked()); } diff --git a/c10/core/SymInt.h b/c10/core/SymInt.h index f5c2ddf00998e..6355f13395053 100644 --- a/c10/core/SymInt.h +++ b/c10/core/SymInt.h @@ -1,35 +1,31 @@ #pragma once -#include +#include #include #include #include #include #include +#include namespace c10 { class SymFloat; -// `SymInt` is a C++ wrapper class around int64_t data_ which and is used to -// represent concrete dimension values. +// SymInt represents either a regular int64_t, or a symbolic integer +// (represented in a type erased way as SymNode). The intention is for SymInt +// to represent symbolic sizes that arise when doing shape computation in +// operator kernels. This allows for tracing through programs without baking in +// concrete sizes into kernel calls. // -// `SymInt` is also a data type in Pytorch that can be used in function schemas -// to enable tracing. +// SymInt has an API equivalent to int64_t. In particular, it is a value type. +// Internally, SymInt is represented in a clever packed way, so that it only +// occupies one word of space; but morally, it is a union between an int64_t +// and an intrusive pointer to SymNodeImpl. // -// `SymInt` is introduced to enable tracing arithmetic -// operations on symbolic integers (e.g. sizes). Tracing symbolic sizes will -// allow LTC and AOTAutograd representing dynamic shapes in expression graphs -// faithfully without baking in concrete dimension values. -// -// To trace the operations, SymInt will overload arithmetic operators (e.g. +, -// -, *) and will provide overloads taking SymInt for commonly used math -// functions. -// -// SymInt will be extenteded to represent a union structure Union[int64_t, -// SymIntNodeImpl*] which will be implemented as a single packed int64_t field -// named data_. +// Invariant: the referenced SymNodeImpl is guaranteed to be a SymNode where +// is_int() returns true class C10_API SymInt { public: @@ -44,6 +40,7 @@ class C10_API SymInt { TORCH_CHECK(!is_symbolic()); }; SymInt() : data_(0) {} + SymInt(SymNode n); // unchecked c-tor accepting raw `data_` // One appropriate use for this is when you are constructing a symint @@ -55,28 +52,28 @@ class C10_API SymInt { // temporary and then use the move constructor/assignment SymInt(const SymInt& s) : data_(0) { if (s.is_symbolic()) { - *this = SymInt::toSymInt(s.toSymIntNodeImpl()); + *this = SymInt(s.toSymNodeImpl()); } else { data_ = s.data_; } } - SymInt(SymInt&& s) : data_(s.data_) { + SymInt(SymInt&& s) noexcept : data_(s.data_) { s.data_ = 0; } SymInt& operator=(const SymInt& s) { if (this != &s) { if (s.is_symbolic()) { - *this = SymInt::toSymInt(s.toSymIntNodeImpl()); + *this = SymInt(s.toSymNodeImpl()); } else { data_ = s.data_; } } return *this; } - SymInt& operator=(SymInt&& s) { + SymInt& operator=(SymInt&& s) noexcept { if (this != &s) { - release_(); // release the current SymIntNode if any + release_(); // release the current SymNode if any data_ = s.data_; if (s.is_symbolic()) s.data_ = 0; @@ -86,31 +83,31 @@ class C10_API SymInt { SymInt clone() const { if (is_symbolic()) { - return toSymIntNodeImplUnowned()->clone()->toSymInt(); + return SymInt(toSymNodeImplUnowned()->clone()); } return *this; } - SymIntNodeImpl* toSymIntNodeImplUnowned() const { + SymNodeImpl* toSymNodeImplUnowned() const { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(is_symbolic()); uint64_t unextended_bits = static_cast(data_) & ~MASK; uint64_t sign_bit_mask = 1ULL << (62 - 1); // https://stackoverflow.com/questions/42534749/signed-extension-from-24-bit-to-32-bit-in-c uint64_t extended_bits = (unextended_bits ^ sign_bit_mask) - sign_bit_mask; - return static_cast( + return static_cast( reinterpret_cast(static_cast(extended_bits))); } void release_() { if (is_symbolic()) { - SymIntNode::reclaim(toSymIntNodeImplUnowned()); // steal + SymNode::reclaim(toSymNodeImplUnowned()); // steal } } - SymIntNodeImpl* release() && { + SymNodeImpl* release() && { #ifndef C10_MOBILE TORCH_INTERNAL_ASSERT(is_symbolic()); - auto* r = toSymIntNodeImplUnowned(); + auto* r = toSymNodeImplUnowned(); data_ = 0; // transfer ownership return r; #else @@ -118,8 +115,7 @@ class C10_API SymInt { #endif } - SymIntNode toSymIntNodeImpl() const; - static c10::SymInt toSymInt(SymIntNode sin); + SymNode toSymNodeImpl() const; ~SymInt() { release_(); @@ -156,22 +152,23 @@ class C10_API SymInt { #endif } - SymInt operator+(SymInt sci) const; - SymInt operator-(SymInt sci) const; - SymInt operator*(SymInt sci) const; - SymInt operator/(SymInt sci) const; - SymInt operator%(SymInt sci) const; - bool operator==(SymInt sci) const; - bool operator!=(SymInt p2) const; - bool operator<(SymInt sci) const; - bool operator<=(SymInt sci) const; - bool operator>(SymInt sci) const; - bool operator>=(SymInt sci) const; - void operator*=(SymInt sci); - void operator+=(SymInt sci); - - SymInt min(SymInt sci) const; - SymInt max(SymInt sci) const; + SymInt operator+(const SymInt& sci) const; + SymInt operator-(const SymInt& sci) const; + SymInt operator*(const SymInt& sci) const; + SymInt operator/(const SymInt& sci) const; + SymInt operator%(const SymInt& sci) const; + bool operator==(const SymInt& sci) const; + bool operator!=(const SymInt& p2) const; + bool operator<(const SymInt& sci) const; + bool operator<=(const SymInt& sci) const; + bool operator>(const SymInt& sci) const; + bool operator>=(const SymInt& sci) const; + void operator*=(const SymInt& sci); + void operator+=(const SymInt& sci); + void operator/=(const SymInt& sci); + + SymInt min(const SymInt& sci) const; + SymInt max(const SymInt& sci) const; SymInt operator*(int64_t sci) const; bool operator<(int64_t sci) const; @@ -235,9 +232,56 @@ inline c10::SymInt multiply_integers(const C& container) { container.begin(), container.end(), c10::SymInt(1), - [](c10::SymInt a, c10::SymInt b) { return a * b; }); + [](const c10::SymInt& a, const c10::SymInt& b) { return a * b; }); +} + +template < + typename Iter, + typename = std::enable_if_t::value_type, + c10::SymInt>::value>> +inline c10::SymInt multiply_integers(Iter begin, Iter end) { + return std::accumulate( + begin, + end, + c10::SymInt(1), + [](const c10::SymInt& a, const c10::SymInt& b) { return a * b; }); +} + +inline SymInt operator+(int64_t a, const SymInt& b) { + return c10::SymInt(a) + b; +} +inline SymInt operator-(int64_t a, const SymInt& b) { + return c10::SymInt(a) - b; +} +inline SymInt operator*(int64_t a, const SymInt& b) { + return c10::SymInt(a) * b; +} +inline SymInt operator/(int64_t a, const SymInt& b) { + return c10::SymInt(a) / b; +} +inline SymInt operator%(int64_t a, const SymInt& b) { + return c10::SymInt(a) % b; +} +inline bool operator==(int64_t a, const SymInt& b) { + return c10::SymInt(a) == b; +} +inline bool operator!=(int64_t a, const SymInt& b) { + return c10::SymInt(a) != b; +} +inline bool operator<(int64_t a, const SymInt& b) { + return c10::SymInt(a) < b; +} +inline bool operator<=(int64_t a, const SymInt& b) { + return c10::SymInt(a) <= b; +} +inline bool operator>(int64_t a, const SymInt& b) { + return c10::SymInt(a) > b; +} +inline bool operator>=(int64_t a, const SymInt& b) { + return c10::SymInt(a) >= b; } -C10_API std::ostream& operator<<(std::ostream& os, SymInt s); -C10_API SymInt operator-(SymInt s); +C10_API std::ostream& operator<<(std::ostream& os, const SymInt& s); +C10_API SymInt operator-(const SymInt& s); } // namespace c10 diff --git a/c10/core/SymIntNodeImpl.cpp b/c10/core/SymIntNodeImpl.cpp deleted file mode 100644 index 483110a90fa64..0000000000000 --- a/c10/core/SymIntNodeImpl.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include -#include - -namespace c10 { - -c10::SymInt SymIntNodeImpl::toSymInt() { - auto sit_sp = SymIntNode::reclaim_copy(this); - return SymInt::toSymInt(sit_sp); -} - -} // namespace c10 diff --git a/c10/core/SymNodeImpl.cpp b/c10/core/SymNodeImpl.cpp new file mode 100644 index 0000000000000..80999ba50f1ed --- /dev/null +++ b/c10/core/SymNodeImpl.cpp @@ -0,0 +1,3 @@ +#include + +namespace c10 {} // namespace c10 diff --git a/c10/core/SymIntNodeImpl.h b/c10/core/SymNodeImpl.h similarity index 51% rename from c10/core/SymIntNodeImpl.h rename to c10/core/SymNodeImpl.h index 0b9d4c5579282..fcec452821d76 100644 --- a/c10/core/SymIntNodeImpl.h +++ b/c10/core/SymNodeImpl.h @@ -1,6 +1,5 @@ #pragma once -#include #include #include #include @@ -10,13 +9,12 @@ namespace c10 { -class SymInt; -class SymIntNodeImpl; +class SymNodeImpl; +using SymNode = c10::intrusive_ptr; -class C10_API SymIntNodeImpl : public c10::intrusive_ptr_target { +class C10_API SymNodeImpl : public c10::intrusive_ptr_target { public: - c10::SymInt toSymInt(); - virtual ~SymIntNodeImpl(){}; + virtual ~SymNodeImpl(){}; template c10::intrusive_ptr dyn_cast() const { @@ -24,66 +22,84 @@ class C10_API SymIntNodeImpl : public c10::intrusive_ptr_target { } // these could be pure virtual when we implement LTC versions - virtual SymIntNode add(const SymIntNode& other) { + virtual bool is_int() { TORCH_CHECK(false, "NYI"); }; - virtual SymIntNode sub(const SymIntNode& other) { + virtual bool is_float() { TORCH_CHECK(false, "NYI"); }; - virtual SymIntNode mul(const SymIntNode& other) { + virtual SymNode add(const SymNode& other) { TORCH_CHECK(false, "NYI"); }; - virtual SymFloatNode truediv(const SymIntNode& other) { + virtual SymNode sub(const SymNode& other) { TORCH_CHECK(false, "NYI"); }; - virtual SymIntNode floordiv(const SymIntNode& other) { + virtual SymNode mul(const SymNode& other) { TORCH_CHECK(false, "NYI"); }; - virtual SymIntNode mod(const SymIntNode& other) { + virtual SymNode truediv(const SymNode& other) { TORCH_CHECK(false, "NYI"); }; - virtual SymIntNode eq(const SymIntNode& other) { + virtual SymNode pow(const SymNode& other) { TORCH_CHECK(false, "NYI"); }; - virtual SymIntNode ne(const SymIntNode& other) { + virtual SymNode floordiv(const SymNode& other) { TORCH_CHECK(false, "NYI"); }; - virtual SymIntNode gt(const SymIntNode& other) { + virtual SymNode mod(const SymNode& other) { TORCH_CHECK(false, "NYI"); }; - virtual SymIntNode lt(const SymIntNode& other) { + virtual SymNode eq(const SymNode& other) { TORCH_CHECK(false, "NYI"); }; - virtual SymIntNode le(const SymIntNode& other) { + virtual SymNode ne(const SymNode& other) { TORCH_CHECK(false, "NYI"); }; - virtual SymIntNode ge(const SymIntNode& other) { + virtual SymNode gt(const SymNode& other) { TORCH_CHECK(false, "NYI"); }; - virtual SymIntNode ceil() { + virtual SymNode lt(const SymNode& other) { TORCH_CHECK(false, "NYI"); }; - virtual SymIntNode neg() { + virtual SymNode le(const SymNode& other) { TORCH_CHECK(false, "NYI"); }; - virtual SymIntNode min(const SymIntNode& other) { + virtual SymNode ge(const SymNode& other) { TORCH_CHECK(false, "NYI"); }; - virtual SymIntNode max(const SymIntNode& other) { + virtual SymNode ceil() { TORCH_CHECK(false, "NYI"); }; - virtual SymIntNode clone() { + virtual SymNode floor() { TORCH_CHECK(false, "NYI"); }; - virtual SymFloatNode sym_float() { + virtual SymNode neg() { + TORCH_CHECK(false, "NYI"); + }; + virtual SymNode min(const SymNode& other) { + TORCH_CHECK(false, "NYI"); + }; + virtual SymNode max(const SymNode& other) { + TORCH_CHECK(false, "NYI"); + }; + virtual SymNode clone() { + TORCH_CHECK(false, "NYI"); + }; + virtual SymNode sym_float() { TORCH_CHECK(false, "NYI"); } - virtual SymIntNode wrap(int64_t num) { + virtual SymNode wrap_int(int64_t num) { + TORCH_CHECK(false, "NYI"); + }; + virtual SymNode wrap_float(double num) { TORCH_CHECK(false, "NYI"); }; virtual int64_t guard_int(const char* file, int64_t line) { TORCH_CHECK(false, "NYI"); }; + virtual double guard_float(const char* file, int64_t line) { + TORCH_CHECK(false, "NYI"); + }; virtual int64_t int_() { TORCH_CHECK(false, "NYI"); }; diff --git a/c10/core/TensorImpl.cpp b/c10/core/TensorImpl.cpp index 3951578a848cc..bee3fa32ec214 100644 --- a/c10/core/TensorImpl.cpp +++ b/c10/core/TensorImpl.cpp @@ -10,6 +10,8 @@ #include #include +#include + C10_DEFINE_bool( caffe2_keep_on_shrink, true, @@ -270,6 +272,9 @@ bool_is_contiguous _compute_contiguous( sizes_and_strides_.strides_arrayref())) bool_is_contiguous TensorImpl::compute_contiguous() const { + if (is_sparse()) { + return bool_is_contiguous(false); + } return COMPUTE_WITH_SIZES_STRIDES_NUMEL(_compute_contiguous); } @@ -304,6 +309,9 @@ bool_is_channels_last_contiguous _compute_channels_last_contiguous_2d( bool_is_channels_last_contiguous TensorImpl:: compute_channels_last_contiguous_2d() const { + if (is_sparse()) { + return bool_is_channels_last_contiguous(false); + } return COMPUTE_WITH_SIZES_STRIDES(_compute_channels_last_contiguous_2d); } @@ -338,17 +346,26 @@ bool_is_channels_last_3d_contiguous _compute_channels_last_contiguous_3d( bool_is_channels_last_3d_contiguous TensorImpl:: compute_channels_last_contiguous_3d() const { + if (is_sparse()) { + return bool_is_channels_last_3d_contiguous(false); + } return COMPUTE_WITH_SIZES_STRIDES(_compute_channels_last_contiguous_3d); } bool_is_channels_last TensorImpl::compute_strides_like_channels_last_2d() const { + if (is_sparse()) { + return bool_is_channels_last(false); + } return bool_is_channels_last( COMPUTE_WITH_SIZES_STRIDES(is_channels_last_strides_2d)); } bool_is_channels_last_3d TensorImpl::compute_strides_like_channels_last_3d() const { + if (is_sparse()) { + return bool_is_channels_last_3d(false); + } return bool_is_channels_last_3d( COMPUTE_WITH_SIZES_STRIDES(is_channels_last_strides_3d)); } @@ -391,6 +408,9 @@ bool_is_non_overlapping_and_dense _compute_non_overlapping_and_dense( bool_is_non_overlapping_and_dense TensorImpl:: compute_non_overlapping_and_dense() const { + if (is_sparse()) { + return bool_is_non_overlapping_and_dense(false); + } return COMPUTE_WITH_SIZES_STRIDES(_compute_non_overlapping_and_dense); } @@ -611,12 +631,13 @@ c10::intrusive_ptr TensorImpl::shallow_copy_and_detach_core( VariableVersion&& version_counter, bool allow_tensor_metadata_change) const { c10::intrusive_ptr r; - const auto& maybe_torch_dispatch_mode_state = - c10::impl::TorchDispatchModeTLS::get_mode(); + const auto mode_stack_len = c10::impl::TorchDispatchModeTLS::stack_len(); // TODO: do we have to exclude after Python dispatch key set? - if (maybe_torch_dispatch_mode_state && + if (mode_stack_len > 0 && !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python)) { - r = maybe_torch_dispatch_mode_state->pyinterpreter()->detach(this); + const auto& cur_torch_dispatch_mode_state = + c10::impl::TorchDispatchModeTLS::get_stack_at(mode_stack_len - 1); + r = cur_torch_dispatch_mode_state->pyinterpreter()->detach(this); } else if ( key_set_.has(DispatchKey::Python) && !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python)) { @@ -785,7 +806,7 @@ void TensorImpl::Extend(int64_t num, float growthPct) { sizes_and_strides_.size_at_unchecked(0) * (1 + growthPct / 100)))); auto oldData = std::move(storage_.data_ptr()); auto oldSize = numel_; - Resize(newCapacity); + Resize(std::move(newCapacity)); auto* newData = raw_mutable_data(data_type_); if (data_type_.copy()) { TORCH_CHECK( @@ -837,7 +858,7 @@ void TensorImpl::ReserveSpace(int64_t outer_dim) { auto oldSize = numel_; SmallVector oldDims( sizes_and_strides.begin(), sizes_and_strides.end()); - Resize(newCapacity); + Resize(std::move(newCapacity)); // Allocate new memory but don't copy over the data raw_mutable_data(data_type_); sizes_and_strides_.set_sizes(oldDims); diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index f110b0e9fa460..a6ba3f16e2a27 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -672,6 +673,25 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { } } + // From https://stackoverflow.com/a/3057522/23845 + // TODO: does C++14 have a stdlib template for this? + template + struct identity { + typedef T type; + }; + + template + ArrayRef generic_sizes() { + return _generic_sizes(identity()); + } + + ArrayRef _generic_sizes(identity) { + return sizes(); + } + ArrayRef _generic_sizes(identity) { + return sym_sizes(); + } + /** * The number of elements in a tensor. * @@ -1306,7 +1326,13 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { * It can be expanded as needed in the future, e.g sparse Tensor. */ inline bool support_as_strided() const { - return is_nested() ? false : device().supports_as_strided(); + if (is_nested()) { + return false; + } + if (key_set_.has(DispatchKey::Functionalize)) { + return false; + } + return device().supports_as_strided(); } // ~~~~~ Autograd API ~~~~~ @@ -2037,7 +2063,11 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { return c10::nullopt; } else if (interpreter == self_interpreter) { // NB: pyobj_ could still be null! - return c10::make_optional(_unchecked_untagged_pyobj()); + if (c10::impl::HermeticPyObjectTLS::get_state()) { + return c10::nullopt; + } else { + return c10::make_optional(_unchecked_untagged_pyobj()); + } } else { TORCH_CHECK( false, diff --git a/c10/core/WrapDimMinimal.cpp b/c10/core/WrapDimMinimal.cpp index 6703f0638901e..2375dc3ac5cf7 100644 --- a/c10/core/WrapDimMinimal.cpp +++ b/c10/core/WrapDimMinimal.cpp @@ -14,7 +14,8 @@ T maybe_wrap_dim_slow(T dim, T dim_post_expr, bool wrap_scalar) { "Dimension specified as ", dim, " but tensor has no dimensions"); - return c10::maybe_wrap_dim(dim, /*dim_post_expr=*/1, /*wrap_scalar=*/false); + return c10::maybe_wrap_dim( + std::move(dim), /*dim_post_expr=*/1, /*wrap_scalar=*/false); } T min = dim_post_expr * -1; diff --git a/c10/core/WrapDimMinimal.h b/c10/core/WrapDimMinimal.h index 0f5949f65082b..dda01fbe18f0f 100644 --- a/c10/core/WrapDimMinimal.h +++ b/c10/core/WrapDimMinimal.h @@ -38,7 +38,7 @@ inline c10::SymInt maybe_wrap_dim( c10::SymInt dim, c10::SymInt dim_post_expr, bool wrap_scalar = true) { - return _maybe_wrap_dim(dim, dim_post_expr, wrap_scalar); + return _maybe_wrap_dim(std::move(dim), std::move(dim_post_expr), wrap_scalar); } } // namespace c10 diff --git a/c10/core/impl/HermeticPyObjectTLS.cpp b/c10/core/impl/HermeticPyObjectTLS.cpp new file mode 100644 index 0000000000000..a7eb89430be8a --- /dev/null +++ b/c10/core/impl/HermeticPyObjectTLS.cpp @@ -0,0 +1,23 @@ +#include + +namespace c10 { +namespace impl { + +thread_local std::atomic hermeticPyObjectState{false}; + +std::atomic HermeticPyObjectTLS::haveState_{false}; + +void HermeticPyObjectTLS::set_state(bool state) { + hermeticPyObjectState = state; +} + +bool HermeticPyObjectTLS::get_tls_state() { + return hermeticPyObjectState; +} + +void HermeticPyObjectTLS::init_state() { + haveState_ = true; +} + +} // namespace impl +} // namespace c10 diff --git a/c10/core/impl/HermeticPyObjectTLS.h b/c10/core/impl/HermeticPyObjectTLS.h new file mode 100644 index 0000000000000..9ecc8e761247b --- /dev/null +++ b/c10/core/impl/HermeticPyObjectTLS.h @@ -0,0 +1,61 @@ +#pragma once + +#include +#include + +namespace c10 { +namespace impl { + +// This TLS controls whether or not we permanently associate PyObject +// with Tensor the first time it is allocated. When hermetic PyObject +// TLS is enabled (state is true), we DO NOT save PyObjects to Tensor, +// meaning you get a distinct PyObject whenever you execute the code in +// question. +struct C10_API HermeticPyObjectTLS { + static void set_state(bool state); + static bool get_state() { + // Hypothetical fastpath if torchdeploy/multipy isn't used. Per + // https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/p2055r0.pdf + // this qualifies relaxed access because it is a single-location data + // structure (only the boolean here). + // + // Forgetting about data races for a moment, is there a logical race? + // + // - Boolean only ever transitions from false to true. So the + // critical situation is when one interpreter is already running + // when a second interpreter switches haveState from false to true. + // + // - The first interpreter is indifferent whether or not it sees + // hasState true/false; obviously false works (this is what the + // interpreter was previously using; more directly, the interpreter + // calls into itself as the handler, so being hermetic is not + // required), and true simply means serviced python operator calls will + // be hermetic; in these cases it is expected to be functionally + // equivalent. + // + // - The second interpreter MUST see hasState true (as its requests will + // be forwarded to the first interpreter), but it is assumed that there + // is a synchronization between the interpreter initialization, and + // when we actually perform operations, so it is guaranteed to see + // hasState true. + // + // QED. + // + // This fastpath is currently disabled so that we can more easily test that + // hermetic mode works correctly even on stock build of PyTorch. + if (false && !haveState_.load(std::memory_order_relaxed)) + return false; + return get_tls_state(); + } + // Call this from the multipy/torchdeploy top level + static void init_state(); + + private: + // This only flipped once from false to true during torchdeploy/multipy + // initialization, and never again. + static std::atomic haveState_; + static bool get_tls_state(); +}; + +} // namespace impl +} // namespace c10 diff --git a/c10/core/impl/LocalDispatchKeySet.h b/c10/core/impl/LocalDispatchKeySet.h index 70af58b957165..391b8cff4939b 100644 --- a/c10/core/impl/LocalDispatchKeySet.h +++ b/c10/core/impl/LocalDispatchKeySet.h @@ -52,7 +52,7 @@ struct C10_API PODLocalDispatchKeySet { } }; static_assert( - std::is_pod::value, + std::is_trivial::value, "PODLocalDispatchKeySet must be a POD type."); struct C10_API LocalDispatchKeySet { diff --git a/c10/core/impl/PyInterpreter.cpp b/c10/core/impl/PyInterpreter.cpp index f1dd268bab806..8c29f13f3e5c3 100644 --- a/c10/core/impl/PyInterpreter.cpp +++ b/c10/core/impl/PyInterpreter.cpp @@ -27,6 +27,13 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable { PANIC(dispatch); } + void python_op_registration_trampoline( + const c10::OperatorHandle& op, + c10::DispatchKey, + torch::jit::Stack* stack) const override { + PANIC(python_op_registration_trampoline); + } + void python_dispatcher( const c10::OperatorHandle& op, c10::DispatchKeySet, diff --git a/c10/core/impl/PyInterpreter.h b/c10/core/impl/PyInterpreter.h index 90fbb8dfebf88..da5b612f093b2 100644 --- a/c10/core/impl/PyInterpreter.h +++ b/c10/core/impl/PyInterpreter.h @@ -141,6 +141,15 @@ struct C10_API PyInterpreterVTable { virtual void dispatch(const c10::OperatorHandle& op, torch::jit::Stack* stack) const = 0; + // This is only invoked in the multipy/torchdeploy situation from + // pythonOpRegistrationTrampoline; this lets us get to the Python + // interpreter to actually find the appropriate Python op registration + // entry to call. + virtual void python_op_registration_trampoline( + const c10::OperatorHandle& op, + c10::DispatchKey, + torch::jit::Stack* stack) const = 0; + // Invoke the Python dispatcher to handle this call virtual void python_dispatcher( const c10::OperatorHandle& op, diff --git a/c10/core/impl/TorchDispatchModeTLS.cpp b/c10/core/impl/TorchDispatchModeTLS.cpp index 5f02686584255..6755657b73687 100644 --- a/c10/core/impl/TorchDispatchModeTLS.cpp +++ b/c10/core/impl/TorchDispatchModeTLS.cpp @@ -8,44 +8,12 @@ namespace impl { thread_local TorchDispatchModeTLS torchDispatchModeState; -// MODE -void TorchDispatchModeTLS::set_mode(std::shared_ptr mode) { - if (mode) { - c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, true); - c10::impl::tls_set_dispatch_key_included( - DispatchKey::PythonTLSSnapshot, true); - } else { - TorchDispatchModeTLS::reset_mode(); - } - torchDispatchModeState.mode_ = std::move(mode); -} - -const std::shared_ptr& TorchDispatchModeTLS::get_mode() { - return torchDispatchModeState.mode_; -} - -void TorchDispatchModeTLS::reset_mode() { - torchDispatchModeState.mode_.reset(); - c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false); - c10::impl::tls_set_dispatch_key_included( - DispatchKey::PythonTLSSnapshot, false); -} - -void TorchDispatchModeTLS::swap_mode(std::shared_ptr& mode) { - if (mode) { +void TorchDispatchModeTLS::push_onto_stack(std::shared_ptr mode) { + if (torchDispatchModeState.stack_.size() == 0) { c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, true); c10::impl::tls_set_dispatch_key_included( DispatchKey::PythonTLSSnapshot, true); - } else { - c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false); - c10::impl::tls_set_dispatch_key_included( - DispatchKey::PythonTLSSnapshot, false); } - torchDispatchModeState.mode_.swap(mode); -} - -// STACK -void TorchDispatchModeTLS::push_onto_stack(std::shared_ptr mode) { torchDispatchModeState.stack_.push_back(std::move(mode)); } @@ -56,6 +24,12 @@ const std::shared_ptr TorchDispatchModeTLS::pop_stack() { const std::shared_ptr out = torchDispatchModeState.stack_.back(); torchDispatchModeState.stack_.pop_back(); + + if (torchDispatchModeState.stack_.size() == 0) { + c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false); + c10::impl::tls_set_dispatch_key_included( + DispatchKey::PythonTLSSnapshot, false); + } return out; } @@ -71,20 +45,27 @@ int64_t TorchDispatchModeTLS::stack_len() { return torchDispatchModeState.stack_.size(); } -// STATE - const TorchDispatchModeTLS& TorchDispatchModeTLS::get_state() { return torchDispatchModeState; } void TorchDispatchModeTLS::set_state(const TorchDispatchModeTLS& state) { torchDispatchModeState = state; + if (torchDispatchModeState.stack_.size() == 0) { + c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false); + c10::impl::tls_set_dispatch_key_included( + DispatchKey::PythonTLSSnapshot, false); + } else { + c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, true); + c10::impl::tls_set_dispatch_key_included( + DispatchKey::PythonTLSSnapshot, true); + } } // UTIL bool dispatch_mode_enabled() { - return static_cast(c10::impl::TorchDispatchModeTLS::get_mode()); + return TorchDispatchModeTLS::stack_len() > 0; } } // namespace impl diff --git a/c10/core/impl/TorchDispatchModeTLS.h b/c10/core/impl/TorchDispatchModeTLS.h index 708c22e014ad4..da30d0460427c 100644 --- a/c10/core/impl/TorchDispatchModeTLS.h +++ b/c10/core/impl/TorchDispatchModeTLS.h @@ -9,11 +9,6 @@ namespace c10 { namespace impl { struct C10_API TorchDispatchModeTLS { - static void set_mode(std::shared_ptr mode); - static const std::shared_ptr& get_mode(); - static void reset_mode(); - static void swap_mode(std::shared_ptr& mode); - static void push_onto_stack(std::shared_ptr mode); static const std::shared_ptr pop_stack(); static const std::shared_ptr& get_stack_at(int64_t idx); @@ -23,12 +18,6 @@ struct C10_API TorchDispatchModeTLS { static void set_state(const TorchDispatchModeTLS& state); private: - // The mode TLS is split into - // - mode_, which is the C++ mode, that can only be the mode handling mode - // or null - // - stack_, which is a vector of modes representing the stack of user - // defined modes - std::shared_ptr mode_; std::vector> stack_; }; diff --git a/c10/cuda/CMakeLists.txt b/c10/cuda/CMakeLists.txt index 1dc4435da5f00..2c26bc06f6ca4 100644 --- a/c10/cuda/CMakeLists.txt +++ b/c10/cuda/CMakeLists.txt @@ -21,16 +21,18 @@ configure_file( # and headers you add set(C10_CUDA_SRCS CUDACachingAllocator.cpp + CUDADeviceAssertionHost.cpp CUDAException.cpp CUDAFunctions.cpp + CUDAMallocAsyncAllocator.cpp CUDAMiscFunctions.cpp CUDAStream.cpp - CUDACachingAllocator.cpp - CUDAMallocAsyncAllocator.cpp impl/CUDAGuardImpl.cpp impl/CUDATest.cpp ) set(C10_CUDA_HEADERS + CUDACachingAllocator.h + CUDADeviceAssertionHost.h CUDAException.h CUDAFunctions.h CUDAGuard.h @@ -38,7 +40,6 @@ set(C10_CUDA_HEADERS CUDAMathCompat.h CUDAMiscFunctions.h CUDAStream.h - CUDACachingAllocator.h impl/CUDAGuardImpl.h impl/CUDATest.h ) diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index 8446c25669d77..aaa647502a897 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include namespace c10 { @@ -104,6 +105,7 @@ constexpr size_t kLargeBuffer = constexpr size_t kMinLargeAlloc = 10485760; // allocations between 1 and 10 MiB may use kLargeBuffer constexpr size_t kRoundLarge = 2097152; // round up large allocations to 2 MiB +constexpr size_t kRoundUpPowerOfTwoIntervals = 16; namespace { @@ -406,11 +408,24 @@ class CachingAllocatorConfig { // More description below in function roundup_power2_next_division // As ane example, if we want 4 divisions between 2's power, this can be done // using env variable: PYTORCH_CUDA_ALLOC_CONF=roundup_power2_divisions:4 - static size_t roundup_power2_divisions() { - return instance().m_roundup_power2_divisions; - } - static size_t roundup_bypass_threshold() { - return instance().m_roundup_bypass_threshold; + static size_t roundup_power2_divisions(size_t size) { + size_t log_size = (63 - llvm::countLeadingZeros(size)); + + // Our intervals start at 1MB and end at 64GB + const size_t interval_start = + 63 - llvm::countLeadingZeros(static_cast(1048576)); + const size_t interval_end = + 63 - llvm::countLeadingZeros(static_cast(68719476736)); + TORCH_CHECK( + (interval_end - interval_start == Native::kRoundUpPowerOfTwoIntervals), + "kRoundUpPowerOfTwoIntervals mismatch"); + + int index = static_cast(log_size) - static_cast(interval_start); + + index = std::max(0, index); + index = std::min( + index, static_cast(Native::kRoundUpPowerOfTwoIntervals) - 1); + return instance().m_roundup_power2_divisions[index]; } static CachingAllocatorConfig& instance() { @@ -423,128 +438,269 @@ class CachingAllocatorConfig { return *s_instance; } - void parseArgs(const char* env) { - // If empty, set the default values - m_max_split_size = std::numeric_limits::max(); - m_roundup_power2_divisions = 0; - m_roundup_bypass_threshold = std::numeric_limits::max(); - m_garbage_collection_threshold = 0; + void parseArgs(const char* env); - if (env == nullptr) { - return; + private: + CachingAllocatorConfig() + : m_max_split_size(std::numeric_limits::max()), + m_garbage_collection_threshold(0) { + m_roundup_power2_divisions.assign(Native::kRoundUpPowerOfTwoIntervals, 0); + } + + void lexArgs(const char* env, std::vector& config); + void consumeToken( + const std::vector& config, + size_t i, + const char c); + size_t parseMaxSplitSize(const std::vector& config, size_t i); + size_t parseGarbageCollectionThreshold( + const std::vector& config, + size_t i); + size_t parseRoundUpPower2Divisions( + const std::vector& config, + size_t i); + size_t parseAllocatorConfig( + const std::vector& config, + size_t i, + bool& used_cudaMallocAsync); + + std::atomic m_max_split_size; + std::vector m_roundup_power2_divisions; + std::atomic m_garbage_collection_threshold; +}; + +void CachingAllocatorConfig::lexArgs( + const char* env, + std::vector& config) { + std::vector buf; + + size_t env_length = strlen(env); + for (size_t i = 0; i < env_length; i++) { + if (env[i] == ',' || env[i] == ':' || env[i] == '[' || env[i] == ']') { + if (buf.size() != 0) { + config.emplace_back(std::string(buf.begin(), buf.end())); + buf.clear(); + } + config.emplace_back(std::string(1, env[i])); + } else if (env[i] != ' ') { + buf.emplace_back(static_cast(env[i])); } + } + if (!buf.empty()) { + config.emplace_back(std::string(buf.begin(), buf.end())); + } +} - const std::string config(env); +void CachingAllocatorConfig::consumeToken( + const std::vector& config, + size_t i, + const char c) { + TORCH_CHECK( + i < config.size() && config[i].compare(std::string(1, c)) == 0, + "Error parsing CachingAllocator settings, expected ", + c, + ""); +} - std::regex exp("[\\s,]+"); - std::sregex_token_iterator it(config.begin(), config.end(), exp, -1); - std::sregex_token_iterator end; - std::vector options(it, end); +size_t CachingAllocatorConfig::parseMaxSplitSize( + const std::vector& config, + size_t i) { + consumeToken(config, ++i, ':'); + if (++i < config.size()) { + size_t val1 = stoi(config[i]); + TORCH_CHECK( + val1 > Native::kLargeBuffer / (1024 * 1024), + "CachingAllocator option max_split_size_mb too small, must be > ", + Native::kLargeBuffer / (1024 * 1024), + ""); + val1 = std::max(val1, Native::kLargeBuffer / (1024 * 1024)); + val1 = std::min(val1, (std::numeric_limits::max() / (1024 * 1024))); + m_max_split_size = val1 * 1024 * 1024; + } else { + TORCH_CHECK(false, "Error, expecting max_split_size_mb value", ""); + } + return i; +} - bool used_cudaMallocAsync = false; - bool used_native_specific_option = false; +size_t CachingAllocatorConfig::parseGarbageCollectionThreshold( + const std::vector& config, + size_t i) { + consumeToken(config, ++i, ':'); + if (++i < config.size()) { + double val1 = stod(config[i]); + TORCH_CHECK( + val1 > 0, "garbage_collect_threshold too small, set it 0.0~1.0", ""); + TORCH_CHECK( + val1 < 1.0, "garbage_collect_threshold too big, set it 0.0~1.0", ""); + m_garbage_collection_threshold = val1; + } else { + TORCH_CHECK( + false, "Error, expecting garbage_collection_threshold value", ""); + } + return i; +} - for (auto option : options) { - std::regex exp2("[:]+"); - std::sregex_token_iterator it2(option.begin(), option.end(), exp2, -1); - std::sregex_token_iterator end2; - std::vector kv(it2, end2); - if (kv.size() >= 2) { - /* Maximum split size in MB. Limited to large size blocks */ - if (kv[0] == "max_split_size_mb") { - size_t val2 = stoi(kv[1]); +size_t CachingAllocatorConfig::parseRoundUpPower2Divisions( + const std::vector& config, + size_t i) { + consumeToken(config, ++i, ':'); + bool first_value = true; + + if (++i < config.size()) { + if (config[i].compare("[") == 0) { + size_t last_index = 0; + while (++i < config.size() && config[i].compare("]") != 0) { + std::string val1 = config[i]; + size_t val2 = 0; + + consumeToken(config, ++i, ':'); + if (++i < config.size()) { + val2 = stoi(config[i]); + } else { TORCH_CHECK( - val2 > Native::kLargeBuffer / (1024 * 1024), - "CachingAllocator option max_split_size_mb too small, must be > ", - Native::kLargeBuffer / (1024 * 1024), - ""); - val2 = std::max(val2, Native::kLargeBuffer / (1024 * 1024)); - val2 = std::min( - val2, (std::numeric_limits::max() / (1024 * 1024))); - m_max_split_size = val2 * 1024 * 1024; - used_native_specific_option = true; - } else if (kv[0] == "roundup_power2_divisions") { - size_t val2 = stoi(kv[1]); + false, "Error parsing roundup_power2_divisions value", ""); + } + TORCH_CHECK( + llvm::isPowerOf2_64(val2), + "For roundups, the divisons has to be power of 2 ", + ""); + + if (val1.compare(">") == 0) { + std::fill( + std::next( + m_roundup_power2_divisions.begin(), + static_cast::difference_type>( + last_index)), + m_roundup_power2_divisions.end(), + val2); + } else { + size_t val1_long = stoul(val1); TORCH_CHECK( - llvm::isPowerOf2_64(val2), - "For roundups, the divisons has to be power of 2 ", + llvm::isPowerOf2_64(val1_long), + "For roundups, the intervals have to be power of 2 ", ""); - m_roundup_power2_divisions = val2; - used_native_specific_option = true; - } else if (kv[0] == "roundup_bypass_threshold_mb") { - size_t val2 = stoi(kv[1]); - m_roundup_bypass_threshold = val2 * 1024 * 1024; - used_native_specific_option = true; - } else if (kv[0] == "backend") { - TORCH_CHECK( - ((kv[1] == "native") || (kv[1] == "cudaMallocAsync")), - "Unknown allocator backend, " - "options are native and cudaMallocAsync"); - used_cudaMallocAsync = (kv[1] == "cudaMallocAsync"); - if (used_cudaMallocAsync) { -#if CUDA_VERSION >= 11040 - int version; - C10_CUDA_CHECK(cudaDriverGetVersion(&version)); - TORCH_CHECK( - version >= 11040, - "backend:cudaMallocAsync requires CUDA runtime " - "11.4 or newer, but cudaDriverGetVersion returned ", - version); -#else - TORCH_CHECK( - false, - "backend:cudaMallocAsync requires PyTorch to be built with " - "CUDA 11.4 or newer, but CUDA_VERSION is ", - CUDA_VERSION); -#endif + + size_t index = 63 - llvm::countLeadingZeros(val1_long); + index = std::max((size_t)0, index); + index = std::min(index, m_roundup_power2_divisions.size() - 1); + + if (first_value) { + std::fill( + m_roundup_power2_divisions.begin(), + std::next( + m_roundup_power2_divisions.begin(), + static_cast::difference_type>( + index)), + val2); + first_value = false; } - TORCH_INTERNAL_ASSERT( - kv[1] == get()->name(), - "Allocator backend parsed at runtime != " - "allocator backend parsed at load time"); - } else if (kv[0] == "garbage_collection_threshold") { - /* - * Perform garbage collection of GPU memory blocks to avoid - * triggering expensive sync-and-reclaim-all operation. Upon setting - * the threshold (e.g., 0.8), the allocator will start reclaiming - * blocks if GPU memory capacity usage exceeds the threshold (i.e., - * 80% of total memory). - * Values 0.0 and 1.0 are not allowed as they are less meaningful. - */ - double val2 = stod(kv[1]); - TORCH_CHECK( - val2 > 0, - "garbage_collect_threshold too small, set it 0.0~1.0", - ""); - TORCH_CHECK( - val2 < 1.0, - "garbage_collect_threshold too big, set it 0.0~1.0", - ""); - m_garbage_collection_threshold = val2; - used_native_specific_option = true; - } else { - TORCH_CHECK(false, "Unrecognized CachingAllocator option: ", kv[0]); + if (index < m_roundup_power2_divisions.size()) { + m_roundup_power2_divisions[index] = val2; + } + last_index = index; } - } - if (used_cudaMallocAsync && used_native_specific_option) { - TORCH_WARN( - "backend:cudaMallocAsync ignores max_split_size_mb, roundup_bypass_threshold_mb," - "roundup_power2_divisions, and garbage_collect_threshold."); + if (config[i + 1].compare("]") != 0) { + consumeToken(config, ++i, ','); + } } + } else { // Keep this for backwards compatibility + size_t val1 = stoi(config[i]); + TORCH_CHECK( + llvm::isPowerOf2_64(val1), + "For roundups, the divisons has to be power of 2 ", + ""); + std::fill( + m_roundup_power2_divisions.begin(), + m_roundup_power2_divisions.end(), + val1); } + } else { + TORCH_CHECK(false, "Error, expecting roundup_power2_divisions value", ""); } + return i; +} - private: - CachingAllocatorConfig() - : m_max_split_size(std::numeric_limits::max()), - m_roundup_power2_divisions(0), - m_garbage_collection_threshold(0) {} - std::atomic m_max_split_size; - std::atomic m_roundup_power2_divisions; - std::atomic m_roundup_bypass_threshold; - std::atomic m_garbage_collection_threshold; -}; +size_t CachingAllocatorConfig::parseAllocatorConfig( + const std::vector& config, + size_t i, + bool& used_cudaMallocAsync) { + consumeToken(config, ++i, ':'); + if (++i < config.size()) { + TORCH_CHECK( + ((config[i] == "native") || (config[i] == "cudaMallocAsync")), + "Unknown allocator backend, " + "options are native and cudaMallocAsync"); + used_cudaMallocAsync = (config[i] == "cudaMallocAsync"); + if (used_cudaMallocAsync) { +#if CUDA_VERSION >= 11040 + int version; + C10_CUDA_CHECK(cudaDriverGetVersion(&version)); + TORCH_CHECK( + version >= 11040, + "backend:cudaMallocAsync requires CUDA runtime " + "11.4 or newer, but cudaDriverGetVersion returned ", + version); +#else + TORCH_CHECK( + false, + "backend:cudaMallocAsync requires PyTorch to be built with " + "CUDA 11.4 or newer, but CUDA_VERSION is ", + CUDA_VERSION); +#endif + } + TORCH_INTERNAL_ASSERT( + config[i] == get()->name(), + "Allocator backend parsed at runtime != " + "allocator backend parsed at load time"); + } else { + TORCH_CHECK(false, "Error parsing backend value", ""); + } + return i; +} + +void CachingAllocatorConfig::parseArgs(const char* env) { + // If empty, set the default values + m_max_split_size = std::numeric_limits::max(); + m_roundup_power2_divisions.assign(Native::kRoundUpPowerOfTwoIntervals, 0); + m_garbage_collection_threshold = 0; + bool used_cudaMallocAsync = false; + bool used_native_specific_option = false; + + if (env == nullptr) { + return; + } + + std::vector config; + lexArgs(env, config); + + for (size_t i = 0; i < config.size(); i++) { + if (config[i].compare("max_split_size_mb") == 0) { + i = parseMaxSplitSize(config, i); + used_native_specific_option = true; + } else if (config[i].compare("garbage_collection_threshold") == 0) { + i = parseGarbageCollectionThreshold(config, i); + used_native_specific_option = true; + } else if (config[i].compare("roundup_power2_divisions") == 0) { + i = parseRoundUpPower2Divisions(config, i); + used_native_specific_option = true; + } else if (config[i].compare("backend") == 0) { + i = parseAllocatorConfig(config, i, used_cudaMallocAsync); + } else { + TORCH_CHECK(false, "Unrecognized CachingAllocator option: ", config[i]); + } + + if (i + 1 < config.size()) { + consumeToken(config, ++i, ','); + } + } + + if (used_cudaMallocAsync && used_native_specific_option) { + TORCH_WARN( + "backend:cudaMallocAsync ignores max_split_size_mb, roundup_bypass_threshold_mb," + "roundup_power2_divisions, and garbage_collect_threshold."); + } +} namespace Native { @@ -727,7 +883,7 @@ class DeviceCachingAllocator { device_free, params.size(), params.stream(), - context); + std::move(context)); } stats.num_ooms += 1; @@ -1137,10 +1293,8 @@ class DeviceCachingAllocator { static size_t round_size(size_t size) { if (size < kMinBlockSize) { return kMinBlockSize; - } else if (size > CachingAllocatorConfig::roundup_bypass_threshold()) { - return kMinBlockSize * ((size + kMinBlockSize - 1) / kMinBlockSize); } else { - auto divisions = CachingAllocatorConfig::roundup_power2_divisions(); + auto divisions = CachingAllocatorConfig::roundup_power2_divisions(size); if (divisions > 0 && size > (kMinBlockSize * divisions)) { return roundup_power2_next_division(size, divisions); } else { @@ -1854,6 +2008,10 @@ class NativeCachingAllocator : public CUDAAllocator { } } + bool initialized() override { + return device_allocator.size() > 0; + } + /** allocates a block which is safe to use from the provided stream */ void malloc(void** devPtr, int device, size_t size, cudaStream_t stream) { TORCH_INTERNAL_ASSERT( @@ -2034,7 +2192,8 @@ class NativeCachingAllocator : public CUDAAllocator { CaptureId_t graph_id, MempoolId_t mempool_id) override { assertValidDevice(device); - device_allocator[device]->notifyCaptureBegin(graph_id, mempool_id); + device_allocator[device]->notifyCaptureBegin( + graph_id, std::move(mempool_id)); } void notifyCaptureAboutToEnd(int device, CaptureId_t graph_id) override { @@ -2046,7 +2205,7 @@ class NativeCachingAllocator : public CUDAAllocator { void notifyCaptureDestroy(int device, MempoolId_t mempool_id) override { assertValidDevice(device); - device_allocator[device]->notifyCaptureDestroy(mempool_id); + device_allocator[device]->notifyCaptureDestroy(std::move(mempool_id)); } void* raw_alloc(size_t nbytes) override { diff --git a/c10/cuda/CUDACachingAllocator.h b/c10/cuda/CUDACachingAllocator.h index 74854b5a25fd3..41e082933d55d 100644 --- a/c10/cuda/CUDACachingAllocator.h +++ b/c10/cuda/CUDACachingAllocator.h @@ -183,6 +183,7 @@ class CUDAAllocator : public Allocator { virtual void* raw_alloc_with_stream(size_t nbytes, cudaStream_t stream) = 0; virtual void raw_delete(void* ptr) = 0; virtual void init(int device_count) = 0; + virtual bool initialized() = 0; virtual void setMemoryFraction(double fraction, int device) = 0; virtual void emptyCache() = 0; virtual void cacheInfo(int dev_id, size_t* largestBlock) = 0; diff --git a/c10/cuda/CUDADeviceAssertion.h b/c10/cuda/CUDADeviceAssertion.h new file mode 100644 index 0000000000000..76c422d83bd69 --- /dev/null +++ b/c10/cuda/CUDADeviceAssertion.h @@ -0,0 +1,98 @@ +#pragma once + +#include +#include + +namespace c10 { +namespace cuda { + +#ifdef TORCH_USE_CUDA_DSA +// Copy string from `src` to `dst` +static __device__ void dstrcpy(char* dst, const char* src) { + int i = 0; + // Copy string from source to destination, ensuring that it + // isn't longer than `C10_CUDA_DSA_MAX_STR_LEN-1` + while (*src != '\0' && i++ < C10_CUDA_DSA_MAX_STR_LEN - 1) { + *dst++ = *src++; + } + *dst = '\0'; +} + +__device__ __noinline__ void dsa_add_new_assertion_failure( + DeviceAssertionsData* assertions_data, + const char* assertion_msg, + const char* filename, + const char* function_name, + const int line_number, + const uint32_t caller, + const dim3 block_id, + const dim3 thread_id) { + // `assertions_data` may be nullptr if device-side assertion checking + // is disabled at run-time. If it is disabled at compile time this + // function will never be called + if (!assertions_data) { + return; + } + + // Atomically increment so other threads can fail at the same time + // Note that incrementing this means that the CPU can observe that + // a failure has happened and can begin to respond before we've + // written information about that failure out to the buffer. + const auto nid = atomicAdd(&(assertions_data->assertion_count), 1); + + if (nid >= C10_CUDA_DSA_ASSERTION_COUNT) { + // At this point we're ran out of assertion buffer space. + // We could print a message about this, but that'd get + // spammy if a lot of threads did it, so we just silently + // ignore any other assertion failures. In most cases the + // failures will all probably be analogous anyway. + return; + } + + // Write information about the assertion failure to memory. + // Note that this occurs only after the `assertion_count` + // increment broadcasts that there's been a problem. + auto& self = assertions_data->assertions[nid]; + dstrcpy(self.assertion_msg, assertion_msg); + dstrcpy(self.filename, filename); + dstrcpy(self.function_name, function_name); + self.line_number = line_number; + self.caller = caller; + self.block_id[0] = block_id.x; + self.block_id[1] = block_id.y; + self.block_id[2] = block_id.z; + self.thread_id[0] = thread_id.x; + self.thread_id[1] = thread_id.y; + self.thread_id[2] = thread_id.z; +} + +// Emulates a kernel assertion. The assertion won't stop the kernel's progress, +// so you should assume everything the kernel produces is garbage if there's an +// assertion failure. +// NOTE: This assumes that `assertions_data` and `assertion_caller_id` are +// arguments of the kernel and therefore accessible. +#define CUDA_KERNEL_ASSERT2(condition) \ + do { \ + if (C10_UNLIKELY(!(condition))) { \ + /* Has an atomic element so threads can fail at the same time */ \ + c10::cuda::dsa_add_new_assertion_failure( \ + assertions_data, \ + C10_STRINGIZE(condition), \ + __FILE__, \ + __FUNCTION__, \ + __LINE__, \ + assertion_caller_id, \ + blockIdx, \ + threadIdx); \ + /* Now that the kernel has failed we early exit the kernel, but */ \ + /* otherwise keep going and rely on the host to check UVM and */ \ + /* determine we've had a problem */ \ + return; \ + } \ + } while (false) +#else +#define CUDA_KERNEL_ASSERT2(condition) assert(condition) +#endif + +} // namespace cuda +} // namespace c10 diff --git a/c10/cuda/CUDADeviceAssertionHost.cpp b/c10/cuda/CUDADeviceAssertionHost.cpp new file mode 100644 index 0000000000000..58ece480799cd --- /dev/null +++ b/c10/cuda/CUDADeviceAssertionHost.cpp @@ -0,0 +1,367 @@ +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#define CHECK_CUDA_API_CALL_WITHOUT_CHECKING_DEVICE_ASSERTS() \ + c10_cuda_check_implementation(__FILE__, __FUNCTION__, __LINE__, false) + +namespace c10 { +namespace cuda { + +namespace { + +/// Get the number of CUDA devices +/// We need our own implementation of this function to prevent +/// an infinite initialization loop for CUDAKernelLaunchRegistry +int dsa_get_device_count() { + int device_count = -1; + C10_CUDA_ERROR_HANDLED(cudaGetDeviceCount(&device_count)); + CHECK_CUDA_API_CALL_WITHOUT_CHECKING_DEVICE_ASSERTS(); + return device_count; +} + +bool dsa_check_if_all_devices_support_managed_memory() { +// It looks as though this'll work best on CUDA GPUs with Pascal +// architectures or newer, per +// https://developer.nvidia.com/blog/unified-memory-cuda-beginners/ +#ifdef TORCH_USE_CUDA_DSA + for (const auto i : c10::irange(dsa_get_device_count())) { + if (dsa_get_device_compute_capability(i) < 6) { + return false; + } + } + return true; +#else + return false; +#endif +} + +bool env_flag_set(const char* env_var_name) { + const char* const env_string = std::getenv(env_var_name); + return (env_string == nullptr) ? false : std::strcmp(env_string, "0"); +} + +/// Deleter for UVM/managed memory pointers +void uvm_deleter(DeviceAssertionsData* uvm_assertions_ptr) { + // Ignore error in destructor + if (uvm_assertions_ptr) { + C10_CUDA_IGNORE_ERROR(cudaFree(uvm_assertions_ptr)); + } +} + +#ifdef TORCH_USE_CUDA_DSA +/// Get current device id +/// We need our own implementation of this function to prevent +/// an infinite initialization loop for CUDAKernelLaunchRegistry +int dsa_get_device_id() { + int device = -1; + C10_CUDA_ERROR_HANDLED(cudaGetDevice(&device)); + CHECK_CUDA_API_CALL_WITHOUT_CHECKING_DEVICE_ASSERTS(); + return device; +} + +/// Get a device's compute capability - note that this dangerously assumes +/// that if one CUDA GPU supports device-side assertions they all do. This is +/// probably fine since the latest CUDA GPU that doesn't support UVM is the +/// K80 released 2014-11-17. Mixing that GPU with a newer one is likely to be +/// rare enough that the defensive +/// We need our own implementation of this function to prevent +/// an infinite initialization loop for CUDAKernelLaunchRegistry +int dsa_get_device_compute_capability(const int device_num) { + int compute_capability = -1; + C10_CUDA_ERROR_HANDLED(cudaDeviceGetAttribute( + &compute_capability, cudaDevAttrComputeCapabilityMajor, device_num)); + CHECK_CUDA_API_CALL_WITHOUT_CHECKING_DEVICE_ASSERTS(); + return compute_capability; +} +#endif + +} // namespace + +/// Check that kernels ran correctly by checking the message buffer. BLOCKING. +std::string c10_retrieve_device_side_assertion_info() { +#ifdef TORCH_USE_CUDA_DSA + const auto& launch_registry = CUDAKernelLaunchRegistry::get_singleton_ref(); + if (!launch_registry.enabled) { + return "Device-side assertion tracking was not enabled by user."; + } else if (!launch_registry.do_all_devices_support_managed_memory) { + return "Device-side assertions disabled because not all devices support managed memory."; + } + + // Hack that saves a lot of challenging sync logic. + // The GPU increments the number of errors it's observed and the CPU can see + // that happening immediately which means we can make it here before the GPU + // is done writing information about those errors to memory. + // A short pause gives it time to finish. Since something's gone wrong, this + // pause shouldn't affect perf. + std::this_thread::sleep_for(std::chrono::seconds(1)); + + // The snapshot causes a brief block. That's okay because this function only + // executes if something's gone wrong such that speed is no longer a priority. + const auto launch_data = launch_registry.snapshot(); + const auto& assertion_data = launch_data.first; + const auto& launch_infos = launch_data.second; + + std::stringstream oss; + + { + oss << "This process interacted the following GPUs = {"; + bool first_gpu_listed = true; + for (const auto& x : uvm_assertions) { + if (x) { + if (!first_gpu_listed) { + oss << "," + } + first_gpu_listed = true; + oss << x; + } + } + oss << "}" << std::endl; + } + + // Loop over each device that could be managed by the process + for (const auto device_num : c10::irange(assertion_data.size())) { + const auto& assertion_data_for_device = assertion_data.at(device_num); + + // Did anything fail? + const auto failures_found = std::min( + assertion_data_for_device.assertion_count, + C10_CUDA_DSA_ASSERTION_COUNT); + if (failures_found == 0) { + continue; + } + + // Something failed, let's talk about that + oss << failures_found + << " CUDA device-side assertion failures were found on GPU #" + << device_num << "!" << std::endl; + if (assertion_data_for_device.assertion_count > + C10_CUDA_DSA_ASSERTION_COUNT) { + oss << "But at least " << assertion_data_for_device.assertion_count + << " assertion failures occurred on the device" << std::endl; + oss << "Adjust `C10_CUDA_DSA_ASSERTION_COUNT` if you need more assertion failure info" + << std::endl; + } + + for (const auto i : c10::irange(failures_found)) { + const auto& self = assertion_data_for_device.assertions[i]; + const auto& launch_info = launch_infos[self.caller % launch_infos.size()]; + oss << "Assertion failure " << i << std::endl; + oss << " GPU assertion failure message = " << self.assertion_msg + << std::endl; + oss << " File containing assertion = " << self.filename << ":" + << self.line_number << std::endl; + oss << " Device function containing assertion = " << self.function_name + << std::endl; + oss << " Thread ID that failed assertion = [" << self.thread_id[0] << "," + << self.thread_id[1] << "," << self.thread_id[2] << "]" << std::endl; + oss << " Block ID that failed assertion = [" << self.block_id[0] << "," + << self.block_id[1] << "," << self.block_id[2] << "]" << std::endl; + if (launch_info.generation_number == self.caller) { + oss << " File containing kernel launch = " + << launch_info.launch_filename << ":" << launch_info.launch_linenum + << std::endl; + oss << " Function containing kernel launch = " + << launch_info.launch_function << std::endl; + oss << " Name of kernel launched that led to failure = " + << launch_info.kernel_name << std::endl; + oss << " Device that launched kernel = " << launch_info.device + << std::endl; + oss << " Stream kernel was launched on = " << launch_info.stream + << std::endl; + oss << " Backtrace of kernel launch site = "; + if (launch_registry.gather_launch_stacktrace) { + oss << "Launch stacktracing disabled." << std::endl; + } else { + oss << "\n" << launch_info.launch_stacktrace << std::endl; + } + } else { + oss << " CPU launch site info: Unavailable, the circular queue wrapped around. Increase `CUDAKernelLaunchRegistry::max_size`." + << std::endl; + } + } + } + return oss.str(); +#else + return "Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n"; +#endif +} + +CUDAKernelLaunchRegistry::CUDAKernelLaunchRegistry() + : do_all_devices_support_managed_memory( + dsa_check_if_all_devices_support_managed_memory()), + gather_launch_stacktrace(check_env_for_enable_launch_stacktracing()), + enabled(check_env_for_dsa_enabled()) { + for (C10_UNUSED const auto _ : c10::irange(dsa_get_device_count())) { + uvm_assertions.emplace_back(nullptr, uvm_deleter); + } + + kernel_launches.resize(max_kernel_launches); +} + +bool CUDAKernelLaunchRegistry::check_env_for_enable_launch_stacktracing() + const { + return env_flag_set("PYTORCH_CUDA_DSA_STACKTRACING"); +} + +bool CUDAKernelLaunchRegistry::check_env_for_dsa_enabled() const { + return env_flag_set("PYTORCH_USE_CUDA_DSA"); +} + +uint32_t CUDAKernelLaunchRegistry::insert( + const char* launch_filename, + const char* launch_function, + const uint32_t launch_linenum, + const char* kernel_name, + const int32_t stream_id) { +#ifdef TORCH_USE_CUDA_DSA + if (!is_enabled()) { + return 0; + } + + const auto backtrace = gather_launch_stacktrace ? c10::get_backtrace() : ""; + + const std::lock_guard lock(read_write_mutex); + + const auto my_gen_number = generation_number++; + // TODO: It would probably be good to get a stack trace here so that + // we can better indicate which launch caused the failure. + kernel_launches[my_gen_number % max_kernel_launches] = { + launch_filename, + launch_function, + launch_linenum, + backtrace, + kernel_name, + dsa_get_device_id(), + stream_id, + my_gen_number}; + return my_gen_number; +#else + return 0; +#endif +} + +std::pair, std::vector> +CUDAKernelLaunchRegistry::snapshot() const { + // This is likely to be the longest-lasting hold on the mutex, but + // we only expect it to be called in cases where we're already failing + // and speed is no longer important + const std::lock_guard lock(read_write_mutex); + + std::vector device_assertions_data; + for (const auto& x : uvm_assertions) { + if (x) { + device_assertions_data.push_back(*x); + } else { + device_assertions_data.emplace_back(); + } + } + + return std::make_pair(device_assertions_data, kernel_launches); +} + +DeviceAssertionsData* CUDAKernelLaunchRegistry:: + get_uvm_assertions_ptr_for_current_device() { +#ifdef TORCH_USE_CUDA_DSA + if (!is_enabled()) { + return nullptr; + } + + const auto device_num = dsa_get_device_id(); + + // If we've already set up this GPU with managed memory, return a pointer to + // the managed memory. This is a lock-free quick-return path. + if (uvm_assertions.at(device_num)) { + return uvm_assertions.at(device_num).get(); + } + + // Need a lock here so there's not race-condition on creating the new device + // assertions buffer + const std::lock_guard lock(gpu_alloc_mutex); + + // If we've already set up this GPU with managed memory, return a pointer to + // the managed memory. This locked path ensures that the device memory is + // allocated only once + if (uvm_assertions.at(device_num)) { + return uvm_assertions.at(device_num).get(); + } + + // Otherwise, set up the GPU to be able to use the device-side assertion + // system + DeviceAssertionsData* uvm_assertions_ptr = nullptr; + + C10_CUDA_ERROR_HANDLED( + cudaMallocManaged(&uvm_assertions_ptr, sizeof(DeviceAssertionsData))); + CHECK_CUDA_API_CALL_WITHOUT_CHECKING_DEVICE_ASSERTS(); + + C10_CUDA_ERROR_HANDLED(cudaMemAdvise( + uvm_assertions_ptr, + sizeof(DeviceAssertionsData), + cudaMemAdviseSetPreferredLocation, + cudaCpuDeviceId)); + CHECK_CUDA_API_CALL_WITHOUT_CHECKING_DEVICE_ASSERTS(); + + // GPU will establish direct mapping of data in CPU memory, no page faults + // will be generated + C10_CUDA_ERROR_HANDLED(cudaMemAdvise( + uvm_assertions_ptr, + sizeof(DeviceAssertionsData), + cudaMemAdviseSetAccessedBy, + cudaCpuDeviceId)); + CHECK_CUDA_API_CALL_WITHOUT_CHECKING_DEVICE_ASSERTS(); + + // Initialize the memory from the CPU; otherwise, pages may have to be created + // on demand. We think that UVM documentation indicates that first access may + // not honor preferred location, which would be bad, if true, because we want + // this memory on the host so we can access it post-assertion. Initializing + // this on the CPU helps ensure that that's where the memory will live. + *uvm_assertions_ptr = DeviceAssertionsData(); + + // Ownership and lifetime management of `uvm_assertions_ptr` now passes to the + // uvm_assertions unique_ptr vector + uvm_assertions.at(device_num).reset(uvm_assertions_ptr); + + return uvm_assertions_ptr; +#else + return nullptr; +#endif +} + +CUDAKernelLaunchRegistry& CUDAKernelLaunchRegistry::get_singleton_ref() { + static CUDAKernelLaunchRegistry launch_registry; + return launch_registry; +} + +bool CUDAKernelLaunchRegistry::has_failed() const { + for (const auto& x : uvm_assertions) { + if (x && x->assertion_count > 0) { + return true; + } + } + return false; +} + +bool CUDAKernelLaunchRegistry::is_enabled() const { +#ifdef TORCH_USE_CUDA_DSA + std::cerr << "" +#else + std::cerr + << "TORCH_USE_CUDA_DSA not enabled in CUDAKernelLaunchRegistry::is_enabled" + << std::endl; + return false; +#endif +} + +} // namespace cuda +} // namespace c10 diff --git a/c10/cuda/CUDADeviceAssertionHost.h b/c10/cuda/CUDADeviceAssertionHost.h new file mode 100644 index 0000000000000..7465f3d36b20e --- /dev/null +++ b/c10/cuda/CUDADeviceAssertionHost.h @@ -0,0 +1,156 @@ +#pragma once + +#include + +#include +#include +#include +#include + +#ifdef USE_CUDA +#define TORCH_USE_CUDA_DSA +#endif + +/// Number of assertion failure messages we can store. If this is too small +/// threads will fail silently. +constexpr int C10_CUDA_DSA_ASSERTION_COUNT = 10; +constexpr int C10_CUDA_DSA_MAX_STR_LEN = 512; + +namespace c10 { +namespace cuda { + +/// Holds information about any device-side assertions that fail. +/// Held in managed memory and access by both the CPU and the GPU. +struct DeviceAssertionData { + /// Stringification of the assertion + char assertion_msg[C10_CUDA_DSA_MAX_STR_LEN]; + /// File the assertion was in + char filename[C10_CUDA_DSA_MAX_STR_LEN]; + /// Name of the function the assertion was in + char function_name[C10_CUDA_DSA_MAX_STR_LEN]; + /// Line number the assertion was at + int line_number; + /// Number uniquely identifying the kernel launch that triggered the assertion + uint32_t caller; + /// block_id of the thread that failed the assertion + int32_t block_id[3]; + /// third_id of the thread that failed the assertion + int32_t thread_id[3]; +}; + +/// Used to hold assertions generated by the device +/// Held in managed memory and access by both the CPU and the GPU. +struct DeviceAssertionsData { + /// Total number of assertions found; a subset of thse will be recorded + /// in `assertions` + int32_t assertion_count; + /// An array of assertions that will be written to in a race-free manner + DeviceAssertionData assertions[C10_CUDA_DSA_ASSERTION_COUNT]; +}; + +/// Use to hold info about kernel launches so that we can run kernels +/// asynchronously and still associate launches with device-side +/// assertion failures +struct CUDAKernelLaunchInfo { + /// Filename of the code where the kernel was launched from + const char* launch_filename; + /// Function from which the kernel was launched + const char* launch_function; + /// Line number of where the code was launched from + uint32_t launch_linenum; + /// Backtrace of where the kernel was launched from, only populated if + /// CUDAKernelLaunchRegistry::gather_launch_stacktrace is True + std::string launch_stacktrace; + /// Kernel that was launched + const char* kernel_name; + /// Device the kernel was launched on + int device; + /// Stream the kernel was launched on + int32_t stream; + /// A number that uniquely identifies the kernel launch + uint64_t generation_number; +}; + +/// Circular buffer used to hold information about kernel launches +/// this is later used to reconstruct how a device-side kernel assertion failure +/// occurred CUDAKernelLaunchRegistry is used as a singleton +class C10_CUDA_API CUDAKernelLaunchRegistry { + private: + /// Assume that this is the max number of kernel launches that might ever be + /// enqueued across all streams on a single device + static constexpr int max_kernel_launches = 1024; + /// How many kernel launch infos we've inserted. Used to ensure that circular + /// queue doesn't provide false information by always increasing, but also to + /// mark where we are inserting into the queue +#ifdef TORCH_USE_CUDA_DSA + uint64_t generation_number = 0; +#endif + /// Shared mutex between writer and accessor to ensure multi-threaded safety. + mutable std::mutex read_write_mutex; + /// Used to ensure prevent race conditions in GPU memory allocation + mutable std::mutex gpu_alloc_mutex; + /// Pointer to managed memory keeping track of device-side assertions. There + /// is one entry for each possible device the process might work with. Unused + /// entries are nullptrs. We could also use an unordered_set here, but this + /// vector design will be faster and the wasted memory is small since we + /// expect the number of GPUs per node will always be small + std::vector< + std::unique_ptr> + uvm_assertions; + /// A single circular buffer holds information about every kernel launch the + /// process makes across all devices. + std::vector kernel_launches; + bool check_env_for_enable_launch_stacktracing() const; + bool check_env_for_dsa_enabled() const; + + public: + CUDAKernelLaunchRegistry(); + /// Register a new kernel launch and obtain a generation number back to be + /// passed to the kernel + uint32_t insert( + const char* launch_filename, + const char* launch_function, + const uint32_t launch_linenum, + const char* kernel_name, + const int32_t stream_id); + /// Get copies of the kernel launch registry and each device's assertion + /// failure buffer so they can be inspected without raising race conditions + std:: + pair, std::vector> + snapshot() const; + /// Get a pointer to the current device's assertion failure buffer. If no such + /// buffer exists then one is created. This means that the first kernel launch + /// made on each device will be slightly slower because memory allocations are + /// required + DeviceAssertionsData* get_uvm_assertions_ptr_for_current_device(); + /// Gets the global singleton of the registry + static CUDAKernelLaunchRegistry& get_singleton_ref(); + /// If not all devices support DSA, we disable it + const bool do_all_devices_support_managed_memory = false; + /// Whether or not to gather stack traces when launching kernels + bool gather_launch_stacktrace = false; + /// Whether or not host-side DSA is enabled or disabled at run-time + /// Device-side code cannot be adjusted at run-time + bool enabled = false; + /// Whether or not a device has indicated a failure + bool has_failed() const; + /// Since multiple mechanisms can enable/disable, we add a function that + /// aggregates them + bool is_enabled() const; +}; + +std::string c10_retrieve_device_side_assertion_info(); + +} // namespace cuda +} // namespace c10 + +// Each kernel launched with TORCH_DSA_KERNEL_LAUNCH +// requires the same input arguments. We introduce the following macro to +// standardize these. +#define TORCH_DSA_KERNEL_ARGS \ + c10::cuda::DeviceAssertionsData *const assertions_data, \ + uint32_t assertion_caller_id + +// This macro can be used to pass the DSA arguments onward to another +// function +#define TORCH_DSA_KERNEL_ARGS_PASS assertions_data, assertion_caller_id diff --git a/c10/cuda/CUDAException.cpp b/c10/cuda/CUDAException.cpp index d35d72c9ba7ba..b6e9b9e3606d8 100644 --- a/c10/cuda/CUDAException.cpp +++ b/c10/cuda/CUDAException.cpp @@ -1,5 +1,6 @@ #include +#include #include #include @@ -9,23 +10,31 @@ namespace c10 { namespace cuda { void c10_cuda_check_implementation( - const std::string& filename, - const std::string& function_name, + const char* filename, + const char* function_name, const int line_number, const bool include_device_assertions) { - // We retrieve the error here in order to keep CUDA data types out of - // CUDAException.h thereby simplifying including it in other files - const cudaError_t err = cudaGetLastError(); + const auto cuda_error = cudaGetLastError(); + const auto cuda_kernel_failure = include_device_assertions + ? c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref().has_failed() + : false; - if (C10_LIKELY(err == cudaSuccess)) { + if (C10_LIKELY(cuda_error == cudaSuccess && !cuda_kernel_failure)) { return; } std::string check_message; #ifndef STRIP_ERROR_MESSAGES check_message.append("CUDA error: "); - check_message.append(cudaGetErrorString(err)); + check_message.append(cudaGetErrorString(cuda_error)); check_message.append(c10::cuda::get_cuda_check_suffix()); + check_message.append("\n"); + if (include_device_assertions) { + check_message.append(c10_retrieve_device_side_assertion_info()); + } else { + check_message.append( + "Device-side assertions were explicitly omitted for this error check; the error probably arose while initializing the DSA handlers."); + } #endif TORCH_CHECK(false, check_message); diff --git a/c10/cuda/CUDAException.h b/c10/cuda/CUDAException.h index cfc7424503a96..101036c4ae9eb 100644 --- a/c10/cuda/CUDAException.h +++ b/c10/cuda/CUDAException.h @@ -1,9 +1,11 @@ #pragma once +#include #include #include #include #include +#include #include // Note [CHECK macro] @@ -22,17 +24,17 @@ class C10_CUDA_API CUDAError : public c10::Error { }; } // namespace c10 -#define C10_CUDA_CHECK(EXPR) \ - do { \ - const cudaError_t __err = EXPR; \ - if (C10_UNLIKELY(__err != cudaSuccess)) { \ - c10::cuda::c10_cuda_check_implementation( \ - __FILE__, \ - __func__, /* Line number's data type is not well-defined between \ - compilers, so we perform an explicit cast */ \ - static_cast(__LINE__), \ - true); \ - } \ +#define C10_CUDA_CHECK(EXPR) \ + do { \ + /* We get & disarm the error inside of */ \ + /* `c10_cuda_check_implementation` */ \ + C10_UNUSED const cudaError_t __err = EXPR; \ + c10::cuda::c10_cuda_check_implementation( \ + __FILE__, \ + __func__, /* Line number's data type is not well-defined between \ + compilers, so we perform an explicit cast */ \ + static_cast(__LINE__), \ + true); \ } while (0) #define C10_CUDA_CHECK_WARN(EXPR) \ @@ -70,14 +72,29 @@ class C10_CUDA_API CUDAError : public c10::Error { // diagnostic if it didn't. #define C10_CUDA_KERNEL_LAUNCH_CHECK() C10_CUDA_CHECK(cudaGetLastError()) +/// Launches a CUDA kernel appending to it all the information need to handle +/// device-side assertion failures. Checks that the launch was successful. +#define TORCH_DSA_KERNEL_LAUNCH( \ + kernel, blocks, threads, shared_mem, stream, ...) \ + do { \ + auto& launch_registry = \ + c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref(); \ + kernel<<>>( \ + __VA_ARGS__, \ + launch_registry.get_uvm_assertions_ptr_for_current_device(), \ + launch_registry.insert( \ + __FILE__, __FUNCTION__, __LINE__, #kernel, stream.id())); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ + } while (0) + namespace c10 { namespace cuda { /// In the event of a CUDA failure, formats a nice error message about that /// failure and also checks for device-side assertion failures C10_CUDA_API void c10_cuda_check_implementation( - const std::string& filename, - const std::string& function_name, + const char* filename, + const char* function_name, const int line_number, const bool include_device_assertions); diff --git a/c10/cuda/CUDAMallocAsyncAllocator.cpp b/c10/cuda/CUDAMallocAsyncAllocator.cpp index 610342ac836bf..f567a2655c940 100644 --- a/c10/cuda/CUDAMallocAsyncAllocator.cpp +++ b/c10/cuda/CUDAMallocAsyncAllocator.cpp @@ -430,6 +430,10 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator { (void)called; } + bool initialized() override { + return devs_initialized_flags.size() > 0; + } + static inline void assertValidDevice(int device) { TORCH_CHECK( 0 <= device && device < device_count, "Invalid device argument."); diff --git a/c10/cuda/impl/CUDATest.cpp b/c10/cuda/impl/CUDATest.cpp index fb58d1c3a0f8f..c5d9e3f1bf2b0 100644 --- a/c10/cuda/impl/CUDATest.cpp +++ b/c10/cuda/impl/CUDATest.cpp @@ -11,7 +11,7 @@ namespace impl { bool has_cuda_gpu() { int count; - C10_CUDA_CHECK(cudaGetDeviceCount(&count)); + C10_CUDA_IGNORE_ERROR(cudaGetDeviceCount(&count)); return count != 0; } diff --git a/c10/cuda/test/CMakeLists.txt b/c10/cuda/test/CMakeLists.txt index 30d60871b8f12..eed7fdff42ca1 100644 --- a/c10/cuda/test/CMakeLists.txt +++ b/c10/cuda/test/CMakeLists.txt @@ -1,6 +1,13 @@ # ---[ Test binaries. set(C10_CUDA_ALL_TEST_FILES + impl/CUDAAssertionsTest_1_var_test.cu + impl/CUDAAssertionsTest_catches_stream.cu + impl/CUDAAssertionsTest_catches_thread_and_block_and_device.cu + impl/CUDAAssertionsTest_from_2_processes.cu + impl/CUDAAssertionsTest_multiple_writes_from_blocks_and_threads.cu + impl/CUDAAssertionsTest_multiple_writes_from_multiple_blocks.cu + impl/CUDAAssertionsTest_multiple_writes_from_same_block.cu impl/CUDATest.cpp ) if(BUILD_TEST) diff --git a/c10/cuda/test/build.bzl b/c10/cuda/test/build.bzl index b2d700820cc17..334b3a75b6aa7 100644 --- a/c10/cuda/test/build.bzl +++ b/c10/cuda/test/build.bzl @@ -1,10 +1,42 @@ +dsa_tests = [ + "impl/CUDAAssertionsTest_1_var_test.cu", + "impl/CUDAAssertionsTest_catches_stream.cu", + "impl/CUDAAssertionsTest_catches_thread_and_block_and_device.cu", + "impl/CUDAAssertionsTest_from_2_processes.cu", + "impl/CUDAAssertionsTest_multiple_writes_from_blocks_and_threads.cu", + "impl/CUDAAssertionsTest_multiple_writes_from_multiple_blocks.cu", + "impl/CUDAAssertionsTest_multiple_writes_from_same_block.cu", +] + def define_targets(rules): rules.cc_test( name = "test", - srcs = ["impl/CUDATest.cpp"], + srcs = [ + "impl/CUDATest.cpp", + ], deps = [ "@com_google_googletest//:gtest_main", "//c10/cuda", ], target_compatible_with = rules.requires_cuda_enabled(), ) + + for src in dsa_tests: + name = src.replace("impl/", "").replace(".cu", "") + rules.cuda_library( + name = "test_" + name + "_lib", + srcs = [ + src, + ], + deps = [ + "@com_google_googletest//:gtest_main", + "//c10/cuda", + ], + target_compatible_with = rules.requires_cuda_enabled(), + ) + rules.cc_test( + name = "test_" + name, + deps = [ + ":test_" + name + "_lib", + ], + ) diff --git a/c10/cuda/test/impl/CUDAAssertionsTest_1_var_test.cu b/c10/cuda/test/impl/CUDAAssertionsTest_1_var_test.cu new file mode 100644 index 0000000000000..f30774102a482 --- /dev/null +++ b/c10/cuda/test/impl/CUDAAssertionsTest_1_var_test.cu @@ -0,0 +1,102 @@ +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include + +using ::testing::HasSubstr; + +void did_not_fail_diagnostics() { +#ifdef TORCH_USE_CUDA_DSA + std::cerr << "DSA was enabled" << std::endl; +#else + std::cerr << "DSA was not enabled" << std::endl; +#endif + + std::cerr + << "c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref().enabled = " + << c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref().enabled + << std::endl; + std::cerr + << "c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref().is_enabled() = " + << c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref().is_enabled() + << std::endl; + std::cerr + << "c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref().do_all_devices_support_managed_memory = " + << c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref() + .do_all_devices_support_managed_memory + << std::endl; +} + +/** + * Device kernel that takes a single integer parameter as argument and + * will always trigger a device side assertion. + */ +__global__ void cuda_always_fail_assertion_kernel( + const int a, + TORCH_DSA_KERNEL_ARGS) { + CUDA_KERNEL_ASSERT2(a != a); +} + +/** + * TEST: Triggering device side assertion on a simple <<<1,1>>> config. + * kernel used takes only 1 variable as parameter function. + */ +void cuda_device_assertions_1_var_test() { + const auto stream = c10::cuda::getStreamFromPool(); + TORCH_DSA_KERNEL_LAUNCH( + cuda_always_fail_assertion_kernel, + 1, /* Blocks */ + 1, /* Threads */ + 0, /* Shared mem */ + stream, /* Stream */ + 1); + + try { + c10::cuda::device_synchronize(); + did_not_fail_diagnostics(); + throw std::runtime_error("Test didn't fail, but should have."); + } catch (const c10::Error& err) { + const auto err_str = std::string(err.what()); + ASSERT_THAT( + err_str, + HasSubstr("CUDA device-side assertion failures were found on GPU #0!")); + ASSERT_THAT( + err_str, HasSubstr("Thread ID that failed assertion = [0,0,0]")); + ASSERT_THAT(err_str, HasSubstr("Block ID that failed assertion = [0,0,0]")); + ASSERT_THAT(err_str, HasSubstr("Device that launched kernel = 0")); + ASSERT_THAT( + err_str, + HasSubstr( + "Name of kernel launched that led to failure = cuda_always_fail_assertion_kernel")); + ASSERT_THAT( + err_str, HasSubstr("File containing kernel launch = " __FILE__)); + ASSERT_THAT( + err_str, + HasSubstr( + "Function containing kernel launch = " + + std::string(__FUNCTION__))); + ASSERT_THAT( + err_str, + HasSubstr( + "Stream kernel was launched on = " + std::to_string(stream.id()))); + } +} + +TEST(CUDATest, cuda_device_assertions_1_var_test) { +#ifdef TORCH_USE_CUDA_DSA + c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref().enabled = true; + std::cerr << "BEFORE TEST" << std::endl; + did_not_fail_diagnostics(); + cuda_device_assertions_1_var_test(); +#else + GTEST_SKIP() << "CUDA device-side assertions (DSA) was not enabled."; +#endif +} diff --git a/c10/cuda/test/impl/CUDAAssertionsTest_catches_stream.cu b/c10/cuda/test/impl/CUDAAssertionsTest_catches_stream.cu new file mode 100644 index 0000000000000..71fcf3ee2491f --- /dev/null +++ b/c10/cuda/test/impl/CUDAAssertionsTest_catches_stream.cu @@ -0,0 +1,101 @@ +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include + +using ::testing::HasSubstr; + +/** + * Device kernel that takes mulitple integer parameters as arguments and + * will always trigger a device side assertion. + */ +__global__ void cuda_multiple_vars_always_fail_assertion_kernel( + const int a, + const int b, + const int c, + const int d, + TORCH_DSA_KERNEL_ARGS) { + int i = a + b + c + d; + if (i != 0) { + CUDA_KERNEL_ASSERT2(i == -i); + } else { + CUDA_KERNEL_ASSERT2(i == i + 1); + } +} + +/** + * Device kernel that takes a single integer parameter as argument and + * will always trigger a device side assertion. + */ +__global__ void cuda_always_fail_assertion_kernel( + const int a, + TORCH_DSA_KERNEL_ARGS) { + CUDA_KERNEL_ASSERT2(a != a); +} + +/** + * TEST: Triggering device side assertion on a simple <<<1,1>>> config. + * kernel used takes multiple variables as parameters to the function. + */ +void cuda_device_assertions_catches_stream() { + const auto stream = c10::cuda::getStreamFromPool(); + TORCH_DSA_KERNEL_LAUNCH( + cuda_multiple_vars_always_fail_assertion_kernel, + 1, /* Blocks */ + 1, /* Threads */ + 0, /* Shared mem */ + stream, /* Stream */ + 1, /* const int a */ + 2, /* const int b */ + 3, /* const int c */ + 4 /* const int d */ + ); + + try { + c10::cuda::device_synchronize(); + throw std::runtime_error("Test didn't fail, but should have."); + } catch (const c10::Error& err) { + const auto err_str = std::string(err.what()); + ASSERT_THAT( + err_str, HasSubstr("# of GPUs this process interacted with = 1")); + ASSERT_THAT( + err_str, + HasSubstr("CUDA device-side assertion failures were found on GPU #0!")); + ASSERT_THAT( + err_str, HasSubstr("Thread ID that failed assertion = [0,0,0]")); + ASSERT_THAT(err_str, HasSubstr("Block ID that failed assertion = [0,0,0]")); + ASSERT_THAT(err_str, HasSubstr("Device that launched kernel = 0")); + ASSERT_THAT( + err_str, + HasSubstr( + "Name of kernel launched that led to failure = cuda_multiple_vars_always_fail_assertion_kernel")); + ASSERT_THAT( + err_str, HasSubstr("File containing kernel launch = " __FILE__)); + ASSERT_THAT( + err_str, + HasSubstr( + "Function containing kernel launch = " + + std::string(__FUNCTION__))); + ASSERT_THAT( + err_str, + HasSubstr( + "Stream kernel was launched on = " + std::to_string(stream.id()))); + } +} + +TEST(CUDATest, cuda_device_assertions_catches_stream) { +#ifdef TORCH_USE_CUDA_DSA + c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref().enabled = true; + cuda_device_assertions_catches_stream(); +#else + GTEST_SKIP() << "CUDA device-side assertions (DSA) was not enabled."; +#endif +} diff --git a/c10/cuda/test/impl/CUDAAssertionsTest_catches_thread_and_block_and_device.cu b/c10/cuda/test/impl/CUDAAssertionsTest_catches_thread_and_block_and_device.cu new file mode 100644 index 0000000000000..1a0a0b475a0d9 --- /dev/null +++ b/c10/cuda/test/impl/CUDAAssertionsTest_catches_thread_and_block_and_device.cu @@ -0,0 +1,86 @@ +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include + +using ::testing::HasSubstr; + +/** + * Device kernel that takes 2 arguments + * @param bad_thread represents the thread we want to trigger assertion on. + * @param bad_block represents the block we want to trigger assertion on. + * This kernel will only trigger a device side assertion for <> pair. all the other blocks and threads pairs will basically be + * no-op. + */ +__global__ void cuda_device_assertions_fail_on_thread_block_kernel( + const int bad_thread, + const int bad_block, + TORCH_DSA_KERNEL_ARGS) { + if (threadIdx.x == bad_thread && blockIdx.x == bad_block) { + CUDA_KERNEL_ASSERT2(false); // This comparison necessarily needs to fail + } +} + +/** + * TEST: Triggering device side assertion on only 1 thread from <<<1024,128>>> + * grid. kernel used is unique, it take 2 parameters to tell which particular + * block and thread it should assert, all the other theads of the kernel will be + * basically no-op. + */ +void cuda_device_assertions_catches_thread_and_block_and_device() { + const auto stream = c10::cuda::getStreamFromPool(); + TORCH_DSA_KERNEL_LAUNCH( + cuda_device_assertions_fail_on_thread_block_kernel, + 1024, /* Blocks */ + 128, /* Threads */ + 0, /* Shared mem */ + stream, /* Stream */ + 29, /* bad thread */ + 937 /* bad block */ + ); + + try { + c10::cuda::device_synchronize(); + throw std::runtime_error("Test didn't fail, but should have."); + } catch (const c10::Error& err) { + const auto err_str = std::string(err.what()); + ASSERT_THAT( + err_str, HasSubstr("Thread ID that failed assertion = [29,0,0]")); + ASSERT_THAT( + err_str, HasSubstr("Block ID that failed assertion = [937,0,0]")); + ASSERT_THAT(err_str, HasSubstr("Device that launched kernel = 0")); + ASSERT_THAT( + err_str, + HasSubstr( + "Name of kernel launched that led to failure = cuda_device_assertions_fail_on_thread_block_kernel")); + ASSERT_THAT( + err_str, HasSubstr("File containing kernel launch = " __FILE__)); + ASSERT_THAT( + err_str, + HasSubstr( + "Function containing kernel launch = " + + std::string(__FUNCTION__))); + ASSERT_THAT( + err_str, + HasSubstr( + "Stream kernel was launched on = " + std::to_string(stream.id()))); + } +} + +TEST(CUDATest, cuda_device_assertions_catches_thread_and_block_and_device) { +#ifdef TORCH_USE_CUDA_DSA + c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref().enabled = true; + cuda_device_assertions_catches_thread_and_block_and_device(); +#else + GTEST_SKIP() << "CUDA device-side assertions (DSA) was not enabled."; +#endif +} diff --git a/c10/cuda/test/impl/CUDAAssertionsTest_from_2_processes.cu b/c10/cuda/test/impl/CUDAAssertionsTest_from_2_processes.cu new file mode 100644 index 0000000000000..3f829259a6b03 --- /dev/null +++ b/c10/cuda/test/impl/CUDAAssertionsTest_from_2_processes.cu @@ -0,0 +1,108 @@ +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include + +using ::testing::HasSubstr; + +const auto max_assertions_failure_str = + "Assertion failure " + std::to_string(C10_CUDA_DSA_ASSERTION_COUNT - 1); + +/** + * Device kernel that takes a single integer parameter as argument and + * will always trigger a device side assertion. + */ +__global__ void cuda_always_fail_assertion_kernel( + const int a, + TORCH_DSA_KERNEL_ARGS) { + CUDA_KERNEL_ASSERT2(a != a); +} + +/** + * Device kernel that takes a single integer parameter as argument and + * will never trigger a device side assertion. + */ +__global__ void cuda_always_succeed_assertion_kernel( + const int a, + TORCH_DSA_KERNEL_ARGS) { + CUDA_KERNEL_ASSERT2(a == a); +} + +// Windows doesn't like `fork` +#ifndef _MSC_VER +/** + * TEST: Triggering device side assertion from 2 different processes from CPU. + * The following code is testing if two processes from CPU that are running + * GPU kernels (not necessarily simultaneously) and are asserting & writing + * to the respective UVMs, mess up anything for each other. + * Once parent process's kernel launch fails and causes a device-side assertion + * and is still alive when the second process is interacting with the GPU, + * trying to launch another kernel. + */ +void cuda_device_assertions_from_2_processes() { + const auto n1 = fork(); + if (n1 == 0) { + // This is the parent process, that will call an assertion failure. + // This should execute before the child process. + // We are achieving this by putting the child process to sleep. + TORCH_DSA_KERNEL_LAUNCH( + cuda_always_fail_assertion_kernel, + 1, /* Blocks */ + 1, /* Threads */ + 0, /* Shared mem */ + c10::cuda::getStreamFromPool(), /* Stream */ + 1); + try { + c10::cuda::device_synchronize(); + throw std::runtime_error("Test didn't fail, but should have."); + } catch (const c10::Error& err) { + const auto err_str = std::string(err.what()); + ASSERT_THAT( + err_str, + HasSubstr( + "1 CUDA device-side assertion failures were found on GPU #0!")); + } + // Keep this alive so we can see what happened to the other process + std::this_thread::sleep_for(std::chrono::milliseconds(3000)); + } else { + // This is the child process + // We put it to sleep for next 2 seconds, to make sure that the parent has + // asserted a failure already. + std::this_thread::sleep_for(std::chrono::milliseconds(2000)); + TORCH_DSA_KERNEL_LAUNCH( + cuda_always_succeed_assertion_kernel, + 1, /* Blocks */ + 1, /* Threads */ + 0, /* Shared mem */ + c10::cuda::getStreamFromPool(), /* Stream */ + 1); + try { + c10::cuda::device_synchronize(); + } catch (const c10::Error& err) { + ASSERT_TRUE(false); // This kernel should not have failed, but did. + } + // End the child process + exit(0); + } +} + +TEST(CUDATest, cuda_device_assertions_from_2_processes) { +#ifdef TORCH_USE_CUDA_DSA + c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref().enabled = true; + cuda_device_assertions_from_2_processes(); +#else + GTEST_SKIP() << "CUDA device-side assertions (DSA) was not enabled."; +#endif +} + +#else + +#endif diff --git a/c10/cuda/test/impl/CUDAAssertionsTest_multiple_writes_from_blocks_and_threads.cu b/c10/cuda/test/impl/CUDAAssertionsTest_multiple_writes_from_blocks_and_threads.cu new file mode 100644 index 0000000000000..f5f8597f20c9a --- /dev/null +++ b/c10/cuda/test/impl/CUDAAssertionsTest_multiple_writes_from_blocks_and_threads.cu @@ -0,0 +1,93 @@ +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include + +using ::testing::HasSubstr; + +const auto max_assertions_failure_str = + "Assertion failure " + std::to_string(C10_CUDA_DSA_ASSERTION_COUNT - 1); + +/** + * Device kernel that takes a single integer parameter as argument and + * will always trigger a device side assertion. + */ +__global__ void cuda_always_fail_assertion_kernel( + const int a, + TORCH_DSA_KERNEL_ARGS) { + CUDA_KERNEL_ASSERT2(a != a); +} + +/** + * TEST: Triggering device side assertion from multiple block but single thread + * <<<10,128>>>. Here we are triggering assertion on 10 blocks, each with only + * 128 thread. + */ +void cuda_device_assertions_multiple_writes_from_blocks_and_threads() { + bool run_threads = false; + + // Create a function to launch kernel that waits for a signal, to try to + // ensure everything is happening simultaneously + const auto launch_the_kernel = [&]() { + // Busy loop waiting for the signal to go + while (!run_threads) { + } + + TORCH_DSA_KERNEL_LAUNCH( + cuda_always_fail_assertion_kernel, + 10, /* Blocks */ + 128, /* Threads */ + 0, /* Shared mem */ + c10::cuda::getCurrentCUDAStream(), /* Stream */ + 1); + }; + + // Spin up a bunch of busy-looping threads + std::vector threads; + for (int i = 0; i < 10; i++) { + threads.emplace_back(launch_the_kernel); + } + + // Paranoid - wait for all the threads to get setup + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + // Mash + run_threads = true; + + // Clean-up + for (auto& x : threads) { + x.join(); + } + + try { + c10::cuda::device_synchronize(); + throw std::runtime_error("Test didn't fail, but should have."); + } catch (const c10::Error& err) { + const auto err_str = std::string(err.what()); + ASSERT_THAT(err_str, HasSubstr(max_assertions_failure_str)); + ASSERT_THAT(err_str, HasSubstr("Device that launched kernel = 0")); + ASSERT_THAT( + err_str, + HasSubstr( + "Name of kernel launched that led to failure = cuda_always_fail_assertion_kernel")); + ASSERT_THAT( + err_str, HasSubstr("File containing kernel launch = " __FILE__)); + } +} + +TEST(CUDATest, cuda_device_assertions_multiple_writes_from_blocks_and_threads) { +#ifdef TORCH_USE_CUDA_DSA + c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref().enabled = true; + cuda_device_assertions_multiple_writes_from_blocks_and_threads(); +#else + GTEST_SKIP() << "CUDA device-side assertions (DSA) was not enabled."; +#endif +} diff --git a/c10/cuda/test/impl/CUDAAssertionsTest_multiple_writes_from_multiple_blocks.cu b/c10/cuda/test/impl/CUDAAssertionsTest_multiple_writes_from_multiple_blocks.cu new file mode 100644 index 0000000000000..a66c792d5a236 --- /dev/null +++ b/c10/cuda/test/impl/CUDAAssertionsTest_multiple_writes_from_multiple_blocks.cu @@ -0,0 +1,90 @@ +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include + +using ::testing::HasSubstr; + +const auto max_assertions_failure_str = + "Assertion failure " + std::to_string(C10_CUDA_DSA_ASSERTION_COUNT - 1); + +/** + * Device kernel that takes a single integer parameter as argument and + * will always trigger a device side assertion. + */ +__global__ void cuda_always_fail_assertion_kernel( + const int a, + TORCH_DSA_KERNEL_ARGS) { + CUDA_KERNEL_ASSERT2(a != a); +} + +/** + * TEST: Triggering device side assertion from multiple block but single thread + * <<<10,1>>>. Here we are triggering assertion on 10 blocks, each with only 1 + * thread. Since we have more than 10 SM on a GPU, we expect each block to be + * executed and successfully assert, Hence we will see assertions logged from + * each block here. + */ +void cuda_device_assertions_multiple_writes_from_multiple_blocks() { + const auto stream = c10::cuda::getStreamFromPool(); + TORCH_DSA_KERNEL_LAUNCH( + cuda_always_fail_assertion_kernel, + 10, /* Blocks */ + 1, /* Threads */ + 0, /* Shared mem */ + stream, /* Stream */ + 1); + + try { + c10::cuda::device_synchronize(); + throw std::runtime_error("Test didn't fail, but should have."); + } catch (const c10::Error& err) { + const auto err_str = std::string(err.what()); + ASSERT_THAT(err_str, HasSubstr(max_assertions_failure_str)); + ASSERT_THAT( + err_str, HasSubstr("Thread ID that failed assertion = [0,0,0]")); + ASSERT_THAT(err_str, HasSubstr("Block ID that failed assertion = [0,0,0]")); + ASSERT_THAT(err_str, HasSubstr("Block ID that failed assertion = [1,0,0]")); + ASSERT_THAT(err_str, HasSubstr("Block ID that failed assertion = [2,0,0]")); + ASSERT_THAT(err_str, HasSubstr("Block ID that failed assertion = [3,0,0]")); + ASSERT_THAT(err_str, HasSubstr("Block ID that failed assertion = [4,0,0]")); + ASSERT_THAT(err_str, HasSubstr("Block ID that failed assertion = [5,0,0]")); + ASSERT_THAT(err_str, HasSubstr("Block ID that failed assertion = [6,0,0]")); + ASSERT_THAT(err_str, HasSubstr("Block ID that failed assertion = [7,0,0]")); + ASSERT_THAT(err_str, HasSubstr("Block ID that failed assertion = [8,0,0]")); + ASSERT_THAT(err_str, HasSubstr("Block ID that failed assertion = [9,0,0]")); + ASSERT_THAT(err_str, HasSubstr("Device that launched kernel = 0")); + ASSERT_THAT( + err_str, + HasSubstr( + "Name of kernel launched that led to failure = cuda_always_fail_assertion_kernel")); + ASSERT_THAT( + err_str, HasSubstr("File containing kernel launch = " __FILE__)); + ASSERT_THAT( + err_str, + HasSubstr( + "Function containing kernel launch = " + + std::string(__FUNCTION__))); + ASSERT_THAT( + err_str, + HasSubstr( + "Stream kernel was launched on = " + std::to_string(stream.id()))); + } +} + +TEST(CUDATest, cuda_device_assertions_multiple_writes_from_multiple_blocks) { +#ifdef TORCH_USE_CUDA_DSA + c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref().enabled = true; + cuda_device_assertions_multiple_writes_from_multiple_blocks(); +#else + GTEST_SKIP() << "CUDA device-side assertions (DSA) was not enabled."; +#endif +} diff --git a/c10/cuda/test/impl/CUDAAssertionsTest_multiple_writes_from_same_block.cu b/c10/cuda/test/impl/CUDAAssertionsTest_multiple_writes_from_same_block.cu new file mode 100644 index 0000000000000..f1e39c8ba19d9 --- /dev/null +++ b/c10/cuda/test/impl/CUDAAssertionsTest_multiple_writes_from_same_block.cu @@ -0,0 +1,78 @@ +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include + +using ::testing::HasSubstr; + +const auto max_assertions_failure_str = + "Assertion failure " + std::to_string(C10_CUDA_DSA_ASSERTION_COUNT - 1); + +/** + * Device kernel that takes a single integer parameter as argument and + * will always trigger a device side assertion. + */ +__global__ void cuda_always_fail_assertion_kernel( + const int a, + TORCH_DSA_KERNEL_ARGS) { + CUDA_KERNEL_ASSERT2(a != a); +} + +/** + * TEST: Triggering device side assertion from single block and multiple threads + * <<<1,128>>>. Once the very first thread asserts all the other threads will + * basically be in bad state and the block id with failed asseriton would be + * [0,0,0]. + */ +void cuda_device_assertions_multiple_writes_from_same_block() { + const auto stream = c10::cuda::getStreamFromPool(); + TORCH_DSA_KERNEL_LAUNCH( + cuda_always_fail_assertion_kernel, + 1, /* Blocks */ + 128, /* Threads */ + 0, /* Shared mem */ + stream, /* Stream */ + 1); + + try { + c10::cuda::device_synchronize(); + throw std::runtime_error("Test didn't fail, but should have."); + } catch (const c10::Error& err) { + const auto err_str = std::string(err.what()); + ASSERT_THAT(err_str, HasSubstr(max_assertions_failure_str)); + ASSERT_THAT(err_str, HasSubstr("Block ID that failed assertion = [0,0,0]")); + ASSERT_THAT(err_str, HasSubstr("Device that launched kernel = 0")); + ASSERT_THAT( + err_str, + HasSubstr( + "Name of kernel launched that led to failure = cuda_always_fail_assertion_kernel")); + ASSERT_THAT( + err_str, HasSubstr("File containing kernel launch = " __FILE__)); + ASSERT_THAT( + err_str, + HasSubstr( + "Function containing kernel launch = " + + std::string(__FUNCTION__))); + ASSERT_THAT( + err_str, + HasSubstr( + "Stream kernel was launched on = " + std::to_string(stream.id()))); + } +} + +TEST(CUDATest, cuda_device_assertions_multiple_writes_from_same_block) { +#ifdef TORCH_USE_CUDA_DSA + c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref().enabled = true; + cuda_device_assertions_multiple_writes_from_same_block(); +#else + GTEST_SKIP() << "CUDA device-side assertions (DSA) was not enabled."; +#endif +} diff --git a/c10/macros/Macros.h b/c10/macros/Macros.h index beefca1d63c60..b6912004bd77c 100644 --- a/c10/macros/Macros.h +++ b/c10/macros/Macros.h @@ -255,13 +255,13 @@ using namespace c10::hip; // constants from // (https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications) // The maximum number of threads per multiprocessor is 1024 for Turing -// architecture (7.5), 1536 for Geforce Ampere (8.6), and 2048 for all other -// architectures. You'll get warnings if you exceed these constants. Hence, the -// following macros adjust the input values from the user to resolve potential -// warnings. +// architecture (7.5), 1536 for Geforce Ampere (8.6)/Jetson Orin (8.7), and +// 2048 for all other architectures. You'll get warnings if you exceed these +// constants. Hence, the following macros adjust the input values from the user +// to resolve potential warnings. #if __CUDA_ARCH__ == 750 constexpr uint32_t CUDA_MAX_THREADS_PER_SM = 1024; -#elif __CUDA_ARCH__ == 860 +#elif __CUDA_ARCH__ == 860 || __CUDA_ARCH__ == 870 constexpr uint32_t CUDA_MAX_THREADS_PER_SM = 1536; #else constexpr uint32_t CUDA_MAX_THREADS_PER_SM = 2048; @@ -326,9 +326,8 @@ constexpr uint32_t CUDA_THREADS_PER_BLOCK_FALLBACK = 256; // CUDA_KERNEL_ASSERT checks the assertion // even when NDEBUG is defined. This is useful for important assertions in CUDA // code that would otherwise be suppressed when building Release. -#if defined(__ANDROID__) || defined(__APPLE__) || \ - (defined(USE_ROCM) && ROCM_VERSION < 40100) || \ - (defined(USE_ROCM) && defined(ROCM_DISABLE_GPU_ASSERTS)) +#if defined(__ANDROID__) || defined(__APPLE__) || \ + (defined(USE_ROCM) && ROCM_VERSION < 40100) // Those platforms do not support assert() #define CUDA_KERNEL_ASSERT(cond) #define SYCL_KERNEL_ASSERT(cond) @@ -347,8 +346,8 @@ __host__ __device__ #endif // __CUDA_ARCH__ void _wassert(wchar_t const* _Message, wchar_t const* _File, unsigned _Line); -} #endif // __SYCL_DEVICE_ONLY__ +} #endif // NDEBUG #define CUDA_KERNEL_ASSERT(cond) \ if (C10_UNLIKELY(!(cond))) { \ @@ -368,7 +367,9 @@ extern SYCL_EXTERNAL void __assert_fail( unsigned int line, const char* func); #else // __SYCL_DEVICE_ONLY__ -#if (defined(__CUDA_ARCH__) && !(defined(__clang__) && defined(__CUDA__))) +#if ( \ + defined(__CUDA_ARCH__) && !(defined(__clang__) && defined(__CUDA__)) && \ + !defined(TORCH_DISABLE_GPU_ASSERTS)) // CUDA supports __assert_fail function which are common for both device // and host side code. __host__ __device__ @@ -386,7 +387,7 @@ __host__ __device__ const char* function) throw() __attribute__((__noreturn__)); #if (defined(__HIP_ARCH__) || defined(__HIP__)) && \ - !defined(ROCM_DISABLE_GPU_ASSERTS) + !defined(TORCH_DISABLE_GPU_ASSERTS) // ROCm supports __assert_fail only as a device side function. __device__ __attribute__((noinline)) __attribute__((weak)) void __assert_fail( const char* assertion, @@ -439,15 +440,6 @@ __device__ __attribute__((noinline)) __attribute__((weak)) void __assert_fail( #define C10_IS_TRIVIALLY_COPYABLE(T) std::is_trivially_copyable::value #endif -#if !defined(__clang__) && !defined(_MSC_VER) && defined(__GNUC__) && \ - __GNUC__ < 6 -#define CONSTEXPR_EXCEPT_GCC5 -#define IS_NOT_GCC5_CONSTEXPR 0 -#else -#define CONSTEXPR_EXCEPT_GCC5 constexpr -#define IS_NOT_GCC5_CONSTEXPR 1 -#endif - #if defined(__CUDA_ARCH__) #if defined(_MSC_VER) && defined(__CUDACC__) #define CONSTEXPR_EXCEPT_WIN_CUDA const diff --git a/c10/macros/build.bzl b/c10/macros/build.bzl index 932d0cabac4cb..50f283560d7e8 100644 --- a/c10/macros/build.bzl +++ b/c10/macros/build.bzl @@ -29,3 +29,12 @@ def define_targets(rules): "//conditions:default": [], }), ) + rules.filegroup( + name = "headers", + srcs = rules.glob( + ["*.h"], + exclude = [ + ], + ), + visibility = ["//:__pkg__"], + ) diff --git a/c10/test/core/SymInt_test.cpp b/c10/test/core/SymInt_test.cpp index a57e7c706486d..d889d72b5afb1 100644 --- a/c10/test/core/SymInt_test.cpp +++ b/c10/test/core/SymInt_test.cpp @@ -1,7 +1,7 @@ #include #include -#include +#include using namespace c10; #ifndef C10_MOBILE @@ -20,12 +20,6 @@ TEST(SymIntTest, ConcreteInts) { check(-4611686018427387904LL); } -TEST(SymIntTest, AddNode) { - auto n = c10::make_intrusive(); - auto i = n->toSymInt(); - EXPECT_TRUE(i.is_symbolic()); -} - TEST(SymIntTest, CheckRange) { EXPECT_FALSE(SymInt::check_range(INT64_MIN)); } diff --git a/c10/test/util/complex_math_test_common.h b/c10/test/util/complex_math_test_common.h index 15addf687856f..ce1be7b38d84d 100644 --- a/c10/test/util/complex_math_test_common.h +++ b/c10/test/util/complex_math_test_common.h @@ -166,6 +166,134 @@ C10_DEFINE_TEST(TestLog2, Rev) { } } +C10_DEFINE_TEST(TestLog1p, Normal) { + // log1p(x) = log(1 + x) + { + c10::complex x(0.1, 1.2); + c10::complex l1 = std::log1p(x); + c10::complex l2 = std::log(1.0f + x); + C10_ASSERT_NEAR(l1.real(), l2.real(), tol); + C10_ASSERT_NEAR(l1.imag(), l2.imag(), tol); + } + { + c10::complex x(0.1, 1.2); + c10::complex l1 = std::log1p(x); + c10::complex l2 = std::log(1.0 + x); + C10_ASSERT_NEAR(l1.real(), l2.real(), tol); + C10_ASSERT_NEAR(l1.imag(), l2.imag(), tol); + } +} + +C10_DEFINE_TEST(TestLog1p, Small) { + // log(1 + x) ~ x for |x| << 1 + { + c10::complex x(1e-9, 2e-9); + c10::complex l = std::log1p(x); + C10_ASSERT_NEAR(l.real() / x.real(), 1, tol); + C10_ASSERT_NEAR(l.imag() / x.imag(), 1, tol); + } + { + c10::complex x(1e-100, 2e-100); + c10::complex l = std::log1p(x); + C10_ASSERT_NEAR(l.real() / x.real(), 1, tol); + C10_ASSERT_NEAR(l.imag() / x.imag(), 1, tol); + } +} + +C10_DEFINE_TEST(TestLog1p, Extreme) { + // log(1 + x) ~ x for |x| << 1 and in the brink of overflow / underflow + { + c10::complex x(-1, 1e-30); + c10::complex l = std::log1p(x); + C10_ASSERT_NEAR(l.real(), -69.07755278982137, tol); + C10_ASSERT_NEAR(l.imag(), 1.5707963267948966, tol); + } + { + c10::complex x(-1, 1e30); + c10::complex l = std::log1p(x); + C10_ASSERT_NEAR(l.real(), 69.07755278982137, tol); + C10_ASSERT_NEAR(l.imag(), 1.5707963267948966, tol); + } + { + c10::complex x(1e30, 1); + c10::complex l = std::log1p(x); + C10_ASSERT_NEAR(l.real(), 69.07755278982137, tol); + C10_ASSERT_NEAR(l.imag(), 1e-30, tol); + } + { + c10::complex x(1e-30, 1); + c10::complex l = std::log1p(x); + C10_ASSERT_NEAR(l.real(), 0.34657359027997264, tol); + C10_ASSERT_NEAR(l.imag(), 0.7853981633974483, tol); + } + { + c10::complex x(1e30, 1e30); + c10::complex l = std::log1p(x); + C10_ASSERT_NEAR(l.real(), 69.42412638010134, tol); + C10_ASSERT_NEAR(l.imag(), 0.7853981633974483, tol); + } + { + c10::complex x(1e-38, 1e-38); + c10::complex l = std::log1p(x); + C10_ASSERT_NEAR(l.real(), 1e-38, tol); + C10_ASSERT_NEAR(l.imag(), 1e-38, tol); + } + { + c10::complex x(1e-38, 2e-30); + c10::complex l = std::log1p(x); + C10_ASSERT_NEAR(l.real(), 1e-30, tol); + C10_ASSERT_NEAR(l.imag(), 2e-30, tol); + } + { + c10::complex x(-1, 1e-250); + c10::complex l = std::log1p(x); + C10_ASSERT_NEAR(l.real(), -575.6462732485114, tol); + C10_ASSERT_NEAR(l.imag(), 1.5707963267948966, tol); + } + { + c10::complex x(-1, 1e250); + c10::complex l = std::log1p(x); + C10_ASSERT_NEAR(l.real(), 575.6462732485114, tol); + C10_ASSERT_NEAR(l.imag(), 1.5707963267948966, tol); + } + { + c10::complex x(1e250, 1); + c10::complex l = std::log1p(x); + C10_ASSERT_NEAR(l.real(), 575.6462732485114, tol); + C10_ASSERT_NEAR(l.imag(), 1e-250, tol); + } + { + c10::complex x(1e-250, 1); + c10::complex l = std::log1p(x); + C10_ASSERT_NEAR(l.real(), 0.34657359027997264, tol); + C10_ASSERT_NEAR(l.imag(), 0.7853981633974483, tol); + } + { + c10::complex x(1e250, 1e250); + c10::complex l = std::log1p(x); + C10_ASSERT_NEAR(l.real(), 575.9928468387914, tol); + C10_ASSERT_NEAR(l.imag(), 0.7853981633974483, tol); + } + { + c10::complex x(1e-250, 1e-250); + c10::complex l = std::log1p(x); + C10_ASSERT_NEAR(l.real(), 1e-250, tol); + C10_ASSERT_NEAR(l.imag(), 1e-250, tol); + } + { + c10::complex x(1e-250, 2e-250); + c10::complex l = std::log1p(x); + C10_ASSERT_NEAR(l.real(), 1e-250, tol); + C10_ASSERT_NEAR(l.imag(), 2e-250, tol); + } + { + c10::complex x(2e-308, 1.5e-250); + c10::complex l = std::log1p(x); + C10_ASSERT_NEAR(l.real(), 2e-308, tol); + C10_ASSERT_NEAR(l.imag(), 1.5e-308, tol); + } +} + // Power functions C10_DEFINE_TEST(TestPowSqrt, Equal) { diff --git a/c10/test/util/intrusive_ptr_test.cpp b/c10/test/util/intrusive_ptr_test.cpp index 7ed1c292841d5..632fe7fc2f202 100644 --- a/c10/test/util/intrusive_ptr_test.cpp +++ b/c10/test/util/intrusive_ptr_test.cpp @@ -146,6 +146,11 @@ TEST(IntrusivePtrTest, givenInvalidPtr_whenCallingGet_thenReturnsNullptr) { EXPECT_EQ(nullptr, obj.get()); } +TEST(IntrusivePtrTest, givenNullptr_whenCallingGet_thenReturnsNullptr) { + intrusive_ptr obj(nullptr); + EXPECT_EQ(nullptr, obj.get()); +} + TEST(IntrusivePtrTest, givenValidPtr_whenDereferencing_thenReturnsObject) { intrusive_ptr obj = make_intrusive(5); diff --git a/c10/test/util/string_view_test.cpp b/c10/test/util/string_view_test.cpp index f63bd1ea71a7c..43e8994d8bfca 100644 --- a/c10/test/util/string_view_test.cpp +++ b/c10/test/util/string_view_test.cpp @@ -218,19 +218,17 @@ static_assert(!string_view("hello").empty(), ""); } // namespace test_empty namespace test_remove_prefix { -CONSTEXPR_EXCEPT_GCC5 string_view remove_prefix(string_view input, size_t len) { +constexpr string_view remove_prefix(string_view input, size_t len) { input.remove_prefix(len); return input; } TEST(StringViewTest, whenRemovingValidPrefix_thenWorks) { -#if IS_NOT_GCC5_CONSTEXPR static_assert( remove_prefix(string_view("hello"), 0) == string_view("hello"), ""); static_assert( remove_prefix(string_view("hello"), 1) == string_view("ello"), ""); static_assert(remove_prefix(string_view("hello"), 5) == string_view(""), ""); -#endif EXPECT_EQ(remove_prefix(string_view("hello"), 0), string_view("hello")); EXPECT_EQ(remove_prefix(string_view("hello"), 1), string_view("ello")); @@ -245,19 +243,17 @@ TEST(StringViewTest, whenRemovingTooLargePrefix_thenThrows) { } // namespace test_remove_prefix namespace test_remove_suffix { -CONSTEXPR_EXCEPT_GCC5 string_view remove_suffix(string_view input, size_t len) { +constexpr string_view remove_suffix(string_view input, size_t len) { input.remove_suffix(len); return input; } TEST(StringViewTest, whenRemovingValidSuffix_thenWorks) { -#if IS_NOT_GCC5_CONSTEXPR static_assert( remove_suffix(string_view("hello"), 0) == string_view("hello"), ""); static_assert( remove_suffix(string_view("hello"), 1) == string_view("hell"), ""); static_assert(remove_suffix(string_view("hello"), 5) == string_view(""), ""); -#endif EXPECT_EQ(remove_suffix(string_view("hello"), 0), string_view("hello")); EXPECT_EQ(remove_suffix(string_view("hello"), 1), string_view("hell")); @@ -272,17 +268,15 @@ TEST(StringViewTest, whenRemovingTooLargeSuffix_thenThrows) { } // namespace test_remove_suffix namespace test_swap_function { -CONSTEXPR_EXCEPT_GCC5 std::pair get() { +constexpr std::pair get() { string_view first = "first"; string_view second = "second"; swap(first, second); return std::make_pair(first, second); } TEST(StringViewTest, testSwapFunction) { -#if IS_NOT_GCC5_CONSTEXPR static_assert(string_view("second") == get().first, ""); static_assert(string_view("first") == get().second, ""); -#endif EXPECT_EQ(string_view("second"), get().first); EXPECT_EQ(string_view("first"), get().second); @@ -290,17 +284,15 @@ TEST(StringViewTest, testSwapFunction) { } // namespace test_swap_function namespace test_swap_method { -CONSTEXPR_EXCEPT_GCC5 std::pair get() { +constexpr std::pair get() { string_view first = "first"; string_view second = "second"; first.swap(second); return std::make_pair(first, second); } TEST(StringViewTest, testSwapMethod) { -#if IS_NOT_GCC5_CONSTEXPR static_assert(string_view("second") == get().first, ""); static_assert(string_view("first") == get().second, ""); -#endif EXPECT_EQ(string_view("second"), get().first); EXPECT_EQ(string_view("first"), get().second); diff --git a/c10/util/C++17.h b/c10/util/C++17.h index c51275721e584..09259ab840bed 100644 --- a/c10/util/C++17.h +++ b/c10/util/C++17.h @@ -49,6 +49,19 @@ using invoke_result = typename std::result_of; template using invoke_result_t = typename invoke_result::type; +// std::is_pod is deprecated in C++20, std::is_standard_layout and +// std::is_trivial are introduced in C++11, std::conjunction has been introduced +// in C++17. +template +#if defined(__cpp_lib_logical_traits) && __cpp_lib_logical_traits >= 201510L +using is_pod = std::conjunction, std::is_trivial>; +#else +using is_pod = std::is_pod; +#endif + +template +constexpr bool is_pod_v = is_pod::value; + namespace guts { template @@ -127,7 +140,7 @@ using void_t = typename make_void::type; #define CUDA_HOST_DEVICE C10_HOST_DEVICE #endif -#ifdef __cpp_lib_apply +#if defined(__cpp_lib_apply) && !defined(__CUDA_ARCH__) template CUDA_HOST_DEVICE inline constexpr decltype(auto) apply(F&& f, Tuple&& t) { diff --git a/c10/util/Exception.h b/c10/util/Exception.h index d86a85adbe4c4..773107f668ae1 100644 --- a/c10/util/Exception.h +++ b/c10/util/Exception.h @@ -275,6 +275,12 @@ class C10_API OutOfMemoryError : public Error { using Error::Error; }; +// Used for collective communication library errors from the distributed module. +// These turn into DistBackendError when they cross into Python. +class C10_API DistBackendError : public Error { + using Error::Error; +}; + // A utility function to return an exception std::string by prepending its // exception type before its what() content C10_API std::string GetExceptionString(const std::exception& e); @@ -562,17 +568,21 @@ namespace detail { // Report a warning to the user. Accepts an arbitrary number of extra // arguments which are concatenated into the warning message using operator<< // -#define TORCH_WARN_WITH(warning_t, ...) \ +#ifdef DISABLE_WARN +#define _TORCH_WARN_WITH(...) ((void)0); +#else +#define _TORCH_WARN_WITH(warning_t, ...) \ ::c10::warn(::c10::Warning( \ warning_t(), \ {__func__, __FILE__, static_cast(__LINE__)}, \ WARNING_MESSAGE_STRING(__VA_ARGS__), \ false)); +#endif -#define TORCH_WARN(...) TORCH_WARN_WITH(::c10::UserWarning, __VA_ARGS__); +#define TORCH_WARN(...) _TORCH_WARN_WITH(::c10::UserWarning, __VA_ARGS__); #define TORCH_WARN_DEPRECATION(...) \ - TORCH_WARN_WITH(::c10::DeprecationWarning, __VA_ARGS__); + _TORCH_WARN_WITH(::c10::DeprecationWarning, __VA_ARGS__); // Report a warning to the user only once. Accepts an arbitrary number of extra // arguments which are concatenated into the warning message using operator<< @@ -584,12 +594,16 @@ namespace detail { return true; \ }() +#ifdef DISABLE_WARN +#define TORCH_WARN_ONCE(...) ((void)0); +#else #define TORCH_WARN_ONCE(...) \ if (::c10::WarningUtils::get_warnAlways()) { \ TORCH_WARN(__VA_ARGS__); \ } else { \ _TORCH_WARN_ONCE(__VA_ARGS__); \ } +#endif // Report an error with a specific argument // NOTE: using the argument name in TORCH_CHECK's message is preferred diff --git a/c10/util/SmallBuffer.h b/c10/util/SmallBuffer.h index 4dfa04c87190a..b519d30ec3963 100644 --- a/c10/util/SmallBuffer.h +++ b/c10/util/SmallBuffer.h @@ -15,7 +15,9 @@ namespace c10 { template class SmallBuffer { - static_assert(std::is_pod::value, "SmallBuffer is intended for POD types"); + static_assert( + std::is_trivial::value, + "SmallBuffer is intended for POD types"); T storage_[N]; size_t size_; diff --git a/c10/util/ThreadLocalDebugInfo.cpp b/c10/util/ThreadLocalDebugInfo.cpp index 85cb839c6107a..e79ee00d1a61f 100644 --- a/c10/util/ThreadLocalDebugInfo.cpp +++ b/c10/util/ThreadLocalDebugInfo.cpp @@ -1,6 +1,8 @@ #include #include +#include + namespace c10 { C10_DEFINE_TLS_static(std::shared_ptr, tls_debug_info); @@ -67,7 +69,7 @@ DebugInfoGuard::DebugInfoGuard( return; } prev_info_ = debug_info; - ThreadLocalDebugInfo::_push(kind, info); + ThreadLocalDebugInfo::_push(kind, std::move(info)); active_ = true; } diff --git a/c10/util/TypeIndex.h b/c10/util/TypeIndex.h index 3e8114735a227..b78690c123bbe 100644 --- a/c10/util/TypeIndex.h +++ b/c10/util/TypeIndex.h @@ -12,8 +12,13 @@ namespace util { // TODO Make it work for more compilers +// Intel compiler works +#if defined(__INTEL_COMPILER) +#define C10_TYPENAME_SUPPORTS_CONSTEXPR 0 +#define C10_TYPENAME_CONSTEXPR + // Clang works -#if defined(__clang__) +#elif defined(__clang__) // except for NVCC #if defined(__CUDACC__) diff --git a/c10/util/build.bzl b/c10/util/build.bzl index b981eba677185..8d79a557477f0 100644 --- a/c10/util/build.bzl +++ b/c10/util/build.bzl @@ -68,5 +68,5 @@ def define_targets(rules): exclude = [ ], ), - visibility = ["//c10:__pkg__"], + visibility = ["//c10:__pkg__", "//:__pkg__"], ) diff --git a/c10/util/complex_math.h b/c10/util/complex_math.h index ecfd0442b751b..8709fe4a0eb55 100644 --- a/c10/util/complex_math.h +++ b/c10/util/complex_math.h @@ -291,6 +291,35 @@ C10_HOST_DEVICE inline c10::complex atanh(const c10::complex& x) { #endif } +template +C10_HOST_DEVICE inline c10::complex log1p(const c10::complex& z) { + // log1p(z) = log(1 + z) + // Let's define 1 + z = r * e ^ (i * a), then we have + // log(r * e ^ (i * a)) = log(r) + i * a + // With z = x + iy, the term r can be written as + // r = ((1 + x) ^ 2 + y ^ 2) ^ 0.5 + // = (1 + x ^ 2 + 2 * x + y ^ 2) ^ 0.5 + // So, log(r) is + // log(r) = 0.5 * log(1 + x ^ 2 + 2 * x + y ^ 2) + // = 0.5 * log1p(x * (x + 2) + y ^ 2) + // we need to use the expression only on certain condition to avoid overflow + // and underflow from `(x * (x + 2) + y ^ 2)` + T x = z.real(); + T y = z.imag(); + T zabs = std::abs(z); + T theta = std::atan2(y, x + T(1)); + if (zabs < 0.5) { + T r = x * (T(2) + x) + y * y; + if (r == 0) { // handle underflow + return {x, theta}; + } + return {T(0.5) * std::log1p(r), theta}; + } else { + T z0 = std::hypot(x + 1, y); + return {std::log(z0), theta}; + } +} + } // namespace c10_complex_math using c10_complex_math::acos; @@ -304,6 +333,7 @@ using c10_complex_math::cosh; using c10_complex_math::exp; using c10_complex_math::log; using c10_complex_math::log10; +using c10_complex_math::log1p; using c10_complex_math::log2; using c10_complex_math::pow; using c10_complex_math::sin; @@ -325,6 +355,7 @@ using c10_complex_math::cosh; using c10_complex_math::exp; using c10_complex_math::log; using c10_complex_math::log10; +using c10_complex_math::log1p; using c10_complex_math::log2; using c10_complex_math::pow; using c10_complex_math::sin; diff --git a/c10/util/intrusive_ptr.h b/c10/util/intrusive_ptr.h index c87305b08be57..e75c1980fdfa7 100644 --- a/c10/util/intrusive_ptr.h +++ b/c10/util/intrusive_ptr.h @@ -326,6 +326,9 @@ class intrusive_ptr final { intrusive_ptr() noexcept : intrusive_ptr(NullType::singleton(), raw::DontIncreaseRefcount{}) {} + intrusive_ptr(std::nullptr_t) noexcept + : intrusive_ptr(NullType::singleton(), raw::DontIncreaseRefcount{}) {} + // This constructor will not increase the ref counter for you. // We use the tagged dispatch mechanism to explicitly mark this constructor // to not increase the refcount diff --git a/c10/util/irange.h b/c10/util/irange.h index e734688c81d6e..78cf94f25c2d8 100644 --- a/c10/util/irange.h +++ b/c10/util/irange.h @@ -15,8 +15,15 @@ namespace detail { template < typename I, + bool one_sided = false, typename std::enable_if::value, int>::type = 0> -struct integer_iterator : std::iterator { +struct integer_iterator { + using iterator_category = std::input_iterator_tag; + using value_type = I; + using difference_type = std::ptrdiff_t; + using pointer = I*; + using reference = I&; + explicit integer_iterator(I value) : value(value) {} I operator*() const { @@ -39,11 +46,19 @@ struct integer_iterator : std::iterator { } bool operator==(const integer_iterator& other) const { - return value == other.value; + if /* constexpr -- we don't have C++17 yet, see #85969 */ (one_sided) { + // Range-for loops' end test is `begin != end`, not `begin < + // end`. To handle `c10::irange(n)` where n < 0 (which should be + // empty), we just make `begin != end` fail whenever `end` is + // negative. + return other.value < 0 || value == other.value; + } else { + return value == other.value; + } } bool operator!=(const integer_iterator& other) const { - return value != other.value; + return !(*this == other); } protected: @@ -54,20 +69,22 @@ struct integer_iterator : std::iterator { template < typename I, + bool one_sided = false, typename std::enable_if::value, bool>::type = true> struct integer_range { public: integer_range(I begin, I end) : begin_(begin), end_(end) {} - detail::integer_iterator begin() const { + using iterator = detail::integer_iterator; + iterator begin() const { return begin_; } - detail::integer_iterator end() const { + iterator end() const { return end_; } private: - detail::integer_iterator begin_; - detail::integer_iterator end_; + iterator begin_; + iterator end_; }; /// Creates an integer range for the half-open interval [begin, end) @@ -95,11 +112,8 @@ template < typename Integer, typename std::enable_if::value, bool>::type = true> -integer_range irange(Integer end) { - // If end<=begin then the range is empty; we can achieve this effect by - // choosing the larger of {0, end} as the loop terminator - // Handles the case where end<0. irange only works for ranges >=0 - return {Integer(), std::max(Integer(), end)}; +integer_range irange(Integer end) { + return {Integer(), end}; } } // namespace c10 diff --git a/c10/util/logging_is_not_google_glog.h b/c10/util/logging_is_not_google_glog.h index d27cc18e45300..d92f163453e93 100644 --- a/c10/util/logging_is_not_google_glog.h +++ b/c10/util/logging_is_not_google_glog.h @@ -49,7 +49,7 @@ class C10_API MessageLogger { // is not used" and "statement has no effect". class C10_API LoggerVoidify { public: - LoggerVoidify() {} + LoggerVoidify() = default; // This has to be an operator with a precedence lower than << but // higher than ?: void operator&(const std::ostream& s) {} diff --git a/c10/util/safe_numerics.h b/c10/util/safe_numerics.h index 7eb9ed39395d8..e5c249dd1d2b7 100644 --- a/c10/util/safe_numerics.h +++ b/c10/util/safe_numerics.h @@ -22,7 +22,13 @@ C10_ALWAYS_INLINE bool add_overflows(uint64_t a, uint64_t b, uint64_t* out) { return __builtin_add_overflow(a, b, out); #else unsigned long long tmp; +#if defined(_M_IX86) || defined(_M_X64) auto carry = _addcarry_u64(0, a, b, &tmp); +#else + tmp = a + b; + unsigned long long vector = (a & b) ^ ((a ^ b) & ~tmp); + auto carry = vector >> 63; +#endif *out = tmp; return carry; #endif diff --git a/c10/util/string_view.h b/c10/util/string_view.h index 0a4e043740b29..9ad4397d83775 100644 --- a/c10/util/string_view.h +++ b/c10/util/string_view.h @@ -179,7 +179,7 @@ class basic_string_view final { return size() == 0; } - CONSTEXPR_EXCEPT_GCC5 void remove_prefix(size_type n) { + constexpr void remove_prefix(size_type n) { if (n > size()) { throw std::out_of_range( "basic_string_view::remove_prefix: out of range. PrefixLength: " + @@ -189,7 +189,7 @@ class basic_string_view final { size_ -= n; } - CONSTEXPR_EXCEPT_GCC5 void remove_suffix(size_type n) { + constexpr void remove_suffix(size_type n) { if (n > size()) { throw std::out_of_range( "basic_string_view::remove_suffix: out of range. SuffixLength: " + @@ -198,7 +198,7 @@ class basic_string_view final { size_ -= n; } - CONSTEXPR_EXCEPT_GCC5 void swap(basic_string_view& sv) noexcept { + constexpr void swap(basic_string_view& sv) noexcept { auto tmp = *this; *this = sv; sv = tmp; @@ -694,7 +694,7 @@ inline std::basic_ostream& operator<<( } template -CONSTEXPR_EXCEPT_GCC5 inline void swap( +constexpr inline void swap( basic_string_view& lhs, basic_string_view& rhs) { lhs.swap(rhs); diff --git a/c2_defs.bzl b/c2_defs.bzl index d77fed977f39e..fedbb4bca84b7 100644 --- a/c2_defs.bzl +++ b/c2_defs.bzl @@ -5,7 +5,7 @@ load("@fbsource//tools/build_defs:default_platform_defs.bzl", "compose_platform_ load("@fbsource//tools/build_defs:dict_defs.bzl", "dict_defs") load("@fbsource//tools/build_defs:expect.bzl", "expect") load("@fbsource//tools/build_defs:fb_xplat_cxx_library.bzl", "fb_xplat_cxx_library") -load("@fbsource//tools/build_defs:fbsource_utils.bzl", "is_arvr_mode") +load("@fbsource//tools/build_defs:fbsource_utils.bzl", "is_arvr_mode", "is_fbcode_mode_mac") load("@fbsource//tools/build_defs:platform_defs.bzl", "ANDROID", "APPLE", "CXX", "IOS", "MACOSX", "WINDOWS") load("@fbsource//tools/build_defs/apple:build_mode_defs.bzl", "is_production_build") load("@fbsource//tools/build_defs/apple:config_utils_defs.bzl", "STATIC_LIBRARY_IOS_CONFIG", "STATIC_LIBRARY_MAC_CONFIG", "fbobjc_configs") @@ -166,6 +166,7 @@ def get_c2_fbandroid_xplat_compiler_flags(): # T95767731 -- remove this once all builds are on at least llvm-13 "-Wno-unknown-warning-option", "-Wno-unused-but-set-variable", + "-DHAVE_MMAP", ] if get_c2_strip_glog(): @@ -220,26 +221,13 @@ def get_c2_fbobjc_ios_frameworks(): frameworks = [] if get_c2_mpscnn(): - frameworks.append( + frameworks.extend([ "$SDKROOT/System/Library/Frameworks/Metal.framework", - ) + "$SDKROOT/System/Library/Frameworks/MetalPerformanceShaders.framework", + ]) return frameworks -def get_c2_fbobjc_linker_flags(): - flags = [] - - if get_c2_mpscnn(): - # Need linker flags as no platform_frameworks exist, and we can't - # use MPSCNN on x86_64. - # We use weak_framework as it's iOS 10 - flags = [ - "-L$SDKROOT/System/Library/Frameworks/MetalPerformanceShaders.framework", - "-weak_framework", - "MetalPerformanceShaders", - ] - return flags - def get_c2_fbobjc_exported_preprocessor_flags(): flags = [] @@ -310,12 +298,6 @@ def get_c2_default_cxx_args(): STATIC_LIBRARY_IOS_CONFIG, extra_target_config = C2_FBOBJC_EXTRA_TARGET_CONFIG, ), - fbobjc_exported_platform_linker_flags = [ - ( - "iphoneos", - get_c2_fbobjc_linker_flags(), - ), - ], fbobjc_exported_platform_preprocessor_flags = [ ( "iphoneos", @@ -351,7 +333,10 @@ def get_c2_aten_cpu_fbobjc_macosx_deps(): "fbsource//xplat/caffe2:cpukernel_avx2", ] else: - return [] + return select({ + "DEFAULT": [], + "ovr_config//os:macos-x86_64": ["fbsource//xplat/deeplearning/fbgemm:fbgemm"], + }) if is_arvr_mode() else [] def get_c2_aten_cpu_fbobjc_macosx_platform_deps(): if is_focus_enabled(): @@ -377,10 +362,19 @@ def get_c2_aten_cpu_fbobjc_macosx_platform_deps(): }, ]) +def using_protobuf_v3(): + # Consider migrating this to `read_config("protobuf", "use_v3")` + # The `is_fbcode_mode_mac()` clause was added rather than changing to `read_config` to minimize changes in behavior + return is_arvr_mode() or is_fbcode_mode_mac() + +def get_c2_protobuf_dep(): + return "fbsource//third-party/protobuf:libprotobuf" if using_protobuf_v3() else "fbsource//xplat/third-party/protobuf:fb-protobuf-lite" + def c2_cxx_library(**kwargs): args = get_c2_default_cxx_args() args.update(kwargs) args.setdefault("platforms", (ANDROID, APPLE, CXX, WINDOWS)) + fb_xplat_cxx_library( labels = [ "supermodule:android/default/caffe2", @@ -403,7 +397,7 @@ def c2_protobuf_rule(protos): protocmd = ("cp $SRCDIR/{} $SRCDIR/{} && chmod +w $SRCDIR/{} && echo \"option optimize_for = LITE_RUNTIME;\" >> $SRCDIR/{} && ".format(p, proto, proto, proto) + "cp $SRCDIR/caffe2/proto/caffe2.proto $SRCDIR/caffe2.proto && chmod +w $SRCDIR/caffe2.proto && echo \"option optimize_for = LITE_RUNTIME;\" >> $SRCDIR/caffe2.proto && " + "sed -i -e 's/caffe2\\/proto\\/caffe2.proto/caffe2.proto/g' $SRCDIR/{} && ".format(proto) + - ("$(exe fbsource//third-party/protobuf:protoc-host) " if is_arvr_mode() else "$(exe fbsource//xplat/third-party/protobuf:protoc) --osx $(location fbsource//xplat/third-party/protobuf:protoc.Darwin) --linux $(location fbsource//xplat/third-party/protobuf:protoc.Linux) ") + + ("$(exe fbsource//third-party/protobuf:protoc-host) " if using_protobuf_v3() else "$(exe fbsource//xplat/third-party/protobuf:protoc) --osx $(location fbsource//xplat/third-party/protobuf:protoc.Darwin) --linux $(location fbsource//xplat/third-party/protobuf:protoc.Linux) ") + "-I $SRCDIR --cpp_out=$OUT $SRCDIR/{}".format(proto)) buck_genrule( name = proto, @@ -450,7 +444,7 @@ def c2_full_protobuf_rule(protos): protocmd = ("cp $SRCDIR/{} $SRCDIR/{} && ".format(p, proto) + "cp $SRCDIR/caffe2/proto/caffe2.proto $SRCDIR/caffe2.proto && " + "sed -i -e 's/caffe2\\/proto\\/caffe2.proto/caffe2.proto/g' $SRCDIR/{} && ".format(proto) + - ("$(exe fbsource//third-party/protobuf:protoc-host) " if is_arvr_mode() else "$(exe fbsource//xplat/third-party/protobuf:protoc) --osx $(location fbsource//xplat/third-party/protobuf:protoc.Darwin) --linux $(location fbsource//xplat/third-party/protobuf:protoc.Linux) ") + + ("$(exe fbsource//third-party/protobuf:protoc-host) " if using_protobuf_v3() else "$(exe fbsource//xplat/third-party/protobuf:protoc) --osx $(location fbsource//xplat/third-party/protobuf:protoc.Darwin) --linux $(location fbsource//xplat/third-party/protobuf:protoc.Linux) ") + "-I $SRCDIR --cpp_out=$OUT $SRCDIR/{}".format(proto)) buck_genrule( name = prefix + proto, @@ -484,7 +478,7 @@ def libcaffe2_cxx_library(name, use_hptt, **kwargs): name = name, exported_deps = [ "fbsource//xplat/caffe2/c10:c10", - "fbsource//third-party/protobuf:libprotobuf" if is_arvr_mode() else "fbsource//xplat/third-party/protobuf:fb-protobuf-lite", + get_c2_protobuf_dep(), ":caffe2_protobuf_headers", ":pthreadpool", ":common_core", diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 64d53de5a64bb..c1b7b65f2353e 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -492,6 +492,7 @@ if(BUILD_LITE_INTERPRETER) set(CMAKE_POSITION_INDEPENDENT_CODE TRUE) else() append_filelist("libtorch_cmake_sources" LIBTORCH_CMAKE_SRCS) + list(APPEND LIBTORCH_CMAKE_SRCS ${LITE_EAGER_SYMOBLICATION_SRCS}) if(BUILD_LAZY_TS_BACKEND) append_filelist("lazy_tensor_ts_sources" LIBTORCH_CMAKE_SRCS) endif() @@ -564,6 +565,7 @@ if(NOT INTERN_DISABLE_MOBILE_INTERP) ${TORCH_SRC_DIR}/csrc/jit/mobile/train/random.cpp ${TORCH_SRC_DIR}/csrc/jit/mobile/train/sequential.cpp ${TORCH_SRC_DIR}/csrc/jit/mobile/upgrader_mobile.cpp + ${TORCH_SRC_DIR}/csrc/jit/serialization/flatbuffer_serializer.cpp ) list(APPEND TORCH_SRCS ${MOBILE_SRCS}) list(APPEND TORCH_SRCS ${LITE_EAGER_SYMOBLICATION_SRCS}) @@ -599,7 +601,6 @@ if(NOT INTERN_BUILD_MOBILE AND NOT BUILD_LITE_INTERPRETER) ${TORCH_SRC_DIR}/csrc/jit/serialization/export_bytecode.cpp ${TORCH_SRC_DIR}/csrc/jit/serialization/export_module.cpp ${TORCH_SRC_DIR}/csrc/jit/serialization/flatbuffer_serializer.cpp - ${TORCH_SRC_DIR}/csrc/jit/serialization/flatbuffer_serializer_jit.cpp ${TORCH_SRC_DIR}/csrc/jit/codegen/fuser/cpu/fused_kernel.cpp ${TORCH_SRC_DIR}/csrc/jit/api/module_save.cpp ${TORCH_SRC_DIR}/csrc/utils/byte_order.cpp @@ -883,10 +884,6 @@ file(WRITE ${DUMMY_EMPTY_FILE} ${DUMMY_FILE_CONTENT}) # Wrapper library for people who link against torch and expect both CPU and CUDA support # Contains "torch_cpu" and "torch_cuda" add_library(torch ${DUMMY_EMPTY_FILE}) -if(BUILD_SPLIT_CUDA) - # When we split torch_cuda, we want a dummy torch_cuda library that contains both parts - add_library(torch_cuda ${DUMMY_EMPTY_FILE}) -endif() if(HAVE_SOVERSION) set_target_properties(torch PROPERTIES VERSION ${TORCH_VERSION} SOVERSION ${TORCH_SOVERSION}) @@ -926,37 +923,19 @@ elseif(USE_CUDA) ${Caffe2_GPU_CU_SRCS_W_SORT_BY_KEY}) set_property(TARGET torch_cuda_w_sort_by_key PROPERTY CUDA_SEPARABLE_COMPILATION OFF) target_link_libraries(torch_cuda PRIVATE torch_cuda_w_sort_by_key) - elseif(BUILD_SPLIT_CUDA) - add_library(torch_cuda_cpp ${Caffe2_GPU_SRCS} ${Caffe2_GPU_SRCS_W_SORT_BY_KEY}) - add_library(torch_cuda_cu ${Caffe2_GPU_CU_SRCS} ${Caffe2_GPU_CU_SRCS_W_SORT_BY_KEY}) else() add_library(torch_cuda ${Caffe2_GPU_SRCS} ${Caffe2_GPU_SRCS_W_SORT_BY_KEY} ${Caffe2_GPU_CU_SRCS} ${Caffe2_GPU_CU_SRCS_W_SORT_BY_KEY}) endif() set(CUDA_LINK_LIBRARIES_KEYWORD) - if(BUILD_SPLIT_CUDA) - torch_compile_options(torch_cuda_cpp) # see cmake/public/utils.cmake - torch_compile_options(torch_cuda_cu) # see cmake/public/utils.cmake - target_compile_definitions(torch_cuda_cpp PRIVATE BUILD_SPLIT_CUDA) - target_compile_definitions(torch_cuda_cpp PRIVATE USE_CUDA) - target_compile_definitions(torch_cuda_cu PRIVATE BUILD_SPLIT_CUDA) - target_compile_definitions(torch_cuda_cu PRIVATE USE_CUDA) - else() - torch_compile_options(torch_cuda) # see cmake/public/utils.cmake - target_compile_definitions(torch_cuda PRIVATE USE_CUDA) - endif() - if(USE_NCCL AND BUILD_SPLIT_CUDA) - target_link_libraries(torch_cuda_cpp PRIVATE __caffe2_nccl) - target_compile_definitions(torch_cuda_cpp PRIVATE USE_NCCL) - elseif(USE_NCCL) + torch_compile_options(torch_cuda) # see cmake/public/utils.cmake + target_compile_definitions(torch_cuda PRIVATE USE_CUDA) + if(USE_NCCL) target_link_libraries(torch_cuda PRIVATE __caffe2_nccl) target_compile_definitions(torch_cuda PRIVATE USE_NCCL) endif() - if(USE_UCC AND BUILD_SPLIT_CUDA) - target_link_libraries(torch_cuda_cpp PRIVATE __caffe2_ucc) - target_compile_definitions(torch_cuda_cpp PRIVATE USE_UCC) - elseif(USE_UCC) + if(USE_UCC) target_link_libraries(torch_cuda PRIVATE __caffe2_ucc) target_compile_definitions(torch_cuda PRIVATE USE_UCC) endif() @@ -998,13 +977,8 @@ elseif(USE_CUDA) endif() if(USE_PRECOMPILED_HEADERS) - if(BUILD_SPLIT_CUDA) - target_precompile_headers(torch_cuda_cpp PRIVATE - "$<$:ATen/core/ATen_pch.h>") - else() - target_precompile_headers(torch_cuda PRIVATE - "$<$:ATen/core/ATen_pch.h>") - endif() + target_precompile_headers(torch_cuda PRIVATE + "$<$:ATen/core/ATen_pch.h>") endif() endif() @@ -1085,12 +1059,7 @@ if(NOT NO_API) ${TORCH_SRC_DIR}/csrc/api/include) endif() -if(BUILD_SPLIT_CUDA AND MSVC) - # -INCLUDE is used to ensure torch_cuda_cpp/cu are linked against in a project that relies on them. - target_link_libraries(torch_cuda_cpp INTERFACE "-INCLUDE:?warp_size@cuda@at@@YAHXZ") - # See [Note about _torch_cuda_cu_linker_symbol_op and torch_cuda_cu] in native_functions.yaml - target_link_libraries(torch_cuda_cu INTERFACE "-INCLUDE:?_torch_cuda_cu_linker_symbol_op_cuda@native@at@@YA?AVTensor@2@AEBV32@@Z") -elseif(USE_CUDA AND MSVC) +if(USE_CUDA AND MSVC) # -INCLUDE is used to ensure torch_cuda is linked against in a project that relies on them. # Related issue: https://github.com/pytorch/pytorch/issues/31611 target_link_libraries(torch_cuda INTERFACE "-INCLUDE:?warp_size@cuda@at@@YAHXZ") @@ -1320,27 +1289,16 @@ if(USE_DISTRIBUTED) if(USE_UCC AND USE_C10D_UCC) target_compile_definitions(torch_cpu PUBLIC USE_C10D_UCC) if(USE_CUDA) - if(BUILD_SPLIT_CUDA) - target_compile_definitions(torch_cuda_cpp PUBLIC USE_C10D_UCC) - else() - target_compile_definitions(torch_cuda PUBLIC USE_C10D_UCC) - endif() + target_compile_definitions(torch_cuda PUBLIC USE_C10D_UCC) endif() endif() if(USE_NCCL AND USE_C10D_NCCL) if(USE_ROCM) target_compile_definitions(torch_hip PUBLIC USE_C10D_NCCL) else() - if(BUILD_SPLIT_CUDA) - target_compile_definitions(torch_cuda_cpp PUBLIC USE_C10D_NCCL) - if(USE_NCCL_WITH_UCC) - target_compile_definitions(torch_cuda_cpp PUBLIC USE_NCCL_WITH_UCC) - endif() - else() - target_compile_definitions(torch_cuda PUBLIC USE_C10D_NCCL) - if(USE_NCCL_WITH_UCC) - target_compile_definitions(torch_cuda PUBLIC USE_NCCL_WITH_UCC) - endif() + target_compile_definitions(torch_cuda PUBLIC USE_C10D_NCCL) + if(USE_NCCL_WITH_UCC) + target_compile_definitions(torch_cuda PUBLIC USE_NCCL_WITH_UCC) endif() endif() endif() @@ -1423,14 +1381,7 @@ torch_set_target_props(torch_cpu) target_compile_options(torch_cpu PRIVATE "-DCAFFE2_BUILD_MAIN_LIB") -if(BUILD_SPLIT_CUDA) - target_compile_options(torch_cuda_cu PRIVATE "-DTORCH_CUDA_CU_BUILD_MAIN_LIB") - target_compile_options(torch_cuda_cpp PRIVATE "-DTORCH_CUDA_CPP_BUILD_MAIN_LIB") - # NB: This must be target_compile_definitions, not target_compile_options, - # as the latter is not respected by nvcc - target_compile_definitions(torch_cuda_cu PRIVATE "-DTORCH_CUDA_CU_BUILD_MAIN_LIB") - target_compile_definitions(torch_cuda_cpp PRIVATE "-DTORCH_CUDA_CPP_BUILD_MAIN_LIB") -elseif(USE_CUDA) +if(USE_CUDA) target_compile_options(torch_cuda PRIVATE "-DTORCH_CUDA_BUILD_MAIN_LIB") # NB: This must be target_compile_definitions, not target_compile_options, # as the latter is not respected by nvcc @@ -1441,10 +1392,7 @@ elseif(USE_ROCM) endif() if(USE_EXPERIMENTAL_CUDNN_V8_API) - if(BUILD_SPLIT_CUDA) - target_compile_definitions(torch_cuda_cu PRIVATE "-DUSE_EXPERIMENTAL_CUDNN_V8_API") - target_compile_definitions(torch_cuda_cpp PRIVATE "-DUSE_EXPERIMENTAL_CUDNN_V8_API") - elseif(USE_CUDA) + if(USE_CUDA) target_compile_definitions(torch_cuda PRIVATE "-DUSE_EXPERIMENTAL_CUDNN_V8_API") endif() endif() @@ -1534,10 +1482,6 @@ caffe2_interface_library(torch_cpu torch_cpu_library) if(USE_CUDA) caffe2_interface_library(torch_cuda torch_cuda_library) - if(BUILD_SPLIT_CUDA) - caffe2_interface_library(torch_cuda_cu torch_cuda_cu_library) - caffe2_interface_library(torch_cuda_cpp torch_cuda_cpp_library) - endif() elseif(USE_ROCM) caffe2_interface_library(torch_hip torch_hip_library) endif() @@ -1548,10 +1492,6 @@ install(TARGETS torch_cpu torch_cpu_library EXPORT Caffe2Targets DESTINATION "${ if(USE_CUDA) install(TARGETS torch_cuda torch_cuda_library EXPORT Caffe2Targets DESTINATION "${TORCH_INSTALL_LIB_DIR}") - if(BUILD_SPLIT_CUDA) - install(TARGETS torch_cuda_cu torch_cuda_cu_library EXPORT Caffe2Targets DESTINATION "${TORCH_INSTALL_LIB_DIR}") - install(TARGETS torch_cuda_cpp torch_cuda_cpp_library EXPORT Caffe2Targets DESTINATION "${TORCH_INSTALL_LIB_DIR}") - endif() elseif(USE_ROCM) install(TARGETS torch_hip torch_hip_library EXPORT Caffe2Targets DESTINATION "${TORCH_INSTALL_LIB_DIR}") endif() @@ -1561,11 +1501,6 @@ target_link_libraries(torch PUBLIC torch_cpu_library) if(USE_CUDA) target_link_libraries(torch PUBLIC torch_cuda_library) - if(BUILD_SPLIT_CUDA) - # NS: Library order is important here to prevent cudnn double linking - target_link_libraries(torch_cuda PUBLIC torch_cuda_cpp_library) - target_link_libraries(torch_cuda PUBLIC torch_cuda_cu_library) - endif() elseif(USE_ROCM) target_link_libraries(torch PUBLIC torch_hip_library) endif() @@ -1578,10 +1513,7 @@ endif() # Install PDB files for MSVC builds if(MSVC AND BUILD_SHARED_LIBS) install(FILES $ DESTINATION "${TORCH_INSTALL_LIB_DIR}" OPTIONAL) - if(BUILD_SPLIT_CUDA) - install(FILES $ DESTINATION "${TORCH_INSTALL_LIB_DIR}" OPTIONAL) - install(FILES $ DESTINATION "${TORCH_INSTALL_LIB_DIR}" OPTIONAL) - elseif(USE_CUDA) + if(USE_CUDA) install(FILES $ DESTINATION "${TORCH_INSTALL_LIB_DIR}" OPTIONAL) elseif(USE_ROCM) install(FILES $ DESTINATION "${TORCH_INSTALL_LIB_DIR}" OPTIONAL) @@ -1589,36 +1521,13 @@ if(MSVC AND BUILD_SHARED_LIBS) endif() # ---[ CUDA library. -if(BUILD_SPLIT_CUDA) - target_link_libraries(torch_cuda_cu INTERFACE torch::cudart) - target_link_libraries(torch_cuda_cpp INTERFACE torch::cudart) - target_link_libraries(torch_cuda_cu PUBLIC c10_cuda torch::nvtoolsext) - target_link_libraries(torch_cuda_cpp PUBLIC c10_cuda torch::nvtoolsext) - - target_include_directories( - torch_cuda_cu INTERFACE $) - target_include_directories( - torch_cuda_cpp INTERFACE $) - target_include_directories( - torch_cuda_cu PRIVATE ${Caffe2_GPU_INCLUDE}) - target_include_directories( - torch_cuda_cpp PRIVATE ${Caffe2_GPU_INCLUDE}) - target_link_libraries( - torch_cuda_cu PRIVATE ${Caffe2_CUDA_DEPENDENCY_LIBS}) - target_link_libraries( - torch_cuda_cpp PRIVATE ${Caffe2_CUDA_DEPENDENCY_LIBS}) - target_link_libraries(torch_cuda_cu PRIVATE torch_cuda_cpp) - if(USE_CUDNN) - target_link_libraries( - torch_cuda_cpp PRIVATE caffe2::cudnn-private) +if(USE_CUDA) + # FIXME: If kineto is linked with CUPTI it pollutes torch_cpu with CUDA dependencies + # Even worse, it never declares that it depends on cudart, but calls the API, see + # https://github.com/pytorch/kineto/blob/aef2f5c0f15e3be52406ac0b885e8689de6bc9f6/libkineto/src/CudaDeviceProperties.cpp#L24 + if(USE_KINETO AND NOT MSVC) + target_link_libraries(torch_cpu PRIVATE torch::cudart) endif() - - # These public dependencies must go after the previous dependencies, as the - # order of the libraries in the linker call matters here when statically - # linking; libculibos and cublas must be last. - target_link_libraries(torch_cuda_cpp PUBLIC torch_cpu_library ${Caffe2_PUBLIC_CUDA_DEPENDENCY_LIBS}) - target_link_libraries(torch_cuda_cu PUBLIC torch_cpu_library ${Caffe2_PUBLIC_CUDA_DEPENDENCY_LIBS}) -elseif(USE_CUDA) target_link_libraries(torch_cuda INTERFACE torch::cudart) target_link_libraries(torch_cuda PUBLIC c10_cuda torch::nvtoolsext) diff --git a/caffe2/README.md b/caffe2/README.md index 0b69eec8191b8..13171fca23bb7 100644 --- a/caffe2/README.md +++ b/caffe2/README.md @@ -1,7 +1,5 @@ # Caffe2 -[![Jenkins Build Status](https://ci.pytorch.org/jenkins/job/caffe2-master/lastCompletedBuild/badge/icon)](https://ci.pytorch.org/jenkins/job/caffe2-master) - Caffe2 is a lightweight, modular, and scalable deep learning framework. Building on the original [Caffe](http://caffe.berkeleyvision.org), Caffe2 is designed with expression, speed, and modularity in mind. ## Questions and Feedback diff --git a/caffe2/contrib/tensorrt/README.md b/caffe2/contrib/tensorrt/README.md index f1e449e727e94..6ffe1dfb53bc6 100644 --- a/caffe2/contrib/tensorrt/README.md +++ b/caffe2/contrib/tensorrt/README.md @@ -15,4 +15,4 @@ For further information please explore `caffe2/python/trt/test_trt.py` test show ## Questions and Feedback -Please use Github issues (https://github.com/pytorch/pytorch/issues) to ask questions, report bugs, and request new features. +Please use GitHub issues (https://github.com/pytorch/pytorch/issues) to ask questions, report bugs, and request new features. diff --git a/caffe2/contrib/tensorrt/tensorrt_tranformer.cc b/caffe2/contrib/tensorrt/tensorrt_tranformer.cc index ebe27ef38a199..f1414deca8caa 100644 --- a/caffe2/contrib/tensorrt/tensorrt_tranformer.cc +++ b/caffe2/contrib/tensorrt/tensorrt_tranformer.cc @@ -518,7 +518,7 @@ void TensorRTTransformer::Transform( return SubnetToTrtOp(net, &mapped_ws, &exporter2, &shape_hints); }; - auto cutResult = opt::OptimizeForBackend(*pred_net, supports, trt_converter) + auto cutResult = opt::OptimizeForBackend(*pred_net, supports, trt_converter); NetDef net_opt = std::move(cutResult.net); // Need to figure out a proper place to handle device option diff --git a/caffe2/core/context_gpu.cu b/caffe2/core/context_gpu.cu index bfa563ca6b8bb..3359e88bcba20 100644 --- a/caffe2/core/context_gpu.cu +++ b/caffe2/core/context_gpu.cu @@ -235,7 +235,7 @@ static void Caffe2InitializeCuda() { // a reserved flag for cudaDeviceEnablePeerAccess that should always be // zero currently. // It is ok if peer access is already enabled... - cudaError_t err = cudaDeviceEnablePeerAccess(j, 0); + cudaError_t err = C10_CUDA_ERROR_HANDLED(cudaDeviceEnablePeerAccess(j, 0)); if ((err != cudaErrorPeerAccessAlreadyEnabled) && (err != cudaSuccess)) { CAFFE_THROW(cudaGetErrorString(err)); @@ -351,7 +351,7 @@ struct CAFFE2_CUDA_API PinnedCPUAllocator final : public at::Allocator { CUDA_ENFORCE(cudaHostUnregister(data)); GetDefaultCPUAllocator()->raw_deleter()(data); } else { - cudaError_t err = cudaFreeHost(data); + cudaError_t err = C10_CUDA_ERROR_HANDLED(cudaFreeHost(data)); profiledCPUMemoryReporter().Delete(data); if (err == cudaErrorInvalidValue) { free(data); @@ -561,12 +561,12 @@ struct DefaultCUDAAllocator final : public at::Allocator { // some models that are currently running with the thc // allocator fit in memory. We will need to find some // way of resolving this problem. - cuda::CUDAStreamGuard g( + c10::cuda::CUDAStreamGuard g( Stream( Stream::DEFAULT, Device(kCUDA, CaffeCudaGetDevice()) )); - ptr = cuda::CUDACachingAllocator::raw_alloc(nbytes); + ptr = c10::cuda::CUDACachingAllocator::raw_alloc(nbytes); } if (FLAGS_caffe2_gpu_memory_tracking) { g_size_map[ptr] = nbytes; @@ -598,7 +598,7 @@ struct DefaultCUDAAllocator final : public at::Allocator { switch (g_cuda_memory_pool_type) { case CudaMemoryPoolType::NONE: { // If memory pool is not set up, use simple cudaFree. - cudaError_t error = cudaFree(ptr); + cudaError_t error = C10_CUDA_ERROR_HANDLED(cudaFree(ptr)); // For some reason, in Python runtime we sometimes delete a data pointer // after the cuda runtime exits - this is odd but is probably caused by // a static workspace that pycaffe2 uses, and the destruction got @@ -625,7 +625,7 @@ struct DefaultCUDAAllocator final : public at::Allocator { break; } case CudaMemoryPoolType::THC: { - cuda::CUDACachingAllocator::raw_delete(ptr); + c10::cuda::CUDACachingAllocator::raw_delete(ptr); if (FLAGS_caffe2_gpu_memory_tracking) { g_cuda_device_affiliation.erase(g_cuda_device_affiliation.find(ptr)); } diff --git a/caffe2/core/context_gpu.h b/caffe2/core/context_gpu.h index e411d9cd735f1..611b2550c7e0c 100644 --- a/caffe2/core/context_gpu.h +++ b/caffe2/core/context_gpu.h @@ -195,10 +195,6 @@ class CAFFE2_CUDA_API CUDAContext final : public BaseContext { // SwitchToDevice() void FinishDeviceComputation() override { CUDA_ENFORCE(cudaStreamSynchronize(getCudaObjects().GetStream(gpu_id_))); - cudaError_t error = cudaGetLastError(); - if (error != cudaSuccess) { - CAFFE_THROW("Encountered CUDA error: ", cudaGetErrorString(error)); - } } inline int device_id() const { @@ -309,11 +305,13 @@ class CAFFE2_CUDA_API CUDAContext final : public BaseContext { } static bool IsStreamFree(const DeviceOption& option, StreamId stream_id) { - auto stream = CUDAContext::cuda_stream(option.device_id(), stream_id); - auto status = cudaStreamQuery(stream); + const auto stream = CUDAContext::cuda_stream(option.device_id(), stream_id); + const auto status = C10_CUDA_ERROR_HANDLED(cudaStreamQuery(stream)); if (status == cudaErrorNotReady) { // ignore and clear the error if not ready - (void)cudaGetLastError(); + C10_CUDA_CLEAR_ERROR(); + } else { + C10_CUDA_CHECK(status); // Reraise error } return status == cudaSuccess; } diff --git a/caffe2/core/macros.h.in b/caffe2/core/macros.h.in index 9c9f734575634..2d9f03e94c0fc 100644 --- a/caffe2/core/macros.h.in +++ b/caffe2/core/macros.h.in @@ -44,6 +44,7 @@ static_assert( #cmakedefine CAFFE2_USE_NVTX #cmakedefine CAFFE2_USE_ITT #cmakedefine CAFFE2_USE_TRT +#cmakedefine TORCH_DISABLE_GPU_ASSERTS #ifndef EIGEN_MPL2_ONLY #cmakedefine EIGEN_MPL2_ONLY @@ -85,4 +86,5 @@ static_assert( {"USE_NVTX", "${CAFFE2_USE_NVTX}"}, \ {"USE_ITT", "${CAFFE2_USE_ITT}"}, \ {"USE_TRT", "${CAFFE2_USE_TRT}"}, \ + {"TORCH_DISABLE_GPU_ASSERTS", "${TORCH_DISABLE_GPU_ASSERTS}"}, \ } diff --git a/caffe2/core/nomnigraph/CMakeLists.txt b/caffe2/core/nomnigraph/CMakeLists.txt index c4d4216ef9e97..8980c52ddfb4f 100644 --- a/caffe2/core/nomnigraph/CMakeLists.txt +++ b/caffe2/core/nomnigraph/CMakeLists.txt @@ -18,5 +18,5 @@ set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} PARENT_SCOPE) set(Caffe2_CPU_INCLUDE ${Caffe2_CPU_INCLUDE} PARENT_SCOPE) set(Caffe2_CPU_TEST_SRCS ${Caffe2_CPU_TEST_SRCS} PARENT_SCOPE) if(USE_TENSORRT) -set(Caffe2_GPU_INCLUDE ${Caffe2_CPU_INCLUDE} PARENT_SCOPE) +set(Caffe2_GPU_INCLUDE ${Caffe2_GPU_INCLUDE} PARENT_SCOPE) endif() diff --git a/caffe2/mobile/contrib/libopencl-stub/README.md b/caffe2/mobile/contrib/libopencl-stub/README.md index 20b3dafa8095f..835ba2172cbe8 100644 --- a/caffe2/mobile/contrib/libopencl-stub/README.md +++ b/caffe2/mobile/contrib/libopencl-stub/README.md @@ -1,7 +1,7 @@ libopencl-stub ============== -A stub opecl library that dynamically dlopen/dlsyms opencl implementations at runtime based on environment variables. Will be useful when opencl implementations are installed in non-standard paths (say pocl on android) +A stub opencl library that dynamically dlopen/dlsyms opencl implementations at runtime based on environment variables. Will be useful when opencl implementations are installed in non-standard paths (say pocl on android) diff --git a/caffe2/operators/batch_box_cox_op.cc b/caffe2/operators/batch_box_cox_op.cc index aa444330969b5..6e2bb4d9a8d9d 100644 --- a/caffe2/operators/batch_box_cox_op.cc +++ b/caffe2/operators/batch_box_cox_op.cc @@ -2,72 +2,34 @@ #include "caffe2/core/operator.h" #include "caffe2/core/tensor.h" - -#ifdef CAFFE2_USE_MKL -#include -#endif // CAFFE2_USE_MKL +#include "caffe2/perfkernels/batch_box_cox.h" namespace caffe2 { -#ifdef CAFFE2_USE_MKL namespace { - -// Helpers for copying parameters. template -void TileArrayIntoVector(const T* a, int D, int K, vector* b) { - b->resize(K * D); - for (int k = 0; k < K; k++) { - std::copy(a, a + D, b->begin() + k * D); - } -} - -void TileIndicesInPlace(vector* v, int D, int K) { - int n = v->size(); - v->resize(K * n); - for (int k = 1; k < K; k++) { - for (int j = 0; j < n; j++) { - (*v)[k * n + j] = (*v)[j] + k * D; +void BoxCoxNaive( + int64_t N, + int64_t D, + const T* data_ptr, + const T* lambda1_ptr, + const T* lambda2_ptr, + T* output_ptr) { + constexpr T k_eps = static_cast(1e-6); + for (int64_t i = 0; i < N; i++) { + for (int64_t j = 0; j < D; j++, data_ptr++, output_ptr++) { + T lambda1_v = lambda1_ptr[j]; + T lambda2_v = lambda2_ptr[j]; + T tmp = std::max(*data_ptr + lambda2_v, k_eps); + if (lambda1_v == 0) { + *output_ptr = std::log(tmp); + } else { + *output_ptr = (std::pow(tmp, lambda1_v) - 1) / lambda1_v; + } } } } - -// MKL VML function templates. -template -void PackV(const int N, const T* a, const int* ia, T* y); -template -void UnpackV(const int N, const T* a, T* y, const int* iy); -template -void Pow(const int N, const T* a, const T* b, T* y); - -#define DELEGATE_PACKV_FUNCTION(T, OriginalFunc) \ - template <> \ - void PackV(const int N, const T* a, const int* ia, T* y) { \ - OriginalFunc(N, a, ia, y); \ - } -DELEGATE_PACKV_FUNCTION(float, vsPackV) -DELEGATE_PACKV_FUNCTION(double, vdPackV) -#undef DELEGATE_PACKV_FUNCTION - -#define DELEGATE_UNPACKV_FUNCTION(T, OriginalFunc) \ - template <> \ - void UnpackV(const int N, const T* a, T* y, const int* iy) { \ - OriginalFunc(N, a, y, iy); \ - } -DELEGATE_UNPACKV_FUNCTION(float, vsUnpackV) -DELEGATE_UNPACKV_FUNCTION(double, vdUnpackV) -#undef DELEGATE_UNPACKV_FUNCTION - -#define DELEGATE_SIMPLE_BINARY_FUNCTION(T, Funcname, OriginalFunc) \ - template <> \ - void Funcname(const int N, const T* a, const T* b, T* y) { \ - OriginalFunc(N, a, b, y); \ - } -DELEGATE_SIMPLE_BINARY_FUNCTION(float, Pow, vsPow) -DELEGATE_SIMPLE_BINARY_FUNCTION(double, Pow, vdPow) -#undef DELEGATE_SIMPLE_BINARY_FUNCTION - -} // namespace -#endif // CAFFE2_USE_MKL +} template <> template @@ -93,227 +55,19 @@ bool BatchBoxCoxOp::DoRunWithType() { const auto* lambda1_ptr = lambda1.template data(); const auto* lambda2_ptr = lambda2.template data(); - const T k_eps = static_cast(1e-6); - #ifdef CAFFE2_USE_MKL if (min_block_size_ < 1) { - BoxCoxNaive(N, D, data_ptr, lambda1_ptr, lambda2_ptr, k_eps, output_ptr); - } else { - // Find zero-valued columns, since they get special treatment. - nonzeros_.clear(); - zeros_.clear(); - nonzeros_.reserve(D); - zeros_.reserve(D); - for (int64_t j = 0; j < D; j++) { - if (lambda1_ptr[j] == 0) { - zeros_.push_back(j); - } else { - nonzeros_.push_back(j); - } - } - - // Process K rows at a time for effective vectorization with small rows. - const int K = std::min(N, (min_block_size_ + D - 1) / D); - - // Avoid copying data if all lambda1 values are zero, or if all are nonzero. - // In each of the three cases here, when K > 1, first process batches of K - // rows by replicating the input parameters K times. Then finish row-by-row. - TypedCachedBuffers& b = GetBuffers(); - if (nonzeros_.size() == D) { - int64_t i = 0; - if (K > 1) { - TileArrayIntoVector(lambda1_ptr, D, K, &b.lambda1_); - TileArrayIntoVector(lambda2_ptr, D, K, &b.lambda2_); - TORCH_DCHECK_EQ(K * D, b.lambda1_.size()); - TORCH_DCHECK_EQ(K * D, b.lambda2_.size()); - for (; i < N - K + 1; i += K, data_ptr += K * D, output_ptr += K * D) { - BoxCoxNonzeroLambda( - K * D, - data_ptr, - b.lambda1_.data(), - b.lambda2_.data(), - k_eps, - output_ptr); - } - } - for (; i < N; i++, data_ptr += D, output_ptr += D) { - BoxCoxNonzeroLambda( - D, data_ptr, lambda1_ptr, lambda2_ptr, k_eps, output_ptr); - } - } else if (zeros_.size() == D) { - int64_t i = 0; - if (K > 1) { - TileArrayIntoVector(lambda2_ptr, D, K, &b.lambda2_z_); - TORCH_DCHECK_EQ(K * D, b.lambda2_z_.size()); - for (; i < N - K + 1; i += K, data_ptr += K * D, output_ptr += K * D) { - BoxCoxZeroLambda( - K * D, data_ptr, b.lambda2_z_.data(), k_eps, output_ptr); - } - } - for (; i < N; i++, data_ptr += D, output_ptr += D) { - BoxCoxZeroLambda(D, data_ptr, lambda2_ptr, k_eps, output_ptr); - } - } else { // General case of mixed zero and non-zero lambda1 values. - int n = nonzeros_.size(); - if (K > 1) { - TileIndicesInPlace(&nonzeros_, 0, K); - TileIndicesInPlace(&zeros_, 0, K); - } - - // Gather parameter values into contiguous memory. - b.lambda1_.resize(nonzeros_.size()); - b.lambda2_.resize(nonzeros_.size()); - b.lambda2_z_.resize(zeros_.size()); - PackV(nonzeros_.size(), lambda1_ptr, nonzeros_.data(), b.lambda1_.data()); - PackV(nonzeros_.size(), lambda2_ptr, nonzeros_.data(), b.lambda2_.data()); - PackV(zeros_.size(), lambda2_ptr, zeros_.data(), b.lambda2_z_.data()); - - int64_t i = 0; - b.accumulator_.resize(std::max(nonzeros_.size(), zeros_.size())); - if (K > 1) { - // Truncate to original size, and re-tile with offsets this time. - nonzeros_.resize(n); - zeros_.resize(D - n); - TileIndicesInPlace(&nonzeros_, D, K); - TileIndicesInPlace(&zeros_, D, K); - TORCH_DCHECK_EQ(nonzeros_.size(), b.lambda1_.size()); - TORCH_DCHECK_EQ(nonzeros_.size(), b.lambda2_.size()); - TORCH_DCHECK_EQ(zeros_.size(), b.lambda2_z_.size()); - for (; i < N - K + 1; i += K, data_ptr += K * D, output_ptr += K * D) { - BoxCoxMixedLambda( - data_ptr, - nonzeros_, - zeros_, - b.lambda1_.data(), - b.lambda2_.data(), - b.lambda2_z_.data(), - k_eps, - b.accumulator_.data(), - output_ptr); - } - // Truncate to original size. - nonzeros_.resize(n); - zeros_.resize(D - n); - } - for (; i < N; i++, data_ptr += D, output_ptr += D) { - BoxCoxMixedLambda( - data_ptr, - nonzeros_, - zeros_, - b.lambda1_.data(), - b.lambda2_.data(), - b.lambda2_z_.data(), - k_eps, - b.accumulator_.data(), - output_ptr); - } - } + BoxCoxNaive(N, D, data_ptr, lambda1_ptr, lambda2_ptr, output_ptr); + return true; } -#else // CAFFE2_USE_MKL - BoxCoxNaive(N, D, data_ptr, lambda1_ptr, lambda2_ptr, k_eps, output_ptr); -#endif // CAFFE2_USE_MKL + caffe2::compute_batch_box_cox( + N, D, min_block_size_, data_ptr, lambda1_ptr, lambda2_ptr, output_ptr); +#else + BoxCoxNaive(N, D, data_ptr, lambda1_ptr, lambda2_ptr, output_ptr); +#endif return true; } -template <> -template -void BatchBoxCoxOp::BoxCoxNaive( - int64_t N, - int64_t D, - const T* data_ptr, - const T* lambda1_ptr, - const T* lambda2_ptr, - T k_eps, - T* output_ptr) { - for (int64_t i = 0; i < N; i++) { - for (int64_t j = 0; j < D; j++, data_ptr++, output_ptr++) { - T lambda1_v = lambda1_ptr[j]; - T lambda2_v = lambda2_ptr[j]; - T tmp = std::max(*data_ptr + lambda2_v, k_eps); - if (lambda1_v == 0) { - *output_ptr = std::log(tmp); - } else { - *output_ptr = (std::pow(tmp, lambda1_v) - 1) / lambda1_v; - } - } - } -} - -#ifdef CAFFE2_USE_MKL - -template <> -template -void BatchBoxCoxOp::BoxCoxNonzeroLambda( - int64_t D, - const T* data_ptr, - const T* lambda1, - const T* lambda2, - T k_eps, - T* out) { - caffe2::math::Add(D, data_ptr, lambda2, out, &context_); - for (int64_t j = 0; j < D; j++) { - out[j] = std::max(out[j], k_eps); - } - Pow(D, out, lambda1, out); - for (int64_t j = 0; j < D; j++) { - out[j] -= 1.0; - } - caffe2::math::Div(D, out, lambda1, out, &context_); -} - -template <> -template -void BatchBoxCoxOp::BoxCoxZeroLambda( - int64_t D, - const T* data_ptr, - const T* lambda2, - T k_eps, - T* output_ptr) { - caffe2::math::Add(D, data_ptr, lambda2, output_ptr, &context_); - for (int64_t j = 0; j < D; j++) { - output_ptr[j] = std::max(output_ptr[j], k_eps); - } - caffe2::math::Log(D, output_ptr, output_ptr, &context_); -} - -template <> -template -void BatchBoxCoxOp::BoxCoxMixedLambda( - const T* data_ptr, - const vector& nonzeros, - const vector& zeros, - const T* lambda1, - const T* lambda2, - const T* lambda2_z, - T k_eps, - T* buffer, - T* output_ptr) { - PackV(nonzeros.size(), data_ptr, nonzeros.data(), buffer); - BoxCoxNonzeroLambda(nonzeros.size(), buffer, lambda1, lambda2, k_eps, buffer); - UnpackV(nonzeros.size(), buffer, output_ptr, nonzeros.data()); - - PackV(zeros.size(), data_ptr, zeros.data(), buffer); - BoxCoxZeroLambda(zeros.size(), buffer, lambda2_z, k_eps, buffer); - UnpackV(zeros.size(), buffer, output_ptr, zeros.data()); -} - -// Helpers to access cached buffers. -#define DEFINE_CACHED_BUFFERS(T, tag) \ - template <> \ - template <> \ - BatchBoxCoxOp::TypedCachedBuffers& \ - BatchBoxCoxOp::GetBuffers() { \ - if (!buffers_ || buffers_->type_ != tag) { \ - buffers_.reset(new BatchBoxCoxOp::TypedCachedBuffers()); \ - buffers_->type_ = tag; \ - } \ - return *static_cast*>(buffers_.get()); \ - } -DEFINE_CACHED_BUFFERS(float, 1); -DEFINE_CACHED_BUFFERS(double, 2); -#undef DEFINE_CACHED_BUFFERS - -#endif // CAFFE2_USE_MKL namespace { diff --git a/caffe2/operators/batch_box_cox_op.h b/caffe2/operators/batch_box_cox_op.h index baa9c955b6cac..a177131e9adee 100644 --- a/caffe2/operators/batch_box_cox_op.h +++ b/caffe2/operators/batch_box_cox_op.h @@ -29,65 +29,7 @@ class BatchBoxCoxOp final : public Operator { bool DoRunWithType(); protected: - template - void BoxCoxNaive( - int64_t N, - int64_t D, - const T* data_ptr, - const T* lambda1_ptr, - const T* lambda2_ptr, - T k_eps, - T* output_ptr); - -#ifdef CAFFE2_USE_MKL - template - void BoxCoxNonzeroLambda( - int64_t D, - const T* data_ptr, - const T* lambda1, - const T* lambda2, - T k_eps, - T* output_ptr); - - template - void BoxCoxZeroLambda( - int64_t D, - const T* data_ptr, - const T* lambda2, - T k_eps, - T* output_ptr); - - template - void BoxCoxMixedLambda( - const T* data_ptr, - const vector& nonzeros, - const vector& zeros, - const T* lambda1, - const T* lambda2, - const T* lambda2_z, - T k_eps, - T* buffer, - T* output_ptr); - - vector nonzeros_, zeros_; - - // Buffers used by the MKL version are cached across calls. - struct CachedBuffers { - virtual ~CachedBuffers() {} - int type_; - }; - template - struct TypedCachedBuffers : public CachedBuffers { - vector lambda1_, lambda2_, lambda2_z_; - vector accumulator_; - }; - template - TypedCachedBuffers& GetBuffers(); - unique_ptr buffers_; - -#endif // CAFFE2_USE_MKL - - int min_block_size_; + std::size_t min_block_size_; INPUT_TAGS(DATA, LAMBDA1, LAMBDA2); }; diff --git a/caffe2/operators/generate_proposals_op_util_nms_gpu.cu b/caffe2/operators/generate_proposals_op_util_nms_gpu.cu index 9776266154cf3..aac9b3b81db26 100644 --- a/caffe2/operators/generate_proposals_op_util_nms_gpu.cu +++ b/caffe2/operators/generate_proposals_op_util_nms_gpu.cu @@ -145,15 +145,15 @@ void nms_gpu_upright( // Overlapping CPU computes and D2H memcpy // both take about the same time cudaEvent_t copy_done; - cudaEventCreate(©_done); + C10_CUDA_CHECK(cudaEventCreate(©_done)); int nto_copy = std::min(CHUNK_SIZE, N); - CUDA_CHECK(cudaMemcpyAsync( + C10_CUDA_CHECK(cudaMemcpyAsync( &h_delete_mask[0], &d_delete_mask[0], nto_copy * mask_ld * sizeof(int), cudaMemcpyDeviceToHost, context->cuda_stream())); - CUDA_CHECK(cudaEventRecord(copy_done, context->cuda_stream())); + C10_CUDA_CHECK(cudaEventRecord(copy_done, context->cuda_stream())); int offset = 0; std::vector h_keep_sorted_list; std::vector rmv(mask_ld, 0); @@ -162,7 +162,7 @@ void nms_gpu_upright( int next_offset = offset + ncopied; nto_copy = std::min(CHUNK_SIZE, N - next_offset); if (nto_copy > 0) { - CUDA_CHECK(cudaMemcpyAsync( + C10_CUDA_CHECK(cudaMemcpyAsync( &h_delete_mask[next_offset * mask_ld], &d_delete_mask[next_offset * mask_ld], nto_copy * mask_ld * sizeof(int), @@ -170,9 +170,10 @@ void nms_gpu_upright( context->cuda_stream())); } // Waiting for previous copy - CUDA_CHECK(cudaEventSynchronize(copy_done)); - if (nto_copy > 0) - cudaEventRecord(copy_done, context->cuda_stream()); + C10_CUDA_CHECK(cudaEventSynchronize(copy_done)); + if (nto_copy > 0){ + C10_CUDA_CHECK(cudaEventRecord(copy_done, context->cuda_stream())); + } for (int i = offset; i < next_offset; ++i) { int iblock = i / BOXES_PER_THREAD; int inblock = i % BOXES_PER_THREAD; @@ -186,15 +187,15 @@ void nms_gpu_upright( } offset = next_offset; } - cudaEventDestroy(copy_done); + C10_CUDA_CHECK(cudaEventDestroy(copy_done)); const int nkeep = h_keep_sorted_list.size(); - cudaMemcpyAsync( + C10_CUDA_CHECK(cudaMemcpyAsync( d_keep_sorted_list, &h_keep_sorted_list[0], nkeep * sizeof(int), cudaMemcpyHostToDevice, - context->cuda_stream()); + context->cuda_stream())); *h_nkeep = nkeep; } @@ -502,15 +503,15 @@ void nms_gpu_rotated( // Overlapping CPU computes and D2H memcpy // both take about the same time cudaEvent_t copy_done; - cudaEventCreate(©_done); + C10_CUDA_CHECK(cudaEventCreate(©_done)); int nto_copy = std::min(CHUNK_SIZE, N); - CUDA_CHECK(cudaMemcpyAsync( + C10_CUDA_CHECK(cudaMemcpyAsync( &h_delete_mask[0], &d_delete_mask[0], nto_copy * mask_ld * sizeof(int), cudaMemcpyDeviceToHost, context->cuda_stream())); - CUDA_CHECK(cudaEventRecord(copy_done, context->cuda_stream())); + C10_CUDA_CHECK(cudaEventRecord(copy_done, context->cuda_stream())); int offset = 0; std::vector h_keep_sorted_list; std::vector rmv(mask_ld, 0); @@ -519,7 +520,7 @@ void nms_gpu_rotated( int next_offset = offset + ncopied; nto_copy = std::min(CHUNK_SIZE, N - next_offset); if (nto_copy > 0) { - CUDA_CHECK(cudaMemcpyAsync( + C10_CUDA_CHECK(cudaMemcpyAsync( &h_delete_mask[next_offset * mask_ld], &d_delete_mask[next_offset * mask_ld], nto_copy * mask_ld * sizeof(int), @@ -527,9 +528,10 @@ void nms_gpu_rotated( context->cuda_stream())); } // Waiting for previous copy - CUDA_CHECK(cudaEventSynchronize(copy_done)); - if (nto_copy > 0) - cudaEventRecord(copy_done, context->cuda_stream()); + C10_CUDA_CHECK(cudaEventSynchronize(copy_done)); + if (nto_copy > 0){ + C10_CUDA_CHECK(cudaEventRecord(copy_done, context->cuda_stream())); + } for (int i = offset; i < next_offset; ++i) { int iblock = i / BOXES_PER_THREAD; int inblock = i % BOXES_PER_THREAD; @@ -543,15 +545,15 @@ void nms_gpu_rotated( } offset = next_offset; } - cudaEventDestroy(copy_done); + C10_CUDA_CHECK(cudaEventDestroy(copy_done)); const int nkeep = h_keep_sorted_list.size(); - cudaMemcpyAsync( + C10_CUDA_CHECK(cudaMemcpyAsync( d_keep_sorted_list, &h_keep_sorted_list[0], nkeep * sizeof(int), cudaMemcpyHostToDevice, - context->cuda_stream()); + context->cuda_stream())); *h_nkeep = nkeep; } diff --git a/caffe2/operators/generate_proposals_op_util_nms_gpu_test.cc b/caffe2/operators/generate_proposals_op_util_nms_gpu_test.cc index 6c8283b3d0fe4..ea656dd30e3b9 100644 --- a/caffe2/operators/generate_proposals_op_util_nms_gpu_test.cc +++ b/caffe2/operators/generate_proposals_op_util_nms_gpu_test.cc @@ -691,7 +691,7 @@ TEST(UtilsNMSTest, TestPerfRotatedNMS) { // list_nitems * sizeof(int), // cudaMemcpyDeviceToHost, // cuda_context.cuda_stream())); -// CUDA_CHECK(cudaStreamSynchronize(cuda_context.cuda_stream())); +// CUDA_CHECK(cudaStreamSynchronize(cuda_context.cuda_stream()); // ASSERT_EQ(keep.size(), gpu_keep.size()); // std::sort(keep.begin(), keep.end()); diff --git a/caffe2/operators/rnn/recurrent_network_executor_gpu.cc b/caffe2/operators/rnn/recurrent_network_executor_gpu.cc index ef041959742ac..0356218c717f4 100644 --- a/caffe2/operators/rnn/recurrent_network_executor_gpu.cc +++ b/caffe2/operators/rnn/recurrent_network_executor_gpu.cc @@ -130,8 +130,7 @@ void CUDARecurrentNetworkExecutor::_ExecRange(int from, int to) { for (int stream_id = 0; stream_id <= std::min(stream_seq, max_streams - 1); stream_id++) { VLOG(1) << "Wait for stream:" << stream_id; - CUDA_CHECK( - cudaStreamSynchronize(CUDAContext::cuda_stream(gpu_id, stream_id))); + CUDA_CHECK(cudaStreamSynchronize(CUDAContext::cuda_stream(gpu_id, stream_id))); } } diff --git a/caffe2/operators/scale_blobs_op.cu b/caffe2/operators/scale_blobs_op.cu index 01421fb822c6f..7305fddece96f 100644 --- a/caffe2/operators/scale_blobs_op.cu +++ b/caffe2/operators/scale_blobs_op.cu @@ -138,9 +138,9 @@ REGISTER_CUDA_OPERATOR(ScaleBlobs, ScaleBlobsOp); } } } - cudaMalloc(&dStartCoorArr, sizeof(int) * coorArrSize); - cudaMemcpy(dStartCoorArr, startCoorArr, sizeof(int) * coorArrSize, - cudaMemcpyHostToDevice); + C10_CUDA_CHECK(cudaMalloc(&dStartCoorArr, sizeof(int) * coorArrSize)); + C10_CUDA_CHECK(cudaMemcpy(dStartCoorArr, startCoorArr, sizeof(int) * coorArrSize, + cudaMemcpyHostToDevice)); // ScaleBlobsCUDAKernelBalanced kernel launch ScaleBlobsCUDAKernelBalanced @@ -150,7 +150,7 @@ REGISTER_CUDA_OPERATOR(ScaleBlobs, ScaleBlobsOp); dOutputArr); C10_CUDA_KERNEL_LAUNCH_CHECK(); - cudaFree(dStartCoorArr); + C10_CUDA_CHECK(cudaFree(dStartCoorArr)); */ template diff --git a/caffe2/operators/segment_reduction_op_gpu.cu b/caffe2/operators/segment_reduction_op_gpu.cu index 7253df677025b..6985c3c3378b4 100644 --- a/caffe2/operators/segment_reduction_op_gpu.cu +++ b/caffe2/operators/segment_reduction_op_gpu.cu @@ -493,7 +493,7 @@ class CUDASparseLengthsSumOp : public Operator { enum { DATA = 0, INDICES = 1, LENGTHS = 1 + (SparseFused ? 1 : 0) }; private: - // menber field to manage memory + // member field to manage memory Tensor inclusive_scan_buffer_{CUDA}; Tensor inclusive_scan_length_buffer_{CUDA}; }; @@ -632,7 +632,7 @@ class CUDASparseLengthsMeanOp : public Operator { enum { DATA = 0, INDICES = 1, LENGTHS = 1 + (SparseFused ? 1 : 0) }; private: - // menber field to manage memory + // member field to manage memory Tensor inclusive_scan_buffer_{CUDA}; Tensor inclusive_scan_length_buffer_{CUDA}; }; @@ -765,7 +765,7 @@ class CUDASparseLengthsMaxOp : public Operator { enum { INDICES = 1, LENGTHS = 1 + (SparseFused ? 1 : 0) }; private: - // menber field to manage memory + // member field to manage memory Tensor inclusive_scan_buffer_{CUDA}; Tensor inclusive_scan_length_buffer_{CUDA}; }; @@ -861,7 +861,7 @@ class CUDASparseLengthsWeightedSumOp : public Operator { enum { DATA = 0, WEIGHTS = 1, INDICES = 2, LENGTHS = 3 }; private: - // menber field to manage memory + // member field to manage memory Tensor inclusive_scan_buffer_{CUDA}; Tensor inclusive_scan_length_buffer_{CUDA}; }; @@ -1356,7 +1356,7 @@ class CUDASparseLengthsSumGradientWithIndicesOp : public Operator { } private: - // menber field to manage memory + // member field to manage memory Tensor inclusive_scan_buffer_{CUDA}; Tensor inclusive_scan_length_buffer_{CUDA}; }; @@ -1437,7 +1437,7 @@ class CUDASparseLengthsMeanGradientWithIndicesOp } private: - // menber field to manage memory + // member field to manage memory Tensor inclusive_scan_buffer_{CUDA}; Tensor inclusive_scan_length_buffer_{CUDA}; }; @@ -1526,7 +1526,7 @@ class CUDASparseLengthsWeightedSumGradientWithIndicesOp } private: - // menber field to manage memory + // member field to manage memory Tensor inclusive_scan_buffer_{CUDA}; Tensor inclusive_scan_length_buffer_{CUDA}; }; @@ -1666,7 +1666,7 @@ class CUDALengthsMaxWithMainInputAndForwardOutputGradientOp } private: - // menber field to manage memory + // member field to manage memory Tensor inclusive_scan_buffer_{CUDA}; Tensor inclusive_scan_length_buffer_{CUDA}; }; @@ -1793,7 +1793,7 @@ class CUDASparseLengthsIndicesInGradientWeightedSumWithMainInputGradientOp } private: - // menber field to manage memory + // member field to manage memory Tensor inclusive_scan_buffer_{CUDA}; Tensor inclusive_scan_length_buffer_{CUDA}; }; diff --git a/caffe2/perfkernels/batch_box_cox.cc b/caffe2/perfkernels/batch_box_cox.cc new file mode 100644 index 0000000000000..3e840d8fa04d3 --- /dev/null +++ b/caffe2/perfkernels/batch_box_cox.cc @@ -0,0 +1,113 @@ +#include "caffe2/perfkernels/common.h" + +#include +#include +#include + +namespace caffe2 { + +namespace { +template +void BoxCoxNaive( + std::size_t N, + std::size_t D, + const T* data_ptr, + const T* __restrict lambda1_ptr, + const T* __restrict lambda2_ptr, + T* output_ptr) { + constexpr T k_eps = static_cast(1e-6); + + for (int64_t i = 0; i < N; i++) { + for (int64_t j = 0; j < D; j++, data_ptr++, output_ptr++) { + T lambda1_v = lambda1_ptr[j]; + T lambda2_v = lambda2_ptr[j]; + T tmp = std::max(*data_ptr + lambda2_v, k_eps); + if (lambda1_v == 0) { + *output_ptr = std::log(tmp); + } else { + T lambda_1 = 1 / lambda1_v; + T pow = std::pow(tmp, lambda1_v); + *output_ptr = lambda_1 * pow - lambda_1; + } + } + } + +} +} + +#if defined(CAFFE2_PERF_WITH_AVX2) && defined(CAFFE2_PERF_USE_MKL) +namespace details { +template +void compute_batch_box_cox__avx2_fma( + std::size_t N, + std::size_t D, + std::size_t block_size, + const T* data_ptr, + const T* __restrict lambda1_ptr, + const T* __restrict lambda2_ptr, + T* output_ptr); + +extern template +void compute_batch_box_cox__avx2_fma( + std::size_t N, + std::size_t D, + std::size_t block_size, + const float* self_data, + const float* __restrict lambda1_data, + const float* __restrict lambda2_data, + float* output_data); + +extern template +void compute_batch_box_cox__avx2_fma( + std::size_t N, + std::size_t D, + std::size_t block_size, + const double* self_data, + const double* __restrict lambda1_data, + const double* __restrict lambda2_data, + double* output_data); +} // namespace detail +#endif + +template +void compute_batch_box_cox( + std::size_t N, + std::size_t D, + std::size_t block_size, + const T* data, + const T* lambda1_data, + const T* lambda2_data, + T* output_data) { +#ifdef CAFFE2_PERF_WITH_AVX2 + AVX2_FMA_DO( + details::compute_batch_box_cox, + N, + D, + block_size, + data, + lambda1_data, + lambda2_data, + output_data); +#endif + BoxCoxNaive(N, D, data, lambda1_data, lambda2_data, output_data); +} + +template void compute_batch_box_cox( + std::size_t N, + std::size_t D, + std::size_t block_size, + const float* data, + const float* lambda1_data, + const float* lambda2_data, + float* output_data); + +template void compute_batch_box_cox( + std::size_t N, + std::size_t D, + std::size_t block_size, + const double* data, + const double* lambda1_data, + const double* lambda2_data, + double* output_data); + +} // namespace caffe2 diff --git a/caffe2/perfkernels/batch_box_cox.h b/caffe2/perfkernels/batch_box_cox.h new file mode 100644 index 0000000000000..60c973bbf8ea1 --- /dev/null +++ b/caffe2/perfkernels/batch_box_cox.h @@ -0,0 +1,35 @@ +// Impmenets BoxCox operator for CPU +#pragma once +#include + +namespace caffe2 { + +template +void compute_batch_box_cox( + std::size_t N, + std::size_t D, + std::size_t block_size, + const T* self_data, + const T* lambda1_data, + const T* lambda2_data, + T* output_data); + +extern template void compute_batch_box_cox( + std::size_t N, + std::size_t D, + std::size_t block_size, + const float* data, + const float* lambda1_data, + const float* lambda2_data, + float* output_data); + +extern template void compute_batch_box_cox( + std::size_t N, + std::size_t D, + std::size_t block_size, + const double* data, + const double* lambda1_data, + const double* lambda2_data, + double* output_data); + +} // namespace caffe2 diff --git a/caffe2/perfkernels/batch_box_cox_avx2.cc b/caffe2/perfkernels/batch_box_cox_avx2.cc new file mode 100644 index 0000000000000..6171b5bfd0326 --- /dev/null +++ b/caffe2/perfkernels/batch_box_cox_avx2.cc @@ -0,0 +1,399 @@ +#include +#ifdef CAFFE2_PERF_USE_MKL +#include +#include +#include + +#include "vectorizer.h" + +// Enable compiler vectorized version only if numerical consistency is not +// required between dev and opt versions - disabled for now +#ifndef FAST_VECTORIZED_KERNEL +#define CPU_CAPABILITY_AVX2 +#include + +namespace at::vec { + +// Implements the vectorized version of std::max() operation, +// which DOESNOT propagates NaN for second argument +template +Vectorized max(const Vectorized& a, const Vectorized& b); + +template <> +Vectorized max(const Vectorized& a, const Vectorized& b) { + // std::max(NaN, nonNan) -> NaN + return _mm256_max_pd(b, a); +} + +template <> +Vectorized max(const Vectorized& a, const Vectorized& b) { + // std::max(NaN, nonNan) -> NaN + return _mm256_max_ps(b, a); +} + +// Implements recieprocal method based on newton-rapson method +// 1. user RCP approximiation +// 2. update with RCP = RCP * (2 - X * RCP) +template +Vectorized fast_recieprocal(const Vectorized& b); +template +scalar_t fast_recieprocal(scalar_t b); + +template<> +Vectorized fast_recieprocal(const Vectorized& b) { + auto minus2 = _mm256_set1_ps(-2.f); + auto rcp = _mm256_rcp_ps(b); + rcp = _mm256_mul_ps(rcp, _mm256_fnmsub_ps(rcp, b, minus2)); + rcp = _mm256_mul_ps(rcp, _mm256_fnmsub_ps(rcp, b, minus2)); + return rcp; +} + +template <> +float fast_recieprocal(float b) { + auto minus2 = _mm_set_ss(-2.f); + auto b_reg = _mm_set_ss(b); + auto rcp = _mm_rcp_ss(b_reg); + rcp = _mm_mul_ss(rcp, _mm_fnmsub_ss(rcp, b_reg, minus2)); + rcp = _mm_mul_ss(rcp, _mm_fnmsub_ss(rcp, b_reg, minus2)); + return _mm_cvtss_f32(rcp); +} + +template<> +Vectorized fast_recieprocal(const Vectorized& b) { + return b.reciprocal(); +} + +template <> +double fast_recieprocal(double b) { + return 1./b; +} + +} +#endif + +#include +#include +#include + +#include + +namespace caffe2::details { + +// MKL VML function templates. +template +void PackV(const int N, const T* a, const int* ia, T* y); +template +void UnpackV(const int N, const T* a, T* y, const int* iy); + +#define DELEGATE_PACKV_FUNCTION(T, OriginalFunc) \ + template <> \ + void PackV(const int N, const T* a, const int* ia, T* y) { \ + OriginalFunc(N, a, ia, y); \ + } +DELEGATE_PACKV_FUNCTION(float, vsPackV) +DELEGATE_PACKV_FUNCTION(double, vdPackV) +#undef DELEGATE_PACKV_FUNCTION + +#define DELEGATE_UNPACKV_FUNCTION(T, OriginalFunc) \ + template <> \ + void UnpackV(const int N, const T* a, T* y, const int* iy) { \ + OriginalFunc(N, a, y, iy); \ + } +DELEGATE_UNPACKV_FUNCTION(float, vsUnpackV) +DELEGATE_UNPACKV_FUNCTION(double, vdUnpackV) +#undef DELEGATE_UNPACKV_FUNCTION + +#ifndef FAST_VECTORIZED_KERNEL +template +void box_cox_zero_lambda( + size_t D, + const T* const self_data, + const T* const lambda2_data, + T k_eps, + T* const output_data) { + int j = 0; + using Vec = at::vec::Vectorized; + constexpr int64_t VLEN = Vec::size(); + auto k_eps_vec = Vec(k_eps); + for(; j + VLEN < D; j += VLEN) { + auto data = Vec::loadu(self_data + j); + auto lambda2 = Vec::loadu(lambda2_data + j); + auto sum = data + lambda2; + auto max = at::vec::max(sum, k_eps_vec); + auto res = max.log(); + res.store(output_data + j); + } + for ( ;j < D; ++j) { + auto sum = self_data[j] + lambda2_data[j]; + auto max = std::max(sum, k_eps); + output_data[j] = std::log(max); + } +} + +template +void box_cox_nonzero_lambda( + int64_t D, + const T* data_ptr, + const T* lambda1_ptr, + const T* lambda2_ptr, + T k_eps, + T* out) { + + int j = 0; + using Vec = at::vec::Vectorized; + constexpr int64_t VLEN = Vec::size(); + auto k_eps_vec = Vec(k_eps); + for(; j + VLEN < D; j += VLEN) { + auto data = Vec::loadu(data_ptr + j); + auto lambda2 = Vec::loadu(lambda2_ptr + j); + auto sum = data + lambda2; + auto max = at::vec::max(sum, k_eps_vec); + auto lambda1 = Vec::loadu(lambda1_ptr + j); + auto lambda_over_1 = at::vec::fast_recieprocal(lambda1); + auto pow = max.pow(lambda1); + auto res = at::vec::fmsub(pow, lambda_over_1, lambda_over_1); + res.store(out + j); + } + for ( ;j < D; ++j) { + auto sum = data_ptr[j] + lambda2_ptr[j]; + auto max = std::max(sum, k_eps); + auto lambda_over_1 = at::vec::fast_recieprocal(lambda1_ptr[j]); + auto pow = std::pow(max, lambda1_ptr[j]); + out[j] = pow * lambda_over_1 - lambda_over_1; + } +} +#else +template +void box_cox_zero_lambda( + size_t D, + const T* const self_data, + const T* const lambda2_data, + T k_eps, + T* const output_data) { + VECTOR_LOOP for (auto j=0 ;j < D; ++j) { + auto sum = self_data[j] + lambda2_data[j]; + auto max = std::max(sum, k_eps); + output_data[j] = std::log(max); + } +} + +template +void box_cox_nonzero_lambda( + int64_t D, + const T* data_ptr, + const T* lambda1_ptr, + const T* lambda2_ptr, + T k_eps, + T* out) { + + VECTOR_LOOP for (auto j=0 ;j < D; ++j) { + FAST_MATH + auto sum = data_ptr[j] + lambda2_ptr[j]; + auto max = std::max(sum, k_eps); + auto lamda1 = lambda1_ptr[j]; + auto lambda_over_1 = 1 / lamda1; + if constexpr (std::is_same::value) { + lambda_over_1 = lambda_over_1 * (T{2} - lambda_over_1 * lamda1); + lambda_over_1 = lambda_over_1 * (T{2} - lambda_over_1 * lamda1); + } + auto pow = std::pow(max, lamda1); + out[j] = pow * lambda_over_1 - lambda_over_1; + } +} +#endif + +template +void box_cox_mixed_lambda( + const T* const self_data, + const std::vector& nonzeros, + const std::vector& zeros, + const T* const lambda1, + const T* const lambda2, + const T* const lambda2_z_, + T k_eps, + T* const buffer, + T* const output_data) { + PackV(nonzeros.size(), self_data, nonzeros.data(), buffer); + box_cox_nonzero_lambda( + nonzeros.size(), buffer, lambda1, lambda2, k_eps, buffer); + UnpackV(nonzeros.size(), buffer, output_data, nonzeros.data()); + + PackV(zeros.size(), self_data, zeros.data(), buffer); + box_cox_zero_lambda( + zeros.size(), buffer, lambda2_z_, k_eps, buffer); + UnpackV(zeros.size(), buffer, output_data, zeros.data()); +} + +template +void TileArrayIntoVector( + const T* const a, + const size_t D, + const int K, + std::vector& b) { + b.resize(K * D); + for (const auto k : c10::irange(K)) { + std::copy(a, a + D, b.begin() + k * D); + } +} + +void TileIndicesInPlace(std::vector& v, const std::size_t D, const std::size_t K) { + auto n = v.size(); + v.resize(K * n); + for (const auto k : c10::irange(1, K)) { + for (const auto j : c10::irange(n)) { + v[k * n + j] = v[j] + k * D; + } + } +} + +template +void compute_batch_box_cox__avx2_fma( + std::size_t N, + std::size_t D, + std::size_t block_size, + const T* self_data, + const T* __restrict lambda1_data, + const T* __restrict lambda2_data, + T* output_data) { + constexpr T k_eps = static_cast(1e-6); + + FOLLY_DECLARE_REUSED(zeros, std::vector); + FOLLY_DECLARE_REUSED(nonzeros, std::vector); + // Don't bother calling reserve; calls after the first will get a + // correctly-sized allocation anyway. + for (const auto j : c10::irange(D)) { + if (lambda1_data[j] == 0) { + zeros.push_back(j); + } else { + nonzeros.push_back(j); + } + } + + // Process K rows at a time for effective vectorization with small rows. + const auto K = std::min(N, (block_size + D - 1) / D); + + FOLLY_DECLARE_REUSED(lambda1_, std::vector); + FOLLY_DECLARE_REUSED(lambda2_, std::vector); + FOLLY_DECLARE_REUSED(lambda2_z_, std::vector); + + if (nonzeros.size() == D) { + // ((x + lambda2)^lambda1 - 1)/lambda1, if lambda1 != 0 + size_t i = 0; + if (K > 1) { + TileArrayIntoVector(lambda1_data, D, K, lambda1_); + TileArrayIntoVector(lambda2_data, D, K, lambda2_); + DCHECK_EQ(K * D, lambda1_.size()); + DCHECK_EQ(K * D, lambda2_.size()); + for (; i < N - K + 1; i += K, self_data += K * D, output_data += K * D) { + box_cox_nonzero_lambda( + K * D, + self_data, + lambda1_.data(), + lambda2_.data(), + k_eps, + output_data); + } + } + for (; i < N; i++, self_data += D, output_data += D) { + box_cox_nonzero_lambda( + D, self_data, lambda1_data, lambda2_data, k_eps, output_data); + } + } else if (zeros.size() == D) { + // ln(x + lambda2), if lambda1 == 0 + size_t i = 0; + if (K > 1) { + TileArrayIntoVector(lambda2_data, D, K, lambda2_z_); + DCHECK_EQ(K * D, lambda2_z_.size()); + for (; i < N - K + 1; i += K, self_data += K * D, output_data += K * D) { + box_cox_zero_lambda( + K * D, self_data, lambda2_z_.data(), k_eps, output_data); + } + } + for (; i < N; i++, self_data += D, output_data += D) { + box_cox_zero_lambda( + D, self_data, lambda2_data, k_eps, output_data); + } + } else { + // mix zeros and nonzeros + const size_t n = nonzeros.size(); + if (K > 1) { + TileIndicesInPlace(nonzeros, 0, K); + TileIndicesInPlace(zeros, 0, K); + } + + FOLLY_DECLARE_REUSED(buffer, std::vector); + + buffer.resize(std::max(nonzeros.size(), zeros.size())); + lambda1_.resize(nonzeros.size()); + lambda2_.resize(nonzeros.size()); + lambda2_z_.resize(zeros.size()); + PackV(nonzeros.size(), lambda1_data, nonzeros.data(), lambda1_.data()); + PackV(nonzeros.size(), lambda2_data, nonzeros.data(), lambda2_.data()); + PackV(zeros.size(), lambda2_data, zeros.data(), lambda2_z_.data()); + + size_t i = 0; + if (K > 1) { + // Truncate to original size, and re-tile with offsets this time. + nonzeros.resize(n); + DCHECK_GT(D, n); + zeros.resize(D - n); + TileIndicesInPlace(nonzeros, D, K); + TileIndicesInPlace(zeros, D, K); + DCHECK_EQ(nonzeros.size(), lambda1_.size()); + DCHECK_EQ(nonzeros.size(), lambda2_.size()); + DCHECK_EQ(zeros.size(), lambda2_z_.size()); + + for (; i < N - K + 1; i += K, self_data += K * D, output_data += K * D) { + box_cox_mixed_lambda( + self_data, + nonzeros, + zeros, + lambda1_.data(), + lambda2_.data(), + lambda2_z_.data(), + k_eps, + buffer.data(), + output_data); + } + // Truncate to original size. + nonzeros.resize(n); + zeros.resize(D - n); + } + for (; i < N; i++, self_data += D, output_data += D) { + box_cox_mixed_lambda( + self_data, + nonzeros, + zeros, + lambda1_.data(), + lambda2_.data(), + lambda2_z_.data(), + k_eps, + buffer.data(), + output_data); + } + } +}; + + +template +void compute_batch_box_cox__avx2_fma( + std::size_t N, + std::size_t D, + std::size_t block_size, + const float* self_data, + const float* __restrict lambda1_data, + const float* __restrict lambda2_data, + float* output_data); + +template +void compute_batch_box_cox__avx2_fma( + std::size_t N, + std::size_t D, + std::size_t block_size, + const double* self_data, + const double* __restrict lambda1_data, + const double* __restrict lambda2_data, + double* output_data); + +} // namespace caffe2::detail +#endif diff --git a/caffe2/perfkernels/common.h b/caffe2/perfkernels/common.h index fb960dbe5dc3c..6fed9e1d6d06c 100644 --- a/caffe2/perfkernels/common.h +++ b/caffe2/perfkernels/common.h @@ -62,7 +62,10 @@ In foo.cc, do: #pragma once +#if defined(CAFFE2_PERF_WITH_AVX512) || defined(CAFFE2_PERF_WITH_AVX2) \ + || defined(CAFFE2_PERF_WITH_AVX) #include +#endif // DO macros: these should be used in your entry function, similar to foo() // above, that routes implementations based on CPU capability. diff --git a/caffe2/perfkernels/embedding_lookup_idx.cc b/caffe2/perfkernels/embedding_lookup_idx.cc index 2c9900b73e06b..48c869ee70381 100644 --- a/caffe2/perfkernels/embedding_lookup_idx.cc +++ b/caffe2/perfkernels/embedding_lookup_idx.cc @@ -1,5 +1,6 @@ #include "caffe2/perfkernels/embedding_lookup_idx.h" +#include #include #include #include "caffe2/core/common.h" @@ -214,6 +215,8 @@ EMBEDDING_IDX_SPECIALIZATION(int32_t, float, float, float, false); EMBEDDING_IDX_SPECIALIZATION(int64_t, float, float, float, false); EMBEDDING_IDX_SPECIALIZATION(int32_t, half, at::Half, float, false); EMBEDDING_IDX_SPECIALIZATION(int64_t, half, at::Half, float, false); +EMBEDDING_IDX_SPECIALIZATION(int32_t, bfloat16, at::BFloat16, float, false); +EMBEDDING_IDX_SPECIALIZATION(int64_t, bfloat16, at::BFloat16, float, false); EMBEDDING_IDX_SPECIALIZATION(int32_t, uint8_t, uint8_t, float, false); EMBEDDING_IDX_SPECIALIZATION(int64_t, uint8_t, uint8_t, float, false); @@ -221,6 +224,8 @@ EMBEDDING_IDX_SPECIALIZATION(int32_t, float, float, float, true); EMBEDDING_IDX_SPECIALIZATION(int64_t, float, float, float, true); EMBEDDING_IDX_SPECIALIZATION(int32_t, half, at::Half, float, true); EMBEDDING_IDX_SPECIALIZATION(int64_t, half, at::Half, float, true); +EMBEDDING_IDX_SPECIALIZATION(int32_t, bfloat16, at::BFloat16, float, true); +EMBEDDING_IDX_SPECIALIZATION(int64_t, bfloat16, at::BFloat16, float, true); EMBEDDING_IDX_SPECIALIZATION(int32_t, uint8_t, uint8_t, float, true); EMBEDDING_IDX_SPECIALIZATION(int64_t, uint8_t, uint8_t, float, true); diff --git a/caffe2/perfkernels/embedding_lookup_idx_avx2.cc b/caffe2/perfkernels/embedding_lookup_idx_avx2.cc index 674af836ba10b..3ed48a1c52322 100644 --- a/caffe2/perfkernels/embedding_lookup_idx_avx2.cc +++ b/caffe2/perfkernels/embedding_lookup_idx_avx2.cc @@ -6,6 +6,7 @@ //// -------------------------- #include +#include #include namespace caffe2 { @@ -341,6 +342,7 @@ static bool EmbeddingLookupIdx_int32_t_float_float__avx2_fma( } } else { // generic code + // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays) for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; int64_t j = 0; @@ -471,6 +473,7 @@ static bool EmbeddingLookupIdx_int64_t_float_float__avx2_fma( bool normalize_by_lengths, float* out) { const int64_t prefdist_T0 = 16; + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) const int64_t fused_block_size = block_size + 0; int64_t dataInd = 0; if (block_size == 128) { @@ -511,7 +514,9 @@ static bool EmbeddingLookupIdx_int64_t_float_float__avx2_fma( __m256 vwgt = _mm256_set1_ps(wgt); const float* ip = &input[idx * fused_block_size]; const int64_t next_T0 = (dataInd < index_size - prefdist_T0) + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) ? (dataInd + prefdist_T0) + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { @@ -626,7 +631,9 @@ static bool EmbeddingLookupIdx_int64_t_float_float__avx2_fma( __m256 vwgt = _mm256_set1_ps(wgt); const float* ip = &input[idx * fused_block_size]; const int64_t next_T0 = (dataInd < index_size - prefdist_T0) + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) ? (dataInd + prefdist_T0) + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { @@ -701,7 +708,9 @@ static bool EmbeddingLookupIdx_int64_t_float_float__avx2_fma( __m256 vwgt = _mm256_set1_ps(wgt); const float* ip = &input[idx * fused_block_size]; const int64_t next_T0 = (dataInd < index_size - prefdist_T0) + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) ? (dataInd + prefdist_T0) + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { @@ -756,7 +765,9 @@ static bool EmbeddingLookupIdx_int64_t_float_float__avx2_fma( __m256 vwgt = _mm256_set1_ps(wgt); const float* ip = &input[idx * fused_block_size]; const int64_t next_T0 = (dataInd < index_size - prefdist_T0) + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) ? (dataInd + prefdist_T0) + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { @@ -780,6 +791,7 @@ static bool EmbeddingLookupIdx_int64_t_float_float__avx2_fma( } } else { // generic code + // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays) for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; int64_t j = 0; @@ -807,7 +819,9 @@ static bool EmbeddingLookupIdx_int64_t_float_float__avx2_fma( __m256 vwgt = _mm256_set1_ps(wgt); const float* ip = &input[idx * fused_block_size]; const int64_t next_T0 = (dataInd < index_size - prefdist_T0) + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) ? (dataInd + prefdist_T0) + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { @@ -1477,6 +1491,7 @@ static bool EmbeddingLookupIdx_int64_t_half_float__avx2_fma( bool normalize_by_lengths, float* out) { const int64_t prefdist_T0 = 16; + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) const int64_t fused_block_size = block_size + 0; int64_t dataInd = 0; if (block_size == 128) { @@ -1517,7 +1532,9 @@ static bool EmbeddingLookupIdx_int64_t_half_float__avx2_fma( __m256 vwgt = _mm256_set1_ps(wgt); const at::Half* ip = &input[idx * fused_block_size]; const int64_t next_T0 = (dataInd < index_size - prefdist_T0) + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) ? (dataInd + prefdist_T0) + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { @@ -1692,7 +1709,9 @@ static bool EmbeddingLookupIdx_int64_t_half_float__avx2_fma( __m256 vwgt = _mm256_set1_ps(wgt); const at::Half* ip = &input[idx * fused_block_size]; const int64_t next_T0 = (dataInd < index_size - prefdist_T0) + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) ? (dataInd + prefdist_T0) + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { @@ -1797,7 +1816,9 @@ static bool EmbeddingLookupIdx_int64_t_half_float__avx2_fma( __m256 vwgt = _mm256_set1_ps(wgt); const at::Half* ip = &input[idx * fused_block_size]; const int64_t next_T0 = (dataInd < index_size - prefdist_T0) + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) ? (dataInd + prefdist_T0) + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { @@ -1867,7 +1888,9 @@ static bool EmbeddingLookupIdx_int64_t_half_float__avx2_fma( __m256 vwgt = _mm256_set1_ps(wgt); const at::Half* ip = &input[idx * fused_block_size]; const int64_t next_T0 = (dataInd < index_size - prefdist_T0) + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) ? (dataInd + prefdist_T0) + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { @@ -1928,7 +1951,9 @@ static bool EmbeddingLookupIdx_int64_t_half_float__avx2_fma( __m256 vwgt = _mm256_set1_ps(wgt); const at::Half* ip = &input[idx * fused_block_size]; const int64_t next_T0 = (dataInd < index_size - prefdist_T0) + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) ? (dataInd + prefdist_T0) + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { @@ -2022,12 +2047,12 @@ bool EmbeddingLookupIdx_int64_t_half_float_true__avx2_fma( } template -static bool EmbeddingLookupIdx_int32_t_uint8_t_float__avx2_fma( +static bool EmbeddingLookupIdx_int32_t_bfloat16_float__avx2_fma( const int64_t block_size, const int64_t output_size, const int64_t index_size, const int64_t data_size, - const uint8_t* input, + const at::BFloat16* input, const int* indices, const int* offsets, const float* weights, @@ -2070,16 +2095,11 @@ static bool EmbeddingLookupIdx_int32_t_uint8_t_float__avx2_fma( return false; } float wgt = 1.f; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - float bio; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } - bio = wgt * scale_bias[2 * idx + 1]; - wgt = wgt * scale_bias[2 * idx]; - __m256 vbio = _mm256_set1_ps(bio); __m256 vwgt = _mm256_set1_ps(wgt); - const uint8_t* ip = &input[idx * fused_block_size]; + const at::BFloat16* ip = &input[idx * fused_block_size]; const int next_T0 = (dataInd < index_size - prefdist_T0) // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) ? (dataInd + prefdist_T0) @@ -2089,104 +2109,138 @@ static bool EmbeddingLookupIdx_int32_t_uint8_t_float__avx2_fma( if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { return false; } - const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; + const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps( vwgt, - _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( - _mm_loadl_epi64(reinterpret_cast(ip + (0))))), - _mm256_add_ps(vop0, vbio)); + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (0)))), + 16)), + vop0); _mm_prefetch( reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps( vwgt, - _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( - _mm_loadl_epi64(reinterpret_cast(ip + (8))))), - _mm256_add_ps(vop8, vbio)); + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (8)))), + 16)), + vop8); // skip unnecessary prefetch of (&ip_next_T0[8]) vop16 = _mm256_fmadd_ps( vwgt, - _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( - _mm_loadl_epi64(reinterpret_cast(ip + (16))))), - _mm256_add_ps(vop16, vbio)); + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (16)))), + 16)), + vop16); // skip unnecessary prefetch of (&ip_next_T0[16]) vop24 = _mm256_fmadd_ps( vwgt, - _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( - _mm_loadl_epi64(reinterpret_cast(ip + (24))))), - _mm256_add_ps(vop24, vbio)); + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (24)))), + 16)), + vop24); // skip unnecessary prefetch of (&ip_next_T0[24]) vop32 = _mm256_fmadd_ps( vwgt, - _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( - _mm_loadl_epi64(reinterpret_cast(ip + (32))))), - _mm256_add_ps(vop32, vbio)); - // skip unnecessary prefetch of (&ip_next_T0[32]) + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (32)))), + 16)), + vop32); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[32]), _MM_HINT_T0); vop40 = _mm256_fmadd_ps( vwgt, - _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( - _mm_loadl_epi64(reinterpret_cast(ip + (40))))), - _mm256_add_ps(vop40, vbio)); + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (40)))), + 16)), + vop40); // skip unnecessary prefetch of (&ip_next_T0[40]) vop48 = _mm256_fmadd_ps( vwgt, - _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( - _mm_loadl_epi64(reinterpret_cast(ip + (48))))), - _mm256_add_ps(vop48, vbio)); + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (48)))), + 16)), + vop48); // skip unnecessary prefetch of (&ip_next_T0[48]) vop56 = _mm256_fmadd_ps( vwgt, - _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( - _mm_loadl_epi64(reinterpret_cast(ip + (56))))), - _mm256_add_ps(vop56, vbio)); + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (56)))), + 16)), + vop56); // skip unnecessary prefetch of (&ip_next_T0[56]) vop64 = _mm256_fmadd_ps( vwgt, - _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( - _mm_loadl_epi64(reinterpret_cast(ip + (64))))), - _mm256_add_ps(vop64, vbio)); + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (64)))), + 16)), + vop64); _mm_prefetch( reinterpret_cast(&ip_next_T0[64]), _MM_HINT_T0); vop72 = _mm256_fmadd_ps( vwgt, - _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( - _mm_loadl_epi64(reinterpret_cast(ip + (72))))), - _mm256_add_ps(vop72, vbio)); + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (72)))), + 16)), + vop72); // skip unnecessary prefetch of (&ip_next_T0[72]) vop80 = _mm256_fmadd_ps( vwgt, - _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( - _mm_loadl_epi64(reinterpret_cast(ip + (80))))), - _mm256_add_ps(vop80, vbio)); + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (80)))), + 16)), + vop80); // skip unnecessary prefetch of (&ip_next_T0[80]) vop88 = _mm256_fmadd_ps( vwgt, - _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( - _mm_loadl_epi64(reinterpret_cast(ip + (88))))), - _mm256_add_ps(vop88, vbio)); + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (88)))), + 16)), + vop88); // skip unnecessary prefetch of (&ip_next_T0[88]) vop96 = _mm256_fmadd_ps( vwgt, - _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( - _mm_loadl_epi64(reinterpret_cast(ip + (96))))), - _mm256_add_ps(vop96, vbio)); - // skip unnecessary prefetch of (&ip_next_T0[96]) + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (96)))), + 16)), + vop96); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[96]), _MM_HINT_T0); vop104 = _mm256_fmadd_ps( vwgt, - _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( - _mm_loadl_epi64(reinterpret_cast(ip + (104))))), - _mm256_add_ps(vop104, vbio)); + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (104)))), + 16)), + vop104); // skip unnecessary prefetch of (&ip_next_T0[104]) vop112 = _mm256_fmadd_ps( vwgt, - _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( - _mm_loadl_epi64(reinterpret_cast(ip + (112))))), - _mm256_add_ps(vop112, vbio)); + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (112)))), + 16)), + vop112); // skip unnecessary prefetch of (&ip_next_T0[112]) vop120 = _mm256_fmadd_ps( vwgt, - _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( - _mm_loadl_epi64(reinterpret_cast(ip + (120))))), - _mm256_add_ps(vop120, vbio)); + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (120)))), + 16)), + vop120); // skip unnecessary prefetch of (&ip_next_T0[120]) } if (!normalize_by_lengths || length == 0) { @@ -2250,16 +2304,11 @@ static bool EmbeddingLookupIdx_int32_t_uint8_t_float__avx2_fma( return false; } float wgt = 1.f; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - float bio; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } - bio = wgt * scale_bias[2 * idx + 1]; - wgt = wgt * scale_bias[2 * idx]; - __m256 vbio = _mm256_set1_ps(bio); __m256 vwgt = _mm256_set1_ps(wgt); - const uint8_t* ip = &input[idx * fused_block_size]; + const at::BFloat16* ip = &input[idx * fused_block_size]; const int next_T0 = (dataInd < index_size - prefdist_T0) // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) ? (dataInd + prefdist_T0) @@ -2269,55 +2318,72 @@ static bool EmbeddingLookupIdx_int32_t_uint8_t_float__avx2_fma( if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { return false; } - const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; + const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps( vwgt, - _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( - _mm_loadl_epi64(reinterpret_cast(ip + (0))))), - _mm256_add_ps(vop0, vbio)); + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (0)))), + 16)), + vop0); _mm_prefetch( reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps( vwgt, - _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( - _mm_loadl_epi64(reinterpret_cast(ip + (8))))), - _mm256_add_ps(vop8, vbio)); + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (8)))), + 16)), + vop8); // skip unnecessary prefetch of (&ip_next_T0[8]) vop16 = _mm256_fmadd_ps( vwgt, - _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( - _mm_loadl_epi64(reinterpret_cast(ip + (16))))), - _mm256_add_ps(vop16, vbio)); + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (16)))), + 16)), + vop16); // skip unnecessary prefetch of (&ip_next_T0[16]) vop24 = _mm256_fmadd_ps( vwgt, - _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( - _mm_loadl_epi64(reinterpret_cast(ip + (24))))), - _mm256_add_ps(vop24, vbio)); + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (24)))), + 16)), + vop24); // skip unnecessary prefetch of (&ip_next_T0[24]) vop32 = _mm256_fmadd_ps( vwgt, - _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( - _mm_loadl_epi64(reinterpret_cast(ip + (32))))), - _mm256_add_ps(vop32, vbio)); - // skip unnecessary prefetch of (&ip_next_T0[32]) + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (32)))), + 16)), + vop32); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[32]), _MM_HINT_T0); vop40 = _mm256_fmadd_ps( vwgt, - _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( - _mm_loadl_epi64(reinterpret_cast(ip + (40))))), - _mm256_add_ps(vop40, vbio)); + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (40)))), + 16)), + vop40); // skip unnecessary prefetch of (&ip_next_T0[40]) vop48 = _mm256_fmadd_ps( vwgt, - _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( - _mm_loadl_epi64(reinterpret_cast(ip + (48))))), - _mm256_add_ps(vop48, vbio)); + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (48)))), + 16)), + vop48); // skip unnecessary prefetch of (&ip_next_T0[48]) vop56 = _mm256_fmadd_ps( vwgt, - _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( - _mm_loadl_epi64(reinterpret_cast(ip + (56))))), - _mm256_add_ps(vop56, vbio)); + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (56)))), + 16)), + vop56); // skip unnecessary prefetch of (&ip_next_T0[56]) } if (!normalize_by_lengths || length == 0) { @@ -2361,16 +2427,11 @@ static bool EmbeddingLookupIdx_int32_t_uint8_t_float__avx2_fma( return false; } float wgt = 1.f; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - float bio; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } - bio = wgt * scale_bias[2 * idx + 1]; - wgt = wgt * scale_bias[2 * idx]; - __m256 vbio = _mm256_set1_ps(bio); __m256 vwgt = _mm256_set1_ps(wgt); - const uint8_t* ip = &input[idx * fused_block_size]; + const at::BFloat16* ip = &input[idx * fused_block_size]; const int next_T0 = (dataInd < index_size - prefdist_T0) // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) ? (dataInd + prefdist_T0) @@ -2380,31 +2441,39 @@ static bool EmbeddingLookupIdx_int32_t_uint8_t_float__avx2_fma( if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { return false; } - const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; + const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps( vwgt, - _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( - _mm_loadl_epi64(reinterpret_cast(ip + (0))))), - _mm256_add_ps(vop0, vbio)); + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (0)))), + 16)), + vop0); _mm_prefetch( reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps( vwgt, - _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( - _mm_loadl_epi64(reinterpret_cast(ip + (8))))), - _mm256_add_ps(vop8, vbio)); + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (8)))), + 16)), + vop8); // skip unnecessary prefetch of (&ip_next_T0[8]) vop16 = _mm256_fmadd_ps( vwgt, - _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( - _mm_loadl_epi64(reinterpret_cast(ip + (16))))), - _mm256_add_ps(vop16, vbio)); + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (16)))), + 16)), + vop16); // skip unnecessary prefetch of (&ip_next_T0[16]) vop24 = _mm256_fmadd_ps( vwgt, - _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( - _mm_loadl_epi64(reinterpret_cast(ip + (24))))), - _mm256_add_ps(vop24, vbio)); + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (24)))), + 16)), + vop24); // skip unnecessary prefetch of (&ip_next_T0[24]) } if (!normalize_by_lengths || length == 0) { @@ -2438,16 +2507,11 @@ static bool EmbeddingLookupIdx_int32_t_uint8_t_float__avx2_fma( return false; } float wgt = 1.f; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - float bio; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } - bio = wgt * scale_bias[2 * idx + 1]; - wgt = wgt * scale_bias[2 * idx]; - __m256 vbio = _mm256_set1_ps(bio); __m256 vwgt = _mm256_set1_ps(wgt); - const uint8_t* ip = &input[idx * fused_block_size]; + const at::BFloat16* ip = &input[idx * fused_block_size]; const int next_T0 = (dataInd < index_size - prefdist_T0) // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) ? (dataInd + prefdist_T0) @@ -2457,19 +2521,23 @@ static bool EmbeddingLookupIdx_int32_t_uint8_t_float__avx2_fma( if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { return false; } - const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; + const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps( vwgt, - _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( - _mm_loadl_epi64(reinterpret_cast(ip + (0))))), - _mm256_add_ps(vop0, vbio)); + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (0)))), + 16)), + vop0); _mm_prefetch( reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps( vwgt, - _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( - _mm_loadl_epi64(reinterpret_cast(ip + (8))))), - _mm256_add_ps(vop8, vbio)); + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (8)))), + 16)), + vop8); // skip unnecessary prefetch of (&ip_next_T0[8]) } if (!normalize_by_lengths || length == 0) { @@ -2483,6 +2551,8 @@ static bool EmbeddingLookupIdx_int32_t_uint8_t_float__avx2_fma( } } else { // generic code + // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays) + alignas(64) at::BFloat16 vtmp1[8] = {0}; for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; int64_t j = 0; @@ -2504,16 +2574,11 @@ static bool EmbeddingLookupIdx_int32_t_uint8_t_float__avx2_fma( return false; } float wgt = 1.f; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - float bio; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } - bio = wgt * scale_bias[2 * idx + 1]; - wgt = wgt * scale_bias[2 * idx]; - __m256 vbio = _mm256_set1_ps(bio); __m256 vwgt = _mm256_set1_ps(wgt); - const uint8_t* ip = &input[idx * fused_block_size]; + const at::BFloat16* ip = &input[idx * fused_block_size]; const int next_T0 = (dataInd < index_size - prefdist_T0) // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) ? (dataInd + prefdist_T0) @@ -2523,21 +2588,27 @@ static bool EmbeddingLookupIdx_int32_t_uint8_t_float__avx2_fma( if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { return false; } - const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; + const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; j = 0; for (; j + 8 <= block_size; j += 8) { _mm256_storeu_ps( &op[j], _mm256_fmadd_ps( vwgt, - _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64( - reinterpret_cast(&ip[j])))), - _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio))); + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(&ip[j]))), + 16)), + _mm256_loadu_ps(&op[j]))); _mm_prefetch( reinterpret_cast(&ip_next_T0[j]), _MM_HINT_T0); } for (; j < block_size; j++) { - op[j] = std::fma(wgt, (float)ip[j], bio + op[j]); + vtmp1[0] = ip[j]; + __m256 vtmp2 = _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(*(reinterpret_cast(vtmp1))), + 16)); + op[j] = std::fma(wgt, ((float*)(&vtmp2))[0], op[j]); } } if (normalize_by_lengths && length) { @@ -2556,19 +2627,19 @@ static bool EmbeddingLookupIdx_int32_t_uint8_t_float__avx2_fma( } return dataInd == index_size; } -bool EmbeddingLookupIdx_int32_t_uint8_t_float_false__avx2_fma( +bool EmbeddingLookupIdx_int32_t_bfloat16_float_false__avx2_fma( const int64_t block_size, const int64_t output_size, const int64_t index_size, const int64_t data_size, - const uint8_t* input, + const at::BFloat16* input, const int* indices, const int* offsets, const float* weights, const float* scale_bias, bool normalize_by_lengths, float* out) { - return EmbeddingLookupIdx_int32_t_uint8_t_float__avx2_fma( + return EmbeddingLookupIdx_int32_t_bfloat16_float__avx2_fma( block_size, output_size, index_size, @@ -2581,19 +2652,19 @@ bool EmbeddingLookupIdx_int32_t_uint8_t_float_false__avx2_fma( normalize_by_lengths, out); } -bool EmbeddingLookupIdx_int32_t_uint8_t_float_true__avx2_fma( +bool EmbeddingLookupIdx_int32_t_bfloat16_float_true__avx2_fma( const int64_t block_size, const int64_t output_size, const int64_t index_size, const int64_t data_size, - const uint8_t* input, + const at::BFloat16* input, const int* indices, const int* offsets, const float* weights, const float* scale_bias, bool normalize_by_lengths, float* out) { - return EmbeddingLookupIdx_int32_t_uint8_t_float__avx2_fma( + return EmbeddingLookupIdx_int32_t_bfloat16_float__avx2_fma( block_size, output_size, index_size, @@ -2608,12 +2679,12 @@ bool EmbeddingLookupIdx_int32_t_uint8_t_float_true__avx2_fma( } template -static bool EmbeddingLookupIdx_int64_t_uint8_t_float__avx2_fma( +static bool EmbeddingLookupIdx_int64_t_bfloat16_float__avx2_fma( const int64_t block_size, const int64_t output_size, const int64_t index_size, const int64_t data_size, - const uint8_t* input, + const at::BFloat16* input, const int64_t* indices, const int64_t* offsets, const float* weights, @@ -2621,6 +2692,7 @@ static bool EmbeddingLookupIdx_int64_t_uint8_t_float__avx2_fma( bool normalize_by_lengths, float* out) { const int64_t prefdist_T0 = 16; + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) const int64_t fused_block_size = block_size + 0; int64_t dataInd = 0; if (block_size == 128) { @@ -2655,83 +2727,1304 @@ static bool EmbeddingLookupIdx_int64_t_uint8_t_float__avx2_fma( return false; } float wgt = 1.f; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - float bio; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } - bio = wgt * scale_bias[2 * idx + 1]; - wgt = wgt * scale_bias[2 * idx]; - __m256 vbio = _mm256_set1_ps(bio); __m256 vwgt = _mm256_set1_ps(wgt); - const uint8_t* ip = &input[idx * fused_block_size]; + const at::BFloat16* ip = &input[idx * fused_block_size]; const int64_t next_T0 = (dataInd < index_size - prefdist_T0) + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) ? (dataInd + prefdist_T0) + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { return false; } - const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; + const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps( vwgt, - _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( - _mm_loadl_epi64(reinterpret_cast(ip + (0))))), - _mm256_add_ps(vop0, vbio)); + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (0)))), + 16)), + vop0); _mm_prefetch( reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps( vwgt, - _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( - _mm_loadl_epi64(reinterpret_cast(ip + (8))))), - _mm256_add_ps(vop8, vbio)); + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (8)))), + 16)), + vop8); // skip unnecessary prefetch of (&ip_next_T0[8]) vop16 = _mm256_fmadd_ps( vwgt, - _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( - _mm_loadl_epi64(reinterpret_cast(ip + (16))))), - _mm256_add_ps(vop16, vbio)); + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (16)))), + 16)), + vop16); // skip unnecessary prefetch of (&ip_next_T0[16]) vop24 = _mm256_fmadd_ps( vwgt, - _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( - _mm_loadl_epi64(reinterpret_cast(ip + (24))))), - _mm256_add_ps(vop24, vbio)); + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (24)))), + 16)), + vop24); // skip unnecessary prefetch of (&ip_next_T0[24]) vop32 = _mm256_fmadd_ps( vwgt, - _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( - _mm_loadl_epi64(reinterpret_cast(ip + (32))))), - _mm256_add_ps(vop32, vbio)); - // skip unnecessary prefetch of (&ip_next_T0[32]) + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (32)))), + 16)), + vop32); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[32]), _MM_HINT_T0); vop40 = _mm256_fmadd_ps( vwgt, - _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( - _mm_loadl_epi64(reinterpret_cast(ip + (40))))), - _mm256_add_ps(vop40, vbio)); + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (40)))), + 16)), + vop40); // skip unnecessary prefetch of (&ip_next_T0[40]) vop48 = _mm256_fmadd_ps( vwgt, - _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( - _mm_loadl_epi64(reinterpret_cast(ip + (48))))), - _mm256_add_ps(vop48, vbio)); + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (48)))), + 16)), + vop48); // skip unnecessary prefetch of (&ip_next_T0[48]) vop56 = _mm256_fmadd_ps( vwgt, - _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( - _mm_loadl_epi64(reinterpret_cast(ip + (56))))), - _mm256_add_ps(vop56, vbio)); + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (56)))), + 16)), + vop56); // skip unnecessary prefetch of (&ip_next_T0[56]) vop64 = _mm256_fmadd_ps( vwgt, - _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( - _mm_loadl_epi64(reinterpret_cast(ip + (64))))), - _mm256_add_ps(vop64, vbio)); + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (64)))), + 16)), + vop64); _mm_prefetch( reinterpret_cast(&ip_next_T0[64]), _MM_HINT_T0); vop72 = _mm256_fmadd_ps( vwgt, - _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (72)))), + 16)), + vop72); + // skip unnecessary prefetch of (&ip_next_T0[72]) + vop80 = _mm256_fmadd_ps( + vwgt, + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (80)))), + 16)), + vop80); + // skip unnecessary prefetch of (&ip_next_T0[80]) + vop88 = _mm256_fmadd_ps( + vwgt, + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (88)))), + 16)), + vop88); + // skip unnecessary prefetch of (&ip_next_T0[88]) + vop96 = _mm256_fmadd_ps( + vwgt, + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (96)))), + 16)), + vop96); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[96]), _MM_HINT_T0); + vop104 = _mm256_fmadd_ps( + vwgt, + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (104)))), + 16)), + vop104); + // skip unnecessary prefetch of (&ip_next_T0[104]) + vop112 = _mm256_fmadd_ps( + vwgt, + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (112)))), + 16)), + vop112); + // skip unnecessary prefetch of (&ip_next_T0[112]) + vop120 = _mm256_fmadd_ps( + vwgt, + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (120)))), + 16)), + vop120); + // skip unnecessary prefetch of (&ip_next_T0[120]) + } + if (!normalize_by_lengths || length == 0) { + _mm256_storeu_ps(&op[0], vop0); + _mm256_storeu_ps(&op[8], vop8); + _mm256_storeu_ps(&op[16], vop16); + _mm256_storeu_ps(&op[24], vop24); + _mm256_storeu_ps(&op[32], vop32); + _mm256_storeu_ps(&op[40], vop40); + _mm256_storeu_ps(&op[48], vop48); + _mm256_storeu_ps(&op[56], vop56); + _mm256_storeu_ps(&op[64], vop64); + _mm256_storeu_ps(&op[72], vop72); + _mm256_storeu_ps(&op[80], vop80); + _mm256_storeu_ps(&op[88], vop88); + _mm256_storeu_ps(&op[96], vop96); + _mm256_storeu_ps(&op[104], vop104); + _mm256_storeu_ps(&op[112], vop112); + _mm256_storeu_ps(&op[120], vop120); + } else { + __m256 vlen_inv = _mm256_set1_ps(1.0f / length); + _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); + _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); + _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); + _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); + _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv)); + _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv)); + _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv)); + _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv)); + _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv)); + _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv)); + _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv)); + _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv)); + _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv)); + _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv)); + _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv)); + _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv)); + } + } + } else if (block_size == 64) { + // unrolling 8 times + for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + float* op = &out[rangeIndex * block_size]; + __m256 vop0 = _mm256_setzero_ps(); + __m256 vop8 = _mm256_setzero_ps(); + __m256 vop16 = _mm256_setzero_ps(); + __m256 vop24 = _mm256_setzero_ps(); + __m256 vop32 = _mm256_setzero_ps(); + __m256 vop40 = _mm256_setzero_ps(); + __m256 vop48 = _mm256_setzero_ps(); + __m256 vop56 = _mm256_setzero_ps(); + if (dataInd != offsets[rangeIndex] - offsets[0]) { + return false; + } + int64_t end_offset = offsets[rangeIndex + 1]; + int64_t length = end_offset - offsets[rangeIndex]; + for (int64_t start = dataInd; dataInd < end_offset - offsets[0]; + ++dataInd) { + const int64_t idx = indices[dataInd]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; + } + __m256 vwgt = _mm256_set1_ps(wgt); + const at::BFloat16* ip = &input[idx * fused_block_size]; + const int64_t next_T0 = (dataInd < index_size - prefdist_T0) + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) + ? (dataInd + prefdist_T0) + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) + : dataInd; + const int64_t idx_pref_T0 = indices[next_T0]; + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } + const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; + vop0 = _mm256_fmadd_ps( + vwgt, + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (0)))), + 16)), + vop0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); + vop8 = _mm256_fmadd_ps( + vwgt, + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (8)))), + 16)), + vop8); + // skip unnecessary prefetch of (&ip_next_T0[8]) + vop16 = _mm256_fmadd_ps( + vwgt, + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (16)))), + 16)), + vop16); + // skip unnecessary prefetch of (&ip_next_T0[16]) + vop24 = _mm256_fmadd_ps( + vwgt, + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (24)))), + 16)), + vop24); + // skip unnecessary prefetch of (&ip_next_T0[24]) + vop32 = _mm256_fmadd_ps( + vwgt, + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (32)))), + 16)), + vop32); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[32]), _MM_HINT_T0); + vop40 = _mm256_fmadd_ps( + vwgt, + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (40)))), + 16)), + vop40); + // skip unnecessary prefetch of (&ip_next_T0[40]) + vop48 = _mm256_fmadd_ps( + vwgt, + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (48)))), + 16)), + vop48); + // skip unnecessary prefetch of (&ip_next_T0[48]) + vop56 = _mm256_fmadd_ps( + vwgt, + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (56)))), + 16)), + vop56); + // skip unnecessary prefetch of (&ip_next_T0[56]) + } + if (!normalize_by_lengths || length == 0) { + _mm256_storeu_ps(&op[0], vop0); + _mm256_storeu_ps(&op[8], vop8); + _mm256_storeu_ps(&op[16], vop16); + _mm256_storeu_ps(&op[24], vop24); + _mm256_storeu_ps(&op[32], vop32); + _mm256_storeu_ps(&op[40], vop40); + _mm256_storeu_ps(&op[48], vop48); + _mm256_storeu_ps(&op[56], vop56); + } else { + __m256 vlen_inv = _mm256_set1_ps(1.0f / length); + _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); + _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); + _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); + _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); + _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv)); + _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv)); + _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv)); + _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv)); + } + } + } else if (block_size == 32) { + // unrolling 4 times + for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + float* op = &out[rangeIndex * block_size]; + __m256 vop0 = _mm256_setzero_ps(); + __m256 vop8 = _mm256_setzero_ps(); + __m256 vop16 = _mm256_setzero_ps(); + __m256 vop24 = _mm256_setzero_ps(); + if (dataInd != offsets[rangeIndex] - offsets[0]) { + return false; + } + int64_t end_offset = offsets[rangeIndex + 1]; + int64_t length = end_offset - offsets[rangeIndex]; + for (int64_t start = dataInd; dataInd < end_offset - offsets[0]; + ++dataInd) { + const int64_t idx = indices[dataInd]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; + } + __m256 vwgt = _mm256_set1_ps(wgt); + const at::BFloat16* ip = &input[idx * fused_block_size]; + const int64_t next_T0 = (dataInd < index_size - prefdist_T0) + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) + ? (dataInd + prefdist_T0) + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) + : dataInd; + const int64_t idx_pref_T0 = indices[next_T0]; + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } + const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; + vop0 = _mm256_fmadd_ps( + vwgt, + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (0)))), + 16)), + vop0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); + vop8 = _mm256_fmadd_ps( + vwgt, + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (8)))), + 16)), + vop8); + // skip unnecessary prefetch of (&ip_next_T0[8]) + vop16 = _mm256_fmadd_ps( + vwgt, + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (16)))), + 16)), + vop16); + // skip unnecessary prefetch of (&ip_next_T0[16]) + vop24 = _mm256_fmadd_ps( + vwgt, + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (24)))), + 16)), + vop24); + // skip unnecessary prefetch of (&ip_next_T0[24]) + } + if (!normalize_by_lengths || length == 0) { + _mm256_storeu_ps(&op[0], vop0); + _mm256_storeu_ps(&op[8], vop8); + _mm256_storeu_ps(&op[16], vop16); + _mm256_storeu_ps(&op[24], vop24); + } else { + __m256 vlen_inv = _mm256_set1_ps(1.0f / length); + _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); + _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); + _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); + _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); + } + } + } else if (block_size == 16) { + // unrolling 2 times + for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + float* op = &out[rangeIndex * block_size]; + __m256 vop0 = _mm256_setzero_ps(); + __m256 vop8 = _mm256_setzero_ps(); + if (dataInd != offsets[rangeIndex] - offsets[0]) { + return false; + } + int64_t end_offset = offsets[rangeIndex + 1]; + int64_t length = end_offset - offsets[rangeIndex]; + for (int64_t start = dataInd; dataInd < end_offset - offsets[0]; + ++dataInd) { + const int64_t idx = indices[dataInd]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; + } + __m256 vwgt = _mm256_set1_ps(wgt); + const at::BFloat16* ip = &input[idx * fused_block_size]; + const int64_t next_T0 = (dataInd < index_size - prefdist_T0) + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) + ? (dataInd + prefdist_T0) + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) + : dataInd; + const int64_t idx_pref_T0 = indices[next_T0]; + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } + const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; + vop0 = _mm256_fmadd_ps( + vwgt, + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (0)))), + 16)), + vop0); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); + vop8 = _mm256_fmadd_ps( + vwgt, + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(ip + (8)))), + 16)), + vop8); + // skip unnecessary prefetch of (&ip_next_T0[8]) + } + if (!normalize_by_lengths || length == 0) { + _mm256_storeu_ps(&op[0], vop0); + _mm256_storeu_ps(&op[8], vop8); + } else { + __m256 vlen_inv = _mm256_set1_ps(1.0f / length); + _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); + _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); + } + } + } else { + // generic code + // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays) + alignas(64) at::BFloat16 vtmp1[8] = {0}; + for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + float* op = &out[rangeIndex * block_size]; + int64_t j = 0; + for (; j + 8 <= block_size; j += 8) { + _mm256_storeu_ps(op + j, _mm256_setzero_ps()); + } + for (; j < block_size; j++) { + op[j] = 0.0f; + } + if (dataInd != offsets[rangeIndex] - offsets[0]) { + return false; + } + int64_t end_offset = offsets[rangeIndex + 1]; + int64_t length = end_offset - offsets[rangeIndex]; + for (int64_t start = dataInd; dataInd < end_offset - offsets[0]; + ++dataInd) { + const int64_t idx = indices[dataInd]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; + } + __m256 vwgt = _mm256_set1_ps(wgt); + const at::BFloat16* ip = &input[idx * fused_block_size]; + const int64_t next_T0 = (dataInd < index_size - prefdist_T0) + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) + ? (dataInd + prefdist_T0) + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) + : dataInd; + const int64_t idx_pref_T0 = indices[next_T0]; + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } + const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; + j = 0; + for (; j + 8 <= block_size; j += 8) { + _mm256_storeu_ps( + &op[j], + _mm256_fmadd_ps( + vwgt, + _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(_mm_loadu_si128( + reinterpret_cast(&ip[j]))), + 16)), + _mm256_loadu_ps(&op[j]))); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[j]), _MM_HINT_T0); + } + for (; j < block_size; j++) { + vtmp1[0] = ip[j]; + __m256 vtmp2 = _mm256_castsi256_ps(_mm256_slli_epi32( + _mm256_cvtepu16_epi32(*(reinterpret_cast(vtmp1))), + 16)); + op[j] = std::fma(wgt, ((float*)(&vtmp2))[0], op[j]); + } + } + if (normalize_by_lengths && length) { + float len_inv = 1.0f / length; + __m256 vlen_inv = _mm256_set1_ps(len_inv); + j = 0; + for (; j + 8 <= block_size; j += 8) { + _mm256_storeu_ps( + &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv)); + } + for (; j < block_size; j++) { + op[j] = len_inv * op[j]; + } + } + } + } + return dataInd == index_size; +} +bool EmbeddingLookupIdx_int64_t_bfloat16_float_false__avx2_fma( + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t data_size, + const at::BFloat16* input, + const int64_t* indices, + const int64_t* offsets, + const float* weights, + const float* scale_bias, + bool normalize_by_lengths, + float* out) { + return EmbeddingLookupIdx_int64_t_bfloat16_float__avx2_fma( + block_size, + output_size, + index_size, + data_size, + input, + indices, + offsets, + weights, + scale_bias, + normalize_by_lengths, + out); +} +bool EmbeddingLookupIdx_int64_t_bfloat16_float_true__avx2_fma( + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t data_size, + const at::BFloat16* input, + const int64_t* indices, + const int64_t* offsets, + const float* weights, + const float* scale_bias, + bool normalize_by_lengths, + float* out) { + return EmbeddingLookupIdx_int64_t_bfloat16_float__avx2_fma( + block_size, + output_size, + index_size, + data_size, + input, + indices, + offsets, + weights, + scale_bias, + normalize_by_lengths, + out); +} + +template +static bool EmbeddingLookupIdx_int32_t_uint8_t_float__avx2_fma( + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t data_size, + const uint8_t* input, + const int* indices, + const int* offsets, + const float* weights, + const float* scale_bias, + bool normalize_by_lengths, + float* out) { + const int prefdist_T0 = 16; + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) + const int fused_block_size = block_size + 0; + int64_t dataInd = 0; + if (block_size == 128) { + // unrolling 16 times + for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + float* op = &out[rangeIndex * block_size]; + __m256 vop0 = _mm256_setzero_ps(); + __m256 vop8 = _mm256_setzero_ps(); + __m256 vop16 = _mm256_setzero_ps(); + __m256 vop24 = _mm256_setzero_ps(); + __m256 vop32 = _mm256_setzero_ps(); + __m256 vop40 = _mm256_setzero_ps(); + __m256 vop48 = _mm256_setzero_ps(); + __m256 vop56 = _mm256_setzero_ps(); + __m256 vop64 = _mm256_setzero_ps(); + __m256 vop72 = _mm256_setzero_ps(); + __m256 vop80 = _mm256_setzero_ps(); + __m256 vop88 = _mm256_setzero_ps(); + __m256 vop96 = _mm256_setzero_ps(); + __m256 vop104 = _mm256_setzero_ps(); + __m256 vop112 = _mm256_setzero_ps(); + __m256 vop120 = _mm256_setzero_ps(); + if (dataInd != offsets[rangeIndex] - offsets[0]) { + return false; + } + int64_t end_offset = offsets[rangeIndex + 1]; + int64_t length = end_offset - offsets[rangeIndex]; + for (int64_t start = dataInd; dataInd < end_offset - offsets[0]; + ++dataInd) { + const int idx = indices[dataInd]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + float bio; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; + } + bio = wgt * scale_bias[2 * idx + 1]; + wgt = wgt * scale_bias[2 * idx]; + __m256 vbio = _mm256_set1_ps(bio); + __m256 vwgt = _mm256_set1_ps(wgt); + const uint8_t* ip = &input[idx * fused_block_size]; + const int next_T0 = (dataInd < index_size - prefdist_T0) + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) + ? (dataInd + prefdist_T0) + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) + : dataInd; + const int idx_pref_T0 = indices[next_T0]; + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } + const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; + vop0 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast(ip + (0))))), + _mm256_add_ps(vop0, vbio)); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); + vop8 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast(ip + (8))))), + _mm256_add_ps(vop8, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[8]) + vop16 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast(ip + (16))))), + _mm256_add_ps(vop16, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[16]) + vop24 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast(ip + (24))))), + _mm256_add_ps(vop24, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[24]) + vop32 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast(ip + (32))))), + _mm256_add_ps(vop32, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[32]) + vop40 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast(ip + (40))))), + _mm256_add_ps(vop40, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[40]) + vop48 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast(ip + (48))))), + _mm256_add_ps(vop48, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[48]) + vop56 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast(ip + (56))))), + _mm256_add_ps(vop56, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[56]) + vop64 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast(ip + (64))))), + _mm256_add_ps(vop64, vbio)); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[64]), _MM_HINT_T0); + vop72 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast(ip + (72))))), + _mm256_add_ps(vop72, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[72]) + vop80 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast(ip + (80))))), + _mm256_add_ps(vop80, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[80]) + vop88 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast(ip + (88))))), + _mm256_add_ps(vop88, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[88]) + vop96 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast(ip + (96))))), + _mm256_add_ps(vop96, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[96]) + vop104 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast(ip + (104))))), + _mm256_add_ps(vop104, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[104]) + vop112 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast(ip + (112))))), + _mm256_add_ps(vop112, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[112]) + vop120 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast(ip + (120))))), + _mm256_add_ps(vop120, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[120]) + } + if (!normalize_by_lengths || length == 0) { + _mm256_storeu_ps(&op[0], vop0); + _mm256_storeu_ps(&op[8], vop8); + _mm256_storeu_ps(&op[16], vop16); + _mm256_storeu_ps(&op[24], vop24); + _mm256_storeu_ps(&op[32], vop32); + _mm256_storeu_ps(&op[40], vop40); + _mm256_storeu_ps(&op[48], vop48); + _mm256_storeu_ps(&op[56], vop56); + _mm256_storeu_ps(&op[64], vop64); + _mm256_storeu_ps(&op[72], vop72); + _mm256_storeu_ps(&op[80], vop80); + _mm256_storeu_ps(&op[88], vop88); + _mm256_storeu_ps(&op[96], vop96); + _mm256_storeu_ps(&op[104], vop104); + _mm256_storeu_ps(&op[112], vop112); + _mm256_storeu_ps(&op[120], vop120); + } else { + __m256 vlen_inv = _mm256_set1_ps(1.0f / length); + _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); + _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); + _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); + _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); + _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv)); + _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv)); + _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv)); + _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv)); + _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv)); + _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv)); + _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv)); + _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv)); + _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv)); + _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv)); + _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv)); + _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv)); + } + } + } else if (block_size == 64) { + // unrolling 8 times + for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + float* op = &out[rangeIndex * block_size]; + __m256 vop0 = _mm256_setzero_ps(); + __m256 vop8 = _mm256_setzero_ps(); + __m256 vop16 = _mm256_setzero_ps(); + __m256 vop24 = _mm256_setzero_ps(); + __m256 vop32 = _mm256_setzero_ps(); + __m256 vop40 = _mm256_setzero_ps(); + __m256 vop48 = _mm256_setzero_ps(); + __m256 vop56 = _mm256_setzero_ps(); + if (dataInd != offsets[rangeIndex] - offsets[0]) { + return false; + } + int64_t end_offset = offsets[rangeIndex + 1]; + int64_t length = end_offset - offsets[rangeIndex]; + for (int64_t start = dataInd; dataInd < end_offset - offsets[0]; + ++dataInd) { + const int idx = indices[dataInd]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + float bio; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; + } + bio = wgt * scale_bias[2 * idx + 1]; + wgt = wgt * scale_bias[2 * idx]; + __m256 vbio = _mm256_set1_ps(bio); + __m256 vwgt = _mm256_set1_ps(wgt); + const uint8_t* ip = &input[idx * fused_block_size]; + const int next_T0 = (dataInd < index_size - prefdist_T0) + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) + ? (dataInd + prefdist_T0) + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) + : dataInd; + const int idx_pref_T0 = indices[next_T0]; + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } + const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; + vop0 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast(ip + (0))))), + _mm256_add_ps(vop0, vbio)); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); + vop8 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast(ip + (8))))), + _mm256_add_ps(vop8, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[8]) + vop16 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast(ip + (16))))), + _mm256_add_ps(vop16, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[16]) + vop24 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast(ip + (24))))), + _mm256_add_ps(vop24, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[24]) + vop32 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast(ip + (32))))), + _mm256_add_ps(vop32, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[32]) + vop40 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast(ip + (40))))), + _mm256_add_ps(vop40, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[40]) + vop48 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast(ip + (48))))), + _mm256_add_ps(vop48, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[48]) + vop56 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast(ip + (56))))), + _mm256_add_ps(vop56, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[56]) + } + if (!normalize_by_lengths || length == 0) { + _mm256_storeu_ps(&op[0], vop0); + _mm256_storeu_ps(&op[8], vop8); + _mm256_storeu_ps(&op[16], vop16); + _mm256_storeu_ps(&op[24], vop24); + _mm256_storeu_ps(&op[32], vop32); + _mm256_storeu_ps(&op[40], vop40); + _mm256_storeu_ps(&op[48], vop48); + _mm256_storeu_ps(&op[56], vop56); + } else { + __m256 vlen_inv = _mm256_set1_ps(1.0f / length); + _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); + _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); + _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); + _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); + _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv)); + _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv)); + _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv)); + _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv)); + } + } + } else if (block_size == 32) { + // unrolling 4 times + for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + float* op = &out[rangeIndex * block_size]; + __m256 vop0 = _mm256_setzero_ps(); + __m256 vop8 = _mm256_setzero_ps(); + __m256 vop16 = _mm256_setzero_ps(); + __m256 vop24 = _mm256_setzero_ps(); + if (dataInd != offsets[rangeIndex] - offsets[0]) { + return false; + } + int64_t end_offset = offsets[rangeIndex + 1]; + int64_t length = end_offset - offsets[rangeIndex]; + for (int64_t start = dataInd; dataInd < end_offset - offsets[0]; + ++dataInd) { + const int idx = indices[dataInd]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + float bio; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; + } + bio = wgt * scale_bias[2 * idx + 1]; + wgt = wgt * scale_bias[2 * idx]; + __m256 vbio = _mm256_set1_ps(bio); + __m256 vwgt = _mm256_set1_ps(wgt); + const uint8_t* ip = &input[idx * fused_block_size]; + const int next_T0 = (dataInd < index_size - prefdist_T0) + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) + ? (dataInd + prefdist_T0) + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) + : dataInd; + const int idx_pref_T0 = indices[next_T0]; + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } + const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; + vop0 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast(ip + (0))))), + _mm256_add_ps(vop0, vbio)); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); + vop8 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast(ip + (8))))), + _mm256_add_ps(vop8, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[8]) + vop16 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast(ip + (16))))), + _mm256_add_ps(vop16, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[16]) + vop24 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast(ip + (24))))), + _mm256_add_ps(vop24, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[24]) + } + if (!normalize_by_lengths || length == 0) { + _mm256_storeu_ps(&op[0], vop0); + _mm256_storeu_ps(&op[8], vop8); + _mm256_storeu_ps(&op[16], vop16); + _mm256_storeu_ps(&op[24], vop24); + } else { + __m256 vlen_inv = _mm256_set1_ps(1.0f / length); + _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); + _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); + _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); + _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); + } + } + } else if (block_size == 16) { + // unrolling 2 times + for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + float* op = &out[rangeIndex * block_size]; + __m256 vop0 = _mm256_setzero_ps(); + __m256 vop8 = _mm256_setzero_ps(); + if (dataInd != offsets[rangeIndex] - offsets[0]) { + return false; + } + int64_t end_offset = offsets[rangeIndex + 1]; + int64_t length = end_offset - offsets[rangeIndex]; + for (int64_t start = dataInd; dataInd < end_offset - offsets[0]; + ++dataInd) { + const int idx = indices[dataInd]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + float bio; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; + } + bio = wgt * scale_bias[2 * idx + 1]; + wgt = wgt * scale_bias[2 * idx]; + __m256 vbio = _mm256_set1_ps(bio); + __m256 vwgt = _mm256_set1_ps(wgt); + const uint8_t* ip = &input[idx * fused_block_size]; + const int next_T0 = (dataInd < index_size - prefdist_T0) + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) + ? (dataInd + prefdist_T0) + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) + : dataInd; + const int idx_pref_T0 = indices[next_T0]; + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } + const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; + vop0 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast(ip + (0))))), + _mm256_add_ps(vop0, vbio)); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); + vop8 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast(ip + (8))))), + _mm256_add_ps(vop8, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[8]) + } + if (!normalize_by_lengths || length == 0) { + _mm256_storeu_ps(&op[0], vop0); + _mm256_storeu_ps(&op[8], vop8); + } else { + __m256 vlen_inv = _mm256_set1_ps(1.0f / length); + _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); + _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); + } + } + } else { + // generic code + // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays) + for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + float* op = &out[rangeIndex * block_size]; + int64_t j = 0; + for (; j + 8 <= block_size; j += 8) { + _mm256_storeu_ps(op + j, _mm256_setzero_ps()); + } + for (; j < block_size; j++) { + op[j] = 0.0f; + } + if (dataInd != offsets[rangeIndex] - offsets[0]) { + return false; + } + int64_t end_offset = offsets[rangeIndex + 1]; + int64_t length = end_offset - offsets[rangeIndex]; + for (int64_t start = dataInd; dataInd < end_offset - offsets[0]; + ++dataInd) { + const int idx = indices[dataInd]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + float bio; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; + } + bio = wgt * scale_bias[2 * idx + 1]; + wgt = wgt * scale_bias[2 * idx]; + __m256 vbio = _mm256_set1_ps(bio); + __m256 vwgt = _mm256_set1_ps(wgt); + const uint8_t* ip = &input[idx * fused_block_size]; + const int next_T0 = (dataInd < index_size - prefdist_T0) + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) + ? (dataInd + prefdist_T0) + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) + : dataInd; + const int idx_pref_T0 = indices[next_T0]; + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } + const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; + j = 0; + for (; j + 8 <= block_size; j += 8) { + _mm256_storeu_ps( + &op[j], + _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64( + reinterpret_cast(&ip[j])))), + _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio))); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[j]), _MM_HINT_T0); + } + for (; j < block_size; j++) { + op[j] = std::fma(wgt, (float)ip[j], bio + op[j]); + } + } + if (normalize_by_lengths && length) { + float len_inv = 1.0f / length; + __m256 vlen_inv = _mm256_set1_ps(len_inv); + j = 0; + for (; j + 8 <= block_size; j += 8) { + _mm256_storeu_ps( + &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv)); + } + for (; j < block_size; j++) { + op[j] = len_inv * op[j]; + } + } + } + } + return dataInd == index_size; +} +bool EmbeddingLookupIdx_int32_t_uint8_t_float_false__avx2_fma( + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t data_size, + const uint8_t* input, + const int* indices, + const int* offsets, + const float* weights, + const float* scale_bias, + bool normalize_by_lengths, + float* out) { + return EmbeddingLookupIdx_int32_t_uint8_t_float__avx2_fma( + block_size, + output_size, + index_size, + data_size, + input, + indices, + offsets, + weights, + scale_bias, + normalize_by_lengths, + out); +} +bool EmbeddingLookupIdx_int32_t_uint8_t_float_true__avx2_fma( + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t data_size, + const uint8_t* input, + const int* indices, + const int* offsets, + const float* weights, + const float* scale_bias, + bool normalize_by_lengths, + float* out) { + return EmbeddingLookupIdx_int32_t_uint8_t_float__avx2_fma( + block_size, + output_size, + index_size, + data_size, + input, + indices, + offsets, + weights, + scale_bias, + normalize_by_lengths, + out); +} + +template +static bool EmbeddingLookupIdx_int64_t_uint8_t_float__avx2_fma( + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t data_size, + const uint8_t* input, + const int64_t* indices, + const int64_t* offsets, + const float* weights, + const float* scale_bias, + bool normalize_by_lengths, + float* out) { + const int64_t prefdist_T0 = 16; + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) + const int64_t fused_block_size = block_size + 0; + int64_t dataInd = 0; + if (block_size == 128) { + // unrolling 16 times + for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + float* op = &out[rangeIndex * block_size]; + __m256 vop0 = _mm256_setzero_ps(); + __m256 vop8 = _mm256_setzero_ps(); + __m256 vop16 = _mm256_setzero_ps(); + __m256 vop24 = _mm256_setzero_ps(); + __m256 vop32 = _mm256_setzero_ps(); + __m256 vop40 = _mm256_setzero_ps(); + __m256 vop48 = _mm256_setzero_ps(); + __m256 vop56 = _mm256_setzero_ps(); + __m256 vop64 = _mm256_setzero_ps(); + __m256 vop72 = _mm256_setzero_ps(); + __m256 vop80 = _mm256_setzero_ps(); + __m256 vop88 = _mm256_setzero_ps(); + __m256 vop96 = _mm256_setzero_ps(); + __m256 vop104 = _mm256_setzero_ps(); + __m256 vop112 = _mm256_setzero_ps(); + __m256 vop120 = _mm256_setzero_ps(); + if (dataInd != offsets[rangeIndex] - offsets[0]) { + return false; + } + int64_t end_offset = offsets[rangeIndex + 1]; + int64_t length = end_offset - offsets[rangeIndex]; + for (int64_t start = dataInd; dataInd < end_offset - offsets[0]; + ++dataInd) { + const int64_t idx = indices[dataInd]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + float bio; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; + } + bio = wgt * scale_bias[2 * idx + 1]; + wgt = wgt * scale_bias[2 * idx]; + __m256 vbio = _mm256_set1_ps(bio); + __m256 vwgt = _mm256_set1_ps(wgt); + const uint8_t* ip = &input[idx * fused_block_size]; + const int64_t next_T0 = (dataInd < index_size - prefdist_T0) + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) + ? (dataInd + prefdist_T0) + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) + : dataInd; + const int64_t idx_pref_T0 = indices[next_T0]; + if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { + return false; + } + const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; + vop0 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast(ip + (0))))), + _mm256_add_ps(vop0, vbio)); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); + vop8 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast(ip + (8))))), + _mm256_add_ps(vop8, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[8]) + vop16 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast(ip + (16))))), + _mm256_add_ps(vop16, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[16]) + vop24 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast(ip + (24))))), + _mm256_add_ps(vop24, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[24]) + vop32 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast(ip + (32))))), + _mm256_add_ps(vop32, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[32]) + vop40 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast(ip + (40))))), + _mm256_add_ps(vop40, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[40]) + vop48 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast(ip + (48))))), + _mm256_add_ps(vop48, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[48]) + vop56 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast(ip + (56))))), + _mm256_add_ps(vop56, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[56]) + vop64 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast(ip + (64))))), + _mm256_add_ps(vop64, vbio)); + _mm_prefetch( + reinterpret_cast(&ip_next_T0[64]), _MM_HINT_T0); + vop72 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (72))))), _mm256_add_ps(vop72, vbio)); // skip unnecessary prefetch of (&ip_next_T0[72]) @@ -2844,7 +4137,9 @@ static bool EmbeddingLookupIdx_int64_t_uint8_t_float__avx2_fma( __m256 vwgt = _mm256_set1_ps(wgt); const uint8_t* ip = &input[idx * fused_block_size]; const int64_t next_T0 = (dataInd < index_size - prefdist_T0) + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) ? (dataInd + prefdist_T0) + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { @@ -2953,7 +4248,9 @@ static bool EmbeddingLookupIdx_int64_t_uint8_t_float__avx2_fma( __m256 vwgt = _mm256_set1_ps(wgt); const uint8_t* ip = &input[idx * fused_block_size]; const int64_t next_T0 = (dataInd < index_size - prefdist_T0) + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) ? (dataInd + prefdist_T0) + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { @@ -3028,7 +4325,9 @@ static bool EmbeddingLookupIdx_int64_t_uint8_t_float__avx2_fma( __m256 vwgt = _mm256_set1_ps(wgt); const uint8_t* ip = &input[idx * fused_block_size]; const int64_t next_T0 = (dataInd < index_size - prefdist_T0) + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) ? (dataInd + prefdist_T0) + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { @@ -3060,6 +4359,7 @@ static bool EmbeddingLookupIdx_int64_t_uint8_t_float__avx2_fma( } } else { // generic code + // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays) for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; int64_t j = 0; @@ -3092,7 +4392,9 @@ static bool EmbeddingLookupIdx_int64_t_uint8_t_float__avx2_fma( __m256 vwgt = _mm256_set1_ps(wgt); const uint8_t* ip = &input[idx * fused_block_size]; const int64_t next_T0 = (dataInd < index_size - prefdist_T0) + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) ? (dataInd + prefdist_T0) + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { diff --git a/caffe2/perfkernels/hp_emblookup_codegen.py b/caffe2/perfkernels/hp_emblookup_codegen.py index 402f3bb92a415..7e4208caf6556 100644 --- a/caffe2/perfkernels/hp_emblookup_codegen.py +++ b/caffe2/perfkernels/hp_emblookup_codegen.py @@ -4,7 +4,7 @@ import sys -sizeof = {"float": 4, "at::Half": 2, "uint8_t": 1} +sizeof = {"float": 4, "at::Half": 2, "at::BFloat16": 2, "uint8_t": 1} def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets): @@ -24,6 +24,16 @@ def compute(regid, InType, use_weights, isa, prefetch): " _mm_loadu_si128(reinterpret_cast(ip + (%d)))),\n" # noqa " vop%d);" % (regid, regid, regid) ) + elif InType == "at::BFloat16": + code.append( + " vop%d = _mm256_fmadd_ps(\n" + " vwgt,\n" + " _mm256_castsi256_ps(_mm256_slli_epi32(\n" + " _mm256_cvtepu16_epi32(_mm_loadu_si128(\n" + " reinterpret_cast(ip + (%d)))),\n" + " 16)),\n" # noqa + " vop%d);" % (regid, regid, regid) + ) elif InType == "uint8_t": code.append( " vop%d = _mm256_fmadd_ps(\n" @@ -104,6 +114,7 @@ def compute(regid, InType, use_weights, isa, prefetch): if InType == "uint8_t": code.append(" " + OutType + " wgt = 1.f;") + code.append(" // NOLINTNEXTLINE(cppcoreguidelines-init-variables)") code.append(" " + OutType + " bio;") code.append(" if (weights) {") code.append( @@ -133,7 +144,10 @@ def compute(regid, InType, use_weights, isa, prefetch): code.append(" const {}* ip = &input[idx * fused_block_size];".format(InType)) code.append( " const {} next_T0 = (dataInd < index_size - prefdist_T0)\n" - " ? (dataInd + prefdist_T0)\n : dataInd;".format( + " // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)\n" + " ? (dataInd + prefdist_T0)\n" + " // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)\n" + " : dataInd;".format( IndexType ) ) @@ -206,6 +220,18 @@ def compute(InType, use_weights, isa): " reinterpret_cast(&ip[j]))),\n" " _mm256_loadu_ps(&op[j])));" ) + elif InType == "at::BFloat16": + code.append( + " _mm256_storeu_ps(\n" + " &op[j],\n" + " _mm256_fmadd_ps(\n" + " vwgt,\n" + " _mm256_castsi256_ps(_mm256_slli_epi32(\n" + " _mm256_cvtepu16_epi32(_mm_loadu_si128(\n" + " reinterpret_cast(&ip[j]))),\n" + " 16)),\n" + " _mm256_loadu_ps(&op[j])));" + ) elif InType == "uint8_t": code.append( " _mm256_storeu_ps(\n" @@ -229,7 +255,8 @@ def compute(InType, use_weights, isa): code = [] if InType == "at::Half": code.append(" alignas(64) at::Half vtmp1[8] = {0};") - + if InType == "at::BFloat16": + code.append(" alignas(64) at::BFloat16 vtmp1[8] = {0};") if use_offsets: @@ -291,6 +318,7 @@ def compute(InType, use_weights, isa): if InType == "uint8_t": code.append(" " + OutType + " wgt = 1.f;") + code.append(" // NOLINTNEXTLINE(cppcoreguidelines-init-variables)") code.append(" " + OutType + " bio;") code.append(" if (weights) {") code.append( @@ -320,7 +348,10 @@ def compute(InType, use_weights, isa): code.append(" const {}* ip = &input[idx * fused_block_size];".format(InType)) code.append( " const {} next_T0 = (dataInd < index_size - prefdist_T0)\n" - " ? (dataInd + prefdist_T0)\n : dataInd;".format( + " // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)\n" + " ? (dataInd + prefdist_T0)\n" + " // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)\n" + " : dataInd;".format( IndexType ) ) @@ -351,6 +382,14 @@ def compute(InType, use_weights, isa): " _mm256_cvtph_ps(*(reinterpret_cast(vtmp1)));" ) code.append(" op[j] = std::fma(wgt, ((float*)(&vtmp2))[0], op[j]);") + elif InType == "at::BFloat16": + code.append(" vtmp1[0] = ip[j];") + code.append( + " __m256 vtmp2 = _mm256_castsi256_ps(_mm256_slli_epi32(\n" + " _mm256_cvtepu16_epi32(*(reinterpret_cast(vtmp1))),\n" + " 16));" + ) + code.append(" op[j] = std::fma(wgt, ((float*)(&vtmp2))[0], op[j]);") elif InType == "uint8_t": code.append(" op[j] = std::fma(wgt, (float)ip[j], bio + op[j]);") else: @@ -408,6 +447,8 @@ def compute(InType, use_weights, isa): ["int64_t", "int64_t", "float", "float", "float", "float"], ["int32_t", "int", "half", "at::Half", "float", "float"], ["int64_t", "int64_t", "half", "at::Half", "float", "float"], + ["int32_t", "int", "bfloat16", "at::BFloat16", "float", "float"], + ["int64_t", "int64_t", "bfloat16", "at::BFloat16", "float", "float"], ["int32_t", "int", "uint8_t", "uint8_t", "float", "float"], ["int64_t", "int64_t", "uint8_t", "uint8_t", "float", "float"], ] @@ -422,6 +463,7 @@ def compute(InType, use_weights, isa): code.append("//// --------------------------\n") code.append("#include ") +code.append("#include ") code.append("#include ") code.append("namespace caffe2 {\n") @@ -461,6 +503,7 @@ def compute(InType, use_weights, isa): code += args code.append(" const " + IndexType + " prefdist_T0 = 16;") + code.append(" // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)") # block_size is the number of elements and fused_block_size is the size of # an entire row, including scale and bias. offset = (8 // sizeof[InType]) if opts.fused else 0 @@ -484,6 +527,7 @@ def compute(InType, use_weights, isa): code += unroll(2, IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets) code.append(" } else {") code.append(" // generic code") + code.append(" // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)") code += generic(IndexType, InType, OutType, True, "AVX2", opts.fused, opts.use_offsets) code.append(" }") code.append(" return dataInd == index_size;") diff --git a/caffe2/perfkernels/lstm_unit_cpu-impl.h b/caffe2/perfkernels/lstm_unit_cpu-impl.h index 5e76e1aa39fe5..239d2807f7788 100644 --- a/caffe2/perfkernels/lstm_unit_cpu-impl.h +++ b/caffe2/perfkernels/lstm_unit_cpu-impl.h @@ -5,27 +5,7 @@ #include "c10/util/irange.h" #include "caffe2/utils/conversions.h" -#if (ENABLE_VECTORIZATION > 0) && !defined(_DEBUG) && !defined(DEBUG) -#if defined(__clang__) && (__clang_major__ > 7) -#define IS_SANITIZER \ - ((__has_feature(address_sanitizer) == 1) || \ - (__has_feature(memory_sanitizer) == 1) || \ - (__has_feature(thread_sanitizer) == 1) || \ - (__has_feature(undefined_sanitizer) == 1)) - -#if IS_SANITIZER == 0 -#define VECTOR_LOOP _Pragma("clang loop vectorize(enable)") -#endif -#elif defined(_OPENMP) && (_OPENMP >= 201511) -// Support with OpenMP4.5 and above -#define VECTOR_LOOP _Pragma("omp for simd") -#endif -#endif - -#ifndef VECTOR_LOOP -// Not supported -#define VECTOR_LOOP -#endif +#include "vectorizer.h" namespace caffe2 { namespace perfkernels { diff --git a/caffe2/perfkernels/vectorizer.h b/caffe2/perfkernels/vectorizer.h new file mode 100644 index 0000000000000..be4e6bbc280f0 --- /dev/null +++ b/caffe2/perfkernels/vectorizer.h @@ -0,0 +1,28 @@ +#pragma once + +#if (ENABLE_VECTORIZATION > 0) && !defined(_DEBUG) && !defined(DEBUG) +#if defined(__clang__) && (__clang_major__ > 7) +#define IS_SANITIZER \ + ((__has_feature(address_sanitizer) == 1) || \ + (__has_feature(memory_sanitizer) == 1) || \ + (__has_feature(thread_sanitizer) == 1) || \ + (__has_feature(undefined_sanitizer) == 1)) + +#if IS_SANITIZER == 0 +#define VECTOR_LOOP _Pragma("clang loop vectorize(enable)") +#define FAST_MATH _Pragma("clang fp contract(fast)") +#define VECTORIZED_KERNEL 1 +#endif +#elif defined(_OPENMP) && (_OPENMP >= 201511) +// Support with OpenMP4.5 and above +#define VECTOR_LOOP _Pragma("omp for simd") +#define VECTORIZED_KERNEL 1 +#define FAST_MATH +#endif +#endif + +#ifndef VECTOR_LOOP +// Not supported +#define VECTOR_LOOP +#define FAST_MATH +#endif diff --git a/caffe2/python/CMakeLists.txt b/caffe2/python/CMakeLists.txt index c092febee4a90..464aa24eadd29 100644 --- a/caffe2/python/CMakeLists.txt +++ b/caffe2/python/CMakeLists.txt @@ -1,6 +1,7 @@ # ---[ CPU files. set(Caffe2_CPU_PYTHON_SRCS "/pybind_state.cc" + "/pybind_workspace.cc" "/pybind_state_dlpack.cc" "/pybind_state_nomni.cc" "/pybind_state_registry.cc" diff --git a/caffe2/python/caffe_translator.py b/caffe2/python/caffe_translator.py index e0aebaf7b24ea..63b5706120ac0 100644 --- a/caffe2/python/caffe_translator.py +++ b/caffe2/python/caffe_translator.py @@ -210,9 +210,9 @@ def TranslateLayer(cls, layer, pretrained_blobs, is_test, **kwargs): try: caffe_ops, params = cls.registry_[layer.type]( layer, pretrained_blobs, is_test, **kwargs) - except KeyError: + except KeyError as e: raise KeyError('No translator registered for layer: %s yet.' % - str(layer)) + str(layer)) from e if caffe_ops is None: caffe_ops = [] if type(caffe_ops) is not list: diff --git a/caffe2/python/clean_workspace_test.py b/caffe2/python/clean_workspace_test.py new file mode 100644 index 0000000000000..c8285f4a1c5bd --- /dev/null +++ b/caffe2/python/clean_workspace_test.py @@ -0,0 +1,15 @@ +import unittest + +from caffe2.python import workspace + + +# This test is extracted out from workspace_test.py because it relies on the pristine +# state of the initial workspace. When tests are run in different orders, this test may +# become flaky because of global state modifications impacting what the root folder is +# after a reset. +class TestWorkspace(unittest.TestCase): + def testRootFolder(self): + self.assertEqual(workspace.ResetWorkspace(), True) + self.assertEqual(workspace.RootFolder(), ".") + self.assertEqual(workspace.ResetWorkspace("/tmp/caffe-workspace-test"), True) + self.assertEqual(workspace.RootFolder(), "/tmp/caffe-workspace-test") diff --git a/caffe2/python/core.py b/caffe2/python/core.py index 4ae75272d3820..9c2efe282f136 100644 --- a/caffe2/python/core.py +++ b/caffe2/python/core.py @@ -640,7 +640,7 @@ def AppendSparseGenerators(self, sparse_generators): assert(g1 == g2) assert dev_1 == dev_2, ( "Unequal devices for sparse generators: " - "{} and {}".format(dev1, dev2) + "{} and {}".format(dev_1, dev_2) ) assert(op1_i is None or op2_i is None) assert(op1_v is None or op2_v is None) @@ -970,7 +970,7 @@ def DoGradientAccumulation(self, fwd_op_idx): input_name, err ) - ) + ) from err # Finally, let's create the sum operator. sum_ops, g = self._MakeSumOps(input_name, input_version) @@ -1175,7 +1175,7 @@ def GetGradientForOp(cls, op, g_output): raise Exception( "Exception when creating gradient for [{}]:{}.\nOp: \n{}". format(op.type, e, str(op)) - ) + ) from e if gradient_ops is None: return [], g_input diff --git a/caffe2/python/model_helper.py b/caffe2/python/model_helper.py index 5eb81d898b33f..18219d3923b42 100644 --- a/caffe2/python/model_helper.py +++ b/caffe2/python/model_helper.py @@ -540,8 +540,8 @@ def ExtractPredictorNet( 'StopGradient' ] ) - except ValueError: - raise Exception("No ops with input={}".format(input_blobs)) + except ValueError as e: + raise Exception("No ops with input={}".format(input_blobs)) from e try: last_op_with_output = max( [ @@ -549,8 +549,8 @@ def ExtractPredictorNet( if output_blobs.intersection(ops[j].output) ] ) - except ValueError: - raise Exception("No ops with output={}".format(output_blobs)) + except ValueError as e: + raise Exception("No ops with output={}".format(output_blobs)) from e def validate_op(op): # Check that the op does not have is_test = 0 set. This is a common diff --git a/caffe2/python/models/download.py b/caffe2/python/models/download.py index 7e735c726568e..895f87a4e4501 100644 --- a/caffe2/python/models/download.py +++ b/caffe2/python/models/download.py @@ -69,10 +69,10 @@ def downloadFromURLToFile(url, filename, show_progress=True): print("") # New line to fix for progress bar except HTTPError as e: raise Exception("Could not download model. [HTTP Error] {code}: {reason}." - .format(code=e.code, reason=e.reason)) + .format(code=e.code, reason=e.reason)) from e except URLError as e: raise Exception("Could not download model. [URL Error] {reason}." - .format(reason=e.reason)) + .format(reason=e.reason)) from e def getURLFromName(name, filename): diff --git a/caffe2/python/onnx/ONNXOpCoverage.md b/caffe2/python/onnx/ONNXOpCoverage.md index bb4b71f055356..66cf2d692e87c 100644 --- a/caffe2/python/onnx/ONNXOpCoverage.md +++ b/caffe2/python/onnx/ONNXOpCoverage.md @@ -19,7 +19,7 @@ This doc keeps tracking why operators are not covered by the testcases. |Atan|||💚OK| |AveragePool||OK|💚OK| |BatchNormalization||OK|💚OK| -|Cast|Yes||💔Need extendtion| +|Cast|Yes||💔Need extension| |Ceil|Yes||💚OK| |Clip|Yes|OK|💚OK| |Concat|Yes|OK|💚OK| diff --git a/caffe2/python/operator_test/_utils.py b/caffe2/python/operator_test/_utils.py new file mode 100644 index 0000000000000..3ee1def89e715 --- /dev/null +++ b/caffe2/python/operator_test/_utils.py @@ -0,0 +1,50 @@ +""" +This file only exists since `torch.testing.assert_allclose` is deprecated, but used extensively throughout the tests in +this package. The replacement `torch.testing.assert_close` doesn't support one feature that is needed here: comparison +between numpy arrays and torch tensors. See https://github.com/pytorch/pytorch/issues/61844 for the reasoning why this +was removed. +""" + +import torch +from typing import Tuple, Any, Optional + +_DTYPE_PRECISIONS = { + torch.float16: (1e-3, 1e-3), + torch.float32: (1e-4, 1e-5), + torch.float64: (1e-5, 1e-8), +} + + +def _get_default_rtol_and_atol(actual: torch.Tensor, expected: torch.Tensor) -> Tuple[float, float]: + actual_rtol, actual_atol = _DTYPE_PRECISIONS.get(actual.dtype, (0.0, 0.0)) + expected_rtol, expected_atol = _DTYPE_PRECISIONS.get(expected.dtype, (0.0, 0.0)) + return max(actual_rtol, expected_rtol), max(actual_atol, expected_atol) + + +def assert_allclose( + actual: Any, + expected: Any, + rtol: Optional[float] = None, + atol: Optional[float] = None, + equal_nan: bool = True, + msg: str = "", +) -> None: + if not isinstance(actual, torch.Tensor): + actual = torch.tensor(actual) + if not isinstance(expected, torch.Tensor): + expected = torch.tensor(expected, dtype=actual.dtype) + + if rtol is None and atol is None: + rtol, atol = _get_default_rtol_and_atol(actual, expected) + + torch.testing.assert_close( + actual, + expected, + rtol=rtol, + atol=atol, + equal_nan=equal_nan, + check_device=True, + check_dtype=False, + check_stride=False, + msg=msg or None, + ) \ No newline at end of file diff --git a/caffe2/python/operator_test/layer_norm_op_test.py b/caffe2/python/operator_test/layer_norm_op_test.py index 32a2511e3e8e3..31ba78be0c19f 100644 --- a/caffe2/python/operator_test/layer_norm_op_test.py +++ b/caffe2/python/operator_test/layer_norm_op_test.py @@ -18,6 +18,8 @@ import unittest +from ._utils import assert_allclose + def _layer_norm_ref(axis, epsilon, X): left = int(np.prod(X.shape[:axis])) @@ -254,10 +256,9 @@ def test_layer_norm_op_c10_preallocated_outputs( actual_mean = self.ws.fetch_blob('mean') actual_std = self.ws.fetch_blob('std') - torch.testing.assert_allclose( - expected_norm, actual_norm, rtol=1e-4, atol=1e-4) - torch.testing.assert_allclose(expected_mean, actual_mean) - torch.testing.assert_allclose(expected_std, actual_std) + assert_allclose(expected_norm, actual_norm, rtol=1e-4, atol=1e-4) + assert_allclose(expected_mean, actual_mean) + assert_allclose(expected_std, actual_std) @given(X=hu.tensor(min_dim=2), eps=st.floats(1e-5, 1e-3), @@ -280,10 +281,9 @@ def test_layer_norm_op_pytorch(self, X, eps, elementwise_affine, gc, dc): actual_norm, actual_mean, actual_std = torch.ops._caffe2.LayerNorm( torch.tensor(X), None, None, axis, eps) - torch.testing.assert_allclose( - expected_norm, actual_norm, rtol=1e-4, atol=1e-4) - torch.testing.assert_allclose(expected_mean, actual_mean) - torch.testing.assert_allclose(expected_std, actual_std) + assert_allclose(expected_norm, actual_norm, rtol=1e-4, atol=1e-4) + assert_allclose(expected_mean, actual_mean) + assert_allclose(expected_std, actual_std) # Test case is using workspace.has_cuda_support and not # workspace.has_gpu_support to exclude it from HIP because tensor interop @@ -313,10 +313,9 @@ def test_layer_norm_op_pytorch_cuda(self, X, eps, elementwise_affine): actual_norm, actual_mean, actual_std = torch.ops._caffe2.LayerNorm( torch.tensor(X).cuda(), None, None, axis, eps) - torch.testing.assert_allclose( - expected_norm, actual_norm.cpu(), rtol=1e-4, atol=1e-4) - torch.testing.assert_allclose(expected_mean, actual_mean.cpu()) - torch.testing.assert_allclose(expected_std, actual_std.cpu()) + assert_allclose(expected_norm, actual_norm, rtol=1e-4, atol=1e-4) + assert_allclose(expected_mean, actual_mean) + assert_allclose(expected_std, actual_std) @given(X=hu.tensor(min_dim=2), eps=st.floats(1e-5, 1e-3), @@ -352,10 +351,9 @@ def jit_layer_norm( actual_norm, actual_mean, actual_std = jit_layer_norm( torch.tensor(X), None, None, axis, eps, elementwise_affine) - torch.testing.assert_allclose( - expected_norm, actual_norm, rtol=1e-4, atol=1e-4) - torch.testing.assert_allclose(expected_mean, actual_mean) - torch.testing.assert_allclose(expected_std, actual_std) + assert_allclose(expected_norm, actual_norm, rtol=1e-4, atol=1e-4) + assert_allclose(expected_mean, actual_mean) + assert_allclose(expected_std, actual_std) @given(X=hu.tensor(min_dim=2), **hu.gcs) def test_layer_norm_brew_wrapper(self, X, gc, dc): diff --git a/caffe2/python/operator_test/roi_align_rotated_op_test.py b/caffe2/python/operator_test/roi_align_rotated_op_test.py index ea835acead617..fcbcb555440bb 100644 --- a/caffe2/python/operator_test/roi_align_rotated_op_test.py +++ b/caffe2/python/operator_test/roi_align_rotated_op_test.py @@ -150,9 +150,9 @@ def roialign_flip(m, axis): indexer = [slice(None)] * m.ndim try: indexer[axis] = slice(None, None, -1) - except IndexError: + except IndexError as e: raise ValueError("axis=%i is invalid for the %i-dimensional input array" - % (axis, m.ndim)) + % (axis, m.ndim)) from e return m[tuple(indexer)] def roialign_ref(X, R): diff --git a/caffe2/python/operator_test/torch_integration_test.py b/caffe2/python/operator_test/torch_integration_test.py index f99a61688de6e..d143e0193dfd7 100644 --- a/caffe2/python/operator_test/torch_integration_test.py +++ b/caffe2/python/operator_test/torch_integration_test.py @@ -11,6 +11,8 @@ from hypothesis import given, settings from scipy.stats import norm +from ._utils import assert_allclose + def generate_rois(roi_counts, im_dims): assert len(roi_counts) == len(im_dims) @@ -172,7 +174,7 @@ def bbox_transform_ref(): legacy_plus_one=True, ) - torch.testing.assert_allclose(box_out, a) + assert_allclose(box_out, a) @given( roi_counts=st.lists(st.integers(0, 5), min_size=1, max_size=10), @@ -268,7 +270,7 @@ def box_with_nms_limit_ref(): ) for o, o_ref in zip(outputs, output_refs): - torch.testing.assert_allclose(o, o_ref) + assert_allclose(o, o_ref) @given( dim_1=st.integers(min_value=10, max_value=10), @@ -314,7 +316,7 @@ def sparse_to_dense_mask_ref(return_presence_mask=False): mask=mask, ) - torch.testing.assert_allclose(output, a) + assert_allclose(output, a) # Testing return_presence_mask = True output, presence_mask = sparse_to_dense_mask_ref(return_presence_mask=True) @@ -330,8 +332,8 @@ def sparse_to_dense_mask_ref(return_presence_mask=False): return_presence_mask=True, ) - torch.testing.assert_allclose(output, a) - torch.testing.assert_allclose(presence_mask, b) + assert_allclose(output, a) + assert_allclose(presence_mask, b) @given( A=st.integers(min_value=4, max_value=4), @@ -382,8 +384,8 @@ def generate_proposals_ref(): 1.0, legacy_plus_one=True, ) - torch.testing.assert_allclose(rois, a) - torch.testing.assert_allclose(rois_probs, b) + assert_allclose(rois, a) + assert_allclose(rois_probs, b) @given( bsz=st.integers(1, 5), @@ -461,9 +463,9 @@ def inference_lstm_ref(): a, b, c = torch.ops._caffe2.InferenceLSTM( lstm_in, num_layers, has_biases, batch_first, is_bidirectional ) - torch.testing.assert_allclose(output, a) - torch.testing.assert_allclose(hidden, b) - torch.testing.assert_allclose(cell, c) + assert_allclose(output, a) + assert_allclose(hidden, b) + assert_allclose(cell, c) # Test case is using workspace.has_cuda_support and not workspace.has_gpu_support # to exclude it from HIP because tensor interop doesn't work for HIP tensors yet @@ -517,8 +519,8 @@ def generate_proposals_ref(): 1.0, legacy_plus_one=True, ) - torch.testing.assert_allclose(rois, a.cpu()) - torch.testing.assert_allclose(rois_probs, b.cpu()) + assert_allclose(rois, a.cpu()) + assert_allclose(rois_probs, b.cpu()) @given( N=st.integers(min_value=1, max_value=2), @@ -567,7 +569,7 @@ def roi_align_ref(_feature, _rois): sampling_ratio=0, aligned=False, ) - torch.testing.assert_allclose(roi_feature_ref, roi_feature.cpu()) + assert_allclose(roi_feature_ref, roi_feature.cpu()) def test_roi_align_cpu(self): self._test_roi_align(device="cpu") @@ -624,7 +626,7 @@ def roi_align_ref(_feature, _rois): sampling_ratio=0, aligned=False, ) - torch.testing.assert_allclose(roi_feature_ref, roi_feature.cpu()) + assert_allclose(roi_feature_ref, roi_feature.cpu()) def test_roi_align_rotated_cpu(self): self._test_roi_align_rotated(device="cpu") @@ -674,9 +676,9 @@ def test_collect_and_distribute_fpn_rpn_proposals_op(self, roi_counts): rois_idx_restore_int32 = fpn_outputs[-1] # [rois] + fpn_outputs should be equal to all_outputs - torch.testing.assert_allclose(rois, all_outputs[0]) + assert_allclose(rois, all_outputs[0]) for x, y in zip(fpn_outputs, all_outputs[1:]): - torch.testing.assert_allclose(x, y) + assert_allclose(x, y) @given(X=hu.tensor(), fast_gelu=st.booleans()) def _test_gelu_op(self, X, fast_gelu, device): @@ -688,7 +690,7 @@ def _gelu_ref(_X): rtol = 1e-3 if fast_gelu else 1e-4 atol = 1e-5 - torch.testing.assert_allclose( + assert_allclose( expected_output, actual_output.cpu(), rtol=rtol, atol=atol ) @@ -719,7 +721,7 @@ def _lengths_ref(X, Y): torch.tensor(data), torch.tensor(lengths, dtype=torch.int32) ) - torch.testing.assert_allclose(expected_output, actual_output.cpu()) + assert_allclose(expected_output, actual_output.cpu()) def _test_lengths_sum_op(self, device): self._test_lengths_op("LengthsSum", torch.ops._caffe2.LengthsSum, device) @@ -775,7 +777,7 @@ def _resize_nearest_ref(X): height_scale=1.5, ) - torch.testing.assert_allclose(expected_output, actual_output.cpu()) + assert_allclose(expected_output, actual_output.cpu()) def test_resize_nearest_op_cpu(self): return self._test_resize_nearest_op("cpu") @@ -838,16 +840,16 @@ def _piecewise_linear_ref(X): binary_input, ) - torch.testing.assert_allclose(torch.tensor(expected_output), actual_output) + assert_allclose(torch.tensor(expected_output), actual_output) def test_alias_with_name_is_in_place(self): device = "cuda" if workspace.has_cuda_support else "cpu" x = torch.tensor([3., 42.]).to(device=device) y = torch.ops._caffe2.AliasWithName(x, "new_name") x[1] = 6 - torch.testing.assert_allclose(x, torch.tensor([3., 6.]).to(device=device)) + assert_allclose(x, torch.tensor([3., 6.]).to(device=device)) # y should also change because y is alias of x - torch.testing.assert_allclose(y, torch.tensor([3., 6.]).to(device=device)) + assert_allclose(y, torch.tensor([3., 6.]).to(device=device)) @unittest.skipIf(not workspace.has_cuda_support, "No cuda support") def test_copy_between_cpu_and_gpu(self): @@ -855,9 +857,9 @@ def test_copy_between_cpu_and_gpu(self): x_gpu_ref = x_cpu_ref.to("cuda") x_gpu = torch.ops._caffe2.CopyCPUToGPU(x_cpu_ref) - torch.testing.assert_allclose(x_gpu, x_gpu_ref) + assert_allclose(x_gpu, x_gpu_ref) x_cpu = torch.ops._caffe2.CopyGPUToCPU(x_gpu) - torch.testing.assert_allclose(x_cpu, x_cpu_ref) + assert_allclose(x_cpu, x_cpu_ref) def test_index_hash_op(self): data = np.random.randint(low=0, high=1000, size=(4, 4, 4)) @@ -873,7 +875,7 @@ def _index_hash_ref(X): torch.tensor(data), seed=0, modulo=100 ) - torch.testing.assert_allclose(expected_output, actual_output.cpu()) + assert_allclose(expected_output, actual_output.cpu()) def test_bucketize_op(self): data = np.random.rand(8, 10).astype(np.float32) * 1000 @@ -889,7 +891,7 @@ def _bucketize_ref(X): expected_output = _bucketize_ref(data) actual_output = torch.ops._caffe2.Bucketize(torch.tensor(data), boundaries) - torch.testing.assert_allclose(expected_output, actual_output.cpu()) + assert_allclose(expected_output, actual_output.cpu()) @given(X=hu.tensor(), eps=st.floats(min_value=1e-4, max_value=1e-2)) def test_logit(self, X, eps): @@ -901,7 +903,7 @@ def ref(X, eps): expected_output = ref(X, eps) actual_output = torch.ops._caffe2.Logit(torch.tensor(X), eps) - torch.testing.assert_allclose(expected_output, actual_output.cpu()) + assert_allclose(expected_output, actual_output.cpu()) def test_percentile(self): original_values = np.array([[3.0, 5.0, 3], [5.0, 1.0, 6.0]]).astype(np.float32) @@ -926,7 +928,7 @@ def _percentile_ref(original_values, value_to_pct, lengths): torch.tensor(value_to_pct), torch.tensor(lengths), ) - torch.testing.assert_allclose(expected_output, actual_output.cpu()) + assert_allclose(expected_output, actual_output.cpu()) def test_batch_bucket_one_hot_op(self): data = np.array([[2, 3], [4, 1], [2, 5]]).astype(np.float32) @@ -947,7 +949,7 @@ def _batch_bucket_one_hot_ref(data, lengths, boundaries): actual_output = torch.ops._caffe2.BatchBucketOneHot( torch.tensor(data), torch.tensor(lengths), torch.tensor(boundaries) ) - torch.testing.assert_allclose(expected_output, actual_output.cpu()) + assert_allclose(expected_output, actual_output.cpu()) def test_gather_ranges_to_dense_op(self): data = np.array([1, 2, 3, 4, 5, 6, 7, 8]) @@ -1033,8 +1035,8 @@ def _merge_id_lists(lengths, values): torch.tensor(values[1]), ] ) - torch.testing.assert_allclose(expected_merged_lengths, output_merged_lengths) - torch.testing.assert_allclose(expected_merged_values, output_merged_values) + assert_allclose(expected_merged_lengths, output_merged_lengths) + assert_allclose(expected_merged_values, output_merged_values) def test_learning_rate(self): base_lr = 0.05 @@ -1097,7 +1099,7 @@ def test_pack_segments(self): packed_tensor, _ = torch.ops._caffe2.PackSegments(lengths, s) self.assertEqual(packed_tensor.numpy().shape, (2, 2, 3, 3)) unpacked_tensor = torch.ops._caffe2.UnpackSegments(lengths, packed_tensor) - torch.testing.assert_allclose(s, unpacked_tensor) + assert_allclose(s, unpacked_tensor) if __name__ == "__main__": diff --git a/caffe2/python/operator_test/video_input_op_test.py b/caffe2/python/operator_test/video_input_op_test.py index f21f219bd90eb..24f9e57434d4f 100644 --- a/caffe2/python/operator_test/video_input_op_test.py +++ b/caffe2/python/operator_test/video_input_op_test.py @@ -13,8 +13,8 @@ try: import lmdb -except ImportError: - raise unittest.SkipTest("python-lmdb is not installed") +except ImportError as e: + raise unittest.SkipTest("python-lmdb is not installed") from e class VideoInputOpTest(unittest.TestCase): diff --git a/caffe2/python/pybind_state.cc b/caffe2/python/pybind_state.cc index a637f15e7a9d3..5b2c2f71a827a 100644 --- a/caffe2/python/pybind_state.cc +++ b/caffe2/python/pybind_state.cc @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -33,6 +34,7 @@ #include "caffe2/predictor/emulator/data_filler.h" #include "caffe2/predictor/predictor.h" #include "caffe2/python/pybind_state_registry.h" +#include "caffe2/python/pybind_workspace.h" #include "caffe2/utils/cpuid.h" #include "caffe2/utils/proto_convert.h" #include "caffe2/utils/string_utils.h" @@ -56,14 +58,6 @@ constexpr bool kPyBindFalse = false; namespace py = pybind11; -// gWorkspaces allows us to define and switch between multiple workspaces in -// Python. -static std::map> gWorkspaces; -// gWorkspace is the pointer to the current workspace. The ownership is kept -// by the gWorkspaces map. -static Workspace* gWorkspace = nullptr; -static std::string gCurrentWorkspaceName; - // NOLINTNEXTLINE(modernize-use-equals-default) BlobFetcherBase::~BlobFetcherBase() {} // NOLINTNEXTLINE(modernize-use-equals-default) @@ -83,17 +77,6 @@ C10_DEFINE_TYPED_REGISTRY( REGISTER_BLOB_FETCHER((TypeMeta::Id()), TensorFetcher); REGISTER_BLOB_FEEDER(CPU, TensorFeeder); -Workspace* GetCurrentWorkspace() { - return gWorkspace; -} - -Workspace* GetWorkspaceByName(const std::string& name) { - if (gWorkspaces.count(name)) { - return gWorkspaces[name].get(); - } - return nullptr; -} - class StringFetcher : public BlobFetcherBase { public: py::object Fetch(const Blob& blob) override { @@ -180,20 +163,6 @@ std::function DefinitionGetter( return [registry](const string& name) { return registry->HelpMessage(name); }; } -void switchWorkspaceInternal(const std::string& name, bool create_if_missing) { - if (gWorkspaces.count(name)) { - gCurrentWorkspaceName = name; - gWorkspace = gWorkspaces[name].get(); - return; - } - - CAFFE_ENFORCE(create_if_missing); - std::unique_ptr new_workspace(new Workspace()); - gWorkspace = new_workspace.get(); - gWorkspaces.insert(std::make_pair(name, std::move(new_workspace))); - gCurrentWorkspaceName = name; -} - namespace python_detail { // Python Op implementations. using FuncRegistry = std::unordered_map; @@ -240,7 +209,7 @@ bool feedBlob( const py::object& arg, const py::object device_option) { DeviceOption option; - if (!device_option.is(py::none())) { + if (!device_option.is_none()) { // If we have a device option passed in, read it. CAFFE_ENFORCE(ParseProtoFromLargeString( py::bytes(device_option).cast(), &option)); @@ -652,10 +621,9 @@ void addObjectMethods(py::module& m) { return (int)self->last_failed_op_net_position; }) .def_property_readonly_static("current", [](py::object /* type */) { - auto ws = gWorkspaces.find(gCurrentWorkspaceName); - CAFFE_ENFORCE(ws != gWorkspaces.end()); - CAFFE_ENFORCE(ws->second.get()); - return py::cast(ws->second.get(), py::return_value_policy::reference); + auto ws = caffe2::python::GetCurrentWorkspace(); + CAFFE_ENFORCE(ws); + return py::cast(ws, py::return_value_policy::reference); }); py::class_>( @@ -784,7 +752,7 @@ void addObjectMethods(py::module& m) { .def( "reset", [](caffe2::onnx::DummyName& instance, const py::object& args) { - if (args.is(py::none())) { + if (args.is_none()) { instance.Reset(std::unordered_set()); } else { instance.Reset(args.cast>()); @@ -972,14 +940,15 @@ void addObjectMethods(py::module& m) { py::class_(m, "Predictor") .def(py::init([](py::bytes init_net, py::bytes predict_net) { - CAFFE_ENFORCE(gWorkspace); + Workspace* workspace = caffe2::python::GetCurrentWorkspace(); + CAFFE_ENFORCE(workspace); NetDef init_net_, predict_net_; CAFFE_ENFORCE(ParseProtoFromLargeString( init_net.cast(), &init_net_)); CAFFE_ENFORCE(ParseProtoFromLargeString( predict_net.cast(), &predict_net_)); return new Predictor( - makePredictorConfig(init_net_, predict_net_, gWorkspace)); + makePredictorConfig(init_net_, predict_net_, workspace)); })) .def( "run", @@ -1139,20 +1108,21 @@ void addGlobalMethods(py::module& m) { } return keys; }); - m.def("on_module_exit", []() { gWorkspaces.clear(); }); + m.def("on_module_exit", []() { caffe2::python::ClearWorkspaces(); }); // create_if_missing not used by necessary for pybind to do // properly do function overloading. m.def( - "switch_workspace", - [](Workspace* ws, py::object /*create_if_missing*/) { gWorkspace = ws; }); + "switch_workspace", [](Workspace* ws, py::object /*create_if_missing*/) { + // TODO + caffe2::python::SetCurrentWorkspace(ws); + }); m.def( "create_child_workspace", [](const std::string& parent_ws_name, const std::string& child_ws_name) { - CAFFE_ENFORCE( - gWorkspaces.count(parent_ws_name), "Parent ws does not exist."); - auto parent_gws = gWorkspaces[parent_ws_name].get(); + auto parent_gws = caffe2::python::GetWorkspaceByName(parent_ws_name); + CAFFE_ENFORCE(parent_gws, "Parent ws does not exist."); std::unique_ptr child_ws(new Workspace(parent_gws)); - gWorkspaces.insert(std::make_pair(child_ws_name, std::move(child_ws))); + caffe2::python::InsertWorkspace(child_ws_name, std::move(child_ws)); }, "Create and register child ws, sharing existing blobs in parent ws.", py::arg("parent_ws_name"), @@ -1160,10 +1130,11 @@ void addGlobalMethods(py::module& m) { m.def( "switch_workspace", [](const std::string& name, const py::object create_if_missing) { - if (create_if_missing.is(py::none())) { - return switchWorkspaceInternal(name, false); + if (create_if_missing.is_none()) { + return caffe2::python::SwitchWorkspaceInternal(name, false); } - return switchWorkspaceInternal(name, create_if_missing.cast()); + return caffe2::python::SwitchWorkspaceInternal( + name, create_if_missing.cast()); }, "Switch to the specified workspace, creating if necessary", py::arg("name"), @@ -1172,31 +1143,28 @@ void addGlobalMethods(py::module& m) { "reset_workspace", [](const py::object& root_folder) { VLOG(1) << "Resetting workspace."; - if (root_folder.is(py::none())) { - // NOLINTNEXTLINE(modernize-make-unique) - gWorkspaces[gCurrentWorkspaceName].reset(new Workspace()); + if (root_folder.is_none()) { + caffe2::python::ResetWorkspace(new Workspace()); } else { - // NOLINTNEXTLINE(modernize-make-unique) - gWorkspaces[gCurrentWorkspaceName].reset( + caffe2::python::ResetWorkspace( new Workspace(root_folder.cast())); } - gWorkspace = gWorkspaces[gCurrentWorkspaceName].get(); return true; }, "Reset the workspace", py::arg("root_folder") = py::none()); m.def("root_folder", []() { - CAFFE_ENFORCE(gWorkspace); - return gWorkspace->RootFolder(); + Workspace* workspace = caffe2::python::GetCurrentWorkspace(); + CAFFE_ENFORCE(workspace); + return workspace->RootFolder(); + }); + m.def("current_workspace", []() { + return caffe2::python::GetCurrentWorkspaceName(); }); - m.def("current_workspace", []() { return gCurrentWorkspaceName; }); m.def("workspaces", []() { std::vector names; - for (const auto& kv : gWorkspaces) { - // NOLINTNEXTLINE(performance-inefficient-vector-operation) - names.push_back(kv.first); - } + caffe2::python::GetWorkspaceNames(names); return names; }); m.def("nearby_opnames", [](const std::string& name) { @@ -1211,41 +1179,46 @@ void addGlobalMethods(py::module& m) { return alternatives; }); m.def("local_blobs", []() { - CAFFE_ENFORCE(gWorkspace); - return gWorkspace->LocalBlobs(); + Workspace* workspace = caffe2::python::GetCurrentWorkspace(); + CAFFE_ENFORCE(workspace); + return workspace->LocalBlobs(); }); m.def("blobs", []() { - CAFFE_ENFORCE(gWorkspace); - return gWorkspace->Blobs(); + Workspace* workspace = caffe2::python::GetCurrentWorkspace(); + CAFFE_ENFORCE(workspace); + return workspace->Blobs(); }); m.def("has_blob", [](const std::string& name) { - CAFFE_ENFORCE(gWorkspace); - return gWorkspace->HasBlob(name); + Workspace* workspace = caffe2::python::GetCurrentWorkspace(); + CAFFE_ENFORCE(workspace); + return workspace->HasBlob(name); }); m.def( "fill_random_network_inputs", [](const py::bytes& net_def, const std::vector>>& inputDims, const std::vector>& inputTypes) { - CAFFE_ENFORCE(gWorkspace); + Workspace* workspace = caffe2::python::GetCurrentWorkspace(); + CAFFE_ENFORCE(workspace); py::gil_scoped_release g; NetDef net; CAFFE_ENFORCE( ParseProtoFromLargeString(net_def.cast(), &net)); caffe2::emulator::fillRandomNetworkInputs( - net, inputDims, inputTypes, gWorkspace); + net, inputDims, inputTypes, workspace); }); m.def( "create_net", [](py::bytes net_def, bool overwrite) { - CAFFE_ENFORCE(gWorkspace); + Workspace* workspace = caffe2::python::GetCurrentWorkspace(); + CAFFE_ENFORCE(workspace); caffe2::NetDef proto; CAFFE_ENFORCE( ParseProtoFromLargeString(net_def.cast(), &proto), "Can't parse net proto: ", net_def.cast()); CAFFE_ENFORCE( - gWorkspace->CreateNet(proto, overwrite), + workspace->CreateNet(proto, overwrite), "Error creating net with proto: ", net_def.cast()); return true; @@ -1253,11 +1226,12 @@ void addGlobalMethods(py::module& m) { py::arg("net_def"), py::arg("overwrite") = kPyBindFalse); m.def("run_net", [](const std::string& name, int num_iter, bool allow_fail) { - CAFFE_ENFORCE(gWorkspace); - CAFFE_ENFORCE(gWorkspace->GetNet(name), "Can't find net ", name); + Workspace* workspace = caffe2::python::GetCurrentWorkspace(); + CAFFE_ENFORCE(workspace); + CAFFE_ENFORCE(workspace->GetNet(name), "Can't find net ", name); py::gil_scoped_release g; for (int i = 0; i < num_iter; i++) { - bool success = gWorkspace->RunNet(name); + bool success = workspace->RunNet(name); if (!allow_fail) { CAFFE_ENFORCE(success, "Error running net ", name); } else { @@ -1271,12 +1245,12 @@ void addGlobalMethods(py::module& m) { m.def( "add_observer_to_net", [](const std::string& net_name, const std::string& observer_type) { - CAFFE_ENFORCE(gWorkspace); - CAFFE_ENFORCE( - gWorkspace->GetNet(net_name), "Can't find net ", net_name); + Workspace* workspace = caffe2::python::GetCurrentWorkspace(); + CAFFE_ENFORCE(workspace); + CAFFE_ENFORCE(workspace->GetNet(net_name), "Can't find net ", net_name); py::gil_scoped_release g; - NetBase* net = gWorkspace->GetNet(net_name); + NetBase* net = workspace->GetNet(net_name); const Observable::Observer* observer = nullptr; #define REGISTER_PYTHON_EXPOSED_OBSERVER(ob_type) \ @@ -1303,12 +1277,12 @@ void addGlobalMethods(py::module& m) { m.def( "remove_observer_from_net", [](const std::string& net_name, const ObserverBase* observer) { - CAFFE_ENFORCE(gWorkspace); - CAFFE_ENFORCE( - gWorkspace->GetNet(net_name), "Can't find net ", net_name); + Workspace* workspace = caffe2::python::GetCurrentWorkspace(); + CAFFE_ENFORCE(workspace); + CAFFE_ENFORCE(workspace->GetNet(net_name), "Can't find net ", net_name); py::gil_scoped_release g; - NetBase* net = gWorkspace->GetNet(net_name); + NetBase* net = workspace->GetNet(net_name); net->DetachObserver(observer); }); m.def("clear_global_net_observer", []() { @@ -1316,11 +1290,12 @@ void addGlobalMethods(py::module& m) { caffe2::ClearGlobalNetObservers(); }); m.def("num_observers_on_net", [](const std::string& net_name) { - CAFFE_ENFORCE(gWorkspace); - CAFFE_ENFORCE(gWorkspace->GetNet(net_name), "Can't find net ", net_name); + Workspace* workspace = caffe2::python::GetCurrentWorkspace(); + CAFFE_ENFORCE(workspace); + CAFFE_ENFORCE(workspace->GetNet(net_name), "Can't find net ", net_name); py::gil_scoped_release g; - NetBase* net = gWorkspace->GetNet(net_name); + NetBase* net = workspace->GetNet(net_name); return net->NumObservers(); }); m.def( @@ -1329,8 +1304,9 @@ void addGlobalMethods(py::module& m) { size_t warmup_runs, size_t main_runs, bool run_individual) { - CAFFE_ENFORCE(gWorkspace); - auto* net = gWorkspace->GetNet(name); + Workspace* workspace = caffe2::python::GetCurrentWorkspace(); + CAFFE_ENFORCE(workspace); + auto* net = workspace->GetNet(name); CAFFE_ENFORCE(net, "Didn't find net: ", name); py::gil_scoped_release g; vector stat = @@ -1338,8 +1314,9 @@ void addGlobalMethods(py::module& m) { return stat; }); m.def("benchmark_net_once", [](const std::string& name) { - CAFFE_ENFORCE(gWorkspace); - auto* net = gWorkspace->GetNet(name); + Workspace* workspace = caffe2::python::GetCurrentWorkspace(); + CAFFE_ENFORCE(workspace); + auto* net = workspace->GetNet(name); CAFFE_ENFORCE(net, "Didn't find net: ", name); py::gil_scoped_release g; float stat = net->TEST_Benchmark_One_Run(); @@ -1347,28 +1324,35 @@ void addGlobalMethods(py::module& m) { }); m.def("delete_net", [](const std::string& name) { - CAFFE_ENFORCE(gWorkspace); - gWorkspace->DeleteNet(name); + Workspace* workspace = caffe2::python::GetCurrentWorkspace(); + CAFFE_ENFORCE(workspace); + workspace->DeleteNet(name); return true; }); - m.def("nets", []() { return gWorkspace->Nets(); }); + m.def("nets", []() { + Workspace* workspace = caffe2::python::GetCurrentWorkspace(); + CAFFE_ENFORCE(workspace); + return workspace->Nets(); + }); m.def("run_operator_once", [](const py::bytes& op_def) { - CAFFE_ENFORCE(gWorkspace); + Workspace* workspace = caffe2::python::GetCurrentWorkspace(); + CAFFE_ENFORCE(workspace); OperatorDef def; CAFFE_ENFORCE(ParseProtoFromLargeString(op_def.cast(), &def)); py::gil_scoped_release g; - CAFFE_ENFORCE(gWorkspace->RunOperatorOnce(def)); + CAFFE_ENFORCE(workspace->RunOperatorOnce(def)); return true; }); // Run an operator multiple times. // This is needed for microbenchmarking as we want the benchmark loop to be in // C++ to minimize overhead. m.def("run_operator_multiple", [](const py::bytes& op_def, int num_runs) { - CAFFE_ENFORCE(gWorkspace); + Workspace* workspace = caffe2::python::GetCurrentWorkspace(); + CAFFE_ENFORCE(workspace); OperatorDef def; CAFFE_ENFORCE(ParseProtoFromLargeString(op_def.cast(), &def)); py::gil_scoped_release g; - std::unique_ptr op(CreateOperator(def, gWorkspace)); + std::unique_ptr op(CreateOperator(def, workspace)); for (int i = 0; i < num_runs; i++) { if (!op->Run()) { return false; @@ -1379,7 +1363,8 @@ void addGlobalMethods(py::module& m) { m.def( "get_operator_cost", [](const py::bytes& op_def, const std::vector& input_blobs) { - CAFFE_ENFORCE(gWorkspace); + Workspace* workspace = caffe2::python::GetCurrentWorkspace(); + CAFFE_ENFORCE(workspace); OperatorDef def; CAFFE_ENFORCE( ParseProtoFromLargeString(op_def.cast(), &def), @@ -1389,37 +1374,40 @@ void addGlobalMethods(py::module& m) { CAFFE_ENFORCE(schema); vector shapes; for (const auto& blob_name : input_blobs) { - auto* blob = gWorkspace->GetBlob(blob_name); + auto* blob = workspace->GetBlob(blob_name); shapes.emplace_back(GetTensorShapeOfBlob(blob)); } const auto c = schema->InferCost(def, shapes); return std::make_tuple(c.flops, c.bytes_written, c.bytes_read); }); m.def("run_net_once", [](const py::bytes& net_def) { - CAFFE_ENFORCE(gWorkspace); + Workspace* workspace = caffe2::python::GetCurrentWorkspace(); + CAFFE_ENFORCE(workspace); NetDef def; CAFFE_ENFORCE(ParseProtoFromLargeString(net_def.cast(), &def)); py::gil_scoped_release g; - CAFFE_ENFORCE(gWorkspace->RunNetOnce(def)); + CAFFE_ENFORCE(workspace->RunNetOnce(def)); return true; }); m.def("run_plan", [](const py::bytes& plan_def) { - CAFFE_ENFORCE(gWorkspace); + Workspace* workspace = caffe2::python::GetCurrentWorkspace(); + CAFFE_ENFORCE(workspace); PlanDef def; CAFFE_ENFORCE( ParseProtoFromLargeString(plan_def.cast(), &def)); py::gil_scoped_release g; - CAFFE_ENFORCE(gWorkspace->RunPlan(def)); + CAFFE_ENFORCE(workspace->RunPlan(def)); return true; }); m.def("run_plan_in_background", [](const py::bytes& plan_def) { - CAFFE_ENFORCE(gWorkspace); + Workspace* workspace = caffe2::python::GetCurrentWorkspace(); + CAFFE_ENFORCE(workspace); PlanDef def; CAFFE_ENFORCE( ParseProtoFromLargeString(plan_def.cast(), &def)); py::gil_scoped_release g; - auto background_plan = std::make_shared(gWorkspace, def); + auto background_plan = std::make_shared(workspace, def); background_plan->run(); return background_plan; }); @@ -1513,7 +1501,8 @@ void addGlobalMethods(py::module& m) { m.def( "infer_shapes_and_types_from_workspace", [](const std::vector& net_protos) { - CAFFE_ENFORCE(gWorkspace); + Workspace* workspace = caffe2::python::GetCurrentWorkspace(); + CAFFE_ENFORCE(workspace); // Parse protobuffers to NetDefs std::vector> nets; @@ -1527,7 +1516,7 @@ void addGlobalMethods(py::module& m) { } auto blob_info = - InferBlobShapesAndTypesFromWorkspace(gWorkspace, nets_ptr); + InferBlobShapesAndTypesFromWorkspace(workspace, nets_ptr); std::string protob; CAFFE_ENFORCE(blob_info.SerializeToString(&protob)); @@ -1593,23 +1582,27 @@ void addGlobalMethods(py::module& m) { return py::bytes(output_net_proto); }); m.def("create_blob", [](const std::string& name) { - CAFFE_ENFORCE(gWorkspace); - CAFFE_ENFORCE(gWorkspace->CreateBlob(name)); + Workspace* workspace = caffe2::python::GetCurrentWorkspace(); + CAFFE_ENFORCE(workspace); + CAFFE_ENFORCE(workspace->CreateBlob(name)); return true; }); m.def("reset_blob", [](const std::string& name) { - CAFFE_ENFORCE(gWorkspace); - auto* b = gWorkspace->GetBlob(name); + Workspace* workspace = caffe2::python::GetCurrentWorkspace(); + CAFFE_ENFORCE(workspace); + auto* b = workspace->GetBlob(name); CAFFE_ENFORCE(b); b->Reset(); }); m.def("fetch_blob", [](const std::string& name) -> py::object { - return python_detail::fetchBlob(gWorkspace, name); + Workspace* workspace = caffe2::python::GetCurrentWorkspace(); + return python_detail::fetchBlob(workspace, name); }); m.def( "feed_blob", [](const std::string& name, py::object arg, py::object device_option) { - auto* blob = gWorkspace->CreateBlob(name); + Workspace* workspace = caffe2::python::GetCurrentWorkspace(); + auto* blob = workspace->CreateBlob(name); return python_detail::feedBlob(blob, arg, device_option); }, "", @@ -1620,16 +1613,18 @@ void addGlobalMethods(py::module& m) { return python_detail::deserializeBlob(content); }); m.def("serialize_blob", [](const std::string& name) { - CAFFE_ENFORCE(gWorkspace); - auto* blob = gWorkspace->GetBlob(name); + Workspace* workspace = caffe2::python::GetCurrentWorkspace(); + CAFFE_ENFORCE(workspace); + auto* blob = workspace->GetBlob(name); CAFFE_ENFORCE(blob); return py::bytes(SerializeBlob(*blob, name)); }); m.def( "deserialize_blob", [](const std::string& name, const py::bytes& serialized) { - CAFFE_ENFORCE(gWorkspace); - auto* blob = gWorkspace->CreateBlob(name); + Workspace* workspace = caffe2::python::GetCurrentWorkspace(); + CAFFE_ENFORCE(workspace); + auto* blob = workspace->CreateBlob(name); DeserializeBlob(serialized.cast(), blob); }); @@ -1639,7 +1634,7 @@ void addGlobalMethods(py::module& m) { "register_python_op", [](py::object func, bool pass_workspace, std::string name) { using namespace python_detail; - CAFFE_ENFORCE(!func.is(py::none())); + CAFFE_ENFORCE(!func.is_none()); if (!name.empty()) { name += ":"; } @@ -1655,7 +1650,7 @@ void addGlobalMethods(py::module& m) { "register_python_gradient_op", [](const std::string& token, py::object func) { using namespace python_detail; - CAFFE_ENFORCE(!func.is(py::none())); + CAFFE_ENFORCE(!func.is_none()); CAFFE_ENFORCE(gRegistry().find(token) != gRegistry().end()); // For global sanity gradient ops shouldn't access workspace gRegistry()[token + "_gradient"] = Func{func, false}; @@ -1695,8 +1690,9 @@ void addGlobalMethods(py::module& m) { m.def("is_numa_enabled", []() { return IsNUMAEnabled(); }); m.def("get_num_numa_nodes", []() { return GetNumNUMANodes(); }); m.def("get_blob_numa_node", [](const std::string& blob_name) { - CAFFE_ENFORCE(gWorkspace); - auto* blob = gWorkspace->GetBlob(blob_name); + Workspace* workspace = caffe2::python::GetCurrentWorkspace(); + CAFFE_ENFORCE(workspace); + auto* blob = workspace->GetBlob(blob_name); CAFFE_ENFORCE(blob); const TensorCPU& tensor = blob->Get(); const void* raw_data = tensor.raw_data(); @@ -1704,8 +1700,9 @@ void addGlobalMethods(py::module& m) { return GetNUMANode(raw_data); }); m.def("get_blob_size_bytes", [](const std::string& blob_name) { - CAFFE_ENFORCE(gWorkspace); - auto* blob = gWorkspace->GetBlob(blob_name); + Workspace* workspace = caffe2::python::GetCurrentWorkspace(); + CAFFE_ENFORCE(workspace); + auto* blob = workspace->GetBlob(blob_name); CAFFE_ENFORCE(blob); return BlobStat::sizeBytes(*blob); }); @@ -1861,13 +1858,14 @@ void addGlobalMethods(py::module& m) { m.def( "run_workspace_transform", [](const std::string& transform_name, py::bytes def) { - CAFFE_ENFORCE(gWorkspace); + Workspace* workspace = caffe2::python::GetCurrentWorkspace(); + CAFFE_ENFORCE(workspace); caffe2::NetDef proto; CAFFE_ENFORCE( ParseProtoFromLargeString(def.cast(), &proto)); auto nn = caffe2::convertToNNModule(proto); auto pass = WorkspaceOptimizationPassRegistry()->Create( - transform_name, &nn, gWorkspace); + transform_name, &nn, workspace); CAFFE_ENFORCE(pass, "Pass doesn't exist: ", transform_name); pass->run(); @@ -1897,7 +1895,8 @@ void addGlobalMethods(py::module& m) { CAFFE_ENFORCE(ParseProtoFromLargeString(def.cast(), &proto)); auto nn = caffe2::convertToNNModule(proto); - opt::OptimizeForMkldnn(&nn, gWorkspace, training_mode); + Workspace* workspace = caffe2::python::GetCurrentWorkspace(); + opt::OptimizeForMkldnn(&nn, workspace, training_mode); auto new_proto = caffe2::convertToCaffe2Proto(nn, proto); std::string out; @@ -1919,12 +1918,13 @@ void addGlobalMethods(py::module& m) { }); m.def("transform_fuseConvBN", [](py::bytes def) { - CAFFE_ENFORCE(gWorkspace); + Workspace* workspace = caffe2::python::GetCurrentWorkspace(); + CAFFE_ENFORCE(workspace); caffe2::NetDef proto; CAFFE_ENFORCE(ParseProtoFromLargeString(def.cast(), &proto)); auto nn = caffe2::convertToNNModule(proto); - opt::fuseConvBN(&nn, gWorkspace); + opt::fuseConvBN(&nn, workspace); auto new_proto = caffe2::convertToCaffe2Proto(nn); std::string out; @@ -1959,8 +1959,7 @@ void addGlobalMethods(py::module& m) { return; } // We will create a default workspace for us to run stuff. - switchWorkspaceInternal("default", true); - gCurrentWorkspaceName = "default"; + caffe2::python::SwitchWorkspaceInternal("default", true); initialized = true; }; diff --git a/caffe2/python/pybind_workspace.cc b/caffe2/python/pybind_workspace.cc new file mode 100644 index 0000000000000..aa837b7b4dfe9 --- /dev/null +++ b/caffe2/python/pybind_workspace.cc @@ -0,0 +1,72 @@ +#include "caffe2/core/workspace.h" + +namespace caffe2 { +namespace python { + +// gWorkspace is the pointer to the current workspace. The ownership is kept +// by the gWorkspaces map. +static Workspace* gWorkspace = nullptr; +static std::string gCurrentWorkspaceName; +// gWorkspaces allows us to define and switch between multiple workspaces in +// Python. +static std::map> gWorkspaces; + +Workspace* GetCurrentWorkspace() { + return gWorkspace; +} + +void SetCurrentWorkspace(Workspace* workspace) { + gWorkspace = workspace; +} + +Workspace* NewWorkspace() { + std::unique_ptr new_workspace(new Workspace()); + gWorkspace = new_workspace.get(); + return gWorkspace; +} + +Workspace* GetWorkspaceByName(const std::string& name) { + if (gWorkspaces.count(name)) { + return gWorkspaces[name].get(); + } + return nullptr; +} + +std::string GetCurrentWorkspaceName() { + return gCurrentWorkspaceName; +} +void InsertWorkspace(const std::string& name, std::unique_ptr ws) { + gWorkspaces.insert(std::make_pair(name, std::move(ws))); +} + +void SwitchWorkspaceInternal(const std::string& name, bool create_if_missing) { + if (gWorkspaces.count(name)) { + gCurrentWorkspaceName = name; + gWorkspace = gWorkspaces[name].get(); + return; + } + + CAFFE_ENFORCE(create_if_missing); + std::unique_ptr new_workspace(new Workspace()); + gWorkspace = new_workspace.get(); + gWorkspaces.insert(std::make_pair(name, std::move(new_workspace))); + gCurrentWorkspaceName = name; +} + +void ResetWorkspace(Workspace* workspace) { + gWorkspaces[gCurrentWorkspaceName].reset(workspace); + gWorkspace = gWorkspaces[gCurrentWorkspaceName].get(); +} + +void GetWorkspaceNames(std::vector& names) { + for (const auto& kv : gWorkspaces) { + // NOLINTNEXTLINE(performance-inefficient-vector-operation) + names.emplace_back(kv.first); + } +} + +void ClearWorkspaces() { + gWorkspaces.clear(); +} +} // namespace python +} // namespace caffe2 diff --git a/caffe2/python/pybind_workspace.h b/caffe2/python/pybind_workspace.h new file mode 100644 index 0000000000000..0467d9ff6ccd3 --- /dev/null +++ b/caffe2/python/pybind_workspace.h @@ -0,0 +1,15 @@ +namespace caffe2 { +namespace python { + +Workspace* GetCurrentWorkspace(); +void SetCurrentWorkspace(Workspace* workspace); +Workspace* NewWorkspace(); +Workspace* GetWorkspaceByName(const std::string& name); +std::string GetCurrentWorkspaceName(); +void InsertWorkspace(const std::string& name, std::unique_ptr ws); +void SwitchWorkspaceInternal(const std::string& name, bool create_if_missing); +void ResetWorkspace(Workspace* workspace); +void GetWorkspaceNames(std::vector& names); +void ClearWorkspaces(); +} // namespace python +} // namespace caffe2 diff --git a/caffe2/python/schema.py b/caffe2/python/schema.py index 60353ac38a256..295b79feadca7 100644 --- a/caffe2/python/schema.py +++ b/caffe2/python/schema.py @@ -546,8 +546,8 @@ def __getattr__(self, item): raise AttributeError(item) try: return super(Struct, self).__getattribute__("fields")[item] - except KeyError: - raise AttributeError(item) + except KeyError as e: + raise AttributeError(item) from e def __setattr__(self, key, value): # Disable setting attributes after initialization to prevent false diff --git a/caffe2/python/trt/transform.py b/caffe2/python/trt/transform.py index 0e304ca4fae30..aee27d6826fbd 100644 --- a/caffe2/python/trt/transform.py +++ b/caffe2/python/trt/transform.py @@ -29,8 +29,8 @@ def _get_output_shapes(output_value_infos): def check_gpu_(): try: C.get_cuda_version() - except Exception as _: - raise Exception("TensorRT related functions require CUDA support") + except Exception as e: + raise Exception("TensorRT related functions require CUDA support") from e def convert_onnx_model_to_trt_op(onnx_model, max_batch_size=64, diff --git a/caffe2/python/workspace_test.py b/caffe2/python/workspace_test.py index 2e2d284f92e43..b434b5e748cc1 100644 --- a/caffe2/python/workspace_test.py +++ b/caffe2/python/workspace_test.py @@ -24,12 +24,6 @@ def setUp(self): ) workspace.ResetWorkspace() - def testRootFolder(self): - self.assertEqual(workspace.ResetWorkspace(), True) - self.assertEqual(workspace.RootFolder(), ".") - self.assertEqual(workspace.ResetWorkspace("/tmp/caffe-workspace-test"), True) - self.assertEqual(workspace.RootFolder(), "/tmp/caffe-workspace-test") - def testWorkspaceHasBlobWithNonexistingName(self): self.assertEqual(workspace.HasBlob("non-existing"), False) diff --git a/caffe2/quantization/server/README.md b/caffe2/quantization/server/README.md index 4819b62fedb77..b7d22bf8bbfe6 100644 --- a/caffe2/quantization/server/README.md +++ b/caffe2/quantization/server/README.md @@ -19,8 +19,8 @@ To compute the quantization parameters of activation tensors, we need to know th * Floating-point requantization -Unlike gemmlowp using fixed-point operations that emulates floating point operations of requantization, fbgemm just uses single-precison floating-point operations. This is because in x86 just using single-precision floating-point operations is faster. Probably, gemmlowp used pure fixed-point operations for low-end mobile processors. QNNPACK also has similar constraints as gemmlowp and provides multiple options of requantization implementations. -The users could modify the code to use a different requantization implementation to be bit-wise idential to the HW they want to emulate for example. If there're enough requests, we could consider implementing a few popular fixed-point requantization as QNNPACK did. +Unlike gemmlowp using fixed-point operations that emulates floating point operations of requantization, fbgemm just uses single-precision floating-point operations. This is because in x86 just using single-precision floating-point operations is faster. Probably, gemmlowp used pure fixed-point operations for low-end mobile processors. QNNPACK also has similar constraints as gemmlowp and provides multiple options of requantization implementations. +The users could modify the code to use a different requantization implementation to be bit-wise identical to the HW they want to emulate for example. If there're enough requests, we could consider implementing a few popular fixed-point requantization as QNNPACK did. * 16-bit accumulation with outlier-aware quantization diff --git a/caffe2/release-notes.md b/caffe2/release-notes.md index e76b760a7ed5e..d449e98f78e3d 100644 --- a/caffe2/release-notes.md +++ b/caffe2/release-notes.md @@ -133,7 +133,7 @@ If you're running this all on a cloud computer, you probably won't have a UI or First configure your cloud server to accept port 8889, or whatever you want, but change the port in the following commands. On AWS you accomplish this by adding a rule to your server's security group allowing a TCP inbound on port 8889. Otherwise you would adjust iptables for this. -Next you launch the Juypter server. +Next you launch the Jupyter server. ``` jupyter notebook --no-browser --port=8889 diff --git a/caffe2/serialize/inline_container.cc b/caffe2/serialize/inline_container.cc index 9d3cc332ae96e..54b94d31775de 100644 --- a/caffe2/serialize/inline_container.cc +++ b/caffe2/serialize/inline_container.cc @@ -338,8 +338,7 @@ PyTorchStreamWriter::PyTorchStreamWriter(std::string file_name) } PyTorchStreamWriter::PyTorchStreamWriter( - // NOLINTNEXTLINE(modernize-pass-by-value) - const std::function& writer_func) + const std::function writer_func) : archive_name_("archive"), writer_func_(writer_func) { setup(archive_name_); @@ -416,6 +415,21 @@ void PyTorchStreamWriter::writeRecord( } void PyTorchStreamWriter::writeEndOfFile() { + // Ensurers that finalized is set to true even + // exception is raised during the method call. + // I.e. even partial call to writeEndOfFile() should mark + // file as finalized, otherwise double exception raised from + // destructor would would result in `std::terminate()` + // See https://github.com/pytorch/pytorch/issues/87997/ + struct Finalizer { + Finalizer(bool& var): var_(var) {} + ~Finalizer() { + var_ = true; + } + private: + bool& var_; + } f(finalized_); + auto allRecords = getAllWrittenRecords(); // If no ".data/version" or "version" record in the output model, rewrites version info if(allRecords.find(".data/version") == allRecords.end() && allRecords.find("version") == allRecords.end()) { diff --git a/caffe2/serialize/inline_container.h b/caffe2/serialize/inline_container.h index 621ffbe9a41ab..3f0e661dd229f 100644 --- a/caffe2/serialize/inline_container.h +++ b/caffe2/serialize/inline_container.h @@ -130,7 +130,7 @@ class TORCH_API PyTorchStreamWriter final { public: explicit PyTorchStreamWriter(std::string archive_name); explicit PyTorchStreamWriter( - const std::function& writer_func); + const std::function writer_func); void setMinVersion(const uint64_t version); diff --git a/caffe2/utils/math/elementwise.cu b/caffe2/utils/math/elementwise.cu index b41d2590e9296..d1911ae4db4c7 100644 --- a/caffe2/utils/math/elementwise.cu +++ b/caffe2/utils/math/elementwise.cu @@ -305,7 +305,7 @@ CAFFE2_SPECIALIZED_HALF_SCALE_CUDA_KERNEL(float) return; \ } \ if (alpha == T(0)) { \ - cudaMemsetAsync(Y, 0, sizeof(T) * N, context->cuda_stream()); \ + C10_CUDA_CHECK(cudaMemsetAsync(Y, 0, sizeof(T) * N, context->cuda_stream())); \ } else { \ thrust::fill( \ thrust::cuda::par.on(context->cuda_stream()), Y, Y + N, alpha); \ diff --git a/caffe2/utils/math/reduce.cu b/caffe2/utils/math/reduce.cu index 69a6469d8ed15..d59cbd387753e 100644 --- a/caffe2/utils/math/reduce.cu +++ b/caffe2/utils/math/reduce.cu @@ -418,12 +418,12 @@ void MomentsCUDA( return; } if (std::equal(X_dims, X_dims + ndim, Y_dims)) { - cudaMemcpyAsync( + C10_CUDA_CHECK(cudaMemcpyAsync( mean, X, sizeof(T) * X_size, cudaMemcpyDeviceToDevice, - context->cuda_stream()); + context->cuda_stream())); Set(Y_size, T(0), var, context); return; } diff --git a/caffe2/utils/math_gpu.cu b/caffe2/utils/math_gpu.cu index 54b0a9391c263..4ad249dadc7e4 100644 --- a/caffe2/utils/math_gpu.cu +++ b/caffe2/utils/math_gpu.cu @@ -2685,12 +2685,12 @@ CAFFE2_CUDA_EXPORT void CopyVector( float* dst, CUDAContext* context) { if (src != dst && N > 0) { - cudaMemcpyAsync( + C10_CUDA_CHECK(cudaMemcpyAsync( dst, src, sizeof(float) * N, cudaMemcpyDeviceToDevice, - context->cuda_stream()); + context->cuda_stream())); } } @@ -2701,12 +2701,12 @@ CAFFE2_CUDA_EXPORT void CopyVector( int* dst, CUDAContext* context) { if (src != dst && N > 0) { - cudaMemcpyAsync( + C10_CUDA_CHECK(cudaMemcpyAsync( dst, src, sizeof(int) * N, cudaMemcpyDeviceToDevice, - context->cuda_stream()); + context->cuda_stream())); } } diff --git a/caffe2/utils/threadpool/ThreadPool.cc b/caffe2/utils/threadpool/ThreadPool.cc index cbccf0749bef1..79fc279f3591b 100644 --- a/caffe2/utils/threadpool/ThreadPool.cc +++ b/caffe2/utils/threadpool/ThreadPool.cc @@ -103,12 +103,13 @@ size_t getDefaultNumThreads() { /* * For llvm-tsan, holding limit for the number of locks for a single thread - * is 64. pthreadpool's worst case is the number of threads in a pool. So we - * want to limit the threadpool size to 64 when running with tsan. However, - * sometimes it is tricky to detect if we are running under tsan, for now - * capping the default threadcount to the tsan limit unconditionally. + * is 63 (because of comparison < 64 instead of <=). pthreadpool's worst + * case is the number of threads in a pool. So we want to limit the threadpool + * size to 64 when running with tsan. However, sometimes it is tricky to + * detect if we are running under tsan, for now capping the default + * threadcount to the tsan limit unconditionally. */ - int tsanThreadLimit = 64; + int tsanThreadLimit = 63; numThreads = std::min(numThreads, tsanThreadLimit); return numThreads; diff --git a/caffe2/video/video_decoder.cc b/caffe2/video/video_decoder.cc index 8993241d39dc5..86bfbfa5ad2a0 100644 --- a/caffe2/video/video_decoder.cc +++ b/caffe2/video/video_decoder.cc @@ -606,7 +606,7 @@ void VideoDecoder::decodeLoop( unique_ptr frame = make_unique(); frame->width_ = outWidth; frame->height_ = outHeight; - frame->data_ = move(buffer); + frame->data_ = std::move(buffer); frame->size_ = size; frame->index_ = frameIndex; frame->outputFrameIndex_ = outputFrameIndex; @@ -735,10 +735,10 @@ bool DecodeMultipleClipsFromVideo( } for (auto& frame : callback.frames) { - sampledFrames.push_back(move(frame)); + sampledFrames.push_back(std::move(frame)); } for (auto& audio_sample : callback.audio_samples) { - sampledAudio.push_back(move(audio_sample)); + sampledAudio.push_back(std::move(audio_sample)); } for (int i = 0; i < buffer_rgb.size(); i++) { diff --git a/caffe2/video/video_decoder.h b/caffe2/video/video_decoder.h index a091142389d63..ba607fd8da3f0 100644 --- a/caffe2/video/video_decoder.h +++ b/caffe2/video/video_decoder.h @@ -508,11 +508,11 @@ class CallbackImpl : public Callback { } void frameDecoded(std::unique_ptr frame) override { - frames.push_back(move(frame)); + frames.push_back(std::move(frame)); } void audioDecoded(std::unique_ptr audio_sample) override { - audio_samples.push_back(move(audio_sample)); + audio_samples.push_back(std::move(audio_sample)); } void videoDecodingStarted(const VideoMeta& /*videoMeta*/) override { diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 47f5be14ed9a6..8faeb401017b8 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -1248,6 +1248,16 @@ if(ANDROID) list(APPEND Caffe2_DEPENDENCY_LIBS log) endif() +# ---[ Kernel asserts +# Kernel asserts are enabled by default for CUDA and disabled for ROCm. +# For ROCm, it can be enabled by setting ROCM_FORCE_ENABLE_GPU_ASSERTS +if(USE_ROCM AND ROCM_FORCE_ENABLE_GPU_ASSERTS) + message(STATUS "Forcefully enabling kernel asserts on ROCM") +elseif(USE_ROCM AND NOT ROCM_FORCE_ENABLE_GPU_ASSERTS) + message(STATUS "Disabling kernel asserts for ROCm") + caffe2_update_option(TORCH_DISABLE_GPU_ASSERTS ON) +endif() + # ---[ LLVM if(USE_LLVM) message(STATUS "Looking for LLVM in ${USE_LLVM}") @@ -1270,6 +1280,21 @@ endif() # ---[ HIP if(USE_ROCM) + # This prevents linking in the libtinfo from /opt/conda/lib which conflicts with ROCm libtinfo. + # Currently only active for Ubuntu 20.04 and greater versions. + if(UNIX AND EXISTS "/etc/os-release") + file(STRINGS /etc/os-release OS_RELEASE) + string(REGEX REPLACE "NAME=\"([A-Za-z]+).*" "\\1" OS_DISTRO ${OS_RELEASE}) + string(REGEX REPLACE ".*VERSION_ID=\"([0-9\.]+).*" "\\1" OS_VERSION ${OS_RELEASE}) + if(OS_DISTRO STREQUAL "Ubuntu" AND OS_VERSION VERSION_GREATER_EQUAL "20.04") + find_library(LIBTINFO_LOC tinfo NO_CMAKE_PATH NO_CMAKE_ENVIRONMENT_PATH) + if(LIBTINFO_LOC) + get_filename_component(LIBTINFO_LOC_PARENT ${LIBTINFO_LOC} DIRECTORY) + set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -Wl,-rpath-link,${LIBTINFO_LOC_PARENT}") + endif() + endif() + endif() + include(${CMAKE_CURRENT_LIST_DIR}/public/LoadHIP.cmake) if(PYTORCH_FOUND_HIP) message(INFO "Compiling with HIP for AMD.") @@ -1296,7 +1321,7 @@ if(USE_ROCM) list(APPEND HIP_CXX_FLAGS -Wno-implicit-int-float-conversion) list(APPEND HIP_CXX_FLAGS -DCAFFE2_USE_MIOPEN) list(APPEND HIP_CXX_FLAGS -DTHRUST_DEVICE_SYSTEM=THRUST_DEVICE_SYSTEM_HIP) - list(APPEND HIP_CXX_FLAGS -std=c++14) + list(APPEND HIP_CXX_FLAGS -std=c++17) add_definitions(-DROCM_VERSION=${ROCM_VERSION_DEV_INT}) add_definitions(-DTORCH_HIP_VERSION=${TORCH_HIP_VERSION}) message("TORCH_HIP_VERSION=${TORCH_HIP_VERSION} is added as a compiler defines") @@ -1560,6 +1585,9 @@ if(CAFFE2_CMAKE_BUILDING_WITH_MAIN_REPO AND NOT INTERN_DISABLE_ONNX) add_subdirectory("${CMAKE_CURRENT_LIST_DIR}/../caffe2/onnx/torch_ops") if(NOT USE_SYSTEM_ONNX) add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/../third_party/onnx EXCLUDE_FROM_ALL) + if(NOT MSVC) + set_target_properties(onnx_proto PROPERTIES CXX_STANDARD 17) + endif() endif() add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/../third_party/foxi EXCLUDE_FROM_ALL) @@ -1662,7 +1690,7 @@ if(NOT INTERN_BUILD_MOBILE) string(APPEND CMAKE_CUDA_FLAGS " -Wno-deprecated-gpu-targets --expt-extended-lambda") if(NOT CMAKE_CXX_COMPILER_ID STREQUAL "GNU") - set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ standard whose features are requested to build this target.") + set(CMAKE_CXX_STANDARD 17 CACHE STRING "The C++ standard whose features are requested to build this target.") endif() # use cub in a safe manner, see: @@ -1927,6 +1955,7 @@ if(USE_KINETO) find_library(CUPTI_LIBRARY_PATH ${CUPTI_LIB_NAME} PATHS ${CUDA_SOURCE_DIR} ${CUDA_SOURCE_DIR}/extras/CUPTI/lib64 + ${CUDA_SOURCE_DIR}/lib ${CUDA_SOURCE_DIR}/lib64 NO_DEFAULT_PATH) @@ -1989,6 +2018,11 @@ if(USE_KINETO) string(APPEND CMAKE_CXX_FLAGS " -DUSE_KINETO") if(LIBKINETO_NOCUPTI) string(APPEND CMAKE_CXX_FLAGS " -DLIBKINETO_NOCUPTI") + endif() + if(LIBKINETO_NOROCTRACER) + string(APPEND CMAKE_CXX_FLAGS " -DLIBKINETO_NOROCTRACER") + endif() + if(LIBKINETO_NOCUPTI AND LIBKINETO_NOROCTRACER) message(STATUS "Configured Kineto (CPU)") else() message(STATUS "Configured Kineto") diff --git a/cmake/External/nccl.cmake b/cmake/External/nccl.cmake index cb928baf3a595..160d2b648c051 100644 --- a/cmake/External/nccl.cmake +++ b/cmake/External/nccl.cmake @@ -15,23 +15,24 @@ if(NOT __NCCL_INCLUDED) # this second replacement is needed when there are multiple archs string(REPLACE ";-gencode" " -gencode" NVCC_GENCODE "${NVCC_GENCODE}") - if("${CMAKE_GENERATOR}" MATCHES "Make") - # Recursive make with jobserver for parallelism - set(MAKE_COMMAND "$(MAKE)") + if(DEFINED ENV{MAX_JOBS}) + set(MAX_JOBS "$ENV{MAX_JOBS}") else() - if(DEFINED ENV{MAX_JOBS}) - set(MAX_JOBS "$ENV{MAX_JOBS}") - else() - include(ProcessorCount) - ProcessorCount(NUM_HARDWARE_THREADS) - # Assume 2 hardware threads per cpu core - math(EXPR MAX_JOBS "${NUM_HARDWARE_THREADS} / 2") - # ProcessorCount might return 0, set to a positive number - if(MAX_JOBS LESS 2) - set(MAX_JOBS 2) - endif() + include(ProcessorCount) + ProcessorCount(NUM_HARDWARE_THREADS) + # Assume 2 hardware threads per cpu core + math(EXPR MAX_JOBS "${NUM_HARDWARE_THREADS} / 2") + # ProcessorCount might return 0, set to a positive number + if(MAX_JOBS LESS 2) + set(MAX_JOBS 2) endif() + endif() + if("${CMAKE_GENERATOR}" MATCHES "Make") + # Recursive make with jobserver for parallelism, and also put a load limit + # here to avoid flaky OOM, https://www.gnu.org/software/make/manual/html_node/Parallel.html + set(MAKE_COMMAND "$(MAKE)" "-l${MAX_JOBS}") + else() # Parallel build with CPU load limit to avoid oversubscription set(MAKE_COMMAND "make" "-j${MAX_JOBS}" "-l${MAX_JOBS}") endif() diff --git a/cmake/Modules/FindMKLDNN.cmake b/cmake/Modules/FindMKLDNN.cmake index e2f427be67c89..30ac5401ddf32 100644 --- a/cmake/Modules/FindMKLDNN.cmake +++ b/cmake/Modules/FindMKLDNN.cmake @@ -76,6 +76,8 @@ IF(NOT MKLDNN_FOUND) SET(DNNL_BUILD_EXAMPLES FALSE CACHE BOOL "" FORCE) SET(DNNL_LIBRARY_TYPE STATIC CACHE STRING "" FORCE) SET(DNNL_ENABLE_PRIMITIVE_CACHE TRUE CACHE BOOL "" FORCE) + SET(DNNL_GRAPH_CPU_RUNTIME ${MKLDNN_CPU_RUNTIME} CACHE STRING "" FORCE) + IF(BUILD_ONEDNN_GRAPH) SET(DNNL_GRAPH_LIBRARY_TYPE STATIC CACHE STRING "" FORCE) ENDIF(BUILD_ONEDNN_GRAPH) diff --git a/cmake/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake b/cmake/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake index 7f22d476d2fbe..65e7a6ac8993c 100644 --- a/cmake/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake +++ b/cmake/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake @@ -94,12 +94,28 @@ if(CUDA_VERSION VERSION_GREATER "10.5") endif() if(NOT CUDA_VERSION VERSION_LESS "11.1") - list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.6" "8.6+PTX") + list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.6") list(APPEND CUDA_ALL_GPU_ARCHITECTURES "8.6") set(CUDA_LIMIT_GPU_ARCHITECUTRE "8.6") + if(CUDA_VERSION VERSION_LESS "11.8") + set(CUDA_LIMIT_GPU_ARCHITECTURE "8.9") + list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.6+PTX") + endif() +endif() + +if(NOT CUDA_VERSION VERSION_LESS "11.8") + list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Ada") + list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Hopper") + list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.9") + list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "9.0") + list(APPEND CUDA_ALL_GPU_ARCHITECTURES "8.9") + list(APPEND CUDA_ALL_GPU_ARCHITECTURES "9.0") + if(CUDA_VERSION VERSION_LESS "12.0") set(CUDA_LIMIT_GPU_ARCHITECTURE "9.0") + list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.9+PTX") + list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "9.0+PTX") endif() endif() @@ -237,6 +253,12 @@ function(CUDA_SELECT_NVCC_ARCH_FLAGS out_variable) elseif(${arch_name} STREQUAL "Ampere") set(arch_bin 8.0) set(arch_ptx 8.0) + elseif(${arch_name} STREQUAL "Ada") + set(arch_bin 8.9) + set(arch_ptx 8.9) + elseif(${arch_name} STREQUAL "Hopper") + set(arch_bin 9.0) + set(arch_ptx 9.0) else() message(SEND_ERROR "Unknown CUDA Architecture Name ${arch_name} in CUDA_SELECT_NVCC_ARCH_FLAGS") endif() diff --git a/cmake/ProtoBufPatch.cmake b/cmake/ProtoBufPatch.cmake index 7f1de9a4a1de9..42696a0a068fc 100644 --- a/cmake/ProtoBufPatch.cmake +++ b/cmake/ProtoBufPatch.cmake @@ -31,12 +31,14 @@ if(NOT SYSTEM_PROTOBUF) # https://github.com/protocolbuffers/protobuf/commit/0400cca3236de1ca303af38bf81eab332d042b7c # changes PROTOBUF_CONSTEXPR to constexpr, which breaks windows # build. - string( - REGEX REPLACE - "static constexpr ([^ ]+) ([^ ]+) =" - "static \\1 const \\2 =" - content - "${content}") + if(MSVC) + string( + REGEX REPLACE + "static constexpr ([^ ]+) ([^ ]+) =" + "static \\1 const \\2 =" + content + "${content}") + endif() foreach(ns ${NAMESPACES}) # Insert "const ::std::string& GetEmptyStringAlreadyInited();" within diff --git a/cmake/Summary.cmake b/cmake/Summary.cmake index fd6444680e2d4..279d72a41e660 100644 --- a/cmake/Summary.cmake +++ b/cmake/Summary.cmake @@ -199,4 +199,5 @@ function(caffe2_print_configuration_summary) # coreml message(STATUS " USE_COREML_DELEGATE : ${USE_COREML_DELEGATE}") message(STATUS " BUILD_LAZY_TS_BACKEND : ${BUILD_LAZY_TS_BACKEND}") + message(STATUS " TORCH_DISABLE_GPU_ASSERTS : ${TORCH_DISABLE_GPU_ASSERTS}") endfunction() diff --git a/cmake/public/LoadHIP.cmake b/cmake/public/LoadHIP.cmake index 89a61b6242856..b51284115f144 100644 --- a/cmake/public/LoadHIP.cmake +++ b/cmake/public/LoadHIP.cmake @@ -143,9 +143,6 @@ message("Building PyTorch for GPU arch: ${PYTORCH_ROCM_ARCH}") # Add HIP to the CMAKE Module Path set(CMAKE_MODULE_PATH ${HIP_PATH}/cmake ${CMAKE_MODULE_PATH}) -#Disable kernel assert due to performance regression -set(ROCM_ENABLE_KERNEL_ASSERTS FALSE CACHE BOOL "Kernel asserts are disabled by default for ROCm") - macro(find_package_and_print_version PACKAGE_NAME) find_package("${PACKAGE_NAME}" ${ARGN}) message("${PACKAGE_NAME} VERSION: ${${PACKAGE_NAME}_VERSION}") @@ -286,19 +283,6 @@ if(HIP_FOUND) find_package_and_print_version(hipcub REQUIRED) find_package_and_print_version(rocthrust REQUIRED) - if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "4.1.0") - if(ROCM_ENABLE_KERNEL_ASSERTS) - message("ROCm version >= 4.1; enabling asserts") - else() - add_definitions(-DROCM_DISABLE_GPU_ASSERTS) - message("ROCm version >= 4.1; kernel asserts are disabled") - endif() - else() - # Disable Asserts In Code (Can't use asserts on HIP stack.) - add_definitions(-DNDEBUG) - message("ROCm version < 4.1; disablng asserts") - endif() - if(HIP_COMPILER STREQUAL clang) set(hip_library_name amdhip64) else() diff --git a/cmake/public/mkl.cmake b/cmake/public/mkl.cmake index 9515a4ae96813..f4ab1ffa9d0fe 100644 --- a/cmake/public/mkl.cmake +++ b/cmake/public/mkl.cmake @@ -9,4 +9,9 @@ set_property( ${MKL_INCLUDE_DIR}) set_property( TARGET caffe2::mkl PROPERTY INTERFACE_LINK_LIBRARIES - ${MKL_LIBRARIES}) + ${MKL_LIBRARIES} ${MKL_THREAD_LIB}) +# TODO: This is a hack, it will not pick up architecture dependent +# MKL libraries correctly; see https://github.com/pytorch/pytorch/issues/73008 +set_property( + TARGET caffe2::mkl PROPERTY INTERFACE_LINK_DIRECTORIES + ${MKL_ROOT}/lib) diff --git a/cmake/public/utils.cmake b/cmake/public/utils.cmake index 5944a5a1a6269..9ad0a2f96f88f 100644 --- a/cmake/public/utils.cmake +++ b/cmake/public/utils.cmake @@ -407,7 +407,7 @@ endmacro() # Usage: # torch_compile_options(lib_name) function(torch_compile_options libname) - set_property(TARGET ${libname} PROPERTY CXX_STANDARD 14) + set_property(TARGET ${libname} PROPERTY CXX_STANDARD 17) set(private_compile_options "") # ---[ Check if warnings should be errors. diff --git a/docker.Makefile b/docker.Makefile index 0768f6ecf6ed8..f85a3c3a3fc15 100644 --- a/docker.Makefile +++ b/docker.Makefile @@ -8,7 +8,7 @@ $(warning WARNING: No docker user found using results from whoami) DOCKER_ORG = $(shell whoami) endif -CUDA_VERSION = 11.3.1 +CUDA_VERSION = 11.6.2 CUDNN_VERSION = 8 BASE_RUNTIME = ubuntu:18.04 BASE_DEVEL = nvidia/cuda:$(CUDA_VERSION)-cudnn$(CUDNN_VERSION)-devel-ubuntu18.04 @@ -18,17 +18,20 @@ CUDA_CHANNEL = nvidia # The conda channel to use to install pytorch / torchvision INSTALL_CHANNEL ?= pytorch -PYTHON_VERSION ?= 3.8 +PYTHON_VERSION ?= 3.10 PYTORCH_VERSION ?= $(shell git describe --tags --always) # Can be either official / dev BUILD_TYPE ?= dev BUILD_PROGRESS ?= auto +# Intentionally left blank +TRITON_VERSION ?= BUILD_ARGS = --build-arg BASE_IMAGE=$(BASE_IMAGE) \ --build-arg PYTHON_VERSION=$(PYTHON_VERSION) \ --build-arg CUDA_VERSION=$(CUDA_VERSION) \ --build-arg CUDA_CHANNEL=$(CUDA_CHANNEL) \ --build-arg PYTORCH_VERSION=$(PYTORCH_VERSION) \ - --build-arg INSTALL_CHANNEL=$(INSTALL_CHANNEL) + --build-arg INSTALL_CHANNEL=$(INSTALL_CHANNEL) \ + --build-arg TRITON_VERSION=$(TRITON_VERSION) EXTRA_DOCKER_BUILD_FLAGS ?= BUILD ?= build diff --git a/docs/Makefile b/docs/Makefile index 122bda6231e39..c506845fa92bc 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -17,8 +17,9 @@ figures: @$(PYCMD) source/scripts/build_activation_images.py @$(PYCMD) source/scripts/build_quantization_configs.py -onnx_supported_aten_ops: +onnx: @$(PYCMD) source/scripts/onnx/build_onnx_supported_aten_op_csv_table.py + @$(PYCMD) source/scripts/onnx/build_onnx_diagnostics_rules_md.py $(SOURCEDIR)/generated/onnx_diagnostics_rules docset: html doc2dash --name $(SPHINXPROJ) --icon $(SOURCEDIR)/_static/img/pytorch-logo-flame.png --enable-js --online-redirect-url https://pytorch.org/docs/ --force $(BUILDDIR)/html/ @@ -34,11 +35,11 @@ html-stable: # See conf.py for more details. RELEASE=1 make html -.PHONY: help Makefile docset onnx_supported_aten_ops +.PHONY: help Makefile docset onnx # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). -%: Makefile figures onnx_supported_aten_ops +%: Makefile figures onnx @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) clean: diff --git a/docs/caffe2/.Doxyfile-c b/docs/caffe2/.Doxyfile-c index c4873d63841ca..b30ab661d24cb 100644 --- a/docs/caffe2/.Doxyfile-c +++ b/docs/caffe2/.Doxyfile-c @@ -1490,7 +1490,7 @@ EXT_LINKS_IN_WINDOW = NO FORMULA_FONTSIZE = 10 -# Use the FORMULA_TRANPARENT tag to determine whether or not the images +# Use the FORMULA_TRANSPARENT tag to determine whether or not the images # generated for formulas are transparent PNGs. Transparent PNGs are not # supported properly for IE 6.0, but are supported on all modern browsers. # diff --git a/docs/caffe2/.Doxyfile-python b/docs/caffe2/.Doxyfile-python index 9d16671ffe3ba..514e580363996 100644 --- a/docs/caffe2/.Doxyfile-python +++ b/docs/caffe2/.Doxyfile-python @@ -1488,7 +1488,7 @@ EXT_LINKS_IN_WINDOW = NO FORMULA_FONTSIZE = 10 -# Use the FORMULA_TRANPARENT tag to determine whether or not the images +# Use the FORMULA_TRANSPARENT tag to determine whether or not the images # generated for formulas are transparent PNGs. Transparent PNGs are not # supported properly for IE 6.0, but are supported on all modern browsers. # diff --git a/docs/cpp/source/notes/tensor_cuda_stream.rst b/docs/cpp/source/notes/tensor_cuda_stream.rst index b80615e8f7f10..4940317713635 100644 --- a/docs/cpp/source/notes/tensor_cuda_stream.rst +++ b/docs/cpp/source/notes/tensor_cuda_stream.rst @@ -144,7 +144,7 @@ CUDA Stream Usage Examples // sum() on tensor0 use `myStream0` as current CUDA stream on device 0 tensor0.sum(); - // change the current device index to 1 by using CUDA device guard within a braket scope + // change the current device index to 1 by using CUDA device guard within a bracket scope { at::cuda::CUDAGuard device_guard{1}; // create a tensor on device 1 @@ -206,7 +206,7 @@ CUDA Stream Usage Examples // sum() on tensor0 uses default CUDA stream as current CUDA stream on device 0 tensor0.sum(); - // sum() on tensor1 uses defualt CUDA stream as current CUDA stream on device 1 + // sum() on tensor1 uses default CUDA stream as current CUDA stream on device 1 tensor1.sum(); .. attention:: diff --git a/docs/requirements.txt b/docs/requirements.txt index 14c93adc22e90..fdbe10778bf98 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -10,3 +10,4 @@ tensorboard==2.10.0 python-etcd==0.4.5 sphinx-copybutton==0.5.0 sphinx-panels==0.4.1 +myst-parser==0.18.1 diff --git a/docs/source/_dynamo.rst b/docs/source/_dynamo.rst new file mode 100644 index 0000000000000..5e16dcf52ddee --- /dev/null +++ b/docs/source/_dynamo.rst @@ -0,0 +1,13 @@ +.. _torch_dynamo: + +torch._dynamo +-------------------------- + +.. warning :: + This module is an early prototype and is subject to change. + +.. currentmodule:: torch._dynamo + +.. automodule:: torch._dynamo + :members: + :member-order: bysource diff --git a/docs/source/_static/img/dynamo/TorchDynamo.png b/docs/source/_static/img/dynamo/TorchDynamo.png new file mode 100644 index 0000000000000..351689d80dc92 Binary files /dev/null and b/docs/source/_static/img/dynamo/TorchDynamo.png differ diff --git a/docs/source/_static/img/dynamo/td_stack.png b/docs/source/_static/img/dynamo/td_stack.png new file mode 100644 index 0000000000000..d20b3250453c5 Binary files /dev/null and b/docs/source/_static/img/dynamo/td_stack.png differ diff --git a/docs/source/_static/img/dynamo/torchinductor_backend.png b/docs/source/_static/img/dynamo/torchinductor_backend.png new file mode 100644 index 0000000000000..84e37aa7c4b63 Binary files /dev/null and b/docs/source/_static/img/dynamo/torchinductor_backend.png differ diff --git a/docs/source/backends.rst b/docs/source/backends.rst index 31eaa85e05020..2a02b325341fb 100644 --- a/docs/source/backends.rst +++ b/docs/source/backends.rst @@ -52,8 +52,14 @@ torch.backends.cuda .. autofunction:: torch.backends.cuda.preferred_linalg_library +.. autoclass:: torch.backends.cuda.SDPBackend + .. autofunction:: torch.backends.cuda.flash_sdp_enabled +.. autofunction:: torch.backends.cuda.enable_mem_efficient_sdp + +.. autofunction:: torch.backends.cuda.mem_efficient_sdp_enabled + .. autofunction:: torch.backends.cuda.enable_flash_sdp .. autofunction:: torch.backends.cuda.math_sdp_enabled diff --git a/docs/source/community/contribution_guide.rst b/docs/source/community/contribution_guide.rst index a2a89721b64e2..30bd9c6cf9751 100644 --- a/docs/source/community/contribution_guide.rst +++ b/docs/source/community/contribution_guide.rst @@ -138,7 +138,7 @@ A great deal of the tutorials on `pytorch.org `__ come from the community itself and we welcome additional contributions. To learn more about how to contribute a new tutorial you can learn more here: `PyTorch.org Tutorial Contribution Guide on -Github `__ +GitHub `__ Improving Documentation & Tutorials ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/community/persons_of_interest.rst b/docs/source/community/persons_of_interest.rst index d011250d490d0..c6fc75b865f0c 100644 --- a/docs/source/community/persons_of_interest.rst +++ b/docs/source/community/persons_of_interest.rst @@ -7,7 +7,7 @@ Responsibilities * Triage and fix high priority issues assigned to the module or library * Triage, review, and land high priority pull requests assigned to the module or library * Answer module or library questions on `discuss.pytorch.org `__ - and `dev-discuss.pytorch.org `__ + and `dev-discuss.pytorch.org `__ * Maintain public user and development documentation * Run meetings and share minutes plus roadmap on a half or quarterly basis @@ -116,6 +116,22 @@ Sparse (torch.sparse) - Christian Puhrsch (`cpuhrsch `__) - Andrew James (`amjames `__) +NestedTensor (torch.nested) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +- Alban Desmaison (`albanD `__) +- Christian Puhrsch (`cpuhrsch `__) +- Driss Guessous (`drisspg `__) +- Joel Schlosser (`jbschlosser `__) +- Mikayla Gawarecki (`mikaylagawarecki `__) +- Natalia Gimelshein (`ngimel `__) + +MaskedTensor (torch.masked) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +- Christian Puhrsch (`cpuhrsch `__) +- (emeritus) George Qi (`george-qi `__) + Fast Fourier Transform (torch.fft) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -307,6 +323,11 @@ TorchAudio ~~~~~~~~~~ - Moto Hira (`mthrok `__) +- Jeff Hwang (`hwangjeff `__) +- Caroline Chen (`carolineechen `__) +- Xiaohui Zhang (`xiaohui-zhang `__) +- Zhaoheng Ni (`nateanl `__) +- (emeritus) Christian Puhrsch (`cpuhrsch `__) - (emeritus) Vincent QB (`vincentqb `__) TorchRec diff --git a/docs/source/conf.py b/docs/source/conf.py index 8c0eac82cf996..f4d1d8b68eb92 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -58,7 +58,8 @@ 'sphinxcontrib.katex', 'sphinx.ext.autosectionlabel', 'sphinx_copybutton', - 'sphinx_panels' + 'sphinx_panels', + 'myst_parser', ] # build the templated autosummary files @@ -335,8 +336,8 @@ "Quantize", # torch.utils.backcompat "Warning", - "SymIntNode", - "SymFloatNode", + "SymInt", + "SymFloat", ] # The suffix(es) of source filenames. diff --git a/docs/source/cuda._sanitizer.rst b/docs/source/cuda._sanitizer.rst index 097d26a324f12..658b975693112 100644 --- a/docs/source/cuda._sanitizer.rst +++ b/docs/source/cuda._sanitizer.rst @@ -29,7 +29,7 @@ Here is an example of a simple synchronization error in PyTorch: The ``a`` tensor is initialized on the default stream and, without any synchronization methods, modified on a new stream. The two kernels will run concurrently on the same tensor, -which might cause the second kernel to read unitialized data before the first one was able +which might cause the second kernel to read uninitialized data before the first one was able to write it, or the first kernel might overwrite part of the result of the second. When this script is run on the commandline with: :: diff --git a/docs/source/cuda.rst b/docs/source/cuda.rst index 601bb078752fd..b14e5cec360db 100644 --- a/docs/source/cuda.rst +++ b/docs/source/cuda.rst @@ -114,6 +114,8 @@ Memory management caching_allocator_alloc caching_allocator_delete get_allocator_backend + CUDAPluggableAllocator + change_current_allocator .. FIXME The following doesn't seem to exist. Is it supposed to? https://github.com/pytorch/pytorch/issues/27785 .. autofunction:: reset_max_memory_reserved diff --git a/docs/source/data.rst b/docs/source/data.rst index db6957c8da787..b44096d101964 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -65,7 +65,7 @@ in real time. See :class:`~torch.utils.data.IterableDataset` for more details. -.. note:: When using an :class:`~torch.utils.data.IterableDataset` with +.. note:: When using a :class:`~torch.utils.data.IterableDataset` with `multi-process data loading `_. The same dataset object is replicated on each worker process, and thus the replicas must be configured differently to avoid duplicated data. See @@ -441,9 +441,6 @@ Example:: .. autoclass:: torch.utils.data.distributed.DistributedSampler -.. This module is experimental and should be private, adding it here for now -.. py:module:: torch.utils.data.communication - .. These modules are documented as part of torch/data listing them here for .. now until we have a clearer fix .. py:module:: torch.utils.data.datapipes diff --git a/docs/source/distributed.checkpoint.rst b/docs/source/distributed.checkpoint.rst new file mode 100644 index 0000000000000..380ec0e6022a4 --- /dev/null +++ b/docs/source/distributed.checkpoint.rst @@ -0,0 +1,4 @@ +Distributed Checkpoint +======================== + +.. automodule:: torch.distributed.checkpoint diff --git a/docs/source/distributed.rst b/docs/source/distributed.rst index 8b1186fb4ceec..777e8f5a2085f 100644 --- a/docs/source/distributed.rst +++ b/docs/source/distributed.rst @@ -190,6 +190,8 @@ joined. .. autofunction:: is_nccl_available +.. autofunction:: is_gloo_available + .. autofunction:: is_torchelastic_launched -------------------------------------------------------------------------------- @@ -331,6 +333,12 @@ an opaque group handle that can be given as a ``group`` argument to all collecti .. autofunction:: new_group +.. autofunction:: get_group_rank + +.. autofunction:: get_global_rank + +.. autofunction:: get_process_group_ranks + Point-to-point communication ---------------------------- @@ -350,6 +358,10 @@ as they should never be created manually, but they are guaranteed to support two .. autofunction:: irecv +.. autofunction:: batch_isend_irecv + +.. autoclass:: P2POp + Synchronous and asynchronous collective operations -------------------------------------------------- Every collective operation function supports the following two kinds of operations, @@ -433,6 +445,8 @@ Collective functions .. autofunction:: reduce_scatter_tensor +.. autofunction:: all_to_all_single + .. autofunction:: all_to_all .. autofunction:: barrier @@ -828,6 +842,13 @@ following matrix shows how the log level can be adjusted via the combination of | ``INFO`` | ``DETAIL`` | Trace (a.k.a. All) | +-------------------------+-----------------------------+------------------------+ +Distributed has a custom Exception type derived from `RuntimeError` called `torch.distributed.DistBackendError`. This exception is thrown when a backend-specific error occurs. For example, if +the `NCCL` backend is used and the user attempts to use a GPU that is not available to the `NCCL` library. + +.. autoclass:: torch.distributed.DistBackendError + +.. warning:: + The DistBackendError exception type is an experimental feature is subject to change. .. Distributed modules that are missing specific entries. .. Adding them here for tracking purposes until they are more permanently fixed. @@ -845,3 +866,4 @@ following matrix shows how the log level can be adjusted via the combination of .. py:module:: torch.distributed.pipeline .. py:module:: torch.distributed.pipeline.sync .. py:module:: torch.distributed.pipeline.sync.skip +.. py:module:: torch.distributed.tensor diff --git a/docs/source/distributed.tensor.parallel.rst b/docs/source/distributed.tensor.parallel.rst new file mode 100644 index 0000000000000..64544539edd43 --- /dev/null +++ b/docs/source/distributed.tensor.parallel.rst @@ -0,0 +1,7 @@ +.. role:: hidden + :class: hidden-section + +Tensor Parallelism +======================== +.. py:module:: torch.distributed.tensor.parallel +.. currentmodule:: torch.distributed.tensor.parallel diff --git a/docs/source/dynamo/custom-backends.rst b/docs/source/dynamo/custom-backends.rst new file mode 100644 index 0000000000000..7322fceb51815 --- /dev/null +++ b/docs/source/dynamo/custom-backends.rst @@ -0,0 +1,157 @@ +Custom Backends +=============== + +Debugging Backend +----------------- + +If you want to better understand what is going on during a +compilation, you can create a custom compiler, which is referred to as +backend in this section, that will print pretty print the fx +``GraphModule`` extracted from Dynamo’s bytecode analysis +and return a ``forward()`` callable. + +For example: + +.. code-block:: python + + from typing import List + import torch + import torch._dynamo as dynamo + def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): + print("my_compiler() called with FX graph:") + gm.graph.print_tabular() + return gm.forward # return a python callable + @dynamo.optimize(my_compiler) + def fn(x, y): + a = torch.cos(x) + b = torch.sin(y) + return a + b + fn(torch.randn(10), torch.randn(10)) + +Running the above example produces the following output: + +:: + + my_compiler() called with FX graph: + opcode name target args kwargs + ------------- ------ ------------------------------------------------------ ---------- -------- + placeholder x x () {} + placeholder y y () {} + call_function cos (x,) {} + call_function sin (y,) {} + call_function add (cos, sin) {} + output output output ((add,),) {} + +This works for ``torch.nn.Module`` as well as shown below: + +.. code-block:: python + + import torch + import torch._dynamo as dynamo + class MockModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.relu = torch.nn.ReLU() + def forward(self, x): + return self.relu(torch.cos(x)) + mod = MockModule() + optimized_mod = dynamo.optimize(my_compiler)(mod) + optimized_mod(torch.randn(10)) + +Let’s take a look at one more example with control flow: + +.. code-block:: python + + from typing import List + import torch + import torch._dynamo as dynamo + def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): + print("my_compiler() called with FX graph:") + gm.graph.print_tabular() + return gm.forward # return a python callable + @dynamo.optimize(my_compiler) + def toy_example(a, b): + x = a / (torch.abs(a) + 1) + if b.sum() < 0: + b = b * -1 + return x * b + for _ in range(100): + toy_example(torch.randn(10), torch.randn(10)) + +Running this example produces the following output: + +:: + + my_compiler() called with FX graph: + opcode name target args kwargs + ------------- ------- ------------------------------------------------------ ---------------- -------- + placeholder a a () {} + placeholder b b () {} + call_function abs_1 (a,) {} + call_function add (abs_1, 1) {} + call_function truediv (a, add) {} + call_method sum_1 sum (b,) {} + call_function lt (sum_1, 0) {} + output output output ((truediv, lt),) {} + + my_compiler() called with FX graph: + opcode name target args kwargs + ------------- ------ ----------------------- ----------- -------- + placeholder b b () {} + placeholder x x () {} + call_function mul (b, -1) {} + call_function mul_1 (x, mul) {} + output output output ((mul_1,),) {} + + my_compiler() called with FX graph: + opcode name target args kwargs + ------------- ------ ----------------------- --------- -------- + placeholder b b () {} + placeholder x x () {} + call_function mul (x, b) {} + output output output ((mul,),) {} + +The order of the last two graphs is nondeterministic depending +on which one is encountered first by the just-in-time compiler. + +Speedy Backend +-------------- + +Integrating a custom backend that offers superior performance is also +easy and we’ll integrate a real one +with `optimize_for_inference `__: + +.. code-block:: python + + def optimize_for_inference_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): + scripted = torch.jit.trace(gm, example_inputs) + return torch.jit.optimize_for_inference(scripted) + +And then you should be able to optimize any existing code with: + +.. code-block:: python + + @dynamo.optimize(optimize_for_inference_compiler) + def code_to_accelerate(): + ... + +Composable Backends +------------------- + +TorchDynamo includes many backends, which can be found in +`backends.py `__ +or ``torchdynamo.list_backends()``. You can combine these backends +together with the following code: + +.. code-block:: python + + from torch._dynamo.optimizations import BACKENDS + def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): + trt_compiled = BACKENDS["tensorrt"](gm, example_inputs) + if trt_compiled is not None: + return trt_compiled + # first backend failed, try something else... + cudagraphs_compiled = BACKENDS["cudagraphs"](gm, example_inputs) + if cudagraphs_compiled is not None: + return cudagraphs_compiled + return gm.forward diff --git a/docs/source/dynamo/deep-dive.rst b/docs/source/dynamo/deep-dive.rst new file mode 100644 index 0000000000000..468fdc6ff9467 --- /dev/null +++ b/docs/source/dynamo/deep-dive.rst @@ -0,0 +1,145 @@ +TorchDynamo Deeper Dive +======================= +**Author**: `Jason Ansel `_ + +What is a guard? +---------------- + +TorchDynamo operates just-in-time and specializes graphs based on +dynamic properties. For example, the first graph above has the following +guards: + +:: + + GUARDS: + - local 'a' TENSOR_MATCH + - local 'b' TENSOR_MATCH + - global 'torch' FUNCTION_MATCH + +If any of those guards fail, the graph will be recaptured and +recompiled. The interesting guard type there is ``TENSOR_MATCH``, which +checks the following ``torch.Tensor`` properties: + +- Python class of the tensor (tensor subclassing, etc) +- dtype +- device +- requires_grad +- dispatch_key (with thread-local includes/excludes applied) +- ndim +- sizes\* (optional) +- strides\* (optional) + +For sizes/strides you can disable this specialization by setting the +following parameter: + +.. code-block:: python + + torch._dynamo.config.dynamic_shapes = True + +The full specialization mode allows the backend compiler to assume an +entirely static graph. Unfortunately, most backends require this. +Operators which return dynamic shapes will trigger a graph break when +not in dynamic shape mode. + +What is Dynamo doing? +--------------------- + +If you want to understand better what TorchDynamo is doing, you can set: + +.. code-block:: python + + torchdynamo.config.debug = True + +This code triggers useful (but spammy) printouts. + +For example, the printouts for the first graph in the ``toy_example`` +are: + +:: + + __compiled_fn_0 .1 + opcode name target args kwargs + ------------- ------- ------------------------------------------------------ ---------------- -------- + placeholder a a () {} + placeholder b b () {} + call_function abs_1 (a,) {} + call_function add (abs_1, 1) {} + call_function truediv (a, add) {} + call_method sum_1 sum (b,) {} + call_function lt (sum_1, 0) {} + output output output ((truediv, lt),) {} + + ORIGINAL BYTECODE toy_example example.py 9 + 10 0 LOAD_FAST 0 (a) + 2 LOAD_GLOBAL 0 (torch) + 4 LOAD_METHOD 1 (abs) + 6 LOAD_FAST 0 (a) + 8 CALL_METHOD 1 + 10 LOAD_CONST 1 (1) + 12 BINARY_ADD + 14 BINARY_TRUE_DIVIDE + 16 STORE_FAST 2 (x) + + 11 18 LOAD_FAST 1 (b) + 20 LOAD_METHOD 2 (sum) + 22 CALL_METHOD 0 + 24 LOAD_CONST 2 (0) + 26 COMPARE_OP 0 (<) + 28 POP_JUMP_IF_FALSE 38 + + 12 30 LOAD_FAST 1 (b) + 32 LOAD_CONST 3 (-1) + 34 BINARY_MULTIPLY + 36 STORE_FAST 1 (b) + + 13 >> 38 LOAD_FAST 2 (x) + 40 LOAD_FAST 1 (b) + 42 BINARY_MULTIPLY + 44 RETURN_VALUE + + MODIFIED BYTECODE + 9 0 LOAD_GLOBAL 3 (__compiled_fn_0) + 2 LOAD_FAST 0 (a) + 4 LOAD_FAST 1 (b) + 6 CALL_FUNCTION 2 + 8 UNPACK_SEQUENCE 2 + 10 STORE_FAST 2 (x) + 12 POP_JUMP_IF_FALSE 24 + 14 LOAD_GLOBAL 4 (__resume_at_30_1) + 16 LOAD_FAST 1 (b) + 18 LOAD_FAST 2 (x) + 20 CALL_FUNCTION 2 + 22 RETURN_VALUE + >> 24 LOAD_GLOBAL 5 (__resume_at_38_2) + 26 LOAD_FAST 1 (b) + 28 LOAD_FAST 2 (x) + 30 CALL_FUNCTION 2 + 32 RETURN_VALUE + + GUARDS: + - local 'a' TENSOR_MATCH + - local 'b' TENSOR_MATCH + - global 'torch' FUNCTION_MATCH + +At the top you can see the FX graph. +Next, you see the original bytecode of the function, followed by the +modified bytecode generated by TorchDynamo. Finally, you see the guards +which we covered above. + +In the modified bytecode, ``__compiled_fn_0`` is the return value of +``my_compiler()`` (the compiled graph). ``__resume_at_30_1`` and +``__resume_at_38_2`` are both generated continuation functions that pick +up execution after a graph break (at bytecode offsets 30 and 38). Each +of these functions take the form: + +:: + + __resume_at_: + ... restore stack state if needed ... + JUMP_ABSOLUTE into toy_example + ... original bytecode of toy_example ... + +By generating this `resume_at` function, we force the remainder of the +function to be executed in a new Python frame which recursively +triggers TorchDynamo to restart its capture once execution reaches that +point for the first time. diff --git a/docs/source/dynamo/faq.rst b/docs/source/dynamo/faq.rst new file mode 100644 index 0000000000000..2b66e81ebc694 --- /dev/null +++ b/docs/source/dynamo/faq.rst @@ -0,0 +1,376 @@ +Frequently Asked Questions +========================== + +At a high level, the TorchDynamo stack consists of a graph capture from +Python code using dynamo and a backend compiler. In this example the +backend compiler consists of backward graph tracing using AOTAutograd +and graph lowering using TorchInductor. There are of course many more +compilers available `here `__ +but for this document we will focus on inductor as a motivating example. + +Torchdynamo supports training, using AotAutograd to capture backwards: + + 1. the ``.forward()`` graph and ``optimizer.step()`` is captured by torchdynamo’s python evalframe frontend + 2. for each segment of ``.forward()`` that torchdynamo captures, it uses AotAutograd to generate a backward graph segment + 3. each pair of forward, backward graph are (optionally) min-cut partitioned to save the minimal state between forward/backward + 4. the forward, backward pairs are wrapped in autograd.function modules 5. usercode calling\ ``.backward()`` still triggers eager’s autograd engine, which runs each ‘compiled backward’ graph as if it were one op, also running any non-compiled eager ops’ .backward() functions + +Do you support Distributed code? +-------------------------------- + +DDP has been tested and works, support for other distributed training +libraries is under discussion. + +The main reason why Distributed code is challenging with dynamo is +because AOTAutograd unrolls both the forward and backward pass and +provides 2 graphs for backends to optimize. This is a problem for +distributed code because we’d like to ideally overlap communication +operations with computations. Eager pytorch accomplishes this in +different ways for DDP/FSDP- using autograd hooks, module hooks, and +modifications/mutations of module states. In a naive application of +dynamo, hooks that should run directly after an operation during +backwards may be delayed until after the entire compiled region of +backwards ops, due to how AOTAutograd compiled functions interact with +dispatcher hooks. + +The basic strategy for optimizing DDP with Dynamo is outlined in +`distributed.py `__ +where the main idea will be to graph break on `DDP bucket +boundaries `__. + +When each node in DDP needs to synchronize its weights with the other +nodes it organizes its gradients and parameters into buckets which +reduces communication times and allows a node to broadcast a fraction of +its gradients to other waiting nodes. + +Graph breaks in distributed code means you can expect dynamo and its +backends to optimize the compute overhead of a distributed program but +not its communication overhead. Graph-breaks may interfere with +compilation speedups, if the reduced graph-size robs the compiler of +fusion opportunities. However, there are diminishing returns with +increasing graph size since most of the current compute optimizations +are local fusions. So in practice this approach may be sufficient. + +Do I still need to export whole graphs? +--------------------------------------- + +For the vast majority of models you probably don’t and you can use +``torch._dynamo()`` optimize as is but there are a few situations where +full graphs are necessary and you can can ensure a full graph by simply +running ``torch.dynamo(..., nopython=True)`` \* Large scale training +runs, think $250K+ that require pipeline parallelism and other advanced +sharding strategies \* Inference optimizers like +`TensorRT `__ or +`AITemplate `__ that rely +on fusing much more aggressively than training optimizers \* Mobile training or +inference. + +Future work will include tracing communication operations into graphs, +coordinating these operations with compute optimizations, and optimizing +the communciation operations. + +Why is my code crashing? +------------------------ + +If your code ran just fine without dynamo and started to crash with it +enabled then the most important first step is figuring out which part of +the stack your failure occurred in so try running things in the below +order and only try the next step if the previous step succeeded. + +1. ``dynamo.optimize("eager")`` which only runs torchdynamo forward graph + capture and then runs the captured graph with PyTorch. If this fails + then there’s an issue with TorchDynamo. + +2. ``dynamo.optimize("aot_eager")`` + which runs torchdynamo to capture a forward graph, and then AOTAutograd + to trace the backward graph without any additional backend compiler + steps. PyTorch eager will then be used to run the forward and backward + graphs. If this fails then there’s an issue with AOTAutograd. + +3. ``dynamo.optimize("inductor")`` which runs torchdynamo to capture a + forward graph, and then AOTAutograd to trace the backward graph with the + TorchInductor compiler. If this fails then there’s an issue with TorchInductor + +TorchDynamo Errors +~~~~~~~~~~~~~~~~~~ + +If the error that is generated occurs with the ``"eager"`` backend, then +torchdynamo is the most likely source of the error. + +To debug these issues we recommend setting +``torch._dynamo.config.verbose=True`` to get a full stack trace to both +the error in torchdynamo and the user code. In addition to this flag, +you can also set the ``log_level`` of torchdynamo through +``torch._dynamo.config.log_level``. The available levels are the +following: - ``logging.DEBUG``: Print every instruction that is +encountered in addition to all below log levels - ``logging.INFO``: +Print each function that is compiled (original and modified bytecode) +and the graph that is captured in addition to all below log levels - +``logging.WARNING`` (default): Print graph breaks in addition to all +below log levels - ``logging.ERROR``: Print errors only + +If a model is sufficiently large, the logs can become overwhelming. If +an error occurs deep within a model’s python code, it can be useful to +execute only the frame in which the error occurs to enable easier +debugging. There are 2 tools available to enable this: + +* ``env TORCHDYNAMO_DEBUG_FUNCTION=`` will only run TorchDynamo on functions with that name. + +* ``env torch._dynamo.config.replay_record_enabled = True``) which dumps an execution record when an error is encountered. This record can then be replayed to run only the frame where an error occurred. + +TorchInductor Errors +-------------------- + +With TorchInductor as the chosen backend, AOTAutograd is used to +generate the backward graph from the forward graph captured by +torchdynamo. It’s important to note that errors can occur during this +tracing and also while TorchInductor lowers the forward and backward +graphs to GPU code or C++. + +A model can often consist of hundreds or thousands of FX nodes, so +narrowing the exact nodes where this problem occurred can be very +difficult which is why we highly recommend you use our minifier to +create tiny reproducible examples of failures you’re seeing. We can +minify errors that occur either at the AOTAutograd layer or Inductor +layer which you should try in the following order. + +1. ``env TORCHDYNAMO_REPRO_AFTER="aot" python your_model.py`` +2. ``env TORCHDYNAMO_REPRO_AFTER="dynamo" python your_model.py`` + +Minifying your error is the quickest path to getting it fixed. + +The minifier will actually create a ``repro.py`` for you at the location +set by ``env TORCHDYNAMO_REPRO_DIR`` so make you have right access to +that directory. You can then run ``python repro.py`` and confirm that +you are getting the same error. + +.. note:: + For other compilers such as nvfuser, the process is similar but + instead you would leverage ``env TORCHDYNAMO_REPRO_AFTER="dynamo" python your_model.py``. + +Why is compilation slow? +------------------------ + +Dynamo Compilation +~~~~~~~~~~~~~~~~~~ + +TorchDynamo has a builtin stats function for collecting and displaying +the time spent in each compilation phase. These stats can be accessed by +calling ``torch._dynamo.utils.compile_times()`` after executing +``torch._dynamo``. By default, this returns a string representation of +the compile times spent in each TorchDynamo function by name. + +Inductor Compilation +~~~~~~~~~~~~~~~~~~~~ + +TorchInductor has a builtin stats and trace function for displaying time +spent in each compilation phase, output code, output graph visualization +and IR dump. ``env TORCHINDUCTOR_TRACE=1 python repro.py``. This is a +debugging tool designed to make it easier to debug/understand the +internals of TorchInductor with an output that will look something like +`this `__ + +Each file in that debug trace can be enabled/disabled via +``torch._inductor.config.trace.*``. The profile and the diagram are both +disabled by default since they are expensive to generate. See the +`example debug directory +output `__ +for more examples. + +Excessive Recompilation +~~~~~~~~~~~~~~~~~~~~~~~ + +When TorchDynamo compiles a function (or part of one), it makes certain +assumptions about locals and globals in order to allow compiler +optimizations, and expresses these assumptions as guards that check +particular values at runtime. If any of these guards fail, Dynamo will +recompile that function (or part) up to +``torch._dynamo.config.cache_size_limit`` times. If your program is +hitting the cache limit, you will first need to determine which guard is +failing and what part of your program is triggering it. + +The `recompilation profiler <#recompilation-profiler>`__ automates the +process of setting TorchDynamo’s cache limit to 1 and running your +program under an observation-only ‘compiler’ that records the causes of +any guard failures. You should be sure to run your program for at least +as long (as many iterations) as you were running when you ran into +trouble, and the profiler will accumulate statistics over this duration. + +.. code-block:: python + + prof = dynamo.utils.CompilationProfiler() + @dynamo.optimize(prof) + def my_model(): + ... + my_model() + print(prof.report()) + +Many of the reasons for graph breaks and excessive recompilation will be +fixed with upcoming support for `tracing dynamic tensor +shapes `__, +more careful choices for guards and better tuned heuristics. + +Why are you recompiling in production? +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +In some cases, you may not want unexpected compiles after a program has +warmed up. For example, if you are serving production traffic in a +latency critical application. For this, TorchDynamo provides an +alternate mode where prior compiled graphs are used, but no new ones are +generated: + +.. code-block:: python + + frozen_toy_example = dynamo.run(toy_example) + frozen_toy_example(torch.randn(10), torch.randn(10)) + +How are you speeding up my code? +-------------------------------- + +There are 3 major ways to accelerat PyTorch code: + +1. Kernel fusion via vertical fusions which fuse sequential operations to avoid + excessive read/writes. For example, fuse 2 subsequent cosines means you + can can do 1 read 1 write instead 2 reads 2 writes 2. Horizontal fusion: + the simplest example being batching where a single matrix is multiplied + with a batch of examples but the more general scenario is a grouped GEMM + where a group of matrix multiplications are scheduled together + +2. Out of order execution: A general optimization for compilers, by looking ahead + at the exact data dependencies within a graph we can decide on the most + opportune time to execute a node and which buffers can be reused + +3. Automatic work placement: Similar of the out of order execution point, + but by matching nodes of a graph to resources like physical hardware or + memory we can design an appropriate schedule + +The above are general principles for accelerating PyTorch code but +different backends will each make different tradeoffs on what to +optimize. For example Inductor first takes care of fusing whatever it +can and only then generates `Triton `__ +kernels. It can also + +Triton in addition offers speedups because of automatic memory +coalescing, memory management and scheduling within each Streaming +Multiprocessor and has been designed to handle tiled computations. + +However, regardless of the backend you use it’s best to use a benchmark +and see approach so try out the PyTorch profiler, visually inspect the +generated kernels and try to see what’s going on for yourself. + +Why am I not seeing speedups? +----------------------------- + +Graph Breaks +~~~~~~~~~~~~ + +The main reason you won’t see the speedups you’d like to by using dynamo +is excessive graph breaks. So what’s a graph break? + +Given a program like: + +.. code-block:: python + + @dynamo.optimize(...) + def some_fun(x): + ... + some_fun(x) + ... + +Torchdynamo will attempt to compile all of the torch/tensor operations +within ``some_fun()`` into a single FX graph, but it may fail to capture +everything into one graph. + +Some graph break reasons are insurmountable to TorchDynamo like calling +into a C extension other than torch is invisible to torchdynamo, and +could do arbitrary things without TorchDynamo being able to introduce +necessary guards to ensure that the compiled program would be safe to reuse. + + To maximize performance, it’s important to have as few graph breaks + as possible. + +Identifying the cause of a graph break +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +To identify all graph breaks in a program and the associated reasons for +the breaks, ``torch._dynamo.explain`` can be used. This tool runs +TorchDynamo on the supplied function and aggregates the graph breaks +that are encountered. Here is an example usage: + +.. code-block:: python + + import torch + import torch._dynamo as dynamo + def toy_example(a, b): + x = a / (torch.abs(a) + 1) + print("woo") + if b.sum() < 0: + b = b * -1 + return x * b + explanation, out_guards, graphs, ops_per_graph = dynamo.explain(toy_example, torch.randn(10), torch.randn(10)) + print(explanation) + """ + Dynamo produced 3 graphs, with 2 graph break and 6 ops. + Break reasons: + 1. call_function BuiltinVariable(print) [ConstantVariable(str)] {} + File "t2.py", line 16, in toy_example + print("woo") + + 2. generic_jump + File "t2.py", line 17, in toy_example + if b.sum() < 0: + """ + +To throw an error on the first graph break encountered you can use +disable python fallback by using ``nopython=True``, this should be +familiar if you’ve worked with export based compilers. + +.. code-block:: python + + @dynamo.optimize(, nopython=True) + def toy_example(a, b): + ... + +Why didn’t my code recompile when I changed it? +----------------------------------------------- + +If you went ahead and enabled dynamic shapes via +``env TORCHDYNAMO_DYNAMIC_SHAPES=1 python model.py`` then your code +won’t recompile on shape changes. We’ve added support for dynamic shapes +which avoids recompilations in the case when shapes vary by less than a +factor of 2. This is especially useful in scenarios like varying image +sizes in CV or variable sequence length in NLP. In inference scenarios +it’s often not possible to know what a batch size will be beforehand +because you take what you can get from different client apps. + +In general, TorchDynamo tries very hard not to recompile things +unnecessarily so if for example torchdynamo finds 3 graphs and your +change only modified one graph then only that graph will recompile. So +another tip to avoid potentially slow compilation times is to warmup a +model by compiling it once after which subsequent compilations will be +much faster. Cold start compile times is still a metric we track +visibly. + +Why am I getting incorrect results? +----------------------------------- + +Accuracy issues can also be minified if you set the environment variable +``TORCHDYNAMO_REPRO_LEVEL=4``, it operates with a similar git bisect +model and a full repro might be something like +``TORCHDYNAMO_REPRO_AFTER="aot" TORCHDYNAMO_REPRO_LEVEL=4`` the reason +we need this is downstream compilers will codegen code whether it’s +Triton code or the C++ backend, the numerics from those downstream +compilers can be different in subtle ways yet have dramatic impact on +your training stability. So the accuracy debugger is very useful for us +to detect bugs in our codegen or with a backend compiler. + +Why am I getting OOMs? +---------------------- + +Dynamo is still an alpha product so there’s a few sources of OOMs and if +you’re seeing an OOM try disabling the following configurations in this +order and then open an issue on Github so we can solve the root problem +1. If you’re using dynamic shapes try disabling them, we’ve disabled +them by default: ``env TORCHDYNAMO_DYNAMIC_SHAPES=0 python model.py`` 2. +CUDA graphs with Triton are enabled by default in inductor but removing +them may alleviate some OOM issues: ``torch._inductor.config.triton.cudagraphs = False``. \ No newline at end of file diff --git a/docs/source/dynamo/get-started.rst b/docs/source/dynamo/get-started.rst new file mode 100644 index 0000000000000..fa1be5d43764d --- /dev/null +++ b/docs/source/dynamo/get-started.rst @@ -0,0 +1,178 @@ +Getting Started +=============== + +Let’s start with a simple example. Note that you are likely to see more +significant speedups the newer your GPU is. + +.. code:: python + + from torch._dynamo import optimize + import torch + def fn(x, y): + a = torch.cos(x).cuda() + b = torch.sin(y).cuda() + return a + b + new_fn = optimize("inductor")(fn) + input_tensor = torch.randn(10000).to(device="cuda:0") + a = new_fn(input_tensor, input_tensor) + +This example will not actually run faster. Its purpose is to demonstrate +the ``torch.cos()`` and ``torch.sin()`` features which are +examples of pointwise ops as in they operate element by element on a +vector. A more famous pointwise op you might want to use would +be something like ``torch.relu()``. Pointwise ops in eager mode are +suboptimal because each one would need to read a tensor from +memory, make some changes, and then write back those changes. The single +most important optimization that inductor does is fusion. So back to our +example we can turn 2 reads and 2 writes into 1 read and 1 write which +is crucial especially for newer GPUs where the bottleneck is memory +bandwidth (how quickly you can send data to a GPU) rather than compute +(how quickly your GPU can crunch floating point operations). + +Another major optimization that inductor makes available is automatic +support for CUDA graphs. +CUDA graphs help eliminate the overhead from launching individual +kernels from a Python program which is especially relevant for newer GPUs. + +TorchDynamo supports many different backends but inductor specifically works +by generating `Triton `__ kernels and +we can inspect them by running ``TORCHINDUCTOR_TRACE=1 python trig.py`` +with the actual generated kernel being + +.. code-block:: python + + @pointwise(size_hints=[16384], filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32'}, 'device': 0, 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}) + @triton.jit + def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): + xnumel = 10000 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.reshape(tl.arange(0, XBLOCK), [XBLOCK]) + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tmp1 = tl.sin(tmp0) + tmp2 = tl.sin(tmp1) + tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK], tl.int32)), tmp2, xmask) + +And you can verify that fusing the two ``sins`` did actually occur +because the two ``sin`` operations occur within a single Triton kernel +and the temporary variables are held in registers with very fast access. + +You can read up a lot more on Triton’s performance +`here `__ but the key is it’s in Python +so you can easily understand it even if you have not written all that +many CUDA kernels. + +Next, let’s try a real model like resnet50 from the PyTorch +hub. + +.. code-block:: python + + import torch + import torch._dynamo as dynamo + model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True) + opt_model = dynamo.optimize("inductor")(model) + model(torch.randn(1,3,64,64)) + +And that is not the only available backend, you can run in a REPL +``dynamo.list_backends()`` to see all the available backends. Try out the +``aot_cudagraphs`` or ``nvfuser`` next as inspiration. + +Let’s do something a bit more interesting now, our community frequently +uses pretrained models from +`transformers `__ or +`TIMM `__ and one of +our design goals is for Dynamo and inductor to work out of the box with +any model that people would like to author. + +So we will directly download a pretrained model from the +HuggingFace hub and optimize it: + +.. code-block:: python + + import torch + from transformers import BertTokenizer, BertModel + import torch._dynamo as dynamo + # Copy pasted from here https://huggingface.co/bert-base-uncased + tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + model = BertModel.from_pretrained("bert-base-uncased").to(device="cuda:0") + model = dynamo.optimize("inductor")(model) # This is the only line of code that we changed + text = "Replace me by any text you'd like." + encoded_input = tokenizer(text, return_tensors='pt').to(device="cuda:0") + output = model(**encoded_input) + +If you remove the ``to(device="cuda:0")`` from the model and +``encoded_input``, then Triton will generate C++ kernels that will be +optimized for running on your CPU. You can inspect both Triton or C++ +kernels for BERT, they’re obviously more complex than the trigonometry +example we had above but you can similarly skim it and understand if you +understand PyTorch. + +Similarly let’s try out a TIMM example + +.. code-block:: python + + import timm + import torch._dynamo as dynamo + import torch + model = timm.create_model('resnext101_32x8d', pretrained=True, num_classes=2) + opt_model = dynamo.optimize("inductor")(model) + opt_model(torch.randn(64,3,7,7)) + +Our goal with Dynamo and inductor is to build the highest coverage ML compiler +which should work with any model you throw at it. + +Existing Backends +~~~~~~~~~~~~~~~~~ + +TorchDynamo has a growing list of backends, which can be found in +`backends.py `__ +or ``torchdynamo.list_backends()`` each of which with its optional dependencies. + +Some of the most commonly used backends include: + +* **Debugging backends**: + * ``dynamo.optimize("eager")`` - Uses PyTorch + to run the extracted GraphModule. This is quite useful in debugging + TorchDynamo issues. + * ``dynamo.optimize("aot_eager")`` - Uses + AotAutograd with no compiler, for example, just using PyTorch eager for the + AotAutograd’s extracted forward and backward graphs. This is useful for + debugging, and unlikely to give speedups. + +* **Training & inference backends**: + * ``dynamo.optimize("inductor")`` - Uses ``TorchInductor`` backend + with AotAutograd and cudagraphs by leveraging + codegened Triton kernels `Read + more `__ + * ``dynamo.optimize("nvfuser")`` - nvFuser with TorchScript. `Read more `__ + * ``dynamo.optimize("aot_nvfuser")`` - nvFuser with AotAutograd. `Read more `__ + * ``dynamo.optimize("aot_cudagraphs")`` - cudagraphs with AotAutograd. `Read more `__ + +* **Inference-only backends**: + * ``dynamo.optimize("ofi")`` - Uses + Torchscript ``optimize_for_inference``. `Read + more `__ + * ``dynamo.optimize("fx2trt")`` - Uses Nvidia TensorRT for inference optimizations. `Read more `__ + * ``dynamo.optimize("onnxrt")`` - Uses ONNXRT for inference on CPU/GPU. `Read more `__ \* ``dynamo.optimize("ipex")`` - Uses IPEX for inference on CPU. `Read more `__ + +Why do you need another way of optimizing PyTorch code? +------------------------------------------------------- + +While a number of other code optimization tools exist in the PyTorch +ecosystem, each of them has its own flow. +Here is a few examples of existing methods and their limitations: + +- ``torch.jit.trace()`` is silently wrong if it cannot trace, for example: + during control flow +- ``torch.jit.script()`` requires modifications to user or library code + by adding type annotations and removing non PyTorch code +- ``torch.fx.symbolic_trace()`` either traces correctly or gives a hard + error but it’s limited to traceable code so still can’t handle + control flow +- ``torch._dynamo`` works out of the box and produces partial graphs. + It still has the option of producing a single graph with + ``nopython=True`` which are needed for `some + situations <./documentation/FAQ.md#do-i-still-need-to-export-whole-graphs>`__ + but allows a smoother transition where partial graphs can be + optimized without code modification diff --git a/docs/source/dynamo/guards-overview.rst b/docs/source/dynamo/guards-overview.rst new file mode 100644 index 0000000000000..a86cd202564b7 --- /dev/null +++ b/docs/source/dynamo/guards-overview.rst @@ -0,0 +1,511 @@ +Guards Overview +=============== + +From a UX perspective, TorchDynamo is very easy to use. The user invokes +``torchdynamo.optimize`` as an annotation: + +.. code-block:: python + + @torchdynamo.optimize(my_compiler) + def fn_foo(bar): + +Where a complete example looks like this: + +.. code-block:: python + + from typing import List + import torch + import torchdynamo + def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): + print("my_compiler() called with FX graph:") + gm.graph.print_tabular() + return gm.forward # return a python callable + @torchdynamo.optimize(my_compiler) + def toy_example(a, b): + x = a / (torch.abs(a) + 1) + if b.sum() < 0: + b = b * -1 + return x * b + for _ in range(100): + toy_example(torch.randn(10), torch.randn(10)) + +This allows TorchDynamo to capture the interpreted Python frames, grab +any and all relevant information, and speed things up wherever it can. +The speedup comes from a few places, and can be rather dependent on the +backend (`my_compiler` in the example above) provided, but the one speedup +that is important in this section is **caching**. Caching itself is not +a direct speedup but a critical enablement that prevents +recompilation. We dig a hole with dynamo, and caching allows us to get +out. It enables us to hold perf +neutrality while then enabling backends - the true source of our +speedups. + +With even a pass-through no-op backend provided: + +.. code-block:: python + + def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): + return gm.forward + +We can see TorchDynamo speeding up Python execution even on +regular Python, not just PyTorch. + +Caching and Guards Overview +--------------------------- + +TorchDynamo operates through caching transformed (by TorchDynamo) user +bytecode. When TorchDynamo receives a frame for evaluation, it checks if the +**objects referenced in the frame have changed** in certain ways, and if +not, TorchDynamo reads the previously transformed user bytecode to evaluate it. +In this section, we will focus on how we can identify whether or not the +**objects referenced in the frame have changed**. This is a critical +piece of functionality in TorchDynamo, because it drives the entire +invalidation lifecycle. This functionality is called **guards**. + +At a very high level, the flow can be summarized like this: + +1. TorchDynamo receives a Python frame. +2. It converts the frame (1) passing it through instruction + translation. +3. For the objects captured in (2), TorchDynamo creates tracking objects that + are: + * tracked on an output graph, which is an internal specialization + of a `torch.fx.Tracer` + * guards +4. TorchDynamo processes the guard objects created in (3), turning them into a + generated Python function, `check_fn`, associated with a piece of code. +5. The `check_fn` is evaluated whenever we encounter this code a + subsequent time - if a `check_fn` passes and evaluates to `True`, TorchDynamo + identifies the code in the cache and the code encountered here as same, and + can be safely used. If it fails and evaluates to `False`, TorchDynamo + identifies the code in the cache as not valid, and can be thrown out in + favor of a new entry, through recompilation or a graph break. + +Python Frame Evaluation and PEP 523 +----------------------------------- + +The functionality of TorchDynamo is based on +`PEP 523 `__. + +TorchDynamo installs a frame evaluation function on Python by using +`_PyInterpreterState_SetEvalFrameFunc`. TorchDynamo has a hook where +Python can hand control back to us during evaluation. + +The function we have installed is ``convert_frame`` or +``convert_frame_assert`` in the ``nopython=True`` case, but glossing +over that nuance for now, let’s take a look at ``convert_frame_assert``, +as ``convert_frame`` proxies to it. + +We can find it on `line 20 of convert_frame.py +`__, +with a signature as follows: + +.. code-block:: python + + def convert_frame_assert(compiler_fn: Callable, one_graph=True): + +This function wraps the entry point of where Python invokes TorchDynamo +with a frame: + +.. code-block:: python + + def _convert_frame_assert(frame: types.FrameType, cache_size: int): + +Here is what this function does: + +1. Checks if it has seen this ``code``\ (see: f_code `here + `__) before and exits + early if it did. +2. Checks if the code is an unsupported case. +3. Checks if the ``cache_size`` (second arg above) crosses the limit + defined in the config, ``cache_size_limit``. If it has, the function + drops the frame and logs warnings. This helps to avoid constant + recompilation of a frame as it generally means that the frame is hot + in an unexpected way and caching it produces needless overhead, + as it is likely to get evicted the next time it is encountered. +4. Passes the frame, alongside a function that creates an + ``InstructionTranslator`` through bytecode + transformation, via ``transform_code_object``. A few crucial things + happen under the hood here: + + 1. New code is produced through ``transform_code_object``. + + 2. An FX tracer named ``output`` is produced through + ``InstructionTranslator``. + + This can be a bit confusing, + as ``InstructionTranslator`` is not an `fx` tracer, but its stored + in a variable named tracer, and its output*\ **is**\ *an `fx`tracer.* + + 3. The function produces guards and stores them on ``output`` above. + + 4. The function produces ``output_instructions`` and stores them on + ``output`` above. + + 5. The function maps the newly produced transformed code to the initial code it + read off the frame. This mapping is worth remembering, we will + refer to it much later on below where we cover guard failures. + +5. Using the transformed code from 4.1 and the guards from 4.3, + the function produces a `GuardedCode`. + +Now that we have learned about frame evaluation, let’s review +``InstructionTranslator``, and see how it turns the frame we handed +it over into TorchDynamo internal types. + +InstructionTranslator +--------------------- + +`InstructionTranslator` does a lot! We won’t cover the details of +everything it does, but most importantly for this document, it produces +a mapping of ``symbolic_locals`` which maintains a mapping from the +frame’s ``f_locals`` to TorchDynamo internal Variable objects (more on these +in a moment. ``symbolic_locals`` is filled via traversing the frame’s +locals: + +.. code-block:: python + + self.symbolic_locals = collections.OrderedDict( + (k, VariableBuilder(self, LocalSource(k))(f_locals[k])) + for k in vars + if k in f_locals + ) + +The important component here is the invocation of a call +into ``VariableBuilder``. ``VariableBuilder``\ ’s call implementation +proxies into a function called ``_wrap``, which in turn both constructs +instances of ``VariableTracker`` and calls ``make_guards`` on them. More +on that later. + +This mapping, in turn, is critical as each Variable has associated +guards, which are then passed to ``self.output``, the instance of +``OutputGraph``, an fx tracer, mentioned in 4.2 of the section above. If +you recall, this ``OutputGraph``, stored in a variable called ``output`` +is where our guards are stored before being passed on to become +``GuardedCode`` + +How does ``InstructionTranslator`` do this? At the heart of it, there is +a loop that is pumped, which drives a function ``step``. + +``step`` is just that - a single processing step, taking exactly one +instruction and doing *something* with it. + +.. note:: These are real instructions processed by TorchDynamo’s + ``transform_code_object``, and it is pretty cool. + +.. note:: This section purposly skips the details of + `dis.get_instructions `__. + +For the example above, here is a snippet of a what a few +``Instruction``\'s may look like: + +.. code-block:: python + + Instruction(opcode=124, opname='LOAD_FAST', arg=0, argval='b', offset=32, starts_line=8, is_jump_target=True, target=None) + Instruction(opcode=100, opname='LOAD_CONST', arg=3, argval=-1, offset=34, starts_line=None, is_jump_target=False, target=None) + Instruction(opcode=20, opname='BINARY_MULTIPLY', arg=None, argval=None, offset=36, starts_line=None, is_jump_target=False, target=None) + +This is the core functionality of this function. Take a look at the ``opname``, +and then take a look at this little snippet from inside ``step``; + +.. code-block:: python + + if not hasattr(self, inst.opname): + unimplemented(f"missing: {inst.opname}") + getattr(self, inst.opname)(inst) + +As we can see, the function checks if the current class, the +``InstructionTranslator`` has an attribute set matching the operator name +(for example, ``LOAD_CONST``). If it does, the function invokes it, passing the +whole instruction object in. If it does not, the function drops the frame as +unimplemented. + +For the ``LOAD_CONST`` example, we can see that we do indeed support it, +with a relatively straightforward definition: + +:: + + def LOAD_CONST(self, inst): + self.push(ConstantVariable(value=inst.argval)) + +We can see that this function creates a new instance of the class +``ConstantVariable`` , with a value, in our example case, -1, and then +pushes it onto the stack. + +There are dozens of such methods - see ``symbolic_convert.py`` for all of +them. Generally, we implement as many matching methods to Python +bytecode instructions as possible. + +Across both the logic downstream of ``step`` and the logic from invoking +``VariableBuilder`` - we now have a lot of ``VariableTracker``\ s and of +course, we’ve spoken about creating guards quiet a bit. Let’s dig into +what Variables are, and get a little closer to understanding guards. + +Variables +--------- + +A ``ConstantVariable`` is an instance of\ ``VariableTracker``. +``VariableTracker`` represents a tracked Python local or stack value. + +When it comes to representing an object inside TorchDynamo, a +``VariableTracker`` does exactly what it says - it tracks a given variable. +It is an extremely flexible class, but there are a few points to keep in +mind: + +- It manages the ``guard`` relationship around the underlying object + through: + + - ``make_guard`` + - ``replace_guards`` + - ``add_guard(s)`` + - ``propagate`` - ``propagate(*vars: List[List["VariableTracker"]])`` - + Perhaps the most important of all, in that it combines guards from + all the provided ``VariableTracker`` instances passed in. It visits + the guards and combines the guards from these onto itself. + +- It acts as a proxy on behalf of the underlying object, implementing + methods for the rest of TorchDynamo to get information about the + tracked object: + + - ``call_method`` + - ``call_function`` + - ``python_type`` + - ``as_proxy`` + - ``is/as_python_proxy`` + +- It stores the variable ``source`` of type ``Source``, from + ``torchdynamo/source.py``. This source type is a relatively self + contained class that helps us organize and bookeep where the original + source came from, and helps provide convenience methods for things + like getting the name, and importantly for us, producing guards. + +And this class (``VariableTracker``) is built around subclassing, +somewhere between a full Abstract Base Class and fully fleshed out class +- it leaves many methods raising ``NotImplementedError`` - with reliance on +subclasses. See ``torchdynamo/variables/`` for all subclasses to fulfill +contracts and custom behaviors. + +Knowing what we know now, we can see an example of how an instruction +from ``dis``, ``BUILD_TUPLE``: + + ``BUILD_TUPLE(count)`` Creates a tuple consuming count items from the + stack, and pushes the resulting tuple onto the stack. + +In our case, our signature will be a *little* different due to the way +we create ``Instruction`` objects, but the gist of it will be the same. +Instead of passing in ``count``, we pass in an object with a little +extra bookkeeping, and of course, we deal with turning regular old +python objects into TorchDynamo notions: + +:: + + def BUILD_TUPLE(self, inst): + items = self.popn(inst.argval) + options = VariableTracker.propagate(items) + self.push(TupleVariable(items, **options)) + +Here is what this code does: + +1. The function reads ``argval``, which in this case, is + analogous to ``counts`` in the pydoc for the equivalent instruction. + +2. The function ``popn`` the items, in this case, the signature is + ``def popn(self, n: int) -> List[TensorVariable]:`` this hints at an + underlying contract - we are returning ``TensorVariables``. If we + take a closer look at ``sybmolic_convert.py`` and + ``InstructionTranslatorBase``/``InstructionTranslator``\ we see that + the only thing pushed onto and popped from our stack are + ``VariableTracker``\ s. + +3) The function calls ``VariableTracker.propogate``. This + takes the guards from every single item popped off the stack in 2, + and recursively traverses it and combines all the guards into + ``options``: ``py return { "guards": guards, }`` + +4) The function then makes a new instance of a ``VariableTracker``, + ``TupleVariable``\ out of the ``items`` and ``options``. This then + allows us to install all the appropriate guards from the ``items`` + that make up the new ``TupleVariable`` + +.. note:: Where did the first guards come from? Propagation + is a good technique, but we need something created before it can be + propagated. ``VariableBuilder`` calls + ``make_guards`` as it creates ``VariableTracker`` instances, from + ``f_locals``. This in turn calls into the ``source``, to have it create + guards. + +After all this, bytecode translation is done and we are one step closer +to producing ``GuardedCode``. We now understand how locals become +``VariableTracker``\ s, how instructions are handled, and where guards +are called on for creation. Before we can go into seeing how code and +guards are combined into a GuardedCode object, we need to dig a little +bit into those ``make_guard`` and ``source.make_guard`` calls above. We +can then understand, what was going on when we made guards +alongside, and on, ``VariableTracker`` instances. + +Making Guards +------------- + +Guards are just Python objects, of the class ``Guard``. Let's look at them +in more detail. + +Looking at the definition of the dataclass (and therefore, ctor +signature), we see that it has a name, a source, and a create function. + +:: + + @dataclasses.dataclass + class Guard: + name: str + source: GuardSource + create_fn: Callable + +The name should be the name of the variable. + +The source here is an enum indicating what *kind* of source the guard +belongs to. + +.. note:: Not to be confused with ``Source`` and the other types + in ``source.py``, as stored on ``VariableTracker``. + +``create_fn`` provides the main functionality to transition from a simple +dataclass to actually producing valid Python code to be invoked for +knowing whether or not things have changed in between invocations, and +whether we can safely read from the code cache or not. + +The most common code paths for getting an instance of a guard are +through ``make_guards`` on ``VariableTracker``. +``make_guards``->``source.make_guard``->``return Guard(self.name(), self.guard_source(), fn)`` + +Or, in a concrete example: + +.. code-block:: python + + ... + elif istype(value, range): + guards = self.make_guards(GuardBuilder.EQUALS_MATCH) + return RangeVariable(value=value, guards=guards) + +Since ``source`` was set at the construction time of this +``VariableTracker``, all that was needed here was to provide the ``fn``, +``GuardBuilder.EQUALS_MATCH`` to the ``create_fn`` field. + +This ``create_fn`` must be a method on ``GuardBuilder``. The reason for +this becomes apparent in our next step. Once we have all the guards +created for a frame, we move on to ``CheckFunctionManager`` and +``compile_check_fn``. + +Before the ``convert_frame`` function can produce a ``GuardedCode``, +it needs to run the ``CheckFunctionManager``, with all the guards, to +produce a ``check_fn`` which will then, in turn get passed in alongside +the code into ``GuardedCode``. This is the same ``check_fn`` that we store in our +cache entry, and the same one we run to know whether or not to retrieve +the code stored alongside. For reference, here is that code: + +.. code-block:: cpp + + static CacheEntry *create_cache_entry(CacheEntry *next, + PyObject *guarded_code) { + CacheEntry *e = (CacheEntry *)malloc(sizeof(CacheEntry)); + DEBUG_NULL_CHECK(e); + e->check_fn = PyObject_GetAttrString(guarded_code, "check_fn"); + NULL_CHECK(e->check_fn); + e->code = (PyCodeObject *)PyObject_GetAttrString(guarded_code, "code"); + NULL_CHECK(e->code); + e->next = next; + return e; + } + +We now know how a ``check_fn`` function is used, and who makes it, and +what it is composed of, but what we do not yet know is how. How does a +list of ``Guard`` objects become a function we can run later on? + +First, we iterate these guards: + +.. code-block:: python + + for guard in sorted(guards or [], key=Guard.sort_key): + if not config.guard_nn_modules and guard.is_nn_module(): + continue + guard.create(local_builder, global_builder) + +Calling ``guard.create`` runs that ``create_fn`` we set on the ``Guard`` +class above (don’t confuse it with the ``check_fn`` we are working on +producing, the names are similar, so it can get a little confusing). In +our example above, our ``create_fn`` is ``GuardBuilder.EQUALS_MATCH``. +So we are now invoking it, passing in the ``self``, the guard itself, +in. + +The signature is: ``def EQUALS_MATCH(self, guard: Guard):`` + +And internally to that function, we can use the ``name`` on the guard to +get back our original object, querying it for data and type information, +which in turn gets us to the most important bit: appending code. + +At its simplest, ``EQUALS_MATCH`` appends just one line of code: +``self.code.append(f"{ref} == {val!r}")``. Where ``ref`` is the name of +the variable, and ``val`` is the value. It might produce code like this: + +.. code-block:: + + y == 2 + +This is a basic example. But if we append a few other kinds of ``GuardBuilder`` +functions and then combine them all with +``and`` in between each statement (as we do), we might get something +like this: + +.. code-block:: + + ___guarded_code.valid and ___check_type_id(y, 94367738391392) and y == 2 and ___check_tensors(x) + +Here is what this code performs: + +1. A check for ``.valid`` +2. A type ID check +3. A value check +4. A tensor check + +This becomes the heart of the code our ``check_fn``, which in turn +is evaluated the **next** time we encounter this code. It +will then check: + +1. Is this code still valid? +2. If (1), Does ``y`` still have a type of ``94367738391392``? +3. If (2), is ``y`` still 2? +4. If (3), let’s check on if tensor ``x`` changed in some specific ways. + +If all of these are still true, then we can use the code cached +alongside this ``check_fn``. + +.. note:: For a deeper dive for how and where this happens + you can read ``static PyCodeObject *lookup(CacheEntry *e, PyObject *f_locals) {`` of + ``_eval_frame.c``. + +If not, then, we can move on to recompiling the code anew, and storing +that in the cache alongside this code, and a whole new ``check_fn``, +again to be checked on yet another subsequent frame. + +There are lots of other such functions on ``GuardBuilder`` which get +coalesced into, at times massive, strings which then get evaluated as +Python code and stored into ``check_fn``. The example above +illustrates of a simple case. To understand this functionality better, read +the other functions on ``GuardBuilder``, or better yet, dump the ``code`` variable +in ``compile_check_fn`` to see what is getting produced, +especially on larger, real models. + +Summary +------- + +In this section, we have reviewed: + +- The role of ``.valid`` and invalidation around weak references (and potentially soon to be NN Moduleinvalidations). +- How the C++ side of guard functions (``___check_type_id``, ``___check_tensors``, etc) operate +- What happens when guards fail. +- What happens if we produce invalid guard code. + +We covered how user provided code wrapped in a TorchDynamo context +goes on to get traced and tracked internally, organized into ``VariableTracker``\ s +``Source``\ s and subsequently ``Guard``\ s, and how those ``Guards`` in +turn guide cache entry selection and invalidation when handing Python +code. diff --git a/docs/source/dynamo/index.rst b/docs/source/dynamo/index.rst new file mode 100644 index 0000000000000..506880981b018 --- /dev/null +++ b/docs/source/dynamo/index.rst @@ -0,0 +1,45 @@ +TorchDynamo Overview +==================== + +**TorchDynamo** is a Python-level JIT compiler designed to make unmodified +PyTorch programs faster. TorchDynamo hooks into the frame evaluation API +in CPython (`PEP 523 `__) to +dynamically modify Python bytecode right before it is executed. It +rewrites Python bytecode in order to extract sequences of PyTorch +operations into an `FX Graph `__ +which is then just-in-time compiled with a customizable backend. +It creates this FX Graph through bytecode analysis and is designed to +mix Python execution with compiled backends to get the best of both +worlds — usability and performance. + +TorchDynamo makes it easy to experiment with different compiler +backends to make PyTorch code faster with a single line decorator +``torch._dynamo.optimize()`` + +.. image:: ../_static/img/dynamo/TorchDynamo.png + +`TorchInductor` is one of the backends +supported by `TorchDynamo Graph `__ +into `Triton `__ for GPUs or +`C++/OpenMP `__ for CPUs. We have a +`training performance dashboard `__ +that provides performance comparison for different training backends. You can read +more in the `TorchInductor post on PyTorch +dev-discuss `__. + +.. seealso:: + + * `TorchDynamo deep-dive video `__ + * `dev-discuss topics `__ + +.. toctree:: + :maxdepth: 1 + :hidden: + + installation + get-started + guards-overview + custom-backends + deep-dive + troubleshooting + faq diff --git a/docs/source/dynamo/installation.rst b/docs/source/dynamo/installation.rst new file mode 100644 index 0000000000000..687e9b072bafe --- /dev/null +++ b/docs/source/dynamo/installation.rst @@ -0,0 +1,85 @@ +Installing TorchDynamo +====================== + +This section describes how to install TorchDynamo. +TorchDynamo is included in the nightly binaries of PyTorch. For +more information, see `Getting Started `__. + +Requirements +------------ + +You must have the following prerequisites to use TorchDynamo: + +* A Linux or macOS environment +* Python 3.8 (recommended). Python 3.7 through 3.10 are supported and + tested. Make sure to have a development version of Python installed + locally as well. + +GPU/CUDA Requirements +~~~~~~~~~~~~~~~~~~~~~ + +To use GPU back ends, and in particular Triton, make sure that +the CUDA that you have installed locally matches the PyTorch version you +are running. + +The following command installs GPU PyTorch + TorchDynamo along with GPU +TorchDynamo dependencies (for CUDA 11.7): + +.. code-block:: shell + + pip3 install numpy --pre torch[dynamo] --force-reinstall --extra-index-url https://download.pytorch.org/whl/nightly/cu117 + +CPU requirements +~~~~~~~~~~~~~~~~ + +There are no additional requirements for CPU TorchDynamo. CPU +TorchDynamo is included in the nightly versions of PyTorch. +To install, run the following command: + +.. code-block:: shell + + pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu + + +Install from Local Source +~~~~~~~~~~~~~~~~~~~~~~~~~ + +Alternatively, you can build PyTorch from `source +`__, which has TorchDynamo +included. + +To install GPU TorchDynamo dependencies, run ``make triton`` in the +PyTorch repo root directory. + +Verify Installation +~~~~~~~~~~~~~~~~~~~ + +If you built PyTorch from source, then you can run the following +commands (from the PyTorch repo root directory) +to check that TorchDynamo is installed correctly: + +.. code-block:: shell + + cd tools/dynamo + python verify_dynamo.py + +If you do not have the PyTorch source locally, you can alternatively +copy the script (``tools/dynamo/verify_dynamo.py``) from the PyTorch +repository and run it locally. + +Docker Installation +------------------- + +We also provide all the required dependencies in the PyTorch nightly +binaries which you can download with the following command: + +.. code-block:: + + docker pull ghcr.io/pytorch/pytorch-nightly + +And for ad hoc experiments just make sure that your container has access +to all your GPUs: + +.. code-block:: bash + + docker run --gpus all -it ghcr.io/pytorch/pytorch-nightly:latest /bin/bash diff --git a/docs/source/dynamo/troubleshooting.rst b/docs/source/dynamo/troubleshooting.rst new file mode 100644 index 0000000000000..3fb33d91ddef8 --- /dev/null +++ b/docs/source/dynamo/troubleshooting.rst @@ -0,0 +1,667 @@ +TorchDynamo Troubleshooting +=========================== + +**Author**: `Michael Lazos `_ + +TorchDynamo is still in active development, and many of the reasons for +graph breaks and excessive recompilation will be fixed with upcoming +support for `tracing dynamic tensor +shapes `__, +more careful choices for guards and better tuned heuristics. + +In the meantime, you may need to diagnose a particular issue and +determine if it is easy to work around with a change to your model, or +file an issue for support. + +Also, we are actively developing debug tools, profilers, and improving our +errors/warnings. Please give us feedback if you have an issue with this +infra, or an idea for an improvement. Below is a table of the available +tools and their typical usage. For additional help see +`Diagnosing Runtime Errors <#diagnosing-runtime-errors>`__. + +.. list-table:: Title + :widths: 25 25 50 + :header-rows: 1 + + * - Tool + - Purpose + - Usage + * - Info logging + - View summarized steps of compilation + - ``torch._dynamo.config.log_level = logging.INFO`` + * - Debug logging + - View detailed steps of compilation (print every instruction traced) + - ``torch._dynamo.config.log_level = logging.DEBUG`` and + ``torch._dynamo.config.verbose = True`` + * - Minifier for any backend + - Find smallest subgraph which reproduces errors for any backend + - set environment variable ``TORCHDYNAMO_REPRO_AFTER="dynamo"`` + * - Minifier for ``TorchInductor`` + - If the error is known to occur after `AOTAutograd`` find + smallest subgraph wich reproduces errors during TorchInductor lowering + - set environment variable ``TORCHDYNAMO_REPRO_AFTER="aot"`` + * - Accuracy minifier + - Finds the smallest subgraph which reproduces an accuracy issue + between an eager model model and optimized model + - ``TORCHDYNAMO_REPRO_AFTER=<"aot"/"dynamo"> TORCHDYNAMO_REPRO_LEVEL=4`` + * - ``torch._dynamo.explain`` + - Find graph breaks and display reasoning for them + - ``torch._dynamo.explain(fn, *inputs)`` + * - Record/Replay + - Record and replay frames which to reproduce errors during graph capture + - ``torch._dynamo.config.replay_record_enabled = True`` + * - TorchDynamo function name filtering + - Only compile functions with the given name to reduce noise when + debugging an issue + - set environment variable ``TORCHDYNAMO_DEBUG_FUNCTION=`` + * - TorchInductor Debug logging + - Print general TorchInductor debug info and generated Triton/C++ code + - ``torch._inductor.config.debug = True`` + * - TorchInductor Tracing + - Show time taken in each TorchInductor stage + output code and graph + visualization + - set the environment variable TORCHINDUCTOR_TRACE=1 or + ``torch._inductor.config.trace.enabled = True`` + +Diagnosing Runtime Errors +~~~~~~~~~~~~~~~~~~~~~~~~~ + +Below is the TorchDynamo compiler stack. + +At a high level, the TorchDynamo stack consists of a graph capture from +Python code (TorchDynamo) and a backend compiler. In this example, the +backend compiler consists of backward graph tracing (AOTAutograd) and +graph lowering (TorchInductor)*. Errors can occur in any component of +the stack and will provide full stack traces. + +You may use info logging +(``torch._dynamo.config.log_level = logging.INFO``) and look for +``Step #: ...`` outputs in order to determine in which component the +error has occurred. Logs are made at the beginning and end of each step, +so the step that an error should correspond to is the most recent logged +step whose end has not yet been logged. The steps correspond to the +following parts of the stack (according to the image above): + +==== ================ +Step Component +==== ================ +1 TorchDynamo +2 Compiler Backend +3 TorchInductor +==== ================ + +The beginning and end of AOTAutograd is currently not logged, but we +plan to add it soon. + +If info logging is insufficient, then there are also some backend +options which can enable you to determine which component is causing the +error if you’re unable to understand the error message that is +generated. These are the following: + +- ``"eager"``: only runs torchdynamo forward graph capture and then + runs the captured graph with PyTorch. This provides an indication as + to whether TorchDynamo is raising the error. + +- ``"aot_eager"``: runs torchdynamo to capture a forward graph, and + then AOTAutograd to trace the backward graph without any additional + backend compiler steps. PyTorch eager will then be used to run the + forward and backward graphs. This is useful to narrow down the issue + to AOTAutograd. + +The general procedure to narrow down an issue is the following: + +1. Run your program with the ``"eager"`` backend. If the error no longer + occurs, the issue is in the backend compiler that is being used (if + using TorchInductor, proceed to step 2. If not, see `this + section <#minifying-backend-compiler-errors>`__). If the error still + occurs with the ``"eager"`` backend, it is an `error while running + torchdynamo <#torchdynamo-errors>`__. + +2. This step is only necessary if ``TorchInductor`` is used as the backend + compiler. Run the model with the ``"aot_eager"`` backend. If this + backend raises an error then the error is occurring during + AOTAutograd tracing. If the error no longer occurs with this backend, + then `the error is in + TorchInductor\* <#minifying-torchinductor-errors>`__. + +Each of these cases are analyzed in the following sections. + +.. note:: The TorchInductor backend consists of + both AOTAutograd tracing and the TorchInductor compiler itself. We will + disambiguate by referring to ``TorchInductor`` as the backend, and + TorchInductor lowering as the phase which lowers the graph traced by + AOTAutograd. + +Torchdynamo Errors +------------------ + +If the error that is generated occurs with the ``"eager"`` backend, then +TorchDynamo is the most likely source of the error. Here is a sample code +which will generate an error. + +.. code-block:: py + + import torch + + import torch._dynamo as dynamo + + + @dynamo.optimize("eager") + def test_assertion_error(): + y = torch.ones(200, 200) + z = {y: 5} + return z + + + test_assertion_error() + +Which will generate the following error: + +:: + + torch._dynamo.convert_frame: [ERROR] WON'T CONVERT test_assertion_error /scratch/mlazos/torchdynamo/../test/errors.py line 26 + due to: + Traceback (most recent call last): + File "/scratch/mlazos/torchdynamo/torchdynamo/symbolic_convert.py", line 837, in BUILD_MAP + assert isinstance(k, ConstantVariable) or ( + AssertionError + + from user code: + File "/scratch/mlazos/torchdynamo/../test/errors.py", line 34, in test_assertion_error + z = {y: 5} + + Set torch._dynamo.config.verbose=True for more information + ========== + +As the message suggests you can set +``torch._dynamo.config.verbose=True`` to get a full stack trace to both +the error in TorchDynamo and the user code. In addition to this flag, +you can also set the ``log_level`` of torchdynamo through +``torch._dynamo.config.log_level``. The available levels are the +following: +- ``logging.DEBUG``: Print every instruction that is +encountered in addition to all below log levels. +- ``logging.INFO``: +Print each function that is compiled (original and modified bytecode) +and the graph that is captured in addition to all below log levels. +- ``logging.WARNING`` (default): Print graph breaks in addition to all +below log levels. +- ``logging.ERROR``: Print errors only. + +If a model is sufficiently large, the logs can become overwhelming. If +an error occurs deep within a model’s Python code, it can be useful to +execute only the frame in which the error occurs to enable easier +debugging. There are two tools available to enable this: + +- Setting the environment variable ``TORCHDYNAMO_DEBUG_FUNCTION`` to the desired function name will only run torchdynamo on functions with that name. +- Enabling the record/replay tool (set ``torch._dynamo.config.replay_record_enabled = True``) which dumps anexecution record when an error is encountered. This record can then be replayed to run only the frame where an error occurred. + +TorchInductor Errors +-------------------- + +If the error does not occur with the ``"eager"`` backend, then the +backend compiler is the source of the error (`example +error `__). +There are `different +choices `__ +for backend compilers for TorchDynamo, with TorchInductor or nvfuser +fitting the needs of most users. This section focuses on TorchInductor +as the motivating example, but some tools will be usable with other +backend compilers. + +Below is the portion of the stack which we are focusing on: + +With TorchInductor as the chosen backend, AOTAutograd is used to +generate the backward graph from the forward graph captured by +torchdynamo. It is important to note that errors can occur during this +tracing and also while TorchInductor lowers the forward and backward +graphs to GPU code or C++. A model can often consist of hundreds or +thousands of FX nodes, so narrowing the exact nodes where this problem +occurred can be very difficult. Fortunately, there are tools availabe to +automatically minify these input graphs to the nodes which are causing +the issue. The first step is to determine whether the error occurs +during tracing of the backward graph with AOTAutograd or during +TorchInductor lowering. As mentioned above in step 2, the +``"aot_eager"`` backend can be used to run only AOTAutograd in isolation +without lowering. If the error still occurs with this backend, this +indicates that the error is occurring during AOTAutograd tracing. + +Here is an example: + +.. code-block:: py + + import torch + + import torch._dynamo as dynamo + + model = torch.nn.Sequential(*[torch.nn.Linear(200, 200) for _ in range(5)]) + @dynamo.optimize("inductor") + def test_backend_error(): + + y = torch.ones(200, 200) + x = torch.ones(200, 200) + z = x + y + a = torch.ops.aten._foobar(z) # dummy function which errors + return model(a) + + + test_backend_error() + +Running this should give you this error with a longer stack trace below +it: + +:: + + Traceback (most recent call last): + File "/scratch/mlazos/torchdynamo/torchinductor/graph.py", line 246, in call_function + return lowerings[target](*args, **kwargs) + File "/scratch/mlazos/torchdynamo/torchinductor/lowering.py", line 185, in wrapped + return decomp_fn(*args, **kwargs) + File "/scratch/mlazos/torchdynamo/torchinductor/lowering.py", line 810, in _foobar + assert False + AssertionError + ... + +`error with full stack +trace `__ + +If you then change ``@dynamo.optimize("inductor")`` to +``@dynamo.optimize("aot_eager")``, it will run without error, because +`the +issue `__ +is in the TorchInductor lowering process, not in AOTAutograd. + +Minifying TorchInductor Errors +------------------------------ + +From here, let’s run the minifier to get a minimal repro. Setting the +environment variable ``TORCHDYNAMO_REPRO_AFTER=“aot”`` (or setting +``torch._dynamo.config.repro_after="aot"`` directly) will generate a +Python program which reduces the graph produced by AOTAutograd to the +smallest subgraph which reproduces the error. (See below for an example +where we minify the graph produced by torchdynamo) Running the program +with this environment variable should show nearly `identical +output `__, +with an additional line indicating where ``minifier_launcher.py`` has +been written to. The output directory is configurable by setting +``torch._dynamo.config.base_dir`` to a valid directory name. The final +step is to run the minifier and check that it runs successfully. A +successful run looks like +`this `__. +If the minifier runs successfully, it generates runnable python code +which reproduces the exact error. For our example this is the following +code: + +.. code-block:: python + + import torch + from torch import tensor, device + import torch.fx as fx + from torch._dynamo.testing import rand_strided + from math import inf + from torch.fx.experimental.proxy_tensor import make_fx + + # torch version: 1.13.0a0+gitfddfc44 + # torch cuda version: 11.6 + # torch git version: fddfc4488afb207971c54ad4bf58130fdc8a4dc5 + + + # CUDA Info: + # nvcc: NVIDIA (R) Cuda compiler driver + # Copyright (c) 2005-2022 NVIDIA Corporation + # Built on Thu_Feb_10_18:23:41_PST_2022 + # Cuda compilation tools, release 11.6, V11.6.112 + # Build cuda_11.6.r11.6/compiler.30978841_0 + + # GPU Hardware Info: + # NVIDIA A100-SXM4-40GB : 8 + + + from torch.nn import * + class Repro(torch.nn.Module): + def __init__(self): + super().__init__() + + + + def forward(self, add): + _foobar = torch.ops.aten._foobar.default(add); add = None + return (_foobar,) + + args = [((200, 200), (200, 1), torch.float32, 'cpu')] + args = [rand_strided(shape, stride, dtype, device) for shape, stride, dtype, device in args] + mod = make_fx(Repro())(*args) + from torch._inductor.compile_fx import compile_fx_inner + + compiled = compile_fx_inner(mod, args) + compiled(*args) + +The ``forward`` method of the ``Repro`` module contains the exact op +which causes the issue. When filing an issue, please include any +minified repros to aid in debugging. + +Minifying Backend Compiler Errors +--------------------------------- + +With backend compilers other than TorchInductor the process for finding +the subgraph causing the error is nearly identical to the procedure in +`errors in TorchInductor <#torchinductor-errors>`__ with one important +caveat. Namely, that the minifier will now be run on the graph that is +traced by TorchDynamo, not the output graph of AOTAutograd. Let’s walk +through an example. + +.. code-block:: py + + import torch + + import torch._dynamo as dynamo + + model = torch.nn.Sequential(*[torch.nn.Linear(200, 200) for _ in range(5)]) + # toy compiler which fails if graph contains relu + def toy_compiler(gm: torch.fx.GraphModule, _): + for node in gm.graph.nodes: + if node.target == torch.relu: + assert False + + return gm + + + @dynamo.optimize(toy_compiler) + def test_backend_error(): + y = torch.ones(200, 200) + x = torch.ones(200, 200) + z = x + y + a = torch.relu(z) + return model(a) + + + test_backend_error() + +In order to run the code after TorchDynamo has traced the forward graph, +you can use the ``TORCHDYNAMO_REPRO_AFTER`` enviornment variable. Running +this program with ``TORCHDYNAMO_REPRO_AFTER=“dynamo”`` (or +``torch._dynamo.config.repro_after="dynamo"``) should produce `this +output `__\ and +the following code in ``{torch._dynamo.config.base_dir}/repro.py``. + +.. note:: The other option for TORCHDYNAMO_REPRO_AFTER are ``"aot"``, which + will run the minifier after the backward graph has been generated. + +.. code-block:: python + + import torch + import torch._dynamo as dynamo + from torch import tensor, device + import torch.fx as fx + from torch._dynamo.testing import rand_strided + from math import inf + from torch._dynamo.debug_utils import run_fwd_maybe_bwd + + + from torch.nn import * + class Repro(torch.nn.Module): + def __init__(self): + super().__init__() + + + + def forward(self, add): + relu = torch.relu(add); add = None + return (relu,) + + + mod = Repro().cuda() + opt_mod = dynamo.optimize("None")(mod) + + + args = [((200, 200), (200, 1), torch.float32, 'cpu', False)] + args = [rand_strided(sh, st, dt, dev).requires_grad_(rg) for (sh, st, dt, dev, rg) in args] + + + with torch.cuda.amp.autocast(enabled=False): + ref = run_fwd_maybe_bwd(mod, args) + res = run_fwd_maybe_bwd(opt_mod, args) + +The minifier successfully reduced the graph to the op that raises the +error in ``toy_compiler``. The other difference from the procedure in +`TorhInductor Errors <#torchinductor-errors>`__ is that the minifier is +automatically run after encountering a backend compiler error. After a +successful run, the minifier writes ``repro.py`` to +``torch._dynamo.config.base_dir``. + +Performance Profiling +~~~~~~~~~~~~~~~~~~~~~ + +Accessing TorchDynamo Profiler +------------------------------ + +TorchDynamo has a builtin stats function for collecting and displaying +the time spent in each compilation phase. These stats can be accessed by +calling ``torch._dynamo.utils.compile_times()`` after executing +Torch._Dynamo. By default, this returns a string representation of the +compile times spent in each TorchDynamo function by name. + +TorchInductor Debug Tracing +--------------------------- + +TorchInductor has a builtin stats and trace function for displaying time +spent in each compilation phase, output code, output graph visualization +and IR dump. This is a debugging tool designed to make it easier to +understand and troubleshoot the internals of TorchInductor. + +Setting the environment variable ``TORCHINDUCTOR_TRACE=1`` will cause a +debug trace directory to be created and printed: + +:: + + $ env TORCHINDUCTOR_TRACE=1 python repro.py + torch._inductor.debug: [WARNING] model_forward_0 debug trace: /tmp/torchinductor_jansel/rh/crhwqgmbqtchqt3v3wdeeszjb352m4vbjbvdovaaeqpzi7tdjxqr.debug + +Here is an `example debug directory +output `__ +for the test program: + +:: + + torch.nn.Sequential( + torch.nn.Linear(10, 10), + torch.nn.LayerNorm(10), + torch.nn.ReLU(), + ) + +Each file in that debug trace can be enabled and disabled through +``torch._inductor.config.trace.*``. The profile and the diagram are both +disabled by default since they are expensive to generate. + +A single node in this new debug format looks like: + +:: + + buf1: SchedulerNode(ComputedBuffer) + buf1.writes = + { MemoryDep(name='buf1', index=0, size=()), + MemoryDep(name='buf1', index=0, size=(s0,))} + buf1.unmet_dependencies = {MemoryDep(name='buf0', index=c0, size=(s0,))} + buf1.met_dependencies = {MemoryDep(name='primals_2', index=c0, size=(s0,))} + buf1.group.device = cuda:0 + buf1.group.iteration = (1, s0) + buf1.sizes = ([], [s0]) + class buf1_loop_body: + var_ranges = {z0: s0} + index0 = z0 + index1 = 0 + def body(self, ops): + get_index = self.get_index('index0') + load = ops.load('buf0', get_index, False) + get_index_1 = self.get_index('index0') + load_1 = ops.load('primals_2', get_index_1, False) + add = ops.add(load, load_1) + get_index_2 = self.get_index('index1') + reduction = ops.reduction('buf1', torch.float32, torch.float32, 'sum', get_index_2, add) + return reduction + +See the `example debug directory +output `__ +for more examples. + +.. + _Memory Profiling + ---------------- + + TBD + +Graph Breaks +------------ + +Given a program like this: + +.. code-block:: python + + @dynamo.optimize(...) + def some_fun(x): + ... + some_fun(x) + ... + +TorchDynamo will attempt to compile all of the torch/tensor operations +within some_fun into a single FX graph, but it may fail to capture +everything into one graph. + +Some graph break reasons are insurmountable to TorchDynamo, and can’t be +easily fixed. - calling into a C extension other than torch is invisible +to torchdynamo, and could do arbitrary things without TorchDynamo being +able to introduce necessary `guards <./GuardsOverviewPt1.md>`__ to +ensure that the compiled program would be safe to reuse. Graph breaks +can hinder performance if the resulting fragments are small. To maximize +performance, it’s important to have as few graph breaks as possible. + +Identifying the Cause of a Graph Break +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +To identify all graph breaks in a program and the associated reasons for +the breaks, ``torch._dynamo.explain`` can be used. This tool runs +TorchDynamo on the supplied function and aggregates the graph breaks +that are encountered. Here is an example usage: + +.. code-block:: python + + import torch + import torch._dynamo as dynamo + def toy_example(a, b): + x = a / (torch.abs(a) + 1) + print("woo") + if b.sum() < 0: + b = b * -1 + return x * b + explanation, out_guards, graphs, ops_per_graph = dynamo.explain(toy_example, torch.randn(10), torch.randn(10)) + print(explanation) + """ + Dynamo produced 3 graphs, with 2 graph break and 6 ops. + Break reasons: + 1. call_function BuiltinVariable(print) [ConstantVariable(str)] {} + File "t2.py", line 16, in toy_example + print("woo") + + 2. generic_jump + File "t2.py", line 17, in toy_example + if b.sum() < 0: + """ + +Outputs include: + +- ``out_guards`` - a list of lists where each sublist contains the guards that must pass to ensure the traced graphs are valid. +- ``graphs`` - a list of graph modules which were successfully traced. +- ``ops_per_graph`` - a list of lists where each sublist contains the ops that are run in the graph. + +To throw an error on the first graph break encountered, use the ``nopython`` +mode. This mode disables TorchDynamo’s Python fallback, and only +succeeds if the entire program is convertible into a single graph. Example +usage: + +.. code-block:: python + + @dynamo.optimize(, nopython=True) + def toy_example(a, b): + ... + +Excessive Recompilation +----------------------- + +When TorchDynamo compiles a function (or part of one), it makes certain +assumptions about locals and globals in order to allow compiler +optimizations, and expresses these assumptions as guards that check +particular values at runtime. If any of these guards fail, Dynamo will +recompile that function (or part) up to +``torch._dynamo.config.cache_size_limit`` times. If your program is +hitting the cache limit, you will first need to determine which guard is +failing and what part of your program is triggering it. + +The `recompilation profiler <#recompilation-profiler>`__ automates the +process of setting TorchDynamo’s cache limit to 1 and running your +program under an observation-only 'compiler' that records the causes of +any guard failures. You should be sure to run your program for at least +as long (as many iterations) as you were running when you ran into +trouble, and the profiler will accumulate statistics over this duration. + +If your program exhibits a bounded amount of dynamism, you may be able +to tune the TorchDynamo cache limit to allow for each variation to be +compiled and cached, but if the cache limit is too high you may find the +cost of recompilation outweighs any optimization benefits. + +:: + + torch._dynamo.config.cache_size_limit = + +Torchdynamo plans to support many common cases of dynamic tensor shapes, +such as varying batch size or sequence length. It does not plan to +support rank-dynamism. In the meantime, setting a specific cache limit +can be used in coordination with bucketing techniques to achieve an +acceptable number of recompilations for some dynamic models. + +.. code-block:: python + + prof = dynamo.utils.CompilationProfiler() + @dynamo.optimize(prof) + def my_model(): + ... + my_model() + print(prof.report()) + +Accuracy Debugging +~~~~~~~~~~~~~~~~~~ + +Accuracy issues can also be minified if you set the environment variable +``TORCHDYNAMO_REPRO_LEVEL=4``, it operates with a similar git bisect +model and a full repro might be something like +``TORCHDYNAMO_REPRO_AFTER="aot" TORCHDYNAMO_REPRO_LEVEL=4`` the reason +we need this is downstream compilers will codegen code whether it’s +Triton code or the C++ backend, the numerics from those downstream +compilers can be different in subtle ways yet have dramatic impact on +your training stability. So the accuracy debugger is very useful for us +to detect bugs in our codegen or with a backend compiler. + +File an Issue +~~~~~~~~~~~~~ + +If you experience problems with TorchDynamo, `file a github +issue `__. + +Before filing an issue, read over the `README <../README.md>`__, +`TROUBLESHOOTING <./TROUBLESHOOTING.md>`__, and search for similar +issues. + +When filing an issue, include the information about your +OS, Python< PyTorch, CUDA, and Triton versions info by running: + +.. code-block:: shell + + python tools/verify_install.py + +- A minimal repro script if possible, which can be generated by running + Minifier +- A description of the error +- The expected behavior +- A log (set ``torch._dynamo.config.log_file`` to a valid file name to + dump the logs to a file and + ``torch._dynamo.config.log_level = logging.DEBUG`` and + ``torch._dynamo.config.verbose = True``) diff --git a/docs/source/fsdp.rst b/docs/source/fsdp.rst index ff42770831b7e..feb6d8cd470b2 100644 --- a/docs/source/fsdp.rst +++ b/docs/source/fsdp.rst @@ -5,3 +5,15 @@ FullyShardedDataParallel .. autoclass:: torch.distributed.fsdp.FullyShardedDataParallel :members: + +.. autoclass:: torch.distributed.fsdp.BackwardPrefetch + :members: + +.. autoclass:: torch.distributed.fsdp.ShardingStrategy + :members: + +.. autoclass:: torch.distributed.fsdp.MixedPrecision + :members: + +.. autoclass:: torch.distributed.fsdp.CPUOffload + :members: diff --git a/docs/source/fx.rst b/docs/source/fx.rst index 988ae081125c7..29d73b3055dc9 100644 --- a/docs/source/fx.rst +++ b/docs/source/fx.rst @@ -36,7 +36,7 @@ What is an FX transform? Essentially, it's a function that looks like this. # Step 3: Construct a Module to return return torch.fx.GraphModule(m, graph) -Your transform will take in an :class:`torch.nn.Module`, acquire a :class:`Graph` +Your transform will take in a :class:`torch.nn.Module`, acquire a :class:`Graph` from it, do some modifications, and return a new :class:`torch.nn.Module`. You should think of the :class:`torch.nn.Module` that your FX transform returns as identical to a regular :class:`torch.nn.Module` -- you can pass it to another @@ -1039,7 +1039,7 @@ Miscellanea traced.eval() x = torch.randn(5, 3) - torch.testing.assert_allclose(traced(x), x) + torch.testing.assert_close(traced(x), x) """ AssertionError: Tensor-likes are not close! @@ -1071,7 +1071,7 @@ Miscellanea traced.eval() x = torch.randn(5, 3) - torch.testing.assert_allclose(traced(x), x) + torch.testing.assert_close(traced(x), x) - Because of this difference, consider marking modules that interact with the ``training`` flag dynamically as leaf modules. diff --git a/docs/source/index.rst b/docs/source/index.rst index b9d097f551913..93fccd2cb66f4 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -42,6 +42,21 @@ Features described in this documentation are classified by release status: notes/* +.. toctree:: + :glob: + :maxdepth: 1 + :caption: torch.compile + :hidden: + + dynamo/index + dynamo/installation + dynamo/get-started + dynamo/guards-overview + dynamo/custom-backends + dynamo/deep-dive + dynamo/troubleshooting + dynamo/faq + .. toctree:: :maxdepth: 1 :caption: Language Bindings @@ -51,7 +66,8 @@ Features described in this documentation are classified by release status: torch::deploy .. toctree:: - :maxdepth: 1 + :glob: + :maxdepth: 2 :caption: Python API torch @@ -70,7 +86,10 @@ Features described in this documentation are classified by release status: torch.distributed.elastic torch.distributed.fsdp torch.distributed.optim + torch.distributed.tensor.parallel + torch.distributed.checkpoint torch.distributions + torch._dynamo <_dynamo> torch.fft futures fx @@ -85,6 +104,7 @@ Features described in this documentation are classified by release status: profiler nn.init onnx + onnx_diagnostics optim complex_numbers ddp_comm_hooks diff --git a/docs/source/linalg.rst b/docs/source/linalg.rst index 02950ff971a62..aec7031e2248e 100644 --- a/docs/source/linalg.rst +++ b/docs/source/linalg.rst @@ -6,6 +6,8 @@ torch.linalg Common linear algebra operations. +See :ref:`Linear Algebra Stability` for some common numerical edge-cases. + .. automodule:: torch.linalg .. currentmodule:: torch.linalg @@ -43,6 +45,8 @@ Decompositions svd svdvals +.. _linalg solvers: + Solvers ------- @@ -55,6 +59,8 @@ Solvers lu_solve lstsq +.. _linalg inverses: + Inverses -------- diff --git a/docs/source/masked.rst b/docs/source/masked.rst index d6ae9f7d56728..60b9af7ebcccb 100644 --- a/docs/source/masked.rst +++ b/docs/source/masked.rst @@ -56,10 +56,10 @@ There are already a number of existing tutorials that we've written to help user - `Advanced semantics - discussion on why certain decisions were made (e.g. requiring masks to match for binary/reduction operations), differences with NumPy's MaskedArray, and reduction semantics`_ -.. _Overview - the place to start for new users, discusses how to use MaskedTensors and why they're useful: https://pytorch.org/tutorials/prototype/maskedtensor_overview.html/ -.. _Sparsity - MaskedTensor supports sparse COO and CSR data and mask Tensors: https://pytorch.org/tutorials/prototype/maskedtensor_sparsity.html/ -.. _Adagrad sparse semantics - a practical example of how MaskedTensor can simplify sparse semantics and implementations: https://pytorch.org/tutorials/prototype/maskedtensor_adagrad_semantics.html> -.. _Advanced semantics - discussion on why certain decisions were made (e.g. requiring masks to match for binary/reduction operations), differences with NumPy's MaskedArray, and reduction semantics: https://pytorch.org/tutorials/prototype/maskedtensor_advanced_semantics.html/ +.. _Overview - the place to start for new users, discusses how to use MaskedTensors and why they're useful: https://pytorch.org/tutorials/prototype/maskedtensor_overview +.. _Sparsity - MaskedTensor supports sparse COO and CSR data and mask Tensors: https://pytorch.org/tutorials/prototype/maskedtensor_sparsity +.. _Adagrad sparse semantics - a practical example of how MaskedTensor can simplify sparse semantics and implementations: https://pytorch.org/tutorials/prototype/maskedtensor_adagrad +.. _Advanced semantics - discussion on why certain decisions were made (e.g. requiring masks to match for binary/reduction operations), differences with NumPy's MaskedArray, and reduction semantics: https://pytorch.org/tutorials/prototype/maskedtensor_advanced_semantics Supported Operators +++++++++++++++++++ @@ -157,7 +157,7 @@ Binary Operators As you may have seen in the tutorial, :class:`MaskedTensor` also has binary operations implemented with the caveat that the masks in the two MaskedTensors must match or else an error will be raised. As noted in the error, if you need support for a particular operator or have proposed semantics for how they should behave instead, please open -an issue on Github. For now, we have decided to go with the most conservative implementation to ensure that users +an issue on GitHub. For now, we have decided to go with the most conservative implementation to ensure that users know exactly what is going on and are being intentional about their decisions with masked semantics. The available binary operators are: diff --git a/docs/source/mobile_optimizer.rst b/docs/source/mobile_optimizer.rst index bb11abf82dbac..4df148dc707b8 100644 --- a/docs/source/mobile_optimizer.rst +++ b/docs/source/mobile_optimizer.rst @@ -7,13 +7,16 @@ torch.utils.mobile_optimizer Torch mobile supports ``torch.mobile_optimizer.optimize_for_mobile`` utility to run a list of optimization pass with modules in eval mode. The method takes the following parameters: a torch.jit.ScriptModule object, a blocklisting optimization set and a preserved method list -By default, if optimization blocklist is None or empty, ``optimize_for_mobile`` will run the following optimizations: +For CPU Backend, by default, if optimization blocklist is None or empty, ``optimize_for_mobile`` will run the following optimizations: - **Conv2D + BatchNorm fusion** (blocklisting option `MobileOptimizerType::CONV_BN_FUSION`): This optimization pass folds ``Conv2d-BatchNorm2d`` into ``Conv2d`` in ``forward`` method of this module and all its submodules. The weight and bias of the ``Conv2d`` are correspondingly updated. - **Insert and Fold prepacked ops** (blocklisting option `MobileOptimizerType::INSERT_FOLD_PREPACK_OPS`): This optimization pass rewrites the graph to replace 2D convolutions and linear ops with their prepacked counterparts. Prepacked ops are stateful ops in that, they require some state to be created, such as weight prepacking and use this state, i.e. prepacked weights, during op execution. XNNPACK is one such backend that provides prepacked ops, with kernels optimized for mobile platforms (such as ARM CPUs). Prepacking of weight enables efficient memory access and thus faster kernel execution. At the moment ``optimize_for_mobile`` pass rewrites the graph to replace ``Conv2D/Linear`` with 1) op that pre-packs weight for XNNPACK conv2d/linear ops and 2) op that takes pre-packed weight and activation as input and generates output activations. Since 1 needs to be done only once, we fold the weight pre-packing such that it is done only once at model load time. This pass of the ``optimize_for_mobile`` does 1 and 2 and then folds, i.e. removes, weight pre-packing ops. - **ReLU/Hardtanh fusion**: XNNPACK ops support fusion of clamping. That is clamping of output activation is done as part of the kernel, including for 2D convolution and linear op kernels. Thus clamping effectively comes for free. Thus any op that can be expressed as clamping op, such as ``ReLU`` or ``hardtanh``, can be fused with previous ``Conv2D`` or ``linear`` op in XNNPACK. This pass rewrites graph by finding ``ReLU/hardtanh`` ops that follow XNNPACK ``Conv2D/linear`` ops, written by the previous pass, and fuses them together. - **Dropout removal** (blocklisting option `MobileOptimizerType::REMOVE_DROPOUT`): This optimization pass removes ``dropout`` and ``dropout_`` nodes from this module when training is false. - **Conv packed params hoisting** (blocklisting option `MobileOptimizerType::HOIST_CONV_PACKED_PARAMS`): This optimization pass moves convolution packed params to the root module, so that the convolution structs can be deleted. This decreases model size without impacting numerics. +for Vulkan Backend, by default, if optimization blocklist is None or empty, ``optimize_for_mobile`` will run the folllwing optimization: + - **Automatic GPU Transfer** (blocklisting option `MobileOptimizerType::VULKAN_AUTOMATIC_GPU_TRANSFER`): This optimization pass rewrites the graph such that inputs are transferred to Vulkan backend, and outputs are transferred to CPU backend + ``optimize_for_mobile`` will also invoke freeze_module pass which only preserves ``forward`` method. If you have other method to that needed to be preserved, add them into the preserved method list and pass into the method. diff --git a/docs/source/nested.rst b/docs/source/nested.rst index 4cfb5bdf701ae..ac07f8acb5a23 100644 --- a/docs/source/nested.rst +++ b/docs/source/nested.rst @@ -196,11 +196,12 @@ NestedTensor and any constraints they have. :func:`torch.nn.Dropout`; "Behavior is the same as on regular tensors." :func:`torch.relu`; "Behavior is the same as on regular tensors." :func:`torch.gelu`; "Behavior is the same as on regular tensors." + :func:`torch.neg`; "Behavior is the same as on regular tensors." :func:`torch.add`; "Supports elementwise addition of two nested tensors. Supports addition of a scalar to a nested tensor." :func:`torch.mul`; "Supports elementwise multiplication of two nested tensors. - Supports multipication of a nested tensor by a scalar." - :func:`torch.select`; "Supports selecting along ``dim=0`` only (analogously ``nt[i]``)." + Supports multiplication of a nested tensor by a scalar." + :func:`torch.select`; "Supports selecting along all dimensions." :func:`torch.clone`; "Behavior is the same as on regular tensors." :func:`torch.detach`; "Behavior is the same as on regular tensors." :func:`torch.unbind`; "Supports unbinding along ``dim=0`` only." diff --git a/docs/source/notes/autograd.rst b/docs/source/notes/autograd.rst index 6eec13a7de557..08ae3957b00a0 100644 --- a/docs/source/notes/autograd.rst +++ b/docs/source/notes/autograd.rst @@ -13,7 +13,7 @@ programs, and can aid you in debugging. How autograd encodes the history -------------------------------- -Autograd is reverse automatic differentiation system. Conceptually, +Autograd is a reverse automatic differentiation system. Conceptually, autograd records a graph recording all of the operations that created the data as you execute operations, giving you a directed acyclic graph whose leaves are the input tensors and roots are the output tensors. @@ -23,11 +23,11 @@ compute the gradients using the chain rule. Internally, autograd represents this graph as a graph of :class:`Function` objects (really expressions), which can be :meth:`~torch.autograd.Function.apply` ed to compute the result of -evaluating the graph. When computing the forwards pass, autograd +evaluating the graph. When computing the forward pass, autograd simultaneously performs the requested computations and builds up a graph representing the function that computes the gradient (the ``.grad_fn`` attribute of each :class:`torch.Tensor` is an entry point into this graph). -When the forwards pass is completed, we evaluate this graph in the +When the forward pass is completed, we evaluate this graph in the backwards pass to compute the gradients. An important thing to note is that the graph is recreated from scratch at every @@ -119,7 +119,7 @@ For more fine-grained exclusion of subgraphs from gradient computation, there is setting the ``requires_grad`` field of a tensor. Below, in addition to discussing the mechanisms above, we also describe -evaluation mode (:meth:`nn.Module.eval()`), a method that is not actually used +evaluation mode (:meth:`nn.Module.eval()`), a method that is not used to disable gradient computation but, because of its name, is often mixed up with the three. Setting ``requires_grad`` @@ -164,8 +164,8 @@ of the module's parameters (which have ``requires_grad=True`` by default). Grad Modes ^^^^^^^^^^ -Apart from setting ``requires_grad`` there are also three possible modes -enableable from Python that can affect how computations in PyTorch are +Apart from setting ``requires_grad`` there are also three grad modes that can +be selected from Python that can affect how computations in PyTorch are processed by autograd internally: default mode (grad mode), no-grad mode, and inference mode, all of which can be togglable via context managers and decorators. @@ -173,7 +173,7 @@ decorators. Default Mode (Grad Mode) ^^^^^^^^^^^^^^^^^^^^^^^^ -The "default mode" is actually the mode we are implicitly in when no other modes like +The "default mode" is the mode we are implicitly in when no other modes like no-grad and inference mode are enabled. To be contrasted with "no-grad mode" the default mode is also sometimes called "grad mode". @@ -237,7 +237,7 @@ For implementation details of inference mode see Evaluation Mode (``nn.Module.eval()``) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Evaluation mode is not actually a mechanism to locally disable gradient computation. +Evaluation mode is not a mechanism to locally disable gradient computation. It is included here anyway because it is sometimes confused to be such a mechanism. Functionally, ``module.eval()`` (or equivalently ``module.train(False)``) are completely @@ -263,7 +263,7 @@ In-place operations with autograd Supporting in-place operations in autograd is a hard matter, and we discourage their use in most cases. Autograd's aggressive buffer freeing and reuse makes it very efficient and there are very few occasions when in-place operations -actually lower memory usage by any significant amount. Unless you're operating +lower memory usage by any significant amount. Unless you're operating under heavy memory pressure, you might never need to use them. There are two main reasons that limit the applicability of in-place operations: @@ -271,13 +271,13 @@ There are two main reasons that limit the applicability of in-place operations: 1. In-place operations can potentially overwrite values required to compute gradients. -2. Every in-place operation actually requires the implementation to rewrite the +2. Every in-place operation requires the implementation to rewrite the computational graph. Out-of-place versions simply allocate new objects and keep references to the old graph, while in-place operations, require changing the creator of all inputs to the :class:`Function` representing this operation. This can be tricky, especially if there are many Tensors that reference the same storage (e.g. created by indexing or transposing), - and in-place functions will actually raise an error if the storage of + and in-place functions will raise an error if the storage of modified inputs is referenced by any other :class:`Tensor`. In-place correctness checks @@ -338,18 +338,18 @@ serializing all the backward calls in a specific order during execution Non-determinism ^^^^^^^^^^^^^^^ -If you are calling ``backward()`` on multiple thread concurrently but with -shared inputs (i.e. Hogwild CPU training). Since parameters are automatically -shared across threads, gradient accumulation might become non-deterministic on -backward calls across threads, because two backward calls might access and try -to accumulate the same ``.grad`` attribute. This is technically not safe, and -it might result in racing condition and the result might be invalid to use. +If you are calling ``backward()`` from multiple threads concurrently and have +shared inputs (i.e. Hogwild CPU training), then non-determinsim should be expected. +This can occur because parameters are automatically shared across threads, +as such, multiple threads may access and try to accumulate the same ``.grad`` +attribute during gradient accumulation. This is technically not safe, and +it might result in race condition and the result might be invalid to use. -But this is expected pattern if you are using the multithreading approach to -drive the whole training process but using shared parameters, user who use -multithreading should have the threading model in mind and should expect this -to happen. User could use the functional API :func:`torch.autograd.grad` to -calculate the gradients instead of ``backward()`` to avoid non-determinism. +Users developing multithreaded models featuring shared parameters should have the +threading model in mind and should understand the issues described above. + +The functional API :func:`torch.autograd.grad` may be used to calculate the +gradients instead of ``backward()`` to avoid non-determinism. Graph retaining ^^^^^^^^^^^^^^^ @@ -368,9 +368,9 @@ Thread Safety on Autograd Node Since Autograd allows the caller thread to drive its backward execution for potential parallelism, it's important that we ensure thread safety on CPU with -parallel backwards that share part/whole of the GraphTask. +parallel ``backward()`` calls that share part/whole of the GraphTask. -Custom Python ``autograd.Function`` is automatically thread safe because of GIL. +Custom Python ``autograd.Function``\s are automatically thread safe because of GIL. For built-in C++ Autograd Nodes (e.g. AccumulateGrad, CopySlices) and custom ``autograd::Function``\s, the Autograd Engine uses thread mutex locking to ensure thread safety on autograd Nodes that might have state write/read. @@ -440,8 +440,8 @@ It also turns out that no interesting real-valued objective fulfill the Cauchy-Riemann equations. So the theory with homomorphic function cannot be used for optimization and most people therefore use the Wirtinger calculus. -Wirtinger Calculus comes in picture ... -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Wirtinger Calculus comes into the picture ... +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ So, we have this great theory of complex differentiability and holomorphic functions, and we can’t use any of it at all, because many diff --git a/docs/source/notes/cuda.rst b/docs/source/notes/cuda.rst index b376adcff2554..4a1538900a88c 100644 --- a/docs/source/notes/cuda.rst +++ b/docs/source/notes/cuda.rst @@ -394,6 +394,12 @@ Available options: the size 1200 lies between 1024 and 2048 and if we do 4 divisions between them, the values are 1024, 1280, 1536, and 1792. So, allocation size of 1200 will be rounded to 1280 as the nearest ceiling of power-2 division. + Specify a single value to apply for all allocation sizes or specify an + array of key value pairs to set power-2 division individually for each + power of two interval. For example to set 1 division for all allocations + under 256MB, 2 division for allocations between 256MB and 512MB, 4 divisions + for allocations between 512MB and 1GB and 8 divisions for any larger allocations, + set the knob value to: [256:1,512:2,1024:4,>:8]. ``roundup_power2_divisions`` is only meaningful with ``backend:native``. With ``backend:cudaMallocAsync``, ``roundup_power2_divisions`` is ignored. * ``roundup_bypass_threshold_mb`` bypass rounding the requested allocation size, @@ -424,6 +430,66 @@ Available options: .. _CUDA's built-in asynchronous allocator: https://developer.nvidia.com/blog/using-cuda-stream-ordered-memory-allocator-part-1/ +.. _cuda-memory-custom-allocator: + +Using custom memory allocators for CUDA +--------------------------------------- + +It is possible to define allocators as simple functions in C/C++ and compile +them as a shared library, the code below shows a basic allocator that just +traces all the memory operations. + +.. code:: C++ + + #include + #include + #include + // Compile with g++ alloc.cc -o alloc.so -I/usr/local/cuda/include -shared -fPIC + extern "C" { + void* my_malloc(ssize_t size, int device, cudaStream_t stream) { + void *ptr; + cudaMalloc(&ptr, size); + std::cout<<"alloc "<`_ for more information.) +TorchDynamo support for DDP currently requires setting `static_graph=True` and `find_unused_parameters=True`, due to +interactions between the graph tracing process and DDP's mechanism for observing operations happening on its module, +but this should be fixed ultimately. + +.. code:: + + ddp_model = DDP(model, device_ids=[rank]) + ddp_model = torch.compile(ddp_model) Internal Design ^^^^^^^^^^^^^^^ @@ -193,3 +204,24 @@ DistributedDataParallel .. image:: https://user-images.githubusercontent.com/16999635/72313120-4e7c1c80-3658-11ea-9c6d-44336b2daeac.png :alt: ddp_code.png :width: 400 px + + +TorchDynamo DDPOptimizer +------------------------ + +DDP's performance advantage comes from overlapping allreduce collectives with computations during backwards. +AotAutograd prevents this overlap when used with TorchDynamo for compiling a whole forward and whole backward graph, +becuase allreduce ops are launched by autograd hooks _after_ the whole optimized backwards computation finishes. + +TorchDynamo's DDPOptimizer helps by breaking the forward graph at the logical boundaries of DDP's allreduce buckets +during backwards. Note: the goal is to break the graph during backwards, and the simplest implementation is to +break the forward graphs and then call AotAutograd and compilation on each section. This allows DDP's allreduce hooks +to fire in-between sections of backwards, and schedule communications to overlap with compute. + +See `this blog post `_ for +a more in-depth explanation and experimental results, or read the docs and code at +`torch/_dynamo/optimizations/distributed.py `_ + +To Debug DDPOptimizer, set `torch._dynamo.config.log_level` to DEBUG (for full graph dumps) or INFO +(for basic info about bucket boundaries). To disable DDPOptimizer, set `torch._dynamo.config.optimize_ddp=False`. +DDP and TorchDynamo should still work correctly without DDPOptimizer, but with performance degradation. \ No newline at end of file diff --git a/docs/source/notes/hip.rst b/docs/source/notes/hip.rst index a9c94e2a4febb..103c5db7d460a 100644 --- a/docs/source/notes/hip.rst +++ b/docs/source/notes/hip.rst @@ -130,17 +130,35 @@ NOTE: The CUDA_VERSION macro, cudaRuntimeGetVersion and cudaDriverGetVersion API semantically map to the same values as HIP_VERSION macro, hipRuntimeGetVersion and hipDriverGetVersion APIs. Please do not use them interchangeably when doing version checks. -Eg: Instead of -#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 -If it is desired to not take the code path for ROCm/HIP: -#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 && !defined(USE_ROCM) -If it is desired to take the code path for ROCm/HIP: -#if (defined(CUDA_VERSION) && CUDA_VERSION >= 11000) || defined(USE_ROCM) -If it is desired to take the code path for ROCm/HIP only for specific HIP versions: -#if (defined(CUDA_VERSION) && CUDA_VERSION >= 11000) || (defined(USE_ROCM) && ROCM_VERSION >= 40300) +For example: Instead of using + +``#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000`` to implicitly exclude ROCm/HIP, + +use the following to not take the code path for ROCm/HIP: + +``#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 && !defined(USE_ROCM)`` + +Alternatively, if it is desired to take the code path for ROCm/HIP: + +``#if (defined(CUDA_VERSION) && CUDA_VERSION >= 11000) || defined(USE_ROCM)`` + +Or if it is desired to take the code path for ROCm/HIP only for specific HIP versions: + +``#if (defined(CUDA_VERSION) && CUDA_VERSION >= 11000) || (defined(USE_ROCM) && ROCM_VERSION >= 40300)`` Refer to CUDA Semantics doc --------------------------- For any sections not listed here, please refer to the CUDA semantics doc: :ref:`cuda-semantics` + + +Enabling kernel asserts +----------------------- + +Kernel asserts are supported on ROCm, but they are disabled due to performance overhead. It can be enabled +by recompiling the PyTorch from source. + +Please add below line as an argument to cmake command parameters:: + + -DROCM_FORCE_ENABLE_GPU_ASSERTS:BOOL=ON diff --git a/docs/source/notes/modules.rst b/docs/source/notes/modules.rst index 7eea02dfa857f..49b27a0ae0142 100644 --- a/docs/source/notes/modules.rst +++ b/docs/source/notes/modules.rst @@ -599,7 +599,7 @@ PyTorch provides two types of hooks for modules: * **Forward hooks** are called during the forward pass. They can be installed for a given module with :func:`~torch.nn.Module.register_forward_pre_hook` and :func:`~torch.nn.Module.register_forward_hook`. These hooks will be called respectively just before the forward function is called and just after it is called. - Alternatively, these hooks can be installed globally for all modules with the analagous + Alternatively, these hooks can be installed globally for all modules with the analogous :func:`~torch.nn.modules.module.register_module_forward_pre_hook` and :func:`~torch.nn.modules.module.register_module_forward_hook` functions. * **Backward hooks** are called during the backward pass. They can be installed with diff --git a/docs/source/notes/numerical_accuracy.rst b/docs/source/notes/numerical_accuracy.rst index b1d05f9460419..82e0bb253129e 100644 --- a/docs/source/notes/numerical_accuracy.rst +++ b/docs/source/notes/numerical_accuracy.rst @@ -34,9 +34,10 @@ even though mathematically it's an identical computation. Similarly, an operation applied to a tensor slice is not guaranteed to produce results that are identical to the slice of the result of the same operation applied to the full tensor. E.g. let -``A`` be a 2-dimentional tensor. ``A.sum(-1)[0]`` is not guaranteed to be bitwise equal to +``A`` be a 2-dimensional tensor. ``A.sum(-1)[0]`` is not guaranteed to be bitwise equal to ``A[:,0].sum()``. + Extremal values --------------- @@ -51,6 +52,40 @@ datatype. E.g.: a.norm() # produces tensor(inf) a.double().norm() # produces tensor(1.4142e+20, dtype=torch.float64), representable in fp32 +.. _Linear Algebra Stability: + +Linear algebra (``torch.linalg``) +--------------------------------- + +Non-finite values +""""""""""""""""" + +The external libraries (backends) that ``torch.linalg`` uses provide no guarantees on their behaviour +when the inputs have non-finite values like ``inf`` or ``NaN``. As such, neither does PyTorch. +The operations may return a tensor with non-finite values, or raise an exception, or even segfault. + +Consider using :func:`torch.isfinite` before calling these functions to detect this situation. + +Extremal values in linalg +""""""""""""""""""""""""" + +Functions within ``torch.linalg`` have more `Extremal Values`_ than other PyTorch functions. + +:ref:`linalg solvers` and :ref:`linalg inverses` assume that the input matrix ``A`` is invertible. If it is close to +being non-invertible (for example, if it has a very small singular value), then these algorithms may silently return +incorrect results. These matrices are said to be `ill-conditioned `_. +If provided with ill-conditioned inputs, the result of these functions they may vary when using the same inputs on different +devices or when using different backends via the keyword ``driver``. + +Spectral operations like ``svd``, ``eig``, and ``eigh`` may also return incorrect results (and their gradients may be infinite) +when their inputs have singular values that are close to each other. This is because the algorithms used to compute these decompositions +struggle to converge for these inputs. + +Running the computation in ``float64`` (as NumPy does by default) often helps, but it does not solve these issues in all cases. +Analyzing the spectrum of the inputs via :func:`torch.linalg.svdvals` or their condition number via :func:`torch.linalg.cond` +may help to detect these issues. + + TensorFloat-32(TF32) on Nvidia Ampere devices --------------------------------------------- diff --git a/docs/source/onnx.rst b/docs/source/onnx.rst index 78ef3cd93663c..8f52be124e2ea 100644 --- a/docs/source/onnx.rst +++ b/docs/source/onnx.rst @@ -499,6 +499,7 @@ ONNX operators that represent the function's behavior in ONNX. For example:: Inline Autograd Function ~~~~~~~~~~~~~~~~~~~~~~~~ + In cases where a static symbolic method is not provided for its subsequent :class:`torch.autograd.Function` or where a function to register ``prim::PythonOp`` as custom symbolic functions is not provided, :func:`torch.onnx.export` tries to inline the graph that corresponds to that :class:`torch.autograd.Function` such that @@ -526,6 +527,73 @@ If you need to avoid inlining of :class:`torch.autograd.Function`, you should ex Custom operators ^^^^^^^^^^^^^^^^ +You can export your model with custom operators that includes a combination of many standard ONNX ops, +or are driven by self-defined C++ backend. + +ONNX-script functions +~~~~~~~~~~~~~~~~~~~~~ + +If an operator is not a standard ONNX op, but can be composed of multiple existing ONNX ops, you can utilize +`ONNX-script `_ to create an external ONNX function to support the operator. +You can export it by following this example:: + + import onnxscript + # There are three opset version needed to be aligned + # This is (1) the opset version in ONNX function + from onnxscript.onnx_opset import opset15 as op + opset_version = 15 + + x = torch.randn(1, 2, 3, 4, requires_grad=True) + model = torch.nn.SELU() + + custom_opset = onnxscript.values.Opset(domain="onnx-script", version=1) + + @onnxscript.script(custom_opset) + def Selu(X): + alpha = 1.67326 # auto wrapped as Constants + gamma = 1.0507 + alphaX = op.CastLike(alpha, X) + gammaX = op.CastLike(gamma, X) + neg = gammaX * (alphaX * op.Exp(X) - alphaX) + pos = gammaX * X + zero = op.CastLike(0, X) + return op.Where(X <= zero, neg, pos) + + # setType API provides shape/type to ONNX shape/type inference + def custom_selu(g: jit_utils.GraphContext, X): + return g.onnxscript_op(Selu, X).setType(X.type()) + + # Register custom symbolic function + # There are three opset version needed to be aligned + # This is (2) the opset version in registry + torch.onnx.register_custom_op_symbolic( + symbolic_name="aten::selu", + symbolic_fn=custom_selu, + opset_version=opset_version, + ) + + # There are three opset version needed to be aligned + # This is (2) the opset version in exporter + torch.onnx.export( + model, + x, + "model.onnx", + opset_version=opset_version, + # only needed if you want to specify an opset version > 1. + custom_opsets={"onnx-script": 2} + ) + +The example above exports it as a custom operator in the "onnx-script" opset. +When exporting a custom operator, you can specify the custom domain version using the +``custom_opsets`` dictionary at export. If not specified, the custom opset version defaults to 1. + +NOTE: Be careful to align the opset version mentioned in the above example, and make sure they are consumed in exporter step. +The example usage of how to write a onnx-script function is a beta version in terms of the active development on onnx-script. +Please follow the latest `ONNX-script `_ + +C++ Operators +~~~~~~~~~~~~~ + If a model uses a custom operator implemented in C++ as described in `Extending TorchScript with Custom C++ Operators `_, you can export it by following this example:: @@ -563,8 +631,6 @@ you can export it by following this example:: custom_opsets={"custom_domain": 2} ) -You can export your model as one or a combination of many standard ONNX ops, or as a custom ONNX operator. - The example above exports it as a custom operator in the "custom_domain" opset. When exporting a custom operator, you can specify the custom domain version using the ``custom_opsets`` dictionary at export. If not specified, the custom opset version defaults to 1. @@ -594,7 +660,7 @@ all of the unconvertible ops in one go you can:: The set is approximated because some ops may be removed during the conversion process and don't need to be converted. Some other ops may have partial support that will fail conversion with particular inputs, but this should give you a -general idea of what ops are not supported. Please feel free to open Github Issues +general idea of what ops are not supported. Please feel free to open GitHub Issues for op support requests. Frequently Asked Questions diff --git a/docs/source/onnx_diagnostics.rst b/docs/source/onnx_diagnostics.rst new file mode 100644 index 0000000000000..ec2edd4cbdbe7 --- /dev/null +++ b/docs/source/onnx_diagnostics.rst @@ -0,0 +1,35 @@ +torch.onnx diagnostics +====================== + +.. contents:: :local: +.. automodule:: torch.onnx._internal.diagnostics +.. currentmodule:: torch.onnx._internal.diagnostics + +Overview +-------- + +NOTE: This feature is underdevelopment and is subject to change. + +The goal is to improve the diagnostics to help users debug and improve their model export to ONNX. + +- The diagnostics are emitted in machine parsable `Static Analysis Results Interchange Format (SARIF) `__. +- A new clearer, structured way to add new and keep track of diagnostic rules. +- Serve as foundation for more future improvements consuming the diagnostics. + + +Diagnostic Rules +---------------- + +.. toctree:: + :glob: + + generated/onnx_diagnostics_rules/* + +API Reference +------------- + +.. autoclass:: torch.onnx._internal.diagnostics.ExportDiagnostic + :members: + +.. autoclass:: torch.onnx._internal.diagnostics.infra.DiagnosticEngine + :members: diff --git a/docs/source/quantization-accuracy-debugging.rst b/docs/source/quantization-accuracy-debugging.rst index 69bda8706cc67..0fa590abd2f0c 100644 --- a/docs/source/quantization-accuracy-debugging.rst +++ b/docs/source/quantization-accuracy-debugging.rst @@ -6,7 +6,7 @@ accuracy. If a quantized model has error compared to the original model, we can categorize the error into: 1. **data insensitive error** - caused by intrinsic model quantization error, - large portion of input data has large errror + large portion of input data has large error 2. **data sensitive error** - caused by outlier input data, small portion of input data has large error 3. **implementation error** - quantized kernel is not matching reference implementation diff --git a/docs/source/quantization-backend-configuration.rst b/docs/source/quantization-backend-configuration.rst index 07fd875fa9b34..bfe93ce701e62 100644 --- a/docs/source/quantization-backend-configuration.rst +++ b/docs/source/quantization-backend-configuration.rst @@ -13,7 +13,7 @@ Default values for native configurations ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Below is the output of the configuration for quantization of ops -in fbgemm and qnnpack (PyTorch's default quantized backends). +in x86 and qnnpack (PyTorch's default quantized backends). Results: diff --git a/docs/source/quantization-support.rst b/docs/source/quantization-support.rst index 681e25b1172bc..d57a4b822f5c5 100644 --- a/docs/source/quantization-support.rst +++ b/docs/source/quantization-support.rst @@ -529,7 +529,7 @@ Quantized dtypes and quantization schemes Note that operator implementations currently only support per channel quantization for weights of the **conv** and **linear** operators. Furthermore, the input data is -mapped linearly to the the quantized data and vice versa +mapped linearly to the quantized data and vice versa as follows: .. math:: diff --git a/docs/source/quantization.rst b/docs/source/quantization.rst index 34cbad9b52cc3..3d95f72bf2b5c 100644 --- a/docs/source/quantization.rst +++ b/docs/source/quantization.rst @@ -256,11 +256,14 @@ PTSQ API Example:: model_fp32.eval() # attach a global qconfig, which contains information about what kind - # of observers to attach. Use 'fbgemm' for server inference and - # 'qnnpack' for mobile inference. Other quantization configurations such - # as selecting symmetric or assymetric quantization and MinMax or L2Norm - # calibration techniques can be specified here. - model_fp32.qconfig = torch.quantization.get_default_qconfig('fbgemm') + # of observers to attach. Use 'x86' for server inference and 'qnnpack' + # for mobile inference. Other quantization configurations such as selecting + # symmetric or assymetric quantization and MinMax or L2Norm calibration techniques + # can be specified here. + # Note: the old 'fbgemm' is still available but 'x86' is the recommended default + # for server inference. + # model_fp32.qconfig = torch.quantization.get_default_qconfig('fbgemm') + model_fp32.qconfig = torch.quantization.get_default_qconfig('x86') # Fuse the activations to preceding layers, where applicable. # This needs to be done manually depending on the model architecture. @@ -352,11 +355,14 @@ QAT API Example:: model_fp32.eval() # attach a global qconfig, which contains information about what kind - # of observers to attach. Use 'fbgemm' for server inference and - # 'qnnpack' for mobile inference. Other quantization configurations such - # as selecting symmetric or assymetric quantization and MinMax or L2Norm - # calibration techniques can be specified here. - model_fp32.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm') + # of observers to attach. Use 'x86' for server inference and 'qnnpack' + # for mobile inference. Other quantization configurations such as selecting + # symmetric or assymetric quantization and MinMax or L2Norm calibration techniques + # can be specified here. + # Note: the old 'fbgemm' is still available but 'x86' is the recommended default + # for server inference. + # model_fp32.qconfig = torch.quantization.get_default_qconfig('fbgemm') + model_fp32.qconfig = torch.quantization.get_default_qat_qconfig('x86') # fuse the activations to preceding layers, where applicable # this needs to be done manually depending on the model architecture @@ -427,7 +433,11 @@ There are multiple quantization types in post training quantization (weight only FXPTQ API Example:: import torch - from torch.ao.quantization import QConfigMapping + from torch.ao.quantization import ( + get_default_qconfig_mapping, + get_default_qat_qconfig_mapping, + QConfigMapping, + ) import torch.quantization.quantize_fx as quantize_fx import copy @@ -454,7 +464,7 @@ FXPTQ API Example:: # model_to_quantize = copy.deepcopy(model_fp) - qconfig_mapping = QConfigMapping().set_global(torch.quantization.get_default_qconfig('qnnpack')) + qconfig_mapping = get_default_qconfig_mapping("qnnpack") model_to_quantize.eval() # prepare model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_mapping, example_inputs) @@ -467,7 +477,7 @@ FXPTQ API Example:: # model_to_quantize = copy.deepcopy(model_fp) - qconfig_mapping = QConfigMapping().set_global(torch.quantization.get_default_qat_qconfig('qnnpack')) + qconfig_mapping = get_default_qat_qconfig_mapping("qnnpack") model_to_quantize.train() # prepare model_prepared = quantize_fx.prepare_qat_fx(model_to_quantize, qconfig_mapping, example_inputs) @@ -728,7 +738,7 @@ Backend/Hardware Support | | |Quantization|Mode |Mode Support| | | | |Quantization| | +-----------------+---------------+------------+------------+------------+ -|server CPU |fbgemm |Supported |All | +|server CPU |fbgemm/onednn |Supported |All | | | | |Supported | +-----------------+---------------+ | + |mobile CPU |qnnpack/xnnpack| | | @@ -742,30 +752,31 @@ Backend/Hardware Support Today, PyTorch supports the following backends for running quantized operators efficiently: -* x86 CPUs with AVX2 support or higher (without AVX2 some operations have inefficient implementations), via `fbgemm `_ +* x86 CPUs with AVX2 support or higher (without AVX2 some operations have inefficient implementations), via `x86` optimized by `fbgemm `_ and `onednn `_ (see the details at `RFC `_) * ARM CPUs (typically found in mobile/embedded devices), via `qnnpack `_ * (early prototype) support for NVidia GPU via `TensorRT `_ through `fx2trt` (to be open sourced) Note for native CPU backends ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -We expose both `fbgemm` and `qnnpack` with the same native pytorch quantized operators, so we need additional flag to distinguish between them. The corresponding implementation of `fbgemm` and `qnnpack` is chosen automatically based on the PyTorch build mode, though users have the option to override this by setting `torch.backends.quantization.engine` to `fbgemm` or `qnnpack`. +We expose both `x86` and `qnnpack` with the same native pytorch quantized operators, so we need additional flag to distinguish between them. The corresponding implementation of `x86` and `qnnpack` is chosen automatically based on the PyTorch build mode, though users have the option to override this by setting `torch.backends.quantization.engine` to `x86` or `qnnpack`. When preparing a quantized model, it is necessary to ensure that qconfig and the engine used for quantized computations match the backend on which the model will be executed. The qconfig controls the type of observers used -during the quantization passes. The qengine controls whether `fbgemm` or -`qnnpack` specific packing function is used when packing weights for linear -and convolution functions and modules. For example: +during the quantization passes. The qengine controls whether `x86` or `qnnpack` +specific packing function is used when packing weights for +linear and convolution functions and modules. For example: -Default settings for fbgemm:: +Default settings for x86:: # set the qconfig for PTQ - qconfig = torch.quantization.get_default_qconfig('fbgemm') + # Note: the old 'fbgemm' is still available but 'x86' is the recommended default on x86 CPUs + qconfig = torch.quantization.get_default_qconfig('x86') # or, set the qconfig for QAT - qconfig = torch.quantization.get_default_qat_qconfig('fbgemm') + qconfig = torch.quantization.get_default_qat_qconfig('x86') # set the qengine to control weight packing - torch.backends.quantized.engine = 'fbgemm' + torch.backends.quantized.engine = 'x86' Default settings for qnnpack:: @@ -992,11 +1003,30 @@ Custom API Example:: Best Practices -------------- -1. If you are using the ``fbgemm`` backend, we need to use 7 bits instead of 8 bits. Make sure you reduce the range for the ``quant\_min``, ``quant\_max``, e.g. +1. If you are using the ``x86`` backend, we need to use 7 bits instead of 8 bits. Make sure you reduce the range for the ``quant\_min``, ``quant\_max``, e.g. if ``dtype`` is ``torch.quint8``, make sure to set a custom ``quant_min`` to be ``0`` and ``quant_max`` to be ``127`` (``255`` / ``2``) if ``dtype`` is ``torch.qint8``, make sure to set a custom ``quant_min`` to be ``-64`` (``-128`` / ``2``) and ``quant_max`` to be ``63`` (``127`` / ``2``), we already set this correctly if you call the `torch.ao.quantization.get_default_qconfig(backend)` or `torch.ao.quantization.get_default_qat_qconfig(backend)` function to get the default ``qconfig`` for -``fbgemm`` or ``qnnpack`` backend +``x86`` or ``qnnpack`` backend + +Frequently Asked Questions +-------------------------- + +1. How can I do quantized inference on GPU?: + + We don't have official GPU support yet, but this is an area of active development, you can find more information + `here `_ + +2. Where can I get ONNX support for my quantized model?: + + You can open an issue in `GitHub - onnx/onnx `_ when you encounter problems with ONNX, + or reach out to people in this list: `PyTorch Governance | Maintainers | ONNX exporter `_ + +3. How can I use quantization with LSTM's?: + + LSTM is supported through our custom module api in both eager mode and fx graph mode quantization. Examples can be found at + Eager Mode: `pytorch/test_quantized_op.py TestQuantizedOps.test_custom_module_lstm `_ + FX Graph Mode: `pytorch/test_quantize_fx.py TestQuantizeFx.test_static_lstm `_ Common Errors --------------------------------------- diff --git a/docs/source/scripts/onnx/build_onnx_diagnostics_rules_md.py b/docs/source/scripts/onnx/build_onnx_diagnostics_rules_md.py new file mode 100644 index 0000000000000..3c2895f6fe769 --- /dev/null +++ b/docs/source/scripts/onnx/build_onnx_diagnostics_rules_md.py @@ -0,0 +1,37 @@ +import argparse +import os +from dataclasses import fields + +from torch.onnx._internal import diagnostics +from torch.onnx._internal.diagnostics import infra + + +def gen_docs(out_dir: str): + os.makedirs(out_dir, exist_ok=True) + for field in fields(diagnostics.rules): + rule = getattr(diagnostics.rules, field.name) + if not isinstance(rule, infra.Rule): + continue + title = f"{rule.id}:{rule.name}" + full_description_markdown = rule.full_description_markdown + assert ( + full_description_markdown is not None + ), f"Expected {title} to have a full description in markdown" + with open(f"{out_dir}/{title}.md", "w") as f: + f.write(f"# {title}\n") + f.write(full_description_markdown) + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate ONNX diagnostics rules doc in markdown." + ) + parser.add_argument( + "out_dir", metavar="OUT_DIR", help="path to output directory for docs" + ) + args = parser.parse_args() + gen_docs(args.out_dir) + + +if __name__ == "__main__": + main() diff --git a/docs/source/signal.rst b/docs/source/signal.rst index e304092ede5ed..a450c92727f35 100644 --- a/docs/source/signal.rst +++ b/docs/source/signal.rst @@ -18,6 +18,13 @@ torch.signal.windows :toctree: generated :nosignatures: + bartlett + blackman cosine exponential gaussian + general_cosine + general_hamming + hamming + hann + kaiser diff --git a/docs/source/sparse.rst b/docs/source/sparse.rst index 2da6a6faaee55..77e8dabec2744 100644 --- a/docs/source/sparse.rst +++ b/docs/source/sparse.rst @@ -10,7 +10,7 @@ torch.sparse .. warning:: The PyTorch API of sparse tensors is in beta and may change in the near future. - We highly welcome feature requests, bug reports and general suggestions as Github issues. + We highly welcome feature requests, bug reports and general suggestions as GitHub issues. Why and when to use sparsity ++++++++++++++++++++++++++++ @@ -40,7 +40,7 @@ Like many other performance optimization sparse storage formats are not always advantageous. When trying sparse formats for your use case you might find your execution time to decrease rather than increase. -Please feel encouraged to open a Github issue if you analytically +Please feel encouraged to open a GitHub issue if you analytically expected to see a stark increase in performance but measured a degradation instead. This helps us prioritize the implementation of efficient kernels and wider performance optimizations. @@ -117,7 +117,7 @@ Operator overview Fundamentally, operations on Tensor with sparse storage formats behave the same as operations on Tensor with strided (or other) storage formats. The particularities of storage, that is the physical layout of the data, influences the performance of -an operation but shhould not influence the semantics. +an operation but should not influence the semantics. We are actively increasing operator coverage for sparse tensors. Users should not diff --git a/docs/source/storage.rst b/docs/source/storage.rst index 28cf4444fbc97..84fed2f659a7b 100644 --- a/docs/source/storage.rst +++ b/docs/source/storage.rst @@ -22,6 +22,10 @@ holds the data as an untyped array of bytes. Every strided :class:`torch.Tensor` contains a :class:`torch.TypedStorage`, which stores all of the data that the :class:`torch.Tensor` views. +.. warning:: + All storage classes except for :class:`torch.UntypedStorage` will be removed + in the future, and :class:`torch.UntypedStorage` will be used in all cases. + .. autoclass:: torch.TypedStorage :members: :undoc-members: diff --git a/docs/source/testing.rst b/docs/source/testing.rst index 122aa651b9579..8837c4a0ec1a7 100644 --- a/docs/source/testing.rst +++ b/docs/source/testing.rst @@ -6,3 +6,4 @@ torch.testing .. autofunction:: assert_close .. autofunction:: make_tensor +.. autofunction:: assert_allclose diff --git a/docs/source/torch.rst b/docs/source/torch.rst index f5e06d5ea438d..23d63bcd750c0 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -613,6 +613,14 @@ Utilities vmap _assert +Optimizations +------------- +.. autosummary:: + :toctree: generated + :nosignatures: + + compile + Operator Tags ------------------------------------ .. autoclass:: Tag diff --git a/functorch/.flake8 b/functorch/.flake8 deleted file mode 100644 index a6d73773e3b55..0000000000000 --- a/functorch/.flake8 +++ /dev/null @@ -1,20 +0,0 @@ -[flake8] -select = B,C,E,F,P,T4,W,B9 -max-line-length = 120 -# C408 ignored because we like the dict keyword argument syntax -# E501 is not flexible enough, we're using B950 instead -ignore = - E203,E305,E402,E501,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303, - # shebang has extra meaning in fbcode lints, so I think it's not worth trying - # to line this up with executable bit - EXE001, - # these ignores are from flake8-bugbear; please fix! - B007,B008, - # these ignores are from flake8-comprehensions; please fix! - C400,C401,C402,C403,C404,C405,C407,C411,C413,C414,C415 -exclude = - ./.git, - ./benchmarks, - ./docs, - ./examples, - ./notebooks diff --git a/functorch/CMakeLists.txt b/functorch/CMakeLists.txt index d203043243829..911f251e88623 100644 --- a/functorch/CMakeLists.txt +++ b/functorch/CMakeLists.txt @@ -1,6 +1,6 @@ cmake_minimum_required(VERSION 3.12) project(functorch) -set(CMAKE_CXX_STANDARD 14) +set(CMAKE_CXX_STANDARD 17) include(GNUInstallDirs) include(CMakePackageConfigHelpers) diff --git a/functorch/__init__.py b/functorch/__init__.py index 971ce793d7203..c02ae3c443b6f 100644 --- a/functorch/__init__.py +++ b/functorch/__init__.py @@ -8,19 +8,19 @@ # Top-level APIs. Please think carefully before adding something to the # top-level namespace: -# - private helper functions should go into functorch._src +# - private helper functions should go into torch._functorch # - very experimental things should go into functorch.experimental # - compilation related things should go into functorch.compile # functorch transforms -from ._src.vmap import vmap -from ._src.eager_transforms import ( +from torch._functorch.vmap import vmap +from torch._functorch.eager_transforms import ( grad, grad_and_value, vjp, jacrev, jvp, jacfwd, hessian, functionalize ) -from ._src.python_key import make_fx +from torch._functorch.python_key import make_fx # utilities. Maybe these should go in their own namespace in the future? -from ._src.make_functional import ( +from torch._functorch.make_functional import ( make_functional_with_buffers, make_functional, combine_state_for_ensemble, diff --git a/functorch/_src/__init__.py b/functorch/_src/__init__.py index 10a55772ab58b..e69de29bb2d1d 100644 --- a/functorch/_src/__init__.py +++ b/functorch/_src/__init__.py @@ -1,5 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. diff --git a/functorch/_src/aot_autograd.py b/functorch/_src/aot_autograd.py deleted file mode 100644 index b1e29b6ac4103..0000000000000 --- a/functorch/_src/aot_autograd.py +++ /dev/null @@ -1,890 +0,0 @@ -import collections -import dataclasses -import warnings -from contextlib import contextmanager, nullcontext -from functools import wraps -from typing import Any, Callable, Dict, List, Optional, Tuple -from torch.fx.experimental.proxy_tensor import is_sym_node - -import torch -import torch.fx.traceback as fx_traceback -import torch.nn as nn -import torch.utils._pytree as pytree -import torch.utils.dlpack -from torch import Tensor -from torch._subclasses import FakeTensorMode, CrossRefFakeMode -from torch.fx import immutable_collections, Interpreter -from torch.fx.experimental.symbolic_shapes import ShapeEnv -from torch.nn.utils import stateless - -from functorch import make_fx -from functorch.experimental import functionalize -from torch._dispatch.python import enable_python_dispatcher -from . import config -from .named_members_polyfill import _named_buffers, _named_parameters -from .partitioners import default_partition - -try: - from torchdynamo import disable as disable_torchdynamo -except ImportError: - - def disable_torchdynamo(x): - return x - - -try: - from torchdynamo.utils import dynamo_timed -except ImportError: - - def dynamo_timed(x): - return x - - -pytree._register_pytree_node( - immutable_collections.immutable_list, - lambda x: (list(x), None), - lambda x, c: immutable_collections.immutable_list(x), -) -pytree._register_pytree_node( - immutable_collections.immutable_dict, - lambda x: (list(x.values()), list(x.keys())), - lambda x, c: immutable_collections.immutable_dict( - {key: value for key, value in zip(c, x)} - ), -) - -aten = torch.ops.aten - - -@contextmanager -def preserve_rng_state(): - rng_state = torch.clone(torch.random.get_rng_state()) - if torch.cuda.is_available(): - cuda_rng_state = torch.clone(torch.cuda.get_rng_state()) - try: - yield - finally: - torch.random.set_rng_state(rng_state) - if torch.cuda.is_available(): - torch.cuda.set_rng_state(cuda_rng_state) - - -# Set up hooks so that during backward the fx's stack_trace is properly set -callback_set = False - - -def setup_stacktrace_preservation_hooks(roots: List): - def iter_graph(roots): - if not roots: - return - seen = set() - q = collections.deque() - for node in roots: - if node is not None: - seen.add(node) - q.append(node) - - while q: - node = q.popleft() - for fn, _idx in node.next_functions: - if fn in seen or fn is None: - continue - seen.add(fn) - q.append(fn) - - yield node - - def get_callback(saved_stack_): - def callback(): - global callback_set - fx_traceback.set_stack_trace(saved_stack_) - callback_set = False - - return callback - - def get_prehook(stack_): - def prehook(grad_output): - global callback_set - - if not callback_set: - torch.autograd.variable.Variable._execution_engine.queue_callback( - get_callback(fx_traceback.format_stack()) - ) - callback_set = True - - fx_traceback.set_stack_trace(stack_) - - return prehook - - def get_posthook(special_stack_): - def posthook(grad_input, grad_output): - fx_traceback.set_stack_trace(special_stack_) - - return posthook - - for node in iter_graph(roots): - forward_node_stack = node.metadata.get("traceback_", []) - node.register_prehook(get_prehook(forward_node_stack)) - - special_stack = forward_node_stack.copy() - special_stack.append( - "Gradient addition node due to multiple use of tensor around:" - ) - node.register_hook(get_posthook(special_stack)) - - -def create_joint_forward_backward(fn): - def joint_forward_backward( - primals: List[Any], tangents: List[Any] - ) -> Tuple[List[Any], List[Any]]: - # Call the forward pass - outs = fn(*primals) - # Get the inputs that need gradients - grad_primals = [] - inputs_needs_grads = [] - for p in primals: - is_grad_tensor = isinstance(p, Tensor) and p.requires_grad - inputs_needs_grads.append(is_grad_tensor) - if is_grad_tensor: - grad_primals.append(p) - - # Get the outputs that need gradients - assert len(tangents) == len(outs) - needed_outs = [] - needed_tangents = [] - for out, tangent in zip(outs, tangents): - if isinstance(out, Tensor) and out.requires_grad: - needed_outs.append(out) - needed_tangents.append(tangent) - - setup_stacktrace_preservation_hooks([out.grad_fn for out in needed_outs]) - - backward_out = [] - # Call the backwards pass - if grad_primals: - with fx_traceback.override_stack_trace(): - backward_out = torch.autograd.grad( - needed_outs, - grad_primals, - grad_outputs=needed_tangents, - allow_unused=True, - ) - backward_out_iter = iter(backward_out) - return outs, [ - next(backward_out_iter) if i else None for i in inputs_needs_grads - ] - - return joint_forward_backward - - -def normalize_as_list(x): - if isinstance(x, tuple): - return list(x) - elif isinstance(x, list): - return x - return [x] - - -aot_autograd_decompositions = {} - - -# This is a list since looking forward, we can have this arbitrarily nested. -graph_being_compiled: List[str] = [] -nth_graph: int = 0 -model_name: str = "model" - - -def set_model_name(name): - global model_name - model_name = name - - -def get_aot_compilation_context() -> Tuple[List[str], str, int]: - return list(graph_being_compiled), model_name, nth_graph - - -def get_aot_graph_name() -> str: - """ - Returns the name of the graph being compiled. - """ - global model_name, graph_being_compiled, nth_graph - return f"{model_name}_{'_'.join(graph_being_compiled)}_{nth_graph}" - - -get_graph_being_compiled = get_aot_graph_name - - -@contextmanager -def track_graph_compiling(graph_name, increment_index=False): - global graph_being_compiled - graph_being_compiled = [graph_name] - yield - if increment_index: - global nth_graph - nth_graph += 1 - graph_being_compiled = [] - - -def make_boxed_func(f): - def g(args): - return f(*args) - - g._boxed_call = True - return g - - -def make_boxed_compiler(compiler): - @wraps(compiler) - def f(fx_g, inps): - out_f = compiler(fx_g, inps) - fx_g = make_boxed_func(out_f) - return fx_g - - return f - - -def call_func_with_args(f, args, steal_args=False, disable_amp=False): - if not steal_args: - args = list(args) - assert isinstance(args, list) - - if disable_amp: - guard = torch._C._DisableAutocast() - try: - if hasattr(f, "_boxed_call"): - out = normalize_as_list(f(args)) - else: - # TODO: Please remove soon - # https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 - warnings.warn( - "Your compiler for AOTAutograd is returning a a function that doesn't take boxed arguments. " - "Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. " - "See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale." - ) - out = normalize_as_list(f(*args)) - finally: - if disable_amp: - del guard - return out - - -@dataclasses.dataclass -class AOTConfig: - """ - Configuration for AOTDispatcher - """ - - fw_compiler: Callable - bw_compiler: Callable - partition_fn: Callable - decompositions: Dict[Callable, Callable] - num_params_buffers: int - - -def aot_dispatch_base(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig): - fw_module = make_fx(flat_fn, aot_config.decompositions)(*flat_args) - if config.debug_graphs: - print("====== Forward (only) graph ======") - fw_module.print_readable() - - - disable_amp = torch._C._is_any_autocast_enabled() - context = disable_autocast_manager if disable_amp else nullcontext - - with context(), track_graph_compiling("inference"): - compiled_fw = aot_config.fw_compiler(fw_module, flat_args) - - @wraps(compiled_fw) - def new_fn(args): - fw_outs = call_func_with_args(compiled_fw, args, disable_amp=disable_amp) - return fw_outs - - return new_fn - - -@contextmanager -def disable_autocast_manager(): - guard = torch._C._DisableAutocast() - try: - yield - finally: - del guard - - -def aot_dispatch_autograd(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig): - # Deduplicate inputs. Suppose you have: - # - # [a, b, a, c] - # - # We want: - # - # remove_dupe_args([a, b, a, c]) == [a, b, c] - # add_dupe_args([a, b, c]) == [a, b, a, c] - # - # This is done via (respectively): - # - # seen_args = {2} # what to drop - # add_dupe_map = { # how to get args from the deduped list - # 0: 0, - # 1: 1, - # 2: 0, - # 3: 2, - # } - # - # Whether to use flat_args or deduped_flat_args? flat_fn takes flat_args, - # and the autograd.Function must take deduped_flat_args; everything - # else is just getting the types right. - - seen_args = {} - keep_arg_mask = [] - dropped_args = False - add_dupe_map = {} - duped_arg_len = len(flat_args) - - j = 0 # index into deduped_flat_args - for i, t in enumerate(flat_args): - if t in seen_args: - keep_arg_mask.append(False) - dropped_args = True - add_dupe_map[i] = seen_args[t] - continue - keep_arg_mask.append(True) - seen_args[t] = j - add_dupe_map[i] = j - j += 1 - - # NB: Hot path, avoid set lookups here - def remove_dupe_args(args): - if not dropped_args: - return args - return [t for t, keep in zip(args, keep_arg_mask) if keep] - - def add_dupe_args(args): - if not dropped_args: - return args - return [args[add_dupe_map[i]] for i in range(duped_arg_len)] - - deduped_flat_args = remove_dupe_args(flat_args) - - joint_forward_backward = create_joint_forward_backward(lambda *args: flat_fn(*add_dupe_args(args))) - - out = flat_fn(*flat_args) - # Collect info on which output tensors require gradients, - # so we can mark them properly in the returned autograd.Function - _flat_outs_not_requiring_grad, _ = pytree.tree_flatten( - pytree.tree_map( - lambda x: isinstance(x, Tensor) and not x.requires_grad, out - ) - ) - out = pytree.tree_map( - lambda x: x.detach().contiguous() if isinstance(x, Tensor) else x, - out, - ) - - if isinstance(out, (list, tuple)): - _num_outs = len(out) - else: - _num_outs = 1 - - joint_inputs = (deduped_flat_args, out) - - disable_amp = torch._C._is_any_autocast_enabled() - - if config.use_functionalize: - # Trace once without decompositions, into a graph of ATen ops. - # NB: tracing_mode is real, as it's assumed the calling context setup - # fake tensor mode / symbolic shapes if that is needed - fx_g = make_fx(joint_forward_backward)(*joint_inputs) - - context = disable_autocast_manager if disable_amp else nullcontext - - def fake_fn(primals, tangents): - with torch.fx.traceback.override_stack_trace(): - return torch.fx.Interpreter(fx_g).run(primals, tangents) - - # Trace a second time, running functionalization, and THEN running decompositions. - # functionalization only acts on ATen today, and doesn't currently handle - # view and inplace ops that come from primtorch. - # Eventually, functionalization should support primtorch view/inplace ops, - # which will make it ok to run decompositions before functionalization. - with context(): - fx_g = make_fx(functionalize(fake_fn), aot_config.decompositions)(*joint_inputs) - fx_g.graph.eliminate_dead_code() - fx_g.recompile() - else: - fx_g = make_fx(joint_forward_backward, aot_config.decompositions)(*joint_inputs) - - if config.debug_joint: - print("====== Joint graph ======") - fx_g.print_readable() - - with torch.no_grad(): - with track_graph_compiling("joint"): - fw_module, bw_module = aot_config.partition_fn(fx_g, joint_inputs) - fw_outs = [n for n in fw_module.graph.nodes if n.op == "output"][0].args[0] - # we only need to bookkeep the symints that are saved for bw, not any symints - # the user forward might have returned in its own output - fw_outs = fw_outs[_num_outs:] - symint_outs = [n for n in fw_outs if is_sym_node(n)] - _num_symints = len(symint_outs) - - if config.debug_graphs: - print("====== Forward graph ======") - fw_module.print_readable() - print("====== Backward graph ======") - bw_module.print_readable() - - with track_graph_compiling("forward"): - compiled_fw_func = aot_config.fw_compiler(fw_module, deduped_flat_args) - - class CompiledFunction(torch.autograd.Function): - compiled_fw = compiled_fw_func - compiled_bw = None - num_outs = _num_outs - num_symints = _num_symints - flat_outs_not_requiring_grad = _flat_outs_not_requiring_grad - - @staticmethod - @disable_torchdynamo - def forward(ctx, *deduped_flat_tensor_args): - fw_outs = call_func_with_args( - CompiledFunction.compiled_fw, deduped_flat_tensor_args, disable_amp=disable_amp - ) - num_outs = CompiledFunction.num_outs - num_symints = CompiledFunction.num_symints - # Partitioners must put symint arguments at the end separate from tensor arguments - if num_symints > 0: - ctx.save_for_backward(*fw_outs[num_outs:-num_symints]) - ctx.symints = fw_outs[-num_symints:] - else: - ctx.save_for_backward(*fw_outs[num_outs:]) - ctx.symints = [] - - fw_outs_not_requiring_grad = [ - x for (i, x) in enumerate(fw_outs[0:num_outs]) if CompiledFunction.flat_outs_not_requiring_grad[i] - ] - ctx.mark_non_differentiable(*fw_outs_not_requiring_grad) - - return tuple(fw_outs[0:num_outs]) - - @staticmethod - @disable_torchdynamo - def backward(ctx, *flat_args): - contiguous_args = [t.contiguous() if torch.is_tensor(t) else t for t in flat_args] - all_args = list(ctx.symints) + list(ctx.saved_tensors) + list(contiguous_args) - if CompiledFunction.compiled_bw is None: - context = disable_autocast_manager if disable_amp else nullcontext - with context(), track_graph_compiling("backward", True): - CompiledFunction.compiled_bw = aot_config.bw_compiler( - bw_module, all_args - ) - ctx.maybe_clear_saved_tensors() - out = call_func_with_args( - CompiledFunction.compiled_bw, all_args, steal_args=True, disable_amp=disable_amp - ) - return tuple(out) - - @wraps(CompiledFunction.apply) - def compiled_function(*args): - return CompiledFunction.apply(*remove_dupe_args(args)) - - return compiled_function - - -@dynamo_timed -def create_aot_dispatcher_function( - flat_fn, flat_args: List[Tensor], aot_config: AOTConfig -): - """ - Traces the forward and backward graphs of the attr:`flat_fn` to generate a - joint graph. The joint graph is an Fx graph with Aten ops. Please refer to - the tracing mechanism to understand the graph capturing details. - - The joint graph is then passed through attr:`partition_fn` to isolate the - forward and backward portions, which are then respectively compiled via the - provided attr:`fw_compiler` and attr:`bw_compiler`. - - The resulting compiled forward and backward graphs are then wrapped up in a - ``torch.autograd.Function`` object. - - The calling convention here is that the first aot_config.num_params_buffers - inputs in flat_args are parameters and buffers, and the rest are inputs. - - We use this to assume that parameters/buffer's shapes don't change. - """ - - # This is the main entry point. - # TODO: Chillee argues that dynamo itself should pass in fake tensors to - # the list of arguments when compiling; at the moment we do not do this - - if aot_config.decompositions is None: - aot_config.decompositions = {} - - aot_config.decompositions = { - **aot_autograd_decompositions, - **aot_config.decompositions, - } - # NB: don't bother setting allow_fallback_kernels; this should not actually - # be configurable in fake tensor, we should automatically do the right - # thing - if config.debug_fake_cross_ref: - # This is a little messy but TorchDynamo directly changes `use_fake_tensor` - # so it's not enough for user to change the config manually - # TODO: have TorchDynamo read in `use_fake_tensor` from os environ / - # coordinate flags - config.use_fake_tensor = False - - if config.use_dynamic_shapes: - assert config.use_fake_tensor, "Dynamic shapes only works with fake tensor" - - shape_env = ShapeEnv() if config.use_dynamic_shapes else None - fake_mode = FakeTensorMode(shape_env=shape_env) if config.use_fake_tensor else nullcontext() - cross_ref = CrossRefFakeMode() if config.debug_fake_cross_ref else nullcontext() - python_dispatcher_mode = enable_python_dispatcher() if config.use_dynamic_shapes else nullcontext() - - with torch.autograd.set_multithreading_enabled(False), preserve_rng_state(), cross_ref, fake_mode, python_dispatcher_mode: - - def process_inputs(flat_args): - if config.use_fake_tensor: - def convert(idx, x): - if not isinstance(x, torch.Tensor): - return x - if idx < aot_config.num_params_buffers and config.static_weight_shapes: - return fake_mode.from_tensor(x, static_shapes=True) - return fake_mode.from_tensor(x, static_shapes=False) - - return [convert(idx, x) for idx, x in enumerate(flat_args)] - else: - return flat_args - - fake_flat_tensor_args = process_inputs(flat_args) - - needs_autograd = ( - any( - [ - x.requires_grad - for x in fake_flat_tensor_args - if isinstance(x, Tensor) - ] - ) - and torch.is_grad_enabled() - ) - # crappy version of dispatcher - # TODO: Do this properly - if needs_autograd: - return make_boxed_func( - aot_dispatch_autograd(flat_fn, fake_flat_tensor_args, aot_config) - ) - else: - return aot_dispatch_base(flat_fn, fake_flat_tensor_args, aot_config) - - -# Inspired by autodidax (thanks!) -class PytreeThunk: - spec = None - # These are some kinda dumb microoptimizations that save about 3-4 us of overhead. - is_simple = ( - None # if the output spec is a tuple/list, we won't bother unflattening it. - ) - is_really_simple = None # if the output spec is a LeafSpec - - def set(self, spec): - assert self.spec is None or self.spec == spec - self.spec = spec - if type(self.spec) in [tuple, list] and all( - isinstance(i, pytree.LeafSpec) for i in spec.children_specs - ): - self.is_simple = True - if isinstance(self.spec, pytree.LeafSpec): - self.is_really_simple = True - - def unflatten(self, x): - if self.is_really_simple: - return x[0] - if self.is_simple: - return x - return pytree.tree_unflatten(x, self.spec) - -KNOWN_TYPES = [torch.Tensor, int, str, float, bool, torch.SymIntNode, torch.SymFloatNode] - - -def aot_function( - fn: Callable, - fw_compiler: Callable, - bw_compiler: Optional[Callable] = None, - partition_fn: Callable = default_partition, - decompositions: Optional[Dict] = None, - num_params_buffers: int = 0, - hasher_type=None, # deprecated - static_argnums: Optional[Tuple[int]] = None, # deprecated -) -> Callable: - """ - Traces the forward and backward graph of :attr:`fn` using torch dispatch - mechanism, and then compiles the generated forward and backward graphs - through :attr:`fw_compiler` and :attr:`bw_compiler`. - - :func:`aot_function` traces the forward and backward graph ahead of time, - and generates a joint forward and backward graph. :attr:`partition_fn` is - then used to separate out forward and backward graphs. The partitioner - function can be used to perform optimizations such as recomputation. One can - set `decompositions` dictionary to decompose the operators into a sequence - of core or simpler operators supported by the backend compilers. - - :func:`aot_function` uses a compilation cache, based on input tensor - properties, to detect when there is a need of recompilation. - - .. warning:: - This API is experimental and likely to change. - - Args: - fn (Callable): A Python function that takes one ore more arguments. Must - return one or more Tensors. - fw_compiler (Callable): A Python function that accepts an Fx graph with - Aten ops and input args, and returns a Callable that semantically is - equivalent to the input Fx graph. - bw_compiler (Optional[Callable]): A Python function that accepts an - Fx graph with Aten ops and input args, and returns a Callable that - semantically is equivalent to the input Fx graph. Default: None - (when None, it defaults to the :attr:`fw_compiler`) - partition_fn (Callable): A Python function that takes a joint forward - and backward graph, and partitions it into separate forward and - backward graphs. - decompositions (Dict): A dictionary to define the decomposition of - larger Aten ops into simpler or core Aten ops. - - Returns: - Returns a ``Callable`` that retains the eager behavior of the original - :attr:`fn`, but with forward and backward graph compiled via - :attr:`fw_compile` and :attr:`bw_compile`. - - A simple example usage of :func:`aot_function` is as follows. This example - will print the forward and backward graphs of the function ``fn`` - - >>> fn = lambda x : x.sin().cos() - >>> def print_compile_fn(fx_module, args): - >>> print(fx_module) - >>> return fx_module - >>> aot_fn = aot_function(fn, print_compile_fn) - >>> x = torch.randn(4, 5, requires_grad=True) - >>> aot_fn(x) - """ - if static_argnums is not None: - raise RuntimeError("static_argnums has been deprecated - manually wrap your function or use torchdynamo.") - - if bw_compiler is None: - bw_compiler = fw_compiler - aot_config = AOTConfig( - fw_compiler=fw_compiler, - bw_compiler=bw_compiler, - partition_fn=partition_fn, - decompositions=decompositions, - num_params_buffers=num_params_buffers, - ) - cached_res = None - - @wraps(fn) - def returned_function(*args, **kwargs): - nonlocal cached_res - # Now flatten the tensor args - flat_args, _ = pytree.tree_flatten((args, kwargs)) - - # Compile the function and save it in the cache - if cached_res is None: - # Save the args_spec for flat_tensor_args to unflatten while tracing - _, tensor_args_spec = pytree.tree_flatten((args, kwargs)) - out_spec = PytreeThunk() - - def flat_fn(*flat_args): - # The input are flattened tensor args. Prepare the args in the - # order that original function expects. Add static args as well. - # They will appear as tensor constants in the traced graph. - nonlocal out_spec - args, kwargs = pytree.tree_unflatten( - flat_args, tensor_args_spec - ) - tree_out = fn(*args, **kwargs) - flat_out, spec = pytree.tree_flatten(tree_out) - for i in flat_out: - is_known_type = False - for j in KNOWN_TYPES: - if isinstance(i, j): - is_known_type = True - break - if not is_known_type: - raise RuntimeError( - f"Found {type(i)} in output, which is not a known type. " - "If this type holds tensors, you need to register a pytree for it. " - "See https://github.com/pytorch/functorch/issues/475 for a brief " - "explanation why. If you don't need to register a pytree, please " - "leave a comment explaining your use case and we'll make this more " - "ergonomic to deal with" - ) - out_spec.set(spec) - return flat_out - - compiled_fn = create_aot_dispatcher_function( - flat_fn, - flat_args, - aot_config, - ) - cached_res = (compiled_fn, out_spec) - - cached_fn, out_spec = cached_res - out = cached_fn(flat_args) - return out_spec.unflatten(out) - - return returned_function - - -def aot_module(mod: nn.Module, *args, **kwargs) -> nn.Module: - """ - Traces the forward and backward graph of :attr:`mod` using torch dispatch - tracing mechanism. It is wrapper function, that underneath uses - :func:`aot_function` to perform tracing and compilation. - - :func:`aot_module` lifts the parameters and buffers of ``nn.Module`` as inputs - to a new callable which is then compiled through :func:`aot_function`. - - .. warning:: - This API is experimental and likely to change. - - Args: - mod (Callable): A ``nn.Module`` module. - args : args to be passed to :func:`aot_function` - kwargs : kwargs to be passed to :func:`aot_function` - - Returns: - Returns a ``nn.Module`` that retains the eager behavior of the original - :attr:`mod`, but with forward and backward graph compiled. - - """ - - def functional_call(named_params, named_buffers, *args, **kwargs): - params_and_buffers = {**named_params, **named_buffers} - return stateless.functional_call(mod, params_and_buffers, args, kwargs) - - named_params = dict(_named_parameters(mod, remove_duplicate=False)) - named_buffers = dict(_named_buffers(mod, remove_duplicate=False)) - num_params_buffers = len(named_params) + len(named_buffers) - compiled_f = aot_function(functional_call, num_params_buffers=num_params_buffers, *args, **kwargs) - - class AOTModule(nn.Module): - def __init__(self): - super(AOTModule, self).__init__() - self.orig_module = mod - - def forward(self, *args, **kwargs): - return compiled_f( - named_params, - named_buffers, - *args, - **kwargs, - ) - - return AOTModule() - - -def aot_module_simplified(mod: nn.Module, *top_args, **top_kwargs) -> nn.Module: - """ - This is the simplified or low overhead version of aot_module. For frontends - like TorchDynamo, the input functions/modules to AOT are static and have - unpacked inputs/outputs. This gives us an opportunity to remove the - (1) pytree overhead to parse inputs/outputs, - (2) AOT Autograd cache, - (3) Reading of params/buffers in every forward call - - :func:`aot_module_simplified` removes these overheads. - """ - ######################################################### - - params = { - **dict(_named_parameters(mod, remove_duplicate=False)), - **dict(_named_buffers(mod, remove_duplicate=False)), - } - params_flat, params_spec = pytree.tree_flatten(params) - params_flat = tuple(params_flat) - params_len = len(params_flat) - - def functional_call(*args, **kwargs): - with stateless._reparametrize_module( - mod, pytree.tree_unflatten(args[:params_len], params_spec) - ): - if isinstance(mod, torch.fx.GraphModule): - with fx_traceback.override_stack_trace(), warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", "Anomaly Detection has been enabled." - ) - with torch.autograd.detect_anomaly(check_nan=False): - out = Interpreter(mod).run(*args[params_len:], **kwargs) - else: - out = mod(*args[params_len:], **kwargs) - - if not isinstance(out, (tuple, list)): - raise RuntimeError( - "Graph output must be a tuple(). This is so that we can avoid " - "pytree processing of the ouputs. Please change the module to " - "have tuple outputs or use aot_module instead." - ) - return out - - def aot_function_simplified( - fn: Callable, - fw_compiler: Callable, - bw_compiler: Optional[Callable] = None, - partition_fn: Callable = default_partition, - decompositions: Optional[Dict] = None, - hasher_type=None, - static_argnums=None, - ) -> Callable: - assert static_argnums is None - if bw_compiler is None: - bw_compiler = fw_compiler - aot_config = AOTConfig( - fw_compiler=fw_compiler, - bw_compiler=bw_compiler, - partition_fn=partition_fn, - decompositions=decompositions, - num_params_buffers=params_len, - ) - - compiled_fn = None - - @wraps(fn) - def new_func(*args): - nonlocal compiled_fn - if compiled_fn is None: - compiled_fn = create_aot_dispatcher_function( - fn, - args, - aot_config, - ) - return compiled_fn(args) - - return new_func - - compiled_f = aot_function_simplified(functional_call, *top_args, **top_kwargs) - - if top_kwargs: - - def forward(*args, **kwargs): - return compiled_f( - *params_flat, - *args, - **kwargs, - ) - - else: - - def forward(*args): - return compiled_f( - *params_flat, - *args, - ) - - forward.zero_grad = mod.zero_grad - forward.named_parameters = mod.named_parameters - return forward - - -compiled_function = aot_function -compiled_module = aot_module diff --git a/functorch/_src/aot_autograd/__init__.py b/functorch/_src/aot_autograd/__init__.py new file mode 100644 index 0000000000000..94f258df84ba8 --- /dev/null +++ b/functorch/_src/aot_autograd/__init__.py @@ -0,0 +1,8 @@ +# This file has moved to under torch/_functorch. It is not public API. +# If you are not a PyTorch developer and you are relying on the following +# imports, please file an issue. +from torch._functorch.aot_autograd import ( + aot_autograd_decompositions, + KNOWN_TYPES, + PytreeThunk, +) diff --git a/functorch/_src/eager_transforms/__init__.py b/functorch/_src/eager_transforms/__init__.py new file mode 100644 index 0000000000000..e3e587c0978fa --- /dev/null +++ b/functorch/_src/eager_transforms/__init__.py @@ -0,0 +1,7 @@ +# This file has moved to under torch/_functorch. It is not public API. +# If you are not a PyTorch developer and you are relying on the following +# imports, please file an issue. +from torch._functorch.eager_transforms import ( + _unwrap_functional_tensor, + _assert_wrapped_functional, +) diff --git a/functorch/_src/make_functional/__init__.py b/functorch/_src/make_functional/__init__.py new file mode 100644 index 0000000000000..3de7787df0c33 --- /dev/null +++ b/functorch/_src/make_functional/__init__.py @@ -0,0 +1,4 @@ +# This file has moved to under torch/_functorch. It is not public API. +# If you are not a PyTorch developer and you are relying on the following +# imports, please file an issue. +from torch._functorch.make_functional import _swap_state diff --git a/functorch/_src/vmap/__init__.py b/functorch/_src/vmap/__init__.py new file mode 100644 index 0000000000000..792a2fde38bb3 --- /dev/null +++ b/functorch/_src/vmap/__init__.py @@ -0,0 +1,16 @@ +# This file has moved to under torch/_functorch. It is not public API. +# If you are not a PyTorch developer and you are relying on the following +# imports, please file an issue. +from torch._functorch.vmap import ( + _add_batch_dim, + _broadcast_to_and_flatten, + _get_name, + _remove_batch_dim, + _validate_and_get_batch_size, + Tensor, + tree_flatten, + tree_unflatten, + _process_batched_inputs, + _create_batched_inputs, + _unwrap_batched, +) diff --git a/functorch/benchmarks/chrome_trace_parser.py b/functorch/benchmarks/chrome_trace_parser.py index 54d2bf1447fb1..ccc8b89544bc3 100755 --- a/functorch/benchmarks/chrome_trace_parser.py +++ b/functorch/benchmarks/chrome_trace_parser.py @@ -5,7 +5,7 @@ import logging import pandas as pd -from functorch._src.benchmark_utils import compute_utilization +from torch._functorch.benchmark_utils import compute_utilization # process the chrome traces output by the pytorch profiler # require the json input file's name to be in format {model_name}_chrome_trace_*.json diff --git a/functorch/benchmarks/cse.py b/functorch/benchmarks/cse.py index 028677d6ee259..14cde14eb3085 100644 --- a/functorch/benchmarks/cse.py +++ b/functorch/benchmarks/cse.py @@ -3,7 +3,7 @@ from functorch import make_fx from torch.profiler import profile, ProfilerActivity -from functorch._src.compile_utils import fx_graph_cse +from torch._functorch.compile_utils import fx_graph_cse def profile_it(f, inp): for _ in range(5): diff --git a/functorch/benchmarks/operator_authoring.py b/functorch/benchmarks/operator_authoring.py index 88e558bdafc1a..cbd816e2ad132 100644 --- a/functorch/benchmarks/operator_authoring.py +++ b/functorch/benchmarks/operator_authoring.py @@ -77,7 +77,7 @@ def setup(n): assert result_nnc.dtype == result_aten.dtype assert result_nnc.size() == result_aten.size() assert result_nnc.stride() == result_aten.stride() - torch.testing.assert_allclose(result_aten, result_nnc) + torch.testing.assert_close(result_aten, result_nnc) return (lambda: nnc(*args), lambda: aten(*args)) return benchmark_loop(setup) @@ -90,7 +90,7 @@ def inplace_setup(n): result_nnc = torch.clone(a) nnc(result_nnc, b, out=result_nnc) aten(result_aten, b, out=result_aten) - torch.testing.assert_allclose(result_aten, result_nnc) + torch.testing.assert_close(result_aten, result_nnc) return (lambda: nnc(a, b, out=a), lambda: aten(a, b, out=a)) return benchmark_loop(inplace_setup) @@ -103,7 +103,7 @@ def out_setup(n): result_nnc = out(n) aten(*args, out=result_aten) nnc(*args, out=result_nnc) - torch.testing.assert_allclose(result_aten, result_nnc) + torch.testing.assert_close(result_aten, result_nnc) result = out(n) return (lambda: nnc(*args, out=result), lambda: aten(*args, out=result)) @@ -118,7 +118,7 @@ def backwards_setup(n): correct = grad_var.grad.clone() grad_var.grad.zero_() nnc(*args).sum().backward() - torch.testing.assert_allclose(correct, grad_var.grad) + torch.testing.assert_close(correct, grad_var.grad) return ( lambda: nnc(*args).sum().backward(), lambda: aten(*args).sum().backward(), diff --git a/functorch/benchmarks/pointwise_scorecard.py b/functorch/benchmarks/pointwise_scorecard.py index ac4cf5f386dcf..15863dc3510cf 100644 --- a/functorch/benchmarks/pointwise_scorecard.py +++ b/functorch/benchmarks/pointwise_scorecard.py @@ -195,13 +195,13 @@ def micros(s): if shape == medium_transpose: raise RuntimeError("pointwise_operator hangs on medium_transpose") pw_op = pointwise_operator(operator) - torch.testing.assert_allclose(operator(*args), pw_op(*args)) + torch.testing.assert_close(operator(*args), pw_op(*args)) except Exception: print(f"pointwise_operator failed on {operator.__name__}, {shape.__name__}") nope.add((operator, shape)) ts_op = torch.jit.script(operator) - torch.testing.assert_allclose(operator(*args), ts_op(*args)) + torch.testing.assert_close(operator(*args), ts_op(*args)) print("fuser,device,operator,shape,time") diff --git a/functorch/compile/__init__.py b/functorch/compile/__init__.py index 12549dceda9fb..569c1b6819bdd 100644 --- a/functorch/compile/__init__.py +++ b/functorch/compile/__init__.py @@ -1,6 +1,6 @@ -from .._src.python_key import pythonkey_decompose -from .._src.fx_minifier import minifier -from .._src.aot_autograd import ( +from torch._functorch.python_key import pythonkey_decompose +from torch._functorch.fx_minifier import minifier +from torch._functorch.aot_autograd import ( aot_function, aot_module, compiled_function, @@ -12,7 +12,7 @@ make_boxed_func, make_boxed_compiler ) -from .._src.compilers import ( +from torch._functorch.compilers import ( ts_compile, draw_graph_compile, nop, @@ -22,10 +22,10 @@ print_compile, default_decompositions ) -from .._src.partitioners import ( +from torch._functorch.partitioners import ( min_cut_rematerialization_partition, default_partition, draw_graph, draw_joint_graph, ) -from .._src import config +from torch._functorch import config diff --git a/functorch/csrc/dim/dim.cpp b/functorch/csrc/dim/dim.cpp index abdebc24e0112..6fc0038bfc958 100644 --- a/functorch/csrc/dim/dim.cpp +++ b/functorch/csrc/dim/dim.cpp @@ -1158,7 +1158,7 @@ struct EnableAllLayers { } } private: - int64_t levels_start_; + int64_t levels_start_{}; Slice> levels_to_dim_; }; @@ -1767,6 +1767,9 @@ static PyObject* order(PyObject *_, PyObject *kwnames) { Arena A; PY_BEGIN + if (kwnames) { + py::raise_error(PyExc_TypeError, "unexpected keyword arguments %S", kwnames); + } AT_ASSERT(nargs-- > 0); Slice orig_levels; Slice levels; @@ -2684,7 +2687,7 @@ static PyObject* py_stack(PyObject *_, auto d = _wrap_dim(dim, ndim, false); auto idx = result_levels.index(d); if (!idx) { - py::raise_error(PyExc_TypeError, "Dimension %R does not exist in inputs", dim); + py::raise_error(PyExc_TypeError, "Dimension %R does not exist in inputs", dim.ptr()); } rawdim = *idx; } diff --git a/functorch/dim/README.md b/functorch/dim/README.md index 750c8847c8502..5ed7bbd3d5284 100644 --- a/functorch/dim/README.md +++ b/functorch/dim/README.md @@ -7,7 +7,7 @@ _An implementation of [named tensors](https://namedtensor.github.io) with the fu The tensor input to a resnet might have the shape [8, 3, 224, 224] but informally we think of those dimensions as 'batch', 'channel', 'width', and 'height'. Eventhough 'width' and 'height' have the same _size_ we still think of them as separate dimensions, and if we have two _different_ images, we think of both as sharing the _same_ 'channel' dimension. -Named tensors gives these dimensions names. [PyTorch's current implementation](https://pytorch.org/docs/stable/named_tensor.html) uses strings to name dimensions. Instead, this library introduces a Python object, a `Dim`, to represent the concept. By expanding the semantics of tensors with dim objects, in addition to naming dimensions, we can get behavior equivalent to batching transforms (xmap, vmap), einops-style rearragement, and loop-style tensor indexing. +Named tensors gives these dimensions names. [PyTorch's current implementation](https://pytorch.org/docs/stable/named_tensor.html) uses strings to name dimensions. Instead, this library introduces a Python object, a `Dim`, to represent the concept. By expanding the semantics of tensors with dim objects, in addition to naming dimensions, we can get behavior equivalent to batching transforms (xmap, vmap), einops-style rearrangement, and loop-style tensor indexing. A preview: @@ -85,11 +85,11 @@ from torchdim import dims batch, channel, width, height = dims(4) ``` -The existing implemention of [Named Tensors](https://pytorch.org/docs/stable/named_tensor.html) in PyTorch, or [JAX's xmap](https://jax.readthedocs.io/en/latest/notebooks/xmap_tutorial.html) use strings to name dimensions. We call these dimensions _first class_ because they are Python objects. +The existing implementation of [Named Tensors](https://pytorch.org/docs/stable/named_tensor.html) in PyTorch, or [JAX's xmap](https://jax.readthedocs.io/en/latest/notebooks/xmap_tutorial.html) use strings to name dimensions. We call these dimensions _first class_ because they are Python objects. In addition to the normal _positional_ dimensions in a tensor, tensors can also have a separate set of first-class dimensions. -You can create tensors with first-class dimensions by indexing the normal positional dimensions of a tensor with a dimension object. The `ndim` property continues to list the number of positional dimesions, while the new `dims` property lists all the bound first-class dimensions. +You can create tensors with first-class dimensions by indexing the normal positional dimensions of a tensor with a dimension object. The `ndim` property continues to list the number of positional dimensions, while the new `dims` property lists all the bound first-class dimensions. ```py input = torch.rand(2, 3, 224, 224) @@ -101,7 +101,7 @@ print(input_fc.dims) # first class dimensions > (batch, channel, width, height) -# since we converted all the positional dimesions +# since we converted all the positional dimensions # first class `input_fc` has 0 positional dimensions now. print(input_fc.ndim) > 0 @@ -266,7 +266,7 @@ print(i <= j) > with dims=(i, j) sizes=(4, 4) ``` -Because of the intentional similarity to loop-level code, using dimsions as tensors makes complicated indexing arithmetic easier to read. +Because of the intentional similarity to loop-level code, using dimensions as tensors makes complicated indexing arithmetic easier to read. Here is code that lookups up features in an embedding table given a sequence of ids: @@ -296,7 +296,7 @@ Unbinding Dims ------------- The `order` method converts first-class dimensions in a tensor back to normal positional dimensions by specifying an order for those dimensions.[^4] -By specifiying a different order from how things were originally bound, it is easy to do transpositions. +By specifying a different order from how things were originally bound, it is easy to do transpositions. ```py i, j = dims(2) @@ -305,7 +305,7 @@ A_T = A[i, j].order(j, i) assert torch.allclose(A.T, A_T) ``` -Indexing acts left-to-right, and `order` also places the new dimensions back on the left, so it possible to work on tensors that have mixed positonal and first-class dimensions: +Indexing acts left-to-right, and `order` also places the new dimensions back on the left, so it possible to work on tensors that have mixed positional and first-class dimensions: ```py B = torch.rand(3, 4, 5) @@ -313,7 +313,7 @@ B_T = B[i, j].order(j, i) assert torch.allclose(B.permute(1, 0, 2), B_T) ``` -[^4] `order` is actually just a synonym for the already-existing `permute` method, which takes a list a dimension specifiers and puts the tensor in that order because rule #2 says that first-class dims can be passed as arguments to functions that previousely took only integers as dimensions. However, the name `permute` is confusing in this context since it implies dim objects have an original order, so we prefer to use `order` when writing code. +[^4] `order` is actually just a synonym for the already-existing `permute` method, which takes a list a dimension specifiers and puts the tensor in that order because rule #2 says that first-class dims can be passed as arguments to functions that previously took only integers as dimensions. However, the name `permute` is confusing in this context since it implies dim objects have an original order, so we prefer to use `order` when writing code. Flattening and Splitting Dims ----------------------------- @@ -412,7 +412,7 @@ Named tensors with first-class dimensions can accomplish the same goal, but usin Automatically batching Code (`vmap`, `xmap`) ----------------------------- -The implicit batching of Rule #1 means it is easy to created batched versions of existing PyTorch code. Simply bind a dim to the dimensions that should act as a batch, and then pass the tensor to the unbatched function. Since the unbatched function does not know about the dim, the dim will be implicictly batched over: +The implicit batching of Rule #1 means it is easy to created batched versions of existing PyTorch code. Simply bind a dim to the dimensions that should act as a batch, and then pass the tensor to the unbatched function. Since the unbatched function does not know about the dim, the dim will be implicitly batched over: ```py batch_size, feature_size = 3, 5 @@ -501,7 +501,7 @@ def multiheadattention(q, k, v, num_attention_heads, dropout_prob, use_positiona Indexing -------- -Rule #3 enables indexing because dimensions act as loop indices when used as a tensor. This allows for a lot of powerful behavior. The simplest might be using the dimensions to compute masks, such as extracing the upper triangular part of a matrix: +Rule #3 enables indexing because dimensions act as loop indices when used as a tensor. This allows for a lot of powerful behavior. The simplest might be using the dimensions to compute masks, such as extracting the upper triangular part of a matrix: ```py from torch import where @@ -745,7 +745,7 @@ The semantics and surface syntax of dimension objects resembles the kind of code These compilers and language have syntax and semantics that resemble the loop-level analogy similar to first-class dimensions. However, as compilers or statically typed languages, they require some binding code to go from running deep learning framework code in Python to using the compiled language. This often at least requires refactoring the compiled parts into their own functions, and may require defining a gradient function. Similar to graph mode frameworks, this adds friction to using and debugging the code. -Dimension objects are just an extension of the existing PyTorch tensors and eager sematics, so there is no friction switching between normal Python code and code that uses them. However, since loops over the dimensions are defined implicitly, they can still execute in Python with good performance compared to explicit loops. Furthermore, with dimension objects, a tensors containing dimensions can compute through code that is oblivous to the dimension such as batching examples. There is no need to separate code into 'compiled' vs 'eager'. +Dimension objects are just an extension of the existing PyTorch tensors and eager semantics, so there is no friction switching between normal Python code and code that uses them. However, since loops over the dimensions are defined implicitly, they can still execute in Python with good performance compared to explicit loops. Furthermore, with dimension objects, a tensors containing dimensions can compute through code that is oblivious to the dimension such as batching examples. There is no need to separate code into 'compiled' vs 'eager'. In this way, first-class dims are a way of adapting the nicer syntax of these array compilers and languages to eager numpy-style libraries. diff --git a/functorch/dim/__init__.py b/functorch/dim/__init__.py index 4f1cd84e44a18..6d36a8994dfe9 100644 --- a/functorch/dim/__init__.py +++ b/functorch/dim/__init__.py @@ -102,9 +102,9 @@ def _def(name, *args, **kwargs): del _Tensor.ndim if use_c: - _Tensor.permute = _Tensor.order = _C._instancemethod(_C.order) + _Tensor.order = _C._instancemethod(_C.order) else: - _Tensor.permute = _Tensor.order = reference.positional + _Tensor.order = reference.positional _def('mean') _def('sum') diff --git a/functorch/docs/source/batch_norm.rst b/functorch/docs/source/batch_norm.rst index 09eb6001b5b66..8ccd4ee587d35 100644 --- a/functorch/docs/source/batch_norm.rst +++ b/functorch/docs/source/batch_norm.rst @@ -11,7 +11,7 @@ we end up with this error How to fix ---------- All of these options assume that you don't need running stats. If you're using a module this means -that it's assumed you won't use batch norm in evalution mode. If you have a use case that involves +that it's assumed you won't use batch norm in evaluation mode. If you have a use case that involves running batch norm with vmap in evaluation mode, please file an issue Option 1: Change the BatchNorm diff --git a/functorch/docs/source/ux_limitations.rst b/functorch/docs/source/ux_limitations.rst index e0090047752e0..4fee30e432881 100644 --- a/functorch/docs/source/ux_limitations.rst +++ b/functorch/docs/source/ux_limitations.rst @@ -290,5 +290,5 @@ Under "same" randomness, elements in a batch produce same random values. For ins .. note:: Finally, our randomness differs from JAX because we aren't using a stateless PRNG, in part because PyTorch doesn't have full support for a stateless PRNG. Instead, we've introduced a flag system to allow for the - most common forms of randmoness that we see. If your use case does not fit these forms of randomness, please + most common forms of randomness that we see. If your use case does not fit these forms of randomness, please file an issue. diff --git a/functorch/examples/compilation/fuse_module.py b/functorch/examples/compilation/fuse_module.py index dafbc80711a3a..ec091eb24435a 100644 --- a/functorch/examples/compilation/fuse_module.py +++ b/functorch/examples/compilation/fuse_module.py @@ -36,7 +36,7 @@ def forward(self, x): compiled_mod = compiled_module(mod, fw_compiler, bw_compiler) for a, b in zip(run(mod, input), run(compiled_mod, input)): - torch.testing.assert_allclose(a, b) + torch.testing.assert_close(a, b) out = mod(input) out.sum().backward() @@ -45,7 +45,7 @@ def forward(self, x): compiled_mod.orig_module.param.grad = None for a, b in zip(run(mod, input), run(compiled_mod, input)): - torch.testing.assert_allclose(a, b) + torch.testing.assert_close(a, b) for _ in range(5): i = 10000 diff --git a/functorch/examples/maml_omniglot/README.md b/functorch/examples/maml_omniglot/README.md index dfb6077814bfe..afc3f55023d47 100644 --- a/functorch/examples/maml_omniglot/README.md +++ b/functorch/examples/maml_omniglot/README.md @@ -1,6 +1,6 @@ # Omniglot MAML examples -In this directory we've provided some examples of traning omniglot that reproduce the experiments from [the original MAML paper](https://arxiv.org/abs/1703.03400). +In this directory we've provided some examples of training omniglot that reproduce the experiments from [the original MAML paper](https://arxiv.org/abs/1703.03400). They can be run via `python {filename}`. diff --git a/functorch/experimental/__init__.py b/functorch/experimental/__init__.py index ea874acafc425..dde503f93bb62 100644 --- a/functorch/experimental/__init__.py +++ b/functorch/experimental/__init__.py @@ -1,5 +1,5 @@ -from .batch_norm_replacement import replace_all_batch_norm_modules_ # PyTorch forward-mode is not mature yet -from .._src.eager_transforms import jvp, jacfwd, hessian -from .._src.vmap import chunk_vmap +from torch._functorch.eager_transforms import hessian, jacfwd, jvp +from torch._functorch.vmap import chunk_vmap +from .batch_norm_replacement import replace_all_batch_norm_modules_ from functorch import functionalize diff --git a/functorch/experimental/cond.py b/functorch/experimental/_cond.py similarity index 73% rename from functorch/experimental/cond.py rename to functorch/experimental/_cond.py index 6f7bcbf506d8d..a3c1936560439 100644 --- a/functorch/experimental/cond.py +++ b/functorch/experimental/_cond.py @@ -1,23 +1,35 @@ import torch + +import torch.utils._pytree as pytree + from torch._C import DispatchKey, DispatchKeySet, ExcludeDispatchKeyGuard from torch._ops import PyOperator +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.proxy_tensor import ( + get_isolated_graphmodule, + get_proxy_slot, + ProxyTorchDispatchMode, + track_tensor_tree, +) +from torch.fx.passes.shape_prop import _extract_tensor_metadata +from torch.utils._python_dispatch import ( + _get_current_dispatch_mode, + _pop_mode_temporarily, +) from torch.utils._pytree import tree_flatten -from torch.fx.experimental.proxy_tensor import get_isolated_graphmodule, get_proxy_slot -import torch.utils._pytree as pytree -from torch.utils._python_dispatch import _get_current_dispatch_mode, _pop_mode_temporarily -from torch.fx.experimental.proxy_tensor import track_tensor_tree -from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode """ We're going to define a `cond` operation. In order to do this, we need implementations for each of the dispatch keys. """ -cond = PyOperator('cond') +cond = PyOperator("cond") def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands): def _unwrap_proxy(e): + if not isinstance(e, (torch.Tensor, torch.SymInt, torch.SymFloat)): + return e return get_proxy_slot(e, proxy_mode.tracer, e, lambda e: e.proxy) assert isinstance(operands, list), "Cond operands must be a list of tensors" @@ -113,6 +125,30 @@ def inner(pred, true_fn, false_fn, operands): return res +@cond.py_impl(FakeTensorMode) +def cond_fake_tensor_mode(pred, true_fn, false_fn, operands): + true_outs = true_fn(*operands) + flat_true_outs, _ = pytree.tree_flatten(true_outs) + flat_false_outs, _ = pytree.tree_flatten(false_fn(*operands)) + if len(flat_true_outs) != len(flat_false_outs): + raise RuntimeError("Unmatched number of outputs from cond() branches.") + + for true_out, false_out in zip(flat_true_outs, flat_false_outs): + true_meta = _extract_tensor_metadata(true_out) + false_meta = _extract_tensor_metadata(false_out) + if true_meta != false_meta: + raise RuntimeError( + f"Unmatched tensor metadata from cond() branches.\ntrue branch: {true_meta}, false branch: {false_meta}") + return true_outs + + +# We cannot directly call fallthrough here due to issue #89037. +@cond.py_impl(DispatchKey.PythonDispatcher) +def cond_python_dispatcher(*args): + _ = ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.PythonDispatcher)) + return cond(*args) + + # TODO(voz): Make this automatic for keys, this is very ugly atm cond.fallthrough(DispatchKey.PythonTLSSnapshot) cond.fallthrough(DispatchKey.ADInplaceOrView) diff --git a/functorch/experimental/_map.py b/functorch/experimental/_map.py new file mode 100644 index 0000000000000..d681526da4b34 --- /dev/null +++ b/functorch/experimental/_map.py @@ -0,0 +1,105 @@ +import torch +import torch.utils._pytree as pytree +from torch._C import DispatchKey, DispatchKeySet, ExcludeDispatchKeyGuard +from torch._ops import PyOperator +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.proxy_tensor import ( + disable_proxy_modes_tracing, + get_proxy_slot, + make_fx, + ProxyTorchDispatchMode, + track_tensor_tree, +) +from torch.utils._python_dispatch import ( + _get_current_dispatch_mode, + _pop_mode_temporarily, +) +from torch.utils._pytree import tree_flatten + + +map = PyOperator("map") + + +def trace_map(proxy_mode, func_overload, f, xs, *args): + def _unwrap_proxy(e): + if not isinstance(e, (torch.Tensor, torch.SymInt, torch.SymFloat)): + return e + return get_proxy_slot(e, proxy_mode.tracer, e, lambda e: e.proxy) + + + if not isinstance(xs, torch.Tensor): + raise ValueError("map() must loop over a tensor") + if len(xs.shape) == 0 or xs.shape[0] == 0: + raise ValueError("map() cannot be traced with scalar tensors or zero dimension tensors") + if not all(isinstance(o, (torch.Tensor, torch.nn.Module)) for o in args): + raise ValueError("map() operands must be a list of tensors or modules") + + with disable_proxy_modes_tracing(): + body_graph = make_fx(f)(xs[0], *args) + + next_name = None + i = 0 + while not next_name: + candidate = f"body_graph_{i}" + if hasattr(proxy_mode.tracer.root, candidate): + i += 1 + else: + next_name = candidate + + proxy_mode.tracer.root.register_module(next_name, body_graph) + node_args = (body_graph, xs, *args) + proxy_args = pytree.tree_map(_unwrap_proxy, node_args) + out_proxy = proxy_mode.tracer.create_proxy('call_function', func_overload, proxy_args, {}, + name="map") + outs = [body_graph(x, *args) for x in xs] + # Implementation notes: we need to use new_empty() + copy_() here instead of stack() directly + # because stack([...]) takes a fixed size list which will specialize dynamic shape here. + # Meanwhile we want to preserve the looped over dimension as symbolic shape, such that: + # ys: Tensor[s0, ...] = map(xs: Tensor[s0, ...], *args) + out = xs.new_empty([xs.shape[0], *outs[0].shape]) + out.copy_(torch.stack(outs)) + return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer) + + +@map.py_impl(DispatchKey.CPU) +def map_cpu(f, xs, *args): + mode = _get_current_dispatch_mode() + assert (mode is None), "Mode should never be enabled for CPU key" + return torch.stack([f(x, *args) for x in xs]) + + +@map.py_impl(DispatchKey.AutogradCPU) +def map_autograd(f, xs, *args): + # TODO: support autograd + flat_operands, _ = tree_flatten([f, xs, args]) + assert all([not f.requires_grad for f in flat_operands + if isinstance(f, torch.Tensor)]) + + _ = ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.AutogradCPU)) + return map(f, xs, *args) + + +@map.py_impl(ProxyTorchDispatchMode) +def map_proxy_torch_dispatch_mode(f, xs, *args): + mode = _get_current_dispatch_mode() + assert (mode is not None), "Mode should always be enabled for python fallback key" + with _pop_mode_temporarily() as mode: + res = trace_map(mode, map, f, xs, *args) + return res + + +@map.py_impl(FakeTensorMode) +def map_fake_tensor_mode(f, xs, *args): + return torch.stack([f(x, *args) for x in xs]) + +# We cannot directly call fallthrough here due to issue #89037. +@map.py_impl(DispatchKey.PythonDispatcher) +def map_python_dispatcher(*args): + _ = ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.PythonDispatcher)) + return map(*args) + + +# TODO(voz) Make this automatic for keys, this is very ugly atm +map.fallthrough(DispatchKey.PythonTLSSnapshot) +map.fallthrough(DispatchKey.ADInplaceOrView) +map.fallthrough(DispatchKey.BackendSelect) diff --git a/functorch/experimental/control_flow.py b/functorch/experimental/control_flow.py new file mode 100644 index 0000000000000..fb235b10cc460 --- /dev/null +++ b/functorch/experimental/control_flow.py @@ -0,0 +1,2 @@ +from ._map import map # noqa: F401 +from ._cond import cond # noqa: F401 diff --git a/functorch/packaging/build_wheel.sh b/functorch/packaging/build_wheel.sh deleted file mode 100644 index 074e7dde77141..0000000000000 --- a/functorch/packaging/build_wheel.sh +++ /dev/null @@ -1,19 +0,0 @@ -#!/bin/bash -set -ex - -script_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" -. "$script_dir/pkg_helpers.bash" - -export BUILD_TYPE=wheel -setup_env 0.2.0 -setup_wheel_python -pip_install numpy pyyaml future ninja -pip_install --upgrade setuptools -setup_pip_pytorch_version -python setup.py clean - -if [[ "$OSTYPE" == "msys" ]]; then - "$script_dir/windows/internal/vc_env_helper.bat" python setup.py bdist_wheel -else - python setup.py bdist_wheel -fi diff --git a/functorch/packaging/pkg_helpers.bash b/functorch/packaging/pkg_helpers.bash deleted file mode 100644 index 329891a07216c..0000000000000 --- a/functorch/packaging/pkg_helpers.bash +++ /dev/null @@ -1,414 +0,0 @@ -# A set of useful bash functions for common functionality we need to do in -# many build scripts - - -# Setup CUDA environment variables, based on CU_VERSION -# -# Inputs: -# CU_VERSION (cpu, cu92, cu100) -# NO_CUDA_PACKAGE (bool) -# BUILD_TYPE (conda, wheel) -# -# Outputs: -# VERSION_SUFFIX (e.g., "") -# PYTORCH_VERSION_SUFFIX (e.g., +cpu) -# WHEEL_DIR (e.g., cu100/) -# CUDA_HOME (e.g., /usr/local/cuda-9.2, respected by torch.utils.cpp_extension) -# FORCE_CUDA (respected by torchvision setup.py) -# NVCC_FLAGS (respected by torchvision setup.py) -# -# Precondition: CUDA versions are installed in their conventional locations in -# /usr/local/cuda-* -# -# NOTE: Why VERSION_SUFFIX versus PYTORCH_VERSION_SUFFIX? If you're building -# a package with CUDA on a platform we support CUDA on, VERSION_SUFFIX == -# PYTORCH_VERSION_SUFFIX and everyone is happy. However, if you are building a -# package with only CPU bits (e.g., torchaudio), then VERSION_SUFFIX is always -# empty, but PYTORCH_VERSION_SUFFIX is +cpu (because that's how you get a CPU -# version of a Python package. But that doesn't apply if you're on OS X, -# since the default CU_VERSION on OS X is cpu. -setup_cuda() { - - # First, compute version suffixes. By default, assume no version suffixes - export VERSION_SUFFIX="" - export PYTORCH_VERSION_SUFFIX="" - export WHEEL_DIR="" - # Wheel builds need suffixes (but not if they're on OS X, which never has suffix) - if [[ "$BUILD_TYPE" == "wheel" ]] && [[ "$(uname)" != Darwin ]]; then - export PYTORCH_VERSION_SUFFIX="+$CU_VERSION" - # Match the suffix scheme of pytorch, unless this package does not have - # CUDA builds (in which case, use default) - if [[ -z "$NO_CUDA_PACKAGE" ]]; then - export VERSION_SUFFIX="$PYTORCH_VERSION_SUFFIX" - export WHEEL_DIR="$CU_VERSION/" - fi - fi - - # Now work out the CUDA settings - case "$CU_VERSION" in - cu115) - if [[ "$OSTYPE" == "msys" ]]; then - export CUDA_HOME="C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v11.5" - else - export CUDA_HOME=/usr/local/cuda-11.5/ - fi - export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6" - ;; - cu113) - if [[ "$OSTYPE" == "msys" ]]; then - export CUDA_HOME="C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v11.3" - else - export CUDA_HOME=/usr/local/cuda-11.3/ - fi - export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6" - ;; - cu112) - if [[ "$OSTYPE" == "msys" ]]; then - export CUDA_HOME="C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v11.2" - else - export CUDA_HOME=/usr/local/cuda-11.2/ - fi - export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6" - ;; - cu111) - if [[ "$OSTYPE" == "msys" ]]; then - export CUDA_HOME="C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v11.1" - else - export CUDA_HOME=/usr/local/cuda-11.1/ - fi - export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6" - ;; - cu110) - if [[ "$OSTYPE" == "msys" ]]; then - export CUDA_HOME="C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v11.0" - else - export CUDA_HOME=/usr/local/cuda-11.0/ - fi - export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0" - ;; - cu102) - if [[ "$OSTYPE" == "msys" ]]; then - export CUDA_HOME="C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v10.2" - else - export CUDA_HOME=/usr/local/cuda-10.2/ - fi - export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5" - ;; - cu101) - if [[ "$OSTYPE" == "msys" ]]; then - export CUDA_HOME="C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v10.1" - else - export CUDA_HOME=/usr/local/cuda-10.1/ - fi - export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5" - ;; - cu100) - if [[ "$OSTYPE" == "msys" ]]; then - export CUDA_HOME="C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v10.0" - else - export CUDA_HOME=/usr/local/cuda-10.0/ - fi - export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5" - ;; - cu92) - if [[ "$OSTYPE" == "msys" ]]; then - export CUDA_HOME="C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v9.2" - else - export CUDA_HOME=/usr/local/cuda-9.2/ - fi - export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0" - ;; - cpu) - ;; - rocm*) - export FORCE_CUDA=1 - ;; - *) - echo "Unrecognized CU_VERSION=$CU_VERSION" - exit 1 - ;; - esac - if [[ -n "$CUDA_HOME" ]]; then - # Adds nvcc binary to the search path so that CMake's `find_package(CUDA)` will pick the right one - export PATH="$CUDA_HOME/bin:$PATH" - export FORCE_CUDA=1 - fi -} - -# Populate build version if necessary, and add version suffix -# -# Inputs: -# BUILD_VERSION (e.g., 0.2.0 or empty) -# VERSION_SUFFIX (e.g., +cpu) -# -# Outputs: -# BUILD_VERSION (e.g., 0.2.0.dev20190807+cpu) -# -# Fill BUILD_VERSION if it doesn't exist already with a nightly string -# Usage: setup_build_version 0.2.0 -setup_build_version() { - if [[ -z "$BUILD_VERSION" ]]; then - export BUILD_VERSION="$1.dev$(date "+%Y%m%d")$VERSION_SUFFIX" - else - export BUILD_VERSION="$BUILD_VERSION$VERSION_SUFFIX" - fi - - # Set build version based on tag if on tag - if [[ -n "${CIRCLE_TAG}" ]]; then - # Strip tag - export BUILD_VERSION="$(echo "${CIRCLE_TAG}" | sed -e 's/^v//' -e 's/-.*$//')${VERSION_SUFFIX}" - fi -} - -# Set some useful variables for OS X, if applicable -setup_macos() { - if [[ "$(uname)" == Darwin ]]; then - export MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++ - fi -} - - -# Top-level entry point for things every package will need to do -# -# Usage: setup_env 0.2.0 -setup_env() { - setup_cuda - setup_build_version "$1" - setup_macos -} - -# Function to retry functions that sometimes timeout or have flaky failures -retry () { - $* || (sleep 1 && $*) || (sleep 2 && $*) || (sleep 4 && $*) || (sleep 8 && $*) -} - -# Inputs: -# PYTHON_VERSION (3.7, 3.8, 3.9) -# UNICODE_ABI (bool) -# -# Outputs: -# PATH modified to put correct Python version in PATH -# -# Precondition: If Linux, you are in a soumith/manylinux-cuda* Docker image -setup_wheel_python() { - if [[ "$(uname)" == Darwin || "$OSTYPE" == "msys" ]]; then - eval "$(conda shell.bash hook)" - conda env remove -n "env$PYTHON_VERSION" || true - conda create ${CONDA_CHANNEL_FLAGS} -yn "env$PYTHON_VERSION" python="$PYTHON_VERSION" - conda activate "env$PYTHON_VERSION" - # Install libpng from Anaconda (defaults) - conda install ${CONDA_CHANNEL_FLAGS} libpng "jpeg<=9b" -y - else - # Install native CentOS libJPEG, freetype and GnuTLS - yum install -y libjpeg-turbo-devel freetype gnutls - case "$PYTHON_VERSION" in - 3.7) python_abi=cp37-cp37m ;; - 3.8) python_abi=cp38-cp38 ;; - 3.9) python_abi=cp39-cp39 ;; - 3.10) python_abi=cp310-cp310 ;; - *) - echo "Unrecognized PYTHON_VERSION=$PYTHON_VERSION" - exit 1 - ;; - esac - # Download all the dependencies required to compile image and video_reader - # extensions - - mkdir -p ext_libraries - pushd ext_libraries - popd - export PATH="/opt/python/$python_abi/bin:$(pwd)/ext_libraries/bin:$PATH" - fi -} - -# Install with pip a bit more robustly than the default -pip_install() { - retry pip install --progress-bar off "$@" -} - -# Install torch with pip, respecting PYTORCH_VERSION, and record the installed -# version into PYTORCH_VERSION, if applicable -setup_pip_pytorch_version() { - if [[ -z "$PYTORCH_VERSION" ]]; then - # Install latest prerelease version of torch, per our nightlies, consistent - # with the requested cuda version - pip_install --pre torch -f "https://download.pytorch.org/whl/nightly/${WHEEL_DIR}torch_nightly.html" - if [[ "$CUDA_VERSION" == "cpu" ]]; then - # CUDA and CPU are ABI compatible on the CPU-only parts, so strip - # in this case - export PYTORCH_VERSION="$(pip show torch | grep ^Version: | sed 's/Version: *//' | sed 's/+.\+//')" - else - export PYTORCH_VERSION="$(pip show torch | grep ^Version: | sed 's/Version: *//')" - fi - else - pip_install "torch==$PYTORCH_VERSION$PYTORCH_VERSION_SUFFIX" \ - -f "https://download.pytorch.org/whl/${CU_VERSION}/torch_stable.html" \ - -f "https://download.pytorch.org/whl/${UPLOAD_CHANNEL}/${CU_VERSION}/torch_${UPLOAD_CHANNEL}.html" - fi -} - -# Fill PYTORCH_VERSION with the latest conda nightly version, and -# CONDA_CHANNEL_FLAGS with appropriate flags to retrieve these versions -# -# You MUST have populated PYTORCH_VERSION_SUFFIX before hand. -setup_conda_pytorch_constraint() { - if [[ -z "$PYTORCH_VERSION" ]]; then - export CONDA_CHANNEL_FLAGS="${CONDA_CHANNEL_FLAGS} -c pytorch-nightly -c pytorch" - export PYTORCH_VERSION="$(conda search --json 'pytorch[channel=pytorch-nightly]' | \ - python -c "import os, sys, json, re; cuver = os.environ.get('CU_VERSION'); \ - cuver_1 = cuver.replace('cu', 'cuda') if cuver != 'cpu' else cuver; \ - cuver_2 = (cuver[:-1] + '.' + cuver[-1]).replace('cu', 'cuda') if cuver != 'cpu' else cuver; \ - print(re.sub(r'\\+.*$', '', \ - [x['version'] for x in json.load(sys.stdin)['pytorch'] \ - if (x['platform'] == 'darwin' or cuver_1 in x['fn'] or cuver_2 in x['fn']) \ - and 'py' + os.environ['PYTHON_VERSION'] in x['fn']][-1]))")" - if [[ -z "$PYTORCH_VERSION" ]]; then - echo "PyTorch version auto detection failed" - echo "No package found for CU_VERSION=$CU_VERSION and PYTHON_VERSION=$PYTHON_VERSION" - exit 1 - fi - else - export CONDA_CHANNEL_FLAGS="${CONDA_CHANNEL_FLAGS} -c pytorch -c pytorch-${UPLOAD_CHANNEL}" - fi - if [[ "$CU_VERSION" == cpu ]]; then - export CONDA_PYTORCH_BUILD_CONSTRAINT="- pytorch==$PYTORCH_VERSION${PYTORCH_VERSION_SUFFIX}" - export CONDA_PYTORCH_CONSTRAINT="- pytorch==$PYTORCH_VERSION" - else - export CONDA_PYTORCH_BUILD_CONSTRAINT="- pytorch==${PYTORCH_VERSION}${PYTORCH_VERSION_SUFFIX}" - export CONDA_PYTORCH_CONSTRAINT="- pytorch==${PYTORCH_VERSION}${PYTORCH_VERSION_SUFFIX}" - fi - if [[ "$OSTYPE" == msys && "$CU_VERSION" == cu92 ]]; then - export CONDA_CHANNEL_FLAGS="${CONDA_CHANNEL_FLAGS} -c defaults -c numba/label/dev" - fi -} - -# Translate CUDA_VERSION into CUDA_CUDATOOLKIT_CONSTRAINT -setup_conda_cudatoolkit_constraint() { - export CONDA_BUILD_VARIANT="cuda" - if [[ "$(uname)" == Darwin ]]; then - export CONDA_BUILD_VARIANT="cpu" - else - case "$CU_VERSION" in - cu115) - export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=11.5,<11.6 # [not osx]" - ;; - cu113) - export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=11.3,<11.4 # [not osx]" - ;; - cu112) - export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=11.2,<11.3 # [not osx]" - ;; - cu111) - export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=11.1,<11.2 # [not osx]" - ;; - cu110) - export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=11.0,<11.1 # [not osx]" - ;; - cu102) - export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=10.2,<10.3 # [not osx]" - ;; - cu101) - export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=10.1,<10.2 # [not osx]" - ;; - cu100) - export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=10.0,<10.1 # [not osx]" - ;; - cu92) - export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=9.2,<9.3 # [not osx]" - ;; - cpu) - export CONDA_CUDATOOLKIT_CONSTRAINT="" - export CONDA_BUILD_VARIANT="cpu" - ;; - *) - echo "Unrecognized CU_VERSION=$CU_VERSION" - exit 1 - ;; - esac - fi -} - -setup_conda_cudatoolkit_plain_constraint() { - export CONDA_BUILD_VARIANT="cuda" - export CMAKE_USE_CUDA=1 - if [[ "$(uname)" == Darwin ]]; then - export CONDA_BUILD_VARIANT="cpu" - export CMAKE_USE_CUDA=0 - else - case "$CU_VERSION" in - cu115) - export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit=11.5" - ;; - cu113) - export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit=11.3" - ;; - cu112) - export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit=11.2" - ;; - cu111) - export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit=11.1" - ;; - cu102) - export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit=10.2" - ;; - cu101) - export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit=10.1" - ;; - cu100) - export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit=10.0" - ;; - cu92) - export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit=9.2" - ;; - cpu) - export CONDA_CUDATOOLKIT_CONSTRAINT="" - export CONDA_BUILD_VARIANT="cpu" - export CMAKE_USE_CUDA=0 - ;; - *) - echo "Unrecognized CU_VERSION=$CU_VERSION" - exit 1 - ;; - esac - fi -} - -# Build the proper compiler package before building the final package -setup_visual_studio_constraint() { - if [[ "$OSTYPE" == "msys" ]]; then - export VSTOOLCHAIN_PACKAGE=vs$VC_YEAR - conda build $CONDA_CHANNEL_FLAGS --no-anaconda-upload packaging/$VSTOOLCHAIN_PACKAGE - cp packaging/$VSTOOLCHAIN_PACKAGE/conda_build_config.yaml packaging/torchvision/conda_build_config.yaml - fi -} - -setup_junit_results_folder() { - if [[ "$CI" == "true" ]]; then - export CONDA_PYTORCH_BUILD_RESULTS_DIRECTORY="${SOURCE_ROOT_DIR}/build_results/results.xml" - fi -} - - -download_copy_ffmpeg() { - if [[ "$OSTYPE" == "msys" ]]; then - # conda install -yq ffmpeg=4.2 -c pytorch - # curl -L -q https://anaconda.org/pytorch/ffmpeg/4.3/download/win-64/ffmpeg-4.3-ha925a31_0.tar.bz2 --output ffmpeg-4.3-ha925a31_0.tar.bz2 - # bzip2 --decompress --stdout ffmpeg-4.3-ha925a31_0.tar.bz2 | tar -x --file=- - # cp Library/bin/*.dll ../torchvision - echo "FFmpeg is disabled currently on Windows" - else - if [[ "$(uname)" == Darwin ]]; then - conda install -yq ffmpeg=4.2 -c pytorch - conda install -yq wget - else - # pushd ext_libraries - # wget -q https://anaconda.org/pytorch/ffmpeg/4.2/download/linux-64/ffmpeg-4.2-hf484d3e_0.tar.bz2 - # tar -xjvf ffmpeg-4.2-hf484d3e_0.tar.bz2 - # rm -rf ffmpeg-4.2-hf484d3e_0.tar.bz2 - # ldconfig - # which ffmpeg - # popd - echo "FFmpeg is disabled currently on Linux" - fi - fi -} diff --git a/functorch/packaging/windows/internal/cuda_install.bat b/functorch/packaging/windows/internal/cuda_install.bat deleted file mode 100644 index 41960224ebaed..0000000000000 --- a/functorch/packaging/windows/internal/cuda_install.bat +++ /dev/null @@ -1,264 +0,0 @@ -@echo on - -if "%CU_VERSION%" == "cpu" ( - echo Skipping for CPU builds - exit /b 0 -) - -set SRC_DIR=%~dp0\.. - -if not exist "%SRC_DIR%\temp_build" mkdir "%SRC_DIR%\temp_build" - -rem in unit test workflow, we get CUDA_VERSION, for example 11.1 -if defined CUDA_VERSION ( - set CUDA_VER=%CUDA_VERSION:.=% -) else ( - set CUDA_VER=%CU_VERSION:cu=% -) - -set /a CUDA_VER=%CU_VERSION:cu=% -set CUDA_VER_MAJOR=%CUDA_VER:~0,-1% -set CUDA_VER_MINOR=%CUDA_VER:~-1,1% -set CUDA_VERSION_STR=%CUDA_VER_MAJOR%.%CUDA_VER_MINOR% - - -if %CUDA_VER% EQU 92 goto cuda92 -if %CUDA_VER% EQU 100 goto cuda100 -if %CUDA_VER% EQU 101 goto cuda101 -if %CUDA_VER% EQU 102 goto cuda102 -if %CUDA_VER% EQU 110 goto cuda110 -if %CUDA_VER% EQU 111 goto cuda111 -if %CUDA_VER% EQU 112 goto cuda112 -if %CUDA_VER% EQU 113 goto cuda113 -if %CUDA_VER% EQU 115 goto cuda115 - - -echo CUDA %CUDA_VERSION_STR% is not supported -exit /b 1 - -:cuda92 -if not exist "%SRC_DIR%\temp_build\cuda_9.2.148_win10.exe" ( - curl -k -L https://ossci-windows.s3.amazonaws.com/win2016/cuda_9.2.148_win10.exe --output "%SRC_DIR%\temp_build\cuda_9.2.148_win10.exe" - if errorlevel 1 exit /b 1 - set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_9.2.148_win10.exe" - set "ARGS=nvcc_9.2 cuobjdump_9.2 nvprune_9.2 cupti_9.2 cublas_9.2 cublas_dev_9.2 cudart_9.2 cufft_9.2 cufft_dev_9.2 curand_9.2 curand_dev_9.2 cusolver_9.2 cusolver_dev_9.2 cusparse_9.2 cusparse_dev_9.2 nvgraph_9.2 nvgraph_dev_9.2 npp_9.2 npp_dev_9.2 nvrtc_9.2 nvrtc_dev_9.2 nvml_dev_9.2" -) - -if not exist "%SRC_DIR%\temp_build\cudnn-9.2-windows10-x64-v7.2.1.38.zip" ( - curl -k -L https://ossci-windows.s3.amazonaws.com/win2016/cudnn-9.2-windows10-x64-v7.2.1.38.zip --output "%SRC_DIR%\temp_build\cudnn-9.2-windows10-x64-v7.2.1.38.zip" - if errorlevel 1 exit /b 1 - set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-9.2-windows10-x64-v7.2.1.38.zip" -) - -goto cuda_common - -:cuda100 - -if not exist "%SRC_DIR%\temp_build\cuda_10.0.130_411.31_win10.exe" ( - curl -k -L https://ossci-windows.s3.amazonaws.com/win2016/cuda_10.0.130_411.31_win10.exe --output "%SRC_DIR%\temp_build\cuda_10.0.130_411.31_win10.exe" - if errorlevel 1 exit /b 1 - set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_10.0.130_411.31_win10.exe" - set "ARGS=nvcc_10.0 cuobjdump_10.0 nvprune_10.0 cupti_10.0 cublas_10.0 cublas_dev_10.0 cudart_10.0 cufft_10.0 cufft_dev_10.0 curand_10.0 curand_dev_10.0 cusolver_10.0 cusolver_dev_10.0 cusparse_10.0 cusparse_dev_10.0 nvgraph_10.0 nvgraph_dev_10.0 npp_10.0 npp_dev_10.0 nvrtc_10.0 nvrtc_dev_10.0 nvml_dev_10.0" -) - -if not exist "%SRC_DIR%\temp_build\cudnn-10.0-windows10-x64-v7.4.1.5.zip" ( - curl -k -L https://ossci-windows.s3.amazonaws.com/win2016/cudnn-10.0-windows10-x64-v7.4.1.5.zip --output "%SRC_DIR%\temp_build\cudnn-10.0-windows10-x64-v7.4.1.5.zip" - if errorlevel 1 exit /b 1 - set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-10.0-windows10-x64-v7.4.1.5.zip" -) - -goto cuda_common - -:cuda101 - -if not exist "%SRC_DIR%\temp_build\cuda_10.1.243_426.00_win10.exe" ( - curl -k -L https://ossci-windows.s3.amazonaws.com/cuda_10.1.243_426.00_win10.exe --output "%SRC_DIR%\temp_build\cuda_10.1.243_426.00_win10.exe" - if errorlevel 1 exit /b 1 - set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_10.1.243_426.00_win10.exe" - set "ARGS=nvcc_10.1 cuobjdump_10.1 nvprune_10.1 cupti_10.1 cublas_10.1 cublas_dev_10.1 cudart_10.1 cufft_10.1 cufft_dev_10.1 curand_10.1 curand_dev_10.1 cusolver_10.1 cusolver_dev_10.1 cusparse_10.1 cusparse_dev_10.1 nvgraph_10.1 nvgraph_dev_10.1 npp_10.1 npp_dev_10.1 nvjpeg_10.1 nvjpeg_dev_10.1 nvrtc_10.1 nvrtc_dev_10.1 nvml_dev_10.1" -) - -if not exist "%SRC_DIR%\temp_build\cudnn-10.1-windows10-x64-v7.6.4.38.zip" ( - curl -k -L https://ossci-windows.s3.amazonaws.com/cudnn-10.1-windows10-x64-v7.6.4.38.zip --output "%SRC_DIR%\temp_build\cudnn-10.1-windows10-x64-v7.6.4.38.zip" - if errorlevel 1 exit /b 1 - set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-10.1-windows10-x64-v7.6.4.38.zip" -) - -goto cuda_common - -:cuda102 - -if not exist "%SRC_DIR%\temp_build\cuda_10.2.89_441.22_win10.exe" ( - curl -k -L https://ossci-windows.s3.amazonaws.com/cuda_10.2.89_441.22_win10.exe --output "%SRC_DIR%\temp_build\cuda_10.2.89_441.22_win10.exe" - if errorlevel 1 exit /b 1 - set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_10.2.89_441.22_win10.exe" - set "ARGS=nvcc_10.2 cuobjdump_10.2 nvprune_10.2 cupti_10.2 cublas_10.2 cublas_dev_10.2 cudart_10.2 cufft_10.2 cufft_dev_10.2 curand_10.2 curand_dev_10.2 cusolver_10.2 cusolver_dev_10.2 cusparse_10.2 cusparse_dev_10.2 nvgraph_10.2 nvgraph_dev_10.2 npp_10.2 npp_dev_10.2 nvjpeg_10.2 nvjpeg_dev_10.2 nvrtc_10.2 nvrtc_dev_10.2 nvml_dev_10.2" -) - -if not exist "%SRC_DIR%\temp_build\cudnn-10.2-windows10-x64-v7.6.5.32.zip" ( - curl -k -L https://ossci-windows.s3.amazonaws.com/cudnn-10.2-windows10-x64-v7.6.5.32.zip --output "%SRC_DIR%\temp_build\cudnn-10.2-windows10-x64-v7.6.5.32.zip" - if errorlevel 1 exit /b 1 - set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-10.2-windows10-x64-v7.6.5.32.zip" -) - -rem The below only for cu102, if it's used in other version, e.g. cu111, torch.cuda.is_availabe() would be False. -if not exist "%SRC_DIR%\temp_build\gpu_driver_dlls.7z" ( - curl -k -L "https://drive.google.com/u/0/uc?id=1injUyo3lnarMgWyRcXqKg4UGnN0ysmuq&export=download" --output "%SRC_DIR%\temp_build\gpu_driver_dlls.zip" - if errorlevel 1 exit /b 1 -) - -echo Installing GPU driver DLLs -7z x %SRC_DIR%\temp_build\gpu_driver_dlls.zip -aoa -o"C:\Windows\System32" - -goto cuda_common - -:cuda110 - -if not exist "%SRC_DIR%\temp_build\cuda_11.0.2_451.48_win10.exe" ( - curl -k -L https://ossci-windows.s3.amazonaws.com/cuda_11.0.2_451.48_win10.exe --output "%SRC_DIR%\temp_build\cuda_11.0.2_451.48_win10.exe" - if errorlevel 1 exit /b 1 - set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_11.0.2_451.48_win10.exe" - set "ARGS=nvcc_11.0 cuobjdump_11.0 nvprune_11.0 nvprof_11.0 cupti_11.0 cublas_11.0 cublas_dev_11.0 cudart_11.0 cufft_11.0 cufft_dev_11.0 curand_11.0 curand_dev_11.0 cusolver_11.0 cusolver_dev_11.0 cusparse_11.0 cusparse_dev_11.0 npp_11.0 npp_dev_11.0 nvjpeg_11.0 nvjpeg_dev_11.0 nvrtc_11.0 nvrtc_dev_11.0 nvml_dev_11.0" -) - -if not exist "%SRC_DIR%\temp_build\cudnn-11.0-windows-x64-v8.0.4.30.zip" ( - curl -k -L https://ossci-windows.s3.amazonaws.com/cudnn-11.0-windows-x64-v8.0.4.30.zip --output "%SRC_DIR%\temp_build\cudnn-11.0-windows-x64-v8.0.4.30.zip" - if errorlevel 1 exit /b 1 - set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-11.0-windows-x64-v8.0.4.30.zip" -) - -goto cuda_common - -:cuda111 - -if not exist "%SRC_DIR%\temp_build\cuda_11.1.1_456.81_win10.exe" ( - curl -k -L https://ossci-windows.s3.amazonaws.com/cuda_11.1.1_456.81_win10.exe --output "%SRC_DIR%\temp_build\cuda_11.1.1_456.81_win10.exe" - if errorlevel 1 exit /b 1 - set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_11.1.1_456.81_win10.exe" - set "ARGS=nvcc_11.1 cuobjdump_11.1 nvprune_11.1 nvprof_11.1 cupti_11.1 cublas_11.1 cublas_dev_11.1 cudart_11.1 cufft_11.1 cufft_dev_11.1 curand_11.1 curand_dev_11.1 cusolver_11.1 cusolver_dev_11.1 cusparse_11.1 cusparse_dev_11.1 npp_11.1 npp_dev_11.1 nvjpeg_11.1 nvjpeg_dev_11.1 nvrtc_11.1 nvrtc_dev_11.1 nvml_dev_11.1" -) - -if not exist "%SRC_DIR%\temp_build\cudnn-11.1-windows-x64-v8.0.5.39.zip" ( - curl -k -L https://ossci-windows.s3.amazonaws.com/cudnn-11.1-windows-x64-v8.0.5.39.zip --output "%SRC_DIR%\temp_build\cudnn-11.1-windows-x64-v8.0.5.39.zip" - if errorlevel 1 exit /b 1 - set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-11.1-windows-x64-v8.0.5.39.zip" -) - -goto cuda_common - -:cuda112 - -if not exist "%SRC_DIR%\temp_build\cuda_11.2.0_460.89_win10.exe" ( - curl -k -L https://ossci-windows.s3.amazonaws.com/cuda_11.2.0_460.89_win10.exe --output "%SRC_DIR%\temp_build\cuda_11.2.0_460.89_win10.exe" - if errorlevel 1 exit /b 1 - set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_11.2.0_460.89_win10.exe" - set "ARGS=nvcc_11.2 cuobjdump_11.2 nvprune_11.2 nvprof_11.2 cupti_11.2 cublas_11.2 cublas_dev_11.2 cudart_11.2 cufft_11.2 cufft_dev_11.2 curand_11.2 curand_dev_11.2 cusolver_11.2 cusolver_dev_11.2 cusparse_11.2 cusparse_dev_11.2 npp_11.2 npp_dev_11.2 nvjpeg_11.2 nvjpeg_dev_11.2 nvrtc_11.2 nvrtc_dev_11.2 nvml_dev_11.2" -) - -if not exist "%SRC_DIR%\temp_build\cudnn-11.2-windows-x64-v8.1.0.77.zip" ( - curl -k -L http://s3.amazonaws.com/ossci-windows/cudnn-11.2-windows-x64-v8.1.0.77.zip --output "%SRC_DIR%\temp_build\cudnn-11.2-windows-x64-v8.1.0.77.zip" - if errorlevel 1 exit /b 1 - set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-11.2-windows-x64-v8.1.0.77.zip" -) - -goto cuda_common - -:cuda113 - -set CUDA_INSTALL_EXE=cuda_11.3.0_465.89_win10.exe -if not exist "%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" ( - curl -k -L "https://ossci-windows.s3.amazonaws.com/%CUDA_INSTALL_EXE%" --output "%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" - if errorlevel 1 exit /b 1 - set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" - set "ARGS=thrust_11.3 nvcc_11.3 cuobjdump_11.3 nvprune_11.3 nvprof_11.3 cupti_11.3 cublas_11.3 cublas_dev_11.3 cudart_11.3 cufft_11.3 cufft_dev_11.3 curand_11.3 curand_dev_11.3 cusolver_11.3 cusolver_dev_11.3 cusparse_11.3 cusparse_dev_11.3 npp_11.3 npp_dev_11.3 nvjpeg_11.3 nvjpeg_dev_11.3 nvrtc_11.3 nvrtc_dev_11.3 nvml_dev_11.3" - -) - -set CUDNN_INSTALL_ZIP=cudnn-11.3-windows-x64-v8.2.0.53.zip -if not exist "%SRC_DIR%\temp_build\%CUDNN_INSTALL_ZIP%" ( - curl -k -L "http://s3.amazonaws.com/ossci-windows/%CUDNN_INSTALL_ZIP%" --output "%SRC_DIR%\temp_build\%CUDNN_INSTALL_ZIP%" - if errorlevel 1 exit /b 1 - set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\%CUDNN_INSTALL_ZIP%" -) - -goto cuda_common - -:cuda115 - -set CUDA_INSTALL_EXE=cuda_11.5.0_496.13_win10.exe -if not exist "%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" ( - curl -k -L "https://ossci-windows.s3.amazonaws.com/%CUDA_INSTALL_EXE%" --output "%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" - if errorlevel 1 exit /b 1 - set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" - set "ARGS=thrust_11.5 nvcc_11.5 cuobjdump_11.5 nvprune_11.5 nvprof_11.5 cupti_11.5 cublas_11.5 cublas_dev_11.5 cudart_11.5 cufft_11.5 cufft_dev_11.5 curand_11.5 curand_dev_11.5 cusolver_11.5 cusolver_dev_11.5 cusparse_11.5 cusparse_dev_11.5 npp_11.5 npp_dev_11.5 nvrtc_11.5 nvrtc_dev_11.5 nvml_dev_11.5" -) - -set CUDNN_INSTALL_ZIP=cudnn-11.3-windows-x64-v8.2.0.53.zip -if not exist "%SRC_DIR%\temp_build\%CUDNN_INSTALL_ZIP%" ( - curl -k -L "http://s3.amazonaws.com/ossci-windows/%CUDNN_INSTALL_ZIP%" --output "%SRC_DIR%\temp_build\%CUDNN_INSTALL_ZIP%" - if errorlevel 1 exit /b 1 - set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\%CUDNN_INSTALL_ZIP%" -) - -goto cuda_common - -:cuda_common - -if not exist "%SRC_DIR%\temp_build\NvToolsExt.7z" ( - curl -k -L https://www.dropbox.com/s/9mcolalfdj4n979/NvToolsExt.7z?dl=1 --output "%SRC_DIR%\temp_build\NvToolsExt.7z" - if errorlevel 1 exit /b 1 -) - -echo Installing CUDA toolkit... -7z x %CUDA_SETUP_FILE% -o"%SRC_DIR%\temp_build\cuda" -pushd "%SRC_DIR%\temp_build\cuda" -sc config wuauserv start= disabled -sc stop wuauserv -sc query wuauserv - -start /wait setup.exe -s %ARGS% -loglevel:6 -log:"%cd%/cuda_install_logs" -echo %errorlevel% - -popd - -echo Installing VS integration... -rem It's for VS 2019 -if "%CUDA_VER_MAJOR%" == "10" ( - xcopy /Y "%SRC_DIR%\temp_build\cuda\CUDAVisualStudioIntegration\extras\visual_studio_integration\MSBuildExtensions\*.*" "C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\MSBuild\Microsoft\VC\v160\BuildCustomizations" -) -if "%CUDA_VER_MAJOR%" == "11" ( - xcopy /Y "%SRC_DIR%\temp_build\cuda\visual_studio_integration\CUDAVisualStudioIntegration\extras\visual_studio_integration\MSBuildExtensions\*.*" "C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\MSBuild\Microsoft\VC\v160\BuildCustomizations" -) - -echo Installing NvToolsExt... -7z x %SRC_DIR%\temp_build\NvToolsExt.7z -o"%SRC_DIR%\temp_build\NvToolsExt" -mkdir "%ProgramFiles%\NVIDIA Corporation\NvToolsExt\bin\x64" -mkdir "%ProgramFiles%\NVIDIA Corporation\NvToolsExt\include" -mkdir "%ProgramFiles%\NVIDIA Corporation\NvToolsExt\lib\x64" -xcopy /Y "%SRC_DIR%\temp_build\NvToolsExt\bin\x64\*.*" "%ProgramFiles%\NVIDIA Corporation\NvToolsExt\bin\x64" -xcopy /Y "%SRC_DIR%\temp_build\NvToolsExt\include\*.*" "%ProgramFiles%\NVIDIA Corporation\NvToolsExt\include" -xcopy /Y "%SRC_DIR%\temp_build\NvToolsExt\lib\x64\*.*" "%ProgramFiles%\NVIDIA Corporation\NvToolsExt\lib\x64" - -echo Setting up environment... -set "PATH=%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\bin;%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\libnvvp;%PATH%" -set "CUDA_PATH=%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%" -set "CUDA_PATH_V%CUDA_VER_MAJOR%_%CUDA_VER_MINOR%=%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%" -set "NVTOOLSEXT_PATH=%ProgramFiles%\NVIDIA Corporation\NvToolsExt\bin\x64" - -if not exist "%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\bin\nvcc.exe" ( - echo CUDA %CUDA_VERSION_STR% installed failed. - echo --------- RunDll32.exe.log - type "%SRC_DIR%\temp_build\cuda\cuda_install_logs\LOG.RunDll32.exe.log" - echo --------- setup.exe.log ------- - type "%SRC_DIR%\temp_build\cuda\cuda_install_logs\LOG.setup.exe.log" - exit /b 1 -) - -echo Installing cuDNN... -7z x %CUDNN_SETUP_FILE% -o"%SRC_DIR%\temp_build\cudnn" -xcopy /Y "%SRC_DIR%\temp_build\cudnn\cuda\bin\*.*" "%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\bin" -xcopy /Y "%SRC_DIR%\temp_build\cudnn\cuda\lib\x64\*.*" "%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\lib\x64" -xcopy /Y "%SRC_DIR%\temp_build\cudnn\cuda\include\*.*" "%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\include" - -echo Cleaning temp files -rd /s /q "%SRC_DIR%\temp_build" || ver > nul diff --git a/functorch/packaging/windows/internal/driver_update.bat b/functorch/packaging/windows/internal/driver_update.bat deleted file mode 100644 index 00b43affc01cc..0000000000000 --- a/functorch/packaging/windows/internal/driver_update.bat +++ /dev/null @@ -1,25 +0,0 @@ -set "DRIVER_DOWNLOAD_LINK=https://ossci-windows.s3.amazonaws.com/461.09-data-center-tesla-desktop-winserver-2019-2016-international.exe" -curl --retry 3 -kL %DRIVER_DOWNLOAD_LINK% --output 461.09-data-center-tesla-desktop-winserver-2019-2016-international.exe -if errorlevel 1 exit /b 1 - -start /wait 461.09-data-center-tesla-desktop-winserver-2019-2016-international.exe -s -noreboot -if errorlevel 1 exit /b 1 - -del 461.09-data-center-tesla-desktop-winserver-2019-2016-international.exe || ver > NUL - -setlocal EnableDelayedExpansion -set NVIDIA_GPU_EXISTS=0 -for /F "delims=" %%i in ('wmic path win32_VideoController get name') do ( - set GPUS=%%i - if not "x!GPUS:NVIDIA=!" == "x!GPUS!" ( - SET NVIDIA_GPU_EXISTS=1 - goto gpu_check_end - ) -) -:gpu_check_end -endlocal & set NVIDIA_GPU_EXISTS=%NVIDIA_GPU_EXISTS% - -if "%NVIDIA_GPU_EXISTS%" == "0" ( - echo "CUDA Driver installation Failed" - exit /b 1 -) diff --git a/functorch/packaging/windows/internal/vc_env_helper.bat b/functorch/packaging/windows/internal/vc_env_helper.bat deleted file mode 100644 index e85a372f93d58..0000000000000 --- a/functorch/packaging/windows/internal/vc_env_helper.bat +++ /dev/null @@ -1,43 +0,0 @@ -@echo on - -set VC_VERSION_LOWER=16 -set VC_VERSION_UPPER=17 -if "%VC_YEAR%" == "2017" ( - set VC_VERSION_LOWER=15 - set VC_VERSION_UPPER=16 -) - -for /f "usebackq tokens=*" %%i in (`"%ProgramFiles(x86)%\Microsoft Visual Studio\Installer\vswhere.exe" -legacy -products * -version [%VC_VERSION_LOWER%^,%VC_VERSION_UPPER%^) -property installationPath`) do ( - if exist "%%i" if exist "%%i\VC\Auxiliary\Build\vcvarsall.bat" ( - set "VS15INSTALLDIR=%%i" - set "VS15VCVARSALL=%%i\VC\Auxiliary\Build\vcvarsall.bat" - goto vswhere - ) -) - -:vswhere -if "%VSDEVCMD_ARGS%" == "" ( - call "%VS15VCVARSALL%" x64 || exit /b 1 -) else ( - call "%VS15VCVARSALL%" x64 %VSDEVCMD_ARGS% || exit /b 1 -) - -@echo on - -set DISTUTILS_USE_SDK=1 - -set args=%1 -shift -:start -if [%1] == [] goto done -set args=%args% %1 -shift -goto start - -:done -if "%args%" == "" ( - echo Usage: vc_env_helper.bat [command] [args] - echo e.g. vc_env_helper.bat cl /c test.cpp -) - -%args% || exit /b 1 diff --git a/functorch/packaging/windows/internal/vc_install_helper.sh b/functorch/packaging/windows/internal/vc_install_helper.sh deleted file mode 100644 index cdae18065b9f6..0000000000000 --- a/functorch/packaging/windows/internal/vc_install_helper.sh +++ /dev/null @@ -1,16 +0,0 @@ -#!/bin/bash - -set -ex - -if [[ "$CU_VERSION" == "cu92" ]]; then - export VC_YEAR=2017 - export VSDEVCMD_ARGS="-vcvars_ver=14.13" - powershell packaging/windows/internal/vs2017_install.ps1 -elif [[ "$CU_VERSION" == "cu100" ]]; then - export VC_YEAR=2017 - export VSDEVCMD_ARGS="" - powershell packaging/windows/internal/vs2017_install.ps1 -else - export VC_YEAR=2019 - export VSDEVCMD_ARGS="" -fi diff --git a/ios/LibTorch-Lite.podspec b/ios/LibTorch-Lite.podspec index 9814eaa367586..96b759f221504 100644 --- a/ios/LibTorch-Lite.podspec +++ b/ios/LibTorch-Lite.podspec @@ -1,6 +1,6 @@ Pod::Spec.new do |s| s.name = 'LibTorch-Lite' - s.version = '1.12.0' + s.version = '1.13.0' s.authors = 'PyTorch Team' s.license = { :type => 'BSD' } s.homepage = 'https://github.com/pytorch/pytorch' @@ -33,4 +33,5 @@ Pod::Spec.new do |s| 'VALID_ARCHS' => 'x86_64 arm64' } s.library = ['c++', 'stdc++'] + s.frameworks = 'Accelerate' end diff --git a/ios/LibTorch.podspec b/ios/LibTorch.podspec index 3c197f0f103b9..6cee4993cca14 100644 --- a/ios/LibTorch.podspec +++ b/ios/LibTorch.podspec @@ -1,6 +1,6 @@ Pod::Spec.new do |s| s.name = 'LibTorch' - s.version = '1.12.0' + s.version = '1.13.0' s.authors = 'PyTorch Team' s.license = { :type => 'BSD' } s.homepage = 'https://github.com/pytorch/pytorch' @@ -33,4 +33,5 @@ Pod::Spec.new do |s| 'VALID_ARCHS' => 'x86_64 arm64' } s.library = ['c++', 'stdc++'] + s.frameworks = 'Accelerate' end diff --git a/ios/TestApp/TestApp.xcodeproj/project.pbxproj b/ios/TestApp/TestApp.xcodeproj/project.pbxproj index 09aeeada17239..ff84280f02ebd 100644 --- a/ios/TestApp/TestApp.xcodeproj/project.pbxproj +++ b/ios/TestApp/TestApp.xcodeproj/project.pbxproj @@ -253,7 +253,7 @@ ALWAYS_SEARCH_USER_PATHS = NO; CLANG_ANALYZER_NONNULL = YES; CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; - CLANG_CXX_LANGUAGE_STANDARD = "gnu++14"; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++17"; CLANG_CXX_LIBRARY = "libc++"; CLANG_ENABLE_MODULES = YES; CLANG_ENABLE_OBJC_ARC = YES; @@ -312,7 +312,7 @@ ALWAYS_SEARCH_USER_PATHS = NO; CLANG_ANALYZER_NONNULL = YES; CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; - CLANG_CXX_LANGUAGE_STANDARD = "gnu++14"; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++17"; CLANG_CXX_LIBRARY = "libc++"; CLANG_ENABLE_MODULES = YES; CLANG_ENABLE_OBJC_ARC = YES; diff --git a/mypy-nofollow.ini b/mypy-nofollow.ini new file mode 100644 index 0000000000000..5b5358643774f --- /dev/null +++ b/mypy-nofollow.ini @@ -0,0 +1,34 @@ +[mypy] +plugins = mypy_plugins/check_mypy_version.py + +cache_dir = .mypy_cache/nofollow +warn_unused_configs = True +warn_redundant_casts = True +show_error_codes = True +show_column_numbers = True +check_untyped_defs = True +follow_imports = skip + +# do not reenable this: +# https://github.com/pytorch/pytorch/pull/60006#issuecomment-866130657 +warn_unused_ignores = False +disallow_any_generics = True + +files = + torch/_dynamo + +# Minimum version supported - variable annotations were introduced +# in Python 3.7 +python_version = 3.7 + +[mypy-sympy] +ignore_missing_imports = True + +[mypy-sympy.*] +ignore_missing_imports = True + +[mypy-torch._C] +ignore_errors = True + +[mypy-torch._C.*] +ignore_errors = True diff --git a/mypy-strict.ini b/mypy-strict.ini index 460599699c46f..81c66d5239ebc 100644 --- a/mypy-strict.ini +++ b/mypy-strict.ini @@ -40,6 +40,7 @@ files = .github, benchmarks/instruction_counts, tools, + torch/profiler/_memory_profiler.py, torch/utils/_pytree.py, torch/utils/benchmark/utils/common.py, torch/utils/benchmark/utils/timer.py, diff --git a/scripts/buck_setup.sh b/scripts/buck_setup.sh index 8e60d92a5fd15..f6152537435c2 100644 --- a/scripts/buck_setup.sh +++ b/scripts/buck_setup.sh @@ -22,16 +22,16 @@ python3 generate-xnnpack-wrappers.py # bazel-skylib printf "\nDownloading bazel-skylib\n" rm -rf bazel-skylib; mkdir bazel-skylib -curl -L $PROXY https://github.com/bazelbuild/bazel-skylib/releases/download/1.0.2/bazel-skylib-1.0.2.tar.gz|tar zx -C bazel-skylib +curl --retry 3 -L $PROXY https://github.com/bazelbuild/bazel-skylib/releases/download/1.0.2/bazel-skylib-1.0.2.tar.gz|tar zx -C bazel-skylib # glog printf "\nDownloading glog\n" rm -rf glog; mkdir glog -curl -L $PROXY https://github.com/google/glog/archive/v0.4.0.tar.gz | tar zx -C glog --strip-components 1 +curl --retry 3 -L $PROXY https://github.com/google/glog/archive/v0.4.0.tar.gz | tar zx -C glog --strip-components 1 # ruy printf "\nDownloading ruy\n" -curl -L $PROXY -o /tmp/ruy.zip https://github.com/google/ruy/archive/a09683b8da7164b9c5704f88aef2dc65aa583e5d.zip +curl --retry 3 -L $PROXY -o /tmp/ruy.zip https://github.com/google/ruy/archive/a09683b8da7164b9c5704f88aef2dc65aa583e5d.zip unzip -q /tmp/ruy.zip -d /tmp/ rm -rf ruy/ mv /tmp/ruy-a09683b8da7164b9c5704f88aef2dc65aa583e5d ruy/ diff --git a/scripts/build_android.sh b/scripts/build_android.sh index 2d6f051ea19fe..e2be6c88e9893 100755 --- a/scripts/build_android.sh +++ b/scripts/build_android.sh @@ -165,6 +165,11 @@ fi # Use-specified CMake arguments go last to allow overridding defaults CMAKE_ARGS+=($@) +# Patch pocketfft (as Android does not have aligned_alloc even if compiled with c++17 +if [ -f third_party/pocketfft/pocketfft_hdronly.h ]; then + sed -i -e "s/#if __cplusplus >= 201703L/#if 0/" third_party/pocketfft/pocketfft_hdronly.h +fi + # Now, actually build the Android target. BUILD_ROOT=${BUILD_ROOT:-"$CAFFE2_ROOT/build_android"} INSTALL_PREFIX=${BUILD_ROOT}/install diff --git a/setup.py b/setup.py index f844c690b74fb..1f069f4d4bdf9 100644 --- a/setup.py +++ b/setup.py @@ -446,8 +446,8 @@ def build_deps(): def check_pydep(importname, module): try: importlib.import_module(importname) - except ImportError: - raise RuntimeError(missing_pydep.format(importname=importname, module=module)) + except ImportError as e: + raise RuntimeError(missing_pydep.format(importname=importname, module=module)) from e class build_ext(setuptools.command.build_ext.build_ext): @@ -762,6 +762,14 @@ def run(self): super().run() +def get_cmake_cache_vars(): + try: + return defaultdict(lambda: False, cmake.get_cmake_cache_variables()) + except FileNotFoundError: + # CMakeCache.txt does not exist. Probably running "python setup.py clean" over a clean directory. + return defaultdict(lambda: False) + + def configure_extension_build(): r"""Configures extension build options according to system environment and user's choice. @@ -769,11 +777,7 @@ def configure_extension_build(): The input to parameters ext_modules, cmdclass, packages, and entry_points as required in setuptools.setup. """ - try: - cmake_cache_vars = defaultdict(lambda: False, cmake.get_cmake_cache_variables()) - except FileNotFoundError: - # CMakeCache.txt does not exist. Probably running "python setup.py clean" over a clean directory. - cmake_cache_vars = defaultdict(lambda: False) + cmake_cache_vars = get_cmake_cache_vars() ################################################################################ # Configure compile flags @@ -791,7 +795,7 @@ def configure_extension_build(): # /EHsc is about standard C++ exception handling # /DNOMINMAX removes builtin min/max functions # /wdXXXX disables warning no. XXXX - extra_compile_args = ['/MD', '/EHsc', '/DNOMINMAX', + extra_compile_args = ['/MD', '/FS', '/EHsc', '/DNOMINMAX', '/wd4267', '/wd4251', '/wd4522', '/wd4522', '/wd4838', '/wd4305', '/wd4244', '/wd4190', '/wd4101', '/wd4996', '/wd4275'] @@ -848,7 +852,7 @@ def configure_extension_build(): pytorch_extra_install_requirements = os.getenv("PYTORCH_EXTRA_INSTALL_REQUIREMENTS", "") if pytorch_extra_install_requirements: report(f"pytorch_extra_install_requirements: {pytorch_extra_install_requirements}") - extra_install_requires += pytorch_extra_install_requirements.split(";") + extra_install_requires += pytorch_extra_install_requirements.split("|") # Cross-compile for M1 @@ -877,7 +881,12 @@ def make_relative_rpath_args(path): ################################################################################ extensions = [] - packages = find_packages(exclude=('tools', 'tools.*')) + excludes = ['tools', 'tools.*'] + if not cmake_cache_vars['BUILD_CAFFE2']: + excludes.extend(['caffe2', 'caffe2.*']) + if not cmake_cache_vars['BUILD_FUNCTORCH']: + excludes.extend(['functorch', 'functorch.*']) + packages = find_packages(exclude=excludes) C = Extension("torch._C", libraries=main_libraries, sources=main_sources, @@ -1027,9 +1036,11 @@ def main(): 'lib/*.pdb', 'lib/torch_shm_manager', 'lib/*.h', + 'include/*.h', 'include/ATen/*.h', 'include/ATen/cpu/*.h', 'include/ATen/cpu/vec/vec256/*.h', + 'include/ATen/cpu/vec/vec256/vsx/*.h', 'include/ATen/cpu/vec/vec512/*.h', 'include/ATen/cpu/vec/*.h', 'include/ATen/core/*.h', @@ -1055,8 +1066,7 @@ def main(): 'include/ATen/native/quantized/*.h', 'include/ATen/native/quantized/cpu/*.h', 'include/ATen/quantized/*.h', - 'include/caffe2/utils/*.h', - 'include/caffe2/utils/**/*.h', + 'include/caffe2/serialize/*.h', 'include/c10/*.h', 'include/c10/macros/*.h', 'include/c10/core/*.h', @@ -1070,7 +1080,6 @@ def main(): 'include/c10/cuda/impl/*.h', 'include/c10/hip/*.h', 'include/c10/hip/impl/*.h', - 'include/caffe2/**/*.h', 'include/torch/*.h', 'include/torch/csrc/*.h', 'include/torch/csrc/api/include/torch/*.h', @@ -1097,7 +1106,8 @@ def main(): 'include/torch/csrc/autograd/generated/*.h', 'include/torch/csrc/autograd/utils/*.h', 'include/torch/csrc/cuda/*.h', - 'include/torch/csrc/distributed/c10d/exception.h', + 'include/torch/csrc/distributed/c10d/*.h', + 'include/torch/csrc/distributed/c10d/*.hpp', 'include/torch/csrc/distributed/rpc/*.h', 'include/torch/csrc/jit/*.h', 'include/torch/csrc/jit/backends/*.h', @@ -1139,6 +1149,7 @@ def main(): 'include/THH/*.cuh', 'include/THH/*.h*', 'include/THH/generic/*.h', + 'include/sleef.h', "_inductor/codegen/*.h", "_inductor/codegen/*.j2", 'share/cmake/ATen/*.cmake', @@ -1157,6 +1168,13 @@ def main(): 'utils/model_dump/code.js', 'utils/model_dump/*.mjs', ] + + if get_cmake_cache_vars()['BUILD_CAFFE2']: + torch_package_data.extend([ + 'include/caffe2/**/*.h', + 'include/caffe2/utils/*.h', + 'include/caffe2/utils/**/*.h', + ]) torchgen_package_data = [ # Recursive glob doesn't work in setup.py, # https://github.com/pypa/setuptools/issues/1806 diff --git a/test/allowlist_for_publicAPI.json b/test/allowlist_for_publicAPI.json index 20edd93d7dc2c..94ff57700af67 100644 --- a/test/allowlist_for_publicAPI.json +++ b/test/allowlist_for_publicAPI.json @@ -525,6 +525,7 @@ "Optional" ], "torch.nested": [ + "nested_tensor", "to_padded_tensor" ], "torch.nn.common_types": [ @@ -622,7 +623,7 @@ "OrderedDict" ], "torch.nn.qat.dynamic.modules.linear": [ - "activation_is_memoryless" + "_activation_is_memoryless" ], "torch.nn.qat.modules.conv": [ "Tuple", @@ -735,10 +736,10 @@ "QuantType", "QuantWrapper", "RecordingObserver", - "add_module_to_qconfig_obs_ctr", + "_add_module_to_qconfig_obs_ctr", "add_observer_", "add_quant_dequant", - "assert_valid_qconfig", + "_assert_valid_qconfig", "convert", "convert_dynamic_jit", "convert_jit", @@ -794,7 +795,7 @@ "prepare_qat", "propagate_qconfig_", "qconfig_equals", - "quant_type_to_str", + "_get_quant_type_to_str", "quantize", "quantize_dynamic", "quantize_dynamic_jit", @@ -865,15 +866,15 @@ "QConfig", "QConfigAny", "QConfigDynamic", - "add_module_to_qconfig_obs_ctr", - "assert_valid_qconfig", + "_add_module_to_qconfig_obs_ctr", + "_assert_valid_qconfig", "get_default_qat_qconfig", "get_default_qconfig", "qconfig_equals" ], "torch.quantization.quant_type": [ "QuantType", - "quant_type_to_str" + "_get_quant_type_to_str" ], "torch.quantization.quantization_mappings": [ "get_default_compare_output_module_list", diff --git a/test/ao/sparsity/test_pruner.py b/test/ao/sparsity/test_pruner.py deleted file mode 100644 index 295939cb3e39f..0000000000000 --- a/test/ao/sparsity/test_pruner.py +++ /dev/null @@ -1,394 +0,0 @@ -# -*- coding: utf-8 -*- -# Owner(s): ["module: unknown"] - - -import copy -import logging - -import torch -from torch import nn -from torch.ao.pruning._experimental.pruner import BasePruner, PruningParametrization, ZeroesParametrization -from torch.nn.utils import parametrize - -from torch.testing._internal.common_utils import TestCase, skipIfTorchDynamo - -logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO) - -DEVICES = { - torch.device("cpu"), - torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") -} - -NEEDS_ZEROS = { # these layers should have pruned indices zero-ed, not removed - nn.BatchNorm2d -} - - -class Linear(nn.Module): - r"""Model with Linear layers, in Sequential and outside, without biases""" - def __init__(self): - super().__init__() - self.seq = nn.Sequential( - nn.Linear(16, 16, bias=False) - ) - self.linear = nn.Linear(16, 16, bias=False) - - def forward(self, x): - x = self.seq(x) - x = self.linear(x) - return x - - -class LinearB(nn.Module): - r"""Model with Linear layers, in Sequential and outside, with biases""" - def __init__(self): - super().__init__() - self.seq = nn.Sequential( - nn.Linear(16, 16, bias=True) - ) - self.linear = nn.Linear(16, 16, bias=True) - - def forward(self, x): - x = self.seq(x) - x = self.linear(x) - return x - - -class MultipleLinear(nn.Module): - r"""Model with multiple Linear layers, in Sequential and outside, without biases - and with activation functions""" - def __init__(self): - super().__init__() - self.seq = nn.Sequential( - nn.Linear(7, 5, bias=False), - nn.ReLU(), - nn.Linear(5, 8, bias=False), - nn.ReLU(), - nn.Linear(8, 6, bias=False) - ) - self.linear = nn.Linear(6, 4, bias=False) - - def forward(self, x): - x = self.seq(x) - x = self.linear(x) - return x - - -class MultipleLinearB(nn.Module): - r"""Model with multiple Linear layers, in Sequential and outside, with biases - and with activation functions""" - def __init__(self): - super().__init__() - self.seq = nn.Sequential( - nn.Linear(7, 5, bias=True), - nn.ReLU(), - nn.Linear(5, 8, bias=True), - nn.ReLU(), - nn.Linear(8, 6, bias=True) - ) - self.linear = nn.Linear(6, 4, bias=True) - - def forward(self, x): - x = self.seq(x) - x = self.linear(x) - return x - - -class MultipleLinearMixed(nn.Module): - r"""Model with multiple Linear layers, in Sequential and outside, some with biases - and with activation functions""" - def __init__(self): - super().__init__() - self.seq = nn.Sequential( - nn.Linear(7, 5, bias=True), - nn.ReLU(), - nn.Linear(5, 8, bias=False), - nn.ReLU(), - nn.Linear(8, 6, bias=True) - ) - self.linear = nn.Linear(6, 4, bias=False) - - def forward(self, x): - x = self.seq(x) - x = self.linear(x) - return x - - -class Conv2dA(nn.Module): - r"""Model with Conv2d layers, in Sequential and outside, without biases""" - def __init__(self): - super().__init__() - self.seq = nn.Sequential( - nn.Conv2d(1, 32, 3, 1, bias=False), - ) - self.conv2d = nn.Conv2d(32, 64, 3, 1, bias=False) - - def forward(self, x): - x = self.seq(x) - x = self.conv2d(x) - return x - - -class Conv2dB(nn.Module): - r"""Model with Conv2d layers, in Sequential and outside, with biases""" - def __init__(self): - super().__init__() - self.seq = nn.Sequential( - nn.Conv2d(1, 32, 3, 1, bias=True), - ) - self.conv2d = nn.Conv2d(32, 64, 3, 1, bias=True) - - def forward(self, x): - x = self.seq(x) - x = self.conv2d(x) - return x - - -class Conv2dC(nn.Module): - r"""Model with Conv2d layers, in Sequential and outside, with and without biases""" - def __init__(self): - super().__init__() - self.seq = nn.Sequential( - nn.Conv2d(1, 32, 3, 1, bias=True), - ) - self.conv2d = nn.Conv2d(32, 64, 3, 1, bias=False) - - def forward(self, x): - x = self.seq(x) - x = self.conv2d(x) - return x - - -class Conv2dBN(nn.Module): - r"""Model with Conv2d layers and BatchNorms""" - def __init__(self): - super().__init__() - self.seq = nn.Sequential( - nn.Conv2d(1, 32, 3, 1, bias=True), - nn.BatchNorm2d(32) - ) - self.conv2d = nn.Conv2d(32, 64, 3, 1, bias=True) - self.bn = nn.BatchNorm2d(64) - - def forward(self, x): - x = self.seq(x) - x = self.conv2d(x) - x = self.bn(x) - return x - - -class SimplePruner(BasePruner): - def update_mask(self, module, tensor_name, **kwargs): - getattr(module.parametrizations, tensor_name)[0].pruned_outputs.add(1) - - -class MultiplePruner(BasePruner): - def update_mask(self, module, tensor_name, **kwargs): - getattr(module.parametrizations, tensor_name)[0].pruned_outputs.update([1, 2]) - - -class TestBasePruner(TestCase): - def _check_pruner_prepared(self, model, pruner, device): - for config in pruner.groups: - modules = [] - if type(config['module']) is tuple: - for module in config['module']: - modules.append(module) - else: - module = config['module'] - modules.append(module) - for module in modules: - assert module.weight.device.type == device.type - # Check mask exists - assert hasattr(module, 'mask') - # Check parametrization exists and is correct - assert parametrize.is_parametrized(module) - assert hasattr(module, "parametrizations") - # Assume that this is the 1st/only parametrization - if isinstance(module, tuple(NEEDS_ZEROS)): - assert type(module.parametrizations.weight[0]) == ZeroesParametrization - else: - assert type(module.parametrizations.weight[0]) == PruningParametrization - - def _check_pruner_mask_squashed(self, model, pruner, device): - for config in pruner.groups: - modules = [] - if type(config['module']) is tuple: - for module in config['module']: - modules.append(module) - else: - module = config['module'] - modules.append(module) - for module in modules: - assert module.weight.device.type == device.type - assert not hasattr(module, "parametrizations") - assert not hasattr(module, 'mask') - - def _check_pruner_valid_before_step(self, model, pruner, device): - for config in pruner.groups: - modules = [] - if type(config['module']) is tuple: - for module in config['module']: - modules.append(module) - else: - module = config['module'] - modules.append(module) - for module in modules: - assert module.weight.device.type == device.type - assert module.parametrizations.weight[0].pruned_outputs == set() - - def _check_pruner_valid_after_step(self, model, pruner, pruned_set, device): - for config in pruner.groups: - modules = [] - if type(config['module']) is tuple: - for module in config['module']: - modules.append(module) - else: - module = config['module'] - modules.append(module) - for module in modules: - assert module.weight.device.type == device.type - assert module.parametrizations.weight[0].pruned_outputs == pruned_set - - def _test_constructor_on_device(self, model, device): - self.assertRaisesRegex(TypeError, 'BasePruner .* update_mask', - BasePruner) - model1 = copy.deepcopy(model).to(device) - pruner = SimplePruner(None) - pruner.prepare(model1, None) - for g in pruner.groups: - module = g['module'] - assert module.weight.device.type == device.type - assert len(pruner.groups) == 2 - pruner.step() - # Can instantiate the model with configs - model2 = copy.deepcopy(model).to(device) - pruner = SimplePruner({'test': 3}) - pruner.prepare(model2, [model2.linear]) - assert len(pruner.groups) == 1 - assert pruner.groups[0]['module_fqn'] == 'linear' - assert 'test' in pruner.groups[0] - assert pruner.groups[0]['test'] == 3 - - def test_constructor(self): - model = Linear() - for device in DEVICES: - self._test_constructor_on_device(model, torch.device(device)) - - def _test_prepare_linear_on_device(self, model, device): - model = copy.deepcopy(model).to(device) - x = torch.ones(128, 16, device=device) - pruner = SimplePruner(None) - pruner.prepare(model, None) - self._check_pruner_prepared(model, pruner, device) - assert model(x).shape == (128, 16) - - def test_prepare_linear(self): - models = [Linear(), LinearB()] # without and with bias - for device in DEVICES: - for model in models: - self._test_prepare_linear_on_device(model, torch.device(device)) - - def _test_prepare_conv2d_on_device(self, model, config, device): - x = torch.ones((1, 1, 28, 28), device=device) - pruner = SimplePruner(None) - pruner.prepare(model, config) - self._check_pruner_prepared(model, pruner, device) - assert model(x).shape == (1, 64, 24, 24) - - def test_prepare_conv2d(self): - bn_model = Conv2dBN() - bn_config = [(bn_model.seq[0], bn_model.seq[1]), (bn_model.conv2d, bn_model.bn)] - - models = [Conv2dA(), Conv2dB(), Conv2dC(), bn_model] - configs = [None, None, None, bn_config] - for device in DEVICES: - for model, config in zip(models, configs): - model = model.to(device) - self._test_prepare_conv2d_on_device(model, config, torch.device(device)) - - def _test_squash_mask_linear_on_device(self, model, device): - model = copy.deepcopy(model).to(device) - x = torch.ones(128, 16, device=device) - pruner = SimplePruner(None) - pruner.prepare(model, None) - pruner.squash_mask() - self._check_pruner_mask_squashed(model, pruner, device) - assert model(x).shape == (128, 16) - - def test_squash_mask_linear(self): - models = [Linear(), LinearB()] # without and with bias - for device in DEVICES: - for model in models: - self._test_squash_mask_linear_on_device(model, torch.device(device)) - - def _test_squash_mask_conv2d_on_device(self, model, config, device): - model = copy.deepcopy(model).to(device) - x = torch.ones((1, 1, 28, 28), device=device) - pruner = SimplePruner(None) - pruner.prepare(model, config) - pruner.squash_mask() - self._check_pruner_mask_squashed(model, pruner, device) - assert model(x).shape == (1, 64, 24, 24) - - def test_squash_mask_conv2d(self): - bn_model = Conv2dBN() - bn_config = [(bn_model.seq[0], bn_model.seq[1]), (bn_model.conv2d, bn_model.bn)] - - models = [Conv2dA(), Conv2dB(), Conv2dC(), bn_model] - configs = [None, None, None, bn_config] - for device in DEVICES: - for model, config in zip(models, configs): - model = model.to(device) - self._test_squash_mask_conv2d_on_device(model, config, torch.device(device)) - - def _test_step_linear_on_device(self, model, is_basic, device): - model = model.to(device) - if is_basic: - x = torch.ones(16, 16) - pruner = SimplePruner(None) - pruner.prepare(model, None) - self._check_pruner_valid_before_step(model, pruner, device) - pruner.step() - self._check_pruner_valid_after_step(model, pruner, {1}, device) - else: - x = torch.ones(7, 7) - pruner = MultiplePruner(None) - pruner.prepare(model, None) - self._check_pruner_valid_before_step(model, pruner, device) - pruner.step() - self._check_pruner_valid_after_step(model, pruner, {1, 2}, device) - - def test_step_linear(self): - basic_models = [Linear(), LinearB()] - complex_models = [MultipleLinear(), MultipleLinearB(), MultipleLinearMixed()] - for device in DEVICES: - for model in basic_models: - self._test_step_linear_on_device(model, True, torch.device(device)) - for model in complex_models: - self._test_step_linear_on_device(model, False, torch.device(device)) - - def _test_step_conv2d_on_device(self, model, config, device): - model = model.to(device) - x = torch.ones((1, 1, 28, 28)).to(device) - pruner = SimplePruner(None) - pruner.prepare(model, config) - self._check_pruner_valid_before_step(model, pruner, device) - pruner.step() - if type(model) is Conv2dBN: - assert pruner.get_module_pruned_outputs(model.seq[1]) == pruner.get_module_pruned_outputs(model.seq[0]) - assert pruner.get_module_pruned_outputs(model.bn) == pruner.get_module_pruned_outputs(model.conv2d) - self._check_pruner_valid_after_step(model, pruner, {1}, device) - assert model(x).shape == (1, 64, 24, 24) - - @skipIfTorchDynamo("TorchDynamo fails with unknown reason") - def test_step_conv2d(self): - bn_model = Conv2dBN() - bn_config = [(bn_model.seq[0], bn_model.seq[1]), - (bn_model.conv2d, bn_model.bn)] - - models = [Conv2dA(), Conv2dB(), Conv2dC(), bn_model] - configs = [None, None, None, None, bn_config] - for device in DEVICES: - for model, config in zip(models, configs): - self._test_step_conv2d_on_device(model, config, torch.device(device)) diff --git a/test/ao/sparsity/test_sparsifier.py b/test/ao/sparsity/test_sparsifier.py index 415679337ff2e..512c58b188367 100644 --- a/test/ao/sparsity/test_sparsifier.py +++ b/test/ao/sparsity/test_sparsifier.py @@ -18,14 +18,16 @@ class Model(nn.Module): def __init__(self): super().__init__() self.seq = nn.Sequential( - nn.Linear(16, 16) + nn.Linear(37, 39) ) - self.linear = nn.Linear(16, 16) - self.head = nn.Linear(16, 4) + self.linear = nn.Linear(39, 33) + self.head = nn.Linear(33, 13) def forward(self, x): x = self.seq(x) + x = torch.relu(x) x = self.linear(x) + x = torch.relu(x) x = self.head(x) return x diff --git a/test/ao/sparsity/test_structured_sparsifier.py b/test/ao/sparsity/test_structured_sparsifier.py new file mode 100644 index 0000000000000..19e5a03640d00 --- /dev/null +++ b/test/ao/sparsity/test_structured_sparsifier.py @@ -0,0 +1,641 @@ +# -*- coding: utf-8 -*- +# Owner(s): ["module: unknown"] + + +import copy +import logging +import random + +import torch +from torch.ao.pruning._experimental.pruner import ( + BaseStructuredSparsifier, + FakeStructuredSparsity, +) +from torch.nn.utils import parametrize + +from torch.testing._internal.common_utils import TestCase, skipIfTorchDynamo +from torch.testing._internal.common_pruning import ( + SimpleLinear, + LinearBias, + LinearActivation, + LinearActivationFunctional, + SimpleConv2d, + Conv2dBias, + Conv2dActivation, + Conv2dPadBias, + Conv2dPool, + Conv2dPoolFlatten, + Conv2dPoolFlattenFunctional, +) + +logging.basicConfig( + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO +) + +DEVICES = { + torch.device("cpu"), + torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"), +} + + +class SimplePruner(BaseStructuredSparsifier): + def update_mask(self, module, tensor_name, **kwargs): + getattr(module.parametrizations, tensor_name)[0].mask[1] = False + + +class ImplementedPruner(BaseStructuredSparsifier): + def update_mask(self, module, tensor_name, **kwargs): + """Prunes 1/3 of the weight output channels, so resulting module has 33.3% pruning""" + num_rows = len(module.parametrizations[tensor_name][0].mask) + prune = random.sample(list(range(num_rows)), num_rows // 3) + module.parametrizations[tensor_name][0].mask[prune] = False + + +class TestBaseStructuredSparsifier(TestCase): + def _check_pruner_prepared(self, model, pruner, device): + for config in pruner.groups: + module = config["module"] + assert module.weight.device.type == device.type + # Check mask exists + assert config["tensor_fqn"] in pruner.state + # Check parametrization exists and is correct + assert parametrize.is_parametrized(module) + assert hasattr(module, "parametrizations") + # Assume that this is the 1st/only parametrization + assert type(module.parametrizations.weight[0]) == FakeStructuredSparsity + + def _check_pruner_valid_before_step(self, model, pruner, device): + for config in pruner.groups: + modules = [] + if type(config["module"]) is tuple: + for module in config["module"]: + modules.append(module) + else: + module = config["module"] + modules.append(module) + for module in modules: + assert module.weight.device.type == device.type + assert module.parametrizations.weight[0].mask.dtype == torch.bool + + def _check_pruner_valid_after_step(self, model, pruner, mask, device): + for config in pruner.groups: + modules = [] + if type(config["module"]) is tuple: + for module in config["module"]: + modules.append(module) + else: + module = config["module"] + modules.append(module) + for module in modules: + assert module.weight.device.type == device.type + total = module.parametrizations.weight[0].mask.numel() + assert ( + module.parametrizations.weight[0].mask.count_nonzero() + == total - mask + ) + + def _test_constructor_on_device(self, model, device): + self.assertRaisesRegex( + TypeError, + "BaseStructuredSparsifier.* update_mask", + BaseStructuredSparsifier, + ) + model1 = copy.deepcopy(model).to(device) + pruner = SimplePruner(None) + pruner.prepare(model1, None) + pruner.enable_mask_update = True + for g in pruner.groups: + module = g["module"] + assert module.weight.device.type == device.type + assert len(pruner.groups) == 5 + pruner.step() + # Can instantiate the model with configs + model2 = copy.deepcopy(model).to(device) + pruner = SimplePruner({"test": 3}) + pruner.prepare(model2, [{"tensor_fqn": "seq.0.weight"}]) + assert len(pruner.groups) == 1 + assert pruner.groups[0]["module_fqn"] == "seq.0" + assert "test" in pruner.groups[0] + assert pruner.groups[0]["test"] == 3 + + def test_constructor(self): + model = SimpleLinear() + for device in DEVICES: + self._test_constructor_on_device(model, torch.device(device)) + + def _test_prepare_linear_on_device(self, model, device): + model = copy.deepcopy(model).to(device) + x = torch.ones(128, 7, device=device) + pruner = SimplePruner(None) + pruner.prepare(model, None) + self._check_pruner_prepared(model, pruner, device) + assert model(x).shape == (128, 10) + + def test_prepare_linear(self): + models = [ + SimpleLinear(), + LinearBias(), + LinearActivation(), + LinearActivationFunctional(), + ] # without and with bias + for device in DEVICES: + for model in models: + self._test_prepare_linear_on_device(model, torch.device(device)) + + def _test_prepare_conv2d_on_device(self, model, expected_shape, config, device): + x = torch.ones((1, 1, 28, 28), device=device) + pruner = SimplePruner(None) + pruner.prepare(model, config) + self._check_pruner_prepared(model, pruner, device) + assert model(x).shape == expected_shape + + def test_prepare_conv2d(self): + models = [ + SimpleConv2d(), + Conv2dBias(), + Conv2dActivation(), + Conv2dPadBias(), + Conv2dPool(), + ] + shapes = [ + (1, 52, 20, 20), + (1, 52, 18, 18), + (1, 52, 18, 18), + (1, 52, 24, 24), + (1, 52, 3, 3), + ] + configs = [None, None, None, None, None] + for device in DEVICES: + for model, shape, config in zip(models, shapes, configs): + model = model.to(device) + self._test_prepare_conv2d_on_device( + model, shape, config, torch.device(device) + ) + + def _test_step_linear_on_device(self, model, device): + model = model.to(device) + x = torch.ones(7, 7, device=device) + pruner = SimplePruner(None) + pruner.prepare(model, None) + pruner.enable_mask_update = True + self._check_pruner_valid_before_step(model, pruner, device) + pruner.step() + self._check_pruner_valid_after_step(model, pruner, 1, device) + + def test_step_linear(self): + models = [ + SimpleLinear(), + LinearBias(), + LinearActivation(), + LinearActivationFunctional(), + ] + for device in DEVICES: + for model in models: + self._test_step_linear_on_device(model, torch.device(device)) + + def _test_step_conv2d_on_device(self, model, expected_shape, config, device): + model = model.to(device) + x = torch.ones((1, 1, 28, 28), device=device) + pruner = SimplePruner(None) + pruner.prepare(model, config) + pruner.enable_mask_update = True + self._check_pruner_valid_before_step(model, pruner, device) + pruner.step() + self._check_pruner_valid_after_step(model, pruner, 1, device) + assert model(x).shape == expected_shape + + @skipIfTorchDynamo("TorchDynamo fails with unknown reason") + def test_step_conv2d(self): + models = [ + SimpleConv2d(), + Conv2dBias(), + Conv2dActivation(), + Conv2dPadBias(), + Conv2dPool(), + ] + shapes = [ + (1, 52, 20, 20), + (1, 52, 18, 18), + (1, 52, 18, 18), + (1, 52, 24, 24), + (1, 52, 3, 3), + ] + configs = [None, None, None, None, None] + for device in DEVICES: + for model, shape, config in zip(models, shapes, configs): + self._test_step_conv2d_on_device( + model, shape, config, torch.device(device) + ) + + def _check_pruner_pruned(self, model, pruner, device): + for config in pruner.groups: + module = config["module"] + assert not hasattr(module, "parametrizations") + assert not hasattr(module, "mask") + + def _test_linear_on_device( + self, model, config, expected_shape, device, also_prune_bias + ): + model = model.to(device) + model.eval() + num_original_params = sum(p.numel() for p in model.parameters()) + x = torch.ones(128, 7, device=device) + + pruner = ImplementedPruner({"prune_bias": also_prune_bias}) + pruner.prepare(model, config) + pruner.enable_mask_update = True + pruner.step() + + y_expected = model(x) + + assert y_expected.shape == (128, 10) + self._check_pruner_prepared(model, pruner, device) + + # Pruning step + pruned = pruner.prune() + y_pruned = pruned(x) + num_pruned_params = sum(p.numel() for p in pruned.parameters()) + + assert y_pruned.shape == expected_shape + self._check_pruner_pruned(model, pruner, device) + if y_pruned.shape == y_expected.shape: + assert torch.isclose(y_expected, y_pruned, rtol=1e-05, atol=1e-07).all() + assert num_pruned_params < num_original_params + + def test_prune_linear_linear(self): + r"""test pruning linear-> linear modules""" + configs, shapes = [], [] + configs.append( + [ + {"tensor_fqn": "seq.0.weight"}, + {"tensor_fqn": "seq.1.weight"}, + {"tensor_fqn": "seq.2.weight"}, + ] + ) + shapes.append((128, 10)) + + configs.append( + [ + {"tensor_fqn": "seq.0.weight"}, + {"tensor_fqn": "seq.1.weight"}, + {"tensor_fqn": "seq.2.weight"}, + {"tensor_fqn": "linear1.weight"}, + ] + ) + shapes.append((128, 10)) + + configs.append( + [ + {"tensor_fqn": "seq.0.weight"}, + {"tensor_fqn": "seq.2.weight"}, + ] + ) + shapes.append((128, 10)) + for device in DEVICES: + for also_prune_bias in [True, False]: + for config, shape in zip(configs, shapes): + self._test_linear_on_device( + SimpleLinear(), + config, + shape, + torch.device(device), + also_prune_bias, + ) + + def test_prune_linear_bias_linear(self): + # linear(bias) -> linear(no bias) + configs, shapes = [], [] + configs.append( + [ + {"tensor_fqn": "seq.0.weight"}, + {"tensor_fqn": "seq.1.weight"}, + ] + ) + shapes.append((128, 10)) + + # linear(bias) -> linear(bias) + configs.append( + [ + {"tensor_fqn": "seq.2.weight"}, + {"tensor_fqn": "seq.3.weight"}, + ] + ) + shapes.append((128, 10)) + + # linear(no bias) -> linear(bias) + configs.append( + [ + {"tensor_fqn": "seq.0.weight"}, + {"tensor_fqn": "seq.1.weight"}, + {"tensor_fqn": "seq.2.weight"}, + ] + ) + shapes.append((128, 10)) + + for device in DEVICES: + for also_prune_bias in [True, False]: + for config, shape in zip(configs, shapes): + self._test_linear_on_device( + LinearBias(), + config, + shape, + torch.device(device), + also_prune_bias, + ) + + def test_prune_linear_activation_linear(self): + config = [ + {"tensor_fqn": "seq.0.weight"}, + {"tensor_fqn": "seq.2.weight"}, + {"tensor_fqn": "seq.4.weight"}, + {"tensor_fqn": "linear1.weight"}, + ] + shape = (128, 10) + + for device in DEVICES: + for also_prune_bias in [True, False]: + # test version with nn.Modules + self._test_linear_on_device( + LinearActivation(), + config, + shape, + torch.device(device), + also_prune_bias, + ) + # test functional version + self._test_linear_on_device( + LinearActivationFunctional(), + config, + shape, + torch.device(device), + also_prune_bias, + ) + + def _test_conv2d_on_device( + self, model, config, x, expected_shape, device, also_prune_bias + ): + model = model.to(device) + num_original_params = sum(p.numel() for p in model.parameters()) + model.eval() + + pruner = ImplementedPruner({"prune_bias": also_prune_bias}) + pruner.prepare(model, config) + pruner.enable_mask_update = True + pruner.step() + + y_expected = model(x) + assert y_expected.shape == expected_shape + + self._check_pruner_prepared(model, pruner, device) + + # Fusion step + pruned = pruner.prune() + y_pruned = pruned(x) + num_pruned_params = sum(p.numel() for p in pruned.parameters()) + + assert y_pruned.shape == expected_shape + self._check_pruner_pruned(model, pruner, device) + if y_pruned.shape == y_expected.shape: + # TODO This rtol is a little high, need to double check if something specific is causing this to fail + assert torch.isclose( + y_expected, y_pruned, rtol=1e-1 + ).all(), f"fail for {type(model)}" + # only time this should be equal is when all layers have padding and we can't prune + assert num_pruned_params <= num_original_params + + def test_prune_conv2d_conv2d(self): + configs, shapes = [], [] + # all within sequential blocks + configs.append( + [ + {"tensor_fqn": "seq.0.weight"}, + ] + ) + shapes.append((1, 52, 20, 20)) + # prune across sequential blocks + configs.append( + [ + {"tensor_fqn": "seq.0.weight"}, + {"tensor_fqn": "seq.1.weight"}, + {"tensor_fqn": "conv2d1.weight"}, + ] + ) + shapes.append((1, 52, 20, 20)) + + for device in DEVICES: + x = torch.ones((1, 1, 28, 28), device=device) + for also_prune_bias in [True, False]: + for config, shape in zip(configs, shapes): + self._test_conv2d_on_device( + SimpleConv2d(), + config, + x, + shape, + torch.device(device), + also_prune_bias, + ) + + def test_prune_conv2d_bias_conv2d(self): + # Conv2d with Bias and no Activation + configs, shapes = [], [] + # conv2d(bias) -> conv2d(bias) + configs.append( + [ + {"tensor_fqn": "seq.0.weight"}, + {"tensor_fqn": "seq.1.weight"}, + ] + ) + shapes.append((1, 52, 18, 18)) + + # conv2d(no bias) -> conv2d(bias) + configs.append( + [ + {"tensor_fqn": "seq.0.weight"}, + {"tensor_fqn": "seq.1.weight"}, + {"tensor_fqn": "conv2d1.weight"}, + ] + ) + shapes.append((1, 52, 18, 18)) + + # conv2d(bias) -> conv2d(no bias) + configs.append( + [ + {"tensor_fqn": "seq.0.weight"}, + {"tensor_fqn": "seq.1.weight"}, + {"tensor_fqn": "seq.2.weight"}, + ] + ) + shapes.append((1, 52, 18, 18)) + + for device in DEVICES: + x = torch.ones((1, 1, 28, 28), device=device) + for also_prune_bias in [True, False]: + for config, shape in zip(configs, shapes): + self._test_conv2d_on_device( + Conv2dBias(), + config, + x, + shape, + torch.device(device), + also_prune_bias, + ) + + def test_prune_conv2d_activation_conv2d(self): + # Conv2d with Activation and no Bias + configs, shapes = [], [] + + # conv2d(no bias) -> activatation -> conv2d(no bias) + configs.append( + [ + {"tensor_fqn": "seq.4.weight"}, + ] + ) + shapes.append((1, 52, 18, 18)) + + # conv2d(bias) -> activatation -> conv2d(bias) + configs.append( + [ + {"tensor_fqn": "seq.0.weight"}, + {"tensor_fqn": "seq.2.weight"}, + ] + ) + shapes.append((1, 52, 18, 18)) + + # conv2d(bias) -> activation -> conv2d(no bias) + configs.append( + [ + {"tensor_fqn": "seq.2.weight"}, + {"tensor_fqn": "seq.4.weight"}, + ] + ) + shapes.append((1, 52, 18, 18)) + + # conv2d(no bias) -> activation -> conv2d(bias) + configs.append( + [ + {"tensor_fqn": "conv2d1.weight"}, + ] + ) + shapes.append((1, 52, 18, 18)) + + for device in DEVICES: + x = torch.ones((1, 1, 28, 28), device=device) + for also_prune_bias in [True, False]: + for config, shape in zip(configs, shapes): + self._test_conv2d_on_device( + Conv2dActivation(), + config, + x, + shape, + torch.device(device), + also_prune_bias, + ) + + def test_prune_conv2d_padding_conv2d(self): + # Conv2d with Padded layers after Bias layers + configs, shapes = [], [] + + # conv(padded, bias) -> conv(padded, bias) + configs.append( + [ + {"tensor_fqn": "seq.4.weight"}, + ] + ) + shapes.append((1, 52, 24, 24)) + + # conv(no bias, no pad) -> conv(padded, bias) + configs.append( + [ + {"tensor_fqn": "seq.2.weight"}, + ] + ) + shapes.append((1, 52, 24, 24)) + + # conv(padded, bias) -> conv ( no bias ,no pad) + configs.append( + [ + {"tensor_fqn": "seq.0.weight"}, + ] + ) + shapes.append((1, 52, 24, 24)) + # conv(pad, bias) -> conv(no pad, bias) + configs.append( + [ + {"tensor_fqn": "seq.6.weight"}, + ] + ) + shapes.append((1, 52, 24, 24)) + # conv(no pad, bias) -> conv(pad, bias) + configs.append( + [ + {"tensor_fqn": "seq.8.weight"}, + ] + ) + shapes.append((1, 52, 24, 24)) + + for device in DEVICES: + x = torch.ones((1, 1, 28, 28), device=device) + for also_prune_bias in [True, False]: + for config, shape in zip(configs, shapes): + self._test_conv2d_on_device( + Conv2dPadBias(), + config, + x, + shape, + torch.device(device), + also_prune_bias, + ) + + def test_prune_conv2d_pool_conv2d(self): + # Conv2d with Pooling layers + config = [ + {"tensor_fqn": "seq.0.weight"}, + {"tensor_fqn": "seq.3.weight"}, + {"tensor_fqn": "conv2d1.weight"}, + {"tensor_fqn": "conv2d2.weight"}, + ] + shape = (1, 52, 3, 3) + + for device in DEVICES: + x = torch.ones((1, 1, 28, 28), device=device) + for also_prune_bias in [True, False]: + self._test_conv2d_on_device( + Conv2dPool(), + config, + x, + shape, + torch.device(device), + also_prune_bias, + ) + + @skipIfTorchDynamo("TorchDynamo fails with unknown reason") + def test_complex_conv2d(self): + """Test fusion for models that contain Conv2d & Linear modules. + Currently supports: Conv2d-Pool2d-Flatten-Linear, Skip-add""" + config = [ + {"tensor_fqn": "seq.0.weight"}, + {"tensor_fqn": "seq.3.weight"}, + {"tensor_fqn": "conv2d1.weight"}, + {"tensor_fqn": "conv2d2.weight"}, + ] + shape = (1, 13) + + for device in DEVICES: + x = torch.ones((1, 1, 28, 28), device=device) + for also_prune_bias in [True, False]: + self._test_conv2d_on_device( + Conv2dPoolFlattenFunctional(), + config, + x, + shape, + torch.device(device), + also_prune_bias, + ) + self._test_conv2d_on_device( + Conv2dPoolFlatten(), + config, + x, + shape, + torch.device(device), + also_prune_bias, + ) diff --git a/test/cpp/api/nn_utils.cpp b/test/cpp/api/nn_utils.cpp index 3d24749a96532..76aab44ac290d 100644 --- a/test/cpp/api/nn_utils.cpp +++ b/test/cpp/api/nn_utils.cpp @@ -92,7 +92,7 @@ TEST_F(NNUtilsTest, ClipGradNorm) { ASSERT_LE(norm_after, max_norm); auto scaled = compare_scaling(grads); ASSERT_NEAR(0, scaled.std().item().toFloat(), 1e-7); - ASSERT_EQ(scaled[0].item().toFloat(), 1); + ASSERT_FLOAT_EQ(scaled[0].item().toFloat(), 1); } // should accept a single tensor as input auto p1 = torch::randn({10, 10}); diff --git a/test/cpp/api/serialize.cpp b/test/cpp/api/serialize.cpp index 0cf8ed88c4188..20d572853d3a1 100644 --- a/test/cpp/api/serialize.cpp +++ b/test/cpp/api/serialize.cpp @@ -257,6 +257,47 @@ TEST(SerializeTest, Basic) { ASSERT_TRUE(x.allclose(y)); } +TEST(SerializeTest, MathBits) { + torch::manual_seed(0); + + auto options = torch::TensorOptions{}.dtype(torch::kComplexFloat); + auto x = torch::randn({5, 5}, options); + { + auto expected = torch::conj(x); + auto actual = save_and_load(expected); + + ASSERT_TRUE(actual.defined()); + ASSERT_EQ(actual.sizes().vec(), expected.sizes().vec()); + ASSERT_TRUE(actual.allclose(expected)); + } + + { + auto expected = torch::_neg_view(x); + auto actual = save_and_load(expected); + + ASSERT_TRUE(actual.defined()); + ASSERT_EQ(actual.sizes().vec(), expected.sizes().vec()); + ASSERT_TRUE(actual.allclose(expected)); + } + + { + auto expected = torch::conj(torch::_neg_view(x)); + auto actual = save_and_load(expected); + + ASSERT_TRUE(actual.defined()); + ASSERT_EQ(actual.sizes().vec(), expected.sizes().vec()); + ASSERT_TRUE(actual.allclose(expected)); + } + + { + // We don't support serializing `ZeroTensor` as it is not public facing yet. + // If in future, `ZeroTensor` serialization is supported, this test should + // start failing! + auto t = torch::_efficientzerotensor({5, 5}); + ASSERT_THROWS_WITH(save_and_load(t), "ZeroTensor is not serializable,"); + } +} + TEST(SerializeTest, BasicToFile) { torch::manual_seed(0); diff --git a/test/cpp/c10d/ProcessGroupNCCLTest.cpp b/test/cpp/c10d/ProcessGroupNCCLTest.cpp index 0d566344f2ced..083c4770e0ae3 100644 --- a/test/cpp/c10d/ProcessGroupNCCLTest.cpp +++ b/test/cpp/c10d/ProcessGroupNCCLTest.cpp @@ -355,7 +355,7 @@ void testAllreduce(const std::string& path, int rank, int size) { const auto* const data = tensor.data_ptr(); for (const auto k : c10::irange(tensor.numel())) { EXPECT_EQ(data[k], expected) - << "Allreduce ouputs do not match expected outputs"; + << "Allreduce outputs do not match expected outputs"; } } } diff --git a/test/cpp/jit/CMakeLists.txt b/test/cpp/jit/CMakeLists.txt index 947b13897cf1d..b8b765a68d8b4 100644 --- a/test/cpp/jit/CMakeLists.txt +++ b/test/cpp/jit/CMakeLists.txt @@ -99,7 +99,9 @@ if(USE_CUDA) list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/python_frontend/test/test_nvfuser_fusion_definition.cpp) list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/python_frontend/test/test_nvfuser_fusion_cache.cpp) list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/python_frontend/test/test_nvfuser_fusion_record.cpp) - list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu.cpp) + list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu1.cpp) + list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu2.cpp) + list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu3.cpp) list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_tensor_factories.cpp) list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_fused_reduction.cpp) list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_shift.cpp) @@ -107,7 +109,7 @@ if(USE_CUDA) list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_view.cpp) list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_transpose.cpp) list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_rng.cu) - list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_scheduler_utils.cpp) + list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_utils.cpp) endif() add_executable(test_jit diff --git a/test/cpp/jit/test_custom_class_registrations.cpp b/test/cpp/jit/test_custom_class_registrations.cpp index 63c6b70133062..16e690d99d8a1 100644 --- a/test/cpp/jit/test_custom_class_registrations.cpp +++ b/test/cpp/jit/test_custom_class_registrations.cpp @@ -222,7 +222,7 @@ struct ElementwiseInterpreter : torch::CustomClassHolder { } if (!output_name_) { - throw std::runtime_error("Output name not specififed!"); + throw std::runtime_error("Output name not specified!"); } return environment.at(*output_name_); diff --git a/test/cpp/jit/test_graph_executor.cpp b/test/cpp/jit/test_graph_executor.cpp index 6913e5f3ac2a8..acda804453f56 100644 --- a/test/cpp/jit/test_graph_executor.cpp +++ b/test/cpp/jit/test_graph_executor.cpp @@ -59,7 +59,7 @@ TEST(GraphExecutorTest, runAsync_executor) { mtx.lock(); ++asyncCounter; mtx.unlock(); - at::launch(move(f)); + at::launch(std::move(f)); }; std::vector stack; // NOLINTNEXTLINE(modernize-use-emplace) diff --git a/test/cpp/jit/test_jit_logging_levels.cpp b/test/cpp/jit/test_jit_logging_levels.cpp index ca2e8c5156e6d..6b92bf7d270ce 100644 --- a/test/cpp/jit/test_jit_logging_levels.cpp +++ b/test/cpp/jit/test_jit_logging_levels.cpp @@ -41,7 +41,15 @@ TEST(JitLoggingTest, CheckOutputStreamSetting) { ::torch::jit::set_jit_logging_levels("test_jit_logging_levels"); std::ostringstream test_stream; ::torch::jit::set_jit_logging_output_stream(test_stream); - JIT_LOG(::torch::jit::JitLoggingLevels::GRAPH_DUMP, "Message"); + /* Using JIT_LOG checks if this file has logging enabled with + is_enabled(__FILE__, level) making the test fail. since we are only testing + the OutputStreamSetting we can forcefully output to it directly. + */ + ::torch::jit::get_jit_logging_output_stream() << ::torch::jit::jit_log_prefix( + ::torch::jit::JitLoggingLevels::GRAPH_DUMP, + __FILE__, + __LINE__, + ::c10::str("Message")); ASSERT_TRUE(test_stream.str().size() > 0); } diff --git a/test/cpp/jit/test_lite_interpreter.cpp b/test/cpp/jit/test_lite_interpreter.cpp index 930b26076bbb1..c45ca96383e9f 100644 --- a/test/cpp/jit/test_lite_interpreter.cpp +++ b/test/cpp/jit/test_lite_interpreter.cpp @@ -17,7 +17,6 @@ #include #include #include -#include #include #include #include @@ -680,7 +679,6 @@ void backportAllVersionCheck( #if !defined FB_XPLAT_BUILD TEST(LiteInterpreterTest, BackPortByteCodeModelAllVersions) { - torch::jit::register_flatbuffer_all(); torch::jit::Module module("m"); // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) module.register_parameter("weight", torch::ones({20, 1, 5, 5}), false); diff --git a/test/cpp/jit/test_lite_trainer.cpp b/test/cpp/jit/test_lite_trainer.cpp index 10ba11dc1b4ae..311a818c4bfd0 100644 --- a/test/cpp/jit/test_lite_trainer.cpp +++ b/test/cpp/jit/test_lite_trainer.cpp @@ -10,7 +10,6 @@ #include #include #include -#include #include #include #include @@ -175,7 +174,6 @@ TEST(MobileTest, SaveParametersDefaultsToZip) { TEST(MobileTest, SaveParametersCanUseFlatbuffer) { // Save some empty parameters using flatbuffer. - register_flatbuffer_all(); std::map empty_parameters; std::stringstream ss_data; _save_parameters(empty_parameters, ss_data, /*use_flatbuffer=*/true); @@ -192,7 +190,6 @@ TEST(MobileTest, SaveParametersCanUseFlatbuffer) { TEST(MobileTest, SaveLoadParametersUsingFlatbuffers) { // Create some simple parameters to save. - register_flatbuffer_all(); std::map input_params; input_params["four_by_ones"] = 4 * torch::ones({}); input_params["three_by_ones"] = 3 * torch::ones({}); diff --git a/test/cpp/jit/test_misc.cpp b/test/cpp/jit/test_misc.cpp index 6e3283f62a5b8..3be0b8598b733 100644 --- a/test/cpp/jit/test_misc.cpp +++ b/test/cpp/jit/test_misc.cpp @@ -491,10 +491,20 @@ TEST(ControlFlowTest, Basic) { ASSERT_EQ(256, run_binary("while_test", 2, 0)); } +#if defined(__has_feature) +#if __has_feature(address_sanitizer) +#define HAS_ASANUBSAN 1 +#endif +#endif + +#ifndef HAS_ASANUBSAN +// This test fails vptr UBSAN checks + TEST(ProtoTest, Basic) { ::ONNX_NAMESPACE::ModelProto proto; proto.set_producer_name("foo"); } +#endif // test a few features that are not directly used in schemas yet TEST(SchemaParserTest, NestedArrays) { @@ -1447,35 +1457,29 @@ TEST(TestSymInt, AddSymbolicInt) { } #ifndef C10_MOBILE -TEST(TestSymInt, TestIntrusive) { - auto a = c10::make_intrusive(); - auto b = c10::make_intrusive(); - ASSERT_EQ(a.use_count(), 1); - ASSERT_EQ(b.use_count(), 1); - auto as = a->toSymInt(); - auto bs = b->toSymInt(); - ASSERT_EQ(a.use_count(), 2); - ASSERT_EQ(b.use_count(), 2); - as = bs; - ASSERT_EQ(a.use_count(), 1); - ASSERT_EQ(b.use_count(), 3); -} - -class TestSymIntNodeImpl : public c10::SymIntNodeImpl { +class TestSymNodeImpl : public c10::SymNodeImpl { public: - TestSymIntNodeImpl(int64_t i) : i_(i) {} + explicit TestSymNodeImpl(int64_t i) : i_(i) {} + + bool is_int() override { + return true; + }; + + bool is_float() override { + return false; + }; bool bool_() override { return static_cast(i_); }; -#define OPDEF3(NAME, OP, RET) \ - RET NAME(const c10::SymIntNode& other) override { \ - return make_intrusive( \ - this->i_ OP dynamic_cast(other.get())->i_); \ +#define OPDEF3(NAME, OP, RET) \ + RET NAME(const c10::SymNode& other) override { \ + return make_intrusive( \ + this->i_ OP dynamic_cast(other.get())->i_); \ } -#define OPDEF2(NAME, OP) OPDEF3(NAME, OP, c10::SymIntNode) +#define OPDEF2(NAME, OP) OPDEF3(NAME, OP, c10::SymNode) OPDEF2(add, +) OPDEF2(sub, -) OPDEF2(mul, *) @@ -1494,17 +1498,19 @@ class TestSymIntNodeImpl : public c10::SymIntNodeImpl { int64_t i_; }; -TEST(TestSymInt, TestSymIntToSymIntNodeDispatch) { +TEST(TestSymInt, TestSymIntToSymNodeDispatch) { auto get = [](c10::SymInt si) { - auto node = si.toSymIntNodeImpl(); - return dynamic_cast(node.get())->i_; + auto node = si.toSymNodeImpl(); + return dynamic_cast(node.get())->i_; }; std::vector inputs{0, 1, -1, 4, -4, 777, -777}; for (auto i : inputs) { for (auto j : inputs) { - auto a = c10::make_intrusive(i)->toSymInt(); - auto b = c10::make_intrusive(j)->toSymInt(); + auto a = c10::SymInt( + static_cast(c10::make_intrusive(i))); + auto b = c10::SymInt( + static_cast(c10::make_intrusive(j))); ASSERT_EQ(get(a + b), i + j); ASSERT_EQ(get(a - b), i - j); ASSERT_EQ(get(a * b), i * j); diff --git a/test/cpp/jit/test_module_api.cpp b/test/cpp/jit/test_module_api.cpp index adaad24203c95..f5535eb64c8ed 100644 --- a/test/cpp/jit/test_module_api.cpp +++ b/test/cpp/jit/test_module_api.cpp @@ -66,7 +66,7 @@ TEST(ModuleAPITest, MethodRunAsync) { mtx.lock(); ++counter; mtx.unlock(); - at::launch(move(f)); + at::launch(std::move(f)); }; auto method = m.get_method("forward"); diff --git a/test/cpp/lazy/test_ir_util.cpp b/test/cpp/lazy/test_ir_util.cpp index 2befb04236ab5..0b2bfc7614b10 100644 --- a/test/cpp/lazy/test_ir_util.cpp +++ b/test/cpp/lazy/test_ir_util.cpp @@ -52,7 +52,7 @@ TEST(IrUtilTest, BasicTest) { dynamic_cast(b.get())->AddOperand(Value(d, 0)); dynamic_cast(c.get())->AddOperand(Value(d, 0)); - std::vector postorder = Util::ComputePostOrder({a.get()}); + auto postorder = Util::ComputePostOrder({a.get()}); EXPECT_EQ(postorder.size(), 4); EXPECT_EQ(postorder.at(0), d.get()); EXPECT_EQ(postorder.at(1), c.get()); diff --git a/test/cpp/lite_interpreter_runtime/resources.h b/test/cpp/lite_interpreter_runtime/resources.h new file mode 100644 index 0000000000000..0be5928b299ba --- /dev/null +++ b/test/cpp/lite_interpreter_runtime/resources.h @@ -0,0 +1,41 @@ +#pragma once + +#include + +namespace torch { +namespace testing { + +namespace detail { +class Path; +} + +/// Gets the path to the resource identified by name. +/// +/// @param name identifies a resource, relative path starting from the +/// repo root +inline auto getResourcePath(std::string name) -> detail::Path; + +// End interface: implementation details follow. + +namespace detail { + +class Path { + public: + explicit Path(std::string rep) : rep_(std::move(rep)) {} + + auto string() const -> std::string const& { + return rep_; + } + + private: + std::string rep_; +}; + +} // namespace detail + +inline auto getResourcePath(std::string name) -> detail::Path { + return detail::Path(std::move(name)); +} + +} // namespace testing +} // namespace torch diff --git a/test/cpp/lite_interpreter_runtime/test_mobile_profiler.cpp b/test/cpp/lite_interpreter_runtime/test_mobile_profiler.cpp index 867b775c1adb4..df9cb9cea28c6 100644 --- a/test/cpp/lite_interpreter_runtime/test_mobile_profiler.cpp +++ b/test/cpp/lite_interpreter_runtime/test_mobile_profiler.cpp @@ -9,6 +9,10 @@ #include +#include + +#include "test/cpp/lite_interpreter_runtime/resources.h" + #ifdef EDGE_PROFILER_USE_KINETO namespace torch { namespace jit { @@ -25,7 +29,10 @@ bool checkMetaData( if (line.find(op_name) != std::string::npos) { while (std::getline(trace_file, line)) { if (line.find(metadata_name) != std::string::npos) { - if (line.find(metadata_val) != std::string::npos) { + if (line.find(metadata_val) != std::string::npos || + !metadata_val.size()) { + /* if found the right metadata_val OR if expected + * metadata value is an empty string then ignore the matadata_val */ return true; } } @@ -37,16 +44,15 @@ bool checkMetaData( } // namespace TEST(MobileProfiler, ModuleHierarchy) { - std::string filePath(__FILE__); - auto testModelFile = filePath.substr(0, filePath.find_last_of("/\\") + 1); - testModelFile.append("to_be_profiled_module.ptl"); + auto testModelFile = torch::testing::getResourcePath( + "test/cpp/lite_interpreter_runtime/to_be_profiled_module.ptl"); std::vector inputs; inputs.emplace_back(at::rand({64, 64})); inputs.emplace_back(at::rand({64, 64})); std::string trace_file_name("/tmp/test_trace.trace"); - mobile::Module bc = _load_for_mobile(testModelFile); + mobile::Module bc = _load_for_mobile(testModelFile.string()); { KinetoEdgeCPUProfiler profiler( bc, @@ -90,16 +96,15 @@ TEST(MobileProfiler, ModuleHierarchy) { } TEST(MobileProfiler, Backend) { - std::string filePath(__FILE__); - auto testModelFile = filePath.substr(0, filePath.find_last_of("/\\") + 1); - testModelFile.append("test_backend_for_profiling.ptl"); + auto testModelFile = torch::testing::getResourcePath( + "test/cpp/lite_interpreter_runtime/test_backend_for_profiling.ptl"); std::vector inputs; inputs.emplace_back(at::rand({64, 64})); inputs.emplace_back(at::rand({64, 64})); std::string trace_file_name("/tmp/test_trace_backend.trace"); - mobile::Module bc = _load_for_mobile(testModelFile); + mobile::Module bc = _load_for_mobile(testModelFile.string()); { KinetoEdgeCPUProfiler profiler( bc, @@ -125,16 +130,15 @@ TEST(MobileProfiler, Backend) { } TEST(MobileProfiler, BackendMemoryEvents) { - std::string filePath(__FILE__); - auto testModelFile = filePath.substr(0, filePath.find_last_of("/\\") + 1); - testModelFile.append("test_backend_for_profiling.ptl"); + auto testModelFile = torch::testing::getResourcePath( + "test/cpp/lite_interpreter_runtime/test_backend_for_profiling.ptl"); std::vector inputs; inputs.emplace_back(at::rand({64, 64})); inputs.emplace_back(at::rand({64, 64})); std::string trace_file_name("/tmp/test_trace_backend_memory.trace"); - mobile::Module bc = _load_for_mobile(testModelFile); + mobile::Module bc = _load_for_mobile(testModelFile.string()); { mobile::KinetoEdgeCPUProfiler profiler( bc, @@ -157,6 +161,51 @@ TEST(MobileProfiler, BackendMemoryEvents) { ASSERT_TRUE(checkMetaData("[memory]", metadata_name, "49152", trace_file)); } +TEST(MobileProfiler, ProfilerEvent) { + auto testModelFile = torch::testing::getResourcePath( + "test/cpp/lite_interpreter_runtime/test_backend_for_profiling.ptl"); + + std::vector inputs; + inputs.emplace_back(at::rand({64, 64})); + inputs.emplace_back(at::rand({64, 64})); + std::string trace_file_name("/tmp/test_trace_profiler_event.trace"); + + std::vector events( + torch::profiler::ProfilerPerfEvents.begin(), + torch::profiler::ProfilerPerfEvents.end()); + + mobile::Module bc = _load_for_mobile(testModelFile.string()); + { + // Bail if something goes wrong here + try { + KinetoEdgeCPUProfiler profiler( + bc, + trace_file_name, + false, // record input_shapes + false, // profile memory + true, // record callstack + false, // record flops + true, // record module hierarchy + events); // performance events + bc.forward(inputs); + } catch (...) { + return; + } + } // End of profiler + std::ifstream trace_file(trace_file_name); + std::string line; + ASSERT_TRUE(trace_file.is_open()); + + for (auto& event : events) { + trace_file.seekg(0, std::ios_base::beg); + /* + * Just checking if the event entry exists in the chrometrace. + * Checking the value in a hardware independent matter is tricky. + */ + ASSERT_TRUE(checkMetaData("aten::__getitem__", event, "", trace_file)); + } +} + } // namespace mobile } // namespace jit } // namespace torch diff --git a/test/cpp/profiler/perf_events.cpp b/test/cpp/profiler/perf_events.cpp new file mode 100644 index 0000000000000..7740f42da4b52 --- /dev/null +++ b/test/cpp/profiler/perf_events.cpp @@ -0,0 +1,248 @@ + +#include + +#include +#include + +double calc_pi() { + volatile double pi = 1.0; + for (int i = 3; i < 100000; i += 2) { + pi += (((i + 1) >> 1) % 2) ? 1.0 / i : -1.0 / i; + } + return pi * 4.0; +} + +TEST(ProfilerTest, LinuxPerf) { + torch::profiler::impl::linux_perf::PerfProfiler profiler; + + std::vector standard_events( + std::begin(torch::profiler::ProfilerPerfEvents), + std::end(torch::profiler::ProfilerPerfEvents)); + torch::profiler::perf_counters_t counters; + counters.resize(standard_events.size(), 0); + + // Use try..catch HACK to check TORCH_CHECK because we don't yet fail + // gracefully if the syscall were to fail + try { + profiler.Configure(standard_events); + + profiler.Enable(); + auto pi = calc_pi(); + profiler.Disable(counters); + } catch (const c10::Error&) { + // Bail here if something bad happened during the profiling, we don't want + // to make the test fail + return; + } catch (...) { + // something else went wrong - this should be reported + ASSERT_EQ(0, 1); + } + + // Should have counted something if worked, so lets test that + // And if it not supported the counters should be zeros. +#if defined(__ANDROID__) || defined(__linux__) + for (auto counter : counters) { + ASSERT_GT(counter, 0); + } +#else /* __ANDROID__ || __linux__ */ + for (auto counter : counters) { + ASSERT_EQ(counter, 0); + } +#endif /* __ANDROID__ || __linux__ */ +} + +TEST(ProfilerTest, LinuxPerfNestedDepth) { + torch::profiler::impl::linux_perf::PerfProfiler profiler; + + // Only monotonically increasing events will work + std::vector standard_events( + std::begin(torch::profiler::ProfilerPerfEvents), + std::end(torch::profiler::ProfilerPerfEvents)); + + torch::profiler::perf_counters_t counters_A; + torch::profiler::perf_counters_t counters_B; + torch::profiler::perf_counters_t counters_C; + + counters_A.resize(standard_events.size(), 0); + counters_B.resize(standard_events.size(), 0); + counters_C.resize(standard_events.size(), 0); + + // Use try..catch HACK to check TORCH_CHECK because we don't yet fail + // gracefully if the syscall were to fail + try { + profiler.Configure(standard_events); + + // * = work kernel calc_pi() + // + // A --*---+ +--*-- A + // | | + // | | + // B +-*--+ +--*-+ B + // | | + // | | + // C +-*--+ C + // + + profiler.Enable(); + auto A = calc_pi(); + + profiler.Enable(); + auto B = calc_pi(); + + profiler.Enable(); + auto C = calc_pi(); + profiler.Disable(counters_C); + + auto B2 = calc_pi(); + profiler.Disable(counters_B); + + auto A2 = calc_pi(); + profiler.Disable(counters_A); + } catch (const c10::Error&) { + // Bail here if something bad happened during the profiling, we don't want + // to make the test fail + return; + } catch (...) { + // something else went wrong - this should be reported + ASSERT_EQ(0, 1); + } + +// for each counter, assert A > B > C +#if defined(__ANDROID__) || defined(__linux__) + for (auto i = 0; i < standard_events.size(); ++i) { + ASSERT_GT(counters_A[i], counters_B[i]); + ASSERT_GT(counters_A[i], counters_C[i]); + ASSERT_GT(counters_B[i], counters_C[i]); + ASSERT_GT(counters_A[i], counters_B[i] + counters_C[i]); + } +#else /* __ANDROID__ || __linux__ */ + for (auto i = 0; i < standard_events.size(); ++i) { + ASSERT_EQ(counters_A[i], 0); + ASSERT_EQ(counters_B[i], 0); + ASSERT_EQ(counters_C[i], 0); + } +#endif /* __ANDROID__ || __linux__ */ +} + +TEST(ProfilerTest, LinuxPerfNestedMultiple) { + torch::profiler::impl::linux_perf::PerfProfiler profiler; + + // Only monotonically increasing events will work + std::vector standard_events( + std::begin(torch::profiler::ProfilerPerfEvents), + std::end(torch::profiler::ProfilerPerfEvents)); + + torch::profiler::perf_counters_t counters_A; + torch::profiler::perf_counters_t counters_B; + torch::profiler::perf_counters_t counters_C; + + counters_A.resize(standard_events.size(), 0); + counters_B.resize(standard_events.size(), 0); + counters_C.resize(standard_events.size(), 0); + + // Use try..catch HACK to check TORCH_CHECK because we don't yet fail + // gracefully if the syscall were to fail + try { + profiler.Configure(standard_events); + + // * = work kernel calc_pi() + // + // A --*---+ +---*----+ +--*-- A + // | | | | + // | | | | + // B +-**-+ B C +-*--+ C + + profiler.Enable(); + auto A1 = calc_pi(); + + profiler.Enable(); + auto B1 = calc_pi(); + auto B2 = calc_pi(); + profiler.Disable(counters_B); + + auto A2 = calc_pi(); + + profiler.Enable(); + auto C1 = calc_pi(); + profiler.Disable(counters_C); + + auto A3 = calc_pi(); + profiler.Disable(counters_A); + } catch (const c10::Error&) { + // Bail here if something bad happened during the profiling, we don't want + // to make the test fail + return; + } catch (...) { + // something else went wrong - this should be reported + ASSERT_EQ(0, 1); + } + +// for each counter, assert A > B > C +#if defined(__ANDROID__) || defined(__linux__) + for (auto i = 0; i < standard_events.size(); ++i) { + ASSERT_GT(counters_A[i], counters_B[i]); + ASSERT_GT(counters_A[i], counters_C[i]); + ASSERT_GT(counters_B[i], counters_C[i]); + ASSERT_GT(counters_A[i], counters_B[i] + counters_C[i]); + } +#else /* __ANDROID__ || __linux__ */ + for (auto i = 0; i < standard_events.size(); ++i) { + ASSERT_EQ(counters_A[i], 0); + ASSERT_EQ(counters_B[i], 0); + ASSERT_EQ(counters_C[i], 0); + } +#endif /* __ANDROID__ || __linux__ */ +} + +TEST(ProfilerTest, LinuxPerfNestedSingle) { + torch::profiler::impl::linux_perf::PerfProfiler profiler; + + // Only monotonically increasing events will work + std::vector standard_events( + std::begin(torch::profiler::ProfilerPerfEvents), + std::end(torch::profiler::ProfilerPerfEvents)); + + torch::profiler::perf_counters_t counters_A; + torch::profiler::perf_counters_t counters_B; + torch::profiler::perf_counters_t counters_C; + + counters_A.resize(standard_events.size(), 0); + counters_B.resize(standard_events.size(), 0); + counters_C.resize(standard_events.size(), 0); + + // Use try..catch HACK to check TORCH_CHECK because we don't yet fail + // gracefully if the syscall were to fail + try { + profiler.Configure(standard_events); + + profiler.Enable(); + profiler.Enable(); + profiler.Enable(); + auto A1 = calc_pi(); + profiler.Disable(counters_C); + profiler.Disable(counters_B); + profiler.Disable(counters_A); + } catch (const c10::Error&) { + // Bail here if something bad happened during the profiling, we don't want + // to make the test fail + return; + } catch (...) { + // something else went wrong - this should be reported + ASSERT_EQ(0, 1); + } + +// for each counter, assert A > B > C +#if defined(__ANDROID__) || defined(__linux__) + for (auto i = 0; i < standard_events.size(); ++i) { + ASSERT_GE(counters_A[i], counters_B[i]); + ASSERT_GE(counters_A[i], counters_C[i]); + ASSERT_GE(counters_B[i], counters_C[i]); + } +#else /* __ANDROID__ || __linux__ */ + for (auto i = 0; i < standard_events.size(); ++i) { + ASSERT_EQ(counters_A[i], 0); + ASSERT_EQ(counters_B[i], 0); + ASSERT_EQ(counters_C[i], 0); + } +#endif /* __ANDROID__ || __linux__ */ +} diff --git a/test/cpp_extensions/extension.cpp b/test/cpp_extensions/extension.cpp index c8772dc1b0ffe..37ed516ca99c2 100644 --- a/test/cpp_extensions/extension.cpp +++ b/test/cpp_extensions/extension.cpp @@ -27,6 +27,10 @@ bool function_taking_optional(c10::optional tensor) { return tensor.has_value(); } +torch::Tensor random_tensor() { + return torch::randn({1}); +} + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("sigmoid_add", &sigmoid_add, "sigmoid(x) + sigmoid(y)"); m.def( @@ -37,4 +41,16 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .def(py::init()) .def("forward", &MatrixMultiplier::forward) .def("get", &MatrixMultiplier::get); + + m.def("get_complex", []() { return c10::complex(1.0, 2.0); }); + m.def("get_device", []() { return at::device_of(random_tensor()).value(); }); + m.def("get_generator", []() { return at::detail::getDefaultCPUGenerator(); }); + m.def("get_intarrayref", []() { return at::IntArrayRef({1, 2, 3}); }); + m.def("get_memory_format", []() { return c10::get_contiguous_memory_format(); }); + m.def("get_storage", []() { return random_tensor().storage(); }); + m.def("get_symfloat", []() { return c10::SymFloat(1.0); }); + m.def("get_symint", []() { return c10::SymInt(1); }); + m.def("get_symint_symbolic", []() { return c10::SymInt(c10::SymInt::UNCHECKED, INT64_MIN); }); + m.def("get_symintarrayref", []() { return at::SymIntArrayRef({1, 2, 3}); }); + m.def("get_tensor", []() { return random_tensor(); }); } diff --git a/test/cuda_results.yaml b/test/cuda_results.yaml new file mode 100644 index 0000000000000..1b7c131dc81f0 --- /dev/null +++ b/test/cuda_results.yaml @@ -0,0 +1,53 @@ +{ + nn.functional.conv_transpose2d: + [[[7.399066925048828, 4.4053635597229, -25.85348129272461, + 58.88909149169922, -88.75193786621094, -18.98126983642578, 9.437820434570312], + [-59.78305435180664, -65.34088134765625, -108.04747009277344, 196.6062469482422, + 71.39350891113281, 37.8786735534668, -69.55322265625], [92.78504943847656, + 91.24403381347656, -94.33301544189453, 9.261059761047363, -182.10206604003906, + 141.4270477294922, 146.89010620117188], [-14.363212585449219, 43.454036712646484, + -76.1098403930664, 242.9479522705078, 198.1458282470703, -49.77315139770508, + 5.891449451446533], [-43.56822967529297, 4.782844066619873, -29.526945114135742, + 65.15388488769531, 161.29757690429688, 118.60847473144531, 27.08570671081543], + [68.29853057861328, -11.507468223571777, 2.044086217880249, 11.003862380981445, + 34.993282318115234, -21.256723403930664, 91.49512481689453], [-70.4466781616211, + 69.04386138916016, 7.764842987060547, 7.61972713470459, -28.99899673461914, + 54.575748443603516, -5.762258052825928]], [[-36.238487243652344, 37.29551696777344, + -22.012331008911133, -30.1353702545166, 33.82851028442383, 33.00322341918945, + 2.7218000888824463], [-7.999058246612549, 122.72489929199219, -1.0639530420303345, + 2.9564287662506104, -143.1276092529297, -110.75650024414062, 48.0764274597168], + [-91.0599136352539, -11.656601905822754, 69.62447357177734, 88.12522888183594, + 337.3008728027344, -76.9416732788086, -110.24406433105469], [-108.1512451171875, + 98.42401123046875, 142.46144104003906, -127.48089599609375, -3.367496967315674, + 86.82833099365234, 86.29623413085938], [-14.339198112487793, -52.287410736083984, + 171.43614196777344, 200.14817810058594, 200.35476684570312, -189.4150390625, + -46.86980056762695], [30.196495056152344, 25.22877311706543, 95.29426574707031, + 4.455311298370361, 118.48747253417969, 87.11080932617188, -83.6124038696289], + [-2.5434072017669678, 91.8791732788086, -10.615175247192383, -12.58531379699707, + -49.3439826965332, 33.37324523925781, -5.983145713806152]], [[4.551003932952881, + 15.84842586517334, -46.354671478271484, 14.721636772155762, 39.01048278808594, + 49.70054244995117, -18.268564224243164], [16.728954315185547, 129.43505859375, + -4.6139116287231445, -3.382319688796997, -238.76353454589844, 13.42194938659668, + 40.393280029296875], [-2.335604429244995, -85.94283294677734, -142.2253875732422, + 135.27537536621094, 18.01512336730957, -26.331714630126953, -33.35443878173828], + [-79.17593383789062, -93.72674560546875, -110.94194030761719, -61.455223083496094, + 6.811624526977539, 129.06478881835938, 12.435402870178223], [10.859378814697266, + 41.3059196472168, 143.55824279785156, -41.754737854003906, -235.32406616210938, + -70.98460388183594, 130.46929931640625], [193.57574462890625, -142.5060272216797, + -102.45012664794922, 124.68048095703125, 136.05215454101562, -9.650590896606445, + -45.59521484375], [-37.829593658447266, 39.12519454956055, 9.293094635009766, + -18.8004093170166, -0.7294210195541382, 51.884910583496094, 36.15913391113281]], + [[-15.651233673095703, 16.31340980529785, -26.752052307128906, 6.281721115112305, + 43.765541076660156, -13.097319602966309, -30.443206787109375], [10.67841911315918, + 66.1829605102539, -9.394262313842773, -131.45101928710938, -38.621002197265625, + 65.9507064819336, 48.76960372924805], [-76.0918197631836, -9.108996391296387, + 13.64936637878418, 96.7411880493164, 124.2474365234375, -111.50318145751953, + -42.397071838378906], [-83.31562805175781, 32.27967071533203, 250.08163452148438, + 58.24131393432617, 129.95318603515625, -10.683560371398926, -123.84668731689453], + [-11.536887168884277, -15.220125198364258, 197.18821716308594, -31.680112838745117, + -81.35874938964844, 157.96974182128906, 105.61251831054688], [78.15926361083984, + -84.49744415283203, -73.91180419921875, 86.370361328125, 77.87918090820312, + 55.3555908203125, -7.273794651031494], [25.232547760009766, 30.352109909057617, + 53.722267150878906, 44.87421798706055, 44.618812561035156, 4.511796951293945, + 9.039834976196289]]] +} diff --git a/test/custom_backend/CMakeLists.txt b/test/custom_backend/CMakeLists.txt index 71f83442e085f..835f17850a842 100644 --- a/test/custom_backend/CMakeLists.txt +++ b/test/custom_backend/CMakeLists.txt @@ -9,9 +9,9 @@ endif() find_package(Torch REQUIRED) add_library(custom_backend SHARED custom_backend.cpp) -set_property(TARGET custom_backend PROPERTY CXX_STANDARD 14) +set_property(TARGET custom_backend PROPERTY CXX_STANDARD 17) target_link_libraries(custom_backend "${TORCH_LIBRARIES}") add_executable(test_custom_backend test_custom_backend.cpp) -set_property(TARGET test_custom_backend PROPERTY CXX_STANDARD 14) +set_property(TARGET test_custom_backend PROPERTY CXX_STANDARD 17) target_link_libraries(test_custom_backend custom_backend) diff --git a/test/custom_operator/CMakeLists.txt b/test/custom_operator/CMakeLists.txt index 47c1c9d45e814..6d1a4988fe382 100644 --- a/test/custom_operator/CMakeLists.txt +++ b/test/custom_operator/CMakeLists.txt @@ -9,12 +9,12 @@ endif() find_package(Torch REQUIRED) add_library(custom_ops SHARED op.cpp) -set_property(TARGET custom_ops PROPERTY CXX_STANDARD 14) +set_property(TARGET custom_ops PROPERTY CXX_STANDARD 17) target_compile_features(custom_ops PUBLIC cxx_range_for) target_link_libraries(custom_ops "${TORCH_LIBRARIES}") target_compile_definitions(custom_ops PRIVATE custom_ops_EXPORTS) add_executable(test_custom_ops test_custom_ops.cpp) -set_property(TARGET test_custom_ops PROPERTY CXX_STANDARD 14) +set_property(TARGET test_custom_ops PROPERTY CXX_STANDARD 17) target_link_libraries(test_custom_ops custom_ops) diff --git a/test/distributed/_composable/test_checkpoint.py b/test/distributed/_composable/test_checkpoint.py new file mode 100644 index 0000000000000..e2907bcb9fcb7 --- /dev/null +++ b/test/distributed/_composable/test_checkpoint.py @@ -0,0 +1,127 @@ +# Owner(s): ["oncall: distributed"] + +from torch.testing._internal.common_cuda import TEST_CUDA +from torch.testing._internal.common_utils import ( + TestCase, + instantiate_parametrized_tests, + parametrize, + run_tests, +) + +import torch +import torch.nn as nn +from torch.distributed._composable import checkpoint + +import unittest +from collections import deque +from contextlib import ContextDecorator +from copy import deepcopy + + +class MemoryDelta(ContextDecorator): + def __init__(self, device: torch.device): + self.device: torch.device = device + self.active_memory_enter: int = 0 + self.active_memory_exit: int = 0 + + def __enter__(self): + self.active_memory_enter = ( + torch.cuda.memory_stats()["active_bytes.all.current"] + if self.device.type == "cuda" + else 0 + ) + return self + + def __exit__(self, *exc): + self.active_memory_exit = ( + torch.cuda.memory_stats()["active_bytes.all.current"] + if self.device.type == "cuda" + else 0 + ) + + def delta(self) -> int: + return self.active_memory_exit - self.active_memory_enter + + +class ToyModel(nn.Module): + def __init__(self): + super().__init__() + self.l1 = nn.Linear(100, 100) + self.seq = nn.Sequential( + nn.ReLU(), + nn.Linear(100, 100), + nn.ReLU(), + ) + + def forward(self, x): + return self.seq(self.l1(x)) + + +class TestCheckpoint(TestCase): + def _get_graph_size(self, out: torch.Tensor) -> int: + q = deque([out.grad_fn]) + num_functions = 0 + while len(q): + fn = q.pop() + num_functions += 1 + for next_fn, _ in fn.next_functions: + if next_fn: + q.append(next_fn) + + return num_functions + + def _test_tensor_only( + self, + net: nn.Module, + x: torch.Tensor, + use_reentrant: bool, + ) -> None: + x1 = x.clone() + x2 = x.clone() + x1.requires_grad = True + x2.requires_grad = True + + net1 = net + net2 = deepcopy(net) + + # no checkpoint + with MemoryDelta(x.device) as mem1: + loss1 = net1(x1).sum() + graph_size1 = self._get_graph_size(loss1) + loss1.backward() + + # with checkpoint + checkpoint(net2.seq, use_reentrant=use_reentrant) + with MemoryDelta(x.device) as mem2: + loss2 = net2(x2).sum() + graph_size2 = self._get_graph_size(loss2) + loss2.backward() + + if use_reentrant: + self.assertTrue(graph_size2 < graph_size1) + + if x.is_cuda: + self.assertTrue(mem2.delta() < mem1.delta()) + + for p1, p2 in zip(net1.parameters(), net2.parameters()): + self.assertEqual(p1.grad, p2.grad) + + @parametrize("use_reentrant", [True, False]) + def test_tensor_only_cpu(self, use_reentrant: bool): + x = torch.randn(20, 100) + net = ToyModel() + self._test_tensor_only(net, x, use_reentrant) + + @unittest.skipIf(not TEST_CUDA, "no cuda") + @parametrize("use_reentrant", [True, False]) + def test_tensor_only_gpu(self, use_reentrant: bool): + x = torch.randn(20, 100, device="cuda:0") + net = ToyModel().to("cuda:0") + self._test_tensor_only(net, x, use_reentrant) + + +instantiate_parametrized_tests(TestCheckpoint) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/_composable/test_compose.py b/test/distributed/_composable/test_compose.py new file mode 100644 index 0000000000000..20c285711d70d --- /dev/null +++ b/test/distributed/_composable/test_compose.py @@ -0,0 +1,179 @@ +# Owner(s): ["oncall: distributed"] + +import copy +import sys + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.distributed._composable import checkpoint, fully_shard +from torch.distributed.fsdp.api import ShardingStrategy +from torch.distributed.fsdp.wrap import ModuleWrapPolicy +from torch.testing._internal.common_dist_composable import ( + CompositeModel, + CompositeParamModel, + UnitModule, +) +from torch.testing._internal.common_distributed import skip_if_lt_x_gpu +from torch.testing._internal.common_fsdp import FSDPTest +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, + run_tests, + TEST_WITH_DEV_DBG_ASAN, +) + + +if not dist.is_available(): + print("Distributed not available, skipping tests", file=sys.stderr) + sys.exit(0) + + +if TEST_WITH_DEV_DBG_ASAN: + print( + "Skip dev-asan as torch + multiprocessing spawn have known issues", + file=sys.stderr, + ) + sys.exit(0) + + +class TestFSDPCheckpoint(FSDPTest): + @property + def world_size(self) -> int: + return 2 + + def _test_parity( + self, + base_model: nn.Module, + test_model: nn.Module, + x: torch.Tensor, + grad_to_none: bool, + ): + LR = 0.01 + base_optim = torch.optim.Adam(base_model.parameters(), lr=LR) + test_optim = torch.optim.Adam(test_model.parameters(), lr=LR) + + for _ in range(5): + test_loss = test_model(x).sum() + base_loss = base_model(x).sum() + + self.assertEqual(test_loss, base_loss) + + test_loss.backward() + test_optim.step() + test_optim.zero_grad(set_to_none=grad_to_none) + + base_loss.backward() + base_optim.step() + base_optim.zero_grad(set_to_none=grad_to_none) + + @skip_if_lt_x_gpu(2) + @parametrize("use_reentrant", [True, False]) + def test_wrap_same_submodule(self, use_reentrant: bool): + model = UnitModule(device=torch.device("cuda")) + + base_model = copy.deepcopy(model) + + test_model = copy.deepcopy(model) + # compose checkpoint and fully_shard + test_model.seq = checkpoint(test_model.seq, use_reentrant=use_reentrant) + test_model.seq = fully_shard( + test_model.seq, + policy=ModuleWrapPolicy({nn.Linear}), + ) + + self.run_subtests( + { + "base_model": [base_model], + "test_model": [test_model], + "x": [torch.randn(2, 100, device="cuda")], + "grad_to_none": [True, False], + }, + self._test_parity, + ) + + def _test_checkpoint_fsdp_submodules(self, use_reentrant): + model = CompositeModel(device=torch.device("cuda")) + + base_model = copy.deepcopy(model) + + test_model = copy.deepcopy(model) + test_model.u1 = fully_shard(test_model.u1, policy=None) + test_model.u2 = fully_shard(test_model.u2) + + test_model.u1.seq = checkpoint(test_model.u1.seq, use_reentrant=use_reentrant) + test_model.u2.seq = checkpoint(test_model.u2.seq, use_reentrant=use_reentrant) + + self.run_subtests( + { + "base_model": [base_model], + "test_model": [test_model], + "x": [torch.randn(2, 100, device="cuda")], + "grad_to_none": [True, False], + }, + self._test_parity, + ) + + @skip_if_lt_x_gpu(2) + def test_checkpoint_fsdp_submodules_use_reentrant(self): + # Escape the brackets like `\[` since `[` has special meaning in regex + with self.assertRaisesRegex( + RuntimeError, + r"setStorage: sizes \[100, 100\], strides \[100, 1\], storage " + "offset 0, and itemsize 4 requiring a storage size of 40000 are " + "out of bounds for storage of size 0", + ): + self._test_checkpoint_fsdp_submodules(True) + + @skip_if_lt_x_gpu(2) + def test_checkpoint_fsdp_submodules_non_reentrant(self): + self._test_checkpoint_fsdp_submodules(False) + + @skip_if_lt_x_gpu(2) + def test_checkpoint_fsdp_submodules_with_param(self): + model = CompositeParamModel(device=torch.device("cuda")) + + base_model = copy.deepcopy(model) + + test_model = copy.deepcopy(model) + test_model.u1.seq = checkpoint(test_model.u1.seq, use_reentrant=False) + test_model.u2.seq = checkpoint(test_model.u2.seq, use_reentrant=False) + test_model = fully_shard(test_model) + + self.run_subtests( + { + "base_model": [base_model], + "test_model": [test_model], + "x": [torch.randn(2, 100, device="cuda")], + "grad_to_none": [True, False], + }, + self._test_parity, + ) + + @skip_if_lt_x_gpu(2) + def test_checkpoint_fsdp_submodules_with_param_no_shard(self): + model = CompositeParamModel(device=torch.device("cuda")) + + base_model = copy.deepcopy(model) + + test_model = copy.deepcopy(model) + test_model.u1.seq = checkpoint(test_model.u1.seq, use_reentrant=False) + test_model.u2.seq = checkpoint(test_model.u2.seq, use_reentrant=False) + test_model = fully_shard(test_model, strategy=ShardingStrategy.NO_SHARD) + + self.run_subtests( + { + "base_model": [base_model], + "test_model": [test_model], + "x": [torch.randn(2, 100, device="cuda")], + "grad_to_none": [True, False], + }, + self._test_parity, + ) + + +instantiate_parametrized_tests(TestFSDPCheckpoint) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/_composable/test_contract.py b/test/distributed/_composable/test_contract.py new file mode 100644 index 0000000000000..d510af6d7b2b0 --- /dev/null +++ b/test/distributed/_composable/test_contract.py @@ -0,0 +1,122 @@ +# Owner(s): ["oncall: distributed"] + +from torch.testing._internal.common_utils import ( + TestCase, + run_tests, + skipIfTorchDynamo, +) + +import torch +import torch.nn as nn +from torch.distributed._composable import contract + +from copy import deepcopy +from typing import Tuple + + +class ToyModel(nn.Module): + def __init__(self): + super().__init__() + self.seq1 = nn.Sequential(*[nn.Linear(10, 10) for _ in range(2)]) + self.seq2 = nn.Sequential(*[nn.Linear(10, 10) for _ in range(2)]) + self.p = nn.Parameter(torch.randn(10, 10), requires_grad=True) + self.b = torch.zeros(1) # buffer + + def forward(self, x, y): + with torch.no_grad(): + self.b += x.sum() + y.sum() + + return self.p + self.seq1(x) + self.seq2(y) + + +class TestContract(TestCase): + @skipIfTorchDynamo("Dynamo does not yet capture module hooks") + def test_add_hooks(self): + def forward_pre_hook( + module: nn.Module, inp: Tuple[torch.Tensor] + ) -> Tuple[torch.Tensor]: + return inp + + def forward_hook( + module: nn.Module, inp: Tuple[torch.Tensor], out: torch.Tensor + ) -> torch.Tensor: + return out + + def backward_pre_hook( + module: nn.Module, grad_output: torch.Tensor + ) -> torch.Tensor: + return grad_output + + def backward_hook( + module: nn.Module, + grad_input: Tuple[torch.Tensor], + grad_output: torch.Tensor, + ) -> Tuple[torch.Tensor]: + return grad_input + + @contract + def noop_api(module: nn.Module) -> nn.Module: + module.register_forward_pre_hook(forward_pre_hook) + module.register_forward_hook(forward_hook) + module.register_full_backward_pre_hook(backward_pre_hook) + module.register_full_backward_hook(backward_hook) + return module + + model = ToyModel() + model_with_hooks = deepcopy(model) + noop_api(model.seq1) + noop_api(model.seq2) + + x, y = torch.randn(10, 10), torch.randn(10, 10) + model(x, y).sum().backward() + model_with_hooks(x, y).sum().backward() + + for p1, p2 in zip(model.parameters(), model_with_hooks.parameters()): + self.assertEqual(p1, p2) + + @skipIfTorchDynamo("Dynamo does not yet capture module hooks") + def test_modify_fqn(self): + class ModelWrapper(nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + + def forward(self, x): + return self.module(x) + + @contract + def wrap_module(module: nn.Module) -> nn.Module: + return ModelWrapper(module) + + model = ToyModel() + + with self.assertRaisesRegex(RuntimeError, "cannot modify FQNs"): + wrap_module(model.seq1) + + @skipIfTorchDynamo("Dynamo does not yet capture module hooks") + def test_state(self): + def check_and_update_state_hook( + module: nn.Module, inp: Tuple[torch.Tensor] + ) -> Tuple[torch.Tensor]: + self.assertEqual(api.state(module).dummy_state, 7) + api.state(module).dummy_state = 8 + return inp + + # FIXME: circular reference looks a bit weird. Shall we make .state a + # top-level API instead attached to contract API? + @contract + def api(module: nn.Module) -> nn.Module: + api.state(module).dummy_state = 7 + module.register_forward_pre_hook(check_and_update_state_hook) + return module + + model = ToyModel() + api(model.seq1) + + self.assertEqual(api.state(model.seq1).dummy_state, 7) + model(torch.zeros(10, 10), torch.zeros(10, 10)) + self.assertEqual(api.state(model.seq1).dummy_state, 8) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/_composable/test_fully_shard.py b/test/distributed/_composable/test_fully_shard.py new file mode 100644 index 0000000000000..71903a2f66544 --- /dev/null +++ b/test/distributed/_composable/test_fully_shard.py @@ -0,0 +1,411 @@ +# Owner(s): ["oncall: distributed"] + +import contextlib +import copy +import functools +import sys +from typing import Callable, Iterable, List, Optional, Tuple + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.distributed._composable import fully_shard +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp._common_utils import ( + _all_handles, + _FSDPState, + _is_fsdp_flattened, +) +from torch.distributed.fsdp.flat_param import _HandlesKey, FlatParamHandle +from torch.distributed.fsdp.wrap import _FSDPPolicy, ModuleWrapPolicy +from torch.testing._internal.common_dist_composable import ( + CompositeParamModel, + UnitModule, +) +from torch.testing._internal.common_distributed import skip_if_lt_x_gpu +from torch.testing._internal.common_fsdp import FSDPTest +from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN + +if not dist.is_available(): + print("Distributed not available, skipping tests", file=sys.stderr) + sys.exit(0) + +if TEST_WITH_DEV_DBG_ASAN: + print( + "Skip dev-asan as torch + multiprocessing spawn have known issues", + file=sys.stderr, + ) + sys.exit(0) + + +class TestFSDPInitialization(FSDPTest): + """Tests composable FSDP initialization.""" + + @property + def world_size(self) -> int: + return 2 + + @skip_if_lt_x_gpu(2) + def test_policy(self): + """Tests passing a ``policy`` for pseudo-auto-wrapping.""" + self.run_subtests( + {"policy": [None, ModuleWrapPolicy({UnitModule})]}, + self._test_policy, + ) + + def _test_policy(self, policy: Optional[_FSDPPolicy]): + local_model = CompositeParamModel(torch.device("cuda")) + fsdp_wrapped_model = FSDP( + copy.deepcopy(local_model), + auto_wrap_policy=policy, + use_orig_params=True, + ) + composable_module = copy.deepcopy(local_model) + fully_shard( + composable_module, + policy=policy, + ) + + # Check that the composable module has the same names as the local + # model and the same sharded parameters as the FSDP-wrapped model + for ( + (local_name, _), + (composable_name, composable_param), + (_, fsdp_wrapped_param), + ) in zip( + local_model.named_parameters(), + composable_module.named_parameters(), + fsdp_wrapped_model.named_parameters(), + ): + self.assertEqual(local_name, composable_name) + self.assertEqual(fsdp_wrapped_param, composable_param) + + # Check that the composable module has the same `FlatParameter` + # construction as the FSDP-wrapped model + composable_handles = fully_shard.state(composable_module)._handles + fsdp_wrapped_handles = FSDP._fsdp_handles(fsdp_wrapped_model) + self.assertEqual(len(composable_handles), len(fsdp_wrapped_handles)) + for (composable_handle, fsdp_wrapped_handle) in zip( + composable_handles, fsdp_wrapped_handles + ): + self.assertEqual( + composable_handle.flat_param.shape, fsdp_wrapped_handle.flat_param.shape + ) + + # Check that the composable module does not add any wrapper class + local_module_classes = set() + composable_module_classes = set() + for submodule in local_model.modules(): + local_module_classes.add(type(submodule)) + for submodule in composable_module.modules(): + composable_module_classes.add(type(submodule)) + self.assertEqual(local_module_classes, composable_module_classes) + + @skip_if_lt_x_gpu(2) + def test_device_id(self): + """Tests passing a ``device_id``.""" + cpu_device = torch.device("cpu") + composable_module = CompositeParamModel(device=cpu_device) + for param in composable_module.parameters(): + assert ( + param.device == cpu_device + ), "Expects module to be initialized on CPU for this unit test" + fully_shard( + composable_module, + policy=ModuleWrapPolicy({UnitModule}), + device_id=self.rank, + ) + for param in composable_module.parameters(): + self.assertEqual(param.device, torch.device("cuda", self.rank)) + + @skip_if_lt_x_gpu(2) + def test_sync_module_states(self): + """Tests passing ``sync_module_states=True``.""" + local_model = CompositeParamModel(device=torch.device("cuda")) + composable_module = copy.deepcopy(local_model) + # Check that the parameters are broadcast from rank 0 by comparing + # against an equivalent FSDP-wrapped module + if self.rank != 0: + for param in composable_module.parameters(): + with torch.no_grad(): + param.zero_() + policy = ModuleWrapPolicy({UnitModule}) + fsdp_wrapped_model = FSDP( + copy.deepcopy(local_model), + auto_wrap_policy=policy, + use_orig_params=True, + ) + fully_shard( + composable_module, + policy=policy, + sync_module_states=True, + ) + for (composable_param, fsdp_wrapped_param) in zip( + composable_module.parameters(), + fsdp_wrapped_model.parameters(), + ): + self.assertEqual(composable_param, fsdp_wrapped_param) + + @skip_if_lt_x_gpu(2) + def test_materialize_meta_module(self): + """Tests materializing a meta-device module.""" + + def _param_init_fn(module: nn.Module): + """ + This is an example ``param_init_fn`` for composable FSDP. + + TODO: This function is not satisfactory because: + (1) This requires guarding with ``_is_fsdp_flattened()``. This + guard is needed to avoid re-initializing parameters for nested + cases since some initialization methods strictly require non-1D + shape (e.g. ``kaiming_uniform_()``), while FSDP replaces the + original parameters with their 1D shards. + (2) This requires module-by-module traversal and manual ``setattr`` + usage as opposed to first calling ``module.to_empty()`` and then + initializing each parameter after. The latter will override the + initialization of already-initialized nested parameters. In other + words, this parameter initialization function must strictly modify + only the parameters on meta device. + """ + torch.manual_seed(0) + for submodule in module.modules(): + for param_name, param in submodule.named_parameters(recurse=False): + if not _is_fsdp_flattened(param) and param.is_meta: + materialized_param = nn.Parameter( + torch.empty_like(param, device=torch.device("cuda")) + ) + nn.init.uniform_(materialized_param) + setattr(submodule, param_name, materialized_param) + + composable_module = CompositeParamModel(device=torch.device("meta")) + meta_model = CompositeParamModel(device=torch.device("meta")) + fsdp_wrapped_model = FSDP( + meta_model, + auto_wrap_policy=ModuleWrapPolicy({UnitModule}), + param_init_fn=_param_init_fn, + use_orig_params=True, + ) + fully_shard( + composable_module, + policy=ModuleWrapPolicy({UnitModule}), + param_init_fn=_param_init_fn, + ) + for ( + (composable_param_name, composable_param), + (fsdp_wrapped_param_name, fsdp_wrapped_param), + ) in zip( + composable_module.named_parameters(), + fsdp_wrapped_model.named_parameters(), + ): + self.assertEqual(composable_param_name, fsdp_wrapped_param_name) + self.assertEqual( + composable_param.device, + torch.device("cuda", torch.cuda.current_device()), + ) + self.assertEqual(composable_param, fsdp_wrapped_param) + + +class TestFSDPRuntime(FSDPTest): + """Tests composable FSDP runtime.""" + + @property + def world_size(self) -> int: + return 2 + + def _init_models_and_optims( + self, device: torch.device + ) -> Tuple[nn.Module, torch.optim.Optimizer, nn.Module, torch.optim.Optimizer]: + local_model = CompositeParamModel(device=device) + fsdp_wrapped_model = FSDP( + copy.deepcopy(local_model), + auto_wrap_policy=ModuleWrapPolicy({UnitModule}), + use_orig_params=True, + ) + composable_module = copy.deepcopy(local_model) + fully_shard( + composable_module, + policy=ModuleWrapPolicy({UnitModule}), + ) + LR = 1e-2 + fsdp_wrapped_optim = torch.optim.Adam(fsdp_wrapped_model.parameters(), lr=LR) + composable_optim = torch.optim.Adam(composable_module.parameters(), lr=LR) + return ( + composable_module, + composable_optim, + fsdp_wrapped_model, + fsdp_wrapped_optim, + ) + + @skip_if_lt_x_gpu(2) + def test_training(self): + """Tests training (forward, backward, optimizer).""" + device = torch.device("cuda") + ( + composable_module, + composable_optim, + fsdp_wrapped_model, + fsdp_wrapped_optim, + ) = self._init_models_and_optims(device) + for _ in range(5): + inp = torch.randn(2, 100, device="cuda") + losses: List[torch.Tensor] = [] + for model, optim in ( + (fsdp_wrapped_model, fsdp_wrapped_optim), + (composable_module, composable_optim), + ): + optim.zero_grad(set_to_none=True) + out = model(inp) + loss = out.sum() + losses.append(loss) + loss.backward() + optim.step() + self.assertEqual(losses[0], losses[1]) + + @skip_if_lt_x_gpu(2) + def test_unshard_reshard_order(self): + """ + Tests that the unshard/reshard order matches between ``fully_shard`` + and ``FullyShardedDataParallel`` for the same policy. + + NOTE: We use FQNs as the proxy for checking the order across the two + versions. See ``_check_same_param_handles()`` for details. + """ + device = torch.device("cuda") + ( + composable_module, + composable_optim, + fsdp_wrapped_model, + fsdp_wrapped_optim, + ) = self._init_models_and_optims(device) + # Before checking the unshard/reshard order, sanity check that the + # assumption about wrapper FQN being a suffix of composable FQN holds + all_composable_handles = _all_handles(fully_shard.state(composable_module)) + all_wrapped_handles = _all_handles(fsdp_wrapped_model) + self._check_same_param_handles(all_composable_handles, all_wrapped_handles) + num_handles = len(all_composable_handles) + + orig_unshard = torch.distributed.fsdp._runtime_utils._unshard + orig_reshard = torch.distributed.fsdp._runtime_utils._reshard + UnshardReshardEvent = Tuple[str, _HandlesKey] + + def patched_unshard( + unshard_reshard_order: List[UnshardReshardEvent], + state: _FSDPState, + handles: List[FlatParamHandle], + *args, + **kwargs, + ): + handles_key = tuple(handles) + unshard_reshard_order.append(("unshard", handles_key)) + return orig_unshard(state, handles, *args, **kwargs) + + def patched_reshard( + unshard_reshard_order: List[UnshardReshardEvent], + state: _FSDPState, + handles: List[FlatParamHandle], + *args, + **kwargs, + ): + handles_key = tuple(handles) + unshard_reshard_order.append(("reshard", handles_key)) + return orig_reshard(state, handles, *args, **kwargs) + + @contextlib.contextmanager + def patch_unshard(_patched_unshard: Callable): + _orig_unshard = torch.distributed.fsdp._runtime_utils._unshard + torch.distributed.fsdp._runtime_utils._unshard = _patched_unshard + try: + yield + finally: + torch.distributed.fsdp._runtime_utils._unshard = _orig_unshard + + @contextlib.contextmanager + def patch_reshard(_patched_reshard: Callable): + _orig_reshard = torch.distributed.fsdp._runtime_utils._reshard + torch.distributed.fsdp._runtime_utils._reshard = _patched_reshard + try: + yield + finally: + torch.distributed.fsdp._runtime_utils._unshard = _orig_reshard + + composable_order: List[UnshardReshardEvent] = [] + wrapped_order: List[UnshardReshardEvent] = [] + + inp = torch.randn(2, 100, device="cuda") + losses: List[torch.Tensor] = [] + + for order, model, optim in ( + (composable_order, composable_module, composable_optim), + (wrapped_order, fsdp_wrapped_model, fsdp_wrapped_optim), + ): + with patch_unshard( + functools.partial(patched_unshard, order) + ), patch_reshard(functools.partial(patched_reshard, order)): + optim.zero_grad(set_to_none=True) + out = model(inp) + loss = out.sum() + losses.append(loss) + loss.backward() + optim.step() + self.assertEqual(losses[0], losses[1]) + + # Sanity check that the unshard/reshard events were recorded, where we + # expect one unshard/reshard pair for forward, one pair for backward, + # and possibly some extra unshards from backward prefetching (in this + # case, we expect exactly 2 extra since there are 3 handles) + self.assertGreaterEqual(len(composable_order), 2 * 2 * num_handles) + self.assertGreaterEqual(len(wrapped_order), 2 * 2 * num_handles) + self.assertGreaterEqual( + len([e for e in composable_order if e[0] == "unshard"]), 2 * num_handles + ) + self.assertGreaterEqual( + len([e for e in wrapped_order if e[0] == "unshard"]), 2 * num_handles + ) + self.assertGreaterEqual( + len([e for e in composable_order if e[0] == "reshard"]), 2 * num_handles + ) + self.assertGreaterEqual( + len([e for e in wrapped_order if e[0] == "reshard"]), 2 * num_handles + ) + + # Check that the unshard/reshard order matches + self.assertEqual(len(composable_order), len(wrapped_order)) + for ( + (composable_event, composable_handles_key), + (wrapped_event, wrapped_handles_key), + ) in zip(composable_order, wrapped_order): + self.assertEqual(composable_event, wrapped_event) + self._check_same_param_handles(composable_handles_key, wrapped_handles_key) + + def _check_same_param_handles( + self, + composable_handles: Iterable[FlatParamHandle], + wrapped_handles: Iterable[FlatParamHandle], + ) -> None: + """ + Checks that ``composable_handles`` matches ``wrapped_handles`` by + checking FQNs. + + For ``fully_shard``, each ``FlatParamHandle`` 's saved FQNs are + prefixed from the local FSDP root, while for wrapper FSDP, they are + prefixed from its owning FSDP instance, which may not be the local FSDP + root. Thus, we relax the check to only that the wrapper FQN is a suffix + of the composable FQN. + + If this check passes for the entire model and we separately unit-test + parity for wrapping policies, then we can be sure that the handles + actually match. + """ + self.assertEqual(len(composable_handles), len(wrapped_handles)) + for composable_handle, wrapped_handle in zip( + composable_handles, wrapped_handles + ): + composable_fqns = composable_handle.flat_param._fqns + wrapped_fqns = wrapped_handle.flat_param._fqns + self.assertEqual(len(composable_fqns), len(wrapped_fqns)) + for composable_fqn, wrapped_fqn in zip(composable_fqns, wrapped_fqns): + self.assertTrue(composable_fqn.endswith(wrapped_fqn)) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/_composable/test_replicate.py b/test/distributed/_composable/test_replicate.py new file mode 100644 index 0000000000000..de9fbfdbbc376 --- /dev/null +++ b/test/distributed/_composable/test_replicate.py @@ -0,0 +1,114 @@ +# Owner(s): ["oncall: distributed"] + +import os +from copy import deepcopy + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import nn +from torch.distributed._composable.replicate import replicate +from torch.testing._internal.common_distributed import MultiProcessTestCase +from torch.testing._internal.common_utils import run_tests + + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.fc1 = nn.Linear(2, 10, bias=False) + self.fc2 = nn.Linear(10, 50, bias=False) + self.fc3 = nn.Linear(50, 4, bias=False) + self.relu = nn.ReLU() + + def forward(self, x): + x = self.relu(self.fc1(x)) + x = self.relu(self.fc2(x)) + x = self.fc3(x) + return F.softmax(x, dim=1) + + +class ReplicateTest(MultiProcessTestCase): + def setUp(self) -> None: + super().setUp() + self._spawn_processes() + + def tearDown(self): + super().tearDown() + try: + os.remove(self.file_name) + except OSError: + pass + + def _compare_module(self, mod, replicate_mod): + dist.init_process_group( + backend="gloo", + rank=self.rank, + world_size=self.world_size, + store=dist.FileStore(self.file_name, self.world_size), + ) + + local_batch_size = 1 + global_batch_size = self.world_size * local_batch_size + input = torch.randn(global_batch_size, 2) + target = torch.randn(global_batch_size, 4) + + def step_model(model, input, target): + model.train() + output = model(input) + loss = F.mse_loss(output, target.to(output.device)) + loss.backward() + for param in model.parameters(): + with torch.no_grad(): + param -= param.grad + param.grad = None + + for iteration in range(2): + step_model(mod, input, target) + step_model( + replicate_mod, + input[ + self.rank + * local_batch_size : (self.rank + 1) + * local_batch_size + ], + target[ + self.rank + * local_batch_size : (self.rank + 1) + * local_batch_size + ], + ) + + self.assertEqual( + len(list(mod.parameters())), + len(list(replicate_mod.parameters())), + ) + for i, j in zip(mod.parameters(), replicate_mod.parameters()): + self.assertEqual(i, j, rtol=1.3e-06, atol=5e-5) + + # Shuffle the input so that DDP input is different + torch.manual_seed(iteration) + input = input[torch.randperm(global_batch_size)] + + def test_replicate_single_module(self): + model = Net() + replicate_model = replicate(deepcopy(model)) + self._compare_module(model, replicate_model) + + def test_replicate_multi_module(self): + model = Net() + replicate_model = deepcopy(model) + replicate(replicate_model.fc1) + replicate(replicate_model.fc2) + replicate(replicate_model.fc3) + self._compare_module(model, replicate_model) + + def test_replicate_with_kwargs(self): + model = Net() + replicate_model = replicate( + deepcopy(model), bucket_cap_mb=1, gradient_as_bucket_view=True + ) + self._compare_module(model, replicate_model) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/_tensor/README.md b/test/distributed/_tensor/README.md new file mode 100644 index 0000000000000..6235f9657d5fe --- /dev/null +++ b/test/distributed/_tensor/README.md @@ -0,0 +1,11 @@ +## Run distributed tensor tests: + +from root, run (either CPU or GPU) + +`pytest test/spmd/tensor/test_tensor.py` + +`pytest test/spmd/tensor/test_ddp.py` + +run specific test case and print stdout/stderr: + +`pytest test/spmd/tensor/test_tensor.py -s -k test_tensor_from_local` diff --git a/test/distributed/_tensor/__init__.py b/test/distributed/_tensor/__init__.py new file mode 100644 index 0000000000000..087882b22d1f0 --- /dev/null +++ b/test/distributed/_tensor/__init__.py @@ -0,0 +1 @@ +# shut up pylint diff --git a/test/distributed/_tensor/test_api.py b/test/distributed/_tensor/test_api.py new file mode 100644 index 0000000000000..a4b5e84bce862 --- /dev/null +++ b/test/distributed/_tensor/test_api.py @@ -0,0 +1,216 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# Owner(s): ["oncall: distributed"] + +import torch +import torch.nn as nn +from torch.distributed._tensor import ( + DeviceMesh, + distribute_module, + distribute_tensor, + DTensor, + Replicate, + Shard, +) +from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.distributed._tensor.common_dtensor import ( + DTensorTestBase, + with_comms, +) + + +class MyModel(nn.Module): + def __init__(self, n_features, n_layers, device): + super().__init__() + self.seq = nn.Sequential( + *[nn.Linear(n_features, n_features, device=device) for _ in range(n_layers)] + ) + + def forward(self, x): + return self.seq(x) + + def reset_parameters(self): + for m in self.seq: + m.reset_parameters() + + +class DTensorAPITest(DTensorTestBase): + @property + def world_size(self) -> int: + # hard code world size to 4 as we need to test + # at least with 2d mesh + return 4 + + @with_comms + def test_distribute_tensor(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + shard_spec = [Shard(0)] + + for requires_grad in [True, False]: + + tensor_to_shard = torch.randn( + 3 * self.world_size, 3, requires_grad=requires_grad + ) + dist_tensor = distribute_tensor(tensor_to_shard, device_mesh, shard_spec) + self.assertEqual(dist_tensor.size(), torch.Size([3 * self.world_size, 3])) + local_tensor = dist_tensor.to_local() + self.assertEqual(local_tensor.size(), torch.Size([3, 3])) + if requires_grad: + self.assertTrue(dist_tensor.requires_grad) + self.assertTrue(dist_tensor.is_leaf) + + @with_comms + def test_distribute_tensor_errors(self): + device_mesh = DeviceMesh( + self.device_type, torch.arange(self.world_size).reshape(2, 2) + ) + tensor_shape = [3 * self.world_size, 3 * self.world_size] + tensor_to_distribute = torch.randn(*tensor_shape) + + with self.assertRaisesRegex(ValueError, "must have the same length"): + shard_spec = [Shard(0)] + distribute_tensor(tensor_to_distribute, device_mesh, shard_spec) + + spec = [Shard(0), Shard(1)] + dtensor = distribute_tensor(tensor_to_distribute, device_mesh, spec) + + with self.assertRaisesRegex(ValueError, "to a different device mesh"): + new_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + distribute_tensor(dtensor, new_mesh, [Shard(0)]) + + with self.assertRaisesRegex(ValueError, "to a different placements"): + new_spec = [Shard(0), Replicate()] + distribute_tensor(dtensor, device_mesh, new_spec) + + @with_comms + def test_distribute_tensor_uneven_sharding(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + input_sizes_and_shard_dims = [ + ((self.world_size * 3 + 1, 3, 3), 0), + ((self.world_size * 3 + 2, 3, 3), 0), + ((3, self.world_size * 3 + 1, 3), 1), + ((3, self.world_size * 3 + 2, 3), 1), + ((3, 3, self.world_size * 3 + 1), 2), + ((3, 3, self.world_size * 3 + 2), 2), + ] + for input_size, shard_dim in input_sizes_and_shard_dims: + shard_spec = [Shard(shard_dim)] + tensor_to_shard = torch.randn(input_size) + splitted_tensor_list = tensor_to_shard.tensor_split( + self.world_size, dim=shard_dim + ) + dist_tensor = distribute_tensor(tensor_to_shard, device_mesh, shard_spec) + self.assertEqual(dist_tensor.size(), torch.Size(input_size)) + local_tensor = dist_tensor.to_local() + self.assertEqual(local_tensor, splitted_tensor_list[self.rank]) + + @with_comms + def test_distribute_module(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + # fully shard all linear modules on dim 0 + module_to_shard = MyModel(5 * self.world_size, 20, device=self.device_type) + shard_spec = [Shard(0)] + + def shard_fn(name, module, device_mesh): + if isinstance(module, nn.Linear): + for name, param in module.named_parameters(): + dist_param = torch.nn.Parameter( + distribute_tensor(param, device_mesh, shard_spec) + ) + module.register_parameter(name, dist_param) + + sharded_module = distribute_module(module_to_shard, device_mesh, shard_fn) + for param in sharded_module.parameters(): + self.assertIsInstance(param, DTensor) + self.assertEqual(param.placements, shard_spec) + + replica_spec = [Replicate()] + # fully replicate all modules without passing in partition_fn + module_to_replicate = MyModel(5, 20, device=self.device_type) + replica_module = distribute_module(module_to_replicate, device_mesh) + for param in replica_module.parameters(): + self.assertIsInstance(param, DTensor) + self.assertEqual(param.placements, replica_spec) + + # fully replicate all modules by passing in partition_fn + def replicate_fn(name, module, device_mesh): + if isinstance(module, nn.Linear): + for name, param in module.named_parameters(): + dist_param = torch.nn.Parameter( + distribute_tensor(param, device_mesh, replica_spec) + ) + module.register_parameter(name, dist_param) + + module_to_replicate = MyModel(5, 20, device=self.device_type) + replica_module = distribute_module( + module_to_replicate, device_mesh, replicate_fn + ) + for param in replica_module.parameters(): + self.assertIsInstance(param, DTensor) + self.assertEqual(param.placements, replica_spec) + + # only shard part of module, and rest of module should be replicate + def shard_fn(name, module, device_mesh): + if isinstance(module, nn.Linear) and (name == "seq.0" or name == "seq.8"): + for name, param in module.named_parameters(): + dist_param = torch.nn.Parameter( + distribute_tensor(param, device_mesh, shard_spec) + ) + module.register_parameter(name, dist_param) + + module_to_distribute = MyModel(5 * self.world_size, 20, device=self.device_type) + dist_module = distribute_module(module_to_distribute, device_mesh, shard_fn) + for name, param in dist_module.named_parameters(): + self.assertIsInstance(param, DTensor) + if name.startswith("seq.0") or name.startswith("seq.8"): + self.assertEqual(param.placements, shard_spec) + else: + self.assertEqual(param.placements, replica_spec) + + @with_comms + def test_distribute_module_input_fn_output_fn(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + + # fully replicate all linear modules + module_to_replicate = MyModel(20, 1, device=self.device_type) + + # mark input sharding on dim 0 + def input_fn(inputs, device_mesh): + return DTensor.from_local(inputs[0], device_mesh, [Shard(0)]) + + def output_fn(outputs, device_mesh): + assert isinstance(outputs, DTensor) + return outputs.to_local() + + replica_module = distribute_module( + module_to_replicate, + device_mesh, + input_fn=input_fn, + output_fn=output_fn, + ) + + input_tensor = torch.randn(5, 20, device=self.device_type) + local_out = replica_module(input_tensor) + self.assertIsInstance(local_out, torch.Tensor) + self.assertNotIsInstance(local_out, DTensor) + + # full replicate (even on inputs) + model = MyModel(10, 10, device=self.device_type) + + def replicate_input_fn(inputs, device_mesh): + return DTensor.from_local(inputs[0], device_mesh, [Replicate()]) + + replica_model = distribute_module( + model, + device_mesh, + input_fn=replicate_input_fn, + ) + input = torch.randn(10, 10, requires_grad=True) + output = replica_model(input) + output.sum().backward() + param_grad = list(replica_model.parameters())[0].grad + self.assertTrue(isinstance(param_grad, DTensor)) + self.assertTrue(isinstance(param_grad.placements[0], Replicate)) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/_tensor/test_common_rules.py b/test/distributed/_tensor/test_common_rules.py new file mode 100644 index 0000000000000..fe89b6c4c40d7 --- /dev/null +++ b/test/distributed/_tensor/test_common_rules.py @@ -0,0 +1,409 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# Owner(s): ["oncall: distributed"] + +import torch +from torch.distributed._tensor import DeviceMesh +from torch.distributed._tensor.dispatch import OpSchema + +from torch.distributed._tensor.ops.common_rules import ( + einop_rule, + pointwise_rule, + reduction_rule, +) +from torch.distributed._tensor.placement_types import DTensorSpec +from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.distributed._tensor.common_dtensor import ( + DTensorTestBase, + with_comms, +) +from torch._C import parse_schema + + +class CommonRulesTest(DTensorTestBase): + @property + def world_size(self) -> int: + # hard code world size to 4 as we need to test + # at least with 2d mesh + return 4 + + @with_comms + def test_einop_basic_propagation(self): + # plain einsum, mm + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + + func_schema = parse_schema("aten::mm(Tensor self, Tensor mat2) -> Tensor") + # propagate col-wise sharding + mat1, mat2 = [-1, -1], [-1, 0] + mat1_spec = DTensorSpec.from_dim_map(mesh, mat1, [], shape=torch.Size([8, 4])) + mat2_spec = DTensorSpec.from_dim_map(mesh, mat2, [], shape=torch.Size([4, 8])) + output_sharding = einop_rule( + "mk,kn->mn", OpSchema(func_schema, (mat1_spec, mat2_spec), {}) + ) + output_spec = output_sharding.output_spec + self.assertIsNotNone(output_spec) + self.assertEqual(output_spec.dim_map, [-1, 0]) + self.assertEqual(output_spec.shape, torch.Size([8, 8])) + + # propagate row-wise sharding + mat1, mat2 = [0, -1], [-1, -1] + mat1_spec = DTensorSpec.from_dim_map(mesh, mat1, [], shape=torch.Size([8, 4])) + mat2_spec = DTensorSpec.from_dim_map(mesh, mat2, [], shape=torch.Size([4, 8])) + output_sharding = einop_rule( + "mk,kn->mn", OpSchema(func_schema, (mat1_spec, mat2_spec), {}) + ) + output_spec = output_sharding.output_spec + self.assertIsNotNone(output_spec) + self.assertEqual(output_spec.dim_map, [0, -1]) + self.assertEqual(output_spec.shape, torch.Size([8, 8])) + + # generate partial + mat1, mat2 = [-1, 0], [0, -1] + mat1_spec = DTensorSpec.from_dim_map(mesh, mat1, [], shape=torch.Size([8, 4])) + mat2_spec = DTensorSpec.from_dim_map(mesh, mat2, [], shape=torch.Size([4, 8])) + output_sharding = einop_rule( + "mk,kn->mn", OpSchema(func_schema, (mat1_spec, mat2_spec), {}) + ) + output_spec = output_sharding.output_spec + self.assertIsNotNone(output_spec) + self.assertTrue(output_spec.placements[0].is_partial()) + self.assertEqual(output_spec.shape, torch.Size([8, 8])) + + @with_comms + def test_einop_pointwise_propagation(self): + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + + func_schema = parse_schema( + "aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor" + ) + # addition + mat1 = [0, -1] + mat1_spec = DTensorSpec.from_dim_map(mesh, mat1, [], shape=torch.Size([8, 8])) + output_sharding = einop_rule( + "ij,ij->ij", OpSchema(func_schema, (mat1_spec, mat1_spec), {}) + ) + output_spec = output_sharding.output_spec + self.assertIsNotNone(output_spec) + self.assertEqual(output_spec.dim_map, [0, -1]) + self.assertEqual(output_spec.shape, torch.Size([8, 8])) + + # broadcast addition + mat1 = [-1, 0, -1] + mat1_spec = DTensorSpec.from_dim_map( + mesh, mat1, [], shape=torch.Size([8, 4, 2]) + ) + mat2_spec = DTensorSpec.from_dim_map(mesh, [-1], [], shape=torch.Size([2])) + output_sharding = einop_rule( + "ijk,k->ijk", OpSchema(func_schema, (mat1_spec, mat2_spec), {}) + ) + output_spec = output_sharding.output_spec + self.assertIsNotNone(output_spec) + self.assertEqual(output_spec.dim_map, [-1, 0, -1]) + self.assertEqual(output_spec.shape, torch.Size([8, 4, 2])) + + # broadcast to a common shape + mat1_spec = DTensorSpec.from_dim_map( + mesh, [0, -1, -1], [], shape=torch.Size([8, 8, 8]) + ) + mat2_spec = DTensorSpec.from_dim_map( + mesh, [-1, -1], [], shape=torch.Size([1, 8]) + ) + output_sharding = einop_rule( + "ijk,1k->ijk", OpSchema(func_schema, (mat1_spec, mat2_spec), {}) + ) + output_spec = output_sharding.output_spec + self.assertIsNotNone(output_spec) + self.assertEqual(output_spec.dim_map, [0, -1, -1]) + self.assertEqual(output_spec.shape, torch.Size([8, 8, 8])) + + @with_comms + def test_einop_merge_sharding(self): + # 2d mesh einop merge sharding + mesh_shape = torch.arange(self.world_size).reshape( + self.world_size // 2, self.world_size // 2 + ) + mesh = DeviceMesh(self.device_type, mesh_shape) + + func_schema = parse_schema("aten::mm(Tensor self, Tensor mat2) -> Tensor") + + mat1, mat2 = [0, -1], [-1, 1] + mat1_spec = DTensorSpec.from_dim_map(mesh, mat1, [], shape=torch.Size([8, 4])) + mat2_spec = DTensorSpec.from_dim_map(mesh, mat2, [], shape=torch.Size([4, 8])) + output_sharding = einop_rule( + "mk,kn->mn", OpSchema(func_schema, (mat1_spec, mat2_spec), {}) + ) + output_spec = output_sharding.output_spec + self.assertIsNotNone(output_spec) + self.assertEqual(output_spec.dim_map, [0, 1]) + self.assertEqual(output_spec.shape, torch.Size([8, 8])) + + @with_comms + def test_einop_linearity(self): + mesh_shape = torch.arange(self.world_size).reshape( + self.world_size // 2, self.world_size // 2 + ) + mesh = DeviceMesh(self.device_type, mesh_shape) + + mm_func_schema = parse_schema( + "aten::mm(Tensor self, Tensor mat2) -> Tensor" + ) + + mat1, mat2 = [0, -1], [-1, -1] + mat1_spec = DTensorSpec.from_dim_map(mesh, mat1, [1], shape=torch.Size([8, 4])) + mat2_spec = DTensorSpec.from_dim_map(mesh, mat2, [], shape=torch.Size([4, 8])) + # if not turn on linearity, partial sum is not eligible to propagate, we return + # suggestion to reshard inputs with no partial sum (i.e. all_reduce one input) + output_sharding = einop_rule( + "mk,kn->mn", OpSchema(mm_func_schema, (mat1_spec, mat2_spec), {}) + ) + self.assertIsNone(output_sharding.output_spec) + suggestions = output_sharding.schema_suggestions + self.assertIsNotNone(suggestions) + suggested_spec = suggestions[0].args_schema[0] + self.assertFalse(suggested_spec.placements[1].is_partial()) + + # einop prop with linearity on mm, should give back suggestion + # on converting placements to partial + output_sharding = einop_rule( + "mk,kn->mn", + OpSchema(mm_func_schema, (mat1_spec, mat2_spec), {}), + linearity=True, + ) + self.assertIsNone(output_sharding.output_spec) + suggestions = output_sharding.schema_suggestions + self.assertIsNotNone(suggestions) + mat2_spec = suggestions[0].args_schema[1] + # mat2 mesh dim 1 should become partial now! + self.assertTrue(mat2_spec.placements[1].is_partial()) + + # einop prop with linearity on point-wise, should give back suggestion + # on converting placements to partial + add_func_schema = parse_schema( + "aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor" + ) + mat1, mat2 = [0, -1], [0, -1] + mat1_spec = DTensorSpec.from_dim_map(mesh, mat1, [1], shape=torch.Size([8, 6])) + mat2_spec = DTensorSpec.from_dim_map(mesh, mat2, [], shape=torch.Size([8, 6])) + + output_sharding = einop_rule( + "ij,ij->ij", + OpSchema(add_func_schema, (mat1_spec, mat2_spec), {}), + linearity=True, + ) + self.assertIsNone(output_sharding.output_spec) + suggestions = output_sharding.schema_suggestions + self.assertIsNotNone(suggestions) + mat2_spec = suggestions[0].args_schema[1] + # mat2 mesh dim 1 should become partial now! + self.assertTrue(mat2_spec.placements[1].is_partial()) + + @with_comms + def test_einop_multi_sharding_on_mesh_dim(self): + # einop prop with multi sharding on same mesh dim + mesh_shape = torch.arange(self.world_size) + mesh = DeviceMesh(self.device_type, mesh_shape) + + func_schema = parse_schema("aten::mm(Tensor self, Tensor mat2) -> Tensor") + mat1, mat2 = [0, -1], [0, -1] + mat1_spec = DTensorSpec.from_dim_map(mesh, mat1, [], shape=torch.Size([8, 12])) + mat2_spec = DTensorSpec.from_dim_map(mesh, mat2, [], shape=torch.Size([12, 4])) + output_sharding = einop_rule( + "mk,kn->mn", OpSchema(func_schema, (mat1_spec, mat2_spec), {}) + ) + output_spec = output_sharding.output_spec + self.assertIsNone(output_spec) + self.assertIsNotNone(output_sharding.schema_suggestions) + + # ensure that the suggestion is to reshard the second + # arg by all_gather its tensor dim sharding + schema_suggestion = output_sharding.schema_suggestions[0] + self.assertEqual(schema_suggestion.args_schema[0].dim_map, [0, -1]) + self.assertEqual(schema_suggestion.args_schema[1].dim_map, [-1, -1]) + + @with_comms + def test_einop_errors(self): + mesh_shape = torch.arange(self.world_size).reshape( + self.world_size // 2, self.world_size // 2 + ) + mesh = DeviceMesh(self.device_type, mesh_shape) + + func_schema = parse_schema( + "aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor" + ) + mat1, mat2 = [0, -1], [1, -1] + mat1_spec = DTensorSpec.from_dim_map(mesh, mat1, [], shape=torch.Size([8, 4])) + mat2_spec = DTensorSpec.from_dim_map(mesh, mat2, [], shape=torch.Size([8, 4])) + + with self.assertRaisesRegex(RuntimeError, "sharded two different ways:"): + einop_rule("ij,ij->ij", OpSchema(func_schema, (mat1_spec, mat2_spec), {})) + + @with_comms + def test_pointwise_rules_broadcasting(self): + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + + func_schema = parse_schema( + "where.self(Tensor condition, Tensor self, Tensor other) -> Tensor" + ) + inp1, inp2, inp3 = [0], [], [-1, -1] + condition = DTensorSpec.from_dim_map(mesh, inp1, [], shape=torch.Size([8])) + self_tensor = DTensorSpec.from_dim_map(mesh, inp2, [], shape=torch.Size([])) + other_tensor = DTensorSpec.from_dim_map( + mesh, inp3, [], shape=torch.Size([1, 1]) + ) + # propagate point-wise sharding with broadcasting + output_sharding = pointwise_rule( + OpSchema(func_schema, (condition, self_tensor, other_tensor), {}) + ) + output_spec = output_sharding.output_spec + self.assertIsNotNone(output_spec) + self.assertEqual(output_spec.dim_map, [-1, 0]) + self.assertEqual(output_spec.shape, [1, 8]) + + @with_comms + def test_pointwise_rules_suggestion(self): + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + + func_schema = parse_schema( + "aten::lerp.Scalar(Tensor self, Tensor end, Scalar weight) -> Tensor" + ) + # propagate point-wise sharding + inp1, inp2 = [-1, -1], [-1, 0] + mat1_spec = DTensorSpec.from_dim_map(mesh, inp1, [], shape=torch.Size([8, 4])) + mat2_spec = DTensorSpec.from_dim_map(mesh, inp2, [], shape=torch.Size([8, 4])) + # adding a positional argument -1 to arg schema + output_sharding = pointwise_rule( + OpSchema(func_schema, (mat1_spec, mat2_spec, -1), {}) + ) + self.assertIsNone(output_sharding.output_spec) + self.assertIsNotNone(output_sharding.schema_suggestions) + + # ensure that the suggestion from pointwise rules still have + # the positional args that are not DTensorSpec + schema_suggestion = output_sharding.schema_suggestions[0] + self.assertEqual(len(schema_suggestion.args_schema), 3) + self.assertEqual(schema_suggestion.args_schema[2], -1) + + @with_comms + def test_pointwise_multi_sharding_on_mesh_dim(self): + # 2d mesh pointwise sharding + mesh_shape = torch.arange(self.world_size).reshape( + self.world_size // 2, self.world_size // 2 + ) + mesh = DeviceMesh(self.device_type, mesh_shape) + + func_schema = parse_schema( + "aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor" + ) + + # basic case to test implicit broadcasting shape alignment + mat1, mat2 = [-1, 0], [0] + mat1_spec = DTensorSpec.from_dim_map(mesh, mat1, [], shape=torch.Size([20, 6])) + mat2_spec = DTensorSpec.from_dim_map(mesh, mat2, [], shape=torch.Size([6])) + output_sharding = pointwise_rule( + OpSchema(func_schema, (mat1_spec, mat2_spec), {}) + ) + output_spec = output_sharding.output_spec + self.assertIsNotNone(output_spec) + self.assertEqual(output_spec.dim_map, [-1, 0]) + + # more advanced case that needs reshard one input to align sharding + mat1, mat2 = [0, -1, -1, 1], [0, -1, 1] + mat1_spec = DTensorSpec.from_dim_map( + mesh, mat1, [], shape=torch.Size([12, 1, 1, 8]) + ) + mat2_spec = DTensorSpec.from_dim_map( + mesh, mat2, [], shape=torch.Size([12, 4, 8]) + ) + output_sharding = pointwise_rule( + OpSchema(func_schema, (mat1_spec, mat2_spec), {}) + ) + output_spec = output_sharding.output_spec + self.assertIsNone(output_spec) + self.assertIsNotNone(output_sharding.schema_suggestions) + + # ensure that the suggestion is to reshard the first + # arg by all_gather first tensor dim sharding + schema_suggestion = output_sharding.schema_suggestions[0] + self.assertEqual(schema_suggestion.args_schema[0].dim_map, [-1, -1, -1, 1]) + self.assertEqual(schema_suggestion.args_schema[1].dim_map, mat2) + + @with_comms + def test_pointwise_enforce_sharding_multi_sharding_on_mesh_dim(self): + # 2d mesh pointwise sharding + mesh_shape = torch.arange(self.world_size).reshape( + self.world_size // 2, self.world_size // 2 + ) + mesh = DeviceMesh(self.device_type, mesh_shape) + + func_schema = parse_schema( + "aten::add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)" + ) + + # more advanced case that needs reshard one input to align sharding + mat1, mat2 = [0, -1, 1], [-1, -1, 0] + mat1_spec = DTensorSpec.from_dim_map( + mesh, mat1, [], shape=torch.Size([12, 4, 8]) + ) + mat2_spec = DTensorSpec.from_dim_map( + mesh, mat2, [], shape=torch.Size([12, 1, 8]) + ) + output_sharding = pointwise_rule( + OpSchema(func_schema, (mat1_spec, mat2_spec), {}) + ) + output_spec = output_sharding.output_spec + self.assertIsNone(output_spec) + self.assertIsNotNone(output_sharding.schema_suggestions) + + # ensure that the suggestion is to reshard the second + # arg as we should enforce the sharding of the first arg + schema_suggestion = output_sharding.schema_suggestions[0] + self.assertEqual(schema_suggestion.args_schema[0].dim_map, mat1) + self.assertEqual(schema_suggestion.args_schema[1].dim_map, mat1) + + @with_comms + def test_reduction_rule(self): + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + + func_schema = parse_schema( + "aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor" + ) + # reduction on a 2d mat + mat1 = [0, -1] + mat1_spec = DTensorSpec.from_dim_map(mesh, mat1, [], shape=torch.Size([8, 4])) + # reduction on dim 0 + output_sharding_0 = reduction_rule( + OpSchema(func_schema, (mat1_spec, 0), {}), + dims=[0], + reduction_linear=True, + ) + self.assertIsNotNone(output_sharding_0.output_spec) + self.assertEqual(output_sharding_0.output_spec.dim_map, [-1]) + # pending sum on dim 0 + self.assertEqual(output_sharding_0.output_spec.sums, [0]) + self.assertEqual(output_sharding_0.output_spec.shape, torch.Size([4])) + + # reduction on dim 1 + output_sharding_1 = reduction_rule( + OpSchema(func_schema, (mat1_spec, 1), {}), + dims=[1], + reduction_linear=True, + ) + self.assertIsNotNone(output_sharding_1.output_spec) + self.assertEqual(output_sharding_1.output_spec.dim_map, [0]) + self.assertEqual(output_sharding_1.output_spec.sums, []) + self.assertEqual(output_sharding_1.output_spec.shape, torch.Size([8])) + + # full reduction if not specify dim + output_sharding_all_dim = reduction_rule( + OpSchema(func_schema, (mat1_spec,), {}), + dims=[0, 1], + reduction_linear=True, + ) + self.assertIsNotNone(output_sharding_all_dim.output_spec) + self.assertEqual(output_sharding_all_dim.output_spec.dim_map, []) + # pending sum on mesh + self.assertEqual(output_sharding_all_dim.output_spec.sums, [0]) + self.assertEqual(output_sharding_all_dim.output_spec.shape, torch.Size([])) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/_tensor/test_device_mesh.py b/test/distributed/_tensor/test_device_mesh.py new file mode 100644 index 0000000000000..49013a8640a6e --- /dev/null +++ b/test/distributed/_tensor/test_device_mesh.py @@ -0,0 +1,495 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# Owner(s): ["oncall: distributed"] + +import torch +from torch.distributed._tensor.device_mesh import DeviceMesh +from torch.distributed._tensor.placement_types import Shard + +from torch.distributed.distributed_c10d import ( + get_global_rank, + get_world_size, + new_group, + ProcessGroup, +) +from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.distributed._tensor.common_dtensor import ( + DTensorTestBase, + with_comms, +) + + +class DeviceMeshTest(DTensorTestBase): + @property + def world_size(self): + return 8 + + @with_comms + def test_device_mesh_2d(self): + mesh_tensor = torch.arange(4).reshape(2, 2) + # construct a cuda device mesh + mesh = DeviceMesh(self.device_type, mesh_tensor) + + # check all dim groups + dim_to_subgroups = mesh.get_dim_groups() + + expected_ranks_by_dim = [[[0, 2], [1, 3]], [[0, 1], [2, 3]]] + for dim, dim_group in enumerate(dim_to_subgroups): + self.assertTrue(dim < 2) + dim_ranks = expected_ranks_by_dim[dim] + + dim_group_size = get_world_size(dim_group) + self.assertIsInstance(dim_group, ProcessGroup) + self.assertEqual(dim_group_size, 2) + global_ranks = [ + get_global_rank(dim_group, i) for i in range(dim_group_size) + ] + current_rank_expected_group_ranks = ( + dim_ranks[0] if self.rank in dim_ranks[0] else dim_ranks[1] + ) + self.assertEqual(global_ranks, current_rank_expected_group_ranks) + + @with_comms + def test_device_mesh_2d_from_dim_groups(self): + # construct a two dimension subgroups + dim_groups = [] + expected_ranks_by_dim = [[[0, 2], [1, 3]], [[0, 1], [2, 3]]] + for dim_group_ranks in expected_ranks_by_dim: + for subgroup_ranks in dim_group_ranks: + subgroup = new_group(ranks=subgroup_ranks) + if self.rank in subgroup_ranks: + dim_groups.append(subgroup) + + # construct a device mesh from the subgroups + mesh = DeviceMesh(self.device_type, [[0, 1], [2, 3]], dim_groups=dim_groups) + + # check all dim groups + dim_to_subgroups = mesh.get_dim_groups() + for dim, dim_group in enumerate(dim_to_subgroups): + self.assertTrue(dim < 2) + dim_ranks = expected_ranks_by_dim[dim] + + dim_group_size = get_world_size(dim_group) + self.assertIsInstance(dim_group, ProcessGroup) + self.assertEqual(dim_group_size, 2) + global_ranks = [ + get_global_rank(dim_group, i) for i in range(dim_group_size) + ] + current_rank_expected_group_ranks = ( + dim_ranks[0] if self.rank in dim_ranks[0] else dim_ranks[1] + ) + self.assertEqual(global_ranks, current_rank_expected_group_ranks) + + @with_comms + def test_device_mesh_dim_groups_error(self): + # construct a two dimension subgroups + dim_groups = [] + expected_ranks_by_dim = [[[0, 2], [1, 3]], [[0, 1], [2, 3]]] + for dim_group_ranks in expected_ranks_by_dim: + for subgroup_ranks in dim_group_ranks: + subgroup = new_group(ranks=subgroup_ranks) + if self.rank in subgroup_ranks: + dim_groups.append(subgroup) + + if len(dim_groups) > 0: + # dim_groups is not a list + self.assertRaises( + RuntimeError, + DeviceMesh, + self.device_type, + [[0, 1], [2, 3]], + dim_groups=dim_groups[0], + ) + + # dim_groups is a list, but not a list of ProcessGroup + self.assertRaises( + RuntimeError, + DeviceMesh, + self.device_type, + [[0, 1], [2, 3]], + dim_groups=[dim_groups[0], "dummy"], + ) + + # dim_groups has incorrect length + self.assertRaises( + RuntimeError, + DeviceMesh, + self.device_type, + [[0, 1], [2, 3]], + dim_groups=[dim_groups[0]], + ) + + @with_comms + def test_device_mesh_nd(self): + # construct a cuda device mesh + mesh_tensor = torch.arange(8).reshape(2, 2, 2) + mesh = DeviceMesh(self.device_type, mesh_tensor) + + # check all dim groups + dim_to_subgroups = mesh.get_dim_groups() + + for dim, dim_group in enumerate(dim_to_subgroups): + self.assertTrue(dim < mesh_tensor.ndim) + dim_ranks = mesh_tensor.swapdims(-1, dim).reshape(-1, 2) + # print(dim_ranks) + # dim_ranks = expected_ranks_by_dim[dim] + + dim_group_size = get_world_size(dim_group) + self.assertIsInstance(dim_group, ProcessGroup) + self.assertEqual(dim_group_size, 2) + global_ranks = [ + get_global_rank(dim_group, i) for i in range(dim_group_size) + ] + for ranks in dim_ranks: + if self.rank in ranks: + self.assertEqual(global_ranks, ranks.tolist()) + + +class DeviceMeshCollectiveTest(DTensorTestBase): + @property + def world_size(self): + return 8 + + @with_comms + def test_all_reduce_1d(self): + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + local_tensor = torch.ones(3, 3, device=self.device_type) * self.rank + mesh.all_reduce(local_tensor, mesh_dim=0) + res_num = ((0 + self.world_size - 1) * self.world_size) / 2 + self.assertEqual(local_tensor, torch.ones(3, 3) * res_num) + + @with_comms + def test_broadcast_1d(self): + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + local_tensor = torch.ones(3, 3, device=self.device_type) * self.rank + mesh.broadcast(local_tensor, mesh_dim=0) + self.assertEqual(local_tensor, torch.zeros(3, 3)) + + @with_comms + def test_scatter_1d(self): + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + scatter_tensor_shape = [3, 3, 3] + for scatter_dim in range(len(scatter_tensor_shape)): + shard_placement = Shard(scatter_dim) + scatter_tensor_shape[scatter_dim] *= self.world_size + # make the random seed same across rank + torch.manual_seed(0) + global_tensor = torch.randn(scatter_tensor_shape, device=self.device_type) + splitted_list, _ = shard_placement._split_tensor( + global_tensor, mesh.size(), with_padding=True, contiguous=True + ) + recv_tensor = torch.empty_like(splitted_list[mesh.get_rank()]) + # scatter on dim > 0 would generate non-contiguous tensor, verify that works + mesh.scatter(recv_tensor, splitted_list, mesh_dim=0) + self.assertEqual(recv_tensor, splitted_list[mesh.get_rank()]) + + @with_comms + def test_scatter_uneven(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + my_rank = device_mesh.get_rank() + tensor_to_split = torch.randn(device_mesh.size() + 3, device_mesh.size() + 1) + + for shard_dim in range(tensor_to_split.ndim): + shard_placement = Shard(shard_dim) + tensor_to_scatter = tensor_to_split.clone() + tensor_splitted_list = tensor_to_split.tensor_split( + device_mesh.size(), dim=shard_dim + ) + padded_tensor_list, pad_idx = shard_placement._split_tensor( + tensor_to_scatter, + device_mesh.size(), + with_padding=True, + contiguous=True, + ) + + scattered_tensor = torch.empty_like(padded_tensor_list[my_rank]) + device_mesh.scatter(scattered_tensor, padded_tensor_list, mesh_dim=0) + # unpad scattered_tensor + if pad_idx != 0 and my_rank >= pad_idx: + scattered_tensor = shard_placement._unpad_tensor(scattered_tensor) + + self.assertEqual( + scattered_tensor.size(), tensor_splitted_list[my_rank].size() + ) + self.assertEqual(scattered_tensor, tensor_splitted_list[my_rank]) + + @with_comms + def test_all_gather_1d(self): + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + dims_to_gather = [0, 1] + for dim in dims_to_gather: + output_size = [3, 3] + output_size[dim] *= self.world_size + # each rank have its own tensor, all_gather gives a list + local_tensor = torch.ones(3, 3, device=self.device_type) + gathered_list = [] + for _ in range(self.world_size): + gathered_list.append(torch.zeros_like(local_tensor)) + mesh.all_gather(gathered_list, local_tensor, mesh_dim=0) + gathered_tensor = torch.cat(gathered_list, dim=dim) + self.assertEqual(gathered_tensor, torch.ones(output_size)) + + @with_comms + def test_all_gather_uneven(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + my_rank = device_mesh.get_rank() + tensor_to_split = torch.ones( + device_mesh.size() + 3, + device_mesh.size() + 1, + device=self.device_type, + ) + + for shard_dim in range(tensor_to_split.ndim): + shard_placement = Shard(shard_dim) + tensor_padded_list, pad_idx = shard_placement._split_tensor( + tensor_to_split, + device_mesh.size(), + with_padding=True, + contiguous=True, + ) + local_tensor = tensor_padded_list[my_rank] + gathered_list = [] + for _ in range(device_mesh.size()): + gathered_list.append(torch.empty_like(local_tensor)) + + device_mesh.all_gather( + gathered_list, + local_tensor, + mesh_dim=0, + ) + if pad_idx != 0: + gathered_list = [ + shard_placement._unpad_tensor(gathered_tensor) + if i >= pad_idx + else gathered_tensor + for i, gathered_tensor in enumerate(gathered_list) + ] + all_gathered_tensor = torch.cat(gathered_list, dim=shard_dim) + self.assertEqual(all_gathered_tensor.size(), tensor_to_split.size()) + self.assertEqual(all_gathered_tensor, tensor_to_split) + + @with_comms + def test_reduce_scatter_1d(self): + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + dims_to_scatter = [0, 1] + for dim in dims_to_scatter: + input_size = [3, 3] + scattered_tensor = torch.empty(input_size, device=self.device_type) + input_size[dim] *= self.world_size + + input_rs_list = ( + torch.ones(input_size, device=self.device_type) * self.rank + ).tensor_split(self.world_size, dim=dim) + res_num = ((0 + self.world_size - 1) * self.world_size) / 2 + mesh.reduce_scatter(scattered_tensor, input_rs_list, mesh_dim=0) + self.assertEqual(scattered_tensor, torch.ones(3, 3) * res_num) + + @with_comms + def test_reduce_scatter_uneven(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + my_rank = device_mesh.get_rank() + tensor_to_split = ( + torch.ones( + device_mesh.size() + 3, + device_mesh.size() + 1, + device=self.device_type, + ) + * self.rank + ) + + for shard_dim in range(tensor_to_split.ndim): + shard_placement = Shard(shard_dim) + tensor_to_scatter = tensor_to_split.clone() + tensor_splitted_list = tensor_to_split.tensor_split( + device_mesh.size(), dim=shard_dim + ) + padded_tensor_list, pad_idx = shard_placement._split_tensor( + tensor_to_scatter, + device_mesh.size(), + with_padding=True, + contiguous=True, + ) + + res_num = ((0 + self.world_size - 1) * self.world_size) / 2 + scattered_tensor = torch.empty_like(padded_tensor_list[my_rank]) + device_mesh.reduce_scatter(scattered_tensor, padded_tensor_list, mesh_dim=0) + # unpad scattered_tensor + if pad_idx != 0 and my_rank >= pad_idx: + scattered_tensor = shard_placement._unpad_tensor(scattered_tensor) + + self.assertEqual( + scattered_tensor.size(), tensor_splitted_list[my_rank].size() + ) + self.assertEqual( + scattered_tensor, + torch.ones_like(tensor_splitted_list[my_rank]) * res_num, + ) + + @with_comms + def test_all_gather_nd(self): + mesh_tensor = torch.arange(8).reshape(2, 2, 2) + mesh = DeviceMesh(self.device_type, mesh_tensor) + local_tensor = torch.ones(3, 3, device=self.device_type) * self.rank + + dim_to_subgroups = mesh.get_dim_groups() + for dim, dim_group in enumerate(dim_to_subgroups): + dim_group_size = get_world_size(dim_group) + global_ranks = [ + get_global_rank(dim_group, i) for i in range(dim_group_size) + ] + gathered_tensor_list = list( + torch.empty( + (dim_group_size * 3, 3), device=self.device_type + ).tensor_split(dim_group_size, dim=0) + ) + mesh.all_gather(gathered_tensor_list, local_tensor, mesh_dim=dim) + gathered_tensor = torch.cat(gathered_tensor_list) + exp_tensor = torch.ones(3 * dim_group_size, 3) + for i in range(len(global_ranks)): + exp_tensor[i * 3 : (i + 1) * 3] = torch.ones(3, 3) * global_ranks[i] + self.assertEqual(gathered_tensor, exp_tensor) + + @with_comms + def test_reduce_scatter_nd(self): + mesh_tensor = torch.arange(8).reshape(2, 2, 2) + mesh = DeviceMesh(self.device_type, mesh_tensor) + + dim_to_subgroups = mesh.get_dim_groups() + for dim, dim_group in enumerate(dim_to_subgroups): + dim_group_size = get_world_size(dim_group) + local_rs_list = ( + torch.ones(dim_group_size * 3, 3, device=self.device_type) * self.rank + ).tensor_split(dim_group_size, dim=0) + scattered_tensor = torch.empty_like( + local_rs_list[mesh.get_coordinate_on_dim(dim)], + device=self.device_type, + ) + global_ranks = [ + get_global_rank(dim_group, i) for i in range(dim_group_size) + ] + mesh.reduce_scatter(scattered_tensor, local_rs_list, mesh_dim=dim) + res_num = torch.sum(torch.tensor(global_ranks)) + self.assertEqual(scattered_tensor, torch.ones(3, 3) * res_num) + + @with_comms + def test_all_reduce_nd(self): + mesh_tensor = torch.arange(8).reshape(2, 2, 2) + mesh = DeviceMesh(self.device_type, mesh_tensor) + local_tensor = torch.ones(3, 3, device=self.device_type) * self.rank + + # check all dim groups + dim_to_subgroups = mesh.get_dim_groups() + for dim, dim_group in enumerate(dim_to_subgroups): + dim_group_size = get_world_size(dim_group) + global_ranks = [ + get_global_rank(dim_group, i) for i in range(dim_group_size) + ] + cloned_local_tensor = local_tensor.clone() + mesh.all_reduce(cloned_local_tensor, mesh_dim=dim) + res_num = sum(global_ranks) + self.assertEqual(cloned_local_tensor, torch.ones(3, 3) * res_num) + + @with_comms + def test_broadcast_nd(self): + mesh_tensor = torch.arange(8).reshape(2, 2, 2) + mesh = DeviceMesh(self.device_type, mesh_tensor) + local_tensor = torch.ones(3, 3, device=self.device_type) * self.rank + + # check all dim groups + dim_to_subgroups = mesh.get_dim_groups() + for dim, dim_group in enumerate(dim_to_subgroups): + dim_group_size = get_world_size(dim_group) + global_ranks = [ + get_global_rank(dim_group, i) for i in range(dim_group_size) + ] + cloned_local_tensor = local_tensor.clone() + mesh.broadcast(cloned_local_tensor, mesh_dim=dim) + res_num = global_ranks[0] + self.assertEqual(cloned_local_tensor, torch.ones(3, 3) * res_num) + + @with_comms + def test_scatter_nd(self): + mesh_tensor = torch.arange(8).reshape(2, 2, 2) + mesh = DeviceMesh(self.device_type, mesh_tensor) + + # check all dim groups + dim_to_subgroups = mesh.get_dim_groups() + for dim, dim_group in enumerate(dim_to_subgroups): + dim_group_size = get_world_size(dim_group) + global_ranks = [ + get_global_rank(dim_group, i) for i in range(dim_group_size) + ] + scattered_tensors = [ + torch.ones(3, 3, device=self.device_type) * global_rank + for global_rank in global_ranks + ] + received_tensor = torch.empty_like( + scattered_tensors[mesh.get_coordinate_on_dim(dim)] + ) + mesh.scatter(received_tensor, scattered_tensors, mesh_dim=dim) + self.assertEqual(received_tensor, torch.ones(3, 3) * self.rank) + + @with_comms + def test_all_to_all_1d(self): + # transpose on a 2D tensor distributed over N nodes: + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + tensor_shape = [3, 3] + input_tensor_list = [ + torch.ones(*tensor_shape, device=self.device_type) + * (rank + self.rank * self.world_size) + for rank in range(self.world_size) + ] + expected_tensor_list = [ + torch.ones(tensor_shape, device=self.device_type) + * (self.rank + rank * self.world_size) # i.e. transpose + for rank in range(self.world_size) + ] + for scatter_dim in range(len(tensor_shape)): + output_tensor_list = [ + torch.empty_like(input_tensor_list[idx]) + for idx in range(len(input_tensor_list)) + ] + # scatter on dim > 0 would generate non-contiguous tensor, verify that works + mesh.all_to_all(output_tensor_list, input_tensor_list, mesh_dim=0) + output_tensor = torch.cat(output_tensor_list, dim=scatter_dim) + expected_tensor = torch.cat(expected_tensor_list, dim=scatter_dim) + + self.assertEqual(output_tensor, expected_tensor) + + @with_comms + def test_all_to_all_nd(self): + mesh_tensor = torch.arange(8).reshape(2, 2, 2) + mesh = DeviceMesh(self.device_type, mesh_tensor) + tensor_shape = [3, 3, 3] + # check all dim groups + dim_to_subgroups = mesh.get_dim_groups() + for dim, dim_group in enumerate(dim_to_subgroups): + my_coordinate = mesh.get_coordinate_on_dim(dim) + dim_group_size = get_world_size(dim_group) + global_ranks = [ + get_global_rank(dim_group, i) for i in range(dim_group_size) + ] + input_tensor_list = [ + torch.ones(*tensor_shape, device=self.device_type) + * (i + self.rank * dim_group_size) + for i in range(dim_group_size) + ] + expected_tensor_list = [ + torch.ones(*tensor_shape, device=self.device_type) + * (my_coordinate + global_rank * dim_group_size) # i.e. transpose + for global_rank in global_ranks + ] + for scatter_dim in range(len(tensor_shape)): + # input_tensor = torch.cat(input_tensor_list, dim=scatter_dim) + output_tensor_list = [ + torch.empty_like(input_tensor_list[idx]) + for idx in range(len(input_tensor_list)) + ] + # scatter on dim > 0 would generate non-contiguous tensor, verify that works + mesh.all_to_all(output_tensor_list, input_tensor_list, mesh_dim=dim) + output_tensor = torch.cat(output_tensor_list, dim=scatter_dim) + expected_tensor = torch.cat(expected_tensor_list, dim=scatter_dim) + self.assertEqual(output_tensor, expected_tensor) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/_tensor/test_dtensor.py b/test/distributed/_tensor/test_dtensor.py new file mode 100644 index 0000000000000..8d29f4d3fea67 --- /dev/null +++ b/test/distributed/_tensor/test_dtensor.py @@ -0,0 +1,327 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# Owner(s): ["oncall: distributed"] + +import torch +from torch.distributed._tensor import DeviceMesh, distribute_tensor, DTensor +from torch.distributed._tensor.placement_types import _Partial, Replicate, Shard + +from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.distributed._tensor.common_dtensor import ( + DTensorTestBase, + with_comms, +) + + +class DTensorTest(DTensorTestBase): + # @with_comms + # def test_tensor_constructor(self): + # import torch.distributed._tensor as dist_tensor + # shard_spec = PlacementSpec(device_mesh, strategies=[Shard(0)]) + # empty_tensor = dist_tensor.empty((12, 10), placement_spec=shard_spec) + # zero_tensor = dist_tensor.zeros((12, 10), placement_spec=shard_spec) + # one_tensor = dist_tensor.ones((12, 10), placement_spec=shard_spec) + + # zero_cuda_tensor = dist_tensor.zeros((12, 10), device="cuda", placement_spec=shard_spec) + + # dist_tensor.empty_like(empty_tensor) + # dist_tensor.zero_like(empty_tensor) + # dist_tensor.one_like(empty_tensor) + + @with_comms + def test_dtensor_constructor(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + shard_spec = [Shard(0)] + local_tensor = torch.randn(3, 3, requires_grad=True) + dist_tensor_shape = torch.Size([self.world_size * 3, 3]) + dist_tensor = DTensor( + local_tensor, + device_mesh, + shard_spec, + size=dist_tensor_shape, + requires_grad=True, + ) + self.assertEqual(dist_tensor.size(), torch.Size((self.world_size * 3, 3))) + + with self.assertWarnsRegex(UserWarning, "To construct"): + DTensor(local_tensor, device_mesh, shard_spec, size=dist_tensor_shape) + + local_tensor = torch.randn(3, 3, requires_grad=False) + with self.assertWarnsRegex(UserWarning, "To construct"): + dist_tensor = DTensor( + local_tensor, + device_mesh, + shard_spec, + size=dist_tensor_shape, + requires_grad=True, + ) + + @with_comms + def test_dtensor_stride(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + shard0_spec = [Shard(0)] + local_tensor = torch.randn(4, 8) + global_shape = torch.Size([self.world_size * 4, 8]) + dist_tensor = DTensor(local_tensor, device_mesh, shard0_spec, size=global_shape) + # won't affect stride + self.assertEqual(dist_tensor.stride(), (8, 1)) + + shard1_spec = [Shard(1)] + local_tensor = torch.randn(8, 4) + global_shape = torch.Size([8, self.world_size * 4]) + dist_tensor = DTensor(local_tensor, device_mesh, shard1_spec, size=global_shape) + # will affect stride after DT initialized + self.assertEqual(dist_tensor.stride(), (4 * self.world_size, 1)) + + # if initialized from a transposed mat + local_tensor = torch.randn(8, 4, 8) + local_tensor_t = local_tensor.permute(1, 2, 0) + global_shape = torch.Size([4, self.world_size * 8, 8]) + self.assertEqual(local_tensor_t.stride(), (8, 1, 32)) + dist_tensor = DTensor( + local_tensor_t, device_mesh, shard1_spec, size=global_shape + ) + global_stride = (8 * self.world_size, 1, 32 * self.world_size) + self.assertEqual(dist_tensor.stride(), global_stride) + + @with_comms + def test_from_local(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + shard_spec = [Shard(0)] + local_tensor = torch.randn(3, 3) + sharded_tensor = DTensor.from_local(local_tensor, device_mesh, shard_spec) + self.assertEqual(sharded_tensor.size(), torch.Size([self.world_size * 3, 3])) + + replica_spec = [Replicate()] + ddp_tensor = DTensor.from_local(local_tensor, device_mesh, replica_spec) + self.assertEqual(ddp_tensor.size(), local_tensor.size()) + + partial_spec = [_Partial()] + partial_tensor = DTensor.from_local(local_tensor, device_mesh, partial_spec) + self.assertEqual(partial_tensor.size(), local_tensor.size()) + + # test dist tensor works with torch.Tensor during backwards + local_tensor_with_grad = torch.randn(3, 3, requires_grad=True) + # do some operations on local tensor + local_tensor_temp = local_tensor_with_grad * 3 + # create the dist tensor with non leaf local tensor, dist tensor created + # should also be non leaf node + dist_tensor = DTensor.from_local(local_tensor_temp, device_mesh, shard_spec) + self.assertFalse(dist_tensor.is_leaf) + # do some random operations on dist tensor + output = dist_tensor * 3 + self.assertIsInstance(output, DTensor) + # trigger .backward() on dist tensor directly + local_grad = torch.ones(3, 3) + grad_output = DTensor.from_local(local_grad, device_mesh, shard_spec) + # run backward directly on dist tensor + output.backward(grad_output) + # check it gradients flow back to original torch.Tensor + self.assertIsNotNone(local_tensor_with_grad.grad) + expected_grad = torch.ones(3, 3) * 9 + self.assertEqual(local_tensor_with_grad.grad, expected_grad) + + @with_comms + def test_to_local(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + shard_spec = [Shard(0)] + dist_tensor_shape = torch.Size([self.world_size * 3, 3]) + local_tensor_with_grad = torch.randn( + 3, 3, device=self.device_type, requires_grad=True + ) + + sharded_tensor = DTensor( + local_tensor_with_grad, + device_mesh, + shard_spec, + size=dist_tensor_shape, + requires_grad=True, + ) + self.assertEqual(sharded_tensor.size(), dist_tensor_shape) + self.assertEqual(sharded_tensor.to_local(), local_tensor_with_grad) + + # test dist tensor works with torch.Tensor during backwards + # dist tensor created is a leaf node, do some operation on dist tensor + temp_st = sharded_tensor * 3 + + # do some operation on local tensor of the dist tensor + new_tensor_with_grad = torch.randn( + 3, 3, device=self.device_type, requires_grad=True + ) + res = temp_st.to_local() + new_tensor_with_grad + # call backward directly on torch.Tensor, and see if it works by + # propagating through dist tensor + res.sum().backward() + self.assertIsNotNone(sharded_tensor.grad) + + self.assertEqual(sharded_tensor.grad.to_local(), torch.ones(3, 3) * 3) + + @with_comms + def test_from_local_then_to_local(self): + # this test ensure end to end from torch.Tensor -> dist tensor -> torch.Tensor works + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + shard_spec = [Shard(0)] + + # step 1. construct from construct local tensor + local_tensor_with_grad = torch.randn( + 3, 3, device=self.device_type, requires_grad=True + ) + # do some operations on local tensor + local_tensor_temp = local_tensor_with_grad + 8 + # step 2. create the dist tensor with non leaf local tensor, dist tensor + # created should also be non leaf node + dist_tensor = DTensor.from_local(local_tensor_temp, device_mesh, shard_spec) + self.assertFalse(dist_tensor.is_leaf) + # do some random operations on dist tensor + output = dist_tensor * 6 + self.assertIsInstance(output, DTensor) + + # step 3. do some operation on local tensor of the dist tensor + new_tensor_with_grad = torch.randn( + 3, 3, device=self.device_type, requires_grad=True + ) + res = output.to_local() + new_tensor_with_grad + # call backward directly on torch.Tensor, and see if it works by + # propagating all the way back to the original torch.Tensor + res.sum().backward() + self.assertIsNotNone(local_tensor_with_grad.grad) + + expected_grad = torch.ones(3, 3) * 6 + self.assertEqual(local_tensor_with_grad.grad, expected_grad) + + @with_comms + def test_dtensor_spec_read_only_after_set(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + shard_spec = [Shard(0)] + local_tensor = torch.randn(3, 3) + sharded_tensor = DTensor.from_local(local_tensor, device_mesh, shard_spec) + + # modify shard_spec, and dist_tensor's spec should not be changed + shard_spec[0] = Replicate() + self.assertTrue(sharded_tensor.placements is not shard_spec) + self.assertNotEqual(sharded_tensor.placements, shard_spec) + + @with_comms + def test_dtensor_properties(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + shard_spec = [Shard(0)] + local_tensor = torch.randn(3, 3) + sharded_tensor = DTensor.from_local(local_tensor, device_mesh, shard_spec) + self.assertEqual(sharded_tensor.device.type, self.device_type) + + +class DTensorMeshTest(DTensorTestBase): + @property + def world_size(self): + return 8 + + @with_comms + def test_dtensor_device_mesh_device_conversion(self): + # construct a cuda device mesh + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + + # construct from a cpu local tensor with cuda device mesh + # should automatically convert the dist tensor to cuda + shard_spec = [Shard(0)] + local_tensor = torch.randn(3, 3) + dist_tensor = DTensor.from_local(local_tensor, mesh, shard_spec) + self.assertEqual(dist_tensor.device.type, self.device_type) + self.assertEqual(dist_tensor.to_local().device.type, self.device_type) + + @with_comms + def test_dtensor_api_device_mesh_context_manager(self): + with DeviceMesh(self.device_type, list(range(self.world_size))) as mesh: + shard_spec = [Shard(0)] + local_tensor = torch.randn(3, 3) + sharded_tensor = DTensor.from_local( + local_tensor, device_mesh=mesh, placements=shard_spec + ) + + with DeviceMesh(self.device_type, list(range(self.world_size))): + shard_spec = [Shard(0)] + local_tensor = torch.randn(3, 3) + sharded_tensor = DTensor.from_local(local_tensor, placements=shard_spec) + replica_spec = [Replicate()] + replica_tensor = sharded_tensor.redistribute(placements=replica_spec) + self.assertEqual( + replica_tensor.size(), torch.Size([3 * self.world_size, 3]) + ) + + @with_comms + def test_dtensor_2d_mesh(self): + mesh_tensor = torch.arange(self.world_size).reshape(2, 4) + # construct a cuda device mesh + mesh = DeviceMesh(self.device_type, mesh_tensor) + + # construct a dist tensor on 2d device mesh and test if works + shard_spec = [Shard(0), Shard(1)] + local_tensor = torch.randn(3, 3) + dist_tensor = DTensor.from_local(local_tensor, mesh, shard_spec) + self.assertEqual( + dist_tensor.size(), torch.Size([3 * mesh.size(0), 3 * mesh.size(1)]) + ) + self.assertEqual(dist_tensor.device.type, self.device_type) + self.assertEqual(dist_tensor.to_local().device.type, self.device_type) + + # if shard on the same tensor dimension + # we should correctly construct the global tensor size + shard_same_dim_spec = [Shard(0), Shard(0)] + local_tensor = torch.randn(3, 3) + dist_tensor = DTensor.from_local(local_tensor, mesh, shard_same_dim_spec) + self.assertEqual(dist_tensor.size(), torch.Size([3 * self.world_size, 3])) + + @with_comms + def test_device_mesh_nd(self): + # construct a cuda device mesh + mesh_tensor = torch.arange(self.world_size).reshape(2, 2, 2) + mesh = DeviceMesh(self.device_type, mesh_tensor) + # construct a dist tensor on 3d device mesh and test if works + shard_spec = [Shard(0), Shard(1), Shard(2)] + local_tensor = torch.randn(3, 3, 3) + dist_tensor = DTensor.from_local(local_tensor, mesh, shard_spec) + self.assertEqual(dist_tensor.size(), torch.Size([6, 6, 6])) + self.assertEqual(dist_tensor.device.type, self.device_type) + self.assertEqual(dist_tensor.to_local().device.type, self.device_type) + + # construct a dist tensor on 3d device mesh with some shards on same dim + shard_spec = [Shard(0), Shard(0), Shard(2)] + local_tensor = torch.randn(3, 3, 3) + dist_tensor = DTensor.from_local(local_tensor, mesh, shard_spec) + self.assertEqual(dist_tensor.size(), torch.Size([12, 3, 6])) + self.assertEqual(dist_tensor.device.type, self.device_type) + self.assertEqual(dist_tensor.to_local().device.type, self.device_type) + + @with_comms + def test_dtensor_spec_local_shard_offset(self): + device_mesh = DeviceMesh( + self.device_type, torch.arange(self.world_size).reshape(2, 4) + ) + tensor_shape = (3 * self.world_size, 3 * self.world_size) + # sharding specs and its corresponding local shard offsets + shard_spec_and_offsets = [ + ( + [Shard(0), Replicate()], + (3 * (self.world_size // 2) * (self.rank // 4), 0), + ), + ( + [Shard(1), Replicate()], + (0, 3 * (self.world_size // 2) * (self.rank // 4)), + ), + ( + [Replicate(), Shard(0)], + (3 * (self.world_size // 4) * (self.rank % 4), 0), + ), + ( + [Replicate(), Shard(1)], + (0, 3 * (self.world_size // 4) * (self.rank % 4)), + ), + ] + + # loop through all sharding specs and check local shard offsets + logical_tensor = torch.randn(tensor_shape) + for shard_spec, expected_shard_offsets in shard_spec_and_offsets: + dtensor = distribute_tensor(logical_tensor, device_mesh, shard_spec) + self.assertEqual(expected_shard_offsets, dtensor._spec.local_offsets) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/_tensor/test_dtensor_ops.py b/test/distributed/_tensor/test_dtensor_ops.py new file mode 100644 index 0000000000000..198fec9a5d192 --- /dev/null +++ b/test/distributed/_tensor/test_dtensor_ops.py @@ -0,0 +1,705 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# Owner(s): ["oncall: distributed"] + +import sys +import unittest +import warnings + +import torch +import torch.distributed as dist +import torch.testing._internal.common_methods_invocations as common_ops + +from torch.distributed._tensor import DeviceMesh, DTensor, Replicate + +from torch.overrides import resolve_name +from torch.testing._internal.common_device_type import ( + instantiate_device_type_tests, + ops, +) +from torch.testing._internal.common_methods_invocations import DecorateInfo +from torch.testing._internal.common_utils import ( + run_tests, + suppress_warnings, + TEST_WITH_ASAN, +) +from torch.testing._internal.distributed._tensor.common_dtensor import ( + DEVICE_TYPE, + DTensorConverter, + DTensorTestBase, + TEST_SKIPS, +) +from torch.testing._internal.distributed._tensor.dtensor_lagging_op_db import ( + dtensor_lagging_op_db, +) +from torch.utils._pytree import tree_flatten, tree_map + +# rewrite common size variables to sth can be sharded evenly +# we can enable uneven shards later, but need to adjust more on +# sample inputs (i.e. view/reshape need to adjust shape size as well) +common_ops.L = 24 +common_ops.M = 12 +common_ops.S = 4 +common_ops.XS = 2 + + +def assert_ref_dtensor_equal(test_case, dtensor_rs, rs): + flat_dtensor_rs, _ = tree_flatten(dtensor_rs) + flat_rs, _ = tree_flatten(rs) + test_case.assertEqual(len(flat_dtensor_rs), len(flat_rs)) + for dtensor_r, r in zip(flat_dtensor_rs, flat_rs): + + if not isinstance(r, torch.Tensor): + continue + + test_case.assertIsInstance(dtensor_r, torch.Tensor) + test_case.assertEqual( + dtensor_r.shape, + r.shape, + f"Shape mismatch! original shape:{r.shape}, dtensor shape: {dtensor_r.shape}", + ) + test_case.assertEqual( + dtensor_r.requires_grad, + r.requires_grad, + "op result requires_grad mismatch!" + f"original requires_grad: {r.requires_grad}, " + f"dtensor requires_grad: {dtensor_r.requires_grad}", + ) + + test_case.assertEqual(dtensor_r.to_local(), r) + + +# Copied from functorch +def xfail(op_name, variant_name="", *, device_type=None, dtypes=None): + return (op_name, variant_name, device_type, dtypes, True) + + +def skip(op_name, variant_name="", *, device_type=None, dtypes=None): + return (op_name, variant_name, device_type, dtypes, False) + + +def skipOps(test_case_name, base_test_name, to_skip): + all_opinfos = dtensor_lagging_op_db + for xfail in to_skip: + op_name, variant_name, device_type, dtypes, expected_failure = xfail + matching_opinfos = [ + o + for o in all_opinfos + if o.name == op_name and o.variant_test_name == variant_name + ] + assert len(matching_opinfos) >= 1, f"Couldn't find OpInfo for {xfail}" + for opinfo in matching_opinfos: + decorators = list(opinfo.decorators) + if expected_failure: + decorator = DecorateInfo( + unittest.expectedFailure, + test_case_name, + base_test_name, + device_type=device_type, + dtypes=dtypes, + ) + decorators.append(decorator) + else: + decorator = DecorateInfo( + unittest.skip("Skipped!"), + test_case_name, + base_test_name, + device_type=device_type, + dtypes=dtypes, + ) + decorators.append(decorator) + opinfo.decorators = tuple(decorators) + + # This decorator doesn't modify fn in any way + def wrapped(fn): + return fn + + return wrapped + + +# Re-generate this failed list, turn on dry_run of the below func +# check_dtensor_func(self, test, op, dry_run=True), then run sth +# like python test/spmd/tensor/test_dtensor_ops.py > failed.expect +dtensor_fails = { + # these sometimes pass and sometimes fail + # we need to remove many of them from list once op + # get full support with varying sharding specs + xfail("__getitem__"), + xfail("__rsub__"), + xfail("masked.amax"), + xfail("masked.amin"), + xfail("masked.argmax"), + xfail("masked.argmin"), + xfail("masked.cumprod"), + xfail("masked.cumsum"), + xfail("masked.log_softmax"), + xfail("masked.logaddexp"), + xfail("masked.logsumexp"), + xfail("masked.median"), + xfail("masked.norm"), + xfail("masked.prod"), + xfail("masked.softmin"), + xfail("masked.softmax"), + xfail("masked.sum"), + xfail("addbmm"), + xfail("addmv"), + xfail("addr"), + xfail("all"), + xfail("allclose"), + xfail("amax"), + xfail("amin"), + xfail("aminmax"), + xfail("any"), + xfail("arange"), + xfail("argmax"), + xfail("argmin"), + xfail("argsort"), + xfail("as_strided"), + xfail("as_strided_scatter"), + xfail("baddbmm"), + xfail("bernoulli"), + xfail("block_diag"), + xfail("broadcast_shapes"), + xfail("cat"), + xfail("cartesian_prod"), + xfail("cdist"), + xfail("cholesky"), + xfail("cholesky_inverse"), + xfail("cholesky_solve"), + xfail("chunk"), + xfail("clamp"), + xfail("clamp_max"), + xfail("clamp_min"), + xfail("column_stack"), + xfail("combinations"), + xfail("complex"), + xfail("constant_pad_nd"), + xfail("copysign"), + xfail("corrcoef"), + xfail("count_nonzero"), + xfail("cov"), + xfail("cross"), + xfail("cummax"), + xfail("cummin"), + xfail("cumsum"), + xfail("cumulative_trapezoid"), + xfail("diag"), + xfail("diag_embed"), + xfail("diagflat"), + xfail("diagonal"), + xfail("diagonal_copy"), + xfail("diagonal_scatter"), + xfail("diff"), + xfail("dist"), + xfail("dot"), + xfail("dstack"), + xfail("einsum"), + xfail("empty"), + xfail("empty_like"), + xfail("eq"), + xfail("eye"), + xfail("fft.fft2"), + xfail("fft.fft"), + xfail("fft.fftn"), + xfail("fft.fftshift"), + xfail("fft.ifft2"), + xfail("fft.ifft"), + xfail("fft.ifftshift"), + xfail("fft.ihfft2"), + xfail("fft.ihfft"), + xfail("fft.ihfftn"), + xfail("fft.irfft2"), + xfail("fft.irfftn"), + xfail("fft.rfft2"), + xfail("fft.rfft"), + xfail("fft.rfftn"), + xfail("flip"), + xfail("fliplr"), + xfail("flipud"), + xfail("floor_divide"), + xfail("fmax"), + xfail("fmin"), + xfail("frexp"), + xfail("full"), + xfail("gather"), + xfail("geqrf"), + xfail("gradient"), + xfail("heaviside"), + xfail("histc"), + xfail("histogram"), + xfail("histogramdd"), + xfail("hstack"), + xfail("index_add"), + xfail("index_copy"), + xfail("index_fill"), + xfail("index_put"), + xfail("index_reduce"), + xfail("index_select"), + xfail("isfinite"), + xfail("isin"), + xfail("isinf"), + xfail("isnan"), + xfail("isneginf"), + xfail("isposinf"), + xfail("kthvalue"), + xfail("linalg.cholesky"), + xfail("linalg.cholesky_ex"), + xfail("linalg.cond"), + xfail("linalg.cross"), + xfail("linalg.det"), + xfail("linalg.det", "singular"), + xfail("linalg.eig"), + xfail("linalg.eigh"), + xfail("linalg.eigvals"), + xfail("linalg.eigvalsh"), + xfail("linalg.householder_product"), + xfail("linalg.inv"), + xfail("linalg.inv_ex"), + xfail("linalg.ldl_factor"), + xfail("linalg.ldl_factor_ex"), + xfail("linalg.ldl_solve"), + xfail("linalg.lstsq"), + xfail("linalg.lstsq", "grad_oriented"), + xfail("linalg.lu"), + xfail("linalg.lu_factor"), + xfail("linalg.lu_factor_ex"), + xfail("linalg.lu_solve"), + xfail("linalg.matrix_norm"), + xfail("linalg.matrix_power"), + xfail("linalg.matrix_rank"), + xfail("linalg.matrix_rank", "hermitian"), + xfail("linalg.multi_dot"), + xfail("linalg.norm"), + xfail("linalg.norm", "subgradients_at_zero"), + xfail("linalg.pinv"), + xfail("linalg.pinv", "hermitian"), + xfail("linalg.qr"), + xfail("linalg.slogdet"), + xfail("linalg.solve"), + xfail("linalg.solve_ex"), + xfail("linalg.solve_triangular"), + xfail("linalg.svd"), + xfail("linalg.svdvals"), + xfail("linalg.tensorinv"), + xfail("linalg.tensorsolve"), + xfail("linalg.vander"), + xfail("linalg.vecdot"), + xfail("linalg.vector_norm"), + xfail("linspace"), + xfail("log_softmax"), + xfail("log_softmax", "with_dtype"), + xfail("logcumsumexp"), + xfail("logdet"), + xfail("logical_not"), + xfail("logspace"), + xfail("logsumexp"), + xfail("lt"), + xfail("lu"), + xfail("lu_solve"), + xfail("lu_unpack"), + xfail("masked_fill"), + xfail("masked_scatter"), + xfail("masked_select"), + xfail("matrix_exp"), + xfail("max", "binary"), + xfail("max", "reduction_no_dim"), + xfail("max", "reduction_with_dim"), + xfail("maximum"), + xfail("median"), + xfail("min", "binary"), + xfail("min", "reduction_no_dim"), + xfail("min", "reduction_with_dim"), + xfail("minimum"), + xfail("mode"), + xfail("msort"), + xfail("multinomial"), + xfail("mv"), + xfail("max_pool2d_with_indices_backward", ""), + xfail("nanmean"), + xfail("nanmedian"), + xfail("nanquantile"), + xfail("nansum"), + xfail("native_batch_norm"), + xfail("native_layer_norm"), + xfail("narrow_copy"), + xfail("ne"), + xfail("new_empty"), + xfail("new_empty_strided"), + xfail("transpose"), + xfail("nn.functional.adaptive_avg_pool1d"), + xfail("nn.functional.adaptive_avg_pool2d"), + xfail("nn.functional.adaptive_avg_pool3d"), + xfail("nn.functional.adaptive_max_pool1d"), + xfail("nn.functional.adaptive_max_pool2d"), + xfail("nn.functional.adaptive_max_pool3d"), + xfail("nn.functional.alpha_dropout"), + xfail("nn.functional.avg_pool1d"), + xfail("nn.functional.avg_pool2d"), + xfail("nn.functional.avg_pool3d"), + xfail("nn.functional.batch_norm"), + xfail("nn.functional.batch_norm", "without_cudnn"), + xfail("nn.functional.bilinear"), + xfail("nn.functional.binary_cross_entropy"), + xfail("nn.functional.binary_cross_entropy_with_logits"), + xfail("nn.functional.celu"), + xfail("nn.functional.conv1d"), + xfail("nn.functional.conv2d"), + xfail("nn.functional.conv_transpose1d"), + xfail("nn.functional.conv_transpose2d"), + xfail("nn.functional.conv_transpose3d"), + xfail("nn.functional.cosine_similarity"), + xfail("nn.functional.cross_entropy"), + xfail("nn.functional.ctc_loss"), + xfail("nn.functional.dropout"), + xfail("nn.functional.dropout2d"), + xfail("nn.functional.dropout3d"), + xfail("nn.functional.elu"), + xfail("nn.functional.fractional_max_pool2d"), + xfail("nn.functional.fractional_max_pool3d"), + xfail("nn.functional.gaussian_nll_loss"), + xfail("nn.functional.glu"), + xfail("nn.functional.grid_sample"), + xfail("nn.functional.group_norm"), + xfail("nn.functional.hardshrink"), + xfail("nn.functional.hardsigmoid"), + xfail("nn.functional.hardswish"), + xfail("nn.functional.hardtanh"), + xfail("nn.functional.huber_loss"), + xfail("nn.functional.instance_norm"), + xfail("nn.functional.interpolate", "area"), + xfail("nn.functional.interpolate", "bicubic"), + xfail("nn.functional.interpolate", "bilinear"), + xfail("nn.functional.interpolate", "linear"), + xfail("nn.functional.interpolate", "nearest"), + xfail("nn.functional.interpolate", "trilinear"), + xfail("nn.functional.layer_norm"), + xfail("nn.functional.leaky_relu"), + xfail("nn.functional.linear"), + xfail("nn.functional.local_response_norm"), + xfail("nn.functional.logsigmoid"), + xfail("nn.functional.margin_ranking_loss"), + xfail("nn.functional.max_pool1d"), + xfail("nn.functional.max_pool2d"), + xfail("nn.functional.max_pool3d"), + xfail("nn.functional.max_unpool1d"), + xfail("nn.functional.max_unpool1d", "grad"), + xfail("nn.functional.max_unpool2d"), + xfail("nn.functional.max_unpool2d", "grad"), + xfail("nn.functional.max_unpool3d"), + xfail("nn.functional.max_unpool3d", "grad"), + xfail("nn.functional.mish"), + xfail("nn.functional.mse_loss"), + xfail("nn.functional.multi_margin_loss"), + xfail("nn.functional.multilabel_margin_loss"), + xfail("nn.functional.multilabel_soft_margin_loss"), + xfail("nn.functional.nll_loss"), + xfail("nn.functional.normalize"), + xfail("nn.functional.pad", "circular"), + xfail("nn.functional.pad", "constant"), + xfail("nn.functional.pad", "reflect"), + xfail("nn.functional.pad", "replicate"), + xfail("nn.functional.pairwise_distance"), + xfail("nn.functional.pdist"), + xfail("nn.functional.pixel_shuffle"), + xfail("nn.functional.pixel_unshuffle"), + xfail("nn.functional.poisson_nll_loss"), + xfail("nn.functional.prelu"), + xfail("nn.functional.relu6"), + xfail("nn.functional.rrelu"), + xfail("nn.functional.selu"), + xfail("nn.functional.silu"), + xfail("nn.functional.smooth_l1_loss"), + xfail("nn.functional.soft_margin_loss"), + xfail("nn.functional.softplus"), + xfail("nn.functional.softshrink"), + xfail("nn.functional.threshold"), + xfail("nn.functional.triplet_margin_loss"), + xfail("nn.functional.triplet_margin_with_distance_loss"), + xfail("nn.functional.unfold"), + xfail("nn.functional.upsample_bilinear"), + xfail("nn.functional.upsample_nearest"), + xfail("nonzero"), + xfail("norm"), + xfail("norm", "fro"), + xfail("norm", "inf"), + xfail("norm", "nuc"), + xfail("normal"), + xfail("normal", "number_mean"), + xfail("ormqr"), + xfail("ones"), + xfail("pca_lowrank"), + xfail("pinverse"), + xfail("polar"), + xfail("put"), + xfail("qr"), + xfail("quantile"), + xfail("rad2deg"), + xfail("rand_like"), + xfail("randint_like"), + xfail("randint"), + xfail("randn"), + xfail("randn_like"), + xfail("renorm"), + xfail("repeat_interleave"), + xfail("resize_"), + xfail("resize_as_"), + xfail("roll"), + xfail("rot90"), + xfail("rsub"), + xfail("scalar_tensor"), + xfail("scatter_add"), + xfail("scatter"), + xfail("scatter_reduce", "amax"), + xfail("scatter_reduce", "amin"), + xfail("scatter_reduce", "mean"), + xfail("scatter_reduce", "prod"), + xfail("scatter_reduce", "sum"), + xfail("searchsorted"), + xfail("select"), + xfail("select_scatter"), + xfail("signbit"), + xfail("sort"), + xfail("sparse.sampled_addmm"), + xfail("special.airy_ai"), + xfail("special.bessel_j0"), + xfail("special.bessel_j1"), + xfail("special.bessel_y0"), + xfail("special.bessel_y1"), + xfail("special.chebyshev_polynomial_t"), + xfail("special.chebyshev_polynomial_u"), + xfail("special.entr"), + xfail("special.erfcx"), + xfail("special.hermite_polynomial_h"), + xfail("special.hermite_polynomial_he"), + xfail("special.i0e"), + xfail("special.i1"), + xfail("special.i1e"), + xfail("special.laguerre_polynomial_l"), + xfail("special.log_ndtr"), + xfail("special.modified_bessel_i0"), + xfail("special.modified_bessel_i1"), + xfail("special.modified_bessel_k0"), + xfail("special.modified_bessel_k1"), + xfail("special.ndtri"), + xfail("special.scaled_modified_bessel_k0"), + xfail("special.scaled_modified_bessel_k1"), + xfail("special.spherical_bessel_j0"), + xfail("special.xlog1py"), + xfail("special.zeta"), + xfail("split"), + xfail("split", "list_args"), + xfail("split_with_sizes"), + xfail("signal.windows.cosine"), + xfail("signal.windows.exponential"), + xfail("signal.windows.gaussian"), + xfail("signal.windows.kaiser"), + xfail("squeeze"), + xfail("stack"), + xfail("std"), + xfail("std_mean"), + xfail("stft"), + xfail("svd"), + xfail("svd_lowrank"), + xfail("symeig"), + xfail("t"), + xfail("take_along_dim"), + xfail("take"), + xfail("tensor_split"), + xfail("to_sparse"), + xfail("topk"), + xfail("trace"), + xfail("trapezoid"), + xfail("trapz"), + xfail("triangular_solve"), + xfail("tril"), + xfail("triu"), + xfail("unbind"), + xfail("unfold"), + xfail("unfold_copy"), + xfail("uniform"), + xfail("unflatten"), + xfail("unique_consecutive"), + xfail("unique"), + xfail("var_mean"), + xfail("vdot"), + xfail("view_as_complex"), + xfail("vstack"), + xfail("zeros"), + # ops inside this might even fail without dtensor + # tests, as we rescale op db common test size factor (i.e. L, M, S) + # which triggered the orignal function run failures with input + # generation becomes wrong, we skip them for now but should enable later. + # TODO: need to clean this list and remove all cases + skip("argwhere"), + skip("cumprod"), + skip("__rmatmul__"), + skip("meshgrid", "list_of_tensors"), + skip("meshgrid", "variadic_tensors"), + skip("nn.functional._scaled_dot_product_attention"), + skip("nn.functional.softmin"), + skip("nn.functional.embedding"), + skip("nn.functional.embedding_bag"), + skip("nn.functional.feature_alpha_dropout", "with_train"), + skip("nn.functional.feature_alpha_dropout", "without_train"), + skip("nn.functional.hinge_embedding_loss"), + skip("nn.functional.cosine_embedding_loss"), + skip("fft.hfft"), + skip("fft.hfft2"), + skip("fft.hfft2"), + skip("fft.hfftn"), + skip("fft.ifftn"), + skip("fft.irfft"), + skip("istft"), + skip("isclose"), + skip("isreal"), + skip("matmul"), + skip("masked.mean"), + skip("masked.var"), + skip("masked.std"), + skip("masked.normalize"), + skip("prod"), + skip("segment_reduce", "lengths"), + skip("segment_reduce", "offsets"), +} + + +# Add a list of ops that are currently failing BW pass +skip_bw = [ + None, # corresponds to the transpose ops 'H' and 'T' + "torch.bucketize", + "torch.conj_physical", + "torch.eq", + "torch.isfinite", + "torch.isnan", +] + + +def run_dtensor_crossref(test_case, func, args, kwargs): + to_dtensor = DTensorConverter(test_case.mesh, args, kwargs) + + # TODO: also handle cases where func raise an exception + rs = func(*args, **kwargs) + + def to_replicate(e: object) -> object: + return ( + e.redistribute(test_case.mesh, test_case.mesh.ndim * [Replicate()]) + if isinstance(e, DTensor) + else e + ) + + try: + # Suppress warnings, this doesn't matter for test_meta.py + # but it does matter if you want to use this decorator + # for cross-ref testing, as some tests may be looking at + # errors + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + # for every comb of sharding choices, we test if it works + for dtensor_args, dtensor_kwargs in to_dtensor: + # Only attempt if we managed to convert all tensors to DTensor + # (if any of them failed, we're in a mixed tensor situation and + # this is not allowed in DTensor) + if to_dtensor.successful(): + # Handle special cases first if there's any + # Suppress warnings, this doesn't matter for test_meta.py + # but it does matter if you want to use this decorator + # for cross-ref testing, as some tests may be looking at + # errors + dtensor_rs = func(*dtensor_args, **dtensor_kwargs) + + # we need to skip tests containing tensors of zero elmeents for now. + # see issue: https://github.com/pytorch/tau/issues/470 + # TODO remove this once issue above fixed. + flat_args, _ = tree_flatten(dtensor_rs) + if any( + isinstance(e, torch.Tensor) and e.numel() == 0 + for e in flat_args + ): + continue + + # redistribute/all_gather the results to compare with normal output + dtensor_rs = tree_map(to_replicate, dtensor_rs) + try: + if resolve_name(func) not in skip_bw: + if isinstance(dtensor_rs, DTensor): + dtensor_rs.to_local().sum().backward() + elif isinstance(dtensor_rs, tuple): + dtensor_rs[0].to_local().sum().backward() + + except Exception as e: + # TODO(anj): Remove this guard exception after gaining more confidence. + if torch.distributed.get_rank() == 0: + print( + f"failed to run BW: {resolve_name(func)}, {func}, {str(e)})" + ) + assert_ref_dtensor_equal(test_case, dtensor_rs, rs) + else: + raise RuntimeError( + f"failed to convert args to DTensor; " + f"originally (*{args}, **{kwargs})" + ) + except Exception as e: + raise RuntimeError( + f"failed to run: {resolve_name(func)}, with (*{args}, **{kwargs})" + ) from e + + return rs + + +def check_dtensor_func(test_case, test_func, opinfo, dry_run=False): + try: + test_func() + except Exception: + test_case.destroy_pg() + if not dry_run: + raise + if dist.get_rank() == 0: + if opinfo.variant_test_name: + print(f"xfail('{opinfo.name}', '{opinfo.variant_test_name}'),") + else: + print(f"xfail('{opinfo.name}'),") + else: + test_case.destroy_pg() + + +class TestDTensorOps(DTensorTestBase): + @property + def world_size(self) -> int: + return 4 + + # only allow float dytpe for now, we can relax this constraint + # when feel necessary later (i.e when adding quantization support). + @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") + @suppress_warnings + @ops(dtensor_lagging_op_db, allowed_dtypes=(torch.float,)) + @skipOps("TestDTensorOps", "test_dtensor_op_db", dtensor_fails) + def test_dtensor_op_db(self, dtype, op): + pg_backend = "nccl" if DEVICE_TYPE == "cuda" else "gloo" + if pg_backend == "nccl" and torch.cuda.device_count() < self.world_size: + sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code) + + self.init_pg(backend=pg_backend) + self.mesh = DeviceMesh(DEVICE_TYPE, torch.arange(self.world_size)) + + # test each op with dist tensor inputs and normal inputs + def test(): + samples = op.sample_inputs(DEVICE_TYPE, dtype, requires_grad=True) + for sample_input in samples: + args = [sample_input.input] + list(sample_input.args) + kwargs = sample_input.kwargs + + run_dtensor_crossref(self, op.op, args, kwargs) + # we need to figure out a way to test the out variant, out variant testing + # is tricky, as we need to pre allocate the dtensor out, some of them rely + # on sharding placements to be pre-known (i.e. mm.out) + # if isinstance(expected, torch.Tensor) and op.supports_out: + # func(*args, **kwargs, out=expected) + + check_dtensor_func(self, test, op) + + +# only instantiate tests for DEVICE_TYPE alone (i.e. either CPU or GPU) +instantiate_device_type_tests(TestDTensorOps, globals(), only_for=(DEVICE_TYPE,)) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/_tensor/test_math_ops.py b/test/distributed/_tensor/test_math_ops.py new file mode 100644 index 0000000000000..72bfd9c9d6d05 --- /dev/null +++ b/test/distributed/_tensor/test_math_ops.py @@ -0,0 +1,117 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# Owner(s): ["oncall: distributed"] + +import itertools + +import torch + +from torch.distributed._tensor import distribute_tensor +from torch.distributed._tensor.placement_types import Replicate, Shard +from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.distributed._tensor.common_dtensor import ( + DTensorTestBase, + skip_unless_torch_gpu, + with_comms, +) + + +class DistMathOpsTest(DTensorTestBase): + @with_comms + def test_sum(self): + device_mesh = self.build_device_mesh() + + shard_spec = [Shard(0)] + + tensor_to_sum = torch.randn(12, 8, 8) + + mat1 = distribute_tensor(tensor_to_sum, device_mesh, shard_spec) + + keep_dim_or_not = [True, False, None] + for dim in range(tensor_to_sum.ndim): + for keep_dim in keep_dim_or_not: + sum_args = (dim, keep_dim) if keep_dim is not None else (dim,) + dim_sumed_tensor = tensor_to_sum.sum(*sum_args) + dt_dim_sumed_tensor = mat1.sum(*sum_args).redistribute( + device_mesh, [Replicate()] * device_mesh.ndim + ) + self.assertEqual(dt_dim_sumed_tensor.to_local(), dim_sumed_tensor) + + full_sumed_tensor = tensor_to_sum.sum() + dt_sum = mat1.sum().redistribute(device_mesh, [Replicate()] * device_mesh.ndim) + self.assertEqual(dt_sum.to_local(), full_sumed_tensor) + + # TODO: forward test can be removed once test_softmax_with_bwd passes on CPU + @with_comms + def test_softmax_fwd(self): + device_mesh = self.build_device_mesh() + + x = torch.rand(8, 12, 16, device=self.device_type) + dims = range(3) # used to convert -1 to the actual dim + softmax_dims = [-1, 0, 1, 2] + shard_dims = [-1, 0, 1, 2] + test_list = list(itertools.product(softmax_dims, shard_dims)) + + for softmax_dim, shard_dim in test_list: + local_y = torch.nn.functional.softmax( + x, dim=softmax_dim, dtype=torch.float32 + ) + dist_x = distribute_tensor(x, device_mesh, [Shard(shard_dim)]) + if dims[shard_dim] == dims[softmax_dim]: + with self.assertRaisesRegex( + Exception, "Cannot run .* on sharding dimension!$" + ): + dist_y = torch.nn.functional.softmax( + dist_x, dim=softmax_dim, dtype=torch.float32 + ) + else: + dist_y = torch.nn.functional.softmax( + dist_x, dim=softmax_dim, dtype=torch.float32 + ) + self.assertTrue(dist_y.placements[0].is_shard(dim=shard_dim)) + dist_y = dist_y.redistribute(device_mesh, [Replicate()]) + self.assertEqual(dist_y.to_local(), local_y) + + # TODO: get test_softmax_with_bwd pass on CPU + # DTensor's _softmax_backward_data produces wrong result on CPU on certain dimension. + # fail_on_cpu_list = [(0, -1), (1, -1)] + @with_comms + @skip_unless_torch_gpu + def test_softmax_with_bwd(self): + device_mesh = self.build_device_mesh() + + dims = range(3) # used to convert -1 to the actual dim + softmax_dims = [-1, 0, 1, 2] + shard_dims = [-1, 0, 1, 2] + test_list = list(itertools.product(softmax_dims, shard_dims)) + + for params in test_list: + softmax_dim, shard_dim = params + x = torch.rand(8, 12, 16, device=self.device_type, requires_grad=True) + self.assertTrue(x.requires_grad) + local_y = torch.nn.functional.softmax( + x, dim=softmax_dim, dtype=torch.float32 + ).sum() + local_y.backward() + + dist_x = distribute_tensor(x, device_mesh, [Shard(shard_dim)]) + self.assertTrue(dist_x.requires_grad) + if dims[softmax_dim] == dims[shard_dim]: + with self.assertRaisesRegex( + Exception, "Cannot run .* on sharding dimension!$" + ): + dist_softmax = dist_x.softmax(dim=softmax_dim) + else: + dist_softmax = dist_x.softmax(dim=softmax_dim) + self.assertTrue(dist_softmax.placements[0].is_shard(dim=shard_dim)) + dist_y = dist_softmax.sum() + dist_y = dist_y.redistribute(device_mesh, [Replicate()]) + self.assertEqual(dist_y.to_local(), local_y) + self.assertIsNone(dist_x.grad) + dist_y.backward() + self.assertIsNotNone(dist_x.grad) + dist_x_grad = dist_x.grad.redistribute(device_mesh, [Replicate()]) + self.assertEqual(dist_x_grad.to_local(), x.grad) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/_tensor/test_matrix_ops.py b/test/distributed/_tensor/test_matrix_ops.py new file mode 100644 index 0000000000000..af9e16dc2c241 --- /dev/null +++ b/test/distributed/_tensor/test_matrix_ops.py @@ -0,0 +1,288 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# Owner(s): ["oncall: distributed"] + +import itertools +from typing import cast, List, Optional + +import torch +from torch.distributed._tensor import DeviceMesh, distribute_tensor +from torch.distributed._tensor.api import DTensor +from torch.distributed._tensor.placement_types import ( + _Partial, + Placement, + Replicate, + Shard, +) +from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.distributed._tensor.common_dtensor import ( + DTensorTestBase, + skip_unless_torch_gpu, + with_comms, +) + + +class DistMatrixOpsTest(DTensorTestBase): + @with_comms + def test_addmm(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + shard_spec = [Shard(0)] + replica_spec = [Replicate()] + + tensor_to_shard = torch.randn(12, 8) + mat1 = distribute_tensor(tensor_to_shard, device_mesh, shard_spec) + tensor_to_replicate = torch.randn(8, 4) + mat2 = distribute_tensor(tensor_to_replicate, device_mesh, replica_spec) + input_tensor = torch.randn(4) + input = distribute_tensor(input_tensor, device_mesh, replica_spec) + + dist_res = torch.addmm(input, mat1, mat2) + local_res = torch.addmm(input_tensor, tensor_to_shard, tensor_to_replicate) + self.assertEqual( + dist_res.redistribute(device_mesh, replica_spec).to_local(), + local_res, + ) + + @with_comms + def test_addmm_auto_redistribute(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + shard0_spec = [Shard(0)] + shard1_spec = [Shard(1)] + replica_spec = [Replicate()] + + tensor_to_shard1 = torch.randn(12, 8, requires_grad=True) + mat1 = distribute_tensor(tensor_to_shard1, device_mesh, shard1_spec) + tensor_to_shard0 = torch.randn(8, 4, requires_grad=True) + mat2 = distribute_tensor(tensor_to_shard0, device_mesh, shard0_spec) + input_tensor = torch.randn(4, requires_grad=True) + input = distribute_tensor(input_tensor, device_mesh, replica_spec) + + local_res = torch.addmm(input_tensor, tensor_to_shard1, tensor_to_shard0) + dist_res = torch.addmm(input, mat1, mat2) + + # test if addmm output is a partial + self.assertIsInstance(dist_res, DTensor) + self.assertIsInstance(dist_res.placements[0], _Partial) + + # test if result is the same as tensor + replica_res = dist_res.redistribute(device_mesh, replica_spec) + dist_local_res = replica_res.to_local() + self.assertEqual(local_res, dist_local_res) + + # backward checks + dist_local_res.sum().backward() + local_res.sum().backward() + self.assertIsNotNone(mat2.grad) + mat2_grad = mat2.grad.redistribute(device_mesh, replica_spec) + self.assertEqual(mat2_grad.to_local(), tensor_to_shard0.grad) + + @with_comms + def test_mm(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + shard0_spec = Shard(0) + shard1_spec = Shard(1) + replica_spec = Replicate() + + t1 = torch.randn(12, 8, requires_grad=True) + t2 = torch.randn(8, 16, requires_grad=True) + local_res = torch.mm(t1, t2) + + def test_placement_comb( + placements1: List[Placement], placements2: List[Placement] + ) -> None: + dt1 = distribute_tensor(t1, device_mesh, placements1) + dt2 = distribute_tensor(t2, device_mesh, placements2) + dist_res: DTensor = cast(DTensor, torch.mm(dt1, dt2)).redistribute( + device_mesh, [replica_spec] + ) + self.assertEqual(dist_res.to_local(), local_res) + # backward + grad_dist_res = torch.ones_like(dist_res) + dist_res.backward(grad_dist_res) + self.assertIsNotNone(dt1.grad) + + placement_specs = [shard0_spec, shard1_spec, replica_spec] + shard_specs_comb = list(itertools.product(placement_specs, placement_specs)) + for spec in shard_specs_comb: + test_placement_comb([spec[0]], [spec[1]]) + + @with_comms + def test_t(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + shard_spec = [Shard(0)] + + tensor_to_transpose = torch.randn(12, 8, requires_grad=True) + mat = distribute_tensor(tensor_to_transpose, device_mesh, shard_spec) + tranposed_mat = mat.t() + self.assertEqual(tranposed_mat.size(), torch.Size([8, 12])) + self.assertEqual(tranposed_mat.placements, [Shard(1)]) + tranposed_mat2 = tranposed_mat.t() + self.assertEqual(tranposed_mat2.size(), torch.Size([12, 8])) + self.assertEqual(tranposed_mat2.placements, shard_spec) + + @with_comms + def test_t_partial(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + + a = torch.randn(12, 8) + b = torch.randn(8, 4) + c = torch.mm(a, b).t() + + da = distribute_tensor(a, device_mesh, [Shard(1)]) + db = distribute_tensor(b, device_mesh, [Shard(0)]) + + # mm(da, db) should return a _Partial tensor. + # transposing it should keep it _Partial + dc = torch.mm(da, db).t() + + self.assertTrue(isinstance(dc.placements[0], _Partial)) + + # check that the local and distributed op results match + self.assertEqual( + c, + dc.redistribute(device_mesh, [Replicate()]).to_local(), + ) + + # baddbmm introduces nan occasionally on CPU: https://github.com/pytorch/pytorch/issues/80588 + @with_comms + @skip_unless_torch_gpu + def test_baddbmm(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + tensor = torch.rand(4, 4, 8, device=self.device_type, requires_grad=True) + batch_1 = torch.rand(4, 4, 8, device=self.device_type, requires_grad=True) + batch_2 = torch.rand(4, 8, 8, device=self.device_type, requires_grad=True) + + def test_placement_comb( + tensor_placements: List[Placement], + batch_1_placements: List[Placement], + batch_2_placements: List[Placement], + beta: int, + alpha: int, + batch_1_grad: Optional[torch.Tensor], + ) -> None: + tensor_dt = distribute_tensor(tensor, device_mesh, tensor_placements) + batch_1_dt = distribute_tensor(batch_1, device_mesh, batch_1_placements) + batch_2_dt = distribute_tensor(batch_2, device_mesh, batch_2_placements) + dist_res = cast( + DTensor, + torch.baddbmm( + tensor_dt, batch_1_dt, batch_2_dt, beta=beta, alpha=alpha + ), + ).redistribute(device_mesh, [Replicate()]) + dist_local_res = dist_res.to_local() + assert not torch.isnan(local_result).any() + assert not torch.isnan(dist_local_res).any() + self.assertEqual(dist_local_res.detach(), local_result.detach()) + + # TODO: add test backward + # grad_dist_res = torch.ones_like(dist_res) + # dist_res.backward(grad_dist_res) + # self.assertIsNotNone(batch_1_dt.grad) + # batch_1_grad_local = batch_1_dt.grad.redistribute( + # device_mesh, [Replicate()] + # ).to_local() + # self.assertEqual(batch_1_grad_local, batch_1_grad) + + shard0_spec = Shard(0) + shard1_spec = Shard(1) + shard2_spec = Shard(2) + replica_spec = Replicate() + shard_specs = [shard0_spec, shard1_spec, shard2_spec, replica_spec] + shard_specs_comb = list( + itertools.product(shard_specs, shard_specs, shard_specs) + ) + passlist = [ + (shard0_spec, shard0_spec, shard0_spec), + (shard0_spec, shard0_spec, replica_spec), + (shard0_spec, shard1_spec, shard0_spec), + (shard0_spec, shard2_spec, shard0_spec), + (shard1_spec, shard1_spec, replica_spec), + (shard0_spec, replica_spec, shard0_spec), + (shard2_spec, replica_spec, shard2_spec), + (shard2_spec, shard0_spec, shard2_spec), + (shard2_spec, shard1_spec, shard2_spec), + (shard2_spec, shard2_spec, shard2_spec), + (replica_spec, shard0_spec, shard0_spec), + (replica_spec, shard1_spec, replica_spec), + (replica_spec, shard2_spec, shard1_spec), + (replica_spec, replica_spec, shard2_spec), + (replica_spec, replica_spec, replica_spec), + ] + # If beta is 0, input tensor will be ignored + numeric_params_comb = [ + (0.0, 0.5), # zero-beta + (0.8, 0.5), # non-zero-beta + ] + + for beta, alpha in numeric_params_comb: + local_result = torch.baddbmm( + tensor, batch_1, batch_2, beta=beta, alpha=alpha + ) + grad_local_res = torch.ones_like(local_result) + local_result.backward(grad_local_res) + # tests that currently pass + for spec in passlist: + test_placement_comb( + [spec[0]], [spec[1]], [spec[2]], beta, alpha, batch_1.grad + ) + # TODO: support these tests + shard_specs_comb = [ + spec for spec in shard_specs_comb if spec not in passlist + ] + for spec in shard_specs_comb: + with self.assertRaises(Exception): + test_placement_comb( + [spec[0]], + [spec[1]], + [spec[2]], + beta, + alpha, + batch_1.grad, + ) + + @with_comms + def test_bmm(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + mat1 = torch.rand(4, 8, 4, device=self.device_type, requires_grad=True) + mat2 = torch.rand(4, 4, 8, device=self.device_type, requires_grad=True) + local_result = torch.bmm(mat1, mat2) + grad_local_res = torch.ones_like(local_result) + local_result.backward(grad_local_res) + + def test_placement_comb( + placements1: List[Placement], + placements2: List[Placement], + ) -> None: + mat1_dt = distribute_tensor(mat1, device_mesh, placements1) + mat2_dt = distribute_tensor(mat2, device_mesh, placements2) + dist_res = cast(DTensor, torch.bmm(mat1_dt, mat2_dt)).redistribute( + device_mesh, [Replicate()] + ) + dist_local_res = dist_res.to_local() + self.assertEqual(dist_local_res, local_result) + + # test backward + # TODO: figure out (replicate, shard1) fail on backward + # it generates a different grad shape + grad_dist_res = torch.ones_like(dist_res) + dist_res.backward(grad_dist_res) + self.assertIsNotNone(mat1_dt.grad) + mat1_dt_grad = cast(DTensor, mat1_dt.grad) + mat1_grad_local = mat1_dt_grad.redistribute( + device_mesh, [Replicate()] + ).to_local() + self.assertEqual(mat1_grad_local, mat1.grad) + + shard0_spec = Shard(0) + shard1_spec = Shard(1) + shard2_spec = Shard(2) + replica_spec = Replicate() + placement_specs = [shard0_spec, shard1_spec, shard2_spec, replica_spec] + shard_specs_comb = list(itertools.product(placement_specs, placement_specs)) + + # tests that currently pass + for spec in shard_specs_comb: + test_placement_comb([spec[0]], [spec[1]]) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/_tensor/test_pointwise_ops.py b/test/distributed/_tensor/test_pointwise_ops.py new file mode 100644 index 0000000000000..5b5eccfcb2ec8 --- /dev/null +++ b/test/distributed/_tensor/test_pointwise_ops.py @@ -0,0 +1,279 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# Owner(s): ["oncall: distributed"] + +from typing import Any, Callable, Dict, Optional, Sequence +from unittest import skip + +import torch + +import torch.utils._pytree as pytree +from torch import Tensor + +from torch.distributed._tensor import DeviceMesh, distribute_tensor, DTensor +from torch.distributed._tensor.placement_types import ( + _Partial, + Placement, + Replicate, + Shard, +) +from torch.distributed.distributed_c10d import ReduceOp +from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.distributed._tensor.common_dtensor import ( + DTensorTestBase, + skip_unless_torch_gpu, + with_comms, +) + + +def no_op(): + return None + + +def deepcopy_convert_to_dtensor( + val: Any, + device_mesh: DeviceMesh, + placements: Sequence[Placement], +) -> Any: + """ + Recursively convert (over Sequence and Dict types) Tensors into DTensors. + + :param device_mesh: the DeviceMesh to use. + :param placements: the Placement list to use. + :return: the transformed structure. + """ + + def f(x): + if isinstance(x, Tensor) and not isinstance(x, DTensor): + return distribute_tensor( + x, + device_mesh=device_mesh, + placements=placements, + ) + return x + + return pytree.tree_map(f, [val])[0] + + +def deepcopy_convert_from_dtensor(val: Any) -> Any: + """ + Recursive convert any DTensor to local Tensor. + + :param val: the structure to coerce. + :return: the coerced structure. + """ + + def f(x): + if isinstance(x, DTensor): + return x.redistribute( + device_mesh=x.device_mesh, + placements=[Replicate()] * x.device_mesh.ndim, + ).to_local() + return x + + return pytree.tree_map(f, [val])[0] + + +class DistElementwiseOpsTest(DTensorTestBase): + def _compare_pairwise_ops( + self, + *, + device_mesh: DeviceMesh, + placements: Sequence[Placement], + op: Callable, + pre_op_fn: Optional[Callable] = None, + args: Sequence[Any] = tuple(), + kwargs: Optional[Dict[str, Any]] = None, + ): + if pre_op_fn is None: + pre_op_fn = no_op + + if not kwargs: + kwargs = {} + + dargs = deepcopy_convert_to_dtensor( + args, + device_mesh=device_mesh, + placements=placements, + ) + dkwargs = deepcopy_convert_to_dtensor( + kwargs, + device_mesh=device_mesh, + placements=placements, + ) + + pre_op_fn() + + # run the reference first, in case the call is broken; + # it's better to debug an incorrect call at this point. + reference_result = op(*args, **kwargs) + + pre_op_fn() + + dist_result = op(*dargs, **dkwargs) + + collected_result = deepcopy_convert_from_dtensor(dist_result) + + self.assertEqual(reference_result, collected_result) + + # TODO: We need to add CPU tests for ops in the future. + def _run_sharded_elementwise_ops( + self, + *, + device_mesh: DeviceMesh, + placements: Sequence[Placement], + pre_op_fn: Optional[Callable] = None, + input_size: Sequence[int], + op: Callable, + **kwargs, + ): + if pre_op_fn is None: + pre_op_fn = no_op + + input_tensor = torch.randn( + *input_size, + device=self.device_type, + requires_grad=True, + ) + + self._compare_pairwise_ops( + device_mesh=device_mesh, + placements=placements, + pre_op_fn=pre_op_fn, + op=op, + args=(input_tensor,), + kwargs=kwargs, + ) + + @with_comms + def test_activations(self): + device_mesh = self.build_device_mesh() + self._run_sharded_elementwise_ops( + device_mesh=device_mesh, + placements=[Shard(0)], + input_size=(8, 5), + op=torch.nn.functional.gelu, + ) + self._run_sharded_elementwise_ops( + device_mesh=device_mesh, + placements=[Replicate()], + input_size=(8, 5), + op=torch.nn.functional.gelu, + ) + self._run_sharded_elementwise_ops( + device_mesh=device_mesh, + placements=[Shard(1)], + input_size=(3, 12), + op=torch.nn.functional.relu, + ) + self._run_sharded_elementwise_ops( + device_mesh=device_mesh, + placements=[Replicate()], + input_size=(8, 5), + op=torch.nn.functional.relu, + ) + self._run_sharded_elementwise_ops( + device_mesh=device_mesh, + placements=[Shard(0)], + input_size=(8, 5), + op=torch.sigmoid, + ) + self._run_sharded_elementwise_ops( + device_mesh=device_mesh, + placements=[Replicate()], + input_size=(8, 5), + op=torch.sigmoid, + ) + + @with_comms + @skip("testing RNG based ops is broken: https://github.com/pytorch/tau/issues/494") + def test_dropout(self): + device_mesh = self.build_device_mesh() + + def _reset_random_seed(): + torch.manual_seed(self.rank + 4) + + self._run_sharded_elementwise_ops( + device_mesh=device_mesh, + placements=[Shard(0)], + input_size=(8, 5), + op=torch.nn.functional.dropout, + pre_op_fn=_reset_random_seed, + p=0.4, + training=False, + ) + self._run_sharded_elementwise_ops( + device_mesh=device_mesh, + placements=[Shard(1)], + input_size=(3, 14), + op=torch.nn.functional.dropout, + pre_op_fn=_reset_random_seed, + p=0.5, + training=True, + ) + + @with_comms + @skip_unless_torch_gpu + def test_dropout_backward(self): + device_mesh = self.build_device_mesh() + placements = [Shard(0)] + + input_size = (8, 5) + + grad_output = torch.rand( + input_size, + device=self.device_type, + requires_grad=True, + ) + mask = ( + torch.rand( + input_size, + device=self.device_type, + requires_grad=False, + ) + < 0.8 + ) + + self._compare_pairwise_ops( + device_mesh=device_mesh, + placements=placements, + op=torch.ops.aten.native_dropout_backward, + kwargs=dict( + grad_output=grad_output, + mask=mask, + scale=0.3, + ), + ) + + @with_comms + def test_dropout_errors(self): + device_mesh = self.build_device_mesh() + with self.assertRaisesRegex(RuntimeError, "supported"): + self._run_sharded_elementwise_ops( + device_mesh=device_mesh, + placements=[_Partial(ReduceOp.SUM)], + input_size=(8, 5), + op=torch.nn.functional.dropout, + ) + + @with_comms + def test_mul_out(self): + device_mesh = self.build_device_mesh() + torch.manual_seed(self.rank) + shard_spec = [Shard(0)] + input_size = (8, 4) + input_tensor = torch.randn(*input_size, device=self.device_type) + dtensor = DTensor.from_local(input_tensor, device_mesh, shard_spec) + + other_tensor = torch.randn(*input_size, device=self.device_type) + other_dtensor = DTensor.from_local(other_tensor, device_mesh, shard_spec) + + output_tensor = torch.randn(*input_size, device=self.device_type) + output_dtensor = DTensor.from_local(output_tensor, device_mesh, shard_spec) + dt = torch.mul(dtensor, other_dtensor, out=output_dtensor) + expected = torch.mul(input_tensor, other_tensor, out=output_tensor) + self.assertEqual(input_tensor, dtensor.to_local()) + self.assertEqual(expected, dt.to_local()) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/_tensor/test_redistribute.py b/test/distributed/_tensor/test_redistribute.py new file mode 100644 index 0000000000000..70489e26791f4 --- /dev/null +++ b/test/distributed/_tensor/test_redistribute.py @@ -0,0 +1,272 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# Owner(s): ["oncall: distributed"] + +import itertools + +import torch +from torch.distributed._tensor import DeviceMesh, distribute_tensor, DTensor +from torch.distributed._tensor.placement_types import _Partial, Replicate, Shard + +from torch.testing._internal.common_utils import run_tests + +from torch.testing._internal.distributed._tensor.common_dtensor import ( + DTensorTestBase, + with_comms, +) + + +class RedistributeTest(DTensorTestBase): + @with_comms + def test_shard_to_replicate_forward_backward(self): + # 1) test shard -> replicate forward + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + replica_spec = [Replicate()] + + input_sizes_and_shard_dim = [ + ((self.world_size * 3, 3), 0), + ((self.world_size * 3 + 1, 3), 0), + ((self.world_size * 3 + 2, 3), 0), + ((3, self.world_size * 3), 1), + ((3, self.world_size * 3 + 1), 1), + ((3, self.world_size * 3 + 2), 1), + ] + + for input_size, shard_dim in input_sizes_and_shard_dim: + shard_spec = [Shard(shard_dim)] + expected_tensor = torch.randn( + input_size, device=self.device_type, requires_grad=True + ) + dtensor = distribute_tensor( + expected_tensor.clone(), device_mesh, shard_spec + ) + reshard_dtensor = dtensor.redistribute(device_mesh, replica_spec) + self.assertEqual(reshard_dtensor.size(), torch.Size(input_size)) + self.assertEqual(expected_tensor, reshard_dtensor.to_local()) + + # 2) test shard -> replicate backward: + # should give gradient as shard + grad_output = torch.ones_like(reshard_dtensor) + reshard_dtensor.backward(grad_output) + grad_input = dtensor.grad + self.assertEqual(grad_input.placements, shard_spec) + self.assertEqual( + grad_input.to_local(), torch.ones(dtensor.to_local().size()) + ) + + @with_comms + def test_replicate_to_replicate_forward_backward(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + replica_spec = [Replicate()] + local_tensor = torch.randn(12, 3, device=self.device_type, requires_grad=True) + # 1) test replicate -> replicate forward + replica_tensor = distribute_tensor(local_tensor, device_mesh, replica_spec) + reshard_replica_tensor = replica_tensor.redistribute(device_mesh, replica_spec) + self.assertEqual(replica_tensor.size(), local_tensor.size()) + self.assertEqual(replica_tensor, reshard_replica_tensor) + + # 2) test replicate -> replicate backward: + # should give gradient as replicate + grad_output = torch.ones_like(reshard_replica_tensor) + reshard_replica_tensor.backward(grad_output) + grad_input = replica_tensor.grad + self.assertEqual(grad_input.placements, replica_spec) + self.assertEqual(grad_input.to_local(), torch.ones(12, 3)) + + @with_comms + def test_replicate_to_shard_forward_backward(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + replica_spec = [Replicate()] + + input_sizes_and_shard_dim = [ + ((self.world_size * 3, 3), 0), + ((self.world_size * 3 + 1, 3), 0), + ((self.world_size * 3 + 2, 3), 0), + ((3, self.world_size * 3), 1), + ((3, self.world_size * 3 + 1), 1), + ((3, self.world_size * 3 + 2), 1), + ] + for input_size, shard_dim in input_sizes_and_shard_dim: + shard_spec = [Shard(shard_dim)] + # 1) test replicate -> shard forward + local_replica = torch.randn( + input_size, device=self.device_type, requires_grad=True + ) + splitted_list = local_replica.tensor_split(self.world_size, shard_dim) + # make local tensor as the element of the corresponding chunked list + local_tensor = splitted_list[self.rank] + replica_tensor = distribute_tensor(local_replica, device_mesh, replica_spec) + reshard_tensor = replica_tensor.redistribute(device_mesh, shard_spec) + self.assertEqual(reshard_tensor.size(), replica_tensor.size()) + self.assertEqual(reshard_tensor.placements, shard_spec) + self.assertEqual(reshard_tensor.to_local(), local_tensor) + + # 2) test replicate -> shard backward: + # should give gradient as replicate + grad_output = torch.ones_like(reshard_tensor) + reshard_tensor.backward(grad_output) + grad_input = replica_tensor.grad + self.assertEqual(grad_input.placements, replica_spec) + self.assertEqual(grad_input.to_local(), torch.ones(input_size)) + + @with_comms + def test_partial_to_replicate_forward_backward(self): + # Although we don't allow user to reshard to produce a partial + # placement (i.e. user can't reshard to partial), we do allow + # replicate to partial internally, and also partial to replicate + # backward should work as expected + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + partial_local = torch.randn(12, 3, device=self.device_type, requires_grad=True) + partial_spec = [_Partial()] + replica_spec = [Replicate()] + # test partial -> replicate, which trigger all_reduce + partial_tensor = DTensor.from_local(partial_local, device_mesh, partial_spec) + global_partial_tensor = partial_tensor.redistribute(device_mesh, replica_spec) + + self.assertEqual(partial_tensor.size(), partial_local.size()) + self.assertEqual( + partial_local * self.world_size, global_partial_tensor.to_local() + ) + + # test backward to have replicate grad on partial + global_partial_tensor.backward(torch.ones_like(global_partial_tensor)) + self.assertIsNotNone(partial_local.grad) + if device_mesh.get_rank() == 0: + self.assertEqual(partial_local.grad, torch.ones_like(partial_local)) + + @with_comms + def test_replicate_to_partial(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + local_tensor = torch.randn(12, 3, device=self.device_type, requires_grad=True) + partial_spec = _Partial() + replica_spec = Replicate() + # 1) test replicate -> partial forward + replica_tensor = distribute_tensor(local_tensor, device_mesh, [replica_spec]) + with self.assertRaisesRegex(RuntimeError, "Can not redistribute to _Partial"): + partial_tensor = replica_tensor.redistribute(device_mesh, [partial_spec]) + + from torch.distributed._tensor.redistribute import Redistribute + + partial_tensor = Redistribute.apply(replica_tensor, device_mesh, [partial_spec]) + self.assertEqual(partial_tensor.size(), local_tensor.size()) + # test it successfully zero out the contents on other ranks + if self.rank == 0: + self.assertEqual(replica_tensor.to_local(), partial_tensor.to_local()) + else: + self.assertEqual(partial_tensor.to_local(), torch.zeros_like(local_tensor)) + + # replicate to partial on sub groups + local_tensor = torch.randn(12, 3, device=self.device_type) + device_mesh = DeviceMesh( + self.device_type, + torch.arange(self.world_size).reshape(self.world_size // 2, 2), + ) + # 1) test replicate -> partial on 2d-mesh subgroups + replica_tensor = distribute_tensor( + local_tensor, device_mesh, [replica_spec, replica_spec] + ) + partial_tensor = Redistribute.apply( + replica_tensor, device_mesh, [partial_spec, partial_spec] + ) + self.assertEqual(partial_tensor.size(), local_tensor.size()) + + if self.rank != 3: + # replicate to partial should only zero out rank 3, and leave + # rank 0/2 (rank0 on mesh dim 1) and 0, 1 (rank0 on mesh dim 1) un-touched + self.assertEqual(replica_tensor.to_local(), partial_tensor.to_local()) + else: + self.assertEqual(replica_tensor.to_local(), torch.zeros_like(local_tensor)) + + @with_comms + def test_partial_to_shard(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + partial_spec = [_Partial()] + + input_sizes_and_shard_dim = [ + ((self.world_size * 3, 3), 0), + ((self.world_size * 3 + 1, 3), 0), + ((self.world_size * 3 + 2, 3), 0), + ((3, self.world_size * 3), 1), + ((3, self.world_size * 3 + 1), 1), + ((3, self.world_size * 3 + 2), 1), + ] + + for input_size, shard_dim in input_sizes_and_shard_dim: + shard_spec = [Shard(shard_dim)] + + partial_local = torch.ones(input_size, device=self.device_type) + partial_tensor = DTensor.from_local( + partial_local, device_mesh, partial_spec, run_check=False + ) + + quot, rem = divmod(input_size[shard_dim], self.world_size) + local_shape = list(input_size) + local_shape[shard_dim] = quot + (1 if self.rank < rem else 0) + # test partial to shard, trigger reduce_scatter + scatter_shard_tensor = partial_tensor.redistribute(device_mesh, shard_spec) + self.assertEqual(scatter_shard_tensor.size(), partial_tensor.size()) + self.assertEqual(scatter_shard_tensor.placements, shard_spec) + self.assertEqual( + scatter_shard_tensor.to_local(), + torch.ones(local_shape) * self.world_size, + ) + + +class MultiDimRedistributeTest(DTensorTestBase): + @property + def world_size(self) -> int: + return 8 + + @with_comms + def test_multi_dim_mesh(self): + devices = torch.arange(self.world_size) + for mesh_shape in [devices, devices.view(4, 2), devices.view(2, 2, 2)]: + mesh_shape = torch.arange(self.world_size).view(-1, 2) + device_mesh = DeviceMesh(self.device_type, mesh_shape) + tensor_shape = (16, 24) + + if torch.distributed.get_rank() == 0: + full_tensor = torch.randn(*tensor_shape) + else: + # these should be entirely ignored + # because distribute_tensor is expected to override shards in ranks != 0 + full_tensor = torch.ones(*tensor_shape) + + possibilities = [Replicate()] + [Shard(i) for i in range(full_tensor.ndim)] + all_outputs = list(itertools.product(*(mesh_shape.ndim * [possibilities]))) + all_inputs = list( + itertools.product(*(mesh_shape.ndim * [possibilities + [_Partial()]])) + ) + + for inputs in all_inputs: + # if partial, temporarily make it Replicated, then replace replicated with partial afterwards + repl_inputs = [Replicate() if s.is_partial() else s for s in inputs] + dt = distribute_tensor(full_tensor, device_mesh, repl_inputs) + + if repl_inputs != inputs: + # create a new DTensor reinterpreting some of the replicated entires as "Partial" + dt = DTensor.from_local( + dt.to_local(), device_mesh, inputs, run_check=False + ) + + for outputs in all_outputs: + # redistribute on target outputs + dt2 = dt.redistribute(device_mesh, outputs) + + # replicate and then get first shard + local_full = dt2.redistribute( + device_mesh, device_mesh.ndim * [Replicate()] + ).to_local() + + if torch.distributed.get_rank() == 0: + self.assertEqual(local_full.shape, full_tensor.shape) + + num_sums = 1 + for idx, input in enumerate(inputs): + if input.is_partial(): + num_sums *= mesh_shape.size(idx) + expected = num_sums * full_tensor + self.assertEqual(local_full, expected) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/_tensor/test_tensor_ops.py b/test/distributed/_tensor/test_tensor_ops.py new file mode 100644 index 0000000000000..254b365e34dc2 --- /dev/null +++ b/test/distributed/_tensor/test_tensor_ops.py @@ -0,0 +1,359 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# Owner(s): ["oncall: distributed"] + +import torch +from torch.distributed._tensor import DeviceMesh, distribute_tensor, DTensor +from torch.distributed._tensor.placement_types import _Partial, Replicate, Shard +from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.distributed._tensor.common_dtensor import ( + DTensorConverter, + DTensorTestBase, + with_comms, +) + + +class DistTensorOpsTest(DTensorTestBase): + @with_comms + def test_aten_contiguous(self): + # this op not covered by dtensor_ops + mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + self._test_op( + mesh, + lambda x: torch.ops.aten.contiguous(x), + torch.randn(16, 32), + ) + + @with_comms + def test_detach(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + shard_spec = [Shard(0)] + + tensor_to_detach = torch.randn(12, 8, requires_grad=True) + mat = distribute_tensor(tensor_to_detach, device_mesh, shard_spec) + detached_mat = mat.detach() + self.assertFalse(detached_mat is mat) + + @with_comms + def test_clone(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + specs = [[Replicate()], [Shard(0)]] + tensor_to_clone = torch.randn(12, 8, requires_grad=True) + for spec in specs: + mat = distribute_tensor(tensor_to_clone, device_mesh, spec) + cloned_mat = mat.clone() + self.assertFalse(cloned_mat is mat) + self.assertEqual(cloned_mat.to_local(), mat.to_local()) + + @with_comms + def test_contiguous(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + tensor = torch.rand(3, 5, 6, requires_grad=True) + sharding = [Shard(0)] + dist_tensor = DTensor.from_local(tensor, device_mesh, sharding) + self.assertTrue(dist_tensor.is_contiguous()) + # shard on dim 0 should not change stride (30, 6, 1) + self.assertEqual(dist_tensor.stride(), tensor.stride()) + + new_dt = dist_tensor.transpose(0, 2) + self.assertFalse(new_dt.is_contiguous()) + self.assertFalse(new_dt.to_local().is_contiguous()) + # check stride + self.assertEqual(new_dt.stride(), (1, 6, 30)) + + new_dt = new_dt.contiguous() + self.assertTrue(new_dt.is_contiguous()) + self.assertTrue(new_dt.to_local().is_contiguous()) + # check stride + self.assertEqual(dist_tensor.stride(), tensor.stride()) + + # check backward + new_dt.to_local().sum().backward() + self.assertEqual(tensor.grad, torch.ones(3, 5, 6)) + + @with_comms + def test_inplace_op(self): + mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + input_tensor = torch.randn((12, 3), device=self.device_type) + dt_to_add = distribute_tensor(input_tensor, mesh, [Shard(0)]) + dt_to_mul = dt_to_add.clone() + expected_add_dt = dt_to_add.clone() + 3 + add_res = dt_to_add.add_(3) + expected_mul_dt = dt_to_mul.clone() * 3 + mul_res = dt_to_mul.mul_(3) + # inplace op should be the same instance before and after + self.assertTrue(add_res is dt_to_add) + self.assertEqual(add_res.to_local(), expected_add_dt.to_local()) + + self.assertTrue(mul_res is dt_to_mul) + self.assertEqual(mul_res.to_local(), expected_mul_dt.to_local()) + + # test inplace op self and other dtensor with other specs + # and make sure out spec not change + shard_spec = [Shard(0)] + partial_spec = [_Partial()] + dt_to_inplace_add = distribute_tensor(input_tensor, mesh, shard_spec) + partial_grad = DTensor.from_local(torch.randn(12, 3), mesh, partial_spec) + res = dt_to_inplace_add.add_(partial_grad) + self.assertTrue(res is dt_to_inplace_add) + self.assertTrue(res.placements == shard_spec) + + @with_comms + def test_op_out_variant(self): + mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + input_tensor = torch.randn((12, 3), device=self.device_type) + sharded_dt_input = distribute_tensor(input_tensor, mesh, [Shard(0)]) + expected_dt = sharded_dt_input.clone() + 3 + sharded_dt_out = sharded_dt_input.clone() + res = torch.add(sharded_dt_input, 3, out=sharded_dt_out) + # op out variant should be the same instance before and after + self.assertTrue(res is sharded_dt_out) + self.assertEqual(sharded_dt_out.to_local(), expected_dt.to_local()) + + # test op out variant with other spec and make sure out spec not change + replica_spec = [Replicate()] + replicate_out = distribute_tensor(input_tensor, mesh, replica_spec) + expected_dt = replicate_out.clone() + 3 + res = torch.add(sharded_dt_input, 3, out=replicate_out) + self.assertTrue(res is replicate_out) + self.assertTrue(res.placements == replica_spec) + self.assertEqual(replicate_out.to_local(), expected_dt.to_local()) + + @with_comms + def test_empty_like(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + shard_spec = [Shard(0)] + + input_tensor = torch.randn(4, 8, requires_grad=True) + dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec) + empty_like_dt = torch.empty_like(dist_tensor) + # empty is not deterministic, so we only check that the shard propagation worked + self.assertEqual((4, 8), empty_like_dt.to_local().shape) + + @with_comms + def test_fill_inplace(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + shard_spec = [Shard(0)] + + input_tensor = torch.randn(4, 8, requires_grad=True) + dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec) + full_like_dt = torch.fill_(dist_tensor, 42.0) + full_expected = torch.full((4, 8), 42.0) + self.assertEqual(full_expected, full_like_dt.to_local()) + self.assertEqual(full_expected, dist_tensor.to_local()) + + @with_comms + def test_full_like(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + shard_spec = [Shard(0)] + + input_tensor = torch.randn(4, 8, requires_grad=True) + dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec) + full_like_dt = torch.full_like(dist_tensor, 42.0) + full_expected = torch.full((4, 8), 42.0) + self.assertEqual(full_expected, full_like_dt.to_local()) + + @with_comms + def test_ones_like(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + shard_spec = [Shard(0)] + + input_tensor = torch.randn(4, 8, requires_grad=True) + dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec) + ones_like_dt = torch.ones_like(dist_tensor) + ones_expected = torch.ones(4, 8) + self.assertEqual(ones_expected, ones_like_dt.to_local()) + + @with_comms + def test_ones_like_partial_sum(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + shard_spec = [_Partial()] + + input_tensor = torch.randn(4, 8, requires_grad=True) + dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec) + assert dist_tensor.shape == (4, 8) + + ones_like_dt = torch.ones_like(dist_tensor) + ones_expected = torch.ones(dist_tensor.shape) + self.assertEqual( + ones_expected, + ones_like_dt.redistribute(device_mesh, [Replicate()]).to_local(), + ) + + @with_comms + def test_fill_inplace_partial_sum(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + shard_spec = [_Partial()] + + input_tensor = torch.randn(4, 8, requires_grad=True) + dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec) + assert dist_tensor.shape == (4, 8) + + torch.fill_(dist_tensor, 42) + fill_expected = torch.full(dist_tensor.shape, 42, dtype=input_tensor.dtype) + self.assertEqual( + fill_expected, + dist_tensor.redistribute(device_mesh, [Replicate()]).to_local(), + ) + + @with_comms + def test_zeros_like_partial_sum(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + shard_spec = [_Partial()] + + input_tensor = torch.randn(4, 8, requires_grad=True) + dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec) + assert dist_tensor.shape == (4, 8) + + zeros_like_dt = torch.zeros_like(dist_tensor) + zeros_expected = torch.zeros(dist_tensor.shape) + self.assertEqual( + zeros_expected, + zeros_like_dt.redistribute(device_mesh, [Replicate()]).to_local(), + ) + + @with_comms + def test_zero_inplace(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + shard_spec = [Shard(0)] + + input_tensor = torch.randn(4, 8, requires_grad=True) + dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec) + zeros_like_dt = torch.zero_(dist_tensor) + zeros_expected = torch.zeros(4, 8) + self.assertEqual(zeros_expected, zeros_like_dt.to_local()) + self.assertEqual(zeros_expected, dist_tensor.to_local()) + + @with_comms + def test_zeros_like(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + shard_spec = [Shard(0)] + + input_tensor = torch.randn(4, 8, requires_grad=True) + dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec) + zeros_like_dt = torch.zeros_like(dist_tensor) + zeros_expected = torch.zeros(4, 8) + self.assertEqual(zeros_expected, zeros_like_dt.to_local()) + + def _test_op(self, mesh, op_call, *args, **kwargs): + out = op_call(*args, **kwargs) + dtc = DTensorConverter(mesh, args, kwargs) + for d_args, d_kwargs in dtc: + self.assertTrue(dtc.successful()) + d_out = op_call(*d_args, **d_kwargs) + self.assertEqual( + d_out.redistribute(mesh, [Replicate()] * mesh.ndim).to_local(), + out, + ) + + @with_comms + def test_index(self): + meshes = [ + DeviceMesh(self.device_type, list(range(self.world_size))), # 1D mesh + # TODO(@azzolini): un-comment when DTensorConverter supports N-D mesh + # DeviceMesh(self.device_type, torch.arange(self.world_size).reshape(2, -1)), # 2D mesh + ] + for mesh in meshes: + self._test_op( + mesh, + lambda x, y: x[y], + torch.randn(16, 32, 16), + torch.randint(5, (4, 8)), + ) + self._test_op( + mesh, + lambda x, y: x.index_select(1, y), + torch.randn(16, 32, 16), + torch.randint(5, (4,)), + ) + self._test_op( + mesh, + lambda x, y: x.index_select(0, y), + torch.randn(16, 32, 16), + torch.randint(5, (4,)), + ) + self._test_op( + mesh, + lambda x, y: x[y], + torch.randn(16, 32, 16), + torch.randint(5, (12,)), + ) + self._test_op( + mesh, + lambda x, y: x[:, y], + torch.randn(16, 32, 16), + torch.randint(5, (4, 8)), + ) + self._test_op( + mesh, + lambda x, y: x[..., y], + torch.randn(16, 32, 16), + torch.randint(5, (4, 12)), + ) + self._test_op( + mesh, + lambda x, y: x[..., y], + torch.randn(16, 32, 16), + torch.randint(5, (4, 8, 16)), + ) + self._test_op( + mesh, + lambda x, y, z: x[z, y], + torch.randn(16, 32, 16), + torch.randint(5, (12, 8, 12)), + torch.randint(2, (12, 8, 12)), + ) + self._test_op( + mesh, + lambda x, y, z: x[z, :, y], + torch.randn(16, 32, 16), + torch.randint(5, (12, 8, 12)), + torch.randint(2, (12, 8, 12)), + ) + self._test_op( + mesh, + lambda x, y, z: x[:, z, :, y], + torch.randn(16, 32, 16, 12), + torch.randint(5, (12, 8, 12)), + torch.randint(2, (12, 8, 12)), + ) + # broadcast in inner dimensions + self._test_op( + mesh, + lambda x, y, z: x[:, z, :, y], + torch.randn(16, 32, 16, 12), + torch.randint(5, (12, 8, 12)), + torch.randint(2, (12, 1, 12)), + ) + # implicit (left-padded) broadcast + self._test_op( + mesh, + lambda x, y, z: x[:, z, :, y], + torch.randn(16, 32, 16, 12), + torch.randint(5, (12, 8, 12)), + torch.randint(2, (8, 12)), + ) + self._test_op( + mesh, + lambda x, y, z: x[z, y, :, :], + torch.randn(16, 32, 16, 12), + torch.randint(2, (8, 12)), + torch.randint(5, (12, 8, 12)), + ) + self._test_op( + mesh, + lambda x, y, z: x[z, :, y, :], + torch.randn(16, 32, 16, 12), + torch.randint(2, (8, 12)), + torch.randint(5, (12, 8, 12)), + ) + self._test_op( + mesh, + lambda x, y, z: x[z, :, :, y], + torch.randn(16, 32, 16, 12), + torch.randint(2, (8, 1)), + torch.randint(5, (12, 8, 12)), + ) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/_tensor/test_tp_sharding_ops.py b/test/distributed/_tensor/test_tp_sharding_ops.py new file mode 100644 index 0000000000000..ef4d635f6ef76 --- /dev/null +++ b/test/distributed/_tensor/test_tp_sharding_ops.py @@ -0,0 +1,105 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# Owner(s): ["oncall: distributed"] + +import torch +from torch.distributed._tensor import ( + DeviceMesh, + distribute_tensor, + DTensor, + Replicate, + Shard, +) +from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.distributed._tensor.common_dtensor import ( + DTensorTestBase, + with_comms, +) + + +class TPShardingOpsTest(DTensorTestBase): + @property + def world_size(self) -> int: + return 4 + + @with_comms + def test_sharded_view(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + torch.manual_seed(0) + tensor = torch.rand(16, 35, 26) + sharding = [Shard(0)] + st = distribute_tensor(tensor, device_mesh, sharding).view(8, 4, 35, 13) + st_new = distribute_tensor(tensor.view(8, 4, 35, 13), device_mesh, sharding) + self.assertEqual(st.to_local(), st_new.to_local()) + self.assertEqual(st.placements[0], st_new.placements[0]) + + @with_comms + def test_sharded_transpose(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + torch.manual_seed(self.rank) + tensor = torch.rand(3, 5, 6, device=self.device_type) + sharding = [Shard(0)] + dist_tensor = DTensor.from_local(tensor, device_mesh, sharding) + new_dt = dist_tensor.transpose(0, 2) + self.assertTrue(new_dt.placements[0].is_shard(dim=2)) + self.assertEqual(new_dt.to_local(), tensor.transpose(0, 2)) + new_dt = dist_tensor.transpose(1, 2) + self.assertTrue(new_dt.placements[0].is_shard(dim=0)) + self.assertEqual(new_dt.to_local(), tensor.transpose(1, 2)) + + @with_comms + def test_sharded_permute(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + torch.manual_seed(self.rank) + tensor = torch.rand(3, 5, 6, device=self.device_type) + sharding = [Shard(0)] + dist_tensor = DTensor.from_local(tensor, device_mesh, sharding) + new_dt = dist_tensor.permute(1, 0, 2) + self.assertTrue(new_dt.placements[0].is_shard(dim=1)) + self.assertEqual(new_dt.to_local(), tensor.permute(1, 0, 2)) + + @with_comms + def test_replicated_permute(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + torch.manual_seed(0) + tensor = torch.rand(3, 5, 6, device=self.device_type) + sharding = [Replicate()] + dist_tensor = DTensor.from_local(tensor, device_mesh, sharding) + new_dt = dist_tensor.permute(1, 0, 2) + self.assertTrue(new_dt.placements[0].is_replicate()) + self.assertEqual(new_dt.to_local(), tensor.permute(1, 0, 2)) + self.assertEqual(new_dt.stride(), tensor.permute(1, 0, 2).stride()) + + @with_comms + def test_sharded_cat(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + torch.manual_seed(self.rank) + tensor_1 = torch.rand(3, 5, 6) + tensor_2 = torch.rand(3, 5, 6) + tensor_3 = torch.rand(3, 5, 6) + sharding = [Shard(0)] + dt_1 = DTensor.from_local(tensor_1, device_mesh, sharding) + dt_2 = DTensor.from_local(tensor_2, device_mesh, sharding) + dt_3 = DTensor.from_local(tensor_3, device_mesh, sharding) + new_dt = torch.cat([dt_1, dt_2, dt_3]) + cat_dt = DTensor.from_local( + torch.cat([tensor_1, tensor_2, tensor_3]), device_mesh, sharding + ) + self.assertEqual(new_dt.to_local(), cat_dt.to_local()) + self.assertEqual(new_dt.size(), cat_dt.size()) + + @with_comms + def test_sharded_split(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + torch.manual_seed(self.rank) + tensor = torch.rand(3, 5, 6, device=self.device_type) + sharding = [Shard(2)] + dist_tensor = DTensor.from_local(tensor, device_mesh, sharding) + dt_list = dist_tensor.split(dist_tensor.size(-1) // 2, dim=-1) + local_tensors = tensor.split(3, dim=-1) + for idx, dt in enumerate(dt_list): + self.assertTrue(dt.placements[0].is_shard(dim=2)) + self.assertEqual(dt.to_local(), local_tensors[idx]) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/_tensor/test_view_ops.py b/test/distributed/_tensor/test_view_ops.py new file mode 100644 index 0000000000000..fa502d2b56031 --- /dev/null +++ b/test/distributed/_tensor/test_view_ops.py @@ -0,0 +1,464 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# Owner(s): ["oncall: distributed"] + +import itertools +from typing import cast, List + +import torch +import torch.distributed as dist +from torch import rand, randn, Tensor +from torch.distributed._tensor import DeviceMesh, distribute_tensor, Replicate, Shard +from torch.distributed._tensor.ops.view_ops import ( + Broadcast, + Flatten, + InputDim, + ops, + Repeat, + Singleton, + Split, + view_groups, +) +from torch.distributed._tensor.placement_types import Placement +from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.distributed._tensor.common_dtensor import ( + DTensorTestBase, + redistribute_profiler, + with_comms, +) +from torch.utils._pytree import tree_flatten + + +class TestViewOps(DTensorTestBase): + def test_view_groups(self): + self.assertEquals( + view_groups([2, 3], [3, 2]), + ( + Split(Flatten((InputDim(0), InputDim(1))), (3, 2), 0), + Split(Flatten((InputDim(0), InputDim(1))), (3, 2), 1), + ), + ) + self.assertEquals( + view_groups([3, 4, 5], [12, 5]), + (Flatten((InputDim(0), InputDim(1))), InputDim(2)), + ) + self.assertEquals( + view_groups([2, 3, 4, 5, 7], [12, 70]), + ( + Split( + Flatten( + ( + InputDim(0), + InputDim(1), + InputDim(2), + InputDim(3), + InputDim(4), + ) + ), + (12, 70), + 0, + ), + Split( + Flatten( + ( + InputDim(0), + InputDim(1), + InputDim(2), + InputDim(3), + InputDim(4), + ) + ), + (12, 70), + 1, + ), + ), + ) + self.assertEquals( + view_groups([2, 3, 4, 5, 7], [3, 8, 7, 5]), + ( + Split(Flatten((InputDim(0), InputDim(1), InputDim(2))), (3, 8), 0), + Split(Flatten((InputDim(0), InputDim(1), InputDim(2))), (3, 8), 1), + Split(Flatten((InputDim(3), InputDim(4))), (7, 5), 0), + Split(Flatten((InputDim(3), InputDim(4))), (7, 5), 1), + ), + ) + self.assertEquals( + view_groups([3, 4, 8, 3], [12, 4, 2, 3]), + ( + Flatten((InputDim(0), InputDim(1))), + Split(InputDim(2), (4, 2), 0), + Split(InputDim(2), (4, 2), 1), + InputDim(3), + ), + ) + self.assertEquals( + view_groups([3, 24], [1, 3, 2, 4, 1, 3, 1]), + ( + Singleton(), + InputDim(0), + Split(InputDim(1), (2, 4, 3), 0), + Split(InputDim(1), (2, 4, 3), 1), + Singleton(), + Split(InputDim(1), (2, 4, 3), 2), + Singleton(), + ), + ) + self.assertEquals( + view_groups([1, 1, 3, 2, 1, 1], [6, 1, 1, 1]), + ( + Flatten((InputDim(2), InputDim(3))), + Singleton(), + Singleton(), + Singleton(), + ), + ) + self.assertEquals( + view_groups([1, 1, 12, 1, 1, 1, 2, 5, 1], [3, 4, 1, 10]), + ( + Split(InputDim(2), (3, 4), 0), + Split(InputDim(2), (3, 4), 1), + Singleton(), + Flatten((InputDim(6), InputDim(7))), + ), + ) + self.assertEquals( + view_groups([2, 3, 4], [2, -1, 4]), + (InputDim(0), InputDim(1), InputDim(2)), + ) + + @property + def world_size(self) -> int: + return 6 + + def call_dt_test(self, op, args, kwargs, device_mesh: DeviceMesh): + spec = ops[op] + rules = spec.dim_map(*args, **kwargs) + outputs = op(*args, **kwargs) + flat_args, _ = tree_flatten(args) + in_shape = flat_args[0].shape + + no_shard_dims = set() + for rule in rules: + if isinstance(rule, Repeat): + if isinstance(rule.input_dim, InputDim): + no_shard_dims.add(rule.input_dim.input_dim) + elif isinstance(rule, Flatten): + for dim in rule.input_dims[1:]: + if isinstance(dim, InputDim): + no_shard_dims.add(dim.input_dim) + elif isinstance(rule, Split): + if isinstance(rule.input_dim, Flatten): + for dim in rule.input_dim.input_dims[1:]: + if isinstance(dim, InputDim): + no_shard_dims.add(dim.input_dim) + + if op == torch.unbind: + no_shard_dims.add(kwargs.get("dim", 0)) + + sharding_choices = cast(List[Placement], [Replicate()]) + [ + Shard(i) for i, s in enumerate(in_shape) if s > 1 and i not in no_shard_dims + ] + + all_sharding_choices = itertools.product( + *(device_mesh.ndim * [sharding_choices]) + ) + + for in_shard in all_sharding_choices: + # print(f' |--- {in_shard}') + in_dt = distribute_tensor(args[0], device_mesh, in_shard) + + with redistribute_profiler() as profiler: + out_dt = op(in_dt, *args[1:], **kwargs) + + self.assertEqual(profiler.num_calls, 0, "Expected no redistribution.") + + full_out = out_dt.redistribute( + device_mesh, device_mesh.ndim * [Replicate()] + ).to_local() + + if dist.get_rank() == 0: + self.assertEqual(outputs, full_out) + + def dimmap_test(self, op, args, expected_rule_output): + rules = ops[op].dim_map(*args) + self.assertEquals(rules, expected_rule_output) + self.call_dt_test(op, args, {}, self.device_mesh) + + @with_comms + def test_view_ops(self): + self.device_mesh = DeviceMesh( + self.device_type, torch.arange(dist.get_world_size()).view(-1, 2) + ) + self.dimmap_test(torch.atleast_1d, (randn(()),), (Singleton(),)) + self.dimmap_test(torch.atleast_1d, (randn(24),), (InputDim(0),)) + self.dimmap_test(torch.atleast_1d, (randn(24, 36),), (InputDim(0), InputDim(1))) + + self.dimmap_test(torch.atleast_2d, (randn(()),), (Singleton(), Singleton())) + self.dimmap_test(torch.atleast_2d, (randn(24),), (Singleton(), InputDim(0))) + self.dimmap_test(torch.atleast_2d, (randn(24, 36),), (InputDim(0), InputDim(1))) + self.dimmap_test( + torch.atleast_2d, + (randn(24, 36, 48),), + (InputDim(0), InputDim(1), InputDim(2)), + ) + + self.dimmap_test( + torch.atleast_3d, + (randn(()),), + (Singleton(), Singleton(), Singleton()), + ) + self.dimmap_test( + torch.atleast_3d, + (randn(24),), + (Singleton(), InputDim(0), Singleton()), + ) + self.dimmap_test( + torch.atleast_3d, + (randn(24, 36),), + (InputDim(0), InputDim(1), Singleton()), + ) + self.dimmap_test( + torch.atleast_3d, + (randn(24, 36, 42),), + (InputDim(0), InputDim(1), InputDim(2)), + ) + self.dimmap_test( + torch.atleast_3d, + (randn(24, 36, 42, 24),), + (InputDim(0), InputDim(1), InputDim(2), InputDim(3)), + ) + + with self.assertRaises(AssertionError): + ops[torch.broadcast_to].dim_map(randn(24, 36), (1, 2, 4)) + + self.dimmap_test( + torch.broadcast_to, + (rand(24, 36), (1, 24, 36)), + (Singleton(), InputDim(0), InputDim(1)), + ) + self.dimmap_test( + torch.broadcast_to, + (rand(24, 36), (42, 24, 36)), + (Broadcast(Singleton(), 42), InputDim(0), InputDim(1)), + ) + self.dimmap_test( + torch.broadcast_to, + (rand(24, 1, 36), (12, 24, 24, 36)), + ( + Broadcast(Singleton(), 12), + InputDim(0), + Broadcast(InputDim(1), 24), + InputDim(2), + ), + ) + self.dimmap_test( + torch.broadcast_to, + (rand(24, 36), (-1, 36)), + (InputDim(0), InputDim(1)), + ) + self.dimmap_test( + torch.broadcast_to, + (rand(24, 1, 36), (-1, 1, 36)), + (InputDim(0), InputDim(1), InputDim(2)), + ) + + self.dimmap_test( + torch.broadcast_to, + (randn(36, 1, 24), (12, 36, 42, 24)), + ( + Broadcast(Singleton(), 12), + InputDim(0), + Broadcast(InputDim(1), 42), + InputDim(2), + ), + ) + + self.dimmap_test( + Tensor.expand, + (randn(24, 1, 36, 1), 36, 24, 42, -1, 24), + ( + Broadcast(Singleton(), 36), + InputDim(0), + Broadcast(InputDim(1), 42), + InputDim(2), + Broadcast(InputDim(3), 24), + ), + ) + + self.dimmap_test( + Tensor.expand, + (randn(24, 1, 36, 1), (36, 24, 42, -1, 24)), + ( + Broadcast(Singleton(), 36), + InputDim(0), + Broadcast(InputDim(1), 42), + InputDim(2), + Broadcast(InputDim(3), 24), + ), + ) + + self.dimmap_test( + torch.flatten, + (randn(24, 36),), + (Flatten((InputDim(0), InputDim(1))),), + ) + self.dimmap_test(torch.flatten, (randn(42),), (InputDim(0),)) + self.dimmap_test(torch.flatten, (randn(()),), (Singleton(),)) + + self.dimmap_test( + torch.movedim, + (randn(12, 24, 48, 96), 1, 2), + (InputDim(0), InputDim(2), InputDim(1), InputDim(3)), + ) + self.dimmap_test( + torch.movedim, + (randn(6, 12, 24), 1, 0), + (InputDim(1), InputDim(0), InputDim(2)), + ) + self.dimmap_test( + torch.movedim, + (randn(24, 12, 6), (1, 2), (0, 1)), + (InputDim(1), InputDim(2), InputDim(0)), + ) + self.dimmap_test( + torch.movedim, + (randn(24, 6, 12), (0, 2, 1), (2, 1, 0)), + (InputDim(1), InputDim(2), InputDim(0)), + ) + self.dimmap_test( + torch.movedim, + (randn(24, 12), (1, 0), (0, 1)), + (InputDim(1), InputDim(0)), + ) + + self.dimmap_test( + torch.movedim, + (randn(36, 24, 12), (1, 2), (0, 1)), + (InputDim(1), InputDim(2), InputDim(0)), + ) + self.dimmap_test( + torch.movedim, + (randn(36, 24, 12), (1, 2), (-3, -2)), + (InputDim(1), InputDim(2), InputDim(0)), + ) + + self.dimmap_test( + torch.permute, + (randn(24, 36, 42), (2, 0, 1)), + (InputDim(2), InputDim(0), InputDim(1)), + ) + self.dimmap_test( + torch.permute, + (randn(24, 36, 42), (-1, -3, -2)), + (InputDim(2), InputDim(0), InputDim(1)), + ) + + self.dimmap_test( + torch.ravel, + (randn(24, 36),), + (Flatten((InputDim(0), InputDim(1))),), + ) + self.dimmap_test(torch.ravel, (randn(42),), (InputDim(0),)) + self.dimmap_test(torch.ravel, (randn(()),), (Singleton(),)) + + self.dimmap_test( + Tensor.repeat, + (randn(24, 36), 1, 2, 1, 1, 2), + ( + Singleton(), + Broadcast(Singleton(), 2), + Singleton(), + InputDim(0), + Repeat(InputDim(1), 2), + ), + ) + + self.dimmap_test( + torch.reshape, + (randn(6, 12, 24), (72, 24)), + (Flatten((InputDim(0), InputDim(1))), InputDim(2)), + ) + + self.dimmap_test( + torch.tile, + (randn(24, 36), (1, 2, 1, 1, 2)), + ( + Singleton(), + Broadcast(Singleton(), 2), + Singleton(), + InputDim(0), + Repeat(InputDim(1), 2), + ), + ) + self.dimmap_test( + torch.tile, + (randn(42, 24, 36), (1, 3)), + (InputDim(0), InputDim(1), Repeat(InputDim(2), 3)), + ) + + self.dimmap_test( + torch.transpose, + (randn(24, 60, 42, 60), 2, 0), + (InputDim(2), InputDim(1), InputDim(0), InputDim(3)), + ) + self.dimmap_test( + torch.transpose, + (randn(24, 60, 42, 60), -1, 0), + (InputDim(3), InputDim(1), InputDim(2), InputDim(0)), + ) + + self.dimmap_test( + torch.unsqueeze, + (randn(42, 24, 36), 1), + (InputDim(0), Singleton(), InputDim(1), InputDim(2)), + ) + + self.dimmap_test( + Tensor.view, + (randn(6, 12, 24), 72, 24), + (Flatten((InputDim(0), InputDim(1))), InputDim(2)), + ) + + self.dimmap_test(Tensor.view, (randn(1, 1, 12), -1), (InputDim(2),)) + + self.dimmap_test( + Tensor.view, + (randn(1, 1, 42, 24), -1), + (Flatten((InputDim(2), InputDim(3))),), + ) + + self.dimmap_test( + Tensor.view, + (randn(1, 1, 42, 1, 24, 1), -1), + (Flatten((InputDim(2), InputDim(4))),), + ) + + self.dimmap_test( + Tensor.view, + (randn(48, 35, 26), (24, 4, 35, 13)), + ( + Split( + Flatten(input_dims=(InputDim(0), InputDim(1), InputDim(2))), + group_shape=(24, 4, 35, 13), + split_id=0, + ), + Split( + Flatten(input_dims=(InputDim(0), InputDim(1), InputDim(2))), + group_shape=(24, 4, 35, 13), + split_id=1, + ), + Split( + Flatten(input_dims=(InputDim(0), InputDim(1), InputDim(2))), + group_shape=(24, 4, 35, 13), + split_id=2, + ), + Split( + Flatten(input_dims=(InputDim(0), InputDim(1), InputDim(2))), + group_shape=(24, 4, 35, 13), + split_id=3, + ), + ), + ) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/_tools/test_memory_tracker.py b/test/distributed/_tools/test_memory_tracker.py new file mode 100644 index 0000000000000..2e19ef6bf7294 --- /dev/null +++ b/test/distributed/_tools/test_memory_tracker.py @@ -0,0 +1,67 @@ +# Owner(s): ["oncall: distributed"] + +from torch.testing._internal.common_cuda import TEST_CUDA +from torch.testing._internal.common_utils import ( + TestCase, + run_tests, +) + +import torch +import torch.nn as nn + +from torch.distributed._tools import MemoryTracker + +import unittest + + +class TestMemoryTracker(TestCase): + @unittest.skipIf(not TEST_CUDA, "no cuda") + def test_local_model(self): + """ + Minimal test case to check the memory tracker can collect the expected + memory stats at operator level, as well as can print the summary result + without crash. + """ + # Create a model with a hierarchy of modules + torch.manual_seed(0) + model = nn.Sequential( + nn.Sequential( + nn.Conv2d(3, 64, kernel_size=(3, 3), padding=(1, 1), bias=False), + nn.BatchNorm2d(64), + nn.ReLU(inplace=False), + nn.AdaptiveAvgPool2d(output_size=(1, 1)), + ), + nn.Flatten(start_dim=1), + nn.Sequential(nn.Linear(64, 2), nn.ReLU(inplace=True)), + ).cuda() + + # Run one iteration of forward and backward pass + tracker = MemoryTracker() + tracker.start_monitor(model) + + x = torch.randn(size=(2, 3, 224, 224), device=torch.device("cuda")) + # torch.LongTensor expects cpu device type, not cuda device type in + # constructor, so calling .cuda() outside constructor here. + target = torch.LongTensor([0, 1]).cuda() + criterion = nn.CrossEntropyLoss() + criterion(model(x), target).backward() + + self.assertTrue(len(tracker._hooks) > 0) + + tracker.stop() + + self.assertTrue(len(tracker._hooks) == 0) + + tracker.summary() + + self.assertTrue(tracker._op_index > 0) + self.assertTrue(len(tracker._operator_names) > 0) + self.assertEqual(len(tracker.memories_allocated), tracker._op_index) + self.assertEqual(len(tracker.memories_active), tracker._op_index) + self.assertEqual(len(tracker.memories_reserved), tracker._op_index) + self.assertTrue(len(tracker._markers) == 2) + self.assertTrue(tracker._cur_module_name != "") + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/algorithms/ddp_comm_hooks/test_ddp_hooks.py b/test/distributed/algorithms/ddp_comm_hooks/test_ddp_hooks.py index ead934eb83e73..d3ea932b05fca 100644 --- a/test/distributed/algorithms/ddp_comm_hooks/test_ddp_hooks.py +++ b/test/distributed/algorithms/ddp_comm_hooks/test_ddp_hooks.py @@ -124,7 +124,7 @@ def test_ddp_comm_hook_allreduce_hook(self): # Register hook case, get the hook grads. hook_grads = self._get_grads(process_group, DDPCommHookType.ALLREDUCE) - torch.testing.assert_allclose(hook_grads, reference_grads, rtol=1e-5, atol=0) + torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=0) @requires_nccl() @skip_if_lt_x_gpu(2) @@ -141,7 +141,7 @@ def test_ddp_comm_hook_fp16compress_hook(self): # Register hook case, get the hook grads. hook_grads = self._get_grads(process_group, DDPCommHookType.FP16_COMPRESS) - torch.testing.assert_allclose(hook_grads, reference_grads, rtol=1e-5, atol=1e-4) + torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=1e-4) @requires_nccl() @skip_if_lt_x_gpu(2) @@ -158,7 +158,7 @@ def test_ddp_comm_hook_quantize_per_tensor_hook(self): # Register hook case, get the hook grads. hook_grads = self._get_grads(process_group, DDPCommHookType.QUANTIZE_PER_TENSOR) - torch.testing.assert_allclose(hook_grads, reference_grads, rtol=1e-5, atol=1e-4) + torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=1e-4) @requires_nccl() @skip_if_lt_x_gpu(2) @@ -177,7 +177,7 @@ def test_ddp_comm_hook_quantize_per_channel_hook(self): process_group, DDPCommHookType.QUANTIZE_PER_CHANNEL ) - torch.testing.assert_allclose(hook_grads, reference_grads, rtol=1e-5, atol=1e-4) + torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=1e-4) @requires_nccl() @@ -198,7 +198,7 @@ def test_ddp_comm_hook_noop_hook(self): hook_grads.div_(self.world_size) dist.all_reduce(hook_grads, group=process_group) - torch.testing.assert_allclose(hook_grads, reference_grads, rtol=1e-5, atol=0) + torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=0) @requires_nccl() @skip_if_lt_x_gpu(2) diff --git a/test/distributed/_shard/checkpoint/test_checkpoint.py b/test/distributed/checkpoint/test_checkpoint.py similarity index 79% rename from test/distributed/_shard/checkpoint/test_checkpoint.py rename to test/distributed/checkpoint/test_checkpoint.py index 1b3cf04eb2ccf..96c98116328c4 100644 --- a/test/distributed/_shard/checkpoint/test_checkpoint.py +++ b/test/distributed/checkpoint/test_checkpoint.py @@ -2,9 +2,9 @@ import sys from typing import Optional, List, cast -from torch.distributed._shard.checkpoint.storage import WriteResult +from torch.distributed.checkpoint.storage import WriteResult -from torch.distributed._shard.checkpoint import ( +from torch.distributed.checkpoint import ( StorageReader, StorageWriter, CheckpointException, @@ -20,17 +20,17 @@ from torch.distributed._shard import sharded_tensor -from torch.distributed._shard.checkpoint.default_planner import ( +from torch.distributed.checkpoint.default_planner import ( _create_default_local_metadata, ) -from torch.distributed._shard.checkpoint.metadata import ( +from torch.distributed.checkpoint.metadata import ( BytesStorageMetadata, Metadata, TensorStorageMetadata, ) -from torch.distributed._shard.checkpoint.planner import ( +from torch.distributed.checkpoint.planner import ( SavePlan, SavePlanner, LoadPlan, @@ -63,6 +63,7 @@ ) sys.exit(0) + class TestModule(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -121,34 +122,44 @@ def test_default_metadata(self) -> None: ) state_dict = { - 'sharded': sharded_tensor.rand(spec, (10, 10, )), - 'replicated': torch.rand(4, device=device), - 'bytes': [1, 2, 3, 4], + "sharded": sharded_tensor.rand( + spec, + ( + 10, + 10, + ), + ), + "replicated": torch.rand(4, device=device), + "bytes": [1, 2, 3, 4], } metadata = _create_default_local_metadata(state_dict) - self.assertTrue('bytes' in metadata.state_dict_metadata) - self.assertIsInstance(metadata.state_dict_metadata['bytes'], BytesStorageMetadata) + self.assertTrue("bytes" in metadata.state_dict_metadata) + self.assertIsInstance( + metadata.state_dict_metadata["bytes"], BytesStorageMetadata + ) - self.assertTrue('replicated' in metadata.state_dict_metadata) - self.assertIsInstance(metadata.state_dict_metadata['replicated'], TensorStorageMetadata) - md = metadata.state_dict_metadata['replicated'] - self.assertEqual(md.size, state_dict['replicated'].size()) + self.assertTrue("replicated" in metadata.state_dict_metadata) + self.assertIsInstance( + metadata.state_dict_metadata["replicated"], TensorStorageMetadata + ) + md = metadata.state_dict_metadata["replicated"] + self.assertEqual(md.size, state_dict["replicated"].size()) self.assertEqual(md.properties.dtype, torch.float32) self.assertEqual(1, len(md.chunks)) - self.assertTrue('sharded' in metadata.state_dict_metadata) - self.assertIsInstance(metadata.state_dict_metadata['sharded'], TensorStorageMetadata) - md = metadata.state_dict_metadata['sharded'] + self.assertTrue("sharded" in metadata.state_dict_metadata) + self.assertIsInstance( + metadata.state_dict_metadata["sharded"], TensorStorageMetadata + ) + md = metadata.state_dict_metadata["sharded"] self.assertEqual(md.properties.dtype, torch.float32) - self.assertEqual(md.size, state_dict['sharded'].size()) + self.assertEqual(md.size, state_dict["sharded"].size()) self.assertEqual(2, len(md.chunks)) + class TestStorageBase: - def __init__( - self, - fail_conf - ): + def __init__(self, fail_conf): self.fail_conf = fail_conf self.rank = 0 if not dist.is_initialized() else dist.get_rank() @@ -164,16 +175,16 @@ def _fail_rank_async(self, name, result=None): ranks = self._get_ranks(name) fut = Future() if ranks is not None and self.rank in ranks: - fut.set_exception(ValueError(f"async rank fail {self.rank} for {name}")) + fut.set_exception( + ValueError(f"async rank fail {self.rank} for {name}") + ) else: fut.set_result(result) return fut + class FaultyStorageWriter(TestStorageBase, StorageWriter): - def __init__( - self, - fail_conf - ): + def __init__(self, fail_conf): super(FaultyStorageWriter, self).__init__(fail_conf) def init(self, is_coordinator: bool) -> None: @@ -188,23 +199,19 @@ def prepare_global_plan(self, plans: List[SavePlan]) -> List[SavePlan]: return plans def write_data( - self, - plan: SavePlan, - planner: SavePlanner + self, plan: SavePlan, planner: SavePlanner ) -> Future[List[WriteResult]]: self._fail_rank("fail_write_data") return self._fail_rank_async("fail_write_data_async", []) - def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None: + def finish( + self, metadata: Metadata, results: List[List[WriteResult]] + ) -> None: self._fail_rank("fail_finish") class FaultyStorageReader(TestStorageBase, StorageReader): - def __init__( - self, - metadata, - fail_conf - ): + def __init__(self, metadata, fail_conf): super(FaultyStorageReader, self).__init__(fail_conf) self.metadata = metadata @@ -219,11 +226,7 @@ def prepare_global_plan(self, plans: List[LoadPlan]) -> List[LoadPlan]: self._fail_rank("fail_prepare_global_plan") return plans - def read_data( - self, - plan: LoadPlan, - planner: LoadPlanner - ) -> Future[None]: + def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]: self._fail_rank("fail_read_data") return self._fail_rank_async("fail_read_data_async") @@ -231,13 +234,14 @@ def read_metadata(self) -> Metadata: self._fail_rank("fail_read_metadata") return self.metadata + class TestDistributedFailure(ShardedTensorTestBase): def get_spec(self): return ChunkShardingSpec( dim=0, placements=[ f"rank:{r}/cuda:{r}" for r in range(dist.get_world_size()) - ] + ], ) @with_comms(init_rpc=False) @@ -245,9 +249,9 @@ def get_spec(self): @requires_nccl() def test_dummy_writer_works(self) -> None: state_dict = { - 'sharded': sharded_tensor.rand(self.get_spec(), 20, 20), - 'replicated': torch.rand(10, 10), - 'bytes': [1, 2, 3, 4] + "sharded": sharded_tensor.rand(self.get_spec(), 20, 20), + "replicated": torch.rand(10, 10), + "bytes": [1, 2, 3, 4], } save_state_dict(state_dict, FaultyStorageWriter({})) @@ -257,9 +261,9 @@ def test_dummy_writer_works(self) -> None: @requires_nccl() def test_dummy_reader_works(self) -> None: state_dict = { - 'sharded': sharded_tensor.rand(self.get_spec(), 20, 20), - 'replicated': torch.rand(10, 10), - 'bytes': [1, 2, 3, 4] + "sharded": sharded_tensor.rand(self.get_spec(), 20, 20), + "replicated": torch.rand(10, 10), + "bytes": [1, 2, 3, 4], } metadata = _create_default_local_metadata(state_dict) @@ -283,8 +287,10 @@ def _test_dist_failure(self, callback, kwargs): failed_ranks = e.failures.keys() for rank in bad_ranks: - self.assertTrue(rank in failed_ranks, msg=f"{rank} was supposed to fail was fine") - + self.assertTrue( + rank in failed_ranks, + msg=f"{rank} was supposed to fail was fine", + ) def _test_save(self, state_dict, coordinator=0, **kwargs): no_dist = not dist.is_initialized() @@ -296,6 +302,7 @@ def _save(): coordinator_rank=coordinator, no_dist=no_dist, ) + self._test_dist_failure(_save, kwargs) def _test_load(self, state_dict, coordinator=0, **kwargs): @@ -317,9 +324,9 @@ def _load(): @requires_nccl() def test_save_error_handling(self) -> None: state_dict = { - 'sharded': sharded_tensor.rand(self.get_spec(), 20, 20), - 'replicated': torch.rand(10, 10), - 'bytes': [1, 2, 3, 4] + "sharded": sharded_tensor.rand(self.get_spec(), 20, 20), + "replicated": torch.rand(10, 10), + "bytes": [1, 2, 3, 4], } self._test_save(state_dict, fail_init=[0]) @@ -334,10 +341,7 @@ def test_save_error_handling(self) -> None: self._test_save(state_dict, coordinator=1, fail_finish=[1]) def test_save_error_handling_no_dist(self) -> None: - state_dict = { - 'replicated': torch.rand(10, 10), - 'bytes': [1, 2, 3, 4] - } + state_dict = {"replicated": torch.rand(10, 10), "bytes": [1, 2, 3, 4]} self.assertFalse(dist.is_initialized()) @@ -354,9 +358,9 @@ def test_save_error_handling_no_dist(self) -> None: @requires_nccl() def test_load_error_handling(self) -> None: state_dict = { - 'sharded': sharded_tensor.rand(self.get_spec(), 20, 20), - 'replicated': torch.rand(10, 10), - 'bytes': [1, 2, 3, 4] + "sharded": sharded_tensor.rand(self.get_spec(), 20, 20), + "replicated": torch.rand(10, 10), + "bytes": [1, 2, 3, 4], } self._test_load(state_dict) @@ -373,12 +377,8 @@ def test_load_error_handling(self) -> None: self._test_load(state_dict, coordinator=3, fail_read_data_async=[2]) self._test_load(state_dict, coordinator=1, fail_prepare_global_plan=[1]) - def test_load_error_handling_no_dist(self) -> None: - state_dict = { - 'replicated': torch.rand(10, 10), - 'bytes': [1, 2, 3, 4] - } + state_dict = {"replicated": torch.rand(10, 10), "bytes": [1, 2, 3, 4]} self._test_load(state_dict) self._test_load(state_dict, fail_init=[0]) self._test_load(state_dict, fail_read_metadata=[0]) @@ -387,5 +387,6 @@ def test_load_error_handling_no_dist(self) -> None: self._test_load(state_dict, fail_read_data=[0]) self._test_load(state_dict, fail_read_data_async=[0]) + if __name__ == "__main__": run_tests() diff --git a/test/distributed/checkpoint/test_dedup_tensors.py b/test/distributed/checkpoint/test_dedup_tensors.py new file mode 100644 index 0000000000000..6f2b81c298df7 --- /dev/null +++ b/test/distributed/checkpoint/test_dedup_tensors.py @@ -0,0 +1,45 @@ +# Owner(s): ["oncall: distributed"] + +import dataclasses +import torch +from torch.distributed.checkpoint._dedup_tensors import dedup_tensors +from torch.distributed.checkpoint.planner import SavePlan, WriteItemType +from torch.distributed.checkpoint.planner_helpers import ( + _create_write_item_for_tensor, +) +from torch.testing._internal.common_utils import run_tests, TestCase + + +# TODO: add comments for create_plan +def create_plan(second_fqn) -> SavePlan: + # the first write item is for a duplicated shard (that covers the whole tensor) + write_item_1 = _create_write_item_for_tensor("tensor_0", torch.rand(4)) + write_item_1 = dataclasses.replace(write_item_1, type=WriteItemType.SHARD) + + # the second write item has different keys + write_item_2 = _create_write_item_for_tensor(second_fqn, torch.rand(10)) + + return SavePlan([write_item_1, write_item_2]) + + +# TODO: add comments for TestDedupTensor +class TestDedupTensor(TestCase): + def test_dedup_shards(self): + rank0 = create_plan("r0") + rank1 = create_plan("r1") + + dedup_plans = dedup_tensors([rank0, rank1]) + + self.assertEqual(2, len(dedup_plans[0].items)) + self.assertEqual(1, len(dedup_plans[1].items)) + + self.assertIn( + "tensor_0", (item.index.fqn for item in dedup_plans[0].items) + ) + self.assertIn("r0", (item.index.fqn for item in dedup_plans[0].items)) + + self.assertIn("r1", (item.index.fqn for item in dedup_plans[1].items)) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/_shard/checkpoint/test_file_system_checkpoint.py b/test/distributed/checkpoint/test_file_system_checkpoint.py similarity index 83% rename from test/distributed/_shard/checkpoint/test_file_system_checkpoint.py rename to test/distributed/checkpoint/test_file_system_checkpoint.py index b5cc38767c962..016467144e8ff 100644 --- a/test/distributed/_shard/checkpoint/test_file_system_checkpoint.py +++ b/test/distributed/checkpoint/test_file_system_checkpoint.py @@ -8,21 +8,27 @@ import torch import torch.distributed as dist from torch.distributed._shard import sharded_tensor -from torch.distributed._shard.sharded_tensor import ShardedTensor, state_dict_hook +from torch.distributed._shard.sharded_tensor import ( + ShardedTensor, + state_dict_hook, +) from torch.distributed._shard.sharding_spec import ( ChunkShardingSpec, EnumerableShardingSpec, ShardingSpec, ShardMetadata, ) -from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu +from torch.testing._internal.common_distributed import ( + requires_nccl, + skip_if_lt_x_gpu, +) from torch.testing._internal.common_utils import TestCase from torch.testing._internal.distributed._shard.sharded_tensor import ( ShardedTensorTestBase, with_comms, ) from torch.testing._internal.distributed._shard.sharded_tensor._test_st_common import ( - MyShardedModel1 + MyShardedModel1, ) @@ -31,7 +37,7 @@ run_tests, ) -from torch.distributed._shard.checkpoint import ( +from torch.distributed.checkpoint import ( FileSystemReader, FileSystemWriter, load_state_dict, @@ -73,7 +79,8 @@ def assert_state_dict_equal( ) elif isinstance(value_1, torch.Tensor): self.assertTrue( - torch.equal(value_1, value_2), f"Key {key}'s tensor does not match" + torch.equal(value_1, value_2), + f"Key {key}'s tensor does not match", ) return True @@ -105,35 +112,59 @@ def test_read_write_only_tensor(self) -> None: state_dict_to_save = MyTestModule().state_dict() fs_writer = FileSystemWriter(path=path) - save_state_dict(state_dict=state_dict_to_save, storage_writer=fs_writer, no_dist=True) + save_state_dict( + state_dict=state_dict_to_save, + storage_writer=fs_writer, + no_dist=True, + ) state_dict_to_load_to = MyTestModule().state_dict() with self.assertRaises(AssertionError): - assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save) + assert_state_dict_equal( + self, state_dict_to_load_to, state_dict_to_save + ) # Load from file without any resharding fs_reader = FileSystemReader(path=path) - load_state_dict(state_dict=state_dict_to_load_to, storage_reader=fs_reader, no_dist=True) + load_state_dict( + state_dict=state_dict_to_load_to, + storage_reader=fs_reader, + no_dist=True, + ) - assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save) + assert_state_dict_equal( + self, state_dict_to_load_to, state_dict_to_save + ) with tempfile.TemporaryDirectory() as path: state_dict_to_save = MyTestModule().state_dict() fs_writer = FileSystemWriter(path=path, single_file_per_rank=True) - save_state_dict(state_dict=state_dict_to_save, storage_writer=fs_writer, no_dist=True) + save_state_dict( + state_dict=state_dict_to_save, + storage_writer=fs_writer, + no_dist=True, + ) state_dict_to_load_to = MyTestModule().state_dict() with self.assertRaises(AssertionError): - assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save) + assert_state_dict_equal( + self, state_dict_to_load_to, state_dict_to_save + ) # Load from file without any resharding fs_reader = FileSystemReader(path=path) - load_state_dict(state_dict=state_dict_to_load_to, storage_reader=fs_reader, no_dist=True) + load_state_dict( + state_dict=state_dict_to_load_to, + storage_reader=fs_reader, + no_dist=True, + ) - assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save) + assert_state_dict_equal( + self, state_dict_to_load_to, state_dict_to_save + ) class TestDistributedStateDictSaveLoadWithSharedTensor(ShardedTensorTestBase): @@ -180,11 +211,15 @@ def test_read_write_shard_tensor(self) -> None: dist.barrier() with self.assertRaises(AssertionError): - assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save) + assert_state_dict_equal( + self, state_dict_to_load_to, state_dict_to_save + ) # Test load. fs_reader = FileSystemReader(path=path) - load_state_dict(state_dict=state_dict_to_load_to, storage_reader=fs_reader) + load_state_dict( + state_dict=state_dict_to_load_to, storage_reader=fs_reader + ) assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save) dist.barrier() @@ -201,7 +236,11 @@ def get_file_path(self) -> str: return paths[0] def load_tensor(self, tensor: ShardedTensor) -> torch.Tensor: - res = torch.zeros(tensor.shape, device="cuda:0") if dist.get_rank() == 0 else None + res = ( + torch.zeros(tensor.shape, device="cuda:0") + if dist.get_rank() == 0 + else None + ) tensor.gather(out=res) return res @@ -295,7 +334,9 @@ def test_load_with_different_shard_plan(self) -> None: state_dict_to_save = model_to_save.state_dict() fs_writer = FileSystemWriter(path=path) - save_state_dict(state_dict=state_dict_to_save, storage_writer=fs_writer) + save_state_dict( + state_dict=state_dict_to_save, storage_writer=fs_writer + ) dist.barrier() @@ -316,7 +357,8 @@ def test_load_with_different_shard_plan(self) -> None: if dist.get_rank() == 0: self.assertTrue( - torch.allclose(store_tensor, load_tensor), msg=f"{s0} vs {s1}" + torch.allclose(store_tensor, load_tensor), + msg=f"{s0} vs {s1}", ) @with_comms(init_rpc=False) @@ -361,7 +403,9 @@ def test_load_rowwise_to_colwise(self) -> None: fs_reader = FileSystemReader(path=path) - load_state_dict(state_dict=state_dict_to_load_to, storage_reader=fs_reader) + load_state_dict( + state_dict=state_dict_to_load_to, storage_reader=fs_reader + ) # We can't use torch.allclose since each ST has a different sharding spec store_tensor = self.load_tensor(model_to_save.sharded_tensor) @@ -370,32 +414,24 @@ def test_load_rowwise_to_colwise(self) -> None: if dist.get_rank() == 0: self.assertTrue(torch.allclose(store_tensor, load_tensor)) - @with_comms(init_rpc=False) @skip_if_lt_x_gpu(2) @requires_nccl() def test_save_load_bytes(self) -> None: path = self.get_file_path() - state_dict_to_save = { - 'bytes0': [1], - 'bytes1': 'string' - } + state_dict_to_save = {"bytes0": [1], "bytes1": "string"} fs_writer = FileSystemWriter(path=path) save_state_dict(state_dict=state_dict_to_save, storage_writer=fs_writer) - state_dict_to_load = { - 'bytes0': [2], - 'bytes1': 'other' - } + state_dict_to_load = {"bytes0": [2], "bytes1": "other"} fs_reader = FileSystemReader(path=path) load_state_dict(state_dict=state_dict_to_load, storage_reader=fs_reader) - self.assertEqual([1], state_dict_to_load['bytes0']) - self.assertEqual('string', state_dict_to_load['bytes1']) - + self.assertEqual([1], state_dict_to_load["bytes0"]) + self.assertEqual("string", state_dict_to_load["bytes1"]) @with_comms(init_rpc=False) @skip_if_lt_x_gpu(2) @@ -454,8 +490,8 @@ def test_switch_between_sharded_tensor_to_tensor(self) -> None: for save_spec in specs: for load_spec in specs: save_dict = { - 'sharded': sharded_tensor.rand(save_spec, tensor_size), - 'replicated': torch.rand(tensor_size, device=self.rank) + "sharded": sharded_tensor.rand(save_spec, tensor_size), + "replicated": torch.rand(tensor_size, device=self.rank), } fs_writer = FileSystemWriter(path=path) @@ -463,25 +499,28 @@ def test_switch_between_sharded_tensor_to_tensor(self) -> None: # Freaky Friday the tensors load_dict = { - 'sharded': torch.zeros(tensor_size, device=self.rank), - 'replicated': sharded_tensor.zeros(load_spec, tensor_size) + "sharded": torch.zeros(tensor_size, device=self.rank), + "replicated": sharded_tensor.zeros(load_spec, tensor_size), } fs_reader = FileSystemReader(path=path) load_state_dict(state_dict=load_dict, storage_reader=fs_reader) - save_dict_sharded = self.load_tensor(save_dict['sharded']) - load_dict_replicated = self.load_tensor(load_dict['replicated']) + save_dict_sharded = self.load_tensor(save_dict["sharded"]) + load_dict_replicated = self.load_tensor(load_dict["replicated"]) if dist.get_rank() == 0: self.assertTrue( - torch.allclose(save_dict_sharded, load_dict['sharded']), - f"save-spec {save_spec} load-spec {load_spec}" + torch.allclose(save_dict_sharded, load_dict["sharded"]), + f"save-spec {save_spec} load-spec {load_spec}", ) self.assertTrue( - torch.allclose(save_dict['replicated'], load_dict_replicated), - f"save-spec {save_spec} load-spec {load_spec}" + torch.allclose( + save_dict["replicated"], load_dict_replicated + ), + f"save-spec {save_spec} load-spec {load_spec}", ) + if __name__ == "__main__": run_tests() diff --git a/test/distributed/_shard/checkpoint/test_file_system_checkpoint_cpu.py b/test/distributed/checkpoint/test_file_system_checkpoint_cpu.py similarity index 76% rename from test/distributed/_shard/checkpoint/test_file_system_checkpoint_cpu.py rename to test/distributed/checkpoint/test_file_system_checkpoint_cpu.py index 321dc2f546883..52e414545c049 100644 --- a/test/distributed/_shard/checkpoint/test_file_system_checkpoint_cpu.py +++ b/test/distributed/checkpoint/test_file_system_checkpoint_cpu.py @@ -9,7 +9,10 @@ import torch import torch.distributed as dist from torch.distributed._shard import sharded_tensor -from torch.distributed._shard.sharded_tensor import ShardedTensor, state_dict_hook +from torch.distributed._shard.sharded_tensor import ( + ShardedTensor, + state_dict_hook, +) from torch.distributed._shard.sharding_spec import ( ChunkShardingSpec, EnumerableShardingSpec, @@ -22,16 +25,18 @@ with_comms, ) from torch.testing._internal.distributed._shard.sharded_tensor._test_st_common import ( - MyShardedModel1 + MyShardedModel1, ) from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, TEST_WITH_DEV_DBG_ASAN, run_tests, ) -from torch.distributed._shard.checkpoint import ( +from torch.distributed.checkpoint import ( FileSystemReader, FileSystemWriter, load_state_dict, @@ -47,6 +52,9 @@ sys.exit(0) +_THREAD_COUNTS = {1, 2} + + def assert_state_dict_equal( self: TestCase, state_dict_1: Dict[str, torch.Tensor], @@ -73,7 +81,8 @@ def assert_state_dict_equal( ) elif isinstance(value_1, torch.Tensor): self.assertTrue( - torch.equal(value_1, value_2), f"Key {key}'s tensor does not match" + torch.equal(value_1, value_2), + f"Key {key}'s tensor does not match", ) return True @@ -100,23 +109,36 @@ def __init__( class TestDistributedStateDictSaveLoad(TestCase): - def test_read_write_only_tensor(self) -> None: + @parametrize("thread_count", _THREAD_COUNTS) + def test_read_write_only_tensor(self, thread_count) -> None: with tempfile.TemporaryDirectory() as path: state_dict_to_save = MyTestModule().state_dict() - fs_writer = FileSystemWriter(path=path) - save_state_dict(state_dict=state_dict_to_save, storage_writer=fs_writer, no_dist=True) + fs_writer = FileSystemWriter(path=path, thread_count=thread_count) + save_state_dict( + state_dict=state_dict_to_save, + storage_writer=fs_writer, + no_dist=True, + ) state_dict_to_load_to = MyTestModule().state_dict() with self.assertRaises(AssertionError): - assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save) + assert_state_dict_equal( + self, state_dict_to_load_to, state_dict_to_save + ) # Load from file without any resharding fs_reader = FileSystemReader(path=path) - load_state_dict(state_dict=state_dict_to_load_to, storage_reader=fs_reader, no_dist=True) + load_state_dict( + state_dict=state_dict_to_load_to, + storage_reader=fs_reader, + no_dist=True, + ) - assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save) + assert_state_dict_equal( + self, state_dict_to_load_to, state_dict_to_save + ) class TestDistributedStateDictSaveLoadWithSharedTensor(ShardedTensorTestBase): @@ -125,7 +147,8 @@ def world_size(self) -> int: return 2 @with_comms(init_rpc=False, backend="gloo") - def test_read_write_shard_tensor(self) -> None: + @parametrize("thread_count", _THREAD_COUNTS) + def test_read_write_shard_tensor(self, thread_count) -> None: paths = [tempfile.mkdtemp()] dist.broadcast_object_list(paths) @@ -146,7 +169,7 @@ def test_read_write_shard_tensor(self) -> None: model_to_save._register_state_dict_hook(state_dict_hook) state_dict_to_save = model_to_save.state_dict() - fs_writer = FileSystemWriter(path=path) + fs_writer = FileSystemWriter(path=path, thread_count=thread_count) save_state_dict(state_dict=state_dict_to_save, storage_writer=fs_writer) dist.barrier() @@ -161,11 +184,15 @@ def test_read_write_shard_tensor(self) -> None: dist.barrier() with self.assertRaises(AssertionError): - assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save) + assert_state_dict_equal( + self, state_dict_to_load_to, state_dict_to_save + ) # Test load. fs_reader = FileSystemReader(path=path) - load_state_dict(state_dict=state_dict_to_load_to, storage_reader=fs_reader) + load_state_dict( + state_dict=state_dict_to_load_to, storage_reader=fs_reader + ) assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save) dist.barrier() @@ -182,12 +209,17 @@ def get_file_path(self) -> str: return paths[0] def load_tensor(self, tensor: ShardedTensor) -> torch.Tensor: - res = torch.zeros(tensor.shape, device="cpu") if dist.get_rank() == 0 else None + res = ( + torch.zeros(tensor.shape, device="cpu") + if dist.get_rank() == 0 + else None + ) tensor.gather(out=res) return res @with_comms(init_rpc=False, backend="gloo") - def test_load_with_different_shard_plan(self) -> None: + @parametrize("thread_count", _THREAD_COUNTS) + def test_load_with_different_shard_plan(self, thread_count) -> None: path = self.get_file_path() # We hardcode the assumption of how many shards are around @@ -273,8 +305,12 @@ def test_load_with_different_shard_plan(self) -> None: model_to_save._register_state_dict_hook(state_dict_hook) state_dict_to_save = model_to_save.state_dict() - fs_writer = FileSystemWriter(path=path) - save_state_dict(state_dict=state_dict_to_save, storage_writer=fs_writer) + fs_writer = FileSystemWriter( + path=path, thread_count=thread_count + ) + save_state_dict( + state_dict=state_dict_to_save, storage_writer=fs_writer + ) dist.barrier() @@ -295,11 +331,13 @@ def test_load_with_different_shard_plan(self) -> None: if dist.get_rank() == 0: self.assertTrue( - torch.allclose(store_tensor, load_tensor), msg=f"{s0} vs {s1}" + torch.allclose(store_tensor, load_tensor), + msg=f"{s0} vs {s1}", ) @with_comms(init_rpc=False, backend="gloo") - def test_load_rowwise_to_colwise(self) -> None: + @parametrize("thread_count", _THREAD_COUNTS) + def test_load_rowwise_to_colwise(self, thread_count) -> None: path = self.get_file_path() self.assertEqual(self.world_size, dist.get_world_size()) @@ -329,7 +367,7 @@ def test_load_rowwise_to_colwise(self) -> None: model_to_save._register_state_dict_hook(state_dict_hook) state_dict_to_save = model_to_save.state_dict() - fs_writer = FileSystemWriter(path=path) + fs_writer = FileSystemWriter(path=path, thread_count=thread_count) save_state_dict(state_dict=state_dict_to_save, storage_writer=fs_writer) model_to_load = MyShardedModel3(dst_spec).cuda(dist.get_rank()) @@ -338,7 +376,9 @@ def test_load_rowwise_to_colwise(self) -> None: fs_reader = FileSystemReader(path=path) - load_state_dict(state_dict=state_dict_to_load_to, storage_reader=fs_reader) + load_state_dict( + state_dict=state_dict_to_load_to, storage_reader=fs_reader + ) # We can't use torch.allclose since each ST has a different sharding spec store_tensor = self.load_tensor(model_to_save.sharded_tensor) @@ -347,33 +387,29 @@ def test_load_rowwise_to_colwise(self) -> None: if dist.get_rank() == 0: self.assertTrue(torch.allclose(store_tensor, load_tensor)) - @with_comms(init_rpc=False, backend="gloo") - def test_save_load_bytes(self) -> None: + @parametrize("thread_count", _THREAD_COUNTS) + def test_save_load_bytes(self, thread_count) -> None: path = self.get_file_path() - state_dict_to_save = { - 'bytes0': [1], - 'bytes1': 'string' - } + state_dict_to_save = {"bytes0": [1], "bytes1": "string"} - fs_writer = FileSystemWriter(path=path) + fs_writer = FileSystemWriter(path=path, thread_count=thread_count) save_state_dict(state_dict=state_dict_to_save, storage_writer=fs_writer) - state_dict_to_load = { - 'bytes0': [2], - 'bytes1': 'other' - } + state_dict_to_load = {"bytes0": [2], "bytes1": "other"} fs_reader = FileSystemReader(path=path) load_state_dict(state_dict=state_dict_to_load, storage_reader=fs_reader) - self.assertEqual([1], state_dict_to_load['bytes0']) - self.assertEqual('string', state_dict_to_load['bytes1']) - + self.assertEqual([1], state_dict_to_load["bytes0"]) + self.assertEqual("string", state_dict_to_load["bytes1"]) @with_comms(init_rpc=False, backend="gloo") - def test_switch_between_sharded_tensor_to_tensor(self) -> None: + @parametrize("thread_count", _THREAD_COUNTS) + def test_switch_between_sharded_tensor_to_tensor( + self, thread_count + ) -> None: path = self.get_file_path() tensor_size = 32 @@ -427,34 +463,47 @@ def test_switch_between_sharded_tensor_to_tensor(self) -> None: for save_spec in specs: for load_spec in specs: save_dict = { - 'sharded': sharded_tensor.rand(save_spec, tensor_size), - 'replicated': torch.rand(tensor_size, device=f"cpu:{self.rank}") + "sharded": sharded_tensor.rand(save_spec, tensor_size), + "replicated": torch.rand( + tensor_size, device=f"cpu:{self.rank}" + ), } - fs_writer = FileSystemWriter(path=path) + fs_writer = FileSystemWriter( + path=path, thread_count=thread_count + ) save_state_dict(state_dict=save_dict, storage_writer=fs_writer) # Freaky Friday the tensors load_dict = { - 'sharded': torch.zeros(tensor_size, device=f"cpu:{self.rank}"), - 'replicated': sharded_tensor.zeros(load_spec, tensor_size) + "sharded": torch.zeros( + tensor_size, device=f"cpu:{self.rank}" + ), + "replicated": sharded_tensor.zeros(load_spec, tensor_size), } fs_reader = FileSystemReader(path=path) load_state_dict(state_dict=load_dict, storage_reader=fs_reader) - save_dict_sharded = self.load_tensor(save_dict['sharded']) - load_dict_replicated = self.load_tensor(load_dict['replicated']) + save_dict_sharded = self.load_tensor(save_dict["sharded"]) + load_dict_replicated = self.load_tensor(load_dict["replicated"]) if dist.get_rank() == 0: self.assertTrue( - torch.allclose(save_dict_sharded, load_dict['sharded']), - f"save-spec {save_spec} load-spec {load_spec}" + torch.allclose(save_dict_sharded, load_dict["sharded"]), + f"save-spec {save_spec} load-spec {load_spec}", ) self.assertTrue( - torch.allclose(save_dict['replicated'], load_dict_replicated), - f"save-spec {save_spec} load-spec {load_spec}" + torch.allclose( + save_dict["replicated"], load_dict_replicated + ), + f"save-spec {save_spec} load-spec {load_spec}", ) + +instantiate_parametrized_tests(TestDistributedStateDictSaveLoad) +instantiate_parametrized_tests(TestDistributedStateDictSaveLoadWithSharedTensor) +instantiate_parametrized_tests(TestDistributedReshardOnLoad) + if __name__ == "__main__": run_tests() diff --git a/test/distributed/checkpoint/test_fsdp_optim_state.py b/test/distributed/checkpoint/test_fsdp_optim_state.py new file mode 100644 index 0000000000000..173542ee91253 --- /dev/null +++ b/test/distributed/checkpoint/test_fsdp_optim_state.py @@ -0,0 +1,112 @@ +# Owner(s): ["oncall: distributed"] + +import torch + +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType +import torch.distributed.checkpoint as dist_cp +import torch.distributed as dist + +from torch.distributed.checkpoint.default_planner import ( + DefaultSavePlanner, + DefaultLoadPlanner, +) +from torch.distributed.checkpoint.optimizer import ( + load_sharded_optimizer_state_dict, +) + +from torch.testing._internal.distributed._tensor.common_dtensor import ( + DTensorTestBase, + with_comms, +) +from torch.testing._internal.common_distributed import skip_if_lt_x_gpu +from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir + + +class FsdpOptimStateCheckpoint(DTensorTestBase): + @with_comms + @skip_if_lt_x_gpu(4) + @with_temp_dir + def test_distributed_tensor_planner(self) -> None: + CHECKPOINT_DIR = self.temp_dir + + model = FSDP(torch.nn.Linear(8, 8, device="meta")) + optim = torch.optim.Adam(model.parameters(), lr=0.1) + + model(torch.rand(8, 8, device=dist.get_rank())).sum().backward() + optim.step() + + with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): + state_dict = { + "model": model.state_dict(), + "optim": FSDP.sharded_optim_state_dict(model, optim), + } + + dist_cp.save_state_dict( + state_dict=state_dict, + storage_writer=dist_cp.FileSystemWriter(CHECKPOINT_DIR), + planner=DefaultSavePlanner( + flatten_state_dict=True, + flatten_sharded_tensors=True, + ), + ) + + # now load the model and ensure the values are the same + model_2 = FSDP(torch.nn.Linear(8, 8, device="meta")) + optim_2 = torch.optim.Adam(model_2.parameters(), lr=0.1) + + with FSDP.summon_full_params(model): + with FSDP.summon_full_params(model_2): + self.assertNotEqual(model.weight, model_2.weight) + self.assertNotEqual(model.bias, model_2.bias) + + # Adam lazily creates its state + self.assertEqual(0, len(optim_2.state)) + + with FSDP.state_dict_type(model_2, StateDictType.SHARDED_STATE_DICT): + state_dict = { + "model": model_2.state_dict(), + # cannot load the optimizer together with the model + } + + dist_cp.load_state_dict( + state_dict=state_dict, + storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR), + planner=DefaultLoadPlanner( + flatten_state_dict=True, + flatten_sharded_tensors=True, + ), + ) + model_2.load_state_dict(state_dict["model"]) + + optim_state = load_sharded_optimizer_state_dict( + model_state_dict=state_dict["model"], + optimizer_key="optim", + storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR), + ) + + flattened_osd = FSDP.flatten_sharded_optim_state_dict( + optim_state["optim"], model_2 + ) + optim_2.load_state_dict(flattened_osd) + + with FSDP.summon_full_params(model): + with FSDP.summon_full_params(model_2): + self.assertEqual(model.weight, model_2.weight) + self.assertEqual(model.bias, model_2.bias) + + def opt_at(opt, idx): + return list(iter(opt.state.values()))[idx] + + # Adam lazily creates its state + self.assertEqual( + opt_at(optim, 0)["exp_avg"], opt_at(optim_2, 0)["exp_avg"] + ) + self.assertEqual( + opt_at(optim, 0)["exp_avg_sq"], opt_at(optim_2, 0)["exp_avg_sq"] + ) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/checkpoint/test_nested_dict.py b/test/distributed/checkpoint/test_nested_dict.py new file mode 100644 index 0000000000000..115982e818127 --- /dev/null +++ b/test/distributed/checkpoint/test_nested_dict.py @@ -0,0 +1,62 @@ +# Owner(s): ["oncall: distributed"] + +import torch +from torch.testing._internal.common_utils import run_tests, TestCase +from torch.distributed.checkpoint._nested_dict import ( + flatten_state_dict, + unflatten_state_dict, +) + + +class TestFlattening(TestCase): + def test_flattening_round_trip(self) -> None: + state_dict = { + "key0": 1, + "key1": [1, 2], + "key2": {1: 2, 2: 3}, + "key3": torch.tensor([1]), + "key4": [[torch.tensor(2), "x"], [1, 2, 3], {"key6": [44]}], + } + + flatten_dict, mapping = flatten_state_dict(state_dict) + """ + flatten_dict: + { + 'key0': 1, + 'key1': [1, 2], + 'key2': {1: 2, 2: 3}, + 'key3': tensor([1]), + 'key4.0.0': tensor(2), + 'key4.0.1': 'x', + 'key4.1': [1, 2, 3], + 'key4.2': {'key6': [44]} + } + """ + restored = unflatten_state_dict(flatten_dict, mapping) + + self.assertEqual(state_dict, restored) + + def test_mapping(self) -> None: + state_dict = { + "k0": [1], + "k2": [torch.tensor([1]), 99, [{"k3": torch.tensor(1)}]], + "k3": ["x", 99, [{"k3": "y"}]], + } + + flatten_dict, mapping = flatten_state_dict(state_dict) + """ + flatten_dict: + {'k0': [1], 'k2.0': tensor([1]), 'k2.1': 99, 'k2.2.0.k3': tensor(1), 'k3': ['x', 99, [{'k3': 'y'}]]} + mapping: + {'k0': ('k0',), 'k2.0': ('k2', 0), 'k2.1': ('k2', 1), 'k2.2.0.k3': ('k2', 2, 0, 'k3'), 'k3': ('k3',)} + """ + + self.assertEqual(("k0",), mapping["k0"]) + self.assertEqual(("k2", 0), mapping["k2.0"]) + self.assertEqual(("k2", 1), mapping["k2.1"]) + self.assertEqual(("k2", 2, 0, "k3"), mapping["k2.2.0.k3"]) + self.assertEqual(("k3",), mapping["k3"]) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/_shard/checkpoint/test_planner.py b/test/distributed/checkpoint/test_planner.py similarity index 97% rename from test/distributed/_shard/checkpoint/test_planner.py rename to test/distributed/checkpoint/test_planner.py index 56373bd67c6d9..334fba237a9ba 100644 --- a/test/distributed/_shard/checkpoint/test_planner.py +++ b/test/distributed/checkpoint/test_planner.py @@ -3,7 +3,7 @@ import sys import torch -from torch.distributed._shard.checkpoint.planner import LoadItemType, WriteItemType +from torch.distributed.checkpoint.planner import LoadItemType, WriteItemType from torch.distributed._shard.sharded_tensor import ( Shard, @@ -18,13 +18,13 @@ TEST_WITH_DEV_DBG_ASAN, run_tests, ) -from torch.distributed._shard.checkpoint.metadata import BytesStorageMetadata, MetadataIndex, TensorStorageMetadata +from torch.distributed.checkpoint.metadata import BytesStorageMetadata, MetadataIndex, TensorStorageMetadata from torch.testing._internal.distributed.distributed_utils import ( with_fake_comms, with_dist ) -from torch.distributed._shard.checkpoint.default_planner import ( +from torch.distributed.checkpoint.default_planner import ( create_default_global_save_plan, create_default_local_save_plan, create_default_local_load_plan, diff --git a/test/distributed/checkpoint/test_traverse.py b/test/distributed/checkpoint/test_traverse.py new file mode 100644 index 0000000000000..3a47311e702bd --- /dev/null +++ b/test/distributed/checkpoint/test_traverse.py @@ -0,0 +1,176 @@ +# Owner(s): ["oncall: distributed"] + +from collections import OrderedDict +import torch + +import torch.distributed.checkpoint._traverse as _traverse +from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE +from torch.testing._internal.common_utils import run_tests, TestCase + + +# TODO: add comments for TestTraverse +class TestTraverse(TestCase): + def test_traverse_shallow(self) -> None: + state_dict = { + "key0": 1, + "key1": [1, 2], + "key2": {1: 2, 2: 3}, + "key3": torch.tensor([1]), + } + + data = {} + + def collect_data(path, value): + nonlocal data + data[path] = value + + _traverse.traverse_state_dict(state_dict, collect_data) + + self.assertIn(("key0",), data) + self.assertEqual(data[("key0",)], 1) + + self.assertIn(("key1",), data) + self.assertEqual(data[("key1",)], [1, 2]) + + self.assertIn(("key2",), data) + self.assertEqual(data[("key2",)], {1: 2, 2: 3}) + + self.assertIn(("key3",), data) + self.assertEqual(data[("key3",)], torch.tensor([1])) + + def test_traverse_nested_list(self) -> None: + state_dict = { + "key1": [ + torch.tensor([1]), + [33, torch.tensor([2]), [44, 55]], + [66, 77], + ], + } + + data = {} + + def collect_data(path, value): + nonlocal data + data[path] = value + + _traverse.traverse_state_dict(state_dict, collect_data) + + self.assertNotIn(("key1"), data) + + self.assertIn(("key1", 0), data) + self.assertEqual(data[("key1", 0)], torch.tensor([1])) + + self.assertIn(("key1", 1, 0), data) + self.assertEqual(data[("key1", 1, 0)], 33) + + self.assertIn(("key1", 1, 1), data) + self.assertEqual(data[("key1", 1, 1)], torch.tensor([2])) + + self.assertIn(("key1", 1, 2), data) + self.assertEqual(data[("key1", 1, 2)], [44, 55]) + self.assertNotIn(("key1", 1, 2, 0), data) + + self.assertIn(("key1", 2), data) + self.assertEqual(data[("key1", 2)], [66, 77]) + + def test_traverse_nested_dict(self) -> None: + state_dict = { + "key0": {"key1": 99, "key2": torch.tensor([1])}, + } + + data = {} + + def collect_data(path, value): + nonlocal data + data[path] = value + + _traverse.traverse_state_dict(state_dict, collect_data) + + self.assertNotIn(("key0",), data) + + self.assertIn(("key0", "key1"), data) + self.assertEqual(data[("key0", "key1")], 99) + + self.assertIn(("key0", "key2"), data) + self.assertEqual(data[("key0", "key2")], torch.tensor([1])) + + def test_traverse_doesnt_ignore_intermediate_collections(self) -> None: + state_dict: STATE_DICT_TYPE = { + "key0": [{"key1": {"key2": torch.tensor([1])}}] + } + + data = {} + + def collect_data(path, value): + nonlocal data + data[path] = value + + _traverse.traverse_state_dict(state_dict, collect_data) + + self.assertIn(("key0", 0, "key1", "key2"), data) + self.assertEqual( + data[("key0", 0, "key1", "key2")], + torch.tensor([1]), + ) + + def test_traverse_with_ordered_dict(self) -> None: + state_dict = OrderedDict( + { + "key0": [ + 99, + torch.tensor([3]), + ] + } + ) + + data = {} + + def collect_data(path, value): + nonlocal data + data[path] = value + + _traverse.traverse_state_dict(state_dict, collect_data) + + self.assertIn(("key0", 0), data) + self.assertEqual(data[("key0", 0)], 99) + + self.assertIn(("key0", 1), data) + self.assertEqual(data[("key0", 1)], torch.tensor([3])) + + def test_set_element(self) -> None: + state_dict: STATE_DICT_TYPE = {} + + _traverse.set_element(state_dict, ("k",), 10) + self.assertEqual(state_dict["k"], 10) + + _traverse.set_element(state_dict, ("k1", 2), 1) + self.assertEqual(state_dict["k1"], [None, None, 1]) + + _traverse.set_element(state_dict, ("k1", 1), 99) + self.assertEqual(state_dict["k1"], [None, 99, 1]) + + _traverse.set_element(state_dict, ("k1", 3), 88) + self.assertEqual(state_dict["k1"], [None, 99, 1, 88]) + + _traverse.set_element(state_dict, ("k2", "k3"), 3) + self.assertEqual(state_dict["k2"], {"k3": 3}) + + _traverse.set_element(state_dict, ("k2", "k4", 0, 0), 99) + self.assertEqual(state_dict["k2"]["k4"][0], [99]) + + def test_get_element(self) -> None: + state_dict = {"a": [0, 1], "b": [2, {"c": "d"}]} + self.assertEqual(_traverse.get_element(state_dict, ("a",)), [0, 1]) + self.assertEqual(_traverse.get_element(state_dict, ("b", 0)), 2) + self.assertEqual(_traverse.get_element(state_dict, ("b", 1, "c")), "d") + + self.assertIsNone(_traverse.get_element(state_dict, ("c",))) + self.assertIsNone(_traverse.get_element(state_dict, ("a", 33))) + self.assertIsNone(_traverse.get_element(state_dict, ("b", 88))) + self.assertIsNone(_traverse.get_element(state_dict, ("b", 0, 2))) + self.assertIsNone(_traverse.get_element(state_dict, ("b", 1, 2))) + self.assertIsNone(_traverse.get_element(state_dict, ("b", 1, "d"))) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/_shard/checkpoint/test_utils.py b/test/distributed/checkpoint/test_utils.py similarity index 96% rename from test/distributed/_shard/checkpoint/test_utils.py rename to test/distributed/checkpoint/test_utils.py index e99a9cf863e4f..e2b4aac605bf6 100644 --- a/test/distributed/_shard/checkpoint/test_utils.py +++ b/test/distributed/checkpoint/test_utils.py @@ -17,8 +17,8 @@ TEST_WITH_DEV_DBG_ASAN, run_tests, ) -from torch.distributed._shard.checkpoint.utils import find_state_dict_object -from torch.distributed._shard.checkpoint.metadata import MetadataIndex +from torch.distributed.checkpoint.utils import find_state_dict_object +from torch.distributed.checkpoint.metadata import MetadataIndex from torch.testing._internal.distributed.distributed_utils import ( with_fake_comms ) diff --git a/test/distributed/fsdp/test_checkpoint_wrapper.py b/test/distributed/fsdp/test_checkpoint_wrapper.py index 8bd2b74695d3b..d8e005fcf82be 100644 --- a/test/distributed/fsdp/test_checkpoint_wrapper.py +++ b/test/distributed/fsdp/test_checkpoint_wrapper.py @@ -1,30 +1,25 @@ # Owner(s): ["oncall: distributed"] +import unittest from copy import deepcopy from functools import partial import torch import torch.nn as nn from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( - checkpoint_wrapper, - offload_wrapper, apply_activation_checkpointing, + checkpoint_wrapper, + CheckpointImpl, CheckpointWrapper, + offload_wrapper, OffloadWrapper, - CheckpointImpl ) - +from torch.testing._internal.common_utils import run_tests, TestCase from torch.utils.checkpoint import checkpoint -from torch.testing._internal.common_utils import ( - run_tests, - TestCase, -) - -import unittest +_SAVED_PREFIX = "_saved_" +GRAD_FN_NEXT_FUNCTIONS = "next_functions" -_SAVED_PREFIX = '_saved_' -GRAD_FN_NEXT_FUNCTIONS = 'next_functions' class CheckpointWrapperTest(TestCase): def setUp(self): @@ -66,13 +61,7 @@ def __init__(self): self.lin = nn.Linear(10, 10) def forward(self, a, b, c=None, d=None, **kwargs): - return ( - self.lin(a), - self.lin(b), - self.lin(c), - self.lin(d) - ) - + return (self.lin(a), self.lin(b), self.lin(c), self.lin(d)) for wrapper in [ partial(checkpoint_wrapper, checkpoint_impl=CheckpointImpl.REENTRANT), @@ -113,7 +102,6 @@ def forward(self, *, a=None, b=None): out = model(a=inp, b=inp) self.assertEqual(2, len(out)) - @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA") def test_checkpoint_wrapper_parity(self): """ @@ -122,13 +110,14 @@ def test_checkpoint_wrapper_parity(self): results in the same maximum memory usage, i.e. they are equivalent memory usage wise. """ + class Model(nn.Module): def __init__( self, n: int, use_cp: bool, use_wrapper: bool = False, - use_reentrant: bool = True + use_reentrant: bool = True, ): super().__init__() self.layers = nn.ModuleList() @@ -138,10 +127,14 @@ def __init__( self.use_reentrant = use_reentrant wrp = partial( checkpoint_wrapper, - checkpoint_impl=CheckpointImpl.REENTRANT if use_reentrant else CheckpointImpl.NO_REENTRANT + checkpoint_impl=CheckpointImpl.REENTRANT + if use_reentrant + else CheckpointImpl.NO_REENTRANT, ) for i in range(self.n): - l = nn.Sequential(nn.Linear(256, 256), nn.Linear(256, 256), nn.Linear(256, 256)) + l = nn.Sequential( + nn.Linear(256, 256), nn.Linear(256, 256), nn.Linear(256, 256) + ) use_checkpoint_wrapper = self.use_wrapper if use_checkpoint_wrapper: l = wrp(l) @@ -149,29 +142,41 @@ def __init__( def forward(self, x): for i in range(self.n): - if ( - self.use_wrapper or - not self.use_cp - ): + if self.use_wrapper or not self.use_cp: x = self.layers[i](x) else: - x = checkpoint(self.layers[i], x, use_reentrant=self.use_reentrant) + x = checkpoint( + self.layers[i], x, use_reentrant=self.use_reentrant + ) return x def test(use_checkpointing, use_wrapper, use_reentrant): - a = Model(8, use_checkpointing, use_wrapper=use_wrapper, use_reentrant=use_reentrant).cuda() + a = Model( + 8, + use_checkpointing, + use_wrapper=use_wrapper, + use_reentrant=use_reentrant, + ).cuda() x = torch.randn(10000, 256, requires_grad=True).cuda() torch.cuda.reset_peak_memory_stats() loss = a(x).sum() loss.backward() return torch.cuda.max_memory_allocated() - functional_no_reentrant = test(use_checkpointing=True, use_wrapper=False, use_reentrant=False) - wrapper_no_reentrant = test(use_checkpointing=False, use_wrapper=True, use_reentrant=False) + functional_no_reentrant = test( + use_checkpointing=True, use_wrapper=False, use_reentrant=False + ) + wrapper_no_reentrant = test( + use_checkpointing=False, use_wrapper=True, use_reentrant=False + ) self.assertEqual(functional_no_reentrant, wrapper_no_reentrant) - functional_reentrant = test(use_checkpointing=True, use_wrapper=False, use_reentrant=True) - wrapper_reentrant = test(use_checkpointing=False, use_wrapper=True, use_reentrant=True) + functional_reentrant = test( + use_checkpointing=True, use_wrapper=False, use_reentrant=True + ) + wrapper_reentrant = test( + use_checkpointing=False, use_wrapper=True, use_reentrant=True + ) self.assertEqual(functional_reentrant, wrapper_reentrant) def test_forward_missing_attributes(self): @@ -181,8 +186,8 @@ def test_forward_missing_attributes(self): # Test indexing is forwarded self.assertEqual(wrapped[0], lin) # Test missing attributes are forwarded. - m._foo = 'bar' - self.assertEqual(wrapped._foo, 'bar') + m._foo = "bar" + self.assertEqual(wrapped._foo, "bar") def test_apply_activation_checkpointing(self): """ @@ -190,6 +195,7 @@ def test_apply_activation_checkpointing(self): to swap modules for their checkpoint-wrapped counterparts given a model. """ + class LinearWithBatchNorm(nn.Module): def __init__(self): super().__init__() @@ -210,7 +216,6 @@ def __init__(self): def forward(self, x): return self.seq(x) - def check_fn(l): return isinstance(l, nn.Linear) @@ -231,13 +236,27 @@ def check_fn(l): apply_activation_checkpointing( model, checkpoint_wrapper_fn=wrapper, check_fn=check_fn ) - n_linear_wrapped = sum(1 if isinstance(x, nn.Linear) else 0 for x in model.modules()) - n_checkpointed = sum(1 if isinstance(x, (CheckpointWrapper, OffloadWrapper)) else 0 for x in model.modules()) + n_linear_wrapped = sum( + 1 if isinstance(x, nn.Linear) else 0 for x in model.modules() + ) + n_checkpointed = sum( + 1 if isinstance(x, (CheckpointWrapper, OffloadWrapper)) else 0 + for x in model.modules() + ) self.assertEqual(n_checkpointed, n_linear_wrapped) self.assertEqual(n_linear, n_linear_wrapped) for j in range(3): - self.assertTrue(isinstance(model.seq[j].lin, (CheckpointWrapper, OffloadWrapper))) - self.assertTrue(isinstance(model.seq[j].nested_linear[0], (CheckpointWrapper, OffloadWrapper))) + self.assertTrue( + isinstance( + model.seq[j].lin, (CheckpointWrapper, OffloadWrapper) + ) + ) + self.assertTrue( + isinstance( + model.seq[j].nested_linear[0], + (CheckpointWrapper, OffloadWrapper), + ) + ) inp = torch.randn(4, 10, requires_grad=True) for i in range(6): @@ -249,9 +268,22 @@ def check_fn(l): for j in range(3): weight_lin = model.seq[j].lin._checkpoint_wrapped_module.weight bias_lin = model.seq[j].lin._checkpoint_wrapped_module.bias - weight_nested_lin = model.seq[j].nested_linear[0]._checkpoint_wrapped_module.weight - bias_nested_lin = model.seq[j].nested_linear[0]._checkpoint_wrapped_module.bias - for param in [weight_lin, bias_lin, weight_nested_lin, bias_nested_lin]: + weight_nested_lin = ( + model.seq[j] + .nested_linear[0] + ._checkpoint_wrapped_module.weight + ) + bias_nested_lin = ( + model.seq[j] + .nested_linear[0] + ._checkpoint_wrapped_module.bias + ) + for param in [ + weight_lin, + bias_lin, + weight_nested_lin, + bias_nested_lin, + ]: self.assertTrue(param.requires_grad) self.assertFalse(param.grad is None) @@ -287,7 +319,7 @@ def testing_cpu_offload_unpack_hook(packed): model = offload_wrapper(model) - inp = torch.randn(3, 10, device='cuda') + inp = torch.randn(3, 10, device="cuda") loss = model(inp).sum() # All autograd saved tensors should be offloaded to CPU. @@ -314,5 +346,6 @@ def dfs(grad_fn): torch.autograd.graph.saved_tensors_hooks.__init__ = orig_init + if __name__ == "__main__": run_tests() diff --git a/test/distributed/fsdp/test_distributed_checkpoint.py b/test/distributed/fsdp/test_distributed_checkpoint.py index ef95973764c43..3e9b967e0d114 100644 --- a/test/distributed/fsdp/test_distributed_checkpoint.py +++ b/test/distributed/fsdp/test_distributed_checkpoint.py @@ -5,23 +5,17 @@ import torch from torch import distributed as dist -from torch.distributed._shard.checkpoint import ( +from torch.distributed.checkpoint import ( FileSystemReader, FileSystemWriter, - save_state_dict, load_state_dict, + save_state_dict, ) -from torch.distributed.fsdp import ( - FullyShardedDataParallel as FSDP, - StateDictType, -) +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel from torch.distributed.fsdp.wrap import enable_wrap, wrap from torch.testing._internal.common_distributed import skip_if_lt_x_gpu -from torch.testing._internal.common_fsdp import ( - FSDPTest, - SkipModel, -) +from torch.testing._internal.common_fsdp import FSDPTest, SkipModel from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, @@ -29,7 +23,6 @@ TEST_WITH_DEV_DBG_ASAN, ) - if not dist.is_available(): print("Distributed not available, skipping tests", file=sys.stderr) sys.exit(0) @@ -75,16 +68,16 @@ def test_distributed_checkpoint(self, state_dict_type) -> None: path = paths[0] writer = FileSystemWriter(path) reader = FileSystemReader(path) - with FSDP.state_dict_type( - model, state_dict_type - ), FSDP.state_dict_type(new_model, state_dict_type): + with FSDP.state_dict_type(model, state_dict_type), FSDP.state_dict_type( + new_model, state_dict_type + ): state_dict = model.state_dict() save_state_dict(state_dict, writer) - with FSDP.state_dict_type( - model, state_dict_type - ), FSDP.state_dict_type(new_model, state_dict_type): + with FSDP.state_dict_type(model, state_dict_type), FSDP.state_dict_type( + new_model, state_dict_type + ): state_dict = new_model.state_dict() load_state_dict(state_dict, reader) new_model.load_state_dict(state_dict) diff --git a/test/distributed/fsdp/test_fsdp_apply.py b/test/distributed/fsdp/test_fsdp_apply.py index d72d57d133b0d..d44239a329344 100644 --- a/test/distributed/fsdp/test_fsdp_apply.py +++ b/test/distributed/fsdp/test_fsdp_apply.py @@ -14,10 +14,7 @@ NestedWrappedModule, TransformerWithSharedParams, ) -from torch.testing._internal.common_utils import ( - TEST_WITH_DEV_DBG_ASAN, - run_tests, -) +from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN if not dist.is_available(): print("Distributed not available, skipping tests", file=sys.stderr) diff --git a/test/distributed/fsdp/test_fsdp_checkpoint.py b/test/distributed/fsdp/test_fsdp_checkpoint.py index 14456df92f84f..994f591ec5e7d 100644 --- a/test/distributed/fsdp/test_fsdp_checkpoint.py +++ b/test/distributed/fsdp/test_fsdp_checkpoint.py @@ -1,37 +1,52 @@ # Owner(s): ["oncall: distributed"] import contextlib +import sys from copy import deepcopy from functools import partial import torch import torch.distributed as dist import torch.nn as nn -from torch.distributed.fsdp.fully_sharded_data_parallel import ( - FullyShardedDataParallel as FSDP, - CPUOffload, -) from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( checkpoint_wrapper, offload_wrapper, ) -from torch.testing._internal.common_distributed import ( - skip_if_lt_x_gpu, -) -from torch.testing._internal.common_fsdp import ( - FSDPTest, - _maybe_wrap_fsdp, +from torch.distributed.fsdp import ShardingStrategy +from torch.distributed.fsdp.fully_sharded_data_parallel import ( + CPUOffload, + FullyShardedDataParallel as FSDP, ) + +from torch.testing._internal.common_distributed import skip_if_lt_x_gpu +from torch.testing._internal.common_fsdp import _maybe_wrap_fsdp, FSDPTest from torch.testing._internal.common_utils import ( - run_tests, - parametrize, instantiate_parametrized_tests, + parametrize, + run_tests, + TEST_WITH_DEV_DBG_ASAN, ) from torch.utils.checkpoint import checkpoint +if not dist.is_available(): + print("Distributed not available, skipping tests", file=sys.stderr) + sys.exit(0) + +if TEST_WITH_DEV_DBG_ASAN: + print( + "Skip dev-asan as torch + multiprocessing spawn have known issues", + file=sys.stderr, + ) + sys.exit(0) + + _save_on_cpu_called = False + + def get_patched_save_on_cpu(): - orig_save_on_cpu = torch.distributed.algorithms._checkpoint.checkpoint_wrapper.save_on_cpu + orig_save_on_cpu = ( + torch.distributed.algorithms._checkpoint.checkpoint_wrapper.save_on_cpu + ) def patched_save_on_cpu(*args, **kwargs): global _save_on_cpu_called @@ -40,14 +55,22 @@ def patched_save_on_cpu(*args, **kwargs): return patched_save_on_cpu + @contextlib.contextmanager def patch_save_on_cpu(new_save_on_cpu): - orig_save_on_cpu = torch.distributed.algorithms._checkpoint.checkpoint_wrapper.save_on_cpu - torch.distributed.algorithms._checkpoint.checkpoint_wrapper.save_on_cpu = new_save_on_cpu + orig_save_on_cpu = ( + torch.distributed.algorithms._checkpoint.checkpoint_wrapper.save_on_cpu + ) + torch.distributed.algorithms._checkpoint.checkpoint_wrapper.save_on_cpu = ( + new_save_on_cpu + ) try: yield finally: - torch.distributed.algorithms._checkpoint.checkpoint_wrapper.save_on_cpu = orig_save_on_cpu + torch.distributed.algorithms._checkpoint.checkpoint_wrapper.save_on_cpu = ( + orig_save_on_cpu + ) + class TestFSDPCheckpoint(FSDPTest): class SequentialModule(nn.Module): @@ -111,16 +134,24 @@ def _verify_parity(self, losses, outputs, models): [CPUOffload(offload_params=True), CPUOffload(offload_params=False)], ) @parametrize("offload_activations", [True, False]) - def test_checkpoint_fsdp_wrapping(self, cpu_offload, offload_activations): + @parametrize("use_orig_params", [False, True]) + def test_checkpoint_fsdp_wrapping( + self, + cpu_offload: CPUOffload, + offload_activations: bool, + use_orig_params: bool, + ): # Test checkpoint(FSDP(layer1), FSDP(layer2), ....) if offload_activations: wrapper_to_use = offload_wrapper else: wrapper_to_use = checkpoint_wrapper + fsdp_kwargs = {"cpu_offload": cpu_offload, "use_orig_params": use_orig_params} ckpt_sequential_wrapped_fsdp = wrapper_to_use( TestFSDPCheckpoint.SequentialModule( - wrap_fsdp=True, cpu_offload=cpu_offload + wrap_fsdp=True, + **fsdp_kwargs, ), ) # Test FSDP(checkpoint(layer1)), FSDP(checkpoint(layer2)), .... @@ -128,11 +159,12 @@ def test_checkpoint_fsdp_wrapping(self, cpu_offload, offload_activations): checkpoint_layer=True, offload_activations=offload_activations, wrap_fsdp=True, - cpu_offload=cpu_offload, + **fsdp_kwargs, ) baseline = TestFSDPCheckpoint.SequentialModule( - wrap_fsdp=True, cpu_offload=cpu_offload + wrap_fsdp=True, + **fsdp_kwargs, ) # note that reentrant-based checkpointing requires inputs to have grad @@ -168,12 +200,19 @@ def test_checkpoint_fsdp_wrapping(self, cpu_offload, offload_activations): [CPUOffload(offload_params=True), CPUOffload(offload_params=False)], ) @parametrize("offload_activations", [True, False]) - def test_basic_checkpoint_end_to_end(self, cpu_offload, offload_activations): + @parametrize("use_orig_params", [False, True]) + def test_basic_checkpoint_end_to_end( + self, + cpu_offload: CPUOffload, + offload_activations: bool, + use_orig_params: bool, + ): + fsdp_kwargs = {"cpu_offload": cpu_offload, "use_orig_params": use_orig_params} global _save_on_cpu_called with patch_save_on_cpu(get_patched_save_on_cpu()): seq = TestFSDPCheckpoint.SequentialModule().to(torch.cuda.current_device()) # Runs FSDP with no checkpointing - fsdp_only_seq = FSDP(deepcopy(seq), cpu_offload=cpu_offload) + fsdp_only_seq = FSDP(deepcopy(seq), **fsdp_kwargs) # Runs checkpoint-wrapped FSDP if offload_activations: wrapper_to_use = offload_wrapper @@ -181,19 +220,21 @@ def test_basic_checkpoint_end_to_end(self, cpu_offload, offload_activations): wrapper_to_use = checkpoint_wrapper checkpointed_fsdp = wrapper_to_use( - FSDP(deepcopy(seq), cpu_offload=cpu_offload), + FSDP(deepcopy(seq), **fsdp_kwargs), ) # Runs FSDP-wrapped checkpointed module fsdp_wrapped_checkpoint = FSDP( wrapper_to_use(deepcopy(seq)), - cpu_offload=cpu_offload, + **fsdp_kwargs, ) # Runs FSDP with manual calls to checkpoint. - fsdp_call_checkpoint = FSDP(deepcopy(seq), cpu_offload=cpu_offload) + fsdp_call_checkpoint = FSDP(deepcopy(seq), **fsdp_kwargs) # note that reentrant-based checkpointing requires inputs to have grad # flag set. - inp = torch.randn(10, 3, device=torch.cuda.current_device(), requires_grad=True) + inp = torch.randn( + 10, 3, device=torch.cuda.current_device(), requires_grad=True + ) models = [ fsdp_only_seq, @@ -207,7 +248,9 @@ def test_basic_checkpoint_end_to_end(self, cpu_offload, offload_activations): losses = [] outputs = [] for m in models: - check_offload = m != fsdp_only_seq and i == 0 and offload_activations + check_offload = ( + m != fsdp_only_seq and i == 0 and offload_activations + ) if m == fsdp_call_checkpoint: # _save_on_cpu should not be called yet self.assertFalse(_save_on_cpu_called) @@ -235,7 +278,91 @@ def test_basic_checkpoint_end_to_end(self, cpu_offload, offload_activations): dist.barrier() + instantiate_parametrized_tests(TestFSDPCheckpoint) + +class CheckpointModule(nn.Module): + def __init__(self, checkpoint: bool = False, use_reentrant: bool = True): + super().__init__() + self.seq = nn.Sequential(*[nn.Linear(100, 100) for _ in range(4)]) + self.checkpoint = checkpoint + self.use_reentrant = use_reentrant + + def forward(self, x): + return ( + checkpoint(self.seq, x, use_reentrant=self.use_reentrant) + if self.checkpoint + else self.seq(x) + ) + + +class ModelWithCheckpointSubmodule(nn.Module): + def __init__(self, checkpoint: bool = False, use_reentrant: bool = True): + super().__init__() + self.l1 = nn.Linear(100, 100) + self.s1 = CheckpointModule(checkpoint, use_reentrant) + self.s2 = CheckpointModule(checkpoint, use_reentrant) + self.relu = nn.ReLU() + self.l2 = nn.Linear(100, 100) + + def forward(self, x): + return self.l2(self.relu(self.s2(self.s1(self.l1(x))))) + + +class TestModel(nn.Module): + def __init__(self, checkpoint: bool = False, use_reentrant: bool = True): + super().__init__() + self.l1 = nn.Linear(100, 100) + self.relu = nn.ReLU() + self.checkpoint1 = ModelWithCheckpointSubmodule(checkpoint, use_reentrant) + self.checkpoint2 = ModelWithCheckpointSubmodule(checkpoint, use_reentrant) + self.l2 = nn.Linear(100, 100) + + def forward(self, x): + return self.l2(self.relu(self.checkpoint2(self.checkpoint1(self.l1(x))))) + + +class TestFSDPCheckpointSubmodule(FSDPTest): + + # TODO: grad value checks occasionally fails when use_reentrant = True + @skip_if_lt_x_gpu(2) + @parametrize("use_reentrant", [False]) + def test_checkpoint_submodule(self, use_reentrant: bool): + model = TestModel(use_reentrant=use_reentrant).cuda() + model_ac = deepcopy(model) + + for _, m in model_ac.named_modules(): + if isinstance(m, CheckpointModule): + m.checkpoint = True + + self.assertTrue(model_ac.checkpoint1.s1.checkpoint) + self.assertTrue(model_ac.checkpoint2.s2.checkpoint) + + fsdp_kwargs = { + "device_id": torch.cuda.current_device(), + "sharding_strategy": ShardingStrategy.NO_SHARD, + } + + # Wrap no checkpointing model submodules with FSDP + model.m1 = FSDP(module=model.checkpoint1, **fsdp_kwargs) + model.m2 = FSDP(module=model.checkpoint2, **fsdp_kwargs) + + # Wrap checkpointing model submodules with FSDP + model_ac.m1 = FSDP(module=model_ac.checkpoint1, **fsdp_kwargs) + model_ac.m2 = FSDP(module=model_ac.checkpoint2, **fsdp_kwargs) + + x = torch.randn(2, 100, device="cuda") + + model(x).sum().backward() + model_ac(x).sum().backward() + + for p1, p2 in zip(model.parameters(), model_ac.parameters()): + self.assertTrue(p1.grad.allclose(p2.grad)) + + +instantiate_parametrized_tests(TestFSDPCheckpointSubmodule) + + if __name__ == "__main__": run_tests() diff --git a/test/distributed/fsdp/test_fsdp_clip_grad_norm.py b/test/distributed/fsdp/test_fsdp_clip_grad_norm.py index 9e39254ec423a..81b9f4c37f06e 100644 --- a/test/distributed/fsdp/test_fsdp_clip_grad_norm.py +++ b/test/distributed/fsdp/test_fsdp_clip_grad_norm.py @@ -1,31 +1,35 @@ # Owner(s): ["oncall: distributed"] +import itertools import sys -from math import inf +from typing import Union import torch +import torch.nn as nn from torch import distributed as dist +from torch.distributed.fsdp import ShardingStrategy from torch.distributed.fsdp.fully_sharded_data_parallel import ( - FullyShardedDataParallel as FSDP, CPUOffload, - _calc_grad_norm, + FullyShardedDataParallel as FSDP, + MixedPrecision, ) -from torch.nn import utils as nn_utils +from torch.distributed.fsdp.wrap import ModuleWrapPolicy +from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer +from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import ( - DeterministicModel, + CUDAInitMode, + FSDPInitMode, FSDPTest, - _collect_total_grad_norm_fsdp, - _collect_total_grad_norm_local, + NestedWrappedModule, + TransformerWithSharedParams, ) from torch.testing._internal.common_utils import ( - TEST_WITH_DEV_DBG_ASAN, - run_tests, - parametrize, instantiate_parametrized_tests, + run_tests, + TEST_WITH_DEV_DBG_ASAN, ) - if not dist.is_available(): print("Distributed not available, skipping tests", file=sys.stderr) sys.exit(0) @@ -39,67 +43,265 @@ class TestClipGradNorm(FSDPTest): - def _run_fsdp_one_iteration(self, norm_type, nested_fsdp, cpu_offload): - """Test FSDP with clip grad norm.""" - fsdp_model = DeterministicModel(nested_fsdp, cpu_offload=cpu_offload) - local_model = DeterministicModel(False) - input = torch.rand(14, 2, device=self.rank) - fsdp_model = FSDP(fsdp_model, cpu_offload=cpu_offload) - self.assertTrue(len(input) >= self.world_size) - out = local_model(input[: self.world_size]) - out.sum().backward() - in_data = torch.tensor(input[self.rank], device=self.rank) - out_fsdp = fsdp_model(in_data) - out_fsdp.sum().backward() - total_norms_fsdp = _collect_total_grad_norm_fsdp( - fsdp_model, norm_type, self.rank + """Tests :meth:`FullyShardedDataParallel.clip_grad_norm_`.""" + + @skip_if_lt_x_gpu(2) + def test_non_root(self): + """ + Tests that calling ``clip_grad_norm_()`` on a non-root FSDP instance + raises an error. + """ + + class Model(nn.Module): + def __init__(self) -> None: + super().__init__() + self.lin1 = nn.Linear(5, 5) + self.lin2 = nn.Linear(5, 5) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.lin2(self.lin1(x)) + + model = Model().cuda() + model.lin2 = FSDP(model.lin2) + fsdp_model = FSDP(model) + fsdp_model(torch.randn((2, 5), device=torch.device("cuda"))).sum().backward() + error_regex = "should only be called on the root FSDP instance" + with self.assertRaisesRegex(RuntimeError, error_regex): + fsdp_model.lin2.clip_grad_norm_(max_norm=2) + + @skip_if_lt_x_gpu(2) + def test_ddp_parity(self): + """ + Tests FSDP with ``FullyShardedDataParallel.clip_grad_norm_()` against + DDP with ``torch.nn.utils.clip_grad_norm_()` when using full precision. + """ + self.run_subtests( + { + "max_norm": [1, 2.5], + "norm_type": [1, 2, float("inf")], + "sharding_strategy": [ + ShardingStrategy.FULL_SHARD, + ShardingStrategy.NO_SHARD, + "mixed_strategy", + ], + "use_orig_params": [False, True], + "offload_params": [False, True], + }, + self._test_ddp_parity, ) - total_norms_local = _collect_total_grad_norm_local(local_model, norm_type) - total_norms_local /= self.world_size - norm_cap = total_norms_fsdp / 2.0 - self.assertEqual(total_norms_local, total_norms_fsdp) - fsdp_model.clip_grad_norm_(norm_cap, norm_type=norm_type) - nn_utils.clip_grad_norm_( - local_model.parameters(), norm_cap, norm_type=norm_type + + def _test_ddp_parity( + self, + max_norm: Union[float, int], + norm_type: Union[float, int], + sharding_strategy: Union[ShardingStrategy, str], + use_orig_params: bool, + offload_params: bool, + ): + local_model = TransformerWithSharedParams.init( + self.process_group, + FSDPInitMode.NO_FSDP, + CUDAInitMode.CUDA_BEFORE, + deterministic=True, ) - total_norms_after_clip_fsdp = _collect_total_grad_norm_fsdp( - fsdp_model, norm_type, self.rank + ddp_model = DDP(local_model, device_ids=[self.rank]) + fsdp_kwargs = { + "cpu_offload": CPUOffload(offload_params=offload_params), + "use_orig_params": use_orig_params, + } + if sharding_strategy == "mixed_strategy": + fsdp_model = TransformerWithSharedParams.init( + self.process_group, + FSDPInitMode.NO_FSDP, + CUDAInitMode.CUDA_BEFORE, + deterministic=True, + ) + # Apply `NO_SHARD` to the encoder + fsdp_model.transformer.encoder = FSDP( + fsdp_model.transformer.encoder, + sharding_strategy=ShardingStrategy.NO_SHARD, + **fsdp_kwargs, + ) + # Apply `FULL_SHARD` to the decoder + fsdp_model.transformer.decoder = FSDP( + fsdp_model.transformer.decoder, + sharding_strategy=ShardingStrategy.FULL_SHARD, + **fsdp_kwargs, + ) + # TODO: FSDP's `clip_grad_norm_()` is not a static method, so we + # must make the root module an FSDP instance + fsdp_model = FSDP( + fsdp_model, sharding_strategy=ShardingStrategy.FULL_SHARD, **fsdp_kwargs + ) + else: + fsdp_kwargs.update( + { + "sharding_strategy": sharding_strategy, + "auto_wrap_policy": ModuleWrapPolicy( + { + TransformerEncoderLayer, + TransformerDecoderLayer, + } + ), + } + ) + fsdp_model = TransformerWithSharedParams.init( + self.process_group, + FSDPInitMode.RECURSIVE, + CUDAInitMode.CUDA_BEFORE, + deterministic=True, + fsdp_kwargs=fsdp_kwargs, + ) + LR = 1e-2 + ddp_optim = torch.optim.Adam(ddp_model.parameters(), lr=LR) + fsdp_optim = torch.optim.Adam(fsdp_model.parameters(), lr=LR) + device = torch.device("cuda") + LARGE_FACTOR = 100 + inp = ddp_model.module.get_input(device) + for model in (ddp_model, fsdp_model): + out = model(*inp) + if isinstance(model, (DDP, FSDP)): + loss = model.module.get_loss(inp, out) + else: + loss = model.get_loss(inp, out) + loss.backward() + + # Multiply gradients by a large factor to ensure that gradients will + # actually be clipped + for param in itertools.chain(ddp_model.parameters(), fsdp_model.parameters()): + if ( + param.grad is not None + ): # gradients may be `None` for `use_orig_params=True` + param.grad *= LARGE_FACTOR + orig_ddp_grads = [ + param.grad.detach().clone() for param in ddp_model.parameters() + ] + orig_fsdp_grads = [ + param.grad.detach().clone() if param.grad is not None else None + for param in fsdp_model.parameters() + ] + + ddp_total_norm = torch.nn.utils.clip_grad_norm_( + ddp_model.parameters(), + max_norm=max_norm, + norm_type=norm_type, ) - total_norms_after_clip_local = _collect_total_grad_norm_local( - local_model, norm_type + fsdp_total_norm = fsdp_model.clip_grad_norm_( + max_norm=max_norm, norm_type=norm_type ) - self.assertTrue(total_norms_after_clip_fsdp <= norm_cap) - self.assertEqual(total_norms_after_clip_local, total_norms_after_clip_fsdp) + self.assertEqual(ddp_total_norm, fsdp_total_norm) - @skip_if_lt_x_gpu(2) - @parametrize("norm_type", [2.0, inf]) - @parametrize("nested_fsdp", [True, False]) - @parametrize( - "cpu_offload", - [CPUOffload(offload_params=True), CPUOffload(offload_params=False)], - ) - def test_fsdp_clip_grad_norm(self, norm_type, nested_fsdp, cpu_offload): - """Test FSDP with clip grad norm.""" - self._run_fsdp_one_iteration(norm_type, nested_fsdp, cpu_offload) + # Check that the gradients were modified by `clip_grad_norm_()` + for param, orig_grad in zip(ddp_model.parameters(), orig_ddp_grads): + assert not torch.equal(param.grad, orig_grad) + for param, orig_grad in zip(fsdp_model.parameters(), orig_fsdp_grads): + if param.grad is None: + self.assertEqual(param.grad, orig_grad) # `None` + else: + assert not torch.equal(param.grad, orig_grad) + + # Run an optimizer step to ensure gradients matched after clipping + ddp_optim.step() + fsdp_optim.step() + with FSDP.summon_full_params(fsdp_model): + for (n1, p1), (n2, p2) in zip( + ddp_model.module.named_parameters(), + fsdp_model.named_parameters(), + ): + self.assertEqual(n1, n2) + self.assertEqual(p1, p2) + + if offload_params: + # TODO: Gradient computation on CPU and GPU differ slightly causing + # drift unrelated to `clip_grad_norm_()`. + # https://github.com/pytorch/pytorch/issues/89133 + return + # Run a few more iterations + # TODO: We cannot run too many iterations, or else there is drift: + # https://github.com/pytorch/pytorch/issues/89136 + for i in range(3): + set_to_none = i % 2 == 0 # exercise both + ddp_optim.zero_grad(set_to_none=set_to_none) + fsdp_optim.zero_grad(set_to_none=set_to_none) + inp = ddp_model.module.get_input(device) + for model in (ddp_model, fsdp_model): + out = model(*inp) + out.sum().backward() + ddp_total_norm = torch.nn.utils.clip_grad_norm_( + ddp_model.parameters(), + max_norm=max_norm, + norm_type=norm_type, + ) + fsdp_total_norm = fsdp_model.clip_grad_norm_( + max_norm=max_norm, norm_type=norm_type + ) + self.assertEqual(ddp_total_norm, fsdp_total_norm) + ddp_optim.step() + fsdp_optim.step() -class TestCalcuGradNorm(FSDPTest): @skip_if_lt_x_gpu(2) - @parametrize("norm_type", [2.0, inf, 1.3, 2.5]) - @parametrize("nested_fsdp", [True, False]) - def test_fsdp_calc_grad_norm(self, norm_type, nested_fsdp): - """Test grad norm cal API.""" - model = FSDP(DeterministicModel(nested_fsdp)) - input = torch.rand(15, 2, device=self.rank) - out = model(input) + def test_low_precision_grads(self): + """Tests ``clip_grad_norm_()`` when using low precision gradients.""" + self.run_subtests( + { + "max_norm": [1, 2.5], + "norm_type": [1, 2, float("inf")], + "sharding_strategy": [ + ShardingStrategy.FULL_SHARD, + ShardingStrategy.NO_SHARD, + ], + "use_orig_params": [False, True], + }, + self._test_low_precision_grads, + ) + + def _test_low_precision_grads( + self, + max_norm: Union[float, int], + norm_type: Union[float, int], + sharding_strategy: ShardingStrategy, + use_orig_params: bool, + ): + fsdp_kwargs = { + "sharding_strategy": sharding_strategy, + "use_orig_params": use_orig_params, + "mixed_precision": MixedPrecision( + param_dtype=torch.float16, + reduce_dtype=torch.float16, + keep_low_precision_grads=True, + ), + } + fsdp_model = FSDP( + NestedWrappedModule.init( + self.process_group, + FSDPInitMode.RECURSIVE, + CUDAInitMode.CUDA_BEFORE, + deterministic=True, + fsdp_kwargs=fsdp_kwargs, + ), + **fsdp_kwargs, + ) + inp = fsdp_model.module.get_input(torch.device("cuda")) + out = fsdp_model(*inp) out.sum().backward() - total_norm = _calc_grad_norm(model.params_with_grad, norm_type) - total_norm_expected = _collect_total_grad_norm_local(model, norm_type) - self.assertEqual(total_norm, total_norm_expected) + for param in fsdp_model.parameters(): + if param.grad is not None: + self.assertEqual(param.grad.dtype, torch.float16) + total_norm = fsdp_model.clip_grad_norm_(max_norm=max_norm, norm_type=norm_type) + # Check that the total norm is in FP16 to match the gradient dtype + self.assertEqual(total_norm.dtype, torch.float16) + # As a best effort, check that each gradient has norm at most the max + # norm (since DDP does not support mixed precision natively, we cannot + # directly compare for parity) + for param in fsdp_model.parameters(): + if param.grad is not None: + self.assertTrue( + torch.linalg.vector_norm(param.grad, norm_type).item() <= max_norm, + ) instantiate_parametrized_tests(TestClipGradNorm) -instantiate_parametrized_tests(TestCalcuGradNorm) if __name__ == "__main__": run_tests() diff --git a/test/distributed/fsdp/test_fsdp_comm.py b/test/distributed/fsdp/test_fsdp_comm.py index c9946a9dd5665..117e756da252e 100644 --- a/test/distributed/fsdp/test_fsdp_comm.py +++ b/test/distributed/fsdp/test_fsdp_comm.py @@ -2,7 +2,7 @@ import sys from contextlib import suppress -from enum import Enum, auto +from enum import auto, Enum from typing import Optional from unittest.mock import patch @@ -19,10 +19,10 @@ TransformerWithSharedParams, ) from torch.testing._internal.common_utils import ( - TEST_WITH_DEV_DBG_ASAN, instantiate_parametrized_tests, parametrize, run_tests, + TEST_WITH_DEV_DBG_ASAN, ) if not dist.is_available(): @@ -45,6 +45,7 @@ class PassType(Enum): class TestCommunication(FSDPTest): """Tests ``FullyShardedDataParallel``'s collective communication usage.""" + def _init_model( self, nested_model: bool, @@ -106,7 +107,8 @@ def _get_ref_num_all_gathers( pass_type, is_first_iter, is_last_iter_no_sync, - ) for pass_type in PassType + ) + for pass_type in PassType ) def _get_ref_num_all_gathers_in_pass( @@ -121,9 +123,11 @@ def _get_ref_num_all_gathers_in_pass( if sharding_strategy is None: sharding_strategy = ShardingStrategy.FULL_SHARD # default # Forward pass: - if pass_type == PassType.FWD and \ - sharding_strategy == ShardingStrategy.SHARD_GRAD_OP and \ - is_last_iter_no_sync: + if ( + pass_type == PassType.FWD + and sharding_strategy == ShardingStrategy.SHARD_GRAD_OP + and is_last_iter_no_sync + ): # Modules do not free the full parameters in the last # iteration's backward pass if it was in `no_sync()` num_all_gathers = 0 @@ -132,21 +136,27 @@ def _get_ref_num_all_gathers_in_pass( # forward pass num_all_gathers = num_fsdp # Backward pass: - elif pass_type == PassType.BWD and \ - sharding_strategy == ShardingStrategy.FULL_SHARD: + elif ( + pass_type == PassType.BWD + and sharding_strategy == ShardingStrategy.FULL_SHARD + ): # Root does not free the full parameters at the end of the # forward pass num_all_gathers = num_fsdp - 1 - elif pass_type == PassType.BWD and \ - sharding_strategy == ShardingStrategy.SHARD_GRAD_OP: + elif ( + pass_type == PassType.BWD + and sharding_strategy == ShardingStrategy.SHARD_GRAD_OP + ): # Modules do not free the full parameters at the end of the # forward pass num_all_gathers = 0 else: - assert 0, f"Unsupported: add a branch for pass_type={pass_type} " \ - f"is_first_iter={is_first_iter} " \ - f"is_last_iter_no_sync={is_last_iter_no_sync} " \ + assert 0, ( + f"Unsupported: add a branch for pass_type={pass_type} " + f"is_first_iter={is_first_iter} " + f"is_last_iter_no_sync={is_last_iter_no_sync} " f"sharding_strategy={sharding_strategy}" + ) if is_first_iter and pass_type == PassType.FWD: # With execution order validation, on the first iteration, we have # an additional two all-gathers before every actual all-gather in @@ -167,7 +177,10 @@ def _print_ref_num_all_gathers_in_pass( if self.rank != 0: return # only print on one rank num_all_gathers = self._get_ref_num_all_gathers_in_pass( - num_fsdp, sharding_strategy, pass_type, is_first_iter, + num_fsdp, + sharding_strategy, + pass_type, + is_first_iter, is_last_iter_no_sync, ) print( @@ -211,8 +224,7 @@ def test_communication( # Count the number of FSDP instances that manage parameters since the # number of collectives are a function of this number num_fsdp = sum( - (isinstance(m, FSDP) and len(m.params) > 0) - for m in fsdp_model.modules() + (isinstance(m, FSDP) and len(m.params) > 0) for m in fsdp_model.modules() ) # If `use_no_sync=True`, we run `num_iters` iterations inside @@ -220,11 +232,16 @@ def test_communication( # and if `use_no_sync=False`, we only run `num_iters` iterations # outside `no_sync()` num_iters = 3 - with patch("torch.distributed.all_gather_into_tensor") as mock_all_gather, \ - patch("torch.distributed._reduce_scatter_base") as mock_reduce_scatter: + with patch( + "torch.distributed.all_gather_into_tensor" + ) as mock_all_gather, patch( + "torch.distributed.reduce_scatter_tensor" + ) as mock_reduce_scatter: + def reset_mocks(): mock_all_gather.reset_mock() mock_reduce_scatter.reset_mock() + # Check the communication cost when using `no_sync()` if use_no_sync: for i in range(num_iters): @@ -233,11 +250,14 @@ def reset_mocks(): num_all_gathers = mock_all_gather.call_count num_reduce_scatters = mock_reduce_scatter.call_count ref_num_all_gathers = self._get_ref_num_all_gathers( - num_fsdp, sharding_strategy, is_first_iter=i == 0, + num_fsdp, + sharding_strategy, + is_first_iter=i == 0, is_last_iter_no_sync=i > 0, ) ref_num_reduce_scatters = self._get_ref_num_reduce_scatters( - num_fsdp, in_no_sync=True, + num_fsdp, + in_no_sync=True, ) self.assertEqual(num_all_gathers, ref_num_all_gathers) self.assertEqual(num_reduce_scatters, ref_num_reduce_scatters) @@ -248,12 +268,14 @@ def reset_mocks(): num_all_gathers = mock_all_gather.call_count num_reduce_scatters = mock_reduce_scatter.call_count ref_num_all_gathers = self._get_ref_num_all_gathers( - num_fsdp, sharding_strategy, + num_fsdp, + sharding_strategy, is_first_iter=not use_no_sync and i == 0, is_last_iter_no_sync=use_no_sync and i == 0, ) ref_num_reduce_scatters = self._get_ref_num_reduce_scatters( - num_fsdp, in_no_sync=False, + num_fsdp, + in_no_sync=False, ) self.assertEqual(num_all_gathers, ref_num_all_gathers) self.assertEqual(num_reduce_scatters, ref_num_reduce_scatters) diff --git a/test/distributed/fsdp/test_fsdp_comm_hooks.py b/test/distributed/fsdp/test_fsdp_comm_hooks.py index bfd710cdac486..125606fbff5cb 100644 --- a/test/distributed/fsdp/test_fsdp_comm_hooks.py +++ b/test/distributed/fsdp/test_fsdp_comm_hooks.py @@ -7,10 +7,9 @@ import torch.nn as nn import torch.nn.functional as F from torch import distributed as dist -from torch.distributed.distributed_c10d import _get_default_group from torch.distributed.algorithms._comm_hooks import default_hooks -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp import MixedPrecision +from torch.distributed.distributed_c10d import _get_default_group +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, MixedPrecision from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy from torch.testing._internal.common_distributed import ( requires_nccl, @@ -26,7 +25,6 @@ run_tests, ) - if not dist.is_available(): print("Distributed not available, skipping tests", file=sys.stderr) sys.exit(0) @@ -35,10 +33,11 @@ BFLOAT16_AVAILABLE = ( torch.cuda.is_available() and torch.version.cuda is not None - and int(torch.version.cuda.split('.')[0]) >= 11) + and int(torch.version.cuda.split(".")[0]) >= 11 +) -class Net(nn.Module): +class Net(nn.Module): def __init__(self, has_wrapping, sharding_strategy, mixed_precision=None): # to ensure determinism torch.manual_seed(0) @@ -46,45 +45,40 @@ def __init__(self, has_wrapping, sharding_strategy, mixed_precision=None): super().__init__() if has_wrapping: - self.net = FSDP(nn.Sequential( - nn.Linear(8, 16), - nn.ReLU(), - FSDP( - nn.Linear(16, 8), - device_id=torch.cuda.current_device(), - sharding_strategy=sharding_strategy, - mixed_precision=mixed_precision, - ) - ), + self.net = FSDP( + nn.Sequential( + nn.Linear(8, 16), + nn.ReLU(), + FSDP( + nn.Linear(16, 8), + device_id=torch.cuda.current_device(), + sharding_strategy=sharding_strategy, + mixed_precision=mixed_precision, + ), + ), device_id=torch.cuda.current_device(), sharding_strategy=sharding_strategy, mixed_precision=mixed_precision, ) else: - self.net = nn.Sequential( - nn.Linear(8, 16), - nn.ReLU(), - nn.Linear(16, 8) - ) + self.net = nn.Sequential(nn.Linear(8, 16), nn.ReLU(), nn.Linear(16, 8)) self.out = nn.Linear(8, 4) def forward(self, x): return self.out(F.relu(self.net(x))) + class DummyState(object): - __slots__ = [ - "process_group", - "noise" - ] + __slots__ = ["process_group", "noise"] def __init__(self, process_group: dist.ProcessGroup, noise: int): self.process_group = process_group self.noise = noise -class DummyHook(object): +class DummyHook(object): def dummy_hook_for_no_shard_fsdp(self, state: DummyState, grad: torch.Tensor): """ This communication hook is for illustration and testing purpose only. @@ -104,7 +98,9 @@ def custom_reduce_scatter(self, output, input, group=None): """ pass - def dummy_hook_for_sharded_fsdp(self, state: DummyState, grad: torch.Tensor, output: torch.Tensor): + def dummy_hook_for_sharded_fsdp( + self, state: DummyState, grad: torch.Tensor, output: torch.Tensor + ): """ This communication hook is for illustration and testing purposes only. This communication hook is used during FSDP ``FULL_SHARD`` or ``SHARD_GRAD_OP`` training. @@ -112,23 +108,21 @@ def dummy_hook_for_sharded_fsdp(self, state: DummyState, grad: torch.Tensor, out ``reduce_scatter`` for gradient communication and stores a sharded gradient in ``output``. """ grad.add_(state.noise) - self.custom_reduce_scatter( - output, grad, group=state.process_group - ) + self.custom_reduce_scatter(output, grad, group=state.process_group) -class TestCommunicationHooks(FSDPTest): +class TestCommunicationHooks(FSDPTest): @skip_if_lt_x_gpu(2) @parametrize( "sharding_strategy", [ ShardingStrategy.NO_SHARD, ShardingStrategy.FULL_SHARD, - ShardingStrategy.SHARD_GRAD_OP - ]) + ShardingStrategy.SHARD_GRAD_OP, + ], + ) def test_default_communication_hook_behavior( - self, - sharding_strategy: Optional[ShardingStrategy] + self, sharding_strategy: Optional[ShardingStrategy] ): """ Tests FSDP's default communication hook's behavior and correctness. @@ -148,14 +142,16 @@ def test_default_communication_hook_behavior( net_default_hook = FSDP( net, device_id=torch.cuda.current_device(), - sharding_strategy=sharding_strategy + sharding_strategy=sharding_strategy, ).to(self.rank) # Check that default hook is set to `all_reduce` for `NO_SHARD` # or `reduce_scatter` for sharded cases - default_hook = default_hooks.reduce_scatter_hook\ - if sharding_strategy != ShardingStrategy.NO_SHARD\ + default_hook = ( + default_hooks.reduce_scatter_hook + if sharding_strategy != ShardingStrategy.NO_SHARD else default_hooks.allreduce_hook + ) for entry in FSDP.fsdp_modules(net_default_hook): self.assertEqual(entry._communication_hook, default_hook) @@ -176,11 +172,13 @@ def test_default_communication_hook_behavior( self.assertEqual( grad[0].item(), expected_grad, - msg=f"Expected hook grad of {expected_grad} but got {grad[0].item()}") + msg=f"Expected hook grad of {expected_grad} but got {grad[0].item()}", + ) def _get_submodules(self, fsdp_net): return [ - submodule for submodule in FSDP.fsdp_modules(fsdp_net) + submodule + for submodule in FSDP.fsdp_modules(fsdp_net) if not submodule.check_is_root() ] @@ -201,12 +199,11 @@ def _init_model(self, core, sharding_strategy, mixed_precision=None): [ ShardingStrategy.NO_SHARD, ShardingStrategy.FULL_SHARD, - ShardingStrategy.SHARD_GRAD_OP - ]) + ShardingStrategy.SHARD_GRAD_OP, + ], + ) def test_default_communication_hook_initialization( - self, - has_wrapping: bool, - sharding_strategy: Optional[ShardingStrategy] + self, has_wrapping: bool, sharding_strategy: Optional[ShardingStrategy] ): """ Tests FSDP's communication hook interface behavior. @@ -219,45 +216,39 @@ def test_default_communication_hook_initialization( # Initialize a model fsdp_model_with_hook = self._init_model( Net(has_wrapping=has_wrapping, sharding_strategy=sharding_strategy), - sharding_strategy=sharding_strategy + sharding_strategy=sharding_strategy, ) # Check that default hook is set to `all_reduce` for `NO_SHARD` # or `reduce_scatter` for sharded cases - default_hook = default_hooks.reduce_scatter_hook\ - if sharding_strategy != ShardingStrategy.NO_SHARD\ + default_hook = ( + default_hooks.reduce_scatter_hook + if sharding_strategy != ShardingStrategy.NO_SHARD else default_hooks.allreduce_hook + ) for entry in FSDP.fsdp_modules(fsdp_model_with_hook): self.assertEqual(entry._communication_hook, default_hook) dummy_state = DummyState(process_group=None, noise=1234) - dummy_hook = DummyHook.dummy_hook_for_no_shard_fsdp\ - if sharding_strategy != ShardingStrategy.NO_SHARD\ + dummy_hook = ( + DummyHook.dummy_hook_for_no_shard_fsdp + if sharding_strategy != ShardingStrategy.NO_SHARD else DummyHook.dummy_hook_for_sharded_fsdp - - fsdp_model_with_hook.register_comm_hook( - dummy_state, - dummy_hook ) + fsdp_model_with_hook.register_comm_hook(dummy_state, dummy_hook) + # Check that we can't register comm hook twice - with self.assertRaisesRegex(AssertionError, '^communication hook can be only registered once$'): - fsdp_model_with_hook.register_comm_hook( - dummy_state, - dummy_hook - ) + with self.assertRaisesRegex( + AssertionError, "^communication hook can be only registered once$" + ): + fsdp_model_with_hook.register_comm_hook(dummy_state, dummy_hook) # Check dummy hook was registered for the root and all submodules if any for entry in FSDP.fsdp_modules(fsdp_model_with_hook): - self.assertEqual( - entry._communication_hook, - dummy_hook - ) - self.assertEqual( - entry._communication_hook_state, - dummy_state - ) + self.assertEqual(entry._communication_hook, dummy_hook) + self.assertEqual(entry._communication_hook_state, dummy_state) for entry in FSDP.fsdp_modules(fsdp_model_with_hook): entry._communication_hook = None @@ -277,18 +268,17 @@ def test_default_communication_hook_initialization( with self.assertRaises(AssertionError): loss.backward() - @skip_if_lt_x_gpu(2) @parametrize( "sharding_strategy", [ ShardingStrategy.NO_SHARD, ShardingStrategy.FULL_SHARD, - ShardingStrategy.SHARD_GRAD_OP - ]) + ShardingStrategy.SHARD_GRAD_OP, + ], + ) def test_registering_hook_non_root( - self, - sharding_strategy: Optional[ShardingStrategy] + self, sharding_strategy: Optional[ShardingStrategy] ): """ Tests FSDP's communication hook registering for submodules. @@ -301,16 +291,21 @@ def test_registering_hook_non_root( fsdp_model_with_hook = self._init_model( Net(has_wrapping=True, sharding_strategy=sharding_strategy), - sharding_strategy=sharding_strategy + sharding_strategy=sharding_strategy, ) dummy_state = DummyState(process_group=None, noise=1234) - dummy_hook = DummyHook.dummy_hook_for_no_shard_fsdp\ - if sharding_strategy != ShardingStrategy.NO_SHARD\ + dummy_hook = ( + DummyHook.dummy_hook_for_no_shard_fsdp + if sharding_strategy != ShardingStrategy.NO_SHARD else DummyHook.dummy_hook_for_sharded_fsdp + ) # Creating a list of non-root submodules to test submodules = self._get_submodules(fsdp_model_with_hook) # Check that assertion is raised for registering a comm hook on a non-root - with self.assertRaisesRegex(AssertionError, '^register_comm_hook can only be called on a root instance.$'): + with self.assertRaisesRegex( + AssertionError, + "^register_comm_hook can only be called on a root instance.$", + ): submodules[1].register_comm_hook(dummy_state, dummy_hook) @skip_if_lt_x_gpu(2) @@ -319,11 +314,11 @@ def test_registering_hook_non_root( [ ShardingStrategy.NO_SHARD, ShardingStrategy.FULL_SHARD, - ShardingStrategy.SHARD_GRAD_OP - ]) + ShardingStrategy.SHARD_GRAD_OP, + ], + ) def test_registering_hook_submodules( - self, - sharding_strategy: Optional[ShardingStrategy] + self, sharding_strategy: Optional[ShardingStrategy] ): """ Tests FSDP's communication hook registering for submodules. @@ -336,24 +331,28 @@ def test_registering_hook_submodules( fsdp_model_with_hook = self._init_model( Net(has_wrapping=True, sharding_strategy=sharding_strategy), - sharding_strategy=sharding_strategy + sharding_strategy=sharding_strategy, ) dummy_state = DummyState(process_group=None, noise=1234) - dummy_hook = DummyHook.dummy_hook_for_no_shard_fsdp\ - if sharding_strategy != ShardingStrategy.NO_SHARD\ + dummy_hook = ( + DummyHook.dummy_hook_for_no_shard_fsdp + if sharding_strategy != ShardingStrategy.NO_SHARD else DummyHook.dummy_hook_for_sharded_fsdp + ) submodules = self._get_submodules(fsdp_model_with_hook) # Simulate a registration of a hook on a submodule submodules[1]._hook_registered = True # Check that an error is raised when some of submodules have a non-default hook assigned - with self.assertRaisesRegex(AssertionError, '^communication hook can be only registered once$'): + with self.assertRaisesRegex( + AssertionError, "^communication hook can be only registered once$" + ): fsdp_model_with_hook.register_comm_hook(dummy_state, dummy_hook) # Reinitialize the model fsdp_model_with_hook = self._init_model( Net(has_wrapping=True, sharding_strategy=sharding_strategy), - sharding_strategy=sharding_strategy + sharding_strategy=sharding_strategy, ) submodules = self._get_submodules(fsdp_model_with_hook) submodules[1]._communication_hook = dummy_hook @@ -361,29 +360,32 @@ def test_registering_hook_submodules( # Check that an error is raised when some of submodules have a non-default hook assigned with self.assertRaisesRegex( AssertionError, - f'^communication hook should be default, but it is {submodules[1]._communication_hook.__name__} instead$' + f"^communication hook should be default, but it is {submodules[1]._communication_hook.__name__} instead$", ): - fsdp_model_with_hook.register_comm_hook( - dummy_state, - dummy_hook - ) + fsdp_model_with_hook.register_comm_hook(dummy_state, dummy_hook) - def _check_low_precision_hook(self, state, hook, sharding_strategy, dtype, has_wrapping): + def _check_low_precision_hook( + self, state, hook, sharding_strategy, dtype, has_wrapping + ): # keep everything deterministic for input data torch.manual_seed(0) torch.cuda.manual_seed(0) fsdp_with_hook = self._init_model( Net(has_wrapping=has_wrapping, sharding_strategy=sharding_strategy), - sharding_strategy=sharding_strategy + sharding_strategy=sharding_strategy, ) fsdp_with_hook.register_comm_hook(state, hook) mp_only_grad = MixedPrecision(reduce_dtype=dtype) fsdp_with_mp = self._init_model( - Net(has_wrapping=has_wrapping, sharding_strategy=sharding_strategy, mixed_precision=mp_only_grad), + Net( + has_wrapping=has_wrapping, + sharding_strategy=sharding_strategy, + mixed_precision=mp_only_grad, + ), sharding_strategy=sharding_strategy, - mixed_precision=mp_only_grad + mixed_precision=mp_only_grad, ) optim_hook = torch.optim.SGD(fsdp_with_hook.parameters(), lr=0.1) @@ -403,7 +405,9 @@ def _check_low_precision_hook(self, state, hook, sharding_strategy, dtype, has_w dist.barrier() - for hook_param, mp_param in zip(fsdp_with_hook.parameters(), fsdp_with_mp.parameters()): + for hook_param, mp_param in zip( + fsdp_with_hook.parameters(), fsdp_with_mp.parameters() + ): self.assertEqual(hook_param.grad, mp_param.grad) @requires_nccl() @@ -414,18 +418,19 @@ def _check_low_precision_hook(self, state, hook, sharding_strategy, dtype, has_w [ ShardingStrategy.NO_SHARD, ShardingStrategy.FULL_SHARD, - ShardingStrategy.SHARD_GRAD_OP - ]) + ShardingStrategy.SHARD_GRAD_OP, + ], + ) def test_fp16_hook( - self, - has_wrapping: bool, - sharding_strategy: Optional[ShardingStrategy] + self, has_wrapping: bool, sharding_strategy: Optional[ShardingStrategy] ): state = default_hooks.LowPrecisionState(process_group=_get_default_group()) hook = default_hooks.fp16_compress_hook - self._check_low_precision_hook(state, hook, sharding_strategy, torch.float16, has_wrapping) + self._check_low_precision_hook( + state, hook, sharding_strategy, torch.float16, has_wrapping + ) @requires_nccl() @requires_nccl_version((2, 10), "Need NCCL 2.10+ for BF16_COMPRESS") @@ -441,18 +446,19 @@ def test_fp16_hook( [ ShardingStrategy.NO_SHARD, ShardingStrategy.FULL_SHARD, - ShardingStrategy.SHARD_GRAD_OP - ]) + ShardingStrategy.SHARD_GRAD_OP, + ], + ) def test_bf16_hook( - self, - has_wrapping: bool, - sharding_strategy: Optional[ShardingStrategy] + self, has_wrapping: bool, sharding_strategy: Optional[ShardingStrategy] ): state = default_hooks.LowPrecisionState(process_group=_get_default_group()) hook = default_hooks.bf16_compress_hook - self._check_low_precision_hook(state, hook, sharding_strategy, torch.bfloat16, has_wrapping) + self._check_low_precision_hook( + state, hook, sharding_strategy, torch.bfloat16, has_wrapping + ) instantiate_parametrized_tests(TestCommunicationHooks) diff --git a/test/distributed/fsdp/test_fsdp_core.py b/test/distributed/fsdp/test_fsdp_core.py index 9557f2abcfbcb..c77378384ef60 100644 --- a/test/distributed/fsdp/test_fsdp_core.py +++ b/test/distributed/fsdp/test_fsdp_core.py @@ -24,14 +24,14 @@ MixtureOfExperts, NestedWrappedModule, NestedWrappedModuleWithDelay, - TransformerWithSharedParams, subtest_name, + TransformerWithSharedParams, ) from torch.testing._internal.common_utils import ( - TEST_WITH_DEV_DBG_ASAN, instantiate_parametrized_tests, parametrize, run_tests, + TEST_WITH_DEV_DBG_ASAN, ) if not dist.is_available(): @@ -47,7 +47,11 @@ params = "cpu_offload,sharding_strategy" cpu_offload_config = [CPUOffload(offload_params=True), CPUOffload(offload_params=False)] -sharding_strategy_config = [None, ShardingStrategy.SHARD_GRAD_OP, ShardingStrategy.NO_SHARD] +sharding_strategy_config = [ + None, + ShardingStrategy.SHARD_GRAD_OP, + ShardingStrategy.NO_SHARD, +] configs = list(itertools.product(cpu_offload_config, sharding_strategy_config)) test_name_mapping = { str(CPUOffload(offload_params=True)): "offload_true", @@ -134,14 +138,10 @@ def test_nested_wrapped_model_single_iteration_mixed_precision( @skip_if_lt_x_gpu(2) @parametrize(params, configs, subtest_name) - # TODO (awgu): 2.0 fails tests - # @parametrize("norm_type", [2.0, None]) - @parametrize("norm_type", [None]) def test_nested_always_wrap_model( self, cpu_offload: CPUOffload, sharding_strategy: Optional[ShardingStrategy], - norm_type: Optional[float], ): self.run_subtests( self._get_subtest_config(cpu_offload), @@ -150,19 +150,14 @@ def test_nested_always_wrap_model( FSDPInitMode.RECURSIVE, cpu_offload=cpu_offload, sharding_strategy=sharding_strategy, - norm_type=norm_type, ) @skip_if_lt_x_gpu(2) @parametrize(params, configs, subtest_name) - # TODO (awgu): 2.0 fails tests - # @parametrize("norm_type", [2.0, None]) - @parametrize("norm_type", [None]) def test_transformer( self, cpu_offload: CPUOffload, sharding_strategy: Optional[ShardingStrategy], - norm_type: Optional[float], ): self.run_subtests( self._get_subtest_config(cpu_offload), @@ -170,7 +165,6 @@ def test_transformer( TransformerWithSharedParams, FSDPInitMode.RECURSIVE, cpu_offload=cpu_offload, - norm_type=norm_type, sharding_strategy=sharding_strategy, ) @@ -224,14 +218,10 @@ def _dummy_ddp_fn(self, model): @skip_if_lt_x_gpu(2) @parametrize(params, configs, subtest_name) - # TODO (awgu): 2.0 fails tests - # @parametrize("norm_type", [2.0, None]) - @parametrize("norm_type", [None]) def test_mixture_of_experts( self, cpu_offload: CPUOffload, sharding_strategy: Optional[ShardingStrategy], - norm_type: Optional[float], ): self.run_subtests( self._get_subtest_config(cpu_offload), @@ -241,7 +231,6 @@ def test_mixture_of_experts( ref_init_fn=self._dummy_ddp_fn, cpu_offload=cpu_offload, sharding_strategy=sharding_strategy, - norm_type=norm_type, ) @skip_if_lt_x_gpu(2) @@ -259,7 +248,7 @@ def test_mixture_of_experts_with_delay_before_free( ref_init_fn=self._dummy_ddp_fn, cpu_offload=cpu_offload, sharding_strategy=sharding_strategy, - init_kwargs={"delay_before_free_ms": 250} + init_kwargs={"delay_before_free_ms": 250}, ) @@ -361,13 +350,30 @@ def test_register_functions_called(self, cuda_first: bool, mixed_precision: bool fsdp_kwargs, ) input = fsdp_model.module.get_input(torch.device("cuda")) - fsdp_model._register_pre_backward_hooks = mock.MagicMock(return_value=None) - fsdp_model._register_post_backward_hooks = mock.MagicMock(return_value=None) - self.assertFalse(fsdp_model._register_post_backward_hooks.called) - self.assertFalse(fsdp_model._register_pre_backward_hooks.called) - fsdp_model(*input) - self.assertTrue(fsdp_model._register_post_backward_hooks.called) - self.assertTrue(fsdp_model._register_pre_backward_hooks.called) + + # Since `_register_pre_backward_hooks()` modifies the forward output, + # we cannot directly mock it. We implement our own counter instead. + orig_register_pre_backward_hooks = ( + torch.distributed.fsdp._runtime_utils._register_pre_backward_hooks + ) + register_pre_backward_hooks_call_count = 0 + + def _register_pre_backward_hooks_with_count(*args, **kwargs): + nonlocal register_pre_backward_hooks_call_count + register_pre_backward_hooks_call_count += 1 + return orig_register_pre_backward_hooks(*args, **kwargs) + + with mock.patch( + "torch.distributed.fsdp._runtime_utils._register_pre_backward_hooks", + _register_pre_backward_hooks_with_count, + ), mock.patch( + "torch.distributed.fsdp._runtime_utils._register_post_backward_hooks" + ) as register_post_bwd_mock: + self.assertEqual(register_pre_backward_hooks_call_count, 0) + self.assertFalse(register_post_bwd_mock.called) + fsdp_model(*input) + self.assertTrue(register_pre_backward_hooks_call_count > 0) + self.assertTrue(register_post_bwd_mock.called) class TestNoGrad(FSDPTest): @@ -397,7 +403,7 @@ def test_transformer_no_grad(self, mixed_precision): fsdp_model, num_steps=1, autocast=False, - mixed_precision=fsdp_kwargs["mixed_precision"] + mixed_precision=fsdp_kwargs["mixed_precision"], ) input = fsdp_model.module.get_input(torch.device("cuda")) # Run a forward in eval mode diff --git a/test/distributed/fsdp/test_fsdp_exec_order.py b/test/distributed/fsdp/test_fsdp_exec_order.py index eaf3066d1bad0..6cd00e5302181 100644 --- a/test/distributed/fsdp/test_fsdp_exec_order.py +++ b/test/distributed/fsdp/test_fsdp_exec_order.py @@ -11,10 +11,10 @@ from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import FSDPTest from torch.testing._internal.common_utils import ( - TEST_WITH_DEV_DBG_ASAN, instantiate_parametrized_tests, parametrize, run_tests, + TEST_WITH_DEV_DBG_ASAN, ) if not dist.is_available(): diff --git a/test/distributed/fsdp/test_flatten_params_wrapper.py b/test/distributed/fsdp/test_fsdp_flatten_params.py similarity index 51% rename from test/distributed/fsdp/test_flatten_params_wrapper.py rename to test/distributed/fsdp/test_fsdp_flatten_params.py index 016398c88deba..5b60ed9820617 100644 --- a/test/distributed/fsdp/test_flatten_params_wrapper.py +++ b/test/distributed/fsdp/test_fsdp_flatten_params.py @@ -1,44 +1,45 @@ # Owner(s): ["oncall: distributed"] import sys -import unittest import torch +import torch.nn as nn from torch import distributed as dist +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.flat_param import ( + FlatParamHandle, FlatParamShardMetadata, HandleConfig, HandleShardingStrategy, ) -from torch.distributed.fsdp.flatten_params_wrapper import FlattenParamsWrapper -from torch.testing._internal.common_utils import TestCase, run_tests +from torch.testing._internal.common_distributed import skip_if_lt_x_gpu +from torch.testing._internal.common_fsdp import FSDPTest +from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN if not dist.is_available(): print("Distributed not available, skipping tests", file=sys.stderr) sys.exit(0) +if TEST_WITH_DEV_DBG_ASAN: + print( + "Skip dev-asan as torch + multiprocessing spawn have known issues", + file=sys.stderr, + ) + sys.exit(0) -class TestFlattenParams(TestCase): - """Base test class and used for CPU case.""" - - def _get_default_config(self): - return HandleConfig(HandleShardingStrategy.FULL_SHARD, False, None, None) - - def _get_empty_module(self, seed=0): - torch.manual_seed(seed) # keep everything deterministic - class Test(torch.nn.Module): - def forward(self, x): - return x + 1 +class TestFlattenParams(FSDPTest): + """Tests parameter flattening and shard metadata logic.""" - module = Test() + @property + def world_size(self) -> int: + # Clamp the world size to 1 since these unit tests either exercise only + # the flattening logic or check sharding subroutines directly without + # requiring multiple ranks + return 1 - def get_input(device, dtype): - torch.manual_seed(1) # keep everything deterministic - return torch.rand(1).to(device=device, dtype=dtype) - - module.get_input = get_input - return module + def _get_default_config(self): + return HandleConfig(HandleShardingStrategy.FULL_SHARD, False, None, None, False) def _get_transformer(self, seed=0): torch.manual_seed(seed) # keep everything deterministic @@ -68,152 +69,247 @@ def _get_shared_params_transformer(self, seed=0): dec_layer.linear2.weight = enc_layer.linear2.weight return module - def _get_output(self, module): - device = next(module.parameters()).device - dtype = next(module.parameters()).dtype - input = module.get_input(device, dtype) - return module(*input) - - def _get_pnorm_after_step(self, module): - optim = torch.optim.SGD(module.parameters(), lr=0.01) - loss = self._get_output(module).sum() - loss.backward() - optim.step() - return torch.norm(torch.stack([p.detach().norm() for p in module.parameters()])) - - def _test_num_params(self, module): - ref_num_params = sum(p.numel() for p in module.parameters()) - - params_to_flatten = list(module.parameters()) - flat_module = FlattenParamsWrapper( - module, - params_to_flatten, - torch.device("cuda"), - self._get_default_config(), - False, - ) - flat_num_params = sum(p.numel() for p in flat_module.parameters()) - - self.assertEqual(ref_num_params, flat_num_params) - self.assertEqual(flat_num_params, flat_module.flat_param.numel()) - - def _test_output(self, module): - ref_output = self._get_output(module) - - params_to_flatten = list(module.parameters()) - flat_module = FlattenParamsWrapper( - module, - params_to_flatten, - torch.device("cuda"), - self._get_default_config(), - False, + @skip_if_lt_x_gpu(1) + def test_partial_flattening(self): + """Tests flattening some submodules but not others.""" + self.run_subtests( + {"half": [False, True]}, + self._test_partial_flattening, ) - flat_output = self._get_output(flat_module) - self.assertEqual(ref_output, flat_output) - def test_partial_flattening(self): + def _test_partial_flattening(self, half: bool): module = self._get_transformer() - num_params = sum(p.numel() for p in module.parameters()) - - params_to_flatten = list(module.encoder.layers[1].parameters()) + list( - module.decoder.layers[0].parameters() + if half: + module = module.half() + numel = sum(p.numel() for p in module.parameters()) + + encoder_1_params = list(module.encoder.layers[1].parameters()) + decoder_0_params = list(module.decoder.layers[0].parameters()) + params_to_flatten = encoder_1_params + decoder_0_params + num_params = [len(encoder_1_params), len(decoder_0_params)] + numel_to_flatten = sum(p.numel() for p in params_to_flatten) + module.encoder.layers[1] = FSDP(module.encoder.layers[1]) + module.decoder.layers[0] = FSDP(module.decoder.layers[0]) + flat_params = [ + module.encoder.layers[1]._flat_param, + module.decoder.layers[0]._flat_param, + ] + + self.assertEqual(sum(fp.numel() for fp in flat_params), numel_to_flatten) + self.assertEqual(sum(p.numel() for p in module.parameters()), numel) + + # Check that flattened parameters have been replaced with a single + # `FlatParameter` + self.assertEqual(len(list(module.encoder.layers[1].parameters())), 1) + self.assertEqual(len(list(module.decoder.layers[0].parameters())), 1) + + # Check that non-flattened parameters remain + self.assertEqual( + len(list(module.encoder.layers[0].parameters())), num_params[0] ) - num_params_to_flatten = sum(p.numel() for p in params_to_flatten) - - module = FlattenParamsWrapper( - module, - params_to_flatten, - torch.device("cuda"), - self._get_default_config(), - False, + self.assertEqual( + len(list(module.decoder.layers[1].parameters())), num_params[1] ) - self.assertEqual(module.flat_param.numel(), num_params_to_flatten) - self.assertEqual(sum(p.numel() for p in module.parameters()), num_params) - # flattened parameters are removed - self.assertEqual(len(list(module.encoder.layers[1].parameters())), 0) - self.assertEqual(len(list(module.decoder.layers[0].parameters())), 0) - - # non-flattened parameters remain - self.assertGreater(len(list(module.encoder.layers[0].parameters())), 0) - self.assertGreater(len(list(module.decoder.layers[1].parameters())), 0) - - # test that changing the module dtype works properly + # Check that calling `module.to()` affects the `FlatParameter`s orig_dtype = params_to_flatten[0].dtype new_dtype = torch.float32 if orig_dtype == torch.float16 else torch.float16 - self.assertEqual(module.flat_param.dtype, orig_dtype) + for flat_param in flat_params: + self.assertEqual(flat_param.dtype, orig_dtype) self.assertTrue( all(p.dtype == orig_dtype for p in module.encoder.layers[0].parameters()) ) module = module.to(dtype=new_dtype) - self.assertEqual(module.flat_param.dtype, new_dtype) + for flat_param in flat_params: + self.assertEqual(flat_param.dtype, new_dtype) self.assertTrue( all(p.dtype == new_dtype for p in module.encoder.layers[0].parameters()) ) def test_flatten_nothing(self): - module = self._get_transformer() - module = FlattenParamsWrapper( - module, - [], - torch.device("cuda"), - self._get_default_config(), - False, + """ + Tests that constructing a ``FlatParamHandle`` with no parameters + raises an error. + """ + self.run_subtests( + {"half": [False, True]}, + self._test_flatten_nothing, ) - self.assertIsNone(module.flat_param) + def _test_flatten_nothing(self, half: bool): + module = self._get_transformer() + if half: + module = module.half() + with self.assertRaisesRegex( + ValueError, + "Cannot initialize a `FlatParameter` from an empty parameter list", + ): + FlatParamHandle( + [], + module, + module, + torch.device("cuda"), + self._get_default_config(), + self.process_group, + False, + ) + + @skip_if_lt_x_gpu(1) def test_empty_module(self): + """ + Tests flattening an empty module (i.e. one without any parameters). + """ module = self._get_empty_module() in_data = torch.rand(1) ref_out = module(in_data) - module = FlattenParamsWrapper( - module, - [], - torch.device("cuda"), - self._get_default_config(), - False, + fsdp_module = FSDP(module) + self.assertEqual(len(list(fsdp_module.parameters())), 0) + self.assertIsNone(fsdp_module._flat_param) + fsdp_out = fsdp_module(in_data) + self.assertEqual(ref_out, fsdp_out) + + def _get_empty_module(self): + """Returns a module with no parameters.""" + torch.manual_seed(0) # keep everything deterministic + + class EmptyModule(torch.nn.Module): + def forward(self, x): + return x + 1 + + def get_input(self, device, dtype): + torch.manual_seed(1) # keep everything deterministic + return torch.rand(1).to(device=device, dtype=dtype) + + return EmptyModule() + + def test_numel_without_shared_params(self): + """ + Tests that numel is preserved after flattening when there are no shared + parameters in the module. + """ + self.run_subtests( + {"half": [False, True]}, + self._test_numel_without_shared_params, ) - self.assertEqual(len(list(module.parameters())), 0) - self.assertIsNone(module.flat_param) - fpw_out = module(in_data) - self.assertEqual(ref_out, fpw_out) - def test_num_params(self): + def _test_numel_without_shared_params(self, half: bool): module = self._get_transformer() - self._test_num_params(module) + if half: + module = module.half() + self._test_numel(module) + + def test_numel_with_shared_params(self): + """ + Tests that numel is preserved after flattening when there are shared + parameters in the module. + """ + self.run_subtests( + {"half": [False, True]}, + self._test_numel_with_shared_params, + ) - def test_shared_params_num_params(self): + def _test_numel_with_shared_params(self, half: bool): module = self._get_shared_params_transformer() - self._test_num_params(module) + if half: + module = module.half() + self._test_numel(module) - def test_output(self): + def _test_numel(self, module): + ref_numel = sum(p.numel() for p in module.parameters()) + params_to_flatten = list(module.parameters()) + flat_param_handle = FlatParamHandle( + params_to_flatten, + module, + module, + torch.device("cuda"), + self._get_default_config(), + self.process_group, + False, + ) + self.assertEqual(ref_numel, flat_param_handle.flat_param.numel()) + + @skip_if_lt_x_gpu(1) + def test_output_without_shared_params(self): + """ + Tests a forward pass after flattening when there are no shared + parameters in the module. + """ + self.run_subtests( + {"half": [False, True]}, + self._test_output_without_shared_params, + ) + + def _test_output_without_shared_params(self, half: bool): module = self._get_transformer() + if half: + module = module.half() self._test_output(module) - def test_shared_params_output(self): + @skip_if_lt_x_gpu(1) + def test_output_with_shared_params(self): + """ + Tests a forward pass after flattening when there are shared parameters + in the module. + """ + self.run_subtests( + {"half": [False, True]}, + self._test_output_with_shared_params, + ) + + def _test_output_with_shared_params(self, half: bool): module = self._get_shared_params_transformer() + if half: + module = module.half() self._test_output(module) - def test_shared_params_pnorm_after_step(self): - # incorrect parameter sharing is likely to cause problems after an - # optimization step - module = self._get_shared_params_transformer() - ref_pnorm_after_step = self._get_pnorm_after_step(module) + def _test_output(self, module: nn.Module): + module = module.to(self.rank) + ref_output = self._get_output(module) + fsdp_module = FSDP(module) + fsdp_output = self._get_output(fsdp_module) + self.assertEqual(ref_output, fsdp_output) - module = self._get_shared_params_transformer() # recreate - params_to_flatten = list(module.parameters()) - flat_module = FlattenParamsWrapper( - module, - params_to_flatten, - torch.device("cuda"), - self._get_default_config(), - False, + def _get_output(self, module): + device = next(module.parameters()).device + dtype = next(module.parameters()).dtype + input = module.get_input(device, dtype) + return module(*input) + + @skip_if_lt_x_gpu(1) + def test_pnorm_after_step_with_shared_params(self): + """ + Tests for parameter Frobenius norm parity after an optimizer step when + there are shared parameters in the module. If the parameter sharing is + handled incorrectly, then an optimizer step should reveal that. + """ + self.run_subtests( + {"half": [False, True]}, + self._test_pnorm_after_step_with_shared_params, ) - flat_pnorm_after_step = self._get_pnorm_after_step(flat_module) - self.assertEqual(ref_pnorm_after_step, flat_pnorm_after_step) + def _test_pnorm_after_step_with_shared_params(self, half: bool): + module = self._get_shared_params_transformer().to(self.rank) + if half: + module = module.half() + ref_pnorm_after_step = self._get_pnorm_after_step(module) + module = self._get_shared_params_transformer().to(self.rank) # recreate + if half: + module = module.half() + fsdp_module = FSDP(module) + fsdp_pnorm_after_step = self._get_pnorm_after_step(fsdp_module) + self.assertEqual(ref_pnorm_after_step, fsdp_pnorm_after_step) - def test_sharded_flat_param(self): + def _get_pnorm_after_step(self, module): + optim = torch.optim.SGD(module.parameters(), lr=0.01) + loss = self._get_output(module).sum() + loss.backward() + optim.step() + return torch.norm(torch.stack([p.detach().norm() for p in module.parameters()])) + + def test_flat_param_shard_metadata(self): + """ + Tests that ``FlatParameter`` shard metadata are computed as expected. + """ module = torch.nn.Sequential( torch.nn.Linear(10, 10, bias=False), torch.nn.ReLU(), @@ -223,14 +319,15 @@ def test_sharded_flat_param(self): torch.nn.ReLU(), ) params_to_flatten = list(module.parameters()) - flat_module = FlattenParamsWrapper( - module, + flat_param_handle = FlatParamHandle( params_to_flatten, + module, + module, torch.device("cuda"), self._get_default_config(), + self.process_group, False, ) - flat_param_handle = flat_module.handle def _test(kwargs, expected): """ @@ -244,9 +341,11 @@ def _test(kwargs, expected): ``init_shard_info()`` with the start and end indices fixed based on rank and world size. """ - flat_param = flat_module.flat_param - flat_param._shard_param_offsets, flat_param._shard_indices = \ - flat_param_handle._get_shard_metadata(kwargs["start"], kwargs["end"]) + flat_param = flat_param_handle.flat_param + ( + flat_param._shard_param_offsets, + flat_param._shard_indices, + ) = flat_param_handle._get_shard_metadata(kwargs["start"], kwargs["end"]) self.assertEqual( flat_param_handle.shard_metadata(), expected, @@ -345,19 +444,5 @@ def _test(kwargs, expected): ) -@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") -class TestFlattenParamsCUDA(TestFlattenParams): - def _get_transformer(self, seed=0): - module = super()._get_transformer(seed=seed) - return module.cuda() - - -@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") -class TestFlattenParamsCUDAHalf(TestFlattenParams): - def _get_transformer(self, seed=0): - module = super()._get_transformer(seed=seed) - return module.cuda().half() - - if __name__ == "__main__": run_tests() diff --git a/test/distributed/fsdp/test_fsdp_freezing_weights.py b/test/distributed/fsdp/test_fsdp_freezing_weights.py index 23836130818c9..430e47adf71e0 100644 --- a/test/distributed/fsdp/test_fsdp_freezing_weights.py +++ b/test/distributed/fsdp/test_fsdp_freezing_weights.py @@ -10,18 +10,14 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.nn.parallel import DistributedDataParallel from torch.testing._internal.common_distributed import skip_if_lt_x_gpu -from torch.testing._internal.common_fsdp import ( - FSDPTest, - get_full_params, -) +from torch.testing._internal.common_fsdp import FSDPTest, get_full_params from torch.testing._internal.common_utils import ( - TEST_WITH_DEV_DBG_ASAN, instantiate_parametrized_tests, parametrize, run_tests, + TEST_WITH_DEV_DBG_ASAN, ) - if not dist.is_available(): print("Distributed not available, skipping tests", file=sys.stderr) sys.exit(0) diff --git a/test/distributed/fsdp/test_fsdp_fx.py b/test/distributed/fsdp/test_fsdp_fx.py index 7b0e0a3ddf2f2..43f8de2150f92 100644 --- a/test/distributed/fsdp/test_fsdp_fx.py +++ b/test/distributed/fsdp/test_fsdp_fx.py @@ -1,13 +1,11 @@ # Owner(s): ["oncall: distributed"] -from typing import Any - import torch -from torch.distributed.fsdp._symbolic_trace import _init_execution_info, _patch_tracer -from torch.testing._internal.common_fsdp import FSDPTest +from torch.distributed.fsdp._trace_utils import _ExecOrderTracer from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, run_tests, + TestCase, ) @@ -26,38 +24,37 @@ def __init__(self) -> None: ) self.relu = torch.nn.ReLU() - def forward(self, x: Any, run_all_layers: bool): + def forward(self, x: torch.Tensor, run_all_layers: bool) -> torch.Tensor: z = self.relu(self.layer0(x)) z = self.relu(self.layer2(z)) z = z @ self.weight1 if run_all_layers: z = self.relu(self.layer1(z)) z = z @ self.weight2 - # used to test the case where a module is called more than once + # Use `layer0` twice to check the handling of multiplicity in the + # saved data structures z = self.relu(self.layer0(x)) return z -class TestSymbolicTracing(FSDPTest): +class TestSymbolicTracing(TestCase): def test_symbolic_tracing_outputs(self): """ - test ``execution_info.module_forward_order`` and ``execution_info.module_to_execution_infos`` - after running ``tracer.trace()`` inside ``_patch_tracer``. + Tests running ``tracer.trace()`` inside ``patch_tracer()`` by checking + the saved data structures. """ model = Model() tracer = torch.fx.Tracer() - execution_info = _init_execution_info(model) - original_call_module = tracer.call_module - original_create_proxy = tracer.create_proxy - with _patch_tracer( - tracer=tracer, root_module=model, execution_info=execution_info - ): + orig_call_module = tracer.call_module + orig_create_proxy = tracer.create_proxy + exec_order_tracer = _ExecOrderTracer() + with exec_order_tracer.patch_tracer(tracer=tracer, root_module=model): concrete_args = {"run_all_layers": True} tracer.trace(model, concrete_args) - # the member functions of tracer should not be changed - self.assertEqual(original_call_module, tracer.call_module) - self.assertEqual(original_create_proxy, tracer.create_proxy) - # test tracer.module_forward_order + # Check that the tracer methods are unchanged after exiting the context + self.assertEqual(orig_call_module, tracer.call_module) + self.assertEqual(orig_create_proxy, tracer.create_proxy) + # Check `module_forward_order` correct_module_forward_order = [ model, model.layer0, @@ -72,12 +69,11 @@ def test_symbolic_tracing_outputs(self): model.layer0, model.relu, ] + exec_info = exec_order_tracer.exec_info + self.assertEqual(exec_info.module_forward_order, correct_module_forward_order) + # Check `module_to_param_usage_infos` self.assertEqual( - execution_info.module_forward_order, correct_module_forward_order - ) - # test execution_info.module_to_execution_infos - self.assertEqual( - execution_info.module_to_execution_infos[model], + exec_info.module_to_param_usage_infos[model], [ (model.layer0, list(model.layer0.named_parameters())), (model.layer2, list(model.layer2.named_parameters())), @@ -88,22 +84,22 @@ def test_symbolic_tracing_outputs(self): ], ) self.assertEqual( - execution_info.module_to_execution_infos[model.layer0], + exec_info.module_to_param_usage_infos[model.layer0], [(model.layer0, list(model.layer0.named_parameters()))], ) self.assertEqual( - execution_info.module_to_execution_infos[model.layer1], + exec_info.module_to_param_usage_infos[model.layer1], [(model.layer1, list(model.layer1.named_parameters()))], ) self.assertEqual( - execution_info.module_to_execution_infos[model.layer2], + exec_info.module_to_param_usage_infos[model.layer2], [ (model.layer2[0], list(model.layer2[0].named_parameters())), (model.layer2[2], list(model.layer2[2].named_parameters())), ], ) - self.assertEqual(execution_info.module_to_execution_infos[model.relu], []) - # test tracer.param_exec_order + self.assertEqual(exec_info.module_to_param_usage_infos[model.relu], []) + # Check `param_forward_order` correct_param_order = [ model.layer0.weight, model.layer0.bias, @@ -113,7 +109,12 @@ def test_symbolic_tracing_outputs(self): model.layer1.weight, model.weight2, ] - self.assertEqual(execution_info.param_exec_order, correct_param_order) + self.assertEqual(exec_info.param_forward_order, correct_param_order) + # Check `visited_params` + self.assertEqual( + len(exec_info.visited_params), len(exec_info.param_forward_order) + ) + self.assertEqual(exec_info.visited_params, set(exec_info.param_forward_order)) instantiate_parametrized_tests(TestSymbolicTracing) diff --git a/test/distributed/fsdp/test_fsdp_grad_acc.py b/test/distributed/fsdp/test_fsdp_grad_acc.py index 1e44f865027d0..ef20d2a2db76e 100644 --- a/test/distributed/fsdp/test_fsdp_grad_acc.py +++ b/test/distributed/fsdp/test_fsdp_grad_acc.py @@ -8,8 +8,7 @@ import torch from torch import distributed as dist -from torch.distributed.fsdp import CPUOffload -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import CPUOffload, FullyShardedDataParallel as FSDP from torch.distributed.fsdp.fully_sharded_data_parallel import ( BackwardPrefetch, ShardingStrategy, @@ -22,10 +21,10 @@ TransformerWithSharedParams, ) from torch.testing._internal.common_utils import ( - TEST_WITH_DEV_DBG_ASAN, instantiate_parametrized_tests, parametrize, run_tests, + TEST_WITH_DEV_DBG_ASAN, ) if not dist.is_available(): @@ -53,16 +52,14 @@ class _GradAccConfig: manager as the way to accumulate gradients. num_iters (int): Number of iterations to accumulate gradients. """ + use_no_sync: bool num_iters: int def __repr__(self) -> str: # Override to remove any spaces in the string to appease the internal # build's test name parser - return ( - f"(use_no_sync={self.use_no_sync}," - f"num_iters={self.num_iters})" - ) + return f"(use_no_sync={self.use_no_sync}," f"num_iters={self.num_iters})" @dataclass @@ -71,14 +68,13 @@ class _GradAccConfigs: This wraps a :class:`list` of :class:`_GradAccConfig` instances with the sole purpose of overriding :meth:`__repr__` to remove spaces. """ + configs: List[_GradAccConfig] def __repr__(self) -> str: # Override to remove any spaces in the string to appease the internal # build's test name parser - return ( - "[" + ",".join(config.__repr__() for config in self.configs) + "]" - ) + return "[" + ",".join(config.__repr__() for config in self.configs) + "]" class TestGradAcc(FSDPTest): @@ -118,9 +114,8 @@ def _test_grad_acc( """ # Gradient accumulation outside `no_sync()` is not currently compatible # with CPU offloading - if ( - cpu_offload.offload_params - and any(not config.use_no_sync for config in configs) + if cpu_offload.offload_params and any( + not config.use_no_sync for config in configs ): return old_allow_tf32 = torch.backends.cuda.matmul.allow_tf32 @@ -144,7 +139,9 @@ def _test_grad_acc( ) device = torch.device("cuda") optim = torch.optim.SGD( - fsdp_model.parameters(), lr=0.01, momentum=0.9, + fsdp_model.parameters(), + lr=0.01, + momentum=0.9, ) # Generate the sequence of batches, each containing the same data @@ -152,16 +149,16 @@ def _test_grad_acc( def permute_tensor(x: torch.Tensor): return x.view(-1)[torch.randperm(x.numel())].view_as(x) - batch: Tuple[torch.Tensor, ...] = \ - fsdp_model.module.get_input(device) + batch: Tuple[torch.Tensor, ...] = fsdp_model.module.get_input(device) batches: List[Tuple[torch.Tensor, ...]] = [batch] num_iters_to_acc = sum(config.num_iters for config in configs) for _ in range(num_iters_to_acc - 1): batches.append(tuple(permute_tensor(t) for t in batch)) for (batch1, batch2) in itertools.combinations(batches, r=2): for t1, t2 in zip(batch1, batch2): - assert not torch.all(t1 == t2), \ - "Check the test to make sure that batches are distinct" + assert not torch.all( + t1 == t2 + ), "Check the test to make sure that batches are distinct" # Concatenate the batches along the given batch dimension concat_batch: Tuple[torch.Tensor, ...] = tuple( @@ -173,17 +170,18 @@ def permute_tensor(x: torch.Tensor): output = fsdp_model(*concat_batch) ref_loss = fsdp_model.module.get_loss(concat_batch, output) ref_loss.backward() - ref_grads = [ - p.grad.detach().clone() for p in fsdp_model.parameters() - ] + ref_grads = [p.grad.detach().clone() for p in fsdp_model.parameters()] # Compute and accumulate the gradients fsdp_model.zero_grad() losses = [] batch_idx = 0 for config in configs: - sync_context = fsdp_model.no_sync() if config.use_no_sync \ + sync_context = ( + fsdp_model.no_sync() + if config.use_no_sync else contextlib.suppress() + ) with sync_context: for _ in range(config.num_iters): if batch_idx == num_iters_to_acc - 1: @@ -199,9 +197,7 @@ def permute_tensor(x: torch.Tensor): loss.backward() losses.append(loss) acc_loss = sum(losses) - acc_grads = [ - p.grad.detach().clone() for p in fsdp_model.parameters() - ] + acc_grads = [p.grad.detach().clone() for p in fsdp_model.parameters()] # Compare the losses and gradients torch.testing.assert_close(ref_loss, acc_loss) @@ -231,17 +227,21 @@ def _get_subtest_config(self) -> Dict[str, List[Any]]: @parametrize( "configs", [ - _GradAccConfigs([ - _GradAccConfig(use_no_sync=True, num_iters=3), - _GradAccConfig(use_no_sync=False, num_iters=3), - _GradAccConfig(use_no_sync=True, num_iters=3), - ]), - _GradAccConfigs([ - _GradAccConfig(use_no_sync=False, num_iters=3), - _GradAccConfig(use_no_sync=True, num_iters=3), - _GradAccConfig(use_no_sync=False, num_iters=3), - ]), - ] + _GradAccConfigs( + [ + _GradAccConfig(use_no_sync=True, num_iters=3), + _GradAccConfig(use_no_sync=False, num_iters=3), + _GradAccConfig(use_no_sync=True, num_iters=3), + ] + ), + _GradAccConfigs( + [ + _GradAccConfig(use_no_sync=False, num_iters=3), + _GradAccConfig(use_no_sync=True, num_iters=3), + _GradAccConfig(use_no_sync=False, num_iters=3), + ] + ), + ], ) @parametrize( "cpu_offload", @@ -253,7 +253,7 @@ def _get_subtest_config(self) -> Dict[str, List[Any]]: ShardingStrategy.FULL_SHARD, ShardingStrategy.SHARD_GRAD_OP, ShardingStrategy.NO_SHARD, - ] + ], ) def test_grad_acc( self, diff --git a/test/distributed/fsdp/test_fsdp_hybrid_shard.py b/test/distributed/fsdp/test_fsdp_hybrid_shard.py new file mode 100644 index 0000000000000..cda15ef21d792 --- /dev/null +++ b/test/distributed/fsdp/test_fsdp_hybrid_shard.py @@ -0,0 +1,243 @@ +# Owner(s): ["oncall: distributed"] + +import contextlib +from functools import partial +from collections import Counter +import sys + +import torch +import torch.nn as nn +import torch.distributed as dist + +from torch.distributed.distributed_c10d import _rank_not_in_group +from torch.distributed.fsdp import ( + FullyShardedDataParallel as FSDP, + ShardingStrategy, +) +from torch.distributed.fsdp.wrap import ModuleWrapPolicy +from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer +from torch.testing._internal.common_distributed import skip_if_lt_x_gpu +from torch.testing._internal.common_fsdp import ( + CUDAInitMode, + FSDPInitMode, + FSDPTest, + TransformerWithSharedParams, +) +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + run_tests, + TEST_WITH_DEV_DBG_ASAN, +) + +if not dist.is_available(): + print("Distributed not available, skipping tests", file=sys.stderr) + sys.exit(0) + +if TEST_WITH_DEV_DBG_ASAN: + print( + "Skip dev-asan as torch + multiprocessing spawn have known issues", + file=sys.stderr, + ) + sys.exit(0) + + +@contextlib.contextmanager +def patch_allreduce(new_allreduce): + """ + Patches dist.all_reduce with a new all_reduce and + restores upon exiting. + """ + orig_ar = dist.all_reduce + dist.all_reduce = new_allreduce + try: + yield + finally: + dist.all_reduce = orig_ar + +@contextlib.contextmanager +def patch_reduce_scatter(new_reduce_scatter): + """ + Patches dist.reduce_scatter_tensor with a new reduce_scatter_tensor and + restores upon exiting. + """ + orig_reduce_scatter = dist.reduce_scatter_tensor + dist.reduce_scatter_tensor = new_reduce_scatter + try: + yield + finally: + dist.reduce_scatter_tensor = orig_reduce_scatter + + +class MyModel(nn.Module): + def __init__(self): + super().__init__() + self.lin1 = nn.Linear(10, 10) + self.lin2 = nn.Linear(10, 10) + self.lin3 = nn.Linear(10, 10) + +class TestFSDPHybridShard(FSDPTest): + @property + def world_size(self): + return max(torch.cuda.device_count(), 2) + + @property + def process_group(self): + return dist.distributed_c10d._get_default_group() + + @skip_if_lt_x_gpu(2) + def test_raises_manual_wrap_hybrid_shard_when_none_policy(self): + model = MyModel().cuda() + err_ctx = self.assertRaisesRegex( + ValueError, "requires explicit specification of process group" + ) + + with err_ctx: + model = FSDP(model, sharding_strategy=ShardingStrategy.HYBRID_SHARD) + + with err_ctx: + model = FSDP(model, sharding_strategy=ShardingStrategy._HYBRID_SHARD_ZERO2) + + + @skip_if_lt_x_gpu(2) + def test_hybrid_shard_strategy_mismatch_raises(self): + for sharding_strategy in [ + ShardingStrategy._HYBRID_SHARD_ZERO2, + ShardingStrategy.HYBRID_SHARD + ]: + with self.subTest(sharding_strategy=sharding_strategy): + model = MyModel().cuda() + intra_pg = self.process_group + inter_pg = dist.new_group(ranks=[self.rank]) + model.lin1 = FSDP(model.lin1, process_group=(intra_pg, inter_pg), sharding_strategy=sharding_strategy) + self.assertEqual(model.lin1.process_group, intra_pg) + self.assertEqual(model.lin1._inter_node_pg, inter_pg) + model = FSDP(model, process_group=intra_pg) + inp = torch.randn(4, 10) + # Errors during _lazy_init + with self.assertRaisesRegex(ValueError, "expect sharding strategies to be the same"): + model(inp) + + @skip_if_lt_x_gpu(2) + def test_hybrid_shard_pg_mismatch_raises(self): + model = MyModel().cuda() + intra_pg = self.process_group + inter_pg = dist.new_group(ranks=[self.rank]) + # Mismatched process groups for intra-node + model.lin1 = FSDP( + model.lin1, process_group=(intra_pg, inter_pg), sharding_strategy=ShardingStrategy.HYBRID_SHARD + ) + model = FSDP( + model, process_group=(dist.new_group(), dist.new_group()), sharding_strategy=ShardingStrategy.HYBRID_SHARD + ) + # Errors during _lazy_init + inp = torch.randn(4, 10) + with self.assertRaisesRegex(ValueError, "intra-node process groups do not match"): + model(inp) + + # Mismatched process groups for inter-node + model = MyModel().cuda() + model.lin1 = FSDP( + model.lin1, process_group=(intra_pg, inter_pg), sharding_strategy=ShardingStrategy.HYBRID_SHARD + ) + model = FSDP( + model, process_group=(intra_pg, dist.new_group()), sharding_strategy=ShardingStrategy.HYBRID_SHARD + ) + with self.assertRaisesRegex(ValueError, "inter-node process groups do not match"): + model(inp) + + @skip_if_lt_x_gpu(2) + def test_invalid_pg_specification_raises(self): + pol = ModuleWrapPolicy({nn.Linear}) + model = MyModel().cuda() + with self.assertRaisesRegex( + ValueError, + "Expected process_group to be passed in" + ): + model = FSDP( + model, + auto_wrap_policy=pol, + process_group=self.process_group, + sharding_strategy=ShardingStrategy.HYBRID_SHARD + ) + + # TODO - add test for ZeRO-2 style sharding ensure params are not + # resharded after forward. + + @skip_if_lt_x_gpu(2) + def test_fsdp_hybrid_shard_basic_setup(self): + """ + Tests basic functionality of HYBRID_SHARD and _HYBRID_SHARD_ZERO2: + 1. Inter and intra-node process groups are correctly setup + 2. Process groups are the same across FSDP wrapped instances + 3. reduce_scatter and allreduce called the expected no. of times + """ + for sharding_strategy in [ + ShardingStrategy.HYBRID_SHARD, ShardingStrategy._HYBRID_SHARD_ZERO2 + ]: + with self.subTest(sharding_strategy=sharding_strategy): + auto_wrap_policy = ModuleWrapPolicy( + {TransformerEncoderLayer, TransformerDecoderLayer}, + ) + fsdp_kwargs = { + "auto_wrap_policy": auto_wrap_policy, + "device_id": torch.cuda.current_device(), + "sharding_strategy": sharding_strategy, + } + fsdp_model = TransformerWithSharedParams.init( + self.process_group, + FSDPInitMode.RECURSIVE, + CUDAInitMode.CUDA_BEFORE, + fsdp_kwargs, + ) + # All FSDP modules should have state.process_group as the process group over which to + # shard (default process group), and state._inter_node_pg (process group containing only + # this rank) + intra_node_pgs = set() + inter_node_pgs = set() + for mod in fsdp_model.fsdp_modules(fsdp_model): + # process_group should be across the node, which is just the + # whole world here. + self.assertEqual( + dist.get_world_size(mod.process_group), + dist.get_world_size(self.process_group) + ) + intra_node_pgs.add(mod.process_group) + inter_node_pg = mod._inter_node_pg + inter_node_pgs.add(inter_node_pg) + self.assertEqual(1, dist.get_world_size(inter_node_pg)) + self.assertFalse(_rank_not_in_group(inter_node_pg)) + self.assertEqual( + sharding_strategy, mod.sharding_strategy + ) + # All fsdp modules should share the same process groups + self.assertEqual(1, len(intra_node_pgs)) + self.assertEqual(1, len(inter_node_pgs)) + + orig_ar = dist.all_reduce + orig_rs = dist.reduce_scatter_tensor + + def patched_collective(orig_collective, counter, *args, **kwargs): + counter[orig_collective] += 1 + return orig_collective(*args, **kwargs) + + cntr = Counter() + patched_allreduce = partial(patched_collective, orig_ar, cntr) + patched_reduce_scatter = partial(patched_collective, orig_rs, cntr) + with ( + patch_allreduce(patched_allreduce), + patch_reduce_scatter(patched_reduce_scatter), + ): + inp = fsdp_model.get_input(device=torch.cuda.current_device()) + out = fsdp_model(inp[0], inp[1]) + loss = fsdp_model.get_loss(inp, out) + loss.backward() + + num_flat_params = len(list(FSDP._fsdp_handles(fsdp_model))) + self.assertEqual(num_flat_params, cntr[orig_ar]) + self.assertEqual(num_flat_params, cntr[orig_rs]) + dist.barrier() + +instantiate_parametrized_tests(TestFSDPHybridShard) + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/fsdp/test_fsdp_ignored_modules.py b/test/distributed/fsdp/test_fsdp_ignored_modules.py index 60c3fd6f88110..297cd3f3ca606 100644 --- a/test/distributed/fsdp/test_fsdp_ignored_modules.py +++ b/test/distributed/fsdp/test_fsdp_ignored_modules.py @@ -14,10 +14,10 @@ TransformerWithSharedParams, ) from torch.testing._internal.common_utils import ( - TEST_WITH_DEV_DBG_ASAN, instantiate_parametrized_tests, parametrize, run_tests, + TEST_WITH_DEV_DBG_ASAN, ) if not dist.is_available(): @@ -74,12 +74,15 @@ def forward(self, x): class ModelWithIgnoredModules(Model): """Adds a variable number of :class:`IgnoredModule` to ``self.layer1``.""" + def __init__(self, num_ignored: int) -> None: assert num_ignored >= 0 super().__init__() - layer1_modules = [torch.nn.Linear(5, 4), torch.nn.Linear(4, 4)] + \ - [IgnoredModule(4, 4) for _ in range(num_ignored)] + \ - [torch.nn.Linear(4, 4)] + layer1_modules = ( + [torch.nn.Linear(5, 4), torch.nn.Linear(4, 4)] + + [IgnoredModule(4, 4) for _ in range(num_ignored)] + + [torch.nn.Linear(4, 4)] + ) self.layer1 = torch.nn.Sequential(*layer1_modules) @@ -96,6 +99,12 @@ def _train_model(self, model, optim, num_iters, device=torch.device("cuda")): def test_ignored_modules_transformer(self): """Tests that ignored modules' parameters are not flattened for a transformer model with shared parameters.""" + self.run_subtests( + {"use_orig_params": [False, True]}, + self._test_ignored_modules_transformer, + ) + + def _test_ignored_modules_transformer(self, use_orig_params: bool): # Initialize an FSDP-wrapped transformer model that has FSDP ignore # the `nn.Transformer` module's parameters model: nn.Module = TransformerWithSharedParams.init( @@ -108,6 +117,7 @@ def test_ignored_modules_transformer(self): model, self.process_group, ignored_modules=[model.transformer], + use_orig_params=use_orig_params, ) # Check that the wrapped model's flattened parameter does not include # the ignored transformer module's parameters @@ -123,7 +133,8 @@ def test_ignored_modules_transformer(self): ) nonignored_numel = total_numel - ignored_numel with FSDP.summon_full_params(wrapped_model): - flat_param_numel = wrapped_model.params[0].numel() + flat_param = wrapped_model.params[0] + flat_param_numel = flat_param.numel() self.assertEqual(flat_param_numel, nonignored_numel) # Check that we can run a few iterations optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3) @@ -133,22 +144,29 @@ def test_ignored_modules_transformer(self): def test_ignored_modules_nested(self): """Tests that passing a module with nested FSDP modules does not error and still ignores non-FSDP modules' parameters.""" + self.run_subtests( + {"use_orig_params": [False, True]}, + self._test_ignored_modules_nested, + ) + + def _test_ignored_modules_nested(self, use_orig_params: bool): # Initialize an FSDP-wrapped nested model that first wraps the nested # sequential's second linear layer (`layer1[1]`) and then wraps the # overall model while ignoring the nested sequential (`layer1`) model = Model().cuda() - model.layer1[1] = FSDP(model.layer1[1]) - wrapped_model = FSDP(model, ignored_modules=[model.layer1]) + model.layer1[1] = FSDP(model.layer1[1], use_orig_params=use_orig_params) + wrapped_model = FSDP( + model, ignored_modules=[model.layer1], use_orig_params=use_orig_params + ) # Check that the wrapped model's flattened parameter does not include # the ignored nested sequential's parameters nonwrapped_model = Model() total_numel = sum(p.numel() for p in nonwrapped_model.parameters()) - ignored_numel = sum( - p.numel() for p in nonwrapped_model.layer1.parameters() - ) + ignored_numel = sum(p.numel() for p in nonwrapped_model.layer1.parameters()) nonignored_numel = total_numel - ignored_numel with FSDP.summon_full_params(wrapped_model): - flat_param_numel = wrapped_model.params[0].numel() + flat_param = wrapped_model.params[0] + flat_param_numel = flat_param.numel() self.assertEqual(flat_param_numel, nonignored_numel) # Check that we can run a few iterations optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3) @@ -176,7 +194,9 @@ def test_ignored_modules_invalid(self): @skip_if_lt_x_gpu(2) @parametrize("pass_ignored_modules_to_root", [False, True]) - def test_diff_ignored_modules_across_ranks(self, pass_ignored_modules_to_root: bool): + def test_diff_ignored_modules_across_ranks( + self, pass_ignored_modules_to_root: bool + ): """ Tests ignoring different modules across ranks. @@ -196,9 +216,11 @@ def test_diff_ignored_modules_across_ranks(self, pass_ignored_modules_to_root: b ] model.layer1 = FSDP(model.layer1, ignored_modules=layer1_ignored_modules) model.layer3 = FSDP(model.layer3) - model_ignored_modules = [ - m for m in model.modules() if isinstance(m, IgnoredModule) - ] if pass_ignored_modules_to_root else [] + model_ignored_modules = ( + [m for m in model.modules() if isinstance(m, IgnoredModule)] + if pass_ignored_modules_to_root + else [] + ) wrapped_model = FSDP(model, ignored_modules=model_ignored_modules) optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3) self._train_model(wrapped_model, optim, 3) diff --git a/test/distributed/fsdp/test_fsdp_input.py b/test/distributed/fsdp/test_fsdp_input.py index 136b65c3b28ec..06a516faaa97b 100644 --- a/test/distributed/fsdp/test_fsdp_input.py +++ b/test/distributed/fsdp/test_fsdp_input.py @@ -8,18 +8,15 @@ from torch.nn import Linear, Module from torch.optim import SGD from torch.testing._internal.common_distributed import skip_if_lt_x_gpu -from torch.testing._internal.common_fsdp import ( - FSDPTest, -) +from torch.testing._internal.common_fsdp import FSDPTest from torch.testing._internal.common_utils import ( - TEST_WITH_DEV_DBG_ASAN, instantiate_parametrized_tests, parametrize, run_tests, subtest, + TEST_WITH_DEV_DBG_ASAN, ) - if not dist.is_available(): print("Distributed not available, skipping tests", file=sys.stderr) sys.exit(0) diff --git a/test/distributed/fsdp/test_fsdp_memory.py b/test/distributed/fsdp/test_fsdp_memory.py index b26aa249dc798..fe2ad8879ad1b 100644 --- a/test/distributed/fsdp/test_fsdp_memory.py +++ b/test/distributed/fsdp/test_fsdp_memory.py @@ -8,18 +8,15 @@ from torch import distributed as dist from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.testing._internal.common_distributed import skip_if_lt_x_gpu -from torch.testing._internal.common_fsdp import ( - FSDPTest, -) +from torch.testing._internal.common_fsdp import FSDPTest from torch.testing._internal.common_utils import ( - TEST_WITH_DEV_DBG_ASAN, instantiate_parametrized_tests, parametrize, run_tests, + TEST_WITH_DEV_DBG_ASAN, ) from torch.utils.checkpoint import checkpoint - if not dist.is_available(): print("Distributed not available, skipping tests", file=sys.stderr) sys.exit(0) diff --git a/test/distributed/fsdp/test_fsdp_meta.py b/test/distributed/fsdp/test_fsdp_meta.py index 1aa426800db62..09e5c7ae83292 100644 --- a/test/distributed/fsdp/test_fsdp_meta.py +++ b/test/distributed/fsdp/test_fsdp_meta.py @@ -6,20 +6,19 @@ import torch.distributed as dist import torch.nn as nn from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp.wrap import always_wrap_policy as always_wrap -from torch.distributed.fsdp.wrap import wrap, enable_wrap -from torch.testing._internal.common_fsdp import ( - FSDPTest, +from torch.distributed.fsdp.wrap import ( + always_wrap_policy as always_wrap, + enable_wrap, + wrap, ) +from torch.testing._internal.common_distributed import skip_if_lt_x_gpu +from torch.testing._internal.common_fsdp import FSDPTest from torch.testing._internal.common_utils import ( - TEST_WITH_DEV_DBG_ASAN, - run_tests, - parametrize, instantiate_parametrized_tests, + parametrize, + run_tests, sandcastle_skip_if, -) -from torch.testing._internal.common_distributed import ( - skip_if_lt_x_gpu, + TEST_WITH_DEV_DBG_ASAN, ) _TORCHDISTX_AVAIL = True @@ -47,10 +46,12 @@ def _reset_params_if_meta(is_meta, model): if is_meta: model.reset_parameters() + class MyLinear(nn.Linear): """ Linear layer with deterministic reset_parameters for testing. """ + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -58,6 +59,7 @@ def reset_parameters(self, *args, **kwargs): with torch.no_grad(): self.weight.fill_(1) + class MyModel(nn.Module): def __init__(self, device): super().__init__() @@ -90,6 +92,7 @@ def reset_parameters(self): if not isinstance(m, FSDP): m.reset_parameters() + def _init_with_reset_params(module): """ to_empty + reset_parameters() init function example for modules @@ -101,6 +104,7 @@ def _init_with_reset_params(module): with torch.no_grad(): module.reset_parameters() + def _init_with_torchdistX(module): """ torchdistX-based deferred module initialization function example @@ -113,6 +117,7 @@ def check_fn(k): deferred_init.materialize_module(module, check_fn=check_fn) + class TestFSDPWithMetaDevice(FSDPTest): @property def world_size(self): @@ -148,7 +153,7 @@ def _test_simple_model_with_meta_device(self, meta_module_fn, init_fn=None): regular_opt = torch.optim.SGD(fsdp_regular.parameters(), lr=1e-3) self._compare_fsdp(fsdp_meta, fsdp_regular) - inp = torch.randn(10, 2, device='cuda') + inp = torch.randn(10, 2, device="cuda") fsdp_meta(inp).sum().backward() fsdp_regular(inp).sum().backward() meta_opt.step() @@ -176,6 +181,7 @@ def _test_simple_model_with_meta_device(self, meta_module_fn, init_fn=None): def test_simple_model_with_meta_device_reset_params(self): def meta_module_fn(): return MyModel(device="meta") + self._test_simple_model_with_meta_device( meta_module_fn, _init_with_reset_params ) @@ -184,11 +190,13 @@ def meta_module_fn(): def test_simple_model_with_meta_device_default_init(self): def meta_module_fn(): return MyModel(device="meta") + self._test_simple_model_with_meta_device(meta_module_fn) @skip_if_lt_x_gpu(2) @sandcastle_skip_if( - not _TORCHDISTX_AVAIL, "Test requires torchdistX: https://github.com/pytorch/torchdistX" + not _TORCHDISTX_AVAIL, + "Test requires torchdistX: https://github.com/pytorch/torchdistX", ) def test_simple_model_with_torchdistX_default_init(self): def meta_module_fn(): @@ -198,15 +206,20 @@ def meta_module_fn(): @skip_if_lt_x_gpu(2) @sandcastle_skip_if( - not _TORCHDISTX_AVAIL, "Test requires torchdistX: https://github.com/pytorch/torchdistX" + not _TORCHDISTX_AVAIL, + "Test requires torchdistX: https://github.com/pytorch/torchdistX", ) def test_simple_model_with_torchdistX_init_fn(self): def meta_module_fn(): return deferred_init.deferred_init(MyModel, device="cuda") - self._test_simple_model_with_meta_device(meta_module_fn, init_fn=_init_with_torchdistX) + self._test_simple_model_with_meta_device( + meta_module_fn, init_fn=_init_with_torchdistX + ) - def _test_nested_model_with_meta_device(self, auto_wrap, meta_module_fn, init_fn=None): + def _test_nested_model_with_meta_device( + self, auto_wrap, meta_module_fn, init_fn=None + ): if auto_wrap: module = meta_module_fn() is_meta = next(module.parameters()).is_meta @@ -225,7 +238,8 @@ def _test_nested_model_with_meta_device(self, auto_wrap, meta_module_fn, init_fn regular_opt = torch.optim.SGD(fsdp_regular.parameters(), lr=1e-3) else: with enable_wrap( - wrapper_cls=FSDP, param_init_fn=init_fn, + wrapper_cls=FSDP, + param_init_fn=init_fn, ): module = meta_module_fn() is_meta = next(module.parameters()).is_meta @@ -246,7 +260,7 @@ def _test_nested_model_with_meta_device(self, auto_wrap, meta_module_fn, init_fn # Compare it before training self._compare_fsdp(fsdp_meta, fsdp_regular) - inp = torch.randn(10, 2, device='cuda') + inp = torch.randn(10, 2, device="cuda") fsdp_meta(inp).sum().backward() fsdp_regular(inp).sum().backward() meta_opt.step() @@ -260,7 +274,9 @@ def meta_module_fn(): return NestedModel(device="meta") self._test_nested_model_with_meta_device( - auto_wrap=auto_wrap, meta_module_fn=meta_module_fn, init_fn=_init_with_reset_params + auto_wrap=auto_wrap, + meta_module_fn=meta_module_fn, + init_fn=_init_with_reset_params, ) @skip_if_lt_x_gpu(2) @@ -270,12 +286,14 @@ def meta_module_fn(): return NestedModel(device="meta") self._test_nested_model_with_meta_device( - auto_wrap=auto_wrap, meta_module_fn=meta_module_fn, + auto_wrap=auto_wrap, + meta_module_fn=meta_module_fn, ) @skip_if_lt_x_gpu(2) @sandcastle_skip_if( - not _TORCHDISTX_AVAIL, "Test requires torchdistX: https://github.com/pytorch/torchdistX" + not _TORCHDISTX_AVAIL, + "Test requires torchdistX: https://github.com/pytorch/torchdistX", ) @parametrize("auto_wrap", [True, False]) def test_nested_model_with_torchdistX_default_init(self, auto_wrap): @@ -288,7 +306,8 @@ def meta_module_fn(): @skip_if_lt_x_gpu(2) @sandcastle_skip_if( - not _TORCHDISTX_AVAIL, "Test requires torchdistX: https://github.com/pytorch/torchdistX" + not _TORCHDISTX_AVAIL, + "Test requires torchdistX: https://github.com/pytorch/torchdistX", ) @parametrize("auto_wrap", [True, False]) def test_nested_model_with_torchdistX_init_fn(self, auto_wrap): @@ -296,7 +315,9 @@ def meta_module_fn(): return deferred_init.deferred_init(NestedModel, device="cuda") self._test_nested_model_with_meta_device( - auto_wrap=auto_wrap, meta_module_fn=meta_module_fn, init_fn=_init_with_torchdistX, + auto_wrap=auto_wrap, + meta_module_fn=meta_module_fn, + init_fn=_init_with_torchdistX, ) def _test_bad_arg(self, meta_module_fn): @@ -306,7 +327,8 @@ def _test_bad_arg(self, meta_module_fn): @skip_if_lt_x_gpu(2) @sandcastle_skip_if( - not _TORCHDISTX_AVAIL, "Test requires torchdistX: https://github.com/pytorch/torchdistX" + not _TORCHDISTX_AVAIL, + "Test requires torchdistX: https://github.com/pytorch/torchdistX", ) def test_bad_arg_torchdistx(self): def meta_module_fn(): diff --git a/test/distributed/fsdp/test_fsdp_misc.py b/test/distributed/fsdp/test_fsdp_misc.py index f2ae0dcfcaeaf..d1b2445dc78b8 100644 --- a/test/distributed/fsdp/test_fsdp_misc.py +++ b/test/distributed/fsdp/test_fsdp_misc.py @@ -1,36 +1,41 @@ # Owner(s): ["oncall: distributed"] -from copy import deepcopy import functools import sys +import warnings from collections import namedtuple from contextlib import suppress +from copy import deepcopy import torch import torch.distributed as dist import torch.nn as nn -from torch.distributed.fsdp import FlatParameter -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp import ShardingStrategy, CPUOffload +from torch.distributed.fsdp import ( + CPUOffload, + FlatParameter, + FullyShardedDataParallel as FSDP, + ShardingStrategy, +) from torch.distributed.fsdp.wrap import ( always_wrap_policy, + ModuleWrapPolicy, transformer_auto_wrap_policy, ) from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import ( + _assert_module_states, CUDAInitMode, FSDPInitMode, FSDPTest, NestedWrappedModule, TransformerWithSharedParams, - _assert_module_states, ) from torch.testing._internal.common_utils import ( - TEST_WITH_DEV_DBG_ASAN, instantiate_parametrized_tests, parametrize, run_tests, + TEST_WITH_DEV_DBG_ASAN, ) if not dist.is_available(): @@ -71,9 +76,7 @@ def forward(self, x): t = torch.ones(1, device="cuda", requires_grad=True) MyOutputType = namedtuple( - "MyOutputType", - ["a", "b", "c", "d"], - defaults=(t, t, t, t) + "MyOutputType", ["a", "b", "c", "d"], defaults=(t, t, t, t) ) inp = MyOutputType() @@ -89,7 +92,6 @@ def forward(self, x): @skip_if_lt_x_gpu(2) def test_fsdp_not_all_outputs_used_in_loss(self): - class MyModule(nn.Module): def __init__(self): super().__init__() @@ -108,20 +110,17 @@ def _check_resharded(fsdp_module): full_param = param._full_param_padded self.assertEqual(full_param.storage().size(), 0) - self.assertEqual( - param.data_ptr(), - param._local_shard.data_ptr() - ) + self.assertEqual(param.data_ptr(), param._local_shard.data_ptr()) def _check_equal(local, fsdp): with FSDP.summon_full_params(fsdp): for p1, p2 in zip(fsdp.parameters(), local.parameters()): - torch.testing.assert_allclose(p1, p2) + torch.testing.assert_close(p1, p2) for sharding_strategy in [ ShardingStrategy.FULL_SHARD, ShardingStrategy.SHARD_GRAD_OP, - ShardingStrategy.NO_SHARD + ShardingStrategy.NO_SHARD, ]: with self.subTest(sharding_strategy=sharding_strategy): fsdp_ctor = functools.partial(FSDP, sharding_strategy=sharding_strategy) @@ -160,7 +159,10 @@ def _check_equal(local, fsdp): # Ensure at least some change from previous params, otherwise # above check would be vacuously true. self.assertTrue( - any(not torch.equal(p1, p2) for p1, p2 in zip(prev_params, m_local.parameters())) + any( + not torch.equal(p1, p2) + for p1, p2 in zip(prev_params, m_local.parameters()) + ) ) prev_params = [p.clone() for p in local_m.parameters()] opt.zero_grad() @@ -168,7 +170,6 @@ def _check_equal(local, fsdp): dist.barrier() - @skip_if_lt_x_gpu(2) @parametrize("use_second_layer", [True, False]) @parametrize("sharding_strategy", [ShardingStrategy.NO_SHARD, None]) @@ -193,10 +194,10 @@ def forward(self, x, y): fsdp = FSDP( MyModel().cuda(), sharding_strategy=sharding_strategy, - auto_wrap_policy=always_wrap_policy + auto_wrap_policy=always_wrap_policy, ) - x = torch.randn(10, 10, device='cuda') - y = torch.randn(10, 10, device='cuda') + x = torch.randn(10, 10, device="cuda") + y = torch.randn(10, 10, device="cuda") for i in range(4): if use_second_layer: a, b = fsdp(x, y) @@ -206,8 +207,8 @@ def forward(self, x, y): loss.backward() # self.a receives grad, self.b does not - a_grad = fsdp.module.a._fsdp_wrapped_module.flat_param.grad - b_grad = fsdp.module.b._fsdp_wrapped_module.flat_param.grad + a_grad = fsdp.module.a._handles[0].flat_param.grad + b_grad = fsdp.module.b._handles[0].flat_param.grad self.assertIsNotNone(a_grad) self.assertIsNone(b_grad) @@ -215,10 +216,20 @@ def forward(self, x, y): def test_device_id_auto_wrap(self): """Tests that ``auto_wrap_policy`` propagates ``device_id`` to all nested FSDP instances.""" - auto_wrap_policy = functools.partial( - transformer_auto_wrap_policy, - transformer_layer_cls={TransformerEncoderLayer, TransformerDecoderLayer}, + self.run_subtests( + {"use_callable": [False, True]}, + self._test_device_id_auto_wrap, ) + + def _test_device_id_auto_wrap(self, use_callable: bool): + module_classes = {TransformerEncoderLayer, TransformerDecoderLayer} + if use_callable: + auto_wrap_policy = functools.partial( + transformer_auto_wrap_policy, + transformer_layer_cls=module_classes, + ) + else: + auto_wrap_policy = ModuleWrapPolicy(module_classes) fsdp_kwargs = { "auto_wrap_policy": auto_wrap_policy, "device_id": torch.cuda.current_device(), @@ -241,6 +252,7 @@ def test_fsdp_device_id_cpu_offload(self): Ensures that even if device_id is specified but we have CPU offload, module is on CPU after init. """ + class MyModel(nn.Module): def __init__(self): super().__init__() @@ -256,7 +268,7 @@ def forward(self, x): model, auto_wrap_policy=always_wrap_policy, cpu_offload=CPUOffload(offload_params=True), - device_id=torch.cuda.current_device() + device_id=torch.cuda.current_device(), ) cpu_device = torch.device("cpu") @@ -281,7 +293,8 @@ def test_fsdp_device_id(self, use_index): without specifying a device ID (i.e. ``torch.device("cuda")``) warns """ dev_id = ( - torch.cuda.current_device() if use_index + torch.cuda.current_device() + if use_index else torch.device("cuda", torch.cuda.current_device()) ) @@ -289,8 +302,7 @@ def _check_device_matches(module, device_id): """Checks that the ``FlatParameter``s in ``module`` have device matching ``device_id``.""" devices = { - p.device for p in module.parameters() - if isinstance(p, FlatParameter) + p.device for p in module.parameters() if isinstance(p, FlatParameter) } assert len(devices) > 0 self.assertEqual(1, len(devices)) @@ -328,11 +340,10 @@ def _check_device_matches(module, device_id): self.process_group, FSDPInitMode.RECURSIVE, CUDAInitMode.CUDA_BEFORE, - fsdp_kwargs={"device_id": torch.device("cuda")} + fsdp_kwargs={"device_id": torch.device("cuda")}, ) _check_device_matches( - nested_wrapped_module, - torch.device("cuda", torch.cuda.current_device()) + nested_wrapped_module, torch.device("cuda", torch.cuda.current_device()) ) @skip_if_lt_x_gpu(2) @@ -340,10 +351,9 @@ def test_module_device_mismatches_device_id(self): """Tests that specifying a ``device_id`` argument to FSDP for a GPU module that does not match the GPU device ID raises an error.""" context = ( - self.assertRaisesRegex( - ValueError, - f"cuda:{self.rank} vs cuda:0" - ) if self.rank != 0 else suppress() + self.assertRaisesRegex(ValueError, f"cuda:{self.rank} vs cuda:0") + if self.rank != 0 + else suppress() ) with context: NestedWrappedModule.init( @@ -360,6 +370,7 @@ def test_module_device_mismatches_device_id(self): def test_multi_device_not_supported(self): """Tests that wrapping a multi-device module (i.e. with submodules on both GPU and CPU) with FSDP raises an error.""" + class MultiDeviceModule(nn.Module): def __init__(self): super().__init__() @@ -392,11 +403,14 @@ def test_no_params(self): # is computed as torch.cuda.current_device when there are no params. no_params = nn.ReLU().cuda() context = ( - self.assertRaisesRegex( - ValueError, - f"Inconsistent.*cuda:{self.rank} vs cuda:0" + ( + self.assertRaisesRegex( + ValueError, f"Inconsistent.*cuda:{self.rank} vs cuda:0" + ) ) - ) if self.rank != 0 else suppress() + if self.rank != 0 + else suppress() + ) with context: module = FSDP(no_params, device_id=0) @@ -406,7 +420,7 @@ def test_fsdp_cpu_init_stays_on_cpu(self): module is on CPU after FSDP initialization, albeit after loging a warning, and that FSDP moves CPU input to GPU before the forward.""" torch.cuda.set_device(self.rank) - regex = "Module is put on CPU" + regex = "passed-in `module` is on CPU" context = self.assertWarnsRegex( expected_warning=UserWarning, expected_regex=regex ) @@ -438,8 +452,7 @@ def test_cpu_init_with_sync_module_states(self): CUDAInitMode.CUDA_NEVER, ) with self.assertRaisesRegex( - ValueError, - "Module has CPU parameters, but sync_module_states=True is specified." + ValueError, "The module has CPU parameters when `sync_module_states=True`" ): FSDP(nested_wrapped_module, self.process_group, sync_module_states=True) @@ -457,6 +470,7 @@ def test_fsdp_same_model_across_ranks(self): FSDP broadcasts model from rank 0 to ensure it starts off with the same values. """ + class MyModel(nn.Module): def __init__(self, rank): super().__init__() @@ -467,19 +481,70 @@ def __init__(self, rank): self.register_buffer("buffer", torch.ones(1) * rank) m = MyModel(self.rank).cuda() - _assert_module_states(m, process_group=self.process_group, assert_fn=self.assertNotEqual) + _assert_module_states( + m, process_group=self.process_group, assert_fn=self.assertNotEqual + ) # Passing sync_module_states into FSDP makes model the same during init. fsdp = FSDP(m, sync_module_states=True) with fsdp.summon_full_params(fsdp): - _assert_module_states(fsdp, process_group=self.process_group, assert_fn=self.assertEqual) + _assert_module_states( + fsdp, process_group=self.process_group, assert_fn=self.assertEqual + ) # sync_module_states also works with CPU module with device_id passed in m = MyModel(self.rank) - _assert_module_states(m, process_group=self.process_group, assert_fn=self.assertNotEqual) + _assert_module_states( + m, process_group=self.process_group, assert_fn=self.assertNotEqual + ) # Passing sync_module_states into FSDP makes model the same during init. fsdp = FSDP(m, device_id=torch.cuda.current_device(), sync_module_states=True) with fsdp.summon_full_params(fsdp): - _assert_module_states(fsdp, process_group=self.process_group, assert_fn=self.assertEqual) + _assert_module_states( + fsdp, process_group=self.process_group, assert_fn=self.assertEqual + ) + + +class TestFSDPMiscWorldSize1(FSDPTest): + @property + def world_size(self) -> int: + return 1 + + @skip_if_lt_x_gpu(1) + def test_world_size_1_sharding_strategy_warning(self): + """ + Tests that FSDP issues a warning when it switches to using ``NO_SHARD`` + when the world size is 1. + """ + warning_prefix = "FSDP is switching to use `NO_SHARD` instead of" + # If the user already passes `NO_SHARD`, then there should not be a + # warning + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") # trigger all warnings + FSDP(nn.Linear(3, 3).cuda(), sharding_strategy=ShardingStrategy.NO_SHARD) + for warning in w: + self.assertTrue( + warning.category != UserWarning + or not str(warning.message).startswith(warning_prefix) + ) + + # Check that a warning is issued + warning_suffix = " since the world size is 1." + # - Pass `FULL_SHARD` or `None` + expected_regex_full_shard = ( + warning_prefix + " " + str(ShardingStrategy.FULL_SHARD) + warning_suffix + ) + with self.assertWarnsRegex(UserWarning, expected_regex_full_shard): + FSDP(nn.Linear(3, 3).cuda(), sharding_strategy=ShardingStrategy.FULL_SHARD) + with self.assertWarnsRegex(UserWarning, expected_regex_full_shard): + FSDP(nn.Linear(3, 3).cuda()) + # - Pass `SHARD_GRAD_OP` + expected_regex_shard_grad_op = ( + warning_prefix + " " + str(ShardingStrategy.SHARD_GRAD_OP) + warning_suffix + ) + with self.assertWarnsRegex(UserWarning, expected_regex_shard_grad_op): + FSDP( + nn.Linear(3, 3).cuda(), sharding_strategy=ShardingStrategy.SHARD_GRAD_OP + ) instantiate_parametrized_tests(TestFSDPMisc) diff --git a/test/distributed/fsdp/test_fsdp_mixed_precision.py b/test/distributed/fsdp/test_fsdp_mixed_precision.py index c803164bff4e5..9522f3a013420 100644 --- a/test/distributed/fsdp/test_fsdp_mixed_precision.py +++ b/test/distributed/fsdp/test_fsdp_mixed_precision.py @@ -11,9 +11,13 @@ import torch.nn as nn import torch.nn.functional as F from torch import distributed as dist -from torch.distributed.fsdp import BackwardPrefetch, CPUOffload -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp import MixedPrecision, ShardingStrategy +from torch.distributed.fsdp import ( + BackwardPrefetch, + CPUOffload, + FullyShardedDataParallel as FSDP, + MixedPrecision, + ShardingStrategy, +) from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy from torch.nn.modules.batchnorm import _BatchNorm @@ -23,19 +27,20 @@ CUDAInitMode, FSDPInitMode, FSDPTest, - TransformerWithSharedParams, subtest_name, + TransformerWithSharedParams, ) from torch.testing._internal.common_utils import ( - TEST_WITH_DEV_DBG_ASAN, instantiate_parametrized_tests, parametrize, run_tests, sandcastle_skip_if, + TEST_WITH_DEV_DBG_ASAN, ) try: import torchvision + HAS_TORCHVISION = True except ImportError: HAS_TORCHVISION = False @@ -66,7 +71,9 @@ mp_only_reduce = MixedPrecision(reduce_dtype=torch.float16) # Only parameters are cast (thus comm should happen in the param_dtype precision) -mp_only_param_and_buf = MixedPrecision(param_dtype=torch.float16, buffer_dtype=torch.float16) +mp_only_param_and_buf = MixedPrecision( + param_dtype=torch.float16, buffer_dtype=torch.float16 +) # Nothing is cast (thus param, comm, grad, and buffer should be in the full precision) mp_no_mixed_precision = MixedPrecision() @@ -80,7 +87,7 @@ mp_diff_buffer_and_reduce = MixedPrecision( param_dtype=torch.float16, buffer_dtype=torch.bfloat16, - reduce_dtype=torch.float32 + reduce_dtype=torch.float32, ) mp_configs.extend([mp_diff_buffer_and_reduce]) @@ -88,18 +95,18 @@ _BUFFER_ORIG_DTYPE = torch.float64 params = "mp_config,cpu_offload,full_precision_param_dtype,enable_sharded_grad_scaler" -cpu_offload_config = [ - CPUOffload(offload_params=True), CPUOffload(offload_params=False) -] +cpu_offload_config = [CPUOffload(offload_params=True), CPUOffload(offload_params=False)] full_precision_param_dtype_config = [torch.float32, torch.float64] enable_sharded_grad_scaler = ["enable_sharded_grad_scaler", None] -configs = list(product( - mp_configs, - cpu_offload_config, - full_precision_param_dtype_config, - enable_sharded_grad_scaler, -)) +configs = list( + product( + mp_configs, + cpu_offload_config, + full_precision_param_dtype_config, + enable_sharded_grad_scaler, + ) +) test_name_mapping = { str(CPUOffload(offload_params=True)): "offload_true", @@ -110,42 +117,50 @@ str(mp_no_mixed_precision): "mp_no_mp", str(torch.float32): "fp32", str(torch.float64): "fp64", - "enable_sharded_grad_scaler": "enable_sharded_grad_scaler" + "enable_sharded_grad_scaler": "enable_sharded_grad_scaler", } if nccl_supports_bf16: - test_name_mapping.update({ - str(mp_diff_buffer_and_reduce): "mp_diff_buffer_reduce", - }) + test_name_mapping.update( + { + str(mp_diff_buffer_and_reduce): "mp_diff_buffer_reduce", + } + ) subtest_name = partial(subtest_name, test_name_mapping) _CURRENT_FULL_PRECISION_PARAM_DTYPE = None + @contextlib.contextmanager def patch_reduce_scatter(new_reduce_scatter, full_precision_param_dtype): """ - Patches dist._reduce_scatter_base with a new reduce_scatter_base and - restores upon exiting. Used for validation of mixed precision + Patches ``dist.reduce_scatter_tensor`` with ``new_reduce_scatter`` and + restores upon exiting. Used for validation of mixed precision. """ - orig_reduce_scatter = dist._reduce_scatter_base - dist._reduce_scatter_base = new_reduce_scatter + orig_reduce_scatter = dist.reduce_scatter_tensor + dist.reduce_scatter_tensor = new_reduce_scatter global _CURRENT_FULL_PRECISION_PARAM_DTYPE _CURRENT_FULL_PRECISION_PARAM_DTYPE = full_precision_param_dtype try: yield finally: - dist._reduce_scatter_base = orig_reduce_scatter + dist.reduce_scatter_tensor = orig_reduce_scatter _CURRENT_FULL_PRECISION_PARAM_DTYPE = None + class LinearMixedPrecision(nn.Module): """ A linear module with extra checks for mixed precision training. """ - def __init__(self, param_dtype): + + def __init__(self, param_dtype, buffer_name="buffer"): super().__init__() self.lin = nn.Linear(10, 10, bias=False).to(param_dtype) - self.register_buffer('buffer', torch.randn((1, 2), dtype=_BUFFER_ORIG_DTYPE)) + # Use a configurable buffer name to avoid all submodules sharing the + # same buffer name, which may hide prefixed vs. unprefixed name bugs + self.buffer_name = buffer_name + self.register_buffer(buffer_name, torch.randn((1, 2), dtype=_BUFFER_ORIG_DTYPE)) self._orig_param_type = param_dtype self._orig_buffer_dtype = _BUFFER_ORIG_DTYPE @@ -153,16 +168,18 @@ def forward(self, tup): # Param and input should be the mixed precision type inp, cls, fsdp, mp_config, full_precision_param_dtype = tup expected_param_type = ( - mp_config.param_dtype if mp_config.param_dtype is not None + mp_config.param_dtype + if mp_config.param_dtype is not None else self._orig_param_type ) expected_buffer_type = ( - mp_config.buffer_dtype if mp_config.buffer_dtype is not None + mp_config.buffer_dtype + if mp_config.buffer_dtype is not None else self._orig_buffer_dtype ) cls.assertEqual(inp.dtype, expected_param_type) # Buffer should be in specified precision as well. - cls.assertEqual(self.buffer.dtype, expected_buffer_type) + cls.assertEqual(getattr(self, self.buffer_name).dtype, expected_buffer_type) # In FSDP, self.params should point to the right type. num_active_fsdp = 0 @@ -193,7 +210,7 @@ def forward(self, tup): if mp_config.param_dtype is not None: cls.assertEqual(0, param._mp_shard.storage().size()) else: - cls.assertFalse(hasattr(param, '_mp_shard')) + cls.assertFalse(hasattr(param, "_mp_shard")) elif param_is_sharded: # This FSDP unit is not active as full param has been # freed or not yet allocated. Ensure param points to full @@ -219,8 +236,12 @@ def world_size(self): def _get_simple_nested_model(self, param_dtype, *fsdp_args, **fsdp_kwargs): model = FSDP( nn.Sequential( - FSDP(LinearMixedPrecision(param_dtype).cuda(), *fsdp_args, **fsdp_kwargs), - LinearMixedPrecision(param_dtype).cuda(), + FSDP( + LinearMixedPrecision(param_dtype, buffer_name="buffer0").cuda(), + *fsdp_args, + **fsdp_kwargs, + ), + LinearMixedPrecision(param_dtype, buffer_name="buffer1").cuda(), ), *fsdp_args, **fsdp_kwargs, @@ -228,7 +249,9 @@ def _get_simple_nested_model(self, param_dtype, *fsdp_args, **fsdp_kwargs): return model def _get_simple_model(self, param_dtype, *fsdp_args, **fsdp_kwargs): - model = FSDP(LinearMixedPrecision(param_dtype).cuda(), *fsdp_args, **fsdp_kwargs) + model = FSDP( + LinearMixedPrecision(param_dtype).cuda(), *fsdp_args, **fsdp_kwargs + ) return model def _validate_no_mp_shard(self, fsdp_model): @@ -239,7 +262,7 @@ def _validate_no_mp_shard(self, fsdp_model): fsdp_units = FSDP.fsdp_modules(fsdp_model) for fsdp in fsdp_units: for param in fsdp.params: - self.assertFalse(hasattr(param, '_mp_shard')) + self.assertFalse(hasattr(param, "_mp_shard")) def _validate_mp_shard_freed(self, fsdp_model): """ @@ -250,17 +273,13 @@ def _validate_mp_shard_freed(self, fsdp_model): for param in fsdp.params: self.assertEqual(0, param._mp_shard.storage().size()) - def _reduce_scatter_base_validate_mp( - self, - orig_reduce_scatter, - mp_config, - *args, - **kwargs + def _reduce_scatter_validate_mp( + self, orig_reduce_scatter, mp_config, *args, **kwargs ): """ - Performs dist._reduce_scatter_base but verifies mixed precision settings - before. This is to test mixed precision is working as expected during - backward pass. In particular it ensures that the gradients were cast to the right type + Runs reduce-scatter but verifies mixed precision settings before. This + is to test mixed precision is working as expected during backward pass. + In particular it ensures that the gradients were cast to the right type and comm. is going to happen in the right type. """ tensors = [] @@ -278,9 +297,11 @@ def _reduce_scatter_base_validate_mp( # If reduce_dtype is not specified (is None) we comm. in the param_dtype # if that is specified, otherwise full precision dtype. expected_dtype = ( - mp_config.reduce_dtype if mp_config.reduce_dtype is not None + mp_config.reduce_dtype + if mp_config.reduce_dtype is not None else ( - mp_config.param_dtype if mp_config.param_dtype is not None + mp_config.param_dtype + if mp_config.param_dtype is not None else _CURRENT_FULL_PRECISION_PARAM_DTYPE ) ) @@ -290,7 +311,9 @@ def _reduce_scatter_base_validate_mp( return orig_reduce_scatter(*args, **kwargs) - def _test_grads_reduced_precision(self, offload_params: bool): + def _test_grads_reduced_precision( + self, offload_params: bool, use_orig_params: bool + ): class MyModel(nn.Module): def __init__(self): super().__init__() @@ -310,6 +333,7 @@ def forward(self, x): fsdp_kwargs = { "mixed_precision": mp, "cpu_offload": CPUOffload(offload_params=offload_params), + "use_orig_params": use_orig_params, } m.lin1 = FSDP(m.lin1, **fsdp_kwargs) m = FSDP(m, **fsdp_kwargs) @@ -317,7 +341,8 @@ def forward(self, x): inp = torch.ones(1, 10) m(inp).sum().backward() for param in m.parameters(): - self.assertEqual(torch.float16, param.grad.dtype) + if param.grad is not None: + self.assertEqual(torch.float16, param.grad.dtype) dist.barrier() @@ -355,16 +380,20 @@ def _run_test_mixed_precision_e2e( model.cuda() # Patch reduce_scatter to add validation for mixed precision types. - orig_reduce_scatter = dist._reduce_scatter_base + orig_reduce_scatter = dist.reduce_scatter_tensor test_reduce_scatter = partial( - self._reduce_scatter_base_validate_mp, orig_reduce_scatter, mp_config, + self._reduce_scatter_validate_mp, + orig_reduce_scatter, + mp_config, ) with patch_reduce_scatter(test_reduce_scatter, full_precision_param_dtype): scaler = ShardedGradScaler(enabled=enable_sharded_grad_scaler) optim = torch.optim.Adam(model.parameters()) for _ in range(3): - inp = torch.randn(3, 10, device='cuda', dtype=full_precision_param_dtype) + inp = torch.randn( + 3, 10, device="cuda", dtype=full_precision_param_dtype + ) # Forward pass of LinearMixedPrecision check casting of # inputs, params, buffers. act, *_ = model( @@ -409,7 +438,9 @@ def _run_test_mixed_precision_e2e( for param in model.parameters(): self.assertEqual(param.dtype, full_precision_param_dtype) if param.grad is not None: - self.assertEqual(param.grad.dtype, full_precision_param_dtype) + self.assertEqual( + param.grad.dtype, full_precision_param_dtype + ) # Unscale the gradients and step scaler.step(optim) @@ -448,8 +479,9 @@ def _run_test_mixed_precision_e2e( self.assertEqual(tensor.dtype, _BUFFER_ORIG_DTYPE) else: self.assertEqual( - tensor.dtype, full_precision_param_dtype, - f"{name}: {tensor.dtype} vs {full_precision_param_dtype}" + tensor.dtype, + full_precision_param_dtype, + f"{name}: {tensor.dtype} vs {full_precision_param_dtype}", ) # After state_dict, buffer's dtype should have been restored @@ -475,7 +507,7 @@ def _get_subtest_config(self) -> Dict[str, List[Any]]: None, BackwardPrefetch.BACKWARD_PRE, BackwardPrefetch.BACKWARD_POST, - ] + ], } @skip_if_lt_x_gpu(2) @@ -516,9 +548,11 @@ def _test_mixed_precision_embedding_table(self, mp_config): # Basic test to ensure int inputs are not casted which would break # modules such as embedding tables. param_dtype = mp_config.param_dtype or torch.float32 - orig_reduce_scatter = dist._reduce_scatter_base + orig_reduce_scatter = dist.reduce_scatter_tensor test_reduce_scatter = partial( - self._reduce_scatter_base_validate_mp, orig_reduce_scatter, mp_config, + self._reduce_scatter_validate_mp, + orig_reduce_scatter, + mp_config, ) with patch_reduce_scatter(test_reduce_scatter, param_dtype): # TODO: `test_mp_embedding_reduce()` fails if we do not wrap the @@ -570,9 +604,11 @@ def test_mp_embedding_params_and_reduce_diff(self): params_and_reduce_different = MixedPrecision( param_dtype=torch.float16, reduce_dtype=torch.float32, - buffer_dtype=torch.float16 + buffer_dtype=torch.float16, + ) + self._test_mixed_precision_embedding_table( + mp_config=params_and_reduce_different ) - self._test_mixed_precision_embedding_table(mp_config=params_and_reduce_different) @skip_if_lt_x_gpu(2) @skipIfNoTorchVision @@ -583,11 +619,12 @@ def test_mixed_precision_resnet(self): """ resnet_model = torchvision.models.resnet50().cuda() resnet_model = nn.SyncBatchNorm.convert_sync_batchnorm( - resnet_model, - process_group=dist.distributed_c10d._get_default_group() + resnet_model, process_group=dist.distributed_c10d._get_default_group() ) - n_bn = sum(1 if isinstance(x, _BatchNorm) else 0 for x in resnet_model.modules()) - inp = torch.ones(1, 3, 1000, 1000, device='cuda') + n_bn = sum( + 1 if isinstance(x, _BatchNorm) else 0 for x in resnet_model.modules() + ) + inp = torch.ones(1, 3, 1000, 1000, device="cuda") mp_config = MixedPrecision( param_dtype=torch.float16, reduce_dtype=torch.float16, @@ -596,7 +633,7 @@ def test_mixed_precision_resnet(self): fsdp = FSDP( resnet_model, auto_wrap_policy=size_based_auto_wrap_policy, - mixed_precision=mp_config + mixed_precision=mp_config, ) # Batchnorm units should be wrapped individually. Validate this by # ensuring there are equal no. of FSDP units that are BN as BN units @@ -615,7 +652,10 @@ def test_mixed_precision_resnet(self): @skip_if_lt_x_gpu(2) def test_grads_reduced_precision(self): self.run_subtests( - {"offload_params": [False, True]}, + { + "offload_params": [False, True], + "use_orig_params": [False, True], + }, self._test_grads_reduced_precision, ) @@ -652,7 +692,7 @@ def never_wrap_policy(*args, **kwargs): ) with self.assertWarnsRegex( expected_warning=UserWarning, - expected_regex="batch norm submodules will be wrapped as separate" + expected_regex="batch norm submodules will be wrapped as separate", ): model = FSDP( net, @@ -669,7 +709,7 @@ def never_wrap_policy(*args, **kwargs): self.assertEqual(no_mixed_precision, bn.mixed_precision) self.assertNotEqual(no_mixed_precision, model.mixed_precision) - inp = torch.randn((1, 2), device='cuda') + inp = torch.randn((1, 2), device="cuda") # Without FSDP BN mixed precision fix, this would result in # RuntimeError: Expected counts to have type Half but got Float # for syncBN @@ -680,6 +720,7 @@ class TestFSDPMixedPrecisionUnsharded(TestFSDPMixedPrecision): """ Smaller test suite for unshared param (i.e. world_size == 1) case. """ + @property def world_size(self): return 1 @@ -687,7 +728,7 @@ def world_size(self): @skip_if_lt_x_gpu(1) def test_grads_reduced_precision(self): self.run_subtests( - {"offload_params": [False, True]}, + {"offload_params": [False, True], "use_orig_params": [False, True]}, self._test_grads_reduced_precision, ) @@ -719,7 +760,50 @@ def test_mixed_precision_e2e_full_shard(self): enable_sharded_grad_scaler=False, ) + instantiate_parametrized_tests(TestFSDPMixedPrecisionSharded) + +class IgnoredModule(nn.Module): + def __init__(self): + super().__init__() + self.l = nn.Linear(100, 100) + + def forward(self, x): + return self.l(x) + + +class Model(nn.Module): + def __init__(self): + super().__init__() + self.l1 = nn.Linear(100, 100) + self.ignored = IgnoredModule() + self.l2 = nn.Linear(100, 100) + + def forward(self, x): + return self.l2(self.ignored(self.l1(x))) + + +class TestFSDPMixedPrecisionIgnoredModules(FSDPTest): + @property + def world_size(self): + return 1 + + @skip_if_lt_x_gpu(1) + def test_mixed_precision_with_ignored_module(self): + model = Model().cuda() + float16 = MixedPrecision(param_dtype=torch.float16) + model = FSDP( + model, + ignored_modules=[model.ignored], + mixed_precision=float16, + ) + + x = torch.ones(2, 100, device=torch.cuda.current_device()) + + with self.assertRaisesRegex(RuntimeError, "must have the same dtype"): + model(x).sum().backward() + + if __name__ == "__main__": run_tests() diff --git a/test/distributed/fsdp/test_fsdp_multiple_forward.py b/test/distributed/fsdp/test_fsdp_multiple_forward.py index c9afbd465f28e..7823f9349a005 100644 --- a/test/distributed/fsdp/test_fsdp_multiple_forward.py +++ b/test/distributed/fsdp/test_fsdp_multiple_forward.py @@ -9,12 +9,8 @@ from torch.nn.parallel import DistributedDataParallel from torch.optim import SGD from torch.testing._internal.common_distributed import skip_if_lt_x_gpu -from torch.testing._internal.common_fsdp import ( - FSDPTest, - get_full_params, -) -from torch.testing._internal.common_utils import TEST_WITH_DEV_DBG_ASAN, run_tests - +from torch.testing._internal.common_fsdp import FSDPTest, get_full_params +from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN if not dist.is_available(): print("Distributed not available, skipping tests", file=sys.stderr) diff --git a/test/distributed/fsdp/test_fsdp_multiple_wrapping.py b/test/distributed/fsdp/test_fsdp_multiple_wrapping.py index 0a3b9e2e2e068..58298fcce26ff 100644 --- a/test/distributed/fsdp/test_fsdp_multiple_wrapping.py +++ b/test/distributed/fsdp/test_fsdp_multiple_wrapping.py @@ -9,8 +9,7 @@ from torch.optim import SGD from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import FSDPTest -from torch.testing._internal.common_utils import TEST_WITH_DEV_DBG_ASAN, run_tests - +from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN if not dist.is_available(): print("Distributed not available, skipping tests", file=sys.stderr) diff --git a/test/distributed/fsdp/test_fsdp_optim_state.py b/test/distributed/fsdp/test_fsdp_optim_state.py index e4199ad532a6b..5b714fe65c265 100644 --- a/test/distributed/fsdp/test_fsdp_optim_state.py +++ b/test/distributed/fsdp/test_fsdp_optim_state.py @@ -2,13 +2,14 @@ import bisect import sys -from enum import Enum, auto +from enum import auto, Enum from typing import Any, Callable, Dict, List, Optional, Tuple, Type import torch +import torch.nn as nn from torch import distributed as dist from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( - _CHECKPOINT_PREFIX, + _CHECKPOINT_WRAPPED_MODULE, apply_activation_checkpointing, ) from torch.distributed.fsdp import FullyShardedDataParallel as FSDP @@ -25,15 +26,13 @@ TransformerWithSharedParams, ) from torch.testing._internal.common_utils import ( - TEST_WITH_DEV_DBG_ASAN, instantiate_parametrized_tests, parametrize, run_tests, + TEST_WITH_DEV_DBG_ASAN, ) -STATE_DICT_TYPE = [ - StateDictType.FULL_STATE_DICT, StateDictType.SHARDED_STATE_DICT -] +STATE_DICT_TYPES = [StateDictType.FULL_STATE_DICT, StateDictType.SHARDED_STATE_DICT] if not dist.is_available(): print("Distributed not available, skipping tests", file=sys.stderr) @@ -49,6 +48,7 @@ class _OSDCommMethod(Enum): """Method for communicating the optimizer state dict for internal tests.""" + BROADCAST_OBJECT_LIST = auto() SCATTER_FULL_OSD = auto() FLATTEN_SHARDED_OSD = auto() @@ -56,12 +56,14 @@ class _OSDCommMethod(Enum): class _ModelClass(Enum): """Different model type to test.""" + NESTED = auto() TRANSFORMER = auto() class Bias(torch.nn.Module): """This module applies a 1D additive bias with dimension ``dim``.""" + def __init__(self, dim: int) -> None: super().__init__() assert dim > 0 @@ -82,6 +84,7 @@ class BlockA(torch.nn.Module): Bias1 bias """ + def __init__(self, in_dim: int, out_dim: int) -> None: super().__init__() assert all(v > 0 for v in (in_dim, out_dim)) @@ -98,6 +101,7 @@ def forward(self, x): x = self.bias_module1(x) return x + class BlockB(torch.nn.Module): """ Used to define interesting nested structure for FSDP wrapping. @@ -108,6 +112,7 @@ class BlockB(torch.nn.Module): Bias bias """ + def __init__(self, in_dim: int, out_dim: int) -> None: super().__init__() assert all(v > 0 for v in (in_dim, out_dim)) @@ -166,21 +171,30 @@ def wrap( fsdp_kwargs = {} # Flatten Bias0; then flatten weight and Bias1 together into `block1` model.block1.bias_module0 = FSDP( - model.block1.bias_module0, process_group=group, **fsdp_kwargs, + model.block1.bias_module0, + process_group=group, + **fsdp_kwargs, ) model.block1 = FSDP(model.block1, process_group=group, **fsdp_kwargs) # Flatten Bias0; flatten Bias1; then flatten weight into `block2[1]` model.block2[1].bias_module0 = FSDP( - model.block2[1].bias_module0, process_group=group, **fsdp_kwargs, + model.block2[1].bias_module0, + process_group=group, + **fsdp_kwargs, ) model.block2[1].bias_module1 = FSDP( - model.block2[1].bias_module1, process_group=group, **fsdp_kwargs, + model.block2[1].bias_module1, + process_group=group, + **fsdp_kwargs, ) model.block2[1] = FSDP(model.block2[1], process_group=group, **fsdp_kwargs) # Flatten weight, Bias, bias into `block2[2]` ignored_modules = [model.block2[2].bias_module0] if ignore_modules else None model.block2[2] = FSDP( - model.block2[2], process_group=group, ignored_modules=ignored_modules, **fsdp_kwargs, + model.block2[2], + process_group=group, + ignored_modules=ignored_modules, + **fsdp_kwargs, ) return model @@ -193,7 +207,9 @@ def wrap_alt( if fsdp_kwargs is None: fsdp_kwargs = {} model.block0.bias_module0 = FSDP( - model.block0.bias_module0, process_group=group, **fsdp_kwargs, + model.block0.bias_module0, + process_group=group, + **fsdp_kwargs, ) model.block0 = FSDP(model.block0, process_group=group, **fsdp_kwargs) return model @@ -211,7 +227,8 @@ def wrap_with_unmanaged_params( # (`model.block2[2]`) or a module not to be wrapped with FSDP (`model`) register_module = model.block2[2] if add_to_fsdp_module else model register_module.register_parameter( - "unmanaged_param", unmanaged_param, + "unmanaged_param", + unmanaged_param, ) # For simplicity, we only add a single unmanaged parameter, but should # be easy to generalize if needed @@ -256,8 +273,7 @@ def param_group0(self) -> List[torch.nn.Parameter]: def param_group1(self) -> List[torch.nn.Parameter]: # Deviate from the `model.parameters()` order further by rearranging # `block2`'s parameters to be before `block0`'s parameters - return list(self.block2.parameters()) + \ - list(self.block0.parameters()) + return list(self.block2.parameters()) + list(self.block0.parameters()) class TestFSDPOptimState(FSDPTest): @@ -281,14 +297,17 @@ def _init_nested_model( ): model = NestedModel().to(device) if wrap: - model = NestedModel.wrap_alt(model, group, fsdp_kwargs) if wrap_alt \ + model = ( + NestedModel.wrap_alt(model, group, fsdp_kwargs) + if wrap_alt else NestedModel.wrap(model, group, fsdp_kwargs=fsdp_kwargs) + ) if not use_multiple_param_groups: optim_input = list(model.parameters()) else: optim_input = [ {"params": model.param_group0()}, - {"params": model.param_group1(), "weight_decay": 0.9} + {"params": model.param_group1(), "weight_decay": 0.9}, ] # Use a reversed parameter order for the optimizer input on odd ranks if use_diff_optim_inputs and self.rank % 2 == 1: @@ -353,7 +372,9 @@ def _broadcast_full_osd(self, full_osd: Dict[str, Any], group=None): ``torch.save()`` and ``torch.load()`` so that all ranks can have it.""" obj_list = [full_osd] dist.broadcast_object_list( - obj_list, src=0, group=group, + obj_list, + src=0, + group=group, ) full_osd = obj_list[0] return full_osd @@ -375,8 +396,9 @@ def _are_equal_states( # Check the values on CPU to be device-agnostic value1 = value1.cpu() value2 = value2.cpu() - if value1.shape != value2.shape or \ - not torch.all(torch.isclose(value1, value2)): + if value1.shape != value2.shape or not torch.all( + torch.isclose(value1, value2) + ): return False else: # non-tensor state if value1 != value2: @@ -422,10 +444,12 @@ def _check_same_state( # Check for at least one match (may be > 1 in toy edge cases, e.g. # multiple biases); nonetheless, each having >= 1 match and the two # lists having equal length imply that the list contents are equal - self.assertTrue(any( - self._are_equal_states(fsdp_osd_state, ref_osd_state) - for ref_osd_state in ref_osd_states - )) + self.assertTrue( + any( + self._are_equal_states(fsdp_osd_state, ref_osd_state) + for ref_osd_state in ref_osd_states + ) + ) def _check_same_param_groups( self, @@ -443,10 +467,12 @@ def _check_same_param_groups( full_osd_param_groups = full_osd["param_groups"] self.assertTrue(len(full_osd_param_groups), len(ref_osd_param_groups)) for full_osd_pg, ref_osd_pg in zip( - full_osd_param_groups, ref_osd_param_groups, + full_osd_param_groups, + ref_osd_param_groups, ): self.assertEqual( - set(full_osd_pg.keys()), set(ref_osd_pg.keys()), + set(full_osd_pg.keys()), + set(ref_osd_pg.keys()), ) for name, full_osd_value in full_osd_pg.items(): if name == "params" and not check_same_param_keys: @@ -465,7 +491,7 @@ def _check_state_device(self, osd: Dict[str, Any], on_gpu: bool): self.assertFalse(value.is_cuda) @skip_if_lt_x_gpu(2) - @parametrize("state_dict_type", STATE_DICT_TYPE) + @parametrize("state_dict_type", STATE_DICT_TYPES) @parametrize("use_multiple_param_groups", [False, True]) @parametrize("rank0_only", [False, True]) @parametrize("use_diff_optim_inputs", [False, True]) @@ -477,7 +503,7 @@ def test_optim_state_dict_nested( use_diff_optim_inputs: bool, ) -> None: """ - Tests :meth:`full_optim_state_dict` and `sharded_optim_state_dict` + Tests :meth:`full_optim_state_dict` and meth:`sharded_optim_state_dict` by comparing the returned dict for an FSDP-wrapped model with that of an equivalent non-wrapped model. @@ -508,18 +534,24 @@ def _test_optim_state_dict_nested( return # not supported NUM_ITERS = 3 model1, optim1, optim_input = self._init_nested_model( - wrap=True, use_multiple_param_groups=use_multiple_param_groups, + wrap=True, + use_multiple_param_groups=use_multiple_param_groups, use_diff_optim_inputs=use_diff_optim_inputs, ) losses1 = self._step_model(model1, optim1, num_iters=NUM_ITERS) if state_dict_type == StateDictType.FULL_STATE_DICT: if use_optim_input: fsdp_osd = FSDP.full_optim_state_dict( - model1, optim1, optim_input, rank0_only=rank0_only, + model1, + optim1, + optim_input, + rank0_only=rank0_only, ) else: fsdp_osd = FSDP.full_optim_state_dict( - model1, optim1, rank0_only=rank0_only, + model1, + optim1, + rank0_only=rank0_only, ) else: if use_optim_input: @@ -531,7 +563,8 @@ def _test_optim_state_dict_nested( self.assertEqual(len(fsdp_osd), 0) return model2, optim2, _ = self._init_nested_model( - wrap=False, use_multiple_param_groups=use_multiple_param_groups, + wrap=False, + use_multiple_param_groups=use_multiple_param_groups, use_diff_optim_inputs=use_diff_optim_inputs, ) losses2 = self._step_model(model2, optim2, num_iters=NUM_ITERS) @@ -544,10 +577,14 @@ def _test_optim_state_dict_nested( # parameter IDs check_same_param_keys = False self._check_same_param_groups( - fsdp_osd, ref_osd, check_same_param_keys=check_same_param_keys, + fsdp_osd, + ref_osd, + check_same_param_keys=check_same_param_keys, ) self._check_same_state( - fsdp_osd, ref_osd, check_same_param_keys=check_same_param_keys, + fsdp_osd, + ref_osd, + check_same_param_keys=check_same_param_keys, ) @skip_if_lt_x_gpu(2) @@ -562,18 +599,19 @@ def test_full_optim_state_dict_keys(self): # Add checkpointing to ensure optim_state_dict and state_dict strip out # checkpointing prefixes. apply_activation_checkpointing( - model, - check_fn=lambda module: isinstance(module, torch.nn.Sequential) + model, check_fn=lambda module: isinstance(module, torch.nn.Sequential) ) optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3) self._step_model(model, optim, device) - optim_state_dict = FSDP.full_optim_state_dict(wrapped_model, optim, rank0_only=False) + optim_state_dict = FSDP.full_optim_state_dict( + wrapped_model, optim, rank0_only=False + ) with FSDP.state_dict_type(wrapped_model, StateDictType.FULL_STATE_DICT): state_dict = wrapped_model.state_dict() self.assertEqual(optim_state_dict["state"].keys(), state_dict.keys()) # Check that checkpointing prefix was indeed stripped. for key in optim_state_dict["state"]: - self.assertNotIn(_CHECKPOINT_PREFIX, key) + self.assertNotIn(_CHECKPOINT_WRAPPED_MODULE, key) @skip_if_lt_x_gpu(2) def test_full_optim_state_dict_nested_invalid(self): @@ -771,11 +809,13 @@ def _test_load_optim_state( # First, run a wrapped model with full world size for a few iterations model1, optim1, optim_input1 = initializer( - wrap=True, use_multiple_param_groups=use_multiple_param_groups, + wrap=True, + use_multiple_param_groups=use_multiple_param_groups, ) self._step_model(model1, optim1, num_iters=NUM_ITERS) fsdp_osd1 = ( - osd_method(model1, optim1, optim_input1) if use_optim_input + osd_method(model1, optim1, optim_input1) + if use_optim_input else osd_method(model1, optim1) ) if halve_world_size: @@ -790,7 +830,8 @@ def _test_load_optim_state( # Second, run a wrapped model with (possibly) halved world size and # (possibly) differing `optim_input` across ranks model2, optim2, optim_input2 = initializer( - wrap=True, group=new_group, + wrap=True, + group=new_group, use_multiple_param_groups=use_multiple_param_groups, use_diff_optim_inputs=use_diff_optim_inputs, **new_model_kwargs, # specify `wrap_alt` to change wrapping @@ -807,13 +848,17 @@ def _test_load_optim_state( if osd_comm_method == _OSDCommMethod.BROADCAST_OBJECT_LIST: fsdp_osd1 = self._broadcast_full_osd(fsdp_osd1, group=new_group) sharded_osd1 = ( - FSDP.shard_full_optim_state_dict(fsdp_osd1, model2, optim_input=optim_input2) + FSDP.shard_full_optim_state_dict( + fsdp_osd1, model2, optim_input=optim_input2 + ) if use_optim_input else FSDP.shard_full_optim_state_dict(fsdp_osd1, model2, optim=optim2) ) fsdp_osd2 = self._broadcast_full_osd(fsdp_osd2, group=new_group) sharded_osd2 = ( - FSDP.shard_full_optim_state_dict(fsdp_osd2, model2, optim_input=optim_input2) + FSDP.shard_full_optim_state_dict( + fsdp_osd2, model2, optim_input=optim_input2 + ) if use_optim_input else FSDP.shard_full_optim_state_dict(fsdp_osd2, model2, optim=optim2) ) @@ -824,7 +869,8 @@ def _test_load_optim_state( model2, optim_input=optim_input2, group=new_group, - ) if use_optim_input + ) + if use_optim_input else FSDP.scatter_full_optim_state_dict( fsdp_osd1 if self.rank == 0 else None, model2, @@ -838,7 +884,8 @@ def _test_load_optim_state( model2, optim_input=optim_input2, group=new_group, - ) if use_optim_input + ) + if use_optim_input else FSDP.scatter_full_optim_state_dict( fsdp_osd2 if self.rank == 0 else None, model2, @@ -851,18 +898,28 @@ def _test_load_optim_state( elif osd_comm_method == _OSDCommMethod.FLATTEN_SHARDED_OSD: sharded_osd1 = ( FSDP.flatten_sharded_optim_state_dict( - fsdp_osd1, model2, optim_input=optim_input2, - ) if use_optim_input + fsdp_osd1, + model2, + optim_input=optim_input2, + ) + if use_optim_input else FSDP.flatten_sharded_optim_state_dict( - fsdp_osd1, model2, optim=optim2, + fsdp_osd1, + model2, + optim=optim2, ) ) sharded_osd2 = ( FSDP.flatten_sharded_optim_state_dict( - fsdp_osd2, model2, optim_input=optim_input2, - ) if use_optim_input + fsdp_osd2, + model2, + optim_input=optim_input2, + ) + if use_optim_input else FSDP.flatten_sharded_optim_state_dict( - fsdp_osd2, model2, optim=optim2, + fsdp_osd2, + model2, + optim=optim2, ) ) @@ -872,22 +929,26 @@ def _test_load_optim_state( local_osd2 = optim2.state_dict() check_same_param_keys = True # should all have matching parameter IDs self._check_same_param_groups( - sharded_osd2, local_osd2, + sharded_osd2, + local_osd2, check_same_param_keys=check_same_param_keys, ) self._check_same_state( - sharded_osd2, local_osd2, + sharded_osd2, + local_osd2, check_same_param_keys=check_same_param_keys, ) # Check that sharding the first model's full/sharded optimizer state dict # according to the second model is equivalent to the second model's # local optimizer state dict self._check_same_param_groups( - sharded_osd1, local_osd2, + sharded_osd1, + local_osd2, check_same_param_keys=check_same_param_keys, ) self._check_same_state( - sharded_osd1, local_osd2, + sharded_osd1, + local_osd2, check_same_param_keys=check_same_param_keys, ) # As a sanity check, check that we can load and run a few iterations @@ -896,7 +957,7 @@ def _test_load_optim_state( self._step_model(model2, optim2, num_iters=NUM_ITERS) @skip_if_lt_x_gpu(2) - @parametrize("state_dict_type", STATE_DICT_TYPE) + @parametrize("state_dict_type", STATE_DICT_TYPES) @parametrize("add_to_fsdp_module", [False, True]) def test_shard_full_optim_state_dict_unmanaged_params( self, @@ -955,7 +1016,8 @@ def _test_shard_full_optim_state_dict_unmanaged_params( device = torch.device("cuda") model = NestedModel().to(device) model, unmanaged_params = NestedModel.wrap_with_unmanaged_params( - model, add_to_fsdp_module, + model, + add_to_fsdp_module, ) optim_input = list(model.parameters()) optim = torch.optim.Adam(optim_input, lr=1e-3) @@ -965,21 +1027,31 @@ def _test_shard_full_optim_state_dict_unmanaged_params( # unflattened parameters with zero-dimensional tensor state (i.e. # Adam "step") and others without (i.e. the unmanaged parameters), # which triggers an error that we have to ensure correctness - error_prefix = "^(All unflattened parameters comprising a " \ - "single flattened parameter must have scalar state with the " \ + error_prefix = ( + "^(All unflattened parameters comprising a " + "single flattened parameter must have scalar state with the " "same value and dtype)" + ) with self.assertRaisesRegex(ValueError, error_prefix): if state_dict_type == StateDictType.FULL_STATE_DICT: ( - FSDP.shard_full_optim_state_dict(fsdp_osd, model, optim_input=optim_input) + FSDP.shard_full_optim_state_dict( + fsdp_osd, model, optim_input=optim_input + ) if use_optim_input - else FSDP.shard_full_optim_state_dict(fsdp_osd, model, optim=optim) + else FSDP.shard_full_optim_state_dict( + fsdp_osd, model, optim=optim + ) ) else: ( - FSDP.flatten_sharded_optim_state_dict(fsdp_osd, model, optim_input=optim_input) + FSDP.flatten_sharded_optim_state_dict( + fsdp_osd, model, optim_input=optim_input + ) if use_optim_input - else FSDP.flatten_sharded_optim_state_dict(fsdp_osd, model, optim=optim) + else FSDP.flatten_sharded_optim_state_dict( + fsdp_osd, model, optim=optim + ) ) else: # If we add the unmanaged parameters to a module not wrapped with @@ -988,26 +1060,34 @@ def _test_shard_full_optim_state_dict_unmanaged_params( # externally to FSDP if state_dict_type == StateDictType.FULL_STATE_DICT: flattened_osd = ( - FSDP.shard_full_optim_state_dict(fsdp_osd, model, optim_input=optim_input) + FSDP.shard_full_optim_state_dict( + fsdp_osd, model, optim_input=optim_input + ) if use_optim_input else FSDP.shard_full_optim_state_dict(fsdp_osd, model, optim=optim) ) else: flattened_osd = ( - FSDP.flatten_sharded_optim_state_dict(fsdp_osd, model, optim_input=optim_input) + FSDP.flatten_sharded_optim_state_dict( + fsdp_osd, model, optim_input=optim_input + ) if use_optim_input - else FSDP.flatten_sharded_optim_state_dict(fsdp_osd, model, optim=optim) + else FSDP.flatten_sharded_optim_state_dict( + fsdp_osd, model, optim=optim + ) ) # Add entries for the unmanaged parameters to be able to load for unmanaged_param in unmanaged_params: NestedModel.add_unmanaged_param_entry( - flattened_osd, unmanaged_param, NUM_ITERS, + flattened_osd, + unmanaged_param, + NUM_ITERS, ) # Check that we can load the optimizer state dict optim.load_state_dict(flattened_osd) @skip_if_lt_x_gpu(2) - @parametrize("state_dict_type", STATE_DICT_TYPE) + @parametrize("state_dict_type", STATE_DICT_TYPES) @parametrize("use_multiple_param_groups", [False, True]) def test_rekey_optim_state_dict_to_ids( self, @@ -1035,7 +1115,8 @@ def _test_rekey_optim_state_dict_to_ids( NUM_ITERS = 3 # Run a wrapped model for a few iterations model1, optim1, optim_input1 = self._init_nested_model( - wrap=True, use_multiple_param_groups=use_multiple_param_groups, + wrap=True, + use_multiple_param_groups=use_multiple_param_groups, ) self._step_model(model1, optim1, num_iters=NUM_ITERS) if state_dict_type == StateDictType.FULL_STATE_DICT: @@ -1055,28 +1136,39 @@ def _test_rekey_optim_state_dict_to_ids( ) # Run a non-wrapped model for a few iterations model2, optim2, optim_input2 = self._init_nested_model( - wrap=False, use_multiple_param_groups=use_multiple_param_groups, + wrap=False, + use_multiple_param_groups=use_multiple_param_groups, ) self._step_model(model2, optim2, num_iters=NUM_ITERS) # Re-key the wrapped model's optimizer state dict using parameter IDs # according to the non-wrapped model rekeyed_osd = ( FSDP.rekey_optim_state_dict( - fsdp_osd, OptimStateKeyType.PARAM_ID, model2, optim_input=optim_input2, + fsdp_osd, + OptimStateKeyType.PARAM_ID, + model2, + optim_input=optim_input2, ) if use_optim_input else FSDP.rekey_optim_state_dict( - fsdp_osd, OptimStateKeyType.PARAM_ID, model2, optim=optim2, + fsdp_osd, + OptimStateKeyType.PARAM_ID, + model2, + optim=optim2, ) ) # Check that the re-keyed dict and actual dict are the same osd = optim2.state_dict() check_same_param_keys = True self._check_same_param_groups( - rekeyed_osd, osd, check_same_param_keys=check_same_param_keys, + rekeyed_osd, + osd, + check_same_param_keys=check_same_param_keys, ) self._check_same_state( - rekeyed_osd, osd, check_same_param_keys=check_same_param_keys, + rekeyed_osd, + osd, + check_same_param_keys=check_same_param_keys, ) # As a sanity check, check that we can load and run a few iterations if state_dict_type != StateDictType.SHARDED_STATE_DICT: @@ -1106,12 +1198,14 @@ def _test_rekey_optim_state_dict_to_names( NUM_ITERS = 3 # Run a wrapped model for a few iterations model1, optim1, optim_input1 = self._init_nested_model( - wrap=True, use_multiple_param_groups=use_multiple_param_groups, + wrap=True, + use_multiple_param_groups=use_multiple_param_groups, ) self._step_model(model1, optim1, num_iters=NUM_ITERS) # Run a non-wrapped model for a few iterations model2, optim2, optim_input2 = self._init_nested_model( - wrap=False, use_multiple_param_groups=use_multiple_param_groups, + wrap=False, + use_multiple_param_groups=use_multiple_param_groups, ) self._step_model(model2, optim2, num_iters=NUM_ITERS) # Re-key the non-wrapped model's optimizer state dict using parameter @@ -1119,20 +1213,32 @@ def _test_rekey_optim_state_dict_to_names( osd2 = optim2.state_dict() rekeyed_osd = ( FSDP.rekey_optim_state_dict( - osd2, OptimStateKeyType.PARAM_NAME, model2, optim_input=optim_input2, - ) if use_optim_input + osd2, + OptimStateKeyType.PARAM_NAME, + model2, + optim_input=optim_input2, + ) + if use_optim_input else FSDP.rekey_optim_state_dict( - osd2, OptimStateKeyType.PARAM_NAME, model2, optim=optim2, + osd2, + OptimStateKeyType.PARAM_NAME, + model2, + optim=optim2, ) ) # Shard the non-wrapped model's re-keyed optimizer state dict, which # maps back to (flattened) parameter IDs sharded_osd = ( FSDP.shard_full_optim_state_dict( - rekeyed_osd, model1, optim_input=optim_input1, - ) if use_optim_input + rekeyed_osd, + model1, + optim_input=optim_input1, + ) + if use_optim_input else FSDP.shard_full_optim_state_dict( - rekeyed_osd, model1, optim=optim1, + rekeyed_osd, + model1, + optim=optim1, ) ) # Check that this sharded optimizer state dict matches the wrapped @@ -1140,10 +1246,14 @@ def _test_rekey_optim_state_dict_to_names( osd1 = optim1.state_dict() check_same_param_keys = True self._check_same_param_groups( - sharded_osd, osd1, check_same_param_keys=check_same_param_keys, + sharded_osd, + osd1, + check_same_param_keys=check_same_param_keys, ) self._check_same_state( - sharded_osd, osd1, check_same_param_keys=check_same_param_keys, + sharded_osd, + osd1, + check_same_param_keys=check_same_param_keys, ) # As a sanity check, check that we can load and run a few iterations optim1.load_state_dict(sharded_osd) @@ -1153,6 +1263,7 @@ def _test_rekey_optim_state_dict_to_names( def test_optim_input_warning(self): """Tests that passing the ``optim_input`` argument into optimizer state checkpointing APIs issues a warning.""" + def should_check_method(method_name: str): # Check every method since they all accept `optim_input` return True @@ -1163,12 +1274,15 @@ def get_warning_context(): expected_warning=UserWarning, expected_regex=warning_regex ) - self._run_on_all_optim_state_apis(should_check_method, get_warning_context, fsdp_kwargs=None) + self._run_on_all_optim_state_apis( + should_check_method, get_warning_context, fsdp_kwargs=None + ) @skip_if_lt_x_gpu(2) def test_use_orig_params_error(self): """Tests that the optimizer state checkpointing APIs raise an error when ``use_orig_params=True``.""" + def should_check_method(method_name: str): # Skip `rekey_optim_state_dict` since that does not depend on # `use_orig_params=True` @@ -1181,7 +1295,9 @@ def get_error_context(): ) fsdp_kwargs = {"use_orig_params": True} - self._run_on_all_optim_state_apis(should_check_method, get_error_context, fsdp_kwargs) + self._run_on_all_optim_state_apis( + should_check_method, get_error_context, fsdp_kwargs + ) def _run_on_all_optim_state_apis( self, @@ -1195,12 +1311,10 @@ def _run_on_all_optim_state_apis( via ``should_check_method_fn``, which gets passed the string name of the method. """ - wrapped_model, wrapped_optim, wrapped_optim_input = ( - self._init_nested_model( - wrap=True, - use_multiple_param_groups=False, - fsdp_kwargs=fsdp_kwargs, - ) + wrapped_model, wrapped_optim, wrapped_optim_input = self._init_nested_model( + wrap=True, + use_multiple_param_groups=False, + fsdp_kwargs=fsdp_kwargs, ) self._step_model(wrapped_model, wrapped_optim, num_iters=2) @@ -1208,14 +1322,18 @@ def _run_on_all_optim_state_apis( if should_check_method_fn("sharded_optim_state_dict"): with context_fn(): fsdp_osd = FSDP.sharded_optim_state_dict( - wrapped_model, wrapped_optim, optim_input=wrapped_optim_input, + wrapped_model, + wrapped_optim, + optim_input=wrapped_optim_input, ) if "fsdp_osd" not in locals(): fsdp_osd = {} # may not be defined due to previous method erroring if should_check_method_fn("flatten_sharded_optim_state_dict"): with context_fn(): FSDP.flatten_sharded_optim_state_dict( - fsdp_osd, wrapped_model, optim_input=wrapped_optim_input, + fsdp_osd, + wrapped_model, + optim_input=wrapped_optim_input, ) # Full optim state dict if should_check_method_fn("full_optim_state_dict"): @@ -1229,17 +1347,23 @@ def _run_on_all_optim_state_apis( if should_check_method_fn("shard_full_optim_state_dict"): with context_fn(): FSDP.shard_full_optim_state_dict( - fsdp_osd, wrapped_model, optim_input=wrapped_optim_input, + fsdp_osd, + wrapped_model, + optim_input=wrapped_optim_input, ) if should_check_method_fn("scatter_full_optim_state_dict"): with context_fn(): FSDP.scatter_full_optim_state_dict( - fsdp_osd, wrapped_model, optim_input=wrapped_optim_input, + fsdp_osd, + wrapped_model, + optim_input=wrapped_optim_input, ) # Rekey optim state dict - nonwrapped_model, nonwrapped_optim, nonwrapped_optim_input = ( - self._init_nested_model(wrap=False, use_multiple_param_groups=False) - ) + ( + nonwrapped_model, + nonwrapped_optim, + nonwrapped_optim_input, + ) = self._init_nested_model(wrap=False, use_multiple_param_groups=False) if should_check_method_fn("rekey_optim_state_dict"): with context_fn(): rekeyed_osd = FSDP.rekey_optim_state_dict( @@ -1259,6 +1383,61 @@ def _run_on_all_optim_state_apis( optim_input=nonwrapped_optim_input, ) + @skip_if_lt_x_gpu(2) + @parametrize("state_dict_type", STATE_DICT_TYPES) + def test_save_load_without_0th_param_state(self, state_dict_type: StateDictType): + """ + Tests saving and loading an optim state dict for Adam optimizer (i.e. + any optimizer with a "step" key in its state) when the first parameter + does not have optimizer state (e.g. unused or frozen). + """ + + class Model(nn.Module): + def __init__(self) -> None: + super().__init__() + self.lin1 = nn.Linear(5, 5) + self.lin2 = nn.Linear(5, 5) + self.relu = nn.ReLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Do not use `lin1`, which is the parameter passed to the + # optimizer and the one checked for "step" state to see if it + # is tensor or float + return self.relu(self.lin2(x)) + + model = Model().cuda() + model.lin1 = FSDP(model.lin1) + model.lin2 = FSDP(model.lin2) + fsdp_model = FSDP(model) + optim = torch.optim.Adam( + fsdp_model.parameters(), lr=1e-2 + ) # or any optimizer with "step" + + # Run an iteration to construct optimizer state + device = torch.device("cuda") + inp = torch.randn((2, 5), device=device) + loss = fsdp_model(inp).sum() + loss.backward() + optim.step() + + # Check that save and load does not error + if state_dict_type == StateDictType.FULL_STATE_DICT: + fsdp_osd = FSDP.full_optim_state_dict(fsdp_model, optim, rank0_only=False) + flattened_osd = FSDP.shard_full_optim_state_dict(fsdp_osd, fsdp_model) + elif state_dict_type == StateDictType.SHARDED_STATE_DICT: + fsdp_osd = FSDP.sharded_optim_state_dict(fsdp_model, optim) + flattened_osd = FSDP.flatten_sharded_optim_state_dict(fsdp_osd, fsdp_model) + optim.load_state_dict(flattened_osd) + # `__setstate__()` will check the 0th parameter to see if "step" is + # represented as a tensor or float, so it is imperative that its state + # is non-empty. + + # Run an iteration as a sanity check + inp = torch.randn((2, 5), device=device) + loss = fsdp_model(inp).sum() + loss.backward() + optim.step() + instantiate_parametrized_tests(TestFSDPOptimState) diff --git a/test/distributed/fsdp/test_fsdp_overlap.py b/test/distributed/fsdp/test_fsdp_overlap.py index 07e8eba09c6c2..8bd5354b2b701 100644 --- a/test/distributed/fsdp/test_fsdp_overlap.py +++ b/test/distributed/fsdp/test_fsdp_overlap.py @@ -11,16 +11,13 @@ from torch.cuda import Event from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.testing._internal.common_distributed import skip_if_lt_x_gpu -from torch.testing._internal.common_fsdp import ( - FSDPTest, -) +from torch.testing._internal.common_fsdp import FSDPTest from torch.testing._internal.common_utils import ( - TEST_WITH_DEV_DBG_ASAN, get_cycles_per_ms, run_tests, + TEST_WITH_DEV_DBG_ASAN, ) - if not dist.is_available(): print("Distributed not available, skipping tests", file=sys.stderr) sys.exit(0) diff --git a/test/distributed/fsdp/test_fsdp_param_exec_order_wrap.py b/test/distributed/fsdp/test_fsdp_param_exec_order_wrap.py deleted file mode 100644 index a1c73d1cafb53..0000000000000 --- a/test/distributed/fsdp/test_fsdp_param_exec_order_wrap.py +++ /dev/null @@ -1,134 +0,0 @@ -# Owner(s): ["oncall: distributed"] - -from typing import Any, Callable - -import torch -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp._symbolic_trace import TracingConfig -from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy -from torch.distributed.fsdp.wrap import always_wrap_policy, ParamExecOrderWrapPolicy -from torch.testing._internal.common_distributed import skip_if_lt_x_gpu -from torch.testing._internal.common_fsdp import FSDPTest -from torch.testing._internal.common_utils import ( - instantiate_parametrized_tests, - parametrize, - run_tests, -) - - -class Model(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - self.layer0 = torch.nn.Linear(6, 6) - self.layer1 = torch.nn.Linear(6, 6, bias=False) - self.layer2 = torch.nn.Sequential( - torch.nn.Linear(6, 3, bias=False), - torch.nn.ReLU(), - torch.nn.Linear(3, 6, bias=False), - ) - self.relu = torch.nn.ReLU() - - def forward(self, x: Any, use_all_params: bool = True): - # `layer0` -> `layer2` -> `layer1` - # the forward execution order is NOT consistent with the model definition order. - z = self.relu(self.layer0(x)) - z = self.relu(self.layer2(z)) - if use_all_params: - z = self.relu(self.layer1(z)) - return z - - def get_input(self, device: torch.device): - return (torch.randn((8, 6)).to(device),) - - def get_loss(self, input, output): - return (output - input[0]).sum() - - @staticmethod - def wrap( - sharding_strategy: ShardingStrategy, - device: torch.device, - wrap_policy: Callable, - ) -> torch.nn.Module: - model = Model() - fsdp_model = FSDP( - model, auto_wrap_policy=wrap_policy, sharding_strategy=sharding_strategy - ) - return fsdp_model.to(device) - - -class TestFSDPExecOrder(FSDPTest): - @property - def device(self): - return torch.device("cuda") - - @skip_if_lt_x_gpu(2) - @parametrize( - "sharding_strategy", - [ShardingStrategy.FULL_SHARD, ShardingStrategy.SHARD_GRAD_OP], - ) - def test_fsdp_flatten_params_exec_order( - self, - sharding_strategy: ShardingStrategy, - ): - """ - Test ``_fsdp_params_exec_order`` with ``ParamExecOrderWrapPolicy``, - after running one iteration of forward and backward pass. - Here ``torch.fx`` is not enabled inside ``ParamExecOrderWrapPolicy``. - """ - wrap_policy = ParamExecOrderWrapPolicy(init_policy=always_wrap_policy) - fsdp_model = Model.wrap(sharding_strategy, self.device, wrap_policy=wrap_policy) - self.assertTrue(fsdp_model._is_param_exec_order_prep_stage()) - # run one iteration to record the execution ordering - input = fsdp_model.module.get_input(self.device) - output = fsdp_model(*input) - loss = fsdp_model.module.get_loss(input, output).to(self.device) - loss.backward() - params_list = list(fsdp_model.parameters()) - # Since the forward execution order is NOT consistent with - # the model definition order, the ordering in flatten_named_params_exec_order - # should be different from named_parameters. - self.assertEqual( - fsdp_model._fsdp_params_exec_order, - [params_list[0], params_list[2], params_list[3], params_list[1]], - ) - self.assertTrue(fsdp_model._use_param_exec_order_policy()) - self.assertTrue(not fsdp_model._is_param_exec_order_prep_stage()) - - @skip_if_lt_x_gpu(2) - @parametrize( - "sharding_strategy", - [ShardingStrategy.FULL_SHARD, ShardingStrategy.SHARD_GRAD_OP], - ) - def test_fsdp_flatten_params_exec_order_symbolic_trace( - self, - sharding_strategy: ShardingStrategy, - ): - """ - Tests ``ParamExecOrderWrapPolicy`` with symbolic tracing. - With symbolic tracing enabled, ``_is_param_exec_order_prep_stage`` - should always set as False. - """ - wrap_policy = ParamExecOrderWrapPolicy( - init_policy=always_wrap_policy, - tracing_config=TracingConfig(concrete_args={"use_all_params": False}), - ) - fsdp_model = Model.wrap( - sharding_strategy, - self.device, - wrap_policy=wrap_policy, - ) - params_list = list(fsdp_model.parameters()) - # Since the forward execution order is NOT consistent with the model definition order, - # the ordering in flatten_named_params_exec_order should be different from named_parameters - self.assertEqual( - fsdp_model._fsdp_params_exec_order, - [params_list[0], params_list[2], params_list[3]], - ) - self.assertTrue(fsdp_model._use_param_exec_order_policy()) - self.assertTrue(not fsdp_model._is_param_exec_order_prep_stage()) - - -instantiate_parametrized_tests(TestFSDPExecOrder) - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/fsdp/test_fsdp_pure_fp16.py b/test/distributed/fsdp/test_fsdp_pure_fp16.py index ed4aef39da0f9..e0033ef3d4b72 100644 --- a/test/distributed/fsdp/test_fsdp_pure_fp16.py +++ b/test/distributed/fsdp/test_fsdp_pure_fp16.py @@ -12,10 +12,10 @@ NestedWrappedModule, ) from torch.testing._internal.common_utils import ( - TEST_WITH_DEV_DBG_ASAN, instantiate_parametrized_tests, parametrize, run_tests, + TEST_WITH_DEV_DBG_ASAN, ) if not dist.is_available(): @@ -31,11 +31,10 @@ class TestPureFP16(FSDPTest): - @property def world_size(self): - # Test fails due to inaccuracies when using more than 5 GPUs - return min(5, super().world_size) + # Test fails due to inaccuracies when using more than 4 GPUs + return min(4, super().world_size) @skip_if_lt_x_gpu(2) @parametrize( diff --git a/test/distributed/fsdp/test_fsdp_sharded_grad_scaler.py b/test/distributed/fsdp/test_fsdp_sharded_grad_scaler.py index 1c230cb7400c4..2124e6b0450f5 100644 --- a/test/distributed/fsdp/test_fsdp_sharded_grad_scaler.py +++ b/test/distributed/fsdp/test_fsdp_sharded_grad_scaler.py @@ -22,11 +22,11 @@ subtest_name, ) from torch.testing._internal.common_utils import ( - TEST_WITH_DEV_DBG_ASAN, - TestCase, instantiate_parametrized_tests, parametrize, run_tests, + TEST_WITH_DEV_DBG_ASAN, + TestCase, ) if not dist.is_available(): @@ -47,21 +47,23 @@ sharding_strategy_config = [ShardingStrategy.SHARD_GRAD_OP, None] mixed_precision = ["enable_mixed_precision", None] -configs = list(itertools.product(cpu_offload_config, - sharding_strategy_config, - mixed_precision)) +configs = list( + itertools.product(cpu_offload_config, sharding_strategy_config, mixed_precision) +) test_name_mapping = { str(CPUOffload(offload_params=True)): "offload_true", str(CPUOffload(offload_params=False)): "offload_false", str(ShardingStrategy.SHARD_GRAD_OP): "shard_grad_op", - "enable_mixed_precision": "mixed_precision" + "enable_mixed_precision": "mixed_precision", } subtest_name = functools.partial(subtest_name, test_name_mapping) class TestShardGradScaler(TestCase): - @unittest.skipIf(amp_definitely_not_available(), "no supported device (cuda, xla) found") + @unittest.skipIf( + amp_definitely_not_available(), "no supported device (cuda, xla) found" + ) def test_grad_scaling(self): pg = DummyProcessGroup(0, 1) scaler = ShardedGradScaler(init_scale=2.0, process_group=pg, enabled=True) @@ -69,21 +71,26 @@ def test_grad_scaling(self): t1 = torch.full((1,), 8.0, dtype=torch.float32, device="cpu") outputs = [t1.clone(), (t0.clone(), t1.clone()), [t0.clone(), t1.clone()]] outputs = scaler.scale(outputs) - self.assertTrue(outputs[0] == 16.0 and outputs[1][0] == 8.0 and outputs[1][1] == 16.0) + self.assertTrue( + outputs[0] == 16.0 and outputs[1][0] == 8.0 and outputs[1][1] == 16.0 + ) self.assertTrue(outputs[2][0] == 8.0 and outputs[2][1] == 16.0) self.assertTrue(scaler._scale.device == t1.device) - @unittest.skipIf(amp_definitely_not_available(), "no supported device (cuda, xla) found") + @unittest.skipIf( + amp_definitely_not_available(), "no supported device (cuda, xla) found" + ) def test_scaling_unscaling_sparse(self): pg = DummyProcessGroup(0, 1) scaler = ShardedGradScaler(init_scale=2.0, process_group=pg, enabled=True) inv_scale = torch.full((1,), 0.5, dtype=torch.float, device="cpu") found_inf = torch.full((1,), 0, dtype=torch.float, device="cpu") - i = torch.tensor([[0, 1, 1], - [2, 0, 2]], device="cpu", dtype=torch.int64) + i = torch.tensor([[0, 1, 1], [2, 0, 2]], device="cpu", dtype=torch.int64) v = torch.tensor([16.0, 32.0, 64.0], dtype=torch.float, device="cpu") - s = torch.sparse_coo_tensor(i, v, torch.Size([2, 3]), device="cpu", dtype=torch.float) + s = torch.sparse_coo_tensor( + i, v, torch.Size([2, 3]), device="cpu", dtype=torch.float + ) # unscale sparse tensors s1 = s.clone() @@ -95,29 +102,34 @@ def test_scaling_unscaling_sparse(self): self.assertEqual(s1.grad.to_dense(), (s / 2).to_dense()) # unscale sparse tensor: inf - v = torch.tensor([16.0, 32.0, float('inf')], dtype=torch.float, device="cpu") - s1.grad = torch.sparse_coo_tensor(i, v, torch.Size([2, 3]), device="cpu", dtype=torch.float) + v = torch.tensor([16.0, 32.0, float("inf")], dtype=torch.float, device="cpu") + s1.grad = torch.sparse_coo_tensor( + i, v, torch.Size([2, 3]), device="cpu", dtype=torch.float + ) found_inf.zero_() found_inf = scaler._unscale_grads_(opt, inv_scale, found_inf)[s1.device] self.assertEqual(found_inf, 1.0) # unscale sparse tensor: overflow (marked as inf) - i = torch.tensor([[1, 1, 1], - [0, 0, 2]], device="cpu", dtype=torch.int64) + i = torch.tensor([[1, 1, 1], [0, 0, 2]], device="cpu", dtype=torch.int64) # coalescing sparse tensor here will cause the value to be Inf v = torch.tensor([2**15, 2**15, 1.0], dtype=torch.float16, device="cpu") - s1 = torch.sparse_coo_tensor(i, v, torch.Size([2, 3]), device="cpu", dtype=torch.float16) + s1 = torch.sparse_coo_tensor( + i, v, torch.Size([2, 3]), device="cpu", dtype=torch.float16 + ) s1.grad = s1.clone() found_inf.zero_() found_inf = scaler._unscale_grads_(opt, inv_scale, found_inf)[s1.device] self.assertEqual(found_inf, 1.0) - @unittest.skipIf(amp_definitely_not_available(), "no supported device (cuda, xla) found") + @unittest.skipIf( + amp_definitely_not_available(), "no supported device (cuda, xla) found" + ) def test_inf_gradients_skip_optim_step(self): pg = DummyProcessGroup(0, 1) scaler = ShardedGradScaler(init_scale=2.0, process_group=pg, enabled=True) loss = torch.full((1,), 4.0, dtype=torch.float32, device="cpu") - t0 = torch.tensor([float('inf')], dtype=torch.float32, device="cpu") + t0 = torch.tensor([float("inf")], dtype=torch.float32, device="cpu") t0.grad = t0.clone() opt = torch.optim.SGD([t0], lr=1.0) scaler.scale(loss) @@ -127,10 +139,7 @@ def test_inf_gradients_skip_optim_step(self): class TestShardedGradScalerParityWithDDP(FSDPTest): def _get_init_modes_for_test(self, cpu_offload): - modes = [ - CUDAInitMode.CUDA_AFTER, - CUDAInitMode.CUDA_BEFORE - ] + modes = [CUDAInitMode.CUDA_AFTER, CUDAInitMode.CUDA_BEFORE] # Note that CUDAInitMode.CUDA_NEVER works currently only with CPU # offload as we explicitly bring the param back to CUDA device. In # general, it will not work since we try to all_gather p.data which is @@ -149,11 +158,15 @@ def test_fsdp_ddp_parity_with_grad_scaler( mixed_precision: Optional[str], ): init_modes = self._get_init_modes_for_test(cpu_offload) - mp = MixedPrecision( - param_dtype=torch.float16, - reduce_dtype=torch.float16, - buffer_dtype=torch.float16, - ) if mixed_precision is not None else None + mp = ( + MixedPrecision( + param_dtype=torch.float16, + reduce_dtype=torch.float16, + buffer_dtype=torch.float16, + ) + if mixed_precision is not None + else None + ) for cuda_init_mode in init_modes: self._test_fsdp_parity( NestedWrappedModule, diff --git a/test/distributed/fsdp/test_fsdp_state_dict.py b/test/distributed/fsdp/test_fsdp_state_dict.py index af56ee956743f..0a453efe8ffba 100644 --- a/test/distributed/fsdp/test_fsdp_state_dict.py +++ b/test/distributed/fsdp/test_fsdp_state_dict.py @@ -11,51 +11,42 @@ import torch.nn as nn from torch import distributed as dist from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + apply_activation_checkpointing, checkpoint_wrapper, + CheckpointImpl, ) -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import ( CPUOffload, FullStateDictConfig, + FullyShardedDataParallel as FSDP, LocalStateDictConfig, MixedPrecision, ShardedStateDictConfig, StateDictType, ) from torch.distributed.fsdp._shard_utils import _gather_state_dict -from torch.distributed.fsdp.fully_sharded_data_parallel import ( - FullyShardedDataParallel, -) -from torch.distributed.fsdp.wrap import ( - enable_wrap, - transformer_auto_wrap_policy, - wrap, -) -from torch.nn import ( - Linear, - Module, - TransformerDecoderLayer, - TransformerEncoderLayer, -) +from torch.distributed.fsdp._unshard_param_utils import FLAT_PARAM +from torch.distributed.fsdp.wrap import enable_wrap, ModuleWrapPolicy, wrap +from torch.nn import Linear, Module, TransformerDecoderLayer, TransformerEncoderLayer from torch.nn.parallel import DistributedDataParallel from torch.optim import SGD from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import ( + _assert_module_states, + _get_state_dict, + _zero_model, CUDAInitMode, FSDPInitMode, FSDPTest, + get_full_params, SkipModel, TransformerWithSharedParams, - _assert_module_states, - _get_state_dict, - _zero_model, - get_full_params, ) from torch.testing._internal.common_utils import ( - TEST_WITH_DEV_DBG_ASAN, instantiate_parametrized_tests, parametrize, run_tests, + TEST_WITH_DEV_DBG_ASAN, ) if not dist.is_available(): @@ -73,7 +64,7 @@ OUTER_SHAPE = [4, 5] BUFFER_SHAPE = [5, 5] -NON_ROOT_FSDP_PREFIX = 'non_fsdp_lin' +NON_ROOT_FSDP_PREFIX = "non_fsdp_lin" _UNFLATTENED_STATE_DICT_IMPLS = ["state_dict", "sharded_state_dict"] _FLATTENED_STATE_DICT_IMPLS = ["local_state_dict"] @@ -98,7 +89,9 @@ def __init__(self, wrap_fsdp, register_buffers=False, ignore_inner=False): "non_persistent_buffer", torch.randn(BUFFER_SHAPE), persistent=False ) if wrap_fsdp: - self.inner = FSDP(self.inner, ignored_modules=([self.inner] if ignore_inner else [])) + self.inner = FSDP( + self.inner, ignored_modules=([self.inner] if ignore_inner else []) + ) self.outer = Linear(*OUTER_SHAPE) if register_buffers: self.outer.register_buffer("buffer", torch.randn(BUFFER_SHAPE)) @@ -118,22 +111,40 @@ class TestFSDPStateDict(FSDPTest): def world_size(self): return 2 - def _broadcast_state_dict(self, state_dict): + def _broadcast_state_dict(self, model, state_dict): + if not isinstance(model, FSDP): + # For non-FSDP root, some parts of the model state on rank 0 may + # not be on CPU, so we move everything to CPU to avoid issues like: + # https://github.com/pytorch/pytorch/issues/77113. + for param_name, param in state_dict.items(): + if param.device != torch.device("cpu"): + state_dict[param_name] = param.cpu() + olist = [state_dict if self.rank == 0 else None] dist.broadcast_object_list(olist) - return olist[0] + state_dict = olist[0] + # Ensure that the state is on CUDA + for param_name in state_dict.keys(): + state_dict[param_name] = state_dict[param_name].cuda() + return state_dict def _compare_models(self, model, model_new, assert_fn, check_fp16=False): - with FullyShardedDataParallel.summon_full_params(model): - with FullyShardedDataParallel.summon_full_params(model_new): + assert assert_fn in (self.assertEqual, self.assertNotEqual) + with FSDP.summon_full_params(model): + with FSDP.summon_full_params(model_new): params = list(model.parameters()) params_new = list(model_new.parameters()) + # Regardless of `assert_fn`, the number of parameters should be + # the same + self.assertEqual(len(params), len(params_new)) assert_fn(params, params_new) if check_fp16: for tensor in model_new.parameters(): self.assertEqual(tensor.dtype, torch.float16) - def _get_simple_nested_model(self, *fsdp_args, wrap=True, checkpoint_wrap=False, **fsdp_kwargs): + def _get_simple_nested_model( + self, *fsdp_args, wrap=True, checkpoint_wrap=False, **fsdp_kwargs + ): if wrap: lin1 = nn.Linear(10, 10, bias=False).cuda() lin2 = nn.Linear(10, 10, bias=False).cuda() @@ -146,7 +157,8 @@ def _get_simple_nested_model(self, *fsdp_args, wrap=True, checkpoint_wrap=False, model = FSDP(seq, *fsdp_args, **fsdp_kwargs) else: model = nn.Sequential( - nn.Linear(10, 10, bias=False).cuda(), nn.Linear(10, 10, bias=False).cuda() + nn.Linear(10, 10, bias=False).cuda(), + nn.Linear(10, 10, bias=False).cuda(), ) return model @@ -222,42 +234,131 @@ def _validate_state_dict_contents( @skip_if_lt_x_gpu(2) @parametrize("state_dict_type", _UNFLATTENED_STATE_DICT_IMPLS) - @parametrize("checkpoint_wrap", ["first", "second", "both"]) - def test_fsdp_state_dict_with_activation_checkpoint(self, state_dict_type, checkpoint_wrap): + @parametrize( + "checkpoint_wrap", + ["source", "dest", "both", "source_after_wrap", "both_after_wrap"], + ) + @parametrize("rank0_only_and_offload", [False, True]) + def test_fsdp_state_dict_with_activation_checkpoint( + self, state_dict_type, checkpoint_wrap, rank0_only_and_offload + ): """Tests saving the state dict, zeroing a target model's parameters, and loading the state dict, where the source and target models may have a checkpoint wrapper.""" + + def apply_ac_to_linears(model) -> None: + non_reentrant_wrapper = partial( + checkpoint_wrapper, + offload_to_cpu=False, + checkpoint_impl=CheckpointImpl.NO_REENTRANT, + ) + apply_activation_checkpointing( + model, + checkpoint_wrapper_fn=non_reentrant_wrapper, + check_fn=lambda submodule: isinstance(submodule, nn.Linear), + ) + for model_call in [ partial(self._get_simple_model), - partial(self._get_simple_nested_model) + partial(self._get_simple_nested_model), ]: - model = model_call(checkpoint_wrap=(checkpoint_wrap in ["first", "both"])) - with FSDP.state_dict_type(model, STATE_DICT_MAPPING[state_dict_type]): + model = model_call(checkpoint_wrap=(checkpoint_wrap in ("source", "both"))) + if checkpoint_wrap in ("source_after_wrap", "both_after_wrap"): + apply_ac_to_linears(model) + with self._get_state_dict_mgr( + model, state_dict_type, rank0_only_and_offload + ): state_dict = _gather_state_dict(_get_state_dict(model, False, False)) # Possibly wrap new model in activation checkpoint wrapper to test save/ # load with this wrapper - model_new = model_call(checkpoint_wrap=(checkpoint_wrap in ["second", "both"])) + model_new = model_call( + checkpoint_wrap=(checkpoint_wrap in ("dest", "both")) + ) + if checkpoint_wrap == "both_after_wrap": + apply_ac_to_linears(model_new) _zero_model(model_new) self._compare_models(model, model_new, self.assertNotEqual) + if rank0_only_and_offload: + state_dict = self._broadcast_state_dict(model, state_dict) # Would fail if checkpoint_wrapper did not correctly implement state_dict pre/post hooks model_new.load_state_dict(state_dict, strict=True) self._compare_models(model, model_new, self.assertEqual) + @skip_if_lt_x_gpu(2) + @parametrize("state_dict_type", _UNFLATTENED_STATE_DICT_IMPLS) + @parametrize("rank0_only_and_offload", [False, True]) + def test_state_dict_with_manual_ac_wrapper( + self, + state_dict_type: str, + rank0_only_and_offload: bool, + ): + """ + Tests saving and loading a state dict for a model manually wrapped with + ``FSDP(CheckpointWrapper(module))``, where the ``CheckpointWrapper`` is + wrapped before FSDP. + + TODO: Investigate why the test above does not cover everything in this + test and de-duplicate afterwards. + """ + if state_dict_type == "sharded_state_dict" and rank0_only_and_offload: + return # not supported + model_ac = TransformerWithSharedParams.init( + self.process_group, + FSDPInitMode.NO_FSDP, + CUDAInitMode.CUDA_BEFORE, + ) + # Manually wrap FSDP without AC + model_no_ac = deepcopy(model_ac) + for i, layer in enumerate(model_no_ac.transformer.encoder.layers): + model_no_ac.transformer.encoder.layers[i] = FSDP(layer) + for i, layer in enumerate(model_no_ac.transformer.decoder.layers): + model_no_ac.transformer.decoder.layers[i] = FSDP(layer) + model_no_ac.transformer = FSDP(model_no_ac.transformer) + + # Manually wrap FSDP with AC as `FSDP(CheckpointWrapper(module))` + for i, layer in enumerate(model_ac.transformer.encoder.layers): + layer = checkpoint_wrapper(layer) + model_ac.transformer.encoder.layers[i] = FSDP(layer) + for i, layer in enumerate(model_ac.transformer.decoder.layers): + layer = checkpoint_wrapper(layer) + model_ac.transformer.decoder.layers[i] = FSDP(layer) + model_ac.transformer = FSDP(model_ac.transformer) + + # Save, load, and compare the two models + with self._get_state_dict_mgr( + model_no_ac, state_dict_type, rank0_only_and_offload + ): + state_dict_no_ac = model_no_ac.state_dict() + with self._get_state_dict_mgr( + model_ac, state_dict_type, rank0_only_and_offload + ): + state_dict_ac = model_ac.state_dict() + self.assertEqual(state_dict_ac.keys(), state_dict_no_ac.keys()) + if rank0_only_and_offload: + state_dict_no_ac = self._broadcast_state_dict(model_no_ac, state_dict_no_ac) + state_dict_ac = self._broadcast_state_dict(model_ac, state_dict_ac) + with self._get_state_dict_mgr( + model_no_ac, state_dict_type, rank0_only_and_offload + ): + model_no_ac.load_state_dict(state_dict_no_ac) + with self._get_state_dict_mgr( + model_ac, state_dict_type, rank0_only_and_offload + ): + model_ac.load_state_dict(state_dict_ac) + self._compare_models(model_ac, model_no_ac, self.assertEqual) + @skip_if_lt_x_gpu(2) @parametrize("state_dict_type", _SUPPORTED_STATE_DICT_IMPLS) def test_state_dict_with_shared_parameters(self, state_dict_type): - auto_wrap_policy = partial( - transformer_auto_wrap_policy, - transformer_layer_cls={ - TransformerEncoderLayer, TransformerDecoderLayer - }, + auto_wrap_policy = ModuleWrapPolicy( + {TransformerEncoderLayer, TransformerDecoderLayer} ) model_creator = partial( TransformerWithSharedParams.init, self.process_group, FSDPInitMode.RECURSIVE, CUDAInitMode.CUDA_BEFORE, - {"auto_wrap_policy": auto_wrap_policy} + {"auto_wrap_policy": auto_wrap_policy}, ) fsdp_model = model_creator() @@ -275,9 +376,8 @@ def test_state_dict_rank0_offload_save_load_flow(self, use_orig_params: bool): """Tests saving a model checkpoint only on rank 0 and loading it only on rank 0 with ``sync_module_states=True`` to emulate the workflow to avoid redundant CPU memory usage.""" - auto_wrap_policy = partial( - transformer_auto_wrap_policy, - transformer_layer_cls={TransformerEncoderLayer, TransformerDecoderLayer}, + auto_wrap_policy = ModuleWrapPolicy( + {TransformerEncoderLayer, TransformerDecoderLayer} ) fsdp_kwargs = { "auto_wrap_policy": auto_wrap_policy, @@ -291,10 +391,14 @@ def test_state_dict_rank0_offload_save_load_flow(self, use_orig_params: bool): ) # Force model parameters and buffers to be nonzero with FSDP.summon_full_params(fsdp_model): - for tensor in itertools.chain(fsdp_model.parameters(), fsdp_model.buffers()): + for tensor in itertools.chain( + fsdp_model.parameters(), fsdp_model.buffers() + ): if torch.count_nonzero(tensor) == 0: with torch.no_grad(): - tensor.add_(torch.tensor(1, dtype=tensor.dtype, device=tensor.device)) + tensor.add_( + torch.tensor(1, dtype=tensor.dtype, device=tensor.device) + ) with self._get_state_dict_mgr(fsdp_model, "state_dict", True): state_dict = deepcopy(_get_state_dict(fsdp_model)) # Initialize a non-wrapped model on all ranks @@ -327,8 +431,8 @@ def test_state_dict_rank0_offload_save_load_flow(self, use_orig_params: bool): assert_fn=self.assertEqual, ) # Check FSDP models correctly loaded the checkpoint - with FullyShardedDataParallel.summon_full_params(fsdp_model): - with FullyShardedDataParallel.summon_full_params(new_fsdp_model): + with FSDP.summon_full_params(fsdp_model): + with FSDP.summon_full_params(new_fsdp_model): params = list(fsdp_model.parameters()) params_new = list(new_fsdp_model.parameters()) self.assertEqual(params, params_new) @@ -341,7 +445,7 @@ def test_state_dict_rank0_offload_save_load_flow(self, use_orig_params: bool): ) @parametrize("fp16", [True, False]) @parametrize("state_dict_rank0_and_offload", [True, False]) - @parametrize("use_orig_params", [False, True]) + @parametrize("use_orig_params", [True, False]) def test_basic_save_and_load_state_dict( self, state_dict_type: StateDictType, @@ -355,15 +459,26 @@ def test_basic_save_and_load_state_dict( with various configs such as fp16 and cpu offload and parameters match as expected. """ - if ( - (state_dict_rank0_and_offload and state_dict_type != "state_dict") - or (use_orig_params and state_dict_type not in _UNFLATTENED_STATE_DICT_IMPLS) + if (state_dict_rank0_and_offload and state_dict_type != "state_dict") or ( + use_orig_params and state_dict_type not in _UNFLATTENED_STATE_DICT_IMPLS ): return # not supported for model_call in [ - partial(self._get_non_fsdp_root_module, cpu_offload=cpu_offload, use_orig_params=use_orig_params), - partial(self._get_simple_nested_model, cpu_offload=cpu_offload, use_orig_params=use_orig_params), - partial(self._get_simple_model, cpu_offload=cpu_offload, use_orig_params=use_orig_params), + partial( + self._get_non_fsdp_root_module, + cpu_offload=cpu_offload, + use_orig_params=use_orig_params, + ), + partial( + self._get_simple_nested_model, + cpu_offload=cpu_offload, + use_orig_params=use_orig_params, + ), + partial( + self._get_simple_model, + cpu_offload=cpu_offload, + use_orig_params=use_orig_params, + ), ]: model = model_call() @@ -375,10 +490,15 @@ def test_basic_save_and_load_state_dict( model, cpu_offload.offload_params, fp16 ) - ignore_keys = [k for k in fsdp_state_dict.keys() if NON_ROOT_FSDP_PREFIX in k] + ignore_keys = [ + k for k in fsdp_state_dict.keys() if NON_ROOT_FSDP_PREFIX in k + ] self._validate_state_dict_contents( - model, fsdp_state_dict, state_dict_rank0_and_offload, ignore_keys=ignore_keys, + model, + fsdp_state_dict, + state_dict_rank0_and_offload, + ignore_keys=ignore_keys, ) if fp16: # Verify fp16 is the type @@ -397,17 +517,7 @@ def test_basic_save_and_load_state_dict( # Verify parameters are the same in the new model. if state_dict_rank0_and_offload: - # Broadcast the state dict and move it back to GPU in - # preparation for loading. - if not isinstance(model, FSDP): - # Move everything to CPU to avoid running into - # https://github.com/pytorch/pytorch/issues/77113, some params - # will still be on GPU for non FSDP root modules. - for k in fsdp_state_dict.keys(): - fsdp_state_dict[k] = fsdp_state_dict[k].cpu() - fsdp_state_dict = self._broadcast_state_dict(fsdp_state_dict) - for key in fsdp_state_dict.keys(): - fsdp_state_dict[key] = fsdp_state_dict[key].cuda() + fsdp_state_dict = self._broadcast_state_dict(model, fsdp_state_dict) with FSDP.state_dict_type(model_new, STATE_DICT_MAPPING[state_dict_type]): model_new.load_state_dict(fsdp_state_dict, strict=True) @@ -463,7 +573,9 @@ def test_save_and_load_after_forward_state_dict( for sharded_tensor in state_dict.values(): shard = sharded_tensor._local_shards[0] shard.tensor = shard.tensor.clone().detach_() - self._validate_state_dict_contents(model, state_dict, state_dict_rank0_and_offload) + self._validate_state_dict_contents( + model, state_dict, state_dict_rank0_and_offload + ) _zero_model(model) # Ensure checkpointed params have the full param dtype @@ -472,11 +584,7 @@ def test_save_and_load_after_forward_state_dict( # Load state_dict into zeroed model if state_dict_rank0_and_offload: - # Broadcast the state dict and move it back to GPU in - # preparation for loading. - state_dict = self._broadcast_state_dict(state_dict) - for key in state_dict.keys(): - state_dict[key] = state_dict[key].cuda() + state_dict = self._broadcast_state_dict(model, state_dict) with FSDP.state_dict_type(model, STATE_DICT_MAPPING[state_dict_type]): model.load_state_dict(state_dict, strict=True) @@ -503,8 +611,8 @@ def _initialize_model( def _state_dict(model: Module, state_dict_type: str): try: enum_val = STATE_DICT_MAPPING[state_dict_type] - except KeyError: - raise ValueError(f"No state_dict type for {state_dict_type}") + except KeyError as e: + raise ValueError(f"No state_dict type for {state_dict_type}") from e with FSDP.state_dict_type(model, enum_val): return model.state_dict() @@ -515,8 +623,8 @@ def _load_state_dict( ): try: enum_val = STATE_DICT_MAPPING[state_dict_type] - except KeyError: - raise ValueError(f"No state_dict for {state_dict_type}") + except KeyError as e: + raise ValueError(f"No state_dict for {state_dict_type}") from e with FSDP.state_dict_type(model, enum_val): return model.load_state_dict(state_dict, strict=True) @@ -560,7 +668,9 @@ def test_state_dict_save_load_flow(self, state_dict_type): for move_to_cpu in [True, False]: with self.subTest(move_to_cpu=move_to_cpu): fsdp_params = self._dist_train( - wrap_fsdp=True, state_dict_type=state_dict_type, move_to_cpu=move_to_cpu, + wrap_fsdp=True, + state_dict_type=state_dict_type, + move_to_cpu=move_to_cpu, ) ddp_params = self._dist_train(wrap_fsdp=False) self.assertEqual(ddp_params, fsdp_params) @@ -570,7 +680,9 @@ def test_state_dict_save_load_flow(self, state_dict_type): def test_fsdp_state_dict_keys(self, state_dict_type): state_dict = self._state_dict(self._initialize_model(True), state_dict_type) if state_dict_type == "local_state_dict": - self.assertEqual(set(["flat_param", "inner.flat_param"]), state_dict.keys()) + self.assertEqual( + set([FLAT_PARAM, f"inner.{FLAT_PARAM}"]), state_dict.keys() + ) elif state_dict_type in ("state_dict", "sharded_state_dict"): # Keys should match local model. local_model = self._initialize_model(wrap_fsdp=False, wrap_ddp=False) @@ -584,7 +696,10 @@ def test_fsdp_state_dict_keys(self, state_dict_type): @parametrize("state_dict_rank0_and_offload", [True, False]) @parametrize("fsdp_root", [True, False]) def test_state_dict_load_into_local_module( - self, state_dict_type, state_dict_rank0_and_offload, fsdp_root, + self, + state_dict_type, + state_dict_rank0_and_offload, + fsdp_root, ): """ Tests that FSDP's state_dict can be loaded into a local model. @@ -597,7 +712,9 @@ def test_state_dict_load_into_local_module( model = self._initialize_model(wrap_fsdp=True, register_buffers=True) optim = SGD(model.parameters(), lr=0.1) if not fsdp_root: - in_data = torch.randn(1, 10, requires_grad=True, device=torch.device("cuda")) + in_data = torch.randn( + 1, 10, requires_grad=True, device=torch.device("cuda") + ) else: in_data = torch.rand(64, 4, requires_grad=True, device=torch.device("cuda")) for _ in range(3): @@ -606,7 +723,7 @@ def test_state_dict_load_into_local_module( optim.step() optim.zero_grad() - with FullyShardedDataParallel.summon_full_params(model): + with FSDP.summon_full_params(model): fsdp_params = deepcopy(list(model.parameters())) # get FSDP state_dict. Note that by default we return full_state_dict. @@ -618,7 +735,10 @@ def test_state_dict_load_into_local_module( ignore_keys = [k for k in fsdp_state_dict.keys() if NON_ROOT_FSDP_PREFIX in k] self._validate_state_dict_contents( - model, fsdp_state_dict, state_dict_rank0_and_offload, ignore_keys=ignore_keys, + model, + fsdp_state_dict, + state_dict_rank0_and_offload, + ignore_keys=ignore_keys, ) # Create zeroed local model if not fsdp_root: @@ -641,17 +761,7 @@ def test_state_dict_load_into_local_module( # Load fsdp's full state dict into the local and verify params are as # expected. if state_dict_rank0_and_offload: - # Broadcast + CUDA state_dict - if not isinstance(model, FSDP): - # Some portions of the model on rank 0 might not be on CPU, - # move everything to CPU to avoid running into - # https://github.com/pytorch/pytorch/issues/77113. - for k, t in fsdp_state_dict.items(): - if t.device != torch.device("cpu"): - fsdp_state_dict[k] = t.cpu() - fsdp_state_dict = self._broadcast_state_dict(fsdp_state_dict) - for key in fsdp_state_dict.keys(): - fsdp_state_dict[key] = fsdp_state_dict[key].cuda() + fsdp_state_dict = self._broadcast_state_dict(model, fsdp_state_dict) # if self.rank == 0: blank_local_model.load_state_dict(fsdp_state_dict, strict=True) @@ -747,10 +857,14 @@ def test_wrong_state_dict_config(self): @parametrize("state_dict_type", _UNFLATTENED_STATE_DICT_IMPLS) @parametrize("prefix", [True, False]) @parametrize("ignore_inner", [True, False]) - def test_state_dict_with_ignored_modules(self, state_dict_type, prefix, ignore_inner): + def test_state_dict_with_ignored_modules( + self, state_dict_type, prefix, ignore_inner + ): # Initialize an FSDP-wrapped model with an ignored module that includes # both parameters and a buffer - model = Model(wrap_fsdp=True, register_buffers=True, ignore_inner=ignore_inner).cuda() + model = Model( + wrap_fsdp=True, register_buffers=True, ignore_inner=ignore_inner + ).cuda() ignored_modules = [model.outer] ignored_tensor_to_tensor_name = { model.outer.bias: "outer.bias", @@ -765,7 +879,8 @@ def test_state_dict_with_ignored_modules(self, state_dict_type, prefix, ignore_i # Note that when model.inner is not ignored this test also ensures # non-ignored buffers are not cloned. buffer_to_buffer_name = { - model.inner.buffer: "inner.buffer", model.outer.buffer: "outer.buffer", + model.inner.buffer: "inner.buffer", + model.outer.buffer: "outer.buffer", } fsdp_model = FSDP(model, ignored_modules=ignored_modules) prefix_str = "foo." if prefix else "" @@ -780,7 +895,11 @@ def test_state_dict_with_ignored_modules(self, state_dict_type, prefix, ignore_i }.items(): prefixed_tensor_name = f"{prefix_str}{tensor_name}" self.assertTrue(prefixed_tensor_name in sd1) - self.assertEqual(tensor.data_ptr(), sd1[prefixed_tensor_name].data_ptr(), f"{prefixed_tensor_name}") + self.assertEqual( + tensor.data_ptr(), + sd1[prefixed_tensor_name].data_ptr(), + f"{prefixed_tensor_name}", + ) # Check that the state dict can be loaded into a non-wrapped version of # the model nonwrapped_model = Model(wrap_fsdp=False, register_buffers=True).cuda() @@ -788,7 +907,7 @@ def test_state_dict_with_ignored_modules(self, state_dict_type, prefix, ignore_i with torch.no_grad(): param.zero_() - to_load = {k[len(prefix_str):] : v for k, v in sd1.items()} + to_load = {k[len(prefix_str) :]: v for k, v in sd1.items()} nonwrapped_model.load_state_dict(to_load, strict=True) local_params = list(nonwrapped_model.parameters()) for fsdp_param, local_param in zip(fsdp_params, local_params): @@ -804,7 +923,10 @@ def test_state_dict_with_ignored_modules(self, state_dict_type, prefix, ignore_i prefixed_tensor_name = f"{prefix_str}{tensor_name}" self.assertTrue(prefixed_tensor_name in sd2) self.assertEqual(tensor.data_ptr(), sd2[prefixed_tensor_name].data_ptr()) - self.assertEqual(sd1[prefixed_tensor_name].data_ptr(), sd2[prefixed_tensor_name].data_ptr()) + self.assertEqual( + sd1[prefixed_tensor_name].data_ptr(), + sd2[prefixed_tensor_name].data_ptr(), + ) @skip_if_lt_x_gpu(2) def test_state_dict_type(self): diff --git a/test/distributed/fsdp/test_fsdp_summon_full_params.py b/test/distributed/fsdp/test_fsdp_summon_full_params.py index 29bf252b796fd..18055dbebffbf 100644 --- a/test/distributed/fsdp/test_fsdp_summon_full_params.py +++ b/test/distributed/fsdp/test_fsdp_summon_full_params.py @@ -9,9 +9,12 @@ import torch import torch.nn as nn from torch import distributed as dist -from torch.distributed.fsdp import CPUOffload -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp import MixedPrecision, ShardingStrategy +from torch.distributed.fsdp import ( + CPUOffload, + FullyShardedDataParallel as FSDP, + MixedPrecision, + ShardingStrategy, +) from torch.distributed.fsdp.flat_param import FlatParamHandle from torch.distributed.fsdp.wrap import enable_wrap, wrap from torch.nn.parallel.distributed import DistributedDataParallel as DDP @@ -25,10 +28,10 @@ TransformerWithSharedParams, ) from torch.testing._internal.common_utils import ( - TEST_WITH_DEV_DBG_ASAN, instantiate_parametrized_tests, parametrize, run_tests, + TEST_WITH_DEV_DBG_ASAN, ) if not dist.is_available(): @@ -52,10 +55,8 @@ def _run_test_summon_full_param_writeback( model = wrap(nn.Sequential(lin1, lin2)) # set the value - outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param") - inner_param = model.get_parameter( - "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param" - ) + outer_param = model._handles[0].flat_param + inner_param = model.module[0]._handles[0].flat_param p = outer_param if modify_outer else inner_param with torch.no_grad(): @@ -131,7 +132,9 @@ def test_summon_full_param_writeback(self): @skip_if_lt_x_gpu(2) @parametrize("mixed_precision", [True, False]) def test_summon_full_param_shard_value(self, mixed_precision): - mixed_precision = MixedPrecision(param_dtype=torch.float16) if mixed_precision else None + mixed_precision = ( + MixedPrecision(param_dtype=torch.float16) if mixed_precision else None + ) raw_model = nn.Linear(10, 11) raw_model_size = self.get_model_param_count(raw_model) expected_shard_size = self.get_expected_sharded_size(raw_model_size) @@ -161,7 +164,9 @@ def test_summon_full_param_shard_value(self, mixed_precision): @parametrize("summon_outer", [True, False]) @parametrize("mixed_precision", [True, False]) def test_summon_full_param_recursive(self, recurse, summon_outer, mixed_precision): - mixed_precision = MixedPrecision(param_dtype=torch.float16) if mixed_precision else None + mixed_precision = ( + MixedPrecision(param_dtype=torch.float16) if mixed_precision else None + ) model = FSDP( nn.Sequential( FSDP(nn.Linear(5, 5, bias=False), mixed_precision=mixed_precision), @@ -176,10 +181,8 @@ def test_summon_full_param_recursive(self, recurse, summon_outer, mixed_precisio shard_inner_numel = int(math.ceil(global_inner_numel / self.world_size)) shard_outer_numel = int(math.ceil(global_outer_numel / self.world_size)) - outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param") - inner_param = model.get_parameter( - "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param" - ) + outer_param = model._handles[0].flat_param + inner_param = model.module[0]._handles[0].flat_param self.assertEqual(shard_outer_numel, outer_param.numel()) self.assertEqual(shard_inner_numel, inner_param.numel()) @@ -209,7 +212,7 @@ def forward(self, fsdp_module): model = FSDP(MyModule()).cuda(self.rank) with self.assertRaisesRegex( - ValueError, "current state is TrainingState_.FORWARD" + ValueError, "Current handle state is HandleTrainingState.FORWARD" ): model(model) @@ -228,7 +231,7 @@ def bad_backwards_hook(tensor): output.register_hook(bad_backwards_hook) with self.assertRaisesRegex( - ValueError, "current state is TrainingState_.BACKWARD_PRE" + ValueError, "Current handle state is HandleTrainingState.BACKWARD_PRE" ): output.backward() @@ -243,9 +246,7 @@ def test_summon_full_params_respects_reshard_after_forward(self): ) def _test_summon_full_params_respects_reshard_after_forward( - self, - mixed_precision: Optional[MixedPrecision], - use_orig_params: bool + self, mixed_precision: Optional[MixedPrecision], use_orig_params: bool ): fsdp_kwargs = { "mixed_precision": mixed_precision, @@ -259,10 +260,8 @@ def _test_summon_full_params_respects_reshard_after_forward( **fsdp_kwargs, ).cuda(self.rank) - outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param") - inner_param = model.get_parameter( - "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param" - ) + outer_param = model._handles[0].flat_param + inner_param = model.module[0]._handles[0].flat_param outer_full_param_size = outer_param.numel() * self.world_size # trigger lazy init @@ -285,7 +284,7 @@ def _test_summon_full_params_respects_reshard_after_forward( def test_summon_single_param(self): model = FSDP(nn.Linear(1, 1, bias=False)).cuda(self.rank) - p = model.get_parameter("_fsdp_wrapped_module.flat_param") + p = model._handles[0].flat_param self.assertEqual(1, p.numel()) with torch.no_grad(): @@ -379,7 +378,9 @@ def __init__(self, fsdp_1, fsdp_2, fsdp_3): def test_reshard_outside_forward_backward_iteration( self, rank0_only, offload_to_cpu, mixed_precision ): - mixed_precision = MixedPrecision(param_dtype=torch.float16) if mixed_precision else None + mixed_precision = ( + MixedPrecision(param_dtype=torch.float16) if mixed_precision else None + ) model = FSDP( nn.Sequential( FSDP(nn.Linear(5, 5, bias=False), mixed_precision=mixed_precision), @@ -388,10 +389,8 @@ def test_reshard_outside_forward_backward_iteration( mixed_precision=mixed_precision, ).cuda(self.rank) - outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param") - inner_param = model.get_parameter( - "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param" - ) + outer_param = model._handles[0].flat_param + inner_param = model.module[0]._handles[0].flat_param outer_full_param_size = outer_param.numel() * self.world_size # First lets validate our assumption about resharding @@ -445,13 +444,15 @@ def test_reshard_outside_forward_backward_iteration( def test_params_are_unflattenned(self, rank0_only, offload_to_cpu, mixed_precision): layer_shape = (10, 12) model = nn.Linear(*layer_shape, bias=False).cuda(self.rank) - mixed_precision = MixedPrecision(param_dtype=torch.float16) if mixed_precision else None + mixed_precision = ( + MixedPrecision(param_dtype=torch.float16) if mixed_precision else None + ) fsdp_model = FSDP(deepcopy(model), mixed_precision=mixed_precision).cuda( self.rank ) def _get_flat_param(): - return fsdp_model.get_parameter("_fsdp_wrapped_module.flat_param") + return fsdp_model._handles[0].flat_param flattened_param = _get_flat_param() self.assertEqual(layer_shape[0] * layer_shape[1] / 2, flattened_param.numel()) @@ -494,7 +495,9 @@ def test_params_count_and_value( offload_to_cpu: bool, mixed_precision: bool, ): - mixed_precision = MixedPrecision(param_dtype=torch.float16) if mixed_precision else None + mixed_precision = ( + MixedPrecision(param_dtype=torch.float16) if mixed_precision else None + ) model = NestedWrappedModule.init( self.process_group, FSDPInitMode.NO_FSDP, @@ -582,7 +585,8 @@ def test_named_parameters_buffers(self, prefix: str, recurse: bool): self.assertEqual(p1, p2) @skip_if_lt_x_gpu(2) - def test_with_grads(self): + def test_with_grads_core(self): + """Tests the core usage of ``summon_full_params(with_grads=True)``.""" self.run_subtests( { "writeback": [False, True], @@ -594,10 +598,10 @@ def test_with_grads(self): ], "use_orig_params": [True], }, - self._test_with_grads, + self._test_with_grads_core, ) - def _test_with_grads( + def _test_with_grads_core( self, writeback: bool, offload_to_cpu: bool, @@ -631,10 +635,13 @@ def _check_grads( assert torch.count_nonzero(p2.grad) > 0 p2.grad *= WRITEBACK_FACTOR new_fsdp_grads = [ - param.grad for param in fsdp_model.parameters() + param.grad + for param in fsdp_model.parameters() if param.grad is not None ] - writeback_persists = writeback or sharding_strategy == ShardingStrategy.NO_SHARD + writeback_persists = ( + writeback or sharding_strategy == ShardingStrategy.NO_SHARD + ) for old_grad, new_grad in zip(old_fsdp_grads, new_fsdp_grads): if writeback_persists: torch.testing.assert_close(old_grad * WRITEBACK_FACTOR, new_grad) @@ -647,14 +654,16 @@ def _check_grads( def _get_error_context(is_supported: bool): return ( - contextlib.suppress() if is_supported + contextlib.suppress() + if is_supported else self.assertRaises(NotImplementedError) ) # some configs not implemented yet def _get_fsdp_grads(fsdp_model: FSDP, is_supported: bool): if is_supported: return [ - param.grad.clone() for param in fsdp_model.parameters() + param.grad.clone() + for param in fsdp_model.parameters() if param.grad is not None ] return None # unused @@ -699,6 +708,41 @@ def _get_fsdp_grads(fsdp_model: FSDP, is_supported: bool): with _get_error_context(is_supported): _check_grads(ddp_model, fsdp_model, old_fsdp_grads) + @skip_if_lt_x_gpu(2) + def test_with_grads_none_grads(self): + """ + Tests that if all ranks' ``FlatParameter`` has ``None`` gradient, then + each original parameter sees ``None`` gradient as well. + """ + self.run_subtests( + { + "sharding_strategy": [ + ShardingStrategy.FULL_SHARD, + ShardingStrategy.SHARD_GRAD_OP, + ShardingStrategy.NO_SHARD, + ] + }, + self._test_with_grads_none_grads, + ) + + def _test_with_grads_none_grads(self, sharding_strategy: ShardingStrategy): + fsdp_model = TransformerWithSharedParams.init( + self.process_group, + FSDPInitMode.RECURSIVE, + CUDAInitMode.CUDA_BEFORE, + deterministic=True, + fsdp_kwargs={ + "use_orig_params": True, + "sharding_strategy": sharding_strategy, + }, + ) + for fsdp_module in FSDP.fsdp_modules(fsdp_model): + for handle in fsdp_module._handles: + assert handle.flat_param.grad is None + with FSDP.summon_full_params(fsdp_model, with_grads=True): + for param in fsdp_model.parameters(): + self.assertTrue(param.grad is None) + instantiate_parametrized_tests(TestSummonFullParams) instantiate_parametrized_tests(TestSummonFullParamsNoShard) diff --git a/test/distributed/fsdp/test_fsdp_tp_integration.py b/test/distributed/fsdp/test_fsdp_tp_integration.py index 7c6c15b2b422c..9b3ba3d5add80 100644 --- a/test/distributed/fsdp/test_fsdp_tp_integration.py +++ b/test/distributed/fsdp/test_fsdp_tp_integration.py @@ -11,9 +11,9 @@ from torch.distributed._shard.sharded_tensor.api import Shard, ShardedTensor from torch.distributed._shard.sharding_plan import ShardingPlan from torch.distributed._shard.sharding_spec import ChunkShardingSpec +from torch.distributed.fsdp._common_utils import _set_fsdp_flattened from torch.distributed.fsdp._fsdp_extensions import _set_fsdp_extensions, FSDPExtensions from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor -from torch.distributed.fsdp._utils import _set_fsdp_flattened from torch.distributed.fsdp.fully_sharded_data_parallel import ( CPUOffload, FullyShardedDataParallel as FSDP, diff --git a/test/distributed/fsdp/test_fsdp_traversal.py b/test/distributed/fsdp/test_fsdp_traversal.py index e1b0a77cfe791..b9c7a0aeac9b2 100644 --- a/test/distributed/fsdp/test_fsdp_traversal.py +++ b/test/distributed/fsdp/test_fsdp_traversal.py @@ -11,10 +11,7 @@ FSDPTest, NestedWrappedModule, ) -from torch.testing._internal.common_utils import ( - TEST_WITH_DEV_DBG_ASAN, - run_tests, -) +from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN if not dist.is_available(): print("Distributed not available, skipping tests", file=sys.stderr) @@ -42,18 +39,20 @@ def test_fsdp_modules(self): ) modules = FSDP.fsdp_modules(nested_wrapped_module) self.assertEquals( - modules, [ + modules, + [ nested_wrapped_module.module.get_submodule("1"), nested_wrapped_module.module.get_submodule("1").get_submodule("0"), nested_wrapped_module.module.get_submodule("2"), - ] + ], ) modules = FSDP.fsdp_modules(nested_wrapped_module, root_only=True) self.assertEqual( - modules, [ + modules, + [ nested_wrapped_module.module.get_submodule("1"), nested_wrapped_module.module.get_submodule("2"), - ] + ], ) diff --git a/test/distributed/fsdp/test_fsdp_uneven.py b/test/distributed/fsdp/test_fsdp_uneven.py index 295afbce508bc..6ffeb279b617b 100644 --- a/test/distributed/fsdp/test_fsdp_uneven.py +++ b/test/distributed/fsdp/test_fsdp_uneven.py @@ -8,11 +8,8 @@ from torch.nn import Linear from torch.optim import SGD from torch.testing._internal.common_distributed import skip_if_lt_x_gpu -from torch.testing._internal.common_fsdp import ( - FSDPTest, -) -from torch.testing._internal.common_utils import TEST_WITH_DEV_DBG_ASAN, run_tests - +from torch.testing._internal.common_fsdp import FSDPTest +from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN if not dist.is_available(): print("Distributed not available, skipping tests", file=sys.stderr) diff --git a/test/distributed/fsdp/test_fsdp_use_orig_params.py b/test/distributed/fsdp/test_fsdp_use_orig_params.py index 69b0645a3fa34..e61f2e4d96ded 100644 --- a/test/distributed/fsdp/test_fsdp_use_orig_params.py +++ b/test/distributed/fsdp/test_fsdp_use_orig_params.py @@ -1,8 +1,9 @@ # Owner(s): ["oncall: distributed"] import functools +import itertools import sys -from typing import Callable, Optional, Tuple, Type +from typing import Any, Callable, Dict, List, Optional, Tuple, Type import torch import torch.nn as nn @@ -13,8 +14,8 @@ FullyShardedDataParallel as FSDP, ShardingStrategy, ) -from torch.distributed.fsdp.fully_sharded_data_parallel import clean_tensor_name -from torch.distributed.fsdp.wrap import always_wrap_policy, transformer_auto_wrap_policy +from torch.distributed.fsdp._common_utils import clean_tensor_name +from torch.distributed.fsdp.wrap import always_wrap_policy, ModuleWrapPolicy from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer from torch.nn.parallel.distributed import DistributedDataParallel as DDP from torch.testing._internal.common_distributed import skip_if_lt_x_gpu @@ -46,16 +47,14 @@ class TestFSDPUseOrigParamsMultipleParamGroups(FSDPTest): """Tests multiple parameter groups.""" - def _get_optim( - self, - model: nn.Module, - optim_class: Type[torch.optim.Optimizer], - multi_tensor: bool, - ) -> torch.optim.Optimizer: + @property + def world_size(self) -> int: + return 2 + + def _get_param_groups(self, model: nn.Module) -> List[Dict[str, Any]]: """ - Constructs an Adam optimizer with three parameter groups, one for - weights, one for biases, and one for everything else, each with - different weight decay and learning rates. + Constructs separate parameter groups for weights, biases, and other + parameters. """ param_groups = [ {"params": [], "weight_decay": 0.1, "lr": 1e-2}, @@ -69,18 +68,24 @@ def _get_optim( param_groups[1]["params"].append(param) else: param_groups[2]["params"].append(param) - return optim_class(param_groups, lr=5e-3, foreach=multi_tensor) + return param_groups - def _get_ddp_transformer_and_optim( + def _get_optim( self, + model: nn.Module, optim_class: Type[torch.optim.Optimizer], multi_tensor: bool, - find_unused_params: bool, - ) -> Tuple[DDP, torch.optim.Optimizer]: + ) -> torch.optim.Optimizer: """ - Returns a transformer with shared parameters wrapped with DDP and a - corresponding optimizer. + Constructs an Adam optimizer with three parameter groups, one for + weights, one for biases, and one for everything else, each with + different weight decay and learning rates. """ + param_groups = self._get_param_groups(model) + return optim_class(param_groups, lr=5e-3, foreach=multi_tensor) + + def _get_ddp_transformer(self, find_unused_params: bool) -> DDP: + """Returns a transformer with shared parameters wrapped with DDP.""" model = TransformerWithSharedParams.init( self.process_group, FSDPInitMode.NO_FSDP, @@ -92,8 +97,7 @@ def _get_ddp_transformer_and_optim( device_ids=[self.rank], find_unused_parameters=find_unused_params, ) - ddp_optim = self._get_optim(ddp_model, optim_class, multi_tensor) - return ddp_model, ddp_optim + return ddp_model def _get_fsdp_transformer_and_optim( self, @@ -113,12 +117,11 @@ def _get_fsdp_transformer_and_optim( # combination with the parameter group construction, ensures different # hyperparameter settings within one `FlatParameter` fsdp_kwargs = { - "auto_wrap_policy": functools.partial( - transformer_auto_wrap_policy, - transformer_layer_cls={ + "auto_wrap_policy": ModuleWrapPolicy( + { TransformerEncoderLayer, TransformerDecoderLayer, - }, + } ), "use_orig_params": True, "sharding_strategy": sharding_strategy, @@ -174,11 +177,17 @@ def _check_train_parity( model.to(torch.device("cpu")) optim.step() if model is ddp_model and fsdp_model.cpu_offload.offload_params: - model.to(torch.device("cuda")) + model.to(device) torch.testing.assert_close(iter_losses[0], iter_losses[1]) iter_losses.clear() + self._check_ddp_fsdp_param_parity(ddp_model, fsdp_model) + + def _check_ddp_fsdp_param_parity(self, ddp_model: DDP, fsdp_model: FSDP): with FSDP.summon_full_params(fsdp_model): - for p1, p2 in zip(ddp_model.parameters(), fsdp_model.parameters()): + for (n1, p1), (n2, p2) in zip( + ddp_model.module.named_parameters(), fsdp_model.named_parameters() + ): + self.assertEqual(n1, n2) torch.testing.assert_close(p1, p2) def _get_sharding_strategy_from_str( @@ -226,11 +235,12 @@ def test_diff_hyperparams(self, sharding_strategy_str: str): sharding_strategy=sharding_strategy, ) + @skip_if_lt_x_gpu(2) @parametrize( "sharding_strategy_str", ["no_shard", "shard_grad_op", "full_shard"], ) - def _test_diff_hyperparams_cpu_offload(self, sharding_strategy_str: str): + def test_diff_hyperparams_cpu_offload(self, sharding_strategy_str: str): """ Tests FSDP parity with DDP when using multiple parameter groups with different hyperparameter settings with CPU offloading enabled. This is @@ -271,11 +281,8 @@ def _test_diff_hyperparams( """ if cuda_init_mode == CUDAInitMode.CUDA_AFTER and cpu_offload.offload_params: return # not supported - ddp_model, ddp_optim = self._get_ddp_transformer_and_optim( - optim_class=optim_class, - multi_tensor=multi_tensor, - find_unused_params=False, - ) + ddp_model = self._get_ddp_transformer(find_unused_params=False) + ddp_optim = self._get_optim(ddp_model, optim_class, multi_tensor) fsdp_model, fsdp_optim = self._get_fsdp_transformer_and_optim( cuda_init_mode=cuda_init_mode, init_optim_before_wrap=init_optim_before_wrap, @@ -313,11 +320,8 @@ def _test_diff_trainability( sharding_strategy: ShardingStrategy, ): optim_class = torch.optim.Adam - ddp_model, ddp_optim = self._get_ddp_transformer_and_optim( - optim_class=optim_class, - multi_tensor=multi_tensor, - find_unused_params=True, - ) + ddp_model = self._get_ddp_transformer(find_unused_params=True) + ddp_optim = self._get_optim(ddp_model, optim_class, multi_tensor) fsdp_model, fsdp_optim = self._get_fsdp_transformer_and_optim( cuda_init_mode=CUDAInitMode.CUDA_BEFORE, init_optim_before_wrap=False, @@ -336,10 +340,145 @@ def _test_diff_trainability( param.requires_grad_(False) self._check_train_parity(ddp_model, ddp_optim, fsdp_model, fsdp_optim, False) + @skip_if_lt_x_gpu(2) + def test_multiple_optimizers(self): + """ + Tests using two optimizers where only one sets gradients to ``None``. + """ + self.run_subtests( + { + "sharding_strategy": [ + ShardingStrategy.FULL_SHARD, + # ShardingStrategy.SHARD_GRAD_OP, + ] + }, + self._test_multiple_optimizers, + ) + + def _test_multiple_optimizers(self, sharding_strategy: ShardingStrategy): + ddp_model = self._get_ddp_transformer(find_unused_params=True) + ddp_param_groups = self._get_param_groups(ddp_model) + assert len(ddp_param_groups) == 3, f"{len(ddp_param_groups)}" + ( + fsdp_model, + _, + ) = self._get_fsdp_transformer_and_optim( # ignore returned optimizer + cuda_init_mode=CUDAInitMode.CUDA_BEFORE, + init_optim_before_wrap=False, + optim_class=torch.optim.Adam, # ignored + multi_tensor=False, # ignored + sharding_strategy=sharding_strategy, + backward_prefetch=BackwardPrefetch.BACKWARD_PRE, + cpu_offload=None, + ) + fsdp_param_groups = self._get_param_groups(fsdp_model) + assert len(fsdp_param_groups) == 3, f"{len(fsdp_param_groups)}" + ddp_optims = [] + fsdp_optims = [] + # For the transformer model, every parameter is either a weight or a + # bias, so we only use the first two parameter groups. Moreover, we use + # Adam and AdamW in particular since they both use bias correction + # dependent on the step, which is incremented even if a parameter has a + # zero gradient but not if the gradient is `None`. This is to test that + # we are differentiating between a zero and `None` gradient correctly. + optim_ctors = [ + functools.partial(torch.optim.Adam, lr=5e-3), + functools.partial(torch.optim.AdamW, lr=1e-2), + ] + + for optim_ctor, ddp_param_group, fsdp_param_group in zip( + optim_ctors, + ddp_param_groups[:2], + fsdp_param_groups[:2], + ): + ddp_optims.append(optim_ctor(ddp_param_group["params"])) + fsdp_optims.append(optim_ctor(fsdp_param_group["params"])) + device = torch.device("cuda") + + # Check that there exists a `FlatParameter` that has both a weight and + # a bias in this rank's shard + has_both = False + for fsdp_module in FSDP.fsdp_modules(fsdp_model): + for handle in fsdp_module._handles: + flat_param = handle.flat_param + assert flat_param._params is not None + has_weight = False + has_bias = False + for param, fqn in zip(flat_param._params, flat_param._fqns): + if "weight" in fqn and param.numel() > 0: + has_weight = True + elif "bias" in fqn and param.numel() > 0: + has_bias = True + has_both |= has_weight and has_bias + assert has_both, ( + f"Rank {self.rank} does not have a `FlatParameter` with both a " + "weight and a bias in its shard, meaning that this test is vacuous" + ) + + # Run one iteration to generate gradients + def run_iter(): + iter_losses = [] + for model, optims in ((ddp_model, ddp_optims), (fsdp_model, fsdp_optims)): + module = model.module + inp = module.get_input(device) + output = model(*inp) + loss = module.get_loss(inp, output).to(device) + iter_losses.append(loss) + module.run_backward(loss) + for optim in optims: + optim.step() + torch.testing.assert_close(iter_losses[0], iter_losses[1]) + iter_losses.clear() + self._check_ddp_fsdp_param_parity(ddp_model, fsdp_model) + + run_iter() + + # Only set the weights' gradients to None + ddp_optims[0].zero_grad(set_to_none=True) + fsdp_optims[0].zero_grad(set_to_none=True) + inp = ddp_model.module.get_input(device) + ddp_output = ddp_model(*inp) + fsdp_output = fsdp_model(*inp) + + # Check that FSDP correctly exposes gradients even after forward + # (namely, `None` for weights and non-`None` for biases) + for (ddp_n, ddp_p), (fsdp_n, fsdp_p) in zip( + ddp_model.module.named_parameters(), + fsdp_model.named_parameters(), + ): + self.assertEqual(ddp_n, fsdp_n) + if fsdp_p.numel() == 0: + # Not in this rank's shard + self.assertTrue(fsdp_p.grad is None) + continue + if ddp_p.grad is None: + self.assertTrue(fsdp_p.grad is None) + else: + self.assertEqual(ddp_p.flatten(), fsdp_p.flatten()) + self.assertEqual(ddp_p.grad.flatten(), fsdp_p.grad.flatten()) + self._check_ddp_fsdp_param_parity(ddp_model, fsdp_model) + + # Finish the iteration (backward pass and optimizer step) + ddp_loss = ddp_model.module.get_loss(inp, ddp_output).to(device) + fsdp_loss = fsdp_model.module.get_loss(inp, fsdp_output).to(device) + ddp_model.module.run_backward(ddp_loss) + fsdp_model.module.run_backward(fsdp_loss) + for optim in itertools.chain(ddp_optims, fsdp_optims): + optim.step() + self._check_ddp_fsdp_param_parity(ddp_model, fsdp_model) + + # Run one more iteration to confirm bias corrections are correct + run_iter() + self._check_ddp_fsdp_param_parity(ddp_model, fsdp_model) + class TestFSDPUseOrigParamsUnshardReshard(FSDPTest): """Tests the unshard/reshard flow.""" + @property + def world_size(self) -> int: + return 2 + def _get_fsdp_models_and_optims( self, sharding_strategy: ShardingStrategy, @@ -867,6 +1006,47 @@ def forward(self, x): fsdp_buffer_names = [n for n, _ in fsdp_model.named_buffers()] self.assertEqual(buffer_names, fsdp_buffer_names) + @skip_if_lt_x_gpu(2) + def test_named_parameters_in_forward(self): + """ + Tests that calling ``named_parameters()`` during forward returns FQNs + and ``Tensor`` s corresponding to the original parameters. + """ + param_shapes = [None, None] + assert_equal_fn = self.assertEqual + + class Model(nn.Module): + def __init__(self) -> None: + super().__init__() + self.lin = nn.Linear(5, 5) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + nonlocal param_shapes + param_names = [tup[0] for tup in self.named_parameters()] + params = [tup[1] for tup in self.named_parameters()] + assert ( + param_shapes[0] is not None and param_shapes[1] is not None + ), "`param_sizes` should be set" + assert_equal_fn( + param_names, + [ + "lin.weight", + "lin.bias", + ], + ) + assert_equal_fn(params[0].shape, param_shapes[0]) + assert_equal_fn(params[1].shape, param_shapes[1]) + return self.lin(x) + + model = Model().cuda() + # Save the *unsharded* original parameter shapes and check the shapes + # match in the forward pass + param_shapes[0] = model.lin.weight.shape + param_shapes[1] = model.lin.bias.shape + fsdp_model = FSDP(model, use_orig_params=True) + inp = torch.randn((2, 5), device=torch.device("cuda")) + fsdp_model(inp) + instantiate_parametrized_tests(TestFSDPUseOrigParamsMultipleParamGroups) instantiate_parametrized_tests(TestFSDPUseOrigParamsUnshardReshard) diff --git a/test/distributed/fsdp/test_utils.py b/test/distributed/fsdp/test_utils.py index 2aa7fa0b6d97e..37c52547e8472 100644 --- a/test/distributed/fsdp/test_utils.py +++ b/test/distributed/fsdp/test_utils.py @@ -2,24 +2,27 @@ import random import sys -from typing import List import unittest from collections import OrderedDict +from dataclasses import dataclass +from enum import auto, Enum +from typing import List import torch import torch.nn as nn from torch import distributed as dist from torch.distributed.fsdp._utils import _apply_to_tensors +from torch.distributed.fsdp._wrap_utils import _get_submodule_to_states +from torch.distributed.fsdp.wrap import ModuleWrapPolicy from torch.distributed.utils import _replace_by_prefix from torch.testing._internal.common_utils import ( - TEST_WITH_DEV_DBG_ASAN, - TestCase, instantiate_parametrized_tests, parametrize, run_tests, subtest, + TEST_WITH_DEV_DBG_ASAN, + TestCase, ) -from dataclasses import dataclass if not dist.is_available(): print("Distributed not available, skipping tests", file=sys.stderr) @@ -60,8 +63,6 @@ class SomeDataClass: some_float: float some_tensor: List[torch.Tensor] - - # create a mixed bag of data. data = [1, "str"] data.append({"key1": get_a_tensor(), "key2": {1: get_a_tensor()}, "key3": 3}) @@ -100,7 +101,6 @@ def test_replace_by_prefix(self): _replace_by_prefix(state_dict, "module.layer.", "layer.") assert state_dict == original_state_dict - def test_packed_sequence(self): """Test to ensure RNN packed sequences are modified correctly.""" rnn = nn.RNN(5, 5) @@ -118,7 +118,110 @@ def fill_fn(x): self.assertEqual(torch.sum(x), 0) +class TestGetSubmoduleToStates(TestCase): + """Tests the function ``_get_submodule_to_states()``.""" + + class SharedParameterMode(Enum): + """ + - ``PARENT_CHILD``: A parent submodule shares a parameter with a child + submodule. + - ``SIBLING``: Two sibling submodules share a parameter. + """ + + PARENT_CHILD = auto() + SIBLING = auto() # TODO: not yet supported + + class Model(nn.Module): + """Nested model with buffers and a shared parameter.""" + + def __init__(self, shared_parameter_mode) -> None: + super().__init__() + self.seq1 = nn.Sequential( + nn.Linear(5, 5, bias=False), + nn.Linear(5, 5, bias=False), + ) + self.seq1.register_buffer("seq1_buffer", torch.randn((5,))) + self.lin = nn.Linear(5, 5, bias=False) + self.seq2 = nn.Sequential( + nn.Sequential(nn.Linear(5, 5, bias=False)), nn.Linear(5, 5, bias=False) + ) + if ( + shared_parameter_mode + == TestGetSubmoduleToStates.SharedParameterMode.PARENT_CHILD + ): + self.seq2[0][0].weight = self.lin.weight + elif ( + shared_parameter_mode + == TestGetSubmoduleToStates.SharedParameterMode.SIBLING + ): + self.seq2[0][0].weight = self.seq1[0].weight + self.seq2[1].register_buffer("seq2_1_buffer", torch.randn((5,))) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.seq2(self.lin(self.seq1(x))) # equivalent to one matmul + + def test_module_wrap_policy(self): + """ + Tests the module wrap policy on a nested model with buffers and a + shared parameter. + + NOTE: This test is hard coded against ``Model``. + """ + model = self.Model(TestGetSubmoduleToStates.SharedParameterMode.PARENT_CHILD) + + # Compute the mapping from submodule to states according to a logical + # module wrap policy + module_classes = (nn.Sequential,) + auto_wrap_policy = ModuleWrapPolicy(set(module_classes)) + submodule_to_states = _get_submodule_to_states( + model, auto_wrap_policy, set(), set() + ) + # Check the number of submodules with states in the mapping + num_submodules_with_states = sum( + isinstance(submodule, module_classes) for submodule in model.modules() + ) # explicitly show how to compute the expected number + if not isinstance(model, module_classes): + num_submodules_with_states += 1 # always include the root + assert num_submodules_with_states == 4, f"{num_submodules_with_states}" + self.assertEqual(len(submodule_to_states), num_submodules_with_states) + + # Check the mapping, i.e. that the dict order follows a post-order + # traversal and that the contents are expected + submodules = list(submodule_to_states.keys()) + # - Root module `model` + self.assertEqual(submodules[0], model) + root_states = submodule_to_states[submodules[0]] + self.assertEqual(root_states.params, [model.lin.weight]) + self.assertEqual(root_states.param_names, ["lin.weight"]) + self.assertEqual(root_states.buffers, []) + self.assertEqual(root_states.buffer_names, []) + # # - `seq2` + self.assertEqual(submodules[1], model.seq2) + seq2_states = submodule_to_states[submodules[1]] + self.assertEqual(seq2_states.params, [model.seq2[1].weight]) + self.assertEqual(seq2_states.param_names, ["1.weight"]) + self.assertEqual(seq2_states.buffers, [model.seq2[1].seq2_1_buffer]) + self.assertEqual(seq2_states.buffer_names, ["1.seq2_1_buffer"]) + # - `seq2[0]` + self.assertEqual(submodules[2], model.seq2[0]) + seq2_0_states = submodule_to_states[submodules[2]] + self.assertEqual(seq2_0_states.params, []) # shared parameter + self.assertEqual(seq2_0_states.param_names, []) + self.assertEqual(seq2_0_states.buffers, []) + self.assertEqual(seq2_0_states.buffer_names, []) + # - `seq1` + self.assertEqual(submodules[3], model.seq1) + seq1_states = submodule_to_states[submodules[3]] + self.assertEqual( + seq1_states.params, [model.seq1[0].weight, model.seq1[1].weight] + ) + self.assertEqual(seq1_states.param_names, ["0.weight", "1.weight"]) + self.assertEqual(seq1_states.buffers, [model.seq1.seq1_buffer]) + self.assertEqual(seq1_states.buffer_names, ["seq1_buffer"]) + + instantiate_parametrized_tests(TestUtils) +instantiate_parametrized_tests(TestGetSubmoduleToStates) if __name__ == "__main__": run_tests() diff --git a/test/distributed/fsdp/test_wrap.py b/test/distributed/fsdp/test_wrap.py index 98ba324f46f18..e157f041ae1bd 100644 --- a/test/distributed/fsdp/test_wrap.py +++ b/test/distributed/fsdp/test_wrap.py @@ -4,7 +4,8 @@ import os import tempfile import unittest -from enum import Enum, auto +from enum import auto, Enum +from typing import Callable, Union import torch import torch.nn as nn @@ -12,15 +13,15 @@ from torch.distributed.fsdp.fully_sharded_data_parallel import ( BackwardPrefetch, CPUOffload, -) -from torch.distributed.fsdp.fully_sharded_data_parallel import ( FullyShardedDataParallel as FSDP, ) from torch.distributed.fsdp.wrap import ( + _FSDPPolicy, _or_policy, _wrap_batchnorm_individually, always_wrap_policy, enable_wrap, + ModuleWrapPolicy, size_based_auto_wrap_policy, transformer_auto_wrap_policy, wrap, @@ -28,20 +29,20 @@ from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import ( + _maybe_cuda, CUDAInitMode, DummyProcessGroup, FSDPInitMode, FSDPTest, TransformerWithSharedParams, - _maybe_cuda, ) from torch.testing._internal.common_utils import ( FILE_SCHEMA, - TestCase, find_free_port, instantiate_parametrized_tests, parametrize, run_tests, + TestCase, ) @@ -54,6 +55,7 @@ def __init__(self): self.bn3 = nn.BatchNorm3d(10) self.sync_bn = nn.SyncBatchNorm(10) + class WrapMethod(Enum): FSDP_CTOR = auto() # FSDP_CTOR is the supported way forward, but keep WRAP_API in case we miss @@ -61,8 +63,6 @@ class WrapMethod(Enum): WRAP_API = auto() - - class TestFSDPWrap(FSDPTest): """ Tests main API for wrapping FSDP, which is to pass auto_wrap_policy into @@ -144,7 +144,9 @@ def test_error_already_wrapped(self, nested, cuda_init_mode): Test that an error is raised if we attempt to wrap when submodules are already FSDP. """ - wrapped_fsdp = self._get_already_wrapped_fsdp(nested=nested, cuda_init_mode=cuda_init_mode) + wrapped_fsdp = self._get_already_wrapped_fsdp( + nested=nested, cuda_init_mode=cuda_init_mode + ) if cuda_init_mode == CUDAInitMode.CUDA_AFTER: wrapped_fsdp = wrapped_fsdp.cuda() @@ -159,9 +161,10 @@ def never_wrap_policy(*args, **kwargs): policy = ( functools.partial( - _or_policy, - policies=[never_wrap_policy, _wrap_batchnorm_individually] - ) if use_or_policy else _wrap_batchnorm_individually + _or_policy, policies=[never_wrap_policy, _wrap_batchnorm_individually] + ) + if use_or_policy + else _wrap_batchnorm_individually ) model = BatchNormNet() fsdp = FSDP(model, auto_wrap_policy=policy) @@ -178,6 +181,7 @@ def test_bn_always_wrapped_individually(self): if the other policy results in a module containing a BN unit being wrapped, the contained BN unit will still be individually wrapped. """ + class MyModule(nn.Module): def __init__(self): super().__init__() @@ -189,8 +193,7 @@ def wrap_bn_container(module, recurse, *args, **kwargs): return isinstance(module, BatchNormNet) my_policy = functools.partial( - _or_policy, - policies=[wrap_bn_container, _wrap_batchnorm_individually] + _or_policy, policies=[wrap_bn_container, _wrap_batchnorm_individually] ) mod = MyModule() fsdp = FSDP(mod, auto_wrap_policy=my_policy) @@ -203,7 +206,7 @@ def wrap_bn_container(module, recurse, *args, **kwargs): fsdp.bn_container.bn1, fsdp.bn_container.bn2, fsdp.bn_container.bn3, - fsdp.bn_container.sync_bn + fsdp.bn_container.sync_bn, ]: self.assertTrue(isinstance(bn, FSDP)) @@ -216,24 +219,21 @@ def wrap_bn_container(module, recurse, *args, **kwargs): fsdp.bn_container.bn1, fsdp.bn_container.bn2, fsdp.bn_container.bn3, - fsdp.bn_container.sync_bn + fsdp.bn_container.sync_bn, ]: self.assertFalse(isinstance(bn, FSDP)) @skip_if_lt_x_gpu(2) @parametrize( "cpu_offload", - [CPUOffload(offload_params=False), CPUOffload(offload_params=True)] + [CPUOffload(offload_params=False), CPUOffload(offload_params=True)], ) @parametrize( "backward_prefetch", - [BackwardPrefetch.BACKWARD_POST, BackwardPrefetch.BACKWARD_PRE] + [BackwardPrefetch.BACKWARD_POST, BackwardPrefetch.BACKWARD_PRE], ) @parametrize("forward_prefetch", [False, True]) - @parametrize( - "cuda_init_mode", - [CUDAInitMode.CUDA_AFTER, CUDAInitMode.CUDA_BEFORE] - ) + @parametrize("cuda_init_mode", [CUDAInitMode.CUDA_AFTER, CUDAInitMode.CUDA_BEFORE]) def test_main_wrap_api( self, cpu_offload: CPUOffload, @@ -286,7 +286,7 @@ def forward(self, input): wrapped_model.module.lin3, wrapped_model.module.lin4.module.nested_lin, wrapped_model.module.lin4, - wrapped_model + wrapped_model, ] for module in modules_in_fsdp_graph_order: @@ -322,7 +322,9 @@ def test_wrap(self, wrap_method): layer = FSDP( nn.Linear(5, 5), process_group=self.process_group, - auto_wrap_policy=functools.partial(size_based_auto_wrap_policy, min_num_params=1) + auto_wrap_policy=functools.partial( + size_based_auto_wrap_policy, min_num_params=1 + ), ) self.assertTrue(isinstance(layer, FSDP)) self.assertEqual(layer.rank, self.process_group.rank()) @@ -362,7 +364,9 @@ def test_always_wrap(self): passed into FSDP, all submodules are wrapped. """ seq = TestFSDPWrap.NestedSequentialModel.get_model(cuda=True) - model = FSDP(seq, process_group=self.process_group, auto_wrap_policy=always_wrap_policy) + model = FSDP( + seq, process_group=self.process_group, auto_wrap_policy=always_wrap_policy + ) TestFSDPWrap.NestedSequentialModel.verify_model_all_wrapped(self, model) @unittest.skipIf(torch.cuda.device_count() < 2, "Requires at least 2 GPUs") @@ -372,6 +376,19 @@ def test_transformer_auto_wrap_policy(self): transformer_auto_wrap_policy, transformer_layer_cls={TransformerEncoderLayer, TransformerDecoderLayer}, ) + self._test_transformer_wrapping(auto_wrap_policy) + + @unittest.skipIf(torch.cuda.device_count() < 2, "Requires at least 2 GPUs") + def test_module_wrap_policy(self): + """Tests the ``ModuleWrapPolicy``.""" + auto_wrap_policy = ModuleWrapPolicy( + {TransformerEncoderLayer, TransformerDecoderLayer} + ) + self._test_transformer_wrapping(auto_wrap_policy) + + def _test_transformer_wrapping( + self, auto_wrap_policy: Union[Callable, _FSDPPolicy] + ): fsdp_kwargs = {"auto_wrap_policy": auto_wrap_policy} fsdp_model = TransformerWithSharedParams.init( self.process_group, @@ -383,7 +400,11 @@ def test_transformer_auto_wrap_policy(self): encoder_layers = set(fsdp_model.module.transformer.encoder.layers) decoder_layers = set(fsdp_model.module.transformer.decoder.layers) for module in modules: - if module is fsdp_model or module in encoder_layers or module in decoder_layers: + if ( + module is fsdp_model + or module in encoder_layers + or module in decoder_layers + ): self.assertTrue(isinstance(module, FSDP)) else: self.assertFalse(isinstance(module, FSDP)) @@ -401,7 +422,7 @@ def test_auto_wrap_api(self): model = FSDP( sequential, process_group=self.process_group, - auto_wrap_policy=my_auto_wrap_policy + auto_wrap_policy=my_auto_wrap_policy, ) TestFSDPWrap.NestedSequentialModel.verify_model(self, model) @@ -420,7 +441,7 @@ def test_auto_wrap_preset_exclude_wrap(self): model = FSDP( sequential, process_group=self.process_group, - auto_wrap_policy=my_auto_wrap_policy + auto_wrap_policy=my_auto_wrap_policy, ) self.assertTrue(isinstance(model, FSDP)) @@ -437,7 +458,11 @@ def test_auto_wrap_preset_exclude_wrap_include_children(self): my_auto_wrap_policy = functools.partial( size_based_auto_wrap_policy, min_num_params=40 ) - model = FSDP(sequential, process_group=self.process_group, auto_wrap_policy=my_auto_wrap_policy) + model = FSDP( + sequential, + process_group=self.process_group, + auto_wrap_policy=my_auto_wrap_policy, + ) self.assertTrue(isinstance(model, FSDP)) self.assertTrue(isinstance(model[0], FSDP)) @@ -452,7 +477,11 @@ def test_auto_wrap_preset_force_leaf(self): my_auto_wrap_policy = functools.partial( size_based_auto_wrap_policy, min_num_params=40 ) - model = FSDP(sequential, process_group=self.process_group, auto_wrap_policy=my_auto_wrap_policy) + model = FSDP( + sequential, + process_group=self.process_group, + auto_wrap_policy=my_auto_wrap_policy, + ) self.assertTrue(isinstance(model.module[0], FSDP)) # Assert children of multihead attention are not wrapped self.assertTrue(isinstance(model.module[1], nn.MultiheadAttention)) @@ -473,7 +502,11 @@ def test_auto_wrap_preset_force_leaf_custom(self): sequential = nn.Sequential( nn.Linear(10, 10), nn.ModuleList([nn.Linear(10, 10)]) ) - model = FSDP(sequential, process_group=self.process_group, auto_wrap_policy=my_auto_wrap_policy) + model = FSDP( + sequential, + process_group=self.process_group, + auto_wrap_policy=my_auto_wrap_policy, + ) # Model was wrapped in FSDP as no inner modules were wrapped. self.assertTrue(isinstance(model, FSDP)) self.assertTrue(isinstance(model.module[0], nn.Linear)) @@ -483,14 +516,12 @@ def test_auto_wrap_preset_force_leaf_custom(self): @parametrize("cuda_init_mode", [CUDAInitMode.CUDA_BEFORE, CUDAInitMode.CUDA_AFTER]) @parametrize( "cpu_offload", - [CPUOffload(offload_params=False), CPUOffload(offload_params=True)] + [CPUOffload(offload_params=False), CPUOffload(offload_params=True)], ) @parametrize("use_device_id", [True, False]) def test_auto_wrap_smoke_test(self, cuda_init_mode, cpu_offload, use_device_id): # CPU offload and CUDA after don't work together as expected. - if ( - cpu_offload.offload_params and cuda_init_mode == CUDAInitMode.CUDA_AFTER - ): + if cpu_offload.offload_params and cuda_init_mode == CUDAInitMode.CUDA_AFTER: return device = torch.device("cuda") @@ -515,12 +546,17 @@ def test_auto_wrap_smoke_test(self, cuda_init_mode, cpu_offload, use_device_id): # cases where full model cannot be loaded onto GPU, but their shards can. cuda_after_init = cuda_init_mode == CUDAInitMode.CUDA_AFTER try: - sequential = TestFSDPWrap.NestedSequentialModel.get_model(cuda=(not cuda_after_init)) + sequential = TestFSDPWrap.NestedSequentialModel.get_model( + cuda=(not cuda_after_init) + ) my_auto_wrap_policy = functools.partial( size_based_auto_wrap_policy, min_num_params=40 ) model = FSDP( - sequential, cpu_offload=cpu_offload, auto_wrap_policy=my_auto_wrap_policy, device_id=device_id + sequential, + cpu_offload=cpu_offload, + auto_wrap_policy=my_auto_wrap_policy, + device_id=device_id, ) TestFSDPWrap.NestedSequentialModel.verify_model(self, model) if cuda_after_init: @@ -568,7 +604,8 @@ def test_auto_wrap_with_ignored_modules(self, wrap_method: WrapMethod): sequential = TestFSDPWrap.NestedSequentialModel.get_model(cuda=False) ignored_modules = [sequential[1], sequential[2][0]] my_auto_wrap_policy = functools.partial( - size_based_auto_wrap_policy, min_num_params=40, + size_based_auto_wrap_policy, + min_num_params=40, ) fsdp_kwargs = { "process_group": self.process_group, diff --git a/test/distributed/optim/test_apply_optimizer_in_backward.py b/test/distributed/optim/test_apply_optimizer_in_backward.py new file mode 100644 index 0000000000000..ebf4c4d4e9c82 --- /dev/null +++ b/test/distributed/optim/test_apply_optimizer_in_backward.py @@ -0,0 +1,113 @@ +# Owner(s): ["oncall: distributed"] + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import unittest +from copy import deepcopy + +import torch +import torch.nn as nn + +from torch.distributed.optim import _apply_optimizer_in_backward + +# TODO (rohan-varma): Add FSDP & DDP tests once supported + + +def _validate_params(params_list, fn): + ref_params = params_list[0] + for param_list in params_list[1:]: + for p1, p2 in zip(ref_params, param_list): + fn(p1, p2) + + +class ApplyOverlappedOptimizerTest(unittest.TestCase): + def _run_training_loop_and_validate(self, inp, models, optimizers): + for i in range(6): + for model in models: + model(inp).sum().backward() + for opt in optimizers: + opt.step() + + with self.subTest(i): + _validate_params( + [model.parameters() for model in models], + torch.testing.assert_allclose, + ) + + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + def _test_apply_optimizer_in_backward(self, share_params) -> None: + weight_optimizer_kwargs = {"lr": 1.0} + bias_optimizer_kwargs = {"lr": 0.5} + model = nn.Sequential(nn.Linear(10, 10), nn.Linear(10, 10)) + if share_params: + model[0].weight = model[1].weight + + # Use different optimizers for weights & biases. + weights = [m.weight for m in model] + biases = [m.bias for m in model] + optim_weight = torch.optim.SGD(weights, **weight_optimizer_kwargs) + optim_bias = torch.optim.SGD(biases, **bias_optimizer_kwargs) + model_with_opt_in_bwd = deepcopy(model) + + # Apply different optimizer in backwards for weights and biases. + _apply_optimizer_in_backward( + torch.optim.SGD, + [m.weight for m in model_with_opt_in_bwd], + optimizer_kwargs=weight_optimizer_kwargs, + ) + + _apply_optimizer_in_backward( + torch.optim.SGD, + [m.bias for m in model_with_opt_in_bwd], + optimizer_kwargs=bias_optimizer_kwargs, + ) + + _validate_params( + [ + model.parameters(), + model_with_opt_in_bwd.parameters(), + ], + torch.testing.assert_allclose, + ) + + self._run_training_loop_and_validate( + torch.randn(4, 10), + [model, model_with_opt_in_bwd], + [optim_weight, optim_bias], + ) + + def test_apply_optimizer_in_backward(self) -> None: + self._test_apply_optimizer_in_backward(share_params=False) + + def test_apply_optimizer_in_backward_shared_params(self) -> None: + self._test_apply_optimizer_in_backward(share_params=True) + + def test_multiple_optim_for_params(self) -> None: + model = nn.Sequential(nn.Linear(10, 10), nn.Linear(10, 10)) + opt_0_kwargs = {"lr": 0.03} + opt_1_kwargs = {"lr": 0.01} + opt_0 = torch.optim.SGD(model.parameters(), **opt_0_kwargs) + opt_1 = torch.optim.SGD(model.parameters(), **opt_1_kwargs) + model_with_opt_in_bwd = deepcopy(model) + _apply_optimizer_in_backward( + torch.optim.SGD, + model_with_opt_in_bwd.parameters(), + optimizer_kwargs=opt_0_kwargs, + ) + _apply_optimizer_in_backward( + torch.optim.SGD, + model_with_opt_in_bwd.parameters(), + optimizer_kwargs=opt_1_kwargs, + ) + self._run_training_loop_and_validate( + torch.randn(4, 10), + [model, model_with_opt_in_bwd], + [opt_0, opt_1], + ) diff --git a/test/distributed/optim/test_named_optimizer.py b/test/distributed/optim/test_named_optimizer.py new file mode 100644 index 0000000000000..880dbb382aa6a --- /dev/null +++ b/test/distributed/optim/test_named_optimizer.py @@ -0,0 +1,245 @@ +# Owner(s): ["oncall: distributed"] + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +import torch.nn as nn + +from torch.distributed.optim import _NamedOptimizer + + +class TestDummyModel(torch.nn.Module): + def __init__(self): + super(TestDummyModel, self).__init__() + torch.manual_seed(0) + self.net1 = nn.Sequential(nn.Linear(8, 16), nn.ReLU()) + self.net2 = nn.Sequential(nn.Linear(16, 32), nn.ReLU()) + self.net3 = nn.Linear(32, 64) + self.net4 = nn.Sequential(nn.ReLU(), nn.Linear(64, 8)) + + def forward(self, x): + return self.net4(self.net3(self.net2(self.net1(x)))) + + +class NamedOptimizerTest(unittest.TestCase): + def _compare_state_dict_group(self, group, named_group, assert_equal=True): + for key, val in group.items(): + if key != "params": + self.assertTrue( + key in named_group, f"{key} not in named optimizer state dict" + ) + err_msg = ( + f"{key} state not equal" if assert_equal else f"{key} state equal" + ) + if isinstance(val, torch.Tensor): + fn = self.assertTrue if assert_equal else self.assertFalse + fn(torch.allclose(val, named_group[key]), err_msg) + else: + fn = self.assertEqual if assert_equal else self.assertNotEqual + fn(val, named_group[key], err_msg) + + def test_state_dict(self): + """Check that NamedOptimizer exposes the expected state dict + interface.""" + m = TestDummyModel() + m_dup = TestDummyModel() + optim_1 = torch.optim.SGD( + [ + {"params": m.net1.parameters()}, + {"params": m.net3.parameters(), "lr": 1e-3}, + ], + lr=1e-2, + momentum=0.9, + ) + + optim_2 = torch.optim.Adam( + [ + {"params": m.net2.parameters()}, + {"params": m.net4.parameters(), "lr": 1e-5}, + ] + ) + + named_optim_1 = _NamedOptimizer( + m_dup.named_parameters(), + torch.optim.SGD, + [ + {"params": m_dup.net1.parameters()}, + {"params": m_dup.net3.parameters(), "lr": 1e-3}, + ], + lr=1e-2, + momentum=0.9, + ) + + named_optim_2 = _NamedOptimizer( + m_dup.named_parameters(), + torch.optim.Adam, + [ + {"params": m_dup.net2.parameters()}, + {"params": m_dup.net4.parameters(), "lr": 1e-5}, + ], + ) + for i in range(2): + x = torch.rand(5, 8) + y = m(x) + y.sum().backward() + optim_1.step() + optim_2.step() + + y = m_dup(x) + y.sum().backward() + named_optim_1.step() + named_optim_2.step() + + sd_1 = optim_1.state_dict() + sd_2 = optim_2.state_dict() + named_sd_1 = named_optim_1.state_dict() + named_sd_2 = named_optim_2.state_dict() + + # Compare "state" in optim state dict + self._compare_state_dict_group( + sd_1["state"][0], + named_sd_1["state"]["net1.0.weight"], + assert_equal=True, + ) + self._compare_state_dict_group( + sd_2["state"][1], + named_sd_2["state"]["net2.0.bias"], + assert_equal=True, + ) + self._compare_state_dict_group( + sd_1["state"][2], + named_sd_1["state"]["net3.weight"], + assert_equal=True, + ) + self._compare_state_dict_group( + sd_2["state"][3], + named_sd_2["state"]["net4.1.bias"], + assert_equal=True, + ) + + # Compare "param_groups" in optim state dict + self._compare_state_dict_group( + sd_1["param_groups"][0], + named_sd_1["param_groups"][0], + assert_equal=True, + ) + self._compare_state_dict_group( + sd_2["param_groups"][1], named_sd_2["param_groups"][1], assert_equal=True + ) + + def test_load_state_dict(self): + """Check that NamedOptimizer exposes the expected state dict + interface.""" + m = TestDummyModel() + named_optim_1 = _NamedOptimizer( + m.named_parameters(), + torch.optim.SGD, + lr=1e-2, + momentum=0.9, + ) + + for _ in range(2): + x = torch.rand(5, 8) + y = m(x) + y.sum().backward() + named_optim_1.step() + + state_dict_to_load = named_optim_1.state_dict() + + named_optim_2 = _NamedOptimizer( + m.named_parameters(), + torch.optim.SGD, + lr=1e-2, + momentum=0.6, + ) + + for _ in range(2): + x = torch.rand(5, 8) + y = m(x) + y.sum().backward() + named_optim_2.step() + + state_dict_before_load = named_optim_2.state_dict() + + # Compare "state" in optim state dict + self._compare_state_dict_group( + state_dict_to_load["state"]["net1.0.weight"], + state_dict_before_load["state"]["net1.0.weight"], + assert_equal=False, + ) + self._compare_state_dict_group( + state_dict_to_load["state"]["net2.0.bias"], + state_dict_before_load["state"]["net2.0.bias"], + assert_equal=False, + ) + self._compare_state_dict_group( + state_dict_to_load["state"]["net3.weight"], + state_dict_before_load["state"]["net3.weight"], + assert_equal=False, + ) + self._compare_state_dict_group( + state_dict_to_load["state"]["net4.1.bias"], + state_dict_before_load["state"]["net4.1.bias"], + assert_equal=False, + ) + + named_optim_2.load_state_dict(state_dict_to_load) + state_dict_after_load = named_optim_2.state_dict() + + # Compare "state" in optim state dict + self._compare_state_dict_group( + state_dict_to_load["state"]["net1.0.weight"], + state_dict_after_load["state"]["net1.0.weight"], + assert_equal=True, + ) + self._compare_state_dict_group( + state_dict_to_load["state"]["net2.0.bias"], + state_dict_after_load["state"]["net2.0.bias"], + assert_equal=True, + ) + self._compare_state_dict_group( + state_dict_to_load["state"]["net3.weight"], + state_dict_after_load["state"]["net3.weight"], + assert_equal=True, + ) + self._compare_state_dict_group( + state_dict_to_load["state"]["net4.1.bias"], + state_dict_after_load["state"]["net4.1.bias"], + assert_equal=True, + ) + + def test_load_state_dict_error(self): + m = TestDummyModel() + named_optim_1 = _NamedOptimizer( + m.named_parameters(), + torch.optim.SGD, + lr=1e-2, + momentum=0.9, + ) + + for _ in range(2): + x = torch.rand(5, 8) + y = m(x) + y.sum().backward() + named_optim_1.step() + + state_dict_to_load = named_optim_1.state_dict() + + named_optim_2 = _NamedOptimizer( + m.named_parameters(), + torch.optim.SGD, + lr=1e-2, + momentum=0.6, + ) + + err_msg = ( + "Expects the optim to be initialized before load but found not initialized" + ) + with self.assertRaisesRegex(ValueError, err_msg): + named_optim_2.load_state_dict(state_dict_to_load) diff --git a/test/distributed/optim/test_zero_redundancy_optimizer.py b/test/distributed/optim/test_zero_redundancy_optimizer.py index 2a3224122a640..3e0474c3a4494 100644 --- a/test/distributed/optim/test_zero_redundancy_optimizer.py +++ b/test/distributed/optim/test_zero_redundancy_optimizer.py @@ -10,7 +10,7 @@ import sys import unittest from contextlib import suppress -from typing import Any, List, cast +from typing import Any, cast, List import numpy as np @@ -24,26 +24,25 @@ hook_with_zero_step, hook_with_zero_step_interleaved, ) -from torch.distributed.algorithms.ddp_comm_hooks.default_hooks import ( - allreduce_hook, -) +from torch.distributed.algorithms.ddp_comm_hooks.default_hooks import allreduce_hook from torch.distributed.algorithms.join import Join, Joinable, JoinHook from torch.distributed.optim import ZeroRedundancyOptimizer from torch.distributed.optim.zero_redundancy_optimizer import _broadcast_object from torch.nn.parallel import DistributedDataParallel as DDP -from torch.optim import SGD, AdamW +from torch.optim import AdamW, SGD from torch.testing._internal import common_distributed from torch.testing._internal.common_utils import ( - IS_WINDOWS, - TEST_WITH_ASAN, - TEST_WITH_DEV_DBG_ASAN, instantiate_parametrized_tests, + IS_WINDOWS, parametrize, run_tests, + TEST_WITH_ASAN, + TEST_WITH_DEV_DBG_ASAN, ) try: import torchvision + HAS_TORCHVISION = True except ImportError: HAS_TORCHVISION = False @@ -51,17 +50,18 @@ # Use GLOO on GPU when running CUDA + Windows def _get_backend_for_tests(): return ( - dist.Backend.NCCL if not IS_WINDOWS and torch.cuda.is_available() + dist.Backend.NCCL + if not IS_WINDOWS and torch.cuda.is_available() # Windows only has GLOO, but GLOO GPU works. And use GLOO CPU when # no GPUs are available. else dist.Backend.GLOO ) + BACKEND = _get_backend_for_tests() -@unittest.skipIf( - TEST_WITH_ASAN or TEST_WITH_DEV_DBG_ASAN, "CUDA + ASAN does not work." -) + +@unittest.skipIf(TEST_WITH_ASAN or TEST_WITH_DEV_DBG_ASAN, "CUDA + ASAN does not work.") class TestZeroRedundancyOptimizer(common_distributed.MultiProcessTestCase): def setUp(self): super(TestZeroRedundancyOptimizer, self).setUp() @@ -70,8 +70,9 @@ def setUp(self): @property def device(self): - return torch.device("cuda") if torch.cuda.is_available() \ - else torch.device("cpu") + return ( + torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + ) @property def world_size(self): @@ -88,18 +89,19 @@ def tearDown(self): pass def dist_init(self, rank, world_size=-1, backend=BACKEND): - if (world_size < 1): + if world_size < 1: world_size = self.world_size store = dist.FileStore(self.file_name, world_size) return dist.init_process_group( - backend=backend, store=store, rank=rank, world_size=world_size, + backend=backend, + store=store, + rank=rank, + world_size=world_size, ) # TODO: sandcastle_skip_if does not work here. -@unittest.skipIf( - TEST_WITH_ASAN or TEST_WITH_DEV_DBG_ASAN, "CUDA + ASAN does not work." -) +@unittest.skipIf(TEST_WITH_ASAN or TEST_WITH_DEV_DBG_ASAN, "CUDA + ASAN does not work.") class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer): def test_state_dict(self): """Check that ZeroRedundancyOptimizer exposes the expected state dict @@ -111,7 +113,10 @@ def test_state_dict(self): RECIPIENT_RANK = 0 # rank 0 is the only rank since the world size is 1 x = torch.tensor([1.0], device=self.device, requires_grad=True) o = ZeroRedundancyOptimizer( - [x], optimizer_class=SGD, lr=LR1, momentum=MOMENTUM, + [x], + optimizer_class=SGD, + lr=LR1, + momentum=MOMENTUM, ) x.backward() o.step() @@ -202,7 +207,9 @@ def step(self, closure=None, kwarg=None): kwarg: List[Any] = [] x = torch.tensor([1.0], device=self.device, requires_grad=True) o = ZeroRedundancyOptimizer( - [x], optimizer_class=SGDWithStepKWArg, lr=LR, + [x], + optimizer_class=SGDWithStepKWArg, + lr=LR, ) x.backward() o.step(0, kwarg=kwarg) @@ -241,7 +248,9 @@ def step(self): x = torch.tensor([1.0], device=self.device, requires_grad=True) o = ZeroRedundancyOptimizer( - [x], optimizer_class=SGDWithoutClosure, lr=LR, + [x], + optimizer_class=SGDWithoutClosure, + lr=LR, ) x.backward() o.step() @@ -274,22 +283,30 @@ def test_constructor(self): ) # Test various constructor inputs in the form: (input, expected error) ctor_inputs = [ - ([], ValueError), # empty parameter list - (torch.randn(1), TypeError), # non-iterable: `torch.Tensor` - (1.2, TypeError), # non-iterable: `float` - ([ - {"params": [l.weight for l in m]}, - {"params": [l.bias for l in m]}, - ], None), # iterable of dict - (list(m.parameters()) + [42], TypeError), # iterable containing invalid type - (m.parameters(), None), # `params` as a generator - (list(m.parameters()), None) # `params` as a list + ([], ValueError), # empty parameter list + (torch.randn(1), TypeError), # non-iterable: `torch.Tensor` + (1.2, TypeError), # non-iterable: `float` + ( + [ + {"params": [l.weight for l in m]}, + {"params": [l.bias for l in m]}, + ], + None, + ), # iterable of dict + ( + list(m.parameters()) + [42], + TypeError, + ), # iterable containing invalid type + (m.parameters(), None), # `params` as a generator + (list(m.parameters()), None), # `params` as a list ] for ctor_input, error in ctor_inputs: context = self.assertRaises(error) if error else suppress() with context: ZeroRedundancyOptimizer( - ctor_input, optimizer_class=SGD, lr=LR, + ctor_input, + optimizer_class=SGD, + lr=LR, ) # Test constructing with multiple parameter groups more thoroughly @@ -297,18 +314,23 @@ def test_constructor(self): BETAS = (0.9, 0.999) EPS = 1e-8 params = [ - {"params": [l.weight for l in m], "weight_decay": 0.}, + {"params": [l.weight for l in m], "weight_decay": 0.0}, {"params": [l.bias for l in m], "weight_decay": WD}, ] o = ZeroRedundancyOptimizer( - params, optimizer_class=AdamW, - lr=LR, betas=BETAS, eps=EPS, + params, + optimizer_class=AdamW, + lr=LR, + betas=BETAS, + eps=EPS, ) - assert len(o.param_groups) == 2, \ - f"Expected 2 ZeRO param groups, but got {len(o.param_groups)}" - assert len(o.optim.param_groups) == 2, \ - "Expected 2 local optimizer param groups, but got " \ + assert ( + len(o.param_groups) == 2 + ), f"Expected 2 ZeRO param groups, but got {len(o.param_groups)}" + assert len(o.optim.param_groups) == 2, ( + "Expected 2 local optimizer param groups, but got " f"{len(o.optim.param_groups)}" + ) def test_same_dense_param_type(self): """Check that ZeroRedundancyOptimizer raises an exception if the input @@ -322,8 +344,11 @@ def test_same_dense_param_type(self): inputs = [ [torch.sparse_coo_tensor(size=(2, 3))], [torch.FloatTensor(1), torch.DoubleTensor(1)], - [torch.FloatTensor(1), torch.FloatTensor(1), - torch.sparse_coo_tensor(size=(2, 3))] + [ + torch.FloatTensor(1), + torch.FloatTensor(1), + torch.sparse_coo_tensor(size=(2, 3)), + ], ] for input in inputs: with self.assertRaises(ValueError): @@ -333,8 +358,11 @@ def test_same_dense_param_type(self): class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer): @property def device(self): - return torch.device(self.rank) if torch.cuda.is_available() \ + return ( + torch.device(self.rank) + if torch.cuda.is_available() else torch.device("cpu") + ) @property def world_size(self): @@ -342,8 +370,11 @@ def world_size(self): @property def context(self): - return suppress() if not torch.cuda.is_available() \ + return ( + suppress() + if not torch.cuda.is_available() else torch.cuda.device(self.rank) + ) def _check_same_model_params( self, @@ -354,13 +385,17 @@ def _check_same_model_params( # Check that model parameters match for p_a, p_b in zip(model_a.parameters(), model_b.parameters()): torch.testing.assert_close( - p_a, p_b, atol=1e-3, rtol=1e-5, + p_a, + p_b, + atol=1e-3, + rtol=1e-5, msg=f"Model parameters differ:\n{p_a} {p_b}\n" + message, ) # Check that model buffers match for b_a, b_b in zip(model_a.buffers(), model_b.buffers()): torch.testing.assert_close( - b_a, b_b, + b_a, + b_b, msg=f"Model buffers differ:\n{b_a} {b_b}\n" + message, ) @@ -382,7 +417,9 @@ def test_step(self): o = SGD(m.parameters(), lr=LR) o_zero = ZeroRedundancyOptimizer( - m_zero.parameters(), optimizer_class=SGD, lr=LR, + m_zero.parameters(), + optimizer_class=SGD, + lr=LR, ) y = m(x) @@ -530,11 +567,7 @@ def all_trainable(): # all partitions have the same elements self.assertEqual(len(o.param_groups), 2) self.assertEqual( - sum([ - x.numel() - for g in o.optim.param_groups - for x in g["params"] - ]), + sum([x.numel() for g in o.optim.param_groups for x in g["params"]]), sum(sizes), ) self.assertEqual(len(o.optim.param_groups), 2) @@ -581,36 +614,39 @@ def test_multiple_param_groups(self): model2 = model2.to(self.device) model3 = model3.to(self.device) inputs = [ - torch.randn(BATCH_SIZE, INPUT_DIM).to(self.device) - for _ in range(NUM_ITERS) + torch.randn(BATCH_SIZE, INPUT_DIM).to(self.device) for _ in range(NUM_ITERS) ] # Construct `optim1` with both parameter groups upfront optim1 = ZeroRedundancyOptimizer( [ - {"params": [l.weight for l in model1], "weight_decay": 0.}, + {"params": [l.weight for l in model1], "weight_decay": 0.0}, {"params": [l.bias for l in model1], "weight_decay": WD}, ], - optimizer_class=AdamW, lr=LR, + optimizer_class=AdamW, + lr=LR, ) # Construct `optim2` by adding the second parameter after optim2 = ZeroRedundancyOptimizer( [l.weight for l in model2], - optimizer_class=AdamW, lr=LR, weight_decay=0., - ) - optim2.add_param_group( - {"params": [l.bias for l in model2], "weight_decay": WD} + optimizer_class=AdamW, + lr=LR, + weight_decay=0.0, ) + optim2.add_param_group({"params": [l.bias for l in model2], "weight_decay": WD}) # Construct `optim3` as a non-sharded optimizer optim3 = AdamW( [ - {"params": [l.weight for l in model3], "weight_decay": 0.}, + {"params": [l.weight for l in model3], "weight_decay": 0.0}, {"params": [l.bias for l in model3], "weight_decay": WD}, - ], lr=LR, + ], + lr=LR, ) # Check parity over a few iterations for input in inputs: for model, optim in ( - (model1, optim1), (model2, optim2), (model3, optim3), + (model1, optim1), + (model2, optim2), + (model3, optim3), ): optim.zero_grad() out = model(input) @@ -695,8 +731,7 @@ def test_nondefault_process_group(self): self.dist_init(self.rank, self.world_size, BACKEND) # Use GPU if enough are available, or fall back to CPU otherwise, which # is fine since Gloo backend supports both - if torch.cuda.is_available() and \ - torch.cuda.device_count() >= self.world_size: + if torch.cuda.is_available() and torch.cuda.device_count() >= self.world_size: device = torch.device(self.rank) else: device = torch.device("cpu") @@ -704,7 +739,8 @@ def test_nondefault_process_group(self): # the case where the global and local ranks do not necessarily match subgroup_ranks = [r for r in range(self.world_size) if r % 2 == 0] process_group = dist.new_group( - ranks=subgroup_ranks, backend=BACKEND, + ranks=subgroup_ranks, + backend=BACKEND, ) # Ranks not participating in the new process group are no longer needed if self.rank not in subgroup_ranks: @@ -719,8 +755,9 @@ def test_nondefault_process_group(self): LR = 1e-3 MOMENTUM = 0.99 REFERENCE_RANK = 0 - assert REFERENCE_RANK in subgroup_ranks, \ - "Reference rank must be in the new process group" + assert ( + REFERENCE_RANK in subgroup_ranks + ), "Reference rank must be in the new process group" loss_fn = torch.nn.L1Loss().to(device) def check(optimizer): @@ -742,11 +779,15 @@ def closure(): # Check that the parameters match across ranks after a step for pg in optimizer.param_groups: for p in pg["params"]: - receptacle = [ - p.clone() for _ in subgroup_ranks - ] if self.rank == REFERENCE_RANK else [] + receptacle = ( + [p.clone() for _ in subgroup_ranks] + if self.rank == REFERENCE_RANK + else [] + ) dist.gather( - p, receptacle, dst=REFERENCE_RANK, + p, + receptacle, + dst=REFERENCE_RANK, group=process_group, ) if self.rank == REFERENCE_RANK: @@ -814,31 +855,41 @@ def test_local_optimizer_parity( torch.nn.Linear(HIDDEN_DIM, OUTPUT_DIM), ).to(self.device) model.register_buffer( - "test_buffer", torch.ones((1), device=self.device) * self.rank, + "test_buffer", + torch.ones((1), device=self.device) * self.rank, ) # Define models/optimizers for DDP with ZeRO and DDP with local # optimizer defaults = {"maximize": True} if maximize else {} sharded_optimizer = ZeroRedundancyOptimizer( - params=model.parameters(), optimizer_class=optimizer_class, - lr=LR, **defaults, + params=model.parameters(), + optimizer_class=optimizer_class, + lr=LR, + **defaults, ) sharded_ddp_model = DDP( - module=model, device_ids=[self.rank], - broadcast_buffers=True, find_unused_parameters=True, + module=model, + device_ids=[self.rank], + broadcast_buffers=True, + find_unused_parameters=True, ) local_model = copy.deepcopy(model).to(self.device) ddp_optimizer = optimizer_class( - local_model.parameters(), lr=LR, **defaults, + local_model.parameters(), + lr=LR, + **defaults, ) ddp_model = DDP( - local_model, device_ids=[self.rank], - broadcast_buffers=True, find_unused_parameters=True, + local_model, + device_ids=[self.rank], + broadcast_buffers=True, + find_unused_parameters=True, ) # Check that the model is properly synchronized between ranks # at construction time self._check_same_model_params( - sharded_ddp_model, ddp_model, + sharded_ddp_model, + ddp_model, "Models differ from the start", ) @@ -858,18 +909,21 @@ def closure_sharded(input_tensor=input_tensor): return sharded_loss loss_ddp = cast( - torch.Tensor, ddp_optimizer.step(closure=closure_ddp), + torch.Tensor, + ddp_optimizer.step(closure=closure_ddp), ) loss_sharded_optim = cast( torch.Tensor, sharded_optimizer.step(closure=closure_sharded), ) torch.testing.assert_close( - loss_ddp, loss_sharded_optim, + loss_ddp, + loss_sharded_optim, msg="Losses differ between local optimizer and ZeRO", ) self._check_same_model_params( - sharded_ddp_model, ddp_model, + sharded_ddp_model, + ddp_model, "Models differ after a step", ) @@ -889,11 +943,11 @@ def closure_sharded(input_tensor=input_tensor): ddp_state_dict = ddp_optimizer.state_dict() sharded_optimizer.consolidate_state_dict(to=REFERENCE_RANK) sharded_optim_state_dict = [ - sharded_optimizer.state_dict() - if self.rank == REFERENCE_RANK else {} + sharded_optimizer.state_dict() if self.rank == REFERENCE_RANK else {} ] dist.broadcast_object_list( - sharded_optim_state_dict, src=REFERENCE_RANK, + sharded_optim_state_dict, + src=REFERENCE_RANK, group=dist.group.WORLD, ) sharded_optim_state_dict = sharded_optim_state_dict[0] @@ -941,14 +995,14 @@ def _test_zero_join(self, device): zero_model = copy.deepcopy(model) zero_model.to(device) zero_optim = ZeroRedundancyOptimizer( - zero_model.parameters(), torch.optim.Adam, lr=LR, + zero_model.parameters(), + torch.optim.Adam, + lr=LR, ) loss_fn = torch.nn.MSELoss() # Use uneven inputs: rank i has i extra inputs - inputs = [ - torch.randn(20, 2).to(device) for _ in range(NUM_INPUTS + rank) - ] + inputs = [torch.randn(20, 2).to(device) for _ in range(NUM_INPUTS + rank)] labels = torch.randn(20, 3).to(device) # Save the gradients and parameters from DDP as the ground truth; do @@ -976,7 +1030,9 @@ def _test_zero_join(self, device): # ranks (which joined early) grads_and_params = [grads_at_each_iter, params_at_each_iter] grads_and_params = _broadcast_object( - grads_and_params, src_rank=world_size - 1, group=dist.group.WORLD, + grads_and_params, + src_rank=world_size - 1, + group=dist.group.WORLD, device=device, ) grads_at_each_iter = grads_and_params[0] @@ -987,7 +1043,7 @@ def _test_zero_join(self, device): # A process must still set the remaining gradients after joining, so we # define a join hook to do this before the ZeRO join hook - class _JoinGradInfo(): + class _JoinGradInfo: def __init__(self, grads): self.grads = grads # remaining gradients to set (in order) self.index = 0 @@ -1029,7 +1085,9 @@ def join_process_group(self): gradient_setter = _GradientSetter() iter = 0 with Join( - [gradient_setter, zero_optim], zero_optim=zero_optim, grads=grads, + [gradient_setter, zero_optim], + zero_optim=zero_optim, + grads=grads, ): for _ in range(NUM_EPOCHS): for input in inputs: @@ -1037,16 +1095,19 @@ def join_process_group(self): Join.notify_join_context(gradient_setter) # Set gradients manually for p, grad in zip( - zero_model.parameters(), grads_at_each_iter[iter], + zero_model.parameters(), + grads_at_each_iter[iter], ): p.grad = grad.detach().clone().to(device) # Perform optimizer step and check parity zero_optim.step() for p, ddp_p in zip( - zero_model.parameters(), params_at_each_iter[iter], + zero_model.parameters(), + params_at_each_iter[iter], ): torch.testing.assert_close( - p, ddp_p, + p, + ddp_p, msg="Parameters differ between using ZeRO and " "local optimizer", ) @@ -1127,6 +1188,7 @@ def copy_param(p): for _ in range(NUM_EPOCHS): for input in inputs: + def closure_local(): local_optim.zero_grad() local_loss = local_model(input).abs().sum() @@ -1139,25 +1201,26 @@ def closure_ddp(): ddp_loss.backward() return ddp_loss - local_loss = cast( - torch.Tensor, local_optim.step(closure=closure_local) - ) - ddp_loss = cast( - torch.Tensor, zero_optim.step(closure=closure_ddp) - ) + local_loss = cast(torch.Tensor, local_optim.step(closure=closure_local)) + ddp_loss = cast(torch.Tensor, zero_optim.step(closure=closure_ddp)) # Increased tolerances are needed to pass when using TF32 # See: https://github.com/pytorch/pytorch/issues/67764 torch.testing.assert_close( - local_loss.cpu(), ddp_loss.cpu(), rtol=1e-03, atol=1e-08, + local_loss.cpu(), + ddp_loss.cpu(), + rtol=1e-03, + atol=1e-08, ), "Losses differ between local optimizer and ZeRO" for local_p, ddp_p in zip( - local_model.parameters(), - ddp_model.parameters() + local_model.parameters(), ddp_model.parameters() ): torch.testing.assert_close( - local_p.cpu(), ddp_p.cpu(), rtol=1e-03, atol=1e-04, + local_p.cpu(), + ddp_p.cpu(), + rtol=1e-03, + atol=1e-04, ), "Models differ after a step" @common_distributed.skip_if_lt_x_gpu(4) @@ -1176,9 +1239,13 @@ def test_zero_model_parallel( # Disable DDP + ReplicatedTensor when `parameter_as_bucket_view=True` # since then ZeroRedundancyOptimizer modifies the model parameters in # place. - from torch.nn.parallel._replicated_tensor_ddp_utils import _ddp_replicated_tensor - context = _ddp_replicated_tensor(False) if parameters_as_bucket_view \ - else suppress() + from torch.nn.parallel._replicated_tensor_ddp_utils import ( + _ddp_replicated_tensor, + ) + + context = ( + _ddp_replicated_tensor(False) if parameters_as_bucket_view else suppress() + ) with context: self.dist_init(self.rank, world_size=2) self._test_zero_model_parallel(parameters_as_bucket_view) @@ -1202,21 +1269,22 @@ def _test_ddp_zero_overlap( is_gpu = device.type == "cuda" if is_gpu: torch.cuda.set_device(device) - models_to_test = [( - torch.nn.Sequential( - torch.nn.Linear(1000, 2000), - torch.nn.Linear(2000, 500), - ), - [torch.randn(1, 1000).to(device) for _ in range(NUM_INPUTS)], - )] + models_to_test = [ + ( + torch.nn.Sequential( + torch.nn.Linear(1000, 2000), + torch.nn.Linear(2000, 500), + ), + [torch.randn(1, 1000).to(device) for _ in range(NUM_INPUTS)], + ) + ] if HAS_TORCHVISION: - models_to_test.append(( - torchvision.models.resnet50(), - [ - torch.randn(1, 3, 3, 1000).to(device) - for _ in range(NUM_INPUTS) - ] - )) + models_to_test.append( + ( + torchvision.models.resnet50(), + [torch.randn(1, 3, 3, 1000).to(device) for _ in range(NUM_INPUTS)], + ) + ) for (model, inputs) in models_to_test: # Enable determinism in cudnn operators with torch.backends.cudnn.flags( @@ -1227,7 +1295,7 @@ def _test_ddp_zero_overlap( ddp_model_overlap = DDP( copy.deepcopy(model).to(device), device_ids=device_ids, - gradient_as_bucket_view=gradient_as_bucket_view + gradient_as_bucket_view=gradient_as_bucket_view, ) if static_graph: ddp_model_overlap._set_static_graph() @@ -1242,16 +1310,18 @@ def _test_ddp_zero_overlap( ddp_model_overlap.register_comm_hook( None, hook_constructor( - allreduce_hook, ddp_model_overlap, zero_optim, + allreduce_hook, + ddp_model_overlap, + zero_optim, **kwargs, - ) + ), ) # Set up the DDP model with local optimizer ddp_model_local = DDP( copy.deepcopy(model).to(device), device_ids=device_ids, - gradient_as_bucket_view=gradient_as_bucket_view + gradient_as_bucket_view=gradient_as_bucket_view, ) if static_graph: ddp_model_local._set_static_graph() @@ -1259,13 +1329,12 @@ def _test_ddp_zero_overlap( ddp_model_local.parameters(), lr=SGD_LR, momentum=SGD_MOMENTUM, - weight_decay=SGD_WEIGHT_DECAY + weight_decay=SGD_WEIGHT_DECAY, ) # Check that the parameters match initially for p1, p2 in zip( - ddp_model_overlap.parameters(), - ddp_model_local.parameters() + ddp_model_overlap.parameters(), ddp_model_local.parameters() ): self.assertEqual(p1, p2) @@ -1303,14 +1372,14 @@ def _test_ddp_zero_overlap( # Check that the parameters are equal for p1, p2 in zip( - ddp_model_overlap.parameters(), - ddp_model_local.parameters() + ddp_model_overlap.parameters(), ddp_model_local.parameters() ): self.assertEqual(p1, p2) # Check that the parameters were updated self.assertNotEqual( - init_params_overlap, list(ddp_model_overlap.parameters()), + init_params_overlap, + list(ddp_model_overlap.parameters()), ) # Ensure that this test runs independently @@ -1360,15 +1429,24 @@ def test_ddp_zero_overlap( device = torch.device(self.rank) if use_gpu else torch.device("cpu") backend = _get_backend_for_tests() self.dist_init(self.rank, self.world_size, backend) - hook_constructor = hook_with_zero_step if not use_interleaved_hook \ + hook_constructor = ( + hook_with_zero_step + if not use_interleaved_hook else hook_with_zero_step_interleaved + ) # Disable DDP + ReplicatedTensor since ZeroRedundancyOptimizer # modifies the model parameters in place. - from torch.nn.parallel._replicated_tensor_ddp_utils import _ddp_replicated_tensor + from torch.nn.parallel._replicated_tensor_ddp_utils import ( + _ddp_replicated_tensor, + ) + with _ddp_replicated_tensor(False): self._test_ddp_zero_overlap( - device, hook_constructor, gradient_as_bucket_view, static_graph, + device, + hook_constructor, + gradient_as_bucket_view, + static_graph, shard_buckets=shard_buckets, ) diff --git a/test/distributed/tensor/parallel/__init__.py b/test/distributed/tensor/parallel/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/distributed/tensor/parallel/test_2d_parallel.py b/test/distributed/tensor/parallel/test_2d_parallel.py new file mode 100644 index 0000000000000..e71be70ae9ab8 --- /dev/null +++ b/test/distributed/tensor/parallel/test_2d_parallel.py @@ -0,0 +1,217 @@ +# Owner(s): ["oncall: distributed"] + +from typing import Any + +import torch +import torch.distributed as dist + +import torch.distributed.distributed_c10d as distributed_c10d +import torch.nn.functional as F +from torch.distributed._shard.sharded_tensor.api import ShardedTensor +from torch.distributed._tensor import DeviceMesh, DTensor as DT, Replicate +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType +from torch.distributed.tensor.parallel import PairwiseParallel, parallelize_module +from torch.distributed.tensor.parallel.fsdp import is_available +from torch.testing._internal.common_distributed import skip_if_lt_x_gpu + +from torch.testing._internal.common_utils import run_tests + +from torch.testing._internal.distributed._tensor.common_dtensor import ( + DTensorTestBase, + with_comms, +) + +# Tensor-Parallel degree +TP_DEGREE = 2 +LR = 3e-5 + + +class SimpleModel(torch.nn.Module): + def __init__(self): + super(SimpleModel, self).__init__() + self.net1 = torch.nn.Linear(5, 8) + self.relu = torch.nn.ReLU() + self.net2 = torch.nn.Linear(8, 4) + self.net3 = torch.nn.Linear(4, 12) + + def forward(self, x): + x = F.relu(self.net1(x)) + x = F.relu(self.net2(x)) + x = F.relu(self.net3(x)) + return x + + +def _distribute_and_fsdp_wrap_module( + module, module_shard, mesh_2d, fsdp_pg, use_orig_params, fsdp_nested +): + if module_shard: + module = parallelize_module(module, mesh_2d, PairwiseParallel(), tp_mesh_dim=1) + pg = fsdp_pg if module_shard else distributed_c10d._get_default_group() + + if fsdp_nested: + module.net1 = FSDP( + module.net1, process_group=pg, use_orig_params=use_orig_params + ) + module.net2 = FSDP( + module.net2, process_group=pg, use_orig_params=use_orig_params + ) + return FSDP(module, process_group=pg, use_orig_params=use_orig_params) + + +def init_model(model_parallel_size=TP_DEGREE, use_orig_params=False, fsdp_nested=False): + rank = dist.get_rank() + torch.cuda.set_device(rank) + world_size = dist.get_world_size() + + model = SimpleModel().cuda(rank) + + # 2-D mesh is [dp, tp] + twod_mesh = DeviceMesh( + device_type="cuda", + mesh=torch.arange(0, world_size).view(model_parallel_size, -1), + ) + + fsdp_pg = twod_mesh.get_dim_groups()[0] + + # Create Input + model = _distribute_and_fsdp_wrap_module( + model, True, twod_mesh, fsdp_pg, use_orig_params, fsdp_nested + ) + return model, fsdp_pg + + +def is_nested_tensor(val: Any) -> bool: + if isinstance(val, ShardedTensor): + if len(val.local_shards()) == 0: + return False + if isinstance(val.local_shards()[0].tensor, ShardedTensor): + return True + if isinstance(val.local_shards()[0].tensor, DT): + raise ValueError("Cannot handle DT nested insided ST") + # Safety valve for when this eventually happen + elif isinstance(val, DT) and isinstance(val._local_tensor, (DT, ShardedTensor)): + raise ValueError("Cannot handle nested DT") + return False + + +class Test2dParallelIntegration(DTensorTestBase): + @with_comms + @skip_if_lt_x_gpu(4) + def test_2d_fsdp_integration_functionality(self) -> None: + if not is_available(): + self.skipTest("FSDP 2d parallel integration not available") + + model_tp = init_model()[0] + + with FSDP.state_dict_type(model_tp, StateDictType.SHARDED_STATE_DICT): + state_dict = model_tp.state_dict() + # TODO once 2D is out, validate the nesting + self.assertTrue(is_nested_tensor(state_dict["net1.weight"])) + self.assertFalse(is_nested_tensor(state_dict["net3.bias"])) + + optim = torch.optim.Adam(model_tp.parameters(), lr=0.0001) + + # Create Input + input_seed = self.rank + torch.manual_seed(input_seed + 1) + input = torch.rand(4, 5).cuda(self.rank) + + model_tp(input).sum().backward() + optim.step() + + optim_state = FSDP.sharded_optim_state_dict(model_tp, optim) + # TODO once 2D is out, validate the nesting + self.assertTrue( + is_nested_tensor(optim_state["state"]["net1.weight"]["exp_avg"]) + ) + self.assertFalse(is_nested_tensor(optim_state["state"]["net3.bias"]["exp_avg"])) + + def _compare_params(self, m1, m2): + with FSDP.summon_full_params(m1): + with FSDP.summon_full_params(m2): + for n_p1, n_p2 in zip(m1.named_parameters(), m2.named_parameters()): + p1 = n_p1[1] + p2 = n_p2[1] + self.assertEqual(n_p1[0], n_p2[0]) + name = n_p1[0] + if name == "net2.bias" and self.rank != 0: + continue + if type(p2) is DT: + p2 = p2.redistribute(p2.device_mesh, [Replicate()]).to_local() + self.assertTrue(torch.allclose(p1, p2), f"{p1} vs {p2}") + + def _test_2d_e2e_flow( + self, use_orig_params=False, fsdp_nested=False, multi_param_group=False + ) -> None: + if not is_available(): + self.skipTest("FSDP 2d parallel integration not available") + torch.manual_seed(0) + model = SimpleModel().cuda(self.rank) + model = FSDP(model, use_orig_params=use_orig_params) + torch.manual_seed(0) + model_2d, dp_pg = init_model( + use_orig_params=use_orig_params, fsdp_nested=fsdp_nested + ) + # Check named parameters are returning the same name at least. + param_names_2d = [name for name, _ in model_2d.named_parameters()] + for name, _ in model.named_parameters(): + self.assertTrue(name in param_names_2d) + self._compare_params(model, model_2d) + + if multi_param_group and use_orig_params: + param_group = [ + {"params": model.net1.parameters(), "lr": 0.02}, + {"params": model.net2.parameters(), "lr": 0.15}, + ] + optim = torch.optim.Adam(param_group, lr=0.01) + param_group = [ + {"params": model_2d.net1.parameters(), "lr": 0.02}, + {"params": model_2d.net2.parameters(), "lr": 0.15}, + ] + optim_2d = torch.optim.Adam(param_group, lr=0.01) + else: + optim = torch.optim.Adam(model.parameters(), lr=0.01) + optim_2d = torch.optim.Adam(model_2d.parameters(), lr=0.01) + + for i in range(5): + # Ensure all input across TP ranks are same. + torch.manual_seed(i + dist.get_rank(dp_pg)) + input = torch.rand(4, 5).cuda(self.rank) + output = model(input) + output_2d = model_2d(input) + self.assertEqual(output, output_2d) + output.sum().backward() + output_2d.sum().backward() + optim.step() + optim_2d.step() + self.assertEqual(model(input), model_2d(input)) + + # Ensure all params are still the same after optimizer update. + self._compare_params(model, model_2d) + + @with_comms + @skip_if_lt_x_gpu(4) + def test_2d_fsdp_integration_correctness(self) -> None: + self._test_2d_e2e_flow() + + @with_comms + @skip_if_lt_x_gpu(4) + def test_2d_fsdp_integration_use_orig_params(self) -> None: + self._test_2d_e2e_flow(use_orig_params=True) + + @with_comms + @skip_if_lt_x_gpu(4) + def test_2d_fsdp_integration_fsdp_nested(self) -> None: + self._test_2d_e2e_flow(fsdp_nested=True) + + @with_comms + @skip_if_lt_x_gpu(4) + def test_2d_fsdp_integration_fsdp_nested_param_groups(self) -> None: + self._test_2d_e2e_flow( + fsdp_nested=True, use_orig_params=True, multi_param_group=True + ) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/tensor/parallel/test_parallelize_api.py b/test/distributed/tensor/parallel/test_parallelize_api.py new file mode 100644 index 0000000000000..7375de3ef1814 --- /dev/null +++ b/test/distributed/tensor/parallel/test_parallelize_api.py @@ -0,0 +1,201 @@ +# Owner(s): ["oncall: distributed"] + +import torch +from torch.distributed._tensor import DeviceMesh, DTensor, Replicate +from torch.distributed.tensor.parallel._utils import _create_1d_device_mesh +from torch.distributed.tensor.parallel.api import _parallelize_linear, _parallelize_mlp +from torch.distributed.tensor.parallel.style import ( + ColwiseParallel, + make_input_replicate_1d, + make_output_replicate_1d, + PairwiseParallel, + ParallelStyle, + RowwiseParallel, +) +from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.distributed._tensor.common_dtensor import ( + DTensorTestBase, + with_comms, +) + + +class MLPModule(torch.nn.Module): + def __init__(self, device): + super(MLPModule, self).__init__() + torch.manual_seed(5) + self.net1 = torch.nn.Linear(10, 16, device=device) + self.relu = torch.nn.ReLU() + self.net2 = torch.nn.Linear(16, 12, device=device) + + def forward(self, x): + return self.net2(self.relu(self.net1(x))) + + +class TensorParallelAPITests(DTensorTestBase): + @property + def world_size(self): + gpu_num = torch.cuda.device_count() + return gpu_num if gpu_num % 2 == 0 and gpu_num > 4 else 4 + + @with_comms + def test_creat_1d_device_mesh(self): + dim_one_size = 2 + mesh_shape = ( + torch.arange(self.world_size) + .reshape( + self.world_size // dim_one_size, + dim_one_size, + ) + .to(torch.int) + ) + mesh = DeviceMesh(self.device_type, mesh_shape) + # When 1D dim is 1. + one_dimention_mesh_shape = mesh_shape[self.rank // dim_one_size, :] + pg = mesh.get_dim_groups()[1] + new_mesh = _create_1d_device_mesh(mesh, 1) + expected_mesh = DeviceMesh(self.device_type, one_dimention_mesh_shape, [pg]) + self.assertEqual(new_mesh.mesh, expected_mesh.mesh) + self.assertEqual(new_mesh.device_type, expected_mesh.device_type) + # When 1D dim is 0. + one_dimention_mesh_shape = mesh_shape[:, self.rank % dim_one_size] + pg = mesh.get_dim_groups()[0] + new_mesh = _create_1d_device_mesh(mesh, 0) + expected_mesh = DeviceMesh(self.device_type, one_dimention_mesh_shape, [pg]) + self.assertEqual(new_mesh.mesh, expected_mesh.mesh) + self.assertEqual(new_mesh.device_type, expected_mesh.device_type) + + @with_comms + def test_creat_1d_device_mesh_error(self): + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + with self.assertRaisesRegex( + AssertionError, + "Expect tp_mesh_dim within range \\[-1, 1\\), but found 3.", + ): + _create_1d_device_mesh(mesh, 3) + + def _compare_params( + self, + local_module, + dist_module, + skip_rowwise_bias=False, + compare_grad=False, + ): + replicate = [Replicate()] + for name, param in local_module.named_parameters(): + dist_param = dist_module.get_parameter(name) + param = param.grad if compare_grad else param + dist_param = dist_param.grad if compare_grad else dist_param + if self.rank == 0 or ( + name not in ["net2.bias"] + and not skip_rowwise_bias + or name not in ["bias", "net2.bias"] + ): + self.assertEqual( + param, + dist_param.redistribute( + device_mesh=dist_param.device_mesh, placements=replicate + ).to_local(), + ) + + def _compare_module(self, local_module, dist_module, inp_size, rowwise=False): + LR = 0.25 # the learning rate we use for testing + local_optim = torch.optim.SGD(local_module.parameters(), lr=LR) + dist_optim = torch.optim.SGD(dist_module.parameters(), lr=LR) + torch.manual_seed(0) + inp = torch.rand(*inp_size, device=self.device_type) + self._compare_params(local_module, dist_module) + + # check forward correctness + local_output = local_module(inp) + inp = inp.chunk(self.world_size, dim=-1)[self.rank] if rowwise else inp + dist_output = dist_module(inp) + dist_output = ( + dist_output.to_local() if isinstance(dist_output, DTensor) else dist_output + ) + self.assertEqual(local_output, dist_output) + + local_output.sum().backward() + dist_output.sum().backward() + + # check backward and ensure gradients are same + self._compare_params(local_module, dist_module, rowwise, True) + + local_optim.step() + dist_optim.step() + self._compare_params(local_module, dist_module, rowwise) + + @with_comms + def test_parallelize_mlp(self): + inp_size = [12, 10] + model = MLPModule(self.device_type) + model_tp = MLPModule(self.device_type) + + # Ensure model are initialized the same way. + self.assertEqual(model.net1.weight, model_tp.net1.weight) + self.assertEqual(model.net1.bias, model_tp.net1.bias) + self.assertEqual(model.net2.weight, model_tp.net2.weight) + self.assertEqual(model.net2.bias, model_tp.net2.bias) + + # Parallelize module. + device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + model_tp = _parallelize_mlp(model_tp, device_mesh, PairwiseParallel()) + self._compare_module(model, model_tp, inp_size) + + @with_comms + def test_parallelize_mlp_error(self): + class DummyParallel(ParallelStyle): + def __init__(self) -> None: + super().__init__(make_input_replicate_1d, make_output_replicate_1d) + + model_tp = MLPModule(self.device_type) + device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + with self.assertRaisesRegex( + NotImplementedError, + "Only support PairwiseParallel for MLP parallelization.", + ): + _parallelize_mlp(model_tp, device_mesh, DummyParallel()) + + with self.assertRaisesRegex( + RuntimeError, "More than one nn.Linear needed for a MLP." + ): + _parallelize_mlp(torch.nn.Linear(10, 5), device_mesh, PairwiseParallel()) + + @with_comms + def test_linear_row_wise_parallel(self): + # test RowwiseParallel + inp_size = [9, 16] + rowwise = RowwiseParallel() + + torch.manual_seed(5) + model = torch.nn.Linear(16, 10, device=self.device_type) + torch.manual_seed(5) + model_tp = torch.nn.Linear(16, 10, device=self.device_type) + + # parallelize model_tp + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + model_tp = _parallelize_linear(model_tp, device_mesh, rowwise) + + # let each rank generate unique local input + torch.manual_seed(self.rank) + self._compare_module(model, model_tp, inp_size, True) + + @with_comms + def test_linear_col_wise_parallel(self): + # test ColwiseParallel + inp_size = [8, 10] + colwise = ColwiseParallel() + + torch.manual_seed(5) + model = torch.nn.Linear(10, 16, device=self.device_type) + torch.manual_seed(5) + model_tp = torch.nn.Linear(10, 16, device=self.device_type) + + # parallelize model_tp + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + model_tp = _parallelize_linear(model_tp, device_mesh, colwise) + + self._compare_module(model, model_tp, inp_size) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/tensor/parallel/test_tp_examples.py b/test/distributed/tensor/parallel/test_tp_examples.py new file mode 100644 index 0000000000000..12ee9b0b651c2 --- /dev/null +++ b/test/distributed/tensor/parallel/test_tp_examples.py @@ -0,0 +1,428 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# Owner(s): ["oncall: distributed"] + +import torch +import torch.nn as nn +from torch.distributed._tensor import DeviceMesh, Replicate +from torch.distributed.tensor.parallel import ( + PairwiseParallel, + parallelize_module, + TensorParallelMultiheadAttention, +) +from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.distributed._tensor.common_dtensor import ( + DTensorTestBase, + NUM_DEVICES, + skip_unless_torch_gpu, + with_comms, +) + + +class MLPModule(torch.nn.Module): + def __init__(self, device): + super(MLPModule, self).__init__() + torch.manual_seed(5) + self.net1 = torch.nn.Linear(10, 16, device=device) + self.relu = torch.nn.ReLU() + self.net2 = torch.nn.Linear(16, 12, device=device) + + def forward(self, x): + return self.net2(self.relu(self.net1(x))) + + +class MultiheadAttnWrap(nn.Module): + def __init__(self, embed_dim, num_heads, add_bias_kv=False, device=None): + super().__init__() + self.attn = nn.MultiheadAttention( + embed_dim, num_heads, add_bias_kv=add_bias_kv, device=device + ) + + def forward(self, query, key, value): + return self.attn(query, key, value) + + +# TODO: replace repeated test code with _check_module +class DistTensorParallelExampleTest(DTensorTestBase): + @with_comms + def test_mlp_megatron_e2e(self): + inp_size = [5, 10] + # Ensure all tp ranks have same input. + torch.manual_seed(0) + inp = torch.rand(*inp_size, device=self.device_type) + model = MLPModule(self.device_type) + model_tp = MLPModule(self.device_type) + + # Ensure model are initialized the same way. + self.assertEqual(model.net1.weight, model_tp.net1.weight) + self.assertEqual(model.net1.bias, model_tp.net1.bias) + self.assertEqual(model.net2.weight, model_tp.net2.weight) + self.assertEqual(model.net2.bias, model_tp.net2.bias) + + # Shard module and initialize optimizer. + LR = 0.25 + device_mesh = DeviceMesh( + self.device_type, + torch.arange(0, NUM_DEVICES), + ) + model_tp = parallelize_module(model_tp, device_mesh, PairwiseParallel()) + optim = torch.optim.SGD(model.parameters(), lr=LR) + optim_tp = torch.optim.SGD(model_tp.parameters(), lr=LR) + + output = model(inp) + output_tp = model_tp(inp) + self.assertEqual(output, output_tp) + + output.sum().backward() + output_tp.sum().backward() + + device_mesh = model_tp.net1.weight.device_mesh + replicate = [Replicate()] * device_mesh.ndim + + # Ensure gradients are same. + self.assertEqual( + model.net1.weight.grad, + model_tp.net1.weight.grad.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + self.assertEqual( + model.net1.bias.grad, + model_tp.net1.bias.grad.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + self.assertEqual( + model.net2.weight.grad, + model_tp.net2.weight.grad.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + self.assertEqual( + model.net2.bias.grad, + model_tp.net2.bias.grad.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + + optim.step() + optim_tp.step() + + # Ensure model weights are still same after update. + self.assertEqual( + model.net1.weight, + model_tp.net1.weight.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + self.assertEqual( + model.net1.bias, + model_tp.net1.bias.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + self.assertEqual( + model.net2.weight, + model_tp.net2.weight.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + # Due to the trick we use for Partial aggregation, we only check the weight when local_rank = 0. + if self.rank == 0: + self.assertEqual( + model.net2.bias, + model_tp.net2.bias.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + + inp = torch.rand(*inp_size, device=self.device_type) + output = model(inp) + output_tp = model_tp(inp) + self.assertEqual(output, output_tp) + + # TensorParallelMultiheadAttention == dist_module(TensorParallelMultiheadAttention) + # baddbmm introduces nan occasionally on CPU: https://github.com/pytorch/pytorch/issues/80588 + @with_comms + @skip_unless_torch_gpu + def test_self_attn_megatron_e2e(self): + inp_size = [8, 12, 16] + # Ensure all tp ranks have same input. + torch.manual_seed(0) + inp = torch.rand(*inp_size, device=self.device_type) + + # Initialize model using same seed. + torch.manual_seed(5) + model = TensorParallelMultiheadAttention( + 16, + 8, + tp_size=NUM_DEVICES, + add_bias_kv=True, + device=self.device_type, + ) + torch.manual_seed(5) + model_tp = TensorParallelMultiheadAttention( + 16, + 8, + tp_size=NUM_DEVICES, + add_bias_kv=True, + device=self.device_type, + ) + + # Ensure model are initialized the same way. + self.assertEqual(model.qkv.weight, model_tp.qkv.weight) + self.assertEqual(model.qkv.bias, model_tp.qkv.bias) + self.assertEqual(model.proj.weight, model_tp.proj.weight) + self.assertEqual(model.proj.bias, model_tp.proj.bias) + + # Shard module and initialize optimizer. + device_mesh = DeviceMesh(self.device_type, list(range(NUM_DEVICES))) + parallelize_module(model_tp, device_mesh, PairwiseParallel()) + + device_mesh = model_tp.qkv.weight.device_mesh + replicate = [Replicate()] * device_mesh.ndim + # Ensure model are initialized the same way. + self.assertEqual( + model.qkv.weight, + model_tp.qkv.weight.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + self.assertEqual( + model.qkv.bias, + model_tp.qkv.bias.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + self.assertEqual( + model.proj.weight, + model_tp.proj.weight.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + self.assertEqual( + model.proj.bias, + model_tp.proj.bias.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + + LR = 0.25 + optim = torch.optim.SGD(model.parameters(), lr=LR) + optim_tp = torch.optim.SGD(model_tp.parameters(), lr=LR) + + output = model(inp, inp, inp) + output_tp = model_tp(inp, inp, inp) + self.assertEqual(output, output_tp) + + output.sum().backward() + output_tp.sum().backward() + + device_mesh = model_tp.qkv.weight.device_mesh + # Ensure gradients are same. + self.assertEqual( + model.qkv.weight.grad, + model_tp.qkv.weight.grad.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + self.assertEqual( + model.qkv.bias.grad, + model_tp.qkv.bias.grad.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + self.assertEqual( + model.proj.weight.grad, + model_tp.proj.weight.grad.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + self.assertEqual( + model.proj.bias.grad, + model_tp.proj.bias.grad.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + + optim.step() + optim_tp.step() + + # Ensure model weights are still same after update. + self.assertEqual( + model.qkv.weight, + model_tp.qkv.weight.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + self.assertEqual( + model.qkv.bias, + model_tp.qkv.bias.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + self.assertEqual( + model.proj.weight, + model_tp.proj.weight.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + self.assertEqual( + model.proj.bias, + model_tp.proj.bias.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + + inp = torch.rand(*inp_size, device=self.device_type) + output = model(inp, inp, inp) + output_tp = model_tp(inp, inp, inp) + self.assertEqual(output, output_tp) + + # TensorParallelMultiheadAttention == dist_module(torch.nn.MultiheadAttention) + # baddbmm introduces nan occasionally on CPU: https://github.com/pytorch/pytorch/issues/80588 + @with_comms + @skip_unless_torch_gpu + def test_self_attn_replacement_megatron_e2e(self): + inp_size = [8, 12, 16] + # Ensure all tp ranks have same input. + torch.manual_seed(0) + inp = torch.rand(*inp_size, device=self.device_type) + + # TODO: our sharding function cannot shard the root node + torch.manual_seed(5) + model = TensorParallelMultiheadAttention( + 16, + 8, + tp_size=NUM_DEVICES, + add_bias_kv=True, + device=self.device_type, + ) + model_tp = MultiheadAttnWrap(16, 8, add_bias_kv=True, device=self.device_type) + + # TODO: somehow using torch.nn.MultiheadAttention's initial params does not work + # Use TensorParallelMultiheadAttention parameters instead + x = model.qkv.weight.clone().detach().requires_grad_() + model_tp.attn.register_parameter("in_proj_weight", torch.nn.Parameter(x)) + + x = model.qkv.bias.clone().detach().requires_grad_() + model_tp.attn.register_parameter("in_proj_bias", torch.nn.Parameter(x)) + + x = model.proj.weight.clone().detach().requires_grad_() + model_tp.attn.out_proj.register_parameter("weight", torch.nn.Parameter(x)) + + x = model.proj.bias.clone().detach().requires_grad_() + model_tp.attn.out_proj.register_parameter("bias", torch.nn.Parameter(x)) + + # check if parameters are same + self.assertEqual(model.qkv.weight, model_tp.attn.in_proj_weight) + self.assertEqual(model.qkv.bias, model_tp.attn.in_proj_bias) + self.assertEqual(model.proj.weight, model_tp.attn.out_proj.weight) + self.assertEqual(model.proj.bias, model_tp.attn.out_proj.bias) + + # Shard module and initialize optimizer. + device_mesh = DeviceMesh(self.device_type, list(range(NUM_DEVICES))) + parallelize_module(model_tp, device_mesh, PairwiseParallel()) + + device_mesh = model_tp.attn.qkv.weight.device_mesh + replicate = [Replicate()] * device_mesh.ndim + # Ensure model are initialized the same way. + self.assertEqual( + model.qkv.weight, + model_tp.attn.qkv.weight.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + self.assertEqual( + model.qkv.bias, + model_tp.attn.qkv.bias.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + self.assertEqual( + model.proj.weight, + model_tp.attn.proj.weight.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + self.assertEqual( + model.proj.bias, + model_tp.attn.proj.bias.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + + LR = 0.25 + optim = torch.optim.SGD(model.parameters(), lr=LR) + optim_tp = torch.optim.SGD(model_tp.parameters(), lr=LR) + + output = model(inp, inp, inp) + output_tp = model_tp(inp, inp, inp) + self.assertEqual(output, output_tp) + + output.sum().backward() + output_tp.sum().backward() + + device_mesh = model_tp.attn.qkv.weight.device_mesh + # Ensure gradients are same. + self.assertEqual( + model.qkv.weight.grad, + model_tp.attn.qkv.weight.grad.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + self.assertEqual( + model.qkv.bias.grad, + model_tp.attn.qkv.bias.grad.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + self.assertEqual( + model.proj.weight.grad, + model_tp.attn.proj.weight.grad.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + self.assertEqual( + model.proj.bias.grad, + model_tp.attn.proj.bias.grad.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + + optim.step() + optim_tp.step() + + # Ensure model weights are still same after update. + self.assertEqual( + model.qkv.weight, + model_tp.attn.qkv.weight.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + self.assertEqual( + model.qkv.bias, + model_tp.attn.qkv.bias.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + self.assertEqual( + model.proj.weight, + model_tp.attn.proj.weight.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + self.assertEqual( + model.proj.bias, + model_tp.attn.proj.bias.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + + inp = torch.rand(*inp_size, device=self.device_type) + output = model(inp, inp, inp) + output_tp = model_tp(inp, inp, inp) + self.assertEqual(output, output_tp) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/tensor/parallel/test_tp_style.py b/test/distributed/tensor/parallel/test_tp_style.py new file mode 100644 index 0000000000000..7aeb086f03a4c --- /dev/null +++ b/test/distributed/tensor/parallel/test_tp_style.py @@ -0,0 +1,186 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# Owner(s): ["oncall: distributed"] + +import torch +from torch.distributed._tensor import DeviceMesh, distribute_tensor, Replicate, Shard +from torch.distributed.tensor.parallel.style import ( + ColwiseParallel, + make_input_replicate_1d, + make_input_shard_1d, + make_output_replicate_1d, + make_output_shard_1d, + make_output_tensor, + RowwiseParallel, +) +from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.distributed._tensor.common_dtensor import ( + DTensorTestBase, + with_comms, +) + + +class TensorParallelStyleTest(DTensorTestBase): + @property + def world_size(self): + gpu_num = torch.cuda.device_count() + return gpu_num if gpu_num % 2 == 0 and gpu_num > 4 else 4 + + def _1d_input_func_check( + self, input_local_tensor, expected_local_tensor, func + ) -> None: + with self.assertRaisesRegex( + RuntimeError, "device_mesh is not passed nor can be inferred" + ): + dtensor = func(input_local_tensor) + device_mesh = DeviceMesh( + self.device_type, + torch.arange(self.world_size).reshape(self.world_size // 2, 2), + ) + with self.assertRaisesRegex( + RuntimeError, + "device_mesh has dims [0-9]+ but expcted to be 1 for input.", + ): + dtensor = func(input_local_tensor, device_mesh) + + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + # test 1: replicate local tensor + dtensor = func(input_local_tensor, device_mesh) + self.assertEqual(expected_local_tensor, dtensor.to_local()) + # test 2: replicate DTensor + dtensor = func(dtensor) + self.assertEqual(expected_local_tensor, dtensor.to_local()) + # test 3: replicate DTensor with DeviceMesh passed + dtensor = func(dtensor, device_mesh) + self.assertEqual(expected_local_tensor, dtensor.to_local()) + + @with_comms + def test_make_input_replicate_1d(self): + tensor = torch.rand(8, 16, device=self.device_type) + self._1d_input_func_check(tensor, tensor, make_input_replicate_1d) + + @with_comms + def test_make_input_shard_1d(self): + tensor = torch.rand(8, 16, device=self.device_type) + self._1d_input_func_check(tensor, tensor, make_input_shard_1d) + + # Common logic for testing prepare output funcs + def _test_prepare_output(self, func, spec, dim=None, device_mesh_input_none=False): + device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + tensor = torch.rand(8, 16, device=self.device_type) + dtensor = distribute_tensor(tensor, device_mesh, spec) + device_mesh_input = None if device_mesh_input_none else device_mesh + if dim is not None: + output = func(dtensor, device_mesh_input, dim) + else: + output = func(dtensor, device_mesh_input) + return output, dtensor, device_mesh + + @with_comms + def test_make_output_shard_1d(self): + # test when output is sharded. + output, dtensor, device_mesh = self._test_prepare_output( + make_output_shard_1d, [Shard(0)], 1 + ) + self.assertEqual(output, dtensor.redistribute(device_mesh, [Shard(1)])) + # test when output is replicated. + output, dtensor, device_mesh = self._test_prepare_output( + make_output_shard_1d, [Replicate()], 0 + ) + self.assertEqual(output, dtensor.redistribute(device_mesh, [Shard(0)])) + # test when input device_mesh is None. + output, dtensor, device_mesh = self._test_prepare_output( + make_output_shard_1d, [Shard(0)], 1, True + ) + self.assertEqual(output, dtensor.redistribute(device_mesh, [Shard(1)])) + + @with_comms + def test_make_output_replicate_1d(self): + output, dtensor, device_mesh = self._test_prepare_output( + make_output_replicate_1d, [Shard(0)] + ) + self.assertEqual(output, dtensor.redistribute(device_mesh, [Replicate()])) + # test when input device_mesh is None. + output, dtensor, device_mesh = self._test_prepare_output( + make_output_replicate_1d, [Shard(0)], None, True + ) + self.assertEqual(output, dtensor.redistribute(device_mesh, [Replicate()])) + + @with_comms + def test_make_output_tensor(self): + # test when output is sharded. + output, dtensor, device_mesh = self._test_prepare_output( + make_output_tensor, [Shard(0)] + ) + self.assertEqual( + output, dtensor.redistribute(device_mesh, [Replicate()]).to_local() + ) + # test when output is replicated. + output, dtensor, device_mesh = self._test_prepare_output( + make_output_tensor, [Replicate()] + ) + self.assertEqual( + output, dtensor.redistribute(device_mesh, [Replicate()]).to_local() + ) + # test when input device_mesh is None. + output, dtensor, device_mesh = self._test_prepare_output( + make_output_tensor, [Shard(0)], None, True + ) + self.assertEqual( + output, dtensor.redistribute(device_mesh, [Replicate()]).to_local() + ) + + # Common logic for testing prepare output funcs errors. + def _test_prepare_output_error(self, func): + tensor = torch.rand(8, 16, device=self.device_type) + device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + dtensor = distribute_tensor(tensor, device_mesh, [Shard(0)]) + output = [dtensor] + with self.assertRaisesRegex( + AssertionError, + "Expect output of Tensor Parallel to be a DTensor, but found" + f" {type(output)}.", + ): + func(output, device_mesh) + device_mesh = DeviceMesh( + self.device_type, + torch.arange(self.world_size).reshape(self.world_size // 2, 2), + ) + with self.assertRaisesRegex( + AssertionError, + "device_mesh has dims 2 but expcted to be 1 for output.", + ): + func(dtensor, device_mesh) + + @with_comms + def test_prepare_output_error(self): + self._test_prepare_output_error(make_output_shard_1d) + self._test_prepare_output_error(make_output_replicate_1d) + self._test_prepare_output_error(make_output_tensor) + + @with_comms + def test_rowwise_parallel_style(self): + tensor = torch.rand(8, 16, device=self.device_type) + rs = RowwiseParallel() + self._1d_input_func_check(tensor, tensor, rs._prepare_input) + # TODO: change output test + output, dtensor, device_mesh = self._test_prepare_output( + rs._prepare_input, [Shard(0)] + ) + self.assertEqual(output, dtensor.redistribute(device_mesh, [Replicate()])) + # test when input device_mesh is None. + output, dtensor, device_mesh = self._test_prepare_output( + rs._prepare_input, [Shard(0)], None, True + ) + self.assertEqual(output, dtensor.redistribute(device_mesh, [Replicate()])) + self._test_prepare_output_error(rs._prepare_output) + + @with_comms + def test_colwise_parallel_style(self): + tensor = torch.rand(8, 16, device=self.device_type) + cs = ColwiseParallel() + self._1d_input_func_check(tensor, tensor, cs._prepare_input) + self.assertEqual(make_output_replicate_1d, cs._prepare_output) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/tensor/parallel/test_view_sharding_dim_change.py b/test/distributed/tensor/parallel/test_view_sharding_dim_change.py new file mode 100644 index 0000000000000..4c1475ef5dba5 --- /dev/null +++ b/test/distributed/tensor/parallel/test_view_sharding_dim_change.py @@ -0,0 +1,30 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# Owner(s): ["oncall: distributed"] + +import torch +from torch.distributed._tensor import DeviceMesh, DTensor, Shard +from torch.distributed.tensor.parallel._view_with_dim_change import ( + _view_with_sharding_dim_change, +) +from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.distributed._tensor.common_dtensor import ( + DTensorTestBase, + with_comms, +) + + +class TPViewShardingDimChangeTest(DTensorTestBase): + @with_comms + def test_view_with_sharding_dim_change(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + torch.manual_seed(self.rank) + tensor = torch.rand(3, 5, 6, device=self.device_type) + sharding = [Shard(2)] + dt = DTensor.from_local(tensor, device_mesh, sharding) + dt = _view_with_sharding_dim_change(dt, 1, (3, -1, 6)) + self.assertTrue(dt.placements[0].is_shard(dim=1)) + self.assertEqual(dt.to_local(), tensor.view(3, -1, 6)) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/test_c10d_common.py b/test/distributed/test_c10d_common.py index 68c760beacbbf..962c12dcba9d7 100644 --- a/test/distributed/test_c10d_common.py +++ b/test/distributed/test_c10d_common.py @@ -2,6 +2,7 @@ import copy import os +import pickle import sys import tempfile import threading @@ -1427,6 +1428,9 @@ def test_send_recv(self): dist.send(input_tensor, (self.rank + 1) % self.world_size) self.assertEqual(input_tensor, torch.zeros(2, 2) + 1) + with self.assertRaises(ValueError): + dist.send(input_tensor, dist.get_rank()) + # test recv input_tensor = torch.zeros(2, 2) dist.recv(input_tensor, (self.rank + 1) % self.world_size) @@ -1464,8 +1468,19 @@ def _call_collective_with_varying_tensors(self, backend, collective, *args): # ensure supported devices (cpu, cuda) succeeds during dispatch call tensor = torch.zeros(2, 2, device=torch.device(device)) # multi tensor collectives - if collective == dist.all_gather: + if collective == dist.barrier: + collective() + elif collective in (dist.all_gather, dist.gather): collective([tensor], tensor, *args) + elif collective == dist.scatter: + collective(tensor, [tensor], *args) + elif collective in (dist.reduce_scatter, dist.all_to_all): + # gloo does not support reduce_scatter or all_to_all + if backend != "gloo": + if collective == dist.reduce_scatter: + collective(tensor, [tensor], *args) + else: + collective([tensor], [tensor], *args) else: collective(tensor, *args) @@ -1482,12 +1497,45 @@ def _test_collectives(self, backend): (dist.reduce, self.rank), (dist.broadcast, self.rank), (dist.all_reduce,), - (dist.all_gather,) + (dist.all_gather,), + (dist.reduce_scatter,), + (dist.barrier,), + (dist.all_to_all,), + (dist.scatter,), ] for collective, *args in collectives_and_args: with self.subTest(collective=collective, args=args): self._call_collective_with_varying_tensors(backend, collective, *args) + def _test_allreduce_coalesced(self, backend): + store = dist.FileStore(self.file_name, self.world_size) + dist.init_process_group( + backend, + world_size=self.world_size, + rank=self.rank, + store=store, + ) + # TODO: this will be updated in the future to not be backend specific + device = "cuda" if backend == "nccl" else "cpu" + tensors = [torch.ones(10, 10, device=torch.device(device))] + dist.all_reduce_coalesced(tensors, dist.ReduceOp.SUM) + for tensor in tensors: + self.assertEqual(tensor, torch.ones(10, 10) * self.world_size) + + def _test_all_to_all_single(self, backend): + store = dist.FileStore(self.file_name, self.world_size) + dist.init_process_group( + backend, + world_size=self.world_size, + rank=self.rank, + store=store, + ) + device = "cuda" if backend == "nccl" else "cpu" + # test alltoall_base + input_tensor = torch.ones(2, 2, device=torch.device(device)) + output_tensor = torch.zeros(2, 2, device=torch.device(device)) + dist.all_to_all_single(output_tensor, input_tensor) + class CompilerTest(MultiProcessTestCase): def setUp(self): super(CompilerTest, self).setUp() @@ -1622,6 +1670,69 @@ def comm_fn(tensor, group=None): self._test_work_wait(tensor, comm_fn=comm_fn) +class ReduceOpTest(TestCase): + + # Ref: https://github.com/pytorch/pytorch/issues/87191 + def test_op_isinstance_of_reduceop(self): + for reduce_op in ( + c10d.ReduceOp.SUM, c10d.ReduceOp.AVG, c10d.ReduceOp.PRODUCT, c10d.ReduceOp.MIN, c10d.ReduceOp.MAX, + c10d.ReduceOp.BAND, c10d.ReduceOp.BOR, c10d.ReduceOp.BXOR, + ): + self.assertTrue(isinstance(reduce_op, c10d.ReduceOp)) + for scale in (torch.tensor(1.0), 2.0): + self.assertTrue(isinstance(dist._make_nccl_premul_sum(scale), c10d.ReduceOp)) + + # Ref: https://github.com/pytorch/pytorch/pull/87303#discussion_r1002879700 + def test_reduceop_copyable(self): + for reduce_op in ( + c10d.ReduceOp.SUM, c10d.ReduceOp.AVG, c10d.ReduceOp.PRODUCT, c10d.ReduceOp.MIN, c10d.ReduceOp.MAX, + c10d.ReduceOp.BAND, c10d.ReduceOp.BOR, c10d.ReduceOp.BXOR, + ): + self.assertEqual(copy.copy(reduce_op), reduce_op) + self.assertEqual(copy.deepcopy(reduce_op), reduce_op) + self.assertEqual(copy.copy(c10d.ReduceOp(reduce_op)), reduce_op) + self.assertEqual(copy.deepcopy(c10d.ReduceOp(reduce_op)), reduce_op) + + for scale in (torch.tensor(1.0), 2.0): + reduce_op = dist._make_nccl_premul_sum(scale) + self.assertEqual(copy.copy(reduce_op), reduce_op) + self.assertEqual(copy.deepcopy(reduce_op), reduce_op) + + def test_reduceop_pickle(self): + for reduce_op in ( + c10d.ReduceOp.SUM, c10d.ReduceOp.AVG, c10d.ReduceOp.PRODUCT, c10d.ReduceOp.MIN, c10d.ReduceOp.MAX, + c10d.ReduceOp.BAND, c10d.ReduceOp.BOR, c10d.ReduceOp.BXOR, + ): + pickle.loads(pickle.dumps(reduce_op)) + orig = c10d.ReduceOp(reduce_op) + self.assertEqual(pickle.loads(pickle.dumps(orig)), orig) + for scale in (torch.tensor(1.0), 2.0): + reduce_op = dist._make_nccl_premul_sum(scale) + self.assertEqual(pickle.loads(pickle.dumps(reduce_op)), reduce_op) + + # Ref: https://github.com/pytorch/pytorch/issues/90072 + def test_reduceop_equal(self): + not_reduceop = "abc" + for reduce_op in ( + c10d.ReduceOp.SUM, c10d.ReduceOp.AVG, c10d.ReduceOp.PRODUCT, c10d.ReduceOp.MIN, c10d.ReduceOp.MAX, + c10d.ReduceOp.BAND, c10d.ReduceOp.BOR, c10d.ReduceOp.BXOR, + ): + reduce_op_obj = c10d.ReduceOp(reduce_op) + # this calls `ReduceOp.__eq__(self, other)` + self.assertEqual(reduce_op_obj, reduce_op_obj) + self.assertEqual(reduce_op_obj, reduce_op) + self.assertNotEqual(reduce_op_obj, not_reduceop) + self.assertNotEqual(reduce_op, not_reduceop) + # TODO(crcrpar): This needs to be `assertEqual` for the associativity even though + # the comparison of `RedOpType` and `ReduceOp` sounds less likely to happen compared + # to that of `ReduceOp` and `RedOptype`. + # this calls `RedOpType.__eq__(self, other)` + self.assertNotEqual(reduce_op, reduce_op_obj) + + self.assertFalse(None in (reduce_op, reduce_op_obj)) + self.assertFalse(not_reduceop in (reduce_op, reduce_op_obj)) + + if __name__ == "__main__": assert ( not torch.cuda._initialized diff --git a/test/distributed/test_c10d_error_logger.py b/test/distributed/test_c10d_error_logger.py new file mode 100644 index 0000000000000..7c8a6241b76b5 --- /dev/null +++ b/test/distributed/test_c10d_error_logger.py @@ -0,0 +1,142 @@ +# Owner(s): ["oncall: distributed"] + +import json +import logging +import os +import re +import sys +from functools import partial, wraps + +import torch +import torch.distributed as dist + +from torch.distributed.c10d_error_logger import _get_or_create_logger +from torch.distributed.distributed_c10d import exception_handler + +if not dist.is_available(): + print("Distributed not available, skipping tests", file=sys.stderr) + sys.exit(0) + +from torch.testing._internal.common_distributed import MultiProcessTestCase, TEST_SKIPS +from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN + +if TEST_WITH_DEV_DBG_ASAN: + print( + "Skip dev-asan as torch + multiprocessing spawn have known issues", + file=sys.stderr, + ) + sys.exit(0) + +BACKEND = dist.Backend.NCCL +WORLD_SIZE = min(4, max(2, torch.cuda.device_count())) + + +def with_comms(func=None): + if func is None: + return partial( + with_comms, + ) + + @wraps(func) + def wrapper(self, *args, **kwargs): + if BACKEND == dist.Backend.NCCL and torch.cuda.device_count() < self.world_size: + sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code) + self.dist_init() + func(self) + self.destroy_comms() + + return wrapper + + +class C10dErrorLoggerTest(MultiProcessTestCase): + def setUp(self): + super(C10dErrorLoggerTest, self).setUp() + os.environ["WORLD_SIZE"] = str(self.world_size) + os.environ["BACKEND"] = BACKEND + self._spawn_processes() + + @property + def device(self): + return ( + torch.device(self.rank) + if BACKEND == dist.Backend.NCCL + else torch.device("cpu") + ) + + @property + def world_size(self): + return WORLD_SIZE + + @property + def process_group(self): + return dist.group.WORLD + + def destroy_comms(self): + # Wait for all ranks to reach here before starting shutdown. + dist.barrier() + dist.destroy_process_group() + + def dist_init(self): + dist.init_process_group( + backend=BACKEND, + world_size=self.world_size, + rank=self.rank, + init_method=f"file://{self.file_name}", + ) + + # set device for nccl pg for collectives + if BACKEND == "nccl": + torch.cuda.set_device(self.rank) + + def test_get_or_create_logger(self): + logger = _get_or_create_logger() + self.assertIsNotNone(logger) + self.assertEqual(1, len(logger.handlers)) + self.assertIsInstance(logger.handlers[0], logging.NullHandler) + + @exception_handler + def failed_broadcast_raise_exception(self): + tensor = torch.arange(2, dtype=torch.int64) + dist.broadcast(tensor, self.world_size + 1) + + @exception_handler + def failed_broadcast_not_raise_exception(self): + try: + tensor = torch.arange(2, dtype=torch.int64) + dist.broadcast(tensor, self.world_size + 1) + except Exception as exception: + pass + + @with_comms + def test_exception_handler_with_dist(self) -> None: + with self.assertRaises(Exception) as exception: + self.failed_broadcast_raise_exception() + + with self.assertLogs(dist._c10d_error_logger, level="DEBUG") as captured: + self.failed_broadcast_not_raise_exception() + error_msg_dict = json.loads( + re.search("({.+})", captured.output[0]).group(0).replace("'", '"') + ) + self.assertEqual(len(error_msg_dict), 7) + + self.assertIn("func_name", error_msg_dict.keys()) + self.assertEqual("broadcast", error_msg_dict["func_name"]) + + self.assertIn("args", error_msg_dict.keys()) + + self.assertIn("backend", error_msg_dict.keys()) + self.assertEqual("nccl", error_msg_dict["backend"]) + + self.assertIn("world_size", error_msg_dict.keys()) + self.assertEqual(str(self.world_size), error_msg_dict["world_size"]) + + self.assertIn("global_rank", error_msg_dict.keys()) + self.assertIn(str(dist.get_rank()), error_msg_dict["global_rank"]) + + # In this test case, local_rank = global_rank, since we don't have multiple processes on one node. + self.assertIn("local_rank", error_msg_dict.keys()) + self.assertIn(str(dist.get_rank()), error_msg_dict["local_rank"]) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/test_c10d_gloo.py b/test/distributed/test_c10d_gloo.py index e0c7c64f7b836..b26c9e9316f3e 100644 --- a/test/distributed/test_c10d_gloo.py +++ b/test/distributed/test_c10d_gloo.py @@ -23,28 +23,35 @@ import torch.nn.functional as F import torch.testing._internal.common_utils as common from test_c10d_common import ( - LOOPBACK, gpus_for_rank, - Task, + LOOPBACK, ModuleForDdpCommHook, SparseGradientModule, + Task, ) from torch import nn +from torch.distributed._shard.sharded_tensor import ( + init_from_local_shards, + Shard, + ShardedTensor, + ShardMetadata, +) from torch.nn.parallel import DistributedDataParallel +from torch.nn.parallel._replicated_tensor_ddp_utils import _ddp_replicated_tensor from torch.testing._internal.common_distributed import ( + create_device, MultiProcessTestCase, requires_gloo, - skip_if_lt_x_gpu, simple_sparse_reduce_tests, + skip_if_lt_x_gpu, skip_if_win32, - create_device, verify_ddp_error_logged, ) from torch.testing._internal.common_utils import ( - TestCase, - run_tests, retry_on_connect_failures, + run_tests, sandcastle_skip, + TestCase, ) @@ -68,7 +75,7 @@ def simple_reduce_tests(rank, world_size): ( c10d.ReduceOp.MAX, torch.tensor([rank + 1.0]), - torch.tensor([world_size]), + torch.tensor([float(world_size)]), ), ] @@ -121,7 +128,7 @@ def simple_coalesced_reduce_tests(rank, world_size): return [ ( c10d.ReduceOp.SUM, - [torch.tensor([rank + 1]), torch.tensor([(rank + 1) ** 2])], + [torch.tensor([rank + 1.0]), torch.tensor([(rank + 1.0) ** 2])], [ torch.tensor([float(world_size * (world_size + 1) / 2)]), torch.tensor( @@ -145,7 +152,7 @@ def simple_coalesced_reduce_tests(rank, world_size): ( c10d.ReduceOp.MAX, [torch.tensor([rank + x]) for x in [1.0, 2.0]], - [torch.tensor([world_size]), torch.tensor([world_size + 1.0])], + [torch.tensor([float(world_size)]), torch.tensor([world_size + 1.0])], ), ] @@ -170,7 +177,7 @@ def simple_multi_input_reduce_tests(rank, world_size): ( c10d.ReduceOp.MAX, [torch.tensor([2 * rank + 1.0]), torch.tensor([2 * rank + 2.0])], - torch.tensor([2 * world_size]), + torch.tensor([2.0 * world_size]), ), ] @@ -247,7 +254,7 @@ def test_empty_tensors(self): fut.wait() output = fut.value() self.assertEqual(0, output[0].numel()) - self.assertEqualIgnoreType(xs[0], output[0]) + self.assertEqual(xs[0], output[0]) @requires_gloo() def test_broadcast_checks(self): @@ -321,8 +328,7 @@ def broadcast(xs, rootRank, rootTensor): # Run with 1 input tensor x = fn(torch.tensor([self.rank])) output = broadcast([x], i, 0) - # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 - self.assertEqualIgnoreType(torch.tensor([i]), output[0]) + self.assertEqual(torch.tensor([i]), output[0]) # Run with 2 input tensors num = 2 @@ -333,10 +339,8 @@ def broadcast(xs, rootRank, rootTensor): ] output = broadcast(xs, i, j) - # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 - self.assertEqualIgnoreType(torch.tensor([i * num + j]), output[0]) - # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 - self.assertEqualIgnoreType(torch.tensor([i * num + j]), output[1]) + self.assertEqual(torch.tensor([i * num + j], dtype=torch.float32), output[0]) + self.assertEqual(torch.tensor([i * num + j], dtype=torch.float32), output[1]) # Test overloaded convenience function x = torch.tensor([self.rank + 1.0]) @@ -422,8 +426,7 @@ def _test_allreduce_basics(self, fn): fut = pg.allreduce([tensor], opts).get_future() fut.wait() result = fut.value() - # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 - self.assertEqualIgnoreType(expected, result[0]) + self.assertEqual(expected, result[0]) # Multi input tests tests = simple_multi_input_reduce_tests(self.rank, self.world_size) @@ -435,8 +438,7 @@ def _test_allreduce_basics(self, fn): fut.wait() result = fut.value() for tensor in result: - # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 - self.assertEqualIgnoreType(output, tensor) + self.assertEqual(output, tensor) # Test overloaded convenience function (defaults to using sum) x = fn(torch.tensor([self.rank + 1.0])) @@ -474,8 +476,7 @@ def _test_allreduce_basics_using_work_api(self, fn): work = pg.allreduce([tensor], opts) work.wait() result = work.result() - # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 - self.assertEqualIgnoreType(expected, result[0]) + self.assertEqual(expected, result[0]) # Multi input tests tests = simple_multi_input_reduce_tests(self.rank, self.world_size) @@ -487,8 +488,7 @@ def _test_allreduce_basics_using_work_api(self, fn): work.wait() result = work.result() for tensor in result: - # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 - self.assertEqualIgnoreType(output, tensor) + self.assertEqual(output, tensor) # Test overloaded convenience function (defaults to using sum) x = fn(torch.tensor([self.rank + 1.0])) @@ -519,12 +519,11 @@ def _test_allreduce_stress(self, inputs): ] for i, future_handle in enumerate(future_handles): future_handle.wait() - # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 - self.assertEqualIgnoreType( + self.assertEqual( torch.tensor( [ (i * self.world_size) - + (self.world_size * (self.world_size - 1) / 2) + + (self.world_size * (self.world_size - 1) // 2) ] ), future_handle.value()[0], @@ -598,8 +597,7 @@ def _test_allreduce_coalesced_basics(self, fn): fut.wait() result = fut.value() for result_tensor, expected in zip(result, outputs): - # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 - self.assertEqualIgnoreType(result_tensor, expected) + self.assertEqual(result_tensor, expected) @requires_gloo() def test_allreduce_coalesced_basics(self): @@ -607,7 +605,7 @@ def test_allreduce_coalesced_basics(self): def _expected_output(self, i): ws = self.world_size - return 2 * [torch.tensor([(i * ws) + (ws * (ws - 1) / 2)])] + return 2 * [torch.tensor([(i * ws) + (ws * (ws - 1) // 2)])] def _test_allreduce_coalesced_stress(self, inputs): store = c10d.FileStore(self.file_name, self.world_size) @@ -620,8 +618,7 @@ def _test_allreduce_coalesced_stress(self, inputs): for i, future_handle in enumerate(future_handles): future_handle.wait() result = future_handle.value() - # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 - self.assertEqualIgnoreType( + self.assertEqual( self._expected_output(i), result, msg="Mismatch in iteration {}".format(i), @@ -643,7 +640,7 @@ def test_allreduce_coalesced_async(self): futs = [c10d.all_reduce_coalesced(x, async_op=True) for x in xs] torch.futures.wait_all(futs) for i, fut in enumerate(futs): - self.assertEqualIgnoreType( + self.assertEqual( self._expected_output(i), fut.wait(), msg="Mismatch in iteration {}".format(i), @@ -1233,7 +1230,7 @@ def test_allgather_coalesced_async(self): # one output tensor list for y, z in zip(y_out, z_out): # one tensor in output tensor list - self.assertEqualIgnoreType(y, z) + self.assertEqual(y, z) # Added to address https://github.com/pytorch/pytorch/issues/65231 # In the failed tests, all assertEqualIgnoreType are passed on all @@ -1296,8 +1293,7 @@ def _test_reduce_basics(self, fn): fut.wait() result = fut.value() if root == self.rank: - # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 - self.assertEqualIgnoreType(output, result[0]) + self.assertEqual(output, result[0]) @requires_gloo() def test_reduce_basics(self): @@ -1330,12 +1326,11 @@ def _test_reduce_stress(self, inputs): iter = i // self.world_size root = i % self.world_size if root == self.rank: - # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 - self.assertEqualIgnoreType( + self.assertEqual( torch.tensor( [ (iter * self.world_size) - + (self.world_size * (self.world_size - 1) / 2) + + (self.world_size * (self.world_size - 1) // 2) ] ), result[0], @@ -1754,6 +1749,49 @@ def forward(self, x): loss = criterion(output, target) loss.backward() + @requires_gloo() + @skip_if_lt_x_gpu(2) + def test_ignored_sharded_tensor(self): + class MyModule(nn.Module): + def __init__(self, shard_tensor: ShardedTensor) -> None: + super().__init__() + self.fc1 = nn.Linear(2, 10, bias=False) + self.st = nn.Parameter(shard_tensor) + self.relu = nn.ReLU() + + def forward(self, x): + x = self.relu(self.fc1(x)) + return F.softmax(x, dim=1) + pg = dist.init_process_group( + "gloo", + init_method=f"file://{self.file_name}", + world_size=self.world_size, + rank=self.rank, + ) + device = torch.device(f"cuda:{self.rank}") + local_shard_metadata = ShardMetadata( + shard_offsets=[(self.rank % 2) * 5, 0], + shard_sizes=[5, 10], + placement=f"rank:{self.rank}/cuda:{self.rank}" + ) + local_shards = [Shard(torch.randn(5, 10, device=device), local_shard_metadata)] + st = init_from_local_shards(local_shards, [10, 10]) + m = MyModule(st) + with _ddp_replicated_tensor(False): + DistributedDataParallel._set_params_and_buffers_to_ignore_for_model( + module=m, + params_and_buffers_to_ignore={'st'} + ) + # test to make DDP constructor will not fail when module includes a ShardedTensor when ignored + DistributedDataParallel( + m, + device_ids=[device] if device.type == "gpu" else None, + process_group=pg, + gradient_as_bucket_view=True, + broadcast_buffers=False, + static_graph=True, + ) + def _run_and_verify_sparse_gradients(self, vanilla_model, ddp_model): mult = 2 batch_size = mult * self.world_size @@ -2262,7 +2300,7 @@ def _test_broadcast_coalesced(self, process_group, device, root_rank): target += torch.arange(60, dtype=half, device=device).chunk(5) target += torch.arange(60, dtype=torch.float32, device=device).chunk(5) - # The tensors to pass to broadcast are idential to the target + # The tensors to pass to broadcast are identical to the target # only on the process that is the root of the broadcast. if self.rank == root_rank: tensors = list(tensor.clone() for tensor in target) @@ -2363,6 +2401,39 @@ class GlooProcessGroupWithDispatchedCollectivesTests(test_c10d_common.ProcessGro def test_collectives(self): self._test_collectives(backend="gloo") + @requires_gloo() + def test_allreduce_coalesced(self): + self._test_allreduce_coalesced(backend="gloo") + + @requires_gloo() + def test_all_to_all_single(self): + self._test_all_to_all_single(backend="gloo") + + @requires_gloo() + def test_allgather_coalesced(self): + store = dist.FileStore(self.file_name, self.world_size) + dist.init_process_group( + "gloo", + world_size=self.world_size, + rank=self.rank, + store=store, + ) + input_tensor = torch.ones(10, 10, dtype=torch.float32) + output_tensor_list = [torch.zeros_like(input_tensor)] + dist.all_gather_coalesced([output_tensor_list], [input_tensor]) + self.assertEqual(output_tensor_list, [input_tensor]) + + @requires_gloo() + def test_monitored_barrier(self): + store = dist.FileStore(self.file_name, self.world_size) + dist.init_process_group( + "gloo", + world_size=self.world_size, + rank=self.rank, + store=store, + ) + dist.monitored_barrier() + class CompilerTest(test_c10d_common.CompilerTest): @property diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 78b1cbbe676cf..ecea7c7811681 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -290,8 +290,7 @@ def broadcast(xs, rootRank, rootTensor): # Run with 1 input tensor x = torch.tensor([self.rank]).cuda(self.rank_to_GPU[self.rank][0]) output = broadcast([x], i, 0) - # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 - self.assertEqualIgnoreType(torch.tensor([i]), output[0]) + self.assertEqual(torch.tensor([i]), output[0]) expected_tensor = torch.empty([i + 1, i + 1]).fill_(i + 1) xs = [torch.empty([i + 1, i + 1]).fill_(-1).cuda(device=device_idx) for device_idx in self.rank_to_GPU[self.rank]] @@ -326,10 +325,9 @@ def allreduce(tensors, op): allreduce(tensors, c10d.ReduceOp.SUM) - ndev = float(self.world_size) - # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 - self.assertEqualIgnoreType( - torch.tensor([ndev * (ndev + 1) / 2]), + ndev = self.world_size + self.assertEqual( + torch.tensor([ndev * (ndev + 1) // 2]), tensors[0], ) @@ -338,9 +336,8 @@ def allreduce(tensors, op): tensors = [torch.tensor([self.rank + 1.]).cuda(local_device_id)] allreduce(tensors, c10d.ReduceOp.AVG) - # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 - ndev = float(self.world_size) - self.assertEqualIgnoreType( + ndev = self.world_size + self.assertEqual( torch.tensor([ndev * (ndev + 1.) / (2. * ndev)]), tensors[0], ) @@ -348,16 +345,14 @@ def allreduce(tensors, op): # Premul Sum if torch.cuda.nccl.version() >= (2, 11, 1): for dtype in torch.half, torch.float, torch.double: - for factor in (3.0, - (torch.tensor([5.0], device=local_device_id, dtype=dtype),)): + for factor in (3.0, torch.tensor([5.0], device=local_device_id, dtype=dtype)): tensors = [torch.tensor([self.rank + 1]).cuda(local_device_id).to(dtype=dtype)] allreduce(tensors, c10d._make_nccl_premul_sum(factor)) - f = factor if isinstance(factor, float) else factor[0] - # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 - self.assertEqualIgnoreType( - f * torch.tensor([float(self.world_size * (self.world_size + 1) / 2)], device=local_device_id), + self.assertEqual( + factor * torch.tensor([self.world_size * (self.world_size + 1) / 2], + dtype=dtype, device=local_device_id), tensors[0], ) @@ -365,17 +360,15 @@ def allreduce(tensors, op): tensors = [torch.tensor([self.rank + 1]).cuda(local_device_id)] allreduce(tensors, c10d.ReduceOp.PRODUCT) - # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 - self.assertEqualIgnoreType( - torch.tensor([float(math.factorial(self.world_size))]), tensors[0] + self.assertEqual( + torch.tensor([math.factorial(self.world_size)]), tensors[0] ) # Min tensors = [torch.tensor([self.rank + 1]).cuda(local_device_id)] allreduce(tensors, c10d.ReduceOp.MIN) - # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 - self.assertEqualIgnoreType(torch.tensor([1.0]), tensors[0]) + self.assertEqual(torch.tensor([1]), tensors[0]) # Max tensors = [torch.tensor([self.rank + 1]).cuda(local_device_id)] @@ -412,14 +405,13 @@ def reduce(xs, rootRank, rootTensor, op=None): reduce(tensors, rt, 0) - # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 if self.rank == rt: - self.assertEqualIgnoreType( - torch.tensor([float(self.world_size * (self.world_size + 1) / 2)]), + self.assertEqual( + torch.tensor([self.world_size * (self.world_size + 1) // 2]), tensors[0], ) else: - self.assertEqualIgnoreType( + self.assertEqual( torch.tensor([self.rank + 1]), tensors[0], ) @@ -435,9 +427,9 @@ def reduce(xs, rootRank, rootTensor, op=None): # Premul sum if torch.cuda.nccl.version() >= (2, 11, 1): - for factor in (3.0, (torch.tensor([5.0], device=local_device_id),)): - if isinstance(factor, tuple): - factor_ref = factor[0].cpu().item() + for factor in (3.0, torch.tensor([5.0], device=local_device_id)): + if isinstance(factor, torch.Tensor): + factor_ref = factor.cpu().item() else: factor_ref = factor float_tensors = [ @@ -513,7 +505,7 @@ def allgather_base(output_t, input_t): work = pg._allgather_base(output_t, input_t) work.wait() - # anticpate an error + # anticipate an error with self.assertRaisesRegex( RuntimeError, "output tensor size must be equal to world_size times input tensor size", @@ -525,7 +517,7 @@ def allgather_base(output_t, input_t): # fails the check because output_t is not correctly sized allgather_base(output_t, tensor) - # anticpate an error + # anticipate an error with self.assertRaisesRegex( RuntimeError, "output tensor must have the same type as input tensor" ): @@ -801,7 +793,7 @@ def reduce_scatter_base(output_t, input_t): work = pg._reduce_scatter_base(output_t, input_t) work.wait() - # anticpate an error + # anticipate an error with self.assertRaisesRegex( RuntimeError, "input tensor must be the same size as output size times world size", @@ -813,7 +805,7 @@ def reduce_scatter_base(output_t, input_t): # fails the check because output_t is not correctly sized reduce_scatter_base(output_t, input_t) - # anticpate an error + # anticipate an error with self.assertRaisesRegex( RuntimeError, "input tensor must be the same type as the output tensor." ): @@ -861,13 +853,12 @@ def reduce_scatter(outputs, input_lists, op): for i in range(num_gpus): expected = torch.tensor( [ - float((1 + self.world_size) * self.world_size / 2) + (1 + self.world_size) * self.world_size // 2 + self.world_size * self.rank ]) - # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 - self.assertEqualIgnoreType(expected, output[i]) + self.assertEqual(expected, output[i]) # Min reduce_scatter(output, tensor_lists, c10d.ReduceOp.MIN) @@ -888,7 +879,7 @@ def reduce_scatter(outputs, input_lists, op): # Product reduce_scatter(output, tensor_lists, c10d.ReduceOp.PRODUCT) - # math pakcage don't have math.perm until python 3.8, so + # math package don't have math.perm until python 3.8, so # we implement a naive version here. def perm(n, k): prod_val = n @@ -900,8 +891,7 @@ def perm(n, k): prod_val = perm(self.rank + self.world_size, self.world_size) expected = torch.tensor([prod_val]) - # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 - self.assertEqualIgnoreType(expected, output[i]) + self.assertEqual(expected, output[i]) # Test the input params overridden scenarios, aka, when the input is # a list and output is just one tensor. @@ -910,19 +900,19 @@ def perm(n, k): input_list = [tensor[0].cuda(self.rank) for tensor in input_per_gpu] pg.reduce_scatter(output_tensor, input_list, c10d.ReduceOp.SUM).wait() expected = torch.tensor( - float((1 + self.world_size) * self.world_size / 2) + self.world_size * self.rank + (1 + self.world_size) * self.world_size // 2 + self.world_size * self.rank ) - self.assertEqualIgnoreType(expected, output_tensor) + self.assertEqual(expected, output_tensor) # Min pg.reduce_scatter(output_tensor, input_list, c10d.ReduceOp.MIN).wait() expected = torch.tensor(self.rank + 1) - self.assertEqualIgnoreType(expected, output_tensor) + self.assertEqual(expected, output_tensor) # Max pg.reduce_scatter(output_tensor, input_list, c10d.ReduceOp.MAX).wait() expected = torch.tensor(self.rank + self.world_size) - self.assertEqualIgnoreType(expected, output_tensor) + self.assertEqual(expected, output_tensor) # Product pg.reduce_scatter(output_tensor, input_list, c10d.ReduceOp.PRODUCT).wait() @@ -930,12 +920,12 @@ def perm(n, k): for k in range(1, self.world_size): prod_val = prod_val * (self.rank + 1 + k) expected = torch.tensor(prod_val) - self.assertEqualIgnoreType(expected, output_tensor) + self.assertEqual(expected, output_tensor) if torch.cuda.nccl.version() >= (2, 11, 1): - for factor in (3.0, (torch.tensor([5.0], device=self.rank),),): - if isinstance(factor, tuple): - factor_ref = factor[0].cpu().item() + for factor in (3.0, torch.tensor([5.0], device=self.rank)): + if isinstance(factor, torch.Tensor): + factor_ref = factor.cpu().item() else: factor_ref = factor output = [t.float() for t in output] @@ -997,8 +987,7 @@ def allreduce(tensors): for i in range(1, len(local_device_ids) + 1): for j in range(i): - # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 - self.assertEqualIgnoreType( + self.assertEqual( torch.tensor([(j + 1) * self.world_size]), tensors_list[i - 1][j] ) @@ -1025,6 +1014,18 @@ def test_send_recv(self): with self.assertRaisesRegex(RuntimeError, 'Tensors must be contiguous'): dist.send(send_tensor_view, 1) + @requires_nccl() + @sandcastle_skip_if(torch.cuda.device_count() < 1, "NCCL test requires 1 GPU") + @skip_if_lt_x_gpu(1) + def test_nccl_dist_backend_error(self): + store = c10d.FileStore(self.file_name, self.world_size) + self._create_process_group_nccl(store, self.opts()) + + # Both rank 0 and 1 will use the same CUDA device resulting in ncclInvalidUsage + with self.assertRaises(dist.DistBackendError) as cm: + dist.broadcast(torch.tensor([1, 2, 3]).cuda(), 0) + + self.assertIsInstance(cm.exception, RuntimeError) class DistributedDataParallelTest( test_c10d_common.CommonDistributedDataParallelTest, MultiProcessTestCase @@ -1652,8 +1653,7 @@ def step_model(model, input, target): target[self.rank : (self.rank + 1)], ) for i, j in zip(model.parameters(), ddp_model.parameters()): - # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 - self.assertEqualIgnoreType(i.grad, j.grad, rtol=1.3e-06, atol=5e-5) + self.assertEqual(i.grad, j.grad, rtol=1.3e-06, atol=5e-5) # Shuffle the input so that DDP input is different torch.manual_seed(1337 + iteration) @@ -2598,7 +2598,7 @@ def test_nccl_timeout(self): try: pg_gloo.barrier().wait() except Exception as e: - raise ValueError(f"Rank {self.rank} barrier timed out waiting for rank 0 with error: {str(e)}") + raise ValueError(f"Rank {self.rank} barrier timed out waiting for rank 0 with error: {str(e)}") from e # Now verify communicators on this rank have # been aborted by watchdog. self._wait_for_comm_abort(process_group, failed_collective_timeout) @@ -2943,6 +2943,48 @@ class NcclProcessGroupWithDispatchedCollectivesTests(test_c10d_common.ProcessGro def test_collectives(self): self._test_collectives(backend="nccl") + @requires_nccl() + @skip_if_lt_x_gpu(1) + def test_allreduce_coalesced(self): + self._test_allreduce_coalesced(backend="nccl") + + @requires_nccl() + @skip_if_lt_x_gpu(1) + def test_all_to_all_single(self): + self._test_all_to_all_single(backend="nccl") + + @requires_nccl() + @skip_if_lt_x_gpu(1) + def test_allgather_base(self): + store = dist.FileStore(self.file_name, self.world_size) + dist.init_process_group( + "nccl", + world_size=self.world_size, + rank=self.rank, + store=store, + ) + device = "cuda" + tensor = torch.ones(10, 10, device=torch.device(device)) + output_tensor = torch.zeros(10, 10, device=torch.device(device)) + dist.all_gather_into_tensor(output_tensor, tensor) + self.assertEqual(output_tensor, tensor) + + @requires_nccl() + @skip_if_lt_x_gpu(1) + def test_reduce_scatter_base(self): + store = dist.FileStore(self.file_name, self.world_size) + dist.init_process_group( + "nccl", + world_size=self.world_size, + rank=self.rank, + store=store, + ) + device = "cuda" + tensor = torch.ones(10, 10, device=torch.device(device)) + output_tensor = torch.zeros(10, 10, device=torch.device(device)) + dist.reduce_scatter_tensor(output_tensor, tensor) + self.assertEqual(output_tensor, tensor) + if __name__ == "__main__": assert ( not torch.cuda._initialized diff --git a/test/distributed/test_c10d_spawn_ucc.py b/test/distributed/test_c10d_spawn_ucc.py new file mode 100644 index 0000000000000..eabd7e1cf45b5 --- /dev/null +++ b/test/distributed/test_c10d_spawn_ucc.py @@ -0,0 +1,110 @@ +# Owner(s): ["oncall: distributed"] + +import sys +import test_c10d_spawn +import torch +import torch.distributed as c10d +from test_c10d_spawn import _torch_dist_nn_available, TestDistributedNNFunctions +from torch.testing._internal.common_cuda import TEST_MULTIGPU +from torch.testing._internal.common_distributed import ( + requires_ucc, + skip_if_lt_x_gpu, +) +from torch.testing._internal.common_utils import ( + TestCase, + run_tests, + sandcastle_skip, + sandcastle_skip_if, + TEST_WITH_DEV_DBG_ASAN, +) + +NO_UCC = not hasattr(c10d, "ProcessGroupUCC") + +# Fails on Python-3.9, see https://github.com/pytorch/pytorch/issues/51619 +if sys.version_info < (3, 9): + + class ProcessGroupShareTensorTest( + test_c10d_spawn.AbstractProcessGroupShareTensorTest, TestCase + ): + @classmethod + def _init_pg_ucc(cls, rank, filename, world_size): + store = c10d.FileStore(filename, world_size) + return c10d.ProcessGroupUCC(store, rank, world_size) + + @sandcastle_skip_if(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed") + @sandcastle_skip_if(NO_UCC, "UCC needed") + def test_shared_broadcast_ucc(self): + self._test_multiprocess( + ProcessGroupShareTensorTest._test_broadcast_process, + [torch.ones(2, 2).to(i) * i for i in range(self.world_size)], + ProcessGroupShareTensorTest._init_pg_ucc, + 1, + ) + + @sandcastle_skip_if(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed") + @sandcastle_skip_if(NO_UCC, "UCC needed") + def test_shared_allreduce_ucc(self): + self._test_multiprocess( + ProcessGroupShareTensorTest._test_allreduce_process, + [torch.ones(2, 2).to(i) for i in range(self.world_size)], + ProcessGroupShareTensorTest._init_pg_ucc, + 1, + ) + + @sandcastle_skip_if(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed") + @sandcastle_skip_if(NO_UCC, "UCC needed") + def test_shared_allgather_ucc(self): + self._test_multiprocess( + ProcessGroupShareTensorTest._test_allgather_process, + [torch.ones(2, 2).to(i) * i for i in range(self.world_size)], + ProcessGroupShareTensorTest._init_pg_ucc, + self.world_size, + ) + + +# Skip dev-asan as torch + multiprocessing spawn have known issues +if not TEST_WITH_DEV_DBG_ASAN: + + class TestDistributedNNFunctionsUcc(TestDistributedNNFunctions): + # Test Common Ops First. + @requires_ucc() + @skip_if_lt_x_gpu(2) + @sandcastle_skip_if( + not _torch_dist_nn_available, "torch.distributed.nn is not available" + ) + def test_broadcast(self): + self._test_broadcast("ucc") + + @requires_ucc() + @skip_if_lt_x_gpu(2) + @sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available") + def test_reduce(self): + self._test_reduce("ucc") + + @requires_ucc() + @skip_if_lt_x_gpu(2) + @sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available") + def test_allreduce(self): + self._test_allreduce("ucc") + + @requires_ucc() + @skip_if_lt_x_gpu(2) + @sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available") + @sandcastle_skip("runs into illegal memory access on first assertEqual check when run locally") + def test_all_gather(self): + self._test_all_gather("ucc") + + @requires_ucc() + @skip_if_lt_x_gpu(2) + @sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available") + def test_all_to_all(self): + self._test_all_to_all("ucc") + + @requires_ucc() + @skip_if_lt_x_gpu(2) + @sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available") + def test_all_to_all_single(self): + self._test_all_to_all_single("ucc") + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/test_dynamo_distributed.py b/test/distributed/test_dynamo_distributed.py new file mode 100644 index 0000000000000..ade7d92543995 --- /dev/null +++ b/test/distributed/test_dynamo_distributed.py @@ -0,0 +1,582 @@ +# Owner(s): ["module: dynamo"] +import copy +import functools +import os +import random +import unittest +from unittest.mock import patch +import numpy as np +import torch +import torch._dynamo +from torch._dynamo.optimizations.distributed import DDPOptimizer +import torch._dynamo.test_case +import torch.distributed as dist +from contextlib import contextmanager +from torch import nn +from torch._dynamo import config +from torch._dynamo.utils import same +from torch._dynamo.testing import collect_results +from torch._inductor.utils import has_triton +from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.testing._internal.common_distributed import ( + MultiProcessTestCase, + import_transformers_or_skip, + skip_if_lt_x_gpu, + requires_nccl +) +import torch._dynamo.logging + + +def reset_rng_state(): + torch.manual_seed(1337) + random.seed(1337) + np.random.seed(1337) + +def init_weights(m): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + m.bias.data.fill_(0.01) + +class ToyModel(nn.Module): + def __init__(self, in_feat=10, hidden_feat=5000, out_feat=5): + super().__init__() + self.net = nn.Sequential( + *[nn.Linear(in_feat, hidden_feat), nn.ReLU()] + + [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()] + + [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()] + + [nn.Linear(hidden_feat, out_feat), nn.ReLU()] + ) + + def forward(self, inputs): + return self.net(inputs) + +def get_model(device, bsz=20, in_feat=10, hidden_feat=5000, out_feat=5): + m = ToyModel(in_feat=in_feat, hidden_feat=hidden_feat, out_feat=out_feat).to(device) + m.apply(init_weights) + inputs = torch.rand(bsz, in_feat).to(device) + outputs = m(inputs) + return m, inputs, outputs + +def get_custom_model(device): + class MyCustomLinear(torch.nn.Module): + def __init__(self): + super(MyCustomLinear, self).__init__() + self.weight = nn.Parameter(torch.randn(512, 512)) + + def forward(self, x): + return torch.mm(x, self.weight.t()) + + class MyLinear(torch.nn.Module): + def __init__(self): + super(MyLinear, self).__init__() + self.linear = torch.nn.Linear(512, 512) + + def forward(self, x): + return self.linear(x) + + class MyModule(torch.nn.Module): + def __init__(self): + super(MyModule, self).__init__() + mods = [ + (MyLinear(), torch.nn.ReLU()), + # sandwitch the custom in the middle so it comes before and after + (MyCustomLinear(), torch.nn.ReLU()), + (MyLinear(), torch.nn.ReLU()), + ] + self.seq = torch.nn.Sequential(*[x for items in mods for x in items]) + + def forward(self, x): + return self.seq(x) + + m = MyModule().to(device) + m.apply(init_weights) + inputs = torch.rand((512, 512)).to(device) + correct_outputs = m(inputs) + return m, inputs, correct_outputs + +def get_hf_bert(rank): + # Note: use @import_transformers_or_skip on your test case if you use this + # in a multiprocessing test + try: + from transformers import BertConfig, AutoModelForMaskedLM + except ImportError as e: + raise unittest.SkipTest("Unable to import transformers") from e + + batch_size, max_length, config, device = 4, 512, BertConfig(), f"cuda:{rank}" + model = AutoModelForMaskedLM.from_config(config).to(device) + input_ids = torch.randint(0, config.vocab_size, (batch_size, max_length)).to(device) + decoder_ids = torch.randint(0, config.vocab_size, (batch_size, max_length)).to(device) + inputs = {'input_ids': input_ids, 'labels': decoder_ids} + model.train() + return model, inputs + +class CheckSplitsCompiler: + def __init__(self): + self.compiler_called = 0 + + def compile_fn(self, gm, example_inputs): + self.compiler_called += 1 + return gm + +@contextmanager +def _per_rank_init(rank, world_size): + # To avoid multiple inheritance from _dynamo.test_case.TestCase and MultiProcessTestCase, + # Just manually implement the most important part of the dynamo behavior to reset/clear. + torch.cuda.set_device(rank) + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '6789' + dist.init_process_group("nccl", rank=rank, world_size=world_size) + torch._dynamo.reset() + torch._dynamo.utils.counters.clear() + yield + torch._dynamo.reset() + torch._dynamo.utils.counters.clear() + dist.destroy_process_group() + + +# This simulates DDP, but it doesn't actually do any process communication; +# it just has enough properties so that the dynamo distributed optimization is +# able to optimize. Feel free to simulate more properties as necessary. The +# other important thing is patching _active_ddp_module, which is what actually +# triggers DDP optimization +class FakeDDP(nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + bucket_cap_mb = 25 + self.bucket_bytes_cap = int(bucket_cap_mb * 1024 * 1024) + + @contextmanager + def _inside_ddp_forward(self): + DDP._active_ddp_module = self + try: + yield + except Exception: + raise + finally: + DDP._active_ddp_module = None + + def forward(self, *inputs, **kwargs): + with self._inside_ddp_forward(): + return self.module.forward(*inputs, **kwargs) + +def run_hf_bert_ddp(self, model, inputs, backend): + reset_rng_state() + correct_outputs = model(**inputs) + correct_loss = correct_outputs.loss + correct_loss.backward() + + reset_rng_state() + opt_model = torch._dynamo.optimize(backend)(model) + opt_outputs = opt_model(**inputs) + opt_loss = opt_outputs.loss + opt_loss.backward() + + inputs_flat = [inputs[k] for k in inputs] + correct_results = collect_results(model, correct_outputs.logits, correct_loss, inputs_flat) + opt_results = collect_results(opt_model, opt_outputs.logits, opt_loss, inputs_flat) + self.assertTrue(same(correct_results, opt_results)) + +class TestFakeDistributedSingleProc(torch._dynamo.test_case.TestCase): + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @patch.object(config, "optimize_ddp", True) + @patch.object(torch._inductor.config, "fallback_random", True) + def test_hf_bert_ddp_inductor(self): + model, inputs = get_hf_bert(0) + model = FakeDDP(model) + run_hf_bert_ddp(self, model, inputs, "inductor") + + @patch.object(config, "optimize_ddp", True) + def test_hf_bert_ddp_aot_eager(self): + model, inputs = get_hf_bert(0) + model = FakeDDP(model) + run_hf_bert_ddp(self, model, inputs, "aot_eager") + + @patch.object(config, "optimize_ddp", True) + def test_issue90375(self): + class Model(nn.Module): + def forward(self): + return torch.randn(3) * torch.randn(3) + + model = Model() + model = FakeDDP(model) + + opt_model = torch._dynamo.optimize("aot_eager")(model) + opt_model() + + +# Are these tests failing? Check and see if TestFakeDistributedSingleProc has a +# single process version; if it's just a problem in the Dynamo distributed +# optimizer, you should be able to repro it single process! +@requires_nccl() +class TestDistributedMultiProc(MultiProcessTestCase): + def setUp(self): + super(TestDistributedMultiProc, self).setUp() + self._spawn_processes() + + def tearDown(self): + super(TestDistributedMultiProc, self).tearDown() + try: + os.remove(self.file_name) + except OSError: + pass + + @property + def world_size(self) -> int: + return torch.cuda.device_count() + + @classmethod + def _run(cls, rank: int, test_name: str, file_name: str, parent_pipe) -> None: + # Don't enable DDP + ReplicatedTensor, as that breaks Dynamo+DDP + # TODO(whc) why is ReplicatedTensor defaulted=True in MultiProcessTestCase, and should we support it? + # from torch.nn.parallel._replicated_tensor_ddp_utils import _set_ddp_with_replicated_tensor + # _set_ddp_with_replicated_tensor(True) + + # The rest is copypasta from MultiProcessTestCase._run + self = cls(test_name) + self.rank = rank + self.file_name = file_name + self.run_test(test_name, parent_pipe) + + @skip_if_lt_x_gpu(2) + @patch.object(config, "optimize_ddp", False) + def test_ddp_baseline_aot_eager_multiprocess(self): + with _per_rank_init(self.rank, self.world_size): + self.assertFalse(config.optimize_ddp) + m, inputs, correct_outputs = get_model(f"cuda:{self.rank}") + m = DDP(m, device_ids=[self.rank]) + m = torch._dynamo.optimize("aot_eager")(m) + outputs = m(inputs) + self.assertTrue(same(correct_outputs, outputs)) + + @skip_if_lt_x_gpu(2) + @import_transformers_or_skip() + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @patch.object(config, "optimize_ddp", True) + @patch.object(torch._inductor.config, "fallback_random", True) + def test_hf_bert_ddp_inductor(self): + + with _per_rank_init(self.rank, self.world_size): + model, inputs = get_hf_bert(self.rank) + model = DDP(model) + run_hf_bert_ddp(self, model, inputs, "inductor") + + @skip_if_lt_x_gpu(2) + @import_transformers_or_skip() + @patch.object(config, "optimize_ddp", True) + def test_hf_bert_ddp_aot_eager(self): + with _per_rank_init(self.rank, self.world_size): + model, inputs = get_hf_bert(self.rank) + model = DDP(model) + run_hf_bert_ddp(self, model, inputs, "aot_eager") + + @skip_if_lt_x_gpu(1) + def test_fsdp_aot_eager(self): + with _per_rank_init(self.rank, self.world_size): + # Test with basic FSDP wrapping (outer wrap around whole model) + m, inputs, correct_outputs = get_model(f"cuda:{self.rank}") + fsdp_m = FSDP(m, use_orig_params=True) + fsdp_m = torch._dynamo.optimize("aot_eager")(fsdp_m) + outputs = fsdp_m(inputs) + self.assertTrue(same(correct_outputs, outputs)) + + # Test with recursive wrapping, nested FSDP around each Linear + m, inputs, correct_outputs = get_model(f"cuda:{self.rank}") + fsdp_m = FSDP( + m, + auto_wrap_policy=functools.partial( + transformer_auto_wrap_policy, transformer_layer_cls=(nn.Linear, ) + ), + use_orig_params=True + ) + fsdp_m = torch._dynamo.optimize("aot_eager")(fsdp_m) + outputs = fsdp_m(inputs) + self.assertTrue(same(correct_outputs, outputs)) + + @skip_if_lt_x_gpu(1) + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + def test_fsdp_inductor(self): + with _per_rank_init(self.rank, self.world_size): + # Test with basic FSDP wrapping (outer wrap around whole model) + m, inputs, correct_outputs = get_model(f"cuda:{self.rank}") + fsdp_m = FSDP(m, use_orig_params=True) + fsdp_m = torch._dynamo.optimize("inductor")(fsdp_m) + outputs = fsdp_m(inputs) + self.assertTrue(same(correct_outputs, outputs)) + + # Test with recursive wrapping, nested FSDP around each Linear + m, inputs, correct_outputs = get_model(f"cuda:{self.rank}") + fsdp_m = FSDP( + m, + auto_wrap_policy=functools.partial( + transformer_auto_wrap_policy, transformer_layer_cls=(nn.Linear, ) + ), + use_orig_params=True + ) + fsdp_m = torch._dynamo.optimize("inductor")(fsdp_m) + outputs = fsdp_m(inputs) + self.assertTrue(same(correct_outputs, outputs)) + + @import_transformers_or_skip() + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + # TODO(whc) Investigate why cudagraphs breaks inductor+fsdp for hf_bert + @patch.object(torch._inductor.config.triton, "cudagraphs", False) + @patch.object(torch._inductor.config, "fallback_random", True) + def test_hf_bert_fsdp(self): + from transformers.models.bert.modeling_bert import BertLayer + + def apply_fsdp(model, wrap_policy): + model = FSDP( + copy.deepcopy(model), + auto_wrap_policy=wrap_policy, + use_orig_params=True + ) + return model + + with _per_rank_init(self.rank, self.world_size): + for (wrap_policy, test_instance) in ( + ( + None, + "FSDP without recursive wrapping" + ), + ( + functools.partial( + transformer_auto_wrap_policy, transformer_layer_cls=(BertLayer, ) + ), + "FSDP with recursive wrapping BertLayer instances" + ) + ): + print(f"Running hf_bert test for {test_instance}") + model, inputs = get_hf_bert(self.rank) + reset_rng_state() + eager_model = apply_fsdp(model, wrap_policy) + correct_outputs = eager_model(**inputs) + correct_loss = correct_outputs.loss + correct_loss.backward() + + reset_rng_state() + opt_model = apply_fsdp(model, wrap_policy) + + opt_model = torch._dynamo.optimize("inductor")(opt_model) + opt_outputs = opt_model(**inputs) + opt_loss = opt_outputs.loss + opt_loss.backward() + + inputs_flat = [inputs[k] for k in inputs] + correct_results = collect_results(eager_model, correct_outputs.logits, correct_loss, inputs_flat) + opt_results = collect_results(opt_model, opt_outputs.logits, opt_loss, inputs_flat) + self.assertTrue(same(correct_results, opt_results)) + + +@requires_nccl() +class TestDistributed(torch._dynamo.test_case.TestCase): + """ + Test harness initializes dist process group + """ + + @classmethod + def setUpClass(cls): + super().setUpClass() + # _exit_stack is set up in TestCase + cls._exit_stack.enter_context( + patch.dict( + os.environ, + { + "MASTER_ADDR": "localhost", + "MASTER_PORT": "12355", + }, + ) + ) + cls.rank = 0 + cls.device = f"cuda:{cls.rank}" + cls.device_ids = None if "cuda" in cls.device else [cls.rank] + dist.init_process_group("nccl", rank=cls.rank, world_size=1) + + @classmethod + def tearDownClass(cls): + dist.destroy_process_group() + super().tearDownClass() + + def get_model(self, bsz=20, in_feat=10, hidden_feat=5000, out_feat=5): + m = ToyModel(in_feat=in_feat, hidden_feat=hidden_feat, out_feat=out_feat).to(self.device) + m.apply(init_weights) + inputs = torch.rand(bsz, in_feat).to(self.device) + outputs = m(inputs) + return m, inputs, outputs + + @patch.object(config, "optimize_ddp", False) + def test_ddp_baseline_aot_eager(self): + from torch.nn.parallel import DistributedDataParallel as DDP + + m, inputs, correct_outputs = self.get_model() + ddp_m = DDP(m, device_ids=self.device_ids) + ddp_m = torch._dynamo.optimize("aot_eager")(ddp_m) + outputs = ddp_m(inputs) + self.assertTrue(same(correct_outputs, outputs)) + + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @patch.object(config, "optimize_ddp", False) + def test_ddp_baseline_inductor(self): + from torch.nn.parallel import DistributedDataParallel as DDP + + m, inputs, correct_outputs = self.get_model() + ddp_m = DDP(m, device_ids=self.device_ids) + ddp_m = torch._dynamo.optimize("inductor")(ddp_m) + outputs = ddp_m(inputs) + self.assertTrue(same(correct_outputs, outputs)) + + @patch.object(config, "optimize_ddp", True) + def test_graph_split(self): + """ + Just ensures that the appropriate number of splits happen (based on + bucket size and model parameters) - verifies the number of times + the user-provided compiler is called by the DDPOptimizer which is + doing the graph splitting + """ + + m, inputs, correct_outputs = self.get_model() + ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25) + + check_splits_compiler = CheckSplitsCompiler() + + @torch._dynamo.optimize(check_splits_compiler.compile_fn) + def opt_fn(inputs): + return ddp_m(inputs) + + opt_outputs = opt_fn(inputs) + self.assertTrue(same(correct_outputs, opt_outputs)) + self.assertEqual(check_splits_compiler.compiler_called, 3) + + @patch.object(config, "optimize_ddp", True) + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + def test_graph_split_inductor(self): + """ + Same as above, but using inductor backend. + We observed issues with inductor/fx interface in the past. + """ + m, inputs, correct_outputs = self.get_model() + ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25) + + @torch._dynamo.optimize("inductor") + def opt_fn(inputs): + return ddp_m(inputs) + + opt_outputs = opt_fn(inputs) + self.assertTrue(same(correct_outputs, opt_outputs)) + + @patch.object(config, "optimize_ddp", True) + def test_no_split(self): + """ + Ensures the DDPOptimizer returns a correct, compiled module without + introducing graph splits. (Based on model parmeters fitting in the bucket) + """ + # DDP will always do a 'first bucket' with a really small size; so only a tiny model will escape this + m, inputs, correct_outputs = self.get_model(hidden_feat=5) + ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=250) + check_splits_compiler = CheckSplitsCompiler() + + @torch._dynamo.optimize(check_splits_compiler.compile_fn) + def opt_fn(inputs): + return ddp_m(inputs) + + opt_outputs = opt_fn(inputs) + self.assertTrue(same(correct_outputs, opt_outputs)) + self.assertEqual(check_splits_compiler.compiler_called, 1) + + @patch.object(config, "optimize_ddp", True) + def test_aot_autograd(self): + """ + Explicitly check AotAutograd family of compilers work, + since they require example inputs propagated between graph splits. + """ + m, inputs, correct_outputs = self.get_model() + ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25) + + @torch._dynamo.optimize("aot_eager") + def opt_fn(inputs): + return ddp_m(inputs) + + opt_outputs = opt_fn(inputs) + opt_outputs.sum().backward() + self.assertTrue(same(correct_outputs, opt_outputs)) + + @patch.object(config, "optimize_ddp", True) + def test_custom_layer(self): + """ + Just ensures that the appropriate number of splits happen (based on + bucket size and model parameters) - verifies the number of times + the user-provided compiler is called by the DDPOptimizer which is + doing the graph splitting + """ + m, inputs, correct_outputs = get_custom_model(self.device) + ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=1) + + check_splits_compiler = CheckSplitsCompiler() + + @torch._dynamo.optimize(check_splits_compiler.compile_fn) + def opt_fn(inputs): + return ddp_m(inputs) + + opt_outputs = opt_fn(inputs) + self.assertTrue(same(correct_outputs, opt_outputs)) + self.assertEqual(check_splits_compiler.compiler_called, 3) + + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + def test_empty_graph_inductor(self): + def fn(): + get_world_size = torch.distributed.distributed_c10d.get_world_size() + return (get_world_size,) + + opt_fn = torch._dynamo.optimize("inductor")(fn) + res = None + try: + res = opt_fn()[0] + except Exception: + pass + self.assertEqual(res, 1) + + @patch.object(config, "optimize_ddp", False) + def test_ignored_parameters(self): + """ + Verifies ddp graph-split logic ignores parameters marked to ignore on DDP module. + Hooks up graph-split optimizer manually so it can peek at internal state. + """ + m, inputs, correct_outputs = get_custom_model(self.device) + parameters_to_ignore = ["seq.2.weight", "seq.4.linear.bias"] + DDP._set_params_and_buffers_to_ignore_for_model(m, parameters_to_ignore) + ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25) + parameter_ids_to_ignore = [ + id(ddp_m.module.get_parameter(p)) + for p in ddp_m.parameters_to_ignore + ] + + check_splits_compiler = CheckSplitsCompiler() + ddp_optimizer = DDPOptimizer( + bucket_bytes_cap=ddp_m.bucket_bytes_cap, + backend_compile_fn=check_splits_compiler.compile_fn + ) + + @torch._dynamo.optimize(ddp_optimizer.compile_fn) + def opt_fn(inputs): + return ddp_m(inputs) + + opt_outputs = opt_fn(inputs) + self.assertTrue(same(correct_outputs, opt_outputs)) + self.assertEqual(check_splits_compiler.compiler_called, 2) + for b in ddp_optimizer.buckets: + for p_id in b.param_ids: + self.assertFalse(p_id in parameter_ids_to_ignore) + + def test_fsdp_orig_params_assert(self): + # Test with basic FSDP wrapping (outer wrap around whole model) + m, inputs, correct_outputs = get_model(f"cuda:{self.rank}") + fsdp_m = FSDP(m, use_orig_params=False) + fsdp_m = torch._dynamo.optimize()(fsdp_m) + self.assertRaisesRegex(AssertionError, "Dynamo only supports FSDP with use_orig_params=True", fsdp_m, inputs) + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/test/distributed/test_multi_threaded_pg.py b/test/distributed/test_multi_threaded_pg.py new file mode 100644 index 0000000000000..1ca04103ddfae --- /dev/null +++ b/test/distributed/test_multi_threaded_pg.py @@ -0,0 +1,132 @@ +# Owner(s): ["oncall: distributed"] + +import sys +import torch +import torch.distributed as dist +from torch._C._distributed_c10d import ReduceOp + +if not dist.is_available(): + print("Distributed not available, skipping tests", file=sys.stderr) + sys.exit(0) + +from torch.testing._internal.common_distributed import ( + spawn_threads_and_init_comms, + MultiThreadedTestCase + +) +from torch.testing._internal.common_utils import TestCase, run_tests + +DEFAULT_WORLD_SIZE = 4 + +class TestCollectivesWithWrapper(TestCase): + @spawn_threads_and_init_comms(world_size=4) + def test_broadcast_object_list(self): + val = 99 if dist.get_rank() == 0 else None + object_list = [val] * dist.get_world_size() + + dist.broadcast_object_list(object_list=object_list) + self.assertEqual(99, object_list[0]) + + def test_collective_error_on_rank_zero(self): + @spawn_threads_and_init_comms(world_size=4) + def _test_method(self): + input_tensor = torch.ones(3, 3) * dist.get_rank() # perform 1st all gather + output_tensors = [torch.empty_like(input_tensor) for _ in range(dist.get_world_size())] + dist.all_gather(output_tensors, input_tensor) + + if dist.get_rank() == 0: + raise AssertionError("Mimic real test failure.") # fail on rank 0 + + dist.all_gather(output_tensors, input_tensor) # perform 2nd all gather + + with self.assertRaisesRegex(AssertionError, "Mimic real test failure."): + _test_method(self) + + def test_collective_error_on_rank_non_zero(self): + @spawn_threads_and_init_comms(world_size=4) + def _test_method(self): + input_tensor = torch.ones(3, 3) * dist.get_rank() # perform 1st all gather + output_tensors = [torch.empty_like(input_tensor) for _ in range(dist.get_world_size())] + dist.all_gather(output_tensors, input_tensor) + + if dist.get_rank() == 1: + raise AssertionError("Mimic real test failure.") # fail on rank 1 + + dist.all_gather(output_tensors, input_tensor) # perform 2nd all gather + + with self.assertRaisesRegex(AssertionError, "Mimic real test failure."): + _test_method(self) + + def test_collective_error_on_rank_non_zero_all(self): + @spawn_threads_and_init_comms(world_size=4) + def _test_method(self): + input_tensor = torch.ones(3, 3) * dist.get_rank() # perform 1st all gather + output_tensors = [torch.empty_like(input_tensor) for _ in range(dist.get_world_size())] + dist.all_gather(output_tensors, input_tensor) + + if dist.get_rank() > 0: + raise AssertionError("Mimic real test failure.") # fail on all non-zero rank + + dist.all_gather(output_tensors, input_tensor) # perform 2nd all gather + + with self.assertRaisesRegex(AssertionError, "Mimic real test failure."): + _test_method(self) + +class TestCollectivesWithBaseClass(MultiThreadedTestCase): + @property + def world_size(self): + return 4 + + def test_allgather(self): + input_tensor = torch.ones(3, 3) * dist.get_rank() + output_tensors = [torch.empty_like(input_tensor) for _ in range(self.world_size)] + dist.all_gather(output_tensors, input_tensor) + for rank, out_tensor in enumerate(output_tensors): + self.assertEqual(out_tensor, torch.ones(3, 3) * rank) + + def test_broadcast(self): + input_tensor = torch.ones(3, 3) * dist.get_rank() + for rank in range(self.world_size): + cloned_input = input_tensor.clone() + dist.broadcast(cloned_input, src=rank) + self.assertEqual(cloned_input, torch.ones(3, 3) * rank) + + def test_scatter(self): + if dist.get_rank() == 0: + scatter_list = [torch.ones(3, 3) * rank for rank in range(self.world_size)] + else: + scatter_list = None + output_tensor = torch.empty(3, 3) + + dist.scatter(output_tensor, scatter_list) + self.assertEqual(output_tensor, torch.ones(3, 3) * dist.get_rank()) + + def test_reduce_scatter(self): + to_reduce_scatter = [torch.ones(3, 3) * rank for rank in range(self.world_size)] + output_tensor = torch.empty(3, 3) + + dist.reduce_scatter(output_tensor, to_reduce_scatter) + expected_tensor = torch.ones(3, 3) * dist.get_rank() * self.world_size + self.assertEqual(output_tensor, expected_tensor) + + def test_broadcast_object_list(self): + val = 99 if dist.get_rank() == 0 else None + object_list = [val] * dist.get_world_size() + print(f"{dist.get_rank()} -> {dist.get_world_size()}") + + dist.broadcast_object_list(object_list=object_list) + self.assertEqual(99, object_list[0]) + + def test_all_reduce(self): + output = torch.ones(3, 3) * dist.get_rank() + dist.all_reduce(output) + res_num = ((0 + self.world_size - 1) * self.world_size) / 2 + self.assertEqual(output, torch.ones(3, 3) * res_num) + + # Test unimplemented error + with self.assertRaisesRegex(NotImplementedError, "only supports SUM on threaded pg for now"): + dist.all_reduce(output, op=ReduceOp.MAX) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributions/test_constraints.py b/test/distributions/test_constraints.py index f0c0023af3d34..475d9f33ec9a3 100644 --- a/test/distributions/test_constraints.py +++ b/test/distributions/test_constraints.py @@ -53,6 +53,7 @@ (constraints.simplex,), (constraints.corr_cholesky,), (constraints.lower_cholesky,), + (constraints.positive_definite,), ] diff --git a/test/distributions/test_distributions.py b/test/distributions/test_distributions.py index b6201d4d9e84d..a5687a4e1439e 100644 --- a/test/distributions/test_distributions.py +++ b/test/distributions/test_distributions.py @@ -1421,7 +1421,7 @@ def ref_log_prob(ref_rate, idx, x, log_prob): # theoretical results. dist = Poisson(rate_zero) dist.log_prob(torch.ones_like(rate_zero)).backward() - torch.testing.assert_allclose(rate_zero.grad, torch.inf) + self.assertEqual(rate_zero.grad, torch.inf) @unittest.skipIf(not TEST_NUMPY, "Numpy not found") def test_poisson_sample(self): @@ -2036,7 +2036,7 @@ def test_lowrank_multivariate_normal_log_prob(self): unbatched_prob = torch.stack([dist_unbatched[i].log_prob(x[:, i]) for i in range(5)]).t() self.assertEqual(batched_prob.shape, unbatched_prob.shape) - self.assertEqual(0.0, (batched_prob - unbatched_prob).abs().max(), atol=1e-3, rtol=0) + self.assertEqual(batched_prob, unbatched_prob, atol=1e-3, rtol=0) @unittest.skipIf(not TEST_NUMPY, "NumPy not found") def test_lowrank_multivariate_normal_sample(self): @@ -2176,7 +2176,7 @@ def test_multivariate_normal_log_prob(self): unbatched_prob = torch.stack([dist_unbatched[i].log_prob(x[:, i]) for i in range(5)]).t() self.assertEqual(batched_prob.shape, unbatched_prob.shape) - self.assertEqual(0.0, (batched_prob - unbatched_prob).abs().max(), atol=1e-3, rtol=0) + self.assertEqual(batched_prob, unbatched_prob, atol=1e-3, rtol=0) @unittest.skipIf(not TEST_NUMPY, "NumPy not found") def test_multivariate_normal_sample(self): @@ -2331,7 +2331,7 @@ def test_wishart_log_prob(self): unbatched_prob = torch.stack([dist_unbatched[i].log_prob(x[:, i]) for i in range(5)]).t() self.assertEqual(batched_prob.shape, unbatched_prob.shape) - self.assertEqual(0.0, (batched_prob - unbatched_prob).abs().max(), atol=1e-3, rtol=0) + self.assertEqual(batched_prob, unbatched_prob, atol=1e-3, rtol=0) @unittest.skipIf(not TEST_NUMPY, "NumPy not found") def test_wishart_sample(self): @@ -2975,6 +2975,9 @@ def test_cdf_log_prob(self): # Tests if the differentiation of the CDF gives the PDF at a given value for Dist, params in EXAMPLES: for i, param in enumerate(params): + # We do not need grads wrt params here, e.g. shape of gamma distribution. + param = {key: value.detach() if isinstance(value, torch.Tensor) else value + for key, value in param.items()} dist = Dist(**param) samples = dist.sample() if not dist.support.is_discrete: @@ -3185,8 +3188,6 @@ def _test_discrete_distribution_mode(self, dist, sanitized_mode, batch_isfinite) self.assertTrue((-1e-12 < delta[mask].detach()).all()) # Allow up to 1e-12 rounding error. def _test_continuous_distribution_mode(self, dist, sanitized_mode, batch_isfinite): - if isinstance(dist, Wishart): - return # We perturb the mode in the unconstrained space and expect the log probability to decrease. num_points = 10 transform = transform_to(dist.support) diff --git a/test/distributions/test_transforms.py b/test/distributions/test_transforms.py index ea99562b1f0c6..d922c83672287 100644 --- a/test/distributions/test_transforms.py +++ b/test/distributions/test_transforms.py @@ -14,7 +14,8 @@ LowerCholeskyTransform, PowerTransform, ReshapeTransform, SigmoidTransform, TanhTransform, SoftmaxTransform, SoftplusTransform, StickBreakingTransform, - identity_transform, Transform, _InverseTransform) + identity_transform, Transform, _InverseTransform, + PositiveDefiniteTransform) from torch.distributions.utils import tril_matrix_to_vec, vec_to_tril_matrix @@ -43,6 +44,7 @@ def get_transforms(cache_size): StickBreakingTransform(cache_size=cache_size), LowerCholeskyTransform(cache_size=cache_size), CorrCholeskyTransform(cache_size=cache_size), + PositiveDefiniteTransform(cache_size=cache_size), ComposeTransform([ AffineTransform(torch.randn(4, 5), torch.randn(4, 5), @@ -118,10 +120,15 @@ def generate_data(transform): domain = domain.base_constraint codomain = transform.codomain x = torch.empty(4, 5) - if domain is constraints.lower_cholesky or codomain is constraints.lower_cholesky: - x = torch.empty(6, 6) - x = x.normal_() + positive_definite_constraints = [constraints.lower_cholesky, constraints.positive_definite] + if domain in positive_definite_constraints: + x = torch.randn(6, 6) + x = x.tril(-1) + x.diag().exp().diag_embed() + if domain is constraints.positive_definite: + return x @ x.T return x + elif codomain in positive_definite_constraints: + return torch.randn(6, 6) elif domain is constraints.real: return x.normal_() elif domain is constraints.real_vector: @@ -189,6 +196,7 @@ def test_with_cache(transform): @pytest.mark.parametrize('test_cached', [True, False]) def test_forward_inverse(transform, test_cached): x = generate_data(transform).requires_grad_() + assert transform.domain.check(x).all() # verify that the input data are valid try: y = transform(x) except NotImplementedError: diff --git a/test/dynamo/test_aot_autograd.py b/test/dynamo/test_aot_autograd.py index 1532267a043d7..fe81a23cc3399 100644 --- a/test/dynamo/test_aot_autograd.py +++ b/test/dynamo/test_aot_autograd.py @@ -59,7 +59,7 @@ def fn(param, y): compiler_fn = functools.partial(compiler_safe_fn, is_safe=is_safe) aot_fn = torch._dynamo.optimize(compiler_fn)(fn) aot_fn(x, y) - self.assertTrue(not is_safe[0]) + self.assertTrue(is_safe[0]) def test_mutation1(self): def fn(_stack0: torch.Tensor, diagonal_chunked_attention_scores: torch.Tensor): @@ -88,7 +88,7 @@ def fn(_stack0: torch.Tensor, diagonal_chunked_attention_scores: torch.Tensor): compiler_fn = functools.partial(compiler_safe_fn, is_safe=is_safe) aot_fn = torch._dynamo.optimize(compiler_fn)(fn) aot_fn(x, y) - self.assertTrue(not is_safe[0]) + self.assertTrue(is_safe[0]) def test_negative_testing_mutation(self): def fn(_stack0: torch.Tensor, diagonal_chunked_attention_scores: torch.Tensor): @@ -202,7 +202,7 @@ def forward(self, x, y): compiler_fn = functools.partial(compiler_safe_fn, is_safe=is_safe) aot_fn = torch._dynamo.optimize(compiler_fn)(graph) aot_fn(x, y) - self.assertTrue(not is_safe[0]) + self.assertTrue(is_safe[0]) def test_call_fn_with_non_const_inputs_aot_unsafe_control_flow(self): class ModuleSpecialFwd(torch.nn.Module): diff --git a/test/dynamo/test_aot_cudagraphs.py b/test/dynamo/test_aot_cudagraphs.py index fdb7c88762b8b..5299e92a060f7 100644 --- a/test/dynamo/test_aot_cudagraphs.py +++ b/test/dynamo/test_aot_cudagraphs.py @@ -104,8 +104,7 @@ def fn(x, y): y = torch.randn((), device="cpu") fn(x, y) - @patch("functorch._src.config.use_functionalize", True) - @patch_all(ok=False) # input mutation not supported yet + @patch("torch._functorch.config.use_functionalize", True) def test_mutate_input(self): def model(x, y): y.add_(3) @@ -160,7 +159,7 @@ def fn(y): y = torch.randn(3, device="cuda:0", requires_grad=True) fn(y) - @patch("functorch._src.config.use_functionalize", True) + @patch("torch._functorch.config.use_functionalize", True) @patch_all() def test_mutated_metadata(self): # more tortured example at @@ -181,7 +180,7 @@ def fn(x): x = torch.empty(0, device="cuda:0") fn(x) - @patch("functorch._src.config.use_functionalize", True) + @patch("torch._functorch.config.use_functionalize", True) @patch_all() def test_dead_fill(self): def model(x): diff --git a/test/dynamo/test_distributed.py b/test/dynamo/test_distributed.py deleted file mode 100644 index 695e34817f37b..0000000000000 --- a/test/dynamo/test_distributed.py +++ /dev/null @@ -1,287 +0,0 @@ -# Owner(s): ["module: dynamo"] -import os -import unittest -from unittest.mock import patch - -import pytest -import torch - -import torch._dynamo -import torch._dynamo.test_case -import torch.distributed as dist -from torch import nn -from torch._dynamo import config -from torch._dynamo.testing import same - - -class ToyModel(nn.Module): - def __init__(self, in_feat=10, hidden_feat=5000, num_hidden=2, out_feat=5): - super().__init__() - self.net = nn.Sequential( - *[nn.Linear(in_feat, hidden_feat), nn.ReLU()] - + [nn.Linear(5000, 5000), nn.ReLU()] * num_hidden - + [nn.Linear(5000, 5), nn.ReLU()] - ) - - def forward(self, inputs): - return self.net(inputs) - - -class CheckSplitsCompiler: - def __init__(self): - self.compiler_called = 0 - - def compile_fn(self, gm, example_inputs): - self.compiler_called += 1 - return gm - - -def skip_if_no_active_ddp(): - from torch.nn.parallel import DistributedDataParallel as DDP - - if not hasattr(DDP, "_get_active_ddp_module"): - raise unittest.SkipTest("requires pytorch landing in parallel") - - -@pytest.mark.skip("Module hangs in PyTorch CI") -class TestDistributed(torch._dynamo.test_case.TestCase): - """ - Test harness initializes dist process group - """ - - @classmethod - def setUpClass(cls): - super().setUpClass() - # _exit_stack is set up in TestCase - cls._exit_stack.enter_context( - patch.dict( - os.environ, - { - "MASTER_ADDR": "localhost", - "MASTER_PORT": "12355", - }, - ) - ) - cls.rank = 0 - cls.device = f"cpu:{cls.rank}" - cls.device_ids = None if "cpu" in cls.device else [cls.rank] - dist.init_process_group("gloo", rank=cls.rank, world_size=1) - - @classmethod - def tearDownClass(cls): - dist.destroy_process_group() - super().tearDownClass() - - def get_model(self): - m = ToyModel().to(self.device) - inputs = torch.randn(20, 10).to(self.device) - outputs = m(inputs) - return m, inputs, outputs - - @patch.object(config, "optimize_ddp", False) - def test_ddp_baseline_aot_eager(self): - from torch.nn.parallel import DistributedDataParallel as DDP - - m, inputs, correct_outputs = self.get_model() - ddp_m = DDP(m, device_ids=self.device_ids) - ddp_m = torch._dynamo.optimize("aot_eager")(ddp_m) - outputs = ddp_m(inputs) - self.assertTrue(same(correct_outputs, outputs)) - - @patch.object(config, "optimize_ddp", False) - def test_ddp_baseline_inductor(self): - from torch.nn.parallel import DistributedDataParallel as DDP - - m, inputs, correct_outputs = self.get_model() - ddp_m = DDP(m, device_ids=self.device_ids) - ddp_m = torch._dynamo.optimize("inductor")(ddp_m) - outputs = ddp_m(inputs) - self.assertTrue(same(correct_outputs, outputs)) - - # can't run with gloo (no support for _allgather_base) and nccl not available in CI - @pytest.mark.xfail - @patch.object(config, "optimize_ddp", False) - def test_fsdp_baseline_aot_eager(self): - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - - m, inputs, correct_outputs = self.get_model() - fsdp_m = FSDP(m, device_id=self.device_ids[0] if self.device_ids else None) - fsdp_m = torch._dynamo.optimize("aot_eager")(fsdp_m) - outputs = fsdp_m(inputs) - self.assertTrue(same(correct_outputs, outputs)) - - # hangs/crashes with inductor currently - @pytest.mark.skip - @patch.object(config, "optimize_ddp", False) - def test_fsdp_baseline_inductor(self): - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - - m, inputs, correct_outputs = self.get_model() - fsdp_m = FSDP(m, device_id=self.device_ids[0] if self.device_ids else None) - fsdp_m = torch._dynamo.optimize("inductor")(fsdp_m) - outputs = fsdp_m(inputs) - self.assertTrue(same(correct_outputs, outputs)) - - @patch.object(config, "optimize_ddp", True) - def test_graph_split(self): - """ - Just ensures that the appropriate number of splits happen (based on - bucket size and model parameters) - verifies the number of times - the user-provided compiler is called by the DDPOptimizer which is - doing the graph splitting - """ - from torch.nn.parallel import DistributedDataParallel as DDP - - skip_if_no_active_ddp() - - m, inputs, correct_outputs = self.get_model() - ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25) - - check_splits_compiler = CheckSplitsCompiler() - - @torch._dynamo.optimize(check_splits_compiler.compile_fn) - def opt_fn(inputs): - return ddp_m(inputs) - - opt_outputs = opt_fn(inputs) - self.assertTrue(same(correct_outputs, opt_outputs)) - self.assertEqual(check_splits_compiler.compiler_called, 3) - - # hangs/crashes with inductor currently - @pytest.mark.skip - @patch.object(config, "optimize_ddp", True) - def test_graph_split_inductor(self): - """ - Same as above, but using inductor backend. - We observed issues with inductor/fx interface in the past. - """ - from torch.nn.parallel import DistributedDataParallel as DDP - - skip_if_no_active_ddp() - m, inputs, correct_outputs = self.get_model() - ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25) - - @torch._dynamo.optimize("inductor") - def opt_fn(inputs): - return ddp_m(inputs) - - opt_outputs = opt_fn(inputs) - self.assertTrue(same(correct_outputs, opt_outputs)) - - @patch.object(config, "optimize_ddp", True) - def test_no_split(self): - """ - Ensures the DDPOptimizer returns a correct, compiled module without - introducing graph splits. (Based on model parmeters fitting in the bucket) - """ - from torch.nn.parallel import DistributedDataParallel as DDP - - skip_if_no_active_ddp() - m, inputs, correct_outputs = self.get_model() - ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=250) - - check_splits_compiler = CheckSplitsCompiler() - - @torch._dynamo.optimize(check_splits_compiler.compile_fn) - def opt_fn(inputs): - return ddp_m(inputs) - - opt_outputs = opt_fn(inputs) - self.assertTrue(same(correct_outputs, opt_outputs)) - self.assertEqual(check_splits_compiler.compiler_called, 1) - - @patch.object(config, "optimize_ddp", True) - def test_aot_autograd(self): - """ - Explicitly check AotAutograd family of compilers work, - since they require example inputs propagated between graph splits. - """ - from torch.nn.parallel import DistributedDataParallel as DDP - - skip_if_no_active_ddp() - m, inputs, correct_outputs = self.get_model() - ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25) - - @torch._dynamo.optimize("aot_eager") - def opt_fn(inputs): - return ddp_m(inputs) - - opt_outputs = opt_fn(inputs) - opt_outputs.sum().backward() - self.assertTrue(same(correct_outputs, opt_outputs)) - - @patch.object(config, "optimize_ddp", True) - def test_custom_layer(self): - """ - Just ensures that the appropriate number of splits happen (based on - bucket size and model parameters) - verifies the number of times - the user-provided compiler is called by the DDPOptimizer which is - doing the graph splitting - """ - from torch.nn.parallel import DistributedDataParallel as DDP - - skip_if_no_active_ddp() - - class MyCustomLinear(torch.nn.Module): - def __init__(self): - super(MyCustomLinear, self).__init__() - self.weight = nn.Parameter(torch.randn(512, 512)) - - def forward(self, x): - return torch.mm(x, self.weight.t()) - - class MyLinear(torch.nn.Module): - def __init__(self): - super(MyLinear, self).__init__() - self.linear = torch.nn.Linear(512, 512) - - def forward(self, x): - return self.linear(x) - - class MyModule(torch.nn.Module): - def __init__(self): - super(MyModule, self).__init__() - mods = [ - (MyLinear(), torch.nn.ReLU()), - # sandwitch the custom in the middle so it comes before and after - (MyCustomLinear(), torch.nn.ReLU()), - (MyLinear(), torch.nn.ReLU()), - ] - self.seq = torch.nn.Sequential(*[x for items in mods for x in items]) - - def forward(self, x): - return self.seq(x) - - m = MyModule().to(self.device) - inputs = torch.randn((512, 512)).to(self.device) - correct_outputs = m(inputs) - ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=1) - - check_splits_compiler = CheckSplitsCompiler() - - @torch._dynamo.optimize(check_splits_compiler.compile_fn) - def opt_fn(inputs): - return ddp_m(inputs) - - opt_outputs = opt_fn(inputs) - self.assertTrue(same(correct_outputs, opt_outputs)) - self.assertEqual(check_splits_compiler.compiler_called, 3) - - def test_empty_graph(self): - def fn(): - get_world_size = torch.distributed.distributed_c10d.get_world_size() - return (get_world_size,) - - opt_fn = torch._dynamo.optimize("inductor")(fn) - res = None - try: - res = opt_fn()[0] - except Exception: - pass - self.assertEqual(res, 1) - - -# TODO(jansel): debug issues running this in CI -# if __name__ == "__main__": -# from torch._dynamo.testing import run_tests -# run_tests() diff --git a/test/dynamo/test_dynamic_shapes.py b/test/dynamo/test_dynamic_shapes.py index a2a94fce1e559..2eb16784514d0 100644 --- a/test/dynamo/test_dynamic_shapes.py +++ b/test/dynamo/test_dynamic_shapes.py @@ -3,14 +3,26 @@ from torch._dynamo.testing import make_test_cls_with_patches try: - from . import test_functions, test_misc, test_modules, test_repros, test_unspec + from . import ( + test_export, + test_functions, + test_misc, + test_modules, + test_repros, + test_subgraphs, + test_unspec, + ) except ImportError: + import test_export import test_functions import test_misc import test_modules import test_repros + import test_subgraphs import test_unspec +import unittest + def make_dynamic_cls(cls): return make_test_cls_with_patches( @@ -23,6 +35,76 @@ def make_dynamic_cls(cls): DynamicShapesReproTests = make_dynamic_cls(test_repros.ReproTests) DynamicShapesNNModuleTests = make_dynamic_cls(test_modules.NNModuleTests) DynamicShapesUnspecTests = make_dynamic_cls(test_unspec.UnspecTests) +DynamicShapesExportTests = make_dynamic_cls(test_export.ExportTests) +DynamicShapesSubGraphTests = make_dynamic_cls(test_subgraphs.SubGraphTests) + + +# DynamicShapesFunctionTests +unittest.expectedFailure( + DynamicShapesFunctionTests.test_len_tensor_dynamic_shapes + # TypeError: 'torch._C.SymIntNode' object cannot be interpreted as an integer +) + +unittest.expectedFailure( + DynamicShapesFunctionTests.test_tensor_len_dynamic_shapes + # TypeError: 'torch._C.SymIntNode' object cannot be interpreted as an integer +) + + +unittest.expectedFailure( + DynamicShapesReproTests.test_do_paste_mask_dynamic_shapes + # aten.min.dim - couldn't find symbolic meta function/decomposition +) + +unittest.expectedFailure( + DynamicShapesReproTests.test_convert_boxes_to_pooler_format_dynamic_shapes + # Could not infer dtype of torch._C.SymIntNode +) + +unittest.expectedFailure( + DynamicShapesReproTests.test_hf_t5_forward_dynamic_shapes + # Cannot call sizes() on tensor with symbolic sizes/strides +) + +# DynamicShapesExportTests +unittest.expectedFailure( + DynamicShapesExportTests.test_export_with_constant_list_nonzero_dynamic_shapes +) +unittest.expectedFailure( + DynamicShapesExportTests.test_export_with_constant_list_nonzero_free_function_dynamic_shapes +) +unittest.expectedFailure( + DynamicShapesExportTests.test_export_with_constant_tuple_nonzero_dynamic_shapes +) +unittest.expectedFailure( + DynamicShapesExportTests.test_export_with_constant_tuple_nonzero_dynamic_shapes +) + + +# DynamicShapesSubGraphTests +unittest.expectedFailure( + DynamicShapesSubGraphTests.test_enumerate_not_break_graph_dynamic_shapes +) +unittest.expectedFailure(DynamicShapesSubGraphTests.test_restore_state_dynamic_shapes) + +# DynamicShapesUnspecTests +# Missing decomp +# RuntimeError: Failed running call_function +# (*(FakeTensor(FakeTensor(..., device='meta', size=(5, 1, 28, 28)), cpu), +# FakeTensor(FakeTensor(..., device='meta', size=(1,)), cpu), +# FakeTensor(FakeTensor(..., device='meta', size=(1,)), cpu), +# FakeTensor(Parameter(FakeTensor(..., device='meta', size=(1,), +# requires_grad=True)), cpu), +# FakeTensor(Parameter(FakeTensor(..., device='meta', size=(1,), +# requires_grad=True)), cpu), False, 0.1, +# FakeTensor(FakeTensor(..., device='meta', size=()), cpu)), **{}): +# aten._local_scalar_dense.default +unittest.expectedFailure(test_unspec.UnspecReproTests.test_batch_norm_act_unspec) + +# SymIntArrayRef expected to contain only concrete integers +unittest.expectedFailure( + DynamicShapesUnspecTests.test_unspec_float_precision_dynamic_shapes +) if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py index 45939e12e767f..b0640f651194d 100644 --- a/test/dynamo/test_export.py +++ b/test/dynamo/test_export.py @@ -71,6 +71,32 @@ def func(x): self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) + @patch.object(torch._dynamo.config, "dynamic_shapes", True) + def test_export_shape_control_flow_1(self): + def func(x): + if x.shape[0] > 10: + return x.cos() + return x.sin() + + opt_func = torch._dynamo.optimize("eager")(func) + real_result = opt_func(torch.ones(6, 4)) + + torch._dynamo.reset() + + exported = torch._dynamo.export(func, torch.ones(6, 4)) + out_graph, out_guards = exported + + dynamo_result = out_graph(torch.ones(6, 4)) + + self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) + hit = False + for guard in out_guards: + if guard.name == "symbolic_shape_expression": + hit = True + self.assertTrue("x.size()[0] <= 10" in guard.code_list) + + self.assertTrue(hit) + def test_export_graph_bypass(self): inp = [ torch.tensor([0.1, 0.1]), @@ -912,25 +938,25 @@ def func(x): torch._dynamo.reset() def compiler(gm, sample_inputs): - aten_gm = make_fx(gm)(*sample_inputs) - - self.assertEqual(len(aten_gm.graph.nodes), len(out_graph.graph.nodes)) - for node1, node2 in zip(aten_gm.graph.nodes, out_graph.graph.nodes): - self.assertEqual(node1.op, node2.op) - if node1.op == "call_function": - self.assertEqual(node1.target, node2.target) - self.assertEqual(len(node1.args), len(node2.args)) - for arg1, arg2 in zip(node1.args, node2.args): - self.assertEqual(type(arg1), type(arg2)) + def fw(*args): + aten_gm = make_fx(gm)(*args) + return aten_gm(*args) - return aten_gm.forward + return fw opt_func = torch._dynamo.optimize(compiler, nopython=True)(func) - make_fx_result = opt_func(inp) + make_fx_result_through_backend = opt_func(inp) - self.assertTrue(torch._dynamo.utils.same(make_fx_result, export_result)) + fx_g = make_fx(func)(inp) + make_fx_result_through_direct = fx_g(inp) + + self.assertTrue( + torch._dynamo.utils.same(make_fx_result_through_backend, export_result) + ) + self.assertTrue( + torch._dynamo.utils.same(make_fx_result_through_direct, export_result) + ) - @patch.object(torch._dynamo.config, "fake_tensor_propagation", False) def test_export_with_constant_method_on_module(self): class MyModule(torch.nn.Module): def __init__(self): @@ -957,7 +983,6 @@ def forward(self, x): result = graph(torch.tensor([[1, 0], [0.25, 0.25]])) self.assertTrue(torch._dynamo.utils.same(result, real_result)) - @patch.object(torch._dynamo.config, "fake_tensor_propagation", False) def test_export_with_constant_method_on_module_invoke_twice(self): class MyModule(torch.nn.Module): def __init__(self): @@ -984,7 +1009,6 @@ def forward(self, x): result = graph(torch.tensor([[1, 0], [0.25, 0.25]])) self.assertTrue(torch._dynamo.utils.same(result, real_result)) - @patch.object(torch._dynamo.config, "fake_tensor_propagation", False) def test_export_with_constant_free_function(self): @torch._dynamo.assume_constant_result def helper_fn(x): @@ -1015,7 +1039,6 @@ def forward(self, x): result = graph(torch.tensor([[1, 0], [0.25, 0.25]])) self.assertTrue(torch._dynamo.utils.same(result, real_result)) - @patch.object(torch._dynamo.config, "fake_tensor_propagation", False) def test_export_with_constant_free_function_and_class_method(self): @torch._dynamo.assume_constant_result def helper_fn(x): @@ -1042,7 +1065,6 @@ def forward(self, x): result = graph(torch.tensor([[1, 0], [0.25, 0.25]])) self.assertTrue(torch._dynamo.utils.same(result, real_result)) - @patch.object(torch._dynamo.config, "fake_tensor_propagation", False) def test_export_with_constant_free_function_and_class_method_multiarg(self): @torch._dynamo.assume_constant_result def helper_fn(x): @@ -1077,7 +1099,6 @@ def forward(self, x, z): ) self.assertTrue(torch._dynamo.utils.same(result, real_result)) - @patch.object(torch._dynamo.config, "fake_tensor_propagation", False) def test_export_with_constant_free_function_and_class_method_multiarg_diff(self): @torch._dynamo.assume_constant_result def helper_fn(x): @@ -1109,7 +1130,6 @@ def forward(self, x, z): ) self.assertTrue(torch._dynamo.utils.same(result, real_result)) - @patch.object(torch._dynamo.config, "fake_tensor_propagation", False) def test_export_with_constant_tuple_nonzero(self): class MyModule(torch.nn.Module): @torch._dynamo.assume_constant_result @@ -1134,7 +1154,6 @@ def forward(self, x): result = graph(torch.tensor([[[1.0, 0], [0, 0]], [[1.0, 0], [0, 0]]])) self.assertTrue(torch._dynamo.utils.same(result, real_result)) - @patch.object(torch._dynamo.config, "fake_tensor_propagation", False) def test_export_with_constant_list_nonzero(self): class MyModule(torch.nn.Module): @torch._dynamo.assume_constant_result @@ -1159,7 +1178,6 @@ def forward(self, x): result = graph(torch.tensor([[[1.0, 0], [0, 0]], [[1.0, 0], [0, 0]]])) self.assertTrue(torch._dynamo.utils.same(result, real_result)) - @patch.object(torch._dynamo.config, "fake_tensor_propagation", False) def test_export_with_constant_list_nonzero_free_function(self): @torch._dynamo.assume_constant_result def helper_fn(x): @@ -1184,7 +1202,6 @@ def forward(self, x): result = graph(torch.tensor([[[1.0, 0], [0, 0]], [[1.0, 0], [0, 0]]])) self.assertTrue(torch._dynamo.utils.same(result, real_result)) - @patch.object(torch._dynamo.config, "fake_tensor_propagation", False) def test_export_with_constant_dict_values(self): class MyModule(torch.nn.Module): @torch._dynamo.assume_constant_result @@ -1207,7 +1224,6 @@ def forward(self, x): result = graph(torch.tensor([[[1.0, 0], [0, 0]], [[1.0, 0], [0, 0]]])) self.assertTrue(torch._dynamo.utils.same(result, real_result)) - @patch.object(torch._dynamo.config, "fake_tensor_propagation", False) def test_export_with_constant_none_control_flow(self): class MyModule(torch.nn.Module): @torch._dynamo.assume_constant_result @@ -1235,7 +1251,6 @@ def forward(self, x): # X is positive, but we compiled helper_fn to return None, so it will still return y self.assertTrue(torch._dynamo.utils.same(result, real_result)) - @patch.object(torch._dynamo.config, "fake_tensor_propagation", False) def test_export_with_constant_not_none_control_flow(self): class MyModule(torch.nn.Module): @torch._dynamo.assume_constant_result @@ -1263,7 +1278,6 @@ def forward(self, x): # X is negative, but we compiled helper_fn to return x, so it will still return y * x self.assertTrue(torch._dynamo.utils.same(result, real_result)) - @patch.object(torch._dynamo.config, "fake_tensor_propagation", False) def test_export_with_constant_none_control_flow_free_func(self): @torch._dynamo.assume_constant_result def helper_fn(x): @@ -1291,7 +1305,6 @@ def forward(self, x): # X is positive, but we compiled helper_fn to return None, so it will still return y self.assertTrue(torch._dynamo.utils.same(result, real_result)) - @patch.object(torch._dynamo.config, "fake_tensor_propagation", False) def test_export_with_constant_not_none_control_flow_pos(self): class MyModule(torch.nn.Module): @torch._dynamo.assume_constant_result @@ -1319,7 +1332,6 @@ def forward(self, x): # X is negative, but we compiled helper_fn to return x, so it will still return y * x self.assertTrue(torch._dynamo.utils.same(result, real_result)) - @patch.object(torch._dynamo.config, "fake_tensor_propagation", False) def test_export_with_constant_not_none_control_flow_free_func(self): @torch._dynamo.assume_constant_result def helper_fn(x): @@ -1347,7 +1359,6 @@ def forward(self, x): # X is negative, but we compiled helper_fn to return x, so it will still return y * x self.assertTrue(torch._dynamo.utils.same(result, real_result)) - @patch.object(torch._dynamo.config, "fake_tensor_propagation", False) def test_export_with_constant_not_return_const(self): class MyModule(torch.nn.Module): @torch._dynamo.assume_constant_result @@ -1423,15 +1434,8 @@ def nop(x): ) @patch.object(torch._dynamo.config, "capture_scalar_outputs", True) - @patch.object(torch._dynamo.config, "fake_tensor_propagation", False) def test_export_with_module_layer(self): - from functorch.experimental.cond import cond - - def true_fn(layer, val): - return layer(val) * torch.tensor(2) - - def false_fn(layer, val): - return layer(val) * torch.tensor(-1) + from functorch.experimental.control_flow import cond class Module(torch.nn.Module): def __init__(self): @@ -1439,7 +1443,13 @@ def __init__(self): self.linear = torch.nn.Linear(3, 3) def forward(self, pred, x): - return cond(pred, true_fn, false_fn, [self.linear, x]) + def true_fn(val): + return self.linear(val) * torch.tensor(2) + + def false_fn(val): + return self.linear(val) * torch.tensor(-1) + + return cond(pred, true_fn, false_fn, [x]) mod = Module() x = torch.randn([3, 3]) @@ -1461,6 +1471,22 @@ def forward(self, pred, x): dynamo_result_2 = out_graph(pred, x) self.assertTrue(torch._dynamo.utils.same(real_result_2, dynamo_result_2)) + def test_export_meta_val(self): + def f(x, y, z): + return x * y + z + + gm, _ = torch._dynamo.export( + f, + torch.ones(3, 2), + torch.zeros(3, 2), + torch.ones(3, 2), + aten_graph=True, + tracing_mode="symbolic", + ) + for node in gm.graph.nodes: + if node.op == "placeholder": + self.assertIn("val", node.meta) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_export_mutations.py b/test/dynamo/test_export_mutations.py new file mode 100644 index 0000000000000..218935d3f8cb8 --- /dev/null +++ b/test/dynamo/test_export_mutations.py @@ -0,0 +1,134 @@ +# Owner(s): ["module: dynamo"] +import torch +import torch._dynamo.test_case +import torch._dynamo.testing + + +class MutationExportTests(torch._dynamo.test_case.TestCase): + def check_failure_on_export(self, mod, *args): + with self.assertRaises(AssertionError): + torch._dynamo.export(mod, *args) + + def check_same_with_export(self, mod, arg): + real_result = mod(arg) + graph, _ = torch._dynamo.export(mod, arg) + result = graph(arg) + self.assertTrue(torch._dynamo.utils.same(result, real_result)) + + def test_module_attribute_mutation_violation_positive_1(self): + # Mutating attribute with a Tensor type + class Foo(torch.nn.Module): + def __init__(self): + super().__init__() + self.a = torch.Tensor(3, 2) + + def forward(self, x): + self.a = self.a.to(torch.float64) + return x.sum() + self.a.sum() + + self.check_failure_on_export(Foo(), torch.Tensor(3, 2)) + + def test_module_attribute_mutation_violation_positive_2(self): + # Mutating attribute with a scalar type + class Foo(torch.nn.Module): + def __init__(self): + super().__init__() + self.a = 2 + + def forward(self, x): + self.a = self.a * 3 + return x.sum() + self.a + + self.check_failure_on_export(Foo(), torch.Tensor(3, 2)) + + def test_module_attribute_mutation_violation_positive_3(self): + # Setting a new attribute inside forward() + class Foo(torch.nn.Module): + def __init__(self): + super().__init__() + self.a = torch.Tensor(3, 2) + + def forward(self, x): + self.b = 2 + return x.sum() + self.a.sum() + self.b + + self.check_failure_on_export(Foo(), torch.Tensor(3, 2)) + + def test_module_attribute_mutation_violation_positive_4(self): + # Mutating attribute with an inline function + class Foo(torch.nn.Module): + def __init__(self): + super().__init__() + + def add(self, a, b): + return a + b + + def forward(self, x): + self.a = self.add(1, 2) * self.add(3, 4) + return x.sum() + self.a + + self.check_failure_on_export(Foo(), torch.Tensor(3, 2)) + + def test_module_attribute_mutation_violation_negative_1(self): + # Mutating attribute with a Tensor type inside __init__ but + # not in forward() + class Foo(torch.nn.Module): + def __init__(self): + super().__init__() + self.a = torch.Tensor(3, 2) + + def forward(self, x): + return x.sum() + self.a.to(torch.float64).sum() + + self.check_same_with_export(Foo(), torch.Tensor(3, 2)) + + def test_module_attribute_mutation_violation_negative_2(self): + # Mutating attribute with a Tensor type inside __init__ twice + class Foo(torch.nn.Module): + def __init__(self): + super().__init__() + self.a = torch.Tensor(3, 2) + self.a = self.a.to(torch.float64) + + def forward(self, x): + return x.sum() + self.a.sum() + + self.check_same_with_export(Foo(), torch.Tensor(3, 2)) + + def test_module_attribute_mutation_violation_negative_3(self): + # Mutating local variable inside forward() + class Foo(torch.nn.Module): + def __init__(self): + super().__init__() + self.a = torch.Tensor(3, 2) + + def forward(self, x): + b = 1 + b = b * 5 + return x.sum() + self.a.sum() + b + + self.check_same_with_export(Foo(), torch.Tensor(3, 2)) + + def test_module_attribute_mutation_violation_negative_4(self): + # Mutating attribute with a Tensor type + # But not exporting but using eager mode as well as dynamo optimize mode + class Foo(torch.nn.Module): + def __init__(self): + super().__init__() + self.a = torch.Tensor(3, 2) + + def forward(self, x): + self.a = self.a.to(torch.float64) + return x.sum() + self.a.sum() + + mod = Foo() + arg = torch.Tensor(3, 2) + real_result = mod(arg) + opt_mod = torch._dynamo.optimize("eager", nopython=True)(mod) + self.assertTrue(torch._dynamo.utils.same(opt_mod(arg), real_result)) + + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index d18ef7e1173fe..4b84d2ca3b71c 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -5,7 +5,9 @@ import inspect import itertools import operator +import unittest from typing import Any +from unittest.mock import patch import torch @@ -66,6 +68,22 @@ def test_inline_jit_annotations(x): def test_add(a, b): return a + b + @make_test + def test_add_(a, b): + a_copy = torch.tensor(a) + return a_copy.add_(b, alpha=5.0) + + @make_test + def test_addcdiv(a, b, c): + # dynamo decomposes this to avoid a graph break when + # the value kwarg is populated + return torch.addcdiv(a, b, c, value=5.0) + + @make_test + def test_addcdiv_(a, b, c): + a_copy = torch.tensor(a) + return a_copy.addcdiv_(b, c, value=5.0) + @make_test def test_is_not_null(a, b): if a is not None and b is not None: @@ -216,6 +234,17 @@ def test_slice5(a): def test_slice6(a): return torch.unsqueeze(a, 0)[:, 2:] + @make_test + def test_range1(a): + return torch.tensor(range(a.size(0))) + + @make_test + def test_range2(x, y): + r = x + y + for i in range(x.size(0) + 2): + r = r / y + return r + @make_test def test_unpack1(a): a, b = a[:5], a[5:] @@ -323,11 +352,26 @@ def test_device(x): if not x.is_cuda: return x + 1 + @make_test + def test_tensor_type(a, b): + m = a.to(torch.float16) + return b.type(m.type()) + + @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") + @make_test + def test_tensor_type2(a, b): + m = a.to("cuda") + return m + b.type(m.type()) + @make_test def test_ndim(x): if x.ndim == 2 and x.ndimension() == 2 and x.dim() == 2: return x + 1 + @make_test + def test_T(x): + return torch.ones_like(x.T) + @make_test def test_is_sparse(x): if not x.is_sparse: diff --git a/test/dynamo/test_minifier.py b/test/dynamo/test_minifier.py index 4570d15b2d148..26b2c6ee557e2 100644 --- a/test/dynamo/test_minifier.py +++ b/test/dynamo/test_minifier.py @@ -1,95 +1,313 @@ # Owner(s): ["module: dynamo"] -import os -import shutil -from unittest.mock import patch +import functools +import re +import textwrap +import unittest import torch - import torch._dynamo -import torch._dynamo.test_case -import torch._dynamo.testing -from torch._dynamo.optimizations.backends import create_backend - - -class MockModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - for _ in range(10): - x = torch.sin(x) - x = torch._foobar(x) - for _ in range(10): - x = torch.cos(x) - return x - - -class MinfierTests(torch._dynamo.test_case.TestCase): - def test_after_dynamo(self): - @create_backend - def bad_dynamo_backend(subgraph): - import sys - - def f(*args): - # Shifted the forced exception to runtime as this is more common - # in JIT compilers. - for node in subgraph.model.graph.nodes: - if node.op == "call_function" and node.target is torch._foobar: - sys.stdout.write("Dynamo compiled failed\n") - raise NotImplementedError("foobar is not implemented") - return subgraph.model(*args) - - return f - - mod = MockModule() - opt_mod = torch._dynamo.optimize("bad_dynamo_backend")(mod) - repro_dir = "/tmp/test_minifier" - repro_file = os.path.join(repro_dir, "minifier_launcher.py") - shutil.rmtree(repro_dir, ignore_errors=True) - - @patch.object(torch._dynamo.config, "repro_after", "dynamo") - @patch.object(torch._dynamo.config, "repro_dir", repro_dir) - def inner(): - x = torch.randn(4) - try: - opt_mod(x) - except Exception: - pass - - inner() - self.assertTrue(os.path.exists(repro_file)) - - # If error_at_aot is True, an error will be produced when AOTAutograd - # attempts to generate the backward graph. - # If error_after_aot is False, an error will be produced in inductor. - def _test_around_aot(self, error_at_aot): - mod = MockModule() - opt_mod = torch._dynamo.optimize("inductor")(mod) - repro_dir = "/tmp/test_minifier" - repro_file = os.path.join(repro_dir, "minifier_launcher.py") - shutil.rmtree(repro_dir, ignore_errors=True) - - repro_after = "dynamo" if error_at_aot else "aot" - - @patch.object(torch._dynamo.config, "repro_after", repro_after) - @patch.object(torch._dynamo.config, "repro_dir", repro_dir) - def inner(): - x = torch.randn(4) - x.requires_grad = error_at_aot - try: - opt_mod(x) - except Exception: - pass - - inner() - - self.assertTrue(os.path.exists(repro_file)) - - def test_at_aot(self): - self._test_around_aot(True) - - def test_after_aot(self): - self._test_around_aot(False) +from torch._dynamo.test_minifier_common import MinifierTestBase + +requires_cuda = functools.partial( + unittest.skipIf, not torch.cuda.is_available(), "requires cuda" +) + +RELU_COMPILE_ERROR_BACKEND = """\ +from torch._dynamo.optimizations.backends import register_backend + +class DynamoCompileError(Exception): + pass + +@register_backend +def test_relu_compile_error(gm: torch.fx.GraphModule, example_inputs): + for node in gm.graph.nodes: + if node.target == torch.relu: + raise DynamoCompileError("relu found") + return gm +""" + +RELU_RUNTIME_ERROR_BACKEND = """\ +from torch._dynamo.optimizations.backends import register_backend + +@register_backend +def test_relu_runtime_error(gm: torch.fx.GraphModule, example_inputs): + for node in gm.graph.nodes: + if node.target == torch.relu: + node.target = torch._assert + node.args = (False, "DynamoRuntimeError") + gm.recompile() + return gm +""" + +RELU_ACCURACY_ERROR_BACKEND = """\ +from torch._dynamo.optimizations.backends import register_backend + +@register_backend +def test_relu_accuracy_error(gm: torch.fx.GraphModule, example_inputs): + for node in gm.graph.nodes: + if node.target == torch.relu: + node.target = torch.add + node.args = (node.args[0], 1) + gm.recompile() + + return gm +""" + +RELU_CUSTOM_ERROR_BACKEND = """\ +class CustomError(Exception): + pass + +def test_relu_custom_error(gm: torch.fx.GraphModule, example_inputs): + for node in gm.graph.nodes: + if node.target == torch.relu: + raise CustomError("relu found") + return gm +""" + + +class MinifierTests(MinifierTestBase): + @classmethod + def setUpClass(cls): + super().setUpClass() + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + + # Test that compile, runtime, and accuracy errors after dynamo can be repro'd (both CPU and CUDA) + def _test_after_dynamo(self, device, repro_level, backend_code, error_name): + run_code = textwrap.dedent( + f"""\ + @torch._dynamo.optimize("{self._get_fn_name(backend_code)}") + def inner(x): + for _ in range(10): + x = torch.sin(x) + x = torch.relu(x) + for _ in range(10): + x = torch.cos(x) + return x + + inner(torch.randn(20, 20).to("{device}")) + """ + ) + + (test_proc, _, repro_proc), _ = self._run_full_test( + run_code, "dynamo", repro_level, backend_code + ) + + self.assertIn(error_name, test_proc.stderr.decode("utf-8")) + self.assertIn(error_name, repro_proc.stderr.decode("utf-8")) + + def test_after_dynamo_cpu_compile_error(self): + self._test_after_dynamo( + "cpu", 2, RELU_COMPILE_ERROR_BACKEND, "DynamoCompileError" + ) + + def test_after_dynamo_cpu_runtime_error(self): + self._test_after_dynamo( + "cpu", 2, RELU_RUNTIME_ERROR_BACKEND, "DynamoRuntimeError" + ) + + def test_after_dynamo_cpu_accuracy_error(self): + self._test_after_dynamo("cpu", 4, RELU_ACCURACY_ERROR_BACKEND, "AccuracyError") + + @requires_cuda() + def test_after_dynamo_cuda_compile_error(self): + self._test_after_dynamo( + "cuda", 2, RELU_COMPILE_ERROR_BACKEND, "DynamoCompileError" + ) + + @requires_cuda() + def test_after_dynamo_cuda_runtime_error(self): + self._test_after_dynamo( + "cuda", 2, RELU_RUNTIME_ERROR_BACKEND, "DynamoRuntimeError" + ) + + @requires_cuda() + def test_after_dynamo_cuda_accuracy_error(self): + self._test_after_dynamo("cuda", 4, RELU_ACCURACY_ERROR_BACKEND, "AccuracyError") + + # Ensure that the testing backends pass when relu is not present. + def _test_after_dynamo_backend_passes(self, device, repro_level, backend_code): + run_code = textwrap.dedent( + f"""\ + @torch._dynamo.optimize("{self._get_fn_name(backend_code)}") + def inner(x): + for _ in range(10): + x = torch.sin(x) + for _ in range(10): + x = torch.cos(x) + return x + + inner(torch.randn(20, 20).to("{device}")) + """ + ) + + test_code = self._gen_test_code(run_code, "dynamo", repro_level, backend_code) + proc, repro_dir = self._run_test_code(test_code) + self.assertEqual(proc.returncode, 0) + self.assertIsNone(repro_dir) + + def test_after_dynamo_cpu_compile_backend_passes(self): + self._test_after_dynamo_backend_passes("cpu", 2, RELU_COMPILE_ERROR_BACKEND) + + def test_after_dynamo_cpu_runtime_backend_passes(self): + self._test_after_dynamo_backend_passes("cpu", 2, RELU_RUNTIME_ERROR_BACKEND) + + def test_after_dynamo_cpu_accuracy_backend_passes(self): + self._test_after_dynamo_backend_passes("cpu", 4, RELU_ACCURACY_ERROR_BACKEND) + + @requires_cuda() + def test_after_dynamo_cuda_compile_backend_passes(self): + self._test_after_dynamo_backend_passes("cuda", 2, RELU_COMPILE_ERROR_BACKEND) + + @requires_cuda() + def test_after_dynamo_cuda_runtime_backend_passes(self): + self._test_after_dynamo_backend_passes("cuda", 2, RELU_RUNTIME_ERROR_BACKEND) + + @requires_cuda() + def test_after_dynamo_cuda_accuracy_backend_passes(self): + self._test_after_dynamo_backend_passes("cuda", 4, RELU_ACCURACY_ERROR_BACKEND) + + # Ensure that generated code with a custom backends generates a runnable minifier + # launcher script that results in a RuntimeError + def test_after_dynamo_custom_backend(self): + run_code = textwrap.dedent( + f"""\ + @torch._dynamo.optimize({self._get_fn_name(RELU_CUSTOM_ERROR_BACKEND)}) + def inner(x): + for _ in range(10): + x = torch.sin(x) + x = torch.relu(x) + for _ in range(10): + x = torch.cos(x) + return x + + inner(torch.randn(20, 20)) + """ + ) + + repro_after = "dynamo" + repro_level = 2 + test_code = self._gen_test_code( + run_code, repro_after, repro_level, RELU_CUSTOM_ERROR_BACKEND + ) + _, repro_dir = self._run_test_code(test_code) + launch_proc, _ = self._run_minifier_launcher("", repro_dir) + self.assertIn("RuntimeError", launch_proc.stderr.decode("utf-8")) + + # Test that a module with mixed cpu/cuda parts with an error after dynamo can be repro'd + @requires_cuda() + def test_cpu_cuda_module_after_dynamo(self): + backend_name = self._get_fn_name(RELU_COMPILE_ERROR_BACKEND) + + run_code = textwrap.dedent( + f"""\ + class CpuCudaModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.m_x = torch.nn.Linear(20, 20).cuda() + self.m_y = torch.nn.Linear(20, 20) + self.p_x = torch.nn.Parameter(torch.randn(20, 20).cuda()) + self.p_y = torch.nn.Parameter(torch.randn(20, 20)) + self.register_buffer("b_x", torch.ones(20, 20).cuda()) + self.register_buffer("b_y", torch.ones(20, 20)) + + def forward(self, x, y): + return self.m_x(x) + self.p_x + self.b_x, self.m_y(y) + self.p_y + self.b_y + + mod = CpuCudaModule() + + @torch._dynamo.optimize("{backend_name}") + def inner(x1, y1): + x2 = torch.randn(20, 20).cuda() + y2 = torch.randn(20, 20) + x3, y3 = mod(x1 + x2, y1 + y2) + return torch.relu(x3.cpu() + y3) + + inner(torch.randn(20, 20).cuda(), torch.randn(20, 20)) + """ + ) + + (test_proc, _, repro_proc), (launch_code, _) = self._run_full_test( + run_code, "dynamo", 2, RELU_COMPILE_ERROR_BACKEND + ) + + tb1 = test_proc.stderr.decode("utf-8") + tb2 = repro_proc.stderr.decode("utf-8") + + # Check if generated minifier code covers all cpu/cuda cases + self.assertIsNotNone(re.search(r"args.*cuda", launch_code)) + self.assertIsNotNone(re.search(r"args.*cpu", launch_code)) + # search for Linear(...).cuda() + self.assertIsNotNone(re.search(r"Linear.*cuda", launch_code)) + # search for Linear(...) + self.assertIsNotNone( + re.search(r"Linear(?!.*cuda.*$)", launch_code, re.MULTILINE) + ) + self.assertIsNotNone(re.search(r"register_buffer.*cuda", launch_code)) + self.assertIsNotNone( + re.search(r"register_buffer(?!.*cuda.*$)", launch_code, re.MULTILINE) + ) + self.assertIsNotNone(re.search(r"Parameter.*cuda", launch_code)) + self.assertIsNotNone( + re.search(r"Parameter(?!.*cuda.*$)", launch_code, re.MULTILINE) + ) + # search for + # = torch.randn(...) + # ... = .cuda() + self.assertIsNotNone( + re.search(r"(\w+) = torch.randn.*\1\.cuda", launch_code, re.DOTALL) + ) + # search for + # = torch.randn(...) + # no followup call to .cuda() + self.assertIsNotNone( + re.search( + r"(\w+) = torch.randn(?!.*\1\.cuda\(\).*$)", launch_code, re.DOTALL + ) + ) + + self.assertIn(backend_name, tb1) + self.assertIn(backend_name, tb2) + + # Test if we can actually get a minified graph + def test_if_graph_minified(self): + backend_name = self._get_fn_name(RELU_COMPILE_ERROR_BACKEND) + + run_code = textwrap.dedent( + f"""\ + @torch._dynamo.optimize("{backend_name}") + def inner(x): + for _ in range(20): + x = torch.sin(x) + x = torch.relu(x) + for _ in range(20): + x = torch.cos(x) + return x + + inner(torch.randn(20, 20)) + """ + ) + + (test_proc, _, repro_proc), (launch_code, repro_code) = self._run_full_test( + run_code, "dynamo", 2, RELU_COMPILE_ERROR_BACKEND + ) + + tb1 = test_proc.stderr.decode("utf-8") + tb2 = repro_proc.stderr.decode("utf-8") + + self.assertIn(backend_name, tb1) + self.assertIn(backend_name, tb2) + + # compare the length of the forward functions + match = re.search(r"def forward.*return", launch_code, re.DOTALL) + self.assertIsNotNone(match) + self.assertGreater(match.group(0).count("\n"), 40) + + match = re.search(r"def forward.*return", repro_code, re.DOTALL) + self.assertIsNotNone(match) + self.assertLess(match.group(0).count("\n"), 5) if __name__ == "__main__": diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 542a0319a48d3..18132cad557d7 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -1,4 +1,5 @@ # Owner(s): ["module: dynamo"] +import abc import collections import copy import dataclasses @@ -26,6 +27,7 @@ same, unsupported, ) +from torch.nn import functional as F from torch.testing._internal.common_utils import freeze_rng_state from torch.testing._internal.jit_utils import JitTestCase @@ -399,12 +401,29 @@ def fn(a, b): return torch._dynamo.testing.standard_test(self, fn=fn, nargs=2, expected_ops=3) + def test_is_tensor2(self): + def fn(x): + if torch.is_tensor(x): + return x + 1 + else: + return torch.ones([2, 3]) + + x1 = {"input": torch.rand(2, 3)} + x2 = torch.rand(2, 3) + ref1 = fn(x1) + ref2 = fn(x2) + opt_fn = torch._dynamo.optimize("eager")(fn) + res1 = opt_fn(x1) + res2 = opt_fn(x2) + self.assertEqual(ref1, res1) + self.assertEqual(ref2, res2) + def test_numel(self): def fn(a): - return a + a.numel() + torch.numel(a) + return (a + a.numel() + torch.numel(a), a + a.nelement()) return torch._dynamo.testing.standard_test( - self, fn=fn, nargs=1, expected_ops=2, expected_ops_dynamic=4 + self, fn=fn, nargs=1, expected_ops=3, expected_ops_dynamic=6 ) def test_pair(self): @@ -477,6 +496,20 @@ def fn(packed): self.assertEqual(cnts.frame_count, 1) self.assertEqual(cnts.op_count, 3) + def test_namedtuple3(self): + def fn(x, packed): + if isinstance(packed, mytuple): + return x + 1 + else: + return x - 1 + + x = torch.rand([2, 3]) + packed = mytuple(1, 2, 3) + ref = fn(x, packed) + opt_fn = torch._dynamo.optimize("eager")(fn) + res = opt_fn(x, packed) + self.assertTrue(same(ref, res)) + def test_range_input(self): def fn(a, rng): x = a @@ -1143,6 +1176,7 @@ def fn(x): torch._dynamo.run()(fn2)(torch.randn(4)) self.assertEqual(cnts2.frame_count, 0) + @patch.object(torch._dynamo.config, "suppress_errors", True) def test_nested_disable_decorator(self): cnts = torch._dynamo.testing.CompileCounter() @@ -1242,6 +1276,58 @@ def f(x): self.assertTrue(same(ref0, res0)) self.assertTrue(same(ref1, res1)) + def test_is_tensor_like2(self): + class MyTensor(object): + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + + if func is torch.max: + return torch.tensor(123) + return func(*args, **kwargs) + + def fn(x): + if torch.overrides.is_tensor_like(x): + return torch.max(x) + else: + return torch.zeros(1) + + x = MyTensor() + ref0 = fn(x) + ref1 = fn(4) + opt_fn = torch._dynamo.optimize("eager")(fn) + res0 = opt_fn(x) + res1 = opt_fn(4) + self.assertTrue(same(ref0, res0)) + self.assertTrue(same(ref1, res1)) + + def test_tensor_data(self): + def fn(x, y): + return x[y.data] + + x = torch.rand(8) + y = torch.ones(8).to(torch.int) + ref = fn(x, y) + opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) + res = opt_fn(x, y) + self.assertTrue(same(ref, res)) + + def test_tensor_layout(self): + def fn(x): + return torch.zeros( + [x.size()[0], x.size()[1]], + dtype=x.dtype, + layout=x.layout, + device=x.device, + ) + + x = torch.rand(2, 3) + ref = fn(x) + opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) + res = opt_fn(x) + self.assertTrue(same(ref, res)) + def test_version_ci(self): # temporary test to check that the ci torch version is set correctly self.assertTrue(hasattr(torch, "_subclasses")) @@ -1343,7 +1429,7 @@ def fn(): self.assertTrue(result[1] == fn.__code__.co_lnotab) def test_torch_profiler(self): - # wrap torch.profiler.* as ProfilerContextWrapperVariable and do nothing + # wrap torch.profiler.* as NullContextVariable and do nothing def fn(x): y = x**2 with torch.profiler.profile(): @@ -1363,7 +1449,7 @@ def fn(x): self.assertEqual(cnts.frame_count, 2) def test_autograd_profiler(self): - # wrap torch.autograd.profiler.* as ProfilerContextWrapperVariable and do nothing + # wrap torch.autograd.profiler.* as NullContextVariable and do nothing def fn(x): y = x**2 with torch.autograd.profiler.profile(): @@ -1382,6 +1468,42 @@ def fn(x): self.assertTrue(same(ref, res)) self.assertEqual(cnts.frame_count, 2) + def test_autograd_profiler_enabled(self): + def fn(x): + if torch.autograd._profiler_enabled(): + return x + 1 + else: + return x - 1 + + x = torch.randn((2, 2), requires_grad=True) + cnts = torch._dynamo.testing.CompileCounter() + opt_fn = torch._dynamo.optimize(cnts)(fn) + + if torch.autograd._profiler_enabled(): + torch.autograd._disable_profiler() + assert not torch.autograd._profiler_enabled() + ref = fn(x) + res = opt_fn(x) + self.assertTrue(same(ref, res)) + + with torch.autograd.profiler.profile(): + assert torch.autograd._profiler_enabled() + ref = fn(x) + res = opt_fn(x) + self.assertTrue(same(ref, res)) + + def test_tensor_is_contiguous(self): + def fn(x): + input = torch.randn((1, 16, 1, 1)) + weight = torch.randn((8, 16, 3, 3)) + weight = weight.to(memory_format=x) + output = torch.conv2d(input, weight, None, (2, 1), (1, 1), (1, 1), 1) + return output.is_contiguous(memory_format=x) + + opt_fn = torch._dynamo.optimize("eager")(fn) + for x in [torch.contiguous_format, torch.channels_last]: + self.assertEqual(fn(x), opt_fn(x)) + def test_python_slice(self): def f1(input): y = 0 @@ -1405,10 +1527,12 @@ def f2(input): self.assertEqual(res2, 9) def test_const_dict_variable_python_type(self): - from torch._dynamo.variables import ConstDictVariable + from torch._dynamo.variables import ConstantVariable, ConstDictVariable - d1 = {"a": 10, "b": 20} - d2 = collections.OrderedDict([("x", 12), ("y", 22)]) + d1 = {"a": ConstantVariable(10), "b": ConstantVariable(20)} + d2 = collections.OrderedDict( + [("x", ConstantVariable(12)), ("y", ConstantVariable(22))] + ) self.assertEqual(ConstDictVariable(d1, dict).python_type(), dict) self.assertEqual( ConstDictVariable(d2, collections.OrderedDict).python_type(), @@ -1578,24 +1702,6 @@ def fn(x, func): opt_fn(x, torch.mul) self.assertEqual(cnts.op_count, 1) - @patch.object(torch._dynamo.config, "fake_tensor_propagation", True) - def test_unsupported_fake_tensor(self): - def f(x): - return torch.quantize_per_tensor(x, 0.1, 10, torch.quint8) - - x = torch.randn(2, 2) - cnts = torch._dynamo.testing.CompileCounter() - opt_f = torch._dynamo.optimize(cnts)(f) - opt_f(x) - self.assertEqual(cnts.op_count, 0) - - torch._dynamo.reset() - with patch.object(torch._dynamo.config, "fake_tensor_propagation", False): - opt_f = torch._dynamo.optimize_assert( - torch._dynamo.testing.CompileCounter() - )(f) - opt_f(x) - def test_inline_list_mutation(self): def f1(x): x.append(torch.ones(8)) @@ -1949,6 +2055,23 @@ def test_cross_entropy_loss_simple_ctor(self): self.assertTrue(torch.allclose(dynamo_output, output)) + def test_nn_functional_reduction(self): + def fn(loss, reduction): + reduction_enum = F._Reduction.get_enum(reduction) + if reduction_enum == 0: + return loss + elif reduction_enum == 1: + return loss.mean() + elif reduction_enum == 2: + return loss.sum() + + x = torch.rand([3, 5]) + y = "mean" + ref = fn(x, y) + opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) + res = opt_fn(x, y) + self.assertTrue(torch.allclose(ref, res)) + def test_large_reduction_list(self): dtype = torch.float32 device = "cpu" @@ -1959,7 +2082,6 @@ def check_sum_all(tensor: torch.Tensor) -> None: check_sum_all(torch.randn(200000, dtype=dtype, device=device)) - @patch.object(torch._dynamo.config, "raise_on_backend_error", True) def test_raise_on_backend_error(self): def my_compiler(gm, _): raise RuntimeError("duck!") @@ -2055,7 +2177,6 @@ def __init__(self): self.names = [] def forward(self, idx, targets=None): - from torch.nn import functional as F b, t = idx.size() assert ( @@ -2215,7 +2336,7 @@ def f_onnx(x): self.assertEqual(f_onnx(input_two_dims), 8) def test_cond(self): - from functorch.experimental.cond import cond + from functorch.experimental.control_flow import cond def true_fn(x): return x.sin() @@ -2232,54 +2353,30 @@ def f(pred, x): b = opt_fn(torch.tensor(True), torch.tensor([0.25, 0.25])) self.assertTrue(same(torch.sin(torch.tensor([0.25, 0.25])), b)) - def test_cond_nested(self): + @unittest.expectedFailure + def test_cond_side_effects(self): from functorch.experimental.cond import cond - def true_fn_nested(x): - return x * 10 - - def false_fn_nested(x): - return x * -1 + c = 0 - def true_fn(pred2, x): - return x.sin() - - def false_fn(pred2, x): - return x + cond(pred2, true_fn_nested, false_fn_nested, [x]) - - def f(pred, pred2, x): - return cond(pred, true_fn, false_fn, [pred2, x]) - - cc = torch._dynamo.testing.CompileCounter() - opt_fn = torch._dynamo.optimize(cc)(f) - true_true_sin = opt_fn( - torch.tensor(True), torch.tensor(True), torch.tensor([0.25, 0.25]) - ) - self.assertTrue(same(torch.sin(torch.tensor([0.25, 0.25])), true_true_sin)) + def true_fn(x): + return x - c - true_false_sin = opt_fn( - torch.tensor(True), torch.tensor(False), torch.tensor([0.25, 0.25]) - ) - self.assertTrue(same(torch.sin(torch.tensor([0.25, 0.25])), true_false_sin)) + def false_fn(x): + return x + c - false_true_sum_mult = opt_fn( - torch.tensor(False), torch.tensor(True), torch.tensor([0.25, 0.25]) - ) - self.assertTrue( - same(torch.tensor([2.75, 2.75]), false_true_sum_mult) - ) # * 10 then add x + def f(pred, x): + nonlocal c + c = 1 + return cond(pred, true_fn, false_fn, [x]) - false_false_sum_neg = opt_fn( - torch.tensor(False), torch.tensor(False), torch.tensor([0.25, 0.25]) - ) - self.assertTrue( - same(torch.tensor([0.0, 0.0]), false_false_sum_neg) - ) # * -1 then add x - self.assertTrue(cc.frame_count, 2) + opt_fn = torch._dynamo.optimize("eager")(f) + c = 0 + a = opt_fn(torch.tensor(False), torch.tensor([0.25, 0.25])) + self.assertTrue(same(torch.tensor([1.25, 1.25]), a)) - @patch.object(torch._dynamo.config, "fake_tensor_propagation", False) - def test_cond_nested_fake_tensor_off(self): - from functorch.experimental.cond import cond + def test_cond_nested(self): + from functorch.experimental.control_flow import cond def true_fn_nested(x): return x * 10 @@ -2321,11 +2418,10 @@ def f(pred, pred2, x): self.assertTrue( same(torch.tensor([0.0, 0.0]), false_false_sum_neg) ) # * -1 then add x - self.assertTrue(cc.frame_count, 1) + self.assertTrue(cc.frame_count, 2) - @patch.object(torch._dynamo.config, "fake_tensor_propagation", False) def test_cond_export(self): - from functorch.experimental.cond import cond + from functorch.experimental.control_flow import cond def true_fn_nested(x): return x * 10 @@ -2369,9 +2465,8 @@ def f(pred, pred2, x): same(torch.tensor([0.0, 0.0]), false_false_sum_neg) ) # * -1 then add x - @patch.object(torch._dynamo.config, "fake_tensor_propagation", False) def test_cond_export_single_arg(self): - from functorch.experimental.cond import cond + from functorch.experimental.control_flow import cond def true_fn(x): return x @@ -2609,21 +2704,12 @@ def fn(): self.assertTrue(same(ref, res)) def test_autograd_function_equivalence(self): - m1 = Module1() - - @torch._dynamo.optimize("eager", nopython=True) - def f1(): - return m1(torch.ones(2, 3)) - - self.assertTrue(torch.allclose(f1(), torch.tensor([2.0]))) - - m2 = Module2() - - @torch._dynamo.optimize("eager", nopython=True) - def f2(): - return m2(torch.ones(2, 3)) - - self.assertTrue(torch.allclose(f2(), torch.tensor([2.0]))) + for i in range(1, 5): + model = globals()[f"Module{i}"]() + opt_model = torch._dynamo.optimize("eager", nopython=True)(model) + self.assertTrue( + torch.allclose(opt_model(torch.ones(2, 3)), torch.tensor([2.0])) + ) def test_object_classmethod(self): class C: @@ -2669,6 +2755,70 @@ def fn(x): res = opt_fn(x) self.assertTrue(torch.allclose(ref, res)) + def test_user_function_variable_supports_type_abcmeta_argument(self): + class Foo(metaclass=abc.ABCMeta): + @abc.abstractclassmethod + def read(self): + pass + + class Bar(Foo): + def read(self): + return "Hello World!" + + class Baz: + pass + + def gn(x, tys=(Bar, Baz)): + if Bar in tys: + return x - 1 + else: + return x + 1 + + def fn(x): + return gn(x) + + x = torch.randn(2, 3) + ref = fn(x) + opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) + res = opt_fn(x) + self.assertTrue(torch.allclose(ref, res)) + + def test_user_function_variable_supports_function_argument(self): + def add1(x): + return x + 1 + + def add2(x): + return x + 2 + + def gn(x, f=add1): + if f is add1: + return x + 1 + else: + return x + 2 + + def fn(x, f): + return gn(x, f) + + x = torch.randn(2, 3) + ref = fn(x, add2) + opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) + res = opt_fn(x, add2) + self.assertTrue(torch.allclose(ref, res)) + + def test_typing_variable_isinstance(self): + def fn(x, m): + if isinstance(m, typing.Mapping): + return x + 1 + else: + return x - 1 + + x = torch.randn(2, 3) + m = {"x": torch.randn(3)} + ref = fn(x, m) + opt_fn = torch._dynamo.optimize("eager")(fn) + res = opt_fn(x, m) + self.assertTrue(torch.allclose(ref, res)) + def test_repro_graph_breaks_in__get_item_by_idx(self): class Mod(torch.nn.Module): def __init__(self): @@ -2731,8 +2881,184 @@ def forward(self, x): dynamo_result = graph(x) self.assertTrue(same(real, dynamo_result)) + def test_error_on_nested_fx_trace(self): + input = torch.rand(2, 3) + + def f(x): + x + x + + real = f(input) + + optimized = torch._dynamo.optimize("eager")(f) + self.assertTrue(same(optimized(input), real)) + + with self.assertRaisesRegex(RuntimeError, "Detected that you are using FX"): + gm = torch.fx.symbolic_trace(optimized) + + @patch.object(torch._dynamo.config, "error_on_nested_fx_trace", False) + def test_no_error_on_nested_fx_trace(self): + input = torch.rand(2, 3) + + def f(x): + x + x + + real = f(input) + + optimized = torch._dynamo.optimize("eager")(f) + self.assertTrue(same(optimized(input), real)) + + # should not error + gm = torch.fx.symbolic_trace(optimized) + self.assertTrue(same(gm(input), real)) + + def test_inference_mode(self): + @torch.inference_mode() + def func(x, y): + return x.add(1.0) + y + + x = torch.ones(4, requires_grad=True) + y = torch.ones(4, requires_grad=True) + ref = func(x, y) + opt_func = torch._dynamo.optimize("eager")(func) + + x1 = torch.ones(4, requires_grad=True) + res = opt_func(x1, y) + self.assertTrue(same(ref, res)) + self.assertTrue(same(x, x1)) + + def test_if_cond_nn_mod(self): + class MockModule(torch.nn.Module): + def __init__(self, output_relu=True): + super(MockModule, self).__init__() + self.relu = torch.nn.ReLU() if output_relu else None + + def forward(self, x): + x = torch.sin(x) + if self.relu: + x = self.relu(x) + return x + + model = MockModule() + opt_model = torch._dynamo.optimize("eager", nopython=True)(model) + + x = torch.rand(4) + ref = model(x) + res = opt_model(x) + self.assertTrue(same(ref, res)) + + model = MockModule(output_relu=False) + opt_model = torch._dynamo.optimize("eager", nopython=True)(model) + + x = torch.rand(4) + ref = model(x) + res = opt_model(x) + self.assertTrue(same(ref, res)) + + def test_torch_cuda_is_available(self): + def fn(x): + if torch.cuda.is_available(): + return x + 1 + else: + return x - 1 + + x = torch.rand(4) + ref = fn(x) + opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) + res = opt_fn(x) + self.assertTrue(same(ref, res)) + + @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") + def test_get_device(self): + def fn(x, y): + x = x + 1 + y = y + 1 + return x.get_device(), y.get_device() + + x = torch.rand(4, device="cuda") + y = torch.rand(4, device="cpu") + ref = fn(x, y) + opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) + res = opt_fn(x, y) + self.assertTrue(same(ref, res)) + + def test_disable_flag(self): + + cnt = torch._dynamo.testing.CompileCounter() + + with patch.dict(os.environ, {"TORCH_COMPILE_DISABLE": "1"}): + + def fn(x, y): + x = x + 1 + y = y + 1 + + opt_fn = torch._dynamo.optimize(cnt) + + self.assertEqual(cnt.frame_count, 0) + + def test_guard_failure_fn(self): + def fn(x, y, k): + x = x + 1 + y = y + 1 + return x * y * k + + x = torch.tensor([0.5, 0.5]) + y = torch.tensor([1.0, 1.0]) + + guard_failure = None + + def guard_failures(failure): + nonlocal guard_failure + guard_failure = failure + + opt_fn = torch._dynamo.optimize( + "eager", nopython=True, guard_fail_fn=guard_failures + )(fn) + + x2 = torch.tensor([0.5, 0.5, 1.0]) + y2 = torch.tensor([0.5, 0.5, 0.5]) + + opt_fn(x, y, 3) + opt_fn(x2, y2, 5) -class CustomFunc(torch.autograd.Function): + self.assertTrue(guard_failure is not None) + self.assertEqual(guard_failure[0], "k == 3") + + def test_guard_failure_fn2(self): + def fn(x, y): + x = x + 1 + y = y + 1 + return x * y + + x = torch.tensor([0.5, 0.5]) + y = torch.tensor([1.0, 1.0]) + + guard_failure = None + + def guard_failures(failure): + nonlocal guard_failure + guard_failure = failure + + opt_fn = torch._dynamo.optimize( + "eager", nopython=True, guard_fail_fn=guard_failures + )(fn) + + x2 = torch.tensor([0.5, 0.5, 1.0]) + y2 = torch.tensor([0.5, 0.5, 0.5]) + + opt_fn(x, y) + opt_fn(x2, y2) + + if torch._dynamo.config.dynamic_shapes: + self.assertTrue(guard_failure is None) + else: + self.assertTrue(guard_failure is not None) + self.assertEqual( + guard_failure[0], + "tensor 'x' size mismatch at index 0. expected 2, actual 3", + ) + + +class CustomFunc1(torch.autograd.Function): @staticmethod def forward(ctx, foo): return foo + foo @@ -2742,18 +3068,46 @@ def backward(ctx, grad_output): return grad_output +class CustomFunc2(torch.autograd.Function): + # the forward function can be staticmethod or classmethod + @classmethod + def forward(cls, ctx, foo): + return foo + foo + + @staticmethod + def backward(ctx, grad_output): + return grad_output + + class Module1(torch.nn.Module): def __init__(self): super().__init__() def forward(self, foo): - return CustomFunc().apply(foo) + return CustomFunc1().apply(foo) class Module2(torch.nn.Module): def __init__(self): super().__init__() - self.fn = CustomFunc.apply + self.fn = CustomFunc1.apply + + def forward(self, foo): + return self.fn(foo) + + +class Module3(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, foo): + return CustomFunc2().apply(foo) + + +class Module4(torch.nn.Module): + def __init__(self): + super().__init__() + self.fn = CustomFunc2.apply def forward(self, foo): return self.fn(foo) diff --git a/test/dynamo/test_modules.py b/test/dynamo/test_modules.py index 2fb83b3add6cf..f510fb87522c5 100644 --- a/test/dynamo/test_modules.py +++ b/test/dynamo/test_modules.py @@ -595,6 +595,57 @@ def forward(self, x): return self.activation(self.linear(self.initializer + x)) * self.scale +class ModuleForwardHasGraphBreak(torch.nn.Module): + def __init__(self): + super().__init__() + self.layer1 = BasicModule() + self.layer2 = BasicModule() + self.layer3 = torch.nn.Sequential(BasicModule(), BasicModule()) + self.layer4 = torch.nn.ModuleList( + [ + torch.nn.Linear(10, 10), + torch.nn.ReLU(), + torch.nn.Linear(10, 10), + torch.nn.ReLU(), + ] + ) + self.layer5 = torch.nn.ModuleDict( + { + "0": torch.nn.Linear(10, 10), + } + ) + self.scale = torch.randn(1, 10) + + def forward(self, x): + """ + This is used to test if the results of functions like `named_parameters` + can be reconstructed correctly after graph break. + + https://github.com/pytorch/torchdynamo/issues/1931 + """ + x = self.layer1(x) + params1 = dict(self.named_parameters()) + params2 = list(self.parameters()) + buffers1 = dict(self.named_buffers()) + buffers2 = list(self.buffers()) + modules1 = dict(self.named_modules()) + modules2 = list(self.modules()) + torch._dynamo.graph_break() + y = modules2 + y = modules1 + y = buffers2 + y = buffers1 + y = params2 + y = params1 + x = ( + self.layer2(x) + + y["layer3.1.linear1.weight"] + + y["layer4.2.weight"] + + y["layer5.0.weight"] + ) + return x * self.scale + + def make_test(fn, expected_ops=None): def test_fn(self): return torch._dynamo.testing.standard_test( @@ -646,6 +697,14 @@ class NNModuleTests(torch._dynamo.test_case.TestCase): test_module_name_string = make_test(ModuleNameString()) test_module_attribute_precedence = make_test(ModuleAttributePrecedence()) + def test_module_forward_has_graph_break(self): + m = ModuleForwardHasGraphBreak() + x = torch.rand([10, 10]) + ref = m(x) + opt_m = torch._dynamo.optimize("eager")(m) + res = opt_m(x) + self.assertTrue(torch.allclose(ref, res)) + def test_unsupportedmethod(self): m = UnsupportedMethodCall() i = torch.randn(10) @@ -720,16 +779,19 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): torch._dynamo.config.traceable_tensor_subclasses.add(TensorProxy) - x = torch.randn(1).as_subclass(TensorProxy) - cnt = torch._dynamo.testing.CompileCounter() - out1 = foo(x) - opt_foo = torch._dynamo.optimize(cnt, nopython=True)(foo) - out2 = opt_foo(x) + try: - self.assertEqual(cnt.op_count, 4) - self.assertTrue(torch._dynamo.testing.same(out1, out2)) + x = torch.randn(1).as_subclass(TensorProxy) + cnt = torch._dynamo.testing.CompileCounter() + out1 = foo(x) + opt_foo = torch._dynamo.optimize(cnt, nopython=True)(foo) + out2 = opt_foo(x) + + self.assertEqual(cnt.op_count, 4) + self.assertTrue(torch._dynamo.testing.same(out1, out2)) - torch._dynamo.config.traceable_tensor_subclasses.remove(TensorProxy) + finally: + torch._dynamo.config.traceable_tensor_subclasses.remove(TensorProxy) def test_torch_function_with_closure(self): def run(): @@ -756,17 +818,18 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): torch._dynamo.config.traceable_tensor_subclasses.add(TensorProxy) - x = torch.randn(1).as_subclass(TensorProxy) - x = torch.randn(1) - cnt = torch._dynamo.testing.CompileCounter() - out1 = foo(x) - opt_foo = torch._dynamo.optimize(cnt, nopython=True)(foo) - out2 = opt_foo(x) + try: + x = torch.randn(1).as_subclass(TensorProxy) + x = torch.randn(1) + cnt = torch._dynamo.testing.CompileCounter() + out1 = foo(x) + opt_foo = torch._dynamo.optimize(cnt, nopython=True)(foo) + out2 = opt_foo(x) - self.assertEqual(cnt.op_count, 4) - self.assertTrue(torch._dynamo.testing.same(out1, out2)) - - torch._dynamo.config.traceable_tensor_subclasses.remove(TensorProxy) + self.assertEqual(cnt.op_count, 4) + self.assertTrue(torch._dynamo.testing.same(out1, out2)) + finally: + torch._dynamo.config.traceable_tensor_subclasses.remove(TensorProxy) run() @@ -904,6 +967,141 @@ def forward(self, x): self.assertTrue(torch._dynamo.testing.same(real, graph(rx))) +class MockModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.relu = torch.nn.ReLU() + self.linear = torch.nn.Linear(10, 10) + self.register_buffer("buf0", torch.randn(10, 10)) + + def forward(self, x): + return self.relu(self.linear(x) + self.buf0) + + +class OptimizedModuleTest(torch._dynamo.test_case.TestCase): + def test_nn_module(self): + mod = MockModule() + cnt = torch._dynamo.testing.CompileCounter() + opt_mod = torch._dynamo.optimize(cnt)(mod) + self.assertIsInstance(opt_mod, torch._dynamo.OptimizedModule) + + x = torch.randn(10, 10) + self.assertTrue(torch._dynamo.testing.same(mod(x), opt_mod(x))) + self.assertEqual(cnt.frame_count, 1) + + def test_to(self): + mod = MockModule() + cnt = torch._dynamo.testing.CompileCounter() + opt_mod = torch._dynamo.optimize(cnt)(mod) + x = torch.randn(10, 10) + self.assertTrue(torch._dynamo.testing.same(mod(x), opt_mod(x))) + self.assertEqual(cnt.frame_count, 1) + + # Ensure that there is no recompilation + opt_mod(x) + self.assertEqual(cnt.frame_count, 1) + + opt_mod = opt_mod.to(device="cpu").to(dtype=torch.float64) + self.assertIsInstance(opt_mod, torch._dynamo.OptimizedModule) + x = torch.randn(10, 10).to(dtype=torch.float64) + opt_mod(x) + # Ensure that there is a recompilation + self.assertEqual(cnt.frame_count, 2) + + # Ensure that there is no recompilation + opt_mod(x) + self.assertEqual(cnt.frame_count, 2) + + torch._dynamo.reset() + opt_mod(x) + self.assertEqual(cnt.frame_count, 3) + + def test_attr(self): + class MockModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 10) + self.register_buffer("buf0", torch.randn(10, 10)) + + def forward(self, x): + return self.r(torch.sin(x)) + self.buf0 + + mod = MockModule() + opt_mod = torch._dynamo.optimize("eager")(mod) + + # Check parameteres and buffers + for (p1, p2) in zip(mod.parameters(), opt_mod.parameters()): + self.assertTrue(id(p1) == id(p2)) + + def test_recursion(self): + mod = MockModule() + cnt = torch._dynamo.testing.CompileCounter() + opt_mod = torch._dynamo.optimize(cnt)(mod) + + for _ in range(5): + opt_mod = torch._dynamo.optimize(cnt)(opt_mod) + opt_mod(torch.randn(10, 10)) + self.assertEqual(cnt.frame_count, 1) + + def test_composition(self): + class InnerModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.relu = torch.nn.ReLU() + + def forward(self, x): + return self.relu(torch.sin(x)) + + opt_inner_mod = InnerModule() + + class OuterModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.mod = opt_inner_mod + + def forward(self, x): + return self.mod(torch.cos(x)) + + outer_mod = OuterModule() + cnt = torch._dynamo.testing.CompileCounter() + opt_outer_mod = torch._dynamo.optimize(cnt)(outer_mod) + + x = torch.randn(4) + self.assertIsInstance(opt_outer_mod, torch._dynamo.OptimizedModule) + self.assertTrue(torch._dynamo.testing.same(outer_mod(x), opt_outer_mod(x))) + self.assertEqual(cnt.frame_count, 1) + + def test_composition_with_opt_mod(self): + class InnerModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.relu = torch.nn.ReLU() + + def forward(self, x): + return self.relu(torch.sin(x)) + + inner_mod = InnerModule() + cnt = torch._dynamo.testing.CompileCounter() + opt_inner_mod = torch._dynamo.optimize(cnt)(inner_mod) + + class OuterModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.mod = opt_inner_mod + + def forward(self, x): + return self.mod(torch.cos(x)) + + outer_mod = OuterModule() + opt_outer_mod = torch._dynamo.optimize(cnt)(outer_mod) + + x = torch.randn(4) + self.assertIsInstance(opt_outer_mod, torch._dynamo.OptimizedModule) + self.assertTrue(torch._dynamo.testing.same(outer_mod(x), opt_outer_mod(x))) + # There will be a graph break for the inner mod being OptimizedModule + self.assertEqual(cnt.frame_count, 2) + + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_no_fake_tensors.py b/test/dynamo/test_no_fake_tensors.py deleted file mode 100644 index d65166f5762c5..0000000000000 --- a/test/dynamo/test_no_fake_tensors.py +++ /dev/null @@ -1,29 +0,0 @@ -# Owner(s): ["module: dynamo"] -from torch._dynamo.testing import make_test_cls_with_patches - -try: - from . import test_functions, test_misc, test_modules, test_repros, test_unspec -except ImportError: - import test_functions - import test_misc - import test_modules - import test_repros - import test_unspec - - -def make_no_fake_cls(cls): - return make_test_cls_with_patches( - cls, "NoFakeTensors", "_no_fake_tensors", ("fake_tensor_propagation", False) - ) - - -NoFakeTensorsFunctionTests = make_no_fake_cls(test_functions.FunctionTests) -NoFakeTensorsMiscTests = make_no_fake_cls(test_misc.MiscTests) -NoFakeTensorsReproTests = make_no_fake_cls(test_repros.ReproTests) -NoFakeTensorsNNModuleTests = make_no_fake_cls(test_modules.NNModuleTests) -NoFakeTensorsUnspecTests = make_no_fake_cls(test_unspec.UnspecTests) - -if __name__ == "__main__": - from torch._dynamo.test_case import run_tests - - run_tests() diff --git a/test/dynamo/test_nops.py b/test/dynamo/test_nops.py index 44e102699d091..c17b9528a4f8e 100644 --- a/test/dynamo/test_nops.py +++ b/test/dynamo/test_nops.py @@ -4,6 +4,7 @@ import torch._dynamo.test_case import torch._dynamo.testing from torch._dynamo import eval_frame +from torch._dynamo.hooks import Hooks c = 10 @@ -32,7 +33,7 @@ def fn3(): with_debug_nops = eval_frame._optimize_catch_errors( - torch._dynamo.testing.debug_insert_nops + torch._dynamo.testing.debug_insert_nops, Hooks(None, None) ) diff --git a/test/dynamo/test_optimizations.py b/test/dynamo/test_optimizations.py index d9f25c5954995..5bff327786fa6 100644 --- a/test/dynamo/test_optimizations.py +++ b/test/dynamo/test_optimizations.py @@ -3,7 +3,6 @@ import json import os import unittest -from unittest.mock import patch import torch @@ -121,11 +120,42 @@ def compiler_fn(graph, example_inputs): opt_fn = torch._dynamo.optimize_assert(compiler_fn)(fn) r3 = opt_fn(a, (b, c), d) + self.assertIsNotNone(r1) + self.assertEqual(r1.size(), r2.size()) + self.assertEqual(r1.stride(), r2.stride()) + self.assertEqual(r1.dtype, r2.dtype) + + self.assertEqual(r1.size(), r3.size()) + self.assertEqual(r1.stride(), r3.stride()) + self.assertEqual(r1.dtype, r3.dtype) + + def test_example_inputs_runtime_use(self): + def fn(a, bc, d): + b, c = bc + return a / d - b / c + + def compiler_fn(graph, example_inputs): + def fwd(*args): + nonlocal r1 + r = graph.forward(*args) + r1 = r[0] + return r + + return fwd + + a = torch.empty(2).fill_(1) + b = torch.empty(2).fill_(2) + c = torch.empty(2).fill_(3) + d = 4 + r1 = None + r2 = fn(a, (b, c), d) + opt_fn = torch._dynamo.optimize_assert(compiler_fn)(fn) + r3 = opt_fn(a, (b, c), d) + self.assertIsNotNone(r1) self.assertTrue(same(r1, r2)) self.assertTrue(same(r1, r3)) - @patch.object(torch._dynamo.config, "fake_tensor_propagation", False) @unittest.skipIf(not has_functorch(), "requires functorch") def test_log_conv_args(self): model = Conv_Bn_Relu(3, 32, kernel_size=3, stride=1) diff --git a/test/dynamo/test_optimizers.py b/test/dynamo/test_optimizers.py index ebb2cde24f6ad..c8b9ee663641e 100644 --- a/test/dynamo/test_optimizers.py +++ b/test/dynamo/test_optimizers.py @@ -1,7 +1,7 @@ # Owner(s): ["module: dynamo"] +import contextlib import inspect -import sys import unittest import torch @@ -10,19 +10,30 @@ import torch._dynamo.test_case import torch._dynamo.testing + input = torch.ones([10, 10]) model = torch.nn.Sequential(*[torch.nn.Linear(10, 10) for _ in range(2)]) model(input).sum().backward() -def make_test(optim_cls, exp_frame_cnt=1, closure=None, **kwargs): +# Include optimizer code for tracing +optim_filenames = set( + [ + inspect.getfile(obj) + for obj in torch.optim.__dict__.values() + if inspect.isclass(obj) + ] +) + + +optim_filenames |= {torch.optim._functional.__file__} + + +def make_test(optim_cls, exp_graph_count=1, closure=None, **kwargs): opt = optim_cls(model.parameters(), **kwargs) def test_fn(self): nonlocal opt - - counter = torch._dynamo.testing.CompileCounter() - if closure is not None: def fn(): @@ -31,18 +42,30 @@ def fn(): else: fn = opt.step - opt_fn = torch._dynamo.optimize(counter)(fn) - opt_fn() + _, _, graphs, _, _, _ = torch._dynamo.explain(fn) - self.assertEqual(counter.frame_count, exp_frame_cnt) + self.assertEqual(exp_graph_count, len(graphs)) return test_fn +@contextlib.contextmanager +def enable_optimizer_tracing(): + try: + old = set(torch._dynamo.skipfiles.FILENAME_ALLOWLIST) + + torch._dynamo.skipfiles.FILENAME_ALLOWLIST.update(optim_filenames) + yield + finally: + torch._dynamo.skipfiles.FILENAME_ALLOWLIST.clear() + torch._dynamo.skipfiles.FILENAME_ALLOWLIST.update(old) + + class OptimizerTests(torch._dynamo.test_case.TestCase): @classmethod def setUpClass(cls): super().setUpClass() + # needed until pytorch assertion is changed to enable Adam # to be called with capturable=True cls._exit_stack.enter_context( @@ -50,16 +73,7 @@ def setUpClass(cls): torch._dynamo.config, "capture_scalar_outputs", True ) ) - cls._exit_stack.enter_context( - unittest.mock.patch.object( - torch._dynamo.config, "fake_tensor_propagation", False - ) - ) - cls._exit_stack.enter_context( - unittest.mock.patch.object( - torch._dynamo.config, "raise_on_assertion_error", True - ) - ) + cls._exit_stack.enter_context(enable_optimizer_tracing()) test_sgd = make_test(torch.optim.SGD, lr=0.01) # lgbfs has data-dependent control and internally iterates @@ -68,24 +82,38 @@ def setUpClass(cls): # test_lbfgs = make_test( # torch.optim.LBFGS, exp_frame_cnt=3, closure=lambda: model(input).sum() # ) - # RAdam has data-dependent control which breaks the graph - test_radam = make_test(torch.optim.RAdam, exp_frame_cnt=1) + + # These optimizers are disabled until we remove item() calls + test_adam = make_test(torch.optim.Adam, exp_graph_count=0) + test_adamw = make_test(torch.optim.AdamW, exp_graph_count=0) + + # RAdam and Adagrad have data-dependent control which breaks the graph; + # furthermore, the break is inside a for loop, so we bail on the frame + # entirely. This is basically an xfail; if the frame count goes up + # you done good + test_radam = make_test(torch.optim.RAdam, exp_graph_count=0) # ASGD has a small optimization that avoids averaging # This will fully capture the graph once that optimization is removed - # NB: in python versions < 3.8, we don't capture graphs when breaks - # occur in a loop - - # Fails without fake tensor: - # TypeError: clamp() received an invalid combination of arguments - got (float, min=int) - # test_asgd = make_test( - # torch.optim.ASGD, exp_frame_cnt=(0 if sys.version_info < (3, 8) else 6) - # ) + # test_asgd = make_test(torch.optim.ASGD, exp_graph_count=0) # exclude SparseAdam because other areas of the stack don't support it yet # the others are handled specially above -exclude = set(["SGD", "Optimizer", "SparseAdam", "LBFGS", "RAdam", "ASGD"]) +exclude = set( + [ + "SGD", # Handled above + "ASGD", # Disabled pending item call removal + optimization removal + "Optimizer", + "SparseAdam", # Unsupported + "LBFGS", # Unsupported + "Adam", # Disabled pending item call removal + "AdamW", # Disabled pending item call removal + "RAdam", # Disabled pending item call removal + "ASGD", + ] +) + optimizers = [ opt for opt in torch.optim.__dict__.values() @@ -100,6 +128,10 @@ def setUpClass(cls): class End2EndTests(torch._dynamo.test_case.TestCase): + @classmethod + def setUpClass(cls): + super().setUpClass() + cls._exit_stack.enter_context(enable_optimizer_tracing()) # https://github.com/pytorch/torchdynamo/issues/1604 def test_optimizing_over_tensor_with_requires_grad(self): @@ -131,7 +163,7 @@ def training_iter_fn(batch, model, optimizer): batch = {"x": input1, "y": input2} for _ in range(2): opt_training_iter_fn(batch, net, optimizer) - self.assertEqual(cnts.frame_count, (2 if sys.version_info < (3, 8) else 6)) + self.assertEqual(cnts.frame_count, 1) if __name__ == "__main__": diff --git a/test/dynamo/test_replay_record.py b/test/dynamo/test_replay_record.py index c158590a9d7f4..5235e355e0d1c 100644 --- a/test/dynamo/test_replay_record.py +++ b/test/dynamo/test_replay_record.py @@ -5,7 +5,6 @@ import unittest import torch - import torch._dynamo.test_case import torch._dynamo.testing @@ -29,17 +28,22 @@ def setUpClass(cls): cls._exit_stack.enter_context( unittest.mock.patch.object(torch._dynamo.config, "print_graph_breaks", True) ) + # Most of the tests are checking to see if errors got logged, so we + # ask for errors to be suppressed + cls._exit_stack.enter_context( + unittest.mock.patch.object(torch._dynamo.config, "suppress_errors", True) + ) cls._exit_stack.enter_context( unittest.mock.patch.object( torch._dynamo.config, - "replay_record_dir_name", - "/tmp/torch._dynamo_error_records/", + "debug_dir_root", + "/tmp/_torchdynamo_debug_/", ) ) @classmethod def tearDownClass(cls): - shutil.rmtree(torch._dynamo.config.replay_record_dir_name, ignore_errors=True) + shutil.rmtree(torch._dynamo.config.debug_dir_root, ignore_errors=True) cls._exit_stack.close() def check_replay(self, fn, *args, exp_exc_name=None): diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 2bd3130958eb2..34f8db248c2b3 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -17,6 +17,14 @@ import torch._dynamo.test_case import torch._dynamo.testing import torch._dynamo.utils + +import torch._functorch.config + +try: + from test_minifier import requires_cuda +except ImportError: + from .test_minifier import requires_cuda + from torch import nn from torch._dynamo.debug_utils import same_two_models from torch._dynamo.testing import rand_strided, requires_static_shapes, same @@ -30,6 +38,16 @@ HAS_REFS = False +_orig_module_call = torch.nn.Module.__call__ + + +def is_fx_tracing_test() -> bool: + """ + Copied from the hpc trainer codebase + """ + return torch.nn.Module.__call__ is not _orig_module_call + + def ifdyn(count1, count2): if torch._dynamo.config.dynamic_shapes: return count1 @@ -793,13 +811,11 @@ def test_do_paste_mask(self): ) self.assertGreaterEqual(torch._dynamo.utils.counters["frames"]["ok"], 3) - # Graph break because of dynamic slicing self.assertEqual( torch._dynamo.utils.counters["frames"]["total"], torch._dynamo.utils.counters["frames"]["ok"] + 1, ) - @patch.object(torch._dynamo.config, "fake_tensor_propagation", True) def test_convert_boxes_to_pooler_format(self): boxes1 = [ Boxes(torch.arange(0, 8).reshape((2, 4))), @@ -872,8 +888,9 @@ def test_longformer_chunk(self): self.assertTrue(same(opt_fn(input1), correct1)) self.assertTrue(same(opt_fn(input2), correct2)) - self.assertEqual(cnt.frame_count, ifdyn(1, 2)) - self.assertEqual(cnt.op_count, ifdyn(19, 4)) + # Dyn recompiles are due to changes in hidden_state (Should we be guarding on this?) + self.assertEqual(cnt.frame_count, ifdyn(4, 2)) + self.assertEqual(cnt.op_count, ifdyn(76, 4)) def test_hf_t5_forward(self): input = torch.randn([1, 2048, 512]) @@ -934,7 +951,10 @@ def test_chunk_reformer_ff(self): self.assertEqual(cnt.op_count, 4) # see: https://github.com/pytorch/pytorch/issues/80067 - @patch.object(torch._dynamo.config, "fake_tensor_propagation", False) + # NB: When you remove the expectedFailure, don't forget to + # uncomment/adjust the assertEqual below + @unittest.expectedFailure + @patch.object(torch._dynamo.config, "fake_tensor_propagation", True) @patch.object(torch._dynamo.config, "capture_scalar_outputs", True) def test_maml_item_capture(self): a = torch.randn(5, 1, 28, 28) @@ -948,12 +968,11 @@ def test_maml_item_capture(self): for _ in range(10): self.assertTrue(same(opt_model(a, b, c, d), correct)) - self.assertEqual(cnt.frame_count, ifdyn(3, 2)) + # self.assertEqual(cnt.frame_count, ifdyn(3, 2)) # TODO(jansel): figure out why op count depends on imports - self.assertIn(cnt.op_count, (36, 35, 29, 28)) + self.assertIn(cnt.op_count, (36, 35, 34, 29, 28, 27)) # see: https://github.com/pytorch/pytorch/issues/80067 - @patch.object(torch._dynamo.config, "fake_tensor_propagation", False) @patch.object(torch._dynamo.config, "capture_scalar_outputs", False) def test_maml_no_item_capture(self): a = torch.randn(5, 1, 28, 28) @@ -969,7 +988,7 @@ def test_maml_no_item_capture(self): self.assertEqual(cnt.frame_count, ifdyn(5, 4)) # TODO(jansel): figure out why op count depends on imports - self.assertIn(cnt.op_count, (31, 36, 35, 29, 28)) + self.assertIn(cnt.op_count, (31, 36, 35, 34, 29, 28)) def test_hf_model_output(self): ex = ModelOutput(a=torch.randn(10), b=torch.randn(10), c=torch.randn(10)) @@ -1028,9 +1047,14 @@ def fn(): before, after = opt_fn() self.assertTrue(same(before, after)) - self.assertEqual(cnt.frame_count, 1) - self.assertEqual(cnt.op_count, 4) # rand, rand - graph, _ = torch._dynamo.export(fn) + self.assertEqual(cnt.frame_count, 2) + self.assertEqual(cnt.op_count, 3) # rand, rand + try: + graph, _ = torch._dynamo.export(fn) + # See https://github.com/pytorch/pytorch/pull/87490 + self.fail("unexpected export success") + except torch._dynamo.exc.Unsupported: + pass def test_seq_append_list(self): x = torch.randn(4, 10) @@ -1084,7 +1108,6 @@ def fn(model, x): self.assertEqual(cnt.frame_count, 1) self.assertEqual(cnt.op_count, 2) - @patch.object(torch._dynamo.config, "fake_tensor_propagation", True) def test_nn_parameter(self): def test_fn(): a = torch.nn.Parameter(torch.randn(5, 5)) @@ -1258,6 +1281,21 @@ def fn(x): res = opt_fn(x) self.assertTrue(same(ref, res)) + # https://github.com/pytorch/torchdynamo/issues/1446 + def test_grad_mode_carrying_correct_state_after_graph_break(self): + def fn(x): + with torch.no_grad(): + y = x * 3 + print("Break") + z = x + 2 + return y, z + + x = torch.randn(3, requires_grad=True) + opt_fn = torch._dynamo.optimize("eager")(fn) + y, z = opt_fn(x) + self.assertFalse(y.requires_grad) + self.assertFalse(z.requires_grad) + def test_abc_setattr(self): # tests that we correctly bail out of __setattr__ calls @@ -1288,6 +1326,7 @@ def blah(self, x): self.assertGreaterEqual(torch._dynamo.utils.counters["frames"]["ok"], 3) self.assertGreaterEqual(torch._dynamo.utils.counters["frames"]["total"], 3) + @patch.object(torch._dynamo.config, "suppress_errors", True) def test_guard_fail_tensor_bool(self): @torch._dynamo.skip def fn(): @@ -1347,6 +1386,8 @@ def fn(args): self.assertTrue(same(ref, res)) + # AssertionError: ABCMeta + @unittest.expectedFailure def test_numpy_list(self): @torch._dynamo.disable def rand_gen(): @@ -1372,8 +1413,17 @@ def fn(x): self.assertTrue(same(ref1, res1)) @unittest.skipIf(not HAS_REFS, "requires recent PT version") - @unittest.expectedFailure def test_primtorch(self): + @torch._dynamo.optimize("eager") + def fn(x): + torch._refs.abs(x) + + fn(torch.randn(3)) + + @unittest.skipIf(not HAS_REFS, "requires recent PT version") + @unittest.expectedFailure + # inline_call [('inline in skipfiles: bind ...python3.10/inspect.py', 1)] + def test_primtorch_no_graph_break(self): @torch._dynamo.optimize("eager", nopython=True) def fn(x): torch._refs.abs(x) @@ -1426,12 +1476,14 @@ def fn(x): fn(torch.randn(3)) + # Bug with storage meta - torch.BoolStorage is becoming torch.storage._LegacyStorageMeta + @unittest.expectedFailure def test_isinstance_storage(self): @torch._dynamo.optimize("eager") def fn(x): f = bytearray([0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x10, 0x40]) bools = torch.BoolStorage.from_buffer(f, "big") - self.assertTrue(isinstance(bools, torch.BoolStorage)) + assert isinstance(bools, torch.BoolStorage) return x fn(torch.randn(3)) @@ -1541,7 +1593,6 @@ def __init__(self): self.assertTrue(same(ref, res)) - @patch.object(torch._dynamo.config, "fake_tensor_propagation", False) def test_specialized_stride(self): def f(): e = torch.empty(4) @@ -1630,8 +1681,21 @@ def fn(x): opt_fn(x) self.assertEqual(cnt.frame_count, 1) - # This doesn't work without fake tensors but I don't care - @patch.object(torch._dynamo.config, "fake_tensor_propagation", True) + @patch.object(torch._functorch.config, "use_dynamic_shapes", True) + def test_bigbird_unsqueeze_inplace(self): + def fn(reshape_2): + view_2 = reshape_2.clone() + view_2.unsqueeze_(2) + cat_11 = torch.cat([view_2], dim=2) + view_13 = cat_11.view((2, 12, 64, -1)) + return (view_13,) + + x = torch.randn(2, 12, 64, 64, requires_grad=True) + ref = fn(x) + opt_fn = torch._dynamo.optimize("aot_eager")(fn) + res = opt_fn(x) + self.assertTrue(same(ref, res)) + def test_issue1466_size_aot_autograd(self): def fn(x): # do a tensor op and a size compute @@ -1711,6 +1775,370 @@ def forward(self, getitem_1, getitem_2, add): ] self.assertTrue(same_two_models(mod, opt_mod, args)) + def test_optimized_deepcopy(self): + # See https://github.com/pytorch/pytorch/pull/88629 + class Foo(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc = torch.nn.Linear(in_features=2, out_features=3, bias=True) + + def forward(self, x): + return self.fc(x) + + mod = Foo() + opt_mod = torch._dynamo.optimize("eager")(mod) + args = [torch.randn(1, 2)] + self.assertTrue(same_two_models(mod, opt_mod, args)) + + def test_class_member(self): + class Foo(torch.nn.Module): + a = 4 + b = torch.ones(3, 4) + + def __init__(self): + super().__init__() + self.c = 4 + + def forward(self, x): + return x.cos() + self.a + self.b + self.c + + mod = Foo() + opt_mod = torch._dynamo.optimize("eager", nopython=True)(mod) + args = (torch.randn(3, 4),) + self.assertTrue(same(mod(*args), opt_mod(*args))) + + def test_named_buffers(self): + class Foo(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("x", torch.ones(3)) + self.register_buffer("y", torch.ones(3)) + + def forward(self, inp): + res = 0 + for name, buffer in self.named_buffers(): + res += buffer.sum() + + return inp.cos() + res + + mod = Foo() + opt_mod = torch._dynamo.optimize("eager", nopython=True)(mod) + args = (torch.randn(3, 4),) + self.assertTrue(same(mod(*args), opt_mod(*args))) + + def test_is_symbolic_tracing(self): + # Ensure no graph break here + def fn(x): + if is_fx_tracing_test(): + return x * 2 + return x * 4 + + a = torch.randn(4) + ref = fn(a) + opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) + res = opt_fn(a) + self.assertTrue(same(ref, res)) + + def test_tokenization(self): + from collections import UserDict + + class BatchEncoding(UserDict): + """ + Copied from tokenization + """ + + def __init__( + self, + data, + ): + super().__init__(data) + + def __getattr__(self, item: str): + try: + return self.data[item] + except KeyError as e: + raise AttributeError from e + + def tokenization(x): + encoding = BatchEncoding({"key": x}) + return encoding["key"] + + opt_fn = torch._dynamo.optimize("eager")(tokenization) + x = torch.rand((1, 4)) + ref = tokenization(x) + res = opt_fn(x) + self.assertTrue(same(ref, res)) + + def test_modules(self): + class Foo(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc = torch.nn.Linear(4, 3) + + def forward(self, inp): + res = torch.zeros(3, 3) + for mod in self.modules(): + res += self.fc(inp) + return res + + mod = Foo() + args = (torch.ones(3, 4),) + cnt = torch._dynamo.testing.CompileCounter() + opt_mod = torch._dynamo.optimize(cnt, nopython=True)(mod) + self.assertTrue(same(mod(*args), opt_mod(*args))) + self.assertEqual(cnt.op_count, 5) + self.assertEqual(cnt.frame_count, 1) + + @requires_cuda() + def test_norm_dtype(self): + def foo(_stack0): + getitem = _stack0[(slice(None, None, None), -1)] + _stack0 = None + normalize = torch.nn.functional.normalize(getitem, p=2, dim=1) + getitem = None + return (normalize,) + + args = [((2, 50, 256), (1, 256, 1), torch.float16, "cuda", False)] + args = [ + rand_strided(sh, st, dt, dev).requires_grad_(rg) + for (sh, st, dt, dev, rg) in args + ] + + opt_foo = torch._dynamo.optimize("aot_inductor_debug")(foo) + with torch.cuda.amp.autocast(enabled=True): + ref = foo(*args)[0] + res = foo(*args)[0] + self.assertEqual(ref.dtype, res.dtype) + + self.assertTrue(same(res, ref)) + + def test_for_loop_graph_break(self): + def inner(x): + return torch.sin(x) + + def fn(x): + for _ in range(100): + inner(x) + torch._dynamo.graph_break() + return x + + cnt = torch._dynamo.testing.CompileCounter() + opt_fn = torch._dynamo.optimize(cnt)(fn) + x = torch.randn(4) + opt_fn(x) + self.assertEqual(cnt.frame_count, 1) + self.assertEqual(cnt.op_count, 1) + + def test_for_loop_graph_break_before(self): + # Checks that the backedge is calculated correctly + def inner(x): + return torch.sin(x) + + def fn(x): + torch._dynamo.graph_break() + for _ in range(100): + inner(x) + return x + + cnt = torch._dynamo.testing.CompileCounter() + opt_fn = torch._dynamo.optimize(cnt)(fn) + x = torch.randn(4) + opt_fn(x) + self.assertEqual(cnt.frame_count, 1) + self.assertEqual(cnt.op_count, 100) + + def test_avoid_dupe_specialization(self): + def f(x, y): + return (x + y) * 1 + + opt_f = torch._dynamo.optimize("aot_eager")(f) + + for b in [True, False]: + x = torch.randn(4, requires_grad=b) + y = torch.randn(4, requires_grad=b) + self.assertEqual(f(x, x), opt_f(x, x)) + self.assertEqual(f(x, y), opt_f(x, y)) + + def test_while_loop_graph_break(self): + # Repro of tacotron2 cache_size_recompilation + def inner(x): + return torch.sin(x) + + def fn(x): + i = 20 + while i > 10: + x = inner(x) + i -= 1 + torch._dynamo.graph_break() + return x + + cnt = torch._dynamo.testing.CompileCounter() + opt_fn = torch._dynamo.optimize(cnt)(fn) + x = torch.randn(4) + opt_fn(x) + self.assertEqual(cnt.frame_count, 1) + self.assertEqual(cnt.op_count, 1) + + @patch.object(torch._dynamo.config, "rewrite_assert_with_torch_assert", True) + def test_rewrite_assert_with_msg(self): + def f(x): + b = x.sin() + assert x[0] == 3, "First dim need to be 3" + return x.cos() + b + + args = (torch.Tensor([3, 4, 5]),) + cnt = torch._dynamo.testing.CompileCounter() + + opt_f = torch._dynamo.optimize(cnt, nopython=True)(f) + self.assertTrue(same(f(*args), opt_f(*args))) + self.assertEqual(cnt.op_count, 6) + self.assertEqual(cnt.frame_count, 1) + + exported, _ = torch._dynamo.export(f, torch.Tensor([3, 4, 5])) + self.assertTrue(same(exported(*args), f(*args))) + + with self.assertRaisesRegex(AssertionError, ""): + exported, _ = torch._dynamo.export(f, torch.Tensor([4, 4, 5])) + + @patch.object(torch._dynamo.config, "rewrite_assert_with_torch_assert", True) + def test_not_rewrite_assert_for_other_errors(self): + def f(x): + b = x.sin() + if not x.sum() <= 3: + raise ValueError("input sum needs to be 3") + return x.cos() + b + + args = (torch.Tensor([3, 4, 5]),) + opt_fn = torch._dynamo.optimize("eager")(f) + with self.assertRaisesRegex(ValueError, "input sum needs to be 3"): + opt_fn(*args) + + # TODO (tmanlaibaatar) handle data-dependent fstring in assert statement. + @patch.object(torch._dynamo.config, "rewrite_assert_with_torch_assert", True) + def test_rewrite_assert_with_fstring_msg(self): + def f(x): + b = x.sin() + assert x[0] == 3, f"First dim need to be {x[0]}" + return x.cos() + b + + args = (torch.Tensor([3, 4, 5]),) + with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, "generic_jump"): + exported, _ = torch._dynamo.export(f, torch.Tensor([3, 4, 5])) + + @patch.object(torch._dynamo.config, "rewrite_assert_with_torch_assert", True) + def test_rewrite_assert_without_msg(self): + def f(x): + b = x.sin() + assert x[0] == 3 + return x.cos() + b + + args = (torch.Tensor([3, 4, 5]),) + exported, _ = torch._dynamo.export(f, torch.Tensor([3, 4, 5])) + self.assertTrue(same(exported(*args), f(*args))) + + with self.assertRaisesRegex(AssertionError, ""): + exported, _ = torch._dynamo.export(f, torch.Tensor([4, 4, 5])) + + @patch.object(torch._dynamo.config, "rewrite_assert_with_torch_assert", True) + def test_rewrite_assert_noop(self): + def f(x): + b = x.sin() + assert True + assert x.dtype == torch.float32 + return x.cos() + b + + args = (torch.Tensor([3, 4, 5]),) + exported, _ = torch._dynamo.export(f, torch.Tensor([3, 4, 5])) + self.assertTrue(same(exported(*args), f(*args))) + + cnt = torch._dynamo.testing.CompileCounter() + opt_f = torch._dynamo.optimize(cnt, nopython=True)(f) + self.assertTrue(same(f(*args), opt_f(*args))) + # torch._assert shouldn't be in the graph + self.assertEqual(cnt.op_count, 3) + self.assertEqual(cnt.frame_count, 1) + + exported, _ = torch._dynamo.export(f, torch.Tensor([4, 4, 5])) + self.assertTrue(same(exported(*args), f(*args))) + + @patch.object(torch._dynamo.config, "rewrite_assert_with_torch_assert", False) + def test_not_rewrite_assert(self): + def f(x): + b = x.sin() + assert x[0] == 3 + return x.cos() + b + + with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, "generic_jump"): + torch._dynamo.export(f, torch.Tensor([3, 4, 5])) + + @patch.object(torch._functorch.config, "use_dynamic_shapes", True) + def test_batchnorm_e2e(self): + class Repro(torch.nn.Module): + def __init__(self): + super().__init__() + self.bn = torch.nn.BatchNorm2d( + 64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True + ) + self.conv1 = torch.nn.Conv2d( + 64, + 64, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + bias=False, + ) + + def forward(self, x): + x1 = self.bn(x) + x2 = self.conv1(x1) + out = torch.nn.functional.relu(x2) + return (out,) + + torch.manual_seed(1337) + + m_ref = Repro() + m_test = deepcopy(m_ref) + + @torch._dynamo.optimize("aot_inductor_debug") + def compiled_fn(x): + return m_test(x) + + x_ref = torch.randn(2, 64, 32, 32, requires_grad=True) + x_test = x_ref.clone() + + # Loop multiple times: each iteration the running_mean/var on batchnorm will update, + # which changes the output of the next iteration + for _ in range(3): + ref = m_ref(x_ref) + res = compiled_fn(x_test) + + self.assertTrue(same(ref, res)) + + for r in ref: + if r.requires_grad: + r.sum().backward() + for r in res: + if r.requires_grad: + r.sum().backward() + + for param_ref, param_test in zip(m_ref.parameters(), m_test.parameters()): + self.assertTrue(same(param_ref, param_test)) + # Assert running_mean/var + for buffer_ref, buffer_test in zip(m_ref.buffers(), m_test.buffers()): + self.assertTrue(same(buffer_ref, buffer_test)) + + @patch.object(torch._dynamo.config, "dynamic_shapes", True) + def test_dynamic_shapes_right_side(self): + def f(x): + return torch.ones(5 * x.shape[0]) + + inp = torch.randn(6, 5) + + gm, _ = torch._dynamo.export( + f, torch.randn(4, 5), aten_graph=True, tracing_mode="symbolic" + ) + self.assertEqual(gm(inp).shape, f(inp).shape) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_subgraphs.py b/test/dynamo/test_subgraphs.py index 3a38561f16d2a..27f73026435cd 100644 --- a/test/dynamo/test_subgraphs.py +++ b/test/dynamo/test_subgraphs.py @@ -367,6 +367,18 @@ def fn(a, b): # just one graph now rather than 10 self.assertEqual(cnt_dynamic.frame_count, 1) + def test_dynamic_kwarg(self): + def fn(a, b): + return a - b * 10 + + torch._dynamo.reset() + cnt_dynamic = torch._dynamo.testing.CompileCounter() + opt_fn = torch._dynamo.optimize(cnt_dynamic, dynamic=True)(fn) + for i in range(10): + opt_fn(torch.randn(i), torch.randn(i)) + # just one graph + self.assertEqual(cnt_dynamic.frame_count, 1) + @patch.object(torch._dynamo.config, "capture_scalar_outputs", True) def test_no_graph_break_on_item(self): def fn(a, b): diff --git a/test/dynamo/test_torchxla_integration.py b/test/dynamo/test_torchxla_integration.py new file mode 100644 index 0000000000000..ecefed93e9370 --- /dev/null +++ b/test/dynamo/test_torchxla_integration.py @@ -0,0 +1,133 @@ +# Owner(s): ["module: dynamo"] +import copy +import unittest + +import torch + +try: + from .test_torchxla_util import maybe_skip_torchxla_test +except ImportError: + from test_torchxla_util import maybe_skip_torchxla_test + +try: + import torch._dynamo.optimizations.torchxla_integration as integration + import torch_xla.core.xla_model as xm + import torch_xla.debug.metrics as metrics +except ImportError: + # tests using torch_xla will be skipped. It's fine to ignore the + # importing error here. + pass + +from torch import fx, nn + + +class BasicModule(nn.Module): + def __init__(self): + super(BasicModule, self).__init__() + + def forward(self, x, y): + return x + y + + def get_random_inputs(self): + return (torch.randn(10), torch.randn(10)) + + +class MatmulModule(nn.Module): + def __init__(self): + super(MatmulModule, self).__init__() + + def forward(self, x, y): + return x @ y + + def get_random_inputs(self): + return (torch.randn(5, 100), torch.randn(100, 5)) + + +class LinearModule(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(10, 5) + + def forward(self, x): + return self.linear(x) + + def get_random_inputs(self): + return (torch.randn(10),) + + +class ModuleInplaceUpdate(nn.Module): + def __init__(self): + super(ModuleInplaceUpdate, self).__init__() + + def forward(self, a, b): + a.sub_(b) + return b - 1, b + 1 + + def get_random_inputs(self): + return (torch.randn(10), torch.randn(10)) + + +def allclose(expected, actual): + def unwrap(cont): + if isinstance(cont, (list, tuple)) and len(cont) == 1: + return cont[0] + return cont + + expected = unwrap(expected) + actual = unwrap(actual) + + if isinstance(expected, torch.Tensor) and isinstance(actual, torch.Tensor): + return torch.allclose(expected, actual) + elif isinstance(expected, (tuple, list)) and isinstance(actual, (tuple, list)): + return len(expected) == len(actual) and all( + torch.allclose(a, b) for a, b in zip(expected, actual) + ) + else: + raise RuntimeError("Unexpected types") + + +def make_reuse_graph_test(module_class, niter=100): + @maybe_skip_torchxla_test + def test_wrapper(self): + xla_dev = xm.xla_device() + xla_module = module_class().to(device=xla_dev) + inputs = tuple(x.to(device=xla_dev) for x in xla_module.get_random_inputs()) + metrics.clear_counters() + optimized_mod = integration.extract_compiled_graph( + fx.symbolic_trace(xla_module), inputs + ) + + for i in range(niter): + xla_inputs = tuple( + inp.to(device=xla_dev) for inp in xla_module.get_random_inputs() + ) + xla_inputs_copy = copy.deepcopy(xla_inputs) + + expected = xla_module(*xla_inputs) + # make sure above lazy computation is executed. + xm.mark_step() + + actual = optimized_mod(*xla_inputs_copy) + + if not allclose(expected, actual): + print( + f"Incorrect results at iter {i}. expected\n{expected}, actual\n{actual}" + ) + self.assertTrue(False) + + # make sure arguments match after calling the model forward method + # to handle inplace updates. + if not allclose(xla_inputs, xla_inputs_copy): + print( + f"Incorrect updated arguments at iter {i}. expected\n{xla_inputs}, actual\n{xla_inputs_copy}" + ) + self.assertTrue(False) + + return test_wrapper + + +class TorchXLAReuseGraphTest(unittest.TestCase): + test_basic = make_reuse_graph_test(BasicModule) + test_matmul = make_reuse_graph_test(MatmulModule) + test_linear = make_reuse_graph_test(LinearModule) + test_inplace_update = make_reuse_graph_test(ModuleInplaceUpdate) diff --git a/test/dynamo/test_torchxla_num_output.py b/test/dynamo/test_torchxla_num_output.py new file mode 100644 index 0000000000000..0e91a358d4690 --- /dev/null +++ b/test/dynamo/test_torchxla_num_output.py @@ -0,0 +1,120 @@ +# Owner(s): ["module: dynamo"] +import unittest + +import torch +from torch import nn +from torch._dynamo.optimizations.torchxla_integration import GraphInputMatcher +from torch.utils._pytree import tree_map_only + +try: + from .test_torchxla_util import maybe_skip_torchxla_test +except ImportError: + from test_torchxla_util import maybe_skip_torchxla_test + +try: + import torch_xla + import torch_xla.core.xla_model as xm +except ImportError: + # tests using torch_xla will be skipped. It's fine to ignore the + # importing error here. + pass + + +class DirectReturnModule(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, a, b, c): + """ + The XLA graph will only return the first 2 items + """ + return a + b, a + c, b + + def get_example_inputs(self): + return (torch.rand(2), torch.rand(2), torch.rand(2)) + + +class DirectReturnWithInplaceUpdateModule(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, a, b, c): + """ + Inplace update on b cause it to be returned in XLA graph + """ + b.zero_() + return a + b, a + c, b + + def get_example_inputs(self): + return (torch.rand(2), torch.rand(2), torch.rand(2)) + + +class DirectReturnWithDuplicatedInplaceUpdateModule(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, a, b, c): + """ + Even if we return b twice, the XLA graph only return b once. + """ + b.zero_() + return a + b, a + c, b, b + + def get_example_inputs(self): + return (torch.rand(2), torch.rand(2), torch.rand(2)) + + +class TestNumOutput(unittest.TestCase): + def do_test(self, model_class, expected_num_output): + xla_dev = xm.xla_device() + model = model_class().to(device=xla_dev) + inputs = tree_map_only( + torch.Tensor, lambda x: x.to(device=xla_dev), model.get_example_inputs() + ) + + xm.mark_step() + args_tensor_ids = [ + torch_xla._XLAC._xla_get_tensor_id(xla_arg) for xla_arg in inputs + ] + tensor_id_to_arg_idx = { + tensor_id: i for i, tensor_id in enumerate(args_tensor_ids) + } + outputs = model(*inputs) + xla_graph_hash = torch_xla._XLAC._get_graph_hash(outputs) + + ( + graph_input_tensor_ids, + graph_input_xla_values, + ) = torch_xla._XLAC._get_tensors_xla_device_data_node(outputs) + + graph_input_matcher = GraphInputMatcher( + tensor_id_to_arg_idx, graph_input_tensor_ids, graph_input_xla_values + ) + torch_xla._XLAC._xla_sync_multi(outputs, []) + + def run_cached_graph(*inputs): + torch_xla._XLAC._xla_sync_multi(inputs, []) + xla_graph_inputs = graph_input_matcher(inputs) + xla_graph_outputs = torch_xla._XLAC._run_cached_graph( + xla_graph_hash, xla_graph_inputs + ) + return xla_graph_outputs + + test_inputs = tree_map_only( + torch.Tensor, lambda x: x.to(device=xla_dev), model.get_example_inputs() + ) + self.assertEqual(expected_num_output, len(run_cached_graph(*test_inputs))) + + @maybe_skip_torchxla_test + def test_direct_return(self): + self.do_test(DirectReturnModule, expected_num_output=2) + + @maybe_skip_torchxla_test + def test_direct_return_with_inplace_update(self): + self.do_test(DirectReturnWithInplaceUpdateModule, expected_num_output=3) + + @maybe_skip_torchxla_test + def test_direct_return_with_duplicated_inplace_update(self): + self.do_test( + DirectReturnWithDuplicatedInplaceUpdateModule, expected_num_output=3 + ) diff --git a/test/dynamo/test_torchxla_util.py b/test/dynamo/test_torchxla_util.py new file mode 100644 index 0000000000000..5c54af34678a6 --- /dev/null +++ b/test/dynamo/test_torchxla_util.py @@ -0,0 +1,26 @@ +# Owner(s): ["module: dynamo"] +import functools +import os +import unittest + + +@functools.lru_cache(None) +def should_run_torchxla_tests(): + """ + Run the tests if torch_xla is available and number of gpu devices is specified. + """ + has_torch_xla = True + try: + import torch_xla # noqa: F401 + except ImportError: + has_torch_xla = False + + gpu_device_specified = int(os.environ.get("GPU_NUM_DEVICES", "0")) > 0 + return has_torch_xla and gpu_device_specified + + +def maybe_skip_torchxla_test(test_case): + return unittest.skipIf( + not should_run_torchxla_tests(), + "Skip the tests since torch_xla is not available or XLA devices are not specified", + )(test_case) diff --git a/test/dynamo/test_unspec.py b/test/dynamo/test_unspec.py index fbf3983661935..7ffed902fd9dc 100644 --- a/test/dynamo/test_unspec.py +++ b/test/dynamo/test_unspec.py @@ -137,7 +137,7 @@ def fn(x): res2 = opt_fn(x) self.assertTrue(same(res1, res2)) - @patch.object(torch._dynamo.config, "fake_tensor_propagation", False) + @patch.object(torch._dynamo.config, "dynamic_shapes", True) def test_multiple_consecutive_random_calls_before_graph(self): def fn(x): dim1 = random.randrange(start=0, stop=5) @@ -171,6 +171,8 @@ def fn(x): res2 = opt_fn(x) self.assertTrue(same(res1, res2)) + # TypeError: zeros(): argument 'size' (position 1) must be tuple of SymInts, not FakeTensor + @unittest.expectedFailure def test_builtin_getitem(self): # builtin getitem args[0] is python list and args[1] is unspec def fn(x, idx): diff --git a/test/dynamo/test_verify_correctness.py b/test/dynamo/test_verify_correctness.py index 8e3624bfd9e7d..7a6f8e3d42639 100644 --- a/test/dynamo/test_verify_correctness.py +++ b/test/dynamo/test_verify_correctness.py @@ -100,8 +100,11 @@ def compiler_fn(graph, example_inputs): r3 = opt_fn(a, (b, c), d) self.assertIsNotNone(r1) - self.assertTrue(same(r1, r2)) - self.assertTrue(same(r1, r3)) + + self.assertEqual(r1.shape, r2.shape) + self.assertEqual(r1.shape, r3.shape) + self.assertEqual(r1.device, r2.device) + self.assertEqual(r1.device, r3.device) @patch.object(config, "verify_correctness", True) def test_nnc(self): diff --git a/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect b/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect index 01ac1efffd29e..7bdd777ad4512 100644 --- a/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect +++ b/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect @@ -53,7 +53,7 @@ torch.fx.node.Node.__init__(self, graph: 'Graph', name: str, op: str, target: 'T torch.fx.node.Node.append(self, x: 'Node') -> None torch.fx.node.Node.format_node(self, placeholder_names: Optional[List[str]] = None, maybe_return_typename: Optional[List[str]] = None) -> Optional[str] torch.fx.node.Node.prepend(self, x: 'Node') -> None -torch.fx.node.Node.replace_all_uses_with(self, replace_with: 'Node', delete_user_cb: Callable[[Node], bool] = >) -> List[Node] +torch.fx.node.Node.replace_all_uses_with(self, replace_with: 'Node', delete_user_cb: Callable[[Node], bool] = >, propagate_meta = False) -> List[Node] torch.fx.node.Node.replace_input_with(self, old_input: 'Node', new_input: 'Node') torch.fx.node.Node.update_arg(self, idx: int, arg: torch.fx.node.Argument) -> None torch.fx.node.Node.update_kwarg(self, key: str, arg: torch.fx.node.Argument) -> None @@ -71,4 +71,4 @@ torch.fx.proxy.TracerBase.iter(self, obj: 'Proxy') -> Iterator torch.fx.proxy.TracerBase.keys(self, obj: 'Proxy') -> Any torch.fx.proxy.TracerBase.proxy(self, node: torch.fx.node.Node) -> 'Proxy' torch.fx.proxy.TracerBase.to_bool(self, obj: 'Proxy') -> bool -torch.fx.subgraph_rewriter.replace_pattern(gm: torch.fx.graph_module.GraphModule, pattern: Callable, replacement: Callable) -> List[torch.fx.subgraph_rewriter.Match] +torch.fx.subgraph_rewriter.replace_pattern(gm: torch.fx.graph_module.GraphModule, pattern: Union[Callable, torch.fx.graph_module.GraphModule], replacement: Union[Callable, torch.fx.graph_module.GraphModule]) -> List[torch.fx.subgraph_rewriter.Match] diff --git a/test/expect/TestSparseCompressedCPU.test_print_SparseBSC_cpu.expect b/test/expect/TestSparseCompressedCPU.test_print_SparseBSC_cpu.expect index 696fcbb08cf12..7c0cccd56cd1d 100644 --- a/test/expect/TestSparseCompressedCPU.test_print_SparseBSC_cpu.expect +++ b/test/expect/TestSparseCompressedCPU.test_print_SparseBSC_cpu.expect @@ -1,6979 +1,3583 @@ -########## torch.float32/torch.int32/size=()+(3, 4)+() ########## +########## torch.float32/torch.int32/size=(2, 3)+(2, 3)+() ########## # sparse tensor -tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 2]), - values=tensor([[[ 1., 11.]], +tensor(ccol_indices=tensor([[[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]], - [[ 2., 12.]], + [[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]]]), + row_indices=tensor([[[0, 0, 0], + [0, 0, 0], + [0, 0, 0]], - [[ 3., 13.]], - - [[ 4., 14.]]]), size=(3, 4), nnz=4, - layout=torch.sparse_bsc) -# _ccol_indices -tensor([0, 2, 4], dtype=torch.int32) -# _row_indices -tensor([0, 1, 0, 2], dtype=torch.int32) -# _values -tensor([[[ 1., 11.]], - - [[ 2., 12.]], - - [[ 3., 13.]], - - [[ 4., 14.]]]) - -########## torch.float32/torch.int32/size=()+(0, 0)+() ########## -# sparse tensor -tensor(ccol_indices=tensor([0]), - row_indices=tensor([], size=(0,)), - values=tensor([], size=(0, 1, 2)), size=(0, 0), nnz=0, - layout=torch.sparse_bsc) -# _ccol_indices -tensor([0], dtype=torch.int32) -# _row_indices -tensor([], dtype=torch.int32) -# _values -tensor([], size=(0, 1, 2)) - -########## torch.float32/torch.int32/size=(2,)+(6, 2)+() ########## -# sparse tensor -tensor(ccol_indices=tensor([[0, 2, 4], - [0, 3, 4]]), - row_indices=tensor([[0, 1, 0, 1], - [0, 1, 2, 0]]), - values=tensor([[[[1.], - [2.]], - - [[2.], - [3.]], - - [[3.], - [4.]], - - [[4.], - [5.]]], - - - [[[5.], - [6.]], - - [[6.], - [7.]], - - [[7.], - [8.]], - - [[8.], - [9.]]]]), size=(2, 6, 2), nnz=4, - layout=torch.sparse_bsc) -# _ccol_indices -tensor([[0, 2, 4], - [0, 3, 4]], dtype=torch.int32) -# _row_indices -tensor([[0, 1, 0, 1], - [0, 1, 2, 0]], dtype=torch.int32) -# _values -tensor([[[[1.], - [2.]], - - [[2.], - [3.]], - - [[3.], - [4.]], - - [[4.], - [5.]]], - - - [[[5.], - [6.]], - - [[6.], - [7.]], - - [[7.], - [8.]], - - [[8.], - [9.]]]]) - -########## torch.float32/torch.int32/size=(2, 3)+(9, 4)+() ########## -# sparse tensor -tensor(ccol_indices=tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], - - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]), - row_indices=tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]), - values=tensor([[[[[ 1., 11.], - [ 2., 12.], - [ 3., 13.]], - - [[ 2., 12.], - [ 3., 13.], - [ 4., 14.]], - - [[ 3., 13.], - [ 4., 14.], - [ 5., 15.]], - - [[ 4., 14.], - [ 5., 15.], - [ 6., 16.]]], - - - [[[ 5., 15.], - [ 6., 16.], - [ 7., 17.]], - - [[ 6., 16.], - [ 7., 17.], - [ 8., 18.]], - - [[ 7., 17.], - [ 8., 18.], - [ 9., 19.]], - - [[ 8., 18.], - [ 9., 19.], - [10., 20.]]], - - - [[[ 9., 19.], - [10., 20.], - [11., 21.]], - - [[10., 20.], - [11., 21.], - [12., 22.]], - - [[11., 21.], - [12., 22.], - [13., 23.]], - - [[12., 22.], - [13., 23.], - [14., 24.]]]], - - - - [[[[13., 23.], - [14., 24.], - [15., 25.]], - - [[14., 24.], - [15., 25.], - [16., 26.]], - - [[15., 25.], - [16., 26.], - [17., 27.]], - - [[16., 26.], - [17., 27.], - [18., 28.]]], - - - [[[17., 27.], - [18., 28.], - [19., 29.]], - - [[18., 28.], - [19., 29.], - [20., 30.]], - - [[19., 29.], - [20., 30.], - [21., 31.]], - - [[20., 30.], - [21., 31.], - [22., 32.]]], - - - [[[21., 31.], - [22., 32.], - [23., 33.]], - - [[22., 32.], - [23., 33.], - [24., 34.]], - - [[23., 33.], - [24., 34.], - [25., 35.]], - - [[24., 34.], - [25., 35.], - [26., 36.]]]]]), size=(2, 3, 9, 4), nnz=4, - layout=torch.sparse_bsc) -# _ccol_indices -tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], - - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]], dtype=torch.int32) -# _row_indices -tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]], dtype=torch.int32) -# _values -tensor([[[[[ 1., 11.], - [ 2., 12.], - [ 3., 13.]], - - [[ 2., 12.], - [ 3., 13.], - [ 4., 14.]], - - [[ 3., 13.], - [ 4., 14.], - [ 5., 15.]], - - [[ 4., 14.], - [ 5., 15.], - [ 6., 16.]]], - - - [[[ 5., 15.], - [ 6., 16.], - [ 7., 17.]], - - [[ 6., 16.], - [ 7., 17.], - [ 8., 18.]], - - [[ 7., 17.], - [ 8., 18.], - [ 9., 19.]], - - [[ 8., 18.], - [ 9., 19.], - [10., 20.]]], - - - [[[ 9., 19.], - [10., 20.], - [11., 21.]], - - [[10., 20.], - [11., 21.], - [12., 22.]], - - [[11., 21.], - [12., 22.], - [13., 23.]], - - [[12., 22.], - [13., 23.], - [14., 24.]]]], - - - - [[[[13., 23.], - [14., 24.], - [15., 25.]], - - [[14., 24.], - [15., 25.], - [16., 26.]], - - [[15., 25.], - [16., 26.], - [17., 27.]], - - [[16., 26.], - [17., 27.], - [18., 28.]]], - - - [[[17., 27.], - [18., 28.], - [19., 29.]], - - [[18., 28.], - [19., 29.], - [20., 30.]], - - [[19., 29.], - [20., 30.], - [21., 31.]], - - [[20., 30.], - [21., 31.], - [22., 32.]]], - - - [[[21., 31.], - [22., 32.], - [23., 33.]], - - [[22., 32.], - [23., 33.], - [24., 34.]], - - [[23., 33.], - [24., 34.], - [25., 35.]], - - [[24., 34.], - [25., 35.], - [26., 36.]]]]]) - - -########## torch.float64/torch.int32/size=()+(3, 4)+() ########## -# sparse tensor -tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 2]), - values=tensor([[[ 1., 11.]], - - [[ 2., 12.]], - - [[ 3., 13.]], - - [[ 4., 14.]]]), size=(3, 4), nnz=4, dtype=torch.float64, - layout=torch.sparse_bsc) -# _ccol_indices -tensor([0, 2, 4], dtype=torch.int32) -# _row_indices -tensor([0, 1, 0, 2], dtype=torch.int32) -# _values -tensor([[[ 1., 11.]], - - [[ 2., 12.]], - - [[ 3., 13.]], - - [[ 4., 14.]]], dtype=torch.float64) - -########## torch.float64/torch.int32/size=()+(0, 0)+() ########## -# sparse tensor -tensor(ccol_indices=tensor([0]), - row_indices=tensor([], size=(0,)), - values=tensor([], size=(0, 1, 2)), size=(0, 0), nnz=0, - dtype=torch.float64, layout=torch.sparse_bsc) -# _ccol_indices -tensor([0], dtype=torch.int32) -# _row_indices -tensor([], dtype=torch.int32) -# _values -tensor([], size=(0, 1, 2), dtype=torch.float64) - -########## torch.float64/torch.int32/size=(2,)+(6, 2)+() ########## -# sparse tensor -tensor(ccol_indices=tensor([[0, 2, 4], - [0, 3, 4]]), - row_indices=tensor([[0, 1, 0, 1], - [0, 1, 2, 0]]), - values=tensor([[[[1.], - [2.]], - - [[2.], - [3.]], - - [[3.], - [4.]], - - [[4.], - [5.]]], - - - [[[5.], - [6.]], - - [[6.], - [7.]], - - [[7.], - [8.]], - - [[8.], - [9.]]]]), size=(2, 6, 2), nnz=4, dtype=torch.float64, - layout=torch.sparse_bsc) -# _ccol_indices -tensor([[0, 2, 4], - [0, 3, 4]], dtype=torch.int32) -# _row_indices -tensor([[0, 1, 0, 1], - [0, 1, 2, 0]], dtype=torch.int32) -# _values -tensor([[[[1.], - [2.]], - - [[2.], - [3.]], - - [[3.], - [4.]], - - [[4.], - [5.]]], - - - [[[5.], - [6.]], - - [[6.], - [7.]], - - [[7.], - [8.]], - - [[8.], - [9.]]]], dtype=torch.float64) - -########## torch.float64/torch.int32/size=(2, 3)+(9, 4)+() ########## -# sparse tensor -tensor(ccol_indices=tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], - - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]), - row_indices=tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]), - values=tensor([[[[[ 1., 11.], - [ 2., 12.], - [ 3., 13.]], - - [[ 2., 12.], - [ 3., 13.], - [ 4., 14.]], - - [[ 3., 13.], - [ 4., 14.], - [ 5., 15.]], - - [[ 4., 14.], - [ 5., 15.], - [ 6., 16.]]], - - - [[[ 5., 15.], - [ 6., 16.], - [ 7., 17.]], - - [[ 6., 16.], - [ 7., 17.], - [ 8., 18.]], - - [[ 7., 17.], - [ 8., 18.], - [ 9., 19.]], - - [[ 8., 18.], - [ 9., 19.], - [10., 20.]]], - - - [[[ 9., 19.], - [10., 20.], - [11., 21.]], - - [[10., 20.], - [11., 21.], - [12., 22.]], - - [[11., 21.], - [12., 22.], - [13., 23.]], - - [[12., 22.], - [13., 23.], - [14., 24.]]]], - - - - [[[[13., 23.], - [14., 24.], - [15., 25.]], - - [[14., 24.], - [15., 25.], - [16., 26.]], - - [[15., 25.], - [16., 26.], - [17., 27.]], - - [[16., 26.], - [17., 27.], - [18., 28.]]], - - - [[[17., 27.], - [18., 28.], - [19., 29.]], - - [[18., 28.], - [19., 29.], - [20., 30.]], - - [[19., 29.], - [20., 30.], - [21., 31.]], - - [[20., 30.], - [21., 31.], - [22., 32.]]], - - - [[[21., 31.], - [22., 32.], - [23., 33.]], - - [[22., 32.], - [23., 33.], - [24., 34.]], - - [[23., 33.], - [24., 34.], - [25., 35.]], - - [[24., 34.], - [25., 35.], - [26., 36.]]]]]), size=(2, 3, 9, 4), nnz=4, - dtype=torch.float64, layout=torch.sparse_bsc) -# _ccol_indices -tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], - - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]], dtype=torch.int32) -# _row_indices -tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]], dtype=torch.int32) -# _values -tensor([[[[[ 1., 11.], - [ 2., 12.], - [ 3., 13.]], - - [[ 2., 12.], - [ 3., 13.], - [ 4., 14.]], - - [[ 3., 13.], - [ 4., 14.], - [ 5., 15.]], - - [[ 4., 14.], - [ 5., 15.], - [ 6., 16.]]], - - - [[[ 5., 15.], - [ 6., 16.], - [ 7., 17.]], - - [[ 6., 16.], - [ 7., 17.], - [ 8., 18.]], - - [[ 7., 17.], - [ 8., 18.], - [ 9., 19.]], - - [[ 8., 18.], - [ 9., 19.], - [10., 20.]]], - - - [[[ 9., 19.], - [10., 20.], - [11., 21.]], - - [[10., 20.], - [11., 21.], - [12., 22.]], - - [[11., 21.], - [12., 22.], - [13., 23.]], - - [[12., 22.], - [13., 23.], - [14., 24.]]]], - - - - [[[[13., 23.], - [14., 24.], - [15., 25.]], - - [[14., 24.], - [15., 25.], - [16., 26.]], - - [[15., 25.], - [16., 26.], - [17., 27.]], - - [[16., 26.], - [17., 27.], - [18., 28.]]], - - - [[[17., 27.], - [18., 28.], - [19., 29.]], - - [[18., 28.], - [19., 29.], - [20., 30.]], - - [[19., 29.], - [20., 30.], - [21., 31.]], - - [[20., 30.], - [21., 31.], - [22., 32.]]], - - - [[[21., 31.], - [22., 32.], - [23., 33.]], - - [[22., 32.], - [23., 33.], - [24., 34.]], - - [[23., 33.], - [24., 34.], - [25., 35.]], - - [[24., 34.], - [25., 35.], - [26., 36.]]]]], dtype=torch.float64) - - -########## torch.float32/torch.int64/size=()+(3, 4)+() ########## -# sparse tensor -tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 2]), - values=tensor([[[ 1., 11.]], - - [[ 2., 12.]], - - [[ 3., 13.]], - - [[ 4., 14.]]]), size=(3, 4), nnz=4, - layout=torch.sparse_bsc) -# _ccol_indices -tensor([0, 2, 4]) -# _row_indices -tensor([0, 1, 0, 2]) -# _values -tensor([[[ 1., 11.]], - - [[ 2., 12.]], - - [[ 3., 13.]], - - [[ 4., 14.]]]) - -########## torch.float32/torch.int64/size=()+(0, 0)+() ########## -# sparse tensor -tensor(ccol_indices=tensor([0]), - row_indices=tensor([], size=(0,)), - values=tensor([], size=(0, 1, 2)), size=(0, 0), nnz=0, - layout=torch.sparse_bsc) -# _ccol_indices -tensor([0]) -# _row_indices -tensor([], dtype=torch.int64) -# _values -tensor([], size=(0, 1, 2)) - -########## torch.float32/torch.int64/size=(2,)+(6, 2)+() ########## -# sparse tensor -tensor(ccol_indices=tensor([[0, 2, 4], - [0, 3, 4]]), - row_indices=tensor([[0, 1, 0, 1], - [0, 1, 2, 0]]), - values=tensor([[[[1.], - [2.]], - - [[2.], - [3.]], - - [[3.], - [4.]], - - [[4.], - [5.]]], - - - [[[5.], - [6.]], - - [[6.], - [7.]], - - [[7.], - [8.]], - - [[8.], - [9.]]]]), size=(2, 6, 2), nnz=4, - layout=torch.sparse_bsc) -# _ccol_indices -tensor([[0, 2, 4], - [0, 3, 4]]) -# _row_indices -tensor([[0, 1, 0, 1], - [0, 1, 2, 0]]) -# _values -tensor([[[[1.], - [2.]], - - [[2.], - [3.]], - - [[3.], - [4.]], - - [[4.], - [5.]]], - - - [[[5.], - [6.]], - - [[6.], - [7.]], - - [[7.], - [8.]], - - [[8.], - [9.]]]]) - -########## torch.float32/torch.int64/size=(2, 3)+(9, 4)+() ########## -# sparse tensor -tensor(ccol_indices=tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], - - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]), - row_indices=tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]), - values=tensor([[[[[ 1., 11.], - [ 2., 12.], - [ 3., 13.]], - - [[ 2., 12.], - [ 3., 13.], - [ 4., 14.]], - - [[ 3., 13.], - [ 4., 14.], - [ 5., 15.]], - - [[ 4., 14.], - [ 5., 15.], - [ 6., 16.]]], - - - [[[ 5., 15.], - [ 6., 16.], - [ 7., 17.]], - - [[ 6., 16.], - [ 7., 17.], - [ 8., 18.]], - - [[ 7., 17.], - [ 8., 18.], - [ 9., 19.]], - - [[ 8., 18.], - [ 9., 19.], - [10., 20.]]], - - - [[[ 9., 19.], - [10., 20.], - [11., 21.]], - - [[10., 20.], - [11., 21.], - [12., 22.]], - - [[11., 21.], - [12., 22.], - [13., 23.]], - - [[12., 22.], - [13., 23.], - [14., 24.]]]], - - - - [[[[13., 23.], - [14., 24.], - [15., 25.]], - - [[14., 24.], - [15., 25.], - [16., 26.]], - - [[15., 25.], - [16., 26.], - [17., 27.]], - - [[16., 26.], - [17., 27.], - [18., 28.]]], - - - [[[17., 27.], - [18., 28.], - [19., 29.]], - - [[18., 28.], - [19., 29.], - [20., 30.]], - - [[19., 29.], - [20., 30.], - [21., 31.]], - - [[20., 30.], - [21., 31.], - [22., 32.]]], - - - [[[21., 31.], - [22., 32.], - [23., 33.]], - - [[22., 32.], - [23., 33.], - [24., 34.]], - - [[23., 33.], - [24., 34.], - [25., 35.]], - - [[24., 34.], - [25., 35.], - [26., 36.]]]]]), size=(2, 3, 9, 4), nnz=4, - layout=torch.sparse_bsc) -# _ccol_indices -tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], - - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]) -# _row_indices -tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]) -# _values -tensor([[[[[ 1., 11.], - [ 2., 12.], - [ 3., 13.]], - - [[ 2., 12.], - [ 3., 13.], - [ 4., 14.]], - - [[ 3., 13.], - [ 4., 14.], - [ 5., 15.]], - - [[ 4., 14.], - [ 5., 15.], - [ 6., 16.]]], - - - [[[ 5., 15.], - [ 6., 16.], - [ 7., 17.]], - - [[ 6., 16.], - [ 7., 17.], - [ 8., 18.]], - - [[ 7., 17.], - [ 8., 18.], - [ 9., 19.]], - - [[ 8., 18.], - [ 9., 19.], - [10., 20.]]], - - - [[[ 9., 19.], - [10., 20.], - [11., 21.]], - - [[10., 20.], - [11., 21.], - [12., 22.]], - - [[11., 21.], - [12., 22.], - [13., 23.]], - - [[12., 22.], - [13., 23.], - [14., 24.]]]], - - - - [[[[13., 23.], - [14., 24.], - [15., 25.]], - - [[14., 24.], - [15., 25.], - [16., 26.]], - - [[15., 25.], - [16., 26.], - [17., 27.]], - - [[16., 26.], - [17., 27.], - [18., 28.]]], - - - [[[17., 27.], - [18., 28.], - [19., 29.]], - - [[18., 28.], - [19., 29.], - [20., 30.]], - - [[19., 29.], - [20., 30.], - [21., 31.]], - - [[20., 30.], - [21., 31.], - [22., 32.]]], - - - [[[21., 31.], - [22., 32.], - [23., 33.]], - - [[22., 32.], - [23., 33.], - [24., 34.]], - - [[23., 33.], - [24., 34.], - [25., 35.]], - - [[24., 34.], - [25., 35.], - [26., 36.]]]]]) - - -########## torch.float64/torch.int64/size=()+(3, 4)+() ########## -# sparse tensor -tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 2]), - values=tensor([[[ 1., 11.]], - - [[ 2., 12.]], - - [[ 3., 13.]], - - [[ 4., 14.]]]), size=(3, 4), nnz=4, dtype=torch.float64, - layout=torch.sparse_bsc) -# _ccol_indices -tensor([0, 2, 4]) -# _row_indices -tensor([0, 1, 0, 2]) -# _values -tensor([[[ 1., 11.]], - - [[ 2., 12.]], - - [[ 3., 13.]], - - [[ 4., 14.]]], dtype=torch.float64) - -########## torch.float64/torch.int64/size=()+(0, 0)+() ########## -# sparse tensor -tensor(ccol_indices=tensor([0]), - row_indices=tensor([], size=(0,)), - values=tensor([], size=(0, 1, 2)), size=(0, 0), nnz=0, - dtype=torch.float64, layout=torch.sparse_bsc) -# _ccol_indices -tensor([0]) -# _row_indices -tensor([], dtype=torch.int64) -# _values -tensor([], size=(0, 1, 2), dtype=torch.float64) - -########## torch.float64/torch.int64/size=(2,)+(6, 2)+() ########## -# sparse tensor -tensor(ccol_indices=tensor([[0, 2, 4], - [0, 3, 4]]), - row_indices=tensor([[0, 1, 0, 1], - [0, 1, 2, 0]]), - values=tensor([[[[1.], - [2.]], - - [[2.], - [3.]], - - [[3.], - [4.]], - - [[4.], - [5.]]], - - - [[[5.], - [6.]], - - [[6.], - [7.]], - - [[7.], - [8.]], - - [[8.], - [9.]]]]), size=(2, 6, 2), nnz=4, dtype=torch.float64, - layout=torch.sparse_bsc) -# _ccol_indices -tensor([[0, 2, 4], - [0, 3, 4]]) -# _row_indices -tensor([[0, 1, 0, 1], - [0, 1, 2, 0]]) -# _values -tensor([[[[1.], - [2.]], - - [[2.], - [3.]], - - [[3.], - [4.]], - - [[4.], - [5.]]], - - - [[[5.], - [6.]], - - [[6.], - [7.]], - - [[7.], - [8.]], - - [[8.], - [9.]]]], dtype=torch.float64) - -########## torch.float64/torch.int64/size=(2, 3)+(9, 4)+() ########## -# sparse tensor -tensor(ccol_indices=tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], - - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]), - row_indices=tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]), - values=tensor([[[[[ 1., 11.], - [ 2., 12.], - [ 3., 13.]], - - [[ 2., 12.], - [ 3., 13.], - [ 4., 14.]], - - [[ 3., 13.], - [ 4., 14.], - [ 5., 15.]], - - [[ 4., 14.], - [ 5., 15.], - [ 6., 16.]]], - - - [[[ 5., 15.], - [ 6., 16.], - [ 7., 17.]], - - [[ 6., 16.], - [ 7., 17.], - [ 8., 18.]], - - [[ 7., 17.], - [ 8., 18.], - [ 9., 19.]], - - [[ 8., 18.], - [ 9., 19.], - [10., 20.]]], - - - [[[ 9., 19.], - [10., 20.], - [11., 21.]], - - [[10., 20.], - [11., 21.], - [12., 22.]], - - [[11., 21.], - [12., 22.], - [13., 23.]], - - [[12., 22.], - [13., 23.], - [14., 24.]]]], - - - - [[[[13., 23.], - [14., 24.], - [15., 25.]], - - [[14., 24.], - [15., 25.], - [16., 26.]], - - [[15., 25.], - [16., 26.], - [17., 27.]], - - [[16., 26.], - [17., 27.], - [18., 28.]]], - - - [[[17., 27.], - [18., 28.], - [19., 29.]], - - [[18., 28.], - [19., 29.], - [20., 30.]], - - [[19., 29.], - [20., 30.], - [21., 31.]], - - [[20., 30.], - [21., 31.], - [22., 32.]]], - - - [[[21., 31.], - [22., 32.], - [23., 33.]], - - [[22., 32.], - [23., 33.], - [24., 34.]], - - [[23., 33.], - [24., 34.], - [25., 35.]], - - [[24., 34.], - [25., 35.], - [26., 36.]]]]]), size=(2, 3, 9, 4), nnz=4, - dtype=torch.float64, layout=torch.sparse_bsc) -# _ccol_indices -tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], - - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]) -# _row_indices -tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]) -# _values -tensor([[[[[ 1., 11.], - [ 2., 12.], - [ 3., 13.]], - - [[ 2., 12.], - [ 3., 13.], - [ 4., 14.]], - - [[ 3., 13.], - [ 4., 14.], - [ 5., 15.]], - - [[ 4., 14.], - [ 5., 15.], - [ 6., 16.]]], - - - [[[ 5., 15.], - [ 6., 16.], - [ 7., 17.]], - - [[ 6., 16.], - [ 7., 17.], - [ 8., 18.]], - - [[ 7., 17.], - [ 8., 18.], - [ 9., 19.]], - - [[ 8., 18.], - [ 9., 19.], - [10., 20.]]], - - - [[[ 9., 19.], - [10., 20.], - [11., 21.]], - - [[10., 20.], - [11., 21.], - [12., 22.]], - - [[11., 21.], - [12., 22.], - [13., 23.]], - - [[12., 22.], - [13., 23.], - [14., 24.]]]], - - - - [[[[13., 23.], - [14., 24.], - [15., 25.]], - - [[14., 24.], - [15., 25.], - [16., 26.]], - - [[15., 25.], - [16., 26.], - [17., 27.]], - - [[16., 26.], - [17., 27.], - [18., 28.]]], - - - [[[17., 27.], - [18., 28.], - [19., 29.]], - - [[18., 28.], - [19., 29.], - [20., 30.]], - - [[19., 29.], - [20., 30.], - [21., 31.]], - - [[20., 30.], - [21., 31.], - [22., 32.]]], - - - [[[21., 31.], - [22., 32.], - [23., 33.]], - - [[22., 32.], - [23., 33.], - [24., 34.]], - - [[23., 33.], - [24., 34.], - [25., 35.]], - - [[24., 34.], - [25., 35.], - [26., 36.]]]]], dtype=torch.float64) - - -########## torch.float32/torch.int32/size=()+(6, 6)+(2,) ########## -# sparse tensor -tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 2]), - values=tensor([[[[ 1., 101.], - [ 11., 111.], - [ 21., 121.]], - - [[ 2., 102.], - [ 12., 112.], - [ 22., 122.]]], - - - [[[ 2., 102.], - [ 12., 112.], - [ 22., 122.]], - - [[ 3., 103.], - [ 13., 113.], - [ 23., 123.]]], - - - [[[ 3., 103.], - [ 13., 113.], - [ 23., 123.]], - - [[ 4., 104.], - [ 14., 114.], - [ 24., 124.]]], - - - [[[ 4., 104.], - [ 14., 114.], - [ 24., 124.]], - - [[ 5., 105.], - [ 15., 115.], - [ 25., 125.]]]]), size=(6, 6, 2), nnz=4, - layout=torch.sparse_bsc) -# _ccol_indices -tensor([0, 2, 4], dtype=torch.int32) -# _row_indices -tensor([0, 1, 0, 2], dtype=torch.int32) -# _values -tensor([[[[ 1., 101.], - [ 11., 111.], - [ 21., 121.]], - - [[ 2., 102.], - [ 12., 112.], - [ 22., 122.]]], - - - [[[ 2., 102.], - [ 12., 112.], - [ 22., 122.]], - - [[ 3., 103.], - [ 13., 113.], - [ 23., 123.]]], - - - [[[ 3., 103.], - [ 13., 113.], - [ 23., 123.]], - - [[ 4., 104.], - [ 14., 114.], - [ 24., 124.]]], - - - [[[ 4., 104.], - [ 14., 114.], - [ 24., 124.]], - - [[ 5., 105.], - [ 15., 115.], - [ 25., 125.]]]]) - -########## torch.float32/torch.int32/size=()+(9, 4)+(4, 2) ########## -# sparse tensor -tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 2]), - values=tensor([[[[[1.0000e+00, 1.0010e+03], - [1.0100e+02, 1.1010e+03], - [2.0100e+02, 1.2010e+03], - [3.0100e+02, 1.3010e+03]], - - [[1.1000e+01, 1.0110e+03], - [1.1100e+02, 1.1110e+03], - [2.1100e+02, 1.2110e+03], - [3.1100e+02, 1.3110e+03]]], - - - [[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], - - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]]], - - - [[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], - - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]]]], - - - - [[[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], - - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]]], - - - [[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], - - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]]], - - - [[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], - - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]]]], - - - - [[[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], - - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]]], - - - [[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], - - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]]], - - - [[[5.0000e+00, 1.0050e+03], - [1.0500e+02, 1.1050e+03], - [2.0500e+02, 1.2050e+03], - [3.0500e+02, 1.3050e+03]], - - [[1.5000e+01, 1.0150e+03], - [1.1500e+02, 1.1150e+03], - [2.1500e+02, 1.2150e+03], - [3.1500e+02, 1.3150e+03]]]], - - - - [[[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], - - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]]], - - - [[[5.0000e+00, 1.0050e+03], - [1.0500e+02, 1.1050e+03], - [2.0500e+02, 1.2050e+03], - [3.0500e+02, 1.3050e+03]], - - [[1.5000e+01, 1.0150e+03], - [1.1500e+02, 1.1150e+03], - [2.1500e+02, 1.2150e+03], - [3.1500e+02, 1.3150e+03]]], - - - [[[6.0000e+00, 1.0060e+03], - [1.0600e+02, 1.1060e+03], - [2.0600e+02, 1.2060e+03], - [3.0600e+02, 1.3060e+03]], - - [[1.6000e+01, 1.0160e+03], - [1.1600e+02, 1.1160e+03], - [2.1600e+02, 1.2160e+03], - [3.1600e+02, 1.3160e+03]]]]]), size=(9, 4, 4, 2), - nnz=4, layout=torch.sparse_bsc) -# _ccol_indices -tensor([0, 2, 4], dtype=torch.int32) -# _row_indices -tensor([0, 1, 0, 2], dtype=torch.int32) -# _values -tensor([[[[[1.0000e+00, 1.0010e+03], - [1.0100e+02, 1.1010e+03], - [2.0100e+02, 1.2010e+03], - [3.0100e+02, 1.3010e+03]], - - [[1.1000e+01, 1.0110e+03], - [1.1100e+02, 1.1110e+03], - [2.1100e+02, 1.2110e+03], - [3.1100e+02, 1.3110e+03]]], - - - [[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], - - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]]], - - - [[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], - - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]]]], - - - - [[[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], - - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]]], - - - [[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], - - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]]], - - - [[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], - - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]]]], - - - - [[[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], - - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]]], - - - [[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], - - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]]], - - - [[[5.0000e+00, 1.0050e+03], - [1.0500e+02, 1.1050e+03], - [2.0500e+02, 1.2050e+03], - [3.0500e+02, 1.3050e+03]], - - [[1.5000e+01, 1.0150e+03], - [1.1500e+02, 1.1150e+03], - [2.1500e+02, 1.2150e+03], - [3.1500e+02, 1.3150e+03]]]], - - - - [[[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], - - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]]], - - - [[[5.0000e+00, 1.0050e+03], - [1.0500e+02, 1.1050e+03], - [2.0500e+02, 1.2050e+03], - [3.0500e+02, 1.3050e+03]], - - [[1.5000e+01, 1.0150e+03], - [1.1500e+02, 1.1150e+03], - [2.1500e+02, 1.2150e+03], - [3.1500e+02, 1.3150e+03]]], - - - [[[6.0000e+00, 1.0060e+03], - [1.0600e+02, 1.1060e+03], - [2.0600e+02, 1.2060e+03], - [3.0600e+02, 1.3060e+03]], - - [[1.6000e+01, 1.0160e+03], - [1.1600e+02, 1.1160e+03], - [2.1600e+02, 1.2160e+03], - [3.1600e+02, 1.3160e+03]]]]]) - -########## torch.float32/torch.int32/size=(2, 3)+(6, 6)+(2, 1) ########## -# sparse tensor -tensor(ccol_indices=tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], - - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]), - row_indices=tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]), - values=tensor([[[[[[[ 1.], - [101.]], - - [[ 11.], - [111.]], - - [[ 21.], - [121.]]], - - - [[[ 2.], - [102.]], - - [[ 12.], - [112.]], - - [[ 22.], - [122.]]]], - - - - [[[[ 2.], - [102.]], - - [[ 12.], - [112.]], - - [[ 22.], - [122.]]], - - - [[[ 3.], - [103.]], - - [[ 13.], - [113.]], - - [[ 23.], - [123.]]]], - - - - [[[[ 3.], - [103.]], - - [[ 13.], - [113.]], - - [[ 23.], - [123.]]], - - - [[[ 4.], - [104.]], - - [[ 14.], - [114.]], - - [[ 24.], - [124.]]]], - - - - [[[[ 4.], - [104.]], - - [[ 14.], - [114.]], - - [[ 24.], - [124.]]], - - - [[[ 5.], - [105.]], - - [[ 15.], - [115.]], - - [[ 25.], - [125.]]]]], - - - - - [[[[[ 5.], - [105.]], - - [[ 15.], - [115.]], - - [[ 25.], - [125.]]], - - - [[[ 6.], - [106.]], - - [[ 16.], - [116.]], - - [[ 26.], - [126.]]]], - - - - [[[[ 6.], - [106.]], - - [[ 16.], - [116.]], - - [[ 26.], - [126.]]], - - - [[[ 7.], - [107.]], - - [[ 17.], - [117.]], - - [[ 27.], - [127.]]]], - - - - [[[[ 7.], - [107.]], - - [[ 17.], - [117.]], - - [[ 27.], - [127.]]], - - - [[[ 8.], - [108.]], - - [[ 18.], - [118.]], - - [[ 28.], - [128.]]]], - - - - [[[[ 8.], - [108.]], - - [[ 18.], - [118.]], - - [[ 28.], - [128.]]], - - - [[[ 9.], - [109.]], - - [[ 19.], - [119.]], - - [[ 29.], - [129.]]]]], - - - - - [[[[[ 9.], - [109.]], - - [[ 19.], - [119.]], - - [[ 29.], - [129.]]], - - - [[[ 10.], - [110.]], - - [[ 20.], - [120.]], - - [[ 30.], - [130.]]]], - - - - [[[[ 10.], - [110.]], - - [[ 20.], - [120.]], - - [[ 30.], - [130.]]], - - - [[[ 11.], - [111.]], - - [[ 21.], - [121.]], - - [[ 31.], - [131.]]]], - - - - [[[[ 11.], - [111.]], - - [[ 21.], - [121.]], - - [[ 31.], - [131.]]], - - - [[[ 12.], - [112.]], - - [[ 22.], - [122.]], - - [[ 32.], - [132.]]]], - - - - [[[[ 12.], - [112.]], - - [[ 22.], - [122.]], - - [[ 32.], - [132.]]], - - - [[[ 13.], - [113.]], - - [[ 23.], - [123.]], - - [[ 33.], - [133.]]]]]], - - - - - - [[[[[[ 13.], - [113.]], - - [[ 23.], - [123.]], - - [[ 33.], - [133.]]], - - - [[[ 14.], - [114.]], - - [[ 24.], - [124.]], - - [[ 34.], - [134.]]]], - - - - [[[[ 14.], - [114.]], - - [[ 24.], - [124.]], - - [[ 34.], - [134.]]], - - - [[[ 15.], - [115.]], - - [[ 25.], - [125.]], - - [[ 35.], - [135.]]]], - - - - [[[[ 15.], - [115.]], - - [[ 25.], - [125.]], - - [[ 35.], - [135.]]], - - - [[[ 16.], - [116.]], - - [[ 26.], - [126.]], - - [[ 36.], - [136.]]]], - - - - [[[[ 16.], - [116.]], - - [[ 26.], - [126.]], - - [[ 36.], - [136.]]], - - - [[[ 17.], - [117.]], - - [[ 27.], - [127.]], - - [[ 37.], - [137.]]]]], - - - - - [[[[[ 17.], - [117.]], - - [[ 27.], - [127.]], - - [[ 37.], - [137.]]], - - - [[[ 18.], - [118.]], - - [[ 28.], - [128.]], - - [[ 38.], - [138.]]]], - - - - [[[[ 18.], - [118.]], - - [[ 28.], - [128.]], - - [[ 38.], - [138.]]], - - - [[[ 19.], - [119.]], - - [[ 29.], - [129.]], - - [[ 39.], - [139.]]]], - - - - [[[[ 19.], - [119.]], - - [[ 29.], - [129.]], - - [[ 39.], - [139.]]], - - - [[[ 20.], - [120.]], - - [[ 30.], - [130.]], - - [[ 40.], - [140.]]]], - - - - [[[[ 20.], - [120.]], - - [[ 30.], - [130.]], - - [[ 40.], - [140.]]], - - - [[[ 21.], - [121.]], - - [[ 31.], - [131.]], - - [[ 41.], - [141.]]]]], - - - - - [[[[[ 21.], - [121.]], - - [[ 31.], - [131.]], - - [[ 41.], - [141.]]], - - - [[[ 22.], - [122.]], - - [[ 32.], - [132.]], - - [[ 42.], - [142.]]]], - - - - [[[[ 22.], - [122.]], - - [[ 32.], - [132.]], - - [[ 42.], - [142.]]], - - - [[[ 23.], - [123.]], - - [[ 33.], - [133.]], - - [[ 43.], - [143.]]]], - - - - [[[[ 23.], - [123.]], - - [[ 33.], - [133.]], - - [[ 43.], - [143.]]], - - - [[[ 24.], - [124.]], - - [[ 34.], - [134.]], - - [[ 44.], - [144.]]]], - - - - [[[[ 24.], - [124.]], - - [[ 34.], - [134.]], - - [[ 44.], - [144.]]], - - - [[[ 25.], - [125.]], - - [[ 35.], - [135.]], - - [[ 45.], - [145.]]]]]]]), size=(2, 3, 6, 6, 2, 1), nnz=4, - layout=torch.sparse_bsc) -# _ccol_indices -tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], - - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]], dtype=torch.int32) -# _row_indices -tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]], dtype=torch.int32) -# _values -tensor([[[[[[[ 1.], - [101.]], - - [[ 11.], - [111.]], - - [[ 21.], - [121.]]], - - - [[[ 2.], - [102.]], - - [[ 12.], - [112.]], - - [[ 22.], - [122.]]]], - - - - [[[[ 2.], - [102.]], - - [[ 12.], - [112.]], - - [[ 22.], - [122.]]], - - - [[[ 3.], - [103.]], - - [[ 13.], - [113.]], - - [[ 23.], - [123.]]]], - - - - [[[[ 3.], - [103.]], - - [[ 13.], - [113.]], - - [[ 23.], - [123.]]], - - - [[[ 4.], - [104.]], - - [[ 14.], - [114.]], - - [[ 24.], - [124.]]]], - - - - [[[[ 4.], - [104.]], - - [[ 14.], - [114.]], - - [[ 24.], - [124.]]], - - - [[[ 5.], - [105.]], - - [[ 15.], - [115.]], - - [[ 25.], - [125.]]]]], - - - - - [[[[[ 5.], - [105.]], - - [[ 15.], - [115.]], - - [[ 25.], - [125.]]], - - - [[[ 6.], - [106.]], - - [[ 16.], - [116.]], - - [[ 26.], - [126.]]]], - - - - [[[[ 6.], - [106.]], - - [[ 16.], - [116.]], - - [[ 26.], - [126.]]], - - - [[[ 7.], - [107.]], - - [[ 17.], - [117.]], - - [[ 27.], - [127.]]]], - - - - [[[[ 7.], - [107.]], - - [[ 17.], - [117.]], - - [[ 27.], - [127.]]], - - - [[[ 8.], - [108.]], - - [[ 18.], - [118.]], - - [[ 28.], - [128.]]]], - - - - [[[[ 8.], - [108.]], - - [[ 18.], - [118.]], - - [[ 28.], - [128.]]], - - - [[[ 9.], - [109.]], - - [[ 19.], - [119.]], - - [[ 29.], - [129.]]]]], - - - - - [[[[[ 9.], - [109.]], - - [[ 19.], - [119.]], - - [[ 29.], - [129.]]], - - - [[[ 10.], - [110.]], - - [[ 20.], - [120.]], - - [[ 30.], - [130.]]]], - - - - [[[[ 10.], - [110.]], - - [[ 20.], - [120.]], - - [[ 30.], - [130.]]], - - - [[[ 11.], - [111.]], - - [[ 21.], - [121.]], - - [[ 31.], - [131.]]]], - - - - [[[[ 11.], - [111.]], - - [[ 21.], - [121.]], - - [[ 31.], - [131.]]], - - - [[[ 12.], - [112.]], - - [[ 22.], - [122.]], - - [[ 32.], - [132.]]]], - - - - [[[[ 12.], - [112.]], - - [[ 22.], - [122.]], - - [[ 32.], - [132.]]], - - - [[[ 13.], - [113.]], - - [[ 23.], - [123.]], - - [[ 33.], - [133.]]]]]], - - - - - - [[[[[[ 13.], - [113.]], - - [[ 23.], - [123.]], - - [[ 33.], - [133.]]], - - - [[[ 14.], - [114.]], - - [[ 24.], - [124.]], - - [[ 34.], - [134.]]]], - - - - [[[[ 14.], - [114.]], - - [[ 24.], - [124.]], - - [[ 34.], - [134.]]], - - - [[[ 15.], - [115.]], - - [[ 25.], - [125.]], - - [[ 35.], - [135.]]]], - - - - [[[[ 15.], - [115.]], - - [[ 25.], - [125.]], - - [[ 35.], - [135.]]], - - - [[[ 16.], - [116.]], - - [[ 26.], - [126.]], - - [[ 36.], - [136.]]]], - - - - [[[[ 16.], - [116.]], - - [[ 26.], - [126.]], - - [[ 36.], - [136.]]], - - - [[[ 17.], - [117.]], - - [[ 27.], - [127.]], - - [[ 37.], - [137.]]]]], - - - - - [[[[[ 17.], - [117.]], - - [[ 27.], - [127.]], - - [[ 37.], - [137.]]], - - - [[[ 18.], - [118.]], - - [[ 28.], - [128.]], - - [[ 38.], - [138.]]]], - - - - [[[[ 18.], - [118.]], - - [[ 28.], - [128.]], - - [[ 38.], - [138.]]], - - - [[[ 19.], - [119.]], - - [[ 29.], - [129.]], - - [[ 39.], - [139.]]]], - - - - [[[[ 19.], - [119.]], - - [[ 29.], - [129.]], - - [[ 39.], - [139.]]], - - - [[[ 20.], - [120.]], - - [[ 30.], - [130.]], - - [[ 40.], - [140.]]]], - - - - [[[[ 20.], - [120.]], - - [[ 30.], - [130.]], - - [[ 40.], - [140.]]], - - - [[[ 21.], - [121.]], - - [[ 31.], - [131.]], - - [[ 41.], - [141.]]]]], - - - - - [[[[[ 21.], - [121.]], - - [[ 31.], - [131.]], - - [[ 41.], - [141.]]], - - - [[[ 22.], - [122.]], - - [[ 32.], - [132.]], - - [[ 42.], - [142.]]]], - - - - [[[[ 22.], - [122.]], - - [[ 32.], - [132.]], - - [[ 42.], - [142.]]], - - - [[[ 23.], - [123.]], - - [[ 33.], - [133.]], - - [[ 43.], - [143.]]]], - - - - [[[[ 23.], - [123.]], - - [[ 33.], - [133.]], - - [[ 43.], - [143.]]], - - - [[[ 24.], - [124.]], - - [[ 34.], - [134.]], - - [[ 44.], - [144.]]]], - - - - [[[[ 24.], - [124.]], - - [[ 34.], - [134.]], - - [[ 44.], - [144.]]], - - - [[[ 25.], - [125.]], - - [[ 35.], - [135.]], - - [[ 45.], - [145.]]]]]]]) - - -########## torch.float64/torch.int32/size=()+(6, 6)+(2,) ########## -# sparse tensor -tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 2]), - values=tensor([[[[ 1., 101.], - [ 11., 111.], - [ 21., 121.]], - - [[ 2., 102.], - [ 12., 112.], - [ 22., 122.]]], - - - [[[ 2., 102.], - [ 12., 112.], - [ 22., 122.]], - - [[ 3., 103.], - [ 13., 113.], - [ 23., 123.]]], - - - [[[ 3., 103.], - [ 13., 113.], - [ 23., 123.]], - - [[ 4., 104.], - [ 14., 114.], - [ 24., 124.]]], - - - [[[ 4., 104.], - [ 14., 114.], - [ 24., 124.]], - - [[ 5., 105.], - [ 15., 115.], - [ 25., 125.]]]]), size=(6, 6, 2), nnz=4, - dtype=torch.float64, layout=torch.sparse_bsc) -# _ccol_indices -tensor([0, 2, 4], dtype=torch.int32) -# _row_indices -tensor([0, 1, 0, 2], dtype=torch.int32) -# _values -tensor([[[[ 1., 101.], - [ 11., 111.], - [ 21., 121.]], - - [[ 2., 102.], - [ 12., 112.], - [ 22., 122.]]], - - - [[[ 2., 102.], - [ 12., 112.], - [ 22., 122.]], - - [[ 3., 103.], - [ 13., 113.], - [ 23., 123.]]], - - - [[[ 3., 103.], - [ 13., 113.], - [ 23., 123.]], - - [[ 4., 104.], - [ 14., 114.], - [ 24., 124.]]], - - - [[[ 4., 104.], - [ 14., 114.], - [ 24., 124.]], - - [[ 5., 105.], - [ 15., 115.], - [ 25., 125.]]]], dtype=torch.float64) - -########## torch.float64/torch.int32/size=()+(9, 4)+(4, 2) ########## -# sparse tensor -tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 2]), - values=tensor([[[[[1.0000e+00, 1.0010e+03], - [1.0100e+02, 1.1010e+03], - [2.0100e+02, 1.2010e+03], - [3.0100e+02, 1.3010e+03]], - - [[1.1000e+01, 1.0110e+03], - [1.1100e+02, 1.1110e+03], - [2.1100e+02, 1.2110e+03], - [3.1100e+02, 1.3110e+03]]], - - - [[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], - - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]]], - - - [[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], - - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]]]], - - - - [[[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], - - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]]], - - - [[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], - - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]]], - - - [[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], - - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]]]], - - - - [[[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], - - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]]], - - - [[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], - - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]]], - - - [[[5.0000e+00, 1.0050e+03], - [1.0500e+02, 1.1050e+03], - [2.0500e+02, 1.2050e+03], - [3.0500e+02, 1.3050e+03]], - - [[1.5000e+01, 1.0150e+03], - [1.1500e+02, 1.1150e+03], - [2.1500e+02, 1.2150e+03], - [3.1500e+02, 1.3150e+03]]]], - - - - [[[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], - - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]]], - - - [[[5.0000e+00, 1.0050e+03], - [1.0500e+02, 1.1050e+03], - [2.0500e+02, 1.2050e+03], - [3.0500e+02, 1.3050e+03]], - - [[1.5000e+01, 1.0150e+03], - [1.1500e+02, 1.1150e+03], - [2.1500e+02, 1.2150e+03], - [3.1500e+02, 1.3150e+03]]], - - - [[[6.0000e+00, 1.0060e+03], - [1.0600e+02, 1.1060e+03], - [2.0600e+02, 1.2060e+03], - [3.0600e+02, 1.3060e+03]], - - [[1.6000e+01, 1.0160e+03], - [1.1600e+02, 1.1160e+03], - [2.1600e+02, 1.2160e+03], - [3.1600e+02, 1.3160e+03]]]]]), size=(9, 4, 4, 2), - nnz=4, dtype=torch.float64, layout=torch.sparse_bsc) -# _ccol_indices -tensor([0, 2, 4], dtype=torch.int32) -# _row_indices -tensor([0, 1, 0, 2], dtype=torch.int32) -# _values -tensor([[[[[1.0000e+00, 1.0010e+03], - [1.0100e+02, 1.1010e+03], - [2.0100e+02, 1.2010e+03], - [3.0100e+02, 1.3010e+03]], - - [[1.1000e+01, 1.0110e+03], - [1.1100e+02, 1.1110e+03], - [2.1100e+02, 1.2110e+03], - [3.1100e+02, 1.3110e+03]]], - - - [[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], - - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]]], - - - [[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], - - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]]]], - - - - [[[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], - - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]]], - - - [[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], - - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]]], - - - [[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], - - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]]]], - - - - [[[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], - - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]]], - - - [[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], - - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]]], - - - [[[5.0000e+00, 1.0050e+03], - [1.0500e+02, 1.1050e+03], - [2.0500e+02, 1.2050e+03], - [3.0500e+02, 1.3050e+03]], - - [[1.5000e+01, 1.0150e+03], - [1.1500e+02, 1.1150e+03], - [2.1500e+02, 1.2150e+03], - [3.1500e+02, 1.3150e+03]]]], - - - - [[[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], - - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]]], - - - [[[5.0000e+00, 1.0050e+03], - [1.0500e+02, 1.1050e+03], - [2.0500e+02, 1.2050e+03], - [3.0500e+02, 1.3050e+03]], - - [[1.5000e+01, 1.0150e+03], - [1.1500e+02, 1.1150e+03], - [2.1500e+02, 1.2150e+03], - [3.1500e+02, 1.3150e+03]]], - - - [[[6.0000e+00, 1.0060e+03], - [1.0600e+02, 1.1060e+03], - [2.0600e+02, 1.2060e+03], - [3.0600e+02, 1.3060e+03]], - - [[1.6000e+01, 1.0160e+03], - [1.1600e+02, 1.1160e+03], - [2.1600e+02, 1.2160e+03], - [3.1600e+02, 1.3160e+03]]]]], dtype=torch.float64) - -########## torch.float64/torch.int32/size=(2, 3)+(6, 6)+(2, 1) ########## -# sparse tensor -tensor(ccol_indices=tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], - - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]), - row_indices=tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]), - values=tensor([[[[[[[ 1.], - [101.]], - - [[ 11.], - [111.]], - - [[ 21.], - [121.]]], - - - [[[ 2.], - [102.]], - - [[ 12.], - [112.]], - - [[ 22.], - [122.]]]], - - - - [[[[ 2.], - [102.]], - - [[ 12.], - [112.]], - - [[ 22.], - [122.]]], - - - [[[ 3.], - [103.]], - - [[ 13.], - [113.]], - - [[ 23.], - [123.]]]], - - - - [[[[ 3.], - [103.]], - - [[ 13.], - [113.]], - - [[ 23.], - [123.]]], - - - [[[ 4.], - [104.]], - - [[ 14.], - [114.]], - - [[ 24.], - [124.]]]], - - - - [[[[ 4.], - [104.]], - - [[ 14.], - [114.]], - - [[ 24.], - [124.]]], - - - [[[ 5.], - [105.]], - - [[ 15.], - [115.]], - - [[ 25.], - [125.]]]]], - - - - - [[[[[ 5.], - [105.]], - - [[ 15.], - [115.]], - - [[ 25.], - [125.]]], - - - [[[ 6.], - [106.]], - - [[ 16.], - [116.]], - - [[ 26.], - [126.]]]], - - - - [[[[ 6.], - [106.]], - - [[ 16.], - [116.]], - - [[ 26.], - [126.]]], - - - [[[ 7.], - [107.]], - - [[ 17.], - [117.]], - - [[ 27.], - [127.]]]], - - - - [[[[ 7.], - [107.]], - - [[ 17.], - [117.]], - - [[ 27.], - [127.]]], - - - [[[ 8.], - [108.]], - - [[ 18.], - [118.]], - - [[ 28.], - [128.]]]], - - - - [[[[ 8.], - [108.]], - - [[ 18.], - [118.]], - - [[ 28.], - [128.]]], - - - [[[ 9.], - [109.]], - - [[ 19.], - [119.]], - - [[ 29.], - [129.]]]]], - - - - - [[[[[ 9.], - [109.]], - - [[ 19.], - [119.]], - - [[ 29.], - [129.]]], - - - [[[ 10.], - [110.]], - - [[ 20.], - [120.]], - - [[ 30.], - [130.]]]], - - - - [[[[ 10.], - [110.]], - - [[ 20.], - [120.]], - - [[ 30.], - [130.]]], - - - [[[ 11.], - [111.]], - - [[ 21.], - [121.]], - - [[ 31.], - [131.]]]], - - - - [[[[ 11.], - [111.]], - - [[ 21.], - [121.]], - - [[ 31.], - [131.]]], - - - [[[ 12.], - [112.]], - - [[ 22.], - [122.]], - - [[ 32.], - [132.]]]], - - - - [[[[ 12.], - [112.]], - - [[ 22.], - [122.]], - - [[ 32.], - [132.]]], - - - [[[ 13.], - [113.]], - - [[ 23.], - [123.]], - - [[ 33.], - [133.]]]]]], - - - - - - [[[[[[ 13.], - [113.]], - - [[ 23.], - [123.]], - - [[ 33.], - [133.]]], - - - [[[ 14.], - [114.]], - - [[ 24.], - [124.]], - - [[ 34.], - [134.]]]], - - - - [[[[ 14.], - [114.]], - - [[ 24.], - [124.]], - - [[ 34.], - [134.]]], - - - [[[ 15.], - [115.]], - - [[ 25.], - [125.]], - - [[ 35.], - [135.]]]], - - - - [[[[ 15.], - [115.]], - - [[ 25.], - [125.]], - - [[ 35.], - [135.]]], - - - [[[ 16.], - [116.]], - - [[ 26.], - [126.]], - - [[ 36.], - [136.]]]], - - - - [[[[ 16.], - [116.]], - - [[ 26.], - [126.]], - - [[ 36.], - [136.]]], - - - [[[ 17.], - [117.]], - - [[ 27.], - [127.]], - - [[ 37.], - [137.]]]]], - - - - - [[[[[ 17.], - [117.]], - - [[ 27.], - [127.]], - - [[ 37.], - [137.]]], - - - [[[ 18.], - [118.]], - - [[ 28.], - [128.]], - - [[ 38.], - [138.]]]], - - - - [[[[ 18.], - [118.]], - - [[ 28.], - [128.]], - - [[ 38.], - [138.]]], - - - [[[ 19.], - [119.]], - - [[ 29.], - [129.]], - - [[ 39.], - [139.]]]], - - - - [[[[ 19.], - [119.]], - - [[ 29.], - [129.]], - - [[ 39.], - [139.]]], - - - [[[ 20.], - [120.]], - - [[ 30.], - [130.]], - - [[ 40.], - [140.]]]], - - - - [[[[ 20.], - [120.]], - - [[ 30.], - [130.]], - - [[ 40.], - [140.]]], - - - [[[ 21.], - [121.]], - - [[ 31.], - [131.]], - - [[ 41.], - [141.]]]]], - - - - - [[[[[ 21.], - [121.]], - - [[ 31.], - [131.]], - - [[ 41.], - [141.]]], - - - [[[ 22.], - [122.]], - - [[ 32.], - [132.]], - - [[ 42.], - [142.]]]], - - - - [[[[ 22.], - [122.]], - - [[ 32.], - [132.]], - - [[ 42.], - [142.]]], - - - [[[ 23.], - [123.]], - - [[ 33.], - [133.]], - - [[ 43.], - [143.]]]], - - - - [[[[ 23.], - [123.]], - - [[ 33.], - [133.]], - - [[ 43.], - [143.]]], - - - [[[ 24.], - [124.]], - - [[ 34.], - [134.]], - - [[ 44.], - [144.]]]], - - - - [[[[ 24.], - [124.]], - - [[ 34.], - [134.]], - - [[ 44.], - [144.]]], + [[0, 0, 0], + [0, 0, 0], + [0, 0, 0]]]), + values=tensor([[[[[1.], + [3.]], + [[2.], + [0.]], - [[[ 25.], - [125.]], - - [[ 35.], - [135.]], - - [[ 45.], - [145.]]]]]]]), size=(2, 3, 6, 6, 2, 1), nnz=4, - dtype=torch.float64, layout=torch.sparse_bsc) -# _ccol_indices -tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], - - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]], dtype=torch.int32) -# _row_indices -tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]], dtype=torch.int32) -# _values -tensor([[[[[[[ 1.], - [101.]], - - [[ 11.], - [111.]], - - [[ 21.], - [121.]]], - - - [[[ 2.], - [102.]], - - [[ 12.], - [112.]], - - [[ 22.], - [122.]]]], - - - - [[[[ 2.], - [102.]], - - [[ 12.], - [112.]], - - [[ 22.], - [122.]]], - - - [[[ 3.], - [103.]], - - [[ 13.], - [113.]], - - [[ 23.], - [123.]]]], - - - - [[[[ 3.], - [103.]], - - [[ 13.], - [113.]], - - [[ 23.], - [123.]]], - - - [[[ 4.], - [104.]], - - [[ 14.], - [114.]], - - [[ 24.], - [124.]]]], + [[0.], + [4.]]], + [[[1.], + [4.]], - [[[[ 4.], - [104.]], + [[2.], + [0.]], - [[ 14.], - [114.]], + [[3.], + [0.]]], - [[ 24.], - [124.]]], + [[[1.], + [2.]], - [[[ 5.], - [105.]], + [[0.], + [3.]], - [[ 15.], - [115.]], + [[0.], + [4.]]]], - [[ 25.], - [125.]]]]], + [[[[0.], + [2.]], + [[1.], + [3.]], - [[[[[ 5.], - [105.]], + [[0.], + [4.]]], - [[ 15.], - [115.]], - [[ 25.], - [125.]]], + [[[1.], + [3.]], + [[0.], + [4.]], - [[[ 6.], - [106.]], + [[2.], + [0.]]], - [[ 16.], - [116.]], - [[ 26.], - [126.]]]], + [[[1.], + [0.]], + [[2.], + [4.]], + [[3.], + [0.]]]]]), size=(2, 3, 2, 3), nnz=3, + layout=torch.sparse_bsc) +# _ccol_indices +tensor([[[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]], - [[[[ 6.], - [106.]], - - [[ 16.], - [116.]], - - [[ 26.], - [126.]]], - - - [[[ 7.], - [107.]], - - [[ 17.], - [117.]], - - [[ 27.], - [127.]]]], - - - - [[[[ 7.], - [107.]], - - [[ 17.], - [117.]], - - [[ 27.], - [127.]]], - - - [[[ 8.], - [108.]], - - [[ 18.], - [118.]], - - [[ 28.], - [128.]]]], - - - - [[[[ 8.], - [108.]], - - [[ 18.], - [118.]], - - [[ 28.], - [128.]]], + [[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]]], dtype=torch.int32) +# _row_indices +tensor([[[0, 0, 0], + [0, 0, 0], + [0, 0, 0]], + [[0, 0, 0], + [0, 0, 0], + [0, 0, 0]]], dtype=torch.int32) +# _values +tensor([[[[[1.], + [3.]], - [[[ 9.], - [109.]], + [[2.], + [0.]], - [[ 19.], - [119.]], + [[0.], + [4.]]], - [[ 29.], - [129.]]]]], + [[[1.], + [4.]], + [[2.], + [0.]], + [[3.], + [0.]]], - [[[[[ 9.], - [109.]], - [[ 19.], - [119.]], + [[[1.], + [2.]], - [[ 29.], - [129.]]], + [[0.], + [3.]], + [[0.], + [4.]]]], - [[[ 10.], - [110.]], - [[ 20.], - [120.]], - [[ 30.], - [130.]]]], + [[[[0.], + [2.]], + [[1.], + [3.]], + [[0.], + [4.]]], - [[[[ 10.], - [110.]], - [[ 20.], - [120.]], + [[[1.], + [3.]], - [[ 30.], - [130.]]], + [[0.], + [4.]], + [[2.], + [0.]]], - [[[ 11.], - [111.]], - [[ 21.], - [121.]], + [[[1.], + [0.]], - [[ 31.], - [131.]]]], + [[2.], + [4.]], + [[3.], + [0.]]]]]) +########## torch.float32/torch.int32/size=()+(8, 6)+() ########## +# sparse tensor +tensor(ccol_indices=tensor([0, 4, 7]), + row_indices=tensor([0, 1, 2, 3, 0, 2, 3]), + values=tensor([[[ 0., 1., 0.], + [ 0., 4., 0.]], - [[[[ 11.], - [111.]], + [[ 6., 7., 8.], + [ 0., 0., 0.]], - [[ 21.], - [121.]], + [[ 0., 9., 0.], + [13., 0., 14.]], - [[ 31.], - [131.]]], + [[ 0., 0., 0.], + [20., 21., 22.]], + [[ 2., 0., 3.], + [ 0., 5., 0.]], - [[[ 12.], - [112.]], + [[10., 11., 12.], + [15., 16., 17.]], - [[ 22.], - [122.]], + [[ 0., 18., 19.], + [ 0., 23., 24.]]]), size=(8, 6), nnz=7, + layout=torch.sparse_bsc) +# _ccol_indices +tensor([0, 4, 7], dtype=torch.int32) +# _row_indices +tensor([0, 1, 2, 3, 0, 2, 3], dtype=torch.int32) +# _values +tensor([[[ 0., 1., 0.], + [ 0., 4., 0.]], - [[ 32.], - [132.]]]], + [[ 6., 7., 8.], + [ 0., 0., 0.]], + [[ 0., 9., 0.], + [13., 0., 14.]], + [[ 0., 0., 0.], + [20., 21., 22.]], - [[[[ 12.], - [112.]], + [[ 2., 0., 3.], + [ 0., 5., 0.]], - [[ 22.], - [122.]], + [[10., 11., 12.], + [15., 16., 17.]], - [[ 32.], - [132.]]], + [[ 0., 18., 19.], + [ 0., 23., 24.]]]) - [[[ 13.], - [113.]], +########## torch.float64/torch.int32/size=(2, 3)+(2, 3)+() ########## +# sparse tensor +tensor(ccol_indices=tensor([[[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]], - [[ 23.], - [123.]], + [[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]]]), + row_indices=tensor([[[0, 0, 0], + [0, 0, 0], + [0, 0, 0]], - [[ 33.], - [133.]]]]]], + [[0, 0, 0], + [0, 0, 0], + [0, 0, 0]]]), + values=tensor([[[[[1.], + [3.]], + [[2.], + [0.]], + [[0.], + [4.]]], + [[[1.], + [4.]], - [[[[[[ 13.], - [113.]], + [[2.], + [0.]], - [[ 23.], - [123.]], + [[3.], + [0.]]], - [[ 33.], - [133.]]], + [[[1.], + [2.]], - [[[ 14.], - [114.]], + [[0.], + [3.]], - [[ 24.], - [124.]], + [[0.], + [4.]]]], - [[ 34.], - [134.]]]], + [[[[0.], + [2.]], - [[[[ 14.], - [114.]], + [[1.], + [3.]], - [[ 24.], - [124.]], + [[0.], + [4.]]], - [[ 34.], - [134.]]], + [[[1.], + [3.]], - [[[ 15.], - [115.]], + [[0.], + [4.]], - [[ 25.], - [125.]], + [[2.], + [0.]]], - [[ 35.], - [135.]]]], + [[[1.], + [0.]], + [[2.], + [4.]], - [[[[ 15.], - [115.]], + [[3.], + [0.]]]]]), size=(2, 3, 2, 3), nnz=3, + dtype=torch.float64, layout=torch.sparse_bsc) +# _ccol_indices +tensor([[[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]], - [[ 25.], - [125.]], + [[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]]], dtype=torch.int32) +# _row_indices +tensor([[[0, 0, 0], + [0, 0, 0], + [0, 0, 0]], - [[ 35.], - [135.]]], + [[0, 0, 0], + [0, 0, 0], + [0, 0, 0]]], dtype=torch.int32) +# _values +tensor([[[[[1.], + [3.]], + [[2.], + [0.]], - [[[ 16.], - [116.]], + [[0.], + [4.]]], - [[ 26.], - [126.]], - [[ 36.], - [136.]]]], + [[[1.], + [4.]], + [[2.], + [0.]], + [[3.], + [0.]]], - [[[[ 16.], - [116.]], - [[ 26.], - [126.]], + [[[1.], + [2.]], - [[ 36.], - [136.]]], + [[0.], + [3.]], + [[0.], + [4.]]]], - [[[ 17.], - [117.]], - [[ 27.], - [127.]], - [[ 37.], - [137.]]]]], + [[[[0.], + [2.]], + [[1.], + [3.]], + [[0.], + [4.]]], - [[[[[ 17.], - [117.]], + [[[1.], + [3.]], - [[ 27.], - [127.]], + [[0.], + [4.]], - [[ 37.], - [137.]]], + [[2.], + [0.]]], - [[[ 18.], - [118.]], + [[[1.], + [0.]], - [[ 28.], - [128.]], + [[2.], + [4.]], - [[ 38.], - [138.]]]], + [[3.], + [0.]]]]], dtype=torch.float64) +########## torch.float64/torch.int32/size=()+(8, 6)+() ########## +# sparse tensor +tensor(ccol_indices=tensor([0, 4, 7]), + row_indices=tensor([0, 1, 2, 3, 0, 2, 3]), + values=tensor([[[ 0., 1., 0.], + [ 0., 4., 0.]], + [[ 6., 7., 8.], + [ 0., 0., 0.]], - [[[[ 18.], - [118.]], + [[ 0., 9., 0.], + [13., 0., 14.]], - [[ 28.], - [128.]], + [[ 0., 0., 0.], + [20., 21., 22.]], - [[ 38.], - [138.]]], + [[ 2., 0., 3.], + [ 0., 5., 0.]], + [[10., 11., 12.], + [15., 16., 17.]], - [[[ 19.], - [119.]], + [[ 0., 18., 19.], + [ 0., 23., 24.]]]), size=(8, 6), nnz=7, + dtype=torch.float64, layout=torch.sparse_bsc) +# _ccol_indices +tensor([0, 4, 7], dtype=torch.int32) +# _row_indices +tensor([0, 1, 2, 3, 0, 2, 3], dtype=torch.int32) +# _values +tensor([[[ 0., 1., 0.], + [ 0., 4., 0.]], - [[ 29.], - [129.]], + [[ 6., 7., 8.], + [ 0., 0., 0.]], - [[ 39.], - [139.]]]], + [[ 0., 9., 0.], + [13., 0., 14.]], + [[ 0., 0., 0.], + [20., 21., 22.]], + [[ 2., 0., 3.], + [ 0., 5., 0.]], - [[[[ 19.], - [119.]], + [[10., 11., 12.], + [15., 16., 17.]], - [[ 29.], - [129.]], + [[ 0., 18., 19.], + [ 0., 23., 24.]]], dtype=torch.float64) - [[ 39.], - [139.]]], +########## torch.float32/torch.int64/size=(2, 3)+(2, 3)+() ########## +# sparse tensor +tensor(ccol_indices=tensor([[[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]], - [[[ 20.], - [120.]], + [[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]]]), + row_indices=tensor([[[0, 0, 0], + [0, 0, 0], + [0, 0, 0]], - [[ 30.], - [130.]], + [[0, 0, 0], + [0, 0, 0], + [0, 0, 0]]]), + values=tensor([[[[[1.], + [3.]], - [[ 40.], - [140.]]]], + [[2.], + [0.]], + [[0.], + [4.]]], - [[[[ 20.], - [120.]], + [[[1.], + [4.]], - [[ 30.], - [130.]], + [[2.], + [0.]], - [[ 40.], - [140.]]], + [[3.], + [0.]]], - [[[ 21.], - [121.]], + [[[1.], + [2.]], - [[ 31.], - [131.]], + [[0.], + [3.]], - [[ 41.], - [141.]]]]], + [[0.], + [4.]]]], + [[[[0.], + [2.]], - [[[[[ 21.], - [121.]], + [[1.], + [3.]], - [[ 31.], - [131.]], + [[0.], + [4.]]], - [[ 41.], - [141.]]], + [[[1.], + [3.]], - [[[ 22.], - [122.]], + [[0.], + [4.]], - [[ 32.], - [132.]], + [[2.], + [0.]]], - [[ 42.], - [142.]]]], + [[[1.], + [0.]], + [[2.], + [4.]], - [[[[ 22.], - [122.]], + [[3.], + [0.]]]]]), size=(2, 3, 2, 3), nnz=3, + layout=torch.sparse_bsc) +# _ccol_indices +tensor([[[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]], - [[ 32.], - [132.]], + [[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]]]) +# _row_indices +tensor([[[0, 0, 0], + [0, 0, 0], + [0, 0, 0]], - [[ 42.], - [142.]]], + [[0, 0, 0], + [0, 0, 0], + [0, 0, 0]]]) +# _values +tensor([[[[[1.], + [3.]], + [[2.], + [0.]], - [[[ 23.], - [123.]], + [[0.], + [4.]]], - [[ 33.], - [133.]], - [[ 43.], - [143.]]]], + [[[1.], + [4.]], + [[2.], + [0.]], + [[3.], + [0.]]], - [[[[ 23.], - [123.]], - [[ 33.], - [133.]], + [[[1.], + [2.]], - [[ 43.], - [143.]]], + [[0.], + [3.]], + [[0.], + [4.]]]], - [[[ 24.], - [124.]], - [[ 34.], - [134.]], - [[ 44.], - [144.]]]], + [[[[0.], + [2.]], + [[1.], + [3.]], + [[0.], + [4.]]], - [[[[ 24.], - [124.]], - [[ 34.], - [134.]], + [[[1.], + [3.]], - [[ 44.], - [144.]]], + [[0.], + [4.]], + [[2.], + [0.]]], - [[[ 25.], - [125.]], - [[ 35.], - [135.]], + [[[1.], + [0.]], - [[ 45.], - [145.]]]]]]], dtype=torch.float64) + [[2.], + [4.]], + [[3.], + [0.]]]]]) -########## torch.float32/torch.int64/size=()+(6, 6)+(2,) ########## +########## torch.float32/torch.int64/size=()+(8, 6)+() ########## # sparse tensor -tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 2]), - values=tensor([[[[ 1., 101.], - [ 11., 111.], - [ 21., 121.]], - - [[ 2., 102.], - [ 12., 112.], - [ 22., 122.]]], - - - [[[ 2., 102.], - [ 12., 112.], - [ 22., 122.]], - - [[ 3., 103.], - [ 13., 113.], - [ 23., 123.]]], +tensor(ccol_indices=tensor([0, 4, 7]), + row_indices=tensor([0, 1, 2, 3, 0, 2, 3]), + values=tensor([[[ 0., 1., 0.], + [ 0., 4., 0.]], + [[ 6., 7., 8.], + [ 0., 0., 0.]], - [[[ 3., 103.], - [ 13., 113.], - [ 23., 123.]], + [[ 0., 9., 0.], + [13., 0., 14.]], - [[ 4., 104.], - [ 14., 114.], - [ 24., 124.]]], + [[ 0., 0., 0.], + [20., 21., 22.]], + [[ 2., 0., 3.], + [ 0., 5., 0.]], - [[[ 4., 104.], - [ 14., 114.], - [ 24., 124.]], + [[10., 11., 12.], + [15., 16., 17.]], - [[ 5., 105.], - [ 15., 115.], - [ 25., 125.]]]]), size=(6, 6, 2), nnz=4, + [[ 0., 18., 19.], + [ 0., 23., 24.]]]), size=(8, 6), nnz=7, layout=torch.sparse_bsc) # _ccol_indices -tensor([0, 2, 4]) +tensor([0, 4, 7]) # _row_indices -tensor([0, 1, 0, 2]) +tensor([0, 1, 2, 3, 0, 2, 3]) # _values -tensor([[[[ 1., 101.], - [ 11., 111.], - [ 21., 121.]], +tensor([[[ 0., 1., 0.], + [ 0., 4., 0.]], - [[ 2., 102.], - [ 12., 112.], - [ 22., 122.]]], + [[ 6., 7., 8.], + [ 0., 0., 0.]], + [[ 0., 9., 0.], + [13., 0., 14.]], - [[[ 2., 102.], - [ 12., 112.], - [ 22., 122.]], + [[ 0., 0., 0.], + [20., 21., 22.]], - [[ 3., 103.], - [ 13., 113.], - [ 23., 123.]]], + [[ 2., 0., 3.], + [ 0., 5., 0.]], + [[10., 11., 12.], + [15., 16., 17.]], - [[[ 3., 103.], - [ 13., 113.], - [ 23., 123.]], + [[ 0., 18., 19.], + [ 0., 23., 24.]]]) - [[ 4., 104.], - [ 14., 114.], - [ 24., 124.]]], - - [[[ 4., 104.], - [ 14., 114.], - [ 24., 124.]], - - [[ 5., 105.], - [ 15., 115.], - [ 25., 125.]]]]) - -########## torch.float32/torch.int64/size=()+(9, 4)+(4, 2) ########## +########## torch.float64/torch.int64/size=(2, 3)+(2, 3)+() ########## # sparse tensor -tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 2]), - values=tensor([[[[[1.0000e+00, 1.0010e+03], - [1.0100e+02, 1.1010e+03], - [2.0100e+02, 1.2010e+03], - [3.0100e+02, 1.3010e+03]], - - [[1.1000e+01, 1.0110e+03], - [1.1100e+02, 1.1110e+03], - [2.1100e+02, 1.2110e+03], - [3.1100e+02, 1.3110e+03]]], - - - [[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], - - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]]], - +tensor(ccol_indices=tensor([[[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]], - [[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], + [[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]]]), + row_indices=tensor([[[0, 0, 0], + [0, 0, 0], + [0, 0, 0]], - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]]]], + [[0, 0, 0], + [0, 0, 0], + [0, 0, 0]]]), + values=tensor([[[[[1.], + [3.]], + [[2.], + [0.]], + [[0.], + [4.]]], - [[[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]]], + [[[1.], + [4.]], + [[2.], + [0.]], - [[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], + [[3.], + [0.]]], - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]]], + [[[1.], + [2.]], - [[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], + [[0.], + [3.]], - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]]]], + [[0.], + [4.]]]], - [[[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], + [[[[0.], + [2.]], - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]]], + [[1.], + [3.]], + [[0.], + [4.]]], - [[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]]], + [[[1.], + [3.]], + [[0.], + [4.]], - [[[5.0000e+00, 1.0050e+03], - [1.0500e+02, 1.1050e+03], - [2.0500e+02, 1.2050e+03], - [3.0500e+02, 1.3050e+03]], + [[2.], + [0.]]], - [[1.5000e+01, 1.0150e+03], - [1.1500e+02, 1.1150e+03], - [2.1500e+02, 1.2150e+03], - [3.1500e+02, 1.3150e+03]]]], + [[[1.], + [0.]], + [[2.], + [4.]], - [[[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], - - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]]], - - - [[[5.0000e+00, 1.0050e+03], - [1.0500e+02, 1.1050e+03], - [2.0500e+02, 1.2050e+03], - [3.0500e+02, 1.3050e+03]], - - [[1.5000e+01, 1.0150e+03], - [1.1500e+02, 1.1150e+03], - [2.1500e+02, 1.2150e+03], - [3.1500e+02, 1.3150e+03]]], - - - [[[6.0000e+00, 1.0060e+03], - [1.0600e+02, 1.1060e+03], - [2.0600e+02, 1.2060e+03], - [3.0600e+02, 1.3060e+03]], - - [[1.6000e+01, 1.0160e+03], - [1.1600e+02, 1.1160e+03], - [2.1600e+02, 1.2160e+03], - [3.1600e+02, 1.3160e+03]]]]]), size=(9, 4, 4, 2), - nnz=4, layout=torch.sparse_bsc) + [[3.], + [0.]]]]]), size=(2, 3, 2, 3), nnz=3, + dtype=torch.float64, layout=torch.sparse_bsc) # _ccol_indices -tensor([0, 2, 4]) -# _row_indices -tensor([0, 1, 0, 2]) -# _values -tensor([[[[[1.0000e+00, 1.0010e+03], - [1.0100e+02, 1.1010e+03], - [2.0100e+02, 1.2010e+03], - [3.0100e+02, 1.3010e+03]], - - [[1.1000e+01, 1.0110e+03], - [1.1100e+02, 1.1110e+03], - [2.1100e+02, 1.2110e+03], - [3.1100e+02, 1.3110e+03]]], - - - [[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], - - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]]], - - - [[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], +tensor([[[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]], - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]]]], - - - - [[[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], - - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]]], - - - [[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], + [[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]]]) +# _row_indices +tensor([[[0, 0, 0], + [0, 0, 0], + [0, 0, 0]], - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]]], + [[0, 0, 0], + [0, 0, 0], + [0, 0, 0]]]) +# _values +tensor([[[[[1.], + [3.]], + [[2.], + [0.]], - [[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], + [[0.], + [4.]]], - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]]]], + [[[1.], + [4.]], + [[2.], + [0.]], - [[[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], + [[3.], + [0.]]], - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]]], + [[[1.], + [2.]], - [[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], + [[0.], + [3.]], - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]]], + [[0.], + [4.]]]], - [[[5.0000e+00, 1.0050e+03], - [1.0500e+02, 1.1050e+03], - [2.0500e+02, 1.2050e+03], - [3.0500e+02, 1.3050e+03]], - [[1.5000e+01, 1.0150e+03], - [1.1500e+02, 1.1150e+03], - [2.1500e+02, 1.2150e+03], - [3.1500e+02, 1.3150e+03]]]], + [[[[0.], + [2.]], + [[1.], + [3.]], + [[0.], + [4.]]], - [[[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]]], + [[[1.], + [3.]], + [[0.], + [4.]], - [[[5.0000e+00, 1.0050e+03], - [1.0500e+02, 1.1050e+03], - [2.0500e+02, 1.2050e+03], - [3.0500e+02, 1.3050e+03]], + [[2.], + [0.]]], - [[1.5000e+01, 1.0150e+03], - [1.1500e+02, 1.1150e+03], - [2.1500e+02, 1.2150e+03], - [3.1500e+02, 1.3150e+03]]], + [[[1.], + [0.]], - [[[6.0000e+00, 1.0060e+03], - [1.0600e+02, 1.1060e+03], - [2.0600e+02, 1.2060e+03], - [3.0600e+02, 1.3060e+03]], + [[2.], + [4.]], - [[1.6000e+01, 1.0160e+03], - [1.1600e+02, 1.1160e+03], - [2.1600e+02, 1.2160e+03], - [3.1600e+02, 1.3160e+03]]]]]) + [[3.], + [0.]]]]], dtype=torch.float64) -########## torch.float32/torch.int64/size=(2, 3)+(6, 6)+(2, 1) ########## +########## torch.float64/torch.int64/size=()+(8, 6)+() ########## # sparse tensor -tensor(ccol_indices=tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], +tensor(ccol_indices=tensor([0, 4, 7]), + row_indices=tensor([0, 1, 2, 3, 0, 2, 3]), + values=tensor([[[ 0., 1., 0.], + [ 0., 4., 0.]], - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]), - row_indices=tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], + [[ 6., 7., 8.], + [ 0., 0., 0.]], - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]), - values=tensor([[[[[[[ 1.], - [101.]], + [[ 0., 9., 0.], + [13., 0., 14.]], - [[ 11.], - [111.]], + [[ 0., 0., 0.], + [20., 21., 22.]], - [[ 21.], - [121.]]], + [[ 2., 0., 3.], + [ 0., 5., 0.]], + [[10., 11., 12.], + [15., 16., 17.]], - [[[ 2.], - [102.]], - - [[ 12.], - [112.]], - - [[ 22.], - [122.]]]], + [[ 0., 18., 19.], + [ 0., 23., 24.]]]), size=(8, 6), nnz=7, + dtype=torch.float64, layout=torch.sparse_bsc) +# _ccol_indices +tensor([0, 4, 7]) +# _row_indices +tensor([0, 1, 2, 3, 0, 2, 3]) +# _values +tensor([[[ 0., 1., 0.], + [ 0., 4., 0.]], + [[ 6., 7., 8.], + [ 0., 0., 0.]], + [[ 0., 9., 0.], + [13., 0., 14.]], - [[[[ 2.], - [102.]], + [[ 0., 0., 0.], + [20., 21., 22.]], - [[ 12.], - [112.]], + [[ 2., 0., 3.], + [ 0., 5., 0.]], - [[ 22.], - [122.]]], + [[10., 11., 12.], + [15., 16., 17.]], + [[ 0., 18., 19.], + [ 0., 23., 24.]]], dtype=torch.float64) - [[[ 3.], - [103.]], - [[ 13.], - [113.]], +########## torch.float32/torch.int32/size=(2, 3)+(2, 3)+(4,) ########## +# sparse tensor +tensor(ccol_indices=tensor([[[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]], - [[ 23.], - [123.]]]], + [[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]]]), + row_indices=tensor([[[0, 0, 0], + [0, 0, 0], + [0, 0, 0]], + [[0, 0, 0], + [0, 0, 0], + [0, 0, 0]]]), + values=tensor([[[[[[1., 2., 3., 4.]], + [[3., 4., 5., 6.]]], - [[[[ 3.], - [103.]], - [[ 13.], - [113.]], + [[[2., 3., 4., 5.]], - [[ 23.], - [123.]]], + [[0., 0., 0., 0.]]], - [[[ 4.], - [104.]], + [[[0., 0., 0., 0.]], - [[ 14.], - [114.]], + [[4., 5., 6., 7.]]]], - [[ 24.], - [124.]]]], + [[[[1., 2., 3., 4.]], - [[[[ 4.], - [104.]], + [[4., 5., 6., 7.]]], - [[ 14.], - [114.]], - [[ 24.], - [124.]]], + [[[2., 3., 4., 5.]], + [[0., 0., 0., 0.]]], - [[[ 5.], - [105.]], - [[ 15.], - [115.]], + [[[3., 4., 5., 6.]], - [[ 25.], - [125.]]]]], + [[0., 0., 0., 0.]]]], + [[[[1., 2., 3., 4.]], - [[[[[ 5.], - [105.]], + [[2., 3., 4., 5.]]], - [[ 15.], - [115.]], - [[ 25.], - [125.]]], + [[[0., 0., 0., 0.]], + [[3., 4., 5., 6.]]], - [[[ 6.], - [106.]], - [[ 16.], - [116.]], + [[[0., 0., 0., 0.]], - [[ 26.], - [126.]]]], + [[4., 5., 6., 7.]]]]], - [[[[ 6.], - [106.]], - [[ 16.], - [116.]], + [[[[[0., 0., 0., 0.]], - [[ 26.], - [126.]]], + [[2., 3., 4., 5.]]], - [[[ 7.], - [107.]], + [[[1., 2., 3., 4.]], - [[ 17.], - [117.]], + [[3., 4., 5., 6.]]], - [[ 27.], - [127.]]]], + [[[0., 0., 0., 0.]], + [[4., 5., 6., 7.]]]], - [[[[ 7.], - [107.]], - [[ 17.], - [117.]], - [[ 27.], - [127.]]], + [[[[1., 2., 3., 4.]], + [[3., 4., 5., 6.]]], - [[[ 8.], - [108.]], - [[ 18.], - [118.]], + [[[0., 0., 0., 0.]], - [[ 28.], - [128.]]]], + [[4., 5., 6., 7.]]], + [[[2., 3., 4., 5.]], - [[[[ 8.], - [108.]], + [[0., 0., 0., 0.]]]], - [[ 18.], - [118.]], - [[ 28.], - [128.]]], + [[[[1., 2., 3., 4.]], - [[[ 9.], - [109.]], + [[0., 0., 0., 0.]]], - [[ 19.], - [119.]], - [[ 29.], - [129.]]]]], + [[[2., 3., 4., 5.]], + [[4., 5., 6., 7.]]], + [[[3., 4., 5., 6.]], - [[[[[ 9.], - [109.]], + [[0., 0., 0., 0.]]]]]]), size=(2, 3, 2, 3, 4), nnz=3, + layout=torch.sparse_bsc) +# _ccol_indices +tensor([[[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]], - [[ 19.], - [119.]], + [[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]]], dtype=torch.int32) +# _row_indices +tensor([[[0, 0, 0], + [0, 0, 0], + [0, 0, 0]], - [[ 29.], - [129.]]], + [[0, 0, 0], + [0, 0, 0], + [0, 0, 0]]], dtype=torch.int32) +# _values +tensor([[[[[[1., 2., 3., 4.]], + [[3., 4., 5., 6.]]], - [[[ 10.], - [110.]], - [[ 20.], - [120.]], + [[[2., 3., 4., 5.]], - [[ 30.], - [130.]]]], + [[0., 0., 0., 0.]]], + [[[0., 0., 0., 0.]], - [[[[ 10.], - [110.]], + [[4., 5., 6., 7.]]]], - [[ 20.], - [120.]], - [[ 30.], - [130.]]], + [[[[1., 2., 3., 4.]], - [[[ 11.], - [111.]], + [[4., 5., 6., 7.]]], - [[ 21.], - [121.]], - [[ 31.], - [131.]]]], + [[[2., 3., 4., 5.]], + [[0., 0., 0., 0.]]], - [[[[ 11.], - [111.]], + [[[3., 4., 5., 6.]], - [[ 21.], - [121.]], + [[0., 0., 0., 0.]]]], - [[ 31.], - [131.]]], - [[[ 12.], - [112.]], + [[[[1., 2., 3., 4.]], - [[ 22.], - [122.]], + [[2., 3., 4., 5.]]], - [[ 32.], - [132.]]]], + [[[0., 0., 0., 0.]], + [[3., 4., 5., 6.]]], - [[[[ 12.], - [112.]], - [[ 22.], - [122.]], + [[[0., 0., 0., 0.]], - [[ 32.], - [132.]]], + [[4., 5., 6., 7.]]]]], - [[[ 13.], - [113.]], - [[ 23.], - [123.]], - [[ 33.], - [133.]]]]]], + [[[[[0., 0., 0., 0.]], + [[2., 3., 4., 5.]]], + [[[1., 2., 3., 4.]], + [[3., 4., 5., 6.]]], - [[[[[[ 13.], - [113.]], - [[ 23.], - [123.]], + [[[0., 0., 0., 0.]], - [[ 33.], - [133.]]], + [[4., 5., 6., 7.]]]], - [[[ 14.], - [114.]], - [[ 24.], - [124.]], + [[[[1., 2., 3., 4.]], - [[ 34.], - [134.]]]], + [[3., 4., 5., 6.]]], + [[[0., 0., 0., 0.]], - [[[[ 14.], - [114.]], + [[4., 5., 6., 7.]]], - [[ 24.], - [124.]], - [[ 34.], - [134.]]], + [[[2., 3., 4., 5.]], + [[0., 0., 0., 0.]]]], - [[[ 15.], - [115.]], - [[ 25.], - [125.]], - [[ 35.], - [135.]]]], + [[[[1., 2., 3., 4.]], + [[0., 0., 0., 0.]]], - [[[[ 15.], - [115.]], + [[[2., 3., 4., 5.]], - [[ 25.], - [125.]], + [[4., 5., 6., 7.]]], - [[ 35.], - [135.]]], + [[[3., 4., 5., 6.]], - [[[ 16.], - [116.]], + [[0., 0., 0., 0.]]]]]]) - [[ 26.], - [126.]], +########## torch.float32/torch.int32/size=()+(8, 6)+(4, 2) ########## +# sparse tensor +tensor(ccol_indices=tensor([0, 4, 7]), + row_indices=tensor([0, 1, 2, 3, 0, 2, 3]), + values=tensor([[[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + + [[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], - [[ 36.], - [136.]]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[[ 16.], - [116.]], + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]], - [[ 26.], - [126.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[ 36.], - [136.]]], - [[[ 17.], - [117.]], + [[[[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], - [[ 27.], - [127.]], + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], - [[ 37.], - [137.]]]]], + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[[[ 17.], - [117.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[ 27.], - [127.]], - [[ 37.], - [137.]]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[ 18.], - [118.]], + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], - [[ 28.], - [128.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[ 38.], - [138.]]]], + [[[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[[ 18.], - [118.]], + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]]]], - [[ 28.], - [128.]], - [[ 38.], - [138.]]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[ 19.], - [119.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 29.], - [129.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[ 39.], - [139.]]]], + [[[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], - [[[[ 19.], - [119.]], + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]]]], - [[ 29.], - [129.]], - [[ 39.], - [139.]]], + [[[[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], - [[[ 20.], - [120.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 30.], - [130.]], + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]]], - [[ 40.], - [140.]]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], - [[[[ 20.], - [120.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[ 30.], - [130.]], - [[ 40.], - [140.]]], + [[[[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], - [[[ 21.], - [121.]], + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], - [[ 31.], - [131.]], + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]]], - [[ 41.], - [141.]]]]], + [[[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]]]], - [[[[[ 21.], - [121.]], - [[ 31.], - [131.]], - [[ 41.], - [141.]]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], - [[[ 22.], - [122.]], + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]]], - [[ 32.], - [132.]], - [[ 42.], - [142.]]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]]]]), size=(8, 6, 4, 2), nnz=7, + layout=torch.sparse_bsc) +# _ccol_indices +tensor([0, 4, 7], dtype=torch.int32) +# _row_indices +tensor([0, 1, 2, 3, 0, 2, 3], dtype=torch.int32) +# _values +tensor([[[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[[ 22.], - [122.]], + [[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], - [[ 32.], - [132.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[ 42.], - [142.]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[ 23.], - [123.]], + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]], - [[ 33.], - [133.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[ 43.], - [143.]]]], + [[[[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], - [[[[ 23.], - [123.]], + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], - [[ 33.], - [133.]], + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]]], - [[ 43.], - [143.]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[ 24.], - [124.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 34.], - [134.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[ 44.], - [144.]]]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[[ 24.], - [124.]], + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], - [[ 34.], - [134.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[ 44.], - [144.]]], + [[[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], - [[[ 25.], - [125.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 35.], - [135.]], + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]]]], - [[ 45.], - [145.]]]]]]]), size=(2, 3, 6, 6, 2, 1), nnz=4, - layout=torch.sparse_bsc) -# _ccol_indices -tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]) -# _row_indices -tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]) -# _values -tensor([[[[[[[ 1.], - [101.]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 11.], - [111.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 21.], - [121.]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[[ 2.], - [102.]], + [[[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], - [[ 12.], - [112.]], + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], - [[ 22.], - [122.]]]], + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]]]], - [[[[ 2.], - [102.]], + [[[[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], - [[ 12.], - [112.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 22.], - [122.]]], + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]]], - [[[ 3.], - [103.]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 13.], - [113.]], + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], - [[ 23.], - [123.]]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[[[ 3.], - [103.]], + [[[[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], - [[ 13.], - [113.]], + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], - [[ 23.], - [123.]]], + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]]], - [[[ 4.], - [104.]], + [[[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], - [[ 14.], - [114.]], + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], - [[ 24.], - [124.]]]], + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]]]], - [[[[ 4.], - [104.]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 14.], - [114.]], + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], - [[ 24.], - [124.]]], + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]]], - [[[ 5.], - [105.]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 15.], - [115.]], + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], - [[ 25.], - [125.]]]]], + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]]]]) +########## torch.float64/torch.int32/size=(2, 3)+(2, 3)+(4,) ########## +# sparse tensor +tensor(ccol_indices=tensor([[[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]], + [[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]]]), + row_indices=tensor([[[0, 0, 0], + [0, 0, 0], + [0, 0, 0]], - [[[[[ 5.], - [105.]], + [[0, 0, 0], + [0, 0, 0], + [0, 0, 0]]]), + values=tensor([[[[[[1., 2., 3., 4.]], - [[ 15.], - [115.]], + [[3., 4., 5., 6.]]], - [[ 25.], - [125.]]], + [[[2., 3., 4., 5.]], - [[[ 6.], - [106.]], + [[0., 0., 0., 0.]]], - [[ 16.], - [116.]], - [[ 26.], - [126.]]]], + [[[0., 0., 0., 0.]], + [[4., 5., 6., 7.]]]], - [[[[ 6.], - [106.]], - [[ 16.], - [116.]], + [[[[1., 2., 3., 4.]], - [[ 26.], - [126.]]], + [[4., 5., 6., 7.]]], - [[[ 7.], - [107.]], + [[[2., 3., 4., 5.]], - [[ 17.], - [117.]], + [[0., 0., 0., 0.]]], - [[ 27.], - [127.]]]], + [[[3., 4., 5., 6.]], + [[0., 0., 0., 0.]]]], - [[[[ 7.], - [107.]], - [[ 17.], - [117.]], - [[ 27.], - [127.]]], + [[[[1., 2., 3., 4.]], + [[2., 3., 4., 5.]]], - [[[ 8.], - [108.]], - [[ 18.], - [118.]], + [[[0., 0., 0., 0.]], - [[ 28.], - [128.]]]], + [[3., 4., 5., 6.]]], + [[[0., 0., 0., 0.]], - [[[[ 8.], - [108.]], + [[4., 5., 6., 7.]]]]], - [[ 18.], - [118.]], - [[ 28.], - [128.]]], - [[[ 9.], - [109.]], + [[[[[0., 0., 0., 0.]], - [[ 19.], - [119.]], + [[2., 3., 4., 5.]]], - [[ 29.], - [129.]]]]], + [[[1., 2., 3., 4.]], + [[3., 4., 5., 6.]]], - [[[[[ 9.], - [109.]], + [[[0., 0., 0., 0.]], - [[ 19.], - [119.]], + [[4., 5., 6., 7.]]]], - [[ 29.], - [129.]]], - [[[ 10.], - [110.]], + [[[[1., 2., 3., 4.]], - [[ 20.], - [120.]], + [[3., 4., 5., 6.]]], - [[ 30.], - [130.]]]], + [[[0., 0., 0., 0.]], + [[4., 5., 6., 7.]]], - [[[[ 10.], - [110.]], - [[ 20.], - [120.]], + [[[2., 3., 4., 5.]], - [[ 30.], - [130.]]], + [[0., 0., 0., 0.]]]], - [[[ 11.], - [111.]], - [[ 21.], - [121.]], + [[[[1., 2., 3., 4.]], - [[ 31.], - [131.]]]], + [[0., 0., 0., 0.]]], + [[[2., 3., 4., 5.]], - [[[[ 11.], - [111.]], + [[4., 5., 6., 7.]]], - [[ 21.], - [121.]], - [[ 31.], - [131.]]], + [[[3., 4., 5., 6.]], + [[0., 0., 0., 0.]]]]]]), size=(2, 3, 2, 3, 4), nnz=3, + dtype=torch.float64, layout=torch.sparse_bsc) +# _ccol_indices +tensor([[[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]], - [[[ 12.], - [112.]], + [[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]]], dtype=torch.int32) +# _row_indices +tensor([[[0, 0, 0], + [0, 0, 0], + [0, 0, 0]], - [[ 22.], - [122.]], + [[0, 0, 0], + [0, 0, 0], + [0, 0, 0]]], dtype=torch.int32) +# _values +tensor([[[[[[1., 2., 3., 4.]], - [[ 32.], - [132.]]]], + [[3., 4., 5., 6.]]], + [[[2., 3., 4., 5.]], - [[[[ 12.], - [112.]], + [[0., 0., 0., 0.]]], - [[ 22.], - [122.]], - [[ 32.], - [132.]]], + [[[0., 0., 0., 0.]], + [[4., 5., 6., 7.]]]], - [[[ 13.], - [113.]], - [[ 23.], - [123.]], - [[ 33.], - [133.]]]]]], + [[[[1., 2., 3., 4.]], + [[4., 5., 6., 7.]]], + [[[2., 3., 4., 5.]], + [[0., 0., 0., 0.]]], - [[[[[[ 13.], - [113.]], - [[ 23.], - [123.]], + [[[3., 4., 5., 6.]], - [[ 33.], - [133.]]], + [[0., 0., 0., 0.]]]], - [[[ 14.], - [114.]], - [[ 24.], - [124.]], + [[[[1., 2., 3., 4.]], - [[ 34.], - [134.]]]], + [[2., 3., 4., 5.]]], + [[[0., 0., 0., 0.]], - [[[[ 14.], - [114.]], + [[3., 4., 5., 6.]]], - [[ 24.], - [124.]], - [[ 34.], - [134.]]], + [[[0., 0., 0., 0.]], + [[4., 5., 6., 7.]]]]], - [[[ 15.], - [115.]], - [[ 25.], - [125.]], - [[ 35.], - [135.]]]], + [[[[[0., 0., 0., 0.]], + [[2., 3., 4., 5.]]], - [[[[ 15.], - [115.]], - [[ 25.], - [125.]], + [[[1., 2., 3., 4.]], - [[ 35.], - [135.]]], + [[3., 4., 5., 6.]]], - [[[ 16.], - [116.]], + [[[0., 0., 0., 0.]], - [[ 26.], - [126.]], + [[4., 5., 6., 7.]]]], - [[ 36.], - [136.]]]], + [[[[1., 2., 3., 4.]], - [[[[ 16.], - [116.]], + [[3., 4., 5., 6.]]], - [[ 26.], - [126.]], - [[ 36.], - [136.]]], + [[[0., 0., 0., 0.]], + [[4., 5., 6., 7.]]], - [[[ 17.], - [117.]], - [[ 27.], - [127.]], + [[[2., 3., 4., 5.]], - [[ 37.], - [137.]]]]], + [[0., 0., 0., 0.]]]], + [[[[1., 2., 3., 4.]], - [[[[[ 17.], - [117.]], + [[0., 0., 0., 0.]]], - [[ 27.], - [127.]], - [[ 37.], - [137.]]], + [[[2., 3., 4., 5.]], + [[4., 5., 6., 7.]]], - [[[ 18.], - [118.]], - [[ 28.], - [128.]], + [[[3., 4., 5., 6.]], - [[ 38.], - [138.]]]], + [[0., 0., 0., 0.]]]]]], dtype=torch.float64) +########## torch.float64/torch.int32/size=()+(8, 6)+(4, 2) ########## +# sparse tensor +tensor(ccol_indices=tensor([0, 4, 7]), + row_indices=tensor([0, 1, 2, 3, 0, 2, 3]), + values=tensor([[[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + + [[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[[[ 18.], - [118.]], - [[ 28.], - [128.]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 38.], - [138.]]], + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[[ 19.], - [119.]], - [[ 29.], - [129.]], - [[ 39.], - [139.]]]], + [[[[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]]], - [[[[ 19.], - [119.]], - [[ 29.], - [129.]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 39.], - [139.]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[[ 20.], - [120.]], - [[ 30.], - [130.]], - [[ 40.], - [140.]]]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[[[ 20.], - [120.]], - [[ 30.], - [130.]], + [[[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], - [[ 40.], - [140.]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]]]], - [[[ 21.], - [121.]], - [[ 31.], - [131.]], - [[ 41.], - [141.]]]]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[[[[ 21.], - [121.]], + [[[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], - [[ 31.], - [131.]], + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], - [[ 41.], - [141.]]], + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]]]], - [[[ 22.], - [122.]], - [[ 32.], - [132.]], + [[[[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], - [[ 42.], - [142.]]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]]], - [[[[ 22.], - [122.]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 32.], - [132.]], + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], - [[ 42.], - [142.]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[[ 23.], - [123.]], - [[ 33.], - [133.]], + [[[[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], - [[ 43.], - [143.]]]], + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]]], - [[[[ 23.], - [123.]], + [[[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], - [[ 33.], - [133.]], + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], - [[ 43.], - [143.]]], + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]]]], - [[[ 24.], - [124.]], - [[ 34.], - [134.]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 44.], - [144.]]]], + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]]], - [[[[ 24.], - [124.]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 34.], - [134.]], + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], - [[ 44.], - [144.]]], + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]]]]), size=(8, 6, 4, 2), nnz=7, + dtype=torch.float64, layout=torch.sparse_bsc) +# _ccol_indices +tensor([0, 4, 7], dtype=torch.int32) +# _row_indices +tensor([0, 1, 2, 3, 0, 2, 3], dtype=torch.int32) +# _values +tensor([[[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], - [[[ 25.], - [125.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[ 35.], - [135.]], - [[ 45.], - [145.]]]]]]]) + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]], -########## torch.float64/torch.int64/size=()+(6, 6)+(2,) ########## -# sparse tensor -tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 2]), - values=tensor([[[[ 1., 101.], - [ 11., 111.], - [ 21., 121.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[ 2., 102.], - [ 12., 112.], - [ 22., 122.]]], - [[[ 2., 102.], - [ 12., 112.], - [ 22., 122.]], + [[[[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], - [[ 3., 103.], - [ 13., 113.], - [ 23., 123.]]], + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]]], - [[[ 3., 103.], - [ 13., 113.], - [ 23., 123.]], - [[ 4., 104.], - [ 14., 114.], - [ 24., 124.]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[ 4., 104.], - [ 14., 114.], - [ 24., 124.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[ 5., 105.], - [ 15., 115.], - [ 25., 125.]]]]), size=(6, 6, 2), nnz=4, - dtype=torch.float64, layout=torch.sparse_bsc) -# _ccol_indices -tensor([0, 2, 4]) -# _row_indices -tensor([0, 1, 0, 2]) -# _values -tensor([[[[ 1., 101.], - [ 11., 111.], - [ 21., 121.]], - [[ 2., 102.], - [ 12., 112.], - [ 22., 122.]]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[ 2., 102.], - [ 12., 112.], - [ 22., 122.]], + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], - [[ 3., 103.], - [ 13., 113.], - [ 23., 123.]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[[ 3., 103.], - [ 13., 113.], - [ 23., 123.]], + [[[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], - [[ 4., 104.], - [ 14., 114.], - [ 24., 124.]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]]]], - [[[ 4., 104.], - [ 14., 114.], - [ 24., 124.]], - [[ 5., 105.], - [ 15., 115.], - [ 25., 125.]]]], dtype=torch.float64) -########## torch.float64/torch.int64/size=()+(9, 4)+(4, 2) ########## -# sparse tensor -tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 2]), - values=tensor([[[[[1.0000e+00, 1.0010e+03], - [1.0100e+02, 1.1010e+03], - [2.0100e+02, 1.2010e+03], - [3.0100e+02, 1.3010e+03]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[1.1000e+01, 1.0110e+03], - [1.1100e+02, 1.1110e+03], - [2.1100e+02, 1.2110e+03], - [3.1100e+02, 1.3110e+03]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]]], + [[[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], - [[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]]]], - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]]]], + [[[[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], - [[[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]]], + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]]], - [[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]]], + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]]]], + [[[[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], - [[[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]]], - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]]], + [[[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], - [[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]]], + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]]]], - [[[5.0000e+00, 1.0050e+03], - [1.0500e+02, 1.1050e+03], - [2.0500e+02, 1.2050e+03], - [3.0500e+02, 1.3050e+03]], - [[1.5000e+01, 1.0150e+03], - [1.1500e+02, 1.1150e+03], - [2.1500e+02, 1.2150e+03], - [3.1500e+02, 1.3150e+03]]]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]]], - [[[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], - [[[5.0000e+00, 1.0050e+03], - [1.0500e+02, 1.1050e+03], - [2.0500e+02, 1.2050e+03], - [3.0500e+02, 1.3050e+03]], + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]]]], dtype=torch.float64) - [[1.5000e+01, 1.0150e+03], - [1.1500e+02, 1.1150e+03], - [2.1500e+02, 1.2150e+03], - [3.1500e+02, 1.3150e+03]]], +########## torch.float32/torch.int64/size=(2, 3)+(2, 3)+(4,) ########## +# sparse tensor +tensor(ccol_indices=tensor([[[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]], - [[[6.0000e+00, 1.0060e+03], - [1.0600e+02, 1.1060e+03], - [2.0600e+02, 1.2060e+03], - [3.0600e+02, 1.3060e+03]], + [[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]]]), + row_indices=tensor([[[0, 0, 0], + [0, 0, 0], + [0, 0, 0]], - [[1.6000e+01, 1.0160e+03], - [1.1600e+02, 1.1160e+03], - [2.1600e+02, 1.2160e+03], - [3.1600e+02, 1.3160e+03]]]]]), size=(9, 4, 4, 2), - nnz=4, dtype=torch.float64, layout=torch.sparse_bsc) -# _ccol_indices -tensor([0, 2, 4]) -# _row_indices -tensor([0, 1, 0, 2]) -# _values -tensor([[[[[1.0000e+00, 1.0010e+03], - [1.0100e+02, 1.1010e+03], - [2.0100e+02, 1.2010e+03], - [3.0100e+02, 1.3010e+03]], + [[0, 0, 0], + [0, 0, 0], + [0, 0, 0]]]), + values=tensor([[[[[[1., 2., 3., 4.]], - [[1.1000e+01, 1.0110e+03], - [1.1100e+02, 1.1110e+03], - [2.1100e+02, 1.2110e+03], - [3.1100e+02, 1.3110e+03]]], + [[3., 4., 5., 6.]]], - [[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], + [[[2., 3., 4., 5.]], - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]]], + [[0., 0., 0., 0.]]], - [[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], + [[[0., 0., 0., 0.]], - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]]]], + [[4., 5., 6., 7.]]]], - [[[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], + [[[[1., 2., 3., 4.]], - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]]], + [[4., 5., 6., 7.]]], - [[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], + [[[2., 3., 4., 5.]], - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]]], + [[0., 0., 0., 0.]]], - [[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], + [[[3., 4., 5., 6.]], - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]]]], + [[0., 0., 0., 0.]]]], - [[[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], + [[[[1., 2., 3., 4.]], - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]]], + [[2., 3., 4., 5.]]], - [[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], + [[[0., 0., 0., 0.]], - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]]], + [[3., 4., 5., 6.]]], - [[[5.0000e+00, 1.0050e+03], - [1.0500e+02, 1.1050e+03], - [2.0500e+02, 1.2050e+03], - [3.0500e+02, 1.3050e+03]], + [[[0., 0., 0., 0.]], - [[1.5000e+01, 1.0150e+03], - [1.1500e+02, 1.1150e+03], - [2.1500e+02, 1.2150e+03], - [3.1500e+02, 1.3150e+03]]]], + [[4., 5., 6., 7.]]]]], - [[[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]]], + [[[[[0., 0., 0., 0.]], + [[2., 3., 4., 5.]]], - [[[5.0000e+00, 1.0050e+03], - [1.0500e+02, 1.1050e+03], - [2.0500e+02, 1.2050e+03], - [3.0500e+02, 1.3050e+03]], - [[1.5000e+01, 1.0150e+03], - [1.1500e+02, 1.1150e+03], - [2.1500e+02, 1.2150e+03], - [3.1500e+02, 1.3150e+03]]], + [[[1., 2., 3., 4.]], + [[3., 4., 5., 6.]]], - [[[6.0000e+00, 1.0060e+03], - [1.0600e+02, 1.1060e+03], - [2.0600e+02, 1.2060e+03], - [3.0600e+02, 1.3060e+03]], - [[1.6000e+01, 1.0160e+03], - [1.1600e+02, 1.1160e+03], - [2.1600e+02, 1.2160e+03], - [3.1600e+02, 1.3160e+03]]]]], dtype=torch.float64) + [[[0., 0., 0., 0.]], -########## torch.float64/torch.int64/size=(2, 3)+(6, 6)+(2, 1) ########## -# sparse tensor -tensor(ccol_indices=tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], + [[4., 5., 6., 7.]]]], - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]), - row_indices=tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]), - values=tensor([[[[[[[ 1.], - [101.]], - [[ 11.], - [111.]], + [[[[1., 2., 3., 4.]], - [[ 21.], - [121.]]], + [[3., 4., 5., 6.]]], - [[[ 2.], - [102.]], + [[[0., 0., 0., 0.]], - [[ 12.], - [112.]], + [[4., 5., 6., 7.]]], - [[ 22.], - [122.]]]], + [[[2., 3., 4., 5.]], + [[0., 0., 0., 0.]]]], - [[[[ 2.], - [102.]], - [[ 12.], - [112.]], - [[ 22.], - [122.]]], + [[[[1., 2., 3., 4.]], + [[0., 0., 0., 0.]]], - [[[ 3.], - [103.]], - [[ 13.], - [113.]], + [[[2., 3., 4., 5.]], - [[ 23.], - [123.]]]], + [[4., 5., 6., 7.]]], + [[[3., 4., 5., 6.]], - [[[[ 3.], - [103.]], + [[0., 0., 0., 0.]]]]]]), size=(2, 3, 2, 3, 4), nnz=3, + layout=torch.sparse_bsc) +# _ccol_indices +tensor([[[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]], - [[ 13.], - [113.]], + [[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]]]) +# _row_indices +tensor([[[0, 0, 0], + [0, 0, 0], + [0, 0, 0]], - [[ 23.], - [123.]]], + [[0, 0, 0], + [0, 0, 0], + [0, 0, 0]]]) +# _values +tensor([[[[[[1., 2., 3., 4.]], + [[3., 4., 5., 6.]]], - [[[ 4.], - [104.]], - [[ 14.], - [114.]], + [[[2., 3., 4., 5.]], - [[ 24.], - [124.]]]], + [[0., 0., 0., 0.]]], + [[[0., 0., 0., 0.]], - [[[[ 4.], - [104.]], + [[4., 5., 6., 7.]]]], - [[ 14.], - [114.]], - [[ 24.], - [124.]]], + [[[[1., 2., 3., 4.]], - [[[ 5.], - [105.]], + [[4., 5., 6., 7.]]], - [[ 15.], - [115.]], - [[ 25.], - [125.]]]]], + [[[2., 3., 4., 5.]], + [[0., 0., 0., 0.]]], + [[[3., 4., 5., 6.]], - [[[[[ 5.], - [105.]], + [[0., 0., 0., 0.]]]], - [[ 15.], - [115.]], - [[ 25.], - [125.]]], + [[[[1., 2., 3., 4.]], - [[[ 6.], - [106.]], + [[2., 3., 4., 5.]]], - [[ 16.], - [116.]], - [[ 26.], - [126.]]]], + [[[0., 0., 0., 0.]], + [[3., 4., 5., 6.]]], - [[[[ 6.], - [106.]], + [[[0., 0., 0., 0.]], - [[ 16.], - [116.]], + [[4., 5., 6., 7.]]]]], - [[ 26.], - [126.]]], - [[[ 7.], - [107.]], - [[ 17.], - [117.]], + [[[[[0., 0., 0., 0.]], - [[ 27.], - [127.]]]], + [[2., 3., 4., 5.]]], + [[[1., 2., 3., 4.]], - [[[[ 7.], - [107.]], + [[3., 4., 5., 6.]]], - [[ 17.], - [117.]], - [[ 27.], - [127.]]], + [[[0., 0., 0., 0.]], + [[4., 5., 6., 7.]]]], - [[[ 8.], - [108.]], - [[ 18.], - [118.]], - [[ 28.], - [128.]]]], + [[[[1., 2., 3., 4.]], + [[3., 4., 5., 6.]]], - [[[[ 8.], - [108.]], + [[[0., 0., 0., 0.]], - [[ 18.], - [118.]], + [[4., 5., 6., 7.]]], - [[ 28.], - [128.]]], + [[[2., 3., 4., 5.]], - [[[ 9.], - [109.]], + [[0., 0., 0., 0.]]]], - [[ 19.], - [119.]], - [[ 29.], - [129.]]]]], + [[[[1., 2., 3., 4.]], + [[0., 0., 0., 0.]]], - [[[[[ 9.], - [109.]], + [[[2., 3., 4., 5.]], - [[ 19.], - [119.]], + [[4., 5., 6., 7.]]], - [[ 29.], - [129.]]], + [[[3., 4., 5., 6.]], - [[[ 10.], - [110.]], + [[0., 0., 0., 0.]]]]]]) - [[ 20.], - [120.]], +########## torch.float32/torch.int64/size=()+(8, 6)+(4, 2) ########## +# sparse tensor +tensor(ccol_indices=tensor([0, 4, 7]), + row_indices=tensor([0, 1, 2, 3, 0, 2, 3]), + values=tensor([[[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + + [[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], - [[ 30.], - [130.]]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[[ 10.], - [110.]], + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]], - [[ 20.], - [120.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[ 30.], - [130.]]], - [[[ 11.], - [111.]], + [[[[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], - [[ 21.], - [121.]], + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], - [[ 31.], - [131.]]]], + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[[ 11.], - [111.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 21.], - [121.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[ 31.], - [131.]]], - [[[ 12.], - [112.]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 22.], - [122.]], + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], - [[ 32.], - [132.]]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], + [[[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], - [[[[ 12.], - [112.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 22.], - [122.]], + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]]]], - [[ 32.], - [132.]]], - [[[ 13.], - [113.]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 23.], - [123.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 33.], - [133.]]]]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], + [[[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]]]], - [[[[[[ 13.], - [113.]], - [[ 23.], - [123.]], - [[ 33.], - [133.]]], + [[[[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[ 14.], - [114.]], + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]]], - [[ 24.], - [124.]], - [[ 34.], - [134.]]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[[[ 14.], - [114.]], - [[ 24.], - [124.]], - [[ 34.], - [134.]]], + [[[[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], - [[[ 15.], - [115.]], + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]]], - [[ 25.], - [125.]], - [[ 35.], - [135.]]]], + [[[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]]]], - [[[[ 15.], - [115.]], - [[ 25.], - [125.]], - [[ 35.], - [135.]]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], - [[[ 16.], - [116.]], + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]]], - [[ 26.], - [126.]], - [[ 36.], - [136.]]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]]]]), size=(8, 6, 4, 2), nnz=7, + layout=torch.sparse_bsc) +# _ccol_indices +tensor([0, 4, 7]) +# _row_indices +tensor([0, 1, 2, 3, 0, 2, 3]) +# _values +tensor([[[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[[ 16.], - [116.]], + [[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], - [[ 26.], - [126.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[ 36.], - [136.]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[ 17.], - [117.]], + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]], - [[ 27.], - [127.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[ 37.], - [137.]]]]], + [[[[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], - [[[[[ 17.], - [117.]], + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]]], - [[ 27.], - [127.]], - [[ 37.], - [137.]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[ 18.], - [118.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[ 28.], - [128.]], - [[ 38.], - [138.]]]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], - [[[[ 18.], - [118.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[ 28.], - [128.]], - [[ 38.], - [138.]]], + [[[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[ 19.], - [119.]], + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]]]], - [[ 29.], - [129.]], - [[ 39.], - [139.]]]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[[ 19.], - [119.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[ 29.], - [129.]], - [[ 39.], - [139.]]], + [[[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], - [[[ 20.], - [120.]], + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]]]], - [[ 30.], - [130.]], - [[ 40.], - [140.]]]], + [[[[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[[ 20.], - [120.]], + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]]], - [[ 30.], - [130.]], - [[ 40.], - [140.]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], - [[[ 21.], - [121.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[ 31.], - [131.]], - [[ 41.], - [141.]]]]], + [[[[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]]], - [[[[[ 21.], - [121.]], - [[ 31.], - [131.]], + [[[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], - [[ 41.], - [141.]]], + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]]]], - [[[ 22.], - [122.]], - [[ 32.], - [132.]], - [[ 42.], - [142.]]]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]]], - [[[[ 22.], - [122.]], - [[ 32.], - [132.]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 42.], - [142.]]], + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]]]]) - [[[ 23.], - [123.]], - [[ 33.], - [133.]], +########## torch.float64/torch.int64/size=(2, 3)+(2, 3)+(4,) ########## +# sparse tensor +tensor(ccol_indices=tensor([[[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]], - [[ 43.], - [143.]]]], + [[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]]]), + row_indices=tensor([[[0, 0, 0], + [0, 0, 0], + [0, 0, 0]], + [[0, 0, 0], + [0, 0, 0], + [0, 0, 0]]]), + values=tensor([[[[[[1., 2., 3., 4.]], + [[3., 4., 5., 6.]]], - [[[[ 23.], - [123.]], - [[ 33.], - [133.]], + [[[2., 3., 4., 5.]], - [[ 43.], - [143.]]], + [[0., 0., 0., 0.]]], - [[[ 24.], - [124.]], + [[[0., 0., 0., 0.]], - [[ 34.], - [134.]], + [[4., 5., 6., 7.]]]], - [[ 44.], - [144.]]]], + [[[[1., 2., 3., 4.]], - [[[[ 24.], - [124.]], + [[4., 5., 6., 7.]]], - [[ 34.], - [134.]], - [[ 44.], - [144.]]], + [[[2., 3., 4., 5.]], + [[0., 0., 0., 0.]]], - [[[ 25.], - [125.]], - [[ 35.], - [135.]], + [[[3., 4., 5., 6.]], - [[ 45.], - [145.]]]]]]]), size=(2, 3, 6, 6, 2, 1), nnz=4, - dtype=torch.float64, layout=torch.sparse_bsc) -# _ccol_indices -tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], + [[0., 0., 0., 0.]]]], - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]) -# _row_indices -tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]) -# _values -tensor([[[[[[[ 1.], - [101.]], - [[ 11.], - [111.]], + [[[[1., 2., 3., 4.]], - [[ 21.], - [121.]]], + [[2., 3., 4., 5.]]], - [[[ 2.], - [102.]], + [[[0., 0., 0., 0.]], - [[ 12.], - [112.]], + [[3., 4., 5., 6.]]], - [[ 22.], - [122.]]]], + [[[0., 0., 0., 0.]], + [[4., 5., 6., 7.]]]]], - [[[[ 2.], - [102.]], - [[ 12.], - [112.]], - [[ 22.], - [122.]]], + [[[[[0., 0., 0., 0.]], - [[[ 3.], - [103.]], + [[2., 3., 4., 5.]]], - [[ 13.], - [113.]], - [[ 23.], - [123.]]]], + [[[1., 2., 3., 4.]], + [[3., 4., 5., 6.]]], - [[[[ 3.], - [103.]], + [[[0., 0., 0., 0.]], - [[ 13.], - [113.]], + [[4., 5., 6., 7.]]]], - [[ 23.], - [123.]]], - [[[ 4.], - [104.]], + [[[[1., 2., 3., 4.]], - [[ 14.], - [114.]], + [[3., 4., 5., 6.]]], - [[ 24.], - [124.]]]], + [[[0., 0., 0., 0.]], + [[4., 5., 6., 7.]]], - [[[[ 4.], - [104.]], - [[ 14.], - [114.]], + [[[2., 3., 4., 5.]], - [[ 24.], - [124.]]], + [[0., 0., 0., 0.]]]], - [[[ 5.], - [105.]], - [[ 15.], - [115.]], + [[[[1., 2., 3., 4.]], - [[ 25.], - [125.]]]]], + [[0., 0., 0., 0.]]], + [[[2., 3., 4., 5.]], + [[4., 5., 6., 7.]]], - [[[[[ 5.], - [105.]], - [[ 15.], - [115.]], + [[[3., 4., 5., 6.]], - [[ 25.], - [125.]]], + [[0., 0., 0., 0.]]]]]]), size=(2, 3, 2, 3, 4), nnz=3, + dtype=torch.float64, layout=torch.sparse_bsc) +# _ccol_indices +tensor([[[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]], + [[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]]]) +# _row_indices +tensor([[[0, 0, 0], + [0, 0, 0], + [0, 0, 0]], - [[[ 6.], - [106.]], + [[0, 0, 0], + [0, 0, 0], + [0, 0, 0]]]) +# _values +tensor([[[[[[1., 2., 3., 4.]], - [[ 16.], - [116.]], + [[3., 4., 5., 6.]]], - [[ 26.], - [126.]]]], + [[[2., 3., 4., 5.]], + [[0., 0., 0., 0.]]], - [[[[ 6.], - [106.]], - [[ 16.], - [116.]], + [[[0., 0., 0., 0.]], - [[ 26.], - [126.]]], + [[4., 5., 6., 7.]]]], - [[[ 7.], - [107.]], - [[ 17.], - [117.]], + [[[[1., 2., 3., 4.]], - [[ 27.], - [127.]]]], + [[4., 5., 6., 7.]]], + [[[2., 3., 4., 5.]], - [[[[ 7.], - [107.]], + [[0., 0., 0., 0.]]], - [[ 17.], - [117.]], - [[ 27.], - [127.]]], + [[[3., 4., 5., 6.]], + [[0., 0., 0., 0.]]]], - [[[ 8.], - [108.]], - [[ 18.], - [118.]], - [[ 28.], - [128.]]]], + [[[[1., 2., 3., 4.]], + [[2., 3., 4., 5.]]], - [[[[ 8.], - [108.]], + [[[0., 0., 0., 0.]], - [[ 18.], - [118.]], + [[3., 4., 5., 6.]]], - [[ 28.], - [128.]]], + [[[0., 0., 0., 0.]], - [[[ 9.], - [109.]], + [[4., 5., 6., 7.]]]]], - [[ 19.], - [119.]], - [[ 29.], - [129.]]]]], + [[[[[0., 0., 0., 0.]], + [[2., 3., 4., 5.]]], - [[[[[ 9.], - [109.]], - [[ 19.], - [119.]], + [[[1., 2., 3., 4.]], - [[ 29.], - [129.]]], + [[3., 4., 5., 6.]]], - [[[ 10.], - [110.]], + [[[0., 0., 0., 0.]], - [[ 20.], - [120.]], + [[4., 5., 6., 7.]]]], - [[ 30.], - [130.]]]], + [[[[1., 2., 3., 4.]], - [[[[ 10.], - [110.]], + [[3., 4., 5., 6.]]], - [[ 20.], - [120.]], - [[ 30.], - [130.]]], + [[[0., 0., 0., 0.]], + [[4., 5., 6., 7.]]], - [[[ 11.], - [111.]], - [[ 21.], - [121.]], + [[[2., 3., 4., 5.]], - [[ 31.], - [131.]]]], + [[0., 0., 0., 0.]]]], - [[[[ 11.], - [111.]], + [[[[1., 2., 3., 4.]], - [[ 21.], - [121.]], + [[0., 0., 0., 0.]]], - [[ 31.], - [131.]]], + [[[2., 3., 4., 5.]], - [[[ 12.], - [112.]], + [[4., 5., 6., 7.]]], - [[ 22.], - [122.]], - [[ 32.], - [132.]]]], + [[[3., 4., 5., 6.]], + [[0., 0., 0., 0.]]]]]], dtype=torch.float64) +########## torch.float64/torch.int64/size=()+(8, 6)+(4, 2) ########## +# sparse tensor +tensor(ccol_indices=tensor([0, 4, 7]), + row_indices=tensor([0, 1, 2, 3, 0, 2, 3]), + values=tensor([[[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + + [[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], - [[[[ 12.], - [112.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[ 22.], - [122.]], - [[ 32.], - [132.]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]], - [[[ 13.], - [113.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[ 23.], - [123.]], - [[ 33.], - [133.]]]]]], + [[[[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]]], - [[[[[[ 13.], - [113.]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 23.], - [123.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 33.], - [133.]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[[ 14.], - [114.]], - [[ 24.], - [124.]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 34.], - [134.]]]], + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[[[ 14.], - [114.]], + [[[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], - [[ 24.], - [124.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 34.], - [134.]]], + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]]]], - [[[ 15.], - [115.]], - [[ 25.], - [125.]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 35.], - [135.]]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[[[ 15.], - [115.]], + [[[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], - [[ 25.], - [125.]], + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], - [[ 35.], - [135.]]], + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]]]], - [[[ 16.], - [116.]], - [[ 26.], - [126.]], + [[[[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], - [[ 36.], - [136.]]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]]], - [[[[ 16.], - [116.]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 26.], - [126.]], + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], - [[ 36.], - [136.]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[[ 17.], - [117.]], - [[ 27.], - [127.]], + [[[[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], - [[ 37.], - [137.]]]]], + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]]], + [[[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], - [[[[[ 17.], - [117.]], + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], - [[ 27.], - [127.]], + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]]]], - [[ 37.], - [137.]]], - [[[ 18.], - [118.]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 28.], - [128.]], + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], - [[ 38.], - [138.]]]], + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[[ 18.], - [118.]], + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], - [[ 28.], - [128.]], + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]]]]), size=(8, 6, 4, 2), nnz=7, + dtype=torch.float64, layout=torch.sparse_bsc) +# _ccol_indices +tensor([0, 4, 7]) +# _row_indices +tensor([0, 1, 2, 3, 0, 2, 3]) +# _values +tensor([[[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 38.], - [138.]]], + [[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[[ 19.], - [119.]], - [[ 29.], - [129.]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 39.], - [139.]]]], + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[[[ 19.], - [119.]], - [[ 29.], - [129.]], + [[[[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], - [[ 39.], - [139.]]], + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]]], - [[[ 20.], - [120.]], - [[ 30.], - [130.]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 40.], - [140.]]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[[[ 20.], - [120.]], - [[ 30.], - [130.]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 40.], - [140.]]], + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[[ 21.], - [121.]], - [[ 31.], - [131.]], + [[[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], - [[ 41.], - [141.]]]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]]]], - [[[[[ 21.], - [121.]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 31.], - [131.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 41.], - [141.]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[[ 22.], - [122.]], + [[[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], - [[ 32.], - [132.]], + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], - [[ 42.], - [142.]]]], + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]]]], - [[[[ 22.], - [122.]], + [[[[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], - [[ 32.], - [132.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 42.], - [142.]]], + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]]], - [[[ 23.], - [123.]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 33.], - [133.]], + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], - [[ 43.], - [143.]]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[[[ 23.], - [123.]], + [[[[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], - [[ 33.], - [133.]], + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], - [[ 43.], - [143.]]], + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]]], - [[[ 24.], - [124.]], + [[[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], - [[ 34.], - [134.]], + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], - [[ 44.], - [144.]]]], + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]]]], - [[[[ 24.], - [124.]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 34.], - [134.]], + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], - [[ 44.], - [144.]]], + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]]], - [[[ 25.], - [125.]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 35.], - [135.]], + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], - [[ 45.], - [145.]]]]]]], dtype=torch.float64) + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]]]], dtype=torch.float64) diff --git a/test/expect/TestSparseCompressedCPU.test_print_SparseBSR_cpu.expect b/test/expect/TestSparseCompressedCPU.test_print_SparseBSR_cpu.expect index 267056b76e678..8fe3223332bb5 100644 --- a/test/expect/TestSparseCompressedCPU.test_print_SparseBSR_cpu.expect +++ b/test/expect/TestSparseCompressedCPU.test_print_SparseBSR_cpu.expect @@ -1,6945 +1,3583 @@ -########## torch.float32/torch.int32/size=()+(4, 3)+() ########## +########## torch.float32/torch.int32/size=(2, 3)+(2, 3)+() ########## # sparse tensor -tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 2]), - values=tensor([[[1.], - [2.]], +tensor(crow_indices=tensor([[[0, 3], + [0, 3], + [0, 3]], - [[2.], - [3.]], + [[0, 3], + [0, 3], + [0, 3]]]), + col_indices=tensor([[[0, 1, 2], + [0, 1, 2], + [0, 1, 2]], - [[3.], - [4.]], + [[0, 1, 2], + [0, 1, 2], + [0, 1, 2]]]), + values=tensor([[[[[1.], + [3.]], - [[4.], - [5.]]]), size=(4, 3), nnz=4, layout=torch.sparse_bsr) -# _crow_indices -tensor([0, 2, 4], dtype=torch.int32) -# _col_indices -tensor([0, 1, 0, 2], dtype=torch.int32) -# _values -tensor([[[1.], - [2.]], - - [[2.], - [3.]], - - [[3.], - [4.]], - - [[4.], - [5.]]]) - -########## torch.float32/torch.int32/size=()+(0, 0)+() ########## -# sparse tensor -tensor(crow_indices=tensor([0]), - col_indices=tensor([], size=(0,)), - values=tensor([], size=(0, 2, 1)), size=(0, 0), nnz=0, - layout=torch.sparse_bsr) -# _crow_indices -tensor([0], dtype=torch.int32) -# _col_indices -tensor([], dtype=torch.int32) -# _values -tensor([], size=(0, 2, 1)) - -########## torch.float32/torch.int32/size=(2,)+(2, 6)+() ########## -# sparse tensor -tensor(crow_indices=tensor([[0, 2, 4], - [0, 3, 4]]), - col_indices=tensor([[0, 1, 0, 1], - [0, 1, 2, 0]]), - values=tensor([[[[ 1., 11.]], - - [[ 2., 12.]], - - [[ 3., 13.]], - - [[ 4., 14.]]], - - - [[[ 5., 15.]], - - [[ 6., 16.]], - - [[ 7., 17.]], - - [[ 8., 18.]]]]), size=(2, 2, 6), nnz=4, - layout=torch.sparse_bsr) -# _crow_indices -tensor([[0, 2, 4], - [0, 3, 4]], dtype=torch.int32) -# _col_indices -tensor([[0, 1, 0, 1], - [0, 1, 2, 0]], dtype=torch.int32) -# _values -tensor([[[[ 1., 11.]], - - [[ 2., 12.]], - - [[ 3., 13.]], - - [[ 4., 14.]]], - - - [[[ 5., 15.]], - - [[ 6., 16.]], + [[2.], + [0.]], - [[ 7., 17.]], + [[0.], + [4.]]], - [[ 8., 18.]]]]) -########## torch.float32/torch.int32/size=(2, 3)+(4, 9)+() ########## -# sparse tensor -tensor(crow_indices=tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], - - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]), - col_indices=tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]), - values=tensor([[[[[ 1., 11., 21.], - [ 2., 12., 22.]], - - [[ 2., 12., 22.], - [ 3., 13., 23.]], - - [[ 3., 13., 23.], - [ 4., 14., 24.]], - - [[ 4., 14., 24.], - [ 5., 15., 25.]]], - - - [[[ 5., 15., 25.], - [ 6., 16., 26.]], + [[[1.], + [4.]], - [[ 6., 16., 26.], - [ 7., 17., 27.]], + [[2.], + [0.]], - [[ 7., 17., 27.], - [ 8., 18., 28.]], + [[3.], + [0.]]], - [[ 8., 18., 28.], - [ 9., 19., 29.]]], + [[[1.], + [2.]], - [[[ 9., 19., 29.], - [10., 20., 30.]], + [[0.], + [3.]], - [[10., 20., 30.], - [11., 21., 31.]], + [[0.], + [4.]]]], - [[11., 21., 31.], - [12., 22., 32.]], - [[12., 22., 32.], - [13., 23., 33.]]]], + [[[[0.], + [2.]], + [[1.], + [3.]], - [[[[13., 23., 33.], - [14., 24., 34.]], + [[0.], + [4.]]], - [[14., 24., 34.], - [15., 25., 35.]], - [[15., 25., 35.], - [16., 26., 36.]], + [[[1.], + [3.]], - [[16., 26., 36.], - [17., 27., 37.]]], + [[0.], + [4.]], + [[2.], + [0.]]], - [[[17., 27., 37.], - [18., 28., 38.]], - [[18., 28., 38.], - [19., 29., 39.]], + [[[1.], + [0.]], - [[19., 29., 39.], - [20., 30., 40.]], + [[2.], + [4.]], - [[20., 30., 40.], - [21., 31., 41.]]], - - - [[[21., 31., 41.], - [22., 32., 42.]], - - [[22., 32., 42.], - [23., 33., 43.]], - - [[23., 33., 43.], - [24., 34., 44.]], - - [[24., 34., 44.], - [25., 35., 45.]]]]]), size=(2, 3, 4, 9), nnz=4, + [[3.], + [0.]]]]]), size=(2, 3, 2, 3), nnz=3, layout=torch.sparse_bsr) # _crow_indices -tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], +tensor([[[0, 3], + [0, 3], + [0, 3]], - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]], dtype=torch.int32) + [[0, 3], + [0, 3], + [0, 3]]], dtype=torch.int32) # _col_indices -tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], +tensor([[[0, 1, 2], + [0, 1, 2], + [0, 1, 2]], - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]], dtype=torch.int32) + [[0, 1, 2], + [0, 1, 2], + [0, 1, 2]]], dtype=torch.int32) # _values -tensor([[[[[ 1., 11., 21.], - [ 2., 12., 22.]], - - [[ 2., 12., 22.], - [ 3., 13., 23.]], - - [[ 3., 13., 23.], - [ 4., 14., 24.]], - - [[ 4., 14., 24.], - [ 5., 15., 25.]]], +tensor([[[[[1.], + [3.]], + [[2.], + [0.]], - [[[ 5., 15., 25.], - [ 6., 16., 26.]], + [[0.], + [4.]]], - [[ 6., 16., 26.], - [ 7., 17., 27.]], - [[ 7., 17., 27.], - [ 8., 18., 28.]], + [[[1.], + [4.]], - [[ 8., 18., 28.], - [ 9., 19., 29.]]], + [[2.], + [0.]], + [[3.], + [0.]]], - [[[ 9., 19., 29.], - [10., 20., 30.]], - [[10., 20., 30.], - [11., 21., 31.]], + [[[1.], + [2.]], - [[11., 21., 31.], - [12., 22., 32.]], + [[0.], + [3.]], - [[12., 22., 32.], - [13., 23., 33.]]]], + [[0.], + [4.]]]], - [[[[13., 23., 33.], - [14., 24., 34.]], + [[[[0.], + [2.]], - [[14., 24., 34.], - [15., 25., 35.]], + [[1.], + [3.]], - [[15., 25., 35.], - [16., 26., 36.]], + [[0.], + [4.]]], - [[16., 26., 36.], - [17., 27., 37.]]], + [[[1.], + [3.]], - [[[17., 27., 37.], - [18., 28., 38.]], + [[0.], + [4.]], - [[18., 28., 38.], - [19., 29., 39.]], - - [[19., 29., 39.], - [20., 30., 40.]], - - [[20., 30., 40.], - [21., 31., 41.]]], - - - [[[21., 31., 41.], - [22., 32., 42.]], - - [[22., 32., 42.], - [23., 33., 43.]], - - [[23., 33., 43.], - [24., 34., 44.]], - - [[24., 34., 44.], - [25., 35., 45.]]]]]) - - -########## torch.float64/torch.int32/size=()+(4, 3)+() ########## -# sparse tensor -tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 2]), - values=tensor([[[1.], - [2.]], - - [[2.], - [3.]], - - [[3.], - [4.]], - - [[4.], - [5.]]]), size=(4, 3), nnz=4, dtype=torch.float64, - layout=torch.sparse_bsr) -# _crow_indices -tensor([0, 2, 4], dtype=torch.int32) -# _col_indices -tensor([0, 1, 0, 2], dtype=torch.int32) -# _values -tensor([[[1.], - [2.]], + [[2.], + [0.]]], - [[2.], - [3.]], - [[3.], - [4.]], + [[[1.], + [0.]], - [[4.], - [5.]]], dtype=torch.float64) + [[2.], + [4.]], -########## torch.float64/torch.int32/size=()+(0, 0)+() ########## -# sparse tensor -tensor(crow_indices=tensor([0]), - col_indices=tensor([], size=(0,)), - values=tensor([], size=(0, 2, 1)), size=(0, 0), nnz=0, - dtype=torch.float64, layout=torch.sparse_bsr) -# _crow_indices -tensor([0], dtype=torch.int32) -# _col_indices -tensor([], dtype=torch.int32) -# _values -tensor([], size=(0, 2, 1), dtype=torch.float64) + [[3.], + [0.]]]]]) -########## torch.float64/torch.int32/size=(2,)+(2, 6)+() ########## +########## torch.float32/torch.int32/size=()+(8, 6)+() ########## # sparse tensor -tensor(crow_indices=tensor([[0, 2, 4], - [0, 3, 4]]), - col_indices=tensor([[0, 1, 0, 1], - [0, 1, 2, 0]]), - values=tensor([[[[ 1., 11.]], - - [[ 2., 12.]], - - [[ 3., 13.]], +tensor(crow_indices=tensor([0, 2, 3, 5, 7]), + col_indices=tensor([0, 1, 0, 0, 1, 0, 1]), + values=tensor([[[ 0., 1., 0.], + [ 0., 4., 0.]], - [[ 4., 14.]]], + [[ 2., 0., 3.], + [ 0., 5., 0.]], + [[ 6., 7., 8.], + [ 0., 0., 0.]], - [[[ 5., 15.]], + [[ 0., 9., 0.], + [13., 0., 14.]], - [[ 6., 16.]], + [[10., 11., 12.], + [15., 16., 17.]], - [[ 7., 17.]], + [[ 0., 0., 0.], + [20., 21., 22.]], - [[ 8., 18.]]]]), size=(2, 2, 6), nnz=4, - dtype=torch.float64, layout=torch.sparse_bsr) + [[ 0., 18., 19.], + [ 0., 23., 24.]]]), size=(8, 6), nnz=7, + layout=torch.sparse_bsr) # _crow_indices -tensor([[0, 2, 4], - [0, 3, 4]], dtype=torch.int32) +tensor([0, 2, 3, 5, 7], dtype=torch.int32) # _col_indices -tensor([[0, 1, 0, 1], - [0, 1, 2, 0]], dtype=torch.int32) +tensor([0, 1, 0, 0, 1, 0, 1], dtype=torch.int32) # _values -tensor([[[[ 1., 11.]], +tensor([[[ 0., 1., 0.], + [ 0., 4., 0.]], - [[ 2., 12.]], + [[ 2., 0., 3.], + [ 0., 5., 0.]], - [[ 3., 13.]], + [[ 6., 7., 8.], + [ 0., 0., 0.]], - [[ 4., 14.]]], + [[ 0., 9., 0.], + [13., 0., 14.]], + [[10., 11., 12.], + [15., 16., 17.]], - [[[ 5., 15.]], + [[ 0., 0., 0.], + [20., 21., 22.]], - [[ 6., 16.]], + [[ 0., 18., 19.], + [ 0., 23., 24.]]]) - [[ 7., 17.]], - [[ 8., 18.]]]], dtype=torch.float64) - -########## torch.float64/torch.int32/size=(2, 3)+(4, 9)+() ########## +########## torch.float64/torch.int32/size=(2, 3)+(2, 3)+() ########## # sparse tensor -tensor(crow_indices=tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], - - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]), - col_indices=tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]), - values=tensor([[[[[ 1., 11., 21.], - [ 2., 12., 22.]], - - [[ 2., 12., 22.], - [ 3., 13., 23.]], - - [[ 3., 13., 23.], - [ 4., 14., 24.]], - - [[ 4., 14., 24.], - [ 5., 15., 25.]]], +tensor(crow_indices=tensor([[[0, 3], + [0, 3], + [0, 3]], + [[0, 3], + [0, 3], + [0, 3]]]), + col_indices=tensor([[[0, 1, 2], + [0, 1, 2], + [0, 1, 2]], - [[[ 5., 15., 25.], - [ 6., 16., 26.]], + [[0, 1, 2], + [0, 1, 2], + [0, 1, 2]]]), + values=tensor([[[[[1.], + [3.]], - [[ 6., 16., 26.], - [ 7., 17., 27.]], + [[2.], + [0.]], - [[ 7., 17., 27.], - [ 8., 18., 28.]], + [[0.], + [4.]]], - [[ 8., 18., 28.], - [ 9., 19., 29.]]], + [[[1.], + [4.]], - [[[ 9., 19., 29.], - [10., 20., 30.]], + [[2.], + [0.]], - [[10., 20., 30.], - [11., 21., 31.]], + [[3.], + [0.]]], - [[11., 21., 31.], - [12., 22., 32.]], - [[12., 22., 32.], - [13., 23., 33.]]]], + [[[1.], + [2.]], + [[0.], + [3.]], + [[0.], + [4.]]]], - [[[[13., 23., 33.], - [14., 24., 34.]], - [[14., 24., 34.], - [15., 25., 35.]], - [[15., 25., 35.], - [16., 26., 36.]], + [[[[0.], + [2.]], - [[16., 26., 36.], - [17., 27., 37.]]], + [[1.], + [3.]], + [[0.], + [4.]]], - [[[17., 27., 37.], - [18., 28., 38.]], - [[18., 28., 38.], - [19., 29., 39.]], + [[[1.], + [3.]], - [[19., 29., 39.], - [20., 30., 40.]], + [[0.], + [4.]], - [[20., 30., 40.], - [21., 31., 41.]]], + [[2.], + [0.]]], - [[[21., 31., 41.], - [22., 32., 42.]], + [[[1.], + [0.]], - [[22., 32., 42.], - [23., 33., 43.]], + [[2.], + [4.]], - [[23., 33., 43.], - [24., 34., 44.]], - - [[24., 34., 44.], - [25., 35., 45.]]]]]), size=(2, 3, 4, 9), nnz=4, + [[3.], + [0.]]]]]), size=(2, 3, 2, 3), nnz=3, dtype=torch.float64, layout=torch.sparse_bsr) # _crow_indices -tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], +tensor([[[0, 3], + [0, 3], + [0, 3]], - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]], dtype=torch.int32) + [[0, 3], + [0, 3], + [0, 3]]], dtype=torch.int32) # _col_indices -tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], +tensor([[[0, 1, 2], + [0, 1, 2], + [0, 1, 2]], - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]], dtype=torch.int32) + [[0, 1, 2], + [0, 1, 2], + [0, 1, 2]]], dtype=torch.int32) # _values -tensor([[[[[ 1., 11., 21.], - [ 2., 12., 22.]], - - [[ 2., 12., 22.], - [ 3., 13., 23.]], - - [[ 3., 13., 23.], - [ 4., 14., 24.]], - - [[ 4., 14., 24.], - [ 5., 15., 25.]]], - - - [[[ 5., 15., 25.], - [ 6., 16., 26.]], - - [[ 6., 16., 26.], - [ 7., 17., 27.]], - - [[ 7., 17., 27.], - [ 8., 18., 28.]], - - [[ 8., 18., 28.], - [ 9., 19., 29.]]], - - - [[[ 9., 19., 29.], - [10., 20., 30.]], - - [[10., 20., 30.], - [11., 21., 31.]], - - [[11., 21., 31.], - [12., 22., 32.]], - - [[12., 22., 32.], - [13., 23., 33.]]]], - - +tensor([[[[[1.], + [3.]], - [[[[13., 23., 33.], - [14., 24., 34.]], + [[2.], + [0.]], - [[14., 24., 34.], - [15., 25., 35.]], + [[0.], + [4.]]], - [[15., 25., 35.], - [16., 26., 36.]], - [[16., 26., 36.], - [17., 27., 37.]]], + [[[1.], + [4.]], + [[2.], + [0.]], - [[[17., 27., 37.], - [18., 28., 38.]], + [[3.], + [0.]]], - [[18., 28., 38.], - [19., 29., 39.]], - [[19., 29., 39.], - [20., 30., 40.]], + [[[1.], + [2.]], - [[20., 30., 40.], - [21., 31., 41.]]], + [[0.], + [3.]], + [[0.], + [4.]]]], - [[[21., 31., 41.], - [22., 32., 42.]], - [[22., 32., 42.], - [23., 33., 43.]], - [[23., 33., 43.], - [24., 34., 44.]], + [[[[0.], + [2.]], - [[24., 34., 44.], - [25., 35., 45.]]]]], dtype=torch.float64) + [[1.], + [3.]], + [[0.], + [4.]]], -########## torch.float32/torch.int64/size=()+(4, 3)+() ########## -# sparse tensor -tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 2]), - values=tensor([[[1.], - [2.]], - [[2.], - [3.]], + [[[1.], + [3.]], - [[3.], - [4.]], + [[0.], + [4.]], - [[4.], - [5.]]]), size=(4, 3), nnz=4, layout=torch.sparse_bsr) -# _crow_indices -tensor([0, 2, 4]) -# _col_indices -tensor([0, 1, 0, 2]) -# _values -tensor([[[1.], - [2.]], + [[2.], + [0.]]], - [[2.], - [3.]], - [[3.], - [4.]], + [[[1.], + [0.]], - [[4.], - [5.]]]) + [[2.], + [4.]], -########## torch.float32/torch.int64/size=()+(0, 0)+() ########## -# sparse tensor -tensor(crow_indices=tensor([0]), - col_indices=tensor([], size=(0,)), - values=tensor([], size=(0, 2, 1)), size=(0, 0), nnz=0, - layout=torch.sparse_bsr) -# _crow_indices -tensor([0]) -# _col_indices -tensor([], dtype=torch.int64) -# _values -tensor([], size=(0, 2, 1)) + [[3.], + [0.]]]]], dtype=torch.float64) -########## torch.float32/torch.int64/size=(2,)+(2, 6)+() ########## +########## torch.float64/torch.int32/size=()+(8, 6)+() ########## # sparse tensor -tensor(crow_indices=tensor([[0, 2, 4], - [0, 3, 4]]), - col_indices=tensor([[0, 1, 0, 1], - [0, 1, 2, 0]]), - values=tensor([[[[ 1., 11.]], - - [[ 2., 12.]], +tensor(crow_indices=tensor([0, 2, 3, 5, 7]), + col_indices=tensor([0, 1, 0, 0, 1, 0, 1]), + values=tensor([[[ 0., 1., 0.], + [ 0., 4., 0.]], - [[ 3., 13.]], + [[ 2., 0., 3.], + [ 0., 5., 0.]], - [[ 4., 14.]]], + [[ 6., 7., 8.], + [ 0., 0., 0.]], + [[ 0., 9., 0.], + [13., 0., 14.]], - [[[ 5., 15.]], + [[10., 11., 12.], + [15., 16., 17.]], - [[ 6., 16.]], + [[ 0., 0., 0.], + [20., 21., 22.]], - [[ 7., 17.]], - - [[ 8., 18.]]]]), size=(2, 2, 6), nnz=4, - layout=torch.sparse_bsr) + [[ 0., 18., 19.], + [ 0., 23., 24.]]]), size=(8, 6), nnz=7, + dtype=torch.float64, layout=torch.sparse_bsr) # _crow_indices -tensor([[0, 2, 4], - [0, 3, 4]]) +tensor([0, 2, 3, 5, 7], dtype=torch.int32) # _col_indices -tensor([[0, 1, 0, 1], - [0, 1, 2, 0]]) +tensor([0, 1, 0, 0, 1, 0, 1], dtype=torch.int32) # _values -tensor([[[[ 1., 11.]], - - [[ 2., 12.]], +tensor([[[ 0., 1., 0.], + [ 0., 4., 0.]], - [[ 3., 13.]], + [[ 2., 0., 3.], + [ 0., 5., 0.]], - [[ 4., 14.]]], + [[ 6., 7., 8.], + [ 0., 0., 0.]], + [[ 0., 9., 0.], + [13., 0., 14.]], - [[[ 5., 15.]], + [[10., 11., 12.], + [15., 16., 17.]], - [[ 6., 16.]], + [[ 0., 0., 0.], + [20., 21., 22.]], - [[ 7., 17.]], + [[ 0., 18., 19.], + [ 0., 23., 24.]]], dtype=torch.float64) - [[ 8., 18.]]]]) -########## torch.float32/torch.int64/size=(2, 3)+(4, 9)+() ########## +########## torch.float32/torch.int64/size=(2, 3)+(2, 3)+() ########## # sparse tensor -tensor(crow_indices=tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], - - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]), - col_indices=tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]), - values=tensor([[[[[ 1., 11., 21.], - [ 2., 12., 22.]], +tensor(crow_indices=tensor([[[0, 3], + [0, 3], + [0, 3]], - [[ 2., 12., 22.], - [ 3., 13., 23.]], + [[0, 3], + [0, 3], + [0, 3]]]), + col_indices=tensor([[[0, 1, 2], + [0, 1, 2], + [0, 1, 2]], - [[ 3., 13., 23.], - [ 4., 14., 24.]], + [[0, 1, 2], + [0, 1, 2], + [0, 1, 2]]]), + values=tensor([[[[[1.], + [3.]], - [[ 4., 14., 24.], - [ 5., 15., 25.]]], + [[2.], + [0.]], + [[0.], + [4.]]], - [[[ 5., 15., 25.], - [ 6., 16., 26.]], - [[ 6., 16., 26.], - [ 7., 17., 27.]], + [[[1.], + [4.]], - [[ 7., 17., 27.], - [ 8., 18., 28.]], + [[2.], + [0.]], - [[ 8., 18., 28.], - [ 9., 19., 29.]]], + [[3.], + [0.]]], - [[[ 9., 19., 29.], - [10., 20., 30.]], + [[[1.], + [2.]], - [[10., 20., 30.], - [11., 21., 31.]], + [[0.], + [3.]], - [[11., 21., 31.], - [12., 22., 32.]], + [[0.], + [4.]]]], - [[12., 22., 32.], - [13., 23., 33.]]]], + [[[[0.], + [2.]], - [[[[13., 23., 33.], - [14., 24., 34.]], + [[1.], + [3.]], - [[14., 24., 34.], - [15., 25., 35.]], + [[0.], + [4.]]], - [[15., 25., 35.], - [16., 26., 36.]], - [[16., 26., 36.], - [17., 27., 37.]]], + [[[1.], + [3.]], + [[0.], + [4.]], - [[[17., 27., 37.], - [18., 28., 38.]], + [[2.], + [0.]]], - [[18., 28., 38.], - [19., 29., 39.]], - [[19., 29., 39.], - [20., 30., 40.]], + [[[1.], + [0.]], - [[20., 30., 40.], - [21., 31., 41.]]], + [[2.], + [4.]], - - [[[21., 31., 41.], - [22., 32., 42.]], - - [[22., 32., 42.], - [23., 33., 43.]], - - [[23., 33., 43.], - [24., 34., 44.]], - - [[24., 34., 44.], - [25., 35., 45.]]]]]), size=(2, 3, 4, 9), nnz=4, + [[3.], + [0.]]]]]), size=(2, 3, 2, 3), nnz=3, layout=torch.sparse_bsr) # _crow_indices -tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], +tensor([[[0, 3], + [0, 3], + [0, 3]], - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]) + [[0, 3], + [0, 3], + [0, 3]]]) # _col_indices -tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], +tensor([[[0, 1, 2], + [0, 1, 2], + [0, 1, 2]], - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]) + [[0, 1, 2], + [0, 1, 2], + [0, 1, 2]]]) # _values -tensor([[[[[ 1., 11., 21.], - [ 2., 12., 22.]], - - [[ 2., 12., 22.], - [ 3., 13., 23.]], - - [[ 3., 13., 23.], - [ 4., 14., 24.]], - - [[ 4., 14., 24.], - [ 5., 15., 25.]]], - - - [[[ 5., 15., 25.], - [ 6., 16., 26.]], - - [[ 6., 16., 26.], - [ 7., 17., 27.]], - - [[ 7., 17., 27.], - [ 8., 18., 28.]], - - [[ 8., 18., 28.], - [ 9., 19., 29.]]], - - - [[[ 9., 19., 29.], - [10., 20., 30.]], - - [[10., 20., 30.], - [11., 21., 31.]], - - [[11., 21., 31.], - [12., 22., 32.]], +tensor([[[[[1.], + [3.]], - [[12., 22., 32.], - [13., 23., 33.]]]], + [[2.], + [0.]], + [[0.], + [4.]]], - [[[[13., 23., 33.], - [14., 24., 34.]], + [[[1.], + [4.]], - [[14., 24., 34.], - [15., 25., 35.]], + [[2.], + [0.]], - [[15., 25., 35.], - [16., 26., 36.]], + [[3.], + [0.]]], - [[16., 26., 36.], - [17., 27., 37.]]], + [[[1.], + [2.]], - [[[17., 27., 37.], - [18., 28., 38.]], + [[0.], + [3.]], - [[18., 28., 38.], - [19., 29., 39.]], + [[0.], + [4.]]]], - [[19., 29., 39.], - [20., 30., 40.]], - [[20., 30., 40.], - [21., 31., 41.]]], + [[[[0.], + [2.]], - [[[21., 31., 41.], - [22., 32., 42.]], + [[1.], + [3.]], - [[22., 32., 42.], - [23., 33., 43.]], + [[0.], + [4.]]], - [[23., 33., 43.], - [24., 34., 44.]], - [[24., 34., 44.], - [25., 35., 45.]]]]]) + [[[1.], + [3.]], + [[0.], + [4.]], -########## torch.float64/torch.int64/size=()+(4, 3)+() ########## -# sparse tensor -tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 2]), - values=tensor([[[1.], - [2.]], - - [[2.], - [3.]], - - [[3.], - [4.]], - - [[4.], - [5.]]]), size=(4, 3), nnz=4, dtype=torch.float64, - layout=torch.sparse_bsr) -# _crow_indices -tensor([0, 2, 4]) -# _col_indices -tensor([0, 1, 0, 2]) -# _values -tensor([[[1.], - [2.]], + [[2.], + [0.]]], - [[2.], - [3.]], - [[3.], - [4.]], + [[[1.], + [0.]], - [[4.], - [5.]]], dtype=torch.float64) + [[2.], + [4.]], -########## torch.float64/torch.int64/size=()+(0, 0)+() ########## -# sparse tensor -tensor(crow_indices=tensor([0]), - col_indices=tensor([], size=(0,)), - values=tensor([], size=(0, 2, 1)), size=(0, 0), nnz=0, - dtype=torch.float64, layout=torch.sparse_bsr) -# _crow_indices -tensor([0]) -# _col_indices -tensor([], dtype=torch.int64) -# _values -tensor([], size=(0, 2, 1), dtype=torch.float64) + [[3.], + [0.]]]]]) -########## torch.float64/torch.int64/size=(2,)+(2, 6)+() ########## +########## torch.float32/torch.int64/size=()+(8, 6)+() ########## # sparse tensor -tensor(crow_indices=tensor([[0, 2, 4], - [0, 3, 4]]), - col_indices=tensor([[0, 1, 0, 1], - [0, 1, 2, 0]]), - values=tensor([[[[ 1., 11.]], - - [[ 2., 12.]], +tensor(crow_indices=tensor([0, 2, 3, 5, 7]), + col_indices=tensor([0, 1, 0, 0, 1, 0, 1]), + values=tensor([[[ 0., 1., 0.], + [ 0., 4., 0.]], - [[ 3., 13.]], + [[ 2., 0., 3.], + [ 0., 5., 0.]], - [[ 4., 14.]]], + [[ 6., 7., 8.], + [ 0., 0., 0.]], + [[ 0., 9., 0.], + [13., 0., 14.]], - [[[ 5., 15.]], + [[10., 11., 12.], + [15., 16., 17.]], - [[ 6., 16.]], + [[ 0., 0., 0.], + [20., 21., 22.]], - [[ 7., 17.]], - - [[ 8., 18.]]]]), size=(2, 2, 6), nnz=4, - dtype=torch.float64, layout=torch.sparse_bsr) + [[ 0., 18., 19.], + [ 0., 23., 24.]]]), size=(8, 6), nnz=7, + layout=torch.sparse_bsr) # _crow_indices -tensor([[0, 2, 4], - [0, 3, 4]]) +tensor([0, 2, 3, 5, 7]) # _col_indices -tensor([[0, 1, 0, 1], - [0, 1, 2, 0]]) +tensor([0, 1, 0, 0, 1, 0, 1]) # _values -tensor([[[[ 1., 11.]], - - [[ 2., 12.]], +tensor([[[ 0., 1., 0.], + [ 0., 4., 0.]], - [[ 3., 13.]], + [[ 2., 0., 3.], + [ 0., 5., 0.]], - [[ 4., 14.]]], + [[ 6., 7., 8.], + [ 0., 0., 0.]], + [[ 0., 9., 0.], + [13., 0., 14.]], - [[[ 5., 15.]], + [[10., 11., 12.], + [15., 16., 17.]], - [[ 6., 16.]], + [[ 0., 0., 0.], + [20., 21., 22.]], - [[ 7., 17.]], + [[ 0., 18., 19.], + [ 0., 23., 24.]]]) - [[ 8., 18.]]]], dtype=torch.float64) -########## torch.float64/torch.int64/size=(2, 3)+(4, 9)+() ########## +########## torch.float64/torch.int64/size=(2, 3)+(2, 3)+() ########## # sparse tensor -tensor(crow_indices=tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], - - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]), - col_indices=tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]), - values=tensor([[[[[ 1., 11., 21.], - [ 2., 12., 22.]], - - [[ 2., 12., 22.], - [ 3., 13., 23.]], - - [[ 3., 13., 23.], - [ 4., 14., 24.]], - - [[ 4., 14., 24.], - [ 5., 15., 25.]]], - +tensor(crow_indices=tensor([[[0, 3], + [0, 3], + [0, 3]], - [[[ 5., 15., 25.], - [ 6., 16., 26.]], + [[0, 3], + [0, 3], + [0, 3]]]), + col_indices=tensor([[[0, 1, 2], + [0, 1, 2], + [0, 1, 2]], - [[ 6., 16., 26.], - [ 7., 17., 27.]], + [[0, 1, 2], + [0, 1, 2], + [0, 1, 2]]]), + values=tensor([[[[[1.], + [3.]], - [[ 7., 17., 27.], - [ 8., 18., 28.]], + [[2.], + [0.]], - [[ 8., 18., 28.], - [ 9., 19., 29.]]], + [[0.], + [4.]]], - [[[ 9., 19., 29.], - [10., 20., 30.]], + [[[1.], + [4.]], - [[10., 20., 30.], - [11., 21., 31.]], + [[2.], + [0.]], - [[11., 21., 31.], - [12., 22., 32.]], + [[3.], + [0.]]], - [[12., 22., 32.], - [13., 23., 33.]]]], + [[[1.], + [2.]], + [[0.], + [3.]], - [[[[13., 23., 33.], - [14., 24., 34.]], + [[0.], + [4.]]]], - [[14., 24., 34.], - [15., 25., 35.]], - [[15., 25., 35.], - [16., 26., 36.]], - [[16., 26., 36.], - [17., 27., 37.]]], + [[[[0.], + [2.]], + [[1.], + [3.]], - [[[17., 27., 37.], - [18., 28., 38.]], + [[0.], + [4.]]], - [[18., 28., 38.], - [19., 29., 39.]], - [[19., 29., 39.], - [20., 30., 40.]], + [[[1.], + [3.]], - [[20., 30., 40.], - [21., 31., 41.]]], + [[0.], + [4.]], + [[2.], + [0.]]], - [[[21., 31., 41.], - [22., 32., 42.]], - [[22., 32., 42.], - [23., 33., 43.]], + [[[1.], + [0.]], - [[23., 33., 43.], - [24., 34., 44.]], + [[2.], + [4.]], - [[24., 34., 44.], - [25., 35., 45.]]]]]), size=(2, 3, 4, 9), nnz=4, + [[3.], + [0.]]]]]), size=(2, 3, 2, 3), nnz=3, dtype=torch.float64, layout=torch.sparse_bsr) # _crow_indices -tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], +tensor([[[0, 3], + [0, 3], + [0, 3]], - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]) + [[0, 3], + [0, 3], + [0, 3]]]) # _col_indices -tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], +tensor([[[0, 1, 2], + [0, 1, 2], + [0, 1, 2]], - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]) + [[0, 1, 2], + [0, 1, 2], + [0, 1, 2]]]) # _values -tensor([[[[[ 1., 11., 21.], - [ 2., 12., 22.]], +tensor([[[[[1.], + [3.]], - [[ 2., 12., 22.], - [ 3., 13., 23.]], + [[2.], + [0.]], - [[ 3., 13., 23.], - [ 4., 14., 24.]], + [[0.], + [4.]]], - [[ 4., 14., 24.], - [ 5., 15., 25.]]], + [[[1.], + [4.]], - [[[ 5., 15., 25.], - [ 6., 16., 26.]], + [[2.], + [0.]], - [[ 6., 16., 26.], - [ 7., 17., 27.]], + [[3.], + [0.]]], - [[ 7., 17., 27.], - [ 8., 18., 28.]], - [[ 8., 18., 28.], - [ 9., 19., 29.]]], + [[[1.], + [2.]], + [[0.], + [3.]], - [[[ 9., 19., 29.], - [10., 20., 30.]], + [[0.], + [4.]]]], - [[10., 20., 30.], - [11., 21., 31.]], - [[11., 21., 31.], - [12., 22., 32.]], - [[12., 22., 32.], - [13., 23., 33.]]]], + [[[[0.], + [2.]], + [[1.], + [3.]], + [[0.], + [4.]]], - [[[[13., 23., 33.], - [14., 24., 34.]], - [[14., 24., 34.], - [15., 25., 35.]], + [[[1.], + [3.]], - [[15., 25., 35.], - [16., 26., 36.]], + [[0.], + [4.]], - [[16., 26., 36.], - [17., 27., 37.]]], + [[2.], + [0.]]], - [[[17., 27., 37.], - [18., 28., 38.]], + [[[1.], + [0.]], - [[18., 28., 38.], - [19., 29., 39.]], + [[2.], + [4.]], - [[19., 29., 39.], - [20., 30., 40.]], + [[3.], + [0.]]]]], dtype=torch.float64) - [[20., 30., 40.], - [21., 31., 41.]]], - - - [[[21., 31., 41.], - [22., 32., 42.]], - - [[22., 32., 42.], - [23., 33., 43.]], - - [[23., 33., 43.], - [24., 34., 44.]], - - [[24., 34., 44.], - [25., 35., 45.]]]]], dtype=torch.float64) - - -########## torch.float32/torch.int32/size=()+(6, 6)+(2,) ########## +########## torch.float64/torch.int64/size=()+(8, 6)+() ########## # sparse tensor -tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 2]), - values=tensor([[[[ 1., 101.], - [ 11., 111.]], - - [[ 2., 102.], - [ 12., 112.]], +tensor(crow_indices=tensor([0, 2, 3, 5, 7]), + col_indices=tensor([0, 1, 0, 0, 1, 0, 1]), + values=tensor([[[ 0., 1., 0.], + [ 0., 4., 0.]], - [[ 3., 103.], - [ 13., 113.]]], + [[ 2., 0., 3.], + [ 0., 5., 0.]], + [[ 6., 7., 8.], + [ 0., 0., 0.]], - [[[ 2., 102.], - [ 12., 112.]], + [[ 0., 9., 0.], + [13., 0., 14.]], - [[ 3., 103.], - [ 13., 113.]], + [[10., 11., 12.], + [15., 16., 17.]], - [[ 4., 104.], - [ 14., 114.]]], + [[ 0., 0., 0.], + [20., 21., 22.]], - - [[[ 3., 103.], - [ 13., 113.]], - - [[ 4., 104.], - [ 14., 114.]], - - [[ 5., 105.], - [ 15., 115.]]], - - - [[[ 4., 104.], - [ 14., 114.]], - - [[ 5., 105.], - [ 15., 115.]], - - [[ 6., 106.], - [ 16., 116.]]]]), size=(6, 6, 2), nnz=4, - layout=torch.sparse_bsr) + [[ 0., 18., 19.], + [ 0., 23., 24.]]]), size=(8, 6), nnz=7, + dtype=torch.float64, layout=torch.sparse_bsr) # _crow_indices -tensor([0, 2, 4], dtype=torch.int32) +tensor([0, 2, 3, 5, 7]) # _col_indices -tensor([0, 1, 0, 2], dtype=torch.int32) +tensor([0, 1, 0, 0, 1, 0, 1]) # _values -tensor([[[[ 1., 101.], - [ 11., 111.]], - - [[ 2., 102.], - [ 12., 112.]], - - [[ 3., 103.], - [ 13., 113.]]], - - - [[[ 2., 102.], - [ 12., 112.]], - - [[ 3., 103.], - [ 13., 113.]], - - [[ 4., 104.], - [ 14., 114.]]], - +tensor([[[ 0., 1., 0.], + [ 0., 4., 0.]], - [[[ 3., 103.], - [ 13., 113.]], + [[ 2., 0., 3.], + [ 0., 5., 0.]], - [[ 4., 104.], - [ 14., 114.]], + [[ 6., 7., 8.], + [ 0., 0., 0.]], - [[ 5., 105.], - [ 15., 115.]]], + [[ 0., 9., 0.], + [13., 0., 14.]], + [[10., 11., 12.], + [15., 16., 17.]], - [[[ 4., 104.], - [ 14., 114.]], + [[ 0., 0., 0.], + [20., 21., 22.]], - [[ 5., 105.], - [ 15., 115.]], + [[ 0., 18., 19.], + [ 0., 23., 24.]]], dtype=torch.float64) - [[ 6., 106.], - [ 16., 116.]]]]) -########## torch.float32/torch.int32/size=()+(4, 9)+(4, 2) ########## +########## torch.float32/torch.int32/size=(2, 3)+(2, 3)+(4,) ########## # sparse tensor -tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 2]), - values=tensor([[[[[1.0000e+00, 1.0010e+03], - [1.0100e+02, 1.1010e+03], - [2.0100e+02, 1.2010e+03], - [3.0100e+02, 1.3010e+03]], +tensor(crow_indices=tensor([[[0, 3], + [0, 3], + [0, 3]], - [[1.1000e+01, 1.0110e+03], - [1.1100e+02, 1.1110e+03], - [2.1100e+02, 1.2110e+03], - [3.1100e+02, 1.3110e+03]], + [[0, 3], + [0, 3], + [0, 3]]]), + col_indices=tensor([[[0, 1, 2], + [0, 1, 2], + [0, 1, 2]], - [[2.1000e+01, 1.0210e+03], - [1.2100e+02, 1.1210e+03], - [2.2100e+02, 1.2210e+03], - [3.2100e+02, 1.3210e+03]]], + [[0, 1, 2], + [0, 1, 2], + [0, 1, 2]]]), + values=tensor([[[[[[1., 2., 3., 4.]], + [[3., 4., 5., 6.]]], - [[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]], + [[[2., 3., 4., 5.]], - [[2.2000e+01, 1.0220e+03], - [1.2200e+02, 1.1220e+03], - [2.2200e+02, 1.2220e+03], - [3.2200e+02, 1.3220e+03]]]], + [[0., 0., 0., 0.]]], + [[[0., 0., 0., 0.]], - [[[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], + [[4., 5., 6., 7.]]]], - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]], - [[2.2000e+01, 1.0220e+03], - [1.2200e+02, 1.1220e+03], - [2.2200e+02, 1.2220e+03], - [3.2200e+02, 1.3220e+03]]], + [[[[1., 2., 3., 4.]], - [[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], + [[4., 5., 6., 7.]]], - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]], - [[2.3000e+01, 1.0230e+03], - [1.2300e+02, 1.1230e+03], - [2.2300e+02, 1.2230e+03], - [3.2300e+02, 1.3230e+03]]]], + [[[2., 3., 4., 5.]], + [[0., 0., 0., 0.]]], - [[[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], + [[[3., 4., 5., 6.]], - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]], + [[0., 0., 0., 0.]]]], - [[2.3000e+01, 1.0230e+03], - [1.2300e+02, 1.1230e+03], - [2.2300e+02, 1.2230e+03], - [3.2300e+02, 1.3230e+03]]], - [[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], + [[[[1., 2., 3., 4.]], - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]], + [[2., 3., 4., 5.]]], - [[2.4000e+01, 1.0240e+03], - [1.2400e+02, 1.1240e+03], - [2.2400e+02, 1.2240e+03], - [3.2400e+02, 1.3240e+03]]]], + [[[0., 0., 0., 0.]], + [[3., 4., 5., 6.]]], - [[[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]], + [[[0., 0., 0., 0.]], - [[2.4000e+01, 1.0240e+03], - [1.2400e+02, 1.1240e+03], - [2.2400e+02, 1.2240e+03], - [3.2400e+02, 1.3240e+03]]], + [[4., 5., 6., 7.]]]]], - [[[5.0000e+00, 1.0050e+03], - [1.0500e+02, 1.1050e+03], - [2.0500e+02, 1.2050e+03], - [3.0500e+02, 1.3050e+03]], - [[1.5000e+01, 1.0150e+03], - [1.1500e+02, 1.1150e+03], - [2.1500e+02, 1.2150e+03], - [3.1500e+02, 1.3150e+03]], - [[2.5000e+01, 1.0250e+03], - [1.2500e+02, 1.1250e+03], - [2.2500e+02, 1.2250e+03], - [3.2500e+02, 1.3250e+03]]]]]), size=(4, 9, 4, 2), - nnz=4, layout=torch.sparse_bsr) -# _crow_indices -tensor([0, 2, 4], dtype=torch.int32) -# _col_indices -tensor([0, 1, 0, 2], dtype=torch.int32) -# _values -tensor([[[[[1.0000e+00, 1.0010e+03], - [1.0100e+02, 1.1010e+03], - [2.0100e+02, 1.2010e+03], - [3.0100e+02, 1.3010e+03]], - - [[1.1000e+01, 1.0110e+03], - [1.1100e+02, 1.1110e+03], - [2.1100e+02, 1.2110e+03], - [3.1100e+02, 1.3110e+03]], - - [[2.1000e+01, 1.0210e+03], - [1.2100e+02, 1.1210e+03], - [2.2100e+02, 1.2210e+03], - [3.2100e+02, 1.3210e+03]]], - - - [[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], - - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]], + [[[[[0., 0., 0., 0.]], - [[2.2000e+01, 1.0220e+03], - [1.2200e+02, 1.1220e+03], - [2.2200e+02, 1.2220e+03], - [3.2200e+02, 1.3220e+03]]]], + [[2., 3., 4., 5.]]], + [[[1., 2., 3., 4.]], - [[[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], + [[3., 4., 5., 6.]]], - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]], - [[2.2000e+01, 1.0220e+03], - [1.2200e+02, 1.1220e+03], - [2.2200e+02, 1.2220e+03], - [3.2200e+02, 1.3220e+03]]], + [[[0., 0., 0., 0.]], + [[4., 5., 6., 7.]]]], - [[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]], - [[2.3000e+01, 1.0230e+03], - [1.2300e+02, 1.1230e+03], - [2.2300e+02, 1.2230e+03], - [3.2300e+02, 1.3230e+03]]]], + [[[[1., 2., 3., 4.]], + [[3., 4., 5., 6.]]], - [[[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], + [[[0., 0., 0., 0.]], - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]], + [[4., 5., 6., 7.]]], - [[2.3000e+01, 1.0230e+03], - [1.2300e+02, 1.1230e+03], - [2.2300e+02, 1.2230e+03], - [3.2300e+02, 1.3230e+03]]], + [[[2., 3., 4., 5.]], - [[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], + [[0., 0., 0., 0.]]]], - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]], - [[2.4000e+01, 1.0240e+03], - [1.2400e+02, 1.1240e+03], - [2.2400e+02, 1.2240e+03], - [3.2400e+02, 1.3240e+03]]]], + [[[[1., 2., 3., 4.]], + [[0., 0., 0., 0.]]], - [[[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]], + [[[2., 3., 4., 5.]], - [[2.4000e+01, 1.0240e+03], - [1.2400e+02, 1.1240e+03], - [2.2400e+02, 1.2240e+03], - [3.2400e+02, 1.3240e+03]]], + [[4., 5., 6., 7.]]], - [[[5.0000e+00, 1.0050e+03], - [1.0500e+02, 1.1050e+03], - [2.0500e+02, 1.2050e+03], - [3.0500e+02, 1.3050e+03]], + [[[3., 4., 5., 6.]], - [[1.5000e+01, 1.0150e+03], - [1.1500e+02, 1.1150e+03], - [2.1500e+02, 1.2150e+03], - [3.1500e+02, 1.3150e+03]], - - [[2.5000e+01, 1.0250e+03], - [1.2500e+02, 1.1250e+03], - [2.2500e+02, 1.2250e+03], - [3.2500e+02, 1.3250e+03]]]]]) - -########## torch.float32/torch.int32/size=(2, 3)+(6, 6)+(2, 1) ########## -# sparse tensor -tensor(crow_indices=tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], - - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]), - col_indices=tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]), - values=tensor([[[[[[[ 1.], - [101.]], + [[0., 0., 0., 0.]]]]]]), size=(2, 3, 2, 3, 4), nnz=3, + layout=torch.sparse_bsr) +# _crow_indices +tensor([[[0, 3], + [0, 3], + [0, 3]], - [[ 11.], - [111.]]], + [[0, 3], + [0, 3], + [0, 3]]], dtype=torch.int32) +# _col_indices +tensor([[[0, 1, 2], + [0, 1, 2], + [0, 1, 2]], + [[0, 1, 2], + [0, 1, 2], + [0, 1, 2]]], dtype=torch.int32) +# _values +tensor([[[[[[1., 2., 3., 4.]], - [[[ 2.], - [102.]], + [[3., 4., 5., 6.]]], - [[ 12.], - [112.]]], + [[[2., 3., 4., 5.]], - [[[ 3.], - [103.]], + [[0., 0., 0., 0.]]], - [[ 13.], - [113.]]]], + [[[0., 0., 0., 0.]], + [[4., 5., 6., 7.]]]], - [[[[ 2.], - [102.]], - [[ 12.], - [112.]]], + [[[[1., 2., 3., 4.]], - [[[ 3.], - [103.]], + [[4., 5., 6., 7.]]], - [[ 13.], - [113.]]], + [[[2., 3., 4., 5.]], - [[[ 4.], - [104.]], + [[0., 0., 0., 0.]]], - [[ 14.], - [114.]]]], + [[[3., 4., 5., 6.]], + [[0., 0., 0., 0.]]]], - [[[[ 3.], - [103.]], - [[ 13.], - [113.]]], + [[[[1., 2., 3., 4.]], - [[[ 4.], - [104.]], + [[2., 3., 4., 5.]]], - [[ 14.], - [114.]]], + [[[0., 0., 0., 0.]], - [[[ 5.], - [105.]], + [[3., 4., 5., 6.]]], - [[ 15.], - [115.]]]], + [[[0., 0., 0., 0.]], + [[4., 5., 6., 7.]]]]], - [[[[ 4.], - [104.]], - [[ 14.], - [114.]]], - [[[ 5.], - [105.]], + [[[[[0., 0., 0., 0.]], - [[ 15.], - [115.]]], + [[2., 3., 4., 5.]]], - [[[ 6.], - [106.]], + [[[1., 2., 3., 4.]], - [[ 16.], - [116.]]]]], + [[3., 4., 5., 6.]]], + [[[0., 0., 0., 0.]], + [[4., 5., 6., 7.]]]], - [[[[[ 5.], - [105.]], - [[ 15.], - [115.]]], + [[[[1., 2., 3., 4.]], - [[[ 6.], - [106.]], + [[3., 4., 5., 6.]]], - [[ 16.], - [116.]]], + [[[0., 0., 0., 0.]], - [[[ 7.], - [107.]], + [[4., 5., 6., 7.]]], - [[ 17.], - [117.]]]], + [[[2., 3., 4., 5.]], + [[0., 0., 0., 0.]]]], - [[[[ 6.], - [106.]], - [[ 16.], - [116.]]], + [[[[1., 2., 3., 4.]], - [[[ 7.], - [107.]], + [[0., 0., 0., 0.]]], - [[ 17.], - [117.]]], + [[[2., 3., 4., 5.]], - [[[ 8.], - [108.]], + [[4., 5., 6., 7.]]], - [[ 18.], - [118.]]]], + [[[3., 4., 5., 6.]], + [[0., 0., 0., 0.]]]]]]) - [[[[ 7.], - [107.]], +########## torch.float32/torch.int32/size=()+(8, 6)+(4, 2) ########## +# sparse tensor +tensor(crow_indices=tensor([0, 2, 3, 5, 7]), + col_indices=tensor([0, 1, 0, 0, 1, 0, 1]), + values=tensor([[[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 17.], - [117.]]], + [[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[[ 8.], - [108.]], - [[ 18.], - [118.]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]], - [[[ 9.], - [109.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[ 19.], - [119.]]]], + [[[[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], - [[[[ 8.], - [108.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 18.], - [118.]]], + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]]], - [[[ 9.], - [109.]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 19.], - [119.]]], + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[[ 10.], - [110.]], - [[ 20.], - [120.]]]]], + [[[[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]]], - [[[[[ 9.], - [109.]], - [[ 19.], - [119.]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[ 10.], - [110.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[ 20.], - [120.]]], - [[[ 11.], - [111.]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 21.], - [121.]]]], + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[[[ 10.], - [110.]], + [[[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], - [[ 20.], - [120.]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]]]], - [[[ 11.], - [111.]], - [[ 21.], - [121.]]], + [[[[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], - [[[ 12.], - [112.]], + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], - [[ 22.], - [122.]]]], + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]]], + [[[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], - [[[[ 11.], - [111.]], + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], - [[ 21.], - [121.]]], + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]]]], - [[[ 12.], - [112.]], - [[ 22.], - [122.]]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[ 13.], - [113.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[ 23.], - [123.]]]], + [[[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], - [[[[ 12.], - [112.]], + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]]]], - [[ 22.], - [122.]]], - [[[ 13.], - [113.]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 23.], - [123.]]], + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]]], - [[[ 14.], - [114.]], - [[ 24.], - [124.]]]]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]]]]), size=(8, 6, 4, 2), nnz=7, + layout=torch.sparse_bsr) +# _crow_indices +tensor([0, 2, 3, 5, 7], dtype=torch.int32) +# _col_indices +tensor([0, 1, 0, 0, 1, 0, 1], dtype=torch.int32) +# _values +tensor([[[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[[[[[ 13.], - [113.]], - [[ 23.], - [123.]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]], - [[[ 14.], - [114.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[ 24.], - [124.]]], - [[[ 15.], - [115.]], + [[[[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], - [[ 25.], - [125.]]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]]], - [[[[ 14.], - [114.]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 24.], - [124.]]], + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[[ 15.], - [115.]], - [[ 25.], - [125.]]], + [[[[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], - [[[ 16.], - [116.]], + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], - [[ 26.], - [126.]]]], + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[[ 15.], - [115.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 25.], - [125.]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[[ 16.], - [116.]], - [[ 26.], - [126.]]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], - [[[ 17.], - [117.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[ 27.], - [127.]]]], + [[[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[[ 16.], - [116.]], + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]]]], - [[ 26.], - [126.]]], - [[[ 17.], - [117.]], + [[[[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], - [[ 27.], - [127.]]], + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]]], - [[[ 18.], - [118.]], - [[ 28.], - [128.]]]]], + [[[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]]]], - [[[[[ 17.], - [117.]], - [[ 27.], - [127.]]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[ 18.], - [118.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[ 28.], - [128.]]], + [[[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], - [[[ 19.], - [119.]], + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], - [[ 29.], - [129.]]]], + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]]]], - [[[[ 18.], - [118.]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 28.], - [128.]]], + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]]], - [[[ 19.], - [119.]], - [[ 29.], - [129.]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], - [[[ 20.], - [120.]], + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]]]]) - [[ 30.], - [130.]]]], +########## torch.float64/torch.int32/size=(2, 3)+(2, 3)+(4,) ########## +# sparse tensor +tensor(crow_indices=tensor([[[0, 3], + [0, 3], + [0, 3]], + [[0, 3], + [0, 3], + [0, 3]]]), + col_indices=tensor([[[0, 1, 2], + [0, 1, 2], + [0, 1, 2]], - [[[[ 19.], - [119.]], + [[0, 1, 2], + [0, 1, 2], + [0, 1, 2]]]), + values=tensor([[[[[[1., 2., 3., 4.]], - [[ 29.], - [129.]]], + [[3., 4., 5., 6.]]], - [[[ 20.], - [120.]], + [[[2., 3., 4., 5.]], - [[ 30.], - [130.]]], + [[0., 0., 0., 0.]]], - [[[ 21.], - [121.]], + [[[0., 0., 0., 0.]], - [[ 31.], - [131.]]]], + [[4., 5., 6., 7.]]]], - [[[[ 20.], - [120.]], + [[[[1., 2., 3., 4.]], - [[ 30.], - [130.]]], + [[4., 5., 6., 7.]]], - [[[ 21.], - [121.]], + [[[2., 3., 4., 5.]], - [[ 31.], - [131.]]], + [[0., 0., 0., 0.]]], - [[[ 22.], - [122.]], + [[[3., 4., 5., 6.]], - [[ 32.], - [132.]]]]], + [[0., 0., 0., 0.]]]], + [[[[1., 2., 3., 4.]], - [[[[[ 21.], - [121.]], + [[2., 3., 4., 5.]]], - [[ 31.], - [131.]]], + [[[0., 0., 0., 0.]], - [[[ 22.], - [122.]], + [[3., 4., 5., 6.]]], - [[ 32.], - [132.]]], + [[[0., 0., 0., 0.]], - [[[ 23.], - [123.]], + [[4., 5., 6., 7.]]]]], - [[ 33.], - [133.]]]], - [[[[ 22.], - [122.]], + [[[[[0., 0., 0., 0.]], - [[ 32.], - [132.]]], + [[2., 3., 4., 5.]]], - [[[ 23.], - [123.]], + [[[1., 2., 3., 4.]], - [[ 33.], - [133.]]], + [[3., 4., 5., 6.]]], - [[[ 24.], - [124.]], + [[[0., 0., 0., 0.]], - [[ 34.], - [134.]]]], + [[4., 5., 6., 7.]]]], - [[[[ 23.], - [123.]], + [[[[1., 2., 3., 4.]], - [[ 33.], - [133.]]], + [[3., 4., 5., 6.]]], - [[[ 24.], - [124.]], + [[[0., 0., 0., 0.]], - [[ 34.], - [134.]]], + [[4., 5., 6., 7.]]], - [[[ 25.], - [125.]], + [[[2., 3., 4., 5.]], - [[ 35.], - [135.]]]], + [[0., 0., 0., 0.]]]], - [[[[ 24.], - [124.]], + [[[[1., 2., 3., 4.]], - [[ 34.], - [134.]]], + [[0., 0., 0., 0.]]], - [[[ 25.], - [125.]], + [[[2., 3., 4., 5.]], - [[ 35.], - [135.]]], + [[4., 5., 6., 7.]]], - [[[ 26.], - [126.]], + [[[3., 4., 5., 6.]], - [[ 36.], - [136.]]]]]]]), size=(2, 3, 6, 6, 2, 1), nnz=4, - layout=torch.sparse_bsr) + [[0., 0., 0., 0.]]]]]]), size=(2, 3, 2, 3, 4), nnz=3, + dtype=torch.float64, layout=torch.sparse_bsr) # _crow_indices -tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], +tensor([[[0, 3], + [0, 3], + [0, 3]], - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]], dtype=torch.int32) + [[0, 3], + [0, 3], + [0, 3]]], dtype=torch.int32) # _col_indices -tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], +tensor([[[0, 1, 2], + [0, 1, 2], + [0, 1, 2]], - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]], dtype=torch.int32) + [[0, 1, 2], + [0, 1, 2], + [0, 1, 2]]], dtype=torch.int32) # _values -tensor([[[[[[[ 1.], - [101.]], - - [[ 11.], - [111.]]], - - - [[[ 2.], - [102.]], - - [[ 12.], - [112.]]], - - - [[[ 3.], - [103.]], - - [[ 13.], - [113.]]]], +tensor([[[[[[1., 2., 3., 4.]], + [[3., 4., 5., 6.]]], - [[[[ 2.], - [102.]], + [[[2., 3., 4., 5.]], - [[ 12.], - [112.]]], + [[0., 0., 0., 0.]]], - [[[ 3.], - [103.]], + [[[0., 0., 0., 0.]], - [[ 13.], - [113.]]], + [[4., 5., 6., 7.]]]], - [[[ 4.], - [104.]], - [[ 14.], - [114.]]]], + [[[[1., 2., 3., 4.]], + [[4., 5., 6., 7.]]], - [[[[ 3.], - [103.]], + [[[2., 3., 4., 5.]], - [[ 13.], - [113.]]], + [[0., 0., 0., 0.]]], - [[[ 4.], - [104.]], + [[[3., 4., 5., 6.]], - [[ 14.], - [114.]]], + [[0., 0., 0., 0.]]]], - [[[ 5.], - [105.]], - [[ 15.], - [115.]]]], + [[[[1., 2., 3., 4.]], + [[2., 3., 4., 5.]]], - [[[[ 4.], - [104.]], + [[[0., 0., 0., 0.]], - [[ 14.], - [114.]]], + [[3., 4., 5., 6.]]], - [[[ 5.], - [105.]], + [[[0., 0., 0., 0.]], - [[ 15.], - [115.]]], + [[4., 5., 6., 7.]]]]], - [[[ 6.], - [106.]], - [[ 16.], - [116.]]]]], + [[[[[0., 0., 0., 0.]], + [[2., 3., 4., 5.]]], - [[[[[ 5.], - [105.]], + [[[1., 2., 3., 4.]], - [[ 15.], - [115.]]], + [[3., 4., 5., 6.]]], - [[[ 6.], - [106.]], + [[[0., 0., 0., 0.]], - [[ 16.], - [116.]]], + [[4., 5., 6., 7.]]]], - [[[ 7.], - [107.]], - [[ 17.], - [117.]]]], + [[[[1., 2., 3., 4.]], + [[3., 4., 5., 6.]]], - [[[[ 6.], - [106.]], + [[[0., 0., 0., 0.]], - [[ 16.], - [116.]]], + [[4., 5., 6., 7.]]], - [[[ 7.], - [107.]], + [[[2., 3., 4., 5.]], - [[ 17.], - [117.]]], + [[0., 0., 0., 0.]]]], - [[[ 8.], - [108.]], - [[ 18.], - [118.]]]], + [[[[1., 2., 3., 4.]], + [[0., 0., 0., 0.]]], - [[[[ 7.], - [107.]], + [[[2., 3., 4., 5.]], - [[ 17.], - [117.]]], + [[4., 5., 6., 7.]]], - [[[ 8.], - [108.]], + [[[3., 4., 5., 6.]], - [[ 18.], - [118.]]], + [[0., 0., 0., 0.]]]]]], dtype=torch.float64) +########## torch.float64/torch.int32/size=()+(8, 6)+(4, 2) ########## +# sparse tensor +tensor(crow_indices=tensor([0, 2, 3, 5, 7]), + col_indices=tensor([0, 1, 0, 0, 1, 0, 1]), + values=tensor([[[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[ 9.], - [109.]], - - [[ 19.], - [119.]]]], - - - - [[[[ 8.], - [108.]], - - [[ 18.], - [118.]]], - - - [[[ 9.], - [109.]], - - [[ 19.], - [119.]]], - - - [[[ 10.], - [110.]], - - [[ 20.], - [120.]]]]], + [[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[[[ 9.], - [109.]], + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]], - [[ 19.], - [119.]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[[ 10.], - [110.]], - [[ 20.], - [120.]]], + [[[[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[ 11.], - [111.]], + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]]], - [[ 21.], - [121.]]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], - [[[[ 10.], - [110.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[ 20.], - [120.]]], - [[[ 11.], - [111.]], + [[[[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], - [[ 21.], - [121.]]], + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]]], - [[[ 12.], - [112.]], - [[ 22.], - [122.]]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[[[ 11.], - [111.]], - [[ 21.], - [121.]]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[ 12.], - [112.]], + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], - [[ 22.], - [122.]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[[ 13.], - [113.]], + [[[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], - [[ 23.], - [123.]]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]]]], - [[[[ 12.], - [112.]], - [[ 22.], - [122.]]], + [[[[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], - [[[ 13.], - [113.]], + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]]], - [[ 23.], - [123.]]], + [[[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], - [[[ 14.], - [114.]], + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], - [[ 24.], - [124.]]]]]], + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]]]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[[[[ 13.], - [113.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[ 23.], - [123.]]], + [[[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], - [[[ 14.], - [114.]], + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], - [[ 24.], - [124.]]], + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]]]], - [[[ 15.], - [115.]], - [[ 25.], - [125.]]]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]]], - [[[[ 14.], - [114.]], - [[ 24.], - [124.]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], - [[[ 15.], - [115.]], + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]]]]), size=(8, 6, 4, 2), nnz=7, + dtype=torch.float64, layout=torch.sparse_bsr) +# _crow_indices +tensor([0, 2, 3, 5, 7], dtype=torch.int32) +# _col_indices +tensor([0, 1, 0, 0, 1, 0, 1], dtype=torch.int32) +# _values +tensor([[[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 25.], - [125.]]], + [[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[[ 16.], - [116.]], - [[ 26.], - [126.]]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[[[ 15.], - [115.]], - [[ 25.], - [125.]]], + [[[[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], - [[[ 16.], - [116.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 26.], - [126.]]], + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]]], - [[[ 17.], - [117.]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 27.], - [127.]]]], + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[[[ 16.], - [116.]], - [[ 26.], - [126.]]], + [[[[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], - [[[ 17.], - [117.]], + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]]], - [[ 27.], - [127.]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[ 18.], - [118.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 28.], - [128.]]]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[[[ 17.], - [117.]], + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], - [[ 27.], - [127.]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[[ 18.], - [118.]], + [[[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], - [[ 28.], - [128.]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]]]], - [[[ 19.], - [119.]], - [[ 29.], - [129.]]]], + [[[[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], - [[[[ 18.], - [118.]], + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]]], - [[ 28.], - [128.]]], + [[[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], - [[[ 19.], - [119.]], + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], - [[ 29.], - [129.]]], + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]]]], - [[[ 20.], - [120.]], - [[ 30.], - [130.]]]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[[[ 19.], - [119.]], - [[ 29.], - [129.]]], + [[[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], - [[[ 20.], - [120.]], + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]]]], - [[ 30.], - [130.]]], - [[[ 21.], - [121.]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 31.], - [131.]]]], + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]]], - [[[[ 20.], - [120.]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 30.], - [130.]]], + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]]]], dtype=torch.float64) - [[[ 21.], - [121.]], - [[ 31.], - [131.]]], +########## torch.float32/torch.int64/size=(2, 3)+(2, 3)+(4,) ########## +# sparse tensor +tensor(crow_indices=tensor([[[0, 3], + [0, 3], + [0, 3]], + [[0, 3], + [0, 3], + [0, 3]]]), + col_indices=tensor([[[0, 1, 2], + [0, 1, 2], + [0, 1, 2]], - [[[ 22.], - [122.]], + [[0, 1, 2], + [0, 1, 2], + [0, 1, 2]]]), + values=tensor([[[[[[1., 2., 3., 4.]], - [[ 32.], - [132.]]]]], + [[3., 4., 5., 6.]]], + [[[2., 3., 4., 5.]], + [[0., 0., 0., 0.]]], - [[[[[ 21.], - [121.]], - [[ 31.], - [131.]]], + [[[0., 0., 0., 0.]], + [[4., 5., 6., 7.]]]], - [[[ 22.], - [122.]], - [[ 32.], - [132.]]], + [[[[1., 2., 3., 4.]], - [[[ 23.], - [123.]], + [[4., 5., 6., 7.]]], - [[ 33.], - [133.]]]], + [[[2., 3., 4., 5.]], + [[0., 0., 0., 0.]]], - [[[[ 22.], - [122.]], - [[ 32.], - [132.]]], + [[[3., 4., 5., 6.]], + [[0., 0., 0., 0.]]]], - [[[ 23.], - [123.]], - [[ 33.], - [133.]]], + [[[[1., 2., 3., 4.]], - [[[ 24.], - [124.]], + [[2., 3., 4., 5.]]], - [[ 34.], - [134.]]]], + [[[0., 0., 0., 0.]], + [[3., 4., 5., 6.]]], - [[[[ 23.], - [123.]], - [[ 33.], - [133.]]], + [[[0., 0., 0., 0.]], + [[4., 5., 6., 7.]]]]], - [[[ 24.], - [124.]], - [[ 34.], - [134.]]], - [[[ 25.], - [125.]], + [[[[[0., 0., 0., 0.]], - [[ 35.], - [135.]]]], + [[2., 3., 4., 5.]]], + [[[1., 2., 3., 4.]], - [[[[ 24.], - [124.]], + [[3., 4., 5., 6.]]], - [[ 34.], - [134.]]], + [[[0., 0., 0., 0.]], - [[[ 25.], - [125.]], + [[4., 5., 6., 7.]]]], - [[ 35.], - [135.]]], - [[[ 26.], - [126.]], + [[[[1., 2., 3., 4.]], - [[ 36.], - [136.]]]]]]]) + [[3., 4., 5., 6.]]], -########## torch.float64/torch.int32/size=()+(6, 6)+(2,) ########## -# sparse tensor -tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 2]), - values=tensor([[[[ 1., 101.], - [ 11., 111.]], + [[[0., 0., 0., 0.]], - [[ 2., 102.], - [ 12., 112.]], + [[4., 5., 6., 7.]]], - [[ 3., 103.], - [ 13., 113.]]], + [[[2., 3., 4., 5.]], - [[[ 2., 102.], - [ 12., 112.]], + [[0., 0., 0., 0.]]]], - [[ 3., 103.], - [ 13., 113.]], - [[ 4., 104.], - [ 14., 114.]]], + [[[[1., 2., 3., 4.]], - [[[ 3., 103.], - [ 13., 113.]], + [[0., 0., 0., 0.]]], - [[ 4., 104.], - [ 14., 114.]], - [[ 5., 105.], - [ 15., 115.]]], + [[[2., 3., 4., 5.]], + [[4., 5., 6., 7.]]], - [[[ 4., 104.], - [ 14., 114.]], - [[ 5., 105.], - [ 15., 115.]], + [[[3., 4., 5., 6.]], - [[ 6., 106.], - [ 16., 116.]]]]), size=(6, 6, 2), nnz=4, - dtype=torch.float64, layout=torch.sparse_bsr) + [[0., 0., 0., 0.]]]]]]), size=(2, 3, 2, 3, 4), nnz=3, + layout=torch.sparse_bsr) # _crow_indices -tensor([0, 2, 4], dtype=torch.int32) -# _col_indices -tensor([0, 1, 0, 2], dtype=torch.int32) -# _values -tensor([[[[ 1., 101.], - [ 11., 111.]], - - [[ 2., 102.], - [ 12., 112.]], - - [[ 3., 103.], - [ 13., 113.]]], +tensor([[[0, 3], + [0, 3], + [0, 3]], + [[0, 3], + [0, 3], + [0, 3]]]) +# _col_indices +tensor([[[0, 1, 2], + [0, 1, 2], + [0, 1, 2]], - [[[ 2., 102.], - [ 12., 112.]], - - [[ 3., 103.], - [ 13., 113.]], - - [[ 4., 104.], - [ 14., 114.]]], + [[0, 1, 2], + [0, 1, 2], + [0, 1, 2]]]) +# _values +tensor([[[[[[1., 2., 3., 4.]], + [[3., 4., 5., 6.]]], - [[[ 3., 103.], - [ 13., 113.]], - [[ 4., 104.], - [ 14., 114.]], + [[[2., 3., 4., 5.]], - [[ 5., 105.], - [ 15., 115.]]], + [[0., 0., 0., 0.]]], - [[[ 4., 104.], - [ 14., 114.]], + [[[0., 0., 0., 0.]], - [[ 5., 105.], - [ 15., 115.]], + [[4., 5., 6., 7.]]]], - [[ 6., 106.], - [ 16., 116.]]]], dtype=torch.float64) -########## torch.float64/torch.int32/size=()+(4, 9)+(4, 2) ########## -# sparse tensor -tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 2]), - values=tensor([[[[[1.0000e+00, 1.0010e+03], - [1.0100e+02, 1.1010e+03], - [2.0100e+02, 1.2010e+03], - [3.0100e+02, 1.3010e+03]], - [[1.1000e+01, 1.0110e+03], - [1.1100e+02, 1.1110e+03], - [2.1100e+02, 1.2110e+03], - [3.1100e+02, 1.3110e+03]], + [[[[1., 2., 3., 4.]], - [[2.1000e+01, 1.0210e+03], - [1.2100e+02, 1.1210e+03], - [2.2100e+02, 1.2210e+03], - [3.2100e+02, 1.3210e+03]]], + [[4., 5., 6., 7.]]], - [[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], + [[[2., 3., 4., 5.]], - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]], + [[0., 0., 0., 0.]]], - [[2.2000e+01, 1.0220e+03], - [1.2200e+02, 1.1220e+03], - [2.2200e+02, 1.2220e+03], - [3.2200e+02, 1.3220e+03]]]], + [[[3., 4., 5., 6.]], + [[0., 0., 0., 0.]]]], - [[[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]], - [[2.2000e+01, 1.0220e+03], - [1.2200e+02, 1.1220e+03], - [2.2200e+02, 1.2220e+03], - [3.2200e+02, 1.3220e+03]]], + [[[[1., 2., 3., 4.]], + [[2., 3., 4., 5.]]], - [[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]], + [[[0., 0., 0., 0.]], - [[2.3000e+01, 1.0230e+03], - [1.2300e+02, 1.1230e+03], - [2.2300e+02, 1.2230e+03], - [3.2300e+02, 1.3230e+03]]]], + [[3., 4., 5., 6.]]], + [[[0., 0., 0., 0.]], - [[[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], + [[4., 5., 6., 7.]]]]], - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]], - [[2.3000e+01, 1.0230e+03], - [1.2300e+02, 1.1230e+03], - [2.2300e+02, 1.2230e+03], - [3.2300e+02, 1.3230e+03]]], - [[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], + [[[[[0., 0., 0., 0.]], - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]], + [[2., 3., 4., 5.]]], - [[2.4000e+01, 1.0240e+03], - [1.2400e+02, 1.1240e+03], - [2.2400e+02, 1.2240e+03], - [3.2400e+02, 1.3240e+03]]]], + [[[1., 2., 3., 4.]], + [[3., 4., 5., 6.]]], - [[[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]], + [[[0., 0., 0., 0.]], - [[2.4000e+01, 1.0240e+03], - [1.2400e+02, 1.1240e+03], - [2.2400e+02, 1.2240e+03], - [3.2400e+02, 1.3240e+03]]], + [[4., 5., 6., 7.]]]], - [[[5.0000e+00, 1.0050e+03], - [1.0500e+02, 1.1050e+03], - [2.0500e+02, 1.2050e+03], - [3.0500e+02, 1.3050e+03]], - [[1.5000e+01, 1.0150e+03], - [1.1500e+02, 1.1150e+03], - [2.1500e+02, 1.2150e+03], - [3.1500e+02, 1.3150e+03]], + [[[[1., 2., 3., 4.]], - [[2.5000e+01, 1.0250e+03], - [1.2500e+02, 1.1250e+03], - [2.2500e+02, 1.2250e+03], - [3.2500e+02, 1.3250e+03]]]]]), size=(4, 9, 4, 2), - nnz=4, dtype=torch.float64, layout=torch.sparse_bsr) -# _crow_indices -tensor([0, 2, 4], dtype=torch.int32) -# _col_indices -tensor([0, 1, 0, 2], dtype=torch.int32) -# _values -tensor([[[[[1.0000e+00, 1.0010e+03], - [1.0100e+02, 1.1010e+03], - [2.0100e+02, 1.2010e+03], - [3.0100e+02, 1.3010e+03]], + [[3., 4., 5., 6.]]], - [[1.1000e+01, 1.0110e+03], - [1.1100e+02, 1.1110e+03], - [2.1100e+02, 1.2110e+03], - [3.1100e+02, 1.3110e+03]], - [[2.1000e+01, 1.0210e+03], - [1.2100e+02, 1.1210e+03], - [2.2100e+02, 1.2210e+03], - [3.2100e+02, 1.3210e+03]]], + [[[0., 0., 0., 0.]], + [[4., 5., 6., 7.]]], - [[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]], + [[[2., 3., 4., 5.]], - [[2.2000e+01, 1.0220e+03], - [1.2200e+02, 1.1220e+03], - [2.2200e+02, 1.2220e+03], - [3.2200e+02, 1.3220e+03]]]], + [[0., 0., 0., 0.]]]], - [[[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], + [[[[1., 2., 3., 4.]], - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]], + [[0., 0., 0., 0.]]], - [[2.2000e+01, 1.0220e+03], - [1.2200e+02, 1.1220e+03], - [2.2200e+02, 1.2220e+03], - [3.2200e+02, 1.3220e+03]]], + [[[2., 3., 4., 5.]], - [[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], + [[4., 5., 6., 7.]]], - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]], - [[2.3000e+01, 1.0230e+03], - [1.2300e+02, 1.1230e+03], - [2.2300e+02, 1.2230e+03], - [3.2300e+02, 1.3230e+03]]]], + [[[3., 4., 5., 6.]], + [[0., 0., 0., 0.]]]]]]) +########## torch.float32/torch.int64/size=()+(8, 6)+(4, 2) ########## +# sparse tensor +tensor(crow_indices=tensor([0, 2, 3, 5, 7]), + col_indices=tensor([0, 1, 0, 0, 1, 0, 1]), + values=tensor([[[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], + [[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[2.3000e+01, 1.0230e+03], - [1.2300e+02, 1.1230e+03], - [2.2300e+02, 1.2230e+03], - [3.2300e+02, 1.3230e+03]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]], - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[2.4000e+01, 1.0240e+03], - [1.2400e+02, 1.1240e+03], - [2.2400e+02, 1.2240e+03], - [3.2400e+02, 1.3240e+03]]]], + [[[[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], - [[[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]], + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]]], - [[2.4000e+01, 1.0240e+03], - [1.2400e+02, 1.1240e+03], - [2.2400e+02, 1.2240e+03], - [3.2400e+02, 1.3240e+03]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[5.0000e+00, 1.0050e+03], - [1.0500e+02, 1.1050e+03], - [2.0500e+02, 1.2050e+03], - [3.0500e+02, 1.3050e+03]], + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], - [[1.5000e+01, 1.0150e+03], - [1.1500e+02, 1.1150e+03], - [2.1500e+02, 1.2150e+03], - [3.1500e+02, 1.3150e+03]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[2.5000e+01, 1.0250e+03], - [1.2500e+02, 1.1250e+03], - [2.2500e+02, 1.2250e+03], - [3.2500e+02, 1.3250e+03]]]]], dtype=torch.float64) -########## torch.float64/torch.int32/size=(2, 3)+(6, 6)+(2, 1) ########## -# sparse tensor -tensor(crow_indices=tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]), - col_indices=tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], + [[[[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]), - values=tensor([[[[[[[ 1.], - [101.]], + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], - [[ 11.], - [111.]]], + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]]], - [[[ 2.], - [102.]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 12.], - [112.]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[[ 3.], - [103.]], - [[ 13.], - [113.]]]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], - [[[[ 2.], - [102.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[ 12.], - [112.]]], + [[[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], - [[[ 3.], - [103.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 13.], - [113.]]], + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]]]], - [[[ 4.], - [104.]], - [[ 14.], - [114.]]]], + [[[[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]]], - [[[[ 3.], - [103.]], - [[ 13.], - [113.]]], + [[[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], - [[[ 4.], - [104.]], + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]]]], - [[ 14.], - [114.]]], - [[[ 5.], - [105.]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 15.], - [115.]]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[[[ 4.], - [104.]], + [[[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], - [[ 14.], - [114.]]], + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]]]], - [[[ 5.], - [105.]], - [[ 15.], - [115.]]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[ 6.], - [106.]], + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], - [[ 16.], - [116.]]]]], + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], - [[[[[ 5.], - [105.]], + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]]]]), size=(8, 6, 4, 2), nnz=7, + layout=torch.sparse_bsr) +# _crow_indices +tensor([0, 2, 3, 5, 7]) +# _col_indices +tensor([0, 1, 0, 0, 1, 0, 1]) +# _values +tensor([[[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 15.], - [115.]]], + [[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[[ 6.], - [106.]], - [[ 16.], - [116.]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]], - [[[ 7.], - [107.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[ 17.], - [117.]]]], + [[[[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], - [[[[ 6.], - [106.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 16.], - [116.]]], + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]]], - [[[ 7.], - [107.]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 17.], - [117.]]], + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[[ 8.], - [108.]], - [[ 18.], - [118.]]]], + [[[[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], - [[[[ 7.], - [107.]], + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]]], - [[ 17.], - [117.]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[ 8.], - [108.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 18.], - [118.]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[[ 9.], - [109.]], - [[ 19.], - [119.]]]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[[[ 8.], - [108.]], - [[ 18.], - [118.]]], + [[[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[ 9.], - [109.]], + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]]]], - [[ 19.], - [119.]]], - [[[ 10.], - [110.]], + [[[[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], - [[ 20.], - [120.]]]]], + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]]], + [[[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], - [[[[[ 9.], - [109.]], + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], - [[ 19.], - [119.]]], + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]]]], - [[[ 10.], - [110.]], - [[ 20.], - [120.]]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[ 11.], - [111.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[ 21.], - [121.]]]], + [[[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], - [[[[ 10.], - [110.]], + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]]]], - [[ 20.], - [120.]]], - [[[ 11.], - [111.]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 21.], - [121.]]], + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]]], - [[[ 12.], - [112.]], - [[ 22.], - [122.]]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]]]]) - [[[[ 11.], - [111.]], - [[ 21.], - [121.]]], +########## torch.float64/torch.int64/size=(2, 3)+(2, 3)+(4,) ########## +# sparse tensor +tensor(crow_indices=tensor([[[0, 3], + [0, 3], + [0, 3]], + [[0, 3], + [0, 3], + [0, 3]]]), + col_indices=tensor([[[0, 1, 2], + [0, 1, 2], + [0, 1, 2]], - [[[ 12.], - [112.]], + [[0, 1, 2], + [0, 1, 2], + [0, 1, 2]]]), + values=tensor([[[[[[1., 2., 3., 4.]], - [[ 22.], - [122.]]], + [[3., 4., 5., 6.]]], - [[[ 13.], - [113.]], + [[[2., 3., 4., 5.]], - [[ 23.], - [123.]]]], + [[0., 0., 0., 0.]]], + [[[0., 0., 0., 0.]], - [[[[ 12.], - [112.]], + [[4., 5., 6., 7.]]]], - [[ 22.], - [122.]]], - [[[ 13.], - [113.]], + [[[[1., 2., 3., 4.]], - [[ 23.], - [123.]]], + [[4., 5., 6., 7.]]], - [[[ 14.], - [114.]], + [[[2., 3., 4., 5.]], - [[ 24.], - [124.]]]]]], + [[0., 0., 0., 0.]]], + [[[3., 4., 5., 6.]], + [[0., 0., 0., 0.]]]], - [[[[[[ 13.], - [113.]], - [[ 23.], - [123.]]], + [[[[1., 2., 3., 4.]], + [[2., 3., 4., 5.]]], - [[[ 14.], - [114.]], - [[ 24.], - [124.]]], + [[[0., 0., 0., 0.]], + [[3., 4., 5., 6.]]], - [[[ 15.], - [115.]], - [[ 25.], - [125.]]]], + [[[0., 0., 0., 0.]], + [[4., 5., 6., 7.]]]]], - [[[[ 14.], - [114.]], - [[ 24.], - [124.]]], + [[[[[0., 0., 0., 0.]], - [[[ 15.], - [115.]], + [[2., 3., 4., 5.]]], - [[ 25.], - [125.]]], + [[[1., 2., 3., 4.]], - [[[ 16.], - [116.]], + [[3., 4., 5., 6.]]], - [[ 26.], - [126.]]]], + [[[0., 0., 0., 0.]], + [[4., 5., 6., 7.]]]], - [[[[ 15.], - [115.]], - [[ 25.], - [125.]]], + [[[[1., 2., 3., 4.]], - [[[ 16.], - [116.]], + [[3., 4., 5., 6.]]], - [[ 26.], - [126.]]], + [[[0., 0., 0., 0.]], - [[[ 17.], - [117.]], + [[4., 5., 6., 7.]]], - [[ 27.], - [127.]]]], + [[[2., 3., 4., 5.]], + [[0., 0., 0., 0.]]]], - [[[[ 16.], - [116.]], - [[ 26.], - [126.]]], + [[[[1., 2., 3., 4.]], - [[[ 17.], - [117.]], + [[0., 0., 0., 0.]]], - [[ 27.], - [127.]]], + [[[2., 3., 4., 5.]], - [[[ 18.], - [118.]], + [[4., 5., 6., 7.]]], - [[ 28.], - [128.]]]]], + [[[3., 4., 5., 6.]], + [[0., 0., 0., 0.]]]]]]), size=(2, 3, 2, 3, 4), nnz=3, + dtype=torch.float64, layout=torch.sparse_bsr) +# _crow_indices +tensor([[[0, 3], + [0, 3], + [0, 3]], + [[0, 3], + [0, 3], + [0, 3]]]) +# _col_indices +tensor([[[0, 1, 2], + [0, 1, 2], + [0, 1, 2]], - [[[[[ 17.], - [117.]], + [[0, 1, 2], + [0, 1, 2], + [0, 1, 2]]]) +# _values +tensor([[[[[[1., 2., 3., 4.]], - [[ 27.], - [127.]]], + [[3., 4., 5., 6.]]], - [[[ 18.], - [118.]], + [[[2., 3., 4., 5.]], - [[ 28.], - [128.]]], + [[0., 0., 0., 0.]]], - [[[ 19.], - [119.]], + [[[0., 0., 0., 0.]], - [[ 29.], - [129.]]]], + [[4., 5., 6., 7.]]]], - [[[[ 18.], - [118.]], + [[[[1., 2., 3., 4.]], - [[ 28.], - [128.]]], + [[4., 5., 6., 7.]]], - [[[ 19.], - [119.]], + [[[2., 3., 4., 5.]], - [[ 29.], - [129.]]], + [[0., 0., 0., 0.]]], - [[[ 20.], - [120.]], + [[[3., 4., 5., 6.]], - [[ 30.], - [130.]]]], + [[0., 0., 0., 0.]]]], - [[[[ 19.], - [119.]], + [[[[1., 2., 3., 4.]], - [[ 29.], - [129.]]], + [[2., 3., 4., 5.]]], - [[[ 20.], - [120.]], + [[[0., 0., 0., 0.]], - [[ 30.], - [130.]]], + [[3., 4., 5., 6.]]], - [[[ 21.], - [121.]], + [[[0., 0., 0., 0.]], - [[ 31.], - [131.]]]], + [[4., 5., 6., 7.]]]]], - [[[[ 20.], - [120.]], - [[ 30.], - [130.]]], + [[[[[0., 0., 0., 0.]], + [[2., 3., 4., 5.]]], - [[[ 21.], - [121.]], - [[ 31.], - [131.]]], + [[[1., 2., 3., 4.]], + [[3., 4., 5., 6.]]], - [[[ 22.], - [122.]], - [[ 32.], - [132.]]]]], + [[[0., 0., 0., 0.]], + [[4., 5., 6., 7.]]]], - [[[[[ 21.], - [121.]], + [[[[1., 2., 3., 4.]], - [[ 31.], - [131.]]], + [[3., 4., 5., 6.]]], - [[[ 22.], - [122.]], + [[[0., 0., 0., 0.]], - [[ 32.], - [132.]]], + [[4., 5., 6., 7.]]], - [[[ 23.], - [123.]], + [[[2., 3., 4., 5.]], - [[ 33.], - [133.]]]], + [[0., 0., 0., 0.]]]], - [[[[ 22.], - [122.]], + [[[[1., 2., 3., 4.]], - [[ 32.], - [132.]]], + [[0., 0., 0., 0.]]], - [[[ 23.], - [123.]], + [[[2., 3., 4., 5.]], - [[ 33.], - [133.]]], + [[4., 5., 6., 7.]]], - [[[ 24.], - [124.]], + [[[3., 4., 5., 6.]], - [[ 34.], - [134.]]]], + [[0., 0., 0., 0.]]]]]], dtype=torch.float64) +########## torch.float64/torch.int64/size=()+(8, 6)+(4, 2) ########## +# sparse tensor +tensor(crow_indices=tensor([0, 2, 3, 5, 7]), + col_indices=tensor([0, 1, 0, 0, 1, 0, 1]), + values=tensor([[[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], - [[[[ 23.], - [123.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[ 33.], - [133.]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[ 24.], - [124.]], + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]], - [[ 34.], - [134.]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[[ 25.], - [125.]], - [[ 35.], - [135.]]]], + [[[[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]]], - [[[[ 24.], - [124.]], - [[ 34.], - [134.]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], - [[[ 25.], - [125.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[ 35.], - [135.]]], - [[[ 26.], - [126.]], + [[[[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], - [[ 36.], - [136.]]]]]]]), size=(2, 3, 6, 6, 2, 1), nnz=4, - dtype=torch.float64, layout=torch.sparse_bsr) -# _crow_indices -tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]], dtype=torch.int32) -# _col_indices -tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]]], - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]], dtype=torch.int32) -# _values -tensor([[[[[[[ 1.], - [101.]], - [[ 11.], - [111.]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[ 2.], - [102.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[ 12.], - [112.]]], - [[[ 3.], - [103.]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 13.], - [113.]]]], + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[[[ 2.], - [102.]], + [[[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], - [[ 12.], - [112.]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]]]], - [[[ 3.], - [103.]], - [[ 13.], - [113.]]], + [[[[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], - [[[ 4.], - [104.]], + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], - [[ 14.], - [114.]]]], + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]]], + [[[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], - [[[[ 3.], - [103.]], + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], - [[ 13.], - [113.]]], + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]]]], - [[[ 4.], - [104.]], - [[ 14.], - [114.]]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[ 5.], - [105.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[ 15.], - [115.]]]], + [[[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], - [[[[ 4.], - [104.]], + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]]]], - [[ 14.], - [114.]]], - [[[ 5.], - [105.]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 15.], - [115.]]], + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]]], - [[[ 6.], - [106.]], - [[ 16.], - [116.]]]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], - - - [[[[[ 5.], - [105.]], - - [[ 15.], - [115.]]], - - - [[[ 6.], - [106.]], - - [[ 16.], - [116.]]], - - - [[[ 7.], - [107.]], - - [[ 17.], - [117.]]]], - - - - [[[[ 6.], - [106.]], - - [[ 16.], - [116.]]], - - - [[[ 7.], - [107.]], - - [[ 17.], - [117.]]], - - - [[[ 8.], - [108.]], - - [[ 18.], - [118.]]]], - - - - [[[[ 7.], - [107.]], - - [[ 17.], - [117.]]], - - - [[[ 8.], - [108.]], - - [[ 18.], - [118.]]], - - - [[[ 9.], - [109.]], - - [[ 19.], - [119.]]]], - - - - [[[[ 8.], - [108.]], - - [[ 18.], - [118.]]], - - - [[[ 9.], - [109.]], - - [[ 19.], - [119.]]], - - - [[[ 10.], - [110.]], - - [[ 20.], - [120.]]]]], - - - - - [[[[[ 9.], - [109.]], - - [[ 19.], - [119.]]], - - - [[[ 10.], - [110.]], - - [[ 20.], - [120.]]], - - - [[[ 11.], - [111.]], - - [[ 21.], - [121.]]]], - - - - [[[[ 10.], - [110.]], - - [[ 20.], - [120.]]], - - - [[[ 11.], - [111.]], - - [[ 21.], - [121.]]], - - - [[[ 12.], - [112.]], - - [[ 22.], - [122.]]]], - - - - [[[[ 11.], - [111.]], - - [[ 21.], - [121.]]], - - - [[[ 12.], - [112.]], - - [[ 22.], - [122.]]], - - - [[[ 13.], - [113.]], - - [[ 23.], - [123.]]]], - - - - [[[[ 12.], - [112.]], - - [[ 22.], - [122.]]], - - - [[[ 13.], - [113.]], - - [[ 23.], - [123.]]], - - - [[[ 14.], - [114.]], - - [[ 24.], - [124.]]]]]], - - - - - - [[[[[[ 13.], - [113.]], - - [[ 23.], - [123.]]], - - - [[[ 14.], - [114.]], - - [[ 24.], - [124.]]], - - - [[[ 15.], - [115.]], - - [[ 25.], - [125.]]]], - - - - [[[[ 14.], - [114.]], - - [[ 24.], - [124.]]], - - - [[[ 15.], - [115.]], - - [[ 25.], - [125.]]], - - - [[[ 16.], - [116.]], - - [[ 26.], - [126.]]]], - - - - [[[[ 15.], - [115.]], - - [[ 25.], - [125.]]], - - - [[[ 16.], - [116.]], - - [[ 26.], - [126.]]], - - - [[[ 17.], - [117.]], - - [[ 27.], - [127.]]]], - - - - [[[[ 16.], - [116.]], - - [[ 26.], - [126.]]], - - - [[[ 17.], - [117.]], - - [[ 27.], - [127.]]], - - - [[[ 18.], - [118.]], - - [[ 28.], - [128.]]]]], - - - - - [[[[[ 17.], - [117.]], - - [[ 27.], - [127.]]], - - - [[[ 18.], - [118.]], - - [[ 28.], - [128.]]], - - - [[[ 19.], - [119.]], - - [[ 29.], - [129.]]]], - - - - [[[[ 18.], - [118.]], - - [[ 28.], - [128.]]], - - - [[[ 19.], - [119.]], - - [[ 29.], - [129.]]], - - - [[[ 20.], - [120.]], - - [[ 30.], - [130.]]]], - - - - [[[[ 19.], - [119.]], - - [[ 29.], - [129.]]], - - - [[[ 20.], - [120.]], - - [[ 30.], - [130.]]], - - - [[[ 21.], - [121.]], - - [[ 31.], - [131.]]]], - - - - [[[[ 20.], - [120.]], - - [[ 30.], - [130.]]], - - - [[[ 21.], - [121.]], - - [[ 31.], - [131.]]], - - - [[[ 22.], - [122.]], - - [[ 32.], - [132.]]]]], - - - - - [[[[[ 21.], - [121.]], - - [[ 31.], - [131.]]], - - - [[[ 22.], - [122.]], - - [[ 32.], - [132.]]], - - - [[[ 23.], - [123.]], - - [[ 33.], - [133.]]]], - - - - [[[[ 22.], - [122.]], - - [[ 32.], - [132.]]], - - - [[[ 23.], - [123.]], - - [[ 33.], - [133.]]], - - - [[[ 24.], - [124.]], - - [[ 34.], - [134.]]]], - - - - [[[[ 23.], - [123.]], - - [[ 33.], - [133.]]], - - - [[[ 24.], - [124.]], - - [[ 34.], - [134.]]], - - - [[[ 25.], - [125.]], - - [[ 35.], - [135.]]]], - - - - [[[[ 24.], - [124.]], - - [[ 34.], - [134.]]], - - - [[[ 25.], - [125.]], - - [[ 35.], - [135.]]], - - - [[[ 26.], - [126.]], - - [[ 36.], - [136.]]]]]]], dtype=torch.float64) - - -########## torch.float32/torch.int64/size=()+(6, 6)+(2,) ########## -# sparse tensor -tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 2]), - values=tensor([[[[ 1., 101.], - [ 11., 111.]], - - [[ 2., 102.], - [ 12., 112.]], - - [[ 3., 103.], - [ 13., 113.]]], - - - [[[ 2., 102.], - [ 12., 112.]], - - [[ 3., 103.], - [ 13., 113.]], - - [[ 4., 104.], - [ 14., 114.]]], - - - [[[ 3., 103.], - [ 13., 113.]], - - [[ 4., 104.], - [ 14., 114.]], - - [[ 5., 105.], - [ 15., 115.]]], - - - [[[ 4., 104.], - [ 14., 114.]], - - [[ 5., 105.], - [ 15., 115.]], - - [[ 6., 106.], - [ 16., 116.]]]]), size=(6, 6, 2), nnz=4, - layout=torch.sparse_bsr) -# _crow_indices -tensor([0, 2, 4]) -# _col_indices -tensor([0, 1, 0, 2]) -# _values -tensor([[[[ 1., 101.], - [ 11., 111.]], - - [[ 2., 102.], - [ 12., 112.]], - - [[ 3., 103.], - [ 13., 113.]]], - - - [[[ 2., 102.], - [ 12., 112.]], - - [[ 3., 103.], - [ 13., 113.]], - - [[ 4., 104.], - [ 14., 114.]]], - - - [[[ 3., 103.], - [ 13., 113.]], - - [[ 4., 104.], - [ 14., 114.]], - - [[ 5., 105.], - [ 15., 115.]]], - - - [[[ 4., 104.], - [ 14., 114.]], - - [[ 5., 105.], - [ 15., 115.]], - - [[ 6., 106.], - [ 16., 116.]]]]) - -########## torch.float32/torch.int64/size=()+(4, 9)+(4, 2) ########## -# sparse tensor -tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 2]), - values=tensor([[[[[1.0000e+00, 1.0010e+03], - [1.0100e+02, 1.1010e+03], - [2.0100e+02, 1.2010e+03], - [3.0100e+02, 1.3010e+03]], - - [[1.1000e+01, 1.0110e+03], - [1.1100e+02, 1.1110e+03], - [2.1100e+02, 1.2110e+03], - [3.1100e+02, 1.3110e+03]], - - [[2.1000e+01, 1.0210e+03], - [1.2100e+02, 1.1210e+03], - [2.2100e+02, 1.2210e+03], - [3.2100e+02, 1.3210e+03]]], - - - [[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], - - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]], - - [[2.2000e+01, 1.0220e+03], - [1.2200e+02, 1.1220e+03], - [2.2200e+02, 1.2220e+03], - [3.2200e+02, 1.3220e+03]]]], - - - - [[[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], - - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]], - - [[2.2000e+01, 1.0220e+03], - [1.2200e+02, 1.1220e+03], - [2.2200e+02, 1.2220e+03], - [3.2200e+02, 1.3220e+03]]], - - - [[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], - - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]], - - [[2.3000e+01, 1.0230e+03], - [1.2300e+02, 1.1230e+03], - [2.2300e+02, 1.2230e+03], - [3.2300e+02, 1.3230e+03]]]], - - - - [[[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], - - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]], - - [[2.3000e+01, 1.0230e+03], - [1.2300e+02, 1.1230e+03], - [2.2300e+02, 1.2230e+03], - [3.2300e+02, 1.3230e+03]]], - - - [[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], - - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]], - - [[2.4000e+01, 1.0240e+03], - [1.2400e+02, 1.1240e+03], - [2.2400e+02, 1.2240e+03], - [3.2400e+02, 1.3240e+03]]]], - - - - [[[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], - - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]], - - [[2.4000e+01, 1.0240e+03], - [1.2400e+02, 1.1240e+03], - [2.2400e+02, 1.2240e+03], - [3.2400e+02, 1.3240e+03]]], - - - [[[5.0000e+00, 1.0050e+03], - [1.0500e+02, 1.1050e+03], - [2.0500e+02, 1.2050e+03], - [3.0500e+02, 1.3050e+03]], - - [[1.5000e+01, 1.0150e+03], - [1.1500e+02, 1.1150e+03], - [2.1500e+02, 1.2150e+03], - [3.1500e+02, 1.3150e+03]], - - [[2.5000e+01, 1.0250e+03], - [1.2500e+02, 1.1250e+03], - [2.2500e+02, 1.2250e+03], - [3.2500e+02, 1.3250e+03]]]]]), size=(4, 9, 4, 2), - nnz=4, layout=torch.sparse_bsr) -# _crow_indices -tensor([0, 2, 4]) -# _col_indices -tensor([0, 1, 0, 2]) -# _values -tensor([[[[[1.0000e+00, 1.0010e+03], - [1.0100e+02, 1.1010e+03], - [2.0100e+02, 1.2010e+03], - [3.0100e+02, 1.3010e+03]], - - [[1.1000e+01, 1.0110e+03], - [1.1100e+02, 1.1110e+03], - [2.1100e+02, 1.2110e+03], - [3.1100e+02, 1.3110e+03]], - - [[2.1000e+01, 1.0210e+03], - [1.2100e+02, 1.1210e+03], - [2.2100e+02, 1.2210e+03], - [3.2100e+02, 1.3210e+03]]], - - - [[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], - - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]], - - [[2.2000e+01, 1.0220e+03], - [1.2200e+02, 1.1220e+03], - [2.2200e+02, 1.2220e+03], - [3.2200e+02, 1.3220e+03]]]], - - - - [[[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], - - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]], - - [[2.2000e+01, 1.0220e+03], - [1.2200e+02, 1.1220e+03], - [2.2200e+02, 1.2220e+03], - [3.2200e+02, 1.3220e+03]]], - - - [[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], - - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]], - - [[2.3000e+01, 1.0230e+03], - [1.2300e+02, 1.1230e+03], - [2.2300e+02, 1.2230e+03], - [3.2300e+02, 1.3230e+03]]]], - - - - [[[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], - - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]], - - [[2.3000e+01, 1.0230e+03], - [1.2300e+02, 1.1230e+03], - [2.2300e+02, 1.2230e+03], - [3.2300e+02, 1.3230e+03]]], - - - [[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], - - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]], - - [[2.4000e+01, 1.0240e+03], - [1.2400e+02, 1.1240e+03], - [2.2400e+02, 1.2240e+03], - [3.2400e+02, 1.3240e+03]]]], - - - - [[[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], - - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]], - - [[2.4000e+01, 1.0240e+03], - [1.2400e+02, 1.1240e+03], - [2.2400e+02, 1.2240e+03], - [3.2400e+02, 1.3240e+03]]], - - - [[[5.0000e+00, 1.0050e+03], - [1.0500e+02, 1.1050e+03], - [2.0500e+02, 1.2050e+03], - [3.0500e+02, 1.3050e+03]], - - [[1.5000e+01, 1.0150e+03], - [1.1500e+02, 1.1150e+03], - [2.1500e+02, 1.2150e+03], - [3.1500e+02, 1.3150e+03]], - - [[2.5000e+01, 1.0250e+03], - [1.2500e+02, 1.1250e+03], - [2.2500e+02, 1.2250e+03], - [3.2500e+02, 1.3250e+03]]]]]) - -########## torch.float32/torch.int64/size=(2, 3)+(6, 6)+(2, 1) ########## -# sparse tensor -tensor(crow_indices=tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], - - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]), - col_indices=tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]), - values=tensor([[[[[[[ 1.], - [101.]], - - [[ 11.], - [111.]]], - - - [[[ 2.], - [102.]], - - [[ 12.], - [112.]]], - - - [[[ 3.], - [103.]], - - [[ 13.], - [113.]]]], - - - - [[[[ 2.], - [102.]], - - [[ 12.], - [112.]]], - - - [[[ 3.], - [103.]], - - [[ 13.], - [113.]]], - - - [[[ 4.], - [104.]], - - [[ 14.], - [114.]]]], - - - - [[[[ 3.], - [103.]], - - [[ 13.], - [113.]]], - - - [[[ 4.], - [104.]], - - [[ 14.], - [114.]]], - - - [[[ 5.], - [105.]], - - [[ 15.], - [115.]]]], - - - - [[[[ 4.], - [104.]], - - [[ 14.], - [114.]]], - - - [[[ 5.], - [105.]], - - [[ 15.], - [115.]]], - - - [[[ 6.], - [106.]], - - [[ 16.], - [116.]]]]], - - - - - [[[[[ 5.], - [105.]], - - [[ 15.], - [115.]]], - - - [[[ 6.], - [106.]], - - [[ 16.], - [116.]]], - - - [[[ 7.], - [107.]], - - [[ 17.], - [117.]]]], - - - - [[[[ 6.], - [106.]], - - [[ 16.], - [116.]]], - - - [[[ 7.], - [107.]], - - [[ 17.], - [117.]]], - - - [[[ 8.], - [108.]], - - [[ 18.], - [118.]]]], - - - - [[[[ 7.], - [107.]], - - [[ 17.], - [117.]]], - - - [[[ 8.], - [108.]], - - [[ 18.], - [118.]]], - - - [[[ 9.], - [109.]], - - [[ 19.], - [119.]]]], - - - - [[[[ 8.], - [108.]], - - [[ 18.], - [118.]]], - - - [[[ 9.], - [109.]], - - [[ 19.], - [119.]]], - - - [[[ 10.], - [110.]], - - [[ 20.], - [120.]]]]], - - - - - [[[[[ 9.], - [109.]], - - [[ 19.], - [119.]]], - - - [[[ 10.], - [110.]], - - [[ 20.], - [120.]]], - - - [[[ 11.], - [111.]], - - [[ 21.], - [121.]]]], - - - - [[[[ 10.], - [110.]], - - [[ 20.], - [120.]]], - - - [[[ 11.], - [111.]], - - [[ 21.], - [121.]]], - - - [[[ 12.], - [112.]], - - [[ 22.], - [122.]]]], - - - - [[[[ 11.], - [111.]], - - [[ 21.], - [121.]]], - - - [[[ 12.], - [112.]], - - [[ 22.], - [122.]]], - - - [[[ 13.], - [113.]], - - [[ 23.], - [123.]]]], - - - - [[[[ 12.], - [112.]], - - [[ 22.], - [122.]]], - - - [[[ 13.], - [113.]], - - [[ 23.], - [123.]]], - - - [[[ 14.], - [114.]], - - [[ 24.], - [124.]]]]]], - - - - - - [[[[[[ 13.], - [113.]], - - [[ 23.], - [123.]]], - - - [[[ 14.], - [114.]], - - [[ 24.], - [124.]]], - - - [[[ 15.], - [115.]], - - [[ 25.], - [125.]]]], - - - - [[[[ 14.], - [114.]], - - [[ 24.], - [124.]]], - - - [[[ 15.], - [115.]], - - [[ 25.], - [125.]]], - - - [[[ 16.], - [116.]], - - [[ 26.], - [126.]]]], - - - - [[[[ 15.], - [115.]], - - [[ 25.], - [125.]]], - - - [[[ 16.], - [116.]], - - [[ 26.], - [126.]]], - - - [[[ 17.], - [117.]], - - [[ 27.], - [127.]]]], - - - - [[[[ 16.], - [116.]], - - [[ 26.], - [126.]]], - - - [[[ 17.], - [117.]], - - [[ 27.], - [127.]]], - - - [[[ 18.], - [118.]], - - [[ 28.], - [128.]]]]], - - - - - [[[[[ 17.], - [117.]], - - [[ 27.], - [127.]]], - - - [[[ 18.], - [118.]], - - [[ 28.], - [128.]]], - - - [[[ 19.], - [119.]], - - [[ 29.], - [129.]]]], - - - - [[[[ 18.], - [118.]], - - [[ 28.], - [128.]]], - - - [[[ 19.], - [119.]], - - [[ 29.], - [129.]]], - - - [[[ 20.], - [120.]], - - [[ 30.], - [130.]]]], - - - - [[[[ 19.], - [119.]], - - [[ 29.], - [129.]]], - - - [[[ 20.], - [120.]], - - [[ 30.], - [130.]]], - - - [[[ 21.], - [121.]], - - [[ 31.], - [131.]]]], - - - - [[[[ 20.], - [120.]], - - [[ 30.], - [130.]]], - - - [[[ 21.], - [121.]], - - [[ 31.], - [131.]]], - - - [[[ 22.], - [122.]], - - [[ 32.], - [132.]]]]], - - - - - [[[[[ 21.], - [121.]], - - [[ 31.], - [131.]]], - - - [[[ 22.], - [122.]], - - [[ 32.], - [132.]]], - - - [[[ 23.], - [123.]], - - [[ 33.], - [133.]]]], - - - - [[[[ 22.], - [122.]], - - [[ 32.], - [132.]]], - - - [[[ 23.], - [123.]], - - [[ 33.], - [133.]]], - - - [[[ 24.], - [124.]], - - [[ 34.], - [134.]]]], - - - - [[[[ 23.], - [123.]], - - [[ 33.], - [133.]]], - - - [[[ 24.], - [124.]], - - [[ 34.], - [134.]]], - - - [[[ 25.], - [125.]], - - [[ 35.], - [135.]]]], - - - - [[[[ 24.], - [124.]], - - [[ 34.], - [134.]]], - - - [[[ 25.], - [125.]], - - [[ 35.], - [135.]]], - - - [[[ 26.], - [126.]], - - [[ 36.], - [136.]]]]]]]), size=(2, 3, 6, 6, 2, 1), nnz=4, - layout=torch.sparse_bsr) -# _crow_indices -tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], - - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]) -# _col_indices -tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]) -# _values -tensor([[[[[[[ 1.], - [101.]], - - [[ 11.], - [111.]]], - - - [[[ 2.], - [102.]], - - [[ 12.], - [112.]]], - - - [[[ 3.], - [103.]], - - [[ 13.], - [113.]]]], - - - - [[[[ 2.], - [102.]], - - [[ 12.], - [112.]]], - - - [[[ 3.], - [103.]], - - [[ 13.], - [113.]]], - - - [[[ 4.], - [104.]], - - [[ 14.], - [114.]]]], - - - - [[[[ 3.], - [103.]], - - [[ 13.], - [113.]]], - - - [[[ 4.], - [104.]], - - [[ 14.], - [114.]]], - - - [[[ 5.], - [105.]], - - [[ 15.], - [115.]]]], - - - - [[[[ 4.], - [104.]], - - [[ 14.], - [114.]]], - - - [[[ 5.], - [105.]], - - [[ 15.], - [115.]]], - - - [[[ 6.], - [106.]], - - [[ 16.], - [116.]]]]], - - - - - [[[[[ 5.], - [105.]], - - [[ 15.], - [115.]]], - - - [[[ 6.], - [106.]], - - [[ 16.], - [116.]]], - - - [[[ 7.], - [107.]], - - [[ 17.], - [117.]]]], - - - - [[[[ 6.], - [106.]], - - [[ 16.], - [116.]]], - - - [[[ 7.], - [107.]], - - [[ 17.], - [117.]]], - - - [[[ 8.], - [108.]], - - [[ 18.], - [118.]]]], - - - - [[[[ 7.], - [107.]], - - [[ 17.], - [117.]]], - - - [[[ 8.], - [108.]], - - [[ 18.], - [118.]]], - - - [[[ 9.], - [109.]], - - [[ 19.], - [119.]]]], - - - - [[[[ 8.], - [108.]], - - [[ 18.], - [118.]]], - - - [[[ 9.], - [109.]], - - [[ 19.], - [119.]]], - - - [[[ 10.], - [110.]], - - [[ 20.], - [120.]]]]], - - - - - [[[[[ 9.], - [109.]], - - [[ 19.], - [119.]]], - - - [[[ 10.], - [110.]], - - [[ 20.], - [120.]]], - - - [[[ 11.], - [111.]], - - [[ 21.], - [121.]]]], - - - - [[[[ 10.], - [110.]], - - [[ 20.], - [120.]]], - - - [[[ 11.], - [111.]], - - [[ 21.], - [121.]]], - - - [[[ 12.], - [112.]], - - [[ 22.], - [122.]]]], - - - - [[[[ 11.], - [111.]], - - [[ 21.], - [121.]]], - - - [[[ 12.], - [112.]], - - [[ 22.], - [122.]]], - - - [[[ 13.], - [113.]], - - [[ 23.], - [123.]]]], - - - - [[[[ 12.], - [112.]], - - [[ 22.], - [122.]]], - - - [[[ 13.], - [113.]], - - [[ 23.], - [123.]]], - - - [[[ 14.], - [114.]], - - [[ 24.], - [124.]]]]]], - - - - - - [[[[[[ 13.], - [113.]], - - [[ 23.], - [123.]]], - - - [[[ 14.], - [114.]], - - [[ 24.], - [124.]]], - - - [[[ 15.], - [115.]], - - [[ 25.], - [125.]]]], - - - - [[[[ 14.], - [114.]], - - [[ 24.], - [124.]]], - - - [[[ 15.], - [115.]], - - [[ 25.], - [125.]]], - - - [[[ 16.], - [116.]], - - [[ 26.], - [126.]]]], - - - - [[[[ 15.], - [115.]], - - [[ 25.], - [125.]]], - - - [[[ 16.], - [116.]], - - [[ 26.], - [126.]]], - - - [[[ 17.], - [117.]], - - [[ 27.], - [127.]]]], - - - - [[[[ 16.], - [116.]], - - [[ 26.], - [126.]]], - - - [[[ 17.], - [117.]], - - [[ 27.], - [127.]]], - - - [[[ 18.], - [118.]], - - [[ 28.], - [128.]]]]], - - - - - [[[[[ 17.], - [117.]], - - [[ 27.], - [127.]]], - - - [[[ 18.], - [118.]], - - [[ 28.], - [128.]]], - - - [[[ 19.], - [119.]], - - [[ 29.], - [129.]]]], - - - - [[[[ 18.], - [118.]], - - [[ 28.], - [128.]]], - - - [[[ 19.], - [119.]], - - [[ 29.], - [129.]]], - - - [[[ 20.], - [120.]], - - [[ 30.], - [130.]]]], - - - - [[[[ 19.], - [119.]], - - [[ 29.], - [129.]]], - - - [[[ 20.], - [120.]], - - [[ 30.], - [130.]]], - - - [[[ 21.], - [121.]], - - [[ 31.], - [131.]]]], - - - - [[[[ 20.], - [120.]], - - [[ 30.], - [130.]]], - - - [[[ 21.], - [121.]], - - [[ 31.], - [131.]]], - - - [[[ 22.], - [122.]], - - [[ 32.], - [132.]]]]], - - - - - [[[[[ 21.], - [121.]], - - [[ 31.], - [131.]]], - - - [[[ 22.], - [122.]], - - [[ 32.], - [132.]]], - - - [[[ 23.], - [123.]], - - [[ 33.], - [133.]]]], - - - - [[[[ 22.], - [122.]], - - [[ 32.], - [132.]]], - - - [[[ 23.], - [123.]], - - [[ 33.], - [133.]]], - - - [[[ 24.], - [124.]], - - [[ 34.], - [134.]]]], - - - - [[[[ 23.], - [123.]], - - [[ 33.], - [133.]]], - - - [[[ 24.], - [124.]], - - [[ 34.], - [134.]]], - - - [[[ 25.], - [125.]], - - [[ 35.], - [135.]]]], - - - - [[[[ 24.], - [124.]], - - [[ 34.], - [134.]]], - - - [[[ 25.], - [125.]], - - [[ 35.], - [135.]]], - - - [[[ 26.], - [126.]], - - [[ 36.], - [136.]]]]]]]) - - -########## torch.float64/torch.int64/size=()+(6, 6)+(2,) ########## -# sparse tensor -tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 2]), - values=tensor([[[[ 1., 101.], - [ 11., 111.]], - - [[ 2., 102.], - [ 12., 112.]], - - [[ 3., 103.], - [ 13., 113.]]], - - - [[[ 2., 102.], - [ 12., 112.]], - - [[ 3., 103.], - [ 13., 113.]], - - [[ 4., 104.], - [ 14., 114.]]], - - - [[[ 3., 103.], - [ 13., 113.]], - - [[ 4., 104.], - [ 14., 114.]], - - [[ 5., 105.], - [ 15., 115.]]], - - - [[[ 4., 104.], - [ 14., 114.]], - - [[ 5., 105.], - [ 15., 115.]], - - [[ 6., 106.], - [ 16., 116.]]]]), size=(6, 6, 2), nnz=4, - dtype=torch.float64, layout=torch.sparse_bsr) -# _crow_indices -tensor([0, 2, 4]) -# _col_indices -tensor([0, 1, 0, 2]) -# _values -tensor([[[[ 1., 101.], - [ 11., 111.]], - - [[ 2., 102.], - [ 12., 112.]], - - [[ 3., 103.], - [ 13., 113.]]], - - - [[[ 2., 102.], - [ 12., 112.]], - - [[ 3., 103.], - [ 13., 113.]], - - [[ 4., 104.], - [ 14., 114.]]], - - - [[[ 3., 103.], - [ 13., 113.]], - - [[ 4., 104.], - [ 14., 114.]], - - [[ 5., 105.], - [ 15., 115.]]], - - - [[[ 4., 104.], - [ 14., 114.]], - - [[ 5., 105.], - [ 15., 115.]], - - [[ 6., 106.], - [ 16., 116.]]]], dtype=torch.float64) - -########## torch.float64/torch.int64/size=()+(4, 9)+(4, 2) ########## -# sparse tensor -tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 2]), - values=tensor([[[[[1.0000e+00, 1.0010e+03], - [1.0100e+02, 1.1010e+03], - [2.0100e+02, 1.2010e+03], - [3.0100e+02, 1.3010e+03]], - - [[1.1000e+01, 1.0110e+03], - [1.1100e+02, 1.1110e+03], - [2.1100e+02, 1.2110e+03], - [3.1100e+02, 1.3110e+03]], - - [[2.1000e+01, 1.0210e+03], - [1.2100e+02, 1.1210e+03], - [2.2100e+02, 1.2210e+03], - [3.2100e+02, 1.3210e+03]]], - - - [[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], - - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]], - - [[2.2000e+01, 1.0220e+03], - [1.2200e+02, 1.1220e+03], - [2.2200e+02, 1.2220e+03], - [3.2200e+02, 1.3220e+03]]]], - - - - [[[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], - - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]], - - [[2.2000e+01, 1.0220e+03], - [1.2200e+02, 1.1220e+03], - [2.2200e+02, 1.2220e+03], - [3.2200e+02, 1.3220e+03]]], - - - [[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], - - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]], - - [[2.3000e+01, 1.0230e+03], - [1.2300e+02, 1.1230e+03], - [2.2300e+02, 1.2230e+03], - [3.2300e+02, 1.3230e+03]]]], - - - - [[[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], - - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]], - - [[2.3000e+01, 1.0230e+03], - [1.2300e+02, 1.1230e+03], - [2.2300e+02, 1.2230e+03], - [3.2300e+02, 1.3230e+03]]], - - - [[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], - - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]], - - [[2.4000e+01, 1.0240e+03], - [1.2400e+02, 1.1240e+03], - [2.2400e+02, 1.2240e+03], - [3.2400e+02, 1.3240e+03]]]], - - - - [[[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], - - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]], - - [[2.4000e+01, 1.0240e+03], - [1.2400e+02, 1.1240e+03], - [2.2400e+02, 1.2240e+03], - [3.2400e+02, 1.3240e+03]]], - - - [[[5.0000e+00, 1.0050e+03], - [1.0500e+02, 1.1050e+03], - [2.0500e+02, 1.2050e+03], - [3.0500e+02, 1.3050e+03]], - - [[1.5000e+01, 1.0150e+03], - [1.1500e+02, 1.1150e+03], - [2.1500e+02, 1.2150e+03], - [3.1500e+02, 1.3150e+03]], - - [[2.5000e+01, 1.0250e+03], - [1.2500e+02, 1.1250e+03], - [2.2500e+02, 1.2250e+03], - [3.2500e+02, 1.3250e+03]]]]]), size=(4, 9, 4, 2), - nnz=4, dtype=torch.float64, layout=torch.sparse_bsr) -# _crow_indices -tensor([0, 2, 4]) -# _col_indices -tensor([0, 1, 0, 2]) -# _values -tensor([[[[[1.0000e+00, 1.0010e+03], - [1.0100e+02, 1.1010e+03], - [2.0100e+02, 1.2010e+03], - [3.0100e+02, 1.3010e+03]], - - [[1.1000e+01, 1.0110e+03], - [1.1100e+02, 1.1110e+03], - [2.1100e+02, 1.2110e+03], - [3.1100e+02, 1.3110e+03]], - - [[2.1000e+01, 1.0210e+03], - [1.2100e+02, 1.1210e+03], - [2.2100e+02, 1.2210e+03], - [3.2100e+02, 1.3210e+03]]], - - - [[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], - - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]], - - [[2.2000e+01, 1.0220e+03], - [1.2200e+02, 1.1220e+03], - [2.2200e+02, 1.2220e+03], - [3.2200e+02, 1.3220e+03]]]], - - - - [[[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], - - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]], - - [[2.2000e+01, 1.0220e+03], - [1.2200e+02, 1.1220e+03], - [2.2200e+02, 1.2220e+03], - [3.2200e+02, 1.3220e+03]]], - - - [[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], - - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]], - - [[2.3000e+01, 1.0230e+03], - [1.2300e+02, 1.1230e+03], - [2.2300e+02, 1.2230e+03], - [3.2300e+02, 1.3230e+03]]]], - - - - [[[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], - - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]], - - [[2.3000e+01, 1.0230e+03], - [1.2300e+02, 1.1230e+03], - [2.2300e+02, 1.2230e+03], - [3.2300e+02, 1.3230e+03]]], - - - [[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], - - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]], - - [[2.4000e+01, 1.0240e+03], - [1.2400e+02, 1.1240e+03], - [2.2400e+02, 1.2240e+03], - [3.2400e+02, 1.3240e+03]]]], - - - - [[[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], - - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]], - - [[2.4000e+01, 1.0240e+03], - [1.2400e+02, 1.1240e+03], - [2.2400e+02, 1.2240e+03], - [3.2400e+02, 1.3240e+03]]], - - - [[[5.0000e+00, 1.0050e+03], - [1.0500e+02, 1.1050e+03], - [2.0500e+02, 1.2050e+03], - [3.0500e+02, 1.3050e+03]], - - [[1.5000e+01, 1.0150e+03], - [1.1500e+02, 1.1150e+03], - [2.1500e+02, 1.2150e+03], - [3.1500e+02, 1.3150e+03]], - - [[2.5000e+01, 1.0250e+03], - [1.2500e+02, 1.1250e+03], - [2.2500e+02, 1.2250e+03], - [3.2500e+02, 1.3250e+03]]]]], dtype=torch.float64) - -########## torch.float64/torch.int64/size=(2, 3)+(6, 6)+(2, 1) ########## -# sparse tensor -tensor(crow_indices=tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], - - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]), - col_indices=tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]), - values=tensor([[[[[[[ 1.], - [101.]], - - [[ 11.], - [111.]]], - - - [[[ 2.], - [102.]], - - [[ 12.], - [112.]]], - - - [[[ 3.], - [103.]], - - [[ 13.], - [113.]]]], - - - - [[[[ 2.], - [102.]], - - [[ 12.], - [112.]]], - - - [[[ 3.], - [103.]], - - [[ 13.], - [113.]]], - - - [[[ 4.], - [104.]], - - [[ 14.], - [114.]]]], - - - - [[[[ 3.], - [103.]], - - [[ 13.], - [113.]]], - - - [[[ 4.], - [104.]], - - [[ 14.], - [114.]]], - - - [[[ 5.], - [105.]], - - [[ 15.], - [115.]]]], - - - - [[[[ 4.], - [104.]], - - [[ 14.], - [114.]]], - - - [[[ 5.], - [105.]], - - [[ 15.], - [115.]]], - - - [[[ 6.], - [106.]], - - [[ 16.], - [116.]]]]], - - - - - [[[[[ 5.], - [105.]], - - [[ 15.], - [115.]]], - - - [[[ 6.], - [106.]], - - [[ 16.], - [116.]]], - - - [[[ 7.], - [107.]], - - [[ 17.], - [117.]]]], - - - - [[[[ 6.], - [106.]], - - [[ 16.], - [116.]]], - - - [[[ 7.], - [107.]], - - [[ 17.], - [117.]]], - - - [[[ 8.], - [108.]], - - [[ 18.], - [118.]]]], - - - - [[[[ 7.], - [107.]], - - [[ 17.], - [117.]]], - - - [[[ 8.], - [108.]], - - [[ 18.], - [118.]]], - - - [[[ 9.], - [109.]], - - [[ 19.], - [119.]]]], - - - - [[[[ 8.], - [108.]], - - [[ 18.], - [118.]]], - - - [[[ 9.], - [109.]], - - [[ 19.], - [119.]]], - - - [[[ 10.], - [110.]], - - [[ 20.], - [120.]]]]], - - - - - [[[[[ 9.], - [109.]], - - [[ 19.], - [119.]]], - - - [[[ 10.], - [110.]], - - [[ 20.], - [120.]]], - - - [[[ 11.], - [111.]], - - [[ 21.], - [121.]]]], - - - - [[[[ 10.], - [110.]], - - [[ 20.], - [120.]]], - - - [[[ 11.], - [111.]], - - [[ 21.], - [121.]]], - - - [[[ 12.], - [112.]], - - [[ 22.], - [122.]]]], - - - - [[[[ 11.], - [111.]], - - [[ 21.], - [121.]]], - - - [[[ 12.], - [112.]], - - [[ 22.], - [122.]]], - - - [[[ 13.], - [113.]], - - [[ 23.], - [123.]]]], - - - - [[[[ 12.], - [112.]], - - [[ 22.], - [122.]]], - - - [[[ 13.], - [113.]], - - [[ 23.], - [123.]]], - - - [[[ 14.], - [114.]], - - [[ 24.], - [124.]]]]]], - - - - - - [[[[[[ 13.], - [113.]], - - [[ 23.], - [123.]]], - - - [[[ 14.], - [114.]], - - [[ 24.], - [124.]]], - - - [[[ 15.], - [115.]], - - [[ 25.], - [125.]]]], - - - - [[[[ 14.], - [114.]], - - [[ 24.], - [124.]]], - - - [[[ 15.], - [115.]], - - [[ 25.], - [125.]]], - - - [[[ 16.], - [116.]], - - [[ 26.], - [126.]]]], - - - - [[[[ 15.], - [115.]], - - [[ 25.], - [125.]]], - - - [[[ 16.], - [116.]], - - [[ 26.], - [126.]]], - - - [[[ 17.], - [117.]], - - [[ 27.], - [127.]]]], - - - - [[[[ 16.], - [116.]], - - [[ 26.], - [126.]]], - - - [[[ 17.], - [117.]], - - [[ 27.], - [127.]]], - - - [[[ 18.], - [118.]], - - [[ 28.], - [128.]]]]], - - - - - [[[[[ 17.], - [117.]], - - [[ 27.], - [127.]]], - - - [[[ 18.], - [118.]], - - [[ 28.], - [128.]]], - - - [[[ 19.], - [119.]], - - [[ 29.], - [129.]]]], - - - - [[[[ 18.], - [118.]], - - [[ 28.], - [128.]]], - - - [[[ 19.], - [119.]], - - [[ 29.], - [129.]]], - - - [[[ 20.], - [120.]], - - [[ 30.], - [130.]]]], - - - - [[[[ 19.], - [119.]], - - [[ 29.], - [129.]]], - - - [[[ 20.], - [120.]], - - [[ 30.], - [130.]]], - - - [[[ 21.], - [121.]], - - [[ 31.], - [131.]]]], - - - - [[[[ 20.], - [120.]], - - [[ 30.], - [130.]]], - - - [[[ 21.], - [121.]], - - [[ 31.], - [131.]]], - - - [[[ 22.], - [122.]], - - [[ 32.], - [132.]]]]], - - - - - [[[[[ 21.], - [121.]], - - [[ 31.], - [131.]]], - - - [[[ 22.], - [122.]], - - [[ 32.], - [132.]]], - - - [[[ 23.], - [123.]], - - [[ 33.], - [133.]]]], - - - - [[[[ 22.], - [122.]], - - [[ 32.], - [132.]]], - - - [[[ 23.], - [123.]], - - [[ 33.], - [133.]]], - - - [[[ 24.], - [124.]], - - [[ 34.], - [134.]]]], - - - - [[[[ 23.], - [123.]], - - [[ 33.], - [133.]]], - - - [[[ 24.], - [124.]], - - [[ 34.], - [134.]]], - - - [[[ 25.], - [125.]], - - [[ 35.], - [135.]]]], - - - - [[[[ 24.], - [124.]], - - [[ 34.], - [134.]]], - - - [[[ 25.], - [125.]], - - [[ 35.], - [135.]]], - - - [[[ 26.], - [126.]], - - [[ 36.], - [136.]]]]]]]), size=(2, 3, 6, 6, 2, 1), nnz=4, + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]]]]), size=(8, 6, 4, 2), nnz=7, dtype=torch.float64, layout=torch.sparse_bsr) # _crow_indices -tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], - - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]) +tensor([0, 2, 3, 5, 7]) # _col_indices -tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]) +tensor([0, 1, 0, 0, 1, 0, 1]) # _values -tensor([[[[[[[ 1.], - [101.]], - - [[ 11.], - [111.]]], - - - [[[ 2.], - [102.]], - - [[ 12.], - [112.]]], - - - [[[ 3.], - [103.]], - - [[ 13.], - [113.]]]], - - - - [[[[ 2.], - [102.]], - - [[ 12.], - [112.]]], - - - [[[ 3.], - [103.]], - - [[ 13.], - [113.]]], - - - [[[ 4.], - [104.]], - - [[ 14.], - [114.]]]], - - - - [[[[ 3.], - [103.]], - - [[ 13.], - [113.]]], - - - [[[ 4.], - [104.]], - - [[ 14.], - [114.]]], - - - [[[ 5.], - [105.]], - - [[ 15.], - [115.]]]], - - - - [[[[ 4.], - [104.]], - - [[ 14.], - [114.]]], - - - [[[ 5.], - [105.]], - - [[ 15.], - [115.]]], - - - [[[ 6.], - [106.]], - - [[ 16.], - [116.]]]]], - - - - - [[[[[ 5.], - [105.]], - - [[ 15.], - [115.]]], - - - [[[ 6.], - [106.]], - - [[ 16.], - [116.]]], - - - [[[ 7.], - [107.]], - - [[ 17.], - [117.]]]], - - - - [[[[ 6.], - [106.]], - - [[ 16.], - [116.]]], - - - [[[ 7.], - [107.]], - - [[ 17.], - [117.]]], - - - [[[ 8.], - [108.]], - - [[ 18.], - [118.]]]], - - - - [[[[ 7.], - [107.]], - - [[ 17.], - [117.]]], - - - [[[ 8.], - [108.]], - - [[ 18.], - [118.]]], - - - [[[ 9.], - [109.]], - - [[ 19.], - [119.]]]], - - - - [[[[ 8.], - [108.]], - - [[ 18.], - [118.]]], - - - [[[ 9.], - [109.]], - - [[ 19.], - [119.]]], - - - [[[ 10.], - [110.]], - - [[ 20.], - [120.]]]]], - - - - - [[[[[ 9.], - [109.]], - - [[ 19.], - [119.]]], - - - [[[ 10.], - [110.]], - - [[ 20.], - [120.]]], - - - [[[ 11.], - [111.]], - - [[ 21.], - [121.]]]], - - - - [[[[ 10.], - [110.]], - - [[ 20.], - [120.]]], - - - [[[ 11.], - [111.]], - - [[ 21.], - [121.]]], - - - [[[ 12.], - [112.]], - - [[ 22.], - [122.]]]], - - - - [[[[ 11.], - [111.]], - - [[ 21.], - [121.]]], - - - [[[ 12.], - [112.]], - - [[ 22.], - [122.]]], - - - [[[ 13.], - [113.]], - - [[ 23.], - [123.]]]], - - - - [[[[ 12.], - [112.]], - - [[ 22.], - [122.]]], - - - [[[ 13.], - [113.]], - - [[ 23.], - [123.]]], - - - [[[ 14.], - [114.]], - - [[ 24.], - [124.]]]]]], - - - - - - [[[[[[ 13.], - [113.]], - - [[ 23.], - [123.]]], - - - [[[ 14.], - [114.]], - - [[ 24.], - [124.]]], - - - [[[ 15.], - [115.]], - - [[ 25.], - [125.]]]], - - - - [[[[ 14.], - [114.]], - - [[ 24.], - [124.]]], - - - [[[ 15.], - [115.]], - - [[ 25.], - [125.]]], - - - [[[ 16.], - [116.]], - - [[ 26.], - [126.]]]], - - - - [[[[ 15.], - [115.]], - - [[ 25.], - [125.]]], - - - [[[ 16.], - [116.]], - - [[ 26.], - [126.]]], - - - [[[ 17.], - [117.]], - - [[ 27.], - [127.]]]], - - - - [[[[ 16.], - [116.]], - - [[ 26.], - [126.]]], - - - [[[ 17.], - [117.]], - - [[ 27.], - [127.]]], - - - [[[ 18.], - [118.]], - - [[ 28.], - [128.]]]]], - - - - - [[[[[ 17.], - [117.]], - - [[ 27.], - [127.]]], - - - [[[ 18.], - [118.]], - - [[ 28.], - [128.]]], - - - [[[ 19.], - [119.]], - - [[ 29.], - [129.]]]], - - - - [[[[ 18.], - [118.]], - - [[ 28.], - [128.]]], - - - [[[ 19.], - [119.]], - - [[ 29.], - [129.]]], - - - [[[ 20.], - [120.]], - - [[ 30.], - [130.]]]], - +tensor([[[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], - [[[[ 19.], - [119.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[ 29.], - [129.]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[ 20.], - [120.]], + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]], - [[ 30.], - [130.]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[[ 21.], - [121.]], - [[ 31.], - [131.]]]], + [[[[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]]], - [[[[ 20.], - [120.]], - [[ 30.], - [130.]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], - [[[ 21.], - [121.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[ 31.], - [131.]]], - [[[ 22.], - [122.]], + [[[[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], - [[ 32.], - [132.]]]]], + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[[[ 21.], - [121.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 31.], - [131.]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[[ 22.], - [122.]], - [[ 32.], - [132.]]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], - [[[ 23.], - [123.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[ 33.], - [133.]]]], + [[[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[[ 22.], - [122.]], + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]]]], - [[ 32.], - [132.]]], - [[[ 23.], - [123.]], + [[[[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], - [[ 33.], - [133.]]], + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]]], - [[[ 24.], - [124.]], - [[ 34.], - [134.]]]], + [[[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]]]], - [[[[ 23.], - [123.]], - [[ 33.], - [133.]]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[ 24.], - [124.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 34.], - [134.]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[[ 25.], - [125.]], + [[[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], - [[ 35.], - [135.]]]], + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]]]], - [[[[ 24.], - [124.]], - [[ 34.], - [134.]]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], - [[[ 25.], - [125.]], + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]]], - [[ 35.], - [135.]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[ 26.], - [126.]], + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], - [[ 36.], - [136.]]]]]]], dtype=torch.float64) + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]]]], dtype=torch.float64) diff --git a/test/expect/TestSparseCompressedCPU.test_print_SparseCSC_cpu.expect b/test/expect/TestSparseCompressedCPU.test_print_SparseCSC_cpu.expect index 15e9bb56a85c7..70c00eb95db6a 100644 --- a/test/expect/TestSparseCompressedCPU.test_print_SparseCSC_cpu.expect +++ b/test/expect/TestSparseCompressedCPU.test_print_SparseCSC_cpu.expect @@ -1,1411 +1,1653 @@ -########## torch.float32/torch.int32/size=()+(3, 2)+() ########## +########## torch.float32/torch.int32/size=(2, 3)+(2, 3)+() ########## # sparse tensor -tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 2]), - values=tensor([1., 2., 3., 4.]), size=(3, 2), nnz=4, - layout=torch.sparse_csc) -# _ccol_indices -tensor([0, 2, 4], dtype=torch.int32) -# _row_indices -tensor([0, 1, 0, 2], dtype=torch.int32) -# _values -tensor([1., 2., 3., 4.]) - -########## torch.float32/torch.int32/size=()+(0, 0)+() ########## -# sparse tensor -tensor(ccol_indices=tensor([0]), - row_indices=tensor([], size=(0,)), - values=tensor([], size=(0,)), size=(0, 0), nnz=0, - layout=torch.sparse_csc) -# _ccol_indices -tensor([0], dtype=torch.int32) -# _row_indices -tensor([], dtype=torch.int32) -# _values -tensor([]) - -########## torch.float32/torch.int32/size=(2,)+(3, 2)+() ########## -# sparse tensor -tensor(ccol_indices=tensor([[0, 2, 4], - [0, 3, 4]]), - row_indices=tensor([[0, 1, 0, 1], - [0, 1, 2, 0]]), - values=tensor([[1., 2., 3., 4.], - [5., 6., 7., 8.]]), size=(2, 3, 2), nnz=4, - layout=torch.sparse_csc) -# _ccol_indices -tensor([[0, 2, 4], - [0, 3, 4]], dtype=torch.int32) -# _row_indices -tensor([[0, 1, 0, 1], - [0, 1, 2, 0]], dtype=torch.int32) -# _values -tensor([[1., 2., 3., 4.], - [5., 6., 7., 8.]]) - -########## torch.float32/torch.int32/size=(2, 3)+(3, 2)+() ########## -# sparse tensor -tensor(ccol_indices=tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], +tensor(ccol_indices=tensor([[[0, 2, 3, 4], + [0, 2, 3, 4], + [0, 2, 3, 4]], - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]), + [[0, 1, 3, 4], + [0, 2, 3, 4], + [0, 1, 3, 4]]]), row_indices=tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]), - values=tensor([[[ 1., 2., 3., 4.], - [ 5., 6., 7., 8.], - [ 9., 10., 11., 12.]], - - [[13., 14., 15., 16.], - [17., 18., 19., 20.], - [21., 22., 23., 24.]]]), size=(2, 3, 3, 2), nnz=4, + [0, 1, 0, 0], + [0, 1, 1, 1]], + + [[1, 0, 1, 1], + [0, 1, 1, 0], + [0, 0, 1, 0]]]), + values=tensor([[[1., 3., 2., 4.], + [1., 4., 2., 3.], + [1., 2., 3., 4.]], + + [[2., 1., 3., 4.], + [1., 3., 4., 2.], + [1., 2., 4., 3.]]]), size=(2, 3, 2, 3), nnz=4, layout=torch.sparse_csc) # _ccol_indices -tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], +tensor([[[0, 2, 3, 4], + [0, 2, 3, 4], + [0, 2, 3, 4]], - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]], dtype=torch.int32) + [[0, 1, 3, 4], + [0, 2, 3, 4], + [0, 1, 3, 4]]], dtype=torch.int32) # _row_indices tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], + [0, 1, 0, 0], + [0, 1, 1, 1]], - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]], dtype=torch.int32) + [[1, 0, 1, 1], + [0, 1, 1, 0], + [0, 0, 1, 0]]], dtype=torch.int32) # _values -tensor([[[ 1., 2., 3., 4.], - [ 5., 6., 7., 8.], - [ 9., 10., 11., 12.]], +tensor([[[1., 3., 2., 4.], + [1., 4., 2., 3.], + [1., 2., 3., 4.]], - [[13., 14., 15., 16.], - [17., 18., 19., 20.], - [21., 22., 23., 24.]]]) + [[2., 1., 3., 4.], + [1., 3., 4., 2.], + [1., 2., 4., 3.]]]) - -########## torch.float64/torch.int32/size=()+(3, 2)+() ########## +########## torch.float32/torch.int32/size=()+(8, 6)+() ########## # sparse tensor -tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 2]), - values=tensor([1., 2., 3., 4.]), size=(3, 2), nnz=4, - dtype=torch.float64, layout=torch.sparse_csc) +tensor(ccol_indices=tensor([ 0, 3, 8, 11, 14, 19, 24]), + row_indices=tensor([2, 5, 7, 0, 1, 2, 4, 7, 2, 5, 7, 0, 4, 5, 1, 4, 5, 6, + 7, 0, 4, 5, 6, 7]), + values=tensor([ 6., 13., 20., 1., 4., 7., 9., 21., 8., 14., 22., + 2., 10., 15., 5., 11., 16., 18., 23., 3., 12., 17., + 19., 24.]), size=(8, 6), nnz=24, layout=torch.sparse_csc) # _ccol_indices -tensor([0, 2, 4], dtype=torch.int32) +tensor([ 0, 3, 8, 11, 14, 19, 24], dtype=torch.int32) # _row_indices -tensor([0, 1, 0, 2], dtype=torch.int32) +tensor([2, 5, 7, 0, 1, 2, 4, 7, 2, 5, 7, 0, 4, 5, 1, 4, 5, 6, 7, 0, 4, 5, 6, 7], + dtype=torch.int32) # _values -tensor([1., 2., 3., 4.], dtype=torch.float64) +tensor([ 6., 13., 20., 1., 4., 7., 9., 21., 8., 14., 22., 2., 10., 15., + 5., 11., 16., 18., 23., 3., 12., 17., 19., 24.]) -########## torch.float64/torch.int32/size=()+(0, 0)+() ########## -# sparse tensor -tensor(ccol_indices=tensor([0]), - row_indices=tensor([], size=(0,)), - values=tensor([], size=(0,)), size=(0, 0), nnz=0, dtype=torch.float64, - layout=torch.sparse_csc) -# _ccol_indices -tensor([0], dtype=torch.int32) -# _row_indices -tensor([], dtype=torch.int32) -# _values -tensor([], dtype=torch.float64) -########## torch.float64/torch.int32/size=(2,)+(3, 2)+() ########## +########## torch.float64/torch.int32/size=(2, 3)+(2, 3)+() ########## # sparse tensor -tensor(ccol_indices=tensor([[0, 2, 4], - [0, 3, 4]]), - row_indices=tensor([[0, 1, 0, 1], - [0, 1, 2, 0]]), - values=tensor([[1., 2., 3., 4.], - [5., 6., 7., 8.]]), size=(2, 3, 2), nnz=4, - dtype=torch.float64, layout=torch.sparse_csc) -# _ccol_indices -tensor([[0, 2, 4], - [0, 3, 4]], dtype=torch.int32) -# _row_indices -tensor([[0, 1, 0, 1], - [0, 1, 2, 0]], dtype=torch.int32) -# _values -tensor([[1., 2., 3., 4.], - [5., 6., 7., 8.]], dtype=torch.float64) - -########## torch.float64/torch.int32/size=(2, 3)+(3, 2)+() ########## -# sparse tensor -tensor(ccol_indices=tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], +tensor(ccol_indices=tensor([[[0, 2, 3, 4], + [0, 2, 3, 4], + [0, 2, 3, 4]], - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]), + [[0, 1, 3, 4], + [0, 2, 3, 4], + [0, 1, 3, 4]]]), row_indices=tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]), - values=tensor([[[ 1., 2., 3., 4.], - [ 5., 6., 7., 8.], - [ 9., 10., 11., 12.]], - - [[13., 14., 15., 16.], - [17., 18., 19., 20.], - [21., 22., 23., 24.]]]), size=(2, 3, 3, 2), nnz=4, + [0, 1, 0, 0], + [0, 1, 1, 1]], + + [[1, 0, 1, 1], + [0, 1, 1, 0], + [0, 0, 1, 0]]]), + values=tensor([[[1., 3., 2., 4.], + [1., 4., 2., 3.], + [1., 2., 3., 4.]], + + [[2., 1., 3., 4.], + [1., 3., 4., 2.], + [1., 2., 4., 3.]]]), size=(2, 3, 2, 3), nnz=4, dtype=torch.float64, layout=torch.sparse_csc) # _ccol_indices -tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], +tensor([[[0, 2, 3, 4], + [0, 2, 3, 4], + [0, 2, 3, 4]], - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]], dtype=torch.int32) + [[0, 1, 3, 4], + [0, 2, 3, 4], + [0, 1, 3, 4]]], dtype=torch.int32) # _row_indices tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], + [0, 1, 0, 0], + [0, 1, 1, 1]], - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]], dtype=torch.int32) + [[1, 0, 1, 1], + [0, 1, 1, 0], + [0, 0, 1, 0]]], dtype=torch.int32) # _values -tensor([[[ 1., 2., 3., 4.], - [ 5., 6., 7., 8.], - [ 9., 10., 11., 12.]], - - [[13., 14., 15., 16.], - [17., 18., 19., 20.], - [21., 22., 23., 24.]]], dtype=torch.float64) +tensor([[[1., 3., 2., 4.], + [1., 4., 2., 3.], + [1., 2., 3., 4.]], + [[2., 1., 3., 4.], + [1., 3., 4., 2.], + [1., 2., 4., 3.]]], dtype=torch.float64) -########## torch.float32/torch.int64/size=()+(3, 2)+() ########## +########## torch.float64/torch.int32/size=()+(8, 6)+() ########## # sparse tensor -tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 2]), - values=tensor([1., 2., 3., 4.]), size=(3, 2), nnz=4, +tensor(ccol_indices=tensor([ 0, 3, 8, 11, 14, 19, 24]), + row_indices=tensor([2, 5, 7, 0, 1, 2, 4, 7, 2, 5, 7, 0, 4, 5, 1, 4, 5, 6, + 7, 0, 4, 5, 6, 7]), + values=tensor([ 6., 13., 20., 1., 4., 7., 9., 21., 8., 14., 22., + 2., 10., 15., 5., 11., 16., 18., 23., 3., 12., 17., + 19., 24.]), size=(8, 6), nnz=24, dtype=torch.float64, layout=torch.sparse_csc) # _ccol_indices -tensor([0, 2, 4]) +tensor([ 0, 3, 8, 11, 14, 19, 24], dtype=torch.int32) # _row_indices -tensor([0, 1, 0, 2]) +tensor([2, 5, 7, 0, 1, 2, 4, 7, 2, 5, 7, 0, 4, 5, 1, 4, 5, 6, 7, 0, 4, 5, 6, 7], + dtype=torch.int32) # _values -tensor([1., 2., 3., 4.]) +tensor([ 6., 13., 20., 1., 4., 7., 9., 21., 8., 14., 22., 2., 10., 15., + 5., 11., 16., 18., 23., 3., 12., 17., 19., 24.], dtype=torch.float64) -########## torch.float32/torch.int64/size=()+(0, 0)+() ########## + +########## torch.float32/torch.int64/size=(2, 3)+(2, 3)+() ########## # sparse tensor -tensor(ccol_indices=tensor([0]), - row_indices=tensor([], size=(0,)), - values=tensor([], size=(0,)), size=(0, 0), nnz=0, +tensor(ccol_indices=tensor([[[0, 2, 3, 4], + [0, 2, 3, 4], + [0, 2, 3, 4]], + + [[0, 1, 3, 4], + [0, 2, 3, 4], + [0, 1, 3, 4]]]), + row_indices=tensor([[[0, 1, 0, 1], + [0, 1, 0, 0], + [0, 1, 1, 1]], + + [[1, 0, 1, 1], + [0, 1, 1, 0], + [0, 0, 1, 0]]]), + values=tensor([[[1., 3., 2., 4.], + [1., 4., 2., 3.], + [1., 2., 3., 4.]], + + [[2., 1., 3., 4.], + [1., 3., 4., 2.], + [1., 2., 4., 3.]]]), size=(2, 3, 2, 3), nnz=4, layout=torch.sparse_csc) # _ccol_indices -tensor([0]) +tensor([[[0, 2, 3, 4], + [0, 2, 3, 4], + [0, 2, 3, 4]], + + [[0, 1, 3, 4], + [0, 2, 3, 4], + [0, 1, 3, 4]]]) # _row_indices -tensor([], dtype=torch.int64) +tensor([[[0, 1, 0, 1], + [0, 1, 0, 0], + [0, 1, 1, 1]], + + [[1, 0, 1, 1], + [0, 1, 1, 0], + [0, 0, 1, 0]]]) # _values -tensor([]) +tensor([[[1., 3., 2., 4.], + [1., 4., 2., 3.], + [1., 2., 3., 4.]], + + [[2., 1., 3., 4.], + [1., 3., 4., 2.], + [1., 2., 4., 3.]]]) -########## torch.float32/torch.int64/size=(2,)+(3, 2)+() ########## +########## torch.float32/torch.int64/size=()+(8, 6)+() ########## # sparse tensor -tensor(ccol_indices=tensor([[0, 2, 4], - [0, 3, 4]]), - row_indices=tensor([[0, 1, 0, 1], - [0, 1, 2, 0]]), - values=tensor([[1., 2., 3., 4.], - [5., 6., 7., 8.]]), size=(2, 3, 2), nnz=4, - layout=torch.sparse_csc) +tensor(ccol_indices=tensor([ 0, 3, 8, 11, 14, 19, 24]), + row_indices=tensor([2, 5, 7, 0, 1, 2, 4, 7, 2, 5, 7, 0, 4, 5, 1, 4, 5, 6, + 7, 0, 4, 5, 6, 7]), + values=tensor([ 6., 13., 20., 1., 4., 7., 9., 21., 8., 14., 22., + 2., 10., 15., 5., 11., 16., 18., 23., 3., 12., 17., + 19., 24.]), size=(8, 6), nnz=24, layout=torch.sparse_csc) # _ccol_indices -tensor([[0, 2, 4], - [0, 3, 4]]) +tensor([ 0, 3, 8, 11, 14, 19, 24]) # _row_indices -tensor([[0, 1, 0, 1], - [0, 1, 2, 0]]) +tensor([2, 5, 7, 0, 1, 2, 4, 7, 2, 5, 7, 0, 4, 5, 1, 4, 5, 6, 7, 0, 4, 5, 6, 7]) # _values -tensor([[1., 2., 3., 4.], - [5., 6., 7., 8.]]) +tensor([ 6., 13., 20., 1., 4., 7., 9., 21., 8., 14., 22., 2., 10., 15., + 5., 11., 16., 18., 23., 3., 12., 17., 19., 24.]) + -########## torch.float32/torch.int64/size=(2, 3)+(3, 2)+() ########## +########## torch.float64/torch.int64/size=(2, 3)+(2, 3)+() ########## # sparse tensor -tensor(ccol_indices=tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], +tensor(ccol_indices=tensor([[[0, 2, 3, 4], + [0, 2, 3, 4], + [0, 2, 3, 4]], - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]), + [[0, 1, 3, 4], + [0, 2, 3, 4], + [0, 1, 3, 4]]]), row_indices=tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]), - values=tensor([[[ 1., 2., 3., 4.], - [ 5., 6., 7., 8.], - [ 9., 10., 11., 12.]], - - [[13., 14., 15., 16.], - [17., 18., 19., 20.], - [21., 22., 23., 24.]]]), size=(2, 3, 3, 2), nnz=4, - layout=torch.sparse_csc) + [0, 1, 0, 0], + [0, 1, 1, 1]], + + [[1, 0, 1, 1], + [0, 1, 1, 0], + [0, 0, 1, 0]]]), + values=tensor([[[1., 3., 2., 4.], + [1., 4., 2., 3.], + [1., 2., 3., 4.]], + + [[2., 1., 3., 4.], + [1., 3., 4., 2.], + [1., 2., 4., 3.]]]), size=(2, 3, 2, 3), nnz=4, + dtype=torch.float64, layout=torch.sparse_csc) # _ccol_indices -tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], +tensor([[[0, 2, 3, 4], + [0, 2, 3, 4], + [0, 2, 3, 4]], - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]) + [[0, 1, 3, 4], + [0, 2, 3, 4], + [0, 1, 3, 4]]]) # _row_indices tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], + [0, 1, 0, 0], + [0, 1, 1, 1]], - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]) + [[1, 0, 1, 1], + [0, 1, 1, 0], + [0, 0, 1, 0]]]) # _values -tensor([[[ 1., 2., 3., 4.], - [ 5., 6., 7., 8.], - [ 9., 10., 11., 12.]], - - [[13., 14., 15., 16.], - [17., 18., 19., 20.], - [21., 22., 23., 24.]]]) +tensor([[[1., 3., 2., 4.], + [1., 4., 2., 3.], + [1., 2., 3., 4.]], + [[2., 1., 3., 4.], + [1., 3., 4., 2.], + [1., 2., 4., 3.]]], dtype=torch.float64) -########## torch.float64/torch.int64/size=()+(3, 2)+() ########## +########## torch.float64/torch.int64/size=()+(8, 6)+() ########## # sparse tensor -tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 2]), - values=tensor([1., 2., 3., 4.]), size=(3, 2), nnz=4, - dtype=torch.float64, layout=torch.sparse_csc) -# _ccol_indices -tensor([0, 2, 4]) -# _row_indices -tensor([0, 1, 0, 2]) -# _values -tensor([1., 2., 3., 4.], dtype=torch.float64) - -########## torch.float64/torch.int64/size=()+(0, 0)+() ########## -# sparse tensor -tensor(ccol_indices=tensor([0]), - row_indices=tensor([], size=(0,)), - values=tensor([], size=(0,)), size=(0, 0), nnz=0, dtype=torch.float64, +tensor(ccol_indices=tensor([ 0, 3, 8, 11, 14, 19, 24]), + row_indices=tensor([2, 5, 7, 0, 1, 2, 4, 7, 2, 5, 7, 0, 4, 5, 1, 4, 5, 6, + 7, 0, 4, 5, 6, 7]), + values=tensor([ 6., 13., 20., 1., 4., 7., 9., 21., 8., 14., 22., + 2., 10., 15., 5., 11., 16., 18., 23., 3., 12., 17., + 19., 24.]), size=(8, 6), nnz=24, dtype=torch.float64, layout=torch.sparse_csc) # _ccol_indices -tensor([0]) +tensor([ 0, 3, 8, 11, 14, 19, 24]) # _row_indices -tensor([], dtype=torch.int64) +tensor([2, 5, 7, 0, 1, 2, 4, 7, 2, 5, 7, 0, 4, 5, 1, 4, 5, 6, 7, 0, 4, 5, 6, 7]) # _values -tensor([], dtype=torch.float64) +tensor([ 6., 13., 20., 1., 4., 7., 9., 21., 8., 14., 22., 2., 10., 15., + 5., 11., 16., 18., 23., 3., 12., 17., 19., 24.], dtype=torch.float64) -########## torch.float64/torch.int64/size=(2,)+(3, 2)+() ########## -# sparse tensor -tensor(ccol_indices=tensor([[0, 2, 4], - [0, 3, 4]]), - row_indices=tensor([[0, 1, 0, 1], - [0, 1, 2, 0]]), - values=tensor([[1., 2., 3., 4.], - [5., 6., 7., 8.]]), size=(2, 3, 2), nnz=4, - dtype=torch.float64, layout=torch.sparse_csc) -# _ccol_indices -tensor([[0, 2, 4], - [0, 3, 4]]) -# _row_indices -tensor([[0, 1, 0, 1], - [0, 1, 2, 0]]) -# _values -tensor([[1., 2., 3., 4.], - [5., 6., 7., 8.]], dtype=torch.float64) -########## torch.float64/torch.int64/size=(2, 3)+(3, 2)+() ########## +########## torch.float32/torch.int32/size=(2, 3)+(2, 3)+(4,) ########## # sparse tensor -tensor(ccol_indices=tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], +tensor(ccol_indices=tensor([[[0, 2, 3, 4], + [0, 2, 3, 4], + [0, 2, 3, 4]], - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]), + [[0, 1, 3, 4], + [0, 2, 3, 4], + [0, 1, 3, 4]]]), row_indices=tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]), - values=tensor([[[ 1., 2., 3., 4.], - [ 5., 6., 7., 8.], - [ 9., 10., 11., 12.]], - - [[13., 14., 15., 16.], - [17., 18., 19., 20.], - [21., 22., 23., 24.]]]), size=(2, 3, 3, 2), nnz=4, - dtype=torch.float64, layout=torch.sparse_csc) + [0, 1, 0, 0], + [0, 1, 1, 1]], + + [[1, 0, 1, 1], + [0, 1, 1, 0], + [0, 0, 1, 0]]]), + values=tensor([[[[1., 2., 3., 4.], + [3., 4., 5., 6.], + [2., 3., 4., 5.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [4., 5., 6., 7.], + [2., 3., 4., 5.], + [3., 4., 5., 6.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]]], + + + [[[2., 3., 4., 5.], + [1., 2., 3., 4.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [3., 4., 5., 6.], + [4., 5., 6., 7.], + [2., 3., 4., 5.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [4., 5., 6., 7.], + [3., 4., 5., 6.]]]]), size=(2, 3, 2, 3, 4), nnz=4, + layout=torch.sparse_csc) # _ccol_indices -tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], +tensor([[[0, 2, 3, 4], + [0, 2, 3, 4], + [0, 2, 3, 4]], - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]) + [[0, 1, 3, 4], + [0, 2, 3, 4], + [0, 1, 3, 4]]], dtype=torch.int32) # _row_indices tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]) -# _values -tensor([[[ 1., 2., 3., 4.], - [ 5., 6., 7., 8.], - [ 9., 10., 11., 12.]], + [0, 1, 0, 0], + [0, 1, 1, 1]], - [[13., 14., 15., 16.], - [17., 18., 19., 20.], - [21., 22., 23., 24.]]], dtype=torch.float64) - - -########## torch.float32/torch.int32/size=()+(3, 2)+(2,) ########## -# sparse tensor -tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 2]), - values=tensor([[1., 2.], - [2., 3.], - [3., 4.], - [4., 5.]]), size=(3, 2, 2), nnz=4, - layout=torch.sparse_csc) -# _ccol_indices -tensor([0, 2, 4], dtype=torch.int32) -# _row_indices -tensor([0, 1, 0, 2], dtype=torch.int32) + [[1, 0, 1, 1], + [0, 1, 1, 0], + [0, 0, 1, 0]]], dtype=torch.int32) # _values -tensor([[1., 2.], - [2., 3.], - [3., 4.], - [4., 5.]]) - -########## torch.float32/torch.int32/size=()+(3, 2)+(4, 2) ########## +tensor([[[[1., 2., 3., 4.], + [3., 4., 5., 6.], + [2., 3., 4., 5.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [4., 5., 6., 7.], + [2., 3., 4., 5.], + [3., 4., 5., 6.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]]], + + + [[[2., 3., 4., 5.], + [1., 2., 3., 4.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [3., 4., 5., 6.], + [4., 5., 6., 7.], + [2., 3., 4., 5.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [4., 5., 6., 7.], + [3., 4., 5., 6.]]]]) + +########## torch.float32/torch.int32/size=()+(8, 6)+(4, 2) ########## # sparse tensor -tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 2]), - values=tensor([[[ 1., 11.], +tensor(ccol_indices=tensor([ 0, 3, 8, 11, 14, 19, 24]), + row_indices=tensor([2, 5, 7, 0, 1, 2, 4, 7, 2, 5, 7, 0, 4, 5, 1, 4, 5, 6, + 7, 0, 4, 5, 6, 7]), + values=tensor([[[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], + + [[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], + + [[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], + + [[ 1., 11.], [ 2., 12.], [ 3., 13.], [ 4., 14.]], + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]], + + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], + + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], + + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], + + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]], + + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]], + + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]], + [[ 2., 12.], [ 3., 13.], [ 4., 14.], [ 5., 15.]], + [[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], + + [[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], + + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], + + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], + + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], + + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], + + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], + [[ 3., 13.], [ 4., 14.], [ 5., 15.], [ 6., 16.]], - [[ 4., 14.], - [ 5., 15.], - [ 6., 16.], - [ 7., 17.]]]), size=(3, 2, 4, 2), nnz=4, + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]], + + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]], + + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]], + + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]]), size=(8, 6, 4, 2), nnz=24, layout=torch.sparse_csc) # _ccol_indices -tensor([0, 2, 4], dtype=torch.int32) +tensor([ 0, 3, 8, 11, 14, 19, 24], dtype=torch.int32) # _row_indices -tensor([0, 1, 0, 2], dtype=torch.int32) +tensor([2, 5, 7, 0, 1, 2, 4, 7, 2, 5, 7, 0, 4, 5, 1, 4, 5, 6, 7, 0, 4, 5, 6, 7], + dtype=torch.int32) # _values -tensor([[[ 1., 11.], +tensor([[[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], + + [[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], + + [[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], + + [[ 1., 11.], [ 2., 12.], [ 3., 13.], [ 4., 14.]], - [[ 2., 12.], - [ 3., 13.], - [ 4., 14.], - [ 5., 15.]], - - [[ 3., 13.], - [ 4., 14.], - [ 5., 15.], - [ 6., 16.]], - [[ 4., 14.], [ 5., 15.], [ 6., 16.], - [ 7., 17.]]]) - -########## torch.float32/torch.int32/size=(2, 3)+(3, 2)+(2, 1) ########## -# sparse tensor -tensor(ccol_indices=tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], - - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]), - row_indices=tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]), - values=tensor([[[[[ 1.], - [ 2.]], - - [[ 2.], - [ 3.]], - - [[ 3.], - [ 4.]], - - [[ 4.], - [ 5.]]], - - - [[[ 5.], - [ 6.]], - - [[ 6.], - [ 7.]], - - [[ 7.], - [ 8.]], - - [[ 8.], - [ 9.]]], - - - [[[ 9.], - [10.]], - - [[10.], - [11.]], - - [[11.], - [12.]], - - [[12.], - [13.]]]], + [ 7., 17.]], + + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], + + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], + + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], + + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]], + + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]], + + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]], + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], + [[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], - [[[[13.], - [14.]], + [[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], - [[14.], - [15.]], + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], - [[15.], - [16.]], + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], - [[16.], - [17.]]], + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], - [[[17.], - [18.]], + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], - [[18.], - [19.]], + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]], - [[19.], - [20.]], + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]], - [[20.], - [21.]]], + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]], + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]], - [[[21.], - [22.]], + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]]) - [[22.], - [23.]], - [[23.], - [24.]], +########## torch.float64/torch.int32/size=(2, 3)+(2, 3)+(4,) ########## +# sparse tensor +tensor(ccol_indices=tensor([[[0, 2, 3, 4], + [0, 2, 3, 4], + [0, 2, 3, 4]], - [[24.], - [25.]]]]]), size=(2, 3, 3, 2, 2, 1), nnz=4, - layout=torch.sparse_csc) + [[0, 1, 3, 4], + [0, 2, 3, 4], + [0, 1, 3, 4]]]), + row_indices=tensor([[[0, 1, 0, 1], + [0, 1, 0, 0], + [0, 1, 1, 1]], + + [[1, 0, 1, 1], + [0, 1, 1, 0], + [0, 0, 1, 0]]]), + values=tensor([[[[1., 2., 3., 4.], + [3., 4., 5., 6.], + [2., 3., 4., 5.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [4., 5., 6., 7.], + [2., 3., 4., 5.], + [3., 4., 5., 6.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]]], + + + [[[2., 3., 4., 5.], + [1., 2., 3., 4.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [3., 4., 5., 6.], + [4., 5., 6., 7.], + [2., 3., 4., 5.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [4., 5., 6., 7.], + [3., 4., 5., 6.]]]]), size=(2, 3, 2, 3, 4), nnz=4, + dtype=torch.float64, layout=torch.sparse_csc) # _ccol_indices -tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], +tensor([[[0, 2, 3, 4], + [0, 2, 3, 4], + [0, 2, 3, 4]], - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]], dtype=torch.int32) + [[0, 1, 3, 4], + [0, 2, 3, 4], + [0, 1, 3, 4]]], dtype=torch.int32) # _row_indices tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], + [0, 1, 0, 0], + [0, 1, 1, 1]], - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]], dtype=torch.int32) + [[1, 0, 1, 1], + [0, 1, 1, 0], + [0, 0, 1, 0]]], dtype=torch.int32) # _values -tensor([[[[[ 1.], - [ 2.]], - - [[ 2.], - [ 3.]], - - [[ 3.], - [ 4.]], - - [[ 4.], - [ 5.]]], - - - [[[ 5.], - [ 6.]], - - [[ 6.], - [ 7.]], - - [[ 7.], - [ 8.]], - - [[ 8.], - [ 9.]]], - - - [[[ 9.], - [10.]], - - [[10.], - [11.]], - - [[11.], - [12.]], - - [[12.], - [13.]]]], - - - - [[[[13.], - [14.]], - - [[14.], - [15.]], - - [[15.], - [16.]], - - [[16.], - [17.]]], - - - [[[17.], - [18.]], - - [[18.], - [19.]], - - [[19.], - [20.]], - - [[20.], - [21.]]], - - - [[[21.], - [22.]], - - [[22.], - [23.]], - - [[23.], - [24.]], - - [[24.], - [25.]]]]]) - - -########## torch.float64/torch.int32/size=()+(3, 2)+(2,) ########## -# sparse tensor -tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 2]), - values=tensor([[1., 2.], - [2., 3.], - [3., 4.], - [4., 5.]]), size=(3, 2, 2), nnz=4, dtype=torch.float64, - layout=torch.sparse_csc) -# _ccol_indices -tensor([0, 2, 4], dtype=torch.int32) -# _row_indices -tensor([0, 1, 0, 2], dtype=torch.int32) -# _values -tensor([[1., 2.], - [2., 3.], - [3., 4.], - [4., 5.]], dtype=torch.float64) - -########## torch.float64/torch.int32/size=()+(3, 2)+(4, 2) ########## +tensor([[[[1., 2., 3., 4.], + [3., 4., 5., 6.], + [2., 3., 4., 5.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [4., 5., 6., 7.], + [2., 3., 4., 5.], + [3., 4., 5., 6.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]]], + + + [[[2., 3., 4., 5.], + [1., 2., 3., 4.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [3., 4., 5., 6.], + [4., 5., 6., 7.], + [2., 3., 4., 5.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [4., 5., 6., 7.], + [3., 4., 5., 6.]]]], dtype=torch.float64) + +########## torch.float64/torch.int32/size=()+(8, 6)+(4, 2) ########## # sparse tensor -tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 2]), - values=tensor([[[ 1., 11.], +tensor(ccol_indices=tensor([ 0, 3, 8, 11, 14, 19, 24]), + row_indices=tensor([2, 5, 7, 0, 1, 2, 4, 7, 2, 5, 7, 0, 4, 5, 1, 4, 5, 6, + 7, 0, 4, 5, 6, 7]), + values=tensor([[[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], + + [[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], + + [[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], + + [[ 1., 11.], [ 2., 12.], [ 3., 13.], [ 4., 14.]], + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]], + + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], + + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], + + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], + + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]], + + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]], + + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]], + [[ 2., 12.], [ 3., 13.], [ 4., 14.], [ 5., 15.]], + [[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], + + [[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], + + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], + + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], + + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], + + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], + + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], + [[ 3., 13.], [ 4., 14.], [ 5., 15.], [ 6., 16.]], - [[ 4., 14.], - [ 5., 15.], - [ 6., 16.], - [ 7., 17.]]]), size=(3, 2, 4, 2), nnz=4, + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]], + + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]], + + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]], + + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]]), size=(8, 6, 4, 2), nnz=24, dtype=torch.float64, layout=torch.sparse_csc) # _ccol_indices -tensor([0, 2, 4], dtype=torch.int32) +tensor([ 0, 3, 8, 11, 14, 19, 24], dtype=torch.int32) # _row_indices -tensor([0, 1, 0, 2], dtype=torch.int32) +tensor([2, 5, 7, 0, 1, 2, 4, 7, 2, 5, 7, 0, 4, 5, 1, 4, 5, 6, 7, 0, 4, 5, 6, 7], + dtype=torch.int32) # _values -tensor([[[ 1., 11.], +tensor([[[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], + + [[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], + + [[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], + + [[ 1., 11.], [ 2., 12.], [ 3., 13.], [ 4., 14.]], - [[ 2., 12.], - [ 3., 13.], - [ 4., 14.], - [ 5., 15.]], - - [[ 3., 13.], - [ 4., 14.], - [ 5., 15.], - [ 6., 16.]], - [[ 4., 14.], [ 5., 15.], [ 6., 16.], - [ 7., 17.]]], dtype=torch.float64) - -########## torch.float64/torch.int32/size=(2, 3)+(3, 2)+(2, 1) ########## -# sparse tensor -tensor(ccol_indices=tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], - - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]), - row_indices=tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]), - values=tensor([[[[[ 1.], - [ 2.]], - - [[ 2.], - [ 3.]], - - [[ 3.], - [ 4.]], - - [[ 4.], - [ 5.]]], - - - [[[ 5.], - [ 6.]], - - [[ 6.], - [ 7.]], - - [[ 7.], - [ 8.]], - - [[ 8.], - [ 9.]]], - - - [[[ 9.], - [10.]], - - [[10.], - [11.]], - - [[11.], - [12.]], - - [[12.], - [13.]]]], + [ 7., 17.]], + + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], + + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], + + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], + + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]], + + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]], + + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]], + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], + [[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], - [[[[13.], - [14.]], + [[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], - [[14.], - [15.]], + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], - [[15.], - [16.]], + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], - [[16.], - [17.]]], + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], - [[[17.], - [18.]], + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], - [[18.], - [19.]], + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]], - [[19.], - [20.]], + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]], - [[20.], - [21.]]], + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]], + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]], - [[[21.], - [22.]], + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]], dtype=torch.float64) - [[22.], - [23.]], - [[23.], - [24.]], +########## torch.float32/torch.int64/size=(2, 3)+(2, 3)+(4,) ########## +# sparse tensor +tensor(ccol_indices=tensor([[[0, 2, 3, 4], + [0, 2, 3, 4], + [0, 2, 3, 4]], - [[24.], - [25.]]]]]), size=(2, 3, 3, 2, 2, 1), nnz=4, - dtype=torch.float64, layout=torch.sparse_csc) + [[0, 1, 3, 4], + [0, 2, 3, 4], + [0, 1, 3, 4]]]), + row_indices=tensor([[[0, 1, 0, 1], + [0, 1, 0, 0], + [0, 1, 1, 1]], + + [[1, 0, 1, 1], + [0, 1, 1, 0], + [0, 0, 1, 0]]]), + values=tensor([[[[1., 2., 3., 4.], + [3., 4., 5., 6.], + [2., 3., 4., 5.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [4., 5., 6., 7.], + [2., 3., 4., 5.], + [3., 4., 5., 6.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]]], + + + [[[2., 3., 4., 5.], + [1., 2., 3., 4.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [3., 4., 5., 6.], + [4., 5., 6., 7.], + [2., 3., 4., 5.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [4., 5., 6., 7.], + [3., 4., 5., 6.]]]]), size=(2, 3, 2, 3, 4), nnz=4, + layout=torch.sparse_csc) # _ccol_indices -tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], +tensor([[[0, 2, 3, 4], + [0, 2, 3, 4], + [0, 2, 3, 4]], - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]], dtype=torch.int32) + [[0, 1, 3, 4], + [0, 2, 3, 4], + [0, 1, 3, 4]]]) # _row_indices tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], + [0, 1, 0, 0], + [0, 1, 1, 1]], - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]], dtype=torch.int32) + [[1, 0, 1, 1], + [0, 1, 1, 0], + [0, 0, 1, 0]]]) # _values -tensor([[[[[ 1.], - [ 2.]], - - [[ 2.], - [ 3.]], - - [[ 3.], - [ 4.]], - - [[ 4.], - [ 5.]]], - - - [[[ 5.], - [ 6.]], - - [[ 6.], - [ 7.]], - - [[ 7.], - [ 8.]], - - [[ 8.], - [ 9.]]], - - - [[[ 9.], - [10.]], - - [[10.], - [11.]], - - [[11.], - [12.]], - - [[12.], - [13.]]]], - - - - [[[[13.], - [14.]], - - [[14.], - [15.]], - - [[15.], - [16.]], - - [[16.], - [17.]]], - - - [[[17.], - [18.]], - - [[18.], - [19.]], - - [[19.], - [20.]], - - [[20.], - [21.]]], - - - [[[21.], - [22.]], - - [[22.], - [23.]], - - [[23.], - [24.]], - - [[24.], - [25.]]]]], dtype=torch.float64) - - -########## torch.float32/torch.int64/size=()+(3, 2)+(2,) ########## +tensor([[[[1., 2., 3., 4.], + [3., 4., 5., 6.], + [2., 3., 4., 5.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [4., 5., 6., 7.], + [2., 3., 4., 5.], + [3., 4., 5., 6.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]]], + + + [[[2., 3., 4., 5.], + [1., 2., 3., 4.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [3., 4., 5., 6.], + [4., 5., 6., 7.], + [2., 3., 4., 5.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [4., 5., 6., 7.], + [3., 4., 5., 6.]]]]) + +########## torch.float32/torch.int64/size=()+(8, 6)+(4, 2) ########## # sparse tensor -tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 2]), - values=tensor([[1., 2.], - [2., 3.], - [3., 4.], - [4., 5.]]), size=(3, 2, 2), nnz=4, - layout=torch.sparse_csc) -# _ccol_indices -tensor([0, 2, 4]) -# _row_indices -tensor([0, 1, 0, 2]) -# _values -tensor([[1., 2.], - [2., 3.], - [3., 4.], - [4., 5.]]) - -########## torch.float32/torch.int64/size=()+(3, 2)+(4, 2) ########## -# sparse tensor -tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 2]), - values=tensor([[[ 1., 11.], +tensor(ccol_indices=tensor([ 0, 3, 8, 11, 14, 19, 24]), + row_indices=tensor([2, 5, 7, 0, 1, 2, 4, 7, 2, 5, 7, 0, 4, 5, 1, 4, 5, 6, + 7, 0, 4, 5, 6, 7]), + values=tensor([[[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], + + [[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], + + [[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], + + [[ 1., 11.], [ 2., 12.], [ 3., 13.], [ 4., 14.]], + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]], + + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], + + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], + + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], + + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]], + + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]], + + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]], + [[ 2., 12.], [ 3., 13.], [ 4., 14.], [ 5., 15.]], + [[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], + + [[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], + + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], + + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], + + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], + + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], + + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], + [[ 3., 13.], [ 4., 14.], [ 5., 15.], [ 6., 16.]], - [[ 4., 14.], - [ 5., 15.], - [ 6., 16.], - [ 7., 17.]]]), size=(3, 2, 4, 2), nnz=4, + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]], + + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]], + + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]], + + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]]), size=(8, 6, 4, 2), nnz=24, layout=torch.sparse_csc) # _ccol_indices -tensor([0, 2, 4]) +tensor([ 0, 3, 8, 11, 14, 19, 24]) # _row_indices -tensor([0, 1, 0, 2]) +tensor([2, 5, 7, 0, 1, 2, 4, 7, 2, 5, 7, 0, 4, 5, 1, 4, 5, 6, 7, 0, 4, 5, 6, 7]) # _values -tensor([[[ 1., 11.], +tensor([[[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], + + [[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], + + [[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], + + [[ 1., 11.], [ 2., 12.], [ 3., 13.], [ 4., 14.]], - [[ 2., 12.], - [ 3., 13.], - [ 4., 14.], - [ 5., 15.]], - - [[ 3., 13.], - [ 4., 14.], - [ 5., 15.], - [ 6., 16.]], - [[ 4., 14.], [ 5., 15.], [ 6., 16.], - [ 7., 17.]]]) - -########## torch.float32/torch.int64/size=(2, 3)+(3, 2)+(2, 1) ########## -# sparse tensor -tensor(ccol_indices=tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], - - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]), - row_indices=tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]), - values=tensor([[[[[ 1.], - [ 2.]], - - [[ 2.], - [ 3.]], - - [[ 3.], - [ 4.]], - - [[ 4.], - [ 5.]]], - - - [[[ 5.], - [ 6.]], - - [[ 6.], - [ 7.]], - - [[ 7.], - [ 8.]], - - [[ 8.], - [ 9.]]], - - - [[[ 9.], - [10.]], - - [[10.], - [11.]], - - [[11.], - [12.]], - - [[12.], - [13.]]]], + [ 7., 17.]], + + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], + + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], + + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], + + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]], + + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]], + + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]], + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], + [[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], - [[[[13.], - [14.]], + [[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], - [[14.], - [15.]], + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], - [[15.], - [16.]], + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], - [[16.], - [17.]]], + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], - [[[17.], - [18.]], + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], - [[18.], - [19.]], + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]], - [[19.], - [20.]], + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]], - [[20.], - [21.]]], + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]], + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]], - [[[21.], - [22.]], + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]]) - [[22.], - [23.]], - [[23.], - [24.]], +########## torch.float64/torch.int64/size=(2, 3)+(2, 3)+(4,) ########## +# sparse tensor +tensor(ccol_indices=tensor([[[0, 2, 3, 4], + [0, 2, 3, 4], + [0, 2, 3, 4]], - [[24.], - [25.]]]]]), size=(2, 3, 3, 2, 2, 1), nnz=4, - layout=torch.sparse_csc) + [[0, 1, 3, 4], + [0, 2, 3, 4], + [0, 1, 3, 4]]]), + row_indices=tensor([[[0, 1, 0, 1], + [0, 1, 0, 0], + [0, 1, 1, 1]], + + [[1, 0, 1, 1], + [0, 1, 1, 0], + [0, 0, 1, 0]]]), + values=tensor([[[[1., 2., 3., 4.], + [3., 4., 5., 6.], + [2., 3., 4., 5.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [4., 5., 6., 7.], + [2., 3., 4., 5.], + [3., 4., 5., 6.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]]], + + + [[[2., 3., 4., 5.], + [1., 2., 3., 4.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [3., 4., 5., 6.], + [4., 5., 6., 7.], + [2., 3., 4., 5.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [4., 5., 6., 7.], + [3., 4., 5., 6.]]]]), size=(2, 3, 2, 3, 4), nnz=4, + dtype=torch.float64, layout=torch.sparse_csc) # _ccol_indices -tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], +tensor([[[0, 2, 3, 4], + [0, 2, 3, 4], + [0, 2, 3, 4]], - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]) + [[0, 1, 3, 4], + [0, 2, 3, 4], + [0, 1, 3, 4]]]) # _row_indices tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], + [0, 1, 0, 0], + [0, 1, 1, 1]], - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]) + [[1, 0, 1, 1], + [0, 1, 1, 0], + [0, 0, 1, 0]]]) # _values -tensor([[[[[ 1.], - [ 2.]], - - [[ 2.], - [ 3.]], - - [[ 3.], - [ 4.]], - - [[ 4.], - [ 5.]]], - - - [[[ 5.], - [ 6.]], - - [[ 6.], - [ 7.]], - - [[ 7.], - [ 8.]], - - [[ 8.], - [ 9.]]], - - - [[[ 9.], - [10.]], - - [[10.], - [11.]], - - [[11.], - [12.]], - - [[12.], - [13.]]]], - - - - [[[[13.], - [14.]], - - [[14.], - [15.]], - - [[15.], - [16.]], - - [[16.], - [17.]]], - - - [[[17.], - [18.]], - - [[18.], - [19.]], - - [[19.], - [20.]], - - [[20.], - [21.]]], - - - [[[21.], - [22.]], - - [[22.], - [23.]], - - [[23.], - [24.]], - - [[24.], - [25.]]]]]) - - -########## torch.float64/torch.int64/size=()+(3, 2)+(2,) ########## -# sparse tensor -tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 2]), - values=tensor([[1., 2.], - [2., 3.], - [3., 4.], - [4., 5.]]), size=(3, 2, 2), nnz=4, dtype=torch.float64, - layout=torch.sparse_csc) -# _ccol_indices -tensor([0, 2, 4]) -# _row_indices -tensor([0, 1, 0, 2]) -# _values -tensor([[1., 2.], - [2., 3.], - [3., 4.], - [4., 5.]], dtype=torch.float64) - -########## torch.float64/torch.int64/size=()+(3, 2)+(4, 2) ########## +tensor([[[[1., 2., 3., 4.], + [3., 4., 5., 6.], + [2., 3., 4., 5.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [4., 5., 6., 7.], + [2., 3., 4., 5.], + [3., 4., 5., 6.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]]], + + + [[[2., 3., 4., 5.], + [1., 2., 3., 4.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [3., 4., 5., 6.], + [4., 5., 6., 7.], + [2., 3., 4., 5.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [4., 5., 6., 7.], + [3., 4., 5., 6.]]]], dtype=torch.float64) + +########## torch.float64/torch.int64/size=()+(8, 6)+(4, 2) ########## # sparse tensor -tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 2]), - values=tensor([[[ 1., 11.], +tensor(ccol_indices=tensor([ 0, 3, 8, 11, 14, 19, 24]), + row_indices=tensor([2, 5, 7, 0, 1, 2, 4, 7, 2, 5, 7, 0, 4, 5, 1, 4, 5, 6, + 7, 0, 4, 5, 6, 7]), + values=tensor([[[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], + + [[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], + + [[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], + + [[ 1., 11.], [ 2., 12.], [ 3., 13.], [ 4., 14.]], + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]], + + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], + + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], + + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], + + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]], + + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]], + + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]], + [[ 2., 12.], [ 3., 13.], [ 4., 14.], [ 5., 15.]], + [[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], + + [[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], + + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], + + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], + + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], + + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], + + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], + [[ 3., 13.], [ 4., 14.], [ 5., 15.], [ 6., 16.]], - [[ 4., 14.], - [ 5., 15.], - [ 6., 16.], - [ 7., 17.]]]), size=(3, 2, 4, 2), nnz=4, + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]], + + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]], + + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]], + + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]]), size=(8, 6, 4, 2), nnz=24, dtype=torch.float64, layout=torch.sparse_csc) # _ccol_indices -tensor([0, 2, 4]) +tensor([ 0, 3, 8, 11, 14, 19, 24]) # _row_indices -tensor([0, 1, 0, 2]) +tensor([2, 5, 7, 0, 1, 2, 4, 7, 2, 5, 7, 0, 4, 5, 1, 4, 5, 6, 7, 0, 4, 5, 6, 7]) # _values -tensor([[[ 1., 11.], +tensor([[[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], + + [[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], + + [[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], + + [[ 1., 11.], [ 2., 12.], [ 3., 13.], [ 4., 14.]], - [[ 2., 12.], - [ 3., 13.], - [ 4., 14.], - [ 5., 15.]], - - [[ 3., 13.], - [ 4., 14.], - [ 5., 15.], - [ 6., 16.]], - [[ 4., 14.], [ 5., 15.], [ 6., 16.], - [ 7., 17.]]], dtype=torch.float64) - -########## torch.float64/torch.int64/size=(2, 3)+(3, 2)+(2, 1) ########## -# sparse tensor -tensor(ccol_indices=tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], - - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]), - row_indices=tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]), - values=tensor([[[[[ 1.], - [ 2.]], - - [[ 2.], - [ 3.]], - - [[ 3.], - [ 4.]], - - [[ 4.], - [ 5.]]], - - - [[[ 5.], - [ 6.]], - - [[ 6.], - [ 7.]], - - [[ 7.], - [ 8.]], - - [[ 8.], - [ 9.]]], - - - [[[ 9.], - [10.]], - - [[10.], - [11.]], - - [[11.], - [12.]], - - [[12.], - [13.]]]], - - - - [[[[13.], - [14.]], - - [[14.], - [15.]], - - [[15.], - [16.]], - - [[16.], - [17.]]], - - - [[[17.], - [18.]], - - [[18.], - [19.]], - - [[19.], - [20.]], - - [[20.], - [21.]]], - - - [[[21.], - [22.]], - - [[22.], - [23.]], - - [[23.], - [24.]], - - [[24.], - [25.]]]]]), size=(2, 3, 3, 2, 2, 1), nnz=4, - dtype=torch.float64, layout=torch.sparse_csc) -# _ccol_indices -tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], - - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]) -# _row_indices -tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]) -# _values -tensor([[[[[ 1.], - [ 2.]], - - [[ 2.], - [ 3.]], - - [[ 3.], - [ 4.]], - - [[ 4.], - [ 5.]]], - - - [[[ 5.], - [ 6.]], - - [[ 6.], - [ 7.]], - - [[ 7.], - [ 8.]], - - [[ 8.], - [ 9.]]], - - - [[[ 9.], - [10.]], - - [[10.], - [11.]], - - [[11.], - [12.]], - - [[12.], - [13.]]]], - - - - [[[[13.], - [14.]], - - [[14.], - [15.]], - - [[15.], - [16.]], - - [[16.], - [17.]]], + [ 7., 17.]], + + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], + + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], + + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], + + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]], + + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]], + + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]], + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], - [[[17.], - [18.]], + [[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], - [[18.], - [19.]], + [[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], - [[19.], - [20.]], + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], - [[20.], - [21.]]], + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], - [[[21.], - [22.]], + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], - [[22.], - [23.]], + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], - [[23.], - [24.]], + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]], - [[24.], - [25.]]]]], dtype=torch.float64) + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]], + + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]], + + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]], + + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]], dtype=torch.float64) diff --git a/test/expect/TestSparseCompressedCPU.test_print_SparseCSR_cpu.expect b/test/expect/TestSparseCompressedCPU.test_print_SparseCSR_cpu.expect index 3ab2e1135aa55..f95a8a0819953 100644 --- a/test/expect/TestSparseCompressedCPU.test_print_SparseCSR_cpu.expect +++ b/test/expect/TestSparseCompressedCPU.test_print_SparseCSR_cpu.expect @@ -1,48 +1,3 @@ -########## torch.float32/torch.int32/size=()+(2, 3)+() ########## -# sparse tensor -tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 2]), - values=tensor([1., 2., 3., 4.]), size=(2, 3), nnz=4, - layout=torch.sparse_csr) -# _crow_indices -tensor([0, 2, 4], dtype=torch.int32) -# _col_indices -tensor([0, 1, 0, 2], dtype=torch.int32) -# _values -tensor([1., 2., 3., 4.]) - -########## torch.float32/torch.int32/size=()+(0, 0)+() ########## -# sparse tensor -tensor(crow_indices=tensor([0]), - col_indices=tensor([], size=(0,)), - values=tensor([], size=(0,)), size=(0, 0), nnz=0, - layout=torch.sparse_csr) -# _crow_indices -tensor([0], dtype=torch.int32) -# _col_indices -tensor([], dtype=torch.int32) -# _values -tensor([]) - -########## torch.float32/torch.int32/size=(2,)+(2, 3)+() ########## -# sparse tensor -tensor(crow_indices=tensor([[0, 2, 4], - [0, 3, 4]]), - col_indices=tensor([[0, 1, 0, 1], - [0, 1, 2, 0]]), - values=tensor([[1., 2., 3., 4.], - [5., 6., 7., 8.]]), size=(2, 2, 3), nnz=4, - layout=torch.sparse_csr) -# _crow_indices -tensor([[0, 2, 4], - [0, 3, 4]], dtype=torch.int32) -# _col_indices -tensor([[0, 1, 0, 1], - [0, 1, 2, 0]], dtype=torch.int32) -# _values -tensor([[1., 2., 3., 4.], - [5., 6., 7., 8.]]) - ########## torch.float32/torch.int32/size=(2, 3)+(2, 3)+() ########## # sparse tensor tensor(crow_indices=tensor([[[0, 2, 4], @@ -52,20 +7,20 @@ tensor(crow_indices=tensor([[[0, 2, 4], [[0, 1, 4], [0, 2, 4], [0, 3, 4]]]), - col_indices=tensor([[[0, 1, 0, 1], + col_indices=tensor([[[0, 1, 0, 2], [0, 1, 2, 0], [0, 0, 1, 2]], [[1, 0, 1, 2], [0, 2, 0, 1], [0, 1, 2, 1]]]), - values=tensor([[[ 1., 2., 3., 4.], - [ 5., 6., 7., 8.], - [ 9., 10., 11., 12.]], + values=tensor([[[1., 2., 3., 4.], + [1., 2., 3., 4.], + [1., 2., 3., 4.]], - [[13., 14., 15., 16.], - [17., 18., 19., 20.], - [21., 22., 23., 24.]]]), size=(2, 3, 2, 3), nnz=4, + [[1., 2., 3., 4.], + [1., 2., 3., 4.], + [1., 2., 3., 4.]]]), size=(2, 3, 2, 3), nnz=4, layout=torch.sparse_csr) # _crow_indices tensor([[[0, 2, 4], @@ -76,7 +31,7 @@ tensor([[[0, 2, 4], [0, 2, 4], [0, 3, 4]]], dtype=torch.int32) # _col_indices -tensor([[[0, 1, 0, 1], +tensor([[[0, 1, 0, 2], [0, 1, 2, 0], [0, 0, 1, 2]], @@ -84,59 +39,31 @@ tensor([[[0, 1, 0, 1], [0, 2, 0, 1], [0, 1, 2, 1]]], dtype=torch.int32) # _values -tensor([[[ 1., 2., 3., 4.], - [ 5., 6., 7., 8.], - [ 9., 10., 11., 12.]], - - [[13., 14., 15., 16.], - [17., 18., 19., 20.], - [21., 22., 23., 24.]]]) +tensor([[[1., 2., 3., 4.], + [1., 2., 3., 4.], + [1., 2., 3., 4.]], + [[1., 2., 3., 4.], + [1., 2., 3., 4.], + [1., 2., 3., 4.]]]) -########## torch.float64/torch.int32/size=()+(2, 3)+() ########## -# sparse tensor -tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 2]), - values=tensor([1., 2., 3., 4.]), size=(2, 3), nnz=4, - dtype=torch.float64, layout=torch.sparse_csr) -# _crow_indices -tensor([0, 2, 4], dtype=torch.int32) -# _col_indices -tensor([0, 1, 0, 2], dtype=torch.int32) -# _values -tensor([1., 2., 3., 4.], dtype=torch.float64) - -########## torch.float64/torch.int32/size=()+(0, 0)+() ########## +########## torch.float32/torch.int32/size=()+(8, 6)+() ########## # sparse tensor -tensor(crow_indices=tensor([0]), - col_indices=tensor([], size=(0,)), - values=tensor([], size=(0,)), size=(0, 0), nnz=0, dtype=torch.float64, - layout=torch.sparse_csr) +tensor(crow_indices=tensor([ 0, 3, 5, 8, 8, 12, 17, 19, 24]), + col_indices=tensor([1, 3, 5, 1, 4, 0, 1, 2, 1, 3, 4, 5, 0, 2, 3, 4, 5, 4, + 5, 0, 1, 2, 4, 5]), + values=tensor([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., + 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., + 23., 24.]), size=(8, 6), nnz=24, layout=torch.sparse_csr) # _crow_indices -tensor([0], dtype=torch.int32) +tensor([ 0, 3, 5, 8, 8, 12, 17, 19, 24], dtype=torch.int32) # _col_indices -tensor([], dtype=torch.int32) +tensor([1, 3, 5, 1, 4, 0, 1, 2, 1, 3, 4, 5, 0, 2, 3, 4, 5, 4, 5, 0, 1, 2, 4, 5], + dtype=torch.int32) # _values -tensor([], dtype=torch.float64) +tensor([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., + 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.]) -########## torch.float64/torch.int32/size=(2,)+(2, 3)+() ########## -# sparse tensor -tensor(crow_indices=tensor([[0, 2, 4], - [0, 3, 4]]), - col_indices=tensor([[0, 1, 0, 1], - [0, 1, 2, 0]]), - values=tensor([[1., 2., 3., 4.], - [5., 6., 7., 8.]]), size=(2, 2, 3), nnz=4, - dtype=torch.float64, layout=torch.sparse_csr) -# _crow_indices -tensor([[0, 2, 4], - [0, 3, 4]], dtype=torch.int32) -# _col_indices -tensor([[0, 1, 0, 1], - [0, 1, 2, 0]], dtype=torch.int32) -# _values -tensor([[1., 2., 3., 4.], - [5., 6., 7., 8.]], dtype=torch.float64) ########## torch.float64/torch.int32/size=(2, 3)+(2, 3)+() ########## # sparse tensor @@ -147,20 +74,20 @@ tensor(crow_indices=tensor([[[0, 2, 4], [[0, 1, 4], [0, 2, 4], [0, 3, 4]]]), - col_indices=tensor([[[0, 1, 0, 1], + col_indices=tensor([[[0, 1, 0, 2], [0, 1, 2, 0], [0, 0, 1, 2]], [[1, 0, 1, 2], [0, 2, 0, 1], [0, 1, 2, 1]]]), - values=tensor([[[ 1., 2., 3., 4.], - [ 5., 6., 7., 8.], - [ 9., 10., 11., 12.]], + values=tensor([[[1., 2., 3., 4.], + [1., 2., 3., 4.], + [1., 2., 3., 4.]], - [[13., 14., 15., 16.], - [17., 18., 19., 20.], - [21., 22., 23., 24.]]]), size=(2, 3, 2, 3), nnz=4, + [[1., 2., 3., 4.], + [1., 2., 3., 4.], + [1., 2., 3., 4.]]]), size=(2, 3, 2, 3), nnz=4, dtype=torch.float64, layout=torch.sparse_csr) # _crow_indices tensor([[[0, 2, 4], @@ -171,7 +98,7 @@ tensor([[[0, 2, 4], [0, 2, 4], [0, 3, 4]]], dtype=torch.int32) # _col_indices -tensor([[[0, 1, 0, 1], +tensor([[[0, 1, 0, 2], [0, 1, 2, 0], [0, 0, 1, 2]], @@ -179,59 +106,32 @@ tensor([[[0, 1, 0, 1], [0, 2, 0, 1], [0, 1, 2, 1]]], dtype=torch.int32) # _values -tensor([[[ 1., 2., 3., 4.], - [ 5., 6., 7., 8.], - [ 9., 10., 11., 12.]], - - [[13., 14., 15., 16.], - [17., 18., 19., 20.], - [21., 22., 23., 24.]]], dtype=torch.float64) +tensor([[[1., 2., 3., 4.], + [1., 2., 3., 4.], + [1., 2., 3., 4.]], + [[1., 2., 3., 4.], + [1., 2., 3., 4.], + [1., 2., 3., 4.]]], dtype=torch.float64) -########## torch.float32/torch.int64/size=()+(2, 3)+() ########## +########## torch.float64/torch.int32/size=()+(8, 6)+() ########## # sparse tensor -tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 2]), - values=tensor([1., 2., 3., 4.]), size=(2, 3), nnz=4, +tensor(crow_indices=tensor([ 0, 3, 5, 8, 8, 12, 17, 19, 24]), + col_indices=tensor([1, 3, 5, 1, 4, 0, 1, 2, 1, 3, 4, 5, 0, 2, 3, 4, 5, 4, + 5, 0, 1, 2, 4, 5]), + values=tensor([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., + 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., + 23., 24.]), size=(8, 6), nnz=24, dtype=torch.float64, layout=torch.sparse_csr) # _crow_indices -tensor([0, 2, 4]) +tensor([ 0, 3, 5, 8, 8, 12, 17, 19, 24], dtype=torch.int32) # _col_indices -tensor([0, 1, 0, 2]) +tensor([1, 3, 5, 1, 4, 0, 1, 2, 1, 3, 4, 5, 0, 2, 3, 4, 5, 4, 5, 0, 1, 2, 4, 5], + dtype=torch.int32) # _values -tensor([1., 2., 3., 4.]) +tensor([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., + 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.], dtype=torch.float64) -########## torch.float32/torch.int64/size=()+(0, 0)+() ########## -# sparse tensor -tensor(crow_indices=tensor([0]), - col_indices=tensor([], size=(0,)), - values=tensor([], size=(0,)), size=(0, 0), nnz=0, - layout=torch.sparse_csr) -# _crow_indices -tensor([0]) -# _col_indices -tensor([], dtype=torch.int64) -# _values -tensor([]) - -########## torch.float32/torch.int64/size=(2,)+(2, 3)+() ########## -# sparse tensor -tensor(crow_indices=tensor([[0, 2, 4], - [0, 3, 4]]), - col_indices=tensor([[0, 1, 0, 1], - [0, 1, 2, 0]]), - values=tensor([[1., 2., 3., 4.], - [5., 6., 7., 8.]]), size=(2, 2, 3), nnz=4, - layout=torch.sparse_csr) -# _crow_indices -tensor([[0, 2, 4], - [0, 3, 4]]) -# _col_indices -tensor([[0, 1, 0, 1], - [0, 1, 2, 0]]) -# _values -tensor([[1., 2., 3., 4.], - [5., 6., 7., 8.]]) ########## torch.float32/torch.int64/size=(2, 3)+(2, 3)+() ########## # sparse tensor @@ -242,20 +142,20 @@ tensor(crow_indices=tensor([[[0, 2, 4], [[0, 1, 4], [0, 2, 4], [0, 3, 4]]]), - col_indices=tensor([[[0, 1, 0, 1], + col_indices=tensor([[[0, 1, 0, 2], [0, 1, 2, 0], [0, 0, 1, 2]], [[1, 0, 1, 2], [0, 2, 0, 1], [0, 1, 2, 1]]]), - values=tensor([[[ 1., 2., 3., 4.], - [ 5., 6., 7., 8.], - [ 9., 10., 11., 12.]], + values=tensor([[[1., 2., 3., 4.], + [1., 2., 3., 4.], + [1., 2., 3., 4.]], - [[13., 14., 15., 16.], - [17., 18., 19., 20.], - [21., 22., 23., 24.]]]), size=(2, 3, 2, 3), nnz=4, + [[1., 2., 3., 4.], + [1., 2., 3., 4.], + [1., 2., 3., 4.]]]), size=(2, 3, 2, 3), nnz=4, layout=torch.sparse_csr) # _crow_indices tensor([[[0, 2, 4], @@ -266,7 +166,7 @@ tensor([[[0, 2, 4], [0, 2, 4], [0, 3, 4]]]) # _col_indices -tensor([[[0, 1, 0, 1], +tensor([[[0, 1, 0, 2], [0, 1, 2, 0], [0, 0, 1, 2]], @@ -274,59 +174,30 @@ tensor([[[0, 1, 0, 1], [0, 2, 0, 1], [0, 1, 2, 1]]]) # _values -tensor([[[ 1., 2., 3., 4.], - [ 5., 6., 7., 8.], - [ 9., 10., 11., 12.]], - - [[13., 14., 15., 16.], - [17., 18., 19., 20.], - [21., 22., 23., 24.]]]) +tensor([[[1., 2., 3., 4.], + [1., 2., 3., 4.], + [1., 2., 3., 4.]], + [[1., 2., 3., 4.], + [1., 2., 3., 4.], + [1., 2., 3., 4.]]]) -########## torch.float64/torch.int64/size=()+(2, 3)+() ########## -# sparse tensor -tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 2]), - values=tensor([1., 2., 3., 4.]), size=(2, 3), nnz=4, - dtype=torch.float64, layout=torch.sparse_csr) -# _crow_indices -tensor([0, 2, 4]) -# _col_indices -tensor([0, 1, 0, 2]) -# _values -tensor([1., 2., 3., 4.], dtype=torch.float64) - -########## torch.float64/torch.int64/size=()+(0, 0)+() ########## +########## torch.float32/torch.int64/size=()+(8, 6)+() ########## # sparse tensor -tensor(crow_indices=tensor([0]), - col_indices=tensor([], size=(0,)), - values=tensor([], size=(0,)), size=(0, 0), nnz=0, dtype=torch.float64, - layout=torch.sparse_csr) +tensor(crow_indices=tensor([ 0, 3, 5, 8, 8, 12, 17, 19, 24]), + col_indices=tensor([1, 3, 5, 1, 4, 0, 1, 2, 1, 3, 4, 5, 0, 2, 3, 4, 5, 4, + 5, 0, 1, 2, 4, 5]), + values=tensor([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., + 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., + 23., 24.]), size=(8, 6), nnz=24, layout=torch.sparse_csr) # _crow_indices -tensor([0]) +tensor([ 0, 3, 5, 8, 8, 12, 17, 19, 24]) # _col_indices -tensor([], dtype=torch.int64) +tensor([1, 3, 5, 1, 4, 0, 1, 2, 1, 3, 4, 5, 0, 2, 3, 4, 5, 4, 5, 0, 1, 2, 4, 5]) # _values -tensor([], dtype=torch.float64) +tensor([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., + 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.]) -########## torch.float64/torch.int64/size=(2,)+(2, 3)+() ########## -# sparse tensor -tensor(crow_indices=tensor([[0, 2, 4], - [0, 3, 4]]), - col_indices=tensor([[0, 1, 0, 1], - [0, 1, 2, 0]]), - values=tensor([[1., 2., 3., 4.], - [5., 6., 7., 8.]]), size=(2, 2, 3), nnz=4, - dtype=torch.float64, layout=torch.sparse_csr) -# _crow_indices -tensor([[0, 2, 4], - [0, 3, 4]]) -# _col_indices -tensor([[0, 1, 0, 1], - [0, 1, 2, 0]]) -# _values -tensor([[1., 2., 3., 4.], - [5., 6., 7., 8.]], dtype=torch.float64) ########## torch.float64/torch.int64/size=(2, 3)+(2, 3)+() ########## # sparse tensor @@ -337,20 +208,20 @@ tensor(crow_indices=tensor([[[0, 2, 4], [[0, 1, 4], [0, 2, 4], [0, 3, 4]]]), - col_indices=tensor([[[0, 1, 0, 1], + col_indices=tensor([[[0, 1, 0, 2], [0, 1, 2, 0], [0, 0, 1, 2]], [[1, 0, 1, 2], [0, 2, 0, 1], [0, 1, 2, 1]]]), - values=tensor([[[ 1., 2., 3., 4.], - [ 5., 6., 7., 8.], - [ 9., 10., 11., 12.]], + values=tensor([[[1., 2., 3., 4.], + [1., 2., 3., 4.], + [1., 2., 3., 4.]], - [[13., 14., 15., 16.], - [17., 18., 19., 20.], - [21., 22., 23., 24.]]]), size=(2, 3, 2, 3), nnz=4, + [[1., 2., 3., 4.], + [1., 2., 3., 4.], + [1., 2., 3., 4.]]]), size=(2, 3, 2, 3), nnz=4, dtype=torch.float64, layout=torch.sparse_csr) # _crow_indices tensor([[[0, 2, 4], @@ -361,7 +232,7 @@ tensor([[[0, 2, 4], [0, 2, 4], [0, 3, 4]]]) # _col_indices -tensor([[[0, 1, 0, 1], +tensor([[[0, 1, 0, 2], [0, 1, 2, 0], [0, 0, 1, 2]], @@ -369,84 +240,33 @@ tensor([[[0, 1, 0, 1], [0, 2, 0, 1], [0, 1, 2, 1]]]) # _values -tensor([[[ 1., 2., 3., 4.], - [ 5., 6., 7., 8.], - [ 9., 10., 11., 12.]], - - [[13., 14., 15., 16.], - [17., 18., 19., 20.], - [21., 22., 23., 24.]]], dtype=torch.float64) +tensor([[[1., 2., 3., 4.], + [1., 2., 3., 4.], + [1., 2., 3., 4.]], + [[1., 2., 3., 4.], + [1., 2., 3., 4.], + [1., 2., 3., 4.]]], dtype=torch.float64) -########## torch.float32/torch.int32/size=()+(2, 3)+(2,) ########## -# sparse tensor -tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 2]), - values=tensor([[1., 2.], - [2., 3.], - [3., 4.], - [4., 5.]]), size=(2, 3, 2), nnz=4, - layout=torch.sparse_csr) -# _crow_indices -tensor([0, 2, 4], dtype=torch.int32) -# _col_indices -tensor([0, 1, 0, 2], dtype=torch.int32) -# _values -tensor([[1., 2.], - [2., 3.], - [3., 4.], - [4., 5.]]) - -########## torch.float32/torch.int32/size=()+(2, 3)+(4, 2) ########## +########## torch.float64/torch.int64/size=()+(8, 6)+() ########## # sparse tensor -tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 2]), - values=tensor([[[ 1., 11.], - [ 2., 12.], - [ 3., 13.], - [ 4., 14.]], - - [[ 2., 12.], - [ 3., 13.], - [ 4., 14.], - [ 5., 15.]], - - [[ 3., 13.], - [ 4., 14.], - [ 5., 15.], - [ 6., 16.]], - - [[ 4., 14.], - [ 5., 15.], - [ 6., 16.], - [ 7., 17.]]]), size=(2, 3, 4, 2), nnz=4, +tensor(crow_indices=tensor([ 0, 3, 5, 8, 8, 12, 17, 19, 24]), + col_indices=tensor([1, 3, 5, 1, 4, 0, 1, 2, 1, 3, 4, 5, 0, 2, 3, 4, 5, 4, + 5, 0, 1, 2, 4, 5]), + values=tensor([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., + 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., + 23., 24.]), size=(8, 6), nnz=24, dtype=torch.float64, layout=torch.sparse_csr) # _crow_indices -tensor([0, 2, 4], dtype=torch.int32) +tensor([ 0, 3, 5, 8, 8, 12, 17, 19, 24]) # _col_indices -tensor([0, 1, 0, 2], dtype=torch.int32) +tensor([1, 3, 5, 1, 4, 0, 1, 2, 1, 3, 4, 5, 0, 2, 3, 4, 5, 4, 5, 0, 1, 2, 4, 5]) # _values -tensor([[[ 1., 11.], - [ 2., 12.], - [ 3., 13.], - [ 4., 14.]], +tensor([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., + 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.], dtype=torch.float64) - [[ 2., 12.], - [ 3., 13.], - [ 4., 14.], - [ 5., 15.]], - - [[ 3., 13.], - [ 4., 14.], - [ 5., 15.], - [ 6., 16.]], - [[ 4., 14.], - [ 5., 15.], - [ 6., 16.], - [ 7., 17.]]]) - -########## torch.float32/torch.int32/size=(2, 3)+(2, 3)+(2, 1) ########## +########## torch.float32/torch.int32/size=(2, 3)+(2, 3)+(4,) ########## # sparse tensor tensor(crow_indices=tensor([[[0, 2, 4], [0, 3, 4], @@ -455,90 +275,43 @@ tensor(crow_indices=tensor([[[0, 2, 4], [[0, 1, 4], [0, 2, 4], [0, 3, 4]]]), - col_indices=tensor([[[0, 1, 0, 1], + col_indices=tensor([[[0, 1, 0, 2], [0, 1, 2, 0], [0, 0, 1, 2]], [[1, 0, 1, 2], [0, 2, 0, 1], [0, 1, 2, 1]]]), - values=tensor([[[[[ 1.], - [ 2.]], - - [[ 2.], - [ 3.]], - - [[ 3.], - [ 4.]], - - [[ 4.], - [ 5.]]], - - - [[[ 5.], - [ 6.]], - - [[ 6.], - [ 7.]], - - [[ 7.], - [ 8.]], - - [[ 8.], - [ 9.]]], - - - [[[ 9.], - [10.]], - - [[10.], - [11.]], - - [[11.], - [12.]], - - [[12.], - [13.]]]], - - - - [[[[13.], - [14.]], - - [[14.], - [15.]], - - [[15.], - [16.]], - - [[16.], - [17.]]], - - - [[[17.], - [18.]], - - [[18.], - [19.]], - - [[19.], - [20.]], - - [[20.], - [21.]]], - - - [[[21.], - [22.]], - - [[22.], - [23.]], - - [[23.], - [24.]], - - [[24.], - [25.]]]]]), size=(2, 3, 2, 3, 2, 1), nnz=4, + values=tensor([[[[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]]], + + + [[[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]]]]), size=(2, 3, 2, 3, 4), nnz=4, layout=torch.sparse_csr) # _crow_indices tensor([[[0, 2, 4], @@ -549,7 +322,7 @@ tensor([[[0, 2, 4], [0, 2, 4], [0, 3, 4]]], dtype=torch.int32) # _col_indices -tensor([[[0, 1, 0, 1], +tensor([[[0, 1, 0, 2], [0, 1, 2, 0], [0, 0, 1, 2]], @@ -557,108 +330,42 @@ tensor([[[0, 1, 0, 1], [0, 2, 0, 1], [0, 1, 2, 1]]], dtype=torch.int32) # _values -tensor([[[[[ 1.], - [ 2.]], - - [[ 2.], - [ 3.]], - - [[ 3.], - [ 4.]], - - [[ 4.], - [ 5.]]], - - - [[[ 5.], - [ 6.]], - - [[ 6.], - [ 7.]], - - [[ 7.], - [ 8.]], - - [[ 8.], - [ 9.]]], - - - [[[ 9.], - [10.]], - - [[10.], - [11.]], - - [[11.], - [12.]], - - [[12.], - [13.]]]], - - - - [[[[13.], - [14.]], - - [[14.], - [15.]], - - [[15.], - [16.]], - - [[16.], - [17.]]], - - - [[[17.], - [18.]], - - [[18.], - [19.]], - - [[19.], - [20.]], - - [[20.], - [21.]]], - - - [[[21.], - [22.]], - - [[22.], - [23.]], - - [[23.], - [24.]], - - [[24.], - [25.]]]]]) - - -########## torch.float64/torch.int32/size=()+(2, 3)+(2,) ########## +tensor([[[[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]]], + + + [[[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]]]]) + +########## torch.float32/torch.int32/size=()+(8, 6)+(4, 2) ########## # sparse tensor -tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 2]), - values=tensor([[1., 2.], - [2., 3.], - [3., 4.], - [4., 5.]]), size=(2, 3, 2), nnz=4, dtype=torch.float64, - layout=torch.sparse_csr) -# _crow_indices -tensor([0, 2, 4], dtype=torch.int32) -# _col_indices -tensor([0, 1, 0, 2], dtype=torch.int32) -# _values -tensor([[1., 2.], - [2., 3.], - [3., 4.], - [4., 5.]], dtype=torch.float64) - -########## torch.float64/torch.int32/size=()+(2, 3)+(4, 2) ########## -# sparse tensor -tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 2]), +tensor(crow_indices=tensor([ 0, 3, 5, 8, 8, 12, 17, 19, 24]), + col_indices=tensor([1, 3, 5, 1, 4, 0, 1, 2, 1, 3, 4, 5, 0, 2, 3, 4, 5, 4, + 5, 0, 1, 2, 4, 5]), values=tensor([[[ 1., 11.], [ 2., 12.], [ 3., 13.], @@ -677,12 +384,113 @@ tensor(crow_indices=tensor([0, 2, 4]), [[ 4., 14.], [ 5., 15.], [ 6., 16.], - [ 7., 17.]]]), size=(2, 3, 4, 2), nnz=4, - dtype=torch.float64, layout=torch.sparse_csr) + [ 7., 17.]], + + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], + + [[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], + + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], + + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]], + + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], + + [[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], + + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], + + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]], + + [[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], + + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]], + + [[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], + + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], + + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]], + + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], + + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]], + + [[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], + + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], + + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]], + + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], + + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]]), size=(8, 6, 4, 2), nnz=24, + layout=torch.sparse_csr) # _crow_indices -tensor([0, 2, 4], dtype=torch.int32) +tensor([ 0, 3, 5, 8, 8, 12, 17, 19, 24], dtype=torch.int32) # _col_indices -tensor([0, 1, 0, 2], dtype=torch.int32) +tensor([1, 3, 5, 1, 4, 0, 1, 2, 1, 3, 4, 5, 0, 2, 3, 4, 5, 4, 5, 0, 1, 2, 4, 5], + dtype=torch.int32) # _values tensor([[[ 1., 11.], [ 2., 12.], @@ -702,9 +510,110 @@ tensor([[[ 1., 11.], [[ 4., 14.], [ 5., 15.], [ 6., 16.], - [ 7., 17.]]], dtype=torch.float64) + [ 7., 17.]], -########## torch.float64/torch.int32/size=(2, 3)+(2, 3)+(2, 1) ########## + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], + + [[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], + + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], + + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]], + + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], + + [[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], + + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], + + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]], + + [[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], + + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]], + + [[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], + + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], + + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]], + + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], + + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]], + + [[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], + + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], + + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]], + + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], + + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]]) + + +########## torch.float64/torch.int32/size=(2, 3)+(2, 3)+(4,) ########## # sparse tensor tensor(crow_indices=tensor([[[0, 2, 4], [0, 3, 4], @@ -713,90 +622,43 @@ tensor(crow_indices=tensor([[[0, 2, 4], [[0, 1, 4], [0, 2, 4], [0, 3, 4]]]), - col_indices=tensor([[[0, 1, 0, 1], + col_indices=tensor([[[0, 1, 0, 2], [0, 1, 2, 0], [0, 0, 1, 2]], [[1, 0, 1, 2], [0, 2, 0, 1], [0, 1, 2, 1]]]), - values=tensor([[[[[ 1.], - [ 2.]], - - [[ 2.], - [ 3.]], - - [[ 3.], - [ 4.]], - - [[ 4.], - [ 5.]]], - - - [[[ 5.], - [ 6.]], - - [[ 6.], - [ 7.]], - - [[ 7.], - [ 8.]], - - [[ 8.], - [ 9.]]], - - - [[[ 9.], - [10.]], - - [[10.], - [11.]], - - [[11.], - [12.]], - - [[12.], - [13.]]]], - - - - [[[[13.], - [14.]], - - [[14.], - [15.]], - - [[15.], - [16.]], - - [[16.], - [17.]]], - - - [[[17.], - [18.]], - - [[18.], - [19.]], - - [[19.], - [20.]], - - [[20.], - [21.]]], - - - [[[21.], - [22.]], - - [[22.], - [23.]], - - [[23.], - [24.]], - - [[24.], - [25.]]]]]), size=(2, 3, 2, 3, 2, 1), nnz=4, + values=tensor([[[[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]]], + + + [[[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]]]]), size=(2, 3, 2, 3, 4), nnz=4, dtype=torch.float64, layout=torch.sparse_csr) # _crow_indices tensor([[[0, 2, 4], @@ -807,7 +669,7 @@ tensor([[[0, 2, 4], [0, 2, 4], [0, 3, 4]]], dtype=torch.int32) # _col_indices -tensor([[[0, 1, 0, 1], +tensor([[[0, 1, 0, 2], [0, 1, 2, 0], [0, 0, 1, 2]], @@ -815,108 +677,42 @@ tensor([[[0, 1, 0, 1], [0, 2, 0, 1], [0, 1, 2, 1]]], dtype=torch.int32) # _values -tensor([[[[[ 1.], - [ 2.]], - - [[ 2.], - [ 3.]], - - [[ 3.], - [ 4.]], - - [[ 4.], - [ 5.]]], - - - [[[ 5.], - [ 6.]], - - [[ 6.], - [ 7.]], - - [[ 7.], - [ 8.]], - - [[ 8.], - [ 9.]]], - - - [[[ 9.], - [10.]], - - [[10.], - [11.]], - - [[11.], - [12.]], - - [[12.], - [13.]]]], - - - - [[[[13.], - [14.]], - - [[14.], - [15.]], - - [[15.], - [16.]], - - [[16.], - [17.]]], - - - [[[17.], - [18.]], - - [[18.], - [19.]], - - [[19.], - [20.]], - - [[20.], - [21.]]], - - - [[[21.], - [22.]], - - [[22.], - [23.]], - - [[23.], - [24.]], - - [[24.], - [25.]]]]], dtype=torch.float64) - - -########## torch.float32/torch.int64/size=()+(2, 3)+(2,) ########## -# sparse tensor -tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 2]), - values=tensor([[1., 2.], - [2., 3.], - [3., 4.], - [4., 5.]]), size=(2, 3, 2), nnz=4, - layout=torch.sparse_csr) -# _crow_indices -tensor([0, 2, 4]) -# _col_indices -tensor([0, 1, 0, 2]) -# _values -tensor([[1., 2.], - [2., 3.], - [3., 4.], - [4., 5.]]) - -########## torch.float32/torch.int64/size=()+(2, 3)+(4, 2) ########## +tensor([[[[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]]], + + + [[[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]]]], dtype=torch.float64) + +########## torch.float64/torch.int32/size=()+(8, 6)+(4, 2) ########## # sparse tensor -tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 2]), +tensor(crow_indices=tensor([ 0, 3, 5, 8, 8, 12, 17, 19, 24]), + col_indices=tensor([1, 3, 5, 1, 4, 0, 1, 2, 1, 3, 4, 5, 0, 2, 3, 4, 5, 4, + 5, 0, 1, 2, 4, 5]), values=tensor([[[ 1., 11.], [ 2., 12.], [ 3., 13.], @@ -935,12 +731,113 @@ tensor(crow_indices=tensor([0, 2, 4]), [[ 4., 14.], [ 5., 15.], [ 6., 16.], - [ 7., 17.]]]), size=(2, 3, 4, 2), nnz=4, - layout=torch.sparse_csr) + [ 7., 17.]], + + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], + + [[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], + + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], + + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]], + + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], + + [[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], + + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], + + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]], + + [[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], + + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]], + + [[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], + + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], + + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]], + + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], + + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]], + + [[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], + + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], + + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]], + + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], + + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]]), size=(8, 6, 4, 2), nnz=24, + dtype=torch.float64, layout=torch.sparse_csr) # _crow_indices -tensor([0, 2, 4]) +tensor([ 0, 3, 5, 8, 8, 12, 17, 19, 24], dtype=torch.int32) # _col_indices -tensor([0, 1, 0, 2]) +tensor([1, 3, 5, 1, 4, 0, 1, 2, 1, 3, 4, 5, 0, 2, 3, 4, 5, 4, 5, 0, 1, 2, 4, 5], + dtype=torch.int32) # _values tensor([[[ 1., 11.], [ 2., 12.], @@ -960,9 +857,110 @@ tensor([[[ 1., 11.], [[ 4., 14.], [ 5., 15.], [ 6., 16.], - [ 7., 17.]]]) + [ 7., 17.]], -########## torch.float32/torch.int64/size=(2, 3)+(2, 3)+(2, 1) ########## + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], + + [[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], + + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], + + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]], + + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], + + [[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], + + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], + + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]], + + [[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], + + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]], + + [[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], + + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], + + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]], + + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], + + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]], + + [[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], + + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], + + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]], + + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], + + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]], dtype=torch.float64) + + +########## torch.float32/torch.int64/size=(2, 3)+(2, 3)+(4,) ########## # sparse tensor tensor(crow_indices=tensor([[[0, 2, 4], [0, 3, 4], @@ -971,90 +969,43 @@ tensor(crow_indices=tensor([[[0, 2, 4], [[0, 1, 4], [0, 2, 4], [0, 3, 4]]]), - col_indices=tensor([[[0, 1, 0, 1], + col_indices=tensor([[[0, 1, 0, 2], [0, 1, 2, 0], [0, 0, 1, 2]], [[1, 0, 1, 2], [0, 2, 0, 1], [0, 1, 2, 1]]]), - values=tensor([[[[[ 1.], - [ 2.]], - - [[ 2.], - [ 3.]], - - [[ 3.], - [ 4.]], - - [[ 4.], - [ 5.]]], - - - [[[ 5.], - [ 6.]], - - [[ 6.], - [ 7.]], - - [[ 7.], - [ 8.]], - - [[ 8.], - [ 9.]]], - - - [[[ 9.], - [10.]], - - [[10.], - [11.]], - - [[11.], - [12.]], - - [[12.], - [13.]]]], - - - - [[[[13.], - [14.]], - - [[14.], - [15.]], - - [[15.], - [16.]], - - [[16.], - [17.]]], - - - [[[17.], - [18.]], - - [[18.], - [19.]], - - [[19.], - [20.]], - - [[20.], - [21.]]], - - - [[[21.], - [22.]], - - [[22.], - [23.]], - - [[23.], - [24.]], - - [[24.], - [25.]]]]]), size=(2, 3, 2, 3, 2, 1), nnz=4, + values=tensor([[[[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]]], + + + [[[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]]]]), size=(2, 3, 2, 3, 4), nnz=4, layout=torch.sparse_csr) # _crow_indices tensor([[[0, 2, 4], @@ -1065,7 +1016,7 @@ tensor([[[0, 2, 4], [0, 2, 4], [0, 3, 4]]]) # _col_indices -tensor([[[0, 1, 0, 1], +tensor([[[0, 1, 0, 2], [0, 1, 2, 0], [0, 0, 1, 2]], @@ -1073,108 +1024,42 @@ tensor([[[0, 1, 0, 1], [0, 2, 0, 1], [0, 1, 2, 1]]]) # _values -tensor([[[[[ 1.], - [ 2.]], - - [[ 2.], - [ 3.]], - - [[ 3.], - [ 4.]], - - [[ 4.], - [ 5.]]], - - - [[[ 5.], - [ 6.]], - - [[ 6.], - [ 7.]], - - [[ 7.], - [ 8.]], - - [[ 8.], - [ 9.]]], - - - [[[ 9.], - [10.]], - - [[10.], - [11.]], - - [[11.], - [12.]], - - [[12.], - [13.]]]], - - - - [[[[13.], - [14.]], - - [[14.], - [15.]], - - [[15.], - [16.]], - - [[16.], - [17.]]], - - - [[[17.], - [18.]], - - [[18.], - [19.]], - - [[19.], - [20.]], - - [[20.], - [21.]]], - - - [[[21.], - [22.]], - - [[22.], - [23.]], - - [[23.], - [24.]], - - [[24.], - [25.]]]]]) - - -########## torch.float64/torch.int64/size=()+(2, 3)+(2,) ########## -# sparse tensor -tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 2]), - values=tensor([[1., 2.], - [2., 3.], - [3., 4.], - [4., 5.]]), size=(2, 3, 2), nnz=4, dtype=torch.float64, - layout=torch.sparse_csr) -# _crow_indices -tensor([0, 2, 4]) -# _col_indices -tensor([0, 1, 0, 2]) -# _values -tensor([[1., 2.], - [2., 3.], - [3., 4.], - [4., 5.]], dtype=torch.float64) - -########## torch.float64/torch.int64/size=()+(2, 3)+(4, 2) ########## +tensor([[[[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]]], + + + [[[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]]]]) + +########## torch.float32/torch.int64/size=()+(8, 6)+(4, 2) ########## # sparse tensor -tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 2]), +tensor(crow_indices=tensor([ 0, 3, 5, 8, 8, 12, 17, 19, 24]), + col_indices=tensor([1, 3, 5, 1, 4, 0, 1, 2, 1, 3, 4, 5, 0, 2, 3, 4, 5, 4, + 5, 0, 1, 2, 4, 5]), values=tensor([[[ 1., 11.], [ 2., 12.], [ 3., 13.], @@ -1193,12 +1078,112 @@ tensor(crow_indices=tensor([0, 2, 4]), [[ 4., 14.], [ 5., 15.], [ 6., 16.], - [ 7., 17.]]]), size=(2, 3, 4, 2), nnz=4, - dtype=torch.float64, layout=torch.sparse_csr) + [ 7., 17.]], + + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], + + [[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], + + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], + + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]], + + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], + + [[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], + + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], + + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]], + + [[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], + + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]], + + [[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], + + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], + + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]], + + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], + + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]], + + [[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], + + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], + + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]], + + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], + + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]]), size=(8, 6, 4, 2), nnz=24, + layout=torch.sparse_csr) # _crow_indices -tensor([0, 2, 4]) +tensor([ 0, 3, 5, 8, 8, 12, 17, 19, 24]) # _col_indices -tensor([0, 1, 0, 2]) +tensor([1, 3, 5, 1, 4, 0, 1, 2, 1, 3, 4, 5, 0, 2, 3, 4, 5, 4, 5, 0, 1, 2, 4, 5]) # _values tensor([[[ 1., 11.], [ 2., 12.], @@ -1218,9 +1203,110 @@ tensor([[[ 1., 11.], [[ 4., 14.], [ 5., 15.], [ 6., 16.], - [ 7., 17.]]], dtype=torch.float64) + [ 7., 17.]], -########## torch.float64/torch.int64/size=(2, 3)+(2, 3)+(2, 1) ########## + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], + + [[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], + + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], + + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]], + + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], + + [[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], + + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], + + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]], + + [[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], + + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]], + + [[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], + + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], + + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]], + + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], + + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]], + + [[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], + + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], + + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]], + + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], + + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]]) + + +########## torch.float64/torch.int64/size=(2, 3)+(2, 3)+(4,) ########## # sparse tensor tensor(crow_indices=tensor([[[0, 2, 4], [0, 3, 4], @@ -1229,90 +1315,43 @@ tensor(crow_indices=tensor([[[0, 2, 4], [[0, 1, 4], [0, 2, 4], [0, 3, 4]]]), - col_indices=tensor([[[0, 1, 0, 1], + col_indices=tensor([[[0, 1, 0, 2], [0, 1, 2, 0], [0, 0, 1, 2]], [[1, 0, 1, 2], [0, 2, 0, 1], [0, 1, 2, 1]]]), - values=tensor([[[[[ 1.], - [ 2.]], - - [[ 2.], - [ 3.]], - - [[ 3.], - [ 4.]], - - [[ 4.], - [ 5.]]], - - - [[[ 5.], - [ 6.]], - - [[ 6.], - [ 7.]], - - [[ 7.], - [ 8.]], - - [[ 8.], - [ 9.]]], - - - [[[ 9.], - [10.]], - - [[10.], - [11.]], - - [[11.], - [12.]], - - [[12.], - [13.]]]], - - - - [[[[13.], - [14.]], - - [[14.], - [15.]], - - [[15.], - [16.]], - - [[16.], - [17.]]], - - - [[[17.], - [18.]], - - [[18.], - [19.]], - - [[19.], - [20.]], - - [[20.], - [21.]]], - - - [[[21.], - [22.]], - - [[22.], - [23.]], - - [[23.], - [24.]], - - [[24.], - [25.]]]]]), size=(2, 3, 2, 3, 2, 1), nnz=4, + values=tensor([[[[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]]], + + + [[[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]]]]), size=(2, 3, 2, 3, 4), nnz=4, dtype=torch.float64, layout=torch.sparse_csr) # _crow_indices tensor([[[0, 2, 4], @@ -1323,7 +1362,7 @@ tensor([[[0, 2, 4], [0, 2, 4], [0, 3, 4]]]) # _col_indices -tensor([[[0, 1, 0, 1], +tensor([[[0, 1, 0, 2], [0, 1, 2, 0], [0, 0, 1, 2]], @@ -1331,81 +1370,284 @@ tensor([[[0, 1, 0, 1], [0, 2, 0, 1], [0, 1, 2, 1]]]) # _values -tensor([[[[[ 1.], - [ 2.]], - - [[ 2.], - [ 3.]], - - [[ 3.], - [ 4.]], - - [[ 4.], - [ 5.]]], - - - [[[ 5.], - [ 6.]], - - [[ 6.], - [ 7.]], - - [[ 7.], - [ 8.]], - - [[ 8.], - [ 9.]]], - - - [[[ 9.], - [10.]], - - [[10.], - [11.]], - - [[11.], - [12.]], - - [[12.], - [13.]]]], - - - - [[[[13.], - [14.]], - - [[14.], - [15.]], - - [[15.], - [16.]], - - [[16.], - [17.]]], - - - [[[17.], - [18.]], +tensor([[[[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]]], + + + [[[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]]]], dtype=torch.float64) + +########## torch.float64/torch.int64/size=()+(8, 6)+(4, 2) ########## +# sparse tensor +tensor(crow_indices=tensor([ 0, 3, 5, 8, 8, 12, 17, 19, 24]), + col_indices=tensor([1, 3, 5, 1, 4, 0, 1, 2, 1, 3, 4, 5, 0, 2, 3, 4, 5, 4, + 5, 0, 1, 2, 4, 5]), + values=tensor([[[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], - [[18.], - [19.]], + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], - [[19.], - [20.]], + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]], - [[20.], - [21.]]], + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]], + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], + + [[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], + + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], + + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]], + + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], + + [[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], + + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], + + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]], + + [[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], + + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]], + + [[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], + + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], + + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]], + + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], + + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]], + + [[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], + + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], + + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]], + + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], + + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]]), size=(8, 6, 4, 2), nnz=24, + dtype=torch.float64, layout=torch.sparse_csr) +# _crow_indices +tensor([ 0, 3, 5, 8, 8, 12, 17, 19, 24]) +# _col_indices +tensor([1, 3, 5, 1, 4, 0, 1, 2, 1, 3, 4, 5, 0, 2, 3, 4, 5, 4, 5, 0, 1, 2, 4, 5]) +# _values +tensor([[[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], - [[[21.], - [22.]], + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], - [[22.], - [23.]], + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]], - [[23.], - [24.]], + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]], - [[24.], - [25.]]]]], dtype=torch.float64) + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], + + [[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], + + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], + + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]], + + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], + + [[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], + + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], + + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]], + + [[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], + + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]], + + [[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], + + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], + + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]], + + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], + + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]], + + [[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], + + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], + + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]], + + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], + + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]], dtype=torch.float64) diff --git a/test/expect/TestSparseCompressedCUDA.test_print_SparseBSC_cuda.expect b/test/expect/TestSparseCompressedCUDA.test_print_SparseBSC_cuda.expect index 46bdb44b2a983..9e563794f07bb 100644 --- a/test/expect/TestSparseCompressedCUDA.test_print_SparseBSC_cuda.expect +++ b/test/expect/TestSparseCompressedCUDA.test_print_SparseBSC_cuda.expect @@ -1,6981 +1,3583 @@ -########## torch.float32/torch.int32/size=()+(3, 4)+() ########## +########## torch.float32/torch.int32/size=(2, 3)+(2, 3)+() ########## # sparse tensor -tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 2]), - values=tensor([[[ 1., 11.]], +tensor(ccol_indices=tensor([[[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]], - [[ 2., 12.]], + [[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]]]), + row_indices=tensor([[[0, 0, 0], + [0, 0, 0], + [0, 0, 0]], - [[ 3., 13.]], - - [[ 4., 14.]]]), device='cuda:0', size=(3, 4), nnz=4, - layout=torch.sparse_bsc) -# _ccol_indices -tensor([0, 2, 4], device='cuda:0', dtype=torch.int32) -# _row_indices -tensor([0, 1, 0, 2], device='cuda:0', dtype=torch.int32) -# _values -tensor([[[ 1., 11.]], - - [[ 2., 12.]], - - [[ 3., 13.]], - - [[ 4., 14.]]], device='cuda:0') - -########## torch.float32/torch.int32/size=()+(0, 0)+() ########## -# sparse tensor -tensor(ccol_indices=tensor([0]), - row_indices=tensor([], size=(0,)), - values=tensor([], size=(0, 1, 2)), device='cuda:0', size=(0, 0), nnz=0, - layout=torch.sparse_bsc) -# _ccol_indices -tensor([0], device='cuda:0', dtype=torch.int32) -# _row_indices -tensor([], device='cuda:0', dtype=torch.int32) -# _values -tensor([], device='cuda:0', size=(0, 1, 2)) - -########## torch.float32/torch.int32/size=(2,)+(6, 2)+() ########## -# sparse tensor -tensor(ccol_indices=tensor([[0, 2, 4], - [0, 3, 4]]), - row_indices=tensor([[0, 1, 0, 1], - [0, 1, 2, 0]]), - values=tensor([[[[1.], - [2.]], - - [[2.], - [3.]], - - [[3.], - [4.]], - - [[4.], - [5.]]], - - - [[[5.], - [6.]], - - [[6.], - [7.]], - - [[7.], - [8.]], - - [[8.], - [9.]]]]), device='cuda:0', size=(2, 6, 2), nnz=4, - layout=torch.sparse_bsc) -# _ccol_indices -tensor([[0, 2, 4], - [0, 3, 4]], device='cuda:0', dtype=torch.int32) -# _row_indices -tensor([[0, 1, 0, 1], - [0, 1, 2, 0]], device='cuda:0', dtype=torch.int32) -# _values -tensor([[[[1.], - [2.]], - - [[2.], - [3.]], - - [[3.], - [4.]], - - [[4.], - [5.]]], - - - [[[5.], - [6.]], - - [[6.], - [7.]], - - [[7.], - [8.]], - - [[8.], - [9.]]]], device='cuda:0') - -########## torch.float32/torch.int32/size=(2, 3)+(9, 4)+() ########## -# sparse tensor -tensor(ccol_indices=tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], - - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]), - row_indices=tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]), - values=tensor([[[[[ 1., 11.], - [ 2., 12.], - [ 3., 13.]], - - [[ 2., 12.], - [ 3., 13.], - [ 4., 14.]], - - [[ 3., 13.], - [ 4., 14.], - [ 5., 15.]], - - [[ 4., 14.], - [ 5., 15.], - [ 6., 16.]]], - - - [[[ 5., 15.], - [ 6., 16.], - [ 7., 17.]], - - [[ 6., 16.], - [ 7., 17.], - [ 8., 18.]], - - [[ 7., 17.], - [ 8., 18.], - [ 9., 19.]], - - [[ 8., 18.], - [ 9., 19.], - [10., 20.]]], - - - [[[ 9., 19.], - [10., 20.], - [11., 21.]], - - [[10., 20.], - [11., 21.], - [12., 22.]], - - [[11., 21.], - [12., 22.], - [13., 23.]], - - [[12., 22.], - [13., 23.], - [14., 24.]]]], - - - - [[[[13., 23.], - [14., 24.], - [15., 25.]], - - [[14., 24.], - [15., 25.], - [16., 26.]], - - [[15., 25.], - [16., 26.], - [17., 27.]], - - [[16., 26.], - [17., 27.], - [18., 28.]]], - - - [[[17., 27.], - [18., 28.], - [19., 29.]], - - [[18., 28.], - [19., 29.], - [20., 30.]], - - [[19., 29.], - [20., 30.], - [21., 31.]], - - [[20., 30.], - [21., 31.], - [22., 32.]]], - - - [[[21., 31.], - [22., 32.], - [23., 33.]], - - [[22., 32.], - [23., 33.], - [24., 34.]], - - [[23., 33.], - [24., 34.], - [25., 35.]], - - [[24., 34.], - [25., 35.], - [26., 36.]]]]]), device='cuda:0', size=(2, 3, 9, 4), - nnz=4, layout=torch.sparse_bsc) -# _ccol_indices -tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], - - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]], device='cuda:0', dtype=torch.int32) -# _row_indices -tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]], device='cuda:0', dtype=torch.int32) -# _values -tensor([[[[[ 1., 11.], - [ 2., 12.], - [ 3., 13.]], - - [[ 2., 12.], - [ 3., 13.], - [ 4., 14.]], - - [[ 3., 13.], - [ 4., 14.], - [ 5., 15.]], - - [[ 4., 14.], - [ 5., 15.], - [ 6., 16.]]], - - - [[[ 5., 15.], - [ 6., 16.], - [ 7., 17.]], - - [[ 6., 16.], - [ 7., 17.], - [ 8., 18.]], - - [[ 7., 17.], - [ 8., 18.], - [ 9., 19.]], - - [[ 8., 18.], - [ 9., 19.], - [10., 20.]]], - - - [[[ 9., 19.], - [10., 20.], - [11., 21.]], - - [[10., 20.], - [11., 21.], - [12., 22.]], - - [[11., 21.], - [12., 22.], - [13., 23.]], - - [[12., 22.], - [13., 23.], - [14., 24.]]]], - - - - [[[[13., 23.], - [14., 24.], - [15., 25.]], - - [[14., 24.], - [15., 25.], - [16., 26.]], - - [[15., 25.], - [16., 26.], - [17., 27.]], - - [[16., 26.], - [17., 27.], - [18., 28.]]], - - - [[[17., 27.], - [18., 28.], - [19., 29.]], - - [[18., 28.], - [19., 29.], - [20., 30.]], - - [[19., 29.], - [20., 30.], - [21., 31.]], - - [[20., 30.], - [21., 31.], - [22., 32.]]], - - - [[[21., 31.], - [22., 32.], - [23., 33.]], - - [[22., 32.], - [23., 33.], - [24., 34.]], - - [[23., 33.], - [24., 34.], - [25., 35.]], - - [[24., 34.], - [25., 35.], - [26., 36.]]]]], device='cuda:0') - - -########## torch.float64/torch.int32/size=()+(3, 4)+() ########## -# sparse tensor -tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 2]), - values=tensor([[[ 1., 11.]], - - [[ 2., 12.]], - - [[ 3., 13.]], - - [[ 4., 14.]]]), device='cuda:0', size=(3, 4), nnz=4, - dtype=torch.float64, layout=torch.sparse_bsc) -# _ccol_indices -tensor([0, 2, 4], device='cuda:0', dtype=torch.int32) -# _row_indices -tensor([0, 1, 0, 2], device='cuda:0', dtype=torch.int32) -# _values -tensor([[[ 1., 11.]], - - [[ 2., 12.]], - - [[ 3., 13.]], - - [[ 4., 14.]]], device='cuda:0', dtype=torch.float64) - -########## torch.float64/torch.int32/size=()+(0, 0)+() ########## -# sparse tensor -tensor(ccol_indices=tensor([0]), - row_indices=tensor([], size=(0,)), - values=tensor([], size=(0, 1, 2)), device='cuda:0', size=(0, 0), nnz=0, - dtype=torch.float64, layout=torch.sparse_bsc) -# _ccol_indices -tensor([0], device='cuda:0', dtype=torch.int32) -# _row_indices -tensor([], device='cuda:0', dtype=torch.int32) -# _values -tensor([], device='cuda:0', size=(0, 1, 2), dtype=torch.float64) - -########## torch.float64/torch.int32/size=(2,)+(6, 2)+() ########## -# sparse tensor -tensor(ccol_indices=tensor([[0, 2, 4], - [0, 3, 4]]), - row_indices=tensor([[0, 1, 0, 1], - [0, 1, 2, 0]]), - values=tensor([[[[1.], - [2.]], - - [[2.], - [3.]], - - [[3.], - [4.]], - - [[4.], - [5.]]], - - - [[[5.], - [6.]], - - [[6.], - [7.]], - - [[7.], - [8.]], - - [[8.], - [9.]]]]), device='cuda:0', size=(2, 6, 2), nnz=4, - dtype=torch.float64, layout=torch.sparse_bsc) -# _ccol_indices -tensor([[0, 2, 4], - [0, 3, 4]], device='cuda:0', dtype=torch.int32) -# _row_indices -tensor([[0, 1, 0, 1], - [0, 1, 2, 0]], device='cuda:0', dtype=torch.int32) -# _values -tensor([[[[1.], - [2.]], - - [[2.], - [3.]], - - [[3.], - [4.]], - - [[4.], - [5.]]], - - - [[[5.], - [6.]], - - [[6.], - [7.]], - - [[7.], - [8.]], - - [[8.], - [9.]]]], device='cuda:0', dtype=torch.float64) - -########## torch.float64/torch.int32/size=(2, 3)+(9, 4)+() ########## -# sparse tensor -tensor(ccol_indices=tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], - - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]), - row_indices=tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]), - values=tensor([[[[[ 1., 11.], - [ 2., 12.], - [ 3., 13.]], - - [[ 2., 12.], - [ 3., 13.], - [ 4., 14.]], - - [[ 3., 13.], - [ 4., 14.], - [ 5., 15.]], - - [[ 4., 14.], - [ 5., 15.], - [ 6., 16.]]], - - - [[[ 5., 15.], - [ 6., 16.], - [ 7., 17.]], - - [[ 6., 16.], - [ 7., 17.], - [ 8., 18.]], - - [[ 7., 17.], - [ 8., 18.], - [ 9., 19.]], - - [[ 8., 18.], - [ 9., 19.], - [10., 20.]]], - - - [[[ 9., 19.], - [10., 20.], - [11., 21.]], - - [[10., 20.], - [11., 21.], - [12., 22.]], - - [[11., 21.], - [12., 22.], - [13., 23.]], - - [[12., 22.], - [13., 23.], - [14., 24.]]]], - - - - [[[[13., 23.], - [14., 24.], - [15., 25.]], - - [[14., 24.], - [15., 25.], - [16., 26.]], - - [[15., 25.], - [16., 26.], - [17., 27.]], - - [[16., 26.], - [17., 27.], - [18., 28.]]], - - - [[[17., 27.], - [18., 28.], - [19., 29.]], - - [[18., 28.], - [19., 29.], - [20., 30.]], - - [[19., 29.], - [20., 30.], - [21., 31.]], - - [[20., 30.], - [21., 31.], - [22., 32.]]], - - - [[[21., 31.], - [22., 32.], - [23., 33.]], - - [[22., 32.], - [23., 33.], - [24., 34.]], - - [[23., 33.], - [24., 34.], - [25., 35.]], - - [[24., 34.], - [25., 35.], - [26., 36.]]]]]), device='cuda:0', size=(2, 3, 9, 4), - nnz=4, dtype=torch.float64, layout=torch.sparse_bsc) -# _ccol_indices -tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], - - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]], device='cuda:0', dtype=torch.int32) -# _row_indices -tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]], device='cuda:0', dtype=torch.int32) -# _values -tensor([[[[[ 1., 11.], - [ 2., 12.], - [ 3., 13.]], - - [[ 2., 12.], - [ 3., 13.], - [ 4., 14.]], - - [[ 3., 13.], - [ 4., 14.], - [ 5., 15.]], - - [[ 4., 14.], - [ 5., 15.], - [ 6., 16.]]], - - - [[[ 5., 15.], - [ 6., 16.], - [ 7., 17.]], - - [[ 6., 16.], - [ 7., 17.], - [ 8., 18.]], - - [[ 7., 17.], - [ 8., 18.], - [ 9., 19.]], - - [[ 8., 18.], - [ 9., 19.], - [10., 20.]]], - - - [[[ 9., 19.], - [10., 20.], - [11., 21.]], - - [[10., 20.], - [11., 21.], - [12., 22.]], - - [[11., 21.], - [12., 22.], - [13., 23.]], - - [[12., 22.], - [13., 23.], - [14., 24.]]]], - - - - [[[[13., 23.], - [14., 24.], - [15., 25.]], - - [[14., 24.], - [15., 25.], - [16., 26.]], - - [[15., 25.], - [16., 26.], - [17., 27.]], - - [[16., 26.], - [17., 27.], - [18., 28.]]], - - - [[[17., 27.], - [18., 28.], - [19., 29.]], - - [[18., 28.], - [19., 29.], - [20., 30.]], - - [[19., 29.], - [20., 30.], - [21., 31.]], - - [[20., 30.], - [21., 31.], - [22., 32.]]], - - - [[[21., 31.], - [22., 32.], - [23., 33.]], - - [[22., 32.], - [23., 33.], - [24., 34.]], - - [[23., 33.], - [24., 34.], - [25., 35.]], - - [[24., 34.], - [25., 35.], - [26., 36.]]]]], device='cuda:0', dtype=torch.float64) - - -########## torch.float32/torch.int64/size=()+(3, 4)+() ########## -# sparse tensor -tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 2]), - values=tensor([[[ 1., 11.]], - - [[ 2., 12.]], - - [[ 3., 13.]], - - [[ 4., 14.]]]), device='cuda:0', size=(3, 4), nnz=4, - layout=torch.sparse_bsc) -# _ccol_indices -tensor([0, 2, 4], device='cuda:0') -# _row_indices -tensor([0, 1, 0, 2], device='cuda:0') -# _values -tensor([[[ 1., 11.]], - - [[ 2., 12.]], - - [[ 3., 13.]], - - [[ 4., 14.]]], device='cuda:0') - -########## torch.float32/torch.int64/size=()+(0, 0)+() ########## -# sparse tensor -tensor(ccol_indices=tensor([0]), - row_indices=tensor([], size=(0,)), - values=tensor([], size=(0, 1, 2)), device='cuda:0', size=(0, 0), nnz=0, - layout=torch.sparse_bsc) -# _ccol_indices -tensor([0], device='cuda:0') -# _row_indices -tensor([], device='cuda:0', dtype=torch.int64) -# _values -tensor([], device='cuda:0', size=(0, 1, 2)) - -########## torch.float32/torch.int64/size=(2,)+(6, 2)+() ########## -# sparse tensor -tensor(ccol_indices=tensor([[0, 2, 4], - [0, 3, 4]]), - row_indices=tensor([[0, 1, 0, 1], - [0, 1, 2, 0]]), - values=tensor([[[[1.], - [2.]], - - [[2.], - [3.]], - - [[3.], - [4.]], - - [[4.], - [5.]]], - - - [[[5.], - [6.]], - - [[6.], - [7.]], - - [[7.], - [8.]], - - [[8.], - [9.]]]]), device='cuda:0', size=(2, 6, 2), nnz=4, - layout=torch.sparse_bsc) -# _ccol_indices -tensor([[0, 2, 4], - [0, 3, 4]], device='cuda:0') -# _row_indices -tensor([[0, 1, 0, 1], - [0, 1, 2, 0]], device='cuda:0') -# _values -tensor([[[[1.], - [2.]], - - [[2.], - [3.]], - - [[3.], - [4.]], - - [[4.], - [5.]]], - - - [[[5.], - [6.]], - - [[6.], - [7.]], - - [[7.], - [8.]], - - [[8.], - [9.]]]], device='cuda:0') - -########## torch.float32/torch.int64/size=(2, 3)+(9, 4)+() ########## -# sparse tensor -tensor(ccol_indices=tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], - - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]), - row_indices=tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]), - values=tensor([[[[[ 1., 11.], - [ 2., 12.], - [ 3., 13.]], - - [[ 2., 12.], - [ 3., 13.], - [ 4., 14.]], - - [[ 3., 13.], - [ 4., 14.], - [ 5., 15.]], - - [[ 4., 14.], - [ 5., 15.], - [ 6., 16.]]], - - - [[[ 5., 15.], - [ 6., 16.], - [ 7., 17.]], - - [[ 6., 16.], - [ 7., 17.], - [ 8., 18.]], - - [[ 7., 17.], - [ 8., 18.], - [ 9., 19.]], - - [[ 8., 18.], - [ 9., 19.], - [10., 20.]]], - - - [[[ 9., 19.], - [10., 20.], - [11., 21.]], - - [[10., 20.], - [11., 21.], - [12., 22.]], - - [[11., 21.], - [12., 22.], - [13., 23.]], - - [[12., 22.], - [13., 23.], - [14., 24.]]]], - - - - [[[[13., 23.], - [14., 24.], - [15., 25.]], - - [[14., 24.], - [15., 25.], - [16., 26.]], - - [[15., 25.], - [16., 26.], - [17., 27.]], - - [[16., 26.], - [17., 27.], - [18., 28.]]], - - - [[[17., 27.], - [18., 28.], - [19., 29.]], - - [[18., 28.], - [19., 29.], - [20., 30.]], - - [[19., 29.], - [20., 30.], - [21., 31.]], - - [[20., 30.], - [21., 31.], - [22., 32.]]], - - - [[[21., 31.], - [22., 32.], - [23., 33.]], - - [[22., 32.], - [23., 33.], - [24., 34.]], - - [[23., 33.], - [24., 34.], - [25., 35.]], - - [[24., 34.], - [25., 35.], - [26., 36.]]]]]), device='cuda:0', size=(2, 3, 9, 4), - nnz=4, layout=torch.sparse_bsc) -# _ccol_indices -tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], - - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]], device='cuda:0') -# _row_indices -tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]], device='cuda:0') -# _values -tensor([[[[[ 1., 11.], - [ 2., 12.], - [ 3., 13.]], - - [[ 2., 12.], - [ 3., 13.], - [ 4., 14.]], - - [[ 3., 13.], - [ 4., 14.], - [ 5., 15.]], - - [[ 4., 14.], - [ 5., 15.], - [ 6., 16.]]], - - - [[[ 5., 15.], - [ 6., 16.], - [ 7., 17.]], - - [[ 6., 16.], - [ 7., 17.], - [ 8., 18.]], - - [[ 7., 17.], - [ 8., 18.], - [ 9., 19.]], - - [[ 8., 18.], - [ 9., 19.], - [10., 20.]]], - - - [[[ 9., 19.], - [10., 20.], - [11., 21.]], - - [[10., 20.], - [11., 21.], - [12., 22.]], - - [[11., 21.], - [12., 22.], - [13., 23.]], - - [[12., 22.], - [13., 23.], - [14., 24.]]]], - - - - [[[[13., 23.], - [14., 24.], - [15., 25.]], - - [[14., 24.], - [15., 25.], - [16., 26.]], - - [[15., 25.], - [16., 26.], - [17., 27.]], - - [[16., 26.], - [17., 27.], - [18., 28.]]], - - - [[[17., 27.], - [18., 28.], - [19., 29.]], - - [[18., 28.], - [19., 29.], - [20., 30.]], - - [[19., 29.], - [20., 30.], - [21., 31.]], - - [[20., 30.], - [21., 31.], - [22., 32.]]], - - - [[[21., 31.], - [22., 32.], - [23., 33.]], - - [[22., 32.], - [23., 33.], - [24., 34.]], - - [[23., 33.], - [24., 34.], - [25., 35.]], - - [[24., 34.], - [25., 35.], - [26., 36.]]]]], device='cuda:0') - - -########## torch.float64/torch.int64/size=()+(3, 4)+() ########## -# sparse tensor -tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 2]), - values=tensor([[[ 1., 11.]], - - [[ 2., 12.]], - - [[ 3., 13.]], - - [[ 4., 14.]]]), device='cuda:0', size=(3, 4), nnz=4, - dtype=torch.float64, layout=torch.sparse_bsc) -# _ccol_indices -tensor([0, 2, 4], device='cuda:0') -# _row_indices -tensor([0, 1, 0, 2], device='cuda:0') -# _values -tensor([[[ 1., 11.]], - - [[ 2., 12.]], - - [[ 3., 13.]], - - [[ 4., 14.]]], device='cuda:0', dtype=torch.float64) - -########## torch.float64/torch.int64/size=()+(0, 0)+() ########## -# sparse tensor -tensor(ccol_indices=tensor([0]), - row_indices=tensor([], size=(0,)), - values=tensor([], size=(0, 1, 2)), device='cuda:0', size=(0, 0), nnz=0, - dtype=torch.float64, layout=torch.sparse_bsc) -# _ccol_indices -tensor([0], device='cuda:0') -# _row_indices -tensor([], device='cuda:0', dtype=torch.int64) -# _values -tensor([], device='cuda:0', size=(0, 1, 2), dtype=torch.float64) - -########## torch.float64/torch.int64/size=(2,)+(6, 2)+() ########## -# sparse tensor -tensor(ccol_indices=tensor([[0, 2, 4], - [0, 3, 4]]), - row_indices=tensor([[0, 1, 0, 1], - [0, 1, 2, 0]]), - values=tensor([[[[1.], - [2.]], - - [[2.], - [3.]], - - [[3.], - [4.]], - - [[4.], - [5.]]], - - - [[[5.], - [6.]], - - [[6.], - [7.]], - - [[7.], - [8.]], - - [[8.], - [9.]]]]), device='cuda:0', size=(2, 6, 2), nnz=4, - dtype=torch.float64, layout=torch.sparse_bsc) -# _ccol_indices -tensor([[0, 2, 4], - [0, 3, 4]], device='cuda:0') -# _row_indices -tensor([[0, 1, 0, 1], - [0, 1, 2, 0]], device='cuda:0') -# _values -tensor([[[[1.], - [2.]], - - [[2.], - [3.]], - - [[3.], - [4.]], - - [[4.], - [5.]]], - - - [[[5.], - [6.]], - - [[6.], - [7.]], - - [[7.], - [8.]], - - [[8.], - [9.]]]], device='cuda:0', dtype=torch.float64) - -########## torch.float64/torch.int64/size=(2, 3)+(9, 4)+() ########## -# sparse tensor -tensor(ccol_indices=tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], - - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]), - row_indices=tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]), - values=tensor([[[[[ 1., 11.], - [ 2., 12.], - [ 3., 13.]], - - [[ 2., 12.], - [ 3., 13.], - [ 4., 14.]], - - [[ 3., 13.], - [ 4., 14.], - [ 5., 15.]], - - [[ 4., 14.], - [ 5., 15.], - [ 6., 16.]]], - - - [[[ 5., 15.], - [ 6., 16.], - [ 7., 17.]], - - [[ 6., 16.], - [ 7., 17.], - [ 8., 18.]], - - [[ 7., 17.], - [ 8., 18.], - [ 9., 19.]], - - [[ 8., 18.], - [ 9., 19.], - [10., 20.]]], - - - [[[ 9., 19.], - [10., 20.], - [11., 21.]], - - [[10., 20.], - [11., 21.], - [12., 22.]], - - [[11., 21.], - [12., 22.], - [13., 23.]], - - [[12., 22.], - [13., 23.], - [14., 24.]]]], - - - - [[[[13., 23.], - [14., 24.], - [15., 25.]], - - [[14., 24.], - [15., 25.], - [16., 26.]], - - [[15., 25.], - [16., 26.], - [17., 27.]], - - [[16., 26.], - [17., 27.], - [18., 28.]]], - - - [[[17., 27.], - [18., 28.], - [19., 29.]], - - [[18., 28.], - [19., 29.], - [20., 30.]], - - [[19., 29.], - [20., 30.], - [21., 31.]], - - [[20., 30.], - [21., 31.], - [22., 32.]]], - - - [[[21., 31.], - [22., 32.], - [23., 33.]], - - [[22., 32.], - [23., 33.], - [24., 34.]], - - [[23., 33.], - [24., 34.], - [25., 35.]], - - [[24., 34.], - [25., 35.], - [26., 36.]]]]]), device='cuda:0', size=(2, 3, 9, 4), - nnz=4, dtype=torch.float64, layout=torch.sparse_bsc) -# _ccol_indices -tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], - - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]], device='cuda:0') -# _row_indices -tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]], device='cuda:0') -# _values -tensor([[[[[ 1., 11.], - [ 2., 12.], - [ 3., 13.]], - - [[ 2., 12.], - [ 3., 13.], - [ 4., 14.]], - - [[ 3., 13.], - [ 4., 14.], - [ 5., 15.]], - - [[ 4., 14.], - [ 5., 15.], - [ 6., 16.]]], - - - [[[ 5., 15.], - [ 6., 16.], - [ 7., 17.]], - - [[ 6., 16.], - [ 7., 17.], - [ 8., 18.]], - - [[ 7., 17.], - [ 8., 18.], - [ 9., 19.]], - - [[ 8., 18.], - [ 9., 19.], - [10., 20.]]], - - - [[[ 9., 19.], - [10., 20.], - [11., 21.]], - - [[10., 20.], - [11., 21.], - [12., 22.]], - - [[11., 21.], - [12., 22.], - [13., 23.]], - - [[12., 22.], - [13., 23.], - [14., 24.]]]], - - - - [[[[13., 23.], - [14., 24.], - [15., 25.]], - - [[14., 24.], - [15., 25.], - [16., 26.]], - - [[15., 25.], - [16., 26.], - [17., 27.]], - - [[16., 26.], - [17., 27.], - [18., 28.]]], - - - [[[17., 27.], - [18., 28.], - [19., 29.]], - - [[18., 28.], - [19., 29.], - [20., 30.]], - - [[19., 29.], - [20., 30.], - [21., 31.]], - - [[20., 30.], - [21., 31.], - [22., 32.]]], - - - [[[21., 31.], - [22., 32.], - [23., 33.]], - - [[22., 32.], - [23., 33.], - [24., 34.]], - - [[23., 33.], - [24., 34.], - [25., 35.]], - - [[24., 34.], - [25., 35.], - [26., 36.]]]]], device='cuda:0', dtype=torch.float64) - - -########## torch.float32/torch.int32/size=()+(6, 6)+(2,) ########## -# sparse tensor -tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 2]), - values=tensor([[[[ 1., 101.], - [ 11., 111.], - [ 21., 121.]], - - [[ 2., 102.], - [ 12., 112.], - [ 22., 122.]]], - - - [[[ 2., 102.], - [ 12., 112.], - [ 22., 122.]], - - [[ 3., 103.], - [ 13., 113.], - [ 23., 123.]]], - - - [[[ 3., 103.], - [ 13., 113.], - [ 23., 123.]], - - [[ 4., 104.], - [ 14., 114.], - [ 24., 124.]]], - - - [[[ 4., 104.], - [ 14., 114.], - [ 24., 124.]], - - [[ 5., 105.], - [ 15., 115.], - [ 25., 125.]]]]), device='cuda:0', size=(6, 6, 2), - nnz=4, layout=torch.sparse_bsc) -# _ccol_indices -tensor([0, 2, 4], device='cuda:0', dtype=torch.int32) -# _row_indices -tensor([0, 1, 0, 2], device='cuda:0', dtype=torch.int32) -# _values -tensor([[[[ 1., 101.], - [ 11., 111.], - [ 21., 121.]], - - [[ 2., 102.], - [ 12., 112.], - [ 22., 122.]]], - - - [[[ 2., 102.], - [ 12., 112.], - [ 22., 122.]], - - [[ 3., 103.], - [ 13., 113.], - [ 23., 123.]]], - - - [[[ 3., 103.], - [ 13., 113.], - [ 23., 123.]], - - [[ 4., 104.], - [ 14., 114.], - [ 24., 124.]]], - - - [[[ 4., 104.], - [ 14., 114.], - [ 24., 124.]], - - [[ 5., 105.], - [ 15., 115.], - [ 25., 125.]]]], device='cuda:0') - -########## torch.float32/torch.int32/size=()+(9, 4)+(4, 2) ########## -# sparse tensor -tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 2]), - values=tensor([[[[[1.0000e+00, 1.0010e+03], - [1.0100e+02, 1.1010e+03], - [2.0100e+02, 1.2010e+03], - [3.0100e+02, 1.3010e+03]], - - [[1.1000e+01, 1.0110e+03], - [1.1100e+02, 1.1110e+03], - [2.1100e+02, 1.2110e+03], - [3.1100e+02, 1.3110e+03]]], - - - [[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], - - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]]], - - - [[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], - - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]]]], - - - - [[[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], - - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]]], - - - [[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], - - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]]], - - - [[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], - - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]]]], - - - - [[[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], - - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]]], - - - [[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], - - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]]], - - - [[[5.0000e+00, 1.0050e+03], - [1.0500e+02, 1.1050e+03], - [2.0500e+02, 1.2050e+03], - [3.0500e+02, 1.3050e+03]], - - [[1.5000e+01, 1.0150e+03], - [1.1500e+02, 1.1150e+03], - [2.1500e+02, 1.2150e+03], - [3.1500e+02, 1.3150e+03]]]], - - - - [[[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], - - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]]], - - - [[[5.0000e+00, 1.0050e+03], - [1.0500e+02, 1.1050e+03], - [2.0500e+02, 1.2050e+03], - [3.0500e+02, 1.3050e+03]], - - [[1.5000e+01, 1.0150e+03], - [1.1500e+02, 1.1150e+03], - [2.1500e+02, 1.2150e+03], - [3.1500e+02, 1.3150e+03]]], - - - [[[6.0000e+00, 1.0060e+03], - [1.0600e+02, 1.1060e+03], - [2.0600e+02, 1.2060e+03], - [3.0600e+02, 1.3060e+03]], - - [[1.6000e+01, 1.0160e+03], - [1.1600e+02, 1.1160e+03], - [2.1600e+02, 1.2160e+03], - [3.1600e+02, 1.3160e+03]]]]]), device='cuda:0', - size=(9, 4, 4, 2), nnz=4, layout=torch.sparse_bsc) -# _ccol_indices -tensor([0, 2, 4], device='cuda:0', dtype=torch.int32) -# _row_indices -tensor([0, 1, 0, 2], device='cuda:0', dtype=torch.int32) -# _values -tensor([[[[[1.0000e+00, 1.0010e+03], - [1.0100e+02, 1.1010e+03], - [2.0100e+02, 1.2010e+03], - [3.0100e+02, 1.3010e+03]], - - [[1.1000e+01, 1.0110e+03], - [1.1100e+02, 1.1110e+03], - [2.1100e+02, 1.2110e+03], - [3.1100e+02, 1.3110e+03]]], - - - [[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], - - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]]], - - - [[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], - - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]]]], - - - - [[[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], - - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]]], - - - [[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], - - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]]], - - - [[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], - - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]]]], - - - - [[[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], - - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]]], - - - [[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], - - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]]], - - - [[[5.0000e+00, 1.0050e+03], - [1.0500e+02, 1.1050e+03], - [2.0500e+02, 1.2050e+03], - [3.0500e+02, 1.3050e+03]], - - [[1.5000e+01, 1.0150e+03], - [1.1500e+02, 1.1150e+03], - [2.1500e+02, 1.2150e+03], - [3.1500e+02, 1.3150e+03]]]], - - - - [[[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], - - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]]], - - - [[[5.0000e+00, 1.0050e+03], - [1.0500e+02, 1.1050e+03], - [2.0500e+02, 1.2050e+03], - [3.0500e+02, 1.3050e+03]], - - [[1.5000e+01, 1.0150e+03], - [1.1500e+02, 1.1150e+03], - [2.1500e+02, 1.2150e+03], - [3.1500e+02, 1.3150e+03]]], - - - [[[6.0000e+00, 1.0060e+03], - [1.0600e+02, 1.1060e+03], - [2.0600e+02, 1.2060e+03], - [3.0600e+02, 1.3060e+03]], - - [[1.6000e+01, 1.0160e+03], - [1.1600e+02, 1.1160e+03], - [2.1600e+02, 1.2160e+03], - [3.1600e+02, 1.3160e+03]]]]], device='cuda:0') - -########## torch.float32/torch.int32/size=(2, 3)+(6, 6)+(2, 1) ########## -# sparse tensor -tensor(ccol_indices=tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], - - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]), - row_indices=tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]), - values=tensor([[[[[[[ 1.], - [101.]], - - [[ 11.], - [111.]], - - [[ 21.], - [121.]]], - - - [[[ 2.], - [102.]], - - [[ 12.], - [112.]], - - [[ 22.], - [122.]]]], - - - - [[[[ 2.], - [102.]], - - [[ 12.], - [112.]], - - [[ 22.], - [122.]]], - - - [[[ 3.], - [103.]], - - [[ 13.], - [113.]], - - [[ 23.], - [123.]]]], - - - - [[[[ 3.], - [103.]], - - [[ 13.], - [113.]], - - [[ 23.], - [123.]]], - - - [[[ 4.], - [104.]], - - [[ 14.], - [114.]], - - [[ 24.], - [124.]]]], - - - - [[[[ 4.], - [104.]], - - [[ 14.], - [114.]], - - [[ 24.], - [124.]]], - - - [[[ 5.], - [105.]], - - [[ 15.], - [115.]], - - [[ 25.], - [125.]]]]], - - - - - [[[[[ 5.], - [105.]], - - [[ 15.], - [115.]], - - [[ 25.], - [125.]]], - - - [[[ 6.], - [106.]], - - [[ 16.], - [116.]], - - [[ 26.], - [126.]]]], - - - - [[[[ 6.], - [106.]], - - [[ 16.], - [116.]], - - [[ 26.], - [126.]]], - - - [[[ 7.], - [107.]], - - [[ 17.], - [117.]], - - [[ 27.], - [127.]]]], - - - - [[[[ 7.], - [107.]], - - [[ 17.], - [117.]], - - [[ 27.], - [127.]]], - - - [[[ 8.], - [108.]], - - [[ 18.], - [118.]], - - [[ 28.], - [128.]]]], - - - - [[[[ 8.], - [108.]], - - [[ 18.], - [118.]], - - [[ 28.], - [128.]]], - - - [[[ 9.], - [109.]], - - [[ 19.], - [119.]], - - [[ 29.], - [129.]]]]], - - - - - [[[[[ 9.], - [109.]], - - [[ 19.], - [119.]], - - [[ 29.], - [129.]]], - - - [[[ 10.], - [110.]], - - [[ 20.], - [120.]], - - [[ 30.], - [130.]]]], - - - - [[[[ 10.], - [110.]], - - [[ 20.], - [120.]], - - [[ 30.], - [130.]]], - - - [[[ 11.], - [111.]], - - [[ 21.], - [121.]], - - [[ 31.], - [131.]]]], - - - - [[[[ 11.], - [111.]], - - [[ 21.], - [121.]], - - [[ 31.], - [131.]]], - - - [[[ 12.], - [112.]], - - [[ 22.], - [122.]], - - [[ 32.], - [132.]]]], - - - - [[[[ 12.], - [112.]], - - [[ 22.], - [122.]], - - [[ 32.], - [132.]]], - - - [[[ 13.], - [113.]], - - [[ 23.], - [123.]], - - [[ 33.], - [133.]]]]]], - - - - - - [[[[[[ 13.], - [113.]], - - [[ 23.], - [123.]], - - [[ 33.], - [133.]]], - - - [[[ 14.], - [114.]], - - [[ 24.], - [124.]], - - [[ 34.], - [134.]]]], - - - - [[[[ 14.], - [114.]], - - [[ 24.], - [124.]], - - [[ 34.], - [134.]]], - - - [[[ 15.], - [115.]], - - [[ 25.], - [125.]], - - [[ 35.], - [135.]]]], - - - - [[[[ 15.], - [115.]], - - [[ 25.], - [125.]], - - [[ 35.], - [135.]]], - - - [[[ 16.], - [116.]], - - [[ 26.], - [126.]], - - [[ 36.], - [136.]]]], - - - - [[[[ 16.], - [116.]], - - [[ 26.], - [126.]], - - [[ 36.], - [136.]]], - - - [[[ 17.], - [117.]], - - [[ 27.], - [127.]], - - [[ 37.], - [137.]]]]], - - - - - [[[[[ 17.], - [117.]], - - [[ 27.], - [127.]], - - [[ 37.], - [137.]]], - - - [[[ 18.], - [118.]], - - [[ 28.], - [128.]], - - [[ 38.], - [138.]]]], - - - - [[[[ 18.], - [118.]], - - [[ 28.], - [128.]], - - [[ 38.], - [138.]]], - - - [[[ 19.], - [119.]], - - [[ 29.], - [129.]], - - [[ 39.], - [139.]]]], - - - - [[[[ 19.], - [119.]], - - [[ 29.], - [129.]], - - [[ 39.], - [139.]]], - - - [[[ 20.], - [120.]], - - [[ 30.], - [130.]], - - [[ 40.], - [140.]]]], - - - - [[[[ 20.], - [120.]], - - [[ 30.], - [130.]], - - [[ 40.], - [140.]]], - - - [[[ 21.], - [121.]], - - [[ 31.], - [131.]], - - [[ 41.], - [141.]]]]], - - - - - [[[[[ 21.], - [121.]], - - [[ 31.], - [131.]], - - [[ 41.], - [141.]]], - - - [[[ 22.], - [122.]], - - [[ 32.], - [132.]], - - [[ 42.], - [142.]]]], - - - - [[[[ 22.], - [122.]], - - [[ 32.], - [132.]], - - [[ 42.], - [142.]]], - - - [[[ 23.], - [123.]], - - [[ 33.], - [133.]], - - [[ 43.], - [143.]]]], - - - - [[[[ 23.], - [123.]], - - [[ 33.], - [133.]], - - [[ 43.], - [143.]]], - - - [[[ 24.], - [124.]], - - [[ 34.], - [134.]], - - [[ 44.], - [144.]]]], - - - - [[[[ 24.], - [124.]], - - [[ 34.], - [134.]], - - [[ 44.], - [144.]]], - - - [[[ 25.], - [125.]], - - [[ 35.], - [135.]], - - [[ 45.], - [145.]]]]]]]), device='cuda:0', - size=(2, 3, 6, 6, 2, 1), nnz=4, layout=torch.sparse_bsc) -# _ccol_indices -tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], - - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]], device='cuda:0', dtype=torch.int32) -# _row_indices -tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]], device='cuda:0', dtype=torch.int32) -# _values -tensor([[[[[[[ 1.], - [101.]], - - [[ 11.], - [111.]], - - [[ 21.], - [121.]]], - - - [[[ 2.], - [102.]], - - [[ 12.], - [112.]], - - [[ 22.], - [122.]]]], - - - - [[[[ 2.], - [102.]], - - [[ 12.], - [112.]], - - [[ 22.], - [122.]]], - - - [[[ 3.], - [103.]], - - [[ 13.], - [113.]], - - [[ 23.], - [123.]]]], - - - - [[[[ 3.], - [103.]], - - [[ 13.], - [113.]], - - [[ 23.], - [123.]]], - - - [[[ 4.], - [104.]], - - [[ 14.], - [114.]], - - [[ 24.], - [124.]]]], - - - - [[[[ 4.], - [104.]], - - [[ 14.], - [114.]], - - [[ 24.], - [124.]]], - - - [[[ 5.], - [105.]], - - [[ 15.], - [115.]], - - [[ 25.], - [125.]]]]], - - - - - [[[[[ 5.], - [105.]], - - [[ 15.], - [115.]], - - [[ 25.], - [125.]]], - - - [[[ 6.], - [106.]], - - [[ 16.], - [116.]], - - [[ 26.], - [126.]]]], - - - - [[[[ 6.], - [106.]], - - [[ 16.], - [116.]], - - [[ 26.], - [126.]]], - - - [[[ 7.], - [107.]], - - [[ 17.], - [117.]], - - [[ 27.], - [127.]]]], - - - - [[[[ 7.], - [107.]], - - [[ 17.], - [117.]], - - [[ 27.], - [127.]]], - - - [[[ 8.], - [108.]], - - [[ 18.], - [118.]], - - [[ 28.], - [128.]]]], - - - - [[[[ 8.], - [108.]], - - [[ 18.], - [118.]], - - [[ 28.], - [128.]]], - - - [[[ 9.], - [109.]], - - [[ 19.], - [119.]], - - [[ 29.], - [129.]]]]], - - - - - [[[[[ 9.], - [109.]], - - [[ 19.], - [119.]], - - [[ 29.], - [129.]]], - - - [[[ 10.], - [110.]], - - [[ 20.], - [120.]], - - [[ 30.], - [130.]]]], - - - - [[[[ 10.], - [110.]], - - [[ 20.], - [120.]], - - [[ 30.], - [130.]]], - - - [[[ 11.], - [111.]], - - [[ 21.], - [121.]], - - [[ 31.], - [131.]]]], - - - - [[[[ 11.], - [111.]], - - [[ 21.], - [121.]], - - [[ 31.], - [131.]]], - - - [[[ 12.], - [112.]], - - [[ 22.], - [122.]], - - [[ 32.], - [132.]]]], - - - - [[[[ 12.], - [112.]], - - [[ 22.], - [122.]], - - [[ 32.], - [132.]]], - - - [[[ 13.], - [113.]], - - [[ 23.], - [123.]], - - [[ 33.], - [133.]]]]]], - - - - - - [[[[[[ 13.], - [113.]], - - [[ 23.], - [123.]], - - [[ 33.], - [133.]]], - - - [[[ 14.], - [114.]], - - [[ 24.], - [124.]], - - [[ 34.], - [134.]]]], - - - - [[[[ 14.], - [114.]], - - [[ 24.], - [124.]], - - [[ 34.], - [134.]]], - - - [[[ 15.], - [115.]], - - [[ 25.], - [125.]], - - [[ 35.], - [135.]]]], - - - - [[[[ 15.], - [115.]], - - [[ 25.], - [125.]], - - [[ 35.], - [135.]]], - - - [[[ 16.], - [116.]], - - [[ 26.], - [126.]], - - [[ 36.], - [136.]]]], - - - - [[[[ 16.], - [116.]], - - [[ 26.], - [126.]], - - [[ 36.], - [136.]]], - - - [[[ 17.], - [117.]], - - [[ 27.], - [127.]], - - [[ 37.], - [137.]]]]], - - - - - [[[[[ 17.], - [117.]], - - [[ 27.], - [127.]], - - [[ 37.], - [137.]]], - - - [[[ 18.], - [118.]], - - [[ 28.], - [128.]], - - [[ 38.], - [138.]]]], - - - - [[[[ 18.], - [118.]], - - [[ 28.], - [128.]], - - [[ 38.], - [138.]]], - - - [[[ 19.], - [119.]], - - [[ 29.], - [129.]], - - [[ 39.], - [139.]]]], - - - - [[[[ 19.], - [119.]], - - [[ 29.], - [129.]], - - [[ 39.], - [139.]]], - - - [[[ 20.], - [120.]], - - [[ 30.], - [130.]], - - [[ 40.], - [140.]]]], - - - - [[[[ 20.], - [120.]], - - [[ 30.], - [130.]], - - [[ 40.], - [140.]]], - - - [[[ 21.], - [121.]], - - [[ 31.], - [131.]], - - [[ 41.], - [141.]]]]], - - - - - [[[[[ 21.], - [121.]], - - [[ 31.], - [131.]], - - [[ 41.], - [141.]]], - - - [[[ 22.], - [122.]], - - [[ 32.], - [132.]], - - [[ 42.], - [142.]]]], - - - - [[[[ 22.], - [122.]], - - [[ 32.], - [132.]], - - [[ 42.], - [142.]]], - - - [[[ 23.], - [123.]], - - [[ 33.], - [133.]], - - [[ 43.], - [143.]]]], - - - - [[[[ 23.], - [123.]], - - [[ 33.], - [133.]], - - [[ 43.], - [143.]]], - - - [[[ 24.], - [124.]], - - [[ 34.], - [134.]], - - [[ 44.], - [144.]]]], - - - - [[[[ 24.], - [124.]], - - [[ 34.], - [134.]], - - [[ 44.], - [144.]]], - - - [[[ 25.], - [125.]], - - [[ 35.], - [135.]], - - [[ 45.], - [145.]]]]]]], device='cuda:0') - - -########## torch.float64/torch.int32/size=()+(6, 6)+(2,) ########## -# sparse tensor -tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 2]), - values=tensor([[[[ 1., 101.], - [ 11., 111.], - [ 21., 121.]], - - [[ 2., 102.], - [ 12., 112.], - [ 22., 122.]]], - - - [[[ 2., 102.], - [ 12., 112.], - [ 22., 122.]], - - [[ 3., 103.], - [ 13., 113.], - [ 23., 123.]]], - - - [[[ 3., 103.], - [ 13., 113.], - [ 23., 123.]], - - [[ 4., 104.], - [ 14., 114.], - [ 24., 124.]]], - - - [[[ 4., 104.], - [ 14., 114.], - [ 24., 124.]], - - [[ 5., 105.], - [ 15., 115.], - [ 25., 125.]]]]), device='cuda:0', size=(6, 6, 2), - nnz=4, dtype=torch.float64, layout=torch.sparse_bsc) -# _ccol_indices -tensor([0, 2, 4], device='cuda:0', dtype=torch.int32) -# _row_indices -tensor([0, 1, 0, 2], device='cuda:0', dtype=torch.int32) -# _values -tensor([[[[ 1., 101.], - [ 11., 111.], - [ 21., 121.]], - - [[ 2., 102.], - [ 12., 112.], - [ 22., 122.]]], - - - [[[ 2., 102.], - [ 12., 112.], - [ 22., 122.]], - - [[ 3., 103.], - [ 13., 113.], - [ 23., 123.]]], - - - [[[ 3., 103.], - [ 13., 113.], - [ 23., 123.]], - - [[ 4., 104.], - [ 14., 114.], - [ 24., 124.]]], - - - [[[ 4., 104.], - [ 14., 114.], - [ 24., 124.]], - - [[ 5., 105.], - [ 15., 115.], - [ 25., 125.]]]], device='cuda:0', dtype=torch.float64) - -########## torch.float64/torch.int32/size=()+(9, 4)+(4, 2) ########## -# sparse tensor -tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 2]), - values=tensor([[[[[1.0000e+00, 1.0010e+03], - [1.0100e+02, 1.1010e+03], - [2.0100e+02, 1.2010e+03], - [3.0100e+02, 1.3010e+03]], - - [[1.1000e+01, 1.0110e+03], - [1.1100e+02, 1.1110e+03], - [2.1100e+02, 1.2110e+03], - [3.1100e+02, 1.3110e+03]]], - - - [[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], - - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]]], - - - [[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], - - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]]]], - - - - [[[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], - - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]]], - - - [[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], - - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]]], - - - [[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], - - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]]]], - - - - [[[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], - - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]]], - - - [[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], - - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]]], - - - [[[5.0000e+00, 1.0050e+03], - [1.0500e+02, 1.1050e+03], - [2.0500e+02, 1.2050e+03], - [3.0500e+02, 1.3050e+03]], - - [[1.5000e+01, 1.0150e+03], - [1.1500e+02, 1.1150e+03], - [2.1500e+02, 1.2150e+03], - [3.1500e+02, 1.3150e+03]]]], - - - - [[[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], - - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]]], - - - [[[5.0000e+00, 1.0050e+03], - [1.0500e+02, 1.1050e+03], - [2.0500e+02, 1.2050e+03], - [3.0500e+02, 1.3050e+03]], - - [[1.5000e+01, 1.0150e+03], - [1.1500e+02, 1.1150e+03], - [2.1500e+02, 1.2150e+03], - [3.1500e+02, 1.3150e+03]]], - - - [[[6.0000e+00, 1.0060e+03], - [1.0600e+02, 1.1060e+03], - [2.0600e+02, 1.2060e+03], - [3.0600e+02, 1.3060e+03]], - - [[1.6000e+01, 1.0160e+03], - [1.1600e+02, 1.1160e+03], - [2.1600e+02, 1.2160e+03], - [3.1600e+02, 1.3160e+03]]]]]), device='cuda:0', - size=(9, 4, 4, 2), nnz=4, dtype=torch.float64, layout=torch.sparse_bsc) -# _ccol_indices -tensor([0, 2, 4], device='cuda:0', dtype=torch.int32) -# _row_indices -tensor([0, 1, 0, 2], device='cuda:0', dtype=torch.int32) -# _values -tensor([[[[[1.0000e+00, 1.0010e+03], - [1.0100e+02, 1.1010e+03], - [2.0100e+02, 1.2010e+03], - [3.0100e+02, 1.3010e+03]], - - [[1.1000e+01, 1.0110e+03], - [1.1100e+02, 1.1110e+03], - [2.1100e+02, 1.2110e+03], - [3.1100e+02, 1.3110e+03]]], - - - [[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], - - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]]], - - - [[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], - - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]]]], - - - - [[[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], - - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]]], - - - [[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], - - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]]], - - - [[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], - - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]]]], - - - - [[[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], - - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]]], - - - [[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], - - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]]], - - - [[[5.0000e+00, 1.0050e+03], - [1.0500e+02, 1.1050e+03], - [2.0500e+02, 1.2050e+03], - [3.0500e+02, 1.3050e+03]], - - [[1.5000e+01, 1.0150e+03], - [1.1500e+02, 1.1150e+03], - [2.1500e+02, 1.2150e+03], - [3.1500e+02, 1.3150e+03]]]], - - - - [[[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], - - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]]], - - - [[[5.0000e+00, 1.0050e+03], - [1.0500e+02, 1.1050e+03], - [2.0500e+02, 1.2050e+03], - [3.0500e+02, 1.3050e+03]], - - [[1.5000e+01, 1.0150e+03], - [1.1500e+02, 1.1150e+03], - [2.1500e+02, 1.2150e+03], - [3.1500e+02, 1.3150e+03]]], - - - [[[6.0000e+00, 1.0060e+03], - [1.0600e+02, 1.1060e+03], - [2.0600e+02, 1.2060e+03], - [3.0600e+02, 1.3060e+03]], - - [[1.6000e+01, 1.0160e+03], - [1.1600e+02, 1.1160e+03], - [2.1600e+02, 1.2160e+03], - [3.1600e+02, 1.3160e+03]]]]], device='cuda:0', dtype=torch.float64) - -########## torch.float64/torch.int32/size=(2, 3)+(6, 6)+(2, 1) ########## -# sparse tensor -tensor(ccol_indices=tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], - - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]), - row_indices=tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]), - values=tensor([[[[[[[ 1.], - [101.]], - - [[ 11.], - [111.]], - - [[ 21.], - [121.]]], - - - [[[ 2.], - [102.]], - - [[ 12.], - [112.]], - - [[ 22.], - [122.]]]], - - - - [[[[ 2.], - [102.]], - - [[ 12.], - [112.]], - - [[ 22.], - [122.]]], - - - [[[ 3.], - [103.]], - - [[ 13.], - [113.]], - - [[ 23.], - [123.]]]], - - - - [[[[ 3.], - [103.]], - - [[ 13.], - [113.]], - - [[ 23.], - [123.]]], - - - [[[ 4.], - [104.]], - - [[ 14.], - [114.]], - - [[ 24.], - [124.]]]], - - - - [[[[ 4.], - [104.]], - - [[ 14.], - [114.]], - - [[ 24.], - [124.]]], - - - [[[ 5.], - [105.]], - - [[ 15.], - [115.]], - - [[ 25.], - [125.]]]]], - - - - - [[[[[ 5.], - [105.]], - - [[ 15.], - [115.]], - - [[ 25.], - [125.]]], - - - [[[ 6.], - [106.]], - - [[ 16.], - [116.]], - - [[ 26.], - [126.]]]], - - - - [[[[ 6.], - [106.]], - - [[ 16.], - [116.]], - - [[ 26.], - [126.]]], - - - [[[ 7.], - [107.]], - - [[ 17.], - [117.]], - - [[ 27.], - [127.]]]], - - - - [[[[ 7.], - [107.]], - - [[ 17.], - [117.]], - - [[ 27.], - [127.]]], - - - [[[ 8.], - [108.]], - - [[ 18.], - [118.]], - - [[ 28.], - [128.]]]], - - - - [[[[ 8.], - [108.]], - - [[ 18.], - [118.]], - - [[ 28.], - [128.]]], - - - [[[ 9.], - [109.]], - - [[ 19.], - [119.]], - - [[ 29.], - [129.]]]]], - - - - - [[[[[ 9.], - [109.]], - - [[ 19.], - [119.]], - - [[ 29.], - [129.]]], - - - [[[ 10.], - [110.]], - - [[ 20.], - [120.]], - - [[ 30.], - [130.]]]], - - - - [[[[ 10.], - [110.]], - - [[ 20.], - [120.]], - - [[ 30.], - [130.]]], - - - [[[ 11.], - [111.]], - - [[ 21.], - [121.]], - - [[ 31.], - [131.]]]], - - - - [[[[ 11.], - [111.]], - - [[ 21.], - [121.]], - - [[ 31.], - [131.]]], - - - [[[ 12.], - [112.]], - - [[ 22.], - [122.]], - - [[ 32.], - [132.]]]], - - - - [[[[ 12.], - [112.]], - - [[ 22.], - [122.]], - - [[ 32.], - [132.]]], - - - [[[ 13.], - [113.]], - - [[ 23.], - [123.]], - - [[ 33.], - [133.]]]]]], - - - - - - [[[[[[ 13.], - [113.]], - - [[ 23.], - [123.]], - - [[ 33.], - [133.]]], - - - [[[ 14.], - [114.]], - - [[ 24.], - [124.]], - - [[ 34.], - [134.]]]], - - - - [[[[ 14.], - [114.]], - - [[ 24.], - [124.]], - - [[ 34.], - [134.]]], - - - [[[ 15.], - [115.]], - - [[ 25.], - [125.]], - - [[ 35.], - [135.]]]], - - - - [[[[ 15.], - [115.]], - - [[ 25.], - [125.]], - - [[ 35.], - [135.]]], - - - [[[ 16.], - [116.]], - - [[ 26.], - [126.]], - - [[ 36.], - [136.]]]], - - - - [[[[ 16.], - [116.]], - - [[ 26.], - [126.]], - - [[ 36.], - [136.]]], - - - [[[ 17.], - [117.]], - - [[ 27.], - [127.]], - - [[ 37.], - [137.]]]]], - - - - - [[[[[ 17.], - [117.]], - - [[ 27.], - [127.]], - - [[ 37.], - [137.]]], - - - [[[ 18.], - [118.]], - - [[ 28.], - [128.]], - - [[ 38.], - [138.]]]], - - - - [[[[ 18.], - [118.]], - - [[ 28.], - [128.]], - - [[ 38.], - [138.]]], - - - [[[ 19.], - [119.]], - - [[ 29.], - [129.]], - - [[ 39.], - [139.]]]], - - - - [[[[ 19.], - [119.]], - - [[ 29.], - [129.]], - - [[ 39.], - [139.]]], - - - [[[ 20.], - [120.]], - - [[ 30.], - [130.]], - - [[ 40.], - [140.]]]], - - - - [[[[ 20.], - [120.]], - - [[ 30.], - [130.]], - - [[ 40.], - [140.]]], - - - [[[ 21.], - [121.]], - - [[ 31.], - [131.]], - - [[ 41.], - [141.]]]]], - - - - - [[[[[ 21.], - [121.]], - - [[ 31.], - [131.]], - - [[ 41.], - [141.]]], - - - [[[ 22.], - [122.]], - - [[ 32.], - [132.]], - - [[ 42.], - [142.]]]], - - - - [[[[ 22.], - [122.]], - - [[ 32.], - [132.]], - - [[ 42.], - [142.]]], - - - [[[ 23.], - [123.]], - - [[ 33.], - [133.]], - - [[ 43.], - [143.]]]], - - - - [[[[ 23.], - [123.]], - - [[ 33.], - [133.]], - - [[ 43.], - [143.]]], + [[0, 0, 0], + [0, 0, 0], + [0, 0, 0]]]), + values=tensor([[[[[1.], + [3.]], + [[2.], + [0.]], - [[[ 24.], - [124.]], + [[0.], + [4.]]], - [[ 34.], - [134.]], - [[ 44.], - [144.]]]], + [[[1.], + [4.]], + [[2.], + [0.]], + [[3.], + [0.]]], - [[[[ 24.], - [124.]], - [[ 34.], - [134.]], + [[[1.], + [2.]], - [[ 44.], - [144.]]], + [[0.], + [3.]], + [[0.], + [4.]]]], - [[[ 25.], - [125.]], - - [[ 35.], - [135.]], - - [[ 45.], - [145.]]]]]]]), device='cuda:0', - size=(2, 3, 6, 6, 2, 1), nnz=4, dtype=torch.float64, - layout=torch.sparse_bsc) -# _ccol_indices -tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], - - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]], device='cuda:0', dtype=torch.int32) -# _row_indices -tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]], device='cuda:0', dtype=torch.int32) -# _values -tensor([[[[[[[ 1.], - [101.]], - - [[ 11.], - [111.]], - - [[ 21.], - [121.]]], - - - [[[ 2.], - [102.]], - - [[ 12.], - [112.]], - - [[ 22.], - [122.]]]], - - - - [[[[ 2.], - [102.]], - - [[ 12.], - [112.]], - - [[ 22.], - [122.]]], - - - [[[ 3.], - [103.]], - - [[ 13.], - [113.]], - - [[ 23.], - [123.]]]], - - - - [[[[ 3.], - [103.]], - - [[ 13.], - [113.]], - - [[ 23.], - [123.]]], - - - [[[ 4.], - [104.]], - - [[ 14.], - [114.]], - - [[ 24.], - [124.]]]], + [[[[0.], + [2.]], - [[[[ 4.], - [104.]], + [[1.], + [3.]], - [[ 14.], - [114.]], + [[0.], + [4.]]], - [[ 24.], - [124.]]], + [[[1.], + [3.]], - [[[ 5.], - [105.]], + [[0.], + [4.]], - [[ 15.], - [115.]], + [[2.], + [0.]]], - [[ 25.], - [125.]]]]], + [[[1.], + [0.]], + [[2.], + [4.]], + [[3.], + [0.]]]]]), device='cuda:0', size=(2, 3, 2, 3), nnz=3, + layout=torch.sparse_bsc) +# _ccol_indices +tensor([[[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]], - [[[[[ 5.], - [105.]], - - [[ 15.], - [115.]], - - [[ 25.], - [125.]]], - - - [[[ 6.], - [106.]], - - [[ 16.], - [116.]], - - [[ 26.], - [126.]]]], - - - - [[[[ 6.], - [106.]], - - [[ 16.], - [116.]], - - [[ 26.], - [126.]]], - - - [[[ 7.], - [107.]], - - [[ 17.], - [117.]], - - [[ 27.], - [127.]]]], - - - - [[[[ 7.], - [107.]], - - [[ 17.], - [117.]], - - [[ 27.], - [127.]]], - - - [[[ 8.], - [108.]], - - [[ 18.], - [118.]], - - [[ 28.], - [128.]]]], - - - - [[[[ 8.], - [108.]], - - [[ 18.], - [118.]], - - [[ 28.], - [128.]]], + [[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]]], device='cuda:0', dtype=torch.int32) +# _row_indices +tensor([[[0, 0, 0], + [0, 0, 0], + [0, 0, 0]], + [[0, 0, 0], + [0, 0, 0], + [0, 0, 0]]], device='cuda:0', dtype=torch.int32) +# _values +tensor([[[[[1.], + [3.]], - [[[ 9.], - [109.]], + [[2.], + [0.]], - [[ 19.], - [119.]], + [[0.], + [4.]]], - [[ 29.], - [129.]]]]], + [[[1.], + [4.]], + [[2.], + [0.]], + [[3.], + [0.]]], - [[[[[ 9.], - [109.]], - [[ 19.], - [119.]], + [[[1.], + [2.]], - [[ 29.], - [129.]]], + [[0.], + [3.]], + [[0.], + [4.]]]], - [[[ 10.], - [110.]], - [[ 20.], - [120.]], - [[ 30.], - [130.]]]], + [[[[0.], + [2.]], + [[1.], + [3.]], + [[0.], + [4.]]], - [[[[ 10.], - [110.]], - [[ 20.], - [120.]], + [[[1.], + [3.]], - [[ 30.], - [130.]]], + [[0.], + [4.]], + [[2.], + [0.]]], - [[[ 11.], - [111.]], - [[ 21.], - [121.]], + [[[1.], + [0.]], - [[ 31.], - [131.]]]], + [[2.], + [4.]], + [[3.], + [0.]]]]], device='cuda:0') +########## torch.float32/torch.int32/size=()+(8, 6)+() ########## +# sparse tensor +tensor(ccol_indices=tensor([0, 4, 7]), + row_indices=tensor([0, 1, 2, 3, 0, 2, 3]), + values=tensor([[[ 0., 1., 0.], + [ 0., 4., 0.]], - [[[[ 11.], - [111.]], + [[ 6., 7., 8.], + [ 0., 0., 0.]], - [[ 21.], - [121.]], + [[ 0., 9., 0.], + [13., 0., 14.]], - [[ 31.], - [131.]]], + [[ 0., 0., 0.], + [20., 21., 22.]], + [[ 2., 0., 3.], + [ 0., 5., 0.]], - [[[ 12.], - [112.]], + [[10., 11., 12.], + [15., 16., 17.]], - [[ 22.], - [122.]], + [[ 0., 18., 19.], + [ 0., 23., 24.]]]), device='cuda:0', size=(8, 6), nnz=7, + layout=torch.sparse_bsc) +# _ccol_indices +tensor([0, 4, 7], device='cuda:0', dtype=torch.int32) +# _row_indices +tensor([0, 1, 2, 3, 0, 2, 3], device='cuda:0', dtype=torch.int32) +# _values +tensor([[[ 0., 1., 0.], + [ 0., 4., 0.]], - [[ 32.], - [132.]]]], + [[ 6., 7., 8.], + [ 0., 0., 0.]], + [[ 0., 9., 0.], + [13., 0., 14.]], + [[ 0., 0., 0.], + [20., 21., 22.]], - [[[[ 12.], - [112.]], + [[ 2., 0., 3.], + [ 0., 5., 0.]], - [[ 22.], - [122.]], + [[10., 11., 12.], + [15., 16., 17.]], - [[ 32.], - [132.]]], + [[ 0., 18., 19.], + [ 0., 23., 24.]]], device='cuda:0') - [[[ 13.], - [113.]], +########## torch.float64/torch.int32/size=(2, 3)+(2, 3)+() ########## +# sparse tensor +tensor(ccol_indices=tensor([[[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]], - [[ 23.], - [123.]], + [[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]]]), + row_indices=tensor([[[0, 0, 0], + [0, 0, 0], + [0, 0, 0]], - [[ 33.], - [133.]]]]]], + [[0, 0, 0], + [0, 0, 0], + [0, 0, 0]]]), + values=tensor([[[[[1.], + [3.]], + [[2.], + [0.]], + [[0.], + [4.]]], + [[[1.], + [4.]], - [[[[[[ 13.], - [113.]], + [[2.], + [0.]], - [[ 23.], - [123.]], + [[3.], + [0.]]], - [[ 33.], - [133.]]], + [[[1.], + [2.]], - [[[ 14.], - [114.]], + [[0.], + [3.]], - [[ 24.], - [124.]], + [[0.], + [4.]]]], - [[ 34.], - [134.]]]], + [[[[0.], + [2.]], - [[[[ 14.], - [114.]], + [[1.], + [3.]], - [[ 24.], - [124.]], + [[0.], + [4.]]], - [[ 34.], - [134.]]], + [[[1.], + [3.]], - [[[ 15.], - [115.]], + [[0.], + [4.]], - [[ 25.], - [125.]], + [[2.], + [0.]]], - [[ 35.], - [135.]]]], + [[[1.], + [0.]], + [[2.], + [4.]], - [[[[ 15.], - [115.]], + [[3.], + [0.]]]]]), device='cuda:0', size=(2, 3, 2, 3), nnz=3, + dtype=torch.float64, layout=torch.sparse_bsc) +# _ccol_indices +tensor([[[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]], - [[ 25.], - [125.]], + [[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]]], device='cuda:0', dtype=torch.int32) +# _row_indices +tensor([[[0, 0, 0], + [0, 0, 0], + [0, 0, 0]], - [[ 35.], - [135.]]], + [[0, 0, 0], + [0, 0, 0], + [0, 0, 0]]], device='cuda:0', dtype=torch.int32) +# _values +tensor([[[[[1.], + [3.]], + [[2.], + [0.]], - [[[ 16.], - [116.]], + [[0.], + [4.]]], - [[ 26.], - [126.]], - [[ 36.], - [136.]]]], + [[[1.], + [4.]], + [[2.], + [0.]], + [[3.], + [0.]]], - [[[[ 16.], - [116.]], - [[ 26.], - [126.]], + [[[1.], + [2.]], - [[ 36.], - [136.]]], + [[0.], + [3.]], + [[0.], + [4.]]]], - [[[ 17.], - [117.]], - [[ 27.], - [127.]], - [[ 37.], - [137.]]]]], + [[[[0.], + [2.]], + [[1.], + [3.]], + [[0.], + [4.]]], - [[[[[ 17.], - [117.]], + [[[1.], + [3.]], - [[ 27.], - [127.]], + [[0.], + [4.]], - [[ 37.], - [137.]]], + [[2.], + [0.]]], - [[[ 18.], - [118.]], + [[[1.], + [0.]], - [[ 28.], - [128.]], + [[2.], + [4.]], - [[ 38.], - [138.]]]], + [[3.], + [0.]]]]], device='cuda:0', dtype=torch.float64) +########## torch.float64/torch.int32/size=()+(8, 6)+() ########## +# sparse tensor +tensor(ccol_indices=tensor([0, 4, 7]), + row_indices=tensor([0, 1, 2, 3, 0, 2, 3]), + values=tensor([[[ 0., 1., 0.], + [ 0., 4., 0.]], + [[ 6., 7., 8.], + [ 0., 0., 0.]], - [[[[ 18.], - [118.]], + [[ 0., 9., 0.], + [13., 0., 14.]], - [[ 28.], - [128.]], + [[ 0., 0., 0.], + [20., 21., 22.]], - [[ 38.], - [138.]]], + [[ 2., 0., 3.], + [ 0., 5., 0.]], + [[10., 11., 12.], + [15., 16., 17.]], - [[[ 19.], - [119.]], + [[ 0., 18., 19.], + [ 0., 23., 24.]]]), device='cuda:0', size=(8, 6), nnz=7, + dtype=torch.float64, layout=torch.sparse_bsc) +# _ccol_indices +tensor([0, 4, 7], device='cuda:0', dtype=torch.int32) +# _row_indices +tensor([0, 1, 2, 3, 0, 2, 3], device='cuda:0', dtype=torch.int32) +# _values +tensor([[[ 0., 1., 0.], + [ 0., 4., 0.]], - [[ 29.], - [129.]], + [[ 6., 7., 8.], + [ 0., 0., 0.]], - [[ 39.], - [139.]]]], + [[ 0., 9., 0.], + [13., 0., 14.]], + [[ 0., 0., 0.], + [20., 21., 22.]], + [[ 2., 0., 3.], + [ 0., 5., 0.]], - [[[[ 19.], - [119.]], + [[10., 11., 12.], + [15., 16., 17.]], - [[ 29.], - [129.]], + [[ 0., 18., 19.], + [ 0., 23., 24.]]], device='cuda:0', dtype=torch.float64) - [[ 39.], - [139.]]], +########## torch.float32/torch.int64/size=(2, 3)+(2, 3)+() ########## +# sparse tensor +tensor(ccol_indices=tensor([[[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]], - [[[ 20.], - [120.]], + [[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]]]), + row_indices=tensor([[[0, 0, 0], + [0, 0, 0], + [0, 0, 0]], - [[ 30.], - [130.]], + [[0, 0, 0], + [0, 0, 0], + [0, 0, 0]]]), + values=tensor([[[[[1.], + [3.]], - [[ 40.], - [140.]]]], + [[2.], + [0.]], + [[0.], + [4.]]], - [[[[ 20.], - [120.]], + [[[1.], + [4.]], - [[ 30.], - [130.]], + [[2.], + [0.]], - [[ 40.], - [140.]]], + [[3.], + [0.]]], - [[[ 21.], - [121.]], + [[[1.], + [2.]], - [[ 31.], - [131.]], + [[0.], + [3.]], - [[ 41.], - [141.]]]]], + [[0.], + [4.]]]], + [[[[0.], + [2.]], - [[[[[ 21.], - [121.]], + [[1.], + [3.]], - [[ 31.], - [131.]], + [[0.], + [4.]]], - [[ 41.], - [141.]]], + [[[1.], + [3.]], - [[[ 22.], - [122.]], + [[0.], + [4.]], - [[ 32.], - [132.]], + [[2.], + [0.]]], - [[ 42.], - [142.]]]], + [[[1.], + [0.]], + [[2.], + [4.]], - [[[[ 22.], - [122.]], + [[3.], + [0.]]]]]), device='cuda:0', size=(2, 3, 2, 3), nnz=3, + layout=torch.sparse_bsc) +# _ccol_indices +tensor([[[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]], - [[ 32.], - [132.]], + [[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]]], device='cuda:0') +# _row_indices +tensor([[[0, 0, 0], + [0, 0, 0], + [0, 0, 0]], - [[ 42.], - [142.]]], + [[0, 0, 0], + [0, 0, 0], + [0, 0, 0]]], device='cuda:0') +# _values +tensor([[[[[1.], + [3.]], + [[2.], + [0.]], - [[[ 23.], - [123.]], + [[0.], + [4.]]], - [[ 33.], - [133.]], - [[ 43.], - [143.]]]], + [[[1.], + [4.]], + [[2.], + [0.]], + [[3.], + [0.]]], - [[[[ 23.], - [123.]], - [[ 33.], - [133.]], + [[[1.], + [2.]], - [[ 43.], - [143.]]], + [[0.], + [3.]], + [[0.], + [4.]]]], - [[[ 24.], - [124.]], - [[ 34.], - [134.]], - [[ 44.], - [144.]]]], + [[[[0.], + [2.]], + [[1.], + [3.]], + [[0.], + [4.]]], - [[[[ 24.], - [124.]], - [[ 34.], - [134.]], + [[[1.], + [3.]], - [[ 44.], - [144.]]], + [[0.], + [4.]], + [[2.], + [0.]]], - [[[ 25.], - [125.]], - [[ 35.], - [135.]], + [[[1.], + [0.]], - [[ 45.], - [145.]]]]]]], device='cuda:0', dtype=torch.float64) + [[2.], + [4.]], + [[3.], + [0.]]]]], device='cuda:0') -########## torch.float32/torch.int64/size=()+(6, 6)+(2,) ########## +########## torch.float32/torch.int64/size=()+(8, 6)+() ########## # sparse tensor -tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 2]), - values=tensor([[[[ 1., 101.], - [ 11., 111.], - [ 21., 121.]], - - [[ 2., 102.], - [ 12., 112.], - [ 22., 122.]]], - - - [[[ 2., 102.], - [ 12., 112.], - [ 22., 122.]], +tensor(ccol_indices=tensor([0, 4, 7]), + row_indices=tensor([0, 1, 2, 3, 0, 2, 3]), + values=tensor([[[ 0., 1., 0.], + [ 0., 4., 0.]], - [[ 3., 103.], - [ 13., 113.], - [ 23., 123.]]], + [[ 6., 7., 8.], + [ 0., 0., 0.]], + [[ 0., 9., 0.], + [13., 0., 14.]], - [[[ 3., 103.], - [ 13., 113.], - [ 23., 123.]], + [[ 0., 0., 0.], + [20., 21., 22.]], - [[ 4., 104.], - [ 14., 114.], - [ 24., 124.]]], + [[ 2., 0., 3.], + [ 0., 5., 0.]], + [[10., 11., 12.], + [15., 16., 17.]], - [[[ 4., 104.], - [ 14., 114.], - [ 24., 124.]], - - [[ 5., 105.], - [ 15., 115.], - [ 25., 125.]]]]), device='cuda:0', size=(6, 6, 2), - nnz=4, layout=torch.sparse_bsc) + [[ 0., 18., 19.], + [ 0., 23., 24.]]]), device='cuda:0', size=(8, 6), nnz=7, + layout=torch.sparse_bsc) # _ccol_indices -tensor([0, 2, 4], device='cuda:0') +tensor([0, 4, 7], device='cuda:0') # _row_indices -tensor([0, 1, 0, 2], device='cuda:0') +tensor([0, 1, 2, 3, 0, 2, 3], device='cuda:0') # _values -tensor([[[[ 1., 101.], - [ 11., 111.], - [ 21., 121.]], - - [[ 2., 102.], - [ 12., 112.], - [ 22., 122.]]], - +tensor([[[ 0., 1., 0.], + [ 0., 4., 0.]], - [[[ 2., 102.], - [ 12., 112.], - [ 22., 122.]], + [[ 6., 7., 8.], + [ 0., 0., 0.]], - [[ 3., 103.], - [ 13., 113.], - [ 23., 123.]]], + [[ 0., 9., 0.], + [13., 0., 14.]], + [[ 0., 0., 0.], + [20., 21., 22.]], - [[[ 3., 103.], - [ 13., 113.], - [ 23., 123.]], + [[ 2., 0., 3.], + [ 0., 5., 0.]], - [[ 4., 104.], - [ 14., 114.], - [ 24., 124.]]], + [[10., 11., 12.], + [15., 16., 17.]], + [[ 0., 18., 19.], + [ 0., 23., 24.]]], device='cuda:0') - [[[ 4., 104.], - [ 14., 114.], - [ 24., 124.]], - [[ 5., 105.], - [ 15., 115.], - [ 25., 125.]]]], device='cuda:0') - -########## torch.float32/torch.int64/size=()+(9, 4)+(4, 2) ########## +########## torch.float64/torch.int64/size=(2, 3)+(2, 3)+() ########## # sparse tensor -tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 2]), - values=tensor([[[[[1.0000e+00, 1.0010e+03], - [1.0100e+02, 1.1010e+03], - [2.0100e+02, 1.2010e+03], - [3.0100e+02, 1.3010e+03]], - - [[1.1000e+01, 1.0110e+03], - [1.1100e+02, 1.1110e+03], - [2.1100e+02, 1.2110e+03], - [3.1100e+02, 1.3110e+03]]], - - - [[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], +tensor(ccol_indices=tensor([[[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]], - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]]], + [[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]]]), + row_indices=tensor([[[0, 0, 0], + [0, 0, 0], + [0, 0, 0]], + [[0, 0, 0], + [0, 0, 0], + [0, 0, 0]]]), + values=tensor([[[[[1.], + [3.]], - [[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], + [[2.], + [0.]], - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]]]], + [[0.], + [4.]]], + [[[1.], + [4.]], - [[[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], + [[2.], + [0.]], - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]]], + [[3.], + [0.]]], - [[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], + [[[1.], + [2.]], - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]]], + [[0.], + [3.]], + [[0.], + [4.]]]], - [[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]]]], + [[[[0.], + [2.]], + [[1.], + [3.]], - [[[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], + [[0.], + [4.]]], - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]]], + [[[1.], + [3.]], - [[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], + [[0.], + [4.]], - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]]], + [[2.], + [0.]]], - [[[5.0000e+00, 1.0050e+03], - [1.0500e+02, 1.1050e+03], - [2.0500e+02, 1.2050e+03], - [3.0500e+02, 1.3050e+03]], + [[[1.], + [0.]], - [[1.5000e+01, 1.0150e+03], - [1.1500e+02, 1.1150e+03], - [2.1500e+02, 1.2150e+03], - [3.1500e+02, 1.3150e+03]]]], + [[2.], + [4.]], - - - [[[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], - - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]]], - - - [[[5.0000e+00, 1.0050e+03], - [1.0500e+02, 1.1050e+03], - [2.0500e+02, 1.2050e+03], - [3.0500e+02, 1.3050e+03]], - - [[1.5000e+01, 1.0150e+03], - [1.1500e+02, 1.1150e+03], - [2.1500e+02, 1.2150e+03], - [3.1500e+02, 1.3150e+03]]], - - - [[[6.0000e+00, 1.0060e+03], - [1.0600e+02, 1.1060e+03], - [2.0600e+02, 1.2060e+03], - [3.0600e+02, 1.3060e+03]], - - [[1.6000e+01, 1.0160e+03], - [1.1600e+02, 1.1160e+03], - [2.1600e+02, 1.2160e+03], - [3.1600e+02, 1.3160e+03]]]]]), device='cuda:0', - size=(9, 4, 4, 2), nnz=4, layout=torch.sparse_bsc) + [[3.], + [0.]]]]]), device='cuda:0', size=(2, 3, 2, 3), nnz=3, + dtype=torch.float64, layout=torch.sparse_bsc) # _ccol_indices -tensor([0, 2, 4], device='cuda:0') -# _row_indices -tensor([0, 1, 0, 2], device='cuda:0') -# _values -tensor([[[[[1.0000e+00, 1.0010e+03], - [1.0100e+02, 1.1010e+03], - [2.0100e+02, 1.2010e+03], - [3.0100e+02, 1.3010e+03]], - - [[1.1000e+01, 1.0110e+03], - [1.1100e+02, 1.1110e+03], - [2.1100e+02, 1.2110e+03], - [3.1100e+02, 1.3110e+03]]], - - - [[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], - - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]]], +tensor([[[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]], + [[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]]], device='cuda:0') +# _row_indices +tensor([[[0, 0, 0], + [0, 0, 0], + [0, 0, 0]], - [[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], - - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]]]], - - - - [[[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], - - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]]], - - - [[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], - - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]]], + [[0, 0, 0], + [0, 0, 0], + [0, 0, 0]]], device='cuda:0') +# _values +tensor([[[[[1.], + [3.]], + [[2.], + [0.]], - [[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], + [[0.], + [4.]]], - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]]]], + [[[1.], + [4.]], + [[2.], + [0.]], - [[[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], + [[3.], + [0.]]], - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]]], + [[[1.], + [2.]], - [[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], + [[0.], + [3.]], - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]]], + [[0.], + [4.]]]], - [[[5.0000e+00, 1.0050e+03], - [1.0500e+02, 1.1050e+03], - [2.0500e+02, 1.2050e+03], - [3.0500e+02, 1.3050e+03]], - [[1.5000e+01, 1.0150e+03], - [1.1500e+02, 1.1150e+03], - [2.1500e+02, 1.2150e+03], - [3.1500e+02, 1.3150e+03]]]], + [[[[0.], + [2.]], + [[1.], + [3.]], + [[0.], + [4.]]], - [[[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]]], + [[[1.], + [3.]], + [[0.], + [4.]], - [[[5.0000e+00, 1.0050e+03], - [1.0500e+02, 1.1050e+03], - [2.0500e+02, 1.2050e+03], - [3.0500e+02, 1.3050e+03]], + [[2.], + [0.]]], - [[1.5000e+01, 1.0150e+03], - [1.1500e+02, 1.1150e+03], - [2.1500e+02, 1.2150e+03], - [3.1500e+02, 1.3150e+03]]], + [[[1.], + [0.]], - [[[6.0000e+00, 1.0060e+03], - [1.0600e+02, 1.1060e+03], - [2.0600e+02, 1.2060e+03], - [3.0600e+02, 1.3060e+03]], + [[2.], + [4.]], - [[1.6000e+01, 1.0160e+03], - [1.1600e+02, 1.1160e+03], - [2.1600e+02, 1.2160e+03], - [3.1600e+02, 1.3160e+03]]]]], device='cuda:0') + [[3.], + [0.]]]]], device='cuda:0', dtype=torch.float64) -########## torch.float32/torch.int64/size=(2, 3)+(6, 6)+(2, 1) ########## +########## torch.float64/torch.int64/size=()+(8, 6)+() ########## # sparse tensor -tensor(ccol_indices=tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], - - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]), - row_indices=tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]), - values=tensor([[[[[[[ 1.], - [101.]], +tensor(ccol_indices=tensor([0, 4, 7]), + row_indices=tensor([0, 1, 2, 3, 0, 2, 3]), + values=tensor([[[ 0., 1., 0.], + [ 0., 4., 0.]], - [[ 11.], - [111.]], + [[ 6., 7., 8.], + [ 0., 0., 0.]], - [[ 21.], - [121.]]], + [[ 0., 9., 0.], + [13., 0., 14.]], + [[ 0., 0., 0.], + [20., 21., 22.]], - [[[ 2.], - [102.]], + [[ 2., 0., 3.], + [ 0., 5., 0.]], - [[ 12.], - [112.]], + [[10., 11., 12.], + [15., 16., 17.]], - [[ 22.], - [122.]]]], + [[ 0., 18., 19.], + [ 0., 23., 24.]]]), device='cuda:0', size=(8, 6), nnz=7, + dtype=torch.float64, layout=torch.sparse_bsc) +# _ccol_indices +tensor([0, 4, 7], device='cuda:0') +# _row_indices +tensor([0, 1, 2, 3, 0, 2, 3], device='cuda:0') +# _values +tensor([[[ 0., 1., 0.], + [ 0., 4., 0.]], + [[ 6., 7., 8.], + [ 0., 0., 0.]], + [[ 0., 9., 0.], + [13., 0., 14.]], - [[[[ 2.], - [102.]], + [[ 0., 0., 0.], + [20., 21., 22.]], - [[ 12.], - [112.]], + [[ 2., 0., 3.], + [ 0., 5., 0.]], - [[ 22.], - [122.]]], + [[10., 11., 12.], + [15., 16., 17.]], + [[ 0., 18., 19.], + [ 0., 23., 24.]]], device='cuda:0', dtype=torch.float64) - [[[ 3.], - [103.]], - [[ 13.], - [113.]], +########## torch.float32/torch.int32/size=(2, 3)+(2, 3)+(4,) ########## +# sparse tensor +tensor(ccol_indices=tensor([[[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]], - [[ 23.], - [123.]]]], + [[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]]]), + row_indices=tensor([[[0, 0, 0], + [0, 0, 0], + [0, 0, 0]], + [[0, 0, 0], + [0, 0, 0], + [0, 0, 0]]]), + values=tensor([[[[[[1., 2., 3., 4.]], + [[3., 4., 5., 6.]]], - [[[[ 3.], - [103.]], - [[ 13.], - [113.]], + [[[2., 3., 4., 5.]], - [[ 23.], - [123.]]], + [[0., 0., 0., 0.]]], - [[[ 4.], - [104.]], + [[[0., 0., 0., 0.]], - [[ 14.], - [114.]], + [[4., 5., 6., 7.]]]], - [[ 24.], - [124.]]]], + [[[[1., 2., 3., 4.]], - [[[[ 4.], - [104.]], + [[4., 5., 6., 7.]]], - [[ 14.], - [114.]], - [[ 24.], - [124.]]], + [[[2., 3., 4., 5.]], + [[0., 0., 0., 0.]]], - [[[ 5.], - [105.]], - [[ 15.], - [115.]], + [[[3., 4., 5., 6.]], - [[ 25.], - [125.]]]]], + [[0., 0., 0., 0.]]]], + [[[[1., 2., 3., 4.]], - [[[[[ 5.], - [105.]], + [[2., 3., 4., 5.]]], - [[ 15.], - [115.]], - [[ 25.], - [125.]]], + [[[0., 0., 0., 0.]], + [[3., 4., 5., 6.]]], - [[[ 6.], - [106.]], - [[ 16.], - [116.]], + [[[0., 0., 0., 0.]], - [[ 26.], - [126.]]]], + [[4., 5., 6., 7.]]]]], - [[[[ 6.], - [106.]], - [[ 16.], - [116.]], + [[[[[0., 0., 0., 0.]], - [[ 26.], - [126.]]], + [[2., 3., 4., 5.]]], - [[[ 7.], - [107.]], + [[[1., 2., 3., 4.]], - [[ 17.], - [117.]], + [[3., 4., 5., 6.]]], - [[ 27.], - [127.]]]], + [[[0., 0., 0., 0.]], + [[4., 5., 6., 7.]]]], - [[[[ 7.], - [107.]], - [[ 17.], - [117.]], - [[ 27.], - [127.]]], + [[[[1., 2., 3., 4.]], + [[3., 4., 5., 6.]]], - [[[ 8.], - [108.]], - [[ 18.], - [118.]], + [[[0., 0., 0., 0.]], - [[ 28.], - [128.]]]], + [[4., 5., 6., 7.]]], + [[[2., 3., 4., 5.]], - [[[[ 8.], - [108.]], + [[0., 0., 0., 0.]]]], - [[ 18.], - [118.]], - [[ 28.], - [128.]]], + [[[[1., 2., 3., 4.]], - [[[ 9.], - [109.]], + [[0., 0., 0., 0.]]], - [[ 19.], - [119.]], - [[ 29.], - [129.]]]]], + [[[2., 3., 4., 5.]], + [[4., 5., 6., 7.]]], + [[[3., 4., 5., 6.]], - [[[[[ 9.], - [109.]], + [[0., 0., 0., 0.]]]]]]), device='cuda:0', + size=(2, 3, 2, 3, 4), nnz=3, layout=torch.sparse_bsc) +# _ccol_indices +tensor([[[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]], - [[ 19.], - [119.]], + [[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]]], device='cuda:0', dtype=torch.int32) +# _row_indices +tensor([[[0, 0, 0], + [0, 0, 0], + [0, 0, 0]], - [[ 29.], - [129.]]], + [[0, 0, 0], + [0, 0, 0], + [0, 0, 0]]], device='cuda:0', dtype=torch.int32) +# _values +tensor([[[[[[1., 2., 3., 4.]], + [[3., 4., 5., 6.]]], - [[[ 10.], - [110.]], - [[ 20.], - [120.]], + [[[2., 3., 4., 5.]], - [[ 30.], - [130.]]]], + [[0., 0., 0., 0.]]], + [[[0., 0., 0., 0.]], - [[[[ 10.], - [110.]], + [[4., 5., 6., 7.]]]], - [[ 20.], - [120.]], - [[ 30.], - [130.]]], + [[[[1., 2., 3., 4.]], - [[[ 11.], - [111.]], + [[4., 5., 6., 7.]]], - [[ 21.], - [121.]], - [[ 31.], - [131.]]]], + [[[2., 3., 4., 5.]], + [[0., 0., 0., 0.]]], - [[[[ 11.], - [111.]], + [[[3., 4., 5., 6.]], - [[ 21.], - [121.]], + [[0., 0., 0., 0.]]]], - [[ 31.], - [131.]]], - [[[ 12.], - [112.]], + [[[[1., 2., 3., 4.]], - [[ 22.], - [122.]], + [[2., 3., 4., 5.]]], - [[ 32.], - [132.]]]], + [[[0., 0., 0., 0.]], + [[3., 4., 5., 6.]]], - [[[[ 12.], - [112.]], - [[ 22.], - [122.]], + [[[0., 0., 0., 0.]], - [[ 32.], - [132.]]], + [[4., 5., 6., 7.]]]]], - [[[ 13.], - [113.]], - [[ 23.], - [123.]], - [[ 33.], - [133.]]]]]], + [[[[[0., 0., 0., 0.]], + [[2., 3., 4., 5.]]], + [[[1., 2., 3., 4.]], + [[3., 4., 5., 6.]]], - [[[[[[ 13.], - [113.]], - [[ 23.], - [123.]], + [[[0., 0., 0., 0.]], - [[ 33.], - [133.]]], + [[4., 5., 6., 7.]]]], - [[[ 14.], - [114.]], - [[ 24.], - [124.]], + [[[[1., 2., 3., 4.]], - [[ 34.], - [134.]]]], + [[3., 4., 5., 6.]]], + [[[0., 0., 0., 0.]], - [[[[ 14.], - [114.]], + [[4., 5., 6., 7.]]], - [[ 24.], - [124.]], - [[ 34.], - [134.]]], + [[[2., 3., 4., 5.]], + [[0., 0., 0., 0.]]]], - [[[ 15.], - [115.]], - [[ 25.], - [125.]], - [[ 35.], - [135.]]]], + [[[[1., 2., 3., 4.]], + [[0., 0., 0., 0.]]], - [[[[ 15.], - [115.]], + [[[2., 3., 4., 5.]], - [[ 25.], - [125.]], + [[4., 5., 6., 7.]]], - [[ 35.], - [135.]]], + [[[3., 4., 5., 6.]], - [[[ 16.], - [116.]], + [[0., 0., 0., 0.]]]]]], device='cuda:0') - [[ 26.], - [126.]], +########## torch.float32/torch.int32/size=()+(8, 6)+(4, 2) ########## +# sparse tensor +tensor(ccol_indices=tensor([0, 4, 7]), + row_indices=tensor([0, 1, 2, 3, 0, 2, 3]), + values=tensor([[[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + + [[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], - [[ 36.], - [136.]]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[[ 16.], - [116.]], + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]], - [[ 26.], - [126.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[ 36.], - [136.]]], - [[[ 17.], - [117.]], + [[[[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], - [[ 27.], - [127.]], + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], - [[ 37.], - [137.]]]]], + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[[[ 17.], - [117.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[ 27.], - [127.]], - [[ 37.], - [137.]]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[ 18.], - [118.]], + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], - [[ 28.], - [128.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[ 38.], - [138.]]]], + [[[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[[ 18.], - [118.]], + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]]]], - [[ 28.], - [128.]], - [[ 38.], - [138.]]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[ 19.], - [119.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 29.], - [129.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[ 39.], - [139.]]]], + [[[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], - [[[[ 19.], - [119.]], + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]]]], - [[ 29.], - [129.]], - [[ 39.], - [139.]]], + [[[[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], - [[[ 20.], - [120.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 30.], - [130.]], + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]]], - [[ 40.], - [140.]]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], - [[[[ 20.], - [120.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[ 30.], - [130.]], - [[ 40.], - [140.]]], + [[[[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], - [[[ 21.], - [121.]], + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], - [[ 31.], - [131.]], + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]]], - [[ 41.], - [141.]]]]], + [[[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]]]], - [[[[[ 21.], - [121.]], - [[ 31.], - [131.]], - [[ 41.], - [141.]]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], - [[[ 22.], - [122.]], + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]]], - [[ 32.], - [132.]], - [[ 42.], - [142.]]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]]]]), device='cuda:0', size=(8, 6, 4, 2), + nnz=7, layout=torch.sparse_bsc) +# _ccol_indices +tensor([0, 4, 7], device='cuda:0', dtype=torch.int32) +# _row_indices +tensor([0, 1, 2, 3, 0, 2, 3], device='cuda:0', dtype=torch.int32) +# _values +tensor([[[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[[ 22.], - [122.]], + [[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], - [[ 32.], - [132.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[ 42.], - [142.]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[ 23.], - [123.]], + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]], - [[ 33.], - [133.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[ 43.], - [143.]]]], + [[[[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], - [[[[ 23.], - [123.]], + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], - [[ 33.], - [133.]], + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]]], - [[ 43.], - [143.]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[ 24.], - [124.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 34.], - [134.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[ 44.], - [144.]]]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[[ 24.], - [124.]], + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], - [[ 34.], - [134.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[ 44.], - [144.]]], + [[[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], - [[[ 25.], - [125.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 35.], - [135.]], + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]]]], - [[ 45.], - [145.]]]]]]]), device='cuda:0', - size=(2, 3, 6, 6, 2, 1), nnz=4, layout=torch.sparse_bsc) -# _ccol_indices -tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]], device='cuda:0') -# _row_indices -tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]], device='cuda:0') -# _values -tensor([[[[[[[ 1.], - [101.]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 11.], - [111.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 21.], - [121.]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[[ 2.], - [102.]], + [[[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], - [[ 12.], - [112.]], + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], - [[ 22.], - [122.]]]], + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]]]], - [[[[ 2.], - [102.]], + [[[[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], - [[ 12.], - [112.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 22.], - [122.]]], + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]]], - [[[ 3.], - [103.]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 13.], - [113.]], + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], - [[ 23.], - [123.]]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[[[ 3.], - [103.]], + [[[[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], - [[ 13.], - [113.]], + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], - [[ 23.], - [123.]]], + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]]], - [[[ 4.], - [104.]], + [[[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], - [[ 14.], - [114.]], + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], - [[ 24.], - [124.]]]], + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]]]], - [[[[ 4.], - [104.]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 14.], - [114.]], + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], - [[ 24.], - [124.]]], + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]]], - [[[ 5.], - [105.]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 15.], - [115.]], + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], - [[ 25.], - [125.]]]]], + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]]]], device='cuda:0') +########## torch.float64/torch.int32/size=(2, 3)+(2, 3)+(4,) ########## +# sparse tensor +tensor(ccol_indices=tensor([[[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]], + [[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]]]), + row_indices=tensor([[[0, 0, 0], + [0, 0, 0], + [0, 0, 0]], - [[[[[ 5.], - [105.]], + [[0, 0, 0], + [0, 0, 0], + [0, 0, 0]]]), + values=tensor([[[[[[1., 2., 3., 4.]], - [[ 15.], - [115.]], + [[3., 4., 5., 6.]]], - [[ 25.], - [125.]]], + [[[2., 3., 4., 5.]], - [[[ 6.], - [106.]], + [[0., 0., 0., 0.]]], - [[ 16.], - [116.]], - [[ 26.], - [126.]]]], + [[[0., 0., 0., 0.]], + [[4., 5., 6., 7.]]]], - [[[[ 6.], - [106.]], - [[ 16.], - [116.]], + [[[[1., 2., 3., 4.]], - [[ 26.], - [126.]]], + [[4., 5., 6., 7.]]], - [[[ 7.], - [107.]], + [[[2., 3., 4., 5.]], - [[ 17.], - [117.]], + [[0., 0., 0., 0.]]], - [[ 27.], - [127.]]]], + [[[3., 4., 5., 6.]], + [[0., 0., 0., 0.]]]], - [[[[ 7.], - [107.]], - [[ 17.], - [117.]], - [[ 27.], - [127.]]], + [[[[1., 2., 3., 4.]], + [[2., 3., 4., 5.]]], - [[[ 8.], - [108.]], - [[ 18.], - [118.]], + [[[0., 0., 0., 0.]], - [[ 28.], - [128.]]]], + [[3., 4., 5., 6.]]], + [[[0., 0., 0., 0.]], - [[[[ 8.], - [108.]], + [[4., 5., 6., 7.]]]]], - [[ 18.], - [118.]], - [[ 28.], - [128.]]], - [[[ 9.], - [109.]], + [[[[[0., 0., 0., 0.]], - [[ 19.], - [119.]], + [[2., 3., 4., 5.]]], - [[ 29.], - [129.]]]]], + [[[1., 2., 3., 4.]], + [[3., 4., 5., 6.]]], - [[[[[ 9.], - [109.]], + [[[0., 0., 0., 0.]], - [[ 19.], - [119.]], + [[4., 5., 6., 7.]]]], - [[ 29.], - [129.]]], - [[[ 10.], - [110.]], + [[[[1., 2., 3., 4.]], - [[ 20.], - [120.]], + [[3., 4., 5., 6.]]], - [[ 30.], - [130.]]]], + [[[0., 0., 0., 0.]], + [[4., 5., 6., 7.]]], - [[[[ 10.], - [110.]], - [[ 20.], - [120.]], + [[[2., 3., 4., 5.]], - [[ 30.], - [130.]]], + [[0., 0., 0., 0.]]]], - [[[ 11.], - [111.]], - [[ 21.], - [121.]], + [[[[1., 2., 3., 4.]], - [[ 31.], - [131.]]]], + [[0., 0., 0., 0.]]], + [[[2., 3., 4., 5.]], - [[[[ 11.], - [111.]], + [[4., 5., 6., 7.]]], - [[ 21.], - [121.]], - [[ 31.], - [131.]]], + [[[3., 4., 5., 6.]], + [[0., 0., 0., 0.]]]]]]), device='cuda:0', + size=(2, 3, 2, 3, 4), nnz=3, dtype=torch.float64, layout=torch.sparse_bsc) +# _ccol_indices +tensor([[[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]], - [[[ 12.], - [112.]], + [[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]]], device='cuda:0', dtype=torch.int32) +# _row_indices +tensor([[[0, 0, 0], + [0, 0, 0], + [0, 0, 0]], - [[ 22.], - [122.]], + [[0, 0, 0], + [0, 0, 0], + [0, 0, 0]]], device='cuda:0', dtype=torch.int32) +# _values +tensor([[[[[[1., 2., 3., 4.]], - [[ 32.], - [132.]]]], + [[3., 4., 5., 6.]]], + [[[2., 3., 4., 5.]], - [[[[ 12.], - [112.]], + [[0., 0., 0., 0.]]], - [[ 22.], - [122.]], - [[ 32.], - [132.]]], + [[[0., 0., 0., 0.]], + [[4., 5., 6., 7.]]]], - [[[ 13.], - [113.]], - [[ 23.], - [123.]], - [[ 33.], - [133.]]]]]], + [[[[1., 2., 3., 4.]], + [[4., 5., 6., 7.]]], + [[[2., 3., 4., 5.]], + [[0., 0., 0., 0.]]], - [[[[[[ 13.], - [113.]], - [[ 23.], - [123.]], + [[[3., 4., 5., 6.]], - [[ 33.], - [133.]]], + [[0., 0., 0., 0.]]]], - [[[ 14.], - [114.]], - [[ 24.], - [124.]], + [[[[1., 2., 3., 4.]], - [[ 34.], - [134.]]]], + [[2., 3., 4., 5.]]], + [[[0., 0., 0., 0.]], - [[[[ 14.], - [114.]], + [[3., 4., 5., 6.]]], - [[ 24.], - [124.]], - [[ 34.], - [134.]]], + [[[0., 0., 0., 0.]], + [[4., 5., 6., 7.]]]]], - [[[ 15.], - [115.]], - [[ 25.], - [125.]], - [[ 35.], - [135.]]]], + [[[[[0., 0., 0., 0.]], + [[2., 3., 4., 5.]]], - [[[[ 15.], - [115.]], - [[ 25.], - [125.]], + [[[1., 2., 3., 4.]], - [[ 35.], - [135.]]], + [[3., 4., 5., 6.]]], - [[[ 16.], - [116.]], + [[[0., 0., 0., 0.]], - [[ 26.], - [126.]], + [[4., 5., 6., 7.]]]], - [[ 36.], - [136.]]]], + [[[[1., 2., 3., 4.]], - [[[[ 16.], - [116.]], + [[3., 4., 5., 6.]]], - [[ 26.], - [126.]], - [[ 36.], - [136.]]], + [[[0., 0., 0., 0.]], + [[4., 5., 6., 7.]]], - [[[ 17.], - [117.]], - [[ 27.], - [127.]], + [[[2., 3., 4., 5.]], - [[ 37.], - [137.]]]]], + [[0., 0., 0., 0.]]]], + [[[[1., 2., 3., 4.]], - [[[[[ 17.], - [117.]], + [[0., 0., 0., 0.]]], - [[ 27.], - [127.]], - [[ 37.], - [137.]]], + [[[2., 3., 4., 5.]], + [[4., 5., 6., 7.]]], - [[[ 18.], - [118.]], - [[ 28.], - [128.]], + [[[3., 4., 5., 6.]], - [[ 38.], - [138.]]]], + [[0., 0., 0., 0.]]]]]], device='cuda:0', dtype=torch.float64) +########## torch.float64/torch.int32/size=()+(8, 6)+(4, 2) ########## +# sparse tensor +tensor(ccol_indices=tensor([0, 4, 7]), + row_indices=tensor([0, 1, 2, 3, 0, 2, 3]), + values=tensor([[[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + + [[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[[[ 18.], - [118.]], - [[ 28.], - [128.]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 38.], - [138.]]], + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[[ 19.], - [119.]], - [[ 29.], - [129.]], - [[ 39.], - [139.]]]], + [[[[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]]], - [[[[ 19.], - [119.]], - [[ 29.], - [129.]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 39.], - [139.]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[[ 20.], - [120.]], - [[ 30.], - [130.]], - [[ 40.], - [140.]]]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[[[ 20.], - [120.]], - [[ 30.], - [130.]], + [[[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], - [[ 40.], - [140.]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]]]], - [[[ 21.], - [121.]], - [[ 31.], - [131.]], - [[ 41.], - [141.]]]]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[[[[ 21.], - [121.]], + [[[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], - [[ 31.], - [131.]], + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], - [[ 41.], - [141.]]], + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]]]], - [[[ 22.], - [122.]], - [[ 32.], - [132.]], + [[[[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], - [[ 42.], - [142.]]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]]], - [[[[ 22.], - [122.]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 32.], - [132.]], + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], - [[ 42.], - [142.]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[[ 23.], - [123.]], - [[ 33.], - [133.]], + [[[[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], - [[ 43.], - [143.]]]], + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]]], - [[[[ 23.], - [123.]], + [[[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], - [[ 33.], - [133.]], + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], - [[ 43.], - [143.]]], + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]]]], - [[[ 24.], - [124.]], - [[ 34.], - [134.]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 44.], - [144.]]]], + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]]], - [[[[ 24.], - [124.]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 34.], - [134.]], + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], - [[ 44.], - [144.]]], + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]]]]), device='cuda:0', size=(8, 6, 4, 2), + nnz=7, dtype=torch.float64, layout=torch.sparse_bsc) +# _ccol_indices +tensor([0, 4, 7], device='cuda:0', dtype=torch.int32) +# _row_indices +tensor([0, 1, 2, 3, 0, 2, 3], device='cuda:0', dtype=torch.int32) +# _values +tensor([[[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], - [[[ 25.], - [125.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[ 35.], - [135.]], - [[ 45.], - [145.]]]]]]], device='cuda:0') + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]], -########## torch.float64/torch.int64/size=()+(6, 6)+(2,) ########## -# sparse tensor -tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 2]), - values=tensor([[[[ 1., 101.], - [ 11., 111.], - [ 21., 121.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[ 2., 102.], - [ 12., 112.], - [ 22., 122.]]], - [[[ 2., 102.], - [ 12., 112.], - [ 22., 122.]], + [[[[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], - [[ 3., 103.], - [ 13., 113.], - [ 23., 123.]]], + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]]], - [[[ 3., 103.], - [ 13., 113.], - [ 23., 123.]], - [[ 4., 104.], - [ 14., 114.], - [ 24., 124.]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[ 4., 104.], - [ 14., 114.], - [ 24., 124.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[ 5., 105.], - [ 15., 115.], - [ 25., 125.]]]]), device='cuda:0', size=(6, 6, 2), - nnz=4, dtype=torch.float64, layout=torch.sparse_bsc) -# _ccol_indices -tensor([0, 2, 4], device='cuda:0') -# _row_indices -tensor([0, 1, 0, 2], device='cuda:0') -# _values -tensor([[[[ 1., 101.], - [ 11., 111.], - [ 21., 121.]], - [[ 2., 102.], - [ 12., 112.], - [ 22., 122.]]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[ 2., 102.], - [ 12., 112.], - [ 22., 122.]], + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], - [[ 3., 103.], - [ 13., 113.], - [ 23., 123.]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[[ 3., 103.], - [ 13., 113.], - [ 23., 123.]], + [[[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], - [[ 4., 104.], - [ 14., 114.], - [ 24., 124.]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]]]], - [[[ 4., 104.], - [ 14., 114.], - [ 24., 124.]], - [[ 5., 105.], - [ 15., 115.], - [ 25., 125.]]]], device='cuda:0', dtype=torch.float64) -########## torch.float64/torch.int64/size=()+(9, 4)+(4, 2) ########## -# sparse tensor -tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 2]), - values=tensor([[[[[1.0000e+00, 1.0010e+03], - [1.0100e+02, 1.1010e+03], - [2.0100e+02, 1.2010e+03], - [3.0100e+02, 1.3010e+03]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[1.1000e+01, 1.0110e+03], - [1.1100e+02, 1.1110e+03], - [2.1100e+02, 1.2110e+03], - [3.1100e+02, 1.3110e+03]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]]], + [[[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], - [[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]]]], - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]]]], + [[[[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], - [[[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]]], + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]]], - [[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]]], + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]]]], + [[[[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], - [[[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]]], - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]]], + [[[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], - [[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]]], + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]]]], - [[[5.0000e+00, 1.0050e+03], - [1.0500e+02, 1.1050e+03], - [2.0500e+02, 1.2050e+03], - [3.0500e+02, 1.3050e+03]], - [[1.5000e+01, 1.0150e+03], - [1.1500e+02, 1.1150e+03], - [2.1500e+02, 1.2150e+03], - [3.1500e+02, 1.3150e+03]]]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]]], - [[[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], - [[[5.0000e+00, 1.0050e+03], - [1.0500e+02, 1.1050e+03], - [2.0500e+02, 1.2050e+03], - [3.0500e+02, 1.3050e+03]], + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]]]], device='cuda:0', dtype=torch.float64) - [[1.5000e+01, 1.0150e+03], - [1.1500e+02, 1.1150e+03], - [2.1500e+02, 1.2150e+03], - [3.1500e+02, 1.3150e+03]]], +########## torch.float32/torch.int64/size=(2, 3)+(2, 3)+(4,) ########## +# sparse tensor +tensor(ccol_indices=tensor([[[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]], - [[[6.0000e+00, 1.0060e+03], - [1.0600e+02, 1.1060e+03], - [2.0600e+02, 1.2060e+03], - [3.0600e+02, 1.3060e+03]], + [[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]]]), + row_indices=tensor([[[0, 0, 0], + [0, 0, 0], + [0, 0, 0]], - [[1.6000e+01, 1.0160e+03], - [1.1600e+02, 1.1160e+03], - [2.1600e+02, 1.2160e+03], - [3.1600e+02, 1.3160e+03]]]]]), device='cuda:0', - size=(9, 4, 4, 2), nnz=4, dtype=torch.float64, layout=torch.sparse_bsc) -# _ccol_indices -tensor([0, 2, 4], device='cuda:0') -# _row_indices -tensor([0, 1, 0, 2], device='cuda:0') -# _values -tensor([[[[[1.0000e+00, 1.0010e+03], - [1.0100e+02, 1.1010e+03], - [2.0100e+02, 1.2010e+03], - [3.0100e+02, 1.3010e+03]], + [[0, 0, 0], + [0, 0, 0], + [0, 0, 0]]]), + values=tensor([[[[[[1., 2., 3., 4.]], - [[1.1000e+01, 1.0110e+03], - [1.1100e+02, 1.1110e+03], - [2.1100e+02, 1.2110e+03], - [3.1100e+02, 1.3110e+03]]], + [[3., 4., 5., 6.]]], - [[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], + [[[2., 3., 4., 5.]], - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]]], + [[0., 0., 0., 0.]]], - [[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], + [[[0., 0., 0., 0.]], - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]]]], + [[4., 5., 6., 7.]]]], - [[[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], + [[[[1., 2., 3., 4.]], - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]]], + [[4., 5., 6., 7.]]], - [[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], + [[[2., 3., 4., 5.]], - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]]], + [[0., 0., 0., 0.]]], - [[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], + [[[3., 4., 5., 6.]], - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]]]], + [[0., 0., 0., 0.]]]], - [[[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], + [[[[1., 2., 3., 4.]], - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]]], + [[2., 3., 4., 5.]]], - [[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], + [[[0., 0., 0., 0.]], - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]]], + [[3., 4., 5., 6.]]], - [[[5.0000e+00, 1.0050e+03], - [1.0500e+02, 1.1050e+03], - [2.0500e+02, 1.2050e+03], - [3.0500e+02, 1.3050e+03]], + [[[0., 0., 0., 0.]], - [[1.5000e+01, 1.0150e+03], - [1.1500e+02, 1.1150e+03], - [2.1500e+02, 1.2150e+03], - [3.1500e+02, 1.3150e+03]]]], + [[4., 5., 6., 7.]]]]], - [[[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]]], + [[[[[0., 0., 0., 0.]], + [[2., 3., 4., 5.]]], - [[[5.0000e+00, 1.0050e+03], - [1.0500e+02, 1.1050e+03], - [2.0500e+02, 1.2050e+03], - [3.0500e+02, 1.3050e+03]], - [[1.5000e+01, 1.0150e+03], - [1.1500e+02, 1.1150e+03], - [2.1500e+02, 1.2150e+03], - [3.1500e+02, 1.3150e+03]]], + [[[1., 2., 3., 4.]], + [[3., 4., 5., 6.]]], - [[[6.0000e+00, 1.0060e+03], - [1.0600e+02, 1.1060e+03], - [2.0600e+02, 1.2060e+03], - [3.0600e+02, 1.3060e+03]], - [[1.6000e+01, 1.0160e+03], - [1.1600e+02, 1.1160e+03], - [2.1600e+02, 1.2160e+03], - [3.1600e+02, 1.3160e+03]]]]], device='cuda:0', dtype=torch.float64) + [[[0., 0., 0., 0.]], -########## torch.float64/torch.int64/size=(2, 3)+(6, 6)+(2, 1) ########## -# sparse tensor -tensor(ccol_indices=tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], + [[4., 5., 6., 7.]]]], - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]), - row_indices=tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]), - values=tensor([[[[[[[ 1.], - [101.]], - [[ 11.], - [111.]], + [[[[1., 2., 3., 4.]], - [[ 21.], - [121.]]], + [[3., 4., 5., 6.]]], - [[[ 2.], - [102.]], + [[[0., 0., 0., 0.]], - [[ 12.], - [112.]], + [[4., 5., 6., 7.]]], - [[ 22.], - [122.]]]], + [[[2., 3., 4., 5.]], + [[0., 0., 0., 0.]]]], - [[[[ 2.], - [102.]], - [[ 12.], - [112.]], - [[ 22.], - [122.]]], + [[[[1., 2., 3., 4.]], + [[0., 0., 0., 0.]]], - [[[ 3.], - [103.]], - [[ 13.], - [113.]], + [[[2., 3., 4., 5.]], - [[ 23.], - [123.]]]], + [[4., 5., 6., 7.]]], + [[[3., 4., 5., 6.]], - [[[[ 3.], - [103.]], + [[0., 0., 0., 0.]]]]]]), device='cuda:0', + size=(2, 3, 2, 3, 4), nnz=3, layout=torch.sparse_bsc) +# _ccol_indices +tensor([[[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]], - [[ 13.], - [113.]], + [[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]]], device='cuda:0') +# _row_indices +tensor([[[0, 0, 0], + [0, 0, 0], + [0, 0, 0]], - [[ 23.], - [123.]]], + [[0, 0, 0], + [0, 0, 0], + [0, 0, 0]]], device='cuda:0') +# _values +tensor([[[[[[1., 2., 3., 4.]], + [[3., 4., 5., 6.]]], - [[[ 4.], - [104.]], - [[ 14.], - [114.]], + [[[2., 3., 4., 5.]], - [[ 24.], - [124.]]]], + [[0., 0., 0., 0.]]], + [[[0., 0., 0., 0.]], - [[[[ 4.], - [104.]], + [[4., 5., 6., 7.]]]], - [[ 14.], - [114.]], - [[ 24.], - [124.]]], + [[[[1., 2., 3., 4.]], - [[[ 5.], - [105.]], + [[4., 5., 6., 7.]]], - [[ 15.], - [115.]], - [[ 25.], - [125.]]]]], + [[[2., 3., 4., 5.]], + [[0., 0., 0., 0.]]], + [[[3., 4., 5., 6.]], - [[[[[ 5.], - [105.]], + [[0., 0., 0., 0.]]]], - [[ 15.], - [115.]], - [[ 25.], - [125.]]], + [[[[1., 2., 3., 4.]], - [[[ 6.], - [106.]], + [[2., 3., 4., 5.]]], - [[ 16.], - [116.]], - [[ 26.], - [126.]]]], + [[[0., 0., 0., 0.]], + [[3., 4., 5., 6.]]], - [[[[ 6.], - [106.]], + [[[0., 0., 0., 0.]], - [[ 16.], - [116.]], + [[4., 5., 6., 7.]]]]], - [[ 26.], - [126.]]], - [[[ 7.], - [107.]], - [[ 17.], - [117.]], + [[[[[0., 0., 0., 0.]], - [[ 27.], - [127.]]]], + [[2., 3., 4., 5.]]], + [[[1., 2., 3., 4.]], - [[[[ 7.], - [107.]], + [[3., 4., 5., 6.]]], - [[ 17.], - [117.]], - [[ 27.], - [127.]]], + [[[0., 0., 0., 0.]], + [[4., 5., 6., 7.]]]], - [[[ 8.], - [108.]], - [[ 18.], - [118.]], - [[ 28.], - [128.]]]], + [[[[1., 2., 3., 4.]], + [[3., 4., 5., 6.]]], - [[[[ 8.], - [108.]], + [[[0., 0., 0., 0.]], - [[ 18.], - [118.]], + [[4., 5., 6., 7.]]], - [[ 28.], - [128.]]], + [[[2., 3., 4., 5.]], - [[[ 9.], - [109.]], + [[0., 0., 0., 0.]]]], - [[ 19.], - [119.]], - [[ 29.], - [129.]]]]], + [[[[1., 2., 3., 4.]], + [[0., 0., 0., 0.]]], - [[[[[ 9.], - [109.]], + [[[2., 3., 4., 5.]], - [[ 19.], - [119.]], + [[4., 5., 6., 7.]]], - [[ 29.], - [129.]]], + [[[3., 4., 5., 6.]], - [[[ 10.], - [110.]], + [[0., 0., 0., 0.]]]]]], device='cuda:0') - [[ 20.], - [120.]], +########## torch.float32/torch.int64/size=()+(8, 6)+(4, 2) ########## +# sparse tensor +tensor(ccol_indices=tensor([0, 4, 7]), + row_indices=tensor([0, 1, 2, 3, 0, 2, 3]), + values=tensor([[[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + + [[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], - [[ 30.], - [130.]]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[[ 10.], - [110.]], + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]], - [[ 20.], - [120.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[ 30.], - [130.]]], - [[[ 11.], - [111.]], + [[[[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], - [[ 21.], - [121.]], + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], - [[ 31.], - [131.]]]], + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[[ 11.], - [111.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 21.], - [121.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[ 31.], - [131.]]], - [[[ 12.], - [112.]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 22.], - [122.]], + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], - [[ 32.], - [132.]]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], + [[[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], - [[[[ 12.], - [112.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 22.], - [122.]], + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]]]], - [[ 32.], - [132.]]], - [[[ 13.], - [113.]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 23.], - [123.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 33.], - [133.]]]]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], + [[[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]]]], - [[[[[[ 13.], - [113.]], - [[ 23.], - [123.]], - [[ 33.], - [133.]]], + [[[[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[ 14.], - [114.]], + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]]], - [[ 24.], - [124.]], - [[ 34.], - [134.]]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[[[ 14.], - [114.]], - [[ 24.], - [124.]], - [[ 34.], - [134.]]], + [[[[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], - [[[ 15.], - [115.]], + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]]], - [[ 25.], - [125.]], - [[ 35.], - [135.]]]], + [[[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]]]], - [[[[ 15.], - [115.]], - [[ 25.], - [125.]], - [[ 35.], - [135.]]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], - [[[ 16.], - [116.]], + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]]], - [[ 26.], - [126.]], - [[ 36.], - [136.]]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]]]]), device='cuda:0', size=(8, 6, 4, 2), + nnz=7, layout=torch.sparse_bsc) +# _ccol_indices +tensor([0, 4, 7], device='cuda:0') +# _row_indices +tensor([0, 1, 2, 3, 0, 2, 3], device='cuda:0') +# _values +tensor([[[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[[ 16.], - [116.]], + [[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], - [[ 26.], - [126.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[ 36.], - [136.]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[ 17.], - [117.]], + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]], - [[ 27.], - [127.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[ 37.], - [137.]]]]], + [[[[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], - [[[[[ 17.], - [117.]], + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]]], - [[ 27.], - [127.]], - [[ 37.], - [137.]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[ 18.], - [118.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[ 28.], - [128.]], - [[ 38.], - [138.]]]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], - [[[[ 18.], - [118.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[ 28.], - [128.]], - [[ 38.], - [138.]]], + [[[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[ 19.], - [119.]], + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]]]], - [[ 29.], - [129.]], - [[ 39.], - [139.]]]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[[ 19.], - [119.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[ 29.], - [129.]], - [[ 39.], - [139.]]], + [[[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], - [[[ 20.], - [120.]], + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]]]], - [[ 30.], - [130.]], - [[ 40.], - [140.]]]], + [[[[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[[ 20.], - [120.]], + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]]], - [[ 30.], - [130.]], - [[ 40.], - [140.]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], - [[[ 21.], - [121.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[ 31.], - [131.]], - [[ 41.], - [141.]]]]], + [[[[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]]], - [[[[[ 21.], - [121.]], - [[ 31.], - [131.]], + [[[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], - [[ 41.], - [141.]]], + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]]]], - [[[ 22.], - [122.]], - [[ 32.], - [132.]], - [[ 42.], - [142.]]]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]]], - [[[[ 22.], - [122.]], - [[ 32.], - [132.]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 42.], - [142.]]], + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]]]], device='cuda:0') - [[[ 23.], - [123.]], - [[ 33.], - [133.]], +########## torch.float64/torch.int64/size=(2, 3)+(2, 3)+(4,) ########## +# sparse tensor +tensor(ccol_indices=tensor([[[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]], - [[ 43.], - [143.]]]], + [[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]]]), + row_indices=tensor([[[0, 0, 0], + [0, 0, 0], + [0, 0, 0]], + [[0, 0, 0], + [0, 0, 0], + [0, 0, 0]]]), + values=tensor([[[[[[1., 2., 3., 4.]], + [[3., 4., 5., 6.]]], - [[[[ 23.], - [123.]], - [[ 33.], - [133.]], + [[[2., 3., 4., 5.]], - [[ 43.], - [143.]]], + [[0., 0., 0., 0.]]], - [[[ 24.], - [124.]], + [[[0., 0., 0., 0.]], - [[ 34.], - [134.]], + [[4., 5., 6., 7.]]]], - [[ 44.], - [144.]]]], + [[[[1., 2., 3., 4.]], - [[[[ 24.], - [124.]], + [[4., 5., 6., 7.]]], - [[ 34.], - [134.]], - [[ 44.], - [144.]]], + [[[2., 3., 4., 5.]], + [[0., 0., 0., 0.]]], - [[[ 25.], - [125.]], - [[ 35.], - [135.]], + [[[3., 4., 5., 6.]], - [[ 45.], - [145.]]]]]]]), device='cuda:0', - size=(2, 3, 6, 6, 2, 1), nnz=4, dtype=torch.float64, - layout=torch.sparse_bsc) -# _ccol_indices -tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], + [[0., 0., 0., 0.]]]], - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]], device='cuda:0') -# _row_indices -tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]], device='cuda:0') -# _values -tensor([[[[[[[ 1.], - [101.]], - [[ 11.], - [111.]], + [[[[1., 2., 3., 4.]], - [[ 21.], - [121.]]], + [[2., 3., 4., 5.]]], - [[[ 2.], - [102.]], + [[[0., 0., 0., 0.]], - [[ 12.], - [112.]], + [[3., 4., 5., 6.]]], - [[ 22.], - [122.]]]], + [[[0., 0., 0., 0.]], + [[4., 5., 6., 7.]]]]], - [[[[ 2.], - [102.]], - [[ 12.], - [112.]], - [[ 22.], - [122.]]], + [[[[[0., 0., 0., 0.]], - [[[ 3.], - [103.]], + [[2., 3., 4., 5.]]], - [[ 13.], - [113.]], - [[ 23.], - [123.]]]], + [[[1., 2., 3., 4.]], + [[3., 4., 5., 6.]]], - [[[[ 3.], - [103.]], + [[[0., 0., 0., 0.]], - [[ 13.], - [113.]], + [[4., 5., 6., 7.]]]], - [[ 23.], - [123.]]], - [[[ 4.], - [104.]], + [[[[1., 2., 3., 4.]], - [[ 14.], - [114.]], + [[3., 4., 5., 6.]]], - [[ 24.], - [124.]]]], + [[[0., 0., 0., 0.]], + [[4., 5., 6., 7.]]], - [[[[ 4.], - [104.]], - [[ 14.], - [114.]], + [[[2., 3., 4., 5.]], - [[ 24.], - [124.]]], + [[0., 0., 0., 0.]]]], - [[[ 5.], - [105.]], - [[ 15.], - [115.]], + [[[[1., 2., 3., 4.]], - [[ 25.], - [125.]]]]], + [[0., 0., 0., 0.]]], + [[[2., 3., 4., 5.]], + [[4., 5., 6., 7.]]], - [[[[[ 5.], - [105.]], - [[ 15.], - [115.]], + [[[3., 4., 5., 6.]], - [[ 25.], - [125.]]], + [[0., 0., 0., 0.]]]]]]), device='cuda:0', + size=(2, 3, 2, 3, 4), nnz=3, dtype=torch.float64, layout=torch.sparse_bsc) +# _ccol_indices +tensor([[[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]], + [[0, 1, 2, 3], + [0, 1, 2, 3], + [0, 1, 2, 3]]], device='cuda:0') +# _row_indices +tensor([[[0, 0, 0], + [0, 0, 0], + [0, 0, 0]], - [[[ 6.], - [106.]], + [[0, 0, 0], + [0, 0, 0], + [0, 0, 0]]], device='cuda:0') +# _values +tensor([[[[[[1., 2., 3., 4.]], - [[ 16.], - [116.]], + [[3., 4., 5., 6.]]], - [[ 26.], - [126.]]]], + [[[2., 3., 4., 5.]], + [[0., 0., 0., 0.]]], - [[[[ 6.], - [106.]], - [[ 16.], - [116.]], + [[[0., 0., 0., 0.]], - [[ 26.], - [126.]]], + [[4., 5., 6., 7.]]]], - [[[ 7.], - [107.]], - [[ 17.], - [117.]], + [[[[1., 2., 3., 4.]], - [[ 27.], - [127.]]]], + [[4., 5., 6., 7.]]], + [[[2., 3., 4., 5.]], - [[[[ 7.], - [107.]], + [[0., 0., 0., 0.]]], - [[ 17.], - [117.]], - [[ 27.], - [127.]]], + [[[3., 4., 5., 6.]], + [[0., 0., 0., 0.]]]], - [[[ 8.], - [108.]], - [[ 18.], - [118.]], - [[ 28.], - [128.]]]], + [[[[1., 2., 3., 4.]], + [[2., 3., 4., 5.]]], - [[[[ 8.], - [108.]], + [[[0., 0., 0., 0.]], - [[ 18.], - [118.]], + [[3., 4., 5., 6.]]], - [[ 28.], - [128.]]], + [[[0., 0., 0., 0.]], - [[[ 9.], - [109.]], + [[4., 5., 6., 7.]]]]], - [[ 19.], - [119.]], - [[ 29.], - [129.]]]]], + [[[[[0., 0., 0., 0.]], + [[2., 3., 4., 5.]]], - [[[[[ 9.], - [109.]], - [[ 19.], - [119.]], + [[[1., 2., 3., 4.]], - [[ 29.], - [129.]]], + [[3., 4., 5., 6.]]], - [[[ 10.], - [110.]], + [[[0., 0., 0., 0.]], - [[ 20.], - [120.]], + [[4., 5., 6., 7.]]]], - [[ 30.], - [130.]]]], + [[[[1., 2., 3., 4.]], - [[[[ 10.], - [110.]], + [[3., 4., 5., 6.]]], - [[ 20.], - [120.]], - [[ 30.], - [130.]]], + [[[0., 0., 0., 0.]], + [[4., 5., 6., 7.]]], - [[[ 11.], - [111.]], - [[ 21.], - [121.]], + [[[2., 3., 4., 5.]], - [[ 31.], - [131.]]]], + [[0., 0., 0., 0.]]]], - [[[[ 11.], - [111.]], + [[[[1., 2., 3., 4.]], - [[ 21.], - [121.]], + [[0., 0., 0., 0.]]], - [[ 31.], - [131.]]], + [[[2., 3., 4., 5.]], - [[[ 12.], - [112.]], + [[4., 5., 6., 7.]]], - [[ 22.], - [122.]], - [[ 32.], - [132.]]]], + [[[3., 4., 5., 6.]], + [[0., 0., 0., 0.]]]]]], device='cuda:0', dtype=torch.float64) +########## torch.float64/torch.int64/size=()+(8, 6)+(4, 2) ########## +# sparse tensor +tensor(ccol_indices=tensor([0, 4, 7]), + row_indices=tensor([0, 1, 2, 3, 0, 2, 3]), + values=tensor([[[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + + [[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], - [[[[ 12.], - [112.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[ 22.], - [122.]], - [[ 32.], - [132.]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]], - [[[ 13.], - [113.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[ 23.], - [123.]], - [[ 33.], - [133.]]]]]], + [[[[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]]], - [[[[[[ 13.], - [113.]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 23.], - [123.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 33.], - [133.]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[[ 14.], - [114.]], - [[ 24.], - [124.]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 34.], - [134.]]]], + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[[[ 14.], - [114.]], + [[[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], - [[ 24.], - [124.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 34.], - [134.]]], + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]]]], - [[[ 15.], - [115.]], - [[ 25.], - [125.]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 35.], - [135.]]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[[[ 15.], - [115.]], + [[[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], - [[ 25.], - [125.]], + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], - [[ 35.], - [135.]]], + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]]]], - [[[ 16.], - [116.]], - [[ 26.], - [126.]], + [[[[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], - [[ 36.], - [136.]]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]]], - [[[[ 16.], - [116.]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 26.], - [126.]], + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], - [[ 36.], - [136.]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[[ 17.], - [117.]], - [[ 27.], - [127.]], + [[[[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], - [[ 37.], - [137.]]]]], + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]]], + [[[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], - [[[[[ 17.], - [117.]], + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], - [[ 27.], - [127.]], + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]]]], - [[ 37.], - [137.]]], - [[[ 18.], - [118.]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 28.], - [128.]], + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], - [[ 38.], - [138.]]]], + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[[ 18.], - [118.]], + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], - [[ 28.], - [128.]], + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]]]]), device='cuda:0', size=(8, 6, 4, 2), + nnz=7, dtype=torch.float64, layout=torch.sparse_bsc) +# _ccol_indices +tensor([0, 4, 7], device='cuda:0') +# _row_indices +tensor([0, 1, 2, 3, 0, 2, 3], device='cuda:0') +# _values +tensor([[[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 38.], - [138.]]], + [[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[[ 19.], - [119.]], - [[ 29.], - [129.]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 39.], - [139.]]]], + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[[[ 19.], - [119.]], - [[ 29.], - [129.]], + [[[[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], - [[ 39.], - [139.]]], + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]]], - [[[ 20.], - [120.]], - [[ 30.], - [130.]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 40.], - [140.]]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[[[ 20.], - [120.]], - [[ 30.], - [130.]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 40.], - [140.]]], + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[[ 21.], - [121.]], - [[ 31.], - [131.]], + [[[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], - [[ 41.], - [141.]]]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]]]], - [[[[[ 21.], - [121.]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 31.], - [131.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 41.], - [141.]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[[ 22.], - [122.]], + [[[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], - [[ 32.], - [132.]], + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], - [[ 42.], - [142.]]]], + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]]]], - [[[[ 22.], - [122.]], + [[[[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], - [[ 32.], - [132.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 42.], - [142.]]], + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]]], - [[[ 23.], - [123.]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 33.], - [133.]], + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], - [[ 43.], - [143.]]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[[[ 23.], - [123.]], + [[[[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], - [[ 33.], - [133.]], + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], - [[ 43.], - [143.]]], + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]]], - [[[ 24.], - [124.]], + [[[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], - [[ 34.], - [134.]], + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], - [[ 44.], - [144.]]]], + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]]]], - [[[[ 24.], - [124.]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 34.], - [134.]], + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], - [[ 44.], - [144.]]], + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]]], - [[[ 25.], - [125.]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 35.], - [135.]], + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], - [[ 45.], - [145.]]]]]]], device='cuda:0', dtype=torch.float64) + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]]]], device='cuda:0', dtype=torch.float64) diff --git a/test/expect/TestSparseCompressedCUDA.test_print_SparseBSR_cuda.expect b/test/expect/TestSparseCompressedCUDA.test_print_SparseBSR_cuda.expect index 0dd1aff7d4dc2..66bc7fa9885e4 100644 --- a/test/expect/TestSparseCompressedCUDA.test_print_SparseBSR_cuda.expect +++ b/test/expect/TestSparseCompressedCUDA.test_print_SparseBSR_cuda.expect @@ -1,6949 +1,3583 @@ -########## torch.float32/torch.int32/size=()+(4, 3)+() ########## +########## torch.float32/torch.int32/size=(2, 3)+(2, 3)+() ########## # sparse tensor -tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 2]), - values=tensor([[[1.], - [2.]], +tensor(crow_indices=tensor([[[0, 3], + [0, 3], + [0, 3]], - [[2.], - [3.]], + [[0, 3], + [0, 3], + [0, 3]]]), + col_indices=tensor([[[0, 1, 2], + [0, 1, 2], + [0, 1, 2]], - [[3.], - [4.]], + [[0, 1, 2], + [0, 1, 2], + [0, 1, 2]]]), + values=tensor([[[[[1.], + [3.]], - [[4.], - [5.]]]), device='cuda:0', size=(4, 3), nnz=4, - layout=torch.sparse_bsr) -# _crow_indices -tensor([0, 2, 4], device='cuda:0', dtype=torch.int32) -# _col_indices -tensor([0, 1, 0, 2], device='cuda:0', dtype=torch.int32) -# _values -tensor([[[1.], - [2.]], - - [[2.], - [3.]], - - [[3.], - [4.]], - - [[4.], - [5.]]], device='cuda:0') - -########## torch.float32/torch.int32/size=()+(0, 0)+() ########## -# sparse tensor -tensor(crow_indices=tensor([0]), - col_indices=tensor([], size=(0,)), - values=tensor([], size=(0, 2, 1)), device='cuda:0', size=(0, 0), nnz=0, - layout=torch.sparse_bsr) -# _crow_indices -tensor([0], device='cuda:0', dtype=torch.int32) -# _col_indices -tensor([], device='cuda:0', dtype=torch.int32) -# _values -tensor([], device='cuda:0', size=(0, 2, 1)) - -########## torch.float32/torch.int32/size=(2,)+(2, 6)+() ########## -# sparse tensor -tensor(crow_indices=tensor([[0, 2, 4], - [0, 3, 4]]), - col_indices=tensor([[0, 1, 0, 1], - [0, 1, 2, 0]]), - values=tensor([[[[ 1., 11.]], - - [[ 2., 12.]], - - [[ 3., 13.]], - - [[ 4., 14.]]], - - - [[[ 5., 15.]], - - [[ 6., 16.]], - - [[ 7., 17.]], - - [[ 8., 18.]]]]), device='cuda:0', size=(2, 2, 6), nnz=4, - layout=torch.sparse_bsr) -# _crow_indices -tensor([[0, 2, 4], - [0, 3, 4]], device='cuda:0', dtype=torch.int32) -# _col_indices -tensor([[0, 1, 0, 1], - [0, 1, 2, 0]], device='cuda:0', dtype=torch.int32) -# _values -tensor([[[[ 1., 11.]], - - [[ 2., 12.]], - - [[ 3., 13.]], - - [[ 4., 14.]]], - - - [[[ 5., 15.]], + [[2.], + [0.]], - [[ 6., 16.]], + [[0.], + [4.]]], - [[ 7., 17.]], - [[ 8., 18.]]]], device='cuda:0') + [[[1.], + [4.]], -########## torch.float32/torch.int32/size=(2, 3)+(4, 9)+() ########## -# sparse tensor -tensor(crow_indices=tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], - - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]), - col_indices=tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]), - values=tensor([[[[[ 1., 11., 21.], - [ 2., 12., 22.]], - - [[ 2., 12., 22.], - [ 3., 13., 23.]], - - [[ 3., 13., 23.], - [ 4., 14., 24.]], - - [[ 4., 14., 24.], - [ 5., 15., 25.]]], - - - [[[ 5., 15., 25.], - [ 6., 16., 26.]], + [[2.], + [0.]], - [[ 6., 16., 26.], - [ 7., 17., 27.]], + [[3.], + [0.]]], - [[ 7., 17., 27.], - [ 8., 18., 28.]], - [[ 8., 18., 28.], - [ 9., 19., 29.]]], + [[[1.], + [2.]], + [[0.], + [3.]], - [[[ 9., 19., 29.], - [10., 20., 30.]], + [[0.], + [4.]]]], - [[10., 20., 30.], - [11., 21., 31.]], - [[11., 21., 31.], - [12., 22., 32.]], - [[12., 22., 32.], - [13., 23., 33.]]]], + [[[[0.], + [2.]], + [[1.], + [3.]], + [[0.], + [4.]]], - [[[[13., 23., 33.], - [14., 24., 34.]], - [[14., 24., 34.], - [15., 25., 35.]], + [[[1.], + [3.]], - [[15., 25., 35.], - [16., 26., 36.]], + [[0.], + [4.]], - [[16., 26., 36.], - [17., 27., 37.]]], + [[2.], + [0.]]], - [[[17., 27., 37.], - [18., 28., 38.]], + [[[1.], + [0.]], - [[18., 28., 38.], - [19., 29., 39.]], + [[2.], + [4.]], - [[19., 29., 39.], - [20., 30., 40.]], - - [[20., 30., 40.], - [21., 31., 41.]]], - - - [[[21., 31., 41.], - [22., 32., 42.]], - - [[22., 32., 42.], - [23., 33., 43.]], - - [[23., 33., 43.], - [24., 34., 44.]], - - [[24., 34., 44.], - [25., 35., 45.]]]]]), device='cuda:0', - size=(2, 3, 4, 9), nnz=4, layout=torch.sparse_bsr) + [[3.], + [0.]]]]]), device='cuda:0', size=(2, 3, 2, 3), nnz=3, + layout=torch.sparse_bsr) # _crow_indices -tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], +tensor([[[0, 3], + [0, 3], + [0, 3]], - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]], device='cuda:0', dtype=torch.int32) + [[0, 3], + [0, 3], + [0, 3]]], device='cuda:0', dtype=torch.int32) # _col_indices -tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], +tensor([[[0, 1, 2], + [0, 1, 2], + [0, 1, 2]], - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]], device='cuda:0', dtype=torch.int32) + [[0, 1, 2], + [0, 1, 2], + [0, 1, 2]]], device='cuda:0', dtype=torch.int32) # _values -tensor([[[[[ 1., 11., 21.], - [ 2., 12., 22.]], - - [[ 2., 12., 22.], - [ 3., 13., 23.]], +tensor([[[[[1.], + [3.]], - [[ 3., 13., 23.], - [ 4., 14., 24.]], + [[2.], + [0.]], - [[ 4., 14., 24.], - [ 5., 15., 25.]]], + [[0.], + [4.]]], - [[[ 5., 15., 25.], - [ 6., 16., 26.]], + [[[1.], + [4.]], - [[ 6., 16., 26.], - [ 7., 17., 27.]], + [[2.], + [0.]], - [[ 7., 17., 27.], - [ 8., 18., 28.]], + [[3.], + [0.]]], - [[ 8., 18., 28.], - [ 9., 19., 29.]]], + [[[1.], + [2.]], - [[[ 9., 19., 29.], - [10., 20., 30.]], + [[0.], + [3.]], - [[10., 20., 30.], - [11., 21., 31.]], + [[0.], + [4.]]]], - [[11., 21., 31.], - [12., 22., 32.]], - [[12., 22., 32.], - [13., 23., 33.]]]], + [[[[0.], + [2.]], + [[1.], + [3.]], - [[[[13., 23., 33.], - [14., 24., 34.]], + [[0.], + [4.]]], - [[14., 24., 34.], - [15., 25., 35.]], - [[15., 25., 35.], - [16., 26., 36.]], + [[[1.], + [3.]], - [[16., 26., 36.], - [17., 27., 37.]]], + [[0.], + [4.]], + [[2.], + [0.]]], - [[[17., 27., 37.], - [18., 28., 38.]], - - [[18., 28., 38.], - [19., 29., 39.]], - - [[19., 29., 39.], - [20., 30., 40.]], - - [[20., 30., 40.], - [21., 31., 41.]]], - - - [[[21., 31., 41.], - [22., 32., 42.]], - - [[22., 32., 42.], - [23., 33., 43.]], - - [[23., 33., 43.], - [24., 34., 44.]], - - [[24., 34., 44.], - [25., 35., 45.]]]]], device='cuda:0') - - -########## torch.float64/torch.int32/size=()+(4, 3)+() ########## -# sparse tensor -tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 2]), - values=tensor([[[1.], - [2.]], - - [[2.], - [3.]], - - [[3.], - [4.]], - - [[4.], - [5.]]]), device='cuda:0', size=(4, 3), nnz=4, - dtype=torch.float64, layout=torch.sparse_bsr) -# _crow_indices -tensor([0, 2, 4], device='cuda:0', dtype=torch.int32) -# _col_indices -tensor([0, 1, 0, 2], device='cuda:0', dtype=torch.int32) -# _values -tensor([[[1.], - [2.]], - - [[2.], - [3.]], - [[3.], - [4.]], + [[[1.], + [0.]], - [[4.], - [5.]]], device='cuda:0', dtype=torch.float64) + [[2.], + [4.]], -########## torch.float64/torch.int32/size=()+(0, 0)+() ########## -# sparse tensor -tensor(crow_indices=tensor([0]), - col_indices=tensor([], size=(0,)), - values=tensor([], size=(0, 2, 1)), device='cuda:0', size=(0, 0), nnz=0, - dtype=torch.float64, layout=torch.sparse_bsr) -# _crow_indices -tensor([0], device='cuda:0', dtype=torch.int32) -# _col_indices -tensor([], device='cuda:0', dtype=torch.int32) -# _values -tensor([], device='cuda:0', size=(0, 2, 1), dtype=torch.float64) + [[3.], + [0.]]]]], device='cuda:0') -########## torch.float64/torch.int32/size=(2,)+(2, 6)+() ########## +########## torch.float32/torch.int32/size=()+(8, 6)+() ########## # sparse tensor -tensor(crow_indices=tensor([[0, 2, 4], - [0, 3, 4]]), - col_indices=tensor([[0, 1, 0, 1], - [0, 1, 2, 0]]), - values=tensor([[[[ 1., 11.]], +tensor(crow_indices=tensor([0, 2, 3, 5, 7]), + col_indices=tensor([0, 1, 0, 0, 1, 0, 1]), + values=tensor([[[ 0., 1., 0.], + [ 0., 4., 0.]], - [[ 2., 12.]], + [[ 2., 0., 3.], + [ 0., 5., 0.]], - [[ 3., 13.]], + [[ 6., 7., 8.], + [ 0., 0., 0.]], - [[ 4., 14.]]], + [[ 0., 9., 0.], + [13., 0., 14.]], + [[10., 11., 12.], + [15., 16., 17.]], - [[[ 5., 15.]], + [[ 0., 0., 0.], + [20., 21., 22.]], - [[ 6., 16.]], - - [[ 7., 17.]], - - [[ 8., 18.]]]]), device='cuda:0', size=(2, 2, 6), nnz=4, - dtype=torch.float64, layout=torch.sparse_bsr) + [[ 0., 18., 19.], + [ 0., 23., 24.]]]), device='cuda:0', size=(8, 6), nnz=7, + layout=torch.sparse_bsr) # _crow_indices -tensor([[0, 2, 4], - [0, 3, 4]], device='cuda:0', dtype=torch.int32) +tensor([0, 2, 3, 5, 7], device='cuda:0', dtype=torch.int32) # _col_indices -tensor([[0, 1, 0, 1], - [0, 1, 2, 0]], device='cuda:0', dtype=torch.int32) +tensor([0, 1, 0, 0, 1, 0, 1], device='cuda:0', dtype=torch.int32) # _values -tensor([[[[ 1., 11.]], +tensor([[[ 0., 1., 0.], + [ 0., 4., 0.]], - [[ 2., 12.]], + [[ 2., 0., 3.], + [ 0., 5., 0.]], - [[ 3., 13.]], + [[ 6., 7., 8.], + [ 0., 0., 0.]], - [[ 4., 14.]]], + [[ 0., 9., 0.], + [13., 0., 14.]], + [[10., 11., 12.], + [15., 16., 17.]], - [[[ 5., 15.]], + [[ 0., 0., 0.], + [20., 21., 22.]], - [[ 6., 16.]], + [[ 0., 18., 19.], + [ 0., 23., 24.]]], device='cuda:0') - [[ 7., 17.]], - [[ 8., 18.]]]], device='cuda:0', dtype=torch.float64) - -########## torch.float64/torch.int32/size=(2, 3)+(4, 9)+() ########## +########## torch.float64/torch.int32/size=(2, 3)+(2, 3)+() ########## # sparse tensor -tensor(crow_indices=tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], - - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]), - col_indices=tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]), - values=tensor([[[[[ 1., 11., 21.], - [ 2., 12., 22.]], - - [[ 2., 12., 22.], - [ 3., 13., 23.]], +tensor(crow_indices=tensor([[[0, 3], + [0, 3], + [0, 3]], - [[ 3., 13., 23.], - [ 4., 14., 24.]], + [[0, 3], + [0, 3], + [0, 3]]]), + col_indices=tensor([[[0, 1, 2], + [0, 1, 2], + [0, 1, 2]], - [[ 4., 14., 24.], - [ 5., 15., 25.]]], + [[0, 1, 2], + [0, 1, 2], + [0, 1, 2]]]), + values=tensor([[[[[1.], + [3.]], + [[2.], + [0.]], - [[[ 5., 15., 25.], - [ 6., 16., 26.]], + [[0.], + [4.]]], - [[ 6., 16., 26.], - [ 7., 17., 27.]], - [[ 7., 17., 27.], - [ 8., 18., 28.]], + [[[1.], + [4.]], - [[ 8., 18., 28.], - [ 9., 19., 29.]]], + [[2.], + [0.]], + [[3.], + [0.]]], - [[[ 9., 19., 29.], - [10., 20., 30.]], - [[10., 20., 30.], - [11., 21., 31.]], + [[[1.], + [2.]], - [[11., 21., 31.], - [12., 22., 32.]], + [[0.], + [3.]], - [[12., 22., 32.], - [13., 23., 33.]]]], + [[0.], + [4.]]]], - [[[[13., 23., 33.], - [14., 24., 34.]], + [[[[0.], + [2.]], - [[14., 24., 34.], - [15., 25., 35.]], + [[1.], + [3.]], - [[15., 25., 35.], - [16., 26., 36.]], + [[0.], + [4.]]], - [[16., 26., 36.], - [17., 27., 37.]]], + [[[1.], + [3.]], - [[[17., 27., 37.], - [18., 28., 38.]], + [[0.], + [4.]], - [[18., 28., 38.], - [19., 29., 39.]], + [[2.], + [0.]]], - [[19., 29., 39.], - [20., 30., 40.]], - [[20., 30., 40.], - [21., 31., 41.]]], + [[[1.], + [0.]], + [[2.], + [4.]], - [[[21., 31., 41.], - [22., 32., 42.]], - - [[22., 32., 42.], - [23., 33., 43.]], - - [[23., 33., 43.], - [24., 34., 44.]], - - [[24., 34., 44.], - [25., 35., 45.]]]]]), device='cuda:0', - size=(2, 3, 4, 9), nnz=4, dtype=torch.float64, layout=torch.sparse_bsr) + [[3.], + [0.]]]]]), device='cuda:0', size=(2, 3, 2, 3), nnz=3, + dtype=torch.float64, layout=torch.sparse_bsr) # _crow_indices -tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], +tensor([[[0, 3], + [0, 3], + [0, 3]], - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]], device='cuda:0', dtype=torch.int32) + [[0, 3], + [0, 3], + [0, 3]]], device='cuda:0', dtype=torch.int32) # _col_indices -tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], +tensor([[[0, 1, 2], + [0, 1, 2], + [0, 1, 2]], - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]], device='cuda:0', dtype=torch.int32) + [[0, 1, 2], + [0, 1, 2], + [0, 1, 2]]], device='cuda:0', dtype=torch.int32) # _values -tensor([[[[[ 1., 11., 21.], - [ 2., 12., 22.]], - - [[ 2., 12., 22.], - [ 3., 13., 23.]], - - [[ 3., 13., 23.], - [ 4., 14., 24.]], - - [[ 4., 14., 24.], - [ 5., 15., 25.]]], - - - [[[ 5., 15., 25.], - [ 6., 16., 26.]], - - [[ 6., 16., 26.], - [ 7., 17., 27.]], - - [[ 7., 17., 27.], - [ 8., 18., 28.]], - - [[ 8., 18., 28.], - [ 9., 19., 29.]]], - - - [[[ 9., 19., 29.], - [10., 20., 30.]], - - [[10., 20., 30.], - [11., 21., 31.]], - - [[11., 21., 31.], - [12., 22., 32.]], +tensor([[[[[1.], + [3.]], - [[12., 22., 32.], - [13., 23., 33.]]]], + [[2.], + [0.]], + [[0.], + [4.]]], - [[[[13., 23., 33.], - [14., 24., 34.]], + [[[1.], + [4.]], - [[14., 24., 34.], - [15., 25., 35.]], + [[2.], + [0.]], - [[15., 25., 35.], - [16., 26., 36.]], + [[3.], + [0.]]], - [[16., 26., 36.], - [17., 27., 37.]]], + [[[1.], + [2.]], - [[[17., 27., 37.], - [18., 28., 38.]], + [[0.], + [3.]], - [[18., 28., 38.], - [19., 29., 39.]], + [[0.], + [4.]]]], - [[19., 29., 39.], - [20., 30., 40.]], - [[20., 30., 40.], - [21., 31., 41.]]], + [[[[0.], + [2.]], - [[[21., 31., 41.], - [22., 32., 42.]], + [[1.], + [3.]], - [[22., 32., 42.], - [23., 33., 43.]], + [[0.], + [4.]]], - [[23., 33., 43.], - [24., 34., 44.]], - [[24., 34., 44.], - [25., 35., 45.]]]]], device='cuda:0', dtype=torch.float64) + [[[1.], + [3.]], + [[0.], + [4.]], -########## torch.float32/torch.int64/size=()+(4, 3)+() ########## -# sparse tensor -tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 2]), - values=tensor([[[1.], - [2.]], - - [[2.], - [3.]], - - [[3.], - [4.]], - - [[4.], - [5.]]]), device='cuda:0', size=(4, 3), nnz=4, - layout=torch.sparse_bsr) -# _crow_indices -tensor([0, 2, 4], device='cuda:0') -# _col_indices -tensor([0, 1, 0, 2], device='cuda:0') -# _values -tensor([[[1.], - [2.]], + [[2.], + [0.]]], - [[2.], - [3.]], - [[3.], - [4.]], + [[[1.], + [0.]], - [[4.], - [5.]]], device='cuda:0') + [[2.], + [4.]], -########## torch.float32/torch.int64/size=()+(0, 0)+() ########## -# sparse tensor -tensor(crow_indices=tensor([0]), - col_indices=tensor([], size=(0,)), - values=tensor([], size=(0, 2, 1)), device='cuda:0', size=(0, 0), nnz=0, - layout=torch.sparse_bsr) -# _crow_indices -tensor([0], device='cuda:0') -# _col_indices -tensor([], device='cuda:0', dtype=torch.int64) -# _values -tensor([], device='cuda:0', size=(0, 2, 1)) + [[3.], + [0.]]]]], device='cuda:0', dtype=torch.float64) -########## torch.float32/torch.int64/size=(2,)+(2, 6)+() ########## +########## torch.float64/torch.int32/size=()+(8, 6)+() ########## # sparse tensor -tensor(crow_indices=tensor([[0, 2, 4], - [0, 3, 4]]), - col_indices=tensor([[0, 1, 0, 1], - [0, 1, 2, 0]]), - values=tensor([[[[ 1., 11.]], - - [[ 2., 12.]], +tensor(crow_indices=tensor([0, 2, 3, 5, 7]), + col_indices=tensor([0, 1, 0, 0, 1, 0, 1]), + values=tensor([[[ 0., 1., 0.], + [ 0., 4., 0.]], - [[ 3., 13.]], + [[ 2., 0., 3.], + [ 0., 5., 0.]], - [[ 4., 14.]]], + [[ 6., 7., 8.], + [ 0., 0., 0.]], + [[ 0., 9., 0.], + [13., 0., 14.]], - [[[ 5., 15.]], + [[10., 11., 12.], + [15., 16., 17.]], - [[ 6., 16.]], + [[ 0., 0., 0.], + [20., 21., 22.]], - [[ 7., 17.]], - - [[ 8., 18.]]]]), device='cuda:0', size=(2, 2, 6), nnz=4, - layout=torch.sparse_bsr) + [[ 0., 18., 19.], + [ 0., 23., 24.]]]), device='cuda:0', size=(8, 6), nnz=7, + dtype=torch.float64, layout=torch.sparse_bsr) # _crow_indices -tensor([[0, 2, 4], - [0, 3, 4]], device='cuda:0') +tensor([0, 2, 3, 5, 7], device='cuda:0', dtype=torch.int32) # _col_indices -tensor([[0, 1, 0, 1], - [0, 1, 2, 0]], device='cuda:0') +tensor([0, 1, 0, 0, 1, 0, 1], device='cuda:0', dtype=torch.int32) # _values -tensor([[[[ 1., 11.]], - - [[ 2., 12.]], +tensor([[[ 0., 1., 0.], + [ 0., 4., 0.]], - [[ 3., 13.]], + [[ 2., 0., 3.], + [ 0., 5., 0.]], - [[ 4., 14.]]], + [[ 6., 7., 8.], + [ 0., 0., 0.]], + [[ 0., 9., 0.], + [13., 0., 14.]], - [[[ 5., 15.]], + [[10., 11., 12.], + [15., 16., 17.]], - [[ 6., 16.]], + [[ 0., 0., 0.], + [20., 21., 22.]], - [[ 7., 17.]], + [[ 0., 18., 19.], + [ 0., 23., 24.]]], device='cuda:0', dtype=torch.float64) - [[ 8., 18.]]]], device='cuda:0') -########## torch.float32/torch.int64/size=(2, 3)+(4, 9)+() ########## +########## torch.float32/torch.int64/size=(2, 3)+(2, 3)+() ########## # sparse tensor -tensor(crow_indices=tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], - - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]), - col_indices=tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]), - values=tensor([[[[[ 1., 11., 21.], - [ 2., 12., 22.]], - - [[ 2., 12., 22.], - [ 3., 13., 23.]], - - [[ 3., 13., 23.], - [ 4., 14., 24.]], - - [[ 4., 14., 24.], - [ 5., 15., 25.]]], - +tensor(crow_indices=tensor([[[0, 3], + [0, 3], + [0, 3]], - [[[ 5., 15., 25.], - [ 6., 16., 26.]], + [[0, 3], + [0, 3], + [0, 3]]]), + col_indices=tensor([[[0, 1, 2], + [0, 1, 2], + [0, 1, 2]], - [[ 6., 16., 26.], - [ 7., 17., 27.]], + [[0, 1, 2], + [0, 1, 2], + [0, 1, 2]]]), + values=tensor([[[[[1.], + [3.]], - [[ 7., 17., 27.], - [ 8., 18., 28.]], + [[2.], + [0.]], - [[ 8., 18., 28.], - [ 9., 19., 29.]]], + [[0.], + [4.]]], - [[[ 9., 19., 29.], - [10., 20., 30.]], + [[[1.], + [4.]], - [[10., 20., 30.], - [11., 21., 31.]], + [[2.], + [0.]], - [[11., 21., 31.], - [12., 22., 32.]], + [[3.], + [0.]]], - [[12., 22., 32.], - [13., 23., 33.]]]], + [[[1.], + [2.]], + [[0.], + [3.]], - [[[[13., 23., 33.], - [14., 24., 34.]], + [[0.], + [4.]]]], - [[14., 24., 34.], - [15., 25., 35.]], - [[15., 25., 35.], - [16., 26., 36.]], - [[16., 26., 36.], - [17., 27., 37.]]], + [[[[0.], + [2.]], + [[1.], + [3.]], - [[[17., 27., 37.], - [18., 28., 38.]], + [[0.], + [4.]]], - [[18., 28., 38.], - [19., 29., 39.]], - [[19., 29., 39.], - [20., 30., 40.]], + [[[1.], + [3.]], - [[20., 30., 40.], - [21., 31., 41.]]], + [[0.], + [4.]], + [[2.], + [0.]]], - [[[21., 31., 41.], - [22., 32., 42.]], - [[22., 32., 42.], - [23., 33., 43.]], + [[[1.], + [0.]], - [[23., 33., 43.], - [24., 34., 44.]], + [[2.], + [4.]], - [[24., 34., 44.], - [25., 35., 45.]]]]]), device='cuda:0', - size=(2, 3, 4, 9), nnz=4, layout=torch.sparse_bsr) + [[3.], + [0.]]]]]), device='cuda:0', size=(2, 3, 2, 3), nnz=3, + layout=torch.sparse_bsr) # _crow_indices -tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], +tensor([[[0, 3], + [0, 3], + [0, 3]], - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]], device='cuda:0') + [[0, 3], + [0, 3], + [0, 3]]], device='cuda:0') # _col_indices -tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], +tensor([[[0, 1, 2], + [0, 1, 2], + [0, 1, 2]], - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]], device='cuda:0') + [[0, 1, 2], + [0, 1, 2], + [0, 1, 2]]], device='cuda:0') # _values -tensor([[[[[ 1., 11., 21.], - [ 2., 12., 22.]], - - [[ 2., 12., 22.], - [ 3., 13., 23.]], - - [[ 3., 13., 23.], - [ 4., 14., 24.]], - - [[ 4., 14., 24.], - [ 5., 15., 25.]]], - - - [[[ 5., 15., 25.], - [ 6., 16., 26.]], - - [[ 6., 16., 26.], - [ 7., 17., 27.]], - - [[ 7., 17., 27.], - [ 8., 18., 28.]], +tensor([[[[[1.], + [3.]], - [[ 8., 18., 28.], - [ 9., 19., 29.]]], + [[2.], + [0.]], + [[0.], + [4.]]], - [[[ 9., 19., 29.], - [10., 20., 30.]], - [[10., 20., 30.], - [11., 21., 31.]], + [[[1.], + [4.]], - [[11., 21., 31.], - [12., 22., 32.]], + [[2.], + [0.]], - [[12., 22., 32.], - [13., 23., 33.]]]], + [[3.], + [0.]]], + [[[1.], + [2.]], - [[[[13., 23., 33.], - [14., 24., 34.]], + [[0.], + [3.]], - [[14., 24., 34.], - [15., 25., 35.]], + [[0.], + [4.]]]], - [[15., 25., 35.], - [16., 26., 36.]], - [[16., 26., 36.], - [17., 27., 37.]]], + [[[[0.], + [2.]], - [[[17., 27., 37.], - [18., 28., 38.]], + [[1.], + [3.]], - [[18., 28., 38.], - [19., 29., 39.]], + [[0.], + [4.]]], - [[19., 29., 39.], - [20., 30., 40.]], - [[20., 30., 40.], - [21., 31., 41.]]], + [[[1.], + [3.]], + [[0.], + [4.]], - [[[21., 31., 41.], - [22., 32., 42.]], + [[2.], + [0.]]], - [[22., 32., 42.], - [23., 33., 43.]], - [[23., 33., 43.], - [24., 34., 44.]], + [[[1.], + [0.]], - [[24., 34., 44.], - [25., 35., 45.]]]]], device='cuda:0') + [[2.], + [4.]], + [[3.], + [0.]]]]], device='cuda:0') -########## torch.float64/torch.int64/size=()+(4, 3)+() ########## +########## torch.float32/torch.int64/size=()+(8, 6)+() ########## # sparse tensor -tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 2]), - values=tensor([[[1.], - [2.]], +tensor(crow_indices=tensor([0, 2, 3, 5, 7]), + col_indices=tensor([0, 1, 0, 0, 1, 0, 1]), + values=tensor([[[ 0., 1., 0.], + [ 0., 4., 0.]], - [[2.], - [3.]], + [[ 2., 0., 3.], + [ 0., 5., 0.]], - [[3.], - [4.]], + [[ 6., 7., 8.], + [ 0., 0., 0.]], - [[4.], - [5.]]]), device='cuda:0', size=(4, 3), nnz=4, - dtype=torch.float64, layout=torch.sparse_bsr) -# _crow_indices -tensor([0, 2, 4], device='cuda:0') -# _col_indices -tensor([0, 1, 0, 2], device='cuda:0') -# _values -tensor([[[1.], - [2.]], - - [[2.], - [3.]], - - [[3.], - [4.]], - - [[4.], - [5.]]], device='cuda:0', dtype=torch.float64) - -########## torch.float64/torch.int64/size=()+(0, 0)+() ########## -# sparse tensor -tensor(crow_indices=tensor([0]), - col_indices=tensor([], size=(0,)), - values=tensor([], size=(0, 2, 1)), device='cuda:0', size=(0, 0), nnz=0, - dtype=torch.float64, layout=torch.sparse_bsr) -# _crow_indices -tensor([0], device='cuda:0') -# _col_indices -tensor([], device='cuda:0', dtype=torch.int64) -# _values -tensor([], device='cuda:0', size=(0, 2, 1), dtype=torch.float64) - -########## torch.float64/torch.int64/size=(2,)+(2, 6)+() ########## -# sparse tensor -tensor(crow_indices=tensor([[0, 2, 4], - [0, 3, 4]]), - col_indices=tensor([[0, 1, 0, 1], - [0, 1, 2, 0]]), - values=tensor([[[[ 1., 11.]], - - [[ 2., 12.]], + [[ 0., 9., 0.], + [13., 0., 14.]], - [[ 3., 13.]], + [[10., 11., 12.], + [15., 16., 17.]], - [[ 4., 14.]]], + [[ 0., 0., 0.], + [20., 21., 22.]], - - [[[ 5., 15.]], - - [[ 6., 16.]], - - [[ 7., 17.]], - - [[ 8., 18.]]]]), device='cuda:0', size=(2, 2, 6), nnz=4, - dtype=torch.float64, layout=torch.sparse_bsr) + [[ 0., 18., 19.], + [ 0., 23., 24.]]]), device='cuda:0', size=(8, 6), nnz=7, + layout=torch.sparse_bsr) # _crow_indices -tensor([[0, 2, 4], - [0, 3, 4]], device='cuda:0') +tensor([0, 2, 3, 5, 7], device='cuda:0') # _col_indices -tensor([[0, 1, 0, 1], - [0, 1, 2, 0]], device='cuda:0') +tensor([0, 1, 0, 0, 1, 0, 1], device='cuda:0') # _values -tensor([[[[ 1., 11.]], +tensor([[[ 0., 1., 0.], + [ 0., 4., 0.]], - [[ 2., 12.]], + [[ 2., 0., 3.], + [ 0., 5., 0.]], - [[ 3., 13.]], + [[ 6., 7., 8.], + [ 0., 0., 0.]], - [[ 4., 14.]]], + [[ 0., 9., 0.], + [13., 0., 14.]], + [[10., 11., 12.], + [15., 16., 17.]], - [[[ 5., 15.]], + [[ 0., 0., 0.], + [20., 21., 22.]], - [[ 6., 16.]], + [[ 0., 18., 19.], + [ 0., 23., 24.]]], device='cuda:0') - [[ 7., 17.]], - [[ 8., 18.]]]], device='cuda:0', dtype=torch.float64) - -########## torch.float64/torch.int64/size=(2, 3)+(4, 9)+() ########## +########## torch.float64/torch.int64/size=(2, 3)+(2, 3)+() ########## # sparse tensor -tensor(crow_indices=tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], - - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]), - col_indices=tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], +tensor(crow_indices=tensor([[[0, 3], + [0, 3], + [0, 3]], - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]), - values=tensor([[[[[ 1., 11., 21.], - [ 2., 12., 22.]], + [[0, 3], + [0, 3], + [0, 3]]]), + col_indices=tensor([[[0, 1, 2], + [0, 1, 2], + [0, 1, 2]], - [[ 2., 12., 22.], - [ 3., 13., 23.]], + [[0, 1, 2], + [0, 1, 2], + [0, 1, 2]]]), + values=tensor([[[[[1.], + [3.]], - [[ 3., 13., 23.], - [ 4., 14., 24.]], + [[2.], + [0.]], - [[ 4., 14., 24.], - [ 5., 15., 25.]]], + [[0.], + [4.]]], - [[[ 5., 15., 25.], - [ 6., 16., 26.]], + [[[1.], + [4.]], - [[ 6., 16., 26.], - [ 7., 17., 27.]], + [[2.], + [0.]], - [[ 7., 17., 27.], - [ 8., 18., 28.]], + [[3.], + [0.]]], - [[ 8., 18., 28.], - [ 9., 19., 29.]]], + [[[1.], + [2.]], - [[[ 9., 19., 29.], - [10., 20., 30.]], + [[0.], + [3.]], - [[10., 20., 30.], - [11., 21., 31.]], + [[0.], + [4.]]]], - [[11., 21., 31.], - [12., 22., 32.]], - [[12., 22., 32.], - [13., 23., 33.]]]], + [[[[0.], + [2.]], + [[1.], + [3.]], - [[[[13., 23., 33.], - [14., 24., 34.]], + [[0.], + [4.]]], - [[14., 24., 34.], - [15., 25., 35.]], - [[15., 25., 35.], - [16., 26., 36.]], + [[[1.], + [3.]], - [[16., 26., 36.], - [17., 27., 37.]]], + [[0.], + [4.]], + [[2.], + [0.]]], - [[[17., 27., 37.], - [18., 28., 38.]], - [[18., 28., 38.], - [19., 29., 39.]], + [[[1.], + [0.]], - [[19., 29., 39.], - [20., 30., 40.]], + [[2.], + [4.]], - [[20., 30., 40.], - [21., 31., 41.]]], - - - [[[21., 31., 41.], - [22., 32., 42.]], - - [[22., 32., 42.], - [23., 33., 43.]], - - [[23., 33., 43.], - [24., 34., 44.]], - - [[24., 34., 44.], - [25., 35., 45.]]]]]), device='cuda:0', - size=(2, 3, 4, 9), nnz=4, dtype=torch.float64, layout=torch.sparse_bsr) + [[3.], + [0.]]]]]), device='cuda:0', size=(2, 3, 2, 3), nnz=3, + dtype=torch.float64, layout=torch.sparse_bsr) # _crow_indices -tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], +tensor([[[0, 3], + [0, 3], + [0, 3]], - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]], device='cuda:0') + [[0, 3], + [0, 3], + [0, 3]]], device='cuda:0') # _col_indices -tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], +tensor([[[0, 1, 2], + [0, 1, 2], + [0, 1, 2]], - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]], device='cuda:0') + [[0, 1, 2], + [0, 1, 2], + [0, 1, 2]]], device='cuda:0') # _values -tensor([[[[[ 1., 11., 21.], - [ 2., 12., 22.]], - - [[ 2., 12., 22.], - [ 3., 13., 23.]], - - [[ 3., 13., 23.], - [ 4., 14., 24.]], +tensor([[[[[1.], + [3.]], - [[ 4., 14., 24.], - [ 5., 15., 25.]]], + [[2.], + [0.]], + [[0.], + [4.]]], - [[[ 5., 15., 25.], - [ 6., 16., 26.]], - [[ 6., 16., 26.], - [ 7., 17., 27.]], + [[[1.], + [4.]], - [[ 7., 17., 27.], - [ 8., 18., 28.]], + [[2.], + [0.]], - [[ 8., 18., 28.], - [ 9., 19., 29.]]], + [[3.], + [0.]]], - [[[ 9., 19., 29.], - [10., 20., 30.]], + [[[1.], + [2.]], - [[10., 20., 30.], - [11., 21., 31.]], + [[0.], + [3.]], - [[11., 21., 31.], - [12., 22., 32.]], + [[0.], + [4.]]]], - [[12., 22., 32.], - [13., 23., 33.]]]], + [[[[0.], + [2.]], - [[[[13., 23., 33.], - [14., 24., 34.]], + [[1.], + [3.]], - [[14., 24., 34.], - [15., 25., 35.]], + [[0.], + [4.]]], - [[15., 25., 35.], - [16., 26., 36.]], - [[16., 26., 36.], - [17., 27., 37.]]], + [[[1.], + [3.]], + [[0.], + [4.]], - [[[17., 27., 37.], - [18., 28., 38.]], + [[2.], + [0.]]], - [[18., 28., 38.], - [19., 29., 39.]], - [[19., 29., 39.], - [20., 30., 40.]], + [[[1.], + [0.]], - [[20., 30., 40.], - [21., 31., 41.]]], + [[2.], + [4.]], + [[3.], + [0.]]]]], device='cuda:0', dtype=torch.float64) - [[[21., 31., 41.], - [22., 32., 42.]], - - [[22., 32., 42.], - [23., 33., 43.]], - - [[23., 33., 43.], - [24., 34., 44.]], - - [[24., 34., 44.], - [25., 35., 45.]]]]], device='cuda:0', dtype=torch.float64) - - -########## torch.float32/torch.int32/size=()+(6, 6)+(2,) ########## +########## torch.float64/torch.int64/size=()+(8, 6)+() ########## # sparse tensor -tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 2]), - values=tensor([[[[ 1., 101.], - [ 11., 111.]], - - [[ 2., 102.], - [ 12., 112.]], - - [[ 3., 103.], - [ 13., 113.]]], - +tensor(crow_indices=tensor([0, 2, 3, 5, 7]), + col_indices=tensor([0, 1, 0, 0, 1, 0, 1]), + values=tensor([[[ 0., 1., 0.], + [ 0., 4., 0.]], - [[[ 2., 102.], - [ 12., 112.]], + [[ 2., 0., 3.], + [ 0., 5., 0.]], - [[ 3., 103.], - [ 13., 113.]], + [[ 6., 7., 8.], + [ 0., 0., 0.]], - [[ 4., 104.], - [ 14., 114.]]], + [[ 0., 9., 0.], + [13., 0., 14.]], + [[10., 11., 12.], + [15., 16., 17.]], - [[[ 3., 103.], - [ 13., 113.]], + [[ 0., 0., 0.], + [20., 21., 22.]], - [[ 4., 104.], - [ 14., 114.]], - - [[ 5., 105.], - [ 15., 115.]]], - - - [[[ 4., 104.], - [ 14., 114.]], - - [[ 5., 105.], - [ 15., 115.]], - - [[ 6., 106.], - [ 16., 116.]]]]), device='cuda:0', size=(6, 6, 2), - nnz=4, layout=torch.sparse_bsr) + [[ 0., 18., 19.], + [ 0., 23., 24.]]]), device='cuda:0', size=(8, 6), nnz=7, + dtype=torch.float64, layout=torch.sparse_bsr) # _crow_indices -tensor([0, 2, 4], device='cuda:0', dtype=torch.int32) +tensor([0, 2, 3, 5, 7], device='cuda:0') # _col_indices -tensor([0, 1, 0, 2], device='cuda:0', dtype=torch.int32) +tensor([0, 1, 0, 0, 1, 0, 1], device='cuda:0') # _values -tensor([[[[ 1., 101.], - [ 11., 111.]], - - [[ 2., 102.], - [ 12., 112.]], - - [[ 3., 103.], - [ 13., 113.]]], +tensor([[[ 0., 1., 0.], + [ 0., 4., 0.]], + [[ 2., 0., 3.], + [ 0., 5., 0.]], - [[[ 2., 102.], - [ 12., 112.]], + [[ 6., 7., 8.], + [ 0., 0., 0.]], - [[ 3., 103.], - [ 13., 113.]], + [[ 0., 9., 0.], + [13., 0., 14.]], - [[ 4., 104.], - [ 14., 114.]]], + [[10., 11., 12.], + [15., 16., 17.]], + [[ 0., 0., 0.], + [20., 21., 22.]], - [[[ 3., 103.], - [ 13., 113.]], + [[ 0., 18., 19.], + [ 0., 23., 24.]]], device='cuda:0', dtype=torch.float64) - [[ 4., 104.], - [ 14., 114.]], - [[ 5., 105.], - [ 15., 115.]]], - - - [[[ 4., 104.], - [ 14., 114.]], - - [[ 5., 105.], - [ 15., 115.]], - - [[ 6., 106.], - [ 16., 116.]]]], device='cuda:0') - -########## torch.float32/torch.int32/size=()+(4, 9)+(4, 2) ########## +########## torch.float32/torch.int32/size=(2, 3)+(2, 3)+(4,) ########## # sparse tensor -tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 2]), - values=tensor([[[[[1.0000e+00, 1.0010e+03], - [1.0100e+02, 1.1010e+03], - [2.0100e+02, 1.2010e+03], - [3.0100e+02, 1.3010e+03]], - - [[1.1000e+01, 1.0110e+03], - [1.1100e+02, 1.1110e+03], - [2.1100e+02, 1.2110e+03], - [3.1100e+02, 1.3110e+03]], - - [[2.1000e+01, 1.0210e+03], - [1.2100e+02, 1.1210e+03], - [2.2100e+02, 1.2210e+03], - [3.2100e+02, 1.3210e+03]]], - +tensor(crow_indices=tensor([[[0, 3], + [0, 3], + [0, 3]], - [[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], + [[0, 3], + [0, 3], + [0, 3]]]), + col_indices=tensor([[[0, 1, 2], + [0, 1, 2], + [0, 1, 2]], - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]], + [[0, 1, 2], + [0, 1, 2], + [0, 1, 2]]]), + values=tensor([[[[[[1., 2., 3., 4.]], - [[2.2000e+01, 1.0220e+03], - [1.2200e+02, 1.1220e+03], - [2.2200e+02, 1.2220e+03], - [3.2200e+02, 1.3220e+03]]]], + [[3., 4., 5., 6.]]], + [[[2., 3., 4., 5.]], - [[[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], + [[0., 0., 0., 0.]]], - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]], - [[2.2000e+01, 1.0220e+03], - [1.2200e+02, 1.1220e+03], - [2.2200e+02, 1.2220e+03], - [3.2200e+02, 1.3220e+03]]], + [[[0., 0., 0., 0.]], + [[4., 5., 6., 7.]]]], - [[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]], - [[2.3000e+01, 1.0230e+03], - [1.2300e+02, 1.1230e+03], - [2.2300e+02, 1.2230e+03], - [3.2300e+02, 1.3230e+03]]]], + [[[[1., 2., 3., 4.]], + [[4., 5., 6., 7.]]], - [[[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], + [[[2., 3., 4., 5.]], - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]], + [[0., 0., 0., 0.]]], - [[2.3000e+01, 1.0230e+03], - [1.2300e+02, 1.1230e+03], - [2.2300e+02, 1.2230e+03], - [3.2300e+02, 1.3230e+03]]], + [[[3., 4., 5., 6.]], - [[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], + [[0., 0., 0., 0.]]]], - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]], - [[2.4000e+01, 1.0240e+03], - [1.2400e+02, 1.1240e+03], - [2.2400e+02, 1.2240e+03], - [3.2400e+02, 1.3240e+03]]]], + [[[[1., 2., 3., 4.]], + [[2., 3., 4., 5.]]], - [[[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], - - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]], - - [[2.4000e+01, 1.0240e+03], - [1.2400e+02, 1.1240e+03], - [2.2400e+02, 1.2240e+03], - [3.2400e+02, 1.3240e+03]]], - - - [[[5.0000e+00, 1.0050e+03], - [1.0500e+02, 1.1050e+03], - [2.0500e+02, 1.2050e+03], - [3.0500e+02, 1.3050e+03]], - - [[1.5000e+01, 1.0150e+03], - [1.1500e+02, 1.1150e+03], - [2.1500e+02, 1.2150e+03], - [3.1500e+02, 1.3150e+03]], - - [[2.5000e+01, 1.0250e+03], - [1.2500e+02, 1.1250e+03], - [2.2500e+02, 1.2250e+03], - [3.2500e+02, 1.3250e+03]]]]]), device='cuda:0', - size=(4, 9, 4, 2), nnz=4, layout=torch.sparse_bsr) -# _crow_indices -tensor([0, 2, 4], device='cuda:0', dtype=torch.int32) -# _col_indices -tensor([0, 1, 0, 2], device='cuda:0', dtype=torch.int32) -# _values -tensor([[[[[1.0000e+00, 1.0010e+03], - [1.0100e+02, 1.1010e+03], - [2.0100e+02, 1.2010e+03], - [3.0100e+02, 1.3010e+03]], - [[1.1000e+01, 1.0110e+03], - [1.1100e+02, 1.1110e+03], - [2.1100e+02, 1.2110e+03], - [3.1100e+02, 1.3110e+03]], + [[[0., 0., 0., 0.]], - [[2.1000e+01, 1.0210e+03], - [1.2100e+02, 1.1210e+03], - [2.2100e+02, 1.2210e+03], - [3.2100e+02, 1.3210e+03]]], + [[3., 4., 5., 6.]]], - [[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], + [[[0., 0., 0., 0.]], - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]], + [[4., 5., 6., 7.]]]]], - [[2.2000e+01, 1.0220e+03], - [1.2200e+02, 1.1220e+03], - [2.2200e+02, 1.2220e+03], - [3.2200e+02, 1.3220e+03]]]], - [[[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], + [[[[[0., 0., 0., 0.]], - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]], + [[2., 3., 4., 5.]]], - [[2.2000e+01, 1.0220e+03], - [1.2200e+02, 1.1220e+03], - [2.2200e+02, 1.2220e+03], - [3.2200e+02, 1.3220e+03]]], + [[[1., 2., 3., 4.]], - [[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], + [[3., 4., 5., 6.]]], - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]], - [[2.3000e+01, 1.0230e+03], - [1.2300e+02, 1.1230e+03], - [2.2300e+02, 1.2230e+03], - [3.2300e+02, 1.3230e+03]]]], + [[[0., 0., 0., 0.]], + [[4., 5., 6., 7.]]]], - [[[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]], + [[[[1., 2., 3., 4.]], - [[2.3000e+01, 1.0230e+03], - [1.2300e+02, 1.1230e+03], - [2.2300e+02, 1.2230e+03], - [3.2300e+02, 1.3230e+03]]], + [[3., 4., 5., 6.]]], - [[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], + [[[0., 0., 0., 0.]], - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]], + [[4., 5., 6., 7.]]], - [[2.4000e+01, 1.0240e+03], - [1.2400e+02, 1.1240e+03], - [2.2400e+02, 1.2240e+03], - [3.2400e+02, 1.3240e+03]]]], + [[[2., 3., 4., 5.]], + [[0., 0., 0., 0.]]]], - [[[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]], - [[2.4000e+01, 1.0240e+03], - [1.2400e+02, 1.1240e+03], - [2.2400e+02, 1.2240e+03], - [3.2400e+02, 1.3240e+03]]], + [[[[1., 2., 3., 4.]], + [[0., 0., 0., 0.]]], - [[[5.0000e+00, 1.0050e+03], - [1.0500e+02, 1.1050e+03], - [2.0500e+02, 1.2050e+03], - [3.0500e+02, 1.3050e+03]], - [[1.5000e+01, 1.0150e+03], - [1.1500e+02, 1.1150e+03], - [2.1500e+02, 1.2150e+03], - [3.1500e+02, 1.3150e+03]], + [[[2., 3., 4., 5.]], - [[2.5000e+01, 1.0250e+03], - [1.2500e+02, 1.1250e+03], - [2.2500e+02, 1.2250e+03], - [3.2500e+02, 1.3250e+03]]]]], device='cuda:0') + [[4., 5., 6., 7.]]], -########## torch.float32/torch.int32/size=(2, 3)+(6, 6)+(2, 1) ########## -# sparse tensor -tensor(crow_indices=tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]), - col_indices=tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], + [[[3., 4., 5., 6.]], - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]), - values=tensor([[[[[[[ 1.], - [101.]], + [[0., 0., 0., 0.]]]]]]), device='cuda:0', + size=(2, 3, 2, 3, 4), nnz=3, layout=torch.sparse_bsr) +# _crow_indices +tensor([[[0, 3], + [0, 3], + [0, 3]], - [[ 11.], - [111.]]], + [[0, 3], + [0, 3], + [0, 3]]], device='cuda:0', dtype=torch.int32) +# _col_indices +tensor([[[0, 1, 2], + [0, 1, 2], + [0, 1, 2]], + [[0, 1, 2], + [0, 1, 2], + [0, 1, 2]]], device='cuda:0', dtype=torch.int32) +# _values +tensor([[[[[[1., 2., 3., 4.]], - [[[ 2.], - [102.]], + [[3., 4., 5., 6.]]], - [[ 12.], - [112.]]], + [[[2., 3., 4., 5.]], - [[[ 3.], - [103.]], + [[0., 0., 0., 0.]]], - [[ 13.], - [113.]]]], + [[[0., 0., 0., 0.]], + [[4., 5., 6., 7.]]]], - [[[[ 2.], - [102.]], - [[ 12.], - [112.]]], + [[[[1., 2., 3., 4.]], - [[[ 3.], - [103.]], + [[4., 5., 6., 7.]]], - [[ 13.], - [113.]]], + [[[2., 3., 4., 5.]], - [[[ 4.], - [104.]], + [[0., 0., 0., 0.]]], - [[ 14.], - [114.]]]], + [[[3., 4., 5., 6.]], + [[0., 0., 0., 0.]]]], - [[[[ 3.], - [103.]], - [[ 13.], - [113.]]], + [[[[1., 2., 3., 4.]], - [[[ 4.], - [104.]], + [[2., 3., 4., 5.]]], - [[ 14.], - [114.]]], + [[[0., 0., 0., 0.]], - [[[ 5.], - [105.]], + [[3., 4., 5., 6.]]], - [[ 15.], - [115.]]]], + [[[0., 0., 0., 0.]], + [[4., 5., 6., 7.]]]]], - [[[[ 4.], - [104.]], - [[ 14.], - [114.]]], - [[[ 5.], - [105.]], + [[[[[0., 0., 0., 0.]], - [[ 15.], - [115.]]], + [[2., 3., 4., 5.]]], - [[[ 6.], - [106.]], + [[[1., 2., 3., 4.]], - [[ 16.], - [116.]]]]], + [[3., 4., 5., 6.]]], + [[[0., 0., 0., 0.]], + [[4., 5., 6., 7.]]]], - [[[[[ 5.], - [105.]], - [[ 15.], - [115.]]], + [[[[1., 2., 3., 4.]], - [[[ 6.], - [106.]], + [[3., 4., 5., 6.]]], - [[ 16.], - [116.]]], + [[[0., 0., 0., 0.]], - [[[ 7.], - [107.]], + [[4., 5., 6., 7.]]], - [[ 17.], - [117.]]]], + [[[2., 3., 4., 5.]], + [[0., 0., 0., 0.]]]], - [[[[ 6.], - [106.]], - [[ 16.], - [116.]]], + [[[[1., 2., 3., 4.]], - [[[ 7.], - [107.]], + [[0., 0., 0., 0.]]], - [[ 17.], - [117.]]], + [[[2., 3., 4., 5.]], - [[[ 8.], - [108.]], + [[4., 5., 6., 7.]]], - [[ 18.], - [118.]]]], + [[[3., 4., 5., 6.]], + [[0., 0., 0., 0.]]]]]], device='cuda:0') - [[[[ 7.], - [107.]], +########## torch.float32/torch.int32/size=()+(8, 6)+(4, 2) ########## +# sparse tensor +tensor(crow_indices=tensor([0, 2, 3, 5, 7]), + col_indices=tensor([0, 1, 0, 0, 1, 0, 1]), + values=tensor([[[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 17.], - [117.]]], + [[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[[ 8.], - [108.]], - [[ 18.], - [118.]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]], - [[[ 9.], - [109.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[ 19.], - [119.]]]], + [[[[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], - [[[[ 8.], - [108.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 18.], - [118.]]], + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]]], - [[[ 9.], - [109.]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 19.], - [119.]]], + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[[ 10.], - [110.]], - [[ 20.], - [120.]]]]], + [[[[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]]], - [[[[[ 9.], - [109.]], - [[ 19.], - [119.]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[ 10.], - [110.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[ 20.], - [120.]]], - [[[ 11.], - [111.]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 21.], - [121.]]]], + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[[[ 10.], - [110.]], + [[[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], - [[ 20.], - [120.]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]]]], - [[[ 11.], - [111.]], - [[ 21.], - [121.]]], + [[[[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], - [[[ 12.], - [112.]], + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], - [[ 22.], - [122.]]]], + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]]], + [[[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], - [[[[ 11.], - [111.]], + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], - [[ 21.], - [121.]]], + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]]]], - [[[ 12.], - [112.]], - [[ 22.], - [122.]]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[ 13.], - [113.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[ 23.], - [123.]]]], + [[[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], - [[[[ 12.], - [112.]], + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]]]], - [[ 22.], - [122.]]], - [[[ 13.], - [113.]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 23.], - [123.]]], + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]]], - [[[ 14.], - [114.]], - [[ 24.], - [124.]]]]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]]]]), device='cuda:0', size=(8, 6, 4, 2), + nnz=7, layout=torch.sparse_bsr) +# _crow_indices +tensor([0, 2, 3, 5, 7], device='cuda:0', dtype=torch.int32) +# _col_indices +tensor([0, 1, 0, 0, 1, 0, 1], device='cuda:0', dtype=torch.int32) +# _values +tensor([[[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[[[[[ 13.], - [113.]], - [[ 23.], - [123.]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]], - [[[ 14.], - [114.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[ 24.], - [124.]]], - [[[ 15.], - [115.]], + [[[[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], - [[ 25.], - [125.]]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]]], - [[[[ 14.], - [114.]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 24.], - [124.]]], + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[[ 15.], - [115.]], - [[ 25.], - [125.]]], + [[[[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], - [[[ 16.], - [116.]], + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], - [[ 26.], - [126.]]]], + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[[ 15.], - [115.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 25.], - [125.]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[[ 16.], - [116.]], - [[ 26.], - [126.]]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], - [[[ 17.], - [117.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[ 27.], - [127.]]]], + [[[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[[ 16.], - [116.]], + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]]]], - [[ 26.], - [126.]]], - [[[ 17.], - [117.]], + [[[[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], - [[ 27.], - [127.]]], + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]]], - [[[ 18.], - [118.]], - [[ 28.], - [128.]]]]], + [[[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]]]], - [[[[[ 17.], - [117.]], - [[ 27.], - [127.]]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[ 18.], - [118.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[ 28.], - [128.]]], + [[[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], - [[[ 19.], - [119.]], + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], - [[ 29.], - [129.]]]], + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]]]], - [[[[ 18.], - [118.]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 28.], - [128.]]], + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]]], - [[[ 19.], - [119.]], - [[ 29.], - [129.]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], - [[[ 20.], - [120.]], + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]]]], device='cuda:0') - [[ 30.], - [130.]]]], +########## torch.float64/torch.int32/size=(2, 3)+(2, 3)+(4,) ########## +# sparse tensor +tensor(crow_indices=tensor([[[0, 3], + [0, 3], + [0, 3]], + [[0, 3], + [0, 3], + [0, 3]]]), + col_indices=tensor([[[0, 1, 2], + [0, 1, 2], + [0, 1, 2]], - [[[[ 19.], - [119.]], + [[0, 1, 2], + [0, 1, 2], + [0, 1, 2]]]), + values=tensor([[[[[[1., 2., 3., 4.]], - [[ 29.], - [129.]]], + [[3., 4., 5., 6.]]], - [[[ 20.], - [120.]], + [[[2., 3., 4., 5.]], - [[ 30.], - [130.]]], + [[0., 0., 0., 0.]]], - [[[ 21.], - [121.]], + [[[0., 0., 0., 0.]], - [[ 31.], - [131.]]]], + [[4., 5., 6., 7.]]]], - [[[[ 20.], - [120.]], + [[[[1., 2., 3., 4.]], - [[ 30.], - [130.]]], + [[4., 5., 6., 7.]]], - [[[ 21.], - [121.]], + [[[2., 3., 4., 5.]], - [[ 31.], - [131.]]], + [[0., 0., 0., 0.]]], - [[[ 22.], - [122.]], + [[[3., 4., 5., 6.]], - [[ 32.], - [132.]]]]], + [[0., 0., 0., 0.]]]], + [[[[1., 2., 3., 4.]], - [[[[[ 21.], - [121.]], + [[2., 3., 4., 5.]]], - [[ 31.], - [131.]]], + [[[0., 0., 0., 0.]], - [[[ 22.], - [122.]], + [[3., 4., 5., 6.]]], - [[ 32.], - [132.]]], + [[[0., 0., 0., 0.]], - [[[ 23.], - [123.]], + [[4., 5., 6., 7.]]]]], - [[ 33.], - [133.]]]], - [[[[ 22.], - [122.]], + [[[[[0., 0., 0., 0.]], - [[ 32.], - [132.]]], + [[2., 3., 4., 5.]]], - [[[ 23.], - [123.]], + [[[1., 2., 3., 4.]], - [[ 33.], - [133.]]], + [[3., 4., 5., 6.]]], - [[[ 24.], - [124.]], + [[[0., 0., 0., 0.]], - [[ 34.], - [134.]]]], + [[4., 5., 6., 7.]]]], - [[[[ 23.], - [123.]], + [[[[1., 2., 3., 4.]], - [[ 33.], - [133.]]], + [[3., 4., 5., 6.]]], - [[[ 24.], - [124.]], + [[[0., 0., 0., 0.]], - [[ 34.], - [134.]]], + [[4., 5., 6., 7.]]], - [[[ 25.], - [125.]], + [[[2., 3., 4., 5.]], - [[ 35.], - [135.]]]], + [[0., 0., 0., 0.]]]], - [[[[ 24.], - [124.]], + [[[[1., 2., 3., 4.]], - [[ 34.], - [134.]]], + [[0., 0., 0., 0.]]], - [[[ 25.], - [125.]], + [[[2., 3., 4., 5.]], - [[ 35.], - [135.]]], + [[4., 5., 6., 7.]]], - [[[ 26.], - [126.]], + [[[3., 4., 5., 6.]], - [[ 36.], - [136.]]]]]]]), device='cuda:0', - size=(2, 3, 6, 6, 2, 1), nnz=4, layout=torch.sparse_bsr) + [[0., 0., 0., 0.]]]]]]), device='cuda:0', + size=(2, 3, 2, 3, 4), nnz=3, dtype=torch.float64, layout=torch.sparse_bsr) # _crow_indices -tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], +tensor([[[0, 3], + [0, 3], + [0, 3]], - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]], device='cuda:0', dtype=torch.int32) + [[0, 3], + [0, 3], + [0, 3]]], device='cuda:0', dtype=torch.int32) # _col_indices -tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], +tensor([[[0, 1, 2], + [0, 1, 2], + [0, 1, 2]], - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]], device='cuda:0', dtype=torch.int32) + [[0, 1, 2], + [0, 1, 2], + [0, 1, 2]]], device='cuda:0', dtype=torch.int32) # _values -tensor([[[[[[[ 1.], - [101.]], - - [[ 11.], - [111.]]], - - - [[[ 2.], - [102.]], - - [[ 12.], - [112.]]], - - - [[[ 3.], - [103.]], +tensor([[[[[[1., 2., 3., 4.]], - [[ 13.], - [113.]]]], + [[3., 4., 5., 6.]]], + [[[2., 3., 4., 5.]], - [[[[ 2.], - [102.]], + [[0., 0., 0., 0.]]], - [[ 12.], - [112.]]], + [[[0., 0., 0., 0.]], - [[[ 3.], - [103.]], + [[4., 5., 6., 7.]]]], - [[ 13.], - [113.]]], - [[[ 4.], - [104.]], + [[[[1., 2., 3., 4.]], - [[ 14.], - [114.]]]], + [[4., 5., 6., 7.]]], + [[[2., 3., 4., 5.]], - [[[[ 3.], - [103.]], + [[0., 0., 0., 0.]]], - [[ 13.], - [113.]]], + [[[3., 4., 5., 6.]], - [[[ 4.], - [104.]], + [[0., 0., 0., 0.]]]], - [[ 14.], - [114.]]], - [[[ 5.], - [105.]], + [[[[1., 2., 3., 4.]], - [[ 15.], - [115.]]]], + [[2., 3., 4., 5.]]], + [[[0., 0., 0., 0.]], - [[[[ 4.], - [104.]], + [[3., 4., 5., 6.]]], - [[ 14.], - [114.]]], + [[[0., 0., 0., 0.]], - [[[ 5.], - [105.]], + [[4., 5., 6., 7.]]]]], - [[ 15.], - [115.]]], - [[[ 6.], - [106.]], - [[ 16.], - [116.]]]]], + [[[[[0., 0., 0., 0.]], + [[2., 3., 4., 5.]]], + [[[1., 2., 3., 4.]], - [[[[[ 5.], - [105.]], + [[3., 4., 5., 6.]]], - [[ 15.], - [115.]]], + [[[0., 0., 0., 0.]], - [[[ 6.], - [106.]], + [[4., 5., 6., 7.]]]], - [[ 16.], - [116.]]], - [[[ 7.], - [107.]], + [[[[1., 2., 3., 4.]], - [[ 17.], - [117.]]]], + [[3., 4., 5., 6.]]], + [[[0., 0., 0., 0.]], - [[[[ 6.], - [106.]], + [[4., 5., 6., 7.]]], - [[ 16.], - [116.]]], + [[[2., 3., 4., 5.]], - [[[ 7.], - [107.]], + [[0., 0., 0., 0.]]]], - [[ 17.], - [117.]]], - [[[ 8.], - [108.]], + [[[[1., 2., 3., 4.]], - [[ 18.], - [118.]]]], + [[0., 0., 0., 0.]]], + [[[2., 3., 4., 5.]], - [[[[ 7.], - [107.]], + [[4., 5., 6., 7.]]], - [[ 17.], - [117.]]], + [[[3., 4., 5., 6.]], - [[[ 8.], - [108.]], + [[0., 0., 0., 0.]]]]]], device='cuda:0', dtype=torch.float64) - [[ 18.], - [118.]]], - - - [[[ 9.], - [109.]], - - [[ 19.], - [119.]]]], - - - - [[[[ 8.], - [108.]], - - [[ 18.], - [118.]]], - - - [[[ 9.], - [109.]], - - [[ 19.], - [119.]]], - - - [[[ 10.], - [110.]], +########## torch.float64/torch.int32/size=()+(8, 6)+(4, 2) ########## +# sparse tensor +tensor(crow_indices=tensor([0, 2, 3, 5, 7]), + col_indices=tensor([0, 1, 0, 0, 1, 0, 1]), + values=tensor([[[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 20.], - [120.]]]]], + [[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[[[ 9.], - [109.]], + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]], - [[ 19.], - [119.]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[[ 10.], - [110.]], - [[ 20.], - [120.]]], + [[[[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[ 11.], - [111.]], + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]]], - [[ 21.], - [121.]]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], - [[[[ 10.], - [110.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[ 20.], - [120.]]], - [[[ 11.], - [111.]], + [[[[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], - [[ 21.], - [121.]]], + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]]], - [[[ 12.], - [112.]], - [[ 22.], - [122.]]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[[[ 11.], - [111.]], - [[ 21.], - [121.]]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[ 12.], - [112.]], + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], - [[ 22.], - [122.]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[[ 13.], - [113.]], + [[[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], - [[ 23.], - [123.]]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]]]], - [[[[ 12.], - [112.]], - [[ 22.], - [122.]]], + [[[[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], - [[[ 13.], - [113.]], + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]]], - [[ 23.], - [123.]]], + [[[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], - [[[ 14.], - [114.]], + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], - [[ 24.], - [124.]]]]]], + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]]]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[[[[ 13.], - [113.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[ 23.], - [123.]]], + [[[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], - [[[ 14.], - [114.]], + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], - [[ 24.], - [124.]]], + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]]]], - [[[ 15.], - [115.]], - [[ 25.], - [125.]]]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]]], - [[[[ 14.], - [114.]], - [[ 24.], - [124.]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], - [[[ 15.], - [115.]], + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]]]]), device='cuda:0', size=(8, 6, 4, 2), + nnz=7, dtype=torch.float64, layout=torch.sparse_bsr) +# _crow_indices +tensor([0, 2, 3, 5, 7], device='cuda:0', dtype=torch.int32) +# _col_indices +tensor([0, 1, 0, 0, 1, 0, 1], device='cuda:0', dtype=torch.int32) +# _values +tensor([[[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 25.], - [125.]]], + [[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[[ 16.], - [116.]], - [[ 26.], - [126.]]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[[[ 15.], - [115.]], - [[ 25.], - [125.]]], + [[[[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], - [[[ 16.], - [116.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 26.], - [126.]]], + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]]], - [[[ 17.], - [117.]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 27.], - [127.]]]], + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[[[ 16.], - [116.]], - [[ 26.], - [126.]]], + [[[[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], - [[[ 17.], - [117.]], + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]]], - [[ 27.], - [127.]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[ 18.], - [118.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 28.], - [128.]]]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[[[ 17.], - [117.]], + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], - [[ 27.], - [127.]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[[ 18.], - [118.]], + [[[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], - [[ 28.], - [128.]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]]]], - [[[ 19.], - [119.]], - [[ 29.], - [129.]]]], + [[[[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], - [[[[ 18.], - [118.]], + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]]], - [[ 28.], - [128.]]], + [[[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], - [[[ 19.], - [119.]], + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], - [[ 29.], - [129.]]], + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]]]], - [[[ 20.], - [120.]], - [[ 30.], - [130.]]]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[[[ 19.], - [119.]], - [[ 29.], - [129.]]], + [[[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], - [[[ 20.], - [120.]], + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]]]], - [[ 30.], - [130.]]], - [[[ 21.], - [121.]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 31.], - [131.]]]], + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]]], - [[[[ 20.], - [120.]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 30.], - [130.]]], + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]]]], device='cuda:0', dtype=torch.float64) - [[[ 21.], - [121.]], - [[ 31.], - [131.]]], +########## torch.float32/torch.int64/size=(2, 3)+(2, 3)+(4,) ########## +# sparse tensor +tensor(crow_indices=tensor([[[0, 3], + [0, 3], + [0, 3]], + [[0, 3], + [0, 3], + [0, 3]]]), + col_indices=tensor([[[0, 1, 2], + [0, 1, 2], + [0, 1, 2]], - [[[ 22.], - [122.]], + [[0, 1, 2], + [0, 1, 2], + [0, 1, 2]]]), + values=tensor([[[[[[1., 2., 3., 4.]], - [[ 32.], - [132.]]]]], + [[3., 4., 5., 6.]]], + [[[2., 3., 4., 5.]], + [[0., 0., 0., 0.]]], - [[[[[ 21.], - [121.]], - [[ 31.], - [131.]]], + [[[0., 0., 0., 0.]], + [[4., 5., 6., 7.]]]], - [[[ 22.], - [122.]], - [[ 32.], - [132.]]], + [[[[1., 2., 3., 4.]], - [[[ 23.], - [123.]], + [[4., 5., 6., 7.]]], - [[ 33.], - [133.]]]], + [[[2., 3., 4., 5.]], + [[0., 0., 0., 0.]]], - [[[[ 22.], - [122.]], - [[ 32.], - [132.]]], + [[[3., 4., 5., 6.]], + [[0., 0., 0., 0.]]]], - [[[ 23.], - [123.]], - [[ 33.], - [133.]]], + [[[[1., 2., 3., 4.]], - [[[ 24.], - [124.]], + [[2., 3., 4., 5.]]], - [[ 34.], - [134.]]]], + [[[0., 0., 0., 0.]], + [[3., 4., 5., 6.]]], - [[[[ 23.], - [123.]], - [[ 33.], - [133.]]], + [[[0., 0., 0., 0.]], + [[4., 5., 6., 7.]]]]], - [[[ 24.], - [124.]], - [[ 34.], - [134.]]], - [[[ 25.], - [125.]], + [[[[[0., 0., 0., 0.]], - [[ 35.], - [135.]]]], + [[2., 3., 4., 5.]]], + [[[1., 2., 3., 4.]], - [[[[ 24.], - [124.]], + [[3., 4., 5., 6.]]], - [[ 34.], - [134.]]], + [[[0., 0., 0., 0.]], - [[[ 25.], - [125.]], + [[4., 5., 6., 7.]]]], - [[ 35.], - [135.]]], - [[[ 26.], - [126.]], + [[[[1., 2., 3., 4.]], - [[ 36.], - [136.]]]]]]], device='cuda:0') + [[3., 4., 5., 6.]]], -########## torch.float64/torch.int32/size=()+(6, 6)+(2,) ########## -# sparse tensor -tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 2]), - values=tensor([[[[ 1., 101.], - [ 11., 111.]], + [[[0., 0., 0., 0.]], - [[ 2., 102.], - [ 12., 112.]], + [[4., 5., 6., 7.]]], - [[ 3., 103.], - [ 13., 113.]]], + [[[2., 3., 4., 5.]], - [[[ 2., 102.], - [ 12., 112.]], + [[0., 0., 0., 0.]]]], - [[ 3., 103.], - [ 13., 113.]], - [[ 4., 104.], - [ 14., 114.]]], + [[[[1., 2., 3., 4.]], - [[[ 3., 103.], - [ 13., 113.]], + [[0., 0., 0., 0.]]], - [[ 4., 104.], - [ 14., 114.]], - [[ 5., 105.], - [ 15., 115.]]], + [[[2., 3., 4., 5.]], + [[4., 5., 6., 7.]]], - [[[ 4., 104.], - [ 14., 114.]], - [[ 5., 105.], - [ 15., 115.]], + [[[3., 4., 5., 6.]], - [[ 6., 106.], - [ 16., 116.]]]]), device='cuda:0', size=(6, 6, 2), - nnz=4, dtype=torch.float64, layout=torch.sparse_bsr) + [[0., 0., 0., 0.]]]]]]), device='cuda:0', + size=(2, 3, 2, 3, 4), nnz=3, layout=torch.sparse_bsr) # _crow_indices -tensor([0, 2, 4], device='cuda:0', dtype=torch.int32) -# _col_indices -tensor([0, 1, 0, 2], device='cuda:0', dtype=torch.int32) -# _values -tensor([[[[ 1., 101.], - [ 11., 111.]], - - [[ 2., 102.], - [ 12., 112.]], - - [[ 3., 103.], - [ 13., 113.]]], - - - [[[ 2., 102.], - [ 12., 112.]], +tensor([[[0, 3], + [0, 3], + [0, 3]], - [[ 3., 103.], - [ 13., 113.]], - - [[ 4., 104.], - [ 14., 114.]]], - - - [[[ 3., 103.], - [ 13., 113.]], - - [[ 4., 104.], - [ 14., 114.]], + [[0, 3], + [0, 3], + [0, 3]]], device='cuda:0') +# _col_indices +tensor([[[0, 1, 2], + [0, 1, 2], + [0, 1, 2]], - [[ 5., 105.], - [ 15., 115.]]], + [[0, 1, 2], + [0, 1, 2], + [0, 1, 2]]], device='cuda:0') +# _values +tensor([[[[[[1., 2., 3., 4.]], + [[3., 4., 5., 6.]]], - [[[ 4., 104.], - [ 14., 114.]], - [[ 5., 105.], - [ 15., 115.]], + [[[2., 3., 4., 5.]], - [[ 6., 106.], - [ 16., 116.]]]], device='cuda:0', dtype=torch.float64) + [[0., 0., 0., 0.]]], -########## torch.float64/torch.int32/size=()+(4, 9)+(4, 2) ########## -# sparse tensor -tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 2]), - values=tensor([[[[[1.0000e+00, 1.0010e+03], - [1.0100e+02, 1.1010e+03], - [2.0100e+02, 1.2010e+03], - [3.0100e+02, 1.3010e+03]], - [[1.1000e+01, 1.0110e+03], - [1.1100e+02, 1.1110e+03], - [2.1100e+02, 1.2110e+03], - [3.1100e+02, 1.3110e+03]], + [[[0., 0., 0., 0.]], - [[2.1000e+01, 1.0210e+03], - [1.2100e+02, 1.1210e+03], - [2.2100e+02, 1.2210e+03], - [3.2100e+02, 1.3210e+03]]], + [[4., 5., 6., 7.]]]], - [[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]], + [[[[1., 2., 3., 4.]], - [[2.2000e+01, 1.0220e+03], - [1.2200e+02, 1.1220e+03], - [2.2200e+02, 1.2220e+03], - [3.2200e+02, 1.3220e+03]]]], + [[4., 5., 6., 7.]]], + [[[2., 3., 4., 5.]], - [[[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], + [[0., 0., 0., 0.]]], - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]], - [[2.2000e+01, 1.0220e+03], - [1.2200e+02, 1.1220e+03], - [2.2200e+02, 1.2220e+03], - [3.2200e+02, 1.3220e+03]]], + [[[3., 4., 5., 6.]], + [[0., 0., 0., 0.]]]], - [[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]], - [[2.3000e+01, 1.0230e+03], - [1.2300e+02, 1.1230e+03], - [2.2300e+02, 1.2230e+03], - [3.2300e+02, 1.3230e+03]]]], + [[[[1., 2., 3., 4.]], + [[2., 3., 4., 5.]]], - [[[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], + [[[0., 0., 0., 0.]], - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]], + [[3., 4., 5., 6.]]], - [[2.3000e+01, 1.0230e+03], - [1.2300e+02, 1.1230e+03], - [2.2300e+02, 1.2230e+03], - [3.2300e+02, 1.3230e+03]]], + [[[0., 0., 0., 0.]], - [[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], + [[4., 5., 6., 7.]]]]], - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]], - [[2.4000e+01, 1.0240e+03], - [1.2400e+02, 1.1240e+03], - [2.2400e+02, 1.2240e+03], - [3.2400e+02, 1.3240e+03]]]], + [[[[[0., 0., 0., 0.]], - [[[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], + [[2., 3., 4., 5.]]], - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]], - [[2.4000e+01, 1.0240e+03], - [1.2400e+02, 1.1240e+03], - [2.2400e+02, 1.2240e+03], - [3.2400e+02, 1.3240e+03]]], + [[[1., 2., 3., 4.]], + [[3., 4., 5., 6.]]], - [[[5.0000e+00, 1.0050e+03], - [1.0500e+02, 1.1050e+03], - [2.0500e+02, 1.2050e+03], - [3.0500e+02, 1.3050e+03]], - [[1.5000e+01, 1.0150e+03], - [1.1500e+02, 1.1150e+03], - [2.1500e+02, 1.2150e+03], - [3.1500e+02, 1.3150e+03]], + [[[0., 0., 0., 0.]], - [[2.5000e+01, 1.0250e+03], - [1.2500e+02, 1.1250e+03], - [2.2500e+02, 1.2250e+03], - [3.2500e+02, 1.3250e+03]]]]]), device='cuda:0', - size=(4, 9, 4, 2), nnz=4, dtype=torch.float64, layout=torch.sparse_bsr) -# _crow_indices -tensor([0, 2, 4], device='cuda:0', dtype=torch.int32) -# _col_indices -tensor([0, 1, 0, 2], device='cuda:0', dtype=torch.int32) -# _values -tensor([[[[[1.0000e+00, 1.0010e+03], - [1.0100e+02, 1.1010e+03], - [2.0100e+02, 1.2010e+03], - [3.0100e+02, 1.3010e+03]], + [[4., 5., 6., 7.]]]], - [[1.1000e+01, 1.0110e+03], - [1.1100e+02, 1.1110e+03], - [2.1100e+02, 1.2110e+03], - [3.1100e+02, 1.3110e+03]], - [[2.1000e+01, 1.0210e+03], - [1.2100e+02, 1.1210e+03], - [2.2100e+02, 1.2210e+03], - [3.2100e+02, 1.3210e+03]]], + [[[[1., 2., 3., 4.]], - [[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], + [[3., 4., 5., 6.]]], - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]], - [[2.2000e+01, 1.0220e+03], - [1.2200e+02, 1.1220e+03], - [2.2200e+02, 1.2220e+03], - [3.2200e+02, 1.3220e+03]]]], + [[[0., 0., 0., 0.]], + [[4., 5., 6., 7.]]], - [[[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], + [[[2., 3., 4., 5.]], - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]], + [[0., 0., 0., 0.]]]], - [[2.2000e+01, 1.0220e+03], - [1.2200e+02, 1.1220e+03], - [2.2200e+02, 1.2220e+03], - [3.2200e+02, 1.3220e+03]]], - [[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], + [[[[1., 2., 3., 4.]], - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]], + [[0., 0., 0., 0.]]], - [[2.3000e+01, 1.0230e+03], - [1.2300e+02, 1.1230e+03], - [2.2300e+02, 1.2230e+03], - [3.2300e+02, 1.3230e+03]]]], + [[[2., 3., 4., 5.]], + [[4., 5., 6., 7.]]], - [[[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]], + [[[3., 4., 5., 6.]], - [[2.3000e+01, 1.0230e+03], - [1.2300e+02, 1.1230e+03], - [2.2300e+02, 1.2230e+03], - [3.2300e+02, 1.3230e+03]]], + [[0., 0., 0., 0.]]]]]], device='cuda:0') +########## torch.float32/torch.int64/size=()+(8, 6)+(4, 2) ########## +# sparse tensor +tensor(crow_indices=tensor([0, 2, 3, 5, 7]), + col_indices=tensor([0, 1, 0, 0, 1, 0, 1]), + values=tensor([[[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], + [[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[2.4000e+01, 1.0240e+03], - [1.2400e+02, 1.1240e+03], - [2.2400e+02, 1.2240e+03], - [3.2400e+02, 1.3240e+03]]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]], - [[[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]], - [[2.4000e+01, 1.0240e+03], - [1.2400e+02, 1.1240e+03], - [2.2400e+02, 1.2240e+03], - [3.2400e+02, 1.3240e+03]]], + [[[[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], - [[[5.0000e+00, 1.0050e+03], - [1.0500e+02, 1.1050e+03], - [2.0500e+02, 1.2050e+03], - [3.0500e+02, 1.3050e+03]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[1.5000e+01, 1.0150e+03], - [1.1500e+02, 1.1150e+03], - [2.1500e+02, 1.2150e+03], - [3.1500e+02, 1.3150e+03]], + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]]], - [[2.5000e+01, 1.0250e+03], - [1.2500e+02, 1.1250e+03], - [2.2500e+02, 1.2250e+03], - [3.2500e+02, 1.3250e+03]]]]], device='cuda:0', dtype=torch.float64) -########## torch.float64/torch.int32/size=(2, 3)+(6, 6)+(2, 1) ########## -# sparse tensor -tensor(crow_indices=tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]), - col_indices=tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]), - values=tensor([[[[[[[ 1.], - [101.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[ 11.], - [111.]]], - [[[ 2.], - [102.]], + [[[[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], - [[ 12.], - [112.]]], + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]]], - [[[ 3.], - [103.]], - [[ 13.], - [113.]]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[[[ 2.], - [102.]], - [[ 12.], - [112.]]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[ 3.], - [103.]], + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], - [[ 13.], - [113.]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[[ 4.], - [104.]], + [[[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], - [[ 14.], - [114.]]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]]]], - [[[[ 3.], - [103.]], - [[ 13.], - [113.]]], + [[[[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], - [[[ 4.], - [104.]], + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]]], - [[ 14.], - [114.]]], + [[[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], - [[[ 5.], - [105.]], + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], - [[ 15.], - [115.]]]], + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]]]], - [[[[ 4.], - [104.]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 14.], - [114.]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[[ 5.], - [105.]], - [[ 15.], - [115.]]], + [[[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], - [[[ 6.], - [106.]], + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]]]], - [[ 16.], - [116.]]]]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], - [[[[[ 5.], - [105.]], + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]]], - [[ 15.], - [115.]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[ 6.], - [106.]], + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], - [[ 16.], - [116.]]], + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]]]]), device='cuda:0', size=(8, 6, 4, 2), + nnz=7, layout=torch.sparse_bsr) +# _crow_indices +tensor([0, 2, 3, 5, 7], device='cuda:0') +# _col_indices +tensor([0, 1, 0, 0, 1, 0, 1], device='cuda:0') +# _values +tensor([[[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], - [[[ 7.], - [107.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[ 17.], - [117.]]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]], - [[[[ 6.], - [106.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[ 16.], - [116.]]], - [[[ 7.], - [107.]], + [[[[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], - [[ 17.], - [117.]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]]], - [[[ 8.], - [108.]], - [[ 18.], - [118.]]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[[[ 7.], - [107.]], - [[ 17.], - [117.]]], + [[[[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], - [[[ 8.], - [108.]], + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], - [[ 18.], - [118.]]], + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]]], - [[[ 9.], - [109.]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 19.], - [119.]]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[[[ 8.], - [108.]], - [[ 18.], - [118.]]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], - [[[ 9.], - [109.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[ 19.], - [119.]]], + [[[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], - [[[ 10.], - [110.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 20.], - [120.]]]]], + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]]]], + [[[[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], - [[[[[ 9.], - [109.]], + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], - [[ 19.], - [119.]]], + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]]], - [[[ 10.], - [110.]], + [[[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], - [[ 20.], - [120.]]], + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]]]], - [[[ 11.], - [111.]], - [[ 21.], - [121.]]]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[[ 10.], - [110.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[ 20.], - [120.]]], + [[[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], - [[[ 11.], - [111.]], + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], - [[ 21.], - [121.]]], + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]]]], - [[[ 12.], - [112.]], - [[ 22.], - [122.]]]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]]], - [[[[ 11.], - [111.]], - [[ 21.], - [121.]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], - [[[ 12.], - [112.]], + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]]]], device='cuda:0') - [[ 22.], - [122.]]], +########## torch.float64/torch.int64/size=(2, 3)+(2, 3)+(4,) ########## +# sparse tensor +tensor(crow_indices=tensor([[[0, 3], + [0, 3], + [0, 3]], - [[[ 13.], - [113.]], + [[0, 3], + [0, 3], + [0, 3]]]), + col_indices=tensor([[[0, 1, 2], + [0, 1, 2], + [0, 1, 2]], - [[ 23.], - [123.]]]], + [[0, 1, 2], + [0, 1, 2], + [0, 1, 2]]]), + values=tensor([[[[[[1., 2., 3., 4.]], + [[3., 4., 5., 6.]]], - [[[[ 12.], - [112.]], + [[[2., 3., 4., 5.]], - [[ 22.], - [122.]]], + [[0., 0., 0., 0.]]], - [[[ 13.], - [113.]], + [[[0., 0., 0., 0.]], - [[ 23.], - [123.]]], + [[4., 5., 6., 7.]]]], - [[[ 14.], - [114.]], - [[ 24.], - [124.]]]]]], + [[[[1., 2., 3., 4.]], + [[4., 5., 6., 7.]]], + [[[2., 3., 4., 5.]], + [[0., 0., 0., 0.]]], - [[[[[[ 13.], - [113.]], - [[ 23.], - [123.]]], + [[[3., 4., 5., 6.]], + [[0., 0., 0., 0.]]]], - [[[ 14.], - [114.]], - [[ 24.], - [124.]]], + [[[[1., 2., 3., 4.]], - [[[ 15.], - [115.]], + [[2., 3., 4., 5.]]], - [[ 25.], - [125.]]]], + [[[0., 0., 0., 0.]], + [[3., 4., 5., 6.]]], - [[[[ 14.], - [114.]], - [[ 24.], - [124.]]], + [[[0., 0., 0., 0.]], + [[4., 5., 6., 7.]]]]], - [[[ 15.], - [115.]], - [[ 25.], - [125.]]], - [[[ 16.], - [116.]], + [[[[[0., 0., 0., 0.]], - [[ 26.], - [126.]]]], + [[2., 3., 4., 5.]]], + [[[1., 2., 3., 4.]], - [[[[ 15.], - [115.]], + [[3., 4., 5., 6.]]], - [[ 25.], - [125.]]], + [[[0., 0., 0., 0.]], - [[[ 16.], - [116.]], + [[4., 5., 6., 7.]]]], - [[ 26.], - [126.]]], - [[[ 17.], - [117.]], + [[[[1., 2., 3., 4.]], - [[ 27.], - [127.]]]], + [[3., 4., 5., 6.]]], + [[[0., 0., 0., 0.]], - [[[[ 16.], - [116.]], + [[4., 5., 6., 7.]]], - [[ 26.], - [126.]]], + [[[2., 3., 4., 5.]], - [[[ 17.], - [117.]], + [[0., 0., 0., 0.]]]], - [[ 27.], - [127.]]], - [[[ 18.], - [118.]], + [[[[1., 2., 3., 4.]], - [[ 28.], - [128.]]]]], + [[0., 0., 0., 0.]]], + [[[2., 3., 4., 5.]], + [[4., 5., 6., 7.]]], - [[[[[ 17.], - [117.]], - [[ 27.], - [127.]]], + [[[3., 4., 5., 6.]], + [[0., 0., 0., 0.]]]]]]), device='cuda:0', + size=(2, 3, 2, 3, 4), nnz=3, dtype=torch.float64, layout=torch.sparse_bsr) +# _crow_indices +tensor([[[0, 3], + [0, 3], + [0, 3]], - [[[ 18.], - [118.]], + [[0, 3], + [0, 3], + [0, 3]]], device='cuda:0') +# _col_indices +tensor([[[0, 1, 2], + [0, 1, 2], + [0, 1, 2]], - [[ 28.], - [128.]]], + [[0, 1, 2], + [0, 1, 2], + [0, 1, 2]]], device='cuda:0') +# _values +tensor([[[[[[1., 2., 3., 4.]], + [[3., 4., 5., 6.]]], - [[[ 19.], - [119.]], - [[ 29.], - [129.]]]], + [[[2., 3., 4., 5.]], + [[0., 0., 0., 0.]]], - [[[[ 18.], - [118.]], + [[[0., 0., 0., 0.]], - [[ 28.], - [128.]]], + [[4., 5., 6., 7.]]]], - [[[ 19.], - [119.]], - [[ 29.], - [129.]]], + [[[[1., 2., 3., 4.]], + [[4., 5., 6., 7.]]], - [[[ 20.], - [120.]], - [[ 30.], - [130.]]]], + [[[2., 3., 4., 5.]], + [[0., 0., 0., 0.]]], - [[[[ 19.], - [119.]], + [[[3., 4., 5., 6.]], - [[ 29.], - [129.]]], + [[0., 0., 0., 0.]]]], - [[[ 20.], - [120.]], - [[ 30.], - [130.]]], + [[[[1., 2., 3., 4.]], + [[2., 3., 4., 5.]]], - [[[ 21.], - [121.]], - [[ 31.], - [131.]]]], + [[[0., 0., 0., 0.]], + [[3., 4., 5., 6.]]], - [[[[ 20.], - [120.]], + [[[0., 0., 0., 0.]], - [[ 30.], - [130.]]], + [[4., 5., 6., 7.]]]]], - [[[ 21.], - [121.]], - [[ 31.], - [131.]]], + [[[[[0., 0., 0., 0.]], - [[[ 22.], - [122.]], + [[2., 3., 4., 5.]]], - [[ 32.], - [132.]]]]], + [[[1., 2., 3., 4.]], + [[3., 4., 5., 6.]]], - [[[[[ 21.], - [121.]], + [[[0., 0., 0., 0.]], - [[ 31.], - [131.]]], + [[4., 5., 6., 7.]]]], - [[[ 22.], - [122.]], - [[ 32.], - [132.]]], + [[[[1., 2., 3., 4.]], + [[3., 4., 5., 6.]]], - [[[ 23.], - [123.]], - [[ 33.], - [133.]]]], + [[[0., 0., 0., 0.]], + [[4., 5., 6., 7.]]], - [[[[ 22.], - [122.]], + [[[2., 3., 4., 5.]], - [[ 32.], - [132.]]], + [[0., 0., 0., 0.]]]], - [[[ 23.], - [123.]], - [[ 33.], - [133.]]], + [[[[1., 2., 3., 4.]], + [[0., 0., 0., 0.]]], - [[[ 24.], - [124.]], - [[ 34.], - [134.]]]], + [[[2., 3., 4., 5.]], + [[4., 5., 6., 7.]]], - [[[[ 23.], - [123.]], + [[[3., 4., 5., 6.]], - [[ 33.], - [133.]]], + [[0., 0., 0., 0.]]]]]], device='cuda:0', dtype=torch.float64) +########## torch.float64/torch.int64/size=()+(8, 6)+(4, 2) ########## +# sparse tensor +tensor(crow_indices=tensor([0, 2, 3, 5, 7]), + col_indices=tensor([0, 1, 0, 0, 1, 0, 1]), + values=tensor([[[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[ 24.], - [124.]], + [[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], - [[ 34.], - [134.]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[[ 25.], - [125.]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 35.], - [135.]]]], + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[[[ 24.], - [124.]], - [[ 34.], - [134.]]], + [[[[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[ 25.], - [125.]], + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]]], - [[ 35.], - [135.]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[ 26.], - [126.]], + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], - [[ 36.], - [136.]]]]]]]), device='cuda:0', - size=(2, 3, 6, 6, 2, 1), nnz=4, dtype=torch.float64, - layout=torch.sparse_bsr) -# _crow_indices -tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]], device='cuda:0', dtype=torch.int32) -# _col_indices -tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]], device='cuda:0', dtype=torch.int32) -# _values -tensor([[[[[[[ 1.], - [101.]], - [[ 11.], - [111.]]], + [[[[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], - [[[ 2.], - [102.]], + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]]], - [[ 12.], - [112.]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[ 3.], - [103.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 13.], - [113.]]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[[[ 2.], - [102.]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 12.], - [112.]]], + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[[ 3.], - [103.]], - [[ 13.], - [113.]]], + [[[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[ 4.], - [104.]], + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]]]], - [[ 14.], - [114.]]]], + [[[[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], - [[[[ 3.], - [103.]], + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], - [[ 13.], - [113.]]], + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]]], - [[[ 4.], - [104.]], + [[[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], - [[ 14.], - [114.]]], + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]]]], - [[[ 5.], - [105.]], - [[ 15.], - [115.]]]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[[ 4.], - [104.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[ 14.], - [114.]]], + [[[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], - [[[ 5.], - [105.]], + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], - [[ 15.], - [115.]]], + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]]]], - [[[ 6.], - [106.]], - [[ 16.], - [116.]]]]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]]], - [[[[[ 5.], - [105.]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 15.], - [115.]]], + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], - - [[[ 6.], - [106.]], - - [[ 16.], - [116.]]], - - - [[[ 7.], - [107.]], - - [[ 17.], - [117.]]]], - - - - [[[[ 6.], - [106.]], - - [[ 16.], - [116.]]], - - - [[[ 7.], - [107.]], - - [[ 17.], - [117.]]], - - - [[[ 8.], - [108.]], - - [[ 18.], - [118.]]]], - - - - [[[[ 7.], - [107.]], - - [[ 17.], - [117.]]], - - - [[[ 8.], - [108.]], - - [[ 18.], - [118.]]], - - - [[[ 9.], - [109.]], - - [[ 19.], - [119.]]]], - - - - [[[[ 8.], - [108.]], - - [[ 18.], - [118.]]], - - - [[[ 9.], - [109.]], - - [[ 19.], - [119.]]], - - - [[[ 10.], - [110.]], - - [[ 20.], - [120.]]]]], - - - - - [[[[[ 9.], - [109.]], - - [[ 19.], - [119.]]], - - - [[[ 10.], - [110.]], - - [[ 20.], - [120.]]], - - - [[[ 11.], - [111.]], - - [[ 21.], - [121.]]]], - - - - [[[[ 10.], - [110.]], - - [[ 20.], - [120.]]], - - - [[[ 11.], - [111.]], - - [[ 21.], - [121.]]], - - - [[[ 12.], - [112.]], - - [[ 22.], - [122.]]]], - - - - [[[[ 11.], - [111.]], - - [[ 21.], - [121.]]], - - - [[[ 12.], - [112.]], - - [[ 22.], - [122.]]], - - - [[[ 13.], - [113.]], - - [[ 23.], - [123.]]]], - - - - [[[[ 12.], - [112.]], - - [[ 22.], - [122.]]], - - - [[[ 13.], - [113.]], - - [[ 23.], - [123.]]], - - - [[[ 14.], - [114.]], - - [[ 24.], - [124.]]]]]], - - - - - - [[[[[[ 13.], - [113.]], - - [[ 23.], - [123.]]], - - - [[[ 14.], - [114.]], - - [[ 24.], - [124.]]], - - - [[[ 15.], - [115.]], - - [[ 25.], - [125.]]]], - - - - [[[[ 14.], - [114.]], - - [[ 24.], - [124.]]], - - - [[[ 15.], - [115.]], - - [[ 25.], - [125.]]], - - - [[[ 16.], - [116.]], - - [[ 26.], - [126.]]]], - - - - [[[[ 15.], - [115.]], - - [[ 25.], - [125.]]], - - - [[[ 16.], - [116.]], - - [[ 26.], - [126.]]], - - - [[[ 17.], - [117.]], - - [[ 27.], - [127.]]]], - - - - [[[[ 16.], - [116.]], - - [[ 26.], - [126.]]], - - - [[[ 17.], - [117.]], - - [[ 27.], - [127.]]], - - - [[[ 18.], - [118.]], - - [[ 28.], - [128.]]]]], - - - - - [[[[[ 17.], - [117.]], - - [[ 27.], - [127.]]], - - - [[[ 18.], - [118.]], - - [[ 28.], - [128.]]], - - - [[[ 19.], - [119.]], - - [[ 29.], - [129.]]]], - - - - [[[[ 18.], - [118.]], - - [[ 28.], - [128.]]], - - - [[[ 19.], - [119.]], - - [[ 29.], - [129.]]], - - - [[[ 20.], - [120.]], - - [[ 30.], - [130.]]]], - - - - [[[[ 19.], - [119.]], - - [[ 29.], - [129.]]], - - - [[[ 20.], - [120.]], - - [[ 30.], - [130.]]], - - - [[[ 21.], - [121.]], - - [[ 31.], - [131.]]]], - - - - [[[[ 20.], - [120.]], - - [[ 30.], - [130.]]], - - - [[[ 21.], - [121.]], - - [[ 31.], - [131.]]], - - - [[[ 22.], - [122.]], - - [[ 32.], - [132.]]]]], - - - - - [[[[[ 21.], - [121.]], - - [[ 31.], - [131.]]], - - - [[[ 22.], - [122.]], - - [[ 32.], - [132.]]], - - - [[[ 23.], - [123.]], - - [[ 33.], - [133.]]]], - - - - [[[[ 22.], - [122.]], - - [[ 32.], - [132.]]], - - - [[[ 23.], - [123.]], - - [[ 33.], - [133.]]], - - - [[[ 24.], - [124.]], - - [[ 34.], - [134.]]]], - - - - [[[[ 23.], - [123.]], - - [[ 33.], - [133.]]], - - - [[[ 24.], - [124.]], - - [[ 34.], - [134.]]], - - - [[[ 25.], - [125.]], - - [[ 35.], - [135.]]]], - - - - [[[[ 24.], - [124.]], - - [[ 34.], - [134.]]], - - - [[[ 25.], - [125.]], - - [[ 35.], - [135.]]], - - - [[[ 26.], - [126.]], - - [[ 36.], - [136.]]]]]]], device='cuda:0', dtype=torch.float64) - - -########## torch.float32/torch.int64/size=()+(6, 6)+(2,) ########## -# sparse tensor -tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 2]), - values=tensor([[[[ 1., 101.], - [ 11., 111.]], - - [[ 2., 102.], - [ 12., 112.]], - - [[ 3., 103.], - [ 13., 113.]]], - - - [[[ 2., 102.], - [ 12., 112.]], - - [[ 3., 103.], - [ 13., 113.]], - - [[ 4., 104.], - [ 14., 114.]]], - - - [[[ 3., 103.], - [ 13., 113.]], - - [[ 4., 104.], - [ 14., 114.]], - - [[ 5., 105.], - [ 15., 115.]]], - - - [[[ 4., 104.], - [ 14., 114.]], - - [[ 5., 105.], - [ 15., 115.]], - - [[ 6., 106.], - [ 16., 116.]]]]), device='cuda:0', size=(6, 6, 2), - nnz=4, layout=torch.sparse_bsr) -# _crow_indices -tensor([0, 2, 4], device='cuda:0') -# _col_indices -tensor([0, 1, 0, 2], device='cuda:0') -# _values -tensor([[[[ 1., 101.], - [ 11., 111.]], - - [[ 2., 102.], - [ 12., 112.]], - - [[ 3., 103.], - [ 13., 113.]]], - - - [[[ 2., 102.], - [ 12., 112.]], - - [[ 3., 103.], - [ 13., 113.]], - - [[ 4., 104.], - [ 14., 114.]]], - - - [[[ 3., 103.], - [ 13., 113.]], - - [[ 4., 104.], - [ 14., 114.]], - - [[ 5., 105.], - [ 15., 115.]]], - - - [[[ 4., 104.], - [ 14., 114.]], - - [[ 5., 105.], - [ 15., 115.]], - - [[ 6., 106.], - [ 16., 116.]]]], device='cuda:0') - -########## torch.float32/torch.int64/size=()+(4, 9)+(4, 2) ########## -# sparse tensor -tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 2]), - values=tensor([[[[[1.0000e+00, 1.0010e+03], - [1.0100e+02, 1.1010e+03], - [2.0100e+02, 1.2010e+03], - [3.0100e+02, 1.3010e+03]], - - [[1.1000e+01, 1.0110e+03], - [1.1100e+02, 1.1110e+03], - [2.1100e+02, 1.2110e+03], - [3.1100e+02, 1.3110e+03]], - - [[2.1000e+01, 1.0210e+03], - [1.2100e+02, 1.1210e+03], - [2.2100e+02, 1.2210e+03], - [3.2100e+02, 1.3210e+03]]], - - - [[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], - - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]], - - [[2.2000e+01, 1.0220e+03], - [1.2200e+02, 1.1220e+03], - [2.2200e+02, 1.2220e+03], - [3.2200e+02, 1.3220e+03]]]], - - - - [[[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], - - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]], - - [[2.2000e+01, 1.0220e+03], - [1.2200e+02, 1.1220e+03], - [2.2200e+02, 1.2220e+03], - [3.2200e+02, 1.3220e+03]]], - - - [[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], - - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]], - - [[2.3000e+01, 1.0230e+03], - [1.2300e+02, 1.1230e+03], - [2.2300e+02, 1.2230e+03], - [3.2300e+02, 1.3230e+03]]]], - - - - [[[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], - - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]], - - [[2.3000e+01, 1.0230e+03], - [1.2300e+02, 1.1230e+03], - [2.2300e+02, 1.2230e+03], - [3.2300e+02, 1.3230e+03]]], - - - [[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], - - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]], - - [[2.4000e+01, 1.0240e+03], - [1.2400e+02, 1.1240e+03], - [2.2400e+02, 1.2240e+03], - [3.2400e+02, 1.3240e+03]]]], - - - - [[[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], - - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]], - - [[2.4000e+01, 1.0240e+03], - [1.2400e+02, 1.1240e+03], - [2.2400e+02, 1.2240e+03], - [3.2400e+02, 1.3240e+03]]], - - - [[[5.0000e+00, 1.0050e+03], - [1.0500e+02, 1.1050e+03], - [2.0500e+02, 1.2050e+03], - [3.0500e+02, 1.3050e+03]], - - [[1.5000e+01, 1.0150e+03], - [1.1500e+02, 1.1150e+03], - [2.1500e+02, 1.2150e+03], - [3.1500e+02, 1.3150e+03]], - - [[2.5000e+01, 1.0250e+03], - [1.2500e+02, 1.1250e+03], - [2.2500e+02, 1.2250e+03], - [3.2500e+02, 1.3250e+03]]]]]), device='cuda:0', - size=(4, 9, 4, 2), nnz=4, layout=torch.sparse_bsr) -# _crow_indices -tensor([0, 2, 4], device='cuda:0') -# _col_indices -tensor([0, 1, 0, 2], device='cuda:0') -# _values -tensor([[[[[1.0000e+00, 1.0010e+03], - [1.0100e+02, 1.1010e+03], - [2.0100e+02, 1.2010e+03], - [3.0100e+02, 1.3010e+03]], - - [[1.1000e+01, 1.0110e+03], - [1.1100e+02, 1.1110e+03], - [2.1100e+02, 1.2110e+03], - [3.1100e+02, 1.3110e+03]], - - [[2.1000e+01, 1.0210e+03], - [1.2100e+02, 1.1210e+03], - [2.2100e+02, 1.2210e+03], - [3.2100e+02, 1.3210e+03]]], - - - [[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], - - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]], - - [[2.2000e+01, 1.0220e+03], - [1.2200e+02, 1.1220e+03], - [2.2200e+02, 1.2220e+03], - [3.2200e+02, 1.3220e+03]]]], - - - - [[[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], - - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]], - - [[2.2000e+01, 1.0220e+03], - [1.2200e+02, 1.1220e+03], - [2.2200e+02, 1.2220e+03], - [3.2200e+02, 1.3220e+03]]], - - - [[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], - - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]], - - [[2.3000e+01, 1.0230e+03], - [1.2300e+02, 1.1230e+03], - [2.2300e+02, 1.2230e+03], - [3.2300e+02, 1.3230e+03]]]], - - - - [[[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], - - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]], - - [[2.3000e+01, 1.0230e+03], - [1.2300e+02, 1.1230e+03], - [2.2300e+02, 1.2230e+03], - [3.2300e+02, 1.3230e+03]]], - - - [[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], - - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]], - - [[2.4000e+01, 1.0240e+03], - [1.2400e+02, 1.1240e+03], - [2.2400e+02, 1.2240e+03], - [3.2400e+02, 1.3240e+03]]]], - - - - [[[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], - - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]], - - [[2.4000e+01, 1.0240e+03], - [1.2400e+02, 1.1240e+03], - [2.2400e+02, 1.2240e+03], - [3.2400e+02, 1.3240e+03]]], - - - [[[5.0000e+00, 1.0050e+03], - [1.0500e+02, 1.1050e+03], - [2.0500e+02, 1.2050e+03], - [3.0500e+02, 1.3050e+03]], - - [[1.5000e+01, 1.0150e+03], - [1.1500e+02, 1.1150e+03], - [2.1500e+02, 1.2150e+03], - [3.1500e+02, 1.3150e+03]], - - [[2.5000e+01, 1.0250e+03], - [1.2500e+02, 1.1250e+03], - [2.2500e+02, 1.2250e+03], - [3.2500e+02, 1.3250e+03]]]]], device='cuda:0') - -########## torch.float32/torch.int64/size=(2, 3)+(6, 6)+(2, 1) ########## -# sparse tensor -tensor(crow_indices=tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], - - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]), - col_indices=tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]), - values=tensor([[[[[[[ 1.], - [101.]], - - [[ 11.], - [111.]]], - - - [[[ 2.], - [102.]], - - [[ 12.], - [112.]]], - - - [[[ 3.], - [103.]], - - [[ 13.], - [113.]]]], - - - - [[[[ 2.], - [102.]], - - [[ 12.], - [112.]]], - - - [[[ 3.], - [103.]], - - [[ 13.], - [113.]]], - - - [[[ 4.], - [104.]], - - [[ 14.], - [114.]]]], - - - - [[[[ 3.], - [103.]], - - [[ 13.], - [113.]]], - - - [[[ 4.], - [104.]], - - [[ 14.], - [114.]]], - - - [[[ 5.], - [105.]], - - [[ 15.], - [115.]]]], - - - - [[[[ 4.], - [104.]], - - [[ 14.], - [114.]]], - - - [[[ 5.], - [105.]], - - [[ 15.], - [115.]]], - - - [[[ 6.], - [106.]], - - [[ 16.], - [116.]]]]], - - - - - [[[[[ 5.], - [105.]], - - [[ 15.], - [115.]]], - - - [[[ 6.], - [106.]], - - [[ 16.], - [116.]]], - - - [[[ 7.], - [107.]], - - [[ 17.], - [117.]]]], - - - - [[[[ 6.], - [106.]], - - [[ 16.], - [116.]]], - - - [[[ 7.], - [107.]], - - [[ 17.], - [117.]]], - - - [[[ 8.], - [108.]], - - [[ 18.], - [118.]]]], - - - - [[[[ 7.], - [107.]], - - [[ 17.], - [117.]]], - - - [[[ 8.], - [108.]], - - [[ 18.], - [118.]]], - - - [[[ 9.], - [109.]], - - [[ 19.], - [119.]]]], - - - - [[[[ 8.], - [108.]], - - [[ 18.], - [118.]]], - - - [[[ 9.], - [109.]], - - [[ 19.], - [119.]]], - - - [[[ 10.], - [110.]], - - [[ 20.], - [120.]]]]], - - - - - [[[[[ 9.], - [109.]], - - [[ 19.], - [119.]]], - - - [[[ 10.], - [110.]], - - [[ 20.], - [120.]]], - - - [[[ 11.], - [111.]], - - [[ 21.], - [121.]]]], - - - - [[[[ 10.], - [110.]], - - [[ 20.], - [120.]]], - - - [[[ 11.], - [111.]], - - [[ 21.], - [121.]]], - - - [[[ 12.], - [112.]], - - [[ 22.], - [122.]]]], - - - - [[[[ 11.], - [111.]], - - [[ 21.], - [121.]]], - - - [[[ 12.], - [112.]], - - [[ 22.], - [122.]]], - - - [[[ 13.], - [113.]], - - [[ 23.], - [123.]]]], - - - - [[[[ 12.], - [112.]], - - [[ 22.], - [122.]]], - - - [[[ 13.], - [113.]], - - [[ 23.], - [123.]]], - - - [[[ 14.], - [114.]], - - [[ 24.], - [124.]]]]]], - - - - - - [[[[[[ 13.], - [113.]], - - [[ 23.], - [123.]]], - - - [[[ 14.], - [114.]], - - [[ 24.], - [124.]]], - - - [[[ 15.], - [115.]], - - [[ 25.], - [125.]]]], - - - - [[[[ 14.], - [114.]], - - [[ 24.], - [124.]]], - - - [[[ 15.], - [115.]], - - [[ 25.], - [125.]]], - - - [[[ 16.], - [116.]], - - [[ 26.], - [126.]]]], - - - - [[[[ 15.], - [115.]], - - [[ 25.], - [125.]]], - - - [[[ 16.], - [116.]], - - [[ 26.], - [126.]]], - - - [[[ 17.], - [117.]], - - [[ 27.], - [127.]]]], - - - - [[[[ 16.], - [116.]], - - [[ 26.], - [126.]]], - - - [[[ 17.], - [117.]], - - [[ 27.], - [127.]]], - - - [[[ 18.], - [118.]], - - [[ 28.], - [128.]]]]], - - - - - [[[[[ 17.], - [117.]], - - [[ 27.], - [127.]]], - - - [[[ 18.], - [118.]], - - [[ 28.], - [128.]]], - - - [[[ 19.], - [119.]], - - [[ 29.], - [129.]]]], - - - - [[[[ 18.], - [118.]], - - [[ 28.], - [128.]]], - - - [[[ 19.], - [119.]], - - [[ 29.], - [129.]]], - - - [[[ 20.], - [120.]], - - [[ 30.], - [130.]]]], - - - - [[[[ 19.], - [119.]], - - [[ 29.], - [129.]]], - - - [[[ 20.], - [120.]], - - [[ 30.], - [130.]]], - - - [[[ 21.], - [121.]], - - [[ 31.], - [131.]]]], - - - - [[[[ 20.], - [120.]], - - [[ 30.], - [130.]]], - - - [[[ 21.], - [121.]], - - [[ 31.], - [131.]]], - - - [[[ 22.], - [122.]], - - [[ 32.], - [132.]]]]], - - - - - [[[[[ 21.], - [121.]], - - [[ 31.], - [131.]]], - - - [[[ 22.], - [122.]], - - [[ 32.], - [132.]]], - - - [[[ 23.], - [123.]], - - [[ 33.], - [133.]]]], - - - - [[[[ 22.], - [122.]], - - [[ 32.], - [132.]]], - - - [[[ 23.], - [123.]], - - [[ 33.], - [133.]]], - - - [[[ 24.], - [124.]], - - [[ 34.], - [134.]]]], - - - - [[[[ 23.], - [123.]], - - [[ 33.], - [133.]]], - - - [[[ 24.], - [124.]], - - [[ 34.], - [134.]]], - - - [[[ 25.], - [125.]], - - [[ 35.], - [135.]]]], - - - - [[[[ 24.], - [124.]], - - [[ 34.], - [134.]]], - - - [[[ 25.], - [125.]], - - [[ 35.], - [135.]]], - - - [[[ 26.], - [126.]], - - [[ 36.], - [136.]]]]]]]), device='cuda:0', - size=(2, 3, 6, 6, 2, 1), nnz=4, layout=torch.sparse_bsr) -# _crow_indices -tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], - - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]], device='cuda:0') -# _col_indices -tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]], device='cuda:0') -# _values -tensor([[[[[[[ 1.], - [101.]], - - [[ 11.], - [111.]]], - - - [[[ 2.], - [102.]], - - [[ 12.], - [112.]]], - - - [[[ 3.], - [103.]], - - [[ 13.], - [113.]]]], - - - - [[[[ 2.], - [102.]], - - [[ 12.], - [112.]]], - - - [[[ 3.], - [103.]], - - [[ 13.], - [113.]]], - - - [[[ 4.], - [104.]], - - [[ 14.], - [114.]]]], - - - - [[[[ 3.], - [103.]], - - [[ 13.], - [113.]]], - - - [[[ 4.], - [104.]], - - [[ 14.], - [114.]]], - - - [[[ 5.], - [105.]], - - [[ 15.], - [115.]]]], - - - - [[[[ 4.], - [104.]], - - [[ 14.], - [114.]]], - - - [[[ 5.], - [105.]], - - [[ 15.], - [115.]]], - - - [[[ 6.], - [106.]], - - [[ 16.], - [116.]]]]], - - - - - [[[[[ 5.], - [105.]], - - [[ 15.], - [115.]]], - - - [[[ 6.], - [106.]], - - [[ 16.], - [116.]]], - - - [[[ 7.], - [107.]], - - [[ 17.], - [117.]]]], - - - - [[[[ 6.], - [106.]], - - [[ 16.], - [116.]]], - - - [[[ 7.], - [107.]], - - [[ 17.], - [117.]]], - - - [[[ 8.], - [108.]], - - [[ 18.], - [118.]]]], - - - - [[[[ 7.], - [107.]], - - [[ 17.], - [117.]]], - - - [[[ 8.], - [108.]], - - [[ 18.], - [118.]]], - - - [[[ 9.], - [109.]], - - [[ 19.], - [119.]]]], - - - - [[[[ 8.], - [108.]], - - [[ 18.], - [118.]]], - - - [[[ 9.], - [109.]], - - [[ 19.], - [119.]]], - - - [[[ 10.], - [110.]], - - [[ 20.], - [120.]]]]], - - - - - [[[[[ 9.], - [109.]], - - [[ 19.], - [119.]]], - - - [[[ 10.], - [110.]], - - [[ 20.], - [120.]]], - - - [[[ 11.], - [111.]], - - [[ 21.], - [121.]]]], - - - - [[[[ 10.], - [110.]], - - [[ 20.], - [120.]]], - - - [[[ 11.], - [111.]], - - [[ 21.], - [121.]]], - - - [[[ 12.], - [112.]], - - [[ 22.], - [122.]]]], - - - - [[[[ 11.], - [111.]], - - [[ 21.], - [121.]]], - - - [[[ 12.], - [112.]], - - [[ 22.], - [122.]]], - - - [[[ 13.], - [113.]], - - [[ 23.], - [123.]]]], - - - - [[[[ 12.], - [112.]], - - [[ 22.], - [122.]]], - - - [[[ 13.], - [113.]], - - [[ 23.], - [123.]]], - - - [[[ 14.], - [114.]], - - [[ 24.], - [124.]]]]]], - - - - - - [[[[[[ 13.], - [113.]], - - [[ 23.], - [123.]]], - - - [[[ 14.], - [114.]], - - [[ 24.], - [124.]]], - - - [[[ 15.], - [115.]], - - [[ 25.], - [125.]]]], - - - - [[[[ 14.], - [114.]], - - [[ 24.], - [124.]]], - - - [[[ 15.], - [115.]], - - [[ 25.], - [125.]]], - - - [[[ 16.], - [116.]], - - [[ 26.], - [126.]]]], - - - - [[[[ 15.], - [115.]], - - [[ 25.], - [125.]]], - - - [[[ 16.], - [116.]], - - [[ 26.], - [126.]]], - - - [[[ 17.], - [117.]], - - [[ 27.], - [127.]]]], - - - - [[[[ 16.], - [116.]], - - [[ 26.], - [126.]]], - - - [[[ 17.], - [117.]], - - [[ 27.], - [127.]]], - - - [[[ 18.], - [118.]], - - [[ 28.], - [128.]]]]], - - - - - [[[[[ 17.], - [117.]], - - [[ 27.], - [127.]]], - - - [[[ 18.], - [118.]], - - [[ 28.], - [128.]]], - - - [[[ 19.], - [119.]], - - [[ 29.], - [129.]]]], - - - - [[[[ 18.], - [118.]], - - [[ 28.], - [128.]]], - - - [[[ 19.], - [119.]], - - [[ 29.], - [129.]]], - - - [[[ 20.], - [120.]], - - [[ 30.], - [130.]]]], - - - - [[[[ 19.], - [119.]], - - [[ 29.], - [129.]]], - - - [[[ 20.], - [120.]], - - [[ 30.], - [130.]]], - - - [[[ 21.], - [121.]], - - [[ 31.], - [131.]]]], - - - - [[[[ 20.], - [120.]], - - [[ 30.], - [130.]]], - - - [[[ 21.], - [121.]], - - [[ 31.], - [131.]]], - - - [[[ 22.], - [122.]], - - [[ 32.], - [132.]]]]], - - - - - [[[[[ 21.], - [121.]], - - [[ 31.], - [131.]]], - - - [[[ 22.], - [122.]], - - [[ 32.], - [132.]]], - - - [[[ 23.], - [123.]], - - [[ 33.], - [133.]]]], - - - - [[[[ 22.], - [122.]], - - [[ 32.], - [132.]]], - - - [[[ 23.], - [123.]], - - [[ 33.], - [133.]]], - - - [[[ 24.], - [124.]], - - [[ 34.], - [134.]]]], - - - - [[[[ 23.], - [123.]], - - [[ 33.], - [133.]]], - - - [[[ 24.], - [124.]], - - [[ 34.], - [134.]]], - - - [[[ 25.], - [125.]], - - [[ 35.], - [135.]]]], - - - - [[[[ 24.], - [124.]], - - [[ 34.], - [134.]]], - - - [[[ 25.], - [125.]], - - [[ 35.], - [135.]]], - - - [[[ 26.], - [126.]], - - [[ 36.], - [136.]]]]]]], device='cuda:0') - - -########## torch.float64/torch.int64/size=()+(6, 6)+(2,) ########## -# sparse tensor -tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 2]), - values=tensor([[[[ 1., 101.], - [ 11., 111.]], - - [[ 2., 102.], - [ 12., 112.]], - - [[ 3., 103.], - [ 13., 113.]]], - - - [[[ 2., 102.], - [ 12., 112.]], - - [[ 3., 103.], - [ 13., 113.]], - - [[ 4., 104.], - [ 14., 114.]]], - - - [[[ 3., 103.], - [ 13., 113.]], - - [[ 4., 104.], - [ 14., 114.]], - - [[ 5., 105.], - [ 15., 115.]]], - - - [[[ 4., 104.], - [ 14., 114.]], - - [[ 5., 105.], - [ 15., 115.]], - - [[ 6., 106.], - [ 16., 116.]]]]), device='cuda:0', size=(6, 6, 2), - nnz=4, dtype=torch.float64, layout=torch.sparse_bsr) -# _crow_indices -tensor([0, 2, 4], device='cuda:0') -# _col_indices -tensor([0, 1, 0, 2], device='cuda:0') -# _values -tensor([[[[ 1., 101.], - [ 11., 111.]], - - [[ 2., 102.], - [ 12., 112.]], - - [[ 3., 103.], - [ 13., 113.]]], - - - [[[ 2., 102.], - [ 12., 112.]], - - [[ 3., 103.], - [ 13., 113.]], - - [[ 4., 104.], - [ 14., 114.]]], - - - [[[ 3., 103.], - [ 13., 113.]], - - [[ 4., 104.], - [ 14., 114.]], - - [[ 5., 105.], - [ 15., 115.]]], - - - [[[ 4., 104.], - [ 14., 114.]], - - [[ 5., 105.], - [ 15., 115.]], - - [[ 6., 106.], - [ 16., 116.]]]], device='cuda:0', dtype=torch.float64) - -########## torch.float64/torch.int64/size=()+(4, 9)+(4, 2) ########## -# sparse tensor -tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 2]), - values=tensor([[[[[1.0000e+00, 1.0010e+03], - [1.0100e+02, 1.1010e+03], - [2.0100e+02, 1.2010e+03], - [3.0100e+02, 1.3010e+03]], - - [[1.1000e+01, 1.0110e+03], - [1.1100e+02, 1.1110e+03], - [2.1100e+02, 1.2110e+03], - [3.1100e+02, 1.3110e+03]], - - [[2.1000e+01, 1.0210e+03], - [1.2100e+02, 1.1210e+03], - [2.2100e+02, 1.2210e+03], - [3.2100e+02, 1.3210e+03]]], - - - [[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], - - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]], - - [[2.2000e+01, 1.0220e+03], - [1.2200e+02, 1.1220e+03], - [2.2200e+02, 1.2220e+03], - [3.2200e+02, 1.3220e+03]]]], - - - - [[[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], - - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]], - - [[2.2000e+01, 1.0220e+03], - [1.2200e+02, 1.1220e+03], - [2.2200e+02, 1.2220e+03], - [3.2200e+02, 1.3220e+03]]], - - - [[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], - - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]], - - [[2.3000e+01, 1.0230e+03], - [1.2300e+02, 1.1230e+03], - [2.2300e+02, 1.2230e+03], - [3.2300e+02, 1.3230e+03]]]], - - - - [[[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], - - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]], - - [[2.3000e+01, 1.0230e+03], - [1.2300e+02, 1.1230e+03], - [2.2300e+02, 1.2230e+03], - [3.2300e+02, 1.3230e+03]]], - - - [[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], - - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]], - - [[2.4000e+01, 1.0240e+03], - [1.2400e+02, 1.1240e+03], - [2.2400e+02, 1.2240e+03], - [3.2400e+02, 1.3240e+03]]]], - - - - [[[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], - - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]], - - [[2.4000e+01, 1.0240e+03], - [1.2400e+02, 1.1240e+03], - [2.2400e+02, 1.2240e+03], - [3.2400e+02, 1.3240e+03]]], - - - [[[5.0000e+00, 1.0050e+03], - [1.0500e+02, 1.1050e+03], - [2.0500e+02, 1.2050e+03], - [3.0500e+02, 1.3050e+03]], - - [[1.5000e+01, 1.0150e+03], - [1.1500e+02, 1.1150e+03], - [2.1500e+02, 1.2150e+03], - [3.1500e+02, 1.3150e+03]], - - [[2.5000e+01, 1.0250e+03], - [1.2500e+02, 1.1250e+03], - [2.2500e+02, 1.2250e+03], - [3.2500e+02, 1.3250e+03]]]]]), device='cuda:0', - size=(4, 9, 4, 2), nnz=4, dtype=torch.float64, layout=torch.sparse_bsr) -# _crow_indices -tensor([0, 2, 4], device='cuda:0') -# _col_indices -tensor([0, 1, 0, 2], device='cuda:0') -# _values -tensor([[[[[1.0000e+00, 1.0010e+03], - [1.0100e+02, 1.1010e+03], - [2.0100e+02, 1.2010e+03], - [3.0100e+02, 1.3010e+03]], - - [[1.1000e+01, 1.0110e+03], - [1.1100e+02, 1.1110e+03], - [2.1100e+02, 1.2110e+03], - [3.1100e+02, 1.3110e+03]], - - [[2.1000e+01, 1.0210e+03], - [1.2100e+02, 1.1210e+03], - [2.2100e+02, 1.2210e+03], - [3.2100e+02, 1.3210e+03]]], - - - [[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], - - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]], - - [[2.2000e+01, 1.0220e+03], - [1.2200e+02, 1.1220e+03], - [2.2200e+02, 1.2220e+03], - [3.2200e+02, 1.3220e+03]]]], - - - - [[[[2.0000e+00, 1.0020e+03], - [1.0200e+02, 1.1020e+03], - [2.0200e+02, 1.2020e+03], - [3.0200e+02, 1.3020e+03]], - - [[1.2000e+01, 1.0120e+03], - [1.1200e+02, 1.1120e+03], - [2.1200e+02, 1.2120e+03], - [3.1200e+02, 1.3120e+03]], - - [[2.2000e+01, 1.0220e+03], - [1.2200e+02, 1.1220e+03], - [2.2200e+02, 1.2220e+03], - [3.2200e+02, 1.3220e+03]]], - - - [[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], - - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]], - - [[2.3000e+01, 1.0230e+03], - [1.2300e+02, 1.1230e+03], - [2.2300e+02, 1.2230e+03], - [3.2300e+02, 1.3230e+03]]]], - - - - [[[[3.0000e+00, 1.0030e+03], - [1.0300e+02, 1.1030e+03], - [2.0300e+02, 1.2030e+03], - [3.0300e+02, 1.3030e+03]], - - [[1.3000e+01, 1.0130e+03], - [1.1300e+02, 1.1130e+03], - [2.1300e+02, 1.2130e+03], - [3.1300e+02, 1.3130e+03]], - - [[2.3000e+01, 1.0230e+03], - [1.2300e+02, 1.1230e+03], - [2.2300e+02, 1.2230e+03], - [3.2300e+02, 1.3230e+03]]], - - - [[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], - - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]], - - [[2.4000e+01, 1.0240e+03], - [1.2400e+02, 1.1240e+03], - [2.2400e+02, 1.2240e+03], - [3.2400e+02, 1.3240e+03]]]], - - - - [[[[4.0000e+00, 1.0040e+03], - [1.0400e+02, 1.1040e+03], - [2.0400e+02, 1.2040e+03], - [3.0400e+02, 1.3040e+03]], - - [[1.4000e+01, 1.0140e+03], - [1.1400e+02, 1.1140e+03], - [2.1400e+02, 1.2140e+03], - [3.1400e+02, 1.3140e+03]], - - [[2.4000e+01, 1.0240e+03], - [1.2400e+02, 1.1240e+03], - [2.2400e+02, 1.2240e+03], - [3.2400e+02, 1.3240e+03]]], - - - [[[5.0000e+00, 1.0050e+03], - [1.0500e+02, 1.1050e+03], - [2.0500e+02, 1.2050e+03], - [3.0500e+02, 1.3050e+03]], - - [[1.5000e+01, 1.0150e+03], - [1.1500e+02, 1.1150e+03], - [2.1500e+02, 1.2150e+03], - [3.1500e+02, 1.3150e+03]], - - [[2.5000e+01, 1.0250e+03], - [1.2500e+02, 1.1250e+03], - [2.2500e+02, 1.2250e+03], - [3.2500e+02, 1.3250e+03]]]]], device='cuda:0', dtype=torch.float64) - -########## torch.float64/torch.int64/size=(2, 3)+(6, 6)+(2, 1) ########## -# sparse tensor -tensor(crow_indices=tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], - - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]), - col_indices=tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]), - values=tensor([[[[[[[ 1.], - [101.]], - - [[ 11.], - [111.]]], - - - [[[ 2.], - [102.]], - - [[ 12.], - [112.]]], - - - [[[ 3.], - [103.]], - - [[ 13.], - [113.]]]], - - - - [[[[ 2.], - [102.]], - - [[ 12.], - [112.]]], - - - [[[ 3.], - [103.]], - - [[ 13.], - [113.]]], - - - [[[ 4.], - [104.]], - - [[ 14.], - [114.]]]], - - - - [[[[ 3.], - [103.]], - - [[ 13.], - [113.]]], - - - [[[ 4.], - [104.]], - - [[ 14.], - [114.]]], - - - [[[ 5.], - [105.]], - - [[ 15.], - [115.]]]], - - - - [[[[ 4.], - [104.]], - - [[ 14.], - [114.]]], - - - [[[ 5.], - [105.]], - - [[ 15.], - [115.]]], - - - [[[ 6.], - [106.]], - - [[ 16.], - [116.]]]]], - - - - - [[[[[ 5.], - [105.]], - - [[ 15.], - [115.]]], - - - [[[ 6.], - [106.]], - - [[ 16.], - [116.]]], - - - [[[ 7.], - [107.]], - - [[ 17.], - [117.]]]], - - - - [[[[ 6.], - [106.]], - - [[ 16.], - [116.]]], - - - [[[ 7.], - [107.]], - - [[ 17.], - [117.]]], - - - [[[ 8.], - [108.]], - - [[ 18.], - [118.]]]], - - - - [[[[ 7.], - [107.]], - - [[ 17.], - [117.]]], - - - [[[ 8.], - [108.]], - - [[ 18.], - [118.]]], - - - [[[ 9.], - [109.]], - - [[ 19.], - [119.]]]], - - - - [[[[ 8.], - [108.]], - - [[ 18.], - [118.]]], - - - [[[ 9.], - [109.]], - - [[ 19.], - [119.]]], - - - [[[ 10.], - [110.]], - - [[ 20.], - [120.]]]]], - - - - - [[[[[ 9.], - [109.]], - - [[ 19.], - [119.]]], - - - [[[ 10.], - [110.]], - - [[ 20.], - [120.]]], - - - [[[ 11.], - [111.]], - - [[ 21.], - [121.]]]], - - - - [[[[ 10.], - [110.]], - - [[ 20.], - [120.]]], - - - [[[ 11.], - [111.]], - - [[ 21.], - [121.]]], - - - [[[ 12.], - [112.]], - - [[ 22.], - [122.]]]], - - - - [[[[ 11.], - [111.]], - - [[ 21.], - [121.]]], - - - [[[ 12.], - [112.]], - - [[ 22.], - [122.]]], - - - [[[ 13.], - [113.]], - - [[ 23.], - [123.]]]], - - - - [[[[ 12.], - [112.]], - - [[ 22.], - [122.]]], - - - [[[ 13.], - [113.]], - - [[ 23.], - [123.]]], - - - [[[ 14.], - [114.]], - - [[ 24.], - [124.]]]]]], - - - - - - [[[[[[ 13.], - [113.]], - - [[ 23.], - [123.]]], - - - [[[ 14.], - [114.]], - - [[ 24.], - [124.]]], - - - [[[ 15.], - [115.]], - - [[ 25.], - [125.]]]], - - - - [[[[ 14.], - [114.]], - - [[ 24.], - [124.]]], - - - [[[ 15.], - [115.]], - - [[ 25.], - [125.]]], - - - [[[ 16.], - [116.]], - - [[ 26.], - [126.]]]], - - - - [[[[ 15.], - [115.]], - - [[ 25.], - [125.]]], - - - [[[ 16.], - [116.]], - - [[ 26.], - [126.]]], - - - [[[ 17.], - [117.]], - - [[ 27.], - [127.]]]], - - - - [[[[ 16.], - [116.]], - - [[ 26.], - [126.]]], - - - [[[ 17.], - [117.]], - - [[ 27.], - [127.]]], - - - [[[ 18.], - [118.]], - - [[ 28.], - [128.]]]]], - - - - - [[[[[ 17.], - [117.]], - - [[ 27.], - [127.]]], - - - [[[ 18.], - [118.]], - - [[ 28.], - [128.]]], - - - [[[ 19.], - [119.]], - - [[ 29.], - [129.]]]], - - - - [[[[ 18.], - [118.]], - - [[ 28.], - [128.]]], - - - [[[ 19.], - [119.]], - - [[ 29.], - [129.]]], - - - [[[ 20.], - [120.]], - - [[ 30.], - [130.]]]], - - - - [[[[ 19.], - [119.]], - - [[ 29.], - [129.]]], - - - [[[ 20.], - [120.]], - - [[ 30.], - [130.]]], - - - [[[ 21.], - [121.]], - - [[ 31.], - [131.]]]], - - - - [[[[ 20.], - [120.]], - - [[ 30.], - [130.]]], - - - [[[ 21.], - [121.]], - - [[ 31.], - [131.]]], - - - [[[ 22.], - [122.]], - - [[ 32.], - [132.]]]]], - - - - - [[[[[ 21.], - [121.]], - - [[ 31.], - [131.]]], - - - [[[ 22.], - [122.]], - - [[ 32.], - [132.]]], - - - [[[ 23.], - [123.]], - - [[ 33.], - [133.]]]], - - - - [[[[ 22.], - [122.]], - - [[ 32.], - [132.]]], - - - [[[ 23.], - [123.]], - - [[ 33.], - [133.]]], - - - [[[ 24.], - [124.]], - - [[ 34.], - [134.]]]], - - - - [[[[ 23.], - [123.]], - - [[ 33.], - [133.]]], - - - [[[ 24.], - [124.]], - - [[ 34.], - [134.]]], - - - [[[ 25.], - [125.]], - - [[ 35.], - [135.]]]], - - - - [[[[ 24.], - [124.]], - - [[ 34.], - [134.]]], - - - [[[ 25.], - [125.]], - - [[ 35.], - [135.]]], - - - [[[ 26.], - [126.]], - - [[ 36.], - [136.]]]]]]]), device='cuda:0', - size=(2, 3, 6, 6, 2, 1), nnz=4, dtype=torch.float64, - layout=torch.sparse_bsr) + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]]]]), device='cuda:0', size=(8, 6, 4, 2), + nnz=7, dtype=torch.float64, layout=torch.sparse_bsr) # _crow_indices -tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], - - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]], device='cuda:0') +tensor([0, 2, 3, 5, 7], device='cuda:0') # _col_indices -tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]], device='cuda:0') +tensor([0, 1, 0, 0, 1, 0, 1], device='cuda:0') # _values -tensor([[[[[[[ 1.], - [101.]], - - [[ 11.], - [111.]]], - - - [[[ 2.], - [102.]], - - [[ 12.], - [112.]]], - - - [[[ 3.], - [103.]], - - [[ 13.], - [113.]]]], - - - - [[[[ 2.], - [102.]], - - [[ 12.], - [112.]]], - - - [[[ 3.], - [103.]], - - [[ 13.], - [113.]]], - - - [[[ 4.], - [104.]], - - [[ 14.], - [114.]]]], - - - - [[[[ 3.], - [103.]], - - [[ 13.], - [113.]]], - - - [[[ 4.], - [104.]], - - [[ 14.], - [114.]]], - - - [[[ 5.], - [105.]], - - [[ 15.], - [115.]]]], - - - - [[[[ 4.], - [104.]], - - [[ 14.], - [114.]]], - - - [[[ 5.], - [105.]], - - [[ 15.], - [115.]]], - - - [[[ 6.], - [106.]], - - [[ 16.], - [116.]]]]], - - - - - [[[[[ 5.], - [105.]], - - [[ 15.], - [115.]]], - - - [[[ 6.], - [106.]], - - [[ 16.], - [116.]]], - - - [[[ 7.], - [107.]], - - [[ 17.], - [117.]]]], - - - - [[[[ 6.], - [106.]], - - [[ 16.], - [116.]]], - - - [[[ 7.], - [107.]], - - [[ 17.], - [117.]]], - - - [[[ 8.], - [108.]], - - [[ 18.], - [118.]]]], - - - - [[[[ 7.], - [107.]], - - [[ 17.], - [117.]]], - - - [[[ 8.], - [108.]], - - [[ 18.], - [118.]]], - - - [[[ 9.], - [109.]], - - [[ 19.], - [119.]]]], - - - - [[[[ 8.], - [108.]], - - [[ 18.], - [118.]]], - - - [[[ 9.], - [109.]], - - [[ 19.], - [119.]]], - - - [[[ 10.], - [110.]], - - [[ 20.], - [120.]]]]], - - - - - [[[[[ 9.], - [109.]], - - [[ 19.], - [119.]]], - - - [[[ 10.], - [110.]], - - [[ 20.], - [120.]]], - - - [[[ 11.], - [111.]], - - [[ 21.], - [121.]]]], - - - - [[[[ 10.], - [110.]], - - [[ 20.], - [120.]]], - - - [[[ 11.], - [111.]], - - [[ 21.], - [121.]]], - - - [[[ 12.], - [112.]], - - [[ 22.], - [122.]]]], - - - - [[[[ 11.], - [111.]], - - [[ 21.], - [121.]]], - - - [[[ 12.], - [112.]], - - [[ 22.], - [122.]]], - - - [[[ 13.], - [113.]], - - [[ 23.], - [123.]]]], - - - - [[[[ 12.], - [112.]], - - [[ 22.], - [122.]]], - - - [[[ 13.], - [113.]], - - [[ 23.], - [123.]]], - - - [[[ 14.], - [114.]], - - [[ 24.], - [124.]]]]]], - - - - - - [[[[[[ 13.], - [113.]], - - [[ 23.], - [123.]]], - - - [[[ 14.], - [114.]], - - [[ 24.], - [124.]]], - - - [[[ 15.], - [115.]], - - [[ 25.], - [125.]]]], - - - - [[[[ 14.], - [114.]], - - [[ 24.], - [124.]]], - - - [[[ 15.], - [115.]], - - [[ 25.], - [125.]]], - - - [[[ 16.], - [116.]], - - [[ 26.], - [126.]]]], - - - - [[[[ 15.], - [115.]], - - [[ 25.], - [125.]]], - - - [[[ 16.], - [116.]], - - [[ 26.], - [126.]]], - - - [[[ 17.], - [117.]], - - [[ 27.], - [127.]]]], - - - - [[[[ 16.], - [116.]], - - [[ 26.], - [126.]]], - - - [[[ 17.], - [117.]], - - [[ 27.], - [127.]]], - - - [[[ 18.], - [118.]], - - [[ 28.], - [128.]]]]], - - - - - [[[[[ 17.], - [117.]], - - [[ 27.], - [127.]]], - - - [[[ 18.], - [118.]], - - [[ 28.], - [128.]]], - - - [[[ 19.], - [119.]], - - [[ 29.], - [129.]]]], - - - - [[[[ 18.], - [118.]], - - [[ 28.], - [128.]]], - - - [[[ 19.], - [119.]], - - [[ 29.], - [129.]]], - - - [[[ 20.], - [120.]], - - [[ 30.], - [130.]]]], - +tensor([[[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], - [[[[ 19.], - [119.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[ 29.], - [129.]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[ 20.], - [120.]], + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]], - [[ 30.], - [130.]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[[ 21.], - [121.]], - [[ 31.], - [131.]]]], + [[[[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]]], - [[[[ 20.], - [120.]], - [[ 30.], - [130.]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], - [[[ 21.], - [121.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[ 31.], - [131.]]], - [[[ 22.], - [122.]], + [[[[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], - [[ 32.], - [132.]]]]], + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[[[ 21.], - [121.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 31.], - [131.]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]]], - [[[ 22.], - [122.]], - [[ 32.], - [132.]]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], - [[[ 23.], - [123.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[ 33.], - [133.]]]], + [[[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[[ 22.], - [122.]], + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]]]], - [[ 32.], - [132.]]], - [[[ 23.], - [123.]], + [[[[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], - [[ 33.], - [133.]]], + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]]], - [[[ 24.], - [124.]], - [[ 34.], - [134.]]]], + [[[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]]]], - [[[[ 23.], - [123.]], - [[ 33.], - [133.]]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[ 24.], - [124.]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[ 34.], - [134.]]], + [[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]]], - [[[ 25.], - [125.]], + [[[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], - [[ 35.], - [135.]]]], + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]]]], - [[[[ 24.], - [124.]], - [[ 34.], - [134.]]], + [[[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], - [[[ 25.], - [125.]], + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]]], - [[ 35.], - [135.]]], + [[[ 0., 0.], + [ 0., 0.], + [ 0., 0.], + [ 0., 0.]], - [[[ 26.], - [126.]], + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], - [[ 36.], - [136.]]]]]]], device='cuda:0', dtype=torch.float64) + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]]]], device='cuda:0', dtype=torch.float64) diff --git a/test/expect/TestSparseCompressedCUDA.test_print_SparseCSC_cuda.expect b/test/expect/TestSparseCompressedCUDA.test_print_SparseCSC_cuda.expect index 64435343b7cb6..65efcec63319b 100644 --- a/test/expect/TestSparseCompressedCUDA.test_print_SparseCSC_cuda.expect +++ b/test/expect/TestSparseCompressedCUDA.test_print_SparseCSC_cuda.expect @@ -1,1411 +1,1661 @@ -########## torch.float32/torch.int32/size=()+(3, 2)+() ########## +########## torch.float32/torch.int32/size=(2, 3)+(2, 3)+() ########## # sparse tensor -tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 2]), - values=tensor([1., 2., 3., 4.]), device='cuda:0', size=(3, 2), nnz=4, - layout=torch.sparse_csc) +tensor(ccol_indices=tensor([[[0, 2, 3, 4], + [0, 2, 3, 4], + [0, 2, 3, 4]], + + [[0, 1, 3, 4], + [0, 2, 3, 4], + [0, 1, 3, 4]]]), + row_indices=tensor([[[0, 1, 0, 1], + [0, 1, 0, 0], + [0, 1, 1, 1]], + + [[1, 0, 1, 1], + [0, 1, 1, 0], + [0, 0, 1, 0]]]), + values=tensor([[[1., 3., 2., 4.], + [1., 4., 2., 3.], + [1., 2., 3., 4.]], + + [[2., 1., 3., 4.], + [1., 3., 4., 2.], + [1., 2., 4., 3.]]]), device='cuda:0', size=(2, 3, 2, 3), + nnz=4, layout=torch.sparse_csc) # _ccol_indices -tensor([0, 2, 4], device='cuda:0', dtype=torch.int32) +tensor([[[0, 2, 3, 4], + [0, 2, 3, 4], + [0, 2, 3, 4]], + + [[0, 1, 3, 4], + [0, 2, 3, 4], + [0, 1, 3, 4]]], device='cuda:0', dtype=torch.int32) # _row_indices -tensor([0, 1, 0, 2], device='cuda:0', dtype=torch.int32) +tensor([[[0, 1, 0, 1], + [0, 1, 0, 0], + [0, 1, 1, 1]], + + [[1, 0, 1, 1], + [0, 1, 1, 0], + [0, 0, 1, 0]]], device='cuda:0', dtype=torch.int32) # _values -tensor([1., 2., 3., 4.], device='cuda:0') +tensor([[[1., 3., 2., 4.], + [1., 4., 2., 3.], + [1., 2., 3., 4.]], -########## torch.float32/torch.int32/size=()+(0, 0)+() ########## + [[2., 1., 3., 4.], + [1., 3., 4., 2.], + [1., 2., 4., 3.]]], device='cuda:0') + +########## torch.float32/torch.int32/size=()+(8, 6)+() ########## # sparse tensor -tensor(ccol_indices=tensor([0]), - row_indices=tensor([], size=(0,)), - values=tensor([], size=(0,)), device='cuda:0', size=(0, 0), nnz=0, +tensor(ccol_indices=tensor([ 0, 3, 8, 11, 14, 19, 24]), + row_indices=tensor([2, 5, 7, 0, 1, 2, 4, 7, 2, 5, 7, 0, 4, 5, 1, 4, 5, 6, + 7, 0, 4, 5, 6, 7]), + values=tensor([ 6., 13., 20., 1., 4., 7., 9., 21., 8., 14., 22., + 2., 10., 15., 5., 11., 16., 18., 23., 3., 12., 17., + 19., 24.]), device='cuda:0', size=(8, 6), nnz=24, layout=torch.sparse_csc) # _ccol_indices -tensor([0], device='cuda:0', dtype=torch.int32) +tensor([ 0, 3, 8, 11, 14, 19, 24], device='cuda:0', dtype=torch.int32) # _row_indices -tensor([], device='cuda:0', dtype=torch.int32) +tensor([2, 5, 7, 0, 1, 2, 4, 7, 2, 5, 7, 0, 4, 5, 1, 4, 5, 6, 7, 0, 4, 5, 6, 7], + device='cuda:0', dtype=torch.int32) # _values -tensor([], device='cuda:0') +tensor([ 6., 13., 20., 1., 4., 7., 9., 21., 8., 14., 22., 2., 10., 15., + 5., 11., 16., 18., 23., 3., 12., 17., 19., 24.], device='cuda:0') -########## torch.float32/torch.int32/size=(2,)+(3, 2)+() ########## -# sparse tensor -tensor(ccol_indices=tensor([[0, 2, 4], - [0, 3, 4]]), - row_indices=tensor([[0, 1, 0, 1], - [0, 1, 2, 0]]), - values=tensor([[1., 2., 3., 4.], - [5., 6., 7., 8.]]), device='cuda:0', size=(2, 3, 2), - nnz=4, layout=torch.sparse_csc) -# _ccol_indices -tensor([[0, 2, 4], - [0, 3, 4]], device='cuda:0', dtype=torch.int32) -# _row_indices -tensor([[0, 1, 0, 1], - [0, 1, 2, 0]], device='cuda:0', dtype=torch.int32) -# _values -tensor([[1., 2., 3., 4.], - [5., 6., 7., 8.]], device='cuda:0') -########## torch.float32/torch.int32/size=(2, 3)+(3, 2)+() ########## +########## torch.float64/torch.int32/size=(2, 3)+(2, 3)+() ########## # sparse tensor -tensor(ccol_indices=tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], +tensor(ccol_indices=tensor([[[0, 2, 3, 4], + [0, 2, 3, 4], + [0, 2, 3, 4]], - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]), + [[0, 1, 3, 4], + [0, 2, 3, 4], + [0, 1, 3, 4]]]), row_indices=tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]), - values=tensor([[[ 1., 2., 3., 4.], - [ 5., 6., 7., 8.], - [ 9., 10., 11., 12.]], - - [[13., 14., 15., 16.], - [17., 18., 19., 20.], - [21., 22., 23., 24.]]]), device='cuda:0', - size=(2, 3, 3, 2), nnz=4, layout=torch.sparse_csc) + [0, 1, 0, 0], + [0, 1, 1, 1]], + + [[1, 0, 1, 1], + [0, 1, 1, 0], + [0, 0, 1, 0]]]), + values=tensor([[[1., 3., 2., 4.], + [1., 4., 2., 3.], + [1., 2., 3., 4.]], + + [[2., 1., 3., 4.], + [1., 3., 4., 2.], + [1., 2., 4., 3.]]]), device='cuda:0', size=(2, 3, 2, 3), + nnz=4, dtype=torch.float64, layout=torch.sparse_csc) # _ccol_indices -tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], +tensor([[[0, 2, 3, 4], + [0, 2, 3, 4], + [0, 2, 3, 4]], - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]], device='cuda:0', dtype=torch.int32) + [[0, 1, 3, 4], + [0, 2, 3, 4], + [0, 1, 3, 4]]], device='cuda:0', dtype=torch.int32) # _row_indices tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], + [0, 1, 0, 0], + [0, 1, 1, 1]], - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]], device='cuda:0', dtype=torch.int32) + [[1, 0, 1, 1], + [0, 1, 1, 0], + [0, 0, 1, 0]]], device='cuda:0', dtype=torch.int32) # _values -tensor([[[ 1., 2., 3., 4.], - [ 5., 6., 7., 8.], - [ 9., 10., 11., 12.]], +tensor([[[1., 3., 2., 4.], + [1., 4., 2., 3.], + [1., 2., 3., 4.]], - [[13., 14., 15., 16.], - [17., 18., 19., 20.], - [21., 22., 23., 24.]]], device='cuda:0') + [[2., 1., 3., 4.], + [1., 3., 4., 2.], + [1., 2., 4., 3.]]], device='cuda:0', dtype=torch.float64) - -########## torch.float64/torch.int32/size=()+(3, 2)+() ########## -# sparse tensor -tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 2]), - values=tensor([1., 2., 3., 4.]), device='cuda:0', size=(3, 2), nnz=4, - dtype=torch.float64, layout=torch.sparse_csc) -# _ccol_indices -tensor([0, 2, 4], device='cuda:0', dtype=torch.int32) -# _row_indices -tensor([0, 1, 0, 2], device='cuda:0', dtype=torch.int32) -# _values -tensor([1., 2., 3., 4.], device='cuda:0', dtype=torch.float64) - -########## torch.float64/torch.int32/size=()+(0, 0)+() ########## +########## torch.float64/torch.int32/size=()+(8, 6)+() ########## # sparse tensor -tensor(ccol_indices=tensor([0]), - row_indices=tensor([], size=(0,)), - values=tensor([], size=(0,)), device='cuda:0', size=(0, 0), nnz=0, +tensor(ccol_indices=tensor([ 0, 3, 8, 11, 14, 19, 24]), + row_indices=tensor([2, 5, 7, 0, 1, 2, 4, 7, 2, 5, 7, 0, 4, 5, 1, 4, 5, 6, + 7, 0, 4, 5, 6, 7]), + values=tensor([ 6., 13., 20., 1., 4., 7., 9., 21., 8., 14., 22., + 2., 10., 15., 5., 11., 16., 18., 23., 3., 12., 17., + 19., 24.]), device='cuda:0', size=(8, 6), nnz=24, dtype=torch.float64, layout=torch.sparse_csc) # _ccol_indices -tensor([0], device='cuda:0', dtype=torch.int32) +tensor([ 0, 3, 8, 11, 14, 19, 24], device='cuda:0', dtype=torch.int32) # _row_indices -tensor([], device='cuda:0', dtype=torch.int32) +tensor([2, 5, 7, 0, 1, 2, 4, 7, 2, 5, 7, 0, 4, 5, 1, 4, 5, 6, 7, 0, 4, 5, 6, 7], + device='cuda:0', dtype=torch.int32) # _values -tensor([], device='cuda:0', dtype=torch.float64) +tensor([ 6., 13., 20., 1., 4., 7., 9., 21., 8., 14., 22., 2., 10., 15., + 5., 11., 16., 18., 23., 3., 12., 17., 19., 24.], device='cuda:0', + dtype=torch.float64) -########## torch.float64/torch.int32/size=(2,)+(3, 2)+() ########## -# sparse tensor -tensor(ccol_indices=tensor([[0, 2, 4], - [0, 3, 4]]), - row_indices=tensor([[0, 1, 0, 1], - [0, 1, 2, 0]]), - values=tensor([[1., 2., 3., 4.], - [5., 6., 7., 8.]]), device='cuda:0', size=(2, 3, 2), - nnz=4, dtype=torch.float64, layout=torch.sparse_csc) -# _ccol_indices -tensor([[0, 2, 4], - [0, 3, 4]], device='cuda:0', dtype=torch.int32) -# _row_indices -tensor([[0, 1, 0, 1], - [0, 1, 2, 0]], device='cuda:0', dtype=torch.int32) -# _values -tensor([[1., 2., 3., 4.], - [5., 6., 7., 8.]], device='cuda:0', dtype=torch.float64) -########## torch.float64/torch.int32/size=(2, 3)+(3, 2)+() ########## +########## torch.float32/torch.int64/size=(2, 3)+(2, 3)+() ########## # sparse tensor -tensor(ccol_indices=tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], +tensor(ccol_indices=tensor([[[0, 2, 3, 4], + [0, 2, 3, 4], + [0, 2, 3, 4]], - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]), + [[0, 1, 3, 4], + [0, 2, 3, 4], + [0, 1, 3, 4]]]), row_indices=tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]), - values=tensor([[[ 1., 2., 3., 4.], - [ 5., 6., 7., 8.], - [ 9., 10., 11., 12.]], - - [[13., 14., 15., 16.], - [17., 18., 19., 20.], - [21., 22., 23., 24.]]]), device='cuda:0', - size=(2, 3, 3, 2), nnz=4, dtype=torch.float64, layout=torch.sparse_csc) + [0, 1, 0, 0], + [0, 1, 1, 1]], + + [[1, 0, 1, 1], + [0, 1, 1, 0], + [0, 0, 1, 0]]]), + values=tensor([[[1., 3., 2., 4.], + [1., 4., 2., 3.], + [1., 2., 3., 4.]], + + [[2., 1., 3., 4.], + [1., 3., 4., 2.], + [1., 2., 4., 3.]]]), device='cuda:0', size=(2, 3, 2, 3), + nnz=4, layout=torch.sparse_csc) # _ccol_indices -tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], +tensor([[[0, 2, 3, 4], + [0, 2, 3, 4], + [0, 2, 3, 4]], - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]], device='cuda:0', dtype=torch.int32) + [[0, 1, 3, 4], + [0, 2, 3, 4], + [0, 1, 3, 4]]], device='cuda:0') # _row_indices tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], + [0, 1, 0, 0], + [0, 1, 1, 1]], - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]], device='cuda:0', dtype=torch.int32) + [[1, 0, 1, 1], + [0, 1, 1, 0], + [0, 0, 1, 0]]], device='cuda:0') # _values -tensor([[[ 1., 2., 3., 4.], - [ 5., 6., 7., 8.], - [ 9., 10., 11., 12.]], - - [[13., 14., 15., 16.], - [17., 18., 19., 20.], - [21., 22., 23., 24.]]], device='cuda:0', dtype=torch.float64) +tensor([[[1., 3., 2., 4.], + [1., 4., 2., 3.], + [1., 2., 3., 4.]], + [[2., 1., 3., 4.], + [1., 3., 4., 2.], + [1., 2., 4., 3.]]], device='cuda:0') -########## torch.float32/torch.int64/size=()+(3, 2)+() ########## +########## torch.float32/torch.int64/size=()+(8, 6)+() ########## # sparse tensor -tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 2]), - values=tensor([1., 2., 3., 4.]), device='cuda:0', size=(3, 2), nnz=4, +tensor(ccol_indices=tensor([ 0, 3, 8, 11, 14, 19, 24]), + row_indices=tensor([2, 5, 7, 0, 1, 2, 4, 7, 2, 5, 7, 0, 4, 5, 1, 4, 5, 6, + 7, 0, 4, 5, 6, 7]), + values=tensor([ 6., 13., 20., 1., 4., 7., 9., 21., 8., 14., 22., + 2., 10., 15., 5., 11., 16., 18., 23., 3., 12., 17., + 19., 24.]), device='cuda:0', size=(8, 6), nnz=24, layout=torch.sparse_csc) # _ccol_indices -tensor([0, 2, 4], device='cuda:0') +tensor([ 0, 3, 8, 11, 14, 19, 24], device='cuda:0') # _row_indices -tensor([0, 1, 0, 2], device='cuda:0') +tensor([2, 5, 7, 0, 1, 2, 4, 7, 2, 5, 7, 0, 4, 5, 1, 4, 5, 6, 7, 0, 4, 5, 6, 7], + device='cuda:0') # _values -tensor([1., 2., 3., 4.], device='cuda:0') +tensor([ 6., 13., 20., 1., 4., 7., 9., 21., 8., 14., 22., 2., 10., 15., + 5., 11., 16., 18., 23., 3., 12., 17., 19., 24.], device='cuda:0') -########## torch.float32/torch.int64/size=()+(0, 0)+() ########## -# sparse tensor -tensor(ccol_indices=tensor([0]), - row_indices=tensor([], size=(0,)), - values=tensor([], size=(0,)), device='cuda:0', size=(0, 0), nnz=0, - layout=torch.sparse_csc) -# _ccol_indices -tensor([0], device='cuda:0') -# _row_indices -tensor([], device='cuda:0', dtype=torch.int64) -# _values -tensor([], device='cuda:0') - -########## torch.float32/torch.int64/size=(2,)+(3, 2)+() ########## -# sparse tensor -tensor(ccol_indices=tensor([[0, 2, 4], - [0, 3, 4]]), - row_indices=tensor([[0, 1, 0, 1], - [0, 1, 2, 0]]), - values=tensor([[1., 2., 3., 4.], - [5., 6., 7., 8.]]), device='cuda:0', size=(2, 3, 2), - nnz=4, layout=torch.sparse_csc) -# _ccol_indices -tensor([[0, 2, 4], - [0, 3, 4]], device='cuda:0') -# _row_indices -tensor([[0, 1, 0, 1], - [0, 1, 2, 0]], device='cuda:0') -# _values -tensor([[1., 2., 3., 4.], - [5., 6., 7., 8.]], device='cuda:0') -########## torch.float32/torch.int64/size=(2, 3)+(3, 2)+() ########## +########## torch.float64/torch.int64/size=(2, 3)+(2, 3)+() ########## # sparse tensor -tensor(ccol_indices=tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], +tensor(ccol_indices=tensor([[[0, 2, 3, 4], + [0, 2, 3, 4], + [0, 2, 3, 4]], - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]), + [[0, 1, 3, 4], + [0, 2, 3, 4], + [0, 1, 3, 4]]]), row_indices=tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]), - values=tensor([[[ 1., 2., 3., 4.], - [ 5., 6., 7., 8.], - [ 9., 10., 11., 12.]], - - [[13., 14., 15., 16.], - [17., 18., 19., 20.], - [21., 22., 23., 24.]]]), device='cuda:0', - size=(2, 3, 3, 2), nnz=4, layout=torch.sparse_csc) + [0, 1, 0, 0], + [0, 1, 1, 1]], + + [[1, 0, 1, 1], + [0, 1, 1, 0], + [0, 0, 1, 0]]]), + values=tensor([[[1., 3., 2., 4.], + [1., 4., 2., 3.], + [1., 2., 3., 4.]], + + [[2., 1., 3., 4.], + [1., 3., 4., 2.], + [1., 2., 4., 3.]]]), device='cuda:0', size=(2, 3, 2, 3), + nnz=4, dtype=torch.float64, layout=torch.sparse_csc) # _ccol_indices -tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], +tensor([[[0, 2, 3, 4], + [0, 2, 3, 4], + [0, 2, 3, 4]], - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]], device='cuda:0') + [[0, 1, 3, 4], + [0, 2, 3, 4], + [0, 1, 3, 4]]], device='cuda:0') # _row_indices tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], + [0, 1, 0, 0], + [0, 1, 1, 1]], - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]], device='cuda:0') + [[1, 0, 1, 1], + [0, 1, 1, 0], + [0, 0, 1, 0]]], device='cuda:0') # _values -tensor([[[ 1., 2., 3., 4.], - [ 5., 6., 7., 8.], - [ 9., 10., 11., 12.]], - - [[13., 14., 15., 16.], - [17., 18., 19., 20.], - [21., 22., 23., 24.]]], device='cuda:0') +tensor([[[1., 3., 2., 4.], + [1., 4., 2., 3.], + [1., 2., 3., 4.]], + [[2., 1., 3., 4.], + [1., 3., 4., 2.], + [1., 2., 4., 3.]]], device='cuda:0', dtype=torch.float64) -########## torch.float64/torch.int64/size=()+(3, 2)+() ########## +########## torch.float64/torch.int64/size=()+(8, 6)+() ########## # sparse tensor -tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 2]), - values=tensor([1., 2., 3., 4.]), device='cuda:0', size=(3, 2), nnz=4, +tensor(ccol_indices=tensor([ 0, 3, 8, 11, 14, 19, 24]), + row_indices=tensor([2, 5, 7, 0, 1, 2, 4, 7, 2, 5, 7, 0, 4, 5, 1, 4, 5, 6, + 7, 0, 4, 5, 6, 7]), + values=tensor([ 6., 13., 20., 1., 4., 7., 9., 21., 8., 14., 22., + 2., 10., 15., 5., 11., 16., 18., 23., 3., 12., 17., + 19., 24.]), device='cuda:0', size=(8, 6), nnz=24, dtype=torch.float64, layout=torch.sparse_csc) # _ccol_indices -tensor([0, 2, 4], device='cuda:0') +tensor([ 0, 3, 8, 11, 14, 19, 24], device='cuda:0') # _row_indices -tensor([0, 1, 0, 2], device='cuda:0') +tensor([2, 5, 7, 0, 1, 2, 4, 7, 2, 5, 7, 0, 4, 5, 1, 4, 5, 6, 7, 0, 4, 5, 6, 7], + device='cuda:0') # _values -tensor([1., 2., 3., 4.], device='cuda:0', dtype=torch.float64) +tensor([ 6., 13., 20., 1., 4., 7., 9., 21., 8., 14., 22., 2., 10., 15., + 5., 11., 16., 18., 23., 3., 12., 17., 19., 24.], device='cuda:0', + dtype=torch.float64) -########## torch.float64/torch.int64/size=()+(0, 0)+() ########## -# sparse tensor -tensor(ccol_indices=tensor([0]), - row_indices=tensor([], size=(0,)), - values=tensor([], size=(0,)), device='cuda:0', size=(0, 0), nnz=0, - dtype=torch.float64, layout=torch.sparse_csc) -# _ccol_indices -tensor([0], device='cuda:0') -# _row_indices -tensor([], device='cuda:0', dtype=torch.int64) -# _values -tensor([], device='cuda:0', dtype=torch.float64) -########## torch.float64/torch.int64/size=(2,)+(3, 2)+() ########## +########## torch.float32/torch.int32/size=(2, 3)+(2, 3)+(4,) ########## # sparse tensor -tensor(ccol_indices=tensor([[0, 2, 4], - [0, 3, 4]]), - row_indices=tensor([[0, 1, 0, 1], - [0, 1, 2, 0]]), - values=tensor([[1., 2., 3., 4.], - [5., 6., 7., 8.]]), device='cuda:0', size=(2, 3, 2), - nnz=4, dtype=torch.float64, layout=torch.sparse_csc) -# _ccol_indices -tensor([[0, 2, 4], - [0, 3, 4]], device='cuda:0') -# _row_indices -tensor([[0, 1, 0, 1], - [0, 1, 2, 0]], device='cuda:0') -# _values -tensor([[1., 2., 3., 4.], - [5., 6., 7., 8.]], device='cuda:0', dtype=torch.float64) - -########## torch.float64/torch.int64/size=(2, 3)+(3, 2)+() ########## -# sparse tensor -tensor(ccol_indices=tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], +tensor(ccol_indices=tensor([[[0, 2, 3, 4], + [0, 2, 3, 4], + [0, 2, 3, 4]], - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]), + [[0, 1, 3, 4], + [0, 2, 3, 4], + [0, 1, 3, 4]]]), row_indices=tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]), - values=tensor([[[ 1., 2., 3., 4.], - [ 5., 6., 7., 8.], - [ 9., 10., 11., 12.]], - - [[13., 14., 15., 16.], - [17., 18., 19., 20.], - [21., 22., 23., 24.]]]), device='cuda:0', - size=(2, 3, 3, 2), nnz=4, dtype=torch.float64, layout=torch.sparse_csc) + [0, 1, 0, 0], + [0, 1, 1, 1]], + + [[1, 0, 1, 1], + [0, 1, 1, 0], + [0, 0, 1, 0]]]), + values=tensor([[[[1., 2., 3., 4.], + [3., 4., 5., 6.], + [2., 3., 4., 5.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [4., 5., 6., 7.], + [2., 3., 4., 5.], + [3., 4., 5., 6.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]]], + + + [[[2., 3., 4., 5.], + [1., 2., 3., 4.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [3., 4., 5., 6.], + [4., 5., 6., 7.], + [2., 3., 4., 5.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [4., 5., 6., 7.], + [3., 4., 5., 6.]]]]), device='cuda:0', + size=(2, 3, 2, 3, 4), nnz=4, layout=torch.sparse_csc) # _ccol_indices -tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], +tensor([[[0, 2, 3, 4], + [0, 2, 3, 4], + [0, 2, 3, 4]], - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]], device='cuda:0') + [[0, 1, 3, 4], + [0, 2, 3, 4], + [0, 1, 3, 4]]], device='cuda:0', dtype=torch.int32) # _row_indices tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], + [0, 1, 0, 0], + [0, 1, 1, 1]], - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]], device='cuda:0') + [[1, 0, 1, 1], + [0, 1, 1, 0], + [0, 0, 1, 0]]], device='cuda:0', dtype=torch.int32) # _values -tensor([[[ 1., 2., 3., 4.], - [ 5., 6., 7., 8.], - [ 9., 10., 11., 12.]], - - [[13., 14., 15., 16.], - [17., 18., 19., 20.], - [21., 22., 23., 24.]]], device='cuda:0', dtype=torch.float64) - - -########## torch.float32/torch.int32/size=()+(3, 2)+(2,) ########## +tensor([[[[1., 2., 3., 4.], + [3., 4., 5., 6.], + [2., 3., 4., 5.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [4., 5., 6., 7.], + [2., 3., 4., 5.], + [3., 4., 5., 6.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]]], + + + [[[2., 3., 4., 5.], + [1., 2., 3., 4.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [3., 4., 5., 6.], + [4., 5., 6., 7.], + [2., 3., 4., 5.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [4., 5., 6., 7.], + [3., 4., 5., 6.]]]], device='cuda:0') + +########## torch.float32/torch.int32/size=()+(8, 6)+(4, 2) ########## # sparse tensor -tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 2]), - values=tensor([[1., 2.], - [2., 3.], - [3., 4.], - [4., 5.]]), device='cuda:0', size=(3, 2, 2), nnz=4, - layout=torch.sparse_csc) -# _ccol_indices -tensor([0, 2, 4], device='cuda:0', dtype=torch.int32) -# _row_indices -tensor([0, 1, 0, 2], device='cuda:0', dtype=torch.int32) -# _values -tensor([[1., 2.], - [2., 3.], - [3., 4.], - [4., 5.]], device='cuda:0') - -########## torch.float32/torch.int32/size=()+(3, 2)+(4, 2) ########## -# sparse tensor -tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 2]), - values=tensor([[[ 1., 11.], +tensor(ccol_indices=tensor([ 0, 3, 8, 11, 14, 19, 24]), + row_indices=tensor([2, 5, 7, 0, 1, 2, 4, 7, 2, 5, 7, 0, 4, 5, 1, 4, 5, 6, + 7, 0, 4, 5, 6, 7]), + values=tensor([[[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], + + [[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], + + [[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], + + [[ 1., 11.], [ 2., 12.], [ 3., 13.], [ 4., 14.]], + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]], + + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], + + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], + + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], + + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]], + + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]], + + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]], + [[ 2., 12.], [ 3., 13.], [ 4., 14.], [ 5., 15.]], + [[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], + + [[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], + + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], + + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], + + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], + + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], + + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], + [[ 3., 13.], [ 4., 14.], [ 5., 15.], [ 6., 16.]], - [[ 4., 14.], - [ 5., 15.], - [ 6., 16.], - [ 7., 17.]]]), device='cuda:0', size=(3, 2, 4, 2), - nnz=4, layout=torch.sparse_csc) + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]], + + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]], + + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]], + + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]]), device='cuda:0', size=(8, 6, 4, 2), + nnz=24, layout=torch.sparse_csc) # _ccol_indices -tensor([0, 2, 4], device='cuda:0', dtype=torch.int32) +tensor([ 0, 3, 8, 11, 14, 19, 24], device='cuda:0', dtype=torch.int32) # _row_indices -tensor([0, 1, 0, 2], device='cuda:0', dtype=torch.int32) +tensor([2, 5, 7, 0, 1, 2, 4, 7, 2, 5, 7, 0, 4, 5, 1, 4, 5, 6, 7, 0, 4, 5, 6, 7], + device='cuda:0', dtype=torch.int32) # _values -tensor([[[ 1., 11.], +tensor([[[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], + + [[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], + + [[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], + + [[ 1., 11.], [ 2., 12.], [ 3., 13.], [ 4., 14.]], - [[ 2., 12.], - [ 3., 13.], - [ 4., 14.], - [ 5., 15.]], - - [[ 3., 13.], - [ 4., 14.], - [ 5., 15.], - [ 6., 16.]], - [[ 4., 14.], [ 5., 15.], [ 6., 16.], - [ 7., 17.]]], device='cuda:0') - -########## torch.float32/torch.int32/size=(2, 3)+(3, 2)+(2, 1) ########## -# sparse tensor -tensor(ccol_indices=tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], - - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]), - row_indices=tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]), - values=tensor([[[[[ 1.], - [ 2.]], - - [[ 2.], - [ 3.]], - - [[ 3.], - [ 4.]], - - [[ 4.], - [ 5.]]], - - - [[[ 5.], - [ 6.]], - - [[ 6.], - [ 7.]], - - [[ 7.], - [ 8.]], - - [[ 8.], - [ 9.]]], - - - [[[ 9.], - [10.]], - - [[10.], - [11.]], - - [[11.], - [12.]], - - [[12.], - [13.]]]], + [ 7., 17.]], + + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], + + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], + + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], + + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]], + + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]], + + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]], + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], + [[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], - [[[[13.], - [14.]], + [[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], - [[14.], - [15.]], + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], - [[15.], - [16.]], + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], - [[16.], - [17.]]], + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], - [[[17.], - [18.]], + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], - [[18.], - [19.]], + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]], - [[19.], - [20.]], + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]], - [[20.], - [21.]]], + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]], + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]], - [[[21.], - [22.]], + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]], device='cuda:0') - [[22.], - [23.]], - [[23.], - [24.]], +########## torch.float64/torch.int32/size=(2, 3)+(2, 3)+(4,) ########## +# sparse tensor +tensor(ccol_indices=tensor([[[0, 2, 3, 4], + [0, 2, 3, 4], + [0, 2, 3, 4]], - [[24.], - [25.]]]]]), device='cuda:0', size=(2, 3, 3, 2, 2, 1), - nnz=4, layout=torch.sparse_csc) + [[0, 1, 3, 4], + [0, 2, 3, 4], + [0, 1, 3, 4]]]), + row_indices=tensor([[[0, 1, 0, 1], + [0, 1, 0, 0], + [0, 1, 1, 1]], + + [[1, 0, 1, 1], + [0, 1, 1, 0], + [0, 0, 1, 0]]]), + values=tensor([[[[1., 2., 3., 4.], + [3., 4., 5., 6.], + [2., 3., 4., 5.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [4., 5., 6., 7.], + [2., 3., 4., 5.], + [3., 4., 5., 6.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]]], + + + [[[2., 3., 4., 5.], + [1., 2., 3., 4.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [3., 4., 5., 6.], + [4., 5., 6., 7.], + [2., 3., 4., 5.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [4., 5., 6., 7.], + [3., 4., 5., 6.]]]]), device='cuda:0', + size=(2, 3, 2, 3, 4), nnz=4, dtype=torch.float64, layout=torch.sparse_csc) # _ccol_indices -tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], +tensor([[[0, 2, 3, 4], + [0, 2, 3, 4], + [0, 2, 3, 4]], - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]], device='cuda:0', dtype=torch.int32) + [[0, 1, 3, 4], + [0, 2, 3, 4], + [0, 1, 3, 4]]], device='cuda:0', dtype=torch.int32) # _row_indices tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], + [0, 1, 0, 0], + [0, 1, 1, 1]], - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]], device='cuda:0', dtype=torch.int32) + [[1, 0, 1, 1], + [0, 1, 1, 0], + [0, 0, 1, 0]]], device='cuda:0', dtype=torch.int32) # _values -tensor([[[[[ 1.], - [ 2.]], - - [[ 2.], - [ 3.]], - - [[ 3.], - [ 4.]], - - [[ 4.], - [ 5.]]], - - - [[[ 5.], - [ 6.]], - - [[ 6.], - [ 7.]], - - [[ 7.], - [ 8.]], - - [[ 8.], - [ 9.]]], - - - [[[ 9.], - [10.]], - - [[10.], - [11.]], - - [[11.], - [12.]], - - [[12.], - [13.]]]], - - - - [[[[13.], - [14.]], - - [[14.], - [15.]], - - [[15.], - [16.]], - - [[16.], - [17.]]], - - - [[[17.], - [18.]], - - [[18.], - [19.]], - - [[19.], - [20.]], - - [[20.], - [21.]]], - - - [[[21.], - [22.]], - - [[22.], - [23.]], - - [[23.], - [24.]], - - [[24.], - [25.]]]]], device='cuda:0') - - -########## torch.float64/torch.int32/size=()+(3, 2)+(2,) ########## +tensor([[[[1., 2., 3., 4.], + [3., 4., 5., 6.], + [2., 3., 4., 5.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [4., 5., 6., 7.], + [2., 3., 4., 5.], + [3., 4., 5., 6.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]]], + + + [[[2., 3., 4., 5.], + [1., 2., 3., 4.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [3., 4., 5., 6.], + [4., 5., 6., 7.], + [2., 3., 4., 5.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [4., 5., 6., 7.], + [3., 4., 5., 6.]]]], device='cuda:0', dtype=torch.float64) + +########## torch.float64/torch.int32/size=()+(8, 6)+(4, 2) ########## # sparse tensor -tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 2]), - values=tensor([[1., 2.], - [2., 3.], - [3., 4.], - [4., 5.]]), device='cuda:0', size=(3, 2, 2), nnz=4, - dtype=torch.float64, layout=torch.sparse_csc) -# _ccol_indices -tensor([0, 2, 4], device='cuda:0', dtype=torch.int32) -# _row_indices -tensor([0, 1, 0, 2], device='cuda:0', dtype=torch.int32) -# _values -tensor([[1., 2.], - [2., 3.], - [3., 4.], - [4., 5.]], device='cuda:0', dtype=torch.float64) - -########## torch.float64/torch.int32/size=()+(3, 2)+(4, 2) ########## -# sparse tensor -tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 2]), - values=tensor([[[ 1., 11.], +tensor(ccol_indices=tensor([ 0, 3, 8, 11, 14, 19, 24]), + row_indices=tensor([2, 5, 7, 0, 1, 2, 4, 7, 2, 5, 7, 0, 4, 5, 1, 4, 5, 6, + 7, 0, 4, 5, 6, 7]), + values=tensor([[[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], + + [[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], + + [[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], + + [[ 1., 11.], [ 2., 12.], [ 3., 13.], [ 4., 14.]], + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]], + + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], + + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], + + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], + + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]], + + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]], + + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]], + [[ 2., 12.], [ 3., 13.], [ 4., 14.], [ 5., 15.]], + [[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], + + [[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], + + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], + + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], + + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], + + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], + + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], + [[ 3., 13.], [ 4., 14.], [ 5., 15.], [ 6., 16.]], - [[ 4., 14.], - [ 5., 15.], - [ 6., 16.], - [ 7., 17.]]]), device='cuda:0', size=(3, 2, 4, 2), - nnz=4, dtype=torch.float64, layout=torch.sparse_csc) + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]], + + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]], + + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]], + + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]]), device='cuda:0', size=(8, 6, 4, 2), + nnz=24, dtype=torch.float64, layout=torch.sparse_csc) # _ccol_indices -tensor([0, 2, 4], device='cuda:0', dtype=torch.int32) +tensor([ 0, 3, 8, 11, 14, 19, 24], device='cuda:0', dtype=torch.int32) # _row_indices -tensor([0, 1, 0, 2], device='cuda:0', dtype=torch.int32) +tensor([2, 5, 7, 0, 1, 2, 4, 7, 2, 5, 7, 0, 4, 5, 1, 4, 5, 6, 7, 0, 4, 5, 6, 7], + device='cuda:0', dtype=torch.int32) # _values -tensor([[[ 1., 11.], +tensor([[[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], + + [[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], + + [[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], + + [[ 1., 11.], [ 2., 12.], [ 3., 13.], [ 4., 14.]], - [[ 2., 12.], - [ 3., 13.], - [ 4., 14.], - [ 5., 15.]], - - [[ 3., 13.], - [ 4., 14.], - [ 5., 15.], - [ 6., 16.]], - [[ 4., 14.], [ 5., 15.], [ 6., 16.], - [ 7., 17.]]], device='cuda:0', dtype=torch.float64) - -########## torch.float64/torch.int32/size=(2, 3)+(3, 2)+(2, 1) ########## -# sparse tensor -tensor(ccol_indices=tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], - - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]), - row_indices=tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]), - values=tensor([[[[[ 1.], - [ 2.]], - - [[ 2.], - [ 3.]], - - [[ 3.], - [ 4.]], - - [[ 4.], - [ 5.]]], - - - [[[ 5.], - [ 6.]], - - [[ 6.], - [ 7.]], - - [[ 7.], - [ 8.]], - - [[ 8.], - [ 9.]]], - - - [[[ 9.], - [10.]], - - [[10.], - [11.]], - - [[11.], - [12.]], - - [[12.], - [13.]]]], + [ 7., 17.]], + + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], + + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], + + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], + + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]], + + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]], + + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]], + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], + [[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], - [[[[13.], - [14.]], + [[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], - [[14.], - [15.]], + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], - [[15.], - [16.]], + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], - [[16.], - [17.]]], + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], - [[[17.], - [18.]], + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], - [[18.], - [19.]], + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]], - [[19.], - [20.]], + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]], - [[20.], - [21.]]], + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]], + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]], - [[[21.], - [22.]], + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]], device='cuda:0', dtype=torch.float64) - [[22.], - [23.]], - [[23.], - [24.]], +########## torch.float32/torch.int64/size=(2, 3)+(2, 3)+(4,) ########## +# sparse tensor +tensor(ccol_indices=tensor([[[0, 2, 3, 4], + [0, 2, 3, 4], + [0, 2, 3, 4]], - [[24.], - [25.]]]]]), device='cuda:0', size=(2, 3, 3, 2, 2, 1), - nnz=4, dtype=torch.float64, layout=torch.sparse_csc) + [[0, 1, 3, 4], + [0, 2, 3, 4], + [0, 1, 3, 4]]]), + row_indices=tensor([[[0, 1, 0, 1], + [0, 1, 0, 0], + [0, 1, 1, 1]], + + [[1, 0, 1, 1], + [0, 1, 1, 0], + [0, 0, 1, 0]]]), + values=tensor([[[[1., 2., 3., 4.], + [3., 4., 5., 6.], + [2., 3., 4., 5.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [4., 5., 6., 7.], + [2., 3., 4., 5.], + [3., 4., 5., 6.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]]], + + + [[[2., 3., 4., 5.], + [1., 2., 3., 4.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [3., 4., 5., 6.], + [4., 5., 6., 7.], + [2., 3., 4., 5.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [4., 5., 6., 7.], + [3., 4., 5., 6.]]]]), device='cuda:0', + size=(2, 3, 2, 3, 4), nnz=4, layout=torch.sparse_csc) # _ccol_indices -tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], +tensor([[[0, 2, 3, 4], + [0, 2, 3, 4], + [0, 2, 3, 4]], - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]], device='cuda:0', dtype=torch.int32) + [[0, 1, 3, 4], + [0, 2, 3, 4], + [0, 1, 3, 4]]], device='cuda:0') # _row_indices tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], + [0, 1, 0, 0], + [0, 1, 1, 1]], - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]], device='cuda:0', dtype=torch.int32) + [[1, 0, 1, 1], + [0, 1, 1, 0], + [0, 0, 1, 0]]], device='cuda:0') # _values -tensor([[[[[ 1.], - [ 2.]], - - [[ 2.], - [ 3.]], - - [[ 3.], - [ 4.]], - - [[ 4.], - [ 5.]]], - - - [[[ 5.], - [ 6.]], - - [[ 6.], - [ 7.]], - - [[ 7.], - [ 8.]], - - [[ 8.], - [ 9.]]], - - - [[[ 9.], - [10.]], - - [[10.], - [11.]], - - [[11.], - [12.]], - - [[12.], - [13.]]]], - - - - [[[[13.], - [14.]], - - [[14.], - [15.]], - - [[15.], - [16.]], - - [[16.], - [17.]]], - - - [[[17.], - [18.]], - - [[18.], - [19.]], - - [[19.], - [20.]], - - [[20.], - [21.]]], - - - [[[21.], - [22.]], - - [[22.], - [23.]], - - [[23.], - [24.]], - - [[24.], - [25.]]]]], device='cuda:0', dtype=torch.float64) - - -########## torch.float32/torch.int64/size=()+(3, 2)+(2,) ########## +tensor([[[[1., 2., 3., 4.], + [3., 4., 5., 6.], + [2., 3., 4., 5.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [4., 5., 6., 7.], + [2., 3., 4., 5.], + [3., 4., 5., 6.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]]], + + + [[[2., 3., 4., 5.], + [1., 2., 3., 4.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [3., 4., 5., 6.], + [4., 5., 6., 7.], + [2., 3., 4., 5.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [4., 5., 6., 7.], + [3., 4., 5., 6.]]]], device='cuda:0') + +########## torch.float32/torch.int64/size=()+(8, 6)+(4, 2) ########## # sparse tensor -tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 2]), - values=tensor([[1., 2.], - [2., 3.], - [3., 4.], - [4., 5.]]), device='cuda:0', size=(3, 2, 2), nnz=4, - layout=torch.sparse_csc) -# _ccol_indices -tensor([0, 2, 4], device='cuda:0') -# _row_indices -tensor([0, 1, 0, 2], device='cuda:0') -# _values -tensor([[1., 2.], - [2., 3.], - [3., 4.], - [4., 5.]], device='cuda:0') - -########## torch.float32/torch.int64/size=()+(3, 2)+(4, 2) ########## -# sparse tensor -tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 2]), - values=tensor([[[ 1., 11.], +tensor(ccol_indices=tensor([ 0, 3, 8, 11, 14, 19, 24]), + row_indices=tensor([2, 5, 7, 0, 1, 2, 4, 7, 2, 5, 7, 0, 4, 5, 1, 4, 5, 6, + 7, 0, 4, 5, 6, 7]), + values=tensor([[[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], + + [[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], + + [[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], + + [[ 1., 11.], [ 2., 12.], [ 3., 13.], [ 4., 14.]], + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]], + + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], + + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], + + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], + + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]], + + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]], + + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]], + [[ 2., 12.], [ 3., 13.], [ 4., 14.], [ 5., 15.]], + [[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], + + [[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], + + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], + + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], + + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], + + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], + + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], + [[ 3., 13.], [ 4., 14.], [ 5., 15.], [ 6., 16.]], - [[ 4., 14.], - [ 5., 15.], - [ 6., 16.], - [ 7., 17.]]]), device='cuda:0', size=(3, 2, 4, 2), - nnz=4, layout=torch.sparse_csc) + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]], + + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]], + + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]], + + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]]), device='cuda:0', size=(8, 6, 4, 2), + nnz=24, layout=torch.sparse_csc) # _ccol_indices -tensor([0, 2, 4], device='cuda:0') +tensor([ 0, 3, 8, 11, 14, 19, 24], device='cuda:0') # _row_indices -tensor([0, 1, 0, 2], device='cuda:0') +tensor([2, 5, 7, 0, 1, 2, 4, 7, 2, 5, 7, 0, 4, 5, 1, 4, 5, 6, 7, 0, 4, 5, 6, 7], + device='cuda:0') # _values -tensor([[[ 1., 11.], +tensor([[[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], + + [[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], + + [[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], + + [[ 1., 11.], [ 2., 12.], [ 3., 13.], [ 4., 14.]], - [[ 2., 12.], - [ 3., 13.], - [ 4., 14.], - [ 5., 15.]], - - [[ 3., 13.], - [ 4., 14.], - [ 5., 15.], - [ 6., 16.]], - [[ 4., 14.], [ 5., 15.], [ 6., 16.], - [ 7., 17.]]], device='cuda:0') - -########## torch.float32/torch.int64/size=(2, 3)+(3, 2)+(2, 1) ########## -# sparse tensor -tensor(ccol_indices=tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], - - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]), - row_indices=tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]), - values=tensor([[[[[ 1.], - [ 2.]], - - [[ 2.], - [ 3.]], - - [[ 3.], - [ 4.]], - - [[ 4.], - [ 5.]]], - - - [[[ 5.], - [ 6.]], - - [[ 6.], - [ 7.]], - - [[ 7.], - [ 8.]], - - [[ 8.], - [ 9.]]], - - - [[[ 9.], - [10.]], - - [[10.], - [11.]], - - [[11.], - [12.]], - - [[12.], - [13.]]]], + [ 7., 17.]], + + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], + + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], + + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], + + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]], + + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]], + + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]], + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], + [[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], - [[[[13.], - [14.]], + [[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], - [[14.], - [15.]], + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], - [[15.], - [16.]], + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], - [[16.], - [17.]]], + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], - [[[17.], - [18.]], + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], - [[18.], - [19.]], + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]], - [[19.], - [20.]], + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]], - [[20.], - [21.]]], + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]], + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]], - [[[21.], - [22.]], + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]], device='cuda:0') - [[22.], - [23.]], - [[23.], - [24.]], +########## torch.float64/torch.int64/size=(2, 3)+(2, 3)+(4,) ########## +# sparse tensor +tensor(ccol_indices=tensor([[[0, 2, 3, 4], + [0, 2, 3, 4], + [0, 2, 3, 4]], - [[24.], - [25.]]]]]), device='cuda:0', size=(2, 3, 3, 2, 2, 1), - nnz=4, layout=torch.sparse_csc) + [[0, 1, 3, 4], + [0, 2, 3, 4], + [0, 1, 3, 4]]]), + row_indices=tensor([[[0, 1, 0, 1], + [0, 1, 0, 0], + [0, 1, 1, 1]], + + [[1, 0, 1, 1], + [0, 1, 1, 0], + [0, 0, 1, 0]]]), + values=tensor([[[[1., 2., 3., 4.], + [3., 4., 5., 6.], + [2., 3., 4., 5.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [4., 5., 6., 7.], + [2., 3., 4., 5.], + [3., 4., 5., 6.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]]], + + + [[[2., 3., 4., 5.], + [1., 2., 3., 4.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [3., 4., 5., 6.], + [4., 5., 6., 7.], + [2., 3., 4., 5.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [4., 5., 6., 7.], + [3., 4., 5., 6.]]]]), device='cuda:0', + size=(2, 3, 2, 3, 4), nnz=4, dtype=torch.float64, layout=torch.sparse_csc) # _ccol_indices -tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], +tensor([[[0, 2, 3, 4], + [0, 2, 3, 4], + [0, 2, 3, 4]], - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]], device='cuda:0') + [[0, 1, 3, 4], + [0, 2, 3, 4], + [0, 1, 3, 4]]], device='cuda:0') # _row_indices tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], + [0, 1, 0, 0], + [0, 1, 1, 1]], - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]], device='cuda:0') + [[1, 0, 1, 1], + [0, 1, 1, 0], + [0, 0, 1, 0]]], device='cuda:0') # _values -tensor([[[[[ 1.], - [ 2.]], - - [[ 2.], - [ 3.]], - - [[ 3.], - [ 4.]], - - [[ 4.], - [ 5.]]], - - - [[[ 5.], - [ 6.]], - - [[ 6.], - [ 7.]], - - [[ 7.], - [ 8.]], - - [[ 8.], - [ 9.]]], - - - [[[ 9.], - [10.]], - - [[10.], - [11.]], - - [[11.], - [12.]], - - [[12.], - [13.]]]], - - - - [[[[13.], - [14.]], - - [[14.], - [15.]], - - [[15.], - [16.]], - - [[16.], - [17.]]], - - - [[[17.], - [18.]], - - [[18.], - [19.]], - - [[19.], - [20.]], - - [[20.], - [21.]]], - - - [[[21.], - [22.]], - - [[22.], - [23.]], - - [[23.], - [24.]], - - [[24.], - [25.]]]]], device='cuda:0') - - -########## torch.float64/torch.int64/size=()+(3, 2)+(2,) ########## +tensor([[[[1., 2., 3., 4.], + [3., 4., 5., 6.], + [2., 3., 4., 5.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [4., 5., 6., 7.], + [2., 3., 4., 5.], + [3., 4., 5., 6.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]]], + + + [[[2., 3., 4., 5.], + [1., 2., 3., 4.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [3., 4., 5., 6.], + [4., 5., 6., 7.], + [2., 3., 4., 5.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [4., 5., 6., 7.], + [3., 4., 5., 6.]]]], device='cuda:0', dtype=torch.float64) + +########## torch.float64/torch.int64/size=()+(8, 6)+(4, 2) ########## # sparse tensor -tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 2]), - values=tensor([[1., 2.], - [2., 3.], - [3., 4.], - [4., 5.]]), device='cuda:0', size=(3, 2, 2), nnz=4, - dtype=torch.float64, layout=torch.sparse_csc) -# _ccol_indices -tensor([0, 2, 4], device='cuda:0') -# _row_indices -tensor([0, 1, 0, 2], device='cuda:0') -# _values -tensor([[1., 2.], - [2., 3.], - [3., 4.], - [4., 5.]], device='cuda:0', dtype=torch.float64) - -########## torch.float64/torch.int64/size=()+(3, 2)+(4, 2) ########## -# sparse tensor -tensor(ccol_indices=tensor([0, 2, 4]), - row_indices=tensor([0, 1, 0, 2]), - values=tensor([[[ 1., 11.], +tensor(ccol_indices=tensor([ 0, 3, 8, 11, 14, 19, 24]), + row_indices=tensor([2, 5, 7, 0, 1, 2, 4, 7, 2, 5, 7, 0, 4, 5, 1, 4, 5, 6, + 7, 0, 4, 5, 6, 7]), + values=tensor([[[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], + + [[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], + + [[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], + + [[ 1., 11.], [ 2., 12.], [ 3., 13.], [ 4., 14.]], + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]], + + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], + + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], + + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], + + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]], + + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]], + + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]], + [[ 2., 12.], [ 3., 13.], [ 4., 14.], [ 5., 15.]], + [[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], + + [[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], + + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], + + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], + + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], + + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], + + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], + [[ 3., 13.], [ 4., 14.], [ 5., 15.], [ 6., 16.]], - [[ 4., 14.], - [ 5., 15.], - [ 6., 16.], - [ 7., 17.]]]), device='cuda:0', size=(3, 2, 4, 2), - nnz=4, dtype=torch.float64, layout=torch.sparse_csc) + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]], + + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]], + + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]], + + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]]), device='cuda:0', size=(8, 6, 4, 2), + nnz=24, dtype=torch.float64, layout=torch.sparse_csc) # _ccol_indices -tensor([0, 2, 4], device='cuda:0') +tensor([ 0, 3, 8, 11, 14, 19, 24], device='cuda:0') # _row_indices -tensor([0, 1, 0, 2], device='cuda:0') +tensor([2, 5, 7, 0, 1, 2, 4, 7, 2, 5, 7, 0, 4, 5, 1, 4, 5, 6, 7, 0, 4, 5, 6, 7], + device='cuda:0') # _values -tensor([[[ 1., 11.], +tensor([[[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], + + [[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], + + [[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], + + [[ 1., 11.], [ 2., 12.], [ 3., 13.], [ 4., 14.]], - [[ 2., 12.], - [ 3., 13.], - [ 4., 14.], - [ 5., 15.]], - - [[ 3., 13.], - [ 4., 14.], - [ 5., 15.], - [ 6., 16.]], - [[ 4., 14.], [ 5., 15.], [ 6., 16.], - [ 7., 17.]]], device='cuda:0', dtype=torch.float64) - -########## torch.float64/torch.int64/size=(2, 3)+(3, 2)+(2, 1) ########## -# sparse tensor -tensor(ccol_indices=tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], - - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]]), - row_indices=tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]]), - values=tensor([[[[[ 1.], - [ 2.]], - - [[ 2.], - [ 3.]], - - [[ 3.], - [ 4.]], - - [[ 4.], - [ 5.]]], - - - [[[ 5.], - [ 6.]], - - [[ 6.], - [ 7.]], - - [[ 7.], - [ 8.]], - - [[ 8.], - [ 9.]]], - - - [[[ 9.], - [10.]], - - [[10.], - [11.]], - - [[11.], - [12.]], - - [[12.], - [13.]]]], - - - - [[[[13.], - [14.]], - - [[14.], - [15.]], - - [[15.], - [16.]], - - [[16.], - [17.]]], - - - [[[17.], - [18.]], - - [[18.], - [19.]], - - [[19.], - [20.]], - - [[20.], - [21.]]], - - - [[[21.], - [22.]], - - [[22.], - [23.]], - - [[23.], - [24.]], - - [[24.], - [25.]]]]]), device='cuda:0', size=(2, 3, 3, 2, 2, 1), - nnz=4, dtype=torch.float64, layout=torch.sparse_csc) -# _ccol_indices -tensor([[[0, 2, 4], - [0, 3, 4], - [0, 1, 4]], - - [[0, 1, 4], - [0, 2, 4], - [0, 3, 4]]], device='cuda:0') -# _row_indices -tensor([[[0, 1, 0, 1], - [0, 1, 2, 0], - [0, 0, 1, 2]], - - [[1, 0, 1, 2], - [0, 2, 0, 1], - [0, 1, 2, 1]]], device='cuda:0') -# _values -tensor([[[[[ 1.], - [ 2.]], - - [[ 2.], - [ 3.]], - - [[ 3.], - [ 4.]], - - [[ 4.], - [ 5.]]], - - - [[[ 5.], - [ 6.]], - - [[ 6.], - [ 7.]], - - [[ 7.], - [ 8.]], - - [[ 8.], - [ 9.]]], - - - [[[ 9.], - [10.]], - - [[10.], - [11.]], - - [[11.], - [12.]], - - [[12.], - [13.]]]], - - - - [[[[13.], - [14.]], - - [[14.], - [15.]], - - [[15.], - [16.]], - - [[16.], - [17.]]], + [ 7., 17.]], + + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], + + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], + + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], + + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]], + + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]], + + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]], + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], - [[[17.], - [18.]], + [[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], - [[18.], - [19.]], + [[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], - [[19.], - [20.]], + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], - [[20.], - [21.]]], + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], - [[[21.], - [22.]], + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], - [[22.], - [23.]], + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], - [[23.], - [24.]], + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]], - [[24.], - [25.]]]]], device='cuda:0', dtype=torch.float64) + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]], + + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]], + + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]], + + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]], device='cuda:0', dtype=torch.float64) diff --git a/test/expect/TestSparseCompressedCUDA.test_print_SparseCSR_cuda.expect b/test/expect/TestSparseCompressedCUDA.test_print_SparseCSR_cuda.expect index ddb5272c79cab..a02ee510ff8a5 100644 --- a/test/expect/TestSparseCompressedCUDA.test_print_SparseCSR_cuda.expect +++ b/test/expect/TestSparseCompressedCUDA.test_print_SparseCSR_cuda.expect @@ -1,48 +1,3 @@ -########## torch.float32/torch.int32/size=()+(2, 3)+() ########## -# sparse tensor -tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 2]), - values=tensor([1., 2., 3., 4.]), device='cuda:0', size=(2, 3), nnz=4, - layout=torch.sparse_csr) -# _crow_indices -tensor([0, 2, 4], device='cuda:0', dtype=torch.int32) -# _col_indices -tensor([0, 1, 0, 2], device='cuda:0', dtype=torch.int32) -# _values -tensor([1., 2., 3., 4.], device='cuda:0') - -########## torch.float32/torch.int32/size=()+(0, 0)+() ########## -# sparse tensor -tensor(crow_indices=tensor([0]), - col_indices=tensor([], size=(0,)), - values=tensor([], size=(0,)), device='cuda:0', size=(0, 0), nnz=0, - layout=torch.sparse_csr) -# _crow_indices -tensor([0], device='cuda:0', dtype=torch.int32) -# _col_indices -tensor([], device='cuda:0', dtype=torch.int32) -# _values -tensor([], device='cuda:0') - -########## torch.float32/torch.int32/size=(2,)+(2, 3)+() ########## -# sparse tensor -tensor(crow_indices=tensor([[0, 2, 4], - [0, 3, 4]]), - col_indices=tensor([[0, 1, 0, 1], - [0, 1, 2, 0]]), - values=tensor([[1., 2., 3., 4.], - [5., 6., 7., 8.]]), device='cuda:0', size=(2, 2, 3), - nnz=4, layout=torch.sparse_csr) -# _crow_indices -tensor([[0, 2, 4], - [0, 3, 4]], device='cuda:0', dtype=torch.int32) -# _col_indices -tensor([[0, 1, 0, 1], - [0, 1, 2, 0]], device='cuda:0', dtype=torch.int32) -# _values -tensor([[1., 2., 3., 4.], - [5., 6., 7., 8.]], device='cuda:0') - ########## torch.float32/torch.int32/size=(2, 3)+(2, 3)+() ########## # sparse tensor tensor(crow_indices=tensor([[[0, 2, 4], @@ -52,21 +7,21 @@ tensor(crow_indices=tensor([[[0, 2, 4], [[0, 1, 4], [0, 2, 4], [0, 3, 4]]]), - col_indices=tensor([[[0, 1, 0, 1], + col_indices=tensor([[[0, 1, 0, 2], [0, 1, 2, 0], [0, 0, 1, 2]], [[1, 0, 1, 2], [0, 2, 0, 1], [0, 1, 2, 1]]]), - values=tensor([[[ 1., 2., 3., 4.], - [ 5., 6., 7., 8.], - [ 9., 10., 11., 12.]], - - [[13., 14., 15., 16.], - [17., 18., 19., 20.], - [21., 22., 23., 24.]]]), device='cuda:0', - size=(2, 3, 2, 3), nnz=4, layout=torch.sparse_csr) + values=tensor([[[1., 2., 3., 4.], + [1., 2., 3., 4.], + [1., 2., 3., 4.]], + + [[1., 2., 3., 4.], + [1., 2., 3., 4.], + [1., 2., 3., 4.]]]), device='cuda:0', size=(2, 3, 2, 3), + nnz=4, layout=torch.sparse_csr) # _crow_indices tensor([[[0, 2, 4], [0, 3, 4], @@ -76,7 +31,7 @@ tensor([[[0, 2, 4], [0, 2, 4], [0, 3, 4]]], device='cuda:0', dtype=torch.int32) # _col_indices -tensor([[[0, 1, 0, 1], +tensor([[[0, 1, 0, 2], [0, 1, 2, 0], [0, 0, 1, 2]], @@ -84,59 +39,33 @@ tensor([[[0, 1, 0, 1], [0, 2, 0, 1], [0, 1, 2, 1]]], device='cuda:0', dtype=torch.int32) # _values -tensor([[[ 1., 2., 3., 4.], - [ 5., 6., 7., 8.], - [ 9., 10., 11., 12.]], +tensor([[[1., 2., 3., 4.], + [1., 2., 3., 4.], + [1., 2., 3., 4.]], - [[13., 14., 15., 16.], - [17., 18., 19., 20.], - [21., 22., 23., 24.]]], device='cuda:0') - - -########## torch.float64/torch.int32/size=()+(2, 3)+() ########## -# sparse tensor -tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 2]), - values=tensor([1., 2., 3., 4.]), device='cuda:0', size=(2, 3), nnz=4, - dtype=torch.float64, layout=torch.sparse_csr) -# _crow_indices -tensor([0, 2, 4], device='cuda:0', dtype=torch.int32) -# _col_indices -tensor([0, 1, 0, 2], device='cuda:0', dtype=torch.int32) -# _values -tensor([1., 2., 3., 4.], device='cuda:0', dtype=torch.float64) + [[1., 2., 3., 4.], + [1., 2., 3., 4.], + [1., 2., 3., 4.]]], device='cuda:0') -########## torch.float64/torch.int32/size=()+(0, 0)+() ########## +########## torch.float32/torch.int32/size=()+(8, 6)+() ########## # sparse tensor -tensor(crow_indices=tensor([0]), - col_indices=tensor([], size=(0,)), - values=tensor([], size=(0,)), device='cuda:0', size=(0, 0), nnz=0, - dtype=torch.float64, layout=torch.sparse_csr) +tensor(crow_indices=tensor([ 0, 3, 5, 8, 8, 12, 17, 19, 24]), + col_indices=tensor([1, 3, 5, 1, 4, 0, 1, 2, 1, 3, 4, 5, 0, 2, 3, 4, 5, 4, + 5, 0, 1, 2, 4, 5]), + values=tensor([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., + 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., + 23., 24.]), device='cuda:0', size=(8, 6), nnz=24, + layout=torch.sparse_csr) # _crow_indices -tensor([0], device='cuda:0', dtype=torch.int32) +tensor([ 0, 3, 5, 8, 8, 12, 17, 19, 24], device='cuda:0', + dtype=torch.int32) # _col_indices -tensor([], device='cuda:0', dtype=torch.int32) +tensor([1, 3, 5, 1, 4, 0, 1, 2, 1, 3, 4, 5, 0, 2, 3, 4, 5, 4, 5, 0, 1, 2, 4, 5], + device='cuda:0', dtype=torch.int32) # _values -tensor([], device='cuda:0', dtype=torch.float64) +tensor([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., + 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.], device='cuda:0') -########## torch.float64/torch.int32/size=(2,)+(2, 3)+() ########## -# sparse tensor -tensor(crow_indices=tensor([[0, 2, 4], - [0, 3, 4]]), - col_indices=tensor([[0, 1, 0, 1], - [0, 1, 2, 0]]), - values=tensor([[1., 2., 3., 4.], - [5., 6., 7., 8.]]), device='cuda:0', size=(2, 2, 3), - nnz=4, dtype=torch.float64, layout=torch.sparse_csr) -# _crow_indices -tensor([[0, 2, 4], - [0, 3, 4]], device='cuda:0', dtype=torch.int32) -# _col_indices -tensor([[0, 1, 0, 1], - [0, 1, 2, 0]], device='cuda:0', dtype=torch.int32) -# _values -tensor([[1., 2., 3., 4.], - [5., 6., 7., 8.]], device='cuda:0', dtype=torch.float64) ########## torch.float64/torch.int32/size=(2, 3)+(2, 3)+() ########## # sparse tensor @@ -147,21 +76,21 @@ tensor(crow_indices=tensor([[[0, 2, 4], [[0, 1, 4], [0, 2, 4], [0, 3, 4]]]), - col_indices=tensor([[[0, 1, 0, 1], + col_indices=tensor([[[0, 1, 0, 2], [0, 1, 2, 0], [0, 0, 1, 2]], [[1, 0, 1, 2], [0, 2, 0, 1], [0, 1, 2, 1]]]), - values=tensor([[[ 1., 2., 3., 4.], - [ 5., 6., 7., 8.], - [ 9., 10., 11., 12.]], - - [[13., 14., 15., 16.], - [17., 18., 19., 20.], - [21., 22., 23., 24.]]]), device='cuda:0', - size=(2, 3, 2, 3), nnz=4, dtype=torch.float64, layout=torch.sparse_csr) + values=tensor([[[1., 2., 3., 4.], + [1., 2., 3., 4.], + [1., 2., 3., 4.]], + + [[1., 2., 3., 4.], + [1., 2., 3., 4.], + [1., 2., 3., 4.]]]), device='cuda:0', size=(2, 3, 2, 3), + nnz=4, dtype=torch.float64, layout=torch.sparse_csr) # _crow_indices tensor([[[0, 2, 4], [0, 3, 4], @@ -171,7 +100,7 @@ tensor([[[0, 2, 4], [0, 2, 4], [0, 3, 4]]], device='cuda:0', dtype=torch.int32) # _col_indices -tensor([[[0, 1, 0, 1], +tensor([[[0, 1, 0, 2], [0, 1, 2, 0], [0, 0, 1, 2]], @@ -179,59 +108,34 @@ tensor([[[0, 1, 0, 1], [0, 2, 0, 1], [0, 1, 2, 1]]], device='cuda:0', dtype=torch.int32) # _values -tensor([[[ 1., 2., 3., 4.], - [ 5., 6., 7., 8.], - [ 9., 10., 11., 12.]], +tensor([[[1., 2., 3., 4.], + [1., 2., 3., 4.], + [1., 2., 3., 4.]], - [[13., 14., 15., 16.], - [17., 18., 19., 20.], - [21., 22., 23., 24.]]], device='cuda:0', dtype=torch.float64) - - -########## torch.float32/torch.int64/size=()+(2, 3)+() ########## -# sparse tensor -tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 2]), - values=tensor([1., 2., 3., 4.]), device='cuda:0', size=(2, 3), nnz=4, - layout=torch.sparse_csr) -# _crow_indices -tensor([0, 2, 4], device='cuda:0') -# _col_indices -tensor([0, 1, 0, 2], device='cuda:0') -# _values -tensor([1., 2., 3., 4.], device='cuda:0') + [[1., 2., 3., 4.], + [1., 2., 3., 4.], + [1., 2., 3., 4.]]], device='cuda:0', dtype=torch.float64) -########## torch.float32/torch.int64/size=()+(0, 0)+() ########## +########## torch.float64/torch.int32/size=()+(8, 6)+() ########## # sparse tensor -tensor(crow_indices=tensor([0]), - col_indices=tensor([], size=(0,)), - values=tensor([], size=(0,)), device='cuda:0', size=(0, 0), nnz=0, - layout=torch.sparse_csr) +tensor(crow_indices=tensor([ 0, 3, 5, 8, 8, 12, 17, 19, 24]), + col_indices=tensor([1, 3, 5, 1, 4, 0, 1, 2, 1, 3, 4, 5, 0, 2, 3, 4, 5, 4, + 5, 0, 1, 2, 4, 5]), + values=tensor([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., + 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., + 23., 24.]), device='cuda:0', size=(8, 6), nnz=24, + dtype=torch.float64, layout=torch.sparse_csr) # _crow_indices -tensor([0], device='cuda:0') +tensor([ 0, 3, 5, 8, 8, 12, 17, 19, 24], device='cuda:0', + dtype=torch.int32) # _col_indices -tensor([], device='cuda:0', dtype=torch.int64) +tensor([1, 3, 5, 1, 4, 0, 1, 2, 1, 3, 4, 5, 0, 2, 3, 4, 5, 4, 5, 0, 1, 2, 4, 5], + device='cuda:0', dtype=torch.int32) # _values -tensor([], device='cuda:0') +tensor([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., + 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.], device='cuda:0', + dtype=torch.float64) -########## torch.float32/torch.int64/size=(2,)+(2, 3)+() ########## -# sparse tensor -tensor(crow_indices=tensor([[0, 2, 4], - [0, 3, 4]]), - col_indices=tensor([[0, 1, 0, 1], - [0, 1, 2, 0]]), - values=tensor([[1., 2., 3., 4.], - [5., 6., 7., 8.]]), device='cuda:0', size=(2, 2, 3), - nnz=4, layout=torch.sparse_csr) -# _crow_indices -tensor([[0, 2, 4], - [0, 3, 4]], device='cuda:0') -# _col_indices -tensor([[0, 1, 0, 1], - [0, 1, 2, 0]], device='cuda:0') -# _values -tensor([[1., 2., 3., 4.], - [5., 6., 7., 8.]], device='cuda:0') ########## torch.float32/torch.int64/size=(2, 3)+(2, 3)+() ########## # sparse tensor @@ -242,21 +146,21 @@ tensor(crow_indices=tensor([[[0, 2, 4], [[0, 1, 4], [0, 2, 4], [0, 3, 4]]]), - col_indices=tensor([[[0, 1, 0, 1], + col_indices=tensor([[[0, 1, 0, 2], [0, 1, 2, 0], [0, 0, 1, 2]], [[1, 0, 1, 2], [0, 2, 0, 1], [0, 1, 2, 1]]]), - values=tensor([[[ 1., 2., 3., 4.], - [ 5., 6., 7., 8.], - [ 9., 10., 11., 12.]], - - [[13., 14., 15., 16.], - [17., 18., 19., 20.], - [21., 22., 23., 24.]]]), device='cuda:0', - size=(2, 3, 2, 3), nnz=4, layout=torch.sparse_csr) + values=tensor([[[1., 2., 3., 4.], + [1., 2., 3., 4.], + [1., 2., 3., 4.]], + + [[1., 2., 3., 4.], + [1., 2., 3., 4.], + [1., 2., 3., 4.]]]), device='cuda:0', size=(2, 3, 2, 3), + nnz=4, layout=torch.sparse_csr) # _crow_indices tensor([[[0, 2, 4], [0, 3, 4], @@ -266,7 +170,7 @@ tensor([[[0, 2, 4], [0, 2, 4], [0, 3, 4]]], device='cuda:0') # _col_indices -tensor([[[0, 1, 0, 1], +tensor([[[0, 1, 0, 2], [0, 1, 2, 0], [0, 0, 1, 2]], @@ -274,59 +178,32 @@ tensor([[[0, 1, 0, 1], [0, 2, 0, 1], [0, 1, 2, 1]]], device='cuda:0') # _values -tensor([[[ 1., 2., 3., 4.], - [ 5., 6., 7., 8.], - [ 9., 10., 11., 12.]], +tensor([[[1., 2., 3., 4.], + [1., 2., 3., 4.], + [1., 2., 3., 4.]], - [[13., 14., 15., 16.], - [17., 18., 19., 20.], - [21., 22., 23., 24.]]], device='cuda:0') + [[1., 2., 3., 4.], + [1., 2., 3., 4.], + [1., 2., 3., 4.]]], device='cuda:0') - -########## torch.float64/torch.int64/size=()+(2, 3)+() ########## -# sparse tensor -tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 2]), - values=tensor([1., 2., 3., 4.]), device='cuda:0', size=(2, 3), nnz=4, - dtype=torch.float64, layout=torch.sparse_csr) -# _crow_indices -tensor([0, 2, 4], device='cuda:0') -# _col_indices -tensor([0, 1, 0, 2], device='cuda:0') -# _values -tensor([1., 2., 3., 4.], device='cuda:0', dtype=torch.float64) - -########## torch.float64/torch.int64/size=()+(0, 0)+() ########## +########## torch.float32/torch.int64/size=()+(8, 6)+() ########## # sparse tensor -tensor(crow_indices=tensor([0]), - col_indices=tensor([], size=(0,)), - values=tensor([], size=(0,)), device='cuda:0', size=(0, 0), nnz=0, - dtype=torch.float64, layout=torch.sparse_csr) +tensor(crow_indices=tensor([ 0, 3, 5, 8, 8, 12, 17, 19, 24]), + col_indices=tensor([1, 3, 5, 1, 4, 0, 1, 2, 1, 3, 4, 5, 0, 2, 3, 4, 5, 4, + 5, 0, 1, 2, 4, 5]), + values=tensor([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., + 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., + 23., 24.]), device='cuda:0', size=(8, 6), nnz=24, + layout=torch.sparse_csr) # _crow_indices -tensor([0], device='cuda:0') +tensor([ 0, 3, 5, 8, 8, 12, 17, 19, 24], device='cuda:0') # _col_indices -tensor([], device='cuda:0', dtype=torch.int64) +tensor([1, 3, 5, 1, 4, 0, 1, 2, 1, 3, 4, 5, 0, 2, 3, 4, 5, 4, 5, 0, 1, 2, 4, 5], + device='cuda:0') # _values -tensor([], device='cuda:0', dtype=torch.float64) +tensor([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., + 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.], device='cuda:0') -########## torch.float64/torch.int64/size=(2,)+(2, 3)+() ########## -# sparse tensor -tensor(crow_indices=tensor([[0, 2, 4], - [0, 3, 4]]), - col_indices=tensor([[0, 1, 0, 1], - [0, 1, 2, 0]]), - values=tensor([[1., 2., 3., 4.], - [5., 6., 7., 8.]]), device='cuda:0', size=(2, 2, 3), - nnz=4, dtype=torch.float64, layout=torch.sparse_csr) -# _crow_indices -tensor([[0, 2, 4], - [0, 3, 4]], device='cuda:0') -# _col_indices -tensor([[0, 1, 0, 1], - [0, 1, 2, 0]], device='cuda:0') -# _values -tensor([[1., 2., 3., 4.], - [5., 6., 7., 8.]], device='cuda:0', dtype=torch.float64) ########## torch.float64/torch.int64/size=(2, 3)+(2, 3)+() ########## # sparse tensor @@ -337,21 +214,21 @@ tensor(crow_indices=tensor([[[0, 2, 4], [[0, 1, 4], [0, 2, 4], [0, 3, 4]]]), - col_indices=tensor([[[0, 1, 0, 1], + col_indices=tensor([[[0, 1, 0, 2], [0, 1, 2, 0], [0, 0, 1, 2]], [[1, 0, 1, 2], [0, 2, 0, 1], [0, 1, 2, 1]]]), - values=tensor([[[ 1., 2., 3., 4.], - [ 5., 6., 7., 8.], - [ 9., 10., 11., 12.]], - - [[13., 14., 15., 16.], - [17., 18., 19., 20.], - [21., 22., 23., 24.]]]), device='cuda:0', - size=(2, 3, 2, 3), nnz=4, dtype=torch.float64, layout=torch.sparse_csr) + values=tensor([[[1., 2., 3., 4.], + [1., 2., 3., 4.], + [1., 2., 3., 4.]], + + [[1., 2., 3., 4.], + [1., 2., 3., 4.], + [1., 2., 3., 4.]]]), device='cuda:0', size=(2, 3, 2, 3), + nnz=4, dtype=torch.float64, layout=torch.sparse_csr) # _crow_indices tensor([[[0, 2, 4], [0, 3, 4], @@ -361,7 +238,7 @@ tensor([[[0, 2, 4], [0, 2, 4], [0, 3, 4]]], device='cuda:0') # _col_indices -tensor([[[0, 1, 0, 1], +tensor([[[0, 1, 0, 2], [0, 1, 2, 0], [0, 0, 1, 2]], @@ -369,84 +246,35 @@ tensor([[[0, 1, 0, 1], [0, 2, 0, 1], [0, 1, 2, 1]]], device='cuda:0') # _values -tensor([[[ 1., 2., 3., 4.], - [ 5., 6., 7., 8.], - [ 9., 10., 11., 12.]], - - [[13., 14., 15., 16.], - [17., 18., 19., 20.], - [21., 22., 23., 24.]]], device='cuda:0', dtype=torch.float64) +tensor([[[1., 2., 3., 4.], + [1., 2., 3., 4.], + [1., 2., 3., 4.]], + [[1., 2., 3., 4.], + [1., 2., 3., 4.], + [1., 2., 3., 4.]]], device='cuda:0', dtype=torch.float64) -########## torch.float32/torch.int32/size=()+(2, 3)+(2,) ########## +########## torch.float64/torch.int64/size=()+(8, 6)+() ########## # sparse tensor -tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 2]), - values=tensor([[1., 2.], - [2., 3.], - [3., 4.], - [4., 5.]]), device='cuda:0', size=(2, 3, 2), nnz=4, - layout=torch.sparse_csr) -# _crow_indices -tensor([0, 2, 4], device='cuda:0', dtype=torch.int32) -# _col_indices -tensor([0, 1, 0, 2], device='cuda:0', dtype=torch.int32) -# _values -tensor([[1., 2.], - [2., 3.], - [3., 4.], - [4., 5.]], device='cuda:0') - -########## torch.float32/torch.int32/size=()+(2, 3)+(4, 2) ########## -# sparse tensor -tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 2]), - values=tensor([[[ 1., 11.], - [ 2., 12.], - [ 3., 13.], - [ 4., 14.]], - - [[ 2., 12.], - [ 3., 13.], - [ 4., 14.], - [ 5., 15.]], - - [[ 3., 13.], - [ 4., 14.], - [ 5., 15.], - [ 6., 16.]], - - [[ 4., 14.], - [ 5., 15.], - [ 6., 16.], - [ 7., 17.]]]), device='cuda:0', size=(2, 3, 4, 2), - nnz=4, layout=torch.sparse_csr) +tensor(crow_indices=tensor([ 0, 3, 5, 8, 8, 12, 17, 19, 24]), + col_indices=tensor([1, 3, 5, 1, 4, 0, 1, 2, 1, 3, 4, 5, 0, 2, 3, 4, 5, 4, + 5, 0, 1, 2, 4, 5]), + values=tensor([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., + 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., + 23., 24.]), device='cuda:0', size=(8, 6), nnz=24, + dtype=torch.float64, layout=torch.sparse_csr) # _crow_indices -tensor([0, 2, 4], device='cuda:0', dtype=torch.int32) +tensor([ 0, 3, 5, 8, 8, 12, 17, 19, 24], device='cuda:0') # _col_indices -tensor([0, 1, 0, 2], device='cuda:0', dtype=torch.int32) +tensor([1, 3, 5, 1, 4, 0, 1, 2, 1, 3, 4, 5, 0, 2, 3, 4, 5, 4, 5, 0, 1, 2, 4, 5], + device='cuda:0') # _values -tensor([[[ 1., 11.], - [ 2., 12.], - [ 3., 13.], - [ 4., 14.]], - - [[ 2., 12.], - [ 3., 13.], - [ 4., 14.], - [ 5., 15.]], +tensor([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., + 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.], device='cuda:0', + dtype=torch.float64) - [[ 3., 13.], - [ 4., 14.], - [ 5., 15.], - [ 6., 16.]], - [[ 4., 14.], - [ 5., 15.], - [ 6., 16.], - [ 7., 17.]]], device='cuda:0') - -########## torch.float32/torch.int32/size=(2, 3)+(2, 3)+(2, 1) ########## +########## torch.float32/torch.int32/size=(2, 3)+(2, 3)+(4,) ########## # sparse tensor tensor(crow_indices=tensor([[[0, 2, 4], [0, 3, 4], @@ -455,91 +283,44 @@ tensor(crow_indices=tensor([[[0, 2, 4], [[0, 1, 4], [0, 2, 4], [0, 3, 4]]]), - col_indices=tensor([[[0, 1, 0, 1], + col_indices=tensor([[[0, 1, 0, 2], [0, 1, 2, 0], [0, 0, 1, 2]], [[1, 0, 1, 2], [0, 2, 0, 1], [0, 1, 2, 1]]]), - values=tensor([[[[[ 1.], - [ 2.]], - - [[ 2.], - [ 3.]], - - [[ 3.], - [ 4.]], - - [[ 4.], - [ 5.]]], - - - [[[ 5.], - [ 6.]], - - [[ 6.], - [ 7.]], - - [[ 7.], - [ 8.]], - - [[ 8.], - [ 9.]]], - - - [[[ 9.], - [10.]], - - [[10.], - [11.]], - - [[11.], - [12.]], - - [[12.], - [13.]]]], - - - - [[[[13.], - [14.]], - - [[14.], - [15.]], - - [[15.], - [16.]], - - [[16.], - [17.]]], - - - [[[17.], - [18.]], - - [[18.], - [19.]], - - [[19.], - [20.]], - - [[20.], - [21.]]], - - - [[[21.], - [22.]], - - [[22.], - [23.]], - - [[23.], - [24.]], - - [[24.], - [25.]]]]]), device='cuda:0', size=(2, 3, 2, 3, 2, 1), - nnz=4, layout=torch.sparse_csr) + values=tensor([[[[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]]], + + + [[[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]]]]), device='cuda:0', + size=(2, 3, 2, 3, 4), nnz=4, layout=torch.sparse_csr) # _crow_indices tensor([[[0, 2, 4], [0, 3, 4], @@ -549,7 +330,7 @@ tensor([[[0, 2, 4], [0, 2, 4], [0, 3, 4]]], device='cuda:0', dtype=torch.int32) # _col_indices -tensor([[[0, 1, 0, 1], +tensor([[[0, 1, 0, 2], [0, 1, 2, 0], [0, 0, 1, 2]], @@ -557,108 +338,42 @@ tensor([[[0, 1, 0, 1], [0, 2, 0, 1], [0, 1, 2, 1]]], device='cuda:0', dtype=torch.int32) # _values -tensor([[[[[ 1.], - [ 2.]], - - [[ 2.], - [ 3.]], - - [[ 3.], - [ 4.]], - - [[ 4.], - [ 5.]]], - - - [[[ 5.], - [ 6.]], - - [[ 6.], - [ 7.]], - - [[ 7.], - [ 8.]], - - [[ 8.], - [ 9.]]], - - - [[[ 9.], - [10.]], - - [[10.], - [11.]], - - [[11.], - [12.]], - - [[12.], - [13.]]]], - - - - [[[[13.], - [14.]], - - [[14.], - [15.]], - - [[15.], - [16.]], - - [[16.], - [17.]]], - - - [[[17.], - [18.]], - - [[18.], - [19.]], - - [[19.], - [20.]], - - [[20.], - [21.]]], - - - [[[21.], - [22.]], - - [[22.], - [23.]], - - [[23.], - [24.]], - - [[24.], - [25.]]]]], device='cuda:0') - - -########## torch.float64/torch.int32/size=()+(2, 3)+(2,) ########## +tensor([[[[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]]], + + + [[[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]]]], device='cuda:0') + +########## torch.float32/torch.int32/size=()+(8, 6)+(4, 2) ########## # sparse tensor -tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 2]), - values=tensor([[1., 2.], - [2., 3.], - [3., 4.], - [4., 5.]]), device='cuda:0', size=(2, 3, 2), nnz=4, - dtype=torch.float64, layout=torch.sparse_csr) -# _crow_indices -tensor([0, 2, 4], device='cuda:0', dtype=torch.int32) -# _col_indices -tensor([0, 1, 0, 2], device='cuda:0', dtype=torch.int32) -# _values -tensor([[1., 2.], - [2., 3.], - [3., 4.], - [4., 5.]], device='cuda:0', dtype=torch.float64) - -########## torch.float64/torch.int32/size=()+(2, 3)+(4, 2) ########## -# sparse tensor -tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 2]), +tensor(crow_indices=tensor([ 0, 3, 5, 8, 8, 12, 17, 19, 24]), + col_indices=tensor([1, 3, 5, 1, 4, 0, 1, 2, 1, 3, 4, 5, 0, 2, 3, 4, 5, 4, + 5, 0, 1, 2, 4, 5]), values=tensor([[[ 1., 11.], [ 2., 12.], [ 3., 13.], @@ -677,12 +392,114 @@ tensor(crow_indices=tensor([0, 2, 4]), [[ 4., 14.], [ 5., 15.], [ 6., 16.], - [ 7., 17.]]]), device='cuda:0', size=(2, 3, 4, 2), - nnz=4, dtype=torch.float64, layout=torch.sparse_csr) + [ 7., 17.]], + + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], + + [[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], + + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], + + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]], + + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], + + [[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], + + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], + + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]], + + [[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], + + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]], + + [[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], + + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], + + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]], + + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], + + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]], + + [[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], + + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], + + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]], + + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], + + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]]), device='cuda:0', size=(8, 6, 4, 2), + nnz=24, layout=torch.sparse_csr) # _crow_indices -tensor([0, 2, 4], device='cuda:0', dtype=torch.int32) +tensor([ 0, 3, 5, 8, 8, 12, 17, 19, 24], device='cuda:0', + dtype=torch.int32) # _col_indices -tensor([0, 1, 0, 2], device='cuda:0', dtype=torch.int32) +tensor([1, 3, 5, 1, 4, 0, 1, 2, 1, 3, 4, 5, 0, 2, 3, 4, 5, 4, 5, 0, 1, 2, 4, 5], + device='cuda:0', dtype=torch.int32) # _values tensor([[[ 1., 11.], [ 2., 12.], @@ -702,9 +519,110 @@ tensor([[[ 1., 11.], [[ 4., 14.], [ 5., 15.], [ 6., 16.], - [ 7., 17.]]], device='cuda:0', dtype=torch.float64) + [ 7., 17.]], -########## torch.float64/torch.int32/size=(2, 3)+(2, 3)+(2, 1) ########## + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], + + [[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], + + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], + + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]], + + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], + + [[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], + + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], + + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]], + + [[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], + + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]], + + [[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], + + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], + + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]], + + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], + + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]], + + [[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], + + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], + + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]], + + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], + + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]], device='cuda:0') + + +########## torch.float64/torch.int32/size=(2, 3)+(2, 3)+(4,) ########## # sparse tensor tensor(crow_indices=tensor([[[0, 2, 4], [0, 3, 4], @@ -713,91 +631,44 @@ tensor(crow_indices=tensor([[[0, 2, 4], [[0, 1, 4], [0, 2, 4], [0, 3, 4]]]), - col_indices=tensor([[[0, 1, 0, 1], + col_indices=tensor([[[0, 1, 0, 2], [0, 1, 2, 0], [0, 0, 1, 2]], [[1, 0, 1, 2], [0, 2, 0, 1], [0, 1, 2, 1]]]), - values=tensor([[[[[ 1.], - [ 2.]], - - [[ 2.], - [ 3.]], - - [[ 3.], - [ 4.]], - - [[ 4.], - [ 5.]]], - - - [[[ 5.], - [ 6.]], - - [[ 6.], - [ 7.]], - - [[ 7.], - [ 8.]], - - [[ 8.], - [ 9.]]], - - - [[[ 9.], - [10.]], - - [[10.], - [11.]], - - [[11.], - [12.]], - - [[12.], - [13.]]]], - - - - [[[[13.], - [14.]], - - [[14.], - [15.]], - - [[15.], - [16.]], - - [[16.], - [17.]]], - - - [[[17.], - [18.]], - - [[18.], - [19.]], - - [[19.], - [20.]], - - [[20.], - [21.]]], - - - [[[21.], - [22.]], - - [[22.], - [23.]], - - [[23.], - [24.]], - - [[24.], - [25.]]]]]), device='cuda:0', size=(2, 3, 2, 3, 2, 1), - nnz=4, dtype=torch.float64, layout=torch.sparse_csr) + values=tensor([[[[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]]], + + + [[[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]]]]), device='cuda:0', + size=(2, 3, 2, 3, 4), nnz=4, dtype=torch.float64, layout=torch.sparse_csr) # _crow_indices tensor([[[0, 2, 4], [0, 3, 4], @@ -807,7 +678,7 @@ tensor([[[0, 2, 4], [0, 2, 4], [0, 3, 4]]], device='cuda:0', dtype=torch.int32) # _col_indices -tensor([[[0, 1, 0, 1], +tensor([[[0, 1, 0, 2], [0, 1, 2, 0], [0, 0, 1, 2]], @@ -815,108 +686,42 @@ tensor([[[0, 1, 0, 1], [0, 2, 0, 1], [0, 1, 2, 1]]], device='cuda:0', dtype=torch.int32) # _values -tensor([[[[[ 1.], - [ 2.]], - - [[ 2.], - [ 3.]], - - [[ 3.], - [ 4.]], - - [[ 4.], - [ 5.]]], - - - [[[ 5.], - [ 6.]], - - [[ 6.], - [ 7.]], - - [[ 7.], - [ 8.]], - - [[ 8.], - [ 9.]]], - - - [[[ 9.], - [10.]], - - [[10.], - [11.]], - - [[11.], - [12.]], - - [[12.], - [13.]]]], - - - - [[[[13.], - [14.]], - - [[14.], - [15.]], - - [[15.], - [16.]], - - [[16.], - [17.]]], - - - [[[17.], - [18.]], - - [[18.], - [19.]], - - [[19.], - [20.]], - - [[20.], - [21.]]], - - - [[[21.], - [22.]], - - [[22.], - [23.]], - - [[23.], - [24.]], - - [[24.], - [25.]]]]], device='cuda:0', dtype=torch.float64) - - -########## torch.float32/torch.int64/size=()+(2, 3)+(2,) ########## -# sparse tensor -tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 2]), - values=tensor([[1., 2.], - [2., 3.], - [3., 4.], - [4., 5.]]), device='cuda:0', size=(2, 3, 2), nnz=4, - layout=torch.sparse_csr) -# _crow_indices -tensor([0, 2, 4], device='cuda:0') -# _col_indices -tensor([0, 1, 0, 2], device='cuda:0') -# _values -tensor([[1., 2.], - [2., 3.], - [3., 4.], - [4., 5.]], device='cuda:0') - -########## torch.float32/torch.int64/size=()+(2, 3)+(4, 2) ########## +tensor([[[[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]]], + + + [[[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]]]], device='cuda:0', dtype=torch.float64) + +########## torch.float64/torch.int32/size=()+(8, 6)+(4, 2) ########## # sparse tensor -tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 2]), +tensor(crow_indices=tensor([ 0, 3, 5, 8, 8, 12, 17, 19, 24]), + col_indices=tensor([1, 3, 5, 1, 4, 0, 1, 2, 1, 3, 4, 5, 0, 2, 3, 4, 5, 4, + 5, 0, 1, 2, 4, 5]), values=tensor([[[ 1., 11.], [ 2., 12.], [ 3., 13.], @@ -935,12 +740,114 @@ tensor(crow_indices=tensor([0, 2, 4]), [[ 4., 14.], [ 5., 15.], [ 6., 16.], - [ 7., 17.]]]), device='cuda:0', size=(2, 3, 4, 2), - nnz=4, layout=torch.sparse_csr) + [ 7., 17.]], + + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], + + [[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], + + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], + + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]], + + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], + + [[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], + + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], + + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]], + + [[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], + + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]], + + [[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], + + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], + + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]], + + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], + + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]], + + [[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], + + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], + + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]], + + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], + + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]]), device='cuda:0', size=(8, 6, 4, 2), + nnz=24, dtype=torch.float64, layout=torch.sparse_csr) # _crow_indices -tensor([0, 2, 4], device='cuda:0') +tensor([ 0, 3, 5, 8, 8, 12, 17, 19, 24], device='cuda:0', + dtype=torch.int32) # _col_indices -tensor([0, 1, 0, 2], device='cuda:0') +tensor([1, 3, 5, 1, 4, 0, 1, 2, 1, 3, 4, 5, 0, 2, 3, 4, 5, 4, 5, 0, 1, 2, 4, 5], + device='cuda:0', dtype=torch.int32) # _values tensor([[[ 1., 11.], [ 2., 12.], @@ -960,9 +867,110 @@ tensor([[[ 1., 11.], [[ 4., 14.], [ 5., 15.], [ 6., 16.], - [ 7., 17.]]], device='cuda:0') + [ 7., 17.]], -########## torch.float32/torch.int64/size=(2, 3)+(2, 3)+(2, 1) ########## + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], + + [[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], + + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], + + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]], + + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], + + [[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], + + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], + + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]], + + [[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], + + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]], + + [[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], + + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], + + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]], + + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], + + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]], + + [[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], + + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], + + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]], + + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], + + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]], device='cuda:0', dtype=torch.float64) + + +########## torch.float32/torch.int64/size=(2, 3)+(2, 3)+(4,) ########## # sparse tensor tensor(crow_indices=tensor([[[0, 2, 4], [0, 3, 4], @@ -971,91 +979,44 @@ tensor(crow_indices=tensor([[[0, 2, 4], [[0, 1, 4], [0, 2, 4], [0, 3, 4]]]), - col_indices=tensor([[[0, 1, 0, 1], + col_indices=tensor([[[0, 1, 0, 2], [0, 1, 2, 0], [0, 0, 1, 2]], [[1, 0, 1, 2], [0, 2, 0, 1], [0, 1, 2, 1]]]), - values=tensor([[[[[ 1.], - [ 2.]], - - [[ 2.], - [ 3.]], - - [[ 3.], - [ 4.]], - - [[ 4.], - [ 5.]]], - - - [[[ 5.], - [ 6.]], - - [[ 6.], - [ 7.]], - - [[ 7.], - [ 8.]], - - [[ 8.], - [ 9.]]], - - - [[[ 9.], - [10.]], - - [[10.], - [11.]], - - [[11.], - [12.]], - - [[12.], - [13.]]]], - - - - [[[[13.], - [14.]], - - [[14.], - [15.]], - - [[15.], - [16.]], - - [[16.], - [17.]]], - - - [[[17.], - [18.]], - - [[18.], - [19.]], - - [[19.], - [20.]], - - [[20.], - [21.]]], - - - [[[21.], - [22.]], - - [[22.], - [23.]], - - [[23.], - [24.]], - - [[24.], - [25.]]]]]), device='cuda:0', size=(2, 3, 2, 3, 2, 1), - nnz=4, layout=torch.sparse_csr) + values=tensor([[[[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]]], + + + [[[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]]]]), device='cuda:0', + size=(2, 3, 2, 3, 4), nnz=4, layout=torch.sparse_csr) # _crow_indices tensor([[[0, 2, 4], [0, 3, 4], @@ -1065,7 +1026,7 @@ tensor([[[0, 2, 4], [0, 2, 4], [0, 3, 4]]], device='cuda:0') # _col_indices -tensor([[[0, 1, 0, 1], +tensor([[[0, 1, 0, 2], [0, 1, 2, 0], [0, 0, 1, 2]], @@ -1073,108 +1034,42 @@ tensor([[[0, 1, 0, 1], [0, 2, 0, 1], [0, 1, 2, 1]]], device='cuda:0') # _values -tensor([[[[[ 1.], - [ 2.]], - - [[ 2.], - [ 3.]], - - [[ 3.], - [ 4.]], - - [[ 4.], - [ 5.]]], - - - [[[ 5.], - [ 6.]], - - [[ 6.], - [ 7.]], - - [[ 7.], - [ 8.]], - - [[ 8.], - [ 9.]]], - - - [[[ 9.], - [10.]], - - [[10.], - [11.]], - - [[11.], - [12.]], - - [[12.], - [13.]]]], - - - - [[[[13.], - [14.]], - - [[14.], - [15.]], - - [[15.], - [16.]], - - [[16.], - [17.]]], - - - [[[17.], - [18.]], - - [[18.], - [19.]], - - [[19.], - [20.]], - - [[20.], - [21.]]], - - - [[[21.], - [22.]], - - [[22.], - [23.]], - - [[23.], - [24.]], - - [[24.], - [25.]]]]], device='cuda:0') - - -########## torch.float64/torch.int64/size=()+(2, 3)+(2,) ########## -# sparse tensor -tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 2]), - values=tensor([[1., 2.], - [2., 3.], - [3., 4.], - [4., 5.]]), device='cuda:0', size=(2, 3, 2), nnz=4, - dtype=torch.float64, layout=torch.sparse_csr) -# _crow_indices -tensor([0, 2, 4], device='cuda:0') -# _col_indices -tensor([0, 1, 0, 2], device='cuda:0') -# _values -tensor([[1., 2.], - [2., 3.], - [3., 4.], - [4., 5.]], device='cuda:0', dtype=torch.float64) - -########## torch.float64/torch.int64/size=()+(2, 3)+(4, 2) ########## +tensor([[[[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]]], + + + [[[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]]]], device='cuda:0') + +########## torch.float32/torch.int64/size=()+(8, 6)+(4, 2) ########## # sparse tensor -tensor(crow_indices=tensor([0, 2, 4]), - col_indices=tensor([0, 1, 0, 2]), +tensor(crow_indices=tensor([ 0, 3, 5, 8, 8, 12, 17, 19, 24]), + col_indices=tensor([1, 3, 5, 1, 4, 0, 1, 2, 1, 3, 4, 5, 0, 2, 3, 4, 5, 4, + 5, 0, 1, 2, 4, 5]), values=tensor([[[ 1., 11.], [ 2., 12.], [ 3., 13.], @@ -1193,12 +1088,113 @@ tensor(crow_indices=tensor([0, 2, 4]), [[ 4., 14.], [ 5., 15.], [ 6., 16.], - [ 7., 17.]]]), device='cuda:0', size=(2, 3, 4, 2), - nnz=4, dtype=torch.float64, layout=torch.sparse_csr) + [ 7., 17.]], + + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], + + [[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], + + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], + + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]], + + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], + + [[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], + + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], + + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]], + + [[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], + + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]], + + [[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], + + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], + + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]], + + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], + + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]], + + [[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], + + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], + + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]], + + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], + + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]]), device='cuda:0', size=(8, 6, 4, 2), + nnz=24, layout=torch.sparse_csr) # _crow_indices -tensor([0, 2, 4], device='cuda:0') +tensor([ 0, 3, 5, 8, 8, 12, 17, 19, 24], device='cuda:0') # _col_indices -tensor([0, 1, 0, 2], device='cuda:0') +tensor([1, 3, 5, 1, 4, 0, 1, 2, 1, 3, 4, 5, 0, 2, 3, 4, 5, 4, 5, 0, 1, 2, 4, 5], + device='cuda:0') # _values tensor([[[ 1., 11.], [ 2., 12.], @@ -1218,9 +1214,110 @@ tensor([[[ 1., 11.], [[ 4., 14.], [ 5., 15.], [ 6., 16.], - [ 7., 17.]]], device='cuda:0', dtype=torch.float64) + [ 7., 17.]], -########## torch.float64/torch.int64/size=(2, 3)+(2, 3)+(2, 1) ########## + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], + + [[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], + + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], + + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]], + + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], + + [[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], + + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], + + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]], + + [[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], + + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]], + + [[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], + + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], + + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]], + + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], + + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]], + + [[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], + + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], + + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]], + + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], + + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]], device='cuda:0') + + +########## torch.float64/torch.int64/size=(2, 3)+(2, 3)+(4,) ########## # sparse tensor tensor(crow_indices=tensor([[[0, 2, 4], [0, 3, 4], @@ -1229,91 +1326,44 @@ tensor(crow_indices=tensor([[[0, 2, 4], [[0, 1, 4], [0, 2, 4], [0, 3, 4]]]), - col_indices=tensor([[[0, 1, 0, 1], + col_indices=tensor([[[0, 1, 0, 2], [0, 1, 2, 0], [0, 0, 1, 2]], [[1, 0, 1, 2], [0, 2, 0, 1], [0, 1, 2, 1]]]), - values=tensor([[[[[ 1.], - [ 2.]], - - [[ 2.], - [ 3.]], - - [[ 3.], - [ 4.]], - - [[ 4.], - [ 5.]]], - - - [[[ 5.], - [ 6.]], - - [[ 6.], - [ 7.]], - - [[ 7.], - [ 8.]], - - [[ 8.], - [ 9.]]], - - - [[[ 9.], - [10.]], - - [[10.], - [11.]], - - [[11.], - [12.]], - - [[12.], - [13.]]]], - - - - [[[[13.], - [14.]], - - [[14.], - [15.]], - - [[15.], - [16.]], - - [[16.], - [17.]]], - - - [[[17.], - [18.]], - - [[18.], - [19.]], - - [[19.], - [20.]], - - [[20.], - [21.]]], - - - [[[21.], - [22.]], - - [[22.], - [23.]], - - [[23.], - [24.]], - - [[24.], - [25.]]]]]), device='cuda:0', size=(2, 3, 2, 3, 2, 1), - nnz=4, dtype=torch.float64, layout=torch.sparse_csr) + values=tensor([[[[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]]], + + + [[[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]]]]), device='cuda:0', + size=(2, 3, 2, 3, 4), nnz=4, dtype=torch.float64, layout=torch.sparse_csr) # _crow_indices tensor([[[0, 2, 4], [0, 3, 4], @@ -1323,7 +1373,7 @@ tensor([[[0, 2, 4], [0, 2, 4], [0, 3, 4]]], device='cuda:0') # _col_indices -tensor([[[0, 1, 0, 1], +tensor([[[0, 1, 0, 2], [0, 1, 2, 0], [0, 0, 1, 2]], @@ -1331,81 +1381,285 @@ tensor([[[0, 1, 0, 1], [0, 2, 0, 1], [0, 1, 2, 1]]], device='cuda:0') # _values -tensor([[[[[ 1.], - [ 2.]], - - [[ 2.], - [ 3.]], - - [[ 3.], - [ 4.]], - - [[ 4.], - [ 5.]]], - - - [[[ 5.], - [ 6.]], - - [[ 6.], - [ 7.]], - - [[ 7.], - [ 8.]], - - [[ 8.], - [ 9.]]], - - - [[[ 9.], - [10.]], - - [[10.], - [11.]], - - [[11.], - [12.]], - - [[12.], - [13.]]]], - - - - [[[[13.], - [14.]], - - [[14.], - [15.]], - - [[15.], - [16.]], - - [[16.], - [17.]]], - - - [[[17.], - [18.]], +tensor([[[[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]]], + + + [[[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]], + + [[1., 2., 3., 4.], + [2., 3., 4., 5.], + [3., 4., 5., 6.], + [4., 5., 6., 7.]]]], device='cuda:0', dtype=torch.float64) + +########## torch.float64/torch.int64/size=()+(8, 6)+(4, 2) ########## +# sparse tensor +tensor(crow_indices=tensor([ 0, 3, 5, 8, 8, 12, 17, 19, 24]), + col_indices=tensor([1, 3, 5, 1, 4, 0, 1, 2, 1, 3, 4, 5, 0, 2, 3, 4, 5, 4, + 5, 0, 1, 2, 4, 5]), + values=tensor([[[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], - [[18.], - [19.]], + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], - [[19.], - [20.]], + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]], - [[20.], - [21.]]], + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]], + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], + + [[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], + + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], + + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]], + + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], + + [[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], + + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], + + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]], + + [[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], + + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]], + + [[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], + + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], + + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]], + + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], + + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]], + + [[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], + + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], + + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]], + + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], + + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]]), device='cuda:0', size=(8, 6, 4, 2), + nnz=24, dtype=torch.float64, layout=torch.sparse_csr) +# _crow_indices +tensor([ 0, 3, 5, 8, 8, 12, 17, 19, 24], device='cuda:0') +# _col_indices +tensor([1, 3, 5, 1, 4, 0, 1, 2, 1, 3, 4, 5, 0, 2, 3, 4, 5, 4, 5, 0, 1, 2, 4, 5], + device='cuda:0') +# _values +tensor([[[ 1., 11.], + [ 2., 12.], + [ 3., 13.], + [ 4., 14.]], - [[[21.], - [22.]], + [[ 2., 12.], + [ 3., 13.], + [ 4., 14.], + [ 5., 15.]], - [[22.], - [23.]], + [[ 3., 13.], + [ 4., 14.], + [ 5., 15.], + [ 6., 16.]], - [[23.], - [24.]], + [[ 4., 14.], + [ 5., 15.], + [ 6., 16.], + [ 7., 17.]], - [[24.], - [25.]]]]], device='cuda:0', dtype=torch.float64) + [[ 5., 15.], + [ 6., 16.], + [ 7., 17.], + [ 8., 18.]], + + [[ 6., 16.], + [ 7., 17.], + [ 8., 18.], + [ 9., 19.]], + + [[ 7., 17.], + [ 8., 18.], + [ 9., 19.], + [10., 20.]], + + [[ 8., 18.], + [ 9., 19.], + [10., 20.], + [11., 21.]], + + [[ 9., 19.], + [10., 20.], + [11., 21.], + [12., 22.]], + + [[10., 20.], + [11., 21.], + [12., 22.], + [13., 23.]], + + [[11., 21.], + [12., 22.], + [13., 23.], + [14., 24.]], + + [[12., 22.], + [13., 23.], + [14., 24.], + [15., 25.]], + + [[13., 23.], + [14., 24.], + [15., 25.], + [16., 26.]], + + [[14., 24.], + [15., 25.], + [16., 26.], + [17., 27.]], + + [[15., 25.], + [16., 26.], + [17., 27.], + [18., 28.]], + + [[16., 26.], + [17., 27.], + [18., 28.], + [19., 29.]], + + [[17., 27.], + [18., 28.], + [19., 29.], + [20., 30.]], + + [[18., 28.], + [19., 29.], + [20., 30.], + [21., 31.]], + + [[19., 29.], + [20., 30.], + [21., 31.], + [22., 32.]], + + [[20., 30.], + [21., 31.], + [22., 32.], + [23., 33.]], + + [[21., 31.], + [22., 32.], + [23., 33.], + [24., 34.]], + + [[22., 32.], + [23., 33.], + [24., 34.], + [25., 35.]], + + [[23., 33.], + [24., 34.], + [25., 35.], + [26., 36.]], + + [[24., 34.], + [25., 35.], + [26., 36.], + [27., 37.]]], device='cuda:0', dtype=torch.float64) diff --git a/test/expect/TestTensorBoard.test_image_with_3_channel_batched.expect b/test/expect/TestTensorBoard.test_image_with_3_channel_batched.expect index 2895ff76fdb8f..bc63fadcd04d0 100644 --- a/test/expect/TestTensorBoard.test_image_with_3_channel_batched.expect +++ b/test/expect/TestTensorBoard.test_image_with_3_channel_batched.expect @@ -4,6 +4,6 @@ value { height: 8 width: 16 colorspace: 3 - encoded_image_string: "\211PNG\r\n\032\n\000\000\000\rIHDR\000\000\000\020\000\000\000\010\010\002\000\000\000\177\024\350\300\000\000\000+IDATx\234cd8\320\360\037\033pww\307*\316\362\343\307\217\037\330$~\374\370\361\037\233\004\013\016\365\377q\211\217H\r\000d\305y\224,\220Z\033\000\000\000\000IEND\256B`\202" + encoded_image_string: "\211PNG\r\n\032\n\000\000\000\rIHDR\000\000\000\020\000\000\000\010\010\002\000\000\000\177\024\350\300\000\000\000\034IDATx\234cd\370\377\377?\003\t\200\211$\325\014\014\014L$\252\037\231\032\000\355.\004\014i.\207\035\000\000\000\000IEND\256B`\202" } } diff --git a/test/expect/TestTensorBoard.test_image_with_boxes.expect b/test/expect/TestTensorBoard.test_image_with_boxes.expect index 4364b4841ef1d..1c28992dfa67c 100644 --- a/test/expect/TestTensorBoard.test_image_with_boxes.expect +++ b/test/expect/TestTensorBoard.test_image_with_boxes.expect @@ -4,6 +4,6 @@ value { height: 32 width: 32 colorspace: 3 - encoded_image_string: "\211PNG\r\n\032\n\000\000\000\rIHDR\000\000\000 \000\000\000 \010\002\000\000\000\374\030\355\243\000\000\000sIDATx\234\355\323=\n\300 \014\005\340\027p\250\267p\324\373\332\373\345\020vn\007\367>0\204b\311\233\305/\344G\000\334\236\021Uu\005R\000\377\007\244\224\342\013||\007\2655\330BfP\215\337S`>:{_l\020\335\242\tX6-\000\032r\007G\316\000\2561\226\201\244\252/\005V\357\026\271\003\033\0149\000\232\270\003+\260\301\220\003\240y\000T\221\324V\250_v\320\000\000\000\000IEND\256B`\202" + encoded_image_string: "\211PNG\r\n\032\n\000\000\000\rIHDR\000\000\000 \000\000\000 \010\002\000\000\000\374\030\355\243\000\000\000CIDATx\234cd\370\377\377?\003\r\001\023MMg```\242\261\371\243\026\214Z@\005\300B@\236\221\221B\013\006\334\007\020@Ai2\364#y\324\202Q\013F-\030\265`\324\202Q\013\206\207\005\0008\302\006@\2475\013\321\000\000\000\000IEND\256B`\202" } } diff --git a/test/expect/TestTensorBoard.test_image_with_one_channel.expect b/test/expect/TestTensorBoard.test_image_with_one_channel.expect index 7b43f507fc2d2..c37098115c1f6 100644 --- a/test/expect/TestTensorBoard.test_image_with_one_channel.expect +++ b/test/expect/TestTensorBoard.test_image_with_one_channel.expect @@ -4,6 +4,6 @@ value { height: 8 width: 8 colorspace: 3 - encoded_image_string: "\211PNG\r\n\032\n\000\000\000\rIHDR\000\000\000\010\000\000\000\010\010\002\000\000\000Km)\334\000\000\000\031IDATx\234cd``\370\217\r0\376\370\361\003\253\004\313\240\224\000\000;\267\273\313%\020=\255\000\000\000\000IEND\256B`\202" + encoded_image_string: "\211PNG\r\n\032\n\000\000\000\rIHDR\000\000\000\010\000\000\000\010\010\002\000\000\000Km)\334\000\000\000\035IDATx\234cd``\370\377\377?\003\006`\302*\312\300\300\300\204Ut\240%\000R\364\006\n\'\250a\364\000\000\000\000IEND\256B`\202" } } diff --git a/test/expect/TestTensorBoard.test_image_with_one_channel_batched.expect b/test/expect/TestTensorBoard.test_image_with_one_channel_batched.expect index e16187d04cb8e..8bd3a721b29f7 100644 --- a/test/expect/TestTensorBoard.test_image_with_one_channel_batched.expect +++ b/test/expect/TestTensorBoard.test_image_with_one_channel_batched.expect @@ -4,6 +4,6 @@ value { height: 8 width: 16 colorspace: 3 - encoded_image_string: "\211PNG\r\n\032\n\000\000\000\rIHDR\000\000\000\020\000\000\000\010\010\002\000\000\000\177\024\350\300\000\000\000(IDATx\234cd``\370\217\r\034?~\034\2538\313\217\037?~\374\370\201)\201U\020\252\001\253\304\250\006$\000\000\230\346y\315\204l;t\000\000\000\000IEND\256B`\202" + encoded_image_string: "\211PNG\r\n\032\n\000\000\000\rIHDR\000\000\000\020\000\000\000\010\010\002\000\000\000\177\024\350\300\000\000\000\036IDATx\234cd``\370\377\377?\003\321\200\211$\325\014\014\014L$\251\036\251\032\000\215\270\006\nS2\367\330\000\000\000\000IEND\256B`\202" } } diff --git a/test/expect/TestTensorBoard.test_image_without_channel.expect b/test/expect/TestTensorBoard.test_image_without_channel.expect index 7b43f507fc2d2..c37098115c1f6 100644 --- a/test/expect/TestTensorBoard.test_image_without_channel.expect +++ b/test/expect/TestTensorBoard.test_image_without_channel.expect @@ -4,6 +4,6 @@ value { height: 8 width: 8 colorspace: 3 - encoded_image_string: "\211PNG\r\n\032\n\000\000\000\rIHDR\000\000\000\010\000\000\000\010\010\002\000\000\000Km)\334\000\000\000\031IDATx\234cd``\370\217\r0\376\370\361\003\253\004\313\240\224\000\000;\267\273\313%\020=\255\000\000\000\000IEND\256B`\202" + encoded_image_string: "\211PNG\r\n\032\n\000\000\000\rIHDR\000\000\000\010\000\000\000\010\010\002\000\000\000Km)\334\000\000\000\035IDATx\234cd``\370\377\377?\003\006`\302*\312\300\300\300\204Ut\240%\000R\364\006\n\'\250a\364\000\000\000\000IEND\256B`\202" } } diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py index f10fd14393580..853f5206969b3 100644 --- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py +++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py @@ -95,6 +95,7 @@ ("aten::_linalg_inv_out_helper", datetime.date(2022, 10, 1)), ("aten::col2im_backward", datetime.date(2022, 12, 1)), ("aten::im2col_backward", datetime.date(2022, 12, 1)), + ("aten::diag_backward", datetime.date(2022, 12, 1)), ("aten::solve", datetime.date(9999, 1, 1)), ("aten::solve.solution", datetime.date(9999, 1, 1)), ("aten::_solve_helper", datetime.date(9999, 1, 1)), @@ -290,7 +291,36 @@ ("aten::nested_to_padded_tensor", datetime.date(2022, 10, 1)), ("aten::nested_tensor", datetime.date(2022, 10, 15)), ("aten::_nested_tensor_layer_norm", datetime.date(2022, 10, 15)), + ("aten::_torch_cuda_cu_linker_symbol_op", datetime.date(2022, 11, 1)), + ("aten::upsample_linear1d_backward", datetime.date(2022, 12, 15)), + ("aten::upsample_bicubic2d_backward", datetime.date(2022, 12, 15)), + ("aten::upsample_trilinear3d", datetime.date(2022, 12, 15)), + ("aten::upsample_bilinear2d", datetime.date(2022, 12, 15)), + ("aten::upsample_nearest3d", datetime.date(2022, 12, 15)), + ("aten::upsample_nearest2d_backward", datetime.date(2022, 12, 15)), + ("aten::upsample_bilinear2d_backward", datetime.date(2022, 12, 15)), + ("aten::upsample_trilinear3d_backward", datetime.date(2022, 12, 15)), + ("aten::upsample_nearest2d", datetime.date(2022, 12, 15)), + ("aten::upsample_bicubic2d", datetime.date(2022, 12, 15)), + ("aten::upsample_nearest1d_backward", datetime.date(2022, 12, 15)), + ("aten::upsample_nearest3d_backward", datetime.date(2022, 12, 15)), + ("aten::upsample_linear1d", datetime.date(2022, 12, 15)), + ("aten::upsample_nearest1d", datetime.date(2022, 12, 15)), + ("aten::_upsample_nearest_exact3d", datetime.date(2022, 12, 15)), + ("aten::_upsample_nearest_exact3d_backward", datetime.date(2022, 12, 15)), + ("aten::_upsample_bilinear2d_aa", datetime.date(2022, 12, 15)), + ("aten::_upsample_bilinear2d_aa_backward", datetime.date(2022, 12, 15)), + ("aten::_upsample_bicubic2d_aa", datetime.date(2022, 12, 15)), + ("aten::_upsample_bicubic2d_aa_backward", datetime.date(2022, 12, 15)), + ("aten::_upsample_nearest_exact1d", datetime.date(2022, 12, 15)), + ("aten::_upsample_nearest_exact1d_backward", datetime.date(2022, 12, 15)), + ("aten::_upsample_nearest_exact2d", datetime.date(2022, 12, 15)), + ("aten::_upsample_nearest_exact2d_backward", datetime.date(2022, 12, 15)), + ("aten::_flash_scaled_dot_product_attention", datetime.date(2022, 12, 15)), + ("aten::_scaled_dot_product_attention_forward", datetime.date(2022, 12, 15)), + ("aten::_efficient_attention_backward", datetime.date(2022, 12, 15)), + ("mkldnn::_convolution_pointwise.binary", datetime.date(2022, 12, 15)), ] ALLOW_LIST_COMPILED = [ diff --git a/test/functorch/common_utils.py b/test/functorch/common_utils.py index 1d7356b6ca7e5..41607bd62297c 100644 --- a/test/functorch/common_utils.py +++ b/test/functorch/common_utils.py @@ -14,6 +14,7 @@ import os import unittest from torch.testing._internal.common_device_type import toleranceOverride +from torch.testing._internal.autograd_function_db import autograd_function_db from collections import namedtuple IS_FBCODE = os.getenv('FUNCTORCH_TEST_FBCODE') == '1' @@ -222,10 +223,11 @@ def clone_if_tensor(x): return x.clone() return x - -def compute_quantities_for_vmap_test( +# Helper function to compare output of `vmap` against the +# `for-loop` version. +def _compute_quantities_for_vmap_test( op, orig_batched_args, orig_kwarg_values, in_dims, - out_dim=0, batch_size=2, compute_loop_out=True, + out_dim, batch_size, compute_loop_out=True, clone_inputs=False): def maybe_clone_inputs(): @@ -236,10 +238,12 @@ def maybe_clone_inputs(): return orig_batched_args, orig_kwarg_values batched_args, kwarg_values = maybe_clone_inputs() + if compute_loop_out: loop_out = loop(op, in_dims, out_dim, batch_size, *batched_args, **kwarg_values) else: loop_out = None + # Used for debugging the resulting operations # from functorch import make_fx # def f(a): @@ -248,7 +252,6 @@ def maybe_clone_inputs(): # print(in_dims, [arg.shape for arg in batched_args], kwarg_values) batched_args, kwarg_values = maybe_clone_inputs() batched_out = vmap(op, in_dims=in_dims, out_dims=out_dim)(*batched_args, **kwarg_values) - yield (loop_out, batched_out) # Tests case where we dispatch to a batching rule with no bdims # This should be handled by autogenerated plumbing. For vmap support @@ -262,24 +265,52 @@ def f(dummy, *args, **kwargs): return op(*args, **kwargs) dummy = torch.ones(batch_size, 1) - expected = pytree.tree_map(add_bdim_if_tensor, batched_out) + vmapvmap_expected = pytree.tree_map(add_bdim_if_tensor, batched_out) inner_in_dims = (0,) + pytree.tree_map(lambda x: None, in_dims) outer_in_dims = (0,) + in_dims batched_args, kwarg_values = maybe_clone_inputs() - output = vmap(vmap(f, inner_in_dims), outer_in_dims)(dummy, *batched_args, **kwarg_values) - yield (expected, output) + vmapvmap_output = vmap(vmap(f, inner_in_dims), outer_in_dims)(dummy, *batched_args, **kwarg_values) + + yield (batched_out, loop_out, vmapvmap_output, vmapvmap_expected) + + +# Function with more friendly return types +# compared to `_compute_quantities_for_vmap_test` +def compute_quantities_for_vmap_test( + op, orig_batched_args, orig_kwarg_values, in_dims, + out_dim=0, batch_size=2, compute_loop_out=True, + clone_inputs=False): + for quantities in _compute_quantities_for_vmap_test(op, orig_batched_args, orig_kwarg_values, in_dims, + out_dim, batch_size, compute_loop_out, clone_inputs): + yield (quantities[0], quantities[1]) + yield (quantities[2], quantities[3]) def get_fallback_and_vmap_exhaustive(op, arg_values, kwarg_values, is_batch_norm_and_training=False, compute_loop_out=True): out_dim = 0 batch_size = 2 + def make_batched(t): + if isinstance(t, torch.Tensor): + shape = list(t.shape) + shape.insert(out_dim, batch_size) + return t.expand(*shape) + return t + + # Inputs generated by `generate_vmap_inputs` just copy/expand the unbatched inputs + # over the batched dimension. Thus we can compute the expected value once and just + # expand it based on the `out_dim` and `batch_size`. + expected_unbatched = op(*arg_values, **kwarg_values) + expected_batched = pytree.tree_map(make_batched, expected_unbatched) generator = generate_vmap_inputs(arg_values, kwarg_values, is_batch_norm_and_training) for batched_args, in_dims, kwarg_values in generator: - for quantities in compute_quantities_for_vmap_test( - op, batched_args, kwarg_values, in_dims, out_dim, batch_size, compute_loop_out): - yield quantities + for quantities in _compute_quantities_for_vmap_test( + op, batched_args, kwarg_values, in_dims, out_dim, batch_size, + compute_loop_out=False): + assert quantities[1] is None + yield (quantities[0], expected_batched) + yield (quantities[2], quantities[3]) def opinfo_in_dict(opinfo, d): @@ -321,7 +352,7 @@ def skip(op_name, variant_name='', *, device_type=None, dtypes=None): def skipOps(test_case_name, base_test_name, to_skip): - all_opinfos = op_db + additional_op_db + all_opinfos = op_db + additional_op_db + autograd_function_db for decorate_meta in to_skip: matching_opinfos = [o for o in all_opinfos if o.name == decorate_meta.op_name and diff --git a/test/functorch/discover_coverage.py b/test/functorch/discover_coverage.py index e52f317087b4c..3f4f74b9224de 100644 --- a/test/functorch/discover_coverage.py +++ b/test/functorch/discover_coverage.py @@ -3,7 +3,7 @@ from torch.testing._internal.common_methods_invocations import op_db from functorch_additional_op_db import additional_op_db from enum import Enum -import functorch._src.top_operators_github_usage as top_ops +import torch._functorch.top_operators_github_usage as top_ops import pprint import unittest import enum diff --git a/test/functorch/functorch_additional_op_db.py b/test/functorch/functorch_additional_op_db.py index b090121d21807..9352924d5004f 100644 --- a/test/functorch/functorch_additional_op_db.py +++ b/test/functorch/functorch_additional_op_db.py @@ -4,8 +4,7 @@ import torch -from torch.testing import \ - (floating_types, floating_types_and, all_types_and_complex_and) +from torch.testing._internal.common_dtype import floating_types, floating_types_and, all_types_and_complex_and from torch.testing._internal.common_utils import make_tensor from torch.testing._internal.common_methods_invocations import OpInfo, SampleInput, DecorateInfo @@ -446,7 +445,8 @@ def sample_inputs_conversion(op_info, device, dtype, requires_grad, **kwargs): sample_inputs_func=sample_inputs_conversion, skips=( # autograd tests don't handle operators that change dtype - DecorateInfo(unittest.expectedFailure, 'TestGradients'), + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients'), + DecorateInfo(unittest.expectedFailure, 'TestBwdGradients'), DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), # RuntimeError: attribute lookup is not defined on builtin DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), @@ -512,7 +512,8 @@ def sample_inputs_conversion(op_info, device, dtype, requires_grad, **kwargs): sample_inputs_func=sample_inputs_conversion, skips=( # autograd tests don't handle operators that change dtype - DecorateInfo(unittest.expectedFailure, 'TestGradients'), + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients'), + DecorateInfo(unittest.expectedFailure, 'TestBwdGradients'), DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), # RuntimeError: attribute lookup is not defined on builtin DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), @@ -525,7 +526,8 @@ def sample_inputs_conversion(op_info, device, dtype, requires_grad, **kwargs): sample_inputs_func=sample_inputs_conversion, skips=( # autograd tests don't handle operators that change dtype - DecorateInfo(unittest.expectedFailure, 'TestGradients'), + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients'), + DecorateInfo(unittest.expectedFailure, 'TestBwdGradients'), DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), # RuntimeError: attribute lookup is not defined on builtin DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index e9a46b0882e2e..5eeca3ffc4cac 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -6,6 +6,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Union, Callable, List, Any, Optional, Dict from unittest.mock import patch from torch.testing._internal.common_utils import TestCase, run_tests, IS_ARM64, IS_WINDOWS import torch @@ -21,7 +22,7 @@ grad, vjp, vmap, jacrev, make_fx ) -from functorch._src.aot_autograd import aot_module_simplified +from torch._functorch.aot_autograd import aot_module_simplified from functorch.compile import ( nnc_jit, compiled_function, compiled_module, min_cut_rematerialization_partition, aot_function, aot_module, @@ -37,8 +38,9 @@ skip, skipOps, ) -from torch._subclasses.fake_tensor import DynamicOutputShapeException +from torch._subclasses.fake_tensor import DynamicOutputShapeException, FakeTensorMode from torch.fx.experimental.proxy_tensor import is_sym_node +from torch.fx.experimental.symbolic_shapes import ShapeEnv USE_TORCHVISION = False try: @@ -231,8 +233,8 @@ def f(x): self.assertEqual(grads, grads2) -def _outs_and_grads(fn, inps): - outs = fn(*inps) +def _outs_and_grads(fn, graph_inps, inps): + outs = fn(*graph_inps) for out in pytree.tree_flatten(outs)[0]: if isinstance(out, torch.Tensor) and out.requires_grad: out.sum().backward(retain_graph=True) @@ -243,14 +245,61 @@ def _outs_and_grads(fn, inps): class TestAOTAutograd(AOTTestCase): - def verify_aot_autograd(self, f, inp): + # test_mutation will: + # - Ensure that inputs are non-leaves, so our graphs can mutate them + # - try to mutate outputs of the graph (to ensure that autograd meta is set properly on outputs) + def verify_aot_autograd( + self, + f, + inp: Union[Callable, List[Any]], + *, + test_mutation: bool = False, + return_fw_graph: bool = False, + decompositions: Optional[Dict] = None, + ): + # Some tests pass in a callable for inp, to generate the inputs + # (useful if we want to generate complicated aliasing inputs) + if isinstance(inp, Callable): + inp_callable = inp + # The callable should return a tuple of f_inputs, f_graph_inputs + # (The idea is that we might want to compile a function with the graph inputs, + # but test autograd backprop all the way through the actual inputs) + inp_copy, graph_inps_copy = inp_callable() + inp, graph_inps = inp_callable() + else: + inp_copy = [] + # Our input clones need to mimic when inputs are duplicates of one another + dupes_map = {} + for i, x in enumerate(inp): + if x in dupes_map: + x_dupe_idx = dupes_map[x] + inp_copy.append(inp_copy[x_dupe_idx]) + else: + dupes_map[x] = i + x_copy = x.clone().detach().requires_grad_(x.requires_grad) + if x.requires_grad and not x.is_leaf: + x_copy = x_copy.clone() + inp_copy.append(x_copy) + + if test_mutation: + # For graphs where we mutate inputs, need our test to make sure inputs aren't leaves + graph_inps = [x.add(1) for x in inp] + graph_inps_copy = [x.add(1) for x in inp_copy] + else: + graph_inps = inp + graph_inps_copy = inp_copy + + # Create a copy of inputs, so we can test input mutation correctness. + + fw_graph_cell = [None] if isinstance(f, nn.Module): - compiled_f = aot_module(f, nop) + compiled_f = aot_module( + f, fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell), bw_compiler=nop, decompositions=decompositions) else: - compiled_f = aot_function(f, nop) - ref_out, ref_grad = _outs_and_grads(f, inp) - test_out, test_grad = _outs_and_grads(compiled_f, inp) - self.assertEqual(ref_out, test_out) + compiled_f = aot_function( + f, fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell), bw_compiler=nop, decompositions=decompositions) + ref_out, ref_grad = _outs_and_grads(f, graph_inps, inp) + test_out, test_grad = _outs_and_grads(compiled_f, graph_inps_copy, inp_copy) self.assertEqual(ref_grad, test_grad) if isinstance(ref_out, torch.Tensor): @@ -259,6 +308,24 @@ def verify_aot_autograd(self, f, inp): for ref_o, test_o in zip(ref_out, test_out): if isinstance(ref_o, torch.Tensor): self.assertEqual(ref_o.requires_grad, test_o.requires_grad) + self.assertEqual(ref_o.is_leaf, test_o.is_leaf) + if ref_o.requires_grad: + # _is_view() should probably unconditionally be the same, + # but in practice I don't think this matters for tensors that don't require grad + self.assertEqual(ref_o._is_view(), test_o._is_view()) + self.assertEqual(ref_o, test_o) + if test_mutation: + # This tests that autograd meta is set properly on the output we can + # mutate it. + ref_o.mul_(2) + test_o.mul_(2) + self.assertEqual(ref_o, test_o) + for ref_i, test_i in zip(inp, inp_copy): + if isinstance(ref_i, torch.Tensor): + self.assertEqual(ref_i.requires_grad, test_i.requires_grad) + self.assertEqual(ref_i, test_i) + if return_fw_graph: + return fw_graph_cell[0] def test_single_output(self): def f(a, b): @@ -278,6 +345,539 @@ def f(a, b): inp = [torch.randn(3, 3, requires_grad=True), torch.randn(3, 3)] self.verify_aot_autograd(f, inp) + def test_input_mutation_simple(self): + def f(a): + a.mul_(2) + return a * 3 + inp = [torch.ones(3, 3, requires_grad=True)] + + fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True, return_fw_graph=True) + # Things to note: + # - the extra clone is because we need to pass the pre-mutated input to grad(), + # but autograd operates above functionalization so we need to manually clone. + # Hopefully backends can optimize this easily. + # - The extra return arg is because the compiled forward returns (mutated inputs + outputs) + self.assertExpectedInline(fw_graph.code.strip(), """\ +def forward(self, primals_1): + clone = torch.ops.aten.clone.default(primals_1); primals_1 = None + mul = torch.ops.aten.mul.Tensor(clone, 2); clone = None + mul_1 = torch.ops.aten.mul.Tensor(mul, 3) + return [mul, mul_1]""") + + def test_input_mutation_is_output(self): + def f(a): + a.mul_(2) + return a + inp = [torch.ones(3, 3, requires_grad=True)] + + fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True, return_fw_graph=True) + self.assertExpectedInline(fw_graph.code.strip(), """\ +def forward(self, primals_1): + clone = torch.ops.aten.clone.default(primals_1); primals_1 = None + mul = torch.ops.aten.mul.Tensor(clone, 2); clone = None + return [mul]""") + + def test_input_mutation_multiple(self): + def f(a, b, c): + a.mul_(2) + c.mul_(2) + return a + b + c + + inp = [ + torch.ones(3, 3, requires_grad=True), + torch.ones(3, 3, requires_grad=True), + torch.ones(3, 3, requires_grad=True), + ] + + fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True, return_fw_graph=True) + self.assertExpectedInline(fw_graph.code.strip(), """\ +def forward(self, primals_1, primals_2, primals_3): + clone = torch.ops.aten.clone.default(primals_1); primals_1 = None + clone_1 = torch.ops.aten.clone.default(primals_3); primals_3 = None + mul = torch.ops.aten.mul.Tensor(clone, 2); clone = None + mul_1 = torch.ops.aten.mul.Tensor(clone_1, 2); clone_1 = None + add = torch.ops.aten.add.Tensor(mul, primals_2); primals_2 = None + add_1 = torch.ops.aten.add.Tensor(add, mul_1); add = None + return [mul, mul_1, add_1]""") + + def test_input_mutation_metadata(self): + def f(a, b): + a.transpose_(1, 0) + return a + b + inp = [ + torch.ones(3, 3, requires_grad=True), + torch.ones(3, 3, requires_grad=True), + ] + + self.verify_aot_autograd(f, inp, test_mutation=True, return_fw_graph=True) + + def test_input_mutation_metadata2(self): + def f(a): + a.transpose_(1, 0) + a.mul_(2) + return a + 1 + inp = [torch.ones(3, 3, requires_grad=True)] + + self.verify_aot_autograd(f, inp, test_mutation=True, return_fw_graph=True) + + def test_input_mutation_resize_smaller(self): + def f(a, b): + a.resize_(2, 2) + return a + b + # tenors that require gradients cannot be resized, so only test requires_grad=False case + inp = [ + torch.ones(3, 3), + torch.ones(2, 2, requires_grad=True), + ] + + self.verify_aot_autograd(f, inp, test_mutation=True, return_fw_graph=True) + + def test_input_mutation_batchnorm(self): + def f(inpt, weight, bias, running_mean, running_var): + # This is additionally a good test, because the input tensors that we mutate + # are *also* saved for backwards. + # This tests that what we save for the backward is actually cloned inputs, + # and not the original inputs that got mutated. + return torch._native_batch_norm_legit(inpt, weight, bias, running_mean, running_var, True, 0.5, 1e-5) + inp = [ + torch.ones(2, 5, 5, 5, requires_grad=True), + torch.ones(5, requires_grad=True), + torch.ones(5, requires_grad=True), + torch.ones(5), + torch.ones(5), + ] + + from torch._decomp import get_decompositions + # This simulates what inductor does (running the fw + bw decompositions) + decompositions = get_decompositions([ + torch.ops.aten._native_batch_norm_legit_functional, + torch.ops.aten.native_batch_norm_backward, + ]) + self.verify_aot_autograd(f, inp, test_mutation=True, return_fw_graph=True, decompositions=decompositions) + + def test_input_output_view_simple(self): + def f(a): + return a.view(-1) + inp = [ + torch.ones(2, 2, requires_grad=True).add(1), + ] + + fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True, return_fw_graph=True) + # Outputs that alias inputs are pulled out of the graph entirely, so we don't compile anything here + self.assertExpectedInline(fw_graph.code.strip(), """\ +def forward(self, primals_1): + return [4, 1, 0]""") + + def test_input_output_view_mutate_multiple(self): + def f(a, b, c): + a.mul_(2) + c.mul_(3) + return b.view(2, 2), c.view(2, 2) + inp = [ + torch.ones(2, 2, requires_grad=True).add(1), + torch.ones(2, 2, requires_grad=True).add(1), + torch.ones(2, 2, requires_grad=True).add(1), + ] + + fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True, return_fw_graph=True) + # The original function returned two outputs, both of which aliased inputs. + # We expect two outputs in the functional graph, a_updated and c_updated. + # The actual aliased outputs themselves aren't in the compiled forward graph; + # Instead, they're generated outside of the graph. + self.assertExpectedInline(fw_graph.code.strip(), """\ +def forward(self, primals_1, primals_2, primals_3): + clone = torch.ops.aten.clone.default(primals_1); primals_1 = None + clone_1 = torch.ops.aten.clone.default(primals_3); primals_3 = None + mul = torch.ops.aten.mul.Tensor(clone, 2); clone = None + mul_1 = torch.ops.aten.mul.Tensor(clone_1, 3); clone_1 = None + return [mul, mul_1, 2, 2, 2, 1, 0, 2, 2, 2, 1, 0]""") + + def test_input_output_view_metadata_mutate_multiple(self): + def f(a, b, c): + b.mul_(3) + c.t_() + return a.view(2, 2), b.view(2, 2), c.view(2, 2) + inp = [ + torch.ones(2, 2, requires_grad=True).add(1), + torch.ones(2, 2, requires_grad=True).add(1), + torch.ones(2, 2, requires_grad=True).add(1), + ] + + fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True, return_fw_graph=True) + # Important thing to check here: of the three inputs: + # Only the b.mul_(3) should show up in the graph (we functionalize it and return it). + # Everything else that does not show up in the graph includes: + # - The metadata mutation on c (we do it outside the graph) + # - All 3 original fw outputs, which are aliases of inputs (we regenerate them outside of the graph) + self.assertExpectedInline(fw_graph.code.strip(), """\ +def forward(self, primals_1, primals_2, primals_3): + clone = torch.ops.aten.clone.default(primals_2); primals_2 = None + mul = torch.ops.aten.mul.Tensor(clone, 3); clone = None + return [mul, 2, 2, 1, 2, 0, 2, 2, 2, 1, 0, 2, 2, 2, 1, 0, 2, 2, 1, 2, 0]""") + + def test_input_mutation_and_output_view(self): + def f(a): + a.add_(1) + return a.view(-1) + inp = [ + torch.ones(2, 2, requires_grad=True).add(1), + ] + + fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True, return_fw_graph=True) + # Here, total # of outputs is 1 because: + # - num_mutated_inps = 1 (a_updated) + # - num_fw_outputs = 0 (the output is an alias of the input, so we move it outside the compiled fw) + self.assertExpectedInline(fw_graph.code.strip(), """\ +def forward(self, primals_1): + clone = torch.ops.aten.clone.default(primals_1); primals_1 = None + add = torch.ops.aten.add.Tensor(clone, 1); clone = None + return [add, 4, 1, 0]""") + + + def test_input_mutation_output_view_multiple(self): + def f(a, b, c, d): + b.transpose_(1, 0) + c.add_(1) + return d + 1, b.diagonal(), a + c + inp = [ + torch.arange(4, requires_grad=True, dtype=torch.float32).view(2, 2).add(1), + torch.arange(4, requires_grad=True, dtype=torch.float32).view(2, 2).add(1), + torch.ones(2, 2, requires_grad=True).add(1), + torch.ones(2, 2, requires_grad=True).add(1), + ] + + fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True, return_fw_graph=True) + self.assertExpectedInline(fw_graph.code.strip(), """\ +def forward(self, primals_1, primals_2, primals_3, primals_4): + clone = torch.ops.aten.clone.default(primals_3); primals_3 = None + add = torch.ops.aten.add.Tensor(clone, 1); clone = None + add_1 = torch.ops.aten.add.Tensor(primals_4, 1); primals_4 = None + add_2 = torch.ops.aten.add.Tensor(primals_1, add); primals_1 = None + return [add, add_1, add_2, 2, 2, 1, 2, 0, 2, 3, 0]""") + + + def test_input_data_and_metadata_mutation(self): + def f(a): + a.t_() + a[0].mul_(2) + return a.view(a.shape) + inp = [torch.ones(3, 3, requires_grad=True)] + + fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True, return_fw_graph=True) + self.assertExpectedInline(fw_graph.code.strip(), """\ +def forward(self, primals_1): + clone = torch.ops.aten.clone.default(primals_1); primals_1 = None + t = torch.ops.aten.t.default(clone) + select = torch.ops.aten.select.int(t, 0, 0); t = None + mul = torch.ops.aten.mul.Tensor(select, 2); select = None + t_1 = torch.ops.aten.t.default(clone); clone = None + select_scatter = torch.ops.aten.select_scatter.default(t_1, mul, 0, 0); t_1 = mul = None + t_2 = torch.ops.aten.t.default(select_scatter); select_scatter = None + t_3 = torch.ops.aten.t.default(t_2); t_2 = None + return [t_3, 3, 3, 1, 3, 0]""") + + def test_view_and_inplace_view(self): + def f(a, b): + a.t_() + return b.view(b.shape), a.view(a.shape) + inp = [ + torch.ones(3, 3, requires_grad=True), + torch.ones(3, 3, requires_grad=True) + ] + + fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True, return_fw_graph=True) + self.assertExpectedInline(fw_graph.code.strip(), """\ +def forward(self, primals_1, primals_2): + return [3, 3, 1, 3, 0, 3, 3, 3, 1, 0, 3, 3, 1, 3, 0]""") + + def test_view_detach(self): + def f(a): + tmp = a.detach() + a.mul_(2) + return a, tmp + inp = [ + torch.ones(3, 3, requires_grad=True), + ] + + self.verify_aot_autograd(f, inp, test_mutation=True, return_fw_graph=True) + + def test_input_inplace_requires_grad_true(self): + def f(a, b): + a.requires_grad_(True) + return a.mul(3), b.mul(4) + inp = [ + # First inp doesnt require grad, but we switch it on + torch.ones(3, 3, requires_grad=False), + torch.ones(3, 3, requires_grad=True), + ] + + fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True, return_fw_graph=True) + self.assertExpectedInline(fw_graph.code.strip(), """\ +def forward(self, primals_1, primals_2): + mul = torch.ops.aten.mul.Tensor(primals_1, 3); primals_1 = None + mul_1 = torch.ops.aten.mul.Tensor(primals_2, 4); primals_2 = None + return [mul, mul_1]""") + + def test_input_data_and_metadata_mutation_aliases_other_input(self): + # a and b are aliased + def f(a, b): + a.t_() + b.mul_(2) + return a.mul(3) + + def inp_callable(): + base = torch.ones(2, 2, requires_grad=True) + # Note: in our test, the add() is important because we need the graph inputs to be non-leaves so we can mutate them. + x = base.add(1) + inp1 = x.view(-1) + inp2 = x.view(-1) + return [base], [inp1, inp2] + + fw_graph = self.verify_aot_autograd(f, inp_callable, test_mutation=True, return_fw_graph=True) + # Important parts of the graph: + # - the compiled graph takes in a base, and we generate a and b (the views) off of the base + # - clone() is still in the graph, because we need to call grad() on the original (non-mutated) inputs + # - We re-generate the views *after* the clone, to preserve view relationships. + self.assertExpectedInline(fw_graph.code.strip(), """\ +def forward(self, primals_1): + clone = torch.ops.aten.clone.default(primals_1); primals_1 = None + as_strided_1 = torch.ops.aten.as_strided.default(clone, [4], [1], 0) + mul = torch.ops.aten.mul.Tensor(as_strided_1, 2); as_strided_1 = None + as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, mul, [4], [1], 0); clone = None + as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0); as_strided_scatter = None + t_1 = torch.ops.aten.t.default(as_strided_5); as_strided_5 = None + mul_1 = torch.ops.aten.mul.Tensor(t_1, 3); t_1 = None + return [mul, mul_1, 4, 1, 0]""") + + def test_input_mutation_aliases_other_input(self): + def f(a, b): + a.add_(1) + return a + b + + def inp_callable(): + base = torch.ones(2, 2, requires_grad=True) + # Note: in our test, the add() is important because we need the graph inputs to be non-leaves so we can mutate them. + x = base.add(1) + inp1 = x[0] + inp2 = x[1] + return [base], [inp1, inp2] + + fw_graph = self.verify_aot_autograd(f, inp_callable, test_mutation=True, return_fw_graph=True) + # Important parts of the graph: + # - the compiled graph takes in a base, and we generate a and b (the views) off of the base + # - clone() is still in the graph, because we need to call grad() on the original (non-mutated) inputs + # - We re-generate the views *after* the clone, to preserve view relationships. + self.assertExpectedInline(fw_graph.code.strip(), """\ +def forward(self, primals_1): + clone = torch.ops.aten.clone.default(primals_1); primals_1 = None + as_strided = torch.ops.aten.as_strided.default(clone, [2], [1], 0) + add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None + as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [2], [1], 0); clone = None + as_strided_4 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 2); as_strided_scatter = None + add_1 = torch.ops.aten.add.Tensor(add, as_strided_4); as_strided_4 = None + return [add, add_1]""") + + def test_input_mutation_aliases_other_input2(self): + def f(a, b): + a.add_(1) + return a + b + + def inp_callable(): + base = torch.ones(2, 2, requires_grad=True) + x = base.add(1) + inp1 = x[0] + # Here, one of the aliased inputs is the base itself + inp2 = x + return [base], [inp1, inp2] + + fw_graph = self.verify_aot_autograd(f, inp_callable, test_mutation=True, return_fw_graph=True) + self.assertExpectedInline(fw_graph.code.strip(), """\ +def forward(self, primals_1): + clone = torch.ops.aten.clone.default(primals_1); primals_1 = None + as_strided = torch.ops.aten.as_strided.default(clone, [2], [1], 0) + add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None + as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [2], [1], 0); clone = None + as_strided_4 = torch.ops.aten.as_strided.default(as_strided_scatter, [2, 2], [2, 1], 0); as_strided_scatter = None + add_1 = torch.ops.aten.add.Tensor(add, as_strided_4); as_strided_4 = None + return [add, add_1]""") + + def test_input_mutation_aliases_and_output_alias(self): + def f(a, b): + # Here, we need to take care:that because and b are aliased + # (1) since a and b are aliased, we generate a view off of "updated b" + # (2) We're returning a view, which doesn't show up in the graph + a.add_(1) + return b.view(b.shape) + + def inp_callable(): + base = torch.ones(2, 2, requires_grad=True) + x = base.add(1) + return [base], [x.view(-1), x.view(-1)] + + fw_graph = self.verify_aot_autograd(f, inp_callable, test_mutation=True, return_fw_graph=True) + self.assertExpectedInline(fw_graph.code.strip(), """\ +def forward(self, primals_1): + clone = torch.ops.aten.clone.default(primals_1); primals_1 = None + as_strided = torch.ops.aten.as_strided.default(clone, [4], [1], 0); clone = None + add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None + return [add, 4, 1, 0]""") + + def test_input_aliased_with_mutation_output_alias(self): + def f(a, b, c): + # a and c alias + c.mul_(2) + # The main thing we're testing here is that + # (1) We need to reconstruct c.view(-1) from the 3rd input to the forward + # (2) But we need to be careful to do this *before* converting aliased inputs into synthetic bases. + # The original fw takes in 3 args, but the compiled fw takes in only 2 args. + return b.add(1), c.view(-1) + + def inp_callable(): + base1 = torch.ones(2, 2, requires_grad=True) + base2 = torch.ones(2, 2, requires_grad=True) + x = base1.add(1) + y = base2.add(1) + return [base1, base2], [x.view(-1), y, x.view(-1)] + + fw_graph = self.verify_aot_autograd(f, inp_callable, test_mutation=True, return_fw_graph=True) + self.assertExpectedInline(fw_graph.code.strip(), """\ +def forward(self, primals_1, primals_2): + clone = torch.ops.aten.clone.default(primals_1); primals_1 = None + as_strided_1 = torch.ops.aten.as_strided.default(clone, [4], [1], 0); clone = None + mul = torch.ops.aten.mul.Tensor(as_strided_1, 2); as_strided_1 = None + add = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None + return [mul, add, 4, 1, 0]""") + + def test_input_metadata_mutation_aliases(self): + def f(a, b): + # a and b alias, and we do a metadata mutation on a + # Since we're not mutating data, then b isn't affected at all. + # We expect aot autograd to not bother with constructing a synthetic base. + a.t_() + return a + b + + def inp_callable(): + base = torch.ones(2, 2, requires_grad=True) + x = base.add(1) + return [base], [x.view(-1), x.view(-1)] + + fw_graph = self.verify_aot_autograd(f, inp_callable, test_mutation=True, return_fw_graph=True) + # Expectation: fwd() takes in 2 args, and we don't construct a synthetic base. + self.assertExpectedInline(fw_graph.code.strip(), """\ +def forward(self, primals_1, primals_2): + view = torch.ops.aten.view.default(primals_1, [4]); primals_1 = None + t = torch.ops.aten.t.default(view); view = None + add = torch.ops.aten.add.Tensor(t, primals_2); t = primals_2 = None + return [add, 4, 1, 0]""") + + def test_input_mutation_aliases_and_none_require_gradients(self): + def f(a, b, c): + # a and b alias, but neither require gradients (so they don't have a _base) + # aot autograd should construct the synthetic base from `torch.Tensor(a.storage())` + a.mul_(2) + return b + 1, c + 1 + + def inp_callable(): + base = torch.ones(2, 2) + c_arg = torch.ones(2, 2, requires_grad=True) + x = base.add(1) + return [base, c_arg], [x.view(-1), x.view(-1), c_arg] + + fw_graph = self.verify_aot_autograd(f, inp_callable, test_mutation=True, return_fw_graph=True) + self.assertExpectedInline(fw_graph.code.strip(), """\ +def forward(self, primals_1, primals_2): + clone = torch.ops.aten.clone.default(primals_1); primals_1 = None + as_strided = torch.ops.aten.as_strided.default(clone, [4], [1], 0) + mul = torch.ops.aten.mul.Tensor(as_strided, 2); as_strided = None + as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, mul, [4], [1], 0); clone = None + as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0); as_strided_scatter = None + add = torch.ops.aten.add.Tensor(as_strided_2, 1); as_strided_2 = None + add_1 = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None + return [mul, add, add_1]""") + + def test_input_mutation_aliases_bases_out_of_order(self): + # This tests our calling convention: if b and d are aliased, then the outer calling convention + # that we send to the compiled forward becomes: + # (b_d_base, a, c) + # Importantly, even though a and c alias in our test, neither inputs are mutated, + # So we don't need to do the base construction / deconstruction + def f(a, b, c, d): + b.add_(1) + d.t_() + return a + c + d, b.view(-1) + + def inp_callable(): + base1 = torch.ones(2, 2, requires_grad=True) + base2 = torch.ones(2, 2, requires_grad=True) + x1 = base1.add(1) + x2 = base2.add(1) + # a and c alias, b and d alias + return [base1, base2], [x1.view(-1), x2.view(-1), x1.view(-1), x2.view(-1)] + + fw_graph = self.verify_aot_autograd(f, inp_callable, test_mutation=True, return_fw_graph=True) + # 3 graph inputs: (b_d_base, a, c) + # 2 returns: (b_updated, a+c+d) + # (there are 2 original fw outs, but one is a view of b so it's not part of the graph) + # (there are also 2 input mutations, but one is a metadata-only mutation so the compiled forward doesn't return it) + self.assertExpectedInline(fw_graph.code.strip(), """\ +def forward(self, primals_1, primals_2, primals_3): + clone = torch.ops.aten.clone.default(primals_1); primals_1 = None + as_strided = torch.ops.aten.as_strided.default(clone, [4], [1], 0) + add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None + add_1 = torch.ops.aten.add.Tensor(primals_2, primals_3); primals_2 = primals_3 = None + as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [4], [1], 0); clone = None + as_strided_4 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0); as_strided_scatter = None + t_1 = torch.ops.aten.t.default(as_strided_4); as_strided_4 = None + add_2 = torch.ops.aten.add.Tensor(add_1, t_1); add_1 = t_1 = None + return [add, add_2, 4, 1, 0, 4, 1, 0]""") + + # Mondo test that tests a combination of: + # input is mutated, that aliases another input (so we make a synthetic base) + # an output is an alias of another output + # an output is an alias of an intermediate + def test_input_mutation_alias_everything(self): + # a and c are aliased + def f(a, b, c): + c.mul_(2) # mutates c + b.t_() # metadata mutate b + tmp = a + c + # TODO: this test doesn't test "alias of an intermediate" yet, + # delete this line later and get that to be tested + return tmp, b.t(), a + out1 = tmp.view(-1) + out2 = b.t() + out3 = out1.unsqueeze(0) + # out1 and out3 are aliases of an intermediate, and alias each other! + # out2 aliases an input, so we don't return it + return out1, out2, out3 + + def inp_callable(): + base1 = torch.ones(2, 2, requires_grad=True) + base2 = torch.ones(2, 2, requires_grad=True) + # Note: in our test, the add() is important because we need the graph inputs to be non-leaves so we can mutate them. + base1_ = base1.add(1) + base2_ = base2.add(1) + a = base1_.view(-1) + b = base2_ + c = base1_.view(-1) + return [base1, base2], [a, b, c] + + fw_graph = self.verify_aot_autograd(f, inp_callable, test_mutation=True, return_fw_graph=True) + # Expected: + # - 2 inputs in the forward: synthetic_base_a_c, b + # - 1 output in the forward: "tmp" + # out2 is an alias of an input, and will be generated off of b outside of the compiled fn + # out1 and out3 are aliases of tmp, that we generate outside of the compiled function + self.assertExpectedInline(fw_graph.code.strip(), """\ +def forward(self, primals_1, primals_2): + clone = torch.ops.aten.clone.default(primals_1); primals_1 = None + as_strided_1 = torch.ops.aten.as_strided.default(clone, [4], [1], 0) + mul = torch.ops.aten.mul.Tensor(as_strided_1, 2); as_strided_1 = None + as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, mul, [4], [1], 0); clone = None + as_strided_4 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0); as_strided_scatter = None + add = torch.ops.aten.add.Tensor(as_strided_4, mul); as_strided_4 = None + return [mul, add, 2, 2, 1, 2, 0, 2, 2, 2, 1, 0]""") + def test_no_grad_input_output(self): def f(a, b): return a.cos(), b.cos(), a * b @@ -287,12 +887,26 @@ def f(a, b): inps = [i() for i in inps] self.verify_aot_autograd(f, inps) - def test_some_outputs_dont_require_grad(self): + def test_some_output_requires_grad_input_doesnt(self): + def f(a, b): + a_view = a.view(-1) + a_view.requires_grad_(True) + return a_view + inp = [torch.randn(3, 3), torch.randn(3, 3, requires_grad=True)] + self.verify_aot_autograd(f, inp) + + def test_some_outputs_dont_require_grad_view(self): def f(a, b): return a.detach(), b inp = [torch.randn(3, 3, requires_grad=True), torch.randn(3, 3, requires_grad=True)] self.verify_aot_autograd(f, inp) + def test_some_outputs_dont_require_grad_non_view(self): + def f(a, b): + return a.add(1).detach(), b + inp = [torch.randn(3, 3, requires_grad=True), torch.randn(3, 3, requires_grad=True)] + self.verify_aot_autograd(f, inp) + def test_inner_grad(self): def foo(x): y = torch.exp(x) @@ -340,8 +954,12 @@ def f(x): for k in x: new_d[k] = x[k] * 2 return new_d - inp = [{'a': torch.randn(3, requires_grad=True), 'b': torch.randn(3, requires_grad=True)}] - self.verify_aot_autograd(f, inp) + + def inp_callable(): + inps = [{'a': torch.randn(3, requires_grad=True), 'b': torch.randn(3, requires_grad=True)}] + return inps, inps + + self.verify_aot_autograd(f, inp_callable) def test_module(self): mod = nn.Sequential(nn.Linear(32, 32), nn.ReLU()) @@ -373,7 +991,8 @@ def f(a, b, c): inp = [torch.randn(5, requires_grad=True) for _ in range(3)] f(*inp).sum().backward() - def test_compilation_context(self): + @patch('torch._functorch.aot_autograd.AOT_COUNTER', new_callable=itertools.count) + def test_compilation_context(self, counter): def f(x): return x.sin().sin() count = [] @@ -388,7 +1007,7 @@ def compiler(fx_g, _): f = aot_function(f, compiler) f(torch.randn(5)) out.sum().backward() - self.assertEqual(count, [(['forward'], 4), (['inference'], 4), (['backward'], 8)]) + self.assertEqual(count, [(['0_forward'], 4), (['1_inference'], 4), (['0_backward'], 8)]) def test_dupe_arg(self): def f(x, y): @@ -397,6 +1016,89 @@ def f(x, y): x = torch.randn(3, 3, requires_grad=True) self.verify_aot_autograd(f, [x, x]) + def test_dupe_arg_torture(self): + def f(x, y): + x.t_() + y.t_() + return x + y + + x = torch.randn(3, 3, requires_grad=True).clone() + self.verify_aot_autograd(f, [x, x]) + + @patch('torch._functorch.aot_autograd.AOT_COUNTER', new_callable=itertools.count) + @patch("torch._functorch.config.debug_assert", True) + def test_invalid_dupe_left_bias(self, counter): + # This test checks that, just because only the first + # argument did a metadata mutation, we still correctly + # switch to strategy 2 (deduplicate) + # See: https://github.com/pytorch/pytorch/pull/89896#discussion_r1036224447 + class F(torch.nn.Module): + def forward(self, x, y): + x.t_() + return (x + y,) + + x = torch.randn(3, 3, requires_grad=True).clone() + y = torch.randn(3, 3, requires_grad=True) + self.verify_aot_autograd(F(), [x, x]) + + fxx = aot_module_simplified(F(), (x, x), nop) + self.assertExpectedRaisesInline( + AssertionError, lambda: fxx(x, y), + """At compilation time, graph 1 was compiled under the assumption that input 1 would be a duplicate of input 0, but at runtime this was not the case. This indicates a guard bug in AOTAutograd or Dynamo, please file a bug to PyTorch.""" # noqa: B950 + ) + + @patch('torch._functorch.aot_autograd.AOT_COUNTER', new_callable=itertools.count) + @patch("torch._functorch.config.debug_assert", True) + def test_invalid_dupe(self, counter): + class F(torch.nn.Module): + def forward(self, x, y): + x.t_() + y.t_() + return (x + y,) + + x = torch.randn(3, 3, requires_grad=True).clone() + y = torch.randn(3, 3, requires_grad=True).clone() + + fxy = aot_module_simplified(F(), (x, y), nop) + fxy(x, y) + fxy(x, x) # is ok! + + fxx = aot_module_simplified(F(), (x, x), nop) + fxx(x, x) + self.assertExpectedRaisesInline( + AssertionError, lambda: fxx(x, y), + """At compilation time, graph 1 was compiled under the assumption that input 1 would be a duplicate of input 0, but at runtime this was not the case. This indicates a guard bug in AOTAutograd or Dynamo, please file a bug to PyTorch.""" # noqa: B950 + ) + + @patch('torch._functorch.aot_autograd.AOT_COUNTER', new_callable=itertools.count) + @patch("torch._functorch.config.debug_assert", True) + def test_invalid_requires_grad(self, counter): + class F(torch.nn.Module): + def forward(self, x, y): + return (x + y,) + + x = torch.randn(3, 3, requires_grad=True) + y = torch.randn(3, 3, requires_grad=True) + z = torch.randn(3, 3, requires_grad=False) + + # Non-mutating please! + def compare(m1, m2, inps): + r1, g1 = _outs_and_grads(m1, inps, inps) + r2, g2 = _outs_and_grads(m2, inps, inps) + self.assertEqual(r1, r2) + self.assertEqual(g1, g2) + + fxy = aot_module_simplified(F(), (x, y), nop) + compare(F(), fxy, (x, y)) + compare(F(), fxy, (x, z)) + + fxz = aot_module_simplified(F(), (x, z), nop) + compare(F(), fxz, (x, z)) + self.assertExpectedRaisesInline( + AssertionError, lambda: fxz(x, y), + """At compilation time, graph 1 was compiled under the assumption that input 1 would not require grad, but at runtime this was not the case. This indicates a guard bug in AOTAutograd or Dynamo, please file a bug to PyTorch.""" # noqa: B950 + ) + def test_resize_input(self): def f(x, y): y.resize_(4) @@ -416,6 +1118,21 @@ def f(x, y): self.assertEqual(ref_out, test_out) + def test_custom_autograd(self): + class CustomFn(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + return x.clone() + + @staticmethod + def backward(ctx, grad_output): + return grad_output + 1 + + def f(x): + return CustomFn.apply(x) + + self.verify_aot_autograd(f, [torch.randn(3)]) + @unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable") def test_autocast_disable_guard(self): guard = torch._C._DisableAutocast() @@ -671,15 +1388,20 @@ def f(a, b, c, d): (compiled_outs[0].sum() + compiled_outs[2].sum()).backward() bw_graph = bw_graph_cell[0] - self.assertEqual(get_num_ins_outs(fw_graph), (4, 13)) - self.assertEqual(get_num_ins_outs(bw_graph), (13, 4)) + # 12 outs because: + # - 5 original outputs -> 4 graph outputs (the 3rd output is an input alias, gets moved outside) + # - 8 saved outputs for backward: 5 tensors, 3 symints + self.assertEqual(get_num_ins_outs(fw_graph), (4, 12)) + self.assertEqual(get_num_ins_outs(bw_graph), (12, 4)) _, fw_graph_out_nodes = get_ins_outs(fw_graph) self.assertEqual( # fw outputs include b.size() which expands to 2 symints, # # TODO(whc)- are the saved-tensors/saved-symints correct here? # i just made the test pass based on what default partition did - [False, True, True, False, False] + [False] * 5 + [True] * 3, + # Of the 5 original forward outputs, the 4th (c) is an input, + # which won't show up in the compiled forward graph + [False, True, True, False] + [False] * 4 + [True] * 4, [is_sym_node(n) for n in fw_graph_out_nodes] ) @@ -727,14 +1449,14 @@ def f(a, b, c, d): (compiled_outs[0].sum() + compiled_outs[2].sum()).backward() bw_graph = bw_graph_cell[0] - self.assertEqual(get_num_ins_outs(fw_graph), (4, 13)) - self.assertEqual(get_num_ins_outs(bw_graph), (13, 4)) + self.assertEqual(get_num_ins_outs(fw_graph), (4, 12)) + self.assertEqual(get_num_ins_outs(bw_graph), (12, 4)) _, fw_graph_out_nodes = get_ins_outs(fw_graph) self.assertEqual( # fw outputs include b.size() which expands to 2 symints, # then 4 tensors (transposes of matricies used for mm) are saved # finally 4 symints are saved - [False, True, True, False, False] + [False] * 4 + [True] * 4, + [False, True, True, False] + [False] * 4 + [True] * 4, [is_sym_node(n) for n in fw_graph_out_nodes] ) @@ -887,15 +1609,52 @@ def forward(self, x, y): ref = mod(*inputs) ref[0].sum().backward() - aot_mod = aot_module_simplified(mod, nop) - aot_mod.zero_grad() - res = aot_mod(*cloned_inputs) + compiled_f = aot_module_simplified(mod, cloned_inputs, nop) + mod.zero_grad() + res = compiled_f(*cloned_inputs) res[0].sum().backward() assert torch.allclose(ref[0], res[0]) assert torch.allclose(inputs[0].grad, cloned_inputs[0].grad) assert torch.allclose(inputs[1].grad, cloned_inputs[1].grad) + def test_aot_module_simplified_dynamic(self): + class MockModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(20, 30) + + def forward(self, x, y): + return (self.linear(x) + y, ) + + mod = MockModule() + + shape_env = ShapeEnv() + fake_mode = FakeTensorMode(shape_env=shape_env) + + x = torch.randn(128, 20, requires_grad=True) + y = torch.randn(128, 30, requires_grad=True) + + inputs = [x, y] + fake_inputs = [fake_mode.from_tensor(x) for x in inputs] + compiled_f = aot_module_simplified(mod, fake_inputs, nop) + + ref = mod(*inputs) + ref[0].sum().backward() + + cloned_inputs = [x.detach().clone().requires_grad_(True) for x in inputs] + res = compiled_f(*cloned_inputs) + res[0].sum().backward() + + self.assertExpectedInline(shape_env.format_guards(), """\ + - Eq(s3, 20) + - Eq(s9, 30)""") + + assert torch.allclose(ref[0], res[0]) + assert torch.allclose(inputs[0].grad, cloned_inputs[0].grad) + assert torch.allclose(inputs[1].grad, cloned_inputs[1].grad) + + def test_aot_module_simplified_preserves_stack_trace(self): class MockModule(torch.nn.Module): def __init__(self): @@ -927,14 +1686,95 @@ def assert_compiler(gm: torch.fx.GraphModule, _): assert 'test_aotdispatch.py' in node.stack_trace return gm.forward # return a python callable - aot_mod = aot_module_simplified(mod, fw_compiler=assert_compiler, bw_compiler=assert_compiler) - x = torch.randn(128, 20, requires_grad=True) y = torch.randn(128, 30, requires_grad=True) inputs = [x, y] - res = aot_mod(*inputs) + + compiled_f = aot_module_simplified(mod, inputs, fw_compiler=assert_compiler, bw_compiler=assert_compiler) + res = compiled_f(*inputs) res[0].sum().backward() + def _test_aot_module_simplified_fake_tensor_gm_raises(self, debug): + class MockModule(torch.nn.Module): + def __init__(self, y): + super().__init__() + self.linear = torch.nn.Linear(4, 4) + self.y = y + + def forward(self, x): + z = self.linear(x) + z = z + self.y + z = z.relu() + return (z, ) + + + real_x = torch.randn(4) + real_y = torch.randn(4) + fake_mode = torch._subclasses.fake_tensor.FakeTensorMode() + fake_y = fake_mode.from_tensor(real_y) + + tracer = torch.fx.Tracer() + tracer.record_stack_traces = True + + # This test uses tracing to lift the fake_y into a constant buffer, + # so we have a contrived trace example. + # For a traceless example closer to how dynamo would call us, see + # test_aot_module_deepcopy_fake_tensor_gm_raises below. + graph = tracer.trace(MockModule(fake_y)) + mod_fake = torch.fx.GraphModule(tracer.root, graph) + + if debug: + inner_message = "FAKE TENSOR CREATION TRACEBACK:" + else: + inner_message = "Enable TORCH_FAKE_TENSOR_DEBUG=1 to get creation stack traces on fake tensors." + + message = f"""Unexpected fake buffer y {inner_message}""" + + with self.assertRaisesRegex( + AssertionError, message + ): + aot_module_simplified(mod_fake, (real_x,), nop) + + # Counterfactual to ensure that the raise is only due to real vs fake + # Run the same exact thing except with a real buffer. + graph = tracer.trace(MockModule(real_y)) + mod_real = torch.fx.GraphModule(tracer.root, graph) + aot_module_simplified(MockModule(real_y), (real_x,), nop) + + @patch("torch._subclasses.fake_tensor.FakeTensorConfig.debug", True) + def test_aot_module_simplified_fake_tensor_gm_raises_debug_enabled(self): + self._test_aot_module_simplified_fake_tensor_gm_raises(debug=True) + + @patch("torch._subclasses.fake_tensor.FakeTensorConfig.debug", False) + def test_aot_module_simplified_fake_tensor_gm_raises_no_debug_enabled(self): + self._test_aot_module_simplified_fake_tensor_gm_raises(debug=False) + + def test_aot_module_deepcopy_fake_tensor_gm_raises(self): + class MockModule(torch.nn.Module): + def __init__(self, y): + super().__init__() + self.linear = torch.nn.Linear(4, 4) + self.linear.bias = torch.nn.Parameter(torch.ones(4)) + + def forward(self, x): + z = self.linear(x) + z = z.relu() + return (z, ) + + + real_x = torch.randn(4) + real_y = torch.randn(4) + + fake_mode = torch._subclasses.fake_tensor.FakeTensorMode() + mod_fake = torch._dynamo.utils.deepcopy_to_fake_tensor(MockModule(real_y), fake_mode) + + with self.assertRaisesRegex( + AssertionError, + """Unexpected fake param linear.weight""" + ): + aot_module_simplified(mod_fake, (real_x,), nop) + + # entries in here don't work and need to be fixed. # Each one of these is a bug (or needs to be investigated) aot_autograd_failures = { @@ -952,13 +1792,16 @@ def assert_compiler(gm: torch.fx.GraphModule, _): xfail('linalg.eig'), xfail('scatter_reduce', 'prod'), - # non-deterministic skip('as_strided_scatter'), + skip('as_strided', 'partial_views'), # flaky # Too annoying to generate random inputs xfail('cholesky'), xfail('linalg.cholesky'), + # Given input size: (s0xs1x2). Calculated output size: ... + skip('max_pool2d_with_indices_backward'), + # Misc xfail('to_sparse'), xfail('corrcoef'), @@ -975,40 +1818,27 @@ def assert_compiler(gm: torch.fx.GraphModule, _): symbolic_aot_autograd_failures = { xfail('__rmatmul__', ''), # Cannot call sizes() on tensor with symbolic sizes/strides - xfail('addcdiv', ''), # aten.fill_.Scalar - couldn't find symbolic meta function/decomposition xfail('addmv', ''), # aten.addmv.default - couldn't find symbolic meta function/decomposition xfail('addr', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('amax', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('amin', ''), # Cannot call sizes() on tensor with symbolic sizes/strides - xfail('as_strided', ''), # Tensor-likes are not close! xfail('baddbmm', ''), # aten.baddbmm.default - couldn't find symbolic meta function/decomposition - xfail('bernoulli', ''), # aten.bernoulli.default - couldn't find symbolic meta function/decomposition xfail('block_diag', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('cartesian_prod', ''), # Cannot call numel() on tensor with symbolic sizes/strides - xfail('cdouble'), # RuntimeError: aten.view_as_real.default - couldn't find symbolic meta function/decomposition - xfail('cfloat'), # RuntimeError: aten.view_as_real.default - couldn't find symbolic meta function/decomposition xfail('cdist', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('cholesky_inverse', ''), # could not find kernel xfail('cholesky_solve', ''), # could not find kernel - xfail('chunk', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('column_stack', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('combinations', ''), # aten.masked_select.default - xfail('complex', ''), # aten.view_as_real.default - couldn't find symbolic meta function/decomposition xfail('cross', ''), # aten.linalg_cross.default - couldn't find symbolic meta function/decomposition xfail('cummax', ''), # aten.cummax.default - couldn't find symbolic meta function/decomposition xfail('cummin', ''), # aten.cummin.default - couldn't find symbolic meta function/decomposition xfail('cumprod', ''), # aten.cumprod.default - couldn't find symbolic meta function/decomposition xfail('cumsum', ''), # aten.cumsum.default - couldn't find symbolic meta function/decomposition xfail('cumulative_trapezoid', ''), # Cannot call sizes() on tensor with symbolic sizes/strides - xfail('deg2rad', ''), # aten.deg2rad.default - couldn't find symbolic meta function/decomposition - xfail('diag', ''), # Cannot call sizes() on tensor with symbolic sizes/strides - xfail('diagonal', ''), # Cannot call sizes() on tensor with symbolic sizes/strides - xfail('diagonal_scatter', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('diff', ''), # aten.zeros_like.default - couldn't find symbolic meta function/decomposition xfail('digamma', ''), # aten.polygamma.default - couldn't find symbolic meta function/decomposition - xfail('dist', ''), # aten.dist.default - couldn't find symbolic meta function/decomposition xfail('dsplit', ''), # Cannot call sizes() on tensor with symbolic sizes/strides - xfail('einsum', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('fft.fft2', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('fft.fft', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('fft.fftn', ''), # Cannot call sizes() on tensor with symbolic sizes/strides @@ -1029,8 +1859,6 @@ def assert_compiler(gm: torch.fx.GraphModule, _): xfail('fft.rfft2', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('fft.rfft', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('fft.rfftn', ''), # Cannot call sizes() on tensor with symbolic sizes/strides - xfail('fmax', ''), # aten.logical_or_.default - couldn't find symbolic meta function/decomposition - xfail('fmin', ''), # aten.logical_or_.default - couldn't find symbolic meta function/decomposition xfail('frexp', ''), # aten.frexp.Tensor - couldn't find symbolic meta function/decomposition xfail('gradient', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('hsplit', ''), # Cannot call sizes() on tensor with symbolic sizes/strides @@ -1039,8 +1867,7 @@ def assert_compiler(gm: torch.fx.GraphModule, _): xfail('inner', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('kron', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('kthvalue', ''), # Cannot call sizes() on tensor with symbolic sizes/strides - xfail('lerp', ''), # aten.lerp.Scalar - couldn't find symbolic meta function/decomposition - xfail('linalg.cholesky_ex', ''), # aten.linalg_cholesky_ex.default - couldn't find symbolic meta functio... + xfail('linalg.cholesky_ex', ''), # could not find kernel for aten.linalg_solve_triangular.default xfail('linalg.cond', ''), # Cannot call numel() on tensor with symbolic sizes/strides xfail('linalg.cross', ''), # aten.linalg_cross.default - couldn't find symbolic meta function/decomposition xfail('linalg.det', ''), # aten._linalg_det.default - couldn't find symbolic meta function/decomposition @@ -1074,7 +1901,6 @@ def assert_compiler(gm: torch.fx.GraphModule, _): xfail('linalg.tensorinv', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('linalg.tensorsolve', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('linalg.vander', ''), # Cannot call sizes() on tensor with symbolic sizes/strides - xfail('linalg.vector_norm', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('logaddexp2', ''), # aten.logaddexp2.default - couldn't find symbolic meta function/decomposition xfail('logaddexp', ''), # aten.logaddexp.default - couldn't find symbolic meta function/decomposition xfail('logcumsumexp', ''), # aten.logcumsumexp.default - couldn't find symbolic meta function/decomposition @@ -1088,36 +1914,19 @@ def assert_compiler(gm: torch.fx.GraphModule, _): xfail('masked.cumprod', ''), # aten.cumprod.default - couldn't find symbolic meta function/decomposition xfail('masked.cumsum', ''), # aten.cumsum.default - couldn't find symbolic meta function/decomposition xfail('masked_fill', ''), # could not find kernel - xfail('masked.log_softmax', ''), # argument 'size' (position 2) must be tuple of ints, not ... xfail('masked.logaddexp', ''), # aten.logaddexp.default - couldn't find symbolic meta function/decomposi... xfail('masked.logsumexp', ''), # Cannot call sizes() on tensor with symbolic sizes/strides - xfail('masked.mean', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=t... - xfail('masked.median', ''), # Cannot call sizes() on tensor with symbolic sizes/strides - xfail('masked.norm', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('masked.prod', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('masked_scatter', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('masked_select', ''), # aten.masked_select.default - couldn't find symbolic meta function/decompos... - xfail('masked.softmax', ''), # argument 'size' (position 2) must be tuple of ints, not torc... - xfail('masked.softmin', ''), # argument 'size' (position 2) must be tuple of ints, not torc... - xfail('masked.std', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=to... - xfail('masked.sum', ''), # Cannot call sizes() on tensor with symbolic sizes/strides - xfail('masked.var', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=to... xfail('matmul', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('matrix_exp', ''), # aten.linalg_matrix_exp.default - couldn't find symbolic meta function/decompo... - xfail('max', 'reduction_no_dim'), # aten.logical_or_.default - couldn't find symbolic meta function/dec... - xfail('max', 'reduction_with_dim'), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('median', ''), # could not find kernel xfail('meshgrid', 'list_of_tensors'), # Cannot call numel() on tensor with symbolic sizes/strides xfail('meshgrid', 'variadic_tensors'), # Cannot call numel() on tensor with symbolic sizes/strides - xfail('min', 'reduction_no_dim'), # aten.logical_or_.default - couldn't find symbolic meta function/dec... xfail('min', 'reduction_with_dim'), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('mode', ''), # Cannot call sizes() on tensor with symbolic sizes/strides - xfail('msort', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('mv', ''), # Cannot call sizes() on tensor with symbolic sizes/strides - xfail('mvlgamma', 'mvlgamma_p_1'), # aten.digamma_.default - couldn't find symbolic meta function/decom... - xfail('mvlgamma', 'mvlgamma_p_3'), # aten.digamma_.default - couldn't find symbolic meta function/decom... - xfail('mvlgamma', 'mvlgamma_p_5'), # aten.digamma_.default - couldn't find symbolic meta function/decom... - xfail('nanmedian', ''), # aten.logical_or_.default - couldn't find symbolic meta function/decomposition xfail('nn.functional._scaled_dot_product_attention', ''), # Cannot call sizes() on tensor with symbolic ... xfail('nn.functional.adaptive_avg_pool3d', ''), # aten._adaptive_avg_pool3d_backward.default - couldn't ... xfail('nn.functional.adaptive_max_pool1d', ''), # Cannot call sizes() on tensor with symbolic sizes/strides @@ -1127,25 +1936,16 @@ def assert_compiler(gm: torch.fx.GraphModule, _): skip('nn.functional.batch_norm', ''), # '0 is not tracked with proxy for torch.Tensor: z2, z3 = z1.split(2) z2.add_(tmp) return x - self._check_functionalize_correctness(f, torch.zeros(4, 2, device=device)) + # See Note [Fix vmap slice_scatter] + self._check_functionalize_correctness(f, torch.zeros(4, 2, device=device), skip_vmap=True) # Ensure functionalize works with List[Optional[Tensor]] arguments. # See the fix / discussion at https://github.com/pytorch/pytorch/pull/76085 @@ -3241,6 +3444,7 @@ def f(x: torch.Tensor) -> torch.Tensor: self.assertEqual(out1, out2) self.assertEqual(inpt1, inpt2) + @unittest.skipIf(IS_FBCODE, 'fails in fbcode') def test_vmap_functionalize_jvp(self, device): def f(x: torch.Tensor) -> torch.Tensor: @@ -3429,6 +3633,7 @@ def forward(self, a_1, b_1) -> torch.Tensor: return index """) + @unittest.skipIf(IS_FBCODE, 'fails in fbcode') def test_functionalize_optional_tensorlist2(self, device): def f(a, b) -> torch.Tensor: @@ -3471,6 +3676,123 @@ def forward(self, x_1): """) +def construct_sum_pyop(): + mysum = PyOperator("mysum") + + @mysum.py_impl(torch._C._functorch.TransformType.Vmap) + def mysum_batch_rule(interpreter, x, dim): + if not torch._C._functorch.is_batchedtensor(x): + with interpreter.lower(): + x = x.view_as(x) # unnecessary, just here to test the dispatch + return mysum(x, dim) + + bdim = torch._C._functorch.maybe_get_bdim(x) + value = torch._C._functorch.get_unwrapped(x) + + with interpreter.lower(): + value = value.movedim(bdim, 0) + result = mysum(value, dim + 1) + + return torch._C._functorch._add_batch_dim(result, 0, interpreter.level()) + + @mysum.py_impl(torch._C._functorch.TransformType.Grad) + def mysum_grad_rule(interpreter, x, dim): + level = interpreter.level() + + class MySum(torch.autograd.function._SingleLevelFunction): + @staticmethod + def forward(ctx, x, dim): + ctx.x_shape = x.shape + ctx.dim = dim + x = torch._C._functorch._unwrap_for_grad(x, level) + with torch.enable_grad(), interpreter.lower(): + x = x.view_as(x) # unnecessary, just here to test the dispatch + y = mysum(x, dim) + + y = torch._C._functorch._wrap_for_grad(y, level) + return y + + @staticmethod + def backward(ctx, gy): + return gy.unsqueeze(ctx.dim).expand(ctx.x_shape), None + + with enable_autograd_function(): + return MySum.apply(x, dim) + + @mysum.py_impl(torch._C.DispatchKey.AutogradCPU) + def mysum_autograd_cpu(x, dim): + return torch.sum(x, dim) + + @mysum.py_impl(torch._C.DispatchKey.AutogradCUDA) + def mysum_autograd_cuda(x, dim): + return torch.sum(x, dim) + + return mysum + +sum_pyop = construct_sum_pyop() + +class TestPyOperatorInteraction(TestCase): + + def test_basic_sum(self, device): + x = torch.randn(2, 3, 4, device=device) + result = sum_pyop(x, 1) + self.assertEqual(result, torch.sum(x, 1)) + + def test_vmap_sum(self, device): + x = torch.randn(2, 3, 4, device=device) + result = vmap(sum_pyop, (0, None))(x, 0) + self.assertEqual(result, torch.sum(x, 1)) + + result = vmap(vmap(sum_pyop, (0, None)), (0, None))(x, 0) + self.assertEqual(result, torch.sum(x, 2)) + + def test_grad_sum(self, device): + x = torch.randn(3, device=device) + gx = grad(sum_pyop)(x, 0) + self.assertEqual(gx, torch.ones_like(x)) + + def test_grad_grad_sum(self, device): + x = torch.randn(3, requires_grad=True, device=device) + + def f(x): + # higher order grad. Requires a non-linearity + return sum_pyop(x.sin(), 0) + + def grad_f_sum(x): + return grad(f)(x).sum() + + ggx = grad(grad_f_sum)(x) + self.assertEqual(ggx, -x.sin()) + + def test_vmap_grad_sum(self, device): + x = torch.randn(2, 3, device=device) + gx = vmap(grad(sum_pyop), (0, None))(x, 0) + self.assertEqual(gx, torch.ones_like(x)) + + def test_no_grad_outside_grad(self, device): + x = torch.randn(3, device=device, requires_grad=True) + with torch.no_grad(): + y = grad(sum_pyop)(x, 0) + self.assertEqual(y, torch.ones_like(x)) + self.assertFalse(y.requires_grad) + + def test_no_grad_inside_grad(self, device): + def f(x): + with torch.no_grad(): + shift = sum_pyop(x ** 2, 0) + return sum_pyop(x ** 2, 0) - shift + + x = torch.randn(3, device=device) + y = grad(f)(x) + self.assertEqual(y, 2 * x) + y = grad(lambda x: grad(f)(x).sum())(x) + self.assertEqual(y, torch.full_like(x, 2)) + + x = torch.randn(3, device=device, requires_grad=True) + y = grad(f)(x) + z, = torch.autograd.grad(y.sum(), x) + self.assertEqual(z, torch.full_like(x, 2)) + only_for = ("cpu", "cuda") instantiate_device_type_tests( @@ -3513,11 +3835,21 @@ def forward(self, x_1): globals(), only_for=only_for, ) +instantiate_device_type_tests( + TestPyOperatorInteraction, + globals(), + only_for=only_for, +) instantiate_device_type_tests( TestFunctionalize, globals(), only_for=only_for, ) +instantiate_device_type_tests( + TestAutogradFunction, + globals(), + only_for=only_for, +) instantiate_parametrized_tests( TestMakeFunctional, ) diff --git a/test/functorch/test_memory_efficient_fusion.py b/test/functorch/test_memory_efficient_fusion.py index b0f18f06b8295..e12da51004504 100644 --- a/test/functorch/test_memory_efficient_fusion.py +++ b/test/functorch/test_memory_efficient_fusion.py @@ -6,7 +6,7 @@ from functorch import make_fx from torch.nn import functional as F from functorch.compile import memory_efficient_fusion -from functorch._src.compile_utils import fx_graph_cse +from torch._functorch.compile_utils import fx_graph_cse from torch.testing._internal.common_utils import TestCase, run_tests import inspect import random diff --git a/test/functorch/test_minifier.py b/test/functorch/test_minifier.py index 49af42795592d..7ed13921d9077 100644 --- a/test/functorch/test_minifier.py +++ b/test/functorch/test_minifier.py @@ -2,7 +2,7 @@ import torch from functorch.compile import minifier -from functorch._src.compile_utils import get_placeholders, get_outputs +from torch._functorch.compile_utils import get_placeholders, get_outputs from functorch import make_fx from torch.testing._internal.common_utils import TestCase, run_tests diff --git a/test/functorch/test_ops.py b/test/functorch/test_ops.py index 5dfe76b3e2877..75721a4e9f759 100644 --- a/test/functorch/test_ops.py +++ b/test/functorch/test_ops.py @@ -9,7 +9,8 @@ import itertools import unittest -from torch.testing._internal.common_utils import TestCase, run_tests, is_iterable_of_tensors, IS_ARM64, parametrize, TEST_WITH_ASAN +from torch.testing._internal.common_utils import TestCase, run_tests, is_iterable_of_tensors, IS_MACOS, \ + IS_ARM64, IS_X86, parametrize, TEST_WITH_ASAN, noncontiguous_like, IS_WINDOWS import torch from torch import Tensor import functools @@ -35,12 +36,16 @@ is_valid_inplace_sample_input, loop2, ) +from torch.testing._internal.autograd_function_db import ( + autograd_function_db +) +from torch.autograd.function import _set_autograd_function_extension_enabled from torch.testing._internal.opinfo.core import SampleInput from torch.utils._pytree import tree_flatten, tree_unflatten, tree_map from functorch import grad, vjp, vmap, jacrev, jacfwd import torch.autograd.forward_ad as fwAD -from functorch._src.eager_transforms import _as_tuple, jvp +from torch._functorch.eager_transforms import _as_tuple, jvp aten = torch.ops.aten @@ -290,6 +295,7 @@ def is_inplace(op, variant): vjp_fail = { xfail('tensor_split'), # data_ptr composite compliance + xfail('NumpyExpMarkDirtyAutogradFunction'), # https://github.com/pytorch/pytorch/issues/90225 } aliasing_ops = { @@ -337,12 +343,12 @@ def is_inplace(op, variant): @unittest.skipIf(TEST_WITH_ASAN, "tests time out with asan, are probably redundant") class TestOperators(TestCase): + @_set_autograd_function_extension_enabled() @with_tf32_off # https://github.com/pytorch/pytorch/issues/86798 - @ops(op_db + additional_op_db, allowed_dtypes=(torch.float,)) + @ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,)) @skipOps('TestOperators', 'test_grad', vjp_fail.union({ xfail('linalg.eig'), # diagonal_scatter does not support complex xfail('chalf', '', device_type='cpu'), # RuntimeError: "sum_cpu" not implemented for 'ComplexHalf' - skip('as_strided_scatter', ''), # silent incorrectness; seems flaky xfail('sparse.sampled_addmm', ''), # RuntimeError: Sparse CSR tensors do not have strides })) @opsToleranceOverride('TestOperators', 'test_grad', ( @@ -400,17 +406,37 @@ def wrapped_fn(*args, **kwargs): skip('nn.functional.max_unpool1d'), # fails everywhere except on mac skip('nn.functional.max_unpool2d'), # fails everywhere except on windows skip('nn.functional.max_unpool3d'), # fails everywhere except on mac - xfail("native_batch_norm"), + xfail("native_batch_norm"), # TODO: fails comparing None to tensor of 0s for saved_mean/var tangents + xfail("_native_batch_norm_legit"), # TODO: fails comparing None to tensor of 0s for saved_mean/var tangents + + xfail('nn.functional._scaled_dot_product_attention', device_type='cuda'), + + xfail('nn.functional.rrelu'), # in-place test errors out with no formula implemented - xfail('nn.functional.rrelu') # in-place test errors out with no formula implemented + # --- Non-Contiguous Failures! --- + # This is expected to fail as the operator + # expects last dim to have stride=1 + xfail('view_as_complex'), + # BUG + # AssertionError: Tensor-likes are not close! + xfail('as_strided'), + xfail('as_strided', 'partial_views'), + decorate('linalg.det', 'singular', + decorator=unittest.skipIf(IS_MACOS and IS_X86, "Fails on x86 MacOS CI")), })) @opsToleranceOverride('TestOperators', 'test_jvp', ( tol1('nn.functional.conv_transpose3d', {torch.float32: tol(atol=1e-04, rtol=1.3e-06)}, device_type='cuda'), + tol1('linalg.tensorsolve', + {torch.float32: tol(atol=1e-04, rtol=1.3e-05)}, device_type='cuda'), tol1('nn.functional.binary_cross_entropy_with_logits', {torch.float32: tol(atol=4e-04, rtol=4e-04)}), tol1('nn.functional.batch_norm', {torch.float32: tol(atol=4e-05, rtol=5e-05)}), + tol1('nn.functional.conv2d', + {torch.float32: tol(atol=4e-05, rtol=5e-05)}), + tol1('pca_lowrank', + {torch.float32: tol(atol=5e-05, rtol=5e-05)}), )) def test_jvp(self, device, dtype, op): # TODO: get rid of vjp_decomp when we add decomposition support to @@ -434,28 +460,38 @@ def test_jvp(self, device, dtype, op): inplace_variant = op.inplace_variant if op.supports_inplace_autograd else None for sample in samples: - args = (sample.input,) + sample.args - kwargs = sample.kwargs if outplace_variant: - self.jvp_opinfo_test(outplace_variant, args, kwargs, + self.jvp_opinfo_test(outplace_variant, sample, sample.output_process_fn_grad, clone_inputs=False, fixme_ref_jvp_local=fixme_ref_jvp_local) if is_valid_inplace_sample_input(sample, op, inplace_variant): - self.jvp_opinfo_test(inplace_variant, args, kwargs, + self.jvp_opinfo_test(inplace_variant, sample, sample.output_process_fn_grad, clone_inputs=True, fixme_ref_jvp_local=fixme_ref_jvp_local) - def jvp_opinfo_test(self, fn, args, kwargs, output_process_fn, + + def jvp_opinfo_test(self, fn, sample, output_process_fn, clone_inputs, fixme_ref_jvp_local): # NB: we used requires_grad=True to determine where the primals are, # but don't need that information otherwise - fn, primals = normalize_op_input_output2( + args = (sample.input,) + sample.args + kwargs = sample.kwargs + contig_fn, primals = normalize_op_input_output2( fn, args, kwargs, output_process_fn, requires_grad=True) orig_primals = tree_map(lambda x: x.detach(), primals) orig_tangents = tree_map(lambda x: torch.randn_like(x), primals) + noncontig_sample = sample.noncontiguous() + noncontig_args = (noncontig_sample.input,) + noncontig_sample.args + noncontig_kwargs = sample.kwargs + noncontig_fn, primals = normalize_op_input_output2( + fn, noncontig_args, noncontig_kwargs, + output_process_fn, requires_grad=True) + noncontig_primals = tree_map(lambda x: x.detach(), primals) + noncontig_tangents = tree_map(lambda x: noncontiguous_like(x), orig_tangents) + def maybe_clone_inputs(): if clone_inputs: primals = tree_map(torch.clone, orig_primals) @@ -465,24 +501,56 @@ def maybe_clone_inputs(): primals, tangents = maybe_clone_inputs() expected_primal_outs, expected_tangent_outs = \ - fixme_ref_jvp_local(fn, primals, tangents) + fixme_ref_jvp_local(contig_fn, primals, tangents) primals, tangents = maybe_clone_inputs() - primal_outs, tangent_outs = jvp(fn, primals, tangents) + primal_outs, tangent_outs = jvp(contig_fn, primals, tangents) + + noncontig_primal_outs, noncontig_tangent_outs = jvp(noncontig_fn, + noncontig_primals, + noncontig_tangents) self.assertEqual(primal_outs, expected_primal_outs) self.assertEqual(tangent_outs, expected_tangent_outs) - @ops(op_db + additional_op_db, allowed_dtypes=(torch.float,)) + self.assertEqual(noncontig_primal_outs, expected_primal_outs) + self.assertEqual(noncontig_tangent_outs, expected_tangent_outs) + + @_set_autograd_function_extension_enabled() + @ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,)) @skipOps('TestOperators', 'test_vjp', vjp_fail.union({ - skip('as_strided_scatter', ''), # silent incorrectness; also might be flaky xfail('sparse.sampled_addmm', ''), + + # ---- Non-Contiguous Failures ---- + # This is expected to fail as the operator + # expects last dim to have stride=1 + xfail('view_as_complex'), + # RuntimeError: query: last dimension must be contiguous + # NOTE: This passes on Windows! + decorate('nn.functional._scaled_dot_product_attention', + decorator=unittest.skipIf(not IS_WINDOWS, "expects contiguous inputs")), + # BUG + # AssertionError: Tensor-likes are not close! + xfail('as_strided'), + xfail('as_strided_scatter'), + xfail('_softmax_backward_data', device_type='cpu'), + xfail('as_strided', 'partial_views'), })) @opsToleranceOverride('TestOperators', 'test_vjp', ( tol1('nn.functional.conv_transpose3d', {torch.float32: tol(atol=5e-05, rtol=9e-05)}, device_type='cuda'), tol1('nn.functional.binary_cross_entropy_with_logits', {torch.float32: tol(atol=1e-04, rtol=1e-04)}), + tol1('__rmatmul__', + {torch.float32: tol(atol=1e-05, rtol=1e-05)}), + tol1('matmul', + {torch.float32: tol(atol=1e-05, rtol=1e-05)}), + tol2('linalg.pinv', 'hermitian', + {torch.float32: tol(atol=1e-05, rtol=1e-05)}), + tol1('linalg.tensorsolve', + {torch.float32: tol(atol=1e-05, rtol=1e-05)}), + tol1('svd_lowrank', + {torch.float32: tol(atol=1e-04, rtol=1e-04)}), )) def test_vjp(self, device, dtype, op): if not op.supports_autograd: @@ -499,14 +567,22 @@ def _test(_op, inplace=False): result = fn(*primals) cotangents = tree_map(lambda x: torch.randn_like(x), result) + noncontig_fn, noncontig_primals = normalize_op_input_output(_op, sample.noncontiguous()) + noncontig_cotangents = tree_map(lambda x: noncontiguous_like(x), cotangents) + out, vjp_fn = vjp(fn, *primals) self.assertEqual(out, result) result_vjps = vjp_fn(cotangents) + out_noncontig, vjp_fn = vjp(noncontig_fn, *noncontig_primals) + self.assertEqual(out_noncontig, result) + noncontig_result_vjps = vjp_fn(noncontig_cotangents) + _, vjp_fn = ref_vjp(fn, *primals) expected_vjps = vjp_fn(cotangents) self.assertEqual(result_vjps, expected_vjps) + self.assertEqual(noncontig_result_vjps, expected_vjps) _test(op) for a_op in op.aliases: @@ -516,13 +592,15 @@ def f(inp, *args, **kwargs): return op.inplace_variant(inp.clone(), *args, **kwargs) _test(f, inplace=True) - @ops(op_db + additional_op_db, allowed_dtypes=(torch.float,)) + @_set_autograd_function_extension_enabled() + @ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,)) @skipOps('TestOperators', 'test_vjpvjp', vjp_fail.union({ skip('nn.functional.max_unpool1d'), # silent incorrectness; Flaky skip('nn.functional.max_unpool2d'), # silent incorrectness; Flaky xfail('nn.functional.ctc_loss'), # Not Implemented xfail('native_layer_norm', ''), # Expected a proper Tensor but got None for argument #1 'other' xfail('sparse.sampled_addmm', ''), # sparse tensors have no strides + skip('nn.functional._scaled_dot_product_attention', device_type='cuda'), # AssertionError: Tensor-likes are not close! # Mismatched elements: 1 / 15 (6.7%) # Greatest absolute difference: 24.0 at index (2, 4) (up to 1e-05 allowed) @@ -587,6 +665,7 @@ def fn(inp, *args, **kwargs): skip("atleast_3d"), # Takes too long skip("ormqr"), # Takes too long xfail("as_strided"), # incorrect output + xfail("as_strided", "partial_views"), # incorrect output xfail("as_strided_scatter"), # incorrect output skip("bernoulli"), # calls random op xfail("bfloat16"), # rank 4 tensor for channels_last @@ -613,10 +692,11 @@ def fn(inp, *args, **kwargs): skip("nn.functional.dropout"), # calls random op skip("nn.functional.dropout2d"), # calls random op skip("nn.functional.dropout3d"), # calls random op + skip("nn.functional.alpha_dropout"), # calls random op skip("nn.functional.feature_alpha_dropout", "with_train"), # calls random op skip("nn.functional.fractional_max_pool2d"), # calls random op skip("nn.functional.fractional_max_pool3d"), # calls random op - skip('nn.functional._scaled_dot_product_attention'), # randomness + xfail('nn.functional._scaled_dot_product_attention'), # randomness # It looks like you're either (1) calling .item() on a Tensor or # (2) attempting to use a Tensor in some data-dependent control flow or # (3) encountering this error in PyTorch internals. @@ -654,6 +734,7 @@ def fn(inp, *args, **kwargs): # view doesn't work on sparse xfail("to_sparse"), xfail("native_batch_norm"), + xfail("_native_batch_norm_legit"), })) @ops(op_db + additional_op_db, allowed_dtypes=(torch.float,)) @toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)}) @@ -665,6 +746,9 @@ def fn(inp, *args, **kwargs): tol1('svd', {torch.float32: tol(atol=1e-03, rtol=5e-04)}), )) + @skipOps('TestOperators', 'test_vmapvjpvjp', { + xfail('as_strided', 'partial_views'), + }) def test_vmapvjpvjp(self, device, dtype, op): # Since, we test `vjpvjp` independently, # for this test, we just verify that vmap @@ -720,6 +804,7 @@ def vjp_of_vjp(*args_and_cotangents): skip('nn.functional.dropout'), # randomness skip('nn.functional.dropout2d'), # randomness skip('nn.functional.dropout3d', ''), # randomness + skip('nn.functional.alpha_dropout'), # randomness skip('nn.functional._scaled_dot_product_attention'), # randomness xfail('as_strided'), # as_strided is too wild for us to support, wontfix xfail('index_put', ''), # not possible due to dynamic shapes; we support a subset @@ -731,12 +816,14 @@ def vjp_of_vjp(*args_and_cotangents): xfail('svd_lowrank', ''), # randomness xfail('to_sparse', ''), # non-dense output skip('to'), # RuntimeError: required rank 4 tensor to use channels_last format + xfail('as_strided', 'partial_views'), # ---------------------------------------------------------------------- # ---------------------------- BUGS ------------------------------------ # All of the following are bugs and need to be fixed skip('linalg.svdvals'), # # really annoying thing where it passes correctness check but not has_batch_rule skip("native_batch_norm"), + skip("_native_batch_norm_legit"), xfail('__getitem__', ''), # dynamic error xfail('linalg.eig'), # Uses aten::allclose xfail('nanquantile', device_type='cpu'), # checks q via a .item() call @@ -779,7 +866,9 @@ def vjp_of_vjp(*args_and_cotangents): tol1('linalg.householder_product', {torch.float32: tol(atol=1e-04, rtol=1e-04)}), )) - @skipOps('TestOperators', 'test_vmapvjp', vmapvjp_fail) + @skipOps('TestOperators', 'test_vmapvjp', vmapvjp_fail.union({ + xfail('as_strided', 'partial_views'), + })) def test_vmapvjp(self, device, dtype, op): if not op.supports_autograd: self.skipTest("Skipped! Autograd not supported.") @@ -809,6 +898,7 @@ def test_vmapvjp(self, device, dtype, op): skip('nn.functional.dropout2d', ''), skip('nn.functional.dropout3d', ''), skip('nn.functional._scaled_dot_product_attention'), # randomness + skip('nn.functional.alpha_dropout'), # randomness skip('nn.functional.feature_alpha_dropout', 'without_train'), skip('nn.functional.feature_alpha_dropout', 'with_train'), xfail('nn.functional.fractional_max_pool2d'), # Cannot access data pointer of Tensor that doesn't have storage @@ -823,8 +913,10 @@ def test_vmapvjp(self, device, dtype, op): # ---------------------------- BUGS ------------------------------------ # The following are bugs that we should fix decorate('nn.functional.conv2d', decorator=unittest.skipIf(IS_ARM64, "Fails on M1")), + decorate('linalg.det', 'singular', decorator=unittest.skipIf(IS_MACOS, "Fails on x86 MacOS CI")), skip('nn.functional.max_pool1d'), # fails on cpu, runs on cuda xfail('masked.mean'), # silent incorrectness (nan difference) + xfail('as_strided', 'partial_views'), # Tensor-likes are not close! xfail('nn.functional.soft_margin_loss', ''), # soft_margin_loss_backward does not support forward-ad xfail('tensor_split'), # data_ptr composite compliance @@ -850,6 +942,7 @@ def test_vmapvjp(self, device, dtype, op): xfail('nn.functional.batch_norm'), xfail('nn.functional.batch_norm', 'without_cudnn'), xfail("native_batch_norm"), + xfail("_native_batch_norm_legit"), # ---------------------------------------------------------------------- } @@ -1052,6 +1145,8 @@ def test(): xfail('segment_reduce', 'lengths'), xfail('sparse.sampled_addmm', ''), xfail("native_batch_norm"), + xfail("_native_batch_norm_legit"), + xfail("native_dropout_backward"), })) def test_vmapvjp_has_batch_rule(self, device, dtype, op): if not op.supports_autograd: @@ -1089,6 +1184,8 @@ def test(): skip('nn.functional.rrelu'), # randomness skip('nn.functional.feature_alpha_dropout', 'with_train'), # randomness skip('nn.functional.feature_alpha_dropout', 'without_train'), # randomness + skip('nn.functional._scaled_dot_product_attention', device_type='cuda'), + skip('nn.functional.alpha_dropout'), # randomness skip('to'), # RuntimeError: required rank 4 tensor to use channels_last format skip('to_sparse', ''), # non-dense output skip('ormqr', ''), # takes too long @@ -1109,7 +1206,6 @@ def test(): xfail('svd_lowrank', ''), xfail('pca_lowrank', ''), xfail('clamp'), - xfail('cross'), # The defaults of this op are *very* weird. No wonder it doesn't work # something weird happening with channels_last xfail('bfloat16'), xfail('double'), @@ -1121,6 +1217,8 @@ def test(): xfail('as_strided_scatter', ''), xfail('sparse.sampled_addmm', ''), xfail("native_batch_norm"), + xfail("_native_batch_norm_legit"), + xfail('as_strided', 'partial_views'), })) def test_vjpvmap(self, device, dtype, op): # NB: there is no vjpvmap_has_batch_rule test because that is almost @@ -1199,6 +1297,7 @@ def get_vjp(cotangents, *primals): xfail('logcumsumexp', ''), # NYI: forward-AD for logcumsumexp xfail('nn.functional.embedding_bag', ''), # NYI: forward-AD for _embedding_bag xfail('nn.functional.grid_sample', ''), # NYI: forward AD for grid_sampler_2d + xfail('grid_sampler_2d', ''), # NYI: forward AD for grid_sampler_2d xfail('nn.functional.hardsigmoid', ''), # NYI: forward AD for hardsigmoid_backward xfail('nn.functional.huber_loss', ''), # NYI: forward AD for huber_loss_backward xfail('nn.functional.logsigmoid', ''), # not differentiable w.r.t. buffer @@ -1210,13 +1309,15 @@ def get_vjp(cotangents, *primals): xfail('nn.functional.soft_margin_loss', ''), # NYI: forward-AD for log_sigmoid_backward xfail('nn.functional.ctc_loss', ''), # NYI: forward-AD for _ctc_loss xfail('nn.functional.pdist', ''), # NYI: forward-AD with _pdist_forward + skip('nn.functional._scaled_dot_product_attention', device_type='cuda'), xfail('nn.functional.multi_margin_loss', ''), # NYI: forward AD with multi_margin_loss skip('linalg.householder_product', '', device_type='cuda'), # flaky, I'm not sure why xfail('sparse.sampled_addmm', ''), # Sparse tensors have no strides - skip('as_strided_scatter', ''), # seems flaky xfail('segment_reduce', 'offsets'), # NYI: forward-AD for segment_reduce xfail('index_reduce', ''), # NYI: forward-AD for index_reduce xfail('segment_reduce', 'lengths'), # NYI: forward-AD for segment_reduce + xfail('native_dropout_backward'), # NYI + })) @opsToleranceOverride('TestOperators', 'test_jvpvjp', ( tol1('masked.prod', @@ -1300,6 +1401,7 @@ def reference(primals, cotangents, primals_tangents, cotangents_tangents): # Potential bugs/errors xfail('as_strided'), # AssertionError: Tensor-likes are not close! + xfail('as_strided', 'partial_views'), # AssertionError: Tensor-likes are not close! xfail('as_strided_scatter'), # AssertionError: Tensor-likes are not close! xfail('bernoulli'), # calls random op xfail('bfloat16'), # required rank 4 tensor to use channels_last format @@ -1329,13 +1431,15 @@ def reference(primals, cotangents, primals_tangents, cotangents_tangents): xfail('nn.functional.dropout2d'), # calls random op xfail('nn.functional.dropout3d'), # calls random op xfail('nn.functional.dropout'), # calls random op - skip('nn.functional._scaled_dot_product_attention'), # randomness + xfail('nn.functional._scaled_dot_product_attention'), # randomness xfail('nn.functional.embedding_bag'), # Forward AD not implemented and no decomposition + xfail('nn.functional.alpha_dropout'), # calls randomn op xfail('nn.functional.feature_alpha_dropout', 'with_train'), # calls random op xfail('nn.functional.fractional_max_pool2d'), # calls random op xfail('nn.functional.fractional_max_pool3d'), # calls random op xfail('nn.functional.gaussian_nll_loss'), # data depenedant flow xfail('nn.functional.grid_sample'), # Forward AD not implemented and no decomposition + xfail('grid_sampler_2d'), # Forward AD not implemented and no decomposition xfail('nn.functional.hardsigmoid'), # Forward AD not implemented and no decomposition xfail('nn.functional.hinge_embedding_loss'), # vmap: inplace into a regular tensor xfail('nn.functional.huber_loss'), # Forward AD not implemented and no decomposition @@ -1373,6 +1477,8 @@ def reference(primals, cotangents, primals_tangents, cotangents_tangents): # input while the running_mean or running_var, which will be updated in # place, were not batched. xfail("native_batch_norm"), + xfail("_native_batch_norm_legit"), + xfail('native_dropout_backward',) })) @ops(op_db + additional_op_db, allowed_dtypes=(torch.float,)) @toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)}) @@ -1612,7 +1718,6 @@ def fn(input, weight, bias): skip('linalg.multi_dot', '', device_type='cpu'), skip('sparse.sampled_addmm', ''), skip('native_layer_norm', '', device_type='cpu'), - xfail('as_strided_scatter', ''), }) @opsToleranceOverride('TestOperators', 'test_vmap_autograd_grad', ( tol1('linalg.householder_product', diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py index 2ee0bc8537604..b07928d128b6e 100644 --- a/test/functorch/test_vmap.py +++ b/test/functorch/test_vmap.py @@ -49,7 +49,7 @@ from functorch import vmap, grad, grad_and_value, jvp, vjp, jacfwd from functorch.experimental import chunk_vmap from torch._C._functorch import reshape_dim_into, reshape_dim_outof -from functorch._src.make_functional import functional_init_with_buffers +from torch._functorch.make_functional import functional_init_with_buffers FALLBACK_REGEX = 'There is a performance drop' @@ -3219,6 +3219,7 @@ def test(): xfail('nn.functional.rrelu'), # randomness xfail('nn.functional.dropout2d', ''), # randomness xfail('nn.functional.dropout3d', ''), # randomness + xfail('nn.functional.alpha_dropout', ''), # randomness xfail('nn.functional.feature_alpha_dropout', 'with_train'), # randomness xfail('as_strided'), # Our test runner can't handle this; manual test exists skip('new_empty_strided'), # empty tensor data is garbage so it's hard to make comparisons with it @@ -3229,15 +3230,15 @@ def test(): xfail('linspace', ''), # test runner can't handle factory functions xfail('arange', ''), # test runner can't handle factory functions xfail('logspace', ''), # test runner can't handle factory functions + xfail('scalar_tensor'), # test runner can't handle factory functions xfail('empty', ''), # test runner can't handle factory functions xfail('ones', ''), # test runner can't handle factory functions xfail('zeros', ''), # test runner can't handle factory functions + xfail('full', ''), # test runner can't handle factory functions xfail('eye', ''), # non-tensor input xfail('broadcast_shapes', ''), # test runner can't handle non-Tensor ops xfail('sparse.sampled_addmm'), # sparse - xfail('cross'), # The default value of dim in op is *very* weird. No wonder it doesn't work - xfail('svd', device_type='cuda'), # not unique, see test_linalg_svd for manual test - xfail('linalg.svd', device_type='cuda'), # not unique, see test_linalg_svd for manual test + skip('_softmax_backward_data'), skip('linalg.eigh', ''), # not unique, see test_linalg_eigh for manual test skip('to'), # RuntimeError: required rank 4 tensor to use channels_last format # ---------------------------------------------------------------------- @@ -3292,7 +3293,15 @@ def test(): )) @toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)}) @skipOps('TestVmapOperatorsOpInfo', 'test_vmap_exhaustive', vmap_fail.union({ + # RuntimeError: Batch norm got a batched tensor as input while the running_mean or running_var, + # which will be updated in place, were not batched. xfail('native_batch_norm'), + xfail('_native_batch_norm_legit'), + xfail('tril'), # Exception not raised on error input + xfail('triu'), # Exception not raised on error input + # The error inputs are vectors, that pass when batched as they are treated as a matrix + xfail('trace'), + xfail('as_strided', 'partial_views'), })) def test_vmap_exhaustive(self, device, dtype, op): # needs to be fixed @@ -3308,10 +3317,14 @@ def test_vmap_exhaustive(self, device, dtype, op): )) @toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)}) @skipOps('TestVmapOperatorsOpInfo', 'test_op_has_batch_rule', vmap_fail.union({ + xfail('as_strided', 'partial_views'), skip('to'), # RuntimeError: required rank 4 tensor to use channels_last format xfail('complex'), xfail('copysign'), + # Batch norm got a batched tensor as input while the running_mean or running_var, + # which will be updated in place, were not batched. xfail('native_batch_norm'), + xfail('_native_batch_norm_legit'), xfail('histogram'), xfail('index_fill'), xfail('nansum'), @@ -3329,7 +3342,6 @@ def test_vmap_exhaustive(self, device, dtype, op): xfail('masked_scatter'), xfail('masked_select'), xfail('nanquantile'), - xfail('narrow_copy'), # hit the vmap fallback which is currently disabled xfail('ormqr'), xfail('put'), xfail('quantile'), @@ -3339,6 +3351,8 @@ def test_vmap_exhaustive(self, device, dtype, op): xfail('tensor_split'), xfail('to_sparse'), xfail('vdot'), + xfail('tril'), # Exception not raised on error input + xfail('triu'), # Exception not raised on error input xfail('__getitem__', ''), xfail('all'), xfail('any'), @@ -3349,6 +3363,7 @@ def test_vmap_exhaustive(self, device, dtype, op): xfail('resize_'), xfail('view_as_complex'), xfail('matrix_exp'), + xfail('trace'), # Does not support batched tensors xfail('bucketize'), xfail('fft.ihfft2'), xfail('fft.ihfftn'), @@ -3377,6 +3392,7 @@ def test_vmap_exhaustive(self, device, dtype, op): xfail('bernoulli', ''), xfail('linalg.lu_factor', ''), xfail('nn.functional.feature_alpha_dropout', 'with_train'), + xfail('native_dropout_backward'), xfail('nn.functional.kl_div', ''), xfail('multinomial', ''), xfail('column_stack', ''), @@ -3450,6 +3466,7 @@ def test_vmap_exhaustive(self, device, dtype, op): xfail('equal', ''), xfail('linalg.lu', ''), skip('linalg.ldl_solve', ''), + skip('_softmax_backward_data'), })) def test_op_has_batch_rule(self, device, dtype, op): # needs to be fixed @@ -3868,6 +3885,11 @@ def f(e_): skip('linalg.multi_dot'), # accepts list of tensor inputs, has its own special test xfail('linalg.vander'), xfail('linalg.vecdot'), + # throws in vmap on CUDA + # IndexError: Dimension out of range (expected to be in range of [-1, 0], but got -2) + # https://github.com/pytorch/pytorch/runs/8110653462?check_suite_focus=true + # but it passes locally + skip('linalg.matrix_norm', ''), skip('linalg.ldl_solve', ''), }) def test_vmap_linalg_failure_1D_input(self, device, dtype, op): @@ -3897,6 +3919,47 @@ def test_vmap_multi_dot_failure_1D_input(self): with self.assertRaisesRegex(RuntimeError, "tensor 1 must be 2D but got 1D"): return vmap(torch.linalg.multi_dot)(inputs) + def test_vmap_escaped_error(self): + escaped = None + + def f(x): + nonlocal escaped + escaped = x + return x ** 2 + + x = torch.randn([3, 3, 3, 3, 3]) + vmap(f)(x) + + common_message = r"your tensor may have escaped from inside a function being vmapped.*{0}.*" + + # Note: These are not a complete set of tests for all possible functions calling 'vmap_check_escaped' + + with self.assertRaisesRegex(RuntimeError, common_message.format("gen_vmap_plumbing")): + escaped.sin() + + with self.assertRaisesRegex(RuntimeError, common_message.format("boxed_tensor_inputs_batch_rule")): + escaped.sin_() + + with self.assertRaisesRegex(RuntimeError, common_message.format("gen_vmap_inplace_plumbing")): + escaped.mul_(1) + + with self.assertRaisesRegex(RuntimeError, common_message.format("binary_cross_entropy_plumbing")): + torch.nn.functional.binary_cross_entropy(escaped, torch.zeros([3, 3, 3, 3])) + + with self.assertRaisesRegex(RuntimeError, common_message.format("boxed_existing_bdim_all_batch_rule")): + torch.nn.functional.adaptive_max_pool2d(escaped, output_size=(1, 1)) + + with self.assertRaisesRegex(RuntimeError, common_message.format("boxed_reduction_batch_rule")): + escaped.argmin() + + a = torch.zeros([4, 4, 4, 4]) + b = torch.zeros([4, 4, 4, 4], dtype=torch.long) + with self.assertRaisesRegex(RuntimeError, common_message.format("boxed_all_tensors_have_optional_bdim")): + torch.ops.aten.adaptive_max_pool2d_backward(escaped, a, b) + + vmap(f)(torch.tensor([[0, 0], [0, 0]], dtype=torch.int)) + with self.assertRaisesRegex(RuntimeError, common_message.format("gen_vmap_plumbing_no_returns")): + torch.ops.aten._linalg_check_errors(escaped, 'linalg.inv', is_matrix=False) class TestRandomness(TestCase): def _reset_random(self, generator, orig_state, use_generator, seed): diff --git a/test/functorch/xfail_suggester.py b/test/functorch/xfail_suggester.py index 4ae552a44bd3c..cdf2cca13671c 100644 --- a/test/functorch/xfail_suggester.py +++ b/test/functorch/xfail_suggester.py @@ -69,7 +69,7 @@ def parse_namespace(base): 'linalg_': 'linalg', '_masked_': '_masked', 'sparse_': 'sparse', - 'speical_': 'special', + 'special_': 'special', } for heading in mappings.keys(): if base.startswith(heading): diff --git a/test/fx/test_common_passes.py b/test/fx/test_common_passes.py index 9c59abce4da61..407e707db8797 100644 --- a/test/fx/test_common_passes.py +++ b/test/fx/test_common_passes.py @@ -73,10 +73,15 @@ def MutationMetadata(x): if torch.cuda.is_available(): Devices.append("cuda") + +def name_fn(common_pass, f, device): + """Names parameterized test cases.""" + return f'{type(common_pass()).__name__}_{f.__name__}_{device}' + @instantiate_parametrized_tests class TestCommonPass(TestCase): - @parametrize("common_pass,f,device", itertools.product(Passes, Test_Cases, Devices)) + @parametrize("common_pass,f,device", itertools.product(Passes, Test_Cases, Devices), name_fn) def test_correctness(self, common_pass, f, device): inp = torch.randn(10, device=device) @@ -94,7 +99,7 @@ def test_correctness(self, common_pass, f, device): self.assertEqual(result, expected) - @parametrize("common_pass,f,device", itertools.product(Passes, Factory_Test_Cases, Devices)) + @parametrize("common_pass,f,device", itertools.product(Passes, Factory_Test_Cases, Devices), name_fn) def test_correctness_factory(self, common_pass, f, device): inp = torch.randn(10, device=device) traced_m = make_fx(f)(inp, device) diff --git a/test/fx/test_fx_param_shape_control_flow.py b/test/fx/test_fx_param_shape_control_flow.py index e9af35d604577..04db468a7e631 100644 --- a/test/fx/test_fx_param_shape_control_flow.py +++ b/test/fx/test_fx_param_shape_control_flow.py @@ -91,26 +91,26 @@ def verify_mm_relu_mods(self, mm_only_mod, relu_mod): performs both mm and relu ops in cascade """ x = torch.randn(10, 5) - torch.testing.assert_allclose(mm_only_mod(x), torch.mm(x, mm_only_mod.get_mul_matrix())) + torch.testing.assert_close(mm_only_mod(x), torch.mm(x, mm_only_mod.get_mul_matrix())) tracer = torch.fx.Tracer(param_shapes_constant=True) traced_graph = tracer.trace(mm_only_mod) # verify the graph module calculates the same result graph_mod_mm = torch.fx.GraphModule(mm_only_mod, traced_graph) - torch.testing.assert_allclose(graph_mod_mm(x), torch.mm(x, mm_only_mod.get_mul_matrix())) + torch.testing.assert_close(graph_mod_mm(x), torch.mm(x, mm_only_mod.get_mul_matrix())) # Make a new module with different parameter shape to go down the different # code path x = torch.randn(10, 15) - torch.testing.assert_allclose(relu_mod(x), torch.relu(torch.mm(x, relu_mod.get_mul_matrix()))) + torch.testing.assert_close(relu_mod(x), torch.relu(torch.mm(x, relu_mod.get_mul_matrix()))) tracer2 = torch.fx.Tracer(param_shapes_constant=True) traced_graph2 = tracer2.trace(relu_mod) # verify the graph module calculates the same result graph_mod_relu = torch.fx.GraphModule(relu_mod, traced_graph2) - torch.testing.assert_allclose(graph_mod_relu(x), torch.relu(torch.mm(x, relu_mod.get_mul_matrix()))) + torch.testing.assert_close(graph_mod_relu(x), torch.relu(torch.mm(x, relu_mod.get_mul_matrix()))) graph1_node_targets = [n.target for n in traced_graph.nodes] diff --git a/test/fx/test_pass_infra.py b/test/fx/test_pass_infra.py index 947c80d66dcee..7a7039979bebe 100644 --- a/test/fx/test_pass_infra.py +++ b/test/fx/test_pass_infra.py @@ -169,5 +169,5 @@ def pass_fail(graph_module): pm = PassManager(passes=[replace_add_with_mul_pass, replace_mul_with_div_pass, pass_fail]) # Comment out this line to see the actual error message - with self.assertRaises(RuntimeError): + with self.assertRaisesRegex(Exception, "pass_fail"): pm(traced_m) diff --git a/test/fx/test_subgraph_rewriter.py b/test/fx/test_subgraph_rewriter.py index ac3498458d600..4568eaa33bd61 100644 --- a/test/fx/test_subgraph_rewriter.py +++ b/test/fx/test_subgraph_rewriter.py @@ -773,7 +773,7 @@ def gemm_bias_mul_replacement_with_c(a, b, bias, c): self.assertEqual(repalcement_node_found, 2) - def test_replace_pattern_with_filter(self): + def test_replace_pattern_with_filters(self): class M(torch.nn.Module): def __init__(self): super().__init__() @@ -833,10 +833,31 @@ def num_repalcement_node_found(traced): # match with filter, should find 1 match traced = symbolic_trace(M()) - matches = subgraph_rewriter.replace_pattern_with_filter( + matches = subgraph_rewriter.replace_pattern_with_filters( traced, BinaryOpScalarReLUPattern, BinaryOpScalarReLUReplacement, - second_input_is_scalar) + [second_input_is_scalar]) self.assertEqual(len(matches), 1) self.assertEqual(num_repalcement_node_found(traced), 1) + + def test_matching_pattern_with_list_type_arg(self): + class M(torch.nn.Module): + def forward(self, x): + return torch.ops.aten._reshape_alias_copy.default(x, [1, 2], [3, 4]) + + def pattern(x, arg0, arg1): + return torch.ops.aten._reshape_alias_copy.default(x, arg0, arg1) + + def replacement(x, arg0, arg1): + return torch.ops.aten._reshape_alias_copy.default(x, arg1, arg0) + + traced = symbolic_trace(M()) + matches = subgraph_rewriter.replace_pattern(traced, pattern, replacement) + + self.assertEqual(len(matches), 1) + + self.assertExpectedInline(traced.code.strip(), """\ +def forward(self, x): + _reshape_alias_copy_default_1 = torch.ops.aten._reshape_alias_copy.default(x, [3, 4], [1, 2]); x = None + return _reshape_alias_copy_default_1""") # noqa: B950 diff --git a/test/inductor/test_minifier.py b/test/inductor/test_minifier.py new file mode 100644 index 0000000000000..18c5e5f33cade --- /dev/null +++ b/test/inductor/test_minifier.py @@ -0,0 +1,213 @@ +# Owner(s): ["module: inductor"] +import functools +import textwrap +import unittest + +import torch +import torch._dynamo +import torch._inductor.utils +from torch._dynamo.test_minifier_common import MinifierTestBase +from torch.testing._internal.common_utils import IS_MACOS + +_HAS_TRITON = torch._inductor.utils.has_triton() +requires_cuda = functools.partial(unittest.skipIf, not _HAS_TRITON, "requires cuda") + +CPP_COMPILE_ERROR = """\ +def cpp_compile_error(x): + return "compile error!" +""" + +CPP_RUNTIME_ERROR = """\ +def cpp_runtime_error(x): + return f"{x}; throw 1" +""" + +CPP_ACCURACY_ERROR = """\ +def cpp_accuracy_error(x): + return f"{x} + decltype({x})(1)" +""" + +TRITON_COMPILE_ERROR = """\ +def triton_compile_error(x): + return "compile error!" +""" + +# NOTE: there is currently not an easy way to cause a triton runtime error. +TRITON_RUNTIME_ERROR = """\ +def triton_runtime_error(x): + return f"{x}; assert?" +""" + +TRITON_ACCURACY_ERROR = """\ +def triton_accuracy_error(x): + return f"{x} + 1" +""" + + +class MinifierTests(MinifierTestBase): + @classmethod + def setUpClass(cls): + super().setUpClass() + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + + # Generates code that patches CppOverrides/TritonOverrides. + def _gen_codegen_fn_patch_code(self, old_fn_name, new_fn_code, device): + new_fn_name = self._get_fn_name(new_fn_code) + if new_fn_name is not None: + patch_code = f"""\ +import torch._inductor.codegen.{"cpp" if device == "cpu" else "triton"} as codegen +overrides = codegen.{"CppOverrides" if device == "cpu" else "TritonOverrides"} +vec_overrides = codegen.{"CppVecOverrides" if device == "cpu" else "TritonOverrides"} +{new_fn_code} +overrides.{old_fn_name} = staticmethod({new_fn_name}) +vec_overrides.{old_fn_name} = staticmethod({new_fn_name}) +""" + return f"""\ +{patch_code} +isolate_fails_code_str = \"\"\"\\ +{patch_code} +torch._dynamo.config.debug_dir_root = "{self.DEBUG_DIR}" +\"\"\" +""" + + # Test that compile and accuracy errors after aot can be repro'd (both CPU and CUDA) + def _test_after_aot(self, device, backend_code, repro_level): + run_code = textwrap.dedent( + f"""\ + @torch._dynamo.optimize("inductor") + def inner(x): + for _ in range(3): + x = torch.sin(x) + x = torch.relu(x) + for _ in range(3): + x = torch.cos(x) + return x + + inner(torch.randn(20, 20).to("{device}")) + """ + ) + patch_code = self._gen_codegen_fn_patch_code("relu", backend_code, device) + self.assertIsNotNone(patch_code) + (test_proc, _, repro_proc), _ = self._run_full_test( + run_code, "aot", repro_level, patch_code + ) + return ( + (test_proc.stderr.decode("utf-8"), repro_proc.stderr.decode("utf-8")), + (test_proc.returncode, repro_proc.returncode), + ) + + def test_after_aot_cpu_compile_error(self): + (tb1, tb2), _ = self._test_after_aot("cpu", CPP_COMPILE_ERROR, 2) + self.assertIn("CppCompileError", tb1) + self.assertIn("CppCompileError", tb2) + + def test_after_aot_cpu_accuracy_error(self): + (tb1, tb2), _ = self._test_after_aot("cpu", CPP_ACCURACY_ERROR, 4) + self.assertIn("AccuracyError", tb1) + self.assertIn("AccuracyError", tb2) + + @requires_cuda() + def test_after_aot_cuda_compile_error(self): + (tb1, tb2), _ = self._test_after_aot("cuda", TRITON_COMPILE_ERROR, 2) + self.assertIn("SyntaxError", tb1) + self.assertIn("SyntaxError", tb2) + + @requires_cuda() + def test_after_aot_cuda_accuracy_error(self): + (tb1, tb2), _ = self._test_after_aot("cuda", TRITON_ACCURACY_ERROR, 4) + self.assertIn("AccuracyError", tb1) + self.assertIn("AccuracyError", tb2) + + # Test that runtime errors after aot can be repro'd (CPU only for now) + def _test_after_aot_runtime_error(self, device, backend_code): + run_code = textwrap.dedent( + f"""\ + @torch._dynamo.optimize("inductor") + def inner(x): + for _ in range(3): + x = torch.sin(x) + x = torch.relu(x) + for _ in range(3): + x = torch.cos(x) + return x + + inner(torch.randn(20, 20).to("{device}")) + """ + ) + patch_code = self._gen_codegen_fn_patch_code("relu", backend_code, device) + self.assertIsNotNone(patch_code) + + (test_proc, _, repro_proc), _ = self._run_full_test( + run_code, "aot", 3, patch_code + ) + + self.assertNotIn("CompilerError", test_proc.stderr.decode("utf-8")) + + self.assertEqual(test_proc.returncode, repro_proc.returncode) + self.assertNotEqual(test_proc.returncode, 0) + + def test_after_aot_cpu_runtime_error(self): + self._test_after_aot_runtime_error("cpu", CPP_RUNTIME_ERROR) + + # NOTE: there is currently not an easy way to cause a triton runtime error. + @unittest.skip + @requires_cuda() + def test_after_aot_cuda_runtime_error(self): + self._test_after_aot_runtime_error("cuda", TRITON_RUNTIME_ERROR) + + # Ensure that inductor codegen patches pass when relu is not present. + def _test_after_aot_backend_passes(self, device, repro_level, backend_code): + run_code = textwrap.dedent( + f"""\ + @torch._dynamo.optimize("inductor") + def inner(x): + for _ in range(3): + x = torch.sin(x) + for _ in range(3): + x = torch.cos(x) + return x + + inner(torch.randn(20, 20).to("{device}")) + """ + ) + patch_code = self._gen_codegen_fn_patch_code("relu", backend_code, device) + self.assertIsNotNone(patch_code) + + test_code = self._gen_test_code(run_code, "aot", repro_level, patch_code) + proc, repro_dir = self._run_test_code(test_code) + self.assertEqual(proc.returncode, 0) + self.assertIsNone(repro_dir) + + def test_after_aot_cpu_compile_backend_passes(self): + self._test_after_aot_backend_passes("cpu", 2, CPP_COMPILE_ERROR) + + def test_after_aot_cpu_runtime_backend_passes(self): + self._test_after_aot_backend_passes("cpu", 2, CPP_RUNTIME_ERROR) + + def test_after_aot_cpu_accuracy_backend_passes(self): + self._test_after_aot_backend_passes("cpu", 4, CPP_ACCURACY_ERROR) + + @requires_cuda() + def test_after_aot_cuda_compile_backend_passes(self): + self._test_after_aot_backend_passes("cuda", 2, TRITON_COMPILE_ERROR) + + # NOTE: there is currently not an easy way to cause a triton runtime error. + @unittest.skip + @requires_cuda() + def test_after_aot_cuda_runtime_backend_passes(self): + self._test_after_aot_backend_passes("cuda", 2, TRITON_RUNTIME_ERROR) + + @requires_cuda() + def test_after_aot_cuda_accuracy_backend_passes(self): + self._test_after_aot_backend_passes("cuda", 4, TRITON_ACCURACY_ERROR) + + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + # skip CI tests on mac since CPU inductor does not seem to work due to C++ compile errors + if not IS_MACOS: + run_tests() diff --git a/test/inductor/test_perf.py b/test/inductor/test_perf.py new file mode 100644 index 0000000000000..d473ff4b74495 --- /dev/null +++ b/test/inductor/test_perf.py @@ -0,0 +1,434 @@ +# Owner(s): ["module: inductor"] +import contextlib +from unittest.mock import patch + +import torch._dynamo +import torch._inductor.config as config +from torch._dynamo.optimizations.backends import register_backend +from torch._inductor import metrics +from torch._inductor.compile_fx import compile_fx, count_bytes_inner +from torch.testing._internal.common_utils import ( + TEST_WITH_ROCM, + TestCase as TorchTestCase, +) +from torch.testing._internal.inductor_utils import HAS_CUDA + +aten = torch.ops.aten + + +@register_backend +def count_bytes_inductor(gm, example_inputs): + return compile_fx(gm, example_inputs, inner_compile=count_bytes_inner) + + +@torch._dynamo.optimize("count_bytes_inductor") +def f(x): + return torch.cat([x, x.cos()]) + + +def count_numel(f, *args): + """ + Assumes all inputs are fp32 + """ + metrics.reset() + torch._dynamo.optimize("count_bytes_inductor")(f)(*args) + print(metrics.nodes_num_elem) + return str(metrics.num_bytes_accessed // 4) + + +DEVICE = "cuda" + + +def T(*size, dtype=torch.float32, device=DEVICE): + return torch.randn(size, dtype=dtype, device=device) + + +def TI(*size, mx=10, dtype=torch.int32, device=DEVICE): + return torch.randint(0, mx, size, dtype=dtype, device=device) + + +class TestCase(TorchTestCase): + device = DEVICE + pass + + +class NumBytesMetricTests(TestCase): + """ + Primarily used for sanity testing that the num_bytes_accessed metrics is correct. + """ + + def test_pointwise(self): + def f(x): + return x.cos() + + inp = (T(10),) + self.assertExpectedInline(count_numel(f, *inp), """20""") + + def f(x, y): + return x + y + + inp = (T(10), T(10)) + self.assertExpectedInline(count_numel(f, *inp), """30""") + + def f(x, y): + return x + y + + inp = (T(10, 10), T(10)) + self.assertExpectedInline(count_numel(f, *inp), """210""") + + def f(x): + return x + x + + inp = (T(10),) + self.assertExpectedInline(count_numel(f, *inp), """20""") + + def f(x): + return x + x.t() + + inp = (T(10, 10),) + self.assertExpectedInline(count_numel(f, *inp), """200""") + + def f(a, b, c): + return a.cos(), b.sin() + c.sin() + + inp = (T(10), T(10), T(10)) + self.assertExpectedInline(count_numel(f, *inp), """50""") + + def test_reduction(self): + def f(x): + return x.sum(dim=1) + + inp = (T(10, 10),) + self.assertExpectedInline(count_numel(f, *inp), """110""") + + def f(x): + return x.sum(dim=0) + + inp = (T(10, 10),) + self.assertExpectedInline(count_numel(f, *inp), """110""") + + def test_extern(self): + def f(x): + return torch.mm(x, x) + + inp = (T(10, 10),) + self.assertExpectedInline(count_numel(f, *inp), """200""") + + def f(a, b): + return torch.mm(a, b) + + inp = (T(10, 10), T(10, 10)) + self.assertExpectedInline(count_numel(f, *inp), """300""") + + def f(x): + x = x.cos() + x = torch.mm(x, x) + x = x.cos() + return x + + inp = (T(10, 10),) + self.assertExpectedInline(count_numel(f, *inp), """600""") + + def f(x): + a = x.cos() + b = x.sin() + x = torch.mm(a, b) + return x + + inp = (T(10, 10),) + self.assertExpectedInline(count_numel(f, *inp), """600""") + + def test_cat(self): + def f(a, b): + return torch.cat([a.sin(), b.sin()]) + + inp = (T(10), T(10)) + self.assertExpectedInline(count_numel(f, *inp), """40""") + + def f(a, b): + return torch.cat([a, b]) + + inp = (T(10), T(10)) + self.assertExpectedInline(count_numel(f, *inp), """40""") + + def f(a, b): + return torch.cat([a.cos(), b]) + + inp = (T(10), T(10)) + self.assertExpectedInline(count_numel(f, *inp), """40""") + + def f(a): + return torch.cat([a.cos(), a.sin()]) + + inp = (T(10),) + self.assertExpectedInline(count_numel(f, *inp), """30""") + + def test_index(self): + def f(a, b): + return a[b] + + inp = (T(10), TI(10, mx=10)) + self.assertExpectedInline(count_numel(f, *inp), """30""") + + +class FusionTests(TestCase): + """ + Tests that things can be fused into a single kernel + """ + + def test_horizontal_reduction_pointwise(self): + def f(a): + b = a.sum(dim=1) + c = a.cos() + return b, c + + inp = (T(10, 10),) + self.assertExpectedInline(count_numel(f, *inp), """210""") + + def test_horizontal_reduction_reduction(self): + def f(a): + b = a.sum(dim=1) + c = a.amax(dim=1) + return b, c + + inp = (T(10, 10),) + self.assertExpectedInline(count_numel(f, *inp), """120""") + + def test_horizontal_reduction_pointwise2(self): + def f(a, b): + c = a.sum(dim=1) + b = b.cos() + return b + c + + inp = (T(10, 10), T(10)) + self.assertExpectedInline(count_numel(f, *inp), """120""") + + def test_horizontal_reduction_outer_pointwise(self): + def f(a, b): + c = a.sum(dim=0) + b = b.cos() + return b + c + + inp = (T(10, 10), T(10)) + self.assertExpectedInline(count_numel(f, *inp), """120""") + + def test_horizontal_sum_pw_broadcast(self): + def f(a, b): + a = a.sum(dim=1, keepdim=True) + b = b.cos() + return a * b + + inp = (T(10, 10), T(10)) + self.assertExpectedInline(count_numel(f, *inp), """210""") + + def test_vertical_sum_pw(self): + def f(a): + a = a.cos() + a = a.sum(dim=1) + return a.cos() + + inp = (T(10, 10),) + self.assertExpectedInline(count_numel(f, *inp), """110""") + + def test_norm_chain(self): + def f(a): + b = a.sum(dim=1, keepdim=True) + a = a * b + b = a.sum(dim=1, keepdim=True) + a = a * b + b = a.sum(dim=1, keepdim=True) + a = a * b + return a + + inp = (T(10, 10),) + self.assertExpectedInline(count_numel(f, *inp), """200""") + + def test_softmax_inner(self): + def f(a): + return torch.softmax(a, dim=1) + + inp = (T(10, 10),) + self.assertExpectedInline(count_numel(f, *inp), """200""") + + def test_layer_norm(self): + # TODO: Suboptimal! We shouldn't need to save normalization stats. + mod = torch.nn.LayerNorm(10, device=self.device) + + def f(x): + return mod(x) + + inp = (T(10, 10),) + with torch.no_grad(): + self.assertExpectedInline(count_numel(f, *inp), """220""") + + def test_double_softmax(self): + def f(x): + x = torch.softmax(x, dim=1) + x = torch.softmax(x, dim=1) + return x + + inp = (T(10, 10),) + self.assertExpectedInline(count_numel(f, *inp), """200""") + + def test_softmax_backward(self): + def f(grad_out, out): + return aten._softmax_backward_data(grad_out, out, 1, torch.float32) + + inp = (T(10, 10), T(10, 10)) + self.assertExpectedInline(count_numel(f, *inp), """300""") + + def test_neighbor(self): + def f(a, b): + return ((a - b) ** 2).sum(dim=-1).amax(dim=1) + + inp = (T(10, 1, 4), T(1, 10, 4)) + self.assertExpectedInline(count_numel(f, *inp), """90""") + + def test_factory_reduction(self): + def f(): + a = torch.ones(10, device=self.device) + b = torch.ones(10, 10, device=self.device) + return (a + b).sum(dim=-1) + + inp = () + self.assertExpectedInline(count_numel(f, *inp), """10""") + + def test_index_pointwise(self): + def f(a, b): + return a[b].cos() + + inp = (T(10, 10), TI(20, mx=10)) + self.assertExpectedInline(count_numel(f, *inp), """320""") + + def test_index_reduction(self): + def f(a, b): + return a[b].cos().sum(dim=1) + + inp = (T(10, 10), TI(20, mx=10)) + self.assertExpectedInline(count_numel(f, *inp), """140""") + + +class SchedulerFusionTests(TestCase): + """ + Testing the fusion group creation heuristic (i.e. cases where we can't fuse + everything into a single kernel) + Disables inductor rematerialization for easier reasoning of tests. + """ + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls._stack = contextlib.ExitStack() + cls._stack.enter_context(patch.object(config, "realize_bytes_threshold", 0)) + + @classmethod + def tearDownClass(cls): + cls._stack.close() + super().tearDownClass() + + def test_fusion_choice1(self): + # Doesn't matter where we break fusion group here + def f(a): + c = a.cos() + d = torch.mm(c, c) + e = c.cos() + return d + e + + inp = (T(10, 10),) + self.assertExpectedInline(count_numel(f, *inp), """700""") + + def test_fusion_choice2(self): + # We should materialize e (it's smaller!) + # [c, e]: 210, [f]: 210, [d]: 200 + def f(a): + c = a.cos() + d = torch.mm(c, c) + e = c.sum(dim=1) + f = d + e + return f + + inp = (T(10, 10),) + self.assertExpectedInline(count_numel(f, *inp), """620""") + + def test_fusion_choice3(self): + # We should materialize e. + # [c, e]: 300, [f]: 300, [d]: 200 + def f(a): + c = a.cos() + d = torch.mm(c, c) + e = c + a + f = d + e + return f, e + + inp = (T(10, 10),) + self.assertExpectedInline(count_numel(f, *inp), """800""") + + +class TilingTests(TestCase): + def test_tiling_simple(self): + def f(a, b): + return a + b.t() + + inp = (T(10, 10), T(10, 10)) + self.assertExpectedInline(count_numel(f, *inp), """300""") + + def f(a, b): + return a.t() + b + + inp = (T(10, 10), T(10, 10)) + self.assertExpectedInline(count_numel(f, *inp), """300""") + + def test_tiling_three(self): + def f(a, b, c): + return a + b.permute(1, 2, 0) + c.permute(2, 0, 1) + + inp = (T(10, 10, 10), T(10, 10, 10), T(10, 10, 10)) + self.assertExpectedInline(count_numel(f, *inp), """4000""") + + +# Test cases where we don't do the right thing yet. +class WouldBeNiceIfItWorked: + def test_horizontal(self): + def f(a): + b = a.sum(dim=0) + c = a.cos() + return b, c + + inp = (T(10, 10),) + self.assertExpectedInline(count_numel(f, *inp), """210""") + + # TODO: We aren't fusing outer dim softmaxes + def test_softmax_outer(self): + def f(a): + return torch.softmax(a, dim=0) + + inp = (T(10, 10),) + self.assertExpectedInline(count_numel(f, *inp), """200""") + + # TODO: The greedy fusion strategy results in suboptimal grouping + @patch.object(config, "realize_bytes_threshold", 0) + def test_fusion_choice4(self): + def f(a, b, b2): + c = a + b + d = torch.mm(c, c) + e = c + b + b2 + f = d + e + b2 + return f, e + + inp = (T(10, 10), T(10, 10, dtype=torch.float16), T(10, 10)) + self.assertExpectedInline(count_numel(f, *inp), """1000""") + + # TODO: We materialize the intermediate if we don't unroll the reduction + def test_neighbor(self): + def f(a, b): + return ((a - b) ** 2).sum(dim=-1).amax(dim=1) + + inp = (T(10, 1, 8), T(1, 10, 8)) + self.assertExpectedInline(count_numel(f, *inp), """170""") + + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + if HAS_CUDA and not TEST_WITH_ROCM: + run_tests(needs="filelock") diff --git a/test/inductor/test_smoke.py b/test/inductor/test_smoke.py new file mode 100644 index 0000000000000..89079723bc224 --- /dev/null +++ b/test/inductor/test_smoke.py @@ -0,0 +1,62 @@ +# Owner(s): ["module: inductor"] +import logging + +import torch +import torch._dynamo as torchdynamo +import torch._inductor.config as torchinductor_config +from torch.testing._internal.common_utils import IS_LINUX, TestCase + + +class MLP(torch.nn.Module): + def __init__(self): + super(MLP, self).__init__() + self.l1 = torch.nn.Linear(1, 6) + self.l2 = torch.nn.Linear(6, 1) + + def forward(self, x=None): + x = torch.relu(self.l1(x)) + x = torch.relu(self.l2(x)) + return x + + +def _test_f(x): + return x * x + + +class SmokeTest(TestCase): + def test_mlp(self): + torchdynamo.config.log_level = logging.INFO + torchdynamo.config.verbose = True + torchinductor_config.debug = True + + mlp = torch.compile(MLP().cuda()) + for _ in range(3): + mlp(torch.randn(1, device="cuda")) + + torchdynamo.config.verbose = False + torchinductor_config.debug = False + + def test_compile_decorator(self): + @torch.compile + def foo(x): + return torch.sin(x) + x.min() + + @torch.compile(mode="reduce-overhead") + def bar(x): + return x * x + + for _ in range(3): + foo(torch.full((3, 4), 0.7, device="cuda")) + bar(torch.rand((2, 2), device="cuda")) + + def test_compile_invalid_options(self): + with self.assertRaises(RuntimeError): + opt_f = torch.compile(_test_f, mode="ha") + + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + if IS_LINUX and torch.cuda.is_available(): + if torch.cuda.get_device_properties(0).major > 5: + run_tests() diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index df5a7fb0a21de..5ea874b0fdb80 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -3,20 +3,27 @@ import dataclasses import functools import importlib +import itertools import os import random import sys +import typing import unittest import weakref +from typing import Any, Callable from unittest.mock import patch +import numpy as np + import torch import torch._dynamo from torch._dynamo.debug_utils import same_two_models from torch._dynamo.testing import rand_strided, same from torch.fx.experimental.proxy_tensor import make_fx +from torch.fx.passes.shape_prop import ShapeProp from torch.nn import functional as F +from torch.testing import make_tensor from torch.testing._internal.common_utils import ( TEST_WITH_ASAN, TEST_WITH_ROCM, @@ -34,11 +41,21 @@ import torch._inductor.config from functorch.compile import config as functorch_config from torch._decomp import get_decompositions - from torch._inductor import config + from torch._inductor import codecache, config, metrics + from torch._inductor.codegen.cpp import cexpr, CppOverrides, CppVecOverrides + from torch._inductor.codegen.triton import texpr from torch._inductor.compile_fx import compile_fx, complex_memory_overlap from torch._inductor.ir import IndexingDiv, ModularIndexing + from torch._inductor.overrides import ( + linear_permute_fusion, + linear_transpose, + permute_linear_fusion, + permute_matmul_fusion, + transpose_linear, + transpose_matmul, + ) from torch._inductor.sizevars import SizeVarAllocator - from torch._inductor.utils import has_torchvision_roi_align, has_triton, timed + from torch._inductor.utils import has_torchvision_roi_align, timed # This will only pass on pytorch builds newer than roughly 5/15/2022 assert get_decompositions([torch.ops.aten.trace]) @@ -48,33 +65,55 @@ sys.stderr.write(f"{type(e)}: {e}\n") if __name__ == "__main__": sys.exit(0) - raise unittest.SkipTest("requires sympy/functorch/filelock") - - -HAS_CPU = False -try: - from subprocess import CalledProcessError - - from torch._inductor.codecache import CppCodeCache + raise unittest.SkipTest("requires sympy/functorch/filelock") from e - CppCodeCache.load("") - HAS_CPU = True -except ( - CalledProcessError, - OSError, - torch._inductor.exc.InvalidCxxCompiler, - torch._inductor.exc.CppCompileError, -): - pass +from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA aten = torch.ops.aten - -HAS_CUDA = has_triton() requires_cuda = functools.partial(unittest.skipIf, not HAS_CUDA, "requires cuda") torch._inductor.config.triton.autotune = False # too slow +# For OneDNN bf16 path, OneDNN requires the cpu has intel avx512 with avx512bw, +# avx512vl, and avx512dq at least. So we will skip the test case if one processor +# is not meet the requirement. +@functools.lru_cache(maxsize=None) +def has_bf16_support(): + import sys + + if sys.platform != "linux": + return False + with open("/proc/cpuinfo", encoding="ascii") as f: + lines = f.read() + return all(word in lines for word in ["avx512bw", "avx512vl", "avx512dq"]) + + +unary_list = [ + torch.nn.ReLU(), + torch.nn.Sigmoid(), + torch.nn.Tanh(), + torch.nn.Hardswish(), + torch.nn.LeakyReLU(0.1, inplace=False), + torch.nn.Hardtanh(min_val=-0.5, max_val=4, inplace=False), + torch.nn.GELU(approximate="none"), + torch.nn.GELU(approximate="tanh"), + torch.nn.ReLU6(), + torch.nn.SiLU(), +] + + +binary_list = [ + lambda x, y: torch.add(x, y), # call_function + lambda x, y: torch.add(y, x), # call_function + lambda x, y: x.add(y), # call_method + lambda x, y: x.add_(y), # call_method + lambda x, y: torch.sub(x, y), # call_function + lambda x, y: x.sub(y), # call_method + lambda x, y: x.sub_(y), # call_method +] + + def requires_decomp(fn): """Decorator to disable test if a decomp is missing""" @@ -90,6 +129,29 @@ def maybe_test(*args, **kwargs): return wrap_test +PassFunc = Callable[[torch.fx.GraphModule, Any], torch.fx.GraphModule] + + +def chain_passes(*passes: PassFunc) -> PassFunc: + def parent_pass(module: torch.fx.GraphModule, input: Any) -> torch.fx.GraphModule: + for pass_ in passes: + if isinstance(module, torch.fx.GraphModule): + ShapeProp(module).propagate(*input) + module = pass_(module) + return module + + return parent_pass + + +def count_call_function(module: torch.fx.GraphModule, target_op: Any) -> int: + return sum( + [ + 1 if (n.op == "call_function" and n.target == target_op) else 0 + for n in module.graph.nodes + ] + ) + + class TestCase(TorchTestCase): @classmethod def setUpClass(cls): @@ -166,8 +228,15 @@ def gather_leaf_tensors(args, kwargs): ) +def clone_preserve_strides(x): + if not isinstance(x, torch.Tensor): + return x + buffer = torch.as_strided(x, (x.storage().size(),), (1,), 0).clone() + out = torch.as_strided(buffer, x.size(), x.stride(), x.storage_offset()) + return out + + @patch.object(torch._inductor.config.triton, "cudagraphs", False) -@patch("torch._dynamo.config.raise_on_backend_error", True) def check_model( self: TestCase, model, @@ -187,9 +256,10 @@ def check_model( kwargs = kwargs or {} torch._dynamo.reset() - ref_inputs = example_inputs + ref_inputs = [clone_preserve_strides(x) for x in example_inputs] ref_kwargs = kwargs has_lowp_args = False + original_lowp_dtype = torch.half if reference_in_float: # check_lowp is ignored here, it's kept just to be able to call `common` with extra arg @@ -203,9 +273,15 @@ def upcast_fn(x): else: return x + def get_original_lowp_dtype(example_inputs): + dtypes = [x.dtype for x in example_inputs if isinstance(x, torch.Tensor)] + dtype_set = set(dtypes) + return dtype_set.pop() if len(dtype_set) == 1 else torch.half + ref_inputs = list(map(upcast_fn, example_inputs)) ref_kwargs = {k: upcast_fn(v) for k, v in kwargs.items()} if has_lowp_args: + original_lowp_dtype = get_original_lowp_dtype(example_inputs) if hasattr(model, "to"): model = model.to(torch.float) @@ -215,7 +291,7 @@ def upcast_fn(x): # downcast the model back if needed if reference_in_float and has_lowp_args: if hasattr(model, "to"): - model = model.to(torch.half) + model = model.to(original_lowp_dtype) torch._inductor.metrics.reset() @@ -261,6 +337,16 @@ def run(*ex, **kwargs): equal_nan=True, exact_dtype=exact_dtype, ) + # In case of input mutations, check that inputs are the same + self.assertEqual( + ref_inputs, + example_inputs, + atol=atol, + rtol=rtol, + equal_nan=True, + # our testing sometimes uses higher precision inputs for the reference + exact_dtype=False, + ) else: for correct_val, actual_val in zip(correct_flat, actual_flat): if isinstance(correct_val, torch.Tensor): @@ -410,13 +496,6 @@ def populate(cls): cls.gen_template(name1, name2) -class SweepInputsCpuTest(SweepInputs2, TestCase): - gen = InputGen(10, "cpu") - - -SweepInputsCpuTest.populate() - - class TestIndexingSimplification(TorchTestCase): def test_indexing_simplification(self): sizevars = SizeVarAllocator() @@ -442,7 +521,12 @@ def test_indexing_simplification(self): self.assertEqual( sizevars.simplify_with_ranges(expr, var_ranges), i1 + 128 * i2 + 64 * r3 ) - + # if there are negative terms in ModularIndexing base, we cannot replace it with IndexingDiv + expr = ModularIndexing(i1 - 15, 1, 64) + self.assertEqual( + sizevars.simplify_with_ranges(expr, var_ranges), + ModularIndexing(i1 - 15, 1, 64), + ) # small terms should be kept if the rest is not guaranteed to be divisible self.assertEqual( sizevars.simplify_with_ranges(IndexingDiv(r3 + i2 + i1, 32), var_ranges), @@ -472,6 +556,14 @@ def test_indexing_simplification(self): ModularIndexing(i0 + i1 * i2 * r3, i2, r3), ModularIndexing(i0, i2, r3) ) + # if there are negative terms, we cannot optimize away zero terms due to https://github.com/openai/triton/issues/619 + self.assertEqual( + ModularIndexing(-i0 + i1 * 20, 2, 10), ModularIndexing(-i0 + i1 * 20, 2, 10) + ) + self.assertEqual( + ModularIndexing(-15 + i1 * 20, 2, 10), ModularIndexing(-15 + i1 * 20, 2, 10) + ) + # Constant fold from divisor into base self.assertEqual(ModularIndexing(i0 * 4, 2, 10), ModularIndexing(i0 * 2, 1, 10)) self.assertEqual(IndexingDiv(i0 * 4, 2), i0 * 2) @@ -612,11 +704,22 @@ def fn(a): self.common(fn, [torch.linspace(-10, 10, 41)]) + def test_sgn_extremal(self): + def fn(a): + return (torch.sgn(a),) + + self.common(fn, [torch.tensor([np.nan, np.inf, -np.inf, 0])]) + def test_max_min(self): def fn(a, b): return (torch.maximum(a, b), torch.minimum(a, b)) self.common(fn, (torch.randn(8), torch.randn(8))) + t1 = torch.randn(8) + t1[0] = float("nan") + t2 = torch.randn(8) + t2[1] = float("nan") + self.common(fn, (t1, t2)) def test_horizonal_fusion1(self): def fn(a, b, c): @@ -655,6 +758,44 @@ def fn(sa, ct, p): ) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) + def test_forced_buffer_realize(self): + # Test torch._test_inductor_realize forces a buffer to be realized + def fn(a): + b = torch._test_inductor_realize(a * 2) + return (b * 2,) + + self.common(fn, (torch.randn(10),)) + self.assertEqual(torch._inductor.metrics.ir_nodes_pre_fusion, 2) + + def test_scheduler_vertical_fusion1(self): + realize = torch._test_inductor_realize + + def fn(sa, ct, p): + # From torchbench.pyhpc_equation_of_state + v17 = -3.087032500374211e-7 + v18 = -1.988366587925593e-8 + v19 = -1.061519070296458e-11 + v20 = 1.550932729220080e-10 + t15 = realize(v19 * ct) + t19 = realize(v17 + ct * (v18 + t15) + v20 * sa) + t20 = realize(1.0 / t19) + t128 = realize(t19 * p) + return t20 + t128 + + self.common( + fn, + ( + torch.randn(204, 204, 26), + torch.randn(204, 204, 26), + torch.randn(26), + ), + ) + self.assertEqual(torch._inductor.metrics.ir_nodes_pre_fusion, 5) + self.assertEqual( + torch._inductor.metrics.generated_kernel_count, + 1 if self.device == "cuda" else 3, + ) + def test_sum1(self): def fn(a, b): return ((a + b).sum(-1),) @@ -721,6 +862,17 @@ def fn(a): self.common(fn, (torch.full((4,), float("-inf")),)) + def test_reduction4(self): + if self.device == "cpu": + raise unittest.SkipTest("Non-deterministic CPU results") + + def fn(a): + return (a.argmax(-1), a.argmin(-1)) + + inputs = (torch.ones(128), torch.ones(4, 4, 1)) + for i in inputs: + self.common(fn, (i,)) + @patch.object(config, "dynamic_shapes", False) def test_unroll_small_reduction(self): def fn(x): @@ -759,6 +911,11 @@ def fn(a): self.common(fn, ((torch.rand((10, 3, 352, 352), dtype=torch.float16),))) def test_expanded_reduction(self): + if self.device == "cpu": + raise unittest.SkipTest( + "https://github.com/pytorch/torchdynamo/issues/1697" + ) + def fn(x, y): z = x * y return z.sum((0, 1)) @@ -1086,6 +1243,45 @@ def fn(a, b): self.common(fn, (1024, 100)) + def test_div_zero_dim(self): + def fn(a, b): + return ( + aten.div(a, b, rounding_mode=None), + aten.div(a, b, rounding_mode="floor"), + aten.div(a, b, rounding_mode="trunc"), + a / b, + a // b, + ) + + for dtype in (torch.float32, torch.int64): + self.common( + fn, + ( + make_tensor(10, device="cpu", dtype=dtype), + make_tensor((), device="cpu", dtype=dtype, exclude_zero=True), + ), + ) + self.common( + fn, + ( + make_tensor((), device="cpu", dtype=dtype), + make_tensor(10, device="cpu", dtype=dtype, exclude_zero=True), + ), + ) + + def test_div_prim(self): + def fn(a, b): + return (torch.ops.prims.div(a, b),) + + for dtype in (torch.float32, torch.int64): + self.common( + fn, + ( + make_tensor(100, device="cpu", dtype=dtype), + make_tensor(100, device="cpu", dtype=dtype, exclude_zero=True), + ), + ) + def test_both_scalars(self): def fn(a, b): return ( @@ -1279,6 +1475,361 @@ def fn(a, b): check_lowp=False, ) + # For gpu path, there has a accurcy issue, + @unittest.skipIf(HAS_CUDA, "only support cpu conv bn test") + def test_conv_bn_fuse(self): + input_shapes = {1: (112,), 2: (112, 112), 3: (55, 55, 55)} + conv_modules = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d} + bn_modules = { + 1: torch.nn.BatchNorm1d, + 2: torch.nn.BatchNorm2d, + 3: torch.nn.BatchNorm3d, + } + options = itertools.product( + [1, 2, 3], + [True, False], + [1, 3], + [1, 2], + [1, 4], + ) + + for ( + dim, + bias, + kernel_size, + dilation, + groups, + ) in options: + oC = 32 * groups + iC = 3 * groups + x_shape = (1, iC) + input_shapes[dim] + mod = torch.nn.Sequential( + conv_modules[dim]( + iC, + oC, + kernel_size=kernel_size, + dilation=dilation, + groups=groups, + bias=bias, + ), + bn_modules[dim](oC), + ).eval() + test_memory_format = [torch.contiguous_format] + # TODO: GPU path doesn't support channels_last now. + if not HAS_CUDA and dim > 1: + channels_last = ( + torch.channels_last if dim == 2 else torch.channels_last_3d + ) + test_memory_format.append(channels_last) + for memory_format in test_memory_format: + v = torch.randn(x_shape, dtype=torch.float32).to( + memory_format=memory_format + ) + with torch.no_grad(): + self.common( + mod, + (v,), + ) + + # For gpu path, there has a accurcy issue, + @unittest.skipIf(HAS_CUDA, "only support cpu conv bn test") + def test_conv_functional_bn_fuse(self): + # Define a BatchNorm using functional BN. + class BatchNorm(torch.nn.BatchNorm2d): + def __init__( + self, + num_features, + eps=1e-5, + momentum=0.1, + affine=True, + track_running_stats=True, + device=None, + dtype=None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super(BatchNorm, self).__init__( + num_features, + eps=eps, + momentum=momentum, + affine=affine, + track_running_stats=track_running_stats, + **factory_kwargs, + ) + + def forward(self, x): + if self.momentum is None: + exponential_average_factor = 0.0 + else: + exponential_average_factor = self.momentum + + if self.training and self.track_running_stats: + # TODO: if statement only here to tell the jit to skip emitting this when it is None + if self.num_batches_tracked is not None: # type: ignore[has-type] + self.num_batches_tracked = self.num_batches_tracked + 1 # type: ignore[has-type] + if self.momentum is None: # use cumulative moving average + exponential_average_factor = 1.0 / float( + self.num_batches_tracked + ) + else: # use exponential moving average + exponential_average_factor = self.momentum + if self.training: + bn_training = True + else: + bn_training = (self.running_mean is None) and ( + self.running_var is None + ) + x = F.batch_norm( + x, + # If buffers are not to be tracked, ensure that they won't be updated + self.running_mean + if not self.training or self.track_running_stats + else None, + self.running_var + if not self.training or self.track_running_stats + else None, + self.weight, + self.bias, + bn_training, + exponential_average_factor, + self.eps, + ) + return x + + v = torch.randn(1, 3, 556, 56, dtype=torch.float32) + mod = torch.nn.Sequential( + torch.nn.Conv2d( + 3, + 64, + kernel_size=3, + dilation=1, + groups=1, + bias=True, + ), + BatchNorm(64), + ).eval() + with torch.no_grad(): + self.common( + mod, + (v,), + ) + + @unittest.skipIf(HAS_CUDA, "only support cpu conv2d unary test") + def test_conv2d_packed(self): + x_shape = (1, 3, 56, 56) + mod = torch.nn.Sequential(torch.nn.Conv2d(3, 64, 3, 3)).eval() + v = torch.randn(x_shape, dtype=torch.float32) + with torch.no_grad(): + self.common( + mod, + (v,), + ) + + # For gpu path, there has a accurcy issue, + # see https://github.com/pytorch/pytorch/issues/87745. + @unittest.skipIf(HAS_CUDA, "only support cpu conv2d unary test") + def test_conv2d_unary(self): + test_memory_format = [torch.contiguous_format, torch.channels_last] + options = itertools.product( + unary_list, + [True, False], + [1, 3], + [1, 2], + [1, 4], + ["same", 0], + test_memory_format, + ) + + for ( + unary_fn, + bias, + kernel_size, + dilation, + groups, + padding, + memory_format, + ) in options: + oC = 32 * groups + iC = 3 * groups + x_shape = (1, iC, 112, 112) + mod = torch.nn.Sequential( + torch.nn.Conv2d( + iC, + oC, + kernel_size=kernel_size, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + ), + unary_fn, + ).eval() + + # TODO: add bf16 test for cpu path? + # TODO: this test fails when requires_grad=False + v = ( + torch.randn(x_shape, dtype=torch.float32, requires_grad=True) + .add(1) + .to(memory_format=memory_format) + ) + with torch.no_grad(): + self.common( + mod, + (v,), + ) + + # For gpu path, there has a accurcy issue, + # see https://github.com/pytorch/pytorch/issues/87745. + @unittest.skipIf(HAS_CUDA, "only support cpu conv2d binary test") + def test_conv2d_binary(self): + class M(torch.nn.Module): + def __init__( + self, + binary_fn, + in_channels, + out_channels, + dilation, + groups, + padding, + bias, + has_relu, + **kwargs, + ): + super(M, self).__init__() + self.conv1 = torch.nn.Conv2d( + in_channels, + out_channels, + dilation=dilation, + groups=groups, + padding=padding, + bias=bias, + **kwargs, + ) + self.conv2 = torch.nn.Sequential( + torch.nn.Conv2d( + in_channels, + out_channels, + dilation=dilation, + groups=groups, + padding=padding, + bias=bias, + **kwargs, + ) + ) + self.binary_fn = binary_fn + self.relu = torch.nn.ReLU() if has_relu else torch.nn.Identity() + + def forward(self, x): + x1 = self.conv1(x) + x2 = self.conv2(x) + return self.relu(self.binary_fn(x1, x2)) + + test_memory_format = [torch.contiguous_format, torch.channels_last] + options = itertools.product( + binary_list, + [True, False], + [True, False], + [1, 3], + [1, 2], + [1, 4], + ["same", 0], + test_memory_format, + ) + + for ( + binary_fn, + has_relu, + bias, + kernel_size, + dilation, + groups, + padding, + memory_format, + ) in options: + oC = 32 * groups + iC = 3 * groups + x_shape = (1, iC, 112, 112) + mod = M( + binary_fn, + iC, + oC, + dilation, + groups, + padding, + bias, + has_relu, + kernel_size=kernel_size, + ).eval() + mod = mod.to(memory_format=memory_format) + # TODO: add bf16 test + v = torch.randn(x_shape, dtype=torch.float32).to( + memory_format=memory_format + ) + with torch.no_grad(): + self.common( + mod, + (v,), + ) + + def test_linear_packed(self): + options = itertools.product([[2, 3, 10], [2, 10]], [True, False]) + for input_shape, bias in options: + mod = torch.nn.Sequential( + torch.nn.Linear(input_shape[-1], 30, bias=bias) + ).eval() + + v = torch.randn(input_shape) + with torch.no_grad(): + self.common( + mod, + (v,), + ) + + def test_linear_unary(self): + options = itertools.product(unary_list, [[2, 3, 10], [2, 10]], [True, False]) + dtype = torch.bfloat16 + if has_bf16_support(): + for eltwise_fn, input_shape, bias in options: + mod = torch.nn.Sequential( + torch.nn.Linear(input_shape[-1], 30, bias=bias), eltwise_fn + ).eval() + + # only fuse for linear when the dtype is bf16 + mod = mod.to(dtype) + v = torch.randn(input_shape).to(dtype) + with torch.no_grad(): + self.common( + mod, + (v,), + ) + + def test_linear_binary(self): + class M(torch.nn.Module): + def __init__(self, eltwise_fn, in_channels, out_channels, bias, **kwargs): + super(M, self).__init__() + self.linear = torch.nn.Linear( + in_channels, out_channels, bias=bias, **kwargs + ) + self.eltwise = eltwise_fn + + def forward(self, x, y): + x = self.linear(x) + x = self.eltwise(x, y) + return x + + options = itertools.product(binary_list, [[2, 3, 10], [2, 10]], [True, False]) + dtype = torch.bfloat16 + out_feature = 30 + if has_bf16_support(): + for binary_ops, input_shape, bias in options: + mod = M(binary_ops, input_shape[-1], out_feature, bias).eval() + + # only fuse for linear when the dtype is bf16 + mod = mod.to(dtype) + v = torch.randn(input_shape).to(dtype) + other = torch.randn(input_shape[:-1] + [out_feature]).to(dtype) + with torch.no_grad(): + self.common(mod, (v, other), atol=2e-3, rtol=0.016) + def test_gather1(self): def fn(a, b): return ( @@ -1470,6 +2021,7 @@ def fn(x, w, b): check_lowp=False, ) + @unittest.skipIf(HAS_CUDA, "only support cpu channels_last") def test_conv2d_channels_last(self): m = torch.nn.Sequential( torch.nn.Conv2d(3, 3, 1, 1), @@ -1553,6 +2105,7 @@ def fn(x): self.common( fn, (torch.randn(2, 4, 16, 16),), + check_lowp=False, ) # lowering to avg_pool2d case @@ -1567,6 +2120,19 @@ def fn(x): (torch.randn(2, 4, 6, 6),), ) + def test_adaptive_avg_pool2d2(self): + # Big kernel size, use fallback + def fn(x): + return aten._adaptive_avg_pool2d(x, (4, 4)) + + torch._inductor.metrics.generated_kernel_count = 0 + self.common( + fn, + (torch.randn(2, 4, 21, 21),), + check_lowp=False, + ) + self.assertEqual(torch._inductor.metrics.generated_kernel_count, 0) + def test_max_pool2d1(self): def fn(x): return aten.max_pool2d_with_indices(x, [3, 3], [2, 2]) @@ -1614,6 +2180,18 @@ def fn(x): (torch.randn([16, 64, 55, 55]),), ) + def test_max_pool2d6(self): + # Too big kernel size, use fallback + def fn(x): + return aten.max_pool2d_with_indices(x, [13, 13], []) + + torch._inductor.metrics.generated_kernel_count = 0 + self.common( + fn, + (torch.randn([16, 64, 55, 55]),), + ) + self.assertEqual(torch._inductor.metrics.generated_kernel_count, 0) + def test_avg_pool2d1(self): def fn(x): return aten.avg_pool2d(x, [3, 3], [2, 2]) @@ -1668,6 +2246,18 @@ def fn(x): (-torch.arange(1 * 8 * 8, dtype=torch.float32).view(1, 1, 8, 8),), ) + def test_avg_pool2d7(self): + # Large kernel size, use fallback + def fn(x): + return aten.avg_pool2d(x, [13, 13], [1, 1], [0, 0]) + + torch._inductor.metrics.generated_kernel_count = 0 + self.common( + fn, + (-torch.arange(1 * 24 * 24, dtype=torch.float32).view(1, 1, 24, 24),), + ) + self.assertEqual(torch._inductor.metrics.generated_kernel_count, 0) + def test_alexnet_prefix(self): def forward(arg6, arg7, arg16): convolution = torch.ops.aten.convolution( @@ -1926,6 +2516,92 @@ def test_layer_norm(self): if self.device != "cpu": self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) + def test_transpose_add(self): + def fn(a, b): + return a.t() + b + + self.common( + fn, (torch.randn([16, 32]), torch.randn([32, 16])), check_lowp=False + ) + if self.device != "cpu": + self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) + + def test_softmax_one_kernel(self): + def fn(x): + dim = 1 + x_max = torch.amax(x, dim, keepdim=True) + unnormalized = torch.exp(x * x_max) + result = unnormalized / torch.sum(unnormalized, dim, keepdim=True) + return result + + self.common(fn, (torch.randn([16, 32]),), check_lowp=False) + if self.device != "cpu": + self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) + + def test_cauchy(self): + def fn(x, y): + return torch.sum(1 / (torch.unsqueeze(x, -1) - y)) + + self.common( + fn, + ( + torch.randn(32), + torch.randn(32), + ), + # Absolute difference: 0.0003662109375 (up to 0.0001 allowed) + # Relative difference: 1.8804297408767818e-05 (up to 1e-05 allowed) + atol=5 * 1e-4, + rtol=5 * 1e-5, + check_lowp=False, + ) + if self.device != "cpu": + self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) + + def test_gather_scatter(self): + def fn(node_feat, edge_index): + src_node_feat = node_feat[edge_index[0]] + dst_node_feat = node_feat[edge_index[1]] + edge_feat = src_node_feat - dst_node_feat + 1 + new_node_feat = torch.zeros_like(node_feat) + new_node_feat.scatter_add_( + 0, edge_index[1].unsqueeze(-1).expand_as(edge_feat), edge_feat + ) + return new_node_feat + + num_nodes = 16 + num_features = 32 + node_feat = torch.randn(num_nodes, num_features) + edge_index = torch.randint(0, num_nodes, size=(2, num_nodes * 5)) + self.common( + fn, + ( + node_feat, + edge_index, + ), + check_lowp=False, + ) + if self.device != "cpu": + self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) + + @patch.object(torch._inductor.config, "max_fusion_size", 1) + def test_no_mega_fusion_during_lowering(self): + n = 50 + + def fn(*args): + x = args[0] + for i in range(n): + x = torch.add(x, args[i]) + return x + + self.common( + fn, + [torch.randn(64) for _ in range(n)], + check_lowp=False, + ) + print("-->", torch._inductor.metrics.generated_kernel_count) + if self.device != "cpu": + self.assertTrue(torch._inductor.metrics.generated_kernel_count > 1) + def test_move_arange(self): def fn(x): return torch.arange(len(x), device="cpu").to(x.device) + x @@ -2038,6 +2714,17 @@ def fn(x): rtol=3e-05, ) + def test_pow3(self): + # power of 0.5 is special-cased, arbitrary power would still produce triton codegen error + def fn(x): + z = torch.tensor(0.123, device=self.device) + w = z + x + return torch.pow(w, 0.5) + + opt = torch._dynamo.optimize("inductor")(fn) + input = torch.rand(()) + self.assertTrue(same(opt(input), fn(input))) + def test_glu(self): def fn(x): return aten.glu(x, -1), aten.glu(x, 1), aten.glu(x, 2) @@ -2141,16 +2828,44 @@ def fn(x): (torch.randn([64]),), ) - def test_flip(self): + def test_expm1(self): def fn(x): - return torch.flip(x, (-1,)), torch.flip(x, (0, 2)) - 2 + return torch.expm1(x), torch.expm1(x) * 2 - self.common( - fn, - (torch.randn([1, 2, 6, 6]),), - ) + for dtype in (torch.float16, torch.float, torch.double, torch.int, torch.int64): + self.common( + fn, + (torch.randn([64]).to(dtype=dtype),), + ) + self.common( + fn, + (torch.arange(-1e-5, 1e-5, 1e-7).to(dtype=dtype),), + ) - def test_signbit(self): + def test_log1p(self): + def fn(x): + return torch.log1p(x), torch.log1p(x) * 2 + + for dtype in (torch.float16, torch.float, torch.double, torch.int, torch.int64): + self.common( + fn, + (torch.randn([64]).to(dtype=dtype),), + ) + self.common( + fn, + (torch.arange(-1e-5, 1e-5, 1e-7).to(dtype=dtype),), + ) + + def test_flip(self): + def fn(x): + return torch.flip(x, (-1,)), torch.flip(x, (0, 2)) - 2 + + self.common( + fn, + (torch.randn([1, 2, 6, 6]),), + ) + + def test_signbit(self): def fn(x): return torch.signbit(x), ~torch.signbit(-x) & 1 @@ -2166,6 +2881,25 @@ def fn(a, b): shape = [1, 2, 6, 6] self.common(fn, (torch.randn(shape), torch.randn(shape))) + def test_fmod_zero_dim(self): + def fn(a, b): + return (torch.fmod(a, b),) + + self.common( + fn, + ( + make_tensor(10, device="cpu", dtype=torch.float32), + make_tensor((), device="cpu", dtype=torch.float32), + ), + ) + self.common( + fn, + ( + make_tensor((), device="cpu", dtype=torch.float32), + make_tensor(10, device="cpu", dtype=torch.float32), + ), + ) + def test_log2(self): def fn(x): return torch.log2(x), torch.log2(x + 1) - 2 @@ -2323,6 +3057,19 @@ def fn(a, b): ), ) + def test_index3(self): + def fn(x, ia, ib): + return (x[:, ia, None, ib, 0],) + + self.common( + fn, + ( + torch.randn(3, 4, 4, 4, 3), + torch.tensor([0, 2, 1], dtype=torch.int64), + torch.tensor([0, 2, 1], dtype=torch.int64), + ), + ) + def test_index_select(self): def fn(a, b): return ( @@ -2340,8 +3087,6 @@ def fn(a, b): ), ) - # https://github.com/pytorch/torchdynamo/issues/467 - @patch.object(torch._dynamo.config, "fake_tensor_propagation", False) def test_cudnn_rnn(self): if self.device == "cpu": raise unittest.SkipTest("requires CUDA") @@ -2447,10 +3192,10 @@ def fn(a): def test_upsample_nearest2d(self): def fn(a): return ( - aten.upsample_nearest2d(a, [74, 76], None), - aten.upsample_nearest2d(a, [70, 75], None), - aten.upsample_nearest2d(a, [45, 74], None), - aten.upsample_nearest2d(a, [36, 39], None), + aten.upsample_nearest2d(a, [74, 76]), + aten.upsample_nearest2d(a, [70, 75]), + aten.upsample_nearest2d(a, [45, 74]), + aten.upsample_nearest2d(a, [36, 39]), aten.upsample_nearest2d(a, None, [2.0, 2.0]), ) @@ -2469,25 +3214,15 @@ def fn(a): self.common(fn, (torch.randn([2, 4, 37, 38, 39]),)) def test_upsample_nearest2d_backward(self): - func = torch.ops.aten.upsample_nearest2d_backward.vec + func = torch.ops.aten.upsample_nearest2d_backward def fn(a): return ( - func( - a, output_size=[6, 12], input_size=[3, 3, 3, 6], scale_factors=None - ), - func( - a, output_size=[6, 12], input_size=[3, 3, 4, 5], scale_factors=None - ), - func( - a, output_size=[6, 12], input_size=[3, 3, 2, 8], scale_factors=None - ), - func( - a, output_size=[6, 12], input_size=[3, 3, 2, 8], scale_factors=None - ), - func( - a, output_size=[6, 12], input_size=[3, 3, 4, 7], scale_factors=None - ), + func(a, output_size=[6, 12], input_size=[3, 3, 3, 6]), + func(a, output_size=[6, 12], input_size=[3, 3, 4, 5]), + func(a, output_size=[6, 12], input_size=[3, 3, 2, 8]), + func(a, output_size=[6, 12], input_size=[3, 3, 2, 8]), + func(a, output_size=[6, 12], input_size=[3, 3, 4, 7]), ) self.common(fn, (torch.randn([3, 3, 6, 12]),)) @@ -2686,6 +3421,16 @@ def fn(x, y): out_eager = (inputs[0] + inputs[1].float()).add_(inputs[1]).mul_(inputs[1]) self.assertTrue(same(out, out_eager)) + @patch.object(config.triton, "ordered_kernel_names", True) + @patch.object(config.triton, "descriptive_kernel_names", False) + def test_kernel_names(self): + @torch._dynamo.optimize("inductor") + def fn(x): + return 2 * x + + inputs = (rand_strided((8,), (1,), device=self.device),) + self.assertTrue(same(fn(*inputs), 2 * inputs[0])) + @patch.object(config.triton, "cudagraphs", True) def test_strided_inputs(self): @torch._dynamo.optimize("inductor") @@ -2730,7 +3475,9 @@ def fn(a): c = a + 2 return b, c - arg1 = torch.randn([1, 64], device=self.device) + # NOTE: this test fails when none of the inputs require grad. + # That seems like an inductor bug. + arg1 = torch.randn([1, 64], device=self.device).requires_grad_(True).add(1) arg2 = arg1.clone() correct1 = fn(arg1) opt_fn = torch._dynamo.optimize_assert(compile_fx)(fn) @@ -2819,7 +3566,7 @@ def fn(in_ptr0, in_ptr1, in_ptr2): ), ) - @unittest.skipIf(not has_torchvision_roi_align(), "requirs torchvision") + @unittest.skipIf(not has_torchvision_roi_align(), "requires torchvision") def test_roi_align(self): def fn(a, b): return torch.ops.torchvision.roi_align(a, b, 0.25, 7, 7, 2, False) @@ -2856,6 +3603,17 @@ def fn(x): ], ) + def test_isinf2(self): + def fn(x): + y = torch.tensor( + [1, float("inf"), 2, float("-inf"), float("nan")], device=self.device + ) + return x == y + + self.common( + fn, (torch.tensor([1, float("inf"), 2, float("-inf"), float("nan")]),) + ) + def test_any(self): def fn(x): return ( @@ -3088,6 +3846,18 @@ def fn(x): self.common(fn, [torch.randn(64, 64)]) + def test_as_strided_scatter(self): + def fn(a, b): + return aten.as_strided_scatter( + a * 8 + 10, + b * 2 - 4, + size=(a.shape[0], a.shape[1] // 2), + stride=(a.shape[1], 2), + storage_offset=0, + ) + + self.common(fn, [torch.randn(10, 1024), torch.randn(10, 512)]) + def test_select_scatter(self): def fn(x, a, b): return ( @@ -3146,6 +3916,9 @@ def fn(a, dim, index, b): ) def test_scatter2(self): + if self.device == "cuda": + raise unittest.SkipTest("unstable on sm86") + def fn(a, dim, index, b): return aten.scatter.reduce(a, dim, index, b, reduce="add") @@ -3260,6 +4033,11 @@ def fn(a, dim, index, b): # issue #1150 def test_dense_mask_index(self): + if self.device == "cpu": + raise unittest.SkipTest( + "https://github.com/pytorch/torchdynamo/issues/1697" + ) + def fn(x, y): y = torch.ops.aten.select.int(y, 0, 2) z = x * y @@ -3279,13 +4057,24 @@ def test_dropout(self): torch.manual_seed(1234) @torch._dynamo.optimize("inductor") - def fn(a): - return torch.nn.functional.dropout(a, 0.5, True) + def fn1(a): + return torch.nn.functional.dropout(a) x = torch.ones(1000, device=self.device, dtype=torch.float32) - result = fn(x) - self.assertTrue(400 < result.nonzero().shape[0] < 600) - self.assertTrue(0.9 < result.mean().item() < 1.1) + result1 = fn1(x) + self.assertTrue(400 < result1.nonzero().shape[0] < 600) + self.assertTrue(0.9 < result1.mean().item() < 1.1) + + random.seed(1234) + torch.manual_seed(1234) + + @torch._dynamo.optimize("inductor") + def fn2(a): + return torch.nn.functional.dropout(a, 0.5, True) + + result2 = fn2(x) + self.assertTrue(400 < result2.nonzero().shape[0] < 600) + self.assertTrue(0.9 < result2.mean().item() < 1.1) def test_dropout_deterministic(self): @torch._dynamo.optimize("inductor") @@ -3425,6 +4214,60 @@ def fn(a, b, c): ], ) + # From https://github.com/pytorch/torchdynamo/issues/1352 + def test_max_pool2d_with_indices_backward4(self): + def fn(a, b, c): + return aten.max_pool2d_with_indices_backward( + a, b, [5, 5], [1, 1], [2, 2], [1, 1], False, c + ) + + torch._inductor.metrics.generated_kernel_count = 0 + x = torch.randn([2, 64, 3, 4]) + result, indices = aten.max_pool2d_with_indices( + x, + [5, 5], + [1, 1], + 2, + 1, + False, + ) + self.common( + fn, + [ + torch.randn_like(result), + x, + indices, + ], + ) + self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) + + def test_max_pool2d_with_indices_backward5(self): + # Window size is too big. Should fallback + def fn(a, b, c): + return aten.max_pool2d_with_indices_backward( + a, b, [13, 13], [1, 1], [2, 2], [1, 1], False, c + ) + + torch._inductor.metrics.generated_kernel_count = 0 + x = torch.randn([2, 64, 20, 20]) + result, indices = aten.max_pool2d_with_indices( + x, + [13, 13], + [1, 1], + 2, + 1, + False, + ) + self.common( + fn, + [ + torch.randn_like(result), + x, + indices, + ], + ) + self.assertEqual(torch._inductor.metrics.generated_kernel_count, 0) + def test_avg_pool2d_backward(self): def fn(a, b): return aten.avg_pool2d_backward( @@ -3480,6 +4323,7 @@ def fn(a, b): None, ) + torch._inductor.metrics.generated_kernel_count = 0 self.common( fn, [ @@ -3487,6 +4331,31 @@ def fn(a, b): torch.randn([1, 2016, 21, 21]), ], ) + self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) + + def test_avg_pool2d_backward4(self): + def fn(a, b): + return aten.avg_pool2d_backward( + a, + b, + [13, 13], + [1, 1], + [0, 0], + True, + False, + None, + ) + + torch._inductor.metrics.generated_kernel_count = 0 + self.common( + fn, + [ + torch.randn([1, 16, 12, 12]), + torch.randn([1, 16, 24, 24]), + ], + check_lowp=False, + ) + self.assertEqual(torch._inductor.metrics.generated_kernel_count, 0) def test_mm_views(self): def fn(a, b): @@ -3642,6 +4511,78 @@ def fn(x): rtol=0.5, ) + def test_conv_backward(self): + def fn(rank4_inps, rank3_inps, rank5_inps): + + out1 = aten.convolution_backward( + *rank4_inps, + [C], + [1, 1], + [0, 0], + [1, 1], + False, + [0, 0], + 1, + [True, True, True], + ) + out2 = aten.convolution_backward( + *rank4_inps, + [C], + [1, 1], + [0, 0], + [1, 1], + False, + [0, 0], + 1, + [True, False, False], + ) + out3 = aten.convolution_backward( + *rank3_inps, + [C], + [1], + [0], + [1], + False, + [0], + 1, + [True, True, True], + ) + out4 = aten.convolution_backward( + *rank5_inps, + [C], + [1, 1, 1], + [0, 0, 0], + [1, 1, 1], + False, + [0, 0, 0], + 1, + [True, True, True], + ) + return (out1, out2, out3, out4) + + B = 3 + C = 4 + H = 5 + grad_out = torch.randn(B, C, H - 2, H - 2, H - 2) + inp = torch.randn(B, C, H, H, H) + weight = torch.randn(C, C, 3, 3, 3) + + def shrink_rank(x, rank): + res = x + while res.dim() > rank: + res = torch.select(res, -1, 0) + return res.contiguous() + + rank4_inps = [shrink_rank(x, 4) for x in [grad_out, inp, weight]] + rank3_inps = [shrink_rank(x, 4) for x in [grad_out, inp, weight]] + rank5_inps = [shrink_rank(x, 5) for x in [grad_out, inp, weight]] + + with torch.backends.cudnn.flags(allow_tf32=False): + self.common( + fn, + [rank4_inps, rank3_inps, rank5_inps], + ) + @unittest.skip( """ FIXME: In the case of having equally max/min elements, our implementation returns @@ -3740,7 +4681,10 @@ def forward(arg38_1, arg81_1, getitem_17, new_zeros_default_4): ((1, 88, 40, 40), (140800, 1600, 40, 1), torch.float32), ((3,), (1,), torch.float32), ] - args = [rand_strided(shape, stride, dtype) for shape, stride, dtype in args] + args = [ + rand_strided(shape, stride, dtype).requires_grad_(True).add(1) + for shape, stride, dtype in args + ] self.common(forward, args) def test_misaligned_address_issue1(self): @@ -3801,35 +4745,50 @@ def forward(x): ] self.common(forward, args) - @requires_cuda() - def test_unspec_inputs(self): - def fn(x, y): - return x + y, x * y, x / y + def test_zero_dim_reductions(self): + for kd in [True, False]: + inps0 = (torch.zeros(2, 0, device=self.device, dtype=torch.float16), 1, kd) + failed_ops = [aten.argmin, aten.argmax, aten.max, aten.min] + for fo in failed_ops: + with self.assertRaisesRegex( + IndexError, "Expected reduction dim 1 to have non-zero size" + ): + mod = make_fx(fo)(*inps0) + _ = compile_fx_inner(mod, inps0) - opt = torch._dynamo.optimize("inductor")(fn) + pass_ops = [ + lambda *x: fn(*x) for fn in [aten.sum, aten.prod, aten.any, aten.all] + ] + for po in pass_ops: + compiled = torch._dynamo.optimize("inductor")(po) + expected = po(*inps0) + actual = compiled(*inps0) - inputs = ( - rand_strided((2, 3), (3, 1), device="cuda"), - rand_strided((), (), device="cpu"), - ) - self.assertTrue(same(opt(*inputs), fn(*inputs))) - inputs = (inputs[1], inputs[0]) - self.assertTrue(same(opt(*inputs), fn(*inputs))) + self.assertTrue(torch.allclose(actual, expected, atol=1e-3, rtol=1e-3)) @requires_cuda() - def test_unspec_inputs_fp16(self): + def test_unspec_inputs(self): def fn(x, y): return x + y, x * y, x / y opt = torch._dynamo.optimize("inductor")(fn) + dtypes = [ + torch.float16, + torch.bfloat16, + torch.float32, + torch.float64, + torch.int32, + torch.int64, + ] - inputs = ( - rand_strided((2, 3), (3, 1), dtype=torch.float16, device="cuda"), - rand_strided((), (), dtype=torch.float16, device="cpu"), - ) - self.assertTrue(same(opt(*inputs), fn(*inputs))) - inputs = (inputs[1], inputs[0]) - self.assertTrue(same(opt(*inputs), fn(*inputs))) + for d in dtypes: + inputs = ( + rand_strided((2, 3), (3, 1), dtype=torch.float32, device="cuda"), + rand_strided((), (), dtype=d, device="cpu"), + ) + self.assertTrue(same(opt(*inputs), fn(*inputs))) + inputs = (inputs[1], inputs[0]) + self.assertTrue(same(opt(*inputs), fn(*inputs))) @patch.object(config.triton, "mm", "aten") def test_list_clearing(self): @@ -3888,9 +4847,72 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): else: self.assertEqual(len(inps), 0) + def test_dtype_mismatch_issue(self): + def fn(x): + attn = torch.nn.functional.pad(x, [0, 1]) + return attn.softmax(dim=-1) + + x = torch.rand(128, 32, 63) + res_ref = fn(x) + res = torch._dynamo.optimize("inductor")(fn)(x) + self.assertEqual(res, res_ref) + + @unittest.skipIf(HAS_CUDA, "histogramdd only supports cpu") + def test_kwargs(self): + def fn(x, y): + return torch.histogramdd( + x, + bins=[3, 3], + weight=y, + ) + + self.common( + fn, + [torch.randn((4, 2)), torch.randn((4))], + ) + + @patch.object(config, "profiler_mark_wrapper_call", True) + def test_profiler_mark_wrapper_call(self): + from torch.profiler import profile + + @torch._dynamo.optimize("inductor", nopython=True) + def fn(a, b): + return a + b + + a = torch.rand((100,)) + b = torch.rand((100,)) + with profile() as prof: + fn(a, b) + assert "inductor_wrapper_call" in ( + e.name for e in prof.profiler.function_events + ) + + @patch.object(config, "cpp_wrapper", True) + @unittest.skipIf(HAS_CUDA, "cpp_wrapper only supports cpu") + def test_cpp_wrapper(self): + device = "cpu" + for name in [ + "test_as_strided", # buffer reuse + "test_cat", # alias + "test_profiler_mark_wrapper_call", # TODO: fallback to default wrapper for now + "test_relu", # multiple inputs + "test_silu", # single input, single output + "test_transpose", # multiple outputs, buffer clear + ]: + test_name = f"{name}_{device}" + assert hasattr(self, test_name), "undefined function" + func = getattr(self, test_name) + assert callable(func), "not a callable" + func() + if HAS_CPU: + class SweepInputsCpuTest(SweepInputs2, TestCase): + gen = InputGen(10, "cpu") + + SweepInputsCpuTest.populate() + class CpuTests(TestCase): common = check_model device = "cpu" @@ -3898,6 +4920,45 @@ class CpuTests(TestCase): CommonTemplate.install(CpuTests, "cpu") class CPUReproTests(TestCase): + def test_conv_stride_constraints(self): + for fmt in [torch.channels_last, torch.contiguous_format]: + # TorchDispatch doesn't work in our cuda invocation for some reason + m = torch.nn.Conv2d(5, 6, [3, 3]) + + def fn(inp, weight): + return ( + F.conv2d( + inp, weight, None, m.stride, m.padding, m.dilation, m.groups + ), + ) + + inp = torch.randn([2, 5, 16, 16]) + inps = [inp, m.weight.to(memory_format=fmt)] + fn_fx = make_fx(fn)(*inps) + fn_compiled = compile_fx_inner(fn_fx, inps) + test_self = self + conv_seen = False + + class RecordFunctions(TorchDispatchMode): + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + kwargs = kwargs if kwargs else {} + if func == torch.ops.aten.convolution.default: + test_self.assertTrue( + args[0].is_contiguous(memory_format=fmt) + ) + test_self.assertTrue( + args[1].is_contiguous(memory_format=fmt) + ) + nonlocal conv_seen + conv_seen = True + + return func(*args, **kwargs) + + with RecordFunctions(): + out = fn_compiled(inps) + + self.assertTrue(conv_seen) + def test_inplace_squeeze_needed(self): mod = torch.nn.Sequential( torch.nn.Linear(10, 10), @@ -3911,7 +4972,11 @@ def fn(x): v = torch.randn(10) result = fn(v) - assert same(result, mod(v)) + # TODO: OMP parallel reduction order is not deterministic. + # Hence, the accurarcy might vary up and down. For short term, + # we increase the tolerance and will fix it later by using + # aten parallel. + assert same(result, mod(v), tol=5e-1) def test_inplace_add_alpha(self): def fn(x, y): @@ -3980,6 +5045,255 @@ def test_complex_memory_overlap(self): self.assertFalse(complex_memory_overlap(gathered)) self.assertFalse(complex_memory_overlap(gathered.t())) + @unittest.skipIf( + not codecache.valid_vec_isa_list(), "Does not support vectorization" + ) + @patch.object(config, "dynamic_shapes", True) + @patch.object(torch._dynamo.config, "dynamic_shapes", True) + @patch.object(functorch_config, "use_dynamic_shapes", True) + def test_vec_dynamic_shapes(self): + def fn(x): + return torch.softmax(x, -1) + + value = torch.randn((2, 10)) + with patch.object(config.cpp, "simdlen", None): + torch._dynamo.reset() + metrics.reset() + opt_fn = torch._dynamo.optimize("inductor")(fn) + opt_fn(value) + + real_out = fn(value) + compiled_out = opt_fn(value) + assert same(real_out, compiled_out, equal_nan=True) + assert metrics.generated_cpp_vec_kernel_count < 1 + + @unittest.skipIf( + not codecache.valid_vec_isa_list(), "Does not support vectorization" + ) + @patch("torch.cuda.is_available", lambda: False) + def test_auto_simd(self): + vec_avx512 = codecache.supported_vec_isa_list[0] + vec_avx2 = codecache.supported_vec_isa_list[1] + self.assertTrue(vec_avx512.bit_width() == 512) + self.assertTrue(vec_avx2.bit_width() == 256) + self.assertTrue(vec_avx512.nelements() == 16) + self.assertTrue(vec_avx2.nelements() == 8) + self.assertTrue(vec_avx512.nelements(torch.bfloat16) == 32) + self.assertTrue(vec_avx2.nelements(torch.bfloat16) == 16) + + with patch.object(config.cpp, "simdlen", None): + isa = codecache.pick_vec_isa() + if vec_avx512 in codecache.valid_vec_isa_list(): + self.assertTrue(isa == vec_avx512) + else: + self.assertTrue(isa == vec_avx2) + + with patch.object(config.cpp, "simdlen", 0): + isa = codecache.pick_vec_isa() + self.assertFalse(isa) + + with patch.object(config.cpp, "simdlen", 1): + isa = codecache.pick_vec_isa() + self.assertFalse(isa) + + with patch.object(config.cpp, "simdlen", 257): + isa = codecache.pick_vec_isa() + self.assertFalse(isa) + + with patch.object(config.cpp, "simdlen", 513): + isa_list = codecache.valid_vec_isa_list() + if vec_avx512 in isa_list: + self.assertFalse(isa) + + with patch.object(config.cpp, "simdlen", 512): + isa_list = codecache.valid_vec_isa_list() + if vec_avx512 in isa_list: + isa = codecache.pick_vec_isa() + self.assertTrue(isa == vec_avx512) + + with patch.object(config.cpp, "simdlen", 256): + isa_list = codecache.valid_vec_isa_list() + if vec_avx2 in isa_list: + isa = codecache.pick_vec_isa() + self.assertTrue(isa == vec_avx2) + + @unittest.skipIf( + not codecache.valid_vec_isa_list(), "Does not support vectorization" + ) + @patch("torch.cuda.is_available", lambda: False) + def test_masked_fill_softmax(self): + def fn(value, mask): + mask = mask.to(torch.bool) + x = torch.masked_fill(value, mask, -33.0) + return torch.softmax(x, -1) + + value = torch.randn((2, 17)) + mask = torch.randint(0, 1, size=(2, 17), dtype=torch.uint8) + with patch.object(config.cpp, "simdlen", None): + torch._dynamo.reset() + metrics.reset() + opt_fn = torch._dynamo.optimize("inductor")(fn) + opt_fn(value, mask) + + real_out = fn(value, mask) + compiled_out = opt_fn(value, mask) + assert same(real_out, compiled_out, equal_nan=True) + assert metrics.generated_cpp_vec_kernel_count >= 1 + + def test_cpu_vec_cosim(self): + cpp_vec_op_list = [] + cpp_op_list = [] + + for k, v in CppVecOverrides.__dict__.items(): + if isinstance(v, staticmethod): + cpp_vec_op_list.append(k) + for k, v in CppOverrides.__dict__.items(): + if isinstance(v, staticmethod): + cpp_op_list.append(k) + + self.assertEqual(cpp_op_list.sort(), cpp_vec_op_list.sort()) + + @unittest.skipIf( + not codecache.valid_vec_isa_list(), "Does not support vectorization" + ) + @patch("torch.cuda.is_available", lambda: False) + def test_erf_cpu_only(self): + def fn(x): + return (torch.erf(x),) + + x = torch.randn((2, 9)) + x[0, 0] = torch.nan + x[1, -1] = torch.nan + + with patch.object(config.cpp, "simdlen", None): + torch._dynamo.reset() + metrics.reset() + traced = make_fx(fn)(x) + compiled = compile_fx_inner(traced, [x]) + assert same(fn(x)[0], compiled([x])[0], equal_nan=True) + assert metrics.generated_cpp_vec_kernel_count == 1 + + @unittest.skipIf( + not codecache.valid_vec_isa_list(), "Does not support vectorization" + ) + @patch("torch.cuda.is_available", lambda: False) + def test_sign_cpu_only(self): + def fn(x): + return (torch.sign(x),) + + x = torch.randn((2, 9)) + x[0, 0] = torch.nan + x[1, -1] = torch.nan + + with patch.object(config.cpp, "simdlen", None): + torch._dynamo.reset() + metrics.reset() + traced = make_fx(fn)(x) + compiled = compile_fx_inner(traced, [x]) + assert same(fn(x)[0], compiled([x])[0], equal_nan=True) + assert metrics.generated_cpp_vec_kernel_count == 1 + + # Currently, we enabled AVX2 and AVX512 for vectorization. If the platform is not + # supported, the vectorization will not work and skip this test case. For ARM or + # other platforms support, we just need to add the ISA info to the supported_vector_isa + # and include proper aten vectorization head file. + @unittest.skipIf( + not codecache.valid_vec_isa_list(), "Does not support vectorization" + ) + @patch("torch.cuda.is_available", lambda: False) + def test_vec_kernel_cpu_only(self): + def fn(x1, x2): + # Current, there are some limitations as follows. + # rsqrt: + # assert [both a fallback and a decomp for same kernel: aten.rsqrt.default] + # round: + # couldn't find symbolic meta function/decomposition + # fmod/logical_and/logic_or: + # vec kernel has not support to_type + x = torch.abs(x1) + x = torch.sin(x) + x = torch.neg(x) + x = torch.square(x) + x = torch.sigmoid(x) + x = torch.relu(x) + x = torch.cos(x) + x = torch.exp(x) + x = torch.sqrt(x) + x = torch.add(x, x1) + x = torch.sub(x, x2) + x = torch.mul(x, x1) + x = torch.div(x, x1) + x = torch.pow(x, 10) + x = torch.log(x) + x = torch.floor(x) + x = torch.ceil(x) + x = torch.trunc(x) + x = torch.lgamma(x) + x = torch.fmod(x, x2) + x = torch.sign(x) + res = x + x2 + return (res,) + + x1 = torch.randn((10, 20)) + x2 = torch.randn((10, 20)) + + with patch.object(config.cpp, "simdlen", 1): + torch._dynamo.reset() + metrics.reset() + traced = make_fx(fn)(x1, x2) + compiled = compile_fx_inner(traced, [x1, x2]) + assert same(fn(x1, x2)[0], compiled([x1, x2])[0], equal_nan=True) + assert metrics.generated_cpp_vec_kernel_count == 0 + + with patch.object(config.cpp, "simdlen", None): + torch._dynamo.reset() + metrics.reset() + traced = make_fx(fn)(x1, x2) + compiled = compile_fx_inner(traced, [x1, x2]) + assert same(fn(x1, x2)[0], compiled([x1, x2])[0], equal_nan=True) + assert metrics.generated_cpp_vec_kernel_count == 1 + + torch._dynamo.reset() + metrics.reset() + x1 = x1.permute(1, 0) + x2 = torch.randn((20, 10)) + traced = make_fx(fn)(x1, x2) + compiled = compile_fx_inner(traced, [x1, x2]) + assert same(fn(x1, x2)[0], compiled([x1, x2])[0], equal_nan=True) + assert metrics.generated_cpp_vec_kernel_count == 1 + + torch._dynamo.reset() + metrics.reset() + x1 = torch.randn((10, 7)) + x2 = torch.randn((10, 7)) + traced = make_fx(fn)(x1, x2) + compiled = compile_fx_inner(traced, ([x1, x2])) + assert same(fn(x1, x2)[0], compiled([x1, x2])[0], equal_nan=True) + assert metrics.generated_cpp_vec_kernel_count == 1 + + @unittest.skipIf( + sys.platform != "linux", "cpp kernel profile only support linux now" + ) + @patch("torch.cuda.is_available", lambda: False) + @patch.object(config.cpp, "enable_kernel_profile", True) + def test_cpp_kernel_profile(self): + from torch.profiler import profile + + @torch._dynamo.optimize("inductor", nopython=True) + def fn(a, b): + return a + b + + a = torch.rand((100,)) + b = torch.rand((100,)) + with profile() as prof: + fn(a, b) + + kernel_profile_events = [] + for e in prof.profiler.function_events: + if "kernel_cpp_0" in e.name: + kernel_profile_events.append(e.name) + assert len(kernel_profile_events) > 0 + if HAS_CUDA: import triton @@ -4002,9 +5316,119 @@ def fn(a): fn, (torch.randn(2, 3, 10, 5, 6, device="cuda")[:, :, 2::2, :, :],) ) + def test_linear_permute_fusion(self): + class TestModule(torch.nn.Module): + def __init__(self, k: int, n: int): + super().__init__() + self.weight = torch.nn.Parameter(torch.randn(n, k)) + self.bias = torch.nn.Parameter(torch.randn(n)) + + def forward(self, input: torch.Tensor): + a0 = torch.nn.functional.linear(input, self.weight, self.bias) + b0 = a0.permute(0, 2, 1) + return b0 + + m, k, n = 16, 8, 4 + trace_func = chain_passes(torch.fx.symbolic_trace, linear_permute_fusion) + module = TestModule(k, n).eval() + input = torch.randn(6, m, k) + traced = trace_func(module, [input]) + num_linear = count_call_function(traced, torch.nn.functional.linear) + num_linear_transpose = count_call_function(traced, linear_transpose) + self.assertEqual(num_linear, 0) + self.assertEqual(num_linear_transpose, 1) + + self.assertTrue(torch.allclose(module(input), traced(input))) + + @patch.object(config.triton, "autotune", True) + def test_inplace_add_alpha_autotune(self): + def fn(x, y): + aten.add_.Tensor(x, y, alpha=0.55) + return (x,) + + x1 = torch.zeros(2, 3, 4, 10, device="cuda") + x2 = torch.zeros(2, 3, 4, 10, device="cuda") + x3 = torch.zeros(2, 3, 4, 10, device="cuda") + y = torch.randn(2, 3, 4, 10, device="cuda").to( + memory_format=torch.channels_last + ) + fn_fx = make_fx(fn)(x1, y) + fn_compiled = compile_fx_inner(fn_fx, [x1, y]) + fn(x2, y) + fn_compiled([x3, y]) + assert same(x2, x3) + + @patch.object(config.triton, "autotune", True) + def test_inplace_buffer_autotune(self): + def foo(x, y, z): + a = x @ y + return a.unsqueeze(0).unsqueeze(0) + z + + x = torch.zeros(5, 5, device="cuda") + y = torch.zeros(5, 5, device="cuda") + z = torch.zeros(1, 1, 5, 5, device="cuda").to( + memory_format=torch.channels_last + ) + self.common( + foo, + (x, y, z), + check_lowp=False, + ) + + def test_permute_linear_fusion(self): + class TestModule(torch.nn.Module): + def __init__(self, k: int, n: int): + super().__init__() + self.weight = torch.nn.Parameter(torch.randn(n, k)) + self.bias = torch.nn.Parameter(torch.randn(n)) + + def forward(self, input: torch.Tensor): + input1 = input.permute(0, 2, 1) + output = torch.nn.functional.linear(input1, self.weight, self.bias) + return output + + m, k, n = 16, 8, 4 + + trace_func = chain_passes(torch.fx.symbolic_trace, permute_linear_fusion) + module = TestModule(k, n).eval() + input = torch.randn(6, k, m) + traced = trace_func(module, [input]) + num_linear = count_call_function(traced, torch.nn.functional.linear) + num_transpose_linear = count_call_function(traced, transpose_linear) + self.assertEqual(num_linear, 0) + self.assertEqual(num_transpose_linear, 1) + + self.assertTrue(torch.allclose(module(input), traced(input))) + + def test_permute_bmm_fusion(self): + class TestModule(torch.nn.Module): + def __init__(self, batch: int, k: int, n: int): + super().__init__() + self.other = torch.randn(batch, k, n) + + def forward(self, input: torch.Tensor): + input1 = input.permute(0, 2, 1) + output = torch.bmm(input1, self.other) + return output + + batch, m, k, n = 6, 16, 8, 4 + + trace_func = chain_passes(torch.fx.symbolic_trace, permute_matmul_fusion) + module = TestModule(batch, k, n).eval() + input = torch.randn(batch, k, m) + traced = trace_func(module, [input]) + num_bmm = count_call_function(traced, torch.bmm) + num_transpose_matmul = count_call_function(traced, transpose_matmul) + self.assertEqual(num_bmm, 0) + self.assertEqual(num_transpose_matmul, 1) + + self.assertTrue(torch.allclose(module(input), traced(input))) + CommonTemplate.install(CudaTests, "cuda") class CudaReproTests(TestCase): + common = check_model_cuda + def test_index_put_issue(self): def forward( self, @@ -4041,6 +5465,30 @@ def forward( compiled = compile_fx_inner(mod, inps) compiled(inps) + @requires_cuda() + def test_input_channels_last(self): + m = torch.nn.Sequential( + torch.nn.Conv2d(3, 3, 1, 1), + ToTuple(), + ).cuda() + inp = ( + torch.randn([2, 3, 16, 16]).to(memory_format=torch.channels_last).cuda() + ) + + self.common( + m, + (inp,), + check_lowp=False, + ) + + @torch._dynamo.optimize() + def foo(m, inp): + return m(inp) + + self.assertTrue( + foo(m, inp)[0].is_contiguous(memory_format=torch.channels_last) + ) + # https://github.com/pytorch/torchdynamo/issues/1681#issuecomment-1283433527 @requires_cuda() def test_unspec_inputs_interop(self): @@ -4165,6 +5613,23 @@ def forward(self, x): for param in model_opt.parameters(): param.add_(1.0) + # https://github.com/pytorch/torchdynamo/issues/1850 + def test_inductor_output_aliases_intermediate(self): + def foo(x): + out = x + x + return out.t() + + foo_opt = torch._dynamo.optimize("inductor")(foo) + + inpt = torch.randn(10, 10, device="cuda", requires_grad=True) + # TODO: this is broken, fix later + # out = foo_opt(inpt) + # out.add_(2) + + out_ref = foo(inpt) + out_ref.add_(2) + # self.assertEqual(out_ref, out) + def test_accuracy_issue1(self): class Repro(torch.nn.Module): def __init__(self): @@ -4219,6 +5684,7 @@ def decorator(fn): meta=meta, configs=configs, save_cache_hook=False, + mutated_arg_names=["in_out_ptr0"], ) return decorator @@ -4279,6 +5745,146 @@ def forward(pred_objectness_logits_3_: torch.Tensor): result = forward(*args) assert same(result, torch.sort(args[0], descending=True, dim=1)[0]) + @requires_cuda() + def test_scalar_triton_index(self): + # The indirect indexing via a scalar like below used to lead to + # bad triton code that made triton segfault when compiling. + # See https://github.com/pytorch/torchdynamo/issues/1515 + def fn(a): + zero = torch.zeros((16,), device=a.device, dtype=torch.int64) + return (a[zero],) + + a = torch.randn((8,), dtype=torch.float32, device="cuda") + + fn_optimized = torch._dynamo.optimize("inductor")(fn) + assert same(fn(a), fn_optimized(a)) + + @requires_cuda() + def test_indirect_indexing_dense_mask(self): + def fn(x, y): + ne = torch.ops.aten.ne.Scalar(x, 1) + sum_1 = torch.ops.aten.sum.dim_IntList(ne, [1]) + sub = torch.ops.aten.sub.Tensor(sum_1, 1) + unsqueeze = torch.ops.aten.unsqueeze.default(sub, -1) + gather = torch.ops.aten.gather.default(x, 1, unsqueeze) + squeeze = torch.ops.aten.squeeze.default(gather) + out = torch.ops.aten.multiply(y, squeeze) + return (out,) + + a = torch.zeros((1, 128), dtype=torch.int64, device="cuda") + b = torch.zeros((1, 128), dtype=torch.int64, device="cuda") + + fn_optimized = torch._dynamo.optimize("inductor")(fn) + assert same(fn(a, b), fn_optimized(a, b)) + + class TritonCodeGenTests(TestCase): + from torch._inductor.triton_ops.autotune import CachingAutotuner + + class NoOpCompilerBackend: + def __init__(self): + self.example_args = None + self.model = None + + def noop_backend( + self, + model_: torch.fx.GraphModule, + example_inputs_: typing.List[torch.Tensor], + ): + """ + The Noop backend does not compile the fx graph it is given. + Instead, it transforms the fx graph so that its functions are + aten operations. It then saves this graph. + """ + from torch._functorch.aot_autograd import Interpreter + from torch._inductor.decomposition import select_decomp_table + from torch._subclasses import FakeTensorMode + + fake_mode = FakeTensorMode() + + def interpret(*args, **kwargs): + return Interpreter(model_).run(*args[0:], **kwargs) + + fake_flat_tensor_args = [ + fake_mode.from_tensor(x) for x in example_inputs_ + ] + fw_module = make_fx(interpret, select_decomp_table())( + *fake_flat_tensor_args + ) + self.model = fw_module + self.example_args = fake_flat_tensor_args + return lambda x: example_inputs_ + + def get_kernels(self, fn, args) -> typing.List[CachingAutotuner]: + from torch._inductor.debug import DebugContext + from torch._inductor.graph import GraphLowering + from torch._inductor.virtualized import V + + cxt = TritonCodeGenTests.NoOpCompilerBackend() + torch._dynamo.optimize(backend=cxt.noop_backend)(fn)(*args) + graph = GraphLowering(cxt.model) + graph.num_static_inputs = 0 + kernels = [] + with V.set_graph_handler(graph), V.set_debug_handler(DebugContext()): + graph.run(*(cxt.example_args)) + mod = graph.compile_to_module() + + for val in mod.__dict__.values(): + if isinstance( + val, torch._inductor.triton_ops.autotune.CachingAutotuner + ): + kernels.append(val) + + return kernels + + def test_divisibile_by_16_covers_numel_args(self): + torch._dynamo.reset() + + def fn(a: torch.Tensor) -> torch.Tensor: + return torch.sum(a) + + kernels = self.get_kernels(fn, [torch.randn([256, 256], device="cuda")]) + self.assertTrue(len(kernels) == 2, "SUM should result in two kernels") + + # kernel0 reduces from 256 to (xnumel=8, rnumel=8192), which means it reduces 256 by 256 into an array of + # size 8 by accumulating 8192 elements at once note that rnumel is equal to 512 * 16, so rnumel which is + # at slot 3 should be in the divisible by 16 descriptor + arguments_that_are_divisible_by_16_in_kernel0 = ( + kernels[0].meta["configs"][0].divisible_by_16 + ) + self.assertEqual(arguments_that_are_divisible_by_16_in_kernel0, (0, 1, 3)) + + # kernel1 reduces from 8 elements to a single scalar. + arguments_that_are_divisible_by_16_in_kernel1 = ( + kernels[1].meta["configs"][0].divisible_by_16 + ) + self.assertEqual(arguments_that_are_divisible_by_16_in_kernel1, (0, 1)) + torch._dynamo.reset() + + +class ExprPrinterTests(TestCase): + def test_print_pow(self): + s1 = sympy.Symbol("foo", integer=True) + s2 = sympy.Symbol("bar", integer=True) + s3 = sympy.Symbol("baz", integer=True) + + cases = ( + # expr, result + # Test exprs. + ( + s1 / (2 * s1 - 1) - 1 / (2 * s1 - 1), + "((-1)*(1/(((-1) + (2*foo))))) + (foo*(1/(((-1) + (2*foo)))))", + ), + (s1 / (s2 - s3), "foo*(1/((bar + ((-1)*baz))))"), + # Test Pow directly. + (sympy.Pow(s1 + s2, 0), "1"), # note: simplified before _print_Pow + (sympy.Pow(s1 + s2, -3), "1/((bar + foo)*(bar + foo)*(bar + foo))"), + (sympy.Pow(s1 + s2, 2), "(bar + foo)*(bar + foo)"), + ) + + for expr, result in cases: + self.assertEqual(cexpr(expr), result) + self.assertEqual(texpr(expr), result) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index e0638341eaa2c..10e6cf1783ef0 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -16,20 +16,22 @@ onlyNativeDeviceTypes, OpDTypes, ops, + skipCPUIf, + skipCUDAIf, ) from torch.testing._internal.common_methods_invocations import op_db from torch.testing._internal.common_utils import ( dtype_abbrs, run_tests, skipCUDAMemoryLeakCheckIf, + skipIfCrossRef, + skipIfTorchDynamo, suppress_warnings, - TEST_WITH_ROCM, TestCase, ) +from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA try: - from torch._inductor.utils import has_triton - try: from .test_torchinductor import check_model, check_model_cuda except ImportError: @@ -120,26 +122,20 @@ def process(device_type): inductor_skips["cpu"] = { "linalg.ldl_solve": {b8, f16, f32, f64, i32, i64}, # segfault + "linalg.ldl_factor": {f32, f64}, # flaky "__rdiv__": {b8, f16, f32, f64, i32, i64}, # flaky } inductor_skips["cuda"] = { - # flaky - "__rdiv__": {b8, f16, f32, f64, i32, i64}, - "masked.prod": {f16, f32, f64}, - "linalg.vander": {f32, f64}, - "sparse.sampled_addmm": {f32, f64}, - "broadcast_tensors": {f16, f32, f64}, - "dsplit": {f16, f32, f64}, # Jiterator kernel is not expected to work with inductor "jiterator_2inputs_2outputs": {b8, f16, f32, f64, i32, i64}, "jiterator_4inputs_with_extra_args": {b8, f16, f32, f64, i32, i64}, "jiterator_binary": {b8, f16, f32, f64, i32, i64}, "jiterator_binary_return_by_ref": {b8, f16, f32, f64, i32, i64}, "jiterator_unary": {b8, f16, f32, f64, i32, i64}, - # Disabled on migration to core - "linalg.pinv.singular": {f32, f64}, - "linalg.householder_product": {f32}, + # flaky + "native_batch_norm": {f16, f32, f64}, + "_native_batch_norm_legit": {f16, f32, f64}, } inductor_expected_failures_single_sample = defaultdict(dict) @@ -152,10 +148,15 @@ def process(device_type): "__getitem__": {b8, f16, f32, f64, i32, i64}, "addr": {f16}, "allclose": {f16, f32, f64}, + "amax": {f16}, + "amin": {f16}, "angle": {f16, f32, f64}, "argwhere": {b8, f16, f32, f64, i32, i64}, "bernoulli": {f32, f64}, "bincount": {i32, i64}, + "bucketize": {b8, f16, f32, f64, i32, i64}, + "cdouble": {b8, f16, f32, f64, i32, i64}, + "cfloat": {b8, f16, f32, f64, i32, i64}, "chalf": {b8, f16, f32, f64, i32, i64}, "cholesky": {f32, f64}, "combinations": {b8, f16, f32, f64, i32, i64}, @@ -165,17 +166,16 @@ def process(device_type): "corrcoef": {f32, f64, i32, i64}, "cov": {f32, f64, i32, i64}, "equal": {b8, f16, f32, f64, i32, i64}, - "erf": {b8, f64}, "fft.fft": {f32, f64}, "fft.fft2": {b8, f32, f64, i32, i64}, "fft.fftn": {b8, f32, f64, i32, i64}, "fft.hfft": {b8, f32, f64, i32, i64}, "fft.hfft2": {b8, f32, f64, i32, i64}, "fft.hfftn": {b8, f32, f64, i32, i64}, - "fft.ifft": {b8, f16, f32, f64, i32, i64}, + "fft.ifft": {f16, f32, f64}, "fft.ifft2": {b8, f32, f64, i32, i64}, "fft.ifftn": {b8, f32, f64, i32, i64}, - "fft.ihfft": {b8, f16, f32, f64, i32, i64}, + "fft.ihfft": {f16, f32, f64}, "fft.ihfft2": {f32, f64}, "fft.ihfftn": {f32, f64}, "fft.irfft": {b8, f32, f64, i32, i64}, @@ -185,41 +185,35 @@ def process(device_type): "fft.rfft2": {f32, f64}, "fft.rfftn": {f32, f64}, "index_add": {f16}, - "index_put": {f16, f32, f64}, "index_reduce": {f16, f32, f64}, "istft": {f32, f64}, - "linalg.cholesky": {f32, f64}, - "linalg.cholesky_ex": {f32, f64}, "linalg.eig": {f32, f64}, "linalg.eigh": {f32, f64}, "linalg.eigvals": {f32, f64}, "linalg.eigvalsh": {f32, f64}, - "linalg.ldl_factor": {f32, f64}, "linalg.lstsq": {f32, f64}, "linalg.lstsq.grad_oriented": {f32, f64}, "linalg.matrix_rank": {f32, f64}, "linalg.matrix_rank.hermitian": {f32, f64}, - "linalg.lu_solve": {f32, f64}, - "lu_solve": {f32, f64}, - "lu_unpack": {f32, f64}, - "logdet": {f32, f64}, + "linalg.pinv.singular": {f32, f64}, "masked.norm": {f16}, + "masked.normalize": {f16}, + "masked.var": {f16}, "masked_fill": {f16}, "masked_scatter": {f16, f32, f64}, "masked_select": {b8, f16, f32, f64, i32, i64}, "max.reduction_no_dim": {f16}, - "max.reduction_with_dim": {b8, f16}, + "max.reduction_with_dim": {b8}, "min.reduction_no_dim": {f16}, - "min.reduction_with_dim": {b8, f16}, + "min.reduction_with_dim": {b8}, "multinomial": {f32, f64}, "nan_to_num": {f16}, "nanquantile": {f32, f64}, "nn.functional.avg_pool1d": {i64}, - "nn.functional.avg_pool2d": {i64}, - "nn.functional.adaptive_avg_pool2d": {f16}, + "nn.functional.avg_pool2d": {i64, f64}, + "nn.functional.adaptive_avg_pool2d": {f16, f64}, "nn.functional.ctc_loss": {f32, f64}, "nn.functional.gaussian_nll_loss": {f32, f64}, - "nn.functional.gelu": {f64}, "nn.functional.local_response_norm": {i64}, "nn.functional.one_hot": {i64}, "nn.functional.pairwise_distance": {f16}, @@ -229,19 +223,17 @@ def process(device_type): "normal": {f16, f32, f64}, "normal.number_mean": {f16, f32, f64}, "pca_lowrank": {f32, f64}, - "pinverse": {f32, f64}, "polar": {f32, f64}, "quantile": {f32, f64}, "rand_like": {f16, f32, f64}, "randint_like": {f16, f32, f64, i32, i64}, + "randint": {f16, f32, f64, i32, i64}, "randn_like": {f16, f32, f64}, "repeat_interleave": {b8, f16, f32, f64, i32, i64}, "scatter_add": {f16}, "scatter_reduce.sum": {f16}, "scatter_reduce.prod": {f16, f32, f64}, "segment_reduce.lengths": {f16, f32, f64}, - "segment_reduce.offsets": {f16, f32, f64}, - "sgn": {f16, f32, f64}, "sparse.sampled_addmm": {f32, f64}, "stft": {f32, f64}, "svd_lowrank": {f32, f64}, @@ -255,7 +247,7 @@ def process(device_type): "unique_consecutive": {b8, f32, f64, i32, i64}, "var": {f16}, "var_mean": {f16}, - "view_as_complex": {f16, f32, f64}, + "view_as_complex": {f16}, } @@ -265,12 +257,16 @@ def process(device_type): "mH": {b8, f16, f32, f64, i32, i64}, "mT": {b8, f16, f32, f64, i32, i64}, "__getitem__": {b8, f16, f32, f64, i32, i64}, + "__rdiv__": {b8, f16, f32, f64, i32, i64}, "allclose": {f16, f32, f64}, "angle": {f32, f64}, "argwhere": {b8, f16, f32, f64, i32, i64}, "baddbmm": {f16}, "bernoulli": {f16, f32, f64}, "bincount": {i32, i64}, + "bucketize": {b8, f16, f32, f64, i32, i64}, + "cdouble": {b8, f16, f32, f64, i32, i64}, + "cfloat": {b8, f16, f32, f64, i32, i64}, "chalf": {b8, f16, f32, f64, i32, i64}, "cholesky": {f32, f64}, "combinations": {b8, f16, f32, f64, i32, i64}, @@ -284,10 +280,10 @@ def process(device_type): "fft.hfft": {b8, f16, f32, f64, i32, i64}, "fft.hfft2": {b8, f16, f32, f64, i32, i64}, "fft.hfftn": {b8, f16, f32, f64, i32, i64}, - "fft.ifft": {b8, f16, f32, f64, i32, i64}, + "fft.ifft": {f16, f32, f64}, "fft.ifft2": {b8, f16, f32, f64, i32, i64}, "fft.ifftn": {b8, f16, f32, f64, i32, i64}, - "fft.ihfft": {b8, f16, f32, f64, i32, i64}, + "fft.ihfft": {f16, f32, f64}, "fft.ihfft2": {f16, f32, f64}, "fft.ihfftn": {f16, f32, f64}, "fft.irfft": {b8, f16, f32, f64, i32, i64}, @@ -296,32 +292,28 @@ def process(device_type): "fft.rfft": {f16, f32, f64}, "fft.rfft2": {f16, f32, f64}, "fft.rfftn": {f16, f32, f64}, - "index_put": {f16, f32, f64}, "index_reduce": {f16, f32, f64}, "istft": {f32, f64}, - "linalg.cholesky": {f32, f64}, - "linalg.cholesky_ex": {f32, f64}, "linalg.eig": {f32, f64}, "linalg.eigh": {f32, f64}, "linalg.eigvals": {f32, f64}, "linalg.eigvalsh": {f32, f64}, - "linalg.ldl_factor": {f32, f64}, "linalg.lstsq": {f32, f64}, "linalg.lstsq.grad_oriented": {f32, f64}, "linalg.matrix_rank": {f32, f64}, "linalg.matrix_rank.hermitian": {f32, f64}, - "linalg.pinv.hermitian": {f32, f64}, - "lu_unpack": {f32, f64}, + "linalg.pinv.singular": {f32, f64}, "masked.argmax": {f16, f32, f64, i32}, "masked.argmin": {f16, f32, f64, i32}, "masked_scatter": {f16, f32, f64}, "masked_select": {b8, f16, f32, f64, i32, i64}, - "max.reduction_with_dim": {b8, i32, i64}, - "min.reduction_with_dim": {b8, i32, i64}, + "max.reduction_with_dim": {b8}, + "min.reduction_with_dim": {b8}, "multinomial": {f16, f32, f64}, "nn.functional.adaptive_avg_pool2d": {f16}, "nn.functional.ctc_loss": {f32, f64}, "nn.functional.grid_sample": {f16}, + "grid_sampler_2d": {f16}, "nn.functional.gaussian_nll_loss": {f16, f32, f64}, "nn.functional.one_hot": {i64}, "nn.functional.rrelu": {f16, f32, f64}, @@ -330,18 +322,18 @@ def process(device_type): "normal": {f16, f32, f64}, "normal.number_mean": {f16, f32, f64}, "pca_lowrank": {f32, f64}, - "pinverse": {f32, f64}, "polar": {f32, f64}, "pow": {i32, i64}, "rand_like": {f16, f32, f64}, "randint_like": {f16, f32, f64, i32, i64}, + "randint": {f16, f32, f64, i32, i64}, "randn_like": {f16, f32, f64}, "repeat_interleave": {b8, f16, f32, f64, i32, i64}, "round.decimals_3": {f16}, "scatter_reduce.prod": {f16, f32, f64}, "segment_reduce.lengths": {f16, f32, f64}, - "segment_reduce.offsets": {f16, f32, f64}, - "sgn": {f16, f32, f64}, + "sparse.sampled_addmm": {f32, f64}, + "std_mean.unbiased": {f16}, "stft": {f32, f64}, "svd_lowrank": {f32, f64}, "tensor_split": {b8, f16, f32, f64, i32, i64}, @@ -350,11 +342,7 @@ def process(device_type): "uniform": {f16, f32, f64}, "unique": {b8, f16, f32, f64, i32, i64}, "unique_consecutive": {b8, f16, f32, f64, i32, i64}, - "view_as_complex": {f16, f32, f64}, # AssertionError: Tensor-likes are not close! - "erf": {b8, f64}, - "nn.functional.gelu": {f64}, - "nn.functional.conv_transpose3d": {f16}, "nn.functional.triplet_margin_loss": {f16}, } @@ -364,12 +352,8 @@ def process(device_type): "asin": {f16}, "cumprod": {f16}, "linalg.vector_norm": {f64, f64}, - "linalg.householder_product": {f32}, - "linalg.lu": {f32, f64}, "kron": {f16}, "nanquantile": {f32, f64}, - "native_batch_norm": {f16, f32, f64}, - "native_layer_norm": {f16, f32, f64}, "nn.functional._scaled_dot_product_attention": {f16}, "nn.functional.avg_pool2d": {f16, f32, f64}, "nn.functional.batch_norm.without_cudnn": {f16}, @@ -424,6 +408,7 @@ def wrapper_set_seed(op, *args, **kwargs): "randn": {"assert_equal": False}, ("nn.functional.tanhshrink", "cuda", f16): {"atol": 3e-4, "rtol": 0.001}, ("cummax", "cuda", f16): {"atol": 5e-4, "rtol": 0.002}, + ("_softmax_backward_data", "cuda", f16): {"atol": 0.008, "rtol": 0.002}, "gradient": {"check_gradient": False}, # segfault on check_gradient # Following tests failed, and causing subsequent tests failing with unrecoverable CUDA error "linalg.solve_triangular": {"check_gradient": False}, @@ -435,12 +420,15 @@ def wrapper_set_seed(op, *args, **kwargs): inductor_all_samples = { "softmax.with_dtype", "index_add", - "index_put", "index_copy", "scatter_reduce.sum", "select_scatter", "squeeze", "unsqueeze", + "sum", + "amax", + "amin", + "all", } @@ -453,6 +441,10 @@ class TestInductorOpInfo(TestCase): @skipCUDAMemoryLeakCheckIf( True ) # inductor kernels failing this test intermittently + @skipCUDAIf(not HAS_CUDA, "Skipped! Triton not found") + @skipCPUIf(not HAS_CPU, "Skipped! Supported CPU compiler not found") + @skipIfTorchDynamo("Test uses dynamo already") + @skipIfCrossRef @_ops(op_db[START:END]) @patch("torch._dynamo.config.raise_on_unsafe_aot_autograd", True) def test_comprehensive(self, device, dtype, op): @@ -538,7 +530,6 @@ def fn(*args, **kwargs): "check_gradient": requires_grad, } adjusted_kwargs.update(overridden_kwargs) - self.check_model_cuda( fn, args, @@ -597,6 +588,4 @@ def fn(*args, **kwargs): instantiate_device_type_tests(TestInductorOpInfo, globals()) if __name__ == "__main__": - torch._dynamo.config.raise_on_assertion_error = True - if has_triton() and not TEST_WITH_ROCM: - run_tests() + run_tests() diff --git a/test/jit/test_async.py b/test/jit/test_async.py index d3769cd452d64..f8a1baea67133 100644 --- a/test/jit/test_async.py +++ b/test/jit/test_async.py @@ -1,6 +1,5 @@ # Owner(s): ["oncall: jit"] -import io import os import sys @@ -420,20 +419,6 @@ def fn(x): self.assertGraphContainsExactly(traced.graph, kind='aten::wait', num_kind_nodes=0) self.assertGraphContainsExactly(traced.graph, kind='aten::add', num_kind_nodes=2) - def test_trace_fork_wait_inline_onnx(self): - def fork_body(x): - return torch.neg(x), torch.neg(x) - - class MyMod(torch.nn.Module): - def forward(self, x): - fut = torch.jit._fork(fork_body, x) - val = torch.jit._wait(fut) - return val[1] - - # smoke test for ONNX export - f = io.BytesIO() - torch.onnx.export(MyMod(), (torch.rand(3, 4),), f) - def test_trace_fork_wait_list_modulecalls(self): def add_one(input): return input + torch.ones(input.size()) diff --git a/test/jit/test_dtype_analysis.py b/test/jit/test_dtype_analysis.py index af1a7f3b24f28..783a1b935d9b7 100644 --- a/test/jit/test_dtype_analysis.py +++ b/test/jit/test_dtype_analysis.py @@ -128,9 +128,9 @@ def assert_dtype_equal(self, fn, in_shapes, in_dtypes): inputs = [self.get_rand_tensor(s, d) for s, d in zip(in_shapes, in_dtypes)] try: self.assert_dtype_equal_custom_args(fn, inputs) - except Exception: + except Exception as e: fail_text = f"Failed for shapes {in_shapes}, and dtypes {in_dtypes}" - raise AssertionError(fail_text) + raise AssertionError(fail_text) from e def assert_dtype_equal_custom_args(self, fn, args): try: diff --git a/test/jit/test_hooks.py b/test/jit/test_hooks.py index 109a5e3f1b716..2963837a638a6 100644 --- a/test/jit/test_hooks.py +++ b/test/jit/test_hooks.py @@ -229,7 +229,7 @@ def pre_hook(self, input: Tuple[str]) -> Tuple[str]: with self.assertRaisesRegex( RuntimeError, - "This error occured while scripting the forward pre-hook 'pre_hook'", + "This error occurred while scripting the forward pre-hook 'pre_hook'", ): torch.jit.script(m) diff --git a/test/jit/test_misc.py b/test/jit/test_misc.py index db37af81993f3..8a5d4ea5f4a7a 100644 --- a/test/jit/test_misc.py +++ b/test/jit/test_misc.py @@ -361,3 +361,22 @@ def test_parse_ir_single_element_tensor_negative(self): ret = func() self.assertTrue(ret.numel() == 1) self.assertTrue(len(ret.size()) == 1) + + + def test_script_many_decorators(self): + def no_op_decorator(f): + return f + + @no_op_decorator + @no_op_decorator + @no_op_decorator + @no_op_decorator + @no_op_decorator + def foo(x, dim: int): + return x.unsqueeze(dim) + + x = torch.randn(1,) + expected = foo(x, 0) + scripted = torch.jit.script(foo) + actual = scripted(x, 0) + torch.testing.assert_close(expected, actual) diff --git a/test/jit/test_python_bindings.py b/test/jit/test_python_bindings.py index 37c2ef7f85af7..51c5e0383b2ca 100644 --- a/test/jit/test_python_bindings.py +++ b/test/jit/test_python_bindings.py @@ -84,6 +84,11 @@ def test_graph_create(self): with self.assertRaises(ValueError): gr.create("prim::Constant", [None]) + def test_add_input(self): + gr = torch._C.Graph() + foo_value = gr.addInput("foo") + assert foo_value in gr.inputs() + def test_canonicalize(self): ir = """ graph(%p207 : Tensor, diff --git a/test/jit/test_symbolic_shape_analysis.py b/test/jit/test_symbolic_shape_analysis.py index 1c4e359662bda..3e3cb3ffed73a 100644 --- a/test/jit/test_symbolic_shape_analysis.py +++ b/test/jit/test_symbolic_shape_analysis.py @@ -319,7 +319,12 @@ def forward(self, x, y): mod = torch.jit.script(CatMod(**inp.kwargs).eval()) args = inp.input - self.assertTrue(len(args) == 2) + + # This test is hard-coded only to work with two sample inputs + # but the OpInfo may have more/less + if len(args) != 2: + continue + out_size = mod(*args).size() inps = list(mod.graph.inputs()) inps[1].setType(inps[1].type().with_sizes(args[0].size())) diff --git a/test/jit/test_tracer.py b/test/jit/test_tracer.py index 50fdec94b9fc0..b36003a2b9209 100644 --- a/test/jit/test_tracer.py +++ b/test/jit/test_tracer.py @@ -1124,14 +1124,6 @@ def foo(x, w): # With `check_trace=True` it will run with `@torch.no_grad()` and break assert. torch.jit.trace(foo, (x, w), check_trace=False) - def test_trace_detach_onnx_erase(self): - class Mod(torch.nn.Module): - def forward(self, x, w): - return torch.matmul(x, w).detach() - - torch.onnx.export_to_pretty_string( - Mod(), (torch.rand(3, 4), torch.rand(4, 5))) - def test_trace_slice_full_dim(self): def foo(x): return x[0:5, 0] + 1.0 diff --git a/test/jit/xnnpack/test_xnnpack_delegate.py b/test/jit/xnnpack/test_xnnpack_delegate.py index 8c759cb01ccf6..c54d9ba1b0881 100644 --- a/test/jit/xnnpack/test_xnnpack_delegate.py +++ b/test/jit/xnnpack/test_xnnpack_delegate.py @@ -8,6 +8,34 @@ torch.ops.load_library("//caffe2:xnnpack_backend") class TestXNNPackBackend(unittest.TestCase): + def test_xnnpack_constant_data(self): + class Module(torch.nn.Module): + def __init__(self): + super().__init__() + self._constant = torch.ones(4, 4, 4) + + def forward(self, x): + return x + self._constant + + scripted_module = torch.jit.script(Module()) + + lowered_module = torch._C._jit_to_backend( + "xnnpack", + scripted_module, + { + "forward": { + "inputs" : [torch.randn(4, 4, 4)], + "outputs": [torch.randn(4, 4, 4)] + } + } + ) + + for i in range(0, 20): + sample_input = torch.randn(4, 4, 4) + actual_output = scripted_module(sample_input) + expected_output = lowered_module(sample_input) + self.assertTrue(torch.allclose(actual_output, expected_output, atol=1e-03, rtol=1e-03)) + def test_xnnpack_lowering(self): class Module(torch.nn.Module): def __init__(self): @@ -67,3 +95,98 @@ def forward(self, x): } ) lowered(torch.zeros(1)) + + def test_xnnpack_backend_add(self): + class AddModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + z = x + y + z = z + x + z = z + x + return z + + add_module = AddModule() + sample_inputs = (torch.rand(1, 512, 512, 3), torch.rand(1, 512, 512, 3)) + sample_output = torch.zeros(1, 512, 512, 3) + + add_module = torch.jit.script(add_module) + expected_output = add_module(sample_inputs[0], sample_inputs[1]) + + lowered_add_module = torch._C._jit_to_backend( + "xnnpack", + add_module, + { + "forward": { + "inputs" : [sample_inputs[0].clone(), sample_inputs[1].clone()], + "outputs": [sample_output] + } + } + ) + + actual_output = lowered_add_module.forward(sample_inputs[0], sample_inputs[1]) + self.assertTrue(torch.allclose(actual_output, expected_output, atol=1e-03, rtol=1e-03)) + + def test_xnnpack_broadcasting(self): + class AddModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return x + y + + add_module = AddModule() + sample_inputs = (torch.rand(5, 1, 4, 1), torch.rand(3, 1, 1)) + sample_output = torch.zeros(5, 3, 4, 1) + + add_module = torch.jit.script(add_module) + expected_output = add_module(sample_inputs[0], sample_inputs[1]) + + lowered_add_module = torch._C._jit_to_backend( + "xnnpack", + add_module, + { + "forward": { + "inputs" : [sample_inputs[0], sample_inputs[1]], + "outputs": [sample_output] + } + } + ) + + actual_output = lowered_add_module.forward(sample_inputs[0], sample_inputs[1]) + self.assertTrue(torch.allclose(actual_output, expected_output, atol=1e-03, rtol=1e-03)) + + def test_xnnpack_unsupported(self): + class AddSpliceModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + z = x + y[:, :, 1, :] + return z + + sample_inputs = (torch.rand(1, 512, 512, 3), torch.rand(1, 512, 512, 3)) + sample_output = torch.zeros(1, 512, 512, 3) + + error_msg = ( + "the module contains the following unsupported ops:\n" + "aten::select\n" + "aten::slice\n" + ) + + add_module = torch.jit.script(AddSpliceModule()) + with self.assertRaisesRegex( + RuntimeError, + error_msg, + ): + _ = torch._C._jit_to_backend( + "xnnpack", + add_module, + { + "forward": { + "inputs" : [sample_inputs[0], sample_inputs[1]], + "outputs": [sample_output] + } + } + ) diff --git a/test/jit_hooks/CMakeLists.txt b/test/jit_hooks/CMakeLists.txt index 546a3040f49bc..91d5a2bf4e01c 100644 --- a/test/jit_hooks/CMakeLists.txt +++ b/test/jit_hooks/CMakeLists.txt @@ -9,5 +9,5 @@ endif() find_package(Torch REQUIRED) add_executable(test_jit_hooks test_jit_hooks.cpp) -set_property(TARGET test_jit_hooks PROPERTY CXX_STANDARD 14) +set_property(TARGET test_jit_hooks PROPERTY CXX_STANDARD 17) target_link_libraries(test_jit_hooks "${TORCH_LIBRARIES}") diff --git a/test/lazy/test_debug_util.py b/test/lazy/test_debug_util.py new file mode 100644 index 0000000000000..df201d54737f1 --- /dev/null +++ b/test/lazy/test_debug_util.py @@ -0,0 +1,44 @@ +# Owner(s): ["oncall: jit"] + +import os +import re +import tempfile +import torch.nn as nn +import unittest + +import torch._lazy +import torch._lazy.ts_backend +from torch.testing._internal.common_utils import IS_WINDOWS, run_tests, TestCase + +torch._lazy.ts_backend.init() + + +@unittest.skipIf(IS_WINDOWS, "To be fixed") +class DebugUtilTest(TestCase): + def _run_linear(self): + device = "lazy" + model = nn.Linear(5, 5).to(device) + output = model(torch.randn(1, 5).to(device)) + torch._lazy.mark_step() + + + def test_get_python_frames(self): + # We only care about the first "Python Stacktrace" part of the saved + # graph. However, we cannot save the whole stack for comparison given + # it depends on a lot of things. + partial_graph = (r"Python Stacktrace:.*" + r"mark_step \(.*/_lazy/__init__.py:[0-9]+\).*" + r"_run_linear \(.*lazy/test_debug_util.py:[0-9]+\).*" + r"test_get_python_frames \(.*lazy/test_debug_util.py:[0-9]+\)") + + with tempfile.NamedTemporaryFile(mode="r+", encoding="utf-8") as graph_file: + os.environ["LTC_SAVE_TENSORS_FILE"] = graph_file.name + self._run_linear() + file = graph_file.read() + if re.search(partial_graph, file, re.DOTALL) is None: + print(file) + self.assertTrue(False) + + +if __name__ == "__main__": + run_tests() diff --git a/test/lazy/test_extract_compiled_graph.py b/test/lazy/test_extract_compiled_graph.py index f4152d0af68bf..0d916952be3b5 100644 --- a/test/lazy/test_extract_compiled_graph.py +++ b/test/lazy/test_extract_compiled_graph.py @@ -141,7 +141,7 @@ def verify_reusing_compiled_graph(mod, exception_msg_pattern, ncase=10): raise e # reraise the exception exception_message = str(e) if not re.search(exception_msg_pattern, exception_message): - raise RuntimeError(f"Expection message does not match the required pattern: {exception_message}") + raise RuntimeError(f"Exception message does not match the required pattern: {exception_message}") from e else: # We are done for the test case that expects an exception return diff --git a/test/lazy/test_reuse_ir.py b/test/lazy/test_reuse_ir.py index 2d19fe1a5b539..f7024e9519cca 100644 --- a/test/lazy/test_reuse_ir.py +++ b/test/lazy/test_reuse_ir.py @@ -111,6 +111,7 @@ def testBatchNorm(self): # BatchNorm2d does extra checks on dimensions which SymInts don't support yet # so we call `torch.ops.aten.native_batch_norm` to bypass the checks. z, _, _ = torch.ops.aten.native_batch_norm(x, weight, bias, None, None, True, 0.1, 1e-5) + z_legit, _, _ = torch.ops.aten._native_batch_norm_legit(x, weight, bias, True, 0.1, 1e-5) device = "lazy" x_lazy = x.detach().clone().to(device=device) @@ -118,12 +119,15 @@ def testBatchNorm(self): bias_lazy = bias.detach().clone().to(device=device) for i in range(10): z_lazy, _, _ = torch.ops.aten.native_batch_norm(x_lazy, weight_lazy, bias_lazy, None, None, True, 0.1, 1e-5) + z_legit_lazy, _, _ = torch.ops.aten._native_batch_norm_legit(x_lazy, weight_lazy, bias_lazy, True, 0.1, 1e-5) torch._lazy.mark_step() torch.testing.assert_close(z.cpu(), z_lazy.cpu()) + torch.testing.assert_close(z_legit.cpu(), z_legit_lazy.cpu()) assert metrics.counter_value("IrNodeReused_torch::lazy::NativeBatchNorm") >= 7 metrics.reset() torch._lazy.ir_cache.reset() + if __name__ == '__main__': run_tests() diff --git a/test/lazy/test_ts_opinfo.py b/test/lazy/test_ts_opinfo.py index f5974ec9f6c2c..092ba3d0388d0 100644 --- a/test/lazy/test_ts_opinfo.py +++ b/test/lazy/test_ts_opinfo.py @@ -59,6 +59,7 @@ def init_lists(): # but run functionalized versions of the composite kernels in core. # This means that we don't expect the ops to show directly in the LTC metrics. FUNCTIONAL_DECOMPOSE_LIST = set([ + 'diag_embed', 'block_diag', 'new_empty_strided', 'narrow_copy', @@ -70,20 +71,28 @@ def init_lists(): 'linalg_pinv.atol_rtol_tensor', 'logsumexp', ]) + # For some ops, we don't support all variants. Here we use formatted_name + # to uniquely identify the variant. + SKIP_VARIANT_LIST = set([ + 'norm_nuc', + 'min_reduction_with_dim' + ]) return (LAZY_OPS_LIST, FALLBACK_LIST, SKIP_RUNTIME_ERROR_LIST, SKIP_INCORRECT_RESULTS_LIST, FUNCTIONAL_DECOMPOSE_LIST, - HAS_SYMINT_SUFFIX) + HAS_SYMINT_SUFFIX, + SKIP_VARIANT_LIST) (LAZY_OPS_LIST, FALLBACK_LIST, SKIP_RUNTIME_ERROR_LIST, SKIP_INCORRECT_RESULTS_LIST, FUNCTIONAL_DECOMPOSE_LIST, - HAS_SYMINT_SUFFIX) = init_lists() + HAS_SYMINT_SUFFIX, + SKIP_VARIANT_LIST) = init_lists() torch.manual_seed(42) @@ -165,6 +174,7 @@ class TestLazyOpInfo(TestCase): if op.name in LAZY_OPS_LIST and op.name not in SKIP_RUNTIME_ERROR_LIST and op.name not in FUNCTIONAL_DECOMPOSE_LIST + and op.formatted_name not in SKIP_VARIANT_LIST ], allowed_dtypes=(torch.float,)) def test_dispatched_to_lazy(self, device, dtype, op): def get_name(op): diff --git a/test/mobile/custom_build/CMakeLists.txt b/test/mobile/custom_build/CMakeLists.txt index 521569176c307..426371f4d2965 100644 --- a/test/mobile/custom_build/CMakeLists.txt +++ b/test/mobile/custom_build/CMakeLists.txt @@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.1) project(custom_build_project) -set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ standard whose features are requested to build this target.") +set(CMAKE_CXX_STANDARD 17 CACHE STRING "The C++ standard whose features are requested to build this target.") # Find torch library find_package(Torch REQUIRED) diff --git a/test/mobile/model_test/README.md b/test/mobile/model_test/README.md index 49b21051c655a..7e99e6763fee8 100644 --- a/test/mobile/model_test/README.md +++ b/test/mobile/model_test/README.md @@ -55,7 +55,7 @@ NOTE: currently Android simulator test does not generate on-the-fly models. Only ## Diagnose failed test If the simulator test is falling, that means the current change will potentially break a production model. So be careful. The detailed error message can be found in test log. If the change has to be made, make sure it doesn't break existing production models, and update the failed test model as appropriate (see the next section). -You can also run these tests locally, please see the insturction in android and ios folder. Remember to generate on-the-fly test models if you want to test it locally (but don't commit these models with _temp suffix). +You can also run these tests locally, please see the instruction in android and ios folder. Remember to generate on-the-fly test models if you want to test it locally (but don't commit these models with _temp suffix). ``` python test/mobile/model_test/gen_test_model.py ios-test ``` diff --git a/test/mobile/test_lite_script_module.py b/test/mobile/test_lite_script_module.py index 638ac37eb88b3..9089977b77f12 100644 --- a/test/mobile/test_lite_script_module.py +++ b/test/mobile/test_lite_script_module.py @@ -241,7 +241,7 @@ def forward(self): script_module = torch.jit.script(MyTestModuleForListWithModuleClass()) with self.assertRaisesRegex(RuntimeError, - r"^Returining a list or dictionary with pytorch class type " + r"^Returning a list or dictionary with pytorch class type " r"is not supported in mobile module " r"\(List\[Foo\] or Dict\[int\, Foo\] for class Foo\(torch\.nn\.Module\)\)\. " r"Workaround\: instead of using pytorch class as their element type\, " @@ -264,7 +264,7 @@ def forward(self): script_module = torch.jit.script(MyTestModuleForDictWithModuleClass()) with self.assertRaisesRegex(RuntimeError, - r"^Returining a list or dictionary with pytorch class type " + r"^Returning a list or dictionary with pytorch class type " r"is not supported in mobile module " r"\(List\[Foo\] or Dict\[int\, Foo\] for class Foo\(torch\.nn\.Module\)\)\. " r"Workaround\: instead of using pytorch class as their element type\, " diff --git a/test/mobile/test_lite_script_type.py b/test/mobile/test_lite_script_type.py index 9a778fb5a7fd9..44eb6d4778e8b 100644 --- a/test/mobile/test_lite_script_type.py +++ b/test/mobile/test_lite_script_type.py @@ -4,6 +4,7 @@ import torch.utils.bundled_inputs import io from typing import Dict, List, NamedTuple +import unittest from torch.jit.mobile import _load_for_lite_interpreter from torch.testing._internal.common_utils import TestCase, run_tests @@ -28,12 +29,13 @@ def forward(self, a: torch.Tensor): buffer.seek(0) mobile_module = _load_for_lite_interpreter(buffer) # Error here mobile_module_result = mobile_module(sample_input).a - torch.testing.assert_allclose( + torch.testing.assert_close( script_module_result, mobile_module_result ) + @unittest.skip("T137512434") def test_typing_dict_with_namedtuple(self): class Foo(NamedTuple): id: torch.Tensor @@ -91,7 +93,7 @@ def forward(self, a: torch.Tensor): buffer_mobile.seek(0) mobile_module = _load_for_lite_interpreter(buffer_mobile) mobile_module_result = mobile_module(sample_input) - torch.testing.assert_allclose( + torch.testing.assert_close( script_module_result, mobile_module_result ) @@ -117,7 +119,7 @@ def forward(self, a: torch.Tensor): buffer_mobile.seek(0) mobile_module = _load_for_lite_interpreter(buffer_mobile) mobile_module_result = mobile_module(sample_input) - torch.testing.assert_allclose( + torch.testing.assert_close( script_module_result, mobile_module_result ) @@ -136,7 +138,7 @@ def forward(self, a: torch.Tensor): buffer_mobile.seek(0) mobile_module = _load_for_lite_interpreter(buffer_mobile) mobile_module_result = mobile_module(sample_input) - torch.testing.assert_allclose( + torch.testing.assert_close( script_module_result, mobile_module_result ) @@ -166,7 +168,7 @@ def forward(self, a: torch.Tensor): buffer_mobile.seek(0) mobile_module = _load_for_lite_interpreter(buffer_mobile) mobile_module_result = mobile_module(sample_input) - torch.testing.assert_allclose( + torch.testing.assert_close( script_module_result.baz.di, mobile_module_result.baz.di ) diff --git a/test/mobile/test_quantize_fx_lite_script_module.py b/test/mobile/test_quantize_fx_lite_script_module.py index 44beeef818c33..ebc96d17697bd 100644 --- a/test/mobile/test_quantize_fx_lite_script_module.py +++ b/test/mobile/test_quantize_fx_lite_script_module.py @@ -47,7 +47,11 @@ def forward(self, indices): for qconfig, node in configs: qconfig_dict = {"": qconfig} - m = prepare_fx(model, qconfig_dict) + m = prepare_fx( + model, + qconfig_dict, + example_inputs=torch.randint(low=0, high=10, size=(20,)), + ) m = convert_fx(m) self._compare_script_and_mobile(m, input=indices) @@ -65,7 +69,7 @@ def forward(self, x): m = M().eval() qconfig_dict = {"": default_qconfig, "module_name": [("conv1", None)]} - m = prepare_fx(m, qconfig_dict) + m = prepare_fx(m, qconfig_dict, example_inputs=torch.randn(1, 1, 1, 1)) data = torch.randn(1, 1, 1, 1) m = convert_fx(m) # first conv is quantized, second conv is not quantized @@ -84,7 +88,11 @@ def test_submodule(self): "": torch.ao.quantization.get_default_qconfig("qnnpack"), **config, } - model = prepare_fx(model, qconfig_dict) + model = prepare_fx( + model, + qconfig_dict, + example_inputs=torch.randn(5, 5), + ) quant = convert_fx(model) x = torch.randn(5, 5) diff --git a/test/nn/test_convolution.py b/test/nn/test_convolution.py new file mode 100644 index 0000000000000..a30a276439754 --- /dev/null +++ b/test/nn/test_convolution.py @@ -0,0 +1,2480 @@ +# Owner(s): ["module: nn"] +import math +import unittest +import itertools +import warnings +from itertools import product + +import torch + +import torch.autograd.forward_ad as fwAD +import torch.backends.cudnn as cudnn +import torch.nn as nn +import torch.nn.functional as F +from torch.testing._internal.common_dtype import floating_types_and, floating_and_complex_types_and +from torch.testing._internal.common_utils import run_tests, \ + skipIfRocmVersionLessThan, skipIfNotMiopenSuggestNHWC, TEST_SCIPY, TEST_WITH_ROCM, \ + download_file, parametrize as parametrize_test, subtest, \ + instantiate_parametrized_tests, set_default_dtype +from torch.testing._internal.common_cuda import TEST_CUDA, TEST_CUDNN +from torch.testing._internal.common_nn import NNTestCase, _test_module_empty_input +from torch.testing._internal.common_device_type import instantiate_device_type_tests, dtypes, \ + dtypesIfCUDA, precisionOverride, skipCUDAIfNoCudnn, skipCUDAIfCudnnVersionLessThan, onlyCUDA, onlyCPU, \ + skipCUDAIfRocm, skipCUDAIfRocmVersionLessThan, skipCUDAIfNotMiopenSuggestNHWC, \ + onlyNativeDeviceTypes, largeTensorTest, skipMeta, \ + disableMkldnn, skipCPUIfNoMkldnn, disablecuDNN, skipCUDAIfMiopen, skipCUDAIfNoMiopen + +from torch.testing import make_tensor +from torch.testing._internal.common_utils import gradcheck, gradgradcheck, \ + GRADCHECK_NONDET_TOL +from torch.testing._internal.common_utils import dtype2prec_DONTUSE +from torch.testing._internal.common_cuda import tf32_on_and_off, tf32_is_not_fp32 + +AMPERE_OR_ROCM = TEST_WITH_ROCM or tf32_is_not_fp32() + + +if TEST_SCIPY: + import scipy.signal + import scipy.ndimage + +class TestConvolutionNN(NNTestCase): + _do_cuda_memory_leak_check = True + _do_cuda_non_default_stream = True + + def test_conv_backcompat(self): + from torch.serialization import SourceChangeWarning + + # This file was generated by running on PyTorch 1.0.1 on Python 2: + # + # import torch + # from torch import nn + # m = nn.Conv2d(1, 1, 1) + # torch.save(m, 'legacy_conv2d.pt') + # + # NB: This Pickle also contains some Unicode data! + path = download_file('https://download.pytorch.org/test_data/legacy_conv2d.pt') + with warnings.catch_warnings(): + warnings.simplefilter('ignore', SourceChangeWarning) + m = torch.load(path, encoding='utf-8') + input = torch.randn((1, 1, 1, 1), dtype=torch.float) + self.assertEqual(m(input).size(), (1, 1, 1, 1)) + + def test_invalid_conv1d(self): + for dtype in [torch.bfloat16, torch.float, torch.double, torch.cfloat, torch.cdouble]: + module = nn.Conv1d(in_channels=3, out_channels=33, kernel_size=10, stride=1, bias=True).to(dtype) + input = torch.randn(1, 3, 4).to(dtype) + with self.assertRaisesRegex(RuntimeError, + r'Calculated padded input size per channel: \(4\). ' + + r'Kernel size: \(10\). Kernel size can\'t be greater than actual input size'): + module(input) + + # Negative stride check + module = nn.Conv1d(in_channels=3, out_channels=6, kernel_size=3, stride=-1, bias=True).to(dtype) + input = torch.randn(1, 3, 4).to(dtype) + with self.assertRaisesRegex(RuntimeError, 'non-positive stride is not supported'): + module(input) + + def test_mismatch_shape_conv2d(self): + for dtype in (torch.float, torch.cfloat): + x = torch.randn(1, 10, 1, 28, 28, dtype=dtype) + w = torch.randn(6, 1, 5, 5, dtype=dtype) + + with self.assertRaisesRegex(RuntimeError, + r'Expected 3D \(unbatched\) or 4D \(batched\) input to conv2d, but got ' + + r'input of size: \[1, 10, 1, 28, 28\]'): + + F.conv2d(x, w) + + def test_conv2d_discontiguous_weight(self): + for dtype in (torch.float, torch.cfloat): + # Test for https://github.com/pytorch/pytorch/issues/55781 + x = torch.ones(64, 16, 16, 16, dtype=dtype) + weight = torch.arange(0, 1.0, 1 / 2.0 ** 10).reshape(32, 16, 1, 2).to(dtype)[:, :, :, ::2] + self.assertFalse(weight.is_contiguous()) + y = torch.nn.functional.conv2d(x, weight, None) + if torch.backends.mkldnn.is_available(): + # Disable MKLDNN explicitly, so that either NNPACK or THCNN will be used + with torch.backends.mkldnn.flags(enabled=False): + y_ = torch.nn.functional.conv2d(x, weight, None) + self.assertEqual(y, y_) + self.assertEqual(y.sum(), 4186112.) + + def test_invalid_conv2d(self): + for dtype in [torch.bfloat16, torch.float, torch.double, torch.cfloat, torch.cdouble]: + module = torch.nn.Conv2d(1, 1, kernel_size=3, dilation=2, stride=2).to(dtype) + input = torch.empty(1, 1, 4, 4).to(dtype) + self.assertRaises(RuntimeError, lambda: module(input)) + + module = nn.Conv2d(in_channels=3, out_channels=33, kernel_size=10, stride=1, bias=True) + input = torch.randn(1, 3, 1, 1) + with self.assertRaisesRegex(RuntimeError, + r'Calculated padded input size per channel: \(1 x 1\). ' + + r'Kernel size: \(10 x 10\). Kernel size can\'t be greater than actual input size'): + module(input) + + # Negative stride check + module = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=4, stride=-1, bias=True).to(dtype) + input = torch.randn(1, 3, 4, 4).to(dtype) + with self.assertRaisesRegex(RuntimeError, 'non-positive stride is not supported'): + module(input) + + # Zero stride check + module = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=4, stride=0, bias=True).to(dtype) + input = torch.randn(1, 3, 4, 4).to(dtype) + with self.assertRaisesRegex(RuntimeError, 'non-positive stride is not supported'): + module(input) + + def test_invalid_conv3d(self): + for dtype in [torch.bfloat16, torch.float, torch.double, torch.cfloat, torch.cdouble]: + module = torch.nn.Conv3d(1, 1, kernel_size=3, dilation=2, stride=2).to(dtype) + input = torch.empty(1, 1, 4, 4, 4).to(dtype) + self.assertRaises(RuntimeError, lambda: module(input)) + + # Negative stride check + module = torch.nn.Conv3d(1, 1, kernel_size=3, stride=-2) + input = torch.empty(1, 1, 4, 4, 4) + with self.assertRaisesRegex(RuntimeError, 'non-positive stride is not supported'): + module(input) + + def test_conv_invalid_groups(self): + with self.assertRaisesRegex(ValueError, 'groups must be a positive integer'): + torch.nn.Conv1d(1, 1, kernel_size=3, dilation=2, stride=2, groups=0) + with self.assertRaisesRegex(ValueError, 'groups must be a positive integer'): + torch.nn.Conv2d(1, 1, kernel_size=3, dilation=2, stride=2, groups=-1) + with self.assertRaisesRegex(ValueError, 'groups must be a positive integer'): + torch.nn.Conv3d(1, 1, kernel_size=3, dilation=2, stride=2, groups=-2) + + def test_Conv1d_module_same_padding(self): + # Compare module against functional: without strides/dilation, asymmetric padding + x = torch.rand(1, 1, 20) + module = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=10, + padding='same') + expect = F.conv1d(x, module.weight, module.bias, padding='same') + self.assertEqual(expect, module(x)) + + # Test dilation, symmetric padding + module = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=10, + padding='same', dilation=2) + expect = F.conv1d(x, module.weight, module.bias, padding='same', dilation=2) + self.assertEqual(expect, module(x)) + + # Test non-zero padding_mode, requiring explicit padding + module = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=10, + padding='same', padding_mode='replicate') + x_padded = F.pad(x, [4, 5], mode='replicate') + expect = F.conv1d(x_padded, module.weight, module.bias, padding='valid') + self.assertEqual(expect, module(x)) + self.assertEqual(x.size(), expect.size()) + + # Test connstruction with invalid padding string raises + with self.assertRaisesRegex(ValueError, 'Invalid padding string'): + module = nn.Conv1d(in_channels=3, out_channels=33, kernel_size=10, padding='foo') + + # Test connstruction with same padding and strides raises + with self.assertRaisesRegex(ValueError, "padding='same'"): + module = nn.Conv1d(in_channels=3, out_channels=33, kernel_size=10, padding='same', stride=2) + + def test_Conv2d_module_same_padding(self): + # Compare module against functional: + # without strides/dilation, both symmetric and asymmetric padding + x = torch.rand(1, 1, 9, 20) + module = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=(5, 10), + padding='same') + expect = F.conv2d(x, module.weight, module.bias, padding='same') + self.assertEqual(expect, module(x)) + + # with dilation, symmetric padding + module = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=(3, 4), + padding='same', dilation=(1, 2)) + expect = F.conv2d(x, module.weight, module.bias, padding='same', dilation=(1, 2)) + self.assertEqual(expect, module(x)) + + # Test non-zero padding_mode, requiring explicit padding + module = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=(3, 4), + padding='same', padding_mode='reflect') + x_padded = F.pad(x, [1, 2, 1, 1], mode='reflect') + expect = F.conv2d(x_padded, module.weight, module.bias, padding='valid') + self.assertEqual(expect, module(x)) + self.assertEqual(x.size(), expect.size()) + + # Test connstruction with invalid padding string raises + with self.assertRaisesRegex(ValueError, 'Invalid padding string'): + module = nn.Conv2d(in_channels=3, out_channels=33, kernel_size=10, padding='foo') + + # Test connstruction with same padding and strides raises + with self.assertRaisesRegex(ValueError, "padding='same'"): + module = nn.Conv2d(in_channels=3, out_channels=33, kernel_size=10, padding='same', stride=2) + with self.assertRaisesRegex(ValueError, "padding='same'"): + module = nn.Conv2d(in_channels=3, out_channels=33, kernel_size=10, padding='same', stride=(1, 3)) + with self.assertRaisesRegex(ValueError, "padding='same'"): + module = nn.Conv2d(in_channels=3, out_channels=33, kernel_size=10, padding='same', stride=(4, 1)) + + def test_Conv3d_module_same_padding(self): + # Compare module against functional: + x = torch.rand(1, 1, 4, 4, 4) + # without dilation, both symmetric and asymmetric padding + module = nn.Conv3d(in_channels=1, out_channels=1, kernel_size=(2, 3, 4), + padding='same') + expect = F.conv3d(x, module.weight, module.bias, padding='same') + self.assertEqual(expect, module(x)) + + # with dilation, both symmetric and asymmetric padding + module = nn.Conv3d(in_channels=1, out_channels=1, kernel_size=(2, 3, 4), + padding='same', dilation=(3, 2, 1)) + expect = F.conv3d(x, module.weight, module.bias, padding='same', dilation=(3, 2, 1)) + self.assertEqual(expect, module(x)) + + # Test non-zero padding_mode, requiring explicit padding + module = nn.Conv3d(in_channels=1, out_channels=1, kernel_size=(2, 3, 4), + padding='same', padding_mode='circular') + x_padded = F.pad(x, [1, 2, 1, 1, 0, 1], mode='circular') + expect = F.conv3d(x_padded, module.weight, module.bias, padding='valid') + self.assertEqual(expect, module(x)) + self.assertEqual(x.size(), expect.size()) + + # Test connstruction with invalid padding string raises + with self.assertRaisesRegex(ValueError, 'Invalid padding string'): + module = nn.Conv3d(in_channels=3, out_channels=33, kernel_size=10, padding='foo') + + # Test connstruction with same padding and strides raises + with self.assertRaisesRegex(ValueError, "padding='same'"): + module = nn.Conv2d(in_channels=3, out_channels=33, kernel_size=10, padding='same', stride=2) + with self.assertRaisesRegex(ValueError, "padding='same'"): + module = nn.Conv2d(in_channels=3, out_channels=33, kernel_size=10, padding='same', stride=(1, 1, 3)) + with self.assertRaisesRegex(ValueError, "padding='same'"): + module = nn.Conv2d(in_channels=3, out_channels=33, kernel_size=10, padding='same', stride=(1, 4, 1)) + with self.assertRaisesRegex(ValueError, "padding='same'"): + module = nn.Conv2d(in_channels=3, out_channels=33, kernel_size=10, padding='same', stride=(5, 1, 1)) + + @unittest.skipIf(not TEST_CUDA, 'CUDA not available') + def test_thnn_conv_strided_padded_dilated(self): + for convfn, dims, transposed in ( + (torch.nn.functional.conv2d, 2, False), + (torch.nn.functional.conv_transpose2d, 2, True), + (torch.nn.functional.conv3d, 3, False), + (torch.nn.functional.conv_transpose3d, 3, True)): + for stride, padding, dilation in ( + (2, 0, 1), (1, 1, 1), (2, 1, 1), (1, 0, 2)): + kwargs = {"stride": stride, "padding": padding, "dilation": dilation} + inp_shape = (1, 2) + dims * (4,) + weight_shape = (2, 2) + dims * (1,) + inputs = torch.randn(inp_shape, dtype=torch.double, device="cuda", requires_grad=True) + weight = torch.randn(weight_shape, dtype=torch.double, device="cuda", requires_grad=True) + bias = torch.randn(2, dtype=torch.double, device="cuda", requires_grad=True) + with torch.backends.cudnn.flags(enabled=False): + res = convfn(inputs, weight, bias, **kwargs) + res_cpu = convfn(inputs.cpu(), weight.cpu(), bias.cpu(), **kwargs) + self.assertEqual(res, res_cpu) + with torch.backends.cudnn.flags(enabled=False): + torch.autograd.gradcheck( + lambda x, w, b: convfn(x, w, b, **kwargs), + (inputs, weight, bias) + ) + torch.autograd.gradcheck( + lambda x, w, b: convfn(x, w, b, **kwargs), + (inputs.cpu(), weight.cpu(), bias.cpu()) + ) + + def test_Conv2d_inconsistent_types(self): + inputs = torch.randn(4, 1, 7, 7, dtype=torch.float) + weights = torch.randn(1, 1, 3, 3, dtype=torch.double) + # inconsistent types should raise an exception + self.assertRaises(RuntimeError, lambda: nn.functional.conv2d(inputs, weights)) + # but it should work with the same type + nn.functional.conv2d(inputs.float(), weights.float()) + + @unittest.skipIf(not TEST_CUDA, 'CUDA not available') + def test_Conv2d_inconsistent_types_on_GPU_without_cudnn(self): + inputs = torch.randn(4, 1, 7, 7, dtype=torch.float, device="cuda") + weights = torch.randn(1, 1, 3, 3, dtype=torch.double, device="cuda") + bias = torch.randn(1, dtype=torch.double, device="cuda") + + with torch.backends.cudnn.flags(enabled=False): + # inconsistent types should raise an exception + self.assertRaises(RuntimeError, lambda: nn.functional.conv2d(inputs, weights)) + self.assertRaises(RuntimeError, lambda: nn.functional.conv2d(inputs, weights.float(), bias)) + + # but it should work with the same type + nn.functional.conv2d(inputs.float(), weights.float(), bias.float()) + + def test_Conv2d_1x1(self): + in_channels = 2 + out_channels = 2 + mod = torch.nn.Conv2d(2, 2, 1, bias=False).to(dtype=torch.double) + input = torch.randn(1, in_channels, 5, 5, requires_grad=True, dtype=torch.double) + for enabled in (False, True): + with torch.backends.mkldnn.flags(enabled=enabled): + gradcheck(F.conv2d, (input, mod.weight)) + + def test_Conv2d_OneDNN(self): + def run_once(group_val=24, dilation=1): + ifm = torch.ones([1, group_val, 6, 6], dtype=torch.float32) + weights = torch.ones([group_val, 1, 3, 3], dtype=torch.float32) + op = torch.nn.Conv2d( + in_channels=group_val, + out_channels=group_val, + kernel_size=[3, 3], + stride=[2, 2], + padding=[1, 1], + dilation=[dilation, dilation], + groups=group_val, + bias=False, + padding_mode='zeros' + ) + + op.weight.data = weights + res = op(ifm) + grad_in = torch.ones(res.shape, dtype=torch.float32) + res.backward(grad_in) + return op.weight.grad + + for gorup_val in (24, 48, 23, 25): + for dilation in (1, 2): + with torch.backends.mkldnn.flags(enabled=False): + without_onednn = run_once(gorup_val, dilation) + + with torch.backends.mkldnn.flags(enabled=True): + with_onednn = run_once(gorup_val, dilation) + + self.assertEqual(without_onednn, with_onednn) + + @unittest.skipIf(not TEST_CUDA, 'CUDA not available') + @unittest.skipIf(not TEST_CUDNN, 'CUDNN not available') + def test_cudnn_non_contiguous(self): + x = torch.randn(192, 16, 50).cuda() + x = x.permute(0, 2, 1).contiguous().permute(0, 2, 1) + m = torch.nn.Conv1d( + in_channels=16, + out_channels=32, + kernel_size=2, + bias=True).cuda() + result = m(x) + + @unittest.skipIf(not TEST_CUDA, 'CUDA not available') + @unittest.skipIf(not TEST_CUDNN, 'CUDNN not available') + def test_Conv2d_inconsistent_types_on_GPU_with_cudnn(self): + inputs = torch.randn(4, 1, 7, 7, dtype=torch.float, device="cuda") + weights = torch.randn(1, 1, 3, 3, dtype=torch.double, device="cuda") + bias = torch.randn(1, dtype=torch.double, device="cuda") + + with torch.backends.cudnn.flags(enabled=True): + # inconsistent types should raise an exception + self.assertRaises(RuntimeError, lambda: nn.functional.conv2d(inputs, weights)) + self.assertRaises(RuntimeError, lambda: nn.functional.conv2d(inputs, weights.float(), bias)) + + # but it should work with the same type + nn.functional.conv2d(inputs.float(), weights.float(), bias.float()) + + def test_Conv2d_missing_argument(self): + c = nn.Conv2d(3, 3, 3) + self.assertRaises(TypeError, lambda: c(None)) + + def test_Conv2d_backward_twice(self): + input = torch.randn(2, 3, 5, 5) + c = nn.Conv2d(3, 3, 3) + o1 = c(input) + o1.sum().backward() + self.assertRaisesRegex(RuntimeError, 'Specify retain_graph=True', + lambda: o1.sum().backward()) + + + def test_conv_modules_raise_error_on_incorrect_input_size(self): + for dtype in [torch.bfloat16, torch.double, torch.float]: + modules = [nn.Conv1d(3, 8, 3).to(dtype), nn.ConvTranspose1d(3, 8, 3).to(dtype), + nn.Conv2d(3, 8, 3).to(dtype), nn.ConvTranspose2d(3, 8, 3).to(dtype), + nn.Conv3d(3, 8, 3).to(dtype), nn.ConvTranspose3d(3, 8, 3).to(dtype)] + + invalid_input_dims = [(1, 4), (1, 4), + (2, 5), (2, 5), + (3, 6), (3, 6)] + + for invalid_dims, module in zip(invalid_input_dims, modules): + for dims in invalid_dims: + input = torch.empty(torch.Size((3, ) * dims)) + self.assertRaises(RuntimeError, lambda: module(input)) + + def test_conv_shapecheck(self): + def test(should_raise, module, input_size, dtype): + input = torch.empty(3, *input_size).to(dtype) + if should_raise: + self.assertRaises(RuntimeError, lambda: module(input)) + else: + # just run it to ensure no exception raised. + module(input) + + for dtype in [torch.bfloat16, torch.float, torch.double, torch.cfloat, torch.cdouble]: + # Conv1d + test(True, nn.Conv1d(1, 1, 3).to(dtype), (1, 2), dtype) + test(True, nn.Conv1d(1, 1, 3, stride=2).to(dtype), (1, 2), dtype) + test(False, nn.Conv1d(1, 1, 2).to(dtype), (1, 2), dtype) + test(False, nn.Conv1d(1, 1, 2, stride=2).to(dtype), (1, 2), dtype) + test(False, nn.Conv1d(1, 1, 3, stride=2, padding=1).to(dtype), (1, 2), dtype) + + # Conv2d + test(True, nn.Conv2d(1, 1, (3, 3)).to(dtype), (1, 2, 2), dtype) + test(False, nn.Conv2d(1, 1, (3, 3)).to(dtype), (1, 3, 3), dtype) + test(False, nn.Conv2d(1, 1, (3, 3), padding=1).to(dtype), (1, 2, 2), dtype) + + # Conv3D + test(True, nn.Conv3d(1, 1, (3, 3, 3)).to(dtype), (1, 2, 2, 2), dtype) + test(False, nn.Conv3d(1, 1, (3, 3, 3)).to(dtype), (1, 3, 3, 3), dtype) + test(False, nn.Conv3d(1, 1, (3, 3, 3), padding=1).to(dtype), (1, 2, 2, 2), dtype) + + def test_ConvTranspose2d_output_size(self): + m = nn.ConvTranspose2d(3, 4, 3, 3, 0, 2) + i = torch.randn(2, 3, 6, 6) + for h in range(15, 22): + for w in range(15, 22): + if 18 <= h <= 20 and 18 <= w <= 20: + output = m(i, output_size=(h, w)) + self.assertEqual(output.size()[2:], (h, w)) + else: + self.assertRaises(ValueError, lambda: m(i, (h, w))) + + def test_ConvTranspose2d_output_size_downsample_upsample(self): + b, c, hid_c = 2, 3, 2 + for h in range(13, 24): + for w in range(13, 17): + for k in range(2, 5): + for d in range(1, 5): + for s in range(1, 4): + for p in range(3): + conv = nn.Conv2d( + in_channels=c, + out_channels=hid_c, + kernel_size=k, + stride=s, + padding=p, + dilation=d, + ) + + t_conv = nn.ConvTranspose2d( + in_channels=hid_c, + out_channels=c, + kernel_size=k, + stride=s, + padding=p, + dilation=d, + ) + + i = torch.randn(b, c, h, w) + + out = t_conv(conv(i), output_size=i.shape) + + self.assertEqual(out.size()[2:], i.size()[2:]) + + def test_ConvTranspose3d_correct_output_size(self): + # Check that ConvTranspose3d can take a 5d output_size. + m = nn.ConvTranspose3d(2, 2, 2) + i = torch.rand(1, 2, 1, 1, 1) + out = m(i, output_size=(1, 2, 2, 2, 2)) + + @unittest.skipIf(not TEST_CUDA, 'CUDA not available') + def test_ConvTranspose2d_half_cublas_gemm(self): + with torch.backends.cudnn.flags(enabled=False): + inputs = torch.randn(1, 1, 16, 16, device='cuda', dtype=torch.half) + deconv = nn.ConvTranspose2d( + 1, 1, 3, stride=2, padding=1, output_padding=1).cuda().half() + output = deconv(inputs) + output.mean().backward() + + # For https://github.com/pytorch/pytorch/pull/1273 + # Almost identical to the above `test_Conv2d_naive_groups` + @torch.backends.cudnn.flags(enabled=True, benchmark=False) + def test_Conv2d_groups_nobias(self): + dev_dtypes = [("cpu", torch.float)] + if TEST_CUDA: + dev_dtypes += [("cuda", torch.float), ("cuda", torch.half)] + if AMPERE_OR_ROCM: + dev_dtypes += [("cuda", torch.bfloat16)] + for device, dtype in dev_dtypes: + m = nn.Conv2d(4, 4, kernel_size=3, groups=2, bias=False).to(device, dtype) + i = torch.randn(2, 4, 6, 6, device=device, dtype=dtype, requires_grad=True) + output = m(i) + grad_output = torch.randn(2, 4, 4, 4, device=device, dtype=dtype) + output.backward(grad_output) + + m1 = nn.Conv2d(2, 2, kernel_size=3, bias=False).to(device, dtype) + m1.weight.data.copy_(m.weight.data[:2]) + i1 = i.data[:, :2].contiguous().requires_grad_(True) + output1 = m1(i1) + output1.backward(grad_output[:, :2].contiguous()) + + m2 = nn.Conv2d(2, 2, kernel_size=3, bias=False).to(device, dtype) + m2.weight.data.copy_(m.weight.data[2:]) + i2 = i.data[:, 2:].contiguous().requires_grad_(True) + output2 = m2(i2) + output2.backward(grad_output[:, 2:].contiguous()) + + self.assertEqual(output, torch.cat([output1, output2], 1)) + self.assertEqual(i.grad.data, + torch.cat([i1.grad.data, i2.grad.data], 1), + atol=dtype2prec_DONTUSE[dtype], rtol=0) + self.assertEqual(m.weight.grad.data, + torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0), + atol=1e-1 if dtype == torch.half else dtype2prec_DONTUSE[dtype], rtol=0) + + # Almost identical to the above `test_Conv2d_naive_groups` + # Covering special case when group > 1, input-channel / group < 16 and output-channel is multiple of 16 + # See also https://github.com/pytorch/pytorch/pull/18463#issuecomment-476563686 + # and https://github.com/pytorch/pytorch/pull/18463#issuecomment-477001024 + @torch.backends.cudnn.flags(enabled=True, benchmark=False) + def test_Conv2d_groups_nobias_v2(self): + torch.manual_seed(123) + dev_dtypes = [("cpu", torch.float)] + if TEST_CUDA: + dev_dtypes += [("cuda", torch.float), ("cuda", torch.half)] + if AMPERE_OR_ROCM: + dev_dtypes += [("cuda", torch.bfloat16)] + for device, dtype in dev_dtypes: + m = nn.Conv2d(4, 16, kernel_size=3, groups=2, bias=False).to(device, dtype) + i = torch.randn(2, 4, 6, 6, device=device, dtype=dtype, requires_grad=True) + output = m(i) + grad_output = torch.randn(2, 16, 4, 4, device=device, dtype=dtype) + output.backward(grad_output) + + m1 = nn.Conv2d(2, 8, kernel_size=3, bias=False).to(device, dtype) + m1.weight.data.copy_(m.weight.data[:8]) + i1 = i.data[:, :2].contiguous().requires_grad_(True) + output1 = m1(i1) + output1.backward(grad_output[:, :8].contiguous()) + + m2 = nn.Conv2d(2, 8, kernel_size=3, bias=False).to(device, dtype) + m2.weight.data.copy_(m.weight.data[8:]) + i2 = i.data[:, 2:].contiguous().requires_grad_(True) + output2 = m2(i2) + output2.backward(grad_output[:, 8:].contiguous()) + + self.assertEqual(output, torch.cat([output1, output2], 1)) + self.assertEqual(i.grad.data, + torch.cat([i1.grad.data, i2.grad.data], 1), + atol=dtype2prec_DONTUSE[dtype], rtol=0) + self.assertEqual(m.weight.grad.data, + torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0), + atol=1e-1 if dtype == torch.half else dtype2prec_DONTUSE[dtype], rtol=0) + + # CPU-only test for group conv3d fast implementation using bmm + # See: https://github.com/pytorch/pytorch/pull/36355 + def test_Conv3d_groups_nobias(self): + torch.manual_seed(123) + m = nn.Conv3d(4, 16, kernel_size=3, groups=2, bias=False).to("cpu", torch.float) + i = torch.randn(2, 4, 6, 6, 6, device="cpu", dtype=torch.float, requires_grad=True) + output = m(i) + grad_output = torch.randn(2, 16, 4, 4, 4, device="cpu", dtype=torch.float) + output.backward(grad_output) + + m1 = nn.Conv3d(2, 8, kernel_size=3, bias=False).to("cpu", torch.float) + m1.weight.data.copy_(m.weight.data[:8]) + i1 = i.data[:, :2].contiguous().requires_grad_(True) + output1 = m1(i1) + output1.backward(grad_output[:, :8].contiguous()) + + m2 = nn.Conv3d(2, 8, kernel_size=3, bias=False).to("cpu", torch.float) + m2.weight.data.copy_(m.weight.data[8:]) + i2 = i.data[:, 2:].contiguous().requires_grad_(True) + output2 = m2(i2) + output2.backward(grad_output[:, 8:].contiguous()) + + self.assertEqual(output, torch.cat([output1, output2], 1)) + self.assertEqual(i.grad.data, + torch.cat([i1.grad.data, i2.grad.data], 1), + atol=dtype2prec_DONTUSE[torch.float], rtol=0) + self.assertEqual(m.weight.grad.data, + torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0), + atol=dtype2prec_DONTUSE[torch.float], rtol=dtype2prec_DONTUSE[torch.float]) + + def test_Conv3d_groups_wbias(self): + torch.manual_seed(123) + m = nn.Conv3d(4, 16, kernel_size=3, groups=2, bias=True).to("cpu", torch.float) + i = torch.randn(2, 4, 6, 6, 6, device="cpu", dtype=torch.float, requires_grad=True) + output = m(i) + grad_output = torch.randn(2, 16, 4, 4, 4, device="cpu", dtype=torch.float) + output.backward(grad_output) + + m1 = nn.Conv3d(2, 8, kernel_size=3, bias=True).to("cpu", torch.float) + m1.weight.data.copy_(m.weight.data[:8]) + m1.bias.data.copy_(m.bias.data[:8]) + i1 = i.data[:, :2].contiguous().requires_grad_(True) + output1 = m1(i1) + output1.backward(grad_output[:, :8].contiguous()) + + m2 = nn.Conv3d(2, 8, kernel_size=3, bias=True).to("cpu", torch.float) + m2.weight.data.copy_(m.weight.data[8:]) + m2.bias.data.copy_(m.bias.data[8:]) + i2 = i.data[:, 2:].contiguous().requires_grad_(True) + output2 = m2(i2) + output2.backward(grad_output[:, 8:].contiguous()) + + self.assertEqual(output, torch.cat([output1, output2], 1)) + self.assertEqual(i.grad.data, + torch.cat([i1.grad.data, i2.grad.data], 1), + atol=dtype2prec_DONTUSE[torch.float], + rtol=dtype2prec_DONTUSE[torch.float]) + self.assertEqual(m.weight.grad.data, + torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0), + atol=dtype2prec_DONTUSE[torch.float], + rtol=dtype2prec_DONTUSE[torch.float]) + self.assertEqual(m.bias.grad.data, + torch.cat([m1.bias.grad.data, m2.bias.grad.data], 0), + atol=dtype2prec_DONTUSE[torch.float], rtol=dtype2prec_DONTUSE[torch.float]) + + def test_conv_tbc(self): + with set_default_dtype(torch.double): + inp = torch.randn(9, 4, 5, requires_grad=True) + weight = torch.randn(3, 5, 6, requires_grad=True) + bias = torch.randn(6, requires_grad=True) + + gradcheck(lambda i, w, b, pad: F.conv_tbc(i, w, b, pad), (inp, weight, bias, 3)) + + @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") + @unittest.skipIf(not TEST_CUDNN, "needs cudnn") + @skipIfRocmVersionLessThan((4, 3)) + @skipIfNotMiopenSuggestNHWC + def test_grouped_conv_cudnn_nhwc_support(self): + # in order to catch the hols in grouped convolution in nhwc support for earlier cudnn version + input = torch.randn((16, 16, 8, 8), dtype=torch.float16, device="cuda").to(memory_format=torch.channels_last) + weight = torch.randn((8, 4, 3, 3), dtype=torch.float16, device="cuda").to(memory_format=torch.channels_last) + out = torch.convolution(input, weight, None, (1, 1), (1, 1), (1, 1), False, (0, 0), 4) + input = torch.randn((16, 8, 8, 8), dtype=torch.float16, device="cuda").to(memory_format=torch.channels_last) + out_transpose = torch.convolution(input, weight, None, (1, 1), (1, 1), (1, 1), True, (0, 0), 4) + + @unittest.expectedFailure + @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") + @unittest.skipIf(not TEST_CUDNN, "needs cudnn") + def test_conv_cudnn_memory_layout_dominance(self): + # desired behavior here is to have the memory_layout of conv.weight to + # dominante the layout of output. + # which is not the same as current behavior, we'll fix this in + # following up PRs and remove the `expectedFailure` tag + input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float32, device="cuda", requires_grad=True) + conv = nn.Conv2d(8, 4, 3).cuda().float() + + out = conv(input) + self.assertTrue(out.is_contiguous()) + + input = input.contiguous(memory_format=torch.channels_last) + out = conv(input) + self.assertTrue(out.is_contiguous()) + + conv.weight.data = conv.weight.contiguous(memory_format=torch.channels_last) + out = conv(input) + self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) + + input = input.contiguous() + out = conv(input) + self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) + + + @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") + def test_cudnn_noncontiguous_weight(self): + # Noncontiguous weights must be contiguous() before being + # passed to cuDNN + input = torch.tensor([1, 1, 1], dtype=torch.double, device="cuda").view(1, 1, 3) + weights1 = torch.tensor([1], dtype=torch.double, device="cuda").expand(1, 1, 2) + weights2 = torch.tensor([1], dtype=torch.double, device="cuda").expand(1, 1, 2).contiguous() + self.assertEqual(F.conv1d(input, weights1, bias=None, stride=2, dilation=2), + F.conv1d(input, weights2, bias=None, stride=2, dilation=2)) + + + def run_grad_conv_test(self, func_forward, func_backward, dim=1, gradient='input'): + for kern, inp_size in [(3, 6), (3, 7), (4, 9)]: + for batch, stride, padding, chan_in, chan_out, dilation in \ + product([1, 2], [1, 2], [0, 1, 2], [2], [3], [1]): + + for has_bias in [True, False]: + input_shape = [batch, chan_in] + weight_shape = [chan_out, chan_in] + for _ in range(dim): + input_shape.append(inp_size) + weight_shape.append(kern) + + input = torch.randn(input_shape, requires_grad=True) + weight = torch.randn(weight_shape, requires_grad=True) + if has_bias: + bias = torch.randn([chan_out], requires_grad=True) + output = func_forward(input, weight, stride=stride, padding=padding, dilation=dilation, bias=bias) + + gradient_o = torch.randn(output.shape) + gradient_w = torch.autograd.grad(output, input if (gradient == 'input') else weight, gradient_o) + + self.assertEqual(gradient_w[0], + func_backward( + input_shape if (gradient == 'input') else input, + weight_shape if (gradient == 'weight') else weight, + gradient_o, + stride=stride, + padding=padding, + dilation=dilation)) + + def test_grad_conv1d_input(self): + self.run_grad_conv_test(F.conv1d, F.grad.conv1d_input, 1, 'input') + + def test_grad_conv1d_weight(self): + self.run_grad_conv_test(F.conv1d, F.grad.conv1d_weight, 1, 'weight') + + def test_grad_conv2d_input(self): + self.run_grad_conv_test(F.conv2d, F.grad.conv2d_input, 2, 'input') + + def test_grad_conv2d_weight(self): + self.run_grad_conv_test(F.conv2d, F.grad.conv2d_weight, 2, 'weight') + + def test_grad_conv3d_input(self): + self.run_grad_conv_test(F.conv3d, F.grad.conv3d_input, 3, 'input') + + def test_grad_conv3d_weight(self): + self.run_grad_conv_test(F.conv3d, F.grad.conv3d_weight, 3, 'weight') + + @unittest.skipIf(not torch._nnpack_available(), "NNPACK unavailable") + def test_nnpack_conv(self): + for kern, inp_size in [(3, 6), (3, 7), (4, 9)]: + for batch, stride, padding, chan_in, chan_out in \ + product([1, 2, 3, 4], [1, 2], [0, 1, 2], [2], [3]): + + for has_bias in [True, False]: + input_shape = [batch, chan_in] + weight_shape = [chan_out, chan_in] + for _ in range(2): + input_shape.append(inp_size) + weight_shape.append(kern) + + input = torch.randn(input_shape, requires_grad=True, dtype=torch.float) + weight = torch.randn(weight_shape, requires_grad=True, dtype=torch.float) + if has_bias: + bias = torch.randn([chan_out], requires_grad=True, dtype=torch.float) + output = torch._nnpack_spatial_convolution(input, weight, stride=stride, padding=padding, bias=bias) + output_expected = torch.nn.functional.conv2d(input, weight, stride=stride, padding=padding, bias=bias) + self.assertEqual(output, output_expected, atol=3e-4, rtol=0) + + gradient_o = torch.randn(output.shape, dtype=torch.float) + + grads = torch.autograd.grad(output, [input, weight], gradient_o) + grads_expected = torch.autograd.grad(output_expected, [input, weight], gradient_o) + for gr, gr_expected in zip(grads, grads_expected): + self.assertEqual(gr, gr_expected, atol=3e-4, rtol=0) + + def test_conv_padding_mode(self): + with self.assertRaisesRegex(ValueError, "padding_mode must be one of"): + nn.Conv2d(3, 3, 3, padding_mode="xyz") + + with self.assertRaisesRegex(ValueError, "padding_mode must be one of"): + nn.Conv2d(3, 3, 3, padding_mode=3) + + with self.assertRaisesRegex(ValueError, "Only \"zeros\" "): + nn.ConvTranspose2d(3, 3, 3, padding_mode="reflect") + + + def test_functional_grad_conv(self): + # Conv 1D + input = torch.randn(1, 1, 5, requires_grad=True) + weight = torch.randn(1, 1, 3, requires_grad=True) + output = F.conv1d(input, weight, dilation=2) + grad_output = torch.randn(output.shape) + + grad_input_autograd, grad_weight_autograd = torch.autograd.grad(output, (input, weight), grad_output) + + grad_input_functional = torch.nn.grad.conv1d_input(input.shape, weight, grad_output, dilation=2) + self.assertEqual(grad_input_functional, grad_input_autograd) + + grad_weight_functional = torch.nn.grad.conv1d_weight(input, weight.shape, grad_output, dilation=2) + self.assertEqual(grad_weight_functional, grad_weight_autograd) + + # Conv 2D + input = torch.randn(1, 1, 5, 5, requires_grad=True) + weight = torch.randn(1, 1, 3, 3, requires_grad=True) + output = F.conv2d(input, weight, dilation=2) + grad_output = torch.randn(output.shape) + + (grad_input_autograd, grad_weight_autograd) = torch.autograd.grad(output, (input, weight), grad_output) + + grad_input_functional = torch.nn.grad.conv2d_input(input.shape, weight, grad_output, dilation=2) + self.assertEqual(grad_input_functional, grad_input_autograd) + + grad_weight_functional = torch.nn.grad.conv2d_weight(input, weight.shape, grad_output, dilation=2) + self.assertEqual(grad_weight_functional, grad_weight_autograd) + + # Conv 3D + input = torch.randn(1, 1, 5, 5, 5, requires_grad=True) + weight = torch.randn(1, 1, 3, 3, 3, requires_grad=True) + output = F.conv3d(input, weight, dilation=2) + grad_output = torch.randn(output.shape) + + (grad_input_autograd, grad_weight_autograd) = torch.autograd.grad(output, (input, weight), grad_output) + + grad_input_functional = torch.nn.grad.conv3d_input(input.shape, weight, grad_output, dilation=2) + self.assertEqual(grad_input_functional, grad_input_autograd) + + grad_weight_functional = torch.nn.grad.conv3d_weight(input, weight.shape, grad_output, dilation=2) + self.assertEqual(grad_weight_functional, grad_weight_autograd) + + def test_functional_grad_conv2d(self): + BATCH_SIZE = 4 + IN_CH = 8 + OUT_CH = 16 + SPATIAL = 32 + + def _test_conv2d(stride, kernel_size, groups, dilation): + padding = kernel_size // 2 + + input = torch.empty(BATCH_SIZE, IN_CH, SPATIAL, SPATIAL).uniform_(-8.0, 8.0).requires_grad_(True) + + weight = torch.empty(OUT_CH, IN_CH // groups, kernel_size, kernel_size).uniform_(-4.0, 4.0).requires_grad_(True) + + output = F.conv2d(input, weight, + stride=stride, padding=padding, dilation=dilation, groups=groups) + + grad_output = torch.randn(output.shape) + + (grad_input_autograd, grad_weight_autograd) = torch.autograd.grad(output, (input, weight), grad_output) + + grad_input_functional = torch.nn.grad.conv2d_input(input.shape, weight, grad_output, + stride=stride, padding=padding, dilation=dilation, groups=groups) + self.assertEqual(grad_input_functional, grad_input_autograd) + + grad_weight_functional = torch.nn.grad.conv2d_weight(input, weight.shape, grad_output, + stride=stride, padding=padding, dilation=dilation, groups=groups) + self.assertEqual(grad_weight_functional, grad_weight_autograd) + + strides = [1, 2] + kernel_sizes = [1, 3, 5] + groups = [1, 2, 4] + dilates = [1, 2] + + for s, k, g, d in product(strides, kernel_sizes, groups, dilates): + _test_conv2d(s, k, g, d) + + +class TestConvolutionNNDeviceType(NNTestCase): + def run_conv_double_back_test(self, kern, stride, padding, chan_in, chan_out, batch_size, + inp_size, dilation, no_weight, groups=1, use_cuda=False, + use_bias=True, dtype=torch.double): + if use_cuda: + device = torch.device("cuda") + else: + device = torch.device("cpu") + + x = torch.randn(batch_size, chan_in, inp_size, inp_size, device=device, + dtype=dtype, requires_grad=True) + weight = torch.randn(chan_out, chan_in // groups, kern, kern, device=device, + dtype=dtype, requires_grad=not no_weight) + if use_bias: + bias = torch.randn(chan_out, device=device, dtype=dtype, requires_grad=True) + else: + bias = None + + def func(*inputs): + if use_bias: + lx, lweight, lbias = inputs + else: + lx, lweight = inputs + lbias = None + # We disable cudnn during forward to avoid finite difference imprecision issues + with cudnn.flags(enabled=False): + out = F.conv2d(lx, lweight, lbias, stride, padding, dilation, groups) + return out + + if use_bias: + inputs = x, weight, bias + else: + inputs = x, weight + + dummy_out = func(*inputs) + grad_y = torch.randn_like(dummy_out, device=device, dtype=dtype, requires_grad=True) + + # Issue #15353: test mkldnn double backward, don't run gradgradcheck due + # to imprecision issues + if dtype == torch.float: + g, = torch.autograd.grad(dummy_out.sum(), x, create_graph=True) + return g.requires_grad + + return gradgradcheck(func, inputs, (grad_y,)) + + @onlyCUDA + @skipCUDAIfNoCudnn + @dtypes(*floating_and_complex_types_and(torch.half, *[torch.bfloat16] if AMPERE_OR_ROCM else [])) + def test_Conv2d_deterministic_cudnn(self, device, dtype): + inputs = torch.randn(2, 3, 5, 5, device=device, dtype=dtype, requires_grad=True) + with cudnn.flags(enabled=True, benchmark=True, deterministic=True): + conv1 = torch.nn.Conv2d(3, 3, 3).to(device, dtype) + conv2 = torch.nn.Conv2d(3, 3, 3).to(device, dtype) + conv2.bias.data.copy_(conv1.bias.data) + conv2.weight.data.copy_(conv1.weight.data) + out1 = conv1(inputs) + out2 = conv2(inputs) + self.assertEqual(out1, out2, atol=0.0, rtol=0) + y = torch.randn(out1.size(), device=device, dtype=dtype) + out1.backward(y) + out2.backward(y) + self.assertEqual(conv1.bias.grad.data, conv2.bias.grad.data, atol=0.0, rtol=0) + self.assertEqual(conv1.weight.grad.data, conv2.weight.grad.data, atol=0.0, rtol=0) + + + @onlyCUDA + @dtypes(*floating_types_and(torch.half, *[torch.bfloat16] if AMPERE_OR_ROCM else [])) + def test_Conv2d_large_workspace(self, device, dtype): + # These sizes require huge cuDNN workspaces. Make sure we choose a + # reasonable algorithm that does not run out of memory + sizes = [ + (1, 256, 109, 175), + (1, 256, 80, 128), + (1, 256, 120, 192), + ] + + def run_test(benchmark): + with torch.backends.cudnn.flags(benchmark=benchmark): + conv = torch.nn.Conv2d(256, 256, kernel_size=3, padding=1).to(device, dtype) + for size in sizes: + x = torch.randn(size, device=device, dtype=dtype) + out = conv(x.detach().clone().requires_grad_()) + out.backward(torch.ones_like(out)) + + run_test(benchmark=False) + run_test(benchmark=True) + + + @onlyCUDA + @dtypes(torch.half, torch.float) + def test_ConvTranspose2d_large_output_padding(self, device, dtype): + net1 = torch.nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\ + .to(device=device, dtype=dtype) + net2 = torch.nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1)\ + .to(device=device, dtype=dtype) + net3 = torch.nn.ConvTranspose2d(32, 3, kernel_size=3, stride=2, padding=1, output_padding=1)\ + .to(device=device, dtype=dtype) + x = torch.rand(1, 128, 6, 6, device=device, dtype=dtype, requires_grad=True) + x = net1(x) + x = net2(x) + x = net3(x) + x.backward(torch.randn_like(x)) + torch.cuda.synchronize() + + + @onlyCUDA + @tf32_on_and_off(0.01) + @dtypes(torch.float, torch.double, torch.half) + # Very similar to test_Conv2d_naive_groups but with special care to handle + # the number of groups == number of input channels + @torch.backends.cudnn.flags(enabled=True, benchmark=False) + def test_Conv2d_depthwise_naive_groups(self, device, dtype): + for depth_multiplier in [1, 2]: + m = nn.Conv2d(2, 2 * depth_multiplier, kernel_size=3, groups=2).to(device, dtype) + i = torch.randn(2, 2, 6, 6, device="cuda", dtype=dtype).div_(2).requires_grad_() + output = m(i) + grad_output = torch.randn(2, 2 * depth_multiplier, 4, 4, device=device, dtype=dtype) / 2 + output.backward(grad_output) + + offset = 1 * depth_multiplier + + m1 = nn.Conv2d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype) + m1.weight.data = m.weight.data[:offset].clone() + m1.bias.data = m.bias.data[:offset].clone() + i1 = i.detach()[:, :1].clone().requires_grad_() + output1 = m1(i1) + output1.backward(grad_output[:, :offset].contiguous()) + + m2 = nn.Conv2d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype) + m2.weight.data.copy_(m.weight.data[offset:]) + m2.bias.data.copy_(m.bias.data[offset:]) + i2 = i.detach()[:, 1:].clone().requires_grad_() + output2 = m2(i2) + output2.backward(grad_output[:, offset:].contiguous()) + + self.assertEqual(output, torch.cat([output1, output2], 1), + atol=dtype2prec_DONTUSE[dtype], rtol=0) + self.assertEqual(i.grad.data, + torch.cat([i1.grad.data, i2.grad.data], 1), + atol=dtype2prec_DONTUSE[dtype], rtol=0) + self.assertEqual(m.bias.grad.data, + torch.cat([m1.bias.grad.data, + m2.bias.grad.data], 0), + atol=dtype2prec_DONTUSE[dtype], rtol=0) + self.assertEqual(m.weight.grad.data, + torch.cat([m1.weight.grad.data, + m2.weight.grad.data], 0), + atol=dtype2prec_DONTUSE[dtype], rtol=0) + + @onlyCUDA + @dtypes(torch.float, torch.double, torch.half) + @tf32_on_and_off(0.005) + @torch.backends.cudnn.flags(enabled=True, benchmark=False) + def test_Conv3d_depthwise_naive_groups(self, device, dtype): + for depth_multiplier in [1, 2]: + m = nn.Conv3d(2, 2 * depth_multiplier, kernel_size=3, groups=2).to(device, dtype) + i = torch.randn(2, 2, 6, 6, 6, device="cuda", dtype=dtype).div_(2).requires_grad_() + output = m(i) + grad_output = torch.randn(2, 2 * depth_multiplier, 4, 4, 4, device=device, dtype=dtype) / 2 + output.backward(grad_output) + + offset = 1 * depth_multiplier + + m1 = nn.Conv3d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype) + m1.weight.data = m.weight.data[:offset].clone() + m1.bias.data = m.bias.data[:offset].clone() + i1 = i.detach()[:, :1].clone().requires_grad_() + output1 = m1(i1) + output1.backward(grad_output[:, :offset].contiguous()) + + m2 = nn.Conv3d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype) + m2.weight.data.copy_(m.weight.data[offset:]) + m2.bias.data.copy_(m.bias.data[offset:]) + i2 = i.detach()[:, 1:].clone().requires_grad_() + output2 = m2(i2) + output2.backward(grad_output[:, offset:].contiguous()) + is_cuda_sm86 = device.startswith("cuda") and torch.cuda.get_device_capability(0) == (8, 6) + atol, rtol = (3e-4, 3e-2) if dtype == torch.float32 and is_cuda_sm86 else (dtype2prec_DONTUSE[dtype], 0) + + self.assertEqual(output, torch.cat([output1, output2], 1), + atol=atol, rtol=rtol) + self.assertEqual(i.grad.data, + torch.cat([i1.grad.data, i2.grad.data], 1), + atol=dtype2prec_DONTUSE[dtype], rtol=0) + self.assertEqual(m.bias.grad.data, + torch.cat([m1.bias.grad.data, + m2.bias.grad.data], 0), + atol=dtype2prec_DONTUSE[dtype], rtol=0) + self.assertEqual(m.weight.grad.data, + torch.cat([m1.weight.grad.data, + m2.weight.grad.data], 0), + atol=dtype2prec_DONTUSE[dtype], rtol=0) + + + @onlyCUDA + @dtypes(*floating_types_and(torch.half, *[torch.bfloat16] if AMPERE_OR_ROCM else [])) + def test_noncontig_conv_grad(self, device, dtype): + # FIXME: remove after adding non-contiguous grad tests for all modules + module = nn.Conv2d(3, 5, kernel_size=3, padding=1).to(device, dtype) + input = torch.randn(2, 3, 10, 10, dtype=dtype, device=device, requires_grad=True) + output = module(input) + + grad = torch.randn(2, 2, 5, 10, 10, dtype=dtype, device=device)[:, 1] + assert not grad.is_contiguous() + output.backward(grad, retain_graph=True) + self.assertIsNotNone(input.grad) + result = input.grad.data.clone() + input.grad.data.zero_() + + output.backward(grad.contiguous()) + self.assertEqual(result, input.grad.data, atol=dtype2prec_DONTUSE[dtype], rtol=0) + + @onlyCUDA + @dtypes(torch.double) + def test_conv_double_backward(self, device, dtype): + with torch.backends.cudnn.flags(deterministic=True): + # Double backward only runs with DoubleTensor due to precision reason + batch_size = 1 + for kern, inp_size, dilations in [(3, 5, [1, 2]), (4, 9, [1])]: + for stride, padding, chan_in, chan_out, dilation in product([1], [2], [2], [3], dilations): + no_weight = stride == 2 + result = self.run_conv_double_back_test(kern, stride, + padding, chan_in, chan_out, + batch_size, inp_size, dilation, + no_weight, use_cuda=True, dtype=dtype) + self.assertTrue(result, + "Conv double backward test failed with parameters:" + + "\nkern: " + str(kern) + + "\nstride: " + str(stride) + + "\npadding: " + str(padding) + + "\nchan_in: " + str(chan_in) + + "\nchan_out: " + str(chan_out) + + "\nbatch_size: " + str(batch_size) + + "\ninp_size: " + str(inp_size) + + "\ndilation: " + str(dilation)) + + + def test_conv_double_backward_no_bias(self): + kern = 3 + stride = 2 + chan_in, chan_out = 2, 4 + batch_size = 2 + inp_size = 5 + padding = 1 + dilation = 1 + no_weight = False + use_bias = True + result = self.run_conv_double_back_test(kern, stride, + padding, chan_in, chan_out, + batch_size, inp_size, dilation, + no_weight, use_bias=use_bias) + self.assertTrue(result, + "Conv double backward test failed with parameters:" + + "\nkern: " + str(kern) + + "\nstride: " + str(stride) + + "\npadding: " + str(padding) + + "\nchan_in: " + str(chan_in) + + "\nchan_out: " + str(chan_out) + + "\nbatch_size: " + str(batch_size) + + "\ninp_size: " + str(inp_size) + + "\ndilation: " + str(dilation)) + + + def test_conv_double_backward_groups(self): + kern = 3 + stride = 1 + padding = 2 + chan_in, chan_out = 2, 4 + batch_size = 2 + inp_size = 6 + dilation = 1 + no_weight = False + groups = 2 + result = self.run_conv_double_back_test(kern, stride, + padding, chan_in * groups, chan_out * groups, + batch_size, inp_size, dilation, + no_weight, groups=groups) + self.assertTrue(result, + "Conv double backward test failed with parameters:" + + "\nkern: " + str(kern) + + "\nstride: " + str(stride) + + "\npadding: " + str(padding) + + "\nchan_in: " + str(chan_in) + + "\nchan_out: " + str(chan_out) + + "\nbatch_size: " + str(batch_size) + + "\ninp_size: " + str(inp_size) + + "\ndilation: " + str(dilation) + + "\ngroups: " + str(groups)) + + + def test_conv_double_backward_stride(self): + batch_size = 2 + + # Cannot provide ggW when stride is > 1 + for kern, inp_size, dilations in [(3, 5, [1, 2]), (3, 7, [1])]: + for stride, padding, chan_in, chan_out, dilation in product([2], [0, 1], [1], [2], dilations): + no_weight = False + self.run_conv_double_back_test(kern, stride, + padding, chan_in, chan_out, + batch_size, inp_size, dilation, + no_weight) + + @dtypes(torch.float, torch.cfloat) + @torch.backends.cudnn.flags(enabled=True, benchmark=False) + def test_conv1d_same_padding(self, device, dtype): + # Test padding='same' outputs the correct shape + test_args = [ + # in_size + range(50, 55), + # kernel_size + [1, 2, 3, 8], + # dilation + range(1, 4), + # stride + [1], + ] + for in_size, k_size, dilation, stride in itertools.product(*test_args): + x = torch.rand(1, 1, in_size, device=device, dtype=dtype) + y = torch.rand(1, 1, k_size, device=device, dtype=dtype) + z = F.conv1d(x, y, padding='same', dilation=dilation, stride=stride) + self.assertEqual(z.size(2), int(math.ceil(in_size / stride))) + + # Compare F.conv1d padding='same' output against manual padding + # Without strides/dilation + x = torch.rand(1, 1, 12, device=device, dtype=dtype) + y = torch.rand(1, 1, 3, device=device, dtype=dtype) + expect = F.conv1d(x, y, padding=1) + actual = F.conv1d(x, y, padding='same') + self.assertEqual(expect, actual) + + # With dilation + x = torch.rand(1, 1, 12, device=device, dtype=dtype) + y = torch.rand(1, 1, 4, device=device, dtype=dtype) + expect = F.conv1d(x, y, padding=3, dilation=2) + actual = F.conv1d(x, y, padding='same', dilation=2) + self.assertEqual(expect, actual) + + # Dilation with asymmetric padding + expect = F.conv1d(x, y, padding=5, dilation=3)[..., 1:] + actual = F.conv1d(x, y, padding='same', dilation=3) + self.assertEqual(expect, actual) + + @dtypes(torch.float, torch.cfloat) + def test_conv2d_same_padding(self, device, dtype): + if dtype is torch.cfloat: + rtol, atol = 2e-6, 2e-6 + else: + rtol, atol = None, None + # Compare F.conv2d padding='same' output against manual padding + # Without strides/dilation + x = torch.rand(1, 1, 10, 11, device=device, dtype=dtype) + y = torch.rand(1, 1, 4, 5, device=device, dtype=dtype) + expect = F.conv2d(x, y, padding=(2, 2))[..., 1:, :] + actual = F.conv2d(x, y, padding='same') + self.assertEqual(expect, actual, rtol=rtol, atol=atol) + + # With dilation + y = torch.rand(1, 1, 3, 4, device=device, dtype=dtype) + expect = F.conv2d(x, y, padding=(2, 3), dilation=2) + actual = F.conv2d(x, y, padding='same', dilation=2) + self.assertEqual(expect, actual, rtol=rtol, atol=atol) + + # Dilation with asymmetric padding + y = torch.rand(1, 1, 4, 4, device=device, dtype=dtype) + expect = F.conv2d(x, y, padding=5, dilation=3)[..., 1:, 1:] + actual = F.conv2d(x, y, padding='same', dilation=3) + self.assertEqual(expect, actual, rtol=rtol, atol=atol) + + @dtypes(torch.float, torch.cfloat) + def test_conv3d_same_padding(self, device, dtype): + if dtype is torch.cfloat: + rtol, atol = 2e-6, 2e-6 + else: + rtol, atol = None, None + # Compare F.conv3d padding='same' output against manual padding + # Without strides/dilation + x = torch.rand(1, 1, 10, 11, 12, device=device, dtype=dtype) + y = torch.rand(1, 1, 1, 2, 5, device=device, dtype=dtype) + expect = F.conv3d(x, y, padding=(0, 1, 2))[..., :, 1:, :] + actual = F.conv3d(x, y, padding='same') + self.assertEqual(expect, actual, rtol=rtol, atol=atol) + + # With dilation + expect = F.conv3d(x, y, padding=(0, 1, 4), dilation=2) + actual = F.conv3d(x, y, padding='same', dilation=2) + self.assertEqual(expect, actual, rtol=rtol, atol=atol) + + # Dilation with asymmetric padding + y = torch.rand(1, 1, 4, 4, 4, device=device, dtype=dtype) + expect = F.conv3d(x, y, padding=5, dilation=3)[..., 1:, 1:, 1:] + actual = F.conv3d(x, y, padding='same', dilation=3) + self.assertEqual(expect, actual, rtol=rtol, atol=atol) + + @dtypes(torch.float, torch.cfloat) + def test_conv1d_valid_padding(self, device, dtype): + # Test F.conv1d padding='valid' is the same as no padding + x = torch.rand(1, 1, 10, device=device, dtype=dtype) + y = torch.rand(1, 1, 4, device=device, dtype=dtype) + expect = F.conv1d(x, y) + actual = F.conv1d(x, y, padding='valid') + self.assertEqual(expect, actual) + + @dtypes(torch.float, torch.cfloat) + def test_conv2d_valid_padding(self, device, dtype): + # Test F.conv2d padding='valid' is the same as no padding + x = torch.rand(1, 1, 1, 10, device=device, dtype=dtype) + y = torch.rand(1, 1, 1, 4, device=device, dtype=dtype) + expect = F.conv2d(x, y) + actual = F.conv2d(x, y, padding='valid') + self.assertEqual(expect, actual) + + @dtypes(torch.float, torch.cfloat) + def test_conv3d_valid_padding(self, device, dtype): + # Test F.conv3d padding='valid' is the same as no padding + x = torch.rand(1, 1, 1, 1, 10, dtype=dtype, device=device) + y = torch.rand(1, 1, 1, 1, 4, dtype=dtype, device=device) + expect = F.conv3d(x, y) + actual = F.conv3d(x, y, padding='valid') + self.assertEqual(expect, actual) + + @dtypes(torch.float, torch.cfloat) + def test_conv1d_same_padding_backward(self, device, dtype): + # Test F.conv1d gradients work with padding='same' + x = torch.rand(1, 1, 12, dtype=dtype, device=device, requires_grad=True) + y = torch.rand(1, 1, 4, dtype=dtype, device=device, requires_grad=True) + + # Symmetric padding + z = F.conv1d(x, y, padding=3, dilation=2) + z.sum().backward() + gx_expect, gy_expect = x.grad, y.grad + x.grad, y.grad = None, None + + z = F.conv1d(x, y, padding='same', dilation=2) + z.sum().backward() + self.assertEqual(gx_expect, x.grad) + self.assertEqual(gy_expect, y.grad) + x.grad, y.grad = None, None + + # Asymmetric padding + z = F.conv1d(x, y, padding=2)[..., 1:] + z.sum().backward() + gx_expect, gy_expect = x.grad, y.grad + x.grad, y.grad = None, None + + z = F.conv1d(x, y, padding='same') + z.sum().backward() + self.assertEqual(gx_expect, x.grad) + self.assertEqual(gy_expect, y.grad) + + @dtypes(torch.float, torch.cfloat) + def test_conv2d_same_padding_backward(self, device, dtype): + # Test F.conv2d gradients work with padding='same' + x = torch.rand(1, 1, 10, 11, device=device, dtype=dtype, requires_grad=True) + y = torch.rand(1, 1, 4, 5, device=device, dtype=dtype, requires_grad=True) + + # Symmetric padding + z = F.conv2d(x, y, padding=(3, 4), dilation=2) + z.sum().backward() + gx_expect, gy_expect = x.grad, y.grad + x.grad, y.grad = None, None + + z = F.conv2d(x, y, padding='same', dilation=2) + z.sum().backward() + self.assertEqual(gx_expect, x.grad) + self.assertEqual(gy_expect, y.grad) + x.grad, y.grad = None, None + + # Asymmetric padding + y = torch.rand(1, 1, 4, 4, device=device, dtype=dtype, requires_grad=True) + z = F.conv2d(x, y, padding=2)[..., 1:, 1:] + z.sum().backward() + gx_expect, gy_expect = x.grad, y.grad + x.grad, y.grad = None, None + + z = F.conv2d(x, y, padding='same') + z.sum().backward() + self.assertEqual(gx_expect, x.grad) + self.assertEqual(gy_expect, y.grad) + + @dtypes(torch.double, torch.cdouble) + def test_conv3d_same_padding_backward(self, device, dtype): + check_forward_ad = torch.device(device).type != 'xla' + + # Test F.conv3d gradients work with padding='same' + x = torch.rand(1, 1, 1, 11, 12, dtype=dtype, device=device, requires_grad=True) + y = torch.rand(1, 1, 1, 2, 5, dtype=dtype, device=device, requires_grad=True) + + # Symmetric padding + z = F.conv3d(x, y, padding=(0, 1, 4), dilation=2) + z.sum().backward() + gx_expect, gy_expect = x.grad, y.grad + x.grad, y.grad = None, None + + z = F.conv3d(x, y, padding='same', dilation=2) + z.sum().backward() + self.assertEqual(gx_expect, x.grad) + self.assertEqual(gy_expect, y.grad) + x.grad, y.grad = None, None + + gradcheck(lambda x, y: F.conv3d(x, y, padding='same', dilation=2), (x, y), + check_forward_ad=check_forward_ad, nondet_tol=1e-5) + if torch.device(device).type != 'cuda': + # https://github.com/pytorch/pytorch/issues/70702 + gradgradcheck(lambda x, y: F.conv3d(x, y, padding='same', dilation=2), (x, y), + check_fwd_over_rev=True) + + # Asymmetric padding + y = torch.rand(1, 1, 1, 4, 4, dtype=dtype, device=device, requires_grad=True) + z = F.conv3d(x, y, padding=2)[..., 1:, 1:] + z.sum().backward() + gx_expect, gy_expect = x.grad, y.grad + x.grad, y.grad = None, None + + z = F.conv3d(x, y, padding='same') + z.sum().backward() + self.assertEqual(gx_expect, x.grad) + self.assertEqual(gy_expect, y.grad) + + gradcheck(lambda x, y: F.conv3d(x, y, padding='same'), (x, y), + check_forward_ad=check_forward_ad, nondet_tol=1e-5) + if torch.device(device).type != 'cuda': + # https://github.com/pytorch/pytorch/issues/70702 + gradgradcheck(lambda x, y: F.conv3d(x, y, padding='same'), (x, y), + check_fwd_over_rev=True) + + @dtypes(torch.float, torch.cfloat) + def test_conv1d_valid_padding_backward(self, device, dtype): + # Test F.conv1d gradients work with padding='valid' + x = torch.rand(1, 1, 10, dtype=dtype, device=device, requires_grad=True) + y = torch.rand(1, 1, 4, dtype=dtype, device=device, requires_grad=True) + F.conv1d(x, y, padding=0).sum().backward() + gx_expect, gy_expect = x.grad, y.grad + x.grad, y.grad = None, None + + F.conv1d(x, y, padding='valid').sum().backward() + gx_actual, gy_actual = x.grad, y.grad + self.assertEqual(gx_expect, gx_actual) + self.assertEqual(gy_expect, gy_actual) + + @unittest.skipIf(not TEST_SCIPY, "Scipy required for the test.") + @dtypes(torch.float, torch.cfloat) + @parametrize_test("mode", ('valid', 'same')) + def test_conv1d_vs_scipy(self, device, dtype, mode): + t = make_tensor((1, 10), device=device, dtype=dtype) + feat_dim = t.shape[1] + weight_even = make_tensor((1, 1, 4), device=device, dtype=dtype) + weight_odd = make_tensor((1, 1, 5), device=device, dtype=dtype) + + def _test(t, weight, mode): + # SciPy expects two 1-D inputs. + t_a = t.view(-1).cpu().numpy() + w_a = weight.view(-1).cpu().numpy() + expected = scipy.signal.convolve(t_a, w_a, mode=mode) + + kwargs = {'padding': mode} + if mode == 'same': + # `same` padding in PyTorch conv1d is different + # from SciPy + p = weight.shape[2] // 2 + t = torch.nn.functional.pad(t, (p, p)) + # We have already taken care of padding + kwargs.pop("padding") + + # second input is flipped in SciPy's convolve + weight_flipped = torch.flip(weight, (2,)) + actual = torch.nn.functional.conv1d(t, weight_flipped, **kwargs).squeeze(0) + if mode == 'same': + actual = actual[:feat_dim] + + self.assertEqual(actual, expected, atol=2e-5, rtol=2e-5) + + # Global dtype for this test suite is torch.double + # This leads to change in type-promotion + # and conv1d outputs `complex128` for `complex64` input. + with set_default_dtype(torch.float): + _test(t, weight_even, mode) + _test(t, weight_odd, mode) + + @unittest.skipIf(not TEST_SCIPY, "Scipy required for the test.") + @dtypes(torch.float, torch.cfloat) + @parametrize_test("mode", ('valid', 'same')) + def test_conv2d_vs_scipy(self, device, dtype, mode): + t = make_tensor((1, 5, 10), device=device, dtype=dtype) + weight_even = make_tensor((1, 1, 2, 4), device=device, dtype=dtype) + weight_odd = make_tensor((1, 1, 3, 5), device=device, dtype=dtype) + + def _test(t, weight, mode): + # SciPy expects two 2-D inputs. + t_a = t.squeeze(0).cpu().numpy() + w_a = weight.squeeze(0).squeeze(0).cpu().numpy() + expected = scipy.signal.convolve2d(t_a, w_a, mode=mode) + + kwargs = {'padding': mode} + if mode == 'same': + # `same` padding in PyTorch conv2d is different + # from SciPy + left_right_pad = weight.shape[3] // 2 + top_bottom_pad = weight.shape[2] // 2 + p = (left_right_pad, left_right_pad, top_bottom_pad, top_bottom_pad) + t = torch.nn.functional.pad(t, p) + # We have already taken care of padding + kwargs.pop("padding") + + # second input is flipped in SciPy's convolve2d + weight_flipped = torch.flip(weight, (2, 3)) + actual = torch.nn.functional.conv2d(t, weight_flipped, **kwargs).squeeze(0) + if mode == 'same': + actual = actual[:5, :10] + + self.assertEqual(actual, expected, rtol=2e-5, atol=5e-6) + + # Global dtype for this test suite is torch.double + # This leads to change in type-promotion + # and conv1d outputs `complex128` for `complex64` input. + with set_default_dtype(torch.float): + _test(t, weight_even, mode) + _test(t, weight_odd, mode) + + @unittest.skipIf(not TEST_SCIPY, "Scipy required for the test.") + @dtypes(torch.float, torch.cfloat) + @parametrize_test("mode", ('valid', 'same')) + def test_conv3d_vs_scipy(self, device, dtype, mode): + t = make_tensor((1, 5, 5, 10), device=device, dtype=dtype) + weight_even = make_tensor((1, 1, 2, 2, 4), device=device, dtype=dtype) + weight_odd = make_tensor((1, 1, 2, 3, 5), device=device, dtype=dtype) + + def _test(t, weight, mode): + # SciPy expects two 3-D inputs. + t_a = t.squeeze(0).cpu().numpy() + w_a = weight.squeeze(0).squeeze(0).cpu().numpy() + expected = scipy.signal.convolve(t_a, w_a, mode=mode) + + kwargs = {'padding': mode} + if mode == 'same': + # `same` padding in PyTorch conv3d is different + # from SciPy + left_right_pad = weight.shape[4] // 2 + top_bottom_pad = weight.shape[3] // 2 + front_back_pad = weight.shape[2] // 2 + p = (left_right_pad, left_right_pad, top_bottom_pad, top_bottom_pad, + front_back_pad, front_back_pad) + t = torch.nn.functional.pad(t, p) + # We have already taken care of padding + kwargs.pop("padding") + + # second input is flipped in SciPy's convolve + weight_flipped = torch.flip(weight, (2, 3, 4)) + actual = torch.nn.functional.conv3d(t, weight_flipped, **kwargs).squeeze(0) + if mode == 'same': + actual = actual[:5, :5, :10] + + if tf32_is_not_fp32() and (dtype == torch.float or dtype == torch.complex64): + self.assertEqual(actual, expected, atol=0.05, rtol=0.05) + else: + self.assertEqual(actual, expected, rtol=2e-5, atol=5e-6) + + # Global dtype for this test suite is torch.double + # This leads to change in type-promotion + # and conv1d outputs `complex128` for `complex64` input. + with set_default_dtype(torch.float): + _test(t, weight_even, mode) + _test(t, weight_odd, mode) + + @dtypes(torch.float, torch.complex64) + def test_conv2d_valid_padding_backward(self, device, dtype): + # Test F.conv2d gradients work with padding='valid' + x = torch.rand(1, 1, 1, 10, device=device, dtype=dtype, requires_grad=True) + y = torch.rand(1, 1, 1, 4, device=device, dtype=dtype, requires_grad=True) + F.conv2d(x, y, padding=0).sum().backward() + gx_expect, gy_expect = x.grad, y.grad + x.grad, y.grad = None, None + + F.conv2d(x, y, padding='valid').sum().backward() + gx_actual, gy_actual = x.grad, y.grad + self.assertEqual(gx_expect, gx_actual) + self.assertEqual(gy_expect, gy_actual) + + @dtypes(torch.double, torch.cdouble) + def test_conv3d_valid_padding_backward(self, device, dtype): + check_forward_ad = torch.device(device).type != 'xla' + + # Test F.conv3d gradients work with padding='valid' + x = torch.rand(1, 1, 1, 1, 10, dtype=dtype, device=device, requires_grad=True) + y = torch.rand(1, 1, 1, 1, 4, dtype=dtype, device=device, requires_grad=True) + F.conv3d(x, y, padding=0).sum().backward() + gx_expect, gy_expect = x.grad, y.grad + x.grad, y.grad = None, None + + F.conv3d(x, y, padding='valid').sum().backward() + gx_actual, gy_actual = x.grad, y.grad + self.assertEqual(gx_expect, gx_actual) + self.assertEqual(gy_expect, gy_actual) + + gradcheck(lambda x, y: F.conv3d(x, y, padding='valid'), (x, y), check_forward_ad=check_forward_ad) + gradgradcheck(lambda x, y: F.conv3d(x, y, padding='valid'), (x, y), check_fwd_over_rev=check_forward_ad) + + @parametrize_test("N", range(2, 4), name_fn=lambda N: 'ConvTranspose{}d'.format(N)) + def test_conv_transpose_with_output_size_and_no_batch_dim(self, device, N): + # For inputs with no batch dim, verify output is the correct shape when output_size is set. + # See https://github.com/pytorch/pytorch/issues/75889 + inp = torch.randn((1, 15, 13) if N == 2 else (1, 15, 13, 13), device=device) + output_size = (1, 240, 200) if N == 2 else (1, 240, 200, 200) + ConvTransposeNd = getattr(nn, 'ConvTranspose{}d'.format(N)) + m = ConvTransposeNd(1, 1, kernel_size=16, stride=16, padding=7, bias=False, device=device) + output = m(inp, output_size=output_size) + self.assertEqual(output.shape, output_size) + + @skipMeta + @parametrize_test("input_shape,transposed,dilated,groups,layout,backend_expected", [ + # === slow === + subtest(((2, 6, 7), False, False, 3, torch.strided, torch._C._ConvBackend.Slow2d), + decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow1d'), + subtest(((2, 6, 7), True, False, 3, torch.strided, torch._C._ConvBackend.SlowTranspose2d), + decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow1d_transposed'), + subtest(((2, 6, 7), False, True, 3, torch.strided, torch._C._ConvBackend.SlowDilated2d), + decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow1d_dilated'), + subtest(((2, 6, 7), True, True, 3, torch.strided, torch._C._ConvBackend.SlowTranspose2d), + decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow1d_dilated_transposed'), + subtest(((2, 6, 7, 8), False, False, 3, torch.strided, torch._C._ConvBackend.Slow2d), + decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow2d'), + subtest(((2, 6, 7, 8), True, False, 3, torch.strided, torch._C._ConvBackend.SlowTranspose2d), + decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow2d_transposed'), + subtest(((2, 6, 7, 8), False, True, 3, torch.strided, torch._C._ConvBackend.SlowDilated2d), + decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow2d_dilated'), + subtest(((2, 6, 7, 8), True, True, 3, torch.strided, torch._C._ConvBackend.SlowTranspose2d), + decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow2d_dilated_transposed'), + subtest(((2, 6, 7, 8, 9), False, False, 3, torch.strided, torch._C._ConvBackend.Slow3d), + decorators=[onlyCPU, disableMkldnn], name='slow3d_cpu'), + # CUDA doesn't have a slow 3D implementation, so it goes to the dilated 3D implementation instead + subtest(((2, 6, 7, 8, 9), False, False, 3, torch.strided, torch._C._ConvBackend.SlowDilated3d), + decorators=[onlyCUDA, disablecuDNN], name='slow3d_cuda'), + # FIXME: RuntimeError: CUDA out of memory. + # subtest(((2, 6, 7, 8, 9), True, False, 3, torch.strided, torch._C._ConvBackend.SlowTranspose3d), + # decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow3d_transposed'), + subtest(((2, 6, 7, 8, 9), False, True, 3, torch.strided, torch._C._ConvBackend.SlowDilated3d), + decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow3d_dilated'), + # FIXME: RuntimeError: CUDA out of memory. + # subtest(((2, 6, 7, 8, 9), True, True, 3, torch.strided, torch._C._ConvBackend.SlowTranspose3d), + # decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow3d_dilated_transposed'), + subtest(((0, 6, 7), False, False, 3, torch.strided, torch._C._ConvBackend.Empty), + decorators=[onlyNativeDeviceTypes, disableMkldnn], name='empty_batch1d'), + subtest(((2, 0, 7), False, False, 3, torch.strided, torch._C._ConvBackend.Empty), + decorators=[onlyNativeDeviceTypes, disableMkldnn], name='empty_channel1d'), + subtest(((0, 0, 7), False, False, 3, torch.strided, torch._C._ConvBackend.Empty), + decorators=[onlyNativeDeviceTypes, disableMkldnn], name='empty_batch_channel1d'), + subtest(((0, 6, 7, 8), False, False, 3, torch.strided, torch._C._ConvBackend.Empty), + decorators=[onlyNativeDeviceTypes, disableMkldnn], name='empty_batch2d'), + subtest(((2, 0, 7, 8), False, False, 3, torch.strided, torch._C._ConvBackend.Empty), + decorators=[onlyNativeDeviceTypes, disableMkldnn], name='empty_channel2d'), + subtest(((0, 0, 7, 8), False, False, 3, torch.strided, torch._C._ConvBackend.Empty), + decorators=[onlyNativeDeviceTypes, disableMkldnn], name='empty_batch_channel2d'), + subtest(((0, 6, 7, 8, 9), False, False, 3, torch.strided, torch._C._ConvBackend.Empty), + decorators=[onlyNativeDeviceTypes, disableMkldnn], name='empty_batch3d'), + subtest(((2, 0, 7, 8, 9), False, False, 3, torch.strided, torch._C._ConvBackend.Empty), + decorators=[onlyNativeDeviceTypes, disableMkldnn], name='empty_channel3d'), + subtest(((0, 0, 7, 8, 9), False, False, 3, torch.strided, torch._C._ConvBackend.Empty), + decorators=[onlyNativeDeviceTypes, disableMkldnn], name='empty_batch_channel3d'), + # === cuda === + # Note that disablecuDNN disables miopen as well. + subtest(((2, 6, 7), False, False, 6, torch.strided, torch._C._ConvBackend.CudaDepthwise2d), + decorators=[onlyCUDA, disablecuDNN], name='cuda_depthwise1d'), + subtest(((2, 6, 7, 8), False, False, 6, torch.strided, torch._C._ConvBackend.CudaDepthwise2d), + decorators=[onlyCUDA, disablecuDNN], name='cuda_depthwise2d'), + subtest(((2, 6, 7, 8, 9), False, False, 6, torch.strided, torch._C._ConvBackend.CudaDepthwise3d), + decorators=[onlyCUDA, disablecuDNN], name='cuda_depthwise3d'), + # === cudnn === + subtest(((2, 6, 7), False, False, 3, torch.strided, torch._C._ConvBackend.Cudnn), + decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen], name='cudnn1d'), + subtest(((2, 6, 7, 8), False, False, 3, torch.strided, torch._C._ConvBackend.Cudnn), + decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen], name='cudnn2d'), + subtest(((2, 6, 7, 8, 9), False, False, 3, torch.strided, torch._C._ConvBackend.Cudnn), + decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen], name='cudnn3d'), + subtest(((2, 6, 7), True, False, 3, torch.strided, torch._C._ConvBackend.CudnnTranspose), + decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen], name='cudnn1d_transposed'), + subtest(((2, 6, 7, 8), True, False, 3, torch.strided, torch._C._ConvBackend.CudnnTranspose), + decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen], name='cudnn2d_transposed'), + # FIXME: RuntimeError: CUDA out of memory. + # subtest(((2, 6, 7, 8, 9), True, False, 3, torch.strided, torch._C._ConvBackend.CudnnTranspose), + # decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen], name='cudnn3d_transposed'), + # === miopen === + subtest(((2, 6, 7), False, False, 3, torch.strided, torch._C._ConvBackend.Miopen), + decorators=[onlyCUDA, skipCUDAIfNoMiopen], name='miopen1d'), + subtest(((2, 6, 7, 8), False, False, 3, torch.strided, torch._C._ConvBackend.Miopen), + decorators=[onlyCUDA, skipCUDAIfNoMiopen], name='miopen2d'), + subtest(((2, 6, 7, 8, 9), False, False, 3, torch.strided, torch._C._ConvBackend.Miopen), + decorators=[onlyCUDA, skipCUDAIfNoMiopen], name='miopen3d'), + subtest(((2, 6, 7), True, False, 3, torch.strided, torch._C._ConvBackend.MiopenTranspose), + decorators=[onlyCUDA, skipCUDAIfNoMiopen], name='miopen1d_transposed'), + subtest(((2, 6, 7, 8), True, False, 3, torch.strided, torch._C._ConvBackend.MiopenTranspose), + decorators=[onlyCUDA, skipCUDAIfNoMiopen], name='miopen2d_transposed'), + subtest(((2, 6, 7, 8, 9), True, False, 3, torch.strided, torch._C._ConvBackend.MiopenTranspose), + decorators=[onlyCUDA, skipCUDAIfNoMiopen], name='miopen3d_transposed'), + subtest(((2, 6, 7), False, False, 6, torch.strided, torch._C._ConvBackend.MiopenDepthwise), + decorators=[onlyCUDA, skipCUDAIfNoMiopen], name='miopen_depthwise1d'), + subtest(((2, 6, 7, 8), False, False, 6, torch.strided, torch._C._ConvBackend.MiopenDepthwise), + decorators=[onlyCUDA, skipCUDAIfNoMiopen], name='miopen_depthwise2d'), + subtest(((2, 6, 7, 8, 9), False, False, 6, torch.strided, torch._C._ConvBackend.MiopenDepthwise), + decorators=[onlyCUDA, skipCUDAIfNoMiopen], name='miopen_depthwise3d'), + # === mkldnn === + subtest(((2, 6, 7), False, False, 3, torch._mkldnn, torch._C._ConvBackend.Mkldnn), + decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn1d'), + subtest(((2, 6, 7, 8), False, False, 3, torch._mkldnn, torch._C._ConvBackend.Mkldnn), + decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn2d'), + subtest(((2, 6, 7, 8, 9), False, False, 3, torch._mkldnn, torch._C._ConvBackend.Mkldnn), + decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn3d'), + # Transposed convolution is broken for mkldnn. See https://github.com/pytorch/pytorch/issues/68775. + subtest(((2, 6, 7), True, False, 3, torch._mkldnn, torch._C._ConvBackend.Mkldnn), + decorators=[onlyCPU, skipCPUIfNoMkldnn, unittest.expectedFailure], name='mkldnn1d_transposed'), + subtest(((2, 6, 7, 8), True, False, 3, torch._mkldnn, torch._C._ConvBackend.Mkldnn), + decorators=[onlyCPU, skipCPUIfNoMkldnn, unittest.expectedFailure], name='mkldnn2d_transposed'), + subtest(((2, 6, 7, 8, 9), True, False, 3, torch._mkldnn, torch._C._ConvBackend.Mkldnn), + decorators=[onlyCPU, skipCPUIfNoMkldnn, unittest.expectedFailure], name='mkldnn3d_transposed'), + subtest(((2, 6, 7), False, True, 3, torch.strided, torch._C._ConvBackend.Mkldnn), + decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn1d_cpu_input'), + subtest(((2, 6, 7, 8), False, True, 3, torch.strided, torch._C._ConvBackend.Mkldnn), + decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn2d_cpu_input'), + subtest(((2, 6, 7, 8, 9), False, True, 3, torch.strided, torch._C._ConvBackend.Mkldnn), + decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn3d_cpu_input'), + subtest(((0, 6, 7), False, False, 3, torch._mkldnn, torch._C._ConvBackend.MkldnnEmpty), + decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn_empty_batch1d'), + subtest(((2, 0, 7), False, False, 3, torch._mkldnn, torch._C._ConvBackend.MkldnnEmpty), + decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn_empty_channel1d'), + subtest(((0, 0, 7), False, False, 3, torch._mkldnn, torch._C._ConvBackend.MkldnnEmpty), + decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn_empty_batch_channel1d'), + subtest(((0, 6, 7, 8), False, False, 3, torch._mkldnn, torch._C._ConvBackend.MkldnnEmpty), + decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn_empty_batch2d'), + subtest(((2, 0, 7, 8), False, False, 3, torch._mkldnn, torch._C._ConvBackend.MkldnnEmpty), + decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn_empty_channel2d'), + subtest(((0, 0, 7, 8), False, False, 3, torch._mkldnn, torch._C._ConvBackend.MkldnnEmpty), + decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn_empty_batch_channel2d'), + subtest(((0, 6, 7, 8, 9), False, False, 3, torch._mkldnn, torch._C._ConvBackend.MkldnnEmpty), + decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn_empty_batch3d'), + subtest(((2, 0, 7, 8, 9), False, False, 3, torch._mkldnn, torch._C._ConvBackend.MkldnnEmpty), + decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn_empty_channel3d'), + subtest(((0, 0, 7, 8, 9), False, False, 3, torch._mkldnn, torch._C._ConvBackend.MkldnnEmpty), + decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn_empty_batch_channel3d'), + # Note: Tests for mobile backends are not currently supported. This comprises + # NnpackSpatial, Winograd3x3Depthwise, and Xnnpack2d backends. Testing these + # requires the ability to gate tests by whether PyTorch is built with USE_MOBILE=1. + ]) + # Test with both bias and no bias. + @parametrize_test("has_bias", [False, True]) + # Test with both stride=1 and stride>1 cases. + @parametrize_test("strided", [False, True]) + # Test with both contiguous and non-contiguous inputs. + @parametrize_test("contiguous", [False, True]) + def test_conv_backend( + self, device, input_shape, has_bias, strided, contiguous, transposed, dilated, groups, + layout, backend_expected): + # Build up inputs. + dtype = torch.float32 + C_in, C_out, dim, kernel_size = input_shape[1], 12, len(input_shape) - 2, 3 + x = torch.randn(*input_shape, device=device, dtype=dtype, requires_grad=True) + weight = torch.randn(C_in if transposed else C_out, + C_out // groups if transposed else C_in // groups, + *[kernel_size for _ in range(dim)], + device=device, dtype=dtype, requires_grad=True) + bias = torch.randn(C_out, device=device, dtype=dtype, requires_grad=True) if has_bias else None + + def _make_noncontiguous(inp): + if inp is None: + return None + old_requires_grad = inp.requires_grad + inp = torch.repeat_interleave(inp, 2, dim=-1) + inp = inp[..., ::2].detach().requires_grad_(old_requires_grad) + return inp + + if not contiguous: + x = _make_noncontiguous(x) + weight = _make_noncontiguous(weight) + bias = _make_noncontiguous(bias) + + if layout is torch._mkldnn: + x = x.to_mkldnn() + # Note that weight and bias are not supported as mkldnn tensors during training. + + stride = (2,) * dim if strided else (1,) * dim + padding = (0,) * dim + dilation = (2,) * dim if dilated else (1,) * dim + output_padding = (0,) * dim + inputs = [x, weight, bias, stride, padding, dilation, transposed, output_padding, groups] + + # Ensure correct backend is selected. + backend_actual = torch._C._select_conv_backend(*inputs) + self.assertEqual(backend_actual, backend_expected) + + # Ensure backward call succeeds. + convolution = torch.ops.aten.convolution + output = convolution(*inputs) + grad_output = torch.randn(output.shape, device=device, dtype=dtype) + if not contiguous: + grad_output = _make_noncontiguous(grad_output) + if layout is torch._mkldnn: + grad_output = grad_output.to_mkldnn() + output.backward(grad_output) + + # mkldnn doesn't support gradcheck :( + if layout is torch._mkldnn: + return + + if backend_actual != torch._C._ConvBackend.Empty: # FIXME: forward AD fails + # Forward AD and forward-over-reverse AD smoke test in float32 + # TODO: remove this if we introduce per-op gradient tests for float32 + with fwAD.dual_level(): + dual_inputs = [(fwAD.make_dual(i, torch.rand_like(i)) if isinstance(i, torch.Tensor) else i) for i in inputs] + # Forward AD + output = convolution(*dual_inputs) + # Forward over reverse AD + grad_output_d = fwAD.make_dual(torch.rand_like(output), torch.rand_like(output)) + if has_bias: + torch.autograd.grad(output, [x, weight, bias], grad_output_d) + else: + torch.autograd.grad(output, [x, weight], grad_output_d) + + # Convert to float64 for gradcheck. + x = x.to(torch.float64).detach().requires_grad_(True) + weight = weight.to(torch.float64).detach().requires_grad_(True) + if bias is not None: + bias = bias.to(torch.float64).detach().requires_grad_(True) + inputs = [x, weight, bias, stride, padding, dilation, transposed, output_padding, groups] + + # Set some backend-specific validation settings. + gradcheck_nondet_tol = 0.0 + if torch.backends.cudnn.is_available(): + # cuDNN introduces non-determinism + gradcheck_nondet_tol = GRADCHECK_NONDET_TOL + + self.assertTrue(gradcheck(convolution, inputs, nondet_tol=gradcheck_nondet_tol)) + + # double backward doesn't support bias gradients + if bias is not None: + bias.requires_grad_(False) + self.assertTrue(gradgradcheck(convolution, inputs, nondet_tol=gradcheck_nondet_tol)) + + + @onlyCPU + def test_conv_contiguous_for_oneDNN(self): + # See https://github.com/pytorch/pytorch/issues/80837. + for dtype in [torch.float, torch.bfloat16]: + conv = nn.Conv2d( + 1, + 128, + kernel_size=(5, 2), + stride=(2, 1), + padding=(0, 1), + dilation=(1, 1), + groups=1, + bias=True, + padding_mode='zeros').to(dtype=dtype) + + x = torch.rand([1, 2, 321, 201, 1]).to(dtype=dtype) + x = torch.transpose(x, 1, 4) + x2 = x[..., 0] + inputs = [x2, conv.weight, conv.bias, (2, 1), (0, 1), (1, 1), False, (0, 1), 1] + if torch.backends.mkldnn.is_available(): + y = conv(x2) + # Disable MKLDNN explicitly + with torch.backends.mkldnn.flags(enabled=False): + y_ = conv(x2) + self.assertEqual(y, y_) + + @onlyCPU + def test_conv_ic1_channels_last_for_oneDNN(self): + # See https://github.com/pytorch/pytorch/issues/82060, N > 1 will call in OneDNN path. + for dtype in [torch.float, torch.bfloat16]: + conv = torch.nn.Conv2d(1, 64, kernel_size=(3, 3), padding=(1, 1), bias=False) + conv = conv.to(memory_format=torch.channels_last).to(dtype=dtype) + x = torch.rand(2, 1, 100, 100).to(dtype=dtype) + if torch.backends.mkldnn.is_available(): + y = conv(x) + # Disable MKLDNN explicitly + with torch.backends.mkldnn.flags(enabled=False): + y_ = conv(x) + self.assertEqual(y, y_) + + @dtypes(torch.float, torch.cfloat) + def test_conv_empty_channel(self, device, dtype): + in_channels = 0 + mod = torch.nn.Conv1d(in_channels, 8, 2, stride=2, dtype=dtype).to(device) + inp = torch.randn(2, 0, 15, device=device, dtype=dtype) + _test_module_empty_input(self, mod, inp, check_size=False) + + with self.assertRaisesRegex(RuntimeError, "Given groups=1, weight"): + inp = torch.randn(2, 1, 0, device=device, dtype=dtype) + mod(inp) + + mod = torch.nn.Conv2d(in_channels, 33, 3, stride=2, dtype=dtype).to(device) + inp = torch.randn(2, 0, 50, 100, device=device, dtype=dtype) + _test_module_empty_input(self, mod, inp, check_size=False) + + with self.assertRaisesRegex(RuntimeError, "Given groups=1, weight"): + inp = torch.randn(2, 1, 40, 0, device=device, dtype=dtype) + mod(inp) + + mod = torch.nn.Conv3d(in_channels, 33, 3, stride=2, dtype=dtype).to(device) + inp = torch.randn(2, 0, 50, 20, 40, device=device, dtype=dtype) + _test_module_empty_input(self, mod, inp, check_size=False) + + with self.assertRaisesRegex(RuntimeError, "Given groups=1, weight"): + inp = torch.randn(2, 1, 50, 0, 40, device=device, dtype=dtype) + mod(inp) + + def test_group_conv_empty(self, device): + mod = torch.nn.Conv2d(4, 4, stride=2, kernel_size=3, padding=1, groups=4).to(device) + inp = torch.randn(0, 4, 4, 4, device=device) + _test_module_empty_input(self, mod, inp, check_size=False) + if self.device_type == 'cuda' and self.has_cudnn(): + with torch.backends.cudnn.flags(enabled=False): + _test_module_empty_input(self, mod, inp, check_size=False) + + def test_group_convTranspose_empty(self, device): + mod = torch.nn.ConvTranspose2d(4, 4, stride=2, kernel_size=3, padding=1, groups=4).to(device) + inp = torch.randn(0, 4, 4, 4, device=device) + _test_module_empty_input(self, mod, inp, check_size=False) + if self.device_type == 'cuda' and self.has_cudnn(): + with torch.backends.cudnn.flags(enabled=False): + _test_module_empty_input(self, mod, inp, check_size=False) + + def test_convTranspose_empty(self, device): + mod = torch.nn.ConvTranspose2d(4, 4, stride=2, kernel_size=3, padding=1).to(device) + inp = torch.randn(0, 4, 4, 4, device=device) + _test_module_empty_input(self, mod, inp, check_size=False) + if self.device_type == 'cuda' and self.has_cudnn(): + with torch.backends.cudnn.flags(enabled=False): + _test_module_empty_input(self, mod, inp, check_size=False) + + @onlyCUDA + @largeTensorTest('12GB') + def test_conv_large_nosplit(self, device): + # Here we just test the convolution correctly route to the fallback implementation + # that is, it does not crash. The correctness of fallback implementation should be + # covered in other tests + dtype = torch.half if self.device_type == 'cuda' else torch.float + conv1 = nn.Conv2d(2, 2, 8, 8).to(device).to(dtype) + input_large = torch.randn(1, 2, 1024, 1024 * 1024, dtype=dtype, device=device) + conv1(input_large) + conv2 = torch.nn.Conv2d(1, 1024, 1, 1).to(device).to(dtype) + input_large = torch.randn(1, 1, 2048, 1024 , dtype=dtype, device=device) + conv2(input_large) + + def test_conv_noncontig_weights(self, device): + for dim in (1, 2, 3): + for grouped in (False, True): + nc = 3 + groups = 3 if grouped else 1 + w = torch.randn([3] * dim, device=device) + w = w.expand([nc, int(nc / groups)] + list(w.shape)) + w = w.detach().requires_grad_() + x = torch.randn([1, nc] + ([5] * dim), device=device, requires_grad=True) + y = getattr(F, 'conv{}d'.format(dim))(x, w, groups=groups) + y.sum().backward() + y = getattr(F, 'conv_transpose{}d'.format(dim))(x, w, groups=groups) + y.sum().backward() + + def test_conv_noncontig_weights_and_bias(self, device): + # need floats to exercise https://github.com/pytorch/pytorch/issues/16018 + for bias in [True, False]: + conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, + bias=bias).to(device, torch.float) + + input_nc = torch.randn((1, 3, 224, 224, 2), device=device, dtype=torch.float)[:, :, :, :, 1] + input_c = input_nc.contiguous() + + weight_nc = torch.randn((64, 3, 7, 7, 2), device=device, dtype=torch.float)[:, :, :, :, 1] + conv1.weight = nn.Parameter(weight_nc) + weight_c = conv1.weight.contiguous() + + if bias: + bias_nc = torch.randn((64, 2), device=device, dtype=torch.float)[:, 1] + conv1.bias = nn.Parameter(bias_nc) + bias_c = conv1.bias.contiguous() + + out1 = conv1(input_nc) + conv1.weight = nn.Parameter(weight_c) + if bias: + conv1.bias = nn.Parameter(bias_c) + out2 = conv1(input_c) + self.assertEqual(out1, out2) + + @onlyCUDA + @largeTensorTest('12GB') + def test_conv_transposed_large(self, device): + dtype = torch.half if self.device_type == 'cuda' else torch.float + conv = nn.ConvTranspose2d(1, 1, 1, 1, bias=False).to(device).to(dtype) + input_large = torch.randn(4096, 1, 512, 1024, dtype=dtype, device=device) + # forward + ret = conv(input_large) + maxdiff0 = (ret.narrow(0, 0, 1024) - conv(input_large.narrow(0, 0, 1024))).abs_().max().item() + maxdiff1 = (ret.narrow(0, 1024, 1024) - conv(input_large.narrow(0, 1024, 1024))).abs_().max().item() + maxdiff2 = (ret.narrow(0, 2048, 1024) - conv(input_large.narrow(0, 2048, 1024))).abs_().max().item() + maxdiff3 = (ret.narrow(0, 3072, 1024) - conv(input_large.narrow(0, 3072, 1024))).abs_().max().item() + if self.device_type == 'cuda': + # cuDNN may use algorithms such as FFT that don't guarantee a diff of 0 + self.assertEqual(maxdiff0, 0, atol=2e-3, rtol=1e-5) + self.assertEqual(maxdiff1, 0, atol=2e-3, rtol=1e-5) + self.assertEqual(maxdiff2, 0, atol=2e-3, rtol=1e-5) + self.assertEqual(maxdiff3, 0, atol=2e-3, rtol=1e-5) + else: + self.assertEqual(maxdiff0, 0) + self.assertEqual(maxdiff1, 0) + self.assertEqual(maxdiff2, 0) + self.assertEqual(maxdiff3, 0) + + @onlyCUDA + @skipCUDAIfRocm + @largeTensorTest('12GB') + def test_conv_large(self, device): + dtype = torch.half if self.device_type == 'cuda' else torch.float + conv = nn.Conv2d(2, 2, 8, 8, bias=False).to(device).to(dtype) + input_large = torch.randn(4097, 2, 512, 512, dtype=dtype, device=device) + # forward + ret = conv(input_large) + self.assertEqual(ret[:2048], conv(input_large[:2048])) + self.assertEqual(ret[2048:4096], conv(input_large[2048:4096])) + self.assertEqual(ret[4096:], conv(input_large[4096:])) + + # backward + conv.zero_grad() + # When computing the backward, we are using the `max(dim=1)`` to create + # some sparsity. Without this sparsity, the rounding error would be + # too large (as large as 1e-5) to satisfy the creterion (1e-6) of `assertEqual` + ret.view(4097, -1).max(dim=1).values.sum().backward() + del ret + grad1 = conv.weight.grad.detach().clone() + conv.zero_grad() + conv(input_large[:2048]).view(2048, -1).max(dim=1).values.sum().backward() + conv(input_large[2048:4096]).view(2048, -1).max(dim=1).values.sum().backward() + conv(input_large[4096:]).view(1, -1).max(dim=1).values.sum().backward() + grad2 = conv.weight.grad.detach().clone() + # gradients are at the order of hundreds, we need to scale it to + # the order of one so that we can compare + scale = 1 / grad2.abs().mean() + grad1 = grad1 * scale + grad2 = grad2 * scale + self.assertEqual(grad1, grad2, atol=5e-2, rtol=5e-3) + + @onlyCUDA + @skipCUDAIfNoCudnn + def test_contig_wrong_stride_cudnn(self, device): + # x has to have batch_size 1 to test contiguous checks + x = torch.randn(1, 16, 5, 5, device=device) + stride = list(x.stride()) + stride[0] = 20 + # change the stride in dimension 0. the tensor is still contiguous because size[0] is 1 + x.set_(x.storage(), 0, x.size(), stride) + self.assertTrue(x.is_contiguous()) + F.conv_transpose2d(x, torch.randn(16, 1, 1, 1, device=device)) + F.conv2d(x, torch.randn(1, 16, 1, 1, device=device)) + + @onlyCUDA + def test_Conv2d_size_1_kernel(self, device): + x_cpu = torch.randn(2, 3, 5, 5) + conv_cpu = torch.nn.Conv2d(3, 3, kernel_size=1) + y_cpu = conv_cpu(x_cpu) + y = torch.rand_like(y_cpu) + y_cpu.backward(y) + + with cudnn.flags(enabled=False): + conv_cuda = torch.nn.Conv2d(3, 3, kernel_size=1).to(device) + conv_cuda.bias.data.copy_(conv_cpu.bias.data) + conv_cuda.weight.data.copy_(conv_cpu.weight.data) + y_cuda = conv_cuda(x_cpu.to(device)) + y_cuda.backward(y.to(device)) + + self.assertEqual(y_cpu, y_cuda, atol=1e-5, rtol=0, exact_device=False) + self.assertEqual(conv_cpu.bias.grad.data, conv_cuda.bias.grad.data, atol=1e-5, rtol=0, exact_device=False) + self.assertEqual(conv_cpu.weight.grad.data, conv_cuda.weight.grad.data, atol=1e-5, rtol=0, exact_device=False) + + @onlyCUDA + def test_ConvTranspose2d_size_1_kernel(self, device): + x_cpu = torch.randn(2, 3, 5, 5) + conv_cpu = torch.nn.ConvTranspose2d(3, 3, kernel_size=1) + y_cpu = conv_cpu(x_cpu) + y = torch.rand_like(y_cpu) + y_cpu.backward(y) + + with cudnn.flags(enabled=False): + conv_cuda = torch.nn.ConvTranspose2d(3, 3, kernel_size=1).to(device) + conv_cuda.bias.data.copy_(conv_cpu.bias.data) + conv_cuda.weight.data.copy_(conv_cpu.weight.data) + y_cuda = conv_cuda(x_cpu.to(device)) + y_cuda.backward(y.to(device)) + + self.assertEqual(y_cpu, y_cuda, atol=1e-5, rtol=0, exact_device=False) + self.assertEqual(conv_cpu.bias.grad.data, conv_cuda.bias.grad.data, atol=1e-5, rtol=0, exact_device=False) + self.assertEqual(conv_cpu.weight.grad.data, conv_cuda.weight.grad.data, atol=1e-5, rtol=0, exact_device=False) + + @onlyCUDA + def test_ConvTranspose3d_size_1_kernel(self, device): + with set_default_dtype(torch.double): + x_cpu = torch.randn(2, 3, 3, 5, 5) + conv_cpu = torch.nn.ConvTranspose3d(3, 3, kernel_size=1) + y_cpu = conv_cpu(x_cpu) + y = torch.rand_like(y_cpu) + y_cpu.backward(y) + + with cudnn.flags(enabled=False): + conv_cuda = torch.nn.ConvTranspose3d(3, 3, kernel_size=1).to(device) + conv_cuda.bias.data.copy_(conv_cpu.bias.data) + conv_cuda.weight.data.copy_(conv_cpu.weight.data) + y_cuda = conv_cuda(x_cpu.to(device)) + y_cuda.backward(y.to(device)) + + self.assertEqual(y_cpu, y_cuda, atol=1e-5, rtol=0, exact_device=False) + self.assertEqual(conv_cpu.bias.grad.data, conv_cuda.bias.grad.data, atol=1e-5, rtol=0, exact_device=False) + self.assertEqual(conv_cpu.weight.grad.data, conv_cuda.weight.grad.data, atol=1e-5, rtol=0, exact_device=False) + + @dtypesIfCUDA(*floating_types_and(torch.half, *[torch.bfloat16] if AMPERE_OR_ROCM else [])) + @dtypes(torch.float) + @torch.backends.cudnn.flags(enabled=True, benchmark=False) + def test_Conv2d_naive_groups(self, device, dtype): + # Check that grouped convolutions matches two half convolutions + m = nn.Conv2d(4, 4, kernel_size=3, groups=2).to(device, dtype) + i = torch.randn(2, 4, 6, 6, device=device, dtype=dtype, requires_grad=True) + output = m(i) + grad_output = torch.randn(2, 4, 4, 4, device=device, dtype=dtype) + output.backward(grad_output) + + m1 = nn.Conv2d(2, 2, kernel_size=3).to(device, dtype) + m1.weight.data.copy_(m.weight.data[:2]) + m1.bias.data.copy_(m.bias.data[:2]) + i1 = i.data[:, :2].contiguous().requires_grad_(True) + output1 = m1(i1) + output1.backward(grad_output[:, :2].contiguous()) + + m2 = nn.Conv2d(2, 2, kernel_size=3).to(device, dtype) + m2.weight.data.copy_(m.weight.data[2:]) + m2.bias.data.copy_(m.bias.data[2:]) + i2 = i.data[:, 2:].contiguous().requires_grad_(True) + output2 = m2(i2) + output2.backward(grad_output[:, 2:].contiguous()) + + self.assertEqual(output, torch.cat([output1, output2], 1)) + self.assertEqual(i.grad.data, + torch.cat([i1.grad.data, i2.grad.data], 1), + atol=dtype2prec_DONTUSE[dtype], rtol=0) + self.assertEqual(m.bias.grad.data, + torch.cat([m1.bias.grad.data, m2.bias.grad.data], 0), + atol=dtype2prec_DONTUSE[dtype], rtol=0) + self.assertEqual(m.weight.grad.data, + torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0), + atol=dtype2prec_DONTUSE[dtype], rtol=0) + + @dtypes(torch.double, torch.cdouble) + def test_Conv2d_backward_depthwise(self, device, dtype): + x = torch.randn(2, 2, 4, 20, device=device, dtype=dtype, requires_grad=True) + weight = torch.randn(2, 1, 3, 5, device=device, dtype=dtype, requires_grad=True) + + def conv2d_depthwise(x, weight): + return torch.nn.functional.conv2d( + x, weight, bias=None, stride=(1, 10), groups=2) + + for cudnn_enabled in [False, True]: + with torch.backends.cudnn.flags(enabled=cudnn_enabled): + torch.autograd.gradcheck(conv2d_depthwise, (x, weight)) + + @onlyCPU + @dtypes(torch.float, torch.double) + def test_conv_thnn_nhwc(self, device, dtype): + def helper(n, c, h, w, out_channels, kernel_size, dilation, groups, input_format, weight_format): + input = torch.randint(-3, 3, (n, c, h, w), dtype=dtype, device=device)\ + .to(memory_format=input_format) + input.requires_grad_() + conv = nn.Conv2d(c, out_channels, kernel_size, dilation=dilation, groups=groups)\ + .to(device='cpu', dtype=dtype, memory_format=weight_format) + for p in conv.parameters(): + p.data = torch.randint_like(p, -3, 3) + + ref_input = input.detach().clone().contiguous().requires_grad_() + ref_conv = nn.Conv2d(c, out_channels, kernel_size, dilation=dilation, groups=groups) + # load_state_dict will restore the stride & memory_layout on ref_conv.weight. + ref_conv.load_state_dict(conv.state_dict()) + ref_conv = ref_conv.to(device='cpu', dtype=dtype, memory_format=torch.contiguous_format) + + out = conv(input) + ref_out = ref_conv(ref_input) + + grad = torch.randint_like(out, -3, 3) + ref_grad = grad.detach().clone().contiguous() + + out.backward(grad) + ref_out.backward(ref_grad) + + self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) + self.assertTrue(ref_out.is_contiguous()) + self.assertEqual(out, ref_out, exact_dtype=False) + self.assertEqual(conv.weight.grad, ref_conv.weight.grad, exact_dtype=False) + self.assertEqual(conv.bias.grad, ref_conv.bias.grad, exact_dtype=False) + self.assertEqual(input.grad, ref_input.grad, exact_dtype=False) + + with torch.backends.mkldnn.flags(enabled=False): + formats = [[torch.channels_last, torch.channels_last], + [torch.channels_last, torch.contiguous_format], + [torch.contiguous_format, torch.channels_last]] + for input_format, weight_format in formats: + # non-dilated conv: thnn_conv2d normal path (with im2col) + helper(2, 8, 4, 4, out_channels=4, kernel_size=3, dilation=1, groups=1, + input_format=input_format, weight_format=weight_format) + helper(2, 8, 4, 4, out_channels=8, kernel_size=3, dilation=1, groups=8, + input_format=input_format, weight_format=weight_format) + # test when input chanels is 1 and not converted to channels last + helper(2, 1, 10, 10, out_channels=8, kernel_size=3, dilation=1, groups=1, + input_format=torch.contiguous_format, weight_format=torch.channels_last) + # non-dilated conv: thnn_conv2d fast path (skip im2col) + helper(1, 16, 56, 56, out_channels=16, kernel_size=1, dilation=1, groups=1, + input_format=input_format, weight_format=weight_format) + # ic == oc == 1 here, so need to stick input to CL to activate channels last + helper(1, 16, 56, 56, out_channels=16, kernel_size=1, dilation=1, groups=16, + input_format=torch.channels_last, weight_format=weight_format) + # dilated conv: slow_conv_dilated2d + helper(2, 8, 11, 13, out_channels=16, kernel_size=3, dilation=2, groups=1, + input_format=input_format, weight_format=weight_format) + helper(2, 16, 11, 13, out_channels=32, kernel_size=3, dilation=2, groups=16, + input_format=input_format, weight_format=weight_format) + + @onlyCUDA + @skipCUDAIfRocmVersionLessThan((4, 3)) + @skipCUDAIfNotMiopenSuggestNHWC + @skipCUDAIfCudnnVersionLessThan(7603) + @dtypes(torch.half, torch.float, torch.cfloat) + def test_conv_cudnn_nhwc(self, device, dtype): + def helper(n, c, h, w, out_channels, kernel_size, groups): + input = torch.randint(-3, 3, (n, c, h, w), dtype=dtype, device=device)\ + .to(memory_format=torch.channels_last) + input.requires_grad_() + conv = nn.Conv2d(c, out_channels, kernel_size, groups=groups)\ + .to(device='cuda', dtype=dtype, memory_format=torch.channels_last) + for p in conv.parameters(): + p.data = torch.randint_like(p, -3, 3) + + # use FP64 channels-first conv as reference + ref_input = input.detach().clone().contiguous().double().requires_grad_() + ref_conv = nn.Conv2d(c, out_channels, kernel_size, groups=groups) + # load_state_dict will restore the stride & memory_layout on ref_conv.weight. + ref_conv.load_state_dict(conv.state_dict()) + ref_conv = ref_conv.to(device='cuda', dtype=torch.double, memory_format=torch.contiguous_format) + + out = conv(input) + ref_out = ref_conv(ref_input) + + grad = torch.randint_like(out, -3, 3) + ref_grad = grad.detach().clone().double().contiguous() + + out.backward(grad) + ref_out.backward(ref_grad) + + self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) + self.assertTrue(input.grad.is_contiguous(memory_format=torch.channels_last)) + self.assertTrue(conv.weight.grad.is_contiguous(memory_format=torch.channels_last)) + + self.assertTrue(ref_out.is_contiguous()) + self.assertTrue(ref_input.grad.is_contiguous()) + self.assertTrue(ref_conv.weight.grad.is_contiguous()) + + self.assertEqual(out, ref_out, exact_dtype=False) + self.assertEqual(conv.weight.grad, ref_conv.weight.grad, exact_dtype=False) + self.assertEqual(conv.bias.grad, ref_conv.bias.grad, exact_dtype=False) + self.assertEqual(input.grad, ref_input.grad, exact_dtype=False) + + helper(2, 8, 4, 4, out_channels=4, kernel_size=3, groups=1) + helper(2, 8, 4, 4, out_channels=8, kernel_size=3, groups=8) + helper(1, 16, 56, 56, out_channels=16, kernel_size=3, groups=1) + helper(1, 16, 56, 56, out_channels=16, kernel_size=3, groups=16) + + @onlyCUDA + @skipCUDAIfRocm + @skipCUDAIfCudnnVersionLessThan(8005) + @dtypes(torch.half, torch.float) + def test_conv_cudnn_ndhwc(self, device, dtype): + def helper(n, c, d, h, w, out_channels, kernel_size, groups): + input = torch.randint(-2, 2, (n, c, d, h, w), dtype=dtype, device=device)\ + .to(memory_format=torch.channels_last_3d) + input.requires_grad_() + conv = nn.Conv3d(c, out_channels, kernel_size, groups=groups)\ + .to(device='cuda', dtype=dtype, memory_format=torch.channels_last_3d) + for p in conv.parameters(): + p.data = torch.randint_like(p, -2, 2) + + # use FP64 channels-first conv as reference + ref_input = input.detach().clone().contiguous().double().requires_grad_() + ref_conv = nn.Conv3d(c, out_channels, kernel_size, groups=groups) + # load_state_dict will restore the stride & memory_layout on ref_conv.weight. + ref_conv.load_state_dict(conv.state_dict()) + ref_conv = ref_conv.to(device='cuda', dtype=torch.double, memory_format=torch.contiguous_format) + + out = conv(input) + ref_out = ref_conv(ref_input) + + grad = torch.randint_like(out, -2, 2) + ref_grad = grad.detach().clone().double().contiguous() + + out.backward(grad) + ref_out.backward(ref_grad) + + self.assertTrue(out.is_contiguous(memory_format=torch.channels_last_3d)) + self.assertTrue(input.grad.is_contiguous(memory_format=torch.channels_last_3d)) + self.assertTrue(conv.weight.grad.is_contiguous(memory_format=torch.channels_last_3d)) + + self.assertTrue(ref_out.is_contiguous()) + self.assertTrue(ref_input.grad.is_contiguous()) + self.assertTrue(ref_conv.weight.grad.is_contiguous()) + + self.assertEqual(out, ref_out, exact_dtype=False) + self.assertEqual(conv.weight.grad, ref_conv.weight.grad, exact_dtype=False) + self.assertEqual(conv.bias.grad, ref_conv.bias.grad, exact_dtype=False) + self.assertEqual(input.grad, ref_input.grad, exact_dtype=False) + + helper(2, 8, 4, 4, 4, out_channels=4, kernel_size=3, groups=1) + helper(2, 8, 4, 4, 4, out_channels=8, kernel_size=3, groups=8) + helper(1, 16, 18, 18, 18, out_channels=16, kernel_size=3, groups=1) + helper(1, 16, 18, 18, 18, out_channels=16, kernel_size=3, groups=16) + + def _run_conv(self, layer, device, inp, grad, ref_conv, ref_input, ref_out, + input_format, weight_format, grad_format, output_format): + conv = layer(inp.size(1), grad.size(1), + ref_conv.weight.size(2)).float().to(device) + # load_state_dict will restore the stride & memory_layout on ref_conv.weight. + conv.load_state_dict(ref_conv.state_dict()) + weight_data = conv.weight.detach().clone().contiguous(memory_format=weight_format) + conv.weight.data = weight_data.resize_(weight_data.size(), memory_format=weight_format) + input = inp.clone().contiguous(memory_format=input_format) + input.resize_(input.size(), memory_format=input_format) + input = input.requires_grad_() + grad = grad.contiguous(memory_format=grad_format) + grad.resize_(grad.size(), memory_format=grad_format) + out = conv(input) + out.backward(grad) + self.assertTrue(out.is_contiguous(memory_format=output_format)) + self.assertEqual(out, ref_out) + self.assertEqual(conv.weight.grad, ref_conv.weight.grad) + self.assertEqual(conv.bias.grad, ref_conv.bias.grad) + self.assertEqual(input.grad, ref_input.grad) + + def _test_conv_cudnn_nhwc_nchw(self, layer, n, c, h, w, k, filter_size, device): + data = torch.randint(1, 10, (n, c, h, w), dtype=torch.float32, device=device) + ref_input = data.clone().contiguous().requires_grad_(True) + ref_conv = layer(c, k, filter_size).float().to(device) + ref_out = ref_conv(ref_input) + grad = torch.randint(1, 10, ref_out.size(), dtype=torch.float32, device="cuda") + ref_out.backward(grad) + + for w_f in [torch.contiguous_format, torch.channels_last]: + for g_f in [torch.contiguous_format, torch.channels_last]: + for input_format in [torch.contiguous_format, torch.channels_last]: + output_format = torch.contiguous_format + # Older versions of CudNN have Channels Last support disabled + if torch.backends.cudnn.version() >= 7603: + if input_format == torch.channels_last: + output_format = torch.channels_last + # This is because we have N111 weight that cannot handle + # the ambiguous memory_format + if w_f == torch.channels_last: + if layer == nn.Conv2d and filter_size * c != 1: + output_format = torch.channels_last + if layer == nn.ConvTranspose2d and filter_size * k != 1: + output_format = torch.channels_last + self._run_conv(layer, device, data, grad, ref_conv, ref_input, + ref_out, input_format, w_f, g_f, output_format) + + @onlyCUDA + @skipCUDAIfRocmVersionLessThan((4, 3)) + @skipCUDAIfNotMiopenSuggestNHWC + @skipCUDAIfCudnnVersionLessThan(7603) + @tf32_on_and_off(0.05) + def test_conv_cudnn_mismatch_memory_format(self, device): + configs = [ + [4, 2, 8, 8, 4, 2], + [4, 1, 8, 8, 4, 2], + [1, 1, 8, 8, 4, 2], + [4, 2, 2, 8, 4, 1], + [4, 2, 1, 8, 4, 1], + [4, 2, 8, 8, 4, 1], + [4, 1, 8, 8, 4, 1], + ] + for n, c, h, w, k, filter_size in configs: + self._test_conv_cudnn_nhwc_nchw(nn.Conv2d, n, c, h, w, k, filter_size, device) + self._test_conv_cudnn_nhwc_nchw(nn.ConvTranspose2d, n, c, h, w, k, filter_size, device) + + # torch.half is erroring out on Windows with CUDA 10.1 + cuDNN 7.6.4 + # returning CUDNN_STATUS_BAD_PARAM + # Disabling that specific test for now [see issue # 33918] + @onlyCUDA + @skipCUDAIfNoCudnn + @dtypes(torch.float, torch.double) + def test_conv_cudnn_nhwc_support(self, device, dtype): + input = torch.randn((1, 16, 1, 1), dtype=dtype, device="cuda", requires_grad=True) + weight = torch.randn((8, 16, 3, 3), dtype=dtype, device="cuda", requires_grad=True) + weight = weight.to(memory_format=torch.channels_last) + o = torch.conv2d(input, weight, None, (2, 1), (1, 1), (1, 1), 1) + self.assertTrue(o.is_contiguous(memory_format=torch.channels_last)) + o.sum().backward() + + # Test that faster algorithms used for inference produce the same results + # Validates depthwise3x3 bug reported in https://github.com/pytorch/pytorch/issues/60176 + @onlyCPU + @dtypes(torch.float) + def test_conv2d_no_grad(self, device, dtype): + for batch in [1, 2, 3]: + for groups in [1, 2, 4]: + input = torch.rand(batch, groups, 8, 8, dtype=dtype, device=device) + m = nn.Conv2d(groups, 8, kernel_size=(3, 3), groups=groups, dtype=dtype, device=device) + with torch.no_grad(): + output_ng = m(input) + output = m(input) + self.assertEqual(output, output_ng, rtol=1e-2, atol=1e-5) + + @onlyCUDA + @skipCUDAIfNoCudnn + @dtypes(torch.float, torch.float16) + @precisionOverride({torch.half: 0.002, torch.float: 1e-4}) + def test_cudnn_convolution_relu(self, device, dtype): + for batch, groups, image_size, kernel_size, memory_format in \ + product((1, 2, 3), + (1, 2, 4), + ((1, 1), (8, 8)), + ((1, 1), (3, 3)), + (torch.channels_last, torch.contiguous_format)): + if image_size[0] < kernel_size[0]: + continue + inp = torch.rand(batch, groups, *image_size, dtype=dtype, device=device) + w = torch.randn(8, groups, *kernel_size, dtype=dtype, device=device) + conv2d_out = torch.conv2d(inp, w, None, (1, 1), (0, 0), (1, 1), 1) + inp = inp.to(memory_format=memory_format) + w = w.to(memory_format=memory_format) + if torch.version.hip: + cudnn_out = torch.miopen_convolution_relu(inp, w, None, (1, 1), (0, 0), (1, 1), 1) + else: + cudnn_out = torch.cudnn_convolution_relu(inp, w, None, (1, 1), (0, 0), (1, 1), 1) + self.assertTrue(cudnn_out.is_contiguous(memory_format=memory_format)) + if tf32_is_not_fp32() and dtype == torch.float: + self.assertEqual(conv2d_out.relu(), cudnn_out, atol=2e-4, rtol=0.006) + else: + self.assertEqual(conv2d_out.relu(), cudnn_out) + + @onlyCUDA + @skipCUDAIfNoCudnn + @dtypes(torch.float, torch.float16) + @precisionOverride({torch.half: 0.002, torch.float: 1e-4}) + def test_cudnn_convolution_add_relu(self, device, dtype): + for batch, groups, image_size, kernel_size, memory_format in \ + product((1, 2, 3), + (1, 2, 4), + ((1, 1), (8, 8)), + ((1, 1), (3, 3)), + (torch.channels_last, torch.contiguous_format)): + if image_size[0] < kernel_size[0]: + continue + inp = torch.rand(batch, groups, *image_size, dtype=dtype, device=device) + w = torch.randn(8, groups, *kernel_size, dtype=dtype, device=device) + conv2d_out = torch.conv2d(inp, w, None, (1, 1), (0, 0), (1, 1), 1) + alpha = 2.0 + z = torch.randn_like(conv2d_out) + + inp = inp.to(memory_format=memory_format) + w = w.to(memory_format=memory_format) + z = z.to(memory_format=memory_format) + if torch.version.hip: + cudnn_out = torch.miopen_convolution_add_relu(inp, w, z, alpha, None, (1, 1), (0, 0), (1, 1), 1) + else: + cudnn_out = torch.cudnn_convolution_add_relu(inp, w, z, alpha, None, (1, 1), (0, 0), (1, 1), 1) + + self.assertTrue(cudnn_out.is_contiguous(memory_format=memory_format)) + if tf32_is_not_fp32() and dtype == torch.float: + self.assertEqual(F.relu(conv2d_out + alpha * z), cudnn_out, atol=3e-4, rtol=0.006) + else: + self.assertEqual(F.relu(conv2d_out + alpha * z), cudnn_out) + + @onlyCUDA + @skipCUDAIfRocm + @skipCUDAIfCudnnVersionLessThan(7603) + def test_convert_conv2d_weight_memory_format(self, device): + input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float32, device=device) + model = nn.Sequential( + nn.Conv2d(8, 4, 3), + nn.BatchNorm2d(4)).to(device).float() + for memory_format in [torch.channels_last, torch.contiguous_format]: + model = nn.utils.convert_conv2d_weight_memory_format(model, memory_format) + out = model(input) + self.assertTrue(out.is_contiguous(memory_format=memory_format)) + + model = nn.Sequential( + nn.ConvTranspose2d(8, 4, 3), + nn.BatchNorm2d(4)).to(device).float() + for memory_format in [torch.channels_last, torch.contiguous_format]: + model = nn.utils.convert_conv2d_weight_memory_format(model, memory_format) + out = model(input) + self.assertTrue(out.is_contiguous(memory_format=memory_format)) + + def test_conv_double_backward_strided_with_3D_input_and_weight(self, device): + # Test that _convolution_double_backward() outputs the correct grad shapes + # for 3D input / weight when stride > 1. This is an ad-hoc regression test for a + # specific case that was uncovered during the convolution consolidation effort. + # The test can be safely deleted if _convolution_double_backward() is removed. + + input = torch.randn(2, 3, 6, device=device) + weight = torch.randn(3, 3, 3, device=device) + bias = torch.randn(3, device=device) + stride = (2,) + padding = (1,) + dilation = (1,) + transposed = False + output_padding = (0,) + groups = 1 + output = torch.ops.aten.convolution(input, weight, bias, stride, padding, dilation, transposed, + output_padding, groups) + + ggI = torch.randn(input.shape, device=device) + ggW = torch.randn(weight.shape, device=device) + ggB = torch.randn(bias.shape, device=device) + gO = torch.randn(output.shape, device=device) + output_mask = [True, True, True] + grad_grad_output, grad_input, grad_weight = torch.ops.aten._convolution_double_backward( + ggI, ggW, ggB, gO, weight, input, stride, padding, dilation, transposed, + output_padding, groups, output_mask) + + # Make sure the correct shapes are computed. + self.assertEqual(grad_grad_output.shape, gO.shape) + self.assertEqual(grad_input.shape, input.shape) + self.assertEqual(grad_weight.shape, weight.shape) + + @onlyCUDA + @largeTensorTest('40GB') + @largeTensorTest('24GB', 'cpu') + def test_conv3d_64bit_indexing(self, device): + x = torch.rand(1, 32, 512, 512, 256) + m = torch.nn.Conv3d(32, 1, kernel_size=1, padding=0, stride=1, bias=False) + yref = m(x) + y = m.to(device=device)(x.to(device=device)) + self.assertEqual(yref, y) + +instantiate_device_type_tests(TestConvolutionNNDeviceType, globals()) +instantiate_parametrized_tests(TestConvolutionNN) + +if __name__ == '__main__': + run_tests() diff --git a/test/nn/test_dropout.py b/test/nn/test_dropout.py index fa2b0baea5549..150e5f57df7c6 100644 --- a/test/nn/test_dropout.py +++ b/test/nn/test_dropout.py @@ -15,6 +15,9 @@ import torch.nn as nn class TestDropoutNN(NNTestCase): + _do_cuda_memory_leak_check = True + _do_cuda_non_default_stream = True + def _test_alpha_dropout(self, cls, input): mean = input.mean() std = input.std() diff --git a/test/nn/test_init.py b/test/nn/test_init.py new file mode 100644 index 0000000000000..9e72c1040a55a --- /dev/null +++ b/test/nn/test_init.py @@ -0,0 +1,420 @@ +# Owner(s): ["module: nn"] +import random +import unittest +import math +import string +from functools import reduce +from operator import mul + +from torch.testing._internal.common_utils import TestCase, TEST_SCIPY, skipIfNoLapack +import torch +import torch.nn.init as init +import torch.nn.functional as F + +if TEST_SCIPY: + from scipy import stats + +class TestNNInit(TestCase): + def setUp(self): + super(TestNNInit, self).setUp() + random.seed(123) + + def _is_normal(self, tensor, mean, std): + samples = tensor.view(-1).tolist() + p_value = stats.kstest(samples, 'norm', args=(mean, std))[1] + return p_value > 0.0001 + + def _is_trunc_normal(self, tensor, mean, std, a, b): + # scipy's trunc norm is suited for data drawn from N(0, 1), + # so we need to transform our data to test it using scipy. + z_samples = (tensor.view(-1) - mean) / std + z_samples = z_samples.tolist() + a0 = (a - mean) / std + b0 = (b - mean) / std + p_value = stats.kstest(z_samples, 'truncnorm', args=(a0, b0))[1] + return p_value > 0.0001 + + def _is_uniform(self, tensor, a, b): + samples = tensor.view(-1).tolist() + p_value = stats.kstest(samples, 'uniform', args=(a, (b - a)))[1] + return p_value > 0.0001 + + def _create_random_nd_tensor(self, dims, size_min, size_max): + size = [random.randint(size_min, size_max) for _ in range(dims)] + tensor = torch.zeros(size) + return tensor + + def _random_float(self, a, b): + return (b - a) * random.random() + a + + def test_calculate_gain_linear(self): + for fn in ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose2d', 'conv_transpose2d', 'conv_transpose3d']: + gain = init.calculate_gain(fn) + self.assertEqual(gain, 1) + + def test_calculate_gain_nonlinear(self): + for fn in ['sigmoid', 'tanh', 'relu', 'leaky_relu']: + gain = init.calculate_gain(fn) + if fn == 'sigmoid': + self.assertEqual(gain, 1) + elif fn == 'tanh': # 5 / 3 + self.assertEqual(gain, 1.6666666666666667) + elif fn == 'relu': # sqrt(2) + self.assertEqual(gain, 1.4142135623730951) + elif fn == 'leaky_relu': # sqrt(2 / 1 + slope^2)) + self.assertEqual(gain, 1.4141428569978354) + elif fn == 'selu': + self.assertEqual(gain, 0.75) + + def test_calculate_gain_leaky_relu(self): + for param in [None, 0, 0.01, 10]: + gain = init.calculate_gain('leaky_relu', param) + if param is None: # Default slope is 0.01 + self.assertEqual(gain, 1.4141428569978354) + elif param == 0: # No slope = same gain as normal ReLU + self.assertEqual(gain, 1.4142135623730951) + elif param == 0.01: + self.assertEqual(gain, 1.4141428569978354) + elif param == 10: + self.assertEqual(gain, 0.14071950894605836) + + def test_calculate_gain_leaky_relu_only_accepts_numbers(self): + for param in [True, [1], {'a': 'b'}]: + with self.assertRaises(ValueError): + init.calculate_gain('leaky_relu', param) + + def test_calculate_gain_only_accepts_valid_nonlinearities(self): + for n in [2, 5, 25]: + # Generate random strings of lengths that definitely aren't supported + random_string = ''.join([random.choice(string.ascii_lowercase) for i in range(n)]) + with self.assertRaises(ValueError): + init.calculate_gain(random_string) + + @unittest.skipIf(not TEST_SCIPY, "Scipy not found.") + def test_uniform(self): + for dims in [1, 2, 4]: + input_tensor = self._create_random_nd_tensor(dims, size_min=30, size_max=50) + a = self._random_float(-3, 3) + b = a + self._random_float(1, 5) + init.uniform_(input_tensor, a=a, b=b) + assert self._is_uniform(input_tensor, a, b) + + @unittest.skipIf(not TEST_SCIPY, "Scipy not found.") + def test_normal(self): + for dims in [1, 2, 4]: + input_tensor = self._create_random_nd_tensor(dims, size_min=30, size_max=50) + mean = self._random_float(-3, 3) + std = self._random_float(1, 5) + init.normal_(input_tensor, mean=mean, std=std) + + assert self._is_normal(input_tensor, mean, std) + + @unittest.skipIf(not TEST_SCIPY, "Scipy not found.") + def test_trunc_normal(self): + for dims in [1, 2, 4]: + input_tensor = self._create_random_nd_tensor(dims, size_min=30, size_max=50) + mean = self._random_float(-3, 3) + std = self._random_float(.01, 1) + a = self._random_float(mean - 2 * std, mean) + b = self._random_float(mean, mean + 2 * std) + init.trunc_normal_(input_tensor, mean=mean, std=std, a=a, b=b) + + assert self._is_trunc_normal(input_tensor, mean, std, a, b) + + def test_constant(self): + for dims in [1, 2, 4]: + input_tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=5) + val = self._random_float(1, 10) + init.constant_(input_tensor, val) + + self.assertEqual(input_tensor, input_tensor.clone().fill_(val)) + + def test_ones_and_zeros(self): + for init_fn_, val in zip([init.ones_, init.zeros_], [1, 0]): + for dims in [1, 2, 4]: + input_tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=5) + init_fn_(input_tensor) + + self.assertEqual(input_tensor, input_tensor.clone().fill_(val)) + + def test_eye(self): + input_tensor = self._create_random_nd_tensor(2, size_min=1, size_max=5) + init.eye_(input_tensor) + + # Check every single element + for i in range(input_tensor.size(0)): + for j in range(input_tensor.size(1)): + if i == j: + assert input_tensor[i][j] == 1 + else: + assert input_tensor[i][j] == 0 + + def test_eye_only_works_on_2d_inputs(self): + for dims in [1, 3]: + with self.assertRaises(ValueError): + tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=3) + init.eye_(tensor) + + def test_dirac_properties(self): + for dims in [3, 4, 5]: + for groups in [1, 2, 3]: + # prepare random tensor with random sizes, but fits groups + a, c, d, e = (random.randint(1, 5) for _ in range(4)) + b = random.randint(1, 5 * groups) # same range as a*groups but all range allowed + # make sure first dim divides by groups + input_tensor = torch.randn((a * groups, b, c, d, e)[:dims]) + + init.dirac_(input_tensor, groups) + + c_out, c_in = input_tensor.size(0) // groups, input_tensor.size(1) + min_d = min(c_out, c_in) + # Check number of nonzeros is equivalent to smallest dim (for each group) + assert torch.nonzero(input_tensor).size(0) == min_d * groups + # Check sum of values (can have precision issues, hence assertEqual) is also equivalent + self.assertEqual(input_tensor.sum(), min_d * groups) + + + def test_dirac_identity(self): + for groups in [1, 3]: + batch, in_c, out_c, size, kernel_size = 8, 3, 9, 5, 3 # in_c, out_c must divide by groups + eff_out_c = out_c // groups + + # Test 1D + input_var = torch.randn(batch, in_c, size) + filter_var = torch.zeros(eff_out_c, in_c, kernel_size) + filter_var = torch.cat([filter_var] * groups) + init.dirac_(filter_var, groups) + output_var = F.conv1d(input_var, filter_var) + input_tensor, output_tensor = input_var.data, output_var.data # Variables do not support nonzero + for g in range(groups): + # Assert in_c outputs are preserved (per each group) + self.assertEqual(input_tensor[:, :, 1:-1], + output_tensor[:, eff_out_c * g:eff_out_c * g + in_c, :]) + # Assert extra outputs are 0 + assert torch.nonzero(output_tensor[:, eff_out_c * g + in_c:eff_out_c * (g + 1), :]).numel() == 0 + + # Test 2D + input_var = torch.randn(batch, in_c, size, size) + filter_var = torch.zeros(eff_out_c, in_c, kernel_size, kernel_size) + filter_var = torch.cat([filter_var] * groups) + init.dirac_(filter_var, groups) + output_var = F.conv2d(input_var, filter_var) + input_tensor, output_tensor = input_var.data, output_var.data # Variables do not support nonzero + for g in range(groups): + # Assert in_c outputs are preserved (per each group) + self.assertEqual(input_tensor[:, :, 1:-1, 1:-1], + output_tensor[:, eff_out_c * g:eff_out_c * g + in_c, :, :]) + # Assert extra outputs are 0 + assert torch.nonzero(output_tensor[:, eff_out_c * g + in_c:eff_out_c * (g + 1), :, :]).numel() == 0 + + # Test 3D + input_var = torch.randn(batch, in_c, size, size, size) + filter_var = torch.zeros(eff_out_c, in_c, kernel_size, kernel_size, kernel_size) + filter_var = torch.cat([filter_var] * groups) + init.dirac_(filter_var, groups) + output_var = F.conv3d(input_var, filter_var) + input_tensor, output_tensor = input_var.data, output_var.data + for g in range(groups): + # Assert in_c outputs are preserved (per each group) + self.assertEqual(input_tensor[:, :, 1:-1, 1:-1, 1:-1], + output_tensor[:, eff_out_c * g:eff_out_c * g + in_c, :, :, :]) + # Assert extra outputs are 0 + assert torch.nonzero(output_tensor[:, eff_out_c * g + in_c:eff_out_c * (g + 1), :, :, :]).numel() == 0 + + def test_dirac_only_works_on_3_4_5d_inputs(self): + for dims in [1, 2, 6]: + with self.assertRaises(ValueError): + tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=3) + init.dirac_(tensor) + + def test_xavier_uniform_errors_on_inputs_smaller_than_2d(self): + for dims in [0, 1]: + tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=1) + with self.assertRaises(ValueError): + init.xavier_uniform_(tensor) + + def test_xavier_normal_errors_on_inputs_smaller_than_2d(self): + for dims in [0, 1]: + tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=1) + with self.assertRaises(ValueError): + init.xavier_normal_(tensor) + + @unittest.skipIf(not TEST_SCIPY, "Scipy not found.") + def test_xavier_uniform(self): + for use_gain in [True, False]: + for dims in [2, 4]: + input_tensor = self._create_random_nd_tensor(dims, size_min=20, size_max=25) + gain = 1 + + if use_gain: + gain = self._random_float(0.1, 2) + init.xavier_uniform_(input_tensor, gain=gain) + else: + init.xavier_uniform_(input_tensor) + + fan_in = input_tensor.size(1) + fan_out = input_tensor.size(0) + if input_tensor.dim() > 2: + fan_in *= input_tensor[0, 0].numel() + fan_out *= input_tensor[0, 0].numel() + + expected_std = gain * math.sqrt(2.0 / (fan_in + fan_out)) + bounds = expected_std * math.sqrt(3) + assert self._is_uniform(input_tensor, -bounds, bounds) + + @unittest.skipIf(not TEST_SCIPY, "Scipy not found.") + def test_xavier_normal(self): + for use_gain in [True, False]: + for dims in [2, 4]: + input_tensor = self._create_random_nd_tensor(dims, size_min=20, size_max=25) + gain = 1 + + if use_gain: + gain = self._random_float(0.1, 2) + init.xavier_normal_(input_tensor, gain=gain) + else: + init.xavier_normal_(input_tensor) + + fan_in = input_tensor.size(1) + fan_out = input_tensor.size(0) + if input_tensor.dim() > 2: + fan_in *= input_tensor[0, 0].numel() + fan_out *= input_tensor[0, 0].numel() + + expected_std = gain * math.sqrt(2.0 / (fan_in + fan_out)) + assert self._is_normal(input_tensor, 0, expected_std) + + def test_kaiming_uniform_errors_on_inputs_smaller_than_2d(self): + for dims in [0, 1]: + with self.assertRaises(ValueError): + tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=1) + init.kaiming_uniform_(tensor) + + def test_kaiming_normal_errors_on_inputs_smaller_than_2d(self): + for dims in [0, 1]: + with self.assertRaises(ValueError): + tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=1) + init.kaiming_normal_(tensor) + + def test_kaiming_uniform_warning_on_0element_tensor(self): + tensor = torch.empty(0, 1) + with self.assertWarnsRegex(UserWarning, "Initializing zero-element tensors is a no-op"): + _ = init.kaiming_uniform_(tensor) + + def test_kaiming_normal_warning_on_0element_tensor(self): + tensor = torch.empty(0, 1) + with self.assertWarnsRegex(UserWarning, "Initializing zero-element tensors is a no-op"): + _ = init.kaiming_normal_(tensor) + + @unittest.skipIf(not TEST_SCIPY, "Scipy not found.") + def test_kaiming_uniform(self): + for use_a in [True, False]: + for dims in [2, 4]: + for mode in ['fan_in', 'fan_out']: + input_tensor = self._create_random_nd_tensor(dims, size_min=20, size_max=25) + if use_a: + a = self._random_float(0.1, 2) + init.kaiming_uniform_(input_tensor, a=a, mode=mode) + else: + a = 0 + init.kaiming_uniform_(input_tensor, mode=mode) + + fan_in = input_tensor.size(1) + fan_out = input_tensor.size(0) + if input_tensor.dim() > 2: + fan_in *= input_tensor[0, 0].numel() + fan_out *= input_tensor[0, 0].numel() + + if mode == 'fan_in': + n = fan_in + else: + n = fan_out + + expected_std = math.sqrt(2.0 / ((1 + a**2) * n)) + bounds = expected_std * math.sqrt(3.0) + assert self._is_uniform(input_tensor, -bounds, bounds) + + @unittest.skipIf(not TEST_SCIPY, "Scipy not found.") + def test_kaiming_normal(self): + for use_a in [True, False]: + for dims in [2, 4]: + for mode in ['fan_in', 'fan_out']: + input_tensor = self._create_random_nd_tensor(dims, size_min=20, size_max=25) + if use_a: + a = self._random_float(0.1, 2) + init.kaiming_normal_(input_tensor, a=a, mode=mode) + else: + a = 0 + init.kaiming_normal_(input_tensor, mode=mode) + + fan_in = input_tensor.size(1) + fan_out = input_tensor.size(0) + if input_tensor.dim() > 2: + fan_in *= input_tensor[0, 0].numel() + fan_out *= input_tensor[0, 0].numel() + + if mode == 'fan_in': + n = fan_in + else: + n = fan_out + + expected_std = math.sqrt(2.0 / ((1 + a**2) * n)) + assert self._is_normal(input_tensor, 0, expected_std) + + def test_sparse_only_works_on_2d_inputs(self): + for dims in [1, 3]: + with self.assertRaises(ValueError): + sparsity = self._random_float(0.1, 0.9) + tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=3) + init.sparse_(tensor, sparsity) + + @unittest.skipIf(not TEST_SCIPY, "Scipy not found.") + def test_sparse_default_std(self): + for use_random_std in [True, False]: + input_tensor = self._create_random_nd_tensor(2, size_min=30, size_max=35) + rows, cols = input_tensor.size(0), input_tensor.size(1) + sparsity = self._random_float(0.1, 0.2) + + std = 0.01 # default std + if use_random_std: + std = self._random_float(0.01, 0.2) + init.sparse_(input_tensor, sparsity=sparsity, std=std) + else: + init.sparse_(input_tensor, sparsity=sparsity) + + for col_idx in range(input_tensor.size(1)): + column = input_tensor[:, col_idx] + assert column[column == 0].nelement() >= math.ceil(sparsity * rows) + + assert self._is_normal(input_tensor[input_tensor != 0], 0, std) + + @skipIfNoLapack + def test_orthogonal(self): + for use_gain in [True, False]: + for tensor_size in [[3, 4], [4, 3], [20, 2, 3, 4], [2, 3, 4, 5]]: + input_tensor = torch.zeros(tensor_size) + gain = 1.0 + + if use_gain: + gain = self._random_float(0.1, 2) + init.orthogonal_(input_tensor, gain=gain) + else: + init.orthogonal_(input_tensor) + + rows, cols = tensor_size[0], reduce(mul, tensor_size[1:]) + flattened_tensor = input_tensor.view(rows, cols) + if rows > cols: + self.assertEqual(torch.mm(flattened_tensor.t(), flattened_tensor), + torch.eye(cols) * gain ** 2, atol=1e-6, rtol=0) + else: + self.assertEqual(torch.mm(flattened_tensor, flattened_tensor.t()), + torch.eye(rows) * gain ** 2, atol=1e-6, rtol=0) + + def test_deprecation(self): + x = torch.randn(3, 3) + + def fn(): + init.normal(x) + + with self.assertWarnsRegex(UserWarning, 'deprecated', msg='methods not suffixed with underscore should be deprecated'): + fn() diff --git a/test/nn/test_module_hooks.py b/test/nn/test_module_hooks.py new file mode 100644 index 0000000000000..889966e006c1a --- /dev/null +++ b/test/nn/test_module_hooks.py @@ -0,0 +1,1334 @@ +# Owner(s): ["module: nn"] +from torch.testing._internal.common_utils import ( + TestCase, + run_tests, + skipIfTorchDynamo, + IS_WINDOWS +) +from torch.testing._internal.common_nn import NNTestCase, _create_basic_net + +import torch +import torch.nn as nn + +from functools import partial +from typing import Any, Dict, List, Tuple +import gc +import unittest +from copy import deepcopy +from tempfile import NamedTemporaryFile +import weakref +import pickle +from collections import OrderedDict +import math + + +class Net(nn.Module): + def __init__(self) -> None: + super().__init__() + self.seq1 = nn.Sequential(*[nn.Linear(10, 10) for _ in range(2)]) + self.seq2 = nn.Sequential(*[nn.Linear(10, 10) for _ in range(2)]) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.seq2(self.seq1(x)) + + +class ToyModel(nn.Module): + def __init__(self) -> None: + super().__init__() + self.net1 = Net() + self.net2 = Net() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.net2(self.net1(x)) + + +def forward_hook( + self: TestCase, + fired_hooks: List[int], + expected_module: nn.Module, + hook_id: int, + module: nn.Module, + inp: Tuple[torch.Tensor], + out: torch.Tensor, +) -> None: + fired_hooks.append(hook_id) + self.assertEqual(id(module), id(expected_module)) + self.assertEqual(len(inp), 1) + + +def forward_pre_hook( + self: TestCase, + fired_hooks: List[int], + expected_module: nn.Module, + hook_id: int, + module: nn.Module, + inp: Tuple[torch.Tensor], +) -> None: + fired_hooks.append(hook_id) + self.assertEqual(id(module), id(expected_module)) + self.assertEqual(len(inp), 1) + + +def full_backward_hook( + self: TestCase, + fired_hooks: List[int], + expected_module: nn.Module, + hook_id: int, + module: nn.Module, + grad_input: Tuple[torch.Tensor], + grad_output: Tuple[torch.Tensor], +) -> None: + fired_hooks.append(hook_id) + self.assertEqual(id(module), id(expected_module)) + self.assertEqual(len(grad_input), 1) + self.assertEqual(len(grad_output), 1) + + +def full_backward_pre_hook( + self: TestCase, + fired_hooks: List[int], + expected_module: nn.Module, + hook_id: int, + module: nn.Module, + grad_input: Tuple[torch.Tensor], +) -> None: + fired_hooks.append(hook_id) + self.assertEqual(id(module), id(expected_module)) + self.assertEqual(len(grad_input), 1) + + +class KwargModel(nn.Module): + def __init__(self) -> None: + super().__init__() + self.net1 = Net() + self.net2 = Net() + + def forward( + self, x: torch.Tensor, bias: torch.Tensor = None + ) -> torch.Tensor: + if bias is not None: + x = x + bias + return x + + def internal_forward_hook( + self, + module: nn.Module, + args: Tuple[torch.Tensor], + kwargs: Dict[str, Any], + out: torch.Tensor, + ): + return out + kwargs["bias"] + + +def kwarg_forward_pre_hook( + self: TestCase, + fired_hooks: List[int], + expected_module: nn.Module, + hook_id: int, + module: nn.Module, + args: Tuple[torch.Tensor], + kwargs: Dict[str, Any], +) -> Tuple[Any, Any]: + fired_hooks.append(hook_id) + self.assertEqual(id(module), id(expected_module)) + self.assertEqual(len(args), 1) + kwargs["bias"] = 2 * kwargs["bias"] + return args, kwargs + + +def kwarg_forward_hook( + self: TestCase, + fired_hooks: List[int], + expected_module: nn.Module, + hook_id: int, + module: nn.Module, + args: Tuple[torch.Tensor], + kwargs: Dict[str, Any], + out: torch.Tensor, +) -> Any: + fired_hooks.append(hook_id) + self.assertEqual(id(module), id(expected_module)) + self.assertEqual(len(args), 1) + + out = out + kwargs["bias"] + return out + + +class TestModuleHooks(TestCase): + @skipIfTorchDynamo("Dynamo does not yet capture hooks") + def test_forward_hooks(self): + fired_hooks: List[int] = [] + model = ToyModel() + x = torch.randn(10, 10) + hook = partial(forward_hook, self, fired_hooks, model.net1.seq2) + model.net1.seq2.register_forward_hook(partial(hook, 0)) + model.net1.seq2.register_forward_hook(partial(hook, 1), prepend=True) + model.net1.seq2.register_forward_hook(partial(hook, 2)) + model.net1.seq2.register_forward_hook(partial(hook, 3)) + model.net1.seq2.register_forward_hook(partial(hook, 4), prepend=True) + expected = [4, 1, 0, 2, 3] + + self.assertEqual(fired_hooks, []) + out = model(x) + self.assertEqual(fired_hooks, expected) + out.sum().backward() + self.assertEqual(fired_hooks, expected) + model(x).sum().backward() + self.assertEqual(fired_hooks, expected + expected) + + @skipIfTorchDynamo("Dynamo does not yet capture hooks") + def test_forward_pre_hooks(self): + fired_hooks: List[int] = [] + model = ToyModel() + x = torch.randn(10, 10) + hook = partial(forward_pre_hook, self, fired_hooks, model.net2.seq1) + model.net2.seq1.register_forward_pre_hook( + partial(hook, 0), prepend=True + ) + model.net2.seq1.register_forward_pre_hook(partial(hook, 1)) + model.net2.seq1.register_forward_pre_hook(partial(hook, 2)) + model.net2.seq1.register_forward_pre_hook(partial(hook, 3)) + model.net2.seq1.register_forward_pre_hook( + partial(hook, 4), prepend=True + ) + expected = [4, 0, 1, 2, 3] + + self.assertEqual(fired_hooks, []) + out = model(x) + self.assertEqual(fired_hooks, expected) + out.sum().backward() + self.assertEqual(fired_hooks, expected) + model(x).sum().backward() + self.assertEqual(fired_hooks, expected + expected) + + @skipIfTorchDynamo("Dynamo does not yet capture hooks") + def test_full_backward_hooks(self): + fired_hooks: List[int] = [] + model = ToyModel() + x = torch.randn(10, 10) + hook = partial(full_backward_hook, self, fired_hooks, model.net1) + model.net1.register_full_backward_hook(partial(hook, 0)) + model.net1.register_full_backward_hook(partial(hook, 1)) + model.net1.register_full_backward_hook(partial(hook, 2)) + model.net1.register_full_backward_hook(partial(hook, 3), prepend=True) + model.net1.register_full_backward_hook(partial(hook, 4), prepend=True) + expected = [4, 3, 0, 1, 2] + + self.assertEqual(fired_hooks, []) + out = model(x) + self.assertEqual(fired_hooks, []) + out.sum().backward() + self.assertEqual(fired_hooks, expected) + model(x).sum().backward() + self.assertEqual(fired_hooks, expected + expected) + + @skipIfTorchDynamo("Dynamo does not yet capture hooks") + def test_full_backward_pre_hooks(self): + fired_hooks: List[int] = [] + model = ToyModel() + x = torch.randn(10, 10) + hook = partial(full_backward_pre_hook, self, fired_hooks, model.net1) + model.net1.register_full_backward_pre_hook( + partial(hook, 0), prepend=True + ) + model.net1.register_full_backward_pre_hook( + partial(hook, 1), prepend=True + ) + model.net1.register_full_backward_pre_hook(partial(hook, 2)) + model.net1.register_full_backward_pre_hook(partial(hook, 3)) + model.net1.register_full_backward_pre_hook(partial(hook, 4)) + expected = [1, 0, 2, 3, 4] + + self.assertEqual(fired_hooks, []) + out = model(x) + self.assertEqual(fired_hooks, []) + out.sum().backward() + self.assertEqual(fired_hooks, expected) + model(x).sum().backward() + self.assertEqual(fired_hooks, expected + expected) + + @skipIfTorchDynamo("Dynamo does not yet capture hooks") + def test_mixed_hooks(self): + fired_hooks: List[int] = [] + model = ToyModel() + x = torch.randn(10, 10) + model.register_forward_pre_hook( + partial(forward_pre_hook, self, fired_hooks, model, 0) + ) + model.register_forward_hook( + partial(forward_hook, self, fired_hooks, model, 1) + ) + model.register_full_backward_pre_hook( + partial(full_backward_pre_hook, self, fired_hooks, model, 2) + ) + model.register_full_backward_hook( + partial(full_backward_hook, self, fired_hooks, model, 3) + ) + + self.assertEqual(fired_hooks, []) + out = model(x) + self.assertEqual(fired_hooks, [0, 1]) + out.sum().backward() + self.assertEqual(fired_hooks, [0, 1, 2, 3]) + model(x).sum().backward() + self.assertEqual(fired_hooks, [0, 1, 2, 3, 0, 1, 2, 3]) + + @skipIfTorchDynamo("Dynamo does not yet capture hooks") + def test_kwarg_hooks(self): + # 1. test forward pre hook + fired_hooks: List[int] = [] + x: torch.Tensor = torch.ones(10, 10) + bias: torch.Tensor = torch.ones(10, 10) + model = KwargModel() + model.register_forward_pre_hook( + partial(kwarg_forward_pre_hook, self, fired_hooks, model, 0), + with_kwargs=True, + ) + + # forward-pre: bias' = bias * 2 + # So, out = x + bias * 2 + self.assertEqual(fired_hooks, []) + out = model(x, bias=bias) + self.assertEqual(fired_hooks, [0]) + self.assertEqual(out, x + 2 * bias, rtol=0, atol=1e-5) + + # 2. test forward pre and forward hooks + fired_hooks: List[int] = [] + x: torch.Tensor = torch.ones(10, 10) + bias: torch.Tensor = torch.ones(10, 10) + model = KwargModel() + model.register_forward_hook( + partial(kwarg_forward_hook, self, fired_hooks, model, 1), + with_kwargs=True, + ) + model.register_forward_pre_hook( + partial(kwarg_forward_pre_hook, self, fired_hooks, model, 0), + with_kwargs=True, + ) + + # forward-pre: bias' = bias * 2 + # forward: out = x + bias' + # forward-post: out = out + bias' + # So, out = x + bias * 4 + self.assertEqual(fired_hooks, []) + out = model(x, bias=bias) + self.assertEqual(fired_hooks, [0, 1]) + self.assertEqual(out, x + 4 * bias, rtol=0, atol=1e-5) + + # 3. test nn.Module member method as forward-post hook + x: torch.Tensor = torch.ones(10, 10) + bias: torch.Tensor = torch.ones(10, 10) + model = KwargModel() + model.register_forward_hook( + model.internal_forward_hook, with_kwargs=True + ) + + # forward: out = x + bias + # forward-post: out = out + bias + # So, out = x + bias * 2 + out = model(x, bias=bias) + self.assertEqual(out, x + 2 * bias, rtol=0, atol=1e-5) + + + @skipIfTorchDynamo("Dynamo does not yet capture hooks") + def test_remove_kwarg_hooks(self): + # test forward pre and forward hooks + fired_hooks: List[int] = [] + x: torch.Tensor = torch.ones(10, 10) + bias: torch.Tensor = torch.ones(10, 10) + model = KwargModel() + forward_hook_handle = model.register_forward_hook( + partial(kwarg_forward_hook, self, fired_hooks, model, 1), + with_kwargs=True, + ) + forward_pre_hook_handle = model.register_forward_pre_hook( + partial(kwarg_forward_pre_hook, self, fired_hooks, model, 0), + with_kwargs=True, + ) + + # forward-pre: bias' = bias * 2 + # forward: out = x + bias' + # forward-post: out = out + bias' + # So, out = x + bias * 4 + self.assertEqual(fired_hooks, []) + out = model(x, bias=bias) + self.assertEqual(fired_hooks, [0, 1]) + self.assertEqual(out, x + 4 * bias, rtol=0, atol=1e-5) + + # forward-pre: bias' = bias * 2 + # forward: out = x + bias' + # So, out = x + bias * 2 + forward_hook_handle.remove() + out = model(x, bias=bias) + self.assertEqual(fired_hooks, [0, 1, 0]) + self.assertEqual(out, x + 2 * bias, rtol=0, atol=1e-5) + self.assertFalse( + forward_hook_handle.id in model._forward_hooks_with_kwargs + ) + + # forward: out = x + bias + # So, out = x + bias + forward_pre_hook_handle.remove() + out = model(x, bias=bias) + self.assertEqual(fired_hooks, [0, 1, 0]) + self.assertEqual(out, x + bias, rtol=0, atol=1e-5) + self.assertFalse( + forward_pre_hook_handle.id in model._forward_pre_hooks_with_kwargs + ) + + +def _hook_to_pickle(*args, **kwargs): + pass + +class TestStateDictHooks(TestCase): + + def test_load_state_dict_pre_hook(self): + + m = nn.Linear(10, 10) + m_state_dict = m.state_dict() + + m_load = nn.Linear(10, 10) + + hook_called = 0 + + def hook_without_module(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + self.assertEqual(m_state_dict, state_dict) + nonlocal hook_called + hook_called += 1 + + def hook_with_module(module, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + self.assertEqual(m_state_dict, state_dict) + self.assertTrue(m_load is module) + nonlocal hook_called + hook_called += 1 + + hook_called = 0 + m_load._register_load_state_dict_pre_hook(hook_without_module) + m_load.load_state_dict(m_state_dict) + self.assertEqual(1, hook_called) + + hook_called = 0 + m_load._register_load_state_dict_pre_hook(hook_with_module, True) + m_load.load_state_dict(m_state_dict) + self.assertEqual(2, hook_called) + + def test_no_extra_ref_to_module(self): + try: + gc.disable() + m = nn.Linear(10, 10) + + m._register_load_state_dict_pre_hook(_hook_to_pickle, True) + weak_m = weakref.ref(m) + del m + + self.assertEqual(weak_m(), None) + finally: + gc.enable() + + def test_pickled_hook(self): + m = nn.Linear(10, 10) + m._register_load_state_dict_pre_hook(_hook_to_pickle, True) + pickle.loads(pickle.dumps(m)) + + def test_load_state_dict_module_pre_hook(self): + hook_called = 0 + + # Test with module instance method as hook + class MyModule(nn.Module): + def __init__(self): + super(MyModule, self).__init__() + self.foo = torch.nn.Parameter(torch.rand(10)) + + def my_pre_load_hook(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + assert [] == error_msgs + assert [] == unexpected_keys + assert [] == missing_keys + assert strict + nonlocal hook_called + hook_called += 1 + + def my_pre_load_hook_with_module( + self, + module, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + assert [] == error_msgs + assert [] == unexpected_keys + assert [] == missing_keys + assert strict + assert self is module + nonlocal hook_called + hook_called += 1 + + # Test that hooks registered on a submodule are also called + # appropriately, i.e. with the submodule as module argument in + # my_pre_load_hook_with_module. + class MyModuleContainer(nn.Module): + def __init__(self, mod): + super().__init__() + self.mod = mod + + for ctor in [MyModuleContainer, lambda x: x]: + m = ctor(MyModule()) + state_dict = m.state_dict() + if isinstance(m, MyModuleContainer): + mod = m.mod + else: + mod = m + + hook_called = 0 + mod._register_load_state_dict_pre_hook( + mod.my_pre_load_hook + ) + m.load_state_dict(state_dict) + self.assertEqual(1, hook_called) + + hook_called = 0 + mod._register_load_state_dict_pre_hook( + mod.my_pre_load_hook_with_module, True + ) + m.load_state_dict(state_dict) + self.assertEqual(2, hook_called) + + def test_load_state_dict_post_hook(self): + hook_called = 0 + + class MyModule(nn.Module): + def __init__(self): + super(MyModule, self).__init__() + self.foo = torch.nn.Parameter(torch.rand(10)) + + def my_post_load_hook(self, module, incompatible_keys): + assert module is self + nonlocal hook_called + incompatible_keys.missing_keys.append("foo") + incompatible_keys.unexpected_keys.append("bar") + hook_called += 1 + + nested = MyModule() + wrapped = nn.ModuleList([nested]) + handle = nested.register_load_state_dict_post_hook( + nested.my_post_load_hook, + ) + # Hook must be called even if it is wrapped + ret = wrapped.load_state_dict(wrapped.state_dict(), strict=False) + self.assertEqual(hook_called, 1) + # Ensure that the hook modified missing_keys and unexpected_keys + missing = ret.missing_keys + unexpected = ret.unexpected_keys + self.assertEqual(missing, ["foo"]) + self.assertEqual(unexpected, ["bar"]) + # When called with strict=True, the error raised should mention the + # missing and unexpected keys the hook added. + with self.assertRaisesRegex(RuntimeError, "foo.*\n.*bar"): + wrapped.load_state_dict(wrapped.state_dict(), strict=True) + self.assertEqual(hook_called, 2) + # Removing the hook via handle.remove() should cause it not to + # fire anymore. + handle.remove() + # Hook did not run so it should not have added any keys + ret = wrapped.load_state_dict(wrapped.state_dict(), strict=False) + self.assertEqual(ret.missing_keys, []) + self.assertEqual(ret.unexpected_keys, []) + # hook_called should not have been incremented + self.assertEqual(hook_called, 2) + + def load_hook_clear_incompatible(module, incompatible_keys): + incompatible_keys.missing_keys.clear() + incompatible_keys.unexpected_keys.clear() + + nested.register_load_state_dict_post_hook(load_hook_clear_incompatible) + state_dict = wrapped.state_dict() + state_dict["extra"] = torch.ones(1) + # load state_dict with strict=True should not throw. + ret = wrapped.load_state_dict(state_dict, strict=True) + # explicitly ensure that the post hook clearned out incompatible_keys + self.assertEqual([], ret.missing_keys) + self.assertEqual([], ret.unexpected_keys) + + @unittest.skipIf(IS_WINDOWS, "Tempfile permission issue on windows") + def test_load_state_dict_post_hook_backward_compatibility(self): + def my_post_load_hook(mod, _): + nonlocal called + called = True + + for m in [nn.Softmin(10), nn.Softmax(10), nn.LogSoftmax(10)]: + called = False + sd = deepcopy(m.state_dict()) + self.assertTrue(hasattr(m, '_load_state_dict_post_hooks')) + # Simulate an older model that did not have this attr + delattr(m, '_load_state_dict_post_hooks') + # Save and load, and ensure that load_state_dict works (without proper + # BC we would run into errors because this attribute would be expected). + # In particular, Softmax runs into the issue described here: + # https://github.com/pytorch/pytorch/issues/77280 + with NamedTemporaryFile() as f: + # Note that torch.save / torch.load is not recommended to save/load + # modules. + torch.save(m, f.name) + m = torch.load(f.name) + m.load_state_dict(sd) + self.assertFalse(called) + + # Ensure hooks can be registered and called. + m.register_load_state_dict_post_hook(my_post_load_hook) + m.load_state_dict(sd) + self.assertTrue(called) + + +class TestModuleGlobalHooks(TestCase): + + def tearDown(self): + nn.modules.module._global_backward_hooks = OrderedDict() + nn.modules.module._global_forward_hooks = OrderedDict() + nn.modules.module._global_forward_pre_hooks = OrderedDict() + + @skipIfTorchDynamo("TorchDynamo does not work well with hooks") + def test_module_global_hooks(self): + module = nn.Sigmoid + + module_1 = module() + module_2 = module() + module_3 = module() + + input = torch.ones(5, 5, requires_grad=True) + + counter = { + 'forwards': 0, + 'backwards': 0 + } + + def fw_hook(inc, h_module, input, output): + self.assertIsInstance(input, tuple) + self.assertTrue(isinstance(output, torch.Tensor)) + self.assertTrue(isinstance(h_module, module)) + self.assertEqual(input[0], torch.ones(5, 5)) + self.assertEqual(output, torch.empty(5, 5).fill_(1 / (1 + 1 / math.e))) + counter['forwards'] += inc + + def bw_hook(inc, h_module, grad_input, grad_output): + self.assertIsInstance(grad_input, tuple) + self.assertIsInstance(grad_output, tuple) + self.assertTrue(isinstance(h_module, module)) + self.assertEqual(grad_output[0], torch.ones(5, 5) * 2) + counter['backwards'] += inc + + test_fwd = nn.modules.module.register_module_forward_hook(lambda *args: fw_hook(1, *args)) + + module_1(input) + module_2(input) + module_3(input) + self.assertEqual(counter['forwards'], 3) + self.assertEqual(counter['backwards'], 0) + + test_bwd = nn.modules.module.register_module_backward_hook( + lambda *args: bw_hook(1, *args)) + + output_1 = module_1(input) + output_2 = module_2(input) + output_3 = module_3(input) + self.assertEqual(counter['forwards'], 6) + self.assertEqual(counter['backwards'], 0) + + output_1.backward(torch.ones(5, 5) * 2, retain_graph=True) + output_2.backward(torch.ones(5, 5) * 2, retain_graph=False) + output_3.backward(torch.ones(5, 5) * 2, retain_graph=False) + self.assertEqual(counter['forwards'], 6) + self.assertEqual(counter['backwards'], 3) + + output_1.backward(torch.ones(5, 5) * 2, retain_graph=True) + self.assertEqual(counter['forwards'], 6) + self.assertEqual(counter['backwards'], 4) + + test2_fwd = nn.modules.module.register_module_forward_hook(lambda *args: fw_hook(2, *args)) + + output = module_1(input) + output = module_2(input) + output = module_3(input) + self.assertEqual(counter['forwards'], 15) + self.assertEqual(counter['backwards'], 4) + + test2_bwd = nn.modules.module.register_module_backward_hook(lambda *args: bw_hook(2, *args)) + + module_1(input).backward(torch.ones(5, 5) * 2) + self.assertEqual(counter['forwards'], 18) + self.assertEqual(counter['backwards'], 7) + + test2_bwd.remove() + + module_2(input).backward(torch.ones(5, 5) * 2) + self.assertEqual(counter['forwards'], 21) + self.assertEqual(counter['backwards'], 8) + + test2_fwd.remove() + + module_3(input).backward(torch.ones(5, 5) * 2) + self.assertEqual(counter['forwards'], 22) + self.assertEqual(counter['backwards'], 9) + + test_fwd.remove() + test_bwd.remove() + + def test_module_global_hook_invalid_outputs(self): + module = nn.Sigmoid() + input = torch.randn(5, 5, requires_grad=True) + + def bw_fail1(self, grad_input, grad_output): + return grad_input[:-1] + + def bw_fail2(self, grad_input, grad_output): + return grad_input + (torch.randn(2, 2),) + + with nn.modules.module.register_module_backward_hook(bw_fail1): + with self.assertRaisesRegex(RuntimeError, 'got 0, but expected 1'): + module(input).sum().backward() + + with nn.modules.module.register_module_backward_hook(bw_fail2): + with self.assertRaisesRegex(RuntimeError, 'got 2, but expected 1'): + module(input).sum().backward() + + @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/847") + def test_module_backward_global_hook_writeable(self): + module = nn.Sigmoid() + input = torch.randn(5, 5, requires_grad=True) + sig_x = torch.sigmoid(input) + + def bw_hook(module, grad_input, grad_output): + for grad in grad_input: + self.assertTrue(isinstance(grad, torch.Tensor)) + for grad in grad_output: + self.assertTrue(isinstance(grad, torch.Tensor)) + return tuple(gi * 2 for gi in grad_input) + + nn.modules.module.register_module_backward_hook(bw_hook) + module(input).backward(torch.ones(5, 5)) + expected_grad = sig_x * (1 - sig_x) * 2 + self.assertEqual(input.grad, expected_grad) + + @skipIfTorchDynamo("TorchDynamo does not work well with hooks") + def test_module_global_forward_preforward_hook_writeable(self): + module = nn.Sigmoid() + input = torch.randn(5, 5, requires_grad=True) + sig_x = torch.sigmoid(input) + + def forward_pre_hook(m, input): + return torch.nn.functional.relu(input[0]) + + def forward_hook(m, input, output): + return -output + + nn.modules.module.register_module_forward_pre_hook(forward_pre_hook) + nn.modules.module.register_module_forward_hook(forward_hook) + output = module(input) + expected_res = -torch.sigmoid(torch.nn.functional.relu(input)) + self.assertEqual(output, expected_res) + output.backward(torch.ones(5, 5) * 2, retain_graph=True) + mask = (input > 0) + expected_grad = -sig_x * (1 - sig_x) * 2 * mask + self.assertEqual(input.grad, expected_grad) + + @skipIfTorchDynamo("TorchDynamo does not work well with hooks") + def test_module_forward_preforward_hook_removable(self): + """ + This test is to test when multiple pre-forward hook functions can be + registered successfully and used correctly, if the handle can be removable + during the pre-forward hook function call. + """ + module = nn.Sigmoid() + + def removable_hook(m, input): + nonlocal handle + handle.remove() + return input + + def removable_hook_2(m, input): + nonlocal handle_2 + handle_2.remove() + return input + + handle = module.register_forward_pre_hook(removable_hook) + handle_2 = module.register_forward_pre_hook(removable_hook_2) + + # make sure hook register is successful + self.assertEqual(len(handle.hooks_dict_ref()), 2) + self.assertEqual(len(handle_2.hooks_dict_ref()), 2) + + input = torch.randn(2, 2) + output = module(input) + self.assertEqual(torch.sigmoid(input), output) + + # make sure hook removal is successful + self.assertFalse(handle.id in handle.hooks_dict_ref()) + self.assertFalse(handle_2.id in handle.hooks_dict_ref()) + self.assertEqual(len(handle.hooks_dict_ref()), 0) + self.assertEqual(len(handle_2.hooks_dict_ref()), 0) + + @skipIfTorchDynamo("TorchDynamo does not work well with hooks") + def test_module_forward_forward_hook_removable(self): + """ + This test is to test when multiple forward hook functions can be registered + successfully and used correctly, if the handle can be removable during the + forward hook function call. + """ + module = nn.Sigmoid() + + def removable_hook(m, input, output): + nonlocal handle + handle.remove() + return output + + def removable_hook_2(m, input, output): + nonlocal handle_2 + handle_2.remove() + return output + + handle = module.register_forward_hook(removable_hook) + handle_2 = module.register_forward_hook(removable_hook_2) + + # make sure hook register is successful + self.assertEqual(len(handle.hooks_dict_ref()), 2) + self.assertEqual(len(handle_2.hooks_dict_ref()), 2) + + input = torch.randn(2, 2) + output = module(input) + self.assertEqual(torch.sigmoid(input), output) + + # make sure hook removal is successful + self.assertFalse(handle.id in handle.hooks_dict_ref()) + self.assertFalse(handle_2.id in handle.hooks_dict_ref()) + self.assertEqual(len(handle.hooks_dict_ref()), 0) + self.assertEqual(len(handle_2.hooks_dict_ref()), 0) + + @skipIfTorchDynamo("TorchDynamo does not work well with hooks") + def test_global_and_local_hooks_order(self): + module = nn.Sigmoid() + + global_forward_pre_called = False + local_forward_pre_called = False + global_forward_called = False + local_forward_called = False + global_backward_called = False + local_backward_called = False + + def global_forward_pre_hook(m, input): + nonlocal global_forward_pre_called + self.assertTrue(not local_forward_pre_called) + global_forward_pre_called = True + return input + + def local_forward_pre_hook(m, input): + nonlocal local_forward_pre_called + self.assertTrue(global_forward_pre_called) + local_forward_pre_called = True + return input + + def global_forward_hook(m, input, output): + nonlocal global_forward_called + self.assertTrue(not local_forward_called) + global_forward_called = True + return output + + def local_forward_hook(m, input, output): + nonlocal local_forward_called + self.assertTrue(global_forward_called) + local_forward_called = True + return output + + def global_backward_hook(m, input, output): + nonlocal global_backward_called + self.assertTrue(not local_backward_called) + global_backward_called = True + return input + + def local_backward_hook(m, input, output): + nonlocal local_backward_called + self.assertTrue(global_backward_called) + local_backward_called = True + return input + + input = torch.randn(5, 5, requires_grad=True) + nn.modules.module.register_module_forward_pre_hook(global_forward_pre_hook) + module.register_forward_pre_hook(local_forward_pre_hook) + nn.modules.module.register_module_forward_hook(global_forward_hook) + module.register_forward_hook(local_forward_hook) + nn.modules.module.register_module_backward_hook(global_backward_hook) + module.register_backward_hook(local_backward_hook) + + output = module(input) + self.assertTrue(local_forward_called and local_forward_pre_called and global_forward_called and global_forward_pre_called) + + output.backward(torch.ones(5, 5), retain_graph=True) + self.assertTrue(local_backward_called and global_backward_called) + + +class TestModuleHookNN(NNTestCase): + _do_cuda_memory_leak_check = True + _do_cuda_non_default_stream = True + + def _test_hooks(self, backward_register_fn): + module = nn.Sigmoid() + input = torch.ones(5, 5, requires_grad=True) + + counter = { + 'forwards': 0, + 'backwards': 0 + } + + def fw_hook(inc, h_module, input, output): + self.assertIsInstance(input, tuple) + self.assertTrue(isinstance(output, torch.Tensor)) + self.assertTrue(h_module is module) + self.assertEqual(input[0], torch.ones(5, 5)) + self.assertEqual(output, torch.empty(5, 5).fill_(1 / (1 + 1 / math.e))) + counter['forwards'] += inc + + def bw_hook(inc, h_module, grad_input, grad_output): + self.assertIsInstance(grad_input, tuple) + self.assertIsInstance(grad_output, tuple) + self.assertTrue(h_module is module) + self.assertEqual(grad_output[0], torch.ones(5, 5) * 2) + counter['backwards'] += inc + + # backward_pre_hook expects callback with only `module` and `grad_output` + # as arguments. + def bw_pre_hook(inc, h_module, grad_output): + self.assertIsInstance(grad_output, tuple) + self.assertTrue(h_module is module) + self.assertEqual(grad_output[0], torch.ones(5, 5) * 2) + counter['backwards'] += inc + + test_fwd = module.register_forward_hook(lambda *args: fw_hook(1, *args)) + + module(input) + module(input) + self.assertEqual(counter['forwards'], 2) + self.assertEqual(counter['backwards'], 0) + + bw_hook_fn = bw_pre_hook if backward_register_fn == 'register_full_backward_pre_hook' else bw_hook + test_bwd = getattr(module, backward_register_fn)( + lambda *args: bw_hook_fn(1, *args)) + + output = module(input) + self.assertEqual(counter['forwards'], 3) + self.assertEqual(counter['backwards'], 0) + + output.backward(torch.ones(5, 5) * 2, retain_graph=True) + self.assertEqual(counter['forwards'], 3) + self.assertEqual(counter['backwards'], 1) + + output.backward(torch.ones(5, 5) * 2, retain_graph=True) + self.assertEqual(counter['forwards'], 3) + self.assertEqual(counter['backwards'], 2) + + test2_fwd = module.register_forward_hook(lambda *args: fw_hook(2, *args)) + + output = module(input) + self.assertEqual(counter['forwards'], 6) + self.assertEqual(counter['backwards'], 2) + + test2_bwd = getattr(module, backward_register_fn)(lambda *args: bw_hook_fn(2, *args)) + + module(input).backward(torch.ones(5, 5) * 2) + self.assertEqual(counter['forwards'], 9) + self.assertEqual(counter['backwards'], 5) + + test2_bwd.remove() + + module(input).backward(torch.ones(5, 5) * 2) + self.assertEqual(counter['forwards'], 12) + self.assertEqual(counter['backwards'], 6) + + test2_fwd.remove() + + module(input).backward(torch.ones(5, 5) * 2) + self.assertEqual(counter['forwards'], 13) + self.assertEqual(counter['backwards'], 7) + + test_fwd.remove() + test_bwd.remove() + + @skipIfTorchDynamo("TorchDynamo does not work well with hooks") + def test_hooks(self): + self._test_hooks("register_backward_hook") + self._test_hooks("register_full_backward_hook") + self._test_hooks("register_full_backward_pre_hook") + + def test_hook_cpp(self): + bn = nn.BatchNorm1d(5) + + def hook(module, grad_inputs, grad_outputs): + self.assertEqual(len(grad_inputs), 1) + self.assertEqual(len(grad_outputs), 1) + self.assertEqual(module, bn) + + bn.register_full_backward_hook(hook) + output = bn(torch.randn(5, 5, requires_grad=True)) + output.sum().backward() + + @skipIfTorchDynamo("TorchDynamo does not work well with hooks") + def test_backward_hooks_interaction(self): + # Test to make sure that the grad_outputs + # updated by full_backward_pre_hook are received by + # the full_backward_hook + module = torch.nn.Sigmoid() + + cnt = {'backward_cnt': 0} + + def bw_pre_hook(m, grad_output): + cnt['backward_cnt'] += 1 + return (grad_output[0] * 0.5, ) + + def bw_hook(m, grad_in, grad_output): + self.assertEqual(torch.full_like(grad_output[0], 0.5), grad_output[0]) + cnt['backward_cnt'] += 1 + return grad_output + + module.register_full_backward_pre_hook(bw_pre_hook) + module.register_full_backward_hook(bw_hook) + + t = torch.ones(1, 2, requires_grad=True) + module(t).sum().backward() + self.assertEqual(cnt['backward_cnt'], 2) + + def test_hook_invalid_outputs(self): + module = nn.Sigmoid() + input = torch.randn(5, 5, requires_grad=True) + + def bw_fail1(self, grad_input, grad_output): + return grad_input[:-1] + + def bw_fail2(self, grad_input, grad_output): + return grad_input + (torch.randn(2, 2),) + + with module.register_backward_hook(bw_fail1): + with self.assertRaisesRegex(RuntimeError, 'got 0, but expected 1'): + module(input).sum().backward() + + with module.register_backward_hook(bw_fail2): + with self.assertRaisesRegex(RuntimeError, 'got 2, but expected 1'): + module(input).sum().backward() + + def bw_pre_fail1(self, grad_output): + return () + + def bw_pre_fail2(self, grad_output): + return grad_output + (torch.randn(2, 2),) + + with module.register_full_backward_pre_hook(bw_pre_fail1): + with self.assertRaisesRegex(RuntimeError, 'got 0, but expected 1'): + module(input).sum().backward() + + with module.register_full_backward_pre_hook(bw_pre_fail2): + with self.assertRaisesRegex(RuntimeError, 'got 2, but expected 1'): + module(input).sum().backward() + + def test_hook_requires_grad(self): + test_self = self + + class MyModule(nn.Module): + def forward(self, arg1, arg2, arg3): + test_self.assertTrue(arg1.requires_grad) + test_self.assertFalse(arg2.requires_grad) + test_self.assertTrue(arg3.requires_grad) + return arg1.sum() + arg2.sum() + arg3.sum() + + inp = torch.rand(2, requires_grad=True) + mod = MyModule() + + mod(inp, inp.detach(), inp) + # Ensure that requires grad is properly propagated + mod.register_full_backward_hook(lambda mod, gI, gO: None) + mod(inp, inp.detach(), inp) + + @skipIfTorchDynamo("TorchDynamo does not work well with hooks") + def test_hook_no_requires_grad(self): + mod = nn.Linear(2, 3) + + inp = torch.rand(1, 2) + + return_val = "None" + hook_called = [0] + + def hook(mod, grad_input, grad_output): + hook_called[0] += 1 + for gI in grad_input: + self.assertIsNone(gI) + for gO in grad_output: + self.assertEqual(gO.size(), (1, 3)) + + if return_val == "grad_input": + return grad_input + elif return_val == "invalid": + # If the inputs were requiring gradients, this would be + # a valid return + return inp + elif return_val == "None": + return None + else: + raise RuntimeError("Invalid return_val string") + + mod.register_full_backward_hook(hook) + + # This should run and trigger the hook properly + mod(inp).sum().backward() + self.assertEqual(hook_called[0], 1) + + return_val = "grad_input" + + mod(inp).sum().backward() + self.assertEqual(hook_called[0], 2) + + return_val = "invalid" + with self.assertRaisesRegex(RuntimeError, "where no input requires gradient"): + mod(inp).sum().backward() + + def test_hook_last_arg_requires_grad(self): + mod = nn.L1Loss() + inp = torch.rand(1, requires_grad=True) + mod.register_full_backward_hook(lambda m, gI, gO: None) + + try: + mod(inp.detach(), inp) + except Exception as ex: + self.fail("Unexpected exception: %s" % ex) + + def test_hook_extra_input(self): + class MyModule(nn.Module): + def forward(self, non_tensor, tensor): + return tensor.clone(), non_tensor + + inp = torch.rand(2, requires_grad=True) + mod = MyModule() + + def hook(mod, grad_input, grad_output): + self.assertIsNone(grad_input[0]) + self.assertIsInstance(grad_input[1], torch.Tensor) + + self.assertIsInstance(grad_output[0], torch.Tensor) + self.assertIsNone(grad_output[1]) + + mod.register_full_backward_hook(hook) + out, _ = mod(True, inp) + out.sum().backward() + + def test_hook_inplace(self): + class MyModule(nn.Module): + def forward(self, inp, do_inplace): + self.inp = inp + if do_inplace: + inp += 1 + return inp.clone() + + hook_called = [0] + + def hook(mod, grad_input, grad_output): + hook_called[0] += 1 + + def hook_pre(mod, grad_output): + hook_called[0] += 1 + + inp = torch.rand(10, requires_grad=True) + mod = MyModule() + for hook_fn, register_fn in [(hook, mod.register_full_backward_hook), + (hook_pre, mod.register_full_backward_pre_hook)]: + hook_called[0] = 0 + with register_fn(hook_fn): + # No inplace should work + mod(inp, False).sum().backward() + self.assertEqual(hook_called[0], 1) + + # Input inplace error should throw an error + with self.assertRaisesRegex(RuntimeError, "Output 0 of BackwardHookFunctionBackward is " + "a view and is being modified inplace."): + mod(inp.clone(), True) + + # Input inplace error should throw an error if we try to re-use the view after they have + # been modified + local_inp = inp.clone() + out = mod(local_inp, False) + local_inp[0] *= 1 + with self.assertRaisesRegex(RuntimeError, "Output 0 of BackwardHookFunctionBackward is " + "a view and its base or another view"): + # Any operation involving the view will fail here + mod.inp + 2 + + # Output inplace error should throw an error + out = mod(inp, False) + with self.assertRaisesRegex(RuntimeError, "BackwardHookFunctionBackward is a view " + "and is being modified inplace."): + out += 1 + + def test_hook_non_full_warning(self): + def noop(*args): + pass + + a = torch.rand(2, requires_grad=True) + b = torch.rand(2, requires_grad=True) + + # Check invalid input container + class MyModule(nn.Module): + def forward(self, l): + return l[0].clone(), l[1].clone() + + m = MyModule() + m.register_backward_hook(noop) + + with self.assertWarnsRegex(UserWarning, "does not take as input a single Tensor or a tuple of Tensors"): + m([a, b]) + + # Check invalid output container + class MyModule(nn.Module): + def forward(self, a, b): + return [a.clone(), b.clone()] + + m = MyModule() + m.register_backward_hook(noop) + + with self.assertWarnsRegex(UserWarning, "does not return a single Tensor or a tuple of Tensors"): + m(a, b) + + # Check invalid output from different Nodes + class MyModule(nn.Module): + def forward(self, a, b): + return a.clone(), b.clone() + + m = MyModule() + m.register_backward_hook(noop) + + with self.assertWarnsRegex(UserWarning, "outputs are generated by different autograd Nodes"): + m(a, b) + + # Check invalid forward with multiple Nodes + class MyModule(nn.Module): + def forward(self, a): + return a.clone().clone() + + m = MyModule() + m.register_backward_hook(noop) + + with self.assertWarnsRegex(UserWarning, "the forward contains multiple autograd Nodes"): + m(a) + + def test_hook_backward_size(self): + # Make module with multiple operations in forward + # And different size for input and outputs + class MyModule(nn.Module): + def forward(self, arg1, arg2): + tmp = arg1.sum() * arg2 + tmp = tmp + arg2.sum() * arg1.sum() + tmp = tmp.sum().view(1) + tmp = tmp.expand(8).contiguous() + return tmp + + module = MyModule() + inp1 = torch.randn(5, 5, requires_grad=True) + inp2 = torch.randn(10, 10, requires_grad=True) + + def bw_hook(module, grad_input, grad_output): + self.assertEqual(len(grad_input), 2) + self.assertEqual(grad_input[0].size(), torch.Size([5, 5])) + self.assertEqual(grad_input[1].size(), torch.Size([10, 10])) + self.assertEqual(len(grad_output), 1) + self.assertEqual(grad_output[0].size(), torch.Size([8])) + + with module.register_full_backward_hook(bw_hook): + module(inp1, inp2).sum().backward() + + @skipIfTorchDynamo("TorchDynamo does not work well with hooks") + def test_hook_backward_writeable(self): + module = nn.Sigmoid() + input = torch.randn(5, 5, requires_grad=True) + sig_x = torch.nn.functional.sigmoid(input) + + def bw_hook(module, grad_input, grad_output): + for grad in grad_input: + self.assertTrue(isinstance(grad, torch.Tensor)) + for grad in grad_output: + self.assertTrue(isinstance(grad, torch.Tensor)) + return tuple(gi * 2 for gi in grad_input) + + module.register_backward_hook(bw_hook) + module(input).backward(torch.ones(5, 5)) + expected_grad = sig_x * (1 - sig_x) * 2 + self.assertEqual(input.grad, expected_grad) + + @skipIfTorchDynamo("TorchDynamo does not work well with hooks") + def test_hook_forward_preforward_writable(self): + module = nn.Sigmoid() + input = torch.randn(5, 5, requires_grad=True) + sig_x = torch.nn.functional.sigmoid(input) + + def forward_pre_hook(m, input): + return torch.nn.functional.relu(input[0]) + + def forward_hook(m, input, output): + return -output + + module.register_forward_pre_hook(forward_pre_hook) + module.register_forward_hook(forward_hook) + output = module(input) + expected_res = -torch.nn.functional.sigmoid(torch.nn.functional.relu(input)) + self.assertEqual(output, expected_res) + output.backward(torch.ones(5, 5) * 2, retain_graph=True) + mask = (input > 0) + expected_grad = -sig_x * (1 - sig_x) * 2 * mask + self.assertEqual(input.grad, expected_grad) + + def test_hook_buffer_registration(self): + for return_buffer in (True, False): + def buffer_registration_hook(module, name, buffer): + buffer.registered = True + if return_buffer: + return buffer + handle = torch.nn.modules.module.register_module_buffer_registration_hook( + buffer_registration_hook + ) + try: + l, n, s = _create_basic_net() + for b in s.buffers(): + self.assertTrue(getattr(b, "registered", False)) + finally: + handle.remove() + + def test_hook_submodule_registration(self): + for return_submodule in (True, False): + def module_registration_hook(module, name, submodule): + module.registered = True + submodule.registered = True + if return_submodule: + return submodule + handle = torch.nn.modules.module.register_module_module_registration_hook( + module_registration_hook + ) + try: + l, n, s = _create_basic_net() + for m in s.modules(): + self.assertTrue(getattr(m, "registered", False)) + finally: + handle.remove() + + def test_hook_parameter_registration(self): + for return_parameter in (True, False): + def parameter_registration_hook(module, name, parameter): + parameter.registered = True + if return_parameter: + return parameter + handle = torch.nn.modules.module.register_module_parameter_registration_hook( + parameter_registration_hook + ) + try: + l, n, s = _create_basic_net() + for p in s.parameters(): + self.assertTrue(getattr(p, "registered", False)) + finally: + handle.remove() + + +if __name__ == "__main__": + run_tests() diff --git a/test/nn/test_multihead_attention.py b/test/nn/test_multihead_attention.py new file mode 100644 index 0000000000000..9c622ffe6e897 --- /dev/null +++ b/test/nn/test_multihead_attention.py @@ -0,0 +1,689 @@ +# Owner(s): ["module: nn"] +import contextlib +import random +import unittest +import unittest.mock as mock + +from torch.nn import MultiheadAttention +from torch.testing._internal.common_device_type import instantiate_device_type_tests, dtypes, \ + onlyCUDA +from torch.testing._internal.common_nn import NNTestCase +from torch.testing._internal.common_utils import run_tests, \ + TEST_NUMPY, TEST_WITH_CROSSREF, \ + parametrize as parametrize_test, instantiate_parametrized_tests +import torch.nn as nn +import torch + +if TEST_NUMPY: + import numpy as np + + +# WARNING: If you add a new top-level test case to this file, you MUST +# update test/run_test.py to list it, otherwise it will NOT be run in +# CI. + +class TestMultiheadAttentionNN(NNTestCase): + _do_cuda_memory_leak_check = True + _do_cuda_non_default_stream = True + + @unittest.skipIf(not TEST_NUMPY, "numpy not found") + @parametrize_test("average_attn_weights", [True, False]) + def test_multihead_attention(self, average_attn_weights): + def _scaled_dot_attn_ref(Q, K, V, dims, unseen_mask=None, key_padding_mask=None, + average_attn_weights=average_attn_weights): + """ Numpy-based reference implementation of scaled dot attention + for testing""" + + QKT = _batchmatmul( + Q, + np.transpose(K, axes=[0, 1, 3, 2]) + / np.sqrt(dims[3], dtype=np.float32), # divide by sqrt(d_head) + ) + b1, b2, s1, s2 = QKT.shape + if unseen_mask is not None or key_padding_mask is not None: + # assert s1 == s2 + for i in range(b1): + for j in range(b2): + for m in range(s1): + for n in range(s2): + if unseen_mask is not None and unseen_mask[m][n] == 0: + QKT[i, j, m, n] = -np.inf + if key_padding_mask is not None and key_padding_mask[i][n]: + QKT[i, j, m, n] = -np.inf + + reference = _softmax(QKT) + ref_attn_weight = reference + if average_attn_weights: + ref_attn_weight = np.sum(ref_attn_weight, axis=1) / b2 + reference = _batchmatmul(reference, V) + return reference, ref_attn_weight + + def _batchmatmul(a, b): # batchmatmul over 4 dim matrix + """ Numpy-based batch matrix multiply over 4 dim matrix""" + assert a.shape[0] == b.shape[0] + assert a.shape[1] == b.shape[1] + retval = np.zeros( + (a.shape[0], a.shape[1], a.shape[2], b.shape[3]), dtype=np.float32 + ) + for i in range(a.shape[0]): + for j in range(a.shape[1]): + retval[i, j, :, :] = np.matmul(a[i, j, :, :], b[i, j, :, :]) + return retval + + def _softmax(x): # softmax over 4 dim matrix + """ Numpy-based reference softmax over 4 dim matrix""" + np.seterr(invalid='ignore') + output = np.zeros(x.shape, dtype=np.float64) + for i in range(x.shape[0]): + for j in range(x.shape[1]): + for k in range(x.shape[2]): + x_curr = x[i, j, k, :] + e_x = np.exp(x_curr - np.amax(x_curr)) + output[i, j, k, :] = e_x / np.sum(e_x) + return output + + def _split_heads_ref(X, dims, nheads, d_head): + X_split = np.reshape(X, dims[:2] + [nheads, d_head]) + X_split_transposed = np.transpose(X_split, [0, 2, 1, 3]) + reference = np.reshape(X_split_transposed, [dims[0], nheads, dims[1], d_head]) + return reference + + def _combine_heads_ref(X, dims, nheads, d_head): + X_transposed = np.transpose(X, [0, 2, 1, 3]) + reference = np.reshape(X_transposed, dims[:2] + [nheads * d_head]) + return reference + + def _fc(X, X_weight, X_bias): + X_fc_b = X_bias.detach().numpy() + X_fc_w = X_weight.detach().numpy() + return np.matmul(X, np.transpose(X_fc_w)) + X_fc_b + + def _create_src_lengths_mask(batch_size, src_lengths): + """ + Generate boolean mask to prevent attention beyond the end of source + Inputs: + batch_size : int + src_lengths : [batch_size] of sentence lengths + Outputs: + [batch_size, max_src_len] + """ + max_srclen = src_lengths.max() + src_indices = torch.arange(0, max_srclen).unsqueeze(0).to(src_lengths) + src_indices = src_indices.expand(batch_size, max_srclen) + src_lengths = src_lengths.unsqueeze(dim=1).expand(batch_size, max_srclen) + # returns [batch_size, max_seq_len] + return (src_indices < src_lengths).int().detach() + + def _multihead_attn_test_helper(add_key_padding_mask=False, add_bias_kv=False, add_zero_attn=False, + saved_kv=False, same_embed_dim=False, + average_attn_weights=average_attn_weights): + for _ in range(100): + batch_sz, seq_len = [random.randint(2, 10) for r in range(2)] + d_head = random.randint(3, 10) + nheads = random.randint(2, 5) * 2 + d_model = d_head * nheads + if same_embed_dim: + kv_dim = d_model + else: + kv_dim = random.randint(5, 20) + dims = [batch_sz, seq_len, kv_dim] + + saved_k = None + saved_k_tensor = None + saved_v = None + saved_v_tensor = None + if saved_kv: + saved_k = np.random.rand(batch_sz * nheads, seq_len, d_head) + saved_k_tensor = torch.from_numpy(saved_k).to(torch.get_default_dtype()) + saved_v = np.random.rand(batch_sz * nheads, seq_len, d_head) + saved_v_tensor = torch.from_numpy(saved_v).to(torch.get_default_dtype()) + + key_padding_mask = None + key_padding_mask_tensor = None + if add_key_padding_mask: + seq_mask = np.random.randint(0, 2, (1, seq_len)) + key_padding_mask = (np.repeat(seq_mask, batch_sz, axis=0) == 1) + key_padding_mask_tensor = torch.from_numpy(key_padding_mask) + decoder_state = np.random.rand(batch_sz, d_model) + K = np.random.rand(*dims) + V = K + Q = np.expand_dims(decoder_state, 1) + attn_mask = np.random.randint(0, 2, size=(1, seq_len)) + attn_mask_tensor = torch.from_numpy(attn_mask).float() + attn_mask_tensor.masked_fill_(attn_mask_tensor == 0, float('-inf')) + attn_mask_tensor.masked_fill_(attn_mask_tensor > 0, float('0.0')) + attn_mask_tensor = attn_mask_tensor.double() + + decoder_state_tensor = torch.from_numpy(decoder_state).to(torch.get_default_dtype()) + source_hid_tensor = torch.from_numpy(K).to(torch.get_default_dtype()).transpose(0, 1) + + multihead_attn_module = MultiheadAttention(d_model, nheads, + add_bias_kv=add_bias_kv, + add_zero_attn=add_zero_attn, + kdim=kv_dim, vdim=kv_dim) + + if add_bias_kv: + bias_k = multihead_attn_module.bias_k.detach().numpy() + bias_v = multihead_attn_module.bias_v.detach().numpy() + else: + bias_k = None + bias_v = None + + _Q = decoder_state_tensor.unsqueeze(1).transpose(0, 1) + _V = source_hid_tensor + _K = source_hid_tensor + + if multihead_attn_module._qkv_same_embed_dim: + result, result_weight = torch.nn.functional.multi_head_attention_forward( + _Q, _K, _V, + d_model, nheads, + multihead_attn_module.in_proj_weight, multihead_attn_module.in_proj_bias, + multihead_attn_module.bias_k, multihead_attn_module.bias_v, + multihead_attn_module.add_zero_attn, multihead_attn_module.dropout, + multihead_attn_module.out_proj.weight, multihead_attn_module.out_proj.bias, + multihead_attn_module.training, key_padding_mask_tensor, True, attn_mask_tensor, + static_k=saved_k_tensor, static_v=saved_v_tensor, + average_attn_weights=average_attn_weights) + else: + result, result_weight = torch.nn.functional.multi_head_attention_forward( + _Q, _K, _V, + d_model, nheads, + None, multihead_attn_module.in_proj_bias, + multihead_attn_module.bias_k, multihead_attn_module.bias_v, + multihead_attn_module.add_zero_attn, multihead_attn_module.dropout, + multihead_attn_module.out_proj.weight, multihead_attn_module.out_proj.bias, + multihead_attn_module.training, key_padding_mask_tensor, True, attn_mask_tensor, + True, multihead_attn_module.q_proj_weight, + multihead_attn_module.k_proj_weight, multihead_attn_module.v_proj_weight, + static_k=saved_k_tensor, static_v=saved_v_tensor, + average_attn_weights=average_attn_weights) + + result = result.squeeze(0).detach().numpy() + + if multihead_attn_module._qkv_same_embed_dim: + q_proj_weight = multihead_attn_module.in_proj_weight[:d_model] + k_proj_weight = multihead_attn_module.in_proj_weight[d_model:(d_model * 2)] + v_proj_weight = multihead_attn_module.in_proj_weight[(d_model * 2):] + else: + q_proj_weight = multihead_attn_module.q_proj_weight + k_proj_weight = multihead_attn_module.k_proj_weight + v_proj_weight = multihead_attn_module.v_proj_weight + + Q_fc = _fc(Q, q_proj_weight, multihead_attn_module.in_proj_bias[:d_model]) + K_fc = _fc(K, k_proj_weight, multihead_attn_module.in_proj_bias[d_model:(d_model * 2)]) + V_fc = _fc(V, v_proj_weight, multihead_attn_module.in_proj_bias[(d_model * 2):]) + + if add_bias_kv: + K_fc = np.concatenate((K_fc, np.repeat(bias_k, K_fc.shape[0], axis=0)), axis=1) + V_fc = np.concatenate((V_fc, np.repeat(bias_v, V_fc.shape[0], axis=0)), axis=1) + if attn_mask is not None: + attn_mask = np.concatenate((attn_mask, np.ones([1, 1])), axis=1) + if key_padding_mask is not None: + key_padding_mask = np.concatenate( + (key_padding_mask, np.full((batch_sz, 1), False, dtype=bool)), axis=1) + dims[1] += 1 + Q_split = _split_heads_ref( + Q_fc, [batch_sz, 1, d_model], nheads, d_head + ) + + if saved_k is not None: + K_split = np.reshape(saved_k, [dims[0], nheads, dims[1], d_head]) + else: + K_split = _split_heads_ref(K_fc, dims, nheads, d_head) + + if saved_v is not None: + V_split = np.reshape(saved_v, [dims[0], nheads, dims[1], d_head]) + else: + V_split = _split_heads_ref(V_fc, dims, nheads, d_head) + + if add_zero_attn: + dims[1] += 1 + K_split = np.concatenate( + (K_split, np.zeros([K_split.shape[0], K_split.shape[1], 1, K_split.shape[3]])), axis=2) + V_split = np.concatenate( + (V_split, np.zeros([V_split.shape[0], V_split.shape[1], 1, V_split.shape[3]])), axis=2) + + if attn_mask is not None: + attn_mask = np.concatenate((attn_mask, np.ones([1, 1])), axis=1) + + if key_padding_mask is not None: + key_padding_mask = np.concatenate( + (key_padding_mask, np.full((batch_sz, 1), False, dtype=bool)), axis=1) + attn_heads, ref_attn_weight = _scaled_dot_attn_ref( + Q=Q_split, + K=K_split, + V=V_split, + dims=Q_split.shape, + unseen_mask=attn_mask, + key_padding_mask=key_padding_mask + ) + combined_attn_heads = _combine_heads_ref( + X=attn_heads, dims=[batch_sz, 1], nheads=nheads, d_head=d_head + ) + + reference = _fc(combined_attn_heads, multihead_attn_module.out_proj.weight, + multihead_attn_module.out_proj.bias) + reference = np.squeeze(reference, axis=1) + + # result = reference + self.assertEqual(tuple(result.shape), (batch_sz, d_model)) + np.testing.assert_allclose(result, reference, atol=1e-5) + + # result_weight = ref_attn_weight + result_weight = result_weight.detach().numpy() + self.assertEqual(tuple(result_weight.shape), tuple(ref_attn_weight.shape)) + np.testing.assert_allclose(result_weight, ref_attn_weight, atol=1e-5) + + def test_multihead_attn_add_bias_kv(): + _multihead_attn_test_helper(add_bias_kv=True) + + def test_multihead_attn_add_zero_attn(): + _multihead_attn_test_helper(add_zero_attn=True) + + def test_multihead_attn_no_masking(): + _multihead_attn_test_helper() + + def test_multihead_attn_key_padding_mask(): + _multihead_attn_test_helper(add_key_padding_mask=True) + + def test_multihead_attn_saved_kv(): + _multihead_attn_test_helper(saved_kv=True) + + def test_multihead_attn_add_bias_kv_zero_attn(): + _multihead_attn_test_helper(add_key_padding_mask=True, add_bias_kv=True, + add_zero_attn=True) + + def test_multihead_attn_all_arguments1(): + _multihead_attn_test_helper(add_key_padding_mask=True, add_zero_attn=True, saved_kv=True) + + def test_multihead_attn_all_arguments2(): + _multihead_attn_test_helper(add_key_padding_mask=True, add_bias_kv=True, + add_zero_attn=True, saved_kv=True) + + def test_multihead_attn_all_arguments3(): + _multihead_attn_test_helper(add_key_padding_mask=True, add_zero_attn=True, + saved_kv=True, same_embed_dim=True) + + test_multihead_attn_add_zero_attn() # Test MultiheadAttention with add_zero_attn + test_multihead_attn_add_bias_kv() # Test MultiheadAttention with add_bias_kv + test_multihead_attn_no_masking() # Test MultiheadAttention without masking + test_multihead_attn_key_padding_mask() # Test MultiheadAttention with src lengths + test_multihead_attn_saved_kv() # Test MultiheadAttention with static kv. + test_multihead_attn_add_bias_kv_zero_attn() # Test MultiheadAttention with bias_kv and zero_attn. + test_multihead_attn_all_arguments1() # Test MultiheadAttention with all the argument. + with self.assertRaisesRegex(AssertionError, "bias cannot be added to static key."): + test_multihead_attn_all_arguments2() # Test MultiheadAttention with all the argument. + test_multihead_attn_all_arguments3() # Test MultiheadAttention with all the argument. + + def test_multihead_attn_3d_attn_mask(self): + embed_dim = 8 + num_heads = 4 + batch_size = 8 + src_len = 3 + tgt_len = 2 + + query = torch.rand(batch_size, tgt_len, embed_dim) # [N, T, D] + key = torch.rand(batch_size, src_len, embed_dim) # [N, S, D] + value = key # [N, S, D] + attn_mask = torch.randint(0, 2, (batch_size, tgt_len, src_len)).float() # [N, T, S] + attn_mask = attn_mask.masked_fill(attn_mask == 0, float('-inf')).masked_fill(attn_mask == 1, float(0.0)) + + mta_model = torch.nn.MultiheadAttention(embed_dim, num_heads) + + # Generate 3D results + attn_mask_3d = torch.repeat_interleave(attn_mask, num_heads, dim=0) # [N * H, T, S] + output_3d = mta_model(query.transpose(0, 1), key.transpose( + 0, 1), value.transpose(0, 1), attn_mask=attn_mask_3d)[0] + output_3d = output_3d.transpose(0, 1) # [N, T, D] + + for i in range(0, batch_size): + output_2d = mta_model(query[i].unsqueeze(0).transpose(0, 1), + key[i].unsqueeze(0).transpose(0, 1), + value[i].unsqueeze(0).transpose(0, 1), + attn_mask=attn_mask[i])[0] + + # output_2d in shape of [T, 1, D] + self.assertEqual(output_3d[i].unsqueeze(0).transpose(0, 1), output_2d) + + def test_multihead_attn_no_bias(self): + embed_dim = 8 + num_heads = 4 + mha = torch.nn.MultiheadAttention(embed_dim, num_heads, bias=False) + + # Verify that bias=False applies to both in and out projection layers. + self.assertIsNone(mha.in_proj_bias) + self.assertIsNone(mha.out_proj.bias) + + def _test_multihead_attn_invalid_shape_impl(self, mha): + # Batched (3D) query cases + query = torch.randn(4, 4, 4) + key = torch.randn(4, 4, 4) + value = torch.randn(4, 4, 4) + + msg = "expected `key` and `value` to be 3-D but found 2-D and 3-D tensors respectively" + # 3D query, 2D key and 3D value + with self.assertRaisesRegex(AssertionError, msg): + mha(query, torch.randn(4, 4), value) + + msg = "expected `key` and `value` to be 3-D but found 3-D and 2-D tensors respectively" + # 3D query, 3D key and 2D value + with self.assertRaisesRegex(AssertionError, msg): + mha(query, key, torch.randn(4, 4)) + + msg = "expected `key_padding_mask` to be `None` or 2-D but found 1-D tensor instead" + # 3D query, 3D key, 3D value and 1D key_padding_mask + with self.assertRaisesRegex(AssertionError, msg): + mha(query, key, value, key_padding_mask=torch.tensor([False, False, True, True], dtype=torch.bool)) + + msg = "expected `attn_mask` to be `None`, 2-D or 3-D but found 1-D tensor instead" + # 3D query, 3D key, 3D value and 1D attn_mask + with self.assertRaisesRegex(AssertionError, msg): + mha(query, key, value, attn_mask=torch.tensor([False, False, True, True], dtype=torch.bool)) + + # Unbatched (2D) query cases + query = torch.randn(4, 4) + key = torch.randn(4, 4) + value = torch.randn(4, 4) + + msg = "expected `key` and `value` to be 2-D but found 3-D and 2-D tensors respectively" + # 2D query, 3D key and 2D value + with self.assertRaisesRegex(AssertionError, msg): + mha(query, torch.randn(4, 4, 4), value) + + msg = "expected `key` and `value` to be 2-D but found 2-D and 3-D tensors respectively" + # 2D query, 3D key and 2D value + with self.assertRaisesRegex(AssertionError, msg): + mha(query, key, torch.randn(4, 4, 4)) + + msg = "expected `key_padding_mask` to be `None` or 1-D but found 2-D tensor instead" + # 2D query, 2D key, 2D value and 1D key_padding_mask + with self.assertRaisesRegex(AssertionError, msg): + mha(query, key, value, key_padding_mask=torch.tensor([[False, False, True, True] * 2], dtype=torch.bool)) + + msg = "expected `attn_mask` to be `None`, 2-D or 3-D but found 1-D tensor instead" + # 2D query, 2D key, 2D value and 1D attn_mask + with self.assertRaisesRegex(AssertionError, msg): + mha(query, key, value, attn_mask=torch.tensor([False, False, True, True], dtype=torch.bool)) + + msg = r"Expected `attn_mask` shape to be \(4, 4, 4\)" + # 2D query, 2D key, 2D value and 3D incorrect attn_mask + with self.assertRaisesRegex(AssertionError, msg): + mha(query, key, value, attn_mask=torch.randn(5, 4, 4).bernoulli_().to(torch.bool)) + + def test_multihead_attn_invalid_shape(self): + mha = torch.nn.MultiheadAttention(4, 4) + self._test_multihead_attn_invalid_shape_impl(mha) + # Give the test a chance to hit the fast path. (Right now, it + # won't, but gating may be less restricted in the future.) + with torch.no_grad(): + self._test_multihead_attn_invalid_shape_impl(mha.eval()) + + @torch.no_grad() + def test_multihead_attn_fast_path_invalid_shape(self): + mha = torch.nn.MultiheadAttention(4, 4, batch_first=True).eval() + + # Batched (3D) query cases + query = torch.randn(4, 4, 4) + key = torch.randn(4, 4, 4) + value = torch.randn(4, 4, 4) + + # Currently, this case will just go to the slow path and get + # the usual message because it fails the requirement to be + # batched. + msg = "expected `key` and `value` to be 3-D but found 2-D and 3-D tensors respectively" + # 3D query, 2D key and 3D value + with self.assertRaisesRegex(AssertionError, msg): + mha(query, torch.randn(3, 3), value, need_weights=False) + + # Currently, this case will just go to the slow path and get + # the usual message because it fails the requirement to be + # batched. + msg = "expected `key` and `value` to be 3-D but found 3-D and 2-D tensors respectively" + # 3D query, 3D key and 2D value + with self.assertRaisesRegex(AssertionError, msg): + mha(query, key, torch.randn(3, 3), need_weights=False) + + msg = "expected `key_padding_mask` to be `None` or 2-D but found 1-D tensor instead" + # 3D query, 3D key, 3D value and 1D key_padding_mask + with self.assertRaisesRegex(AssertionError, msg): + mha(query, key, value, key_padding_mask=torch.tensor( + [False, True, True], dtype=torch.bool), need_weights=False) + + msg = "expected `attn_mask` to be `None`, 2-D or 3-D but found 1-D tensor instead" + # 3D query, 3D key, 3D value and 1D attn_mask + with self.assertRaisesRegex(AssertionError, msg): + mha(query, key, value, attn_mask=torch.tensor([False, True, True], dtype=torch.bool), need_weights=False) + + # Unbatched (2D) query cases + # NOTE: error messages are the same as regular path because the fast path doesn't support 2D. + query = torch.randn(4, 4) + key = torch.randn(4, 4) + value = torch.randn(4, 4) + + msg = "expected `key` and `value` to be 2-D but found 3-D and 2-D tensors respectively" + # 2D query, 3D key and 2D value + with self.assertRaisesRegex(AssertionError, msg): + mha(query, torch.randn(4, 4, 4), value) + + msg = "expected `key` and `value` to be 2-D but found 2-D and 3-D tensors respectively" + # 2D query, 3D key and 2D value + with self.assertRaisesRegex(AssertionError, msg): + mha(query, key, torch.randn(4, 4, 4)) + + msg = "expected `key_padding_mask` to be `None` or 1-D but found 2-D tensor instead" + # 2D query, 2D key, 2D value and 1D key_padding_mask + with self.assertRaisesRegex(AssertionError, msg): + mha(query, key, value, key_padding_mask=torch.tensor([[False, False, True, True] * 2], dtype=torch.bool)) + + msg = "expected `attn_mask` to be `None`, 2-D or 3-D but found 1-D tensor instead" + # 2D query, 2D key, 2D value and 1D attn_mask + with self.assertRaisesRegex(AssertionError, msg): + mha(query, key, value, attn_mask=torch.tensor([False, False, True, True], dtype=torch.bool)) + + msg = r"Expected `attn_mask` shape to be \(4, 4, 4\)" + # 2D query, 2D key, 2D value and 3D incorrect attn_mask + with self.assertRaisesRegex(AssertionError, msg): + mha(query, key, value, attn_mask=torch.randn(5, 4, 4).bernoulli_().to(torch.bool)) + + def test_multihead_attn_nested_tensor_outside_fast_path(self): + mha = torch.nn.MultiheadAttention(4, 4, batch_first=True).eval() + nt = torch.nested.nested_tensor([torch.randn(4, 4)]) + # One tested platform (linux-bionic-py3.7-clang) has a torch_function for one + # or more of these. Take advantage of that to test the torch_function bailout. + has_torch_func = torch.overrides.has_torch_function( + (nt, mha.in_proj_weight, mha.in_proj_bias, mha.out_proj.weight, mha.out_proj.bias)) + if has_torch_func: + msg = "MultiheadAttention does not support NestedTensor.*argument has_torch_function" + else: + msg = ("MultiheadAttention does not support NestedTensor outside of its fast path.*grad is " + + "enabled and.*or biases requires_grad") + with self.assertRaisesRegex(AssertionError, msg): + mha(nt, nt, nt) + + if has_torch_func: + # Just give up, they're all going to fail with the same message. + return + + with torch.no_grad(): + mha(nt, nt, nt) + with torch.inference_mode(): + mha(nt, nt, nt) + nt = torch.nested.nested_tensor([torch.randn(4, 4, requires_grad=False)]) + nt.requires_grad = False + with self.assertRaisesRegex(AssertionError, msg): + mha(nt, nt, nt) + mha.in_proj_weight.requires_grad = False + mha.in_proj_bias.requires_grad = False + mha.out_proj.weight.requires_grad = False + mha.out_proj.bias.requires_grad = False + mha(nt, nt, nt) + + +class TestMultiheadAttentionNNDeviceType(NNTestCase): + def test_multihead_self_attn_two_masks_fast_path(self, device): + """ + Multihead self-attention should give the same result on the fast path (BetterTransformer) as on the slow path + when both attention mask (mask type 0) and key padding mask (mask type 1) are provided + """ + with torch.no_grad(): + embed_dim = 14 + num_heads = 7 + batch_size = 8 + src_len = 5 + + query = value = key = torch.rand(batch_size, src_len, embed_dim).to(device) + # Create masks of two different types + attn_mask = torch.randint(0, 2, (src_len, src_len)).bool().to(device) + key_padding_mask = torch.randint(0, 2, (batch_size, src_len)).bool().to(device) + + # We'll need expanded versions of the masks for masking out the outputs below + attn_mask_expanded = attn_mask.reshape(1, 1, src_len, src_len) \ + .expand(batch_size, num_heads, src_len, src_len) + key_padding_mask_expanded = key_padding_mask.reshape(batch_size, 1, 1, src_len) \ + .expand(batch_size, num_heads, src_len, src_len) + merged_mask = attn_mask_expanded.logical_or(key_padding_mask_expanded) + + # Compute attention on the fast path + mta_model = torch.nn.MultiheadAttention(embed_dim, num_heads, batch_first=True, device=device) + mta_model.training = False + result_fast_path, _ = mta_model(query, key, value, attn_mask=attn_mask, key_padding_mask=key_padding_mask) + + # Compute attention on the slow path + result_ref, _ = torch.nn.functional.multi_head_attention_forward(query.transpose(0, 1), + key.transpose(0, 1), + value.transpose(0, 1), + embed_dim, num_heads, + mta_model.in_proj_weight, + mta_model.in_proj_bias, + mta_model.bias_k, mta_model.bias_v, + mta_model.add_zero_attn, + mta_model.dropout, + mta_model.out_proj.weight, + mta_model.out_proj.bias, + training=mta_model.training, + key_padding_mask=key_padding_mask, + need_weights=False, + attn_mask=attn_mask, + use_separate_proj_weight=False, + q_proj_weight=mta_model.q_proj_weight, + k_proj_weight=mta_model.k_proj_weight, + v_proj_weight=mta_model.v_proj_weight, + average_attn_weights=False, + ) + result_ref = result_ref.transpose(0, 1) # Convert to batch-first + + # Rows which are completely masked out are nan, we need to exclude them from comparison + mask_out = merged_mask[:, 0, :, :].all(-1, keepdim=True).expand(batch_size, src_len, embed_dim) + result_fast_path_masked = result_fast_path.masked_fill(mask_out, 0) + result_ref_masked = result_ref.masked_fill(mask_out, 0) + + self.assertEqual(result_fast_path_masked, result_ref_masked) + + @torch.no_grad() + @unittest.skipIf(TEST_WITH_CROSSREF, 'CrossRef turns on TorchFunctionMode, and so disables fastpath.') + def test_multihead_self_attn_two_masks_fast_path_mock(self, device): + """ + Multihead self-attention should take fast path when both attention mask (mask type 0) + and key padding mask (mask type 1) are provided at the same time on CPU and CUDA + """ + if device not in ['cpu', 'cuda']: + self.skipTest("Fastpath only runs on CPU and CUDA.") + with torch.autocast(device_type=device, enabled=False): + embed_dim = 14 + num_heads = 7 + batch_size = 8 + src_len = 5 + + query = value = key = torch.rand(batch_size, src_len, embed_dim).to(device) + # Create masks of two different types + attn_mask = torch.randint(0, 2, (src_len, src_len)).bool().to(device) + key_padding_mask = torch.randint(0, 2, (batch_size, src_len)).bool().to(device) + + with mock.patch('torch._native_multi_head_attention') as fastpath_mock: + # Compute attention on the fast path + mta_model = torch.nn.MultiheadAttention(embed_dim, num_heads, batch_first=True, device=device).eval() + mta_model.training = False + mta_model(query, key, value, attn_mask=attn_mask, key_padding_mask=key_padding_mask) + # If mock was called, fastpath was taken + self.assertTrue(fastpath_mock.called) + + @onlyCUDA + @dtypes(torch.half, torch.float, torch.double) + def test_multihead_attention_dtype(self, device, dtype): + embed_dim = 128 + num_heads = 8 + sl = 10 + bs = 8 + model = nn.MultiheadAttention(embed_dim, num_heads).cuda().to(dtype) + q = torch.randn(sl, bs, embed_dim, device=device, dtype=dtype) + k = torch.randn(sl, bs, embed_dim, device=device, dtype=dtype) + v = torch.randn(sl, bs, embed_dim, device=device, dtype=dtype) + out = model(q, k, v) + self.assertEqual(q.size(), out[0].size()) + self.assertEqual(dtype, out[0].dtype) + + @onlyCUDA + @dtypes(torch.half, torch.float, torch.double) + def test_multihead_attention_dtype_batch_first(self, device, dtype): + embed_dim = 128 + num_heads = 8 + sl = 10 + bs = 8 + # With batch_first=True, we have the possibility of hitting + # the native fast path if we call .eval() and enable inference + # mode. Test both paths. + for training in (True, False): + model = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True).cuda().to(dtype) + if not training: + model = model.eval() + cm = torch.no_grad() + else: + cm = contextlib.nullcontext() + with cm: + q = torch.randn(bs, sl, embed_dim, device=device, dtype=dtype) + k = torch.randn(bs, sl, embed_dim, device=device, dtype=dtype) + v = torch.randn(bs, sl, embed_dim, device=device, dtype=dtype) + # fast path currently doesn't support weights + out = model(q, k, v, need_weights=False) + self.assertEqual(q.size(), out[0].size()) + self.assertEqual(dtype, out[0].dtype) + + @dtypes(torch.double) + @torch.no_grad() + def test_multihead_attn_fast_path_query_and_bias_have_different_dtypes(self, device, dtype): + mha = torch.nn.MultiheadAttention(4, 4, batch_first=True, dtype=dtype, device=device).eval() + mha.in_proj_bias = torch.nn.Parameter(mha.in_proj_bias.to(torch.half).to(device)) + query = torch.randn(4, 4, 4, dtype=dtype, device=device) + mha(query, query, query) + + @dtypes(torch.double) + @torch.no_grad() + def test_multihead_attn_fast_path_small_test(self, device, dtype): + mha = torch.nn.MultiheadAttention(4, 4, batch_first=True, dtype=dtype, device=device).eval() + query = torch.randn(4, 4, 4, dtype=dtype, device=device) + mha(query, query, query) + + @dtypes(torch.double) + @torch.no_grad() + def test_multihead_attn_in_proj_bias_none(self, device, dtype): + mha = torch.nn.MultiheadAttention(2, 2, bias=False, dtype=dtype, device=device) + query = torch.rand(2, 2, 2, dtype=dtype, device=device) + mha(query, query, query) + + @dtypes(torch.double) + @torch.no_grad() + def test_multihead_attn_in_proj_weight_none(self, device, dtype): + # Setting kdim == vdim == 2 means that vdim != embed_dim + # will cause the logic to use per-input project weights, thereby + # forcing self.in_proj_weight = None + mha = torch.nn.MultiheadAttention(4, 4, vdim=2, kdim=2, dtype=dtype, device=device) + query = torch.rand(4, 4, 4, dtype=dtype, device=device) + key = torch.rand(4, 4, 2, dtype=dtype, device=device) + mha(query, key, key) + + +instantiate_device_type_tests(TestMultiheadAttentionNNDeviceType, globals()) +instantiate_parametrized_tests(TestMultiheadAttentionNN) + +if __name__ == '__main__': + run_tests() diff --git a/test/nn/test_parametrization.py b/test/nn/test_parametrization.py new file mode 100644 index 0000000000000..0ba361d310d3f --- /dev/null +++ b/test/nn/test_parametrization.py @@ -0,0 +1,1525 @@ +# Owner(s): ["module: nn"] +from copy import deepcopy +from itertools import product + +import torch + +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.init as init +import torch.nn.utils.parametrize as parametrize +from torch.nn import Parameter +from torch.testing._internal.common_utils import run_tests, skipIfNoLapack, \ + TemporaryFileName, instantiate_parametrized_tests, set_default_dtype +from torch.testing._internal.common_cuda import TEST_MULTIGPU +from torch.testing._internal.common_nn import NNTestCase +from torch.testing._internal.common_utils import gradcheck + + +class TestNNParametrization(NNTestCase): + _do_cuda_memory_leak_check = True + _do_cuda_non_default_stream = True + + # FIXME: Rewrite this test using functions not depending on LAPACK + # and remove the `@skipIfNoLapack` (see #70995) + # torch/nn/utils/parametrize + @skipIfNoLapack + def test_register_and_remove_parametrization(self): + r"""Test that it is possible to add a few parametrizations + on a parameter or a buffer and that removing them restores the initial state + It also tests that backpropagating through them works as expected + """ + # Define a couple matrix parametrizations + class Skew(nn.Module): + def forward(self, X): + X = X.tril(-1) + return X - X.T + + class Orthogonal(nn.Module): + def forward(self, X): + # Cayley map + # If X is skew-symmetric it returns an orthogonal matrix + Id = torch.eye(X.size(0), device=X.device) + # We call contiguous because solve returns a tensor with strides that are Fortran-contiguous + # and autograd raises a performance warning. + # This happens when we remove the parametrization with leave_parametrized=True, + # which does a set_ with a non-contiguous tensor while the gradient is contiguous + return torch.linalg.solve(Id + X, Id - X).contiguous() + + class Resize(nn.Module): + def forward(self, X): + return X[[0]] + + class NoResize(nn.Module): + def forward(self, X): + return X + + # Define a couple vector parametrizations + class FirstZero(nn.Module): + def forward(self, x): + return torch.cat([x.new_zeros(1), x[1:]]) + + class LastZero(nn.Module): + def forward(self, x): + return torch.cat([x[:-1], x.new_zeros(1)]) + + model = nn.Linear(8, 8) + initial_weight_id = id(model.weight) + initial_bias_id = id(model.bias) + initial_model = deepcopy(model) + + # Test unsafe flag + with self.assertRaisesRegex(ValueError, "Registering a parametrization may not change the shape of the tensor"): + parametrize.register_parametrization(model, "weight", Resize()) # default unsafe = False + model(torch.ones(8, 8)) + + # One parametrization with unsafe=True + parametrize.register_parametrization(model, "weight", Resize(), unsafe=True) + self.assertTrue(hasattr(model, "parametrizations")) + self.assertTrue(parametrize.is_parametrized(model)) + self.assertTrue(parametrize.is_parametrized(model, "weight")) + self.assertFalse(parametrize.is_parametrized(model, "bias")) + self.assertNotIn("weight", model._parameters) + A = model.weight + self.assertTrue(A.shape[0] == 1) + parametrize.remove_parametrizations(model, "weight", leave_parametrized=False) + self.assertFalse(hasattr(model, "parametrizations")) + self.assertEqual(model.weight, initial_model.weight) + self.assertEqual(id(model.weight), initial_weight_id) + self.assertEqual(model.__class__, nn.Linear) + + # Two parametrizations with unsafe=True + parametrize.register_parametrization(model, "weight", Resize(), unsafe=True) + parametrize.register_parametrization(model, "weight", NoResize(), unsafe=False) + self.assertTrue(hasattr(model, "parametrizations")) + self.assertTrue(parametrize.is_parametrized(model)) + self.assertTrue(parametrize.is_parametrized(model, "weight")) + self.assertFalse(parametrize.is_parametrized(model, "bias")) + self.assertNotIn("weight", model._parameters) + A = model.weight + self.assertTrue(A.shape[0] == 1) + parametrize.remove_parametrizations(model, "weight", leave_parametrized=False) + self.assertFalse(hasattr(model, "parametrizations")) + self.assertEqual(model.weight, initial_model.weight) + self.assertEqual(id(model.weight), initial_weight_id) + self.assertEqual(model.__class__, nn.Linear) + + # Test unsafe flag doesn't change expected behavior + parametrize.register_parametrization(model, "weight", Skew(), unsafe=True) + self.assertTrue(hasattr(model, "parametrizations")) + self.assertTrue(parametrize.is_parametrized(model)) + self.assertTrue(parametrize.is_parametrized(model, "weight")) + self.assertFalse(parametrize.is_parametrized(model, "bias")) + self.assertNotIn("weight", model._parameters) + # Result should be skew-symmetric + A = model.weight + self.assertEqual(A, -A.T) + # Remove and check consistency + parametrize.remove_parametrizations(model, "weight", leave_parametrized=False) + self.assertFalse(hasattr(model, "parametrizations")) + self.assertEqual(model.weight, initial_model.weight) + self.assertEqual(id(model.weight), initial_weight_id) + self.assertEqual(model.__class__, nn.Linear) + + # Test one parametrization + parametrize.register_parametrization(model, "weight", Skew()) + self.assertTrue(hasattr(model, "parametrizations")) + self.assertTrue(parametrize.is_parametrized(model)) + self.assertTrue(parametrize.is_parametrized(model, "weight")) + self.assertFalse(parametrize.is_parametrized(model, "bias")) + self.assertNotIn("weight", model._parameters) + # Result should be skew-symmetric + A = model.weight + self.assertEqual(A, -A.T) + # Remove and check consistency + parametrize.remove_parametrizations(model, "weight", leave_parametrized=False) + self.assertFalse(hasattr(model, "parametrizations")) + self.assertEqual(model.weight, initial_model.weight) + self.assertEqual(id(model.weight), initial_weight_id) + self.assertEqual(model.__class__, nn.Linear) + + # Test two parametrizations at the same time and removing them + parametrize.register_parametrization(model, "weight", Skew()) + parametrize.register_parametrization(model, "weight", Orthogonal()) + # Result should be orthogonal + X = model.weight + Id = torch.eye(X.size(0), device=X.device) + self.assertEqual(X.T @ X, Id) + # Structure tests + self.assertTrue(hasattr(model, "parametrizations")) + self.assertTrue(parametrize.is_parametrized(model)) + self.assertTrue(parametrize.is_parametrized(model, "weight")) + self.assertFalse(parametrize.is_parametrized(model, "bias")) + self.assertIn("weight", model.parametrizations) + self.assertNotIn("weight", model._parameters) + # Remove + parametrize.remove_parametrizations(model, "weight", leave_parametrized=False) + self.assertEqual(model.weight, initial_model.weight) + self.assertEqual(id(model.weight), initial_weight_id) + self.assertFalse(hasattr(model, "parametrizations")) + self.assertEqual(model.__class__, nn.Linear) + + # Add everything + parametrize.register_parametrization(model, "weight", Skew()) + parametrize.register_parametrization(model, "weight", Orthogonal()) + parametrize.register_parametrization(model, "bias", FirstZero()) + parametrize.register_parametrization(model, "bias", LastZero()) + + # Basic tests + self.assertTrue(parametrize.is_parametrized(model)) + self.assertTrue(parametrize.is_parametrized(model, "weight")) + self.assertTrue(parametrize.is_parametrized(model, "bias")) + self.assertEqual(model.bias[0].item(), 0.) + self.assertEqual(model.bias[-1].item(), 0.) + self.assertEqual(len(list(model.parameters())), 2) # Nothing weird has happpened + # Should not throw + + sgd = torch.optim.SGD(model.parameters(), lr=0.01) + + weight_copy = model.weight.clone() + bias_copy = model.bias.clone() + sgd.zero_grad() + (model.weight.T @ model.bias).sum().backward() + sgd.step() + self.assertNotEqual(model.weight, weight_copy) + self.assertNotEqual(model.bias, bias_copy) + + # Remove first parametrization. + # Check that the model is still parametrized and so is the second parameter + parametrize.remove_parametrizations(model, "weight", leave_parametrized=False) + self.assertTrue(parametrize.is_parametrized(model)) # Still parametrized + self.assertFalse(parametrize.is_parametrized(model, "weight")) # Parametrization removed + self.assertTrue(parametrize.is_parametrized(model, "bias")) # Still parametrized + self.assertEqual(model.bias[0].item(), 0.) # Still parametrized + self.assertEqual(model.bias[-1].item(), 0.) # Still parametrized + self.assertNotEqual(model.weight, initial_model.weight) # Has been updated + self.assertEqual(id(model.weight), initial_weight_id) # Keeps the same id + self.assertEqual(len(list(model.parameters())), 2) # Nothing weird has happened + # Should not throw + weight_copy = model.weight.clone() + bias_copy = model.bias.clone() + sgd.zero_grad() + (model.weight.T @ model.bias).sum().backward() + sgd.step() + self.assertNotEqual(model.weight, weight_copy) + self.assertNotEqual(model.bias, bias_copy) + + # Remove the second parametrization. + # Check that the module is not parametrized + parametrize.remove_parametrizations(model, "bias", leave_parametrized=False) + self.assertFalse(parametrize.is_parametrized(model)) # Not parametrized + self.assertNotEqual(model.bias, initial_model.bias) # Has been updated + self.assertNotEqual(model.bias[0].item(), 0.) # Not parametrized + self.assertNotEqual(model.bias[-1].item(), 0.) # Not parametrized + self.assertEqual(id(model.bias), initial_bias_id) # Keeps the same id + self.assertFalse(hasattr(model, "parametrizations")) # Not parametrized the module + self.assertEqual(model.__class__, nn.Linear) # Resores the previous class + self.assertEqual(len(list(model.parameters())), 2) # Nothing weird has happeed + + # Should not throw things are updated + weight_copy = model.weight.clone() + bias_copy = model.bias.clone() + sgd.zero_grad() + (model.weight.T @ model.bias).sum().backward() + sgd.step() + self.assertNotEqual(model.weight, weight_copy) + self.assertNotEqual(model.bias, bias_copy) + + # Test leave_parametrized=True + for _ in range(2): + parametrize.register_parametrization(model, "weight", Skew()) + parametrize.register_parametrization(model, "weight", Orthogonal()) + parametrize.remove_parametrizations(model, "weight", leave_parametrized=True) + # We didn't change the dtype nor had multiple inputs, so the id should be the same + self.assertEqual(id(model.weight), initial_weight_id) + self.assertEqual(id(model.bias), initial_bias_id) + + # Should not throw. Things are updated + weight_copy = model.weight.clone() + bias_copy = model.bias.clone() + sgd.zero_grad() + (model.weight.T @ model.bias).sum().backward() + sgd.step() + self.assertNotEqual(model.weight, weight_copy) + self.assertNotEqual(model.bias, bias_copy) + + def test_register_and_remove_nested_parametrization(self): + r"""Test that it is possible to nest the parametrizations + meaning that the original param is parametrized again + """ + class Skew(nn.Module): + def forward(self, X): + X = X.tril(-1) + return X - X.T + + model = nn.Linear(8, 8) + # Add top level parametrization + parametrize.register_parametrization(model, "weight", Skew()) + self.assertTrue(hasattr(model, "parametrizations")) + self.assertTrue(parametrize.is_parametrized(model)) + self.assertTrue(parametrize.is_parametrized(model, "weight")) + self.assertFalse(parametrize.is_parametrized(model, "bias")) + self.assertNotIn("weight", model._parameters) + # Result should be skew-symmetric + A = model.weight + self.assertEqual(A, -A.T) + + # Add nested parametrization + param_mod = model.parametrizations.weight + self.assertFalse(hasattr(param_mod, "parametrizations")) + self.assertFalse(parametrize.is_parametrized(param_mod)) + self.assertFalse(parametrize.is_parametrized(param_mod, "original")) + + parametrize.register_parametrization(param_mod, "original", Skew()) + self.assertTrue(hasattr(param_mod, "parametrizations")) + self.assertTrue(parametrize.is_parametrized(param_mod)) + self.assertTrue(parametrize.is_parametrized(param_mod, "original")) + self.assertNotIn("original", param_mod._parameters) + # Result should be skew-symmetric + A = param_mod.original + self.assertEqual(A, -A.T) + + # Remove nested param and check consistency + parametrize.remove_parametrizations(param_mod, "original", leave_parametrized=False) + self.assertFalse(hasattr(param_mod, "parametrizations")) + self.assertEqual(param_mod.__class__, parametrize.ParametrizationList) + + # Remove top level and check consistency + parametrize.remove_parametrizations(model, "weight", leave_parametrized=False) + self.assertFalse(hasattr(model, "parametrizations")) + self.assertEqual(model.__class__, nn.Linear) + + def test_register_and_remove_buffer_parametrization(self): + r"""Test that it is possible to add and remove parametrizations on buffers""" + # Define a couple vector parametrizations + class FirstZero(nn.Module): + def forward(self, x): + return torch.cat([x.new_zeros(1), x[1:]]) + + class LastZero(nn.Module): + def forward(self, x): + return torch.cat([x[:-1], x.new_zeros(1)]) + + model = nn.Linear(8, 8) + + # Instantiate parametrizations on buffers. It should work as expected + delattr(model, "bias") + model.register_buffer("bias", torch.ones(8)) + parametrize.register_parametrization(model, "bias", FirstZero()) + parametrize.register_parametrization(model, "bias", LastZero()) + self.assertTrue(parametrize.is_parametrized(model)) + self.assertTrue(parametrize.is_parametrized(model, "bias")) + self.assertEqual(model.bias[0].item(), 0.) + self.assertEqual(model.bias[-1].item(), 0.) + self.assertTrue((model.bias[1:-1] == torch.ones(6)).all()) + self.assertEqual(len(list(model.parameters())), 1) + + # Remove parametrizations on buffers. It should work as expected + parametrize.remove_parametrizations(model, "bias", leave_parametrized=True) + self.assertFalse(parametrize.is_parametrized(model)) + self.assertFalse(parametrize.is_parametrized(model, "bias")) + self.assertEqual(model.bias[0].item(), 0.) + self.assertEqual(model.bias[-1].item(), 0.) + self.assertTrue((model.bias[1:-1] == torch.ones(6)).all()) + self.assertEqual(len(list(model.parameters())), 1) + + # FIXME: Rewrite this test using functions not depending on LAPACK + # and remove the `@skipIfNoLapack` (see #70995) + @skipIfNoLapack + def test_serialization_parametrization(self): + r"""Test that it is possible to serialize a parametrized model via state_dict""" + # A stateful parametrization + class Orthogonal(nn.Module): + def __init__(self, n): + super().__init__() + self.register_buffer("id", torch.eye(n)) + self.register_buffer("B", torch.empty(n, n)) + init.orthogonal_(self.B) + + def forward(self, X): + A = X.triu(1) + A = A - A.T + return self.B @ torch.linalg.solve(self.id + A, self.id - A) + + def get_model(): + model = torch.nn.Sequential( + torch.nn.Linear(5, 5), + torch.nn.ReLU(), + torch.nn.Linear(5, 1), + ) + + parametrize.register_parametrization(model[0], "weight", Orthogonal(5)) + return model + + model = get_model() + + prev_weight = model[0].weight + prev_B = model[0].parametrizations.weight[0].B + + new_model = get_model() + with TemporaryFileName() as fname: + torch.save(model.state_dict(), fname) + new_model.load_state_dict(torch.load(fname)) + + # Integrity tests + self.assertTrue(parametrize.is_parametrized(new_model[0], "weight")) + self.assertEqual(prev_weight, new_model[0].weight) + self.assertEqual(prev_B, new_model[0].parametrizations.weight[0].B) + + # Trying to save the whole parametrized model raises + with self.assertRaisesRegex(RuntimeError, "state_dict"): + with TemporaryFileName() as fname: + torch.save(model, fname) + + # FIXME: Rewrite this test using functions not depending on LAPACK + # and remove the `@skipIfNoLapack` (see #70995) + @skipIfNoLapack + def test_initialization_parametrization(self): + r"""Test that it is possible to initialize a parametrization when it + implements a `right_inverse` method + """ + class Skew(nn.Module): + def forward(self, X): + A = X.triu(1) + return A - A.T + + def is_skew(self, A): + return torch.allclose(A, -A.T, atol=1e-6) + + def right_inverse(self, X): + if not self.is_skew(X): + raise ValueError("The matrix is not skew-symmetric.") + return X.triu(1) + + # Implements a Cayley map where right_inverse is not quite the inverse of forward + class Orthogonal(nn.Module): + def __init__(self, n): + super().__init__() + self.register_buffer("B", torch.eye(n)) + + def forward(self, X): + Id = torch.eye(X.size(0)) + return self.B @ torch.linalg.solve(Id + X, Id - X) + + def is_orthogonal(self, X): + Id = torch.eye(X.size(0)) + return torch.allclose(X.T @ X, Id, atol=1e-4) + + def right_inverse(self, X): + if not self.is_orthogonal(X): + raise ValueError("The input is not orthogonal.") + # cayley(0) == Id, so B @ cayley(0) == B + self.B = X + return torch.zeros_like(X) + + N = 5 + model = nn.Linear(N, N) + # Register the skew-symmetric constraint. The result is now skew-symmetric + skew = Skew() + # Make the weight skew-symmetric before registering the parametrization + with torch.no_grad(): + model.weight.set_(skew(model.weight)) + parametrize.register_parametrization(model, "weight", skew) + X = torch.rand(N, N) + # X is not skew-symmetric, so it throws an error + with self.assertRaises(ValueError): + model.weight = X + # Make X skew-symmetric + X = X - X.T + model.weight = X + self.assertEqual(model.parametrizations.weight.original, X.triu(1)) + self.assertEqual(model.weight, X) + + # Having several parametrizations registered should work in the same way + parametrize.register_parametrization(model, "weight", Orthogonal(N)) + # Register now the Cayley map. The result is now orthogonal + X = torch.rand(N, N) + # X is not orthogonal, so it throws an error + with self.assertRaises(ValueError): + model.weight = X + init.orthogonal_(X) + model.weight = X + self.assertEqual(model.weight, X) + self.assertEqual(model.parametrizations.weight.original, torch.zeros_like(X)) + + def test_errors_unparametrized_tensor_parametrization(self): + # Test errors when registering a parametrization on an unparametrized tensor + module = nn.Linear(3, 4) + weight_init = module.weight.clone() + + class Identity(nn.Module): + def forward(self, x): + return x + + # Register a parametrization on a non-existing parameter throws + with self.assertRaisesRegex(ValueError, "does not have a parameter"): + parametrize.register_parametrization(module, "foo", Identity()) + self.assertFalse(parametrize.is_parametrized(module)) + + # Removing parametrizations from an unparametrized tensor throws + with self.assertRaisesRegex(ValueError, "does not have a parametrization"): + parametrize.remove_parametrizations(module, "bias") + self.assertFalse(parametrize.is_parametrized(module)) + + # A correct parametrization with several outputs + class Sum(nn.Module): + def forward(self, x, y): + return x + y + + def right_inverse(self, z): + return z, torch.zeros_like(z) + + parametrize.register_parametrization(module, "weight", Sum()) + # Cannot remove a parametrization with several outputs with `leave_parametrized=False` + with self.assertRaisesRegex(ValueError, "leave_parametrized=False"): + parametrize.remove_parametrizations(module, "weight", leave_parametrized=False) + parametrize.remove_parametrizations(module, "weight", leave_parametrized=True) + + # A parametrization with an incorrect number of outputs + class WrongNumberParams(nn.Module): + def forward(self, x, y, z): + return x + y + z + + def right_inverse(self, w): + return w, torch.zeros_like(w) + + # Makes param(*param.right_inverse(X)) fail + with self.assertRaisesRegex(TypeError, "positional argument"): + parametrize.register_parametrization(module, "weight", WrongNumberParams()) + self.assertFalse(parametrize.is_parametrized(module)) + + # A parametrization with a right_inverse that does not return a Tensor or Sequence[Tensor] + class WrongRightInverse(Identity): + def right_inverse(self, z): + return None + + # right_inverse should return a Tensor or a Sequence[Tensor] + with self.assertRaisesRegex(ValueError, "Tensor or a Sequence of"): + parametrize.register_parametrization(module, "weight", WrongRightInverse()) + self.assertFalse(parametrize.is_parametrized(module)) + + # If it's a sequence, it must to be a sequence of tensors + class WrongRightInverseSequence(nn.Module): + def forward(self, x, y): + return x + + def right_inverse(self, z): + return None, z + + with self.assertRaisesRegex(ValueError, "of the sequence with type"): + parametrize.register_parametrization(module, "weight", WrongRightInverseSequence()) + self.assertFalse(parametrize.is_parametrized(module)) + + # A parametrization from one tensor to one tensor that changes the dtype + class ChangeDtypeInverse(nn.Module): + def forward(self, x): + return x.float() + + def right_inverse(self, w): + return w.bool() + + # For parametrizations that return one tensor, right_inverse may not change the dtype + with self.assertRaisesRegex(ValueError, "outputs one tensor, it may not change the dtype"): + parametrize.register_parametrization(module, "weight", ChangeDtypeInverse()) + self.assertFalse(parametrize.is_parametrized(module)) + + # Doesn't return a tensor + class NotTensor(nn.Module): + def forward(self, x): + return 2 + + # Forward must return a tensor + with self.assertRaisesRegex(ValueError, "must return a tensor"): + parametrize.register_parametrization(module, "weight", NotTensor()) + self.assertFalse(parametrize.is_parametrized(module)) + + # A parametrization from one tensor to one tensor that changes the dtype + class ChangeDtype(nn.Module): + def forward(self, x): + return x.bool() + + # forward should not change the initial dtype + with self.assertRaisesRegex(ValueError, "may not change the dtype"): + parametrize.register_parametrization(module, "weight", ChangeDtype()) + self.assertFalse(parametrize.is_parametrized(module)) + + # Change shape + class ChangeShape(nn.Module): + def forward(self, x): + return x[:-1] + + # forward should not change the original shape + with self.assertRaisesRegex(ValueError, "may not change the shape"): + parametrize.register_parametrization(module, "weight", ChangeShape()) + self.assertFalse(parametrize.is_parametrized(module)) + + # Many to one that changes dtype + class ChangeDtypeMulti(nn.Module): + def forward(self, x, y): + return (x + y).bool() + + def right_inverse(self, w): + return w, w + 1 + + # forward should not change the original shape even for parametrizations with many inputs + with self.assertRaisesRegex(ValueError, "may not change the dtype"): + parametrize.register_parametrization(module, "weight", ChangeDtypeMulti()) + self.assertFalse(parametrize.is_parametrized(module)) + + # Returning a sequence of size one, although weird, it's correct + class SequenceLen1(nn.Module): + def forward(self, x): + return x + + def right_inverse(self, w): + return (w,) + + parametrize.register_parametrization(module, "weight", SequenceLen1()) + self.assertTrue(hasattr(module.parametrizations.weight, "original0")) + self.assertFalse(hasattr(module.parametrizations.weight, "original1")) + _ = module.weight # Does not throw + self.assertTrue(parametrize.is_parametrized(module)) + parametrize.remove_parametrizations(module, "weight", leave_parametrized=True) + + # None of the operations above should have altered the weight + self.assertFalse(parametrize.is_parametrized(module)) + self.assertEqual(module.weight, weight_init) + + def test_errors_parametrized_tensor_parametrization(self): + # Test errors when registering a parametrization on a parametrized tensor + + class Identity(nn.Module): + def forward(self, x): + return x + + module = nn.Linear(3, 4) + parametrize.register_parametrization(module, "weight", Identity()) + + # Has to return a tensor + class WrongReturn(nn.Module): + def forward(self, x): + return x, x + + with self.assertRaisesRegex(ValueError, "must return a tensor"): + parametrize.register_parametrization(module, "weight", WrongReturn()) + self.assertTrue(parametrize.is_parametrized(module)) + self.assertEqual(len(module.parametrizations.weight), 1) + self.assertTrue(isinstance(module.parametrizations.weight[0], Identity)) + + # Cannot change dtype + class ChangeDtype(nn.Module): + def forward(self, x): + return x.bool() + + with self.assertRaisesRegex(ValueError, "may not change the dtype"): + parametrize.register_parametrization(module, "weight", ChangeDtype()) + self.assertTrue(parametrize.is_parametrized(module)) + self.assertEqual(len(module.parametrizations.weight), 1) + self.assertTrue(isinstance(module.parametrizations.weight[0], Identity)) + + # Cannot change shape + class ChangeShape(nn.Module): + def forward(self, x): + return x[:-1] + + with self.assertRaisesRegex(ValueError, "may not change the shape"): + parametrize.register_parametrization(module, "weight", ChangeShape()) + self.assertTrue(parametrize.is_parametrized(module)) + self.assertEqual(len(module.parametrizations.weight), 1) + self.assertTrue(isinstance(module.parametrizations.weight[0], Identity)) + + # The following checks are mostly due to bugs in the code of the parametrization + + # right_inverse has to return a tensor + class WrongReturnInverse(Identity): + def right_inverse(self, x): + return x, x + + with self.assertRaisesRegex(ValueError, "right_inverse must return a tensor"): + parametrize.register_parametrization(module, "weight", WrongReturnInverse()) + self.assertTrue(parametrize.is_parametrized(module)) + self.assertEqual(len(module.parametrizations.weight), 1) + self.assertTrue(isinstance(module.parametrizations.weight[0], Identity)) + + # Cannot change dtype + class ChangeDtypeInverse(Identity): + def right_inverse(self, x): + return x.bool() + + with self.assertRaisesRegex(ValueError, "must have the same dtype"): + parametrize.register_parametrization(module, "weight", ChangeDtypeInverse()) + self.assertTrue(parametrize.is_parametrized(module)) + self.assertEqual(len(module.parametrizations.weight), 1) + self.assertTrue(isinstance(module.parametrizations.weight[0], Identity)) + + # Cannot change shape + class ChangeShapeInverse(Identity): + def right_inverse(self, x): + return x[:-1] + + with self.assertRaisesRegex(ValueError, "must have the same shape"): + parametrize.register_parametrization(module, "weight", ChangeShapeInverse()) + self.assertTrue(parametrize.is_parametrized(module)) + self.assertEqual(len(module.parametrizations.weight), 1) + self.assertTrue(isinstance(module.parametrizations.weight[0], Identity)) + + # FIXME: Rewrite this test using functions not depending on LAPACK + # and remove the `@skipIfNoLapack` (see #70995) + @skipIfNoLapack + def test_multiple_inputs_parametrization(self): + # A parametrization with several outputs + class RankOne(nn.Module): + def forward(self, x, y): + # Form a rank-1 matrix from a pair of vectors + return x.unsqueeze(-1) @ y.unsqueeze(-2) + + def right_inverse(self, Y): + # We project the given matrix onto the rank 1 matrices + U, S, Vh = torch.linalg.svd(Y, full_matrices=False) + # S is ordered in a decreasing way. + s0_sqrt = S[0].sqrt().unsqueeze(-1) + return U[..., :, 0] * s0_sqrt, Vh[..., 0, :] * s0_sqrt + + # Simple parametrisation + class Double(nn.Module): + def forward(self, x): + return 2.0 * x + + def right_inverse(self, w): + return 0.5 * w + + model = nn.Linear(3, 3) + # Test one parametrization + parametrize.register_parametrization(model, "weight", RankOne()) + self.assertTrue(hasattr(model, "parametrizations")) + self.assertTrue(parametrize.is_parametrized(model)) + self.assertTrue(parametrize.is_parametrized(model, "weight")) + self.assertTrue(hasattr(model.parametrizations.weight, "original0")) + self.assertIn("original0", model.parametrizations.weight._parameters) + self.assertTrue(hasattr(model.parametrizations.weight, "original1")) + self.assertIn("original1", model.parametrizations.weight._parameters) + self.assertFalse(parametrize.is_parametrized(model, "bias")) + self.assertNotIn("weight", model._parameters) + # Result should be rank 1 + self.assertEqual(torch.linalg.matrix_rank(model.weight).item(), 1) + + with self.assertRaisesRegex(ValueError, "leave_parametrized=False"): + # Cannot remove a parametrization with multiple inputs and not leave it parametrized + parametrize.remove_parametrizations(model, "weight", leave_parametrized=False) + # Remove parametrization and check consistency + parametrize.remove_parametrizations(model, "weight", leave_parametrized=True) + self.assertFalse(hasattr(model, "parametrizations")) + self.assertEqual(model.__class__, nn.Linear) + self.assertFalse(parametrize.is_parametrized(model)) + self.assertEqual(torch.linalg.matrix_rank(model.weight).item(), 1) + self.assertIn("weight", model._parameters) + + # Registering parametrizations with one input on top of one with multiple inputs should work + init_weight = model.weight.clone() + parametrize.register_parametrization(model, "weight", RankOne()) + # Projecting a rank 1 matrix onto the matrices of rank one does not change the matrix + self.assertEqual(init_weight, model.weight) + parametrize.register_parametrization(model, "weight", Double()) + # The matrix now is twice the initial matrix + self.assertEqual(2.0 * init_weight, model.weight) + # Multiplying by a scalar does not change the rank + self.assertEqual(torch.linalg.matrix_rank(model.weight).item(), 1) + + # The model has now three parameters + self.assertEqual(len(list(model.parameters())), 3) + + sgd = torch.optim.SGD(model.parameters(), lr=0.1) + + # Test backward. Should not throw + for _ in range(2): + sgd.zero_grad() + loss = (model.weight.T @ model.bias).sum() + loss.backward() + sgd.step() + + # Same drill as before, removing should work as expected + with self.assertRaisesRegex(ValueError, "leave_parametrized=False"): + # Cannot remove a parametrization with multiple inputs and not leave it parametrized + parametrize.remove_parametrizations(model, "weight", leave_parametrized=False) + # Remove parametrization and check consistency + parametrize.remove_parametrizations(model, "weight", leave_parametrized=True) + self.assertFalse(hasattr(model, "parametrizations")) + self.assertEqual(model.__class__, nn.Linear) + self.assertFalse(parametrize.is_parametrized(model)) + self.assertEqual(torch.linalg.matrix_rank(model.weight).item(), 1) + self.assertIn("weight", model._parameters) + + # The model has now two parameters + self.assertEqual(len(list(model.parameters())), 2) + + # Test backward. Should not throw + sgd = torch.optim.SGD(model.parameters(), lr=0.1) + for _ in range(2): + sgd.zero_grad() + loss = (model.weight.T @ model.bias).sum() + loss.backward() + sgd.step() + + # FIXME: Rewrite this test using functions not depending on LAPACK + # and remove the `@skipIfNoLapack` (see #70995) + @skipIfNoLapack + def test_caching_parametrization(self): + r"""Test the caching system of a parametrization""" + # Define a couple matrix parametrizations + class Skew(nn.Module): + def forward(self, X): + X = X.tril(-1) + return X - X.T + + class Orthogonal(nn.Module): + def forward(self, X): + Id = torch.eye(X.size(0), device=X.device) + return torch.linalg.solve(Id + X, Id - X) + + model = nn.Linear(5, 5) + parametrize.register_parametrization(model, "weight", Skew()) + parametrize.register_parametrization(model, "weight", Orthogonal()) + + # Test that the caching system works + with parametrize.cached(): + X = model.weight + Y = model.weight + self.assertEqual(id(X), id(Y)) + + # FIXME: Rewrite this test using functions not depending on LAPACK + # and remove the `@skipIfNoLapack` (see #70995) + @skipIfNoLapack + def test_caching_parametrization_with_transfer_parametrizations_and_params(self): + r"""Test that transferring parametrizations doesn't cause issues with caching""" + class Skew(nn.Module): + def forward(self, X): + X = X.tril(-1) + return X - X.T + + class Orthogonal(nn.Module): + def forward(self, X): + Id = torch.eye(X.size(0), device=X.device) + return torch.linalg.solve(Id + X, Id - X) + + model = nn.Linear(5, 5) + parametrize.register_parametrization(model, "weight", Skew()) + parametrize.register_parametrization(model, "weight", Orthogonal()) + + to_model = nn.Linear(5, 5) + parametrize.transfer_parametrizations_and_params(model, to_model) + + with parametrize.cached(): + X = model.weight + Y = model.weight + self.assertEqual(id(X), id(Y)) + + A = to_model.weight + B = to_model.weight + self.assertEqual(id(A), id(B)) + + # test that the results are distinct objects for each module + self.assertNotEqual(id(A), id(X)) + + def test_parametrization_same_training_mode(self): + r"""Test training mode updated on parametrization registration""" + class Identity(nn.Module): + def forward(self, X): + return X + + module = nn.Linear(4, 4) + module.eval() + parametrize.register_parametrization(module, "weight", Identity()) + self.assertFalse(module.parametrizations.weight[0].training) + module.train() + parametrize.register_parametrization(module, "weight", Identity().eval()) + self.assertTrue(module.parametrizations.weight[0].training) + self.assertTrue(module.parametrizations.weight[1].training) + + def test_type_before_parametrizations(self): + r"""Test that type_before_parametrizations always retrieves original type""" + + class Identity(nn.Module): + def forward(self, X): + return X + + model = nn.Linear(5, 5) + original_type = type(model) + self.assertTrue( + parametrize.type_before_parametrizations(model) == original_type + ) + parametrize.register_parametrization(model, "weight", Identity()) + self.assertTrue( + parametrize.type_before_parametrizations(model) == original_type + ) + + def test_deepcopy_after_parametrization(self): + r"""Test that we are able to create a deepcopy of the module when it's parametrized.""" + + class AddOne(nn.Module): + def forward(self, x): + return x + 1.0 + + class ModelWithoutDeepcopy(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.tensor([1., 1., 1., 1.]), requires_grad=True) + self.bias = nn.Parameter(torch.tensor([0., 0., 0., 0.]), requires_grad=True) + self.attr = [1.0, 2.0, 3.0, 4.0] + + class ActualModel(ModelWithoutDeepcopy): + # Emulate custom implementation of the deepcopying. + def __deepcopy__(self, memo): + result = self.__new__(self.__class__) + memo[id(self)] = result + result.__dict__ = deepcopy(self.__dict__, memo) + return result + + def check_deepcopy(m1: nn.Module, m2: nn.Module): + w1 = m1.parametrizations.weight.original + w2 = m2.parametrizations.weight.original + b1 = m1.parametrizations.bias.original if parametrize.is_parametrized(m1, "bias") else m1.bias + b2 = m2.parametrizations.bias.original if parametrize.is_parametrized(m2, "bias") else m2.bias + # Weights, biases and attributes should be equal but they must be different objects. + self.assertEqual(m1.__dict__.keys(), m2.__dict__.keys()) + self.assertIsNot(m1, m2) + self.assertEqual(w1, w2) + self.assertIsNot(w1, w2) + self.assertEqual(b1, b2) + self.assertIsNot(b1, b2) + self.assertEqual(m1.attr, m2.attr) + self.assertIsNot(m1.attr, m2.attr) + + for model in (ModelWithoutDeepcopy(), ActualModel()): + # General check that we are able to create deepcopy. + parametrize.register_parametrization(model, "weight", AddOne()) + check_deepcopy(model, deepcopy(model)) + # Check that this works on models with several parametrized tensors. + parametrize.register_parametrization(model, "bias", AddOne()) + check_deepcopy(model, deepcopy(model)) + # Check that this works on models where tensors have more than one parametrization. + parametrize.register_parametrization(model, "weight", AddOne()) + check_deepcopy(model, deepcopy(model)) + + def test_transfer_parametrizations_and_params(self): + r"""Test that all parametrizations and their associated parameters are transferred.""" + + class AddOne(nn.Module): + def forward(self, x): + return x + 1.0 + + class Double(nn.Module): + def forward(self, x): + return 2.0 * x + + def right_inverse(self, x): + return 0.5 * x + + class MinusOne(nn.Module): + def forward(self, x): + return x - 1.0 + + model = nn.Linear(5, 5) + parametrize.register_parametrization(model, "weight", AddOne()) + parametrize.register_parametrization(model, "weight", Double()) + parametrize.register_parametrization(model, "weight", MinusOne()) + hold_weight = model.weight + + to_model = torch.ao.nn.qat.Linear( + 5, 5, qconfig=torch.ao.quantization.get_default_qconfig() + ) + parametrize.transfer_parametrizations_and_params(model, to_model) + + # checks that final and original value are correct and the to_model is parametrized + self.assertTrue(torch.nn.utils.parametrize.is_parametrized(to_model, "weight")) + self.assertEqual(model.weight, to_model.weight) + self.assertEqual( + model.parametrizations.weight.original, + to_model.parametrizations.weight.original, + ) + + # check that the transfer didn't affect the original value + self.assertEqual(hold_weight, model.weight) + + # testing that changes to one set of parametrizations do not affect the other + parametrize.remove_parametrizations(to_model, "weight") + self.assertFalse(torch.nn.utils.parametrize.is_parametrized(to_model, "weight")) + self.assertTrue(torch.nn.utils.parametrize.is_parametrized(model, "weight")) + + # also test that parameters that don't exist in to_model get transferred + model.test_param = Parameter(torch.randn(5, 5)) + + self.assertTrue(not hasattr(to_model, "test_param")) + parametrize.register_parametrization(model, "test_param", Double()) + hold_test_param = model.test_param + parametrize.transfer_parametrizations_and_params(model, to_model, "test_param") + + # check that previously missing params got transferred correctly + self.assertEqual(model.test_param, to_model.test_param) + self.assertEqual( + model.parametrizations.test_param.original, + to_model.parametrizations.test_param.original, + ) + + # check that the new transfer didn't change the value for the from_module + self.assertEqual(hold_test_param, model.test_param) + + def test_transfer_parametrizations_and_params_right_inverse(self): + r"""Test that all parametrizations and their associated parameters are transferred.""" + + class Double(nn.Module): + def forward(self, x): + return 2.0 * x + + def right_inverse(self, x): + return 0.5 * x + + model = nn.Linear(5, 5) + parametrize.register_parametrization(model, "weight", Double()) + hold_weight = model.weight + + to_model = torch.ao.nn.qat.Linear( + 5, 5, qconfig=torch.ao.quantization.get_default_qconfig() + ) + parametrize.transfer_parametrizations_and_params(model, to_model) + + # check that transfer occurs successfully + self.assertEqual(model.weight, to_model.weight) + self.assertEqual( + model.parametrizations.weight.original, + to_model.parametrizations.weight.original, + ) + + # check that transfer doesn't affect the from_model weight + self.assertEqual(hold_weight, model.weight) + + def test_transfer_parametrizations_and_params_single_param(self): + r"""Test that all parametrizations and their associated parameters are transferred.""" + + class AddOne(nn.Module): + def forward(self, x): + return x + 1.0 + + class Double(nn.Module): + def forward(self, x): + return 2.0 * x + + class MinusOne(nn.Module): + def forward(self, x): + return x - 1.0 + + model = nn.Linear(5, 5, bias=True) + parametrize.register_parametrization(model, "weight", AddOne()) + parametrize.register_parametrization(model, "weight", Double()) + parametrize.register_parametrization(model, "weight", MinusOne()) + parametrize.register_parametrization(model, "bias", AddOne()) + parametrize.register_parametrization(model, "bias", Double()) + parametrize.register_parametrization(model, "bias", MinusOne()) + + to_model = torch.ao.nn.qat.Linear( + 5, 5, bias=True, qconfig=torch.ao.quantization.get_default_qconfig() + ) + parametrize.transfer_parametrizations_and_params(model, to_model, "weight") + + # check that weight and only weight was transferred + self.assertEqual(model.weight, to_model.weight) + self.assertEqual( + model.parametrizations.weight.original, + to_model.parametrizations.weight.original, + ) + self.assertTrue("bias" not in to_model.parametrizations) + + # FIXME: Rewrite this test using functions not depending on LAPACK + # and remove the `@skipIfNoLapack` (see #70995) + @skipIfNoLapack + def test_transfer_parametrizations_and_params_many_to_one(self): + # A parametrization with several outputs + class RankOne(nn.Module): + def forward(self, x, y): + # Form a rank-1 matrix from a pair of vectors + return x.unsqueeze(-1) @ y.unsqueeze(-2) + + def right_inverse(self, Y): + # We project the given matrix onto the rank 1 matrices + U, S, Vh = torch.linalg.svd(Y, full_matrices=False) + # S is ordered in a decreasing way. + s0_sqrt = S[0].sqrt().unsqueeze(-1) + return U[..., :, 0] * s0_sqrt, Vh[..., 0, :] * s0_sqrt + + class Double(nn.Module): + def forward(self, x): + return 2.0 * x + + model = nn.Linear(3, 3) + parametrize.register_parametrization(model, "weight", RankOne()) + parametrize.register_parametrization(model, "weight", Double()) + hold_weight = model.weight + + to_model = torch.ao.nn.qat.Linear( + 3, 3, qconfig=torch.ao.quantization.get_default_qconfig() + ) + + parametrize.transfer_parametrizations_and_params(model, to_model) + + # checks that final and original value are correct and the to_model is parametrized + self.assertTrue(torch.nn.utils.parametrize.is_parametrized(to_model, "weight")) + self.assertEqual(model.weight, to_model.weight) + self.assertEqual( + model.parametrizations.weight.original0, + to_model.parametrizations.weight.original0, + ) + self.assertEqual( + model.parametrizations.weight.original1, + to_model.parametrizations.weight.original1, + ) + + # check that the transfer didn't affect the original value + self.assertEqual(hold_weight, model.weight) + + # testing that changes to one set of parametrizations do not affect the other + model.test_param = Parameter(torch.randn(3, 3)) + + self.assertTrue(not hasattr(to_model, "test_param")) + parametrize.register_parametrization(model, "test_param", RankOne()) + hold_test_param = model.test_param + parametrize.transfer_parametrizations_and_params(model, to_model, "test_param") + + # also check that previously missing params got transferred correctly + self.assertEqual(model.test_param, to_model.test_param) + self.assertEqual( + model.parametrizations.test_param.original0, + to_model.parametrizations.test_param.original0, + ) + self.assertEqual( + model.parametrizations.test_param.original1, + to_model.parametrizations.test_param.original1, + ) + + # check that the new transfer didn't change the value for the from_module + self.assertEqual(hold_test_param, model.test_param) + + def test_new_spectral_norm(self): + with set_default_dtype(torch.double): + input = torch.randn(3, 5) + m = nn.Linear(5, 7) + m = torch.nn.utils.parametrizations.spectral_norm(m) + spectral_norm_m = m.parametrizations.weight[0] + + self.assertEqual(spectral_norm_m._u.size(), torch.Size([m.weight.size(0)])) + + # .parametrizations.weight.original should be trainable + self.assertTrue(hasattr(m.parametrizations.weight, 'original')) + self.assertTrue('original' in m.parametrizations.weight._parameters) + + # u should be just a reused buffer + self.assertTrue(hasattr(spectral_norm_m, '_u')) + self.assertTrue('_u' in spectral_norm_m._buffers) + self.assertTrue('_v' in spectral_norm_m._buffers) + + # weight should be a plain attribute, not counted as a buffer or a param + self.assertIsNotNone(m.weight) + self.assertFalse('weight' in m._buffers) + self.assertFalse('weight' in m._parameters) + + # it should also be sharing storage as `weight_orig` + # self.assertEqual(m.parametrizations.weight.original.storage(), m.weight.storage()) + self.assertEqual(m.parametrizations.weight.original.size(), m.weight.size()) + self.assertEqual(m.parametrizations.weight.original.stride(), m.weight.stride()) + + m = torch.nn.utils.parametrize.remove_parametrizations(m, 'weight') + + # spectral_norm is the only parametrization + self.assertFalse(hasattr(m, 'parametrizations')) + self.assertTrue('weight' in m._parameters) + + # We can register spectral_norm multiple times on the same parameter + # and on multiple parameters in the same module + m = torch.nn.utils.parametrizations.spectral_norm(m, 'weight') + m = torch.nn.utils.parametrizations.spectral_norm(m, 'weight') + m = torch.nn.utils.parametrizations.spectral_norm(m, 'bias') + + # If we remove the parametrization on bias, weight is still parametrized + # Removing a parametrization runs forward in eval mode if leave_parametrized=True + m = torch.nn.utils.parametrize.remove_parametrizations(m, 'bias') + self.assertTrue('bias' in m._parameters) + self.assertTrue(hasattr(m, 'parametrizations')) + self.assertFalse('weight' in m._parameters) + + m = torch.nn.utils.parametrize.remove_parametrizations(m, 'weight') + # Neither weight and bias are parametrized + self.assertFalse(hasattr(m, 'parametrizations')) + self.assertTrue('weight' in m._parameters) + self.assertFalse(torch.nn.utils.parametrize.is_parametrized(m)) + + # test correctness in training/eval modes and cpu/multi-gpu settings + for apply_dp in (True, False): + if apply_dp: + if not TEST_MULTIGPU: + continue + device = torch.device('cuda:0') + + def maybe_wrap(m): + return torch.nn.DataParallel(m, [0, 1]) + else: + device = torch.device('cpu') + + def maybe_wrap(m): + return m + + for requires_grad in (True, False): + def get_modules(): + m = nn.Linear(3, 4).to(device) + m.weight.requires_grad_(requires_grad) + m = torch.nn.utils.parametrizations.spectral_norm(m) + wrapped_m = maybe_wrap(m) + spectral_norm_m = m.parametrizations.weight[0] + return m, wrapped_m, spectral_norm_m + + input = torch.randn(2, 3, device=device) + + m, wrapped_m, spectral_norm_m = get_modules() + + self.assertTrue(hasattr(spectral_norm_m, '_u')) + u0 = spectral_norm_m._u.clone() + v0 = spectral_norm_m._v.clone() + + # TEST TRAINING BEHAVIOR + + # We perform GD first to modify the initial matrix + opt = torch.optim.SGD(wrapped_m.parameters(), lr=0.1) + + opt.zero_grad() + wrapped_m(input).sum().backward() + opt.step() + + out = wrapped_m(input) + if requires_grad: + # run forward again and assert that u and v are updated + self.assertNotEqual(u0, spectral_norm_m._u) + self.assertNotEqual(v0, spectral_norm_m._v) + + # assert that backprop reaches original weight + # can't use gradcheck because the function changes as we + # activate through it in training mode + if requires_grad: + torch.autograd.grad(out.sum(), m.parametrizations.weight.original) + + # test backward works with multiple forwards + # it uses training mode so we need to reset `u` and `v` vectors + # to same value at beginning for finite difference test to pass + saved_u = spectral_norm_m._u.clone() + saved_v = spectral_norm_m._v.clone() + + def fn(input): + spectral_norm_m._u.data.copy_(saved_u) + spectral_norm_m._v.data.copy_(saved_v) + out0 = wrapped_m(input) + out1 = wrapped_m(input) + return out0 + out1 + + # Make sure we can compute gradients wrt to all the parameters in the case + # of double forward + fn(input.clone().requires_grad_()).sum().backward() + gradcheck(fn, (input.clone().requires_grad_(),), check_batched_grad=False) + + # test removing + # spectral norm module needs to be in eval mode if we'd like to + # avoid doing another power iteration + m, wrapped_m, _ = get_modules() + pre_remove_out = wrapped_m(input) + m.eval() + m = torch.nn.utils.parametrize.remove_parametrizations(m, 'weight') + self.assertEqual(wrapped_m(input), pre_remove_out) + + torch.nn.utils.parametrizations.spectral_norm(m) + for _ in range(3): + pre_remove_out = wrapped_m(input) + m.eval() + m = torch.nn.utils.parametrize.remove_parametrizations(m, 'weight') + self.assertEqual(wrapped_m(input), pre_remove_out) + + # TEST EVAL BEHAVIOR + m, wrapped_m, spectral_norm_m = get_modules() + wrapped_m(input) + last_train_out = wrapped_m(input) + last_train_u = spectral_norm_m._u.clone() + last_train_v = spectral_norm_m._v.clone() + wrapped_m.zero_grad() + wrapped_m.eval() + + eval_out0 = wrapped_m(input) + # assert eval gives same result as last training iteration + self.assertEqual(eval_out0, last_train_out) + # assert doing more iteartion in eval don't change things + self.assertEqual(eval_out0, wrapped_m(input)) + self.assertEqual(last_train_u, spectral_norm_m._u) + self.assertEqual(last_train_v, spectral_norm_m._v) + + # FIXME: the code below is flaky when executed with DataParallel + # see https://github.com/pytorch/pytorch/issues/13818 + if apply_dp: + continue + + # test backward works with multiple forwards in mixed training + # and eval modes + # it uses training mode so we need to reset `u` and `v` vectors + # to same value at beginning for finite difference test to pass + saved_u = spectral_norm_m._u.clone() + saved_v = spectral_norm_m._v.clone() + + def fn(input): + spectral_norm_m._u.data.copy_(saved_u) + spectral_norm_m._v.data.copy_(saved_v) + wrapped_m.train() + out0 = wrapped_m(input) + wrapped_m.eval() + out1 = wrapped_m(input) + wrapped_m.train() + out2 = wrapped_m(input) + wrapped_m.eval() + out3 = wrapped_m(input) + return out0 + out1 + out2 + out3 + + gradcheck(fn, (input.clone().requires_grad_(),)) + + # assert that backprop reaches weight_orig in eval + if requires_grad: + def fn(weight): + return wrapped_m(input) + + gradcheck(fn, (m.parametrizations.weight.original,)) + + def test_new_spectral_norm_load_state_dict(self): + for activate_times in (0, 3): + inp = torch.randn(2, 3) + m = nn.Linear(3, 5) + snm = torch.nn.utils.parametrizations.spectral_norm(m) + snm.train() + + for _ in range(activate_times): + snm(inp) + + state_dict = deepcopy(snm.state_dict()) + self.assertEqual({ + 'parametrizations.weight.original', + 'bias', + 'parametrizations.weight.0._v', + 'parametrizations.weight.0._u' + }, set(state_dict.keys())) + + # test that non-strict loading works + non_strict_state_dict = deepcopy(state_dict) + non_strict_state_dict['nonsense'] = 'nonsense' + with self.assertRaisesRegex(RuntimeError, r'Unexpected key\(s\) in state_dict: "nonsense"'): + snm.load_state_dict(non_strict_state_dict, strict=True) + snm.load_state_dict(non_strict_state_dict, strict=False) + del non_strict_state_dict['parametrizations.weight.original'] + snm.load_state_dict(non_strict_state_dict, strict=False) + del non_strict_state_dict['parametrizations.weight.0._u'] + snm.load_state_dict(non_strict_state_dict, strict=False) + del non_strict_state_dict['parametrizations.weight.0._v'] + snm.load_state_dict(non_strict_state_dict, strict=False) + non_strict_state_dict['weight'] = snm.weight.detach().clone() # set W as a buffer + snm.load_state_dict(non_strict_state_dict, strict=False) + del non_strict_state_dict._metadata['parametrizations.weight.0'] # remove metadata info + snm.load_state_dict(non_strict_state_dict, strict=False) + del non_strict_state_dict['weight'] # remove W buffer + snm.load_state_dict(non_strict_state_dict, strict=False) + del non_strict_state_dict['bias'] + snm.load_state_dict(non_strict_state_dict, strict=False) + + # normal state_dict + + # test that re-wrapping does not matter + m = torch.nn.utils.parametrize.remove_parametrizations(snm, 'weight') + snm = torch.nn.utils.parametrizations.spectral_norm(m) + + snm.load_state_dict(state_dict) + with torch.no_grad(): + snm.eval() + out0_eval = snm(inp) + snm.train() + out1_train = snm(inp) + out2_train = snm(inp) + snm.eval() + out3_eval = snm(inp) + + # test that re-wrapping does not matter + m = torch.nn.utils.parametrize.remove_parametrizations(snm, 'weight') + snm = torch.nn.utils.parametrizations.spectral_norm(m) + + # Test normal loading + snm.load_state_dict(state_dict) + with torch.no_grad(): + snm.eval() + self.assertEqual(out0_eval, snm(inp)) + snm.train() + self.assertEqual(out1_train, snm(inp)) + self.assertEqual(out2_train, snm(inp)) + snm.eval() + self.assertEqual(out3_eval, snm(inp)) + + def test_new_spectral_norm_dim(self): + inp = torch.randn(2, 3, 10, 12) + m = nn.ConvTranspose2d(3, 4, (5, 6)) + m = torch.nn.utils.parametrizations.spectral_norm(m) + snm = m.parametrizations.weight[0] + # this should not run into incompatible shapes + x = m(inp) + # check that u refers to the same dimension + self.assertEqual(snm._u.shape, m.parametrizations.weight.original[0, :, 0, 0].shape) + + def test_new_spectral_norm_forward(self): + input = torch.randn(3, 5) + m = nn.Linear(5, 7) + m = torch.nn.utils.parametrizations.spectral_norm(m) + snm = m.parametrizations.weight[0] + # naive forward + _weight = m.parametrizations.weight.original + _bias, _v = m.bias, snm._v + _weight_mat = _weight.view(_weight.size(0), -1) + _u = torch.mv(_weight_mat, _v) + _u = F.normalize(_u, dim=0, eps=1e-12) + _v = torch.mv(_weight_mat.t(), _u) + _v = F.normalize(_v, dim=0, eps=1e-12) + _weight.data /= torch.dot(_u, torch.matmul(_weight_mat, _v)) + out_hat = torch.nn.functional.linear(input, _weight, _bias) + expect_out = m(input) + self.assertEqual(expect_out, out_hat) + + @skipIfNoLapack + def test_orthogonal_parametrization(self): + # Orthogonal implements 6 algorithms (3x parametrizations times 2 options of use_trivialization) + + def assert_is_orthogonal(X): + n, k = X.size(-2), X.size(-1) + if n < k: + X = X.mT + n, k = k, n + Id = torch.eye(k, dtype=X.dtype, device=X.device).expand(*(X.size()[:-2]), k, k) + eps = 10 * n * torch.finfo(X.dtype).eps + torch.testing.assert_close(X.mH @ X, Id, atol=eps, rtol=0.) + + def assert_weight_allclose_Q(weight, W): + # Test that weight is equal to the Q part of the QR decomposition of W + # (or of its transpose if the matrix is wide) + wide_matrix = W.size(-2) < W.size(-1) + if wide_matrix: + W = W.mT + Q, R = torch.linalg.qr(W) + Q *= R.diagonal(dim1=-2, dim2=-1).sgn().unsqueeze(-2) + if wide_matrix: + Q = Q.mT + torch.testing.assert_close(Q, weight, atol=1e-5, rtol=0.) + + for shape, dtype, use_linear in product(((4, 4), (5, 3), (3, 5)), # square/ tall / wide + (torch.float32, torch.complex64), + (True, False)): + # Conv2d does not support complex yet + if not use_linear: + continue + + if use_linear: + input = torch.randn(3, shape[0], dtype=dtype) + else: + input = torch.randn(2, 2, shape[0] + 2, shape[1] + 1, dtype=dtype) + + for parametrization, use_trivialization in product(("matrix_exp", "cayley", "householder"), + (False, True)): + # right_inverse for Cayley and matrix_exp not implemented for use_trivialization=False + # See Note [right_inverse expm cayley] + can_initialize = use_trivialization or parametrization == "householder" + + # We generate them every time to always start with fresh weights + if use_linear: + m = nn.Linear(*shape, dtype=dtype) + else: + m = nn.Conv2d(2, 3, shape, dtype=dtype) + + # We do not support householder for complex inputs + # See Note [Householder complex] + w_init = m.weight.clone() + if parametrization == "householder" and m.weight.is_complex(): + msg = "householder parametrization does not support complex tensors" + with self.assertRaisesRegex(ValueError, msg): + torch.nn.utils.parametrizations.orthogonal(m, + "weight", + parametrization, + use_trivialization=use_trivialization) + continue + + wide_matrix = w_init.size(-2) < w_init.size(-1) + torch.nn.utils.parametrizations.orthogonal(m, + "weight", + parametrization, + use_trivialization=use_trivialization) + # Forwards works as expected + self.assertEqual(w_init.shape, m.weight.shape) + assert_is_orthogonal(m.weight) + if can_initialize: + assert_weight_allclose_Q(m.weight, w_init) + + # Intializing with a given orthogonal matrix works + X = torch.randn_like(m.weight) + if wide_matrix: + X = X.mT + w_new = torch.linalg.qr(X).Q + if wide_matrix: + w_new = w_new.mT + if can_initialize: + m.weight = w_new + torch.testing.assert_close(w_new, m.weight, atol=1e-5, rtol=0.) + else: + msg = "assign to the matrix exponential or the Cayley parametrization" + with self.assertRaisesRegex(NotImplementedError, msg): + m.weight = w_new + + # Intializing with a non-orthogonal matrix makes m.weight be the Q part of the given matrix + w_new = torch.randn_like(m.weight) + if can_initialize: + m.weight = w_new + assert_weight_allclose_Q(m.weight, w_new) + else: + msg = "assign to the matrix exponential or the Cayley parametrization" + with self.assertRaisesRegex(NotImplementedError, msg): + m.weight = w_new + + opt = torch.optim.SGD(m.parameters(), lr=0.1) + for _ in range(2): + opt.zero_grad() + m(input).norm().backward() + grad = m.parametrizations.weight.original.grad + self.assertIsNotNone(grad) + # We do not update the upper triangular part of the matrix if tall tril if wide + if grad.size(-2) >= grad.size(-1): + zeros_grad = grad.triu(1) + else: + zeros_grad = grad.tril(-1) + self.assertEqual(zeros_grad, torch.zeros_like(zeros_grad)) + # The gradient in the diagonal can only be imaginary because a skew-Hermitian + # matrix has imaginary diagonal + diag_grad = grad.diagonal(dim1=-2, dim2=-1) + if grad.is_complex(): + diag_grad = diag_grad.real + self.assertEqual(diag_grad, torch.zeros_like(diag_grad)) + opt.step() + assert_is_orthogonal(m.weight) + + @skipIfNoLapack + def test_orthogonal_errors(self): + m = nn.Linear(3, 4) + with self.assertRaisesRegex(ValueError, "has to be one of"): + torch.nn.utils.parametrizations.orthogonal(m, "weight", "foo") + + with self.assertRaisesRegex(ValueError, "Expected a matrix"): + torch.nn.utils.parametrizations.orthogonal(m, "bias") + + torch.nn.utils.parametrizations.orthogonal(m, "weight") + with self.assertRaisesRegex(ValueError, "matrices of shape"): + m.weight = torch.randn(5, 5) + torch.nn.utils.parametrize.remove_parametrizations(m, "weight") + + +instantiate_parametrized_tests(TestNNParametrization) + +if __name__ == '__main__': + run_tests() diff --git a/test/nn/test_pooling.py b/test/nn/test_pooling.py index 073269a7c5539..3826b1dff70c3 100644 --- a/test/nn/test_pooling.py +++ b/test/nn/test_pooling.py @@ -130,6 +130,9 @@ def test_avg_pool3d_ceil_mode(self): class TestPoolingNN(NNTestCase): + _do_cuda_memory_leak_check = True + _do_cuda_non_default_stream = True + def test_adaptive_pooling_input_size(self): for numel in (2, 3): for pool_type in ('Max', 'Avg'): @@ -410,6 +413,30 @@ def test_FractionalMaxPool3d_zero_out_size(self, device): out = mod(inp) self.assertEqual(out, torch.empty((16, 0, 1, 1), device=device)) + @onlyNativeDeviceTypes + def test_FractionalMaxPool2d_zero_samples(self, device): + samples = torch.rand([0, 16, 2], device=device) + mod = nn.FractionalMaxPool2d([2, 2], output_size=[1, 1], _random_samples=samples) + inp = torch.randn([0, 16, 32, 32], device=device) + out = mod(inp) + self.assertEqual(out, torch.empty((0, 16, 1, 1), device=device)) + + inp1 = torch.randn([1, 16, 32, 32], device=device) + with self.assertRaisesRegex(RuntimeError, "Expect _random_samples"): + out1 = mod(inp1) + + @onlyNativeDeviceTypes + def test_FractionalMaxPool3d_zero_samples(self, device): + samples = torch.rand([0, 16, 3], device=device) + mod = nn.FractionalMaxPool3d([3, 2, 2], output_size=[1, 1, 1], _random_samples=samples) + inp = torch.randn([0, 16, 50, 32, 32], device=device) + out = mod(inp) + self.assertEqual(out, torch.empty((0, 16, 1, 1, 1), device=device)) + + inp1 = torch.randn([1, 16, 50, 32, 32], device=device) + with self.assertRaisesRegex(RuntimeError, "Expect _random_samples"): + out1 = mod(inp1) + @onlyNativeDeviceTypes def test_MaxPool_zero_batch_dim(self, device): inp = torch.randn(0, 16, 50, device=device) diff --git a/test/nn/test_pruning.py b/test/nn/test_pruning.py new file mode 100644 index 0000000000000..bd2db02d056fc --- /dev/null +++ b/test/nn/test_pruning.py @@ -0,0 +1,939 @@ +# Owner(s): ["module: nn"] +import unittest +import unittest.mock as mock +import pickle + +import torch + +import torch.nn as nn +import torch.nn.utils.prune as prune +from torch.testing._internal.common_utils import TEST_NUMPY, TemporaryFileName, \ + instantiate_parametrized_tests, run_tests +from torch.testing._internal.common_nn import NNTestCase + +class TestPruningNN(NNTestCase): + _do_cuda_memory_leak_check = True + _do_cuda_non_default_stream = True + + # torch/nn/utils/prune.py + @unittest.skipIf(not TEST_NUMPY, "numpy not found") + def test_validate_pruning_amount_init(self): + r"""Test the first util function that validates the pruning + amount requested by the user the moment the pruning method + is initialized. This test checks that the expected errors are + raised whenever the amount is invalid. + The original function runs basic type checking + value range checks. + It doesn't check the validity of the pruning amount with + respect to the size of the tensor to prune. That's left to + `_validate_pruning_amount`, tested below. + """ + # neither float not int should raise TypeError + with self.assertRaises(TypeError): + prune._validate_pruning_amount_init(amount="I'm a string") + + # float not in [0, 1] should raise ValueError + with self.assertRaises(ValueError): + prune._validate_pruning_amount_init(amount=1.1) + with self.assertRaises(ValueError): + prune._validate_pruning_amount_init(amount=20.) + + # negative int should raise ValueError + with self.assertRaises(ValueError): + prune._validate_pruning_amount_init(amount=-10) + + # all these should pass without errors because they're valid amounts + prune._validate_pruning_amount_init(amount=0.34) + prune._validate_pruning_amount_init(amount=1500) + prune._validate_pruning_amount_init(amount=0) + prune._validate_pruning_amount_init(amount=0.) + prune._validate_pruning_amount_init(amount=1) + prune._validate_pruning_amount_init(amount=1.) + self.assertTrue(True) + + @unittest.skipIf(not TEST_NUMPY, "numpy not found") + def test_validate_pruning_amount(self): + r"""Tests the second util function that validates the pruning + amount requested by the user, this time with respect to the size + of the tensor to prune. The rationale is that if the pruning amount, + converted to absolute value of units to prune, is larger than + the number of units in the tensor, then we expect the util function + to raise a value error. + """ + # if amount is int and amount > tensor_size, raise ValueError + with self.assertRaises(ValueError): + prune._validate_pruning_amount(amount=20, tensor_size=19) + + # amount is a float so this should not raise an error + prune._validate_pruning_amount(amount=0.3, tensor_size=0) + + # this is okay + prune._validate_pruning_amount(amount=19, tensor_size=20) + prune._validate_pruning_amount(amount=0, tensor_size=0) + prune._validate_pruning_amount(amount=1, tensor_size=1) + self.assertTrue(True) + + @unittest.skipIf(not TEST_NUMPY, "numpy not found") + def test_compute_nparams_to_prune(self): + r"""Test that requested pruning `amount` gets translated into the + correct absolute number of units to prune. + """ + self.assertEqual( + prune._compute_nparams_toprune(amount=0, tensor_size=15), + 0 + ) + self.assertEqual( + prune._compute_nparams_toprune(amount=10, tensor_size=15), + 10 + ) + # if 1 is int, means 1 unit + self.assertEqual( + prune._compute_nparams_toprune(amount=1, tensor_size=15), + 1 + ) + # if 1. is float, means 100% of units + self.assertEqual( + prune._compute_nparams_toprune(amount=1., tensor_size=15), + 15 + ) + self.assertEqual( + prune._compute_nparams_toprune(amount=0.4, tensor_size=17), + 7 + ) + + def test_random_pruning_sizes(self): + r"""Test that the new parameters and buffers created by the pruning + method have the same size as the input tensor to prune. These, in + fact, correspond to the pruned version of the tensor itself, its + mask, and its original copy, so the size must match. + """ + # fixturize test + # TODO: add other modules + modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)] + names = ['weight', 'bias'] + + for m in modules: + for name in names: + with self.subTest(m=m, name=name): + original_tensor = getattr(m, name) + + prune.random_unstructured(m, name=name, amount=0.1) + # mask has the same size as tensor being pruned + self.assertEqual( + original_tensor.size(), + getattr(m, name + '_mask').size() + ) + # 'orig' tensor has the same size as the original tensor + self.assertEqual( + original_tensor.size(), + getattr(m, name + '_orig').size() + ) + # new tensor has the same size as the original tensor + self.assertEqual( + original_tensor.size(), + getattr(m, name).size() + ) + + def test_random_pruning_orig(self): + r"""Test that original tensor is correctly stored in 'orig' + after pruning is applied. Important to make sure we don't + lose info about the original unpruned parameter. + """ + # fixturize test + # TODO: add other modules + modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)] + names = ['weight', 'bias'] + + for m in modules: + for name in names: + with self.subTest(m=m, name=name): + + # tensor prior to pruning + original_tensor = getattr(m, name) + prune.random_unstructured(m, name=name, amount=0.1) + self.assertEqual( + original_tensor, + getattr(m, name + '_orig') + ) + + def test_random_pruning_new_weight(self): + r"""Test that module.name now contains a pruned version of + the original tensor obtained from multiplying it by the mask. + """ + # fixturize test + # TODO: add other modules + modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)] + names = ['weight', 'bias'] + + for m in modules: + for name in names: + with self.subTest(m=m, name=name): + # tensor prior to pruning + original_tensor = getattr(m, name) + prune.random_unstructured(m, name=name, amount=0.1) + # weight = weight_orig * weight_mask + self.assertEqual( + getattr(m, name), + getattr(m, name + '_orig') + * getattr(m, name + '_mask').to( + dtype=original_tensor.dtype + ), + ) + + def test_identity_pruning(self): + r"""Test that a mask of 1s does not change forward or backward. + """ + input_ = torch.ones(1, 5) + m = nn.Linear(5, 2) + y_prepruning = m(input_) # output prior to pruning + + # compute grad pre-pruning and check it's equal to all ones + y_prepruning.sum().backward() + old_grad_weight = m.weight.grad.clone() # don't grab pointer! + self.assertEqual(old_grad_weight, torch.ones_like(m.weight)) + old_grad_bias = m.bias.grad.clone() + self.assertEqual(old_grad_bias, torch.ones_like(m.bias)) + + # remove grads + m.zero_grad() + + # force the mask to be made of all 1s + prune.identity(m, name="weight") + + # with mask of 1s, output should be identical to no mask + y_postpruning = m(input_) + self.assertEqual(y_prepruning, y_postpruning) + + # with mask of 1s, grad should be identical to no mask + y_postpruning.sum().backward() + self.assertEqual(old_grad_weight, m.weight_orig.grad) + self.assertEqual(old_grad_bias, m.bias.grad) + + # calling forward twice in a row shouldn't change output + y1 = m(input_) + y2 = m(input_) + self.assertEqual(y1, y2) + + def test_random_pruning_0perc(self): + r"""Test that a mask of 1s does not change forward or backward. + """ + input_ = torch.ones(1, 5) + m = nn.Linear(5, 2) + y_prepruning = m(input_) # output prior to pruning + + # compute grad pre-pruning and check it's equal to all ones + y_prepruning.sum().backward() + old_grad_weight = m.weight.grad.clone() # don't grab pointer! + self.assertEqual(old_grad_weight, torch.ones_like(m.weight)) + old_grad_bias = m.bias.grad.clone() + self.assertEqual(old_grad_bias, torch.ones_like(m.bias)) + + # remove grads + m.zero_grad() + + # force the mask to be made of all 1s + with mock.patch( + "torch.nn.utils.prune.RandomUnstructured.compute_mask" + ) as compute_mask: + compute_mask.return_value = torch.ones_like(m.weight) + prune.random_unstructured(m, name='weight', amount=0.9) # amount won't count + + # with mask of 1s, output should be identical to no mask + y_postpruning = m(input_) + self.assertEqual(y_prepruning, y_postpruning) + + # with mask of 1s, grad should be identical to no mask + y_postpruning.sum().backward() + self.assertEqual(old_grad_weight, m.weight_orig.grad) + self.assertEqual(old_grad_bias, m.bias.grad) + + # calling forward twice in a row shouldn't change output + y1 = m(input_) + y2 = m(input_) + self.assertEqual(y1, y2) + + def test_random_pruning(self): + input_ = torch.ones(1, 5) + m = nn.Linear(5, 2) + + # define custom mask to assign with mock + mask = torch.ones_like(m.weight) + mask[1, 0] = 0 + mask[0, 3] = 0 + + # check grad is zero for masked weights + with mock.patch( + "torch.nn.utils.prune.RandomUnstructured.compute_mask" + ) as compute_mask: + compute_mask.return_value = mask + prune.random_unstructured(m, name='weight', amount=0.9) + + y_postpruning = m(input_) + y_postpruning.sum().backward() + # weight_orig is the parameter, so it's the tensor that will accumulate the grad + self.assertEqual(m.weight_orig.grad, mask) # all 1s, except for masked units + self.assertEqual(m.bias.grad, torch.ones_like(m.bias)) + + # make sure that weight_orig update doesn't modify [1, 0] and [0, 3] + old_weight_orig = m.weight_orig.clone() + # update weights + learning_rate = 1. + for p in m.parameters(): + p.data.sub_(p.grad.data * learning_rate) + # since these are pruned, they should not be updated + self.assertEqual(old_weight_orig[1, 0], m.weight_orig[1, 0]) + self.assertEqual(old_weight_orig[0, 3], m.weight_orig[0, 3]) + + def test_random_pruning_forward(self): + r"""check forward with mask (by hand). + """ + input_ = torch.ones(1, 5) + m = nn.Linear(5, 2) + + # define custom mask to assign with mock + mask = torch.zeros_like(m.weight) + mask[1, 0] = 1 + mask[0, 3] = 1 + + with mock.patch( + "torch.nn.utils.prune.RandomUnstructured.compute_mask" + ) as compute_mask: + compute_mask.return_value = mask + prune.random_unstructured(m, name='weight', amount=0.9) + + yhat = m(input_) + self.assertEqual(yhat[0, 0], m.weight_orig[0, 3] + m.bias[0]) + self.assertEqual(yhat[0, 1], m.weight_orig[1, 0] + m.bias[1]) + + def test_remove_pruning_forward(self): + r"""Remove pruning and check forward is unchanged from previous + pruned state. + """ + input_ = torch.ones(1, 5) + m = nn.Linear(5, 2) + + # define custom mask to assign with mock + mask = torch.ones_like(m.weight) + mask[1, 0] = 0 + mask[0, 3] = 0 + + # check grad is zero for masked weights + with mock.patch( + "torch.nn.utils.prune.RandomUnstructured.compute_mask" + ) as compute_mask: + compute_mask.return_value = mask + prune.random_unstructured(m, name='weight', amount=0.9) + + y_postpruning = m(input_) + + prune.remove(m, 'weight') + + y_postremoval = m(input_) + self.assertEqual(y_postpruning, y_postremoval) + + def test_pruning_id_consistency(self): + r"""Test that pruning doesn't change the id of the parameters, which + would otherwise introduce issues with pre-existing optimizers that + point to old parameters. + """ + m = nn.Linear(5, 2, bias=False) + + tensor_id = id(list(m.parameters())[0]) + + prune.random_unstructured(m, name="weight", amount=0.9) + self.assertEqual(tensor_id, id(list(m.parameters())[0])) + + prune.remove(m, "weight") + self.assertEqual(tensor_id, id(list(m.parameters())[0])) + + def test_random_pruning_pickle(self): + modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)] + names = ['weight', 'bias'] + + for m in modules: + for name in names: + with self.subTest(m=m, name=name): + prune.random_unstructured(m, name=name, amount=0.1) + m_new = pickle.loads(pickle.dumps(m)) + self.assertIsInstance(m_new, type(m)) + + def test_multiple_pruning_calls(self): + # if you call pruning twice, the hook becomes a PruningContainer + m = nn.Conv3d(2, 2, 2) + prune.l1_unstructured(m, name='weight', amount=0.1) + weight_mask0 = m.weight_mask # save it for later sanity check + + # prune again + prune.ln_structured(m, name='weight', amount=0.3, n=2, dim=0) + hook = next(iter(m._forward_pre_hooks.values())) + self.assertIsInstance( + hook, + torch.nn.utils.prune.PruningContainer + ) + # check that container._tensor_name is correctly set no matter how + # many pruning methods are in the container + self.assertEqual(hook._tensor_name, 'weight') + + # check that the pruning container has the right length + # equal to the number of pruning iters + self.assertEqual(len(hook), 2) # m.weight has been pruned twice + + # check that the entries of the pruning container are of the expected + # type and in the expected order + self.assertIsInstance(hook[0], torch.nn.utils.prune.L1Unstructured) + self.assertIsInstance(hook[1], torch.nn.utils.prune.LnStructured) + + # check that all entries that are 0 in the 1st mask are 0 in the + # 2nd mask too + self.assertTrue(torch.all(m.weight_mask[weight_mask0 == 0] == 0)) + + # prune again + prune.ln_structured(m, name='weight', amount=0.1, n=float('inf'), dim=1) + # check that container._tensor_name is correctly set no matter how + # many pruning methods are in the container + hook = next(iter(m._forward_pre_hooks.values())) + self.assertEqual(hook._tensor_name, 'weight') + + def test_pruning_container(self): + # create an empty container + container = prune.PruningContainer() + container._tensor_name = 'test' + self.assertEqual(len(container), 0) + + p = prune.L1Unstructured(amount=2) + p._tensor_name = 'test' + + # test adding a pruning method to a container + container.add_pruning_method(p) + + # test error raised if tensor name is different + q = prune.L1Unstructured(amount=2) + q._tensor_name = 'another_test' + with self.assertRaises(ValueError): + container.add_pruning_method(q) + + # test that adding a non-pruning method object to a pruning container + # raises a TypeError + with self.assertRaises(TypeError): + container.add_pruning_method(10) + with self.assertRaises(TypeError): + container.add_pruning_method('ugh') + + def test_pruning_container_compute_mask(self): + r"""Test `compute_mask` of pruning container with a known `t` and + `default_mask`. Indirectly checks that Ln structured pruning is + acting on the right axis. + """ + # create an empty container + container = prune.PruningContainer() + container._tensor_name = 'test' + + # 1) test unstructured pruning + # create a new pruning method + p = prune.L1Unstructured(amount=2) + p._tensor_name = 'test' + # add the pruning method to the container + container.add_pruning_method(p) + + # create tensor to be pruned + t = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]).to(dtype=torch.float32) + # create prior mask by hand + default_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 1]]) + # since we are pruning the two lowest magnitude units, the outcome of + # the calculation should be this: + expected_mask = torch.tensor([[0, 0, 1, 0], [1, 1, 0, 1]], dtype=torch.float32) + computed_mask = container.compute_mask(t, default_mask) + self.assertEqual(expected_mask, computed_mask) + + # 2) test structured pruning + q = prune.LnStructured(amount=1, n=2, dim=0) + q._tensor_name = 'test' + container.add_pruning_method(q) + # since we are pruning the lowest magnitude one of the two rows, the + # outcome of the calculation should be this: + expected_mask = torch.tensor([[0, 0, 0, 0], [1, 1, 0, 1]], dtype=torch.float32) + computed_mask = container.compute_mask(t, default_mask) + self.assertEqual(expected_mask, computed_mask) + + # 2) test structured pruning, along another axis + r = prune.LnStructured(amount=1, n=2, dim=1) + r._tensor_name = 'test' + container.add_pruning_method(r) + # since we are pruning the lowest magnitude of the four columns, the + # outcome of the calculation should be this: + expected_mask = torch.tensor([[0, 1, 1, 0], [0, 1, 0, 1]], dtype=torch.float32) + computed_mask = container.compute_mask(t, default_mask) + self.assertEqual(expected_mask, computed_mask) + + def test_l1_unstructured_pruning(self): + r"""Test that l1 unstructured pruning actually removes the lowest + entries by l1 norm (by hand). It also checks that applying l1 + unstructured pruning more than once respects the previous mask. + """ + m = nn.Linear(4, 2) + # modify its weight matrix by hand + m.weight = torch.nn.Parameter( + torch.tensor( + [[1, 2, 3, 4], [-4, -3, -2, -1]], dtype=torch.float32 + ) + ) + + prune.l1_unstructured(m, 'weight', amount=2) + expected_weight = torch.tensor([[0, 2, 3, 4], [-4, -3, -2, 0]], + dtype=m.weight.dtype) + self.assertEqual(expected_weight, m.weight) + + # check that pruning again removes the next two smallest entries + prune.l1_unstructured(m, 'weight', amount=2) + expected_weight = torch.tensor([[0, 0, 3, 4], [-4, -3, 0, 0]], + dtype=m.weight.dtype) + self.assertEqual(expected_weight, m.weight) + + def test_l1_unstructured_pruning_with_importance_scores(self): + r"""Test that l1 unstructured pruning actually removes the lowest + entries of importance scores and not the parameter by l1 norm (by hand). + It also checks that applying l1 unstructured pruning more than once + respects the previous mask. + """ + m = nn.Linear(4, 2) + # modify its weight matrix by hand + m.weight = torch.nn.Parameter( + torch.tensor( + [[1, 2, 3, 4], [-4, -3, -2, -1]], dtype=torch.float32 + ) + ) + importance_scores = torch.tensor( + [[4, 2, 1, 3], [-3, -1, -2, -4]], dtype=torch.float32 + ) + + prune.l1_unstructured(m, 'weight', amount=2, importance_scores=importance_scores) + expected_weight = torch.tensor([[1, 2, 0, 4], [-4, 0, -2, -1]], + dtype=m.weight.dtype) + self.assertEqual(expected_weight, m.weight) + + # check that pruning again removes two entries of m.weight that are colocated with + # the next two smallest absolute values of importance scores. + prune.l1_unstructured(m, 'weight', amount=2, importance_scores=importance_scores) + expected_weight = torch.tensor([[1, 0, 0, 4], [-4, 0, 0, -1]], + dtype=m.weight.dtype) + self.assertEqual(expected_weight, m.weight) + + def test_unstructured_pruning_same_magnitude(self): + r"""Since it may happen that the tensor to prune has entries with the + same exact magnitude, it is important to check that pruning happens + consistenly based on the bottom % of weights, and not by threshold, + which would instead kill off *all* units with magnitude = threshold. + """ + AMOUNT = 0.2 + p = prune.L1Unstructured(amount=AMOUNT) + # create a random tensors with entries in {-2, 0, 2} + t = 2 * torch.randint(low=-1, high=2, size=(10, 7)) + nparams_toprune = prune._compute_nparams_toprune(AMOUNT, t.nelement()) + + computed_mask = p.compute_mask(t, default_mask=torch.ones_like(t)) + nparams_pruned = torch.sum(computed_mask == 0) + self.assertEqual(nparams_toprune, nparams_pruned) + + def test_random_structured_pruning_amount(self): + AMOUNT = 0.6 + AXIS = 2 + p = prune.RandomStructured(amount=AMOUNT, dim=AXIS) + t = 2 * torch.randint(low=-1, high=2, size=(5, 4, 2)).to( + dtype=torch.float32 + ) + nparams_toprune = prune._compute_nparams_toprune(AMOUNT, t.shape[AXIS]) + + computed_mask = p.compute_mask(t, default_mask=torch.ones_like(t)) + # check that 1 column is fully prune, the others are left untouched + remaining_axes = [_ for _ in range(len(t.shape)) if _ != AXIS] + per_column_sums = sorted( + torch.sum(computed_mask == 0, axis=remaining_axes) + ) + assert per_column_sums == [0, 20] + + def test_ln_structured_pruning(self): + r"""Check Ln structured pruning by hand. + """ + m = nn.Conv2d(3, 1, 2) + m.weight.data = torch.tensor( + [[[[1., 2.], [1., 2.5]], + [[0.5, 1.], [0.1, 0.1]], + [[-3., -5.], [0.1, -1.]]]] + ) + # expected effect of pruning 1 of the 3 channels by L2-norm + expected_mask_axis1 = torch.ones_like(m.weight) + expected_mask_axis1[:, 1] = 0. + + prune.ln_structured(m, 'weight', amount=1, n=2, dim=1) + self.assertEqual(expected_mask_axis1, m.weight_mask) + + # expected effect of pruning 1 of the 2 columns along axis -1 by L1-norm + expected_mask_axis3 = expected_mask_axis1 + expected_mask_axis3[:, :, :, 0] = 0. + + prune.ln_structured(m, 'weight', amount=1, n=1, dim=-1) + self.assertEqual(expected_mask_axis3, m.weight_mask) + + def test_ln_structured_pruning_importance_scores(self): + r"""Check Ln structured pruning by hand. + """ + m = nn.Conv2d(3, 1, 2) + m.weight.data = torch.tensor( + [[[[1., 2.], [1., 2.5]], + [[0.5, 1.], [0.1, 0.1]], + [[-3., -5.], [0.1, -1.]]]] + ) + importance_scores = torch.tensor( + [[[[10., 1.], [10., 1.]], + [[30., 3.], [30., 3.]], + [[-20., -2.], [-20., -2.]]]] + ) + # expected effect of pruning 1 of the 3 channels by L2-norm + expected_mask_axis1 = torch.ones_like(m.weight) + expected_mask_axis1[:, 0] = 0. + + prune.ln_structured(m, 'weight', amount=1, n=2, dim=1, importance_scores=importance_scores) + self.assertEqual(expected_mask_axis1, m.weight_mask) + + # expected effect of pruning 1 of the 2 columns along axis -1 by L1-norm + expected_mask_axis3 = expected_mask_axis1 + expected_mask_axis3[:, :, :, 1] = 0. + + prune.ln_structured(m, 'weight', amount=1, n=1, dim=-1, importance_scores=importance_scores) + self.assertEqual(expected_mask_axis3, m.weight_mask) + + def test_remove_pruning(self): + r"""`prune.remove` removes the hook and the reparametrization + and makes the pruning final in the original parameter. + """ + modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)] + names = ['weight', 'bias'] + + for m in modules: + for name in names: + with self.subTest(m=m, name=name): + # first prune + prune.random_unstructured(m, name, amount=0.5) + self.assertIn(name + "_orig", dict(m.named_parameters())) + self.assertIn(name + "_mask", dict(m.named_buffers())) + self.assertNotIn(name, dict(m.named_parameters())) + self.assertTrue(hasattr(m, name)) + pruned_t = getattr(m, name) + + # then remove pruning + prune.remove(m, name) + self.assertIn(name, dict(m.named_parameters())) + self.assertNotIn(name + "_orig", dict(m.named_parameters())) + self.assertNotIn(name + "_mask", dict(m.named_buffers())) + final_t = getattr(m, name) + + self.assertEqual(pruned_t, final_t) + + def test_remove_pruning_exception(self): + r"""Removing from an unpruned tensor throws an assertion error + """ + modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)] + names = ['weight', 'bias'] + + for m in modules: + for name in names: + with self.subTest(m=m, name=name): + # check that the module isn't pruned + self.assertFalse(prune.is_pruned(m)) + # since it isn't pruned, pruning can't be removed from it + with self.assertRaises(ValueError): + prune.remove(m, name) + + + def test_global_pruning(self): + r"""Test that global l1 unstructured pruning over 2 parameters removes + the `amount=4` smallest global weights across the 2 parameters. + """ + m = nn.Linear(4, 2) + n = nn.Linear(3, 1) + # modify the weight matrices by hand + m.weight = torch.nn.Parameter( + torch.tensor([[1, 2, 3, 4], [-4, -3, -2, -1]]).to( + dtype=torch.float32) + ) + n.weight = torch.nn.Parameter( + torch.tensor([[0, 0.1, -2]]).to( + dtype=torch.float32) + ) + + params_to_prune = ( + (m, 'weight'), + (n, 'weight'), + ) + + # prune the 4 smallest weights globally by L1 magnitude + prune.global_unstructured( + params_to_prune, + pruning_method=prune.L1Unstructured, + amount=4 + ) + + expected_mweight = torch.tensor([[0, 2, 3, 4], [-4, -3, -2, 0]], + dtype=m.weight.dtype) + self.assertEqual(expected_mweight, m.weight) + + expected_nweight = torch.tensor([[0, 0, -2]]).to(dtype=n.weight.dtype) + self.assertEqual(expected_nweight, n.weight) + + def test_global_pruning_importance_scores(self): + r"""Test that global l1 unstructured pruning over 2 parameters removes + the `amount=4` smallest global weights across the 2 parameters. + """ + m = nn.Linear(4, 2) + n = nn.Linear(3, 1) + # modify the weight matrices by hand + m.weight = torch.nn.Parameter( + torch.tensor([[1, 2, 3, 4], [-4, -3, -2, -1]]).to( + dtype=torch.float32) + ) + m_importance_scores = torch.tensor( + [[4, 2, 1, 3], [-3, -1, -2, -4]], dtype=torch.float32 + ) + n.weight = torch.nn.Parameter( + torch.tensor([[0, 0.1, -2]]).to( + dtype=torch.float32) + ) + n_importance_scores = torch.tensor([[0, 10., -0.2]]).to(dtype=torch.float32) + + params_to_prune = ( + (m, 'weight'), + (n, 'weight'), + ) + importance_scores = { + (m, 'weight'): m_importance_scores, + (n, 'weight'): n_importance_scores, + } + + # prune the 4 smallest weights globally by L1 magnitude + prune.global_unstructured( + params_to_prune, + pruning_method=prune.L1Unstructured, + amount=4, + importance_scores=importance_scores, + ) + + expected_m_weight = torch.tensor([[1, 2, 0, 4], [-4, 0, -2, -1]], + dtype=m.weight.dtype) + self.assertEqual(expected_m_weight, m.weight) + + expected_n_weight = torch.tensor([[0, 0.1, 0]]).to(dtype=n.weight.dtype) + self.assertEqual(expected_n_weight, n.weight) + + def test_custom_from_mask_pruning(self): + r"""Test that the CustomFromMask is capable of receiving + as input at instantiation time a custom mask, and combining it with + the previous default mask to generate the correct final mask. + """ + # new mask + mask = torch.tensor([[0, 1, 1, 0], [0, 0, 1, 1]]) + # old mask + default_mask = torch.tensor([[0, 0, 0, 0], [1, 1, 1, 1]]) + + # some tensor (not actually used) + t = torch.rand_like(mask.to(dtype=torch.float32)) + + p = prune.CustomFromMask(mask=mask) + + computed_mask = p.compute_mask(t, default_mask) + expected_mask = torch.tensor([[0, 0, 0, 0], [0, 0, 1, 1]], dtype=computed_mask.dtype) + + self.assertEqual(computed_mask, expected_mask) + + def test_pruning_rollback(self): + r"""Test that if something fails when the we try to compute the mask, + then the model isn't left in some intermediate half-pruned state. + The try/except statement in `apply` should handle rolling back + to the previous state before pruning began. + """ + modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)] + names = ['weight', 'bias'] + + for m in modules: + for name in names: + with self.subTest(m=m, name=name): + + with mock.patch( + "torch.nn.utils.prune.L1Unstructured.compute_mask" + ) as compute_mask: + compute_mask.side_effect = Exception('HA!') + with self.assertRaises(Exception): + prune.l1_unstructured(m, name=name, amount=0.9) + + self.assertTrue( + name in dict(m.named_parameters()) + ) + self.assertFalse( + name + '_mask' in dict(m.named_buffers()) + ) + self.assertFalse( + name + '_orig' in dict(m.named_parameters()) + ) + + def test_pruning_serialization_model(self): + # create a model + model = torch.nn.Sequential( + torch.nn.Linear(10, 10), + torch.nn.ReLU(), + torch.nn.Linear(10, 1), + ) + # check that everything looks normal before pruning + self.assertNotIn('0.weight_orig', model.state_dict()) + self.assertNotIn('0.weight_mask', model.state_dict()) + self.assertIn('0.weight', model.state_dict()) + + # prune one of its parameters + prune.l1_unstructured(module=model[0], name='weight', amount=0.9) + + # check that the original weight and the new mask are present + self.assertIn('0.weight_orig', model.state_dict()) + self.assertIn('0.weight_mask', model.state_dict()) + self.assertNotIn('0.weight', model.state_dict()) + self.assertTrue(hasattr(model[0], 'weight')) + + pruned_weight = model[0].weight + + with TemporaryFileName() as fname: + torch.save(model, fname) + new_model = torch.load(fname) + + # check that the original weight and the new mask are present + self.assertIn('0.weight_orig', new_model.state_dict()) + self.assertIn('0.weight_mask', new_model.state_dict()) + self.assertNotIn('0.weight', new_model.state_dict()) + self.assertTrue(hasattr(new_model[0], 'weight')) + + self.assertEqual(pruned_weight, new_model[0].weight) + + def test_pruning_serialization_state_dict(self): + # create a model + model = torch.nn.Sequential( + torch.nn.Linear(10, 10), + torch.nn.ReLU(), + torch.nn.Linear(10, 1), + ) + # check that everything looks normal before pruning + self.assertNotIn('0.weight_orig', model.state_dict()) + self.assertNotIn('0.weight_mask', model.state_dict()) + self.assertIn('0.weight', model.state_dict()) + + # prune one of its parameters + prune.l1_unstructured(module=model[0], name='weight', amount=0.9) + + # check that the original weight and the new mask are present + self.assertIn('0.weight_orig', model.state_dict()) + self.assertIn('0.weight_mask', model.state_dict()) + self.assertNotIn('0.weight', model.state_dict()) + self.assertTrue(hasattr(model[0], 'weight')) + + pruned_weight = model[0].weight + + # make pruning permanent and restore parameter names as in base + # architecture + prune.remove(module=model[0], name='weight') + + # check that the original weight and the new mask are no longer present + self.assertNotIn('0.weight_orig', model.state_dict()) + self.assertNotIn('0.weight_mask', model.state_dict()) + self.assertIn('0.weight', model.state_dict()) + + # save the state dict of model and reload it into new_model + new_model = torch.nn.Sequential( + torch.nn.Linear(10, 10), + torch.nn.ReLU(), + torch.nn.Linear(10, 1), + ) + with TemporaryFileName() as fname: + torch.save(model.state_dict(), fname) + new_model.load_state_dict(torch.load(fname)) + + # check that the original weight and the new mask are not present in + # new_model either. + self.assertNotIn('0.weight_orig', new_model.state_dict()) + self.assertNotIn('0.weight_mask', new_model.state_dict()) + self.assertIn('0.weight', new_model.state_dict()) + + self.assertEqual(pruned_weight, new_model[0].weight) + + def test_prune(self): + # create a new pruning method + p = prune.L1Unstructured(amount=2) + # create tensor to be pruned + t = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]).to(dtype=torch.float32) + # create prior mask by hand + default_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 1]]) + # since we are pruning the two lowest magnitude units, the outcome of + # the calculation should be this: + expected_mask = torch.tensor([[0, 0, 1, 0], [1, 1, 0, 1]]) + pruned_tensor = p.prune(t, default_mask) + self.assertEqual(t * expected_mask, pruned_tensor) + + def test_prune_importance_scores(self): + # create a new pruning method + p = prune.L1Unstructured(amount=2) + # create tensor to be pruned + t = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]).to(dtype=torch.float32) + importance_scores = torch.tensor( + [[1, 2, 3, 4], [1.5, 1.6, 1.7, 1.8]] + ).to(dtype=torch.float32) + # create prior mask by hand + default_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 1]]) + # since we are pruning the two lowest magnitude units, the outcome of + # the calculation should be this: + expected_mask = torch.tensor([[0, 1, 1, 0], [0, 1, 0, 1]]) + pruned_tensor = p.prune(t, default_mask, importance_scores=importance_scores) + self.assertEqual(t * expected_mask, pruned_tensor) + + def test_prune_importance_scores_mimic_default(self): + # create a new pruning method + p = prune.L1Unstructured(amount=2) + # create tensor to be pruned + t = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]).to(dtype=torch.float32) + # create prior mask by hand + default_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 1]]) + # since we are pruning the two lowest magnitude units, the outcome of + # the calculation should be this: + expected_mask = torch.tensor([[0, 0, 1, 0], [1, 1, 0, 1]]) + pruned_tensor_without_importance_scores = p.prune(t, default_mask) + pruned_tensor_with_importance_scores = p.prune(t, default_mask, importance_scores=t) + self.assertEqual(pruned_tensor_without_importance_scores, pruned_tensor_with_importance_scores) + self.assertEqual(t * expected_mask, pruned_tensor_without_importance_scores) + + def test_rnn_pruning(self): + l = torch.nn.LSTM(32, 32) + # This Module has 4 parameters called: + # 'weight_ih_l0', 'weight_hh_l0', 'bias_ih_l0', 'bias_hh_l0' + + # Pruning one of them causes one of the weights to become a tensor + prune.l1_unstructured(l, 'weight_ih_l0', 0.5) + assert ( + sum([isinstance(p, torch.nn.Parameter) for p in l._flat_weights]) + == 3 + ) + + # Removing the pruning reparametrization restores the Parameter + prune.remove(l, 'weight_ih_l0') + assert ( + sum([isinstance(p, torch.nn.Parameter) for p in l._flat_weights]) + == 4 + ) + + # Make sure that, upon removal of the reparametrization, the + # `._parameters` and `.named_parameters` contain the right params. + # Specifically, the original weight ('weight_ih_l0') should be placed + # back in the parameters, while the reparametrization component + # ('weight_ih_l0_orig') should be removed. + assert 'weight_ih_l0' in l._parameters + assert l._parameters['weight_ih_l0'] is not None + assert 'weight_ih_l0_orig' not in l._parameters + assert 'weight_ih_l0' in dict(l.named_parameters()) + assert dict(l.named_parameters())['weight_ih_l0'] is not None + assert 'weight_ih_l0_orig' not in dict(l.named_parameters()) + +instantiate_parametrized_tests(TestPruningNN) + +if __name__ == '__main__': + run_tests() diff --git a/test/onnx/expect/TestOperators.test_avg_pool2d.expect b/test/onnx/expect/TestOperators.test_avg_pool2d.expect index d551ff38f809b..c5f8ba6b85781 100644 --- a/test/onnx/expect/TestOperators.test_avg_pool2d.expect +++ b/test/onnx/expect/TestOperators.test_avg_pool2d.expect @@ -1,6 +1,6 @@ ir_version: 7 producer_name: "pytorch" -producer_version: "CURRENT_VERSION" +producer_version: "2.0.0" graph { node { output: "onnx::Pad_1" @@ -33,11 +33,6 @@ graph { output: "3" name: "AveragePool_2" op_type: "AveragePool" - attribute { - name: "ceil_mode" - i: 0 - type: INT - } attribute { name: "kernel_shape" ints: 3 diff --git a/test/onnx/internal/test_diagnostics.py b/test/onnx/internal/test_diagnostics.py index fbe79216d0879..49402204e9d27 100644 --- a/test/onnx/internal/test_diagnostics.py +++ b/test/onnx/internal/test_diagnostics.py @@ -3,6 +3,7 @@ import contextlib import dataclasses import io +import typing import unittest from typing import AbstractSet, Tuple @@ -18,7 +19,7 @@ def _assert_has_diagnostics( rule_level_pairs: AbstractSet[Tuple[infra.Rule, infra.Level]], ): sarif_log = engine.sarif_log() - unseen_pairs = {(rule.id, level.value) for rule, level in rule_level_pairs} + unseen_pairs = {(rule.id, level.name.lower()) for rule, level in rule_level_pairs} actual_results = [] for run in sarif_log.runs: if run.results is None: @@ -110,23 +111,15 @@ class TestOnnxDiagnostics(common_utils.TestCase): def setUp(self): engine = diagnostics.engine engine.clear() + self._sample_rule = diagnostics.rules.missing_custom_symbolic_function super().setUp() - def test_assert_diagnostic_raises_when_diagnostic_not_found(self): - with self.assertRaises(AssertionError): - with assert_diagnostic( - self, - diagnostics.engine, - diagnostics.rules.node_missing_onnx_shape_inference, - diagnostics.levels.WARNING, - ): - pass - - def test_cpp_diagnose_emits_warning(self): + def _trigger_node_missing_onnx_shape_inference_warning_diagnostic_from_cpp( + self, + ) -> diagnostics.ExportDiagnostic: class CustomAdd(torch.autograd.Function): @staticmethod def forward(ctx, x, y): - ctx.save_for_backward(x, y) return x + y @staticmethod @@ -137,6 +130,30 @@ class M(torch.nn.Module): def forward(self, x): return CustomAdd.apply(x, x) + # trigger warning for missing shape inference. + rule = diagnostics.rules.node_missing_onnx_shape_inference + torch.onnx.export(M(), torch.randn(3, 4), io.BytesIO()) + + context = diagnostics.engine.contexts[-1] + for diagnostic in context.diagnostics: + if ( + diagnostic.rule == rule + and diagnostic.level == diagnostics.levels.WARNING + ): + return typing.cast(diagnostics.ExportDiagnostic, diagnostic) + raise AssertionError("No diagnostic found.") + + def test_assert_diagnostic_raises_when_diagnostic_not_found(self): + with self.assertRaises(AssertionError): + with assert_diagnostic( + self, + diagnostics.engine, + diagnostics.rules.node_missing_onnx_shape_inference, + diagnostics.levels.WARNING, + ): + pass + + def test_cpp_diagnose_emits_warning(self): with assert_diagnostic( self, diagnostics.engine, @@ -144,7 +161,7 @@ def forward(self, x): diagnostics.levels.WARNING, ): # trigger warning for missing shape inference. - torch.onnx.export(M(), torch.randn(3, 4), io.BytesIO()) + self._trigger_node_missing_onnx_shape_inference_warning_diagnostic_from_cpp() def test_py_diagnose_emits_error(self): class M(torch.nn.Module): @@ -168,15 +185,44 @@ def forward(self, x): def test_diagnostics_engine_records_diagnosis_reported_outside_of_export( self, ): - sample_rule = diagnostics.rules.missing_custom_symbolic_function sample_level = diagnostics.levels.ERROR with assert_diagnostic( self, diagnostics.engine, - sample_rule, + self._sample_rule, sample_level, ): - diagnostics.context.diagnose(sample_rule, sample_level, ("foo",)) + diagnostics.context.diagnose(self._sample_rule, sample_level) + + def test_diagnostics_records_python_call_stack(self): + diagnostic = diagnostics.ExportDiagnostic( + self._sample_rule, diagnostics.levels.NOTE + ) + stack = diagnostic.python_call_stack + assert stack is not None # for mypy + self.assertGreater(len(stack.frames), 0) + frame = stack.frames[0] + assert frame.location.snippet is not None # for mypy + self.assertIn("self._sample_rule", frame.location.snippet) + assert frame.location.uri is not None # for mypy + self.assertIn("test_diagnostics.py", frame.location.uri) + + def test_diagnostics_records_cpp_call_stack(self): + diagnostic = ( + self._trigger_node_missing_onnx_shape_inference_warning_diagnostic_from_cpp() + ) + stack = diagnostic.cpp_call_stack + assert stack is not None # for mypy + self.assertGreater(len(stack.frames), 0) + frame_messages = [frame.location.message for frame in stack.frames] + # node missing onnx shape inference warning only comes from ToONNX (_jit_pass_onnx) + # after node-level shape type inference and processed symbolic_fn output type + self.assertTrue( + any( + isinstance(message, str) and "torch::jit::NodeToONNX" in message + for message in frame_messages + ) + ) @dataclasses.dataclass @@ -196,31 +242,17 @@ class TestDiagnosticsInfra(common_utils.TestCase): def setUp(self): self.engine = infra.DiagnosticEngine() self.rules = _RuleCollectionForTest() - self.diagnostic_tool = infra.DiagnosticTool("test_tool", "1.0.0", self.rules) with contextlib.ExitStack() as stack: self.context = stack.enter_context( - self.engine.create_diagnostic_context(self.diagnostic_tool) + self.engine.create_diagnostic_context("test", "1.0.0") ) self.addCleanup(stack.pop_all().close) return super().setUp() - def test_diagnose_raises_value_error_when_rule_not_supported(self): - rule_id = "0" - rule_name = "nonexistent-rule" - with self.assertRaisesRegex( - ValueError, - f"Rule '{rule_id}:{rule_name}' is not supported by this tool " - f"'{self.diagnostic_tool.name} {self.diagnostic_tool.version}'.", - ): - self.context.diagnose( - infra.Rule(id=rule_id, name=rule_name, message_default_template=""), - infra.Level.WARNING, - ) - def test_diagnostics_engine_records_diagnosis_reported_in_nested_contexts( self, ): - with self.engine.create_diagnostic_context(self.diagnostic_tool) as context: + with self.engine.create_diagnostic_context("inner_test", "1.0.1") as context: context.diagnose(self.rules.rule_without_message_args, infra.Level.WARNING) sarif_log = self.engine.sarif_log() self.assertEqual(len(sarif_log.runs), 2) @@ -250,9 +282,7 @@ def test_diagnostics_engine_records_diagnosis_with_custom_rules(self): ) with self.engine.create_diagnostic_context( - tool=infra.DiagnosticTool( - name="custom_tool", version="1.0", rules=custom_rules - ) + "custom_rules", "1.0" ) as diagnostic_context: with assert_all_diagnostics( self, @@ -269,20 +299,6 @@ def test_diagnostics_engine_records_diagnosis_with_custom_rules(self): custom_rules.custom_rule_2, infra.Level.ERROR # type: ignore[attr-defined] ) - def test_diagnostic_tool_raises_type_error_when_diagnostic_type_is_invalid( - self, - ): - with self.assertRaisesRegex( - TypeError, - "Expected diagnostic_type to be a subclass of Diagnostic, but got", - ): - _ = infra.DiagnosticTool( - "custom_tool", - "1.0", - self.rules, - diagnostic_type=int, - ) - if __name__ == "__main__": common_utils.run_tests() diff --git a/test/onnx/onnx_test_common.py b/test/onnx/onnx_test_common.py index 45f90d4193ce7..6963d16284ce6 100644 --- a/test/onnx/onnx_test_common.py +++ b/test/onnx/onnx_test_common.py @@ -3,15 +3,13 @@ from __future__ import annotations import os -import random from typing import Any, Mapping, Type -import numpy as np import onnxruntime +import pytorch_test_common import torch from torch.onnx import _constants, verification -from torch.testing._internal import common_utils onnx_model_dir = os.path.join( os.path.dirname(os.path.realpath(__file__)), @@ -54,13 +52,7 @@ def parameterize_class_name(cls: Type, idx: int, input_dicts: Mapping[Any, Any]) return f"{cls.__name__}_{suffix}" -def set_rng_seed(seed): - torch.manual_seed(seed) - random.seed(seed) - np.random.seed(seed) - - -class _TestONNXRuntime(common_utils.TestCase): +class _TestONNXRuntime(pytorch_test_common.ExportTestCase): opset_version = _constants.ONNX_DEFAULT_OPSET keep_initializers_as_inputs = True # For IR version 3 type export. is_script = False @@ -68,7 +60,7 @@ class _TestONNXRuntime(common_utils.TestCase): check_dtype = True def setUp(self): - set_rng_seed(0) + super().setUp() onnxruntime.set_seed(0) if torch.cuda.is_available(): torch.cuda.manual_seed_all(0) diff --git a/test/onnx/pytorch_test_common.py b/test/onnx/pytorch_test_common.py index 4a44932fb1206..4e443c333f35f 100644 --- a/test/onnx/pytorch_test_common.py +++ b/test/onnx/pytorch_test_common.py @@ -2,12 +2,17 @@ import functools import os +import random import sys import unittest from typing import Optional +import numpy as np + import torch from torch.autograd import function +from torch.onnx._internal import diagnostics +from torch.testing._internal import common_utils pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.insert(-1, pytorch_test_dir) @@ -188,3 +193,24 @@ def wrapper(self, *args, **kwargs): def flatten(x): return tuple(function._iter_filter(lambda o: isinstance(o, torch.Tensor))(x)) + + +def set_rng_seed(seed): + torch.manual_seed(seed) + random.seed(seed) + np.random.seed(seed) + + +class ExportTestCase(common_utils.TestCase): + """Test case for ONNX export. + + Any test case that tests functionalities under torch.onnx should inherit from this class. + """ + + def setUp(self): + super().setUp() + # TODO(#88264): Flaky test failures after changing seed. + set_rng_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + diagnostics.engine.clear() diff --git a/test/onnx/test_autograd_funs.py b/test/onnx/test_autograd_funs.py index 97f0652ecf378..a5498f39d2da7 100644 --- a/test/onnx/test_autograd_funs.py +++ b/test/onnx/test_autograd_funs.py @@ -1,16 +1,16 @@ # Owner(s): ["module: onnx"] -import unittest +import pytorch_test_common import torch - from onnx_test_common import run_model_test from torch.onnx import OperatorExportTypes from torch.onnx._globals import GLOBALS from torch.onnx.utils import _model_to_graph +from torch.testing._internal import common_utils -class TestAutogradFuns(unittest.TestCase): +class TestAutogradFuns(pytorch_test_common.ExportTestCase): opset_version = GLOBALS.export_onnx_opset_version keep_initializers_as_inputs = False onnx_shape_inference = True @@ -209,4 +209,4 @@ def forward(self, input): if __name__ == "__main__": - unittest.main() + common_utils.run_tests() diff --git a/test/onnx/test_custom_ops.py b/test/onnx/test_custom_ops.py index db5ddfd001140..5609b497535e9 100644 --- a/test/onnx/test_custom_ops.py +++ b/test/onnx/test_custom_ops.py @@ -4,6 +4,7 @@ import numpy as np import onnx import onnx_test_common +import pytorch_test_common import torch import torch.utils.cpp_extension from test_pytorch_onnx_caffe2 import do_export @@ -11,7 +12,7 @@ from torch.testing._internal import common_utils -class TestCustomOps(common_utils.TestCase): +class TestCustomOps(pytorch_test_common.ExportTestCase): def test_custom_add(self): op_source = """ #include @@ -38,9 +39,7 @@ def forward(self, a, b): def symbolic_custom_add(g, self, other): return g.op("Add", self, other) - from torch.onnx import register_custom_op_symbolic - - register_custom_op_symbolic( + torch.onnx.register_custom_op_symbolic( "custom_namespace::custom_add", symbolic_custom_add, 9 ) @@ -48,6 +47,9 @@ def symbolic_custom_add(g, self, other): y = torch.randn(2, 3, 4, requires_grad=False) model = CustomAddModel() + # before fixing #51833 this used to give a PyBind error + # with PyTorch 1.10dev ("Unable to cast from non-held to held + # instance (T& to Holder)") onnxir, _ = do_export(model, (x, y), opset_version=11) onnx_model = onnx.ModelProto.FromString(onnxir) prepared = c2.prepare(onnx_model) @@ -55,7 +57,7 @@ def symbolic_custom_add(g, self, other): np.testing.assert_array_equal(caffe2_out[0], model(x, y).cpu().numpy()) -class TestCustomAutogradFunction(common_utils.TestCase): +class TestCustomAutogradFunction(pytorch_test_common.ExportTestCase): opset_version = 9 keep_initializers_as_inputs = False onnx_shape_inference = True @@ -129,7 +131,7 @@ def symbolic_pythonop(ctx: torch.onnx.SymbolicContext, g, *args, **kwargs): onnx_test_common.run_model_test(self, model, input_args=(x,)) -class TestExportAsContribOps(common_utils.TestCase): +class TestExportAsContribOps(pytorch_test_common.ExportTestCase): opset_version = 14 keep_initializers_as_inputs = False onnx_shape_inference = True diff --git a/test/jit/test_export_modes.py b/test/onnx/test_export_modes.py similarity index 64% rename from test/jit/test_export_modes.py rename to test/onnx/test_export_modes.py index dbf10cddc059b..502f31b38b10a 100644 --- a/test/jit/test_export_modes.py +++ b/test/onnx/test_export_modes.py @@ -1,29 +1,27 @@ -# Owner(s): ["oncall: jit"] +# Owner(s): ["module: onnx"] import io import os import shutil import sys import tempfile +import unittest import torch import torch.nn as nn -from torch.onnx import OperatorExportTypes from torch.autograd import Variable +from torch.onnx import OperatorExportTypes # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) -from torch.testing._internal.jit_utils import JitTestCase -from torch.testing._internal.common_utils import skipIfNoLapack, skipIfCaffe2, skipIfNoCaffe2 +import pytorch_test_common + +from torch.testing._internal import common_utils -if __name__ == '__main__': - raise RuntimeError("This test file is not meant to be run directly, use:\n\n" - "\tpython test/test_jit.py TESTNAME\n\n" - "instead.") # Smoke tests for export methods -class TestExportModes(JitTestCase): +class TestExportModes(pytorch_test_common.ExportTestCase): class MyModel(nn.Module): def __init__(self): super(TestExportModes.MyModel, self).__init__() @@ -35,41 +33,66 @@ def test_protobuf(self): torch_model = TestExportModes.MyModel() fake_input = Variable(torch.randn(1, 1, 224, 224), requires_grad=True) f = io.BytesIO() - torch.onnx._export(torch_model, (fake_input), f, verbose=False, - export_type=torch.onnx.ExportTypes.PROTOBUF_FILE) + torch.onnx._export( + torch_model, + (fake_input), + f, + verbose=False, + export_type=torch.onnx.ExportTypes.PROTOBUF_FILE, + ) def test_zipfile(self): torch_model = TestExportModes.MyModel() fake_input = Variable(torch.randn(1, 1, 224, 224), requires_grad=True) f = io.BytesIO() - torch.onnx._export(torch_model, (fake_input), f, verbose=False, - export_type=torch.onnx.ExportTypes.ZIP_ARCHIVE) + torch.onnx._export( + torch_model, + (fake_input), + f, + verbose=False, + export_type=torch.onnx.ExportTypes.ZIP_ARCHIVE, + ) def test_compressed_zipfile(self): torch_model = TestExportModes.MyModel() fake_input = Variable(torch.randn(1, 1, 224, 224), requires_grad=True) f = io.BytesIO() - torch.onnx._export(torch_model, (fake_input), f, verbose=False, - export_type=torch.onnx.ExportTypes.COMPRESSED_ZIP_ARCHIVE) + torch.onnx._export( + torch_model, + (fake_input), + f, + verbose=False, + export_type=torch.onnx.ExportTypes.COMPRESSED_ZIP_ARCHIVE, + ) def test_directory(self): torch_model = TestExportModes.MyModel() fake_input = Variable(torch.randn(1, 1, 224, 224), requires_grad=True) d = tempfile.mkdtemp() - torch.onnx._export(torch_model, (fake_input), d, verbose=False, - export_type=torch.onnx.ExportTypes.DIRECTORY) + torch.onnx._export( + torch_model, + (fake_input), + d, + verbose=False, + export_type=torch.onnx.ExportTypes.DIRECTORY, + ) shutil.rmtree(d) def test_onnx_multiple_return(self): @torch.jit.script def foo(a): return (a, a) + f = io.BytesIO() x = torch.ones(3) - torch.onnx._export(foo, (x,), f) - - @skipIfNoCaffe2 - @skipIfNoLapack + torch.onnx.export(foo, (x,), f) + + # TODO(87318): Can't pass even with Caffe2 + @unittest.skip( + "RuntimeError: ScalarType UNKNOWN_SCALAR is an unexpected tensor scalar type" + ) + @common_utils.skipIfNoCaffe2 + @common_utils.skipIfNoLapack def test_caffe2_aten_fallback(self): class ModelWithAtenNotONNXOp(nn.Module): def forward(self, x, y): @@ -80,13 +103,15 @@ def forward(self, x, y): x = torch.rand(3, 4) y = torch.rand(3, 4) torch.onnx.export_to_pretty_string( - ModelWithAtenNotONNXOp(), (x, y), + ModelWithAtenNotONNXOp(), + (x, y), add_node_names=False, do_constant_folding=False, - operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK) + operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK, + ) - @skipIfCaffe2 - @skipIfNoLapack + @common_utils.skipIfCaffe2 + @common_utils.skipIfNoLapack def test_aten_fallback(self): class ModelWithAtenNotONNXOp(nn.Module): def forward(self, x, y): @@ -97,12 +122,14 @@ def forward(self, x, y): x = torch.rand(3, 4) y = torch.rand(3, 4) torch.onnx.export_to_pretty_string( - ModelWithAtenNotONNXOp(), (x, y), + ModelWithAtenNotONNXOp(), + (x, y), add_node_names=False, do_constant_folding=False, operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK, # support for linalg.qr was added in later op set versions. - opset_version=9) + opset_version=9, + ) # torch.fmod is using to test ONNX_ATEN. # If you plan to remove fmod from aten, or found this test failed. @@ -115,7 +142,13 @@ def forward(self, x, y): x = torch.randn(3, 4, dtype=torch.float32) y = torch.randn(3, 4, dtype=torch.float32) torch.onnx.export_to_pretty_string( - ModelWithAtenFmod(), (x, y), + ModelWithAtenFmod(), + (x, y), add_node_names=False, do_constant_folding=False, - operator_export_type=OperatorExportTypes.ONNX_ATEN) + operator_export_type=OperatorExportTypes.ONNX_ATEN, + ) + + +if __name__ == "__main__": + common_utils.run_tests() diff --git a/test/onnx/test_models.py b/test/onnx/test_models.py index 7084bd75bace7..15904839957ee 100644 --- a/test/onnx/test_models.py +++ b/test/onnx/test_models.py @@ -2,8 +2,9 @@ import unittest -import torch +import pytorch_test_common +import torch from model_defs.dcgan import _netD, _netG, bsz, imgsz, nz, weights_init from model_defs.emb_seq import EmbeddingNetwork1, EmbeddingNetwork2 from model_defs.mnist import MNIST @@ -44,7 +45,7 @@ def toC(x): BATCH_SIZE = 2 -class TestModels(common_utils.TestCase): +class TestModels(pytorch_test_common.ExportTestCase): opset_version = 9 # Caffe2 doesn't support the default. keep_initializers_as_inputs = False diff --git a/test/onnx/test_models_onnxruntime.py b/test/onnx/test_models_onnxruntime.py index c84640e535e11..4b7bdb58ae514 100644 --- a/test/onnx/test_models_onnxruntime.py +++ b/test/onnx/test_models_onnxruntime.py @@ -8,6 +8,7 @@ import onnx_test_common import parameterized import PIL +import pytorch_test_common import test_models import torch @@ -64,7 +65,7 @@ def exportTest( TestModels = type( "TestModels", - (common_utils.TestCase,), + (pytorch_test_common.ExportTestCase,), dict( test_models.TestModels.__dict__, is_script_test_enabled=False, @@ -77,7 +78,7 @@ def exportTest( # model tests for scripting with new JIT APIs and shape inference TestModels_new_jit_API = type( "TestModels_new_jit_API", - (common_utils.TestCase,), + (pytorch_test_common.ExportTestCase,), dict( TestModels.__dict__, exportTest=exportTest, @@ -393,6 +394,7 @@ def forward(self, images, features: Mapping[str, torch.Tensor]): ) @skipScriptTest() # TODO: #75625 + @skipIfUnsupportedMinOpsetVersion(20) def test_transformer_encoder(self): class MyModule(torch.nn.Module): def __init__(self, ninp, nhead, nhid, dropout, nlayers): diff --git a/test/onnx/test_onnx_opset.py b/test/onnx/test_onnx_opset.py index dab33bf00b09d..ef79e82ee266a 100644 --- a/test/onnx/test_onnx_opset.py +++ b/test/onnx/test_onnx_opset.py @@ -4,6 +4,7 @@ import itertools import onnx +import pytorch_test_common import torch import torch.onnx @@ -70,7 +71,7 @@ def check_onnx_opsets_operator( check_onnx_opset_operator(model, ops[opset_version], opset_version) -class TestONNXOpset(common_utils.TestCase): +class TestONNXOpset(pytorch_test_common.ExportTestCase): def test_opset_fallback(self): class MyModule(Module): def forward(self, x): diff --git a/test/onnx/test_onnxscript_no_runtime.py b/test/onnx/test_onnxscript_no_runtime.py new file mode 100644 index 0000000000000..125e899af9449 --- /dev/null +++ b/test/onnx/test_onnxscript_no_runtime.py @@ -0,0 +1,164 @@ +# Owner(s): ["module: onnx"] + +"""Test the support on onnxscript in PyTorch-ONNX converter.""" +import io +from typing import List + +import onnx +import onnxscript +import torch +from onnxscript.onnx_types import FLOAT +from torch.onnx._internal import jit_utils +from torch.testing._internal import common_utils + + +class TestONNXScriptExport(common_utils.TestCase): + + # opset version is + # 1. local function is supported after opset 15 + # 2. onnx-script requires users to determine opset in local function + opset_version = 15 + + def test_onnxscript_registration_with_multiple_models(self): + + from onnxscript.onnx_opset import opset15 as op + + # 1. Register Selu onnxscript function as custom Op + custom_opset = onnxscript.values.Opset(domain="onnx-script", version=1) + + @onnxscript.script(custom_opset) + def Selu(X): + # TODO: onnx/ort doesn't support default values for now + # move this when they do + alpha = 1.67326 # auto wrapped as Constants + gamma = 1.0507 + alphaX = op.CastLike(alpha, X) + gammaX = op.CastLike(gamma, X) + neg = gammaX * (alphaX * op.Exp(X) - alphaX) + pos = gammaX * X + zero = op.CastLike(0, X) + return op.Where(X <= zero, neg, pos) + + def custom_selu(g: jit_utils.GraphContext, X): + return g.onnxscript_op(Selu, X).setType(X.type()) + + torch.onnx.register_custom_op_symbolic( + symbolic_name="aten::selu", + symbolic_fn=custom_selu, + opset_version=self.opset_version, + ) + + # 2. Register layer_norm onnxscript function as custom Op + @onnxscript.script(custom_opset) + def layer_norm( + X, axes: List[int], weight: FLOAT[...], bias: FLOAT[...], eps: float + ): + mean = op.ReduceMean(X, axes=axes) + D = X - mean # op.Sub(X, mean) + DD = D * D # op.Mul(D, D) + var = op.ReduceMean(DD, axes=axes) + vareps = var + eps # op.Add(var, eps) + stddev = op.Sqrt(vareps) + invstddev = op.Reciprocal(stddev) + normalized = D * invstddev # op.Mul(D, invstddev) + normalizedw = op.CastLike( + normalized, weight + ) # Type issue if missing this Op + normalizedscaled = normalizedw * weight # op.Mul(normalized, weight) + return normalizedscaled + bias + + @torch.onnx.symbolic_helper.parse_args("v", "is", "v", "v", "f", "none") + def custom_layer_norm( + g, input, normalized_shape, weight, bias, eps, cudnn_enable + ): + # TODO: move the comprehension into local function once + # it's supported by onnxscript + axes = [-i for i in range(len(normalized_shape), 0, -1)] + return g.onnxscript_op( + layer_norm, input, weight, bias, axes_i=axes, eps_f=eps + ).setType(input.type()) + + torch.onnx.register_custom_op_symbolic( + symbolic_name="aten::layer_norm", + symbolic_fn=custom_layer_norm, + opset_version=self.opset_version, + ) + + # 3. export two models + x = torch.randn(1, 2, 3, 4, requires_grad=True) + model_selu = torch.nn.SELU() + selu_onnx = io.BytesIO() + torch.onnx.export(model_selu, x, selu_onnx, opset_version=self.opset_version) + + N, C = 3, 4 + y = torch.randn(N, C) + model_layer_norm = torch.nn.LayerNorm(C) + layer_norm_onnx = io.BytesIO() + torch.onnx.export( + model_layer_norm, y, layer_norm_onnx, opset_version=self.opset_version + ) + + # 4. test on models + selu_proto = onnx.load(io.BytesIO(selu_onnx.getvalue())) + layer_norm_proto = onnx.load(io.BytesIO(layer_norm_onnx.getvalue())) + + self.assertEqual(len(selu_proto.functions), 1) + self.assertEqual(len(layer_norm_proto.functions), 1) + self.assertEqual(selu_proto.functions[0].name, "Selu") + self.assertEqual(layer_norm_proto.functions[0].name, "layer_norm") + + def test_loop_registration(self): + # Control flow is tested for _find_onnxscript_op function in torch/onnx/utils.py, + # which has recursive logic to go through every nodes with subgraph in model proto + class NestedLoopsModel(torch.jit.ScriptModule): + def __init__(self): + super().__init__() + self.selu = torch.nn.SELU() + + @torch.jit.script_method + def forward(self, x): + y = x + for i in range(x.size(3)): + if i == 0: + y = self.selu(x) + else: + y += i + return y + + model = NestedLoopsModel() + inputs = torch.zeros(1, 2, 3, 4) + + from onnxscript.onnx_opset import opset15 as op + + custom_opset = onnxscript.values.Opset(domain="onnx-script", version=2) + + @onnxscript.script(custom_opset) + def Selu(X): + alpha = 1.6732632423543772848170429916717 + gamma = 1.0507009873554804934193349852946 + alphaX = op.CastLike(alpha, X) + gammaX = op.CastLike(gamma, X) + neg = gammaX * (alphaX * op.Exp(X) - alphaX) + pos = gammaX * X + zero = op.CastLike(0, X) + return op.Where(X <= zero, neg, pos) + + def custom_selu(g, X): + # domain of the Op should be aligned with onnx-script + # setType API is required for custom Op to support + # torchscript shape type inference + print("custom_selu is used!") + return g.onnxscript_op(Selu, X).setType(X.type()) + + torch.onnx.register_custom_op_symbolic( + symbolic_name="aten::selu", + symbolic_fn=custom_selu, + opset_version=15, + ) + + saved_model = io.BytesIO() + torch.onnx.export( + torch.jit.script(model), inputs, f=saved_model, opset_version=15 + ) + loop_selu_proto = onnx.load(io.BytesIO(saved_model.getvalue())) + self.assertEqual(len(loop_selu_proto.functions), 1) diff --git a/test/onnx/test_onnxscript_runtime.py b/test/onnx/test_onnxscript_runtime.py new file mode 100644 index 0000000000000..e22e76c8315e7 --- /dev/null +++ b/test/onnx/test_onnxscript_runtime.py @@ -0,0 +1,130 @@ +# Owner(s): ["module: onnx"] + +"""Test the support on onnxscript in PyTorch-ONNX converter with onnxruntime.""" +from typing import List + +import onnx_test_common +import onnxscript +import torch +from onnxscript.onnx_types import FLOAT +from torch.onnx._internal import jit_utils +from torch.testing._internal import common_utils + + +class TestONNXScriptRuntime(onnx_test_common._TestONNXRuntime): + + # opset version is + # 1. local function is supported after opset 15 + # 2. onnx-script requires users to determine opset in local function + opset_version = 15 + + def test_selu_from_onnxscript_example(self): + + x = torch.randn(1, 2, 3, 4, requires_grad=True) + model = torch.nn.SELU() + + from onnxscript.onnx_opset import opset15 as op + + # TODO(titaiwang): make an official domain for onnxscript usage + custom_opset = onnxscript.values.Opset(domain="onnx-script", version=1) + + @onnxscript.script(custom_opset) + def Selu(X): + # TODO: onnx/ort doesn't support default values for now + # move this when they do + alpha = 1.67326 # auto wrapped as Constants + gamma = 1.0507 + alphaX = op.CastLike(alpha, X) + gammaX = op.CastLike(gamma, X) + neg = gammaX * (alphaX * op.Exp(X) - alphaX) + pos = gammaX * X + zero = op.CastLike(0, X) + return op.Where(X <= zero, neg, pos) + + def custom_selu(g: jit_utils.GraphContext, X): + return g.onnxscript_op(Selu, X).setType(X.type()) + + torch.onnx.register_custom_op_symbolic( + symbolic_name="aten::selu", + symbolic_fn=custom_selu, + opset_version=self.opset_version, + ) + self.run_test(model, x) + + def test_layer_norm(self): + + x = torch.randn(2, 3) + y = torch.randn(2, 3) + z = torch.randn(2, 3) + + class N(torch.nn.Module): + def __init__(self, prob): + super().__init__() + self.dropout = torch.nn.Dropout(prob) + + def forward(self, x): + return self.dropout(x) + + class M(torch.nn.Module): + def __init__(self, num_layers): + super().__init__() + self.num_layers = num_layers + self.lns = torch.nn.ModuleList( + [torch.nn.LayerNorm(3, eps=i) for i in range(num_layers)] + ) + self.celu1 = torch.nn.CELU(1.0) + self.celu2 = torch.nn.CELU(2.0) + self.dropout = N(0.5) + + def forward(self, x, y, z): + res1 = self.celu1(x) + res2 = self.celu2(y) + for ln in self.lns: + z = ln(z) + return res1 + res2, self.dropout(z) + + model = M(3) + + from onnxscript.onnx_opset import opset15 as op + + custom_opset = onnxscript.values.Opset(domain="onnxscript", version=1) + + @onnxscript.script(custom_opset) + def layer_norm( + X, axes: List[int], weight: FLOAT[...], bias: FLOAT[...], eps: float + ): + mean = op.ReduceMean(X, axes=axes) + D = X - mean # op.Sub(X, mean) + DD = D * D # op.Mul(D, D) + var = op.ReduceMean(DD, axes=axes) + vareps = var + eps # op.Add(var, eps) + stddev = op.Sqrt(vareps) + invstddev = op.Reciprocal(stddev) + normalized = D * invstddev # op.Mul(D, invstddev) + normalizedw = op.CastLike( + normalized, weight + ) # Type issue if missing this Op + normalizedscaled = normalizedw * weight # op.Mul(normalized, weight) + return normalizedscaled + bias + + @torch.onnx.symbolic_helper.parse_args("v", "is", "v", "v", "f", "none") + def custom_layer_norm( + g, input, normalized_shape, weight, bias, eps, cudnn_enable + ): + # TODO: move the comprehension into local function once it's supported by onnxscript + axes = [-i for i in range(len(normalized_shape), 0, -1)] + return g.onnxscript_op( + layer_norm, input, weight, bias, axes_i=axes, eps_f=eps + ).setType(input.type()) + + torch.onnx.register_custom_op_symbolic( + symbolic_name="aten::layer_norm", + symbolic_fn=custom_layer_norm, + opset_version=self.opset_version, + ) + + self.run_test(model, (x, y, z)) + + +if __name__ == "__main__": + common_utils.run_tests() diff --git a/test/onnx/test_operators.py b/test/onnx/test_operators.py index 9b743a50d3323..7375cf3fe4d7a 100644 --- a/test/onnx/test_operators.py +++ b/test/onnx/test_operators.py @@ -31,6 +31,7 @@ ) from torch.autograd import Function, Variable from torch.nn import functional, Module +from torch.onnx._internal import diagnostics from torch.onnx.symbolic_helper import ( _get_tensor_dim_size, _get_tensor_sizes, @@ -71,6 +72,10 @@ def forward(self, *args): class TestOperators(common_utils.TestCase): + def setUp(self): + super().setUp() + diagnostics.engine.clear() + def assertONNX(self, f, args, params=None, **kwargs): if params is None: params = () @@ -649,10 +654,14 @@ def test_repeat_dim_overflow(self): x = torch.randn(1, 2, requires_grad=True) self.assertONNX(lambda x: x.repeat(1, 2, 3, 4), x) + @unittest.skip("It started failing after #81761") + # TODO(#83661): Fix and enable the test def test_norm_p1(self): x = torch.randn(1, 2, 3, 4, requires_grad=True) self.assertONNX(lambda x: x.norm(p=1, dim=2), (x)) + @unittest.skip("It started failing after #81761") + # TODO(#83661): Fix and enable the test def test_norm_p2(self): x = torch.randn(1, 2, 3, 4, requires_grad=True) self.assertONNX(lambda x: x.norm(p=2, dim=2), (x)) @@ -952,6 +961,8 @@ def test_pixel_shuffle(self): lambda x: torch.pixel_shuffle(x, upscale_factor=2), x, opset_version=11 ) + @unittest.skip("It started failing after #81761") + # TODO(#83661): Fix and enable the test def test_frobenius_norm(self): x = torch.randn(2, 3, 4).float() self.assertONNX(lambda x: torch.norm(x, p="fro", dim=(0, 1), keepdim=True), x) diff --git a/test/onnx/test_pytorch_helper.py b/test/onnx/test_pytorch_helper.py index 362841d8bf90f..7d7f3ade7f581 100644 --- a/test/onnx/test_pytorch_helper.py +++ b/test/onnx/test_pytorch_helper.py @@ -4,6 +4,7 @@ import unittest import numpy as np +import pytorch_test_common import torch.nn.init as init import torch.onnx @@ -15,7 +16,7 @@ from torch.testing._internal.common_utils import skipIfNoLapack -class TestCaffe2Backend(common_utils.TestCase): +class TestCaffe2Backend(pytorch_test_common.ExportTestCase): @skipIfNoLapack @unittest.skip("test broken because Lapack was always missing.") def test_helper(self): diff --git a/test/onnx/test_pytorch_jit_onnx.py b/test/onnx/test_pytorch_jit_onnx.py index f069251ee064c..784bd0954b0ad 100644 --- a/test/onnx/test_pytorch_jit_onnx.py +++ b/test/onnx/test_pytorch_jit_onnx.py @@ -1,5 +1,6 @@ # Owner(s): ["module: onnx"] import onnxruntime +import pytorch_test_common import torch from pytorch_test_common import skipIfNoCuda @@ -171,7 +172,7 @@ def MakeTestCase(opset_version: int) -> type: name = f"TestJITIRToONNX_opset{opset_version}" return type( str(name), - (common_utils.TestCase,), + (pytorch_test_common.ExportTestCase,), dict(_TestJITIRToONNX.__dict__, opset_version=opset_version), ) diff --git a/test/onnx/test_pytorch_onnx_caffe2.py b/test/onnx/test_pytorch_onnx_caffe2.py index 141d3683171f6..78440ac6ecb5b 100644 --- a/test/onnx/test_pytorch_onnx_caffe2.py +++ b/test/onnx/test_pytorch_onnx_caffe2.py @@ -12,6 +12,7 @@ import model_defs.word_language_model as word_language_model import numpy as np import onnx +import pytorch_test_common import torch.onnx import torch.onnx.operators import torch.utils.model_zoo as model_zoo @@ -129,18 +130,10 @@ def do_export(model, inputs, *args, **kwargs): } -class TestCaffe2Backend_opset9(common_utils.TestCase): +class TestCaffe2Backend_opset9(pytorch_test_common.ExportTestCase): opset_version = 9 embed_params = False - def setUp(self): - # the following should ideally be super().setUp(), https://github.com/pytorch/pytorch/issues/79630 - common_utils.TestCase.setUp(self) - torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(0) - np.random.seed(seed=0) - def convert_cuda(self, model, input): cuda_model = model.cuda() # input might be nested - we want to move everything to GPU @@ -3198,44 +3191,44 @@ def setup_rnn_tests(): # to embed_params=True TestCaffe2BackendEmbed_opset9 = type( "TestCaffe2BackendEmbed_opset9", - (common_utils.TestCase,), + (pytorch_test_common.ExportTestCase,), dict(TestCaffe2Backend_opset9.__dict__, embed_params=True), ) # opset 7 tests TestCaffe2Backend_opset7 = type( "TestCaffe2Backend_opset7", - (common_utils.TestCase,), + (pytorch_test_common.ExportTestCase,), dict(TestCaffe2Backend_opset9.__dict__, opset_version=7), ) TestCaffe2BackendEmbed_opset7 = type( "TestCaffe2BackendEmbed_opset7", - (common_utils.TestCase,), + (pytorch_test_common.ExportTestCase,), dict(TestCaffe2Backend_opset9.__dict__, embed_params=True, opset_version=7), ) # opset 8 tests TestCaffe2Backend_opset8 = type( "TestCaffe2Backend_opset8", - (common_utils.TestCase,), + (pytorch_test_common.ExportTestCase,), dict(TestCaffe2Backend_opset9.__dict__, opset_version=8), ) TestCaffe2BackendEmbed_opset8 = type( "TestCaffe2BackendEmbed_opset8", - (common_utils.TestCase,), + (pytorch_test_common.ExportTestCase,), dict(TestCaffe2Backend_opset9.__dict__, embed_params=True, opset_version=8), ) # opset 10 tests TestCaffe2Backend_opset10 = type( "TestCaffe2Backend_opset10", - (common_utils.TestCase,), + (pytorch_test_common.ExportTestCase,), dict(TestCaffe2Backend_opset9.__dict__, opset_version=10), ) TestCaffe2BackendEmbed_opset10 = type( "TestCaffe2BackendEmbed_opset10", - (common_utils.TestCase,), + (pytorch_test_common.ExportTestCase,), dict(TestCaffe2Backend_opset9.__dict__, embed_params=True, opset_version=10), ) @@ -3243,7 +3236,7 @@ def setup_rnn_tests(): # to embed_params=True TestCaffe2BackendEmbed_opset9_new_jit_API = type( "TestCaffe2BackendEmbed_opset9_new_jit_API", - (common_utils.TestCase,), + (pytorch_test_common.ExportTestCase,), dict(TestCaffe2Backend_opset9.__dict__, embed_params=True), ) diff --git a/test/onnx/test_pytorch_onnx_caffe2_quantized.py b/test/onnx/test_pytorch_onnx_caffe2_quantized.py index f6466aa0869e5..92079ebbe6d92 100644 --- a/test/onnx/test_pytorch_onnx_caffe2_quantized.py +++ b/test/onnx/test_pytorch_onnx_caffe2_quantized.py @@ -6,13 +6,14 @@ import numpy as np import onnx +import pytorch_test_common import torch.ao.nn.quantized as nnq import torch.nn as nn import torch.onnx from torch.testing._internal import common_utils -class TestQuantizedOps(common_utils.TestCase): +class TestQuantizedOps(pytorch_test_common.ExportTestCase): def generic_test( self, model, sample_inputs, input_names=None, decimal=3, relaxed_check=False ): diff --git a/test/onnx/test_pytorch_onnx_no_runtime.py b/test/onnx/test_pytorch_onnx_no_runtime.py index 5f2ce3fa657a1..c741ddd2c41ed 100644 --- a/test/onnx/test_pytorch_onnx_no_runtime.py +++ b/test/onnx/test_pytorch_onnx_no_runtime.py @@ -7,18 +7,21 @@ import itertools import unittest import unittest.mock +import warnings from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union +import numpy as np import onnx import onnx.numpy_helper +import pytorch_test_common import torch import torch.nn.functional as F from torch import Tensor -from torch.onnx import symbolic_helper, utils +from torch.onnx import OperatorExportTypes, symbolic_helper, utils from torch.onnx._globals import GLOBALS from torch.onnx._internal import registration -from torch.testing._internal import common_utils +from torch.testing._internal import common_quantization, common_utils, jit_utils def export_to_onnx( @@ -32,6 +35,7 @@ def export_to_onnx( mocks: Optional[Iterable] = None, operator_export_type: torch.onnx.OperatorExportTypes = torch.onnx.OperatorExportTypes.ONNX, opset_version: int = GLOBALS.export_onnx_opset_version, + **torch_onnx_export_kwargs, ) -> onnx.ModelProto: """Exports `model(input)` to ONNX and returns it. @@ -44,6 +48,7 @@ def export_to_onnx( mocks: list of mocks to use during export operator_export_type: export type as described by `torch.onnx.export(...operator_export_type,...)` opset_version: ONNX opset version as described by `torch.onnx.export(...opset_version,...)` + torch_onnx_export_kwargs: extra torch.onnx.export kwargs arguments Returns: A valid ONNX model (`onnx.ModelProto`) """ @@ -60,6 +65,7 @@ def export_to_onnx( f, operator_export_type=operator_export_type, opset_version=opset_version, + **torch_onnx_export_kwargs, ) # Validate ONNX graph before returning it @@ -68,7 +74,7 @@ def export_to_onnx( return onnx_model -class TestONNXExport(common_utils.TestCase): +class TestONNXExport(pytorch_test_common.ExportTestCase): def test_fuse_addmm(self): class AddmmModel(torch.nn.Module): def forward(self, x): @@ -76,7 +82,7 @@ def forward(self, x): x = torch.ones(3, 3) f = io.BytesIO() - torch.onnx._export(AddmmModel(), x, f, verbose=False) + torch.onnx.export(AddmmModel(), x, f, verbose=False) def test_onnx_transpose_incomplete_tensor_type(self): # Smoke test to get us into the state where we are attempting to export @@ -163,7 +169,7 @@ def forward(self, x): mte = ModuleToExport() f = io.BytesIO() with self.assertRaisesRegex(RuntimeError, "Couldn't export Python"): - torch.onnx._export(mte, (torch.zeros(1, 2, 3),), f, verbose=False) + torch.onnx.export(mte, (torch.zeros(1, 2, 3),), f, verbose=False) def test_onnx_export_script_inline_trace(self): class ModuleToInline(torch.nn.Module): @@ -427,7 +433,11 @@ def forward(self, x): onnx_model = export_to_onnx( MyClip(), torch.randn(3, 4, requires_grad=True), - custom_ops=[common_utils.custom_op("aten::clamp", bad_clamp, 9)], + custom_ops=[ + common_utils.custom_op( + "aten::clamp", bad_clamp, GLOBALS.export_onnx_opset_version + ) + ], operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, ) self.assertAtenOp(onnx_model, "clamp", "Tensor") @@ -541,7 +551,7 @@ def forward(self, x): x = torch.randn(32, 3) f = io.BytesIO() - torch.onnx._export(test_model, (x,), f, do_constant_folding=False) + torch.onnx.export(test_model, (x,), f, do_constant_folding=False) loaded_model = onnx.load_from_string(f.getvalue()) actual_list = [p.name for p in loaded_model.graph.initializer] @@ -777,6 +787,397 @@ def forward(self, x): model, inputs, f, dynamic_axes={"x": [0, 1]}, input_names=["x"] ) + def test_dropout_script(self): + + eg = torch.zeros(1, 2, 3, requires_grad=True) + + @jit_utils._trace(eg) + def foo(x): + x = torch.neg(x) + return F.dropout(x) + + class MyDrop(torch.nn.Module): + def forward(self, x): + return foo(x) + + f = io.BytesIO() + with warnings.catch_warnings(record=True): + torch.onnx.export(MyDrop(), (eg,), f, verbose=False) + + def test_pack_padded_pad_packed_trace(self): + from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence + + T, B, C = 3, 5, 7 + + class PadPackedWrapper(torch.nn.Module): + def __init__(self): + super(PadPackedWrapper, self).__init__() + + def forward(self, x, seq_lens): + x = pack_padded_sequence(x, seq_lens) + x, _ = pad_packed_sequence(x) + return x + + x = np.ones((T, B, C)) + seq_lens = np.array([3, 3, 2, 2, 1], dtype=np.int32) + # set padding value so we can test equivalence + for b in range(B): + if seq_lens[b] < T: + x[seq_lens[b] :, b, :] = 0 + seq_lens = torch.from_numpy(seq_lens) + x = torch.autograd.Variable(torch.from_numpy(x), requires_grad=True) + + m = PadPackedWrapper() + m_traced = torch.jit.trace( + m, + ( + x, + seq_lens, + ), + ) + + y = m(x, seq_lens) + loss = torch.sum(y) + loss.backward() + grad = x.grad.clone() + x.grad.zero_() + + y_traced = m_traced(x, seq_lens) + loss_traced = torch.sum(y_traced) + loss_traced.backward() + grad_traced = x.grad.clone() + + self.assertEqual(y_traced, x) + self.assertEqual(y_traced, y) + self.assertEqual(grad, grad_traced) + + f = io.BytesIO() + torch.onnx.export(m, (x, seq_lens), f, verbose=False) + + # Suppression: ONNX warns when exporting RNNs because of potential batch size mismatch. + @common_utils.suppress_warnings + def test_rnn_trace_override(self): + from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence + + num_layers = 3 + T, B, C = 11, 5, 7 + + class RNNTraceWrapper(torch.nn.Module): + def __init__(self, cell_type): + super(RNNTraceWrapper, self).__init__() + if cell_type == "RNN": + self.rnn = torch.nn.RNN( + input_size=C, hidden_size=C, num_layers=num_layers + ) + elif cell_type == "LSTM": + self.rnn = torch.nn.LSTM( + input_size=C, hidden_size=C, num_layers=num_layers + ) + elif cell_type == "GRU": + self.rnn = torch.nn.GRU( + input_size=C, hidden_size=C, num_layers=num_layers + ) + + def forward(self, x, seq_lens): + x = pack_padded_sequence(x, seq_lens) + x, _ = self.rnn(x) + x, _ = pad_packed_sequence(x) + return x + + for cell_type in ["RNN", "LSTM", "GRU"]: + x = torch.ones(T, B, C, requires_grad=True) + seq_lens = torch.from_numpy(np.array([11, 3, 2, 2, 1], dtype=np.int32)) + + m = RNNTraceWrapper(cell_type) + m_traced = torch.jit.trace( + m, + ( + x, + seq_lens, + ), + ) + + y = m(x, seq_lens) + loss = torch.sum(y) + loss.backward() + grad = x.grad.clone() + x.grad.zero_() + + y_traced = m_traced(x, seq_lens) + loss_traced = torch.sum(y_traced) + loss_traced.backward() + grad_traced = x.grad.clone() + + self.assertEqual(y_traced, y) + self.assertEqual(grad, grad_traced) + + f = io.BytesIO() + torch.onnx.export(m, (x, seq_lens), f, verbose=False) + + def test_trace_fork_wait_inline_onnx(self): + def fork_body(x): + return torch.neg(x), torch.neg(x) + + class MyMod(torch.nn.Module): + def forward(self, x): + fut = torch.jit._fork(fork_body, x) + val = torch.jit._wait(fut) + return val[1] + + # smoke test for ONNX export + f = io.BytesIO() + torch.onnx.export(MyMod(), (torch.rand(3, 4),), f) + + def test_trace_detach_onnx_erase(self): + class Mod(torch.nn.Module): + def forward(self, x, w): + return torch.matmul(x, w).detach() + + torch.onnx.export_to_pretty_string(Mod(), (torch.rand(3, 4), torch.rand(4, 5))) + + @common_utils.skipIfNoCaffe2 + def test_caffe2_aten_fallback_must_fallback(self): + class ModelWithAtenNotONNXOp(torch.nn.Module): + def forward(self, x, y): + abcd = x + y + defg = torch.linalg.qr(abcd) + return defg + + # TODO: Refactor common_utils._decide_skip_caffe2 to support parametrize + for operator_export_type in ( + OperatorExportTypes.ONNX_ATEN, + OperatorExportTypes.ONNX_ATEN_FALLBACK, + ): + x = torch.rand(3, 4) + y = torch.rand(3, 4) + f = io.BytesIO() + torch.onnx.export( + ModelWithAtenNotONNXOp(), + (x, y), + f, + do_constant_folding=False, + operator_export_type=operator_export_type, + # support for linalg.qr was added in later op set versions. + opset_version=9, + ) + onnx_model = onnx.load(io.BytesIO(f.getvalue())) + self.assertAtenOp(onnx_model, "linalg_qr") + + @common_utils.skipIfNoCaffe2 + def test_caffe2_onnx_aten_must_not_fallback(self): + class ModelWithAtenFmod(torch.nn.Module): + def forward(self, x, y): + return torch.fmod(x, y) + + # TODO: Refactor common_utils._decide_skip_caffe2 to support parametrize + for operator_export_type in ( + OperatorExportTypes.ONNX_ATEN_FALLBACK, + OperatorExportTypes.ONNX_ATEN, + ): + x = torch.randn(3, 4, dtype=torch.float32) + y = torch.randn(3, 4, dtype=torch.float32) + f = io.BytesIO() + torch.onnx.export( + ModelWithAtenFmod(), + (x, y), + f, + do_constant_folding=False, + operator_export_type=operator_export_type, + opset_version=10, # or higher + ) + onnx_model = onnx.load(io.BytesIO(f.getvalue())) + assert onnx_model.graph.node[0].op_type == "Mod" + + @common_utils.skipIfCaffe2 + def test_aten_fallback_must_fallback(self): + class ModelWithAtenNotONNXOp(torch.nn.Module): + def forward(self, x, y): + abcd = x + y + defg = torch.linalg.qr(abcd) + return defg + + x = torch.rand(3, 4) + y = torch.rand(3, 4) + f = io.BytesIO() + torch.onnx.export( + ModelWithAtenNotONNXOp(), + (x, y), + f, + do_constant_folding=False, + operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, + # support for linalg.qr was added in later op set versions. + opset_version=9, + ) + onnx_model = onnx.load(io.BytesIO(f.getvalue())) + self.assertAtenOp(onnx_model, "linalg_qr") + + @common_utils.skipIfCaffe2 + def test_onnx_aten(self): + class ModelWithAtenFmod(torch.nn.Module): + def forward(self, x, y): + return torch.fmod(x, y) + + x = torch.randn(3, 4, dtype=torch.float32) + y = torch.randn(3, 4, dtype=torch.float32) + f = io.BytesIO() + torch.onnx.export( + ModelWithAtenFmod(), + (x, y), + f, + do_constant_folding=False, + operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN, + ) + onnx_model = onnx.load(io.BytesIO(f.getvalue())) + self.assertAtenOp(onnx_model, "fmod", "Tensor") + + @common_utils.skipIfCaffe2 + def test_onnx_aten_fallback_must_not_fallback(self): + # For BUILD_CAFFE2=0, aten fallback only when not exportable + class ONNXExportable(torch.nn.Module): + def __init__(self): + super(ONNXExportable, self).__init__() + self.quant = torch.quantization.QuantStub() + self.fc1 = torch.nn.Linear(12, 8) + self.fc2 = torch.nn.Linear(8, 4) + self.fc3 = torch.nn.Linear(4, 6) + self.dequant = torch.quantization.DeQuantStub() + + def forward(self, x): + x = self.quant(x) + x = x.view((-1, 12)) + h = F.relu(self.fc1(x)) + h = F.relu(self.fc2(h)) + h = F.relu(self.fc3(h)) + h = self.dequant(h) + return h + + dummy_input = torch.randn(12) + f = io.BytesIO() + torch.onnx.export( + ONNXExportable(), + (dummy_input,), + f, + do_constant_folding=False, + operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, + ) + onnx_model = onnx.load(io.BytesIO(f.getvalue())) + all_aten_nodes = [ + p + for p in onnx_model.graph.node + if p.op_type == "ATen" and p.domain == "org.pytorch.aten" + ] + self.assertEqual(len(all_aten_nodes), 0) + + def test_cat_with_empty_tensor(self): + class NoopConcat(torch.nn.Module): + def forward(self, x): + return torch.cat((torch.Tensor([]), x)) + + x = torch.randn(4, 5, 6) + # TODO: Parametrize this test for opset_version + for opset_version in {9, 11}: + f = io.BytesIO() + torch.onnx.export(NoopConcat(), (x,), f, opset_version=opset_version) + loaded_model = onnx.load_from_string(f.getvalue()) + self.assertEqual( + len(loaded_model.graph.output[0].type.tensor_type.shape.dim), 3 + ) + for idx, dim in enumerate(x.shape): + self.assertEqual( + loaded_model.graph.output[0] + .type.tensor_type.shape.dim[idx] + .dim_value, + dim, + ) + + +class TestQuantizeEagerONNXExport(common_utils.TestCase): + def _test_lower_graph_impl(self, model, data): + model.qconfig = torch.ao.quantization.default_qconfig + model = torch.ao.quantization.prepare(model) + model = torch.ao.quantization.convert(model) + + _ = model(data) + input_names = ["x"] + + def _export_to_onnx(model, input, input_names): + traced = torch.jit.trace(model, input) + buf = io.BytesIO() + torch.jit.save(traced, buf) + buf.seek(0) + + model = torch.jit.load(buf) + f = io.BytesIO() + torch.onnx.export( + model, + input, + f, + input_names=input_names, + operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, + opset_version=9, + ) + + _export_to_onnx(model, data, input_names) + + @common_quantization.skipIfNoFBGEMM + @common_utils.skipIfNoCaffe2 + def test_lower_graph_linear(self): + model = torch.ao.quantization.QuantWrapper( + torch.nn.Linear(5, 10, bias=True) + ).to(dtype=torch.float) + data_numpy = np.random.rand(1, 2, 5).astype(np.float32) + data = torch.from_numpy(data_numpy).to(dtype=torch.float) + self._test_lower_graph_impl(model, data) + + @common_quantization.skipIfNoFBGEMM + @common_utils.skipIfNoCaffe2 + def test_lower_graph_conv2d(self): + model = torch.ao.quantization.QuantWrapper( + torch.nn.Conv2d(3, 5, 2, bias=True) + ).to(dtype=torch.float) + data_numpy = np.random.rand(1, 3, 6, 6).astype(np.float32) + data = torch.from_numpy(data_numpy).to(dtype=torch.float) + self._test_lower_graph_impl(model, data) + + @common_quantization.skipIfNoFBGEMM + @unittest.skip( + "onnx opset9 does not support quantize_per_tensor and caffe2 \ + does not support conv3d" + ) + def test_lower_graph_conv3d(self): + model = torch.ao.quantization.QuantWrapper( + torch.nn.Conv3d(3, 5, 2, bias=True) + ).to(dtype=torch.float) + data_numpy = np.random.rand(1, 3, 6, 6, 6).astype(np.float32) + data = torch.from_numpy(data_numpy).to(dtype=torch.float) + self._test_lower_graph_impl(model, data) + + @pytorch_test_common.skipIfNoCuda + def test_composed_layer_norm_small_eps_fp16_keep_double(self): + class Net(torch.nn.Module): + def __init__(self, C): + super().__init__() + self.layer_norm = torch.nn.LayerNorm(C, eps=1e-8) + + def forward(self, x): + return self.layer_norm(x) + + N, C = 8, 4 + model = Net(C).cuda().half() + x = torch.randn(N, C).cuda().half() + f = io.BytesIO() + torch.onnx.export(model, x, f, opset_version=14) + onnx_model = onnx.load_from_string(f.getvalue()) + const_node = [n for n in onnx_model.graph.node if n.op_type == "Constant"] + self.assertNotEqual(len(const_node), 0) + double_type_count = 0 + for node in const_node: + for a in node.attribute: + # EPS constant should be in double type + if a.name == "value" and a.t.data_type == 11: + double_type_count += 1 + self.assertNotEqual(double_type_count, 0) + if __name__ == "__main__": common_utils.run_tests() diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index 43e8d3579c192..b30056acb09d7 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -1412,8 +1412,26 @@ def test_avgpool_1d_ceil(self): x = torch.randn(1, 1, 7) self.run_test(model, x) - def test_avgpool_2d_ceil(self): - model = torch.nn.AvgPool2d(3, 2, ceil_mode=True) + @common_utils.parametrize( + "padding", + (0, 1), + ) + @common_utils.parametrize( + "ceil_mode", + (True, False), + ) + @common_utils.parametrize( + "count_include_pad", + (True, False), + ) + def test_avgpool_2d(self, padding, ceil_mode, count_include_pad): + model = torch.nn.AvgPool2d( + 3, + 3, + padding=padding, + ceil_mode=ceil_mode, + count_include_pad=count_include_pad, + ) x = torch.randn(20, 16, 50, 32) self.run_test(model, x) @@ -2625,6 +2643,23 @@ def forward(self, x): x = torch.empty(2, 3, 3, dtype=torch.double).uniform_(0, 1) self.run_test(Bernoulli(), x) + def test_bernoulli_p(self): + class Bernoulli_float(torch.nn.Module): + def forward(self, x): + return torch.mul(x, torch.bernoulli(x, 0.2).size(0)) + + class Bernoulli_tensor(torch.nn.Module): + def forward(self, x): + return torch.mul(x, torch.rand_like(x).bernoulli_(x).size(0)) + + x = torch.rand(3, 3) + self.run_test(Bernoulli_float(), x) + self.run_test(Bernoulli_tensor(), x) + + x = torch.rand(2, 3, 3, dtype=torch.double) + self.run_test(Bernoulli_float(), x) + self.run_test(Bernoulli_tensor(), x) + @unittest.skip("Bug in ORT, skip test until rel-1.11.") @skipIfUnsupportedMinOpsetVersion(14) def test_reshape_allowzero(self): @@ -3893,6 +3928,52 @@ def forward(self, src, index): index = torch.tensor([[0, 0], [1, 1], [0, 1]], dtype=torch.int64) self.run_test(ScatterModel(), (src, index)) + @skipIfUnsupportedMinOpsetVersion(16) + def test_scatter_add_different_size_index_src(self): + class ScatterModel(torch.nn.Module): + def forward(self, input, indices, src): + return input.scatter_add(0, indices, src) + + src = torch.ones((2, 5)) + input = torch.zeros(3, 5, dtype=src.dtype) + indices = torch.tensor([[0, 1, 2, 0, 0]]) + self.run_test(ScatterModel(), input_args=(input, indices, src)) + + @common_utils.parametrize( + "src, indices", + [ + common_utils.subtest( + [torch.ones((1, 5)), torch.tensor([[0, 1, 2, 0, 0]])], + name="src_indices_dynamic_combination1", + ), + common_utils.subtest( + [torch.ones((2, 5)), torch.tensor([[0, 1, 2, 0, 0], [1, 0, 2, 1, 2]])], + name="src_indices_dynamic_combination2", + ), + common_utils.subtest( + [torch.ones((3, 5)), torch.tensor([[0, 1, 2, 0, 0], [1, 0, 2, 1, 2]])], + name="src_indices_dynamic_combination3", + ), + common_utils.subtest( + [torch.ones((3, 5)), torch.tensor([[0, 1, 2, 0], [1, 0, 2, 1]])], + name="src_indices_dynamic_combination4", + ), + ], + ) + @skipIfUnsupportedMinOpsetVersion(16) + def test_scatter_add_dynamic_index(self, src, indices): + class ScatterModel(torch.nn.Module): + def forward(self, input, indices, src): + return input.scatter_add(0, indices, src) + + input = torch.zeros(3, 5, dtype=src.dtype) + self.run_test( + ScatterModel(), + input_args=(input, indices, src), + input_names=["input", "indices", "src"], + dynamic_axes={"indices": {0: "a", 1: "b"}, "src": {0: "c", 1: "d"}}, + ) + @skipIfUnsupportedMinOpsetVersion(9) def test_bucketize(self): class BucketModel(torch.nn.Module): @@ -6666,6 +6747,8 @@ def forward(self, x, y): y = torch.tensor(2) self.run_test(FullLikeModel(), (x, y)) + @unittest.skip("It started failing after #81761") + # TODO(#83661): Fix and enable the test def test_l1_norm(self): class NormModel(torch.nn.Module): def forward(self, x): @@ -6674,6 +6757,8 @@ def forward(self, x): x = torch.randn(4, 2, 3, requires_grad=True) self.run_test(NormModel(), x) + @unittest.skip("It started failing after #81761") + # TODO(#83661): Fix and enable the test def test_l2_norm(self): class NormModel(torch.nn.Module): def forward(self, x): @@ -6682,6 +6767,8 @@ def forward(self, x): x = torch.randn(4, 2, 3, requires_grad=True) self.run_test(NormModel(), x) + @unittest.skip("It started failing after #81761") + # TODO(#83661): Fix and enable the test def test_frobenius_norm(self): class NormModel(torch.nn.Module): def forward(self, x): @@ -6690,6 +6777,8 @@ def forward(self, x): x = torch.randn(4, 2, 3, requires_grad=True) self.run_test(NormModel(), x) + @unittest.skip("It started failing after #81761") + # TODO(#83661): Fix and enable the test def test_frobenius_norm_keepdim(self): class NormModel(torch.nn.Module): def forward(self, x): @@ -7437,6 +7526,27 @@ def forward(self, x, pad: List[int]): x = torch.randn(2, 2, 4, 4) self.run_test(Pad(), (x, pad)) + @skipIfUnsupportedMinOpsetVersion(11) + def test_pad_circular(self): + class PadModel(torch.nn.Module): + def forward(self, x): + out = torch.nn.functional.pad(x, (1, 2, 1, 2), mode="circular") + return out + + x = torch.randn(2, 3, 3, 4) + self.run_test(PadModel(), (x)) + + @skipIfUnsupportedMinOpsetVersion(11) + def test_pad_circular_negative(self): + # Test for different pad integer types + class PadModel(torch.nn.Module): + def forward(self, x): + out = torch.nn.functional.pad(x, (-1, -2), mode="circular") + return out + + x = torch.randn(2, 3, 6) + self.run_test(PadModel(), (x)) + @skipIfUnsupportedMaxOpsetVersion(10) @skipScriptTest() # TODO: the logic in symbolic_opset9 doesn't handle script def test_unsupported_pad(self): @@ -8689,6 +8799,28 @@ def forward(self, x, y): y = torch.full_like(x, True) self.run_test(MinimumModel(), (x, y)) + @skipIfUnsupportedMinOpsetVersion(12) + def test_maximum_dtypes(self): + class MaximumModel(torch.nn.Module): + def forward(self, x, y): + return torch.maximum(x, y) + + x = torch.randn((5, 5), dtype=torch.float16) + y = torch.randn((5, 5), dtype=torch.float) + self.run_test(MaximumModel(), (x, y)) + + x = torch.randn((5, 5), dtype=torch.float16) + y = torch.randint(10, (5, 5), dtype=torch.int16) + self.run_test(MaximumModel(), (x, y)) + + x = torch.randint(10, (5, 5), dtype=torch.int16) + y = torch.randint(10, (5, 5), dtype=torch.int32) + self.run_test(MaximumModel(), (x, y)) + + x = torch.randint(10, (5, 5), dtype=torch.int) + y = torch.full_like(x, True) + self.run_test(MaximumModel(), (x, y)) + @skipIfUnsupportedMinOpsetVersion(9) def test_any(self): class M(torch.nn.Module): @@ -9090,7 +9222,7 @@ def forward(self, x, y, cond): ) @skipScriptTest( - skip_before_opset_version=11, reason="dynamic split support addded in 11" + skip_before_opset_version=11, reason="dynamic split support added in 11" ) def test_split_tensor_scalar(self): class SplitModel(torch.nn.Module): @@ -11393,7 +11525,6 @@ def forward(self, x, y): self.run_test(M_ToDeviceDtype(), (x, y)) @skipIfUnsupportedMinOpsetVersion(9) - @skipScriptTest() def test_fill(self): class FillModule(torch.nn.Module): def forward(self, x, filled_value: int): @@ -11403,6 +11534,14 @@ def forward(self, x, filled_value: int): filled_value = 7 self.run_test(FillModule(), (x, filled_value)) + class FillFloatModule(torch.nn.Module): + def forward(self, x, filled_value: float): + return x.fill_(filled_value) + + x = torch.randn((4, 5, 6)) + filled_value = 7.5 + self.run_test(FillFloatModule(), (x, filled_value)) + class FillScalarModule(torch.nn.Module): def forward(self, x): res = x + 2 @@ -11827,6 +11966,20 @@ def test_quantized_conv2d_relu(self): q_input = torch.quantize_per_tensor(input, 0.5, 128, torch.quint8) self.run_test(model, q_input) + @skipIfUnsupportedMinOpsetVersion(10) + def test_quantized_conv1d_relu(self): + model = torch.nn.intrinsic.quantized.ConvReLU1d(16, 33, 3, stride=2) + # Manually initialize model weight and bias to random numbers. + # By default all zeros. + q_weight = torch.quantize_per_tensor( + torch.randn(33, 16, 3), 0.5, 0, torch.qint8 + ) + bias = torch.arange(33).to(torch.float) - 16 + model.set_weight_bias(q_weight, bias) + input = torch.randn(3, 16, 32) + q_input = torch.quantize_per_tensor(input, 0.5, 128, torch.quint8) + self.run_test(model, q_input) + @common_utils.parametrize( "function_or_module", [ @@ -12261,6 +12414,17 @@ def test_qat_upsample_nearest2d(self): input = _construct_tensor_for_quantization_test((4, 3, 2, 2)) self.run_test(model, input) + def test_0d_tensor_broadcast(self): + class fn(torch.nn.Module): + def forward(self, x, y): + a = torch.add(x, y) + b = torch.mul(y, y) + return a + b + + x = torch.ones(0) + y = torch.ones(1) + self.run_test(fn(), (x, y), input_names=["x", "y"], output_names=["output"]) + @skipIfUnsupportedMinOpsetVersion(9) def test_convolution_allow_tf32(self): class Module(torch.nn.Module): @@ -12367,8 +12531,6 @@ def forward(self, x) -> Optional[Tensor]: y = None return y - # Skip now to wait more insight on https://github.com/onnx/onnx/issues/4424 - # Model fails on type inference, as it's input/output type mismatch. class LoopNoneInput(torch.nn.Module): def forward(self, x) -> Optional[Tensor]: y: Optional[Tensor] = None @@ -12408,7 +12570,7 @@ def test_optional_output(self, module_class: Type[torch.nn.Module], x_size: int) input_names=["x"], ) exported = onnx.load_from_string(f.getvalue()) - expected_elem_type = torch.onnx.JitScalarType.from_dtype(x.dtype).onnx_type() + expected_elem_type = torch.onnx.JitScalarType.from_value(x).onnx_type() expected_output_type = onnx.helper.make_optional_type_proto( onnx.helper.make_tensor_type_proto(expected_elem_type, (dynamic_axis_name,)) ) @@ -12517,6 +12679,59 @@ def forward(self, x): x, ) + @skipScriptTest() + @skipIfUnsupportedMinOpsetVersion(16) + @unittest.skipIf( + not torch.hub._check_module_exists("torch_geometric"), + "torch_geometric not installed.", + ) + def test_sage_conv(self): + from torch_geometric import nn as torch_geometric_nn + + # Input + coords0 = torch.randn(1, 6) + coords1 = torch.randn(1, 6) + coords = torch.transpose(torch.cat((coords0, coords1), dim=0), 0, 1) + adj = torch_geometric_nn.knn_graph(coords, k=2, batch=None, loop=True) + edge_from = adj[0:1, :] + edge_to = adj[1:, :] + inputs = (coords0, coords1, edge_from, edge_to) + + class MySAGEConv(torch.nn.Module): + def __init__(self): + super().__init__() + self.SAGEConvBlock1 = torch_geometric_nn.SAGEConv( + 2, 512, normalize=True + ) + self.bano1 = torch_geometric_nn.BatchNorm(512) + self.relu = torch.nn.ReLU() + self.dense1 = torch.nn.Seq(Lin(512, 1)) + self.sigmoid = torch.nn.Sigmoid() + + def forward(self, coords0, coords1, edge_from, edge_to): + adj = torch.cat((edge_from, edge_to), dim=0) + gra = torch.transpose(torch.cat((coords0, coords1), dim=0), 0, 1) + x1 = self.SAGEConvBlock1(gra, edge_index=adj) + x = torch.unsqueeze(torch.sum(x1), dim=0) + return x + + input_names = ["coords0", "coords1", "edge_from", "edge_to"] + output_names = ["outputs"] + dynamic_axes = { + "coords0": {0: "batch_size", 1: "features"}, + "coords1": {0: "batch_size", 1: "features"}, + "edge_from": {0: "batch_size", 1: "features"}, + "edge_to": {0: "batch_size", 1: "features"}, + "outputs": {0: "batch_size"}, + } + self.run_test( + MySAGEConv(), + inputs, + input_names=input_names, + output_names=output_names, + dynamic_axes=dynamic_axes, + ) + # Cannot export with older opsets because of "ConstantFill" op # ConstantFill was a temp op removed at opset 8. This is no longer supported by onnxruntime # There are still some issues prevent us from enabling script test for these scenarios: diff --git a/test/onnx/test_pytorch_onnx_onnxruntime_cuda.py b/test/onnx/test_pytorch_onnx_onnxruntime_cuda.py index 193b87af3d284..9695d05b6072f 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime_cuda.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime_cuda.py @@ -110,8 +110,8 @@ def forward(self, x): try: from apex import amp - except Exception: - raise unittest.SkipTest("Apex is not available") + except Exception as e: + raise unittest.SkipTest("Apex is not available") from e input = torch.randn(3, 3, device=torch.device("cuda")) model = amp.initialize(LinearModel(), opt_level="O2") self.run_test(model, input) diff --git a/test/onnx/test_pytorch_onnx_shape_inference.py b/test/onnx/test_pytorch_onnx_shape_inference.py index 86258fb1d0ec1..915677279d017 100644 --- a/test/onnx/test_pytorch_onnx_shape_inference.py +++ b/test/onnx/test_pytorch_onnx_shape_inference.py @@ -1,7 +1,10 @@ # Owner(s): ["module: onnx"] -import numpy as np +import io +import numpy as np +import onnx +import pytorch_test_common import torch from pytorch_test_common import skipIfUnsupportedMinOpsetVersion from torch.onnx import _constants, symbolic_helper @@ -19,7 +22,7 @@ def verify(actual_type): return verify -class TestONNXShapeInference(common_utils.TestCase): +class TestONNXShapeInference(pytorch_test_common.ExportTestCase): def setUp(self): self.opset_version = _constants.ONNX_MAX_OPSET symbolic_helper._set_onnx_shape_inference(True) @@ -283,5 +286,172 @@ def test_reduce_prod_without_axes(self): self.run_test(g, reduce_prod.node(), expect_tensor("Long", shape=(1,))) +class TestONNXCustomOpShapeInference(pytorch_test_common.ExportTestCase): + def setUp(self): + super().setUp() + self.opset_version = _constants.ONNX_MAX_OPSET + + def test_setType_maintains_output_shape_for_single_custom_op(self): + + self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "::linalg_inv", 9) + + class CustomInverse(torch.nn.Module): + def forward(self, x): + return torch.inverse(x) + x + + def linalg_inv_settype(g, self): + return g.op("com.microsoft::Inverse", self).setType(self.type()) + + torch.onnx.register_custom_op_symbolic("::linalg_inv", linalg_inv_settype, 9) + model = CustomInverse() + x = torch.randn(2, 3, 3) + f = io.BytesIO() + torch.onnx.export( + model, + (x,), + f, + opset_version=self.opset_version, + custom_opsets={"com.microsoft": 1}, + ) + + model_proto = onnx.load(io.BytesIO(f.getvalue())) + model_value_info = model_proto.graph.value_info + self.assertIsNotNone(model_value_info) + assert model_value_info + dims = model_value_info[0].type.tensor_type.shape.dim + for i in range(len(dims)): + # If node output has shape info, it should have dim_value + # Otherwise, it has dim_params with dynamic shape + self.assertTrue(dims[i].HasField("dim_value")) + for dim, rank in zip(dims, x.size()): + self.assertEqual(dim.dim_value, rank) + + def test_no_setType_for_single_custom_op(self): + + self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "::linalg_inv", 9) + + class CustomInverse(torch.nn.Module): + def forward(self, x): + return torch.inverse(x) + x + + def linalg_inv_no_settype(g, self): + return g.op("com.microsoft::Inverse", self) + + torch.onnx.register_custom_op_symbolic("::linalg_inv", linalg_inv_no_settype, 9) + model = CustomInverse() + x = torch.randn(2, 3, 3) + f = io.BytesIO() + torch.onnx.export( + model, + (x,), + f, + opset_version=self.opset_version, + custom_opsets={"com.microsoft": 1}, + ) + + model_proto = onnx.load(io.BytesIO(f.getvalue())) + model_value_info = model_proto.graph.value_info + self.assertIsNotNone(model_value_info) + assert model_value_info + dims = model_value_info[0].type.tensor_type.shape.dim + for i in range(len(dims)): + # If node output has shape info, it should have dim_value + # Otherwise, it has dim_params with dynamic shape + self.assertTrue(dims[i].HasField("dim_param")) + + def test_setType_maintains_output_shape_for_single_custom_op_with_dynamic_axes( + self, + ): + + self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "::linalg_inv", 9) + + class CustomInverse(torch.nn.Module): + def forward(self, x): + return torch.inverse(x) + x + + def linalg_inv_settype(g, self): + return g.op("com.microsoft::Inverse", self).setType( + self.type().with_dtype(torch.float).with_sizes([None, 3, 3]) + ) + + torch.onnx.register_custom_op_symbolic("::linalg_inv", linalg_inv_settype, 9) + model = CustomInverse() + x = torch.randn(2, 3, 3) + f = io.BytesIO() + torch.onnx.export( + model, + (x,), + f, + opset_version=self.opset_version, + custom_opsets={"com.microsoft": 1}, + input_names=["x"], + dynamic_axes={"x": {0: "batch"}}, + ) + + model_proto = onnx.load(io.BytesIO(f.getvalue())) + model_value_info = model_proto.graph.value_info + self.assertIsNotNone(model_value_info) + assert model_value_info + dims = model_value_info[0].type.tensor_type.shape.dim + # The first axe should be dynamic as we defined when exporting + self.assertTrue(dims[0].HasField("dim_param")) + for i in range(1, len(dims)): + # If node output has shape info, it should have dim_value + # Otherwise, it has dim_params with dynamic shape + self.assertTrue(dims[i].HasField("dim_value")) + self.assertEqual(dims[i].dim_value, x.size()[i]) + + def test_setType_maintains_output_shape_for_single_custom_op_with_onnx_ops(self): + + self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "::linalg_inv", 9) + + class CustomInverse(torch.nn.Module): + def forward(self, x, y, z): + x = torch.inverse(x) + return x + y + z + + def linalg_inv_settype(g, self): + return g.op("com.microsoft::Inverse", self).setType( + self.type().with_dtype(torch.float).with_sizes([2, 3, 10, 10]) + ) + + torch.onnx.register_custom_op_symbolic("::linalg_inv", linalg_inv_settype, 9) + model = CustomInverse() + x = torch.randn(2, 3, 10, 10) + y = torch.randn(2, 3, 10, 10) + z = torch.randn(2, 3, 10, 10) + f = io.BytesIO() + torch.onnx.export( + model, + (x, y, z), + f, + opset_version=self.opset_version, + custom_opsets={"com.microsoft": 1}, + ) + + model_proto = onnx.load(io.BytesIO(f.getvalue())) + # To validate the shape of inverse Op, we need to find inverse output name, + # and then use it to identify its value_info for the shape. + output_name = "" + for node in model_proto.graph.node: + if node.op_type == "Inverse": + output_name = node.output[0] + break + assert output_name + model_value_info = model_proto.graph.value_info + self.assertIsNotNone(model_value_info) + assert model_value_info + for value_info in model_value_info: + assert value_info.name + if value_info.name == output_name: + dims = value_info.type.tensor_type.shape.dim + for i in range(len(dims)): + # If node output has shape info, it should have dim_value + # Otherwise, it has dim_params with dynamic shape + self.assertTrue(dims[i].HasField("dim_value")) + for dim, rank in zip(dims, x.size()): + self.assertEqual(dim.dim_value, rank) + + if __name__ == "__main__": common_utils.run_tests() diff --git a/test/onnx/test_utility_funs.py b/test/onnx/test_utility_funs.py index 26467d54c1c6c..25ee698fd6d0e 100644 --- a/test/onnx/test_utility_funs.py +++ b/test/onnx/test_utility_funs.py @@ -8,6 +8,7 @@ import onnx import parameterized +import pytorch_test_common import torch import torch.onnx @@ -27,13 +28,7 @@ from verify import verify -class _BaseTestCase(common_utils.TestCase): - def setUp(self): - super().setUp() - torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(0) - +class _BaseTestCase(pytorch_test_common.ExportTestCase): def _model_to_graph( self, model, @@ -64,7 +59,7 @@ def _model_to_graph( @common_utils.instantiate_parametrized_tests -class TestUnconvertibleOps(common_utils.TestCase): +class TestUnconvertibleOps(pytorch_test_common.ExportTestCase): """Unit tests for the `unconvertible_ops` function.""" def setUp(self): @@ -129,6 +124,18 @@ def test_it_returns_empty_list_when_all_ops_convertible( _, unconvertible_ops = utils.unconvertible_ops(module, (x,), opset_version=12) self.assertEqual(unconvertible_ops, []) + def test_it_returns_empty_list_when_model_contains_supported_inplace_ops(self): + class SkipConnectionModule(torch.nn.Module): + def forward(self, x): + out = x + out += x + out = torch.nn.functional.relu(out, inplace=True) + + module = SkipConnectionModule() + x = torch.randn(4, 4) + _, unconvertible_ops = utils.unconvertible_ops(module, (x,), opset_version=13) + self.assertEqual(unconvertible_ops, []) + @parameterized.parameterized_class( [ @@ -233,7 +240,6 @@ def forward(self, x): for node in graph.nodes(): self.assertNotEqual(node.kind(), "onnx::ReduceL2") - self.assertEqual(len(list(graph.nodes())), 2) def test_constant_fold_reduceL1(self): class NormModule(torch.nn.Module): @@ -251,7 +257,6 @@ def forward(self, x): for node in graph.nodes(): self.assertNotEqual(node.kind(), "onnx::ReduceL1") - self.assertEqual(len(list(graph.nodes())), 2) def test_constant_fold_slice(self): class NarrowModule(torch.nn.Module): @@ -598,8 +603,7 @@ def forward(self, x): params = list(params_dict.values()) self.assertEqual(len(params), 1) weight = params[0] - # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 - self.assertEqualIgnoreType(weight, torch.tensor([2, 3, 4, 5, 6])) + self.assertEqual(weight, torch.tensor([2.0, 3.0, 4.0, 5.0, 6.0])) def test_constant_fold_sub(self): class Module(torch.nn.Module): @@ -630,8 +634,7 @@ def forward(self, x): params = list(params_dict.values()) self.assertEqual(len(params), 1) weight = params[0] - # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 - self.assertEqualIgnoreType(weight, torch.tensor([0, -1, -2, -3, -4])) + self.assertEqual(weight, torch.tensor([0.0, -1.0, -2.0, -3.0, -4.0])) def test_constant_fold_sqrt(self): class Module(torch.nn.Module): diff --git a/test/package/module_a_remapped_path.py b/test/package/module_a_remapped_path.py new file mode 100644 index 0000000000000..793ddd4296885 --- /dev/null +++ b/test/package/module_a_remapped_path.py @@ -0,0 +1 @@ +result = "module_a_remapped_path" diff --git a/test/package/test_misc.py b/test/package/test_misc.py index c29602d8e360b..908e8d29992c3 100644 --- a/test/package/test_misc.py +++ b/test/package/test_misc.py @@ -2,7 +2,9 @@ # Owner(s): ["oncall: package/deploy"] import inspect +import os import platform +import sys from io import BytesIO from pathlib import Path from textwrap import dedent @@ -104,6 +106,60 @@ def test_file_structure(self): import_exclude, ) + def test_loaders_that_remap_files_work_ok(self): + from importlib.abc import MetaPathFinder + from importlib.machinery import SourceFileLoader + from importlib.util import spec_from_loader + + class LoaderThatRemapsModuleA(SourceFileLoader): + def get_filename(self, name): + result = super().get_filename(name) + if name == "module_a": + return os.path.join(os.path.dirname(result), "module_a_remapped_path.py") + else: + return result + + class FinderThatRemapsModuleA(MetaPathFinder): + def find_spec(self, fullname, path, target): + """Try to find the original spec for module_a using all the + remaining meta_path finders.""" + if fullname != "module_a": + return None + spec = None + for finder in sys.meta_path: + if finder is self: + continue + if hasattr(finder, "find_spec"): + spec = finder.find_spec(fullname, path, target=target) + elif hasattr(finder, "load_module"): + spec = spec_from_loader(fullname, finder) + if spec is not None: + break + assert spec is not None and isinstance(spec.loader, SourceFileLoader) + spec.loader = LoaderThatRemapsModuleA(spec.loader.name, spec.loader.path) + return spec + + sys.meta_path.insert(0, FinderThatRemapsModuleA()) + # clear it from sys.modules so that we use the custom finder next time + # it gets imported + sys.modules.pop("module_a", None) + try: + buffer = BytesIO() + with PackageExporter(buffer) as he: + import module_a + + he.intern("**") + he.save_module(module_a.__name__) + + + buffer.seek(0) + hi = PackageImporter(buffer) + self.assertTrue("remapped_path" in hi.get_source("module_a")) + finally: + # pop it again to ensure it does not mess up other tests + sys.modules.pop("module_a", None) + sys.meta_path.pop(0) + def test_python_version(self): """ Tests that the current python version is stored in the package and is available diff --git a/test/profiler/test_memory_profiler.py b/test/profiler/test_memory_profiler.py new file mode 100644 index 0000000000000..84442724205ab --- /dev/null +++ b/test/profiler/test_memory_profiler.py @@ -0,0 +1,1559 @@ +# Owner(s): ["oncall: profiler"] +import functools +import gc +import itertools as it +import textwrap +from typing import Callable, Dict, Iterator, List, Optional, Tuple + +import torch +from torch._C._profiler import _EventType, _TensorMetadata +from torch.profiler import _memory_profiler, _utils +from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase +from torch.utils._pytree import tree_flatten + + +profile = functools.partial( + torch.profiler.profile, record_shapes=True, profile_memory=True, with_stack=True +) + + +@skipIfTorchDynamo("TorchDynamo removes profiler altogether.") +class TestMemoryProfiler(TestCase): + def test_config_check(self) -> None: + with torch.profiler.profile() as prof: + pass + + pattern = r"record_shapes=True, profile_memory=True, with_stack=True" + with self.assertRaisesRegex(ValueError, pattern): + prof._memory_profile() + + with torch.profiler.profile(record_shapes=True, with_stack=True) as prof: + pass + + pattern = r"^profile_memory=True required for memory profiling\.$" + with self.assertRaisesRegex(ValueError, pattern): + prof._memory_profile() + + with profile() as prof: + pass + + self.assertIsInstance(prof._memory_profile(), _memory_profiler.MemoryProfile) + + +class ScaleLayer(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.scale = torch.nn.Parameter(torch.rand(()), requires_grad=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x * self.scale + + +class LazyLinear(torch.nn.Module): + def __init__(self, in_features: int, out_features: int): + super().__init__() + self.in_features = in_features + self.out_features = out_features + + def forward(self, x) -> torch.Tensor: + if getattr(self, "weight", None) is None: + self.weight = torch.nn.Parameter( + torch.empty((self.out_features, self.in_features)) + ) + self.bias = torch.nn.Parameter(torch.empty(self.out_features)) + + return torch.nn.functional.linear(x, self.weight, self.bias) + + +class RecordInputOutputDispatchMode(torch.utils._python_dispatch.TorchDispatchMode): + def __init__(self): + self.results = [] + + def mark_region(self, name: str): + self.results.append((name, (), ())) + + @staticmethod + def flat_ids(args): + flat_args = tree_flatten(args)[0] + return tuple( + (t._cdata, t.storage().data_ptr()) + for t in flat_args + if isinstance(t, torch.Tensor) and t.storage() + ) + + def __torch_dispatch__(self, func, types, args=..., kwargs=None): + args = args or [] + kwargs = kwargs or {} + flat_inputs = self.flat_ids(args) + self.flat_ids(kwargs) + out = func(*args, **kwargs) + flat_outputs = self.flat_ids(out) + if ( + flat_inputs or flat_outputs + ) and "_record_function_enter" not in func.name(): + self.results.append((func.name(), flat_inputs, flat_outputs)) + return out + + +@skipIfTorchDynamo("TorchDynamo changes Python calls that memory profiling relies on.") +class TestIdentifyGradients(TestCase): + def gradient_detected( + self, + prof: torch.profiler.profile, + ctx: _EventType, + grad_tensor: torch.Tensor, + parameter: Optional[torch.Tensor] = None, + ) -> None: + + # This is not an exhaustive check, but for the purpose of unit testing + # it is sufficient. + def key_matches_tensor(key, tensor) -> bool: + # Vacuous case. + if tensor is None: + return True + + if key is None: + return False + + return tensor.storage().data_ptr() == key.storage.ptr + + tree = prof.profiler.kineto_results.experimental_event_tree() + for node in _utils.traverse_dfs(tree): + for p_key, p_grad_key in _memory_profiler.extract_gradients(node): + if node.tag == ctx and key_matches_tensor(p_grad_key, grad_tensor): + if parameter is None: + return True # Don't need to check parameter; we're done. + + elif p_key is not None: + # For a complex workflow a gradient could correspond to + # different parameters at different points in a trace. + # However this will not happen in the relatively simple + # cases tested here, so if `extract_gradients` identifies + # the parameter corresponding to a particular gradient it + # must be the one we expect. + self.assertTrue(key_matches_tensor(p_key, parameter)) + return True + + return False + + def assertGradientDetected(self, name: str, *args, **kwargs) -> None: + self.assertTrue( + self.gradient_detected(*args, **kwargs), + f"Failed to identify gradient `{name}` from profile.", + ) + + def assertOnlyGradients( + self, prof: torch.profiler.profile, tensors: Iterator[torch.Tensor] + ) -> None: + allowed_set = {t.storage().data_ptr() for t in tensors} + + tree = prof.profiler.kineto_results.experimental_event_tree() + for node in _utils.traverse_dfs(tree): + for _, p_grad_key in _memory_profiler.extract_gradients(node): + self.assertTrue( + p_grad_key.storage.ptr in allowed_set, + f"Tensor wrongly marked as gradient: {node.name}: {p_grad_key}", + ) + + def test_extract_gradients_low_level(self) -> None: + x = torch.ones((1,)) + w0 = torch.ones((1,), requires_grad=True) + w1 = torch.ones((1,), requires_grad=True) + + def check(cold_start: bool): + self.assertEqual(w0.grad is None, cold_start) + self.assertEqual(w1.grad is None, cold_start) + with profile() as prof: + z = x.expand(4) * w0 + (z * w1).sum().backward() + + # Gradient detection through op inspection does not provide a + # reference to the parameter corresponding to the gradient. + self.assertGradientDetected("w0", prof, _EventType.TorchOp, w0.grad) + self.assertGradientDetected("w1", prof, _EventType.TorchOp, w1.grad) + self.assertOnlyGradients(prof, (w0.grad, w1.grad)) + + check(cold_start=True) + check(cold_start=False) + + def test_extract_gradients_from_module(self) -> None: + model = torch.nn.Sequential(torch.nn.Linear(2, 1), ScaleLayer()) + named_parameters = {name: p for name, p in model.named_parameters()} + self.assertEqual(len(named_parameters), 3) + + def assert_only_gradients(prof: torch.profiler.profile): + gradients = tuple(i.grad for i in named_parameters.values()) + self.assertFalse(any(i is None for i in gradients)) + self.assertOnlyGradients(prof, gradients) + + def check(cold_start: bool): + x = torch.ones((2, 2)) + with profile() as prof: + model(x).sum().backward() + + for name, p in named_parameters.items(): + # The first time we run a module none of the `.grad` fields + # have been initialized. This is fine; in that case we can + # detect everything we need in the profiled section. + self.assertNotEqual( + self.gradient_detected(prof, _EventType.PyCall, p.grad, p), + cold_start, + name, + ) + + # Op based detection should still identify the gradients. + self.assertGradientDetected(name, prof, _EventType.TorchOp, p.grad) + assert_only_gradients(prof) + + # We can detect gradients even when `.backward()` is not called. + with profile() as prof: + model(torch.ones((2, 2))) + + for name, p in named_parameters.items(): + self.assertGradientDetected(name, prof, _EventType.PyCall, p.grad, p) + self.assertFalse( + self.gradient_detected(prof, _EventType.TorchOp, p.grad), name + ) + assert_only_gradients(prof) + + check(cold_start=True) + check(cold_start=False) + + def _test_extract_gradients_from_optimizer(self, set_to_none: bool) -> None: + + x = torch.ones((1,)) + w0 = torch.ones((1,), requires_grad=True) + w1 = torch.ones((1,), requires_grad=True) + optimizer = torch.optim.SGD((w0, w1), lr=0.1, momentum=0.9) + + def check(cold_start: bool): + self.assertEqual(w0.grad is None, cold_start) + self.assertEqual(w1.grad is None, cold_start) + with profile() as prof: + optimizer.zero_grad(set_to_none=set_to_none) + z = x.expand(4) * w0 + (z * w1).sum().backward() + optimizer.step() + + # Optimizer instrumentation runs late in the step, so we can detect + # gradients for both cold and warm start. + self.assertGradientDetected("w0", prof, _EventType.PyCall, w0.grad, w0) + self.assertGradientDetected("w1", prof, _EventType.PyCall, w1.grad, w1) + + self.assertGradientDetected("w0", prof, _EventType.TorchOp, w0.grad) + self.assertGradientDetected("w1", prof, _EventType.TorchOp, w1.grad) + self.assertOnlyGradients(prof, (w0.grad, w1.grad)) + + with profile() as prof: + for _ in range(2): + optimizer.zero_grad(set_to_none=set_to_none) + z = x.expand(4) * w0 + (z * w1).sum().backward() + optimizer.step() + + # Inspected state is cached, so if we replace gradients (as is the + # case for `set_to_none=True`) our python instrumentation will not + # see them. + # TODO(robieta): Should `.step()` be excluded from caching? + self.assertNotEqual( + self.gradient_detected(prof, _EventType.PyCall, w0.grad, w0), + set_to_none, + ) + + self.assertNotEqual( + self.gradient_detected(prof, _EventType.PyCall, w1.grad, w1), + set_to_none, + ) + + if set_to_none: + with self.assertRaisesRegex(AssertionError, "Tensor wrongly marked"): + self.assertOnlyGradients(prof, (w0.grad, w1.grad)) + + check(cold_start=True) + check(cold_start=False) + + def test_extract_gradients_from_optimizer(self) -> None: + self._test_extract_gradients_from_optimizer(set_to_none=False) + + def test_extract_gradients_from_optimizer_set_to_none(self) -> None: + self._test_extract_gradients_from_optimizer(set_to_none=True) + + def test_extract_gradients_from_module_and_optimizer(self) -> None: + # Module and optimizer are thoroughly tested individually and should be + # additive. Thus we can manage with a lightweight check that they don't + # interact adversely. + model = torch.nn.Sequential(torch.nn.Linear(2, 1), ScaleLayer()) + optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) + with profile() as prof: + model(torch.ones((2, 2))).sum().backward() + optimizer.step() + + self.assertGradientDetected( + "weight", prof, _EventType.PyCall, model[0].weight.grad, model[0].weight + ) + + +@skipIfTorchDynamo("TorchDynamo removes profiler altogether.") +class TestDataFlow(TestCase): + def setUp(self) -> None: + super().setUp() + self.maxDiff = None + + @staticmethod + def formatSchemas( + prof: torch.profiler.profile, indent: int = 12 + ) -> Tuple[Tuple[str, Tuple[bool, ...]], ...]: + tree = prof.profiler.kineto_results.experimental_event_tree() + out: List[Tuple[str, Tuple[bool, ...]]] = [] + for node in _utils.traverse_dfs(tree): + if node.tag == _EventType.TorchOp: + e = node.extra_fields + schemas = _memory_profiler.SchemaMatcher.match_schemas(e) + name = node.name + if len(schemas) == 1: + name = f"{name}.{schemas[0].overload_name}" + elif len(schemas) > 1: + name = f"{name}.{{{', '.join(s.overload_name for s in schemas)}}}" + + out.append((name, _memory_profiler.SchemaMatcher.inputs_are_mutable(e))) + return tuple(out) + + @staticmethod + def _run_and_format_data_flow( + inputs: Dict[str, torch.Tensor], + f: Callable[..., Optional[Dict[str, torch.Tensor]]], + indent: int = 12, + ) -> str: + with profile() as prof: + outputs = f(**inputs) or {} + gc.collect() + + memory_profile = prof._memory_profile() + graph = memory_profile._data_flow_graph + storage_to_id = {key.storage.ptr: key.id for key in graph._active_version} + + lines: List[str] = [] + for name, t in it.chain(inputs.items(), outputs.items()): + lines.append(f"{name + ':':<8} T{storage_to_id[t.storage().data_ptr()]}") + if t.grad is not None: + grad_id = storage_to_id[t.grad.storage().data_ptr()] + lines.append(f"{name + '.grad:':<9} T{grad_id}") + + if lines: + lines.append("") + + for node in graph.flow_nodes: + destroyed = {k for k, v in node._edges.items() if v.is_deletion} + + inputs: List[str] = [] + for key, (_, v) in node.inputs.items(): + inputs.append(f"T{key.id}(v{v}{'*' if key in destroyed else ''})") + + outputs = [f"T{key.id}(v{v})" for key, v in node.outputs.items()] + if inputs or outputs: + event_name = node._event.name.replace("torch::autograd::", "") + lines.append( + f"{event_name:<25} {', '.join(inputs):<15} -> {', '.join(outputs)}" + ) + + return textwrap.indent("\n".join([l.rstrip() for l in lines]), " " * indent) + + def test_match_schemas(self) -> None: + with profile() as prof: + x = torch.ones((1,)).mul(2).add_(2) + _ = torch.sin(x, out=torch.empty_like(x)) + + self.assertEqual( + self.formatSchemas(prof), + ( + ("aten::ones.", (False,) * 5), + ("aten::empty.memory_format", (False,) * 6), + # + # fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!) + ("aten::fill_.Scalar", (True, False)), + ("aten::mul.Tensor", (False, False)), + ("aten::to.dtype", (False,) * 5), + ("aten::_to_copy.", (False,) * 7), + ("aten::empty_strided.", (False,) * 6), + # + # copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!) + ("aten::copy_.", (True, False, False)), + # + # add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) + ("aten::add_.Tensor", (True, False, False)), + ("aten::to.dtype", (False,) * 5), + ("aten::_to_copy.", (False,) * 7), + ("aten::empty_strided.", (False,) * 6), + # + # copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!) + ("aten::copy_.", (True, False, False)), + ("aten::empty_like.", (False,) * 6), + ("aten::empty_strided.", (False,) * 6), + # + # sin.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + ("aten::sin.out", (False, True)), + ), + ) + + def test_match_schemas_backward(self) -> None: + x = torch.ones((1,)) + w = torch.ones((1,), requires_grad=True) + with profile() as prof: + torch.mul(x, w).backward() + + self.assertEqual( + self.formatSchemas(prof), + ( + ("aten::mul.Tensor", (False, False)), + ("aten::ones_like.", (False,) * 6), + ("aten::empty_like.", (False,) * 6), + ("aten::empty_strided.", (False,) * 6), + # + # fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!) + ("aten::fill_.Scalar", (True, False)), + ("autograd::engine::evaluate_function: MulBackward0", ()), + ("MulBackward0", (None,)), + ("aten::mul.Tensor", (False, False)), + ( + "autograd::engine::evaluate_function: torch::autograd::AccumulateGrad", + (), + ), + ("torch::autograd::AccumulateGrad", (None,)), + ("aten::detach.", (False,)), + ("detach", (None,)), + ), + ) + + def test_match_schemas_tensorlist(self) -> None: + x = torch.ones((1,)) + y = torch.ones((1,)) + with profile() as prof: + torch.cat([x, y], axis=0) + + self.assertEqual( + self.formatSchemas(prof), + (("aten::cat.", (False, False)),), + ) + + def test_data_flow_graph_with_annotations(self) -> None: + def f(x, y): + # torch._C._jit_get_schemas_for_operator will reject any name that + # is missing a namespace. (denoted by the presence of "::") We want + # to check that we skip both annotations which have no schema + # (return empty tuple from SchemaMatcher.lookup_schemas) and + # annotations which cannot have schema (return None from + # SchemaMatcher.lookup_schemas). + with torch.profiler.record_function("Namespaced::Annotation"): + with torch.profiler.record_function("My Annotation"): + x.zero_() + y.zero_() + return {"x0": torch.ones_like(x), "y0": torch.zeros_like(y)} + + inputs = {"x": torch.ones((1,)), "y": torch.ones((1,))} + self.assertExpectedInline( + self._run_and_format_data_flow(inputs, f), + """\ + x: T0 + y: T1 + x0: T2 + y0: T3 + + aten::zero_ T0(v0) -> T0(v1) + aten::zero_ T1(v0) -> T1(v1) + aten::ones_like T0(v1) -> T2(v0) + aten::zeros_like T1(v1) -> T3(v0)""", + ) + + def test_data_flow_graph_non_op_allocations(self) -> None: + def f(x): + x.mul(2) + + # The python arg parser will convert the python scalar `2` to a Tensor + # to pass to `aten::mul`. As a result there is no op that "owns" the + # allocation. The Tensor deletions also do not happen in an op; they + # are collected as a result of the Python objects going out of scope. + self.assertExpectedInline( + self._run_and_format_data_flow({"x": torch.ones((1,))}, f), + """\ + x: T1 + + [memory] -> T0(v0) + aten::mul T0(v0), T1(v0) -> + [memory] T0(v0*) ->""", + ) + + def test_data_flow_graph_simple(self) -> None: + inputs = {"x": torch.ones((25,)), "y": torch.ones((25,), requires_grad=True)} + + def f0(x, y): + z = x.mul(y) + return {"z": z.view_as(z)} + + def f1(x, y): + with torch.no_grad(): + return f0(x, y) + + self.assertExpectedInline( + self._run_and_format_data_flow(inputs, f0), + """\ + x: T0 + y: T1 + z: T2 + + aten::mul T0(v0), T1(v0) -> T2(v0) + aten::view_as T2(v0) ->""", + ) + + # Out of place is identical regardless of Autograd. + self.assertExpectedInline( + self._run_and_format_data_flow(inputs, f0), + """\ + x: T0 + y: T1 + z: T2 + + aten::mul T0(v0), T1(v0) -> T2(v0) + aten::view_as T2(v0) ->""", + ) + + def test_data_flow_graph_simple_inplace(self) -> None: + inputs = {"x": torch.ones((25,)), "y": torch.ones((25,), requires_grad=True)} + + def f0(x, y): + x.mul_(y) + + def f1(x, y): + with torch.no_grad(): + return f0(x, y) + + # When Autograd is enabled a second Tensor `T2` is created to store + # the values of T0(v0) which are needed for backwards. + self.assertExpectedInline( + self._run_and_format_data_flow(inputs, f0), + """\ + x: T0 + y: T1 + + aten::mul_ T0(v0), T1(v0) -> T0(v1), T2(v0)""", + ) + + self.assertExpectedInline( + self._run_and_format_data_flow(inputs, f1), + """\ + x: T0 + y: T1 + + aten::mul_ T0(v0), T1(v0) -> T0(v1)""", + ) + + def test_data_flow_graph_simple_backward(self) -> None: + inputs = { + "x": torch.ones((1,)), + "w": torch.ones((1,), requires_grad=True), + } + self.assertExpectedInline( + self._run_and_format_data_flow( + inputs, lambda x, w: (x * w).sin().backward() + ), + """\ + x: T0 + w: T1 + w.grad: T7 + + aten::mul T0(v0), T1(v0) -> T2(v0) + aten::sin T2(v0) -> T3(v0) + aten::ones_like T3(v0) -> T4(v0) + SinBackward0 T2(v0), T4(v0) -> T6(v0) + [memory] T2(v0*) -> + MulBackward0 T0(v0), T6(v0) -> T7(v0) + [memory] T6(v0*) -> + AccumulateGrad T7(v0) -> + [memory] T4(v0*) -> + [memory] T3(v0*) ->""", + ) + + def test_data_flow_graph_complicated(self) -> None: + def f(): + x = torch.ones((25,)) + y = x.mul(2).add_(2) + z = torch.sin(y, out=torch.empty_like(y)) + return {"x": x, "y": y, "z": z} + + # T1 is the `2` in `.mul(2)`. The Python arg parser automatically + # converts Scalar arguments to Tensors. The same is true for `T4` + # and `.add_(2)`. + self.assertExpectedInline( + self._run_and_format_data_flow({}, f), + """\ + x: T0 + y: T3 + z: T6 + + aten::ones -> T0(v0) + [memory] -> T1(v0) + aten::mul T0(v0), T1(v0) -> T3(v0) + [memory] T1(v0*) -> + [memory] -> T4(v0) + aten::add_ T3(v0), T4(v0) -> T3(v1) + [memory] T4(v0*) -> + aten::empty_like T3(v1) -> T6(v0) + aten::sin T3(v1), T6(v0) -> T6(v1)""", + ) + + with profile() as prof: + f() + + # `aten::mul` creates a temporary Tensor (T2), which is why the output + # is has ID three rather than two. + mul_node = prof._memory_profile()._data_flow_graph.flow_nodes[2] + self.assertEqual(mul_node._event.name, "aten::mul") + self.assertEqual(len(mul_node.intermediates), 1) + self.assertEqual(mul_node.intermediates[0].id, 2) + + def test_data_flow_graph_stacked(self) -> None: + inputs = { + "x": torch.ones((25,)), + "w0": torch.ones((1,), requires_grad=True), + "w1": torch.ones((1,), requires_grad=True), + } + + def f(x, w0, w1): + return x.mul(w0).relu().mul(w1).relu().sum() + + def f_fwd(**kwargs): + with torch.no_grad(): + return {"loss": f(**kwargs)} + + def f_fwd_bwd(**kwargs): + loss = f(**kwargs) + loss.backward() + return {"loss": loss} + + self.assertExpectedInline( + self._run_and_format_data_flow(inputs, f_fwd), + """\ + x: T0 + w0: T1 + w1: T4 + loss: T7 + + aten::mul T0(v0), T1(v0) -> T2(v0) + aten::relu T2(v0) -> T3(v0) + [memory] T2(v0*) -> + aten::mul T3(v0), T4(v0) -> T5(v0) + [memory] T3(v0*) -> + aten::relu T5(v0) -> T6(v0) + [memory] T5(v0*) -> + aten::sum T6(v0) -> T7(v0) + [memory] T6(v0*) ->""", + ) + + self.assertExpectedInline( + self._run_and_format_data_flow(inputs, f_fwd_bwd), + """\ + x: T0 + w0: T1 + w0.grad: T15 + w1: T4 + w1.grad: T12 + loss: T7 + + aten::mul T0(v0), T1(v0) -> T2(v0) + aten::relu T2(v0) -> T3(v0) + [memory] T2(v0*) -> + aten::mul T3(v0), T4(v0) -> T5(v0) + aten::relu T5(v0) -> T6(v0) + [memory] T5(v0*) -> + aten::sum T6(v0) -> T7(v0) + aten::ones_like T7(v0) -> T8(v0) + SumBackward0 T8(v0) -> + ReluBackward0 T6(v0), T8(v0) -> T9(v0) + [memory] T6(v0*) -> + MulBackward0 T3(v0), T4(v0), T9(v0) -> T10(v0), T11(v0) + aten::sum T10(v0) -> T12(v0) + [memory] T10(v0*) -> + [memory] T9(v0*) -> + AccumulateGrad T12(v0) -> + ReluBackward0 T3(v0), T11(v0) -> T13(v0) + [memory] T11(v0*) -> + [memory] T3(v0*) -> + MulBackward0 T0(v0), T13(v0) -> T14(v0) + aten::sum T14(v0) -> T15(v0) + [memory] T14(v0*) -> + [memory] T13(v0*) -> + AccumulateGrad T15(v0) -> + [memory] T8(v0*) ->""", + ) + + # Second time grads are already initialized. + self.assertExpectedInline( + self._run_and_format_data_flow(inputs, f_fwd_bwd), + """\ + x: T0 + w0: T1 + w0.grad: T17 + w1: T4 + w1.grad: T13 + loss: T7 + + aten::mul T0(v0), T1(v0) -> T2(v0) + aten::relu T2(v0) -> T3(v0) + [memory] T2(v0*) -> + aten::mul T3(v0), T4(v0) -> T5(v0) + aten::relu T5(v0) -> T6(v0) + [memory] T5(v0*) -> + aten::sum T6(v0) -> T7(v0) + aten::ones_like T7(v0) -> T8(v0) + SumBackward0 T8(v0) -> + ReluBackward0 T6(v0), T8(v0) -> T9(v0) + [memory] T6(v0*) -> + MulBackward0 T3(v0), T4(v0), T9(v0) -> T10(v0), T11(v0) + aten::sum T10(v0) -> T12(v0) + [memory] T10(v0*) -> + [memory] T9(v0*) -> + AccumulateGrad T12(v0*), T13(v0) -> T13(v1) + ReluBackward0 T3(v0), T11(v0) -> T14(v0) + [memory] T11(v0*) -> + [memory] T3(v0*) -> + MulBackward0 T0(v0), T14(v0) -> T15(v0) + aten::sum T15(v0) -> T16(v0) + [memory] T15(v0*) -> + [memory] T14(v0*) -> + AccumulateGrad T16(v0*), T17(v0) -> T17(v1) + [memory] T8(v0*) ->""", + ) + + return + + x = torch.ones((25,)) + w0 = torch.ones((1,), requires_grad=True) + w1 = torch.ones((1,), requires_grad=True) + + with profile() as prof_no_grad: + with torch.no_grad(): + x.mul(w0).relu().mul(w1).relu().sum() + + # TODO: one with `.logsumexp(dim=0)` + + self.assertExpectedInline( + self._format_graph(prof_no_grad), + """\ + aten::mul T0(v0), T1(v0) -> T2(v0) + aten::relu T2(v0) -> T3(v0) + [memory] T2(v0*) -> + aten::mul T3(v0), T4(v0) -> T5(v0) + [memory] T3(v0*) -> + aten::relu T5(v0) -> T6(v0) + [memory] T5(v0*) -> + aten::sum T6(v0) -> T7(v0) + [memory] T6(v0*) -> + [memory] T7(v0*) ->""", + ) + + with profile() as prof_grad: + loss = x.mul(w0).relu().mul(w1).relu().sum() + loss.backward() + + self.assertExpectedInline( + self._format_graph(prof_grad), + """\ + aten::mul T0(v0), T1(v0) -> T2(v0) + aten::relu T2(v0) -> T3(v0) + [memory] T2(v0*) -> + aten::mul T3(v0), T4(v0) -> T5(v0) + aten::relu T5(v0) -> T6(v0) + [memory] T5(v0*) -> + aten::sum T6(v0) -> T7(v0) + aten::ones_like T7(v0) -> T8(v0) + SumBackward0 T8(v0) -> T8(v1) + ReluBackward0 T6(v0), T8(v1) -> T8(v2), T9(v0) + [memory] T6(v0*) -> + MulBackward0 T3(v0), T4(v0), T9(v0) -> T9(v1), T10(v0), T11(v0) + aten::sum T10(v0) -> T12(v0) + [memory] T10(v0*) -> + [memory] T9(v1*) -> + AccumulateGrad T12(v0) -> T12(v1) + ReluBackward0 T3(v0), T11(v0) -> T11(v1), T13(v0) + [memory] T11(v1*) -> + [memory] T3(v0*) -> + MulBackward0 T0(v0), T13(v0) -> T13(v1), T14(v0) + aten::sum T14(v0) -> T15(v0) + [memory] T14(v0*) -> + [memory] T13(v1*) -> + AccumulateGrad T15(v0) -> T15(v1) + [memory] T8(v2*) ->""", + ) + + # Second time grads are already initialized. + with profile() as prof_grad: + loss = x.mul(w0).relu().mul(w1).relu().sum() + loss.backward() + + self.assertExpectedInline( + self._format_graph(prof_grad), + """\ + aten::mul T0(v0), T1(v0) -> T2(v0) + aten::relu T2(v0) -> T3(v0) + [memory] T2(v0*) -> + aten::mul T3(v0), T4(v0) -> T5(v0) + aten::relu T5(v0) -> T6(v0) + [memory] T5(v0*) -> + aten::sum T6(v0) -> T7(v0) + aten::ones_like T7(v0) -> T8(v0) + SumBackward0 T8(v0) -> T8(v1) + ReluBackward0 T6(v0), T8(v1) -> T8(v2), T9(v0) + [memory] T6(v0*) -> + MulBackward0 T3(v0), T4(v0), T9(v0) -> T9(v1), T10(v0), T11(v0) + aten::sum T10(v0) -> T12(v0) + [memory] T10(v0*) -> + [memory] T9(v1*) -> + AccumulateGrad T12(v0*), T13(v0) -> T13(v1) + ReluBackward0 T3(v0), T11(v0) -> T11(v1), T14(v0) + [memory] T11(v1*) -> + [memory] T3(v0*) -> + MulBackward0 T0(v0), T14(v0) -> T14(v1), T15(v0) + aten::sum T15(v0) -> T16(v0) + [memory] T15(v0*) -> + [memory] T14(v1*) -> + AccumulateGrad T16(v0*), T17(v0) -> T17(v1) + [memory] T8(v2*) ->""", + ) + + +@skipIfTorchDynamo("TorchDynamo changes Python calls that memory profiling relies on.") +class TestMemoryProfilerE2E(TestCase): + @staticmethod + def _lookup_tensor_categories( + t: torch.Tensor, memory_profile: _memory_profiler.MemoryProfile + ) -> Dict[_memory_profiler.TensorAndID, Optional[_memory_profiler.Category]]: + storage = t.storage() + if storage is None: + raise ValueError("Cannot look up uninitialized Tensor.") + + snapshot = memory_profile._category_snapshot() + ids = { + key.storage.allocation_id + for key, _ in snapshot + if key.storage.ptr == storage.data_ptr() and key.device == storage.device + } + + return { + (key, version): category + for (key, version), category in memory_profile._category_snapshot().items() + # + # If a Tensor is live we want the most recent ID + if key.storage.allocation_id == max(ids | {-1}) + } + + def _run_and_check_parameters_and_gradients(self, inner_fn, model): + + with profile() as prof: + inner_fn() + + memory_profile = prof._memory_profile() + + def assert_category(t: torch.Tensor, category: _memory_profiler.Category): + self.assertIsNotNone(t) + categories = self._lookup_tensor_categories(t, memory_profile) + self.assertGreater(len(categories), 0) + self.assertTrue(all(c == category for c in categories.values()), categories) + + for p in model.parameters(): + assert_category(p, _memory_profiler.Category.PARAMETER) + assert_category(p.grad, _memory_profiler.Category.GRADIENT) + + # Rely on internal asserts + _ = memory_profile.timeline + + def _run_and_format_categories(self, fn, indent=12): + """Generate summary of assigned categories for expecttest.""" + + # Use `__torch_dispatch__` to collect ground truth. + with RecordInputOutputDispatchMode() as record_ops, profile() as prof: + fn(lambda name: record_ops.mark_region(f"-- {name} ".ljust(105, "-"))) + + memory_profile = prof._memory_profile() + ptr_pair_to_key: Dict[Tuple[int, int], _memory_profiler.TensorKey] = {} + snapshot = memory_profile._category_snapshot() + + # Build map from observed live Tensors to the memory profiler's + # TensorKey representation. + for op in memory_profile._op_tree.dfs(): + if op.typed[0] == _EventType.TorchOp: + inputs = tree_flatten(op.typed[1].inputs)[0] + for t in (i for i in inputs if isinstance(i, _TensorMetadata)): + key = _memory_profiler.TensorKey.from_tensor(t) + if key: + ptr_pair_to_key[(t.impl_ptr, t.storage_data_ptr)] = key + + def format_categories(ptr_pair: int): + target_key = ptr_pair_to_key.get(ptr_pair, None) + if target_key is None: + return "???" + + matches = tuple( + (version, category.name if category else "???") + for (key, version), category in snapshot.items() + if key == target_key + ) + assert matches, "Failed to lookup Tensor" + + # Deduplicate version bumps which don't change the category. + categories = [matches[0][1]] + for _, category in matches: + if category != categories[-1]: + categories.append(category) + + return f"{target_key.storage.allocation_id} ({','.join(categories)})" + + out: List[str] = [] + for name, inputs, outputs in record_ops.results: + if inputs or outputs: + # PyTorch ops + inputs_str = ", ".join(format_categories(i) for i in inputs) + outputs_str = ", ".join(format_categories(i) for i in outputs) + out.append(f"{name:<40} {inputs_str:<45} -> {outputs_str}") + + else: + # Marked regions. + out.append(f"\n{name}") + + return textwrap.indent("\n".join(out), " " * indent) + + def test_parameters_and_gradients(self): + model = torch.nn.Sequential( + torch.nn.Linear(2, 2), ScaleLayer(), torch.nn.Linear(2, 1), ScaleLayer() + ) + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + + def fwd_only(): + _ = model(torch.ones((2, 2))) + + def fwd_bwd_step(): + y = model(torch.ones((2, 2))) + torch.nn.functional.mse_loss(y, torch.rand((2, 1))).backward() + optimizer.step() + optimizer.zero_grad() + + # If we profile the first step then gradients will not have been + # created when we call `model.forward`, so if we don't call `.backward` + # then gradients are never created. + with self.assertRaises(AssertionError): + self._run_and_check_parameters_and_gradients(inner_fn=fwd_only, model=model) + + # On the first step we must rely on `AccumulateGrad`, since gradients + # did not exist when `model.forward` was called. + self.assertTrue(all(p.grad is None for p in model.parameters())) + self._run_and_check_parameters_and_gradients(inner_fn=fwd_bwd_step, model=model) + + # After one step the python tracer will also flag gradients. + self.assertTrue(not any(p.grad is None for p in model.parameters())) + self._run_and_check_parameters_and_gradients(inner_fn=fwd_bwd_step, model=model) + + # The parameter gradients are not used but we still detect them with + # the python tracer. + self._run_and_check_parameters_and_gradients(inner_fn=fwd_only, model=model) + + def test_parameters_and_gradients_set_to_none(self): + model = torch.nn.Sequential(torch.nn.Linear(2, 2), torch.nn.Linear(2, 1)) + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + + def fwd_bwd_step(): + for _ in range(3): + # zero grads at the start so gradients are still live to be + # checked. + optimizer.zero_grad(set_to_none=True) + + y = model(torch.ones((2, 2))) + torch.nn.functional.mse_loss(y, torch.rand((2, 1))).backward() + optimizer.step() + + fwd_bwd_step() + self.assertTrue(not any(p.grad is None for p in model.parameters())) + self._run_and_check_parameters_and_gradients(inner_fn=fwd_bwd_step, model=model) + + optimizer.zero_grad(set_to_none=True) + self.assertTrue(all(p.grad is None for p in model.parameters())) + self._run_and_check_parameters_and_gradients(inner_fn=fwd_bwd_step, model=model) + + def test_inputs_fwd(self): + model = torch.nn.Sequential(torch.nn.Linear(2, 2), torch.nn.Linear(2, 1)) + inputs = [torch.ones((2, 2)) for _ in range(2)] + + with profile() as prof: + # Inputs which were allocated before profiling began + for x in inputs: + _ = model(x) + + # Inputs which were allocated after profiling began + for _ in range(2): + x = torch.ones((2, 2)) + inputs.append(x) + _ = model(x) + + memory_profile = prof._memory_profile() + for x in inputs: + categories = self._lookup_tensor_categories(x, memory_profile) + self.assertGreater(len(categories), 0) + self.assertTrue( + all(i == _memory_profiler.Category.INPUT for i in categories.values()), + categories, + ) + + snapshot = memory_profile._category_snapshot() + self.assertTrue(_memory_profiler.Category.INPUT in snapshot.values()) + + def test_inputs_fwd_lazy(self): + model = torch.nn.Sequential(LazyLinear(2, 2), LazyLinear(2, 1)) + inputs = [torch.ones((2, 2)) for _ in range(2)] + + with profile() as prof: + # Inputs which were allocated before profiling began + for x in inputs: + _ = model(x) + + # Inputs which were allocated after profiling began + for _ in range(2): + x = torch.ones((2, 2)) + inputs.append(x) + _ = model(x) + + # For now we can't make any meaningful statements without a backward + # pass. Here we simply ensure that passes don't generate false positive + # category classifications. + memory_profile = prof._memory_profile() + for x in inputs: + categories = self._lookup_tensor_categories(x, memory_profile) + self.assertGreater(len(categories), 0) + self.assertTrue(all(i is None for i in categories.values()), categories) + + snapshot = memory_profile._category_snapshot() + self.assertFalse(_memory_profiler.Category.INPUT in snapshot.values()) + + def test_inputs_fwd_bwd(self): + model = torch.nn.Sequential(torch.nn.Linear(2, 2), torch.nn.Linear(2, 1)) + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + inputs_targets = [(torch.ones((2, 2)), torch.rand((2, 1))) for _ in range(2)] + + def fwd_bwd_step(x, targets): + y = model(x) + torch.nn.functional.mse_loss(y, targets).backward() + optimizer.step() + optimizer.zero_grad() + + with profile() as prof: + # Inputs which were allocated before profiling began + for x, targets in inputs_targets: + fwd_bwd_step(x, targets) + + # Inputs which were allocated after profiling began + for _ in range(2): + x = torch.ones((2, 2)) + targets = torch.rand((2, 1)) + inputs_targets.append((x, targets)) + fwd_bwd_step(x, targets) + + memory_profile = prof._memory_profile() + + def check(t): + categories = self._lookup_tensor_categories(t, memory_profile) + self.assertGreater(len(categories), 0) + self.assertTrue( + all(i == _memory_profiler.Category.INPUT for i in categories.values()) + ) + + for x, targets in inputs_targets: + check(x) + check(targets) + + def test_lazily_initialized(self) -> None: + model = torch.nn.Sequential( + torch.nn.Linear(2, 2), + torch.nn.ReLU(), + LazyLinear(2, 2), + torch.nn.ReLU(), + torch.nn.Linear(2, 1), + ) + + self.assertEqual(len(list(model.parameters())), 4) + + def inner_fn(): + y = model(torch.ones((2, 2))) + torch.nn.functional.mse_loss(y, torch.rand((2, 1))).backward() + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + optimizer.step() + optimizer.zero_grad() + + self._run_and_check_parameters_and_gradients(inner_fn=inner_fn, model=model) + self.assertEqual(len(list(model.parameters())), 6) + + def test_manual_optimizer_step(self) -> None: + model = torch.nn.Sequential(torch.nn.Linear(2, 2), torch.nn.Linear(2, 1)) + + def inner_fn(): + y = model(torch.ones((2, 2))) + torch.nn.functional.mse_loss(y, torch.rand((2, 1))).backward() + + with torch.no_grad(): + for p in model.parameters(): + grad = p.grad + self.assertIsNotNone(grad) + p.add_(grad, alpha=-0.1) + + self._run_and_check_parameters_and_gradients(inner_fn=inner_fn, model=model) + + def test_categories_e2e_simple_fwd(self) -> None: + w0 = torch.ones((1,), requires_grad=True) + w1 = torch.ones((1,), requires_grad=True) + + def step_fn(_): + x = torch.ones((2, 2)) + y = torch.cat([x * w0, x * w1], dim=1) + + # NOTE: We expect that all unknown categories. This is simply a sanity + # check to ensure that we do not over-label. + self.assertExpectedInline( + self._run_and_format_categories(step_fn), + """\ + aten::ones -> 1 (???) + aten::mul.Tensor 1 (???), 2 (???) -> 3 (???) + aten::mul.Tensor 1 (???), 4 (???) -> 5 (???) + aten::cat 3 (???), 5 (???) -> ???""", + ) + + def test_categories_e2e_simple_fwd_bwd(self) -> None: + w0 = torch.ones((1,), requires_grad=True) + w1 = torch.ones((1,), requires_grad=True) + + def step_fn(mark_region): + x = torch.ones((2, 2)) + targets = torch.ones((2, 4)) + + mark_region("Forward & loss") + y = torch.cat([x * w0, x * w1], dim=1) + loss = torch.nn.functional.binary_cross_entropy_with_logits(y, targets) + + mark_region("Backward") + loss.backward() + + self.assertExpectedInline( + self._run_and_format_categories(step_fn), + """\ + aten::ones -> 1 (INPUT) + aten::ones -> 2 (INPUT) + + -- Forward & loss --------------------------------------------------------------------------------------- + aten::mul.Tensor 1 (INPUT), 3 (INPUT) -> 4 (INPUT) + aten::mul.Tensor 1 (INPUT), 5 (INPUT) -> 6 (INPUT) + aten::cat 4 (INPUT), 6 (INPUT) -> 7 (INPUT) + aten::binary_cross_entropy_with_logits 7 (INPUT), 2 (INPUT) -> 13 (INPUT) + + -- Backward --------------------------------------------------------------------------------------------- + aten::ones_like 13 (INPUT) -> 16 (INPUT) + aten::sigmoid 7 (INPUT) -> 17 (TEMPORARY) + aten::sub.Tensor 17 (TEMPORARY), 2 (INPUT) -> 18 (TEMPORARY) + aten::mul.Tensor 18 (TEMPORARY), 16 (INPUT) -> 19 (AUTOGRAD_DETAIL) + aten::div_.Scalar 19 (AUTOGRAD_DETAIL) -> 19 (AUTOGRAD_DETAIL) + aten::slice.Tensor 19 (AUTOGRAD_DETAIL) -> 19 (AUTOGRAD_DETAIL) + aten::slice.Tensor 19 (AUTOGRAD_DETAIL) -> 19 (AUTOGRAD_DETAIL) + aten::mul.Tensor 19 (AUTOGRAD_DETAIL), 1 (INPUT) -> 22 (AUTOGRAD_DETAIL) + aten::sum.dim_IntList 22 (AUTOGRAD_DETAIL) -> 23 (GRADIENT) + aten::view 23 (GRADIENT) -> 23 (GRADIENT) + aten::detach 23 (GRADIENT) -> 23 (GRADIENT) + aten::detach 23 (GRADIENT) -> ??? + aten::mul.Tensor 19 (AUTOGRAD_DETAIL), 1 (INPUT) -> 24 (AUTOGRAD_DETAIL) + aten::sum.dim_IntList 24 (AUTOGRAD_DETAIL) -> 25 (GRADIENT) + aten::view 25 (GRADIENT) -> 25 (GRADIENT) + aten::detach 25 (GRADIENT) -> 25 (GRADIENT) + aten::detach 25 (GRADIENT) -> ???""", + ) + + def test_categories_e2e_simple_fwd_bwd_step(self) -> None: + w0 = torch.ones((1,), requires_grad=True) + w1 = torch.ones((1,), requires_grad=True) + optimizer = torch.optim.SGD([w0, w1], lr=0.1) + + def step_fn(mark_region): + x = torch.ones((2, 2)) + targets = torch.ones((2, 4)) + + mark_region("Forward & loss") + y = torch.cat([x * w0, x * w1], dim=1) + loss = torch.nn.functional.binary_cross_entropy_with_logits(y, targets) + + mark_region("Backward") + loss.backward() + + mark_region("Optimizer") + optimizer.step() + optimizer.zero_grad() + + self.assertExpectedInline( + self._run_and_format_categories(step_fn), + """\ + aten::ones -> 1 (INPUT) + aten::ones -> 2 (INPUT) + + -- Forward & loss --------------------------------------------------------------------------------------- + aten::mul.Tensor 1 (INPUT), 3 (PARAMETER) -> 4 (ACTIVATION) + aten::mul.Tensor 1 (INPUT), 5 (PARAMETER) -> 6 (ACTIVATION) + aten::cat 4 (ACTIVATION), 6 (ACTIVATION) -> 7 (ACTIVATION) + aten::binary_cross_entropy_with_logits 7 (ACTIVATION), 2 (INPUT) -> 13 (ACTIVATION) + + -- Backward --------------------------------------------------------------------------------------------- + aten::ones_like 13 (ACTIVATION) -> 16 (ACTIVATION) + aten::sigmoid 7 (ACTIVATION) -> 17 (TEMPORARY) + aten::sub.Tensor 17 (TEMPORARY), 2 (INPUT) -> 18 (TEMPORARY) + aten::mul.Tensor 18 (TEMPORARY), 16 (ACTIVATION) -> 19 (AUTOGRAD_DETAIL) + aten::div_.Scalar 19 (AUTOGRAD_DETAIL) -> 19 (AUTOGRAD_DETAIL) + aten::slice.Tensor 19 (AUTOGRAD_DETAIL) -> 19 (AUTOGRAD_DETAIL) + aten::slice.Tensor 19 (AUTOGRAD_DETAIL) -> 19 (AUTOGRAD_DETAIL) + aten::mul.Tensor 19 (AUTOGRAD_DETAIL), 1 (INPUT) -> 22 (AUTOGRAD_DETAIL) + aten::sum.dim_IntList 22 (AUTOGRAD_DETAIL) -> 23 (GRADIENT) + aten::view 23 (GRADIENT) -> 23 (GRADIENT) + aten::detach 23 (GRADIENT) -> 23 (GRADIENT) + aten::detach 23 (GRADIENT) -> 23 (GRADIENT) + aten::mul.Tensor 19 (AUTOGRAD_DETAIL), 1 (INPUT) -> 24 (AUTOGRAD_DETAIL) + aten::sum.dim_IntList 24 (AUTOGRAD_DETAIL) -> 25 (GRADIENT) + aten::view 25 (GRADIENT) -> 25 (GRADIENT) + aten::detach 25 (GRADIENT) -> 25 (GRADIENT) + aten::detach 25 (GRADIENT) -> 25 (GRADIENT) + + -- Optimizer -------------------------------------------------------------------------------------------- + aten::add_.Tensor 3 (PARAMETER), 25 (GRADIENT) -> 3 (PARAMETER) + aten::add_.Tensor 5 (PARAMETER), 23 (GRADIENT) -> 5 (PARAMETER) + aten::zero_ 25 (GRADIENT) -> 25 (GRADIENT) + aten::zero_ 23 (GRADIENT) -> 23 (GRADIENT)""", + ) + + def test_categories_e2e_simple_module_fwd(self) -> None: + model = torch.nn.Linear(2, 4, bias=True) + self.assertExpectedInline( + self._run_and_format_categories(lambda _: model(torch.ones((2, 2)))), + """\ + aten::ones -> 1 (INPUT) + aten::t 2 (PARAMETER) -> 2 (PARAMETER) + aten::addmm 3 (PARAMETER), 1 (INPUT), 2 (PARAMETER) -> 4 (ACTIVATION)""", + ) + + def test_categories_e2e_simple_module_fwd_bwd(self) -> None: + model = torch.nn.Linear(2, 1, bias=True) + + def step_fn(mark_region): + mark_region("Forward & loss") + loss = model(torch.ones((2, 2))).sum() + + mark_region("Backward") + loss.backward() + + self.assertExpectedInline( + self._run_and_format_categories(step_fn), + """\ + + -- Forward & loss --------------------------------------------------------------------------------------- + aten::ones -> 1 (INPUT) + aten::t 2 (PARAMETER) -> 2 (PARAMETER) + aten::addmm 3 (PARAMETER), 1 (INPUT), 2 (PARAMETER) -> 4 (ACTIVATION) + aten::sum 4 (ACTIVATION) -> 5 (ACTIVATION) + + -- Backward --------------------------------------------------------------------------------------------- + aten::ones_like 5 (ACTIVATION) -> 6 (ACTIVATION) + aten::expand 6 (ACTIVATION) -> 6 (ACTIVATION) + aten::t 6 (ACTIVATION) -> 6 (ACTIVATION) + aten::mm 6 (ACTIVATION), 1 (INPUT) -> 7 (GRADIENT) + aten::t 7 (GRADIENT) -> 7 (GRADIENT) + aten::sum.dim_IntList 6 (ACTIVATION) -> 9 (GRADIENT) + aten::view 9 (GRADIENT) -> 9 (GRADIENT) + aten::detach 9 (GRADIENT) -> 9 (GRADIENT) + aten::detach 9 (GRADIENT) -> ??? + aten::t 7 (GRADIENT) -> 7 (GRADIENT) + aten::detach 7 (GRADIENT) -> 7 (GRADIENT) + aten::detach 7 (GRADIENT) -> ???""", + ) + + def test_categories_e2e_simple_module_fwd_bwd_step(self) -> None: + model = torch.nn.Linear(2, 1, bias=True) + optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) + + def step_fn(mark_region): + mark_region("Forward & loss") + loss = model(torch.ones((2, 2))).sum() + + mark_region("Backward") + loss.backward() + + mark_region("Optimizer") + optimizer.step() + optimizer.zero_grad() + + self.assertExpectedInline( + self._run_and_format_categories(step_fn), + """\ + + -- Forward & loss --------------------------------------------------------------------------------------- + aten::ones -> 1 (INPUT) + aten::t 2 (PARAMETER) -> 2 (PARAMETER) + aten::addmm 3 (PARAMETER), 1 (INPUT), 2 (PARAMETER) -> 4 (ACTIVATION) + aten::sum 4 (ACTIVATION) -> 5 (ACTIVATION) + + -- Backward --------------------------------------------------------------------------------------------- + aten::ones_like 5 (ACTIVATION) -> 6 (ACTIVATION) + aten::expand 6 (ACTIVATION) -> 6 (ACTIVATION) + aten::t 6 (ACTIVATION) -> 6 (ACTIVATION) + aten::mm 6 (ACTIVATION), 1 (INPUT) -> 7 (GRADIENT) + aten::t 7 (GRADIENT) -> 7 (GRADIENT) + aten::sum.dim_IntList 6 (ACTIVATION) -> 9 (GRADIENT) + aten::view 9 (GRADIENT) -> 9 (GRADIENT) + aten::detach 9 (GRADIENT) -> 9 (GRADIENT) + aten::detach 9 (GRADIENT) -> 9 (GRADIENT) + aten::t 7 (GRADIENT) -> 7 (GRADIENT) + aten::detach 7 (GRADIENT) -> 7 (GRADIENT) + aten::detach 7 (GRADIENT) -> 7 (GRADIENT) + + -- Optimizer -------------------------------------------------------------------------------------------- + aten::clone 7 (GRADIENT) -> 10 (OPTIMIZER_STATE) + aten::detach 10 (OPTIMIZER_STATE) -> 10 (OPTIMIZER_STATE) + aten::detach 10 (OPTIMIZER_STATE) -> 10 (OPTIMIZER_STATE) + aten::add_.Tensor 2 (PARAMETER), 10 (OPTIMIZER_STATE) -> 2 (PARAMETER) + aten::clone 9 (GRADIENT) -> 11 (OPTIMIZER_STATE) + aten::detach 11 (OPTIMIZER_STATE) -> 11 (OPTIMIZER_STATE) + aten::detach 11 (OPTIMIZER_STATE) -> 11 (OPTIMIZER_STATE) + aten::add_.Tensor 3 (PARAMETER), 11 (OPTIMIZER_STATE) -> 3 (PARAMETER) + aten::zero_ 7 (GRADIENT) -> 7 (GRADIENT) + aten::zero_ 9 (GRADIENT) -> 9 (GRADIENT)""", + ) + + def test_categories_e2e_sequential_fwd(self) -> None: + model = torch.nn.Sequential( + torch.nn.Linear(2, 4, bias=True), + torch.nn.ReLU(), + torch.nn.Linear(4, 4, bias=False), + torch.nn.Softmax(dim=1), + ) + self.assertExpectedInline( + self._run_and_format_categories(lambda _: model(torch.ones((2, 2)))), + """\ + aten::ones -> 1 (INPUT) + aten::t 2 (PARAMETER) -> 2 (PARAMETER) + aten::addmm 3 (PARAMETER), 1 (INPUT), 2 (PARAMETER) -> 4 (ACTIVATION) + aten::relu 4 (ACTIVATION) -> 5 (ACTIVATION) + aten::detach 5 (ACTIVATION) -> ??? + aten::t 6 (PARAMETER) -> 6 (PARAMETER) + aten::mm 5 (ACTIVATION), 6 (PARAMETER) -> 7 (ACTIVATION) + aten::_softmax 7 (ACTIVATION) -> 8 (ACTIVATION) + aten::detach 8 (ACTIVATION) -> ???""", + ) + + def test_categories_e2e_sequential_fwd_bwd(self) -> None: + model = torch.nn.Sequential( + torch.nn.Linear(2, 4, bias=True), + torch.nn.ReLU(), + torch.nn.Linear(4, 4, bias=False), + torch.nn.Softmax(dim=1), + ) + + def step_fn(mark_region): + x = torch.ones((2, 2)) + targets = torch.ones((2, 4)) + + mark_region("Forward") + y = model(x) + + mark_region("Loss") + loss = torch.sum((y - targets) ** 2).mean() + + mark_region("Backward") + loss.backward() + + self.assertExpectedInline( + self._run_and_format_categories(step_fn), + """\ + aten::ones -> 1 (INPUT) + aten::ones -> 2 (INPUT) + + -- Forward ---------------------------------------------------------------------------------------------- + aten::t 3 (PARAMETER) -> 3 (PARAMETER) + aten::addmm 4 (PARAMETER), 1 (INPUT), 3 (PARAMETER) -> 5 (ACTIVATION) + aten::relu 5 (ACTIVATION) -> 6 (ACTIVATION) + aten::detach 6 (ACTIVATION) -> 6 (ACTIVATION) + aten::t 7 (PARAMETER) -> 7 (PARAMETER) + aten::mm 6 (ACTIVATION), 7 (PARAMETER) -> 8 (ACTIVATION) + aten::_softmax 8 (ACTIVATION) -> 9 (ACTIVATION) + aten::detach 9 (ACTIVATION) -> 9 (ACTIVATION) + + -- Loss ------------------------------------------------------------------------------------------------- + aten::sub.Tensor 9 (ACTIVATION), 2 (INPUT) -> 10 (ACTIVATION) + aten::pow.Tensor_Scalar 10 (ACTIVATION) -> 11 (ACTIVATION) + aten::sum 11 (ACTIVATION) -> 12 (ACTIVATION) + aten::mean 12 (ACTIVATION) -> 13 (ACTIVATION) + + -- Backward --------------------------------------------------------------------------------------------- + aten::ones_like 13 (ACTIVATION) -> 16 (ACTIVATION) + aten::expand 16 (ACTIVATION) -> 16 (ACTIVATION) + aten::div.Scalar 16 (ACTIVATION) -> 19 (AUTOGRAD_DETAIL) + aten::expand 19 (AUTOGRAD_DETAIL) -> 19 (AUTOGRAD_DETAIL) + aten::pow.Tensor_Scalar 10 (ACTIVATION) -> 20 (TEMPORARY) + aten::mul.Scalar 20 (TEMPORARY) -> 23 (TEMPORARY) + aten::mul.Tensor 19 (AUTOGRAD_DETAIL), 23 (TEMPORARY) -> 24 (AUTOGRAD_DETAIL) + aten::detach 9 (ACTIVATION) -> 9 (ACTIVATION) + aten::_softmax_backward_data 24 (AUTOGRAD_DETAIL), 9 (ACTIVATION) -> 25 (AUTOGRAD_DETAIL) + aten::t 25 (AUTOGRAD_DETAIL) -> 25 (AUTOGRAD_DETAIL) + aten::mm 25 (AUTOGRAD_DETAIL), 6 (ACTIVATION) -> 26 (GRADIENT) + aten::t 26 (GRADIENT) -> 26 (GRADIENT) + aten::t 7 (PARAMETER) -> 7 (PARAMETER) + aten::mm 25 (AUTOGRAD_DETAIL), 7 (PARAMETER) -> 27 (AUTOGRAD_DETAIL) + aten::t 26 (GRADIENT) -> 26 (GRADIENT) + aten::detach 26 (GRADIENT) -> 26 (GRADIENT) + aten::detach 26 (GRADIENT) -> ??? + aten::detach 6 (ACTIVATION) -> 6 (ACTIVATION) + aten::threshold_backward 27 (AUTOGRAD_DETAIL), 6 (ACTIVATION) -> 28 (AUTOGRAD_DETAIL) + aten::t 28 (AUTOGRAD_DETAIL) -> 28 (AUTOGRAD_DETAIL) + aten::mm 28 (AUTOGRAD_DETAIL), 1 (INPUT) -> 29 (GRADIENT) + aten::t 29 (GRADIENT) -> 29 (GRADIENT) + aten::sum.dim_IntList 28 (AUTOGRAD_DETAIL) -> 30 (GRADIENT) + aten::view 30 (GRADIENT) -> 30 (GRADIENT) + aten::detach 30 (GRADIENT) -> 30 (GRADIENT) + aten::detach 30 (GRADIENT) -> ??? + aten::t 29 (GRADIENT) -> 29 (GRADIENT) + aten::detach 29 (GRADIENT) -> 29 (GRADIENT) + aten::detach 29 (GRADIENT) -> ???""", + ) + + def test_memory_timeline(self) -> None: + model = torch.nn.Sequential( + torch.nn.Linear(64, 512, bias=True), + torch.nn.ReLU(), + torch.nn.Linear(512, 512, bias=False), + torch.nn.Softmax(dim=1), + ) + optimizer = torch.optim.Adam(model.parameters(), lr=0.1) + + with profile() as prof: + x = torch.ones((1024, 64)) + targets = torch.ones((1024, 512)) + y = model(x) + loss = torch.nn.functional.mse_loss(y, targets) + loss.backward() + optimizer.step() + optimizer.zero_grad() + + memory_profile = prof._memory_profile() + timeline = memory_profile.timeline + times = tuple(t for t, _, _, _ in timeline) + self.assertTrue(all(t1 >= t0 for t0, t1 in zip(times, times[1:])), times) + self.assertTrue( + all( + (t == -1) if action == _memory_profiler.Action.PREEXISTING else (t > 0) + for t, action, _, _ in timeline + ) + ) + + def category_name(category): + return category.name if category else "???" + + def format_action(action, key, version): + category = memory_profile._categories.get(key, version) + if action == _memory_profiler.Action.INCREMENT_VERSION: + new_category = memory_profile._categories.get(key, version + 1) + if category != new_category: + return f"{category_name(category)} -> {category_name(new_category)}" + return category_name(category) + + def format_size(size: int): + if size < 1024: + return f"{size / 1024:3.1f} kB" + return f"{size // 1024} kB" + + + # We generate sequential IDs for Tensors; however platforms vary + # slightly in the exact computation executed. If this results in + # tensor creation the IDs will be shifted and the unit test will fail. + # (Even though the behavior we're testing is unchanged.) To correct for + # this we assign sequential numbers to the tensors which are actually + # tested, effectively suppressing the extraneous implementation details. + id_map = {} + + def id_for_testing(key): + return id_map.setdefault(key.storage.allocation_id, len(id_map)) + + lines = [ + f"{action.name.lower():<25} {format_action(action, key, version):<25} " + f"{id_for_testing(key):>3}(v{version}) {format_size(size):>15}" + for _, action, (key, version), size in prof._memory_profile().timeline + + # We generally don't care about tiny allocations during memory + # profiling and they add a lot of noise to the unit test. + if size >= 256 + ] + + self.assertExpectedInline( + textwrap.indent("\n".join(lines), " " * 12), + """\ + preexisting PARAMETER 0(v0) 128 kB + preexisting PARAMETER 1(v0) 2 kB + preexisting PARAMETER 2(v0) 1024 kB + create INPUT 3(v0) 256 kB + create INPUT 4(v0) 2048 kB + create ACTIVATION 5(v0) 2048 kB + create ACTIVATION 6(v0) 2048 kB + destroy ACTIVATION 5(v0) 2048 kB + create ACTIVATION 7(v0) 2048 kB + create ACTIVATION 8(v0) 2048 kB + destroy ACTIVATION 7(v0) 2048 kB + create ACTIVATION 9(v0) 2048 kB + create TEMPORARY 10(v0) 2048 kB + destroy TEMPORARY 10(v0) 2048 kB + create AUTOGRAD_DETAIL 11(v0) 2048 kB + create AUTOGRAD_DETAIL 12(v0) 2048 kB + destroy AUTOGRAD_DETAIL 11(v0) 2048 kB + create GRADIENT 13(v0) 1024 kB + create AUTOGRAD_DETAIL 14(v0) 2048 kB + destroy AUTOGRAD_DETAIL 12(v0) 2048 kB + create AUTOGRAD_DETAIL 15(v0) 2048 kB + destroy AUTOGRAD_DETAIL 14(v0) 2048 kB + destroy ACTIVATION 6(v0) 2048 kB + create GRADIENT 16(v0) 128 kB + create GRADIENT 17(v0) 2 kB + destroy AUTOGRAD_DETAIL 15(v0) 2048 kB + create OPTIMIZER_STATE 18(v0) 128 kB + create OPTIMIZER_STATE 19(v0) 128 kB + create OPTIMIZER_STATE 20(v0) 2 kB + create OPTIMIZER_STATE 21(v0) 2 kB + create OPTIMIZER_STATE 22(v0) 1024 kB + create OPTIMIZER_STATE 23(v0) 1024 kB + increment_version OPTIMIZER_STATE 18(v0) 128 kB + increment_version OPTIMIZER_STATE 18(v1) 128 kB + increment_version OPTIMIZER_STATE 19(v0) 128 kB + increment_version OPTIMIZER_STATE 19(v1) 128 kB + create ??? 24(v0) 128 kB + create ??? 25(v0) 128 kB + destroy ??? 24(v0) 128 kB + increment_version ??? 25(v0) 128 kB + increment_version PARAMETER 0(v0) 128 kB + increment_version OPTIMIZER_STATE 20(v0) 2 kB + increment_version OPTIMIZER_STATE 20(v1) 2 kB + increment_version OPTIMIZER_STATE 21(v0) 2 kB + increment_version OPTIMIZER_STATE 21(v1) 2 kB + create ??? 26(v0) 2 kB + create ??? 27(v0) 2 kB + destroy ??? 26(v0) 2 kB + increment_version ??? 27(v0) 2 kB + destroy ??? 25(v1) 128 kB + increment_version PARAMETER 1(v0) 2 kB + increment_version OPTIMIZER_STATE 22(v0) 1024 kB + increment_version OPTIMIZER_STATE 22(v1) 1024 kB + increment_version OPTIMIZER_STATE 23(v0) 1024 kB + increment_version OPTIMIZER_STATE 23(v1) 1024 kB + create ??? 28(v0) 1024 kB + create ??? 29(v0) 1024 kB + destroy ??? 28(v0) 1024 kB + increment_version ??? 29(v0) 1024 kB + destroy ??? 27(v1) 2 kB + increment_version PARAMETER 2(v0) 1024 kB + destroy ??? 29(v1) 1024 kB + increment_version GRADIENT 16(v0) 128 kB + increment_version GRADIENT 17(v0) 2 kB + increment_version GRADIENT 13(v0) 1024 kB""") + + +if __name__ == "__main__": + run_tests() diff --git a/test/profiler/test_profiler.py b/test/profiler/test_profiler.py index 5f3d7621dcfb3..acaa1f9667579 100644 --- a/test/profiler/test_profiler.py +++ b/test/profiler/test_profiler.py @@ -6,6 +6,7 @@ import os import re import tempfile +import textwrap import unittest from dataclasses import dataclass, field from typing import List, Optional @@ -33,6 +34,7 @@ record_function, supported_activities, ) +from torch._C._profiler import _TensorMetadata from torch.profiler._pattern_matcher import ( Conv2dBiasFollowedByBatchNorm2dPattern, ExtraCUDACopyPattern, @@ -1278,12 +1280,14 @@ def test_nested_tensor_with_shapes(self): def find_node_with_name(nodes, name): - for node in nodes: + for node in _utils.traverse_dfs(nodes): if node.name == name: return node - result = find_node_with_name(node.children, name) - if result is not None: - return result + +def find_node_with_regex(nodes, pattern): + for node in _utils.traverse_dfs(nodes): + if re.search(pattern, node.name): + return node class SimpleNet(nn.Module): @@ -1303,7 +1307,8 @@ def _get_tensor_fields(self, node, index): self.assertIsInstance( node.extra_fields, torch._C._profiler._ExtraFields_TorchOp) - tensor_info = node.extra_fields.inputs.tensor_metadata[index] + tensor_info = node.extra_fields.inputs[index] + self.assertIsInstance(tensor_info, _TensorMetadata) self.assertIsNotNone(tensor_info.impl_ptr) self.assertIsNotNone(tensor_info.storage_data_ptr) self.assertIsNotNone(tensor_info.id) @@ -1368,6 +1373,304 @@ def get_fields(op_name, index): self.assertEqual(c_id, c_id_new) self.assertEqual(d_id, c_id_new) + @staticmethod + def _format_allocations(profiled_code): + gc.collect() + with profile(profile_memory=True, record_shapes=True) as prof: + profiled_code() + gc.collect() + + root_events = prof.profiler.kineto_results.experimental_event_tree() + events = sorted(_utils.traverse_dfs(root_events), key=lambda x: x.start_time_ns) + allocations = tuple( + event.extra_fields + for event in events + if isinstance(event.extra_fields, torch._C._profiler._ExtraFields_Allocation) + ) + + return textwrap.indent("\n".join( + f"{repr(i.id):>5}{' ' * 6}" + f"{repr(i.allocation_id):>5}{' ' * 6}" + f"{'Allocation' if i.alloc_size > 0 else 'Free'}" + for i in allocations + ), " " * 12) + + def test_tensorimpl_invalidation_set(self) -> None: + def profiled_code(add_empty_set: bool): + x = torch.ones((1,)) + + # Determines if new storage is created before or after the old one + # is destroyed. + if add_empty_set: + x.set_() + + x.set_(torch.ones((1,)).storage()) + x.view_as(x) + + self.assertExpectedInline( + self._format_allocations(lambda: profiled_code(add_empty_set=False)), + """\ + 0 1 Allocation + 0 2 Allocation + 0 1 Free + 0 2 Free""" + ) + + self.assertExpectedInline( + self._format_allocations(lambda: profiled_code(add_empty_set=True)), + """\ + 0 1 Allocation + 0 1 Free + 0 2 Allocation + 0 2 Free""" + ) + + def test_tensorimpl_invalidation_keep_alive(self) -> None: + def profiled_code(add_empty_set: bool): + x = torch.ones((1,)) + x_storages = [x.storage()] + for _ in range(3): + x.set_() + x.set_(torch.ones((1,)).storage()) + + # This keeps the StorageImpls alive and preserves the chain. + # (Despite the `set_()` call.) + x_storages.append(x.storage()) + x.view_as(x) + + # Free storage in a deterministic fashion. + while x_storages: + x_storages.pop() + gc.collect() + + # Determines if new storage is created before or after the old one + # is destroyed. + if add_empty_set: + x.set_() + + for _ in range(3): + x.set_(torch.ones((1,)).storage()) + x.view_as(x) + + del x + gc.collect() + + self.assertExpectedInline( + self._format_allocations(lambda: profiled_code(add_empty_set=False)), + """\ + 0 1 Allocation + 0 2 Allocation + 0 4 Allocation + 0 5 Allocation + 0 4 Free + 0 2 Free + 0 1 Free + 0 6 Allocation + 0 5 Free + 0 7 Allocation + 0 6 Free + 0 8 Allocation + 0 7 Free + 0 8 Free""" + ) + + self.assertExpectedInline( + self._format_allocations(lambda: profiled_code(add_empty_set=True)), + """\ + 0 1 Allocation + 0 2 Allocation + 0 4 Allocation + 0 5 Allocation + 0 4 Free + 0 2 Free + 0 1 Free + 0 5 Free + 0 6 Allocation + 0 7 Allocation + 0 6 Free + 0 8 Allocation + 0 7 Free + 0 8 Free""" + ) + + def test_tensorimpl_invalidation_full(self) -> None: + def profiled_code(): + x = torch.ones((1,)) + x_storages = [x.storage()] + for _ in range(3): + x.set_() + x.set_(torch.ones((1,)).storage()) + x_storages.append(x.storage()) + x.view_as(x) + + # Free storage in a deterministic fashion. + while x_storages: + x_storages.pop() + gc.collect() + + for _ in range(3): + x.set_(torch.ones((1,)).storage()) + + for _ in range(3): + x.set_() + x.set_(torch.ones((1,)).storage()) + + for i in range(4): + x.resize_((1 + i,)) + x.view_as(x) + + self.assertExpectedInline( + self._format_allocations(profiled_code), + """\ + 0 1 Allocation + 0 2 Allocation + 0 4 Allocation + 0 5 Allocation + 0 4 Free + 0 2 Free + 0 1 Free + 0 6 Allocation + 0 5 Free + 0 7 Allocation + 0 6 Free + 0 8 Allocation + 0 7 Free + 0 8 Free + 0 9 Allocation + 0 9 Free + 0 10 Allocation + 0 10 Free + 0 11 Allocation + 0 12 Allocation + 0 11 Free + 0 13 Allocation + 0 12 Free + 0 14 Allocation + 0 13 Free + 0 14 Free""" + ) + + def test_tensorimpl_invalidation_scalar_args(self) -> None: + def profiled_code(): + with torch.no_grad(): + x = torch.ones((1,)) + for _ in range(10): + x.add_(2) + + self.assertExpectedInline( + self._format_allocations(profiled_code), + """\ + 0 1 Allocation + 1 2 Allocation + 2 3 Allocation + 2 3 Free + 1 2 Free + 3 4 Allocation + 4 5 Allocation + 4 5 Free + 3 4 Free + 5 6 Allocation + 6 7 Allocation + 6 7 Free + 5 6 Free + 7 8 Allocation + 8 9 Allocation + 8 9 Free + 7 8 Free + 9 10 Allocation + 10 11 Allocation + 10 11 Free + 9 10 Free + 11 12 Allocation + 12 13 Allocation + 12 13 Free + 11 12 Free + 13 14 Allocation + 14 15 Allocation + 14 15 Free + 13 14 Free + 15 16 Allocation + 16 17 Allocation + 16 17 Free + 15 16 Free + 17 18 Allocation + 18 19 Allocation + 18 19 Free + 17 18 Free + 19 20 Allocation + 20 21 Allocation + 20 21 Free + 19 20 Free + 0 1 Free""") + + + def test_module_and_optimizer_ids(self) -> None: + model = torch.nn.Linear(2, 1, bias=True) + optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) + + def check(cold_start: bool) -> None: + with profile(with_stack=True, profile_memory=True, record_shapes=True) as p: + x = torch.ones((1, 2)) + _ = x.sin() # Mark `x` + model(x).backward() + optimizer.step() + _ = optimizer.state[model.weight]["momentum_buffer"].cos() # Mark weight momentum + _ = model.weight.grad.tan() # Mark weight gradient + + nodes = p.profiler.kineto_results.experimental_event_tree() + + def get_fields(op_name, index): + return self._get_tensor_fields( + find_node_with_name(nodes, op_name), + index) + + # Marked Tensors act as ground truth for python tracer IDs. + _, _, x_id = get_fields("aten::sin", 0) + _, _, weight_momenumtum_id = get_fields("aten::cos", 0) + _, _, weight_grad_id = get_fields("aten::tan", 0) + self.assertNotEqual(x_id, weight_momenumtum_id) + self.assertNotEqual(x_id, weight_grad_id) + self.assertNotEqual(weight_momenumtum_id, weight_grad_id) + + # Use linear op to identify weight ground truth. + linear_op_node = find_node_with_name(nodes, "aten::linear") + self.assertIsNotNone(linear_op_node) + x_metadata, weight_metadata, _ = linear_op_node.extra_fields.inputs + self.assertEqual(x_id, x_metadata.id) + + # Module + linear_module_node = find_node_with_name(nodes, "nn.Module: Linear_0") + self.assertIsNotNone(linear_module_node) + self.assertIsNotNone(linear_module_node.extra_fields.module) + self.assertIsNone(linear_module_node.extra_fields.optimizer) + + linear_parameters = linear_module_node.extra_fields.module.parameters + name, weight, weight_grad = linear_parameters[0] + self.assertEqual(name, "weight") + self.assertEqual(weight.id, weight_metadata.id) + + self.assertEqual(weight_grad is None, cold_start) + if not cold_start: + self.assertEqual(weight_grad.id, weight_grad_id) + + # Optimizer + step_node = find_node_with_regex(nodes, "_optimizer_step_code") + self.assertIsNotNone(step_node) + self.assertIsNone(step_node.extra_fields.module) + self.assertIsNotNone(step_node.extra_fields.optimizer) + optimizer_parameters = step_node.extra_fields.optimizer.parameters + self.assertEqual(len(optimizer_parameters), 2) # Weight and bias + weight, weight_grad, state = optimizer_parameters[0] + self.assertEqual(weight.id, weight_metadata.id) + self.assertEqual(weight_grad.id, weight_grad_id) + self.assertEqual(len(state), 1) + self.assertEqual(state[0][0], "momentum_buffer") + self.assertEqual(state[0][1].id, weight_momenumtum_id) + + # Check that we handle first step (lazy initalization) and steady state. + check(cold_start=True) + check(cold_start=False) + def _test_allocation_ids(self, before_fn, after_fn) -> None: with profile(profile_memory=True, record_shapes=True) as p: # Introduce other operations and allocations to check robustness @@ -1436,6 +1739,43 @@ def test_allocation_ids_with_other_ops(self) -> None: lambda: torch.zeros((1,)).cos() ) + def test_impl_reuse(self) -> None: + repeats = 1_000 + with profile(profile_memory=True, record_shapes=True) as p: + for _ in range(repeats): + torch.ones((1,)) + gc.collect() + + roots = p.profiler.kineto_results.experimental_event_tree() + tensor_impls = tuple( + e.extra_fields.inputs[0].impl_ptr + for e in _utils.traverse_dfs(roots) + if e.name == "aten::fill_" + ) + + self.assertEqual(len(tensor_impls), repeats) + self.assertEqual(len(set(tensor_impls)), repeats) + + def test_allocation_id_uniqueness(self) -> None: + repeats = 1_000 + with profile(profile_memory=True, record_shapes=True) as p: + for _ in range(repeats): + torch.ones((1,)) + gc.collect() + + roots = p.profiler.kineto_results.experimental_event_tree() + id_set = set() + for e in _utils.traverse_dfs(roots): + fields = e.extra_fields + if isinstance(fields, torch._C._profiler._ExtraFields_TorchOp): + id_set |= {t.allocation_id for t in fields.inputs if isinstance(t, _TensorMetadata)} + + elif isinstance(fields, torch._C._profiler._ExtraFields_Allocation): + id_set.add(fields.allocation_id) + + id_set.difference_update([None]) + self.assertEqual(repeats, len(id_set)) + def test_extra_fields(self): with profile(with_stack=True, profile_memory=True) as p: _ = torch.ones((1,)) @@ -1474,18 +1814,14 @@ def test_tensor_properties(self): node.extra_fields, torch._C._profiler._ExtraFields_TorchOp) - self.assertEqual(node.extra_fields.inputs.shapes, [[4, 4], [4, 1], []]) - self.assertEqual(node.extra_fields.inputs.strides, [[12, 3], [1, 1], []]) - - input_info = node.extra_fields.inputs - self.assertEqual(input_info.dtypes, ['float', 'float', 'Scalar']) + def getattr_inputs(name, default): + return [getattr(i, name, default) for i in node.extra_fields.inputs] - layout_info = [x.layout if x else None for x in input_info.tensor_metadata] - self.assertEqual(layout_info, [torch.strided, torch.strided, None]) - device_info = [x.device if x else None for x in input_info.tensor_metadata] - self.assertEqual(device_info, [torch.device("cpu"), torch.device("cpu"), None]) - tensor_dtypes = [x.dtype if x else None for x in input_info.tensor_metadata] - self.assertEqual(tensor_dtypes, [torch.float32, torch.float32, None]) + self.assertEqual(getattr_inputs("sizes", []), [[4, 4], [4, 1], []]) + self.assertEqual(getattr_inputs("strides", []), [[12, 3], [1, 1], []]) + self.assertEqual(getattr_inputs("layout", None), [torch.strided, torch.strided, None]) + self.assertEqual(getattr_inputs("device", None), [torch.device("cpu"), torch.device("cpu"), None]) + self.assertEqual(getattr_inputs("dtype", None), [torch.float32, torch.float32, None]) self.assertEqual(node.extra_fields.scope, torch.profiler.RecordScope.FUNCTION) mul_node = find_node_with_name(nodes, "aten::mul") @@ -1510,19 +1846,13 @@ def test_sparse_tensors(self): node.extra_fields, torch._C._profiler._ExtraFields_TorchOp) - self.assertEqual(node.extra_fields.inputs.shapes, [[2, 3], [2, 3], []]) - self.assertEqual(node.extra_fields.inputs.strides, [[], [], []]) + def getattr_inputs(name, default): + return [getattr(i, name, default) for i in node.extra_fields.inputs] - input_info = node.extra_fields.inputs - - # FIXME: Different systems have different names for int64_t - # below are example names I have found. This is not guaranteed to be exhaustive. - # self.assertIn(input_info.dtypes[0], ["long long", "long int", "long", "__int64"]) - - layout_info = [x.layout if x else None for x in input_info.tensor_metadata] - self.assertEqual(layout_info, [torch.sparse_coo, torch.sparse_coo, None]) - device_info = [x.device if x else None for x in input_info.tensor_metadata] - self.assertEqual(device_info, [torch.device("cpu"), torch.device("cpu"), None]) + self.assertEqual(getattr_inputs("sizes", []), [[2, 3], [2, 3], []]) + self.assertEqual(getattr_inputs("strides", []), [[], [], []]) + self.assertEqual(getattr_inputs("layout", None), [torch.sparse_coo, torch.sparse_coo, None]) + self.assertEqual(getattr_inputs("device", None), [torch.device("cpu"), torch.device("cpu"), None]) @unittest.skipIf(not torch._C.has_mkldnn, "MKL-DNN build is disabled") def test_mkldnn_tensors(self): @@ -1539,16 +1869,13 @@ def test_mkldnn_tensors(self): node.extra_fields, torch._C._profiler._ExtraFields_TorchOp) - self.assertEqual(node.extra_fields.inputs.shapes, [[4, 3], [4, 3], []]) - self.assertEqual(node.extra_fields.inputs.strides, [[], [], []]) + def getattr_inputs(name, default): + return [getattr(i, name, default) for i in node.extra_fields.inputs] - input_info = node.extra_fields.inputs - self.assertEqual(input_info.dtypes, ['float', 'float', 'Scalar']) - - layout_info = [x.layout if x else None for x in input_info.tensor_metadata] - self.assertEqual(layout_info, [torch._mkldnn, torch._mkldnn, None]) - device_info = [x.device if x else None for x in input_info.tensor_metadata] - self.assertEqual(device_info, [torch.device("cpu"), torch.device("cpu"), None]) + self.assertEqual(getattr_inputs("sizes", []), [[4, 3], [4, 3], []]) + self.assertEqual(getattr_inputs("strides", []), [[], [], []]) + self.assertEqual(getattr_inputs("layout", None), [torch._mkldnn, torch._mkldnn, None]) + self.assertEqual(getattr_inputs("device", None), [torch.device("cpu"), torch.device("cpu"), None]) def test_scalar_ins(self): x = torch.ones(5, 5) @@ -1561,11 +1888,29 @@ def test_scalar_ins(self): node = find_node_with_name(nodes, "aten::add") self.assertIsNotNone(node) + def getattr_inputs(name, default): + return [getattr(i, name, default) for i in node.extra_fields.inputs] + # The second argument to the add gets promotoed to a zerodim Tensor - input_info = node.extra_fields.inputs - self.assertEqual(input_info.dtypes, ['float', 'double', 'Scalar']) - self.assertEqual(input_info.shapes, [[5, 5], [], []]) - self.assertEqual(input_info.ivalues, [None, None, alpha]) + self.assertEqual(getattr_inputs("dtype", None), [torch.float32, torch.float64, None]) + self.assertEqual(getattr_inputs("sizes", []), [[5, 5], [], []]) + self.assertEqual(node.extra_fields.inputs[2], alpha) + + def test_tensor_lists(self): + x = torch.ones((1,)) + y = torch.ones((1,)) + with profile(with_stack=True, profile_memory=True, record_shapes=True) as p: + _ = torch.stack((x, y)) + + nodes = p.profiler.kineto_results.experimental_event_tree() + node = find_node_with_name(nodes, "aten::stack") + inputs = node.extra_fields.inputs + self.assertEqual(len(inputs), 2) + self.assertIsInstance(inputs[0], list) + self.assertEqual(len(inputs[0]), 2) + self.assertEqual(x.storage().data_ptr(), inputs[0][0].storage_data_ptr) + self.assertEqual(y.storage().data_ptr(), inputs[0][1].storage_data_ptr) + def test_nnmodule_params(self): @@ -1574,7 +1919,7 @@ def flat_out_extrafields(nodes, out=None): out = [] for node in nodes: if isinstance(node.extra_fields, _ExtraFields_PyCall) and node.extra_fields.module: - if node.extra_fields.module.params: + if node.extra_fields.module.parameters: out.append(node.extra_fields.module) flat_out_extrafields(node.children, out) return out @@ -1589,7 +1934,7 @@ def flat_out_extrafields(nodes, out=None): modules = flat_out_extrafields(p.profiler.kineto_results.experimental_event_tree()) self.assertEqual(len(modules), 2, f"Expected two parameter list, but got {len(modules)}") - params = [(n, p.storage_data_ptr, g.storage_data_ptr) for module in modules for (n, p, g) in module.params] + params = [(n, p.storage_data_ptr, g.storage_data_ptr) for module in modules for (n, p, g) in module.parameters] expected = [(name, val.storage().data_ptr(), val.grad.storage().data_ptr()) for name, val in net.fc1._parameters.items()] expected += [(name, val.storage().data_ptr(), val.grad.storage().data_ptr()) for name, val in net.fc2._parameters.items()] self.assertEqual(expected, params, f"{expected} vs. {params}") @@ -1599,29 +1944,37 @@ def _flat_out_extrafields(self, nodes, out=None): out = [] for node in nodes: if (isinstance(node.extra_fields, _ExtraFields_PyCall) and - node.extra_fields.optimizer and node.extra_fields.optimizer.param_addrs): + node.extra_fields.optimizer and node.extra_fields.optimizer.parameters): # avoiding OptInfo duplicates from iterations - addr = node.extra_fields.optimizer.param_addrs[0].storage_data_ptr - if not [o for o in out if addr == o.param_addrs[0].storage_data_ptr]: + addr = node.extra_fields.optimizer.parameters[0][0].storage_data_ptr + if not [o for o in out if addr == o.parameters[0][0].storage_data_ptr]: out.append(node.extra_fields.optimizer) self._flat_out_extrafields(node.children, out) return out def _check_results(self, opt, opts, check_items=False): self.assertEqual(len(opts), 1, f"Expected 1 optimizer: len(opts): {len(opts)}") - self.assertEqual(id(opt), opts[0].self, f"Optimizer addr ({id(opt)}) vs. profiled addr ({opts[0].self})") + self.assertEqual(id(opt), opts[0].self_ptr, f"Optimizer addr ({id(opt)}) vs. profiled addr ({opts[0].self_ptr})") if check_items: self.assertEqual(len(opt.param_groups), len(opts)) for group, opt_ in zip(opt.param_groups, opts): self.assertEqual( [(v.storage().data_ptr()) for v in group.get("params", [])], - [(o.storage_data_ptr) for o in opt_.param_addrs] + [(o.storage_data_ptr) for (o, _, _) in opt_.parameters] ) for opt_ in opts: - self.assertEqual( - [(name, val.storage().data_ptr()) for dic in opt.state.values() for name, val in dic.items()], - [(n, p.storage_data_ptr) for (n, p) in opt_.opt_state] - ) + observed_state = { + p.storage_data_ptr: {name: s.storage_data_ptr for name, s in state} + for (p, _, state) in opt_.parameters + } + + # Make sure the profiler collected all optimizer state and check + # that the address recorded by the profiler is correct. + for parameter, parameter_state in opt.state.items(): + self.assertEqual( + {name: value.storage().data_ptr() for name, value in parameter_state.items()}, + observed_state.get(parameter.storage().data_ptr(), []) + ) def test_optimizer(self): inputs = torch.rand(10) diff --git a/test/profiler/test_profiler_tree.py b/test/profiler/test_profiler_tree.py index f0097985f2940..d4a31c6456131 100644 --- a/test/profiler/test_profiler_tree.py +++ b/test/profiler/test_profiler_tree.py @@ -138,11 +138,6 @@ def flatten(nodes, depth=0, out=None): @staticmethod def fmt_name(name: str) -> str: - # torch::autograd::Node relies on c10::demangle to generate names, and - # Windows demangles to include `struct` in the name. - if IS_WINDOWS: - name = name.replace('struct torch::autograd::AccumulateGrad', 'torch::autograd::AccumulateGrad') - match = re.match(r"^(.*)\.py\(([0-9]+)\): (.*)$", name) if match: filename, _, fn = match.groups() @@ -317,34 +312,18 @@ def test_profiler_experimental_tree_with_record_function(self): self.assertTreesMatch( ProfilerTree.format(p.profiler, 12), """\ - aten::zeros - aten::empty - aten::zero_ Top level Annotation - aten::empty - aten::zeros - aten::empty - aten::zero_ First Annotation - aten::empty aten::ones aten::empty aten::fill_ - aten::zeros - aten::empty - aten::zero_ Second Annotation - aten::empty aten::add aten::to aten::_to_copy aten::empty_strided aten::copy_ - aten::zeros - aten::empty - aten::zero_ Third Annotation - aten::empty aten::ones_like aten::empty_like aten::empty_strided diff --git a/test/quantization/ao_migration/common.py b/test/quantization/ao_migration/common.py index bade3b7ff4d26..50045a39e7ab5 100644 --- a/test/quantization/ao_migration/common.py +++ b/test/quantization/ao_migration/common.py @@ -6,7 +6,8 @@ class AOMigrationTestCase(TestCase): def _test_package_import(self, package_name: str, base: Optional[str] = None, - skip: List[str] = None): + skip: List[str] = None, + new_package_name: Optional[str] = None): r"""Tests the module import by making sure that all the internals match (except the dunder methods). @@ -19,8 +20,10 @@ def _test_package_import(self, package_name: str, base = base or 'quantization' old_base = 'torch.' + base new_base = 'torch.ao.' + base + if new_package_name is None: + new_package_name = package_name old_module = importlib.import_module(f'{old_base}.{package_name}') - new_module = importlib.import_module(f'{new_base}.{package_name}') + new_module = importlib.import_module(f'{new_base}.{new_package_name}') old_module_dir = set(dir(old_module)) new_module_dir = set(dir(new_module)) # Remove magic modules from checking in subsets @@ -36,15 +39,17 @@ def _test_package_import(self, package_name: str, f"{old_module_dir - new_module_dir}" def _test_function_import(self, package_name: str, function_list: List[str], - base: Optional[str] = None): + base: Optional[str] = None, new_package_name: Optional[str] = None): r"""Tests individual function list import by comparing the functions and their hashes.""" if base is None: base = 'quantization' old_base = 'torch.' + base new_base = 'torch.ao.' + base + if new_package_name is None: + new_package_name = package_name old_location = importlib.import_module(f'{old_base}.{package_name}') - new_location = importlib.import_module(f'{new_base}.{package_name}') + new_location = importlib.import_module(f'{new_base}.{new_package_name}') for fn_name in function_list: old_function = getattr(old_location, fn_name) new_function = getattr(new_location, fn_name) diff --git a/test/quantization/ao_migration/test_quantization.py b/test/quantization/ao_migration/test_quantization.py index 89b69d1ef1829..9c246e1b7cd89 100644 --- a/test/quantization/ao_migration/test_quantization.py +++ b/test/quantization/ao_migration/test_quantization.py @@ -118,7 +118,7 @@ def test_package_import_quant_type(self): def test_function_import_quant_type(self): function_list = [ 'QuantType', - 'quant_type_to_str', + '_get_quant_type_to_str', ] self._test_function_import('quant_type', function_list) @@ -177,9 +177,9 @@ def test_function_import_qconfig(self): "default_qat_qconfig_v2", "get_default_qconfig", "get_default_qat_qconfig", - "assert_valid_qconfig", + "_assert_valid_qconfig", "QConfigAny", - "add_module_to_qconfig_obs_ctr", + "_add_module_to_qconfig_obs_ctr", "qconfig_equals" ] self._test_function_import('qconfig', function_list) @@ -225,7 +225,7 @@ def test_function_import_fuser_method_mappings(self): "get_fuser_method", ] dict_list = [ - "DEFAULT_OP_LIST_TO_FUSER_METHOD" + "_DEFAULT_OP_LIST_TO_FUSER_METHOD" ] self._test_function_import('fuser_method_mappings', function_list) self._test_dict_import('fuser_method_mappings', dict_list) diff --git a/test/quantization/ao_migration/test_quantization_fx.py b/test/quantization/ao_migration/test_quantization_fx.py index 03d1da6f2cfb4..fed2921cea722 100644 --- a/test/quantization/ao_migration/test_quantization_fx.py +++ b/test/quantization/ao_migration/test_quantization_fx.py @@ -11,8 +11,6 @@ def test_function_import_quantize_fx(self): '_check_is_graph_module', '_swap_ff_with_fxff', '_fuse_fx', - 'Scope', - 'ScopeContextManager', 'QuantizationTracer', '_prepare_fx', '_prepare_standalone_module_fx', @@ -26,7 +24,10 @@ def test_function_import_quantize_fx(self): self._test_function_import('quantize_fx', function_list) def test_package_import_fx(self): - self._test_package_import('fx') + self._test_package_import('fx', skip=[ + 'fusion_patterns', + 'quantization_patterns', + ]) def test_function_import_fx(self): function_list = [ @@ -99,7 +100,10 @@ def test_function_import_fx_equalize(self): self._test_function_import('fx._equalize', function_list) def test_package_import_fx_quantization_patterns(self): - self._test_package_import('fx.quantization_patterns') + self._test_package_import( + 'fx.quantization_patterns', + new_package_name='fx.quantize_handler', + ) def test_function_import_fx_quantization_patterns(self): function_list = [ @@ -118,7 +122,11 @@ def test_function_import_fx_quantization_patterns(self): 'GeneralTensorShapeOpQuantizeHandler', 'StandaloneModuleQuantizeHandler' ] - self._test_function_import('fx.quantization_patterns', function_list) + self._test_function_import( + 'fx.quantization_patterns', + function_list, + new_package_name='fx.quantize_handler', + ) def test_package_import_fx_match_utils(self): self._test_package_import('fx.match_utils') @@ -158,14 +166,21 @@ def test_function_import_fx_fuse(self): self._test_function_import('fx.fuse', function_list) def test_package_import_fx_fusion_patterns(self): - self._test_package_import('fx.fusion_patterns') + self._test_package_import( + 'fx.fusion_patterns', + new_package_name='fx.fuse_handler', + ) def test_function_import_fx_fusion_patterns(self): function_list = [ 'FuseHandler', 'DefaultFuseHandler' ] - self._test_function_import('fx.fusion_patterns', function_list) + self._test_function_import( + 'fx.fusion_patterns', + function_list, + new_package_name='fx.fuse_handler', + ) # we removed matching test for torch.quantization.fx.quantization_types # old: torch.quantization.fx.quantization_types @@ -177,22 +192,15 @@ def test_package_import_fx_utils(self): def test_function_import_fx_utils(self): function_list = [ - 'graph_pretty_str', - 'get_per_tensor_qparams', - 'quantize_node', 'get_custom_module_class_keys', 'get_linear_prepack_op_for_dtype', 'get_qconv_prepack_op', - 'get_qconv_op', 'get_new_attr_name_with_prefix', 'graph_module_from_producer_nodes', 'assert_and_get_unique_device', 'create_getattr_from_value', - 'create_qparam_nodes', 'all_node_args_have_no_tensors', - 'node_return_type_is_int', 'get_non_observable_arg_indexes_and_types', - 'is_get_tensor_info_node', 'maybe_get_next_module' ] self._test_function_import('fx.utils', function_list) diff --git a/test/quantization/core/test_backend_config.py b/test/quantization/core/test_backend_config.py index e1e7067d4135b..6cf8f3d5e5c61 100644 --- a/test/quantization/core/test_backend_config.py +++ b/test/quantization/core/test_backend_config.py @@ -13,10 +13,8 @@ DTypeWithConstraints, ObservationType, ) -from torch.ao.quantization.fake_quantize import FixedQParamsFakeQuantize -from torch.ao.quantization.fuser_method_mappings import reverse_sequential_wrapper2 -from torch.ao.quantization.fx.quantization_patterns import _default_root_node_getter -from torch.ao.quantization.observer import default_fixed_qparams_range_0to1_observer +from torch.ao.quantization.fuser_method_mappings import _reverse_sequential_wrapper2 +from torch.ao.quantization.fx.quantize_handler import _default_root_node_getter class TestBackendConfig(QuantizationTestCase): @@ -106,7 +104,7 @@ def test_dtype_config_to_dict(self): # BackendPatternConfig # ====================== - _fuser_method = reverse_sequential_wrapper2(nni.LinearReLU) + _fuser_method = _reverse_sequential_wrapper2(nni.LinearReLU) _num_tensor_args_to_observation_type = { 0: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, @@ -118,7 +116,6 @@ def test_dtype_config_to_dict(self): "input": 1, "weight": 2, } - _fake_quantize = FixedQParamsFakeQuantize.with_args(observer=default_fixed_qparams_range_0to1_observer) def _extra_inputs_getter(self, p): return (torch.rand(3, 3),) @@ -141,9 +138,7 @@ def _get_backend_op_config2(self): ._set_extra_inputs_getter(self._extra_inputs_getter) \ ._set_num_tensor_args_to_observation_type(self._num_tensor_args_to_observation_type) \ ._set_input_type_to_index(self._input_type_to_index) \ - ._set_input_output_observed(False) \ - ._set_overwrite_output_fake_quantize(self._fake_quantize) \ - ._set_overwrite_output_observer(default_fixed_qparams_range_0to1_observer) + ._set_input_output_observed(False) def _get_backend_pattern_config_dict1(self): return { @@ -167,8 +162,6 @@ def _get_backend_pattern_config_dict2(self): "num_tensor_args_to_observation_type": self._num_tensor_args_to_observation_type, "input_type_to_index": self._input_type_to_index, "input_output_observed": False, - "overwrite_output_fake_quantize": self._fake_quantize, - "overwrite_output_observer": default_fixed_qparams_range_0to1_observer } def test_backend_op_config_set_observation_type(self): @@ -246,18 +239,6 @@ def test_backend_op_config_set_input_output_observed(self): conf._set_input_output_observed(False) self.assertEqual(conf._input_output_observed, False) - def test_backend_op_config_set_overwrite_output_fake_quantize(self): - conf = BackendPatternConfig(torch.sigmoid) - self.assertTrue(conf._overwrite_output_fake_quantize is None) - conf._set_overwrite_output_fake_quantize(self._fake_quantize) - self.assertEqual(conf._overwrite_output_fake_quantize, self._fake_quantize) - - def test_backend_op_config_set_overwrite_output_observer(self): - conf = BackendPatternConfig(torch.sigmoid) - self.assertTrue(conf._overwrite_output_observer is None) - conf._set_overwrite_output_observer(default_fixed_qparams_range_0to1_observer) - self.assertEqual(conf._overwrite_output_observer, default_fixed_qparams_range_0to1_observer) - def test_backend_op_config_from_dict(self): conf_dict1 = self._get_backend_pattern_config_dict1() conf1 = BackendPatternConfig.from_dict(conf_dict1) @@ -273,8 +254,6 @@ def test_backend_op_config_from_dict(self): self.assertEqual(len(conf1._num_tensor_args_to_observation_type), 0) self.assertEqual(len(conf1._input_type_to_index), 0) self.assertTrue(conf1._input_output_observed is None) - self.assertTrue(conf1._overwrite_output_fake_quantize is None) - self.assertTrue(conf1._overwrite_output_observer is None) # Test temporary/internal keys conf_dict2 = self._get_backend_pattern_config_dict2() conf2 = BackendPatternConfig.from_dict(conf_dict2) @@ -290,8 +269,6 @@ def test_backend_op_config_from_dict(self): self.assertEqual(conf2._num_tensor_args_to_observation_type, self._num_tensor_args_to_observation_type) self.assertEqual(conf2._input_type_to_index, self._input_type_to_index) self.assertEqual(conf2._input_output_observed, False) - self.assertEqual(conf2._overwrite_output_fake_quantize, self._fake_quantize) - self.assertEqual(conf2._overwrite_output_observer, default_fixed_qparams_range_0to1_observer) def test_backend_op_config_to_dict(self): conf1 = self._get_backend_op_config1() diff --git a/test/quantization/core/test_quantized_module.py b/test/quantization/core/test_quantized_module.py index 5964de70b8e39..780f1ebb6cd57 100644 --- a/test/quantization/core/test_quantized_module.py +++ b/test/quantization/core/test_quantized_module.py @@ -1036,6 +1036,33 @@ def test_prelu(self): self.assertEqual(qy_ref, qy, msg="PReLU module API failed") + def test_channel_shuffle(self): + """Tests the correctness of the ChannelShuffle module. + """ + x_scale = 10.0 / 256 + x_zero_point = 1 + y_scale = x_scale + y_zero_point = x_zero_point + + dims = (1, 4, 4, 8) + groups = 2 + + X = (torch.randn(dims, dtype=torch.float) - 0.5) * 10 + qX = torch.quantize_per_tensor(X, x_scale, x_zero_point, dtype=torch.quint8) + dqX = qX.dequantize() + + float_mod = torch.nn.ChannelShuffle(groups).float() + dqY_ref = float_mod(dqX) + qY_ref = torch.quantize_per_tensor( + dqY_ref, y_scale, y_zero_point, dtype=torch.quint8) + + quant_mod = torch.nn.ChannelShuffle(groups) + qY = quant_mod(qX) + + self.assertEqual(qY_ref.int_repr().numpy(), qY.int_repr().numpy(), + msg="ChannelShuffle module API failed, qY_ref\n{} vs qY\n{}" + .format(qY_ref, qY)) + class TestDynamicQuantizedModule(QuantizationTestCase): def _test_qconv_impl(self, q_mod, dq_mod, dim, dtype, bias): in_channels = 3 diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py index 79297e073f047..c91a2bf547280 100644 --- a/test/quantization/core/test_quantized_op.py +++ b/test/quantization/core/test_quantized_op.py @@ -21,9 +21,9 @@ import torch.testing._internal.hypothesis_utils as hu hu.assert_deadline_disabled() -from torch.testing._internal.common_utils import TestCase, skipIfSlowGradcheckEnv +from torch.testing._internal.common_utils import TestCase from torch.testing._internal.common_utils import IS_PPC, TEST_WITH_UBSAN, IS_MACOS, BUILD_WITH_CAFFE2 -from torch.testing._internal.common_quantization import skipIfNoFBGEMM, skipIfNoQNNPACK +from torch.testing._internal.common_quantization import skipIfNoFBGEMM, skipIfNoQNNPACK, skipIfNoONEDNN from torch.testing._internal.common_quantized import _quantize, _dequantize, _calculate_dynamic_qparams, \ override_quantized_engine, supported_qengines, override_qengines, _snr from torch.testing._internal.common_quantized import ( @@ -130,7 +130,6 @@ def _get_random_tensor_and_q_params(shapes, rand_scale, torch_type): X_scale = 1e-10 return X, X_scale, X_zero_point -@skipIfSlowGradcheckEnv class TestQuantizedOps(TestCase): """Helper function to test quantized activation functions.""" @@ -1811,9 +1810,9 @@ def test_adaptive_avg_pool2d_nhwc(self): for name, op in ops_under_test.items(): X_hat = op(qX, output_size=output_size) self.assertTrue(X_hat.stride() != sorted(X_hat.stride())) - # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 - self.assertEqualIgnoreType(X_ref, X_hat.int_repr(), atol=1.0, rtol=0, - msg=error_message.format(name, X_ref, X_hat.int_repr())) + self.assertEqual(X_ref, X_hat.int_repr(), atol=1.0, rtol=0, + msg=error_message.format(name, X_ref, X_hat.int_repr()), + exact_dtype=False) self.assertEqual(scale, X_hat.q_scale(), msg=error_message.format(name + '.scale', scale, X_hat.q_scale())) self.assertEqual(zero_point, X_hat.q_zero_point(), @@ -1887,10 +1886,9 @@ def test_adaptive_avg_pool(self): devices = ["cpu", "cuda"] if (dim == 2 and torch.cuda.is_available()) else ["cpu"] for device in devices: qX_hat = op(qX.to(device=device), output_size=output_size) - # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 - self.assertEqualIgnoreType( + self.assertEqual( X_ref, qX_hat.int_repr(), atol=1.0, - rtol=0, msg=error_message.format(name, X_ref, qX_hat)) + rtol=0, msg=error_message.format(name, X_ref, qX_hat), exact_dtype=False) self.assertEqual( scale, qX_hat.q_scale(), msg=error_message.format(name + '.scale', scale, @@ -1962,9 +1960,9 @@ def test_adaptive_avg_pool3d_ndhwc(self): for name, op in ops_under_test.items(): X_hat = op(qX, output_size=output_size) self.assertTrue(X_hat.stride() != sorted(X_hat.stride())) - # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 - self.assertEqualIgnoreType(X_ref, X_hat.int_repr(), atol=1.0, rtol=0, - msg=error_message.format(name, X_ref, X_hat.int_repr())) + self.assertEqual(X_ref, X_hat.int_repr(), atol=1.0, rtol=0, + msg=error_message.format(name, X_ref, X_hat.int_repr()), + exact_dtype=False) self.assertEqual(scale, X_hat.q_scale(), msg=error_message.format(name + '.scale', scale, X_hat.q_scale())) self.assertEqual(zero_point, X_hat.q_zero_point(), @@ -2109,10 +2107,10 @@ def test_interpolate(self, X, size, mode, scale_factor, align_corners, nhwc_layo for name, op in ops_under_test.items(): qX_hat = op(qX, size=size, scale_factor=scale_factor, mode=mode, align_corners=align_corners) - # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 - self.assertEqualIgnoreType(X_ref, qX_hat.int_repr(), atol=1.0, rtol=0, - msg="{} results are off: qX_hat={} X_ref={}" - .format(name, qX_hat.int_repr(), X_ref)) + self.assertEqual(X_ref, qX_hat.int_repr(), atol=1.0, rtol=0, + msg="{} results are off: qX_hat={} X_ref={}" + .format(name, qX_hat.int_repr(), X_ref), + exact_dtype=False) self.assertEqual(scale, qX_hat.q_scale(), msg=error_message.format(name + '.scale', scale, qX_hat.q_scale())) self.assertEqual(zero_point, qX_hat.q_zero_point(), @@ -2164,10 +2162,9 @@ def test_interpolate3d(self, X, size, mode, scale_factor, align_corners, nhwc_la for name, op in ops_under_test.items(): qX_hat = op(qX, size=size, scale_factor=scale_factor, mode=mode, align_corners=align_corners) - # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 - self.assertEqualIgnoreType(X_ref, qX_hat.int_repr(), atol=1.0, rtol=0, - msg="{} results are off: qX_hat={}, X_ref={}" - .format(name, qX_hat.int_repr(), X_ref)) + self.assertEqual(X_ref, qX_hat.int_repr(), atol=1.0, rtol=0, + msg="{} results are off: qX_hat={}, X_ref={}" + .format(name, qX_hat.int_repr(), X_ref), exact_dtype=False) self.assertEqual(scale, qX_hat.q_scale(), msg=error_message.format(name + '.scale', scale, qX_hat.q_scale())) self.assertEqual(zero_point, qX_hat.q_zero_point(), @@ -3529,17 +3526,8 @@ def test_dynamic_convtranspose3d(self): class TestQuantizedLinear(TestCase): - """Tests the correctness of the quantized linear and linear_relu op.""" - @given(batch_size=st.integers(1, 4), - input_channels=st.integers(16, 32), - output_channels=st.integers(4, 8), - use_bias=st.booleans(), - use_relu=st.booleans(), - use_multi_dim_input=st.booleans(), - use_channelwise=st.booleans()) - @override_qengines - def test_qlinear(self, batch_size, input_channels, output_channels, use_bias, - use_relu, use_multi_dim_input, use_channelwise): + def _test_qlinear_impl(self, batch_size, input_channels, output_channels, use_bias, + post_op, use_multi_dim_input, use_channelwise, **post_op_kwargs): decimal_val = 4 dtypes = [torch.quint8] if torch.backends.quantized.engine == 'qnnpack': @@ -3561,8 +3549,10 @@ def test_qlinear(self, batch_size, input_channels, output_channels, use_bias, nptype = np_dtype[dtype] qlinear_prepack = torch.ops.quantized.linear_prepack - if use_relu: + if post_op == 'relu': qlinear = torch.ops.quantized.linear_relu + elif post_op == 'leaky_relu': + qlinear = torch.ops.quantized.linear_leaky_relu else: qlinear = torch.ops.quantized.linear if use_multi_dim_input: @@ -3637,7 +3627,7 @@ def test_qlinear(self, batch_size, input_channels, output_channels, use_bias, b, scale=X_scale * (W_scales[0].item()), zero_point=0, dtype=torch.qint32) if use_bias else None # Compare X_scale * W_scale * input_channels * X_value_max * W_value_max with # Y_scale * 255 (max for uint8). - Y_scale = 125.1234 + Y_scale = 12.34 Y_zp = 5 # Weight prepacking operator for quantized Linear float_bias = b if use_bias else None @@ -3645,13 +3635,13 @@ def test_qlinear(self, batch_size, input_channels, output_channels, use_bias, if use_multi_dim_input: X_q = X_q.view(3, int(batch_size / 3), input_channels) # Quantized Linear operator with prepacked weight - Y_q = qlinear(X_q, W_prepack, Y_scale, Y_zp) - if not use_channelwise: + Y_q = qlinear(X_q, W_prepack, Y_scale, Y_zp, **post_op_kwargs) + if not use_channelwise and post_op in ('none', 'relu'): # Test the per-tensor quantization only # Reference quantized Linear operator Y_q_ref = qlinear_ref(X_q0, X_scale, X_zp, W_q0, W_scales[0], W_zps[0], b_q0, Y_scale, Y_zp, dtype=nptype) - if use_relu: + if post_op == 'relu': Y_q_ref[Y_q_ref < Y_zp] = Y_zp if use_multi_dim_input: Y_q_ref = np.reshape( @@ -3664,14 +3654,168 @@ def test_qlinear(self, batch_size, input_channels, output_channels, use_bias, X_fp32 = X_q.dequantize().to(dtype=torch.float) b_fp32 = b_q.dequantize().to(dtype=torch.float) if use_bias else None Y_fp32_ref = F.linear(X_fp32, W_fp32, b_fp32) - if use_relu: + if post_op == 'relu': Y_fp32_ref[Y_fp32_ref < 0.0] = 0.0 + elif post_op == 'leaky_relu': + Y_fp32_ref = F.leaky_relu(Y_fp32_ref, **post_op_kwargs) Y_q_ref2 = torch.quantize_per_tensor( Y_fp32_ref, Y_scale, Y_zp, dtype) # Assert equal np.testing.assert_array_almost_equal( Y_q_ref2.int_repr().numpy(), Y_q.int_repr().numpy(), decimal=decimal_val) + """Tests the correctness of the quantized linear op.""" + @override_qengines + def test_qlinear(self): + batch_size_list = [1, 4] + input_channels_list = [16, 32] + output_channels_list = [4, 8] + use_bias_list = [True, False] + use_multi_dim_input_list = [True, False] + use_channelwise_list = [True, False] + post_op = 'none' + cases = itertools.product(batch_size_list, input_channels_list, output_channels_list, + use_bias_list, use_multi_dim_input_list, use_channelwise_list) + for batch_size, input_channels, output_channels, use_bias,\ + use_multi_dim_input, use_channelwise in cases: + self._test_qlinear_impl(batch_size, input_channels, output_channels, + use_bias, post_op, use_multi_dim_input, use_channelwise) + + """Tests the correctness of the quantized linear_relu op.""" + @override_qengines + def test_qlinear_relu(self): + batch_size_list = [1, 4] + input_channels_list = [16, 32] + output_channels_list = [4, 8] + use_bias_list = [True, False] + use_multi_dim_input_list = [True, False] + use_channelwise_list = [True, False] + post_op = 'relu' + cases = itertools.product(batch_size_list, input_channels_list, output_channels_list, + use_bias_list, use_multi_dim_input_list, use_channelwise_list) + for batch_size, input_channels, output_channels, use_bias,\ + use_multi_dim_input, use_channelwise in cases: + self._test_qlinear_impl(batch_size, input_channels, output_channels, + use_bias, post_op, use_multi_dim_input, use_channelwise) + + @given(batch_size=st.integers(1, 4), + input_channels=st.integers(16, 32), + output_channels=st.integers(4, 8), + use_bias=st.booleans(), + use_relu=st.booleans(), + use_multi_dim_input=st.booleans(), + use_channelwise=st.booleans()) + @skipIfNoFBGEMM + def test_qlinear_with_input_q_dq_qweight_dq_output_fp32( + self, batch_size, input_channels, output_channels, use_bias, + use_relu, use_multi_dim_input, use_channelwise): + decimal_val = 4 + dtypes = [torch.quint8] + for dtype in dtypes: + # No support for channelwise in xnnpack (int8) + # ONEDNN does not support qint8 + if dtype == torch.qint8 and (use_channelwise or qengine_is_onednn()): + return + + nptype = np_dtype[dtype] + qlinear_prepack = torch.ops.quantized.linear_prepack + if use_relu: + qlinear = torch.ops.quantized.linear_with_input_q_dq_qweight_dq_relu_output_fp32 + else: + qlinear = torch.ops.quantized.linear_with_input_q_dq_qweight_dq_output_fp32 + if use_multi_dim_input: + batch_size *= 3 # Test the multi-dim input tensor + X_scale = 1.5 + X_zp = 5 + X_value_min = -128 if dtype == torch.qint8 else 0 + X_value_max = 127 if dtype == torch.qint8 else 255 + X_q0 = np.round( + np.random.rand(batch_size, input_channels) * + (X_value_max - X_value_min) + + X_value_min + ).astype(nptype) + + W_scales = np.random.rand(output_channels) + # xnnpack forces W_zp to 0 when using symmetric quantization + # ONEDNN only supports symmetric quantization of weight + if dtype == torch.qint8 or qengine_is_onednn(): + W_zps = np.zeros(output_channels).astype(np.int) + else: + W_zps = np.round(np.random.rand(output_channels) * 100 - 50).astype(np.int) + # when using symmetric quantization + # special restriction for xnnpack fully connected op weight + # [-127, 127] instead of [-128, 127] + W_value_min = -127 if dtype == torch.qint8 else -128 + W_value_max = 127 + W_q0 = np.round( + np.random.rand(output_channels, input_channels) + * (W_value_max - W_value_min) + + W_value_min + ).astype(np.int8) # weight is always int8_t + b_value_min = -10 + b_value_max = 10 + b_q0 = np.round( + np.random.rand(output_channels) * + (b_value_max - b_value_min) + b_value_min + ).astype(np.int32) if use_bias else None + if torch.backends.quantized.engine in ('x86', 'fbgemm', 'onednn'): + avoid_vpmaddubsw_overflow_linear( + batch_size, + input_channels, + output_channels, + X_q0, + X_value_min, + X_value_max, + W_q0, + W_value_min, + W_value_max, + ) + X = torch.from_numpy(_dequantize( + X_q0, X_scale, X_zp)).to(dtype=torch.float) + X_q = torch.quantize_per_tensor( + X, scale=X_scale, zero_point=X_zp, dtype=dtype) + if use_channelwise: + W = torch.from_numpy(_dequantize(W_q0, W_scales.reshape( + (-1, 1)), W_zps.reshape((-1, 1)))).to(dtype=torch.float) + W_q = torch.quantize_per_channel(W, scales=torch.from_numpy(W_scales), + zero_points=torch.from_numpy(W_zps), axis=0, dtype=torch.qint8) + b = torch.from_numpy(_dequantize( + b_q0, X_scale * W_scales, 0)).to(dtype=torch.float) if use_bias else None + b_q = torch.quantize_per_channel(b, scales=torch.from_numpy(X_scale * W_scales), + zero_points=torch.zeros(output_channels, dtype=torch.long), + axis=0, dtype=torch.qint32) if use_bias else None + else: + W = torch.from_numpy(_dequantize( + W_q0, W_scales[0], W_zps[0])).to(dtype=torch.float) + W_q = torch.quantize_per_tensor(W, scale=W_scales[0], zero_point=( + W_zps[0].astype(int).item()), dtype=torch.qint8) + b = torch.from_numpy(_dequantize( + b_q0, X_scale * (W_scales[0].item()), 0)).to(dtype=torch.float) if use_bias else None + b_q = torch.quantize_per_tensor( + b, scale=X_scale * (W_scales[0].item()), zero_point=0, dtype=torch.qint32) if use_bias else None + # Compare X_scale * W_scale * input_channels * X_value_max * W_value_max with + # Y_scale * 255 (max for uint8). + Y_scale = 125.1234 + Y_zp = 5 + # Weight prepacking operator for quantized Linear + float_bias = b if use_bias else None + W_prepack = qlinear_prepack(W_q, float_bias) + if use_multi_dim_input: + X = X.view(3, int(batch_size / 3), input_channels) + X_q = X_q.view(3, int(batch_size / 3), input_channels) + # Quantized Linear operator with prepacked weight + Y_q_dq = qlinear(X, X_scale, X_zp, W_prepack) + # Test both per-tensor and per-channel quantization + # Reference quantized result from PyTorch Linear operator + W_fp32 = W_q.dequantize().to(dtype=torch.float) + X_fp32 = X_q.dequantize().to(dtype=torch.float) + b_fp32 = b_q.dequantize().to(dtype=torch.float) if use_bias else None + Y_fp32_ref = F.linear(X_fp32, W_fp32, b_fp32) + if use_relu: + Y_fp32_ref[Y_fp32_ref < 0.0] = 0.0 + decimal_val = 1 + np.testing.assert_array_almost_equal(Y_fp32_ref.numpy(), Y_q_dq.numpy(), decimal=decimal_val) + @given(batch_size=st.integers(1, 4), # in cudnn v. 8.4.0, there is a limitation that input channels # should be a multiple of 4 for int8 tensors. in cudnn v.8.3.3 @@ -3814,6 +3958,26 @@ def test_qlinear_unpack(self, W, use_channelwise): np.testing.assert_equal( W_q.q_zero_point(), W_q_origin.q_zero_point()) + @skipIfNoONEDNN + def test_qlinear_leaky_relu(self): + with override_quantized_engine('onednn'): + batch_size_list = [1, 4] + input_channels_list = [16, 32] + output_channels_list = [4, 8] + use_bias_list = [True, False] + use_multi_dim_input_list = [True, False] + use_channelwise_list = [True, False] + negative_slopes_list = [0.01, 0.05] + post_op = 'leaky_relu' + cases = itertools.product(batch_size_list, input_channels_list, output_channels_list, + use_bias_list, use_multi_dim_input_list, + use_channelwise_list, negative_slopes_list) + for batch_size, input_channels, output_channels, use_bias,\ + use_multi_dim_input, use_channelwise, neg_slope in cases: + self._test_qlinear_impl(batch_size, input_channels, output_channels, + use_bias, post_op, use_multi_dim_input, + use_channelwise, negative_slope=neg_slope) + @unittest.skipIf(IS_MACOS, "Known test failure on Mac.") class TestQuantizedEmbeddingOps(TestCase): @@ -5455,6 +5619,35 @@ def test_qconv3d_unpack( (stride_d, stride_h, stride_w), (pad_d, pad_h, pad_w), (o_pad, o_pad, o_pad), channelwise) + def test_conv_reorder_issue_onednn(self): + """ Ensure reorder failure issue in conv is fixed for onednn backend. + Onednn backend used to encounter reorder failure + when running conv with dynamic input shapes. + Solved by https://github.com/pytorch/pytorch/pull/86876 + """ + if 'onednn' not in supported_qengines: + return + with override_quantized_engine('onednn'): + bs = 1 + ic, oc = 128, 512 + kh, kw = 1, 1 + ih, iw = 28, 28 + bias = None + strides, paddings, dilates, groups = (1, 1), (0, 0), (1, 1), 1 + w = torch.randn((oc, ic, kh, kw)) + qw = torch.quantize_per_tensor(w, scale=1.0, zero_point=0, dtype=torch.qint8) + x = torch.randn((bs, ic, ih, iw)) + qx = torch.quantize_per_tensor(x, scale=1.0, zero_point=0, dtype=torch.quint8) + w_packed = torch.ops.quantized.conv2d_prepack( + qw, bias, strides, paddings, dilates, groups + ) + torch.ops.quantized.conv2d(qx, w_packed, output_scale=1.0, output_zero_point=0) + ih, iw = 5, 4 + x = torch.randn((bs, ic, ih, iw)) + qx = torch.quantize_per_tensor(x, scale=1.0, zero_point=0, dtype=torch.quint8) + # The following should pass when input shape is changed + torch.ops.quantized.conv2d(qx, w_packed, output_scale=1.0, output_zero_point=0) + class TestPadding(TestCase): @given(batch_size=st.integers(1, 64), channels=st.integers(1, 64), diff --git a/test/quantization/core/test_quantized_tensor.py b/test/quantization/core/test_quantized_tensor.py index 28eddd7cd974d..98e21ab30f097 100644 --- a/test/quantization/core/test_quantized_tensor.py +++ b/test/quantization/core/test_quantized_tensor.py @@ -70,7 +70,7 @@ def _calculate_dynamic_qparams(X, dtype, reduce_range=False): return [scale.astype(np.float32), int(nudged_zero_point)] # Note we explicitly cast variables to np.float32 in a couple of places to avoid -# the default casting in Python often resuling in double precision and to make +# the default casting in Python often resulting in double precision and to make # sure we're doing the same numerics as C++ code. def param_search_greedy(x, bit_rate, n_bins=200, ratio=0.16): xmin, xmax = np.min(x), np.max(x) @@ -443,8 +443,7 @@ def _test_per_channel_qtensor_creation(self, device): for dtype, zero_points in itertools.product([torch.qint8, torch.quint8], [zero_points_float, zero_points_int]): q = torch._empty_per_channel_affine_quantized( [numel], scales=scales, zero_points=zero_points, axis=ch_axis, dtype=dtype, device=device) - # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 - self.assertEqualIgnoreType(scales, q.q_per_channel_scales()) + self.assertEqual(scales, q.q_per_channel_scales(), exact_dtype=False) self.assertEqual(zero_points, q.q_per_channel_zero_points()) self.assertEqual(ch_axis, q.q_per_channel_axis()) @@ -453,8 +452,7 @@ def _test_per_channel_qtensor_creation(self, device): int_tensor = torch.randint(0, 100, size=(numel,), dtype=torch.uint8, device=device) q = torch._make_per_channel_quantized_tensor(int_tensor, scales, zero_points, ch_axis) self.assertEqual(int_tensor, q.int_repr()) - # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 - self.assertEqualIgnoreType(scales, q.q_per_channel_scales()) + self.assertEqual(scales, q.q_per_channel_scales(), exact_dtype=False) self.assertEqual(zero_points, q.q_per_channel_zero_points()) self.assertEqual(ch_axis, q.q_per_channel_axis()) @@ -809,8 +807,7 @@ def test_qtensor_per_channel_permute(self): self.assertEqual(qr.stride(), list(reversed(sorted(qr.stride())))) self.assertNotEqual(qlast.stride(), list(reversed(sorted(qlast.stride())))) self.assertEqual(qr.int_repr(), qlast.int_repr()) - # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 - self.assertEqualIgnoreType(scales, qlast.q_per_channel_scales()) + self.assertEqual(scales.to(dtype=torch.float64), qlast.q_per_channel_scales()) self.assertEqual(zero_points, qlast.q_per_channel_zero_points()) self.assertEqual(1, qlast.q_per_channel_axis()) self.assertEqual(qlast.dequantize(), qr.dequantize()) @@ -1461,7 +1458,112 @@ def test_bfp16_quantize(self): X = torch.randn(5 , 10) quantized_X = X.to(torch.bfloat16) dedequantized_X = quantized_X.to(torch.float32) - torch.testing.assert_allclose(X, dedequantized_X, rtol=1e-4, atol=5e-3) + torch.testing.assert_close(X, dedequantized_X, rtol=1e-4, atol=5e-3) + + def test_decomposed_quantize_per_tensor(self): + # register the ops + import torch.ao.quantization.fx._decomposed + X = torch.randn(5, 10) + test_cases = [ + (torch.quint8, torch.uint8, 0, 255), + (torch.qint8, torch.int8, -128, 127), + (torch.qint32, torch.int32, -2**31, 2**31 - 1), + ] + for qdtype, dtype, quant_min, quant_max in test_cases: + scale, zero_point = _calculate_dynamic_qparams(X, qdtype) + quantized_X = torch.quantize_per_tensor(X, scale, zero_point, qdtype) + quantized_decomposed_X = \ + torch.ops.quantized_decomposed.quantize_per_tensor( + X, scale, zero_point, quant_min, quant_max, dtype) + self.assertEqual(quantized_decomposed_X.dtype, dtype) + self.assertEqual(quantized_X.int_repr(), quantized_decomposed_X) + + def test_decomposed_dequantize_per_tensor(self): + import torch.ao.quantization.fx._decomposed + X = torch.randn(5, 10) + test_cases = [ + (torch.quint8, torch.uint8, 0, 255), + (torch.qint8, torch.int8, -128, 127), + (torch.qint32, torch.int32, -2**31, 2**31 - 1), + ] + + for qdtype, dtype, quant_min, quant_max in test_cases: + scale, zero_point = _calculate_dynamic_qparams(X, qdtype) + quantized_X = torch.quantize_per_tensor(X, scale, zero_point, qdtype) + dequantized_X = torch.dequantize(quantized_X) + + quantized_decomposed_X = torch.ops.quantized_decomposed.quantize_per_tensor( + X, scale, zero_point, quant_min, quant_max, dtype) + dequantized_decomposed_X = torch.ops.quantized_decomposed.dequantize_per_tensor( + quantized_decomposed_X, scale, zero_point, quant_min, quant_max, dtype + ) + self.assertEqual(quantized_X.int_repr(), quantized_decomposed_X) + self.assertEqual(dequantized_X, dequantized_decomposed_X) + + def test_decomposed_dynamic_quant_pattern(self): + import torch.ao.quantization.fx._decomposed + X = torch.randn(5, 10) + dtype = torch.uint8 + qdtype = torch.quint8 + scale, zero_point = torch._choose_qparams_per_tensor(X, False) + quant_min, quant_max = 0, 255 + + quantized_X = torch.quantize_per_tensor(X, scale, zero_point, qdtype) + dequantized_X = torch.dequantize(quantized_X) + + # Now try decomposed pattern + (scale_decomposed, zero_point_decomposed) = torch.ops.quantized_decomposed.choose_qparams.tensor( + X, quant_min, quant_max, dtype) + quantized_decomposed_X = torch.ops.quantized_decomposed.quantize_per_tensor.tensor( + X, scale_decomposed, zero_point_decomposed, quant_min, quant_max, dtype) + + dequantized_decomposed_X = torch.ops.quantized_decomposed.dequantize_per_tensor.tensor( + quantized_decomposed_X, scale_decomposed, zero_point_decomposed, quant_min, quant_max, dtype + ) + self.assertEqual(quantized_X.int_repr(), quantized_decomposed_X) + self.assertEqual(dequantized_X, dequantized_decomposed_X) + + def test_decomposed_quantize_per_channel(self): + # register the ops + import torch.ao.quantization.fx._decomposed + X = torch.randn(5, 10) + qdtype = torch.quint8 + dtype = torch.uint8 + scales = torch.randn(5,) + zero_points = torch.randint(0, 100, (5,)) + quant_min, quant_max = 0, 255 + axis = 0 + + quantized_X = torch.quantize_per_channel(X, scales, zero_points, axis, qdtype) + quantized_decomposed_X = \ + torch.ops.quantized_decomposed.quantize_per_channel( + X, scales, zero_points, axis, quant_min, quant_max, dtype) + self.assertEqual(quantized_decomposed_X.dtype, dtype) + self.assertEqual(quantized_X.int_repr(), quantized_decomposed_X) + + def test_decomposed_dequantize_per_channel(self): + # register the ops + import torch.ao.quantization.fx._decomposed + X = torch.randn(5, 10) + qdtype = torch.quint8 + dtype = torch.uint8 + scales = torch.randn(5,) + zero_points = torch.randint(0, 100, (5,)) + quant_min, quant_max = 0, 255 + axis = 0 + + quantized_X = torch.quantize_per_channel(X, scales, zero_points, axis, qdtype) + dequantized_X = torch.dequantize(quantized_X) + + quantized_decomposed_X = \ + torch.ops.quantized_decomposed.quantize_per_channel( + X, scales, zero_points, axis, quant_min, quant_max, dtype) + dequantized_decomposed_X = \ + torch.ops.quantized_decomposed.dequantize_per_channel( + quantized_decomposed_X, scales, zero_points, axis, quant_min, quant_max, dtype) + + self.assertEqual(quantized_X.int_repr(), quantized_decomposed_X) + self.assertEqual(dequantized_X, dequantized_decomposed_X) if __name__ == '__main__': raise RuntimeError("This test file is not meant to be run directly, use:\n\n" diff --git a/test/quantization/core/test_top_level_apis.py b/test/quantization/core/test_top_level_apis.py index 7343a16040d25..f76db1cd4139b 100644 --- a/test/quantization/core/test_top_level_apis.py +++ b/test/quantization/core/test_top_level_apis.py @@ -59,3 +59,35 @@ def test_fake_quants(self) -> None: for observer in self.fake_quants: obs = self._get_observer_ins(observer) obs.forward(t) + + +class TestQConfig(TestCase): + + REDUCE_RANGE_DICT = { + 'fbgemm': (True, False), + 'qnnpack': (False, False), + 'onednn': (False, False), + 'x86': (True, False), + } + + def test_reduce_range_qat(self) -> None: + for backend, reduce_ranges in self.REDUCE_RANGE_DICT.items(): + for version in range(2): + qconfig = torch.ao.quantization.get_default_qat_qconfig(backend, version) + + fake_quantize_activ = qconfig.activation() + self.assertEqual(fake_quantize_activ.activation_post_process.reduce_range, reduce_ranges[0]) + + fake_quantize_weight = qconfig.weight() + self.assertEqual(fake_quantize_weight.activation_post_process.reduce_range, reduce_ranges[1]) + + def test_reduce_range(self) -> None: + for backend, reduce_ranges in self.REDUCE_RANGE_DICT.items(): + for version in range(1): + qconfig = torch.ao.quantization.get_default_qconfig(backend, version) + + fake_quantize_activ = qconfig.activation() + self.assertEqual(fake_quantize_activ.reduce_range, reduce_ranges[0]) + + fake_quantize_weight = qconfig.weight() + self.assertEqual(fake_quantize_weight.reduce_range, reduce_ranges[1]) diff --git a/test/quantization/core/test_workflow_module.py b/test/quantization/core/test_workflow_module.py index 7194872f4e5e9..6ac8bed90ca3f 100644 --- a/test/quantization/core/test_workflow_module.py +++ b/test/quantization/core/test_workflow_module.py @@ -1011,11 +1011,11 @@ def test_fused_obs_fq_module(self, device): ) # Compare params with reference - torch.testing.assert_allclose(out, out_ref) - torch.testing.assert_allclose( + torch.testing.assert_close(out, out_ref) + torch.testing.assert_close( running_min_op, mod.activation_post_process.min_val ) - torch.testing.assert_allclose( + torch.testing.assert_close( running_max_op, mod.activation_post_process.max_val ) @@ -1066,11 +1066,11 @@ def test_fused_obs_fq_moving_avg_module(self, device): ) # Compare params with reference - torch.testing.assert_allclose(out, out_ref) - torch.testing.assert_allclose( + torch.testing.assert_close(out, out_ref) + torch.testing.assert_close( running_min_op, mod.activation_post_process.min_val ) - torch.testing.assert_allclose( + torch.testing.assert_close( running_max_op, mod.activation_post_process.max_val ) @@ -1095,12 +1095,12 @@ def test_compare_fused_obs_fq_oss_module(self, device): x = torch.randn(5, 5, device=device) out = mod(x) out_ref = mod_ref(x) - torch.testing.assert_allclose(out, out_ref) - torch.testing.assert_allclose( + torch.testing.assert_close(out, out_ref) + torch.testing.assert_close( mod_ref.activation_post_process.min_val, mod.activation_post_process.min_val, ) - torch.testing.assert_allclose( + torch.testing.assert_close( mod_ref.activation_post_process.max_val, mod.activation_post_process.max_val, ) @@ -1151,20 +1151,20 @@ def test_fused_mod_per_channel(self): False, ) # Compare params with reference - torch.testing.assert_allclose(out, out_ref) + torch.testing.assert_close(out, out_ref) if mod.observer_enabled[0]: - torch.testing.assert_allclose( + torch.testing.assert_close( running_min_op, mod.activation_post_process.min_val ) - torch.testing.assert_allclose( + torch.testing.assert_close( running_max_op, mod.activation_post_process.max_val ) if mod.fake_quant_enabled: - torch.testing.assert_allclose(scale, mod.scale) - torch.testing.assert_allclose(zero_point, mod.zero_point) + torch.testing.assert_close(scale, mod.scale) + torch.testing.assert_close(zero_point, mod.zero_point) - torch.testing.assert_allclose(mod.state_dict()['activation_post_process.min_val'], running_min_op) - torch.testing.assert_allclose(mod.state_dict()['activation_post_process.max_val'], running_max_op) + torch.testing.assert_close(mod.state_dict()['activation_post_process.min_val'], running_min_op) + torch.testing.assert_close(mod.state_dict()['activation_post_process.max_val'], running_max_op) def test_fused_mod_reduce_range(self): obs = FusedMovingAvgObsFakeQuantize(quant_min=0, quant_max=255, dtype=torch.quint8, reduce_range=True) diff --git a/test/quantization/core/test_workflow_ops.py b/test/quantization/core/test_workflow_ops.py index b459b5865bfaa..a0687d88fa57d 100644 --- a/test/quantization/core/test_workflow_ops.py +++ b/test/quantization/core/test_workflow_ops.py @@ -1083,7 +1083,7 @@ def test_fused_obs_fake_quant_moving_avg(self, device, symmetric_quant) -> None: self.assertEqual(in_running_min_ref, in_running_min_op) self.assertEqual(in_running_max_ref, in_running_max_op) - torch.testing.assert_allclose(out, x_in) + torch.testing.assert_close(out, x_in) # Test empty input works x = torch.empty(0, 5, device=device) @@ -1176,7 +1176,7 @@ def test_fused_obs_fake_quant_moving_avg_per_channel(self, device, symmetric_qua x_in = x self.assertEqual(in_running_min_ref, in_running_min_op) self.assertEqual(in_running_max_ref, in_running_max_op) - torch.testing.assert_allclose(out, x_in) + torch.testing.assert_close(out, x_in) @given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),) @settings(deadline=None) @@ -1218,7 +1218,7 @@ def test_fused_obs_fake_quant_backward_op(self, device) -> None: False, ) # verify the output matches - torch.testing.assert_allclose(out, x_fake_quant) + torch.testing.assert_close(out, x_fake_quant) # verify the gradient matches expectation of fake_quant op dout = torch.rand_like(x, dtype=torch.float).to(device) @@ -1264,7 +1264,7 @@ def test_fused_backward_op_fake_quant_off(self, device) -> None: False, ) # verify the output matches - torch.testing.assert_allclose(out, x) + torch.testing.assert_close(out, x) # verify the gradient matches expectation of fake_quant op dout = torch.rand_like(x, dtype=torch.float).to(device) diff --git a/test/quantization/eager/test_fuse_eager.py b/test/quantization/eager/test_fuse_eager.py index 9f120b889c2e5..1ebc4bfd094eb 100644 --- a/test/quantization/eager/test_fuse_eager.py +++ b/test/quantization/eager/test_fuse_eager.py @@ -28,6 +28,7 @@ ModelForLinearBNFusion, ModelForFusionWithBias, ModelForConvTransposeBNFusion, + SingleLayerLinearModel, test_only_eval_fn, test_only_train_fn, skipIfNoFBGEMM, @@ -363,6 +364,17 @@ def test_fusion_convtranspose_bn_eval(self): self.assertEqual(golden, model(inp2)) + def test_fuse_function_customization(self): + dummy_model = SingleLayerLinearModel().train() + dummy_model.eval() + + # A custom fuse funct + def custom_fuse_func(module, is_qat, add_fuser_mapping): + return [torch.nn.Identity()] + + dummy_model = fuse_modules(dummy_model, [["fc1"]], fuser_func=custom_fuse_func) + self.assertEqual(type(dummy_model.fc1), nn.Identity) + def test_forward_hooks_preserved(self): r"""Test case that checks whether forward pre hooks of the first module and post forward hooks of the last module in modules list passed to fusion function preserved. diff --git a/test/quantization/eager/test_quantize_eager_ptq.py b/test/quantization/eager/test_quantize_eager_ptq.py index 7d87cc520ba04..e0ad793df68a8 100644 --- a/test/quantization/eager/test_quantize_eager_ptq.py +++ b/test/quantization/eager/test_quantize_eager_ptq.py @@ -19,8 +19,10 @@ float16_dynamic_qconfig, float_qparams_weight_only_qconfig, float_qparams_weight_only_qconfig_4bit, + FixedQParamsObserver, PerChannelMinMaxObserver, default_dynamic_quant_observer, + default_weight_observer, QConfig, ) @@ -61,8 +63,6 @@ supported_qengines, override_qengines, ) -from torch.testing._internal.jit_utils import JitTestCase -from torch.testing._internal.common_utils import skipIfNoCaffe2 from hypothesis import given from hypothesis import strategies as st @@ -71,8 +71,6 @@ # Standard library from typing import Tuple -import io -import unittest import numpy as np class TestQuantizeEagerOps(QuantizationTestCase): @@ -1026,6 +1024,48 @@ def test_quantwrapper_attaches_qconfig_to_dequant(self): mq = torch.ao.quantization.convert(mp) self.assertTrue(isinstance(mq[0].dequant, nnq.DeQuantize)) + def test_activations_in_non_leaf_module_list(self): + """ + Ensure activations like `nn.Sigmoid` and `nn.Tanh` are properly handled in + `non_leaf_module_list`. + """ + class MyModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.quant = QuantStub() + self.sigmoid = torch.nn.Sigmoid() + self.hardsigmoid = torch.nn.Hardsigmoid() + self.softmax = torch.nn.Softmax() + self.tanh = torch.nn.Tanh() + self.dequant = DeQuantStub() + + def forward(self, x): + x = self.quant(x) + x = self.sigmoid(x) + x = self.hardsigmoid(x) + x = self.softmax(x) + x = self.tanh(x) + x = self.dequant(x) + return x + + qconfig = QConfig( + activation=FixedQParamsObserver.with_args(scale=123.0, zero_point=0), + weight=default_weight_observer + ) + m = MyModel() + m.qconfig = qconfig + m = prepare(m, observer_non_leaf_module_list=[ + torch.nn.Sigmoid, + torch.nn.Hardsigmoid, + torch.nn.Softmax, + torch.nn.Tanh, + ]) + + # Should use the observer specified in the QConfig instead of the default (FixedQParamsFakeQuantize) + self.assertTrue(isinstance(m.sigmoid.activation_post_process, FixedQParamsObserver)) + self.assertTrue(isinstance(m.hardsigmoid.activation_post_process, FixedQParamsObserver)) + self.assertTrue(isinstance(m.softmax.activation_post_process, FixedQParamsObserver)) + self.assertTrue(isinstance(m.tanh.activation_post_process, FixedQParamsObserver)) @skipIfNoFBGEMM class TestQuantizeEagerPTQDynamic(QuantizationTestCase): @@ -1443,53 +1483,6 @@ def forward(self, indices, offsets, linear_in): self.assertTrue('QuantizedEmbedding' in str(q_model)) self.assertTrue('DynamicQuantizedLinear' in str(q_model)) -class TestQuantizeEagerONNXExport(JitTestCase): - def _test_lower_graph_impl(self, model, data): - model.qconfig = torch.ao.quantization.default_qconfig - model = torch.ao.quantization.prepare(model) - model = torch.ao.quantization.convert(model) - - outputs = model(data) - input_names = ["x"] - - def export_to_onnx(model, input, input_names): - traced = torch.jit.trace(model, input) - buf = io.BytesIO() - torch.jit.save(traced, buf) - buf.seek(0) - - model = torch.jit.load(buf) - f = io.BytesIO() - torch.onnx.export(model, input, f, input_names=input_names, - operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, - opset_version=9) - onnx_model = export_to_onnx(model, data, input_names) - - @skipIfNoFBGEMM - @skipIfNoCaffe2 - def test_lower_graph_linear(self): - model = torch.ao.quantization.QuantWrapper(torch.nn.Linear(5, 10, bias=True)).to(dtype=torch.float) - data_numpy = np.random.rand(1, 2, 5).astype(np.float32) - data = torch.from_numpy(data_numpy).to(dtype=torch.float) - self._test_lower_graph_impl(model, data) - - @skipIfNoFBGEMM - @skipIfNoCaffe2 - def test_lower_graph_conv2d(self): - model = torch.ao.quantization.QuantWrapper(torch.nn.Conv2d(3, 5, 2, bias=True)).to(dtype=torch.float) - data_numpy = np.random.rand(1, 3, 6, 6).astype(np.float32) - data = torch.from_numpy(data_numpy).to(dtype=torch.float) - self._test_lower_graph_impl(model, data) - - @skipIfNoFBGEMM - @unittest.skip("onnx opset9 does not support quantize_per_tensor and caffe2 \ - does not support conv3d") - def test_lower_graph_conv3d(self): - model = torch.ao.quantization.QuantWrapper(torch.nn.Conv3d(3, 5, 2, bias=True)).to(dtype=torch.float) - data_numpy = np.random.rand(1, 3, 6, 6, 6).astype(np.float32) - data = torch.from_numpy(data_numpy).to(dtype=torch.float) - self._test_lower_graph_impl(model, data) - if __name__ == '__main__': raise RuntimeError("This test file is not meant to be run directly, use:\n\n" diff --git a/test/quantization/eager/test_quantize_eager_qat.py b/test/quantization/eager/test_quantize_eager_qat.py index bc118a82062d9..44911b6d9e11a 100644 --- a/test/quantization/eager/test_quantize_eager_qat.py +++ b/test/quantization/eager/test_quantize_eager_qat.py @@ -594,6 +594,7 @@ def forward(self, x): eps = 1e-5 self.assertTrue(torch.abs(mq.quant.scale * 2 - res.q_scale()) < eps) + @override_qengines def test_qat_embedding_bag_errors(self): default_qat_qconfig = get_default_qat_qconfig(torch.backends.quantized.engine) diff --git a/test/quantization/fx/test_numeric_suite_fx.py b/test/quantization/fx/test_numeric_suite_fx.py index 27fe772d2e228..eb7dcdfac3556 100644 --- a/test/quantization/fx/test_numeric_suite_fx.py +++ b/test/quantization/fx/test_numeric_suite_fx.py @@ -31,6 +31,7 @@ LSTMwithHiddenDynamicModel, SparseNNModel, skip_if_no_torchvision, + TwoLayerLinearModel ) from torch.ao.quantization.quantization_mappings import ( get_default_static_quant_module_mappings, @@ -40,7 +41,7 @@ from torch.testing._internal.common_cuda import TEST_CUDA from torch.testing._internal.common_quantization import NodeSpec as ns from torch.ao.quantization.fx.pattern_utils import get_default_quant_patterns -import torch.ao.quantization.fx.quantization_patterns as qp +import torch.ao.quantization.fx.quantize_handler as qh from torch.ao.ns.fx.pattern_utils import ( get_type_a_related_to_b, ) @@ -82,8 +83,9 @@ loggers_set_enabled, loggers_set_save_activations, ) +from torch.ao.ns.fx.qconfig_multi_mapping import QConfigMultiMapping from torch.ao.quantization.backend_config import get_native_backend_config -from torch.ao.quantization.fx.backend_config_utils import get_pattern_to_quantize_handlers +from torch.ao.quantization.fx.quantize_handler import _get_pattern_to_quantize_handlers # Note: these models are not for use outside of this file. While it's good @@ -297,7 +299,7 @@ def get_all_quant_patterns(): all_quant_patterns = get_default_quant_patterns() # some of the patterns are moved to (native) backend_config_dict so we need to # add them back here - for pattern, quantize_handler in get_pattern_to_quantize_handlers(get_native_backend_config()).items(): + for pattern, quantize_handler in _get_pattern_to_quantize_handlers(get_native_backend_config()).items(): all_quant_patterns[pattern] = quantize_handler return all_quant_patterns @@ -667,21 +669,21 @@ def _op_is_unmatchable(op): base_op = pattern qhandler_cls_all_ops_quantizeable = [ - qp.CatQuantizeHandler, - qp.ConvReluQuantizeHandler, - qp.LinearReLUQuantizeHandler, - qp.BatchNormQuantizeHandler, - qp.EmbeddingQuantizeHandler, - qp.RNNDynamicQuantizeHandler, + qh.CatQuantizeHandler, + qh.ConvReluQuantizeHandler, + qh.LinearReLUQuantizeHandler, + qh.BatchNormQuantizeHandler, + qh.EmbeddingQuantizeHandler, + qh.RNNDynamicQuantizeHandler, ] qhandler_cls_quant_op_same_signature = [ - qp.FixedQParamsOpQuantizeHandler, - qp.CopyNodeQuantizeHandler, - qp.GeneralTensorShapeOpQuantizeHandler, + qh.FixedQParamsOpQuantizeHandler, + qh.CopyNodeQuantizeHandler, + qh.GeneralTensorShapeOpQuantizeHandler, ] - if qhandler_cls == qp.BinaryOpQuantizeHandler: + if qhandler_cls == qh.BinaryOpQuantizeHandler: # these ops do not have quantized equivalents ops_to_skip = [ torch.bmm, @@ -695,11 +697,11 @@ def _op_is_unmatchable(op): self.assertTrue( _op_in_base_sets_of_related_ops(base_op), f"{base_op} not in sets of related ops") - elif qhandler_cls == qp.RNNDynamicQuantizeHandler: + elif qhandler_cls == qh.RNNDynamicQuantizeHandler: # TODO(future PR): add support for all classes in # RNNDynamicQuantizeHandler pass - elif qhandler_cls == qp.DefaultNodeQuantizeHandler: + elif qhandler_cls == qh.DefaultNodeQuantizeHandler: self.assertTrue( _op_in_base_sets_of_related_ops(base_op), f"{base_op} not in sets of related ops") @@ -1604,23 +1606,23 @@ def test_op_io_dtype_coverage(self): if ( qhandler_cls in ( - qp.BinaryOpQuantizeHandler, - qp.RNNDynamicQuantizeHandler, + qh.BinaryOpQuantizeHandler, + qh.RNNDynamicQuantizeHandler, ) ): # TODO(future PR): implement shadowing for binary ops # TODO(future PR): implement shadowing for RNN ops continue - elif qhandler_cls == qp.CatQuantizeHandler: + elif qhandler_cls == qh.CatQuantizeHandler: self.assertTrue( base_op in FUNS_IO_TYPE_FP32_OR_INT8, f"missing IO type handling for {base_op}") elif ( qhandler_cls in ( - qp.ConvReluQuantizeHandler, - qp.LinearReLUQuantizeHandler, - qp.BatchNormQuantizeHandler, - qp.DefaultNodeQuantizeHandler, + qh.ConvReluQuantizeHandler, + qh.LinearReLUQuantizeHandler, + qh.BatchNormQuantizeHandler, + qh.DefaultNodeQuantizeHandler, ) ): self.assertTrue( @@ -1628,9 +1630,9 @@ def test_op_io_dtype_coverage(self): f"missing IO type handling for {base_op}") elif ( qhandler_cls in ( - qp.FixedQParamsOpQuantizeHandler, - qp.CopyNodeQuantizeHandler, - qp.GeneralTensorShapeOpQuantizeHandler, + qh.FixedQParamsOpQuantizeHandler, + qh.CopyNodeQuantizeHandler, + qh.GeneralTensorShapeOpQuantizeHandler, ) ): if ( @@ -1648,7 +1650,7 @@ def test_op_io_dtype_coverage(self): # version, so it does not fit into the cases above. (base_op is torch.nn.Softmax), f"missing IO type handling for {base_op}") - elif qhandler_cls == qp.EmbeddingQuantizeHandler: + elif qhandler_cls == qh.EmbeddingQuantizeHandler: # embedding shadowing is not implemented, for now continue else: @@ -2096,6 +2098,7 @@ def _test_impl(self, m, example_input, qconfig_mappings): results = extract_results_n_shadows_model(msq) print_comparisons_n_shadows_model(results) + return msq def test_linear_mod(self): class M(nn.Module): @@ -2110,9 +2113,8 @@ def forward(self, x): m = M().eval() example_input = (torch.randn(2, 2),) - qconfig_mappings = [ - QConfigMapping().set_global(torch.quantization.default_qconfig), - ] + qconfig_mappings = \ + QConfigMultiMapping().set_global([torch.quantization.default_qconfig]) self._test_impl(m, example_input, qconfig_mappings) def test_linear_relu_mod(self): @@ -2132,10 +2134,12 @@ def forward(self, x): m = M().eval() example_input = (torch.randn(2, 2),) - qconfig_mappings = [ - QConfigMapping().set_global(torch.quantization.default_qconfig), - QConfigMapping().set_global(torch.quantization.default_dynamic_qconfig), - ] + qconfig_mappings = ( + QConfigMultiMapping().set_global([ + torch.quantization.default_qconfig, + torch.quantization.default_dynamic_qconfig + ]) + ) self._test_impl(m, example_input, qconfig_mappings) def test_conv_bn_relu_mod(self): @@ -2154,10 +2158,12 @@ def forward(self, x): m = M().eval() example_input = (torch.randn(32, 1, 16, 16),) - qconfig_mappings = [ - QConfigMapping().set_global(torch.quantization.default_qconfig), - QConfigMapping().set_global(torch.quantization.default_per_channel_qconfig), - ] + + qconfig_mappings = QConfigMultiMapping() \ + .set_global([ + torch.quantization.default_qconfig, + torch.quantization.default_per_channel_qconfig + ]) self._test_impl(m, example_input, qconfig_mappings) def test_functions(self): @@ -2194,10 +2200,8 @@ def forward(self, x): m = M().eval() example_input = (torch.randn(2, 2),) - qconfig_mappings = [ - QConfigMapping().set_global(torch.quantization.default_qconfig), - # QConfigMapping().set_global(torch.quantization.default_per_channel_qconfig), - ] + qconfig_mappings = QConfigMultiMapping() \ + .set_global([torch.quantization.default_qconfig]) self._test_impl(m, example_input, qconfig_mappings) def test_partial_qconfig_mapping(self): @@ -2220,19 +2224,17 @@ def forward(self, x): example_input = (torch.randn(2, 2),) qconfig = torch.ao.quantization.default_qconfig - qconfig_mappings = [ - QConfigMapping().set_global(None) - .set_object_type(F.linear, qconfig) - .set_object_type(F.relu, qconfig), - ] + qconfig_mappings = QConfigMultiMapping() \ + .set_object_type(F.linear, [qconfig]) \ + .set_object_type(F.relu, [qconfig]) self._test_impl(m, example_input, qconfig_mappings) def test_logger_enabled_and_save_activations_flags(self): m = nn.Sequential(nn.Linear(1, 1)).eval() example_input = (torch.randn(1, 1),) - qconfig_mappings = [ - QConfigMapping().set_global(torch.quantization.default_qconfig), - ] + + qconfig_mappings = QConfigMultiMapping() \ + .set_global([torch.quantization.default_qconfig]) backend_config = get_native_backend_config() msp = prepare_n_shadows_model( @@ -2281,11 +2283,267 @@ def test_mobilenet_v2(self): pretrained=False, quantize=False).eval() example_input = (torch.randn(1, 3, 224, 224),) - qconfig_mappings = [ + qconfig_mappings = QConfigMultiMapping() \ + .set_global([torch.quantization.default_qconfig, torch.quantization.default_dynamic_qconfig]) + + self._test_impl(m, example_input, qconfig_mappings) + + def test_qconfig_multi_mapping_deduplication(self): + # check that insertion deduplicates qconfigs + qconfig_multi_mapping = QConfigMultiMapping().set_global( + [torch.quantization.default_qconfig, torch.quantization.default_qconfig] + ) + self.assertEqual(len(qconfig_multi_mapping.qconfig_mappings_list), 1) + + def test_qconfig_multi_mapping_insert_padding(self): + # test that inserting a higher priority qconfig style with fewer elements than a lower priority qconfig will + # result in adding None to the extra QConfigMappings at that same style+key + qconfig_multi_mapping = ( + QConfigMultiMapping() + .set_global( + [ + torch.quantization.default_qconfig, + torch.quantization.default_dynamic_qconfig, + ] + ) + .set_object_type(torch.nn.Linear, [torch.quantization.default_qconfig]) + .set_module_name_regex("fc", [torch.quantization.default_qconfig]) + .set_module_name("fc2", [torch.quantization.default_qconfig]) + .set_module_name_object_type_order( + "", nn.Linear, 0, [torch.quantization.default_qconfig] + ) + ) + + self.assertEqual( + qconfig_multi_mapping.qconfig_mappings_list[1].object_type_qconfigs[ + torch.nn.Linear + ], + None, + ) + self.assertEqual( + qconfig_multi_mapping.qconfig_mappings_list[1].module_name_regex_qconfigs[ + "fc" + ], + None, + ) + self.assertEqual( + qconfig_multi_mapping.qconfig_mappings_list[1].module_name_qconfigs["fc2"], + None, + ) + self.assertEqual( + qconfig_multi_mapping.qconfig_mappings_list[ + 1 + ].module_name_object_type_order_qconfigs[("", nn.Linear, 0)], + None, + ) + + def test_qconfig_multi_mapping_retroactive_padding(self): + # test that inserting a lower priority qconfig style with more elements thhan lower priority qconfig styles + # will result in the new QConfigMapping having None at all previously existing styles+keys + qconfig_multi_mapping = ( + QConfigMultiMapping() + .set_object_type(torch.nn.Linear, [torch.quantization.default_qconfig]) + .set_module_name_regex("fc", [torch.quantization.default_qconfig]) + .set_module_name("fc2", [torch.quantization.default_qconfig]) + .set_module_name_object_type_order( + "", nn.Linear, 0, [torch.quantization.default_qconfig] + ) + .set_global( + [ + torch.quantization.default_qconfig, + torch.quantization.default_dynamic_qconfig, + ] + ) + ) + + self.assertEqual( + qconfig_multi_mapping.qconfig_mappings_list[1].object_type_qconfigs[ + torch.nn.Linear + ], + None, + ) + self.assertEqual( + qconfig_multi_mapping.qconfig_mappings_list[1].module_name_regex_qconfigs[ + "fc" + ], + None, + ) + self.assertEqual( + qconfig_multi_mapping.qconfig_mappings_list[1].module_name_qconfigs["fc2"], + None, + ) + self.assertEqual( + qconfig_multi_mapping.qconfig_mappings_list[ + 1 + ].module_name_object_type_order_qconfigs[("", nn.Linear, 0)], + None, + ) + + def test_qconfig_multi_mapping_end_to_end(self): + # test that the prepare/convert_n_shadows_model works as expected + # with qconfig_multi_mapping and avoids unwanted matches + + m = TwoLayerLinearModel().eval() + example_input = m.get_example_inputs() + + qconfig_multi_mapping = ( + QConfigMultiMapping() + .set_global( + [ + torch.quantization.default_qconfig, + torch.quantization.default_dynamic_qconfig, + ] + ) + .set_module_name("fc2", [None, torch.quantization.default_qconfig]) + ) + self.assertEqual( + qconfig_multi_mapping.qconfig_mappings_list[1].module_name_qconfigs["fc2"], + None, + ) + msq = self._test_impl(m, example_input, qconfig_multi_mapping) + + self.checkQuantizedLinear(msq.shadow_wrapper_0_1.mod_0) + self.checkDynamicQuantizedLinear(msq.shadow_wrapper_0_2.mod_0, torch.qint8) + self.checkQuantizedLinear(msq.shadow_wrapper_1_1.mod_0) + self.assertRaisesRegex(AttributeError, ".*", lambda: msq.shadow_wrapper_1_2) + + def test_qconfig_multi_mapping_from_list(self): + # test QConfigMultiMapping.from_list_qconfig_mapping works as expected + + m = TwoLayerLinearModel().eval() + example_input = m.get_example_inputs() + + qconfig_mappings_list = [ QConfigMapping().set_global(torch.quantization.default_qconfig), - QConfigMapping().set_global(torch.quantization.default_dynamic_qconfig), + QConfigMapping() + .set_global(torch.quantization.default_dynamic_qconfig) + .set_module_name("fc2", torch.quantization.default_qconfig), ] - self._test_impl(m, example_input, qconfig_mappings) + + qconfig_multi_mapping = QConfigMultiMapping().from_list_qconfig_mapping( + qconfig_mappings_list + ) + self.assertEqual( + qconfig_multi_mapping.qconfig_mappings_list[1].module_name_qconfigs["fc2"], + None, + ) + + msq = self._test_impl(m, example_input, qconfig_multi_mapping) + + self.checkQuantizedLinear(msq.shadow_wrapper_0_1.mod_0) + self.checkDynamicQuantizedLinear(msq.shadow_wrapper_0_2.mod_0, torch.qint8) + self.checkQuantizedLinear(msq.shadow_wrapper_1_1.mod_0) + self.assertRaisesRegex(AttributeError, ".*", lambda: msq.shadow_wrapper_1_2) + + def test_qconfig_multi_mapping_ordering(self): + # test that the module ordering ignores None + + m = TwoLayerLinearModel().eval() + example_input = m.get_example_inputs() + qconfig_multi_mapping = ( + QConfigMultiMapping() + .set_global( + [ + torch.ao.quantization.default_qconfig, + torch.ao.quantization.default_dynamic_qconfig, + ] + ) + .set_module_name( + "fc2", + [ + None, + torch.ao.quantization.default_dynamic_qconfig, + torch.ao.quantization.default_qat_qconfig_v2, + ], + ) + ) + self.assertEqual(len(qconfig_multi_mapping.qconfig_mappings_list), 2) + msq = self._test_impl(m, example_input, qconfig_multi_mapping) + + self.checkQuantizedLinear(msq.shadow_wrapper_0_1.mod_0) + self.checkDynamicQuantizedLinear(msq.shadow_wrapper_0_2.mod_0, torch.qint8) + self.checkDynamicQuantizedLinear(msq.shadow_wrapper_1_1.mod_0, torch.qint8) + self.checkQuantizedLinear(msq.shadow_wrapper_1_2.mod_0) + + def test_qconfig_multi_mapping_repr(self): + qconfig_multi_mapping = ( + QConfigMultiMapping() + .set_global( + [ + torch.ao.quantization.default_qconfig, + torch.ao.quantization.default_dynamic_qconfig, + ] + ) + .set_module_name( + "fc2", + [ + None, + torch.ao.quantization.default_dynamic_qconfig, + torch.ao.quantization.default_qat_qconfig_v2, + ], + ) + ) + self.assertTrue(isinstance(qconfig_multi_mapping.__repr__(), str)) + + def test_custom_functions_and_tracer(self): + class M(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(2, 2) + self.fc2 = nn.Linear(2, 2) + + def forward(self, x): + x = self.fc1(x) + x = self.fc2(x) + return x + + m = M().eval() + example_inputs = (torch.randn(2, 2),) + + qconfig_mappings = QConfigMultiMapping().set_global( + [torch.quantization.default_qat_qconfig] + ) + + custom_tracer = torch.ao.quantization.quantize_fx.QuantizationTracer( + ["fc2"], [] + ) + + custom_prepare_fn = torch.ao.quantization.quantize_fx.prepare_qat_fx + + def custom_convert_fn(module, to_print): + print(to_print) + mod = torch.ao.quantization.quantize_fx.convert_fx(module) + return mod + + backend_config = get_native_backend_config() + + # test that input is valid + _ = m(*example_inputs) + + kwargs = {"to_print": "working"} + + msp = prepare_n_shadows_model( + m, + example_inputs, + qconfig_mappings, + backend_config, + custom_prepare_fn=custom_prepare_fn, + custom_prepare_kwargs=None, + custom_tracer=custom_tracer, + ) + + for _ in range(2): + msp(*example_inputs) + + msq = convert_n_shadows_model( + msp, custom_convert_fn=custom_convert_fn, custom_convert_kwargs=kwargs + ) + print(msq) + loggers_set_enabled(msq, True) + msq(*example_inputs) + + results = extract_results_n_shadows_model(msq) + print_comparisons_n_shadows_model(results) class TestFXNumericSuiteCoreAPIsModels(FXNumericSuiteQuantizationTestCase): diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py index 6935081a5c923..794d70b56f8f2 100644 --- a/test/quantization/fx/test_quantize_fx.py +++ b/test/quantization/fx/test_quantize_fx.py @@ -18,11 +18,13 @@ prepare_fx, convert_fx, convert_to_reference_fx, + _convert_to_reference_decomposed_fx, prepare_qat_fx, fuse_fx, ) -from torch.ao.quantization.fx.quantization_patterns import DefaultNodeQuantizeHandler + +from torch.ao.quantization.fx.quantize_handler import DefaultNodeQuantizeHandler from torch.ao.quantization.fx.match_utils import ( is_match, @@ -31,8 +33,8 @@ from torch.ao.quantization import ( QuantType, - quant_type_to_str, ) +from torch.ao.quantization.quant_type import _get_quant_type_to_str from torch.ao.quantization import ( QuantStub, @@ -40,6 +42,7 @@ QuantWrapper, default_qconfig, default_dynamic_qconfig, + default_per_channel_qconfig, default_qat_qconfig, default_reuse_input_qconfig, default_symmetric_qnnpack_qconfig, @@ -88,18 +91,19 @@ from torch.ao.quantization.qconfig_mapping import ( _get_symmetric_qnnpack_qconfig_mapping, - GLOBAL_DICT_KEY, - MODULE_NAME_DICT_KEY, - MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY, - MODULE_NAME_REGEX_DICT_KEY, - OBJECT_TYPE_DICT_KEY, + _GLOBAL_DICT_KEY, + _MODULE_NAME_DICT_KEY, + _MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY, + _MODULE_NAME_REGEX_DICT_KEY, + _OBJECT_TYPE_DICT_KEY, QConfigMapping, ) -from torch.ao.quantization.qconfig_mapping_utils import ( - get_object_type_qconfig, - get_module_name_qconfig, - get_module_name_regex_qconfig, +from torch.ao.quantization.fx.qconfig_mapping_utils import ( + _get_object_type_qconfig, + _get_module_name_qconfig, + _get_module_name_regex_qconfig, + maybe_adjust_qconfig_for_module_name_object_type_order, ) from torch.ao.quantization.fx.pattern_utils import ( @@ -128,10 +132,6 @@ StandaloneModuleConfigEntry, ) -from torch.ao.quantization.fx.qconfig_mapping_utils import ( - maybe_adjust_qconfig_for_module_name_object_type_order, -) - from torch.ao.quantization.fx.utils import ( _reroute_tuple_getitem_pattern, NodeInfo, @@ -157,6 +157,7 @@ LinearReluModel, QuantizationTestCase, skipIfNoFBGEMM, + skipIfNoQNNPACK, skip_if_no_torchvision, train_one_epoch, run_ddp, @@ -190,7 +191,7 @@ import operator import unittest import io -from typing import Callable, Optional, List +from typing import Callable, Optional, List, Tuple class BinaryOp(torch.nn.Module): def __init__(self, binary_op, ibinary_op, is_inplace, is_scalar): @@ -489,7 +490,7 @@ def forward(self, x): self.checkGraphModuleNodes(m, expected_node=ns.call_module(torch.nn.intrinsic.modules.fused.LinearReLU)) - @unittest.skip("Temprorarily skipping the test case, will enable after the simple" + @unittest.skip("Temporarily skipping the test case, will enable after the simple" "pattern format is supported") def test_fuse_addtional_fuser_method(self): class MyConvReLU(torch.nn.Module): @@ -712,6 +713,22 @@ def forward(self, x, y): if n.op == 'call_module' and type(modules[n.target]) == nn.ReLU: self.assertTrue(is_match(modules, n, pattern)) + def test_pattern_match_constant(self): + class M(torch.nn.Module): + def forward(self, x): + x, _ = torch.ops.aten.max_pool2d_with_indices.default(x) + return x + + pattern = (operator.getitem, torch.ops.aten.max_pool2d_with_indices.default, 0) + m = torch.fx.symbolic_trace(M()) + # eliminate the code that get the second output of maxpool, so that the pattern + # can be matched + m.graph.eliminate_dead_code() + modules = dict(m.named_modules()) + for n in m.graph.nodes: + if n.op == "call_function" and n.target == operator.getitem: + self.assertTrue(is_match(modules, n, pattern)) + def test_fused_module_qat_swap(self): class Tmp(torch.nn.Module): def __init__(self): @@ -1874,9 +1891,9 @@ def test_qconfig_mapping_set_object_type(self): qconfig_mapping.set_object_type(torch.nn.Linear, qconfig3) self.assertEqual(qconfig_mapping.object_type_qconfigs[torch.nn.Linear], qconfig3) self.assertEqual(qconfig_mapping.object_type_qconfigs[torch.nn.ReLU], qconfig2) - self.assertEqual(get_object_type_qconfig(qconfig_mapping, torch.nn.Linear, None), qconfig3) - self.assertEqual(get_object_type_qconfig(qconfig_mapping, torch.nn.ReLU, None), qconfig2) - self.assertEqual(get_object_type_qconfig(qconfig_mapping, "nomatch", None), None) + self.assertEqual(_get_object_type_qconfig(qconfig_mapping, torch.nn.Linear, None), qconfig3) + self.assertEqual(_get_object_type_qconfig(qconfig_mapping, torch.nn.ReLU, None), qconfig2) + self.assertEqual(_get_object_type_qconfig(qconfig_mapping, "nomatch", None), None) def test_qconfig_mapping_set_module_name_regex(self): qconfig1 = get_default_qconfig() @@ -1896,11 +1913,11 @@ def test_qconfig_mapping_set_module_name_regex(self): qconfig_mapping.set_module_name_regex("foo.*bar", qconfig3) self.assertEqual(qconfig_mapping.module_name_regex_qconfigs["foo.*bar"], qconfig3) self.assertEqual(qconfig_mapping.module_name_regex_qconfigs["foo.*"], qconfig2) - self.assertEqual(get_module_name_regex_qconfig(qconfig_mapping, "foo123bar", None), qconfig3) - self.assertEqual(get_module_name_regex_qconfig(qconfig_mapping, "foobar", None), qconfig3) - self.assertEqual(get_module_name_regex_qconfig(qconfig_mapping, "foobaz", None), qconfig2) - self.assertEqual(get_module_name_regex_qconfig(qconfig_mapping, "foo", None), qconfig2) - self.assertEqual(get_module_name_regex_qconfig(qconfig_mapping, "nomatch", None), None) + self.assertEqual(_get_module_name_regex_qconfig(qconfig_mapping, "foo123bar", None), qconfig3) + self.assertEqual(_get_module_name_regex_qconfig(qconfig_mapping, "foobar", None), qconfig3) + self.assertEqual(_get_module_name_regex_qconfig(qconfig_mapping, "foobaz", None), qconfig2) + self.assertEqual(_get_module_name_regex_qconfig(qconfig_mapping, "foo", None), qconfig2) + self.assertEqual(_get_module_name_regex_qconfig(qconfig_mapping, "nomatch", None), None) def test_qconfig_mapping_set_module_name(self): qconfig1 = get_default_qconfig() @@ -1920,9 +1937,9 @@ def test_qconfig_mapping_set_module_name(self): qconfig_mapping.set_module_name("mod1", qconfig3) self.assertEqual(qconfig_mapping.module_name_qconfigs["mod1"], qconfig3) self.assertEqual(qconfig_mapping.module_name_qconfigs["mod2"], qconfig2) - self.assertEqual(get_module_name_qconfig(qconfig_mapping, "mod1", None), qconfig3) - self.assertEqual(get_module_name_qconfig(qconfig_mapping, "mod2", None), qconfig2) - self.assertEqual(get_module_name_qconfig(qconfig_mapping, "nomatch", None), None) + self.assertEqual(_get_module_name_qconfig(qconfig_mapping, "mod1", None), qconfig3) + self.assertEqual(_get_module_name_qconfig(qconfig_mapping, "mod2", None), qconfig2) + self.assertEqual(_get_module_name_qconfig(qconfig_mapping, "nomatch", None), None) def test_qconfig_mapping_set_module_name_object_type_order(self): qconfig1 = get_default_qconfig() @@ -1970,20 +1987,20 @@ def _get_qconfig_dict_for_qconfig_mapping_test(self, global_qconfig, qconfig1, q Return a dummy qconfig_dict to test QConfigMapping's to_dict and from_dict methods. """ return { - GLOBAL_DICT_KEY: global_qconfig, - OBJECT_TYPE_DICT_KEY: [ + _GLOBAL_DICT_KEY: global_qconfig, + _OBJECT_TYPE_DICT_KEY: [ (torch.nn.Linear, qconfig1), (torch.nn.ReLU, qconfig2), ], - MODULE_NAME_REGEX_DICT_KEY: [ + _MODULE_NAME_REGEX_DICT_KEY: [ ("foo.*bar", qconfig1), ("foo.*", qconfig2), ], - MODULE_NAME_DICT_KEY: [ + _MODULE_NAME_DICT_KEY: [ ("bazbaz", qconfig1), ("borbor", qconfig2), ], - MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY: [ + _MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY: [ ("bazbaz", torch.nn.Linear, 0, qconfig1), ("foofoo", torch.nn.ReLU, 1, qconfig2), ], @@ -2037,6 +2054,34 @@ def test_qconfig_mapping_to_dict(self): qconfig_dict = self._get_qconfig_dict_for_qconfig_mapping_test(global_qconfig, qconfig1, qconfig2) self.assertEqual(qconfig_mapping.to_dict(), qconfig_dict) + def test_qconfig_mapping_repr(self): + self.assertTrue(isinstance(get_default_qconfig_mapping().__repr__(), str)) + + def test_default_qconfig_mapping_override_global(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(1, 1, 1) + + def forward(self, x): + return self.conv(x) + + m = M().eval() + my_qconfig = QConfig(activation=MinMaxObserver, weight=default_weight_observer) + qconfig_mapping = get_default_qconfig_mapping() + # Override global qconfig + old_global_qconfig = qconfig_mapping.global_qconfig + qconfig_mapping.set_global(my_qconfig) + # Verify the correct qconfig was used + example_inputs = (torch.randn(1, 1, 1, 1),) + m = prepare_fx(m, qconfig_mapping, example_inputs) + self.assertTrue(isinstance(old_global_qconfig.activation(), HistogramObserver)) + self.assertTrue(isinstance(my_qconfig.activation(), MinMaxObserver)) + self.assertTrue(hasattr(m, "activation_post_process_0")) + self.assertTrue(hasattr(m, "activation_post_process_1")) + self.assertTrue(isinstance(m.activation_post_process_0, MinMaxObserver)) + self.assertTrue(isinstance(m.activation_post_process_1, MinMaxObserver)) + # Dummy classes for PrepareCustomConfig testing class _DummyStandaloneModule: @@ -2634,7 +2679,7 @@ def forward(self, x): } for quant_type in [QuantType.STATIC, QuantType.DYNAMIC]: - key = quant_type_to_str(quant_type) + key = _get_quant_type_to_str(quant_type) qconfig, quantized_module_class, num_observers = test_configs[key] qconfig_dict = {"": qconfig} if key == "static": @@ -4215,6 +4260,81 @@ def forward(self, inputs: torch.Tensor, h0: torch.Tensor, c0: torch.Tensor): } self._test_static_lstm_helper(m2, prepare_node_occurrence, convert_node_occurrence2) + def test_static_lstm_with_custom_fixed_qparams(self): + """ + Test statically quantized LSTM with custom fixed qparams assigned to each of the + inner submodules. This flow requires users to extend `torch.ao.nn.quantizable.LSTM` + and use the child class in the custom module mapping. + """ + class MyModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.my_lstm = torch.nn.LSTM(50, 50, 1) + + def forward(self, inputs: torch.Tensor, h0: torch.Tensor, c0: torch.Tensor): + x = self.my_lstm(inputs, (h0, c0)) + return x + + class UserLSTM(torch.ao.nn.quantizable.LSTM): + """ + Example of user provided LSTM implementation that has fixed qparams assigned + to the inner submodules. + """ + @classmethod + def from_float(cls, other): + assert isinstance(other, cls._FLOAT_MODULE) + # uint16, [-16, 16) + linear_output_obs_ctr = FixedQParamsObserver.with_args(scale=2 ** -11, zero_point=2 ** 15, dtype=torch.qint32) + # uint16, [0, 1) + sigmoid_obs_ctr = FixedQParamsObserver.with_args(scale=2 ** -16, zero_point=0, dtype=torch.qint32) + # uint16, [-1, 1) + tanh_obs_ctr = FixedQParamsObserver.with_args(scale=2 ** -15, zero_point=2 ** 15, dtype=torch.qint32) + # int16, [-16, 16) + cell_state_obs_ctr = FixedQParamsObserver.with_args(scale=2 ** -11, zero_point=0, dtype=torch.qint32) + # uint8, [-1, 1) + hidden_state_obs_ctr = FixedQParamsObserver.with_args(scale=2 ** -7, zero_point=2 ** 7, dtype=torch.quint8) + return torch.ao.quantization.utils._get_lstm_with_individually_observed_parts( + float_lstm=other, + linear_output_obs_ctr=linear_output_obs_ctr, + sigmoid_obs_ctr=sigmoid_obs_ctr, + tanh_obs_ctr=tanh_obs_ctr, + cell_state_obs_ctr=cell_state_obs_ctr, + hidden_state_obs_ctr=hidden_state_obs_ctr, + ) + + # Prepare model + qconfig_mapping = get_default_qconfig_mapping() + example_inputs = (torch.rand(5, 3, 50), torch.rand(1, 3, 50), torch.randn(1, 3, 50)) + prepare_custom_config = PrepareCustomConfig() \ + .set_float_to_observed_mapping(torch.nn.LSTM, UserLSTM) + convert_custom_config = ConvertCustomConfig() \ + .set_observed_to_quantized_mapping(UserLSTM, torch.ao.nn.quantized.LSTM) + model = MyModel() + model = prepare_fx(model, qconfig_mapping, example_inputs, prepare_custom_config=prepare_custom_config) + + # Validate that the observers inserted to each inner module has the expected qparams + def validate_qparams(inner_module: torch.nn.Module, scale: float, zero_point: int, dtype: torch.dtype): + self.assertTrue(hasattr(inner_module, "activation_post_process")) + obs = inner_module.activation_post_process + self.assertTrue(isinstance(obs, FixedQParamsObserver)) + self.assertEqual(obs.scale, scale) + self.assertEqual(obs.zero_point, zero_point) + self.assertEqual(obs.dtype, dtype) + cell = model.my_lstm.layers[0].layer_fw.cell + validate_qparams(cell.igates, 2 ** -11, 2 ** 15, torch.qint32) + validate_qparams(cell.hgates, 2 ** -11, 2 ** 15, torch.qint32) + validate_qparams(cell.input_gate, 2 ** -16, 0, torch.qint32) + validate_qparams(cell.forget_gate, 2 ** -16, 0, torch.qint32) + validate_qparams(cell.cell_gate, 2 ** -15, 2 ** 15, torch.qint32) + validate_qparams(cell.output_gate, 2 ** -16, 0, torch.qint32) + validate_qparams(cell.fgate_cx_igate_cgate, 2 ** -11, 0, torch.qint32) + validate_qparams(cell.ogate_cy, 2 ** -7, 2 ** 7, torch.quint8) + + # Make sure the rest of the flow runs + model(*example_inputs) + model = convert_fx(model, convert_custom_config=convert_custom_config, _remove_qconfig=False) + model(*example_inputs) + def test_reroute_tuple_getitem_patterns(self): """ The following graph should redirect the output to `b`. After the transformation, @@ -5223,6 +5343,233 @@ def forward(self, x): # make sure this runs m = prepare_fx(m, qconfig_mapping, example_inputs, backend_config=backend_config) + def test_get_default_qconfig_valid_backend(self): + """ Checks that AssertionError is raised when non expected backend input is specified + """ + invalid_backends = ["imaginary_backend", 3] + for invalid_backend in invalid_backends: + with self.assertRaisesRegex(AssertionError, "not supported"): + qconfig = get_default_qconfig(invalid_backend) + with self.assertRaisesRegex(AssertionError, "not supported"): + qconfig = get_default_qat_qconfig(invalid_backend) + with self.assertRaisesRegex(AssertionError, "not supported"): + qconfig_mapping = get_default_qconfig_mapping(invalid_backend) + with self.assertRaisesRegex(AssertionError, "not supported"): + qconfig_mapping = get_default_qat_qconfig_mapping(invalid_backend) + + def test__convert_to_reference_decomposed_fx(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(5, 10) + + def forward(self, x): + return self.linear(x) + + m = M().eval() + qconfig_mapping = get_default_qconfig_mapping("fbgemm") + example_inputs = (torch.randn(1, 5),) + m = prepare_fx(m, qconfig_mapping, example_inputs) + m_ref = copy.deepcopy(m) + m_ref = convert_to_reference_fx(m_ref) + m = _convert_to_reference_decomposed_fx(m) + expected_occurrence = { + ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor): 2, + ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor): 2, + } + self.checkGraphModuleNodes( + m, + expected_node_occurrence=expected_occurrence) + # make sure it runs + res_ref = m_ref(*example_inputs) + res = m(*example_inputs) + self.assertEqual(res, res_ref) + + @skipIfNoQNNPACK + def test__convert_to_reference_decomposed_fx_dynamic_quant(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(5, 10) + + def forward(self, x): + return self.linear(x) + + # to avoid reduce_range + with override_quantized_engine("qnnpack"): + m = M().eval() + qconfig_mapping = get_default_qconfig_mapping("fbgemm") \ + .set_object_type(torch.nn.Linear, default_dynamic_qconfig) + example_inputs = (torch.randn(1, 5),) + m = prepare_fx(m, qconfig_mapping, example_inputs) + m(*example_inputs) + m_ref = copy.deepcopy(m) + m_ref = convert_to_reference_fx(m_ref) + m = _convert_to_reference_decomposed_fx(m) + expected_occurrence = { + ns.call_function(torch.ops.quantized_decomposed.choose_qparams.tensor): 1, + ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.tensor): 1, + ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.tensor): 1, + } + self.checkGraphModuleNodes( + m, + expected_node_occurrence=expected_occurrence) + # make sure it runs + res_ref = m_ref(*example_inputs) + res = m(*example_inputs) + self.assertEqual(res, res_ref) + + def test__convert_to_reference_decomposed_fx_per_channel_quant(self): + class M(torch.nn.Module): + def forward(self, x, weight, bias): + return F.linear(x, weight, bias) + + m = M().eval() + qconfig_mapping = get_default_qconfig_mapping("fbgemm") \ + .set_object_type(F.linear, default_per_channel_qconfig) + example_inputs = (torch.randn(1, 5), torch.randn(10, 5), torch.randn(10,)) + m = prepare_fx(m, qconfig_mapping, example_inputs) + m(*example_inputs) + m_ref = copy.deepcopy(m) + m_ref = convert_to_reference_fx(m_ref) + m = _convert_to_reference_decomposed_fx(m) + expected_occurrence = { + # for input and output activations + ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor): 2, + ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor): 2, + # for weight + ns.call_function(torch.ops.quantized_decomposed.quantize_per_channel): 1, + ns.call_function(torch.ops.quantized_decomposed.dequantize_per_channel): 1, + } + self.checkGraphModuleNodes( + m, + expected_node_occurrence=expected_occurrence) + # make sure it runs + res_ref = m_ref(*example_inputs) + res = m(*example_inputs) + self.assertEqual(res, res_ref) + + def test_change_backend_config_for_fixed_qparam_ops(self): + """ Making sure we can skip validation of qconfigs for fixedqparam ops based + on BackendConfig + """ + class M(nn.Module): + def __init__(self): + super().__init__() + self.tanh = torch.nn.Tanh() + + def forward(self, x: torch.Tensor): + x = self.tanh(x) + return x + + model = M().eval() + # we set a global default_qconfig, which will be ignored since the backend + # we defined doesn't support anything + # this is to make sure we don't validate the qconfig when BackendConfig does not + # have fixed qparam op related configurations + qconfig_mapping = QConfigMapping().set_global(default_qconfig) + backend_config = BackendConfig() + # make sure this runs + model = prepare_fx( + model, + qconfig_mapping=qconfig_mapping, + example_inputs=(torch.randn(1, 2, 3, 4),), + backend_config=backend_config + ) + + def test_channel_shuffle_lowering(self): + # Three versions of channel shuffle + class M1(torch.nn.Module): + def __init__(self): + super().__init__() + self.op = torch.nn.ChannelShuffle(2) + + def forward(self, x): + return self.op(x + x) + x + + class M2(torch.nn.Module): + def forward(self, x): + return torch.channel_shuffle(x + x, 2) + x + + class M3(torch.nn.Module): + def forward(self, x): + return torch.nn.functional.channel_shuffle(x + x, 2) + x + + x = torch.randn(4, 4, 4, 4) + # torch.channel_shuffle is equivalent to torch.nn.functional.channel_shuffle + model_node_pairs = [ + (M1().eval(), ns.call_module(torch.nn.ChannelShuffle)), + (M2().eval(), ns.call_function(torch.channel_shuffle)), + (M3().eval(), ns.call_function(torch.channel_shuffle)) + ] + for m, node in model_node_pairs: + m = prepare_fx(m, {"": default_qconfig}, example_inputs=(x,)) + m_copy = copy.deepcopy(m) + m = convert_fx(m) + m_ref = convert_to_reference_fx(m_copy) + node_occurrence = { + node: 1, + ns.call_function(torch.quantize_per_tensor): 1, + ns.call_method("dequantize"): 1 + } + node_occurrence_ref = { + node: 1, + ns.call_function(torch.quantize_per_tensor): 4, + ns.call_method("dequantize"): 4 + } + self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) + self.checkGraphModuleNodes(m_ref, expected_node_occurrence=node_occurrence_ref) + + def test_match_pattern_with_multiple_args(self): + """ Test that we can match a pattern that has multiple arguments + Pattern: + shape \ + transpose (observed) -> reshape -> output (observed) -> + + where `reshape` has two arguments + """ + + def _get_pattern_configs(): + backend_pattern_configs = [] + observation_type = ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT + weighted_op_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.qint8, + bias_dtype=torch.float, + ) + dtype_configs = [weighted_op_quint8_dtype_config] + + def root_node_getter(node_pattern): + reshape, transpose, shape = node_pattern + return transpose + + backend_pattern_configs.append( + BackendPatternConfig((torch.reshape, torch.transpose, MatchAllNode)) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_root_node_getter(root_node_getter)) + return backend_pattern_configs + + backend_config = BackendConfig().set_backend_pattern_configs(_get_pattern_configs()) + + class M(torch.nn.Module): + def forward(self, x): + x = torch.transpose(x, 0, 1) + x = torch.reshape(x, (-1,)) + return x + + m = M().eval() + qconfig_mapping = QConfigMapping().set_global(default_qconfig) + example_inputs = (torch.randn(1, 3, 3, 3),) + m = prepare_fx(m, qconfig_mapping, example_inputs, backend_config=backend_config) + node_occurrence = { + # one for input of the pattern and one for output of the pattern + ns.call_module(MinMaxObserver): 2 + } + self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) + + @skipIfNoFBGEMM class TestQuantizeFxOps(QuantizationTestCase): def setUp(self): @@ -6677,9 +7024,8 @@ def forward(self, x): M(), data, quant_type, custom_qconfig_dict=qconfig_mapping, expected_node_occurrence=node_occurrence, is_reference=True) - def test_fixed_qparams_ops_qconfig_error(self): - """ Test that a proper error message is shown when user don't specify the correct - qconfig for fixed qaprams ops + def test_fixed_qparams_ops_wrong_qconfig(self): + """ Test that wrong qconfigs for fixed qparams ops results in the ops not being quantized. """ class M(torch.nn.Module): def __init__(self): @@ -6699,8 +7045,15 @@ def forward(self, x): data = (torch.randn((2, 2, 2, 2), dtype=torch.float),) qconfig_mapping = QConfigMapping().set_global(default_qconfig) m = M().eval() - with self.assertRaisesRegex(ValueError, "get_default_qconfig_mapping"): - m = prepare_fx(m, qconfig_mapping, data) + node_occurrence = { + ns.call_function(torch.quantize_per_tensor): 0, + ns.call_method("dequantize"): 0, + } + self.checkGraphModeFxOp( + m, data, QuantType.STATIC, custom_qconfig_dict=qconfig_mapping, + expected_node_occurrence=node_occurrence, is_reference=True) + self.assertTrue(isinstance(m.sigmoid, torch.nn.Sigmoid)) + self.assertTrue(isinstance(m.tanh, torch.nn.Tanh)) @skipIfNoFBGEMM def test_general_shape_ops(self): @@ -8085,7 +8438,7 @@ def forward(self, x): inp = torch.randn(5, 5, device=device, requires_grad=True) out_ref = prepared_ref(inp) out = prepared(inp) - torch.testing.assert_allclose(out, out_ref) + torch.testing.assert_close(out, out_ref) # try backward pass labels = torch.randn(5, 5, device=device) @@ -8093,7 +8446,7 @@ def forward(self, x): grad = torch.autograd.grad(loss, [inp]) loss_ref = (out_ref - labels).sum() grad_ref = torch.autograd.grad(loss_ref, [inp]) - torch.testing.assert_allclose(grad[0], grad_ref[0]) + torch.testing.assert_close(grad[0], grad_ref[0]) if 'fbgemm' in torch.backends.quantized.supported_engines: # During the lowering step in convert, fold_weight calls quantized::linear_prepack @@ -8106,7 +8459,7 @@ def forward(self, x): out = converted(inp) out_ref = converted_ref(inp) - torch.testing.assert_allclose(out, out_ref) + torch.testing.assert_close(out, out_ref) if __name__ == '__main__': raise RuntimeError("This test file is not meant to be run directly, use:\n\n" "\tpython test/test_quantization.py TESTNAME\n\n" diff --git a/test/quantization/jit/test_quantize_jit.py b/test/quantization/jit/test_quantize_jit.py index 84ab3a723b70f..7726dc04c7111 100644 --- a/test/quantization/jit/test_quantize_jit.py +++ b/test/quantization/jit/test_quantize_jit.py @@ -73,7 +73,6 @@ from torch.testing._internal.jit_utils import attrs_with_prefix from torch.testing._internal.jit_utils import get_forward from torch.testing._internal.jit_utils import get_forward_graph -from torch.testing._internal.common_utils import skipIfSlowGradcheckEnv from torch.jit._recursive import wrap_cpp_module @@ -1626,7 +1625,6 @@ def forward(self, x): torch.jit.save(model, b) -@skipIfSlowGradcheckEnv class TestQuantizeJitOps(QuantizationTestCase): """Test graph mode post training static quantization works for individual ops end to end. @@ -2674,6 +2672,7 @@ def forward(self, x): m.graph ) + @override_qengines def test_hardswish(self): class FunctionalHardswish(torch.nn.Module): def __init__(self, inplace): @@ -2698,6 +2697,7 @@ def forward(self, input): m.graph ) + @override_qengines def test_elu(self): class FunctionalELU(torch.nn.Module): def __init__(self, inplace=False): @@ -2714,6 +2714,7 @@ def forward(self, input): m = self.checkGraphModeOp(m, self.img_data_2d, "quantized::elu", tracing) FileCheck().check_not("aten::elu").check_not("aten::elu_").run(m.graph) + @override_qengines def test_layer_norm(self): data = [[torch.rand((1, 2, 5, 5), dtype=torch.float)] for _ in range(2)] layer_norm = torch.nn.LayerNorm([2, 5, 5]) @@ -2723,6 +2724,7 @@ def test_layer_norm(self): ) FileCheck().check_not("aten::layer_norm").run(m.graph) + @override_qengines def test_group_norm(self): data = [[torch.rand((1, 4, 5, 5), dtype=torch.float)] for _ in range(2)] group_norm = torch.nn.GroupNorm(2, 4) @@ -2732,6 +2734,7 @@ def test_group_norm(self): ) FileCheck().check_not("aten::group_norm").run(m.graph) + @override_qengines def test_instance_norm(self): data_1d = [[torch.rand((1, 4, 5), dtype=torch.float)] for _ in range(2)] data_2d = [[torch.rand((1, 4, 5, 1), dtype=torch.float)] for _ in range(2)] diff --git a/test/run_test.py b/test/run_test.py index 35004406d0115..62ce99ae7937a 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -26,6 +26,7 @@ shell, set_cwd, parser as common_parser, + is_slow_gradcheck_env, ) import torch.distributed as dist from torch.multiprocessing import get_context @@ -100,9 +101,6 @@ def skip_test_p(name: str) -> bool: 'test_jit_simple', 'test_jit_string', 'test_kernel_launch_checks', - 'test_metal', - # Right now we have a separate CI job for running MPS - 'test_mps', 'test_nnapi', 'test_segment_reductions', 'test_static_runtime', @@ -115,6 +113,7 @@ def skip_test_p(name: str) -> bool: "distributed/launcher/bin/test_script_is_torchelastic_launched", "distributed/launcher/bin/test_script_local_rank", "distributed/test_c10d_spawn", + "distributed/_tensor/test_dtensor_ops", 'distributions/test_transforms', 'distributions/test_utils', ], @@ -170,6 +169,7 @@ def skip_test_p(name: str) -> bool: "distributed/elastic/events/lib_test", "distributed/elastic/agent/server/test/api_test", "test_deploy", + "distributed/test_c10d_error_logger.py" ] WINDOWS_BLOCKLIST = [ @@ -301,6 +301,7 @@ def skip_test_p(name: str) -> bool: "test_nn", "test_ops", "test_ops_gradients", + "test_ops_fwd_gradients", "test_ops_jit", "test_torch" ] @@ -379,6 +380,23 @@ def skip_test_p(name: str) -> bool: "distributions/test_distributions", ] +# These are just the slowest ones, this isn't an exhaustive list. +TESTS_NOT_USING_GRADCHECK = [ + # Note that you should use skipIfSlowGradcheckEnv if you do not wish to + # skip all the tests in that file, e.g. test_mps + "doctests", + "test_meta", + "test_hub", + "test_fx", + "test_decomp", + "test_cpp_extensions_jit", + "test_jit", + "test_ops", + "test_ops_jit", + "dynamo/test_recompile_ux", + "inductor/test_smoke", + "test_quantization", +] def print_to_stderr(message): print(message, file=sys.stderr) @@ -422,8 +440,11 @@ def run_test( if options.pytest: unittest_args = [arg if arg != "-f" else "-x" for arg in unittest_args] elif IS_CI: + ci_args = ["--import-slow-tests", "--import-disabled-tests"] + if os.getenv("PYTORCH_TEST_RERUN_DISABLED_TESTS", "0") == "1": + ci_args.append("--rerun-disabled-tests") # use the downloaded test cases configuration, not supported in pytest - unittest_args.extend(["--import-slow-tests", "--import-disabled-tests"]) + unittest_args.extend(ci_args) # Extra arguments are not supported with pytest executable = get_executable_command( @@ -722,33 +743,72 @@ def print_log_file(test: str, file_path: str, failed: bool) -> None: def run_test_ops(test_module, test_directory, options): + if os.getenv("PYTORCH_TEST_RERUN_DISABLED_TESTS", "0") == "1": + # When under rerun-disabled-tests mode, run the same tests multiple times to determine their + # flakiness status. Default to 50 re-runs + rerun_options = ["--flake-finder", "--flake-runs=50"] + else: + # When under the normal mode, retry a failed test 2 more times. -x means stop at the first + # failure + rerun_options = ["-x", "--reruns=2"] + + default_unittest_args = [ + "--use-pytest", + "-vv", + "-rfEX" + ] + default_unittest_args.extend(rerun_options) + if 'slow-gradcheck' in os.getenv("BUILD_ENVIRONMENT", ""): + extra_unittest_args = default_unittest_args.copy() # there are a lot of tests that take up a lot of space in slowgrad check, so don't bother parallelizing # it's also on periodic so we don't care about TTS as much - return run_test(test_module, test_directory, copy.deepcopy(options), - extra_unittest_args=["--use-pytest", '-vv', '-x', '--reruns=2', '-rfEX'], - ) + return run_test( + test_module, + test_directory, + copy.deepcopy(options), + extra_unittest_args=extra_unittest_args, + ) + return_codes = [] os.environ["NUM_PARALLEL_PROCS"] = str(NUM_PROCS) pool = get_context("spawn").Pool(NUM_PROCS) for i in range(NUM_PROCS): - return_code = pool.apply_async(run_test, args=(test_module, test_directory, copy.deepcopy(options)), - kwds={"extra_unittest_args": ["--use-pytest", '-vv', '-x', '--reruns=2', '-rfEX', - f'--shard-id={i}', f'--num-shards={NUM_PROCS}', - "-k=not _linalg_cholesky_"], - }) + extra_unittest_args = default_unittest_args.copy() + extra_unittest_args.extend([ + f"--shard-id={i}", + f"--num-shards={NUM_PROCS}", + "-k=not _linalg_cholesky_", + ]) + + return_code = pool.apply_async( + run_test, + args=(test_module, test_directory, copy.deepcopy(options)), + kwds={ + "extra_unittest_args": extra_unittest_args, + }, + ) return_codes.append(return_code) + pool.close() pool.join() - del os.environ['NUM_PARALLEL_PROCS'] + del os.environ["NUM_PARALLEL_PROCS"] for return_code in return_codes: if return_code.get() != 0: return return_code.get() - return_code = run_test(test_module, test_directory, copy.deepcopy(options), - extra_unittest_args=["--use-pytest", '-vv', '-x', '--reruns=2', '-rfEX', - "-k=_linalg_cholesky_"], - ) + + extra_unittest_args = default_unittest_args.copy() + extra_unittest_args.extend([ + "-k=_linalg_cholesky_", + ]) + + return_code = run_test( + test_module, + test_directory, + copy.deepcopy(options), + extra_unittest_args=extra_unittest_args, + ) return return_code @@ -765,6 +825,7 @@ def run_test_ops(test_module, test_directory, options): "distributed/test_c10d_common": get_run_test_with_subprocess_fn(), "distributed/test_c10d_spawn_gloo": get_run_test_with_subprocess_fn(), "distributed/test_c10d_spawn_nccl": get_run_test_with_subprocess_fn(), + "distributed/test_c10d_spawn_ucc": get_run_test_with_subprocess_fn(), "distributed/test_store": get_run_test_with_subprocess_fn(), "distributed/test_pg_wrapper": get_run_test_with_subprocess_fn(), "distributed/rpc/test_faulty_agent": get_run_test_with_subprocess_fn(), @@ -772,8 +833,10 @@ def run_test_ops(test_module, test_directory, options): "distributed/rpc/test_share_memory": get_run_test_with_subprocess_fn(), "distributed/rpc/cuda/test_tensorpipe_agent": get_run_test_with_subprocess_fn(), "doctests": run_doctests, + "inductor/test_torchinductor_opinfo": run_test_ops, "test_ops": run_test_ops, "test_ops_gradients": run_test_ops, + "test_ops_fwd_gradients": run_test_ops, "test_ops_jit": run_test_ops, "functorch/test_ops": run_test_ops, } @@ -822,6 +885,14 @@ def parse_args(): "This requires functorch to already be installed." ) ) + parser.add_argument( + "--mps", + "--mps", + action="store_true", + help=( + "If this flag is present, we will only run test_mps and test_metal" + ) + ) parser.add_argument( "-core", "--core", @@ -981,11 +1052,11 @@ def find_test_index(test, selected_tests, find_last_index=False): return found_idx -def exclude_tests(exclude_list, selected_tests, exclude_message=None): +def exclude_tests(exclude_list, selected_tests, exclude_message=None, exact_match=False): for exclude_test in exclude_list: tests_copy = selected_tests[:] for test in tests_copy: - if test.startswith(exclude_test): + if (not exact_match and test.startswith(exclude_test)) or test == exclude_test: if exclude_message is not None: print_to_stderr("Excluding {} {}".format(test, exclude_message)) selected_tests.remove(test) @@ -1031,6 +1102,12 @@ def get_selected_tests(options): # Exclude all functorch tests otherwise options.exclude.extend(FUNCTORCH_TESTS) + if options.mps: + selected_tests = ['test_mps', 'test_metal'] + else: + # Exclude all mps tests otherwise + options.exclude.extend(['test_mps', 'test_metal']) + # process reordering if options.bring_to_front: to_front = set(options.bring_to_front) @@ -1115,6 +1192,11 @@ def get_selected_tests(options): else: print("Found test time stats from artifacts") test_file_times_config = test_file_times[test_config] + if is_slow_gradcheck_env(): + # HACK: hardcode approx test times, so these two don't get put in the same shard + # we can remove this when their actual runtimes are recorded + test_file_times_config["test_ops_fwd_gradients"] = 3600 * 2 + 600 # 2:10 + test_file_times_config["test_ops_gradients"] = 3600 * 2 + 600 # 2:10 shards = calculate_shards(num_shards, selected_tests, test_file_times_config, must_serial=must_serial) _, tests_from_shard = shards[which_shard - 1] @@ -1130,6 +1212,11 @@ def get_selected_tests(options): selected_tests = exclude_tests(TESTS_REQUIRING_LAPACK, selected_tests, "PyTorch is built without LAPACK support.") + if is_slow_gradcheck_env(): + selected_tests = exclude_tests(TESTS_NOT_USING_GRADCHECK, selected_tests, + "Running in slow gradcheck mode, skipping tests " + "that don't use gradcheck.", exact_match=True) + if options.distributed_tests: # Run distributed tests with multiple backends across all shards, one per backend selected_tests.extend(DISTRIBUTED_TESTS_WITH_MULTIPLE_BACKENDS.keys()) diff --git a/test/scripts/run_cuda_memcheck.py b/test/scripts/run_cuda_memcheck.py index 10202e416d008..7d882b8c1fff4 100755 --- a/test/scripts/run_cuda_memcheck.py +++ b/test/scripts/run_cuda_memcheck.py @@ -119,7 +119,7 @@ async def run1(coroutine_id): gpuid = coroutine_id % GPUS else: gpu_assignments = args.gpus.split(':') - assert args.nproc == len(gpu_assignments), 'Please specify GPU assignmnent for each process, separated by :' + assert args.nproc == len(gpu_assignments), 'Please specify GPU assignment for each process, separated by :' gpuid = gpu_assignments[coroutine_id] while progress < len(ALL_TESTS): diff --git a/test/test_ao_sparsity.py b/test/test_ao_sparsity.py index ebe89689d6861..3024b3b100d45 100644 --- a/test/test_ao_sparsity.py +++ b/test/test_ao_sparsity.py @@ -14,9 +14,7 @@ from ao.sparsity.test_sparsifier import TestBaseSparsifier # noqa: F401 from ao.sparsity.test_sparsifier import TestWeightNormSparsifier # noqa: F401 from ao.sparsity.test_sparsifier import TestNearlyDiagonalSparsifier # noqa: F401 - -# Pruner -from ao.sparsity.test_pruner import TestBasePruner # noqa: F401 +from ao.sparsity.test_structured_sparsifier import TestBaseStructuredSparsifier # noqa: F401 # Scheduler from ao.sparsity.test_scheduler import TestScheduler # noqa: F401 diff --git a/test/test_autocast.py b/test/test_autocast.py index bfbe46d08b890..1a8263a79f93d 100644 --- a/test/test_autocast.py +++ b/test/test_autocast.py @@ -1,9 +1,12 @@ # Owner(s): ["module: unknown"] import collections +import unittest + import torch from torch.testing._internal.common_utils import TestCase, run_tests from torch.testing._internal.autocast_test_lists import AutocastCPUTestLists +from torch.utils._python_dispatch import TorchDispatchMode class TestAutocastCPU(TestCase): def setUp(self): @@ -122,6 +125,64 @@ def test_autocast_torch_need_autocast_promote(self): for op, args in self.autocast_lists.torch_need_autocast_promote: self._run_autocast_outofplace(op, args, torch.float32) +@unittest.skipIf(not torch.cuda.is_available(), "requires cuda") +class TestAutocastGPU(TestCase): + def test_cast_cache_is_global(self): + """ + Verifies that the autocast cache is global. This is done by + mocking out cache clearing at the end of the forward pass, + running forward+backward with an explicit call to autocast in the + backward, and verifying that the weight only get cast to float16 once. + """ + + class CustomLinear(torch.autograd.Function): + @staticmethod + def forward(ctx, x, w_t): + ctx.save_for_backward(x, w_t) + return torch.nn.functional.linear(x, w_t) + + @staticmethod + def backward(ctx, grad_output): + x, w_t = ctx.saved_tensors + with torch.autocast(device_type='cuda'): + dL_dX = torch.matmul(grad_output, w_t) + dL_dW = torch.matmul(x.transpose(0, 1), grad_output).transpose(0, 1) + return dL_dX, dL_dW + + data = torch.randn(2, 3).cuda() + weight = torch.nn.Parameter(torch.randn(4, 3).cuda()) + weight_dtype_cast_counter = 0 + + class WeightDTypeCastCounterMode(TorchDispatchMode): + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + if ( + func is torch.ops.aten._to_copy.default and + args[0] is weight and + kwargs['dtype'] is torch.float16 + ): + nonlocal weight_dtype_cast_counter + weight_dtype_cast_counter += 1 + return func(*args, **kwargs) + + def __enter__(self): + self.old_clear_cache = torch.clear_autocast_cache + torch.clear_autocast_cache = lambda: None + return super().__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + torch.clear_autocast_cache = self.old_clear_cache + return super().__exit__(exc_type, exc_val, exc_tb) + + with WeightDTypeCastCounterMode(): + with torch.autocast(device_type='cuda'): + output = CustomLinear.apply(data, weight) + s = output.sum() + s.backward() + + self.assertEqual(weight_dtype_cast_counter, 1) + + class TestTorchAutocast(TestCase): def test_autocast_fast_dtype(self): gpu_fast_dtype = torch.get_autocast_gpu_dtype() diff --git a/test/test_autograd.py b/test/test_autograd.py index bcb42449c349f..f7efb9f64216a 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -62,6 +62,83 @@ def graph_desc(fn): class TestAutograd(TestCase): + def test_copy_slices_graph_task_updates(self): + def f1(x, y): + out = x.clone().view(-1) + out += y + return out + + def f2(x, y): + out = x.clone().view(-1) + b = out * 2 + out += y + return out + b + + x = torch.rand(2, requires_grad=True) + y = torch.rand(2, requires_grad=True) + + y_safe = torch._C._functions.DelayedError("Boom!", 1)(y) + + for f in [f1, f2]: + # Ensure that the error Node works + out = f(x, y_safe) + with self.assertRaisesRegex(RuntimeError, "Boom!"): + out.sum().backward() + + out = f(x, y_safe) + with self.assertRaisesRegex(RuntimeError, "Boom!"): + torch.autograd.grad(out.sum(), y) + + # Ensure that if we don't ask for y, it doesn't crash + out = f(x, y_safe) + torch.autograd.grad(out.sum(), x) + + out = f(x, y_safe) + torch.autograd.grad(out.sum(), y_safe) + + out = f(x, y_safe) + torch.autograd.grad(out.sum(), (x, y_safe)) + + # Ensure that we don't run extra view Node + def f3(x, y): + out = x.clone().view(-1) + + def hook(*args): + # This should never be called! + self.assertTrue(False) + out.register_hook(hook) + + b = out + y + out += y + return out + b, b + + out, b = f3(x, y_safe) + torch.autograd.grad(out.sum(), (b, y_safe)) + + + def test_grad_mode_class_decoration(self): + # Decorating class is deprecated and should not be used + with self.assertWarnsRegex(UserWarning, "Decorating classes is deprecated"): + @torch.no_grad() + class Foo(): + pass + + # Decorating functions or methods is fine though + with warnings.catch_warnings(record=True) as w: + @torch.no_grad() + def foo(): + pass + + class Foo2(): + @torch.no_grad() + def __init__(self): + pass + + @torch.no_grad() + def foo(self): + pass + + self.assertEqual(len(w), 0) def test_tensor_grad_warnings(self): dummy = torch.empty(1) @@ -435,7 +512,7 @@ def fn(x): # .backward(inputs=) is OK out = c.sum() - torch.autograd.backward(out, inputs=(a,), retain_graph=True) + torch.autograd.backward(out, inputs=(a, b), retain_graph=True) self.assertEqual(counter[0], 2) # .backward() is OK @@ -467,6 +544,94 @@ def fn(x): with self.assertRaisesRegex(RuntimeError, "expects an grad_fn"): torch._C._will_engine_execute_node(out) + def test_custom_function_setup_context_simple(self): + class MySquare(Function): + @staticmethod + def forward(x): + return x ** 2 + + @staticmethod + def setup_context(ctx, inputs, outputs): + x, = inputs + ctx.save_for_backward(x) + + @staticmethod + def backward(ctx, gO): + x, = ctx.saved_tensors + return gO * 2 * x + + with torch.autograd.function._set_autograd_function_extension_enabled(True): + x = torch.randn([], requires_grad=True) + y = MySquare.apply(x) + gx, = torch.autograd.grad(y, x) + self.assertEqual(gx, 2 * x) + + def test_custom_function_setup_context_multi_output(self): + # Multiple outputs with some non-Tensor outputs. + class MySquare(Function): + @staticmethod + def forward(x): + two_x = x.item() * 2 + return x ** 2, two_x + + @staticmethod + def setup_context(ctx, inputs, outputs): + x, = inputs + _, two_x = outputs + ctx.two_x = two_x + + @staticmethod + @once_differentiable + def backward(ctx, gO, _): + return gO * ctx.two_x + + with torch.autograd.function._set_autograd_function_extension_enabled(True): + x = torch.randn([], requires_grad=True) + y, _ = MySquare.apply(x) + gx, = torch.autograd.grad(y, x) + self.assertEqual(gx, 2 * x) + + def test_custom_function_setup_context_multi_input(self): + class MyReshape(Function): + @staticmethod + def forward(x, shape, scale_forward, scale_backward): + return x.reshape(shape) * scale_forward + + @staticmethod + def setup_context(ctx, inputs, outputs): + x, shape, scale_forward, scale_backward = inputs + ctx.scale_backward = scale_backward + ctx.x_shape = x.shape + + @staticmethod + def backward(ctx, gO): + return gO.reshape(ctx.x_shape) * ctx.scale_backward, None, None, None + + class MyReshapeRef(Function): + @staticmethod + def forward(ctx, x, shape, scale_forward, scale_backward): + ctx.scale_backward = scale_backward + ctx.x_shape = x.shape + return x.reshape(shape) * scale_forward + + @staticmethod + def backward(ctx, gO): + return gO.reshape(ctx.x_shape) * ctx.scale_backward, None, None, None + + def test(x, shape, scale_forward, scale_backward): + y = MyReshape.apply(x, shape, scale_forward, scale_backward).sum() + gx, = torch.autograd.grad(y, x) + + y_expected = MyReshapeRef.apply(x, shape, scale_forward, scale_backward).sum() + gx_expected, = torch.autograd.grad(y_expected, x) + + self.assertEqual(y_expected, y) + self.assertEqual(gx_expected, gx) + + with torch.autograd.function._set_autograd_function_extension_enabled(True): + test(torch.randn(24, requires_grad=True), (3, 8), 7, 11) + test(torch.randn(2, 3, 4, requires_grad=True), (6, 4), -1, 2) + def test_accumulate_grad(self): grad_output = torch.ones(5, 5) @@ -904,6 +1069,20 @@ def prehook(grad_output): self.assertEqual(pre_counter[0], 4) self.assertTrue(torch.allclose(a.grad, torch.ones(3, 3) * 2)) + def test_autograd_function_extension_feature_flag(self): + try: + prev = torch._C._is_autograd_function_extension_enabled() + + torch._C._set_autograd_function_extension_enabled(True) + state = torch._C._is_autograd_function_extension_enabled() + self.assertTrue(state) + + torch._C._set_autograd_function_extension_enabled(False) + state = torch._C._is_autograd_function_extension_enabled() + self.assertFalse(state) + finally: + torch._C._set_autograd_function_extension_enabled(prev) + def test_grad_fn_prehooks_multiple_outputs(self): # Compute gradients without hooks b = torch.rand(3, 3, requires_grad=True) @@ -1626,7 +1805,7 @@ def coro_enable_grad(n=10): self.assertTrue(torch.is_grad_enabled()) yield (-i if has_raised else i) - except UnrecoverableException: + except UnrecoverableException : self.assertTrue(torch.is_grad_enabled()) raise SecondaryException @@ -2136,7 +2315,7 @@ def backward(ctx, grad_a, grad_b): def test_mark_non_differentiable_none(self): # This used to segfault because MyFunction would send back null # gradients to MulBackward, which is implemented in C++. C++ - # implemented functions expect incoming grad_ouptuts to be non-null. + # implemented functions expect incoming grad_outputs to be non-null. class MyFunction(Function): @staticmethod def forward(ctx, input): @@ -2380,7 +2559,7 @@ def backward(ctx, grad_x): with self.assertWarnsRegex(DeprecationWarning, "should not be instantiated"): f = Id() - # # After raising warning, should still return an instance + # After raising warning, should still return an instance self.assertIsInstance(f, Id) x = torch.zeros(1, requires_grad=True) with self.assertRaisesRegex(RuntimeError, "non-static forward method is deprecated"): @@ -2549,7 +2728,7 @@ def test_detach(self): self.assertEqual(x.grad, torch.ones(10, 10) * 2) self.assertEqual(y.grad, torch.ones(10, 10) * 2) - # in-place deatch on a view raises an exception + # in-place detach on a view raises an exception view = x.narrow(0, 1, 4) self.assertRaisesRegex(RuntimeError, 'view', lambda: view.detach_()) @@ -3073,6 +3252,128 @@ def hook(_): self.assertEqual(torch._C._current_graph_task_id(), -1) + def test_current_graph_task_execution_order(self): + predicted = [None] + + def hook(_): + predicted[0] = torch._C._current_graph_task_execution_order() + + def names(nodes): + return ", ".join([node.name().split(' ')[-1] for node in nodes]) + '\n' + + def grad_fns(*tensors): + # or grad accumulator + out = [] + for t in tensors: + if t.requires_grad and t.grad_fn is None: + out.append(t.clone().grad_fn.next_functions[0][0]) + else: + out.append(t.grad_fn) + return out + + actual = [] + + def register_logging_hooks(*tensors): + # register hooks that log the order in which they are called + def get_hook(i): + def hook(t_): + actual.append(tensors[i]) + return hook + + for i, t in enumerate(tensors): + t.register_hook(get_hook(i)) + + # Basic example: single path + t = torch.tensor(1., requires_grad=True).clone().sin().exp() + t.register_hook(hook) + with torch.autograd.set_multithreading_enabled(False): + t.backward() + self.assertExpectedInline(names(predicted[0]), """\ +ExpBackward0, SinBackward0, CloneBackward0, torch::autograd::AccumulateGrad +""") + + # We don't exactly follow sequence_nr order + a = torch.tensor(1., requires_grad=True) + b = torch.tensor(2., requires_grad=True) + c = b.sin() + d = a.cos() + out = c * d + register_logging_hooks(a, b, c, d, out) + out.register_hook(hook) + with torch.autograd.set_multithreading_enabled(False): + out.backward() + self.assertEqual(predicted[0], grad_fns(*actual)) + actual = [] + + # Multiple roots are also OK + a = torch.tensor(1., requires_grad=True) + b = a * 2 + out = b.sin() + out2 = b.cos() + out3 = b.cos() + register_logging_hooks(a, b, out, out2, out3) + out3.register_hook(hook) + with torch.autograd.set_multithreading_enabled(False): + torch.autograd.grad((out, out3, out2), inputs=(a,)) + self.assertExpectedInline(names(predicted[0]), """\ +CosBackward0, CosBackward0, SinBackward0, MulBackward0, torch::autograd::AccumulateGrad +""") + # TODO: Uncomment after update to hooks behavior + # self.assertEqual(predicted[0], grad_fns(*actual)) + actual = [] + + # Case where next node is nullptr + a = torch.tensor(1., requires_grad=True) + b = a * 2 + out = b.sin() + register_logging_hooks(a, b, out) + out.register_hook(hook) + with torch.autograd.set_multithreading_enabled(False): + out.backward() + self.assertEqual(predicted[0], grad_fns(*actual)) + actual = [] + + # Case where two `inputs` on the same path + a = torch.tensor(1., requires_grad=True) + b = a * 2 + out = b.sin() + register_logging_hooks(a, b, out) + out.register_hook(hook) + with torch.autograd.set_multithreading_enabled(False): + torch.autograd.grad((out,), inputs=(a, b,)) + self.assertEqual(names(predicted[0]), """\ +SinBackward0, MulBackward0, torch::autograd::AccumulateGrad +""") + # TODO: Uncomment after update to hooks behavior + # self.assertEqual(predicted[0], grad_fns(*actual)) + actual = [] + + # Case where `inputs` specifies a subgraph + a = torch.tensor(1., requires_grad=True) + b = torch.tensor(1., requires_grad=True) + c = a * b + out = c.sin() + register_logging_hooks(a, b, c, out) + out.register_hook(hook) + with torch.autograd.set_multithreading_enabled(False): + torch.autograd.grad((out,), inputs=(a,)) + self.assertEqual(names(predicted[0]), """\ +SinBackward0, MulBackward0, torch::autograd::AccumulateGrad +""") + # TODO: Uncomment after update to hooks behavior + # self.assertEqual(predicted[0], grad_fns(*actual)) + actual = [] + + # Errors when not called in a backward + with self.assertRaisesRegex(RuntimeError, "should only be called during the backward pass"): + torch._C._current_graph_task_execution_order() + + # Errors when context manager not enabled + t = torch.tensor(1., requires_grad=True).clone().sin().exp() + t.register_hook(hook) + with self.assertRaisesRegex(RuntimeError, "expects the current backward to be executed with multithreading disabled"): + t.backward() + def test_profiler(self): x = torch.randn(10, 10) @@ -3192,16 +3493,16 @@ def test_record_function_callbacks(self): foo_event = [event for event in function_events if "foo" in event.name][0] self.assertEqual(foo_event.count, 1) - def test_record_function_new_signatures(self): + def test_record_function_legacy(self): # Test the new _record_function ops work # Note: Remove once record_function uses these directly x = torch.randn(10, 10) with profile(use_kineto=kineto_available()) as p: - record = torch.ops.profiler._record_function_enter_new("bar", None) + handle = torch.ops.profiler._record_function_enter("bar", None) try: y = x * 2 + 4 finally: - torch.ops.profiler._record_function_exit(record) + torch.ops.profiler._record_function_exit(handle) function_events = p.function_events foo_event = [event for event in function_events if "bar" in event.name][0] @@ -3510,19 +3811,6 @@ def test_out_variant_raises_when_inputs_require_grad(self): # we should throw an exception if the output requires grad self.assertRaisesRegex(RuntimeError, 'out=', lambda: torch.mul(a, b, out=x)) - # TODO: see if this test can be OpInfo'd or moved to diagonal's test suite - def test_diagonal_derivative_requires_grad(self): - # test that the backward requires grad - # we do this is because diagonal_backward uses inplace - # operations and gradgradcheck does not catch whether - # they works as expected (it will succeed even if - # the gradient has requires_grad == False - a = torch.randn(5, 6, requires_grad=True) - b = torch.diagonal(a)**2 - c = b.sum() - d, = torch.autograd.grad(c, a, retain_graph=True, create_graph=True) - self.assertTrue(d.requires_grad) - def test_anomaly_detect_nan(self): size = 10 @@ -3962,20 +4250,32 @@ def fn(sparse): check(fast_mode=True) check(fast_mode=False) - @unittest.expectedFailure def test_gradcheck_sparse_csr_input(self): def check(fast_mode): def fn(sparse_csr): return torch.clone(sparse_csr).to_dense() - # Fails because gradcheck can't work with sparse csr inputs yet gradcheck(fn, torch.rand(2, 2, dtype=torch.double).to_sparse_csr().requires_grad_(True), check_sparse_nnz=True, check_batched_grad=False, fast_mode=fast_mode) with self.assertRaisesRegex(RuntimeError, 'gradcheck expects all tensor inputs are dense'): gradcheck(fn, torch.rand(2, 2, dtype=torch.double).to_sparse_csr().requires_grad_(True), check_sparse_nnz=False, check_batched_grad=False, fast_mode=fast_mode) - # check(fast_mode=True) # Segmentation fault + # check(fast_mode=True) # RuntimeError: sparse_mask_sparse_csr expects self to be 2D + check(fast_mode=False) + + def test_gradcheck_sparse_csc_input(self): + def check(fast_mode): + def fn(sparse_csc): + return torch.clone(sparse_csc).to_dense() + + gradcheck(fn, torch.rand(2, 2, dtype=torch.double).to_sparse_csc().requires_grad_(True), check_sparse_nnz=True, + check_batched_grad=False, fast_mode=fast_mode) + + with self.assertRaisesRegex(RuntimeError, 'gradcheck expects all tensor inputs are dense'): + gradcheck(fn, torch.rand(2, 2, dtype=torch.double).to_sparse_csc().requires_grad_(True), check_sparse_nnz=False, + check_batched_grad=False, fast_mode=fast_mode) + # check(fast_mode=True) # RuntimeError: Expected result Tensor to be of format CSR check(fast_mode=False) def test_gradcheck_nondeterministic(self): @@ -5338,9 +5638,7 @@ def test_grad_fn_attr_bindings(self): a = torch.ones(1, 1, 2, requires_grad=True) out = torch.nn.functional.interpolate(a, scale_factor=0.5, mode="linear") - self.assertIsNone(out.grad_fn._saved_output_size) - self.assertEqual(out.grad_fn._saved_scale_factors, (0.5,)) - self.assertIsInstance(out.grad_fn._saved_scale_factors[0], float) + self.assertEqual(out.grad_fn._saved_scales, 0.5) a = torch.ones(2, 2, requires_grad=True) out = torch.pdist(a, p=1) @@ -5772,6 +6070,23 @@ def run_tests(fn): run_tests(lambda v: v.swapdims_(0, 0)) run_tests(lambda v: v.swapaxes_(0, 0)) + def test_autograd_inplace_view_of_view(self): + x = torch.zeros(2) + with torch.no_grad(): + y = x.view(2) + y.requires_grad_(True) + z = y.view(2) + with self.assertRaisesRegex(RuntimeError, "a view of a view .* is being .* inside the no_grad block"): + z /= 2 + + x = torch.zeros(2) + with torch.inference_mode(): + y = x.view(2) + y.requires_grad_(True) + z = y.view(2) + with self.assertRaisesRegex(RuntimeError, "a view of a view .* is being .* inside the inference_mode"): + z /= 2 + # TODO This is not the correct behavior - # See https://github.com/pytorch/pytorch/issues/49825#issuecomment-794466627 def test_autograd_inplace_views_cross_dtype(self): @@ -6514,6 +6829,25 @@ def forward(self, x): gc.collect() self.assertIsNone(ref_()) + def test_full_backward_hook_double_backward(self): + x = torch.rand(1, requires_grad=True) + y = torch.rand_like(x) + + func = torch.nn.MSELoss() + counter = [0] + + def hook(module, grad_input, grad_output): + counter[0] += 1 + + func.register_full_backward_hook(hook) + + f = func(x, y) + + (gradx_f,) = torch.autograd.grad(f, x, create_graph=True) + self.assertEqual(counter[0], 1) + _ = torch.autograd.grad(gradx_f, x) + # We should not error, and counter should not be incremented + self.assertEqual(counter[0], 1) def test_input_buffer_accum(self): leaf = torch.rand(2, 2, requires_grad=True) @@ -6652,6 +6986,20 @@ def inplace_double(x): # not leaf, not output test(lambda: (1 + torch.randn(5, requires_grad=True)), False) + def test_saved_variable_saved_original_inplace_detach(self): + # Detaching a tensor that is saved input raises + a = torch.tensor(1., requires_grad=True).clone() + b = a.sin() + a.detach_() + with self.assertRaisesRegex(RuntimeError, "Trying to use a saved tensor that has been detached"): + b.backward() + + # Detaching a tensor that is saved as output is OK + a = torch.tensor(1., requires_grad=True).clone() + b = a.exp() + a.detach_() + b.backward() + def test_saved_variable_packing_unpacking_did_not_save_original_with_hooks(self): # Tests that packing/unpacking a SavedVariable works correctly with user-defined hooks # The saved_original / did_not_save_original distinction corresponds to the `save_original` @@ -6679,8 +7027,8 @@ def pack(x): with torch.autograd.graph.saved_tensors_hooks(pack, lambda x: x): a = torch.ones(5, requires_grad=True) - warnings.simplefilter('always') with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') y = a * a # should raise two warnings from a being saved twice self.assertEqual(len(w), 2) @@ -7055,6 +7403,28 @@ def get_out(): err_msg = "RuntimeError: one of the variables needed for gradient computation" self.assertTrue(err_msg in e.output.decode("utf-8")) + def test_view_func_replay(self): + def _assert_match_metadata(a, b): + self.assertEqual(a.size(), b.size()) + self.assertEqual(a.stride(), b.stride()) + self.assertEqual(a.storage_offset(), b.storage_offset()) + + def _test_op(fn, inp, args): + out = fn(inp, *args) + self.assertTrue(out._is_view) + self.assertTrue(out._base is inp) + + new_inp = inp.clone() + _assert_match_metadata(new_inp, inp) + new_out = out._view_func(new_inp) + _assert_match_metadata(new_out, out) + + _test_op(torch.select, torch.rand(2, 2), (0, 0)) + _test_op(torch.as_strided, torch.rand(2, 2), ((4,), (1,))) + _test_op(torch.view_as_complex, torch.rand(2, 2), ()) + _test_op(torch.view_as_real, torch.rand(2, 2, dtype=torch.cfloat), ()) + + def index_perm_variable(shape, max_indices): if not isinstance(shape, tuple): shape = (shape,) @@ -7794,6 +8164,7 @@ def test_min_max_median_backprops_to_all_values(self, device): self.assertEqual(x.grad.sum(), 1.) self.assertEqual((x.grad == 1 / 3).sum(), 3) + @skipIfMps def test_scatter_index_reduce_amin_amax_backprops_to_all_values(self, device): # tests that gradients are evenly distributed when there are multiple max/min values # tested here instead of adding a SampleInput as the backward for this case is non-differentiable for gradgrad @@ -7809,6 +8180,7 @@ def test_scatter_index_reduce_amin_amax_backprops_to_all_values(self, device): gradcheck(fn, (input, 0, idx, src, reduction), check_batched_grad=False) + @skipIfMps def test_scatter_index_reduce_prod_gradgrad_error(self, device): # test that double backward raises an error for the case where 2 zeros in src # are scattered to the same position in self @@ -8240,8 +8612,7 @@ def test_unused_output_device(self, devices): outputs = Broadcast.apply(list(range(len(devices))), x) y = outputs[-1] * 2 y.sum().backward() - # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 - self.assertEqualIgnoreType(x.grad, torch.ones(5, 5) * 2) + self.assertEqual(x.grad, torch.ones(5, 5) * 2) @deviceCountAtLeast(2) def test_backward_device(self, devices): @@ -8624,6 +8995,7 @@ def do_test(): self.assertNotWarn(do_test) + @skipIfMps def test_to_r_to_c(self, device): def do_test(): inp_r = torch.randn(3, 2, dtype=torch.double, device=device, @@ -8654,6 +9026,184 @@ def test_warning_in_backward(self, device): with self.assertWarnsRegex(UserWarning, "Warn from backward"): b.backward() +class TestAllowMutationOnSaved(TestCase): + def assertClonedLenEqual(self, ctx, n): + self.assertEqual(len(list(ctx.cloned.items())), n) + + def assertTIDMapLenEqual(self, ctx, n): + self.assertEqual(len(list(ctx.tid_to_weakhandle.items())), n) + + def test_basic(self): + a = torch.rand(2, 3, requires_grad=True) + + def fn(a): + b = a.clone() + out = (b**2).sum() + b.sin_() + out.sum().backward() + return a.grad + msg = "variables needed for gradient computation has been modified by an inplace" + with self.assertRaisesRegex(RuntimeError, msg): + fn(a) + + with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: + da = fn(a) + + self.assertTrue(torch.allclose(a * 2, da)) + self.assertClonedLenEqual(ctx, 0) + + def test_views(self): + a = torch.rand(2, 3, requires_grad=True) + + def fn(a): + b = a.clone() + c = b.view_as(b) + out = (b**2).sum() # How does this work? + c.sin_() + out.sum().backward() + return a.grad + + msg = "variables needed for gradient computation has been modified by an inplace" + with self.assertRaisesRegex(RuntimeError, msg): + fn(a) + + with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: + da = fn(a) + + self.assertClonedLenEqual(ctx, 0) + self.assertTrue(torch.allclose(a * 2, da)) + + def test_save_base_and_modify_view(self): + with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: + a = torch.rand(2, 3, requires_grad=True) + b = a.clone() + c = b[:1] + out = b**2 + # modify the view + c *= 10 + # self.assertClonedLenEqual(ctx, 1) + out.sum().backward() + self.assertClonedLenEqual(ctx, 0) + + self.assertClonedLenEqual(ctx, 0) + self.assertTrue(torch.allclose(a * 2, a.grad)) + + def test_save_view_modify_base(self): + with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: + a = torch.rand(2, 3, requires_grad=True) + b = a.clone() + c = b[:] + out = (c**2).sum() + b *= 2 + out.backward() + self.assertTrue(torch.allclose(a * 2, a.grad)) + + def test_double_backward(self): + with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: + a = torch.rand(2, 3, requires_grad=True) + b = a.clone() + out = (b**2).sum() + b.sin_() + torch.autograd.grad(out, a, create_graph=True) + da, = torch.autograd.grad(out, a, create_graph=True) + d2a, = torch.autograd.grad(da.sum(), a) + + self.assertTrue(torch.allclose(torch.ones_like(a) * 2, d2a)) + self.assertClonedLenEqual(ctx, 0) + + def test_saved_but_not_anymore(self): + # Make sure we don't clone if the tensor was once saved, but + # by the time we do in-place, it is no longer saved + with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: + a = torch.randn(2, 3, requires_grad=True).clone() + out = (a**2).sum() + self.assertTIDMapLenEqual(ctx, 1) + self.assertClonedLenEqual(ctx, 0) + out.backward() + a.sin_() + self.assertClonedLenEqual(ctx, 0) + out = (a**2).sum() + a.sin_() + self.assertClonedLenEqual(ctx, 1) + del out + self.assertClonedLenEqual(ctx, 0) + + def test_saved_same_tensor_many_times(self): + # We should only clone once + with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: + a = torch.randn(2, 3, requires_grad=True).clone() + b = a**2 + c = a**2 + a.sin_() + self.assertClonedLenEqual(ctx, 1) + del b, c + self.assertClonedLenEqual(ctx, 0) + + def test_saved_same_tensor_different_versions(self): + with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: + a = torch.randn(2, 3, requires_grad=True).clone() + b = a**2 + a.sin_() + c = a**2 + a.sin_() + self.assertClonedLenEqual(ctx, 2) + del b + self.assertClonedLenEqual(ctx, 1) + del c + self.assertClonedLenEqual(ctx, 0) + + def test_with_math_views(self): + with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: + a = torch.tensor([1 + 1j], requires_grad=True).clone() + b = a.conj() + out = (b**2).sum() + a.sin_() + out.backward() + + a = torch.tensor([1 + 1j], requires_grad=True).clone() + b = a.conj() + out = (b**2).sum() + # in this case, it is no longer a view it seems + b.sin_() + out.backward() + + def test_with_out_variant(self): + with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: + a = torch.tensor([1.], requires_grad=True) + b = torch.tensor([1.]) + c = torch.tensor([2.]) + out = a * b + self.assertTIDMapLenEqual(ctx, 1) + torch.sin(c, out=b) + self.assertClonedLenEqual(ctx, 1) + out.backward() + self.assertClonedLenEqual(ctx, 0) + + def test_backward_out_of_context(self): + # Out of context + with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: + a = torch.rand(2, 3, requires_grad=True) + out = (a**2).sum() + + msg = "Trying to backward outside of the 'allow_mutation_on_saved_tensors' context" + with self.assertRaisesRegex(RuntimeError, msg): + out.backward() + + # Different context + with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: + a = torch.rand(2, 3, requires_grad=True) + out = (a**2).sum() + + with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: + with self.assertRaisesRegex(RuntimeError, msg): + out.backward() + + def test_disallow_nesting(self): + with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: + msg = "allow_mutation_on_saved_tensors contexts cannot be nested" + with self.assertRaisesRegex(RuntimeError, msg): + with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: + pass class TestAutogradInferenceMode(TestCase): def _is_inference_tensor(self, tensor): diff --git a/test/test_binary_ufuncs.py b/test/test_binary_ufuncs.py index abcbb493342bf..8ffab2daa6e28 100644 --- a/test/test_binary_ufuncs.py +++ b/test/test_binary_ufuncs.py @@ -491,9 +491,6 @@ def test_type_promotion(self, device, op): make_tensor, (5,), device=device, **op.rhs_make_tensor_kwargs ) - make_lhs_scalar_tensor = partial( - make_tensor, (), device='cpu', **op.lhs_make_tensor_kwargs - ) make_rhs_scalar_tensor = partial( make_tensor, (), device='cpu', **op.rhs_make_tensor_kwargs ) @@ -782,17 +779,14 @@ def _supported(dtypes): ) self.assertEqual(result.dtype, expected_dtype) - # scalar int x scalar float + # scalar x scalar # Note: result dtype is default float type - # TODO: FIXME: re-enable this, scalar x scalar type promotion is currently broken - # https://github.com/pytorch/pytorch/issues/76801 - # if op.supports_two_python_scalars and _supported((torch.long, torch.float32)): - # lhs_i_scalar = 1 - # rhs_f_scalar = 2. - - # result = op(lhs_i_scalar, rhs_f_scalar) - # expected_dtype = torch.get_default_dtype() if not op.always_returns_bool else torch.bool - # self.assertEqual(result.dtype, expected_dtype) + if op.supports_two_python_scalars and _supported((torch.long, torch.float32)): + rhs_f_scalar = 2. + for lhs in (1, 1.): + result = op(lhs, rhs_f_scalar) + expected_dtype = torch.get_default_dtype() if not op.always_returns_bool else torch.bool + self.assertEqual(result.dtype, expected_dtype) # TODO: move to error input test @ops(binary_ufuncs, allowed_dtypes=(torch.float32,)) diff --git a/test/test_cpp_extensions_aot.py b/test/test_cpp_extensions_aot.py index 2f505553859fe..77ed19a36381a 100644 --- a/test/test_cpp_extensions_aot.py +++ b/test/test_cpp_extensions_aot.py @@ -1,6 +1,10 @@ # Owner(s): ["module: cpp-extensions"] +from itertools import repeat import os +import re +import sys +from typing import Union import unittest import torch.testing._internal.common_utils as common @@ -10,6 +14,16 @@ import torch.backends.cudnn import torch.utils.cpp_extension +if sys.version_info >= (3, 8): + from typing import get_args, get_origin +else: + def get_args(tp): + return tp.__args__ + + def get_origin(tp): + if hasattr(tp, "__origin__"): + return tp.__origin__ + try: import pytest HAS_PYTEST = True @@ -133,6 +147,99 @@ def test_cuda_dlink_libs(self): test = cuda_dlink.add(a, b) self.assertEqual(test, ref) + +class TestPybindTypeCasters(common.TestCase): + """Pybind tests for ahead-of-time cpp extensions + + These tests verify the types returned from cpp code using custom type + casters. By exercising pybind, we also verify that the type casters work + properly. + + For each type caster in `torch/csrc/utils/pybind.h` we create a pybind + function that takes no arguments and returns the type_caster type. The + second argument to `PYBIND11_TYPE_CASTER` should be the type we expect to + receive in python, in these tests we verify this at run-time. + """ + @staticmethod + def expected_return_type(func): + """ + Our Pybind functions have a signature of the form `() -> return_type`. + """ + # Imports needed for the `eval` below. + from typing import List, Tuple # noqa: F401 + + return eval(re.search("-> (.*)\n", func.__doc__).group(1)) + + def check(self, func): + val = func() + expected = self.expected_return_type(func) + origin = get_origin(expected) + if origin is list: + self.check_list(val, expected) + elif origin is tuple: + self.check_tuple(val, expected) + else: + self.assertIsInstance(val, expected) + + def check_list(self, vals, expected): + self.assertIsInstance(vals, list) + list_type = get_args(expected)[0] + for val in vals: + self.assertIsInstance(val, list_type) + + def check_tuple(self, vals, expected): + self.assertIsInstance(vals, tuple) + tuple_types = get_args(expected) + if tuple_types[1] is ...: + tuple_types = repeat(tuple_types[0]) + for val, tuple_type in zip(vals, tuple_types): + self.assertIsInstance(val, tuple_type) + + def check_union(self, funcs): + """Special handling for Union type casters. + + A single cpp type can sometimes be cast to different types in python. + In these cases we expect to get exactly one function per python type. + """ + # Verify that all functions have the same return type. + union_type = set(self.expected_return_type(f) for f in funcs) + assert len(union_type) == 1 + union_type = union_type.pop() + self.assertIs(Union, get_origin(union_type)) + expected_types = set(get_args(union_type)) + for func in funcs: + val = func() + for tp in expected_types: + if isinstance(val, tp): + expected_types.remove(tp) + break + else: + raise AssertionError(f"{val} is not an instance of {expected_types}") + self.assertFalse(expected_types, f"Missing functions for types {expected_types}") + + def test_pybind_return_types(self): + functions = [ + cpp_extension.get_complex, + cpp_extension.get_device, + cpp_extension.get_generator, + cpp_extension.get_intarrayref, + cpp_extension.get_memory_format, + cpp_extension.get_storage, + cpp_extension.get_symfloat, + cpp_extension.get_symintarrayref, + cpp_extension.get_tensor, + ] + union_functions = [ + [cpp_extension.get_symint, cpp_extension.get_symint_symbolic], + ] + for func in functions: + with self.subTest(msg=f"check {func.__name__}"): + self.check(func) + for funcs in union_functions: + with self.subTest(msg=f"check {[f.__name__ for f in funcs]}"): + self.check_union(funcs) + + class TestORTTensor(common.TestCase): def test_unregistered(self): a = torch.arange(0, 10, device='cpu') diff --git a/test/test_cpp_extensions_jit.py b/test/test_cpp_extensions_jit.py index e4b1e9e550873..2ead8d32ca179 100644 --- a/test/test_cpp_extensions_jit.py +++ b/test/test_cpp_extensions_jit.py @@ -15,7 +15,7 @@ import torch.backends.cudnn import torch.utils.cpp_extension from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME -from torch.testing._internal.common_utils import gradcheck, skipIfSlowGradcheckEnv +from torch.testing._internal.common_utils import gradcheck TEST_CUDA = torch.cuda.is_available() and CUDA_HOME is not None @@ -38,7 +38,6 @@ def remove_build_path(): shutil.rmtree(default_build_root) # There's only one test that runs gracheck, run slow mode manually -@skipIfSlowGradcheckEnv class TestCppExtensionJIT(common.TestCase): """Tests just-in-time cpp extensions. Don't confuse this with the PyTorch JIT (aka TorchScript). diff --git a/test/test_cuda.py b/test/test_cuda.py index c8876d5fbb0cd..40eaaa97a3b7e 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -15,6 +15,7 @@ import tempfile import threading import unittest +import warnings from random import randint import torch @@ -595,7 +596,7 @@ def test_serialization_array_with_storage(self): self.assertTrue(isinstance(q_copy[1], torch.cuda.IntTensor)) self.assertTrue(isinstance(q_copy[2], torch.cuda.FloatTensor)) self.assertTrue(isinstance(q_copy[3], torch.storage.TypedStorage)) - self.assertTrue(isinstance(q_copy[3]._storage, torch.UntypedStorage)) + self.assertTrue(isinstance(q_copy[3]._untyped_storage, torch.UntypedStorage)) q_copy[1].fill_(10) self.assertEqual(q_copy[3], torch.cuda.IntStorage(10).fill_(10)) @@ -1576,6 +1577,38 @@ def test_multinomial_invalid_probs_cuda(self): self._spawn_test_multinomial_invalid_probs_cuda([1., -inf, 1.]) self._spawn_test_multinomial_invalid_probs_cuda([1., 1., nan]) + @staticmethod + def _mute_init(): + os.dup2(os.open(os.devnull, os.O_WRONLY), sys.stderr.fileno()) + + def _spawn_method(self, method, arg): + ctx = torch.multiprocessing.get_context("spawn") + with ctx.Pool(1, initializer=self._mute_init) as pool: + errors = pool.map(method, [arg]) + for e in errors: + if 'device-side assert triggered' not in str(e): + self.fail(e) + + @staticmethod + def _test_index_bounds_cuda(idx): + x = torch.arange(10, device="cuda") + try: + y = x[torch.tensor([idx])] + return f"x[torch.tensor([{idx})]={y}" + except RuntimeError as err: + return err + + @slowTest + @unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \ + don't support multiprocessing with spawn start method") + @skipIfRocm + def test_index_out_of_bounds_exception_cuda(self): + test_method = TestCuda._test_index_bounds_cuda + # Test in-bound access works fine + self.assertEqual(test_method(1), "x[torch.tensor([1)]=tensor([1], device='cuda:0')") + # Test that indexing out of bounds causes assert + self._spawn_method(test_method, 11) + @slowTest @unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory") def test_huge_index(self): @@ -3030,18 +3063,22 @@ def forward(ctx, a, b): def backward(ctx, grad): self.assertTrue(torch.is_autocast_enabled()) a, b = ctx.saved_tensors - return grad.mm(b.t()), a.t().mm(grad) + a_grad, b_grad = grad.mm(b.t()), a.t().mm(grad) + self.assertTrue(a_grad.dtype is dtype and b_grad.dtype is dtype) + return a_grad, b_grad mymm = MyMM.apply x = torch.randn((8, 8), device="cuda", dtype=torch.float32, requires_grad=True) y = torch.randn((8, 8), device="cuda", dtype=torch.float32, requires_grad=True) - with torch.cuda.amp.autocast(): - output = mymm(x, y) - self.assertTrue(output.dtype is torch.float16) - loss = output.sum() - loss.backward() + dtypes = (torch.float16, torch.bfloat16) if TEST_BF16 else (torch.float16,) + for dtype in dtypes: + with torch.cuda.amp.autocast(dtype=dtype): + output = mymm(x, y) + self.assertTrue(output.dtype is dtype) + loss = output.sum() + loss.backward() def test_autocast_custom_cast_inputs(self): class MyMM(torch.autograd.Function): @@ -3255,6 +3292,18 @@ def test_graph_capture_simple(self): self.assertTrue(b.sum().item() == 11000.) + @unittest.skipIf((not TEST_CUDA) or + TEST_WITH_ROCM or + int(torch.version.cuda.split(".")[0]) < 11, "CUDA >= 11.0 required for graphs") + def test_graph_warn_if_has_zero_nodes(self): + with warnings.catch_warnings(record=True) as caught: + g = torch.cuda.CUDAGraph() + s = torch.cuda.Stream() + with torch.cuda.stream(s): + g.capture_begin() + g.capture_end() + self.assertTrue(any("The CUDA Graph is empty" in str(w.message) for w in caught)) + @unittest.skipIf((not TEST_CUDA) or TEST_WITH_ROCM or int(torch.version.cuda.split(".")[0]) < 11, "CUDA >= 11.0 required for graphs") @@ -4763,10 +4812,14 @@ def power2_div(size, div_factor): nelems = 21 * 1024 * 1024 nbytes = 4 * nelems # floats are 4 bytes + nelems_big = 100 * 1024 * 1024 + nbytes_big = 4 * nelems_big # floats are 4 bytes + start_mem = torch.cuda.memory_stats()[key] torch.cuda.memory._set_allocator_settings("") x = torch.rand(nelems, device='cuda') + # test roundup_power2_divisions single value syntax reg_mem = torch.cuda.memory_stats()[key] torch.cuda.memory._set_allocator_settings("roundup_power2_divisions:4") y = torch.rand(nelems, device='cuda') @@ -4788,6 +4841,26 @@ def power2_div(size, div_factor): reg_mem = torch.cuda.memory_stats()[key] self.assertTrue(reg_mem - start_mem == nbytes) + # roundup_power2_divisions knob array syntax + torch.cuda.memory.empty_cache() + torch.cuda.memory._set_allocator_settings( + "garbage_collection_threshold:0.5,roundup_power2_divisions:[64:8,128:2,256:2,512:2,1024:1,>:1]") + start_mem = torch.cuda.memory_stats()[key] + w = torch.rand(nelems, device='cuda') + + pow2_div8_mem = torch.cuda.memory_stats()[key] + if not TEST_CUDAMALLOCASYNC: + # not supported with the cudaMallocAsync backend + self.assertTrue(pow2_div8_mem - start_mem == power2_div(nbytes, 8)) + + torch.cuda.memory.empty_cache() + start_mem = torch.cuda.memory_stats()[key] + v = torch.rand(nelems_big, device='cuda') + + pow2_div2_mem = torch.cuda.memory_stats()[key] + if not TEST_CUDAMALLOCASYNC: + # not supported with the cudaMallocAsync backend + self.assertTrue(pow2_div2_mem - start_mem == power2_div(nbytes_big, 2)) with self.assertRaises(RuntimeError): torch.cuda.memory._set_allocator_settings("foo:1,bar:2") diff --git a/test/test_dataloader.py b/test/test_dataloader.py index 270ca89764ed1..9f1b73cf9ed41 100644 --- a/test/test_dataloader.py +++ b/test/test_dataloader.py @@ -20,19 +20,16 @@ ChainDataset, ConcatDataset, DataLoader, - DataLoader2, Dataset, IterableDataset, IterDataPipe, Subset, TensorDataset, - communication, _utils ) from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL from torch.utils.data.dataset import random_split from torch.utils.data.datapipes.iter import IterableWrapper -from torch.utils.data.datapipes.map import SequenceWrapper from torch._utils import ExceptionWrapper from torch.testing._internal.common_utils import (TestCase, run_tests, TEST_NUMPY, IS_WINDOWS, IS_CI, NO_MULTIPROCESSING_SPAWN, skipIfRocm, slowTest, @@ -2122,9 +2119,7 @@ def test_default_collate_dtype(self): arr = [1.1, 2.3, -0.9] collated = _utils.collate.default_collate(arr) - # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 - self.assertEqualIgnoreType(collated, torch.tensor(arr)) - self.assertEqual(collated.dtype, torch.float64) + self.assertEqual(collated, torch.tensor(arr, dtype=torch.float64)) arr = [True, False] collated = _utils.collate.default_collate(arr) @@ -2222,114 +2217,6 @@ def test_excessive_thread_creation_warning(self): r"excessive worker creation might get DataLoader running slow or even freeze"): dataloader = DataLoader(self.dataset, batch_size=2, num_workers=1000) -# Define a global function for testing purposes since local functions cannot be pickled -def identity(x): - return x - -@unittest.skipIf( - TEST_WITH_TSAN, - "Fails with TSAN with the following error: starting new threads after multi-threaded " - "fork is not supported. Dying (set die_after_fork=0 to override)") -class TestDataLoader2(TestCase): - @skipIfNoDill - def test_basics(self): - # TODO(VitalyFedyunin): This test will start breaking if we remove guaranteed order - # of traversing workers - dp = IterableWrapper(list(range(1000))).sharding_filter() - dl = DataLoader(dp, batch_size=3, collate_fn=identity, num_workers=2) - dl2 = DataLoader2(dp, batch_size=3, collate_fn=identity, num_workers=2) - dl2_threading = DataLoader2(dp, batch_size=3, collate_fn=identity, num_workers=2, parallelism_mode='thread') - self.assertEqual(list(dl), list(dl2)) - self.assertEqual(list(dl), list(dl2_threading)) - - class Sorter(IterDataPipe): - def __init__(self, datapipe): - self.datapipe = datapipe - - def __iter__(self): - return iter(sorted(self.datapipe)) - - def test_shuffle(self): - items = list(range(1000)) - dp = IterableWrapper(items).sharding_filter().shuffle() - - dl = DataLoader2(dp, batch_size=None, num_workers=2, shuffle=False) - self.assertEqual(items, list(dl)) - - dl = DataLoader2(dp, batch_size=None, num_workers=2, shuffle=True) - self.assertNotEqual(items, list(dl)) - self.assertEqual(items, sorted(list(dl))) - - dl = DataLoader2(dp, batch_size=None, num_workers=2, shuffle=True) - self.assertNotEqual(items, list(dl)) - self.assertEqual(items, sorted(list(dl))) - - dl = DataLoader2(self.Sorter(dp), batch_size=None, num_workers=2, shuffle=True) - self.assertEqual(list(dl), items) - - dl = DataLoader2(self.Sorter(dp), batch_size=None, num_workers=2, shuffle=True) - self.assertEqual(list(dl), items) - - -@unittest.skipIf( - TEST_WITH_TSAN, - "Fails with TSAN with the following error: starting new threads after multi-threaded " - "fork is not supported. Dying (set die_after_fork=0 to override)") -class TestDataLoader2_EventLoop(TestCase): - @skipIfNoDill - def test_basic_threading(self): - def clean_me(process, req_queue, res_queue): - req_queue.put(communication.messages.TerminateRequest()) - _ = res_queue.get() - process.join() - - it = list(range(100)) - numbers_dp = IterableWrapper(it) - (process, req_queue, res_queue, _thread_local_datapipe) = communication.eventloop.SpawnThreadForDataPipeline(numbers_dp) - - process.start() - local_datapipe = communication.iter.QueueWrapper( - communication.protocol.IterDataPipeQueueProtocolClient(req_queue, res_queue)) - - actual = list(local_datapipe) - clean_me(process, req_queue, res_queue) - - self.assertEqual(list(range(100)), actual) - - @skipIfNoDill - def test_basic_mapdatapipe_threading(self): - def clean_me(process, req_queue, res_queue): - req_queue.put(communication.messages.TerminateRequest()) - _ = res_queue.get() - process.join() - - input_len = 100 - it = list(range(input_len)) - numbers_dp = SequenceWrapper(it) - (process, req_queue, res_queue, _thread_local_datapipe) = communication.eventloop.SpawnThreadForDataPipeline( - numbers_dp) - - process.start() - - # Functional Test: Ensure that you can retrieve every element from the Queue and DataPipe - local_datapipe = communication.map.QueueWrapperForMap( - communication.protocol.MapDataPipeQueueProtocolClient(req_queue, res_queue)) - actual = list(local_datapipe) - self.assertEqual([(x, x) for x in range(100)], actual) - - # Functional Test: raise Error when input - local_datapipe = communication.map.QueueWrapperForMap( - communication.protocol.MapDataPipeQueueProtocolClient(req_queue, res_queue)) - with self.assertRaisesRegex(IndexError, "out of bound"): - local_datapipe[1000] - - # __len__ Test: Ensure that the correct length is returned - local_datapipe = communication.map.QueueWrapperForMap( - communication.protocol.MapDataPipeQueueProtocolClient(req_queue, res_queue)) - self.assertEqual(input_len, len(local_datapipe)) - - clean_me(process, req_queue, res_queue) - class IntegrationTestDataLoaderDataPipe(TestCase): r""" @@ -2827,6 +2714,9 @@ def __getitem__(self, index): @unittest.skipIf(IS_WINDOWS, "Needs fork") +@unittest.skipIf( + TEST_WITH_ASAN, + "This test hangs when running with ASAN, see https://github.com/pytorch/pytorch/issues/75492") class TestConvAfterFork(TestCase): # Tests crash reported in https://github.com/pytorch/pytorch/issues/53565 def test_conv_after_fork(self): diff --git a/test/test_datapipe.py b/test/test_datapipe.py index 49d2ba1ee79cb..b5de6a5f4006c 100644 --- a/test/test_datapipe.py +++ b/test/test_datapipe.py @@ -33,7 +33,7 @@ import torch.utils.data.datapipes as dp import torch.utils.data.graph import torch.utils.data.graph_settings -from torch.testing._internal.common_utils import TestCase, run_tests, suppress_warnings +from torch.testing._internal.common_utils import TestCase, run_tests, suppress_warnings, skipIfTorchDynamo from torch.utils.data import ( DataLoader, DataChunk, @@ -54,6 +54,7 @@ ) from torch.utils.data.datapipes.dataframe import CaptureDataFrame from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper +from torch.utils.data.datapipes.iter.grouping import SHARDING_PRIORITIES try: import dill @@ -219,6 +220,7 @@ def test_dir(self): for api in ['open', 'read', 'close']: self.assertTrue(api in s) + @skipIfTorchDynamo def test_api(self): fd = TestStreamWrapper._FakeFD("") wrap_fd = StreamWrapper(fd) @@ -2352,6 +2354,9 @@ def __iter__(self): for i in range(self.size): yield i + def __len__(self): + return self.size + class TestGraph(TestCase): class CustomIterDataPipe(IterDataPipe): @@ -2663,6 +2668,40 @@ def test_simple_sharding(self): items += list(sharded_dp) self.assertEqual(sorted(all_items), sorted(items)) + def test_sharding_groups(self): + def construct_sharded_pipe(): + sharding_pipes = [] + dp = NumbersDataset(size=90) + dp = dp.sharding_filter(sharding_group_filter=SHARDING_PRIORITIES.DISTRIBUTED) + sharding_pipes.append(dp) + dp = dp.sharding_filter(sharding_group_filter=SHARDING_PRIORITIES.MULTIPROCESSING) + sharding_pipes.append(dp) + dp = dp.sharding_filter(sharding_group_filter=300) + sharding_pipes.append(dp) + return dp, sharding_pipes + + dp, sharding_pipes = construct_sharded_pipe() + + for pipe in sharding_pipes: + pipe.apply_sharding(2, 1, sharding_group=SHARDING_PRIORITIES.DISTRIBUTED) + pipe.apply_sharding(5, 3, sharding_group=SHARDING_PRIORITIES.MULTIPROCESSING) + pipe.apply_sharding(3, 1, sharding_group=300) + + actual = list(dp) + expected = [17, 47, 77] + self.assertEquals(expected, actual) + self.assertEquals(3, len(dp)) + + dp, _ = construct_sharded_pipe() + dp.apply_sharding(2, 1, sharding_group=SHARDING_PRIORITIES.DEFAULT) + with self.assertRaises(Exception): + dp.apply_sharding(5, 3, sharding_group=SHARDING_PRIORITIES.MULTIPROCESSING) + + dp, _ = construct_sharded_pipe() + dp.apply_sharding(5, 3, sharding_group=SHARDING_PRIORITIES.MULTIPROCESSING) + with self.assertRaises(Exception): + dp.apply_sharding(2, 1, sharding_group=SHARDING_PRIORITIES.DEFAULT) + def test_sharding_length(self): numbers_dp = dp.iter.IterableWrapper(range(13)) sharded_dp0 = numbers_dp.sharding_filter() diff --git a/test/test_decomp.py b/test/test_decomp.py index dbc754147858f..b947f72c586bd 100644 --- a/test/test_decomp.py +++ b/test/test_decomp.py @@ -8,6 +8,7 @@ from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten from torch.utils._mode_utils import no_dispatch +from torch.testing import make_tensor from torch.testing._internal.common_utils import ( is_iterable_of_tensors, TestCase, @@ -15,13 +16,13 @@ suppress_warnings, TEST_WITH_ASAN, run_tests, - skipIfSlowGradcheckEnv, skipIfTorchDynamo, ) from torch.testing._internal.common_device_type import ( onlyNativeDeviceTypes, ops, instantiate_device_type_tests, + onlyCUDA, ) from torch.testing._internal.common_methods_invocations import op_db from torch._dispatch.python import enable_python_dispatcher @@ -160,8 +161,12 @@ def op_assert_ref(test_case, op, test_dtype, i, orig, decomp, ref, args, kwargs) (torch.bfloat16, torch.ops.aten.native_layer_norm_backward.default): 2e-2, (torch.bfloat16, torch.ops.aten.native_batch_norm.default): 1e-5, (torch.float16, torch.ops.aten.native_batch_norm.default): 1e-5, - (torch.bfloat16, torch.ops.aten.linalg_vector_norm.default): 1e-6, - (torch.float16, torch.ops.aten.linalg_vector_norm.default): 1e-6, + (torch.bfloat16, torch.ops.aten._native_batch_norm_legit.default): 1e-5, + (torch.bfloat16, torch.ops.aten._native_batch_norm_legit.no_stats): 1e-5, + (torch.float16, torch.ops.aten._native_batch_norm_legit.default): 1e-5, + (torch.float16, torch.ops.aten._native_batch_norm_legit.no_stats): 1e-5, + (torch.bfloat16, torch.ops.aten.linalg_vector_norm.default): 1e-5, + (torch.float16, torch.ops.aten.linalg_vector_norm.default): 1e-5, (torch.float16, torch.ops.aten.nll_loss_forward.default): 1e-2, (torch.bfloat16, torch.ops.aten.nll_loss_forward.default): 1e-1, } @@ -199,10 +204,20 @@ def op_assert_equal(test_case, op, test_dtype, orig, decomp, args, kwargs): (torch.float32, torch.ops.aten.grid_sampler_2d.default) : (7e-6, 3e-5), # Exceeds tolerances on CUDA, likely due to fma (torch.float32, torch.ops.aten.mv.default) : (1e-5, 3e-5), - (torch.float64, torch.ops.aten.upsample_bicubic2d.vec) : (1e-5, 1e-6), (torch.complex64, torch.ops.aten.mv.default): (5e-5, 5e-5), + (torch.float64, torch.ops.aten.upsample_bicubic2d.vec) : (1e-5, 5e-4), + (torch.float64, torch.ops.aten.upsample_bicubic2d.default) : (1e-5, 5e-4), + # The decomposition is TOO correct. It computes everything in int64, so sometimes + # there's an off-by-one error. See + # https://github.com/pytorch/pytorch/issues/81996 + # https://github.com/pytorch/pytorch/issues/82230 + (torch.int8, torch.ops.aten.linspace.default) : (0, 1), + (torch.uint8, torch.ops.aten.linspace.default) : (0, 1), + (torch.int16, torch.ops.aten.linspace.default) : (0, 1), + (torch.int32, torch.ops.aten.linspace.default) : (0, 1), + (torch.int64, torch.ops.aten.linspace.default) : (0, 1), } - if (test_dtype, op) in tol_table: + if (decomp.dtype, op) in tol_table: rtol, atol = tol_table[(decomp.dtype, op)] else: rtol, atol = _getDefaultRtolAndAtol(orig.dtype, decomp.dtype) @@ -292,12 +307,17 @@ def normalize_op_input_output(f, sample, requires_grad=True): # See https://github.com/pytorch/pytorch/issues/81669 (None, None, "nn.functional.relu6"), (None, None, "meshgrid"), + # diag was not decomposed (it just registers a decomp for diag_out, torch.diag is CompImplicit) + (None, None, "diag"), + # _softmax_backward_data's CPU kernel for bfloat16 always return the grad_input as float32 + ("cpu", torch.bfloat16, "_softmax_backward_data"), + (None, None, "norm"), + # native_batch_norm is only implicit when python dispatcher is on (and noncomposite otherwise) + (None, None, "native_batch_norm"), } CROSS_REF_BACKWARD_EXCLUDE_SET = { # Decomposed backward formula is not as precise - ("cuda", torch.float16, "nn.functional.embedding"), - ("cuda", torch.bfloat16, "nn.functional.embedding"), ("cpu", torch.bfloat16, "nn.functional.hardswish"), ("cuda", torch.float16, "nn.functional.cross_entropy"), } @@ -351,7 +371,6 @@ def test_unsupported(t): return any(test_unsupported(x) for x in itertools.chain(flat_args, flat_kwargs)) -@skipIfSlowGradcheckEnv class TestDecomp(TestCase): longMessage = True @@ -374,6 +393,19 @@ def test_quick(self, device, dtype, op): def test_comprehensive(self, device, dtype, op): self.do_cross_ref(device, dtype, op, run_all=True) + def test_uniform(self, device): + size = (2, 3, 4, 5) + dtype = torch.float32 + x = make_tensor(size, dtype=dtype, device=device) + low = 0.3 + high = 0.9 + + torch.manual_seed(123) + ref = torch.ops.aten.uniform(x, low, high) + torch.manual_seed(123) + res = torch._decomp.decompositions.uniform(x, low=low, high=high) + self.assertEqual(ref, res) + @skipIfTorchDynamo("Test does not work with TorchDynamo") def do_cross_ref(self, device, dtype, op, *, run_all): test_keys = [ @@ -560,6 +592,47 @@ def test_contiguous_log_softmax(self, device): instantiate_device_type_tests(DecompContiguousTests, globals()) +class DecompAmpTests(TestCase): + @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") + @skipIfCrossRef + @onlyCUDA + def test_amp_batch_norm_backward(self): + device = "cuda" + grad_out = torch.randn((1, 2, 16, 16), dtype=torch.float16, device=device) + x = torch.randn((1, 2, 16, 16), dtype=torch.float16, device=device) + weight = torch.randn((2,), dtype=torch.float32, device=device) + rmean = torch.randn((2,), dtype=torch.float32, device=device) + rvar = torch.randn((2,), dtype=torch.float32, device=device) + mean = torch.randn((0,), dtype=torch.float32, device=device) + + ref = torch.ops.aten.native_batch_norm_backward( + grad_out, + x, + weight, + rmean, + rvar, + mean, + mean, + False, + 1e-05, + [True, True, True]) + res = torch._decomp.decompositions.native_batch_norm_backward( + grad_out, + x, + weight, + rmean, + rvar, + mean, + mean, + False, + 1e-05, + [True, True, True]) + for (a, b) in zip(ref, res): + self.assertEqual(a.stride(), b.stride()) + self.assertEqual(a.dtype, b.dtype) + + +instantiate_device_type_tests(DecompAmpTests, globals()) if __name__ == "__main__": run_tests() diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index 87b1dd9aa8217..06230f86943a0 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -4,16 +4,25 @@ from torch._C import _disabled_torch_function_impl import torch.fx import torch.nn.functional as F -from torch.testing._internal.common_utils import run_tests, TestCase, skipIfTorchDynamo, IS_WINDOWS +from torch.testing._internal.common_utils import run_tests, TestCase, skipIfTorchDynamo, \ + IS_WINDOWS, parametrize, instantiate_parametrized_tests import unittest import torch import operator import itertools +import random +import contextlib +import math +import builtins +import atexit import io +import os from torch.utils._pytree import tree_map +from torch.fx.experimental import symbolic_shapes from torch.fx.experimental.proxy_tensor import make_fx -from torch.fx.experimental.symbolic_shapes import ShapeEnv, PySymInt, sym_float +from torch.fx.experimental.symbolic_shapes import ShapeEnv, sym_float, guard_int, SymNode, sym_sqrt, sym_int, to_node from torch.utils._python_dispatch import TorchDispatchMode +from torch import SymInt aten = torch.ops.aten @@ -112,12 +121,9 @@ def __torch_dispatch__(cls, func_overload, types, args=(), kwargs=None): raise RuntimeError(f"operator {func_overload} not supported") -def create_symbolic_tensor(name, arg, shape_env, storage_offset=0): - sym_shapes, sym_strides = shape_env.create_symbolic_sizes_strides(arg) - return FakeSymbolicTensor(sym_shapes, sym_strides, arg.dtype, arg.layout, arg.requires_grad, arg.device, storage_offset) - - -CPP_SYMINT_CLASS = type(torch.SymIntNode.new_symint(1)) +def create_symbolic_tensor(name, arg, shape_env): + sym_shapes, sym_strides, sym_storage_offset = shape_env.create_symbolic_sizes_strides_storage_offset(arg) + return FakeSymbolicTensor(sym_shapes, sym_strides, arg.dtype, arg.layout, arg.requires_grad, arg.device, sym_storage_offset) def create_symint(shape_env, i): return shape_env.create_symintnode(shape_env.create_symbol(i)) @@ -156,8 +162,8 @@ def test_roundtrip(self): shape_env = ShapeEnv() x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env) - self.assertTrue(not isinstance(x.shape[0], PySymInt)) - self.assertTrue(isinstance(x.shape[0], CPP_SYMINT_CLASS)) + self.assertTrue(not isinstance(x.shape[0], SymNode)) + self.assertTrue(isinstance(x.shape[0], SymInt)) self.assertTrue(x.shape[0] == 5) self.assertTrue(x.shape[1] == 4) @@ -165,23 +171,17 @@ def test_roundtrip(self): self.assertTrue(x.size()[0], 5) self.assertTrue(x.size()[1], 4) - self.assertTrue(isinstance(x.size()[1], CPP_SYMINT_CLASS)) + self.assertTrue(isinstance(x.size()[1], SymInt)) self.assertTrue(x.size()[2] == 3) self.assertTrue(x.size(0) == 5) self.assertTrue(x.size(1) == 4) self.assertTrue(x.size(2) == 3) - self.assertTrue(isinstance(x.size(2), CPP_SYMINT_CLASS)) + self.assertTrue(isinstance(x.size(2), SymInt)) - offset = create_symint(shape_env, 2) - y = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env, offset) - self.assertTrue(isinstance(y.storage_offset(), CPP_SYMINT_CLASS)) - self.assertTrue(y.storage_offset() == 2) - - offset = 2 - z = create_symbolic_tensor("z", torch.randn(5, 4, 3), shape_env, offset) - self.assertTrue(isinstance(z.storage_offset(), int)) - self.assertTrue(z.storage_offset() == 2) + y = create_symbolic_tensor("x", torch.randn(5, 4, 3)[1:], shape_env) + self.assertTrue(isinstance(y.storage_offset(), SymInt)) + self.assertTrue(y.storage_offset() == 12) @skipIfNoSympy def test_binary(self): @@ -267,7 +267,7 @@ def test_symint_vargs(self): def test_stride(self): shape_env = ShapeEnv() x = create_symbolic_tensor("x", torch.randn(5, 5), shape_env) - self.assertIsInstance(x.stride()[0], CPP_SYMINT_CLASS) + self.assertIsInstance(x.stride()[0], SymInt) @skipIfNoSympy def test_size_expressions(self): @@ -279,18 +279,29 @@ def test_size_expressions(self): else: result = expand_x + expand_x - gt_op = shape_env.guards[0][0] + gt_op, _bt = shape_env.guards[-1] self.assertTrue(isinstance(gt_op, sympy.core.relational.StrictGreaterThan)) self.assertTrue(str(x.shape[0]), str(gt_op.args[0])) self.assertTrue(str(expand_x.shape[1]), str(x.shape[0])) self.assertTrue(str(expand_x.shape[1]), str(result.shape[0])) + @skipIfNoSympy + def test_numel(self): + shape_env = ShapeEnv() + x = create_symbolic_tensor("x", torch.randn(5), shape_env) + self.assertIsInstance(x.numel(), torch.SymInt) + self.assertIsInstance(torch.numel(x), torch.SymInt) + + x = torch.rand(3, 3) + self.assertIsInstance(x.numel(), int) + self.assertIsInstance(torch.numel(x), int) + @skipIfNoSympy def test_int_to_float(self): shape_env = ShapeEnv() x = create_symbolic_tensor("x", torch.randn(5), shape_env) r = sym_float(x.shape[0]) - self.assertTrue(isinstance(r, torch.SymFloatNode)) + self.assertIsInstance(r, torch.SymFloat, msg=type(r)) @skipIfNoSympy def test_aten_ops(self): @@ -320,15 +331,53 @@ def test_meta_symint(self): shape_env = ShapeEnv() a0 = create_symint(shape_env, 2) r = torch.empty(a0, device='meta') - self.assertIsInstance(r.shape[0], CPP_SYMINT_CLASS) + self.assertIsInstance(r.shape[0], SymInt) @skipIfNoSympy def test_guard_int(self): shape_env = ShapeEnv() a0 = create_symint(shape_env, 2) - self.assertEqual(a0.guard_int(), 2) - self.assertEqual(str(shape_env.guards[0][0]), "s0") - self.assertEqual(shape_env.guards[0][1], 2) + self.assertEqual(guard_int(a0), 2) + self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s1, 2)""") + + @skipIfNoSympy + def test_sym_int(self): + shape_env = ShapeEnv() + a0 = create_symint(shape_env, 5) + r = sym_int(a0) + self.assertEqual(r, 5) + self.assertIsInstance(r, torch.SymInt, msg=type(r)) + self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s1, 5)""") + + a1 = create_symint(shape_env, 7) + r = sym_int(a1 / 2) + self.assertEqual(guard_int(r), 3) + self.assertIsInstance(r, torch.SymInt, msg=type(r)) + self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(floor(s3/2), 3)""") + + a2 = create_symint(shape_env, -3) + r = sym_int(a2 / 2) + self.assertEqual(guard_int(r), -1) + self.assertIsInstance(r, torch.SymInt, msg=type(r)) + self.assertExpectedInline(str(shape_env.guards[2][0]), """Eq(ceiling(-s5/2), -1)""") + + @skipIfNoSympy + def test_sym_sqrt(self): + shape_env = ShapeEnv() + a0 = create_symint(shape_env, 4) + r = sym_sqrt(a0) + self.assertEqual(r, 2) + self.assertIsInstance(r, torch.SymFloat, msg=type(r)) + self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(sqrt(s1), 2)""") + + @skipIfNoSympy + def test_sym_floor(self): + shape_env = ShapeEnv() + a0 = create_symint(shape_env, 5) + r = math.floor(a0 / 2) + self.assertEqual(r, 2) + self.assertIsInstance(r, torch.SymInt, msg=type(r)) + self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(floor(s1/2), 2)""") @skipIfNoSympy def test_int_conversion(self): @@ -348,7 +397,9 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): assert func == torch.ops.aten.add.Tensor nonlocal sym_int_encountered - sym_int_encountered = kwargs["alpha"] is a0 + # WARNING: do not do identity tests on the outer + # SymInt/SymFloat, they are NOT STABLE + sym_int_encountered = kwargs["alpha"].node is a0.node kwargs["alpha"] = 0 return func(*args) @@ -373,20 +424,174 @@ def f(a, b): self.assertExpectedInline(mock_stdout.getvalue().strip(), """\ class f(torch.nn.Module): - def forward(self, a_1: f32[s0, s1], b_1: f32[s2, s1]): + def forward(self, a_1: f32[s1, s3], b_1: f32[s8, s3]): # No stacktrace found for following nodes - sym_size: Sym(s0) = torch.ops.aten.sym_size(a_1, 0) - sym_size_1: Sym(s2) = torch.ops.aten.sym_size(b_1, 0) - add: Sym(s0 + s2) = sym_size + sym_size_1; sym_size = sym_size_1 = None - sym_size_2: Sym(s1) = torch.ops.aten.sym_size(a_1, 1) - sym_size_3: Sym(s1) = torch.ops.aten.sym_size(b_1, 1); b_1 = None - add_1: Sym(2*s1) = sym_size_2 + sym_size_3; sym_size_2 = sym_size_3 = None - new_empty: f32[s0 + s2, 2*s1] = torch.ops.aten.new_empty.default(a_1, [add, add_1], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False); a_1 = add = add_1 = None + sym_size: Sym(s1) = torch.ops.aten.sym_size(a_1, 0) + sym_size_1: Sym(s8) = torch.ops.aten.sym_size(b_1, 0) + add: Sym(s1 + s8) = sym_size + sym_size_1; sym_size = sym_size_1 = None + sym_size_2: Sym(s3) = torch.ops.aten.sym_size(a_1, 1) + sym_size_3: Sym(s3) = torch.ops.aten.sym_size(b_1, 1); b_1 = None + add_1: Sym(2*s3) = sym_size_2 + sym_size_3; sym_size_2 = sym_size_3 = None + new_empty: f32[s1 + s8, 2*s3] = torch.ops.aten.new_empty.default(a_1, [add, add_1], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False); a_1 = add = add_1 = None native_dropout = torch.ops.aten.native_dropout.default(new_empty, 0.5, True); new_empty = None - getitem: f32[s0 + s2, 2*s1] = native_dropout[0] - getitem_1: b8[s0 + s2, 2*s1] = native_dropout[1]; native_dropout = None + getitem: f32[s1 + s8, 2*s3] = native_dropout[0] + getitem_1: b8[s1 + s8, 2*s3] = native_dropout[1]; native_dropout = None return (getitem, getitem_1)""") # noqa: B950 +# This environment variable controls whether or not we print expected failure +# lists at the end of a test suite run. The intended usage looks like this: +# +# 1. Run `PYTORCH_COLLECT_EXPECT=1 python test/test_dynamic_shapes.py -k TestSymNumberMagicMethods`. +# 2. Given the printed xfail list, add them to the set expected_failure_sym_magic_methods. +COLLECT_EXPECT = os.getenv('PYTORCH_COLLECT_EXPECT', '0') == '1' + +seen_failed = [] +def print_seen(): + out = [] + for key, reason in seen_failed: + # Make sure the generated line is lint clean + msg = f" {key}, # {reason}" + eol = msg.find("\n") + if eol != -1: + msg = msg[:eol] + out.append(msg[:120]) + + print("expected_failure_sym_magic_methods = {") + print("\n".join(out)) + print("}") + +if COLLECT_EXPECT: + atexit.register(print_seen) + +expected_failure_sym_magic_methods = { + ('floordiv', 'SymFloat', 'float'), # Cannot convert complex to float + ('floordiv', 'float', 'SymFloat'), # Cannot convert complex to float + ('floordiv', 'SymFloat', 'SymFloat'), # Cannot convert complex to float + ('floordiv', 'SymFloat', 'int'), # Scalars are not close! + ('floordiv', 'float', 'SymInt'), # Scalars are not close! + ('floordiv', 'SymFloat', 'SymInt'), # Scalars are not close! + ('floordiv', 'SymInt', 'float'), # Cannot convert complex to float + ('floordiv', 'int', 'SymFloat'), # Cannot convert complex to float + ('floordiv', 'SymInt', 'SymFloat'), # Cannot convert complex to float +} + +@skipIfTorchDynamo("Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)") +class TestSymNumberMagicMethods(TestCase): + def _do_test(self, fn, inp1, inp2, shape_env, is_unary_fn): + # Helper function + seed_node = (create_symint(shape_env, 1) / 1.).get_pyobj() + + def get_sym_inp(inp): + if isinstance(inp, int): + return torch.SymInt(to_node(seed_node, inp)) + else: + return torch.SymFloat(to_node(seed_node, inp)) + + def maybe_xfail(inp1, inp2): + key = (fn, type(inp1).__name__, type(inp2).__name__) + if COLLECT_EXPECT: + @contextlib.contextmanager + def context(): + try: + yield + except (TypeError, AssertionError) as e: + seen_failed.append((key, str(e))) + return context() + + if key in expected_failure_sym_magic_methods: + return self.assertRaises((TypeError, AssertionError)) + else: + return contextlib.nullcontext() + + # These functions might return plain int/float + has_valid_downcast = fn in ["min", "max"] + if fn in symbolic_shapes.magic_methods_on_builtins: + lambda_apply = getattr(builtins, fn) + elif fn in symbolic_shapes.magic_methods_on_math: + lambda_apply = getattr(math, fn) + elif fn in symbolic_shapes.magic_methods_on_submodule: + lambda_apply = getattr(symbolic_shapes, fn) + else: + lambda_apply = getattr(operator, fn) + + if fn in symbolic_shapes.always_float_magic_methods: + tp = "float" + elif fn in symbolic_shapes.always_int_magic_methods: + tp = "int" + elif is_unary_fn: + tp = "float" if isinstance(inp1, float) else "int" + else: + tp = "float" if any(isinstance(i, float) for i in [inp1, inp2]) else "int" + + def guard_fn(v): + try: + if fn in symbolic_shapes.always_bool_magic_methods: + return bool(v) + else: + return getattr(v.node, f"guard_{tp}")("", 0) + except Exception as e: + if has_valid_downcast: + return v + else: + raise e + + # Get reference result + with maybe_xfail(inp1, inp2): + if is_unary_fn: + ref_out = lambda_apply(inp1) + else: + ref_out = lambda_apply(inp1, inp2) + + # Symified first arg + sym_inp1 = get_sym_inp(inp1) + with maybe_xfail(sym_inp1, inp2): + if is_unary_fn: + out = lambda_apply(sym_inp1) + else: + out = lambda_apply(sym_inp1, inp2) + self.assertEqual(guard_fn(out), ref_out) + + if is_unary_fn: + return + + # Symified second arg + sym_inp2 = get_sym_inp(inp2) + with maybe_xfail(inp1, sym_inp2): + out = lambda_apply(inp1, sym_inp2) + self.assertEqual(guard_fn(out), ref_out) + + # Symified both args + with maybe_xfail(sym_inp1, sym_inp2): + out = lambda_apply(sym_inp1, sym_inp2) + self.assertEqual(guard_fn(out), ref_out) + + + @parametrize("fn", list(symbolic_shapes.magic_methods.keys())) + @parametrize("first_type", ["int", "float"]) + @parametrize("second_type", ["int", "float"]) + def test_method(self, fn, first_type, second_type): + if first_type == "float": + self.skipTest(f"{fn} is not a float magic method") + + is_unary_fn = fn in symbolic_shapes.unary_magic_methods + # Second argument is ignored for unary function. So only run for one type + if is_unary_fn and second_type == "float": + self.skipTest(f"{fn} is unary and already tested") + + # We could pass int/float directly for types but then the + # mangled test name is bad + inp1 = random.random() * 2.5 + if first_type == "int": + inp1 = int(inp1) + inp2 = random.random() * 2.5 + if second_type == "int": + inp2 = int(inp2) + + shape_env = ShapeEnv() + + self._do_test(fn, inp1, inp2, shape_env, is_unary_fn) + +instantiate_parametrized_tests(TestSymNumberMagicMethods) if __name__ == '__main__': run_tests() diff --git a/test/test_dynamic_shapes.py.bak b/test/test_dynamic_shapes.py.bak deleted file mode 100644 index 19c77fe4d7ab0..0000000000000 --- a/test/test_dynamic_shapes.py.bak +++ /dev/null @@ -1,391 +0,0 @@ -# -*- coding: utf-8 -*- -# Owner(s): ["oncall: jit"] - -from torch._C import _disabled_torch_function_impl -import torch.fx -import torch.nn.functional as F -from torch.testing._internal.common_utils import run_tests, TestCase, skipIfTorchDynamo -import unittest -import torch -import operator -import itertools -import io -from torch.utils._pytree import tree_map -from torch.fx.experimental.proxy_tensor import make_fx -from torch.fx.experimental.symbolic_shapes import ShapeEnv, PySymInt, sym_float -from torch.utils._python_dispatch import TorchDispatchMode - -aten = torch.ops.aten - -try: - import sympy - HAS_SYMPY = True -except ImportError: - HAS_SYMPY = False -skipIfNoSympy = unittest.skipIf(not HAS_SYMPY, "no sympy") - - -meta_funcs = {} - - -def register_meta(op): - def decorator(f): - def add_func(op): - meta_funcs[op] = f - tree_map(add_func, op) - return f - return decorator - - -@register_meta([aten.add.Tensor, aten.sub.Tensor]) -def binary_meta(a, b): - return a.new_empty(a.shape) - - -@register_meta(aten.cat.default) -def cat_meta(tensors, dim=0): - concat_length = 0 - shape = tensors[0].shape - for tensor in tensors: - for idx, (common_length, length) in enumerate(zip(shape, tensor.shape)): - if idx == dim: - concat_length = concat_length + length - else: - assert length == common_length - new_shape = list(shape) - new_shape[dim] = concat_length - return tensors[0].new_empty(new_shape) - - -@register_meta([aten.narrow_copy.default]) -def narrow_copy_symint_meta(a, dim, start, length, **kwargs): - shape = [] - for i, x in enumerate(a.shape): - if i == dim: - shape.append(length) - else: - shape.append(x) - return a.new_empty(tuple(shape)) - - -@register_meta([aten.expand.default]) -def expand_symint_meta(a, size, implicit=False): - return a.new_empty(size) - - -def create_contiguous(shape): - strides = [1] - for dim in reversed(shape[:-1]): - strides.append(dim * strides[-1]) - return list(reversed(strides)) - - -class FakeSymbolicTensor(torch.Tensor): - @staticmethod - def __new__(cls, sym_shape, sym_strides, dtype, layout, requires_grad, device, storage_offset=0): - # TODO: this is wrong in general - sym_stride = create_contiguous(sym_shape) - r = torch.Tensor._make_wrapper_subclass( - cls, sym_shape, - sym_stride, storage_offset, - dtype=dtype, layout=layout, requires_grad=requires_grad, - device=device, - ) - return r - - __torch_function__ = _disabled_torch_function_impl - - def new_empty(self, shape): - return FakeSymbolicTensor(shape, None, self.dtype, self.layout, self.requires_grad, self.device) - - @classmethod - def __torch_dispatch__(cls, func_overload, types, args=(), kwargs=None): - if func_overload in meta_funcs: - return meta_funcs[func_overload](*args, **kwargs) - - if func_overload == torch.ops.aten.new_empty.default: - self = args[0] - shape = args[1] - return FakeSymbolicTensor(shape, self.stride(), self.dtype, self.layout, self.requires_grad, self.device) - - raise RuntimeError(f"operator {func_overload} not supported") - - -def create_symbolic_tensor(name, arg, shape_env, storage_offset=0): - sym_shapes, sym_strides = shape_env.create_symbolic_sizes_strides(arg) - return FakeSymbolicTensor(sym_shapes, sym_strides, arg.dtype, arg.layout, arg.requires_grad, arg.device, storage_offset) - - -CPP_SYMINT_CLASS = type(torch.SymIntNode.new_symint(1)) - -def create_symint(shape_env, i): - return shape_env.create_symintnode(shape_env.create_symbol(i)) - -@skipIfTorchDynamo("Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)") -class TestPySymInt(TestCase): - - @skipIfNoSympy - def test_arith_ops(self): - shape_env = ShapeEnv() - symints = [] - for i in range(2, 5): - symints.append((i, create_symint(shape_env, i))) - - ops = [operator.add, operator.sub, operator.floordiv, operator.mul, operator.mod] - - for op in ops: - for args in itertools.permutations(symints, 2): - if not isinstance(args[0][1], int) and ((op != operator.mod or op != operator.floordiv) and args[1][0] != 0): - self.assertTrue(op(args[0][1], args[1][1]) == op(args[0][0], args[1][0])) - - - @skipIfNoSympy - def test_reverse_arith_ops(self): - shape_env = ShapeEnv() - - a = create_symint(shape_env, 2) - self.assertTrue(5 // a == 5 // 2) - - a = create_symint(shape_env, 2) - self.assertTrue(5 * a == 5 * 2) - - - @skipIfNoSympy - def test_roundtrip(self): - shape_env = ShapeEnv() - x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env) - - self.assertTrue(not isinstance(x.shape[0], PySymInt)) - self.assertTrue(isinstance(x.shape[0], CPP_SYMINT_CLASS)) - - self.assertTrue(x.shape[0] == 5) - self.assertTrue(x.shape[1] == 4) - self.assertTrue(x.shape[2], 3) - - self.assertTrue(x.size()[0], 5) - self.assertTrue(x.size()[1], 4) - self.assertTrue(isinstance(x.size()[1], CPP_SYMINT_CLASS)) - self.assertTrue(x.size()[2] == 3) - - self.assertTrue(x.size(0) == 5) - self.assertTrue(x.size(1) == 4) - self.assertTrue(x.size(2) == 3) - self.assertTrue(isinstance(x.size(2), CPP_SYMINT_CLASS)) - - offset = create_symint(shape_env, 2) - y = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env, offset) - self.assertTrue(isinstance(y.storage_offset(), CPP_SYMINT_CLASS)) - self.assertTrue(y.storage_offset() == 2) - - offset = 2 - z = create_symbolic_tensor("z", torch.randn(5, 4, 3), shape_env, offset) - self.assertTrue(isinstance(z.storage_offset(), int)) - self.assertTrue(z.storage_offset() == 2) - - @skipIfNoSympy - def test_binary(self): - shape_env = ShapeEnv() - x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env) - y = create_symbolic_tensor("y", torch.randn(5, 4, 3), shape_env) - - z = x + y - self.assertTrue(z.shape[0] == 5) - self.assertTrue(z.shape[1] == 4) - self.assertTrue(z.shape[2] == 3) - - # broadcasting - y = create_symbolic_tensor("y", torch.randn(1, 4, 1), shape_env) - z = x + y - self.assertTrue(z.shape[0] == 5) - self.assertTrue(z.shape[1] == 4) - self.assertTrue(z.shape[2] == 3) - - @skipIfNoSympy - def test_symint_args(self): - shape_env = ShapeEnv() - x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env) - y = create_symbolic_tensor("y", torch.randn(5, 4, 1), shape_env) - LAST_DIM = 2 - z = x.narrow_copy(LAST_DIM, 0, y.shape[LAST_DIM]) - self.assertTrue(z.shape[2] == y.shape[2]) - - # arithmetic expr with two symints - z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - y.shape[LAST_DIM]) - self.assertTrue(z.shape[2] == 2) - - # arithmetic expr with a symint and python int - z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - 1) - self.assertTrue(z.shape[2] == 2) - - @skipIfNoSympy - def test_symint_vargs(self): - shape_env = ShapeEnv() - x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env) - y = create_symbolic_tensor("y", torch.randn(1, 4, 1), shape_env) - - # varargs - z = y.expand(x.shape[0], y.shape[1], x.shape[2]) - self.assertTrue(z.shape[0] == 5) - self.assertTrue(z.shape[1] == 4) - self.assertTrue(z.shape[2] == 3) - - # shape list - z = y.expand((x.shape[0], y.shape[1], x.shape[2])) - self.assertTrue(z.shape[0] == 5) - self.assertTrue(z.shape[1] == 4) - self.assertTrue(z.shape[2] == 3) - - # mixed python symints and ints - z = y.expand(x.shape[0], y.shape[1], 3) - self.assertTrue(z.shape[0] == 5) - self.assertTrue(z.shape[1] == 4) - self.assertTrue(z.shape[2] == 3) - - # mixed python symints and ints in a list - z = y.expand((x.shape[0], y.shape[1], 3)) - self.assertTrue(z.shape[0] == 5) - self.assertTrue(z.shape[1] == 4) - self.assertTrue(z.shape[2] == 3) - - # mixed python symints and ints - z = y.expand(5, y.shape[1], x.shape[2]) - self.assertTrue(z.shape[0] == 5) - self.assertTrue(z.shape[1] == 4) - self.assertTrue(z.shape[2] == 3) - - # mixed python ints and symints in a list - z = y.expand((5, y.shape[1], x.shape[2])) - self.assertTrue(z.shape[0] == 5) - self.assertTrue(z.shape[1] == 4) - self.assertTrue(z.shape[2] == 3) - - z = y.expand((y.shape[1],)) - z = y.expand(y.shape[1]) - - @skipIfNoSympy - def test_stride(self): - shape_env = ShapeEnv() - x = create_symbolic_tensor("x", torch.randn(5, 5), shape_env) - self.assertIsInstance(x.stride()[0], CPP_SYMINT_CLASS) - - @skipIfNoSympy - def test_size_expressions(self): - shape_env = ShapeEnv() - x = create_symbolic_tensor("x", torch.randn(5), shape_env) - expand_x = x.expand(x.shape[0], x.shape[0]) - if expand_x.shape[0] > 3: - result = expand_x + expand_x - else: - result = expand_x + expand_x - - gt_op = shape_env.guards[0][0] - self.assertTrue(isinstance(gt_op, sympy.core.relational.StrictGreaterThan)) - self.assertTrue(str(x.shape[0]), str(gt_op.args[0])) - self.assertTrue(str(expand_x.shape[1]), str(x.shape[0])) - self.assertTrue(str(expand_x.shape[1]), str(result.shape[0])) - - @skipIfNoSympy - def test_int_to_float(self): - shape_env = ShapeEnv() - x = create_symbolic_tensor("x", torch.randn(5), shape_env) - r = sym_float(x.shape[0]) - self.assertTrue(isinstance(r, torch.SymFloatNode)) - - @skipIfNoSympy - def test_aten_ops(self): - - shape_env = ShapeEnv() - x = create_symbolic_tensor("x", torch.randn(5), shape_env) - torch.ops.aten.narrow_copy.default(x, 0, 0, x.shape[0]) - - shape_env = ShapeEnv() - x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env) - torch.ops.aten.expand.default(x, [x.shape[0], x.shape[1], x.shape[2]]) - - def test_fx_trace_intlist(self): - class CustomModule(torch.nn.Module): - def forward(self, x): - bs, c, h, w = x.shape - return F.pad(x, (0, w % 2, 0, h % 2, 0, 0)) - - m = CustomModule() - x = torch.rand(1, 3, 4, 4) - # should not TypeError: pad(): argument 'pad' (position 2) must be - # tuple of ints, not tuple - torch.fx.symbolic_trace(m) - - @skipIfNoSympy - def test_meta_symint(self): - shape_env = ShapeEnv() - a0 = create_symint(shape_env, 2) - r = torch.empty(a0, device='meta') - self.assertIsInstance(r.shape[0], CPP_SYMINT_CLASS) - - @skipIfNoSympy - def test_guard_int(self): - shape_env = ShapeEnv() - a0 = create_symint(shape_env, 2) - self.assertEqual(a0.guard_int(), 2) - self.assertEqual(str(shape_env.guards[0][0]), "s0") - self.assertEqual(shape_env.guards[0][1], 2) - - @skipIfNoSympy - def test_int_conversion(self): - shape_env = ShapeEnv() - a0 = create_symint(shape_env, 2) - self.assertRaisesRegex(RuntimeError, "Trying to extract", lambda: int(a0)) - - @skipIfNoSympy - def test_symint_as_scalar(self): - shape_env = ShapeEnv() - a0 = create_symint(shape_env, 2) - - sym_int_encountered = False - - class TestSymInt(TorchDispatchMode): - def __torch_dispatch__(self, func, types, args=(), kwargs=None): - assert func == torch.ops.aten.add.Tensor - - nonlocal sym_int_encountered - sym_int_encountered = kwargs["alpha"] is a0 - kwargs["alpha"] = 0 - return func(*args) - - x = torch.rand([4, 4]) - with TestSymInt(): - y = torch.add(x, x, alpha=a0) - - self.assertTrue(sym_int_encountered) - - @skipIfNoSympy - @unittest.mock.patch('sys.stdout', new_callable=io.StringIO) - def test_print_readable_with_symints(self, mock_stdout): - def f(a, b): - dim0 = a.shape[0] + b.shape[0] - dim1 = a.shape[1] + b.shape[1] - d = a.new_empty(dim0, dim1) - d = torch.ops.aten.native_dropout(d, 0.5, train=True) - return d - - fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(5, 3), torch.randn(4, 3)) - fx_g.print_readable() - - self.assertExpectedInline(mock_stdout.getvalue().strip(), """\ -class f(torch.nn.Module): - def forward(self, a_1: f32[t0.size(0),t0.size(1)], b_1: f32[t1.size(0),t0.size(1)]): - # No stacktrace found for following nodes - sym_size: Sym(t0.size(0)) = torch.ops.aten.sym_size(a_1, 0) - sym_size_1: Sym(t1.size(0)) = torch.ops.aten.sym_size(b_1, 0) - add: Sym(t0.size(0) + t1.size(0)) = sym_size + sym_size_1; sym_size = sym_size_1 = None - sym_size_2: Sym(t0.size(1)) = torch.ops.aten.sym_size(a_1, 1) - sym_size_3: Sym(t0.size(1)) = torch.ops.aten.sym_size(b_1, 1); b_1 = None - add_1: Sym(2*t0.size(1)) = sym_size_2 + sym_size_3; sym_size_2 = sym_size_3 = None - new_empty: f32[t0.size(0) + t1.size(0),2*t0.size(1)] = torch.ops.aten.new_empty.default(a_1, [add, add_1], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False); a_1 = add = add_1 = None - native_dropout = torch.ops.aten.native_dropout.default(new_empty, 0.5, True); new_empty = None - getitem: f32[t0.size(0) + t1.size(0),2*t0.size(1)] = native_dropout[0] - getitem_1: b8[t0.size(0) + t1.size(0),2*t0.size(1)] = native_dropout[1]; native_dropout = None - return (getitem, getitem_1)""") # noqa: B950 - - -if __name__ == '__main__': - run_tests() diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py index 0d81cdf10f82f..86c1884d50b03 100644 --- a/test/test_fake_tensor.py +++ b/test/test_fake_tensor.py @@ -2,6 +2,7 @@ from torch.testing._internal.common_utils import TestCase, run_tests, skipIfCrossRef, skipIfRocm import torch +import torch._dynamo import itertools import numpy as np from torch.testing._internal.jit_utils import RUN_CUDA @@ -11,6 +12,7 @@ FakeTensorConverter, DynamicOutputShapeException, ) +from torch.fx.passes.fake_tensor_prop import FakeTensorProp from torch.testing import FileCheck from torch import nn import unittest @@ -18,6 +20,7 @@ import contextlib import weakref import copy +from torch.utils._pytree import tree_flatten class FakeTensorTest(TestCase): def checkType(self, t, device_str, size): @@ -135,6 +138,18 @@ def test_mode(self): self.assertTrue(isinstance(out, FakeTensor)) + def check_function_with_fake(self, fn): + out = fn() + with torch._subclasses.FakeTensorMode(): + out_fake = fn() + + for a, b in zip(tree_flatten(out), tree_flatten(out_fake)): + if not isinstance(a, FakeTensor): + self.assertTrue(not isinstance(b, FakeTensor)) + continue + + prims.utils.compare_tensor_meta(a, b, check_strides=True) + @unittest.skipIf(not RUN_CUDA, "requires cuda") def test_non_kwarg_device(self): with FakeTensorMode(): @@ -144,6 +159,13 @@ def test_non_kwarg_device(self): z = x.to(torch.device("cuda")) self.assertEqual(z.device.type, "cuda") + def test_non_overlapping_stride_zero(self): + def foo(): + x = torch.empty_strided([1, 3, 427, 640], (0, 1, 1920, 3)) + return x.half() + + self.check_function_with_fake(foo) + def test_fake_mode_error(self): x = torch.rand([4, 4]) @@ -194,6 +216,25 @@ def test_randperm(self): y1 = torch.randperm(5, device="cpu") prims.utils.compare_tensor_meta(y, y1) + def test_print_in_fake_mode(self): + x = torch.zeros(2) + # does not fail + with FakeTensorMode(): + out = str(x) + assert "FakeTensor" not in out + + @unittest.skipIf(not RUN_CUDA, "requires cuda") + def test_upsample_bilinear_small_channels(self): + out = [] + mode = FakeTensorMode() + for i, context in enumerate([contextlib.nullcontext, lambda: mode]): + with context(): + arg0_1 = torch.empty_strided((3, 427, 640), (1, 1920, 3), dtype=torch.float32, device='cuda') + unsqueeze = torch.ops.aten.unsqueeze.default(arg0_1, 0) + out.append(torch.ops.aten.upsample_bilinear2d.default(unsqueeze, [800, 1199], False)) + + self.assertTrue(out[1].is_contiguous()) + self.checkMetaProps(out[0], out[1]) @unittest.skipIf(not RUN_CUDA, "requires cuda") def test_cpu_fallback(self): @@ -352,7 +393,7 @@ def test_data_dependent_operator(self): self.assertRaises(DynamicOutputShapeException, lambda: torch.nonzero(x)) def checkMetaProps(self, t1, t2): - prims.utils.compare_tensor_meta(t1, t2) + prims.utils.compare_tensor_meta(t1, t2, check_strides=True) @skipIfCrossRef def test_deepcopy(self): @@ -505,7 +546,7 @@ def test_memoized_conversion_from_meta(self): x = torch.rand(2, 2).to(device="meta") mode = FakeTensorMode() converter = mode.fake_tensor_converter - self.assertTrue(converter(mode, x, "cpu") is converter(mode, x, "cpu")) + self.assertTrue(converter.from_meta_and_device(mode, x, "cpu") is converter.from_meta_and_device(mode, x, "cpu")) def test_separate_tensor_storages_view(self): x = torch.rand(2, 2, 2) @@ -554,10 +595,10 @@ def test_dead_key(self): converter = FakeTensorConverter() x_conv = converter(mode, x) self.assertEqual(len(converter.tensor_memo), 1) - self.assertEqual(len(converter.meta_converter.tensor_memo), 1) + x_conv2 = converter(mode, x) + assert x_conv2 is x_conv del x self.assertEqual(len(converter.tensor_memo), 0) - self.assertEqual(len(converter.meta_converter.tensor_memo), 0) def test_no_active_mode(self): with FakeTensorMode() as mode: @@ -657,5 +698,62 @@ def test_like_ops(self): op = self.get_aten_op(schema) self.assertIn(op, torch._subclasses.fake_tensor._like_tensor_constructors) +class FakeTensorPropTest(TestCase): + def test_fake_tensor_prop_on_nn_module(self): + class ToyNnModuleWithParameters(torch.nn.Module): + def __init__(self): + super().__init__() + self.layer1 = torch.nn.Linear(4, 3) + self.layer2 = torch.nn.Linear(3, 2) + + def forward(self, value): + value = self.layer1(value) + value = torch.relu(value) + value = self.layer2(value) + return value + + model = ToyNnModuleWithParameters() + value = torch.randn(5, 4) + # Convert nn.Module to GraphModule so that FakeTensorProp runs. + graph_model = torch.fx.symbolic_trace(model, (value,)) + # The following block runs FakeTensorProp on graph_module w/to the same FakeTensorMode + # + # TODO(wschin): there should be an API to run FakeTensorProp for GraphModule + # with parameters and buffers. + with FakeTensorMode() as fake_tensor_mode: + + def to_fake_tensor(x): + if isinstance(x, torch.Tensor) and not isinstance(x, FakeTensor): + return fake_tensor_mode.from_tensor(x) + return x + + fake_parameters_and_buffers = { + k: to_fake_tensor(v) + for k, v in itertools.chain( + graph_model.named_parameters(), graph_model.named_buffers() + ) + } + with torch.nn.utils.stateless._reparametrize_module( + graph_model, fake_parameters_and_buffers + ): + # This case uses the **same** fake tensor mode to + # 1. create fake parameters and fake buffers, and + # 2. run FakeTensorProp + # The result should be correct. + result = FakeTensorProp(graph_model, fake_tensor_mode).propagate(value) + self.assertTrue(isinstance(result, FakeTensor)) + self.assertEqual(result.shape, (5, 2)) + # This case uses the **different** fake tensor modes to + # 1. create fake parameters and fake buffers, and + # 2. run FakeTensorProp + # The following code should fail. + failed = False + try: + FakeTensorProp(graph_model).propagate(value) + except AssertionError: + # AssertionError: tensor's device must be `meta`, got cpu instead + failed = True + self.assertTrue(failed) + if __name__ == "__main__": run_tests() diff --git a/test/test_foreach.py b/test/test_foreach.py index 3e2921ed73da7..13e0e6ebc9cf1 100644 --- a/test/test_foreach.py +++ b/test/test_foreach.py @@ -31,6 +31,7 @@ complex(1.0 - random.random(), 1.0 - random.random()), ) + def getScalarLists(N): return ( ("int", [random.randint(0, 9) + 1 for _ in range(N)]), @@ -41,8 +42,10 @@ def getScalarLists(N): ("mixed", [True, 1, 2.0, 3.0 + 4.5j] + [3.0 for _ in range(N - 4)]), ) + _BOOL_SUB_ERR_MSG = "Subtraction, the `-` operator" + class RegularFuncWrapper: def __init__(self, func): @@ -88,6 +91,7 @@ def __call__(self, inputs, is_cuda, is_fastpath, **kwargs): # note(mkozuki): inplace foreach functions are void functions. return inputs[0] if self._is_inplace else actual + class TestForeach(TestCase): @property @@ -159,7 +163,7 @@ def _test_binary_op_tensorlists(self, device, dtype, opinfo, N, is_fastpath, dis inputs = [ opinfo.sample_inputs(device, dtype, N, noncontiguous=not is_fastpath), [ - make_tensor((N - i , 1), device=device, dtype=dtype, noncontiguous=not is_fastpath) for i in range(N) + make_tensor((N - i, 1), device=device, dtype=dtype, noncontiguous=not is_fastpath) for i in range(N) ], ] self._binary_test(dtype, op, ref, inputs, is_fastpath and disable_fastpath, is_inplace=False) @@ -248,7 +252,7 @@ def test_binary_op_scalarlist_slowpath(self, device, dtype, op): for _, scalarlist in getScalarLists(N): self._test_binary_op_scalarlist(device, dtype, op, N, scalarlist, False, False) - def _pointwise_test(self, dtype, op, ref, inputs, is_fastpath, is_inplace, *, values=None): + def _pointwise_test(self, dtype, op, ref, inputs, is_fastpath, is_inplace, *, values=None, custom_values_err=None): ref_inputs = [[t.clone().detach() for t in inputs[0]], inputs[1], inputs[2]] if is_inplace else inputs try: actual = op(inputs, self.is_cuda, is_fastpath) @@ -262,13 +266,18 @@ def _pointwise_test(self, dtype, op, ref, inputs, is_fastpath, is_inplace, *, va try: actual = op(inputs + [values], self.is_cuda, is_fastpath) except RuntimeError as e: - with self.assertRaisesRegex(type(e), re.escape(str(e))): - ref(ref_inputs, values=values) + # Match with error messages from regular non-foreach reference if no + # custom error message was provided. + if custom_values_err is None: + with self.assertRaisesRegex(type(e), re.escape(str(e))): + ref(ref_inputs, values=values) + else: + self.assertEqual(re.escape(str(e)), re.escape(custom_values_err)) else: expected = ref(ref_inputs, values=values) self.assertEqual(expected, actual) - def _test_pointwise_op(self, device, dtype, opinfo, N, is_fastpath, disable_fastpath, *, values=None): + def _test_pointwise_op(self, device, dtype, opinfo, N, is_fastpath, disable_fastpath, *, values=None, custom_values_err=None): n_expected_cudaLaunchKernels = N if disable_fastpath else 1 op, ref, inplace_op, inplace_ref = self._get_funcs(opinfo, n_expected_cudaLaunchKernels) inputs = [ @@ -276,8 +285,10 @@ def _test_pointwise_op(self, device, dtype, opinfo, N, is_fastpath, disable_fast opinfo.sample_inputs(device, dtype, N, noncontiguous=not is_fastpath), opinfo.sample_inputs(device, dtype, N, noncontiguous=not is_fastpath), ] - self._pointwise_test(dtype, op, ref, inputs, is_fastpath, is_inplace=False, values=values) - self._pointwise_test(dtype, inplace_op, inplace_ref, inputs, is_fastpath, is_inplace=True, values=values) + self._pointwise_test(dtype, op, ref, inputs, is_fastpath, is_inplace=False, + values=values, custom_values_err=custom_values_err) + self._pointwise_test(dtype, inplace_op, inplace_ref, inputs, is_fastpath, + is_inplace=True, values=values, custom_values_err=custom_values_err) # Tests of implicit broadcasting inputs = [ @@ -289,9 +300,11 @@ def _test_pointwise_op(self, device, dtype, opinfo, N, is_fastpath, disable_fast make_tensor((1, N - i), device=device, dtype=dtype, noncontiguous=not is_fastpath) for i in range(N) ], ] - self._pointwise_test(dtype, op, ref, inputs, is_fastpath and disable_fastpath, is_inplace=False, values=values) + self._pointwise_test(dtype, op, ref, inputs, is_fastpath and disable_fastpath, + is_inplace=False, values=values, custom_values_err=custom_values_err) self._pointwise_test( - dtype, inplace_op, inplace_ref, inputs, is_fastpath and disable_fastpath, is_inplace=True, values=values) + dtype, inplace_op, inplace_ref, inputs, is_fastpath and disable_fastpath, + is_inplace=True, values=values, custom_values_err=custom_values_err) @skipMeta @ops(foreach_pointwise_op_db) @@ -302,9 +315,24 @@ def test_pointwise_op_fastpath(self, device, dtype, op): self._test_pointwise_op(device, dtype, op, N, True, disable_fastpath) for scalar in Scalars: self._test_pointwise_op(device, dtype, op, N, True, disable_fastpath, values=scalar) - for _, scalarlist in getScalarLists(N): + for case, scalarlist in getScalarLists(N): self._test_pointwise_op( device, dtype, op, N, True, disable_fastpath, values=scalarlist) + self._test_pointwise_op( + device, dtype, op, N, True, disable_fastpath, values=torch.tensor(scalarlist)) + self._test_pointwise_op( + device, dtype, op, N, True, disable_fastpath, values=torch.tensor(scalarlist)[0], + custom_values_err="Expected packed scalar Tensor to be of dimension 1. Got 0 instead.") + if device == "cuda": + self._test_pointwise_op( + device, dtype, op, N, True, disable_fastpath, values=torch.tensor(scalarlist, device="cuda"), + custom_values_err="Expected scalars to be on CPU, got cuda:0 instead.") + self._test_pointwise_op( + device, dtype, op, N, True, disable_fastpath, values=torch.tensor(scalarlist)[:2], + custom_values_err=f"Expected length of scalars to match input of length {len(scalarlist)} but got 2 instead.") + self._test_pointwise_op( + device, dtype, op, N, True, disable_fastpath, values=torch.tensor([[0, 1], [2, 3]])[:, 1], + custom_values_err="Expected scalars to be contiguous.") @ops(foreach_pointwise_op_db) def test_pointwise_op_slowpath(self, device, dtype, op): @@ -313,9 +341,11 @@ def test_pointwise_op_slowpath(self, device, dtype, op): self._test_pointwise_op(device, dtype, op, N, False, False) for scalar in Scalars: self._test_pointwise_op(device, dtype, op, N, False, False, values=scalar) - for _, scalarlist in getScalarLists(N): + for case, scalarlist in getScalarLists(N): self._test_pointwise_op( device, dtype, op, N, False, False, values=scalarlist) + self._test_pointwise_op( + device, dtype, op, N, False, False, values=torch.tensor(scalarlist)) # note(mkozuki): fastpath test uses dtypes which fastpath implementation supports. # To confirm the dtypes of `OpInfo` cover the dtypes that the function support, @@ -476,7 +506,6 @@ def test_binary_op_scalar_with_different_tensor_dtypes(self, device, dtype, op): runtime_error = e self.assertIsNone(runtime_error) - @skipIfTorchDynamo("Different error msgs, TODO") @ops(foreach_binary_op_db, dtypes=all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) def test_binary_op_list_error_cases(self, device, dtype, op): diff --git a/test/test_functionalization.py b/test/test_functionalization.py index 2eb79c73cc0bd..ec1a0caa804c4 100644 --- a/test/test_functionalization.py +++ b/test/test_functionalization.py @@ -1,11 +1,16 @@ # Owner(s): ["module: codegen"] import torch -from torch.testing._internal.common_utils import TestCase, run_tests, skipIfTorchDynamo, TEST_WITH_TORCHDYNAMO +from contextlib import nullcontext +from torch.testing._internal.common_utils import ( + TestCase, run_tests, skipIfTorchDynamo, TEST_WITH_TORCHDYNAMO, IS_WINDOWS, + xfail_inherited_tests +) from torch.testing._internal.logging_tensor import LoggingTensor, capture_logs -from torch.utils._pytree import tree_map +from torch.utils._pytree import tree_map, tree_map_only, tree_flatten from torch.fx.experimental.proxy_tensor import make_fx from torch.fx.passes.reinplace import reinplace +from torch._dispatch.python import enable_crossref_functionalize, enable_python_dispatcher import unittest @@ -20,60 +25,79 @@ def are_aliased(x, y): # We can unify testing and use functionalize() here instead # if/when functorch moves into core. -# This is basically a crappy version of `functionalize()` for single-tensor-arg inputs. -def _functionalize(f, *, reapply_views: bool): - def wrapped(a): - input_functional = torch._to_functional_tensor(a) - torch._enable_functionalization(reapply_views=reapply_views) - try: - out = f(input_functional) - finally: - torch._disable_functionalization() - torch._sync(input_functional) - inpt_new = torch._from_functional_tensor(input_functional) - if inpt_new is not a: - # Existing deficiency in functionalize(): - # we don't correctly mutate input metadata (yet?) - if inpt_new.shape == a.shape: - a.copy_(inpt_new) - tree_map(torch._sync, out) - out_unwrapped = tree_map(torch._from_functional_tensor, out) - return out_unwrapped +# This is basically a crappy version of `functionalize()`. +def _functionalize(f, *, reapply_views: bool, crossref: bool): + def to_fun(t: torch.Tensor): + func_t = torch._to_functional_tensor(t) + func_t.requires_grad = t.requires_grad + return func_t + + def wrapped(*inputs): + ctx = nullcontext() + if crossref: + ctx = enable_crossref_functionalize() + with ctx: + inputs_functional = tree_map_only(torch.Tensor, to_fun, inputs) + torch._enable_functionalization(reapply_views=reapply_views) + try: + out = f(*inputs_functional) + finally: + torch._disable_functionalization() + flat_inputs, _ = tree_flatten(inputs) + flat_inputs_functional, _ = tree_flatten(inputs_functional) + for inpt, input_functional in zip(flat_inputs, flat_inputs_functional): + torch._sync(input_functional) + inpt_new = torch._from_functional_tensor(input_functional) + if inpt_new is not inpt: + # Existing deficiency in functionalize(): + # we don't correctly mutate input metadata (yet?) + if inpt_new.shape == inpt.shape: + inpt.copy_(inpt_new) + tree_map(torch._sync, out) + out_unwrapped = tree_map(torch._from_functional_tensor, out) + return out_unwrapped return wrapped @unittest.skipIf(TEST_WITH_TORCHDYNAMO, "https://github.com/pytorch/pytorch/issues/81457") class TestFunctionalization(TestCase): - def get_logs(self, func, inpt, *, reapply_views=False, run_reinplace=False): - inpt_clone = inpt.clone() - traced_f = make_fx(_functionalize(func, reapply_views=reapply_views))(inpt) + crossref = False + + def get_logs(self, func, *inpts, reapply_views=False, run_reinplace=False): + inpts_clone = tree_map_only(torch.Tensor, torch.clone, inpts) + traced_f = make_fx(_functionalize(func, reapply_views=reapply_views, crossref=self.crossref))(*inpts) if run_reinplace: - traced_f = reinplace(traced_f, inpt_clone) + traced_f = reinplace(traced_f, *inpts_clone) return traced_f.code - def assert_functionalization(self, func, inpt, *, reapply_views=False, mutated_input_metadata=False): - input_clone = inpt.clone() - input_clone2 = inpt.clone() - input_clone3 = inpt.clone() + def assert_functionalization(self, func, *inpts, reapply_views=False, mutated_input_metadata=False): + clones1 = tree_map_only(torch.Tensor, torch.clone, inpts) + clones2 = tree_map_only(torch.Tensor, torch.clone, inpts) + clones3 = tree_map_only(torch.Tensor, torch.clone, inpts) # Compare outputs (and mutated inputs), with and without functionalization. - out_ref = func(inpt) - out_functional = _functionalize(func, reapply_views=reapply_views)(input_clone) + out_ref = func(*inpts) + out_functional = _functionalize(func, reapply_views=reapply_views, crossref=self.crossref)(*clones1) + # The reinplacing pass is only valid to run with reapply_views=True. - functional_func = make_fx(_functionalize(func, reapply_views=True))(input_clone2) - reinplace_func = reinplace(make_fx(_functionalize(func, reapply_views=True))(input_clone2), input_clone2) + functional_func = make_fx(_functionalize(func, reapply_views=True, crossref=self.crossref))(*clones2) + reinplace_func = reinplace(functional_func, *clones2) # NOTE: for now, need to pass in fresh inputs here, because make_fx # will directly mutate the inputs that you trace with. # Once this is fixed we can clean this up. - out_reinplace = reinplace_func(input_clone3) + out_reinplace = reinplace_func(*clones3) # functionalize() deficiency: input metadata mutations aren't propagated properly, # so we just need to skip checks here for the tests that exercise that. if not mutated_input_metadata: - self.assertEqual(inpt, input_clone) # input mutations should still occur - self.assertEqual(inpt, input_clone3) + flat_inpts, _ = tree_flatten(inpts) + flat_clones1, _ = tree_flatten(clones1) + flat_clones3, _ = tree_flatten(clones3) + for inpt, input_clone, input_clone3 in zip(flat_inpts, flat_clones1, flat_clones3): + self.assertEqual(inpt, input_clone) # input mutations should still occur + self.assertEqual(inpt, input_clone3) # Handle tests with multi-tensor outputs if isinstance(out_ref, tuple): @@ -101,6 +125,76 @@ def f(x): return z2 self.assert_functionalization(f, torch.ones(4)) + def test_freeze(self): + def f(x): + y = x.clone() + z = y[0] + torch._freeze_functional_tensor(y) + x.add_(1) + self.assertRaises(RuntimeError, lambda: y.add_(1)) + self.assertRaises(RuntimeError, lambda: z.add_(1)) + return z + + _functionalize(f, reapply_views=True, crossref=self.crossref)(torch.ones(3, 3)) + + def test_copy_stride_mismatch(self): + def f(x): + y = torch.empty_strided((2, 2), (5, 1)) + y.copy_(x) + return y + + r = _functionalize(f, reapply_views=True, crossref=self.crossref)(torch.ones(2, 2)) + self.assertEqual(r.stride(), (5, 1)) + + def test_view_clone_view_inplace(self): + def f(input): + shape = [1, 1024, 128, 128] + input_reshaped = input.view(shape) + out = input_reshaped.clone() + r = out.view(input.shape) + r.relu_() + return r + + def g(x): + loss = f(x).sum() + from torch._functorch.aot_autograd import setup_stacktrace_preservation_hooks + import torch.fx.traceback as fx_traceback + setup_stacktrace_preservation_hooks([loss.grad_fn]) + with fx_traceback.override_stack_trace(): + loss.backward() + return x.grad + + with torch.autograd.detect_anomaly(check_nan=False): + logs = self.get_logs(g, torch.ones(16, 64, 128, 128, requires_grad=True)) + self.assertExpectedInline(logs, """\ + + + +def forward(self, arg0_1): + view_copy = torch.ops.aten.view_copy.default(arg0_1, [1, 1024, 128, 128]); arg0_1 = None + clone = torch.ops.aten.clone.default(view_copy); view_copy = None + view_copy_1 = torch.ops.aten.view_copy.default(clone, [16, 64, 128, 128]) + relu = torch.ops.aten.relu.default(view_copy_1); view_copy_1 = None + view_copy_2 = torch.ops.aten.view_copy.default(clone, [16, 64, 128, 128]); clone = None + sum_1 = torch.ops.aten.sum.default(relu) + ones_like = torch.ops.aten.ones_like.default(sum_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False, memory_format = torch.preserve_format); sum_1 = None + expand_copy = torch.ops.aten.expand_copy.default(ones_like, [16, 64, 128, 128]); ones_like = None + view_copy_3 = torch.ops.aten.view_copy.default(expand_copy, [1, 1024, 128, 128]); expand_copy = None + new_empty_strided = torch.ops.aten.new_empty_strided.default(view_copy_3, [1, 1024, 128, 128], [16777216, 16384, 128, 1]) + copy = torch.ops.aten.copy.default(new_empty_strided, view_copy_3); new_empty_strided = view_copy_3 = None + view_copy_4 = torch.ops.aten.view_copy.default(copy, [16, 64, 128, 128]) + view_copy_5 = torch.ops.aten.view_copy.default(copy, [16, 64, 128, 128]) + clone_1 = torch.ops.aten.clone.default(view_copy_5, memory_format = torch.contiguous_format) + threshold_backward = torch.ops.aten.threshold_backward.default(clone_1, relu, 0); clone_1 = relu = None + copy_1 = torch.ops.aten.copy.default(view_copy_5, threshold_backward); view_copy_5 = threshold_backward = None + view_copy_6 = torch.ops.aten.view_copy.default(copy, [16, 64, 128, 128]); copy = None + detach_copy = torch.ops.aten.detach_copy.default(view_copy_6); view_copy_6 = None + view_copy_7 = torch.ops.aten.view_copy.default(copy_1, [1, 1024, 128, 128]); copy_1 = None + view_copy_8 = torch.ops.aten.view_copy.default(view_copy_7, [16, 64, 128, 128]); view_copy_7 = None + detach_copy_1 = torch.ops.aten.detach_copy.default(view_copy_8); view_copy_8 = None + return detach_copy_1 + """) # noqa: B950 + def test_simple(self): def f(x): # simple test: 1 view op, 1 inplace op @@ -115,13 +209,13 @@ def f(x): -def forward(self, a_1): +def forward(self, arg0_1): ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False) - view_copy = torch.ops.aten.view_copy.default(a_1, [4, 2]) + view_copy = torch.ops.aten.view_copy.default(arg0_1, [4, 2]) add = torch.ops.aten.add.Tensor(view_copy, ones); view_copy = ones = None view_copy_1 = torch.ops.aten.view_copy.default(add, [4, 2]) mul = torch.ops.aten.mul.Tensor(view_copy_1, view_copy_1) - copy_ = torch.ops.aten.copy_.default(a_1, view_copy_1); a_1 = view_copy_1 = None + copy_ = torch.ops.aten.copy_.default(arg0_1, view_copy_1); arg0_1 = view_copy_1 = None return add """) @@ -130,13 +224,13 @@ def forward(self, a_1): -def forward(self, a_1): +def forward(self, arg0_1): ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False) - view = torch.ops.aten.view.default(a_1, [4, 2]) + view = torch.ops.aten.view.default(arg0_1, [4, 2]) add = torch.ops.aten.add.Tensor(view, ones); view = ones = None view_1 = torch.ops.aten.view.default(add, [4, 2]) mul = torch.ops.aten.mul.Tensor(view_1, view_1) - copy_ = torch.ops.aten.copy_.default(a_1, view_1); a_1 = view_1 = None + copy_ = torch.ops.aten.copy_.default(arg0_1, view_1); arg0_1 = view_1 = None return add """) @@ -155,9 +249,9 @@ def f(x): -def forward(self, a_1): +def forward(self, arg0_1): ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False) - view_copy = torch.ops.aten.view_copy.default(a_1, [4, 2]); a_1 = None + view_copy = torch.ops.aten.view_copy.default(arg0_1, [4, 2]); arg0_1 = None empty = torch.ops.aten.empty.memory_format([], device = device(type='cpu'), pin_memory = False) add = torch.ops.aten.add.Tensor(view_copy, ones); view_copy = ones = None mul = torch.ops.aten.mul.Tensor(add, add); add = None @@ -169,9 +263,9 @@ def forward(self, a_1): -def forward(self, a_1): +def forward(self, arg0_1): ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False) - view = torch.ops.aten.view.default(a_1, [4, 2]); a_1 = None + view = torch.ops.aten.view.default(arg0_1, [4, 2]); arg0_1 = None empty = torch.ops.aten.empty.memory_format([], device = device(type='cpu'), pin_memory = False) add = torch.ops.aten.add.Tensor(view, ones); view = ones = None mul = torch.ops.aten.mul.Tensor(add, add); add = None @@ -192,10 +286,10 @@ def f(x): -def forward(self, a_1): +def forward(self, arg0_1): empty = torch.ops.aten.empty.memory_format([4], device = device(type='cpu'), pin_memory = False) empty_1 = torch.ops.aten.empty.memory_format([4], device = device(type='cpu'), pin_memory = False) - aminmax = torch.ops.aten.aminmax.default(a_1, dim = 0); a_1 = None + aminmax = torch.ops.aten.aminmax.default(arg0_1, dim = 0); arg0_1 = None getitem = aminmax[0] getitem_1 = aminmax[1]; aminmax = None return getitem @@ -206,10 +300,10 @@ def forward(self, a_1): -def forward(self, a_1): +def forward(self, arg0_1): empty = torch.ops.aten.empty.memory_format([4], device = device(type='cpu'), pin_memory = False) empty_1 = torch.ops.aten.empty.memory_format([4], device = device(type='cpu'), pin_memory = False) - aminmax = torch.ops.aten.aminmax.default(a_1, dim = 0); a_1 = None + aminmax = torch.ops.aten.aminmax.default(arg0_1, dim = 0); arg0_1 = None getitem = aminmax[0] getitem_1 = aminmax[1]; aminmax = None return getitem @@ -230,7 +324,7 @@ def f(x): -def forward(self, a_1): +def forward(self, arg0_1): _tensor_constant0 = self._tensor_constant0 lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None view_copy = torch.ops.aten.view_copy.default(lift_fresh_copy, [-1]); lift_fresh_copy = None @@ -244,7 +338,7 @@ def forward(self, a_1): -def forward(self, a_1): +def forward(self, arg0_1): _tensor_constant0 = self._tensor_constant0 lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None view = torch.ops.aten.view.default(lift_fresh_copy, [-1]); lift_fresh_copy = None @@ -263,7 +357,7 @@ def f(x): out = x[functional_tensor, nonfunctional_tensor] return out out = f(torch.ones(2, 2)) - out_functional = _functionalize(f, reapply_views=True)(torch.ones(2, 2)) + out_functional = _functionalize(f, reapply_views=True, crossref=self.crossref)(torch.ones(2, 2)) self.assertEqual(out, out_functional) def test_inplace_on_non_view(self): @@ -280,11 +374,11 @@ def f(x): -def forward(self, a_1): +def forward(self, arg0_1): ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False) - view_copy = torch.ops.aten.view_copy.default(a_1, [4, 2]) - add = torch.ops.aten.add.Tensor(a_1, ones); ones = None - copy_ = torch.ops.aten.copy_.default(a_1, add); a_1 = None + view_copy = torch.ops.aten.view_copy.default(arg0_1, [4, 2]) + add = torch.ops.aten.add.Tensor(arg0_1, ones); ones = None + copy_ = torch.ops.aten.copy_.default(arg0_1, add); arg0_1 = None view_copy_1 = torch.ops.aten.view_copy.default(add, [4, 2]); add = None return view_copy_1 """) @@ -294,11 +388,11 @@ def forward(self, a_1): -def forward(self, a_1): +def forward(self, arg0_1): ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False) - view = torch.ops.aten.view.default(a_1, [4, 2]) - add = torch.ops.aten.add.Tensor(a_1, ones); ones = None - copy_ = torch.ops.aten.copy_.default(a_1, add); a_1 = None + view = torch.ops.aten.view.default(arg0_1, [4, 2]) + add = torch.ops.aten.add.Tensor(arg0_1, ones); ones = None + copy_ = torch.ops.aten.copy_.default(arg0_1, add); arg0_1 = None view_1 = torch.ops.aten.view.default(add, [4, 2]); add = None return view_1 """) @@ -314,15 +408,15 @@ def f(x): -def forward(self, a_1): - _fused_moving_avg_obs_fq_helper_functional = torch.ops.aten._fused_moving_avg_obs_fq_helper_functional.default(a_1, a_1, a_1, a_1, a_1, a_1, a_1, 1.0, 0, 1, 0) +def forward(self, arg0_1): + _fused_moving_avg_obs_fq_helper_functional = torch.ops.aten._fused_moving_avg_obs_fq_helper_functional.default(arg0_1, arg0_1, arg0_1, arg0_1, arg0_1, arg0_1, arg0_1, 1.0, 0, 1, 0) getitem = _fused_moving_avg_obs_fq_helper_functional[0] getitem_1 = _fused_moving_avg_obs_fq_helper_functional[1] getitem_2 = _fused_moving_avg_obs_fq_helper_functional[2] getitem_3 = _fused_moving_avg_obs_fq_helper_functional[3] getitem_4 = _fused_moving_avg_obs_fq_helper_functional[4] getitem_5 = _fused_moving_avg_obs_fq_helper_functional[5]; _fused_moving_avg_obs_fq_helper_functional = None - copy_ = torch.ops.aten.copy_.default(a_1, getitem_5); a_1 = getitem_5 = None + copy_ = torch.ops.aten.copy_.default(arg0_1, getitem_5); arg0_1 = getitem_5 = None return (getitem, getitem_1) """) # noqa: B950 @@ -337,11 +431,11 @@ def f(x): -def forward(self, a_1): - as_strided_copy = torch.ops.aten.as_strided_copy.default(a_1, [2], [2], 1) +def forward(self, arg0_1): + as_strided_copy = torch.ops.aten.as_strided_copy.default(arg0_1, [2], [2], 1) add = torch.ops.aten.add.Tensor(as_strided_copy, 1); as_strided_copy = None - as_strided_scatter = torch.ops.aten.as_strided_scatter.default(a_1, add, [2], [2], 1); add = None - copy_ = torch.ops.aten.copy_.default(a_1, as_strided_scatter); a_1 = None + as_strided_scatter = torch.ops.aten.as_strided_scatter.default(arg0_1, add, [2], [2], 1); add = None + copy_ = torch.ops.aten.copy_.default(arg0_1, as_strided_scatter); arg0_1 = None return as_strided_scatter """) @@ -356,8 +450,8 @@ def f(x): -def forward(self, a_1): - block_diag = torch.ops.aten.block_diag.default([a_1, a_1]); a_1 = None +def forward(self, arg0_1): + block_diag = torch.ops.aten.block_diag.default([arg0_1, arg0_1]); arg0_1 = None return block_diag """) @@ -372,9 +466,9 @@ def f(x): -def forward(self, a_1): +def forward(self, arg0_1): empty = torch.ops.aten.empty.memory_format([0], device = device(type='cpu'), pin_memory = False) - cat = torch.ops.aten.cat.default([a_1]); a_1 = None + cat = torch.ops.aten.cat.default([arg0_1]); arg0_1 = None return cat """) @@ -383,9 +477,9 @@ def forward(self, a_1): -def forward(self, a_1): +def forward(self, arg0_1): empty = torch.ops.aten.empty.memory_format([0], device = device(type='cpu'), pin_memory = False) - cat = torch.ops.aten.cat.default([a_1]); a_1 = None + cat = torch.ops.aten.cat.default([arg0_1]); arg0_1 = None return cat """) @@ -404,12 +498,12 @@ def f(x): -def forward(self, a_1): +def forward(self, arg0_1): ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False) - clone = torch.ops.aten.clone.default(a_1) + clone = torch.ops.aten.clone.default(arg0_1) diagonal_copy = torch.ops.aten.diagonal_copy.default(clone); clone = None add = torch.ops.aten.add.Tensor(diagonal_copy, ones); diagonal_copy = ones = None - mul = torch.ops.aten.mul.Tensor(a_1, a_1); a_1 = None + mul = torch.ops.aten.mul.Tensor(arg0_1, arg0_1); arg0_1 = None return mul """) @@ -418,12 +512,12 @@ def forward(self, a_1): -def forward(self, a_1): +def forward(self, arg0_1): ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False) - clone = torch.ops.aten.clone.default(a_1) + clone = torch.ops.aten.clone.default(arg0_1) diagonal = torch.ops.aten.diagonal.default(clone); clone = None add = torch.ops.aten.add_.Tensor(diagonal, ones); diagonal = ones = None - mul = torch.ops.aten.mul.Tensor(a_1, a_1); a_1 = None + mul = torch.ops.aten.mul.Tensor(arg0_1, arg0_1); arg0_1 = None return mul """) @@ -441,12 +535,12 @@ def f(x): -def forward(self, a_1): +def forward(self, arg0_1): ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False) - diagonal_copy = torch.ops.aten.diagonal_copy.default(a_1) + diagonal_copy = torch.ops.aten.diagonal_copy.default(arg0_1) add = torch.ops.aten.add.Tensor(diagonal_copy, ones); diagonal_copy = ones = None - diagonal_scatter = torch.ops.aten.diagonal_scatter.default(a_1, add); add = None - copy_ = torch.ops.aten.copy_.default(a_1, diagonal_scatter); a_1 = None + diagonal_scatter = torch.ops.aten.diagonal_scatter.default(arg0_1, add); add = None + copy_ = torch.ops.aten.copy_.default(arg0_1, diagonal_scatter); arg0_1 = None return diagonal_scatter """) @@ -465,20 +559,20 @@ def f(x): -def forward(self, a_1): +def forward(self, arg0_1): ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False) - split_copy = torch.ops.aten.split_copy.Tensor(a_1, 2) + split_copy = torch.ops.aten.split_copy.Tensor(arg0_1, 2) getitem = split_copy[0] getitem_1 = split_copy[1]; split_copy = None diagonal_copy = torch.ops.aten.diagonal_copy.default(getitem_1); getitem_1 = None add = torch.ops.aten.add.Tensor(diagonal_copy, ones); diagonal_copy = ones = None - split_copy_1 = torch.ops.aten.split_copy.Tensor(a_1, 2) + split_copy_1 = torch.ops.aten.split_copy.Tensor(arg0_1, 2) getitem_2 = split_copy_1[0] getitem_3 = split_copy_1[1]; split_copy_1 = None diagonal_scatter = torch.ops.aten.diagonal_scatter.default(getitem_3, add); getitem_3 = None - slice_scatter = torch.ops.aten.slice_scatter.default(a_1, diagonal_scatter, 0, 2, 4); diagonal_scatter = None + slice_scatter = torch.ops.aten.slice_scatter.default(arg0_1, diagonal_scatter, 0, 2, 4); diagonal_scatter = None mul = torch.ops.aten.mul.Tensor(slice_scatter, slice_scatter) - copy_ = torch.ops.aten.copy_.default(a_1, slice_scatter); a_1 = slice_scatter = None + copy_ = torch.ops.aten.copy_.default(arg0_1, slice_scatter); arg0_1 = slice_scatter = None return add """) # noqa: B950 @@ -496,12 +590,12 @@ def f(x): -def forward(self, a_1): +def forward(self, arg0_1): ones = torch.ops.aten.ones.default([4], device = device(type='cpu'), pin_memory = False) - transpose_copy = torch.ops.aten.transpose_copy.int(a_1, 1, 0) + transpose_copy = torch.ops.aten.transpose_copy.int(arg0_1, 1, 0) select_copy = torch.ops.aten.select_copy.int(transpose_copy, 0, 0); transpose_copy = None add = torch.ops.aten.add.Tensor(select_copy, ones); select_copy = ones = None - transpose_copy_1 = torch.ops.aten.transpose_copy.int(a_1, 1, 0); a_1 = None + transpose_copy_1 = torch.ops.aten.transpose_copy.int(arg0_1, 1, 0); arg0_1 = None select_scatter = torch.ops.aten.select_scatter.default(transpose_copy_1, add, 0, 0); transpose_copy_1 = add = None transpose_copy_2 = torch.ops.aten.transpose_copy.int(select_scatter, 1, 0); select_scatter = None transpose_copy_3 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0); transpose_copy_2 = None @@ -523,13 +617,13 @@ def f(x): -def forward(self, a_1): - view_copy = torch.ops.aten.view_copy.default(a_1, [8]) +def forward(self, arg0_1): + view_copy = torch.ops.aten.view_copy.default(arg0_1, [8]) arange = torch.ops.aten.arange.default(4, device = device(type='cpu'), pin_memory = False) arange_1 = torch.ops.aten.arange.default(4, dtype = torch.float32, device = device(type='cpu'), pin_memory = False) index_put = torch.ops.aten.index_put.default(view_copy, [arange], arange_1); view_copy = arange = arange_1 = None view_copy_1 = torch.ops.aten.view_copy.default(index_put, [4, 2]) - copy_ = torch.ops.aten.copy_.default(a_1, view_copy_1); a_1 = view_copy_1 = None + copy_ = torch.ops.aten.copy_.default(arg0_1, view_copy_1); arg0_1 = view_copy_1 = None return index_put """) # noqa: B950 @@ -548,14 +642,14 @@ def f(x): -def forward(self, a_1): +def forward(self, arg0_1): ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False) - view_copy = torch.ops.aten.view_copy.default(a_1, [4, 2]) + view_copy = torch.ops.aten.view_copy.default(arg0_1, [4, 2]) add = torch.ops.aten.add.Tensor(view_copy, 1); view_copy = None mul = torch.ops.aten.mul.Tensor(add, 2) div = torch.ops.aten.div.Tensor(mul, 1); mul = None view_copy_1 = torch.ops.aten.view_copy.default(add, [4, 2]); add = None - copy_ = torch.ops.aten.copy_.default(a_1, view_copy_1); a_1 = view_copy_1 = None + copy_ = torch.ops.aten.copy_.default(arg0_1, view_copy_1); arg0_1 = view_copy_1 = None return div """) @@ -573,8 +667,8 @@ def f(x): -def forward(self, a_1): - clone = torch.ops.aten.clone.default(a_1); a_1 = None +def forward(self, arg0_1): + clone = torch.ops.aten.clone.default(arg0_1); arg0_1 = None ge = torch.ops.aten.ge.Scalar(clone, 0); clone = None _to_copy = torch.ops.aten._to_copy.default(ge, dtype = torch.float32, layout = torch.strided); ge = None return _to_copy @@ -585,8 +679,8 @@ def forward(self, a_1): -def forward(self, a_1): - clone = torch.ops.aten.clone.default(a_1); a_1 = None +def forward(self, arg0_1): + clone = torch.ops.aten.clone.default(arg0_1); arg0_1 = None ge = torch.ops.aten.ge.Scalar(clone, 0); clone = None _to_copy = torch.ops.aten._to_copy.default(ge, dtype = torch.float32, layout = torch.strided); ge = None return _to_copy @@ -621,8 +715,8 @@ def f(x): -def forward(self, a_1): - view_copy = torch.ops.aten.view_copy.default(a_1, [4, 2]); a_1 = None +def forward(self, arg0_1): + view_copy = torch.ops.aten.view_copy.default(arg0_1, [4, 2]); arg0_1 = None return view_copy """) @@ -646,36 +740,44 @@ def f(x): -def forward(self, a_1): +def forward(self, arg0_1): ones = torch.ops.aten.ones.default([2, 2], device = device(type='cpu'), pin_memory = False) - add = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None + add = torch.ops.aten.add.Tensor(arg0_1, arg0_1); arg0_1 = None view_copy = torch.ops.aten.view_copy.default(add, [8]) - _reshape_alias_copy = torch.ops.aten._reshape_alias_copy.default(view_copy, [2, 4], [4, 1]); view_copy = None - transpose_copy = torch.ops.aten.transpose_copy.int(_reshape_alias_copy, 1, 0) + view_copy_1 = torch.ops.aten.view_copy.default(view_copy, [2, 4]); view_copy = None + transpose_copy = torch.ops.aten.transpose_copy.int(view_copy_1, 1, 0) unsqueeze_copy = torch.ops.aten.unsqueeze_copy.default(transpose_copy, 0); transpose_copy = None squeeze_copy = torch.ops.aten.squeeze_copy.default(unsqueeze_copy); unsqueeze_copy = None split_copy = torch.ops.aten.split_copy.Tensor(squeeze_copy, 2); squeeze_copy = None getitem = split_copy[0] getitem_1 = split_copy[1]; split_copy = None add_1 = torch.ops.aten.add.Tensor(getitem, ones); getitem = ones = None - select_copy = torch.ops.aten.select_copy.int(_reshape_alias_copy, 0, 0); _reshape_alias_copy = None - clone = torch.ops.aten.clone.default(add_1, memory_format = torch.contiguous_format) - _unsafe_view = torch.ops.aten._unsafe_view.default(clone, [4]); clone = None - view_copy_1 = torch.ops.aten.view_copy.default(add, [8]); add = None - _reshape_alias_copy_1 = torch.ops.aten._reshape_alias_copy.default(view_copy_1, [2, 4], [4, 1]); view_copy_1 = None - transpose_copy_1 = torch.ops.aten.transpose_copy.int(_reshape_alias_copy_1, 1, 0); _reshape_alias_copy_1 = None + select_copy = torch.ops.aten.select_copy.int(view_copy_1, 0, 0); view_copy_1 = None + view_copy_2 = torch.ops.aten.view_copy.default(add_1, [4]) + view_copy_3 = torch.ops.aten.view_copy.default(add, [8]); add = None + view_copy_4 = torch.ops.aten.view_copy.default(view_copy_3, [2, 4]); view_copy_3 = None + transpose_copy_1 = torch.ops.aten.transpose_copy.int(view_copy_4, 1, 0); view_copy_4 = None unsqueeze_copy_1 = torch.ops.aten.unsqueeze_copy.default(transpose_copy_1, 0); transpose_copy_1 = None squeeze_copy_1 = torch.ops.aten.squeeze_copy.default(unsqueeze_copy_1); unsqueeze_copy_1 = None slice_scatter = torch.ops.aten.slice_scatter.default(squeeze_copy_1, add_1, 0, 0, 2); squeeze_copy_1 = None unsqueeze_copy_2 = torch.ops.aten.unsqueeze_copy.default(slice_scatter, 0); slice_scatter = None squeeze_copy_2 = torch.ops.aten.squeeze_copy.dim(unsqueeze_copy_2, 0); unsqueeze_copy_2 = None transpose_copy_2 = torch.ops.aten.transpose_copy.int(squeeze_copy_2, 1, 0); squeeze_copy_2 = None - _reshape_alias_copy_2 = torch.ops.aten._reshape_alias_copy.default(transpose_copy_2, [8], [1]); transpose_copy_2 = None - view_copy_2 = torch.ops.aten.view_copy.default(_reshape_alias_copy_2, [4, 2]); _reshape_alias_copy_2 = None - view_copy_3 = torch.ops.aten.view_copy.default(view_copy_2, [8]); view_copy_2 = None - _reshape_alias_copy_3 = torch.ops.aten._reshape_alias_copy.default(view_copy_3, [2, 4], [4, 1]); view_copy_3 = None - select_copy_1 = torch.ops.aten.select_copy.int(_reshape_alias_copy_3, 0, 0); _reshape_alias_copy_3 = None - add_2 = torch.ops.aten.add.Tensor(select_copy_1, _unsafe_view); select_copy_1 = _unsafe_view = None + view_copy_5 = torch.ops.aten.view_copy.default(transpose_copy_2, [8]); transpose_copy_2 = None + view_copy_6 = torch.ops.aten.view_copy.default(view_copy_5, [4, 2]); view_copy_5 = None + view_copy_7 = torch.ops.aten.view_copy.default(view_copy_6, [8]) + view_copy_8 = torch.ops.aten.view_copy.default(view_copy_7, [2, 4]); view_copy_7 = None + select_copy_1 = torch.ops.aten.select_copy.int(view_copy_8, 0, 0); view_copy_8 = None + view_copy_9 = torch.ops.aten.view_copy.default(view_copy_6, [8]); view_copy_6 = None + view_copy_10 = torch.ops.aten.view_copy.default(view_copy_9, [2, 4]); view_copy_9 = None + transpose_copy_3 = torch.ops.aten.transpose_copy.int(view_copy_10, 1, 0); view_copy_10 = None + unsqueeze_copy_3 = torch.ops.aten.unsqueeze_copy.default(transpose_copy_3, 0); transpose_copy_3 = None + squeeze_copy_3 = torch.ops.aten.squeeze_copy.default(unsqueeze_copy_3); unsqueeze_copy_3 = None + split_copy_1 = torch.ops.aten.split_copy.Tensor(squeeze_copy_3, 2); squeeze_copy_3 = None + getitem_2 = split_copy_1[0] + getitem_3 = split_copy_1[1]; split_copy_1 = None + view_copy_11 = torch.ops.aten.view_copy.default(getitem_2, [4]); getitem_2 = None + add_2 = torch.ops.aten.add.Tensor(select_copy_1, view_copy_11); select_copy_1 = view_copy_11 = None return add_1 """) # noqa: B950 @@ -684,34 +786,34 @@ def forward(self, a_1): -def forward(self, a_1): +def forward(self, arg0_1): ones = torch.ops.aten.ones.default([2, 2], device = device(type='cpu'), pin_memory = False) - add = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None + add = torch.ops.aten.add.Tensor(arg0_1, arg0_1); arg0_1 = None view = torch.ops.aten.view.default(add, [8]) - _reshape_alias = torch.ops.aten._reshape_alias.default(view, [2, 4], [4, 1]); view = None - transpose = torch.ops.aten.transpose.int(_reshape_alias, 1, 0) + view_1 = torch.ops.aten.view.default(view, [2, 4]); view = None + transpose = torch.ops.aten.transpose.int(view_1, 1, 0) unsqueeze = torch.ops.aten.unsqueeze.default(transpose, 0); transpose = None squeeze = torch.ops.aten.squeeze.default(unsqueeze); unsqueeze = None split = torch.ops.aten.split.Tensor(squeeze, 2); squeeze = None getitem = split[0] getitem_1 = split[1]; split = None add_1 = torch.ops.aten.add_.Tensor(getitem, ones); ones = None - select = torch.ops.aten.select.int(_reshape_alias, 0, 0); _reshape_alias = None + select = torch.ops.aten.select.int(view_1, 0, 0); view_1 = None clone = torch.ops.aten.clone.default(getitem, memory_format = torch.contiguous_format) _unsafe_view = torch.ops.aten._unsafe_view.default(clone, [4]); clone = None - view_1 = torch.ops.aten.view.default(add, [8]); add = None - _reshape_alias_1 = torch.ops.aten._reshape_alias.default(view_1, [2, 4], [4, 1]); view_1 = None - transpose_1 = torch.ops.aten.transpose.int(_reshape_alias_1, 1, 0); _reshape_alias_1 = None + view_2 = torch.ops.aten.view.default(add, [8]); add = None + view_3 = torch.ops.aten.view.default(view_2, [2, 4]); view_2 = None + transpose_1 = torch.ops.aten.transpose.int(view_3, 1, 0); view_3 = None unsqueeze_1 = torch.ops.aten.unsqueeze.default(transpose_1, 0); transpose_1 = None squeeze_1 = torch.ops.aten.squeeze.default(unsqueeze_1); unsqueeze_1 = None unsqueeze_2 = torch.ops.aten.unsqueeze.default(squeeze_1, 0); squeeze_1 = None squeeze_2 = torch.ops.aten.squeeze.dim(unsqueeze_2, 0); unsqueeze_2 = None transpose_2 = torch.ops.aten.transpose.int(squeeze_2, 1, 0); squeeze_2 = None - _reshape_alias_2 = torch.ops.aten._reshape_alias.default(transpose_2, [8], [1]); transpose_2 = None - view_2 = torch.ops.aten.view.default(_reshape_alias_2, [4, 2]); _reshape_alias_2 = None - view_3 = torch.ops.aten.view.default(view_2, [8]); view_2 = None - _reshape_alias_3 = torch.ops.aten._reshape_alias.default(view_3, [2, 4], [4, 1]); view_3 = None - select_1 = torch.ops.aten.select.int(_reshape_alias_3, 0, 0); _reshape_alias_3 = None + view_4 = torch.ops.aten.view.default(transpose_2, [8]); transpose_2 = None + view_5 = torch.ops.aten.view.default(view_4, [4, 2]); view_4 = None + view_6 = torch.ops.aten.view.default(view_5, [8]); view_5 = None + view_7 = torch.ops.aten.view.default(view_6, [2, 4]); view_6 = None + select_1 = torch.ops.aten.select.int(view_7, 0, 0); view_7 = None add_2 = torch.ops.aten.add.Tensor(select_1, _unsafe_view); select_1 = _unsafe_view = None return getitem """) @@ -729,13 +831,13 @@ def f(x): -def forward(self, a_1): +def forward(self, arg0_1): ones = torch.ops.aten.ones.default([4, 2], device = device(type='cpu'), pin_memory = False) - view = torch.ops.aten.view.default(a_1, [4, 2]) + view = torch.ops.aten.view.default(arg0_1, [4, 2]) add = torch.ops.aten.add.Tensor(view, ones); view = ones = None view_1 = torch.ops.aten.view.default(add, [4, 2]) mul = torch.ops.aten.mul.Tensor(view_1, view_1) - copy_ = torch.ops.aten.copy_.default(a_1, view_1); a_1 = view_1 = None + copy_ = torch.ops.aten.copy_.default(arg0_1, view_1); arg0_1 = view_1 = None return add """) @@ -761,8 +863,8 @@ def f(x): _z = torch._from_functional_tensor(z) self.assertTrue(are_aliased(_y, _z)) - # copy_() gets its own test, because it is special cased in functionalization. - # self.copy_(src) decomposes into src.to(self).expand_as(self). + # copy_() gets its own test, because it used to be special cased in functionalization. + # However, now it works pretty similar to other functional ops def test_copy_(self): def f(x): tmp = torch.zeros(2, 2) @@ -779,10 +881,11 @@ def f(x): -def forward(self, a_1): +def forward(self, arg0_1): zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False) diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros); zeros = None - add = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None + copy = torch.ops.aten.copy.default(diagonal_copy, arg0_1); diagonal_copy = None + add = torch.ops.aten.add.Tensor(copy, arg0_1); copy = arg0_1 = None return add """) @@ -791,11 +894,12 @@ def forward(self, a_1): -def forward(self, a_1): +def forward(self, arg0_1): zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False) diagonal = torch.ops.aten.diagonal.default(zeros); zeros = None - add = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None - return add + copy = torch.ops.aten.copy_.default(diagonal, arg0_1) + add = torch.ops.aten.add_.Tensor(diagonal, arg0_1); arg0_1 = None + return diagonal """) # Test 2: copy_() with same dtype, different shape @@ -805,11 +909,11 @@ def forward(self, a_1): -def forward(self, a_1): +def forward(self, arg0_1): zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False) diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros); zeros = None - expand_copy = torch.ops.aten.expand_copy.default(a_1, [2]) - add = torch.ops.aten.add.Tensor(expand_copy, a_1); expand_copy = a_1 = None + copy = torch.ops.aten.copy.default(diagonal_copy, arg0_1); diagonal_copy = None + add = torch.ops.aten.add.Tensor(copy, arg0_1); copy = arg0_1 = None return add """) @@ -818,12 +922,12 @@ def forward(self, a_1): -def forward(self, a_1): +def forward(self, arg0_1): zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False) diagonal = torch.ops.aten.diagonal.default(zeros); zeros = None - expand_copy = torch.ops.aten.expand_copy.default(a_1, [2]) - add = torch.ops.aten.add_.Tensor(expand_copy, a_1); a_1 = None - return expand_copy + copy = torch.ops.aten.copy_.default(diagonal, arg0_1) + add = torch.ops.aten.add_.Tensor(diagonal, arg0_1); arg0_1 = None + return diagonal """) # Test 3: copy_() with different dtype, same shape @@ -833,11 +937,11 @@ def forward(self, a_1): -def forward(self, a_1): +def forward(self, arg0_1): zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False) diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros); zeros = None - _to_copy = torch.ops.aten._to_copy.default(a_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False) - add = torch.ops.aten.add.Tensor(_to_copy, a_1); _to_copy = a_1 = None + copy = torch.ops.aten.copy.default(diagonal_copy, arg0_1); diagonal_copy = None + add = torch.ops.aten.add.Tensor(copy, arg0_1); copy = arg0_1 = None return add """) # noqa: B950 @@ -846,12 +950,12 @@ def forward(self, a_1): -def forward(self, a_1): +def forward(self, arg0_1): zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False) diagonal = torch.ops.aten.diagonal.default(zeros); zeros = None - _to_copy = torch.ops.aten._to_copy.default(a_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False) - add = torch.ops.aten.add_.Tensor(_to_copy, a_1); a_1 = None - return _to_copy + copy = torch.ops.aten.copy_.default(diagonal, arg0_1) + add = torch.ops.aten.add_.Tensor(diagonal, arg0_1); arg0_1 = None + return diagonal """) # noqa: B950 # Test 4: copy_() with different dtype, different shape @@ -861,12 +965,11 @@ def forward(self, a_1): -def forward(self, a_1): +def forward(self, arg0_1): zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False) diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros); zeros = None - _to_copy = torch.ops.aten._to_copy.default(a_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False) - expand_copy = torch.ops.aten.expand_copy.default(_to_copy, [2]); _to_copy = None - add = torch.ops.aten.add.Tensor(expand_copy, a_1); expand_copy = a_1 = None + copy = torch.ops.aten.copy.default(diagonal_copy, arg0_1); diagonal_copy = None + add = torch.ops.aten.add.Tensor(copy, arg0_1); copy = arg0_1 = None return add """) # noqa: B950 @@ -875,13 +978,12 @@ def forward(self, a_1): -def forward(self, a_1): +def forward(self, arg0_1): zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False) diagonal = torch.ops.aten.diagonal.default(zeros); zeros = None - _to_copy = torch.ops.aten._to_copy.default(a_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False) - expand_copy = torch.ops.aten.expand_copy.default(_to_copy, [2]); _to_copy = None - add = torch.ops.aten.add_.Tensor(expand_copy, a_1); a_1 = None - return expand_copy + copy = torch.ops.aten.copy_.default(diagonal, arg0_1) + add = torch.ops.aten.add_.Tensor(diagonal, arg0_1); arg0_1 = None + return diagonal """) # noqa: B950 def test_expand_symint(self): @@ -896,8 +998,8 @@ def f(x): -def forward(self, a_1): - expand_copy = torch.ops.aten.expand_copy.default(a_1, [2, 2]); a_1 = None +def forward(self, arg0_1): + expand_copy = torch.ops.aten.expand_copy.default(arg0_1, [2, 2]); arg0_1 = None return expand_copy """) @@ -914,8 +1016,8 @@ def f(x): -def forward(self, a_1): - add = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None +def forward(self, arg0_1): + add = torch.ops.aten.add.Tensor(arg0_1, arg0_1); arg0_1 = None diagonal_copy = torch.ops.aten.diagonal_copy.default(add) fill = torch.ops.aten.fill.Scalar(diagonal_copy, 0); diagonal_copy = None diagonal_scatter = torch.ops.aten.diagonal_scatter.default(add, fill); add = fill = None @@ -927,8 +1029,8 @@ def forward(self, a_1): -def forward(self, a_1): - add = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None +def forward(self, arg0_1): + add = torch.ops.aten.add.Tensor(arg0_1, arg0_1); arg0_1 = None diagonal = torch.ops.aten.diagonal.default(add) fill = torch.ops.aten.fill_.Scalar(diagonal, 0); diagonal = None return add @@ -951,8 +1053,8 @@ def f(w): -def forward(self, a_1): - add = torch.ops.aten.add.Tensor(a_1, 1); a_1 = None +def forward(self, arg0_1): + add = torch.ops.aten.add.Tensor(arg0_1, 1); arg0_1 = None view_copy = torch.ops.aten.view_copy.default(add, [4, 4]) resize = torch.ops.aten.resize.default(view_copy, [3, 3]) as_strided_copy = torch.ops.aten.as_strided_copy.default(view_copy, [3, 3], [3, 1]); view_copy = None @@ -974,8 +1076,8 @@ def forward(self, a_1): -def forward(self, a_1): - add = torch.ops.aten.add.Tensor(a_1, 1); a_1 = None +def forward(self, arg0_1): + add = torch.ops.aten.add.Tensor(arg0_1, 1); arg0_1 = None view = torch.ops.aten.view.default(add, [4, 4]) resize = torch.ops.aten.resize.default(view, [3, 3]) as_strided = torch.ops.aten.as_strided.default(view, [3, 3], [3, 1]); view = None @@ -1014,8 +1116,8 @@ def f(x): -def forward(self, a_1): - add = torch.ops.aten.add.Tensor(a_1, 1); a_1 = None +def forward(self, arg0_1): + add = torch.ops.aten.add.Tensor(arg0_1, 1); arg0_1 = None resize = torch.ops.aten.resize.default(add, [5, 5]); add = None view_copy = torch.ops.aten.view_copy.default(resize, [25]); resize = None fill = torch.ops.aten.fill.Scalar(view_copy, 1); view_copy = None @@ -1029,8 +1131,8 @@ def forward(self, a_1): -def forward(self, a_1): - add = torch.ops.aten.add.Tensor(a_1, 1); a_1 = None +def forward(self, arg0_1): + add = torch.ops.aten.add.Tensor(arg0_1, 1); arg0_1 = None resize = torch.ops.aten.resize_.default(add, [5, 5]) view = torch.ops.aten.view.default(add, [25]); add = None fill = torch.ops.aten.fill_.Scalar(view, 1) @@ -1113,7 +1215,7 @@ def f(x): -def forward(self, a_1): +def forward(self, arg0_1): zeros = torch.ops.aten.zeros.default([10], device = device(type='cpu'), pin_memory = False) select_copy = torch.ops.aten.select_copy.int(zeros, 0, 5) fill = torch.ops.aten.fill.Scalar(select_copy, 1); select_copy = None @@ -1126,12 +1228,156 @@ def forward(self, a_1): -def forward(self, a_1): +def forward(self, arg0_1): zeros = torch.ops.aten.zeros.default([10], device = device(type='cpu'), pin_memory = False) select = torch.ops.aten.select.int(zeros, 0, 5) fill = torch.ops.aten.fill_.Scalar(select, 1); select = None return zeros """) + + def test_instance_norm(self): + size = 100 + + def f(x, running_mean, running_var): + with enable_python_dispatcher(): + return torch.instance_norm(x, None, None, running_mean, running_var, + use_input_stats=True, momentum=0.1, eps=1e-5, cudnn_enabled=False) + self.assert_functionalization(f, torch.randn(20, size, 35, 45), torch.zeros(size), torch.ones(size)) + # On Windows, for instance_norm, the alias_copy's are reordered to come right before they need to be used + # whereas on other platforms, the alias_copy's are before the view_copy's. + # e.g., the alias_copy after the getitem_4 assignment would be moved to be right before the copy assignment. + if not IS_WINDOWS: + logs = self.get_logs(f, torch.randn(20, size, 35, 45), torch.zeros(size), torch.ones(size)) + self.assertExpectedInline(logs, """\ + + + +def forward(self, arg0_1, arg1_1, arg2_1): + repeat = torch.ops.aten.repeat.default(arg1_1, [20]) + repeat_1 = torch.ops.aten.repeat.default(arg2_1, [20]) + view_copy = torch.ops.aten.view_copy.default(arg0_1, [1, 2000, 35, 45]); arg0_1 = None + empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu')) + _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(view_copy, None, None, repeat, repeat_1, True, 0.1, 1e-05); view_copy = repeat = repeat_1 = None + getitem = _native_batch_norm_legit_functional[0] + getitem_1 = _native_batch_norm_legit_functional[1] + getitem_2 = _native_batch_norm_legit_functional[2] + getitem_3 = _native_batch_norm_legit_functional[3] + getitem_4 = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None + alias_copy = torch.ops.aten.alias_copy.default(arg1_1) + view_copy_1 = torch.ops.aten.view_copy.default(getitem_3, [20, 100]) + view_copy_2 = torch.ops.aten.view_copy.default(getitem_3, [20, 100]); getitem_3 = None + mean = torch.ops.aten.mean.dim(view_copy_2, [0]); view_copy_2 = None + copy = torch.ops.aten.copy.default(alias_copy, mean); alias_copy = mean = None + alias_copy_1 = torch.ops.aten.alias_copy.default(arg2_1) + view_copy_3 = torch.ops.aten.view_copy.default(getitem_4, [20, 100]) + view_copy_4 = torch.ops.aten.view_copy.default(getitem_4, [20, 100]); getitem_4 = None + mean_1 = torch.ops.aten.mean.dim(view_copy_4, [0]); view_copy_4 = None + copy_1 = torch.ops.aten.copy.default(alias_copy_1, mean_1); alias_copy_1 = mean_1 = None + view_copy_5 = torch.ops.aten.view_copy.default(getitem, [20, 100, 35, 45]); getitem = None + alias_copy_2 = torch.ops.aten.alias_copy.default(copy); copy = None + copy_ = torch.ops.aten.copy_.default(arg1_1, alias_copy_2); arg1_1 = alias_copy_2 = None + alias_copy_3 = torch.ops.aten.alias_copy.default(copy_1); copy_1 = None + copy__1 = torch.ops.aten.copy_.default(arg2_1, alias_copy_3); arg2_1 = alias_copy_3 = None + return view_copy_5 + """) # noqa: B950 + + reinplaced_logs = self.get_logs( + f, torch.randn(20, size, 35, 45), torch.zeros(size), torch.ones(size), + reapply_views=True, run_reinplace=True + ) + self.assertExpectedInline(reinplaced_logs, """\ + + + +def forward(self, arg0_1, arg1_1, arg2_1): + repeat = torch.ops.aten.repeat.default(arg1_1, [20]) + repeat_1 = torch.ops.aten.repeat.default(arg2_1, [20]) + view = torch.ops.aten.view.default(arg0_1, [1, 2000, 35, 45]); arg0_1 = None + empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu')) + _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(view, None, None, repeat, repeat_1, True, 0.1, 1e-05); view = repeat = repeat_1 = None + getitem = _native_batch_norm_legit_functional[0] + getitem_1 = _native_batch_norm_legit_functional[1] + getitem_2 = _native_batch_norm_legit_functional[2] + getitem_3 = _native_batch_norm_legit_functional[3] + getitem_4 = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None + alias = torch.ops.aten.alias.default(arg1_1) + view_1 = torch.ops.aten.view.default(getitem_3, [20, 100]) + view_2 = torch.ops.aten.view.default(getitem_3, [20, 100]); getitem_3 = None + mean = torch.ops.aten.mean.dim(view_2, [0]); view_2 = None + copy = torch.ops.aten.copy.default(alias, mean); alias = mean = None + alias_1 = torch.ops.aten.alias.default(arg2_1) + view_3 = torch.ops.aten.view.default(getitem_4, [20, 100]) + view_4 = torch.ops.aten.view.default(getitem_4, [20, 100]); getitem_4 = None + mean_1 = torch.ops.aten.mean.dim(view_4, [0]); view_4 = None + copy_1 = torch.ops.aten.copy.default(alias_1, mean_1); alias_1 = mean_1 = None + view_5 = torch.ops.aten.view.default(getitem, [20, 100, 35, 45]); getitem = None + alias_2 = torch.ops.aten.alias.default(copy); copy = None + copy_ = torch.ops.aten.copy_.default(arg1_1, alias_2); arg1_1 = alias_2 = None + alias_3 = torch.ops.aten.alias.default(copy_1); copy_1 = None + copy__1 = torch.ops.aten.copy_.default(arg2_1, alias_3); arg2_1 = alias_3 = None + return view_5 + """) # noqa: B950 + + + def test_batch_norm(self): + def f(x, running_mean, running_var): + with enable_python_dispatcher(): + return torch.batch_norm(x, None, None, running_mean, running_var, False, 0.1, 1e-5, False) + + self.assert_functionalization(f, torch.randn(20, 100, 35, 45), torch.zeros(100), torch.ones(100)) + logs = self.get_logs(f, torch.randn(20, 100, 35, 45), torch.zeros(100), torch.ones(100)) + self.assertExpectedInline(logs, """\ + + + +def forward(self, arg0_1, arg1_1, arg2_1): + empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu')) + _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(arg0_1, None, None, arg1_1, arg2_1, False, 0.1, 1e-05); arg0_1 = None + getitem = _native_batch_norm_legit_functional[0] + getitem_1 = _native_batch_norm_legit_functional[1] + getitem_2 = _native_batch_norm_legit_functional[2] + getitem_3 = _native_batch_norm_legit_functional[3] + getitem_4 = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None + copy_ = torch.ops.aten.copy_.default(arg1_1, getitem_3); arg1_1 = getitem_3 = None + copy__1 = torch.ops.aten.copy_.default(arg2_1, getitem_4); arg2_1 = getitem_4 = None + return getitem + """) # noqa: B950 + + reinplaced_logs = self.get_logs( + f, torch.randn(20, 100, 35, 45), torch.zeros(100), torch.ones(100), reapply_views=True, run_reinplace=True + ) + self.assertExpectedInline(reinplaced_logs, """\ + + + +def forward(self, arg0_1, arg1_1, arg2_1): + empty = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cpu')) + _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(arg0_1, None, None, arg1_1, arg2_1, False, 0.1, 1e-05); arg0_1 = None + getitem = _native_batch_norm_legit_functional[0] + getitem_1 = _native_batch_norm_legit_functional[1] + getitem_2 = _native_batch_norm_legit_functional[2] + getitem_3 = _native_batch_norm_legit_functional[3] + getitem_4 = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None + copy_ = torch.ops.aten.copy_.default(arg1_1, getitem_3); arg1_1 = getitem_3 = None + copy__1 = torch.ops.aten.copy_.default(arg2_1, getitem_4); arg2_1 = getitem_4 = None + return getitem + """) # noqa: B950 + + +@xfail_inherited_tests([ + "test_as_strided", + "test_copy_", + "test_diagonal", + "test_diagonal_mutated_input", + "test_everything", + "test_fill_", + "test_split", + "test_view_clone_view_inplace", + "test_view_inplace", +]) +class TestCrossRefFunctionalization(TestFunctionalization): + crossref = True + if __name__ == '__main__': run_tests() diff --git a/test/test_fx.py b/test/test_fx.py index c8da9d3d2cf67..a9e186a2f7f0c 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -57,7 +57,6 @@ IS_WINDOWS, find_library_location, run_tests, - skipIfSlowGradcheckEnv, ) from torch.testing._internal.jit_utils import JitTestCase @@ -234,7 +233,7 @@ def forward(self, x): new_instance.__init__(gm3, gm3.graph) x = torch.randn(5, 3) - torch.testing.assert_allclose(new_instance(x), torch.relu(x)) + torch.testing.assert_close(new_instance(x), torch.relu(x)) def test_custom_import(self): graph = torch.fx.Graph() @@ -809,7 +808,7 @@ def forward(self, x): traced = torch.fx.symbolic_trace(ec) x = torch.randn(bs, d_hid) - torch.testing.assert_allclose(ec(x), traced(x)) + torch.testing.assert_close(ec(x), traced(x)) def test_node_tagging(self): @@ -1126,7 +1125,7 @@ def foo(x : Tuple): traced = torch.fx.symbolic_trace(foo) x = (torch.randn(5, 3),) - torch.testing.assert_allclose(traced(x), x[0]) + torch.testing.assert_close(traced(x), x[0]) bio = io.BytesIO() @@ -1136,7 +1135,7 @@ def foo(x : Tuple): loaded = torch.load(bio) - torch.testing.assert_allclose(loaded(x), x[0]) + torch.testing.assert_close(loaded(x), x[0]) def test_torch_fx_len(self): class FXLenTest(torch.nn.Module): @@ -1680,6 +1679,36 @@ def forward(self, x): if node.op in {'placeholder'}: self.assertEqual(node.meta['tensor_meta'].memory_format, torch.channels_last_3d) + def test_nn_module_stack(self): + class SubModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv_mod = torch.nn.Conv2d(64, 64, (3, 3), padding=1, bias=False) + + def forward(self, x): + return self.conv_mod(x) + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.sub_mod = SubModule() + + def forward(self, x): + return self.sub_mod(x) + + m = MyModule() + gm = torch.fx.symbolic_trace(m) + + mod_stack = {} + expected_stack = [('sub_mod', str(type(m.sub_mod))), + ('sub_mod.conv_mod', str(type(m.sub_mod.conv_mod)))] + for node in gm.graph.nodes: + mod_stack = node.meta.get('nn_module_stack', {}) + if mod_stack: + break + stack_list = list(mod_stack.items()) + self.assertEqual(stack_list, expected_stack) + def test_interpreter(self): class MyModule(torch.nn.Module): def __init__(self): @@ -1806,7 +1835,7 @@ def forward(self, x, y=3.14159): interp = Interpreter(gm) x = torch.randn(5, 3) out = interp.run(x) - torch.testing.assert_allclose(out, x + 3.14159) + torch.testing.assert_close(out, x + 3.14159) def test_interpreter_not_enough_args(self): class Model(torch.nn.Module): @@ -2315,8 +2344,8 @@ def forward(self, x): traced1.recompile() x = torch.randn(15, 15) - torch.testing.assert_allclose(traced1(x), torch.relu(x)) - torch.testing.assert_allclose(copied(x), torch.neg(x)) + torch.testing.assert_close(traced1(x), torch.relu(x)) + torch.testing.assert_close(copied(x), torch.neg(x)) def test_direct_param_use(self): class TransposeTest(torch.nn.Module): @@ -2699,7 +2728,7 @@ def forward(self, x): replica = gm._replicate_for_data_parallel() out_replica = replica(x) - torch.testing.assert_allclose(out_replica, out) + torch.testing.assert_close(out_replica, out) def test_ast_rewriter_rewrites_assert(self): class M(torch.nn.Module): @@ -2791,7 +2820,7 @@ def to_trace(y): def test_profiler_ranges_side_effect(self): g = torch.fx.Graph() - handle = g.call_function(torch.ops.profiler._record_function_enter, ('test_range',)) + handle = g.call_function(torch.ops.profiler._record_function_enter_new, ('test_range',)) g.call_function(torch.ops.profiler._record_function_exit, (handle,)) g.output(None) @@ -2801,7 +2830,7 @@ def test_profiler_ranges_side_effect(self): found_targets.setdefault(node.target) self.assertEqual( list(found_targets.keys()), - [torch.ops.profiler._record_function_enter, torch.ops.profiler._record_function_exit] + [torch.ops.profiler._record_function_enter_new, torch.ops.profiler._record_function_exit] ) g.eliminate_dead_code() @@ -2811,7 +2840,7 @@ def test_profiler_ranges_side_effect(self): found_targets.setdefault(node.target) self.assertEqual( list(found_targets.keys()), - [torch.ops.profiler._record_function_enter, torch.ops.profiler._record_function_exit] + [torch.ops.profiler._record_function_enter_new, torch.ops.profiler._record_function_exit] ) def test_ast_rewriter_wrapped_via_decorator(self): @@ -3045,7 +3074,7 @@ def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> boo traced_graph = MyCustomTracer().trace(model) gm2 = torch.fx.GraphModule(model, traced_graph) gm2.delete_all_unused_submodules() - torch.testing.assert_allclose(gm2(inputs), model(inputs)) + torch.testing.assert_close(gm2(inputs), model(inputs)) def test_fx_stateless(self): class MockModule(torch.nn.Module): @@ -3806,7 +3835,7 @@ def test_class_member_back_compat(self): f"unintended, please revert it. If it was intended, check with the FX " \ f"team to ensure that the proper deprecation protocols have been followed " \ f"and subsequently --accept the change." - raise AssertionError(msg) + raise AssertionError(msg) from e def test_public_api_surface(self): non_back_compat_objects = {} @@ -3926,7 +3955,6 @@ def tearDown(self): "max_pool2d": PROXY_ITERABLE, "max_pool3d": PROXY_ITERABLE, - "group_norm": PROXY_ITERATED, "lp_pool2d": PROXY_ITERATED, "max_unpool1d": PROXY_ITERATED, "max_unpool2d": PROXY_ITERATED, @@ -3960,6 +3988,7 @@ def tearDown(self): "gaussian_nll_loss": CONTROL_FLOW, "glu": CONTROL_FLOW, "grid_sample": CONTROL_FLOW, + "group_norm": CONTROL_FLOW, "gumbel_softmax": CONTROL_FLOW, "hardsigmoid": CONTROL_FLOW, "hardswish": CONTROL_FLOW, @@ -4030,7 +4059,7 @@ def tearDown(self): "max_pool2d": PROXY_ITERATED, "max_pool3d": PROXY_ITERATED, - "group_norm": LEN_ERROR + "group_norm": CONTROL_FLOW } @classmethod @@ -4110,7 +4139,6 @@ def tearDownClass(cls): instantiate_device_type_tests(TestOperatorSignatures, globals()) @skipIfNoTorchVision -@skipIfSlowGradcheckEnv class TestVisionTracing(JitTestCase): def setUp(self): # Checking for mutable operations while tracing is feature flagged diff --git a/test/test_fx_backends.py b/test/test_fx_backends.py deleted file mode 100644 index f9103d61aa960..0000000000000 --- a/test/test_fx_backends.py +++ /dev/null @@ -1,252 +0,0 @@ -# Owner(s): ["module: fx"] - -import copy -import sys -import logging -from typing import List, Tuple - -import torch -from torch.fx._symbolic_trace import symbolic_trace -from torch.fx.experimental.proxy_tensor import make_fx -from torch.fx.passes.backends.nvfuser import NvFuserBackend - -from torch.testing._internal.common_utils import run_tests, TEST_CUDA, TestCase -from torch.testing._internal.common_device_type import ( - instantiate_device_type_tests, - dtypes, -) - -if not TEST_CUDA: - print('CUDA not available, skipping tests', file=sys.stderr) - TestCase = object # noqa: F811 - -logging.basicConfig(level=logging.DEBUG) -logger = logging.getLogger(__name__) - -class HF_T5_Partial(torch.nn.Module): - - def inputs_meta(self): - return [ - (torch.Size([512, 512]), torch.float32), - (torch.Size([512, 512]), torch.float32), - (torch.Size([512, 512]), torch.float32), - (torch.Size([512, 512]), torch.float32), - (torch.Size([512]), torch.float32), - (torch.Size([2048, 512]), torch.float32), - (torch.Size([512, 2048]), torch.float32), - (torch.Size([512]), torch.float32), - (torch.Size([8, 1024, 512]), torch.float32), - (torch.Size([8, 8, 1024, 1024]), torch.float32), - ] - - def forward(self, primals_1, primals_2, primals_3, primals_4, primals_5, - primals_6, primals_7, primals_8, primals_9, primals_10): - pow_1 = torch.ops.aten.pow(primals_9, 2) - mean = torch.ops.aten.mean(pow_1, [-1], True) - add = torch.ops.aten.add(mean, 1e-06) - rsqrt = torch.ops.aten.rsqrt(add) - mul = torch.ops.aten.mul(primals_9, rsqrt) - mul_1 = torch.ops.aten.mul(primals_5, mul) - t = torch.ops.aten.t(primals_3) - view = torch.ops.aten.view(mul_1, [8192, 512]) - mm = torch.ops.aten.mm(view, t) - _unsafe_view = torch.ops.aten._unsafe_view(mm, [8, 1024, 512]) - view_1 = torch.ops.aten.view(_unsafe_view, [8, -1, 8, 64]) - transpose = torch.ops.aten.transpose(view_1, 1, 2) - t_1 = torch.ops.aten.t(primals_1) - view_2 = torch.ops.aten.view(mul_1, [8192, 512]) - mm_1 = torch.ops.aten.mm(view_2, t_1) - _unsafe_view_1 = torch.ops.aten._unsafe_view(mm_1, [8, 1024, 512]) - view_3 = torch.ops.aten.view(_unsafe_view_1, [8, -1, 8, 64]) - transpose_1 = torch.ops.aten.transpose(view_3, 1, 2) - t_2 = torch.ops.aten.t(primals_4) - view_4 = torch.ops.aten.view(mul_1, [8192, 512]) - mm_2 = torch.ops.aten.mm(view_4, t_2) - _unsafe_view_2 = torch.ops.aten._unsafe_view(mm_2, [8, 1024, 512]) - view_5 = torch.ops.aten.view(_unsafe_view_2, [8, -1, 8, 64]) - transpose_2 = torch.ops.aten.transpose(view_5, 1, 2) - transpose_3 = torch.ops.aten.transpose(transpose_1, 3, 2) - expand = torch.ops.aten.expand(transpose, [8, 8, 1024, 64]) - clone = torch.ops.aten.clone(expand, memory_format=torch.contiguous_format) - _unsafe_view_3 = torch.ops.aten._unsafe_view(clone, [64, 1024, 64]) - expand_1 = torch.ops.aten.expand(transpose_3, [8, 8, 64, 1024]) - clone_1 = torch.ops.aten.clone(expand_1, memory_format=torch.contiguous_format) - _unsafe_view_4 = torch.ops.aten._unsafe_view(clone_1, [64, 64, 1024]) - bmm = torch.ops.aten.bmm(_unsafe_view_3, _unsafe_view_4) - _unsafe_view_5 = torch.ops.aten._unsafe_view(bmm, [8, 8, 1024, 1024]) - add_ = torch.ops.aten.add_(_unsafe_view_5, primals_10) - _softmax = torch.ops.aten._softmax(add_, -1, False) - expand_2 = torch.ops.aten.expand(_softmax, [8, 8, 1024, 1024]) - view_6 = torch.ops.aten.view(expand_2, [64, 1024, 1024]) - expand_3 = torch.ops.aten.expand(transpose_2, [8, 8, 1024, 64]) - clone_2 = torch.ops.aten.clone(expand_3, memory_format=torch.contiguous_format) - _unsafe_view_6 = torch.ops.aten._unsafe_view(clone_2, [64, 1024, 64]) - bmm_1 = torch.ops.aten.bmm(view_6, _unsafe_view_6) - _unsafe_view_7 = torch.ops.aten._unsafe_view(bmm_1, [8, 8, 1024, 64]) - transpose_4 = torch.ops.aten.transpose(_unsafe_view_7, 1, 2) - clone_3 = torch.ops.aten.clone(transpose_4, memory_format=torch.contiguous_format) - view_7 = torch.ops.aten.view(clone_3, [8, -1, 512]) - t_3 = torch.ops.aten.t(primals_2) - view_8 = torch.ops.aten.view(view_7, [8192, 512]) - mm_3 = torch.ops.aten.mm(view_8, t_3) - _unsafe_view_8 = torch.ops.aten._unsafe_view(mm_3, [8, 1024, 512]) - add_1 = torch.ops.aten.add(primals_9, _unsafe_view_8) - pow_2 = torch.ops.aten.pow(add_1, 2) - mean_1 = torch.ops.aten.mean(pow_2, [-1], True) - add_2 = torch.ops.aten.add(mean_1, 1e-06) - rsqrt_1 = torch.ops.aten.rsqrt(add_2) - mul_2 = torch.ops.aten.mul(add_1, rsqrt_1) - mul_3 = torch.ops.aten.mul(primals_8, mul_2) - t_4 = torch.ops.aten.t(primals_6) - view_9 = torch.ops.aten.view(mul_3, [8192, 512]) - mm_4 = torch.ops.aten.mm(view_9, t_4) - _unsafe_view_9 = torch.ops.aten._unsafe_view(mm_4, [8, 1024, 2048]) - relu = torch.ops.aten.relu(_unsafe_view_9) - t_5 = torch.ops.aten.t(primals_7) - view_10 = torch.ops.aten.view(relu, [8192, 2048]) - mm_5 = torch.ops.aten.mm(view_10, t_5) - _unsafe_view_10 = torch.ops.aten._unsafe_view(mm_5, [8, 1024, 512]) - add_3 = torch.ops.aten.add(add_1, _unsafe_view_10) - return [add_3, rsqrt, _unsafe_view_3, t_3, _softmax, view_6, mul_2, t, view_9, t_1, primals_5, add_1, - _unsafe_view_4, view_2, view_10, t_5, t_2, primals_8, view_4, view_8, rsqrt_1, primals_9, t_4, - mul, _unsafe_view_6, relu, view] - - -class TestFxNvFuserBackend(TestCase): - - def _generate_random_inputs(self, device, inputs_meta: List[Tuple[torch.Size, torch.dtype]]): - inputs = [] - for meta in inputs_meta: - shape, dtype = meta - - if dtype in {torch.int, torch.int32, torch.int64, torch.bool, torch.int, torch.uint8}: - input = torch.randint(0, 1, shape, dtype=dtype, device=device) - else: - input = torch.rand(shape, dtype=dtype, device=device) - - inputs.append(input) - - return inputs - - - @dtypes(torch.float32) - def test_nvfuser_call_module_backend(self, device, dtype): - - class Model(torch.nn.Module): - - def __init__(self): - super(Model, self).__init__() - self.bn = torch.nn.BatchNorm2d(3) - self.relu = torch.nn.ReLU() - - def forward(self, inp): - o = self.bn(inp) - o = self.relu(o) - return o - - inp = torch.randn(2, 3, 4, 5).to(dtype=dtype, device=device) - m = Model().to(dtype=dtype, device=device) - - # note that the traced module here contains only `call_module` node, - # which isn't fused by nvfuser backend. But `nvfuser.compile` should run without error - traced = symbolic_trace(m) - - nvfuser = NvFuserBackend() - compiled_module = nvfuser.compile(traced) - - eager_result = m(inp) - nvfuser_result = compiled_module(inp) - - torch.testing.assert_close(eager_result, nvfuser_result, rtol=1e-5, atol=1e-5) - - - @dtypes(torch.float32) - def test_nvfuser_backend(self, device, dtype): - m = HF_T5_Partial() - m.to(device) - - traced = symbolic_trace(m) - - nvfuser = NvFuserBackend() - compiled_module = nvfuser.compile(traced) - - inputs = self._generate_random_inputs(device, m.inputs_meta()) - - eager_result = m(*inputs) - nvfuser_result = compiled_module(*inputs) - - torch.testing.assert_close(eager_result, nvfuser_result, rtol=1e-5, atol=1e-5) - - @dtypes(torch.float32) - def test_aten_square(self, device, dtype): - - def fn(x): - square = torch.square(x) - a = square + 1 - b = a + 1 - return b - - inputs = torch.randn(4, device=device) - traced = make_fx(fn)(inputs) - - nvfuser = NvFuserBackend() - compiled_module = nvfuser.compile(copy.deepcopy(traced)) - - for node in compiled_module.graph.nodes: - if node.op == "call_function": - assert "fused" in str(node.target), "the entire function should be fused into a single fusion group" - - eager_result = traced(inputs) - nvfuser_result = compiled_module(inputs) - torch.testing.assert_close(eager_result, nvfuser_result, rtol=1e-5, atol=1e-5) - - @dtypes(torch.float32) - def test_aten_leakyrelu(self, device, dtype): - - def fn(x): - square = torch.ops.aten.leaky_relu(x, 0.1) - a = square + 1 - b = a + 1 - return b - - inputs = torch.randn(4, device=device) - traced = make_fx(fn)(inputs) - - nvfuser = NvFuserBackend() - compiled_module = nvfuser.compile(copy.deepcopy(traced)) - - for node in compiled_module.graph.nodes: - if node.op == "call_function": - assert "fused" in str(node.target), "the entire function should be fused into a single fusion group" - - eager_result = traced(inputs) - nvfuser_result = compiled_module(inputs) - torch.testing.assert_close(eager_result, nvfuser_result, rtol=1e-5, atol=1e-5) - - @dtypes(torch.float32) - def test_aten_where(self, device, dtype): - - def fn(x): - where = torch.ops.aten.where(x < 0, -x, x) - a = where + 1 - b = a + 1 - return b - - inputs = torch.randn(4, device=device) - traced = make_fx(fn)(inputs) - - nvfuser = NvFuserBackend() - compiled_module = nvfuser.compile(copy.deepcopy(traced)) - - for node in compiled_module.graph.nodes: - if node.op == "call_function": - assert "fused" in str(node.target), "the entire function should be fused into a single fusion group" - - eager_result = traced(inputs) - nvfuser_result = compiled_module(inputs) - torch.testing.assert_close(eager_result, nvfuser_result, rtol=1e-5, atol=1e-5) - -instantiate_device_type_tests(TestFxNvFuserBackend, globals(), only_for="cuda") - -if __name__ == "__main__": - run_tests() diff --git a/test/test_fx_experimental.py b/test/test_fx_experimental.py index ae7a2250b8abb..e94c1bc7cc445 100644 --- a/test/test_fx_experimental.py +++ b/test/test_fx_experimental.py @@ -7,43 +7,44 @@ import sys import tempfile import unittest -from typing import Callable, Dict, Union, List, Optional from types import BuiltinFunctionType +from typing import Callable, Dict, List, NamedTuple, Optional, Tuple, Union import torch +import torch.fx.experimental.meta_tracer import torch.fx.experimental.optimization as optimization from torch.fx._symbolic_trace import symbolic_trace from torch.fx.experimental import merge_matmul from torch.fx.experimental.accelerator_partitioner import Partitioner -from torch.fx.experimental.normalize import NormalizeOperators, NormalizeArgs -from torch.fx.passes import graph_manipulation -from torch.fx.passes.param_fetch import lift_lowering_attrs_to_nodes +from torch.fx.experimental.normalize import NormalizeArgs, NormalizeOperators from torch.fx.experimental.partitioner_utils import ( - NodeLatency, - get_partition_to_latency_mapping, - get_latency_of_partitioned_graph, Device, + get_latency_of_partitioned_graph, + get_partition_to_latency_mapping, + NodeLatency, PartitionerConfig, PartitionMode, ) from torch.fx.experimental.rewriter import RewritingTracer from torch.fx.experimental.schema_type_annotation import AnnotateTypesWithSchema -import torch.fx.experimental.meta_tracer from torch.fx.graph_module import GraphModule from torch.fx.node import Node from torch.fx.operator_schemas import ( _torchscript_type_to_python_type, + create_type_hint, normalize_function, normalize_module, type_matches, - create_type_hint, ) +from torch.fx.passes import graph_manipulation +from torch.fx.passes.param_fetch import lift_lowering_attrs_to_nodes from torch.fx.passes.shape_prop import ShapeProp from torch.fx.passes.split_module import split_module +from torch.fx.passes.annotate_getitem_nodes import annotate_getitem_nodes from torch.testing._internal.common_device_type import ( - ops, - onlyCPU, instantiate_device_type_tests, + onlyCPU, + ops, ) from torch.testing._internal.common_methods_invocations import op_db from torch.testing._internal.common_nn import module_tests, new_module_tests @@ -782,7 +783,7 @@ def split_callback(n): x = torch.randn(5, 3) foo = torch.randn(5, 3) - torch.testing.assert_allclose(split(x, foo=foo), traced(x, foo=foo)) + torch.testing.assert_close(split(x, foo=foo), traced(x, foo=foo)) @skipIfNoTorchVision def test_subgraph_trivial_resnet(self): @@ -814,7 +815,7 @@ def forward(self, x, targets=None): split = split_module(traced, mtt, lambda node: 0) x = torch.randn(50, 512) - torch.testing.assert_allclose(split(x), traced(x)) + torch.testing.assert_close(split(x), traced(x)) def test_normalize_binary_operators(self): ops_to_test = { @@ -1080,6 +1081,37 @@ def is_leaf_module( # Smoke test torchscript compilation since now we're emitting type annotations torch.jit.script(traced_functionals_annotated) + def test_annotate_getitem_node(self): + class CustomType: + pass + + class CustomNamedTuple(NamedTuple): + x: int + y: float + + class MyModule(torch.nn.Module): + def forward(self, inp: Tuple[CustomType, torch.Tensor], inp2: List[CustomType], inp3: CustomNamedTuple): + inp_0 = inp[0] + inp_1 = inp[1] + inp2_0 = inp2[0] + inp3_x = inp3.x + inp3_y = inp3.y + return inp_0 + inp_1 + inp2_0 + inp3_x + inp3_y + + my_module = MyModule() + my_module_traced = torch.fx.symbolic_trace(my_module) + + # by default, fx transform loses type annotation of getitem nodes. + for node in my_module_traced.graph.nodes: + if node.target == operator.getitem: + assert node.type is None + + annotate_getitem_nodes(my_module_traced.graph) + + for node in my_module_traced.graph.nodes: + if node.target == operator.getitem: + self.assertIsNotNone(node.type, f"Node {node} should be annotated but is not.") + def test_subgraph_uniquename(self): class MyModule(torch.nn.Module): def __init__(self): diff --git a/test/test_fx_passes.py b/test/test_fx_passes.py index 0aa721f34a167..d9e5abc921df7 100644 --- a/test/test_fx_passes.py +++ b/test/test_fx_passes.py @@ -182,46 +182,93 @@ def forward13(a, b, c): c1 = a1 + c return b1 + c1 + @staticmethod + def forward14(a, b, c): + a0, a1 = torch.ops.aten.std_mean(a) + out = a0 + 1.0 + return out + + @staticmethod + def forward15(a, b, c): + a0 = torch.ops.aten.view(a, [2, 2]) + a1 = torch.ops.aten.permute(a0, [1, 0]) + a2 = a1 + 1.0 + a3 = torch.ops.aten.permute(a2, [1, 0]) + a4 = a3 + 1.0 + a5 = torch.ops.aten.permute(a4, [1, 0]) + return torch.ops.aten.permute(a5, [1, 0]) + + @staticmethod + def forward16(a, b, c): + a0 = a - 1.0 + a1 = torch.ops.aten.view(a0, [2, 2]) + a2 = torch.ops.aten.permute(a1, [1, 0]) + a3 = a2 + 1.0 + a4 = torch.ops.aten.permute(a3, [1, 0]) + a5 = a4 + 1.0 + a6 = torch.ops.aten.permute(a5, [1, 0]) + a7 = torch.ops.aten.permute(a6, [1, 0]) + return a7 - 1.0 + # A mock OperatorSupport class, where only operator.add is supported class MockOperatorSupport(OperatorSupport): def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: - return node.op == "call_function" and node.target in {operator.add, operator.getitem} - + return (node.op == "call_function" and + node.target in {operator.add, operator.getitem, + torch.ops.aten.view, + torch.ops.aten.permute, + torch.ops.aten.std_mean}) @instantiate_parametrized_tests class TestFXGraphPasses(JitTestCase): - @parametrize("fn, expected_partition", [ - (TestPartitionFunctions.forward1, [["add_7", "add_6"], ["add_5", "add_4", "add_3"], ["add_2", "add_1", "add"]]), - (TestPartitionFunctions.forward2, [["add_3", "add_2"], ["add_1", "add"]]), + @parametrize("fn, expected_partition, bookend_non_compute_pass", [ + (TestPartitionFunctions.forward1, [["add_7", "add_6"], ["add_5", "add_4", "add_3"], ["add_2", "add_1", "add"]], False), + (TestPartitionFunctions.forward2, [["add_3", "add_2"], ["add_1", "add"]], False), # 1 horizontal fusion with common producer - (TestPartitionFunctions.forward3, [["add_2", "add_1", "add"]]), - (TestPartitionFunctions.forward4, [["add_2", "add_1", "add"]]), + (TestPartitionFunctions.forward3, [["add_2", "add_1", "add"]], False), + (TestPartitionFunctions.forward4, [["add_2", "add_1", "add"]], False), # 2 branches cases - (TestPartitionFunctions.forward5, [["add_1", "add"]]), - (TestPartitionFunctions.forward6, [["add"]]), - (TestPartitionFunctions.forward7, [["add_3", "add_2", "add", "add_1"]]), - (TestPartitionFunctions.forward8, [["add_3", "add_2", "add", "add_1"]]), + (TestPartitionFunctions.forward5, [["add_1", "add"]], False), + (TestPartitionFunctions.forward6, [["add"]], False), + (TestPartitionFunctions.forward7, [["add_3", "add_2", "add", "add_1"]], False), + (TestPartitionFunctions.forward8, [["add_3", "add_2", "add", "add_1"]], False), # 3 branch cases - (TestPartitionFunctions.forward9, [['add_3', 'add_2', 'add_1', 'add']]), - (TestPartitionFunctions.forward10, [['add_3', 'add_2', 'add', 'add_1']]), - (TestPartitionFunctions.forward11, [['add_1'], ['add']]), + (TestPartitionFunctions.forward9, [['add_3', 'add_2', 'add_1', 'add']], False), + (TestPartitionFunctions.forward10, [['add_3', 'add_2', 'add', 'add_1']], False), + (TestPartitionFunctions.forward11, [['add_1'], ['add']], False), # 4 not necessarily the only partition, just to verify that there's no cyclic dependency after partition - (TestPartitionFunctions.forward12, [["add_2"], ["add_3", "add_4", "add_1"], ["add"]]), + (TestPartitionFunctions.forward12, [["add_2"], ["add_3", "add_4", "add_1"], ["add"]], False), # 5 getitem special case - (TestPartitionFunctions.forward13, [["add_2", "add_1", "add"]]), + (TestPartitionFunctions.forward13, [["add_2", "add_1", "add"]], False), + (TestPartitionFunctions.forward14, [["add", "std_mean", "getitem", "getitem_1"]], False), + + # 6 bookend non_compute pass + (TestPartitionFunctions.forward15, [["permute_1", "add_1", "add"]], True), + (TestPartitionFunctions.forward15, [['add_1', 'add', 'permute_1', 'view', 'permute_2', 'permute_3', 'permute']], False), + (TestPartitionFunctions.forward16, [["permute_1", "add_1", "add"]], True), + (TestPartitionFunctions.forward16, [['add_1', 'add', 'permute_1', 'view', 'permute_2', 'permute_3', 'permute']], False), ]) - def test_partitioner(self, fn, expected_partition): + def test_partitioner(self, fn, expected_partition, bookend_non_compute_pass): traced = symbolic_trace(fn) + non_compute_ops = [] + if bookend_non_compute_pass: + non_compute_ops = ["torch.ops.aten.view", "torch.ops.aten.permute"] + supported_ops = MockOperatorSupport() - partitioner = CapabilityBasedPartitioner(traced, supported_ops, allows_single_node_partition=True) + partitioner = CapabilityBasedPartitioner(traced, + supported_ops, + allows_single_node_partition=True, + non_compute_ops=non_compute_ops) partitions = partitioner.propose_partitions() + if bookend_non_compute_pass: + partitioner.remove_bookend_non_compute_ops(partitions) partitions_name = [[node.name for node in partition.nodes] for partition in partitions] assert len(partitions_name) == len(expected_partition) diff --git a/test/test_fx_reinplace_pass.py b/test/test_fx_reinplace_pass.py index abb9696225c44..dc512cadea69e 100644 --- a/test/test_fx_reinplace_pass.py +++ b/test/test_fx_reinplace_pass.py @@ -345,9 +345,8 @@ def forward(self): ones = torch.ops.aten.ones.default([4, 2, 4], device = device(type='cpu'), pin_memory = False) slice_1 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807) slice_2 = torch.ops.aten.slice.Tensor(slice_1, 1, 2, 9223372036854775807); slice_1 = None + copy = torch.ops.aten.copy_.default(slice_2, ones); slice_2 = ones = None slice_3 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807) - slice_tensor = torch.ops.aten.slice.Tensor(slice_3, 1, 2, 9223372036854775807); slice_3 = None - copy__default = torch.ops.aten.copy_.default(slice_tensor, ones); slice_tensor = ones = None return zeros """) diff --git a/test/test_indexing.py b/test/test_indexing.py index 1d5f2ea68ac21..5b0d9f51360b3 100644 --- a/test/test_indexing.py +++ b/test/test_indexing.py @@ -11,13 +11,15 @@ import numpy as np from torch.testing import make_tensor -from torch.testing._internal.common_utils import TestCase, run_tests +from torch.testing._internal.common_utils import ( + TestCase, run_tests, TEST_WITH_TORCHDYNAMO, skipIfMps) from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, onlyCUDA, dtypes, dtypesIfCPU, dtypesIfCUDA, onlyNativeDeviceTypes) class TestIndexing(TestCase): + @skipIfMps def test_index(self, device): def consec(size, start=1): @@ -737,6 +739,10 @@ def test_byte_mask_accumulate(self, device): self.assertEqual(y, torch.ones(size=(10, 10), device=device)) self.assertEqual(len(w), 2) + @unittest.skipIf( + TEST_WITH_TORCHDYNAMO, + "This test causes SIGKILL when running with dynamo, https://github.com/pytorch/pytorch/issues/88472" + ) def test_index_put_accumulate_large_tensor(self, device): # This test is for tensors with number of elements >= INT_MAX (2^31 - 1). N = (1 << 31) + 5 diff --git a/test/test_jit.py b/test/test_jit.py index b1425a4ed71ca..6cbc091d506b5 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -20,7 +20,6 @@ from jit.test_autodiff import TestAutodiffJit # noqa: F401 from jit.test_autodiff_subgraph_slicing import TestAutodiffSubgraphSlicing # noqa: F401 from jit.test_custom_operators import TestCustomOperators # noqa: F401 -from jit.test_export_modes import TestExportModes # noqa: F401 from jit.test_graph_rewrite_passes import TestGraphRewritePasses # noqa: F401 from jit.test_class_type import TestClassType # noqa: F401 from jit.test_builtins import TestBuiltins, TestTensorBuiltins # noqa: F401 @@ -97,7 +96,7 @@ from torch.testing._internal.common_jit import check_against_reference from torch.testing._internal.common_utils import run_tests, IS_WINDOWS, TEST_WITH_UBSAN, \ suppress_warnings, BUILD_WITH_CAFFE2, IS_SANDCASTLE, GRAPH_EXECUTOR, ProfilingMode, TestCase, \ - freeze_rng_state, slowTest, TemporaryFileName, skipIfCompiledWithoutNumpy, \ + freeze_rng_state, slowTest, TemporaryFileName, \ enable_profiling_mode_for_profiling_tests, TEST_MKL, set_default_dtype, num_profiled_runs, \ skipIfCrossRef, IS_MACOS, skipIfTorchDynamo from torch.testing._internal.jit_utils import JitTestCase, enable_cpu_fuser, disable_autodiff_subgraph_inlining, \ @@ -3952,6 +3951,14 @@ def invalid4(a): return a + 2 torch.jit.script(invalid4) + def test_calls_in_type_annotations(self): + with self.assertRaisesRegex(RuntimeError, "Type annotation should not contain calls"): + def spooky(a): + # type: print("Hello") -> Tensor # noqa: F723 + return a + 2 + print(torch.__file__) + torch.jit.annotations.get_signature(spooky, None, 1, True) + def test_is_optional(self): ann = Union[List[int], List[float]] torch._jit_internal.is_optional(ann) @@ -5913,23 +5920,6 @@ def test_fuser_multiple_blocks(this, that, theother, meme): self.assertEqual(cu.test_fuser_multiple_blocks(*inputs), outputs) - def test_dropout_script(self): - - eg = torch.zeros(1, 2, 3, requires_grad=True) - - @_trace(eg) - def foo(x): - x = torch.neg(x) - return F.dropout(x) - - class MyDrop(nn.Module): - def forward(self, x): - return foo(x) - - f = io.BytesIO() - with warnings.catch_warnings(record=True): - torch.onnx.export(MyDrop(), (eg,), f, verbose=False) - @unittest.skip("RuntimeError: VariableType::ID() not implemented") def test_cast(self): script = ''' @@ -9780,50 +9770,6 @@ def forward(self, rep): m = M2() m(torch.zeros(4, 3)) - @skipIfCompiledWithoutNumpy - def test_pack_padded_pad_packed_trace(self): - from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence - T, B, C = 3, 5, 7 - - class PadPackedWrapper(torch.nn.Module): - def __init__(self): - super(PadPackedWrapper, self).__init__() - - def forward(self, x, seq_lens): - x = pack_padded_sequence(x, seq_lens) - x, _ = pad_packed_sequence(x) - return x - - x = np.ones((T, B, C)) - seq_lens = np.array([3, 3, 2, 2, 1], dtype=np.int32) - # set padding value so we can test equivalence - for b in range(B): - if seq_lens[b] < T: - x[seq_lens[b]:, b, :] = 0 - seq_lens = torch.from_numpy(seq_lens) - x = torch.autograd.Variable(torch.from_numpy(x), requires_grad=True) - - m = PadPackedWrapper() - m_traced = torch.jit.trace(m, (x, seq_lens,)) - - y = m(x, seq_lens) - loss = torch.sum(y) - loss.backward() - grad = x.grad.clone() - x.grad.zero_() - - y_traced = m_traced(x, seq_lens) - loss_traced = torch.sum(y_traced) - loss_traced.backward() - grad_traced = x.grad.clone() - - self.assertEqual(y_traced, x) - self.assertEqual(y_traced, y) - self.assertEqual(grad, grad_traced) - - f = io.BytesIO() - torch.onnx._export(m, (x, seq_lens), f, verbose=False) - def test_script_pack_padded_sequence(self): from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence @@ -10024,54 +9970,6 @@ def forward(self, input: torch.Tensor): m_scripted = torch.jit.script(m) self.assertEqual(m_scripted(torch.tensor(1)), torch.tensor(246)) - # Suppression: ONNX warns when exporting RNNs because of potential batch size mismatch. - @suppress_warnings - @skipIfCompiledWithoutNumpy - def test_rnn_trace_override(self): - from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence - num_layers = 3 - T, B, C = 11, 5, 7 - - class RNNTraceWrapper(torch.nn.Module): - def __init__(self, cell_type): - super(RNNTraceWrapper, self).__init__() - if cell_type == 'RNN': - self.rnn = torch.nn.RNN(input_size=C, hidden_size=C, num_layers=num_layers) - elif cell_type == 'LSTM': - self.rnn = torch.nn.LSTM(input_size=C, hidden_size=C, num_layers=num_layers) - elif cell_type == 'GRU': - self.rnn = torch.nn.GRU(input_size=C, hidden_size=C, num_layers=num_layers) - - def forward(self, x, seq_lens): - x = pack_padded_sequence(x, seq_lens) - x, _ = self.rnn(x) - x, _ = pad_packed_sequence(x) - return x - - for cell_type in ['RNN', 'LSTM', 'GRU']: - x = torch.ones(T, B, C, requires_grad=True) - seq_lens = torch.from_numpy(np.array([11, 3, 2, 2, 1], dtype=np.int32)) - - m = RNNTraceWrapper(cell_type) - m_traced = torch.jit.trace(m, (x, seq_lens,)) - - y = m(x, seq_lens) - loss = torch.sum(y) - loss.backward() - grad = x.grad.clone() - x.grad.zero_() - - y_traced = m_traced(x, seq_lens) - loss_traced = torch.sum(y_traced) - loss_traced.backward() - grad_traced = x.grad.clone() - - self.assertEqual(y_traced, y) - self.assertEqual(grad, grad_traced) - - f = io.BytesIO() - torch.onnx._export(m, (x, seq_lens), f, verbose=False) - def test_python_call_non_tensor(self): def foo(a, b, c): # type: (Tensor, int, Tuple[Tensor, int]) -> Tuple[int, Tensor] diff --git a/test/test_jit_autocast.py b/test/test_jit_autocast.py index 5d555e7cd9d8d..d311eb687a763 100644 --- a/test/test_jit_autocast.py +++ b/test/test_jit_autocast.py @@ -797,7 +797,7 @@ def test_nchw_autocast_jit_trace_model(model, x): y = traced_model(x.clone()) with torch.cpu.amp.autocast(), torch.no_grad(): y2 = model(x.clone()) - torch.testing.assert_allclose(y.double(), y2.double(), rtol=1e-03, atol=1e-03) + torch.testing.assert_close(y.double(), y2.double(), rtol=1e-03, atol=1e-03) for i in range(self.models.__len__()): test_nchw_autocast_jit_trace_model(self.models[i], self.inputs[i]) @@ -812,13 +812,39 @@ def test_nhwc_autocast_jit_trace_model(model, x): y = traced_model(x.clone().to(memory_format=torch.channels_last)) with torch.cpu.amp.autocast(), torch.no_grad(): y2 = model(x.clone().to(memory_format=torch.channels_last)) - torch.testing.assert_allclose(y.double(), y2.double(), rtol=1e-03, atol=1e-03) + torch.testing.assert_close(y.double(), y2.double(), rtol=1e-03, atol=1e-03) for i in range(self.models.__len__()): if self.inputs[i].size().__len__() == 5: # NHWC 3D case not support yet continue test_nhwc_autocast_jit_trace_model(self.models[i], self.inputs[i]) + def test_cat_promote(self): + class TestModel(torch.nn.Module): + def __init__(self): + super(TestModel, self).__init__() + + def forward(self, a, b): + return torch.cat([a, b], 0) + with torch.jit.fuser("none"): + # In this testcase, we will check whether cat has done the promotion in AMP with mixed dtype inputs. + # To avoid the fusion group from TE, we will disable the fuser here. + for jit_freeze_or_not in [False, True]: + test_model = TestModel().eval() + with torch.cpu.amp.autocast(cache_enabled=False, dtype=torch.bfloat16), torch.no_grad(): + a = torch.rand(24, 128, 128) + b = torch.rand(24, 128, 128, dtype=torch.bfloat16) + c = test_model(a, b) + traced = torch.jit.trace(test_model, (a, b)) + if jit_freeze_or_not: + traced = torch.jit.freeze(traced) + for _ in range(3): + c2 = traced(a, b) + self.assertTrue(c.dtype, torch.float32) + self.assertTrue(c2.dtype, torch.float32) + traced_graph = traced.graph_for(a, b) + self.assertTrue(any(n.kind() == "aten::to" for n in traced_graph.nodes())) + def test_script_autocast_cpu(self): def fn(x): if torch.is_autocast_cpu_enabled(): diff --git a/test/test_jit_cuda_fuser.py b/test/test_jit_cuda_fuser.py index e51cd01cd4cda..0a13fdb20a823 100644 --- a/test/test_jit_cuda_fuser.py +++ b/test/test_jit_cuda_fuser.py @@ -20,7 +20,7 @@ from torch.testing._internal.common_jit import JitCommonTestCase from torch.testing._internal.common_methods_invocations import op_db, SampleInput from torch.testing._internal.common_utils import run_tests, ProfilingMode, GRAPH_EXECUTOR, TEST_WITH_ROCM, slowTest, \ - is_iterable_of_tensors, freeze_rng_state + is_iterable_of_tensors, freeze_rng_state, skipIfRocm from torch.testing._internal.jit_utils import clone_inputs, get_traced_sample_variant_pairs, JitTestCase, RUN_CUDA from torch.testing._internal.jit_metaprogramming_utils import create_traced_fn from torch.testing import FileCheck @@ -143,7 +143,7 @@ def setUp(self): disabled_ops = ("aten::batch_norm", "aten::_batch_norm_impl_index", "aten::_batch_norm_impl_index_backward", - "aten::native_batch_norm_backward") + "aten::native_batch_norm_backward",) for op in disabled_ops: disabled_flag = torch._C._jit_set_nvfuser_skip_node_kind(op, False) if disabled_flag: @@ -383,6 +383,27 @@ def func(x: torch.Tensor): self.assertTrue(self._compare("comparing output failed", o, jit_o, 1e-4)) self.assertGraphContains(t_jit.graph_for(x), FUSION_GUARD) + @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device") + @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, + "Requires fusion optimization pass to be effective") + def test_variance_profiling(self): + with nvfuser_singleton_fusion(True): + for op in [torch.var, torch.std]: + for dtype in [torch.float16, torch.float32, torch.double]: + for axis in [-2, -1, 2, 1]: + for unbiased in [False, True]: + for keepdim in [False, True]: + def t(x: torch.Tensor, dim: List[int], unbiased: bool, keepdim: bool): + o = torch.mul(x, 2.0) + o = op(o, dim=dim, unbiased=unbiased, keepdim=keepdim) + return o + + x = torch.randn(8, 4, 16, dtype=dtype, device="cuda") + t_jit = torch.jit.script(t) + self._run_helper(t_jit, t, x, [axis], unbiased, keepdim, check_stride=False, check_runs=5) + + @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") @@ -1744,6 +1765,7 @@ def test_norm(self): x[1] = C self._norm_helper(x, torch.float32, "cuda", 1e-4, is_batch_norm_else_instance_norm) + @skipIfRocm @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, @@ -3365,6 +3387,7 @@ def test_batch_norm_impl_index_inner_bcast(self): training, track_running_stats = training_and_track self._test_batch_norm_impl_index_helper(2, 1, 1, affine, track_running_stats, training) + @skipIfRocm @unittest.skipIf(is_pre_volta(), "reduction not supported in pre volta device") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, @@ -4464,6 +4487,125 @@ def t(x, w): self.assertEqual(jit_o, o) self.assertGraphContainsExactly(t_jit.graph_for(x, w), FUSION_GUARD, 2, consider_subgraphs=True) + @skipIfRocm + # see issue here on why we disabled this test https://github.com/csarofeen/pytorch/issues/2127 + @unittest.skipIf(is_pre_volta(), "permutation scheduling can be dangerous on pre-volta device") + @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, + "Requires fusion optimization pass to be effective") + def test_view_before_permute(self): + view_examples = [[[1, 19, 1, 12, 7, 1, 99], [1, 19, 1, 3, 2772]], + [[3, 17, 80, 1], [51, 1, 2, 4, 10]], + [[3, 17, 80, 1, 9], [51, 1, 2, 4, 10, 9]], + [[2, 3, 4, 5], [1, 6, 1, 2, 2, 5]], + [[22, 22, 2], [22, 11, 1, 1, 4]], + [[37, 9, 7, 6, 10], [333, 2, 2, 3, 35]], + [[8, 1, 1, 8, 1, 8], [8, 2, 4, 1, 8]], + [[1, 333, 1], [1, 37, 9]], + [[1, 333], [1, 1, 1, 111, 1, 3]], + [[1, 27454, 1, 2], [1, 7844, 1, 7]], + [[1, 7844, 1, 7], [1, 27454, 2]]] + + def _getTransposeAxes(sizes): + # broadcast do not change + # always move inner-most dim + # random permutation of other dims + result = [] + valid_sizes = [] + for idx, val in enumerate(sizes): + if val > 1 and idx < len(sizes) - 1: + valid_sizes.append((idx, val)) + result.append(idx) + idx, new_size = valid_sizes[random.randint(0, len(valid_sizes) - 1)] + result[idx] = len(sizes) - 1 + result[len(sizes) - 1] = idx + return result + + def _transposeSize(sizes, dims): + return [sizes[old_pos] for old_pos in dims] + + for example in view_examples: + before_view_size, after_view_size = example + axes = _getTransposeAxes(after_view_size) + output_size = _transposeSize(after_view_size, axes) + self._view_before_permute_helper(before_view_size, after_view_size, output_size, axes) + + def _view_before_permute_helper(self, input_shape, view_shape, output_shape, dims): + def t(x, y, view_shape : List[int], dims : List[int]): + x_v = x.view(view_shape) + x_t = torch.permute(x_v, dims) + o = torch.add(x_t, y) + o = torch.relu(o) + return o + + x = torch.randn(*input_shape, device="cuda") + y = torch.randn(*output_shape, device="cuda") + t_jit = torch.jit.script(t) + self._run_helper(t_jit, t, x, y, view_shape, dims) + + @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, + "Requires fusion optimization pass to be effective") + def test_permute(self): + max_dims = 4 + for ndims in range(2, max_dims + 1): + shape = [idx + 2 for idx in range(ndims)] + for dims in itertools.permutations(range(ndims)): + self._permute_helper(shape, dims) + + def _permute_helper(self, shape, dims): + def t(x, y, dims : List[int]): + x_t = torch.permute(x, dims) + y_t = torch.permute(y, dims) + o = torch.add(x_t, y_t) + o = torch.relu(o) + return o + + x = torch.randn(*shape, device="cuda") + y = torch.randn(*shape, device="cuda") + t_jit = torch.jit.script(t) + self._run_helper(t_jit, t, x, y, dims) + + @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, + "Requires fusion optimization pass to be effective") + def test_transpose(self): + max_dims = 4 + for ndims in range(2, max_dims + 1): + shape = [idx + 2 for idx in range(ndims)] + for idx in range(1, ndims): + for jdx in range(idx): + self._transpose_helper(shape, idx, jdx) + + def _transpose_helper(self, shape, dim0, dim1): + def t(x, y, dim0 : int, dim1 : int): + x_t = torch.transpose(x, dim0, dim1) + y_t = torch.transpose(y, dim0, dim1) + o = torch.add(x_t, y_t) + o = torch.nn.functional.gelu(o) + return o + + x = torch.randn(*shape, device="cuda") + y = torch.randn(*shape, device="cuda") + t_jit = torch.jit.script(t) + self._run_helper(t_jit, t, x, y, dim0, dim1) + + @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, + "Requires fusion optimization pass to be effective") + def test_transpose_default(self): + def t(x, y): + x_t = torch.t(x) + y_t = torch.t(y) + o = torch.add(x_t, y_t) + o = torch.nn.functional.gelu(o) + return o + + x = torch.randn(3, 5, device="cuda") + y = torch.randn(3, 5, device="cuda") + t_jit = torch.jit.script(t) + self._run_helper(t_jit, t, x, y) + @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires fusion optimization pass to be effective") @@ -4971,6 +5113,27 @@ def t(x, y, w0, w1): t_jit = torch.jit.script(t) self._run_helper(t_jit, t, x0, x1, w0, w1, check_stride=True) + @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, + "Requires fusion optimization pass to be effective") + def test_no_tensor_input(self): + device = "cuda" + x = torch.randn(512, device=device) + + def t(x): + tensor0 = torch.tensor(3, dtype=torch.float32, device='cuda') + tensor1 = torch.tensor(3, dtype=torch.float32, device='cuda') + o = torch.div(x.numel(), tensor0) + o = torch.mul(o, tensor1) + return o + + t_jit = torch.jit.script(t) + self._run_helper(t_jit, t, x, check_stride=True) + + # Note that curently TS embeds constant tensor in the graph + # this triggers memory leak check in CI + torch.jit._state._python_cu.drop_all_functions() + class TestEnableDisableCudaFuser(JitTestCase): def setUp(self): @@ -5086,6 +5249,7 @@ def test_nvfuser_correctness(self, device, dtype, op): # if the CU is not cleared. torch.jit._state._python_cu.drop_all_functions() + @skipIfRocm @slowTest @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py index e1c820fda9c46..4cfcfbe4b315c 100644 --- a/test/test_jit_fuser_te.py +++ b/test/test_jit_fuser_te.py @@ -597,7 +597,7 @@ def apply(fn): except Exception as e: raise RuntimeError( " ".join(["Failed:", str(dtype), op.__name__, device]) - ) + ) from e def test_minmax_int_ops(self): def apply(fn): @@ -627,7 +627,7 @@ def apply(fn): except Exception as e: raise RuntimeError( " ".join(["Failed:", str(dtype), op.__name__, device]) - ) + ) from e def test_comparison_eq_ne(self): for device in self.devices: @@ -1046,8 +1046,7 @@ def fn_test_rand2(x, y): script_f = torch.jit.script(fn_test_rand2) warmup_forward(script_f, x, y) out = script_f(x, y) - # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 - self.assertEqualIgnoreType(out[0, :] + torch.zeros(4, 4, device='cuda'), out) + self.assertEqual(out[0, :] + torch.zeros(4, 4, device='cuda'), out) @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") @unittest.skip("rand_like is not supported yet") @@ -1288,7 +1287,7 @@ def fn(input_v, mask): except Exception as e: raise RuntimeError( " ".join(["Failed:", str(self_dtype), op.__name__, device, str(size)]) - ) + ) from e def test_isnan(self): x = torch.rand([4]) @@ -1321,7 +1320,7 @@ def test_isnan(self): except Exception as e: raise RuntimeError( " ".join(["Failed:", str(dtype), 'isnan', device]) - ) + ) from e def test_gelu(self): def apply(fn): @@ -1352,7 +1351,7 @@ def apply(fn): except Exception as e: raise RuntimeError( " ".join(["Failed:", str(dtype), op.__name__, device, str(size)]) - ) + ) from e def test_unary_ops(self): with torch._jit_internal._disable_emit_hooks(): @@ -1435,7 +1434,7 @@ def apply(fn): except Exception as e: raise RuntimeError( " ".join(["Failed:", str(dtype), op.__name__, device, str(size)]) - ) + ) from e def test_binary_ops(self): def apply(fn): @@ -1488,7 +1487,7 @@ def apply(fn): except Exception as e: raise RuntimeError( " ".join(["Failed:", str(dtype), op.__name__, device]) - ) + ) from e def test_binary_scalar_ops(self): def apply(fn): @@ -1534,7 +1533,7 @@ def apply(fn): try: k = torch._C._te.TensorExprKernel(graph) except Exception as e: - raise RuntimeError(" ".join(["Compilation failed:", device, str(code)])) + raise RuntimeError(" ".join(["Compilation failed:", device, str(code)])) from e # Run the graph for x, y in product(values[dtype_x], values[dtype_y]): @@ -1543,7 +1542,7 @@ def apply(fn): res = k.run((x, y)) self.assertEqual(ref, res) except Exception as e: - raise RuntimeError(" ".join(["Failed at runtime:", device, str(x), str(y), str(code)])) + raise RuntimeError(" ".join(["Failed at runtime:", device, str(x), str(y), str(code)])) from e def test_matmul(self): if self.dynamic_shapes: @@ -1599,7 +1598,7 @@ def fn(x, y): except Exception as e: raise RuntimeError( " ".join(["Failed:", str(dtype), device]) - ) + ) from e def test_binary_tensor_scalar_ops(self): with torch._jit_internal._disable_emit_hooks(): @@ -1643,7 +1642,7 @@ def apply_with_scalar(fn, scalar): except Exception as e: raise RuntimeError( " ".join(["Failed:", str(dtype), op.__name__, device]) - ) + ) from e def test_binary_div_ops(self): def apply_with_scalar(fn, scalar): @@ -1676,7 +1675,7 @@ def apply_with_scalar(fn, scalar): except Exception as e: raise RuntimeError( "Failed: {} {} {} {}".format(dtype, op.__name__, device, scalar) - ) + ) from e def test_binary_pow(self): def apply_with_scalar(fn, scalar): @@ -1714,7 +1713,7 @@ def apply_with_scalar(fn, scalar): except Exception as e: raise RuntimeError( " ".join(["Failed:", str(dtype), op.__name__, device]) - ) + ) from e def test_ternary_ops(self): def apply(fn): @@ -1746,7 +1745,7 @@ def apply(fn): except Exception as e: raise RuntimeError( " ".join(["Failed:", str(dtype), op.__name__, device]) - ) + ) from e def test_ternary_norm_ops(self): def apply(fn): @@ -1777,7 +1776,7 @@ def apply(fn): except Exception as e: raise RuntimeError( " ".join(["Failed:", str(dtype), op.__name__, device]) - ) + ) from e @unittest.skip("FIXME: fuser doesn't include ListConstruct nodes to the group causing a failure") @@ -1810,7 +1809,7 @@ def apply(fn): except Exception as e: raise RuntimeError( " ".join(["Failed:", str(dtype), op.__name__, device]) - ) + ) from e def test_where_ops(self): def apply(fn): @@ -1843,7 +1842,7 @@ def apply(fn): except Exception as e: raise RuntimeError( " ".join(["Failed:", str(dtype), op.__name__, device]) - ) + ) from e def test_unsupported_dtypes(self): for device in self.devices: @@ -2202,7 +2201,9 @@ def test_batch_norm(self): def test(fn, args): trace = torch.jit.trace(fn, args) self.assertAllFused(trace.graph_for(*args)) - torch.testing.assert_allclose(fn(*args), trace(*args)) + # TODO: Are `NaN`'s actually ok here or did this pass silently before, because `equal_nan=True` was the + # default? + torch.testing.assert_close(fn(*args), trace(*args), equal_nan=True) def bn(i, x): return torch.batch_norm(i, x, x, x, x, False, 0.1, 1e-4, False).relu() diff --git a/test/test_jit_llga_fuser.py b/test/test_jit_llga_fuser.py index 4804a442c1d66..12bd955043b96 100644 --- a/test/test_jit_llga_fuser.py +++ b/test/test_jit_llga_fuser.py @@ -774,6 +774,36 @@ def t3(x, y): self.assertGraphContainsExactly(t_jit_3.graph_for(x, y), LLGA_FUSION_GROUP, 0) +@unittest.skipIf(LLGA_NOT_ENABLED, "MKL-DNN build is disabled") +@unittest.skip("Enable when integration with dynamo aot_autograd is more stable") +class TestDynamoAOT(JitTestCase): + def test_dynamo_aot_ts_onednn(self): + class Seq(nn.Module): + def __init__(self): + super().__init__() + self.layers = nn.Sequential( + nn.Linear(10, 10), + nn.ReLU(), + nn.Linear(10, 10), + nn.ReLU(), + ) + + def forward(self, x): + return self.layers(x) + + mod = Seq() + + import torch._dynamo + aot_mod = torch._dynamo.optimize("aot_ts", nopython=True)(mod) + + for _ in range(10): + with torch.jit.fuser("fuser3"): + loss = aot_mod(torch.rand([10, 10])).sum() + loss.backward() + + torch._dynamo.reset() + + @unittest.skipIf(IS_AVX512_UNSUPPORTED, "This test fails for BF16 on machines without AVX512.") @unittest.skipIf(LLGA_NOT_ENABLED, "MKL-DNN build is disabled") class TestModel(JitLlgaTestCase): diff --git a/test/test_vmap.py b/test/test_legacy_vmap.py similarity index 100% rename from test/test_vmap.py rename to test/test_legacy_vmap.py diff --git a/test/test_linalg.py b/test/test_linalg.py index 86790677f56a4..41c3e8a2d9ba2 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -1357,17 +1357,16 @@ def run_test_case(input, ord, dim, keepdim): def test_norm_fused_type_promotion(self, device, dtype): x = torch.randn(10, device=device, dtype=dtype) - def profile_and_check(fn, x, kwargs, fn_name): + def profile_and_check(fn, x, kwargs): with torch.profiler.profile(activities=(torch.profiler.ProfilerActivity.CPU,)) as p: fn(x, **kwargs, dtype=torch.float) # smoke check that profiler returned some events - self.assertTrue(fn_name in map(lambda e: e.name, p.events())) + self.assertTrue("aten::linalg_vector_norm" in (e.name for e in p.events())) # test that there was no explicit copy - self.assertFalse("aten::to" in map(lambda e: e.name, p.events())) + self.assertFalse("aten::to" in (e.name for e in p.events())) - for f, kwargs, fn_name in zip((torch.norm, torch.linalg.vector_norm), ({"p" : 2}, {}), - ("aten::norm", "aten::linalg_vector_norm")): - profile_and_check(f, x, kwargs, fn_name) + for f, kwargs, in zip((torch.linalg.vector_norm, torch.norm), ({}, {"p" : 2})): + profile_and_check(f, x, kwargs) @skipMeta # https://github.com/pytorch/pytorch/issues/53739 @skipCPUIfNoLapack @@ -1540,7 +1539,7 @@ def run_error_test_case(input, ord, dim, keepdim, error_type, error_regex): @skipCUDAIfNoMagma @skipCPUIfNoLapack @dtypes(torch.cfloat, torch.cdouble) - @precisionOverride({torch.cfloat: 2e-4}) + @precisionOverride({torch.cfloat: 5e-4}) def test_norm_complex(self, device, dtype): def gen_error_message(input_size, ord, keepdim, dim=None): return "complex norm failed for input size %s, ord=%s, keepdim=%s, dim=%s" % ( @@ -2310,10 +2309,10 @@ def test_nuclear_norm_exceptions_old(self, device): x = torch.tensor(lst, dtype=torch.double, device=device) for axes in (), (0,): self.assertRaises(RuntimeError, torch.norm, x, "nuc", axes) - self.assertRaises(IndexError, torch.norm, x, "nuc", (0, 1)) + self.assertRaises(RuntimeError, torch.norm, x, "nuc", (0, 1)) x = torch.tensor([[0, 1, 2], [3, 4, 5]], dtype=torch.double, device=device) - self.assertRaisesRegex(RuntimeError, "duplicate or invalid", torch.norm, x, "nuc", (0, 0)) + self.assertRaisesRegex(RuntimeError, "must be different", torch.norm, x, "nuc", (0, 0)) self.assertRaisesRegex(IndexError, "Dimension out of range", torch.norm, x, "nuc", (0, 2)) @skipCUDAIfNoCusolver @@ -2476,28 +2475,6 @@ def test_svd_memory_allocation(self, device, dtype): result = torch.linalg.svd(a, full_matrices=False) self.assertEqual(result.S, S) - # This test doesn't work with MAGMA backend https://github.com/pytorch/pytorch/issues/72106 - @skipMeta - @skipCUDAIfRocm - @skipCUDAIfNoCusolver - @skipCPUIfNoLapack - @dtypes(*floating_and_complex_types()) - def test_svd_nan_error(self, device, dtype): - for svd in [torch.svd, torch.linalg.svd]: - # if input contains NaN then an error is triggered for svd - # When cuda < 11.5, cusolver raises CUSOLVER_STATUS_EXECUTION_FAILED when input contains nan. - # When cuda >= 11.5, cusolver normally finishes execution and sets info array indicating convergence issue. - error_msg = r'(CUSOLVER_STATUS_EXECUTION_FAILED|The algorithm failed to converge)' - a = torch.full((3, 3), float('nan'), dtype=dtype, device=device) - a[0] = float('nan') - with self.assertRaisesRegex(torch.linalg.LinAlgError, error_msg): - svd(a) - error_msg = r'(CUSOLVER_STATUS_EXECUTION_FAILED|\(Batch element 1\): The algorithm failed to converge)' - a = torch.randn(3, 33, 33, dtype=dtype, device=device) - a[1, 0, 0] = float('nan') - with self.assertRaisesRegex(torch.linalg.LinAlgError, error_msg): - svd(a) - def cholesky_solve_test_helper(self, A_dims, b_dims, upper, device, dtype): from torch.testing._internal.common_utils import random_hermitian_pd_matrix diff --git a/test/test_meta.py b/test/test_meta.py index 6b283da39cbe0..b427e75a0c4ff 100644 --- a/test/test_meta.py +++ b/test/test_meta.py @@ -1,11 +1,12 @@ # Owner(s): ["module: primTorch"] +import itertools import torch import os from enum import Enum from torch.overrides import resolve_name from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten -from torch._subclasses.meta_utils import MetaConverter +from torch._subclasses.meta_utils import MetaConverter, assert_metadata_eq import torch.utils._python_dispatch from torch._dispatch.python import enable_python_dispatcher from torch.testing._internal.common_utils import ( @@ -14,12 +15,13 @@ suppress_warnings, TEST_WITH_ASAN, run_tests, - skipIfSlowGradcheckEnv, dtype_abbrs ) from torch.testing._internal.common_device_type import ( ops, instantiate_device_type_tests, + onlyCUDA, + OpDTypes, ) from torch.testing._internal.common_methods_invocations import op_db from torchgen.utils import YamlLoader @@ -50,7 +52,6 @@ u8 = torch.uint8 -@skipIfSlowGradcheckEnv class TestMetaConverter(TestCase): def assertSameVersionCounter(self, m1, m2): # Cannot easily test m1 and m2 have same storage due to @@ -63,6 +64,9 @@ def assertSameVersionCounter(self, m1, m2): self.assertNotEqual(m1._version, vc) self.assertEqual(m2._version, m1._version) + def assertMetadataMatches(self, m1, m2): + assert_metadata_eq(self.assertEqual, m1, m2) + def test_view_of_non_leaf(self): x = torch.randn(4, requires_grad=True) y = x.neg() @@ -71,9 +75,14 @@ def test_view_of_non_leaf(self): to_meta = MetaConverter() m1 = to_meta(z1) m2 = to_meta(z2) - self.assertEqual(m1.shape, z1.shape) + + # check the test is actually testing what it claims self.assertTrue(m1._is_view()) self.assertFalse(m1._base.is_leaf) + + self.assertIsNot(m1, m2) + self.assertMetadataMatches(m1, z1) + self.assertMetadataMatches(m2, z2) self.assertSameVersionCounter(m1, m2) def test_view_of_leaf(self): @@ -83,35 +92,133 @@ def test_view_of_leaf(self): to_meta = MetaConverter() m1 = to_meta(z1) m2 = to_meta(z2) - self.assertEqual(m1.shape, z1.shape) + + # check the test is actually testing what it claims self.assertTrue(m1._is_view()) self.assertTrue(m1._base.is_leaf) + + self.assertIsNot(m1, m2) + self.assertMetadataMatches(m1, z1) + self.assertMetadataMatches(m2, z2) self.assertSameVersionCounter(m1, m2) + def test_view_of_view_of_leaf(self): + x = torch.randn(8) + y = x.view(2, 4) + y.requires_grad = True + z = y.view(2, 2, 2) + + to_meta = MetaConverter() + mx = to_meta(x) + mz = to_meta(z) + + self.assertFalse(z.is_leaf) + + self.assertMetadataMatches(mx, x) + self.assertMetadataMatches(mz, z) + def test_leaf(self): x = torch.randn(4, requires_grad=True) to_meta = MetaConverter() m = to_meta(x) - self.assertEqual(m.shape, x.shape) + + # check the test is actually testing what it claims self.assertTrue(m.is_leaf) self.assertTrue(m.requires_grad) + self.assertMetadataMatches(m, x) + def test_non_leaf(self): x = torch.randn(4, requires_grad=True) y = x.neg() to_meta = MetaConverter() m = to_meta(y) - self.assertEqual(m.shape, y.shape) + + # check the test is actually testing what it claims self.assertFalse(m.is_leaf) self.assertTrue(m.requires_grad) + self.assertMetadataMatches(m, y) + def test_requires_grad_false(self): x = torch.randn(4, requires_grad=False) to_meta = MetaConverter() m = to_meta(x) - self.assertEqual(m.shape, x.shape) + + # check the test is actually testing what it claims self.assertFalse(m.requires_grad) + self.assertMetadataMatches(m, x) + + def test_channels_last(self): + x = torch.empty(2, 3, 4, 5, memory_format=torch.channels_last) + to_meta = MetaConverter() + m = to_meta(x) + + # check the test is actually testing what it claims + self.assertTrue(m.is_leaf) + + self.assertMetadataMatches(m, x) + + def test_channels_last_leaf(self): + x = torch.empty(2, 3, 4, 5, memory_format=torch.channels_last, requires_grad=True) + to_meta = MetaConverter() + m = to_meta(x) + + # check the test is actually testing what it claims + self.assertTrue(m.requires_grad) + self.assertTrue(m.is_leaf) + + self.assertMetadataMatches(m, x) + + def test_channels_last_non_leaf(self): + x = torch.empty(2, 3, 4, 5, memory_format=torch.channels_last, requires_grad=True) + y = x + 2 + + # sanity + self.assertEqual(x.stride(), y.stride()) + self.assertFalse(y.is_leaf) + + to_meta = MetaConverter() + m = to_meta(y) + + # check the test is actually testing what it claims + self.assertTrue(m.requires_grad) + self.assertFalse(m.is_leaf) + + self.assertMetadataMatches(m, y) + + # Check that we can autograd with m as input without erroring; + # see https://github.com/pytorch/pytorch/issues/87956 + loss = m.sum() + torch.autograd.grad(loss, m) + + def test_empty_strided_non_dense_leaf(self): + x = torch.empty_strided((2, 2), (4, 2), requires_grad=True) + + to_meta = MetaConverter() + m = to_meta(x) + + # check the test is actually testing what it claims + self.assertTrue(m.requires_grad) + self.assertTrue(m.is_leaf) + + self.assertMetadataMatches(m, x) + + def test_non_leaf_torture(self): + x = torch.empty(20, requires_grad=True) + with torch.no_grad(): + x.set_(x.storage(), 10, (2,), (2,)) + + to_meta = MetaConverter() + m = to_meta(x) + + # check the test is actually testing what it claims + self.assertTrue(m.requires_grad) + self.assertTrue(m.is_leaf) + + self.assertMetadataMatches(m, x) + # NB: complex stuff is not actually exercised right now because # we have a blanket exclusion for complex conversion @@ -119,41 +226,30 @@ def test_view_as_real(self): x = torch.randn(4, dtype=torch.complex64) y = torch.view_as_real(x) m = MetaConverter()(y) - self.assertEqual(m.shape, y.shape) - self.assertEqual(m.stride(), y.stride()) - self.assertEqual(m.dtype, y.dtype) + self.assertMetadataMatches(m, y) def test_complex_noncontiguous_bug(self): x = torch.randn((2, 2, 4, 9), dtype=torch.complex32)[:, 0, :, :] m = MetaConverter()(x) - self.assertEqual(m.shape, x.shape) - self.assertEqual(m.stride(), x.stride()) - self.assertEqual(m.dtype, x.dtype) + self.assertMetadataMatches(m, x) def test_view_as_complex(self): x = torch.randn((4, 2), dtype=torch.float32) y = torch.view_as_complex(x) m = MetaConverter()(y) - self.assertEqual(m.shape, y.shape) - self.assertEqual(m.stride(), y.stride()) - self.assertEqual(m.dtype, y.dtype) + self.assertMetadataMatches(m, y) def test_view_dtype(self): x = torch.randn(4, dtype=torch.float32) y = x.view(dtype=torch.int32) m = MetaConverter()(y) - self.assertEqual(m.shape, y.shape) - self.assertEqual(m.stride(), y.stride()) - self.assertEqual(m.dtype, y.dtype) + self.assertMetadataMatches(m, y) def test_imag(self): x = torch.randn(4, dtype=torch.complex64) y = x.imag m = MetaConverter()(y) - self.assertEqual(m.shape, y.shape) - self.assertEqual(m.dtype, y.dtype) - self.assertEqual(m.stride(), y.stride()) - self.assertEqual(m.storage_offset(), y.storage_offset()) + self.assertMetadataMatches(m, y) def test_weakref(self): x = torch.randn(4, 4, 4) @@ -168,9 +264,10 @@ def test_weakref(self): m.check_for_expired_weak_storages() self.assertEqual(len(m.storage_memo), 0) li = [] + r = [] for i in range(4): li.append(torch.rand([i])) - m(li[-1]) + r.append(m(li[-1])) self.assertEqual(len(m.tensor_memo), 4) del li self.assertEqual(len(m.tensor_memo), 0) @@ -185,24 +282,70 @@ def test_tensor_outlives_converter(self): del m self.assertIs(ref(), None) +aten = torch.ops.aten + CHECK_STRIDES = { torch.Tensor.__getitem__, } +CHECK_ALL_STRIDES = { + aten.unsqueeze.default +} + +CHECK_STRIDES_SKIPS = { + aten._conj_physical.default, + aten._fft_c2c.default, + aten._fft_c2r.default, + aten._fft_r2c.default, + aten._linalg_svd.default, + aten.binary_cross_entropy.default, + aten.complex.default, + aten.copysign.Tensor, + aten.div.Tensor_mode, + aten.floor_divide.default, + aten.heaviside.default, + aten.lerp.Scalar, + aten.lerp.Tensor, + aten.logical_and.default, + aten.logical_or.default, + aten.logical_xor.default, + aten.pow.Scalar, + aten.prelu.default, + aten.special_xlog1py.default, + aten.xlogy.Tensor, + + # channel_last and channel_last_3d related failures + aten.convolution.default, + + # following ops fails if include_storage_offset = True, but these are a bit edge casey + # we should still fix them, leaving them here for tracking. + # aten._reshape_alias.default, # repro with test_dispatch_symbolic_meta_outplace_all_strides_matmul_cuda_float32 + # aten.view.default, # repro with test_dispatch_symbolic_meta_outplace_all_strides_unflatten_cuda_float32 +} + +class CheckStrides(Enum): + NONE = 0 + SIGNIFICANT = 1 + ALL = 2 + def should_check_strides(func): + if func in CHECK_ALL_STRIDES: + return CheckStrides.ALL if func in CHECK_STRIDES: - return True + return CheckStrides.SIGNIFICANT + if func in CHECK_STRIDES_SKIPS: + return CheckStrides.NONE if not isinstance(func, torch._ops.OpOverload): - return False + return CheckStrides.NONE # Prims are expected to model strides correctly if func.namespace == "prims": - return True + return CheckStrides.SIGNIFICANT # Check if it's a view, by testing if any of the returns have # a non-empty alias set if any(r.alias_info.before_set for r in func._schema.returns if r.alias_info): - return True + return CheckStrides.SIGNIFICANT # TODO: check for TensorIterator - return False + return CheckStrides.SIGNIFICANT def assert_ref_meta_equal(test_case, func, meta_rs, rs, msg_callable): flat_meta_rs, _ = tree_flatten(meta_rs) @@ -218,7 +361,10 @@ def test_assert(cond, msg): test_assert(meta_r.dtype == r.dtype, f"but real dtype was {r.dtype}") test_assert(meta_r.shape == r.shape, f"but real shape was {r.shape}") # See https://github.com/pytorch/pytorch/issues/78050 - if should_check_strides(func): + if should_check_strides(func) == CheckStrides.ALL: + same_strides, _ = torch._prims_common.check_all_strides(meta_r, r) + test_assert(same_strides, f"but real stride was {r.stride()}") + elif should_check_strides(func) == CheckStrides.SIGNIFICANT: same_strides, _ = torch._prims_common.check_significant_strides(meta_r, r) test_assert(same_strides, f"but real stride was {r.stride()}") test_assert( @@ -335,7 +481,7 @@ def run_meta_crossref( # they're not tested outside of gradcheck which only checks # torch.float64 and torch.complex128 (which this second one # often skipped as well). - raise unittest.SkipTest("Original OpInfo is broken") + raise unittest.SkipTest("Original OpInfo is broken") from e # TODO: also handle cases where func raise an exception @@ -489,7 +635,6 @@ def run_meta_crossref( torch.linalg.eig : {f64, f32, c128, c64}, torch.linalg.eigvals : {f64, f32, c128, c64}, torch.linalg.lstsq : {f64, f32, c128, c64}, - torch.Tensor.conj_physical_: {c128, c32, c64}, } meta_function_expected_failures_only_outplace = { @@ -579,6 +724,7 @@ def run_meta_crossref( meta_function_device_expected_failures['cpu'] = { torch.native_batch_norm: {bf16}, + torch._native_batch_norm_legit: {bf16}, torch.native_layer_norm: {bf16}, } @@ -612,8 +758,8 @@ def run_meta_crossref( } meta_function_device_skips['cpu'] = { - torch.narrow_copy: {b8, bf16, c128, c32, c64, f16, f32, f64, i16, i32, i64, i8, u8}, torch.native_batch_norm: {f32, f64}, + torch._native_batch_norm_legit: {f32, f64}, } meta_function_device_skips['cuda'] = { @@ -660,7 +806,12 @@ def __init__(self, test_case, *, device, dtype, inplace): def __torch_function__(self, func, types, args=(), kwargs=None): kwargs = kwargs or {} - if torch.jit.is_tracing() or isinstance(func, torch.ScriptMethod): + if ( + torch.jit.is_tracing() or isinstance(func, torch.ScriptMethod) or + # meta converter doesn't work correctly when no_dispatch() is on, so + # skip running the crossref test in this case + torch._C._dispatch_tls_local_exclude_set().has(torch._C.DispatchKey.Python) + ): return func(*args, **kwargs) if self.dtype in meta_function_skips.get(func, set()): @@ -684,8 +835,6 @@ def __torch_function__(self, func, types, args=(), kwargs=None): kwargs, dtype=self.dtype, device_type=self.device_type, run_symbolic_meta=False ) -aten = torch.ops.aten - # these always fail meta_dispatch_expected_failures = { aten.allclose.default: {f16, bf16, f32, f64, c64, c128}, # NotImplementedError: 'aten::_local_scalar_dense' @@ -758,7 +907,6 @@ def __torch_function__(self, func, types, args=(), kwargs=None): aten.unique_consecutive.default : {i8, f64, i64, bf16, f32, i32, b8, i16, u8}, aten.unique_dim.default : {i8, f64, i64, bf16, f32, i32, b8, i16, u8}, aten.upsample_nearest3d.vec : {bf16, f32, f64, u8}, - aten.conj_physical_.default: {c128, c32, c64}, } # these sometimes pass and sometimes fail @@ -781,6 +929,13 @@ def __torch_function__(self, func, types, args=(), kwargs=None): # For CompositeImplicitAutograd functions that fail before hitting the Mode meta_dispatch_early_skips = set({ torch.Tensor.float_power_, + # Errors out in one of the tests, while ProxyTensor passes... + torch.Tensor.cumsum_, +}) + +meta_inplace_skips = set({ + # Errors out in one of the tests, while ProxyTensor passes... + torch.Tensor.cumsum_, }) meta_dispatch_device_expected_failures = defaultdict(dict) @@ -788,6 +943,8 @@ def __torch_function__(self, func, types, args=(), kwargs=None): meta_dispatch_device_expected_failures['cpu'] = { aten.native_batch_norm.default: {bf16}, + aten._native_batch_norm_legit.default: {bf16}, + aten._native_batch_norm_legit.no_stats: {bf16}, aten.native_layer_norm.default: {bf16}, } @@ -833,6 +990,8 @@ def __torch_function__(self, func, types, args=(), kwargs=None): meta_dispatch_device_skips['cpu'] = { aten._embedding_bag_forward_only.default: {f16, f32, f64}, aten.native_batch_norm.default: {f32, f64}, + aten._native_batch_norm_legit.default: {f32, f64}, + aten._native_batch_norm_legit.no_stats: {f32, f64}, } meta_dispatch_device_skips['cuda'] = { @@ -850,6 +1009,55 @@ def __torch_function__(self, func, types, args=(), kwargs=None): aten.miopen_batch_norm.default: {f32}, } +def get_strided_args(args): + + def get_strided_variants(t, include_storage_offset=False): + variants = [] + + # contiguous + variants.append(t) + + # transposed + if t.ndim > 1: + perm = list(reversed(range(t.ndim))) + transposed = torch.empty( + t.shape[::-1], device=t.device, dtype=t.dtype, requires_grad=t.requires_grad + ).permute(perm).copy_(t) + variants.append(transposed) + + # nondense + if t.ndim > 0: + nondense = torch.repeat_interleave(t, 2, dim=-1)[..., ::2] + variants.append(nondense) + + # channel_last + if t.ndim == 4: + variants.append(t.contiguous(memory_format=torch.channels_last)) + + # channel_last_3d + if t.ndim == 5: + variants.append(t.contiguous(memory_format=torch.channels_last_3d)) + + # storage_offset + if include_storage_offset: + buffer = torch.empty(t.numel() + 1, device=t.device, dtype=t.dtype, requires_grad=t.requires_grad) + buffer = buffer.as_strided(t.shape, t.stride(), storage_offset=1) + buffer.copy_(t) + variants.append(buffer) + + return variants + + strided_args = [] + for arg in args: + if isinstance(arg, torch.Tensor) and not arg.is_sparse_csr and arg.is_contiguous(): + strided_arg_variants = get_strided_variants(arg) + else: + strided_arg_variants = [arg] + strided_args.append(strided_arg_variants) + + for result in itertools.product(*strided_args): + yield result + class MetaCrossRefDispatchMode(torch.utils._python_dispatch.TorchDispatchMode): test_case: TestCase device: torch.device @@ -896,7 +1104,6 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): # inconsistencies between CUDA and CPU, and running on CUDA makes it easier # to ignore the CPU case when inconsistencies arise. Ideally we deal # with the inconsistencies but this takes time. -@skipIfSlowGradcheckEnv class TestMeta(TestCase): # Copies inputs to inplace operations to avoid inplace modifications # to leaves requiring gradient @@ -933,6 +1140,8 @@ def test_meta_inplace(self, device, dtype, op): func = op.get_inplace() if not func: self.skipTest("No inplace variable for this op") + if func in meta_inplace_skips: + self.skipTest("Skipped") func = self._get_safe_inplace(func) samples = op.sample_inputs(device, dtype, requires_grad=False) for sample_input in samples: @@ -943,7 +1152,7 @@ def test_meta_inplace(self, device, dtype, op): with MetaCrossRefFunctionMode(self, dtype=dtype, device=device, inplace=True): expected = func(*args, **kwargs) - def _run_dispatch_meta_test(self, device, dtype, op, symbolic_meta, inplace): + def _run_dispatch_meta_test(self, device, dtype, op, symbolic_meta, inplace, all_stride_variants=False): if inplace: func = op.get_inplace() if not func: @@ -962,14 +1171,21 @@ def _run_dispatch_meta_test(self, device, dtype, op, symbolic_meta, inplace): if inplace and sample_input.broadcasts_input: continue - args = [sample_input.input] + list(sample_input.args) + sample_args = [sample_input.input] + list(sample_input.args) kwargs = sample_input.kwargs - with MetaCrossRefDispatchMode.push(self, dtype=dtype, device=device, symbolic_meta=symbolic_meta): - expected = func(*args, **kwargs) + if all_stride_variants and sum(isinstance(arg, torch.Tensor) for arg in sample_args) <= 5: + # test inputs <= 5 tensors to avoid combinatorial explosion + strided_args = get_strided_args(sample_args) + else: + strided_args = [sample_args] - if not inplace and isinstance(expected, torch.Tensor) and op.supports_out: - func(*args, **kwargs, out=expected) + for args in strided_args: + with MetaCrossRefDispatchMode.push(self, dtype=dtype, device=device, symbolic_meta=symbolic_meta): + expected = func(*args, **kwargs) + + if not inplace and isinstance(expected, torch.Tensor) and op.supports_out: + func(*args, **kwargs, out=expected) @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") @@ -1002,6 +1218,26 @@ def test_dispatch_symbolic_meta_outplace(self, device, dtype, op): def test_dispatch_symbolic_meta_inplace(self, device, dtype, op): self._run_dispatch_meta_test(device, dtype, op, symbolic_meta=True, inplace=True) + @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") + @skipIfCrossRef + @suppress_warnings + # only test one dtype, as output stride behavior is the same for all dtypes + @ops(op_db, dtypes=OpDTypes.any_common_cpu_cuda_one) + # Only test on CUDA, as CUDA kernel's stride is the reference + @onlyCUDA + def test_dispatch_symbolic_meta_outplace_all_strides(self, device, dtype, op): + self._run_dispatch_meta_test(device, dtype, op, symbolic_meta=True, inplace=False, all_stride_variants=True) + + @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") + @skipIfCrossRef + @suppress_warnings + # only test one dtype, as output stride behavior is the same for all dtypes + @ops(op_db, dtypes=OpDTypes.any_common_cpu_cuda_one) + # Only test on CUDA, as CUDA kernel's stride is the reference + @onlyCUDA + def test_dispatch_symbolic_meta_inplace_all_strides(self, device, dtype, op): + self._run_dispatch_meta_test(device, dtype, op, symbolic_meta=True, inplace=True, all_stride_variants=True) + def test_empty_quantized(self): r = torch.empty(2 ** 52, device='meta', dtype=torch.qint8) @@ -1013,6 +1249,91 @@ def test_huber_loss_backward(self): self.assertEqual(r.device.type, 'meta') self.assertEqual(r.shape, inps[0].shape) + def test_fill__alias_relationship(self): + inps = torch.rand(2**52, device='meta') + r = torch.ops.aten.fill_(inps, 1.0) + # aten.fill_ returns an aliase + self.assertEqual(id(inps), id(r)) + + # aten.fill returns a new tensor + r2 = torch.ops.aten.fill(inps, 1.0) + self.assertNotEqual(id(inps), id(r2)) + + def test_meta__fused_moving_avg_obs_fq_helper(self, device): + from torch.ao.quantization import FusedMovingAvgObsFakeQuantize + to_meta = MetaConverter() + + x = torch.randn(5, 5, device=device) + running_min_op = torch.tensor(float("inf"), device=device) + running_max_op = torch.tensor(float("-inf"), device=device) + avg_const = 0.01 + scale = torch.tensor([1.0], device=device) + zero_point = torch.tensor([0], dtype=torch.int, device=device) + + mod = FusedMovingAvgObsFakeQuantize() + torch.ao.quantization.enable_fake_quant(mod) + torch.ao.quantization.enable_observer(mod) + mod.to(device) + + meta_x = to_meta(x) + + args = [ + x, + mod.observer_enabled, + mod.fake_quant_enabled, + running_min_op, + running_max_op, + scale, + zero_point, + avg_const, + 0, + 255, + 0, + ] + + meta_args = args.copy() + meta_args[0] = meta_x + + kwargss = [ + {}, + {"per_row_fake_quant": False, "symmetric_quant": False}, + {"per_row_fake_quant": False, "symmetric_quant": True}, + ] + + for kwargs in kwargss: + ref_out = aten._fused_moving_avg_obs_fq_helper.default(*args, **kwargs) + meta_out = aten._fused_moving_avg_obs_fq_helper.default(*meta_args, **kwargs) + + self.assertEqual(ref_out[0].size(), meta_out[0].size()) + self.assertEqual(ref_out[0].stride(), meta_out[0].stride()) + self.assertEqual(ref_out[1].size(), meta_out[1].size()) + self.assertEqual(ref_out[1].stride(), meta_out[1].stride()) + + def test_cdist_forward(self, device): + to_meta = MetaConverter() + x1 = torch.rand([3, 2], device=device) + x2 = torch.rand([2, 2], device=device) + p = 2.0 + for compute_mode in (None, 1, 2): + ref = aten._cdist_forward.default(x1, x2, p, compute_mode) + res = aten._cdist_forward.default(to_meta(x1), to_meta(x2), p, compute_mode) + self.assertEqual(res.device.type, 'meta') + self.assertEqual(ref.shape, res.shape) + + # opinfo test is using aten.fill_, it's not testing aten.fill + @onlyCUDA + def test_fill_stride(self): + to_meta = MetaConverter() + sample_args = [torch.rand(2, 2, 2, 2), 1.0] + + for args in get_strided_args(sample_args): + meta_args = to_meta(args) + ref_out = torch.ops.aten.fill(*args) + meta_out = torch.ops.aten.fill(*meta_args) + self.assertEqual(ref_out.size(), meta_out.size()) + self.assertEqual(ref_out.stride(), meta_out.stride()) + + def test_map_location_deserialize(self): import io diff --git a/test/test_mkldnn.py b/test/test_mkldnn.py index 04a213b1a13df..f4f427ba659c7 100644 --- a/test/test_mkldnn.py +++ b/test/test_mkldnn.py @@ -1060,6 +1060,11 @@ def test_transpose(self): x.to_mkldnn().transpose(dim1, dim2).to_dense(), ) + def test_transpose_invalid_dime(self): + x = torch.randn(3, 4, 5, dtype=torch.float32).to_mkldnn() + with self.assertRaisesRegex(IndexError, "Dimension out of range"): + torch._mkldnn_transpose(x, 0, 12) + def test_linear_non_contiguous_weight(self): in_features = torch.randint(3, 10, (1,)).item() out_features = torch.randint(3, 100, (1,)).item() diff --git a/test/test_mkldnn_fusion.py b/test/test_mkldnn_fusion.py index cdef4bcfd6a57..9f264337d9567 100644 --- a/test/test_mkldnn_fusion.py +++ b/test/test_mkldnn_fusion.py @@ -271,8 +271,8 @@ def forward(self, x, other): for pointwise_name, pointwise_fn in self._binary_list().items(): for dim in [2, 3]: channels_last = torch.channels_last if dim == 2 else torch.channels_last_3d - options = itertools.product([True, False], [1, 2], [1, 4], [torch.contiguous_format, channels_last]) - for bias, dilation, groups, memory_format in options: + options = itertools.product([False, True], [True, False], [1, 2], [1, 4], [torch.contiguous_format, channels_last]) + for fuse_relu, bias, dilation, groups, memory_format in options: oC = 32 * groups iC = 3 * groups x_shape = (1, iC) + input_shapes[dim] @@ -282,12 +282,26 @@ def forward(self, x, other): other = torch.randn_like(mod.conv(x)) with torch.no_grad(): ref = mod(x, other) + unary_attr = None + if fuse_relu: + ref.relu_() + unary_attr = "relu" attr = pointwise_name fused = torch.ops.mkldnn._convolution_pointwise( x, other, mod.conv.weight, mod.conv.bias, mod.conv.padding, mod.conv.stride, mod.conv.dilation, - mod.conv.groups, attr + mod.conv.groups, attr, None, unary_attr, [], None ) - self.assertEqual(ref, fused) + # for binary add, we support inplace version. + if attr == "add": + fused_inplace = torch.ops.mkldnn._convolution_pointwise_( + x, other, mod.conv.weight, mod.conv.bias, mod.conv.padding, mod.conv.stride, mod.conv.dilation, + mod.conv.groups, attr, None, unary_attr, [], None + ) + self.assertEqual(ref, other) + self.assertEqual(ref, fused_inplace) + + self.assertEqual(ref, fused) + def test_linear_binary_fusion_ops(self): class M(nn.Module): diff --git a/test/test_module_init.py b/test/test_module_init.py index dc05a95da6f2b..98dcb3ee694a4 100644 --- a/test/test_module_init.py +++ b/test/test_module_init.py @@ -4,7 +4,7 @@ import torch from unittest import mock from unittest.mock import MagicMock, patch -from torch.testing import floating_types +from torch.testing._internal.common_dtype import floating_types from torch.testing._internal.common_device_type import instantiate_device_type_tests, dtypes from torch.testing._internal.common_quantization import skipIfNoFBGEMM from torch.testing._internal.common_utils import TestCase, run_tests diff --git a/test/test_modules.py b/test/test_modules.py index e06f0cc617d99..2f5008244d548 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -11,7 +11,8 @@ instantiate_device_type_tests, onlyCUDA, toleranceOverride, tol, skipMeta) from torch.testing._internal.common_modules import module_db, modules, TrainEvalMode from torch.testing._internal.common_utils import ( - TestCase, run_tests, freeze_rng_state, mock_wrapper, get_tensors_from, gradcheck, gradgradcheck, skipIfMps) + TestCase, run_tests, freeze_rng_state, mock_wrapper, get_tensors_from, gradcheck, + gradgradcheck, skipIfMps, skipIfTorchInductor) from unittest.mock import patch, call @@ -326,6 +327,7 @@ def inner_zero_grad(obj): @skipIfMps @modules(module_db) + @skipIfTorchInductor("to be fixed") def test_non_contiguous_tensors(self, device, dtype, module_info, training): # Check modules work with non-contiguous tensors @@ -489,6 +491,7 @@ def test_gradgrad(self, device, dtype, module_info, training): @toleranceOverride({torch.float32: tol(5e-2, 0), torch.float64: tol(4e-4, 0)}) @modules(module_db) + @skipIfTorchInductor("to be fixed") def test_cpu_gpu_parity(self, device, dtype, module_info, training): # TODO: RNN / GRU / LSTM don't support backwards on eval mode for cuDNN; skip this in a # nicer way for eval mode only. @@ -579,6 +582,7 @@ def check_backward(cpu_output, gpu_output): @skipIfMps @modules(module_db) + @skipIfTorchInductor("to be fixed") def test_memory_format(self, device, dtype, module_info, training): is_sm86 = device.startswith("cuda") and torch.cuda.get_device_capability(0) == (8, 6) # TODO tighten it to a specific module diff --git a/test/test_mps.py b/test/test_mps.py index 9702239df95df..b70ff1c43fae4 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -15,25 +15,48 @@ import torch.nn as nn import torch.nn.functional as F import itertools +import yaml +import platform from collections import defaultdict from torch._six import inf from torch.nn import Parameter +from torch.testing._internal import opinfo from torch.testing._internal.common_utils import \ - (gradcheck, gradgradcheck, run_tests, TestCase, download_file, - TEST_WITH_UBSAN, dtype_abbrs) + (gradcheck, gradgradcheck, run_tests, TestCase, download_file, IS_CI, + TEST_WITH_UBSAN, dtype_abbrs, skipIfSlowGradcheckEnv, TEST_WITH_ASAN, suppress_warnings) from torch.testing import make_tensor from torch.testing._comparison import TensorLikePair from torch.testing._internal.common_dtype import get_all_dtypes, integral_types +import torch.mps import torch.backends.mps from torch.distributions import Uniform, Exponential -from functools import partial - -from torch.testing._internal.common_methods_invocations import op_db -from torch.testing._internal.common_device_type import ops, instantiate_device_type_tests +from functools import partial, reduce + +from torch.testing._internal.common_methods_invocations import ( + op_db, + UnaryUfuncInfo, + ReductionOpInfo, + SpectralFuncInfo, + BinaryUfuncInfo, +) +from torch.testing._internal.common_device_type import ops, instantiate_device_type_tests, onlyMPS from torch.testing._internal.common_nn import NNTestCase import numpy as np import torch import torch.utils._pytree as pytree +from itertools import product + + +# Copied from `test_ops.py` for the purposes of duplicating `test_numpy_ref` +_ref_test_ops = tuple( + filter( + lambda op: not isinstance( + op, (UnaryUfuncInfo, ReductionOpInfo, SpectralFuncInfo, BinaryUfuncInfo) + ) + and op.ref is not None, + op_db, + ) +) # Same logic as test_cuda.py if not torch.backends.mps.is_available(): @@ -46,7 +69,7 @@ def _npRelu(self, np_features): return np.maximum(np_features, np.zeros(np_features.shape)).astype(np_features.dtype) def testNpRelu(self): - torch.testing.assert_allclose( + torch.testing.assert_close( np.array([[0., 0.7, 0.0, 0.3, 0.0], [0.1, 0.0, 0.5, 0.0, 0.9]]), self._npRelu( np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, -0.7, @@ -60,7 +83,7 @@ def _testRelu(self, np_features, device): py_relu = torch.nn.ReLU(inplace=False)(py_tensor) py_relu_cpu = py_relu.to("cpu") - torch.testing.assert_allclose(np_relu, py_relu_cpu) + self.assertEqual(np_relu, py_relu_cpu) def _testReluInPlace(self, np_features, device): np_relu = self._npRelu(np_features) @@ -70,9 +93,9 @@ def _testReluInPlace(self, np_features, device): py_relu = torch.nn.ReLU(inplace=True)(py_tensor) py_relu_cpu = py_relu.to("cpu") - torch.testing.assert_allclose(np_relu, py_relu_cpu) + self.assertEqual(np_relu, py_relu_cpu) # Inplace Relu modifies the initial input and it should match the output of Relu - torch.testing.assert_allclose(np_relu, py_tensor.to("cpu")) + self.assertEqual(np_relu, py_tensor.to("cpu")) def testNumbersCPU(self): for t in [np.int32]: @@ -137,7 +160,7 @@ def _npLeakyRelu(self, np_features, negative_slope=0.1): return np.maximum(np_features, negative_slope * np_features).astype(np_features.dtype) def testNpLeakyRelu(self): - torch.testing.assert_allclose( + torch.testing.assert_close( np.array([[-0.09, 0.7, -0.05, 0.3, -0.01], [0.1, -0.03, 0.5, -0.07, 0.9]]), self._npLeakyRelu( @@ -152,14 +175,14 @@ def _testLeakyRelu(self, np_features, negative_slope, device): cpu_leaky_relu = relu_op(cpu_x) mps_leaky_relu = relu_op(mps_x) - torch.testing.assert_allclose(cpu_leaky_relu, mps_leaky_relu.to('cpu')) + torch.testing.assert_close(cpu_leaky_relu, mps_leaky_relu.to('cpu')) # test backward pass cpu_grad = torch.ones_like(cpu_leaky_relu) mps_grad = cpu_grad.to('mps') cpu_leaky_relu.backward(gradient=cpu_grad) mps_leaky_relu.backward(gradient=mps_grad) - torch.testing.assert_allclose(cpu_x.grad, mps_x.grad.to('cpu')) + torch.testing.assert_close(cpu_x.grad, mps_x.grad.to('cpu')) def testNumbersCPU(self): for t in [np.float32]: @@ -231,6 +254,17 @@ def test_exp1(self, device="mps", dtype=torch.float): input = torch.tensor([-0.1, 3.0, -0.9]).to('mps') output = torch.exp(input).to('cpu') + def test_exp_strided_output(self): + x = torch.rand((256, 10), device='mps') + x_cpu = x.to("cpu") + + x = x.permute(1, 0) + x_cpu = x_cpu.permute(1, 0) + + res = x.exp() + res_cpu = x_cpu.exp() + self.assertEqual(res, res_cpu) + def _testLeakyRelu(self, np_features, negative_slope, device): cpu_x = torch.from_numpy(np_features).requires_grad_() mps_x = torch.from_numpy(np_features).to('mps').requires_grad_() @@ -238,14 +272,14 @@ def _testLeakyRelu(self, np_features, negative_slope, device): cpu_leaky_relu = relu_op(cpu_x) mps_leaky_relu = relu_op(mps_x) - torch.testing.assert_allclose(cpu_leaky_relu, mps_leaky_relu.to('cpu')) + torch.testing.assert_close(cpu_leaky_relu, mps_leaky_relu.to('cpu')) # test backward pass cpu_grad = torch.ones_like(cpu_leaky_relu) mps_grad = cpu_grad.to('mps') cpu_leaky_relu.backward(gradient=cpu_grad) mps_leaky_relu.backward(gradient=mps_grad) - torch.testing.assert_allclose(cpu_x.grad, mps_x.grad.to('cpu')) + torch.testing.assert_close(cpu_x.grad, mps_x.grad.to('cpu')) def testNumbersGPU(self): for t in [np.float32]: @@ -274,14 +308,209 @@ def test_mm(self): B = torch.ones(5, 6).to("mps") C = torch.ones(6, 5).to("mps") D = torch.mm(B, C).cpu() - torch.testing.assert_allclose(D, torch.full((5, 5), 6.0)) + torch.testing.assert_close(D, torch.full((5, 5), 6.0)) + + def test_linalg_cross(self): + def helper(dtype): + device = "mps" + if dtype is torch.int32 or dtype is torch.int64: + x = torch.randint(0, 99999, (100, 3, 100), dtype=dtype, device=device) + y = torch.randint(0, 99999, (100, 3, 100), dtype=dtype, device=device) + else: + x = torch.rand(100, 3, 100, dtype=dtype, device=device) + y = torch.rand(100, 3, 100, dtype=dtype, device=device) + x_cpu = x.to("cpu") + y_cpu = y.to("cpu") + res1 = torch.linalg.cross(x, y, dim=1) + res2 = torch.tensor((), dtype=dtype, device=device) + res1_cpu = torch.linalg.cross(x_cpu, y_cpu, dim=1) + res2_cpu = torch.tensor((), dtype=dtype, device="cpu") + torch.linalg.cross(x, y, dim=1, out=res2) + torch.linalg.cross(x_cpu, y_cpu, dim=1, out=res2_cpu) + self.assertEqual(res1, res2) + self.assertEqual(res1, res1_cpu) + self.assertEqual(res2, res2_cpu) + + # test for broadcastable inputs + if dtype is torch.int32 or dtype is torch.int64: + x = torch.randint(0, 99999, (1, 3, 2), dtype=dtype, device=device) + y = torch.randint(0, 99999, (4, 3, 1), dtype=dtype, device=device) + else: + x = torch.rand(1, 3, 2, dtype=dtype, device=device) + y = torch.rand(4, 3, 1, dtype=dtype, device=device) + x_cpu = x.to("cpu") + y_cpu = y.to("cpu") + res1 = torch.linalg.cross(x, y, dim=1) + res2 = torch.tensor((), dtype=dtype, device=device) + res1_cpu = torch.linalg.cross(x_cpu, y_cpu, dim=1) + res2_cpu = torch.tensor((), dtype=dtype, device="cpu") + torch.linalg.cross(x, y, dim=1, out=res2) + torch.linalg.cross(x_cpu, y_cpu, dim=1, out=res2_cpu) + self.assertEqual(res1, res2) + self.assertEqual(res1, res1_cpu) + self.assertEqual(res2, res2_cpu) + [helper(dtype) for dtype in [torch.int32, torch.int64, torch.float32]] + + def test_cdist_large(self, device="mps"): + for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']: + x = torch.randn(100, 10, device=device) + y = torch.randn(100, 10, device=device) + actual = torch.cdist(x, y, p=2, compute_mode=cm) + expected = self._brute_cdist(x, y, p=2) + self.assertEqual(expected, actual) + + def test_cdist_large_batch(self, device="mps"): + for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']: + x = torch.randn(4, 3, 100, 10, device=device) + y = torch.randn(4, 3, 100, 10, device=device) + actual = torch.cdist(x, y, p=2, compute_mode=cm) + expected = self._brute_cdist(x, y, p=2) + self.assertEqual(expected, actual) + + def test_cdist_non_contiguous(self, device="mps"): + for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']: + x = torch.randn(5, 7, device=device).mT + y = torch.randn(5, 3, device=device).mT + actual = torch.cdist(x, y, p=2, compute_mode=cm) + expected = self._brute_cdist(x, y, p=2) + self.assertFalse(x.is_contiguous()) + self.assertFalse(y.is_contiguous()) + self.assertEqual(expected, actual) + + x = torch.randn(7, 5, device=device) + y = torch.randn(5, 3, device=device).t() + actual = torch.cdist(x, y, p=2, compute_mode=cm) + expected = self._brute_cdist(x, y, p=2) + self.assertTrue(x.is_contiguous()) + self.assertFalse(y.is_contiguous()) + self.assertEqual(expected, actual) + + x = torch.randn(5, 7, device=device).t() + y = torch.randn(3, 5, device=device) + actual = torch.cdist(x, y, p=2, compute_mode=cm) + expected = self._brute_cdist(x, y, p=2) + self.assertFalse(x.is_contiguous()) + self.assertTrue(y.is_contiguous()) + self.assertEqual(expected, actual) + + def test_cdist_non_contiguous_batch(self, device="mps"): + for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']: + x = torch.randn(4, 3, 2, 5, 7, device=device).mT + y = torch.randn(4, 3, 2, 5, 3, device=device).mT + actual = torch.cdist(x, y, p=2, compute_mode=cm) + expected = self._brute_cdist(x, y, p=2) + self.assertFalse(x.is_contiguous()) + self.assertFalse(y.is_contiguous()) + self.assertEqual(expected, actual) + + x = torch.randn(7, 2, 7, 5, device=device) + y = torch.randn(7, 2, 5, 3, device=device).mT + actual = torch.cdist(x, y, p=2, compute_mode=cm) + expected = self._brute_cdist(x, y, p=2) + self.assertTrue(x.is_contiguous()) + self.assertFalse(y.is_contiguous()) + self.assertEqual(expected, actual) + + x = torch.randn(4, 5, 7, device=device).mT + y = torch.randn(4, 3, 5, device=device) + actual = torch.cdist(x, y, p=2, compute_mode=cm) + expected = self._brute_cdist(x, y, p=2) + self.assertFalse(x.is_contiguous()) + self.assertTrue(y.is_contiguous()) + self.assertEqual(expected, actual) + + def test_cdist_euclidean_large(self, device="mps"): + def _test_euclidean_large_cdist(sizex, sizey=None): + if sizey is None: + sizey = sizex + x = torch.randn(sizex, device=device, dtype=torch.float) + y = torch.randn(sizey, device=device, dtype=torch.float) + eps = 1e-6 + # to avoid extremum + x = x - (((x - y) < eps).float() * 2 * eps) + x.requires_grad = True + y.requires_grad = True + dist = torch.cdist(x, y, p=2) + # Do a backward pass to check that it is valid for large + # matrices + loss = dist.sum() + loss.backward() + + _test_euclidean_large_cdist((2000, 5)) + + def test_cdist_same_inputs(self, device="mps"): + # Test to detect issues in cdist gradient calculation + # When the distances are 0 + sizex = (1, 27, 32) + for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]: + x = torch.randn(sizex, device=device, dtype=torch.float) + dist_grad = torch.randn((1, 27, 27), device=device, dtype=torch.float) + y = x.clone() + eps = 1e-6 + x.requires_grad = True + d = torch.cdist(x, y) + d.backward(dist_grad) + # Check that the backward passs does not contain invalid + # values such as nan or inf + assert torch.isfinite(x.grad).all() + + + def _brute_cdist(self, x, y, p=2): + r1 = x.shape[-2] + r2 = y.shape[-2] + if r1 == 0 or r2 == 0: + return torch.empty(r1, r2, device=x.device) + return torch.norm(x[..., None, :] - y[..., None, :, :], p=p, dim=-1) + + def test_cdist_norm(self, device="mps"): + for r1 in [3, 4]: + for m in [2, 3]: + for r2 in [4, 6]: + for p in [0, 1, 1.5, 2.5, float('inf')]: + x = torch.randn(r1, m, device=device) + y = torch.randn(r2, m, device=device) + if p == 2: + for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']: + actual = torch.cdist(x, y, p=2, compute_mode=cm) + expected = self._brute_cdist(x, y, p=2) + self.assertEqual(expected, actual, rtol=0, atol=0.02) + else: + actual = torch.cdist(x, y, p=p) + expected = self._brute_cdist(x, y, p=p) + self.assertEqual(expected, actual) + + def test_cdist_norm_batch(self, device="mps"): + for r1 in [3, 4]: + for m in [2, 3]: + for r2 in [4, 6]: + for p in [0, 3, 1.5, 2.5, float('inf')]: + x = torch.randn(2, 3, 6, r1, m, device=device) + y = torch.randn(2, 3, 6, r2, m, device=device) + if p == 2: + for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']: + actual = torch.cdist(x, y, p=2, compute_mode=cm) + expected = self._brute_cdist(x, y, p=2) + self.assertEqual(expected, actual, rtol=0, atol=0.02) + else: + actual = torch.cdist(x, y, p=p) + expected = self._brute_cdist(x, y, p=p) + self.assertEqual(expected, actual) + + def test_cross(self): + a = torch.randn(4, 3, device="mps") + b = torch.randn(4, 3, device="mps") + a_cpu = a.to("cpu") + b_cpu = b.to("cpu") + res = torch.cross(a, b, dim=1) + res_cpu = torch.cross(a_cpu, b_cpu, dim=1) + self.assertEqual(res, res_cpu) def test_addmm(self): A = torch.ones(5, 5).to("mps") B = torch.ones(5, 6).to("mps") C = torch.ones(6, 5).to("mps") D = torch.addmm(A, B, C).to("cpu") - torch.testing.assert_allclose(D, torch.full((5, 5), 7.0)) + torch.testing.assert_close(D, torch.full((5, 5), 7.0)) def test_bmm(self): batch1_cpu = torch.randn(10, 3, 4) @@ -296,6 +525,16 @@ def test_bmm(self): self.assertEqual(output_cpu, output_mps) self.assertEqual(output_cpu.size(), output_mps.size()) + def test_trace(self): + M_cpu = torch.randn(3, 3) + M_mps = M_cpu.detach().clone().to("mps") + + output_cpu = torch.trace(M_cpu) + output_mps = torch.trace(M_mps) + + self.assertEqual(output_cpu, output_mps) + self.assertEqual(output_cpu.size(), output_mps.size()) + def test_addbmm(self): M_cpu = torch.randn(3, 5) batch1_cpu = torch.randn(10, 3, 4) @@ -336,7 +575,7 @@ def helper(input_shape, batch1_shape, batch2_shape): def test_local_scalar_dense_mps(self): x_cpu = torch.randn(1) y_mps = x_cpu.to("mps") - torch.testing.assert_allclose(x_cpu.item(), y_mps.item()) + torch.testing.assert_close(x_cpu.item(), y_mps.item()) def test_linear_1d_weight(self): device = 'cpu' @@ -459,12 +698,46 @@ def test_uniform(self): low.grad.zero_() high.grad.zero_() + def test_randperm(self, device="mps"): + rng_device = None + for n in (5, 100, 50000, 100000): + for dtype in (torch.long, torch.half, torch.float): + if n > 2049 and dtype == torch.half: # Large n for torch.half will raise an exception, do not test here. + continue + if n > 256 and dtype == torch.bfloat16: + continue + with torch.random.fork_rng(devices=rng_device): + res1 = torch.randperm(n, dtype=dtype, device=device) + res2 = torch.empty(0, dtype=dtype, device=device) + torch.randperm(n, out=res2, dtype=dtype, device=device) + self.assertEqual(res1.cpu().sort().values.long(), torch.arange(n, device=device)) + + # Default type is long + for n in (100, 10000): + self.assertEqual(torch.randperm(n, device=device).dtype, torch.long) + + # randperm of 0 elements is an empty tensor + res1 = torch.randperm(0) + res2 = torch.tensor(5, dtype=dtype, device=device) + torch.randperm(0, out=res2) + self.assertEqual(res1.numel(), 0) + self.assertEqual(res2.numel(), 0) + + # Test non-contiguous tensors + for n in (4, 5, 6, 10, 20): + non_contiguous_tensor = torch.zeros((2, 3), dtype=torch.long, device=device).t() + self.assertFalse(non_contiguous_tensor.is_contiguous()) + with torch.random.fork_rng(devices=rng_device): + res = torch.randperm(n, dtype=torch.long, device=device) + torch.randperm(n, out=non_contiguous_tensor) + self.assertEqual(res.cpu().sort().values.long(), torch.arange(n, device=device)) + # Test forward maxpool2d def test_max_pool2d(self): def helper(shape, ks, padding=0, dilation=1, ceil_mode=False, return_indices=False, test_ties=False): cpu_x = None - if(test_ties): + if (test_ties): cpu_x = torch.ones(shape, device='cpu', dtype=torch.float, requires_grad=True) else: cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) @@ -473,7 +746,7 @@ def helper(shape, ks, padding=0, dilation=1, ceil_mode=False, return_indices=Fal pool = torch.nn.MaxPool2d(kernel_size=ks, padding=padding, dilation=dilation, ceil_mode=ceil_mode, return_indices=return_indices) - if(return_indices is False): + if (return_indices is False): y = pool(x) ref_y = pool(cpu_x) @@ -608,7 +881,7 @@ def helper(shape, channels_last=False): np.random.seed(332) arr = (256 - 128) * np.random.random_sample(size=shape) + 128 cpu_x = torch.tensor(arr, device='cpu', dtype=torch.float, requires_grad=True) - if(channels_last): + if (channels_last): cpu_x = cpu_x.to(memory_format=torch.channels_last) cpu_x.retain_grad() x = cpu_x.detach().clone().to('mps').requires_grad_() @@ -627,7 +900,7 @@ def helper(shape, eps=1, momentum=0.1, wts=False, training=False, channels_last= np.random.seed(332) arr = (256 - 128) * np.random.random_sample(size=shape) + 128 cpu_x = torch.tensor(arr, device='cpu', dtype=torch.float, requires_grad=True) - if(channels_last): + if (channels_last): cpu_x = cpu_x.to(memory_format=torch.channels_last) cpu_x.retain_grad() x = cpu_x.detach().clone().to('mps').requires_grad_() @@ -637,7 +910,7 @@ def helper(shape, eps=1, momentum=0.1, wts=False, training=False, channels_last= cpu_running_var = None running_mean = None running_var = None - if(track_running_stats): + if (track_running_stats): mean_arr = (240 - 140) * np.random.random_sample(size=mean_shape) + 140 cpu_running_mean = torch.tensor(mean_arr, device='cpu', dtype=torch.float) var_arr = 32 * np.random.random_sample(size=mean_shape) @@ -649,7 +922,7 @@ def helper(shape, eps=1, momentum=0.1, wts=False, training=False, channels_last= cpu_weight = None bias = None cpu_bias = None - if(wts): + if (wts): cpu_weight = torch.randn(mean_shape, device='cpu', dtype=torch.float, requires_grad=True) weight = cpu_weight.detach().clone().to('mps').requires_grad_() cpu_bias = torch.randn(mean_shape, device='cpu', dtype=torch.float, requires_grad=True) @@ -658,7 +931,7 @@ def helper(shape, eps=1, momentum=0.1, wts=False, training=False, channels_last= y = None ref_y = None - if(not test_module): + if (not test_module): y = torch.nn.functional.batch_norm(x, running_mean, running_var, weight=weight, bias=bias, @@ -675,7 +948,7 @@ def helper(shape, eps=1, momentum=0.1, wts=False, training=False, channels_last= batchnorm_op = None mps_batchnorm_op = None - if(len(shape) == 3): + if (len(shape) == 3): batchnorm_op = torch.nn.BatchNorm1d(shape[1], eps=eps, momentum=momentum, @@ -688,7 +961,7 @@ def helper(shape, eps=1, momentum=0.1, wts=False, training=False, channels_last= affine=wts, track_running_stats=track_running_stats, device='mps') - elif(len(shape) == 4): + elif (len(shape) == 4): batchnorm_op = torch.nn.BatchNorm2d(shape[1], eps=eps, momentum=momentum, @@ -701,7 +974,7 @@ def helper(shape, eps=1, momentum=0.1, wts=False, training=False, channels_last= affine=wts, track_running_stats=track_running_stats, device='mps') - elif(len(shape) == 5): + elif (len(shape) == 5): batchnorm_op = torch.nn.BatchNorm3d(shape[1], eps=eps, momentum=momentum, @@ -715,12 +988,12 @@ def helper(shape, eps=1, momentum=0.1, wts=False, training=False, channels_last= track_running_stats=track_running_stats, device='mps') - if(track_running_stats): + if (track_running_stats): batchnorm_op.running_mean = cpu_running_mean batchnorm_op.running_var = cpu_running_var mps_batchnorm_op.running_mean = running_mean mps_batchnorm_op.running_var = running_var - if(wts): + if (wts): batchnorm_op.weight = torch.nn.Parameter(cpu_weight) batchnorm_op.bias = torch.nn.Parameter(cpu_bias) mps_batchnorm_op.weight = torch.nn.Parameter(weight) @@ -730,7 +1003,7 @@ def helper(shape, eps=1, momentum=0.1, wts=False, training=False, channels_last= y = mps_batchnorm_op(x) self.assertEqual(y, ref_y) - if(not test_module): + if (not test_module): self.assertEqual(running_mean, cpu_running_mean) self.assertEqual(running_var, cpu_running_var) else: @@ -743,8 +1016,8 @@ def helper(shape, eps=1, momentum=0.1, wts=False, training=False, channels_last= y.backward(gradient=grad) self.assertEqual(x.grad, cpu_x.grad) - if(wts): - if(not test_module): + if (wts): + if (not test_module): self.assertEqual(weight.grad, cpu_weight.grad) self.assertEqual(bias.grad, cpu_bias.grad) else: @@ -755,10 +1028,10 @@ def helper(shape, eps=1, momentum=0.1, wts=False, training=False, channels_last= for test_module in [False, True]: for track_running_stats in [True, False]: for channels_last in [False]: - if(channels_last and len(shape) != 4): + if (channels_last and len(shape) != 4): continue # Running stats must be tracked in eval mode - if(track_running_stats): + if (track_running_stats): helper(shape, eps=0, momentum=1, channels_last=channels_last, track_running_stats=track_running_stats, test_module=test_module) helper(shape, channels_last=channels_last, @@ -780,6 +1053,55 @@ def helper(shape, eps=1, momentum=0.1, wts=False, training=False, channels_last= helper(shape, eps=3, momentum=0.67, wts=True, training=True, channels_last=channels_last, track_running_stats=track_running_stats, test_module=test_module) + def test_norm(self): + a = torch.arange(9, dtype=torch.float, device="mps") - 4 + b = a.reshape((3, 3)) + + a_cpu = torch.arange(9, dtype=torch.float, device="cpu") - 4 + b_cpu = a_cpu.reshape((3, 3)) + + res = torch.norm(a) + res_cpu = torch.norm(a_cpu) + self.assertEqual(res, res_cpu) + + res = torch.norm(b) + res_cpu = torch.norm(b_cpu) + self.assertEqual(res, res_cpu) + + res = torch.norm(a, float('inf')) + res_cpu = torch.norm(a_cpu, float('inf')) + self.assertEqual(res, res_cpu) + + res = torch.norm(b, float('inf')) + res_cpu = torch.norm(b_cpu, float('inf')) + self.assertEqual(res, res_cpu) + + c = torch.tensor([[1, 2, 3], [-1, 1, 4]], dtype=torch.float, device="mps") + c_cpu = torch.tensor([[1, 2, 3], [-1, 1, 4]], dtype=torch.float, device="cpu") + + res = torch.norm(c, dim=0) + res_cpu = torch.norm(c_cpu, dim=0) + self.assertEqual(res, res_cpu) + + res = torch.norm(c, dim=1) + res_cpu = torch.norm(c_cpu, dim=1) + self.assertEqual(res, res_cpu) + + res = torch.norm(c, p=1, dim=1) + res_cpu = torch.norm(c_cpu, p=1, dim=1) + self.assertEqual(res, res_cpu) + + d = torch.arange(8, dtype=torch.float, device="mps").reshape(2, 2, 2) + d_cpu = torch.arange(8, dtype=torch.float, device="cpu").reshape(2, 2, 2) + + res = torch.norm(d, dim=(1, 2)) + res_cpu = torch.norm(d_cpu, dim=(1, 2)) + self.assertEqual(res, res_cpu) + + res = torch.norm(d[0, :, :]), torch.norm(d[1, :, :]) + res_cpu = torch.norm(d_cpu[0, :, :]), torch.norm(d_cpu[1, :, :]) + self.assertEqual(res, res_cpu) + def test_layer_norm(self): # TODO: Test non-contiguous def helper(input_shape, normalized_shape, eps=1e-05, elementwise_affine=True, dtype=torch.float32): @@ -793,7 +1115,7 @@ def helper(input_shape, normalized_shape, eps=1e-05, elementwise_affine=True, dt cpu_bias = torch.randn(normalized_shape, device='cpu', dtype=dtype, requires_grad=True) bias = cpu_bias.detach().clone().to('mps').requires_grad_() - if(elementwise_affine): + if (elementwise_affine): cpu_op.weight = torch.nn.Parameter(cpu_wt) mps_op.weight = torch.nn.Parameter(wt) cpu_op.bias = torch.nn.Parameter(cpu_bias) @@ -810,7 +1132,7 @@ def helper(input_shape, normalized_shape, eps=1e-05, elementwise_affine=True, dt self.assertEqual(result, cpu_result) self.assertEqual(x.grad, cpu_x.grad) - if(elementwise_affine): + if (elementwise_affine): self.assertEqual(mps_op.weight.grad, cpu_op.weight.grad) self.assertEqual(mps_op.bias.grad, cpu_op.bias.grad) @@ -826,7 +1148,7 @@ def helper(shape, eps=1, momentum=0.1, wts=False, channels_last=False, track_run np.random.seed(332) arr = (256 - 128) * np.random.random_sample(size=shape) + 128 cpu_x = torch.tensor(arr, device='cpu', dtype=torch.float, requires_grad=True) - if(channels_last): + if (channels_last): cpu_x = cpu_x.to(memory_format=torch.channels_last) cpu_x.retain_grad() x = cpu_x.detach().clone().to('mps').requires_grad_() @@ -836,7 +1158,7 @@ def helper(shape, eps=1, momentum=0.1, wts=False, channels_last=False, track_run cpu_running_var = None running_mean = None running_var = None - if(track_running_stats): + if (track_running_stats): mean_arr = (240 - 140) * np.random.random_sample(size=mean_shape) + 140 cpu_running_mean = torch.tensor(mean_arr, device='cpu', dtype=torch.float) var_arr = 32 * np.random.random_sample(size=mean_shape) @@ -848,7 +1170,7 @@ def helper(shape, eps=1, momentum=0.1, wts=False, channels_last=False, track_run cpu_weight = None bias = None cpu_bias = None - if(wts): + if (wts): cpu_weight = torch.randn(mean_shape, device='cpu', dtype=torch.float, requires_grad=True) weight = cpu_weight.detach().clone().to('mps').requires_grad_() cpu_bias = torch.randn(mean_shape, device='cpu', dtype=torch.float, requires_grad=True) @@ -857,7 +1179,7 @@ def helper(shape, eps=1, momentum=0.1, wts=False, channels_last=False, track_run y = None ref_y = None - if(not test_module): + if (not test_module): ref_y = torch.nn.functional.instance_norm(cpu_x, cpu_running_mean, cpu_running_var, weight=cpu_weight, bias=cpu_bias, @@ -872,7 +1194,7 @@ def helper(shape, eps=1, momentum=0.1, wts=False, channels_last=False, track_run instancenorm_op = None mps_instancenorm_op = None - if(len(shape) == 3): + if (len(shape) == 3): instancenorm_op = torch.nn.InstanceNorm1d(shape[1], eps=eps, momentum=momentum, @@ -885,7 +1207,7 @@ def helper(shape, eps=1, momentum=0.1, wts=False, channels_last=False, track_run affine=wts, track_running_stats=track_running_stats, device='mps') - elif(len(shape) == 4): + elif (len(shape) == 4): instancenorm_op = torch.nn.InstanceNorm2d(shape[1], eps=eps, momentum=momentum, @@ -898,7 +1220,7 @@ def helper(shape, eps=1, momentum=0.1, wts=False, channels_last=False, track_run affine=wts, track_running_stats=track_running_stats, device='mps') - elif(len(shape) == 5): + elif (len(shape) == 5): instancenorm_op = torch.nn.InstanceNorm3d(shape[1], eps=eps, momentum=momentum, @@ -912,12 +1234,12 @@ def helper(shape, eps=1, momentum=0.1, wts=False, channels_last=False, track_run track_running_stats=track_running_stats, device='mps') - if(track_running_stats): + if (track_running_stats): instancenorm_op.running_mean = cpu_running_mean instancenorm_op.running_var = cpu_running_var mps_instancenorm_op.running_mean = running_mean mps_instancenorm_op.running_var = running_var - if(wts): + if (wts): instancenorm_op.weight = torch.nn.Parameter(cpu_weight) instancenorm_op.bias = torch.nn.Parameter(cpu_bias) mps_instancenorm_op.weight = torch.nn.Parameter(weight) @@ -927,7 +1249,7 @@ def helper(shape, eps=1, momentum=0.1, wts=False, channels_last=False, track_run y = mps_instancenorm_op(x) self.assertEqual(y, ref_y) - if(not test_module): + if (not test_module): self.assertEqual(running_mean, cpu_running_mean) self.assertEqual(running_var, cpu_running_var) else: @@ -940,8 +1262,8 @@ def helper(shape, eps=1, momentum=0.1, wts=False, channels_last=False, track_run y.backward(gradient=grad) self.assertEqual(x.grad, cpu_x.grad) - if(wts): - if(not test_module): + if (wts): + if (not test_module): self.assertEqual(weight.grad, cpu_weight.grad) self.assertEqual(bias.grad, cpu_bias.grad) else: @@ -952,10 +1274,10 @@ def helper(shape, eps=1, momentum=0.1, wts=False, channels_last=False, track_run for test_module in [False, True]: for track_running_stats in [True, False]: for channels_last in [False]: - if(channels_last and len(shape) != 4): + if (channels_last and len(shape) != 4): continue # Running stats must be tracked in eval mode - if(track_running_stats): + if (track_running_stats): helper(shape, eps=0, momentum=1, channels_last=channels_last, track_running_stats=track_running_stats, test_module=test_module) helper(shape, channels_last=channels_last, @@ -993,7 +1315,7 @@ def helper(input_shape, wt_shape, cpu_bias = None bias = None - if(bias_shape is not None): + if (bias_shape is not None): cpu_bias = torch.randn(bias_shape, device='cpu', dtype=torch.float, requires_grad=True) bias = cpu_bias.detach().clone().to('mps').requires_grad_() @@ -1011,7 +1333,7 @@ def helper(input_shape, wt_shape, self.assertEqual(y, ref_y, rtol=2.6e-05, atol=2e-04) self.assertEqual(x.grad, cpu_x.grad, rtol=2.6e-06, atol=2e-05) self.assertEqual(wt.grad, cpu_wt.grad, atol=8e-04, rtol=10.4e-05) - if(bias_shape is not None): + if (bias_shape is not None): self.assertEqual(bias.grad, cpu_bias.grad, atol=8e-04, rtol=10.4e-05) N = 1 @@ -1067,7 +1389,7 @@ def helper(input_shape, wt_shape, cpu_bias = None bias = None - if(bias_shape is not None): + if (bias_shape is not None): cpu_bias = torch.randn(bias_shape, device='cpu', dtype=torch.float, requires_grad=True) bias = cpu_bias.detach().clone().to('mps').requires_grad_() @@ -1087,7 +1409,7 @@ def helper(input_shape, wt_shape, self.assertEqual(x.grad, cpu_x.grad, rtol=2.6e-06, atol=2e-05) self.assertEqual(wt.grad, cpu_wt.grad, atol=8e-04, rtol=10.4e-05) - # if(bias_shape is not None): + # if (bias_shape is not None): # print(cpu_bias.grad) # print(bias.grad.to('cpu')) # self.assertEqual(bias.grad, cpu_bias.grad) @@ -1106,7 +1428,7 @@ def helper(input_shape, wt_shape, for padding in [0, 1, 2]: for output_padding in [0, 1, 2]: for dilation in [1, 2]: - if(output_padding >= stride or output_padding >= dilation): + if (output_padding >= stride or output_padding >= dilation): continue helper((N, C_out, H, W), (C_out, C_in, kH, kW), stride=stride, padding=padding, output_padding=output_padding, dilation=dilation) @@ -1290,6 +1612,78 @@ def test_expand_cpu_to_mps_copy(self): self.assertEqual(x_cpu, x.cpu()) + def test_cpu_to_strided_mps_copy(self): + # https://github.com/pytorch/pytorch/issues/86975 + + a1 = torch.Tensor([[1, 2], [3, 4], [5, 6]]).to(torch.device("mps")) + b1 = torch.Tensor([-1, -1]) + a1[1:, 1] = b1 + + a2 = torch.Tensor([[1, 2], [3, 4], [5, 6]]).to(torch.device("mps")) + b2 = torch.Tensor([-1, -1]).to(torch.device("mps")) + a2[1:, 1] = b2 + + self.assertEqual(a1, a2) + + def test_view_slice_reshape(self): + x = torch.randn([1, 4, 4], device="mps") + y = x[0, :1, 1:] + + x_cpu = x.to("cpu") + y_cpu = x_cpu[0, :1, 1:] + + r = y + 1 + r_cpu = y_cpu + 1 + self.assertEqual(r, r_cpu) + + def test_slice_reshape(self): + x = torch.randn([1, 6, 4, 2], dtype=torch.float, device="mps") + x_cpu = x.detach().clone().to("cpu") + + x = x[:, 3:].view(2, 3, 4, 1) + x_cpu = x_cpu[:, 3:].view(2, 3, 4, 1) + self.assertEqual(x, x_cpu) + + x = x + 2 + x_cpu = x_cpu + 2 + self.assertEqual(x, x_cpu) + + def test_slice_reshape_contg_view(self): + import torch + + x_mps = torch.randn(1, 4800, 2, device="mps") + x_cpu = x_mps.detach().clone().cpu() + + r_mps = x_mps + 2 + r_cpu = x_cpu + 2 + + self.assertEqual(r_mps, r_cpu) + + def test_view_slice(self): + # https://github.com/pytorch/pytorch/issues/83995 + NUM_SAMPLES = 60 + s = (0, 1) + + X = torch.rand(8000, 3, dtype=torch.float32, device='cpu') + X_mps = X.detach().clone().to("cpu") + + idx = torch.randint(0, X.shape[0], (1,)).repeat(len(s)) + pts = torch.randint(0, X.shape[0], (NUM_SAMPLES, X.shape[1])) + idx_mps = idx.to("mps") + pts_mps = pts.to("mps") + pts[:, s] = idx + pts_mps[:, s] = idx_mps + + actual_pts = torch.zeros(NUM_SAMPLES, X.shape[1], dtype=torch.float) + actual_pts_mps = torch.zeros(NUM_SAMPLES, X.shape[1], dtype=torch.float, device="mps") + + for i in range(NUM_SAMPLES): + for j in range(X.shape[1]): + actual_pts_mps[i, j] = X_mps[pts_mps[i, j], j] + actual_pts[i, j] = X[pts[i, j], j] + self.assertEqual(actual_pts[i, j], actual_pts_mps[i, j]) + + def test_slice(self): values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]] cpu_x = torch.tensor(values, device='cpu') @@ -1613,9 +2007,10 @@ def test_full_bugs(self): # See https://github.com/pytorch/pytorch/issues/84995 def test_div_bugs(self): for (dtype, mode) in itertools.product(integral_types(), ['trunc', 'floor']): - x = torch.tensor(list(range(1, 11)), device='mps', dtype=dtype) - y = torch.div(x, 101, rounding_mode=mode) - self.assertEqual(y.sum(), 0) + if dtype != torch.int64: + x = torch.tensor(list(range(1, 11)), device='mps', dtype=dtype) + y = torch.div(x, 101, rounding_mode=mode) + self.assertEqual(y.sum(), 0) # See https://github.com/pytorch/pytorch/issues/82663 def test_bool_expand(self): @@ -1629,6 +2024,156 @@ def test_empty_neg(self): y = -x self.assertEqual(x, y) + def _test_unique_scalar_empty(self, dtype, device, f): + # test scalar + x = torch.tensor(0, dtype=dtype, device=device) + unique, inverse, counts = f(x, return_inverse=True, return_counts=True) + expected_unique = torch.tensor([0], dtype=dtype, device=device) + expected_inverse = torch.tensor(0, device=device) + expected_counts = torch.tensor([1], device=device) + self.assertEqual(unique, expected_unique) + self.assertEqual(inverse, expected_inverse) + self.assertEqual(counts, expected_counts) + + # test zero sized tensor + x = torch.zeros((0, 0, 3), dtype=dtype, device=device) + unique, inverse, counts = f(x, return_inverse=True, return_counts=True) + expected_unique = torch.tensor([], dtype=dtype, device=device) + expected_inverse = torch.empty((0, 0, 3), dtype=torch.long, device=device) + expected_counts = torch.tensor([], dtype=torch.long, device=device) + self.assertEqual(unique, expected_unique) + self.assertEqual(inverse, expected_inverse) + self.assertEqual(counts, expected_counts) + + def _test_unique_with_expects(self, device, dtype, f, x, expected_unique, expected_inverse, expected_counts, additional_shape): + def ensure_tuple(x): + if isinstance(x, torch.Tensor): + return (x,) + return x + + for return_inverse in [True, False]: + for return_counts in [True, False]: + # test with expected + ret = ensure_tuple(f(x, return_inverse=return_inverse, return_counts=return_counts)) + self.assertEqual(len(ret), 1 + int(return_inverse) + int(return_counts)) + self.assertEqual(expected_unique, ret[0]) + if return_inverse: + self.assertEqual(expected_inverse, ret[1]) + if return_counts: + count_index = 1 + int(return_inverse) + self.assertEqual(expected_counts, ret[count_index]) + + # tests per-element unique on a higher rank tensor. + y = x.view(additional_shape) + y_unique, y_inverse, y_counts = f(y, return_inverse=True, return_counts=True) + self.assertEqual(expected_unique, y_unique) + self.assertEqual(expected_inverse.view(additional_shape), y_inverse) + self.assertEqual(expected_counts, y_counts) + + def test_unique_all_dtypes(self, device="mps"): + def helper(dtype): + def ensure_tuple(x): + if isinstance(x, torch.Tensor): + return (x,) + return x + + if dtype is torch.bool: + x = torch.tensor([True, False, False, False, True, False, True, False], dtype=torch.bool, device=device) + expected_unique = torch.tensor([False, True], dtype=torch.bool, device=device) + expected_inverse = torch.tensor([1, 0, 0, 0, 1, 0, 1, 0], dtype=torch.long, device=device) + expected_counts = torch.tensor([5, 3], dtype=torch.long, device=device) + else: + x = torch.tensor([1, 2, 3, 2, 8, 5, 2, 3], dtype=dtype, device=device) + expected_unique = torch.tensor([1, 2, 3, 5, 8], dtype=dtype, device=device) + expected_inverse = torch.tensor([0, 1, 2, 1, 4, 3, 1, 2], device=device) + expected_counts = torch.tensor([1, 3, 2, 1, 1], device=device) + + # test sorted unique + fs = ( + lambda x, **kwargs: torch.unique(x, sorted=True, **kwargs), + lambda x, **kwargs: x.unique(sorted=True, **kwargs), + ) + x_sliced = torch.empty(x.size(0) * 2, dtype=dtype, device=device)[::2].copy_(x) + xs = (x, x_sliced) + for f, x in product(fs, xs): + self._test_unique_with_expects(device, dtype, f, x, expected_unique, expected_inverse, expected_counts, (2, 2, 2)) + self._test_unique_scalar_empty(dtype, device, f) + + # test unsorted unique + fs = ( + lambda x, **kwargs: torch.unique(x, sorted=False, **kwargs), + lambda x, **kwargs: x.unique(sorted=False, **kwargs) + ) + for f, x in product(fs, xs): + self._test_unique_scalar_empty(dtype, device, f) + for return_inverse, return_counts in product((True, False), repeat=2): + ret = ensure_tuple(f(x, return_inverse=return_inverse, return_counts=return_counts)) + self.assertEqual(len(ret), 1 + int(return_inverse) + int(return_counts)) + x_list = x.tolist() + x_unique_list = ret[0].tolist() + self.assertEqual(expected_unique.tolist(), sorted(x_unique_list)) + if return_inverse: + x_inverse_list = ret[1].tolist() + for i, j in enumerate(x_inverse_list): + self.assertEqual(x_list[i], x_unique_list[j]) + if return_counts: + count_index = 1 + int(return_inverse) + x_counts_list = ret[count_index].tolist() + for i, j in zip(x_unique_list, x_counts_list): + count = 0 + for k in x_list: + if k == i: + count += 1 + self.assertEqual(j, count) + [helper(dtype) for dtype in [torch.float32, torch.float16, torch.int64, torch.int32, torch.int16, torch.uint8]] + + def test_unique(self): + def helper(x, return_inverse, return_counts): + cpu_x = x + x = cpu_x.detach().clone().to('mps') + + result = torch.unique(x, return_inverse=return_inverse, return_counts=return_counts) + result_cpu = torch.unique(cpu_x, return_inverse=return_inverse, return_counts=return_counts) + + self.assertEqual(result, result_cpu) + helper(torch.tensor([1, 2, 4, 2, 1]), False, False) + helper(torch.randint(3, (10,)), False, False) + helper(torch.randint(3, (10,)), True, False) + helper(torch.randint(3, (10,)), False, True) + helper(torch.randint(3, (10,)), True, True) + helper(torch.randint(3, (1,)), True, True) + helper(torch.randint(3, (0,)), True, True) + + def test_unique_consecutive(self): + def helper(x, dim, return_inverse, return_counts): + cpu_x = x + x = cpu_x.detach().clone().to('mps') + + result = torch.unique_consecutive(x, dim=dim, return_inverse=return_inverse, return_counts=return_counts) + result_cpu = torch.unique_consecutive(cpu_x, dim=dim, return_inverse=return_inverse, return_counts=return_counts) + + self.assertEqual(result, result_cpu) + helper(torch.tensor([1, 2, 4, 2, 1]), 0, False, False) + helper(torch.randint(3, (10,)), 0, False, False) + helper(torch.randint(3, (10,)), 0, True, False) + helper(torch.randint(3, (10,)), 0, False, True) + helper(torch.randint(3, (10,)), 0, True, True) + helper(torch.randint(3, (10,)), 0, True, True) + helper(torch.randint(3, (1,)), 0, True, True) + helper(torch.randint(3, (0,)), 0, True, True) + + helper(torch.tensor([[1, 1, 2, 3, 3, 2], [1, 1, 1, 2, 2, 1]]), 0, False, False) + helper(torch.tensor([[1, 1, 2, 3, 3, 2], [1, 1, 1, 2, 2, 1]]), 0, True, True) + helper(torch.randint(2, (20, 2)), 0, True, True) + helper(torch.randint(2, (1, 2)), 0, True, True) + helper(torch.randint(2, (0, 2)), 0, True, True) + + helper(torch.tensor([[1, 1, 2, 3, 3, 2], [1, 1, 1, 2, 2, 1]]), 1, False, False) + helper(torch.tensor([[1, 1, 2, 3, 3, 2], [1, 1, 1, 2, 2, 1]]), 1, True, True) + helper(torch.randint(2, (2, 20)), 1, True, True) + helper(torch.randint(2, (2, 1)), 1, True, True) + helper(torch.randint(2, (2, 0)), 1, True, True) + # See https://github.com/pytorch/pytorch/issues/85675 def test_cat_non_contiguous(self): def rotate_subset(data): @@ -1641,6 +2186,7 @@ def rotate_subset(data): cpu_result = rotate_subset(data) mps_result = rotate_subset(mps_data) self.assertEqual(cpu_result, mps_result.to("cpu")) + self.assertEqual(cpu_result.is_contiguous(), mps_result.is_contiguous()) # See https://github.com/pytorch/pytorch/issues/85967 def test_from_numpy_non_contiguous(self): @@ -1649,6 +2195,77 @@ def test_from_numpy_non_contiguous(self): t_mps = torch.tensor(a, device="mps") self.assertEqual(t_cpu, t_mps.to("cpu")) + def test_cumsum_all_dtypes(self): + def helper(dtype): + t = torch.tensor([1, 1, 1, 1], device="mps", dtype=dtype) + t_cpu = torch.tensor([1, 1, 1, 1], device="cpu") + + a = t.cumsum(0, dtype=dtype) + a_cpu = t_cpu.cumsum(0, dtype=dtype) + + self.assertEqual(a.cpu(), a_cpu) + [helper(dtype) for dtype in [torch.int8, torch.int16, torch.int32, torch.float32]] + + try: + helper(torch.int64) + except Exception as e: + e_string = str(e) + self.assertEqual(e_string, "MPS does not support cumsum op with int64 input") + + def test_gelu_tanh(self): + def helper(shape): + cpu_x = torch.randn(shape, device='cpu', dtype=torch.float) + x = cpu_x.detach().clone().to('mps') + + gelu_tanh_result = torch.nn.functional.gelu(x, approximate='tanh') + gelu_tanh_result_cpu = torch.nn.functional.gelu(cpu_x, approximate='tanh') + self.assertEqual(gelu_tanh_result, gelu_tanh_result_cpu) + + helper((2, 8, 4, 5)) + + # # Failures due to precision issues, enable after resolving from mps + # def test_div_floor_int(self): + # def helper(shape, dtype): + # cpu_x = torch.randint(-9999, -1,shape, device='cpu', dtype=dtype) + # x = cpu_x.detach().clone().to('mps') + + # cpu_y = torch.randint(1, 9999, shape, device='cpu', dtype=dtype) + # y = cpu_y.detach().clone().to('mps') + + # div_result = torch.div(x, y,rounding_mode='floor') + # div_result_cpu = torch.div(cpu_x, cpu_y, rounding_mode='floor') + # self.assertEqual(div_result, div_result_cpu) + + # helper((2, 8, 4, 5), torch.int16) + # helper((2, 8, 4, 5), torch.int32) + + def test_median_int16(self): + def helper(shape, dtype): + cpu_x = torch.randint(-9999, 9999, shape, device='cpu', dtype=dtype) + x = cpu_x.detach().clone().to('mps') + + median_result = torch.median(x) + median_result_cpu = torch.median(cpu_x) + self.assertEqual(median_result, median_result_cpu) + + helper((2, 8, 4, 5), torch.int16) + + def test_cumsum_minus_one_axis(self): + def helper(dtype): + # Test with axis -1 + cpu_x = None + if(dtype == torch.float32): + cpu_x = torch.randn(10, 3, device='cpu', dtype=torch.float32) + else: + cpu_x = torch.randint(0, 20, (10, 3), device='cpu', dtype=torch.float32) + x = cpu_x.detach().clone().to('mps') + + cpu_y = cpu_x.cumsum(-1) + y = x.cumsum(-1) + + self.assertEqual(y, cpu_y) + + [helper(dtype) for dtype in [torch.float32, torch.int16, torch.int32, torch.uint8]] class TestLogical(TestCase): def _wrap_tensor(self, x, device="cpu", dtype=None, requires_grad=False): @@ -1743,6 +2360,24 @@ def helper(x, other): helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(True)) helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(False)) + def test_min_max(self): + def helper(dtype): + for _ in range(10): + if dtype == torch.float32 or dtype == torch.float16: + x = torch.randn((30, 15), device='mps', dtype=dtype) + else: + x = torch.randint(0, 100, (30, 15), device="mps", dtype=dtype) + x_cpu = x.to("cpu") + + y = x.max() + y_cpu = x_cpu.max() + self.assertEqual(y, y_cpu) + + z = x.min() + z_cpu = x_cpu.min() + self.assertEqual(z, z_cpu) + + [helper(dtype) for dtype in [torch.float32, torch.float16, torch.int32, torch.int16, torch.uint8, torch.int8, torch.bool]] class TestSmoothL1Loss(TestCase): @@ -1784,6 +2419,20 @@ def test_smooth_l1_loss_reduction_mean_sum_backward(self): class TestNLLLoss(TestCase): + def test_nll2d_loss_backward(self, device='mps'): + a = torch.randn(3, 5, requires_grad=True, device=device) + b = torch.tensor([1, 0, 4], device=device) + loss = nn.NLLLoss() + out = loss(a, b) + self.assertIsNone(out.grad_fn._saved_weight) + loss = nn.NLLLoss(weight=torch.ones((5,), device=device)) + out = loss(a, b) + self.assertEqual(out.grad_fn._saved_weight, torch.ones((5,))) + + out.sum().backward() + with self.assertRaisesRegex(RuntimeError, "after they have already been freed"): + out.grad_fn._saved_weight + def test_nll_loss_mismatched_batch(self, device='mps'): x = torch.randn((10, 3), requires_grad=True, device=device) # t should have size (10,) @@ -1845,16 +2494,17 @@ def _nll_loss_helper(self, input_size, reduction, expected): input = torch.rand(input_size, requires_grad=True, device='cpu') num_channels = input_size[1] target_size = (input_size[0], ) + tuple(input_size[2:]) + weights = torch.randn(num_channels) + weights_mps = weights.to("mps") target = torch.randint(num_channels, target_size, device='cpu') # MPS input_mps = input.detach().clone().to('mps').requires_grad_() target_mps = target.detach().clone().to('mps') - output_cpu = F.nll_loss(input, target, reduction=reduction) - output_mps = F.nll_loss(input_mps, target_mps, reduction=reduction) - # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 - self.assertEqualIgnoreType(output_cpu, output_mps.to('cpu')) + output_cpu = F.nll_loss(input, target, weight=weights, reduction=reduction) + output_mps = F.nll_loss(input_mps, target_mps, weight=weights_mps, reduction=reduction) + self.assertEqual(output_cpu, output_mps.to('cpu')) output_cpu.sum().backward() output_mps.sum().backward() @@ -1873,8 +2523,7 @@ def _nll_loss_1d_helper(self, input_size, reduction): output_cpu = F.nll_loss(input, target, reduction=reduction) output_mps = F.nll_loss(input_mps, target_mps, reduction=reduction) - # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 - self.assertEqualIgnoreType(output_cpu, output_mps.to('cpu')) + self.assertEqual(output_cpu, output_mps.to('cpu')) output_cpu.sum().backward() output_mps.sum().backward() @@ -1904,7 +2553,156 @@ def test_as_strided(self): strided_mps_out = strided_mps1 - strided_mps2 self.assertEqual(strided_cpu_out, strided_mps_out) + def test_unfold(self): + x = torch.arange(1., 8) + x_mps = torch.arange(1., 8, device="mps") + + y = x.unfold(0, 2, 1) + y_mps = x_mps.unfold(0, 2, 1) + + self.assertEqual(y, y_mps) + + def test_unfold_all_devices_and_dtypes(self): + supported_dtypes = [torch.float32, torch.float16, torch.int64, torch.int32, torch.int16, torch.uint8] + for dt in supported_dtypes: + x = torch.empty((0, 1, 3, 0), dtype=dt, device="mps") + self.assertEqual((0, 1, 1, 0, 3), x.unfold(2, 3, 2).shape) + + def test_unfold_scalars(self): + x = torch.tensor(0.5, device="mps") + # unfold on a 0-dimensional tensor should always return a 1-d dimensional + # tensor of shape [size] (i.e., the second parameter to unfold) + + self.assertEqual(torch.empty(0, device="mps"), x.unfold(0, 0, 1)) + self.assertEqual(torch.empty(0, device="mps"), x.unfold(0, 0, 2)) + self.assertEqual(torch.tensor([0.5], device="mps"), x.unfold(0, 1, 1)) + + def test_bincount_simple(self): + input = torch.randint(0, 8, (5,), dtype=torch.int32, device="mps") + input_cpu = input.to("cpu") + weights = torch.linspace(0, 1, steps=5, device="mps", dtype=torch.float32) + weights_cpu = weights.to("cpu") + + x = torch.bincount(input) + x_cpu = torch.bincount(input_cpu) + self.assertEqual(x, x_cpu) + + y = input.bincount(weights) + y_cpu = input_cpu.bincount(weights_cpu) + self.assertEqual(y, y_cpu) + + def test_bincount_reduction(self): + device = "mps" + # negative input throws + with self.assertRaisesRegex(RuntimeError, '1-d non-negative integral'): + torch.bincount(torch.tensor([1, -1], device=device, dtype=torch.int32)) + # n-d input, with n > 1 throws + with self.assertRaisesRegex(RuntimeError, '1-d non-negative integral'): + torch.bincount(torch.tensor([[1, 2], [3, 4]], device=device)) + # minlength < 0 throws + with self.assertRaisesRegex(RuntimeError, 'minlength should be >= 0'): + torch.bincount(torch.tensor([1, 3], device=device), + torch.tensor([.2, .2], device=device), + minlength=-1) + # n-d weights, with n > 1 throws + with self.assertRaisesRegex(RuntimeError, '1-d'): + torch.bincount(torch.tensor([1, 0], device=device, dtype=torch.int32), + torch.tensor([[1., 0.3], [1., 0.3]], device=device, dtype=torch.float)) + # input and weights dim mismatch + with self.assertRaisesRegex(RuntimeError, 'same length'): + torch.bincount(torch.tensor([1, 0], device=device, dtype=torch.int32), + torch.tensor([1., 0.3, 0.5], device=device, dtype=torch.float)) + # 1-d input with no elements and default minlength + self.assertEqual(torch.bincount(torch.tensor([], device=device, dtype=torch.long)), + torch.zeros(0, dtype=torch.long, device=device)) + # 1-d input with no elements and specified minlength + self.assertEqual(torch.bincount(torch.tensor([], device=device, dtype=torch.long), minlength=10), + torch.zeros(10, dtype=torch.long, device=device)) + + # test tensor method without weights + long_counts = torch.tensor( + [0, 3, 2, 1, 3], dtype=torch.uint8, device=device).bincount() + self.assertEqual( + torch.tensor([1, 1, 1, 2], dtype=torch.int64, device=device), + long_counts) + # test avoiding overflow for uint8 (#76979) + count_uint8 = torch.tensor([0, 1, 2, 3, 255], dtype=torch.uint8, device=device).bincount() + count_int16 = torch.tensor([0, 1, 2, 3, 255], dtype=torch.int16, device=device).bincount() + self.assertEqual(count_uint8, count_int16) + # test minlength functionality + int_counts = torch.bincount( + torch.tensor([1, 1, 1, 1], device=device, dtype=torch.int32), minlength=5) + self.assertEqual( + torch.tensor([0, 4, 0, 0, 0], dtype=torch.int64, device=device), + int_counts) + # test weights + byte_counts = torch.bincount( + torch.tensor([0, 1, 1, 1, 4], device=device, dtype=torch.int32), + torch.tensor([.1, .2, .3, .4, .5], device=device)) + self.assertEqual( + torch.tensor([0.1, 0.9, 0, 0, 0.5], device=device), byte_counts) + byte_counts = torch.bincount( + torch.tensor([0, 1, 1, 1, 4], device=device, dtype=torch.int32), + torch.tensor([1, 2, 3, 4, 5], dtype=torch.int8, device=device)) + self.assertEqual( + torch.tensor([1, 9, 0, 0, 5], device=device, dtype=torch.int32), byte_counts) + # test non-contiguous inputs and weights + inputs = torch.tensor([[0, 0], [3, 1], [2, 1], [1, 1], [3, 4]], device=device, dtype=torch.int32) + weights = torch.tensor([[.1, 1], [.2, 2], [.3, 3], [.4, 4], [.5, 5]], device=device) + for i in [0, 1]: + assert not inputs[:, i].is_contiguous(), "Inputs are supposed to be non-contiguous" + assert not weights[:, i].is_contiguous(), "Weights are supposed to be non-contiguous" + # inputs are non-contiguous but weights are contiguous + self.assertEqual(inputs[:, 0].bincount(), torch.tensor([1, 1, 1, 2])) + # inputs and weights are non-contiguous + self.assertEqual( + inputs[:, 1].bincount(weights[:, 1]), + torch.tensor([1, 9, 0, 0, 5], dtype=torch.float32)) + # weights are non-contiguous but inputs are contiguous + self.assertEqual(inputs[:, 1].contiguous().bincount(weights[:, 1]), + torch.tensor([1, 9, 0, 0, 5], dtype=torch.float32)) + + # test bincount on non-contiguous slices + all0s = torch.zeros((32, 2), dtype=torch.int32, device=device) + self.assertEqual(all0s[:, 0].bincount(), torch.tensor([32])) + + all1s = torch.ones((32, 2), dtype=torch.int32, device=device) + self.assertEqual(all1s[:, 0].bincount(), torch.tensor([0, 32])) + + # test large number of bins - global memory use + big_exp = torch.zeros(100, device=device) + big_exp[-1] = 50.0 + big_w = torch.tensor([.5] * 100, device=device) + big_out = torch.tensor([99] * 100, device=device, dtype=torch.int32).bincount(big_w) + self.assertEqual(big_exp, big_out) + # test large input size + big_exp = torch.zeros(2, device=device, dtype=torch.int64) + big_exp[1] = 10 + big_out = torch.ones(10, dtype=torch.int8, device=device).bincount() + self.assertEqual(big_exp, big_out) + + def test_bincount(self): + device = "mps" + input_size = (5000,) + w = torch.randn(input_size, dtype=torch.float, device=device) + w_cpu = w.cpu() + + t = torch.randint(50, input_size, dtype=torch.int8, device=device) + self.assertEqual(t.cpu().bincount(), t.bincount()) + self.assertEqual(t.cpu().bincount(w_cpu), t.bincount(w)) + + t = torch.randint(500, input_size, dtype=torch.int32, device=device) + self.assertEqual(t.cpu().bincount(), t.bincount()) + self.assertEqual(t.cpu().bincount(w_cpu), t.bincount(w)) + + t = torch.randint(2000, input_size, dtype=torch.int32, device=device) + self.assertEqual(t.cpu().bincount(), t.bincount()) + self.assertEqual(t.cpu().bincount(w_cpu), t.bincount(w)) + t = torch.zeros([10], dtype=torch.int32, device=device) + t[0] = 35488 + counted = t.bincount(minlength=65536) + self.assertEqual(torch.sum(counted), 10) def test_sum_backward(self): def helper(n, c): @@ -2239,6 +3037,14 @@ def test_eq(self): self.assertEqual(result_cpu, result_mps.to('cpu')) + def test_signed_vs_unsigned_comparison(self): + cpu_x = torch.tensor((-1, 2, 3), device='cpu', dtype=torch.uint8) + mps_x = torch.tensor((-1, 2, 3), device='mps', dtype=torch.uint8) + # in the comparison of signed vs. unsigned we should always cast to unsigned + self.assertEqual(cpu_x == -1, mps_x == -1) + self.assertEqual(cpu_x > -1, mps_x > -1) + self.assertEqual(cpu_x < -1, mps_x < -1) + def test_eq_int64(self): values1 = [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]] values2 = [[[1, 2, 15], [4, 5, 6]], [[7, 8, 9], [0, 11, 12]]] @@ -2381,7 +3187,7 @@ def helper(n, c, h, w, reduction_type, dtype=torch.float32): cpu_x = None x = None - if(dtype not in [torch.float32, torch.bool]): + if (dtype not in [torch.float32, torch.bool]): cpu_x = torch.randint(50, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False) x = cpu_x.detach().clone().to('mps') elif (dtype == torch.bool): @@ -2441,7 +3247,7 @@ def helper(n, c, h, w, reduction_type, dtype=torch.float32): def test_max_el(self): def helper(n, c, h, w, dtype=torch.float32): - if(dtype not in [torch.float32, torch.bool]): + if (dtype not in [torch.float32, torch.bool]): cpu_x = torch.randint(50, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False) x = cpu_x.detach().clone().to('mps') elif (dtype == torch.bool): @@ -2522,6 +3328,47 @@ def helper(n, c, h, w, dtype=torch.float32): helper(2, 8, 4, 5, torch.int32) # helper(2, 8, 4, 5, torch.int64) + def test_median(self): + def helper_dtype_int32(n1, n2, n3): + cpu_x = torch.randint(50, (n1, n2, n3), device='cpu', dtype=torch.int32) + mps_x = cpu_x.detach().clone().to('mps') + + result_cpu = torch.median(cpu_x) + result_mps = torch.median(mps_x) + + self.assertEqual(result_cpu, result_mps) + + for dim in [0, 1, 2]: + for keepdim in [True, False]: + y, idx = torch.median(cpu_x, dim=dim, keepdim=keepdim) + refy, refidx = torch.median(mps_x, dim=dim, keepdim=keepdim) + self.assertEqual(y, refy) + self.assertEqual(idx, refidx) + + def helper_dtype_float32(n1, n2, n3): + cpu_x = torch.randn(n1, n2, n3, device='cpu', dtype=torch.float32) + mps_x = cpu_x.detach().clone().to('mps') + + result_cpu = torch.median(cpu_x) + result_mps = torch.median(mps_x) + + self.assertEqual(result_cpu, result_mps) + + for dim in [0, 1, 2]: + for keepdim in [True, False]: + y, idx = torch.median(cpu_x, dim=dim, keepdim=keepdim) + refy, refidx = torch.median(mps_x, dim=dim, keepdim=keepdim) + self.assertEqual(y, refy) + self.assertEqual(idx, refidx) + + helper_dtype_int32(10, 10, 10) # median at even place + helper_dtype_int32(3, 3, 3) # median at odd place + helper_dtype_int32(1, 1, 1) + helper_dtype_int32(1, 2, 3) + helper_dtype_float32(10, 10, 10) + helper_dtype_float32(3, 3, 3) + helper_dtype_float32(1, 1, 1) + def test_any(self): def helper(shape): input_xs = [] @@ -2765,7 +3612,7 @@ def test_sum(self): def helper(n, c, h, w, dtype=torch.float32): cpu_x = None x = None - if(dtype not in [torch.float32, torch.bool]): + if (dtype not in [torch.float32, torch.bool]): cpu_x = torch.randint(50, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False) x = cpu_x.detach().clone().to('mps') elif (dtype == torch.bool): @@ -2830,7 +3677,7 @@ def test_prod(self): def helper(shape, dtype=torch.float32): cpu_x = None x = None - if(dtype not in [torch.float32, torch.bool]): + if (dtype not in [torch.float32, torch.bool]): cpu_x = torch.randint(1, 6, shape, device='cpu', dtype=dtype, requires_grad=False) x = cpu_x.detach().clone().to('mps') elif (dtype == torch.bool): @@ -3212,26 +4059,33 @@ def helper(n, c, h, w): def test_divmode(self): def helper(shape, rounding_mode): for dtype in [torch.float32, torch.float16, torch.int32, torch.int64]: - cpu_x = None - cpu_y = None - if(dtype in [torch.float32, torch.float16]): - cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False) - cpu_y = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False) - else: - cpu_x = torch.randint(-10, 0, shape, device='cpu', dtype=dtype, requires_grad=False) - cpu_y = torch.randint(-10, 0, shape, device='cpu', dtype=dtype, requires_grad=False) - - mps_x = cpu_x.detach().clone().to('mps') - # clamp to avoid division by 0 - mps_y = cpu_y.detach().clone().to('mps') - - result_div_cpu = torch.div(cpu_x, cpu_y, rounding_mode=rounding_mode) - result_div_mps = torch.div(mps_x, mps_y, rounding_mode=rounding_mode) - self.assertEqual(result_div_mps, result_div_cpu) + if (rounding_mode is not None and "floor" in rounding_mode and dtype == torch.int64) is False: + cpu_x = None + cpu_y = None + if (dtype in [torch.float32, torch.float16]): + cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False) + cpu_y = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False) + else: + cpu_x = torch.randint(-10, 0, shape, device='cpu', dtype=dtype, requires_grad=False) + cpu_y = torch.randint(-10, 0, shape, device='cpu', dtype=dtype, requires_grad=False) + + mps_x = cpu_x.detach().clone().to('mps') + # clamp to avoid division by 0 + mps_y = cpu_y.detach().clone().to('mps') + + if (rounding_mode == "floor_divide"): + result_div_cpu = torch.floor_divide(cpu_x, cpu_y) + result_div_mps = torch.floor_divide(mps_x, mps_y) + self.assertEqual(result_div_mps, result_div_cpu) + else: + result_div_cpu = torch.div(cpu_x, cpu_y, rounding_mode=rounding_mode) + result_div_mps = torch.div(mps_x, mps_y, rounding_mode=rounding_mode) + self.assertEqual(result_div_mps, result_div_cpu) helper((2, 8, 4, 5), None) helper((2, 8, 4, 5), "floor") helper((2, 8, 4, 5), "trunc") + helper((2, 8, 4, 5), "floor_divide") def test_rounding(self): def helper(shape): @@ -3270,6 +4124,13 @@ def helper(n, c): helper(3, 1) + def test_im2col(self): + def helper(x): + return torch.nn.functional.unfold(x, kernel_size=(10, 15), dilation=2, padding=5, stride=3) + x_cpu = torch.rand(1, 1, 200, 100) + x = x_cpu.detach().clone().to('mps') + self.assertEqual(helper(x_cpu), helper(x)) + def test_select(self): def helper(n, c): cpu_x = torch.randn(n, c, device='cpu', dtype=torch.float, requires_grad=True) @@ -3325,26 +4186,6 @@ def helper(shape): helper((1, 5)) helper((5, 9, 7, 4)) - def test_upsample_nearest_exact2d(self): - def helper(N, C, H, W): - inputCPU = torch.arange(N * C * H * W, device='cpu', dtype=torch.float, - requires_grad=True).reshape(N, C, H, W) - inputCPU.retain_grad() - inputMPS = inputCPU.detach().clone().to('mps').requires_grad_() - - outputCPU = torch.nn.functional.interpolate(inputCPU, size=(5, 5), mode='nearest-exact') - outputMPS = torch.nn.functional.interpolate(inputMPS, size=(5, 5), mode='nearest-exact') - - self.assertEqual(outputCPU, outputMPS) - - outputCPU.backward(gradient=torch.full_like(outputCPU, 0.3)) - outputMPS.backward(gradient=torch.full_like(outputMPS, 0.3)) - - self.assertEqual(inputCPU.grad, inputMPS.grad) - - helper(1, 1, 4, 4) - helper(7, 5, 3, 2) - def test_upsample_nearest2d(self): def helper(N, C, H, W): inputCPU = torch.arange(N * C * H * W, device='cpu', dtype=torch.float, @@ -3410,19 +4251,49 @@ def helper(N, C, H, W): helper(1, 1, 4, 4) helper(7, 5, 3, 2) - def test_upsample_nearest1d(self): - def helper(N, C, H, W): - inputCPU = torch.arange(C * H * W, device='cpu', dtype=torch.float, - requires_grad=True).reshape(C, H, W) - inputMPS = inputCPU.detach().clone().to('mps') + def test_interpolate(self): + def helper(shape, output_size, scales, mode, align_corners=False): + inputCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) + inputCPU.retain_grad() + inputMPS = inputCPU.detach().clone().to('mps').requires_grad_() - outputCPU = torch.nn.functional.interpolate(inputCPU, scale_factor=2.0, mode='nearest') - outputMPS = torch.nn.functional.interpolate(inputMPS, scale_factor=2.0, mode='nearest') + # align_corners is used for 2D interpolation only + if (align_corners is True and len(shape) > 3 and mode == 'bilinear'): + if scales is not None: + outputCPU = nn.functional.interpolate(inputCPU, scale_factor=scales, mode=mode, align_corners=align_corners) + outputMPS = nn.functional.interpolate(inputMPS, scale_factor=scales, mode=mode, align_corners=align_corners) + else: + outputCPU = nn.functional.interpolate(inputCPU, size=output_size, mode=mode, align_corners=align_corners) + outputMPS = nn.functional.interpolate(inputMPS, size=output_size, mode=mode, align_corners=align_corners) + elif scales is not None: + outputCPU = nn.functional.interpolate(inputCPU, scale_factor=scales, mode=mode) + outputMPS = nn.functional.interpolate(inputMPS, scale_factor=scales, mode=mode) + else: + outputCPU = nn.functional.interpolate(inputCPU, size=output_size, mode=mode) + outputMPS = nn.functional.interpolate(inputMPS, size=output_size, mode=mode) self.assertEqual(outputCPU, outputMPS) - helper(1, 1, 4, 4) - helper(7, 5, 3, 2) + # backward pass (chose 0.6 just to have the grad_output != 1) + outputCPU.backward(gradient=torch.full_like(outputCPU, 0.6)) + outputMPS.backward(gradient=torch.full_like(outputMPS, 0.6)) + self.assertEqual(inputCPU.grad, inputMPS.grad) + + # 1D interpolation + for mode in ['nearest', 'nearest-exact']: + helper([2, 3, 4], [3], None, mode) # downsample with size + helper([2, 3, 4], [6], None, mode) # upsample with size + helper([2, 3, 4], None, [0.6], mode) # downsample with scale factor + helper([2, 3, 4], None, [1.7], mode) # upsample with scale factor + # 2D interpolation + for mode in ['nearest', 'nearest-exact', 'bilinear']: + helper([2, 3, 4, 5], [3, 4], None, mode) # downsample_nearest with size + helper([2, 3, 4, 5], [6, 7], None, mode) # upsample_nearest with size + helper([2, 3, 4, 5], None, [0.6, 0.7], mode) # downsample_nearest with scale factor + helper([2, 3, 4, 5], None, [1.4, 1.7], mode) # upsample_nearest with scale factor + # align_corners=True + helper([2, 3, 4, 5], [3, 4], None, 'bilinear', True) + helper([2, 3, 4, 5], None, [1.4, 1.7], 'bilinear', True) # Test concat forward def test_cat1(self): @@ -3459,6 +4330,15 @@ def test_constant_pad(self): r_mps = m(input_mps) self.assertEqual(r_cpu, r_mps.to("cpu")) + # Arbitrary input dimensions + pad = (1, 1, 0, 0, 0, 0) + value = 3.5 + input_cpu = torch.randn((1, 1, 3, 3, 3, 3, 3, 3, 3, 3)) + input_mps = input_cpu.detach().clone().to("mps") + r_cpu = F.pad(input_cpu, pad=pad, value=value) + r_mps = F.pad(input_mps, pad=pad, value=value) + self.assertEqual(r_cpu, r_mps.to("cpu")) + def test_circular_pad(self): # https://github.com/pytorch/pytorch/issues/80856 k_cpu = torch.ones(3, 3, 9, 9) @@ -3520,10 +4400,20 @@ def helper(shape, padding, op, value=0): helper((2, 1, 6, 8), 2, nn.ReplicationPad2d) # verify if a change in shape of padding would cause problems with graph caching helper((2, 1, 6, 8), (2, 4, 3, 5), nn.ReplicationPad2d) + # negative padding + helper((1, 3, 4, 4), (-1, 1, -2, 1), nn.ReplicationPad2d) # Constant Pad 2D helper((2, 1, 6, 8), (2, 4, 3, 5), nn.ConstantPad2d) # input size < pad size helper((1, 2, 3), (0, 0, 0, 1), nn.ConstantPad2d) + # pad dims < input dims + helper((50, 9, 300), (0, 0, 0, 31), nn.ConstantPad2d) + # pad dims == input dims + helper((1, 3), (0, 2, 0, 1), nn.ConstantPad2d) + # input.numel() == 0 but output.numel() > 0 + helper((0, 3, 3), (1, 1, 1, 1, 1, 1), nn.ConstantPad2d) + # pad dims < input dims - 2 + helper((1, 2, 3, 4), (1, 2), nn.ConstantPad2d) # 3D Padding helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ReflectionPad3d) @@ -3531,6 +4421,10 @@ def helper(shape, padding, op, value=0): helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ReplicationPad3d) # Constant Pad 3D helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ConstantPad3d) + # check the workaround for the right padding bug in Monterey + helper((1, 2, 2, 2, 2), (0, 1), nn.ConstantPad3d) + # input size < pad size + helper((2, 4, 6), (1, 3, 3, 5, 3, 4), nn.ConstantPad3d) # Test stack forward def test_stack(self): @@ -3541,7 +4435,7 @@ def helper(shape, dtype=torch.float32): y, cpu_y = None, None z, cpu_z = None, None - if(dtype not in [torch.float32, torch.bool]): + if (dtype not in [torch.float32, torch.bool]): cpu_x = torch.randint(50, shape, device='cpu', dtype=dtype, requires_grad=False) x = cpu_x.detach().clone().to('mps') cpu_y = torch.randint(50, shape, device='cpu', dtype=dtype, requires_grad=False) @@ -3813,12 +4707,12 @@ def helper(shape, dim=0): # Test softplus def test_softplus(self): - def helper(shape): + def helper(shape, beta=1, threshold=20): cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) x = cpu_x.detach().clone().to('mps').requires_grad_() - softplus_result = torch.nn.Softplus(beta=0.5, threshold=0.5)(x) - softplus_result_cpu = torch.nn.Softplus(beta=0.5, threshold=0.5)(cpu_x) + softplus_result = torch.nn.Softplus(beta=beta, threshold=threshold)(x) + softplus_result_cpu = torch.nn.Softplus(beta=beta, threshold=threshold)(cpu_x) cpu_grad = torch.randn(softplus_result.shape) grad = cpu_grad.to('mps') @@ -3831,7 +4725,9 @@ def helper(shape): # Test empty shape too for shape in [(), (2, 3), (10, 10), (2, 3, 4, 5)]: - helper(shape) + for beta in [0.5, 1, 2, 3, 4]: + for threshold in [0.5, 20, 30, 40, 50]: + helper(shape, beta, threshold) # Test silu @@ -3884,7 +4780,7 @@ def helper(src_dtype, dst_dtype): def test_adaptive_avg_pool2d_simple(self): def helper(input_shape, out_shape, channels_last): cpu_x = torch.randn(input_shape, device='cpu', dtype=torch.float, requires_grad=True) - if(channels_last): + if (channels_last): cpu_x = cpu_x.to(memory_format=torch.channels_last) cpu_x.retain_grad() x = cpu_x.detach().clone().to('mps').requires_grad_() @@ -3929,11 +4825,11 @@ def helper(input_shape, out_shape, channels_last): def test_adaptive_max_pool2d_simple(self): def helper(input_shape, out_shape, return_indices, dtype, channels_last=False): cpu_x = None - if(dtype in [torch.float16, torch.float32]): + if (dtype in [torch.float16, torch.float32]): cpu_x = torch.randn(input_shape, device='cpu', dtype=dtype, requires_grad=True) else: cpu_x = torch.randint(50, input_shape, device='cpu', dtype=dtype, requires_grad=True) - if(channels_last): + if (channels_last): cpu_x = cpu_x.to(memory_format=torch.channels_last) cpu_x.retain_grad() x = cpu_x.detach().clone().to('mps').requires_grad_() @@ -3941,7 +4837,7 @@ def helper(input_shape, out_shape, return_indices, dtype, channels_last=False): max_result, max_indices = None, None max_result_cpu, max_indices_cpu = None, None - if(return_indices): + if (return_indices): max_result, max_indices = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(x) max_result_cpu, max_indices_cpu = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(cpu_x) else: @@ -3955,7 +4851,7 @@ def helper(input_shape, out_shape, return_indices, dtype, channels_last=False): max_result_cpu.backward(gradient=cpu_grad) self.assertEqual(max_result, max_result_cpu) - if(return_indices): + if (return_indices): self.assertEqual(max_indices, max_indices_cpu) self.assertEqual(x.grad, cpu_x.grad) @@ -4032,7 +4928,7 @@ def helper(shape, min_val, max_val, inplace=False): cpu_x = None x = None - if(not inplace): + if (not inplace): cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) x = cpu_x.detach().clone().to('mps').requires_grad_() else: @@ -4044,7 +4940,7 @@ def helper(shape, min_val, max_val, inplace=False): self.assertEqual(hardtanh_result, hardtanh_result_cpu) - if(not inplace): + if (not inplace): cpu_grad = torch.randn(hardtanh_result_cpu.shape) grad = cpu_grad.to('mps') hardtanh_result.backward(gradient=grad) @@ -4057,6 +4953,38 @@ def helper(shape, min_val, max_val, inplace=False): helper(shape, min_val, max_val) helper(shape, min_val, max_val, inplace=True) + def test_hardswish(self): + def helper(shape, inplace=False, requires_grad=True): + m = nn.Hardswish(inplace=inplace) + + input_cpu = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=requires_grad) + input_mps = input_cpu.detach().clone().to('mps').requires_grad_(requires_grad) + + if inplace and requires_grad: # check that both raise runtime error + self.assertRaises(RuntimeError, lambda: m(input_cpu)) + self.assertRaises(RuntimeError, lambda: m(input_mps)) + return + + output_cpu = m(input_cpu) + output_mps = m(input_mps) + + cpu_grad = torch.ones_like(output_cpu) + mps_grad = cpu_grad.to('mps') + + self.assertEqual(output_cpu, output_mps) + + if requires_grad: + output_cpu.backward(gradient=cpu_grad) + output_mps.backward(gradient=mps_grad) + + self.assertEqual(input_cpu.grad, input_mps.grad) + + for shape in [(0, 3), [], (2, 3), (2, 8, 4, 5)]: + helper(shape, inplace=False, requires_grad=False) + helper(shape, inplace=True, requires_grad=False) + helper(shape, inplace=False, requires_grad=True) + helper(shape, inplace=True, requires_grad=True) + def test_transpose_2D(self): values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]] values1 = [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]] @@ -4135,6 +5063,20 @@ def helper(shape): helper((2, 8, 4, 5)) + def test_signbit(self): + def helper(shape, dtype): + cpu_x = torch.randn(shape, device='cpu').to(dtype) + x = cpu_x.clone().to('mps') + + signbit_result = torch.signbit(x) + signbit_result_cpu = torch.signbit(cpu_x) + + self.assertEqual(signbit_result, signbit_result_cpu) + + helper((2, 8, 4, 5), torch.int) + helper((2, 8, 4, 5), torch.float) + helper((2, 8, 4, 5), torch.int64) + # Test neg def test_neg(self): def helper(shape): @@ -4224,33 +5166,35 @@ def helper(shape, dim, index, idx_dtype=torch.int32): helper((2, 3, 3), -1, [1, 2]) def test_embedding_dense_backward(self): - def helper(n, d, m): + def helper(n, d, m, idx): embeddingMPS = nn.Embedding(n, d, max_norm=True, device='mps') + emedding_weight = embeddingMPS.weight.detach().cpu() W_MPS = torch.randn((m, d), requires_grad=True, device='mps') - idx_MPS = torch.tensor([0, 1, 2]).to('mps') + idx_MPS = torch.tensor(idx, device='mps') a_MPS = embeddingMPS.weight.clone() @ W_MPS.t() # weight must be cloned for this to be differentiable a_MPS.retain_grad() b_MPS = embeddingMPS(idx_MPS) @ W_MPS.t() # modifies weight in-place b_MPS.retain_grad() - out_MPS = (a_MPS.unsqueeze(0) + b_MPS.unsqueeze(1)) + out_MPS = (a_MPS.unsqueeze(0) + b_MPS) loss_MPS = out_MPS.sigmoid().prod() loss_MPS.backward() - embeddingCPU = nn.Embedding(n, d, max_norm=True, scale_grad_by_freq=True) + embeddingCPU = nn.Embedding(n, d, max_norm=True, _weight=emedding_weight) W_CPU = W_MPS.to('cpu') - idx_CPU = torch.tensor([0, 1, 2]) + idx_CPU = torch.tensor(idx) a_CPU = embeddingCPU.weight.clone() @ W_CPU.t() # weight must be cloned for this to be differentiable a_CPU.retain_grad() b_CPU = embeddingCPU(idx_CPU) @ W_CPU.t() # modifies weight in-place b_CPU.retain_grad() - out_CPU = (a_CPU.unsqueeze(0) + b_CPU.unsqueeze(1)) + out_CPU = (a_CPU.unsqueeze(0) + b_CPU) loss_CPU = out_CPU.sigmoid().prod() loss_CPU.backward() self.assertEqual(b_CPU.grad, b_MPS.grad) self.assertEqual(a_CPU.grad, a_MPS.grad) - helper(3, 5, 7) + helper(3, 5, 7, [0, 1, 2]) + helper(3, 5, 7, 2) # test scalar index # Test pytorch gather def test_gather(self): @@ -4318,7 +5262,7 @@ def helper(shape, dim, idx_shape, src_shape, idx_dtype=torch.int64, do_add=True) # Indices should be taken from range of axis along which gathering is done idx_np = None - if(do_add): + if (do_add): idx_np = np.random.randint(0, shape[dim], idx_shape) else: idx_np = np.array([[0, 1, 2], @@ -4333,7 +5277,7 @@ def helper(shape, dim, idx_shape, src_shape, idx_dtype=torch.int64, do_add=True) scatter_result = None scatter_result_cpu = None - if(do_add): + if (do_add): scatter_result = torch.scatter_add(x, dim=dim, index=idx, src=src) scatter_result_cpu = torch.scatter_add(cpu_x, dim=dim, index=cpu_idx, src=cpu_src) else: @@ -4343,14 +5287,14 @@ def helper(shape, dim, idx_shape, src_shape, idx_dtype=torch.int64, do_add=True) cpu_grad = None grad = None - if(idx_shape == src_shape): + if (idx_shape == src_shape): cpu_grad = torch.randn(shape, device='cpu', dtype=torch.float) grad = cpu_grad.to('mps') scatter_result.backward(gradient=grad) scatter_result_cpu.backward(gradient=cpu_grad) self.assertEqual(scatter_result, scatter_result_cpu) - if(idx_shape == src_shape): + if (idx_shape == src_shape): self.assertEqual(cpu_x.grad, x.grad) self.assertEqual(cpu_src.grad, src.grad) @@ -4392,7 +5336,7 @@ def helper(idx_dtype=torch.int64, do_add=True): scatter_result = None scatter_result_cpu = None - if(do_add): + if (do_add): scatter_result = torch.scatter_add(x, dim=0, index=idx, src=src) scatter_result_cpu = torch.scatter_add(cpu_x, dim=0, index=cpu_idx, src=cpu_src) else: @@ -4435,22 +5379,22 @@ def helper(shape, dim, idx_shape, src_shape, idx_dtype=torch.int64, reduce_str=" self.assertEqual(scatter_result, scatter_result_cpu) # for reduce in ["sum", "prod", "amax", "amin"]: - for reduce in ["add", "multiply"]: - helper((2, 3), 0, (5, 3), (5, 3), reduce_str=reduce) - helper((2, 8, 4, 5), 0, (10, 8, 4, 5), (10, 8, 4, 5), reduce_str=reduce) - helper((8, 8, 4, 5), 0, (10, 8, 4, 5), (10, 8, 4, 5), reduce_str=reduce) - helper((8, 8, 4, 5), 0, (4, 7, 3, 2), (4, 7, 3, 2), reduce_str=reduce) - helper((8, 8, 4, 5), 0, (4, 6, 3, 2), (4, 7, 3, 2), reduce_str=reduce) - helper((8, 8, 4, 5), 0, (4, 6, 3, 2), (8, 8, 4, 5), reduce_str=reduce) - - helper((2, 8, 4, 5), 1, (2, 20, 4, 5), (2, 20, 4, 5), reduce_str=reduce) - helper((2, 8, 4, 5), 1, (2, 13, 3, 2), (2, 13, 3, 2), reduce_str=reduce) - helper((8, 8, 4, 5), 1, (6, 5, 2, 3), (6, 5, 2, 3), reduce_str=reduce) - helper((8, 8, 4, 5), 1, (3, 4, 2, 2), (6, 5, 2, 3), reduce_str=reduce) - - helper((4, 5, 9, 8), 2, (4, 5, 13, 8), (4, 5, 13, 8), reduce_str=reduce) - helper((4, 5, 9, 8), 2, (3, 4, 10, 6), (3, 4, 10, 6), reduce_str=reduce) - helper((4, 5, 9, 8), 2, (3, 3, 7, 5), (3, 4, 10, 6), reduce_str=reduce) + for reduce_type in ["add", "multiply"]: + helper((2, 3), 0, (5, 3), (5, 3), reduce_str=reduce_type) + helper((2, 8, 4, 5), 0, (10, 8, 4, 5), (10, 8, 4, 5), reduce_str=reduce_type) + helper((8, 8, 4, 5), 0, (10, 8, 4, 5), (10, 8, 4, 5), reduce_str=reduce_type) + helper((8, 8, 4, 5), 0, (4, 7, 3, 2), (4, 7, 3, 2), reduce_str=reduce_type) + helper((8, 8, 4, 5), 0, (4, 6, 3, 2), (4, 7, 3, 2), reduce_str=reduce_type) + helper((8, 8, 4, 5), 0, (4, 6, 3, 2), (8, 8, 4, 5), reduce_str=reduce_type) + + helper((2, 8, 4, 5), 1, (2, 20, 4, 5), (2, 20, 4, 5), reduce_str=reduce_type) + helper((2, 8, 4, 5), 1, (2, 13, 3, 2), (2, 13, 3, 2), reduce_str=reduce_type) + helper((8, 8, 4, 5), 1, (6, 5, 2, 3), (6, 5, 2, 3), reduce_str=reduce_type) + helper((8, 8, 4, 5), 1, (3, 4, 2, 2), (6, 5, 2, 3), reduce_str=reduce_type) + + helper((4, 5, 9, 8), 2, (4, 5, 13, 8), (4, 5, 13, 8), reduce_str=reduce_type) + helper((4, 5, 9, 8), 2, (3, 4, 10, 6), (3, 4, 10, 6), reduce_str=reduce_type) + helper((4, 5, 9, 8), 2, (3, 3, 7, 5), (3, 4, 10, 6), reduce_str=reduce_type) def test_is_nonzero(self): self.assertFalse(torch.is_nonzero(torch.tensor([0.]).to('mps'))) @@ -4484,6 +5428,21 @@ def helper(shape, diag=0): helper((2, 8, 4, 5), diag=-2) helper((2, 8, 4, 5), diag=-3) + # Test inverse + def test_inverse(self): + def helper(n): + cpu_input = torch.randn(n, n, device='cpu') + mps_input = cpu_input.to('mps') + + cpu_result = torch.linalg.inv(cpu_input) + mps_result = torch.linalg.inv(mps_input) + self.assertEqual(cpu_result, mps_result) + + helper(2) + helper(6) + helper(3) + helper(8) + # Test tril def test_tril(self): def helper(shape, diag=0): @@ -4516,7 +5475,7 @@ def helper(n, m, dtype): cpu_result = None result = None - if(n == m): + if (n == m): cpu_result = torch.eye(n, dtype=dtype, device='cpu') result = torch.eye(n, dtype=dtype, device='mps') else: @@ -4575,11 +5534,19 @@ def test_arange(self): self.assertEqual(np.arange(1, 2, .3, dtype=np.float32), torch.arange(1, 2, .3, device='mps')) self.assertEqual(np.arange(6.3, dtype=np.float32), torch.arange(6.3, device='mps')) + def test_arange_empty(self): + out_mps = torch.tensor([], device="mps") + out_cpu = torch.tensor([], device="cpu") + + y_mps = torch.arange(0, 0, 1, out=out_mps) + y_cpu = torch.arange(0, 0, 1, out=out_cpu) + self.assertEqual(y_mps, y_cpu) + # Test softmax def test_softmax(self): def helper(shape, dim, channels_last=False): cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) - if(channels_last): + if (channels_last): cpu_x = cpu_x.to(memory_format=torch.channels_last) cpu_x.retain_grad() x = cpu_x.detach().clone().to('mps').requires_grad_() @@ -4591,7 +5558,7 @@ def helper(shape, dim, channels_last=False): cpu_grad = None grad = None - if(not channels_last): + if (not channels_last): cpu_grad = torch.randn(shape, device='cpu', dtype=torch.float) grad = cpu_grad.to('mps') @@ -4599,7 +5566,7 @@ def helper(shape, dim, channels_last=False): softmax_result_cpu.backward(gradient=cpu_grad) self.assertEqual(softmax_result, softmax_result_cpu) - if(not channels_last): + if (not channels_last): self.assertEqual(x.grad, cpu_x.grad) def helper2(dim): @@ -4622,7 +5589,7 @@ def helper2(dim): for channels_last in [False]: for shape in [(2, 4, 8, 5), (3, 4, 6, 7, 2)]: - if(len(shape) != 4 and channels_last): + if (len(shape) != 4 and channels_last): continue for dim in [0, 1, 2, 3, -1, -2, -3]: helper(shape, dim, channels_last) @@ -4645,6 +5612,13 @@ def helper(shape, alpha): helper((2, 8, 3, 5), 0.1) helper((2, 8, 3, 5), 0.2) + def test_nan_to_num(self): + inputCPU = torch.tensor([float('nan'), float('inf'), -float('inf'), 3.14]) + inputMPS = inputCPU.detach().clone().to('mps').requires_grad_() + outputCPU = torch.nan_to_num(inputCPU, nan=2.0, posinf=1.0, neginf=-1.0) + outputMPS = torch.nan_to_num(inputMPS, nan=2.0, posinf=1.0, neginf=-1.0) + self.assertEqual(outputMPS, outputCPU) + # Test where def test_where(self): def helper(shape, x_shape, y_shape, cond_dtype=torch.bool, x_dtype=torch.float): @@ -4744,6 +5718,67 @@ def test_bernoulli(self): mps_out = torch.bernoulli(all_ones) self.assertEqual(mps_out, all_ones) + def test_mps_generator(self): + # explicit manual seeding by creating an MPS Generator + g_mps = torch.Generator(device='mps') + g_mps.manual_seed(999) + mps_x = torch.randn(5, device='mps', generator=g_mps) + g_mps.manual_seed(999) + mps_y = torch.randn(5, device='mps', generator=g_mps) + # seed values were the same, so the random tensor contents should match + self.assertEqual(mps_x, mps_y) + # save generator's state to restore it later + g_state = g_mps.get_state() + + # generate random numbers without seeding + mps_x = torch.randn(5, device='mps', generator=g_mps) + # in this case, the random results must differ from the last generated random results + self.assertNotEqual(mps_x, mps_y) + + # restore the previously saved state, and the results should match again + g_mps.set_state(g_state) + mps_x = torch.randn(5, device='mps', generator=g_mps) + self.assertEqual(mps_x, mps_y) + + def test_default_mps_generator(self): + # manual seeding on the "default" MPS generator using + # the global torch.manual_seed() + torch.manual_seed(230) + mps_x = torch.randn(5, device='mps') + # manual seeding using torch.mps.manual_seed() + # which should set the "default" MPS generator + # like the global torch.manual_seed() + torch.mps.manual_seed(230) + mps_y = torch.randn(5, device='mps') + # seed values were the same, so the random tensor contents should match + self.assertEqual(mps_x, mps_y) + + # save the default generator's state to restore it later + g_state = torch.mps.get_rng_state() + + # generate random numbers without seeding + mps_x = torch.randn(5, device='mps') + # in this case, the random results must differ from the last generated random results + self.assertNotEqual(mps_x, mps_y) + + # restore the previously saved state, and the results should match again + torch.mps.set_rng_state(g_state) + mps_x = torch.randn(5, device='mps') + self.assertEqual(mps_x, mps_y) + + def test_device_synchronize(self): + # just running some ops each followed by a synchronize to wait for + # MPS stream to finish running each of them + net1 = torch.nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\ + .to(device='mps', dtype=torch.float) + + x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True) + torch.mps.synchronize() + x = net1(x) + torch.mps.synchronize() + x.backward(torch.randn_like(x)) + torch.mps.synchronize() + # Test random_.to and random_.from def test_random(self): def helper(shape, low, high, dtype=torch.int32): @@ -4868,6 +5903,7 @@ def helper(shape, op): helper((2, 8, 4, 5), torch.exp) helper((2, 8, 3, 5), torch.exp2) + helper((2, 8, 3, 5), torch.expm1) helper((2, 8, 3, 5), torch.log) helper((2, 8, 3, 5), torch.cos) @@ -4895,7 +5931,7 @@ def helper(probs, compare_mean, compare_var, num_samples=5, replacement=True): prob_tensor = cpu_prob_tensor.detach().clone().to('mps') mps_out = torch.multinomial(prob_tensor, num_samples, replacement=replacement) - if(not replacement): + if (not replacement): print(mps_out.to('cpu')) else: # Compare "real" with theoretical values @@ -4979,10 +6015,14 @@ def test_conv_expand(self): # The test should not crash def test_permute(self): - X = torch.randn(5, 5).to('mps') - torch.log(X) - X = X.permute(1, 0) - torch.log(X) + M_cpu = torch.randn(5, 5) + M_mps = M_cpu.to('mps') + + output_cpu = M_cpu.permute(1, 0) + output_mps = M_mps.permute(1, 0) + + self.assertEqual(output_cpu, output_mps) + self.assertEqual(output_cpu.size(), output_mps.size()) # Printing of non_contiguous should not crash def test_print_non_contiguous(self): @@ -5426,14 +6466,13 @@ def test_T_view(self, device="mps"): v[0, 1] = 0 self.assertEqual(t[1, 0], v[0, 1]) - # requires aten::unfold - # def test_unfold_view(self, device="mps"): - # t = torch.ones(10, device=device) - # v = t.unfold(0, 3, 2) - # self.assertTrue(self.is_view_of(t, v)) + def test_unfold_view(self, device="mps"): + t = torch.ones(10, device=device) + v = t.unfold(0, 3, 2) + self.assertTrue(self.is_view_of(t, v)) - # v[1, 0] = 0 - # self.assertEqual(t[2], v[1, 0]) + v[1, 0] = 0 + self.assertEqual(t[2], v[1, 0]) def test_squeeze_view(self, device="mps"): t = torch.ones(5, 1, 5, device=device) @@ -6040,7 +7079,6 @@ def test_view(self, device="mps"): self.assertRaises(RuntimeError, lambda: tensor.view(7, -1)) self.assertRaises(RuntimeError, lambda: tensor.view(15, -1, -1)) - # RuntimeError: Invalid device for storage: mps def test_contiguous(self, device="mps"): x = torch.randn(1, 16, 5, 5, device=device) self.assertTrue(x.is_contiguous()) @@ -6135,17 +7173,33 @@ def test_conv_transpose_1d_nn_functional(self): self.assertEqual(tcpu, tgpu.cpu(), rtol=2.6e-05, atol=2e-04) def test_conv_backward_1d_channels_last(self): - # https://github.com/pytorch/pytorch/issues/84511 - conv_cpu = torch.nn.Conv1d(in_channels=1, out_channels=1, kernel_size=3) - conv_mps = copy.deepcopy(conv_cpu).to(device='mps') + def helper(shape, in_channels=1, out_channels=1, kernel_size=3, groups=1): + # https://github.com/pytorch/pytorch/issues/84511 + conv_cpu = torch.nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, groups=groups) + conv_mps = torch.nn.Conv1d( + in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, groups=groups).to("mps") + conv_mps.weight.data = conv_cpu.weight.data.detach().clone().to("mps").requires_grad_(True) + conv_mps.bias.data = conv_cpu.bias.data.detach().clone().to("mps").requires_grad_(True) + + + data = torch.rand(shape, dtype=torch.float32) + x_cpu = data.permute(0, 2, 1).contiguous().requires_grad_(True) + x_mps = data.permute(0, 2, 1).detach().clone().to("mps").contiguous().requires_grad_(True) + res_cpu = conv_cpu(x_cpu) + res_mps = conv_mps(x_mps) + self.assertEqual(res_cpu, res_mps) + res_cpu = res_cpu.sum().backward() + res_mps = res_mps.sum().backward() - data = torch.rand(1, 176, 1, dtype=torch.float32) - x_cpu = data.permute(0, 2, 1).contiguous() - x_mps = data.permute(0, 2, 1).contiguous().to("mps") - res_cpu = conv_cpu(x_cpu).sum().backward() - res_mps = conv_mps(x_mps).sum().backward() + self.assertEqual(conv_cpu.weight.grad, conv_mps.weight.grad, rtol=2.6e-05, atol=2e-04) + self.assertEqual(x_cpu.grad, x_mps.grad) - self.assertEqual(res_cpu, res_mps) + helper(shape=(1, 176, 1)) + helper(shape=(2, 12, 1)) + helper(shape=(3, 176, 1)) + helper(shape=(4, 376, 1)) + helper(shape=(1024, 376, 9), in_channels=9, out_channels=1, groups=1) + helper(shape=(1024, 376, 9), in_channels=9, out_channels=9, groups=3) def test_conv1d_contiguous(self): model_cpu = torch.nn.Conv1d(1, 128, 3) @@ -6181,10 +7235,358 @@ def test_conv2d_single_stride(self): x_gpu = conv_gpu(y_gpu) self.assertEqual(x_cpu, x_gpu.cpu(), rtol=1e-03, atol=1e-05) + def test_grid_sample(self): + def test(N, C, H, W, mode, padding_mode, align_corners, input_requires_grad): + def test_shape(N, C, IH, IW, H, W, mode, padding_mode, align_corners): + for grid_dim_contig_order in [(0, 1, 2, 3), (0, 3, 1, 2), (3, 0, 1, 2), (0, 2, 1, 3)]: + # grid_dim_contig_order specifies the dimension order that can + # make grid to be contiguous. + # i.e., grid.permute(grid_dim_contig_order) is contiguous. + # e.g., with grid_dim_contig_order=[0, 3, 1, 2], grid should be + # initialized with contiguous tensor of shape [N, 2, H, W] + # and permuted to [N, H, W, 2] afterwards. + grid_shape = [N, H, W, 2] + grid_init_shape = [grid_shape[d] for d in grid_dim_contig_order] + grid_fwd_permute = [None, None, None, None] + for i, d in enumerate(grid_dim_contig_order): + grid_fwd_permute[d] = i + + def get_grid(device='cpu', data=None): + if data is not None: + assert list(data.shape) == grid_shape + data = data.permute(grid_dim_contig_order).to(device) + else: + data = torch.randn(grid_init_shape, device=device) + grid = data.permute(grid_fwd_permute) + assert grid.permute(grid_dim_contig_order).is_contiguous() + return grid + + input_cpu = torch.randn(C, N, IH, IW).transpose(0, 1).requires_grad_(input_requires_grad) + grid_cpu = get_grid().requires_grad_() + out_cpu = F.grid_sample(input_cpu, grid_cpu, mode=mode, padding_mode=padding_mode, + align_corners=align_corners) + self.assertTrue(out_cpu.size() == torch.Size([N, C, H, W])) + + gradients = torch.randn_like(out_cpu) + out_cpu.backward(gradients) + + + # Compare against unvectorized CPU fallback + + # NOTE [ grid_sample CPU fallback ] + # grid_sample uses AVX for 2d images, but that requires 32-bit indexing for + # 32-bit floats. So we also have a fallback that is used only for float tensors + # requiring 64-bit indexing. That requires too much memory to run on CI, so we + # also export the fallback and test it here to ensure feature parity with + # the vectorized version. + input_fallback = input_cpu.float().detach_().requires_grad_() + grid_fallback = grid_cpu.float().detach_().requires_grad_() + out_fallback = torch._grid_sampler_2d_cpu_fallback( + input_fallback, grid_fallback, + F.GRID_SAMPLE_INTERPOLATION_MODES[mode], + F.GRID_SAMPLE_PADDING_MODES[padding_mode], + align_corners) + self.assertEqual(out_fallback, out_cpu.float(), atol=1e-5, rtol=5e-5) + + out_fallback.backward(gradients.float()) + if input_requires_grad: + self.assertEqual(input_fallback.grad, input_cpu.grad.float(), atol=1e-4, rtol=5e-5) + self.assertEqual(grid_fallback.grad, grid_cpu.grad.float(), atol=1e-4, rtol=5e-5) + + input_mps = input_cpu.detach().transpose(0, 1).to("mps").transpose(0, 1).requires_grad_(input_requires_grad) + grid_mps = get_grid('mps', grid_cpu.detach()).requires_grad_() + out_mps = F.grid_sample(input_mps, grid_mps, mode=mode, padding_mode=padding_mode, align_corners=align_corners) + self.assertEqual(out_cpu, out_mps) + out_mps.backward(gradients.to("mps")) + if input_requires_grad: + self.assertEqual(input_cpu.grad, input_mps.grad) + self.assertEqual(grid_cpu.grad, grid_mps.grad, atol=5e-5, rtol=0) + + # check that zero-dimensional input strides don't error out + base_input = torch.randn(N, C, 1, IW) + input_cpu = base_input.expand_as(input_mps).requires_grad_(input_requires_grad) + out_cpu = F.grid_sample(input_cpu, grid_cpu, mode=mode, padding_mode=padding_mode, + align_corners=align_corners) + + input_mps = base_input.to("mps").expand_as(input_mps).requires_grad_(input_requires_grad) + out_mps = F.grid_sample(input_mps, grid_mps, mode=mode, padding_mode=padding_mode, align_corners=align_corners) + self.assertEqual(out_cpu, out_mps) + + # test same size output + test_shape(N, C, H, W, H, W, mode, padding_mode, align_corners) + + # test larger output + N = random.randint(2, 8) + C = random.randint(2, 8) + IH = random.randint(2, 8) + IW = random.randint(2, 8) + H = random.randint(IH + 1, 12) + W = random.randint(IW + 1, 12) + test_shape(N, C, IH, IW, H, W, mode, padding_mode, align_corners) + + # test smaller output + N = random.randint(2, 8) + C = random.randint(2, 8) + IH = random.randint(2, 8) + IW = random.randint(2, 8) + H = random.randint(2, IH) + W = random.randint(2, IW) + test_shape(N, C, IH, IW, H, W, mode, padding_mode, align_corners) + + # test 1x1 inpput + N = random.randint(2, 8) + C = random.randint(2, 8) + IH = 1 + IW = 1 + H = random.randint(2, 5) + W = random.randint(2, 5) + test_shape(N, C, IH, IW, H, W, mode, padding_mode, align_corners) + + # testing empty grid + N = random.randint(2, 8) + C = random.randint(2, 8) + IH = random.randint(2, 8) + IW = random.randint(2, 8) + W = random.randint(3, IW + 2) + test_shape(N, C, IH, IW, 0, W, mode, padding_mode, align_corners) + + # testing empty channel + N = random.randint(2, 8) + IH = random.randint(2, 8) + IW = random.randint(2, 8) + H = random.randint(3, IH + 2) + W = random.randint(3, IW + 2) + test_shape(N, 0, IH, IW, H, W, mode, padding_mode, align_corners) + + # testing empty batch + C = random.randint(2, 8) + IH = random.randint(2, 8) + IW = random.randint(2, 8) + H = random.randint(3, IH + 2) + W = random.randint(3, IW + 2) + test_shape(0, C, IH, IW, H, W, mode, padding_mode, align_corners) + + for mode in ('bilinear', 'nearest'): + for padding_mode in ('zeros', 'reflection'): + for align_corners in (True, False): + # test known input + input = torch.arange(1., 11, device="mps").view(1, 1, 2, 5) + grid = torch.tensor( + [[[-0.9, -4.1], [0, 0.2000], [1, -1], [-0.333, 1e-6], [0.5, 1.0]], + [[-1.0, -0.5], [0, 0.3333], [1, -1], [-0.200, 1e-6], [1.5, 0.5]]], device="mps").view(1, 2, 5, 2) + if mode == 'bilinear': + if padding_mode == 'zeros': + if align_corners: + groundtruth = torch.tensor( + [[0.0000, 6.0000000000, 5.0000, 4.8340, 9.0000], + [2.2500, 6.3332500450, 5.0000, 5.1000, 0.0000]], device="mps").view(1, 1, 2, 5) + else: + groundtruth = torch.tensor( + [[0.0000, 6.5000000000, 1.2500, 4.6675000191, 4.6250], + [0.5000, 7.1665000916, 1.2500, 5.0000000000, 0.0000]], device="mps").view(1, 1, 2, 5) + elif padding_mode == 'border': + if align_corners: + groundtruth = torch.tensor( + [[1.2000, 6.0000000000, 5.0000, 4.8340, 9.0000], + [2.2500, 6.3332500450, 5.0000, 5.1000, 8.7500]], device="mps").view(1, 1, 2, 5) + else: + groundtruth = torch.tensor( + [[1.0000, 6.5000000000, 5.0000, 4.6675000191, 9.2500], + [1.0000, 7.1665000916, 5.0000, 5.0000000000, 10.0000]], device="mps").view(1, 1, 2, 5) + elif padding_mode == 'reflection': + if align_corners: + groundtruth = torch.tensor( + [[3.4500, 6.0000000000, 5.0000, 4.8340, 9.0000], + [2.2500, 6.3332500450, 5.0000, 5.1000, 7.7500]], device="mps").view(1, 1, 2, 5) + else: + groundtruth = torch.tensor( + [[3.0000004768, 6.5000000000, 5.0000, 4.6675000191, 9.2500], + [1.0000000000, 7.1665000916, 5.0000, 5.0000000000, 9.2500]], device="mps").view(1, 1, 2, 5) + else: + raise AssertionError("missing groundtruth test for padding mode '{}'".format(padding_mode)) + elif mode == 'nearest': + if padding_mode == 'zeros': + if align_corners: + groundtruth = torch.tensor( + [[0., 8., 5., 7., 9.], + [1., 8., 5., 8., 0.]], device="mps").view(1, 1, 2, 5) + else: + groundtruth = torch.tensor( + [[0., 8., 5., 7., 0.], + [1., 8., 5., 8., 0.]], device="mps").view(1, 1, 2, 5) + elif padding_mode == 'border': + if align_corners: + groundtruth = torch.tensor( + [[1., 8., 5., 7., 9.], + [1., 8., 5., 8., 10.]], device="mps").view(1, 1, 2, 5) + else: + groundtruth = torch.tensor( + [[1., 8., 5., 7., 9.], + [1., 8., 5., 8., 10.]], device="mps").view(1, 1, 2, 5) + elif padding_mode == 'reflection': + if align_corners: + groundtruth = torch.tensor( + [[1., 8., 5., 7., 9.], + [1., 8., 5., 8., 9.]], device="mps").view(1, 1, 2, 5) + else: + groundtruth = torch.tensor( + [[1., 8., 5., 7., 9.], + [1., 8., 5., 8., 9.]], device="mps").view(1, 1, 2, 5) + else: + raise AssertionError("missing groundtruth test for padding mode '{}'".format(padding_mode)) + elif mode == 'bicubic': + if padding_mode == 'zeros': + if align_corners: + groundtruth = torch.tensor( + [[-0.10424726, 7.1400003, 5.0000, 5.7842274, 9.0000], + [2.4492188, 7.4814040, 5.0000, 6.0277520, 0.0000]], device="mps").view(1, 1, 2, 5) + else: + groundtruth = torch.tensor( + [[0.00000, 7.6287503, 1.0625, 5.5977230, 5.3270264], + [0.40625, 8.0288770, 1.0625, 5.9375067, -0.3515625]], device="mps").view(1, 1, 2, 5) + elif padding_mode == 'border': + if align_corners: + groundtruth = torch.tensor( + [[1.1520010, 6.0599990, 5.0000, 4.870930, 9.0000000], + [2.1328125, 6.4258375, 5.0000, 5.076003, 8.8671875]], device="mps").view(1, 1, 2, 5) + else: + groundtruth = torch.tensor( + [[0.894531, 6.6050020, 4.625, 4.7138715, 9.800781], + [0.906250, 7.2822485, 4.625, 5.0000052, 10.00000]], device="mps").view(1, 1, 2, 5) + elif padding_mode == 'reflection': + if align_corners: + groundtruth = torch.tensor( + [[3.1822524, 6.239998, 5.0000, 4.8709273, 9.00000], + [1.7812500, 6.703594, 5.0000, 5.0760007, 8.21875]], device="mps").view(1, 1, 2, 5) + else: + groundtruth = torch.tensor( + [[2.7993753, 6.6050020, 4.25, 4.7138715, 10.269531], + [0.8125000, 7.2822485, 4.25, 5.0000052, 9.332031]], device="mps").view(1, 1, 2, 5) + else: + raise AssertionError("missing groundtruth test for padding mode '{}'".format(padding_mode)) + + else: + raise AssertionError("missing groundtruth test for interpolation mode '{}'".format(mode)) + output = F.grid_sample(input, grid, mode=mode, padding_mode=padding_mode, + align_corners=align_corners) + self.assertEqual(output, groundtruth, atol=1e-5, rtol=0, + msg="groundtruth comparison failed for mode={}, " + "padding_mode={}".format(mode, padding_mode)) + class TestAdvancedIndexing(TestCase): supported_dtypes = [torch.float32, torch.float16, torch.int64, torch.int32, torch.int16, torch.uint8] supported_np_dtypes = [np.float32, np.float16, np.int64, np.int32, np.int16, np.uint8] + def test_nonzero_no_warning(self): + device = "mps" + t = torch.randn((2, 2), device=device) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + torch.nonzero(t) + t.nonzero() + self.assertEqual(len(w), 0) + + def test_nonzero(self): + def helper(dtype): + device = "mps" + shapes = [ + torch.Size((12,)), + torch.Size((12, 1)), + torch.Size((1, 12)), + torch.Size((6, 2)), + torch.Size((3, 2, 2)), + torch.Size((5, 5, 5)), + ] + + def gen_nontrivial_input(shape, dtype, device): + if dtype != torch.bfloat16: + return torch.randint(2, shape, device=device, dtype=dtype) + else: + # windows does not work for bfloat16 randing + return torch.randint(2, shape, device=device, dtype=torch.float).to(dtype) + + for shape in shapes: + tensor = gen_nontrivial_input(shape, dtype, device) + dst1 = torch.nonzero(tensor, as_tuple=False) + dst2 = tensor.nonzero(as_tuple=False) + dst3 = torch.empty([], dtype=torch.long, device=device) + dst3 = dst3.resize_(0) + torch.nonzero(tensor, out=dst3) + np_array = tensor.cpu().numpy() if dtype != torch.bfloat16 else tensor.float().cpu().numpy() + np_result = torch.from_numpy(np.stack(np_array.nonzero())).t() + self.assertEqual(dst1.cpu(), np_result, atol=0, rtol=0) + self.assertEqual(dst2.cpu(), np_result, atol=0, rtol=0) + self.assertEqual(dst3.cpu(), np_result, atol=0, rtol=0) + tup1 = torch.nonzero(tensor, as_tuple=True) + tup2 = tensor.nonzero(as_tuple=True) + tup1 = torch.stack(tup1).t().cpu() + tup2 = torch.stack(tup2).t().cpu() + self.assertEqual(tup1, np_result, atol=0, rtol=0) + self.assertEqual(tup2, np_result, atol=0, rtol=0) + [helper(dtype) for dtype in self.supported_dtypes] + + def test_nonzero_astuple_out(self): + device = "mps" + t = torch.randn((3, 3, 3), device=device) + out = torch.empty([], dtype=torch.long, device=device) + out = out.resize_(0) + + with self.assertRaises(RuntimeError): + torch.nonzero(t, as_tuple=True, out=out) + + self.assertEqual(torch.nonzero(t, as_tuple=False, out=out), torch.nonzero(t, out=out)) + + # Verifies that JIT script cannot handle the as_tuple kwarg + # See Issue https://github.com/pytorch/pytorch/issues/45499. + def _foo(t): + tuple_result = torch.nonzero(t, as_tuple=True) + nontuple_result = torch.nonzero(t, as_tuple=False) + out = torch.empty_like(nontuple_result) + torch.nonzero(t, as_tuple=False, out=out) + return tuple_result, nontuple_result, out + + with self.assertRaises(RuntimeError): + scripted_foo = torch.jit.script(_foo) + + # Verifies that JIT tracing works fine + traced_foo = torch.jit.trace(_foo, t) + traced_tuple, traced_nontuple, traced_out = traced_foo(t) + expected_tuple = torch.nonzero(t, as_tuple=True) + expected_nontuple = torch.nonzero(t) + + self.assertEqual(traced_tuple, expected_tuple) + self.assertEqual(traced_nontuple, expected_nontuple) + self.assertEqual(traced_out, expected_nontuple) + + def test_nonzero_discontiguous(self): + device = "mps" + shape = (4, 4) + tensor = torch.randint(2, shape, device=device) + tensor_nc = torch.empty(shape[0], shape[1] * 2, device=device)[:, ::2].copy_(tensor) + dst1 = tensor.nonzero(as_tuple=False) + dst2 = tensor_nc.nonzero(as_tuple=False) + self.assertEqual(dst1, dst2, atol=0, rtol=0) + dst3 = torch.empty_like(dst1) + data_ptr = dst3.data_ptr() + # expect dst3 storage to be reused + torch.nonzero(tensor, out=dst3) + self.assertEqual(data_ptr, dst3.data_ptr()) + self.assertEqual(dst1, dst3, atol=0, rtol=0) + # discontiguous out + dst4 = torch.empty(dst1.size(0), dst1.size(1) * 2, dtype=torch.long, device=device)[:, ::2] + data_ptr = dst4.data_ptr() + strides = dst4.stride() + torch.nonzero(tensor, out=dst4) + self.assertEqual(data_ptr, dst4.data_ptr()) + self.assertEqual(dst1, dst4, atol=0, rtol=0) + self.assertEqual(strides, dst4.stride()) + + def test_nonzero_non_diff(self): + device = "mps" + x = torch.randn(10, requires_grad=True) + nz = x.nonzero() + self.assertFalse(nz.requires_grad) + def test_masked_select(self): x = torch.randn(3, 4) x_mps = x.to("mps") @@ -6537,8 +7939,7 @@ def test_index_put_accumulate_duplicate_indices(self, device="mps"): # lots of duplicates interleaved with each other delta = torch.empty(i, dtype=torch.float32, device=device).uniform_(-1, 1) - # cumsum not supported on 'mps', fallback on 'cpu' - indices = delta.cpu().cumsum(0).long().to("mps") + indices = delta.cumsum(0).long().to("mps") # abs for int64 is not supported on mps, fallback on 'cpu' to calculate it input = torch.randn(indices.cpu().abs().max().to("mps") + 1, device=device) @@ -6946,14 +8347,14 @@ def test_no_warning_on_import(self): # On Windows, opening the subprocess with the default CWD makes `import torch` # fail, so just set CWD to this script's directory cwd=os.path.dirname(os.path.realpath(__file__)),).decode("utf-8") - self.assertEquals(out, "") + self.assertEqual(out, "") def _get_not_implemented_op(self): - # This can be changed once we actually implement `torch.bincount` + # This can be changed once we actually implement `torch.histc` # Should return fn, args, kwargs, string_version - return (torch.bincount, - torch.tensor([4], device='mps'), {}, - "torch.bincount(torch.tensor([4, 3, 6, 3, 4], device='mps'))") + return (torch.histc, + torch.tensor([100], device='mps'), {}, + "torch.histc(torch.tensor([4], device='mps', dtype=torch.float))") def test_error_on_not_implemented(self): fn, args, kwargs, _ = self._get_not_implemented_op() @@ -7087,149 +8488,402 @@ def test_serialization_map_location(self): class TestConsistency(TestCase): + # TODO: This is only used while some ops are being added. # This list should contain all ops and dtypes eventually # This can be generated automatically in the `new_mps_allowlist.txt` file # by doing `EXPECTTEST_ACCEPT=1 python test_mps.py TestConsistencyCPU` # You most likely do NOT want to modify this manually ALLOWLIST_OP = { + 'H': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'T': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], '__getitem__': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], '__radd__': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], '__rand__': ['b8', 'i16', 'i32', 'i64', 'u8'], - '__rdiv__': ['f16', 'f32', 'i16', 'i32', 'u8'], - '__rmatmul__': ['f32'], + '__rdiv__': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + '__rmatmul__': ['f32', 'i16', 'i32', 'i64', 'u8'], + '__rmod__': ['f16', 'f32'], '__rmul__': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], '__ror__': ['b8', 'i16', 'i32', 'i64', 'u8'], - '__rpow__': ['f16'], + '__rpow__': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + '__rsub__': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], '__rxor__': ['b8', 'i16', 'i32', 'i64', 'u8'], - 'masked.argmax': ['i16', 'i64', 'u8'], - 'masked.argmin': ['i16', 'i64', 'u8'], - 'masked.log_softmax': ['f32'], - 'masked.logaddexp': ['f32'], - 'masked.norm': ['f16', 'f32'], - 'masked.normalize': ['f16', 'f32'], - 'masked.softmax': ['f32'], - 'masked.softmin': ['f32'], - 'masked.std': ['f32'], - 'masked.var': ['f32'], - 'abs': ['f16', 'f32', 'i16', 'i32', 'u8'], - 'acos': ['f32', 'i16', 'i32', 'u8'], - 'acosh': ['f32', 'i16', 'i32', 'u8'], - 'add': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64'], - 'addbmm': ['f32'], + '_native_batch_norm_legit': ['f32'], + '_softmax_backward_data': ['f32'], + 'abs': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'acos': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'acosh': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'add': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'addbmm': ['f32', 'i16', 'i32', 'i64', 'u8'], 'addcdiv': ['f32'], 'addcmul': ['f32', 'i16', 'i32', 'i64', 'u8'], - 'addmm': ['f32'], - 'addmv': ['f32'], - 'addr': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'addmm': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'addmv': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'addr': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'all': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'allclose': ['f16', 'f32'], + 'amax': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'amin': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'aminmax': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'angle': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'any': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'arange': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'argmax': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'argmin': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'amax': ['f32'], - 'amix': ['f32'], - 'logsumexp': ['f32'], - 'mean': ['f32'], - 'sum': ['f32'], - 'asin': ['f32', 'i16', 'i32', 'u8'], - 'asinh': ['f32', 'i16', 'i32', 'u8'], - 'atan': ['f32', 'i16', 'i32', 'u8'], - 'atan2': ['f32'], - 'atanh': ['f32', 'i16', 'i32', 'u8'], + 'argsort': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'argwhere': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'as_strided': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'as_strided_scatter': ['b8', + 'f16', + 'f32', + 'i16', + 'i32', + 'i64', + 'u8'], + 'asin': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'asinh': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'atan': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'atan2': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'atanh': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], 'atleast_1d': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'atleast_2d': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'atleast_3d': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'baddbmm': ['f32'], + 'baddbmm': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'bernoulli': ['f32'], + 'bfloat16': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'bincount': ['i16', 'i32', 'i64', 'u8'], 'bitwise_and': ['b8', 'i16', 'i32', 'i64', 'u8'], 'bitwise_left_shift': ['i16', 'i32', 'i64', 'u8'], 'bitwise_not': ['b8', 'i16', 'i32', 'i64', 'u8'], 'bitwise_or': ['b8', 'i16', 'i32', 'i64', 'u8'], 'bitwise_right_shift': ['i16', 'i32', 'i64', 'u8'], 'bitwise_xor': ['b8', 'i16', 'i32', 'i64', 'u8'], - 'block_diag': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64'], - 'bmm': ['f32'], + 'block_diag': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'bmm': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'bool': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'broadcast_shapes': ['f32'], + 'broadcast_tensors': ['b8', + 'f16', + 'f32', + 'i16', + 'i32', + 'i64', + 'u8'], + 'broadcast_to': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'bucketize': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'byte': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'cartesian_prod': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'cat': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'ceil': ['f32', 'int32', 'int64', 'f16'], - 'char': ['b8', 'u8'], + 'cdist': ['f32'], + 'cdouble': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'ceil': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'cfloat': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'chalf': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'char': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'cholesky': ['f32'], + 'cholesky_inverse': ['f32'], + 'cholesky_solve': ['f32'], 'chunk': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'clamp': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'clamp_max': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'clamp_min': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'clone': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'column_stack': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'combinations': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'complex': ['f16', 'f32'], 'conj': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'conj_physical': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'constant_pad_nd': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'contiguous': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'corrcoef': ['f32'], - 'cos': ['f32', 'i16', 'i32', 'u8', 'i64'], - 'cosh': ['f32', 'i16', 'i32', 'u8', 'i64'], - 'cov': ['f32'], + 'copysign': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'corrcoef': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'cos': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'cosh': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'count_nonzero': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'cov': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'cross': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'cummax': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'cummin': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'cumprod': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'cumsum': ['f32', 'i16', 'i32', 'i64', 'u8'], 'deg2rad': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'diag': ['f32', 'i32'], - 'diag_embed': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64'], - 'diagflat': ['f32', 'i32'], - 'diagonal_scatter': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64'], - 'diff': ['f16', 'f32', 'i16', 'i32', 'i64'], - 'dist': ['f32'], + 'diag': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'diag_embed': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'diagflat': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'diagonal': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'diagonal_copy': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'diagonal_scatter': ['b8', + 'f16', + 'f32', + 'i16', + 'i32', + 'i64', + 'u8'], + 'diff': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'digamma': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'dist': ['f16', 'f32'], + 'div': ['f16', 'f32', 'u8', 'b8', 'i16', 'i32', 'i64'], 'dot': ['f32', 'i16', 'i32', 'i64', 'u8'], - 'einsum': ['f32'], + 'double': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'dsplit': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'dstack': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'einsum': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'empty': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'empty_like': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'eq': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'equal': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'erf': ['f32', 'i16', 'i32', 'u8'], - 'exp': ['f32', 'i16', 'i32', 'u8'], - 'exp2': ['f16', 'f32', 'i16', 'i32', 'u8'], + 'erf': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'erfc': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'erfinv': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'exp': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'exp2': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'expand': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'expand_as': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'expm1': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], 'eye': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'fft.fft': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'fft.fft2': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'fft.fftn': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'fft.fftshift': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'fft.hfft': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'fft.hfft2': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'fft.hfftn': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'fft.ifft': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'fft.ifft2': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'fft.ifftn': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'fft.ifftshift': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'fft.ihfft': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'fft.ihfft2': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'fft.ihfftn': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'fft.irfft': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'fft.irfft2': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'fft.irfftn': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'fft.rfft': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'fft.rfft2': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'fft.rfftn': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], 'fill': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'flatten': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'flip': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'fliplr': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'flipud': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'float': ['f32'], - 'floor': ['f32', 'f16', 'i16', 'i32', 'i64'], + 'flip': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'fliplr': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'flipud': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'float': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'float_power': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'floor': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'floor_divide': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'fmax': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'fmin': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'fmod': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'frac': ['f16', 'f32'], - 'gradient': ['f16', 'f32', 'i16'], - 'half': ['f16'], + 'frexp': ['f16', 'f32'], + 'full': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'full_like': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'gather': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'gcd': ['i16', 'i32', 'i64', 'u8'], + 'ge': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'geqrf': ['f32'], + 'gradient': ['f16', 'f32', 'i16', 'i32', 'i64'], + 'grid_sampler_2d': ['f32'], + 'gt': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'half': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'heaviside': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'histc': ['f32'], + 'histogram': ['f32'], + 'histogramdd': ['f32'], + 'hsplit': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'hstack': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'index_select': ['f32', 'i16', 'i32', 'i64'], - 'int': ['i32'], + 'hypot': ['f32'], + 'i0': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'igamma': ['f16', 'f32'], + 'igammac': ['f16', 'f32'], + 'index_copy': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'index_fill': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'index_put': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'index_reduce': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'index_select': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'inner': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'int': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'isclose': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'isfinite': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'isin': ['f32', 'i16', 'i32', 'i64', 'u8'], 'isinf': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'isnan': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'isneginf': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'isposinf': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'isreal': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'kron': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'linalg.matrix_norm': ['f16'], + 'kthvalue': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'lcm': ['i16', 'i32', 'i64', 'u8'], + 'ldexp': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'le': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'lerp': ['f32'], + 'lgamma': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'linalg.cholesky': ['f32'], + 'linalg.cholesky_ex': ['f32'], + 'linalg.cond': ['f32'], + 'linalg.cross': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'linalg.det': ['f32'], + 'linalg.eig': ['f32'], + 'linalg.eigh': ['f32'], + 'linalg.eigvals': ['f32'], + 'linalg.eigvalsh': ['f32'], + 'linalg.householder_product': ['f32'], + 'linalg.inv': ['f32'], + 'linalg.inv_ex': ['f32'], + 'linalg.ldl_factor': ['f32'], + 'linalg.ldl_factor_ex': ['f32'], + 'linalg.ldl_solve': ['f32'], + 'linalg.lstsq': ['f32'], + 'linalg.lu': ['f32'], + 'linalg.lu_factor': ['f32'], + 'linalg.lu_factor_ex': ['f32'], + 'linalg.lu_solve': ['f32'], + 'linalg.matrix_norm': ['f16', 'f32'], + 'linalg.matrix_power': ['f32'], + 'linalg.matrix_rank': ['f32'], + 'linalg.multi_dot': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'linalg.norm': ['f16', 'f32'], + 'linalg.pinv': ['f32'], + 'linalg.qr': ['f32'], + 'linalg.slogdet': ['f32'], + 'linalg.solve': ['f32'], + 'linalg.solve_ex': ['f32'], + 'linalg.solve_triangular': ['f32'], 'linalg.svd': ['f32'], + 'linalg.svdvals': ['f32'], + 'linalg.tensorinv': ['f32'], + 'linalg.tensorsolve': ['f32'], + 'linalg.vander': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'linalg.vecdot': ['f32'], 'linalg.vector_norm': ['f16', 'f32'], 'linspace': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'log': ['f32', 'i16', 'i32', 'u8'], - 'log10': ['f32', 'i16', 'i32', 'u8'], - 'log2': ['f32', 'i16', 'i32', 'u8'], - 'log_softmax': ['f32'], + 'log': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'log10': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'log1p': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'log2': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'log_softmax': ['f32', 'b8', 'f16', 'i16', 'i32', 'i64', 'u8'], 'logaddexp': ['f32'], 'logaddexp2': ['f32'], + 'logcumsumexp': ['f32'], + 'logdet': ['f32'], + 'logical_and': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'logical_not': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'logical_or': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'logical_xor': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'logit': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], 'logspace': ['f32', 'i16', 'i32', 'i64', 'u8'], - 'masked_fill': ['f16', 'i16', 'i32', 'i64'], + 'logsumexp': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'long': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'lt': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'lu': ['f32'], + 'lu_solve': ['f32'], + 'lu_unpack': ['f32'], + 'mH': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'mT': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'masked.amax': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'masked.amin': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'masked.argmax': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'masked.argmin': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'masked.cumprod': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'masked.cumsum': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'masked.log_softmax': ['f32'], + 'masked.logaddexp': ['f32'], + 'masked.logsumexp': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'masked.mean': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'masked.median': ['f32'], + 'masked.norm': ['f16', 'f32'], + 'masked.normalize': ['f16', 'f32'], + 'masked.prod': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'masked.softmax': ['f32'], + 'masked.softmin': ['f32'], + 'masked.std': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'masked.sum': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'masked.var': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'masked_fill': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'masked_scatter': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'masked_select': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'matmul': ['f32'], - 'mm': ['f32'], - 'mv': ['f32'], - 'neg': ['f16', 'f32', 'i16', 'i32', 'i64'], + 'matmul': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'matrix_exp': ['f32'], + 'max': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'max_pool2d_with_indices_backward': ['f32'], + 'maximum': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'mean': ['f16', 'f32'], + 'median': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'meshgrid': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'min': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'minimum': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'mm': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'mode': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'movedim': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'msort': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'mul': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'multinomial': ['f32'], + 'mv': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'mvlgamma': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'nan_to_num': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'nanmean': ['f16', 'f32'], + 'nanmedian': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'nanquantile': ['f32'], + 'nansum': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'narrow': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'narrow_copy': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'native_batch_norm': ['f32'], + 'native_dropout_backward': ['b8', + 'f16', + 'f32', + 'i16', + 'i32', + 'i64', + 'u8'], + 'native_layer_norm': ['f32'], + 'ne': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'neg': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'new_empty': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'new_empty_strided': ['b8', + 'f16', + 'f32', + 'i16', + 'i32', + 'i64', + 'u8'], + 'new_full': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'new_ones': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'new_zeros': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'nextafter': ['f32'], + 'nn.functional._scaled_dot_product_attention': ['f32'], + 'nn.functional.adaptive_avg_pool1d': ['f32'], + 'nn.functional.adaptive_avg_pool2d': ['f32'], + 'nn.functional.adaptive_avg_pool3d': ['f16', 'f32'], 'nn.functional.adaptive_max_pool1d': ['f32'], 'nn.functional.adaptive_max_pool2d': ['f32'], + 'nn.functional.adaptive_max_pool3d': ['f32'], + 'nn.functional.alpha_dropout': ['f32'], + 'nn.functional.avg_pool1d': ['f32', 'i64'], + 'nn.functional.avg_pool2d': ['f32', 'i64'], + 'nn.functional.avg_pool3d': ['f32', 'i64'], + 'nn.functional.batch_norm': ['f32'], + 'nn.functional.bilinear': ['f32', 'i16', 'i32', 'i64', 'u8'], 'nn.functional.binary_cross_entropy': ['f32'], 'nn.functional.binary_cross_entropy_with_logits': ['f32'], 'nn.functional.celu': ['f32'], 'nn.functional.conv1d': ['f32'], 'nn.functional.conv2d': ['f32'], 'nn.functional.conv_transpose1d': ['f32'], + 'nn.functional.conv_transpose2d': ['f32'], 'nn.functional.cosine_embedding_loss': ['b8', 'f32', 'i16', 'i32', - 'i64'], + 'i64', + 'u8'], + 'nn.functional.cosine_similarity': ['f32'], + 'nn.functional.cross_entropy': ['f32'], + 'nn.functional.ctc_loss': ['f32'], + 'nn.functional.dropout': ['f32'], + 'nn.functional.dropout2d': ['f32'], + 'nn.functional.dropout3d': ['f32'], 'nn.functional.elu': ['f32'], + 'nn.functional.embedding': ['f16', 'f32'], + 'nn.functional.embedding_bag': ['f16', 'f32'], 'nn.functional.feature_alpha_dropout': ['b8', 'f16', 'f32', @@ -7237,51 +8891,141 @@ class TestConsistency(TestCase): 'i32', 'i64', 'u8'], + 'nn.functional.fractional_max_pool2d': ['f32'], + 'nn.functional.fractional_max_pool3d': ['f32'], 'nn.functional.gaussian_nll_loss': ['f32'], + 'nn.functional.gelu': ['f32'], 'nn.functional.glu': ['f32'], + 'nn.functional.grid_sample': ['f32'], 'nn.functional.group_norm': ['f32'], + 'nn.functional.hardshrink': ['f32'], + 'nn.functional.hardsigmoid': ['f32'], + 'nn.functional.hardswish': ['f32'], 'nn.functional.hardtanh': ['f32', 'i16', 'i32', 'i64'], 'nn.functional.hinge_embedding_loss': ['f32'], - 'nn.functional.huber_loss': ['f32'], + 'nn.functional.huber_loss': ['f16', 'f32'], 'nn.functional.instance_norm': ['f32'], + 'nn.functional.interpolate': ['f32', 'u8'], 'nn.functional.kl_div': ['f32'], 'nn.functional.l1_loss': ['f16', 'f32'], + 'nn.functional.layer_norm': ['f32'], 'nn.functional.leaky_relu': ['f32'], - 'nn.functional.linear': ['f32'], - 'nn.functional.local_response_norm': ['f32'], - 'nn.functional.margin_ranking_loss': ['f32', 'i16', 'i32'], + 'nn.functional.linear': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'nn.functional.local_response_norm': ['f32', 'i64'], + 'nn.functional.logsigmoid': ['f32'], + 'nn.functional.margin_ranking_loss': ['f32', + 'i16', + 'i32', + 'i64', + 'u8'], + 'nn.functional.max_pool1d': ['f32'], + 'nn.functional.max_pool2d': ['f32'], + 'nn.functional.max_pool3d': ['f32'], + 'nn.functional.max_unpool1d': ['f32'], + 'nn.functional.max_unpool2d': ['f32'], + 'nn.functional.max_unpool3d': ['f32'], + 'nn.functional.mish': ['f32'], 'nn.functional.mse_loss': ['f16', 'f32'], - 'nn.functional.pad': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64'], + 'nn.functional.multi_margin_loss': ['f32'], + 'nn.functional.multilabel_margin_loss': ['f32'], + 'nn.functional.multilabel_soft_margin_loss': ['f32'], + 'nn.functional.nll_loss': ['f32'], + 'nn.functional.normalize': ['f32'], + 'nn.functional.one_hot': ['i64'], + 'nn.functional.pad': ['b8', + 'f16', + 'f32', + 'i16', + 'i32', + 'i64', + 'u8'], 'nn.functional.pairwise_distance': ['f16', 'f32', 'i16', 'i32', - 'i64'], - 'nn.functional.poisson_nll_loss': ['f32', 'i16', 'i32', 'u8'], + 'i64', + 'u8'], + 'nn.functional.pdist': ['f32'], + 'nn.functional.pixel_shuffle': ['b8', + 'f16', + 'f32', + 'i16', + 'i32', + 'i64', + 'u8'], + 'nn.functional.pixel_unshuffle': ['b8', + 'f16', + 'f32', + 'i16', + 'i32', + 'i64', + 'u8'], + 'nn.functional.poisson_nll_loss': ['f32', + 'i16', + 'i32', + 'i64', + 'u8'], 'nn.functional.prelu': ['f32'], 'nn.functional.relu': ['f32', 'i16', 'i32', 'i64', 'u8'], 'nn.functional.relu6': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'nn.functional.rrelu': ['f32'], 'nn.functional.selu': ['f32'], 'nn.functional.silu': ['f32'], 'nn.functional.smooth_l1_loss': ['f16', 'f32'], 'nn.functional.soft_margin_loss': ['f32'], - 'nn.functional.softmin': ['f32'], - 'nn.functional.softsign': ['f16', 'f32', 'i16', 'u8'], - 'nn.functional.tanhshrink': ['f32', 'i16', 'i32', 'u8'], + 'nn.functional.softmin': ['f32', 'f16', 'i16', 'i32', 'i64', 'u8'], + 'nn.functional.softshrink': ['f32'], + 'nn.functional.softsign': ['f16', + 'f32', + 'i16', + 'i32', + 'i64', + 'u8'], + 'nn.functional.tanhshrink': ['f32', 'i16', 'i32', 'i64', 'u8'], 'nn.functional.threshold': ['f32', 'i16', 'i32', 'i64', 'u8'], - 'nn.functional.triplet_margin_loss': ['f32', 'i16', 'i32', 'i64'], + 'nn.functional.triplet_margin_loss': ['f32', + 'i16', + 'i32', + 'i64', + 'u8'], 'nn.functional.triplet_margin_with_distance_loss': ['f32', 'i16', 'i32', - 'i64'], + 'i64', + 'u8'], + 'nn.functional.unfold': ['f16', 'f32'], 'nn.functional.upsample_bilinear': ['f32'], + 'nn.functional.upsample_nearest': ['f32', 'u8'], + 'nonzero': ['b8', 'u8', 'f16', 'f32', 'i16', 'i32', 'i64'], 'norm': ['f32', 'f16'], + 'normal': ['f16', 'f32'], + 'ones': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'ones_like': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'ormqr': ['f32'], + 'outer': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'pca_lowrank': ['f32'], + 'permute': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'pinverse': ['f32'], + 'polar': ['f32'], + 'polygamma': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], 'positive': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'pow': ['f16'], + 'pow': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'prod': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'put': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'qr': ['f32'], + 'quantile': ['f32'], 'rad2deg': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'rand_like': ['f16', 'f32'], + 'randint': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'randint_like': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'randn': ['f16', 'f32'], + 'randn_like': ['f16', 'f32'], + 'ravel': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'real': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'reciprocal': ['f16', 'f32', 'i16', 'i32', 'u8'], - 'repeat': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'reciprocal': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'remainder': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'renorm': ['f16', 'f32'], + 'repeat': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'repeat_interleave': ['b8', 'f16', 'f32', @@ -7289,72 +9033,215 @@ class TestConsistency(TestCase): 'i32', 'i64', 'u8'], - 'resize_': ['b8', 'i16', 'i32', 'i64', 'u8'], - 'resize_as_': ['b8', 'i16', 'i32', 'i64', 'u8'], + 'reshape': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'reshape_as': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'resize_': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'resize_as_': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'resolve_conj': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'resolve_neg': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'rot90': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'round': ['f32', 'f16', 'i16', 'i32', 'i64'], - 'rsqrt': ['f32', 'i16', 'i32', 'u8'], - 'select_scatter': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64'], + 'roll': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'rot90': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'round': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'rsqrt': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'rsub': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'scalar_tensor': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'scatter': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'scatter_add': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'scatter_reduce': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'searchsorted': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'segment_reduce': ['f16', 'f32'], + 'select': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'select_scatter': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'sgn': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'short': ['i16'], - 'sigmoid': ['f32'], - 'sign': ['b8', 'f16', 'f32', 'i16', 'i32', 'u8', 'i64'], - 'sin': ['f32', 'i16', 'i32', 'u8'], - 'sinh': ['f32', 'i16', 'i32', 'u8'], - 'slice_scatter': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64'], - 'softmax': ['f32'], + 'short': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'sigmoid': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'sign': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'signal.windows.bartlett': ['f16', 'f32'], + 'signal.windows.blackman': ['f16', 'f32'], + 'signal.windows.cosine': ['f16', 'f32'], + 'signal.windows.exponential': ['f16', 'f32'], + 'signal.windows.gaussian': ['f16', 'f32'], + 'signal.windows.general_cosine': ['f16', 'f32'], + 'signal.windows.general_hamming': ['f16', 'f32'], + 'signal.windows.hamming': ['f16', 'f32'], + 'signal.windows.hann': ['f16', 'f32'], + 'signal.windows.kaiser': ['f16', 'f32'], + 'signbit': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'sin': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'sinc': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'sinh': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'slice': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'slice_scatter': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'softmax': ['f32', 'b8', 'f16', 'i16', 'i32', 'i64', 'u8'], + 'sort': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'special.airy_ai': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'special.bessel_j0': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'special.bessel_j1': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'special.bessel_y0': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'special.bessel_y1': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'special.chebyshev_polynomial_t': ['b8', + 'f32', + 'i16', + 'i32', + 'i64', + 'u8'], + 'special.chebyshev_polynomial_u': ['b8', + 'f32', + 'i16', + 'i32', + 'i64', + 'u8'], + 'special.entr': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'special.erfcx': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'special.hermite_polynomial_h': ['b8', + 'f32', + 'i16', + 'i32', + 'i64', + 'u8'], + 'special.hermite_polynomial_he': ['b8', + 'f32', + 'i16', + 'i32', + 'i64', + 'u8'], + 'special.i0e': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'special.i1': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'special.i1e': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'special.laguerre_polynomial_l': ['b8', + 'f32', + 'i16', + 'i32', + 'i64', + 'u8'], + 'special.log_ndtr': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'special.modified_bessel_i0': ['b8', + 'f32', + 'i16', + 'i32', + 'i64', + 'u8'], + 'special.modified_bessel_i1': ['b8', + 'f32', + 'i16', + 'i32', + 'i64', + 'u8'], + 'special.modified_bessel_k0': ['b8', + 'f32', + 'i16', + 'i32', + 'i64', + 'u8'], + 'special.modified_bessel_k1': ['b8', + 'f32', + 'i16', + 'i32', + 'i64', + 'u8'], 'special.ndtr': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'special.ndtri': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'special.polygamma': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'special.scaled_modified_bessel_k0': ['b8', + 'f32', + 'i16', + 'i32', + 'i64', + 'u8'], + 'special.scaled_modified_bessel_k1': ['b8', + 'f32', + 'i16', + 'i32', + 'i64', + 'u8'], + 'special.spherical_bessel_j0': ['b8', + 'f32', + 'i16', + 'i32', + 'i64', + 'u8'], + 'special.xlog1py': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'special.zeta': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], 'split': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'sqrt': ['f32', 'i16', 'i32', 'u8'], - 'square': ['f16', 'f32'], + 'split_with_sizes': ['b8', + 'f16', + 'f32', + 'i16', + 'i32', + 'i64', + 'u8'], + 'sqrt': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'square': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'squeeze': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'stack': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'sub': ['f32', 'i16', 'i32', 'i64'], + 'std': ['f16', 'f32'], + 'std_mean': ['f16', 'f32'], + 'stft': ['f32'], + 'sub': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'sum': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'sum_to_size': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'svd': ['f32'], + 'svd_lowrank': ['f32'], + 'symeig': ['f32'], 't': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'tan': ['i16', 'i32', 'u8'], - 'tanh': ['f32', 'i16', 'i32', 'u8'], - 'tensordot': ['f32'], - 'tile': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'topk': ['f32'], - 'trapz': ['f16', 'f32', 'i16', 'i32', 'i64'], + 'take': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'take_along_dim': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'tan': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'tanh': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'tensor_split': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'tensordot': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'tile': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'to_sparse': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'topk': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'trace': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'transpose': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'trapezoid': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'trapz': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'triangular_solve': ['f32'], 'tril': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'tril_indices': ['i32', 'i64'], 'triu': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'triu_indices': ['i32', 'i64'], - 'true_divide': ['b8', 'f16', 'f32', 'i16', 'u8'], - 'trunc': ['f32'], + 'true_divide': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'trunc': ['f32', 'i16', 'i32', 'i64', 'u8'], 'unbind': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'unflatten': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'unfold': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'unfold_copy': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'uniform': ['f16', 'f32'], + 'unique_consecutive': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], 'unsqueeze': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'var': ['f16', 'f32'], + 'var_mean': ['f16', 'f32'], + 'vdot': ['f32', 'i16', 'i32', 'i64', 'u8'], 'view': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'view_as': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'view_as_complex': ['f16', 'f32'], + 'view_copy': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'vsplit': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'vstack': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'where': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'xlogy': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'zero_': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'clamp': ['f32', 'i16', 'i32', 'i64', 'u8'], - 'clamp_max': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'clamp_min': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'logical_and': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'logical_or': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'logical_xor': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'where': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8']} - + 'zeros': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'zeros_like': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'index_add': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'nn.functional.softplus': ['f32'], + } ALLOWLIST_OP_GRAD = { + 'H': ['f16', 'f32'], + 'T': ['f16', 'f32'], + '__getitem__': ['f16', 'f32'], '__radd__': ['f16', 'f32'], '__rdiv__': ['f16', 'f32'], '__rmatmul__': ['f32'], + '__rmod__': ['f16', 'f32'], '__rmul__': ['f16', 'f32'], - 'masked.log_softmax': ['f32'], - 'masked.logaddexp': ['f32'], - 'masked.softmax': ['f32'], - 'masked.softmin': ['f32'], - 'masked.std': ['f32'], - 'masked.var': ['f32'], + '__rpow__': ['f32'], + '__rsub__': ['f16', 'f32'], + '_native_batch_norm_legit': ['f32'], + '_softmax_backward_data': ['f32'], 'abs': ['f16', 'f32'], 'acos': ['f32'], 'acosh': ['f32'], @@ -7366,290 +9253,1004 @@ class TestConsistency(TestCase): 'addmv': ['f32'], 'addr': ['f32'], 'all': ['f16', 'f32'], + 'amax': ['f16', 'f32'], + 'amin': ['f16', 'f32'], + 'angle': ['f16', 'f32'], 'any': ['f16', 'f32'], 'arange': ['f16', 'f32'], 'argmax': ['f16', 'f32'], 'argmin': ['f16', 'f32'], + 'argsort': ['f16', 'f32'], + 'argwhere': ['f16', 'f32'], + 'as_strided': ['f16', 'f32'], + 'as_strided_scatter': ['f16', 'f32'], 'asin': ['f32'], 'asinh': ['f32'], 'atan': ['f32'], 'atan2': ['f32'], + 'atanh': ['f32'], 'atleast_1d': ['f16', 'f32'], 'atleast_2d': ['f16', 'f32'], 'atleast_3d': ['f16', 'f32'], 'baddbmm': ['f32'], + 'bernoulli': ['f32'], + 'bfloat16': ['f16', 'f32'], 'block_diag': ['f16', 'f32'], 'bmm': ['f32'], + 'bool': ['f16', 'f32'], 'broadcast_shapes': ['f32'], + 'broadcast_tensors': ['f16', 'f32'], + 'broadcast_to': ['f16', 'f32'], + 'bucketize': ['f16', 'f32'], + 'byte': ['f16', 'f32'], + 'cartesian_prod': ['f16', 'f32'], 'cat': ['f16', 'f32'], + 'cdist': ['f32'], 'ceil': ['f32'], + 'char': ['f16', 'f32'], + 'cholesky': ['f32'], + 'cholesky_inverse': ['f32'], + 'cholesky_solve': ['f32'], 'chunk': ['f16', 'f32'], + 'clamp': ['f32'], + 'clamp_max': ['f16', 'f32'], + 'clamp_min': ['f16', 'f32'], 'clone': ['f16', 'f32'], 'column_stack': ['f16', 'f32'], + 'combinations': ['f16', 'f32'], 'conj': ['f16', 'f32'], 'conj_physical': ['f16', 'f32'], + 'constant_pad_nd': ['f16', 'f32'], 'contiguous': ['f16', 'f32'], + 'copysign': ['f16', 'f32'], 'corrcoef': ['f32'], 'cos': ['f32'], 'cosh': ['f32'], + 'count_nonzero': ['f16', 'f32'], + 'cov': ['f32'], + 'cross': ['f32'], + 'cummax': ['f32'], + 'cummin': ['f32'], + 'cumprod': ['f32'], + 'cumsum': ['f32'], + 'cumulative_trapezoid': ['f32'], 'deg2rad': ['f16', 'f32'], - 'diag': ['f32'], + 'diag': ['f16', 'f32'], 'diag_embed': ['f16', 'f32'], - 'diagflat': ['f32'], + 'diagflat': ['f16', 'f32'], + 'diagonal': ['f16', 'f32'], + 'diagonal_copy': ['f16', 'f32'], 'diagonal_scatter': ['f16', 'f32'], 'diff': ['f16', 'f32'], - 'dist': ['f32'], + 'digamma': ['f32'], + 'dist': ['f16', 'f32'], + 'div': ['f16', 'f32'], 'dot': ['f32'], + 'double': ['f16', 'f32'], + 'dsplit': ['f16', 'f32'], + 'dstack': ['f16', 'f32'], 'einsum': ['f32'], + 'empty_like': ['f16', 'f32'], + 'eq': ['f16', 'f32'], 'erf': ['f32'], + 'erfc': ['f32'], + 'erfinv': ['f32'], 'exp': ['f32'], 'exp2': ['f16', 'f32'], + 'expand': ['f16', 'f32'], + 'expand_as': ['f16', 'f32'], + 'expm1': ['f32'], + 'fft.fftshift': ['f16', 'f32'], + 'fft.hfft': ['f32'], + 'fft.hfft2': ['f32'], + 'fft.hfftn': ['f32'], + 'fft.ifftshift': ['f16', 'f32'], + 'fft.irfft': ['f32'], + 'fft.irfft2': ['f32'], + 'fft.irfftn': ['f32'], 'fill': ['f16', 'f32'], 'flatten': ['f16', 'f32'], 'flip': ['f16', 'f32'], 'fliplr': ['f16', 'f32'], 'flipud': ['f16', 'f32'], - 'float': ['f32'], + 'float': ['f16', 'f32'], + 'float_power': ['f16', 'f32'], 'floor': ['f32'], - 'gradient': ['f32'], - 'half': ['f16'], + 'fmax': ['f16', 'f32'], + 'fmin': ['f16', 'f32'], + 'fmod': ['f16', 'f32'], + 'frac': ['f16', 'f32'], + 'frexp': ['f16', 'f32'], + 'full': ['f16', 'f32'], + 'full_like': ['f16', 'f32'], + 'gather': ['f16', 'f32'], + 'ge': ['f16', 'f32'], + 'gradient': ['f16', 'f32'], + 'grid_sampler_2d': ['f32'], + 'gt': ['f16', 'f32'], + 'half': ['f16', 'f32'], + 'histc': ['f32'], + 'hsplit': ['f16', 'f32'], 'hstack': ['f16', 'f32'], - 'index_select': ['f32'], + 'hypot': ['f32'], + 'i0': ['f32'], + 'index_add': ['f16', 'f32'], + 'index_copy': ['f16', 'f32'], + 'index_fill': ['f16', 'f32'], + 'index_put': ['f16', 'f32'], + 'index_reduce': ['f16', 'f32'], + 'index_select': ['f16', 'f32'], + 'inner': ['f32'], + 'int': ['f16', 'f32'], 'isclose': ['f16', 'f32'], 'isfinite': ['f16', 'f32'], + 'isin': ['f32'], 'isinf': ['f16', 'f32'], 'isnan': ['f16', 'f32'], + 'isneginf': ['f16', 'f32'], + 'isposinf': ['f16', 'f32'], 'isreal': ['f16', 'f32'], - 'kron': ['f32'], - 'linalg.matrix_norm': ['f16'], + 'kron': ['f16', 'f32'], + 'kthvalue': ['f32'], + 'ldexp': ['f16', 'f32'], + 'le': ['f16', 'f32'], + 'lerp': ['f32'], + 'lgamma': ['f32'], + 'linalg.cholesky': ['f32'], + 'linalg.cholesky_ex': ['f32'], + 'linalg.cond': ['f32'], + 'linalg.cross': ['f32'], + 'linalg.det': ['f32'], + 'linalg.eigh': ['f32'], + 'linalg.eigvalsh': ['f32'], + 'linalg.householder_product': ['f32'], + 'linalg.inv': ['f32'], + 'linalg.inv_ex': ['f32'], + 'linalg.ldl_factor': ['f32'], + 'linalg.ldl_factor_ex': ['f32'], + 'linalg.lstsq': ['f32'], + 'linalg.lu': ['f32'], + 'linalg.lu_factor': ['f32'], + 'linalg.lu_factor_ex': ['f32'], + 'linalg.lu_solve': ['f32'], + 'linalg.matrix_norm': ['f16', 'f32'], + 'linalg.matrix_power': ['f32'], + 'linalg.matrix_rank': ['f32'], + 'linalg.multi_dot': ['f32'], + 'linalg.norm': ['f16', 'f32'], + 'linalg.pinv': ['f32'], + 'linalg.qr': ['f32'], + 'linalg.slogdet': ['f32'], + 'linalg.solve': ['f32'], + 'linalg.solve_ex': ['f32'], + 'linalg.solve_triangular': ['f32'], 'linalg.svd': ['f32'], + 'linalg.svdvals': ['f32'], + 'linalg.tensorinv': ['f32'], + 'linalg.tensorsolve': ['f32'], + 'linalg.vander': ['f32'], + 'linalg.vecdot': ['f32'], + 'linalg.vector_norm': ['f16', 'f32'], 'linspace': ['f16', 'f32'], 'log': ['f32'], 'log10': ['f32'], + 'log1p': ['f32'], 'log2': ['f32'], - 'log_softmax': ['f32'], + 'log_softmax': ['f32', 'f16'], 'logaddexp': ['f32'], + 'logaddexp2': ['f32'], + 'logcumsumexp': ['f32'], + 'logdet': ['f32'], + 'logical_and': ['f16', 'f32'], 'logical_not': ['f16', 'f32'], + 'logical_or': ['f16', 'f32'], + 'logical_xor': ['f16', 'f32'], + 'logit': ['f32'], 'logspace': ['f32'], + 'logsumexp': ['f32'], + 'long': ['f16', 'f32'], + 'lt': ['f16', 'f32'], + 'lu': ['f32'], + 'lu_solve': ['f32'], + 'lu_unpack': ['f32'], + 'mH': ['f16', 'f32'], + 'mT': ['f16', 'f32'], + 'masked.amax': ['f16', 'f32'], + 'masked.amin': ['f16', 'f32'], + 'masked.argmax': ['f16', 'f32'], + 'masked.argmin': ['f16', 'f32'], + 'masked.cumprod': ['f32'], + 'masked.cumsum': ['f32'], + 'masked.log_softmax': ['f32'], + 'masked.logaddexp': ['f32'], + 'masked.logsumexp': ['f32'], + 'masked.mean': ['f16', 'f32'], + 'masked.median': ['f32'], + 'masked.norm': ['f16', 'f32'], + 'masked.normalize': ['f16', 'f32'], + 'masked.prod': ['f32'], + 'masked.softmax': ['f32'], + 'masked.softmin': ['f32'], + 'masked.std': ['f32'], + 'masked.sum': ['f16', 'f32'], + 'masked.var': ['f16', 'f32'], + 'masked_fill': ['f16', 'f32'], + 'masked_scatter': ['f16', 'f32'], + 'masked_select': ['f16', 'f32'], 'matmul': ['f32'], + 'matrix_exp': ['f32'], + 'max': ['f16', 'f32'], + 'max_pool2d_with_indices_backward': ['f32'], + 'maximum': ['f16', 'f32'], + 'mean': ['f16', 'f32'], + 'median': ['f32'], + 'meshgrid': ['f16', 'f32'], + 'min': ['f16', 'f32'], + 'minimum': ['f16', 'f32'], 'mm': ['f32'], + 'mode': ['f16', 'f32'], + 'movedim': ['f16', 'f32'], + 'msort': ['f16', 'f32'], + 'mul': ['f16', 'f32'], + 'multinomial': ['f32'], 'mv': ['f32'], + 'mvlgamma': ['f32'], + 'nan_to_num': ['f16', 'f32'], + 'nanmean': ['f16', 'f32'], + 'nanmedian': ['f32'], + 'nanquantile': ['f32'], + 'nansum': ['f16', 'f32'], + 'narrow': ['f16', 'f32'], + 'native_batch_norm': ['f32'], + 'native_dropout_backward': ['f16', 'f32'], + 'native_layer_norm': ['f32'], + 'ne': ['f16', 'f32'], 'neg': ['f16', 'f32'], + 'new_empty': ['f16', 'f32'], + 'new_empty_strided': ['f16', 'f32'], + 'new_full': ['f16', 'f32'], + 'new_ones': ['f16', 'f32'], + 'new_zeros': ['f16', 'f32'], + 'nn.functional._scaled_dot_product_attention': ['f32'], + 'nn.functional.adaptive_avg_pool1d': ['f32'], + 'nn.functional.adaptive_avg_pool2d': ['f32'], + 'nn.functional.adaptive_avg_pool3d': ['f16', 'f32'], 'nn.functional.adaptive_max_pool1d': ['f32'], 'nn.functional.adaptive_max_pool2d': ['f32'], + 'nn.functional.adaptive_max_pool3d': ['f32'], + 'nn.functional.alpha_dropout': ['f32'], + 'nn.functional.avg_pool1d': ['f32'], + 'nn.functional.avg_pool2d': ['f32'], + 'nn.functional.avg_pool3d': ['f32'], + 'nn.functional.batch_norm': ['f32'], + 'nn.functional.bilinear': ['f32'], 'nn.functional.binary_cross_entropy': ['f32'], + 'nn.functional.binary_cross_entropy_with_logits': ['f32'], 'nn.functional.celu': ['f32'], 'nn.functional.conv1d': ['f32'], 'nn.functional.conv2d': ['f32'], 'nn.functional.conv_transpose1d': ['f32'], + 'nn.functional.conv_transpose2d': ['f32'], + 'nn.functional.conv_transpose3d': ['f32'], 'nn.functional.cosine_embedding_loss': ['f32'], + 'nn.functional.cosine_similarity': ['f32'], + 'nn.functional.cross_entropy': ['f32'], + 'nn.functional.ctc_loss': ['f32'], + 'nn.functional.dropout': ['f32'], + 'nn.functional.dropout2d': ['f32'], + 'nn.functional.dropout3d': ['f32'], 'nn.functional.elu': ['f32'], - 'nn.functional.feature_alpha_dropout': ['f16', 'f32'], + 'nn.functional.embedding': ['f16', 'f32'], + 'nn.functional.embedding_bag': ['f16', 'f32'], + 'nn.functional.feature_alpha_dropout': ['f32', 'f16'], + 'nn.functional.fractional_max_pool2d': ['f32'], + 'nn.functional.fractional_max_pool3d': ['f32'], + 'nn.functional.gaussian_nll_loss': ['f32'], + 'nn.functional.gelu': ['f32'], 'nn.functional.glu': ['f32'], + 'nn.functional.grid_sample': ['f32'], + 'nn.functional.group_norm': ['f32'], + 'nn.functional.hardshrink': ['f32'], + 'nn.functional.hardsigmoid': ['f32'], + 'nn.functional.hardswish': ['f32'], 'nn.functional.hardtanh': ['f32'], 'nn.functional.hinge_embedding_loss': ['f32'], - 'nn.functional.huber_loss': ['f32'], + 'nn.functional.huber_loss': ['f16', 'f32'], 'nn.functional.instance_norm': ['f32'], + 'nn.functional.interpolate': ['f32'], 'nn.functional.kl_div': ['f32'], 'nn.functional.l1_loss': ['f16', 'f32'], + 'nn.functional.layer_norm': ['f32'], 'nn.functional.leaky_relu': ['f32'], + 'nn.functional.linear': ['f32'], 'nn.functional.local_response_norm': ['f32'], + 'nn.functional.logsigmoid': ['f32'], 'nn.functional.margin_ranking_loss': ['f32'], + 'nn.functional.max_pool1d': ['f32'], + 'nn.functional.max_pool2d': ['f32'], + 'nn.functional.max_pool3d': ['f32'], + 'nn.functional.max_unpool1d': ['f32'], + 'nn.functional.max_unpool2d': ['f32'], + 'nn.functional.max_unpool3d': ['f32'], + 'nn.functional.mish': ['f32'], 'nn.functional.mse_loss': ['f32'], + 'nn.functional.multi_margin_loss': ['f32'], + 'nn.functional.multilabel_margin_loss': ['f32'], + 'nn.functional.multilabel_soft_margin_loss': ['f32'], + 'nn.functional.nll_loss': ['f32'], + 'nn.functional.normalize': ['f32'], 'nn.functional.pad': ['f16', 'f32'], 'nn.functional.pairwise_distance': ['f16', 'f32'], + 'nn.functional.pdist': ['f32'], + 'nn.functional.pixel_shuffle': ['f16', 'f32'], + 'nn.functional.pixel_unshuffle': ['f16', 'f32'], 'nn.functional.poisson_nll_loss': ['f32'], + 'nn.functional.prelu': ['f32'], 'nn.functional.relu': ['f32'], 'nn.functional.relu6': ['f32'], + 'nn.functional.rrelu': ['f32'], 'nn.functional.selu': ['f32'], 'nn.functional.silu': ['f32'], + 'nn.functional.smooth_l1_loss': ['f32'], 'nn.functional.soft_margin_loss': ['f32'], - 'nn.functional.softmin': ['f32'], + 'nn.functional.softmin': ['f32', 'f16'], + 'nn.functional.softplus': ['f32'], + 'nn.functional.softshrink': ['f32'], 'nn.functional.softsign': ['f16', 'f32'], + 'nn.functional.tanhshrink': ['f32'], 'nn.functional.threshold': ['f32'], 'nn.functional.triplet_margin_loss': ['f32'], 'nn.functional.triplet_margin_with_distance_loss': ['f32'], + 'nn.functional.unfold': ['f16', 'f32'], 'nn.functional.upsample_bilinear': ['f32'], - 'norm': ['f32', 'f16'], + 'nn.functional.upsample_nearest': ['f32'], + 'nonzero': ['f16', 'f32'], + 'norm': ['f16', 'f32'], + 'normal': ['f16', 'f32'], + 'ones': ['f16', 'f32'], + 'ones_like': ['f16', 'f32'], + 'ormqr': ['f32'], + 'outer': ['f16', 'f32'], + 'pca_lowrank': ['f32'], + 'permute': ['f16', 'f32'], + 'pinverse': ['f32'], + 'polygamma': ['f32'], 'positive': ['f16', 'f32'], + 'pow': ['f32'], + 'prod': ['f32'], + 'put': ['f16', 'f32'], + 'qr': ['f32'], + 'quantile': ['f32'], 'rad2deg': ['f16', 'f32'], + 'rand_like': ['f16', 'f32'], + 'randint': ['f16', 'f32'], + 'randint_like': ['f16', 'f32'], + 'randn_like': ['f16', 'f32'], + 'ravel': ['f16', 'f32'], 'real': ['f16', 'f32'], 'reciprocal': ['f16', 'f32'], + 'remainder': ['f16', 'f32'], + 'renorm': ['f16', 'f32'], 'repeat': ['f16', 'f32'], 'repeat_interleave': ['f16', 'f32'], + 'reshape': ['f16', 'f32'], + 'reshape_as': ['f16', 'f32'], 'resolve_conj': ['f16', 'f32'], 'resolve_neg': ['f16', 'f32'], + 'roll': ['f16', 'f32'], + 'rot90': ['f16', 'f32'], 'round': ['f32'], 'rsqrt': ['f32'], + 'rsub': ['f16', 'f32'], + 'scatter': ['f16', 'f32'], + 'scatter_add': ['f16', 'f32'], + 'scatter_reduce': ['f16', 'f32'], + 'searchsorted': ['f16', 'f32'], + 'segment_reduce': ['f16', 'f32'], + 'select': ['f16', 'f32'], 'select_scatter': ['f16', 'f32'], + 'sgn': ['f16', 'f32'], + 'short': ['f16', 'f32'], + 'sigmoid': ['f32'], 'sign': ['f16', 'f32'], + 'signbit': ['f16', 'f32'], 'sin': ['f32'], + 'sinc': ['f32'], 'sinh': ['f32'], + 'slice': ['f16', 'f32'], 'slice_scatter': ['f16', 'f32'], - 'softmax': ['f32'], + 'softmax': ['f32', 'f16'], + 'sort': ['f16', 'f32'], + 'special.airy_ai': ['f32'], + 'special.bessel_j0': ['f32'], + 'special.bessel_j1': ['f32'], + 'special.bessel_y0': ['f32'], + 'special.bessel_y1': ['f32'], + 'special.chebyshev_polynomial_t': ['f32'], + 'special.chebyshev_polynomial_u': ['f32'], + 'special.entr': ['f32'], + 'special.erfcx': ['f32'], + 'special.hermite_polynomial_h': ['f32'], + 'special.hermite_polynomial_he': ['f32'], + 'special.i0e': ['f32'], + 'special.i1': ['f32'], + 'special.i1e': ['f32'], + 'special.laguerre_polynomial_l': ['f32'], + 'special.log_ndtr': ['f32'], + 'special.modified_bessel_i0': ['f32'], + 'special.modified_bessel_i1': ['f32'], + 'special.modified_bessel_k0': ['f32'], + 'special.modified_bessel_k1': ['f32'], + 'special.ndtr': ['f32'], + 'special.ndtri': ['f32'], + 'special.polygamma': ['f32'], + 'special.scaled_modified_bessel_k0': ['f32'], + 'special.scaled_modified_bessel_k1': ['f32'], + 'special.spherical_bessel_j0': ['f32'], + 'special.xlog1py': ['f16', 'f32'], 'split': ['f16', 'f32'], + 'split_with_sizes': ['f16', 'f32'], 'sqrt': ['f32'], 'square': ['f16', 'f32'], 'squeeze': ['f16', 'f32'], 'stack': ['f16', 'f32'], - 'sub': ['f32'], + 'std': ['f16', 'f32'], + 'std_mean': ['f16', 'f32'], + 'sub': ['f16', 'f32'], + 'sum': ['f16', 'f32'], 'sum_to_size': ['f16', 'f32'], 'svd': ['f32'], + 'svd_lowrank': ['f32'], + 'symeig': ['f32'], 't': ['f16', 'f32'], + 'take': ['f16', 'f32'], + 'take_along_dim': ['f16', 'f32'], + 'tan': ['f32'], 'tanh': ['f32'], + 'tensor_split': ['f16', 'f32'], 'tensordot': ['f32'], 'tile': ['f16', 'f32'], + 'to': ['f16', 'f32'], + 'topk': ['f32'], + 'trace': ['f32'], + 'transpose': ['f16', 'f32'], + 'trapezoid': ['f16', 'f32'], + 'trapz': ['f16', 'f32'], + 'triangular_solve': ['f32'], 'tril': ['f16', 'f32'], 'triu': ['f16', 'f32'], 'true_divide': ['f16', 'f32'], 'trunc': ['f32'], 'unbind': ['f16', 'f32'], 'unflatten': ['f16', 'f32'], + 'unfold': ['f16', 'f32'], + 'unfold_copy': ['f16', 'f32'], + 'uniform': ['f16', 'f32'], 'unsqueeze': ['f16', 'f32'], + 'var': ['f16', 'f32'], + 'var_mean': ['f16', 'f32'], + 'vdot': ['f32'], 'view': ['f16', 'f32'], 'view_as': ['f16', 'f32'], + 'view_copy': ['f16', 'f32'], 'vsplit': ['f16', 'f32'], 'vstack': ['f16', 'f32'], - 'zero_': ['f16', 'f32']} + 'where': ['f16', 'f32'], + 'xlogy': ['f16', 'f32'], + 'zero_': ['f16', 'f32'], + 'zeros': ['f16', 'f32'], + 'zeros_like': ['f16', 'f32'], + } + + BLOCKLIST_OP_GRAD = { + # Unimplemented ops + '__getitem__': ['f16'], + 'combinations': ['f16', 'f32'], + 'logaddexp2': ['f32'], + 'masked_select': ['f16', 'f32'], + 'nn.functional.binary_cross_entropy_with_logits': ['f16', 'f32'], + 'nn.functional.group_norm': ['f32'], + 'prod': ['f32'], + 'sgn': ['f16', 'f32'], + 'unfold_copy': ['f16', 'f32'], + 'unfold': ['f16', 'f32'], + 'trace': ['f32'], + + # Hard crash + 'linalg.norm': ['f16'], + 'linalg.norm_subgradients': ['f16'], + 'max': ['f16', 'f32'], + 'maximum': ['f16', 'f32'], + 'min': ['f16', 'f32'], + 'minimum': ['f16', 'f32'], + 'nn.functional.linear': ['f32'], + 'nn.functional.prelu': ['f32'], + 'nn.functional.tanhshrink': ['f32'], + 'sigmoid': ['f32'], + + # Correctness issues + 'nn.functional.conv_transpose2d': ['f32'], + 'atanh': ['f32'], + 'div': ['f16'], + 'gradient': ['f16'], + 'kron': ['f16'], + 'linalg.solve_triangular': ['f32'], + 'linalg.vector_norm': ['f16'], + 'nn.functional.bilinear': ['f32'], + 'nn.functional.cross_entropy': ['f32'], + 'nn.functional.gelu': ['f32'], + 'nn.functional.layer_norm': ['f32'], + 'nn.functional.nll_loss': ['f32'], + 'nn.functional.smooth_l1_loss': ['f32'], + 'std': ['f16'], + 'triangular_solve': ['f32'], + 'var': ['f16'], + 'nn.functional.embedding': ['f16'], + + # Unsupported dtype + 'special.ndtr': ['f32'], + 'trapezoid': ['f16', 'f32'], + 'trapz': ['f16', 'f32'], + } # These ops that are problematic. So never run them even when # generating the new allowlist. # If the dtype list is None, all dtypes are excluded. # All the entries in this list should be removed BLOCKLIST = { - # Functions that hang - 'masked_fill': [torch.bool, torch.uint8, torch.float32], 'where': [torch.bool], - # + forward when requires_grad=True or running backward - 'masked.mean': [torch.bool, torch.float16], - 'masked.prod': [torch.bool], - 'masked.sum': [torch.bool], - # Functions that hard crash - 'nn.functional.kl_div': [torch.int16, torch.int32, torch.int64], - 'nn.functional.nll_loss': [torch.float32], - 'nn.functional.padreflect': [torch.float32], 'nn.functional.padreplicate': [torch.float32], - 'std': [torch.float16], - 'stft': [torch.float32], 'var': [torch.float16], - # + forward when requires_grad=True or running backward - 'index_select': [torch.float16], - 'nn.functional.embedding': [torch.float32, torch.float16], - '__rpow__': [torch.int64], - 'masked.std': [torch.int32], - 'masked.var': [torch.int32], - 'as_strided_scatter': [torch.uint8], - 'atan2': [torch.int64], - 'bfloat16': None, - 'block_diag': [torch.uint8], - 'byte': None, - 'chalf': None, - 'diag_embed': [torch.uint8], - 'diagonal_scatter': [torch.uint8], - 'index_add': None, - 'log1p': None, - 'long': None, - 'nn.functional.avg_pool1d': [torch.int64], - 'nn.functional.avg_pool2d': [torch.int64], + 'sgn': [torch.bool], + 'linalg.inv': [torch.float32], + 'linalg.inv_ex': [torch.float32], + 'linalg.matrix_power': [torch.float32], + 'resize_': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'resize_as_': [torch.float16, torch.float32], + 'topk': [torch.int16, torch.int32, torch.int64, torch.uint8], + + # Functions with correctness issues + 'unique': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'norm': [torch.float16], + 'nn.functional.feature_alpha_dropoutwith_train': [torch.float32], + 'cumulative_trapezoid': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'addr': [torch.float16], + 'as_stridedpartial_views': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'trace': [torch.int64], + 'normalnumber_mean': [torch.float16, torch.float32], + 'new_empty_strided': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'multinomial': [torch.float32], + 'dist': [torch.float16], + + # failure due to issue: atan2() may generate NAN in output with + 'atan2': [torch.bool, torch.int16, torch.int32, torch.uint8], + + # Unsupported Border padding mode + 'grid_sampler_2d': [torch.float32], + 'nn.functional.grid_sample': [torch.float32], + + # failures due to issue #103039644: Wrong results from avgPooling2DWithSourceTensor() + # when both ceilMode and includeZeroPadToAverage are True + 'nn.functional.avg_pool1d': [torch.float32, torch.int64], + 'nn.functional.avg_pool2d': [torch.float32, torch.int64], + 'nn.functional.adaptive_avg_pool1d': [torch.float32], + 'nn.functional.adaptive_avg_pool2d': [torch.float32], + + # failures due to issue #102048039: powerWithPrimaryTensor() with integer input may return wrong results + 'pow': [torch.int16, torch.int32, torch.int64, torch.uint8], + '__rpow__': [torch.int16, torch.int32, torch.uint8, torch.int64], + + # failures before macOS 13.3 + 'nn.functional.conv_transpose2d': [torch.float32], + } + + UNIMPLEMENTED_OPS = { + # Failures due to lack of op implementation on MPS backend + 'linalg.eig': [torch.float32], + 'linalg.eigvals': [torch.float32], + 'fft.fft': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'fft.ifft': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'fft.ihfft2': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'fft.ihfft': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'fft.ihfftn': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'fft.rfft2': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'fft.rfft': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'fft.rfftn': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'put': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'stft': [torch.float32], + 'nn.functional.conv_transpose3d': [torch.int64, torch.float32], + 'rounddecimals_neg_3': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'rounddecimals_3': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'rounddecimals_0': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + '__rmod__': [torch.float16, torch.float32], + '__rsub__': [torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'aminmax': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'angle': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'argsort': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'bucketize': [torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'cholesky': [torch.float32], + 'cholesky_inverse': [torch.float32], + 'cholesky_solve': [torch.float32], + 'copysign': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'cummax': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'cummin': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'cumprod': [torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'digamma': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'erfc': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'erfinv': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'fmax': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'fmin': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'fmod': [torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'frexp': [torch.float16, torch.float32], + 'gcd': [torch.int16, torch.int32, torch.int64, torch.uint8], + 'geqrf': [torch.float32], + 'heaviside': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'histc': [torch.float32], + 'histogram': [torch.float32], + 'histogramdd': [torch.float32], + 'hypot': [torch.float32], + 'i0': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'igamma': [torch.float16, torch.float32], + 'igammac': [torch.float16, torch.float32], + 'index_copy': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'index_fill': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'index_reduce': [torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'isin': [torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'isneginf': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'isposinf': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'kthvalue': [torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'lcm': [torch.int16, torch.int32, torch.int64, torch.uint8], + 'ldexp': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'lerp': [torch.float32], + 'lgamma': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'linalg.cholesky': [torch.float32], + 'linalg.cholesky_ex': [torch.float32], + 'linalg.cond': [torch.float32], + 'linalg.detsingular': [torch.float32], + 'linalg.det': [torch.float32], + 'linalg.eigh': [torch.float32], + 'linalg.eigvalsh': [torch.float32], + 'linalg.householder_product': [torch.float32], + 'linalg.ldl_factor': [torch.float32], + 'linalg.ldl_factor_ex': [torch.float32], + 'linalg.ldl_solve': [torch.float32], + 'linalg.lstsq': [torch.float32], + 'linalg.lstsqgrad_oriented': [torch.float32], + 'linalg.lu': [torch.float32], + 'linalg.lu_factor': [torch.float32], + 'linalg.lu_factor_ex': [torch.float32], + 'linalg.lu_solve': [torch.float32], + 'linalg.matrix_norm': [torch.float32], + 'linalg.norm': [torch.float32], + 'linalg.normsubgradients_at_zero': [torch.float32], + 'linalg.qr': [torch.float32], + 'linalg.slogdet': [torch.float32], + 'linalg.solve': [torch.float32], + 'linalg.solve_ex': [torch.float32], + 'linalg.svdvals': [torch.float32], + 'linalg.tensorsolve': [torch.float32], + 'linalg.vander': [torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'linalg.vecdot': [torch.float32], + 'logcumsumexp': [torch.float32], + 'logdet': [torch.float32], + 'logit': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'lu': [torch.float32], + 'lu_solve': [torch.float32], + 'lu_unpack': [torch.float32], + 'masked.cumprod': [torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'masked.median': [torch.float32], + 'masked_scatter': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'matrix_exp': [torch.float32], + 'mode': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'msort': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'mvlgamma': [torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'mvlgammamvlgamma_p_1': [torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'mvlgammamvlgamma_p_3': [torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'mvlgammamvlgamma_p_5': [torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'nanquantile': [torch.float32], + 'nanmean': [torch.float32, torch.float16], + 'nanmedian': [torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'nansum': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'native_dropout_backward': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'nextafter': [torch.float32], + 'normnuc': [torch.float32], + 'nn.functional._scaled_dot_product_attention': [torch.float32], + 'nn.functional.fractional_max_pool2d': [torch.float32], + 'nn.functional.fractional_max_pool3d': [torch.float32], + 'nn.functional.adaptive_avg_pool3d': [torch.float16, torch.float32], + 'nn.functional.adaptive_max_pool3d': [torch.float32], + 'nn.functional.interpolatearea': [torch.float32], + 'nn.functional.interpolatebicubic': [torch.float32], + 'nn.functional.interpolatelinear': [torch.float32], + 'nn.functional.interpolatetrilinear': [torch.float32], + 'nn.functional.max_unpool1dgrad': [torch.float32], + 'nn.functional.max_unpool2dgrad': [torch.float32], + 'nn.functional.max_unpool3dgrad': [torch.float32], + 'nn.functional.avg_pool3d': [torch.float32, torch.int64], + 'nn.functional.ctc_loss': [torch.float32], + 'nn.functional.embedding_bag': [torch.float16, torch.float32], + 'nn.functional.max_pool2d': [torch.float32], + 'nn.functional.hardshrink': [torch.float32], + 'nn.functional.hardsigmoid': [torch.float32], + 'nn.functional.logsigmoid': [torch.float32], + 'nn.functional.max_pool3d': [torch.float32], + 'nn.functional.max_unpool1d': [torch.float32], + 'nn.functional.max_unpool2d': [torch.float32], + 'nn.functional.max_unpool3d': [torch.float32], + 'nn.functional.mish': [torch.float32], + 'nn.functional.multi_margin_loss': [torch.float32], + 'nn.functional.multilabel_margin_loss': [torch.float32], + 'nn.functional.multilabel_soft_margin_loss': [torch.float32], + 'nn.functional.pdist': [torch.float32], + 'nn.functional.rrelu': [torch.float32], + 'nn.functional.softshrink': [torch.float32], + 'nn.functional.unfold': [torch.float16, torch.float32], + 'nn.functional.norm': [torch.float32], + 'ormqr': [torch.float32], + 'pca_lowrank': [torch.float32], + 'pinverse': [torch.float32], + 'polar': [torch.float32], + 'polygamma': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'polygammapolygamma_n_0': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'polygammapolygamma_n_1': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'polygammapolygamma_n_2': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'polygammapolygamma_n_3': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'polygammapolygamma_n_4': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'qr': [torch.float32], + 'quantile': [torch.float32], + 'remainder': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8], + 'renorm': [torch.float16, torch.float32], + 'roll': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'rsub': [torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'scatter_reduceamax': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'scatter_reduceamin': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'scatter_reducemin': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'scatter_reducemean': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'scatter_reduceprod': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'scatter_reducesum': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'searchsorted': [torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'segment_reduce': [torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'segment_reduceoffsets': [torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'segment_reducelengths': [torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'sinc': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'sort': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'special.airy_ai': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'special.bessel_j0': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'special.bessel_j1': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'special.bessel_y0': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'special.bessel_y1': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'special.chebyshev_polynomial_t': [torch.bool, + torch.float16, + torch.float32, + torch.int16, + torch.int32, + torch.int64, + torch.uint8], + 'special.chebyshev_polynomial_u': [torch.bool, + torch.float16, + torch.float32, + torch.int16, + torch.int32, + torch.int64, + torch.uint8], + 'special.entr': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'special.erfcx': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'special.hermite_polynomial_h': [torch.bool, + torch.float16, + torch.float32, + torch.int16, + torch.int32, + torch.int64, + torch.uint8], + 'special.hermite_polynomial_he': [torch.bool, + torch.float16, + torch.float32, + torch.int16, + torch.int32, + torch.int64, + torch.uint8], + 'special.i0e': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'special.i1': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'special.i1e': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'special.laguerre_polynomial_l': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'special.log_ndtr': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'special.modified_bessel_i0': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'special.modified_bessel_i1': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'special.modified_bessel_k0': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'special.modified_bessel_k1': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'special.ndtri': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'special.polygamma': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'special.polygammaspecial_polygamma_n_0': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'special.scaled_modified_bessel_k0': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'special.scaled_modified_bessel_k1': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'special.spherical_bessel_j0': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'special.xlog1py': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'special.zeta': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'std_mean': [torch.float16, torch.float32], + 'std_meanunbiased': [torch.float16, torch.float32], + 'svd_lowrank': [torch.float32], + 'symeig': [torch.float32], + 'take': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'to_sparse': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'var_mean': [torch.float16, torch.float32], + 'var_meanunbiased': [torch.float16, torch.float32], + 'vdot': [torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'view_as_complex': [torch.float16, torch.float32], + 'xlogy': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + } + + EXPECTED_FAILURES = { + # Failures due to unsupported data types on MPS backend + 'bfloat16': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'chalf': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], 'nn.functional.conv1d': [torch.int64], 'nn.functional.conv2d': [torch.int64], 'nn.functional.conv_transpose1d': [torch.int64], - 'nn.functional.conv_transpose2d': [torch.int64], - 'nn.functional.conv_transpose3d': [torch.int64, torch.float32], - 'nn.functional.huber_loss': [torch.float16], - 'nn.functional.local_response_norm': [torch.int64], - 'nn.functional.padcircular': [torch.uint8], - 'nn.functional.softplus': [torch.float32], - 'pow': [torch.int64], - 'select_scatter': [torch.uint8], - 'sigmoid': [torch.int64], - 'slice_scatter': [torch.uint8], - 'square': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8], # moved from section below - - - # ALLOW_LIST doesn't know about variants - 'nn.functional.padconstant': None, - - # These were moved from ALLOWLIST to BLOCK as they are not working - # locally - 'tile': ['torch.float16', 'torch.float32', 'torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'], - '__radd__': ['torch.bool', 'torch.uint8'], - '__rmul__': ['torch.uint8'], - 'add': ['torch.bool', 'torch.uint8'], - 'addr': ['torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'], - 'diag': ['torch.int64'], - 'diagflat': ['torch.int64'], - - # Functions that are flaky - # These are detected as "ok" by the expect case but actually fail to run sometimes - 'H': None, - 'T': None, - 'as_strided': None, - 'broadcast_tensors': None, - 'broadcast': None, - 'broadcast_to': None, - 'diagonal': None, - 'divfloor_rounding': None, - 'divno_rounding_mode': None, - 'divtrunc_rounding': None, - 'dsplit': None, - 'hsplit': None, - 'empty': None, - 'expand_as': None, - 'expand': None, - 'ge': None, - 'ne': None, - 'le': None, - 'lt': None, - 'gt': None, - 'transpose': None, - 'splitlist_args': None, - 'select': None, - 'reshape': None, - 'reshape_as': None, - 'permute': None, - 'norm': None, - 'nn.functional.pixel_unshuffle': None, - 'nn.functional.pixel_shuffle': None, - 'nn.functional.cross_entropy': None, - 'nn.functional.one_hot': None, - 'narrow': None, - 'movedim': None, - 'minreduction_with_dim': None, - 'minreduction_no_dim': None, - 'minbinary': None, - 'meshgridvariadic_tensors': None, - 'meshgridlist_of_tensors': None, - 'maxreduction_with_dim': None, - 'maxreduction_no_dim': None, - 'maxbinary': None, - 'maximum': None, - 'minimum': None, - 'mT': None, - 'mH': None, - 'outer': None, - 'softmaxwith_dtype': None, - 'rounddecimals_neg_3': None, - 'rounddecimals_3': None, - 'rounddecimals_0': None, - 'normnuc': None, - 'nn.functional.softminwith_dtype': None, - 'nn.functional.feature_alpha_dropoutwith_train': None, - 'log_softmaxwith_dtype': None, - 'split_with_sizes': None, - 'trapezoid': None, - 'eq': None, - 'mul': None, - 'cartesian_prod': None, - 'nonzero': None, - 'bool': None, - 'inner': None, - 'dstack': None, - 'take_along_dim': None, + 'nn.functional.softminwith_dtype': [torch.bool, + torch.float16, + torch.float32, + torch.int16, + torch.int32, + torch.int64, + torch.uint8], + 'log_softmaxwith_dtype': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'softmaxwith_dtype': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + '__rmatmul__': [torch.int16, torch.int32, torch.uint8], + 'addmmdecomposed': [torch.int16, torch.int32, torch.int64, torch.uint8], + 'addbmm': [torch.int16, torch.int32, torch.int64, torch.uint8], + 'addmm': [torch.int16, torch.int32, torch.int64, torch.uint8], + 'addmv': [torch.int16, torch.int32, torch.int64, torch.uint8], + 'baddbmm': [torch.int16, torch.int32, torch.int64, torch.uint8], + 'bmm': [torch.int16, torch.int32, torch.int64, torch.uint8], + 'cdouble': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'cfloat': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'complex': [torch.float16, torch.float32], + 'double': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'einsum': [torch.int16, torch.int32, torch.int64, torch.uint8], + 'fft.fft': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'fft.fft2': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'fft.fftn': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'fft.fftshift': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'fft.hfft': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'fft.hfft2': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'fft.hfftn': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'fft.ifft': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'fft.ifft2': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'fft.ifftn': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'fft.ifftshift': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'fft.ihfft': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'fft.ihfft2': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'fft.ihfftn': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'fft.irfft': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'fft.irfft2': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'fft.irfftn': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'fft.rfft': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'float_power': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'full': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'full_like': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'inner': [torch.int16, torch.int32, torch.int64, torch.uint8], + 'linalg.matrix_rank': [torch.float32], + 'linalg.matrix_rankhermitian': [torch.float32], + 'linalg.multi_dot': [torch.int16, torch.int32, torch.int64, torch.uint8], + 'linalg.pinv': [torch.float32], + 'linalg.pinvhermitian': [torch.float32], + 'log_softmax': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'matmul': [torch.int16, torch.int32, torch.int64, torch.uint8], # MPS device does not support mm for non-float inputs + 'mm': [torch.int16, torch.int32, torch.int64, torch.uint8], + 'mv': [torch.int16, torch.int32, torch.int64, torch.uint8], + 'new_full': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'new_ones': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'new_zeros': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'nn.functional.batch_norm': [torch.float32], + 'nn.functional.bilinear': [torch.int16, torch.int32, torch.int64, torch.uint8], + 'nn.functional.linear': [torch.int16, torch.int32, torch.int64, torch.uint8], + 'nn.functional.softmin': [torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'ones_like': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'signal.windows.blackman': [torch.float16], + 'signal.windows.cosine': [torch.float16], + 'signal.windows.exponential': [torch.float16], + 'signal.windows.gaussian': [torch.float16], + 'signal.windows.general_cosine': [torch.float16], + 'signal.windows.general_hamming': [torch.float16], + 'signal.windows.hamming': [torch.float16], + 'signal.windows.hann': [torch.float16], + 'signal.windows.kaiser': [torch.float16], + 'stft': [torch.float32], + 'tensordot': [torch.int16, torch.int32, torch.int64, torch.uint8], + 'zeros_like': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'bincount': [torch.int16, torch.int32, torch.int64, torch.uint8], + } + + UNDEFINED_BEHAVIOUR = { + # Failures due to random output that they generate using + # Philox engine causing mismatch with CPU results + 'uniform': [torch.float16, torch.float32], + 'rand_like': [torch.float16, torch.float32], + 'randint_like': [torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'randn_like': [torch.float16, torch.float32], + 'bernoulli': [torch.float32], + 'normal': [torch.float16, torch.float32, torch.float16, torch.float32], + 'nn.functional.alpha_dropout': [torch.float32], + 'nn.functional.dropout': [torch.float32], + 'nn.functional.dropout2d': [torch.float32], + 'nn.functional.dropout3d': [torch.float32], + # these fill tensors with uninitialized data, causing mismatch with CPU + 'new_empty': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'empty_like': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'empty': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + # problem 103190467, as_strided_scatter has non-deterministic behavior when the update indices are not unique + 'as_strided_scatter': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + # duplicate indices are used in the testcase - undefined behaviour + 'index_put': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + } + + FAST_MATH_PRECISION_ISSUES = { + # Failures due to precision issues + 'tan': [torch.float32], + 'pow': [torch.float32], + 'masked.softmin': [torch.float32], + 'masked.softmax': [torch.float32], + 'masked.log_softmax': [torch.float32], + 'cdist': [torch.float32], + '__rpow__': [torch.float32] + } + + FP16_LOW_PRECISION_LIST = { + "add", "sub", "div", + "__rdiv__", "__rmul__", + "nn.functional.huber_loss", + "true_divide" + } + + BLOCKLIST_MACOS_12 = { + 'nn.functional.conv_transpose2d': [torch.float32, torch.float16], + } + + ALLOWLIST_MACOS_13_3 = { + 'pow': [torch.int16, torch.int32, torch.int64, torch.uint8], + '__rpow__': [torch.uint8], + 'nn.functional.conv_transpose2d': [torch.float32], } + dirname = os.path.dirname(__file__) + filename = os.path.join(dirname, "cuda_results.yaml") + with open(filename) as f: + data = yaml.safe_load(f) + CUDA_RESULT = dict() + for key, value in data.items(): + CUDA_RESULT[key] = torch.as_tensor(value) + + MPS_SKIP_LIST = reduce(lambda x, y: dict(x, **y), ( + FAST_MATH_PRECISION_ISSUES, BLOCKLIST, UNDEFINED_BEHAVIOUR, EXPECTED_FAILURES, UNIMPLEMENTED_OPS)) + # Used for accept mode only NEW_ALLOW_LIST = defaultdict(list) NEW_ALLOW_LIST_GRAD = defaultdict(list) + product_version = float('.'.join(platform.mac_ver()[0].split('.')[:2])) + + def get_error_message(self, key, op_name, dtype): + if key in self.FAST_MATH_PRECISION_ISSUES and dtype in self.FAST_MATH_PRECISION_ISSUES[key]: + return f"Running test with {op_name} fails due to precision issues (fast math) so skipping" + elif key in self.BLOCKLIST and dtype in self.BLOCKLIST[key]: + return f"Running test with {op_name} fails so skipping" + elif key in self.UNDEFINED_BEHAVIOUR and dtype in self.UNDEFINED_BEHAVIOUR[key]: + return f"Running test with {op_name} fails due to undefined behaviour / random output so skipping" + elif key in self.EXPECTED_FAILURES and dtype in self.EXPECTED_FAILURES[key]: + return f"Running test with {op_name} expected to fail due to unsupported MPS data type so skipping" + elif key in self.UNIMPLEMENTED_OPS and dtype in self.UNIMPLEMENTED_OPS[key]: + return f"Running test with {op_name} expected to fail due to missing op implementation" + elif self.product_version < 13.0 and key in self.BLOCKLIST_MACOS_12 and dtype in self.BLOCKLIST_MACOS_12[key]: + return f"Running test with {op_name} expected to fail on macOS 12" + return None + + def compare_with_CUDA(self, op, mps_out, atol, rtol): + cuda_out = self.CUDA_RESULT[op.name] + try: + self.assertEqual(cuda_out, mps_out, atol=atol, rtol=rtol) + except Exception as e: + return False + else: + return True + @ops(op_db, allowed_dtypes=MPS_DTYPES) def test_output_match(self, device, dtype, op): self.assertEqual(device, "cpu") @@ -7657,9 +10258,15 @@ def test_output_match(self, device, dtype, op): self.skipTest("MPS is not available") key = op.name + op.variant_test_name - if key in self.BLOCKLIST: - if self.BLOCKLIST[key] is None or dtype in self.BLOCKLIST[key]: - self.skipTest(f"Running test with {op.name} hangs so skipping") + if key in self.MPS_SKIP_LIST: + msg = self.get_error_message(key, op.name, dtype) + if msg is not None and not (self.product_version >= 13.3 and + key in self.ALLOWLIST_MACOS_13_3 and dtype in self.ALLOWLIST_MACOS_13_3[key]): + self.skipTest(msg) + if self.product_version < 13.0 and key in self.BLOCKLIST_MACOS_12: + msg = self.get_error_message(key, op.name, dtype) + if msg is not None: + self.skipTest(msg) # Make this an expecttest manually # When this env variable is set, generate a new ALLOWLIST_OP @@ -7677,7 +10284,8 @@ def test_output_match(self, device, dtype, op): if dtype_abbrs[dtype] not in self.ALLOWLIST_OP[op.name]: self.skipTest(f"{op.name} is in the allow list for MPS but {dtype} is excluded") - if op.name not in self.ALLOWLIST_OP_GRAD or dtype_abbrs[dtype] not in self.ALLOWLIST_OP_GRAD[op.name]: + if (op.name not in self.ALLOWLIST_OP_GRAD or dtype_abbrs[dtype] not in self.ALLOWLIST_OP_GRAD[op.name] or + (op.name in self.BLOCKLIST_OP_GRAD and dtype_abbrs[dtype] in self.BLOCKLIST_OP_GRAD[op.name])): run_grad_test = False def get_samples(): @@ -7702,22 +10310,41 @@ def get_samples(): mps_args = [mps_sample.input] + list(mps_sample.args) mps_kwargs = mps_sample.kwargs + # for tensor_split(), the second tensor arg ("tensor_indices_or_sections") must be on CPU only + if (op.name == "tensor_split" and isinstance(mps_args[1], torch.Tensor)): + mps_args[1] = cpu_args[1] + cpu_out = op(*cpu_args, **cpu_kwargs) mps_out = op(*mps_args, **mps_kwargs) if op.name == "nn.functional.conv2d" and dtype == torch.float32: atol = 1e-4 rtol = 3e-5 - elif op.name == "add" and dtype == torch.float16: + elif (op.name in self.FP16_LOW_PRECISION_LIST) and dtype == torch.float16: atol = 1e-2 rtol = 1e-2 + elif (op.name == "masked.mean"): + atol = 7e-4 + rtol = 2e-3 + elif (op.name == "native_layer_norm"): + atol = 1e-4 + rtol = 1.3e-5 else: atol = None rtol = None self.assertEqual(cpu_out, mps_out, atol=atol, rtol=rtol) + if op.name in ["cat"]: + self.assertEqual(cpu_out.is_contiguous(), mps_out.is_contiguous()) + except Exception as e: + if any(s in str(e).lower() for s in ["int64", "macos 13"]): + self.skipTest(f"{str(e)}") + + if op.name in self.CUDA_RESULT and self.compare_with_CUDA(op, mps_out, atol=atol, rtol=rtol): + continue + if not generate_new_truth: raise e forward_failed = True @@ -7769,7 +10396,7 @@ def req_grad(t): cpu_grad_inputs = torch.autograd.grad(diff_cpu_out, diff_cpu_arg, grad_outputs=cpu_grad_outputs, allow_unused=True) mps_grad_inputs = torch.autograd.grad(diff_mps_out, diff_mps_arg, grad_outputs=mps_grad_outputs, allow_unused=True) - self.assertEqual(cpu_grad_inputs, mps_grad_inputs) + self.assertEqual(cpu_grad_inputs, mps_grad_inputs, atol=atol, rtol=rtol) except Exception as e: if not generate_new_truth: raise e @@ -7790,10 +10417,66 @@ def req_grad(t): # So each test append to the dict and write it. with open("new_mps_allowlist_grad.txt", "w") as f: pprint.pprint(self.NEW_ALLOW_LIST_GRAD, stream=f) + + +# Copied from `TestCommon` in `test_ops.py`, just enough to duplicate the `test_numpy_ref` for MPS +@skipIfSlowGradcheckEnv +class TestCommon(TestCase): + + UNIMPLEMENTED_OPS = { + 'aminmax': [torch.float32], + 'roll': [torch.float32], + } + + exact_dtype = True + + # Verifies, on teardown, that no OpInfo is still using dynamic dtypes in CI + @classmethod + def tearDownClass(cls): + super().tearDownClass() + + if IS_CI: + err_msg = ( + "The operator(s) below is(are) using dynamic_dtypes in the OpInfo entries." + "This is OK for testing, but be sure to set the dtypes manually before landing your PR!" + ) + # Assure no opinfo entry has dynamic_dtypes + filtered_ops = list(filter(opinfo.utils.is_dynamic_dtype_set, op_db)) + for op in filtered_ops: + fmt_str = opinfo.utils.str_format_dynamic_dtype(op) + err_msg += "\n" + fmt_str + + assert len(filtered_ops) == 0, err_msg + + # This is the MPS equivalent of `test_numpy_ref` from `test_ops.py`. It lives over here while + # MPS still requires some fairly heavy special casing in the test framework. + # When MPS becomes more consistent, this can probably be merged with that test using + # `@dtypesIfMPS(torch.float32)`, but for now, the assertions themselves need to be loosened + @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") + @onlyMPS + @suppress_warnings + # MPS only supports float32 + @ops(_ref_test_ops, allowed_dtypes=(torch.float32,)) + def test_numpy_ref_mps(self, device, dtype, op): + key = op.name + op.variant_test_name + if key in self.UNIMPLEMENTED_OPS and dtype in self.UNIMPLEMENTED_OPS[key]: + self.skipTest(f"Running test with {op.name} expected to fail due to missing op implementation") + + # Unlike `test_numpy_ref`, this test compares in `float32` since at the time of this test's creation MPS + # does not support float64 Tensors. + # A few ops are currently broken on their reference inputs, but not their sample inputs. These should + # get patched up and this workaround removed. + broken_on_ref_inputs = op.name in ['clamp', 'where'] + inputs = op.reference_inputs(device, dtype) if not broken_on_ref_inputs else op.sample_inputs(device, dtype) + for sample_input in inputs: + self.compare_with_reference(op, op.ref, sample_input) + # TODO: Actually instantiate that test for the "mps" device to better reflect what it is doing. # This requires mps to be properly registered in the device generic test framework which is not the -# case right now. +# case right now. We can probably use `allow_mps` introduced in https://github.com/pytorch/pytorch/pull/87342 +# to achieve this. instantiate_device_type_tests(TestConsistency, globals(), only_for="cpu") +instantiate_device_type_tests(TestCommon, globals(), allow_mps=True) if __name__ == "__main__": run_tests() diff --git a/test/test_native_functions.py b/test/test_native_functions.py index 831998cbf6be2..ba7889e10f4c5 100644 --- a/test/test_native_functions.py +++ b/test/test_native_functions.py @@ -19,6 +19,46 @@ def forward(self, values, incr: Optional[List[int]]): class TestNativeFunctions(TestCase): + def _lists_with_str(self): + return [ + ("foo",), + (2, "foo"), + ("foo", 3), + ["foo"], + [2, "foo"], + ["foo", 3], + "foo", + ] + + def _test_raises_str_typeerror(self, fn): + for arg in self._lists_with_str(): + self.assertRaisesRegex(TypeError, "str", lambda: fn(arg)) + try: + fn(arg) + except TypeError as e: + print(e) + + def test_symintlist_error(self): + x = torch.randn(1) + self._test_raises_str_typeerror(lambda arg: torch._C._nn.pad(x, arg)) + + def test_vararg_symintlist_error(self): + self._test_raises_str_typeerror(lambda arg: torch.rand(arg)) + self._test_raises_str_typeerror(lambda arg: torch.rand(*arg)) + + def test_symintlist_error_with_overload_but_is_unique(self): + x = torch.randn(1) + y = torch.randn(1) + self._test_raises_str_typeerror(lambda arg: x.set_(y, 0, arg)) + + def test_symintlist_error_with_overload(self): + x = torch.randn(1) + self._test_raises_str_typeerror(lambda arg: x.view(arg)) + + def test_intlist_error_with_overload(self): + x = torch.randn(1) + self._test_raises_str_typeerror(lambda arg: torch._C._nn.pad(x, arg)) + # # optional float list # @@ -113,7 +153,7 @@ def fake_module(values, const): self.do_test_optional_intlist_with_module(fake_module) def test_optional_intlist_invalid(self): - with self.assertRaisesRegex(TypeError, "must be .* not"): + with self.assertRaisesRegex(TypeError, "must be .* but found"): IntListWrapperModule()(torch.zeros(1), [0.5]) with self.assertRaisesRegex(RuntimeError, "value of type .* instead found type"): diff --git a/test/test_native_mha.py b/test/test_native_mha.py index 41f56e8b89296..2af1db7395b61 100644 --- a/test/test_native_mha.py +++ b/test/test_native_mha.py @@ -1,5 +1,6 @@ # Owner(s): ["module: nn"] import math +import copy import torch from torch.testing._internal.common_device_type import ( @@ -9,7 +10,7 @@ onlyCUDA, skipMeta, ) -from torch.testing._internal.common_utils import run_tests, TestCase +from torch.testing._internal.common_utils import parametrize, run_tests, TestCase class TestMHADeviceType(TestCase): @torch.no_grad() @@ -116,36 +117,40 @@ def _test_multihead_attention_impl( bs = 16 sl = 8 - q = torch.randn(bs, sl, embed_dim, device=device, dtype=dtype) * 10 + q = 6 * torch.rand(bs, sl, embed_dim, device=device, dtype=torch.float32) - 3 if use_padding: if pad_all: for q_i in q: - q_i[-1] = torch.zeros_like(q[0][-1], device=device, dtype=dtype) + q_i[-1] = torch.zeros_like(q[0][-1], device=device, dtype=torch.float32) mask = torch.zeros(q.shape[:-1], device=device, dtype=torch.bool) for mask_i in mask: mask_i[-1] = True else: - q[0][-1] = torch.zeros_like(q[0][-1], device=device, dtype=dtype) + q[0][-1] = torch.zeros_like(q[0][-1], device=device, dtype=torch.float32) mask = torch.zeros(q.shape[:-1], device=device, dtype=torch.bool) mask[0][-1] = True if mode == "self": k = q v = q elif mode == "encdec": - k = torch.randn(bs, sl, embed_dim, device=device, dtype=dtype) * 10 + k = 6 * torch.rand(bs, sl, embed_dim, device=device, dtype=torch.float32) - 3 v = k elif mode == "generic": - k = torch.randn(bs, sl, embed_dim, device=device, dtype=dtype) * 10 - v = torch.randn(bs, sl, embed_dim, device=device, dtype=dtype) * 10 + k = 6 * torch.rand(bs, sl, embed_dim, device=device, dtype=torch.float32) - 3 + v = 6 * torch.rand(bs, sl, embed_dim, device=device, dtype=torch.float32) - 3 else: self.fail(f"invalid mode `{mode}`!") - qkv = torch.nn.Linear(embed_dim, 3 * embed_dim, device=device, dtype=dtype) - proj = torch.nn.Linear(embed_dim, embed_dim, device=device, dtype=dtype) + qkv = torch.nn.Linear(embed_dim, 3 * embed_dim, device=device, dtype=torch.float32) + native_qkv = copy.deepcopy(qkv).to(dtype=dtype) + + proj = torch.nn.Linear(embed_dim, embed_dim, device=device, dtype=torch.float32) + native_proj = copy.deepcopy(proj).to(dtype=dtype) pt = torch.nn.MultiheadAttention( - embed_dim, num_heads, batch_first=True, device=device, dtype=dtype + embed_dim, num_heads, batch_first=True, device=device, dtype=torch.float32 ) + pt.in_proj_weight = qkv.weight pt.in_proj_bias = qkv.bias pt.out_proj.weight = proj.weight @@ -177,7 +182,7 @@ def forward(self, q, k, v, key_padding_mask): ) npt = NativeMHA( - embed_dim=embed_dim, num_heads=num_heads, qkv=qkv, proj=proj + embed_dim=embed_dim, num_heads=num_heads, qkv=native_qkv, proj=native_proj ).to(dtype) if device == "cuda": @@ -209,8 +214,12 @@ def forward(self, q, k, v, key_padding_mask): k = torch.nested.nested_tensor(torch.unbind(k), device=device, dtype=dtype) v = torch.nested.nested_tensor(torch.unbind(v), device=device, dtype=dtype) + native_q = q.to(dtype=dtype) + native_k = k.to(dtype=dtype) + native_v = v.to(dtype=dtype) + ynpt, weight_npt = npt( - q, k, v, key_padding_mask=mask if use_padding and not use_nt else None + native_q, native_k, native_v, key_padding_mask=mask if use_padding and not use_nt else None ) if use_nt: ynpt = ynpt.to_padded_tensor(0) @@ -244,7 +253,7 @@ def do_pad_all(tensors): weight_npt[0][nh][-1] = torch.zeros_like(weight_npt[0][nh][-1], device=device, dtype=dtype) if dtype == torch.half: - torch.testing.assert_close(ypt, ynpt, atol=1e-3, rtol=1e-3) + torch.testing.assert_close(ypt, ynpt.to(torch.float32), atol=1e-3, rtol=1e-3) else: # High rtol seems necessary for # test_native_multihead_attention_cpu_float32 on Windows, @@ -252,35 +261,40 @@ def do_pad_all(tensors): torch.testing.assert_close(ypt, ynpt, atol=2e-5, rtol=2e-3) if need_weights: - torch.testing.assert_close(weight_pt, weight_npt) + torch.testing.assert_close(weight_pt, weight_npt.to(torch.float32), atol=5e-4, rtol=5e-4) else: self.assertEqual(weight_pt, weight_npt) @dtypesIfCUDA(torch.float, torch.half) @dtypes(torch.float) @skipMeta + @parametrize("use_nt", [False, True]) + @parametrize("use_padding, pad_all", [(False, False), (True, False), (True, True)]) + @parametrize("need_weights", [False]) + @parametrize("average_attn_weights", [False, True]) + @parametrize("fused", [False, True]) @torch.no_grad() - def test_native_multihead_self_attention(self, device, dtype): - for (use_padding, pad_all) in ((False, False), (True, False), (True, True)): - for use_nt in (False, True): - # Figuring out exactly which elements of the weights are garbage in this - # case eludes me, and it's not particularly enlightening to test anyway - # because padding doesn't especially affect the intermediate weights. - for need_weights in (False, not pad_all): - for average_attn_weights in (False, True): - with self.subTest(use_padding=use_padding, pad_all=pad_all, - use_nt=use_nt, need_weights=need_weights, - average_attn_weights=average_attn_weights): - self._test_multihead_attention_impl( - device, - dtype, - "self", - use_nt=use_nt, - use_padding=use_padding, - pad_all=pad_all, - need_weights=need_weights, - average_attn_weights=average_attn_weights, - ) + def test_native_multihead_self_attention(self, device, dtype, use_nt, + need_weights, average_attn_weights, use_padding, pad_all, fused): + for need_weights in (False, not pad_all): + with self.subTest(use_padding=use_padding, pad_all=pad_all, + use_nt=use_nt, need_weights=need_weights, + average_attn_weights=average_attn_weights): + with torch.backends.cuda.sdp_kernel( + enable_flash=False, enable_mem_efficient=False + ) if not fused else torch.backends.cuda.sdp_kernel( + enable_flash=True, enable_mem_efficient=True + ): + self._test_multihead_attention_impl( + device, + dtype, + "self", + use_nt=use_nt, + use_padding=use_padding, + pad_all=pad_all, + need_weights=need_weights, + average_attn_weights=average_attn_weights, + ) @dtypesIfCUDA(torch.float, torch.half) @dtypes(torch.float) diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index 7eb7dead38d3d..7107538863158 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -1,21 +1,34 @@ # Owner(s): ["module: nestedtensor"] +import unittest + +import numpy as np import torch import torch.nn -import unittest from torch.testing._internal.common_device_type import ( dtypes, dtypesIfCUDA, instantiate_device_type_tests, - skipMeta, + onlyCPU, onlyCUDA, - onlyCPU + skipMeta, ) from torch.testing._internal.common_dtype import floating_types_and_half -from torch.testing._internal.common_utils import TestCase, IS_FBCODE, run_tests, freeze_rng_state, parametrize, gradcheck +from torch.testing._internal.common_utils import ( + freeze_rng_state, + gradcheck, + instantiate_parametrized_tests, + IS_FBCODE, + parametrize, + run_tests, + subtest, + TestCase, +) # Tests are ported from pytorch/nestedtensor. # This makes porting as_nested_tensor easier in the future. + + def _iter_constructors(): # yield as_nested_tensor yield torch.nested.nested_tensor @@ -25,6 +38,8 @@ def _iter_constructors(): # an output nested tensor consists of # * `len(ragged_sizes)` matrices # * matrices[i].shape == (20, ragged_sizes[i]) + + def random_nt_noncontiguous_pair(ragged_sizes, device="cpu", dtype=torch.float16): xs = [] for size in ragged_sizes: @@ -41,6 +56,8 @@ def random_nt_noncontiguous_pair(ragged_sizes, device="cpu", dtype=torch.float16 # Helper functions to pad a noncontiguous nested tensor # can be replaced once to_padded_tensor supports noncontiguous memory + + def noncontiguous_to_padded_tensor(input, shape=None): tensors = input.unbind() ntensors = len(tensors) @@ -64,6 +81,8 @@ def noncontiguous_to_padded_tensor(input, shape=None): return result # Helper function to generate a random nested tensor + + def random_nt(device, dtype, num_tensors, max_dims, min_dims=None): if min_dims is None: min_dims = tuple([0] * len(max_dims)) @@ -75,7 +94,78 @@ def random_nt(device, dtype, num_tensors, max_dims, min_dims=None): ts1.append(t1) return torch.nested.nested_tensor(ts1, device=device, dtype=dtype) + class TestNestedTensor(TestCase): + @parametrize("batch_size", [2, 4]) + @parametrize("max_seq_len", [3, 5]) + @parametrize("vocab_size", [10, 20]) + def test_2d_nested_tensor(self, batch_size, max_seq_len, vocab_size): + data = [] + nested_tensor_ref_list = [] + for _ in range(batch_size): + if max_seq_len == 0: + length = 0 + else: + length = np.random.randint(low=1, high=max_seq_len) + row = list(np.random.randint(low=0, high=vocab_size, size=(length,))) + data.append(row) + nested_tensor_ref_list.append(torch.tensor(row)) + nested_tensor = torch.nested.nested_tensor(data, dtype=torch.int64) + nested_tensor_list = nested_tensor.unbind() + for id in range(batch_size): + self.assertEqual( + nested_tensor_list[id], + nested_tensor_ref_list[id].type(torch.int64) + ) + + @parametrize("batch_size", [2, 4]) + @parametrize("max_seq_len", [3, 5]) + @parametrize("vocab_size", [10, 20]) + def test_3d_nested_tensor(self, batch_size, max_seq_len, vocab_size): + data = [] + nested_tensor_ref_list = [] + for _ in range(batch_size): + if max_seq_len == 0: + length = 0 + else: + length = np.random.randint(low=1, high=max_seq_len) + row = list(np.random.randint(low=0, high=vocab_size, size=(length,))) + row = [list(item * np.arange(max_seq_len)) for item in row] + data.append(row) + nested_tensor_ref_list.append(torch.Tensor(row)) + nested_tensor = torch.nested.nested_tensor(data, dtype=torch.int64) + nested_tensor_list = nested_tensor.unbind() + for id in range(batch_size): + self.assertEqual( + nested_tensor_list[id], + nested_tensor_ref_list[id].type(torch.int64) + ) + + @parametrize("batch_size", [2, 4]) + @parametrize("max_seq_len", [3, 5]) + @parametrize("vocab_size", [10, 20]) + def test_3d_nested_tensor_float(self, batch_size, max_seq_len, vocab_size): + data = [] + nested_tensor_ref_list = [] + for _ in range(batch_size): + if max_seq_len == 0: + length = 0 + else: + length = np.random.randint(low=1, high=max_seq_len) + row = list( + np.random.randint(low=0, high=vocab_size, size=(length,)).astype(float) + ) + row = [list(item * np.arange(max_seq_len)) for item in row] + data.append(row) + nested_tensor_ref_list.append(torch.Tensor(row)) + nested_tensor = torch.nested.nested_tensor(data, dtype=torch.float) + nested_tensor_list = nested_tensor.unbind() + for id in range(batch_size): + self.assertEqual( + nested_tensor_list[id], + nested_tensor_ref_list[id].type(torch.float) + ) + @torch.inference_mode() def _test_unbind_case(self, a, b): @@ -133,7 +223,6 @@ def _test_fn(unbind_fn): @torch.inference_mode() def test_nested_tensor(self): - self.assertRaises(TypeError, lambda: torch.nested.nested_tensor([3.0])) self.assertRaises(TypeError, lambda: torch.nested.nested_tensor(torch.tensor([3.0]))) self.assertRaises(TypeError, lambda: torch.nested.nested_tensor(4.0)) @@ -209,9 +298,7 @@ def test_size(self): a1 = constructor([]) self.assertRaisesRegex( RuntimeError, - "Tensors of type NestedTensorImpl do not have sym sizes" - if IS_FBCODE - else "NestedTensorImpl doesn't support sizes", + "NestedTensorImpl doesn't support sizes", lambda: a1.size(), ) @@ -287,20 +374,6 @@ def test_repr_string(self): self.assertEqual(str(a), expected) self.assertEqual(repr(a), expected) - @torch.inference_mode() - def test_activations(self): - for func in (torch.nn.functional.relu, - torch.nn.functional.relu_, - torch.nn.functional.gelu, - torch._C._nn.gelu_, - torch.tanh, - torch.tanh_): - t = torch.tensor([-1, 0, 1], dtype=torch.float) - nt = torch.nested.nested_tensor([t]) - nested_result = func(nt) - self.assertTrue(nested_result.is_nested) - self.assertEqual(func(t), nested_result.unbind()[0]) - def test_to_padded_tensor_on_empty_tensor(self): nt = torch.nested.nested_tensor([]) @@ -365,6 +438,66 @@ def test_data_ptr(getter): self.assertIs(torch.int32, nt2.to(dtype=torch.int32).dtype) self.assertEqual(nt2.device, nt2.to(dtype=torch.int32).device) + def test_copy_(self): + ntensors = 4 + nt = random_nt(torch.device('cpu'), torch.float32, ntensors, (4, 4)) + nt_copy = torch.empty_like(nt) + nt_copy.copy_(nt) + + for (nt_ub, nt_copy_ub) in zip(nt.unbind(), nt_copy): + self.assertEqual(nt_ub, nt_copy_ub) + + nt_error = torch.nested.nested_tensor([torch.tensor([0, 0])]) + self.assertRaisesRegex( + RuntimeError, + "copy_ only supports tensors that are the same size for Nested implementations", + lambda: nt_error.copy_(nt) + ) + + if torch.cuda.is_available(): + nt = random_nt(torch.device('cuda'), torch.float32, ntensors, (4, 4)) + nt_copy = torch.empty_like(nt, device=torch.device('cpu')) + nt_copy.copy_(nt, non_blocking=True) + torch.cuda.current_stream(torch.cuda.current_device()).synchronize() + for (nt_ub, nt_copy_ub) in zip(nt.unbind(), nt_copy): + self.assertEqual(nt_ub, nt_copy_ub) + + nt_copy = torch.empty_like(nt, device=torch.device('cpu')) + nt_copy.copy_(nt, non_blocking=False) + for (nt_ub, nt_copy_ub) in zip(nt.unbind(), nt_copy): + self.assertEqual(nt_ub, nt_copy_ub) + + def test_fill_(self): + ntensors = 4 + nt = random_nt(torch.device('cpu'), torch.float32, ntensors, (4, 4)) + nt.fill_(10.) + for nt_ub in nt.unbind(): + t = torch.empty_like(nt_ub) + t.fill_(10.) + self.assertEqual(nt_ub, t) + + fill_tensor = torch.tensor([11.]) + self.assertRaisesRegex( + RuntimeError, + "fill_ only supports 0-dimension value tensor", + lambda: nt.fill_(fill_tensor) + ) + + nt.fill_(fill_tensor[0]) + for nt_ub in nt.unbind(): + t = torch.empty_like(nt_ub) + t.fill_(11.) + self.assertEqual(nt_ub, t) + + def test_ones_like(self): + ntensors = 4 + nt = random_nt(torch.device('cpu'), torch.float32, ntensors, (4, 4)) + ones_nt = torch.ones_like(nt) + + for nt_ub in ones_nt.unbind(): + t = torch.ones_like(nt_ub) + self.assertEqual(nt_ub, t) + class TestNestedTensorDeviceType(TestCase): @@ -410,7 +543,6 @@ def test_detach(self, device, dtype): self.assertEqual(a.grad, torch.ones(2, 4, device=device, dtype=dtype)) self.assertEqual(b.grad, torch.ones(5, 4, device=device, dtype=dtype)) - @dtypes(torch.float, torch.float16, torch.double) def test_unbind_noncontiguous(self, device, dtype): nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7), device, dtype) @@ -668,6 +800,13 @@ def test_nested_tensor_indexing(self, device, dtype): self.assertEqual(nt[1, ...], x1) self.assertRaises(IndexError, lambda: nt[1, 4, 2]) self.assertRaises(NotImplementedError, lambda: nt[:, 1, 1]) + # test select on non-batch dimensions + self.assertEqual(nt.select(1, 0)[0], x0.select(0, 0)) + self.assertEqual(nt.select(1, 0)[1], x1.select(0, 0)) + self.assertRaises(IndexError, lambda: nt.select(1, 3)) + self.assertEqual(nt.select(2, 0)[0], x0.select(1, 0)) + self.assertEqual(nt.select(2, 0)[1], x1.select(1, 0)) + self.assertRaises(IndexError, lambda: nt.select(2, 5)) # make sure indexing returns a view nt[0].fill_(100.0) answer = torch.tensor(100.0, device=device, dtype=dtype).expand((2, 5)) @@ -686,6 +825,24 @@ def test_nested_tensor_indexing(self, device, dtype): expected_grad = torch.nested.nested_tensor([grad_x0, torch.zeros((3, 4), device=device, dtype=dtype)]) self.assertEqual(nt.grad, expected_grad) + @parametrize("func", [subtest(torch.nn.functional.relu, name='relu'), + subtest(torch.nn.functional.relu_, name='relu_'), + subtest(torch.nn.functional.gelu, name='gelu'), + subtest(torch._C._nn.gelu_, name='gelu_'), + subtest(torch.tanh, name='tanh'), + subtest(torch.tanh_, name='tanh_'), + subtest(torch.neg, name='neg')]) + def test_activations(self, device, func): + nt, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7), device=device, dtype=torch.float32) + nested_result = func(nt) + self.assertTrue(nested_result.is_nested) + for t, t_res in zip(nt.unbind(), nested_result.unbind()): + self.assertEqual(func(t), t_res) + self.assertRaisesRegex( + RuntimeError, + "NestedTensor must be contiguous to get buffer.", + lambda: func(nt_noncontiguous)) + @dtypes(*floating_types_and_half()) def test_nested_tensor_chunk(self, device, dtype): # Transformer use case @@ -755,6 +912,21 @@ def test_nested_tensor_add(self, device, dtype): out = nt1 + nt2 self.assertEqual(ref, out) + @onlyCUDA + @dtypes(torch.float, torch.float16) + @torch.inference_mode() + @parametrize("embedding_dim", [8, 128, 256, 384]) + def test_nested_tensor_dense_elementwise(self, device, dtype, embedding_dim): + batch_size = 32 + seq_lens = torch.randint(low=0, high=10, size=(batch_size,)) + ts = [torch.randn((seq_len, embedding_dim)) for seq_len in seq_lens] + nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) + t = torch.randn((batch_size, 1, embedding_dim), device=device, dtype=dtype) + ref_add = torch.nested.nested_tensor([t1 + t2 for (t1, t2) in zip(nt.unbind(), t.unbind())]) + ref_mul = torch.nested.nested_tensor([t1 * t2 for (t1, t2) in zip(nt.unbind(), t.unbind())]) + self.assertEqual(nt.add(t), ref_add) + self.assertEqual(nt.mul(t), ref_mul) + @dtypes(torch.float, torch.float16) @skipMeta @torch.inference_mode() @@ -789,6 +961,38 @@ def test_nested_tensor_mul(self, device, dtype): lambda: vector.mul(nt1) ) + @dtypes(torch.float, torch.float16) + @skipMeta + @torch.inference_mode() + def test_nested_tensor_div(self, device, dtype): + nt, nt2 = self.random_nt_pair(device, dtype, 4, (4, 4)) + scale = 4.0 + ref = torch.nested.nested_tensor([t / scale for t in nt.unbind()]) + out = nt / 4.0 + self.assertEqual(ref, out) + ref_transposed = ref.transpose(1, 2) + out = nt.transpose(1, 2) / 4.0 + self.assertEqual(ref_transposed, out) + + ref = torch.nested.nested_tensor([t / t2 for (t, t2) in zip(nt.unbind(), nt2.unbind())]) + out = nt / nt2 + self.assertEqual(ref, out) + + out = nt.transpose(1, 2) / nt2.transpose(1, 2) + self.assertEqual(ref.transpose(1, 2), out) + + nt_transpose_copy = torch.nested.nested_tensor([t.transpose(0, 1) for t in nt.unbind()]) + + self.assertRaisesRegex( + RuntimeError, "div requires strides to match when given NestedTensors", + lambda: nt_transpose_copy.transpose(1, 2) / nt2) + + nt = torch.nested.nested_tensor([torch.randn(i, 4) for i in [3, 4, 5]], device=device, dtype=dtype) + nt_chunks = nt.chunk(2, -1) + self.assertRaisesRegex( + RuntimeError, "div requires offsets to match when given NestedTensors", + lambda: nt_chunks[0] / nt_chunks[1]) + @dtypes(torch.float, torch.float16) @skipMeta @torch.inference_mode() @@ -1115,6 +1319,16 @@ def _test_bmm(self, device, dtype): else: self.assertEqual(actual, expect) + # test tensorcore path + nt0 = torch.nested.nested_tensor([torch.randn((2, 8)), torch.randn((3, 16))], device=device, dtype=dtype) + nt1 = torch.nested.nested_tensor([torch.randn((8, 8)), torch.randn((16, 8))], device=device, dtype=dtype) + actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0) + expect = torch.nested.to_padded_tensor(nt0, 0.0).bmm(torch.nested.to_padded_tensor(nt1, 0.0)) + if dtype == torch.float16: + self.assertEqual(actual, expect, rtol=1e-3, atol=1e-3) + else: + self.assertEqual(actual, expect) + @onlyCUDA @dtypes(torch.float, torch.double, torch.float16) def test_bmm_cuda(self, device, dtype): @@ -1126,15 +1340,48 @@ def test_bmm_cuda(self, device, dtype): def test_bmm_cpu(self, device, dtype): self._test_bmm(device, dtype) - # TODO: Re-enable this test once bmm supports non-contiguous inputs. - # # cannot test torch.float16 because: RuntimeError: "addmm_impl_cpu_" not implemented for 'Half' - # @dtypes(torch.float, torch.double) - # def test_bmm_noncontiguous(self, device, dtype): - # nt0_contiguous, nt0_noncontiguous = random_nt_noncontiguous_pair((2, 3), device, dtype) - # nt1_contiguous, nt1_noncontiguous = random_nt_noncontiguous_pair((6, 7), device, dtype) - # self.assertEqual( - # nt0_contiguous.transpose(-1, -2).bmm(nt1_contiguous), - # nt0_noncontiguous.transpose(-1, -2).bmm(nt1_noncontiguous)) + # cannot test torch.float16 because: RuntimeError: "addmm_impl_cpu_" not implemented for 'Half' + @dtypes(torch.float, torch.double) + def test_bmm_noncontiguous(self, device, dtype): + nt0_contiguous, nt0_noncontiguous = random_nt_noncontiguous_pair((2, 3), device, dtype) + nt1_contiguous, nt1_noncontiguous = random_nt_noncontiguous_pair((6, 7), device, dtype) + self.assertEqual( + nt0_contiguous.transpose(-1, -2).bmm(nt1_contiguous), + nt0_noncontiguous.transpose(-1, -2).bmm(nt1_noncontiguous)) + + @dtypes(torch.float, torch.double) + def test_matmul_with_bmm_path(self, device, dtype): + def unbind_rebind_matmul(nt1, nt2): + t1s = nt1.unbind() + t2s = nt2.unbind() + out_ts = [t1.matmul(t2) for t1, t2 in zip(t1s, t2s)] + return torch.nested.nested_tensor(out_ts) + + # [N, n_head, *, head_dim], [N, n_head, head_dim, *] + N = np.random.randint(2, 5) + n_heads = np.random.randint(2, 5) + head_dim = 3 + t1s = [] + t2s = [] + for _ in range(N): + seq_len1 = np.random.randint(2, 5) + seq_len2 = np.random.randint(2, 5) + t1s.append(torch.randn(n_heads, seq_len1, head_dim)) + t2s.append(torch.randn(n_heads, head_dim, seq_len2)) + nt1 = torch.nested.nested_tensor(t1s, device=device, dtype=dtype) + nt2 = torch.nested.nested_tensor(t2s, device=device, dtype=dtype) + self.assertEqual(torch.matmul(nt1, nt2), unbind_rebind_matmul(nt1, nt2)) + + # test with noncontiguous + t3s = [] + t4s = [] + for _ in range(N): + seq_len = np.random.randint(2, 5) + t3s.append(torch.randn(seq_len, n_heads, head_dim)) + t4s.append(torch.randn(seq_len, n_heads, head_dim)) + nt3 = torch.nested.nested_tensor(t3s, device=device, dtype=dtype).transpose(1, 2) + nt4 = torch.nested.nested_tensor(t4s, device=device, dtype=dtype).transpose(1, 2).transpose(2, 3) + self.assertEqual(torch.matmul(nt3, nt4), unbind_rebind_matmul(nt3, nt4)) # cannot test torch.float16 because: RuntimeError: "bmm" not implemented for 'Half' @dtypes(torch.float, torch.double) @@ -1656,39 +1903,38 @@ class TestNestedTensorAutograd(TestCase): # Note [Gradcheck args check_batched_grad=False] the common_utils testing version of gradcheck # includes the default parameters used for testing ops with gradcheck. However nested tensor # does not support the stack op therefore we turn it off for these tests - def _create_leaf_nested_tensor_from_list(self, requires_grad=False): - return torch.nested.nested_tensor([torch.randn(1, 2), - torch.randn(7, 8)], requires_grad=requires_grad) + def _create_leaf_nested_tensor_from_list(self, tensor_device, requires_grad=False): + return torch.nested.nested_tensor([torch.randn(1, 2,), + torch.randn(7, 8)], requires_grad=requires_grad, device=tensor_device) - def _create_nested_tensor_from_list(self, requires_grad=False): + def _create_nested_tensor_from_list(self, tensor_device, requires_grad=False): return torch.nested.as_nested_tensor([torch.randn(1, 2, requires_grad=requires_grad), - torch.randn(7, 8, requires_grad=requires_grad)]) + torch.randn(7, 8, requires_grad=requires_grad)], device=tensor_device) - - def _create_nested_tensor_from_mask(self, requires_grad=False): - data = torch.randn(2, 3, 4, requires_grad=requires_grad) + def _create_nested_tensor_from_mask(self, tensor_device, requires_grad=False): + data = torch.randn(2, 3, 4, requires_grad=requires_grad, device=tensor_device) mask = torch.ones_like(data[:, :, 0]).bool() return torch._nested_tensor_from_mask(data, mask) - def test_as_nested_tensor_propagates_gradients(self): - a = torch.arange(3, dtype=torch.float) - b = torch.arange(5, dtype=torch.float) + def test_as_nested_tensor_propagates_gradients(self, device): + a = torch.arange(3, dtype=torch.float, device=device) + b = torch.arange(5, dtype=torch.float, device=device) nt = torch.nested.as_nested_tensor([a, b]) # tensors with requires_grad=False are leaves self.assertTrue(nt.is_leaf) self.assertTrue(not nt.requires_grad) - a = torch.arange(3, dtype=torch.float, requires_grad=True) - b = torch.arange(5, dtype=torch.float, requires_grad=True) + a = torch.arange(3, dtype=torch.float, requires_grad=True, device=device) + b = torch.arange(5, dtype=torch.float, requires_grad=True, device=device) nt2 = torch.nested.as_nested_tensor([a, b]) - fake_grad = torch.nested.nested_tensor([torch.ones_like(a), torch.zeros_like(b)]) + fake_grad = torch.nested.nested_tensor([torch.ones_like(a), torch.zeros_like(b)], device=device) nt2.backward(fake_grad) self.assertEqual(a.grad, fake_grad[0]) self.assertEqual(b.grad, fake_grad[1]) - def test_nested_tensor_generates_leaf(self): - a = torch.arange(3, dtype=torch.float, requires_grad=True) - b = torch.arange(5, dtype=torch.float, requires_grad=True) + def test_nested_tensor_generates_leaf(self, device): + a = torch.arange(3, dtype=torch.float, requires_grad=True, device=device) + b = torch.arange(5, dtype=torch.float, requires_grad=True, device=device) nt = torch.nested.nested_tensor([a, b], requires_grad=False) self.assertTrue(nt.is_leaf) @@ -1698,33 +1944,32 @@ def test_nested_tensor_generates_leaf(self): self.assertTrue(nt2.is_leaf) self.assertTrue(nt2.requires_grad) - fake_grad = torch.nested.nested_tensor([torch.ones_like(a), torch.zeros_like(b)]) + fake_grad = torch.nested.nested_tensor([torch.ones_like(a), torch.zeros_like(b)], device=device) nt2.backward(fake_grad) self.assertEqual(nt2.grad, fake_grad) self.assertEqual(a.grad, None) self.assertEqual(b.grad, None) - - def test_set_requires_grad_from_list(self): - nt = self._create_nested_tensor_from_list() + def test_set_requires_grad_from_list(self, device): + nt = self._create_nested_tensor_from_list(device) nt.requires_grad_() assert nt.requires_grad - def test_set_requires_grad_from_mask(self): - nt = self._create_nested_tensor_from_mask() + def test_set_requires_grad_from_mask(self, device): + nt = self._create_nested_tensor_from_mask(device) nt.requires_grad_() assert nt.requires_grad - def test_backward_for_add_op(self): - nt_1 = self._create_nested_tensor_from_mask() - nt_2 = self._create_nested_tensor_from_mask() + def test_backward_for_add_op(self, device): + nt_1 = self._create_nested_tensor_from_mask(device) + nt_2 = self._create_nested_tensor_from_mask(device) nt_1.requires_grad_() c = nt_1 + nt_2 assert nt_1.requires_grad assert c.requires_grad - grad_output = self._create_nested_tensor_from_mask() + grad_output = self._create_nested_tensor_from_mask(device) c.backward(grad_output) # Grad check doesn't work with nested yet. @@ -1732,27 +1977,27 @@ def test_backward_for_add_op(self): self.assertEqual(nt_1.grad, grad_output) # Test Factory Functions - def test_nested_tensor_to_padded_tensor(self): + def test_nested_tensor_to_padded_tensor(self, device): for padding_val in [0, 1]: - nt = self._create_leaf_nested_tensor_from_list(True) + nt = self._create_leaf_nested_tensor_from_list(tensor_device=device, requires_grad=True) out = torch.nested.to_padded_tensor(nt, padding_val) - grad_output = torch.ones(out.shape) + grad_output = torch.ones(out.shape, device=device) out.backward(grad_output) - self.assertEqual(nt.grad, torch.nested.nested_tensor([torch.ones(1, 2), torch.ones(7, 8)])) + self.assertEqual(nt.grad, torch.nested.nested_tensor([torch.ones(1, 2), torch.ones(7, 8)], device=device)) - def test_nested_tensor_from_mask_and_to_padded(self): + def test_nested_tensor_from_mask_and_to_padded(self, device): N, L, D = 2, 4, 4 - mask = torch.ones(N, L) + mask = torch.ones(N, L, device=device) for i in range(1, N): - end = torch.randint(1, L - 1, (1,)) + end = torch.randint(1, L - 1, (1,), device=device) mask[i, end:] = 0 mask[0, :] = 1 mask = mask.bool() - data = torch.randn(N, L, D, requires_grad=True, dtype=torch.float64) + data = torch.randn(N, L, D, requires_grad=True, dtype=torch.float64, device=device) def grad_test_func(inpt): nt = torch._nested_tensor_from_mask(inpt, mask) @@ -1760,9 +2005,9 @@ def grad_test_func(inpt): return torch.nested.to_padded_tensor(nt, 0) assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) - def test_nested_tensor_from_padded(self): + def test_nested_tensor_from_padded(self, device): nested_size = torch.tensor([[1, 2], [2, 2]]) - padded_tensor = torch.randn(2, 2, 2, dtype=torch.float64) + padded_tensor = torch.randn(2, 2, 2, dtype=torch.float64, device=device) padded_tensor[0, 1, :] = 0 padded_tensor.requires_grad_() @@ -1774,9 +2019,9 @@ def grad_test_func(tensor, nested_size): data = (padded_tensor, nested_size) assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) - def test_nested_tensor_from_padded_fused(self): + def test_nested_tensor_from_padded_fused(self, device): nested_size = torch.tensor([[1, 8], [2, 8]]) - padded_tensor = torch.randn(2, 2, 2, 4, dtype=torch.float64) + padded_tensor = torch.randn(2, 2, 2, 4, dtype=torch.float64, device=device) padded_tensor[0, 1, :] = 0 padded_tensor.requires_grad_() @@ -1787,11 +2032,11 @@ def grad_test_func(tensor, nested_size): data = (padded_tensor, nested_size) assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) - def test_nested_tensor_from_list(self): + def test_nested_tensor_from_list(self, device): - a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64) - b = torch.randn(2, 2, requires_grad=True, dtype=torch.float64) - c = torch.randn(10, 2, requires_grad=True, dtype=torch.float64) + a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64, device=device) + b = torch.randn(2, 2, requires_grad=True, dtype=torch.float64, device=device) + c = torch.randn(10, 2, requires_grad=True, dtype=torch.float64, device=device) def grad_test_func(a, b, c): c = torch.nested.as_nested_tensor([a, b, c]) @@ -1807,11 +2052,11 @@ def test_dropout_backward(self): y.backward(nt.clone().detach()) self.assertEqual(nt.grad, y) - def test_nested_tensor_bmm_gradcheck(self): - a = torch.randn(2, 6, requires_grad=True, dtype=torch.float64) - b = torch.randn(3, 6, requires_grad=True, dtype=torch.float64) - c = torch.randn(6, 4, requires_grad=True, dtype=torch.float64) - d = torch.randn(6, 5, requires_grad=True, dtype=torch.float64) + def test_nested_tensor_bmm_gradcheck(self, device): + a = torch.randn(2, 6, requires_grad=True, dtype=torch.float64, device=device) + b = torch.randn(3, 6, requires_grad=True, dtype=torch.float64, device=device) + c = torch.randn(6, 4, requires_grad=True, dtype=torch.float64, device=device) + d = torch.randn(6, 5, requires_grad=True, dtype=torch.float64, device=device) def grad_test_func(a, b, c, d): nt0 = torch.nested.as_nested_tensor([a, b]) @@ -1822,9 +2067,9 @@ def grad_test_func(a, b, c, d): data = (a, b, c, d) assert torch.autograd.gradcheck(grad_test_func, inputs=data) - def test_nested_tensor_bmm_backward(self): - nt0 = torch.nested.nested_tensor([torch.randn((2, 6)), torch.randn((3, 6))], requires_grad=True) - nt1 = torch.nested.nested_tensor([torch.randn((6, 4)), torch.randn((6, 5))], requires_grad=True) + def test_nested_tensor_bmm_backward(self, device): + nt0 = torch.nested.nested_tensor([torch.randn((2, 6)), torch.randn((3, 6))], requires_grad=True, device=device) + nt1 = torch.nested.nested_tensor([torch.randn((6, 4)), torch.randn((6, 5))], requires_grad=True, device=device) with torch.no_grad(): pt0 = torch.nested.to_padded_tensor(nt0, 0.0).requires_grad_(True) pt1 = torch.nested.to_padded_tensor(nt1, 0.0).requires_grad_(True) @@ -1837,11 +2082,11 @@ def test_nested_tensor_bmm_backward(self): self.assertEqual(torch.nested.to_padded_tensor(nt0.grad, 0.0), pt0.grad) self.assertEqual(torch.nested.to_padded_tensor(nt1.grad, 0.0), pt1.grad) - def test_nested_tensor_matmul_gradcheck(self): - a = torch.randn(2, 6, requires_grad=True, dtype=torch.float64) - b = torch.randn(3, 6, requires_grad=True, dtype=torch.float64) - c = torch.randn(6, 4, requires_grad=True, dtype=torch.float64) - d = torch.randn(6, 5, requires_grad=True, dtype=torch.float64) + def test_nested_tensor_matmul_gradcheck(self, device): + a = torch.randn(2, 6, requires_grad=True, dtype=torch.float64, device=device) + b = torch.randn(3, 6, requires_grad=True, dtype=torch.float64, device=device) + c = torch.randn(6, 4, requires_grad=True, dtype=torch.float64, device=device) + d = torch.randn(6, 5, requires_grad=True, dtype=torch.float64, device=device) def grad_test_func(a, b, c, d): nt0 = torch.nested.as_nested_tensor([a, b]) @@ -1852,9 +2097,9 @@ def grad_test_func(a, b, c, d): data = (a, b, c, d) assert torch.autograd.gradcheck(grad_test_func, inputs=data) - def test_nested_tensor_matmul_backward(self): - nt0 = torch.nested.nested_tensor([torch.randn((7, 2, 6)), torch.randn((7, 3, 6))], requires_grad=True) - nt1 = torch.nested.nested_tensor([torch.randn((7, 6, 4)), torch.randn((7, 6, 5))], requires_grad=True) + def test_nested_tensor_matmul_backward(self, device): + nt0 = torch.nested.nested_tensor([torch.randn((7, 2, 6)), torch.randn((7, 3, 6))], requires_grad=True, device=device) + nt1 = torch.nested.nested_tensor([torch.randn((7, 6, 4)), torch.randn((7, 6, 5))], requires_grad=True, device=device) with torch.no_grad(): pt0 = torch.nested.to_padded_tensor(nt0, 0.0).requires_grad_(True) pt1 = torch.nested.to_padded_tensor(nt1, 0.0).requires_grad_(True) @@ -1867,9 +2112,9 @@ def test_nested_tensor_matmul_backward(self): self.assertEqual(torch.nested.to_padded_tensor(nt0.grad, 0.0), pt0.grad) self.assertEqual(torch.nested.to_padded_tensor(nt1.grad, 0.0), pt1.grad) - def test_nested_tensor_transpose_gradcheck(self): - a = torch.randn(2, 5, requires_grad=True) - b = torch.randn(3, 4, requires_grad=True) + def test_nested_tensor_transpose_gradcheck(self, device): + a = torch.randn(2, 5, requires_grad=True, device=device) + b = torch.randn(3, 4, requires_grad=True, device=device) def grad_test_func(a, b): nt = torch.nested.as_nested_tensor([a, b]) @@ -1879,8 +2124,8 @@ def grad_test_func(a, b): data = (a, b) assert torch.autograd.gradcheck(grad_test_func, inputs=data, eps=1e-3) - def test_nested_tensor_transpose_backward(self): - nt = torch.nested.nested_tensor([torch.randn((2, 5)), torch.randn((3, 4))], requires_grad=True) + def test_nested_tensor_transpose_backward(self, device): + nt = torch.nested.nested_tensor([torch.randn((2, 5)), torch.randn((3, 4))], requires_grad=True, device=device) with torch.no_grad(): pt = torch.nested.to_padded_tensor(nt, 0.0).requires_grad_(True) @@ -1891,9 +2136,9 @@ def test_nested_tensor_transpose_backward(self): self.assertEqual(torch.nested.to_padded_tensor(nt.grad, 0.0), pt.grad) - def test_nested_tensor_reshape_gradcheck(self): - a = torch.randn(2, 6, requires_grad=True) - b = torch.randn(3, 6, requires_grad=True) + def test_nested_tensor_reshape_gradcheck(self, device): + a = torch.randn(2, 6, requires_grad=True, device=device) + b = torch.randn(3, 6, requires_grad=True, device=device) def grad_test_func(a, b): nt = torch.nested.as_nested_tensor([a, b]) @@ -1915,8 +2160,8 @@ def test_nested_tensor_reshape_backward(self): self.assertEqual(torch.nested.to_padded_tensor(nt.grad, 0.0), pt.grad) - def test_nested_tensor_squeeze_backward(self): - nt = torch.nested.nested_tensor([torch.randn((2, 6, 1)), torch.randn((3, 6, 1))], requires_grad=True) + def test_nested_tensor_squeeze_backward(self, device): + nt = torch.nested.nested_tensor([torch.randn((2, 6, 1)), torch.randn((3, 6, 1))], requires_grad=True, device=device) with torch.no_grad(): pt = torch.nested.to_padded_tensor(nt, 0.0).requires_grad_(True) @@ -1927,9 +2172,9 @@ def test_nested_tensor_squeeze_backward(self): self.assertEqual(torch.nested.to_padded_tensor(nt.grad, 0.0), pt.grad) - def test_nested_tensor_squeeze_gradcheck(self): - a = torch.randn((2, 6, 1), dtype=torch.float64, requires_grad=True) - b = torch.randn((3, 6, 1), dtype=torch.float64, requires_grad=True) + def test_nested_tensor_squeeze_gradcheck(self, device): + a = torch.randn((2, 6, 1), dtype=torch.float64, requires_grad=True, device=device) + b = torch.randn((3, 6, 1), dtype=torch.float64, requires_grad=True, device=device) def grad_test_func(a, b): nt = torch.nested.as_nested_tensor([a, b]) @@ -1938,8 +2183,8 @@ def grad_test_func(a, b): assert torch.autograd.gradcheck(grad_test_func, inputs=(a, b), eps=1e-3) - def test_nested_tensor_unsqueeze_backward(self): - nt = torch.nested.nested_tensor([torch.randn((2, 6)), torch.randn((3, 6))], requires_grad=True) + def test_nested_tensor_unsqueeze_backward(self, device): + nt = torch.nested.nested_tensor([torch.randn((2, 6)), torch.randn((3, 6))], requires_grad=True, device=device) with torch.no_grad(): pt = torch.nested.to_padded_tensor(nt, 0.0).requires_grad_(True) @@ -1950,9 +2195,9 @@ def test_nested_tensor_unsqueeze_backward(self): self.assertEqual(torch.nested.to_padded_tensor(nt.grad, 0.0), pt.grad) - def test_nested_tensor_unsqueeze_gradcheck(self): - a = torch.randn((2, 6), dtype=torch.float64, requires_grad=True) - b = torch.randn((3, 6), dtype=torch.float64, requires_grad=True) + def test_nested_tensor_unsqueeze_gradcheck(self, device): + a = torch.randn((2, 6), dtype=torch.float64, requires_grad=True, device=device) + b = torch.randn((3, 6), dtype=torch.float64, requires_grad=True, device=device) def grad_test_func(a, b): nt = torch.nested.as_nested_tensor([a, b]) @@ -1961,14 +2206,14 @@ def grad_test_func(a, b): assert torch.autograd.gradcheck(grad_test_func, inputs=(a, b), eps=1e-3) - def test_nested_tensor_linear(self): + def test_nested_tensor_linear(self, device): - a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64) - b = torch.randn(2, 2, requires_grad=True, dtype=torch.float64) - c = torch.randn(3, 2, requires_grad=True, dtype=torch.float64) + a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64, device=device) + b = torch.randn(2, 2, requires_grad=True, dtype=torch.float64, device=device) + c = torch.randn(3, 2, requires_grad=True, dtype=torch.float64, device=device) - weight = torch.randn(2, 2, requires_grad=True, dtype=torch.float64) - bias = torch.randn(2, requires_grad=True, dtype=torch.float64) + weight = torch.randn(2, 2, requires_grad=True, dtype=torch.float64, device=device) + bias = torch.randn(2, requires_grad=True, dtype=torch.float64, device=device) def grad_test_func(a, b, c, weight, bias=None): nt = torch.nested.as_nested_tensor([a, b, c]) @@ -1982,10 +2227,10 @@ def grad_test_func(a, b, c, weight, bias=None): data = (a, b, c, weight) assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) - def test_nested_tensor_softmax(self): - a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64) - b = torch.randn(2, 2, requires_grad=True, dtype=torch.float64) - c = torch.randn(3, 2, requires_grad=True, dtype=torch.float64) + def test_nested_tensor_softmax(self, device): + a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64, device=device) + b = torch.randn(2, 2, requires_grad=True, dtype=torch.float64, device=device) + c = torch.randn(3, 2, requires_grad=True, dtype=torch.float64, device=device) def grad_test_func(a, b, c, dim): nt = torch.nested.as_nested_tensor([a, b, c]) @@ -1997,14 +2242,14 @@ def grad_test_func(a, b, c, dim): data = (a, b, c, -1) assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) - def test_nested_tensor_linear_backward(self): - a = torch.randn(1, 2, requires_grad=False) - b = torch.randn(2, 2, requires_grad=False) - c = torch.randn(3, 2, requires_grad=False) + def test_nested_tensor_linear_backward(self, device): + a = torch.randn(1, 2, requires_grad=False, device=device) + b = torch.randn(2, 2, requires_grad=False, device=device) + c = torch.randn(3, 2, requires_grad=False, device=device) - weight = torch.randn(2, 2, requires_grad=True) - bias = torch.randn(2, requires_grad=True) - nt = torch.nested.as_nested_tensor([a, b, c]) + weight = torch.randn(2, 2, requires_grad=True, device=device) + bias = torch.randn(2, requires_grad=True, device=device) + nt = torch.nested.as_nested_tensor([a, b, c], device=device) out = torch.functional.F.linear(nt, weight, bias) @@ -2017,10 +2262,10 @@ def test_nested_tensor_linear_backward(self): assert b.grad is None assert c.grad is None - def test_values_grad_with_broadcast(self): - a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64) - b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64) - c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64) + def test_values_grad_with_broadcast(self, device): + a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device) + b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device) + c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device) def grad_test_func(a, b, c): nt = torch.nested.as_nested_tensor([a, b, c]) @@ -2030,10 +2275,10 @@ def grad_test_func(a, b, c): data = (a, b, c) assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) - def test_to_buffer_series_ops_grad_with_broadcast(self): - a = torch.randn(1, 1, 2, requires_grad=True, dtype=torch.float64) - b = torch.randn(1, 1, 2, requires_grad=True, dtype=torch.float64) - c = torch.randn(1, 1, 2, requires_grad=True, dtype=torch.float64) + def test_to_buffer_series_ops_grad_with_broadcast(self, device): + a = torch.randn(1, 1, 2, requires_grad=True, dtype=torch.float64, device=device) + b = torch.randn(1, 1, 2, requires_grad=True, dtype=torch.float64, device=device) + c = torch.randn(1, 1, 2, requires_grad=True, dtype=torch.float64, device=device) def grad_test_func(a, b, c): nt = torch.nested.as_nested_tensor([a, b, c]) @@ -2044,10 +2289,10 @@ def grad_test_func(a, b, c): data = (a, b, c) assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) - def test_unbind_flow_through(self): - a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64) - b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64) - c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64) + def test_unbind_flow_through(self, device): + a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device) + b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device) + c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device) def grad_test_func(a, b, c): nt = torch.nested.as_nested_tensor([a, b, c]) @@ -2060,18 +2305,21 @@ def grad_test_func(a, b, c): data = (a, b, c) assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) - def test_indexing_backward(self): + def test_indexing_backward(self, device): x0 = torch.randn((2, 5)) x1 = torch.randn((3, 4)) - nt = torch.nested.nested_tensor([x0, x1], requires_grad=True) + nt = torch.nested.nested_tensor([x0, x1], device=device, requires_grad=True) self.assertEqual(nt[0], x0) self.assertEqual(nt[-1], x1) - grad_x0 = torch.randn((2, 5)) + grad_x0 = torch.randn((2, 5), device=device) nt[0].backward(grad_x0) - expected_grad = torch.nested.nested_tensor([grad_x0, torch.zeros((3, 4))]) + expected_grad = torch.nested.nested_tensor([grad_x0, torch.zeros((3, 4), device=device)]) self.assertEqual(nt.grad, expected_grad) + +instantiate_parametrized_tests(TestNestedTensor) instantiate_device_type_tests(TestNestedTensorDeviceType, globals()) +instantiate_device_type_tests(TestNestedTensorAutograd, globals()) if __name__ == '__main__': run_tests() diff --git a/test/test_nn.py b/test/test_nn.py index 6c7f1e82ccd63..cbff32d480fbb 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -3,21 +3,16 @@ import contextlib import math import random -import string import unittest import io -import unittest.mock as mock import itertools import warnings import pickle from copy import deepcopy from itertools import product -from functools import reduce, partial -from operator import mul +from functools import partial from collections import OrderedDict from tempfile import NamedTemporaryFile -import weakref -import gc import torch @@ -30,35 +25,30 @@ import torch.backends.cudnn as cudnn import torch.nn as nn import torch.nn.functional as F -import torch.nn.init as init import torch.nn.utils.rnn as rnn_utils from torch.nn.utils import clip_grad_norm_, clip_grad_value_ -import torch.nn.utils.parametrize as parametrize -import torch.nn.utils.prune as prune from torch.nn.utils import parameters_to_vector, vector_to_parameters +from torch.nn.utils.fusion import fuse_conv_bn_weights +from torch.nn.utils.fusion import fuse_linear_bn_weights from torch.nn import Parameter from torch.nn.parallel._functions import Broadcast -from torch.testing._internal.common_dtype import integral_types, floating_types_and, get_all_math_dtypes, \ - floating_and_complex_types_and +from torch.testing._internal.common_dtype import integral_types, get_all_math_dtypes from torch.testing._internal.common_utils import freeze_rng_state, run_tests, TestCase, skipIfNoLapack, skipIfRocm, \ - skipIfRocmVersionLessThan, skipIfNotMiopenSuggestNHWC, TEST_NUMPY, TEST_SCIPY, TEST_WITH_CROSSREF, TEST_WITH_ROCM, \ + TEST_NUMPY, TEST_SCIPY, TEST_WITH_CROSSREF, TEST_WITH_ROCM, \ download_file, get_function_arglist, load_tests, skipIfMps,\ - TemporaryFileName, TEST_WITH_UBSAN, IS_PPC, \ - parametrize as parametrize_test, subtest, instantiate_parametrized_tests, set_default_dtype, IS_WINDOWS, \ - skipIfTorchDynamo + TEST_WITH_UBSAN, IS_PPC, \ + parametrize as parametrize_test, subtest, instantiate_parametrized_tests, \ + skipIfTorchDynamo, IS_WINDOWS from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, TEST_CUDNN_VERSION from torch.testing._internal.common_nn import NNTestCase, NewModuleTest, CriterionTest, \ - module_tests, criterion_tests, loss_reference_fns, \ + module_tests, criterion_tests, loss_reference_fns, _create_basic_net, \ ctcloss_reference, new_module_tests, single_batch_reference_fn, _test_bfloat16_ops, _test_module_empty_input from torch.testing._internal.common_device_type import instantiate_device_type_tests, dtypes, \ - dtypesIfCUDA, precisionOverride, skipCUDAIfNoCudnn, skipCUDAIfCudnnVersionLessThan, onlyCUDA, onlyCPU, \ - skipCUDAIfRocm, skipCUDAIf, skipCUDAIfNotRocm, skipCUDAIfRocmVersionLessThan, skipCUDAIfNotMiopenSuggestNHWC, \ - onlyNativeDeviceTypes, deviceCountAtLeast, largeTensorTest, expectedFailureMeta, skipMeta, get_all_device_types, \ - disableMkldnn, skipCPUIfNoMkldnn, disablecuDNN, skipCUDAIfMiopen, skipCUDAIfNoMiopen -from torch.nn import MultiheadAttention + dtypesIfMPS, dtypesIfCUDA, precisionOverride, skipCUDAIfCudnnVersionLessThan, onlyCUDA, onlyCPU, \ + skipCUDAIfRocm, skipCUDAIf, skipMPSIf, skipCUDAIfNotRocm, \ + onlyNativeDeviceTypes, deviceCountAtLeast, largeTensorTest, expectedFailureMeta, skipMeta, get_all_device_types from hypothesis import given -from torch.testing import make_tensor import torch.testing._internal.hypothesis_utils as hu from torch.testing._internal.common_utils import _assertGradAndGradgradChecks, gradcheck, gradgradcheck, \ GRADCHECK_NONDET_TOL @@ -74,7 +64,6 @@ load_tests = load_tests if TEST_SCIPY: - from scipy import stats import scipy.signal import scipy.ndimage @@ -146,26 +135,6 @@ def _get_parameters(self, module): d_params.append(p.grad) return params, d_params - def _create_basic_net(self): - class Layer(nn.Module): - def __init__(self): - super(Layer, self).__init__() - self.layer_dummy_param = Parameter(torch.empty(3, 5)) - self.register_buffer('layer_dummy_buf', torch.zeros(1, 3, 3, 7)) - - class Net(nn.Module): - def __init__(self): - super(Net, self).__init__() - self.l1 = Layer() - self.dummy_param = Parameter(torch.empty(3, 5)) - self.register_buffer('dummy_buf', torch.zeros(7, 3, 3, 1)) - - l = Layer() - n = Net() - s = nn.Sequential(n, n) - - return l, n, s - def test_parse_to(self): # Test for buggy use of THPMemoryFormat_New self.assertEqual( @@ -174,7 +143,7 @@ def test_parse_to(self): ) def test_requires_grad_(self): - m = self._create_basic_net()[-1] + m = _create_basic_net()[-1] assert len(list(m.buffers())) > 0, 'invalid test' assert all(not b.requires_grad for b in m.buffers()) > 0, 'invalid test' assert len(list(m.parameters())) > 0, 'invalid test' @@ -195,24 +164,6 @@ def test_module_backcompat(self): input = torch.randn(2, 3, dtype=torch.float) self.assertEqual(m(input).size(), (2, 5)) - def test_conv_backcompat(self): - from torch.serialization import SourceChangeWarning - - # This file was generated by running on PyTorch 1.0.1 on Python 2: - # - # import torch - # from torch import nn - # m = nn.Conv2d(1, 1, 1) - # torch.save(m, 'legacy_conv2d.pt') - # - # NB: This Pickle also contains some Unicode data! - path = download_file('https://download.pytorch.org/test_data/legacy_conv2d.pt') - with warnings.catch_warnings(): - warnings.simplefilter('ignore', SourceChangeWarning) - m = torch.load(path, encoding='utf-8') - input = torch.randn((1, 1, 1, 1), dtype=torch.float) - self.assertEqual(m(input).size(), (1, 1, 1, 1)) - def test_share_memory(self): class Net(nn.Module): def __init__(self): @@ -236,464 +187,6 @@ def forward(self, inp): for b in net.buffers(): self.assertTrue(b.storage().is_shared()) - def _test_hooks(self, backward_register_fn): - module = nn.Sigmoid() - input = torch.ones(5, 5, requires_grad=True) - - counter = { - 'forwards': 0, - 'backwards': 0 - } - - def fw_hook(inc, h_module, input, output): - self.assertIsInstance(input, tuple) - self.assertTrue(isinstance(output, torch.Tensor)) - self.assertTrue(h_module is module) - self.assertEqual(input[0], torch.ones(5, 5)) - self.assertEqual(output, torch.empty(5, 5).fill_(1 / (1 + 1 / math.e))) - counter['forwards'] += inc - - def bw_hook(inc, h_module, grad_input, grad_output): - self.assertIsInstance(grad_input, tuple) - self.assertIsInstance(grad_output, tuple) - self.assertTrue(h_module is module) - self.assertEqual(grad_output[0], torch.ones(5, 5) * 2) - counter['backwards'] += inc - - # backward_pre_hook expects callback with only `module` and `grad_output` - # as arguments. - def bw_pre_hook(inc, h_module, grad_output): - self.assertIsInstance(grad_output, tuple) - self.assertTrue(h_module is module) - self.assertEqual(grad_output[0], torch.ones(5, 5) * 2) - counter['backwards'] += inc - - test_fwd = module.register_forward_hook(lambda *args: fw_hook(1, *args)) - - module(input) - module(input) - self.assertEqual(counter['forwards'], 2) - self.assertEqual(counter['backwards'], 0) - - bw_hook_fn = bw_pre_hook if backward_register_fn == 'register_full_backward_pre_hook' else bw_hook - test_bwd = getattr(module, backward_register_fn)( - lambda *args: bw_hook_fn(1, *args)) - - output = module(input) - self.assertEqual(counter['forwards'], 3) - self.assertEqual(counter['backwards'], 0) - - output.backward(torch.ones(5, 5) * 2, retain_graph=True) - self.assertEqual(counter['forwards'], 3) - self.assertEqual(counter['backwards'], 1) - - output.backward(torch.ones(5, 5) * 2, retain_graph=True) - self.assertEqual(counter['forwards'], 3) - self.assertEqual(counter['backwards'], 2) - - test2_fwd = module.register_forward_hook(lambda *args: fw_hook(2, *args)) - - output = module(input) - self.assertEqual(counter['forwards'], 6) - self.assertEqual(counter['backwards'], 2) - - test2_bwd = getattr(module, backward_register_fn)(lambda *args: bw_hook_fn(2, *args)) - - module(input).backward(torch.ones(5, 5) * 2) - self.assertEqual(counter['forwards'], 9) - self.assertEqual(counter['backwards'], 5) - - test2_bwd.remove() - - module(input).backward(torch.ones(5, 5) * 2) - self.assertEqual(counter['forwards'], 12) - self.assertEqual(counter['backwards'], 6) - - test2_fwd.remove() - - module(input).backward(torch.ones(5, 5) * 2) - self.assertEqual(counter['forwards'], 13) - self.assertEqual(counter['backwards'], 7) - - test_fwd.remove() - test_bwd.remove() - - @skipIfTorchDynamo("TorchDynamo does not work well with hooks") - def test_hooks(self): - self._test_hooks("register_backward_hook") - self._test_hooks("register_full_backward_hook") - self._test_hooks("register_full_backward_pre_hook") - - def test_hook_cpp(self): - bn = nn.BatchNorm1d(5) - - def hook(module, grad_inputs, grad_outputs): - self.assertEqual(len(grad_inputs), 1) - self.assertEqual(len(grad_outputs), 1) - self.assertEqual(module, bn) - - bn.register_full_backward_hook(hook) - output = bn(torch.randn(5, 5, requires_grad=True)) - output.sum().backward() - - @skipIfTorchDynamo("TorchDynamo does not work well with hooks") - def test_backward_hooks_interaction(self): - # Test to make sure that the grad_outputs - # updated by full_backward_pre_hook are received by - # the full_backward_hook - module = torch.nn.Sigmoid() - - cnt = {'backward_cnt': 0} - - def bw_pre_hook(m, grad_output): - cnt['backward_cnt'] += 1 - return (grad_output[0] * 0.5, ) - - def bw_hook(m, grad_in, grad_output): - self.assertEqual(torch.full_like(grad_output[0], 0.5), grad_output[0]) - cnt['backward_cnt'] += 1 - return grad_output - - module.register_full_backward_pre_hook(bw_pre_hook) - module.register_full_backward_hook(bw_hook) - - t = torch.ones(1, 2, requires_grad=True) - module(t).sum().backward() - self.assertEqual(cnt['backward_cnt'], 2) - - def test_hook_invalid_outputs(self): - module = nn.Sigmoid() - input = torch.randn(5, 5, requires_grad=True) - - def bw_fail1(self, grad_input, grad_output): - return grad_input[:-1] - - def bw_fail2(self, grad_input, grad_output): - return grad_input + (torch.randn(2, 2),) - - with module.register_backward_hook(bw_fail1): - with self.assertRaisesRegex(RuntimeError, 'got 0, but expected 1'): - module(input).sum().backward() - - with module.register_backward_hook(bw_fail2): - with self.assertRaisesRegex(RuntimeError, 'got 2, but expected 1'): - module(input).sum().backward() - - def bw_pre_fail1(self, grad_output): - return () - - def bw_pre_fail2(self, grad_output): - return grad_output + (torch.randn(2, 2),) - - with module.register_full_backward_pre_hook(bw_pre_fail1): - with self.assertRaisesRegex(RuntimeError, 'got 0, but expected 1'): - module(input).sum().backward() - - with module.register_full_backward_pre_hook(bw_pre_fail2): - with self.assertRaisesRegex(RuntimeError, 'got 2, but expected 1'): - module(input).sum().backward() - - def test_hook_requires_grad(self): - test_self = self - - class MyModule(nn.Module): - def forward(self, arg1, arg2, arg3): - test_self.assertTrue(arg1.requires_grad) - test_self.assertFalse(arg2.requires_grad) - test_self.assertTrue(arg3.requires_grad) - return arg1.sum() + arg2.sum() + arg3.sum() - - inp = torch.rand(2, requires_grad=True) - mod = MyModule() - - mod(inp, inp.detach(), inp) - # Ensure that requires grad is properly propagated - mod.register_full_backward_hook(lambda mod, gI, gO: None) - mod(inp, inp.detach(), inp) - - @skipIfTorchDynamo("TorchDynamo does not work well with hooks") - def test_hook_no_requires_grad(self): - mod = nn.Linear(2, 3) - - inp = torch.rand(1, 2) - - return_val = "None" - hook_called = [0] - - def hook(mod, grad_input, grad_output): - hook_called[0] += 1 - for gI in grad_input: - self.assertIsNone(gI) - for gO in grad_output: - self.assertEqual(gO.size(), (1, 3)) - - if return_val == "grad_input": - return grad_input - elif return_val == "invalid": - # If the inputs were requiring gradients, this would be - # a valid return - return inp - elif return_val == "None": - return None - else: - raise RuntimeError("Invalid return_val string") - - mod.register_full_backward_hook(hook) - - # This should run and trigger the hook properly - mod(inp).sum().backward() - self.assertEqual(hook_called[0], 1) - - return_val = "grad_input" - - mod(inp).sum().backward() - self.assertEqual(hook_called[0], 2) - - return_val = "invalid" - with self.assertRaisesRegex(RuntimeError, "where no input requires gradient"): - mod(inp).sum().backward() - - def test_hook_last_arg_requires_grad(self): - mod = nn.L1Loss() - inp = torch.rand(1, requires_grad=True) - mod.register_full_backward_hook(lambda m, gI, gO: None) - - try: - mod(inp.detach(), inp) - except Exception as ex: - self.fail("Unexpected exception: %s" % ex) - - def test_hook_extra_input(self): - class MyModule(nn.Module): - def forward(self, non_tensor, tensor): - return tensor.clone(), non_tensor - - inp = torch.rand(2, requires_grad=True) - mod = MyModule() - - def hook(mod, grad_input, grad_output): - self.assertIsNone(grad_input[0]) - self.assertIsInstance(grad_input[1], torch.Tensor) - - self.assertIsInstance(grad_output[0], torch.Tensor) - self.assertIsNone(grad_output[1]) - - mod.register_full_backward_hook(hook) - out, _ = mod(True, inp) - out.sum().backward() - - def test_hook_inplace(self): - class MyModule(nn.Module): - def forward(self, inp, do_inplace): - self.inp = inp - if do_inplace: - inp += 1 - return inp.clone() - - hook_called = [0] - - def hook(mod, grad_input, grad_output): - hook_called[0] += 1 - - def hook_pre(mod, grad_output): - hook_called[0] += 1 - - inp = torch.rand(10, requires_grad=True) - mod = MyModule() - for hook_fn, register_fn in [(hook, mod.register_full_backward_hook), - (hook_pre, mod.register_full_backward_pre_hook)]: - hook_called[0] = 0 - with register_fn(hook_fn): - # No inplace should work - mod(inp, False).sum().backward() - self.assertEqual(hook_called[0], 1) - - # Input inplace error should throw an error - with self.assertRaisesRegex(RuntimeError, "Output 0 of BackwardHookFunctionBackward is " - "a view and is being modified inplace."): - mod(inp.clone(), True) - - # Input inplace error should throw an error if we try to re-use the view after they have - # been modified - local_inp = inp.clone() - out = mod(local_inp, False) - local_inp[0] *= 1 - with self.assertRaisesRegex(RuntimeError, "Output 0 of BackwardHookFunctionBackward is " - "a view and its base or another view"): - # Any operation involving the view will fail here - mod.inp + 2 - - # Output inplace error should throw an error - out = mod(inp, False) - with self.assertRaisesRegex(RuntimeError, "BackwardHookFunctionBackward is a view " - "and is being modified inplace."): - out += 1 - - def test_hook_non_full_warning(self): - def noop(*args): - pass - - a = torch.rand(2, requires_grad=True) - b = torch.rand(2, requires_grad=True) - - # Check invalid input container - class MyModule(nn.Module): - def forward(self, l): - return l[0].clone(), l[1].clone() - - m = MyModule() - m.register_backward_hook(noop) - - with self.assertWarnsRegex(UserWarning, "does not take as input a single Tensor or a tuple of Tensors"): - m([a, b]) - - # Check invalid output container - class MyModule(nn.Module): - def forward(self, a, b): - return [a.clone(), b.clone()] - - m = MyModule() - m.register_backward_hook(noop) - - with self.assertWarnsRegex(UserWarning, "does not return a single Tensor or a tuple of Tensors"): - m(a, b) - - # Check invalid output from different Nodes - class MyModule(nn.Module): - def forward(self, a, b): - return a.clone(), b.clone() - - m = MyModule() - m.register_backward_hook(noop) - - with self.assertWarnsRegex(UserWarning, "outputs are generated by different autograd Nodes"): - m(a, b) - - # Check invalid forward with multiple Nodes - class MyModule(nn.Module): - def forward(self, a): - return a.clone().clone() - - m = MyModule() - m.register_backward_hook(noop) - - with self.assertWarnsRegex(UserWarning, "the forward contains multiple autograd Nodes"): - m(a) - - def test_hook_backward_size(self): - # Make module with multiple operations in forward - # And different size for input and outputs - class MyModule(nn.Module): - def forward(self, arg1, arg2): - tmp = arg1.sum() * arg2 - tmp = tmp + arg2.sum() * arg1.sum() - tmp = tmp.sum().view(1) - tmp = tmp.expand(8).contiguous() - return tmp - - module = MyModule() - inp1 = torch.randn(5, 5, requires_grad=True) - inp2 = torch.randn(10, 10, requires_grad=True) - - def bw_hook(module, grad_input, grad_output): - self.assertEqual(len(grad_input), 2) - self.assertEqual(grad_input[0].size(), torch.Size([5, 5])) - self.assertEqual(grad_input[1].size(), torch.Size([10, 10])) - self.assertEqual(len(grad_output), 1) - self.assertEqual(grad_output[0].size(), torch.Size([8])) - - with module.register_full_backward_hook(bw_hook): - module(inp1, inp2).sum().backward() - - @skipIfTorchDynamo("TorchDynamo does not work well with hooks") - def test_hook_backward_writeable(self): - module = nn.Sigmoid() - input = torch.randn(5, 5, requires_grad=True) - sig_x = torch.nn.functional.sigmoid(input) - - def bw_hook(module, grad_input, grad_output): - for grad in grad_input: - self.assertTrue(isinstance(grad, torch.Tensor)) - for grad in grad_output: - self.assertTrue(isinstance(grad, torch.Tensor)) - return tuple(gi * 2 for gi in grad_input) - - module.register_backward_hook(bw_hook) - module(input).backward(torch.ones(5, 5)) - expected_grad = sig_x * (1 - sig_x) * 2 - self.assertEqual(input.grad, expected_grad) - - @skipIfTorchDynamo("TorchDynamo does not work well with hooks") - def test_hook_forward_preforward_writable(self): - module = nn.Sigmoid() - input = torch.randn(5, 5, requires_grad=True) - sig_x = torch.nn.functional.sigmoid(input) - - def forward_pre_hook(m, input): - return torch.nn.functional.relu(input[0]) - - def forward_hook(m, input, output): - return -output - - module.register_forward_pre_hook(forward_pre_hook) - module.register_forward_hook(forward_hook) - output = module(input) - expected_res = -torch.nn.functional.sigmoid(torch.nn.functional.relu(input)) - self.assertEqual(output, expected_res) - output.backward(torch.ones(5, 5) * 2, retain_graph=True) - mask = (input > 0).double() - expected_grad = -sig_x * (1 - sig_x) * 2 * mask - self.assertEqual(input.grad, expected_grad) - - def test_hook_buffer_registration(self): - for return_buffer in (True, False): - def buffer_registration_hook(module, name, buffer): - buffer.registered = True - if return_buffer: - return buffer - handle = torch.nn.modules.module.register_module_buffer_registration_hook( - buffer_registration_hook - ) - try: - l, n, s = self._create_basic_net() - for b in s.buffers(): - self.assertTrue(getattr(b, "registered", False)) - finally: - handle.remove() - - def test_hook_submodule_registration(self): - for return_submodule in (True, False): - def module_registration_hook(module, name, submodule): - module.registered = True - submodule.registered = True - if return_submodule: - return submodule - handle = torch.nn.modules.module.register_module_module_registration_hook( - module_registration_hook - ) - try: - l, n, s = self._create_basic_net() - for m in s.modules(): - self.assertTrue(getattr(m, "registered", False)) - finally: - handle.remove() - - def test_hook_parameter_registration(self): - for return_parameter in (True, False): - def parameter_registration_hook(module, name, parameter): - parameter.registered = True - if return_parameter: - return parameter - handle = torch.nn.modules.module.register_module_parameter_registration_hook( - parameter_registration_hook - ) - try: - l, n, s = self._create_basic_net() - for p in s.parameters(): - self.assertTrue(getattr(p, "registered", False)) - finally: - handle.remove() - def test_to(self): m = nn.Linear(3, 5) self.assertIs(m, m.to('cpu')) @@ -761,198 +254,11 @@ def test_no_grad(self): self.assertFalse(output2.requires_grad) self.assertRaises(RuntimeError, lambda: output2.backward(torch.ones(1, 5, 10, 10))) - def test_invalid_conv1d(self): - for dtype in [torch.bfloat16, torch.float, torch.double, torch.cfloat, torch.cdouble]: - module = nn.Conv1d(in_channels=3, out_channels=33, kernel_size=10, stride=1, bias=True).to(dtype) - input = torch.randn(1, 3, 4).to(dtype) - with self.assertRaisesRegex(RuntimeError, - r'Calculated padded input size per channel: \(4\). ' + - r'Kernel size: \(10\). Kernel size can\'t be greater than actual input size'): - module(input) - - # Negative stride check - module = nn.Conv1d(in_channels=3, out_channels=6, kernel_size=3, stride=-1, bias=True).to(dtype) - input = torch.randn(1, 3, 4).to(dtype) - with self.assertRaisesRegex(RuntimeError, 'non-positive stride is not supported'): - module(input) - - def test_mismatch_shape_conv2d(self): - for dtype in (torch.float, torch.cfloat): - x = torch.randn(1, 10, 1, 28, 28, dtype=dtype) - w = torch.randn(6, 1, 5, 5, dtype=dtype) - - with self.assertRaisesRegex(RuntimeError, - r'Expected 3D \(unbatched\) or 4D \(batched\) input to conv2d, but got ' + - r'input of size: \[1, 10, 1, 28, 28\]'): - - F.conv2d(x, w) - - def test_conv2d_discontiguous_weight(self): - for dtype in (torch.float, torch.cfloat): - # Test for https://github.com/pytorch/pytorch/issues/55781 - x = torch.ones(64, 16, 16, 16, dtype=dtype) - weight = torch.arange(0, 1.0, 1 / 2.0 ** 10).reshape(32, 16, 1, 2).to(dtype)[:, :, :, ::2] - self.assertFalse(weight.is_contiguous()) - y = torch.nn.functional.conv2d(x, weight, None) - if torch.backends.mkldnn.is_available(): - # Disable MKLDNN explicitly, so that either NNPACK or THCNN will be used - with torch.backends.mkldnn.flags(enabled=False): - y_ = torch.nn.functional.conv2d(x, weight, None) - self.assertEqual(y, y_) - self.assertEqual(y.sum(), 4186112.) - - def test_invalid_conv2d(self): - for dtype in [torch.bfloat16, torch.float, torch.double, torch.cfloat, torch.cdouble]: - module = torch.nn.Conv2d(1, 1, kernel_size=3, dilation=2, stride=2).to(dtype) - input = torch.empty(1, 1, 4, 4).to(dtype) - self.assertRaises(RuntimeError, lambda: module(input)) - - module = nn.Conv2d(in_channels=3, out_channels=33, kernel_size=10, stride=1, bias=True) - input = torch.randn(1, 3, 1, 1) - with self.assertRaisesRegex(RuntimeError, - r'Calculated padded input size per channel: \(1 x 1\). ' + - r'Kernel size: \(10 x 10\). Kernel size can\'t be greater than actual input size'): - module(input) - - # Negative stride check - module = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=4, stride=-1, bias=True).to(dtype) - input = torch.randn(1, 3, 4, 4).to(dtype) - with self.assertRaisesRegex(RuntimeError, 'non-positive stride is not supported'): - module(input) - - # Zero stride check - module = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=4, stride=0, bias=True).to(dtype) - input = torch.randn(1, 3, 4, 4).to(dtype) - with self.assertRaisesRegex(RuntimeError, 'non-positive stride is not supported'): - module(input) - - def test_invalid_conv3d(self): - for dtype in [torch.bfloat16, torch.float, torch.double, torch.cfloat, torch.cdouble]: - module = torch.nn.Conv3d(1, 1, kernel_size=3, dilation=2, stride=2).to(dtype) - input = torch.empty(1, 1, 4, 4, 4).to(dtype) - self.assertRaises(RuntimeError, lambda: module(input)) - - # Negative stride check - module = torch.nn.Conv3d(1, 1, kernel_size=3, stride=-2) - input = torch.empty(1, 1, 4, 4, 4) - with self.assertRaisesRegex(RuntimeError, 'non-positive stride is not supported'): - module(input) - - def test_conv_invalid_groups(self): - with self.assertRaisesRegex(ValueError, 'groups must be a positive integer'): - torch.nn.Conv1d(1, 1, kernel_size=3, dilation=2, stride=2, groups=0) - with self.assertRaisesRegex(ValueError, 'groups must be a positive integer'): - torch.nn.Conv2d(1, 1, kernel_size=3, dilation=2, stride=2, groups=-1) - with self.assertRaisesRegex(ValueError, 'groups must be a positive integer'): - torch.nn.Conv3d(1, 1, kernel_size=3, dilation=2, stride=2, groups=-2) - - def test_Conv1d_module_same_padding(self): - # Compare module against functional: without strides/dilation, asymmetric padding - x = torch.rand(1, 1, 20) - module = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=10, - padding='same') - expect = F.conv1d(x, module.weight, module.bias, padding='same') - self.assertEqual(expect, module(x)) - - # Test dilation, symmetric padding - module = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=10, - padding='same', dilation=2) - expect = F.conv1d(x, module.weight, module.bias, padding='same', dilation=2) - self.assertEqual(expect, module(x)) - - # Test non-zero padding_mode, requiring explicit padding - module = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=10, - padding='same', padding_mode='replicate') - x_padded = F.pad(x, [4, 5], mode='replicate') - expect = F.conv1d(x_padded, module.weight, module.bias, padding='valid') - self.assertEqual(expect, module(x)) - self.assertEqual(x.size(), expect.size()) - - # Test connstruction with invalid padding string raises - with self.assertRaisesRegex(ValueError, 'Invalid padding string'): - module = nn.Conv1d(in_channels=3, out_channels=33, kernel_size=10, padding='foo') - - # Test connstruction with same padding and strides raises - with self.assertRaisesRegex(ValueError, "padding='same'"): - module = nn.Conv1d(in_channels=3, out_channels=33, kernel_size=10, padding='same', stride=2) - - def test_Conv2d_module_same_padding(self): - # Compare module against functional: - # without strides/dilation, both symmetric and asymmetric padding - x = torch.rand(1, 1, 9, 20) - module = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=(5, 10), - padding='same') - expect = F.conv2d(x, module.weight, module.bias, padding='same') - self.assertEqual(expect, module(x)) - - # with dilation, symmetric padding - module = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=(3, 4), - padding='same', dilation=(1, 2)) - expect = F.conv2d(x, module.weight, module.bias, padding='same', dilation=(1, 2)) - self.assertEqual(expect, module(x)) - - # Test non-zero padding_mode, requiring explicit padding - module = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=(3, 4), - padding='same', padding_mode='reflect') - x_padded = F.pad(x, [1, 2, 1, 1], mode='reflect') - expect = F.conv2d(x_padded, module.weight, module.bias, padding='valid') - self.assertEqual(expect, module(x)) - self.assertEqual(x.size(), expect.size()) - - # Test connstruction with invalid padding string raises - with self.assertRaisesRegex(ValueError, 'Invalid padding string'): - module = nn.Conv2d(in_channels=3, out_channels=33, kernel_size=10, padding='foo') - - # Test connstruction with same padding and strides raises - with self.assertRaisesRegex(ValueError, "padding='same'"): - module = nn.Conv2d(in_channels=3, out_channels=33, kernel_size=10, padding='same', stride=2) - with self.assertRaisesRegex(ValueError, "padding='same'"): - module = nn.Conv2d(in_channels=3, out_channels=33, kernel_size=10, padding='same', stride=(1, 3)) - with self.assertRaisesRegex(ValueError, "padding='same'"): - module = nn.Conv2d(in_channels=3, out_channels=33, kernel_size=10, padding='same', stride=(4, 1)) - - def test_Conv3d_module_same_padding(self): - # Compare module against functional: - x = torch.rand(1, 1, 4, 4, 4) - # without dilation, both symmetric and asymmetric padding - module = nn.Conv3d(in_channels=1, out_channels=1, kernel_size=(2, 3, 4), - padding='same') - expect = F.conv3d(x, module.weight, module.bias, padding='same') - self.assertEqual(expect, module(x)) - - # with dilation, both symmetric and asymmetric padding - module = nn.Conv3d(in_channels=1, out_channels=1, kernel_size=(2, 3, 4), - padding='same', dilation=(3, 2, 1)) - expect = F.conv3d(x, module.weight, module.bias, padding='same', dilation=(3, 2, 1)) - self.assertEqual(expect, module(x)) - - # Test non-zero padding_mode, requiring explicit padding - module = nn.Conv3d(in_channels=1, out_channels=1, kernel_size=(2, 3, 4), - padding='same', padding_mode='circular') - x_padded = F.pad(x, [1, 2, 1, 1, 0, 1], mode='circular') - expect = F.conv3d(x_padded, module.weight, module.bias, padding='valid') - self.assertEqual(expect, module(x)) - self.assertEqual(x.size(), expect.size()) - - # Test connstruction with invalid padding string raises - with self.assertRaisesRegex(ValueError, 'Invalid padding string'): - module = nn.Conv3d(in_channels=3, out_channels=33, kernel_size=10, padding='foo') - - # Test connstruction with same padding and strides raises - with self.assertRaisesRegex(ValueError, "padding='same'"): - module = nn.Conv2d(in_channels=3, out_channels=33, kernel_size=10, padding='same', stride=2) - with self.assertRaisesRegex(ValueError, "padding='same'"): - module = nn.Conv2d(in_channels=3, out_channels=33, kernel_size=10, padding='same', stride=(1, 1, 3)) - with self.assertRaisesRegex(ValueError, "padding='same'"): - module = nn.Conv2d(in_channels=3, out_channels=33, kernel_size=10, padding='same', stride=(1, 4, 1)) - with self.assertRaisesRegex(ValueError, "padding='same'"): - module = nn.Conv2d(in_channels=3, out_channels=33, kernel_size=10, padding='same', stride=(5, 1, 1)) - def test_parameters_and_named_parameters(self): def names(named_parameters): return [k for k, _ in named_parameters] - l, n, s = self._create_basic_net() + l, n, s = _create_basic_net() self.assertEqual(len(list(l.parameters())), 1) self.assertEqual( @@ -974,11 +280,39 @@ def names(named_parameters): names(s.named_parameters()), ['0.dummy_param', '0.l1.layer_dummy_param']) + def test_named_parameters_remove_duplicate(self): + def names(named_parameters): + return [k for k, _ in named_parameters] + + class M1(nn.Module): + def __init__(self): + super().__init__() + self.param1 = nn.Parameter(torch.empty(3, 3)) + self.param2 = self.param1 + + m1 = M1() + self.assertEqual(names(m1.named_parameters()), + ["param1"]) + self.assertEqual(names(m1.named_parameters(remove_duplicate=False)), + ["param1", "param2"]) + + class M2(nn.Module): + def __init__(self): + super().__init__() + self.mod1 = nn.Linear(3, 4, bias=False) + self.mod2 = self.mod1 + + m2 = M2() + self.assertEqual(names(m2.named_parameters()), + ["mod1.weight"]) + self.assertEqual(names(m2.named_parameters(remove_duplicate=False)), + ["mod1.weight", "mod2.weight"]) + def test_buffers_and_named_buffers(self): def names(named_buffers): return [k for k, _ in named_buffers] - l, n, s = self._create_basic_net() + l, n, s = _create_basic_net() self.assertEqual(len(list(l.buffers())), 1) self.assertEqual( @@ -2418,2121 +1752,126 @@ def test_vector_to_parameters(self): sample = next(model.parameters())[0, 0, 0] self.assertTrue(torch.equal(sample.data, vec.data[:5])) - # FIXME: Rewrite this test using functions not depending on LAPACK - # and remove the `@skipIfNoLapack` (see #70995) - # torch/nn/utils/parametrize - @skipIfNoLapack - def test_register_and_remove_parametrization(self): - r"""Test that it is possible to add a few parametrizations - on a parameter or a buffer and that removing them restores the initial state - It also tests that backpropagating through them works as expected - """ - # Define a couple matrix parametrizations - class Skew(nn.Module): - def forward(self, X): - X = X.tril(-1) - return X - X.T - - class Orthogonal(nn.Module): - def forward(self, X): - # Cayley map - # If X is skew-symmetric it returns an orthogonal matrix - Id = torch.eye(X.size(0), device=X.device) - # We call contiguous because solve returns a tensor with strides that are Fortran-contiguous - # and autograd raises a performance warning. - # This happens when we remove the parametrization with leave_parametrized=True, - # which does a set_ with a non-contiguous tensor while the gradient is contiguous - return torch.linalg.solve(Id + X, Id - X).contiguous() - - class Resize(nn.Module): - def forward(self, X): - return X[[0]] - - class NoResize(nn.Module): - def forward(self, X): - return X - - # Define a couple vector parametrizations - class FirstZero(nn.Module): - def forward(self, x): - return torch.cat([x.new_zeros(1), x[1:]]) + def test_rnn_weight_norm(self): + def check_weight_norm(l, name, num_params): + # This Module has 4 or 5 parameters called: + # 'weight_ih_l0', 'weight_hh_l0', 'bias_ih_l0', 'bias_hh_l0', weight_hr_l0 - class LastZero(nn.Module): - def forward(self, x): - return torch.cat([x[:-1], x.new_zeros(1)]) - - model = nn.Linear(8, 8) - initial_weight_id = id(model.weight) - initial_bias_id = id(model.bias) - initial_model = deepcopy(model) - - # Test unsafe flag - with self.assertRaisesRegex(ValueError, "Registering a parametrization may not change the shape of the tensor"): - parametrize.register_parametrization(model, "weight", Resize()) # default unsafe = False - model(torch.ones(8, 8)) - - # One parametrization with unsafe=True - parametrize.register_parametrization(model, "weight", Resize(), unsafe=True) - self.assertTrue(hasattr(model, "parametrizations")) - self.assertTrue(parametrize.is_parametrized(model)) - self.assertTrue(parametrize.is_parametrized(model, "weight")) - self.assertFalse(parametrize.is_parametrized(model, "bias")) - self.assertNotIn("weight", model._parameters) - A = model.weight - self.assertTrue(A.shape[0] == 1) - parametrize.remove_parametrizations(model, "weight", leave_parametrized=False) - self.assertFalse(hasattr(model, "parametrizations")) - self.assertEqual(model.weight, initial_model.weight) - self.assertEqual(id(model.weight), initial_weight_id) - self.assertEqual(model.__class__, nn.Linear) - - # Two parametrizations with unsafe=True - parametrize.register_parametrization(model, "weight", Resize(), unsafe=True) - parametrize.register_parametrization(model, "weight", NoResize(), unsafe=False) - self.assertTrue(hasattr(model, "parametrizations")) - self.assertTrue(parametrize.is_parametrized(model)) - self.assertTrue(parametrize.is_parametrized(model, "weight")) - self.assertFalse(parametrize.is_parametrized(model, "bias")) - self.assertNotIn("weight", model._parameters) - A = model.weight - self.assertTrue(A.shape[0] == 1) - parametrize.remove_parametrizations(model, "weight", leave_parametrized=False) - self.assertFalse(hasattr(model, "parametrizations")) - self.assertEqual(model.weight, initial_model.weight) - self.assertEqual(id(model.weight), initial_weight_id) - self.assertEqual(model.__class__, nn.Linear) - - # Test unsafe flag doesn't change expected behavior - parametrize.register_parametrization(model, "weight", Skew(), unsafe=True) - self.assertTrue(hasattr(model, "parametrizations")) - self.assertTrue(parametrize.is_parametrized(model)) - self.assertTrue(parametrize.is_parametrized(model, "weight")) - self.assertFalse(parametrize.is_parametrized(model, "bias")) - self.assertNotIn("weight", model._parameters) - # Result should be skew-symmetric - A = model.weight - self.assertEqual(A, -A.T) - # Remove and check consistency - parametrize.remove_parametrizations(model, "weight", leave_parametrized=False) - self.assertFalse(hasattr(model, "parametrizations")) - self.assertEqual(model.weight, initial_model.weight) - self.assertEqual(id(model.weight), initial_weight_id) - self.assertEqual(model.__class__, nn.Linear) - - # Test one parametrization - parametrize.register_parametrization(model, "weight", Skew()) - self.assertTrue(hasattr(model, "parametrizations")) - self.assertTrue(parametrize.is_parametrized(model)) - self.assertTrue(parametrize.is_parametrized(model, "weight")) - self.assertFalse(parametrize.is_parametrized(model, "bias")) - self.assertNotIn("weight", model._parameters) - # Result should be skew-symmetric - A = model.weight - self.assertEqual(A, -A.T) - # Remove and check consistency - parametrize.remove_parametrizations(model, "weight", leave_parametrized=False) - self.assertFalse(hasattr(model, "parametrizations")) - self.assertEqual(model.weight, initial_model.weight) - self.assertEqual(id(model.weight), initial_weight_id) - self.assertEqual(model.__class__, nn.Linear) - - # Test two parametrizations at the same time and removing them - parametrize.register_parametrization(model, "weight", Skew()) - parametrize.register_parametrization(model, "weight", Orthogonal()) - # Result should be orthogonal - X = model.weight - Id = torch.eye(X.size(0), device=X.device) - self.assertEqual(X.T @ X, Id) - # Structure tests - self.assertTrue(hasattr(model, "parametrizations")) - self.assertTrue(parametrize.is_parametrized(model)) - self.assertTrue(parametrize.is_parametrized(model, "weight")) - self.assertFalse(parametrize.is_parametrized(model, "bias")) - self.assertIn("weight", model.parametrizations) - self.assertNotIn("weight", model._parameters) - # Remove - parametrize.remove_parametrizations(model, "weight", leave_parametrized=False) - self.assertEqual(model.weight, initial_model.weight) - self.assertEqual(id(model.weight), initial_weight_id) - self.assertFalse(hasattr(model, "parametrizations")) - self.assertEqual(model.__class__, nn.Linear) - - # Add everything - parametrize.register_parametrization(model, "weight", Skew()) - parametrize.register_parametrization(model, "weight", Orthogonal()) - parametrize.register_parametrization(model, "bias", FirstZero()) - parametrize.register_parametrization(model, "bias", LastZero()) - - # Basic tests - self.assertTrue(parametrize.is_parametrized(model)) - self.assertTrue(parametrize.is_parametrized(model, "weight")) - self.assertTrue(parametrize.is_parametrized(model, "bias")) - self.assertEqual(model.bias[0].item(), 0.) - self.assertEqual(model.bias[-1].item(), 0.) - self.assertEqual(len(list(model.parameters())), 2) # Nothing weird has happpened - # Should not throw - - sgd = torch.optim.SGD(model.parameters(), lr=0.01) - - weight_copy = model.weight.clone() - bias_copy = model.bias.clone() - sgd.zero_grad() - (model.weight.T @ model.bias).sum().backward() - sgd.step() - self.assertNotEqual(model.weight, weight_copy) - self.assertNotEqual(model.bias, bias_copy) - - # Remove first parametrization. - # Check that the model is still parametrized and so is the second parameter - parametrize.remove_parametrizations(model, "weight", leave_parametrized=False) - self.assertTrue(parametrize.is_parametrized(model)) # Still parametrized - self.assertFalse(parametrize.is_parametrized(model, "weight")) # Parametrization removed - self.assertTrue(parametrize.is_parametrized(model, "bias")) # Still parametrized - self.assertEqual(model.bias[0].item(), 0.) # Still parametrized - self.assertEqual(model.bias[-1].item(), 0.) # Still parametrized - self.assertNotEqual(model.weight, initial_model.weight) # Has been updated - self.assertEqual(id(model.weight), initial_weight_id) # Keeps the same id - self.assertEqual(len(list(model.parameters())), 2) # Nothing weird has happened - # Should not throw - weight_copy = model.weight.clone() - bias_copy = model.bias.clone() - sgd.zero_grad() - (model.weight.T @ model.bias).sum().backward() - sgd.step() - self.assertNotEqual(model.weight, weight_copy) - self.assertNotEqual(model.bias, bias_copy) - - # Remove the second parametrization. - # Check that the module is not parametrized - parametrize.remove_parametrizations(model, "bias", leave_parametrized=False) - self.assertFalse(parametrize.is_parametrized(model)) # Not parametrized - self.assertNotEqual(model.bias, initial_model.bias) # Has been updated - self.assertNotEqual(model.bias[0].item(), 0.) # Not parametrized - self.assertNotEqual(model.bias[-1].item(), 0.) # Not parametrized - self.assertEqual(id(model.bias), initial_bias_id) # Keeps the same id - self.assertFalse(hasattr(model, "parametrizations")) # Not parametrized the module - self.assertEqual(model.__class__, nn.Linear) # Resores the previous class - self.assertEqual(len(list(model.parameters())), 2) # Nothing weird has happeed - - # Should not throw things are updated - weight_copy = model.weight.clone() - bias_copy = model.bias.clone() - sgd.zero_grad() - (model.weight.T @ model.bias).sum().backward() - sgd.step() - self.assertNotEqual(model.weight, weight_copy) - self.assertNotEqual(model.bias, bias_copy) - - # Test leave_parametrized=True - for _ in range(2): - parametrize.register_parametrization(model, "weight", Skew()) - parametrize.register_parametrization(model, "weight", Orthogonal()) - parametrize.remove_parametrizations(model, "weight", leave_parametrized=True) - # We didn't change the dtype nor had multiple inputs, so the id should be the same - self.assertEqual(id(model.weight), initial_weight_id) - self.assertEqual(id(model.bias), initial_bias_id) - - # Should not throw. Things are updated - weight_copy = model.weight.clone() - bias_copy = model.bias.clone() - sgd.zero_grad() - (model.weight.T @ model.bias).sum().backward() - sgd.step() - self.assertNotEqual(model.weight, weight_copy) - self.assertNotEqual(model.bias, bias_copy) - - def test_register_and_remove_nested_parametrization(self): - r"""Test that it is possible to nest the parametrizations - meaning that the original param is parametrized again - """ - class Skew(nn.Module): - def forward(self, X): - X = X.tril(-1) - return X - X.T - - model = nn.Linear(8, 8) - # Add top level parametrization - parametrize.register_parametrization(model, "weight", Skew()) - self.assertTrue(hasattr(model, "parametrizations")) - self.assertTrue(parametrize.is_parametrized(model)) - self.assertTrue(parametrize.is_parametrized(model, "weight")) - self.assertFalse(parametrize.is_parametrized(model, "bias")) - self.assertNotIn("weight", model._parameters) - # Result should be skew-symmetric - A = model.weight - self.assertEqual(A, -A.T) - - # Add nested parametrization - param_mod = model.parametrizations.weight - self.assertFalse(hasattr(param_mod, "parametrizations")) - self.assertFalse(parametrize.is_parametrized(param_mod)) - self.assertFalse(parametrize.is_parametrized(param_mod, "original")) - - parametrize.register_parametrization(param_mod, "original", Skew()) - self.assertTrue(hasattr(param_mod, "parametrizations")) - self.assertTrue(parametrize.is_parametrized(param_mod)) - self.assertTrue(parametrize.is_parametrized(param_mod, "original")) - self.assertNotIn("original", param_mod._parameters) - # Result should be skew-symmetric - A = param_mod.original - self.assertEqual(A, -A.T) - - # Remove nested param and check consistency - parametrize.remove_parametrizations(param_mod, "original", leave_parametrized=False) - self.assertFalse(hasattr(param_mod, "parametrizations")) - self.assertEqual(param_mod.__class__, parametrize.ParametrizationList) - - # Remove top level and check consistency - parametrize.remove_parametrizations(model, "weight", leave_parametrized=False) - self.assertFalse(hasattr(model, "parametrizations")) - self.assertEqual(model.__class__, nn.Linear) - - def test_register_and_remove_buffer_parametrization(self): - r"""Test that it is possible to add and remove parametrizations on buffers""" - # Define a couple vector parametrizations - class FirstZero(nn.Module): - def forward(self, x): - return torch.cat([x.new_zeros(1), x[1:]]) + # Applying weight norm on one of them causes it to become a tensor + l = torch.nn.utils.weight_norm(l, name=name) + self.assertEqual( + sum([isinstance(p, torch.nn.Parameter) for p in l._flat_weights]), + num_params - 1, + ) - class LastZero(nn.Module): - def forward(self, x): - return torch.cat([x[:-1], x.new_zeros(1)]) - - model = nn.Linear(8, 8) - - # Instantiate parametrizations on buffers. It should work as expected - delattr(model, "bias") - model.register_buffer("bias", torch.ones(8)) - parametrize.register_parametrization(model, "bias", FirstZero()) - parametrize.register_parametrization(model, "bias", LastZero()) - self.assertTrue(parametrize.is_parametrized(model)) - self.assertTrue(parametrize.is_parametrized(model, "bias")) - self.assertEqual(model.bias[0].item(), 0.) - self.assertEqual(model.bias[-1].item(), 0.) - self.assertTrue((model.bias[1:-1] == torch.ones(6)).all()) - self.assertEqual(len(list(model.parameters())), 1) - - # Remove parametrizations on buffers. It should work as expected - parametrize.remove_parametrizations(model, "bias", leave_parametrized=True) - self.assertFalse(parametrize.is_parametrized(model)) - self.assertFalse(parametrize.is_parametrized(model, "bias")) - self.assertEqual(model.bias[0].item(), 0.) - self.assertEqual(model.bias[-1].item(), 0.) - self.assertTrue((model.bias[1:-1] == torch.ones(6)).all()) - self.assertEqual(len(list(model.parameters())), 1) - - # FIXME: Rewrite this test using functions not depending on LAPACK - # and remove the `@skipIfNoLapack` (see #70995) - @skipIfNoLapack - def test_serialization_parametrization(self): - r"""Test that it is possible to serialize a parametrized model via state_dict""" - # A stateful parametrization - class Orthogonal(nn.Module): - def __init__(self, n): - super().__init__() - self.register_buffer("id", torch.eye(n)) - self.register_buffer("B", torch.empty(n, n)) - init.orthogonal_(self.B) - - def forward(self, X): - A = X.triu(1) - A = A - A.T - return self.B @ torch.linalg.solve(self.id + A, self.id - A) - - def get_model(): - model = torch.nn.Sequential( - torch.nn.Linear(5, 5), - torch.nn.ReLU(), - torch.nn.Linear(5, 1), + # Removing the weight norm reparametrization restores the Parameter + l = torch.nn.utils.remove_weight_norm(l, name=name) + self.assertEqual( + sum([isinstance(p, torch.nn.Parameter) for p in l._flat_weights]), + num_params, ) - parametrize.register_parametrization(model[0], "weight", Orthogonal(5)) - return model + # Make sure that, upon removal of the reparametrization, the + # `._parameters` and `.named_parameters` contain the right params. + # Specifically, the original weight ('weight_ih_l0') should be placed + # back in the parameters, while the reparametrization components + # ('weight_ih_l0_v' and 'weight_ih_l0_g') should be removed. + self.assertTrue(name in l._parameters) + self.assertIsNotNone(l._parameters[name]) + self.assertTrue(name + '_v' not in l._parameters) + self.assertTrue(name + '_g' not in l._parameters) + self.assertTrue(name in dict(l.named_parameters())) + self.assertIsNotNone(dict(l.named_parameters())[name]) + self.assertTrue(name + '_v' not in dict(l.named_parameters())) + self.assertTrue(name + '_g' not in dict(l.named_parameters())) - model = get_model() + check_weight_norm(torch.nn.LSTM(32, 32), 'weight_ih_l0', 4) + check_weight_norm(torch.nn.LSTM(32, 32, proj_size=16), 'weight_hr_l0', 5) - prev_weight = model[0].weight - prev_B = model[0].parametrizations.weight[0].B - new_model = get_model() - with TemporaryFileName() as fname: - torch.save(model.state_dict(), fname) - new_model.load_state_dict(torch.load(fname)) + def test_weight_norm(self): + for dtype in [torch.float, torch.bfloat16]: + input = torch.randn(3, 4, dtype=dtype) + m = nn.Linear(4, 5).to(dtype=dtype) + expected_output = m(input) - # Integrity tests - self.assertTrue(parametrize.is_parametrized(new_model[0], "weight")) - self.assertEqual(prev_weight, new_model[0].weight) - self.assertEqual(prev_B, new_model[0].parametrizations.weight[0].B) + # add weight normalization + m = torch.nn.utils.weight_norm(m) + self.assertEqual(m.weight_v.size(), m.weight.size()) + self.assertEqual(m.weight_g.size(), (5, 1)) + self.assertEqual(m(input), expected_output, atol=dtype2prec_DONTUSE[dtype], rtol=0) - # Trying to save the whole parametrized model raises - with self.assertRaisesRegex(RuntimeError, "state_dict"): - with TemporaryFileName() as fname: - torch.save(model, fname) + # remove weight norm + m = torch.nn.utils.remove_weight_norm(m) + self.assertFalse(hasattr(m, 'weight_g')) + self.assertFalse(hasattr(m, 'weight_v')) + self.assertEqual(m(input), expected_output, atol=dtype2prec_DONTUSE[dtype], rtol=0) - # FIXME: Rewrite this test using functions not depending on LAPACK - # and remove the `@skipIfNoLapack` (see #70995) - @skipIfNoLapack - def test_initialization_parametrization(self): - r"""Test that it is possible to initialize a parametrization when it - implements a `right_inverse` method - """ - class Skew(nn.Module): - def forward(self, X): - A = X.triu(1) - return A - A.T - - def is_skew(self, A): - return torch.allclose(A, -A.T, atol=1e-6) - - def right_inverse(self, X): - if not self.is_skew(X): - raise ValueError("The matrix is not skew-symmetric.") - return X.triu(1) - - # Implements a Cayley map where right_inverse is not quite the inverse of forward - class Orthogonal(nn.Module): - def __init__(self, n): - super().__init__() - self.register_buffer("B", torch.eye(n)) - - def forward(self, X): - Id = torch.eye(X.size(0)) - return self.B @ torch.linalg.solve(Id + X, Id - X) - - def is_orthogonal(self, X): - Id = torch.eye(X.size(0)) - return torch.allclose(X.T @ X, Id, atol=1e-4) - - def right_inverse(self, X): - if not self.is_orthogonal(X): - raise ValueError("The input is not orthogonal.") - # cayley(0) == Id, so B @ cayley(0) == B - self.B = X - return torch.zeros_like(X) - - N = 5 - model = nn.Linear(N, N) - # Register the skew-symmetric constraint. The result is now skew-symmetric - skew = Skew() - # Make the weight skew-symmetric before registering the parametrization - with torch.no_grad(): - model.weight.set_(skew(model.weight)) - parametrize.register_parametrization(model, "weight", skew) - X = torch.rand(N, N) - # X is not skew-symmetric, so it throws an error - with self.assertRaises(ValueError): - model.weight = X - # Make X skew-symmetric - X = X - X.T - model.weight = X - self.assertEqual(model.parametrizations.weight.original, X.triu(1)) - self.assertEqual(model.weight, X) - - # Having several parametrizations registered should work in the same way - parametrize.register_parametrization(model, "weight", Orthogonal(N)) - # Register now the Cayley map. The result is now orthogonal - X = torch.rand(N, N) - # X is not orthogonal, so it throws an error - with self.assertRaises(ValueError): - model.weight = X - init.orthogonal_(X) - model.weight = X - self.assertEqual(model.weight, X) - self.assertEqual(model.parametrizations.weight.original, torch.zeros_like(X)) - - def test_errors_unparametrized_tensor_parametrization(self): - # Test errors when registering a parametrization on an unparametrized tensor - module = nn.Linear(3, 4) - weight_init = module.weight.clone() - - class Identity(nn.Module): - def forward(self, x): - return x - - # Register a parametrization on a non-existing parameter throws - with self.assertRaisesRegex(ValueError, "does not have a parameter"): - parametrize.register_parametrization(module, "foo", Identity()) - self.assertFalse(parametrize.is_parametrized(module)) - - # Removing parametrizations from an unparametrized tensor throws - with self.assertRaisesRegex(ValueError, "does not have a parametrization"): - parametrize.remove_parametrizations(module, "bias") - self.assertFalse(parametrize.is_parametrized(module)) - - # A correct parametrization with several outputs - class Sum(nn.Module): - def forward(self, x, y): - return x + y - - def right_inverse(self, z): - return z, torch.zeros_like(z) - - parametrize.register_parametrization(module, "weight", Sum()) - # Cannot remove a parametrization with several outputs with `leave_parametrized=False` - with self.assertRaisesRegex(ValueError, "leave_parametrized=False"): - parametrize.remove_parametrizations(module, "weight", leave_parametrized=False) - parametrize.remove_parametrizations(module, "weight", leave_parametrized=True) - - # A parametrization with an incorrect number of outputs - class WrongNumberParams(nn.Module): - def forward(self, x, y, z): - return x + y + z - - def right_inverse(self, w): - return w, torch.zeros_like(w) - - # Makes param(*param.right_inverse(X)) fail - with self.assertRaisesRegex(TypeError, "positional argument"): - parametrize.register_parametrization(module, "weight", WrongNumberParams()) - self.assertFalse(parametrize.is_parametrized(module)) - - # A parametrization with a right_inverse that does not return a Tensor or Sequence[Tensor] - class WrongRightInverse(Identity): - def right_inverse(self, z): - return None + # test with dim=1 + m = torch.nn.utils.weight_norm(m, dim=1) + self.assertEqual(m.weight_v.size(), m.weight.size()) + self.assertEqual(m.weight_g.size(), (1, 4)) + self.assertEqual(m(input), expected_output, atol=dtype2prec_DONTUSE[dtype], rtol=0) - # right_inverse should return a Tensor or a Sequence[Tensor] - with self.assertRaisesRegex(ValueError, "Tensor or a Sequence of"): - parametrize.register_parametrization(module, "weight", WrongRightInverse()) - self.assertFalse(parametrize.is_parametrized(module)) + # test with dim=None + m = nn.Linear(4, 5).to(dtype=dtype) + expected_output = m(input) + m = torch.nn.utils.weight_norm(m, dim=None) + self.assertEqual(m(input), expected_output) - # If it's a sequence, it must to be a sequence of tensors - class WrongRightInverseSequence(nn.Module): - def forward(self, x, y): - return x + with self.assertRaisesRegex(RuntimeError, 'register two weight_norm hooks'): + m = torch.nn.utils.weight_norm(m) + m = torch.nn.utils.weight_norm(m) - def right_inverse(self, z): - return None, z + # For float16, the forward of the Module doesn't work but we must still be able + # to register the weight norm as this is often done before sending the Module to + # CUDA. + m = nn.Linear(4, 5, dtype=torch.float16) + m = torch.nn.utils.weight_norm(m) - with self.assertRaisesRegex(ValueError, "of the sequence with type"): - parametrize.register_parametrization(module, "weight", WrongRightInverseSequence()) - self.assertFalse(parametrize.is_parametrized(module)) + def test_parameterlistdict_setting_attributes(self): + with warnings.catch_warnings(record=True) as w: + mod = nn.ParameterList(map(nn.Parameter, [torch.rand(2), torch.rand(2)])) + self.assertTrue(len(w) == 0) - # A parametrization from one tensor to one tensor that changes the dtype - class ChangeDtypeInverse(nn.Module): - def forward(self, x): - return x.float() + with warnings.catch_warnings(record=True) as w: + mod.train() + mod.eval() + self.assertTrue(len(w) == 0) - def right_inverse(self, w): - return w.bool() + with warnings.catch_warnings(record=True) as w: + mod = nn.ParameterDict({"a": nn.Parameter(torch.rand(2)), "b": nn.Parameter(torch.rand(2))}) + self.assertTrue(len(w) == 0) - # For parametrizations that return one tensor, right_inverse may not change the dtype - with self.assertRaisesRegex(ValueError, "outputs one tensor, it may not change the dtype"): - parametrize.register_parametrization(module, "weight", ChangeDtypeInverse()) - self.assertFalse(parametrize.is_parametrized(module)) + with warnings.catch_warnings(record=True) as w: + mod.train() + mod.eval() + self.assertTrue(len(w) == 0) - # Doesn't return a tensor - class NotTensor(nn.Module): - def forward(self, x): - return 2 + def test_parameterlistdict_pickle(self): + m = nn.ParameterList(map(nn.Parameter, [torch.rand(2), torch.rand(2)])) + with warnings.catch_warnings(record=True) as w: + m = pickle.loads(pickle.dumps(m)) + self.assertTrue(len(w) == 0) - # Forward must return a tensor - with self.assertRaisesRegex(ValueError, "must return a tensor"): - parametrize.register_parametrization(module, "weight", NotTensor()) - self.assertFalse(parametrize.is_parametrized(module)) + # Test whether loading from older checkpoints works without triggering warnings + m = nn.ParameterList(map(nn.Parameter, [torch.rand(2), torch.rand(2)])) + del m._forward_pre_hooks, m._state_dict_hooks, m._load_state_dict_pre_hooks, m._non_persistent_buffers_set + with warnings.catch_warnings(record=True) as w: + m = pickle.loads(pickle.dumps(m)) + self.assertTrue(len(w) == 0) - # A parametrization from one tensor to one tensor that changes the dtype - class ChangeDtype(nn.Module): - def forward(self, x): - return x.bool() + m = nn.ParameterDict({"a": nn.Parameter(torch.rand(2)), "b": nn.Parameter(torch.rand(2))}) + with warnings.catch_warnings(record=True) as w: + m = pickle.loads(pickle.dumps(m)) + self.assertTrue(len(w) == 0) - # forward should not change the initial dtype - with self.assertRaisesRegex(ValueError, "may not change the dtype"): - parametrize.register_parametrization(module, "weight", ChangeDtype()) - self.assertFalse(parametrize.is_parametrized(module)) - - # Change shape - class ChangeShape(nn.Module): - def forward(self, x): - return x[:-1] - - # forward should not change the original shape - with self.assertRaisesRegex(ValueError, "may not change the shape"): - parametrize.register_parametrization(module, "weight", ChangeShape()) - self.assertFalse(parametrize.is_parametrized(module)) - - # Many to one that changes dtype - class ChangeDtypeMulti(nn.Module): - def forward(self, x, y): - return (x + y).bool() - - def right_inverse(self, w): - return w, w + 1 - - # forward should not change the original shape even for parametrizations with many inputs - with self.assertRaisesRegex(ValueError, "may not change the dtype"): - parametrize.register_parametrization(module, "weight", ChangeDtypeMulti()) - self.assertFalse(parametrize.is_parametrized(module)) - - # Returning a sequence of size one, although weird, it's correct - class SequenceLen1(nn.Module): - def forward(self, x): - return x - - def right_inverse(self, w): - return (w,) - - parametrize.register_parametrization(module, "weight", SequenceLen1()) - self.assertTrue(hasattr(module.parametrizations.weight, "original0")) - self.assertFalse(hasattr(module.parametrizations.weight, "original1")) - _ = module.weight # Does not throw - self.assertTrue(parametrize.is_parametrized(module)) - parametrize.remove_parametrizations(module, "weight", leave_parametrized=True) - - # None of the operations above should have altered the weight - self.assertFalse(parametrize.is_parametrized(module)) - self.assertEqual(module.weight, weight_init) - - def test_errors_parametrized_tensor_parametrization(self): - # Test errors when registering a parametrization on a parametrized tensor - - class Identity(nn.Module): - def forward(self, x): - return x - - module = nn.Linear(3, 4) - parametrize.register_parametrization(module, "weight", Identity()) - - # Has to return a tensor - class WrongReturn(nn.Module): - def forward(self, x): - return x, x - - with self.assertRaisesRegex(ValueError, "must return a tensor"): - parametrize.register_parametrization(module, "weight", WrongReturn()) - self.assertTrue(parametrize.is_parametrized(module)) - self.assertEqual(len(module.parametrizations.weight), 1) - self.assertTrue(isinstance(module.parametrizations.weight[0], Identity)) - - # Cannot change dtype - class ChangeDtype(nn.Module): - def forward(self, x): - return x.bool() - - with self.assertRaisesRegex(ValueError, "may not change the dtype"): - parametrize.register_parametrization(module, "weight", ChangeDtype()) - self.assertTrue(parametrize.is_parametrized(module)) - self.assertEqual(len(module.parametrizations.weight), 1) - self.assertTrue(isinstance(module.parametrizations.weight[0], Identity)) - - # Cannot change shape - class ChangeShape(nn.Module): - def forward(self, x): - return x[:-1] - - with self.assertRaisesRegex(ValueError, "may not change the shape"): - parametrize.register_parametrization(module, "weight", ChangeShape()) - self.assertTrue(parametrize.is_parametrized(module)) - self.assertEqual(len(module.parametrizations.weight), 1) - self.assertTrue(isinstance(module.parametrizations.weight[0], Identity)) - - # The following checks are mostly due to bugs in the code of the parametrization - - # right_inverse has to return a tensor - class WrongReturnInverse(Identity): - def right_inverse(self, x): - return x, x - - with self.assertRaisesRegex(ValueError, "right_inverse must return a tensor"): - parametrize.register_parametrization(module, "weight", WrongReturnInverse()) - self.assertTrue(parametrize.is_parametrized(module)) - self.assertEqual(len(module.parametrizations.weight), 1) - self.assertTrue(isinstance(module.parametrizations.weight[0], Identity)) - - # Cannot change dtype - class ChangeDtypeInverse(Identity): - def right_inverse(self, x): - return x.bool() - - with self.assertRaisesRegex(ValueError, "must have the same dtype"): - parametrize.register_parametrization(module, "weight", ChangeDtypeInverse()) - self.assertTrue(parametrize.is_parametrized(module)) - self.assertEqual(len(module.parametrizations.weight), 1) - self.assertTrue(isinstance(module.parametrizations.weight[0], Identity)) - - # Cannot change shape - class ChangeShapeInverse(Identity): - def right_inverse(self, x): - return x[:-1] - - with self.assertRaisesRegex(ValueError, "must have the same shape"): - parametrize.register_parametrization(module, "weight", ChangeShapeInverse()) - self.assertTrue(parametrize.is_parametrized(module)) - self.assertEqual(len(module.parametrizations.weight), 1) - self.assertTrue(isinstance(module.parametrizations.weight[0], Identity)) - - # FIXME: Rewrite this test using functions not depending on LAPACK - # and remove the `@skipIfNoLapack` (see #70995) - @skipIfNoLapack - def test_multiple_inputs_parametrization(self): - # A parametrization with several outputs - class RankOne(nn.Module): - def forward(self, x, y): - # Form a rank-1 matrix from a pair of vectors - return x.unsqueeze(-1) @ y.unsqueeze(-2) - - def right_inverse(self, Y): - # We project the given matrix onto the rank 1 matrices - U, S, Vh = torch.linalg.svd(Y, full_matrices=False) - # S is ordered in a decreasing way. - s0_sqrt = S[0].sqrt().unsqueeze(-1) - return U[..., :, 0] * s0_sqrt, Vh[..., 0, :] * s0_sqrt - - # Simple parametrisation - class Double(nn.Module): - def forward(self, x): - return 2.0 * x - - def right_inverse(self, w): - return 0.5 * w - - model = nn.Linear(3, 3) - # Test one parametrization - parametrize.register_parametrization(model, "weight", RankOne()) - self.assertTrue(hasattr(model, "parametrizations")) - self.assertTrue(parametrize.is_parametrized(model)) - self.assertTrue(parametrize.is_parametrized(model, "weight")) - self.assertTrue(hasattr(model.parametrizations.weight, "original0")) - self.assertIn("original0", model.parametrizations.weight._parameters) - self.assertTrue(hasattr(model.parametrizations.weight, "original1")) - self.assertIn("original1", model.parametrizations.weight._parameters) - self.assertFalse(parametrize.is_parametrized(model, "bias")) - self.assertNotIn("weight", model._parameters) - # Result should be rank 1 - self.assertEqual(torch.linalg.matrix_rank(model.weight).item(), 1) - - with self.assertRaisesRegex(ValueError, "leave_parametrized=False"): - # Cannot remove a parametrization with multiple inputs and not leave it parametrized - parametrize.remove_parametrizations(model, "weight", leave_parametrized=False) - # Remove parametrization and check consistency - parametrize.remove_parametrizations(model, "weight", leave_parametrized=True) - self.assertFalse(hasattr(model, "parametrizations")) - self.assertEqual(model.__class__, nn.Linear) - self.assertFalse(parametrize.is_parametrized(model)) - self.assertEqual(torch.linalg.matrix_rank(model.weight).item(), 1) - self.assertIn("weight", model._parameters) - - # Registering parametrizations with one input on top of one with multiple inputs should work - init_weight = model.weight.clone() - parametrize.register_parametrization(model, "weight", RankOne()) - # Projecting a rank 1 matrix onto the matrices of rank one does not change the matrix - self.assertEqual(init_weight, model.weight) - parametrize.register_parametrization(model, "weight", Double()) - # The matrix now is twice the initial matrix - self.assertEqual(2.0 * init_weight, model.weight) - # Multiplying by a scalar does not change the rank - self.assertEqual(torch.linalg.matrix_rank(model.weight).item(), 1) - - # The model has now three parameters - self.assertEqual(len(list(model.parameters())), 3) - - sgd = torch.optim.SGD(model.parameters(), lr=0.1) - - # Test backward. Should not throw - for _ in range(2): - sgd.zero_grad() - loss = (model.weight.T @ model.bias).sum() - loss.backward() - sgd.step() - - # Same drill as before, removing should work as expected - with self.assertRaisesRegex(ValueError, "leave_parametrized=False"): - # Cannot remove a parametrization with multiple inputs and not leave it parametrized - parametrize.remove_parametrizations(model, "weight", leave_parametrized=False) - # Remove parametrization and check consistency - parametrize.remove_parametrizations(model, "weight", leave_parametrized=True) - self.assertFalse(hasattr(model, "parametrizations")) - self.assertEqual(model.__class__, nn.Linear) - self.assertFalse(parametrize.is_parametrized(model)) - self.assertEqual(torch.linalg.matrix_rank(model.weight).item(), 1) - self.assertIn("weight", model._parameters) - - # The model has now two parameters - self.assertEqual(len(list(model.parameters())), 2) - - # Test backward. Should not throw - sgd = torch.optim.SGD(model.parameters(), lr=0.1) - for _ in range(2): - sgd.zero_grad() - loss = (model.weight.T @ model.bias).sum() - loss.backward() - sgd.step() - - # FIXME: Rewrite this test using functions not depending on LAPACK - # and remove the `@skipIfNoLapack` (see #70995) - @skipIfNoLapack - def test_caching_parametrization(self): - r"""Test the caching system of a parametrization""" - # Define a couple matrix parametrizations - class Skew(nn.Module): - def forward(self, X): - X = X.tril(-1) - return X - X.T - - class Orthogonal(nn.Module): - def forward(self, X): - Id = torch.eye(X.size(0), device=X.device) - return torch.linalg.solve(Id + X, Id - X) - - model = nn.Linear(5, 5) - parametrize.register_parametrization(model, "weight", Skew()) - parametrize.register_parametrization(model, "weight", Orthogonal()) - - # Test that the caching system works - with parametrize.cached(): - X = model.weight - Y = model.weight - self.assertEqual(id(X), id(Y)) - - # FIXME: Rewrite this test using functions not depending on LAPACK - # and remove the `@skipIfNoLapack` (see #70995) - @skipIfNoLapack - def test_caching_parametrization_with_transfer_parametrizations_and_params(self): - r"""Test that transferring parametrizations doesn't cause issues with caching""" - class Skew(nn.Module): - def forward(self, X): - X = X.tril(-1) - return X - X.T - - class Orthogonal(nn.Module): - def forward(self, X): - Id = torch.eye(X.size(0), device=X.device) - return torch.linalg.solve(Id + X, Id - X) - - model = nn.Linear(5, 5) - parametrize.register_parametrization(model, "weight", Skew()) - parametrize.register_parametrization(model, "weight", Orthogonal()) - - to_model = nn.Linear(5, 5) - parametrize.transfer_parametrizations_and_params(model, to_model) - - with parametrize.cached(): - X = model.weight - Y = model.weight - self.assertEqual(id(X), id(Y)) - - A = to_model.weight - B = to_model.weight - self.assertEqual(id(A), id(B)) - - # test that the results are distinct objects for each module - self.assertNotEqual(id(A), id(X)) - - def test_parametrization_same_training_mode(self): - r"""Test training mode updated on parametrization registration""" - class Identity(nn.Module): - def forward(self, X): - return X - - module = nn.Linear(4, 4) - module.eval() - parametrize.register_parametrization(module, "weight", Identity()) - self.assertFalse(module.parametrizations.weight[0].training) - module.train() - parametrize.register_parametrization(module, "weight", Identity().eval()) - self.assertTrue(module.parametrizations.weight[0].training) - self.assertTrue(module.parametrizations.weight[1].training) - - def test_type_before_parametrizations(self): - r"""Test that type_before_parametrizations always retrieves original type""" - - class Identity(nn.Module): - def forward(self, X): - return X - - model = nn.Linear(5, 5) - original_type = type(model) - self.assertTrue( - parametrize.type_before_parametrizations(model) == original_type - ) - parametrize.register_parametrization(model, "weight", Identity()) - self.assertTrue( - parametrize.type_before_parametrizations(model) == original_type - ) - - def test_deepcopy_after_parametrization(self): - r"""Test that we are able to create a deepcopy of the module when it's parametrized.""" - - class AddOne(nn.Module): - def forward(self, x): - return x + 1.0 - - class ModelWithoutDeepcopy(nn.Module): - def __init__(self): - super().__init__() - self.weight = nn.Parameter(torch.tensor([1., 1., 1., 1.]), requires_grad=True) - self.bias = nn.Parameter(torch.tensor([0., 0., 0., 0.]), requires_grad=True) - self.attr = [1.0, 2.0, 3.0, 4.0] - - class ActualModel(ModelWithoutDeepcopy): - # Emulate custom implementation of the deepcopying. - def __deepcopy__(self, memo): - result = self.__new__(self.__class__) - memo[id(self)] = result - result.__dict__ = deepcopy(self.__dict__, memo) - return result - - def check_deepcopy(m1: nn.Module, m2: nn.Module): - w1 = m1.parametrizations.weight.original - w2 = m2.parametrizations.weight.original - b1 = m1.parametrizations.bias.original if parametrize.is_parametrized(m1, "bias") else m1.bias - b2 = m2.parametrizations.bias.original if parametrize.is_parametrized(m2, "bias") else m2.bias - # Weights, biases and attributes should be equal but they must be different objects. - self.assertEqual(m1.__dict__.keys(), m2.__dict__.keys()) - self.assertIsNot(m1, m2) - self.assertEqual(w1, w2) - self.assertIsNot(w1, w2) - self.assertEqual(b1, b2) - self.assertIsNot(b1, b2) - self.assertEqual(m1.attr, m2.attr) - self.assertIsNot(m1.attr, m2.attr) - - for model in (ModelWithoutDeepcopy(), ActualModel()): - # General check that we are able to create deepcopy. - parametrize.register_parametrization(model, "weight", AddOne()) - check_deepcopy(model, deepcopy(model)) - # Check that this works on models with several parametrized tensors. - parametrize.register_parametrization(model, "bias", AddOne()) - check_deepcopy(model, deepcopy(model)) - # Check that this works on models where tensors have more than one parametrization. - parametrize.register_parametrization(model, "weight", AddOne()) - check_deepcopy(model, deepcopy(model)) - - def test_transfer_parametrizations_and_params(self): - r"""Test that all parametrizations and their associated parameters are transferred.""" - - class AddOne(nn.Module): - def forward(self, x): - return x + 1.0 - - class Double(nn.Module): - def forward(self, x): - return 2.0 * x - - def right_inverse(self, x): - return 0.5 * x - - class MinusOne(nn.Module): - def forward(self, x): - return x - 1.0 - - model = nn.Linear(5, 5) - parametrize.register_parametrization(model, "weight", AddOne()) - parametrize.register_parametrization(model, "weight", Double()) - parametrize.register_parametrization(model, "weight", MinusOne()) - hold_weight = model.weight - - to_model = torch.ao.nn.qat.Linear( - 5, 5, qconfig=torch.ao.quantization.get_default_qconfig() - ) - parametrize.transfer_parametrizations_and_params(model, to_model) - - # checks that final and original value are correct and the to_model is parametrized - self.assertTrue(torch.nn.utils.parametrize.is_parametrized(to_model, "weight")) - self.assertEqual(model.weight, to_model.weight) - self.assertEqual( - model.parametrizations.weight.original, - to_model.parametrizations.weight.original, - ) - - # check that the transfer didn't affect the original value - self.assertEqual(hold_weight, model.weight) - - # testing that changes to one set of parametrizations do not affect the other - parametrize.remove_parametrizations(to_model, "weight") - self.assertFalse(torch.nn.utils.parametrize.is_parametrized(to_model, "weight")) - self.assertTrue(torch.nn.utils.parametrize.is_parametrized(model, "weight")) - - # also test that parameters that don't exist in to_model get transferred - model.test_param = Parameter(torch.randn(5, 5)) - - self.assertTrue(not hasattr(to_model, "test_param")) - parametrize.register_parametrization(model, "test_param", Double()) - hold_test_param = model.test_param - parametrize.transfer_parametrizations_and_params(model, to_model, "test_param") - - # check that previously missing params got transferred correctly - self.assertEqual(model.test_param, to_model.test_param) - self.assertEqual( - model.parametrizations.test_param.original, - to_model.parametrizations.test_param.original, - ) - - # check that the new transfer didn't change the value for the from_module - self.assertEqual(hold_test_param, model.test_param) - - def test_transfer_parametrizations_and_params_right_inverse(self): - r"""Test that all parametrizations and their associated parameters are transferred.""" - - class Double(nn.Module): - def forward(self, x): - return 2.0 * x - - def right_inverse(self, x): - return 0.5 * x - - model = nn.Linear(5, 5) - parametrize.register_parametrization(model, "weight", Double()) - hold_weight = model.weight - - to_model = torch.ao.nn.qat.Linear( - 5, 5, qconfig=torch.ao.quantization.get_default_qconfig() - ) - parametrize.transfer_parametrizations_and_params(model, to_model) - - # check that transfer occurs successfully - self.assertEqual(model.weight, to_model.weight) - self.assertEqual( - model.parametrizations.weight.original, - to_model.parametrizations.weight.original, - ) - - # check that transfer doesn't affect the from_model weight - self.assertEqual(hold_weight, model.weight) - - def test_transfer_parametrizations_and_params_single_param(self): - r"""Test that all parametrizations and their associated parameters are transferred.""" - - class AddOne(nn.Module): - def forward(self, x): - return x + 1.0 - - class Double(nn.Module): - def forward(self, x): - return 2.0 * x - - class MinusOne(nn.Module): - def forward(self, x): - return x - 1.0 - - model = nn.Linear(5, 5, bias=True) - parametrize.register_parametrization(model, "weight", AddOne()) - parametrize.register_parametrization(model, "weight", Double()) - parametrize.register_parametrization(model, "weight", MinusOne()) - parametrize.register_parametrization(model, "bias", AddOne()) - parametrize.register_parametrization(model, "bias", Double()) - parametrize.register_parametrization(model, "bias", MinusOne()) - - to_model = torch.ao.nn.qat.Linear( - 5, 5, bias=True, qconfig=torch.ao.quantization.get_default_qconfig() - ) - parametrize.transfer_parametrizations_and_params(model, to_model, "weight") - - # check that weight and only weight was transferred - self.assertEqual(model.weight, to_model.weight) - self.assertEqual( - model.parametrizations.weight.original, - to_model.parametrizations.weight.original, - ) - self.assertTrue("bias" not in to_model.parametrizations) - - # FIXME: Rewrite this test using functions not depending on LAPACK - # and remove the `@skipIfNoLapack` (see #70995) - @skipIfNoLapack - def test_transfer_parametrizations_and_params_many_to_one(self): - # A parametrization with several outputs - class RankOne(nn.Module): - def forward(self, x, y): - # Form a rank-1 matrix from a pair of vectors - return x.unsqueeze(-1) @ y.unsqueeze(-2) - - def right_inverse(self, Y): - # We project the given matrix onto the rank 1 matrices - U, S, Vh = torch.linalg.svd(Y, full_matrices=False) - # S is ordered in a decreasing way. - s0_sqrt = S[0].sqrt().unsqueeze(-1) - return U[..., :, 0] * s0_sqrt, Vh[..., 0, :] * s0_sqrt - - class Double(nn.Module): - def forward(self, x): - return 2.0 * x - - model = nn.Linear(3, 3) - parametrize.register_parametrization(model, "weight", RankOne()) - parametrize.register_parametrization(model, "weight", Double()) - hold_weight = model.weight - - to_model = torch.ao.nn.qat.Linear( - 3, 3, qconfig=torch.ao.quantization.get_default_qconfig() - ) - - parametrize.transfer_parametrizations_and_params(model, to_model) - - # checks that final and original value are correct and the to_model is parametrized - self.assertTrue(torch.nn.utils.parametrize.is_parametrized(to_model, "weight")) - self.assertEqual(model.weight, to_model.weight) - self.assertEqual( - model.parametrizations.weight.original0, - to_model.parametrizations.weight.original0, - ) - self.assertEqual( - model.parametrizations.weight.original1, - to_model.parametrizations.weight.original1, - ) - - # check that the transfer didn't affect the original value - self.assertEqual(hold_weight, model.weight) - - # testing that changes to one set of parametrizations do not affect the other - model.test_param = Parameter(torch.randn(3, 3)) - - self.assertTrue(not hasattr(to_model, "test_param")) - parametrize.register_parametrization(model, "test_param", RankOne()) - hold_test_param = model.test_param - parametrize.transfer_parametrizations_and_params(model, to_model, "test_param") - - # also check that previously missing params got transferred correctly - self.assertEqual(model.test_param, to_model.test_param) - self.assertEqual( - model.parametrizations.test_param.original0, - to_model.parametrizations.test_param.original0, - ) - self.assertEqual( - model.parametrizations.test_param.original1, - to_model.parametrizations.test_param.original1, - ) - - # check that the new transfer didn't change the value for the from_module - self.assertEqual(hold_test_param, model.test_param) - - # torch/nn/utils/prune.py - @unittest.skipIf(not TEST_NUMPY, "numpy not found") - def test_validate_pruning_amount_init(self): - r"""Test the first util function that validates the pruning - amount requested by the user the moment the pruning method - is initialized. This test checks that the expected errors are - raised whenever the amount is invalid. - The original function runs basic type checking + value range checks. - It doesn't check the validity of the pruning amount with - respect to the size of the tensor to prune. That's left to - `_validate_pruning_amount`, tested below. - """ - # neither float not int should raise TypeError - with self.assertRaises(TypeError): - prune._validate_pruning_amount_init(amount="I'm a string") - - # float not in [0, 1] should raise ValueError - with self.assertRaises(ValueError): - prune._validate_pruning_amount_init(amount=1.1) - with self.assertRaises(ValueError): - prune._validate_pruning_amount_init(amount=20.) - - # negative int should raise ValueError - with self.assertRaises(ValueError): - prune._validate_pruning_amount_init(amount=-10) - - # all these should pass without errors because they're valid amounts - prune._validate_pruning_amount_init(amount=0.34) - prune._validate_pruning_amount_init(amount=1500) - prune._validate_pruning_amount_init(amount=0) - prune._validate_pruning_amount_init(amount=0.) - prune._validate_pruning_amount_init(amount=1) - prune._validate_pruning_amount_init(amount=1.) - self.assertTrue(True) - - @unittest.skipIf(not TEST_NUMPY, "numpy not found") - def test_validate_pruning_amount(self): - r"""Tests the second util function that validates the pruning - amount requested by the user, this time with respect to the size - of the tensor to prune. The rationale is that if the pruning amount, - converted to absolute value of units to prune, is larger than - the number of units in the tensor, then we expect the util function - to raise a value error. - """ - # if amount is int and amount > tensor_size, raise ValueError - with self.assertRaises(ValueError): - prune._validate_pruning_amount(amount=20, tensor_size=19) - - # amount is a float so this should not raise an error - prune._validate_pruning_amount(amount=0.3, tensor_size=0) - - # this is okay - prune._validate_pruning_amount(amount=19, tensor_size=20) - prune._validate_pruning_amount(amount=0, tensor_size=0) - prune._validate_pruning_amount(amount=1, tensor_size=1) - self.assertTrue(True) - - @unittest.skipIf(not TEST_NUMPY, "numpy not found") - def test_compute_nparams_to_prune(self): - r"""Test that requested pruning `amount` gets translated into the - correct absolute number of units to prune. - """ - self.assertEqual( - prune._compute_nparams_toprune(amount=0, tensor_size=15), - 0 - ) - self.assertEqual( - prune._compute_nparams_toprune(amount=10, tensor_size=15), - 10 - ) - # if 1 is int, means 1 unit - self.assertEqual( - prune._compute_nparams_toprune(amount=1, tensor_size=15), - 1 - ) - # if 1. is float, means 100% of units - self.assertEqual( - prune._compute_nparams_toprune(amount=1., tensor_size=15), - 15 - ) - self.assertEqual( - prune._compute_nparams_toprune(amount=0.4, tensor_size=17), - 7 - ) - - def test_random_pruning_sizes(self): - r"""Test that the new parameters and buffers created by the pruning - method have the same size as the input tensor to prune. These, in - fact, correspond to the pruned version of the tensor itself, its - mask, and its original copy, so the size must match. - """ - # fixturize test - # TODO: add other modules - modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)] - names = ['weight', 'bias'] - - for m in modules: - for name in names: - with self.subTest(m=m, name=name): - original_tensor = getattr(m, name) - - prune.random_unstructured(m, name=name, amount=0.1) - # mask has the same size as tensor being pruned - self.assertEqual( - original_tensor.size(), - getattr(m, name + '_mask').size() - ) - # 'orig' tensor has the same size as the original tensor - self.assertEqual( - original_tensor.size(), - getattr(m, name + '_orig').size() - ) - # new tensor has the same size as the original tensor - self.assertEqual( - original_tensor.size(), - getattr(m, name).size() - ) - - def test_random_pruning_orig(self): - r"""Test that original tensor is correctly stored in 'orig' - after pruning is applied. Important to make sure we don't - lose info about the original unpruned parameter. - """ - # fixturize test - # TODO: add other modules - modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)] - names = ['weight', 'bias'] - - for m in modules: - for name in names: - with self.subTest(m=m, name=name): - - # tensor prior to pruning - original_tensor = getattr(m, name) - prune.random_unstructured(m, name=name, amount=0.1) - self.assertEqual( - original_tensor, - getattr(m, name + '_orig') - ) - - def test_random_pruning_new_weight(self): - r"""Test that module.name now contains a pruned version of - the original tensor obtained from multiplying it by the mask. - """ - # fixturize test - # TODO: add other modules - modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)] - names = ['weight', 'bias'] - - for m in modules: - for name in names: - with self.subTest(m=m, name=name): - # tensor prior to pruning - original_tensor = getattr(m, name) - prune.random_unstructured(m, name=name, amount=0.1) - # weight = weight_orig * weight_mask - self.assertEqual( - getattr(m, name), - getattr(m, name + '_orig') - * getattr(m, name + '_mask').to( - dtype=original_tensor.dtype - ), - ) - - def test_identity_pruning(self): - r"""Test that a mask of 1s does not change forward or backward. - """ - input_ = torch.ones(1, 5) - m = nn.Linear(5, 2) - y_prepruning = m(input_) # output prior to pruning - - # compute grad pre-pruning and check it's equal to all ones - y_prepruning.sum().backward() - old_grad_weight = m.weight.grad.clone() # don't grab pointer! - self.assertEqual(old_grad_weight, torch.ones_like(m.weight)) - old_grad_bias = m.bias.grad.clone() - self.assertEqual(old_grad_bias, torch.ones_like(m.bias)) - - # remove grads - m.zero_grad() - - # force the mask to be made of all 1s - prune.identity(m, name="weight") - - # with mask of 1s, output should be identical to no mask - y_postpruning = m(input_) - self.assertEqual(y_prepruning, y_postpruning) - - # with mask of 1s, grad should be identical to no mask - y_postpruning.sum().backward() - self.assertEqual(old_grad_weight, m.weight_orig.grad) - self.assertEqual(old_grad_bias, m.bias.grad) - - # calling forward twice in a row shouldn't change output - y1 = m(input_) - y2 = m(input_) - self.assertEqual(y1, y2) - - def test_random_pruning_0perc(self): - r"""Test that a mask of 1s does not change forward or backward. - """ - input_ = torch.ones(1, 5) - m = nn.Linear(5, 2) - y_prepruning = m(input_) # output prior to pruning - - # compute grad pre-pruning and check it's equal to all ones - y_prepruning.sum().backward() - old_grad_weight = m.weight.grad.clone() # don't grab pointer! - self.assertEqual(old_grad_weight, torch.ones_like(m.weight)) - old_grad_bias = m.bias.grad.clone() - self.assertEqual(old_grad_bias, torch.ones_like(m.bias)) - - # remove grads - m.zero_grad() - - # force the mask to be made of all 1s - with mock.patch( - "torch.nn.utils.prune.RandomUnstructured.compute_mask" - ) as compute_mask: - compute_mask.return_value = torch.ones_like(m.weight) - prune.random_unstructured(m, name='weight', amount=0.9) # amount won't count - - # with mask of 1s, output should be identical to no mask - y_postpruning = m(input_) - self.assertEqual(y_prepruning, y_postpruning) - - # with mask of 1s, grad should be identical to no mask - y_postpruning.sum().backward() - self.assertEqual(old_grad_weight, m.weight_orig.grad) - self.assertEqual(old_grad_bias, m.bias.grad) - - # calling forward twice in a row shouldn't change output - y1 = m(input_) - y2 = m(input_) - self.assertEqual(y1, y2) - - def test_random_pruning(self): - input_ = torch.ones(1, 5) - m = nn.Linear(5, 2) - - # define custom mask to assign with mock - mask = torch.ones_like(m.weight) - mask[1, 0] = 0 - mask[0, 3] = 0 - - # check grad is zero for masked weights - with mock.patch( - "torch.nn.utils.prune.RandomUnstructured.compute_mask" - ) as compute_mask: - compute_mask.return_value = mask - prune.random_unstructured(m, name='weight', amount=0.9) - - y_postpruning = m(input_) - y_postpruning.sum().backward() - # weight_orig is the parameter, so it's the tensor that will accumulate the grad - self.assertEqual(m.weight_orig.grad, mask) # all 1s, except for masked units - self.assertEqual(m.bias.grad, torch.ones_like(m.bias)) - - # make sure that weight_orig update doesn't modify [1, 0] and [0, 3] - old_weight_orig = m.weight_orig.clone() - # update weights - learning_rate = 1. - for p in m.parameters(): - p.data.sub_(p.grad.data * learning_rate) - # since these are pruned, they should not be updated - self.assertEqual(old_weight_orig[1, 0], m.weight_orig[1, 0]) - self.assertEqual(old_weight_orig[0, 3], m.weight_orig[0, 3]) - - def test_random_pruning_forward(self): - r"""check forward with mask (by hand). - """ - input_ = torch.ones(1, 5) - m = nn.Linear(5, 2) - - # define custom mask to assign with mock - mask = torch.zeros_like(m.weight) - mask[1, 0] = 1 - mask[0, 3] = 1 - - with mock.patch( - "torch.nn.utils.prune.RandomUnstructured.compute_mask" - ) as compute_mask: - compute_mask.return_value = mask - prune.random_unstructured(m, name='weight', amount=0.9) - - yhat = m(input_) - self.assertEqual(yhat[0, 0], m.weight_orig[0, 3] + m.bias[0]) - self.assertEqual(yhat[0, 1], m.weight_orig[1, 0] + m.bias[1]) - - def test_remove_pruning_forward(self): - r"""Remove pruning and check forward is unchanged from previous - pruned state. - """ - input_ = torch.ones(1, 5) - m = nn.Linear(5, 2) - - # define custom mask to assign with mock - mask = torch.ones_like(m.weight) - mask[1, 0] = 0 - mask[0, 3] = 0 - - # check grad is zero for masked weights - with mock.patch( - "torch.nn.utils.prune.RandomUnstructured.compute_mask" - ) as compute_mask: - compute_mask.return_value = mask - prune.random_unstructured(m, name='weight', amount=0.9) - - y_postpruning = m(input_) - - prune.remove(m, 'weight') - - y_postremoval = m(input_) - self.assertEqual(y_postpruning, y_postremoval) - - def test_pruning_id_consistency(self): - r"""Test that pruning doesn't change the id of the parameters, which - would otherwise introduce issues with pre-existing optimizers that - point to old parameters. - """ - m = nn.Linear(5, 2, bias=False) - - tensor_id = id(list(m.parameters())[0]) - - prune.random_unstructured(m, name="weight", amount=0.9) - self.assertEqual(tensor_id, id(list(m.parameters())[0])) - - prune.remove(m, "weight") - self.assertEqual(tensor_id, id(list(m.parameters())[0])) - - def test_random_pruning_pickle(self): - modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)] - names = ['weight', 'bias'] - - for m in modules: - for name in names: - with self.subTest(m=m, name=name): - prune.random_unstructured(m, name=name, amount=0.1) - m_new = pickle.loads(pickle.dumps(m)) - self.assertIsInstance(m_new, type(m)) - - def test_multiple_pruning_calls(self): - # if you call pruning twice, the hook becomes a PruningContainer - m = nn.Conv3d(2, 2, 2) - prune.l1_unstructured(m, name='weight', amount=0.1) - weight_mask0 = m.weight_mask # save it for later sanity check - - # prune again - prune.ln_structured(m, name='weight', amount=0.3, n=2, dim=0) - hook = next(iter(m._forward_pre_hooks.values())) - self.assertIsInstance( - hook, - torch.nn.utils.prune.PruningContainer - ) - # check that container._tensor_name is correctly set no matter how - # many pruning methods are in the container - self.assertEqual(hook._tensor_name, 'weight') - - # check that the pruning container has the right length - # equal to the number of pruning iters - self.assertEqual(len(hook), 2) # m.weight has been pruned twice - - # check that the entries of the pruning container are of the expected - # type and in the expected order - self.assertIsInstance(hook[0], torch.nn.utils.prune.L1Unstructured) - self.assertIsInstance(hook[1], torch.nn.utils.prune.LnStructured) - - # check that all entries that are 0 in the 1st mask are 0 in the - # 2nd mask too - self.assertTrue(torch.all(m.weight_mask[weight_mask0 == 0] == 0)) - - # prune again - prune.ln_structured(m, name='weight', amount=0.1, n=float('inf'), dim=1) - # check that container._tensor_name is correctly set no matter how - # many pruning methods are in the container - hook = next(iter(m._forward_pre_hooks.values())) - self.assertEqual(hook._tensor_name, 'weight') - - def test_pruning_container(self): - # create an empty container - container = prune.PruningContainer() - container._tensor_name = 'test' - self.assertEqual(len(container), 0) - - p = prune.L1Unstructured(amount=2) - p._tensor_name = 'test' - - # test adding a pruning method to a container - container.add_pruning_method(p) - - # test error raised if tensor name is different - q = prune.L1Unstructured(amount=2) - q._tensor_name = 'another_test' - with self.assertRaises(ValueError): - container.add_pruning_method(q) - - # test that adding a non-pruning method object to a pruning container - # raises a TypeError - with self.assertRaises(TypeError): - container.add_pruning_method(10) - with self.assertRaises(TypeError): - container.add_pruning_method('ugh') - - def test_pruning_container_compute_mask(self): - r"""Test `compute_mask` of pruning container with a known `t` and - `default_mask`. Indirectly checks that Ln structured pruning is - acting on the right axis. - """ - # create an empty container - container = prune.PruningContainer() - container._tensor_name = 'test' - - # 1) test unstructured pruning - # create a new pruning method - p = prune.L1Unstructured(amount=2) - p._tensor_name = 'test' - # add the pruning method to the container - container.add_pruning_method(p) - - # create tensor to be pruned - t = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]).to(dtype=torch.float32) - # create prior mask by hand - default_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 1]]) - # since we are pruning the two lowest magnitude units, the outcome of - # the calculation should be this: - expected_mask = torch.tensor([[0, 0, 1, 0], [1, 1, 0, 1]], dtype=torch.float32) - computed_mask = container.compute_mask(t, default_mask) - self.assertEqual(expected_mask, computed_mask) - - # 2) test structured pruning - q = prune.LnStructured(amount=1, n=2, dim=0) - q._tensor_name = 'test' - container.add_pruning_method(q) - # since we are pruning the lowest magnitude one of the two rows, the - # outcome of the calculation should be this: - expected_mask = torch.tensor([[0, 0, 0, 0], [1, 1, 0, 1]], dtype=torch.float32) - computed_mask = container.compute_mask(t, default_mask) - self.assertEqual(expected_mask, computed_mask) - - # 2) test structured pruning, along another axis - r = prune.LnStructured(amount=1, n=2, dim=1) - r._tensor_name = 'test' - container.add_pruning_method(r) - # since we are pruning the lowest magnitude of the four columns, the - # outcome of the calculation should be this: - expected_mask = torch.tensor([[0, 1, 1, 0], [0, 1, 0, 1]], dtype=torch.float32) - computed_mask = container.compute_mask(t, default_mask) - self.assertEqual(expected_mask, computed_mask) - - def test_l1_unstructured_pruning(self): - r"""Test that l1 unstructured pruning actually removes the lowest - entries by l1 norm (by hand). It also checks that applying l1 - unstructured pruning more than once respects the previous mask. - """ - m = nn.Linear(4, 2) - # modify its weight matrix by hand - m.weight = torch.nn.Parameter( - torch.tensor( - [[1, 2, 3, 4], [-4, -3, -2, -1]], dtype=torch.float32 - ) - ) - - prune.l1_unstructured(m, 'weight', amount=2) - expected_weight = torch.tensor([[0, 2, 3, 4], [-4, -3, -2, 0]], - dtype=m.weight.dtype) - self.assertEqual(expected_weight, m.weight) - - # check that pruning again removes the next two smallest entries - prune.l1_unstructured(m, 'weight', amount=2) - expected_weight = torch.tensor([[0, 0, 3, 4], [-4, -3, 0, 0]], - dtype=m.weight.dtype) - self.assertEqual(expected_weight, m.weight) - - def test_l1_unstructured_pruning_with_importance_scores(self): - r"""Test that l1 unstructured pruning actually removes the lowest - entries of importance scores and not the parameter by l1 norm (by hand). - It also checks that applying l1 unstructured pruning more than once - respects the previous mask. - """ - m = nn.Linear(4, 2) - # modify its weight matrix by hand - m.weight = torch.nn.Parameter( - torch.tensor( - [[1, 2, 3, 4], [-4, -3, -2, -1]], dtype=torch.float32 - ) - ) - importance_scores = torch.tensor( - [[4, 2, 1, 3], [-3, -1, -2, -4]], dtype=torch.float32 - ) - - prune.l1_unstructured(m, 'weight', amount=2, importance_scores=importance_scores) - expected_weight = torch.tensor([[1, 2, 0, 4], [-4, 0, -2, -1]], - dtype=m.weight.dtype) - self.assertEqual(expected_weight, m.weight) - - # check that pruning again removes two entries of m.weight that are colocated with - # the next two smallest absolute values of importance scores. - prune.l1_unstructured(m, 'weight', amount=2, importance_scores=importance_scores) - expected_weight = torch.tensor([[1, 0, 0, 4], [-4, 0, 0, -1]], - dtype=m.weight.dtype) - self.assertEqual(expected_weight, m.weight) - - def test_unstructured_pruning_same_magnitude(self): - r"""Since it may happen that the tensor to prune has entries with the - same exact magnitude, it is important to check that pruning happens - consistenly based on the bottom % of weights, and not by threshold, - which would instead kill off *all* units with magnitude = threshold. - """ - AMOUNT = 0.2 - p = prune.L1Unstructured(amount=AMOUNT) - # create a random tensors with entries in {-2, 0, 2} - t = 2 * torch.randint(low=-1, high=2, size=(10, 7)) - nparams_toprune = prune._compute_nparams_toprune(AMOUNT, t.nelement()) - - computed_mask = p.compute_mask(t, default_mask=torch.ones_like(t)) - nparams_pruned = torch.sum(computed_mask == 0) - self.assertEqual(nparams_toprune, nparams_pruned) - - def test_random_structured_pruning_amount(self): - AMOUNT = 0.6 - AXIS = 2 - p = prune.RandomStructured(amount=AMOUNT, dim=AXIS) - t = 2 * torch.randint(low=-1, high=2, size=(5, 4, 2)).to( - dtype=torch.float32 - ) - nparams_toprune = prune._compute_nparams_toprune(AMOUNT, t.shape[AXIS]) - - computed_mask = p.compute_mask(t, default_mask=torch.ones_like(t)) - # check that 1 column is fully prune, the others are left untouched - remaining_axes = [_ for _ in range(len(t.shape)) if _ != AXIS] - per_column_sums = sorted( - torch.sum(computed_mask == 0, axis=remaining_axes) - ) - assert per_column_sums == [0, 20] - - def test_ln_structured_pruning(self): - r"""Check Ln structured pruning by hand. - """ - m = nn.Conv2d(3, 1, 2) - m.weight.data = torch.tensor( - [[[[1., 2.], [1., 2.5]], - [[0.5, 1.], [0.1, 0.1]], - [[-3., -5.], [0.1, -1.]]]] - ) - # expected effect of pruning 1 of the 3 channels by L2-norm - expected_mask_axis1 = torch.ones_like(m.weight) - expected_mask_axis1[:, 1] = 0. - - prune.ln_structured(m, 'weight', amount=1, n=2, dim=1) - self.assertEqual(expected_mask_axis1, m.weight_mask) - - # expected effect of pruning 1 of the 2 columns along axis -1 by L1-norm - expected_mask_axis3 = expected_mask_axis1 - expected_mask_axis3[:, :, :, 0] = 0. - - prune.ln_structured(m, 'weight', amount=1, n=1, dim=-1) - self.assertEqual(expected_mask_axis3, m.weight_mask) - - def test_ln_structured_pruning_importance_scores(self): - r"""Check Ln structured pruning by hand. - """ - m = nn.Conv2d(3, 1, 2) - m.weight.data = torch.tensor( - [[[[1., 2.], [1., 2.5]], - [[0.5, 1.], [0.1, 0.1]], - [[-3., -5.], [0.1, -1.]]]] - ) - importance_scores = torch.tensor( - [[[[10., 1.], [10., 1.]], - [[30., 3.], [30., 3.]], - [[-20., -2.], [-20., -2.]]]] - ) - # expected effect of pruning 1 of the 3 channels by L2-norm - expected_mask_axis1 = torch.ones_like(m.weight) - expected_mask_axis1[:, 0] = 0. - - prune.ln_structured(m, 'weight', amount=1, n=2, dim=1, importance_scores=importance_scores) - self.assertEqual(expected_mask_axis1, m.weight_mask) - - # expected effect of pruning 1 of the 2 columns along axis -1 by L1-norm - expected_mask_axis3 = expected_mask_axis1 - expected_mask_axis3[:, :, :, 1] = 0. - - prune.ln_structured(m, 'weight', amount=1, n=1, dim=-1, importance_scores=importance_scores) - self.assertEqual(expected_mask_axis3, m.weight_mask) - - def test_remove_pruning(self): - r"""`prune.remove` removes the hook and the reparametrization - and makes the pruning final in the original parameter. - """ - modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)] - names = ['weight', 'bias'] - - for m in modules: - for name in names: - with self.subTest(m=m, name=name): - # first prune - prune.random_unstructured(m, name, amount=0.5) - self.assertIn(name + "_orig", dict(m.named_parameters())) - self.assertIn(name + "_mask", dict(m.named_buffers())) - self.assertNotIn(name, dict(m.named_parameters())) - self.assertTrue(hasattr(m, name)) - pruned_t = getattr(m, name) - - # then remove pruning - prune.remove(m, name) - self.assertIn(name, dict(m.named_parameters())) - self.assertNotIn(name + "_orig", dict(m.named_parameters())) - self.assertNotIn(name + "_mask", dict(m.named_buffers())) - final_t = getattr(m, name) - - self.assertEqual(pruned_t, final_t) - - def test_remove_pruning_exception(self): - r"""Removing from an unpruned tensor throws an assertion error - """ - modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)] - names = ['weight', 'bias'] - - for m in modules: - for name in names: - with self.subTest(m=m, name=name): - # check that the module isn't pruned - self.assertFalse(prune.is_pruned(m)) - # since it isn't pruned, pruning can't be removed from it - with self.assertRaises(ValueError): - prune.remove(m, name) - - - def test_global_pruning(self): - r"""Test that global l1 unstructured pruning over 2 parameters removes - the `amount=4` smallest global weights across the 2 parameters. - """ - m = nn.Linear(4, 2) - n = nn.Linear(3, 1) - # modify the weight matrices by hand - m.weight = torch.nn.Parameter( - torch.tensor([[1, 2, 3, 4], [-4, -3, -2, -1]]).to( - dtype=torch.float32) - ) - n.weight = torch.nn.Parameter( - torch.tensor([[0, 0.1, -2]]).to( - dtype=torch.float32) - ) - - params_to_prune = ( - (m, 'weight'), - (n, 'weight'), - ) - - # prune the 4 smallest weights globally by L1 magnitude - prune.global_unstructured( - params_to_prune, - pruning_method=prune.L1Unstructured, - amount=4 - ) - - expected_mweight = torch.tensor([[0, 2, 3, 4], [-4, -3, -2, 0]], - dtype=m.weight.dtype) - self.assertEqual(expected_mweight, m.weight) - - expected_nweight = torch.tensor([[0, 0, -2]]).to(dtype=n.weight.dtype) - self.assertEqual(expected_nweight, n.weight) - - def test_global_pruning_importance_scores(self): - r"""Test that global l1 unstructured pruning over 2 parameters removes - the `amount=4` smallest global weights across the 2 parameters. - """ - m = nn.Linear(4, 2) - n = nn.Linear(3, 1) - # modify the weight matrices by hand - m.weight = torch.nn.Parameter( - torch.tensor([[1, 2, 3, 4], [-4, -3, -2, -1]]).to( - dtype=torch.float32) - ) - m_importance_scores = torch.tensor( - [[4, 2, 1, 3], [-3, -1, -2, -4]], dtype=torch.float32 - ) - n.weight = torch.nn.Parameter( - torch.tensor([[0, 0.1, -2]]).to( - dtype=torch.float32) - ) - n_importance_scores = torch.tensor([[0, 10., -0.2]]).to(dtype=torch.float32) - - params_to_prune = ( - (m, 'weight'), - (n, 'weight'), - ) - importance_scores = { - (m, 'weight'): m_importance_scores, - (n, 'weight'): n_importance_scores, - } - - # prune the 4 smallest weights globally by L1 magnitude - prune.global_unstructured( - params_to_prune, - pruning_method=prune.L1Unstructured, - amount=4, - importance_scores=importance_scores, - ) - - expected_m_weight = torch.tensor([[1, 2, 0, 4], [-4, 0, -2, -1]], - dtype=m.weight.dtype) - self.assertEqual(expected_m_weight, m.weight) - - expected_n_weight = torch.tensor([[0, 0.1, 0]]).to(dtype=n.weight.dtype) - self.assertEqual(expected_n_weight, n.weight) - - def test_custom_from_mask_pruning(self): - r"""Test that the CustomFromMask is capable of receiving - as input at instantiation time a custom mask, and combining it with - the previous default mask to generate the correct final mask. - """ - # new mask - mask = torch.tensor([[0, 1, 1, 0], [0, 0, 1, 1]]) - # old mask - default_mask = torch.tensor([[0, 0, 0, 0], [1, 1, 1, 1]]) - - # some tensor (not actually used) - t = torch.rand_like(mask.to(dtype=torch.float32)) - - p = prune.CustomFromMask(mask=mask) - - computed_mask = p.compute_mask(t, default_mask) - expected_mask = torch.tensor([[0, 0, 0, 0], [0, 0, 1, 1]], dtype=computed_mask.dtype) - - self.assertEqual(computed_mask, expected_mask) - - def test_pruning_rollback(self): - r"""Test that if something fails when the we try to compute the mask, - then the model isn't left in some intermediate half-pruned state. - The try/except statement in `apply` should handle rolling back - to the previous state before pruning began. - """ - modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)] - names = ['weight', 'bias'] - - for m in modules: - for name in names: - with self.subTest(m=m, name=name): - - with mock.patch( - "torch.nn.utils.prune.L1Unstructured.compute_mask" - ) as compute_mask: - compute_mask.side_effect = Exception('HA!') - with self.assertRaises(Exception): - prune.l1_unstructured(m, name=name, amount=0.9) - - self.assertTrue( - name in dict(m.named_parameters()) - ) - self.assertFalse( - name + '_mask' in dict(m.named_buffers()) - ) - self.assertFalse( - name + '_orig' in dict(m.named_parameters()) - ) - - def test_pruning_serialization_model(self): - # create a model - model = torch.nn.Sequential( - torch.nn.Linear(10, 10), - torch.nn.ReLU(), - torch.nn.Linear(10, 1), - ) - # check that everything looks normal before pruning - self.assertNotIn('0.weight_orig', model.state_dict()) - self.assertNotIn('0.weight_mask', model.state_dict()) - self.assertIn('0.weight', model.state_dict()) - - # prune one of its parameters - prune.l1_unstructured(module=model[0], name='weight', amount=0.9) - - # check that the original weight and the new mask are present - self.assertIn('0.weight_orig', model.state_dict()) - self.assertIn('0.weight_mask', model.state_dict()) - self.assertNotIn('0.weight', model.state_dict()) - self.assertTrue(hasattr(model[0], 'weight')) - - pruned_weight = model[0].weight - - with TemporaryFileName() as fname: - torch.save(model, fname) - new_model = torch.load(fname) - - # check that the original weight and the new mask are present - self.assertIn('0.weight_orig', new_model.state_dict()) - self.assertIn('0.weight_mask', new_model.state_dict()) - self.assertNotIn('0.weight', new_model.state_dict()) - self.assertTrue(hasattr(new_model[0], 'weight')) - - self.assertEqual(pruned_weight, new_model[0].weight) - - def test_pruning_serialization_state_dict(self): - # create a model - model = torch.nn.Sequential( - torch.nn.Linear(10, 10), - torch.nn.ReLU(), - torch.nn.Linear(10, 1), - ) - # check that everything looks normal before pruning - self.assertNotIn('0.weight_orig', model.state_dict()) - self.assertNotIn('0.weight_mask', model.state_dict()) - self.assertIn('0.weight', model.state_dict()) - - # prune one of its parameters - prune.l1_unstructured(module=model[0], name='weight', amount=0.9) - - # check that the original weight and the new mask are present - self.assertIn('0.weight_orig', model.state_dict()) - self.assertIn('0.weight_mask', model.state_dict()) - self.assertNotIn('0.weight', model.state_dict()) - self.assertTrue(hasattr(model[0], 'weight')) - - pruned_weight = model[0].weight - - # make pruning permanent and restore parameter names as in base - # architecture - prune.remove(module=model[0], name='weight') - - # check that the original weight and the new mask are no longer present - self.assertNotIn('0.weight_orig', model.state_dict()) - self.assertNotIn('0.weight_mask', model.state_dict()) - self.assertIn('0.weight', model.state_dict()) - - # save the state dict of model and reload it into new_model - new_model = torch.nn.Sequential( - torch.nn.Linear(10, 10), - torch.nn.ReLU(), - torch.nn.Linear(10, 1), - ) - with TemporaryFileName() as fname: - torch.save(model.state_dict(), fname) - new_model.load_state_dict(torch.load(fname)) - - # check that the original weight and the new mask are not present in - # new_model either. - self.assertNotIn('0.weight_orig', new_model.state_dict()) - self.assertNotIn('0.weight_mask', new_model.state_dict()) - self.assertIn('0.weight', new_model.state_dict()) - - self.assertEqual(pruned_weight, new_model[0].weight) - - def test_prune(self): - # create a new pruning method - p = prune.L1Unstructured(amount=2) - # create tensor to be pruned - t = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]).to(dtype=torch.float32) - # create prior mask by hand - default_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 1]]) - # since we are pruning the two lowest magnitude units, the outcome of - # the calculation should be this: - expected_mask = torch.tensor([[0, 0, 1, 0], [1, 1, 0, 1]]) - pruned_tensor = p.prune(t, default_mask) - self.assertEqual(t * expected_mask, pruned_tensor) - - def test_prune_importance_scores(self): - # create a new pruning method - p = prune.L1Unstructured(amount=2) - # create tensor to be pruned - t = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]).to(dtype=torch.float32) - importance_scores = torch.tensor( - [[1, 2, 3, 4], [1.5, 1.6, 1.7, 1.8]] - ).to(dtype=torch.float32) - # create prior mask by hand - default_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 1]]) - # since we are pruning the two lowest magnitude units, the outcome of - # the calculation should be this: - expected_mask = torch.tensor([[0, 1, 1, 0], [0, 1, 0, 1]]) - pruned_tensor = p.prune(t, default_mask, importance_scores=importance_scores) - self.assertEqual(t * expected_mask, pruned_tensor) - - def test_prune_importance_scores_mimic_default(self): - # create a new pruning method - p = prune.L1Unstructured(amount=2) - # create tensor to be pruned - t = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]).to(dtype=torch.float32) - # create prior mask by hand - default_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 1]]) - # since we are pruning the two lowest magnitude units, the outcome of - # the calculation should be this: - expected_mask = torch.tensor([[0, 0, 1, 0], [1, 1, 0, 1]]) - pruned_tensor_without_importance_scores = p.prune(t, default_mask) - pruned_tensor_with_importance_scores = p.prune(t, default_mask, importance_scores=t) - self.assertEqual(pruned_tensor_without_importance_scores, pruned_tensor_with_importance_scores) - self.assertEqual(t * expected_mask, pruned_tensor_without_importance_scores) - - def test_rnn_pruning(self): - l = torch.nn.LSTM(32, 32) - # This Module has 4 parameters called: - # 'weight_ih_l0', 'weight_hh_l0', 'bias_ih_l0', 'bias_hh_l0' - - # Pruning one of them causes one of the weights to become a tensor - prune.l1_unstructured(l, 'weight_ih_l0', 0.5) - assert ( - sum([isinstance(p, torch.nn.Parameter) for p in l._flat_weights]) - == 3 - ) - - # Removing the pruning reparametrization restores the Parameter - prune.remove(l, 'weight_ih_l0') - assert ( - sum([isinstance(p, torch.nn.Parameter) for p in l._flat_weights]) - == 4 - ) - - # Make sure that, upon removal of the reparametrization, the - # `._parameters` and `.named_parameters` contain the right params. - # Specifically, the original weight ('weight_ih_l0') should be placed - # back in the parameters, while the reparametrization component - # ('weight_ih_l0_orig') should be removed. - assert 'weight_ih_l0' in l._parameters - assert l._parameters['weight_ih_l0'] is not None - assert 'weight_ih_l0_orig' not in l._parameters - assert 'weight_ih_l0' in dict(l.named_parameters()) - assert dict(l.named_parameters())['weight_ih_l0'] is not None - assert 'weight_ih_l0_orig' not in dict(l.named_parameters()) - - def test_rnn_weight_norm(self): - def check_weight_norm(l, name, num_params): - # This Module has 4 or 5 parameters called: - # 'weight_ih_l0', 'weight_hh_l0', 'bias_ih_l0', 'bias_hh_l0', weight_hr_l0 - - # Applying weight norm on one of them causes it to become a tensor - l = torch.nn.utils.weight_norm(l, name=name) - self.assertEqual( - sum([isinstance(p, torch.nn.Parameter) for p in l._flat_weights]), - num_params - 1, - ) - - # Removing the weight norm reparametrization restores the Parameter - l = torch.nn.utils.remove_weight_norm(l, name=name) - self.assertEqual( - sum([isinstance(p, torch.nn.Parameter) for p in l._flat_weights]), - num_params, - ) - - # Make sure that, upon removal of the reparametrization, the - # `._parameters` and `.named_parameters` contain the right params. - # Specifically, the original weight ('weight_ih_l0') should be placed - # back in the parameters, while the reparametrization components - # ('weight_ih_l0_v' and 'weight_ih_l0_g') should be removed. - self.assertTrue(name in l._parameters) - self.assertIsNotNone(l._parameters[name]) - self.assertTrue(name + '_v' not in l._parameters) - self.assertTrue(name + '_g' not in l._parameters) - self.assertTrue(name in dict(l.named_parameters())) - self.assertIsNotNone(dict(l.named_parameters())[name]) - self.assertTrue(name + '_v' not in dict(l.named_parameters())) - self.assertTrue(name + '_g' not in dict(l.named_parameters())) - - check_weight_norm(torch.nn.LSTM(32, 32), 'weight_ih_l0', 4) - check_weight_norm(torch.nn.LSTM(32, 32, proj_size=16), 'weight_hr_l0', 5) - - - def test_weight_norm(self): - for dtype in [torch.float, torch.bfloat16]: - input = torch.randn(3, 4, dtype=dtype) - m = nn.Linear(4, 5).to(dtype=dtype) - expected_output = m(input) - - # add weight normalization - m = torch.nn.utils.weight_norm(m) - self.assertEqual(m.weight_v.size(), m.weight.size()) - self.assertEqual(m.weight_g.size(), (5, 1)) - self.assertEqual(m(input), expected_output, atol=dtype2prec_DONTUSE[dtype], rtol=0) - - # remove weight norm - m = torch.nn.utils.remove_weight_norm(m) - self.assertFalse(hasattr(m, 'weight_g')) - self.assertFalse(hasattr(m, 'weight_v')) - self.assertEqual(m(input), expected_output, atol=dtype2prec_DONTUSE[dtype], rtol=0) - - # test with dim=1 - m = torch.nn.utils.weight_norm(m, dim=1) - self.assertEqual(m.weight_v.size(), m.weight.size()) - self.assertEqual(m.weight_g.size(), (1, 4)) - self.assertEqual(m(input), expected_output, atol=dtype2prec_DONTUSE[dtype], rtol=0) - - # test with dim=None - m = nn.Linear(4, 5).to(dtype=dtype) - expected_output = m(input) - m = torch.nn.utils.weight_norm(m, dim=None) - self.assertEqual(m(input), expected_output) - - with self.assertRaisesRegex(RuntimeError, 'register two weight_norm hooks'): - m = torch.nn.utils.weight_norm(m) - m = torch.nn.utils.weight_norm(m) - - # For float16, the forward of the Module doesn't work but we must still be able - # to register the weight norm as this is often done before sending the Module to - # CUDA. - m = nn.Linear(4, 5, dtype=torch.float16) - m = torch.nn.utils.weight_norm(m) - - def test_parameterlistdict_setting_attributes(self): - with warnings.catch_warnings(record=True) as w: - mod = nn.ParameterList(map(nn.Parameter, [torch.rand(2), torch.rand(2)])) - self.assertTrue(len(w) == 0) - - with warnings.catch_warnings(record=True) as w: - mod.train() - mod.eval() - self.assertTrue(len(w) == 0) - - with warnings.catch_warnings(record=True) as w: - mod = nn.ParameterDict({"a": nn.Parameter(torch.rand(2)), "b": nn.Parameter(torch.rand(2))}) - self.assertTrue(len(w) == 0) - - with warnings.catch_warnings(record=True) as w: - mod.train() - mod.eval() - self.assertTrue(len(w) == 0) - - def test_parameterlistdict_pickle(self): - m = nn.ParameterList(map(nn.Parameter, [torch.rand(2), torch.rand(2)])) - with warnings.catch_warnings(record=True) as w: - m = pickle.loads(pickle.dumps(m)) - self.assertTrue(len(w) == 0) - - # Test whether loading from older checkpoints works without triggering warnings - m = nn.ParameterList(map(nn.Parameter, [torch.rand(2), torch.rand(2)])) - del m._forward_pre_hooks, m._state_dict_hooks, m._load_state_dict_pre_hooks, m._non_persistent_buffers_set - with warnings.catch_warnings(record=True) as w: - m = pickle.loads(pickle.dumps(m)) - self.assertTrue(len(w) == 0) - - m = nn.ParameterDict({"a": nn.Parameter(torch.rand(2)), "b": nn.Parameter(torch.rand(2))}) - with warnings.catch_warnings(record=True) as w: - m = pickle.loads(pickle.dumps(m)) - self.assertTrue(len(w) == 0) - - # Test whether loading from older checkpoints works without triggering warnings - m = nn.ParameterDict({"a": nn.Parameter(torch.rand(2)), "b": nn.Parameter(torch.rand(2))}) - del m._forward_pre_hooks, m._state_dict_hooks, m._load_state_dict_pre_hooks, m._non_persistent_buffers_set - with warnings.catch_warnings(record=True) as w: - m = pickle.loads(pickle.dumps(m)) - self.assertTrue(len(w) == 0) + # Test whether loading from older checkpoints works without triggering warnings + m = nn.ParameterDict({"a": nn.Parameter(torch.rand(2)), "b": nn.Parameter(torch.rand(2))}) + del m._forward_pre_hooks, m._state_dict_hooks, m._load_state_dict_pre_hooks, m._non_persistent_buffers_set + with warnings.catch_warnings(record=True) as w: + m = pickle.loads(pickle.dumps(m)) + self.assertTrue(len(w) == 0) def test_weight_norm_pickle(self): m = torch.nn.utils.weight_norm(nn.Linear(5, 7)) @@ -4676,277 +2015,17 @@ def fn(input): wrapped_m.train() out2 = wrapped_m(input) wrapped_m.eval() - out3 = wrapped_m(input) - return out0 + out1 + out2 + out3 - - gradcheck(fn, (input.clone().requires_grad_(),)) - - # assert that backprop reaches weight_orig in eval - if requires_grad: - def fn(weight): - return wrapped_m(input) - - gradcheck(fn, (m.weight_orig,)) - - def test_new_spectral_norm(self): - input = torch.randn(3, 5) - m = nn.Linear(5, 7) - m = torch.nn.utils.parametrizations.spectral_norm(m) - spectral_norm_m = m.parametrizations.weight[0] - - self.assertEqual(spectral_norm_m._u.size(), torch.Size([m.weight.size(0)])) - - # .parametrizations.weight.original should be trainable - self.assertTrue(hasattr(m.parametrizations.weight, 'original')) - self.assertTrue('original' in m.parametrizations.weight._parameters) - - # u should be just a reused buffer - self.assertTrue(hasattr(spectral_norm_m, '_u')) - self.assertTrue('_u' in spectral_norm_m._buffers) - self.assertTrue('_v' in spectral_norm_m._buffers) - - # weight should be a plain attribute, not counted as a buffer or a param - self.assertIsNotNone(m.weight) - self.assertFalse('weight' in m._buffers) - self.assertFalse('weight' in m._parameters) - - # it should also be sharing storage as `weight_orig` - # self.assertEqual(m.parametrizations.weight.original.storage(), m.weight.storage()) - self.assertEqual(m.parametrizations.weight.original.size(), m.weight.size()) - self.assertEqual(m.parametrizations.weight.original.stride(), m.weight.stride()) - - m = torch.nn.utils.parametrize.remove_parametrizations(m, 'weight') - - # spectral_norm is the only parametrization - self.assertFalse(hasattr(m, 'parametrizations')) - self.assertTrue('weight' in m._parameters) - - # We can register spectral_norm multiple times on the same parameter - # and on multiple parameters in the same module - m = torch.nn.utils.parametrizations.spectral_norm(m, 'weight') - m = torch.nn.utils.parametrizations.spectral_norm(m, 'weight') - m = torch.nn.utils.parametrizations.spectral_norm(m, 'bias') - - # If we remove the parametrization on bias, weight is still parametrized - # Removing a parametrization runs forward in eval mode if leave_parametrized=True - m = torch.nn.utils.parametrize.remove_parametrizations(m, 'bias') - self.assertTrue('bias' in m._parameters) - self.assertTrue(hasattr(m, 'parametrizations')) - self.assertFalse('weight' in m._parameters) - - m = torch.nn.utils.parametrize.remove_parametrizations(m, 'weight') - # Neither weight and bias are parametrized - self.assertFalse(hasattr(m, 'parametrizations')) - self.assertTrue('weight' in m._parameters) - self.assertFalse(torch.nn.utils.parametrize.is_parametrized(m)) - - # test correctness in training/eval modes and cpu/multi-gpu settings - for apply_dp in (True, False): - if apply_dp: - if not TEST_MULTIGPU: - continue - device = torch.device('cuda:0') - - def maybe_wrap(m): - return torch.nn.DataParallel(m, [0, 1]) - else: - device = torch.device('cpu') - - def maybe_wrap(m): - return m - - for requires_grad in (True, False): - def get_modules(): - m = nn.Linear(3, 4).to(device) - m.weight.requires_grad_(requires_grad) - m = torch.nn.utils.parametrizations.spectral_norm(m) - wrapped_m = maybe_wrap(m) - spectral_norm_m = m.parametrizations.weight[0] - return m, wrapped_m, spectral_norm_m - - input = torch.randn(2, 3, device=device) - - m, wrapped_m, spectral_norm_m = get_modules() - - self.assertTrue(hasattr(spectral_norm_m, '_u')) - u0 = spectral_norm_m._u.clone() - v0 = spectral_norm_m._v.clone() - - # TEST TRAINING BEHAVIOR - - # We perform GD first to modify the initial matrix - opt = torch.optim.SGD(wrapped_m.parameters(), lr=0.1) - - opt.zero_grad() - wrapped_m(input).sum().backward() - opt.step() - - out = wrapped_m(input) - if requires_grad: - # run forward again and assert that u and v are updated - self.assertNotEqual(u0, spectral_norm_m._u) - self.assertNotEqual(v0, spectral_norm_m._v) - - # assert that backprop reaches original weight - # can't use gradcheck because the function changes as we - # activate through it in training mode - if requires_grad: - torch.autograd.grad(out.sum(), m.parametrizations.weight.original) - - # test backward works with multiple forwards - # it uses training mode so we need to reset `u` and `v` vectors - # to same value at beginning for finite difference test to pass - saved_u = spectral_norm_m._u.clone() - saved_v = spectral_norm_m._v.clone() - - def fn(input): - spectral_norm_m._u.data.copy_(saved_u) - spectral_norm_m._v.data.copy_(saved_v) - out0 = wrapped_m(input) - out1 = wrapped_m(input) - return out0 + out1 - - # Make sure we can compute gradients wrt to all the parameters in the case - # of double forward - fn(input.clone().requires_grad_()).sum().backward() - gradcheck(fn, (input.clone().requires_grad_(),), check_batched_grad=False) - - # test removing - # spectral norm module needs to be in eval mode if we'd like to - # avoid doing another power iteration - m, wrapped_m, _ = get_modules() - pre_remove_out = wrapped_m(input) - m.eval() - m = torch.nn.utils.parametrize.remove_parametrizations(m, 'weight') - self.assertEqual(wrapped_m(input), pre_remove_out) - - torch.nn.utils.parametrizations.spectral_norm(m) - for _ in range(3): - pre_remove_out = wrapped_m(input) - m.eval() - m = torch.nn.utils.parametrize.remove_parametrizations(m, 'weight') - self.assertEqual(wrapped_m(input), pre_remove_out) - - # TEST EVAL BEHAVIOR - m, wrapped_m, spectral_norm_m = get_modules() - wrapped_m(input) - last_train_out = wrapped_m(input) - last_train_u = spectral_norm_m._u.clone() - last_train_v = spectral_norm_m._v.clone() - wrapped_m.zero_grad() - wrapped_m.eval() - - eval_out0 = wrapped_m(input) - # assert eval gives same result as last training iteration - self.assertEqual(eval_out0, last_train_out) - # assert doing more iteartion in eval don't change things - self.assertEqual(eval_out0, wrapped_m(input)) - self.assertEqual(last_train_u, spectral_norm_m._u) - self.assertEqual(last_train_v, spectral_norm_m._v) - - # FIXME: the code below is flaky when executed with DataParallel - # see https://github.com/pytorch/pytorch/issues/13818 - if apply_dp: - continue - - # test backward works with multiple forwards in mixed training - # and eval modes - # it uses training mode so we need to reset `u` and `v` vectors - # to same value at beginning for finite difference test to pass - saved_u = spectral_norm_m._u.clone() - saved_v = spectral_norm_m._v.clone() - - def fn(input): - spectral_norm_m._u.data.copy_(saved_u) - spectral_norm_m._v.data.copy_(saved_v) - wrapped_m.train() - out0 = wrapped_m(input) - wrapped_m.eval() - out1 = wrapped_m(input) - wrapped_m.train() - out2 = wrapped_m(input) - wrapped_m.eval() - out3 = wrapped_m(input) - return out0 + out1 + out2 + out3 - - gradcheck(fn, (input.clone().requires_grad_(),)) - - # assert that backprop reaches weight_orig in eval - if requires_grad: - def fn(weight): - return wrapped_m(input) - - gradcheck(fn, (m.parametrizations.weight.original,)) - - def test_new_spectral_norm_load_state_dict(self): - for activate_times in (0, 3): - inp = torch.randn(2, 3) - m = nn.Linear(3, 5) - snm = torch.nn.utils.parametrizations.spectral_norm(m) - snm.train() - - for _ in range(activate_times): - snm(inp) - - state_dict = deepcopy(snm.state_dict()) - self.assertEqual({ - 'parametrizations.weight.original', - 'bias', - 'parametrizations.weight.0._v', - 'parametrizations.weight.0._u' - }, set(state_dict.keys())) - - # test that non-strict loading works - non_strict_state_dict = deepcopy(state_dict) - non_strict_state_dict['nonsense'] = 'nonsense' - with self.assertRaisesRegex(RuntimeError, r'Unexpected key\(s\) in state_dict: "nonsense"'): - snm.load_state_dict(non_strict_state_dict, strict=True) - snm.load_state_dict(non_strict_state_dict, strict=False) - del non_strict_state_dict['parametrizations.weight.original'] - snm.load_state_dict(non_strict_state_dict, strict=False) - del non_strict_state_dict['parametrizations.weight.0._u'] - snm.load_state_dict(non_strict_state_dict, strict=False) - del non_strict_state_dict['parametrizations.weight.0._v'] - snm.load_state_dict(non_strict_state_dict, strict=False) - non_strict_state_dict['weight'] = snm.weight.detach().clone() # set W as a buffer - snm.load_state_dict(non_strict_state_dict, strict=False) - del non_strict_state_dict._metadata['parametrizations.weight.0'] # remove metadata info - snm.load_state_dict(non_strict_state_dict, strict=False) - del non_strict_state_dict['weight'] # remove W buffer - snm.load_state_dict(non_strict_state_dict, strict=False) - del non_strict_state_dict['bias'] - snm.load_state_dict(non_strict_state_dict, strict=False) + out3 = wrapped_m(input) + return out0 + out1 + out2 + out3 - # normal state_dict + gradcheck(fn, (input.clone().requires_grad_(),)) - # test that re-wrapping does not matter - m = torch.nn.utils.parametrize.remove_parametrizations(snm, 'weight') - snm = torch.nn.utils.parametrizations.spectral_norm(m) + # assert that backprop reaches weight_orig in eval + if requires_grad: + def fn(weight): + return wrapped_m(input) - snm.load_state_dict(state_dict) - with torch.no_grad(): - snm.eval() - out0_eval = snm(inp) - snm.train() - out1_train = snm(inp) - out2_train = snm(inp) - snm.eval() - out3_eval = snm(inp) - - # test that re-wrapping does not matter - m = torch.nn.utils.parametrize.remove_parametrizations(snm, 'weight') - snm = torch.nn.utils.parametrizations.spectral_norm(m) - - # Test normal loading - snm.load_state_dict(state_dict) - with torch.no_grad(): - snm.eval() - self.assertEqual(out0_eval, snm(inp)) - snm.train() - self.assertEqual(out1_train, snm(inp)) - self.assertEqual(out2_train, snm(inp)) - snm.eval() - self.assertEqual(out3_eval, snm(inp)) + gradcheck(fn, (m.weight_orig,)) @skipIfNoLapack def test_spectral_norm_load_state_dict(self): @@ -5055,16 +2134,6 @@ def test_spectral_norm_dim(self): # check that u refers to the same dimension self.assertEqual(m.weight_u.shape, m.weight_orig[0, :, 0, 0].shape) - def test_new_spectral_norm_dim(self): - inp = torch.randn(2, 3, 10, 12) - m = nn.ConvTranspose2d(3, 4, (5, 6)) - m = torch.nn.utils.parametrizations.spectral_norm(m) - snm = m.parametrizations.weight[0] - # this should not run into incompatible shapes - x = m(inp) - # check that u refers to the same dimension - self.assertEqual(snm._u.shape, m.parametrizations.weight.original[0, :, 0, 0].shape) - def test_spectral_norm_forward(self): input = torch.randn(3, 5) m = nn.Linear(5, 7) @@ -5081,164 +2150,11 @@ def test_spectral_norm_forward(self): expect_out = m(input) self.assertEqual(expect_out, out_hat) - def test_new_spectral_norm_forward(self): - input = torch.randn(3, 5) - m = nn.Linear(5, 7) - m = torch.nn.utils.parametrizations.spectral_norm(m) - snm = m.parametrizations.weight[0] - # naive forward - _weight = m.parametrizations.weight.original - _bias, _v = m.bias, snm._v - _weight_mat = _weight.view(_weight.size(0), -1) - _u = torch.mv(_weight_mat, _v) - _u = F.normalize(_u, dim=0, eps=1e-12) - _v = torch.mv(_weight_mat.t(), _u) - _v = F.normalize(_v, dim=0, eps=1e-12) - _weight.data /= torch.dot(_u, torch.matmul(_weight_mat, _v)) - out_hat = torch.nn.functional.linear(input, _weight, _bias) - expect_out = m(input) - self.assertEqual(expect_out, out_hat) - def test_spectral_norm_pickle(self): m = torch.nn.utils.spectral_norm(nn.Linear(5, 7)) m = pickle.loads(pickle.dumps(m)) self.assertIsInstance(m, nn.Linear) - @skipIfNoLapack - def test_orthogonal_parametrization(self): - # Orthogonal implements 6 algorithms (3x parametrizations times 2 options of use_trivialization) - - def assert_is_orthogonal(X): - n, k = X.size(-2), X.size(-1) - if n < k: - X = X.mT - n, k = k, n - Id = torch.eye(k, dtype=X.dtype, device=X.device).expand(*(X.size()[:-2]), k, k) - eps = 10 * n * torch.finfo(X.dtype).eps - torch.testing.assert_allclose(X.mH @ X, Id, atol=eps, rtol=0.) - - - def assert_weight_allclose_Q(weight, W): - # Test that weight is equal to the Q part of the QR decomposition of W - # (or of its transpose if the matrix is wide) - wide_matrix = W.size(-2) < W.size(-1) - if wide_matrix: - W = W.mT - Q, R = torch.linalg.qr(W) - Q *= R.diagonal(dim1=-2, dim2=-1).sgn().unsqueeze(-2) - if wide_matrix: - Q = Q.mT - torch.testing.assert_allclose(Q, weight, atol=1e-5, rtol=0.) - - - for shape, dtype, use_linear in product(((4, 4), (5, 3), (3, 5)), # square/ tall / wide - (torch.float32, torch.complex64), - (True, False)): - # Conv2d does not support complex yet - if not use_linear: - continue - - if use_linear: - input = torch.randn(3, shape[0], dtype=dtype) - else: - input = torch.randn(2, 2, shape[0] + 2, shape[1] + 1, dtype=dtype) - - for parametrization, use_trivialization in product(("matrix_exp", "cayley", "householder"), - (False, True)): - # right_inverse for Cayley and matrix_exp not implemented for use_trivialization=False - # See Note [right_inverse expm cayley] - can_initialize = use_trivialization or parametrization == "householder" - - # We generate them every time to always start with fresh weights - if use_linear: - m = nn.Linear(*shape, dtype=dtype) - else: - m = nn.Conv2d(2, 3, shape, dtype=dtype) - - # We do not support householder for complex inputs - # See Note [Householder complex] - w_init = m.weight.clone() - if parametrization == "householder" and m.weight.is_complex(): - msg = "householder parametrization does not support complex tensors" - with self.assertRaisesRegex(ValueError, msg): - torch.nn.utils.parametrizations.orthogonal(m, - "weight", - parametrization, - use_trivialization=use_trivialization) - continue - - wide_matrix = w_init.size(-2) < w_init.size(-1) - torch.nn.utils.parametrizations.orthogonal(m, - "weight", - parametrization, - use_trivialization=use_trivialization) - # Forwards works as expected - self.assertEqual(w_init.shape, m.weight.shape) - assert_is_orthogonal(m.weight) - if can_initialize: - assert_weight_allclose_Q(m.weight, w_init) - - # Intializing with a given orthogonal matrix works - X = torch.randn_like(m.weight) - if wide_matrix: - X = X.mT - w_new = torch.linalg.qr(X).Q - if wide_matrix: - w_new = w_new.mT - if can_initialize: - m.weight = w_new - torch.testing.assert_allclose(w_new, m.weight, atol=1e-5, rtol=0.) - else: - msg = "assign to the matrix exponential or the Cayley parametrization" - with self.assertRaisesRegex(NotImplementedError, msg): - m.weight = w_new - - # Intializing with a non-orthogonal matrix makes m.weight be the Q part of the given matrix - w_new = torch.randn_like(m.weight) - if can_initialize: - m.weight = w_new - assert_weight_allclose_Q(m.weight, w_new) - else: - msg = "assign to the matrix exponential or the Cayley parametrization" - with self.assertRaisesRegex(NotImplementedError, msg): - m.weight = w_new - - opt = torch.optim.SGD(m.parameters(), lr=0.1) - for _ in range(2): - opt.zero_grad() - m(input).norm().backward() - grad = m.parametrizations.weight.original.grad - self.assertIsNotNone(grad) - # We do not update the upper triangular part of the matrix if tall tril if wide - if grad.size(-2) >= grad.size(-1): - zeros_grad = grad.triu(1) - else: - zeros_grad = grad.tril(-1) - self.assertEqual(zeros_grad, torch.zeros_like(zeros_grad)) - # The gradient in the diagonal can only be imaginary because a skew-Hermitian - # matrix has imaginary diagonal - diag_grad = grad.diagonal(dim1=-2, dim2=-1) - if grad.is_complex(): - diag_grad = diag_grad.real - self.assertEqual(diag_grad, torch.zeros_like(diag_grad)) - opt.step() - assert_is_orthogonal(m.weight) - - @skipIfNoLapack - def test_orthogonal_errors(self): - m = nn.Linear(3, 4) - with self.assertRaisesRegex(ValueError, "has to be one of"): - torch.nn.utils.parametrizations.orthogonal(m, "weight", "foo") - - with self.assertRaisesRegex(ValueError, "Expected a matrix"): - torch.nn.utils.parametrizations.orthogonal(m, "bias") - - torch.nn.utils.parametrizations.orthogonal(m, "weight") - with self.assertRaisesRegex(ValueError, "matrices of shape"): - m.weight = torch.randn(5, 5) - torch.nn.utils.parametrize.remove_parametrizations(m, "weight") - - def test_threshold_int(self): x = torch.tensor([-3, -2, -1, 0, 1, 2, 3]) expected = torch.tensor([99, 99, 99, 99, 1, 2, 3]) @@ -5319,491 +2235,6 @@ def test_nested_tensor_from_mask_error(self): mask[0, 2] = False self.assertRaises(RuntimeError, lambda: torch._nested_tensor_from_mask(input, mask)) - @unittest.skipIf(not TEST_NUMPY, "numpy not found") - @parametrize_test("average_attn_weights", [True, False]) - def test_multihead_attention(self, average_attn_weights): - def _scaled_dot_attn_ref(Q, K, V, dims, unseen_mask=None, key_padding_mask=None, - average_attn_weights=average_attn_weights): - """ Numpy-based reference implementation of scaled dot attention - for testing""" - - QKT = _batchmatmul( - Q, - np.transpose(K, axes=[0, 1, 3, 2]) - / np.sqrt(dims[3], dtype=np.float32), # divide by sqrt(d_head) - ) - b1, b2, s1, s2 = QKT.shape - if unseen_mask is not None or key_padding_mask is not None: - # assert s1 == s2 - for i in range(b1): - for j in range(b2): - for m in range(s1): - for n in range(s2): - if unseen_mask is not None and unseen_mask[m][n] == 0: - QKT[i, j, m, n] = -np.inf - if key_padding_mask is not None and key_padding_mask[i][n]: - QKT[i, j, m, n] = -np.inf - - reference = _softmax(QKT) - ref_attn_weight = reference - if average_attn_weights: - ref_attn_weight = np.sum(ref_attn_weight, axis=1) / b2 - reference = _batchmatmul(reference, V) - return reference, ref_attn_weight - - def _batchmatmul(a, b): # batchmatmul over 4 dim matrix - """ Numpy-based batch matrix multiply over 4 dim matrix""" - assert a.shape[0] == b.shape[0] - assert a.shape[1] == b.shape[1] - retval = np.zeros( - (a.shape[0], a.shape[1], a.shape[2], b.shape[3]), dtype=np.float32 - ) - for i in range(a.shape[0]): - for j in range(a.shape[1]): - retval[i, j, :, :] = np.matmul(a[i, j, :, :], b[i, j, :, :]) - return retval - - def _softmax(x): # softmax over 4 dim matrix - """ Numpy-based reference softmax over 4 dim matrix""" - np.seterr(invalid='ignore') - output = np.zeros(x.shape, dtype=np.float64) - for i in range(x.shape[0]): - for j in range(x.shape[1]): - for k in range(x.shape[2]): - x_curr = x[i, j, k, :] - e_x = np.exp(x_curr - np.amax(x_curr)) - output[i, j, k, :] = e_x / np.sum(e_x) - return output - - def _split_heads_ref(X, dims, nheads, d_head): - X_split = np.reshape(X, dims[:2] + [nheads, d_head]) - X_split_transposed = np.transpose(X_split, [0, 2, 1, 3]) - reference = np.reshape(X_split_transposed, [dims[0], nheads, dims[1], d_head]) - return reference - - def _combine_heads_ref(X, dims, nheads, d_head): - X_transposed = np.transpose(X, [0, 2, 1, 3]) - reference = np.reshape(X_transposed, dims[:2] + [nheads * d_head]) - return reference - - def _fc(X, X_weight, X_bias): - X_fc_b = X_bias.detach().numpy() - X_fc_w = X_weight.detach().numpy() - return np.matmul(X, np.transpose(X_fc_w)) + X_fc_b - - def _create_src_lengths_mask(batch_size, src_lengths): - """ - Generate boolean mask to prevent attention beyond the end of source - Inputs: - batch_size : int - src_lengths : [batch_size] of sentence lengths - Outputs: - [batch_size, max_src_len] - """ - max_srclen = src_lengths.max() - src_indices = torch.arange(0, max_srclen).unsqueeze(0).to(src_lengths) - src_indices = src_indices.expand(batch_size, max_srclen) - src_lengths = src_lengths.unsqueeze(dim=1).expand(batch_size, max_srclen) - # returns [batch_size, max_seq_len] - return (src_indices < src_lengths).int().detach() - - def _multihead_attn_test_helper(add_key_padding_mask=False, add_bias_kv=False, add_zero_attn=False, - saved_kv=False, same_embed_dim=False, - average_attn_weights=average_attn_weights): - for _ in range(100): - batch_sz, seq_len = [random.randint(2, 10) for r in range(2)] - d_head = random.randint(3, 10) - nheads = random.randint(2, 5) * 2 - d_model = d_head * nheads - if same_embed_dim: - kv_dim = d_model - else: - kv_dim = random.randint(5, 20) - dims = [batch_sz, seq_len, kv_dim] - - saved_k = None - saved_k_tensor = None - saved_v = None - saved_v_tensor = None - if saved_kv: - saved_k = np.random.rand(batch_sz * nheads, seq_len, d_head) - saved_k_tensor = torch.from_numpy(saved_k).to(torch.get_default_dtype()) - saved_v = np.random.rand(batch_sz * nheads, seq_len, d_head) - saved_v_tensor = torch.from_numpy(saved_v).to(torch.get_default_dtype()) - - key_padding_mask = None - key_padding_mask_tensor = None - if add_key_padding_mask: - seq_mask = np.random.randint(0, 2, (1, seq_len)) - key_padding_mask = (np.repeat(seq_mask, batch_sz, axis=0) == 1) - key_padding_mask_tensor = torch.from_numpy(key_padding_mask) - decoder_state = np.random.rand(batch_sz, d_model) - K = np.random.rand(*dims) - V = K - Q = np.expand_dims(decoder_state, 1) - attn_mask = np.random.randint(0 , 2, size=(1, seq_len)) - attn_mask_tensor = torch.from_numpy(attn_mask).float() - attn_mask_tensor.masked_fill_(attn_mask_tensor == 0, float('-inf')) - attn_mask_tensor.masked_fill_(attn_mask_tensor > 0, float('0.0')) - attn_mask_tensor = attn_mask_tensor.double() - - decoder_state_tensor = torch.from_numpy(decoder_state).to(torch.get_default_dtype()) - source_hid_tensor = torch.from_numpy(K).to(torch.get_default_dtype()).transpose(0, 1) - - multihead_attn_module = MultiheadAttention(d_model, nheads, - add_bias_kv=add_bias_kv, - add_zero_attn=add_zero_attn, - kdim=kv_dim, vdim=kv_dim) - - if add_bias_kv: - bias_k = multihead_attn_module.bias_k.detach().numpy() - bias_v = multihead_attn_module.bias_v.detach().numpy() - else: - bias_k = None - bias_v = None - - _Q = decoder_state_tensor.unsqueeze(1).transpose(0, 1) - _V = source_hid_tensor - _K = source_hid_tensor - - if multihead_attn_module._qkv_same_embed_dim: - result, result_weight = torch.nn.functional.multi_head_attention_forward( - _Q, _K, _V, - d_model, nheads, - multihead_attn_module.in_proj_weight, multihead_attn_module.in_proj_bias, - multihead_attn_module.bias_k, multihead_attn_module.bias_v, - multihead_attn_module.add_zero_attn, multihead_attn_module.dropout, - multihead_attn_module.out_proj.weight, multihead_attn_module.out_proj.bias, - multihead_attn_module.training, key_padding_mask_tensor, True, attn_mask_tensor, - static_k=saved_k_tensor, static_v=saved_v_tensor, - average_attn_weights=average_attn_weights) - else: - result, result_weight = torch.nn.functional.multi_head_attention_forward( - _Q, _K, _V, - d_model, nheads, - None, multihead_attn_module.in_proj_bias, - multihead_attn_module.bias_k, multihead_attn_module.bias_v, - multihead_attn_module.add_zero_attn, multihead_attn_module.dropout, - multihead_attn_module.out_proj.weight, multihead_attn_module.out_proj.bias, - multihead_attn_module.training, key_padding_mask_tensor, True, attn_mask_tensor, - True, multihead_attn_module.q_proj_weight, - multihead_attn_module.k_proj_weight, multihead_attn_module.v_proj_weight, - static_k=saved_k_tensor, static_v=saved_v_tensor, - average_attn_weights=average_attn_weights) - - result = result.squeeze(0).detach().numpy() - - if multihead_attn_module._qkv_same_embed_dim: - q_proj_weight = multihead_attn_module.in_proj_weight[:d_model] - k_proj_weight = multihead_attn_module.in_proj_weight[d_model:(d_model * 2)] - v_proj_weight = multihead_attn_module.in_proj_weight[(d_model * 2):] - else: - q_proj_weight = multihead_attn_module.q_proj_weight - k_proj_weight = multihead_attn_module.k_proj_weight - v_proj_weight = multihead_attn_module.v_proj_weight - - Q_fc = _fc(Q, q_proj_weight, multihead_attn_module.in_proj_bias[:d_model]) - K_fc = _fc(K, k_proj_weight, multihead_attn_module.in_proj_bias[d_model:(d_model * 2)]) - V_fc = _fc(V, v_proj_weight, multihead_attn_module.in_proj_bias[(d_model * 2):]) - - if add_bias_kv: - K_fc = np.concatenate((K_fc, np.repeat(bias_k, K_fc.shape[0], axis=0)), axis=1) - V_fc = np.concatenate((V_fc, np.repeat(bias_v, V_fc.shape[0], axis=0)), axis=1) - if attn_mask is not None: - attn_mask = np.concatenate((attn_mask, np.ones([1, 1])), axis=1) - if key_padding_mask is not None: - key_padding_mask = np.concatenate((key_padding_mask, np.full((batch_sz, 1), False, dtype=bool)), axis=1) - dims[1] += 1 - Q_split = _split_heads_ref( - Q_fc, [batch_sz, 1, d_model], nheads, d_head - ) - - if saved_k is not None: - K_split = np.reshape(saved_k, [dims[0], nheads, dims[1], d_head]) - else: - K_split = _split_heads_ref(K_fc, dims, nheads, d_head) - - if saved_v is not None: - V_split = np.reshape(saved_v, [dims[0], nheads, dims[1], d_head]) - else: - V_split = _split_heads_ref(V_fc, dims, nheads, d_head) - - if add_zero_attn: - dims[1] += 1 - K_split = np.concatenate((K_split, np.zeros([K_split.shape[0], K_split.shape[1], 1, K_split.shape[3]])), axis=2) - V_split = np.concatenate((V_split, np.zeros([V_split.shape[0], V_split.shape[1], 1, V_split.shape[3]])), axis=2) - - if attn_mask is not None: - attn_mask = np.concatenate((attn_mask, np.ones([1, 1])), axis=1) - - if key_padding_mask is not None: - key_padding_mask = np.concatenate((key_padding_mask, np.full((batch_sz, 1), False, dtype=bool)), axis=1) - attn_heads, ref_attn_weight = _scaled_dot_attn_ref( - Q=Q_split, - K=K_split, - V=V_split, - dims=Q_split.shape, - unseen_mask=attn_mask, - key_padding_mask=key_padding_mask - ) - combined_attn_heads = _combine_heads_ref( - X=attn_heads, dims=[batch_sz, 1], nheads=nheads, d_head=d_head - ) - - reference = _fc(combined_attn_heads, multihead_attn_module.out_proj.weight, multihead_attn_module.out_proj.bias) - reference = np.squeeze(reference, axis=1) - - # result = reference - self.assertEqual(tuple(result.shape), (batch_sz, d_model)) - np.testing.assert_allclose(result, reference, atol=1e-5) - - # result_weight = ref_attn_weight - result_weight = result_weight.detach().numpy() - self.assertEqual(tuple(result_weight.shape), tuple(ref_attn_weight.shape)) - np.testing.assert_allclose(result_weight, ref_attn_weight, atol=1e-5) - - def test_multihead_attn_add_bias_kv(): - _multihead_attn_test_helper(add_bias_kv=True) - - def test_multihead_attn_add_zero_attn(): - _multihead_attn_test_helper(add_zero_attn=True) - - def test_multihead_attn_no_masking(): - _multihead_attn_test_helper() - - def test_multihead_attn_key_padding_mask(): - _multihead_attn_test_helper(add_key_padding_mask=True) - - def test_multihead_attn_saved_kv(): - _multihead_attn_test_helper(saved_kv=True) - - def test_multihead_attn_add_bias_kv_zero_attn(): - _multihead_attn_test_helper(add_key_padding_mask=True, add_bias_kv=True, - add_zero_attn=True) - - def test_multihead_attn_all_arguments1(): - _multihead_attn_test_helper(add_key_padding_mask=True, add_zero_attn=True, saved_kv=True) - - def test_multihead_attn_all_arguments2(): - _multihead_attn_test_helper(add_key_padding_mask=True, add_bias_kv=True, - add_zero_attn=True, saved_kv=True) - - def test_multihead_attn_all_arguments3(): - _multihead_attn_test_helper(add_key_padding_mask=True, add_zero_attn=True, - saved_kv=True, same_embed_dim=True) - - test_multihead_attn_add_zero_attn() # Test MultiheadAttention with add_zero_attn - test_multihead_attn_add_bias_kv() # Test MultiheadAttention with add_bias_kv - test_multihead_attn_no_masking() # Test MultiheadAttention without masking - test_multihead_attn_key_padding_mask() # Test MultiheadAttention with src lengths - test_multihead_attn_saved_kv() # Test MultiheadAttention with static kv. - test_multihead_attn_add_bias_kv_zero_attn() # Test MultiheadAttention with bias_kv and zero_attn. - test_multihead_attn_all_arguments1() # Test MultiheadAttention with all the argument. - with self.assertRaisesRegex(AssertionError, "bias cannot be added to static key."): - test_multihead_attn_all_arguments2() # Test MultiheadAttention with all the argument. - test_multihead_attn_all_arguments3() # Test MultiheadAttention with all the argument. - - def test_multihead_attn_3d_attn_mask(self): - embed_dim = 8 - num_heads = 4 - batch_size = 8 - src_len = 3 - tgt_len = 2 - - query = torch.rand(batch_size, tgt_len, embed_dim) # [N, T, D] - key = torch.rand(batch_size, src_len, embed_dim) # [N, S, D] - value = key # [N, S, D] - attn_mask = torch.randint(0, 2, (batch_size, tgt_len, src_len)).float() # [N, T, S] - attn_mask = attn_mask.masked_fill(attn_mask == 0, float('-inf')).masked_fill(attn_mask == 1, float(0.0)) - - mta_model = torch.nn.MultiheadAttention(embed_dim, num_heads) - - # Generate 3D results - attn_mask_3d = torch.repeat_interleave(attn_mask, num_heads, dim=0) # [N * H, T, S] - output_3d = mta_model(query.transpose(0, 1), key.transpose(0, 1), value.transpose(0, 1), attn_mask=attn_mask_3d)[0] - output_3d = output_3d.transpose(0, 1) # [N, T, D] - - for i in range(0, batch_size): - output_2d = mta_model(query[i].unsqueeze(0).transpose(0, 1), - key[i].unsqueeze(0).transpose(0, 1), - value[i].unsqueeze(0).transpose(0, 1), - attn_mask=attn_mask[i])[0] - - # output_2d in shape of [T, 1, D] - self.assertEqual(output_3d[i].unsqueeze(0).transpose(0, 1), output_2d) - - def test_multihead_attn_no_bias(self): - embed_dim = 8 - num_heads = 4 - mha = torch.nn.MultiheadAttention(embed_dim, num_heads, bias=False) - - # Verify that bias=False applies to both in and out projection layers. - self.assertIsNone(mha.in_proj_bias) - self.assertIsNone(mha.out_proj.bias) - - def _test_multihead_attn_invalid_shape_impl(self, mha): - # Batched (3D) query cases - query = torch.randn(4, 4, 4) - key = torch.randn(4, 4, 4) - value = torch.randn(4, 4, 4) - - msg = "expected `key` and `value` to be 3-D but found 2-D and 3-D tensors respectively" - # 3D query, 2D key and 3D value - with self.assertRaisesRegex(AssertionError, msg): - mha(query, torch.randn(4, 4), value) - - msg = "expected `key` and `value` to be 3-D but found 3-D and 2-D tensors respectively" - # 3D query, 3D key and 2D value - with self.assertRaisesRegex(AssertionError, msg): - mha(query, key, torch.randn(4, 4)) - - msg = "expected `key_padding_mask` to be `None` or 2-D but found 1-D tensor instead" - # 3D query, 3D key, 3D value and 1D key_padding_mask - with self.assertRaisesRegex(AssertionError, msg): - mha(query, key, value, key_padding_mask=torch.tensor([False, False, True, True], dtype=torch.bool)) - - msg = "expected `attn_mask` to be `None`, 2-D or 3-D but found 1-D tensor instead" - # 3D query, 3D key, 3D value and 1D attn_mask - with self.assertRaisesRegex(AssertionError, msg): - mha(query, key, value, attn_mask=torch.tensor([False, False, True, True], dtype=torch.bool)) - - # Unbatched (2D) query cases - query = torch.randn(4, 4) - key = torch.randn(4, 4) - value = torch.randn(4, 4) - - msg = "expected `key` and `value` to be 2-D but found 3-D and 2-D tensors respectively" - # 2D query, 3D key and 2D value - with self.assertRaisesRegex(AssertionError, msg): - mha(query, torch.randn(4, 4, 4), value) - - msg = "expected `key` and `value` to be 2-D but found 2-D and 3-D tensors respectively" - # 2D query, 3D key and 2D value - with self.assertRaisesRegex(AssertionError, msg): - mha(query, key, torch.randn(4, 4, 4)) - - msg = "expected `key_padding_mask` to be `None` or 1-D but found 2-D tensor instead" - # 2D query, 2D key, 2D value and 1D key_padding_mask - with self.assertRaisesRegex(AssertionError, msg): - mha(query, key, value, key_padding_mask=torch.tensor([[False, False, True, True] * 2], dtype=torch.bool)) - - msg = "expected `attn_mask` to be `None`, 2-D or 3-D but found 1-D tensor instead" - # 2D query, 2D key, 2D value and 1D attn_mask - with self.assertRaisesRegex(AssertionError, msg): - mha(query, key, value, attn_mask=torch.tensor([False, False, True, True], dtype=torch.bool)) - - msg = r"Expected `attn_mask` shape to be \(4, 4, 4\)" - # 2D query, 2D key, 2D value and 3D incorrect attn_mask - with self.assertRaisesRegex(AssertionError, msg): - mha(query, key, value, attn_mask=torch.randn(5, 4, 4).bernoulli_().to(torch.bool)) - - def test_multihead_attn_invalid_shape(self): - mha = torch.nn.MultiheadAttention(4, 4) - self._test_multihead_attn_invalid_shape_impl(mha) - # Give the test a chance to hit the fast path. (Right now, it - # won't, but gating may be less restricted in the future.) - with torch.no_grad(): - self._test_multihead_attn_invalid_shape_impl(mha.eval()) - - @torch.no_grad() - def test_multihead_attn_fast_path_invalid_shape(self): - mha = torch.nn.MultiheadAttention(4, 4, batch_first=True).eval() - - # Batched (3D) query cases - query = torch.randn(4, 4, 4) - key = torch.randn(4, 4, 4) - value = torch.randn(4, 4, 4) - - # Currently, this case will just go to the slow path and get - # the usual message because it fails the requirement to be - # batched. - msg = "expected `key` and `value` to be 3-D but found 2-D and 3-D tensors respectively" - # 3D query, 2D key and 3D value - with self.assertRaisesRegex(AssertionError, msg): - mha(query, torch.randn(3, 3), value, need_weights=False) - - # Currently, this case will just go to the slow path and get - # the usual message because it fails the requirement to be - # batched. - msg = "expected `key` and `value` to be 3-D but found 3-D and 2-D tensors respectively" - # 3D query, 3D key and 2D value - with self.assertRaisesRegex(AssertionError, msg): - mha(query, key, torch.randn(3, 3), need_weights=False) - - msg = "expected `key_padding_mask` to be `None` or 2-D but found 1-D tensor instead" - # 3D query, 3D key, 3D value and 1D key_padding_mask - with self.assertRaisesRegex(AssertionError, msg): - mha(query, key, value, key_padding_mask=torch.tensor([False, True, True], dtype=torch.bool), need_weights=False) - - msg = "expected `attn_mask` to be `None`, 2-D or 3-D but found 1-D tensor instead" - # 3D query, 3D key, 3D value and 1D attn_mask - with self.assertRaisesRegex(AssertionError, msg): - mha(query, key, value, attn_mask=torch.tensor([False, True, True], dtype=torch.bool), need_weights=False) - - # Unbatched (2D) query cases - # NOTE: error messages are the same as regular path because the fast path doesn't support 2D. - query = torch.randn(4, 4) - key = torch.randn(4, 4) - value = torch.randn(4, 4) - - msg = "expected `key` and `value` to be 2-D but found 3-D and 2-D tensors respectively" - # 2D query, 3D key and 2D value - with self.assertRaisesRegex(AssertionError, msg): - mha(query, torch.randn(4, 4, 4), value) - - msg = "expected `key` and `value` to be 2-D but found 2-D and 3-D tensors respectively" - # 2D query, 3D key and 2D value - with self.assertRaisesRegex(AssertionError, msg): - mha(query, key, torch.randn(4, 4, 4)) - - msg = "expected `key_padding_mask` to be `None` or 1-D but found 2-D tensor instead" - # 2D query, 2D key, 2D value and 1D key_padding_mask - with self.assertRaisesRegex(AssertionError, msg): - mha(query, key, value, key_padding_mask=torch.tensor([[False, False, True, True] * 2], dtype=torch.bool)) - - msg = "expected `attn_mask` to be `None`, 2-D or 3-D but found 1-D tensor instead" - # 2D query, 2D key, 2D value and 1D attn_mask - with self.assertRaisesRegex(AssertionError, msg): - mha(query, key, value, attn_mask=torch.tensor([False, False, True, True], dtype=torch.bool)) - - msg = r"Expected `attn_mask` shape to be \(4, 4, 4\)" - # 2D query, 2D key, 2D value and 3D incorrect attn_mask - with self.assertRaisesRegex(AssertionError, msg): - mha(query, key, value, attn_mask=torch.randn(5, 4, 4).bernoulli_().to(torch.bool)) - - def test_multihead_attn_nested_tensor_outside_fast_path(self): - mha = torch.nn.MultiheadAttention(4, 4, batch_first=True).eval() - nt = torch.nested.nested_tensor([torch.randn(4, 4)]) - # One tested platform (linux-bionic-py3.7-clang) has a torch_function for one - # or more of these. Take advantage of that to test the torch_function bailout. - has_torch_func = torch.overrides.has_torch_function( - (nt, mha.in_proj_weight, mha.in_proj_bias, mha.out_proj.weight, mha.out_proj.bias)) - if has_torch_func: - msg = "MultiheadAttention does not support NestedTensor.*argument has_torch_function" - else: - msg = ("MultiheadAttention does not support NestedTensor outside of its fast path.*grad is " + - "enabled and.*or biases requires_grad") - with self.assertRaisesRegex(AssertionError, msg): - mha(nt, nt, nt) - - if has_torch_func: - # Just give up, they're all going to fail with the same message. - return - - with torch.no_grad(): - mha(nt, nt, nt) - with torch.inference_mode(): - mha(nt, nt, nt) - nt = torch.nested.nested_tensor([torch.randn(4, 4, requires_grad=False)]) - nt.requires_grad = False - with self.assertRaisesRegex(AssertionError, msg): - mha(nt, nt, nt) - mha.in_proj_weight.requires_grad = False - mha.in_proj_bias.requires_grad = False - mha.out_proj.weight.requires_grad = False - mha.out_proj.bias.requires_grad = False - mha(nt, nt, nt) - def test_normalize(self): inputs = torch.randn(1, 3, 4, 4, requires_grad=True) self.assertTrue(gradcheck(lambda x: F.normalize(x, p=1, dim=-1), (inputs,))) @@ -6020,6 +2451,60 @@ def hook_fn(module, state_dict, prefix, local_metadata, strict, missing_keys, un model[0][0]._register_load_state_dict_pre_hook(hook_fn, with_module=True) model.load_state_dict(model.state_dict(), strict=True) + @unittest.skipIf(IS_WINDOWS, "Tempfile permission issue on windows") + def test_register_state_dict_pre_hook_backward_compat(self): + called = False + + def my_state_dict_pre_hook(*args, **kwargs): + nonlocal called + called = True + + m = nn.Linear(1, 1) + self.assertTrue(hasattr(m, '_state_dict_pre_hooks')) + delattr(m, '_state_dict_pre_hooks') + # Save and load, ensure we can still call state_dict + # without running into issues. + with NamedTemporaryFile() as f: + # Note that torch.save / torch.load is not recommended + # to save / load modules. + torch.save(m, f.name) + m = torch.load(f.name) + + # Ensure we can run state_dict without issues + _ = m.state_dict() + self.assertFalse(called) + m.register_state_dict_pre_hook(my_state_dict_pre_hook) + _ = m.state_dict() + self.assertTrue(called) + + def test_register_state_dict_pre_hook(self): + _state_dict_prefix = "foo." + state_dict_pre_hook_count = 0 + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.a = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3), nn.Linear(3, 3)) + + def forward(self, x): + return self.a(x) + + def my_state_dict_pre_hook(module, prefix, keep_vars): + nonlocal keep_var_setting + self.assertEqual(keep_vars, keep_var_setting) + nonlocal state_dict_pre_hook_count + state_dict_pre_hook_count += 1 + self.assertTrue(prefix.startswith(_state_dict_prefix)) + + mod = MyModule() + mod.register_state_dict_pre_hook(my_state_dict_pre_hook) + # Test to ensure submodules run the hook as well. + mod.a.register_state_dict_pre_hook(my_state_dict_pre_hook) + for keep_var_setting in [True, False]: + _ = mod.state_dict(prefix=_state_dict_prefix, keep_vars=keep_var_setting) + self.assertEqual(2, state_dict_pre_hook_count) + state_dict_pre_hook_count = 0 + @skipIfTorchDynamo("TorchDynamo fails here for unknown reasons") def test_load_state_dict_ref_cycle(self): # load_state_dict shouldn't cause a reference cycle involving Tensors @@ -6252,378 +2737,6 @@ def test_assignments(get_list, a, b, c): self.assertIn('buf', l.state_dict()) self.assertEqual(l.state_dict()['buf'], buf) - @unittest.skipIf(not TEST_CUDA, 'CUDA not available') - def test_thnn_conv_strided_padded_dilated(self): - for convfn, dims, transposed in ( - (torch.nn.functional.conv2d, 2, False), - (torch.nn.functional.conv_transpose2d, 2, True), - (torch.nn.functional.conv3d, 3, False), - (torch.nn.functional.conv_transpose3d, 3, True)): - for stride, padding, dilation in ( - (2, 0, 1), (1, 1, 1), (2, 1, 1), (1, 0, 2)): - kwargs = {"stride": stride, "padding": padding, "dilation": dilation} - inp_shape = (1, 2) + dims * (4,) - weight_shape = (2, 2) + dims * (1,) - inputs = torch.randn(inp_shape, dtype=torch.double, device="cuda", requires_grad=True) - weight = torch.randn(weight_shape, dtype=torch.double, device="cuda", requires_grad=True) - bias = torch.randn(2, dtype=torch.double, device="cuda", requires_grad=True) - with torch.backends.cudnn.flags(enabled=False): - res = convfn(inputs, weight, bias, **kwargs) - res_cpu = convfn(inputs.cpu(), weight.cpu(), bias.cpu(), **kwargs) - self.assertEqual(res, res_cpu) - with torch.backends.cudnn.flags(enabled=False): - torch.autograd.gradcheck( - lambda x, w, b: convfn(x, w, b, **kwargs), - (inputs, weight, bias) - ) - torch.autograd.gradcheck( - lambda x, w, b: convfn(x, w, b, **kwargs), - (inputs.cpu(), weight.cpu(), bias.cpu()) - ) - - def test_Conv2d_inconsistent_types(self): - inputs = torch.randn(4, 1, 7, 7, dtype=torch.float) - weights = torch.randn(1, 1, 3, 3, dtype=torch.double) - # inconsistent types should raise an exception - self.assertRaises(RuntimeError, lambda: nn.functional.conv2d(inputs, weights)) - # but it should work with the same type - nn.functional.conv2d(inputs.float(), weights.float()) - - @unittest.skipIf(not TEST_CUDA, 'CUDA not available') - def test_Conv2d_inconsistent_types_on_GPU_without_cudnn(self): - inputs = torch.randn(4, 1, 7, 7, dtype=torch.float, device="cuda") - weights = torch.randn(1, 1, 3, 3, dtype=torch.double, device="cuda") - bias = torch.randn(1, dtype=torch.double, device="cuda") - - with torch.backends.cudnn.flags(enabled=False): - # inconsistent types should raise an exception - self.assertRaises(RuntimeError, lambda: nn.functional.conv2d(inputs, weights)) - self.assertRaises(RuntimeError, lambda: nn.functional.conv2d(inputs, weights.float(), bias)) - - # but it should work with the same type - nn.functional.conv2d(inputs.float(), weights.float(), bias.float()) - - def test_Conv2d_1x1(self): - in_channels = 2 - out_channels = 2 - mod = torch.nn.Conv2d(2, 2, 1, bias=False).to(dtype=torch.double) - input = torch.randn(1, in_channels, 5, 5, requires_grad=True, dtype=torch.double) - for enabled in (False, True): - with torch.backends.mkldnn.flags(enabled=enabled): - gradcheck(F.conv2d, (input, mod.weight)) - - def test_Conv2d_OneDNN(self): - def run_once(group_val=24, dilation=1): - ifm = torch.ones([1, group_val, 6, 6], dtype=torch.float32) - weights = torch.ones([group_val, 1, 3, 3], dtype=torch.float32) - op = torch.nn.Conv2d( - in_channels=group_val, - out_channels=group_val, - kernel_size=[3, 3], - stride=[2, 2], - padding=[1, 1], - dilation=[dilation, dilation], - groups=group_val, - bias=False, - padding_mode='zeros' - ) - - op.weight.data = weights - res = op(ifm) - grad_in = torch.ones(res.shape, dtype=torch.float32) - res.backward(grad_in) - return op.weight.grad - - for gorup_val in (24, 48, 23, 25): - for dilation in (1, 2): - with torch.backends.mkldnn.flags(enabled=False): - without_onednn = run_once(gorup_val, dilation) - - with torch.backends.mkldnn.flags(enabled=True): - with_onednn = run_once(gorup_val, dilation) - - self.assertEqual(without_onednn, with_onednn) - - @unittest.skipIf(not TEST_CUDA, 'CUDA not available') - @unittest.skipIf(not TEST_CUDNN, 'CUDNN not available') - def test_cudnn_non_contiguous(self): - x = torch.randn(192, 16, 50).cuda() - x = x.permute(0, 2, 1).contiguous().permute(0, 2, 1) - m = torch.nn.Conv1d( - in_channels=16, - out_channels=32, - kernel_size=2, - bias=True).cuda() - result = m(x) - - @unittest.skipIf(not TEST_CUDA, 'CUDA not available') - @unittest.skipIf(not TEST_CUDNN, 'CUDNN not available') - def test_Conv2d_inconsistent_types_on_GPU_with_cudnn(self): - inputs = torch.randn(4, 1, 7, 7, dtype=torch.float, device="cuda") - weights = torch.randn(1, 1, 3, 3, dtype=torch.double, device="cuda") - bias = torch.randn(1, dtype=torch.double, device="cuda") - - with torch.backends.cudnn.flags(enabled=True): - # inconsistent types should raise an exception - self.assertRaises(RuntimeError, lambda: nn.functional.conv2d(inputs, weights)) - self.assertRaises(RuntimeError, lambda: nn.functional.conv2d(inputs, weights.float(), bias)) - - # but it should work with the same type - nn.functional.conv2d(inputs.float(), weights.float(), bias.float()) - - def test_Conv2d_missing_argument(self): - c = nn.Conv2d(3, 3, 3) - self.assertRaises(TypeError, lambda: c(None)) - - def test_Conv2d_backward_twice(self): - input = torch.randn(2, 3, 5, 5) - c = nn.Conv2d(3, 3, 3) - o1 = c(input) - o1.sum().backward() - self.assertRaisesRegex(RuntimeError, 'Specify retain_graph=True', - lambda: o1.sum().backward()) - - - def test_conv_modules_raise_error_on_incorrect_input_size(self): - for dtype in [torch.bfloat16, torch.double, torch.float]: - modules = [nn.Conv1d(3, 8, 3).to(dtype), nn.ConvTranspose1d(3, 8, 3).to(dtype), - nn.Conv2d(3, 8, 3).to(dtype), nn.ConvTranspose2d(3, 8, 3).to(dtype), - nn.Conv3d(3, 8, 3).to(dtype), nn.ConvTranspose3d(3, 8, 3).to(dtype)] - - invalid_input_dims = [(1, 4), (1, 4), - (2, 5), (2, 5), - (3, 6), (3, 6)] - - for invalid_dims, module in zip(invalid_input_dims, modules): - for dims in invalid_dims: - input = torch.empty(torch.Size((3, ) * dims)) - self.assertRaises(RuntimeError, lambda: module(input)) - - def test_conv_shapecheck(self): - def test(should_raise, module, input_size, dtype): - input = torch.empty(3, *input_size).to(dtype) - if should_raise: - self.assertRaises(RuntimeError, lambda: module(input)) - else: - # just run it to ensure no exception raised. - module(input) - - for dtype in [torch.bfloat16, torch.float, torch.double, torch.cfloat, torch.cdouble]: - # Conv1d - test(True, nn.Conv1d(1, 1, 3).to(dtype), (1, 2), dtype) - test(True, nn.Conv1d(1, 1, 3, stride=2).to(dtype), (1, 2), dtype) - test(False, nn.Conv1d(1, 1, 2).to(dtype), (1, 2), dtype) - test(False, nn.Conv1d(1, 1, 2, stride=2).to(dtype), (1, 2), dtype) - test(False, nn.Conv1d(1, 1, 3, stride=2, padding=1).to(dtype), (1, 2), dtype) - - # Conv2d - test(True, nn.Conv2d(1, 1, (3, 3)).to(dtype), (1, 2, 2), dtype) - test(False, nn.Conv2d(1, 1, (3, 3)).to(dtype), (1, 3, 3), dtype) - test(False, nn.Conv2d(1, 1, (3, 3), padding=1).to(dtype), (1, 2, 2), dtype) - - # Conv3D - test(True, nn.Conv3d(1, 1, (3, 3, 3)).to(dtype), (1, 2, 2, 2), dtype) - test(False, nn.Conv3d(1, 1, (3, 3, 3)).to(dtype), (1, 3, 3, 3), dtype) - test(False, nn.Conv3d(1, 1, (3, 3, 3), padding=1).to(dtype), (1, 2, 2, 2), dtype) - - def test_ConvTranspose2d_output_size(self): - m = nn.ConvTranspose2d(3, 4, 3, 3, 0, 2) - i = torch.randn(2, 3, 6, 6) - for h in range(15, 22): - for w in range(15, 22): - if 18 <= h <= 20 and 18 <= w <= 20: - output = m(i, output_size=(h, w)) - self.assertEqual(output.size()[2:], (h, w)) - else: - self.assertRaises(ValueError, lambda: m(i, (h, w))) - - def test_ConvTranspose2d_output_size_downsample_upsample(self): - b, c, hid_c = 2, 3, 2 - for h in range(13, 24): - for w in range(13, 17): - for k in range(2, 5): - for d in range(1, 5): - for s in range(1, 4): - for p in range(3): - conv = nn.Conv2d( - in_channels=c, - out_channels=hid_c, - kernel_size=k, - stride=s, - padding=p, - dilation=d, - ) - - t_conv = nn.ConvTranspose2d( - in_channels=hid_c, - out_channels=c, - kernel_size=k, - stride=s, - padding=p, - dilation=d, - ) - - i = torch.randn(b, c, h, w) - - out = t_conv(conv(i), output_size=i.shape) - - self.assertEqual(out.size()[2:], i.size()[2:]) - - def test_ConvTranspose3d_correct_output_size(self): - # Check that ConvTranspose3d can take a 5d output_size. - m = nn.ConvTranspose3d(2, 2, 2) - i = torch.rand(1, 2, 1, 1, 1) - out = m(i, output_size=(1, 2, 2, 2, 2)) - - @unittest.skipIf(not TEST_CUDA, 'CUDA not available') - def test_ConvTranspose2d_half_cublas_gemm(self): - with torch.backends.cudnn.flags(enabled=False): - inputs = torch.randn(1, 1, 16, 16, device='cuda', dtype=torch.half) - deconv = nn.ConvTranspose2d( - 1, 1, 3, stride=2, padding=1, output_padding=1).cuda().half() - output = deconv(inputs) - output.mean().backward() - - # For https://github.com/pytorch/pytorch/pull/1273 - # Almost identical to the above `test_Conv2d_naive_groups` - @torch.backends.cudnn.flags(enabled=True, benchmark=False) - def test_Conv2d_groups_nobias(self): - dev_dtypes = [("cpu", torch.float)] - if TEST_CUDA: - dev_dtypes += [("cuda", torch.float), ("cuda", torch.half)] - if AMPERE_OR_ROCM: - dev_dtypes += [("cuda", torch.bfloat16)] - for device, dtype in dev_dtypes: - m = nn.Conv2d(4, 4, kernel_size=3, groups=2, bias=False).to(device, dtype) - i = torch.randn(2, 4, 6, 6, device=device, dtype=dtype, requires_grad=True) - output = m(i) - grad_output = torch.randn(2, 4, 4, 4, device=device, dtype=dtype) - output.backward(grad_output) - - m1 = nn.Conv2d(2, 2, kernel_size=3, bias=False).to(device, dtype) - m1.weight.data.copy_(m.weight.data[:2]) - i1 = i.data[:, :2].contiguous().requires_grad_(True) - output1 = m1(i1) - output1.backward(grad_output[:, :2].contiguous()) - - m2 = nn.Conv2d(2, 2, kernel_size=3, bias=False).to(device, dtype) - m2.weight.data.copy_(m.weight.data[2:]) - i2 = i.data[:, 2:].contiguous().requires_grad_(True) - output2 = m2(i2) - output2.backward(grad_output[:, 2:].contiguous()) - - self.assertEqual(output, torch.cat([output1, output2], 1)) - self.assertEqual(i.grad.data, - torch.cat([i1.grad.data, i2.grad.data], 1), - atol=dtype2prec_DONTUSE[dtype], rtol=0) - self.assertEqual(m.weight.grad.data, - torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0), - atol=1e-1 if dtype == torch.half else dtype2prec_DONTUSE[dtype], rtol=0) - - # Almost identical to the above `test_Conv2d_naive_groups` - # Covering special case when group > 1, input-channel / group < 16 and output-channel is multiple of 16 - # See also https://github.com/pytorch/pytorch/pull/18463#issuecomment-476563686 - # and https://github.com/pytorch/pytorch/pull/18463#issuecomment-477001024 - @torch.backends.cudnn.flags(enabled=True, benchmark=False) - def test_Conv2d_groups_nobias_v2(self): - torch.manual_seed(123) - dev_dtypes = [("cpu", torch.float)] - if TEST_CUDA: - dev_dtypes += [("cuda", torch.float), ("cuda", torch.half)] - if AMPERE_OR_ROCM: - dev_dtypes += [("cuda", torch.bfloat16)] - for device, dtype in dev_dtypes: - m = nn.Conv2d(4, 16, kernel_size=3, groups=2, bias=False).to(device, dtype) - i = torch.randn(2, 4, 6, 6, device=device, dtype=dtype, requires_grad=True) - output = m(i) - grad_output = torch.randn(2, 16, 4, 4, device=device, dtype=dtype) - output.backward(grad_output) - - m1 = nn.Conv2d(2, 8, kernel_size=3, bias=False).to(device, dtype) - m1.weight.data.copy_(m.weight.data[:8]) - i1 = i.data[:, :2].contiguous().requires_grad_(True) - output1 = m1(i1) - output1.backward(grad_output[:, :8].contiguous()) - - m2 = nn.Conv2d(2, 8, kernel_size=3, bias=False).to(device, dtype) - m2.weight.data.copy_(m.weight.data[8:]) - i2 = i.data[:, 2:].contiguous().requires_grad_(True) - output2 = m2(i2) - output2.backward(grad_output[:, 8:].contiguous()) - - self.assertEqual(output, torch.cat([output1, output2], 1)) - self.assertEqual(i.grad.data, - torch.cat([i1.grad.data, i2.grad.data], 1), - atol=dtype2prec_DONTUSE[dtype], rtol=0) - self.assertEqual(m.weight.grad.data, - torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0), - atol=1e-1 if dtype == torch.half else dtype2prec_DONTUSE[dtype], rtol=0) - - # CPU-only test for group conv3d fast implementation using bmm - # See: https://github.com/pytorch/pytorch/pull/36355 - def test_Conv3d_groups_nobias(self): - torch.manual_seed(123) - m = nn.Conv3d(4, 16, kernel_size=3, groups=2, bias=False).to("cpu", torch.float) - i = torch.randn(2, 4, 6, 6, 6, device="cpu", dtype=torch.float, requires_grad=True) - output = m(i) - grad_output = torch.randn(2, 16, 4, 4, 4, device="cpu", dtype=torch.float) - output.backward(grad_output) - - m1 = nn.Conv3d(2, 8, kernel_size=3, bias=False).to("cpu", torch.float) - m1.weight.data.copy_(m.weight.data[:8]) - i1 = i.data[:, :2].contiguous().requires_grad_(True) - output1 = m1(i1) - output1.backward(grad_output[:, :8].contiguous()) - - m2 = nn.Conv3d(2, 8, kernel_size=3, bias=False).to("cpu", torch.float) - m2.weight.data.copy_(m.weight.data[8:]) - i2 = i.data[:, 2:].contiguous().requires_grad_(True) - output2 = m2(i2) - output2.backward(grad_output[:, 8:].contiguous()) - - self.assertEqual(output, torch.cat([output1, output2], 1)) - self.assertEqual(i.grad.data, - torch.cat([i1.grad.data, i2.grad.data], 1), - atol=dtype2prec_DONTUSE[torch.float], rtol=0) - self.assertEqual(m.weight.grad.data, - torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0), - atol=dtype2prec_DONTUSE[torch.float], rtol=dtype2prec_DONTUSE[torch.float]) - - def test_Conv3d_groups_wbias(self): - torch.manual_seed(123) - m = nn.Conv3d(4, 16, kernel_size=3, groups=2, bias=True).to("cpu", torch.float) - i = torch.randn(2, 4, 6, 6, 6, device="cpu", dtype=torch.float, requires_grad=True) - output = m(i) - grad_output = torch.randn(2, 16, 4, 4, 4, device="cpu", dtype=torch.float) - output.backward(grad_output) - - m1 = nn.Conv3d(2, 8, kernel_size=3, bias=True).to("cpu", torch.float) - m1.weight.data.copy_(m.weight.data[:8]) - m1.bias.data.copy_(m.bias.data[:8]) - i1 = i.data[:, :2].contiguous().requires_grad_(True) - output1 = m1(i1) - output1.backward(grad_output[:, :8].contiguous()) - - m2 = nn.Conv3d(2, 8, kernel_size=3, bias=True).to("cpu", torch.float) - m2.weight.data.copy_(m.weight.data[8:]) - m2.bias.data.copy_(m.bias.data[8:]) - i2 = i.data[:, 2:].contiguous().requires_grad_(True) - output2 = m2(i2) - output2.backward(grad_output[:, 8:].contiguous()) - - self.assertEqual(output, torch.cat([output1, output2], 1)) - self.assertEqual(i.grad.data, - torch.cat([i1.grad.data, i2.grad.data], 1), - atol=dtype2prec_DONTUSE[torch.float], - rtol=dtype2prec_DONTUSE[torch.float]) - self.assertEqual(m.weight.grad.data, - torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0), - atol=dtype2prec_DONTUSE[torch.float], - rtol=dtype2prec_DONTUSE[torch.float]) - self.assertEqual(m.bias.grad.data, - torch.cat([m1.bias.grad.data, m2.bias.grad.data], 0), - atol=dtype2prec_DONTUSE[torch.float], rtol=dtype2prec_DONTUSE[torch.float]) - def test_container_copy(self): class Model(nn.Module): def __init__(self): @@ -6711,33 +2824,6 @@ def test_mse_loss_size_warning(self): self.assertEqual(len(w), 1) self.assertIn('Please ensure they have the same size.', str(w[0])) - def test_poisson_nll_loss_reduction_modes(self): - input = torch.tensor([0.5, 1.5, 2.5]) - target = torch.tensor([1., 2., 3.]) - component_wise_loss = torch.exp(input) - target * input - self.assertEqual(component_wise_loss, - F.poisson_nll_loss(input, target, reduction='none')) - self.assertEqual(torch.sum(component_wise_loss), - F.poisson_nll_loss(input, target, reduction='sum')) - self.assertEqual(torch.mean(component_wise_loss), - F.poisson_nll_loss(input, target, reduction='mean')) - with self.assertRaisesRegex(ValueError, 'is not valid'): - F.poisson_nll_loss(input, target, reduction='total') - - def test_gaussian_nll_loss_reduction_modes(self): - input = torch.tensor([[0.5, 1.5, 2.5], [2., 4., 6.]]) - target = torch.tensor([[1., 2., 3.], [4., 5., 6.]]) - var = torch.tensor([[0.5, 1., 1.5], [1., 1.5, 2.]]) - component_wise_loss = 0.5 * (torch.log(var) + (input - target)**2 / var) - self.assertEqual(component_wise_loss, - F.gaussian_nll_loss(input, target, var, reduction='none')) - self.assertEqual(torch.sum(component_wise_loss), - F.gaussian_nll_loss(input, target, var, reduction='sum')) - self.assertEqual(torch.mean(component_wise_loss), - F.gaussian_nll_loss(input, target, var, reduction='mean')) - with self.assertRaisesRegex(ValueError, 'is not valid'): - F.gaussian_nll_loss(input, target, var, reduction='total') - def test_gaussian_nll_loss_broadcasting(self): input = torch.tensor([[0.5, 1.5, 2.5], [2., 4., 6.]]) target_full = torch.tensor([[1., 2., 3.], [1., 2., 3.]]) @@ -8917,11 +5003,12 @@ def helper(self, size, dtype, mixed_dtype=False): helper(self, shape, torch.bfloat16, False) helper(self, shape, torch.bfloat16, True) - def test_batchnorm_non_contig_cpu(self): + @parametrize_test('bn_module', [torch.nn.BatchNorm2d, torch.nn.SyncBatchNorm]) + def test_batchnorm_non_contig_cpu(self, bn_module): input = torch.arange(6, dtype=torch.float).reshape(1, 3, 2, 1).cpu() input = input.permute(0, 2, 1, 3) - bn = torch.nn.BatchNorm2d(2).cpu().float().eval() + bn = bn_module(2).cpu().float().eval() bn.weight.data.uniform_() bn.bias.data.uniform_() @@ -8939,7 +5026,7 @@ def test_batchnorm_non_contig_cpu(self): input_bf = torch.arange(24, dtype=torch.bfloat16).reshape(1, 3, 2, 4) input_bf = input_bf.permute(0, 2, 1, 3) input_f = input_bf.float() - bn_mix = torch.nn.BatchNorm2d(2).float().eval() + bn_mix = bn_module(2).float().eval() ref_bn_f = deepcopy(bn_mix) out_bf = bn_mix(input_bf) ref_out_bf = ref_bn_f(input_f) @@ -10343,12 +6430,10 @@ def helper(size, scale_factor, mode, device, memory_format=torch.contiguous_form outf = m(inputf) out = m(input) - self.assertEqual(out.dtype, dtype) - self.assertEqualIgnoreType(out, outf, atol=0.1, rtol=0.0) + self.assertEqual(out, outf.to(dtype), atol=0.1, rtol=0.0) out.sum().backward() outf.sum().backward() - self.assertEqual(input.grad.dtype, dtype) self.assertEqual(input.grad, inputf.grad.to(dtype), atol=0.1, rtol=0) for device in ['cpu']: @@ -10368,7 +6453,7 @@ def test_interpolate_illegal_memory_access(self): input = torch.ones((1, 1, in_s), device='cuda', requires_grad=True) # note we allocated grad_output to be larger so out of bound access - # woudl be visible in grad_input + # would be visible in grad_input grad = torch.ones((1, 1, out_s * 2), device='cuda', requires_grad=True) grad = grad[:, :, :out_s] @@ -10384,6 +6469,67 @@ def test_interpolate_illegal_memory_access(self): self.assertEqual(out_ref, out) self.assertEqual(input_ref.grad, input.grad) + def test_interpolate_buffer_overflow(self): + # Test buffer overflow issue due to inaccurate floating point + # representation for integer values. See issue below for details. + # https://github.com/pytorch/pytorch/issues/88939 + + def helper(size, dtype, mode, device, is_channels_last): + input = torch.ones(size, dtype=dtype, device=device) + if is_channels_last: + if len(size) == 3: + input = input.transpose(1, 2).contiguous().transpose(1, 2) + elif len(size) == 4: + input = input.to(memory_format=torch.channels_last) + else: + input = input.to(memory_format=torch.channels_last_3d) + output1 = F.interpolate(input, 2, mode=mode, align_corners=True) + # reset the corner value and expect the output is changed as well + # the output won't be changed on buffer overflow + input[(-1,) * len(size)] = 0.5 + output2 = F.interpolate(input, 2, mode=mode, align_corners=True) + self.assertNotEqual(output1, output2) + + size_dtype_list = [] + # We set the size larger than the floating point exactly representable range + # float: exact representable range (-2**24,2**24) + size_dtype_list.append(([1, 10, 2**24 + 4], torch.float)) + size_dtype_list.append(([1, 10, 2, 2**24 + 4], torch.float)) + size_dtype_list.append(([1, 10, 2, 2, 2**24 + 4], torch.float)) + # bfloat16: exact representable range (-2**8, 2**8) + size_dtype_list.append(([1, 10, 2**8 + 4], torch.bfloat16)) + size_dtype_list.append(([1, 10, 2, 2**8 + 4], torch.bfloat16)) + size_dtype_list.append(([1, 10, 2, 2, 2**8 + 4], torch.bfloat16)) + # half: exact representable range (-2**11, 2**11) + size_dtype_list.append(([1, 10, 2**11 + 4], torch.half)) + size_dtype_list.append(([1, 10, 2, 2**11 + 4], torch.half)) + size_dtype_list.append(([1, 10, 2, 2, 2**11 + 4], torch.half)) + + # TODO: turn on cuda test after buffer overflow issue is fixed in cuda kernel + # devices = ['cpu'] + (['cuda'] if torch.cuda.is_available() else []) + devices = ['cpu'] + + for mode in ('linear', 'bilinear', 'bicubic', 'trilinear'): + for size_dtype in size_dtype_list: + size, dtype = size_dtype + if ( + mode == 'linear' and len(size) != 3 + or (mode == 'bilinear' and len(size) != 4) + or (mode == 'bicubic' and len(size) != 4) + or (mode == 'trilinear' and len(size) != 5) + ): + continue + for device in devices: + if ( + device == 'cpu' and dtype == torch.half + or (device == 'cuda' and dtype == torch.bfloat16) + ): + # no half precision support on cpu or bfloat16 on cuda yet + continue + for is_channels_last in (True, False): + helper(size, dtype, mode, device, is_channels_last) + + def test_interpolate(self): def _test_interpolate_helper(in_t, scale_factor, layer): out_size = int(math.floor(in_t.shape[-1] * scale_factor)) @@ -10562,140 +6708,6 @@ def test_bilinear_broadcasting(self): expected = m(input1.view(6, 5), input2.view(6, 6)).view(2, 3, 8) self.assertEqual(expected, m(input1, input2)) - def test_conv_tbc(self): - inp = torch.randn(9, 4, 5, requires_grad=True) - weight = torch.randn(3, 5, 6, requires_grad=True) - bias = torch.randn(6, requires_grad=True) - - gradcheck(lambda i, w, b, pad: F.conv_tbc(i, w, b, pad), (inp, weight, bias, 3)) - - - @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") - @unittest.skipIf(not TEST_CUDNN, "needs cudnn") - @skipIfRocmVersionLessThan((4, 3)) - @skipIfNotMiopenSuggestNHWC - def test_grouped_conv_cudnn_nhwc_support(self): - # in order to catch the hols in grouped convolution in nhwc support for earlier cudnn version - input = torch.randn((16, 16, 8, 8), dtype=torch.float16, device="cuda").to(memory_format=torch.channels_last) - weight = torch.randn((8, 4, 3, 3), dtype=torch.float16, device="cuda").to(memory_format=torch.channels_last) - out = torch.convolution(input, weight, None, (1, 1), (1, 1), (1, 1), False, (0, 0), 4) - input = torch.randn((16, 8, 8, 8), dtype=torch.float16, device="cuda").to(memory_format=torch.channels_last) - out_transpose = torch.convolution(input, weight, None, (1, 1), (1, 1), (1, 1), True, (0, 0), 4) - - @unittest.expectedFailure - @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") - @unittest.skipIf(not TEST_CUDNN, "needs cudnn") - def test_conv_cudnn_memory_layout_dominance(self): - # desired behavior here is to have the memory_layout of conv.weight to - # dominante the layout of output. - # which is not the same as current behavior, we'll fix this in - # following up PRs and remove the `expectedFailure` tag - input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float32, device="cuda", requires_grad=True) - conv = nn.Conv2d(8, 4, 3).cuda().float() - - out = conv(input) - self.assertTrue(out.is_contiguous()) - - input = input.contiguous(memory_format=torch.channels_last) - out = conv(input) - self.assertTrue(out.is_contiguous()) - - conv.weight.data = conv.weight.contiguous(memory_format=torch.channels_last) - out = conv(input) - self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) - - input = input.contiguous() - out = conv(input) - self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) - - - @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") - def test_cudnn_noncontiguous_weight(self): - # Noncontiguous weights must be contiguous() before being - # passed to cuDNN - input = torch.tensor([1, 1, 1], dtype=torch.double, device="cuda").view(1, 1, 3) - weights1 = torch.tensor([1], dtype=torch.double, device="cuda").expand(1, 1, 2) - weights2 = torch.tensor([1], dtype=torch.double, device="cuda").expand(1, 1, 2).contiguous() - self.assertEqual(F.conv1d(input, weights1, bias=None, stride=2, dilation=2), - F.conv1d(input, weights2, bias=None, stride=2, dilation=2)) - - - def run_grad_conv_test(self, func_forward, func_backward, dim=1, gradient='input'): - for kern, inp_size in [(3, 6), (3, 7), (4, 9)]: - for batch, stride, padding, chan_in, chan_out, dilation in \ - product([1, 2], [1, 2], [0, 1, 2], [2], [3], [1]): - - for has_bias in [True, False]: - input_shape = [batch, chan_in] - weight_shape = [chan_out, chan_in] - for _ in range(dim): - input_shape.append(inp_size) - weight_shape.append(kern) - - input = torch.randn(input_shape, requires_grad=True) - weight = torch.randn(weight_shape, requires_grad=True) - if has_bias: - bias = torch.randn([chan_out], requires_grad=True) - output = func_forward(input, weight, stride=stride, padding=padding, dilation=dilation, bias=bias) - - gradient_o = torch.randn(output.shape) - gradient_w = torch.autograd.grad(output, input if (gradient == 'input') else weight, gradient_o) - - self.assertEqual(gradient_w[0], - func_backward( - input_shape if (gradient == 'input') else input, - weight_shape if (gradient == 'weight') else weight, - gradient_o, - stride=stride, - padding=padding, - dilation=dilation)) - - def test_grad_conv1d_input(self): - self.run_grad_conv_test(F.conv1d, F.grad.conv1d_input, 1, 'input') - - def test_grad_conv1d_weight(self): - self.run_grad_conv_test(F.conv1d, F.grad.conv1d_weight, 1, 'weight') - - def test_grad_conv2d_input(self): - self.run_grad_conv_test(F.conv2d, F.grad.conv2d_input, 2, 'input') - - def test_grad_conv2d_weight(self): - self.run_grad_conv_test(F.conv2d, F.grad.conv2d_weight, 2, 'weight') - - def test_grad_conv3d_input(self): - self.run_grad_conv_test(F.conv3d, F.grad.conv3d_input, 3, 'input') - - def test_grad_conv3d_weight(self): - self.run_grad_conv_test(F.conv3d, F.grad.conv3d_weight, 3, 'weight') - - @unittest.skipIf(not torch._nnpack_available(), "NNPACK unavailable") - def test_nnpack_conv(self): - for kern, inp_size in [(3, 6), (3, 7), (4, 9)]: - for batch, stride, padding, chan_in, chan_out in \ - product([1, 2, 3, 4], [1, 2], [0, 1, 2], [2], [3]): - - for has_bias in [True, False]: - input_shape = [batch, chan_in] - weight_shape = [chan_out, chan_in] - for _ in range(2): - input_shape.append(inp_size) - weight_shape.append(kern) - - input = torch.randn(input_shape, requires_grad=True, dtype=torch.float) - weight = torch.randn(weight_shape, requires_grad=True, dtype=torch.float) - if has_bias: - bias = torch.randn([chan_out], requires_grad=True, dtype=torch.float) - output = torch._nnpack_spatial_convolution(input, weight, stride=stride, padding=padding, bias=bias) - output_expected = torch.nn.functional.conv2d(input, weight, stride=stride, padding=padding, bias=bias) - self.assertEqual(output, output_expected, atol=3e-4, rtol=0) - - gradient_o = torch.randn(output.shape, dtype=torch.float) - - grads = torch.autograd.grad(output, [input, weight], gradient_o) - grads_expected = torch.autograd.grad(output_expected, [input, weight], gradient_o) - for gr, gr_expected in zip(grads, grads_expected): - self.assertEqual(gr, gr_expected, atol=3e-4, rtol=0) - def test_fold_invalid_arg(self): # input.size(1) not divisible by \prod(kernel_size) @@ -10742,16 +6754,6 @@ def test_unfold_invalid_arg(self): unfold = nn.Unfold(kernel_size=(1, 3), padding=(1, 1), dilation=(1, 2)) unfold(torch.randn(1, 2, 2, 2)) - def test_conv_padding_mode(self): - with self.assertRaisesRegex(ValueError, "padding_mode must be one of"): - nn.Conv2d(3, 3, 3, padding_mode="xyz") - - with self.assertRaisesRegex(ValueError, "padding_mode must be one of"): - nn.Conv2d(3, 3, 3, padding_mode=3) - - with self.assertRaisesRegex(ValueError, "Only \"zeros\" "): - nn.ConvTranspose2d(3, 3, 3, padding_mode="reflect") - def test_softmin(self): x = torch.randn(2, 16) self.assertEqual(F.softmin(x, 1), F.softmax(-x, 1)) @@ -10763,12 +6765,10 @@ def test_log_softmax_cpu(self, dtype=torch.bfloat16): input = inputf.to(dtype).detach().requires_grad_(True) outf = F.log_softmax(inputf, dim=dim) out = F.log_softmax(input, dim=dim) - self.assertEqual(out.dtype, dtype) self.assertEqual(out, outf.to(dtype=dtype), atol=0.1, rtol=0) out.sum().backward() outf.sum().backward() - self.assertEqual(input.grad.dtype, dtype) self.assertEqual(input.grad, inputf.grad.to(dtype), atol=0.1, rtol=0) def test_softmax_cpu(self, dtype=torch.bfloat16): @@ -10777,12 +6777,10 @@ def test_softmax_cpu(self, dtype=torch.bfloat16): input = inputf.to(dtype).detach().requires_grad_(True) outf = F.softmax(inputf, dim=dim) out = F.softmax(input, dim=dim) - self.assertEqual(out.dtype, dtype) - self.assertEqualIgnoreType(out, outf, atol=1e-3, rtol=0) + self.assertEqual(out, outf.to(dtype), atol=1e-3, rtol=0) out.sum().backward() outf.sum().backward() - self.assertEqual(input.grad.dtype, dtype) self.assertEqual(input.grad, inputf.grad.to(dtype), atol=1e-3, rtol=0) def test_adaptive_log_softmax(self): @@ -10893,12 +6891,10 @@ def test_cross_entropy_loss(self, dtype=torch.bfloat16): outf = loss_cpu(inputf, target) out = loss_cpu(input, target) - self.assertEqual(out.dtype, dtype) self.assertEqual(out, outf.to(dtype=dtype), atol=1e-1, rtol=0) outf.backward() out.backward() - self.assertEqual(input.grad.dtype, dtype) self.assertEqual(input.grad, inputf.grad.to(dtype=dtype), atol=1e-1, rtol=0) def test_cross_entropy_loss_precision(self): @@ -10993,95 +6989,16 @@ def test_sync_batchnorm_accuracy_cuda(self): # fwd: torch.batch_norm_stats, torch.batch_norm_gather_stats_with_counts, torch.batch_norm_elemt # bwd: torch.batch_norm_backward_reduce, torch.batch_norm_backward_elemt - def _batch_norm_stats(data): + def _batch_norm_stats(data, memory_format, mean_axes): mean1, _ = torch.batch_norm_stats(data, 1e-5) - mean2, _ = torch.batch_norm_stats(data.to(memory_format=torch.channels_last), 1e-5) - mean_ref = torch.mean(data, (0, 2, 3), keepdim=False) + mean2, _ = torch.batch_norm_stats(data.to(memory_format=memory_format), 1e-5) + mean_ref = torch.mean(data, mean_axes, keepdim=False) self.assertEqual(mean_ref, mean1) self.assertEqual(mean_ref, mean2) - data = torch.randn(1, 96, 112, 112, dtype=torch.float, device='cuda') - _batch_norm_stats(data) - - def test_functional_grad_conv(self): - # Conv 1D - input = torch.randn(1, 1, 5, requires_grad=True) - weight = torch.randn(1, 1, 3, requires_grad=True) - output = F.conv1d(input, weight, dilation=2) - grad_output = torch.randn(output.shape) - - grad_input_autograd, grad_weight_autograd = torch.autograd.grad(output, (input, weight), grad_output) - - grad_input_functional = torch.nn.grad.conv1d_input(input.shape, weight, grad_output, dilation=2) - self.assertEqual(grad_input_functional, grad_input_autograd) - - grad_weight_functional = torch.nn.grad.conv1d_weight(input, weight.shape, grad_output, dilation=2) - self.assertEqual(grad_weight_functional, grad_weight_autograd) - - # Conv 2D - input = torch.randn(1, 1, 5, 5, requires_grad=True) - weight = torch.randn(1, 1, 3, 3, requires_grad=True) - output = F.conv2d(input, weight, dilation=2) - grad_output = torch.randn(output.shape) - - (grad_input_autograd, grad_weight_autograd) = torch.autograd.grad(output, (input, weight), grad_output) - - grad_input_functional = torch.nn.grad.conv2d_input(input.shape, weight, grad_output, dilation=2) - self.assertEqual(grad_input_functional, grad_input_autograd) - - grad_weight_functional = torch.nn.grad.conv2d_weight(input, weight.shape, grad_output, dilation=2) - self.assertEqual(grad_weight_functional, grad_weight_autograd) - - # Conv 3D - input = torch.randn(1, 1, 5, 5, 5, requires_grad=True) - weight = torch.randn(1, 1, 3, 3, 3, requires_grad=True) - output = F.conv3d(input, weight, dilation=2) - grad_output = torch.randn(output.shape) - - (grad_input_autograd, grad_weight_autograd) = torch.autograd.grad(output, (input, weight), grad_output) - - grad_input_functional = torch.nn.grad.conv3d_input(input.shape, weight, grad_output, dilation=2) - self.assertEqual(grad_input_functional, grad_input_autograd) - - grad_weight_functional = torch.nn.grad.conv3d_weight(input, weight.shape, grad_output, dilation=2) - self.assertEqual(grad_weight_functional, grad_weight_autograd) - - def test_functional_grad_conv2d(self): - BATCH_SIZE = 4 - IN_CH = 8 - OUT_CH = 16 - SPATIAL = 32 - - def _test_conv2d(stride, kernel_size, groups, dilation): - padding = kernel_size // 2 - - input = torch.empty(BATCH_SIZE, IN_CH, SPATIAL, SPATIAL).uniform_(-8.0, 8.0).requires_grad_(True) - - weight = torch.empty(OUT_CH, IN_CH // groups, kernel_size, kernel_size).uniform_(-4.0, 4.0).requires_grad_(True) - - output = F.conv2d(input, weight, - stride=stride, padding=padding, dilation=dilation, groups=groups) - - grad_output = torch.randn(output.shape) - - (grad_input_autograd, grad_weight_autograd) = torch.autograd.grad(output, (input, weight), grad_output) - - grad_input_functional = torch.nn.grad.conv2d_input(input.shape, weight, grad_output, - stride=stride, padding=padding, dilation=dilation, groups=groups) - self.assertEqual(grad_input_functional, grad_input_autograd) - - grad_weight_functional = torch.nn.grad.conv2d_weight(input, weight.shape, grad_output, - stride=stride, padding=padding, dilation=dilation, groups=groups) - self.assertEqual(grad_weight_functional, grad_weight_autograd) - - strides = [1, 2] - kernel_sizes = [1, 3, 5] - groups = [1, 2, 4] - dilates = [1, 2] - - for s, k, g, d in product(strides, kernel_sizes, groups, dilates): - _test_conv2d(s, k, g, d) + _batch_norm_stats(torch.randn(1, 96, 112, 112, dtype=torch.float, device='cuda'), torch.channels_last, (0, 2, 3)) + _batch_norm_stats(torch.randn(1, 96, 112, 112, 112, dtype=torch.float, device='cuda'), torch.channels_last_3d, (0, 2, 3, 4)) def test_flatten(self): tensor_input = torch.randn(2, 1, 2, 3) @@ -11171,411 +7088,6 @@ def test_padding_list(self): y = net(x) -class TestNNInit(TestCase): - def setUp(self): - super(TestNNInit, self).setUp() - random.seed(123) - - def _is_normal(self, tensor, mean, std): - samples = tensor.view(-1).tolist() - p_value = stats.kstest(samples, 'norm', args=(mean, std))[1] - return p_value > 0.0001 - - def _is_trunc_normal(self, tensor, mean, std, a, b): - # scipy's trunc norm is suited for data drawn from N(0, 1), - # so we need to transform our data to test it using scipy. - z_samples = (tensor.view(-1) - mean) / std - z_samples = z_samples.tolist() - a0 = (a - mean) / std - b0 = (b - mean) / std - p_value = stats.kstest(z_samples, 'truncnorm', args=(a0, b0))[1] - return p_value > 0.0001 - - def _is_uniform(self, tensor, a, b): - samples = tensor.view(-1).tolist() - p_value = stats.kstest(samples, 'uniform', args=(a, (b - a)))[1] - return p_value > 0.0001 - - def _create_random_nd_tensor(self, dims, size_min, size_max): - size = [random.randint(size_min, size_max) for _ in range(dims)] - tensor = torch.zeros(size) - return tensor - - def _random_float(self, a, b): - return (b - a) * random.random() + a - - def test_calculate_gain_linear(self): - for fn in ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose2d', 'conv_transpose2d', 'conv_transpose3d']: - gain = init.calculate_gain(fn) - self.assertEqual(gain, 1) - - def test_calculate_gain_nonlinear(self): - for fn in ['sigmoid', 'tanh', 'relu', 'leaky_relu']: - gain = init.calculate_gain(fn) - if fn == 'sigmoid': - self.assertEqual(gain, 1) - elif fn == 'tanh': # 5 / 3 - self.assertEqual(gain, 1.6666666666666667) - elif fn == 'relu': # sqrt(2) - self.assertEqual(gain, 1.4142135623730951) - elif fn == 'leaky_relu': # sqrt(2 / 1 + slope^2)) - self.assertEqual(gain, 1.4141428569978354) - elif fn == 'selu': - self.assertEqual(gain, 0.75) - - def test_calculate_gain_leaky_relu(self): - for param in [None, 0, 0.01, 10]: - gain = init.calculate_gain('leaky_relu', param) - if param is None: # Default slope is 0.01 - self.assertEqual(gain, 1.4141428569978354) - elif param == 0: # No slope = same gain as normal ReLU - self.assertEqual(gain, 1.4142135623730951) - elif param == 0.01: - self.assertEqual(gain, 1.4141428569978354) - elif param == 10: - self.assertEqual(gain, 0.14071950894605836) - - def test_calculate_gain_leaky_relu_only_accepts_numbers(self): - for param in [True, [1], {'a': 'b'}]: - with self.assertRaises(ValueError): - init.calculate_gain('leaky_relu', param) - - def test_calculate_gain_only_accepts_valid_nonlinearities(self): - for n in [2, 5, 25]: - # Generate random strings of lengths that definitely aren't supported - random_string = ''.join([random.choice(string.ascii_lowercase) for i in range(n)]) - with self.assertRaises(ValueError): - init.calculate_gain(random_string) - - @unittest.skipIf(not TEST_SCIPY, "Scipy not found.") - def test_uniform(self): - for dims in [1, 2, 4]: - input_tensor = self._create_random_nd_tensor(dims, size_min=30, size_max=50) - a = self._random_float(-3, 3) - b = a + self._random_float(1, 5) - init.uniform_(input_tensor, a=a, b=b) - assert self._is_uniform(input_tensor, a, b) - - @unittest.skipIf(not TEST_SCIPY, "Scipy not found.") - def test_normal(self): - for dims in [1, 2, 4]: - input_tensor = self._create_random_nd_tensor(dims, size_min=30, size_max=50) - mean = self._random_float(-3, 3) - std = self._random_float(1, 5) - init.normal_(input_tensor, mean=mean, std=std) - - assert self._is_normal(input_tensor, mean, std) - - @unittest.skipIf(not TEST_SCIPY, "Scipy not found.") - def test_trunc_normal(self): - for dims in [1, 2, 4]: - input_tensor = self._create_random_nd_tensor(dims, size_min=30, size_max=50) - mean = self._random_float(-3, 3) - std = self._random_float(.01, 1) - a = self._random_float(mean - 2 * std, mean) - b = self._random_float(mean, mean + 2 * std) - init.trunc_normal_(input_tensor, mean=mean, std=std, a=a, b=b) - - assert self._is_trunc_normal(input_tensor, mean, std, a, b) - - def test_constant(self): - for dims in [1, 2, 4]: - input_tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=5) - val = self._random_float(1, 10) - init.constant_(input_tensor, val) - - self.assertEqual(input_tensor, input_tensor.clone().fill_(val)) - - def test_ones_and_zeros(self): - for init_fn_, val in zip([init.ones_, init.zeros_], [1, 0]): - for dims in [1, 2, 4]: - input_tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=5) - init_fn_(input_tensor) - - self.assertEqual(input_tensor, input_tensor.clone().fill_(val)) - - def test_eye(self): - input_tensor = self._create_random_nd_tensor(2, size_min=1, size_max=5) - init.eye_(input_tensor) - - # Check every single element - for i in range(input_tensor.size(0)): - for j in range(input_tensor.size(1)): - if i == j: - assert input_tensor[i][j] == 1 - else: - assert input_tensor[i][j] == 0 - - def test_eye_only_works_on_2d_inputs(self): - for dims in [1, 3]: - with self.assertRaises(ValueError): - tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=3) - init.eye_(tensor) - - def test_dirac_properties(self): - for dims in [3, 4, 5]: - for groups in [1, 2, 3]: - # prepare random tensor with random sizes, but fits groups - a, c, d, e = (random.randint(1, 5) for _ in range(4)) - b = random.randint(1, 5 * groups) # same range as a*groups but all range allowed - # make sure first dim divides by groups - input_tensor = torch.randn((a * groups, b, c, d, e)[:dims]) - - init.dirac_(input_tensor, groups) - - c_out, c_in = input_tensor.size(0) // groups, input_tensor.size(1) - min_d = min(c_out, c_in) - # Check number of nonzeros is equivalent to smallest dim (for each group) - assert torch.nonzero(input_tensor).size(0) == min_d * groups - # Check sum of values (can have precision issues, hence assertEqual) is also equivalent - self.assertEqual(input_tensor.sum(), min_d * groups) - - - def test_dirac_identity(self): - for groups in [1, 3]: - batch, in_c, out_c, size, kernel_size = 8, 3, 9, 5, 3 # in_c, out_c must divide by groups - eff_out_c = out_c // groups - - # Test 1D - input_var = torch.randn(batch, in_c, size) - filter_var = torch.zeros(eff_out_c, in_c, kernel_size) - filter_var = torch.cat([filter_var] * groups) - init.dirac_(filter_var, groups) - output_var = F.conv1d(input_var, filter_var) - input_tensor, output_tensor = input_var.data, output_var.data # Variables do not support nonzero - for g in range(groups): - # Assert in_c outputs are preserved (per each group) - self.assertEqual(input_tensor[:, :, 1:-1], - output_tensor[:, eff_out_c * g:eff_out_c * g + in_c, :]) - # Assert extra outputs are 0 - assert torch.nonzero(output_tensor[:, eff_out_c * g + in_c:eff_out_c * (g + 1), :]).numel() == 0 - - # Test 2D - input_var = torch.randn(batch, in_c, size, size) - filter_var = torch.zeros(eff_out_c, in_c, kernel_size, kernel_size) - filter_var = torch.cat([filter_var] * groups) - init.dirac_(filter_var, groups) - output_var = F.conv2d(input_var, filter_var) - input_tensor, output_tensor = input_var.data, output_var.data # Variables do not support nonzero - for g in range(groups): - # Assert in_c outputs are preserved (per each group) - self.assertEqual(input_tensor[:, :, 1:-1, 1:-1], - output_tensor[:, eff_out_c * g:eff_out_c * g + in_c, :, :]) - # Assert extra outputs are 0 - assert torch.nonzero(output_tensor[:, eff_out_c * g + in_c:eff_out_c * (g + 1), :, :]).numel() == 0 - - # Test 3D - input_var = torch.randn(batch, in_c, size, size, size) - filter_var = torch.zeros(eff_out_c, in_c, kernel_size, kernel_size, kernel_size) - filter_var = torch.cat([filter_var] * groups) - init.dirac_(filter_var, groups) - output_var = F.conv3d(input_var, filter_var) - input_tensor, output_tensor = input_var.data, output_var.data - for g in range(groups): - # Assert in_c outputs are preserved (per each group) - self.assertEqual(input_tensor[:, :, 1:-1, 1:-1, 1:-1], - output_tensor[:, eff_out_c * g:eff_out_c * g + in_c, :, :, :]) - # Assert extra outputs are 0 - assert torch.nonzero(output_tensor[:, eff_out_c * g + in_c:eff_out_c * (g + 1), :, :, :]).numel() == 0 - - def test_dirac_only_works_on_3_4_5d_inputs(self): - for dims in [1, 2, 6]: - with self.assertRaises(ValueError): - tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=3) - init.dirac_(tensor) - - def test_xavier_uniform_errors_on_inputs_smaller_than_2d(self): - for dims in [0, 1]: - tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=1) - with self.assertRaises(ValueError): - init.xavier_uniform_(tensor) - - def test_xavier_normal_errors_on_inputs_smaller_than_2d(self): - for dims in [0, 1]: - tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=1) - with self.assertRaises(ValueError): - init.xavier_normal_(tensor) - - @unittest.skipIf(not TEST_SCIPY, "Scipy not found.") - def test_xavier_uniform(self): - for use_gain in [True, False]: - for dims in [2, 4]: - input_tensor = self._create_random_nd_tensor(dims, size_min=20, size_max=25) - gain = 1 - - if use_gain: - gain = self._random_float(0.1, 2) - init.xavier_uniform_(input_tensor, gain=gain) - else: - init.xavier_uniform_(input_tensor) - - fan_in = input_tensor.size(1) - fan_out = input_tensor.size(0) - if input_tensor.dim() > 2: - fan_in *= input_tensor[0, 0].numel() - fan_out *= input_tensor[0, 0].numel() - - expected_std = gain * math.sqrt(2.0 / (fan_in + fan_out)) - bounds = expected_std * math.sqrt(3) - assert self._is_uniform(input_tensor, -bounds, bounds) - - @unittest.skipIf(not TEST_SCIPY, "Scipy not found.") - def test_xavier_normal(self): - for use_gain in [True, False]: - for dims in [2, 4]: - input_tensor = self._create_random_nd_tensor(dims, size_min=20, size_max=25) - gain = 1 - - if use_gain: - gain = self._random_float(0.1, 2) - init.xavier_normal_(input_tensor, gain=gain) - else: - init.xavier_normal_(input_tensor) - - fan_in = input_tensor.size(1) - fan_out = input_tensor.size(0) - if input_tensor.dim() > 2: - fan_in *= input_tensor[0, 0].numel() - fan_out *= input_tensor[0, 0].numel() - - expected_std = gain * math.sqrt(2.0 / (fan_in + fan_out)) - assert self._is_normal(input_tensor, 0, expected_std) - - def test_kaiming_uniform_errors_on_inputs_smaller_than_2d(self): - for dims in [0, 1]: - with self.assertRaises(ValueError): - tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=1) - init.kaiming_uniform_(tensor) - - def test_kaiming_normal_errors_on_inputs_smaller_than_2d(self): - for dims in [0, 1]: - with self.assertRaises(ValueError): - tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=1) - init.kaiming_normal_(tensor) - - def test_kaiming_uniform_warning_on_0element_tensor(self): - tensor = torch.empty(0, 1) - with self.assertWarnsRegex(UserWarning, "Initializing zero-element tensors is a no-op"): - _ = init.kaiming_uniform_(tensor) - - def test_kaiming_normal_warning_on_0element_tensor(self): - tensor = torch.empty(0, 1) - with self.assertWarnsRegex(UserWarning, "Initializing zero-element tensors is a no-op"): - _ = init.kaiming_normal_(tensor) - - @unittest.skipIf(not TEST_SCIPY, "Scipy not found.") - def test_kaiming_uniform(self): - for use_a in [True, False]: - for dims in [2, 4]: - for mode in ['fan_in', 'fan_out']: - input_tensor = self._create_random_nd_tensor(dims, size_min=20, size_max=25) - if use_a: - a = self._random_float(0.1, 2) - init.kaiming_uniform_(input_tensor, a=a, mode=mode) - else: - a = 0 - init.kaiming_uniform_(input_tensor, mode=mode) - - fan_in = input_tensor.size(1) - fan_out = input_tensor.size(0) - if input_tensor.dim() > 2: - fan_in *= input_tensor[0, 0].numel() - fan_out *= input_tensor[0, 0].numel() - - if mode == 'fan_in': - n = fan_in - else: - n = fan_out - - expected_std = math.sqrt(2.0 / ((1 + a**2) * n)) - bounds = expected_std * math.sqrt(3.0) - assert self._is_uniform(input_tensor, -bounds, bounds) - - @unittest.skipIf(not TEST_SCIPY, "Scipy not found.") - def test_kaiming_normal(self): - for use_a in [True, False]: - for dims in [2, 4]: - for mode in ['fan_in', 'fan_out']: - input_tensor = self._create_random_nd_tensor(dims, size_min=20, size_max=25) - if use_a: - a = self._random_float(0.1, 2) - init.kaiming_normal_(input_tensor, a=a, mode=mode) - else: - a = 0 - init.kaiming_normal_(input_tensor, mode=mode) - - fan_in = input_tensor.size(1) - fan_out = input_tensor.size(0) - if input_tensor.dim() > 2: - fan_in *= input_tensor[0, 0].numel() - fan_out *= input_tensor[0, 0].numel() - - if mode == 'fan_in': - n = fan_in - else: - n = fan_out - - expected_std = math.sqrt(2.0 / ((1 + a**2) * n)) - assert self._is_normal(input_tensor, 0, expected_std) - - def test_sparse_only_works_on_2d_inputs(self): - for dims in [1, 3]: - with self.assertRaises(ValueError): - sparsity = self._random_float(0.1, 0.9) - tensor = self._create_random_nd_tensor(dims, size_min=1, size_max=3) - init.sparse_(tensor, sparsity) - - @unittest.skipIf(not TEST_SCIPY, "Scipy not found.") - def test_sparse_default_std(self): - for use_random_std in [True, False]: - input_tensor = self._create_random_nd_tensor(2, size_min=30, size_max=35) - rows, cols = input_tensor.size(0), input_tensor.size(1) - sparsity = self._random_float(0.1, 0.2) - - std = 0.01 # default std - if use_random_std: - std = self._random_float(0.01, 0.2) - init.sparse_(input_tensor, sparsity=sparsity, std=std) - else: - init.sparse_(input_tensor, sparsity=sparsity) - - for col_idx in range(input_tensor.size(1)): - column = input_tensor[:, col_idx] - assert column[column == 0].nelement() >= math.ceil(sparsity * rows) - - assert self._is_normal(input_tensor[input_tensor != 0], 0, std) - - @skipIfNoLapack - def test_orthogonal(self): - for use_gain in [True, False]: - for tensor_size in [[3, 4], [4, 3], [20, 2, 3, 4], [2, 3, 4, 5]]: - input_tensor = torch.zeros(tensor_size) - gain = 1.0 - - if use_gain: - gain = self._random_float(0.1, 2) - init.orthogonal_(input_tensor, gain=gain) - else: - init.orthogonal_(input_tensor) - - rows, cols = tensor_size[0], reduce(mul, tensor_size[1:]) - flattened_tensor = input_tensor.view(rows, cols) - if rows > cols: - self.assertEqual(torch.mm(flattened_tensor.t(), flattened_tensor), - torch.eye(cols) * gain ** 2, atol=1e-6, rtol=0) - else: - self.assertEqual(torch.mm(flattened_tensor, flattened_tensor.t()), - torch.eye(rows) * gain ** 2, atol=1e-6, rtol=0) - - def test_deprecation(self): - x = torch.randn(3, 3) - - def fn(): - init.normal(x) - - with self.assertWarnsRegex(UserWarning, 'deprecated', msg='methods not suffixed with underscore should be deprecated'): - fn() - class TestFusionEval(TestCase): @given(X=hu.tensor(shapes=((5, 3, 5, 5),)), running_mean=hu.tensor(shapes=(6,)), @@ -12083,50 +7595,6 @@ def _buildEquivalentAffineTransforms3d(device, input_size, output_size, angle_ra class TestNNDeviceType(NNTestCase): - def run_conv_double_back_test(self, kern, stride, padding, chan_in, chan_out, batch_size, - inp_size, dilation, no_weight, groups=1, use_cuda=False, - use_bias=True, dtype=torch.double): - if use_cuda: - device = torch.device("cuda") - else: - device = torch.device("cpu") - - x = torch.randn(batch_size, chan_in, inp_size, inp_size, device=device, - dtype=dtype, requires_grad=True) - weight = torch.randn(chan_out, chan_in // groups, kern, kern, device=device, - dtype=dtype, requires_grad=not no_weight) - if use_bias: - bias = torch.randn(chan_out, device=device, dtype=dtype, requires_grad=True) - else: - bias = None - - def func(*inputs): - if use_bias: - lx, lweight, lbias = inputs - else: - lx, lweight = inputs - lbias = None - # We disable cudnn during forward to avoid finite difference imprecision issues - with cudnn.flags(enabled=False): - out = F.conv2d(lx, lweight, lbias, stride, padding, dilation, groups) - return out - - if use_bias: - inputs = x, weight, bias - else: - inputs = x, weight - - dummy_out = func(*inputs) - grad_y = torch.randn_like(dummy_out, device=device, dtype=dtype, requires_grad=True) - - # Issue #15353: test mkldnn double backward, don't run gradgradcheck due - # to imprecision issues - if dtype == torch.float: - g, = torch.autograd.grad(dummy_out.sum(), x, create_graph=True) - return g.requires_grad - - return gradgradcheck(func, inputs, (grad_y,)) - def _test_InstanceNorm_general(self, cls, input, device, dtype=torch.float): # default case track_running_stats=False b, c = input.size(0), input.size(1) @@ -12251,6 +7719,17 @@ def _test_LayerNorm_cuda_half(self, device): output.sum().backward() self.assertEqualTypeString(output, input) + def _test_LayerNorm_cpu_mixed_dtype(self, device): + for elementwise_affine in [True, False]: + # layer norm input shape is normalized to m x n, cpu vectorized on n, + # so make sure n exceeds vector length + input = torch.empty(2, 3, 11, 3, device=device, dtype=torch.bfloat16).random_(1, 10) + m = nn.LayerNorm([11, 3], elementwise_affine=elementwise_affine).to(device, torch.bfloat16) + m2 = deepcopy(m).to(device, torch.float) + out = m(input) + out2 = m2(input) + self.assertEqual(out, out2) + def _test_GroupNorm_general(self, device, dtype=torch.float): good_shape_g = { (1, 2, 3, 4): 2, @@ -12555,177 +8034,9 @@ def test_affine_3d_rotateRandom(self, device): for r in range(affine_tensor.size(2)): for c in range(affine_tensor.size(3)): grid_out = np.dot(grid_ary, [i, r, c, 1]) - self.assertEqual(affine_tensor[0, i, r, c], grid_out[:3], exact_dtype=False) - - self.assertEqual(scipy_ary, gridsample_ary.reshape_as(scipy_ary)) - - - @onlyCUDA - @skipCUDAIfNoCudnn - @dtypes(*floating_and_complex_types_and(torch.half, *[torch.bfloat16] if AMPERE_OR_ROCM else [])) - def test_Conv2d_deterministic_cudnn(self, device, dtype): - inputs = torch.randn(2, 3, 5, 5, device=device, dtype=dtype, requires_grad=True) - with cudnn.flags(enabled=True, benchmark=True, deterministic=True): - conv1 = torch.nn.Conv2d(3, 3, 3).to(device, dtype) - conv2 = torch.nn.Conv2d(3, 3, 3).to(device, dtype) - conv2.bias.data.copy_(conv1.bias.data) - conv2.weight.data.copy_(conv1.weight.data) - out1 = conv1(inputs) - out2 = conv2(inputs) - self.assertEqual(out1, out2, atol=0.0, rtol=0) - y = torch.randn(out1.size(), device=device, dtype=dtype) - out1.backward(y) - out2.backward(y) - self.assertEqual(conv1.bias.grad.data, conv2.bias.grad.data, atol=0.0, rtol=0) - self.assertEqual(conv1.weight.grad.data, conv2.weight.grad.data, atol=0.0, rtol=0) - - - @onlyCUDA - @dtypes(*floating_types_and(torch.half, *[torch.bfloat16] if AMPERE_OR_ROCM else [])) - def test_Conv2d_large_workspace(self, device, dtype): - # These sizes require huge cuDNN workspaces. Make sure we choose a - # reasonable algorithm that does not run out of memory - sizes = [ - (1, 256, 109, 175), - (1, 256, 80, 128), - (1, 256, 120, 192), - ] - - def run_test(benchmark): - with torch.backends.cudnn.flags(benchmark=benchmark): - conv = torch.nn.Conv2d(256, 256, kernel_size=3, padding=1).to(device, dtype) - for size in sizes: - x = torch.randn(size, device=device, dtype=dtype) - out = conv(x.detach().clone().requires_grad_()) - out.backward(torch.ones_like(out)) - - run_test(benchmark=False) - run_test(benchmark=True) - - - @onlyCUDA - @dtypes(torch.half, torch.float) - def test_ConvTranspose2d_large_output_padding(self, device, dtype): - net1 = torch.nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\ - .to(device=device, dtype=dtype) - net2 = torch.nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1)\ - .to(device=device, dtype=dtype) - net3 = torch.nn.ConvTranspose2d(32, 3, kernel_size=3, stride=2, padding=1, output_padding=1)\ - .to(device=device, dtype=dtype) - x = torch.rand(1, 128, 6, 6, device=device, dtype=dtype, requires_grad=True) - x = net1(x) - x = net2(x) - x = net3(x) - x.backward(torch.randn_like(x)) - torch.cuda.synchronize() - - - @onlyCUDA - @tf32_on_and_off(0.01) - @dtypes(torch.float, torch.double, torch.half) - # Very similar to test_Conv2d_naive_groups but with special care to handle - # the number of groups == number of input channels - @torch.backends.cudnn.flags(enabled=True, benchmark=False) - def test_Conv2d_depthwise_naive_groups(self, device, dtype): - for depth_multiplier in [1, 2]: - m = nn.Conv2d(2, 2 * depth_multiplier, kernel_size=3, groups=2).to(device, dtype) - i = torch.randn(2, 2, 6, 6, device="cuda", dtype=dtype).div_(2).requires_grad_() - output = m(i) - grad_output = torch.randn(2, 2 * depth_multiplier, 4, 4, device=device, dtype=dtype) / 2 - output.backward(grad_output) - - offset = 1 * depth_multiplier - - m1 = nn.Conv2d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype) - m1.weight.data = m.weight.data[:offset].clone() - m1.bias.data = m.bias.data[:offset].clone() - i1 = i.detach()[:, :1].clone().requires_grad_() - output1 = m1(i1) - output1.backward(grad_output[:, :offset].contiguous()) - - m2 = nn.Conv2d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype) - m2.weight.data.copy_(m.weight.data[offset:]) - m2.bias.data.copy_(m.bias.data[offset:]) - i2 = i.detach()[:, 1:].clone().requires_grad_() - output2 = m2(i2) - output2.backward(grad_output[:, offset:].contiguous()) - - self.assertEqual(output, torch.cat([output1, output2], 1), - atol=dtype2prec_DONTUSE[dtype], rtol=0) - self.assertEqual(i.grad.data, - torch.cat([i1.grad.data, i2.grad.data], 1), - atol=dtype2prec_DONTUSE[dtype], rtol=0) - self.assertEqual(m.bias.grad.data, - torch.cat([m1.bias.grad.data, - m2.bias.grad.data], 0), - atol=dtype2prec_DONTUSE[dtype], rtol=0) - self.assertEqual(m.weight.grad.data, - torch.cat([m1.weight.grad.data, - m2.weight.grad.data], 0), - atol=dtype2prec_DONTUSE[dtype], rtol=0) - - @onlyCUDA - @dtypes(torch.float, torch.double, torch.half) - @tf32_on_and_off(0.005) - @torch.backends.cudnn.flags(enabled=True, benchmark=False) - def test_Conv3d_depthwise_naive_groups(self, device, dtype): - for depth_multiplier in [1, 2]: - m = nn.Conv3d(2, 2 * depth_multiplier, kernel_size=3, groups=2).to(device, dtype) - i = torch.randn(2, 2, 6, 6, 6, device="cuda", dtype=dtype).div_(2).requires_grad_() - output = m(i) - grad_output = torch.randn(2, 2 * depth_multiplier, 4, 4, 4, device=device, dtype=dtype) / 2 - output.backward(grad_output) - - offset = 1 * depth_multiplier - - m1 = nn.Conv3d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype) - m1.weight.data = m.weight.data[:offset].clone() - m1.bias.data = m.bias.data[:offset].clone() - i1 = i.detach()[:, :1].clone().requires_grad_() - output1 = m1(i1) - output1.backward(grad_output[:, :offset].contiguous()) - - m2 = nn.Conv3d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype) - m2.weight.data.copy_(m.weight.data[offset:]) - m2.bias.data.copy_(m.bias.data[offset:]) - i2 = i.detach()[:, 1:].clone().requires_grad_() - output2 = m2(i2) - output2.backward(grad_output[:, offset:].contiguous()) - is_cuda_sm86 = device.startswith("cuda") and torch.cuda.get_device_capability(0) == (8, 6) - atol, rtol = (3e-4, 3e-2) if dtype == torch.float32 and is_cuda_sm86 else (dtype2prec_DONTUSE[dtype], 0) - - self.assertEqual(output, torch.cat([output1, output2], 1), - atol=atol, rtol=rtol) - self.assertEqual(i.grad.data, - torch.cat([i1.grad.data, i2.grad.data], 1), - atol=dtype2prec_DONTUSE[dtype], rtol=0) - self.assertEqual(m.bias.grad.data, - torch.cat([m1.bias.grad.data, - m2.bias.grad.data], 0), - atol=dtype2prec_DONTUSE[dtype], rtol=0) - self.assertEqual(m.weight.grad.data, - torch.cat([m1.weight.grad.data, - m2.weight.grad.data], 0), - atol=dtype2prec_DONTUSE[dtype], rtol=0) - - - @onlyCUDA - @dtypes(*floating_types_and(torch.half, *[torch.bfloat16] if AMPERE_OR_ROCM else [])) - def test_noncontig_conv_grad(self, device, dtype): - # FIXME: remove after adding non-contiguous grad tests for all modules - module = nn.Conv2d(3, 5, kernel_size=3, padding=1).to(device, dtype) - input = torch.randn(2, 3, 10, 10, dtype=dtype, device=device, requires_grad=True) - output = module(input) - - grad = torch.randn(2, 2, 5, 10, 10, dtype=dtype, device=device)[:, 1] - assert not grad.is_contiguous() - output.backward(grad, retain_graph=True) - self.assertIsNotNone(input.grad) - result = input.grad.data.clone() - input.grad.data.zero_() + self.assertEqual(affine_tensor[0, i, r, c], grid_out[:3], exact_dtype=False) - output.backward(grad.contiguous()) - self.assertEqual(result, input.grad.data, atol=dtype2prec_DONTUSE[dtype], rtol=0) + self.assertEqual(scipy_ary, gridsample_ary.reshape_as(scipy_ary)) @onlyCUDA @@ -12735,769 +8046,35 @@ def test_batchnorm_large_batch(self, device, dtype): data = torch.rand(880801, 1, 1, 1, device=device, dtype=dtype) out = bn(data).sum().backward() - - @onlyCUDA - @dtypes(torch.double) - def test_conv_double_backward(self, device, dtype): - with torch.backends.cudnn.flags(deterministic=True): - # Double backward only runs with DoubleTensor due to precision reason - batch_size = 1 - for kern, inp_size, dilations in [(3, 5, [1, 2]), (4, 9, [1])]: - for stride, padding, chan_in, chan_out, dilation in product([1], [2], [2], [3], dilations): - no_weight = stride == 2 - result = self.run_conv_double_back_test(kern, stride, - padding, chan_in, chan_out, - batch_size, inp_size, dilation, - no_weight, use_cuda=True, dtype=dtype) - self.assertTrue(result, - "Conv double backward test failed with parameters:" + - "\nkern: " + str(kern) + - "\nstride: " + str(stride) + - "\npadding: " + str(padding) + - "\nchan_in: " + str(chan_in) + - "\nchan_out: " + str(chan_out) + - "\nbatch_size: " + str(batch_size) + - "\ninp_size: " + str(inp_size) + - "\ndilation: " + str(dilation)) - - - def test_conv_double_backward_no_bias(self): - kern = 3 - stride = 2 - chan_in, chan_out = 2, 4 - batch_size = 2 - inp_size = 5 - padding = 1 - dilation = 1 - no_weight = False - use_bias = True - result = self.run_conv_double_back_test(kern, stride, - padding, chan_in, chan_out, - batch_size, inp_size, dilation, - no_weight, use_bias=use_bias) - self.assertTrue(result, - "Conv double backward test failed with parameters:" + - "\nkern: " + str(kern) + - "\nstride: " + str(stride) + - "\npadding: " + str(padding) + - "\nchan_in: " + str(chan_in) + - "\nchan_out: " + str(chan_out) + - "\nbatch_size: " + str(batch_size) + - "\ninp_size: " + str(inp_size) + - "\ndilation: " + str(dilation)) - - - def test_conv_double_backward_groups(self): - kern = 3 - stride = 1 - padding = 2 - chan_in, chan_out = 2, 4 - batch_size = 2 - inp_size = 6 - dilation = 1 - no_weight = False - groups = 2 - result = self.run_conv_double_back_test(kern, stride, - padding, chan_in * groups, chan_out * groups, - batch_size, inp_size, dilation, - no_weight, groups=groups) - self.assertTrue(result, - "Conv double backward test failed with parameters:" + - "\nkern: " + str(kern) + - "\nstride: " + str(stride) + - "\npadding: " + str(padding) + - "\nchan_in: " + str(chan_in) + - "\nchan_out: " + str(chan_out) + - "\nbatch_size: " + str(batch_size) + - "\ninp_size: " + str(inp_size) + - "\ndilation: " + str(dilation) + - "\ngroups: " + str(groups)) - - - def test_conv_double_backward_stride(self): - batch_size = 2 - - # Cannot provide ggW when stride is > 1 - for kern, inp_size, dilations in [(3, 5, [1, 2]), (3, 7, [1])]: - for stride, padding, chan_in, chan_out, dilation in product([2], [0, 1], [1], [2], dilations): - no_weight = False - self.run_conv_double_back_test(kern, stride, - padding, chan_in, chan_out, - batch_size, inp_size, dilation, - no_weight) - - @dtypes(torch.float, torch.cfloat) - @torch.backends.cudnn.flags(enabled=True, benchmark=False) - def test_conv1d_same_padding(self, device, dtype): - # Test padding='same' outputs the correct shape - test_args = [ - # in_size - range(50, 55), - # kernel_size - [1, 2, 3, 8], - # dilation - range(1, 4), - # stride - [1], - ] - for in_size, k_size, dilation, stride in itertools.product(*test_args): - x = torch.rand(1, 1, in_size, device=device, dtype=dtype) - y = torch.rand(1, 1, k_size, device=device, dtype=dtype) - z = F.conv1d(x, y, padding='same', dilation=dilation, stride=stride) - self.assertEqual(z.size(2), int(math.ceil(in_size / stride))) - - # Compare F.conv1d padding='same' output against manual padding - # Without strides/dilation - x = torch.rand(1, 1, 12, device=device, dtype=dtype) - y = torch.rand(1, 1, 3, device=device, dtype=dtype) - expect = F.conv1d(x, y, padding=1) - actual = F.conv1d(x, y, padding='same') - self.assertEqual(expect, actual) - - # With dilation - x = torch.rand(1, 1, 12, device=device, dtype=dtype) - y = torch.rand(1, 1, 4, device=device, dtype=dtype) - expect = F.conv1d(x, y, padding=3, dilation=2) - actual = F.conv1d(x, y, padding='same', dilation=2) - self.assertEqual(expect, actual) - - # Dilation with asymmetric padding - expect = F.conv1d(x, y, padding=5, dilation=3)[..., 1:] - actual = F.conv1d(x, y, padding='same', dilation=3) - self.assertEqual(expect, actual) - - @dtypes(torch.float, torch.cfloat) - def test_conv2d_same_padding(self, device, dtype): - if dtype is torch.cfloat: - rtol, atol = 2e-6, 2e-6 - else: - rtol, atol = None, None - # Compare F.conv2d padding='same' output against manual padding - # Without strides/dilation - x = torch.rand(1, 1, 10, 11, device=device, dtype=dtype) - y = torch.rand(1, 1, 4, 5, device=device, dtype=dtype) - expect = F.conv2d(x, y, padding=(2, 2))[..., 1:, :] - actual = F.conv2d(x, y, padding='same') - self.assertEqual(expect, actual, rtol=rtol, atol=atol) - - # With dilation - y = torch.rand(1, 1, 3, 4, device=device, dtype=dtype) - expect = F.conv2d(x, y, padding=(2, 3), dilation=2) - actual = F.conv2d(x, y, padding='same', dilation=2) - self.assertEqual(expect, actual, rtol=rtol, atol=atol) - - # Dilation with asymmetric padding - y = torch.rand(1, 1, 4, 4, device=device, dtype=dtype) - expect = F.conv2d(x, y, padding=5, dilation=3)[..., 1:, 1:] - actual = F.conv2d(x, y, padding='same', dilation=3) - self.assertEqual(expect, actual, rtol=rtol, atol=atol) - - @dtypes(torch.float, torch.cfloat) - def test_conv3d_same_padding(self, device, dtype): - if dtype is torch.cfloat: - rtol, atol = 2e-6, 2e-6 - else: - rtol, atol = None, None - # Compare F.conv3d padding='same' output against manual padding - # Without strides/dilation - x = torch.rand(1, 1, 10, 11, 12, device=device, dtype=dtype) - y = torch.rand(1, 1, 1, 2, 5, device=device, dtype=dtype) - expect = F.conv3d(x, y, padding=(0, 1, 2))[..., :, 1:, :] - actual = F.conv3d(x, y, padding='same') - self.assertEqual(expect, actual, rtol=rtol, atol=atol) - - # With dilation - expect = F.conv3d(x, y, padding=(0, 1, 4), dilation=2) - actual = F.conv3d(x, y, padding='same', dilation=2) - self.assertEqual(expect, actual, rtol=rtol, atol=atol) - - # Dilation with asymmetric padding - y = torch.rand(1, 1, 4, 4, 4, device=device, dtype=dtype) - expect = F.conv3d(x, y, padding=5, dilation=3)[..., 1:, 1:, 1:] - actual = F.conv3d(x, y, padding='same', dilation=3) - self.assertEqual(expect, actual, rtol=rtol, atol=atol) - - @dtypes(torch.float, torch.cfloat) - def test_conv1d_valid_padding(self, device, dtype): - # Test F.conv1d padding='valid' is the same as no padding - x = torch.rand(1, 1, 10, device=device, dtype=dtype) - y = torch.rand(1, 1, 4, device=device, dtype=dtype) - expect = F.conv1d(x, y) - actual = F.conv1d(x, y, padding='valid') - self.assertEqual(expect, actual) - - @dtypes(torch.float, torch.cfloat) - def test_conv2d_valid_padding(self, device, dtype): - # Test F.conv2d padding='valid' is the same as no padding - x = torch.rand(1, 1, 1, 10, device=device, dtype=dtype) - y = torch.rand(1, 1, 1, 4, device=device, dtype=dtype) - expect = F.conv2d(x, y) - actual = F.conv2d(x, y, padding='valid') - self.assertEqual(expect, actual) - - @dtypes(torch.float, torch.cfloat) - def test_conv3d_valid_padding(self, device, dtype): - # Test F.conv3d padding='valid' is the same as no padding - x = torch.rand(1, 1, 1, 1, 10, dtype=dtype, device=device) - y = torch.rand(1, 1, 1, 1, 4, dtype=dtype, device=device) - expect = F.conv3d(x, y) - actual = F.conv3d(x, y, padding='valid') - self.assertEqual(expect, actual) - - @dtypes(torch.float, torch.cfloat) - def test_conv1d_same_padding_backward(self, device, dtype): - # Test F.conv1d gradients work with padding='same' - x = torch.rand(1, 1, 12, dtype=dtype, device=device, requires_grad=True) - y = torch.rand(1, 1, 4, dtype=dtype, device=device, requires_grad=True) - - # Symmetric padding - z = F.conv1d(x, y, padding=3, dilation=2) - z.sum().backward() - gx_expect, gy_expect = x.grad, y.grad - x.grad, y.grad = None, None - - z = F.conv1d(x, y, padding='same', dilation=2) - z.sum().backward() - self.assertEqual(gx_expect, x.grad) - self.assertEqual(gy_expect, y.grad) - x.grad, y.grad = None, None - - # Asymmetric padding - z = F.conv1d(x, y, padding=2)[..., 1:] - z.sum().backward() - gx_expect, gy_expect = x.grad, y.grad - x.grad, y.grad = None, None - - z = F.conv1d(x, y, padding='same') - z.sum().backward() - self.assertEqual(gx_expect, x.grad) - self.assertEqual(gy_expect, y.grad) - - @dtypes(torch.float, torch.cfloat) - def test_conv2d_same_padding_backward(self, device, dtype): - # Test F.conv2d gradients work with padding='same' - x = torch.rand(1, 1, 10, 11, device=device, dtype=dtype, requires_grad=True) - y = torch.rand(1, 1, 4, 5, device=device, dtype=dtype, requires_grad=True) - - # Symmetric padding - z = F.conv2d(x, y, padding=(3, 4), dilation=2) - z.sum().backward() - gx_expect, gy_expect = x.grad, y.grad - x.grad, y.grad = None, None - - z = F.conv2d(x, y, padding='same', dilation=2) - z.sum().backward() - self.assertEqual(gx_expect, x.grad) - self.assertEqual(gy_expect, y.grad) - x.grad, y.grad = None, None - - # Asymmetric padding - y = torch.rand(1, 1, 4, 4, device=device, dtype=dtype, requires_grad=True) - z = F.conv2d(x, y, padding=2)[..., 1:, 1:] - z.sum().backward() - gx_expect, gy_expect = x.grad, y.grad - x.grad, y.grad = None, None - - z = F.conv2d(x, y, padding='same') - z.sum().backward() - self.assertEqual(gx_expect, x.grad) - self.assertEqual(gy_expect, y.grad) - - @dtypes(torch.double, torch.cdouble) - def test_conv3d_same_padding_backward(self, device, dtype): - check_forward_ad = torch.device(device).type != 'xla' - - # Test F.conv3d gradients work with padding='same' - x = torch.rand(1, 1, 1, 11, 12, dtype=dtype, device=device, requires_grad=True) - y = torch.rand(1, 1, 1, 2, 5, dtype=dtype, device=device, requires_grad=True) - - # Symmetric padding - z = F.conv3d(x, y, padding=(0, 1, 4), dilation=2) - z.sum().backward() - gx_expect, gy_expect = x.grad, y.grad - x.grad, y.grad = None, None - - z = F.conv3d(x, y, padding='same', dilation=2) - z.sum().backward() - self.assertEqual(gx_expect, x.grad) - self.assertEqual(gy_expect, y.grad) - x.grad, y.grad = None, None - - gradcheck(lambda x, y: F.conv3d(x, y, padding='same', dilation=2), (x, y), - check_forward_ad=check_forward_ad, nondet_tol=1e-5) - if torch.device(device).type != 'cuda': - # https://github.com/pytorch/pytorch/issues/70702 - gradgradcheck(lambda x, y: F.conv3d(x, y, padding='same', dilation=2), (x, y), - check_fwd_over_rev=True) - - # Asymmetric padding - y = torch.rand(1, 1, 1, 4, 4, dtype=dtype, device=device, requires_grad=True) - z = F.conv3d(x, y, padding=2)[..., 1:, 1:] - z.sum().backward() - gx_expect, gy_expect = x.grad, y.grad - x.grad, y.grad = None, None - - z = F.conv3d(x, y, padding='same') - z.sum().backward() - self.assertEqual(gx_expect, x.grad) - self.assertEqual(gy_expect, y.grad) - - gradcheck(lambda x, y: F.conv3d(x, y, padding='same'), (x, y), - check_forward_ad=check_forward_ad, nondet_tol=1e-5) - if torch.device(device).type != 'cuda': - # https://github.com/pytorch/pytorch/issues/70702 - gradgradcheck(lambda x, y: F.conv3d(x, y, padding='same'), (x, y), - check_fwd_over_rev=True) - - @dtypes(torch.float, torch.cfloat) - def test_conv1d_valid_padding_backward(self, device, dtype): - # Test F.conv1d gradients work with padding='valid' - x = torch.rand(1, 1, 10, dtype=dtype, device=device, requires_grad=True) - y = torch.rand(1, 1, 4, dtype=dtype, device=device, requires_grad=True) - F.conv1d(x, y, padding=0).sum().backward() - gx_expect, gy_expect = x.grad, y.grad - x.grad, y.grad = None, None - - F.conv1d(x, y, padding='valid').sum().backward() - gx_actual, gy_actual = x.grad, y.grad - self.assertEqual(gx_expect, gx_actual) - self.assertEqual(gy_expect, gy_actual) - - @unittest.skipIf(not TEST_SCIPY, "Scipy required for the test.") - @dtypes(torch.float, torch.cfloat) - @parametrize_test("mode", ('valid', 'same')) - def test_conv1d_vs_scipy(self, device, dtype, mode): - t = make_tensor((1, 10), device=device, dtype=dtype) - feat_dim = t.shape[1] - weight_even = make_tensor((1, 1, 4), device=device, dtype=dtype) - weight_odd = make_tensor((1, 1, 5), device=device, dtype=dtype) - - def _test(t, weight, mode): - # SciPy expects two 1-D inputs. - t_a = t.view(-1).cpu().numpy() - w_a = weight.view(-1).cpu().numpy() - expected = scipy.signal.convolve(t_a, w_a, mode=mode) - - kwargs = {'padding': mode} - if mode == 'same': - # `same` padding in PyTorch conv1d is different - # from SciPy - p = weight.shape[2] // 2 - t = torch.nn.functional.pad(t, (p, p)) - # We have already taken care of padding - kwargs.pop("padding") - - # second input is flipped in SciPy's convolve - weight_flipped = torch.flip(weight, (2,)) - actual = torch.nn.functional.conv1d(t, weight_flipped, **kwargs).squeeze(0) - if mode == 'same': - actual = actual[:feat_dim] - - self.assertEqual(actual, expected) - - # Global dtype for this test suite is torch.double - # This leads to change in type-promotion - # and conv1d outputs `complex128` for `complex64` input. - with set_default_dtype(torch.float): - _test(t, weight_even, mode) - _test(t, weight_odd, mode) - - @unittest.skipIf(not TEST_SCIPY, "Scipy required for the test.") - @dtypes(torch.float, torch.cfloat) - @parametrize_test("mode", ('valid', 'same')) - def test_conv2d_vs_scipy(self, device, dtype, mode): - t = make_tensor((1, 5, 10), device=device, dtype=dtype) - weight_even = make_tensor((1, 1, 2, 4), device=device, dtype=dtype) - weight_odd = make_tensor((1, 1, 3, 5), device=device, dtype=dtype) - - def _test(t, weight, mode): - # SciPy expects two 2-D inputs. - t_a = t.squeeze(0).cpu().numpy() - w_a = weight.squeeze(0).squeeze(0).cpu().numpy() - expected = scipy.signal.convolve2d(t_a, w_a, mode=mode) - - kwargs = {'padding': mode} - if mode == 'same': - # `same` padding in PyTorch conv2d is different - # from SciPy - left_right_pad = weight.shape[3] // 2 - top_bottom_pad = weight.shape[2] // 2 - p = (left_right_pad, left_right_pad, top_bottom_pad, top_bottom_pad) - t = torch.nn.functional.pad(t, p) - # We have already taken care of padding - kwargs.pop("padding") - - # second input is flipped in SciPy's convolve2d - weight_flipped = torch.flip(weight, (2, 3)) - actual = torch.nn.functional.conv2d(t, weight_flipped, **kwargs).squeeze(0) - if mode == 'same': - actual = actual[:5, :10] - - self.assertEqual(actual, expected, rtol=2e-5, atol=5e-6) - - # Global dtype for this test suite is torch.double - # This leads to change in type-promotion - # and conv1d outputs `complex128` for `complex64` input. - with set_default_dtype(torch.float): - _test(t, weight_even, mode) - _test(t, weight_odd, mode) - - @unittest.skipIf(not TEST_SCIPY, "Scipy required for the test.") - @dtypes(torch.float, torch.cfloat) - @parametrize_test("mode", ('valid', 'same')) - def test_conv3d_vs_scipy(self, device, dtype, mode): - t = make_tensor((1, 5, 5, 10), device=device, dtype=dtype) - weight_even = make_tensor((1, 1, 2, 2, 4), device=device, dtype=dtype) - weight_odd = make_tensor((1, 1, 2, 3, 5), device=device, dtype=dtype) - - def _test(t, weight, mode): - # SciPy expects two 3-D inputs. - t_a = t.squeeze(0).cpu().numpy() - w_a = weight.squeeze(0).squeeze(0).cpu().numpy() - expected = scipy.signal.convolve(t_a, w_a, mode=mode) - - kwargs = {'padding': mode} - if mode == 'same': - # `same` padding in PyTorch conv3d is different - # from SciPy - left_right_pad = weight.shape[4] // 2 - top_bottom_pad = weight.shape[3] // 2 - front_back_pad = weight.shape[2] // 2 - p = (left_right_pad, left_right_pad, top_bottom_pad, top_bottom_pad, - front_back_pad, front_back_pad) - t = torch.nn.functional.pad(t, p) - # We have already taken care of padding - kwargs.pop("padding") - - # second input is flipped in SciPy's convolve - weight_flipped = torch.flip(weight, (2, 3, 4)) - actual = torch.nn.functional.conv3d(t, weight_flipped, **kwargs).squeeze(0) - if mode == 'same': - actual = actual[:5, :5, :10] - - if tf32_is_not_fp32() and (dtype == torch.float or dtype == torch.complex64): - self.assertEqual(actual, expected, atol=0.05, rtol=0.05) - else: - self.assertEqual(actual, expected, rtol=2e-5, atol=5e-6) - - # Global dtype for this test suite is torch.double - # This leads to change in type-promotion - # and conv1d outputs `complex128` for `complex64` input. - with set_default_dtype(torch.float): - _test(t, weight_even, mode) - _test(t, weight_odd, mode) - - @dtypes(torch.float, torch.complex64) - def test_conv2d_valid_padding_backward(self, device, dtype): - # Test F.conv2d gradients work with padding='valid' - x = torch.rand(1, 1, 1, 10, device=device, dtype=dtype, requires_grad=True) - y = torch.rand(1, 1, 1, 4, device=device, dtype=dtype, requires_grad=True) - F.conv2d(x, y, padding=0).sum().backward() - gx_expect, gy_expect = x.grad, y.grad - x.grad, y.grad = None, None - - F.conv2d(x, y, padding='valid').sum().backward() - gx_actual, gy_actual = x.grad, y.grad - self.assertEqual(gx_expect, gx_actual) - self.assertEqual(gy_expect, gy_actual) - - @dtypes(torch.double, torch.cdouble) - def test_conv3d_valid_padding_backward(self, device, dtype): - check_forward_ad = torch.device(device).type != 'xla' - - # Test F.conv3d gradients work with padding='valid' - x = torch.rand(1, 1, 1, 1, 10, dtype=dtype, device=device, requires_grad=True) - y = torch.rand(1, 1, 1, 1, 4, dtype=dtype, device=device, requires_grad=True) - F.conv3d(x, y, padding=0).sum().backward() - gx_expect, gy_expect = x.grad, y.grad - x.grad, y.grad = None, None - - F.conv3d(x, y, padding='valid').sum().backward() - gx_actual, gy_actual = x.grad, y.grad - self.assertEqual(gx_expect, gx_actual) - self.assertEqual(gy_expect, gy_actual) - - gradcheck(lambda x, y: F.conv3d(x, y, padding='valid'), (x, y), check_forward_ad=check_forward_ad) - gradgradcheck(lambda x, y: F.conv3d(x, y, padding='valid'), (x, y), check_fwd_over_rev=check_forward_ad) - - @parametrize_test("N", range(2, 4), name_fn=lambda N: 'ConvTranspose{}d'.format(N)) - def test_conv_transpose_with_output_size_and_no_batch_dim(self, device, N): - # For inputs with no batch dim, verify output is the correct shape when output_size is set. - # See https://github.com/pytorch/pytorch/issues/75889 - inp = torch.randn((1, 15, 13) if N == 2 else (1, 15, 13, 13), device=device) - output_size = (1, 240, 200) if N == 2 else (1, 240, 200, 200) - ConvTransposeNd = getattr(nn, 'ConvTranspose{}d'.format(N)) - m = ConvTransposeNd(1, 1, kernel_size=16, stride=16, padding=7, bias=False, device=device) - output = m(inp, output_size=output_size) - self.assertEqual(output.shape, output_size) - - @skipMeta - @parametrize_test("input_shape,transposed,dilated,groups,layout,backend_expected", [ - # === slow === - subtest(((2, 6, 7), False, False, 3, torch.strided, torch._C._ConvBackend.Slow2d), - decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow1d'), - subtest(((2, 6, 7), True, False, 3, torch.strided, torch._C._ConvBackend.SlowTranspose2d), - decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow1d_transposed'), - subtest(((2, 6, 7), False, True, 3, torch.strided, torch._C._ConvBackend.SlowDilated2d), - decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow1d_dilated'), - subtest(((2, 6, 7), True, True, 3, torch.strided, torch._C._ConvBackend.SlowTranspose2d), - decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow1d_dilated_transposed'), - subtest(((2, 6, 7, 8), False, False, 3, torch.strided, torch._C._ConvBackend.Slow2d), - decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow2d'), - subtest(((2, 6, 7, 8), True, False, 3, torch.strided, torch._C._ConvBackend.SlowTranspose2d), - decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow2d_transposed'), - subtest(((2, 6, 7, 8), False, True, 3, torch.strided, torch._C._ConvBackend.SlowDilated2d), - decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow2d_dilated'), - subtest(((2, 6, 7, 8), True, True, 3, torch.strided, torch._C._ConvBackend.SlowTranspose2d), - decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow2d_dilated_transposed'), - subtest(((2, 6, 7, 8, 9), False, False, 3, torch.strided, torch._C._ConvBackend.Slow3d), - decorators=[onlyCPU, disableMkldnn], name='slow3d_cpu'), - # CUDA doesn't have a slow 3D implementation, so it goes to the dilated 3D implementation instead - subtest(((2, 6, 7, 8, 9), False, False, 3, torch.strided, torch._C._ConvBackend.SlowDilated3d), - decorators=[onlyCUDA, disablecuDNN], name='slow3d_cuda'), - # FIXME: RuntimeError: CUDA out of memory. - # subtest(((2, 6, 7, 8, 9), True, False, 3, torch.strided, torch._C._ConvBackend.SlowTranspose3d), - # decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow3d_transposed'), - subtest(((2, 6, 7, 8, 9), False, True, 3, torch.strided, torch._C._ConvBackend.SlowDilated3d), - decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow3d_dilated'), - # FIXME: RuntimeError: CUDA out of memory. - # subtest(((2, 6, 7, 8, 9), True, True, 3, torch.strided, torch._C._ConvBackend.SlowTranspose3d), - # decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow3d_dilated_transposed'), - subtest(((0, 6, 7), False, False, 3, torch.strided, torch._C._ConvBackend.Empty), - decorators=[onlyNativeDeviceTypes, disableMkldnn], name='empty_batch1d'), - subtest(((2, 0, 7), False, False, 3, torch.strided, torch._C._ConvBackend.Empty), - decorators=[onlyNativeDeviceTypes, disableMkldnn], name='empty_channel1d'), - subtest(((0, 0, 7), False, False, 3, torch.strided, torch._C._ConvBackend.Empty), - decorators=[onlyNativeDeviceTypes, disableMkldnn], name='empty_batch_channel1d'), - subtest(((0, 6, 7, 8), False, False, 3, torch.strided, torch._C._ConvBackend.Empty), - decorators=[onlyNativeDeviceTypes, disableMkldnn], name='empty_batch2d'), - subtest(((2, 0, 7, 8), False, False, 3, torch.strided, torch._C._ConvBackend.Empty), - decorators=[onlyNativeDeviceTypes, disableMkldnn], name='empty_channel2d'), - subtest(((0, 0, 7, 8), False, False, 3, torch.strided, torch._C._ConvBackend.Empty), - decorators=[onlyNativeDeviceTypes, disableMkldnn], name='empty_batch_channel2d'), - subtest(((0, 6, 7, 8, 9), False, False, 3, torch.strided, torch._C._ConvBackend.Empty), - decorators=[onlyNativeDeviceTypes, disableMkldnn], name='empty_batch3d'), - subtest(((2, 0, 7, 8, 9), False, False, 3, torch.strided, torch._C._ConvBackend.Empty), - decorators=[onlyNativeDeviceTypes, disableMkldnn], name='empty_channel3d'), - subtest(((0, 0, 7, 8, 9), False, False, 3, torch.strided, torch._C._ConvBackend.Empty), - decorators=[onlyNativeDeviceTypes, disableMkldnn], name='empty_batch_channel3d'), - # === cuda === - # Note that disablecuDNN disables miopen as well. - subtest(((2, 6, 7), False, False, 6, torch.strided, torch._C._ConvBackend.CudaDepthwise2d), - decorators=[onlyCUDA, disablecuDNN], name='cuda_depthwise1d'), - subtest(((2, 6, 7, 8), False, False, 6, torch.strided, torch._C._ConvBackend.CudaDepthwise2d), - decorators=[onlyCUDA, disablecuDNN], name='cuda_depthwise2d'), - subtest(((2, 6, 7, 8, 9), False, False, 6, torch.strided, torch._C._ConvBackend.CudaDepthwise3d), - decorators=[onlyCUDA, disablecuDNN], name='cuda_depthwise3d'), - # === cudnn === - subtest(((2, 6, 7), False, False, 3, torch.strided, torch._C._ConvBackend.Cudnn), - decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen], name='cudnn1d'), - subtest(((2, 6, 7, 8), False, False, 3, torch.strided, torch._C._ConvBackend.Cudnn), - decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen], name='cudnn2d'), - subtest(((2, 6, 7, 8, 9), False, False, 3, torch.strided, torch._C._ConvBackend.Cudnn), - decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen], name='cudnn3d'), - subtest(((2, 6, 7), True, False, 3, torch.strided, torch._C._ConvBackend.CudnnTranspose), - decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen], name='cudnn1d_transposed'), - subtest(((2, 6, 7, 8), True, False, 3, torch.strided, torch._C._ConvBackend.CudnnTranspose), - decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen], name='cudnn2d_transposed'), - # FIXME: RuntimeError: CUDA out of memory. - # subtest(((2, 6, 7, 8, 9), True, False, 3, torch.strided, torch._C._ConvBackend.CudnnTranspose), - # decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen], name='cudnn3d_transposed'), - # === miopen === - subtest(((2, 6, 7), False, False, 3, torch.strided, torch._C._ConvBackend.Miopen), - decorators=[onlyCUDA, skipCUDAIfNoMiopen], name='miopen1d'), - subtest(((2, 6, 7, 8), False, False, 3, torch.strided, torch._C._ConvBackend.Miopen), - decorators=[onlyCUDA, skipCUDAIfNoMiopen], name='miopen2d'), - subtest(((2, 6, 7, 8, 9), False, False, 3, torch.strided, torch._C._ConvBackend.Miopen), - decorators=[onlyCUDA, skipCUDAIfNoMiopen], name='miopen3d'), - subtest(((2, 6, 7), True, False, 3, torch.strided, torch._C._ConvBackend.MiopenTranspose), - decorators=[onlyCUDA, skipCUDAIfNoMiopen], name='miopen1d_transposed'), - subtest(((2, 6, 7, 8), True, False, 3, torch.strided, torch._C._ConvBackend.MiopenTranspose), - decorators=[onlyCUDA, skipCUDAIfNoMiopen], name='miopen2d_transposed'), - subtest(((2, 6, 7, 8, 9), True, False, 3, torch.strided, torch._C._ConvBackend.MiopenTranspose), - decorators=[onlyCUDA, skipCUDAIfNoMiopen], name='miopen3d_transposed'), - subtest(((2, 6, 7), False, False, 6, torch.strided, torch._C._ConvBackend.MiopenDepthwise), - decorators=[onlyCUDA, skipCUDAIfNoMiopen], name='miopen_depthwise1d'), - subtest(((2, 6, 7, 8), False, False, 6, torch.strided, torch._C._ConvBackend.MiopenDepthwise), - decorators=[onlyCUDA, skipCUDAIfNoMiopen], name='miopen_depthwise2d'), - subtest(((2, 6, 7, 8, 9), False, False, 6, torch.strided, torch._C._ConvBackend.MiopenDepthwise), - decorators=[onlyCUDA, skipCUDAIfNoMiopen], name='miopen_depthwise3d'), - # === mkldnn === - subtest(((2, 6, 7), False, False, 3, torch._mkldnn, torch._C._ConvBackend.Mkldnn), - decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn1d'), - subtest(((2, 6, 7, 8), False, False, 3, torch._mkldnn, torch._C._ConvBackend.Mkldnn), - decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn2d'), - subtest(((2, 6, 7, 8, 9), False, False, 3, torch._mkldnn, torch._C._ConvBackend.Mkldnn), - decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn3d'), - # Transposed convolution is broken for mkldnn. See https://github.com/pytorch/pytorch/issues/68775. - subtest(((2, 6, 7), True, False, 3, torch._mkldnn, torch._C._ConvBackend.Mkldnn), - decorators=[onlyCPU, skipCPUIfNoMkldnn, unittest.expectedFailure], name='mkldnn1d_transposed'), - subtest(((2, 6, 7, 8), True, False, 3, torch._mkldnn, torch._C._ConvBackend.Mkldnn), - decorators=[onlyCPU, skipCPUIfNoMkldnn, unittest.expectedFailure], name='mkldnn2d_transposed'), - subtest(((2, 6, 7, 8, 9), True, False, 3, torch._mkldnn, torch._C._ConvBackend.Mkldnn), - decorators=[onlyCPU, skipCPUIfNoMkldnn, unittest.expectedFailure], name='mkldnn3d_transposed'), - subtest(((2, 6, 7), False, True, 3, torch.strided, torch._C._ConvBackend.Mkldnn), - decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn1d_cpu_input'), - subtest(((2, 6, 7, 8), False, True, 3, torch.strided, torch._C._ConvBackend.Mkldnn), - decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn2d_cpu_input'), - subtest(((2, 6, 7, 8, 9), False, True, 3, torch.strided, torch._C._ConvBackend.Mkldnn), - decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn3d_cpu_input'), - subtest(((0, 6, 7), False, False, 3, torch._mkldnn, torch._C._ConvBackend.MkldnnEmpty), - decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn_empty_batch1d'), - subtest(((2, 0, 7), False, False, 3, torch._mkldnn, torch._C._ConvBackend.MkldnnEmpty), - decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn_empty_channel1d'), - subtest(((0, 0, 7), False, False, 3, torch._mkldnn, torch._C._ConvBackend.MkldnnEmpty), - decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn_empty_batch_channel1d'), - subtest(((0, 6, 7, 8), False, False, 3, torch._mkldnn, torch._C._ConvBackend.MkldnnEmpty), - decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn_empty_batch2d'), - subtest(((2, 0, 7, 8), False, False, 3, torch._mkldnn, torch._C._ConvBackend.MkldnnEmpty), - decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn_empty_channel2d'), - subtest(((0, 0, 7, 8), False, False, 3, torch._mkldnn, torch._C._ConvBackend.MkldnnEmpty), - decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn_empty_batch_channel2d'), - subtest(((0, 6, 7, 8, 9), False, False, 3, torch._mkldnn, torch._C._ConvBackend.MkldnnEmpty), - decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn_empty_batch3d'), - subtest(((2, 0, 7, 8, 9), False, False, 3, torch._mkldnn, torch._C._ConvBackend.MkldnnEmpty), - decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn_empty_channel3d'), - subtest(((0, 0, 7, 8, 9), False, False, 3, torch._mkldnn, torch._C._ConvBackend.MkldnnEmpty), - decorators=[onlyCPU, skipCPUIfNoMkldnn], name='mkldnn_empty_batch_channel3d'), - # Note: Tests for mobile backends are not currently supported. This comprises - # NnpackSpatial, Winograd3x3Depthwise, and Xnnpack2d backends. Testing these - # requires the ability to gate tests by whether PyTorch is built with USE_MOBILE=1. - ]) - # Test with both bias and no bias. - @parametrize_test("has_bias", [False, True]) - # Test with both stride=1 and stride>1 cases. - @parametrize_test("strided", [False, True]) - # Test with both contiguous and non-contiguous inputs. - @parametrize_test("contiguous", [False, True]) - def test_conv_backend( - self, device, input_shape, has_bias, strided, contiguous, transposed, dilated, groups, - layout, backend_expected): - # Build up inputs. - dtype = torch.float32 - C_in, C_out, dim, kernel_size = input_shape[1], 12, len(input_shape) - 2, 3 - x = torch.randn(*input_shape, device=device, dtype=dtype, requires_grad=True) - weight = torch.randn(C_in if transposed else C_out, - C_out // groups if transposed else C_in // groups, - *[kernel_size for _ in range(dim)], - device=device, dtype=dtype, requires_grad=True) - bias = torch.randn(C_out, device=device, dtype=dtype, requires_grad=True) if has_bias else None - - def _make_noncontiguous(inp): - if inp is None: - return None - old_requires_grad = inp.requires_grad - inp = torch.repeat_interleave(inp, 2, dim=-1) - inp = inp[..., ::2].detach().requires_grad_(old_requires_grad) - return inp - - if not contiguous: - x = _make_noncontiguous(x) - weight = _make_noncontiguous(weight) - bias = _make_noncontiguous(bias) - - if layout is torch._mkldnn: - x = x.to_mkldnn() - # Note that weight and bias are not supported as mkldnn tensors during training. - - stride = (2,) * dim if strided else (1,) * dim - padding = (0,) * dim - dilation = (2,) * dim if dilated else (1,) * dim - output_padding = (0,) * dim - inputs = [x, weight, bias, stride, padding, dilation, transposed, output_padding, groups] - - # Ensure correct backend is selected. - backend_actual = torch._C._select_conv_backend(*inputs) - self.assertEqual(backend_actual, backend_expected) - - # Ensure backward call succeeds. - convolution = torch.ops.aten.convolution - output = convolution(*inputs) - grad_output = torch.randn(output.shape, device=device, dtype=dtype) - if not contiguous: - grad_output = _make_noncontiguous(grad_output) - if layout is torch._mkldnn: - grad_output = grad_output.to_mkldnn() - output.backward(grad_output) - - # mkldnn doesn't support gradcheck :( - if layout is torch._mkldnn: - return - - if backend_actual != torch._C._ConvBackend.Empty: # FIXME: forward AD fails - # Forward AD and forward-over-reverse AD smoke test in float32 - # TODO: remove this if we introduce per-op gradient tests for float32 - with fwAD.dual_level(): - dual_inputs = [(fwAD.make_dual(i, torch.rand_like(i)) if isinstance(i, torch.Tensor) else i) for i in inputs] - # Forward AD - output = convolution(*dual_inputs) - # Forward over reverse AD - grad_output_d = fwAD.make_dual(torch.rand_like(output), torch.rand_like(output)) - if has_bias: - torch.autograd.grad(output, [x, weight, bias], grad_output_d) - else: - torch.autograd.grad(output, [x, weight], grad_output_d) - - # Convert to float64 for gradcheck. - x = x.to(torch.float64).detach().requires_grad_(True) - weight = weight.to(torch.float64).detach().requires_grad_(True) - if bias is not None: - bias = bias.to(torch.float64).detach().requires_grad_(True) - inputs = [x, weight, bias, stride, padding, dilation, transposed, output_padding, groups] - - # Set some backend-specific validation settings. - gradcheck_nondet_tol = 0.0 - if torch.backends.cudnn.is_available(): - # cuDNN introduces non-determinism - gradcheck_nondet_tol = GRADCHECK_NONDET_TOL - - self.assertTrue(gradcheck(convolution, inputs, nondet_tol=gradcheck_nondet_tol)) - - # double backward doesn't support bias gradients - if bias is not None: - bias.requires_grad_(False) - self.assertTrue(gradgradcheck(convolution, inputs, nondet_tol=gradcheck_nondet_tol)) - - - @onlyCPU - def test_conv_contiguous_for_oneDNN(self): - # See https://github.com/pytorch/pytorch/issues/80837. - for dtype in [torch.float, torch.bfloat16]: - conv = nn.Conv2d( - 1, - 128, - kernel_size=(5, 2), - stride=(2, 1), - padding=(0, 1), - dilation=(1, 1), - groups=1, - bias=True, - padding_mode='zeros').to(dtype=dtype) - - x = torch.rand([1, 2, 321, 201, 1]).to(dtype=dtype) - x = torch.transpose(x, 1, 4) - x2 = x[..., 0] - inputs = [x2, conv.weight, conv.bias, (2, 1), (0, 1), (1, 1), False, (0, 1), 1] - if torch.backends.mkldnn.is_available(): - y = conv(x2) - # Disable MKLDNN explicitly - with torch.backends.mkldnn.flags(enabled=False): - y_ = conv(x2) - self.assertEqual(y, y_) - - @onlyCPU - def test_conv_ic1_channels_last_for_oneDNN(self): - # See https://github.com/pytorch/pytorch/issues/82060, N > 1 will call in OneDNN path. - for dtype in [torch.float, torch.bfloat16]: - conv = torch.nn.Conv2d(1, 64, kernel_size=(3, 3), padding=(1, 1), bias=False) - conv = conv.to(memory_format=torch.channels_last).to(dtype=dtype) - x = torch.rand(2, 1, 100, 100).to(dtype=dtype) - if torch.backends.mkldnn.is_available(): - y = conv(x) - # Disable MKLDNN explicitly - with torch.backends.mkldnn.flags(enabled=False): - y_ = conv(x) - self.assertEqual(y, y_) + @dtypesIfCUDA(torch.float, torch.double, torch.half, torch.complex128) + @dtypes(torch.float, torch.double, torch.bfloat16, torch.complex128) + def test_conv_empty_input(self, device, dtype): + def help(input, conv, memory_format): + ref_out = conv(input) + conv_cl = conv.to(memory_format=memory_format) + out_cl = conv_cl(input) + self.assertEqual(ref_out, out_cl) + input_cl = input.to(memory_format=memory_format) + out_cl2 = conv(input_cl) + self.assertEqual(out_cl, out_cl2) + out_cl3 = conv_cl(input_cl) + self.assertEqual(out_cl, out_cl3) + + # channels_last case + input2d = torch.randn((0, 4, 20, 20)).to(device=device, dtype=dtype) + conv2d = torch.nn.Conv2d(4, 4, 3, 1).to(device=device, dtype=dtype) + help(input2d, conv2d, torch.channels_last) + # channels_last_3d case + input3d = torch.randn((0, 4, 20, 20, 20)).to(device=device, dtype=dtype) + conv3d = torch.nn.Conv3d(4, 4, 3, 1).to(device=device, dtype=dtype) + help(input3d, conv3d, torch.channels_last_3d) + # non-contiguous case + weight = torch.rand(4, 8, 3, 3)[:, ::2, :, :].to(device=device, dtype=dtype) + bias = torch.rand(4).to(device=device, dtype=dtype) + out = F.conv2d(input2d, weight, bias, (1, 1), 0, (1, 1), 1) + weight = weight.contiguous() + out_ref = F.conv2d(input2d, weight, bias, (1, 1), 0, (1, 1), 1) + self.assertEqual(out_ref, out) def test_InstanceNorm1d_general(self, device): b = random.randint(3, 5) @@ -13540,6 +8117,7 @@ def test_instancenorm_raises_error_if_less_than_one_value_per_channel(self, devi with self.assertRaises(ValueError): torch.nn.InstanceNorm1d(10)(x).to(device) + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_instancenorm_raises_error_for_single_spatial_element_during_training(self, device): BATCH_SIZE = 10 NUM_CHANNELS = 3 @@ -13567,6 +8145,9 @@ def test_LayerNorm_general(self, device): if self.device_type == 'cuda': self._test_LayerNorm_cuda_half(device) + if self.device_type == 'cpu': + self._test_LayerNorm_cpu_mixed_dtype(device) + @onlyNativeDeviceTypes def test_LayerNorm_numeric(self, device): def layer_norm_ref(X, gamma, beta, normalized_shape, eps): @@ -13627,6 +8208,7 @@ def test_GroupNorm_raises_error_if_one_value_per_group(self, device): with self.assertRaises(ValueError): torch.nn.GroupNorm(10, 10)(x).to(device) + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_GroupNorm_empty(self, device): mod = torch.nn.GroupNorm(2, 4).to(device) inp = torch.randn(0, 4, 2, 2, device=device) @@ -13702,6 +8284,7 @@ def group_norm_ref(X, gamma, beta, groups, channels, eps): @onlyNativeDeviceTypes @dtypes(torch.float64, torch.complex128) + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_pad(self, device, dtype): # Assert assertion errors are raised for invalid circular padding values inputs = torch.randn(1, 1, 4, device=device, dtype=dtype, requires_grad=True) @@ -13733,6 +8316,7 @@ def test_pad(self, device, dtype): @onlyNativeDeviceTypes @dtypes(torch.float64, torch.complex128) + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_ReplicationPad_empty(self, device, dtype): for mod, inp in [ (torch.nn.ReplicationPad1d(3), torch.randn(0, 3, 10, device=device, dtype=dtype)), @@ -13755,6 +8339,7 @@ def test_ReplicationPad_empty(self, device, dtype): inp = torch.randn(3, 0, 10, 10, 10, device=device, dtype=dtype) mod(inp) + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_ReplicationPad1d_large(self, device): shapes = ([2, 65736, 4], [65736, 2, 4]) pl, pr = 3, 4 @@ -13779,6 +8364,7 @@ def test_ReplicationPad1d_large(self, device): self.assertEqual(x.grad[:, :, 0], g[:, :, : pl + 1].sum(-1)) self.assertEqual(x.grad[:, :, -1], g[:, :, -pr - 1:].sum(-1)) + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_ReplicationPad2d_large(self, device): shapes = ([2, 65736, 4, 4], [65736, 2, 4, 4]) pl, pr, pt, pb = 3, 4, 5, 6 @@ -13844,6 +8430,8 @@ def test_ReplicationPad3d_large(self, device): self.assertEqual(x.grad[:, :, 1:-1, 1:-1, 1:-1], g[:, :, pf + 1 : -pbk - 1, pt + 1 : -pbt - 1, pl + 1 : -pr - 1]) @onlyNativeDeviceTypes + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") + def test_Bilinear_empty(self, device): mod = torch.nn.Bilinear(20, 30, 40).to(device) inp1 = torch.randn(0, 10, 20, requires_grad=True, device=device) @@ -13860,6 +8448,7 @@ def test_Bilinear_empty(self, device): @expectedFailureMeta # RuntimeError: cannot reshape tensor of 0 elements into shape [1, 0, -1] @onlyNativeDeviceTypes + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_TransformerEncoderLayer_empty(self, device): for training in (True, False): for batch_first, input_shape in [(True, (0, 10, 512)), @@ -13887,6 +8476,7 @@ def test_TransformerEncoderLayer_empty(self, device): @expectedFailureMeta # RuntimeError: cannot reshape tensor of 0 elements into shape [1, 0, -1] @onlyNativeDeviceTypes + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_TransformerEncoder_empty(self, device): for batch_first, input_shape in [(True, (0, 10, 512)), (False, (10, 0, 512))]: @@ -13897,6 +8487,7 @@ def test_TransformerEncoder_empty(self, device): @expectedFailureMeta # RuntimeError: cannot reshape tensor of 0 elements into shape [1, 0, -1] @onlyNativeDeviceTypes + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_TransformerDecoderLayer_empty(self, device): for batch_first, memory_shape, tgt_shape in [(True, (0, 10, 512), (0, 20, 512)), (False, (10, 0, 512), (20, 0, 512))]: @@ -13918,6 +8509,7 @@ def test_TransformerDecoder_empty(self, device): @expectedFailureMeta # RuntimeError: cannot reshape tensor of 0 elements into shape [1, 0, -1] @onlyNativeDeviceTypes + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_Transformer_empty(self, device): for batch_first, src_shape, tgt_shape in [(True, (10, 0, 512), (20, 0, 512))]: transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12).to(device) @@ -13927,6 +8519,7 @@ def test_Transformer_empty(self, device): @onlyNativeDeviceTypes @dtypes(torch.float32, torch.complex64) + @dtypesIfMPS(torch.float32) def test_ReflectionPad_empty(self, device, dtype): for mod, inp in [ (torch.nn.ReflectionPad1d(2), torch.randn(0, 3, 10, device=device, dtype=dtype)), @@ -13971,6 +8564,7 @@ def test_ReflectionPad2d_large(self, device): self.assertEqual(x.grad, ref_x.grad) @onlyNativeDeviceTypes + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_LocalResponseNorm_empty(self, device): mod = torch.nn.LocalResponseNorm(2).to(device) inp = torch.ones(0, 5, 24, 24, device=device) @@ -13999,6 +8593,7 @@ def test_ReflectionPad3d_large(self, device): @onlyNativeDeviceTypes @dtypes(torch.float, torch.double) + @dtypesIfMPS(torch.float) def test_MarginLoss_empty(self, device, dtype): for mod, x, y in [ (torch.nn.MultiMarginLoss().to(device), @@ -14103,6 +8698,7 @@ def check_rnn_grads(rnn1, rnn2): else: self.assertEqual(hx.grad, hx_device.grad) + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_BatchNorm_empty(self, device): mod = torch.nn.BatchNorm2d(3).to(device) inp = torch.randn(0, 3, 2, 2, device=device) @@ -14116,57 +8712,6 @@ def test_BatchNorm_empty(self, device): self.assertEqual(mod.weight.grad, torch.tensor([0., 0, 0], device=device)) self.assertEqual(mod.bias.grad, torch.tensor([0., 0, 0], device=device)) - @dtypes(torch.float, torch.cfloat) - def test_conv_empty_channel(self, device, dtype): - in_channels = 0 - mod = torch.nn.Conv1d(in_channels, 8, 2, stride=2, dtype=dtype).to(device) - inp = torch.randn(2, 0, 15, device=device, dtype=dtype) - _test_module_empty_input(self, mod, inp, check_size=False) - - with self.assertRaisesRegex(RuntimeError, "Given groups=1, weight"): - inp = torch.randn(2, 1, 0, device=device, dtype=dtype) - mod(inp) - - mod = torch.nn.Conv2d(in_channels, 33, 3, stride=2, dtype=dtype).to(device) - inp = torch.randn(2, 0, 50, 100, device=device, dtype=dtype) - _test_module_empty_input(self, mod, inp, check_size=False) - - with self.assertRaisesRegex(RuntimeError, "Given groups=1, weight"): - inp = torch.randn(2, 1, 40, 0, device=device, dtype=dtype) - mod(inp) - - mod = torch.nn.Conv3d(in_channels, 33, 3, stride=2, dtype=dtype).to(device) - inp = torch.randn(2, 0, 50, 20, 40, device=device, dtype=dtype) - _test_module_empty_input(self, mod, inp, check_size=False) - - with self.assertRaisesRegex(RuntimeError, "Given groups=1, weight"): - inp = torch.randn(2, 1, 50, 0, 40, device=device, dtype=dtype) - mod(inp) - - def test_group_conv_empty(self, device): - mod = torch.nn.Conv2d(4, 4, stride=2, kernel_size=3, padding=1, groups=4).to(device) - inp = torch.randn(0, 4, 4, 4, device=device) - _test_module_empty_input(self, mod, inp, check_size=False) - if self.device_type == 'cuda' and self.has_cudnn(): - with torch.backends.cudnn.flags(enabled=False): - _test_module_empty_input(self, mod, inp, check_size=False) - - def test_group_convTranspose_empty(self, device): - mod = torch.nn.ConvTranspose2d(4, 4, stride=2, kernel_size=3, padding=1, groups=4).to(device) - inp = torch.randn(0, 4, 4, 4, device=device) - _test_module_empty_input(self, mod, inp, check_size=False) - if self.device_type == 'cuda' and self.has_cudnn(): - with torch.backends.cudnn.flags(enabled=False): - _test_module_empty_input(self, mod, inp, check_size=False) - - def test_convTranspose_empty(self, device): - mod = torch.nn.ConvTranspose2d(4, 4, stride=2, kernel_size=3, padding=1).to(device) - inp = torch.randn(0, 4, 4, 4, device=device) - _test_module_empty_input(self, mod, inp, check_size=False) - if self.device_type == 'cuda' and self.has_cudnn(): - with torch.backends.cudnn.flags(enabled=False): - _test_module_empty_input(self, mod, inp, check_size=False) - @onlyCUDA @largeTensorTest('16GB') def test_prelu_backward_32bit_indexing(self, device): @@ -14175,6 +8720,7 @@ def test_prelu_backward_32bit_indexing(self, device): output = m(input_) output.backward(input_) + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_linear_empty(self, device): mod = torch.nn.Linear(7, 7).to(device) inp = torch.randn(0, 7, device=device) @@ -14230,6 +8776,7 @@ def test_one_hot(self, device): with self.assertRaises(RuntimeError): torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), -2) + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_nn_empty(self, device): # One off tests to ensure scalars from nn.yaml are properly applied def verify_scalars(input, output): @@ -14245,6 +8792,7 @@ def verify_scalars(input, output): output = m(input) verify_scalars(input, output) + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_nn_scalars(self, device): # One off tests to ensure scalars from nn.yaml are properly applied def verify_scalars(input, output): @@ -14264,6 +8812,7 @@ def verify_scalars(input, output): output = m(input) verify_scalars(input, output) + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_nn_scalars_reductions(self, device): # One off tests to ensure scalars from nn.yaml are properly applied def verify_reduction_scalars(input, reduction, output): @@ -14289,6 +8838,7 @@ def verify_reduction_scalars(input, reduction, output): # verify that bogus reduction strings are errors @onlyNativeDeviceTypes + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_invalid_reduction_strings(self, device): input = torch.randn(3, 5, requires_grad=True, device=device) cinput = torch.randn(3, 5, requires_grad=True, device=device, dtype=torch.cfloat) @@ -14337,6 +8887,7 @@ def v(fn): v(lambda: F.soft_margin_loss(input, input.sign().detach(), reduction=reduction)) @onlyNativeDeviceTypes + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_smooth_l1_loss_vs_huber_loss(self, device): def _make_test_tensor(shape, contiguous=True): if contiguous: @@ -14425,6 +8976,7 @@ def func(device): # We don't want to make propagating NaN a hard requirement on ops, but for # these easy ones, we should make them do so. + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_nonlinearity_propagate_nan(self, device): def test(nonlinearity, *args, **kwargs): x = torch.tensor([nan], device=device) @@ -14562,6 +9114,7 @@ def helper(isize, osize): helper(20, 11) helper(10, 15) + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_upsamplingNearest2d(self, device): # Forward AD does not support XLA because XLA tensors don't have storage check_forward_ad = torch.device(device).type != 'xla' @@ -14680,6 +9233,7 @@ def helper(memory_format, isize, osize): helper(torch.contiguous_format, 10, 15) helper(torch.channels_last, 10, 15) + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_upsamplingNearest3d(self, device): # Forward AD does not support XLA because XLA tensors don't have storage check_forward_ad = torch.device(device).type != 'xla' @@ -14792,6 +9346,7 @@ def helper(memory_format, isize, osize): @parametrize_test("antialias", [True, False]) @parametrize_test("align_corners", [True, False]) + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_upsamplingBilinear2d(self, device, antialias, align_corners): # Forward AD does not support XLA because XLA tensors don't have storage check_forward_ad = torch.device(device).type != 'xla' @@ -14870,6 +9425,7 @@ def test_upsamplingBilinear2d_aa_correctness(self, device, memory_format): @parametrize_test("antialias", [True, False]) @parametrize_test("align_corners", [True, False]) + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_upsamplingBicubic2d(self, device, antialias, align_corners): kwargs = dict(mode='bicubic', align_corners=align_corners, antialias=antialias) # test float scale factor up & downsampling @@ -14889,6 +9445,7 @@ def test_upsamplingBicubic2d(self, device, antialias, align_corners): inpt = torch.ones(2, 3, 8, 8, requires_grad=True, device=device) gradcheck(lambda x: F.interpolate(x, out_size, **kwargs), [inpt], nondet_tol=nondet_tol) + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_upsamplingBicubic2d_correctness(self, device): # test output against known input: align_corners=False result must match opencv in_t = torch.arange(8., device=device).view(1, 2, 2, 2) @@ -14922,12 +9479,118 @@ def test_upsamplingBicubic2d_aa_correctness(self, device, memory_format): t_out = F.interpolate(t_in, size=(2, 2), mode="bicubic", align_corners=False, antialias=True) self.assertEqual(expected_out, t_out) + @onlyCUDA + @dtypes(torch.half) + @largeTensorTest('40GB') + def test_upsampling_64bit_indexing_channels_last(self, device, dtype): + x = torch.rand((32, 64, 512, 512), dtype=dtype, device=device) + out = torch.nn.functional.interpolate(x.to(memory_format=torch.channels_last), scale_factor=2, mode='nearest') + out_ref = torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest') + del x + self.assertTrue(torch.allclose(out, out_ref)) + def _slow_masked_softmax(self, input, mask): exp = torch.exp(input) exp = exp * mask s = exp.sum(dim=3, keepdim=True).expand(exp.size()) return exp / s + def test_masked_softmax_mask_types(self, device): + # Test that mask type 0 (LxL attention mask), mask type 1 (BxL padding mask), + # and mask type 2 (generic BxHxLxL mask) are processed correctly on the + # fast path and the results match explicit slow calculation. + sizes = [(1, 1, 32), (3, 16, 310), (12, 4, 1024), (4, 2, 1200)] + + for (B, num_heads, L) in sizes: + + # mask_type == 0 => attention mask of shape LxL + src_mask_orig = torch.randint(0, 2, (L, L)).bool() + src_mask = src_mask_orig.reshape(1, 1, L, L).expand(B, num_heads, L, L).bool() + + # mask_type == 1 => padding mask of shape BxL + src_key_padding_mask_orig = torch.randint(0, 2, (B, L)).bool() + src_key_padding_mask = src_key_padding_mask_orig.reshape(B, 1, 1, L).expand(B, num_heads, L, L).bool() + + # mask_type == 2 => shape BxHxLxL + generic_mask = torch.randint(0, 2, (B, num_heads, L, L)).bool() + masks = [(src_mask_orig, src_mask, 0), + (src_key_padding_mask_orig, src_key_padding_mask, 1), + (generic_mask, generic_mask, 2) + ] + for dim in [0, 3]: + for mask_orig, mask, mask_type in masks: + if (self.device_type == "cuda") and (num_heads % 2) and (mask_type == 1): + # CUDA path doesn't support padding mask when the number of heads is odd + continue + input = torch.randn((B, num_heads, L, L)) + if (self.device_type == "cuda"): + input = input.cuda() + mask = mask.cuda() + mask_orig = mask_orig.cuda() + native_res = torch._masked_softmax(input, mask_orig, dim, mask_type) + mask = ~mask + + def slow_masked_softmax(input, mask): + exp = torch.exp(input) + exp = exp * mask + s = exp.sum(dim=dim, keepdim=True).expand(exp.size()) + return exp / s + + pt_res = slow_masked_softmax(input, mask) + pt_res = torch.nan_to_num(pt_res) + + mask_not = mask.logical_not() + # In result, should only fill the entirely masked out rows since those are non-deterministic (*may* be 0) + # Converts rows with all True's to False + mask_out = mask_not.all(dim, keepdim=True).expand(mask_not.shape) + self.assertEqual( + pt_res.masked_fill(mask_out, 0), + native_res.masked_fill(mask_out, 0), + exact_dtype=True + ) + + @onlyCUDA + def test_masked_softmax_devices_parity(self): + # Test that softmax with mask type 0 (LxL attention mask), mask type 1 (BxL padding mask), + # and mask type 2 (BxHxLxL generic mask) gives the same result on CPU and on CUDA. + + sizes = [(1, 1, 32), (3, 16, 310), (12, 4, 1024), (4, 2, 1200)] + for (B, num_heads, L) in sizes: + # mask_type == 0 => attention mask of shape LxL + src_mask = torch.randint(0, 2, (L, L)).bool() + # mask_type == 1 => padding mask of shape BxL + src_key_padding_mask = torch.randint(0, 2, (B, L)).bool() + # mask_type == 2 => generic mask of shape BxHxLxL + generic_mask = torch.randint(0, 2, (B, num_heads, L, L)).bool() + masks = [(src_mask, 0), (src_key_padding_mask, 1), (generic_mask, 2)] + input = torch.randn((B, num_heads, L, L)) + for dim in [0, 3]: + for mask, mask_type in masks: + if (num_heads % 2) and (mask_type == 1): + # CUDA path doesn't support padding mask when the number of heads is odd + continue + + def softmax_on_device(mask, input, device): + # Compute softmax on a given device + input_device = input.to(device) + mask_device = mask.to(device) + softmax_res = torch._masked_softmax(input_device, mask_device, dim, mask_type) + if mask_type == 0: + mask_expanded = mask_device.reshape(1, 1, L, L).expand(B, num_heads, L, L).bool() + elif mask_type == 1: + mask_expanded = mask_device.reshape(B, 1, 1, L).expand(B, num_heads, L, L).bool() + else: + mask_expanded = mask_device + # In result, should only fill the entirely masked out rows since those are non-deterministic (*may* be 0) + # Fill rows with all True's with 0 + mask_out = mask_expanded.all(dim, keepdim=True).expand(mask_expanded.shape) + softmax_res = softmax_res.masked_fill(mask_out, 0) + return softmax_res + + cpu_res = softmax_on_device(mask, input, "cpu") + cuda_res = softmax_on_device(mask, input, "cuda") + self.assertEqual(cpu_res, cuda_res, exact_dtype=True) + def test_masked_softmax(self, device): sizes = [(1, 1, 32), (3, 16, 310), (12, 4, 1024), (4, 2, 1200)] for (B, num_heads, L) in sizes: @@ -15145,59 +9808,6 @@ def _test_helper(shape): # test non-persistent softmax kernel _test_helper((4, 1536)) - @onlyCUDA - @largeTensorTest('12GB') - def test_conv_large_nosplit(self, device): - # Here we just test the convolution correctly route to the fallback implementation - # that is, it does not crash. The correctness of fallback implementation should be - # covered in other tests - dtype = torch.half if self.device_type == 'cuda' else torch.float - conv1 = nn.Conv2d(2, 2, 8, 8).to(device).to(dtype) - input_large = torch.randn(1, 2, 1024, 1024 * 1024, dtype=dtype, device=device) - conv1(input_large) - conv2 = torch.nn.Conv2d(1, 1024, 1, 1).to(device).to(dtype) - input_large = torch.randn(1, 1, 2048, 1024 , dtype=dtype, device=device) - conv2(input_large) - - def test_conv_noncontig_weights(self, device): - for dim in (1, 2, 3): - for grouped in (False, True): - nc = 3 - groups = 3 if grouped else 1 - w = torch.randn([3] * dim, device=device) - w = w.expand([nc, int(nc / groups)] + list(w.shape)) - w = w.detach().requires_grad_() - x = torch.randn([1, nc] + ([5] * dim), device=device, requires_grad=True) - y = getattr(F, 'conv{}d'.format(dim))(x, w, groups=groups) - y.sum().backward() - y = getattr(F, 'conv_transpose{}d'.format(dim))(x, w, groups=groups) - y.sum().backward() - - def test_conv_noncontig_weights_and_bias(self, device): - # need floats to exercise https://github.com/pytorch/pytorch/issues/16018 - for bias in [True, False]: - conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, - bias=bias).to(device, torch.float) - - input_nc = torch.randn((1, 3, 224, 224, 2), device=device, dtype=torch.float)[:, :, :, :, 1] - input_c = input_nc.contiguous() - - weight_nc = torch.randn((64, 3, 7, 7, 2), device=device, dtype=torch.float)[:, :, :, :, 1] - conv1.weight = nn.Parameter(weight_nc) - weight_c = conv1.weight.contiguous() - - if bias: - bias_nc = torch.randn((64, 2), device=device, dtype=torch.float)[:, 1] - conv1.bias = nn.Parameter(bias_nc) - bias_c = conv1.bias.contiguous() - - out1 = conv1(input_nc) - conv1.weight = nn.Parameter(weight_c) - if bias: - conv1.bias = nn.Parameter(bias_c) - out2 = conv1(input_c) - self.assertEqual(out1, out2) - def test_save_lstm_compatibility(self, device): # Test that saving an LSTM in PyTorch 1.7 and older can still be # loaded in newer versions of PyTorch. @@ -15348,63 +9958,6 @@ def test_grid_sample_large_index_3d(self, device, dtype): small_image.grad.zero_() large_view.grad.zero_() - @onlyCUDA - @largeTensorTest('12GB') - def test_conv_transposed_large(self, device): - dtype = torch.half if self.device_type == 'cuda' else torch.float - conv = nn.ConvTranspose2d(1, 1, 1, 1, bias=False).to(device).to(dtype) - input_large = torch.randn(4096, 1, 512, 1024, dtype=dtype, device=device) - # forward - ret = conv(input_large) - maxdiff0 = (ret.narrow(0, 0, 1024) - conv(input_large.narrow(0, 0, 1024))).abs_().max().item() - maxdiff1 = (ret.narrow(0, 1024, 1024) - conv(input_large.narrow(0, 1024, 1024))).abs_().max().item() - maxdiff2 = (ret.narrow(0, 2048, 1024) - conv(input_large.narrow(0, 2048, 1024))).abs_().max().item() - maxdiff3 = (ret.narrow(0, 3072, 1024) - conv(input_large.narrow(0, 3072, 1024))).abs_().max().item() - if self.device_type == 'cuda': - # cuDNN may use algorithms such as FFT that don't guarantee a diff of 0 - self.assertEqual(maxdiff0, 0, atol=2e-3, rtol=1e-5) - self.assertEqual(maxdiff1, 0, atol=2e-3, rtol=1e-5) - self.assertEqual(maxdiff2, 0, atol=2e-3, rtol=1e-5) - self.assertEqual(maxdiff3, 0, atol=2e-3, rtol=1e-5) - else: - self.assertEqual(maxdiff0, 0) - self.assertEqual(maxdiff1, 0) - self.assertEqual(maxdiff2, 0) - self.assertEqual(maxdiff3, 0) - - @onlyCUDA - @skipCUDAIfRocm - @largeTensorTest('12GB') - def test_conv_large(self, device): - dtype = torch.half if self.device_type == 'cuda' else torch.float - conv = nn.Conv2d(2, 2, 8, 8, bias=False).to(device).to(dtype) - input_large = torch.randn(4097, 2, 512, 512, dtype=dtype, device=device) - # forward - ret = conv(input_large) - self.assertEqual(ret[:2048], conv(input_large[:2048])) - self.assertEqual(ret[2048:4096], conv(input_large[2048:4096])) - self.assertEqual(ret[4096:], conv(input_large[4096:])) - - # backward - conv.zero_grad() - # When computing the backward, we are using the `max(dim=1)`` to create - # some sparsity. Without this sparsity, the rounding error would be - # too large (as large as 1e-5) to satisfy the creterion (1e-6) of `assertEqual` - ret.view(4097, -1).max(dim=1).values.sum().backward() - del ret - grad1 = conv.weight.grad.detach().clone() - conv.zero_grad() - conv(input_large[:2048]).view(2048, -1).max(dim=1).values.sum().backward() - conv(input_large[2048:4096]).view(2048, -1).max(dim=1).values.sum().backward() - conv(input_large[4096:]).view(1, -1).max(dim=1).values.sum().backward() - grad2 = conv.weight.grad.detach().clone() - # gradients are at the order of hundreds, we need to scale it to - # the order of one so that we can compare - scale = 1 / grad2.abs().mean() - grad1 = grad1 * scale - grad2 = grad2 * scale - self.assertEqual(grad1, grad2, atol=5e-2, rtol=5e-3) - def _test_gumbel_softmax_st_shapes(self, device, dtype, shape, dim, count_expected): logits = torch.randn(shape, dtype=torch.float, device=device) logits = logits.to(dtype) @@ -15422,6 +9975,7 @@ def _test_gumbel_softmax_straight_through(self, device, dtype): num_draws = 100 logits = torch.tensor([[0.2, 0.8, 0.1]], device=device) + logits = logits.reshape([1, 3]) logits = logits.to(dtype).requires_grad_() probs = logits.softmax(dim=-1) @@ -15461,9 +10015,9 @@ def _test_gumbel_softmax_grad(self, device, dtype): tol = 2 * torch.finfo(dtype).eps self.assertEqual(logits_soft.grad, logits_hard.grad, atol=tol, rtol=0) - @skipIfMps @dtypesIfCUDA(torch.half, torch.float, torch.double) @dtypes(torch.float, torch.double) + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_gumbel_softmax(self, device, dtype): self._test_gumbel_softmax_st_shapes(device, dtype, shape=[5], dim=0, count_expected=1) self._test_gumbel_softmax_st_shapes(device, dtype, shape=[5], dim=-1, count_expected=1) @@ -15490,6 +10044,7 @@ def _test_rnn_retain_variables(self, device, dtype): self.assertEqual(grads, grads2) @dtypesIfCUDA(torch.half, torch.float, torch.double) + @dtypesIfMPS(torch.half, torch.float) @dtypes(torch.double) def test_rnn_retain_variables(self, device, dtype): self._test_rnn_retain_variables(device, dtype) @@ -15540,6 +10095,7 @@ def flatten_out(mod, inp): # Merge into OpInfo? @skipMeta # LSTM cell reuses output which was resized @dtypes(torch.double) + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_LSTM_grad_and_gradgrad(self, device, dtype): hsize = 4 inp = torch.rand(1, 3, hsize, device=device, dtype=dtype, requires_grad=True) @@ -15549,6 +10105,7 @@ def test_LSTM_grad_and_gradgrad(self, device, dtype): @skipMeta # GRU cell reuses output which was resized @dtypes(torch.double) + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_GRU_grad_and_gradgrad(self, device, dtype): hsize = 4 inp = torch.rand(1, 3, hsize, device=device, dtype=dtype, requires_grad=True) @@ -15714,76 +10271,6 @@ def test_CTCLoss_no_batch_dim(self, device, reduction, use_module_form): self._assertEqual_list((input_length, 1, vocab_size), [t.grad.shape for t in log_probs_refs]) self._assertEqual_list((input_length, vocab_size), [t.grad.shape for t in log_probs_no_bd_refs]) - @onlyCUDA - @skipCUDAIfNoCudnn - def test_contig_wrong_stride_cudnn(self, device): - # x has to have batch_size 1 to test contiguous checks - x = torch.randn(1, 16, 5, 5, device=device) - stride = list(x.stride()) - stride[0] = 20 - # change the stride in dimension 0. the tensor is still contiguous because size[0] is 1 - x.set_(x.storage(), 0, x.size(), stride) - self.assertTrue(x.is_contiguous()) - F.conv_transpose2d(x, torch.randn(16, 1, 1, 1, device=device)) - F.conv2d(x, torch.randn(1, 16, 1, 1, device=device)) - - @onlyCUDA - def test_Conv2d_size_1_kernel(self, device): - x_cpu = torch.randn(2, 3, 5, 5) - conv_cpu = torch.nn.Conv2d(3, 3, kernel_size=1) - y_cpu = conv_cpu(x_cpu) - y = torch.rand_like(y_cpu) - y_cpu.backward(y) - - with cudnn.flags(enabled=False): - conv_cuda = torch.nn.Conv2d(3, 3, kernel_size=1).to(device) - conv_cuda.bias.data.copy_(conv_cpu.bias.data) - conv_cuda.weight.data.copy_(conv_cpu.weight.data) - y_cuda = conv_cuda(x_cpu.to(device)) - y_cuda.backward(y.to(device)) - - self.assertEqual(y_cpu, y_cuda, atol=1e-5, rtol=0, exact_device=False) - self.assertEqual(conv_cpu.bias.grad.data, conv_cuda.bias.grad.data, atol=1e-5, rtol=0, exact_device=False) - self.assertEqual(conv_cpu.weight.grad.data, conv_cuda.weight.grad.data, atol=1e-5, rtol=0, exact_device=False) - - @onlyCUDA - def test_ConvTranspose2d_size_1_kernel(self, device): - x_cpu = torch.randn(2, 3, 5, 5) - conv_cpu = torch.nn.ConvTranspose2d(3, 3, kernel_size=1) - y_cpu = conv_cpu(x_cpu) - y = torch.rand_like(y_cpu) - y_cpu.backward(y) - - with cudnn.flags(enabled=False): - conv_cuda = torch.nn.ConvTranspose2d(3, 3, kernel_size=1).to(device) - conv_cuda.bias.data.copy_(conv_cpu.bias.data) - conv_cuda.weight.data.copy_(conv_cpu.weight.data) - y_cuda = conv_cuda(x_cpu.to(device)) - y_cuda.backward(y.to(device)) - - self.assertEqual(y_cpu, y_cuda, atol=1e-5, rtol=0, exact_device=False) - self.assertEqual(conv_cpu.bias.grad.data, conv_cuda.bias.grad.data, atol=1e-5, rtol=0, exact_device=False) - self.assertEqual(conv_cpu.weight.grad.data, conv_cuda.weight.grad.data, atol=1e-5, rtol=0, exact_device=False) - - @onlyCUDA - def test_ConvTranspose3d_size_1_kernel(self, device): - x_cpu = torch.randn(2, 3, 3, 5, 5) - conv_cpu = torch.nn.ConvTranspose3d(3, 3, kernel_size=1) - y_cpu = conv_cpu(x_cpu) - y = torch.rand_like(y_cpu) - y_cpu.backward(y) - - with cudnn.flags(enabled=False): - conv_cuda = torch.nn.ConvTranspose3d(3, 3, kernel_size=1).to(device) - conv_cuda.bias.data.copy_(conv_cpu.bias.data) - conv_cuda.weight.data.copy_(conv_cpu.weight.data) - y_cuda = conv_cuda(x_cpu.to(device)) - y_cuda.backward(y.to(device)) - - self.assertEqual(y_cpu, y_cuda, atol=1e-5, rtol=0, exact_device=False) - self.assertEqual(conv_cpu.bias.grad.data, conv_cuda.bias.grad.data, atol=1e-5, rtol=0, exact_device=False) - self.assertEqual(conv_cpu.weight.grad.data, conv_cuda.weight.grad.data, atol=1e-5, rtol=0, exact_device=False) - def _ordered_sequence(self, device, dtype): """Create ordered list of random sequences""" seqs = [torch.empty(random.randint(1, 6), device=device, dtype=dtype) @@ -15869,96 +10356,6 @@ def test_softmax(self, device, dtype): # should be bitwise equal self.assertEqual(input.grad, inputf.grad.to(dtype), atol=0, rtol=0) - @onlyCUDA - @dtypes(torch.half, torch.float, torch.double) - def test_multihead_attention_dtype(self, device, dtype): - embed_dim = 128 - num_heads = 8 - sl = 10 - bs = 8 - model = nn.MultiheadAttention(embed_dim, num_heads).cuda().to(dtype) - q = torch.randn(sl, bs, embed_dim, device=device, dtype=dtype) - k = torch.randn(sl, bs, embed_dim, device=device, dtype=dtype) - v = torch.randn(sl, bs, embed_dim, device=device, dtype=dtype) - out = model(q, k, v) - self.assertEqual(q.size(), out[0].size()) - self.assertEqual(dtype, out[0].dtype) - - @onlyCUDA - @dtypes(torch.half, torch.float, torch.double) - def test_multihead_attention_dtype_batch_first(self, device, dtype): - embed_dim = 128 - num_heads = 8 - sl = 10 - bs = 8 - # With batch_first=True, we have the possibility of hitting - # the native fast path if we call .eval() and enable inference - # mode. Test both paths. - for training in (True, False): - model = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True).cuda().to(dtype) - if not training: - model = model.eval() - cm = torch.no_grad() - else: - cm = contextlib.nullcontext() - with cm: - q = torch.randn(bs, sl, embed_dim, device=device, dtype=dtype) - k = torch.randn(bs, sl, embed_dim, device=device, dtype=dtype) - v = torch.randn(bs, sl, embed_dim, device=device, dtype=dtype) - # fast path currently doesn't support weights - out = model(q, k, v, need_weights=False) - self.assertEqual(q.size(), out[0].size()) - self.assertEqual(dtype, out[0].dtype) - - @dtypesIfCUDA(*floating_types_and(torch.half, *[torch.bfloat16] if AMPERE_OR_ROCM else [])) - @dtypes(torch.float) - @torch.backends.cudnn.flags(enabled=True, benchmark=False) - def test_Conv2d_naive_groups(self, device, dtype): - # Check that grouped convolutions matches two half convolutions - m = nn.Conv2d(4, 4, kernel_size=3, groups=2).to(device, dtype) - i = torch.randn(2, 4, 6, 6, device=device, dtype=dtype, requires_grad=True) - output = m(i) - grad_output = torch.randn(2, 4, 4, 4, device=device, dtype=dtype) - output.backward(grad_output) - - m1 = nn.Conv2d(2, 2, kernel_size=3).to(device, dtype) - m1.weight.data.copy_(m.weight.data[:2]) - m1.bias.data.copy_(m.bias.data[:2]) - i1 = i.data[:, :2].contiguous().requires_grad_(True) - output1 = m1(i1) - output1.backward(grad_output[:, :2].contiguous()) - - m2 = nn.Conv2d(2, 2, kernel_size=3).to(device, dtype) - m2.weight.data.copy_(m.weight.data[2:]) - m2.bias.data.copy_(m.bias.data[2:]) - i2 = i.data[:, 2:].contiguous().requires_grad_(True) - output2 = m2(i2) - output2.backward(grad_output[:, 2:].contiguous()) - - self.assertEqual(output, torch.cat([output1, output2], 1)) - self.assertEqual(i.grad.data, - torch.cat([i1.grad.data, i2.grad.data], 1), - atol=dtype2prec_DONTUSE[dtype], rtol=0) - self.assertEqual(m.bias.grad.data, - torch.cat([m1.bias.grad.data, m2.bias.grad.data], 0), - atol=dtype2prec_DONTUSE[dtype], rtol=0) - self.assertEqual(m.weight.grad.data, - torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0), - atol=dtype2prec_DONTUSE[dtype], rtol=0) - - @dtypes(torch.double, torch.cdouble) - def test_Conv2d_backward_depthwise(self, device, dtype): - x = torch.randn(2, 2, 4, 20, device=device, dtype=dtype, requires_grad=True) - weight = torch.randn(2, 1, 3, 5, device=device, dtype=dtype, requires_grad=True) - - def conv2d_depthwise(x, weight): - return torch.nn.functional.conv2d( - x, weight, bias=None, stride=(1, 10), groups=2) - - for cudnn_enabled in [False, True]: - with torch.backends.cudnn.flags(enabled=cudnn_enabled): - torch.autograd.gradcheck(conv2d_depthwise, (x, weight)) - def _test_batchnorm_grad(self, device, dtype=torch.double): bs, n_feat, size_feat = 4, 5, 6 input = torch.arange(bs * n_feat * size_feat, device=device, @@ -15971,6 +10368,7 @@ def _test_batchnorm_grad(self, device, dtype=torch.double): _assertGradAndGradgradChecks(self, F.batch_norm, (input, running_mean, running_var, weight, bias, training, 0.1, 0.0001)) + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_batchnorm_grad(self, device): self._test_batchnorm_grad(device) @@ -16009,6 +10407,7 @@ def test_layernorm_weight_bias(self): out_zero_bias = torch.layer_norm(input, normalized_shape, data, bias, eps) self.assertEqual(out_none_bias, out_zero_bias) + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_hardsigmoid_grad(self, device): inputs = (torch.randn(4, 16, 16, device=device) - 0.5) * 10 inputs.requires_grad = True @@ -16016,6 +10415,7 @@ def test_hardsigmoid_grad(self, device): # currently fails on XLA @onlyNativeDeviceTypes + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_hardswish_grad(self, device): inputs = (torch.randn(4, 16, 16, device=device) - 0.5) * 10 inputs.requires_grad = True @@ -16207,6 +10607,7 @@ def test_batchnorm_simple_average_mixed(self, device, dtype): @onlyNativeDeviceTypes @dtypes(torch.float, torch.double) + @dtypesIfMPS(torch.float) def test_grid_sample_nan_inf(self, device, dtype): input = torch.zeros([1, 1, 3, 3], device=device, dtype=dtype) grid = torch.tensor([[[[nan, 0], [0, inf]]]], device=device, dtype=dtype) @@ -16235,6 +10636,7 @@ def test_CTCLoss_empty_target(self, device): # Merge into OpInfo? @skipCUDAIf(True, """Test is flaky on Linux and Windows, typical error message: https://github.com/pytorch/pytorch/issues/34870""") + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_ctc_loss(self, device): batch_size = 64 num_labels = 101 @@ -16427,6 +10829,7 @@ def test_batchnorm_update_stats(self, device): with torch.backends.cudnn.flags(enabled=False): self._test_batchnorm_update_stats(device) + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_multi_margin_loss_errors(self, device): self.assertRaises(RuntimeError, lambda: nn.functional.multi_margin_loss(torch.randn(5, device=device), @@ -16460,6 +10863,9 @@ def test_bfloat16(fn, device, inp_dims, prec): test_bfloat16(torch.nn.Softshrink(), device, shape, prec=1e-2) test_bfloat16(torch.nn.Hardswish(), device, shape, prec=2e-2) test_bfloat16(torch.nn.Softplus(), device, shape, prec=1e-2) + test_bfloat16(torch.nn.SiLU(), device, shape, prec=1e-2) + test_bfloat16(torch.nn.Hardtanh(), device, shape, prec=1e-2) + test_bfloat16(torch.nn.Mish(), device, shape, prec=1e-2) @onlyCUDA def test_activations_bfloat16(self, device): @@ -16478,370 +10884,6 @@ def test_softmax_bfloat16(self, device): # test softmax with large input value which casues exp() to overflow _test_bfloat16_ops(self, torch.nn.Softmax(dim=dim), device, inp_dims=(16, 33, 15, 16), prec=0.05, scale_factor=1000.0) - @onlyCPU - @dtypes(torch.float, torch.double) - def test_conv_thnn_nhwc(self, device, dtype): - def helper(n, c, h, w, out_channels, kernel_size, dilation, groups, input_format, weight_format): - input = torch.randint(-3, 3, (n, c, h, w), dtype=dtype, device=device)\ - .to(memory_format=input_format) - input.requires_grad_() - conv = nn.Conv2d(c, out_channels, kernel_size, dilation=dilation, groups=groups)\ - .to(device='cpu', dtype=dtype, memory_format=weight_format) - for p in conv.parameters(): - p.data = torch.randint_like(p, -3, 3) - - ref_input = input.detach().clone().contiguous().requires_grad_() - ref_conv = nn.Conv2d(c, out_channels, kernel_size, dilation=dilation, groups=groups) - # load_state_dict will restore the stride & memory_layout on ref_conv.weight. - ref_conv.load_state_dict(conv.state_dict()) - ref_conv = ref_conv.to(device='cpu', dtype=dtype, memory_format=torch.contiguous_format) - - out = conv(input) - ref_out = ref_conv(ref_input) - - grad = torch.randint_like(out, -3, 3) - ref_grad = grad.detach().clone().contiguous() - - out.backward(grad) - ref_out.backward(ref_grad) - - self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) - self.assertTrue(ref_out.is_contiguous()) - self.assertEqual(out, ref_out, exact_dtype=False) - self.assertEqual(conv.weight.grad, ref_conv.weight.grad, exact_dtype=False) - self.assertEqual(conv.bias.grad, ref_conv.bias.grad, exact_dtype=False) - self.assertEqual(input.grad, ref_input.grad, exact_dtype=False) - - with torch.backends.mkldnn.flags(enabled=False): - formats = [[torch.channels_last, torch.channels_last], - [torch.channels_last, torch.contiguous_format], - [torch.contiguous_format, torch.channels_last]] - for input_format, weight_format in formats: - # non-dilated conv: thnn_conv2d normal path (with im2col) - helper(2, 8, 4, 4, out_channels=4, kernel_size=3, dilation=1, groups=1, - input_format=input_format, weight_format=weight_format) - helper(2, 8, 4, 4, out_channels=8, kernel_size=3, dilation=1, groups=8, - input_format=input_format, weight_format=weight_format) - # test when input chanels is 1 and not converted to channels last - helper(2, 1, 10, 10, out_channels=8, kernel_size=3, dilation=1, groups=1, - input_format=torch.contiguous_format, weight_format=torch.channels_last) - # non-dilated conv: thnn_conv2d fast path (skip im2col) - helper(1, 16, 56, 56, out_channels=16, kernel_size=1, dilation=1, groups=1, - input_format=input_format, weight_format=weight_format) - # ic == oc == 1 here, so need to stick input to CL to activate channels last - helper(1, 16, 56, 56, out_channels=16, kernel_size=1, dilation=1, groups=16, - input_format=torch.channels_last, weight_format=weight_format) - # dilated conv: slow_conv_dilated2d - helper(2, 8, 11, 13, out_channels=16, kernel_size=3, dilation=2, groups=1, - input_format=input_format, weight_format=weight_format) - helper(2, 16, 11, 13, out_channels=32, kernel_size=3, dilation=2, groups=16, - input_format=input_format, weight_format=weight_format) - - @onlyCUDA - @skipCUDAIfRocmVersionLessThan((4, 3)) - @skipCUDAIfNotMiopenSuggestNHWC - @skipCUDAIfCudnnVersionLessThan(7603) - @dtypes(torch.half, torch.float, torch.cfloat) - def test_conv_cudnn_nhwc(self, device, dtype): - def helper(n, c, h, w, out_channels, kernel_size, groups): - input = torch.randint(-3, 3, (n, c, h, w), dtype=dtype, device=device)\ - .to(memory_format=torch.channels_last) - input.requires_grad_() - conv = nn.Conv2d(c, out_channels, kernel_size, groups=groups)\ - .to(device='cuda', dtype=dtype, memory_format=torch.channels_last) - for p in conv.parameters(): - p.data = torch.randint_like(p, -3, 3) - - # use FP64 channels-first conv as reference - ref_input = input.detach().clone().contiguous().double().requires_grad_() - ref_conv = nn.Conv2d(c, out_channels, kernel_size, groups=groups) - # load_state_dict will restore the stride & memory_layout on ref_conv.weight. - ref_conv.load_state_dict(conv.state_dict()) - ref_conv = ref_conv.to(device='cuda', dtype=torch.double, memory_format=torch.contiguous_format) - - out = conv(input) - ref_out = ref_conv(ref_input) - - grad = torch.randint_like(out, -3, 3) - ref_grad = grad.detach().clone().double().contiguous() - - out.backward(grad) - ref_out.backward(ref_grad) - - self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) - self.assertTrue(input.grad.is_contiguous(memory_format=torch.channels_last)) - self.assertTrue(conv.weight.grad.is_contiguous(memory_format=torch.channels_last)) - - self.assertTrue(ref_out.is_contiguous()) - self.assertTrue(ref_input.grad.is_contiguous()) - self.assertTrue(ref_conv.weight.grad.is_contiguous()) - - self.assertEqual(out, ref_out, exact_dtype=False) - self.assertEqual(conv.weight.grad, ref_conv.weight.grad, exact_dtype=False) - self.assertEqual(conv.bias.grad, ref_conv.bias.grad, exact_dtype=False) - self.assertEqual(input.grad, ref_input.grad, exact_dtype=False) - - helper(2, 8, 4, 4, out_channels=4, kernel_size=3, groups=1) - helper(2, 8, 4, 4, out_channels=8, kernel_size=3, groups=8) - helper(1, 16, 56, 56, out_channels=16, kernel_size=3, groups=1) - helper(1, 16, 56, 56, out_channels=16, kernel_size=3, groups=16) - - @onlyCUDA - @skipCUDAIfRocm - @skipCUDAIfCudnnVersionLessThan(8005) - @dtypes(torch.half, torch.float) - def test_conv_cudnn_ndhwc(self, device, dtype): - def helper(n, c, d, h, w, out_channels, kernel_size, groups): - input = torch.randint(-2, 2, (n, c, d, h, w), dtype=dtype, device=device)\ - .to(memory_format=torch.channels_last_3d) - input.requires_grad_() - conv = nn.Conv3d(c, out_channels, kernel_size, groups=groups)\ - .to(device='cuda', dtype=dtype, memory_format=torch.channels_last_3d) - for p in conv.parameters(): - p.data = torch.randint_like(p, -2, 2) - - # use FP64 channels-first conv as reference - ref_input = input.detach().clone().contiguous().double().requires_grad_() - ref_conv = nn.Conv3d(c, out_channels, kernel_size, groups=groups) - # load_state_dict will restore the stride & memory_layout on ref_conv.weight. - ref_conv.load_state_dict(conv.state_dict()) - ref_conv = ref_conv.to(device='cuda', dtype=torch.double, memory_format=torch.contiguous_format) - - out = conv(input) - ref_out = ref_conv(ref_input) - - grad = torch.randint_like(out, -2, 2) - ref_grad = grad.detach().clone().double().contiguous() - - out.backward(grad) - ref_out.backward(ref_grad) - - self.assertTrue(out.is_contiguous(memory_format=torch.channels_last_3d)) - self.assertTrue(input.grad.is_contiguous(memory_format=torch.channels_last_3d)) - self.assertTrue(conv.weight.grad.is_contiguous(memory_format=torch.channels_last_3d)) - - self.assertTrue(ref_out.is_contiguous()) - self.assertTrue(ref_input.grad.is_contiguous()) - self.assertTrue(ref_conv.weight.grad.is_contiguous()) - - self.assertEqual(out, ref_out, exact_dtype=False) - self.assertEqual(conv.weight.grad, ref_conv.weight.grad, exact_dtype=False) - self.assertEqual(conv.bias.grad, ref_conv.bias.grad, exact_dtype=False) - self.assertEqual(input.grad, ref_input.grad, exact_dtype=False) - - helper(2, 8, 4, 4, 4, out_channels=4, kernel_size=3, groups=1) - helper(2, 8, 4, 4, 4, out_channels=8, kernel_size=3, groups=8) - helper(1, 16, 18, 18, 18, out_channels=16, kernel_size=3, groups=1) - helper(1, 16, 18, 18, 18, out_channels=16, kernel_size=3, groups=16) - - def _run_conv(self, layer, device, inp, grad, ref_conv, ref_input, ref_out, - input_format, weight_format, grad_format, output_format): - conv = layer(inp.size(1), grad.size(1), - ref_conv.weight.size(2)).float().to(device) - # load_state_dict will restore the stride & memory_layout on ref_conv.weight. - conv.load_state_dict(ref_conv.state_dict()) - weight_data = conv.weight.detach().clone().contiguous(memory_format=weight_format) - conv.weight.data = weight_data.resize_(weight_data.size(), memory_format=weight_format) - input = inp.clone().contiguous(memory_format=input_format) - input.resize_(input.size(), memory_format=input_format) - input = input.requires_grad_() - grad = grad.contiguous(memory_format=grad_format) - grad.resize_(grad.size(), memory_format=grad_format) - out = conv(input) - out.backward(grad) - self.assertTrue(out.is_contiguous(memory_format=output_format)) - self.assertEqual(out, ref_out) - self.assertEqual(conv.weight.grad, ref_conv.weight.grad) - self.assertEqual(conv.bias.grad, ref_conv.bias.grad) - self.assertEqual(input.grad, ref_input.grad) - - def _test_conv_cudnn_nhwc_nchw(self, layer, n, c, h, w, k, filter_size, device): - data = torch.randint(1, 10, (n, c, h, w), dtype=torch.float32, device=device) - ref_input = data.clone().contiguous().requires_grad_(True) - ref_conv = layer(c, k, filter_size).float().to(device) - ref_out = ref_conv(ref_input) - grad = torch.randint(1, 10, ref_out.size(), dtype=torch.float32, device="cuda") - ref_out.backward(grad) - - for w_f in [torch.contiguous_format, torch.channels_last]: - for g_f in [torch.contiguous_format, torch.channels_last]: - for input_format in [torch.contiguous_format, torch.channels_last]: - output_format = torch.contiguous_format - # Older versions of CudNN have Channels Last support disabled - if torch.backends.cudnn.version() >= 7603: - if input_format == torch.channels_last: - output_format = torch.channels_last - # This is because we have N111 weight that cannot handle - # the ambiguous memory_format - if w_f == torch.channels_last: - if layer == nn.Conv2d and filter_size * c != 1: - output_format = torch.channels_last - if layer == nn.ConvTranspose2d and filter_size * k != 1: - output_format = torch.channels_last - self._run_conv(layer, device, data, grad, ref_conv, ref_input, - ref_out, input_format, w_f, g_f, output_format) - - @onlyCUDA - @skipCUDAIfRocmVersionLessThan((4, 3)) - @skipCUDAIfNotMiopenSuggestNHWC - @skipCUDAIfCudnnVersionLessThan(7603) - @tf32_on_and_off(0.05) - def test_conv_cudnn_mismatch_memory_format(self, device): - configs = [ - [4, 2, 8, 8, 4, 2], - [4, 1, 8, 8, 4, 2], - [1, 1, 8, 8, 4, 2], - [4, 2, 2, 8, 4, 1], - [4, 2, 1, 8, 4, 1], - [4, 2, 8, 8, 4, 1], - [4, 1, 8, 8, 4, 1], - ] - for n, c, h, w, k, filter_size in configs: - self._test_conv_cudnn_nhwc_nchw(nn.Conv2d, n, c, h, w, k, filter_size, device) - self._test_conv_cudnn_nhwc_nchw(nn.ConvTranspose2d, n, c, h, w, k, filter_size, device) - - # torch.half is erroring out on Windows with CUDA 10.1 + cuDNN 7.6.4 - # returning CUDNN_STATUS_BAD_PARAM - # Disabling that specific test for now [see issue # 33918] - @onlyCUDA - @skipCUDAIfNoCudnn - @dtypes(torch.float, torch.double) - def test_conv_cudnn_nhwc_support(self, device, dtype): - input = torch.randn((1, 16, 1, 1), dtype=dtype, device="cuda", requires_grad=True) - weight = torch.randn((8, 16, 3, 3), dtype=dtype, device="cuda", requires_grad=True) - weight = weight.to(memory_format=torch.channels_last) - o = torch.conv2d(input, weight, None, (2, 1), (1, 1), (1, 1), 1) - self.assertTrue(o.is_contiguous(memory_format=torch.channels_last)) - o.sum().backward() - - # Test that faster algorithms used for inference produce the same results - # Validates depthwise3x3 bug reported in https://github.com/pytorch/pytorch/issues/60176 - @onlyCPU - @dtypes(torch.float) - def test_conv2d_no_grad(self, device, dtype): - for batch in [1, 2, 3]: - for groups in [1, 2, 4]: - input = torch.rand(batch, groups, 8, 8, dtype=dtype, device=device) - m = nn.Conv2d(groups, 8, kernel_size=(3, 3), groups=groups, dtype=dtype, device=device) - with torch.no_grad(): - output_ng = m(input) - output = m(input) - self.assertEqual(output, output_ng, rtol=1e-2, atol=1e-5) - - @onlyCUDA - @skipCUDAIfNoCudnn - @dtypes(torch.float, torch.float16) - @precisionOverride({torch.half: 0.002, torch.float: 1e-4}) - def test_cudnn_convolution_relu(self, device, dtype): - for batch, groups, image_size, kernel_size, memory_format in \ - product((1, 2, 3), - (1, 2, 4), - ((1, 1), (8, 8)), - ((1, 1), (3, 3)), - (torch.channels_last, torch.contiguous_format)): - if image_size[0] < kernel_size[0]: - continue - inp = torch.rand(batch, groups, *image_size, dtype=dtype, device=device) - w = torch.randn(8, groups, *kernel_size, dtype=dtype, device=device) - conv2d_out = torch.conv2d(inp, w, None, (1, 1), (0, 0), (1, 1), 1) - inp = inp.to(memory_format=memory_format) - w = w.to(memory_format=memory_format) - if torch.version.hip: - cudnn_out = torch.miopen_convolution_relu(inp, w, None, (1, 1), (0, 0), (1, 1), 1) - else: - cudnn_out = torch.cudnn_convolution_relu(inp, w, None, (1, 1), (0, 0), (1, 1), 1) - self.assertTrue(cudnn_out.is_contiguous(memory_format=memory_format)) - if tf32_is_not_fp32() and dtype == torch.float: - self.assertEqual(conv2d_out.relu(), cudnn_out, atol=2e-4, rtol=0.006) - else: - self.assertEqual(conv2d_out.relu(), cudnn_out) - - @onlyCUDA - @skipCUDAIfNoCudnn - @dtypes(torch.float, torch.float16) - @precisionOverride({torch.half: 0.002, torch.float: 1e-4}) - def test_cudnn_convolution_add_relu(self, device, dtype): - for batch, groups, image_size, kernel_size, memory_format in \ - product((1, 2, 3), - (1, 2, 4), - ((1, 1), (8, 8)), - ((1, 1), (3, 3)), - (torch.channels_last, torch.contiguous_format)): - if image_size[0] < kernel_size[0]: - continue - inp = torch.rand(batch, groups, *image_size, dtype=dtype, device=device) - w = torch.randn(8, groups, *kernel_size, dtype=dtype, device=device) - conv2d_out = torch.conv2d(inp, w, None, (1, 1), (0, 0), (1, 1), 1) - alpha = 2.0 - z = torch.randn_like(conv2d_out) - - inp = inp.to(memory_format=memory_format) - w = w.to(memory_format=memory_format) - z = z.to(memory_format=memory_format) - if torch.version.hip: - cudnn_out = torch.miopen_convolution_add_relu(inp, w, z, alpha, None, (1, 1), (0, 0), (1, 1), 1) - else: - cudnn_out = torch.cudnn_convolution_add_relu(inp, w, z, alpha, None, (1, 1), (0, 0), (1, 1), 1) - - self.assertTrue(cudnn_out.is_contiguous(memory_format=memory_format)) - if tf32_is_not_fp32() and dtype == torch.float: - self.assertEqual(F.relu(conv2d_out + alpha * z), cudnn_out, atol=3e-4, rtol=0.006) - else: - self.assertEqual(F.relu(conv2d_out + alpha * z), cudnn_out) - - @onlyCUDA - @skipCUDAIfRocm - @skipCUDAIfCudnnVersionLessThan(7603) - def test_convert_conv2d_weight_memory_format(self, device): - input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float32, device=device) - model = nn.Sequential( - nn.Conv2d(8, 4, 3), - nn.BatchNorm2d(4)).to(device).float() - for memory_format in [torch.channels_last, torch.contiguous_format]: - model = nn.utils.convert_conv2d_weight_memory_format(model, memory_format) - out = model(input) - self.assertTrue(out.is_contiguous(memory_format=memory_format)) - - model = nn.Sequential( - nn.ConvTranspose2d(8, 4, 3), - nn.BatchNorm2d(4)).to(device).float() - for memory_format in [torch.channels_last, torch.contiguous_format]: - model = nn.utils.convert_conv2d_weight_memory_format(model, memory_format) - out = model(input) - self.assertTrue(out.is_contiguous(memory_format=memory_format)) - - def test_conv_double_backward_strided_with_3D_input_and_weight(self, device): - # Test that _convolution_double_backward() outputs the correct grad shapes - # for 3D input / weight when stride > 1. This is an ad-hoc regression test for a - # specific case that was uncovered during the convolution consolidation effort. - # The test can be safely deleted if _convolution_double_backward() is removed. - - input = torch.randn(2, 3, 6, device=device) - weight = torch.randn(3, 3, 3, device=device) - bias = torch.randn(3, device=device) - stride = (2,) - padding = (1,) - dilation = (1,) - transposed = False - output_padding = (0,) - groups = 1 - output = torch.ops.aten.convolution(input, weight, bias, stride, padding, dilation, transposed, - output_padding, groups) - - ggI = torch.randn(input.shape, device=device) - ggW = torch.randn(weight.shape, device=device) - ggB = torch.randn(bias.shape, device=device) - gO = torch.randn(output.shape, device=device) - output_mask = [True, True, True] - grad_grad_output, grad_input, grad_weight = torch.ops.aten._convolution_double_backward( - ggI, ggW, ggB, gO, weight, input, stride, padding, dilation, transposed, - output_padding, groups, output_mask) - - # Make sure the correct shapes are computed. - self.assertEqual(grad_grad_output.shape, gO.shape) - self.assertEqual(grad_input.shape, input.shape) - self.assertEqual(grad_weight.shape, weight.shape) - def test_nll_loss_mismatched_batch(self, device): x = torch.randn((10, 3), requires_grad=True, device=device) # t should have size (10,) @@ -16849,18 +10891,21 @@ def test_nll_loss_mismatched_batch(self, device): with self.assertRaisesRegex(ValueError, 'Expected.*batch_size'): F.nll_loss(x, t) + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_nll_loss_out_of_bounds_ignore_index(self, device): x = torch.randn(6, 3, requires_grad=True, device=device) t = torch.tensor([0, 1, 255, 0, 1, 2], dtype=torch.int64, device=device) for reduction in ['mean', 'none']: F.nll_loss(x, t, ignore_index=255, reduction=reduction).sum().backward() + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_nll_loss_invalid_target_dim(self, device): x = torch.randn((10, 3), device=device) t = torch.zeros((10, 2), dtype=torch.int64, device=device) with self.assertRaisesRegex(RuntimeError, "1D target tensor expected"): F.nll_loss(x, t) + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_nll_loss_invalid_weights(self, device): x = torch.randn((10, 3), device=device) t = torch.empty(10, dtype=torch.int64, device=device).random_(0, 3) @@ -16918,6 +10963,7 @@ def _nll_loss_helper(self, input_size, reduction, expected, device): output.sum().backward() self.assertEqual(input.grad.size(), input.size()) + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_nll_loss_empty_tensor_reduction_none(self, device): self._nll_loss_helper([0, 3], "none", torch.empty([0], device=device), device) self._nll_loss_helper([0, 3, 5, 7], "none", torch.empty([0, 5, 7], device=device), device) @@ -16925,6 +10971,7 @@ def test_nll_loss_empty_tensor_reduction_none(self, device): self._nll_loss_helper([2, 3, 5, 0], "none", torch.empty([2, 5, 0], device=device), device) self._nll_loss_helper([2, 3, 5, 7, 0], "none", torch.empty([2, 5, 7, 0], device=device), device) + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") @unittest.skipIf(TEST_WITH_UBSAN, "division-by-zero error with UBSAN") def test_nll_loss_empty_tensor_reduction_mean(self, device): nan = torch.tensor(float('nan'), device=device) @@ -16934,6 +10981,7 @@ def test_nll_loss_empty_tensor_reduction_mean(self, device): self._nll_loss_helper([2, 3, 5, 0], "mean", nan, device) self._nll_loss_helper([2, 3, 5, 7, 0], "mean", nan, device) + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_nll_loss_empty_tensor_reduction_sum(self, device): zero = torch.tensor(0, device=device) self._nll_loss_helper([0, 3], "sum", zero, device) @@ -16943,6 +10991,7 @@ def test_nll_loss_empty_tensor_reduction_sum(self, device): self._nll_loss_helper([2, 3, 5, 7, 0], "sum", zero, device) @unittest.skipIf(TEST_WITH_UBSAN, "division-by-zero error with UBSAN") + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_nll_loss_total_weight_is_zero(self, device): def helper(input_size): @@ -16960,6 +11009,7 @@ def helper(input_size): helper([2, 3, 5, 7, 9]) @unittest.skipIf(TEST_WITH_UBSAN, "division-by-zero error with UBSAN") + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_nll_loss_all_ignored(self, device): def helper(input_size): @@ -16975,6 +11025,7 @@ def helper(input_size): helper([2, 3, 5, 7]) helper([2, 3, 5, 7, 9]) + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_nll_loss_byte_target_matches_long(self, device): N, C = 10, 4 input = torch.randn(N, C, device=device, requires_grad=True) @@ -16997,6 +11048,7 @@ def compute_result_and_gradient(reduction, target_dtype): self.assertEqual(result_long, result_byte) self.assertEqual(grad_long, grad_byte) + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_cross_entropy_loss_prob_target_all_reductions(self, device): # Test with k-dimensional loss. for k in range(5): @@ -17013,6 +11065,7 @@ def test_cross_entropy_loss_prob_target_all_reductions(self, device): input, target, reduction=reduction, weight=w) self.assertEqual(output, output_ref) + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_cross_entropy_loss_prob_target_unit_weights(self, device): # Test with k-dimensional loss. for k in range(5): @@ -17032,6 +11085,7 @@ def test_cross_entropy_loss_prob_target_unit_weights(self, device): @parametrize_test('reduction', ['none', 'mean', 'sum']) @parametrize_test('weighted', [False, True]) + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_cross_entropy_loss_prob_target_no_batch_dim(self, device, reduction, weighted): C = 5 input = torch.randn(C, device=device).log_softmax(dim=-1) @@ -17044,6 +11098,7 @@ def test_cross_entropy_loss_prob_target_no_batch_dim(self, device, reduction, we loss_batch = loss_batch.squeeze(0) self.assertEqual(loss_no_batch, loss_batch) + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_cross_entropy_loss_index_target_unit_weights(self, device): # Test with k-dimensional loss. for k in range(5): @@ -17061,6 +11116,7 @@ def test_cross_entropy_loss_index_target_unit_weights(self, device): output_unit = m_unit(input, target) self.assertEqual(output, output_unit) + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_cross_entropy_loss_one_hot_target(self, device): # Test with k-dimensional loss. for k in range(5): @@ -17088,6 +11144,7 @@ def test_cross_entropy_loss_one_hot_target(self, device): output_one_hot = m(input, target_one_hot) self.assertEqual(output, output_one_hot) + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_cross_entropy_label_smoothing_errors(self, device): N, C = 3, 4 input_args = [ @@ -17100,6 +11157,7 @@ def test_cross_entropy_label_smoothing_errors(self, device): r"label_smoothing must be between 0\.0"): loss(*input_arg) + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_cross_entropy_label_smoothing_consistent_index_target_and_probs(self, device): N, C = 10, 4 ks = range(5) @@ -17133,6 +11191,7 @@ def test_cross_entropy_label_smoothing_consistent_index_target_and_probs(self, d self.assertEqual(output_with_prob, output_with_index, rtol=1e-07, atol=1e-05) + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_cross_entropy_label_smoothing_with_probs(self, device): N, C = 10, 4 ks = range(5) @@ -17159,7 +11218,7 @@ def test_cross_entropy_label_smoothing_with_probs(self, device): self.assertEqual(output_with_smoothing, output_with_manual_smoothing) - + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_cross_entropy_label_smoothing_weight_ignore_indices(self, device): reductions = ['none', 'sum', 'mean'] label_smoothings = [0.05, 0.15] @@ -17244,6 +11303,7 @@ def test_softshrink_negative(self, device): r'lambda must be greater or equal to 0, but found to be -1\.'): m(input) + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_fold(self, device): def test_dtype(fn, input, dtype): input = input.detach().clone().to(dtype=dtype).requires_grad_(True) @@ -17269,7 +11329,7 @@ def func(x): if device == 'cpu': test_dtype(func, x, torch.bfloat16) - + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_logsigmoid_out(self, device): # this isn't actually documented, but was broken previously: # https://github.com/pytorch/pytorch/issues/36499 @@ -17406,6 +11466,7 @@ def __init__(self): for p, pe in zip(test_model.parameters(), ref_model.parameters()): self.assertEqual(p.grad.to(devices[0]), pe.grad) + @skipIfMps def test_elu_inplace_overlap(self, device): x = torch.randn((1, 6), dtype=torch.bfloat16, device=device).expand((6, 6)) with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): @@ -17415,6 +11476,7 @@ def test_elu_inplace_overlap(self, device): # Merge into OpInfo? @onlyNativeDeviceTypes + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_elu_inplace_with_neg_alpha(self, device): a = torch.tensor([-1., 1.], device=device, requires_grad=True) b = torch.nn.functional.elu_(a.clone(), alpha=-2) @@ -17427,27 +11489,32 @@ def test_elu_inplace_with_neg_alpha(self, device): b.backward(torch.ones(2, device=device)) @expectedFailureMeta # https://github.com/pytorch/pytorch/issues/54897 + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_hardswish_inplace_overlap(self, device): x = torch.randn((1, 6), device=device).expand((6, 6)) with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): F.hardswish(x, inplace=True) + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_silu_inplace_overlap(self, device): x = torch.randn((1, 6), device=device).expand((6, 6)) with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): F.silu(x, inplace=True) @onlyNativeDeviceTypes + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_mish_inplace_overlap(self, device): x = torch.randn((1, 6), device=device).expand((6, 6)) with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): F.mish(x, inplace=True) + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_softplus_inplace_overlap(self, device): x = torch.randn((1, 6), device=device).expand((6, 6)) with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): F.softplus(x, out=x) + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_softplus_low_threshold(self, device): # Ensure gradients are computed correctly with a low threshold. model = torch.nn.Softplus(threshold=1).double() @@ -17456,11 +11523,13 @@ def test_softplus_low_threshold(self, device): output = model(input) torch.autograd.gradcheck(model, input) + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_softshrink_inplace_overlap(self, device): x = torch.randn((1, 6), device=device).expand((6, 6)) with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): F.softshrink(x, out=x) + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_leaky_relu_inplace_overlap(self, device): x = torch.randn((1, 6), device=device).expand((6, 6)) with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): @@ -17469,6 +11538,7 @@ def test_leaky_relu_inplace_overlap(self, device): F.leaky_relu_(x) # Merge into OpInfo? + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_leaky_relu_inplace_with_neg_slope(self, device): a = torch.tensor([-1., 1.], device=device, requires_grad=True) b = torch.nn.functional.leaky_relu_(a.clone(), -2) @@ -17481,6 +11551,7 @@ def test_leaky_relu_inplace_with_neg_slope(self, device): b.backward(torch.ones(2, device=device)) # Merge into OpInfo? + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_leaky_relu_inplace_with_zero_slope(self, device): a = torch.tensor([-2., 0., 2.], device=device, requires_grad=True) b = torch.nn.functional.leaky_relu_(a.clone(), 0.0) @@ -17520,6 +11591,7 @@ def test_softshrink(self, device): out = softshrink(x) self.assertEqual(out, expected, atol=1e-2, rtol=0) + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_threshold_inplace_overlap(self, device): # Inplace threshold is okay, because it is idempotent x = torch.randn((1, 6), device=device).expand((6, 6)) @@ -17527,6 +11599,7 @@ def test_threshold_inplace_overlap(self, device): F.threshold_(x, 0.5, 0.5) @onlyNativeDeviceTypes + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_triplet_margin_with_distance_loss_default_parity(self, device): # Test for `nn.TripletMarginWithDistanceLoss` and # `F.triplet_margin_with_distance_loss`. Checks @@ -17561,6 +11634,7 @@ def test_triplet_margin_with_distance_loss_default_parity(self, device): (anchor, positive, negative))) @onlyNativeDeviceTypes + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_triplet_margin_with_distance_loss(self, device): # Test for parity between `nn.TripletMarginWithDistanceLoss` and # `F.triplet_margin_with_distance_loss`. @@ -17604,6 +11678,7 @@ def cosine_distance(x, y): self.assertEqual(functional, modular, atol=1e-6, rtol=1e-6) self.assertEqual(traced, modular, atol=1e-6, rtol=1e-6) + @skipMPSIf(True, "the test doesn't work on MPS as double/complex types are not supported") def test_to_complex(self, device): m = nn.Linear(3, 5).to(device) self.assertIs(m, m.to(device)) @@ -17622,6 +11697,7 @@ def test_to_complex(self, device): @skipMeta @dtypes(torch.float32, torch.float64) + @dtypesIfMPS(torch.float32) def test_module_to_empty(self, device, dtype): class MyModule(nn.Module): def __init__(self, in_features, out_features, device=None, dtype=None): @@ -17649,6 +11725,7 @@ def forward(self, x): m(input) @skipMeta + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_skip_init(self, device): torch.manual_seed(1) m_initialized = torch.nn.Linear(5, 1) @@ -17662,6 +11739,7 @@ def test_skip_init(self, device): @dtypes(torch.float) @dtypesIfCUDA(torch.double, torch.float, torch.half) + @skipMPSIf(True, "the test doesn't work on MPS as double types are not supported") def test_transformerencoderlayer(self, device, dtype): # this is a deterministic test for TransformerEncoderLayer d_model = 4 @@ -17840,49 +11918,38 @@ def perm_fn(x): with cm: _test(batch_first=batch_first, training=training, atol=atol, rtol=rtol) + @onlyCPU @dtypes(torch.double) - @torch.no_grad() - def test_multihead_attn_fast_path_query_and_bias_have_different_dtypes(self, device, dtype): - mha = torch.nn.MultiheadAttention(4, 4, batch_first=True, dtype=dtype, device=device).eval() - mha.in_proj_bias = torch.nn.Parameter(mha.in_proj_bias.to(torch.half).to(device)) - query = torch.randn(4, 4, 4, dtype=dtype, device=device) - mha(query, query, query) + def test_transformerencoderlayer_fast_path(self, device, dtype): + """ + Test transformer fast path on CPU with different valid mask types and shapes + """ + d_model = 512 + nhead = 8 + batch_size = 32 + src_len = 10 - @dtypes(torch.double) - @torch.no_grad() - def test_multihead_attn_fast_path_small_test(self, device, dtype): - mha = torch.nn.MultiheadAttention(4, 4, batch_first=True, dtype=dtype, device=device).eval() - query = torch.randn(4, 4, 4, dtype=dtype, device=device) - mha(query, query, query) + model = torch.nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True, + device=device, dtype=dtype, dropout=0) + model.eval() - @dtypes(torch.double) - @torch.no_grad() - def test_multihead_attn_in_proj_bias_none(self, device, dtype): - mha = torch.nn.MultiheadAttention(2, 2, bias=False, dtype=dtype, device=device) - query = torch.rand(2, 2, 2, dtype=dtype, device=device) - mha(query, query, query) + # Batched inputs + src = torch.rand(batch_size, src_len, 512) - @dtypes(torch.double) - @torch.no_grad() - def test_multihead_attn_in_proj_weight_none(self, device, dtype): - # Setting kdim == vdim == 2 means that vdim != embed_dim - # will cause the logic to use per-input project weights, thereby - # forcing self.in_proj_weight = None - mha = torch.nn.MultiheadAttention(4, 4, vdim=2, kdim=2, dtype=dtype, device=device) - query = torch.rand(4, 4, 4, dtype=dtype, device=device) - key = torch.rand(4, 4, 2, dtype=dtype, device=device) - mha(query, key, key) + # Attention mask of shape (src_len, src_len) + src_mask = torch.zeros(src_len, src_len).to(torch.bool) + with torch.no_grad(): + model(src, src_mask=src_mask) - @onlyCPU - @dtypes(torch.double) - def test_transformerencoderlayer_fast_path(self, device, dtype): - model = torch.nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True, device=device, dtype=dtype) - src = torch.rand(32, 10, 512) - src_mask = torch.zeros(10, 10).to(torch.bool) + # Padding mask of shape (batch_size, src_len) + src_key_padding_mask = torch.zeros(batch_size, src_len).to(torch.bool) + with torch.no_grad(): + model(src, src_key_padding_mask=src_key_padding_mask) - model.eval() + # Provide both masks with torch.no_grad(): - model(src, src_mask) + model(src, src_mask=src_mask, src_key_padding_mask=src_key_padding_mask) + @dtypes(torch.float) @dtypesIfCUDA(torch.half, torch.float) @@ -17965,291 +12032,6 @@ def perm_fn(x): _test(activation=activation, batch_first=batch_first, training=training) -class TestModuleGlobalHooks(TestCase): - - def tearDown(self): - nn.modules.module._global_backward_hooks = OrderedDict() - nn.modules.module._global_forward_hooks = OrderedDict() - nn.modules.module._global_forward_pre_hooks = OrderedDict() - - @skipIfTorchDynamo("TorchDynamo does not work well with hooks") - def test_module_global_hooks(self): - module = nn.Sigmoid - - module_1 = module() - module_2 = module() - module_3 = module() - - input = torch.ones(5, 5, requires_grad=True) - - counter = { - 'forwards': 0, - 'backwards': 0 - } - - def fw_hook(inc, h_module, input, output): - self.assertIsInstance(input, tuple) - self.assertTrue(isinstance(output, torch.Tensor)) - self.assertTrue(isinstance(h_module, module)) - self.assertEqual(input[0], torch.ones(5, 5)) - self.assertEqual(output, torch.empty(5, 5).fill_(1 / (1 + 1 / math.e))) - counter['forwards'] += inc - - def bw_hook(inc, h_module, grad_input, grad_output): - self.assertIsInstance(grad_input, tuple) - self.assertIsInstance(grad_output, tuple) - self.assertTrue(isinstance(h_module, module)) - self.assertEqual(grad_output[0], torch.ones(5, 5) * 2) - counter['backwards'] += inc - - test_fwd = nn.modules.module.register_module_forward_hook(lambda *args: fw_hook(1, *args)) - - module_1(input) - module_2(input) - module_3(input) - self.assertEqual(counter['forwards'], 3) - self.assertEqual(counter['backwards'], 0) - - test_bwd = nn.modules.module.register_module_backward_hook( - lambda *args: bw_hook(1, *args)) - - output_1 = module_1(input) - output_2 = module_2(input) - output_3 = module_3(input) - self.assertEqual(counter['forwards'], 6) - self.assertEqual(counter['backwards'], 0) - - output_1.backward(torch.ones(5, 5) * 2, retain_graph=True) - output_2.backward(torch.ones(5, 5) * 2, retain_graph=False) - output_3.backward(torch.ones(5, 5) * 2, retain_graph=False) - self.assertEqual(counter['forwards'], 6) - self.assertEqual(counter['backwards'], 3) - - output_1.backward(torch.ones(5, 5) * 2, retain_graph=True) - self.assertEqual(counter['forwards'], 6) - self.assertEqual(counter['backwards'], 4) - - test2_fwd = nn.modules.module.register_module_forward_hook(lambda *args: fw_hook(2, *args)) - - output = module_1(input) - output = module_2(input) - output = module_3(input) - self.assertEqual(counter['forwards'], 15) - self.assertEqual(counter['backwards'], 4) - - test2_bwd = nn.modules.module.register_module_backward_hook(lambda *args: bw_hook(2, *args)) - - module_1(input).backward(torch.ones(5, 5) * 2) - self.assertEqual(counter['forwards'], 18) - self.assertEqual(counter['backwards'], 7) - - test2_bwd.remove() - - module_2(input).backward(torch.ones(5, 5) * 2) - self.assertEqual(counter['forwards'], 21) - self.assertEqual(counter['backwards'], 8) - - test2_fwd.remove() - - module_3(input).backward(torch.ones(5, 5) * 2) - self.assertEqual(counter['forwards'], 22) - self.assertEqual(counter['backwards'], 9) - - test_fwd.remove() - test_bwd.remove() - - def test_module_global_hook_invalid_outputs(self): - module = nn.Sigmoid() - input = torch.randn(5, 5, requires_grad=True) - - def bw_fail1(self, grad_input, grad_output): - return grad_input[:-1] - - def bw_fail2(self, grad_input, grad_output): - return grad_input + (torch.randn(2, 2),) - - with nn.modules.module.register_module_backward_hook(bw_fail1): - with self.assertRaisesRegex(RuntimeError, 'got 0, but expected 1'): - module(input).sum().backward() - - with nn.modules.module.register_module_backward_hook(bw_fail2): - with self.assertRaisesRegex(RuntimeError, 'got 2, but expected 1'): - module(input).sum().backward() - - @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/847") - def test_module_backward_global_hook_writeable(self): - module = nn.Sigmoid() - input = torch.randn(5, 5, requires_grad=True) - sig_x = torch.sigmoid(input) - - def bw_hook(module, grad_input, grad_output): - for grad in grad_input: - self.assertTrue(isinstance(grad, torch.Tensor)) - for grad in grad_output: - self.assertTrue(isinstance(grad, torch.Tensor)) - return tuple(gi * 2 for gi in grad_input) - - nn.modules.module.register_module_backward_hook(bw_hook) - module(input).backward(torch.ones(5, 5)) - expected_grad = sig_x * (1 - sig_x) * 2 - self.assertEqual(input.grad, expected_grad) - - @skipIfTorchDynamo("TorchDynamo does not work well with hooks") - def test_module_global_forward_preforward_hook_writeable(self): - module = nn.Sigmoid() - input = torch.randn(5, 5, requires_grad=True) - sig_x = torch.sigmoid(input) - - def forward_pre_hook(m, input): - return torch.nn.functional.relu(input[0]) - - def forward_hook(m, input, output): - return -output - - nn.modules.module.register_module_forward_pre_hook(forward_pre_hook) - nn.modules.module.register_module_forward_hook(forward_hook) - output = module(input) - expected_res = -torch.sigmoid(torch.nn.functional.relu(input)) - self.assertEqual(output, expected_res) - output.backward(torch.ones(5, 5) * 2, retain_graph=True) - mask = (input > 0).double() - expected_grad = -sig_x * (1 - sig_x) * 2 * mask - self.assertEqual(input.grad, expected_grad) - - @skipIfTorchDynamo("TorchDynamo does not work well with hooks") - def test_module_forward_preforward_hook_removable(self): - """ - This test is to test when multiple pre-forward hook functions can be - registered successfully and used correctly, if the handle can be removable - during the pre-forward hook function call. - """ - module = nn.Sigmoid() - - def removable_hook(m, input): - nonlocal handle - handle.remove() - return input - - def removable_hook_2(m, input): - nonlocal handle_2 - handle_2.remove() - return input - - handle = module.register_forward_pre_hook(removable_hook) - handle_2 = module.register_forward_pre_hook(removable_hook_2) - - # make sure hook register is successful - self.assertEqual(len(handle.hooks_dict_ref()), 2) - self.assertEqual(len(handle_2.hooks_dict_ref()), 2) - - input = torch.randn(2, 2) - output = module(input) - self.assertEqual(torch.sigmoid(input), output) - - # make sure hook removal is successful - self.assertFalse(handle.id in handle.hooks_dict_ref()) - self.assertFalse(handle_2.id in handle.hooks_dict_ref()) - self.assertEqual(len(handle.hooks_dict_ref()), 0) - self.assertEqual(len(handle_2.hooks_dict_ref()), 0) - - @skipIfTorchDynamo("TorchDynamo does not work well with hooks") - def test_module_forward_forward_hook_removable(self): - """ - This test is to test when multiple forward hook functions can be registered - successfully and used correctly, if the handle can be removable during the - forward hook function call. - """ - module = nn.Sigmoid() - - def removable_hook(m, input, output): - nonlocal handle - handle.remove() - return output - - def removable_hook_2(m, input, output): - nonlocal handle_2 - handle_2.remove() - return output - - handle = module.register_forward_hook(removable_hook) - handle_2 = module.register_forward_hook(removable_hook_2) - - # make sure hook register is successful - self.assertEqual(len(handle.hooks_dict_ref()), 2) - self.assertEqual(len(handle_2.hooks_dict_ref()), 2) - - input = torch.randn(2, 2) - output = module(input) - self.assertEqual(torch.sigmoid(input), output) - - # make sure hook removal is successful - self.assertFalse(handle.id in handle.hooks_dict_ref()) - self.assertFalse(handle_2.id in handle.hooks_dict_ref()) - self.assertEqual(len(handle.hooks_dict_ref()), 0) - self.assertEqual(len(handle_2.hooks_dict_ref()), 0) - - @skipIfTorchDynamo("TorchDynamo does not work well with hooks") - def test_global_and_local_hooks_order(self): - module = nn.Sigmoid() - - global_forward_pre_called = False - local_forward_pre_called = False - global_forward_called = False - local_forward_called = False - global_backward_called = False - local_backward_called = False - - def global_forward_pre_hook(m, input): - nonlocal global_forward_pre_called - self.assertTrue(not local_forward_pre_called) - global_forward_pre_called = True - return input - - def local_forward_pre_hook(m, input): - nonlocal local_forward_pre_called - self.assertTrue(global_forward_pre_called) - local_forward_pre_called = True - return input - - def global_forward_hook(m, input, output): - nonlocal global_forward_called - self.assertTrue(not local_forward_called) - global_forward_called = True - return output - - def local_forward_hook(m, input, output): - nonlocal local_forward_called - self.assertTrue(global_forward_called) - local_forward_called = True - return output - - def global_backward_hook(m, input, output): - nonlocal global_backward_called - self.assertTrue(not local_backward_called) - global_backward_called = True - return input - - def local_backward_hook(m, input, output): - nonlocal local_backward_called - self.assertTrue(global_backward_called) - local_backward_called = True - return input - - input = torch.randn(5, 5, requires_grad=True) - nn.modules.module.register_module_forward_pre_hook(global_forward_pre_hook) - module.register_forward_pre_hook(local_forward_pre_hook) - nn.modules.module.register_module_forward_hook(global_forward_hook) - module.register_forward_hook(local_forward_hook) - nn.modules.module.register_module_backward_hook(global_backward_hook) - module.register_backward_hook(local_backward_hook) - - output = module(input) - self.assertTrue(local_forward_called and local_forward_pre_called and global_forward_called and global_forward_pre_called) - - output.backward(torch.ones(5, 5), retain_graph=True) - self.assertTrue(local_backward_called and global_backward_called) - - class TestFunctionalPickle(TestCase): # issue gh-38137 @@ -18257,210 +12039,33 @@ def test_pickle_softsign(self): # Make sure it does not throw an exception s = pickle.dumps(F.softsign) -def _hook_to_pickle(*args, **kwargs): - pass - -class TestStateDictHooks(TestCase): - - def test_load_state_dict_pre_hook(self): - - m = nn.Linear(10, 10) - m_state_dict = m.state_dict() - - m_load = nn.Linear(10, 10) - - hook_called = 0 - - def hook_without_module(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): - self.assertEqual(m_state_dict, state_dict) - nonlocal hook_called - hook_called += 1 - - def hook_with_module(module, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): - self.assertEqual(m_state_dict, state_dict) - self.assertTrue(m_load is module) - nonlocal hook_called - hook_called += 1 - - hook_called = 0 - m_load._register_load_state_dict_pre_hook(hook_without_module) - m_load.load_state_dict(m_state_dict) - self.assertEqual(1, hook_called) - - hook_called = 0 - m_load._register_load_state_dict_pre_hook(hook_with_module, True) - m_load.load_state_dict(m_state_dict) - self.assertEqual(2, hook_called) - - def test_no_extra_ref_to_module(self): - try: - gc.disable() - m = nn.Linear(10, 10) - - m._register_load_state_dict_pre_hook(_hook_to_pickle, True) - weak_m = weakref.ref(m) - del m - - self.assertEqual(weak_m(), None) - finally: - gc.enable() - - def test_pickled_hook(self): - m = nn.Linear(10, 10) - m._register_load_state_dict_pre_hook(_hook_to_pickle, True) - pickle.loads(pickle.dumps(m)) - - def test_load_state_dict_module_pre_hook(self): - hook_called = 0 - - # Test with module instance method as hook - class MyModule(nn.Module): - def __init__(self): - super(MyModule, self).__init__() - self.foo = torch.nn.Parameter(torch.rand(10)) - - def my_pre_load_hook(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): - assert [] == error_msgs - assert [] == unexpected_keys - assert [] == missing_keys - assert strict - nonlocal hook_called - hook_called += 1 - - def my_pre_load_hook_with_module( - self, - module, - state_dict, - prefix, - local_metadata, - strict, - missing_keys, - unexpected_keys, - error_msgs, - ): - assert [] == error_msgs - assert [] == unexpected_keys - assert [] == missing_keys - assert strict - assert self is module - nonlocal hook_called - hook_called += 1 - - # Test that hooks registered on a submodule are also called - # appropriately, i.e. with the submodule as module argument in - # my_pre_load_hook_with_module. - class MyModuleContainer(nn.Module): - def __init__(self, mod): - super().__init__() - self.mod = mod - - for ctor in [MyModuleContainer, lambda x: x]: - m = ctor(MyModule()) - state_dict = m.state_dict() - if isinstance(m, MyModuleContainer): - mod = m.mod - else: - mod = m - - hook_called = 0 - mod._register_load_state_dict_pre_hook( - mod.my_pre_load_hook - ) - m.load_state_dict(state_dict) - self.assertEqual(1, hook_called) - - hook_called = 0 - mod._register_load_state_dict_pre_hook( - mod.my_pre_load_hook_with_module, True - ) - m.load_state_dict(state_dict) - self.assertEqual(2, hook_called) - - def test_load_state_dict_post_hook(self): - hook_called = 0 - - class MyModule(nn.Module): - def __init__(self): - super(MyModule, self).__init__() - self.foo = torch.nn.Parameter(torch.rand(10)) - - def my_post_load_hook(self, module, incompatible_keys): - assert module is self - nonlocal hook_called - incompatible_keys.missing_keys.append("foo") - incompatible_keys.unexpected_keys.append("bar") - hook_called += 1 - - nested = MyModule() - wrapped = nn.ModuleList([nested]) - handle = nested.register_load_state_dict_post_hook( - nested.my_post_load_hook, - ) - # Hook must be called even if it is wrapped - ret = wrapped.load_state_dict(wrapped.state_dict(), strict=False) - self.assertEqual(hook_called, 1) - # Ensure that the hook modified missing_keys and unexpected_keys - missing = ret.missing_keys - unexpected = ret.unexpected_keys - self.assertEqual(missing, ["foo"]) - self.assertEqual(unexpected, ["bar"]) - # When called with strict=True, the error raised should mention the - # missing and unexpected keys the hook added. - with self.assertRaisesRegex(RuntimeError, "foo.*\n.*bar"): - wrapped.load_state_dict(wrapped.state_dict(), strict=True) - self.assertEqual(hook_called, 2) - # Removing the hook via handle.remove() should cause it not to - # fire anymore. - handle.remove() - # Hook did not run so it should not have added any keys - ret = wrapped.load_state_dict(wrapped.state_dict(), strict=False) - self.assertEqual(ret.missing_keys, []) - self.assertEqual(ret.unexpected_keys, []) - # hook_called should not have been incremented - self.assertEqual(hook_called, 2) - - def load_hook_clear_incompatible(module, incompatible_keys): - incompatible_keys.missing_keys.clear() - incompatible_keys.unexpected_keys.clear() - - nested.register_load_state_dict_post_hook(load_hook_clear_incompatible) - state_dict = wrapped.state_dict() - state_dict["extra"] = torch.ones(1) - # load state_dict with strict=True should not throw. - ret = wrapped.load_state_dict(state_dict, strict=True) - # explicitly ensure that the post hook clearned out incompatible_keys - self.assertEqual([], ret.missing_keys) - self.assertEqual([], ret.unexpected_keys) - - @unittest.skipIf(IS_WINDOWS, "Tempfile permission issue on windows") - def test_load_state_dict_post_hook_backward_compatibility(self): - def my_post_load_hook(mod, _): - nonlocal called - called = True - - for m in [nn.Softmin(10), nn.Softmax(10), nn.LogSoftmax(10)]: - called = False - sd = deepcopy(m.state_dict()) - self.assertTrue(hasattr(m, '_load_state_dict_post_hooks')) - # Simulate an older model that did not have this attr - delattr(m, '_load_state_dict_post_hooks') - # Save and load, and ensure that load_state_dict works (without proper - # BC we would run into errors because this attribute would be expected). - # In particular, Softmax runs into the issue described here: - # https://github.com/pytorch/pytorch/issues/77280 - with NamedTemporaryFile() as f: - # Note that torch.save / torch.load is not recommended to save/load - # modules. - torch.save(m, f.name) - m = torch.load(f.name) - m.load_state_dict(sd) - self.assertFalse(called) - - # Ensure hooks can be registered and called. - m.register_load_state_dict_post_hook(my_post_load_hook) - m.load_state_dict(sd) - self.assertTrue(called) +class TestFusionUtils(TestCase): + def test_fuse_conv_bn_requires_grad(self): + conv = torch.nn.Conv2d(3, 3, 3) + bn = torch.nn.BatchNorm2d(3) + cases = itertools.product([True, False], [True, False]) + for w_rg, b_rg in cases: + conv.weight.requires_grad = w_rg + conv.bias.requires_grad = b_rg + weight, bias = \ + fuse_conv_bn_weights(conv.weight, conv.bias, + bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias) + self.assertEqual(weight.requires_grad, w_rg) + self.assertEqual(bias.requires_grad, b_rg) + + def test_fuse_linear_bn_requires_grad(self): + linear = torch.nn.Linear(3, 3) + bn = torch.nn.BatchNorm1d(3) + cases = itertools.product([True, False], [True, False]) + for w_rg, b_rg in cases: + linear.weight.requires_grad = w_rg + linear.bias.requires_grad = b_rg + weight, bias = \ + fuse_linear_bn_weights(linear.weight, linear.bias, + bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias) + self.assertEqual(weight.requires_grad, w_rg) + self.assertEqual(bias.requires_grad, b_rg) instantiate_device_type_tests(TestNNDeviceType, globals()) instantiate_parametrized_tests(TestNN) diff --git a/test/test_nvfuser_dynamo.py b/test/test_nvfuser_dynamo.py new file mode 100644 index 0000000000000..e59ead80fe13c --- /dev/null +++ b/test/test_nvfuser_dynamo.py @@ -0,0 +1,148 @@ +# Owner(s): ["module: nvfuser"] + +import unittest +import warnings +from functools import partial + +import torch +import torch._dynamo as torchdynamo +from torch.testing import make_tensor +from torch.testing._internal.common_utils import ( + IS_WINDOWS, + run_tests, + skipIfTorchDynamo, + TEST_WITH_ROCM, + TestCase, +) +from torch.testing._internal.jit_utils import RUN_CUDA + +RUN_NVFUSER = RUN_CUDA and not TEST_WITH_ROCM + + +def is_pre_volta(): + if not RUN_NVFUSER: + return False + prop = torch.cuda.get_device_properties(torch.cuda.current_device()) + return prop.major < 7 + + +def is_networkx_available(): + try: + import networkx # noqa: F401 + + return True + except ImportError: + return False + + +@skipIfTorchDynamo("Not a suitable test for TorchDynamo") +@unittest.skipIf(IS_WINDOWS, "TorchDynamo is not supported on Windows") +@unittest.skipIf(not RUN_NVFUSER, "requires CUDA") +@unittest.skipIf(is_pre_volta(), "Only supported on Volta and newer devices.") +class TestNvFuserDynamo(TestCase): + def test_basic(self): + input1 = make_tensor((2, 4, 8), device="cuda", dtype=torch.float32) + input2 = make_tensor((2, 4, 8), device="cuda", dtype=torch.float32) + + @torchdynamo.optimize("nvprims_nvfuser") + def func(a, b): + return a.sin() + b.cos() + + # No warnings and no errors + with warnings.catch_warnings(record=True) as w: + nvfuser_result = func(input1, input2) + self.assertEqual(len(w), 0) + eager_result = func.__wrapped__(input1, input2) + self.assertEqual(eager_result, nvfuser_result) + + @unittest.skipIf(not is_networkx_available(), "networkx not available") + def test_min_cut(self): + from functorch.compile import default_partition + from torch._dynamo.optimizations.training import nvprims_fw_bw_partition_fn + + def get_fw_bw_graph(f, inps, partitioner): + from functorch.compile import aot_function + + # Helper functions are taken from functorch/test_aotdispatch.py + def extract_graph(fx_g, _, graph_cell): + graph_cell[0] = fx_g + return fx_g + + fw_graph_cell = [None] + bw_graph_cell = [None] + aot_function( + f, + fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell), + bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell), + partition_fn=partitioner, + )(*inps).sum().backward() + return (fw_graph_cell[0], bw_graph_cell[0]) + + def get_ins_outs(fx_g): + ins = [] + outs = [] + for n in fx_g.graph.nodes: + if n.op == "placeholder": + ins.append(n) + elif n.op == "output": + outs = tuple(n.args[0]) + return ins, outs + + def get_num_ins_outs(fx_g): + return tuple(len(i) for i in get_ins_outs(fx_g)) + + def func(x): + return x * x * x + + input1 = make_tensor( + (3,), device="cpu", dtype=torch.float32, requires_grad=True + ) + fw_graph, bw_graph = get_fw_bw_graph(func, [input1], default_partition) + self.assertEqual(get_num_ins_outs(fw_graph), (1, 3)) + self.assertEqual(get_num_ins_outs(bw_graph), (3, 1)) + + input1 = make_tensor( + (3,), device="cpu", dtype=torch.float32, requires_grad=True + ) + fw_graph, bw_graph = get_fw_bw_graph(func, [input1], nvprims_fw_bw_partition_fn) + self.assertEqual(get_num_ins_outs(fw_graph), (1, 2)) + self.assertEqual(get_num_ins_outs(bw_graph), (2, 1)) + + def test_batch_norm_implicit_dtype_promotion(self): + input1 = make_tensor((2, 3, 4, 5), device="cuda", dtype=torch.float32) + input2 = make_tensor((5, 5), device="cuda", dtype=torch.float32) + w = make_tensor((3), device="cuda", dtype=torch.float32) + b = make_tensor((3), device="cuda", dtype=torch.float32) + + @torchdynamo.optimize("nvprims_nvfuser") + def func(mat1, mat2, w, b): + o = torch.matmul(mat1, mat2) + return torch.batch_norm(o, w, b, None, None, True, 1e-2, 1e-5, True) + + # No warnings and no errors + with torch.cuda.amp.autocast(): + with warnings.catch_warnings(record=True) as warning: + nvfuser_result = func(input1, input2, w, b) + self.assertEqual(len(warning), 0) + eager_result = func.__wrapped__(input1, input2, w, b) + self.assertEqual(eager_result, nvfuser_result) + + def test_dtype_correctness(self): + input1 = make_tensor((2, 4, 8), device="cuda", dtype=torch.float16) + + @torchdynamo.optimize("nvprims_nvfuser") + def func(a): + tmp = a + 1.0 + # nvfuser would promote output to fp32 in math, FusionDefinition should cast output dtype back + return torch.where(tmp > 0, tmp, 0.0) + + # No warnings and no errors + with warnings.catch_warnings(record=True) as w: + nvfuser_result = func(input1) + self.assertEqual(len(w), 0) + eager_result = func.__wrapped__(input1) + self.assertEqual(eager_result, nvfuser_result) + + +if __name__ == "__main__": + run_tests() diff --git a/test/test_nvfuser_frontend.py b/test/test_nvfuser_frontend.py index 28c5894a002c9..9974eb29c7271 100644 --- a/test/test_nvfuser_frontend.py +++ b/test/test_nvfuser_frontend.py @@ -189,6 +189,24 @@ def test_broadcast_mixing(self) : eager_out = refs.add(input1, prims.broadcast_in_dim(input2, [3, 3], [0])) self.assertEqual(eager_out, nvf_out) + def test_ops_broadcast(self) : + fs = Fusion() + with FusionDefinition(fs) as fd : + t0 = fd.define_tensor(1) + t1 = fd.define_tensor(3) + + t0_b = fd.ops.broadcast(t0, [True, False, True]) + t2 = fd.ops.add(t0_b, t1) + + fd.add_output(t2) + + input1 = torch.randn(3, device='cuda') + input2 = torch.randn(2, 3, 4, device='cuda') + + nvf_out = fs.execute([input1, input2])[0] + eager_out = refs.add(prims.broadcast_in_dim(input1, [2, 3, 4], [1]), input2) + self.assertEqual(eager_out, nvf_out) + def test_prim_layer_norm_fwd(self) : def primitive_definition( inputs: torch.Tensor, diff --git a/test/test_ops.py b/test/test_ops.py index c63de0a4778d3..b78d2d8e096e9 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -7,11 +7,13 @@ import itertools import torch import contextlib +import re +import os + from collections import defaultdict from importlib import import_module from torch.utils._pytree import tree_map from typing import Dict - from torch.testing import make_tensor from torch.testing._internal.common_dtype import ( floating_and_complex_types_and, @@ -26,6 +28,7 @@ IS_SANDCASTLE, clone_input_helper, IS_CI, + set_default_dtype, suppress_warnings, noncontiguous_like, TEST_WITH_ASAN, @@ -35,7 +38,7 @@ IS_FBCODE, first_sample, parametrize, - skipIfSlowGradcheckEnv, + skipIfTorchInductor, slowTest, ) from torch.testing._internal.common_methods_invocations import ( @@ -56,7 +59,6 @@ onlyCPU, onlyNativeDeviceTypes, OpDTypes, - skipCUDAIfRocm, skipMeta, ) from torch._subclasses.fake_tensor import ( @@ -108,7 +110,6 @@ # Tests that apply to all operators and aren't related to any particular # system -@skipIfSlowGradcheckEnv class TestCommon(TestCase): exact_dtype = True @@ -151,6 +152,90 @@ def test_multiple_devices(self, devices, dtype, op): "Skipped! Only supports single tensor or iterable of tensor outputs." ) + def test_pointwise_tag_coverage(self): + + pytorch_dir = os.path.abspath(__file__ + "/../../") + files = [ + "aten/src/ATen/native/UnaryOps.cpp", + "aten/src/ATen/native/BinaryOps.cpp", + "aten/src/ATen/native/PointwiseOps.cpp", + "aten/src/ATen/native/TensorCompare.cpp", + ] + + allowed_functions = ( + # reduction version of these operators + "aten.max.default", + "aten.max.dim", + "aten.max.dim_max", + "aten.max.names_dim", + "aten.max.names_dim_max", + "aten.max.unary_out", + "aten.min.default", + "aten.min.dim", + "aten.min.dim_min", + "aten.min.names_dim", + "aten.min.names_dim_min", + # not pointwise + "aten.isin.Tensor_Tensor", + "aten.isin.Tensor_Tensor_out", + "aten.isin.Tensor_Scalar", + "aten.isin.Tensor_Scalar_out", + "aten.isin.Scalar_Tensor", + "aten.isin.Scalar_Tensor_out", + "aten.mode.default", + "aten.mode.dimname", + "aten.mode.dimname_out", + "aten.mode.values", + ) + + regex = re.compile(r"DEFINE_DISPATCH\(.*_stub") + + def get_opoverloadpacket_from_dispatch(kernel): + if hasattr(torch.ops.aten, kernel): + return kernel + if hasattr(torch.ops.aten, f"__{kernel}__"): + return f"__{kernel}__" + if hasattr(torch.ops.aten, f"special_{kernel}"): + return f"special_{kernel}" + if "_" in kernel: + kernel_split = kernel.split("_") + new_kernel = "_".join(kernel_split[:-1]) + if hasattr(torch.ops.aten, new_kernel): + return new_kernel + + # could not find op from kernel dispatch string + self.assertTrue(False) + + for file_name in files: + with open(os.path.join(pytorch_dir, file_name), "r") as f: + lines = f.read() + matches = regex.findall(lines) + for match in matches: + kernel = match[len("DEFINE_DISPATCH("):-len("_stub")] + + # no op definition for it, but defined with DEFINE_DISPATCH ? + if kernel == "trigamma": + continue + + kernel = get_opoverloadpacket_from_dispatch(kernel) + overloadpacket = getattr(torch.ops.aten, kernel) + + for overload_name in overloadpacket.overloads(): + overload = getattr(overloadpacket, overload_name) + + if not torch._C._dispatch_has_kernel(overload.name()): + continue + + # TODO: tags are not propagated to generated overload, + # and there's no way of specifying them + if torch.Tag.generated in overload.tags: + continue + + if str(overload) in allowed_functions: + continue + + self.assertTrue(torch.Tag.pointwise in overload.tags) + # Tests that the function and its (ndarray-accepting) reference produce the same # values on the tensors from sample_inputs func for the corresponding op. # This test runs in double and complex double precision because @@ -161,16 +246,12 @@ def test_multiple_devices(self, devices, dtype, op): @suppress_warnings @ops(_ref_test_ops, allowed_dtypes=(torch.float64, torch.long, torch.complex128)) def test_numpy_ref(self, device, dtype, op): - try: - # Sets the default dtype to NumPy's default dtype of double - cur_default = torch.get_default_dtype() - torch.set_default_dtype(torch.double) + # Sets the default dtype to NumPy's default dtype of double + with set_default_dtype(torch.double): for sample_input in op.reference_inputs(device, dtype): self.compare_with_reference( op, op.ref, sample_input, exact_dtype=(dtype is not torch.long) ) - finally: - torch.set_default_dtype(cur_default) # Tests that the cpu and gpu results are consistent @onlyCUDA @@ -184,7 +265,7 @@ def to_cpu(arg): return arg.to(device='cpu') return arg - samples = op.sample_inputs(device, dtype) + samples = op.reference_inputs(device, dtype) for sample in samples: cpu_sample = sample.transform(to_cpu) @@ -209,6 +290,7 @@ def to_cpu(arg): @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") @onlyNativeDeviceTypes @ops(python_ref_db) + @skipIfTorchInductor("Takes too long for inductor") def test_python_ref_meta(self, device, dtype, op): with FakeTensorMode() as mode: pass @@ -374,6 +456,7 @@ def _distance(a, b): @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") @onlyNativeDeviceTypes @ops(python_ref_db) + @skipIfTorchInductor("Takes too long for inductor") def test_python_ref(self, device, dtype, op): # In this test, primTorch refs call into the refs namespace # For example, a ref with torch.foo in it will calls refs.foo instead @@ -386,6 +469,7 @@ def test_python_ref(self, device, dtype, op): @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") @onlyNativeDeviceTypes @ops(python_ref_db) + @skipIfTorchInductor("Takes too long for inductor") def test_python_ref_torch_fallback(self, device, dtype, op): # In this test, refs call into the torch namespace (after the initial invocation) # For example, a ref with torch.foo in it will call torch.foo instead of refs.foo @@ -394,9 +478,9 @@ def test_python_ref_torch_fallback(self, device, dtype, op): @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") @onlyCUDA - @skipCUDAIfRocm @ops(python_ref_db) @parametrize('executor', ['aten', 'nvfuser']) + @skipIfTorchInductor("Takes too long for inductor") def test_python_ref_executor(self, device, dtype, op, executor): # TODO: Not all dtypes are supported with nvfuser from torch._prims_common import _torch_dtype_to_nvfuser_dtype_map @@ -417,9 +501,10 @@ def test_python_ref_executor(self, device, dtype, op, executor): # skip zero-dim tensors for some composites of reduction operations and view skip_zero_dim_ops = [ - "_refs.softmax", "_refs.logsumexp", "_refs.log_softmax", + "_refs.native_group_norm", + "_refs.softmax", "_refs.sum_to_size", "ops.nvprims.view", ] @@ -452,11 +537,13 @@ def test_errors(self, device, op): for ei in error_inputs: si = ei.sample_input with self.assertRaisesRegex(ei.error_type, ei.error_regex): - op(si.input, *si.args, **si.kwargs) + out = op(si.input, *si.args, **si.kwargs) + self.assertFalse(isinstance(out, type(NotImplemented))) @skipMeta @onlyNativeDeviceTypes @ops([op for op in python_ref_db if op.error_inputs_func is not None], dtypes=OpDTypes.none) + @skipIfTorchInductor("Takes too long for inductor") def test_python_ref_errors(self, device, op): mode = FakeTensorMode() with mode: @@ -471,8 +558,7 @@ def _to_tensormeta(x): for ei in error_inputs: si = ei.sample_input meta_sample = si.transform(_to_tensormeta) - # TODO: match strings - with self.assertRaisesRegex(ei.error_type, ""): + with self.assertRaisesRegex(ei.error_type, ei.error_regex): op(meta_sample.input, *meta_sample.args, **meta_sample.kwargs) # Tests that the function produces the same result when called with @@ -500,11 +586,6 @@ def test_noncontiguous_samples(self, device, dtype, op): noncontig_sample.kwargs, ) - # Verifies sample input tensors should have no grad or history - sample_tensor = t_inp if isinstance(t_inp, torch.Tensor) else t_inp[0] - assert sample_tensor.grad is None - assert sample_tensor.grad_fn is None - # validates forward expected = op(t_inp, *t_args, **t_kwargs) actual = op(n_inp, *n_args, **n_kwargs) @@ -1069,6 +1150,7 @@ def _test_inplace_preserve_storage(samples, variants): # Reference testing for operations in complex32 against complex64. # NOTE: We test against complex64 as NumPy doesn't have a complex32 equivalent dtype. @ops(op_db, allowed_dtypes=(torch.complex32,)) + @skipIfTorchInductor("Inductor does not support complex dtype yet") def test_complex_half_reference_testing(self, device, dtype, op): if not op.supports_dtype(torch.complex32, device): unittest.skip("Does not support complex32") @@ -1096,8 +1178,10 @@ def test_complex_half_reference_testing(self, device, dtype, op): # `cfloat` input -> `float` output self.assertEqual(actual, expected, exact_dtype=False) + @ops(op_db, allowed_dtypes=(torch.bool,)) @unittest.skipIf(TEST_WITH_UBSAN, "Test uses undefined behavior") + @skipIfTorchInductor("Inductor does not support view with dtype yet") def test_non_standard_bool_values(self, device, dtype, op): # Test boolean values other than 0x00 and 0x01 (gh-54789) def convert_boolean_tensors(x): @@ -1383,7 +1467,6 @@ def test_forward_ad(self, device, dtype, op): op.get_op(), args, kwargs, op.gradcheck_wrapper, self.assertEqual) -@skipIfSlowGradcheckEnv class TestMathBits(TestCase): # Tests that # 1. The operator's output for physically conjugated/negated tensors and conjugate/negative view tensors @@ -1498,6 +1581,7 @@ def clone_and_perform_view(input, **kwargs): self.assertEqual(tensor.grad, cloned1_tensor.grad) @ops(ops_and_refs, allowed_dtypes=(torch.cfloat,)) + @skipIfTorchInductor("Inductor does not support complex dtype yet") def test_conj_view(self, device, dtype, op): if not op.test_conjugated_samples: self.skipTest("Operation doesn't support conjugated inputs.") @@ -1520,6 +1604,7 @@ def test_conj_view(self, device, dtype, op): ) @ops(ops_and_refs, allowed_dtypes=(torch.double,)) + @skipIfTorchInductor("Inductor does not support complex dtype yet") def test_neg_view(self, device, dtype, op): if not op.test_neg_view: self.skipTest("Operation not tested with tensors with negative bit.") @@ -1539,6 +1624,7 @@ def test_neg_view(self, device, dtype, op): ) @ops(ops_and_refs, allowed_dtypes=(torch.cdouble,)) + @skipIfTorchInductor("Inductor does not support complex dtype yet") def test_neg_conj_view(self, device, dtype, op): if not op.test_neg_view: self.skipTest("Operation not tested with tensors with negative bit.") @@ -1575,7 +1661,7 @@ def is_bit_set(x): def check_inplace_view(func, input, rs, input_size, input_strides): if func is None: return - # TODO: extend this test to test ops with multiple outputs and ops like native_batch_norm.out + # TODO: extend this test to test ops with multiple outputs and ops like native_batch_norm(_legit).out # which mutate not necessarily the first input. if isinstance(rs, torch.Tensor) and rs is input: unequal_size = rs.size() != input_size @@ -1593,7 +1679,6 @@ def check_inplace_view(func, input, rs, input_size, input_strides): # A mode that when enabled runs correctness checks to ensure # that operators have expected tags based on their input and # ouput tensor properties -@skipIfSlowGradcheckEnv class TestTagsMode(TorchDispatchMode): def __torch_dispatch__(self, func, types, args=(), kwargs=None): if isinstance(args[0], torch.Tensor): @@ -1606,7 +1691,6 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): return rs # Test to verify the correctness for tags in `tags.yaml`, also available for access through `torch.Tags` -@skipIfSlowGradcheckEnv class TestTags(TestCase): @onlyCPU @ops(ops_and_refs, dtypes=OpDTypes.any_one) @@ -1626,7 +1710,6 @@ def test_tags(self, device, dtype, op): check_inplace_view(opoverloadpacket, input, rs, old_size, old_stride) -@skipIfSlowGradcheckEnv class TestRefsOpsInfo(TestCase): import_paths = ["_refs", "_refs.special", "_refs.nn.functional", "_refs.fft", "_refs._conversions"] @@ -1662,13 +1745,13 @@ class TestRefsOpsInfo(TestCase): '_refs.index_add_', '_refs.index_copy_', '_refs.index_fill_', + '_refs.native_group_norm', } not_in_decomp_table = { # duplicated in _decomp and _refs - '_refs.nn.functional.elu', + '_refs.nn.functional.group_norm', '_refs.nn.functional.mse_loss', - '_refs.var', '_refs.rsub', # duplicated due to efficiency concerns of the ref vs the decomp '_refs.index_add_', @@ -1744,7 +1827,6 @@ class TestRefsOpsInfo(TestCase): '_refs.unflatten', '_refs.sum_to_size', # ref implementation missing kwargs - '_refs.full', # missing "layout" '_refs.full_like', # missing "layout" '_refs.ones_like', # missing "layout" '_refs.round', # missing "decimals" @@ -1762,9 +1844,13 @@ class TestRefsOpsInfo(TestCase): @parametrize("op", ref_ops_names) def test_refs_are_in_python_ref_db(self, op): + inplace = op[-1] == "_" if op in self.skip_ref_ops: raise unittest.SkipTest(f"{op} does not have an entry in python_ref_db") - self.assertIn(op, self.ref_db_names) + elif inplace: + self.assertNotIn(op, self.ref_db_names, msg=f"{op} is an in-place operation and should not have an OpInfo") + else: + self.assertIn(op, self.ref_db_names) @parametrize("op", ref_ops_names) def test_refs_are_in_decomp_table(self, op): @@ -1891,8 +1977,6 @@ def test_refs_are_in_decomp_table(self, op): "svd_lowrank", "sgn", "cholesky", - "linalg.eigh", - "symeig", } fake_backward_xfails = {xfail(stride_skip) for stride_skip in fake_backward_xfails} | { @@ -1911,7 +1995,6 @@ def test_refs_are_in_decomp_table(self, op): skip('pinverse'), } -@skipIfSlowGradcheckEnv class TestFakeTensor(TestCase): def _test_fake_helper(self, device, dtype, op, context): name = op.name @@ -1977,6 +2060,59 @@ def map_to_fake(e): except torch._subclasses.fake_tensor.DataDependentOutputException: self.assertTrue(name in data_dependent_op_tests) + @ops(op_db, dtypes=OpDTypes.any_one) + def test_pointwise_ops(self, device, dtype, op): + name = op.name + if op.variant_test_name: + name += "." + op.variant_test_name + if name in fake_skips or "sparse" in name or "jiterator" in name: + self.skipTest("Skip failing test") + + test_self = self + + class TestPointwiseMode(TorchDispatchMode): + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + kwargs = kwargs or {} + + out = func(*args, **kwargs) + + if torch.Tag.pointwise in func.tags: + shapes = [] + for inp in tree_flatten((args, kwargs)): + if isinstance(inp, torch.Tensor): + shapes.append(inp.shape) + + out_shape = torch._refs._broadcast_shapes(*shapes) + + for out_elem in tree_flatten(out): + if isinstance(out_elem, torch.Tensor): + test_self.assertEqual(out_elem.shape, out_shape) + + return out + + samples = op.sample_inputs(device, dtype, requires_grad=False) + for sample in samples: + mode = FakeTensorMode(throw_on_data_dependent_ops=True) + + def map_to_fake(e): + if isinstance(e, torch.Tensor): + return mode.from_tensor(e) + else: + return e + + input = tree_map(map_to_fake, sample.input) + args = tree_map(map_to_fake, sample.args) + kwargs = tree_map(map_to_fake, sample.kwargs) + + try: + op(input, *args, **kwargs) + except Exception as e: + continue + + with TestPointwiseMode(): + with mode: + op(input, *args, **kwargs) + @ops(op_db, dtypes=OpDTypes.any_one) def test_fake(self, device, dtype, op): self._test_fake_helper(device, dtype, op, contextlib.nullcontext) diff --git a/test/test_ops_fwd_gradients.py b/test/test_ops_fwd_gradients.py new file mode 100644 index 0000000000000..4b7b1c785d5f0 --- /dev/null +++ b/test/test_ops_fwd_gradients.py @@ -0,0 +1,76 @@ +# Owner(s): ["module: unknown"] + +from functools import partial +import torch + +from torch.testing._internal.common_utils import ( + TestGradients, run_tests, skipIfTorchInductor, IS_MACOS) +from torch.testing._internal.common_methods_invocations import op_db +from torch.testing._internal.common_device_type import \ + (instantiate_device_type_tests, ops, OpDTypes) + +# TODO: fixme https://github.com/pytorch/pytorch/issues/68972 +torch.set_default_dtype(torch.float32) + +# TODO: mitigate flaky issue on macOS https://github.com/pytorch/pytorch/issues/66033 +# AFAIK, c10::ThreadPool looks correct in the way it uses condition_variable wait. The +# issue seems to point to macOS itself https://github.com/graphia-app/graphia/issues/33 +if IS_MACOS: + torch.set_num_threads(1) + +# gradcheck requires double precision +_gradcheck_ops = partial(ops, dtypes=OpDTypes.supported, + allowed_dtypes=[torch.double, torch.cdouble]) + +class TestFwdGradients(TestGradients): + # Test that forward-over-reverse gradgrad is computed correctly + @_gradcheck_ops(op_db) + def test_fn_fwgrad_bwgrad(self, device, dtype, op): + self._skip_helper(op, device, dtype) + + if op.supports_fwgrad_bwgrad: + self._check_helper(device, dtype, op, op.get_op(), "fwgrad_bwgrad") + else: + err_msg = r"Trying to use forward AD with .* that does not support it" + hint_msg = ("Running forward-over-backward gradgrad for an OP that has does not support it did not " + "raise any error. If your op supports forward AD, you should set supports_fwgrad_bwgrad=True.") + with self.assertRaisesRegex(NotImplementedError, err_msg, msg=hint_msg): + self._check_helper(device, dtype, op, op.get_op(), "fwgrad_bwgrad") + + + def _forward_grad_helper(self, device, dtype, op, variant, is_inplace): + # TODO: clean up how attributes are passed to gradcheck from OpInfos + def call_grad_test_helper(): + check_batched_forward_grad = ((op.check_batched_forward_grad and not is_inplace) or + (op.check_inplace_batched_forward_grad and is_inplace)) + self._grad_test_helper(device, dtype, op, variant, check_forward_ad=True, check_backward_ad=False, + check_batched_grad=False, check_batched_forward_grad=check_batched_forward_grad) + if op.supports_forward_ad: + call_grad_test_helper() + else: + err_msg = r"Trying to use forward AD with .* that does not support it" + hint_msg = ("Running forward AD for an OP that has does not support it did not " + "raise any error. If your op supports forward AD, you should set supports_forward_ad=True") + with self.assertRaisesRegex(NotImplementedError, err_msg, msg=hint_msg): + call_grad_test_helper() + + @_gradcheck_ops(op_db) + def test_forward_mode_AD(self, device, dtype, op): + self._skip_helper(op, device, dtype) + + self._forward_grad_helper(device, dtype, op, op.get_op(), is_inplace=False) + + @_gradcheck_ops(op_db) + @skipIfTorchInductor("to be fixed") + def test_inplace_forward_mode_AD(self, device, dtype, op): + self._skip_helper(op, device, dtype) + + if not op.inplace_variant or not op.supports_inplace_autograd: + self.skipTest("Skipped! Operation does not support inplace autograd.") + + self._forward_grad_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()), is_inplace=True) + +instantiate_device_type_tests(TestFwdGradients, globals()) + +if __name__ == '__main__': + run_tests() diff --git a/test/test_ops_gradients.py b/test/test_ops_gradients.py index 0411f043df9c0..b4401af543d01 100644 --- a/test/test_ops_gradients.py +++ b/test/test_ops_gradients.py @@ -1,11 +1,9 @@ # Owner(s): ["module: unknown"] -from functools import partial, wraps -from itertools import chain +from functools import partial import torch -from torch.testing._internal.common_utils import \ - (TestCase, is_iterable_of_tensors, run_tests, gradcheck, gradgradcheck, is_slow_gradcheck_env) +from torch.testing._internal.common_utils import TestGradients, run_tests from torch.testing._internal.common_methods_invocations import op_db from torch.testing._internal.common_device_type import \ (instantiate_device_type_tests, ops, OpDTypes) @@ -17,137 +15,7 @@ _gradcheck_ops = partial(ops, dtypes=OpDTypes.supported, allowed_dtypes=[torch.double, torch.cdouble]) -class TestGradients(TestCase): - exact_dtype = True - - # Copies inputs to inplace operations to avoid inplace modifications - # to leaves requiring gradient - def _get_safe_inplace(self, inplace_variant): - @wraps(inplace_variant) - def _fn(t, *args, **kwargs): - return inplace_variant(t.clone(), *args, **kwargs) - - return _fn - - def _check_helper(self, device, dtype, op, variant, check, *, check_forward_ad=False, check_backward_ad=True, - check_batched_grad=None, check_batched_forward_grad=False): - assert check in ('gradcheck', 'bwgrad_bwgrad', 'fwgrad_bwgrad') - # NB: check_backward_ad does not affect gradgradcheck (always True) - if variant is None: - self.skipTest("Skipped! Variant not implemented.") - if not op.supports_dtype(dtype, torch.device(device).type): - self.skipTest(f"Skipped! {op.name} does not support dtype {str(dtype)}") - - def is_inplace(variant): - if hasattr(variant, "__wrapped__"): - return variant.__wrapped__ is op.get_inplace() - return variant is op.get_inplace() - - include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex - - samples = op.sample_inputs(device, dtype, requires_grad=True, include_conjugated_inputs=include_conjugated_inputs, - small_inputs_only=is_slow_gradcheck_env()) - - for sample in samples: - if sample.broadcasts_input and is_inplace(variant): - continue - - # Gradcheck expects tensors as its input, but autograd actually supports tensorlists - # and tensors passed as kwargs. The following creates a function that accepts just - # the tensors that require grad as varargs, and then recomposes them back into the - # original input. - - # Creates gradcheck inputs by identifying tensors requiring grad - all_args = None - if is_iterable_of_tensors(sample.input): - all_args = chain(sample.input, sample.args, sample.kwargs.values()) - else: - all_args = tuple(chain((sample.input,), sample.args, sample.kwargs.values())) - gradcheck_args = tuple(x for x in all_args if (isinstance(x, torch.Tensor) and x.requires_grad)) - - def _input_recomposition_helper(inputs, inp, input_idx): - if is_iterable_of_tensors(inp): - tensor_list = [] - for x in inp: - if isinstance(x, torch.Tensor) and x.requires_grad: - tensor_list.append(inputs[input_idx]) - input_idx = input_idx + 1 - else: - tensor_list.append(x) - return tensor_list, input_idx - elif isinstance(inp, torch.Tensor) and inp.requires_grad: - return inputs[input_idx], input_idx + 1 - else: - return inp, input_idx - - def fn(*inputs): - # Puts inputs back into sample properly - positional_args = [] - input_idx = 0 - inp, input_idx = _input_recomposition_helper(inputs, sample.input, input_idx) - positional_args.append(inp) - - for x in sample.args: - inp, input_idx = _input_recomposition_helper(inputs, x, input_idx) - positional_args.append(inp) - - # Recreates kwargs - kwargs = {} - for k, v in sample.kwargs.items(): - inp, input_idx = _input_recomposition_helper(inputs, v, input_idx) - kwargs[k] = inp - - output = op.gradcheck_wrapper(variant, *positional_args, **kwargs) - if sample.output_process_fn_grad is not None: - return sample.output_process_fn_grad(output) - return output - - if check == 'gradcheck': - if check_batched_grad is None: - check_batched_grad = op.check_batched_grad - self.assertTrue(gradcheck(fn, gradcheck_args, - check_batched_grad=check_batched_grad, - check_grad_dtypes=True, - nondet_tol=op.gradcheck_nondet_tol, - fast_mode=op.gradcheck_fast_mode, - check_forward_ad=check_forward_ad, - check_backward_ad=check_backward_ad, - check_undefined_grad=True, - check_batched_forward_grad=check_batched_forward_grad)) - elif check in ('bwgrad_bwgrad', 'fwgrad_bwgrad'): # gradgrad check - self.assertFalse(check_forward_ad, msg="Cannot run forward AD check for gradgradcheck") - for gen_non_contig_grad_outputs in (False, True): - kwargs = { - "gen_non_contig_grad_outputs": gen_non_contig_grad_outputs, - "check_batched_grad": op.check_batched_gradgrad, - "check_grad_dtypes": True, - "nondet_tol": op.gradcheck_nondet_tol, - "fast_mode": op.gradcheck_fast_mode - } - if check == "fwgrad_bwgrad": - kwargs["check_fwd_over_rev"] = True - kwargs["check_rev_over_rev"] = False - kwargs["check_batched_grad"] = False - kwargs["check_undefined_grad"] = False - - self.assertTrue(gradgradcheck(fn, gradcheck_args, **kwargs)) - else: - self.assertTrue(False, msg="Unknown check requested!") - - def _grad_test_helper(self, device, dtype, op, variant, *, check_forward_ad=False, check_backward_ad=True, - check_batched_grad=None, check_batched_forward_grad=False): - return self._check_helper(device, dtype, op, variant, 'gradcheck', check_forward_ad=check_forward_ad, - check_backward_ad=check_backward_ad, check_batched_grad=check_batched_grad, - check_batched_forward_grad=check_batched_forward_grad) - - def _skip_helper(self, op, device, dtype): - if dtype not in op.supported_backward_dtypes(torch.device(device).type): - self.skipTest("Skipped! Op doesn't support autograd for this dtype.") - if not op.supports_autograd and not op.supports_forward_ad: - self.skipTest("Skipped! autograd not supported.") - if op.name == "cat": - self.skipTest("TODO(whc) fix pre-existing bug with cat for newly added opinfo for empty+nonempty") - +class TestBwdGradients(TestGradients): # Tests that gradients are computed correctly @_gradcheck_ops(op_db) def test_fn_grad(self, device, dtype, op): @@ -191,20 +59,6 @@ def test_fn_gradgrad(self, device, dtype, op): else: self._check_helper(device, dtype, op, op.get_op(), 'bwgrad_bwgrad') - # Test that forward-over-reverse gradgrad is computed correctly - @_gradcheck_ops(op_db) - def test_fn_fwgrad_bwgrad(self, device, dtype, op): - self._skip_helper(op, device, dtype) - - if op.supports_fwgrad_bwgrad: - self._check_helper(device, dtype, op, op.get_op(), "fwgrad_bwgrad") - else: - err_msg = r"Trying to use forward AD with .* that does not support it" - hint_msg = ("Running forward-over-backward gradgrad for an OP that has does not support it did not " - "raise any error. If your op supports forward AD, you should set supports_fwgrad_bwgrad=True.") - with self.assertRaisesRegex(NotImplementedError, err_msg, msg=hint_msg): - self._check_helper(device, dtype, op, op.get_op(), "fwgrad_bwgrad") - # Test that gradients of gradients are properly raising @_gradcheck_ops(op_db) def test_fn_fail_gradgrad(self, device, dtype, op): @@ -230,39 +84,8 @@ def test_inplace_gradgrad(self, device, dtype, op): self.skipTest("Skipped! Operation does not support inplace autograd.") self._check_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()), "bwgrad_bwgrad") - def _forward_grad_helper(self, device, dtype, op, variant, is_inplace): - # TODO: clean up how attributes are passed to gradcheck from OpInfos - def call_grad_test_helper(): - check_batched_forward_grad = ((op.check_batched_forward_grad and not is_inplace) or - (op.check_inplace_batched_forward_grad and is_inplace)) - self._grad_test_helper(device, dtype, op, variant, check_forward_ad=True, check_backward_ad=False, - check_batched_grad=False, check_batched_forward_grad=check_batched_forward_grad) - if op.supports_forward_ad: - call_grad_test_helper() - else: - err_msg = r"Trying to use forward AD with .* that does not support it" - hint_msg = ("Running forward AD for an OP that has does not support it did not " - "raise any error. If your op supports forward AD, you should set supports_forward_ad=True") - with self.assertRaisesRegex(NotImplementedError, err_msg, msg=hint_msg): - call_grad_test_helper() - - @_gradcheck_ops(op_db) - def test_forward_mode_AD(self, device, dtype, op): - self._skip_helper(op, device, dtype) - - self._forward_grad_helper(device, dtype, op, op.get_op(), is_inplace=False) - - @_gradcheck_ops(op_db) - def test_inplace_forward_mode_AD(self, device, dtype, op): - self._skip_helper(op, device, dtype) - - if not op.inplace_variant or not op.supports_inplace_autograd: - self.skipTest("Skipped! Operation does not support inplace autograd.") - - self._forward_grad_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()), is_inplace=True) - -instantiate_device_type_tests(TestGradients, globals()) +instantiate_device_type_tests(TestBwdGradients, globals()) if __name__ == '__main__': run_tests() diff --git a/test/test_ops_jit.py b/test/test_ops_jit.py index 57d1120978e4b..e03e051ff012e 100644 --- a/test/test_ops_jit.py +++ b/test/test_ops_jit.py @@ -7,7 +7,7 @@ from torch.testing import FileCheck from torch.testing._internal.common_utils import \ - (run_tests, IS_SANDCASTLE, clone_input_helper, first_sample, skipIfSlowGradcheckEnv) + (run_tests, IS_SANDCASTLE, clone_input_helper, first_sample) from torch.testing._internal.common_methods_invocations import op_db from torch.testing._internal.common_device_type import instantiate_device_type_tests, ops, OpDTypes from torch.testing._internal.common_jit import JitCommonTestCase, check_against_reference @@ -30,7 +30,6 @@ # autodifferentiation behavior. # Inherits from JitCommonTestCase instead of TestCase directly to share # functionality with original test_jit.py method operator tests -@skipIfSlowGradcheckEnv class TestJit(JitCommonTestCase): exact_dtype = True diff --git a/test/test_optim.py b/test/test_optim.py index 104bdb046d345..36de7b18eab34 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -11,15 +11,42 @@ import torch.optim as optim import torch.nn.functional as F from torch.nn import Parameter -from torch.optim import SGD +from torch.optim import Adam, SGD, Optimizer from torch import sparse -from torch.optim.lr_scheduler import LambdaLR, MultiplicativeLR, SequentialLR, StepLR, \ - MultiStepLR, ConstantLR, LinearLR, ExponentialLR, CosineAnnealingLR, ReduceLROnPlateau, \ - _LRScheduler, CyclicLR, CosineAnnealingWarmRestarts, OneCycleLR, ChainedScheduler, PolynomialLR, \ - EPOCH_DEPRECATION_WARNING +from torch.optim.lr_scheduler import ( + LambdaLR, + MultiplicativeLR, + SequentialLR, + StepLR, + MultiStepLR, + ConstantLR, + LinearLR, + ExponentialLR, + CosineAnnealingLR, + ReduceLROnPlateau, + LRScheduler, + CyclicLR, + CosineAnnealingWarmRestarts, + OneCycleLR, + ChainedScheduler, + PolynomialLR, + EPOCH_DEPRECATION_WARNING, +) from torch.optim.swa_utils import AveragedModel, SWALR, update_bn -from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_UBSAN, load_tests, \ - parametrize, instantiate_parametrized_tests, gradcheck, skipIfRocm +from torch.testing._internal.common_utils import ( + TestCase, + run_tests, + TEST_WITH_UBSAN, + load_tests, + parametrize, + instantiate_parametrized_tests, + gradcheck, + skipIfRocm, + skipIfTorchDynamo +) +from typing import Dict, Any, Tuple +from torch.optim.optimizer import register_optimizer_step_pre_hook, register_optimizer_step_post_hook + # load_tests from common_utils is used to automatically filter tests for # sharding on sandcastle. This line silences flake warnings load_tests = load_tests @@ -27,19 +54,24 @@ def rosenbrock(tensor): x, y = tensor - return (1 - x) ** 2 + 100 * (y - x ** 2) ** 2 + return (1 - x) ** 2 + 100 * (y - x**2) ** 2 def drosenbrock(tensor): x, y = tensor - return torch.tensor((-400 * x * (y - x ** 2) - 2 * (1 - x), 200 * (y - x ** 2))) + return torch.tensor((-400 * x * (y - x**2) - 2 * (1 - x), 200 * (y - x**2))) class TestOptim(TestCase): exact_dtype = True - def _test_rosenbrock_sparse(self, constructor, scheduler_constructors=None, - sparse_only=False, maximize=False): + def _test_rosenbrock_sparse( + self, + constructor, + scheduler_constructors=None, + sparse_only=False, + maximize=False, + ): if scheduler_constructors is None: scheduler_constructors = [] params_t = torch.tensor([1.5, 1.5]) @@ -69,11 +101,11 @@ def eval(params, sparse_grad, w): if w: i = torch.LongTensor([[0, 0]]) x = grad[0] - v = torch.tensor([x / 4., x - x / 4.]) + v = torch.tensor([x / 4.0, x - x / 4.0]) else: i = torch.LongTensor([[1, 1]]) y = grad[1] - v = torch.tensor([y - y / 4., y / 4.]) + v = torch.tensor([y - y / 4.0, y / 4.0]) x = sparse.DoubleTensor(i, v, torch.Size([2])).to(dtype=v.dtype) with torch.no_grad(): if sparse_grad: @@ -100,8 +132,16 @@ def eval(params, sparse_grad, w): else: self.assertGreaterEqual(rosenbrock(params), rosenbrock(params_t)) - def _test_basic_cases_template(self, weight_tensor, bias_tensor, input_tensor, constructor, - scheduler_constructors, constructor_accepts_maximize=True, constructor_accepts_foreach=False): + def _test_basic_cases_template( + self, + weight_tensor, + bias_tensor, + input_tensor, + constructor, + scheduler_constructors, + constructor_accepts_maximize=True, + constructor_accepts_foreach=False, + ): maximize_options = set([False, constructor_accepts_maximize]) foreach_options = set([False, constructor_accepts_foreach]) @@ -109,14 +149,19 @@ def _test_basic_cases_template(self, weight_tensor, bias_tensor, input_tensor, c if constructor_accepts_maximize and constructor_accepts_foreach: pass elif constructor_accepts_maximize: + def four_arg_constructor(weight, bias, maximize, foreach): self.assertFalse(foreach) return constructor(weight, bias, maximize) + elif constructor_accepts_foreach: + def four_arg_constructor(weight, bias, maximize, foreach): self.assertFalse(maximize) return constructor(weight, bias, foreach) + else: + def four_arg_constructor(weight, bias, maximize, foreach): self.assertFalse(maximize or foreach) return constructor(weight, bias) @@ -198,31 +243,35 @@ def fn_base(optimizer, weight, bias): self.assertEqual(optimizer.state_dict(), optimizer_c.state_dict()) # Make sure repeated parameters have identical representation in state dict optimizer_c.param_groups.extend(optimizer_c.param_groups) - self.assertEqual(optimizer.state_dict()['param_groups'][-1], - optimizer_c.state_dict()['param_groups'][-1]) + self.assertEqual( + optimizer.state_dict()["param_groups"][-1], + optimizer_c.state_dict()["param_groups"][-1], + ) # Make sure that optimizers that support maximize can load older models state_dict = optimizer.state_dict() - if 'maximize' in state_dict['param_groups'][0]: - for group in state_dict['param_groups']: - del group['maximize'] + if "maximize" in state_dict["param_groups"][0]: + for group in state_dict["param_groups"]: + del group["maximize"] optimizer.load_state_dict(state_dict) # Make sure we can still step optimizer.step() # Make sure that optimizers that support foreach can load older models state_dict = optimizer.state_dict() - if 'foreach' in state_dict['param_groups'][0]: - for group in state_dict['param_groups']: - del group['foreach'] + if "foreach" in state_dict["param_groups"][0]: + for group in state_dict["param_groups"]: + del group["foreach"] optimizer.load_state_dict(state_dict) # Make sure we can still step optimizer.step() # Make sure that loading optimizers with step not wrapped in tensor can work state_dict = optimizer.state_dict() - if 'step' in state_dict['state'][0] and torch.is_tensor(state_dict['state'][0]['step']): - for state in state_dict['state'].values(): - state['step'] = state['step'].item() + if "step" in state_dict["state"][0] and torch.is_tensor( + state_dict["state"][0]["step"] + ): + for state in state_dict["state"].values(): + state["step"] = state["step"].item() optimizer.load_state_dict(state_dict) optimizer.step() @@ -233,8 +282,12 @@ def fn_base(optimizer, weight, bias): with torch.no_grad(): input_cuda = input.clone().detach().to(dtype=torch.float32, device="cuda") - weight_cuda = Parameter(weight.clone().detach().to(dtype=torch.float32, device="cuda")) - bias_cuda = Parameter(bias.clone().detach().to(dtype=torch.float32, device="cuda")) + weight_cuda = Parameter( + weight.clone().detach().to(dtype=torch.float32, device="cuda") + ) + bias_cuda = Parameter( + bias.clone().detach().to(dtype=torch.float32, device="cuda") + ) optimizer_cuda = constructor(weight_cuda, bias_cuda) fn_cuda = functools.partial(fn_base, optimizer_cuda, weight_cuda, bias_cuda) @@ -247,9 +300,11 @@ def fn_base(optimizer, weight, bias): # Make sure that device of state['step'] is still CPU new_state_dict = optimizer_cuda.state_dict() - if 'step' in state_dict['state'][0] and torch.is_tensor(state_dict['state'][0]['step']): - for state in new_state_dict['state'].values(): - self.assertEqual(state['step'].device.type, 'cpu') + if "step" in state_dict["state"][0] and torch.is_tensor( + state_dict["state"][0]["step"] + ): + for state in new_state_dict["state"].values(): + self.assertEqual(state["step"].device.type, "cpu") for _i in range(20): optimizer.step(fn) @@ -259,16 +314,26 @@ def fn_base(optimizer, weight, bias): # validate deepcopy() copies all public attributes def getPublicAttr(obj): - return set(k for k in obj.__dict__ if not k.startswith('_')) + return set(k for k in obj.__dict__ if not k.startswith("_")) + self.assertEqual(getPublicAttr(optimizer), getPublicAttr(deepcopy(optimizer))) - def _test_basic_cases(self, constructor, scheduler_constructors=None, - ignore_multidevice=False, constructor_accepts_maximize=False, constructor_accepts_foreach=False, - atol=None, rtol=None): + def _test_basic_cases( + self, + constructor, + scheduler_constructors=None, + ignore_multidevice=False, + constructor_accepts_maximize=False, + constructor_accepts_foreach=False, + atol=None, + rtol=None, + ): if scheduler_constructors is None: scheduler_constructors = [] - def make_two_arg_constructor(constructor, maximize: bool = False, foreach: bool = False): + def make_two_arg_constructor( + constructor, maximize: bool = False, foreach: bool = False + ): if constructor_accepts_maximize and constructor_accepts_foreach: return lambda weight, bias: constructor(weight, bias, maximize, foreach) if constructor_accepts_maximize: @@ -286,7 +351,8 @@ def make_two_arg_constructor(constructor, maximize: bool = False, foreach: bool torch.randn(10), torch.randn(5), make_two_arg_constructor(constructor, maximize, foreach), - atol=atol, rtol=rtol + atol=atol, + rtol=rtol, ) self._test_basic_cases_template( torch.randn(10, 5), @@ -373,90 +439,162 @@ def _test_complex_2d(self, optimizer_constructor, f=None): self.assertEqual(a1.imag, a1_imag) def _build_params_dict(self, weight, bias, **kwargs): - return [{'params': [weight]}, dict(params=[bias], **kwargs)] + return [{"params": [weight]}, dict(params=[bias], **kwargs)] def _build_params_dict_single(self, weight, bias, **kwargs): return [dict(params=bias, **kwargs)] def test_sgd(self): self._test_basic_cases( - lambda weight, bias, maximize, foreach: optim.SGD([weight, bias], lr=1e-3, maximize=maximize, foreach=foreach), - constructor_accepts_maximize=True, constructor_accepts_foreach=True, + lambda weight, bias, maximize, foreach: optim.SGD( + [weight, bias], lr=1e-3, maximize=maximize, foreach=foreach + ), + constructor_accepts_maximize=True, + constructor_accepts_foreach=True, ) self._test_basic_cases( - lambda weight, bias, maximize, foreach: optim.SGD([weight, bias], lr=1e-3, maximize=maximize, foreach=foreach), - constructor_accepts_maximize=True, constructor_accepts_foreach=True, + lambda weight, bias, maximize, foreach: optim.SGD( + [weight, bias], lr=1e-3, maximize=maximize, foreach=foreach + ), + constructor_accepts_maximize=True, + constructor_accepts_foreach=True, ) self._test_basic_cases( lambda weight, bias, maximize, foreach: optim.SGD( self._build_params_dict(weight, bias, lr=1e-2), - lr=1e-3, maximize=maximize, foreach=foreach), - constructor_accepts_maximize=True, constructor_accepts_foreach=True, + lr=1e-3, + maximize=maximize, + foreach=foreach, + ), + constructor_accepts_maximize=True, + constructor_accepts_foreach=True, ) self._test_basic_cases( lambda weight, bias, maximize, foreach: optim.SGD( self._build_params_dict_single(weight, bias, lr=1e-2), - lr=1e-3, maximize=maximize, foreach=foreach), - constructor_accepts_maximize=True, constructor_accepts_foreach=True, + lr=1e-3, + maximize=maximize, + foreach=foreach, + ), + constructor_accepts_maximize=True, + constructor_accepts_foreach=True, ) self._test_basic_cases( lambda weight, bias, maximize, foreach: optim.SGD( - self._build_params_dict_single(weight, bias, lr=1e-2), maximize=maximize, foreach=foreach), - constructor_accepts_maximize=True, constructor_accepts_foreach=True, + self._build_params_dict_single(weight, bias, lr=1e-2), + maximize=maximize, + foreach=foreach, + ), + constructor_accepts_maximize=True, + constructor_accepts_foreach=True, ) self._test_basic_cases( - lambda weight, bias, maximize, foreach: optim.SGD([weight, bias], lr=1e-3, maximize=maximize, foreach=foreach), + lambda weight, bias, maximize, foreach: optim.SGD( + [weight, bias], lr=1e-3, maximize=maximize, foreach=foreach + ), [lambda opt: StepLR(opt, gamma=0.9, step_size=10)], - constructor_accepts_maximize=True, constructor_accepts_foreach=True, + constructor_accepts_maximize=True, + constructor_accepts_foreach=True, ) self._test_basic_cases( - lambda weight, bias, maximize, foreach: optim.SGD([weight, bias], lr=1e-3, maximize=maximize, foreach=foreach), - [lambda opt: LinearLR(opt, start_factor=0.4, end_factor=0.8, total_iters=4)], - constructor_accepts_maximize=True, constructor_accepts_foreach=True, + lambda weight, bias, maximize, foreach: optim.SGD( + [weight, bias], lr=1e-3, maximize=maximize, foreach=foreach + ), + [ + lambda opt: LinearLR( + opt, start_factor=0.4, end_factor=0.8, total_iters=4 + ) + ], + constructor_accepts_maximize=True, + constructor_accepts_foreach=True, ) self._test_basic_cases( - lambda weight, bias, maximize, foreach: optim.SGD([weight, bias], lr=1e-3, maximize=maximize, foreach=foreach), + lambda weight, bias, maximize, foreach: optim.SGD( + [weight, bias], lr=1e-3, maximize=maximize, foreach=foreach + ), [lambda opt: ConstantLR(opt, factor=0.4, total_iters=4)], - constructor_accepts_maximize=True, constructor_accepts_foreach=True, + constructor_accepts_maximize=True, + constructor_accepts_foreach=True, ) self._test_basic_cases( - lambda weight, bias, maximize, foreach: optim.SGD([weight, bias], lr=1e-3, maximize=maximize, foreach=foreach), - [lambda opt: StepLR(opt, gamma=0.9, step_size=10), - lambda opt: LinearLR(opt, start_factor=0.4, end_factor=0.6, total_iters=4)], - constructor_accepts_maximize=True, constructor_accepts_foreach=True, + lambda weight, bias, maximize, foreach: optim.SGD( + [weight, bias], lr=1e-3, maximize=maximize, foreach=foreach + ), + [ + lambda opt: StepLR(opt, gamma=0.9, step_size=10), + lambda opt: LinearLR( + opt, start_factor=0.4, end_factor=0.6, total_iters=4 + ), + ], + constructor_accepts_maximize=True, + constructor_accepts_foreach=True, ) self._test_basic_cases( - lambda weight, bias, maximize, foreach: optim.SGD([weight, bias], lr=1e-3, maximize=maximize, foreach=foreach), - [lambda opt: StepLR(opt, gamma=0.9, step_size=10), - lambda opt: ReduceLROnPlateau(opt)], - constructor_accepts_maximize=True, constructor_accepts_foreach=True, + lambda weight, bias, maximize, foreach: optim.SGD( + [weight, bias], lr=1e-3, maximize=maximize, foreach=foreach + ), + [ + lambda opt: StepLR(opt, gamma=0.9, step_size=10), + lambda opt: ReduceLROnPlateau(opt), + ], + constructor_accepts_maximize=True, + constructor_accepts_foreach=True, ) self._test_basic_cases( - lambda weight, bias, maximize, foreach: optim.SGD([weight, bias], lr=1e-3, maximize=maximize, foreach=foreach), - [lambda opt: StepLR(opt, gamma=0.99, step_size=10), + lambda weight, bias, maximize, foreach: optim.SGD( + [weight, bias], lr=1e-3, maximize=maximize, foreach=foreach + ), + [ + lambda opt: StepLR(opt, gamma=0.99, step_size=10), lambda opt: ExponentialLR(opt, gamma=0.99), - lambda opt: ReduceLROnPlateau(opt)], - constructor_accepts_maximize=True, constructor_accepts_foreach=True, + lambda opt: ReduceLROnPlateau(opt), + ], + constructor_accepts_maximize=True, + constructor_accepts_foreach=True, ) self._test_basic_cases( - lambda weight, bias, maximize, foreach: - optim.SGD([weight, bias], lr=1e-3, momentum=0.5, maximize=maximize, foreach=foreach), - constructor_accepts_maximize=True, constructor_accepts_foreach=True, + lambda weight, bias, maximize, foreach: optim.SGD( + [weight, bias], + lr=1e-3, + momentum=0.5, + maximize=maximize, + foreach=foreach, + ), + constructor_accepts_maximize=True, + constructor_accepts_foreach=True, ) self._test_basic_cases( - lambda weight, bias, maximize, foreach: - optim.SGD([weight, bias], lr=1e-3, momentum=0.5, weight_decay=1, maximize=maximize, foreach=foreach), - constructor_accepts_maximize=True, constructor_accepts_foreach=True, + lambda weight, bias, maximize, foreach: optim.SGD( + [weight, bias], + lr=1e-3, + momentum=0.5, + weight_decay=1, + maximize=maximize, + foreach=foreach, + ), + constructor_accepts_maximize=True, + constructor_accepts_foreach=True, ) self._test_basic_cases( - lambda weight, bias, maximize, foreach: - optim.SGD([weight, bias], nesterov=True, lr=1e-3, momentum=0.5, weight_decay=1, maximize=maximize, foreach=foreach), - constructor_accepts_maximize=True, constructor_accepts_foreach=True, + lambda weight, bias, maximize, foreach: optim.SGD( + [weight, bias], + nesterov=True, + lr=1e-3, + momentum=0.5, + weight_decay=1, + maximize=maximize, + foreach=foreach, + ), + constructor_accepts_maximize=True, + constructor_accepts_foreach=True, ) self._test_basic_cases( - lambda weight, bias, maximize, foreach: optim.SGD([weight, bias], lr=1e-3, maximize=maximize, foreach=foreach), + lambda weight, bias, maximize, foreach: optim.SGD( + [weight, bias], lr=1e-3, maximize=maximize, foreach=foreach + ), [lambda opt: PolynomialLR(opt, power=0.9, total_iters=4)], - constructor_accepts_maximize=True, constructor_accepts_foreach=True, + constructor_accepts_maximize=True, + constructor_accepts_foreach=True, ) with self.assertRaisesRegex(ValueError, "Invalid momentum value: -0.5"): optim.SGD(None, lr=1e-2, momentum=-0.5) @@ -468,7 +606,7 @@ def test_sgd_sparse(self): ) self._test_rosenbrock_sparse( lambda params: optim.SGD(params, lr=0.0048, foreach=foreach), - [lambda opt: StepLR(opt, gamma=0.99999, step_size=300)] + [lambda opt: StepLR(opt, gamma=0.99999, step_size=300)], ) def test_sgd_complex(self): @@ -480,13 +618,29 @@ def test_sgd_complex(self): lambda param: optim.SGD([param], lr=0.001, momentum=1, foreach=foreach) ) self._test_complex_optimizer( - lambda param: optim.SGD([param], lr=0.001, momentum=1, weight_decay=1, foreach=foreach) + lambda param: optim.SGD( + [param], lr=0.001, momentum=1, weight_decay=1, foreach=foreach + ) ) self._test_complex_optimizer( - lambda param: optim.SGD([param], lr=0.001, nesterov=True, momentum=1, weight_decay=1, foreach=foreach) + lambda param: optim.SGD( + [param], + lr=0.001, + nesterov=True, + momentum=1, + weight_decay=1, + foreach=foreach, + ) ) self._test_complex_optimizer( - lambda param: optim.SGD([param], lr=0.001, momentum=1, dampening=0.5, weight_decay=1, foreach=foreach) + lambda param: optim.SGD( + [param], + lr=0.001, + momentum=1, + dampening=0.5, + weight_decay=1, + foreach=foreach, + ) ) def _test_derived_optimizers(self, optimizer_pairs_with_flags, flag): @@ -495,21 +649,27 @@ def _test_derived_optimizers(self, optimizer_pairs_with_flags, flag): assert flag in ("foreach", "fused") kIterations = 4 - device = 'cuda' + device = "cuda" for optimizer_constructor, params in optimizer_pairs_with_flags: res, state = [], [] for foreach in (False, True): - input = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], dtype=torch.float64, device=device).reshape(3, 2) + input = torch.tensor( + [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], dtype=torch.float64, device=device + ).reshape(3, 2) torch.manual_seed(1) - model = torch.nn.Sequential(torch.nn.Linear(2, 3), - torch.nn.Sigmoid(), - torch.nn.Linear(3, 1), - torch.nn.Sigmoid()) + model = torch.nn.Sequential( + torch.nn.Linear(2, 3), + torch.nn.Sigmoid(), + torch.nn.Linear(3, 1), + torch.nn.Sigmoid(), + ) model.to(dtype=torch.float64, device=device) params_with_foreach = deepcopy(params) params_with_foreach["foreach"] = foreach - optimizer = optimizer_constructor(model.parameters(), **params_with_foreach) + optimizer = optimizer_constructor( + model.parameters(), **params_with_foreach + ) for _ in range(kIterations): optimizer.zero_grad() @@ -539,26 +699,36 @@ def _test_derived_optimizers(self, optimizer_pairs_with_flags, flag): actual = mt_p_state[k] # If `torch.optim.Adam` is `__init__`ed with either `fused=True` or `capturable=True`, # `step` Tensor is 1D while usually it's 0D. - if k == "step" and isinstance(actual, torch.Tensor) and actual.ndim == 1: + if ( + k == "step" + and isinstance(actual, torch.Tensor) + and actual.ndim == 1 + ): actual = actual[0] self.assertEqual(st_p_state[k], actual, atol=5e-5, rtol=0) def test_multi_tensor_optimizers(self): optimizer_pairs_with_flags = [ - (optim.Adam, dict(weight_decay=1., amsgrad=True)), - (optim.Adam, dict(weight_decay=1., amsgrad=False)), - (optim.Adam, dict(weight_decay=0., amsgrad=True)), - (optim.Adam, dict(weight_decay=0., amsgrad=False)), - (optim.AdamW, dict(weight_decay=1., amsgrad=True)), - (optim.AdamW, dict(weight_decay=1., amsgrad=False)), - (optim.AdamW, dict(weight_decay=0., amsgrad=True)), - (optim.AdamW, dict(weight_decay=0., amsgrad=False)), - (optim.NAdam, dict(weight_decay=0., momentum_decay=6e-3)), - (optim.NAdam, dict(weight_decay=1., momentum_decay=6e-3)), - (optim.NAdam, dict(weight_decay=0., momentum_decay=4e-3)), + (optim.Adam, dict(weight_decay=1.0, amsgrad=True)), + (optim.Adam, dict(weight_decay=1.0, amsgrad=False)), + (optim.Adam, dict(weight_decay=0.0, amsgrad=True)), + (optim.Adam, dict(weight_decay=0.0, amsgrad=False)), + (optim.AdamW, dict(weight_decay=1.0, amsgrad=True)), + (optim.AdamW, dict(weight_decay=1.0, amsgrad=False)), + (optim.AdamW, dict(weight_decay=0.0, amsgrad=True)), + (optim.AdamW, dict(weight_decay=0.0, amsgrad=False)), + (optim.NAdam, dict(weight_decay=0.0, momentum_decay=6e-3)), + (optim.NAdam, dict(weight_decay=1.0, momentum_decay=6e-3)), + (optim.NAdam, dict(weight_decay=0.0, momentum_decay=4e-3)), (optim.NAdam, dict(weight_decay=0.01, momentum_decay=4e-3)), - (optim.SGD, dict(lr=0.2, momentum=1, dampening=0, weight_decay=1, nesterov=True)), - (optim.SGD, dict(lr=0.2, momentum=1, dampening=0.5, weight_decay=1, nesterov=False)), + ( + optim.SGD, + dict(lr=0.2, momentum=1, dampening=0, weight_decay=1, nesterov=True), + ), + ( + optim.SGD, + dict(lr=0.2, momentum=1, dampening=0.5, weight_decay=1, nesterov=False), + ), (optim.RAdam, dict(weight_decay=0)), (optim.RAdam, dict(weight_decay=1)), (optim.RMSprop, dict(weight_decay=1, momentum=1, centered=True)), @@ -579,48 +749,71 @@ def test_multi_tensor_optimizers(self): def test_fused_optimizers(self): optimizer_pairs_with_flags = [ - (optim.Adam, dict(weight_decay=1., amsgrad=False)), - (optim.Adam, dict(weight_decay=1., amsgrad=True)), - (optim.Adam, dict(weight_decay=0., amsgrad=False)), - (optim.Adam, dict(weight_decay=0., amsgrad=True)), + (optim.Adam, dict(weight_decay=1.0, amsgrad=False)), + (optim.Adam, dict(weight_decay=1.0, amsgrad=True)), + (optim.Adam, dict(weight_decay=0.0, amsgrad=False)), + (optim.Adam, dict(weight_decay=0.0, amsgrad=True)), ] self._test_derived_optimizers(optimizer_pairs_with_flags, "fused") def test_adam(self): self._test_basic_cases( - lambda weight, bias, maximize, foreach: optim.Adam([weight, bias], lr=1e-3, maximize=maximize, foreach=foreach), + lambda weight, bias, maximize, foreach: optim.Adam( + [weight, bias], lr=1e-3, maximize=maximize, foreach=foreach + ), constructor_accepts_maximize=True, constructor_accepts_foreach=True, ) self._test_basic_cases( lambda weight, bias, maximize, foreach: optim.Adam( - self._build_params_dict(weight, bias, lr=1e-2), lr=1e-3, maximize=maximize, foreach=foreach), + self._build_params_dict(weight, bias, lr=1e-2), + lr=1e-3, + maximize=maximize, + foreach=foreach, + ), constructor_accepts_maximize=True, constructor_accepts_foreach=True, ) self._test_basic_cases( lambda weight, bias, maximize, foreach: optim.Adam( - [weight, bias], lr=1e-3, amsgrad=True, maximize=maximize, foreach=foreach), + [weight, bias], + lr=1e-3, + amsgrad=True, + maximize=maximize, + foreach=foreach, + ), constructor_accepts_maximize=True, constructor_accepts_foreach=True, ) self._test_basic_cases( lambda weight, bias, maximize, foreach: optim.Adam( - [weight, bias], lr=1e-3, weight_decay=0.1, maximize=maximize, foreach=foreach), + [weight, bias], + lr=1e-3, + weight_decay=0.1, + maximize=maximize, + foreach=foreach, + ), constructor_accepts_maximize=True, constructor_accepts_foreach=True, ) self._test_basic_cases( lambda weight, bias, maximize, foreach: optim.Adam( self._build_params_dict(weight, bias, lr=1e-2), - lr=1e-3, amsgrad=True, maximize=maximize, foreach=foreach), + lr=1e-3, + amsgrad=True, + maximize=maximize, + foreach=foreach, + ), constructor_accepts_maximize=True, constructor_accepts_foreach=True, ) self._test_basic_cases( lambda weight, bias, maximize, foreach: optim.Adam( self._build_params_dict(weight, bias, lr=1e-2), - lr=1e-3, maximize=maximize, foreach=foreach), + lr=1e-3, + maximize=maximize, + foreach=foreach, + ), [lambda opt: ExponentialLR(opt, gamma=0.9)], constructor_accepts_maximize=True, constructor_accepts_foreach=True, @@ -628,7 +821,10 @@ def test_adam(self): self._test_basic_cases( lambda weight, bias, maximize, foreach: optim.Adam( self._build_params_dict(weight, bias, lr=1e-2), - lr=1e-3, maximize=maximize, foreach=foreach), + lr=1e-3, + maximize=maximize, + foreach=foreach, + ), [lambda opt: LinearLR(opt, start_factor=0.4, total_iters=4)], constructor_accepts_maximize=True, constructor_accepts_foreach=True, @@ -636,33 +832,56 @@ def test_adam(self): self._test_basic_cases( lambda weight, bias, maximize, foreach: optim.Adam( self._build_params_dict(weight, bias, lr=1e-2), - lr=1e-3, maximize=maximize, foreach=foreach), + lr=1e-3, + maximize=maximize, + foreach=foreach, + ), [lambda opt: ConstantLR(opt, factor=0.4, total_iters=4)], constructor_accepts_maximize=True, constructor_accepts_foreach=True, ) self._test_basic_cases( lambda weight, bias, maximize, foreach: optim.Adam( - [weight, bias], lr=1e-3, amsgrad=True, maximize=maximize, foreach=foreach), - [lambda opt: ConstantLR(opt, factor=0.4, total_iters=4), - lambda opt: ExponentialLR(opt, gamma=0.9)], + [weight, bias], + lr=1e-3, + amsgrad=True, + maximize=maximize, + foreach=foreach, + ), + [ + lambda opt: ConstantLR(opt, factor=0.4, total_iters=4), + lambda opt: ExponentialLR(opt, gamma=0.9), + ], constructor_accepts_maximize=True, constructor_accepts_foreach=True, ) self._test_basic_cases( lambda weight, bias, maximize, foreach: optim.Adam( - [weight, bias], lr=1e-3, amsgrad=True, maximize=maximize, foreach=foreach), - [lambda opt: ExponentialLR(opt, gamma=0.9), - lambda opt: ReduceLROnPlateau(opt)], + [weight, bias], + lr=1e-3, + amsgrad=True, + maximize=maximize, + foreach=foreach, + ), + [ + lambda opt: ExponentialLR(opt, gamma=0.9), + lambda opt: ReduceLROnPlateau(opt), + ], constructor_accepts_maximize=True, constructor_accepts_foreach=True, ) self._test_basic_cases( lambda weight, bias, maximize, foreach: optim.Adam( self._build_params_dict(weight, bias, lr=1e-2), - lr=1e-3, amsgrad=True, maximize=maximize, foreach=foreach), - [lambda opt: StepLR(opt, gamma=0.9, step_size=10), - lambda opt: ReduceLROnPlateau(opt)], + lr=1e-3, + amsgrad=True, + maximize=maximize, + foreach=foreach, + ), + [ + lambda opt: StepLR(opt, gamma=0.9, step_size=10), + lambda opt: ReduceLROnPlateau(opt), + ], constructor_accepts_maximize=True, constructor_accepts_foreach=True, ) @@ -670,7 +889,10 @@ def test_adam(self): self._test_basic_cases( lambda weight, bias, maximize, foreach: optim.Adam( self._build_params_dict(weight, bias, lr=1e-2), - lr=1e-3, maximize=maximize, foreach=foreach), + lr=1e-3, + maximize=maximize, + foreach=foreach, + ), [lambda opt: PolynomialLR(opt, total_iters=4, power=0.9)], constructor_accepts_maximize=True, constructor_accepts_foreach=True, @@ -678,7 +900,9 @@ def test_adam(self): self._test_complex_2d(optim.Adam) self._test_complex_2d(functools.partial(optim.Adam, foreach=True)) - with self.assertRaisesRegex(ValueError, "Invalid beta parameter at index 0: 1.0"): + with self.assertRaisesRegex( + ValueError, "Invalid beta parameter at index 0: 1.0" + ): optim.Adam(None, lr=1e-2, betas=(1.0, 0.0)) with self.assertRaisesRegex(ValueError, "Invalid weight_decay value: -1"): @@ -686,25 +910,42 @@ def test_adam(self): def test_adamw(self): self._test_basic_cases( - lambda weight, bias, maximize, foreach: optim.AdamW([weight, bias], lr=1e-3, maximize=maximize, foreach=foreach), + lambda weight, bias, maximize, foreach: optim.AdamW( + [weight, bias], lr=1e-3, maximize=maximize, foreach=foreach + ), constructor_accepts_maximize=True, constructor_accepts_foreach=True, ) self._test_basic_cases( lambda weight, bias, maximize, foreach: optim.AdamW( - self._build_params_dict(weight, bias, lr=1e-2), lr=1e-3, maximize=maximize, foreach=foreach), + self._build_params_dict(weight, bias, lr=1e-2), + lr=1e-3, + maximize=maximize, + foreach=foreach, + ), constructor_accepts_maximize=True, constructor_accepts_foreach=True, ) self._test_basic_cases( lambda weight, bias, maximize, foreach: optim.AdamW( - [weight, bias], lr=1e-3, weight_decay=1, maximize=maximize, foreach=foreach), + [weight, bias], + lr=1e-3, + weight_decay=1, + maximize=maximize, + foreach=foreach, + ), constructor_accepts_maximize=True, constructor_accepts_foreach=True, ) self._test_basic_cases( lambda weight, bias, maximize, foreach: optim.AdamW( - [weight, bias], lr=1e-3, weight_decay=1, amsgrad=True, maximize=maximize, foreach=foreach), + [weight, bias], + lr=1e-3, + weight_decay=1, + amsgrad=True, + maximize=maximize, + foreach=foreach, + ), constructor_accepts_maximize=True, constructor_accepts_foreach=True, ) @@ -715,21 +956,25 @@ def test_adamw(self): def test_sparse_adam(self): self._test_rosenbrock_sparse( - lambda params: optim.SparseAdam(params, lr=4e-2), - [], - True + lambda params: optim.SparseAdam(params, lr=4e-2), [], True ) self._test_rosenbrock_sparse( lambda params: optim.SparseAdam(params, lr=4e-2, maximize=True), [], True, - True + True, ) - with self.assertRaisesRegex(ValueError, "Invalid beta parameter at index 0: 1.0"): + with self.assertRaisesRegex( + ValueError, "Invalid beta parameter at index 0: 1.0" + ): optim.SparseAdam(None, lr=1e-2, betas=(1.0, 0.0)) - with self.assertRaisesRegex(ValueError, "SparseAdam requires dense parameter tensors"): + with self.assertRaisesRegex( + ValueError, "SparseAdam requires dense parameter tensors" + ): optim.SparseAdam([torch.zeros(3, layout=torch.sparse_coo)]) - with self.assertRaisesRegex(ValueError, "SparseAdam requires dense parameter tensors"): + with self.assertRaisesRegex( + ValueError, "SparseAdam requires dense parameter tensors" + ): optim.SparseAdam([{"params": [torch.zeros(3, layout=torch.sparse_coo)]}]) # ROCm precision is too low to pass this test @@ -737,27 +982,38 @@ def test_adadelta(self): # Handles https://github.com/pytorch/pytorch/issues/69698 self.rel_tol = 4e-3 self._test_basic_cases( - lambda weight, bias, maximize, foreach: optim.Adadelta([weight, bias], maximize=maximize, foreach=foreach), + lambda weight, bias, maximize, foreach: optim.Adadelta( + [weight, bias], maximize=maximize, foreach=foreach + ), constructor_accepts_maximize=True, constructor_accepts_foreach=True, ) self._test_basic_cases( lambda weight, bias, maximize, foreach: optim.Adadelta( - self._build_params_dict(weight, bias, rho=0.95), maximize=maximize, foreach=foreach), + self._build_params_dict(weight, bias, rho=0.95), + maximize=maximize, + foreach=foreach, + ), constructor_accepts_maximize=True, constructor_accepts_foreach=True, ) self._test_basic_cases( lambda weight, bias, maximize, foreach: optim.Adadelta( - self._build_params_dict(weight, bias, rho=0.95), maximize=maximize, foreach=foreach), - [lambda opt: StepLR(opt, gamma=0.9, step_size=10), - lambda opt: ReduceLROnPlateau(opt)], + self._build_params_dict(weight, bias, rho=0.95), + maximize=maximize, + foreach=foreach, + ), + [ + lambda opt: StepLR(opt, gamma=0.9, step_size=10), + lambda opt: ReduceLROnPlateau(opt), + ], constructor_accepts_maximize=True, constructor_accepts_foreach=True, ) self._test_basic_cases( lambda weight, bias, maximize, foreach: optim.Adadelta( - [weight, bias], weight_decay=1, maximize=maximize, foreach=foreach), + [weight, bias], weight_decay=1, maximize=maximize, foreach=foreach + ), constructor_accepts_maximize=True, constructor_accepts_foreach=True, ) @@ -768,52 +1024,68 @@ def test_adadelta_complex(self): # Handles https://github.com/pytorch/pytorch/issues/69698 self.rel_tol = 2e-2 for optimizer in [optim.Adadelta]: - self._test_complex_optimizer( - lambda weight: optimizer([weight]) - ) - self._test_complex_optimizer( - lambda weight: optimizer([weight], rho=0.95) - ) + self._test_complex_optimizer(lambda weight: optimizer([weight])) + self._test_complex_optimizer(lambda weight: optimizer([weight], rho=0.95)) self._test_complex_optimizer( lambda weight: optimizer([weight], rho=0.95, weight_decay=1) ) def test_nadam(self): self._test_basic_cases( - lambda weight, bias, foreach: optim.NAdam([weight, bias], lr=1e-3, foreach=foreach), + lambda weight, bias, foreach: optim.NAdam( + [weight, bias], lr=1e-3, foreach=foreach + ), constructor_accepts_foreach=True, ) self._test_basic_cases( lambda weight, bias, foreach: optim.NAdam( - self._build_params_dict(weight, bias, lr=1e-2), - lr=1e-3, foreach=foreach), + self._build_params_dict(weight, bias, lr=1e-2), lr=1e-3, foreach=foreach + ), constructor_accepts_foreach=True, ) self._test_basic_cases( lambda weight, bias, foreach: optim.NAdam( - [weight, bias], lr=1e-3, weight_decay=0.1, momentum_decay=6e-3, foreach=foreach), + [weight, bias], + lr=1e-3, + weight_decay=0.1, + momentum_decay=6e-3, + foreach=foreach, + ), constructor_accepts_foreach=True, ) self._test_basic_cases( lambda weight, bias, foreach: optim.NAdam( - [weight, bias], lr=1e-3, weight_decay=0.1, momentum_decay=6e-3, foreach=foreach), + [weight, bias], + lr=1e-3, + weight_decay=0.1, + momentum_decay=6e-3, + foreach=foreach, + ), [lambda opt: ExponentialLR(opt, gamma=0.9)], constructor_accepts_foreach=True, ) - with self.assertRaisesRegex(ValueError, "Invalid beta parameter at index 0: 1.0"): + with self.assertRaisesRegex( + ValueError, "Invalid beta parameter at index 0: 1.0" + ): optim.NAdam(None, lr=1e-2, betas=(1.0, 0.0)) with self.assertRaisesRegex(ValueError, "Invalid momentum_decay value: -0.2"): optim.NAdam(None, lr=1e-2, momentum_decay=-0.2) def test_adagrad(self): self._test_basic_cases( - lambda weight, bias, maximize, foreach: optim.Adagrad([weight, bias], lr=1e-1, maximize=maximize, foreach=foreach), + lambda weight, bias, maximize, foreach: optim.Adagrad( + [weight, bias], lr=1e-1, maximize=maximize, foreach=foreach + ), constructor_accepts_maximize=True, constructor_accepts_foreach=True, ) self._test_basic_cases( lambda weight, bias, maximize, foreach: optim.Adagrad( - [weight, bias], lr=1e-1, initial_accumulator_value=0.1, maximize=maximize, foreach=foreach, + [weight, bias], + lr=1e-1, + initial_accumulator_value=0.1, + maximize=maximize, + foreach=foreach, ), constructor_accepts_maximize=True, constructor_accepts_foreach=True, @@ -822,7 +1094,9 @@ def test_adagrad(self): lambda weight, bias, maximize, foreach: optim.Adagrad( self._build_params_dict(weight, bias, lr=1e-2), lr=1e-1, - maximize=maximize, foreach=foreach), + maximize=maximize, + foreach=foreach, + ), constructor_accepts_maximize=True, constructor_accepts_foreach=True, ) @@ -830,7 +1104,9 @@ def test_adagrad(self): lambda weight, bias, maximize, foreach: optim.Adagrad( self._build_params_dict(weight, bias, lr=1e-2), lr=1e-1, - maximize=maximize, foreach=foreach), + maximize=maximize, + foreach=foreach, + ), [lambda opt: ReduceLROnPlateau(opt)], constructor_accepts_maximize=True, constructor_accepts_foreach=True, @@ -839,9 +1115,13 @@ def test_adagrad(self): lambda weight, bias, maximize, foreach: optim.Adagrad( self._build_params_dict(weight, bias, lr=1e-2), lr=1e-1, - maximize=maximize, foreach=foreach), - [lambda opt: ReduceLROnPlateau(opt), - lambda opt: ExponentialLR(opt, gamma=0.99)], + maximize=maximize, + foreach=foreach, + ), + [ + lambda opt: ReduceLROnPlateau(opt), + lambda opt: ExponentialLR(opt, gamma=0.99), + ], constructor_accepts_maximize=True, constructor_accepts_foreach=True, ) @@ -855,8 +1135,10 @@ def test_adagrad_sparse(self): ) self._test_rosenbrock_sparse( lambda params: optim.Adagrad(params, lr=0.1, foreach=foreach), - [lambda opt: StepLR(opt, gamma=1 - 1e-5, step_size=500), - lambda opt: ReduceLROnPlateau(opt, threshold=1e-4)] + [ + lambda opt: StepLR(opt, gamma=1 - 1e-5, step_size=500), + lambda opt: ReduceLROnPlateau(opt, threshold=1e-4), + ], ) def test_adagrad_complex(self): @@ -866,55 +1148,81 @@ def test_adagrad_complex(self): ) self._test_complex_optimizer( lambda param: optim.Adagrad( - [param], lr=1e-1, initial_accumulator_value=0.1, foreach=foreach, + [param], + lr=1e-1, + initial_accumulator_value=0.1, + foreach=foreach, ) ) def test_adamax(self): self._test_basic_cases( lambda weight, bias, maximize, foreach: optim.Adamax( - [weight, bias], lr=1e-1, maximize=maximize, foreach=foreach), + [weight, bias], lr=1e-1, maximize=maximize, foreach=foreach + ), constructor_accepts_maximize=True, constructor_accepts_foreach=True, ) self._test_basic_cases( lambda weight, bias, maximize, foreach: optim.Adamax( self._build_params_dict(weight, bias, lr=1e-2), - lr=1e-1, maximize=maximize, foreach=foreach), + lr=1e-1, + maximize=maximize, + foreach=foreach, + ), constructor_accepts_maximize=True, constructor_accepts_foreach=True, ) self._test_basic_cases( lambda weight, bias, maximize, foreach: optim.Adamax( - [weight, bias], lr=1e-1, weight_decay=1, maximize=maximize, foreach=foreach), + [weight, bias], + lr=1e-1, + weight_decay=1, + maximize=maximize, + foreach=foreach, + ), constructor_accepts_maximize=True, constructor_accepts_foreach=True, ) self._test_complex_2d(optim.Adamax) self._test_complex_2d(functools.partial(optim.Adamax, foreach=True)) - with self.assertRaisesRegex(ValueError, "Invalid beta parameter at index 1: 1.0"): + with self.assertRaisesRegex( + ValueError, "Invalid beta parameter at index 1: 1.0" + ): optim.Adamax(None, lr=1e-2, betas=(0.0, 1.0)) def test_radam(self): self._test_basic_cases( - lambda weight, bias, foreach: optim.RAdam([weight, bias], lr=1e-3, foreach=foreach), + lambda weight, bias, foreach: optim.RAdam( + [weight, bias], lr=1e-3, foreach=foreach + ), constructor_accepts_foreach=True, ) self._test_basic_cases( lambda weight, bias, foreach: optim.RAdam( - self._build_params_dict(weight, bias, lr=1e-2), lr=1e-3, foreach=foreach), + self._build_params_dict(weight, bias, lr=1e-2), lr=1e-3, foreach=foreach + ), constructor_accepts_foreach=True, ) self._test_basic_cases( - lambda weight, bias, foreach: optim.RAdam([weight, bias], lr=1e-3, weight_decay=0.1, foreach=foreach), + lambda weight, bias, foreach: optim.RAdam( + [weight, bias], lr=1e-3, weight_decay=0.1, foreach=foreach + ), constructor_accepts_foreach=True, ) self._test_basic_cases( - lambda weight, bias, foreach: optim.RAdam([weight, bias], lr=1e-3, foreach=foreach), - [lambda opt: ExponentialLR(opt, gamma=0.9), lambda opt: ReduceLROnPlateau(opt)], + lambda weight, bias, foreach: optim.RAdam( + [weight, bias], lr=1e-3, foreach=foreach + ), + [ + lambda opt: ExponentialLR(opt, gamma=0.9), + lambda opt: ReduceLROnPlateau(opt), + ], constructor_accepts_foreach=True, ) - with self.assertRaisesRegex(ValueError, "Invalid beta parameter at index 0: 1.0"): + with self.assertRaisesRegex( + ValueError, "Invalid beta parameter at index 0: 1.0" + ): optim.RAdam(None, lr=1e-2, betas=(1.0, 0.0)) with self.assertRaisesRegex(ValueError, "Invalid weight_decay value: -1"): @@ -924,53 +1232,89 @@ def test_rmsprop(self): for foreach in (False, True): self._test_basic_cases( lambda weight, bias, maximize, foreach: optim.RMSprop( - [weight, bias], lr=1e-2, maximize=maximize, foreach=foreach), + [weight, bias], lr=1e-2, maximize=maximize, foreach=foreach + ), constructor_accepts_maximize=True, constructor_accepts_foreach=True, ) self._test_basic_cases( lambda weight, bias, maximize, foreach: optim.RMSprop( self._build_params_dict(weight, bias, lr=1e-3), - lr=1e-2, maximize=maximize, foreach=foreach), + lr=1e-2, + maximize=maximize, + foreach=foreach, + ), constructor_accepts_maximize=True, constructor_accepts_foreach=True, ) self._test_basic_cases( lambda weight, bias, maximize, foreach: optim.RMSprop( self._build_params_dict(weight, bias, lr=1e-3), - lr=1e-2, centered=True, maximize=maximize, foreach=foreach), + lr=1e-2, + centered=True, + maximize=maximize, + foreach=foreach, + ), constructor_accepts_maximize=True, constructor_accepts_foreach=True, ) self._test_basic_cases( lambda weight, bias, maximize, foreach: optim.RMSprop( self._build_params_dict(weight, bias, lr=1e-3), - lr=1e-2, centered=True, momentum=0.1, maximize=maximize, foreach=foreach), + lr=1e-2, + centered=True, + momentum=0.1, + maximize=maximize, + foreach=foreach, + ), constructor_accepts_maximize=True, constructor_accepts_foreach=True, ) self._test_basic_cases( lambda weight, bias, maximize, foreach: optim.RMSprop( self._build_params_dict(weight, bias, lr=1e-3), - lr=1e-2, momentum=0.1, maximize=maximize, foreach=foreach), + lr=1e-2, + momentum=0.1, + maximize=maximize, + foreach=foreach, + ), constructor_accepts_maximize=True, constructor_accepts_foreach=True, ) self._test_basic_cases( lambda weight, bias, maximize, foreach: optim.RMSprop( self._build_params_dict(weight, bias, lr=1e-3), - lr=1e-2, momentum=0.1, weight_decay=1, maximize=maximize, foreach=foreach), + lr=1e-2, + momentum=0.1, + weight_decay=1, + maximize=maximize, + foreach=foreach, + ), constructor_accepts_maximize=True, constructor_accepts_foreach=True, ) self._test_complex_2d(lambda param: optim.RMSprop(param, foreach=foreach)) - self._test_complex_2d(lambda param: optim.RMSprop(param, centered=True, foreach=foreach)) - self._test_complex_2d(lambda param: optim.RMSprop(param, momentum=0.1, foreach=foreach)) - self._test_complex_2d(lambda param: optim.RMSprop(param, maximize=True, foreach=foreach)) - self._test_complex_optimizer(lambda param: optim.RMSprop([param], foreach=foreach)) - self._test_complex_optimizer(lambda param: optim.RMSprop([param], centered=True, foreach=foreach)) - self._test_complex_optimizer(lambda param: optim.RMSprop([param], momentum=0.1, foreach=foreach)) - self._test_complex_optimizer(lambda param: optim.RMSprop([param], maximize=True, foreach=foreach)) + self._test_complex_2d( + lambda param: optim.RMSprop(param, centered=True, foreach=foreach) + ) + self._test_complex_2d( + lambda param: optim.RMSprop(param, momentum=0.1, foreach=foreach) + ) + self._test_complex_2d( + lambda param: optim.RMSprop(param, maximize=True, foreach=foreach) + ) + self._test_complex_optimizer( + lambda param: optim.RMSprop([param], foreach=foreach) + ) + self._test_complex_optimizer( + lambda param: optim.RMSprop([param], centered=True, foreach=foreach) + ) + self._test_complex_optimizer( + lambda param: optim.RMSprop([param], momentum=0.1, foreach=foreach) + ) + self._test_complex_optimizer( + lambda param: optim.RMSprop([param], maximize=True, foreach=foreach) + ) with self.assertRaisesRegex(ValueError, "Invalid momentum value: -1.0"): optim.RMSprop(None, lr=1e-2, momentum=-1.0, foreach=foreach) @@ -978,69 +1322,103 @@ def test_asgd(self): for foreach in (False, True): self._test_basic_cases( lambda weight, bias, maximize, foreach: optim.ASGD( - [weight, bias], lr=1e-3, t0=100, maximize=maximize, foreach=foreach), + [weight, bias], lr=1e-3, t0=100, maximize=maximize, foreach=foreach + ), constructor_accepts_maximize=True, constructor_accepts_foreach=True, ) self._test_basic_cases( lambda weight, bias, maximize, foreach: optim.ASGD( self._build_params_dict(weight, bias, lr=1e-2), - lr=1e-3, t0=100, maximize=maximize, foreach=foreach), + lr=1e-3, + t0=100, + maximize=maximize, + foreach=foreach, + ), constructor_accepts_maximize=True, constructor_accepts_foreach=True, ) self._test_basic_cases( lambda weight, bias, maximize, foreach: optim.ASGD( self._build_params_dict(weight, bias, lr=1e-2), - lr=1e-3, weight_decay=1, maximize=maximize, foreach=foreach), + lr=1e-3, + weight_decay=1, + maximize=maximize, + foreach=foreach, + ), constructor_accepts_maximize=True, constructor_accepts_foreach=True, ) # Ref: https://github.com/pytorch/pytorch/issues/84560 # self._test_complex_2d(optimizer) - self._test_complex_optimizer(lambda params: optim.ASGD([params], foreach=foreach)) - self._test_complex_optimizer(lambda params: optim.ASGD([params], maximize=True, foreach=foreach)) - self._test_complex_optimizer(lambda params: optim.ASGD([params], maximize=True, weight_decay=0.9, foreach=foreach)) - self._test_complex_optimizer(lambda params: optim.ASGD([params], maximize=False, weight_decay=0.9, foreach=foreach)) - self._test_complex_optimizer(lambda params: optim.ASGD([params], weight_decay=0.9, foreach=foreach)) + self._test_complex_optimizer( + lambda params: optim.ASGD([params], foreach=foreach) + ) + self._test_complex_optimizer( + lambda params: optim.ASGD([params], maximize=True, foreach=foreach) + ) + self._test_complex_optimizer( + lambda params: optim.ASGD( + [params], maximize=True, weight_decay=0.9, foreach=foreach + ) + ) + self._test_complex_optimizer( + lambda params: optim.ASGD( + [params], maximize=False, weight_decay=0.9, foreach=foreach + ) + ) + self._test_complex_optimizer( + lambda params: optim.ASGD([params], weight_decay=0.9, foreach=foreach) + ) with self.assertRaisesRegex(ValueError, "Invalid weight_decay value: -0.5"): optim.ASGD(None, lr=1e-2, weight_decay=-0.5, foreach=foreach) @skipIfRocm def test_rprop(self): - is_cuda_sm86 = torch.cuda.is_available() and torch.cuda.get_device_capability(0) == (8, 6) + is_cuda_sm86 = torch.cuda.is_available() and torch.cuda.get_device_capability( + 0 + ) == (8, 6) for foreach in (False, True): self._test_basic_cases( lambda weight, bias, maximize, foreach: optim.Rprop( - [weight, bias], lr=2e-4, maximize=maximize, foreach=foreach), + [weight, bias], lr=2e-4, maximize=maximize, foreach=foreach + ), constructor_accepts_maximize=True, constructor_accepts_foreach=True, ) self._test_basic_cases( lambda weight, bias, maximize, foreach: optim.Rprop( - self._build_params_dict(weight, bias, lr=1e-2), lr=2e-4, maximize=maximize, foreach=foreach), + self._build_params_dict(weight, bias, lr=1e-2), + lr=2e-4, + maximize=maximize, + foreach=foreach, + ), constructor_accepts_maximize=True, constructor_accepts_foreach=True, - atol=4e-5 if is_cuda_sm86 else None, rtol=3e-5 if is_cuda_sm86 else None + atol=4e-5 if is_cuda_sm86 else None, + rtol=3e-5 if is_cuda_sm86 else None, ) self._test_complex_2d(lambda param: optim.Rprop(param, foreach=foreach)) self._test_complex_optimizer( lambda param: optim.Rprop([param], lr=0.001, foreach=foreach) ) self._test_complex_optimizer( - lambda param: optim.Rprop([param], lr=0.001, maximize=True, foreach=foreach) + lambda param: optim.Rprop( + [param], lr=0.001, maximize=True, foreach=foreach + ) ) with self.assertRaisesRegex(ValueError, "Invalid eta values: 1.0, 0.5"): optim.Rprop(None, lr=1e-2, etas=(1.0, 0.5), foreach=foreach) def test_lbfgs(self): self._test_basic_cases( - lambda weight, bias: optim.LBFGS([weight, bias]), - ignore_multidevice=True + lambda weight, bias: optim.LBFGS([weight, bias]), ignore_multidevice=True ) self._test_basic_cases( - lambda weight, bias: optim.LBFGS([weight, bias], line_search_fn="strong_wolfe"), - ignore_multidevice=True + lambda weight, bias: optim.LBFGS( + [weight, bias], line_search_fn="strong_wolfe" + ), + ignore_multidevice=True, ) @unittest.skipIf(TEST_WITH_UBSAN, "division-by-zero error with UBSAN") @@ -1066,7 +1444,9 @@ def test_duplicate_params_in_param_group(self): warnings.simplefilter("always") optim.SGD([param, param], lr=0.1) self.assertEqual(len(w), 1) - self.assertIn('a parameter group with duplicate parameters', str(w[0].message)) + self.assertIn( + "a parameter group with duplicate parameters", str(w[0].message) + ) def test_no_grad_for_all_params(self): params = [torch.randn(5, 5, requires_grad=False) for _ in range(2)] @@ -1097,13 +1477,25 @@ def test_functional_fused_adam_with_foundinf(self): num_tensors = 5 for amsgrad in (False, True): - params, grads, exp_avgs, exp_avg_sqs = [[torch.ones((1,), device="cuda") for _ in range(num_tensors)] for _ in range(4)] - max_exp_avg_sqs = [torch.ones((1,), device="cuda") for _ in range(num_tensors)] if amsgrad else [] - state_steps = [torch.ones((1,), dtype=torch.float32, device="cuda") for _ in range(num_tensors)] + params, grads, exp_avgs, exp_avg_sqs = [ + [torch.ones((1,), device="cuda") for _ in range(num_tensors)] + for _ in range(4) + ] + max_exp_avg_sqs = ( + [torch.ones((1,), device="cuda") for _ in range(num_tensors)] + if amsgrad + else [] + ) + state_steps = [ + torch.ones((1,), dtype=torch.float32, device="cuda") + for _ in range(num_tensors) + ] grad_scale = torch.cuda.amp.grad_scaler._MultiDeviceReplicator( - torch.ones((1,), dtype=torch.float32, device="cuda")) + torch.ones((1,), dtype=torch.float32, device="cuda") + ) found_inf = torch.cuda.amp.grad_scaler._MultiDeviceReplicator( - torch.ones((1,), dtype=torch.float32, device="cuda")) + torch.ones((1,), dtype=torch.float32, device="cuda") + ) adam.adam( params, @@ -1119,7 +1511,7 @@ def test_functional_fused_adam_with_foundinf(self): beta1=0.9, beta2=0.99, lr=1e-2, - weight_decay=.0, + weight_decay=0.0, eps=1e-8, maximize=False, grad_scale=grad_scale, @@ -1128,16 +1520,32 @@ def test_functional_fused_adam_with_foundinf(self): self.assertEqual( state_steps, - [torch.ones((1,), dtype=torch.float32, device="cuda") for _ in range(num_tensors)], + [ + torch.ones((1,), dtype=torch.float32, device="cuda") + for _ in range(num_tensors) + ], ) def test_empty_grad(self): - optimizers = [torch.optim.Adadelta, torch.optim.Adagrad, torch.optim.Adam, torch.optim.AdamW, - torch.optim.Adamax, torch.optim.ASGD, torch.optim.NAdam, torch.optim.RAdam, - torch.optim.RMSprop, torch.optim.Rprop, torch.optim.SGD, torch.optim.SparseAdam] + optimizers = [ + torch.optim.Adadelta, + torch.optim.Adagrad, + torch.optim.Adam, + torch.optim.AdamW, + torch.optim.Adamax, + torch.optim.ASGD, + torch.optim.NAdam, + torch.optim.RAdam, + torch.optim.RMSprop, + torch.optim.Rprop, + torch.optim.SGD, + torch.optim.SparseAdam, + ] for optimizer in optimizers: - net = torch.nn.Embedding(5, 1, padding_idx=0, sparse=optimizer is torch.optim.SparseAdam) + net = torch.nn.Embedding( + 5, 1, padding_idx=0, sparse=optimizer is torch.optim.SparseAdam + ) original_params = (param.detach().clone() for param in net.parameters()) # Simulate a batch that only indexes the embedding at padding_idx x = torch.tensor([[0, 0]]).int() @@ -1151,6 +1559,101 @@ def test_empty_grad(self): # assert that the parameters have not changed self.assertEqual(original_param, param) + @skipIfTorchDynamo() + def test_post_hook(self): + def post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]): + nonlocal data + data += 2 + + params = [torch.Tensor([1, 1])] + opt = SGD(params, lr=0.001) + data = 2 + hook_handle = opt.register_step_post_hook(post_hook) + + opt.step() + opt.step() + # check if pre hooks were registered + self.assertEqual(data, 6) + + # remove handles, take step and verify that hook is no longer registered + hook_handle.remove() + + opt.step() + self.assertEqual(data, 6) + + @skipIfTorchDynamo() + def test_pre_hook(self): + def pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]): + nonlocal data + data += 2 + + params = [torch.Tensor([1, 1])] + opt = SGD(params, lr=0.001) + data = 5 + hook_handle = opt.register_step_pre_hook(pre_hook) + + opt.step() + opt.step() + # check if pre hooks were registered + self.assertEqual(data, 9) + + # remove handles, take step and verify that hook is no longer registered + hook_handle.remove() + + opt.step() + self.assertEqual(data, 9) + + @skipIfTorchDynamo() + def test_pre_and_post_hook(self): + def global_pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]): + nonlocal data + data.append(0) + + def global_post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]): + nonlocal data + data.append(5) + + def local_pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]): + nonlocal data + data.append(1) + + def local_post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]): + nonlocal data + data.append(2) + + params = [torch.Tensor([1, 1])] + opt1 = SGD(params, lr=0.001) + opt2 = Adam(params, lr=0.01) + data = [] + + # register global hooks to both optimizers + global_pre_handle = register_optimizer_step_pre_hook(global_pre_hook) + global_post_handle = register_optimizer_step_post_hook(global_post_hook) + + # register local hooks + first_pre_handle = opt1.register_step_pre_hook(local_pre_hook) + first_post_handle = opt1.register_step_post_hook(local_post_hook) + second_pre_handle = opt2.register_step_pre_hook(local_pre_hook) + second_post_handle = opt2.register_step_post_hook(local_post_hook) + + opt1.step() + self.assertListEqual(data, [0, 1, 2, 5]) + opt2.step() + self.assertListEqual(data, [0, 1, 2, 5, 0, 1, 2, 5]) + opt1.step() + self.assertListEqual(data, [0, 1, 2, 5, 0, 1, 2, 5, 0, 1, 2, 5]) + + # remove all hooks + global_pre_handle.remove() + global_post_handle.remove() + first_pre_handle.remove() + first_post_handle.remove() + second_pre_handle.remove() + second_post_handle.remove() + + opt1.step() + opt2.step() + self.assertListEqual(data, [0, 1, 2, 5, 0, 1, 2, 5, 0, 1, 2, 5]) class SchedulerTestNet(torch.nn.Module): @@ -1184,8 +1687,12 @@ def setUp(self): super(TestLRScheduler, self).setUp() self.net = SchedulerTestNet() self.opt = SGD( - [{'params': self.net.conv1.parameters()}, {'params': self.net.conv2.parameters(), 'lr': 0.5}], - lr=0.05) + [ + {"params": self.net.conv1.parameters()}, + {"params": self.net.conv2.parameters(), "lr": 0.5}, + ], + lr=0.05, + ) def _check_warning_is_epoch_deprecation_warning(self, w, *, num_warnings: int = 1): """This function swallows the epoch deprecation warning which is produced when we @@ -1199,25 +1706,32 @@ def _check_warning_is_epoch_deprecation_warning(self, w, *, num_warnings: int = self.assertEqual(warning.message.args[0], EPOCH_DEPRECATION_WARNING) def test_error_when_getlr_has_epoch(self): - class MultiStepLR(torch.optim.lr_scheduler._LRScheduler): + class MultiStepLR(torch.optim.lr_scheduler.LRScheduler): def __init__(self, optimizer, gamma, milestones, last_epoch=-1): - self.init_lr = [group['lr'] for group in optimizer.param_groups] + self.init_lr = [group["lr"] for group in optimizer.param_groups] self.gamma = gamma self.milestones = milestones super().__init__(optimizer, last_epoch) def get_lr(self, step): global_step = self.last_epoch - gamma_power = ([0] + [i + 1 for i, m in enumerate(self.milestones) if global_step >= m])[-1] - return [init_lr * (self.gamma ** gamma_power) for init_lr in self.init_lr] + gamma_power = ( + [0] + + [i + 1 for i, m in enumerate(self.milestones) if global_step >= m] + )[-1] + return [ + init_lr * (self.gamma**gamma_power) for init_lr in self.init_lr + ] optimizer = torch.optim.SGD([torch.rand(1)], lr=1) with self.assertRaises(TypeError): scheduler = MultiStepLR(optimizer, gamma=1, milestones=[10, 20]) + @skipIfTorchDynamo("Torchdynamo keeps references to optim in the guards and the stack of the graph break frames") def test_no_cyclic_references(self): import gc + param = Parameter(torch.empty(10)) optim = SGD([param], lr=0.5) scheduler = LambdaLR(optim, lambda epoch: 1.0) @@ -1225,23 +1739,29 @@ def test_no_cyclic_references(self): # Prior to Python 3.7, local variables in a function will be referred by the current frame. import sys + if sys.version_info < (3, 7): import inspect + referrers = gc.get_referrers(optim) self.assertTrue( len(referrers) == 1 and referrers[0] is inspect.currentframe(), - "Optimizer should contain no cyclic references (except current frame)") + "Optimizer should contain no cyclic references (except current frame)", + ) del referrers else: self.assertTrue( len(gc.get_referrers(optim)) == 0, - "Optimizer should contain no cyclic references") + "Optimizer should contain no cyclic references", + ) gc.collect() del optim self.assertEqual( - gc.collect(), 0, msg="Optimizer should be garbage-collected on __del__") + gc.collect(), 0, msg="Optimizer should be garbage-collected on __del__" + ) + @skipIfTorchDynamo("Torchdynamo keeps references to optim in the guards and the stack of the graph break frames") def test_no_cyclic_references_in_step(self): import gc import weakref @@ -1261,6 +1781,7 @@ def run(): # automatically collect unreachable objects. gc.disable() ref = run() + assert ref() is None gc.enable() # restore @@ -1276,7 +1797,7 @@ def old_pattern(): scheduler.step() self.opt.step() - self.assertWarnsRegex(UserWarning, r'how-to-adjust-learning-rate', old_pattern) + self.assertWarnsRegex(UserWarning, r"how-to-adjust-learning-rate", old_pattern) def test_old_pattern_warning_with_arg(self): epochs = 35 @@ -1290,12 +1811,12 @@ def old_pattern2(): scheduler.step() self.opt.step() - self.assertWarnsRegex(UserWarning, r'how-to-adjust-learning-rate', old_pattern2) + self.assertWarnsRegex(UserWarning, r"how-to-adjust-learning-rate", old_pattern2) def test_old_pattern_warning_resuming(self): epochs = 35 for i, group in enumerate(self.opt.param_groups): - group['initial_lr'] = 0.01 + group["initial_lr"] = 0.01 with warnings.catch_warnings(record=True) as ws: warnings.simplefilter("always") # allow any warning to be raised @@ -1307,12 +1828,12 @@ def old_pattern(): scheduler.step() self.opt.step() - self.assertWarnsRegex(UserWarning, r'how-to-adjust-learning-rate', old_pattern) + self.assertWarnsRegex(UserWarning, r"how-to-adjust-learning-rate", old_pattern) def test_old_pattern_warning_resuming_with_arg(self): epochs = 35 for i, group in enumerate(self.opt.param_groups): - group['initial_lr'] = 0.01 + group["initial_lr"] = 0.01 with warnings.catch_warnings(record=True) as ws: warnings.simplefilter("always") # allow any warning to be raised @@ -1324,12 +1845,12 @@ def old_pattern2(): scheduler.step() self.opt.step() - self.assertWarnsRegex(UserWarning, r'how-to-adjust-learning-rate', old_pattern2) + self.assertWarnsRegex(UserWarning, r"how-to-adjust-learning-rate", old_pattern2) def test_old_pattern_warning_with_overridden_optim_step(self): epochs = 35 for i, group in enumerate(self.opt.param_groups): - group['initial_lr'] = 0.01 + group["initial_lr"] = 0.01 with warnings.catch_warnings(record=True) as ws: warnings.simplefilter("always") # allow any warning to be raised @@ -1352,7 +1873,7 @@ def old_pattern2(): scheduler.step() self.opt.step() - self.assertWarnsRegex(UserWarning, r'how-to-adjust-learning-rate', old_pattern2) + self.assertWarnsRegex(UserWarning, r"how-to-adjust-learning-rate", old_pattern2) def test_new_pattern_no_warning(self): epochs = 35 @@ -1405,7 +1926,9 @@ def new_pattern(): self.opt.step() scheduler.step() - self.assertWarnsRegex(UserWarning, r'`optimizer.step\(\)` has been overridden', new_pattern) + self.assertWarnsRegex( + UserWarning, r"`optimizer.step\(\)` has been overridden", new_pattern + ) def _test_lr_is_constant_for_constant_epoch(self, scheduler): l = [] @@ -1416,7 +1939,7 @@ def _test_lr_is_constant_for_constant_epoch(self, scheduler): scheduler.step(2) self._check_warning_is_epoch_deprecation_warning(w) - l.append(self.opt.param_groups[0]['lr']) + l.append(self.opt.param_groups[0]["lr"]) self.assertEqual(min(l), max(l)) def test_step_lr_is_constant_for_constant_epoch(self): @@ -1451,8 +1974,11 @@ def test_step_lr(self): def test_get_last_lr_step_lr(self): from torch.nn import Parameter + epochs = 10 - optimizer = torch.optim.SGD([Parameter(torch.randn(2, 2, requires_grad=True))], 0.1) + optimizer = torch.optim.SGD( + [Parameter(torch.randn(2, 2, requires_grad=True))], 0.1 + ) targets = [[0.1] * 3 + [0.01] * 3 + [0.001] * 3 + [0.0001]] scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 3, gamma=0.1) self._test_get_last_lr(scheduler, targets, epochs) @@ -1507,12 +2033,21 @@ def test_get_last_lr_linearlr(self): # lr = 0.005 if 4 <= epoch epochs = 10 start_factor = 1.0 / 4 - end_factor = 3. / 5 + end_factor = 3.0 / 5 iters = 4 - interpolation = [start_factor + i * (end_factor - start_factor) / iters for i in range(iters)] - single_targets = [x * 0.05 for x in interpolation] + [0.05 * end_factor] * (epochs - iters) + interpolation = [ + start_factor + i * (end_factor - start_factor) / iters for i in range(iters) + ] + single_targets = [x * 0.05 for x in interpolation] + [0.05 * end_factor] * ( + epochs - iters + ) targets = [single_targets, [x * epochs for x in single_targets]] - scheduler = LinearLR(self.opt, start_factor=start_factor, end_factor=end_factor, total_iters=iters) + scheduler = LinearLR( + self.opt, + start_factor=start_factor, + end_factor=end_factor, + total_iters=iters, + ) self._test_get_last_lr(scheduler, targets, epochs) def test_constantlr(self): @@ -1533,14 +2068,16 @@ def test_linearlr(self): epochs = 10 start_factor = 1.0 / 2 iters = 4 - interpolation = [start_factor + i * (1 - start_factor) / iters for i in range(iters)] + interpolation = [ + start_factor + i * (1 - start_factor) / iters for i in range(iters) + ] single_targets = [x * 0.05 for x in interpolation] + [0.05] * (epochs - iters) targets = [single_targets, [x * epochs for x in single_targets]] scheduler = LinearLR(self.opt, start_factor=start_factor, total_iters=iters) self._test(scheduler, targets, epochs) def test_linearlr_start_factor_limits1(self): - start_factor = 0. + start_factor = 0.0 iters = 4 with self.assertRaises(ValueError): LinearLR(self.opt, start_factor=start_factor, total_iters=iters) @@ -1568,9 +2105,11 @@ def test_linearlr_with_epoch(self): # lr = 0.005 if 4 <= epoch epochs = 10 start_factor = 1.0 / 2 - end_factor = 1. + end_factor = 1.0 iters = 4 - interpolation = [start_factor + i * (end_factor - start_factor) / iters for i in range(iters)] + interpolation = [ + start_factor + i * (end_factor - start_factor) / iters for i in range(iters) + ] single_targets = [x * 0.05 for x in interpolation] + [0.05] * (epochs - iters) targets = [single_targets, [x * epochs for x in single_targets]] scheduler = LinearLR(self.opt, start_factor=start_factor, total_iters=iters) @@ -1578,7 +2117,7 @@ def test_linearlr_with_epoch(self): def test_exp_lr(self): epochs = 10 - single_targets = [0.05 * (0.9 ** x) for x in range(epochs)] + single_targets = [0.05 * (0.9**x) for x in range(epochs)] targets = [single_targets, [x * epochs for x in single_targets]] scheduler = ExponentialLR(self.opt, gamma=0.9) self._test(scheduler, targets, epochs) @@ -1587,7 +2126,9 @@ def test_poly_lr(self): epochs = 10 power = 0.9 total_iters = 5 - single_targets = [(1.0 - x / total_iters) ** power * 0.05 for x in range(total_iters)] + [0.0] * (epochs - total_iters) + single_targets = [ + (1.0 - x / total_iters) ** power * 0.05 for x in range(total_iters) + ] + [0.0] * (epochs - total_iters) targets = [single_targets, [x * epochs for x in single_targets]] scheduler = PolynomialLR(self.opt, power=power, total_iters=total_iters) self._test(scheduler, targets, epochs) @@ -1595,9 +2136,10 @@ def test_poly_lr(self): def test_cos_anneal_lr(self): epochs = 10 eta_min = 1e-10 - single_targets = [eta_min + (0.05 - eta_min) * - (1 + math.cos(math.pi * x / epochs)) / 2 - for x in range(epochs)] + single_targets = [ + eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / epochs)) / 2 + for x in range(epochs) + ] targets = [single_targets, [x * epochs for x in single_targets]] scheduler = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min) self._test(scheduler, targets, epochs) @@ -1608,8 +2150,12 @@ def test_closed_form_step_lr(self): self._test_against_closed_form(scheduler, closed_form_scheduler, 20) def test_closed_form_linearlr(self): - scheduler = LinearLR(self.opt, start_factor=1.0 / 3, end_factor=0.7, total_iters=4) - closed_form_scheduler = LinearLR(self.opt, start_factor=1.0 / 3, end_factor=0.7, total_iters=4) + scheduler = LinearLR( + self.opt, start_factor=1.0 / 3, end_factor=0.7, total_iters=4 + ) + closed_form_scheduler = LinearLR( + self.opt, start_factor=1.0 / 3, end_factor=0.7, total_iters=4 + ) self._test_against_closed_form(scheduler, closed_form_scheduler, 20) def test_closed_form_constantlr(self): @@ -1637,7 +2183,9 @@ def test_closed_form_cos_anneal_lr(self): epochs = 20 T_max = 5 scheduler = CosineAnnealingLR(self.opt, T_max=T_max, eta_min=eta_min) - closed_form_scheduler = CosineAnnealingLR(self.opt, T_max=T_max, eta_min=eta_min) + closed_form_scheduler = CosineAnnealingLR( + self.opt, T_max=T_max, eta_min=eta_min + ) self._test_against_closed_form(scheduler, closed_form_scheduler, epochs) def test_cos_anneal_lr_continue(self): @@ -1648,97 +2196,135 @@ def test_cos_anneal_lr_continue(self): scheduler.step() original_lrs = scheduler._last_lr new_scheduler = CosineAnnealingLR( - self.opt, T_max=T_max, eta_min=eta_min, last_epoch=0) + self.opt, T_max=T_max, eta_min=eta_min, last_epoch=0 + ) new_lrs = new_scheduler._last_lr - torch.testing.assert_allclose(original_lrs, new_lrs, rtol=1e-4, atol=1e-5) + torch.testing.assert_close(original_lrs, new_lrs, rtol=1e-4, atol=1e-5) def test_reduce_lr_on_plateau1(self): epochs = 10 for param_group in self.opt.param_groups: - param_group['lr'] = 0.5 + param_group["lr"] = 0.5 targets = [[0.5] * 20] metrics = [10 - i * 0.0167 for i in range(20)] - scheduler = ReduceLROnPlateau(self.opt, threshold_mode='abs', mode='min', - threshold=0.01, patience=5, cooldown=5) + scheduler = ReduceLROnPlateau( + self.opt, + threshold_mode="abs", + mode="min", + threshold=0.01, + patience=5, + cooldown=5, + ) self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) def test_reduce_lr_on_plateau2(self): epochs = 22 for param_group in self.opt.param_groups: - param_group['lr'] = 0.5 + param_group["lr"] = 0.5 targets = [[0.5] * 6 + [0.05] * 7 + [0.005] * 7 + [0.0005] * 2] metrics = [10 - i * 0.0165 for i in range(22)] - scheduler = ReduceLROnPlateau(self.opt, patience=5, cooldown=0, threshold_mode='abs', - mode='min', threshold=0.1) + scheduler = ReduceLROnPlateau( + self.opt, + patience=5, + cooldown=0, + threshold_mode="abs", + mode="min", + threshold=0.1, + ) self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) def test_reduce_lr_on_plateau3(self): epochs = 22 for param_group in self.opt.param_groups: - param_group['lr'] = 0.5 + param_group["lr"] = 0.5 targets = [[0.5] * (2 + 6) + [0.05] * (5 + 6) + [0.005] * 4] metrics = [-0.8] * 2 + [-0.234] * 20 - scheduler = ReduceLROnPlateau(self.opt, mode='max', patience=5, cooldown=5, - threshold_mode='abs') + scheduler = ReduceLROnPlateau( + self.opt, mode="max", patience=5, cooldown=5, threshold_mode="abs" + ) self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) def test_reduce_lr_on_plateau4(self): epochs = 20 for param_group in self.opt.param_groups: - param_group['lr'] = 0.5 + param_group["lr"] = 0.5 targets = [[0.5] * 20] - metrics = [1.5 * (1.025 ** i) for i in range(20)] # 1.025 > 1.1**0.25 - scheduler = ReduceLROnPlateau(self.opt, mode='max', patience=3, - threshold_mode='rel', threshold=0.1) + metrics = [1.5 * (1.025**i) for i in range(20)] # 1.025 > 1.1**0.25 + scheduler = ReduceLROnPlateau( + self.opt, mode="max", patience=3, threshold_mode="rel", threshold=0.1 + ) self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) def test_reduce_lr_on_plateau5(self): epochs = 20 for param_group in self.opt.param_groups: - param_group['lr'] = 0.5 + param_group["lr"] = 0.5 targets = [[0.5] * 6 + [0.05] * (5 + 6) + [0.005] * 4] - metrics = [1.5 * (1.005 ** i) for i in range(20)] - scheduler = ReduceLROnPlateau(self.opt, mode='max', threshold_mode='rel', - threshold=0.1, patience=5, cooldown=5) + metrics = [1.5 * (1.005**i) for i in range(20)] + scheduler = ReduceLROnPlateau( + self.opt, + mode="max", + threshold_mode="rel", + threshold=0.1, + patience=5, + cooldown=5, + ) self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) def test_reduce_lr_on_plateau6(self): epochs = 20 for param_group in self.opt.param_groups: - param_group['lr'] = 0.5 + param_group["lr"] = 0.5 targets = [[0.5] * 20] - metrics = [1.5 * (0.85 ** i) for i in range(20)] - scheduler = ReduceLROnPlateau(self.opt, mode='min', threshold_mode='rel', - threshold=0.1) + metrics = [1.5 * (0.85**i) for i in range(20)] + scheduler = ReduceLROnPlateau( + self.opt, mode="min", threshold_mode="rel", threshold=0.1 + ) self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) def test_reduce_lr_on_plateau7(self): epochs = 20 for param_group in self.opt.param_groups: - param_group['lr'] = 0.5 + param_group["lr"] = 0.5 targets = [[0.5] * 6 + [0.05] * (5 + 6) + [0.005] * 4] metrics = [1] * 7 + [0.6] + [0.5] * 12 - scheduler = ReduceLROnPlateau(self.opt, mode='min', threshold_mode='rel', - threshold=0.1, patience=5, cooldown=5) + scheduler = ReduceLROnPlateau( + self.opt, + mode="min", + threshold_mode="rel", + threshold=0.1, + patience=5, + cooldown=5, + ) self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) def test_reduce_lr_on_plateau8(self): epochs = 20 for param_group in self.opt.param_groups: - param_group['lr'] = 0.5 + param_group["lr"] = 0.5 targets = [[0.5] * 6 + [0.4] * 14, [0.5] * 6 + [0.3] * 14] - metrics = [1.5 * (1.005 ** i) for i in range(20)] - scheduler = ReduceLROnPlateau(self.opt, mode='max', threshold_mode='rel', min_lr=[0.4, 0.3], - threshold=0.1, patience=5, cooldown=5) + metrics = [1.5 * (1.005**i) for i in range(20)] + scheduler = ReduceLROnPlateau( + self.opt, + mode="max", + threshold_mode="rel", + min_lr=[0.4, 0.3], + threshold=0.1, + patience=5, + cooldown=5, + ) self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) def test_sequentiallr1(self): epochs = 19 schedulers = [None] * 2 - targets = [[0.05, 0.04, 0.032] + [0.05 for x in range(4)] - + [0.05 * 0.1 for x in range(4)] - + [0.05 * 0.01 for x in range(4)] - + [0.05 * 0.001 for x in range(4)]] + targets = [ + [0.05, 0.04, 0.032] + + [0.05 for x in range(4)] + + [0.05 * 0.1 for x in range(4)] + + [0.05 * 0.01 for x in range(4)] + + [0.05 * 0.001 for x in range(4)] + ] milestones = [3] schedulers[0] = ExponentialLR(self.opt, gamma=0.8) schedulers[1] = StepLR(self.opt, gamma=0.1, step_size=4) @@ -1748,7 +2334,7 @@ def test_sequentiallr1(self): def test_sequentiallr2(self): epochs = 13 schedulers = [None] * 2 - targets = [[0.005, 0.005, 0.005] + [0.05 * 0.9 ** x for x in range(10)]] + targets = [[0.005, 0.005, 0.005] + [0.05 * 0.9**x for x in range(10)]] milestones = [3] schedulers[0] = ConstantLR(self.opt, factor=0.1, total_iters=3) schedulers[1] = ExponentialLR(self.opt, gamma=0.9) @@ -1758,8 +2344,11 @@ def test_sequentiallr2(self): def test_sequentiallr3(self): epochs = 12 schedulers = [None] * 3 - targets = [[0.005, 0.005, 0.005] + [0.05, 0.04, 0.032] - + [0.05, 0.05, 0.005, 0.005, 0.0005, 0.0005]] + targets = [ + [0.005, 0.005, 0.005] + + [0.05, 0.04, 0.032] + + [0.05, 0.05, 0.005, 0.005, 0.0005, 0.0005] + ] milestones = [3, 6] schedulers[0] = ConstantLR(self.opt, factor=0.1, total_iters=3) schedulers[1] = ExponentialLR(self.opt, gamma=0.8) @@ -1773,9 +2362,11 @@ def test_sequentiallr4(self): schedulers = [ torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1), - torch.optim.lr_scheduler.ConstantLR(optimizer, factor=0.1) + torch.optim.lr_scheduler.ConstantLR(optimizer, factor=0.1), ] - scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers, milestones=[10]) + scheduler = torch.optim.lr_scheduler.SequentialLR( + optimizer, schedulers, milestones=[10] + ) new_lr = optimizer.param_groups[0]["lr"] @@ -1800,7 +2391,7 @@ def test_get_last_lr_sequentiallr(self): def test_chained_lr2_get_last_lr_before_step(self): schedulers = [ LinearLR(self.opt, start_factor=0.4, total_iters=3), - MultiStepLR(self.opt, milestones=[4, 8, 10], gamma=0.1) + MultiStepLR(self.opt, milestones=[4, 8, 10], gamma=0.1), ] scheduler = ChainedScheduler(schedulers) self.assertEqual(scheduler.get_last_lr(), schedulers[-1].get_last_lr()) @@ -1826,7 +2417,9 @@ def test_chained_lr2(self): def test_chained_lr3(self): epochs = 10 schedulers = [None] * 2 - targets = [[0.02, 0.03, 0.04, 0.05] + [0.005] * 4 + [0.0005] * 3 + [0.00005] * 3] + targets = [ + [0.02, 0.03, 0.04, 0.05] + [0.005] * 4 + [0.0005] * 3 + [0.00005] * 3 + ] schedulers[0] = LinearLR(self.opt, start_factor=0.4, total_iters=3) schedulers[1] = MultiStepLR(self.opt, milestones=[4, 8, 10], gamma=0.1) scheduler = ChainedScheduler(schedulers) @@ -1836,10 +2429,12 @@ def test_chained_lr3(self): def test_chained_lr4(self): epochs = 9 schedulers = [None] * 3 - targets = [[0.05 * 0.2 * 0.9 ** x for x in range(3)] - + [0.05 * 0.2 * 0.9 ** 3 * 0.1] - + [0.05 * 0.9 ** x * 0.1 for x in range(4, 6)] - + [0.05 * 0.9 ** x * 0.01 for x in range(6, 9)]] + targets = [ + [0.05 * 0.2 * 0.9**x for x in range(3)] + + [0.05 * 0.2 * 0.9**3 * 0.1] + + [0.05 * 0.9**x * 0.1 for x in range(4, 6)] + + [0.05 * 0.9**x * 0.01 for x in range(6, 9)] + ] schedulers[0] = ExponentialLR(self.opt, gamma=0.9) schedulers[1] = ConstantLR(self.opt, factor=0.2, total_iters=4) schedulers[2] = StepLR(self.opt, gamma=0.1, step_size=3) @@ -1877,10 +2472,10 @@ def test_compound_step_and_multistep_lr(self): def test_compound_step_and_exp_lr(self): epochs = 10 schedulers = [None] * 2 - single_targets = [0.05 * (0.9 ** x) for x in range(3)] - single_targets += [0.005 * (0.9 ** x) for x in range(3, 6)] - single_targets += [0.0005 * (0.9 ** x) for x in range(6, 9)] - single_targets += [0.00005 * (0.9 ** x) for x in range(9, 12)] + single_targets = [0.05 * (0.9**x) for x in range(3)] + single_targets += [0.005 * (0.9**x) for x in range(3, 6)] + single_targets += [0.0005 * (0.9**x) for x in range(6, 9)] + single_targets += [0.00005 * (0.9**x) for x in range(9, 12)] targets = [single_targets, [x * epochs for x in single_targets]] schedulers[0] = StepLR(self.opt, gamma=0.1, step_size=3) schedulers[1] = ExponentialLR(self.opt, gamma=0.9) @@ -1889,10 +2484,10 @@ def test_compound_step_and_exp_lr(self): def test_compound_exp_and_multistep_lr(self): epochs = 10 schedulers = [None] * 2 - single_targets = [0.05 * (0.9 ** x) for x in range(2)] - single_targets += [0.005 * (0.9 ** x) for x in range(2, 5)] - single_targets += [0.0005 * (0.9 ** x) for x in range(5, 9)] - single_targets += [0.00005 * (0.9 ** x) for x in range(9, 11)] + single_targets = [0.05 * (0.9**x) for x in range(2)] + single_targets += [0.005 * (0.9**x) for x in range(2, 5)] + single_targets += [0.0005 * (0.9**x) for x in range(5, 9)] + single_targets += [0.00005 * (0.9**x) for x in range(9, 11)] targets = [single_targets, [x * epochs for x in single_targets]] schedulers[0] = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) schedulers[1] = ExponentialLR(self.opt, gamma=0.9) @@ -1904,13 +2499,18 @@ def test_compound_exp_and_linearlr(self): start_factor = 0.4 end_factor = 0.9 schedulers = [None] * 2 - single_targets = [0.05 * (0.9 ** x) for x in range(11)] + single_targets = [0.05 * (0.9**x) for x in range(11)] for i in range(iters): single_targets[i] *= start_factor + i / iters * (end_factor - start_factor) for i in range(iters, 11): single_targets[i] *= end_factor targets = [single_targets, [x * epochs for x in single_targets]] - schedulers[0] = LinearLR(self.opt, start_factor=start_factor, end_factor=end_factor, total_iters=iters) + schedulers[0] = LinearLR( + self.opt, + start_factor=start_factor, + end_factor=end_factor, + total_iters=iters, + ) schedulers[1] = ExponentialLR(self.opt, gamma=0.9) self._test(schedulers, targets, epochs) @@ -1919,7 +2519,13 @@ def test_compound_step_and_constantlr(self): iters = 4 factor = 0.4 schedulers = [None] * 2 - single_targets = [0.05 * 0.4] * 3 + [0.005 * 0.4] + [0.005] * 2 + [0.0005] * 3 + [0.00005] * 3 + single_targets = ( + [0.05 * 0.4] * 3 + + [0.005 * 0.4] + + [0.005] * 2 + + [0.0005] * 3 + + [0.00005] * 3 + ) targets = [single_targets, [x * epochs for x in single_targets]] schedulers[0] = StepLR(self.opt, gamma=0.1, step_size=3) schedulers[1] = ConstantLR(self.opt, factor=0.4, total_iters=4) @@ -1941,9 +2547,10 @@ def test_compound_linearlr_and_multistep_lr(self): def test_compound_cosanneal_and_step_lr(self): epochs = 10 eta_min = 1e-10 - single_targets = [eta_min + (0.05 - eta_min) * - (1 + math.cos(math.pi * x / epochs)) / 2 - for x in range(epochs)] + single_targets = [ + eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / epochs)) / 2 + for x in range(epochs) + ] single_targets = [x * 0.1 ** (i // 3) for i, x in enumerate(single_targets)] targets = [single_targets, [x * epochs for x in single_targets]] schedulers = [None] * 2 @@ -1954,9 +2561,10 @@ def test_compound_cosanneal_and_step_lr(self): def test_compound_cosanneal_and_multistep_lr(self): epochs = 10 eta_min = 1e-10 - single_targets = [eta_min + (0.05 - eta_min) * - (1 + math.cos(math.pi * x / epochs)) / 2 - for x in range(epochs)] + single_targets = [ + eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / epochs)) / 2 + for x in range(epochs) + ] multipliers = [1] * 2 + [0.1] * 3 + [0.01] * 4 + [0.001] single_targets = [x * y for x, y in zip(single_targets, multipliers)] targets = [single_targets, [x * epochs for x in single_targets]] @@ -1971,9 +2579,10 @@ def test_compound_cosanneal_and_linearlr(self): start_factor = 0.4 eta_min = 1e-10 schedulers = [None] * 2 - single_targets = [eta_min + (0.05 - eta_min) * - (1 + math.cos(math.pi * x / epochs)) / 2 - for x in range(epochs)] + single_targets = [ + eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / epochs)) / 2 + for x in range(epochs) + ] for i in range(iters): single_targets[i] *= start_factor + i / iters * (1 - start_factor) targets = [single_targets, [x * epochs for x in single_targets]] @@ -1984,10 +2593,11 @@ def test_compound_cosanneal_and_linearlr(self): def test_compound_cosanneal_and_exp_lr(self): epochs = 10 eta_min = 1e-10 - single_targets = [eta_min + (0.05 - eta_min) * - (1 + math.cos(math.pi * x / epochs)) / 2 - for x in range(epochs)] - multipliers = [0.1 ** i for i in range(epochs)] + single_targets = [ + eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / epochs)) / 2 + for x in range(epochs) + ] + multipliers = [0.1**i for i in range(epochs)] single_targets = [x * y for x, y in zip(single_targets, multipliers)] targets = [single_targets, [x * epochs for x in single_targets]] schedulers = [None] * 2 @@ -1998,7 +2608,7 @@ def test_compound_cosanneal_and_exp_lr(self): def test_compound_reduce_lr_on_plateau1(self): epochs = 10 for param_group in self.opt.param_groups: - param_group['lr'] = 0.5 + param_group["lr"] = 0.5 single_targets = [0.5] * 20 multipliers = [0.1 ** (i // 3) for i in range(20)] single_targets = [x * y for x, y in zip(multipliers, single_targets)] @@ -2006,15 +2616,21 @@ def test_compound_reduce_lr_on_plateau1(self): targets = targets[1:] # test runs step before checking lr metrics = [10 - i * 0.0167 for i in range(20)] schedulers = [None, None] - schedulers[0] = ReduceLROnPlateau(self.opt, threshold_mode='abs', mode='min', - threshold=0.01, patience=5, cooldown=5) + schedulers[0] = ReduceLROnPlateau( + self.opt, + threshold_mode="abs", + mode="min", + threshold=0.01, + patience=5, + cooldown=5, + ) schedulers[1] = StepLR(self.opt, gamma=0.1, step_size=3) self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs) def test_compound_reduce_lr_on_plateau2(self): epochs = 22 for param_group in self.opt.param_groups: - param_group['lr'] = 0.5 + param_group["lr"] = 0.5 single_targets = [0.5] * 6 + [0.05] * 7 + [0.005] * 7 + [0.0005] * 2 multipliers = [1] * 3 + [0.1] * 5 + [0.01] * 4 + [0.001] * 10 single_targets = [x * y for x, y in zip(single_targets, multipliers)] @@ -2022,42 +2638,51 @@ def test_compound_reduce_lr_on_plateau2(self): targets = targets[1:] # test runs step before checking lr metrics = [10 - i * 0.0165 for i in range(22)] schedulers = [None] * 2 - schedulers[0] = ReduceLROnPlateau(self.opt, patience=5, cooldown=0, threshold_mode='abs', - mode='min', threshold=0.1) + schedulers[0] = ReduceLROnPlateau( + self.opt, + patience=5, + cooldown=0, + threshold_mode="abs", + mode="min", + threshold=0.1, + ) schedulers[1] = MultiStepLR(self.opt, gamma=0.1, milestones=[3, 8, 12]) self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs) def test_compound_reduce_lr_on_plateau3(self): epochs = 22 for param_group in self.opt.param_groups: - param_group['lr'] = 0.5 + param_group["lr"] = 0.5 single_targets = [0.5] * (2 + 6) + [0.05] * (5 + 6) + [0.005] * 4 - multipliers = [0.1 ** i for i in range(epochs)] + multipliers = [0.1**i for i in range(epochs)] single_targets = [x * y for x, y in zip(multipliers, single_targets)] targets = [single_targets] targets = targets[1:] # test runs step before checking lr metrics = [-0.8] * 2 + [-0.234] * 20 schedulers = [None, None] - schedulers[0] = ReduceLROnPlateau(self.opt, mode='max', patience=5, cooldown=5, - threshold_mode='abs') + schedulers[0] = ReduceLROnPlateau( + self.opt, mode="max", patience=5, cooldown=5, threshold_mode="abs" + ) schedulers[1] = ExponentialLR(self.opt, gamma=0.1) self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs) def test_compound_reduce_lr_on_plateau4(self): epochs = 20 for param_group in self.opt.param_groups: - param_group['lr'] = 0.05 + param_group["lr"] = 0.05 epochs = 10 eta_min = 1e-10 - single_targets = [eta_min + (0.05 - eta_min) * - (1 + math.cos(math.pi * x / epochs)) / 2 - for x in range(epochs)] + single_targets = [ + eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / epochs)) / 2 + for x in range(epochs) + ] targets = [single_targets] targets = targets[1:] # test runs step before checking lr - metrics = [1.5 * (1.025 ** i) for i in range(20)] # 1.025 > 1.1**0.25 + metrics = [1.5 * (1.025**i) for i in range(20)] # 1.025 > 1.1**0.25 schedulers = [None, None] - schedulers[0] = ReduceLROnPlateau(self.opt, mode='max', patience=3, - threshold_mode='rel', threshold=0.1) + schedulers[0] = ReduceLROnPlateau( + self.opt, mode="max", patience=3, threshold_mode="rel", threshold=0.1 + ) schedulers[1] = CosineAnnealingLR(self.opt, epochs, eta_min) self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs) @@ -2066,7 +2691,7 @@ def test_compound_reduce_lr_on_plateau5(self): start_factor = 0.4 epochs = 22 for param_group in self.opt.param_groups: - param_group['lr'] = 0.5 + param_group["lr"] = 0.5 single_targets = [0.5] * 6 + [0.05] * 7 + [0.005] * 7 + [0.0005] * 2 multipliers = [1] * 22 for i in range(iters): @@ -2076,8 +2701,14 @@ def test_compound_reduce_lr_on_plateau5(self): targets = targets[1:] # test runs step before checking lr metrics = [10 - i * 0.0165 for i in range(22)] schedulers = [None] * 2 - schedulers[0] = ReduceLROnPlateau(self.opt, patience=5, cooldown=0, threshold_mode='abs', - mode='min', threshold=0.1) + schedulers[0] = ReduceLROnPlateau( + self.opt, + patience=5, + cooldown=0, + threshold_mode="abs", + mode="min", + threshold=0.1, + ) schedulers[1] = LinearLR(self.opt, start_factor=start_factor, total_iters=iters) self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs) @@ -2090,30 +2721,94 @@ def test_cycle_lr_triangular_mode_one_lr(self): momentum_target = [5, 4, 3, 2, 1, 2, 3, 4, 5, 4, 3] lr_targets = [lr_target, lr_target] momentum_targets = [momentum_target, momentum_target] - scheduler = CyclicLR(self.opt, base_lr=1, max_lr=5, step_size_up=4, - cycle_momentum=True, base_momentum=1, max_momentum=5, - mode='triangular') + scheduler = CyclicLR( + self.opt, + base_lr=1, + max_lr=5, + step_size_up=4, + cycle_momentum=True, + base_momentum=1, + max_momentum=5, + mode="triangular", + ) self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target)) def test_cycle_lr_triangular_mode_one_lr_no_momentum(self): lr_target = [1, 2, 3, 4, 5, 4, 3, 2, 1, 2, 3] lr_targets = [lr_target, lr_target] - momentum_target = [self.opt.defaults['momentum']] * len(lr_target) + momentum_target = [self.opt.defaults["momentum"]] * len(lr_target) momentum_targets = [momentum_target, momentum_target] - scheduler = CyclicLR(self.opt, base_lr=1, max_lr=5, step_size_up=4, - cycle_momentum=False, mode='triangular') + scheduler = CyclicLR( + self.opt, + base_lr=1, + max_lr=5, + step_size_up=4, + cycle_momentum=False, + mode="triangular", + ) self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target)) def test_cycle_lr_triangular2_mode_one_lr(self): - lr_target = [1, 2, 3, 4, 5, 4, 3, 2, 1, 1.5, 2.0, 2.5, 3.0, 2.5, 2.0, 1.5, - 1, 1.25, 1.50, 1.75, 2.00, 1.75] - momentum_target = [5.0, 4.0, 3.0, 2.0, 1.0, 2.0, 3.0, 4.0, 5.0, 4.5, 4.0, - 3.5, 3.0, 3.5, 4.0, 4.5, 5.0, 4.75, 4.5, 4.25, 4.0, 4.25] + lr_target = [ + 1, + 2, + 3, + 4, + 5, + 4, + 3, + 2, + 1, + 1.5, + 2.0, + 2.5, + 3.0, + 2.5, + 2.0, + 1.5, + 1, + 1.25, + 1.50, + 1.75, + 2.00, + 1.75, + ] + momentum_target = [ + 5.0, + 4.0, + 3.0, + 2.0, + 1.0, + 2.0, + 3.0, + 4.0, + 5.0, + 4.5, + 4.0, + 3.5, + 3.0, + 3.5, + 4.0, + 4.5, + 5.0, + 4.75, + 4.5, + 4.25, + 4.0, + 4.25, + ] lr_targets = [lr_target, lr_target] momentum_targets = [momentum_target, momentum_target] - scheduler = CyclicLR(self.opt, base_lr=1, max_lr=5, step_size_up=4, - cycle_momentum=True, base_momentum=1, max_momentum=5, - mode='triangular2') + scheduler = CyclicLR( + self.opt, + base_lr=1, + max_lr=5, + step_size_up=4, + cycle_momentum=True, + base_momentum=1, + max_momentum=5, + mode="triangular2", + ) self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target)) def test_cycle_lr_exp_range_mode_one_lr(self): @@ -2125,10 +2820,17 @@ def test_cycle_lr_exp_range_mode_one_lr(self): momentum_target = [max_lr - x * diff_lr * gamma**i for i, x in enumerate(xs)] lr_targets = [lr_target, lr_target] momentum_targets = [momentum_target, momentum_target] - scheduler = CyclicLR(self.opt, base_lr=base_lr, - max_lr=max_lr, step_size_up=4, - cycle_momentum=True, base_momentum=base_lr, max_momentum=max_lr, - mode='exp_range', gamma=gamma) + scheduler = CyclicLR( + self.opt, + base_lr=base_lr, + max_lr=max_lr, + step_size_up=4, + cycle_momentum=True, + base_momentum=base_lr, + max_momentum=max_lr, + mode="exp_range", + gamma=gamma, + ) self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target)) def test_cycle_lr_triangular_mode(self): @@ -2138,23 +2840,81 @@ def test_cycle_lr_triangular_mode(self): momentum_target_1 = [5, 4, 3, 2, 1, 2, 3, 4, 5, 4, 3] momentum_target_2 = [x + 1 for x in momentum_target_1] momentum_targets = [momentum_target_1, momentum_target_2] - scheduler = CyclicLR(self.opt, base_lr=[1, 2], max_lr=[5, 6], step_size_up=4, - cycle_momentum=True, base_momentum=[1, 2], max_momentum=[5, 6], - mode='triangular') + scheduler = CyclicLR( + self.opt, + base_lr=[1, 2], + max_lr=[5, 6], + step_size_up=4, + cycle_momentum=True, + base_momentum=[1, 2], + max_momentum=[5, 6], + mode="triangular", + ) self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target_1)) def test_cycle_lr_triangular2_mode(self): - lr_target_1 = [1, 2, 3, 4, 5, 4, 3, 2, 1, 1.5, 2.0, 2.5, 3.0, 2.5, 2.0, 1.5, 1, - 1.25, 1.50, 1.75, 2.00, 1.75] + lr_target_1 = [ + 1, + 2, + 3, + 4, + 5, + 4, + 3, + 2, + 1, + 1.5, + 2.0, + 2.5, + 3.0, + 2.5, + 2.0, + 1.5, + 1, + 1.25, + 1.50, + 1.75, + 2.00, + 1.75, + ] lr_target_2 = [x + 2 for x in lr_target_1] lr_targets = [lr_target_1, lr_target_2] - momentum_target_1 = [5.0, 4.0, 3.0, 2.0, 1.0, 2.0, 3.0, 4.0, 5.0, 4.5, 4.0, 3.5, - 3.0, 3.5, 4.0, 4.5, 5.0, 4.75, 4.5, 4.25, 4.0, 4.25] + momentum_target_1 = [ + 5.0, + 4.0, + 3.0, + 2.0, + 1.0, + 2.0, + 3.0, + 4.0, + 5.0, + 4.5, + 4.0, + 3.5, + 3.0, + 3.5, + 4.0, + 4.5, + 5.0, + 4.75, + 4.5, + 4.25, + 4.0, + 4.25, + ] momentum_target_2 = [x + 2 for x in momentum_target_1] momentum_targets = [momentum_target_1, momentum_target_2] - scheduler = CyclicLR(self.opt, base_lr=[1, 3], max_lr=[5, 7], step_size_up=4, - cycle_momentum=True, base_momentum=[1, 3], max_momentum=[5, 7], - mode='triangular2') + scheduler = CyclicLR( + self.opt, + base_lr=[1, 3], + max_lr=[5, 7], + step_size_up=4, + cycle_momentum=True, + base_momentum=[1, 3], + max_momentum=[5, 7], + mode="triangular2", + ) self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target_1)) def test_cycle_lr_exp_range_mode(self): @@ -2169,46 +2929,129 @@ def test_cycle_lr_exp_range_mode(self): lr_target_1 = [base_lr_1 + x * diff_lr_1 * gamma**i for i, x in enumerate(xs)] lr_target_2 = [base_lr_2 + x * diff_lr_2 * gamma**i for i, x in enumerate(xs)] lr_targets = [lr_target_1, lr_target_2] - momentum_target_1 = [max_lr_1 - x * diff_lr_1 * gamma**i for i, x in enumerate(xs)] - momentum_target_2 = [max_lr_2 - x * diff_lr_2 * gamma**i for i, x in enumerate(xs)] + momentum_target_1 = [ + max_lr_1 - x * diff_lr_1 * gamma**i for i, x in enumerate(xs) + ] + momentum_target_2 = [ + max_lr_2 - x * diff_lr_2 * gamma**i for i, x in enumerate(xs) + ] momentum_targets = [momentum_target_1, momentum_target_2] - scheduler = CyclicLR(self.opt, base_lr=[base_lr_1, base_lr_2], - max_lr=[max_lr_1, max_lr_2], step_size_up=4, - cycle_momentum=True, base_momentum=[base_lr_1, base_lr_2], - max_momentum=[max_lr_1, max_lr_2], - mode='exp_range', gamma=gamma) + scheduler = CyclicLR( + self.opt, + base_lr=[base_lr_1, base_lr_2], + max_lr=[max_lr_1, max_lr_2], + step_size_up=4, + cycle_momentum=True, + base_momentum=[base_lr_1, base_lr_2], + max_momentum=[max_lr_1, max_lr_2], + mode="exp_range", + gamma=gamma, + ) self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target_1)) def test_cycle_lr_triangular_mode_step_size_up_down(self): - lr_target = [1.0, 2.0, 3.0, 4.0, 5.0, 13.0 / 3, 11.0 / 3, 9.0 / 3, 7.0 / 3, 5.0 / 3, 1.0] + lr_target = [ + 1.0, + 2.0, + 3.0, + 4.0, + 5.0, + 13.0 / 3, + 11.0 / 3, + 9.0 / 3, + 7.0 / 3, + 5.0 / 3, + 1.0, + ] lr_targets = [lr_target, lr_target] - momentum_target = [5.0, 4.0, 3.0, 2.0, 1.0, 5.0 / 3, 7.0 / 3, 3.0, 11.0 / 3, 13.0 / 3, 5.0] + momentum_target = [ + 5.0, + 4.0, + 3.0, + 2.0, + 1.0, + 5.0 / 3, + 7.0 / 3, + 3.0, + 11.0 / 3, + 13.0 / 3, + 5.0, + ] momentum_targets = [momentum_target, momentum_target] - scheduler = CyclicLR(self.opt, base_lr=1, max_lr=5, - step_size_up=4, - step_size_down=6, - cycle_momentum=True, - base_momentum=1, max_momentum=5, - mode='triangular') + scheduler = CyclicLR( + self.opt, + base_lr=1, + max_lr=5, + step_size_up=4, + step_size_down=6, + cycle_momentum=True, + base_momentum=1, + max_momentum=5, + mode="triangular", + ) self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target)) def test_cycle_lr_triangular2_mode_step_size_up_down(self): - lr_base_target = ([ - 1.0, 3.0, 5.0, 13.0 / 3, 11.0 / 3, 9.0 / 3, 7.0 / 3, 5.0 / 3, 1.0, 2.0, 3.0, 8.0 / 3, - 7.0 / 3, 6.0 / 3, 5.0 / 3, 4.0 / 3, 1.0, 3.0 / 2, 2.0, 11.0 / 6, 10.0 / 6, 9.0 / 6, - 8.0 / 6, 7.0 / 6 - ]) - momentum_base_target = ([ - 5.0, 3.0, 1.0, 5.0 / 3, 7.0 / 3, 3.0, 11.0 / 3, 13.0 / 3, 5.0, 4.0, 3.0, 10.0 / 3, - 11.0 / 3, 4.0, 13.0 / 3, 14.0 / 3, 5.0, 4.5, 4.0, 25.0 / 6, 13.0 / 3, 4.5, 14.0 / 3, - 29.0 / 6 - ]) + lr_base_target = [ + 1.0, + 3.0, + 5.0, + 13.0 / 3, + 11.0 / 3, + 9.0 / 3, + 7.0 / 3, + 5.0 / 3, + 1.0, + 2.0, + 3.0, + 8.0 / 3, + 7.0 / 3, + 6.0 / 3, + 5.0 / 3, + 4.0 / 3, + 1.0, + 3.0 / 2, + 2.0, + 11.0 / 6, + 10.0 / 6, + 9.0 / 6, + 8.0 / 6, + 7.0 / 6, + ] + momentum_base_target = [ + 5.0, + 3.0, + 1.0, + 5.0 / 3, + 7.0 / 3, + 3.0, + 11.0 / 3, + 13.0 / 3, + 5.0, + 4.0, + 3.0, + 10.0 / 3, + 11.0 / 3, + 4.0, + 13.0 / 3, + 14.0 / 3, + 5.0, + 4.5, + 4.0, + 25.0 / 6, + 13.0 / 3, + 4.5, + 14.0 / 3, + 29.0 / 6, + ] deltas = [2 * i for i in range(0, 2)] base_lrs = [1 + delta for delta in deltas] max_lrs = [5 + delta for delta in deltas] lr_targets = [[x + delta for x in lr_base_target] for delta in deltas] - momentum_targets = [[x + delta for x in momentum_base_target] for delta in deltas] + momentum_targets = [ + [x + delta for x in momentum_base_target] for delta in deltas + ] scheduler = CyclicLR( self.opt, base_lr=base_lrs, @@ -2218,26 +3061,47 @@ def test_cycle_lr_triangular2_mode_step_size_up_down(self): cycle_momentum=True, base_momentum=base_lrs, max_momentum=max_lrs, - mode='triangular2') - self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_base_target)) + mode="triangular2", + ) + self._test_cycle_lr( + scheduler, lr_targets, momentum_targets, len(lr_base_target) + ) def test_cycle_lr_exp_range_mode_step_size_up_down(self): base_lr, max_lr = 1, 5 diff_lr = max_lr - base_lr gamma = 0.9 - xs = ([ - 0.0, 0.5, 1.0, 5.0 / 6, 4.0 / 6, 3.0 / 6, 2.0 / 6, 1.0 / 6, 0.0, 0.5, 1.0, 5.0 / 6, - 4.0 / 6 - ]) + xs = [ + 0.0, + 0.5, + 1.0, + 5.0 / 6, + 4.0 / 6, + 3.0 / 6, + 2.0 / 6, + 1.0 / 6, + 0.0, + 0.5, + 1.0, + 5.0 / 6, + 4.0 / 6, + ] lr_target = [base_lr + x * diff_lr * gamma**i for i, x in enumerate(xs)] lr_targets = [lr_target, lr_target] momentum_target = [max_lr - x * diff_lr * gamma**i for i, x in enumerate(xs)] momentum_targets = [momentum_target, momentum_target] - scheduler = CyclicLR(self.opt, base_lr=base_lr, max_lr=max_lr, - step_size_up=2, step_size_down=6, - cycle_momentum=True, base_momentum=base_lr, - max_momentum=max_lr, - mode='exp_range', gamma=gamma) + scheduler = CyclicLR( + self.opt, + base_lr=base_lr, + max_lr=max_lr, + step_size_up=2, + step_size_down=6, + cycle_momentum=True, + base_momentum=base_lr, + max_momentum=max_lr, + mode="exp_range", + gamma=gamma, + ) self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target)) def test_cycle_lr_with_momentumless_optimizer(self): @@ -2250,15 +3114,25 @@ def test_cycle_lr_with_momentumless_optimizer(self): # in more detail in https://github.com/pytorch/pytorch/issues/19003 ). old_opt = self.opt self.opt = optim.Adam( - [{'params': self.net.conv1.parameters()}, {'params': self.net.conv2.parameters(), 'lr': 0.5}], - lr=0.05) + [ + {"params": self.net.conv1.parameters()}, + {"params": self.net.conv2.parameters(), "lr": 0.5}, + ], + lr=0.05, + ) lr_target = [1, 2, 3, 4, 5, 4, 3, 2, 1, 2, 3] lr_targets = [lr_target, lr_target] momentum_target = [None] * len(lr_target) momentum_targets = [momentum_target, momentum_target] - scheduler = CyclicLR(self.opt, base_lr=1, max_lr=5, step_size_up=4, - cycle_momentum=False, mode='triangular') + scheduler = CyclicLR( + self.opt, + base_lr=1, + max_lr=5, + step_size_up=4, + cycle_momentum=False, + mode="triangular", + ) self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target)) self.opt = old_opt # set optimizer back to SGD @@ -2271,6 +3145,7 @@ def test_cycle_lr_cycle_momentum_fail_with_momentumless_optimizer(self): def test_cycle_lr_removed_after_out_of_scope(self): import gc import weakref + gc.disable() def test(): @@ -2284,7 +3159,9 @@ def test(): def test_onecycle_lr_invalid_anneal_strategy(self): with self.assertRaises(ValueError): - scheduler = OneCycleLR(self.opt, max_lr=1e-3, total_steps=10, anneal_strategy="CATS") + scheduler = OneCycleLR( + self.opt, max_lr=1e-3, total_steps=10, anneal_strategy="CATS" + ) def test_onecycle_lr_invalid_pct_start(self): with self.assertRaises(ValueError): @@ -2299,8 +3176,15 @@ def test_onecycle_lr_linear_annealing(self): momentum_target = [22, 11.5, 1, 4, 7, 10, 13, 16, 19, 22] lr_targets = [lr_target, lr_target] momentum_targets = [momentum_target, momentum_target] - scheduler = OneCycleLR(self.opt, max_lr=25, final_div_factor=2, base_momentum=1, max_momentum=22, - total_steps=10, anneal_strategy='linear') + scheduler = OneCycleLR( + self.opt, + max_lr=25, + final_div_factor=2, + base_momentum=1, + max_momentum=22, + total_steps=10, + anneal_strategy="linear", + ) self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10) def test_onecycle_lr_linear_annealing_three_phases(self): @@ -2308,59 +3192,111 @@ def test_onecycle_lr_linear_annealing_three_phases(self): momentum_target = [22, 15, 8, 1, 8, 15, 22, 22, 22, 22] lr_targets = [lr_target, lr_target] momentum_targets = [momentum_target, momentum_target] - scheduler = OneCycleLR(self.opt, max_lr=25, div_factor=25, - base_momentum=1, max_momentum=22, - total_steps=10, anneal_strategy='linear', - pct_start=0.4, final_div_factor=4, - three_phase=True) + scheduler = OneCycleLR( + self.opt, + max_lr=25, + div_factor=25, + base_momentum=1, + max_momentum=22, + total_steps=10, + anneal_strategy="linear", + pct_start=0.4, + final_div_factor=4, + three_phase=True, + ) self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10) def test_onecycle_lr_cosine_annealing(self): def annealing_cos(start, end, pct): cos_out = math.cos(math.pi * pct) + 1 return end + (start - end) / 2.0 * cos_out - lr_target = [1, 13, 25, annealing_cos(25, 0.5, 1 / 7.0), annealing_cos(25, 0.5, 2 / 7.0), - annealing_cos(25, 0.5, 3 / 7.0), annealing_cos(25, 0.5, 4 / 7.0), annealing_cos(25, 0.5, 5 / 7.0), - annealing_cos(25, 0.5, 6 / 7.0), 0.5] - momentum_target = [22, 11.5, 1, annealing_cos(1, 22, 1 / 7.0), annealing_cos(1, 22, 2 / 7.0), - annealing_cos(1, 22, 3 / 7.0), annealing_cos(1, 22, 4 / 7.0), annealing_cos(1, 22, 5 / 7.0), - annealing_cos(1, 22, 6 / 7.0), 22] + + lr_target = [ + 1, + 13, + 25, + annealing_cos(25, 0.5, 1 / 7.0), + annealing_cos(25, 0.5, 2 / 7.0), + annealing_cos(25, 0.5, 3 / 7.0), + annealing_cos(25, 0.5, 4 / 7.0), + annealing_cos(25, 0.5, 5 / 7.0), + annealing_cos(25, 0.5, 6 / 7.0), + 0.5, + ] + momentum_target = [ + 22, + 11.5, + 1, + annealing_cos(1, 22, 1 / 7.0), + annealing_cos(1, 22, 2 / 7.0), + annealing_cos(1, 22, 3 / 7.0), + annealing_cos(1, 22, 4 / 7.0), + annealing_cos(1, 22, 5 / 7.0), + annealing_cos(1, 22, 6 / 7.0), + 22, + ] lr_targets = [lr_target, lr_target] momentum_targets = [momentum_target, momentum_target] - scheduler = OneCycleLR(self.opt, max_lr=25, final_div_factor=2, base_momentum=1, max_momentum=22, - total_steps=10) + scheduler = OneCycleLR( + self.opt, + max_lr=25, + final_div_factor=2, + base_momentum=1, + max_momentum=22, + total_steps=10, + ) self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10) def test_cycle_lr_with_adam(self): old_opt = self.opt self.opt = optim.Adam( - [{'params': self.net.conv1.parameters()}, {'params': self.net.conv2.parameters(), 'lr': 0.5}], - lr=0.05) + [ + {"params": self.net.conv1.parameters()}, + {"params": self.net.conv2.parameters(), "lr": 0.5}, + ], + lr=0.05, + ) lr_target = [1, 13, 25, 21.5, 18, 14.5, 11, 7.5, 4, 0.5] momentum_target = [22, 11.5, 1, 4, 7, 10, 13, 16, 19, 22] lr_targets = [lr_target, lr_target] momentum_targets = [momentum_target, momentum_target] - scheduler = OneCycleLR(self.opt, max_lr=25, final_div_factor=2, base_momentum=1, max_momentum=22, - total_steps=10, anneal_strategy='linear') + scheduler = OneCycleLR( + self.opt, + max_lr=25, + final_div_factor=2, + base_momentum=1, + max_momentum=22, + total_steps=10, + anneal_strategy="linear", + ) self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10, use_beta1=True) self.opt = old_opt # set optimizer back to SGD def test_lambda_lr(self): epochs = 10 - self.opt.param_groups[0]['lr'] = 0.05 - self.opt.param_groups[1]['lr'] = 0.4 - targets = [[0.05 * (0.9 ** x) for x in range(epochs)], [0.4 * (0.8 ** x) for x in range(epochs)]] - scheduler = LambdaLR(self.opt, - lr_lambda=[lambda x1: 0.9 ** x1, lambda x2: 0.8 ** x2]) + self.opt.param_groups[0]["lr"] = 0.05 + self.opt.param_groups[1]["lr"] = 0.4 + targets = [ + [0.05 * (0.9**x) for x in range(epochs)], + [0.4 * (0.8**x) for x in range(epochs)], + ] + scheduler = LambdaLR( + self.opt, lr_lambda=[lambda x1: 0.9**x1, lambda x2: 0.8**x2] + ) self._test(scheduler, targets, epochs) def test_multiplicative_lr(self): epochs = 10 - self.opt.param_groups[0]['lr'] = 0.05 - self.opt.param_groups[1]['lr'] = 0.4 - targets = [[0.05 * (0.9 ** x) for x in range(epochs)], [0.4 * (0.8 ** x) for x in range(epochs)]] - scheduler = MultiplicativeLR(self.opt, lr_lambda=[lambda x1: 0.9, lambda x2: 0.8]) + self.opt.param_groups[0]["lr"] = 0.05 + self.opt.param_groups[1]["lr"] = 0.4 + targets = [ + [0.05 * (0.9**x) for x in range(epochs)], + [0.4 * (0.8**x) for x in range(epochs)], + ] + scheduler = MultiplicativeLR( + self.opt, lr_lambda=[lambda x1: 0.9, lambda x2: 0.8] + ) self._test(scheduler, targets, epochs) @parametrize("T_mult", [1, 2, 4]) @@ -2370,14 +3306,20 @@ def test_CosineAnnealingWarmRestarts_lr1(self, T_mult): T_i = 10 T_cur = 0 targets = [[0.05], [0.5]] - scheduler = CosineAnnealingWarmRestarts(self.opt, T_0=T_i, T_mult=T_mult, eta_min=eta_min) + scheduler = CosineAnnealingWarmRestarts( + self.opt, T_0=T_i, T_mult=T_mult, eta_min=eta_min + ) for _ in range(1, iters, 1): T_cur += 1 if T_cur >= T_i: T_cur = T_cur - T_i T_i = int(T_mult) * T_i - targets[0] += [eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2] - targets[1] += [eta_min + (0.5 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2] + targets[0] += [ + eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2 + ] + targets[1] += [ + eta_min + (0.5 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2 + ] self._test(scheduler, targets, iters) def test_CosineAnnealingWarmRestarts_lr2(self): @@ -2388,41 +3330,69 @@ def test_CosineAnnealingWarmRestarts_lr2(self): T_i = 10 T_cur = 0 targets = [[0.05], [0.5]] - scheduler = CosineAnnealingWarmRestarts(self.opt, T_0=T_i, T_mult=T_mult, eta_min=eta_min) + scheduler = CosineAnnealingWarmRestarts( + self.opt, T_0=T_i, T_mult=T_mult, eta_min=eta_min + ) for _ in torch.arange(0.1, iters, 0.1): T_cur = round(T_cur + 0.1, 1) if T_cur >= T_i: T_cur = T_cur - T_i T_i = int(T_mult) * T_i - targets[0] += [eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2] - targets[1] += [eta_min + (0.5 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2] + targets[0] += [ + eta_min + + (0.05 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2 + ] + targets[1] += [ + eta_min + + (0.5 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2 + ] self._test_CosineAnnealingWarmRestarts(scheduler, targets, iters) def test_CosineAnnealingWarmRestarts_lr3(self): - epochs_for_T_mults = [[0, 1, 2, 3, 4, 5, 12, 27, 3, 4, 5, 6, 13], - [0, 1, 2, 3, 4, 5, 25, 32, 33, 34, 80, 81, 3], - [0, 0.1, 0.2, 0.3, 1.3, 2.3, 17.5, 18.5, 19.5, 29.5, 30.5, 31.5, 50]] - T_curs_for_T_mults = [[1, 2, 3, 4, 5, 2, 7, 3, 4, 5, 6, 3], - [1, 2, 3, 4, 5, 15, 2, 3, 4, 10, 11, 3], - [0.1, 0.2, 0.3, 1.3, 2.3, 7.5, 8.5, 9.5, 19.5, 20.5, 21.5, 10]] - T_is_for_T_mults = [[10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10], - [10, 10, 10, 10, 10, 20, 40, 40, 40, 80, 80, 10], - [10, 10, 10, 10, 10, 30, 30, 30, 30, 30, 30, 90]] + epochs_for_T_mults = [ + [0, 1, 2, 3, 4, 5, 12, 27, 3, 4, 5, 6, 13], + [0, 1, 2, 3, 4, 5, 25, 32, 33, 34, 80, 81, 3], + [0, 0.1, 0.2, 0.3, 1.3, 2.3, 17.5, 18.5, 19.5, 29.5, 30.5, 31.5, 50], + ] + T_curs_for_T_mults = [ + [1, 2, 3, 4, 5, 2, 7, 3, 4, 5, 6, 3], + [1, 2, 3, 4, 5, 15, 2, 3, 4, 10, 11, 3], + [0.1, 0.2, 0.3, 1.3, 2.3, 7.5, 8.5, 9.5, 19.5, 20.5, 21.5, 10], + ] + T_is_for_T_mults = [ + [10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10], + [10, 10, 10, 10, 10, 20, 40, 40, 40, 80, 80, 10], + [10, 10, 10, 10, 10, 30, 30, 30, 30, 30, 30, 90], + ] eta_min = 1e-10 T_mults = [1, 2, 3] - for epochs, T_mult, T_curs, T_is in zip(epochs_for_T_mults, T_mults, T_curs_for_T_mults, T_is_for_T_mults): + for epochs, T_mult, T_curs, T_is in zip( + epochs_for_T_mults, T_mults, T_curs_for_T_mults, T_is_for_T_mults + ): targets = [[0.05], [0.5]] - scheduler = CosineAnnealingWarmRestarts(self.opt, T_0=10, T_mult=T_mult, eta_min=eta_min) + scheduler = CosineAnnealingWarmRestarts( + self.opt, T_0=10, T_mult=T_mult, eta_min=eta_min + ) for T_cur, T_i in zip(T_curs, T_is): - targets[0] += [eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2] - targets[1] += [eta_min + (0.5 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2] - self._test_interleaved_CosineAnnealingWarmRestarts(scheduler, targets, epochs) + targets[0] += [ + eta_min + + (0.05 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2 + ] + targets[1] += [ + eta_min + + (0.5 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2 + ] + self._test_interleaved_CosineAnnealingWarmRestarts( + scheduler, targets, epochs + ) def test_swalr_no_anneal(self): epochs, swa_start, swa_lr = 10, 5, 0.01 - initial_lrs = [group['lr'] for group in self.opt.param_groups] - targets = [[lr] * (swa_start + 1) + [swa_lr] * (epochs - swa_start - 1) - for lr in initial_lrs] + initial_lrs = [group["lr"] for group in self.opt.param_groups] + targets = [ + [lr] * (swa_start + 1) + [swa_lr] * (epochs - swa_start - 1) + for lr in initial_lrs + ] swa_scheduler = SWALR(self.opt, anneal_epochs=1, swa_lr=swa_lr) self._test_swalr(swa_scheduler, None, targets, swa_start, epochs) @@ -2435,15 +3405,22 @@ def test_swalr_cosine_anneal_after_multiplicative(self): def anneal_coef(t): if t + 1 >= anneal_epochs: - return 0. + return 0.0 return (1 + math.cos(math.pi * (t + 1) / anneal_epochs)) / 2 - initial_lrs = [group['lr'] for group in self.opt.param_groups] - targets_before_swa = [[lr * mult_factor**i for i in range(swa_start + 1)] - for lr in initial_lrs] + initial_lrs = [group["lr"] for group in self.opt.param_groups] + targets_before_swa = [ + [lr * mult_factor**i for i in range(swa_start + 1)] for lr in initial_lrs + ] swa_epochs = epochs - swa_start - 1 - targets = [lrs + [lrs[-1] * anneal_coef(t) + swa_lr * (1 - anneal_coef(t)) for t in range(swa_epochs)] - for lrs in targets_before_swa] + targets = [ + lrs + + [ + lrs[-1] * anneal_coef(t) + swa_lr * (1 - anneal_coef(t)) + for t in range(swa_epochs) + ] + for lrs in targets_before_swa + ] self._test_swalr(swa_scheduler, scheduler, targets, swa_start, epochs) @@ -2452,29 +3429,46 @@ def test_swalr_linear_anneal_after_multiplicative(self): epochs, swa_start, swa_lrs, anneal_epochs = 15, 5, [0.01, 0.02], 4 mult_factor = 0.9 scheduler = MultiplicativeLR(self.opt, lr_lambda=lambda epoch: mult_factor) - swa_scheduler = SWALR(self.opt, anneal_epochs=anneal_epochs, - anneal_strategy="linear", swa_lr=swa_lrs) + swa_scheduler = SWALR( + self.opt, + anneal_epochs=anneal_epochs, + anneal_strategy="linear", + swa_lr=swa_lrs, + ) def anneal_coef(t): if t + 1 >= anneal_epochs: - return 0. + return 0.0 return 1 - (t + 1) / anneal_epochs - initial_lrs = [group['lr'] for group in self.opt.param_groups] - targets_before_swa = [[lr * mult_factor**i for i in range(swa_start + 1)] - for lr in initial_lrs] + initial_lrs = [group["lr"] for group in self.opt.param_groups] + targets_before_swa = [ + [lr * mult_factor**i for i in range(swa_start + 1)] for lr in initial_lrs + ] swa_epochs = epochs - swa_start - 1 - targets = [lrs + [lrs[-1] * anneal_coef(t) + swa_lr * (1 - anneal_coef(t)) for t in range(swa_epochs)] - for lrs, swa_lr in zip(targets_before_swa, swa_lrs)] + targets = [ + lrs + + [ + lrs[-1] * anneal_coef(t) + swa_lr * (1 - anneal_coef(t)) + for t in range(swa_epochs) + ] + for lrs, swa_lr in zip(targets_before_swa, swa_lrs) + ] self._test_swalr(swa_scheduler, scheduler, targets, swa_start, epochs) def _test_swalr(self, swa_scheduler, scheduler, targets, swa_start, epochs): for epoch in range(epochs): for param_group, target in zip(self.opt.param_groups, targets): - self.assertEqual(target[epoch], param_group['lr'], - msg='LR is wrong in epoch {}: expected {}, got {}'.format( - epoch, target[epoch], param_group['lr']), atol=1e-5, rtol=0) + self.assertEqual( + target[epoch], + param_group["lr"], + msg="LR is wrong in epoch {}: expected {}, got {}".format( + epoch, target[epoch], param_group["lr"] + ), + atol=1e-5, + rtol=0, + ) if epoch >= swa_start: self.opt.step() swa_scheduler.step() @@ -2485,29 +3479,32 @@ def _test_swalr(self, swa_scheduler, scheduler, targets, swa_start, epochs): def test_swalr_hypers(self): # Test that SWALR raises errors for incorrect hyper-parameters with self.assertRaisesRegex(ValueError, "anneal_strategy must"): - swa_scheduler = SWALR(self.opt, anneal_strategy="exponential", swa_lr=1.) + swa_scheduler = SWALR(self.opt, anneal_strategy="exponential", swa_lr=1.0) with self.assertRaisesRegex(ValueError, "anneal_epochs must"): - swa_scheduler = SWALR(self.opt, anneal_epochs=-1, swa_lr=1.) + swa_scheduler = SWALR(self.opt, anneal_epochs=-1, swa_lr=1.0) with self.assertRaisesRegex(ValueError, "anneal_epochs must"): - swa_scheduler = SWALR(self.opt, anneal_epochs=1.7, swa_lr=1.) + swa_scheduler = SWALR(self.opt, anneal_epochs=1.7, swa_lr=1.0) with self.assertRaisesRegex(ValueError, "swa_lr must"): - swa_scheduler = SWALR(self.opt, swa_lr=[1., 0.1, 0.01]) + swa_scheduler = SWALR(self.opt, swa_lr=[1.0, 0.1, 0.01]) def test_step_lr_state_dict(self): self._check_scheduler_state_dict( lambda: StepLR(self.opt, gamma=0.1, step_size=3), - lambda: StepLR(self.opt, gamma=0.01 / 2, step_size=1)) + lambda: StepLR(self.opt, gamma=0.01 / 2, step_size=1), + ) def test_multi_step_lr_state_dict(self): self._check_scheduler_state_dict( lambda: MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]), - lambda: MultiStepLR(self.opt, gamma=0.01, milestones=[1, 4, 6])) + lambda: MultiStepLR(self.opt, gamma=0.01, milestones=[1, 4, 6]), + ) def test_exp_step_lr_state_dict(self): self._check_scheduler_state_dict( lambda: ExponentialLR(self.opt, gamma=0.1), - lambda: ExponentialLR(self.opt, gamma=0.01)) + lambda: ExponentialLR(self.opt, gamma=0.01), + ) def test_cosine_lr_state_dict(self): epochs = 10 @@ -2515,49 +3512,56 @@ def test_cosine_lr_state_dict(self): self._check_scheduler_state_dict( lambda: CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min), lambda: CosineAnnealingLR(self.opt, T_max=epochs // 2, eta_min=eta_min / 2), - epochs=epochs) + epochs=epochs, + ) def test_reduce_lr_on_plateau_state_dict(self): - scheduler = ReduceLROnPlateau(self.opt, mode='min', factor=0.1, patience=2) + scheduler = ReduceLROnPlateau(self.opt, mode="min", factor=0.1, patience=2) for score in [1.0, 2.0, 3.0, 4.0, 3.0, 4.0, 5.0, 3.0, 2.0, 1.0]: scheduler.step(score) - scheduler_copy = ReduceLROnPlateau(self.opt, mode='max', factor=0.5, patience=10) + scheduler_copy = ReduceLROnPlateau( + self.opt, mode="max", factor=0.5, patience=10 + ) scheduler_copy.load_state_dict(scheduler.state_dict()) for key in scheduler.__dict__.keys(): - if key not in {'optimizer', 'is_better'}: + if key not in {"optimizer", "is_better"}: self.assertEqual(scheduler.__dict__[key], scheduler_copy.__dict__[key]) def test_lambda_lr_state_dict_fn(self): scheduler = LambdaLR(self.opt, lr_lambda=lambda x: x) state = scheduler.state_dict() - self.assertIsNone(state['lr_lambdas'][0]) + self.assertIsNone(state["lr_lambdas"][0]) scheduler_copy = LambdaLR(self.opt, lr_lambda=lambda x: x) scheduler_copy.load_state_dict(state) for key in scheduler.__dict__.keys(): - if key not in {'optimizer', 'lr_lambdas'}: + if key not in {"optimizer", "lr_lambdas"}: self.assertEqual(scheduler.__dict__[key], scheduler_copy.__dict__[key]) def test_lambda_lr_state_dict_obj(self): scheduler = LambdaLR(self.opt, lr_lambda=LambdaLRTestObject(10)) state = scheduler.state_dict() - self.assertIsNotNone(state['lr_lambdas'][0]) + self.assertIsNotNone(state["lr_lambdas"][0]) scheduler_copy = LambdaLR(self.opt, lr_lambda=LambdaLRTestObject(-1)) scheduler_copy.load_state_dict(state) for key in scheduler.__dict__.keys(): - if key not in {'optimizer'}: + if key not in {"optimizer"}: self.assertEqual(scheduler.__dict__[key], scheduler_copy.__dict__[key]) def test_CosineAnnealingWarmRestarts_lr_state_dict(self): self._check_scheduler_state_dict( lambda: CosineAnnealingWarmRestarts(self.opt, T_0=10, T_mult=2), - lambda: CosineAnnealingWarmRestarts(self.opt, T_0=100)) + lambda: CosineAnnealingWarmRestarts(self.opt, T_0=100), + ) def test_swa_lr_state_dict(self): self._check_scheduler_state_dict( lambda: SWALR(self.opt, anneal_epochs=3, swa_lr=0.5), - lambda: SWALR(self.opt, anneal_epochs=10, anneal_strategy="linear", swa_lr=5.)) + lambda: SWALR( + self.opt, anneal_epochs=10, anneal_strategy="linear", swa_lr=5.0 + ), + ) def _check_scheduler_state_dict(self, constr, constr2, epochs=10): scheduler = constr() @@ -2567,12 +3571,12 @@ def _check_scheduler_state_dict(self, constr, constr2, epochs=10): scheduler_copy = constr2() scheduler_copy.load_state_dict(scheduler.state_dict()) for key in scheduler.__dict__.keys(): - if key != 'optimizer': + if key != "optimizer": self.assertEqual(scheduler.__dict__[key], scheduler_copy.__dict__[key]) self.assertEqual(scheduler.get_last_lr(), scheduler_copy.get_last_lr()) def _test_get_last_lr(self, schedulers, targets, epochs=10): - if isinstance(schedulers, _LRScheduler): + if isinstance(schedulers, LRScheduler): schedulers = [schedulers] optimizers = {scheduler.optimizer for scheduler in schedulers} for epoch in range(epochs): @@ -2581,32 +3585,54 @@ def _test_get_last_lr(self, schedulers, targets, epochs=10): [scheduler.step() for scheduler in schedulers] target = [[t[epoch] for t in targets]] * len(schedulers) for t, r in zip(target, result): - self.assertEqual(target, result, - msg='LR is wrong in epoch {}: expected {}, got {}'.format( - epoch, t, r), atol=1e-5, rtol=0) + self.assertEqual( + target, + result, + msg="LR is wrong in epoch {}: expected {}, got {}".format( + epoch, t, r + ), + atol=1e-5, + rtol=0, + ) def _test_with_epoch(self, schedulers, targets, epochs=10): - if isinstance(schedulers, _LRScheduler): + if isinstance(schedulers, LRScheduler): schedulers = [schedulers] optimizers = {scheduler.optimizer for scheduler in schedulers} for epoch in range(epochs): [optimizer.step() for optimizer in optimizers] with warnings.catch_warnings(record=True) as w: - [scheduler.step(epoch) for scheduler in schedulers] # step before assert: skip initial lr - self._check_warning_is_epoch_deprecation_warning(w, num_warnings=len(schedulers)) + [ + scheduler.step(epoch) for scheduler in schedulers + ] # step before assert: skip initial lr + self._check_warning_is_epoch_deprecation_warning( + w, num_warnings=len(schedulers) + ) for param_group, target in zip(self.opt.param_groups, targets): - self.assertEqual(target[epoch], param_group['lr'], - msg='LR is wrong in epoch {}: expected {}, got {}'.format( - epoch, target[epoch], param_group['lr']), atol=1e-5, rtol=0) + self.assertEqual( + target[epoch], + param_group["lr"], + msg="LR is wrong in epoch {}: expected {}, got {}".format( + epoch, target[epoch], param_group["lr"] + ), + atol=1e-5, + rtol=0, + ) def _test(self, schedulers, targets, epochs=10): - if isinstance(schedulers, _LRScheduler): + if isinstance(schedulers, LRScheduler): schedulers = [schedulers] for epoch in range(epochs): for param_group, target in zip(self.opt.param_groups, targets): - self.assertEqual(target[epoch], param_group['lr'], - msg='LR is wrong in epoch {}: expected {}, got {}'.format( - epoch, target[epoch], param_group['lr']), atol=1e-5, rtol=0) + self.assertEqual( + target[epoch], + param_group["lr"], + msg="LR is wrong in epoch {}: expected {}, got {}".format( + epoch, target[epoch], param_group["lr"] + ), + atol=1e-5, + rtol=0, + ) [scheduler.step() for scheduler in schedulers] def _test_CosineAnnealingWarmRestarts(self, scheduler, targets, epochs=10): @@ -2614,17 +3640,29 @@ def _test_CosineAnnealingWarmRestarts(self, scheduler, targets, epochs=10): epoch = round(epoch.item(), 1) scheduler.step(epoch) for param_group, target in zip(self.opt.param_groups, targets): - self.assertEqual(target[index], param_group['lr'], - msg='LR is wrong in epoch {}: expected {}, got {}'.format( - epoch, target[index], param_group['lr']), atol=1e-5, rtol=0) + self.assertEqual( + target[index], + param_group["lr"], + msg="LR is wrong in epoch {}: expected {}, got {}".format( + epoch, target[index], param_group["lr"] + ), + atol=1e-5, + rtol=0, + ) def _test_interleaved_CosineAnnealingWarmRestarts(self, scheduler, targets, epochs): for index, epoch in enumerate(epochs): scheduler.step(epoch) for param_group, target in zip(self.opt.param_groups, targets): - self.assertEqual(target[index], param_group['lr'], - msg='LR is wrong in epoch {}: expected {}, got {}'.format( - epoch, target[index], param_group['lr']), atol=1e-5, rtol=0) + self.assertEqual( + target[index], + param_group["lr"], + msg="LR is wrong in epoch {}: expected {}, got {}".format( + epoch, target[index], param_group["lr"] + ), + atol=1e-5, + rtol=0, + ) def _test_against_closed_form(self, scheduler, closed_form_scheduler, epochs=10): self.setUp() @@ -2634,18 +3672,28 @@ def _test_against_closed_form(self, scheduler, closed_form_scheduler, epochs=10) with warnings.catch_warnings(record=True) as w: closed_form_scheduler.step(epoch) self._check_warning_is_epoch_deprecation_warning(w) - targets.append([group['lr'] for group in self.opt.param_groups]) + targets.append([group["lr"] for group in self.opt.param_groups]) self.setUp() for epoch in range(epochs): self.opt.step() scheduler.step() for i, param_group in enumerate(self.opt.param_groups): - self.assertEqual(targets[epoch][i], param_group['lr'], - msg='LR is wrong in epoch {}: expected {}, got {}'.format( - epoch, targets[epoch][i], param_group['lr']), atol=1e-5, rtol=0) + self.assertEqual( + targets[epoch][i], + param_group["lr"], + msg="LR is wrong in epoch {}: expected {}, got {}".format( + epoch, targets[epoch][i], param_group["lr"] + ), + atol=1e-5, + rtol=0, + ) - def _test_reduce_lr_on_plateau(self, schedulers, targets, metrics, epochs=10, verbose=False): - if isinstance(schedulers, _LRScheduler) or isinstance(schedulers, ReduceLROnPlateau): + def _test_reduce_lr_on_plateau( + self, schedulers, targets, metrics, epochs=10, verbose=False + ): + if isinstance(schedulers, LRScheduler) or isinstance( + schedulers, ReduceLROnPlateau + ): schedulers = [schedulers] for epoch in range(epochs): self.opt.step() @@ -2655,40 +3703,89 @@ def _test_reduce_lr_on_plateau(self, schedulers, targets, metrics, epochs=10, ve else: scheduler.step() if verbose: - print('epoch{}:\tlr={}'.format(epoch, self.opt.param_groups[0]['lr'])) + print("epoch{}:\tlr={}".format(epoch, self.opt.param_groups[0]["lr"])) for param_group, target in zip(self.opt.param_groups, targets): - self.assertEqual(target[epoch], param_group['lr'], - msg='LR is wrong in epoch {}: expected {}, got {}'.format( - epoch, target[epoch], param_group['lr']), atol=1e-5, rtol=0) + self.assertEqual( + target[epoch], + param_group["lr"], + msg="LR is wrong in epoch {}: expected {}, got {}".format( + epoch, target[epoch], param_group["lr"] + ), + atol=1e-5, + rtol=0, + ) - def _test_cycle_lr(self, scheduler, lr_targets, momentum_targets, batch_iterations, verbose=False, use_beta1=False): + def _test_cycle_lr( + self, + scheduler, + lr_targets, + momentum_targets, + batch_iterations, + verbose=False, + use_beta1=False, + ): for batch_num in range(batch_iterations): if verbose: - if 'momentum' in self.opt.param_groups[0].keys(): - print('batch{}:\tlr={},momentum={}'.format(batch_num, self.opt.param_groups[0]['lr'], - self.opt.param_groups[0]['momentum'])) - elif use_beta1 and 'betas' in self.opt.param_groups[0].keys(): - print('batch{}:\tlr={},beta1={}'.format(batch_num, self.opt.param_groups[0]['lr'], - self.opt.param_groups[0]['betas'][0])) + if "momentum" in self.opt.param_groups[0].keys(): + print( + "batch{}:\tlr={},momentum={}".format( + batch_num, + self.opt.param_groups[0]["lr"], + self.opt.param_groups[0]["momentum"], + ) + ) + elif use_beta1 and "betas" in self.opt.param_groups[0].keys(): + print( + "batch{}:\tlr={},beta1={}".format( + batch_num, + self.opt.param_groups[0]["lr"], + self.opt.param_groups[0]["betas"][0], + ) + ) else: - print('batch{}:\tlr={}'.format(batch_num, self.opt.param_groups[0]['lr'])) - - for param_group, lr_target, momentum_target in zip(self.opt.param_groups, lr_targets, momentum_targets): + print( + "batch{}:\tlr={}".format( + batch_num, self.opt.param_groups[0]["lr"] + ) + ) + + for param_group, lr_target, momentum_target in zip( + self.opt.param_groups, lr_targets, momentum_targets + ): self.assertEqual( - lr_target[batch_num], param_group['lr'], - msg='LR is wrong in batch_num {}: expected {}, got {}'.format( - batch_num, lr_target[batch_num], param_group['lr']), atol=1e-5, rtol=0) + lr_target[batch_num], + param_group["lr"], + msg="LR is wrong in batch_num {}: expected {}, got {}".format( + batch_num, lr_target[batch_num], param_group["lr"] + ), + atol=1e-5, + rtol=0, + ) - if use_beta1 and 'betas' in param_group.keys(): + if use_beta1 and "betas" in param_group.keys(): self.assertEqual( - momentum_target[batch_num], param_group['betas'][0], - msg='Beta1 is wrong in batch_num {}: expected {}, got {}'.format( - batch_num, momentum_target[batch_num], param_group['betas'][0]), atol=1e-5, rtol=0) - elif 'momentum' in param_group.keys(): + momentum_target[batch_num], + param_group["betas"][0], + msg="Beta1 is wrong in batch_num {}: expected {}, got {}".format( + batch_num, + momentum_target[batch_num], + param_group["betas"][0], + ), + atol=1e-5, + rtol=0, + ) + elif "momentum" in param_group.keys(): self.assertEqual( - momentum_target[batch_num], param_group['momentum'], - msg='Momentum is wrong in batch_num {}: expected {}, got {}'.format( - batch_num, momentum_target[batch_num], param_group['momentum']), atol=1e-5, rtol=0) + momentum_target[batch_num], + param_group["momentum"], + msg="Momentum is wrong in batch_num {}: expected {}, got {}".format( + batch_num, + momentum_target[batch_num], + param_group["momentum"], + ), + atol=1e-5, + rtol=0, + ) self.opt.step() scheduler.step() @@ -2701,7 +3798,9 @@ def test_cosine_then_cyclic(self): model = torch.nn.Linear(2, 1) optimizer = torch.optim.SGD(model.parameters(), lr=optim_lr) - lr_scheduler_1 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20, eta_min=0.1) + lr_scheduler_1 = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, T_max=20, eta_min=0.1 + ) lr_scheduler_2 = torch.optim.lr_scheduler.CyclicLR( optimizer, base_lr=base_lr, max_lr=max_lr, step_size_up=1, step_size_down=3 ) @@ -2737,7 +3836,9 @@ class SWATestCNN(torch.nn.Module): def __init__(self, input_channels): super(SWATestCNN, self).__init__() self.n_features = 10 - self.conv1 = torch.nn.Conv2d(input_channels, self.n_features, kernel_size=3, padding=1) + self.conv1 = torch.nn.Conv2d( + input_channels, self.n_features, kernel_size=3, padding=1 + ) self.bn = torch.nn.BatchNorm2d(self.n_features, momentum=0.3) def compute_preactivation(self, x): @@ -2750,7 +3851,6 @@ def forward(self, x): class TestSWAUtils(TestCase): - def _test_averaged_model(self, net_device, swa_device): dnn = torch.nn.Sequential( torch.nn.Conv2d(1, 5, kernel_size=3), @@ -2761,7 +3861,7 @@ def _test_averaged_model(self, net_device, swa_device): torch.nn.ReLU(), torch.nn.Linear(5, 5), torch.nn.ReLU(), - torch.nn.Linear(5, 10) + torch.nn.Linear(5, 10), ).to(net_device) averaged_dnn = AveragedModel(dnn, device=swa_device) @@ -2793,8 +3893,7 @@ def test_averaged_model_mixed_device(self): if not torch.cuda.is_available(): return dnn = torch.nn.Sequential( - torch.nn.Conv2d(1, 5, kernel_size=3), - torch.nn.Linear(5, 10) + torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10) ) dnn[0].cuda() dnn[1].cpu() @@ -2814,8 +3913,7 @@ def test_averaged_model_mixed_device(self): def test_averaged_model_state_dict(self): dnn = torch.nn.Sequential( - torch.nn.Conv2d(1, 5, kernel_size=3), - torch.nn.Linear(5, 10) + torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10) ) averaged_dnn = AveragedModel(dnn) averaged_dnn2 = AveragedModel(dnn) @@ -2833,12 +3931,14 @@ def test_averaged_model_exponential(self): # Test AveragedModel with EMA as avg_fn dnn = torch.nn.Sequential( torch.nn.Conv2d(1, 5, kernel_size=3), - torch.nn.Linear(5, 10) + torch.nn.BatchNorm2d(5, momentum=0.3), + torch.nn.Linear(5, 10), ) alpha = 0.9 def avg_fn(p_avg, p, n_avg): return alpha * p_avg + (1 - alpha) * p + averaged_dnn = AveragedModel(dnn, avg_fn=avg_fn) averaged_params = [torch.zeros_like(param) for param in dnn.parameters()] n_updates = 10 @@ -2849,29 +3949,40 @@ def avg_fn(p_avg, p, n_avg): if i == 0: updated_averaged_params.append(p.clone()) else: - updated_averaged_params.append((p_avg * alpha + - p * (1 - alpha)).clone()) + updated_averaged_params.append( + (p_avg * alpha + p * (1 - alpha)).clone() + ) + for b in dnn.buffers(): + if b.size() != torch.Size([]): + b.detach_().add_(torch.randn_like(b)) + averaged_dnn.update_parameters(dnn) averaged_params = updated_averaged_params for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()): self.assertEqual(p_avg, p_swa) + for b_avg, b_swa in zip(dnn.buffers(), averaged_dnn.module.buffers()): + self.assertEqual(b_avg, b_swa) def test_averaged_model_exponential_buffers(self): # Test AveragedModel with EMA as avg_fn and use_buffers as True. dnn = torch.nn.Sequential( torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.BatchNorm2d(5, momentum=0.3), - torch.nn.Linear(5, 10) + torch.nn.Linear(5, 10), ) alpha = 0.9 def avg_fn(p_avg, p, n_avg): return alpha * p_avg + (1 - alpha) * p + averaged_dnn = AveragedModel(dnn, avg_fn=avg_fn, use_buffers=True) dnn_params = itertools.chain(dnn.parameters(), dnn.buffers()) - averaged_params = [torch.zeros_like(param) for param in dnn_params - if param.size() != torch.Size([])] + averaged_params = [ + torch.zeros_like(param) + for param in dnn_params + if param.size() != torch.Size([]) + ] n_updates = 10 for i in range(n_updates): updated_averaged_params = [] @@ -2882,13 +3993,18 @@ def avg_fn(p_avg, p, n_avg): if i == 0: updated_averaged_params.append(p.clone()) else: - updated_averaged_params.append((p_avg * alpha + - p * (1 - alpha)).clone()) + updated_averaged_params.append( + (p_avg * alpha + p * (1 - alpha)).clone() + ) averaged_dnn.update_parameters(dnn) averaged_params = updated_averaged_params for p_avg, p_swa in zip( - averaged_params, itertools.chain(averaged_dnn.module.parameters(), averaged_dnn.module.buffers())): + averaged_params, + itertools.chain( + averaged_dnn.module.parameters(), averaged_dnn.module.buffers() + ), + ): self.assertEqual(p_avg, p_swa) def _test_update_bn(self, dnn, dl_x, dl_xy, cuda): @@ -2923,10 +4039,10 @@ def _test_update_bn(self, dnn, dl_x, dl_xy, cuda): self.assertEqual(preactivation_var, dnn.bn.running_var, atol=1e-1, rtol=0) def _reset_bn(module): - if issubclass(module.__class__, - torch.nn.modules.batchnorm._BatchNorm): + if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): module.running_mean = torch.zeros_like(module.running_mean) module.running_var = torch.ones_like(module.running_var) + # reset batch norm and run update_bn again dnn.apply(_reset_bn) update_bn(dl_xy, dnn, device=x.device) @@ -3011,17 +4127,29 @@ def _diff_fn(p, grad, opt_differentiable_state, opt_class, kwargs, *ignored): opt.state[p].update(opt_differentiable_state) opt.step() return (p,) + tuple( - v for v in opt.state[p].values() if isinstance(v, torch.Tensor) and v.requires_grad) + v + for v in opt.state[p].values() + if isinstance(v, torch.Tensor) and v.requires_grad + ) class TestDifferentiableOptimizer(TestCase): - def test_sgd(self): p = torch.rand(10, requires_grad=True, dtype=torch.float64) grad = torch.rand(10, requires_grad=True, dtype=torch.float64) mbuff = torch.rand(10, requires_grad=True, dtype=torch.float64) - state = {'momentum_buffer': mbuff} - gradcheck(_diff_fn, (p, grad, state, torch.optim.SGD, {'lr': 0.9, 'differentiable': True}, *state.values())) + state = {"momentum_buffer": mbuff} + gradcheck( + _diff_fn, + ( + p, + grad, + state, + torch.optim.SGD, + {"lr": 0.9, "differentiable": True}, + *state.values(), + ), + ) def test_adam(self): state = {} @@ -3029,31 +4157,56 @@ def test_adam(self): grad = torch.rand(10, requires_grad=True, dtype=torch.float64) # `step` is not a continuous variable (even though we define it as a float) # and so it shouldn't require gradients. - state['step'] = torch.tensor(10., requires_grad=False, dtype=torch.float64) - state['exp_avg'] = torch.rand(10, requires_grad=True, dtype=torch.float64) - state['exp_avg_sq'] = torch.rand(10, requires_grad=True, dtype=torch.float64) - state['max_exp_avg_sq'] = torch.rand(10, requires_grad=True, dtype=torch.float64) + state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64) + state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64) + state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64) + state["max_exp_avg_sq"] = torch.rand( + 10, requires_grad=True, dtype=torch.float64 + ) gradcheck( _diff_fn, - (p, grad, state, torch.optim.Adam, - {'lr': 0.9, 'differentiable': True, 'amsgrad': True}, *state.values()) + ( + p, + grad, + state, + torch.optim.Adam, + {"lr": 0.9, "differentiable": True, "amsgrad": True}, + *state.values(), + ), ) def test_rmsprop(self): state = {} p = torch.rand(10, requires_grad=True, dtype=torch.float64) grad = torch.rand(10, requires_grad=True, dtype=torch.float64) - state['step'] = 0 - state['square_avg'] = torch.rand(10, requires_grad=True, dtype=torch.float64) - state['momentum_buffer'] = torch.rand(10, requires_grad=True, dtype=torch.float64) + state["step"] = 0 + state["square_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64) + state["momentum_buffer"] = torch.rand( + 10, requires_grad=True, dtype=torch.float64 + ) # This can cause issues with large values and nan due to sqrt ops - state['grad_avg'] = 1e-2 * torch.rand(10, requires_grad=True, dtype=torch.float64) + state["grad_avg"] = 1e-2 * torch.rand( + 10, requires_grad=True, dtype=torch.float64 + ) gradcheck( _diff_fn, - (p, grad, state, torch.optim.RMSprop, - {'lr': 0.9, 'maximize': True, 'momentum': 0.9, 'differentiable': True, 'centered': True, 'weight_decay': 0.1}, - *state.values())) + ( + p, + grad, + state, + torch.optim.RMSprop, + { + "lr": 0.9, + "maximize": True, + "momentum": 0.9, + "differentiable": True, + "centered": True, + "weight_decay": 0.1, + }, + *state.values(), + ), + ) def test_adadelta(self): state = {} @@ -3061,13 +4214,19 @@ def test_adadelta(self): grad = torch.rand(10, requires_grad=True, dtype=torch.float64) # `step` is not a continuous variable (even though we define it as a float) # and so it shouldn't require gradients. - state['step'] = torch.tensor(10., requires_grad=False, dtype=torch.float64) - state['square_avg'] = torch.rand(10, requires_grad=True, dtype=torch.float64) - state['acc_delta'] = torch.rand(10, requires_grad=True, dtype=torch.float64) + state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64) + state["square_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64) + state["acc_delta"] = torch.rand(10, requires_grad=True, dtype=torch.float64) gradcheck( _diff_fn, - (p, grad, state, torch.optim.Adadelta, - {'lr': 0.9, 'weight_decay': 0.1, 'differentiable': True}, *state.values()) + ( + p, + grad, + state, + torch.optim.Adadelta, + {"lr": 0.9, "weight_decay": 0.1, "differentiable": True}, + *state.values(), + ), ) def test_adagrad(self): @@ -3076,12 +4235,18 @@ def test_adagrad(self): grad = torch.rand(10, requires_grad=True, dtype=torch.float64) # `step` is not a continuous variable (even though we define it as a float) # and so it shouldn't require gradients. - state['step'] = torch.tensor(10., requires_grad=False, dtype=torch.float64) - state['sum'] = torch.rand(10, requires_grad=True, dtype=torch.float64) + state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64) + state["sum"] = torch.rand(10, requires_grad=True, dtype=torch.float64) gradcheck( _diff_fn, - (p, grad, state, torch.optim.Adagrad, - {'lr': 0.9, 'weight_decay': 0.1, 'differentiable': True}, *state.values()) + ( + p, + grad, + state, + torch.optim.Adagrad, + {"lr": 0.9, "weight_decay": 0.1, "differentiable": True}, + *state.values(), + ), ) def test_adamax(self): @@ -3090,13 +4255,19 @@ def test_adamax(self): grad = torch.rand(10, requires_grad=True, dtype=torch.float64) # `step` is not a continuous variable (even though we define it as a float) # and so it shouldn't require gradients. - state['step'] = torch.tensor(10., requires_grad=False, dtype=torch.float64) - state['exp_avg'] = torch.rand(10, requires_grad=True, dtype=torch.float64) - state['exp_inf'] = torch.rand(10, requires_grad=True, dtype=torch.float64) + state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64) + state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64) + state["exp_inf"] = torch.rand(10, requires_grad=True, dtype=torch.float64) gradcheck( _diff_fn, - (p, grad, state, torch.optim.Adamax, - {'lr': 0.9, 'weight_decay': 0.1, 'differentiable': True}, *state.values()) + ( + p, + grad, + state, + torch.optim.Adamax, + {"lr": 0.9, "weight_decay": 0.1, "differentiable": True}, + *state.values(), + ), ) def test_asgd(self): @@ -3105,15 +4276,21 @@ def test_asgd(self): grad = torch.rand(10, requires_grad=True, dtype=torch.float64) # `step` `eta` & `mu` are not continuous variables (even though we define them as a float) # and so it shouldn't require gradients. - state['step'] = torch.tensor(10., requires_grad=False, dtype=torch.float64) - state['eta'] = torch.tensor(0.9, requires_grad=False, dtype=torch.float64) - state['mu'] = torch.tensor(1.0, requires_grad=False, dtype=torch.float64) - state['ax'] = torch.rand(10, requires_grad=True, dtype=torch.float64) + state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64) + state["eta"] = torch.tensor(0.9, requires_grad=False, dtype=torch.float64) + state["mu"] = torch.tensor(1.0, requires_grad=False, dtype=torch.float64) + state["ax"] = torch.rand(10, requires_grad=True, dtype=torch.float64) gradcheck( _diff_fn, - (p, grad, state, torch.optim.ASGD, - {'lr': 0.9, 'differentiable': True}, *state.values()) + ( + p, + grad, + state, + torch.optim.ASGD, + {"lr": 0.9, "differentiable": True}, + *state.values(), + ), ) def test_rprop(self): @@ -3122,32 +4299,45 @@ def test_rprop(self): grad = torch.rand(10, requires_grad=True, dtype=torch.float64) # `step` is not a continuous variable (even though we define it as a float) # and so it shouldn't require gradients. - state['step'] = torch.tensor(10., requires_grad=False, dtype=torch.float64) - state['prev'] = torch.rand(10, requires_grad=True, dtype=torch.float64) - state['step_size'] = torch.rand(10, requires_grad=True, dtype=torch.float64) + state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64) + state["prev"] = torch.rand(10, requires_grad=True, dtype=torch.float64) + state["step_size"] = torch.rand(10, requires_grad=True, dtype=torch.float64) gradcheck( _diff_fn, - (p, grad, state, torch.optim.Rprop, - {'lr': 0.9, 'differentiable': True}, *state.values()) + ( + p, + grad, + state, + torch.optim.Rprop, + {"lr": 0.9, "differentiable": True}, + *state.values(), + ), ) - def test_adamw(self): state = {} p = torch.rand(10, requires_grad=True, dtype=torch.float64) grad = torch.rand(10, requires_grad=True, dtype=torch.float64) # `step` is not a continuous variable (even though we define it as a float) # and so it shouldn't require gradients. - state['step'] = torch.tensor(10., requires_grad=False, dtype=torch.float64) - state['exp_avg'] = torch.rand(10, requires_grad=True, dtype=torch.float64) - state['exp_avg_sq'] = torch.rand(10, requires_grad=True, dtype=torch.float64) - state['max_exp_avg_sq'] = torch.rand(10, requires_grad=True, dtype=torch.float64) + state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64) + state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64) + state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64) + state["max_exp_avg_sq"] = torch.rand( + 10, requires_grad=True, dtype=torch.float64 + ) gradcheck( _diff_fn, - (p, grad, state, torch.optim.AdamW, - {'lr': 0.9, 'differentiable': True, 'amsgrad': True}, *state.values()) + ( + p, + grad, + state, + torch.optim.AdamW, + {"lr": 0.9, "differentiable": True, "amsgrad": True}, + *state.values(), + ), ) def test_nadam(self): @@ -3156,15 +4346,21 @@ def test_nadam(self): grad = torch.rand(10, requires_grad=True, dtype=torch.float64) # `step` is not a continuous variable (even though we define it as a float) # and so it shouldn't require gradients. - state['step'] = torch.tensor(10., requires_grad=False, dtype=torch.float64) - state['exp_avg'] = torch.rand(10, requires_grad=True, dtype=torch.float64) - state['exp_avg_sq'] = torch.rand(10, requires_grad=True, dtype=torch.float64) - state['mu_product'] = torch.tensor(1.0, requires_grad=True, dtype=torch.float64) + state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64) + state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64) + state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64) + state["mu_product"] = torch.tensor(1.0, requires_grad=True, dtype=torch.float64) gradcheck( _diff_fn, - (p, grad, state, torch.optim.NAdam, - {'lr': 0.9, 'differentiable': True}, *state.values()) + ( + p, + grad, + state, + torch.optim.NAdam, + {"lr": 0.9, "differentiable": True}, + *state.values(), + ), ) def test_radam(self): @@ -3173,16 +4369,22 @@ def test_radam(self): grad = torch.rand(10, requires_grad=True, dtype=torch.float64) # `step` is not a continuous variable (even though we define it as a float) # and so it shouldn't require gradients. - state['step'] = torch.tensor(10., requires_grad=False, dtype=torch.float64) - state['exp_avg'] = torch.rand(10, requires_grad=True, dtype=torch.float64) - state['exp_avg_sq'] = torch.rand(10, requires_grad=True, dtype=torch.float64) + state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64) + state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64) + state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64) gradcheck( _diff_fn, - (p, grad, state, torch.optim.RAdam, - {'lr': 0.9, 'differentiable': True}, *state.values()) + ( + p, + grad, + state, + torch.optim.RAdam, + {"lr": 0.9, "differentiable": True}, + *state.values(), + ), ) -if __name__ == '__main__': +if __name__ == "__main__": run_tests() diff --git a/test/test_overrides.py b/test/test_overrides.py index e9e01684bda53..629cd2a106806 100644 --- a/test/test_overrides.py +++ b/test/test_overrides.py @@ -387,6 +387,10 @@ def test_mean_semantics(self): self.assertEqual(torch.mean(t3), 4.0) self.assertEqual(bar(t3), 0) + def test_has_torch_function_non_sequence(self): + with self.assertRaisesRegex(TypeError, "expected a sequence"): + has_torch_function(object()) + def test_mm_semantics(self): """Test that a function with multiple arguments can be overrided""" t1 = DiagonalTensor(5, 2) @@ -897,7 +901,6 @@ def run_test(fast_mode): 'dtype', 'is_floating_point', 'is_sparse', - 'is_sparse_csr', 'layout', 'new_zeros', 'numel', @@ -1175,7 +1178,7 @@ def __torch_function__(self, *args, **kwargs): self.assertEqual(torch.mm(x, x), -1) self.assertEqual(bar(x), 1) self.assertRaisesRegex( - TypeError, r'SubTensor.+TorchFunctionStackMode', + TypeError, r'SubTensor', lambda: self.assertEqual(torch.max(x, x))) def test_with_mode(self): @@ -1248,7 +1251,7 @@ def __torch_function__(cls, func, _, args=(), kwargs=None): return func(args, kwargs) x = torch.tensor(5.) - with self.assertRaisesRegex(RuntimeError, "should be a normal method not a class method"): + with self.assertRaisesRegex(RuntimeError, "classmethod is not supported, please make it a plain method"): with A(): x + x diff --git a/test/test_prims.py b/test/test_prims.py index 674a032796044..cadef6097df15 100644 --- a/test/test_prims.py +++ b/test/test_prims.py @@ -8,11 +8,11 @@ import torch from torch.testing import make_tensor -from torch.testing._internal.common_utils import parametrize, run_tests, TestCase, TEST_SCIPY, skipCUDAMemoryLeakCheckIf +from torch.testing._internal.common_utils import (parametrize, run_tests, TestCase, TEST_SCIPY, + set_default_dtype, skipCUDAMemoryLeakCheckIf) from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, onlyCUDA, - skipCUDAIfRocm, dtypes, OpDTypes, ) @@ -38,7 +38,6 @@ class TestPrims(TestCase): @onlyCUDA - @skipCUDAIfRocm @dtypes(torch.float32) def test_broadcast_in_dim(self, device, dtype): def _wrapper(a, b, broadcast_dimensions): @@ -102,7 +101,6 @@ def _wrapper(a, b, broadcast_dimensions): """ @onlyCUDA - @skipCUDAIfRocm @dtypes(torch.float32) def test_broadcast_in_dim_sum(self, device, dtype): def _wrapper(a): @@ -130,11 +128,8 @@ def test_cbrt_prim(self, device, dtype): batches = [(), (1,), (2,), (0, 1), (1, 1), (2, 2)] shapes = [(), (0,), (1,), (5,)] - try: - # Sets the default dtype to NumPy's default dtype of double - cur_default = torch.get_default_dtype() - torch.set_default_dtype(torch.double) - + # Sets the default dtype to NumPy's default dtype of double + with set_default_dtype(torch.double): # Tested here, as this OP is not currently exposed or tested in ATen for b, s in product(batches, shapes): x = make_arg(b + s) @@ -144,11 +139,8 @@ def test_cbrt_prim(self, device, dtype): y_np = scipy.special.cbrt(x_np) self.assertEqual(y, y_np, exact_device=False) - finally: - torch.set_default_dtype(cur_default) @onlyCUDA - @skipCUDAIfRocm def test_nvfuser_impl_is_used(self, device): # This test is to ensure that when the nvfuser implementation exists it is used # Assuming one-to-one mapping between prims and nvfuser implementations @@ -215,8 +207,69 @@ def func(a): ) self.assertFalse(include_any_nvprims_sin) + def test_partitioner_tuple_output(self, device): + # This test verifies that the partitioner doesn't segment on nodes with + # tuple outputs. + from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner + from torch._prims.nvfuser_executor import NvfuserPrimOperatorSupport + + a = make_tensor(5, 3, 3, device=device, dtype=torch.float32) + + def func(x): + xx = torch.ops.nvprims.add(x, 1) + var, mean = torch.ops.nvprims.var_mean(x, correction=0) + var_cos = torch.ops.nvprims.cos(var) + mean_sin = torch.ops.nvprims.sin(mean) + return torch.ops.nvprims.add(var_cos, mean_sin) + + gm = make_fx(func)(a) + supported_ops = NvfuserPrimOperatorSupport() + partitioner = CapabilityBasedPartitioner( + gm, supported_ops, allows_single_node_partition=False + ) + partitions = partitioner.propose_partitions() + self.assertEqual(len(partitions), 1) + + @onlyCUDA + @dtypes(torch.float32) + def test_full(self, device, dtype): + from torch.fx.experimental.proxy_tensor import make_fx + from torch._prims.context import TorchRefsNvfuserCapabilityMode + from torch._prims.executor import execute + + def func1(size, value, b): + return (torch.full(size, value, dtype=dtype, device=device),) + + def func2(size, value, b): + a = torch.full(size, value, dtype=dtype, device=device) + b_sin = b.sin() + return (torch.add(a, b_sin),) + + def func3(size, value, b): + return (torch.full(size, value, dtype=dtype, device=device), b) + + def func4(size, value, b): + b_sin = b.sin() + return (torch.full(size, value, dtype=dtype, device=device), b_sin) + + def func5(size, value, b): + b_sin = b.sin() + a = torch.full(size, value, dtype=dtype, device=device) + a_sin = a.sin() + return (a, b_sin, a_sin) + + for func in (func1, func3, func2, func3, func4, func5): + size = (3, 3) + value = 10 + b = torch.randn(*size, dtype=dtype, device=device) + + with TorchRefsNvfuserCapabilityMode(): + gm = make_fx(func)(size, value, b) + + out = execute(gm, size, value, b, executor="strictly_nvfuser") + self.assertEqual(out, func(size, value, b)) + @onlyCUDA - @skipCUDAIfRocm def test_nvfuser_empty_fusion(self, device): from torch.fx.experimental.proxy_tensor import make_fx from torch._prims.executor import execute @@ -268,7 +321,6 @@ def func(x, dtype): self.assertEqual(includes_nvprim_convert_element_type, nvprim_support_flag) @onlyCUDA - @skipCUDAIfRocm def test_nvfuser_rand_like_fusion(self, device): from torch._prims.context import TorchRefsNvfuserCapabilityMode from torch.fx.experimental.proxy_tensor import make_fx @@ -287,7 +339,6 @@ def func(a): @skipCUDAMemoryLeakCheckIf(True) # https://github.com/pytorch/pytorch/issues/84529 @onlyCUDA - @skipCUDAIfRocm def test_nvfuser_no_args(self, device): from torch._prims.context import TorchRefsNvfuserCapabilityMode from torch.fx.experimental.proxy_tensor import make_fx @@ -318,7 +369,6 @@ def func(): self.assertEqual(out, func()) @onlyCUDA - @skipCUDAIfRocm def test_nvfuser_constant_tensors(self, device): from torch._prims.context import TorchRefsNvfuserCapabilityMode from torch.fx.experimental.proxy_tensor import make_fx @@ -341,11 +391,10 @@ def func(b): self.assertEqual(out, gm(b)) @onlyCUDA - @skipCUDAIfRocm def test_nvfuser_executor_cached_noncontiguous(self, device): # This test is to ensure that nvfuser computes correct results for noncontiguous tensors from torch.fx.experimental.proxy_tensor import make_fx - from torch._prims.context import TorchRefsMode + from torch._prims.context import TorchRefsNvfuserCapabilityMode from torch._prims.executor import execute a = torch.randn(3, 3, device=device) @@ -353,16 +402,18 @@ def test_nvfuser_executor_cached_noncontiguous(self, device): def func(a): return torch.sigmoid(a) - with TorchRefsMode(): + with TorchRefsNvfuserCapabilityMode(): gm = make_fx(func)(a) # First run to create the cache - execute(gm, a, executor="nvfuser") + execute(gm, a, executor="strictly_nvfuser") # a.mT is noncontiguous, but it shouldn't affect correctness expected = execute(gm, a.mT, executor="aten") - actual = execute(gm, a.mT, executor="nvfuser") - self.assertEqual(expected, actual) + for use_python_cache in [True, False]: + params = {"use_python_fusion_cache": use_python_cache} + actual = execute(gm, a.mT, executor="strictly_nvfuser", executor_parameters=params) + self.assertEqual(expected, actual) def test_nvfuser_capability_context(self, device): # This test is to ensure that the torch calls are replaced with refs @@ -442,7 +493,6 @@ def func(a): @onlyCUDA - @skipCUDAIfRocm def test_nvfuser_executor_parameters(self, device): from torch.fx.experimental.proxy_tensor import make_fx from torch._prims.executor import execute @@ -475,7 +525,6 @@ def func(a): @onlyCUDA - @skipCUDAIfRocm def test_nvfuser_executor_partitioned(self, device): # This test is to ensure that nvfuser partitioned executor works correctly # It's assumed that digamma is not supported by nvfuser @@ -483,7 +532,7 @@ def test_nvfuser_executor_partitioned(self, device): self.assertTrue(getattr(torch.ops.nvprims, "digamma", None) is None) from torch.fx.experimental.proxy_tensor import make_fx - from torch._prims.context import TorchRefsMode + from torch._prims.context import TorchRefsNvfuserCapabilityMode from torch._prims.executor import execute a = torch.randn(3, 4, device=device) @@ -496,7 +545,7 @@ def func(a, b, c): dd = torch.sqrt(d) return torch.mul(aa, dd.digamma()) - with TorchRefsMode(): + with TorchRefsNvfuserCapabilityMode(): gm = make_fx(func)(a, b, c) expected = execute(gm, a, b, c, executor="aten") @@ -504,7 +553,6 @@ def func(a, b, c): self.assertEqual(expected, actual) @onlyCUDA - @skipCUDAIfRocm def test_nvfuser_executor_partitioned_no_partitions_error(self, device): # This test is to ensure that nvfuser partitioned executor works correctly # It's assumed that digamma is not supported by nvfuser @@ -512,7 +560,7 @@ def test_nvfuser_executor_partitioned_no_partitions_error(self, device): self.assertTrue(getattr(torch.ops.nvprims, "digamma", None) is None) from torch.fx.experimental.proxy_tensor import make_fx - from torch._prims.context import TorchRefsMode + from torch._prims.context import TorchRefsNvfuserCapabilityMode from torch._prims.executor import execute a = torch.randn(3, 4, device=device) @@ -520,7 +568,7 @@ def test_nvfuser_executor_partitioned_no_partitions_error(self, device): def func(a): return torch.digamma(a) # not supported by nvfuser - with TorchRefsMode(): + with TorchRefsNvfuserCapabilityMode(): gm = make_fx(func)(a) with catch_warnings(record=True) as w: @@ -549,7 +597,6 @@ def func(a): self.assertFalse(node.target == torch.ops.aten.add.default) @onlyCUDA - @skipCUDAIfRocm @dtypes(torch.float32, torch.float64) def test_native_batch_norm_nvprims(self, device, dtype): from torch._prims.context import TorchRefsNvfuserCapabilityMode @@ -612,7 +659,6 @@ def func( self.assertEqual(out, gm(sample.input, *sample.args)) @onlyCUDA - @skipCUDAIfRocm @dtypes(torch.float32, torch.float64) def test_cudnn_batch_norm_nvprims(self, device, dtype): from torch._prims.context import TorchRefsNvfuserCapabilityMode @@ -666,7 +712,13 @@ def func( # Check that the graph can be executed with nvFuser out = execute(gm, sample.input, *sample.args, executor="nvfuser") - self.assertEqual(out, gm(sample.input, *sample.args)) + ref_out = gm(sample.input, *sample.args) + for idx, (left, right) in enumerate(zip(out, ref_out)): + # Nvfuser does not support torch.uint8 dtype so check reserve output against 0 scalar + if idx == 3: + self.assertTrue(torch.all(torch.eq(left, 0))) + else: + self.assertEqual(left, right) # decomposition of native_batch_norm_backward uses a casting, which prevents nvprim lowering on CPU build @onlyCUDA @@ -711,7 +763,6 @@ def func2(grad, input, weight, rm, rv, eps, train): self.assertTrue(all_nvprims) @onlyCUDA - @skipCUDAIfRocm @dtypes(torch.float32) def test_silu_backward_no_filled_tensor(self, device, dtype): # This test verifies a workaround for @@ -760,7 +811,6 @@ def func(a): @onlyCUDA - @skipCUDAIfRocm @dtypes(torch.float32) @parametrize("correction", [0, 1]) def test_var(self, device, dtype, correction): @@ -781,7 +831,6 @@ def _wrapper(a): self.assertEqual(_wrapper(a), result) @onlyCUDA - @skipCUDAIfRocm @dtypes(torch.float16, torch.float32) @parametrize("correction", [0, 1]) @parametrize("keepdim", [True, False]) @@ -806,7 +855,6 @@ def _wrapper(a): self.assertTrue(includes_nvprims_var_mean) @onlyCUDA - @skipCUDAIfRocm @dtypes(torch.float16, torch.float32) def test_nvprims_view(self, device, dtype): from torch.fx.experimental.proxy_tensor import make_fx @@ -853,7 +901,31 @@ def func7(a): self.assertEqual(out, func(a)) @onlyCUDA - @skipCUDAIfRocm + @dtypes(torch.float16, torch.float32) + def test_nvprims_view_partitioner(self, device, dtype): + # This test verifies that views that are not fused with other ops are + # correctly overriden to call aten implementation. + from torch.fx.experimental.proxy_tensor import make_fx + from torch._prims.context import TorchRefsNvfuserCapabilityMode + from torch._prims.nvfuser_executor import maybe_partition_graph + + make_arg = partial(make_tensor, device=device, dtype=dtype) + a = make_arg((4, 5)) + b = make_arg((5, 4)) + + def func(a, b): + aa = a.view(b.shape) + aa = aa.view(a.shape) + return aa.digamma() + + with TorchRefsNvfuserCapabilityMode(): + gm = make_fx(func)(a, b) + gm, _ = maybe_partition_graph(gm, False, False) + + out = gm(a, b) + self.assertEqual(out, func(a, b)) + + @onlyCUDA @dtypes(torch.float32, torch.float16) def test_cpu_tensor(self, device, dtype): from torch.fx.experimental.proxy_tensor import make_fx @@ -894,7 +966,6 @@ def _wrapper(t0, t1, cpu_scalar): self.assertEqual(expected, nvprim_aten_fallback) @onlyCUDA - @skipCUDAIfRocm @dtypes(torch.float32) def test_pytree_input_output(self, device, dtype): @make_traced @@ -1051,7 +1122,6 @@ def test_constant_pad_nd_memory_format(self, device, dtype): class TestDecomp(TestCase): @onlyCUDA - @skipCUDAIfRocm @dtypes(torch.float16, torch.float32) def test_decomposition_type_promotion_nvprim_amp(self, device, dtype): x = torch.rand(5, device=device).to(dtype) @@ -1092,7 +1162,6 @@ def fn1(x): self.assertFalse(includes_aten_to_copy) @onlyCUDA - @skipCUDAIfRocm @dtypes(torch.float16, torch.float32) def test_masked_fill_decomposition_under_nvprim_context(self, device, dtype): # masked_fill decomposition extracts cpu scalar tensor value when diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index 914261ae1c6ab..38911c7981be7 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1,6 +1,6 @@ # Owner(s): ["module: ProxyTensor"] -from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS +from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS, xfail_inherited_tests import torch import unittest import warnings @@ -13,6 +13,7 @@ from torch._subclasses.fake_tensor import DynamicOutputShapeException from torch._decomp import decomposition_table +from torch.fx.experimental.symbolic_shapes import sym_float from torch.testing._internal.common_device_type import ops from torch._C import _disabled_torch_function_impl from torch.fx.experimental.proxy_tensor import make_fx, DecompositionInterpreter, get_isolated_graphmodule, has_proxy @@ -20,7 +21,6 @@ from torch import nn import re -import types import functools import itertools @@ -70,16 +70,6 @@ def create_normalized_name(op): print("}") -def copy_func(f): - """Based on http://stackoverflow.com/a/6528148/190597 (Glenn Maynard)""" - g = types.FunctionType(f.__code__, f.__globals__, name=f.__name__, - argdefs=f.__defaults__, - closure=f.__closure__) - g = functools.update_wrapper(g, f) - g.__kwdefaults__ = f.__kwdefaults__ - return g - - # Copied from functorch def xfail(op_name, variant_name='', *, device_type=None, dtypes=None): return (op_name, variant_name, device_type, dtypes, True) @@ -400,6 +390,19 @@ def f(x): ) ) + def test_val_metadata_mutation(self): + def f(x): + y = x.clone() + y.unsqueeze_(0) + return y + + traced = make_fx(f, tracing_mode=self.tracing_mode)(torch.randn(3, requires_grad=True)) + self.assertEqual([ + tuple(node.meta['val'].shape) + for node in traced.graph.nodes + if 'val' in node.meta + ], [(3,), (3,), (1, 3)]) + def test_make_fx_overloads(self): def f(x): return x.cos() + torch.randn(x.shape) @@ -701,28 +704,9 @@ class TestGenericProxyTensorFake(TestGenericProxyTensor): tracing_mode = "fake" -def xfail_inherited_tests(tests): - """ - Given a list of test names which are defined by a superclass of the - class this decorates, mark them as expected failure. This is useful - if you are doing poor man's parameterized tests by subclassing a generic - test class. - """ - def deco(cls): - for t in tests: - # NB: expectedFailure operates by mutating the method in question, - # which is why you have to copy the function first - setattr(cls, t, unittest.expectedFailure(copy_func(getattr(cls, t)))) - return cls - return deco - - @skipIfNoSympy @xfail_inherited_tests([ - "test_inplace_metadata", - "test_mode_tracing_factory_function", "test_make_fx_overloads", - "test_resnet18_backward_trace", "test_trace_subclasses", ]) class TestGenericProxyTensorSymbolic(TestGenericProxyTensor): @@ -849,23 +833,21 @@ def forward(self, a_1): sym_size = torch.ops.aten.sym_size(a_1, 0); a_1 = None mul = sym_size * 2; sym_size = None empty = torch.ops.aten.empty.memory_format([mul], device = device(type='cpu'), pin_memory = False); mul = None - detach = torch.ops.aten.detach.default(empty); empty = None - return detach""") + return empty""") def test_neg_shape(self): def f(a): return torch.empty(-a.shape[0] + 10) - r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(1)).code).strip() + r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(2)).code).strip() self.assertExpectedInline(r, """\ def forward(self, a_1): sym_size = torch.ops.aten.sym_size(a_1, 0); a_1 = None neg = -sym_size; sym_size = None add = neg + 10; neg = None empty = torch.ops.aten.empty.memory_format([add], device = device(type='cpu'), pin_memory = False); add = None - detach = torch.ops.aten.detach.default(empty); empty = None - return detach""") + return empty""") def test_sqrt_size(self): def f(a): @@ -875,8 +857,7 @@ def f(a): self.assertExpectedInline(r, """\ def forward(self, a_1): sym_size = torch.ops.aten.sym_size(a_1, 0) - sym_float = torch.fx.experimental.symbolic_shapes.sym_float(sym_size); sym_size = None - pow_1 = sym_float ** 0.5; sym_float = None + pow_1 = sym_size ** 0.5; sym_size = None div = torch.ops.aten.div.Tensor(a_1, pow_1); a_1 = pow_1 = None return div""") @@ -949,7 +930,7 @@ def f(a, b): fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(5), torch.randn(4)) meta_c = _get_node(fx_g, lambda x: x.target == aten.new_empty.default) meta_d = _get_node(fx_g, lambda x: x.target == operator.add) - self.assertTrue(meta_c.meta['val'].shape[0].get_pyobj().expr == meta_d.meta['val'].expr) + self.assertTrue(meta_c.meta['val'].shape[0].get_pyobj().expr == meta_d.meta['val'].node.expr) def test_metadata_fresh(self): def f(x): @@ -964,8 +945,27 @@ def f(x): # happened afterwards self.assertTrue(meta_inp.meta['val'].shape[0].get_pyobj().expr == 3) + def test_elementwise_meta_with_sym_numbers(self): + def f(x, offset, as_sym_float=False): + x0 = x.size()[0] + if as_sym_float: + x0 = sym_float(x0) + return torch.add(x0, offset) + + fx_g = make_fx(f, tracing_mode="symbolic")(torch.rand(2, 3), 2.0, False) + meta_add = _get_node(fx_g, lambda x: x.target == aten.add.Tensor) + self.assertEqual(meta_add.meta['val'].shape, ()) + self.assertEqual(meta_add.meta['val'].dtype, torch.float32) + fx_g = make_fx(f, tracing_mode="symbolic")(torch.rand(2, 3), 2, False) + meta_add = _get_node(fx_g, lambda x: x.target == aten.add.Tensor) + self.assertEqual(meta_add.meta['val'].shape, ()) + self.assertEqual(meta_add.meta['val'].dtype, torch.int64) + fx_g = make_fx(f, tracing_mode="symbolic")(torch.rand(2, 3), 2, True) + meta_add = _get_node(fx_g, lambda x: x.target == aten.add.Tensor) + self.assertEqual(meta_add.meta['val'].shape, ()) + self.assertEqual(meta_add.meta['val'].dtype, torch.float32) def test_return_symint(self): def f(x): @@ -976,6 +976,28 @@ def f(x): return x.shape self._test_dynamic(f, [(5, 3)], [[(4, 6)]]) + def test_rmethod(self): + def f(x): + return x.size(0) + x + self._test_dynamic(f, [(5,)], [[(4,)], [(12,)]]) + + def test_mega_guard(self): + def f(a, b): + assert a.shape[0] == b.shape[0] * 2 + assert b.shape[0] == 8 + return a.cos() + fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(16), torch.randn(8)) + self.assertExpectedInline(str(fx_g.shape_env.get_guard_expr()), """Eq(s5, 8) & Eq(s1, 2*s5)""") + + def test_sym_storage_offset(self): + def f(x, y): + return x + y + + inp = (torch.randn(8)[3:], torch.randn(5)) + fx_g = make_fx(f, tracing_mode="symbolic")(*inp) + inp = (torch.randn(8)[3:], torch.randn(5)) + self.assertEqual(fx_g(*inp), f(*inp)) + def _assert_no_guards(self, fx_g, free_symbols): assert _get_free_symbols(fx_g.shape_env) == free_symbols, fx_g.shape_env.var_to_val assert len(fx_g.shape_env.get_nontrivial_guards()) == 0, fx_g.shape_env.format_guards() @@ -1084,6 +1106,8 @@ def f(a, b, c, d, e): xfail('multinomial'), xfail('cholesky'), xfail('cholesky_inverse'), + # cannot do these as they rely on tensor data + xfail('repeat_interleave'), # ASAN failures due to divide by 0 skip('nn.functional.nll_loss'), } @@ -1094,37 +1118,17 @@ def f(a, b, c, d, e): xfail('linalg.eig'), xfail('linalg.eigvals'), skip('masked.logsumexp', ''), # Tensors of type TensorImpl do not have numel - xfail('masked.amax', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition - xfail('masked.amin', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition - xfail('masked.argmax', ''), # aten.argmax.default - couldn't find symbolic meta function/decomposition - xfail('masked.argmin', ''), # aten.argmin.default - couldn't find symbolic meta function/decomposition xfail('masked.cumprod', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition - xfail('masked.cumsum', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition - xfail('masked.log_softmax', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition xfail('masked.logaddexp', ''), # aten.logaddexp.default - couldn't find symbolic meta function/decomposition - xfail('masked.mean', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=torch.device, ... - xfail('masked.median', ''), # aten.nanmedian.dim - couldn't find symbolic meta function/decomposition - xfail('masked.norm', ''), # aten.linalg_vector_norm.default - couldn't find symbolic meta function/decomposition - xfail('masked.prod', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition - xfail('masked.softmax', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition - xfail('masked.softmin', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition - xfail('masked.std', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=torch.device, d... - xfail('masked.sum', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition - xfail('masked.var', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=torch.device, d... xfail('addmv', ''), # aten.addmv.default - couldn't find symbolic meta function/decomposition xfail('addr', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('aminmax', ''), # aten.aminmax.default - couldn't find symbolic meta function/decomposition - xfail('argmax', ''), # aten.argmax.default - couldn't find symbolic meta function/decomposition - xfail('argmin', ''), # aten.argmin.default - couldn't find symbolic meta function/decomposition - xfail('argsort', ''), # aten.sort.default - couldn't find symbolic meta function/decomposition xfail('argwhere', ''), # aten.nonzero.default - couldn't find symbolic meta function/decomposition xfail('baddbmm', ''), # aten.baddbmm.default - couldn't find symbolic meta function/decomposition - xfail('bernoulli', ''), # aten.bernoulli.default - couldn't find symbolic meta function/decomposition xfail('bucketize', ''), # aten.bucketize.Tensor - couldn't find symbolic meta function/decomposition xfail('cartesian_prod', ''), # Tensors of type TensorImpl do not have numel xfail('cdist', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('cholesky_solve', ''), # Could not run 'aten::_cholesky_solve_helper' with arguments from the 'Meta' back... - xfail('chunk', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('column_stack', ''), # Tensors of type TensorImpl do not have numel xfail('combinations', ''), xfail('count_nonzero', ''), # Could not run 'aten::count_nonzero.dim_IntList' with arguments from the 'Meta' ba... @@ -1133,10 +1137,7 @@ def f(a, b, c, d, e): xfail('cummin', ''), # aten.cummin.default - couldn't find symbolic meta function/decomposition xfail('cumprod', ''), # aten.cumprod.default - couldn't find symbolic meta function/decomposition xfail('cumulative_trapezoid', ''), # aten.slice.Tensor - couldn't find symbolic meta function/decomposition - xfail('deg2rad', ''), # aten.deg2rad.default - couldn't find symbolic meta function/decomposition - xfail('diagonal_scatter', ''), # aten.diagonal_scatter.default - couldn't find symbolic meta function/decomposition xfail('diff', ''), # aten.empty_like.default - couldn't find symbolic meta function/decomposition - xfail('dist', ''), # aten.dist.default - couldn't find symbolic meta function/decomposition xfail('dsplit', ''), # aten.slice.Tensor - couldn't find symbolic meta function/decomposition xfail('fft.fft2', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('fft.fft', ''), # aten.size.default - couldn't find symbolic meta function/decomposition @@ -1173,9 +1174,6 @@ def f(a, b, c, d, e): xfail('isin', ''), # aten.isin.Tensor_Tensor - couldn't find symbolic meta function/decomposition xfail('kron', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('kthvalue', ''), # aten.kthvalue.default - couldn't find symbolic meta function/decomposition - xfail('lerp', ''), # aten.lerp.Scalar - couldn't find symbolic meta function/decomposition - xfail('linalg.cholesky', ''), # aten.linalg_cholesky_ex.default - couldn't find symbolic meta function/decomposition - xfail('linalg.cholesky_ex', ''), # aten.linalg_cholesky_ex.default - couldn't find symbolic meta function/decomposition xfail('linalg.cond', ''), # Tensors of type TensorImpl do not have numel xfail('linalg.cross', ''), # aten.linalg_cross.default - couldn't find symbolic meta function/decomposition xfail('linalg.det', ''), # aten._linalg_det.default - couldn't find symbolic meta function/decomposition @@ -1219,49 +1217,34 @@ def f(a, b, c, d, e): xfail('lu', ''), # aten.linalg_lu_factor_ex.default - couldn't find symbolic meta function/decomposition xfail('lu_solve', ''), # aten.linalg_lu_solve.default - couldn't find symbolic meta function/decomposition xfail('lu_unpack', ''), # aten.lu_unpack.default - couldn't find symbolic meta function/decomposition - xfail('masked_fill', ''), # expected predicate to be bool, got torch.float32 - xfail('masked_scatter', ''), # aten.masked_scatter.default - couldn't find symbolic meta function/decomposition xfail('masked_select', ''), # aten.masked_select.default - couldn't find symbolic meta function/decomposition xfail('matrix_exp', ''), # aten.linalg_matrix_exp.default - couldn't find symbolic meta function/decomposition - xfail('max', 'reduction_with_dim'), # aten.max.dim - couldn't find symbolic meta function/decomposition xfail('median', ''), # Could not run 'aten::median' with arguments from the 'Meta' backend. This could be becau... xfail('meshgrid', 'list_of_tensors'), # Tensors of type TensorImpl do not have numel xfail('meshgrid', 'variadic_tensors'), # Tensors of type TensorImpl do not have numel xfail('min', 'reduction_with_dim'), # aten.min.dim - couldn't find symbolic meta function/decomposition xfail('mode', ''), # aten.mode.default - couldn't find symbolic meta function/decomposition - xfail('msort', ''), # aten.sort.default - couldn't find symbolic meta function/decomposition xfail('nanquantile', ''), # Could not run 'aten::equal' with arguments from the 'Meta' backend. xfail('narrow', ''), # aten.size.default - couldn't find symbolic meta function/decomposition - xfail('nn.functional.adaptive_avg_pool3d', ''), # aten._adaptive_avg_pool3d.default - couldn't find symbolic meta func... + xfail('max_pool2d_with_indices_backward', ''), # (symint math failure) Given input size: (s0xs1x2). Calculated ... xfail('nn.functional.adaptive_max_pool1d', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('nn.functional.adaptive_max_pool2d', ''), # aten.adaptive_max_pool2d.default - couldn't find symbolic meta funct... xfail('nn.functional.adaptive_max_pool3d', ''), # argument 'output_size' (position 2) must be tupl... xfail('nn.functional.avg_pool3d', ''), # aten.avg_pool3d.default - couldn't find symbolic meta function/decomposition xfail('nn.functional.bilinear', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('nn.functional.binary_cross_entropy', ''), # aten.new_empty.default - couldn't find symbolic meta function/decom... - xfail('nn.functional.conv1d', ''), # aten.convolution.default - couldn't find symbolic meta function/decomposition - xfail('nn.functional.conv2d', ''), # aten.convolution.default - couldn't find symbolic meta function/decomposition - xfail('nn.functional.cosine_embedding_loss', ''), # The underlying op of 'aten.stride' has no overload name '_schema' xfail('nn.functional.cosine_similarity', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('nn.functional.cross_entropy', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('nn.functional.ctc_loss'), # aten._ctc_loss.Tensor - couldn't find symbolic meta function/decomposition - xfail('nn.functional.dropout2d', ''), # Tensors of type TensorImpl do not have numel - xfail('nn.functional.dropout3d', ''), # Tensors of type TensorImpl do not have numel - xfail('nn.functional.dropout', ''), # Tensors of type TensorImpl do not have numel xfail('nn.functional.embedding_bag', ''), # aten._embedding_bag_forward_only.default - couldn't find symbolic meta fun... - xfail('nn.functional.embedding', ''), # argument 'size' must be tuple of ints, but found element of type tor... xfail('nn.functional.fractional_max_pool2d', ''), # argument 'size' must be tuple of ints, but found element of t... xfail('nn.functional.fractional_max_pool3d', ''), # argument 'size' must be tuple of ints, but found element of t... xfail('nn.functional.grid_sample', ''), # aten.grid_sampler_2d.default - couldn't find symbolic meta function/decompos... - xfail('nn.functional.group_norm', ''), # 'torch._C.SymIntNode' and 'int' - xfail('nn.functional.hinge_embedding_loss', ''), # aten.empty_like.default - couldn't find symbolic meta function/deco... xfail('nn.functional.interpolate', 'area'), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('nn.functional.interpolate', 'bicubic'), # aten.upsample_bicubic2d.vec - couldn't find symbolic meta function/d... - xfail('nn.functional.interpolate', 'bilinear'), # aten.upsample_bilinear2d.vec - couldn't find symbolic meta function... xfail('nn.functional.interpolate', 'linear'), # aten.upsample_linear1d.vec - couldn't find symbolic meta function/dec... xfail('nn.functional.interpolate', 'nearest'), # aten.upsample_nearest1d.vec - couldn't find symbolic meta function/d... xfail('nn.functional.interpolate', 'trilinear'), # aten.upsample_trilinear3d.vec - couldn't find symbolic meta functi... - xfail('nn.functional.margin_ranking_loss', ''), # The underlying op of 'aten.stride' has no overload name '_schema' xfail('nn.functional.max_pool1d', ''), # Trying to call aten.size on a tensor with symbolic shapes. xfail('nn.functional.max_pool3d', ''), # aten.max_pool3d_with_indices.default - couldn't find symbolic meta function/d... xfail('nn.functional.max_unpool1d', 'grad'), # aten.max_unpool2d.default - couldn't find symbolic meta function/decom... @@ -1274,11 +1257,9 @@ def f(a, b, c, d, e): xfail('nn.functional.pdist', ''), # Could not run 'aten::_pdist_forward' with arguments from the 'Meta' backend... xfail('nn.functional.pixel_shuffle', ''), # aten.pixel_shuffle.default - couldn't find symbolic meta function/decompos... xfail('nn.functional.pixel_unshuffle', ''), # aten.pixel_unshuffle.default - couldn't find symbolic meta function/deco... - xfail('nn.functional.rrelu', ''), # aten.empty_like.default - couldn't find symbolic meta function/decomposition xfail('nn.functional.smooth_l1_loss', ''), # aten.size.default - couldn't find symbolic meta function/decomposition - xfail('nn.functional.unfold', ''), # aten.im2col.default - couldn't find symbolic meta function/decomposition - xfail('nn.functional.upsample_bilinear', ''), # aten.upsample_bilinear2d.vec - couldn't find symbolic meta function/de... xfail('nn.functional.upsample_nearest', ''), # aten.upsample_nearest1d.vec - couldn't find symbolic meta function/deco... + xfail('nonzero', ''), # aten.nonzero.default - couldn't find symbolic meta function/decomposition xfail('norm', 'nuc'), # aten._linalg_svd.default - couldn't find symbolic meta function/decomposition xfail('normal', ''), # aten.normal.Tensor_Tensor - couldn't find symbolic meta function/decomposition xfail('normal', 'number_mean'), # aten.normal.float_Tensor - couldn't find symbolic meta function/decomposition @@ -1291,28 +1272,16 @@ def f(a, b, c, d, e): xfail('polygamma', 'polygamma_n_2'), # aten.polygamma.default - couldn't find symbolic meta function/decomposition xfail('polygamma', 'polygamma_n_3'), # aten.polygamma.default - couldn't find symbolic meta function/decomposition xfail('polygamma', 'polygamma_n_4'), # aten.polygamma.default - couldn't find symbolic meta function/decomposition - xfail('put', ''), # aten.clone.default - couldn't find symbolic meta function/decomposition xfail('quantile', ''), # Could not run 'aten::equal' with arguments from the 'Meta' backend. xfail('qr', ''), # aten.linalg_qr.default - couldn't find symbolic meta function/decomposition - xfail('rad2deg', ''), # aten.rad2deg.default - couldn't find symbolic meta function/decomposition xfail('renorm', ''), # aten.renorm.default - couldn't find symbolic meta function/decomposition + xfail('repeat_interleave', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('reshape_as', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('resize_', ''), # aten.clone.default - couldn't find symbolic meta function/decomposition xfail('resize_as_', ''), # aten.clone.default - couldn't find symbolic meta function/decomposition xfail('roll', ''), # Tensors of type TensorImpl do not have numel - xfail('round', ''), # aten.round.default - couldn't find symbolic meta function/decomposition - xfail('round', 'decimals_0'), # aten.round.decimals - couldn't find symbolic meta function/decomposition - xfail('round', 'decimals_3'), # aten.round.decimals - couldn't find symbolic meta function/decomposition - xfail('round', 'decimals_neg_3'), # aten.round.decimals - couldn't find symbolic meta function/decomposition - xfail('scatter', ''), # aten.scatter.src - couldn't find symbolic meta function/decomposition - xfail('scatter_reduce', 'amax'), # aten.scatter_reduce.two - couldn't find symbolic meta function/decomposition - xfail('scatter_reduce', 'amin'), # aten.scatter_reduce.two - couldn't find symbolic meta function/decomposition - xfail('scatter_reduce', 'mean'), # aten.scatter_reduce.two - couldn't find symbolic meta function/decomposition - xfail('scatter_reduce', 'prod'), # aten.scatter_reduce.two - couldn't find symbolic meta function/decomposition - xfail('scatter_reduce', 'sum'), # aten.scatter_reduce.two - couldn't find symbolic meta function/decomposition xfail('searchsorted', ''), # Could not run 'aten::searchsorted.Tensor' with arguments from the 'Meta' backend. ... xfail('segment_reduce', 'offsets'), # aten.segment_reduce.default - couldn't find symbolic meta function/decomposition - xfail('sort', ''), # aten.sort.default - couldn't find symbolic meta function/decomposition xfail('special.airy_ai', ''), # aten.special_airy_ai.default - couldn't find symbolic meta function/decomposition xfail('special.bessel_y0', ''), # aten.special_bessel_y0.default - couldn't find symbolic meta function/decomposition xfail('special.bessel_y1', ''), # aten.special_bessel_y1.default - couldn't find symbolic meta function/decomposition @@ -1328,8 +1297,6 @@ def f(a, b, c, d, e): xfail('special.polygamma', 'special_polygamma_n_0'), # aten.polygamma.default - couldn't find symbolic meta function/... xfail('special.scaled_modified_bessel_k0', ''), # aten.special_scaled_modified_bessel_k0.default - couldn't find symbo... xfail('special.scaled_modified_bessel_k1', ''), # aten.special_scaled_modified_bessel_k1.default - couldn't find symbo... - xfail('special.xlog1py', ''), # aten.special_xlog1py.default - couldn't find symbolic meta function/decomposition - xfail('split', ''), # 'torch._C.SymIntNode' and 'int' xfail('stft', ''), # argument 'size' must be tuple of ints, but found element of type torch._C.SymIntNode at... xfail('sum_to_size', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('svd', ''), # aten._linalg_svd.default - couldn't find symbolic meta function/decomposition @@ -1338,14 +1305,13 @@ def f(a, b, c, d, e): xfail('take_along_dim', ''), # dtype of indices should be Long but got Float xfail('take', ''), # aten.take.default - couldn't find symbolic meta function/decomposition xfail('tensordot', ''), # aten.size.default - couldn't find symbolic meta function/decomposition - xfail('topk', ''), # aten.topk.default - couldn't find symbolic meta function/decomposition xfail('trapz', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('trapezoid', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('triangular_solve', ''), # aten.triangular_solve.default - couldn't find symbolic meta function/decomposition - xfail('view_as_complex', ''), # aten.view_as_complex.default - couldn't find symbolic meta function/decomposition xfail('view_as', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('vsplit', ''), # aten.size.default - couldn't find symbolic meta function/decomposition - xfail('unbind', ''), # aten.unbind.int - couldn't find symbolic meta function/decomposition + xfail('unique_consecutive', ''), # aten.unique_consecutive.default - couldn't find symbolic meta function/decomposition + xfail('unique', ''), # aten._unique2.default - couldn't find symbolic meta function/decomposition } symbolic_tensor_segfaults = { skip('nn.functional.batch_norm') # Segfault?? @@ -1353,101 +1319,29 @@ def f(a, b, c, d, e): symbolic_tensor_failures.update(symbolic_tensor_segfaults) +outplace_symbolic_tensor_failures = { + xfail('masked_fill', ''), # expected predicate to be bool, got torch.float32 + xfail('masked_scatter', ''), # aten.masked_scatter.default - couldn't find symbolic meta function/decomposition + xfail('nn.functional.rrelu', ''), # aten.empty_like.default - couldn't find symbolic meta function/decomposition +} + inplace_symbolic_tensor_failures = { - xfail('abs', ''), # aten.abs_.default - couldn't find symbolic meta function/decomposition - xfail('acos', ''), # aten.acos_.default - couldn't find symbolic meta function/decomposition - xfail('acosh', ''), # aten.acosh_.default - couldn't find symbolic meta function/decomposition - xfail('addbmm', ''), # aten.addbmm_.default - couldn't find symbolic meta function/decomposition - xfail('addcdiv', ''), # aten.addcdiv_.default - couldn't find symbolic meta function/decomposition - xfail('addcmul', ''), # aten.addcmul_.default - couldn't find symbolic meta function/decomposition - xfail('addmm', ''), # aten.addmm_.default - couldn't find symbolic meta function/decomposition - xfail('addmm', 'decomposed'), # aten.addmm_.default - couldn't find symbolic meta function/decomposition - xfail('as_strided', ''), # aten.as_strided_.default - couldn't find symbolic meta function/decomposition - xfail('asin', ''), # aten.asin_.default - couldn't find symbolic meta function/decomposition - xfail('asinh', ''), # aten.asinh_.default - couldn't find symbolic meta function/decomposition - xfail('atan2', ''), # aten.atan2_.default - couldn't find symbolic meta function/decomposition - xfail('atan', ''), # aten.atan_.default - couldn't find symbolic meta function/decomposition - xfail('atanh', ''), # aten.atanh_.default - couldn't find symbolic meta function/decomposition - xfail('ceil', ''), # aten.ceil_.default - couldn't find symbolic meta function/decomposition - xfail('clamp', ''), # aten.clamp_.Tensor - couldn't find symbolic meta function/decomposition - xfail('clamp_max', ''), # aten.clamp_max_.Tensor - couldn't find symbolic meta function/decomposition - xfail('clamp_min', ''), # aten.clamp_min_.Tensor - couldn't find symbolic meta function/decomposition - xfail('conj_physical', ''), # aten.conj_physical_.default - couldn't find symbolic meta function/decomposition - xfail('copysign', ''), # aten.copysign_.Tensor - couldn't find symbolic meta function/decomposition - xfail('cos', ''), # aten.cos_.default - couldn't find symbolic meta function/decomposition - xfail('cosh', ''), # aten.cosh_.default - couldn't find symbolic meta function/decomposition - xfail('cumsum', ''), # aten.cumsum_.default - couldn't find symbolic meta function/decomposition - xfail('digamma', ''), # aten.digamma_.default - couldn't find symbolic meta function/decomposition - xfail('div', 'floor_rounding'), # aten.div_.Tensor_mode - couldn't find symbolic meta function/decomposition - xfail('div', 'no_rounding_mode'), # aten.div_.Tensor - couldn't find symbolic meta function/decomposition - xfail('div', 'trunc_rounding'), # aten.div_.Tensor_mode - couldn't find symbolic meta function/decomposition - xfail('eq', ''), # aten.eq_.Tensor - couldn't find symbolic meta function/decomposition - xfail('erf', ''), # aten.erf_.default - couldn't find symbolic meta function/decomposition - xfail('erfc', ''), # aten.erfc_.default - couldn't find symbolic meta function/decomposition - xfail('erfinv', ''), # aten.erfinv_.default - couldn't find symbolic meta function/decomposition - xfail('exp2', ''), # aten.exp2_.default - couldn't find symbolic meta function/decomposition - xfail('exp', ''), # aten.exp_.default - couldn't find symbolic meta function/decomposition - xfail('expm1', ''), # aten.expm1_.default - couldn't find symbolic meta function/decomposition - xfail('float_power', ''), # the base given to float_power_ has dtype Float but the operation's result requires dtype Double - xfail('floor', ''), # aten.floor_.default - couldn't find symbolic meta function/decomposition - xfail('floor_divide', ''), # aten.floor_divide_.Tensor - couldn't find symbolic meta function/decomposition - xfail('fmod', ''), # aten.fmod_.Tensor - couldn't find symbolic meta function/decomposition - xfail('frac', ''), # aten.frac_.default - couldn't find symbolic meta function/decomposition - xfail('ge', ''), # aten.ge_.Tensor - couldn't find symbolic meta function/decomposition - xfail('gt', ''), # aten.gt_.Tensor - couldn't find symbolic meta function/decomposition - xfail('heaviside', ''), # aten.heaviside_.default - couldn't find symbolic meta function/decomposition - xfail('hypot', ''), # aten.hypot_.default - couldn't find symbolic meta function/decomposition - xfail('igamma', ''), # aten.igamma_.default - couldn't find symbolic meta function/decomposition - xfail('igammac', ''), # aten.igammac_.default - couldn't find symbolic meta function/decomposition - xfail('le', ''), # aten.le_.Tensor - couldn't find symbolic meta function/decomposition - xfail('lgamma', ''), # aten.lgamma_.default - couldn't find symbolic meta function/decomposition - xfail('log10', ''), # aten.log10_.default - couldn't find symbolic meta function/decomposition - xfail('log1p', ''), # aten.log1p_.default - couldn't find symbolic meta function/decomposition - xfail('log2', ''), # aten.log2_.default - couldn't find symbolic meta function/decomposition - xfail('log', ''), # aten.log_.default - couldn't find symbolic meta function/decomposition - xfail('logical_and', ''), # aten.logical_and_.default - couldn't find symbolic meta function/decomposition - xfail('logical_or', ''), # aten.logical_or_.default - couldn't find symbolic meta function/decomposition - xfail('logical_xor', ''), # aten.logical_xor_.default - couldn't find symbolic meta function/decomposition - xfail('logit', ''), # aten.logit_.default - couldn't find symbolic meta function/decomposition - xfail('lt', ''), # aten.lt_.Tensor - couldn't find symbolic meta function/decomposition - xfail('mvlgamma', 'mvlgamma_p_1'), # aten.mvlgamma_.default - couldn't find symbolic meta function/decomposition - xfail('mvlgamma', 'mvlgamma_p_3'), # aten.mvlgamma_.default - couldn't find symbolic meta function/decomposition - xfail('mvlgamma', 'mvlgamma_p_5'), # aten.mvlgamma_.default - couldn't find symbolic meta function/decomposition - xfail('nan_to_num', ''), # aten.nan_to_num_.default - couldn't find symbolic meta function/decomposition - xfail('ne', ''), # aten.ne_.Tensor - couldn't find symbolic meta function/decomposition - xfail('neg', ''), # aten.neg_.default - couldn't find symbolic meta function/decomposition - xfail('nextafter', ''), # aten.nextafter_.default - couldn't find symbolic meta function/decomposition - xfail('nn.functional.celu', ''), # aten.celu_.default - couldn't find symbolic meta function/decomposition - xfail('nn.functional.elu', ''), # aten.elu_.default - couldn't find symbolic meta function/decomposition - xfail('nn.functional.hardsigmoid', ''), # aten.hardsigmoid_.default - couldn't find symbolic meta function/decomposition - xfail('nn.functional.mish', ''), # aten.mish_.default - couldn't find symbolic meta function/decomposition - xfail('nn.functional.selu', ''), # aten.elu_.default - couldn't find symbolic meta function/decomposition - xfail('nn.functional.threshold', ''), # aten.threshold_.default - couldn't find symbolic meta function/decomposition - xfail('pow', ''), # aten.pow_.Tensor - couldn't find symbolic meta function/decomposition - xfail('reciprocal', ''), # aten.reciprocal_.default - couldn't find symbolic meta function/decomposition - xfail('remainder', ''), # aten.remainder_.Tensor - couldn't find symbolic meta function/decomposition - xfail('rsqrt', ''), # aten.rsqrt_.default - couldn't find symbolic meta function/decomposition - xfail('scatter_add', ''), # aten.scatter_add_.default - couldn't find symbolic meta function/decomposition - xfail('sgn', ''), # aten.sgn_.default - couldn't find symbolic meta function/decomposition - xfail('sigmoid', ''), # aten.sigmoid_.default - couldn't find symbolic meta function/decomposition - xfail('sign', ''), # aten.sign_.default - couldn't find symbolic meta function/decomposition - xfail('sin', ''), # aten.sin_.default - couldn't find symbolic meta function/decomposition - xfail('sinc', ''), # aten.sinc_.default - couldn't find symbolic meta function/decomposition - xfail('sinh', ''), # aten.sinh_.default - couldn't find symbolic meta function/decomposition - xfail('sqrt', ''), # aten.sqrt_.default - couldn't find symbolic meta function/decomposition - xfail('square', ''), # aten.pow_.Scalar - couldn't find symbolic meta function/decomposition - xfail('squeeze', ''), # aten.squeeze_.default - couldn't find symbolic meta function/decomposition - xfail('t', ''), # aten.t_.default - couldn't find symbolic meta function/decomposition - xfail('tan', ''), # aten.tan_.default - couldn't find symbolic meta function/decomposition - xfail('tanh', ''), # aten.tanh_.default - couldn't find symbolic meta function/decomposition - xfail('transpose', ''), # aten.transpose_.default - couldn't find symbolic meta function/decomposition - xfail('tril', ''), # aten.tril_.default - couldn't find symbolic meta function/decomposition - xfail('triu', ''), # aten.triu_.default - couldn't find symbolic meta function/decomposition - xfail('true_divide', ''), # aten.div_.Tensor - couldn't find symbolic meta function/decomposition - xfail('trunc', ''), # aten.trunc_.default - couldn't find symbolic meta function/decomposition - xfail('uniform', ''), # aten.uniform_.default - couldn't find symbolic meta function/decomposition - xfail('unsqueeze', ''), # aten.unsqueeze_.default - couldn't find symbolic meta function/decomposition - xfail('xlogy', ''), # aten.xlogy_.Tensor - couldn't find symbolic meta function/decomposition + # bugs + xfail('float_power', ''), # base given to float_power_ has dtype Float but the operation's result requires dtype Double + # decomp not implemented + xfail('addmm', ''), + xfail('addmm', 'decomposed'), + xfail('nn.functional.hardsigmoid', ''), + xfail('round', ''), # ref missing a kwarg + xfail('round', 'decimals_0'), # ref missing a kwarg + xfail('round', 'decimals_3'), # ref missing a kwarg + xfail('round', 'decimals_neg_3'), # ref missing a kwarg + xfail('unique', ''), + # in-place has a different signature than out-of-place + xfail('uniform', ''), + # Views + xfail('t', ''), + xfail('transpose', ''), } # Copies inputs to inplace operations to avoid inplace modifications @@ -1460,10 +1354,13 @@ def _fn(t, *args, **kwargs): return _fn def _test_make_fx_helper(self, device, dtype, op, tracing_mode, inplace=False): - def f(args, kwargs, extra_args): + def f(args, kwargs, extra_args, extra_kwargs): if extra_args: for i, t in extra_args: args[i] = t.size() + if extra_kwargs: + for k, t in extra_kwargs.items(): + kwargs[k] = t.size() fn = _get_safe_inplace(op.get_inplace()) if inplace else op.op return fn(*args, **kwargs) @@ -1484,23 +1381,26 @@ def f(args, kwargs, extra_args): # - Unpack the size in the wrapper to get a torch.Size with dynamic shapes (in # symbolic mode, a no-op otherwise) extra_args = [] + extra_kwargs = {} for i, arg in enumerate(args): if isinstance(arg, torch.Size): - extra_args.append((i, torch.empty((), device="cpu").expand(arg))) - # TODO: support kwargs + extra_args.append((i, torch.empty(arg, device="cpu"))) + for key, value in kwargs.items(): + if isinstance(value, torch.Size): + extra_kwargs[key] = torch.empty(value, device="cpu") try: - new_f = make_fx(f, tracing_mode=tracing_mode)(args, kwargs, extra_args) + new_f = make_fx(f, tracing_mode=tracing_mode)(args, kwargs, extra_args, extra_kwargs) except DynamicOutputShapeException as e: self.skipTest("Dynamic output shape operation in trace") for arg in args: if isinstance(arg, torch.Tensor) and arg.dtype == torch.float: arg.uniform_(0, 1) try: - old_out = f(args, kwargs, extra_args) + old_out = f(args, kwargs, extra_args, extra_kwargs) except Exception: continue - new_out = wrapper_set_seed(new_f, args, kwargs, extra_args) + new_out = wrapper_set_seed(new_f, args, kwargs, extra_args, extra_kwargs) self.assertEqual(new_out, old_out) class TestProxyTensorOpInfo(TestCase): @@ -1517,7 +1417,7 @@ def test_make_fx_fake_exhaustive(self, device, dtype, op): @skipIfNoSympy @ops(op_db, allowed_dtypes=(torch.float,)) @skipOps('TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive', - make_fx_failures | fake_tensor_failures | symbolic_tensor_failures) + make_fx_failures | fake_tensor_failures | symbolic_tensor_failures | outplace_symbolic_tensor_failures) def test_make_fx_symbolic_exhaustive(self, device, dtype, op): _test_make_fx_helper(self, device, dtype, op, "symbolic") diff --git a/test/test_public_bindings.py b/test/test_public_bindings.py index 5215281b7ac62..4d2df65126983 100644 --- a/test/test_public_bindings.py +++ b/test/test_public_bindings.py @@ -207,8 +207,8 @@ def test_no_new_bindings(self): "StreamObjType", "StringType", "SUM", - "SymFloatNode", - "SymIntNode", + "SymFloat", + "SymInt", "TensorType", "ThroughputBenchmark", "TracingState", @@ -261,7 +261,7 @@ def test_no_new_bindings(self): "set_num_threads", "unify_type_list", "vitals_enabled", - + "VULKAN_AUTOMATIC_GPU_TRANSFER", "wait", "Tag", } diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py index dea96d19b74c4..33465217bbbc0 100644 --- a/test/test_python_dispatch.py +++ b/test/test_python_dispatch.py @@ -390,6 +390,24 @@ def test_produce_real_type(self) -> None: $4 = torch._ops.aten.select.int($3, 1, 1) $5 = torch._ops.aten.clone.default($4, memory_format=torch.contiguous_format)''') + def test_optional_tensor_list(self) -> None: + def weird(xs): + print("woof") + return torch.empty(()) + + my_lib = Library("my_lib", "DEF") + my_lib.define("weird(Tensor?[] self) -> Tensor") + my_lib.impl("weird", weird, "CPU") + with capture_logs() as logs: + x = LoggingTensor(torch.ones(2, 2)) + log_input("x", x) + torch.ops.my_lib.weird.default([None, x]) + + self.assertExpectedInline('\n'.join(logs), '''\ +$0 = input('x') +$1 = torch._ops.my_lib.weird.default([None, LoggingTensor(tensor([[1., 1.], + [1., 1.]]))])''') + def test_list_ret(self) -> None: # test all sequence types are permissible returns for list_type in (list, tuple): @@ -1050,7 +1068,7 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): return func(args, kwargs) x = torch.tensor(5.) - with self.assertRaisesRegex(RuntimeError, "should be a normal method not a class method"): + with self.assertRaisesRegex(RuntimeError, "classmethod is not supported, please make it a plain method"): with A(): x + x diff --git a/test/test_quantization.py b/test/test_quantization.py index 2dc6f7ac7850d..2726e0f82eec5 100644 --- a/test/test_quantization.py +++ b/test/test_quantization.py @@ -52,7 +52,6 @@ from quantization.eager.test_quantize_eager_ptq import TestQuantizeEagerPTQStatic # noqa: F401 from quantization.eager.test_quantize_eager_ptq import TestQuantizeEagerPTQDynamic # noqa: F401 from quantization.eager.test_quantize_eager_ptq import TestQuantizeEagerOps # noqa: F401 -from quantization.eager.test_quantize_eager_ptq import TestQuantizeEagerONNXExport # noqa: F401 # 2. Eager mode quantization aware training from quantization.eager.test_quantize_eager_qat import TestQuantizeEagerQAT # noqa: F401 from quantization.eager.test_quantize_eager_qat import TestQuantizeEagerQATNumerics # noqa: F401 diff --git a/test/test_reductions.py b/test/test_reductions.py index a4be31cd6f929..7a360888e6592 100644 --- a/test/test_reductions.py +++ b/test/test_reductions.py @@ -464,9 +464,9 @@ def test_dim_reduction_less_than_64(self, device): torch.norm] for op in ops: with self.assertRaisesRegex(RuntimeError, "only tensors with up to 64 dims are supported"): - op(x, 64) + op(x, dim=64) with self.assertRaisesRegex(RuntimeError, "only tensors with up to 64 dims are supported"): - op(x, -1) + op(x, dim=-1) @onlyCPU @dtypes(torch.float, torch.bfloat16) @@ -1793,11 +1793,9 @@ def test_repeated_dim(self, device): x = torch.randn(3, 3, 3, 3, device=device) error_msg = r'appears multiple times in the list of dims' - norm_error_msg = r'Expected dims to be different, got' for op in ops: for dim in [(0, 0), (0, -4)]: - e_msg = norm_error_msg if op == torch.norm else error_msg - with self.assertRaisesRegex(RuntimeError, e_msg): + with self.assertRaisesRegex(RuntimeError, error_msg): op(x, dim=dim) # TODO: update this test to comapre against NumPy @@ -2843,6 +2841,9 @@ def test_against_np(tensor, bins=100, min=0, max=0): expanded = torch.randn(1, 5, 1, 2, device=device).expand(3, 5, 7, 2) test_against_np(expanded) + linear = torch.linspace(0, 0.99 - 5.0e-7, 101).to(device) + test_against_np(linear, bins=20, min=0, max=0.99) + @onlyCPU def test_histc_bfloat16(self, device): actual = torch.histc( diff --git a/test/test_serialization.py b/test/test_serialization.py index d8cfd08aea084..b97c35c46762a 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -297,6 +297,9 @@ def _test_serialization(conversion): self.assertEqual(x, y["tensor"]) _test_serialization(lambda x: x.to_sparse()) _test_serialization(lambda x: x.to_sparse_csr()) + _test_serialization(lambda x: x.to_sparse_csc()) + _test_serialization(lambda x: x.to_sparse_bsr(1, 1)) + _test_serialization(lambda x: x.to_sparse_bsc(1, 1)) def test_serialization_sparse(self): self._test_serialization(False) @@ -333,36 +336,60 @@ def __reduce_ex__(self, proto): "size is inconsistent with indices"): y = torch.load(f) - def test_serialization_sparse_csr_invalid(self): + def _test_serialization_sparse_compressed_invalid(self, + conversion, + get_compressed_indices, + get_plain_indices): x = torch.zeros(3, 3) x[1][1] = 1 - x = x.to_sparse_csr() + x = conversion(x) class TensorSerializationSpoofer(object): def __init__(self, tensor): self.tensor = tensor def __reduce_ex__(self, proto): - invalid_crow_indices = self.tensor.crow_indices().clone() - invalid_crow_indices[0] = 3 + invalid_compressed_indices = get_compressed_indices(self.tensor).clone() + invalid_compressed_indices[0] = 3 return ( torch._utils._rebuild_sparse_tensor, ( self.tensor.layout, ( - invalid_crow_indices, - self.tensor.col_indices(), + invalid_compressed_indices, + get_plain_indices(self.tensor), self.tensor.values(), self.tensor.size()))) + if x.layout in {torch.sparse_csr, torch.sparse_bsr}: + compressed_indices_name = 'crow_indices' + else: + compressed_indices_name = 'ccol_indices' + with tempfile.NamedTemporaryFile() as f: torch.save({"spoofed": TensorSerializationSpoofer(x)}, f) f.seek(0) with self.assertRaisesRegex( RuntimeError, - "rebuilding sparse tensor for layout torch.sparse_csr"): + f"`{compressed_indices_name}[[]..., 0[]] == 0` is not satisfied."): y = torch.load(f) + def test_serialization_sparse_csr_invalid(self): + self._test_serialization_sparse_compressed_invalid( + torch.Tensor.to_sparse_csr, torch.Tensor.crow_indices, torch.Tensor.col_indices) + + def test_serialization_sparse_csc_invalid(self): + self._test_serialization_sparse_compressed_invalid( + torch.Tensor.to_sparse_csc, torch.Tensor.ccol_indices, torch.Tensor.row_indices) + + def test_serialization_sparse_bsr_invalid(self): + self._test_serialization_sparse_compressed_invalid( + lambda x: x.to_sparse_bsr(1, 1), torch.Tensor.crow_indices, torch.Tensor.col_indices) + + def test_serialization_sparse_bsc_invalid(self): + self._test_serialization_sparse_compressed_invalid( + lambda x: x.to_sparse_bsc(1, 1), torch.Tensor.ccol_indices, torch.Tensor.row_indices) + def test_serialize_device(self): device_str = ['cpu', 'cpu:0', 'cuda', 'cuda:0'] device_obj = [torch.device(d) for d in device_str] @@ -567,6 +594,34 @@ def test_serialization_filelike_uses_readinto(self): b = torch.load(data) self.assertTrue(data.was_called('readinto')) + def test_serialization_filelike_exceptions(self): + # Try to serialize to buffers that does not have write method + # Or have a malfrormed one, and make sure it does not cause an abort + # See https://github.com/pytorch/pytorch/issues/87997 + x = torch.rand(10) + with self.assertRaises(AttributeError): + # Tries to serialize str into tensor + torch.save('foo', x) + x.write = "bar" + x.flush = "baz" + with self.assertRaises(TypeError): + # Tries to serialize str into tensor with write property + torch.save('foo', x) + x.write = str.__add__ + x.flush = str.__mul__ + with self.assertRaises(TypeError): + # Tries to serialize str into tensor with wrong callable write property + torch.save('foo', x) + s_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + s = torch.CharStorage(s_data) + with self.assertRaises(AttributeError): + # Tries to serialize list into CharStorage + torch.save(s_data, s) + x = torch.randint(10, (3, 3), dtype=torch.float).cpu().numpy() + with self.assertRaises(AttributeError): + # Tries to serialize ndarray into ndarray + torch.save(x, x) + def test_serialization_storage_slice(self): # Generated using: @@ -877,6 +932,28 @@ def test_meta_serialization(self, weights_only): self.assertEqual(state['weight'].size(), big_model.weight.size()) + def test_serialization_python_attr(self): + def _test_save_load_attr(t): + t.foo = 'foo' + t.pi = 3.14 + + with BytesIOContext() as f: + torch.save(t, f) + f.seek(0) + loaded_t = torch.load(f) + + self.assertEqual(t, loaded_t) + self.assertEqual(t.foo, loaded_t.foo) + self.assertEqual(t.pi, loaded_t.pi) + + t = torch.zeros(3, 3) + _test_save_load_attr(t) + # This should start failing once Parameter + # supports saving Python Attribute. + err_msg = "'Parameter' object has no attribute" + with self.assertRaisesRegex(AttributeError, err_msg): + _test_save_load_attr(torch.nn.Parameter(t)) + def test_weights_only_assert(self): class HelloWorld: def __reduce__(self): @@ -892,6 +969,48 @@ def __reduce__(self): with self.assertRaisesRegex(pickle.UnpicklingError, "Unsupported class"): torch.load(f, weights_only=True) + @parametrize('weights_only', (False, True)) + def test_serialization_math_bits(self, weights_only): + t = torch.randn(1, dtype=torch.cfloat) + + def _save_load_check(t): + with BytesIOContext() as f: + torch.save(t, f) + f.seek(0) + # Unsafe load should work + self.assertEqual(torch.load(f, weights_only=weights_only), t) + + t_conj = torch.conj(t) + _save_load_check(t_conj) + + t_neg = torch._neg_view(t) + _save_load_check(t_neg) + + t_n_c = torch._neg_view(torch.conj(t)) + _save_load_check(t_n_c) + + @parametrize('weights_only', (False, True)) + def test_serialization_efficient_zerotensor(self, weights_only): + # We don't support serializing `ZeroTensor` as it is not public + # facing yet. + # If in future, `ZeroTensor` serialization is supported, this test + # should start failing! + t = torch._efficientzerotensor((4, 5)) + + def _save_load_check(t): + with BytesIOContext() as f: + torch.save(t, f) + f.seek(0) + # Unsafe load should work + self.assertEqual(torch.load(f, weights_only=weights_only), t) + + # NOTE: `torch.save` fails before we hit the TORCH_CHECK in `getTensoMetadata` + # as nullptr storage is disabled. + err_msg = (r'python bindings to nullptr storage \(e.g., from torch.Tensor._make_wrapper_subclass\)' + ' are currently unsafe and thus disabled') + with self.assertRaisesRegex(RuntimeError, err_msg): + _save_load_check(t) + def run(self, *args, **kwargs): with serialization_method(use_zip=True): return super(TestSerialization, self).run(*args, **kwargs) diff --git a/test/test_sparse.py b/test/test_sparse.py index 8ae982c034ae4..93a2241d06804 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -9,7 +9,8 @@ from torch.testing import make_tensor from torch.testing._internal.common_utils import TestCase, run_tests, skipIfRocm, do_test_dtypes, \ do_test_empty_full, load_tests, TEST_NUMPY, TEST_SCIPY, IS_WINDOWS, gradcheck, coalescedonoff, \ - DeterministicGuard, first_sample, TEST_WITH_CROSSREF, TEST_WITH_ROCM, skipIfTorchDynamo + DeterministicGuard, first_sample, TEST_WITH_CROSSREF, TEST_WITH_ROCM, skipIfTorchDynamo, \ + parametrize, subtest from torch.testing._internal.common_cuda import TEST_CUDA, _get_torch_cuda_version from numbers import Number from typing import Dict, Any @@ -40,6 +41,17 @@ IS_WINDOWS and torch.version.cuda and LooseVersion(torch.version.cuda) > "11.2" ) or (not IS_WINDOWS and CUDA11OrLater) +def all_sparse_layouts(test_name='layout', include_strided=False): + return parametrize(test_name, [ + subtest(torch.strided, name='Strided'), + subtest(torch.sparse_coo, name='SparseCOO'), + subtest(torch.sparse_csr, name='SparseCSR'), + subtest(torch.sparse_csc, name='SparseCSC'), + subtest(torch.sparse_bsr, name='SparseBSR'), + subtest(torch.sparse_bsc, name='SparseBSC'), + ][(0 if include_strided else 1):]) + + class CrossRefSparseFakeMode(torch._subclasses.CrossRefFakeMode): def __init__(self): super(CrossRefSparseFakeMode, self).__init__( @@ -413,9 +425,6 @@ def test_to_sparse(self, device, dtype, coalesced): self.assertEqual(expected.size(), result.size()) self.assertEqual(dim, result.sparse_dim()) - sp, _, _ = self._gen_sparse(2, 10, [3, 3, 3], dtype=value_type, device=device, coalesced=coalesced) - self.assertRaises(RuntimeError, lambda: sp.to_sparse()) - @dtypes(torch.double, torch.cdouble) def test_sparse_bool(self, device, dtype): a = torch.tensor([True, False], dtype=dtype, device=device).to(torch.bool) @@ -2184,16 +2193,7 @@ def is_integral(dtype): with self.assertRaisesRegex(RuntimeError, "log1p_ requires coalesced input"): sparse_tensor.log1p_() - if not is_integral_dtype: - sparse_tensor.requires_grad_() - self.assertTrue(sparse_tensor.requires_grad) - - # test autograd - x = sparse_tensor.clone() - y = sparse_tensor.log1p() - with self.assertRaisesRegex(RuntimeError, "log1p of a sparse tensor is made to be non-differentiable"): - y.backward(x) - else: + if is_integral_dtype: with self.assertRaisesRegex(RuntimeError, "only Tensors of floating point dtype can require gradients"): sparse_tensor.requires_grad_() @@ -3019,7 +3019,6 @@ def test_change_tensor_metadata(self, device, dtype): self.assertEqual(list(t.coalesce().indices().size()), [2, 1]) self.assertEqual(list(t.coalesce().values().size()), [1, 3]) - @skipIfRocm @coalescedonoff @dtypes(torch.double) def test_pickle(self, device, dtype, coalesced): @@ -3718,7 +3717,7 @@ def check_empty(sparse_shape, nnz, dense_shape, coalesce): check(self, s, d) check_empty(shape, nnz, sub_shape, coalesced) - @unittest.skipIf(not TEST_NUMPY, "NumPy is not availible") + @unittest.skipIf(not TEST_NUMPY, "NumPy is not available") @onlyCPU @dtypes(*all_types_and_complex_and(torch.bool)) def test_sparse_spdiags(self, device, dtype): @@ -4059,6 +4058,190 @@ def test_basic(self): self.assertEqual(r.values(), torch.empty(0, 4, device='meta')) +class TestSparseAny(TestCase): + + def test_generate_simple_inputs(self): + layouts = [torch.strided, torch.sparse_coo, torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc] + + tested_combinations = set() + for tensors in zip(*map(self.generate_simple_inputs, layouts)): + for i, t in enumerate(tensors): + self.assertEqual(t.layout, layouts[i]) + + # all layouts must produce semantically the same tensors + self.assertEqual(t, tensors[0]) + + if t.layout is torch.strided: + is_hybrid = None + else: + is_hybrid = t.dense_dim() > 0 + if t.layout in {torch.sparse_csr, torch.sparse_bsr}: + is_batch = t.crow_indices().ndim > 1 + elif t.layout in {torch.sparse_csc, torch.sparse_bsc}: + is_batch = t.ccol_indices().ndim > 1 + else: + is_batch = None + if t.layout in {torch.sparse_bsr, torch.sparse_bsc}: + blocksize = t.values().shape[1:3] + nontrivial_blocksize = 1 not in blocksize + else: + nontrivial_blocksize = None + tested_combinations.add((t.layout, is_hybrid, is_batch, nontrivial_blocksize)) + + # Ensure that the inputs generation covers all layout, + # non-hybrid/hybrid, and non-batch/batch combinations: + for layout in layouts: + for is_hybrid in [False, True]: + if layout is torch.strided: + is_hybrid = None + for is_batch in [False, True]: + if layout in {torch.sparse_coo, torch.strided}: + is_batch = None + for nontrivial_blocksize in [False, True]: + if layout not in {torch.sparse_bsr, torch.sparse_bsc}: + nontrivial_blocksize = None + key = (layout, is_hybrid, is_batch, nontrivial_blocksize) + assert key in tested_combinations, key + + @all_sparse_layouts('from_layout', include_strided=True) + @all_sparse_layouts('to_layout', include_strided=False) + @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) + def test_to_sparse(self, from_layout, to_layout, device, dtype): + """ + This test tests conversion from any layout to any sparse layout. + """ + + for t in self.generate_simple_inputs( + from_layout, device=device, dtype=dtype, + enable_hybrid=( + # TODO: to support conversion strided->hybrid + # CSR/CSC/BSR/BSC, to_sparse() requires extra keyword + # argument, either nof_batch_dims or + # nof_dense_dims + not (from_layout is torch.strided and to_layout in + {torch.sparse_bsr, torch.sparse_bsc, torch.sparse_csr, torch.sparse_csc}))): + + if to_layout in {torch.sparse_bsr, torch.sparse_bsc}: + if from_layout == torch.sparse_bsr: + batch_ndim = t.crow_indices().dim() - 1 + blocksize = t.values().shape[batch_ndim + 1:batch_ndim + 3] + elif from_layout == torch.sparse_bsc: + batch_ndim = t.ccol_indices().dim() - 1 + blocksize = t.values().shape[batch_ndim + 1:batch_ndim + 3] + else: + blocksize = (1, 1) + else: + blocksize = None + + if from_layout is torch.strided: + is_batch = None + is_hybrid = None + else: + is_batch = t.dim() > (t.sparse_dim() + t.dense_dim()) + is_hybrid = t.dense_dim() > 0 + + def explicit_to_sparse(x): + # Used to check that the explicit conversion methods + # are consistent with the `to_sparse(*, layout, + # blocksize)` method. + if to_layout is torch.sparse_coo: + return x.to_sparse_coo() + elif to_layout is torch.sparse_csr: + return x.to_sparse_csr() + elif to_layout is torch.sparse_csc: + return x.to_sparse_csc() + elif to_layout is torch.sparse_bsr: + return x.to_sparse_bsr(blocksize) + elif to_layout is torch.sparse_bsc: + return x.to_sparse_bsc(blocksize) + else: + assert 0 # unreachable + + # TODO: The following exception cases all correspond to + # not implemented conversions + if from_layout is torch.sparse_coo and to_layout in { + torch.sparse_bsr, torch.sparse_bsc} and t.sparse_dim() == 2 and is_hybrid: + with self.assertRaisesRegex(RuntimeError, "conversion from Csr to Bsr is only possible for 2d inputs"): + t.to_sparse(layout=to_layout, blocksize=blocksize) + with self.assertRaisesRegex(RuntimeError, "conversion from Csr to Bsr is only possible for 2d inputs"): + explicit_to_sparse(t) + continue + elif from_layout is torch.sparse_csr and to_layout in {torch.sparse_bsr} and (is_batch or is_hybrid): + with self.assertRaisesRegex(RuntimeError, "conversion from Csr to Bsr is only possible for 2d inputs"): + t.to_sparse(layout=to_layout, blocksize=blocksize) + with self.assertRaisesRegex(RuntimeError, "conversion from Csr to Bsr is only possible for 2d inputs"): + explicit_to_sparse(t) + continue + elif from_layout is torch.sparse_coo and to_layout in { + torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc} and t.sparse_dim() != 2: + with self.assertRaisesRegex( + RuntimeError, "Only tensors with two sparse dimensions can be converted to the Sparse(Csr|Csc) layout"): + t.to_sparse(layout=to_layout, blocksize=blocksize) + with self.assertRaisesRegex( + RuntimeError, "Only tensors with two sparse dimensions can be converted to the Sparse(Csr|Csc) layout"): + explicit_to_sparse(t) + continue + elif from_layout in {torch.sparse_csr, torch.sparse_csc} and to_layout is torch.sparse_coo and is_batch: + with self.assertRaisesRegex(RuntimeError, + "crow_indices is supposed to be a vector, but got \\d+ dimensional tensor"): + t.to_sparse(layout=to_layout, blocksize=blocksize) + with self.assertRaisesRegex(RuntimeError, + "crow_indices is supposed to be a vector, but got \\d+ dimensional tensor"): + explicit_to_sparse(t) + continue + elif from_layout in {torch.sparse_bsr, torch.sparse_bsc} and to_layout is torch.sparse_coo: + with self.assertRaisesRegex( + RuntimeError, + "sparse_compressed_to_sparse expected SparseCsr or SparseCsc layout but got Sparse(Bsr|Bsc)"): + t.to_sparse(layout=to_layout, blocksize=blocksize) + with self.assertRaisesRegex( + RuntimeError, + "sparse_compressed_to_sparse expected SparseCsr or SparseCsc layout but got Sparse(Bsr|Bsc)"): + explicit_to_sparse(t) + self.skipTest('NOT IMPL') + elif (from_layout, to_layout) in {(torch.sparse_bsc, torch.sparse_csr), (torch.sparse_bsc, torch.sparse_csc), + (torch.sparse_bsr, torch.sparse_csr), (torch.sparse_bsr, torch.sparse_csc), + (torch.sparse_csc, torch.sparse_bsr), (torch.sparse_csc, torch.sparse_bsc), + (torch.sparse_csr, torch.sparse_bsc)}: + with self.assertRaisesRegex( + RuntimeError, + r"sparse_compressed_to_sparse_(csr|csc|bsr|bsc) expected\s*(SparseCsr[,]|)\s*Sparse(Csr|Bsr)" + " or Sparse(Csc|Bsc) layout but got Sparse(Csr|Csc|Bsr|Bsc)"): + t.to_sparse(layout=to_layout, blocksize=blocksize) + with self.assertRaisesRegex( + RuntimeError, + r"sparse_compressed_to_sparse_(csr|csc|bsr|bsc) expected\s*(SparseCsr[,]|)\s*Sparse(Csr|Bsr)" + " or Sparse(Csc|Bsc) layout but got Sparse(Csr|Csc|Bsr|Bsc)"): + explicit_to_sparse(t) + self.skipTest('NOT IMPL') + else: + r = t.to_sparse(layout=to_layout, blocksize=blocksize) + + self.assertEqual(r.layout, to_layout) + + # to_sparse method uses unsafe construction of sparse + # tensors. Here we explicitly validate the results to + # make sure that the sparse tensors are consistent + # with the corresponding sparse tensor invariants. + if r.layout in {torch.sparse_csr, torch.sparse_bsr, torch.sparse_csc, torch.sparse_bsc}: + if r.layout in {torch.sparse_csr, torch.sparse_bsr}: + compressed_indices, plain_indices = r.crow_indices(), r.col_indices() + else: + compressed_indices, plain_indices = r.ccol_indices(), r.row_indices() + torch._validate_sparse_compressed_tensor_args(compressed_indices, plain_indices, r.values(), + r.shape, r.layout) + elif r.layout is torch.sparse_coo: + torch._validate_sparse_coo_tensor_args(r._indices(), r._values(), r.shape) + else: + assert 0 # unreachable + + # Finally, we'll test tensor equality: + self.assertEqual(r, t) + + # Also, check consistency with explicit conversion methods: + r2 = explicit_to_sparse(t) + self.assertEqual(r2, r) + # e.g., TestSparseUnaryUfuncsCPU and TestSparseUnaryUfuncsCUDA instantiate_device_type_tests(TestSparseUnaryUfuncs, globals(), except_for='meta') @@ -4068,5 +4251,7 @@ def test_basic(self): # e.g., TestSparseCPU and TestSparseCUDA instantiate_device_type_tests(TestSparse, globals(), except_for='meta') +instantiate_device_type_tests(TestSparseAny, globals(), except_for='meta') + if __name__ == '__main__': run_tests() diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py index 7e364ad94e071..e8eb8564b860d 100644 --- a/test/test_sparse_csr.py +++ b/test/test_sparse_csr.py @@ -1,6 +1,5 @@ # Owner(s): ["module: sparse"] -import copy import torch import random import itertools @@ -61,7 +60,13 @@ def _check_cusparse_sddmm_available(): UNARY_EWISE_CSR_ALLOW_AUTOGRAD = [ 'abs', 'conj_physical', + 'deg2rad', 'neg', + 'positive', + 'frac', + 'nn.functional.relu', + 'log1p', + 'rad2deg' ] # This should be just an import from test_linalg instead of code duplication @@ -192,147 +197,6 @@ def genTensor(self, size, nnz, *, layout, device=None, dtype=torch.float, index_ device = self.device_type return self.genSparseCompressedTensor(size, nnz, device=device, dtype=dtype, index_dtype=index_dtype, layout=layout) - def _generate_small_inputs_utils(self, layout, device=None, dtype=None): - - def shape(shape, basedim=0, blocksize=(1, 1), dense_shape=()): - # Below, we define compressed and plain indices that - # correspond to row compressed tensors. In order to reuse - # the indices tensors for column compressed tensors, we - # swap the row and columns in shape dims (basedim and - # basedim + 1, respectively) to obtain the correct shape - # for column compressed tensors. Batch and dense - # dimensions remain as they are. - # - # Similarly, we reuse indices of non-block tensors for - # block tensors, that means, we'll need to multiply the - # base shape of the non-block tensor with blocksize to get - # the base shape of a block tensor. - if layout is torch.sparse_csc: - shape = shape[:basedim] + (shape[basedim + 1], shape[basedim]) + shape[basedim + 2:] - elif layout is torch.sparse_bsc: - shape = shape[:basedim] + (shape[basedim + 1] * blocksize[1], shape[basedim] * blocksize[0]) + shape[basedim + 2:] - elif layout is torch.sparse_bsr: - shape = shape[:basedim] + (shape[basedim] * blocksize[0], shape[basedim + 1] * blocksize[1]) + shape[basedim + 2:] - return shape - - def values(lst, basedim=0, blocksize=(1, 1), densesize=(), device=device, dtype=dtype): - # Below, we define values for non-blocked and non-hybrid - # tensors. To reuse these for blocked tensors, we replace - # all values in lst with a double-list that "shape" - # corresponds to blocksize. - # To support hybrid tensors, the values in lst are further - # replaced with a N-list where N==len(densesize) and the - # shape corresponds to densesize. - - max_val = torch.iinfo(dtype).max if dtype in [torch.int16, torch.int8, torch.uint8] else None - - def list_add(lst, value): - # recursively add a value to lst items - if isinstance(lst, list): - return [list_add(item, value) for item in lst] - rc = lst + value - return rc if max_val is None else (rc % max_val) - - def stretch_values(value, bdim, values_item_shape): - # replace a value with a new value that extends the - # dimensionality of the value by - # len(values_item_shape) from right. The left - # dimensions up to bdim are considered as batch - # dimensions. - if not values_item_shape: - return value - if isinstance(value, list) and bdim >= 0: - return [stretch_values(item, bdim - 1, values_item_shape) for item in value] - new_value = functools.reduce(lambda x, dims: [copy.deepcopy(x) for _ in range(dims)], - reversed(values_item_shape), None) - for p in itertools.product(*map(list, map(range, values_item_shape))): - row = functools.reduce(lambda x, i: x.__getitem__(i), p[:-1], new_value) - row[p[-1]] = list_add(value, sum([i * 10 ** d for d, i in enumerate(p)])) - return new_value - - if layout is torch.sparse_bsr: - values_item_shape = blocksize + densesize - elif layout is torch.sparse_bsc: - values_item_shape = tuple(reversed(blocksize)) + densesize - else: - values_item_shape = densesize - - if not lst: - return torch.tensor(lst, device=device, dtype=dtype).reshape(0, *values_item_shape) - - lst = stretch_values(lst, basedim, values_item_shape) - - return torch.tensor(lst, device=device, dtype=dtype) - - return shape, values - - def _generate_small_inputs(self, layout, device=None, dtype=None, index_dtype=None, - enable_batched=True, enable_hybrid=True): - """Generator of inputs to sparse compressed tensor factory functions. - - The input is defined as a 4-tuple: - compressed_indices, plain_indices, values, expected_size_from_shape_inference - """ - if index_dtype is None: - index_dtype = torch.int64 - - shape, values = self._generate_small_inputs_utils(layout, device, dtype) - - # a regular tensor - yield (torch.tensor([0, 2, 4], device=device, dtype=index_dtype), - torch.tensor([0, 1, 0, 2], device=device, dtype=index_dtype), - values([1, 2, 3, 4], 0, (2, 1)), - shape((2, 3), 0, (2, 1))) - - # a tensor with zero dimensions - yield (torch.tensor([0, ], device=device, dtype=index_dtype), - torch.tensor([], device=device, dtype=index_dtype), - values([], 0, (2, 1)), - shape((0, 0), 0, (2, 1))) - - if enable_batched: - # a batched tensor with one batch dimension - yield (torch.tensor([[0, 2, 4], [0, 3, 4]], device=device, dtype=index_dtype), - torch.tensor([[0, 1, 0, 1], [0, 1, 2, 0]], device=device, dtype=index_dtype), - values([[1, 2, 3, 4], [5, 6, 7, 8]], 1, (1, 2)), - shape((2, 2, 3), 1, (1, 2))) - - # a batched tensor with two batch dimensions - yield (torch.tensor([[[0, 2, 4], [0, 3, 4], [0, 1, 4]], - [[0, 1, 4], [0, 2, 4], [0, 3, 4]]], - device=device, dtype=index_dtype), - torch.tensor([[[0, 1, 0, 1], [0, 1, 2, 0], [0, 0, 1, 2]], - [[1, 0, 1, 2], [0, 2, 0, 1], [0, 1, 2, 1]]], - device=device, dtype=index_dtype), - values([[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], - [[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]], 2, (2, 3)), - shape((2, 3, 2, 3), 2, (2, 3))) - - if enable_hybrid: - # a tensor with one dense dimension - yield (torch.tensor([0, 2, 4], device=device, dtype=index_dtype), - torch.tensor([0, 1, 0, 2], device=device, dtype=index_dtype), - values([1, 2, 3, 4], 0, (3, 2), (2,)), - shape((2, 3, 2), 0, (3, 2))) - - # a tensor with two dense dimensions - yield (torch.tensor([0, 2, 4], device=device, dtype=index_dtype), - torch.tensor([0, 1, 0, 2], device=device, dtype=index_dtype), - values([1, 2, 3, 4], 0, (2, 3), (4, 2)), - shape((2, 3, 4, 2), 0, (2, 3))) - - if enable_batched and enable_hybrid: - # a batched tensor with two batch dimensions and two dense dimensions - yield (torch.tensor([[[0, 2, 4], [0, 3, 4], [0, 1, 4]], - [[0, 1, 4], [0, 2, 4], [0, 3, 4]]], - device=device, dtype=index_dtype), - torch.tensor([[[0, 1, 0, 1], [0, 1, 2, 0], [0, 0, 1, 2]], - [[1, 0, 1, 2], [0, 2, 0, 1], [0, 1, 2, 1]]], - device=device, dtype=index_dtype), - values([[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], - [[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]], 2, (3, 2), (2, 1)), - shape((2, 3, 2, 3, 2, 1), 2, (3, 2))) - @all_sparse_compressed_layouts() @onlyCPU def test_layout(self, layout): @@ -346,11 +210,14 @@ def test_layout(self, layout): @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) def test_sparse_compressed_constructor(self, layout, device, dtype, use_factory_function, shape_and_device_inference, input_kind): - if input_kind == 'list' and shape_and_device_inference and torch.device(device).type == 'cuda': - # list inputs to factory/constructor function without - # specifying device will result a sparse compressed tensor - # on CPU. So, skip testing against cuda device as unused. - self.skipTest("nothing to test") + if input_kind == 'list' and shape_and_device_inference: + if torch.device(device).type == 'cuda': + # list inputs to factory/constructor function without + # specifying device will result a sparse compressed tensor + # on CPU. So, skip testing against cuda device as unused. + self.skipTest("nothing to test") + if dtype not in {torch.float32, torch.complex64, torch.int64, torch.bool}: + self.skipTest("dtype not supported with list values") expected_devices = [torch.device(device)] if TEST_CUDA and torch.device(device).type == 'cuda' and torch.cuda.device_count() >= 2 and not shape_and_device_inference: @@ -363,29 +230,34 @@ def test_sparse_compressed_constructor(self, layout, device, dtype, torch.sparse_bsc: torch.sparse_bsc_tensor, }[layout] compressed_indices_mth, plain_indices_mth = sparse_compressed_indices_methods[layout] - for index_dtype in [torch.int32, torch.int64]: + if input_kind == 'list': + index_dtypes = [torch.int64] + else: + index_dtypes = [torch.int32, torch.int64] + for index_dtype in index_dtypes: for expected_device in expected_devices: - for compressed_indices, plain_indices, values, size in self._generate_small_inputs( - layout, expected_device, dtype, index_dtype): + for (compressed_indices, plain_indices, values), kwargs in self.generate_simple_inputs( + layout, device=expected_device, dtype=dtype, index_dtype=index_dtype, + # skip zero-sized tensors for list inputs: + enable_zero_sized=input_kind != 'list', + output_tensor=False): + size = kwargs['size'] + if shape_and_device_inference and 0 in size: + # skip shape inference for zero-sized tensor + # inputs because (i) the shape determined from + # an empty list is ambiguous, and (ii) the + # size of the plain dimension defined as + # max(plain_indices) is undefined if + # plain_indices has no values + continue + compressed_indices_expect = compressed_indices + plain_indices_expect = plain_indices + values_expect = values + if input_kind == 'list': - if size == (0, 0): - # for this degenerate case, plain_indices must - # remain a tensor because - # tensor(plain_indices) results a float dtype - # when plain_indices is an empty list - if index_dtype == torch.int32: - # skip testing int32 case because - # tensor(compressed_indices) results a - # int64 dtype when compressed_indices is - # [0] (a list of single int zero). - continue - else: - plain_indices = plain_indices.tolist() compressed_indices = compressed_indices.tolist() + plain_indices = plain_indices.tolist() values = values.tolist() - if size == (0, 0) and layout in {torch.sparse_bsr, torch.sparse_bsc}: - # in the block sparse case, values of type list needs to represent a 3-D tensor - values = [[[]]] if use_factory_function: if shape_and_device_inference: @@ -401,9 +273,9 @@ def test_sparse_compressed_constructor(self, layout, device, dtype, dtype=dtype, layout=layout, device=expected_device) self.assertEqual(layout, sparse.layout) self.assertEqual(size, sparse.shape) - self.assertEqual(compressed_indices, compressed_indices_mth(sparse)) - self.assertEqual(plain_indices, plain_indices_mth(sparse)) - self.assertEqual(values, sparse.values()) + self.assertEqual(compressed_indices_expect, compressed_indices_mth(sparse)) + self.assertEqual(plain_indices_expect, plain_indices_mth(sparse)) + self.assertEqual(values_expect, sparse.values()) self.assertEqual(sparse.device, sparse.values().device) self.assertEqual(sparse.device, expected_device) @@ -449,10 +321,8 @@ def test_empty_errors(self, layout, device, dtype): @all_sparse_compressed_layouts() @dtypes(*all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16)) def test_clone(self, layout, device, dtype): - for compressed_indices, plain_indices, values, size in self._generate_small_inputs( - layout, device, dtype, index_dtype=torch.int32): - sparse = torch.sparse_compressed_tensor(compressed_indices, plain_indices, values, size, - dtype=dtype, layout=layout, device=device) + for sparse in self.generate_simple_inputs( + layout, device=device, dtype=dtype, index_dtype=torch.int32): cloned_sparse = sparse.clone() self.assertEqual(sparse, cloned_sparse) @@ -461,10 +331,37 @@ def test_print(self, layout, device): compressed_indices_mth, plain_indices_mth = sparse_compressed_indices_methods[layout] printed = [] for enable_hybrid in [False, True]: + # using local patterns for test_print stability + patterns = [ + # 2 x 3 batch of 3 x 2 tensors, trivial blocksize, non-hybrid/hybrid: + ([[[[1, 2, 0], + [1, 0, 3]], + [[1, 2, 3], + [1, 0, 0]], + [[1, 0, 0], + [1, 2, 3]]], + [[[0, 2, 0], + [1, 2, 3]], + [[1, 0, 3], + [1, 2, 0]], + [[1, 2, 3], + [0, 2, 0]]]], [(2, 1)], [(), (4,)] if enable_hybrid else [()]), + # tensor with non-trivial blocksize, non-hybrid/hybrid: + ([[0, 1, 0, 2, 0, 2], + [0, 1, 0, 0, 2, 0], + [3, 3, 3, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 5, 0, 6, 6, 6], + [5, 0, 5, 6, 6, 6], + [0, 0, 0, 0, 8, 8], + [7, 7, 7, 0, 8, 8]], [(2, 3)], [(), (4, 2)] if enable_hybrid else [()]), + ] for index_dtype in [torch.int32, torch.int64]: for dtype in [torch.float32, torch.float64]: - for compressed_indices, plain_indices, values, size in self._generate_small_inputs( - layout, device, dtype, index_dtype, enable_hybrid=enable_hybrid): + for (compressed_indices, plain_indices, values), kwargs in self.generate_simple_inputs( + layout, device=device, dtype=dtype, index_dtype=index_dtype, enable_hybrid=enable_hybrid, + enable_zero_sized=False, output_tensor=False, patterns=patterns): + size = tuple(kwargs['size']) block_ndim = 2 if layout in {torch.sparse_bsr, torch.sparse_bsc} else 0 base_ndim = 2 batch_ndim = compressed_indices.dim() - 1 @@ -587,9 +484,6 @@ def test_consistency(self, layout, device, dtype, op): if require_mask and layout in {torch.sparse_bsr, torch.sparse_bsc}: self.skipTest(f"{op.name} does not support input with {layout} layout") - if layout is torch.sparse_bsc: - self.skipTest(f"test requires conversion from Strided layout to {layout} layout") - samples = list(op.sample_inputs(device, dtype)) # Fail early to prevent silent success with this test @@ -644,9 +538,7 @@ def test_consistency(self, layout, device, dtype, op): @all_sparse_compressed_layouts('layout2') @dtypes(*all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16)) def test_empty_like(self, layout, layout2, device, dtype): - for compressed_indices, plain_indices, values, size in self._generate_small_inputs(layout): - sparse = torch.sparse_compressed_tensor(compressed_indices, plain_indices, values, size, - dtype=dtype, layout=layout, device=device) + for sparse in self.generate_simple_inputs(layout): if layout == layout2: result = torch.empty_like(sparse, layout=layout2) compressed_indices_mth, plain_indices_mth = sparse_compressed_indices_methods[result.layout] @@ -668,14 +560,28 @@ def test_empty_like(self, layout, layout2, device, dtype): @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) def test_validate(self, layout, device, dtype): for index_dtype in [torch.int32, torch.int64]: - for compressed_indices, plain_indices, values, size in self._generate_small_inputs( - layout, device, dtype, index_dtype, enable_batched=True, enable_hybrid=True): + for (compressed_indices, plain_indices, values), kwargs in self.generate_simple_inputs( + layout, device=device, dtype=dtype, index_dtype=index_dtype, output_tensor=False): + size = kwargs['size'] torch._validate_sparse_compressed_tensor_args(compressed_indices, plain_indices, values, size, layout) def _generate_invalid_input(self, layout, device): from functools import partial - shape, values = self._generate_small_inputs_utils(layout, device=device) + def shape(shape, basedim=0): + blocksize = (1, 1) + if layout is torch.sparse_csc: + shape = shape[:basedim] + (shape[basedim + 1], shape[basedim]) + shape[basedim + 2:] + elif layout is torch.sparse_bsc: + shape = shape[:basedim] + (shape[basedim + 1] * blocksize[1], shape[basedim] * blocksize[0]) + shape[basedim + 2:] + elif layout is torch.sparse_bsr: + shape = shape[:basedim] + (shape[basedim] * blocksize[0], shape[basedim + 1] * blocksize[1]) + shape[basedim + 2:] + return shape + + def values(lst, device=device): + if layout in {torch.sparse_bsr, torch.sparse_bsc}: + lst = [[[item]] for item in lst] + return torch.tensor(lst, device=device) tensor = partial(torch.tensor, device=device) values = partial(values, device=device) @@ -708,7 +614,7 @@ def _generate_invalid_input(self, layout, device): shape((2, 3)), 'compressed_indices must have dimensionality >= 1 but got 0') - yield ('compressed/plain_indices mismatch of dimensionalites', + yield ('compressed/plain_indices mismatch of dimensionalities', tensor([[0, 2, 4]]), tensor([0, 1, 0, 2]), values([1, 2, 3, 4]), @@ -716,14 +622,14 @@ def _generate_invalid_input(self, layout, device): 'compressed_indices and plain_indices dimensionalities must be equal but got 2 and 1, respectively') if layout in {torch.sparse_csr, torch.sparse_csc}: - yield ('indices and values mismatch of dimensionalites', + yield ('indices and values mismatch of dimensionalities', tensor([[0, 2, 4]]), tensor([[0, 1, 0, 2]]), values([1, 2, 3, 4]), shape((2, 3)), r'values must have dimensionality > sum of batch and block dimensionalities \(=1 \+ 0\) but got 1') else: - yield ('indices and values mismatch of dimensionalites', + yield ('indices and values mismatch of dimensionalities', tensor([[0, 2, 4]]), tensor([[0, 1, 0, 2]]), values([1, 2, 3, 4]), @@ -735,7 +641,7 @@ def _generate_invalid_input(self, layout, device): tensor([0, 1, 0, 2]), values([1, 2, 3, 4]), (2,), - r'tensor dimensionality must be sum of batch, base, and dense dimensionalites \(=0 \+ 2 \+ 0\) but got 1') + r'tensor dimensionality must be sum of batch, base, and dense dimensionalities \(=0 \+ 2 \+ 0\) but got 1') yield ('invalid batchsize', tensor([[0, 2, 4]]), @@ -922,7 +828,8 @@ def test_invalid_input(self, layout, device, target): @onlyCPU @all_sparse_compressed_layouts() def test_dim(self, layout): - for compressed_indices, plain_indices, values, size in self._generate_small_inputs(layout): + for (compressed_indices, plain_indices, values), kwargs in self.generate_simple_inputs(layout, output_tensor=False): + size = kwargs['size'] batch_dim = compressed_indices.dim() - 1 sparse_dim = 2 block_dim = 2 if layout in {torch.sparse_bsr, torch.sparse_bsc} else 0 @@ -932,6 +839,75 @@ def test_dim(self, layout): self.assertEqual(sparse.dense_dim(), dense_dim) + @skipMeta + @all_sparse_compressed_layouts() + @dtypes(*all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16)) + def test_to_dtype(self, layout, device, dtype): + # to_dense does not support hybrid inputs + for sparse in self.generate_simple_inputs(layout, dtype=dtype, device=device, enable_hybrid=False): + for to_dtype in all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16): + sparse_to_dtype = sparse.to(to_dtype) + dense_to_dtype = sparse.to_dense().to(to_dtype) + self.assertEqual(sparse_to_dtype.to_dense(), dense_to_dtype) + + @skipMeta + @all_sparse_compressed_layouts() + @dtypes(torch.double) + def test_pickle(self, layout, dtype, device): + import pickle + + for sparse in self.generate_simple_inputs(layout, device=device, dtype=dtype): + serialized = pickle.dumps(sparse) + sparse_loaded = pickle.loads(serialized) + + self.assertEqual(sparse, sparse_loaded) + + @all_sparse_compressed_layouts() + @parametrize("index_dtype", [torch.int32, torch.int64]) + @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) + def test_select_copy(self, device, dtype, index_dtype, layout): + + def is_view_of(base, other): + # a shameless copy of TestViewOps.is_view_of + if ((not other._is_view() or + other is base or + other._base is not base or + base.device != other.device)): + return False + if base.device.type == 'cpu' or base.device.type == 'cuda': + if base._storage().data_ptr() != other._storage().data_ptr(): + return False + return True + + kwargs = dict(device=device, dtype=dtype, index_dtype=index_dtype) + for sparse, dense in zip(self.generate_simple_inputs(layout, **kwargs), + self.generate_simple_inputs(torch.strided, **kwargs)): + if layout in {torch.sparse_csr, torch.sparse_bsr}: + n_batchdim = sparse.crow_indices().ndim - 1 + elif layout in {torch.sparse_csc, torch.sparse_bsc}: + n_batchdim = sparse.ccol_indices().ndim - 1 + else: + assert 0 # unreachable + self.assertEqual(sparse, dense) + for dim in range(sparse.ndim): + if sparse.shape[dim] == 0: + with self.assertRaisesRegex(IndexError, "index 0 out of range for tensor of size"): + torch.select_copy(sparse, dim, 0) + with self.assertRaisesRegex(IndexError, "index 0 out of range for tensor of size"): + torch.select_copy(dense, dim, 0) + elif n_batchdim and dim >= n_batchdim and dim < n_batchdim + 2: + with self.assertRaisesRegex( + RuntimeError, + "selecting sparse dimensions is not implemented for batched sparse compressed tensors"): + torch.select_copy(sparse, dim, 0) + else: + for index in {0, sparse.shape[dim] // 2, sparse.shape[dim] - 1}: + dense_select = torch.select_copy(dense, dim, index) + sparse_select = torch.select_copy(sparse, dim, index) + self.assertEqual(sparse_select, dense_select) + self.assertFalse(is_view_of(sparse_select.values(), sparse.values())) + + def _npref_block_addmm_addmv(c, a, b, alpha, beta): return alpha * (a @ b) + beta * c @@ -1006,52 +982,26 @@ def test_select(self, device, dtype, index_dtype, layout): device=device) self.assertEqual(expected_sparse_selected12, sparse_selected12) - # Select from dense dimensions - sparse_hybrid = self.genSparseCompressedTensor(shape + (4, 2), - nnz, - device=device, - layout=layout, - dtype=dtype, - index_dtype=index_dtype, - blocksize=blocksize, - dense_dims=2) - sparse_hybrid_dense_selected = sparse_hybrid.select(4, 1) - expected_sparse_hybrid_dense_selected = sparse_hybrid.values().select(-2, 1) - self.assertEqual(expected_sparse_hybrid_dense_selected, sparse_hybrid_dense_selected) - - - # selecting rows/col with batch dims not allowed sparse_non_batched = sparse[0, 0] - # select from sparse dimensions if layout supports is - if layout in {torch.sparse_csr, torch.sparse_csc}: + # select from sparse dimensions + for select_args in [(0, 0), (1, 1)]: + sparse_selected = sparse_non_batched.select(*select_args) + dense_selected = sparse_non_batched.to_dense().select(*select_args) + self.assertEqual(dense_selected, sparse_selected) - for select_args in [(0, 0), (1, 1)]: - sparse_selected = sparse_non_batched.select(*select_args) - dense_selected = sparse_non_batched.to_dense().select(*select_args) - self.assertEqual(dense_selected, sparse_selected) + self.assertEqual(sparse[0, 0, 0, 0], sparse.to_dense()[0, 0, 0, 0]) + # assigning to sparse through indexing is disabled + with self.assertRaisesRegex(TypeError, "Cannot assign to a sparse tensor"): + sparse[0, 0, 0, 0] = 99.0 - self.assertEqual(sparse[0, 0, 0, 0], sparse.to_dense()[0, 0, 0, 0]) - # assigning to sparse through indexing is disabled, not tested generally because only layouts supporting - # sparse dim select will get far enough to test - with self.assertRaisesRegex(TypeError, "Cannot assign to a sparse tensor"): - sparse[0, 0, 0, 0] = 99.0 + # select from sparse dimensions without removing batch dims + msg = "selecting sparse dimensions is not implemented for batched sparse compressed tensors." + with self.assertRaisesRegex(RuntimeError, msg): + sparse.select(-2, 0) - # select from sparse dimensions without removing batch dims, not tested generally because only layouts - # supporting sparse dim select will get far enough - msg = "selecting rows or columns is not implemented for batched sparse compressed tensors." - with self.assertRaisesRegex(RuntimeError, msg): - sparse.select(-2, 0) - - with self.assertRaisesRegex(RuntimeError, msg): - sparse.select(-1, 0) - # ensure raises if layout does not support - else: - msg = ( - "selecting non-batch dimensions is currently only supported for non-blocked sparse " - "compressed layouts tensors.") - with self.assertRaisesRegex(RuntimeError, msg): - sparse_non_batched.select(0, 0) + with self.assertRaisesRegex(RuntimeError, msg): + sparse.select(-1, 0) @skipMeta @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) @@ -2477,6 +2427,9 @@ def test_autograd_sparse_csr_unary(self, device, dtype, op): raise ValueError("Expected at least one 2D tensor in samples.") for sample in samples: + # We must skip samples of low dimensionality, we can't covert them to sparsed compressed layouts + if sample.input.ndim < 2: + continue sparse_input = sample.input.to_sparse_csr().requires_grad_(True) def fn(input): @@ -2808,33 +2761,6 @@ def test_exercise_detach(self, device, dtype): detached_inp = inp.detach() self.assertEqual(inp, detached_inp) - def _convert_to_layout(self, a, target_layout, blocksize=(2, 2)): - """ - Helper function to call the correct layout conversion - with reasonable defaults for the block size. Clearly there - is a need for a to.layout overload. - """ - if target_layout is torch.sparse_csr: - result = a.to_sparse_csr() - elif target_layout is torch.sparse_csc: - result = a.to_sparse_csc() - elif target_layout is torch.sparse_bsr: - result = a.to_sparse_bsr(blocksize) - elif target_layout is torch.sparse_bsc: - result = a.to_sparse_bsc(blocksize) - else: - raise NotImplementedError(repr(a)) - assert result.layout is target_layout - # to_sparse_xyz methods use unsafe construction of sparse - # compressed tensors. Here we explicitly validate the results - # to make sure that the sparse tensors are consistent with the - # corresponding sparse tensor invariants. - compressed_indices_mth, plain_indices_mth = sparse_compressed_indices_methods[result.layout] - compressed_indices, plain_indices = compressed_indices_mth(result), plain_indices_mth(result) - torch._validate_sparse_compressed_tensor_args(compressed_indices, plain_indices, result.values(), - result.shape, result.layout) - return result - def _construct_sp_matrix(self, tensor, layout, blocksize=(2, 2)): if tensor.layout in [torch.sparse_coo, torch.sparse_csr, torch.sparse_csc, torch.strided]: tensor = tensor.to_dense() @@ -2853,9 +2779,13 @@ def _construct_sp_matrix(self, tensor, layout, blocksize=(2, 2)): @all_sparse_compressed_layouts('to_layout') @all_sparse_compressed_layouts('from_layout') def test_compressed_layout_conversions_coverage(self, device, from_layout, to_layout): - """ - This test performs a smoke test for covered conversion and verifies + """This test performs a smoke test for covered conversion and verifies that an exception is thrown for unsupported conversions. + + TODO: This test covers a subset of + TestSparseAny.test_to_sparse tests and can be + eliminated. Keeping the test until the new + `Tensor.to_sparse(*, layout, blocksize)` has landed. """ allowed_pairwise_layouts_sets = { @@ -2882,19 +2812,21 @@ def _to_from_layout(layout_a, layout_b, a): if a.dim() > 2: expect_error = True - b = self._convert_to_layout(a, layout_a) + blocksize_a = (1, 1) if layout_a in {torch.sparse_bsr, torch.sparse_bsc} else None + blocksize_b = (1, 1) if layout_b in {torch.sparse_bsr, torch.sparse_bsc} else None + b = a.to_sparse(layout=layout_a, blocksize=blocksize_a) if expect_error: with self.assertRaises(RuntimeError): - self._convert_to_layout(b, layout_b) + b.to_sparse(layout=layout_b, blocksize=blocksize_b) else: - c = self._convert_to_layout(b, layout_b) + c = b.to_sparse(layout=layout_b, blocksize=blocksize_b) self.assertEqual(a.to_dense(), c.to_dense()) # change of blocksize upon conversion is not yet supported. if b.layout in block_layouts: for block_layout in block_layouts: - with self.assertRaisesRegex(RuntimeError, "blocksize does not match the blocksize"): - self._convert_to_layout(b, block_layout, blocksize=3) + with self.assertRaisesRegex(RuntimeError, "conversion from.*to.*is not implemented"): + b.to_sparse(layout=block_layout, blocksize=(3, 3)) batch_dims = [(), (2,), (2, 2), (2, 2, 2)] sparse_dims = (6, 12) @@ -2908,11 +2840,13 @@ def _to_from_layout(layout_a, layout_b, a): @hybrid_nonhybrid() @unittest.skipIf(not TEST_SCIPY, "SciPy not found") def test_dense_to_from_sparse_compressed(self, device, hybrid, batched, layout): - """ - This test tests conversion from dense to/from CSR and CSC + """This test tests conversion from dense to/from CSR and CSC by comparing to SciPy's implementation. - TODO: Eventually this is meant to be merged into test_compressed_layout_conversions_coverage + Here we test only those conversion combinations that SciPy + supports to ensure that PyTorch conversions are in the same + page with SciPy. Independent from SciPy, all conversion + combinations are tested in TestSparseAny.test_to_sparse. """ # adjust this block as support is added @@ -2959,7 +2893,7 @@ def _check_batched(pt_tensor, dense, check_batch=None, batch_shape=(), blocksize for batch_index in np.ndindex(batch_shape): pt_matrix = pt_tensor[batch_index] dense_matrix = dense[batch_index] - dense_matrix_pt = self._convert_to_layout(dense_matrix, layout, blocksize) + dense_matrix_pt = dense_matrix.to_sparse(layout=layout, blocksize=blocksize or None) # sanity check, selecting batch of to_ and dense[batch].to_ should give the same result self.assertEqual(pt_matrix, dense_matrix_pt) check_batch(pt_matrix, dense_matrix, blocksize, **kwargs) @@ -3003,12 +2937,12 @@ def _generate_subject(sparse_shape, batch_shape, hybrid_shape): batch_sizes = [(3,), (1, 3), (2, 1, 3)] if batched else [()] hybrid_sizes = [(4, ), (2, 2)] if hybrid else [()] if not hybrid: - # general cases, always run, hybrid excluded untill dense->sparse api exists + # general cases, always run, hybrid excluded until dense->sparse api exists for sparse_shape, blocksize, batch_shape, hybrid_shape in itertools.product( sparse_sizes, blocksizes, batch_sizes, hybrid_sizes): dense = _generate_subject(sparse_shape, batch_shape, hybrid_shape) if expect_to_layout_support: - sparse = self._convert_to_layout(dense, layout, blocksize) + sparse = dense.to_sparse(layout=layout, blocksize=blocksize or None) check_content(sparse, dense, blocksize=blocksize, batch_shape=batch_shape, hybrid_shape=hybrid_shape) if expect_from_layout_support: dense_back = sparse.to_dense() @@ -3018,7 +2952,7 @@ def _generate_subject(sparse_shape, batch_shape, hybrid_shape): sparse.to_dense() else: with self.assertRaises(RuntimeError): - self._convert_to_layout(dense, layout, blocksize) + dense.to_sparse(layout=layout, blocksize=blocksize or None) # special cases for batched tensors if batched and expect_to_layout_support: @@ -3052,7 +2986,7 @@ def _generate_subject(sparse_shape, batch_shape, hybrid_shape): mask = mask.transpose(-3, -2) mask = mask.reshape_as(dense) dense = dense * mask - sparse = self._convert_to_layout(dense, layout, blocksize) + sparse = dense.to_sparse(layout=layout, blocksize=blocksize or None) check_content(sparse, dense, blocksize=blocksize, batch_shape=batch_shape, hybrid_shape=hybrid_shape) if expect_from_layout_support: @@ -3070,14 +3004,14 @@ def _generate_subject(sparse_shape, batch_shape, hybrid_shape): dense = dense * mask msg = "Expect the same number of specified elements per batch." with self.assertRaisesRegex(RuntimeError, msg): - self._convert_to_layout(dense, layout, blocksize) + dense.to_sparse(layout=layout, blocksize=blocksize or None) # Should throw if there is a zero in the batch size dense = make_tensor((0,) + shape, dtype=torch.float, device=device) layout_code = str(layout).split("_")[-1] msg = f"to_sparse_{layout_code}: Expected product of batch dimensions to be non-zero." with self.assertRaisesRegex(RuntimeError, msg): - self._convert_to_layout(dense, layout, blocksize=blocksize) + dense.to_sparse(layout=layout, blocksize=blocksize or None) if hybrid: # conversion from sparse -> dense should be blocked with dense dims @@ -3110,21 +3044,24 @@ def test_sparse_to_sparse_compressed(self, device, dtype, coalesced, layout): This test tests conversion from COO to CSR and CSC and CSC to CSR and CSC by comparing to SciPy's implementation. - TODO: Eventually this is meant to be merged into test_compressed_layout_conversions_coverage + Here we test only those conversion combinations that SciPy + supports to ensure that PyTorch conversions are in the same + page with SciPy. Independent from SciPy, all conversion + combinations are tested in TestSparseAny.test_to_sparse. """ if layout is torch.sparse_bsc: # TODO: Remove this once support has been enabled - return + self.skipTest('NOT IMPL') if layout is torch.sparse_bsr: # TODO: Remove this once support has been enabled - return + self.skipTest('NOT IMPL') for shape in [(0, 10), (6, 0), (6, 10), (0, 0)]: sparse_dim = 2 nnz = shape[0] * shape[1] // 2 sparse, _, _ = self.genSparseTensor(shape, sparse_dim, nnz, coalesced, device, dtype) sp_matrix = self._construct_sp_matrix(sparse, layout) - pt_matrix = self._convert_to_layout(sparse, layout) + pt_matrix = sparse.to_sparse(layout=layout) compressed_indices_mth = { torch.sparse_csr: torch.Tensor.crow_indices, @@ -3144,7 +3081,7 @@ def test_sparse_to_sparse_compressed(self, device, dtype, coalesced, layout): sparse_csc = sparse.to_sparse_csc() sp_matrix = self._construct_sp_matrix(sparse_csc, layout) - pt_matrix = self._convert_to_layout(sparse_csc, layout) + pt_matrix = sparse_csc.to_sparse(layout=layout) self.assertEqual(layout, pt_matrix.layout) self.assertEqual(sp_matrix.shape, pt_matrix.shape) diff --git a/test/test_tensorboard.py b/test/test_tensorboard.py index e836b0f1ba8dc..3899bbead7f1f 100644 --- a/test/test_tensorboard.py +++ b/test/test_tensorboard.py @@ -7,6 +7,7 @@ import sys import unittest import uuid +import expecttest TEST_TENSORBOARD = True try: @@ -520,6 +521,11 @@ def read_expected_content(function_ptr): return f.read() def compare_image_proto(actual_proto, function_ptr): + if expecttest.ACCEPT: + expected_file = get_expected_file(function_ptr) + with open(expected_file, 'w') as f: + f.write(text_format.MessageToString(actual_proto)) + return True expected_str = read_expected_content(function_ptr) expected_proto = Summary() text_format.Parse(expected_str, expected_proto) @@ -537,6 +543,9 @@ def compare_image_proto(actual_proto, function_ptr): ) def compare_proto(str_to_compare, function_ptr): + if expecttest.ACCEPT: + write_proto(str_to_compare, function_ptr) + return True expected = read_expected_content(function_ptr) str_to_compare = str(str_to_compare) return remove_whitespace(str_to_compare) == remove_whitespace(expected) diff --git a/test/test_tensorexpr.py b/test/test_tensorexpr.py index 20deb0a43c429..cf894f3749eb9 100644 --- a/test/test_tensorexpr.py +++ b/test/test_tensorexpr.py @@ -1325,6 +1325,17 @@ def do_exp(x, y, z): x = warmup_and_run_forward(traced, x, y, z) self.assertLastGraphAllFused() + def test_sin_pow(self): + def test(x): + return torch.sin(torch.pow(x, 0)) + + for data_type, shape in itertools.product(self.dtypes, [[3], [5], [10]]): + x = torch.rand(shape, dtype=data_type) + scripted = torch.jit.script(test) + out = warmup_and_run_forward(scripted, x) + self.assertLastGraphAllFused() + self.assertEqual(out, test(x)) + def test_transpose(self): @torch.jit.script def test(x, y, z): diff --git a/test/test_testing.py b/test/test_testing.py index e31872f7da6fd..821a30ab432b2 100644 --- a/test/test_testing.py +++ b/test/test_testing.py @@ -12,7 +12,7 @@ import subprocess import sys import unittest.mock -from typing import Any, Callable, Iterator, List, Tuple +from typing import Any, Callable, Iterator, List, Tuple, Generator import torch @@ -23,7 +23,7 @@ from torch.testing._internal.common_device_type import \ (PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY, PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY, dtypes, get_device_type_test_bases, instantiate_device_type_tests, onlyCUDA, onlyNativeDeviceTypes, - deviceCountAtLeast, ops, expectedFailureMeta) + deviceCountAtLeast, ops, expectedFailureMeta, OpDTypes) from torch.testing._internal.common_methods_invocations import op_db from torch.testing._internal import opinfo from torch.testing._internal.common_dtype import all_types_and_complex_and @@ -68,7 +68,7 @@ def test_assertEqual_longMessage(self): self.longMessage = True extra_msg = "sentinel" - with self.assertRaisesRegex(AssertionError, re.escape(f"{default_msg} : {extra_msg}")): + with self.assertRaisesRegex(AssertionError, re.escape(f"{default_msg}\n{extra_msg}")): self.assertEqual(actual, expected, msg=extra_msg) finally: self.longMessage = long_message @@ -1178,6 +1178,17 @@ def test_mismatching_values_msg(self): with self.assertRaisesRegex(AssertionError, re.escape("Sparse CSR values")): fn() + @unittest.expectedFailure + def test_hybrid_support(self): + # If you read this after the test unexpectedly succeeded, this is a good thing. It means that you added support + # for `.to_dense()` for hybrid sparse CSR tensors and in turn enabled support for them in + # `torch.testing.assert_close` if comparing to strided tensors. You can safely remove this test as well as the + # patch on `TensorOrArrayPair` in `torch.testing._internal.common_utils`. + actual = torch.sparse_csr_tensor([0, 2, 4], [0, 1, 0, 1], [[1, 11], [2, 12], [3, 13], [4, 14]]) + expected = torch.stack([actual[0].to_dense(), actual[1].to_dense()]) + + torch.testing.assert_close(actual, expected, check_layout=False) + @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Not all sandcastle jobs support CSC testing") class TestAssertCloseSparseCSC(TestCase): @@ -1777,18 +1788,24 @@ def test_circular_dependencies(self) -> None: "torch.contrib.", # something weird "torch.testing._internal.distributed.", # just fails "torch.ao.pruning._experimental.", # depends on pytorch_lightning, not user-facing - "torch.cuda._dynamo_graphs", # depends on torchdynamo ] # See https://github.com/pytorch/pytorch/issues/77801 if not sys.version_info >= (3, 9): ignored_modules.append("torch.utils.benchmark") if IS_WINDOWS or IS_MACOS: - # Distributed does not work on Windows or by default on Mac - ignored_modules.append("torch.distributed.") + # Distributed should be importable on Windows(except nn.api.), but not on Mac + if IS_MACOS: + ignored_modules.append("torch.distributed.") + else: + ignored_modules.append("torch.distributed.nn.api.") + ignored_modules.append("torch.distributed.optim.") + ignored_modules.append("torch.distributed.pipeline.") + ignored_modules.append("torch.distributed.rpc.") ignored_modules.append("torch.testing._internal.dist_utils") # And these both end up with transitive dependencies on distributed ignored_modules.append("torch.nn.parallel._replicated_tensor_ddp_interop") ignored_modules.append("torch.testing._internal.common_fsdp") + ignored_modules.append("torch.testing._internal.common_distributed") torch_dir = os.path.dirname(torch.__file__) for base, folders, files in os.walk(torch_dir): @@ -1818,6 +1835,27 @@ def test_no_warning_on_import(self) -> None: cwd=os.path.dirname(os.path.realpath(__file__)),).decode("utf-8") self.assertEquals(out, "") + @unittest.skipIf(IS_WINDOWS, "importing torch+CUDA on CPU results in warning") + @parametrize('path', ['torch', 'functorch']) + def test_no_mutate_global_logging_on_import(self, path) -> None: + # Calling logging.basicConfig, among other things, modifies the global + # logging state. It is not OK to modify the global logging state on + # `import torch` (or other submodules we own) because users do not expect it. + expected = 'abcdefghijklmnopqrstuvwxyz' + commands = [ + 'import logging', + f'import {path}', + '_logger = logging.getLogger("torch_test_testing")', + 'logging.root.addHandler(logging.StreamHandler())', + 'logging.root.setLevel(logging.INFO)', + f'_logger.info("{expected}")' + ] + out = subprocess.check_output( + [sys.executable, "-W", "all", "-c", "; ".join(commands)], + stderr=subprocess.STDOUT, + ).decode("utf-8") + self.assertEqual(out.strip(), expected) + class TestOpInfos(TestCase): def test_sample_input(self) -> None: a, b, c, d, e = [object() for _ in range(5)] @@ -1881,5 +1919,31 @@ def test_sample_input_metadata(self) -> None: self.assertEqual(s2.name, "foo") +# Tests that validate the various sample generating functions on each OpInfo. +class TestOpInfoSampleFunctions(TestCase): + + @ops(op_db, dtypes=OpDTypes.any_one) + def test_opinfo_sample_generators(self, device, dtype, op): + # Test op.sample_inputs doesn't generate multiple samples when called + samples = op.sample_inputs(device, dtype) + self.assertIsInstance(samples, Generator) + + @ops([op for op in op_db if op.reference_inputs_func is not None], dtypes=OpDTypes.any_one) + def test_opinfo_reference_generators(self, device, dtype, op): + # Test op.reference_inputs doesn't generate multiple samples when called + samples = op.reference_inputs(device, dtype) + self.assertIsInstance(samples, Generator) + + @ops([op for op in op_db if op.error_inputs_func is not None], dtypes=OpDTypes.none) + def test_opinfo_error_generators(self, device, op): + # Test op.error_inputs doesn't generate multiple inputs when called + samples = op.error_inputs(device) + self.assertIsInstance(samples, Generator) + + +instantiate_device_type_tests(TestOpInfoSampleFunctions, globals()) +instantiate_parametrized_tests(TestImports) + + if __name__ == '__main__': run_tests() diff --git a/test/test_torch.py b/test/test_torch.py index f84d8aff08950..26ab9b61f5d9c 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -30,16 +30,17 @@ from torch import multiprocessing as mp from torch.testing import make_tensor from torch.testing._internal.common_utils import ( - TestCase, TEST_WITH_ROCM, run_tests, + TEST_WITH_TORCHINDUCTOR, TestCase, TEST_WITH_ROCM, run_tests, IS_WINDOWS, IS_FILESYSTEM_UTF8_ENCODING, NO_MULTIPROCESSING_SPAWN, - IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, load_tests, slowTest, + IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, load_tests, skipIfTorchInductor, slowTest, TEST_WITH_CROSSREF, skipIfTorchDynamo, skipCUDAMemoryLeakCheckIf, BytesIOContext, skipIfRocm, skipIfNoSciPy, TemporaryFileName, TemporaryDirectoryName, wrapDeterministicFlagAPITest, DeterministicGuard, CudaSyncGuard, - skipIfNotRegistered, bytes_to_scalar, parametrize, skipIfMps) + skipIfNotRegistered, bytes_to_scalar, parametrize, skipIfMps, noncontiguous_like) from multiprocessing.reduction import ForkingPickler from torch.testing._internal.common_device_type import ( + dtypesIfMPS, expectedFailureMeta, expectedFailureXLA, instantiate_device_type_tests, @@ -55,7 +56,7 @@ tf32_on_and_off, tf32_is_not_fp32, TEST_CUDNN) from torch.testing._internal.common_dtype import ( floating_types_and, get_all_math_dtypes, all_types_and_complex_and, complex_types, - all_types_and, floating_types, floating_and_complex_types, + all_types_and, floating_types, floating_and_complex_types, integral_types_and, ) # Protects against includes accidentally setting the default dtype @@ -114,6 +115,8 @@ def test_cuda_vitals_gpu_only(self, device): self.assertIn('CUDA.used\t\t true', torch.read_vitals()) +is_cuda_sm86 = torch.cuda.is_available() and torch.cuda.get_device_capability(0) == (8, 6) + class TestTorchDeviceType(TestCase): exact_dtype = True @@ -161,6 +164,8 @@ def rand_byte(): @dtypes(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64, torch.bool, torch.float32, torch.complex64, torch.float64, torch.complex128) + @dtypesIfMPS(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64, + torch.bool, torch.float32) def test_storage(self, device, dtype): v = make_tensor((3, 5), dtype=dtype, device=device, low=-9, high=9) self.assertEqual(v.storage()[0], v[0][0]) @@ -222,6 +227,7 @@ def test_storage_setitem(self, device, dtype): @onlyNativeDeviceTypes @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) + @dtypesIfMPS(*integral_types_and(torch.bool, torch.half, torch.float32)) def test_tensor_storage_type(self, device, dtype): a = make_tensor((10,), dtype=dtype, device=device, low=-9, high=9) @@ -314,6 +320,7 @@ def test_untyped_storage_meta(self, device): @onlyNativeDeviceTypes @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) + @dtypesIfMPS(*integral_types_and(torch.bool, torch.half, torch.float32)) def test_storage_meta_from_tensor(self, device, dtype): t_check = make_tensor((4, 5, 3), dtype=dtype, device=device, low=-9, high=9) t = t_check.to('meta') @@ -356,7 +363,7 @@ def test_storage_meta_errors(self, device, dtype): s0.tolist() with tempfile.NamedTemporaryFile() as f: - with self.assertRaisesRegex(RuntimeError, r'Device not recognized'): + with self.assertRaisesRegex(NotImplementedError, r'Cannot copy out'): s0._write_file(f, True, True, s0.element_size()) for device in ['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']: @@ -374,6 +381,7 @@ def test_module_share_memory(self): model.share_memory() @dtypes(torch.float32, torch.complex64) + @dtypesIfMPS(torch.float32) def test_deepcopy(self, device, dtype): from copy import deepcopy a = torch.randn(5, 5, dtype=dtype, device=device) @@ -401,6 +409,7 @@ def test_deepcopy(self, device, dtype): self.assertEqual(deepcopy(a).foo, 3) @dtypes(torch.float32, torch.complex64) + @dtypesIfMPS(torch.float32) def test_deepcopy_scalar(self, device, dtype): from copy import deepcopy a = torch.tensor(5, dtype=dtype, device=device) @@ -819,6 +828,7 @@ def test_warn_always_caught(self, device): torch.from_numpy(a) @onlyNativeDeviceTypes + @skipIfMps def test_complex_half_experimental_warning(self, device): msg = 'ComplexHalf support is experimental' with self.assertWarnsOnceRegex(UserWarning, msg): @@ -856,6 +866,7 @@ def test_complex_half_experimental_warning(self, device): t + 1 # TODO: this test should be in test_nn.py + @skipIfMps def test_conv_transposed_backward_agnostic_to_memory_format(self, device): in_channels = 64 out_channels = 128 @@ -1445,6 +1456,7 @@ def test_nondeterministic_alert_EmbeddingBag_max(self, device): torch.device(device).type == 'cuda') @dtypes(*all_types_and_complex_and(torch.bool)) + @dtypesIfMPS(*integral_types_and(torch.bool, torch.float32)) def test_nondeterministic_alert_cumsum(self, device, dtype): input = make_tensor((10,), dtype=dtype, device=device, low=-9, high=9) should_alert = torch.device(device).type == 'cuda' and (dtype.is_floating_point or dtype.is_complex) @@ -1544,6 +1556,7 @@ def test_nondeterministic_alert_grid_sample_3d(self, device): 'grid_sampler_3d_backward_cuda', torch.device(device).type == 'cuda') + @skipIfMps def test_invalid_shapes_grid_sampler(self, device): make_arg = partial( make_tensor, device=device, dtype=torch.float64, requires_grad=True) @@ -1811,6 +1824,7 @@ def test_repeat_interleave(self, device): @dtypes(*floating_types()) @dtypesIfCPU(*floating_types_and(torch.bfloat16)) @dtypesIfCUDA(*floating_types_and(torch.half)) + @dtypesIfMPS(torch.half, torch.float) # crashes for half def test_bernoulli_p(self, device, dtype): for trivial_p in ([0, 1], [1, 0, 1, 1, 0, 1]): x = torch.tensor(trivial_p, dtype=dtype, device=device) @@ -1833,6 +1847,7 @@ def isBinary(t): @dtypes(*floating_types()) @dtypesIfCPU(*all_types_and(torch.bool)) @dtypesIfCUDA(*all_types_and(torch.bool, torch.half)) + @dtypesIfMPS(*integral_types_and(torch.bool, torch.half, torch.float32)) def test_bernoulli_self(self, device, dtype): def isBinary(t): @@ -1872,7 +1887,7 @@ def test_bernoulli_edge_cases(self, device, dtype): self.assertEqual(num_zeros, 0) @dtypes(*floating_types_and(torch.half, torch.bfloat16)) - @skipIfMps + @dtypesIfMPS(torch.half, torch.float) def test_exponential(self, device, dtype): a = torch.tensor([10], dtype=dtype, device=device).exponential_(0.5) self.assertEqual(a.dtype, dtype) @@ -1920,6 +1935,7 @@ def test_corrcoef(self, device, dtype): self.assertEqual(res, ref, exact_dtype=False) @dtypes(torch.int, torch.float, torch.cfloat) + @dtypesIfMPS(torch.int, torch.float) def test_cov(self, device, dtype): def check(t, correction=1, fweights=None, aweights=None): res = torch.cov(t, correction=correction, fweights=fweights, aweights=aweights) @@ -1940,6 +1956,7 @@ def check(t, correction=1, fweights=None, aweights=None): @skipIfNoSciPy @dtypes(*floating_types_and(torch.half, torch.bfloat16)) + @dtypesIfMPS(torch.half, torch.float) def test_uniform_kstest(self, device, dtype): from scipy import stats size = 1000 @@ -2312,8 +2329,9 @@ def test_cumprod(self, device): x = torch.rand(100, 100, device=device) res1 = torch.cumprod(x, 1) res2 = torch.tensor([]).to(device) - torch.cumprod(x, 1, out=res2) - self.assertEqual(res1, res2) + if not TEST_WITH_TORCHINDUCTOR: + torch.cumprod(x, 1, out=res2) + self.assertEqual(res1, res2) x.cumprod_(1) self.assertEqual(res1, x) @@ -2500,6 +2518,7 @@ def to_np(t): # All tensors appear contiguous on XLA @onlyNativeDeviceTypes @dtypes(*all_types_and_complex_and(torch.half, torch.bool)) + @dtypesIfMPS(*integral_types_and(torch.bool, torch.half, torch.float32)) # crashes for torch.int8_t def test_diff_noncontig(self, device, dtype): shapes = ( (1,), @@ -2522,6 +2541,7 @@ def test_diff_noncontig(self, device, dtype): @dtypes(*all_types_and_complex_and(torch.bool)) @dtypesIfCPU(*all_types_and_complex_and(torch.half, torch.bool)) @dtypesIfCUDA(*all_types_and_complex_and(torch.half, torch.bool)) + @dtypesIfMPS(*integral_types_and(torch.bool, torch.half, torch.float32)) def test_diff(self, device, dtype): shapes = ( (1,), @@ -2784,6 +2804,7 @@ def test_unfold_scalars(self, device): self.assertEqual(torch.tensor([0.5], device=device), x.unfold(0, 1, 1)) # FIXME: move to data movement test suite + @skipIfMps def test_copy_all_dtypes_and_devices(self, device): from copy import copy for dt in all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16): @@ -2869,6 +2890,7 @@ def test_copy_transpose_math_view(self, device, dtype): dst.copy_(src.conj()) self.assertEqual(dst, src.conj_physical()) + @skipIfMps def test_clone_all_dtypes_and_devices(self, device): for dt in all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16): x = torch.tensor((1, 1), dtype=dt, device=device) @@ -2890,6 +2912,7 @@ def test_clone_not_memory_dense(self): # FIXME: move to elementwise ternary test suite @dtypesIfCUDA(*set(get_all_math_dtypes('cuda'))) + @dtypesIfMPS(*integral_types_and(torch.half, torch.float32)) @dtypes(*set(get_all_math_dtypes('cpu'))) def test_addcmul(self, device, dtype): # Returns floating or integral scalar corresponding to dtype @@ -2943,6 +2966,7 @@ def test_narrow_empty(self, device): # FIXME: move to indexing test suite @parametrize("reduce", ['prod', 'amin', 'amax', 'mean']) @dtypes(*all_types_and(torch.half, torch.bfloat16)) + @dtypesIfMPS(*integral_types_and(torch.half, torch.float32)) def test_index_reduce(self, device, dtype, reduce): size = (3, 4, 5) index_dtypes = [torch.int, torch.long] @@ -2959,10 +2983,9 @@ def test_index_reduce(self, device, dtype, reduce): dest = make_tensor(size, device=device, dtype=dtype, noncontiguous=dest_noncontig) src_size = size[:dim] + (num_src,) + size[dim + 1:] src = make_tensor(src_size, device=device, dtype=dtype, noncontiguous=src_noncontig) - idx = torch.randint(num_dest, (num_src,), dtype=idx_dtype, device=device) - if index_noncontig: - # noncontiguous_like fails with RuntimeError: XLA tensors do not have storage - idx = torch.testing.make_non_contiguous(idx) + idx = torch.testing.make_tensor( + num_src, low=0, high=num_dest, dtype=idx_dtype, device=device, noncontiguous=index_noncontig + ) expected = dest.clone() dest.index_reduce_(dim, idx, src, reduce, include_self=include_self) # fill rows in idx with reduction inits if include_self=False @@ -2993,6 +3016,7 @@ def test_index_reduce(self, device, dtype, reduce): # FIXME: move to test indexing @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) + @dtypesIfMPS(*integral_types_and(torch.bool, torch.half, torch.float32)) def test_index_copy(self, device, dtype): # We just test for num_copy <= num_dest, as otherwise there are repeated indices # and the behavior is undefined @@ -3203,6 +3227,7 @@ def ref_index_select(src, dim, idx): # FIXME: find a test suite for the take operator @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) + @dtypesIfMPS(*integral_types_and(torch.bool, torch.half, torch.float32)) def test_take(self, device, dtype): idx_size = (4,) @@ -3238,6 +3263,7 @@ def ref_take(src, idx): # The bool instance does not work on GPU. See # https://github.com/pytorch/pytorch/issues/54317 @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16)) + @dtypesIfMPS(*integral_types_and(torch.half, torch.float32)) def test_put(self, device, dtype): src_size = (4,) @@ -3309,6 +3335,7 @@ def ref_put(dst, idx, src, accumulate): # The bool instance does not work on GPU. See # https://github.com/pytorch/pytorch/issues/54317 @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16)) + @dtypesIfMPS(*integral_types_and(torch.half, torch.float32)) def test_put_accumulate(self, device, dtype): # Test for parallel adds with accumulate == True low_precision = dtype == torch.half or dtype == torch.bfloat16 @@ -3356,6 +3383,7 @@ def scatter_allow_reduce(self, device, dtype, reduceop): @dtypes(*floating_and_complex_types()) @dtypesIfCPU(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) @dtypesIfCUDA(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) + @dtypesIfMPS(*integral_types_and(torch.bool, torch.half, torch.float32)) def test_scatter_reduce_operations_to_large_input(self, device, dtype): index = torch.tensor([[1], [2]], device=device, dtype=torch.long) test_data = [ @@ -3383,6 +3411,7 @@ def test_scatter_reduce_operations_to_large_input(self, device, dtype): @dtypes(*floating_and_complex_types()) @dtypesIfCPU(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) @dtypesIfCUDA(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) + @dtypesIfMPS(*integral_types_and(torch.bool, torch.half, torch.float32)) def test_scatter_reduce_scalar(self, device, dtype): index = torch.tensor([[1], [2]], device=device, dtype=torch.long) test_data = [ @@ -3422,6 +3451,7 @@ def test_scatter_add_non_unique_index(self, device): @dtypes(*floating_and_complex_types()) @dtypesIfCPU(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) @dtypesIfCUDA(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) + @dtypesIfMPS(*integral_types_and(torch.bool, torch.half, torch.float32)) def test_scatter_reduce_non_unique_index(self, device, dtype): height = 2 width = 2 @@ -3496,6 +3526,7 @@ def test_scatter_add_bool(self, device): # FIXME: find a test suite for the masked scatter operator @onlyNativeDeviceTypes @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) + @dtypesIfMPS(*integral_types_and(torch.bool, torch.half, torch.float32)) def test_masked_scatter(self, device, dtype): dt = dtype with warnings.catch_warnings(record=True) as w: @@ -3584,6 +3615,7 @@ def test_masked_scatter_large_tensor(self, device): # FIXME: find a test suite for the masked select operator @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) + @dtypesIfMPS(*integral_types_and(torch.bool, torch.half, torch.float32)) def test_masked_select(self, device, dtype): if device == 'cpu': warn = 'masked_select received a mask with dtype torch.uint8,' @@ -3652,6 +3684,7 @@ def test_masked_select_discontiguous(self, device): # FIXME: find a test suite for the masked fill operator @dtypes(*product(all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16), (torch.uint8, torch.bool))) + @dtypesIfMPS(*product(integral_types_and(torch.half, torch.float, torch.bool), (torch.uint8, torch.bool))) def test_masked_fill(self, device, dtypes): dtype = dtypes[0] mask_dtype = dtypes[1] @@ -3914,7 +3947,7 @@ def test_dim_function_empty(self, device): @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "sandcastle OOM with current tpx gpu/re configuration") @skipIfRocm @onlyCUDA - @largeTensorTest('10GB', device='cpu') + @largeTensorTest('32GB', device='cpu') @largeTensorTest('5GB', device='cuda') def test_pdist_norm_large(self, device): # use dim0>=46342 for forward, see: @@ -3924,7 +3957,8 @@ def test_pdist_norm_large(self, device): # Will require 1249975000 float32s expected_cpu = torch.pdist(x, p=2) # ~1250M * 4 bytes = 5 GB on CPU actual_gpu = torch.pdist(x.to(device), p=2) # 5 GB on GPU - self.assertEqual(expected_cpu, actual_gpu.cpu()) # Another 5 GB on CPU + # Workaround for large memory overhead of self.assertTrue (see #84944) + self.assertTrue(torch.allclose(expected_cpu, actual_gpu.cpu())) # FIXME: move to elementwise ternary test suite @onlyNativeDeviceTypes @@ -4649,6 +4683,7 @@ def compare_strides(s1, s2, div): @onlyCUDA @unittest.skipIf(PYTORCH_CUDA_MEMCHECK, "is_pinned uses failure to detect pointer property") + @skipIfTorchInductor("pin_memory isn't yet supported in TorchInductor") def test_pin_memory_from_constructor(self, device): def _get_like(t, **kwargs): return [ @@ -4870,6 +4905,7 @@ def _test_memory_format_transformations(self, device, input_generator_fn, transf x = x.permute(permutation) self.assertEqual(x.stride(), transformation_fn(x, memory_format=torch.preserve_format).stride()) + @skipIfMps def test_memory_format_to(self, device): def get_generator(memory_format, shape): def input_generator_fn(device): @@ -4887,6 +4923,7 @@ def transformation_fn(tensor, **kwargs): self._test_memory_format_transformations( device, get_generator(mf, shape), transformation_fn, mf, default_is_preserve=True) + @skipIfMps def test_memory_format_type(self, device): def get_generator(memory_format, shape): def input_generator_fn(device): @@ -4946,6 +4983,7 @@ def input_generator_fn(device): self._test_memory_format_transformations( device, get_generator(mf, shape), transformation_fn, mf, compare_data=False, default_is_preserve=True) + @skipIfMps def test_memory_format_type_shortcuts(self, device): def get_generator(memory_format, shape, dtype): def input_generator_fn(device): @@ -5191,6 +5229,7 @@ def test_assertRaisesRegex_ignore_msg_non_native_device(self, device): torch.nn.functional.nll_loss(x, t, weight=invalid_weight) @dtypes(*all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.complex32)) + @dtypesIfMPS(*integral_types_and(torch.bool, torch.half, torch.float32)) def test_copy_(self, device, dtype): def can_cast(src_dtype, dst_dtype): # torch.can_cast(torch.int16, torch.uint8) returns True @@ -5228,6 +5267,7 @@ def make_tensor_wrapper(shape, dtype): self.assertEqual(src, dst.copy_(t), rtol=rtol, atol=atol) @dtypes(*all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.complex32)) + @dtypesIfMPS(*integral_types_and(torch.bool, torch.half, torch.float32)) def test_item(self, device, dtype): t = torch.ones((), device=device, dtype=dtype) self.assertEqual(1, t.item()) @@ -5308,6 +5348,9 @@ def test_type_conversions_same_device(self, devices): @dtypesIfCUDA(torch.half, torch.float, torch.double, torch.int8, torch.short, torch.int, torch.long, torch.uint8) + @dtypesIfMPS(torch.half, torch.float, + torch.int8, torch.short, torch.int, torch.long, + torch.uint8) @dtypes(torch.float, torch.double, torch.int8, torch.short, torch.int, torch.long, torch.uint8) @@ -5588,10 +5631,10 @@ def test_index_add(self): dest = make_tensor(dest.shape, device=device, dtype=dest.dtype, noncontiguous=True) src = torch.randn(num_copy, *other_sizes, device=device) if not src_contig: - src = torch.testing.make_non_contiguous(src) + src = noncontiguous_like(src) idx = torch.randperm(num_dest, dtype=dtype, device=device).narrow(0, 0, num_copy) if not index_contig: - idx = torch.testing.make_non_contiguous(idx) + idx = noncontiguous_like(idx) # index_add_ without alpha argument dest2 = dest.clone() dest.index_add_(0, idx, src) @@ -5674,36 +5717,6 @@ def test_unflatten(self): r"the unspecified dimension size -1 can be any value and is ambiguous"): torch.randn(2, 0).unflatten(1, (2, -1, 0)) - def test_pytorch_library_disabled_env(self): - import subprocess - env = os.environ.copy() - env['PYTORCH_DISABLE_LIBRARY'] = '1' - try: - subprocess.check_output([sys.executable, '-c', 'import torch'], env=env) - except subprocess.CalledProcessError as e: - raise RuntimeError("Could not 'import torch' with PYTORCH_DISABLE_LIBRARY=0") from e - - # Test that warnings generated from C++ are translated to the correct type - def test_warn_types(self): - test_cases = [ - # function, warning type, message - (torch._C._warn, UserWarning, r"Test message for TORCH_WARN"), - (torch._C._warn_deprecation, DeprecationWarning, r"Test message for TORCH_WARN_DEPRECATION"), - ] - - for fn, warning_type, message in test_cases: - with warnings.catch_warnings(record=True) as w: - warnings.resetwarnings() - warnings.filterwarnings('always', category=warning_type) - fn() - - self.assertEqual(len(w), 1, msg=f'{warning_type} not raised') - warning = w[0].message - self.assertTrue(isinstance(warning, warning_type), msg=f'{warning_type} not raised') - self.assertTrue(re.search( - message, - str(warning))) - def test_structseq_repr(self): a = torch.arange(250).reshape(5, 5, 10) expected = """ @@ -6459,6 +6472,114 @@ def test_storage_casts(self): self.assertEqual(complexdouble_storage.type(), 'torch.ComplexDoubleStorage') self.assertIs(complexdouble_storage.dtype, torch.complex128) + # Test that internal versions of functions related to TypedStorage do not + # produce a deprecation warning + def test_typed_storage_internal_no_warning(self): + s0 = torch.FloatStorage(10) + s0_untyped = s0.untyped() + t0 = torch.randn(10) + + funcs = [ + lambda: torch.FloatStorage(_internal=True), + lambda: torch.TypedStorage( + dtype=torch.float, + device='cpu', + _internal=True), + lambda: torch.TypedStorage( + wrap_storage=s0_untyped, + dtype=s0.dtype, + _internal=True), + lambda: torch.FloatStorage._dtype, + lambda: s0._resize_(20), + lambda: s0._size(), + lambda: s0._untyped_storage, + lambda: s0._is_shared(), + lambda: s0._share_memory_(), + lambda: s0._pickle_storage_type(), + lambda: s0._setitem(slice(0, s0._size()), 1), + lambda: s0._element_size(), + lambda: s0._deepcopy({}), + lambda: s0._data_ptr(), + lambda: s0._nbytes(), + lambda: t0._typed_storage(), + ] + + if torch.cuda.is_available(): + s1 = torch.cuda.FloatStorage(10) + s1_untyped = s1.untyped() + t1 = torch.randn(10, device='cuda') + + funcs += [ + lambda: torch.cuda.FloatStorage(_internal=True), + lambda: torch.TypedStorage( + dtype=torch.float, + device='cuda', + _internal=True), + lambda: torch.TypedStorage( + wrap_storage=s1_untyped, + dtype=s1.dtype, + _internal=True), + lambda: torch.cuda.FloatStorage._dtype, + lambda: s1._resize_(20), + lambda: s1._size(), + lambda: s1._untyped_storage, + lambda: s1._is_shared(), + lambda: s1._share_memory_(), + lambda: s1._pickle_storage_type(), + lambda: s1._setitem(slice(0, s1._size()), 1), + lambda: s1._element_size(), + lambda: s1._deepcopy({}), + lambda: s1._data_ptr(), + lambda: s1._nbytes(), + lambda: t1._typed_storage(), + ] + + # Check that each of the TypedStorage internal function calls do not + # produce a deprecation warning + for f in funcs: + with warnings.catch_warnings(): + warnings.filterwarnings('error', "TypedStorage is deprecated") + f() + + # Test that public functions related to TypedStorage produce a deprecation + # warning + def test_typed_storage_deprecation_warning(self): + s0 = torch.FloatStorage(10) + funcs = [ + lambda: torch.FloatStorage(), + lambda: torch.FloatStorage.dtype, + lambda: s0.fill_(0), + lambda: s0.is_cuda, + lambda: s0.untyped(), + lambda: len(s0), + lambda: s0[0], + ] + + if torch.cuda.is_available(): + s1 = torch.cuda.FloatStorage(10) + funcs += [ + lambda: torch.cuda.FloatStorage(), + lambda: torch.cuda.FloatStorage.dtype, + lambda: s1.fill_(0), + lambda: s1.is_cuda, + lambda: s1.untyped(), + lambda: len(s1), + lambda: s1[0], + ] + + # Check that each of the TypedStorage function calls produce a warning + # if warnings are reset between each + for f in funcs: + with warnings.catch_warnings(record=True) as w: + warnings.resetwarnings() + f() + self.assertEqual(len(w), 1, msg=str([str(a) for a in w])) + warning = w[0].message + self.assertTrue(warning, DeprecationWarning) + self.assertTrue(re.search( + '^TypedStorage is deprecated', + str(warning))) + def test_from_file(self): def assert_with_filename(filename): size = 10000 @@ -6824,6 +6945,7 @@ def test_new(self) -> None: self.assertRaises(RuntimeError, lambda: x.new(z.storage())) @unittest.skipIf(PYTORCH_CUDA_MEMCHECK, "is_pinned uses failure to detect pointer property") + @skipIfTorchInductor("pin_memory isn't yet supported in TorchInductor") def test_pin_memory(self): x = torch.randn(3, 5) self.assertFalse(x.is_pinned()) @@ -7568,6 +7690,51 @@ def test_copy_many_to_one(self): # storage to a single storage would cause RuntimeError to be thrown self.assertRaises(RuntimeError, lambda: torch.zeros(1, 6).expand(5, 6).copy_(torch.zeros(5, 6))) + def test_copy_float16(self): + # Check that fbgemm code no longer reads memory out of bounds, see + # copy_impl and fbgemm::Float16ToFloat_ref. + # https://github.com/pytorch/pytorch/issues/88543 + + # Types to test different code paths in copy_impl. + dtypes = ( + # out_dtype, src_dtype + (torch.float32, torch.float16), # fbgemm + (torch.float16, torch.float32), # fbgemm + (torch.float32, torch.float32), # TensorIterator + ) + + cases = ( + # out_shape, src_shape, is_ok + # These cases used to crash with fbgemm, make sure these also raise + # exceptions with TensorIterator. + ((1, 2, 3), (0, 2, 3), False), # same strides, not allowed by TI + ((1, 5, 6), (4, 5, 6), False), # same strides, not allowed by TI + (1, (0, 2, 3), False), # different strides + ((4, 5, 6), (0, 2, 3), False), # different strides + ((4, 5, 6), (1, 2, 3), False), # different strides + ((4, 5, 6), (6, 5, 4), False), # same numel + + # These cases should pass with fbgemm and TensorIterator. + ((4, 5, 6), (1, 5, 6), True), # same strides + ((4, 5, 6), (4, 5, 6), True), # same strides + ((0, 2, 3), 1, True), # different strides, allowed by TI + ((4, 5, 6), (4, 5, 1), True), # different strides, allowed by TI + ) + + for (out_shape, src_shape, is_ok), (out_dtype, src_dtype) in itertools.product(cases, dtypes): + out = torch.zeros(out_shape, dtype=out_dtype, device=torch.device('cpu')) + src = torch.ones(src_shape, dtype=src_dtype, device=torch.device('cpu')) + if is_ok: + if torch.cuda.is_available(): + out_cuda = out.cuda() + src_cuda = src.cuda() + res = out.copy_(src) + if torch.cuda.is_available(): + res_cuda = out_cuda.copy_(src_cuda) + self.assertEqual(res, res_cuda) + else: + self.assertRaises(RuntimeError, lambda: out.copy_(src)) + # FIXME: Port to a more appropriate test suite def _test_to_with_layout(self, layout): def test_copy_behavior(t, non_blocking=False): @@ -8337,6 +8504,19 @@ def test_conj_neg_tolist(self): self.assertEqual(y1, y1_expect.tolist()) self.assertEqual(y2, y1_expect.imag.tolist()) + @unittest.skipIf(torch.backends.cuda.is_built(), "Skipped for cuda-enabled build") + def test_no_cuda_monkeypatch(self): + # Note that this is not in test_cuda.py as this whole file is skipped when cuda + # is not available. + with self.assertRaisesRegex(RuntimeError, "Tried to instantiate dummy base class Stream"): + torch.cuda.Stream() + + with self.assertRaisesRegex(RuntimeError, "Tried to instantiate dummy base class Event"): + torch.cuda.Event() + + with self.assertRaisesRegex(RuntimeError, "Tried to instantiate dummy base class CUDAGraph"): + torch.cuda.graphs.CUDAGraph() + # The following block extends TestTorch with negative dim wrapping tests # FIXME: replace these with OpInfo sample inputs or systemic OpInfo tests # Functions to test negative dimension wrapping diff --git a/test/test_transformers.py b/test/test_transformers.py index 1eff4d61fb203..0260c822498d3 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -1,14 +1,16 @@ # Owner(s): ["module: nn"] import contextlib +from functools import partial import torch import torch.nn as nn import torch.nn.functional as F import unittest from unittest.mock import patch import math -from torch.backends.cuda import sdp_kernel +from torch.backends.cuda import sdp_kernel, SDPBackend import torch.optim as optim +from torch.testing._internal.common_dtype import floating_types_and_half from torch.testing._internal.common_nn import NNTestCase from torch.testing._internal.common_utils import ( @@ -19,22 +21,18 @@ freeze_rng_state, TEST_WITH_CROSSREF, TEST_WITH_ROCM, - IS_WINDOWS + IS_WINDOWS, + slowTest, + set_default_dtype, + gradcheck ) -from torch.testing._internal.common_cuda import TEST_CUDA + +from torch.testing._internal.common_methods_invocations import wrapper_set_seed +from torch.testing._internal.common_cuda import TEST_CUDA, SM80OrLater if TEST_FAIRSEQ: import fairseq.models.transformer as fairseq_transformer -@contextlib.contextmanager -def set_default_dtype(dtype): - saved_dtype = torch.get_default_dtype() - torch.set_default_dtype(dtype) - try: - yield - finally: - torch.set_default_dtype(saved_dtype) - class TestTransformers(NNTestCase): _do_cuda_memory_leak_check = True _do_cuda_non_default_stream = True @@ -71,6 +69,7 @@ def test_self_attn_TxT_attn_mask(self): self.assertEqual(output_mask_4d, output_mask_TxT) @parametrize("device", device_list) + @slowTest def test_train_with_pad_and_catch_error(self, device): iters = 100 pad_mask = torch.tensor([[1, 1, 0, 0]], dtype=torch.bool).to(device) @@ -148,14 +147,14 @@ def test_transformerencoderlayer_src_mask(self, device, nhead): @parametrize("use_torchscript", [False]) @parametrize("enable_nested_tensor", [True, False]) @parametrize("use_autocast", [True, False]) - def test_transformerencoder_fastpath(self, device, use_torchscript, enable_nested_tensor, use_autocast): + @parametrize("d_model", [12, 256]) + def test_transformerencoder_fastpath(self, device, use_torchscript, enable_nested_tensor, use_autocast, d_model): """ Test TransformerEncoder fastpath output matches slowpath output """ torch.manual_seed(1234) - d_model = 12 nhead = 4 - dim_feedforward = 12 + dim_feedforward = d_model batch_first = True model = torch.nn.TransformerEncoder( @@ -213,7 +212,7 @@ def test_transformerencoder_fastpath(self, device, use_torchscript, enable_neste ] input_mask_pairs = [ ( - torch.tensor(pair[0], device=device, dtype=torch.float32), # float input + torch.tensor(pair[0], device=device, dtype=torch.get_default_dtype()), # float input torch.tensor(pair[1], device=device, dtype=torch.bool) # bool mask ) for pair in input_mask_pairs ] @@ -224,7 +223,6 @@ def test_transformerencoder_fastpath(self, device, use_torchscript, enable_neste with torch.no_grad(): fastpath_output = model(input, src_key_padding_mask=src_key_padding_mask) slowpath_output = model(input, src_key_padding_mask=src_key_padding_mask) # reference - # Make sure fastpath_output is same shape as slowpath_output and mask. # When enable_nested_tensor=true, fastpath_output may be smaller than input tensor. # Eg if input bs=1, seqlen=6, and we mask out 2 tokens, fastpath_output will have bs=1, seqlen=4. @@ -266,7 +264,7 @@ def test_transformerencoder_square_input(self, with_no_grad, training, enable_ne model = model.train() else: model = model.eval() - x = torch.arange(0, 16).reshape(2, 2, 4).to(torch.float).to(device) + x = torch.arange(0, 16).reshape(2, 2, 4).to(torch.get_default_dtype()).to(device) src_mask = torch.Tensor([[0, 1], [0, 0]]).to(torch.bool).to(device) if with_no_grad: @@ -867,26 +865,38 @@ def rand_tensor(*shape): actual = torch.ops.aten._scaled_dot_product_attention( query, key, value, attn_mask, dropout_p, need_attn_weights, is_causal) - # freeze_rng_state() doesn't seem to work outside of CPU, so dropout makes the results incomparable. - # TODO: Do this skipping in a nicer way once the granular test skipping logic lands. - if dropout_p == 0.0 or device == 'cpu': self.assertEqual(actual, expected) + if attn_mask_dim is None: + q = q.double().clone() + k = k.double().clone() + v = v.double().clone() + q.requires_grad_() + k.requires_grad_() + v.requires_grad_() + + assert gradcheck(lambda *args, **kwargs: wrapper_set_seed(sdp_ref, *args, **kwargs), + (q, k, v, attn_mask, dropout_p)) + assert gradcheck(lambda *args, **kwargs: + wrapper_set_seed(torch.nn.functional._scaled_dot_product_attention, *args, **kwargs), + (q, k, v, attn_mask, dropout_p)) + @unittest.skipIf(TEST_WITH_CROSSREF, 'Fastpath not available with crossref') @torch.no_grad() def test_mask_check_fastpath(self): """ - Test that fastpath is executed independently of the mask that is passed. - If the passed mask is left aligned or mask_check=False, test that nested tensors are used (sparsity fastpath), - otherwise use fastpath with traditional tensors. + Test that fastpath is executed independently of the masks that are passed. + If the passed key padding mask is left aligned or mask_check=False, test that nested tensors are used + (sparsity fastpath), otherwise use fastpath with traditional tensors. + Also test that fast path is executed with both key padding mask and attention mask passed at the same time. """ x = torch.Tensor([[[1, 2], [3, 4], [5, 6]]]).to(torch.float) - def _test_fastpath(model, mask, mock_return_value, nested_tensors=True): + def _test_fastpath(model, key_padding_mask, mock_return_value, attn_mask=None, nested_tensors=True): with patch('torch._transformer_encoder_layer_fwd') as fastpath_mock: fastpath_mock.return_value = mock_return_value - model(x, src_key_padding_mask=mask) + model(x, src_key_padding_mask=key_padding_mask, mask=attn_mask) # If mock was called, fastpath was taken self.assertTrue(fastpath_mock.called) @@ -900,44 +910,52 @@ def _test_fastpath(model, mask, mock_return_value, nested_tensors=True): model = torch.nn.TransformerEncoder(encoder_layer, num_layers=2, enable_nested_tensor=True, mask_check=True) model.eval() - aligned_mask = torch.Tensor([[0, 0, 1]]).to(torch.bool) - not_aligned_mask = torch.Tensor([[1, 0, 1]]).to(torch.bool) + aligned_key_padding_mask = torch.Tensor([[0, 0, 1]]).to(torch.bool) + not_aligned_key_padding_mask = torch.Tensor([[1, 0, 1]]).to(torch.bool) + attn_mask = torch.Tensor([[1, 0, 1], [0, 1, 0], [1, 0, 1]]).to(torch.bool) nested_tensor_return_value = torch.nested.nested_tensor([torch.ones((2, 2), dtype=torch.float)]) tensor_return_value = torch.ones((1, 3, 2), dtype=torch.float) # Left aligned mask results in sparsity fastpath - _test_fastpath(model, aligned_mask, nested_tensor_return_value, nested_tensors=True) + _test_fastpath(model, aligned_key_padding_mask, nested_tensor_return_value, nested_tensors=True) # Not aligned mask results in fastpath - _test_fastpath(model, not_aligned_mask, tensor_return_value, nested_tensors=False) + _test_fastpath(model, not_aligned_key_padding_mask, tensor_return_value, nested_tensors=False) model = torch.nn.TransformerEncoder(encoder_layer, num_layers=2, enable_nested_tensor=False, mask_check=True) model.eval() # If nested tensor disabled, fastpath is always taken - _test_fastpath(model, aligned_mask, tensor_return_value, nested_tensors=False) - _test_fastpath(model, not_aligned_mask, tensor_return_value, nested_tensors=False) - + _test_fastpath(model, aligned_key_padding_mask, tensor_return_value, nested_tensors=False) + _test_fastpath(model, not_aligned_key_padding_mask, tensor_return_value, nested_tensors=False) + # Fast path is taken if both attention mask and key padding mask are present + _test_fastpath(model, aligned_key_padding_mask, tensor_return_value, attn_mask=attn_mask, nested_tensors=False) model = torch.nn.TransformerEncoder(encoder_layer, num_layers=2, enable_nested_tensor=True, mask_check=False) model.eval() # Mask check disabled results in sparisty fastpath, independently of the mask - _test_fastpath(model, aligned_mask, nested_tensor_return_value, nested_tensors=True) - _test_fastpath(model, not_aligned_mask, nested_tensor_return_value, nested_tensors=True) + _test_fastpath(model, aligned_key_padding_mask, nested_tensor_return_value, nested_tensors=True) + _test_fastpath(model, not_aligned_key_padding_mask, nested_tensor_return_value, nested_tensors=True) + + def rand_nt(self, shape, device, dtype, requires_grad=False, packed=False): + batch, seq_len, num_heads, head_dim = shape + size = (seq_len, num_heads, head_dim) if not packed else (seq_len, 3 * num_heads * head_dim) + return torch.nested.nested_tensor([ + torch.randn(size, device=device, dtype=dtype, requires_grad=requires_grad) + for _ in range(batch)]) + + def rand_tensor(self, shape, device, dtype, requires_grad=False, packed=False): + batch, seq_len, num_heads, head_dim = shape + size = (batch, seq_len, num_heads, head_dim) if not packed else (batch, seq_len, 3 * num_heads * head_dim) + return torch.randn(size, device=device, dtype=dtype, requires_grad=requires_grad) @unittest.skipIf(not TEST_CUDA or TEST_WITH_ROCM or IS_WINDOWS, "Flash Attention was not built for this system") @parametrize("type", ["dense", "nested"]) @parametrize("is_contiguous", [True, False]) def test_scaled_dot_product_attention_fused_kernels(self, type: str, is_contiguous: bool): - def rand_nt(shape): - batch, seq_len, num_heads, head_dim = shape - return torch.nested.nested_tensor([torch.randn(seq_len, num_heads, head_dim, - device="cuda", dtype=torch.float16) for _ in range(batch)]) - - def rand_tensor(shape): - batch, seq_len, num_heads, head_dim = shape - return torch.randn(batch, seq_len, num_heads, head_dim, device="cuda", dtype=torch.float16) + rand_nt = partial(self.rand_nt, device="cuda", dtype=torch.float16) + rand_tensor = partial(self.rand_tensor, device="cuda", dtype=torch.float16) batch, seq_len, num_heads, head_dim = 32, 64, 16, 64 shape = (batch, seq_len, num_heads, head_dim) @@ -975,14 +993,8 @@ def rand_tensor(shape): @parametrize("type", ["dense", "nested"]) @parametrize("is_contiguous", [True, False]) def test_scaled_dot_product_attention_fused_kernels_packed(self, type: str, is_contiguous: bool): - def rand_nt(shape): - batch, seq_len, num_heads, head_dim = shape - return torch.nested.nested_tensor([torch.randn(seq_len, 3 * num_heads * head_dim, - device="cuda", dtype=torch.float16) for _ in range(batch)]) - - def rand_tensor(shape): - batch, seq_len, num_heads, head_dim = shape - return torch.randn(batch, seq_len, 3 * num_heads * head_dim, device="cuda", dtype=torch.float16) + rand_nt = partial(self.rand_nt, device="cuda", dtype=torch.float16, packed=True) + rand_tensor = partial(self.rand_tensor, device="cuda", dtype=torch.float16, packed=True) batch_size, seq_len, num_heads, head_dim = 32, 64, 16, 64 shape = (batch_size, seq_len, num_heads, head_dim) @@ -1010,6 +1022,197 @@ def rand_tensor(shape): self.assertEqual(actual[0].contiguous(), math_ref[0].contiguous(), atol=2e-3, rtol=1e-2) + @unittest.skipIf(not TEST_CUDA or TEST_WITH_ROCM or IS_WINDOWS, "Flash Attention was not built for this system") + @parametrize("type", ["dense", "nested"]) + @parametrize("fused_kernel", ["flash", "mem_efficient"]) + def test_scaled_dot_product_attention_fused_kernels_packed_accuracy(self, type: str, fused_kernel: str): + if (not SM80OrLater) and fused_kernel == "flash": + return + + def rand_nt(shape): + batch, seq_len, num_heads, head_dim = shape + tensors = [6 * torch.rand((seq_len, 3 * num_heads * head_dim), device="cuda", dtype=torch.float32) - 3 + for _ in range(batch)] + return (torch.nested.nested_tensor(tensors, device="cuda", dtype=torch.float32), + torch.nested.nested_tensor(tensors, device="cuda", dtype=torch.float16)) + + def rand_tensor(shape): + batch, seq_len, num_heads, head_dim = shape + tensor = 6 * torch.rand((batch, seq_len, 3 * num_heads * head_dim), device="cuda", dtype=torch.float32) - 3 + return tensor, tensor.to(dtype=torch.float16) + + batch_size, seq_len, num_heads, head_dim = 16, 8, 4, 64 + shape = (batch_size, seq_len, num_heads, head_dim) + + # Test Packed + qkv, qkv_low_precision = rand_tensor(shape) if type == "dense" else rand_nt(shape) + query, key, value = qkv.chunk(3, dim=-1) + query_lp, key_lp, value_lp = qkv_low_precision.chunk(3, dim=-1) + + query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) + + query_lp = query_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) + key_lp = key_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) + value_lp = value_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) + + if fused_kernel == "flash": + with sdp_kernel(enable_mem_efficient=False, enable_math=False): + # TODO Flash for the nested path is currently not working due to cuda memory issues + if type == "nested": + self.assertRaises(RuntimeError, lambda: torch.nn.functional._scaled_dot_product_attention( + query_lp, key_lp, value_lp, attn_mask=None, dropout_p=0.0, need_attn_weights=False, is_causal=False)) + return + actual = torch.nn.functional._scaled_dot_product_attention( + query_lp, key_lp, value_lp, attn_mask=None, dropout_p=0.0, need_attn_weights=False, is_causal=False) + elif fused_kernel == "mem_efficient": + with sdp_kernel(enable_flash=False, enable_math=False): + actual = torch.nn.functional._scaled_dot_product_attention( + query_lp, key_lp, value_lp, attn_mask=None, dropout_p=0.0, need_attn_weights=False, is_causal=False) + + with sdp_kernel(enable_flash=False, enable_mem_efficient=False): + math_ref_lp = torch.nn.functional._scaled_dot_product_attention( + query_lp.contiguous(), key_lp.contiguous(), value_lp.contiguous(), + attn_mask=None, dropout_p=0.0, need_attn_weights=False, is_causal=False) + + with sdp_kernel(enable_flash=False, enable_mem_efficient=False): + math_query = query.contiguous() + math_key = key.contiguous() + math_value = value.contiguous() + + math_ref = torch.nn.functional._scaled_dot_product_attention( + math_query, math_key, math_value, attn_mask=None, dropout_p=0.0, need_attn_weights=False, is_causal=False) + + actual_test = actual[0] + math_ref_test = math_ref[0] + math_ref_lp_test = math_ref_lp[0] + + if actual_test.is_nested: + actual_test = torch.nested.to_padded_tensor(actual_test.contiguous(), padding=0.0) + math_ref_test = torch.nested.to_padded_tensor(math_ref_test, padding=0.0) + math_ref_lp_test = torch.nested.to_padded_tensor(math_ref_lp_test, padding=0.0) + + actual_test = actual_test.to(dtype=torch.float32).contiguous() + math_ref_test = math_ref_test.to(dtype=torch.float32).contiguous() + math_ref_lp_test = math_ref_lp_test.to(dtype=torch.float32).contiguous() + + self.assertEqual(math_ref_test, math_ref_lp_test, atol=7e-3, rtol=7e-3) + self.assertEqual(actual_test, math_ref_test, atol=5e-3, rtol=5e-3) + + @unittest.skipIf(not TEST_CUDA or TEST_WITH_ROCM or IS_WINDOWS, "Flash Attention was not built for this system") + @parametrize("contiguous_inputs", [True, False]) + def test_sdp_math_gradcheck(self, contiguous_inputs: bool): + + batch_size, seq_len, num_heads, head_dim = 4, 4, 2, 16 + rand_tensor = partial(self.rand_tensor, device="cuda", dtype=torch.float64, requires_grad=True, packed=True) + + qkv = rand_tensor((batch_size, seq_len, num_heads, head_dim)) + query, key, value = qkv.chunk(3, dim=-1) + + query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) + + if contiguous_inputs: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + + with sdp_kernel(enable_math=True, enable_mem_efficient=False, enable_flash=False): + assert gradcheck(lambda *args, **kwargs: + wrapper_set_seed(torch.nn.functional._scaled_dot_product_attention, *args, **kwargs), + (query, key, value, None, 0.0, False, False) + ) + + @unittest.skipIf(not TEST_CUDA or TEST_WITH_ROCM or IS_WINDOWS, "Flash Attention was not built for this system") + @parametrize("contiguous_inputs", [True, False]) + def test_sdp_fused_grad_against_math(self, contiguous_inputs: bool): + batch_size, seq_len, num_heads, head_dim = 4, 4, 2, 16 + rand_tensor = partial(self.rand_tensor, device="cuda", dtype=torch.float64, requires_grad=True, packed=True) + + qkv = rand_tensor((batch_size, seq_len, num_heads, head_dim)) + qkv_lp = qkv.detach().clone().to(torch.float32).requires_grad_() + + query, key, value = qkv.chunk(3, dim=-1) + query_lp, key_lp, value_lp = qkv_lp.chunk(3, dim=-1) + + query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) + + query_lp = query_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) + key_lp = key_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) + value_lp = value_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) + + if contiguous_inputs: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + + query_lp = query_lp.contiguous() + key_lp = key_lp.contiguous() + value_lp = value_lp.contiguous() + + with sdp_kernel(enable_math=True, enable_mem_efficient=False, enable_flash=False): + out, atten = torch.nn.functional._scaled_dot_product_attention(query, key, value, None, 0.0, False, False) + + with sdp_kernel(enable_math=False, enable_mem_efficient=True, enable_flash=False): + out_lp, atten_lp = torch.nn.functional._scaled_dot_product_attention( + query_lp, key_lp, value_lp, None, 0.0, False, False) + + rand_upward = torch.rand_like(out) + rand_upward_lp = rand_upward.to(torch.float32) + + out.backward(rand_upward) + out_lp.backward(rand_upward_lp) + + # Cast up and compare + self.assertEqual(qkv.grad, qkv_lp.grad.to(torch.float64), atol=1e-5, rtol=1e-5) + + @parametrize("type", ["dense", "nested"]) + def test_fused_sdp_choice(self, type: str): + device = "cpu" + # Test that cpu and nestedtensor cpu return MATH backend + for dtype in floating_types_and_half(): + make_tensor = partial(self.rand_tensor, device=device, dtype=dtype) + size = (2, 2, 3, 4) + q, k, v = make_tensor(size), make_tensor(size), make_tensor(size) + assert torch._fused_sdp_choice(q, k, v) == SDPBackend.MATH + + if TEST_CUDA and not TEST_WITH_ROCM and not IS_WINDOWS: + batch_size, seq_len, num_heads, head_dim = 32, 64, 16, 64 + shape = (batch_size, seq_len, num_heads, head_dim) + device = "cuda" + make_tensor = partial(self.rand_tensor, device=device, dtype=torch.float16, packed=True) + make_nt = partial(self.rand_nt, device=device, dtype=torch.float16, packed=True) + + qkv = make_tensor(shape) if type == "dense" else make_nt(shape) + query, key, value = qkv.chunk(3, dim=-1) + + query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) + + if SM80OrLater and not type == "nested": + assert torch._fused_sdp_choice(query, key, value) == SDPBackend.FLASH_ATTENTION + else: + assert torch._fused_sdp_choice(query, key, value) == SDPBackend.EFFICIENT_ATTENTION + + # Change dtype to float32 so that efficient attention should get chosen + make_tensor = partial(self.rand_tensor, device=device, dtype=torch.float32, packed=True) + make_nt = partial(self.rand_nt, device=device, dtype=torch.float32, packed=True) + + qkv = make_tensor(shape) if type == "dense" else make_nt(shape) + query, key, value = qkv.chunk(3, dim=-1) + + query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) + + assert torch._fused_sdp_choice(query, key, value) == SDPBackend.EFFICIENT_ATTENTION + + @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") def test_sdp_runtime_dispatch(self): # We will test all the constraints that we know will cause a failure @@ -1017,37 +1220,49 @@ def test_sdp_runtime_dispatch(self): # will fail on CI/CD becuase it is not compiled with the right flags device = 'cuda' dtype = torch.float16 + make_tensor = partial(self.rand_tensor, device=device, dtype=dtype) - def make_tensor(*size, device=device, dtype=dtype): - return torch.randn(size, device=device, dtype=dtype) - - with sdp_kernel(enable_flash=False, enable_math=False): - q, k, v = make_tensor(2, 3, 4), make_tensor(2, 3, 4), make_tensor(2, 3, 4) + with sdp_kernel(enable_flash=False, enable_math=False, enable_mem_efficient=False): + size = (2, 3, 4) + q = torch.randn(size, device=device, dtype=dtype) + k = torch.randn(size, device=device, dtype=dtype) + v = torch.randn(size, device=device, dtype=dtype) + self.assertRaisesRegex(RuntimeError, "No viable backend for scaled_dot_product_attention was found.", + lambda: torch._fused_sdp_choice(q, k, v)) self.assertRaisesRegex(RuntimeError, "No viable backend for scaled_dot_product_attention was found.", lambda: torch.nn.functional._scaled_dot_product_attention(q, k, v)) - with sdp_kernel(enable_flash=True, enable_math=False): + with sdp_kernel(enable_flash=True, enable_mem_efficient=False, enable_math=False): # Failures for invalid input # Dim is not 4 - q, k, v = make_tensor(2, 3, 4), make_tensor(2, 3, 4), make_tensor(2, 3, 4) + q = torch.randn(size, device=device, dtype=dtype) + k = torch.randn(size, device=device, dtype=dtype) + v = torch.randn(size, device=device, dtype=dtype) self.assertRaises(RuntimeError, lambda: torch.nn.functional._scaled_dot_product_attention( q, k, v, None, 0.0, False, False)) # Xformers can now cover this case but will add back in next PR - # # Invalid last_dim size - # q, k, v = make_tensor(2, 2, 3, 4), make_tensor(2, 2, 3, 4), make_tensor(2, 2, 3, 4) - # self.assertRaises(RuntimeError, lambda: torch.nn.functional._scaled_dot_product_attention( - # q, k, v, None, 0.0, False, False)) + # Invalid last_dim size + size = (2, 2, 3, 4) + q, k, v = make_tensor(size), make_tensor(size), make_tensor(size) + self.assertRaises(RuntimeError, lambda: torch.nn.functional._scaled_dot_product_attention( + q, k, v, None, 0.0, False, False)) # Invalid dtype - q, k, v = make_tensor(2, 2, 3, 16, dtype=torch.float64), make_tensor( - 2, 2, 3, 16, dtype=torch.float64), make_tensor(2, 2, 3, 16, dtype=torch.float64) + size = (2, 2, 3, 16) + make_tensor = partial(self.rand_tensor, device=device, dtype=torch.float64) + q, k, v = make_tensor(size), make_tensor(size), make_tensor(size) + self.assertRaises(RuntimeError, lambda: torch.nn.functional._scaled_dot_product_attention( + q, k, v, None, 0.0, False, False)) + + make_tensor = partial(self.rand_tensor, device=device, dtype=torch.float32) + q, k, v = make_tensor(size), make_tensor(size), make_tensor(size) self.assertRaises(RuntimeError, lambda: torch.nn.functional._scaled_dot_product_attention( q, k, v, None, 0.0, False, False)) # Failures for unsupported SDP args - q, k, v = make_tensor(2, 2, 3, 16), make_tensor(2, 2, 3, 16), make_tensor(2, 2, 3, 16) + q, k, v = make_tensor(size), make_tensor(size), make_tensor(size) # Needs attention weights self.assertRaises(RuntimeError, lambda: torch.nn.functional._scaled_dot_product_attention( @@ -1057,6 +1272,15 @@ def make_tensor(*size, device=device, dtype=dtype): self.assertRaises(RuntimeError, lambda: torch.nn.functional._scaled_dot_product_attention( q, k, v, torch.ones_like(q), 0.0, False, False)) + # Test failing MHA when bias was NoneType + def test_bias_is_none(self): + x = torch.rand((1, 5, 10)) + model = torch.nn.modules.activation.MultiheadAttention(10, 1, bias=False, batch_first=True) + model.eval() + model(x, x, x) + # completes without error + + # TODO: Replace this with instantiate_device_type_tests() to take advantage of test framework support for # cross device / dtype testing. instantiate_parametrized_tests(TestTransformers) diff --git a/test/test_type_promotion.py b/test/test_type_promotion.py index a881c36075e3c..1d80556a7d48f 100644 --- a/test/test_type_promotion.py +++ b/test/test_type_promotion.py @@ -7,7 +7,8 @@ import torch from torch.testing._internal.common_utils import (TestCase, run_tests, load_tests, make_tensor, - TEST_NUMPY, torch_to_numpy_dtype_dict, numpy_to_torch_dtype_dict) + TEST_NUMPY, set_default_dtype, torch_to_numpy_dtype_dict, + numpy_to_torch_dtype_dict) from torch.testing._internal.common_device_type import (instantiate_device_type_tests, onlyNativeDeviceTypes, dtypes, onlyCPU, expectedFailureMeta, skipMeta) from torch.testing._internal.common_dtype import ( @@ -30,14 +31,10 @@ def float_double_default_dtype(fn): @wraps(fn) def wrapped_fn(*args, **kwargs): - cur_dtype = torch.get_default_dtype() - try: - torch.set_default_dtype(torch.float) + with set_default_dtype(torch.float): fn(*args, **kwargs) - torch.set_default_dtype(torch.double) + with set_default_dtype(torch.double): fn(*args, **kwargs) - finally: - torch.set_default_dtype(cur_dtype) return wrapped_fn @@ -476,7 +473,7 @@ def _get_dtype(x): elif isinstance(x, complex): return torch.complex64 else: - raise AssertionError(f"Unkonwn type {x}") + raise AssertionError(f"Unknown type {x}") # tensor against tensor a_tensor = torch.tensor((0, 1), device=device, dtype=dtypes[0]) diff --git a/test/test_unary_ufuncs.py b/test/test_unary_ufuncs.py index 5a9bdb53ab6b3..3676d88de5680 100644 --- a/test/test_unary_ufuncs.py +++ b/test/test_unary_ufuncs.py @@ -1094,6 +1094,62 @@ def test_mish(self, device, dtype): rtol=rtol, ) + @dtypes(torch.complex64, torch.complex128) + @onlyCPU + def test_log1p_complex(self, device, dtype): + # The output values here were obtained using arbitrary precision math (mpmath) + # and double checked with WolframAlpha. + # Not using numpy's log1p here because by the time of writing this, + # np.log1p has precision problems for small complex input values, see here: + # https://github.com/numpy/numpy/issues/22609 + inouts = [ + (0.2 + 0.3j, 0.21263386770217202 + 0.24497866312686414j), + (1e-19 + 1e-18j, 1e-19 + 1e-18j), + (1e-18 + 0.1j, 0.00497517 + 0.0996687j), + (0.1 + 1e-18j, 0.0953102 + 9.090909090909090909e-19j), + (0.5 + 0j, 0.40546510810816 + 0j), + (0.0 + 0.5j, 0.111571776 + 0.463647609j), + (2.0 + 1.0j, 1.151292546497023 + 0.3217505543966422j), + (-1.0 + 2.0j, 0.6931471805599453 + 1.570796326794897j), + (2.0j, 0.80471895621705014 + 1.1071487177940904j), + (-2.0j, 0.80471895621705014 - 1.1071487177940904j), + ] + # test the extreme values + if dtype == torch.complex128: + inouts += [ + (-1 + 1e250j, 575.6462732485114 + 1.5707963267948966j), + (1e250 + 1j, 575.6462732485114 + 1e-250j), + (1e250 + 1e250j, 575.9928468387914 + 0.7853981633974483j), + (1e-250 + 1e250j, 575.6462732485114 + 1.5707963267948966j), + (1e-250 + 2e-250j, 1e-250 + 2e-250j), + (1e250 + 1e-250j, 575.6462732485114 + 0.0j), + ] + elif dtype == torch.complex64: + inouts += [ + (-1 + 1e30j, 69.07755278982137 + 1.5707963267948966j), + (1e30 + 1j, 69.07755278982137 + 1e-30j), + (1e30 + 1e30j, 69.42412638010134 + 0.7853981633974483j), + (1e-30 + 1e30j, 69.07755278982137 + 1.5707963267948966j), + (1e-30 + 2e-30j, 1e-30 + 2e-30j), + (1e30 + 1e-30j, 69.07755278982137 + 0.0j), + ] + + # test the log1p individually + for inp, out in inouts: + res = torch.log1p(torch.tensor(inp, dtype=dtype, device=device)) + self.assertFalse(torch.any(torch.isnan(res))) + # setting up atol == 0.0 because some part has very small values + self.assertEqual(res.real, out.real, atol=0.0, rtol=1e-6) + self.assertEqual(res.imag, out.imag, atol=0.0, rtol=1e-6) + + # test the log1p in tensor + inp_lst, out_lst = [list(elmt) for elmt in zip(*inouts)] + inp_tens = torch.tensor(inp_lst, dtype=dtype, device=device) + out_tens = torch.tensor(out_lst, dtype=dtype, device=device) + res_tens = torch.log1p(inp_tens) + self.assertEqual(res_tens.real, out_tens.real, atol=0.0, rtol=1e-6) + self.assertEqual(res_tens.imag, out_tens.imag, atol=0.0, rtol=1e-6) + # do ops like threshold need a test_unary(_nonufunc) test suite? @onlyCPU @dtypes(*get_all_math_dtypes("cpu")) diff --git a/test/test_utils.py b/test/test_utils.py index 3ad9bf73aaf78..b745e771abd12 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -19,7 +19,7 @@ import torch.utils.cpp_extension from torch.autograd._functions.utils import check_onnx_broadcast from torch.onnx.symbolic_opset9 import _prepare_onnx_paddings -from torch.testing._internal.common_utils import load_tests, IS_SANDCASTLE, IS_WINDOWS +from torch.testing._internal.common_utils import load_tests, IS_FBCODE, IS_SANDCASTLE, IS_WINDOWS # load_tests from torch.testing._internal.common_utils is used to automatically filter tests for # sharding on sandcastle. This line silences flake warnings @@ -610,6 +610,7 @@ def test_bottleneck_cuda(self): from torch.utils.collect_env import get_pretty_env_info +@unittest.skipIf(IS_FBCODE, "runs pip which is not available internally") class TestCollectEnv(TestCase): def test_smoke(self): info_output = get_pretty_env_info() diff --git a/test/test_view_ops.py b/test/test_view_ops.py index 6c65457ae24f1..49ffbd872f154 100644 --- a/test/test_view_ops.py +++ b/test/test_view_ops.py @@ -6,16 +6,17 @@ from itertools import product, permutations, combinations from functools import partial import random +from torch._C import dtype from torch.testing import make_tensor from torch.testing._internal.common_utils import ( - TestCase, run_tests, suppress_warnings, gradcheck, gradgradcheck, + IS_FBCODE, TestCase, run_tests, skipIfMps, suppress_warnings, gradcheck, gradgradcheck, numpy_to_torch_dtype_dict, skipIfTorchDynamo ) from torch.testing._internal.common_device_type import \ - (instantiate_device_type_tests, onlyCPU, dtypes, onlyNativeDeviceTypes, skipMeta) + (dtypesIfMPS, instantiate_device_type_tests, onlyCPU, dtypes, onlyNativeDeviceTypes, skipMeta) from torch.testing._internal.common_dtype import ( - all_types_and_complex_and, complex_types, all_types_and, floating_and_complex_types_and, + all_types_and_complex_and, complex_types, all_types_and, floating_and_complex_types_and, integral_types_and, ) # TODO: replace this with make_tensor() in common_utils.py @@ -102,7 +103,7 @@ def is_view_of(self, base, other): # Note: only validates storage on native device types # because some accelerators, like XLA, do not expose storage if base.device.type == 'cpu' or base.device.type == 'cuda': - if base.storage().data_ptr() != other.storage().data_ptr(): + if base._storage().data_ptr() != other._storage().data_ptr(): return False return True @@ -369,6 +370,7 @@ def test_view_tensor_dsplit(self, device, dtype): self.assertEqual(t_dsplit[1][2, 2, 0], t[2, 2, 2]) @onlyNativeDeviceTypes + @skipIfMps @dtypes(*all_types_and(torch.half, torch.bfloat16)) def test_imag_noncomplex(self, device, dtype): t = torch.ones((5, 5), dtype=dtype, device=device) @@ -409,6 +411,7 @@ def compare_with_numpy(contiguous_input=True): @onlyNativeDeviceTypes @dtypes(*complex_types()) + @skipIfMps def test_conj_imag_view(self, device, dtype) -> None: t = _make_tensor((4, 5,), dtype, device) t_numpy_conj = torch.from_numpy(t.cpu().numpy().conj()).to(device=device) @@ -423,6 +426,7 @@ def test_conj_imag_view(self, device, dtype) -> None: self.assertTrue(v_imag.is_neg()) @onlyNativeDeviceTypes + @skipIfMps def test_conj_view_with_shared_memory(self, device) -> None: a = _make_tensor((4, 5,), torch.cfloat, device) b = a.conj() @@ -857,6 +861,7 @@ def test_advanced_indexing_nonview(self, device): nv[1, 1] = 0 self.assertNotEqual(t[2, 2], nv[1, 1]) + @unittest.skipIf(IS_FBCODE, "TorchScript backend not yet supported in FBCODE/OVRSOURCE builds") def test_advanced_indexing_assignment(self, device): t = torch.ones(3, 3, device=device) rows = torch.tensor([[0, 0], [2, 2]], device=device) @@ -865,6 +870,7 @@ def test_advanced_indexing_assignment(self, device): self.assertEqual(t[2, 2], 0) @unittest.skip("See https://github.com/pytorch/pytorch/pull/32720") + @skipIfMps def test_chunk_view(self, device): t = torch.zeros(3, 3, device=device) l = torch.chunk(t, 3) @@ -926,6 +932,12 @@ def test_view_copy(self, device): self.assertEqual(a_view_copy, a_view) self.assertEqual(a.grad, a_ref.grad) + # Testing that the output of a view_copy kernel (by default) is contiguous. + def test_view_copy_output_contiguous(self, device): + a = torch.randn(4, 4, 4, 4, device=device).to(memory_format=torch.channels_last) + b = torch.ops.aten.slice_copy(a, 0, 0, 2) + self.assertTrue(b.is_contiguous()) + def test_view_copy_out(self, device): a = torch.randn(2, 2, device=device) out = torch.empty(2, device=device) @@ -1293,6 +1305,7 @@ def test_big_transpose(self, device): t2 = torch.from_numpy(t.cpu().numpy().transpose()) self.assertEqual(t1, t2) + @skipIfMps def test_T(self, device): a = torch.randn(2, 3, 4, device=device) t1 = a.T @@ -1304,6 +1317,7 @@ def test_T(self, device): self.assertEqual(scalar, scalar.T) @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) + @dtypesIfMPS(*integral_types_and(torch.bool, torch.half, torch.float32)) def test_transposes(self, device, dtype): for op in ("T", "H", "mT", "mH", "adjoint"): shapes = ((), (2, 3), (2, 3, 4)) if op[0] == "m" or op == "adjoint" else ((), (2, 3),) @@ -1320,6 +1334,7 @@ def test_transposes(self, device, dtype): self.assertEqual(t2, t1) @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) + @dtypesIfMPS(*integral_types_and(torch.bool, torch.half, torch.float32)) def test_transposes_errors(self, device, dtype): for op in ("H", "mT", "mH", "adjoint"): shapes = ((2,), (2, 3, 4)) if op == "H" else ((2,),) @@ -1330,6 +1345,7 @@ def test_transposes_errors(self, device, dtype): if op == "adjoint": t1 = t1() + @skipIfMps def test_python_types(self, device): a1 = torch.randn((1, 2), device=device, dtype=torch.float64) a2 = torch.randn((1, 2), device=device, dtype=float) @@ -1371,6 +1387,7 @@ def test_helper(shape, numel, memory_format, device): @onlyNativeDeviceTypes @dtypes(torch.int64, torch.float, torch.complex128) + @dtypesIfMPS(torch.int64, torch.float) def test_transpose_invalid(self, device, dtype): for fn in (torch.swapdims, torch.swapaxes, torch.transpose): shape = _rand_shape(4, min_size=5, max_size=10) @@ -1384,6 +1401,7 @@ def test_transpose_invalid(self, device, dtype): fn(x, 0, 5) @dtypes(torch.int64, torch.float, torch.complex128) + @dtypesIfMPS(torch.int64, torch.float) def test_transpose_vs_numpy(self, device, dtype): for fn in (torch.swapdims, torch.swapaxes, torch.transpose): for nd in range(5): @@ -1446,6 +1464,7 @@ def _test_atleast_dim(self, torch_fn, np_fn, device, dtype): # TODO: are these view ops? @dtypes(*all_types_and_complex_and(torch.half)) + @dtypesIfMPS(*integral_types_and(torch.bool, torch.half, torch.float32)) def test_atleast(self, device, dtype): self._test_atleast_dim(torch.atleast_1d, np.atleast_1d, device, dtype) self._test_atleast_dim(torch.atleast_2d, np.atleast_2d, device, dtype) @@ -1549,6 +1568,7 @@ def test_broadcast_shapes(self, device): # Skip BFloat16 since numpy does not support it @dtypes(*all_types_and_complex_and(torch.half, torch.bool)) + @dtypesIfMPS(*integral_types_and(torch.bool, torch.half, torch.float32)) def test_broadcast_to(self, device, dtype): def can_broadcast(s0, s1): # s0.dim() <= s1.dim(), reverse s0 and s1 to compare trailing dimension @@ -1576,6 +1596,7 @@ def can_broadcast(s0, s1): r"must match the existing size \(\d\)"): torch.broadcast_to(t, s1) + @skipIfMps def test_view(self, device): tensor = torch.rand(15, device=device) template = torch.rand(3, 5, device=device) @@ -1652,6 +1673,7 @@ def test_view(self, device): self.assertEqual(tensor.view(1, 6, 2, 1), contig_tensor.view(1, 6, 2, 1)) @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) + @dtypesIfMPS(*integral_types_and(torch.bool, torch.half, torch.float32)) def test_reshape_view_semantics(self, device, dtype): tensor = make_tensor((15, 4), dtype=dtype, device=device) target = (20, 3) @@ -1786,13 +1808,14 @@ def test_tensor_split_errors(self, device): + ' zero-dimensional or one-dimensional tensor, but got a tensor with 2 dims'): torch.tensor_split(torch.rand(S, device=device), torch.tensor(((1,),)), 0) + @skipIfMps def test_resize_all_dtypes_and_devices(self, device): shape = (2, 2) for dt in all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool): x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device) x.resize_(shape) self.assertEqual(shape, x.shape) - + @skipIfMps def test_resize_as_all_dtypes_and_devices(self, device): for dt in all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool): x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device) @@ -1808,6 +1831,7 @@ def test_resize_overflow(self, device): with self.assertRaisesRegex(RuntimeError, 'overflow'): x.resize_([8, 8, 2**29, 2**29]) + @skipIfMps def test_view_all_dtypes_and_devices(self, device): for dt in all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool): x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device) diff --git a/test/test_xnnpack_integration.py b/test/test_xnnpack_integration.py index 9e510d1715b10..17ac2d9e7fc3a 100644 --- a/test/test_xnnpack_integration.py +++ b/test/test_xnnpack_integration.py @@ -14,7 +14,7 @@ import io import itertools -from torch.testing._internal.common_utils import TEST_WITH_TSAN +from torch.testing._internal.common_utils import IS_FBCODE, TEST_WITH_TSAN @unittest.skipUnless(torch.backends.xnnpack.enabled, " XNNPACK must be enabled for these tests." @@ -987,6 +987,7 @@ def validate_transform_conv1d_to_conv2d( torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3) + @unittest.skipIf(IS_FBCODE, "T137513244") def test_conv1d_basic(self): batch_size_list = range(1, 3) input_channels_per_group_list = range(10, 12) diff --git a/third_party/build_bundled.py b/third_party/build_bundled.py index 4da1b84a6f32e..d60a2c1354fd2 100644 --- a/third_party/build_bundled.py +++ b/third_party/build_bundled.py @@ -181,9 +181,14 @@ def squeeze(t): ), help="location to output new bundled licenses file", ) - + parser.add_argument( + "--include-files", + action="store_true", + default=False, + help="include actual license terms to the output", + ) args = parser.parse_args() fname = args.out_file print(f"+ Writing bundled licenses to {args.out_file}") with open(fname, 'w') as fid: - create_bundled(third_party, fid) + create_bundled(third_party, fid, args.include_files) diff --git a/third_party/fbgemm b/third_party/fbgemm index 4d1738b3142a6..908a8f361ac5c 160000 --- a/third_party/fbgemm +++ b/third_party/fbgemm @@ -1 +1 @@ -Subproject commit 4d1738b3142a6cb0c032cd639e239566010b054a +Subproject commit 908a8f361ac5c6103e55fbbb38ef8110457ff6eb diff --git a/third_party/gloo b/third_party/gloo index 5b14351326313..4a5e339b76426 160000 --- a/third_party/gloo +++ b/third_party/gloo @@ -1 +1 @@ -Subproject commit 5b143513263133af2b95547e97c07cebeb72bf72 +Subproject commit 4a5e339b764261d20fc409071dc7a8b8989aa195 diff --git a/third_party/gloo.BUILD b/third_party/gloo.BUILD index 3f623e54e6ad4..e9deaa13fc63f 100644 --- a/third_party/gloo.BUILD +++ b/third_party/gloo.BUILD @@ -75,8 +75,7 @@ cc_library( ] ) + if_cuda(glob(["gloo/cuda*.cc"])), copts = [ - "-std=gnu++11", - "-std=c++11", + "-std=c++17", ], visibility = ["//visibility:public"], deps = [":gloo_headers"] + if_cuda( diff --git a/third_party/ideep b/third_party/ideep index ececd0a4f53c3..e533c771a1e75 160000 --- a/third_party/ideep +++ b/third_party/ideep @@ -1 +1 @@ -Subproject commit ececd0a4f53c39f2d91caaddee0de1cd214f5b99 +Subproject commit e533c771a1e75a1c225c14b2261eefa62681d9e6 diff --git a/third_party/kineto b/third_party/kineto index 0703c78999061..6c1629809068e 160000 --- a/third_party/kineto +++ b/third_party/kineto @@ -1 +1 @@ -Subproject commit 0703c78999061b8329dfab7ec5046fc5764a5573 +Subproject commit 6c1629809068efd78a8d56b4aa479c7ec49ae562 diff --git a/third_party/mkl-dnn.BUILD b/third_party/mkl-dnn.BUILD index fb41d31e89a84..6179b860c2030 100644 --- a/third_party/mkl-dnn.BUILD +++ b/third_party/mkl-dnn.BUILD @@ -55,8 +55,8 @@ template_rule( substitutions = { "@DNNL_VERSION_MAJOR@": "2", "@DNNL_VERSION_MINOR@": "7", - "@DNNL_VERSION_PATCH@": "0", - "@DNNL_VERSION_HASH@": "650085b2f3643aad05c629425983491d63b5c289", + "@DNNL_VERSION_PATCH@": "2", + "@DNNL_VERSION_HASH@": "fbec3e25a559ee252022ae066817b204e106a6ba", }, ) diff --git a/third_party/pybind11 b/third_party/pybind11 index aa304c9c7d725..80dc998efced8 160000 --- a/third_party/pybind11 +++ b/third_party/pybind11 @@ -1 +1 @@ -Subproject commit aa304c9c7d725ffb9d10af08a3b34cb372307020 +Subproject commit 80dc998efced8ceb2be59756668a7e90e8bef917 diff --git a/third_party/xnnpack.buck.bzl b/third_party/xnnpack.buck.bzl index ee07488e26749..41f6e2e7c8150 100644 --- a/third_party/xnnpack.buck.bzl +++ b/third_party/xnnpack.buck.bzl @@ -1,4 +1,5 @@ load("//tools/build_defs:fb_xplat_cxx_library.bzl", "fb_xplat_cxx_library") +load("//tools/build_defs:fbsource_utils.bzl", "is_arvr_mode") load("//tools/build_defs:glob_defs.bzl", "subdir_glob") load("//tools/build_defs:platform_defs.bzl", "ANDROID", "APPLE", "APPLETVOS", "CXX", "IOS", "MACOSX", "WINDOWS") load( @@ -237,6 +238,10 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F fb_xplat_cxx_library( name = "ukernels_sse", + srcs = (select({ + "DEFAULT": [], + "ovr_config//os:macos-x86_64": PROD_SSE_MICROKERNEL_SRCS, + }) if is_arvr_mode() else []), headers = subdir_glob([ ("XNNPACK/src", "**/*.c"), ("XNNPACK/src", "**/*.h"), @@ -259,12 +264,12 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = [ + platform_srcs = ([ ( "x86|x86_64|platform009|platform010", PROD_SSE_MICROKERNEL_SRCS, ), - ], + ] if not is_arvr_mode() else []), preferred_linkage = "static", preprocessor_flags = [ "-DXNN_LOG_LEVEL=0", @@ -316,6 +321,10 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F fb_xplat_cxx_library( name = "ukernels_sse2", + srcs = (select({ + "DEFAULT": [], + "ovr_config//os:macos-x86_64": PROD_SSE2_MICROKERNEL_SRCS, + }) if is_arvr_mode() else []), headers = subdir_glob([ ("XNNPACK/src", "**/*.c"), ("XNNPACK/src", "**/*.h"), @@ -338,12 +347,12 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = [ + platform_srcs = ([ ( "x86|x86_64|platform009|platform010", PROD_SSE2_MICROKERNEL_SRCS, ), - ], + ] if not is_arvr_mode() else []), preferred_linkage = "static", preprocessor_flags = [ "-DXNN_LOG_LEVEL=0", @@ -397,6 +406,10 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F fb_xplat_cxx_library( name = "ukernels_ssse3", + srcs = (select({ + "DEFAULT": [], + "ovr_config//os:macos-x86_64": PROD_SSSE3_MICROKERNEL_SRCS, + }) if is_arvr_mode() else []), headers = subdir_glob([ ("XNNPACK/src", "**/*.c"), ("XNNPACK/src", "**/*.h"), @@ -419,12 +432,12 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = [ + platform_srcs = ([ ( "x86|x86_64|platform009|platform010", PROD_SSSE3_MICROKERNEL_SRCS, ), - ], + ] if not is_arvr_mode() else []), preferred_linkage = "static", preprocessor_flags = [ "-DXNN_LOG_LEVEL=0", @@ -478,6 +491,10 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F fb_xplat_cxx_library( name = "ukernels_sse41", + srcs = (select({ + "DEFAULT": [], + "ovr_config//os:macos-x86_64": PROD_SSE41_MICROKERNEL_SRCS, + }) if is_arvr_mode() else []), headers = subdir_glob([ ("XNNPACK/src", "**/*.c"), ("XNNPACK/src", "**/*.h"), @@ -500,12 +517,12 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = [ + platform_srcs = ([ ( "x86|x86_64|platform009|platform010", PROD_SSE41_MICROKERNEL_SRCS, ), - ], + ] if not is_arvr_mode() else []), preferred_linkage = "static", preprocessor_flags = [ "-DXNN_LOG_LEVEL=0", @@ -559,6 +576,10 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F fb_xplat_cxx_library( name = "ukernels_avx", + srcs = (select({ + "DEFAULT": [], + "ovr_config//os:macos-x86_64": PROD_AVX_MICROKERNEL_SRCS, + }) if is_arvr_mode() else []), headers = subdir_glob([ ("XNNPACK/src", "**/*.h"), ("XNNPACK/src", "**/*.c"), @@ -582,12 +603,12 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = [ + platform_srcs = ([ ( "x86|x86_64|platform009|platform010", PROD_AVX_MICROKERNEL_SRCS, ), - ], + ] if not is_arvr_mode() else []), preferred_linkage = "static", preprocessor_flags = [ "-DXNN_LOG_LEVEL=0", @@ -640,6 +661,10 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F fb_xplat_cxx_library( name = "ukernels_f16c", + srcs = (select({ + "DEFAULT": [], + "ovr_config//os:macos-x86_64": PROD_F16C_MICROKERNEL_SRCS, + }) if is_arvr_mode() else []), headers = subdir_glob([ ("XNNPACK/src", "**/*.h"), ("XNNPACK/src", "**/*.c"), @@ -663,12 +688,12 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = [ + platform_srcs = ([ ( "x86|x86_64|platform009|platform010", PROD_F16C_MICROKERNEL_SRCS, ), - ], + ] if not is_arvr_mode() else []), platforms = (APPLE, ANDROID, CXX, WINDOWS), preferred_linkage = "static", preprocessor_flags = [ @@ -723,6 +748,10 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F fb_xplat_cxx_library( name = "ukernels_xop", + srcs = (select({ + "DEFAULT": [], + "ovr_config//os:macos-x86_64": PROD_XOP_MICROKERNEL_SRCS, + }) if is_arvr_mode() else []), headers = subdir_glob([ ("XNNPACK/src", "**/*.h"), ("XNNPACK/src", "**/*.c"), @@ -746,12 +775,12 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = [ + platform_srcs = ([ ( "x86|x86_64|platform009|platform010", PROD_XOP_MICROKERNEL_SRCS, ), - ], + ] if not is_arvr_mode() else []), preferred_linkage = "static", preprocessor_flags = [ "-DXNN_LOG_LEVEL=0", @@ -804,6 +833,10 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F fb_xplat_cxx_library( name = "ukernels_fma3", + srcs = (select({ + "DEFAULT": [], + "ovr_config//os:macos-x86_64": PROD_FMA3_MICROKERNEL_SRCS, + }) if is_arvr_mode() else []), headers = subdir_glob([ ("XNNPACK/src", "**/*.h"), ("XNNPACK/src", "**/*.c"), @@ -829,12 +862,12 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = [ + platform_srcs = ([ ( "x86|x86_64|platform009|platform010", PROD_FMA3_MICROKERNEL_SRCS, ), - ], + ] if not is_arvr_mode() else []), preferred_linkage = "static", preprocessor_flags = [ "-DXNN_LOG_LEVEL=0", @@ -901,6 +934,10 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F fb_xplat_cxx_library( name = "ukernels_avx2", + srcs = (select({ + "DEFAULT": [], + "ovr_config//os:macos-x86_64": PROD_AVX2_MICROKERNEL_SRCS, + }) if is_arvr_mode() else []), headers = subdir_glob([ ("XNNPACK/src", "**/*.c"), ("XNNPACK/src", "**/*.h"), @@ -928,12 +965,12 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = [ + platform_srcs = ([ ( "x86|x86_64|platform009|platform010", PROD_AVX2_MICROKERNEL_SRCS, ), - ], + ] if not is_arvr_mode() else []), preferred_linkage = "static", preprocessor_flags = [ "-DXNN_LOG_LEVEL=0", @@ -1006,6 +1043,10 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F fb_xplat_cxx_library( name = "ukernels_avx512", + srcs = (select({ + "DEFAULT": [], + "ovr_config//os:macos-x86_64": PROD_AVX512F_MICROKERNEL_SRCS, + }) if is_arvr_mode() else []), headers = subdir_glob([ ("XNNPACK/src", "**/*.c"), ("XNNPACK/src", "**/*.h"), @@ -1029,12 +1070,12 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = [ + platform_srcs = ([ ( "x86|x86_64|platform009|platform010", PROD_AVX512F_MICROKERNEL_SRCS, ), - ], + ] if not is_arvr_mode() else []), preferred_linkage = "static", preprocessor_flags = [ "-DXNN_LOG_LEVEL=0", @@ -1087,6 +1128,10 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F fb_xplat_cxx_library( name = "ukernels_avx512skx", + srcs = (select({ + "DEFAULT": [], + "ovr_config//os:macos-x86_64": PROD_AVX512SKX_MICROKERNEL_SRCS, + }) if is_arvr_mode() else []), headers = subdir_glob([ ("XNNPACK/src", "**/*.c"), ("XNNPACK/src", "**/*.h"), @@ -1118,12 +1163,12 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], - platform_srcs = [ + platform_srcs = ([ ( "x86|x86_64|platform009|platform010", PROD_AVX512SKX_MICROKERNEL_SRCS, ), - ], + ] if not is_arvr_mode() else []), preferred_linkage = "static", preprocessor_flags = [ "-DXNN_LOG_LEVEL=0", diff --git a/tools/BUCK.bzl b/tools/BUCK.bzl index e61ab02e48a26..58a49fded0eec 100644 --- a/tools/BUCK.bzl +++ b/tools/BUCK.bzl @@ -62,10 +62,11 @@ def define_tools_targets( ("code_analyzer", "gen_oplist.py"), ("code_analyzer", "gen_op_registration_allowlist.py"), ]), - base_module = "", + base_module = "tools.code_analyzer", tests = [ ":gen_oplist_test", ], + visibility = ["PUBLIC"], deps = [ ":gen_selected_mobile_ops_header", torchgen_deps, @@ -75,7 +76,7 @@ def define_tools_targets( python_binary( name = "gen_oplist", - main_module = "gen_oplist", + main_module = "tools.code_analyzer.gen_oplist", visibility = ["PUBLIC"], deps = [ ":gen_oplist_lib", @@ -211,6 +212,18 @@ def define_tools_targets( "gen_vulkan_spv.py", ], base_module = "", + deps = [ + torchgen_deps, + ":gen_aten_vulkan_glsl_lib", + ], + ) + + python_library( + name = "gen_aten_vulkan_glsl_lib", + srcs = [ + "gen_vulkan_glsl.py", + ], + base_module = "tools", deps = [ torchgen_deps, ], @@ -223,6 +236,20 @@ def define_tools_targets( "PUBLIC", ], deps = [ + ":gen_aten_vulkan_glsl_lib", + ":gen_aten_vulkan_spv_lib", + ], + ) + + python_test( + name = "vulkan_codegen_test", + srcs = [ + "test/test_vulkan_codegen.py", + ], + contacts = contacts, + visibility = ["PUBLIC"], + deps = [ + ":gen_aten_vulkan_glsl_lib", ":gen_aten_vulkan_spv_lib", ], ) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 7ddbe8dd6cf70..0eded5c1ab53d 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -200,7 +200,7 @@ # preferable since it would be less efficient. # # NB: The parameter names here MUST be consistent with the parameter names -# in Decalarations.yaml +# in native_functions.yaml - name: abs(Tensor self) -> Tensor self: grad * self.sgn() result: handle_r_to_c(result.scalar_type(), self_t.conj() * self_p.sgn()) @@ -526,10 +526,6 @@ self: grad.diagonal(offset, dim1, dim2) result: auto_linear -- name: diag(Tensor self, int diagonal=0) -> Tensor - self: diag_backward_symint(grad, self.sym_sizes(), diagonal) - result: auto_linear - - name: diagonal(Tensor(a) self, int offset=0, int dim1=0, int dim2=1) -> Tensor(a) self: diagonal_backward_symint(grad, self.sym_sizes(), offset, dim1, dim2) result: auto_linear @@ -1148,6 +1144,14 @@ input, weight, bias: "grad.defined() ? native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, eps, grad_input_mask) : std::tuple()" result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, training, eps) +- name: _native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor) + input, weight, bias: "grad.defined() ? native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, eps, grad_input_mask) : std::tuple()" + result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, training, eps) + +- name: _native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor) + input, weight, bias: "grad.defined() ? native_batch_norm_backward(grad, input, weight, Tensor(), Tensor(), result1, result2, training, eps, grad_input_mask) : std::tuple()" + result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, Tensor(), Tensor(), result1, result2, training, eps) + - name: native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_invstd, bool train, float eps, bool[3] output_mask) -> (Tensor, Tensor, Tensor) input, weight, grad_out: batchnorm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_out, running_mean, running_var, train, eps, save_mean, save_invstd, grad_input_mask) save_mean: not_implemented("native_batch_norm_backward save_mean") @@ -1164,7 +1168,7 @@ rstd: not_implemented("native_layer_norm_backward rstd") - name: native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor) - input, weight, bias: "GradMode::is_enabled() || grads[1].defined() || grads[2].defined() ? infinitely_differentiable_native_group_norm_backward(grads[0], grads[1], grads[2], input, result1, result2, weight, N, C, HxW, group, eps, grad_input_mask) : (grads[0].defined() ? native_group_norm_backward_symint(grads[0].is_contiguous() ? grads[0] : grads[0].contiguous(), input.is_contiguous() ? input : input.contiguous(), result1, result2, weight, N, C, HxW, group, grad_input_mask) : std::tuple())" + input, weight, bias: "GradMode::is_enabled() || grads[1].defined() || grads[2].defined() ? infinitely_differentiable_native_group_norm_backward(grads[0], grads[1], grads[2], input, result1, result2, weight, N, C, HxW, group, eps, grad_input_mask) : (grads[0].defined() ? native_group_norm_backward_symint(grads[0].device().is_xpu() ? grads[0] : grads[0].contiguous(), input.device().is_xpu() ? input : input.contiguous(), result1, result2, weight, N, C, HxW, group, grad_input_mask) : std::tuple())" result0: group_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, result1, result2, group) result1: group_norm_mean_jvp(input_t, result1, group) result2: group_norm_invstd_jvp(input_p, input_t, result1, result2, group) @@ -1383,16 +1387,16 @@ src: grad.gather(dim, index) result: scatter_add(self_t, dim, index, src_t) -- name: select.int(Tensor(a) self, int dim, int index) -> Tensor(a) +- name: select.int(Tensor(a) self, int dim, SymInt index) -> Tensor(a) dispatch: Default: self: select_backward_symint(grad, self.sym_sizes(), dim, index) result: auto_linear AutogradNestedTensor: - self: _nested_select_backward(grad, self, dim, index) + self: _nested_select_backward_symint(grad, self, dim, index) -- name: select_backward(Tensor grad_output, SymInt[] input_sizes, int dim, int index) -> Tensor - grad_output: grad.select(dim, index) +- name: select_backward(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt index) -> Tensor + grad_output: grad.select_symint(dim, index) result: auto_linear - name: sigmoid(Tensor self) -> Tensor @@ -1439,9 +1443,9 @@ src: grad.slice_symint(dim, start, end, step) result: auto_linear -- name: select_scatter(Tensor self, Tensor src, int dim, int index) -> Tensor - self: select_scatter(grad, zeros_like(src), dim, index) - src: grad.select(dim, index) +- name: select_scatter(Tensor self, Tensor src, int dim, SymInt index) -> Tensor + self: select_scatter_symint(grad, zeros_like(src), dim, index) + src: grad.select_symint(dim, index) result: auto_linear - name: diagonal_scatter(Tensor self, Tensor src, int offset=0, int dim1=0, int dim2=1) -> Tensor @@ -1510,12 +1514,12 @@ self: unsqueeze_to(grad, dim, self.sym_sizes()) result: auto_linear -- name: std.correction(Tensor self, int[1]? dim, *, int? correction, bool keepdim=False) -> Tensor +- name: std.correction(Tensor self, int[1]? dim=None, *, int? correction=None, bool keepdim=False) -> Tensor self: std_backward(result, grad, self, dim, correction, keepdim) # pointwise (variance) + sum + sqrt result: (at::real(var_backward(self_t.conj(), self_p, dim, correction, true).sum(dim.value_or(IntArrayRef({})), keepdim)) / (2. * result)).masked_fill_(result == 0, 0) -- name: std_mean.correction(Tensor self, int[1]? dim, *, int? correction, bool keepdim=False) -> (Tensor, Tensor) +- name: std_mean.correction(Tensor self, int[1]? dim=None, *, int? correction=None, bool keepdim=False) -> (Tensor, Tensor) self: std_mean_backward(grads[0], grads[1], self, result0, dim, correction, keepdim) result0: (at::real(var_backward(self_t.conj(), self_p, dim, correction, true).sum(dim.value_or(IntArrayRef({})), keepdim)) / (2. * result0)).masked_fill_(result0 == 0, 0) # linear @@ -1658,7 +1662,7 @@ - name: _to_dense(Tensor self, ScalarType? dtype=None) -> Tensor self: to_dense_backward(grad, self) -- name: to_sparse(Tensor self) -> Tensor +- name: to_sparse(Tensor self, *, Layout? layout=None, int[2]? blocksize=None) -> Tensor self: grad.to_dense() - name: to_sparse.sparse_dim(Tensor self, int sparse_dim) -> Tensor @@ -1719,12 +1723,12 @@ self: grad.squeeze(dim) result: auto_linear -- name: var.correction(Tensor self, int[1]? dim, *, int? correction, bool keepdim=False) -> Tensor +- name: var.correction(Tensor self, int[1]? dim=None, *, int? correction=None, bool keepdim=False) -> Tensor self: var_backward(grad, self, dim, correction, keepdim) # pointwise + sum result: at::real(var_backward(self_t.conj(), self_p, dim, correction, true).sum(dim.value_or(IntArrayRef({})), keepdim)) -- name: var_mean.correction(Tensor self, int[1]? dim, *, int? correction, bool keepdim=False) -> (Tensor, Tensor) +- name: var_mean.correction(Tensor self, int[1]? dim=None, *, int? correction=None, bool keepdim=False) -> (Tensor, Tensor) self: var_mean_backward(grads[0], grads[1], self, dim, correction, keepdim) result0: at::real(var_backward(self_t.conj(), self_p, dim, correction, true).sum(dim.value_or(IntArrayRef({})), keepdim)) # linear @@ -1770,7 +1774,7 @@ self: grad.to_dense().sparse_mask(mask).to_dense() mask: non_differentiable -- name: _sparse_coo_tensor_with_dims_and_tensors(SymInt sparse_dim, SymInt dense_dim, SymInt[] size, Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor +- name: _sparse_coo_tensor_with_dims_and_tensors(int sparse_dim, int dense_dim, SymInt[] size, Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor values: sparse_constructor_values_backward(grad, indices) - name: _sparse_sum.dim(Tensor self, int[1] dim) -> Tensor @@ -1829,13 +1833,13 @@ + binary_cross_entropy_with_logits_target_backward(target_t, self_p, target_p, weight, pos_weight, at::Reduction::None), reduction)" -- name: embedding(Tensor weight, Tensor indices, int padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor +- name: embedding(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor indices: non_differentiable weight: embedding_backward_symint(grad, indices, weight.sym_size(0), padding_idx, scale_grad_by_freq, sparse) result: auto_linear -- name: embedding_dense_backward(Tensor grad_output, Tensor indices, SymInt num_weights, int padding_idx, bool scale_grad_by_freq) -> Tensor - grad_output: embedding_dense_double_backward(grad, indices, padding_idx) +- name: embedding_dense_backward(Tensor grad_output, Tensor indices, SymInt num_weights, SymInt padding_idx, bool scale_grad_by_freq) -> Tensor + grad_output: embedding_dense_double_backward_symint(grad, indices, padding_idx) indices: non_differentiable result: auto_linear @@ -2086,53 +2090,6 @@ self: _upsample_nearest_exact3d_backward_symint(grad, output_size, self.sym_sizes(), scales_d, scales_h, scales_w) result: auto_linear -- name: upsample_linear1d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor - input: upsample_linear1d_backward_symint(grad, output_size, input.sym_sizes(), align_corners, scale_factors) - result: auto_linear - -- name: upsample_bilinear2d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor - input: upsample_bilinear2d_backward_symint(grad, output_size, input.sym_sizes(), align_corners, scale_factors) - result: auto_linear - -- name: _upsample_bilinear2d_aa.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor - input: _upsample_bilinear2d_aa_backward_symint(grad, output_size, input.sym_sizes(), align_corners, scale_factors) - result: auto_linear - -- name: upsample_trilinear3d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor - input: upsample_trilinear3d_backward_symint(grad, output_size, input.sym_sizes(), align_corners, scale_factors) - result: auto_linear - -- name: upsample_bicubic2d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor - input: upsample_bicubic2d_backward_symint(grad, output_size, input.sym_sizes(), align_corners, scale_factors) - result: auto_linear - -- name: _upsample_bicubic2d_aa.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor - input: _upsample_bicubic2d_aa_backward_symint(grad, output_size, input.sym_sizes(), align_corners, scale_factors) - -- name: upsample_nearest1d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor - input: upsample_nearest1d_backward_symint(grad, output_size, input.sym_sizes(), scale_factors) - result: auto_linear - -- name: _upsample_nearest_exact1d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor - input: _upsample_nearest_exact1d_backward_symint(grad, output_size, input.sym_sizes(), scale_factors) - result: auto_linear - -- name: upsample_nearest2d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor - input: upsample_nearest2d_backward_symint(grad, output_size, input.sym_sizes(), scale_factors) - result: auto_linear - -- name: _upsample_nearest_exact2d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor - input: _upsample_nearest_exact2d_backward_symint(grad, output_size, input.sym_sizes(), scale_factors) - result: auto_linear - -- name: upsample_nearest3d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor - input: upsample_nearest3d_backward_symint(grad, output_size, input.sym_sizes(), scale_factors) - result: auto_linear - -- name: _upsample_nearest_exact3d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor - input: _upsample_nearest_exact3d_backward_symint(grad, output_size, input.sym_sizes(), scale_factors) - result: auto_linear - - name: pixel_shuffle(Tensor self, int upscale_factor) -> Tensor self: pixel_unshuffle(grad, upscale_factor) result: auto_linear @@ -2210,19 +2167,19 @@ indices: non_differentiable result: auto_linear -- name: convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups) -> Tensor +- name: convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, SymInt[] padding, int[] dilation, bool transposed, SymInt[] output_padding, int groups) -> Tensor input, weight, bias: "grad.defined() ? convolution_backward_symint(grad, input, weight, bias->sym_sizes(), stride, padding, dilation, transposed, output_padding, groups, grad_input_mask) : std::tuple()" result: convolution_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, stride, padding, dilation, transposed, output_padding, groups) # TorchScript serializes calls to _convolution so this entry is present until that is changed to use convolution. # Note that the benchmark, deterministic, cudnn_enabled, and allow_tf32 flags are queried from the global context # by convolution_backward instead of being passed along from the forward pass. -- name: _convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor +- name: _convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, SymInt[] padding, int[] dilation, bool transposed, SymInt[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor input, weight, bias: "grad.defined() ? convolution_backward_symint(grad, input, weight, bias->sym_sizes(), stride, padding, dilation, transposed, output_padding, groups, grad_input_mask) : std::tuple()" result: _convolution_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, stride, padding, dilation, transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled, allow_tf32) -- name: convolution_backward(Tensor grad_output, Tensor input, Tensor weight, SymInt[]? bias_sizes, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor) - grad_output, input, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, input, stride, padding, dilation, transposed, output_padding, groups, grad_input_mask) +- name: convolution_backward(Tensor grad_output, Tensor input, Tensor weight, SymInt[]? bias_sizes, int[] stride, SymInt[] padding, int[] dilation, bool transposed, SymInt[] output_padding, int groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + grad_output, input, weight: _convolution_double_backward_symint(grads[0], grads[1], grads[2], grad_output, weight, input, stride, padding, dilation, transposed, output_padding, groups, grad_input_mask) result0: std::get<0>(convolution_backward_symint(grad_output_p, input_p, weight_t, bias_sizes, stride, padding, dilation, transposed, output_padding, groups, {true, false, false})) + std::get<0>(convolution_backward_symint(grad_output_t, input_p, weight_p, bias_sizes, stride, padding, dilation, transposed, output_padding, groups, {true, false, false})) result1: std::get<1>(convolution_backward_symint(grad_output_p, input_t, weight_p, bias_sizes, stride, padding, dilation, transposed, output_padding, groups, {false, true, false})) + std::get<1>(convolution_backward_symint(grad_output_t, input_p, weight_p, bias_sizes, stride, padding, dilation, transposed, output_padding, groups, {false, true, false})) result2: convolution_backward_jvp_grad_bias(grad_output_t, result2) @@ -2233,10 +2190,10 @@ - name: convolution_backward_overrideable(Tensor grad_output, Tensor input, Tensor weight, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias) grad_output, input, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, input, stride, padding, dilation, transposed, output_padding, groups, grad_input_mask) -- name: slow_conv_transpose2d(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] output_padding=0, int[2] dilation=1) -> Tensor +- name: slow_conv_transpose2d(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias=None, int[2] stride=1, SymInt[2] padding=0, SymInt[2] output_padding=0, int[2] dilation=1) -> Tensor self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, true, output_padding, 1, grad_input_mask) : std::tuple()" -- name: slow_conv_transpose3d(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] output_padding=0, int[3] dilation=1) -> Tensor +- name: slow_conv_transpose3d(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias=None, int[3] stride=1, SymInt[3] padding=0, SymInt[3] output_padding=0, int[3] dilation=1) -> Tensor self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, true, output_padding, 1, grad_input_mask) : std::tuple()" - name: _slow_conv2d_forward(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias, int[2] stride, int[2] padding) -> Tensor @@ -2245,20 +2202,20 @@ - name: _slow_conv2d_backward.output_mask(Tensor grad_output, Tensor self, Tensor weight, int[2] kernel_size, int[2] stride, int[2] padding, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias) grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, {{1, 1}}, false, {{0, 0}}, 1, grad_input_mask) -- name: _conv_depthwise2d(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias, int[2] stride, int[2] padding, int[2] dilation) -> Tensor +- name: _conv_depthwise2d(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias, int[2] stride, SymInt[2] padding, int[2] dilation) -> Tensor self, weight, bias: "grad.defined() ? convolution_backward_symint(grad.contiguous(), self, weight, bias->sym_sizes(), stride, padding, dilation, /*transposed=*/ false, /*output_padding=*/ {{0, 0}}, /*groups=*/ 1, grad_input_mask) : std::tuple()" -- name: conv_depthwise3d(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias, int[3] stride, int[3] padding, int[3] dilation) -> Tensor +- name: conv_depthwise3d(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias, int[3] stride, SymInt[3] padding, int[3] dilation) -> Tensor self, weight, bias: "grad.defined() ? convolution_backward_symint(grad.contiguous(), self, weight, bias->sym_sizes(), stride, padding, dilation, /*transposed=*/ false, /*output_padding=*/ {{0, 0, 0}}, /*groups=*/ 1, grad_input_mask) : std::tuple()" -- name: slow_conv3d_forward(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias, int[3] stride, int[3] padding) -> Tensor +- name: slow_conv3d_forward(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias, int[3] stride, SymInt[3] padding) -> Tensor self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, /*dilation=*/ {{1, 1, 1}}, false, /*output_padding=*/ {{0, 0, 0}}, 1, grad_input_mask) : std::tuple()" -- name: slow_conv_dilated2d(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1) -> Tensor - self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, false, std::vector(padding.size(), 0), 1, grad_input_mask) : std::tuple()" +- name: slow_conv_dilated2d(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias=None, int[2] stride=1, SymInt[2] padding=0, int[2] dilation=1) -> Tensor + self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, false, std::vector(padding.size(), 0), 1, grad_input_mask) : std::tuple()" -- name: slow_conv_dilated3d(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] dilation=1) -> Tensor - self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, false, std::vector(padding.size(), 0), 1, grad_input_mask) : std::tuple()" +- name: slow_conv_dilated3d(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias=None, int[3] stride=1, SymInt[3] padding=0, int[3] dilation=1) -> Tensor + self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, false, std::vector(padding.size(), 0), 1, grad_input_mask) : std::tuple()" - name: col2im(Tensor self, SymInt[2] output_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor self: im2col(grad, kernel_size, dilation, padding, stride) @@ -2511,53 +2468,6 @@ grad_output: _upsample_nearest_exact3d_symint(grad, output_size, scales_d, scales_h, scales_w) result: auto_linear -- name: upsample_linear1d_backward.vec(Tensor grad_output, SymInt[]? output_size, SymInt[] input_size, bool align_corners, float[]? scale_factors) -> Tensor - grad_output: upsample_linear1d_symint(grad, output_size, align_corners, scale_factors) - result: auto_linear - -- name: upsample_bilinear2d_backward.vec(Tensor grad_output, SymInt[]? output_size, SymInt[] input_size, bool align_corners, float[]? scale_factors) -> Tensor - grad_output: upsample_bilinear2d_symint(grad, output_size, align_corners, scale_factors) - result: auto_linear - -- name: _upsample_bilinear2d_aa_backward.vec(Tensor grad_output, SymInt[]? output_size, SymInt[] input_size, bool align_corners, float[]? scale_factors) -> Tensor - grad_output: _upsample_bilinear2d_aa_symint(grad, output_size, align_corners, scale_factors) - result: auto_linear - -- name: upsample_trilinear3d_backward.vec(Tensor grad_output, SymInt[]? output_size, SymInt[] input_size, bool align_corners, float[]? scale_factors) -> Tensor - grad_output: upsample_trilinear3d_symint(grad, output_size, align_corners, scale_factors) - result: auto_linear - -- name: upsample_bicubic2d_backward.vec(Tensor grad_output, SymInt[]? output_size, SymInt[] input_size, bool align_corners, float[]? scale_factors) -> Tensor - grad_output: upsample_bicubic2d_symint(grad, output_size, align_corners, scale_factors) - result: auto_linear - -- name: _upsample_bicubic2d_aa_backward.vec(Tensor grad_output, SymInt[]? output_size, SymInt[] input_size, bool align_corners, float[]? scale_factors) -> Tensor - grad_output: _upsample_bicubic2d_aa_symint(grad, output_size, align_corners, scale_factors) - -- name: upsample_nearest1d_backward.vec(Tensor grad_output, SymInt[]? output_size, SymInt[] input_size, float[]? scale_factors) -> Tensor - grad_output: upsample_nearest1d_symint(grad, output_size, scale_factors) - result: auto_linear - -- name: _upsample_nearest_exact1d_backward.vec(Tensor grad_output, SymInt[]? output_size, SymInt[] input_size, float[]? scale_factors) -> Tensor - grad_output: _upsample_nearest_exact1d_symint(grad, output_size, scale_factors) - result: auto_linear - -- name: upsample_nearest2d_backward.vec(Tensor grad_output, SymInt[]? output_size, SymInt[] input_size, float[]? scale_factors) -> Tensor - grad_output: upsample_nearest2d_symint(grad, output_size, scale_factors) - result: auto_linear - -- name: _upsample_nearest_exact2d_backward.vec(Tensor grad_output, SymInt[]? output_size, SymInt[] input_size, float[]? scale_factors) -> Tensor - grad_output: _upsample_nearest_exact2d_symint(grad, output_size, scale_factors) - result: auto_linear - -- name: upsample_nearest3d_backward.vec(Tensor grad_output, SymInt[]? output_size, SymInt[] input_size, float[]? scale_factors) -> Tensor - grad_output: upsample_nearest3d_symint(grad, output_size, scale_factors) - result: auto_linear - -- name: _upsample_nearest_exact3d_backward.vec(Tensor grad_output, SymInt[]? output_size, SymInt[] input_size, float[]? scale_factors) -> Tensor - grad_output: _upsample_nearest_exact3d_symint(grad, output_size, scale_factors) - result: auto_linear - - name: sigmoid_backward(Tensor grad_output, Tensor output) -> Tensor grad_output: sigmoid_backward(grad, output.conj()) output: grad.conj() * grad_output * (-2 * output.conj() + 1) @@ -2612,9 +2522,9 @@ # nnpack -- name: _nnpack_spatial_convolution(Tensor input, Tensor weight, Tensor? bias, int[2] padding, int[2] stride=1) -> Tensor +- name: _nnpack_spatial_convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[2] padding, int[2] stride=1) -> Tensor # NNPACK does not support strided convolutions in the backwards path, which is the reason why we are using the closest available function that does here. - input, weight, bias: "grad.defined() ? convolution_backward_symint(grad, input, weight, bias->sym_sizes(), stride, padding, std::vector(padding.size(), 1), false, std::vector(padding.size(), 0), 1, grad_input_mask) : std::tuple()" + input, weight, bias: "grad.defined() ? convolution_backward_symint(grad, input, weight, bias->sym_sizes(), stride, padding, std::vector(padding.size(), 1), false, std::vector(padding.size(), 0), 1, grad_input_mask) : std::tuple()" #LSTM MPS - name: _lstm_mps(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor, Tensor, Tensor) @@ -2645,14 +2555,14 @@ # miopen -- name: miopen_convolution_transpose(Tensor self, Tensor weight, Tensor? bias, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor +- name: miopen_convolution_transpose(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] output_padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, true, output_padding, groups, grad_input_mask) : std::tuple()" -- name: miopen_convolution(Tensor self, Tensor weight, Tensor? bias, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor - self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, false, std::vector(padding.size(), 0), groups, grad_input_mask) : std::tuple()" +- name: miopen_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor + self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, false, std::vector(padding.size(), 0), groups, grad_input_mask) : std::tuple()" -- name: miopen_depthwise_convolution(Tensor self, Tensor weight, Tensor? bias, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor - self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, false, std::vector(padding.size(), 0), groups, grad_input_mask) : std::tuple()" +- name: miopen_depthwise_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor + self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, false, std::vector(padding.size(), 0), groups, grad_input_mask) : std::tuple()" - name: miopen_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor) input, weight, bias: "grad.defined() ? (training ? miopen_batch_norm_backward(input, grad.contiguous(), weight, running_mean, running_var, result1, result2, epsilon) : native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, epsilon, grad_input_mask)) : std::tuple()" @@ -2671,8 +2581,8 @@ dropout_state: non_differentiable # mkldnn -- name: mkldnn_convolution(Tensor self, Tensor weight, Tensor? bias, int[] padding, int[] stride, int[] dilation, int groups) -> Tensor - self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, /*transposed=*/ false, /*output_padding=*/ std::vector(padding.size(), 0), groups, grad_input_mask) : std::tuple()" +- name: mkldnn_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, int[] stride, int[] dilation, int groups) -> Tensor + self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, /*transposed=*/ false, /*output_padding=*/ std::vector(padding.size(), 0), groups, grad_input_mask) : std::tuple()" - name: mkldnn_linear(Tensor self, Tensor weight, Tensor? bias=None) -> Tensor self, weight, bias: mkldnn_linear_backward(self, grad, weight, grad_input_mask) @@ -2689,7 +2599,7 @@ - name: _mkldnn_reshape(Tensor self, int[] shape) -> Tensor self: grad.reshape_symint(self.sym_sizes()) -# Nested Tensor +# NestedTensor - name: _nested_tensor_from_tensor_list(Tensor[] list, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor list: "grad.defined()? at::unbind(grad) : std::vector(list.size())" @@ -2710,6 +2620,15 @@ nested_size: non_differentiable nested_strides: non_differentiable +# Transformers +- name: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, bool compute_log_sumexp, bool is_causal=False) -> (Tensor, Tensor) + output_differentiability: [True, False] + query, key, value: _scaled_dot_product_efficient_attention_backward(grad, query, key, value, result0, result1, is_causal) + +- name: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seqlen_q, bool compute_log_sumexp=False, bool causal=False) -> (Tensor, Tensor) + output_differentiability: [True, False] + query, key, value: _efficient_attention_backward(grad, query, key, value, result0, result1, causal) + # fft - name: _fft_r2c(Tensor self, int[] dim, int normalization, bool onesided) -> Tensor self: fft_r2c_backward(grad, dim, normalization, onesided, self.sym_size(dim.back())) @@ -2825,6 +2744,10 @@ AutogradCUDA: self: grad.reshape_as(self) + 1 +- name: _test_inductor_realize(Tensor self) -> Tensor + self: grad + result: auto_linear + - name: _efficientzerotensor(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor output_differentiability: [False] @@ -2988,3 +2911,7 @@ - name: special_spherical_bessel_j0(Tensor x) -> Tensor x: non_differentiable + +- name: _reshape_copy(Tensor self, SymInt[] size) -> Tensor + self: grad.reshape_symint(self.sym_sizes()) + result: auto_linear diff --git a/tools/autograd/gen_autograd_functions.py b/tools/autograd/gen_autograd_functions.py index 7b120593eb539..3e9e125bfb9f6 100644 --- a/tools/autograd/gen_autograd_functions.py +++ b/tools/autograd/gen_autograd_functions.py @@ -291,7 +291,7 @@ for (auto i : c10::irange(prop.size())) { auto si = prop[i]; if (si.is_symbolic()) { - auto py_symint = py::cast(si.toSymIntNodeImpl()).release().ptr(); + auto py_symint = py::cast(si).release().ptr(); PyTuple_SetItem(tup, (Py_ssize_t) i, py_symint); } else { PyTuple_SetItem(tup, (Py_ssize_t) i, PyLong_FromUnsignedLong(si.as_int_unchecked())); @@ -313,7 +313,7 @@ """ GETTER_BODY_SYMINT = """\ -return prop.is_symbolic() ? py::cast(prop.toSymIntNodeImpl()).release().ptr() : PyLong_FromUnsignedLong(prop.as_int_unchecked()); +return prop.is_symbolic() ? py::cast(prop).release().ptr() : PyLong_FromUnsignedLong(prop.as_int_unchecked()); """ GETTER_BODY_DOUBLE = """\ diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py index f90ec74459de4..ee06a8ed12384 100644 --- a/tools/autograd/gen_python_functions.py +++ b/tools/autograd/gen_python_functions.py @@ -124,6 +124,7 @@ "_local_scalar_dense", "to", "_to_copy", + "_reshape_copy", "copy_sparse_to_sparse_", "copy_", "numpy_T", diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index 53bd60b76e6bd..3d1ff895c837f 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -295,6 +295,7 @@ "reflection_pad3d", "linalg_cholesky_ex", "linalg_eig", + "diagonal_copy", "select_backward", "diagonal_backward", "slice_backward", @@ -347,6 +348,7 @@ "conj_physical_", "_neg_view", "_reshape_alias", + "_reshape_copy", "_linalg_det", "lu_solve", "linalg_solve_triangular", diff --git a/tools/autograd/templates/python_functions.cpp b/tools/autograd/templates/python_functions.cpp index 57343a53ea982..eacf56b31d88e 100644 --- a/tools/autograd/templates/python_functions.cpp +++ b/tools/autograd/templates/python_functions.cpp @@ -5,7 +5,7 @@ #include #include -#include +#include #include "torch/csrc/autograd/generated/Functions.h" #include "torch/csrc/autograd/python_cpp_function.h" #include diff --git a/tools/autograd/templates/python_nested_functions.cpp b/tools/autograd/templates/python_nested_functions.cpp index cdfc4336163f4..5515ca6f8a0b3 100644 --- a/tools/autograd/templates/python_nested_functions.cpp +++ b/tools/autograd/templates/python_nested_functions.cpp @@ -4,7 +4,7 @@ #include "torch/csrc/Device.h" #include "torch/csrc/DynamicTypes.h" #include "torch/csrc/Exceptions.h" -#include "torch/csrc/autograd/python_special_functions.h" +#include "torch/csrc/autograd/python_nested_functions.h" #include "torch/csrc/autograd/python_return_types.h" #include "torch/csrc/autograd/python_variable.h" #include "torch/csrc/autograd/utils/wrap_outputs.h" @@ -47,6 +47,7 @@ namespace torch { namespace autograd { ${py_forwards} static PyMethodDef nested_functions[] = { + {NULL, NULL, 0, NULL}, ${py_method_defs} {NULL} }; @@ -54,6 +55,7 @@ static PyMethodDef nested_functions[] = { static PyObject* THPNestedVariableFunctionsModule = NULL; void initNestedFunctions(PyObject* module) { + nested_functions[0] = get_nested_functions_manual()[0]; static struct PyModuleDef def = { PyModuleDef_HEAD_INIT, "torch._C._nested", diff --git a/tools/autograd/templates/python_variable_methods.cpp b/tools/autograd/templates/python_variable_methods.cpp index e4df2a8dc61da..6ad042c0b903a 100644 --- a/tools/autograd/templates/python_variable_methods.cpp +++ b/tools/autograd/templates/python_variable_methods.cpp @@ -240,12 +240,7 @@ static PyObject * THPVariable_numel(PyObject* self, PyObject* args) if (jit::tracer::isTracing()) { return wrap(jit::tracer::getNumelOf(self_)); } else { - auto si = self_.sym_numel(); - if (si.is_symbolic()) { - return py::cast(si.toSymIntNodeImpl()).release().ptr(); - } else { - return THPUtils_packInt64(si.as_int_unchecked()); - } + return py::cast(self_.sym_numel()).release().ptr(); } END_HANDLE_TH_ERRORS } @@ -984,7 +979,7 @@ static PyObject * THPVariable_storage(PyObject* self, PyObject* arg) { HANDLE_TH_ERRORS if (check_has_torch_function(self)) { - return handle_torch_function(self, "storage"); + return handle_torch_function(self, "_storage"); } auto& self_ = THPVariable_Unpack(self); return createPyObject(self_.storage()); @@ -1140,7 +1135,7 @@ static PyObject* THPVariable_set_( { "set_()", "set_(Storage source)", - "set_(Storage source, int64_t storage_offset, IntArrayRef size, IntArrayRef stride=None)", + "set_(Storage source, SymInt storage_offset, SymIntArrayRef size, SymIntArrayRef stride=None)", "set_(Tensor source)", "set_(Tensor source, SymInt storage_offset, SymIntArrayRef size, SymIntArrayRef stride=None)", }, @@ -1186,19 +1181,19 @@ static PyObject* THPVariable_set_( " for argument 1 'storage'"); auto dispatch_set_ = [](const Tensor& self, Storage source, - int64_t storage_offset, - IntArrayRef size, - IntArrayRef stride) -> Tensor { + c10::SymInt storage_offset, + c10::SymIntArrayRef size, + c10::SymIntArrayRef stride) -> Tensor { pybind11::gil_scoped_release no_gil; - return self.set_(source, storage_offset, size, stride); + return self.set__symint(source, storage_offset, size, stride); }; return wrap(dispatch_set_( - self, storage, _r.toInt64(1), _r.intlist(2), _r.intlist(3))); + self, storage, _r.toSymInt(1), _r.symintlist(2), _r.symintlist(3))); } case 3: { // aten::set_.source_Tensor(Tensor(a!) self, Tensor source) -> Tensor(a!) auto dispatch_set_ = [](const Tensor& self, const Tensor& source) -> Tensor { - TORCH_INTERNAL_ASSERT(source.dtype() == self.dtype()); + TORCH_CHECK(source.dtype() == self.dtype(), "Could not set tensor of type ", source.dtype(), " to a tensor of type ", self.dtype()); pybind11::gil_scoped_release no_gil; return self.set_(source); }; diff --git a/tools/bazel.bzl b/tools/bazel.bzl index f7da1839930d2..3c6f98154aebb 100644 --- a/tools/bazel.bzl +++ b/tools/bazel.bzl @@ -1,5 +1,5 @@ load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library", "cc_test") -load("@rules_cuda//cuda:defs.bzl", "requires_cuda_enabled") +load("@rules_cuda//cuda:defs.bzl", "cuda_library", "requires_cuda_enabled") load("//c10/macros:cmake_configure_file.bzl", "cmake_configure_file") load("//tools/config:defs.bzl", "if_cuda") @@ -25,6 +25,7 @@ rules = struct( cc_library = cc_library, cc_test = cc_test, cmake_configure_file = cmake_configure_file, + cuda_library = cuda_library, filegroup = native.filegroup, genrule = _genrule, glob = native.glob, diff --git a/tools/code_analyzer/gen_oplist.py b/tools/code_analyzer/gen_oplist.py index 1e5d1277afcdf..18104ab30cb6c 100644 --- a/tools/code_analyzer/gen_oplist.py +++ b/tools/code_analyzer/gen_oplist.py @@ -127,7 +127,7 @@ def main(argv: List[Any]) -> None: default=False, required=False, ) - options = parser.parse_args() + options = parser.parse_args(argv) if os.path.isfile(options.model_file_list_path): print("Processing model file: ", options.model_file_list_path) @@ -186,4 +186,4 @@ def main(argv: List[Any]) -> None: if __name__ == "__main__": - main(sys.argv) + main(sys.argv[1:]) diff --git a/tools/code_coverage/README.md b/tools/code_coverage/README.md index 67adb445d053d..32fbc89e6aace 100644 --- a/tools/code_coverage/README.md +++ b/tools/code_coverage/README.md @@ -51,7 +51,7 @@ Great, you are ready to run the code coverage tool for the first time! Start fro ``` python oss_coverage.py --run-only=atest ``` -This command will run `atest` binary in `build/bin/` folder and generate reoports over the entire *Pytorch* folder. You can find the reports in `profile/summary`. But you may only be interested in the `aten` folder, in this case, try: +This command will run `atest` binary in `build/bin/` folder and generate reports over the entire *Pytorch* folder. You can find the reports in `profile/summary`. But you may only be interested in the `aten` folder, in this case, try: ``` python oss_coverage.py --run-only=atest --interest-only=aten ``` @@ -91,9 +91,9 @@ python oss_coverage.py --run-only=atest --interest-only=c10 --summary **2. Run tests yourself** -When you are developing a new feature, you may first run the tests yourself to make sure the implementation is all right and then want to learn its coverage. But sometimes the test take very long time and you don't want to wait to run it again when doing code coverage. In this case, you can use these arguments to accerate your development (make sure you build pytorch with the coverage option!): +When you are developing a new feature, you may first run the tests yourself to make sure the implementation is all right and then want to learn its coverage. But sometimes the test take very long time and you don't want to wait to run it again when doing code coverage. In this case, you can use these arguments to accelerate your development (make sure you build pytorch with the coverage option!): ``` -# run tests when you are devloping a new feature, assume the the test is `test_nn.py` +# run tests when you are developing a new feature, assume the test is `test_nn.py` python oss_coverage.py --run-only=test_nn.py # or you can run it yourself cd test/ && python test_nn.py diff --git a/tools/dynamo/verify_dynamo.py b/tools/dynamo/verify_dynamo.py new file mode 100644 index 0000000000000..df03e6331728b --- /dev/null +++ b/tools/dynamo/verify_dynamo.py @@ -0,0 +1,167 @@ +import os +import re +import subprocess +import sys +import traceback +import warnings + +from pkg_resources import packaging + +MIN_CUDA_VERSION = packaging.version.parse("11.6") +MIN_PYTHON_VERSION = (3, 7) + + +class VerifyDynamoError(BaseException): + pass + + +def check_python(): + if sys.version_info < MIN_PYTHON_VERSION: + raise VerifyDynamoError( + f"Python version not supported: {sys.version_info} " + f"- minimum requirement: {MIN_PYTHON_VERSION}" + ) + return sys.version_info + + +def check_torch(): + import torch + + return packaging.version.parse(torch.__version__) + + +# based on torch/utils/cpp_extension.py +def get_cuda_version(): + from torch.utils import cpp_extension + + CUDA_HOME = cpp_extension._find_cuda_home() + if not CUDA_HOME: + raise VerifyDynamoError(cpp_extension.CUDA_NOT_FOUND_MESSAGE) + + nvcc = os.path.join(CUDA_HOME, "bin", "nvcc") + cuda_version_str = ( + subprocess.check_output([nvcc, "--version"]) + .strip() + .decode(*cpp_extension.SUBPROCESS_DECODE_ARGS) + ) + cuda_version = re.search(r"release (\d+[.]\d+)", cuda_version_str) + if cuda_version is None: + raise VerifyDynamoError("CUDA version not found in `nvcc --version` output") + + cuda_str_version = cuda_version.group(1) + return packaging.version.parse(cuda_str_version) + + +def check_cuda(): + import torch + + if not torch.cuda.is_available(): + return None + + torch_cuda_ver = packaging.version.parse(torch.version.cuda) + + # check if torch cuda version matches system cuda version + cuda_ver = get_cuda_version() + if cuda_ver != torch_cuda_ver: + # raise VerifyDynamoError( + warnings.warn( + f"CUDA version mismatch, `torch` version: {torch_cuda_ver}, env version: {cuda_ver}" + ) + + if torch_cuda_ver < MIN_CUDA_VERSION: + # raise VerifyDynamoError( + warnings.warn( + f"(`torch`) CUDA version not supported: {torch_cuda_ver} " + f"- minimum requirement: {MIN_CUDA_VERSION}" + ) + if cuda_ver < MIN_CUDA_VERSION: + # raise VerifyDynamoError( + warnings.warn( + f"(env) CUDA version not supported: {cuda_ver} " + f"- minimum requirement: {MIN_CUDA_VERSION}" + ) + + return cuda_ver + + +def check_dynamo(backend, device, err_msg): + import torch + + if device == "cuda" and not torch.cuda.is_available(): + print(f"CUDA not available -- skipping CUDA check on {backend} backend\n") + return + + try: + import torch._dynamo as dynamo + + if device == "cuda": + import torch._inductor.utils as utils + + if not utils.has_triton(): + print( + f"WARNING: CUDA available but triton cannot be used. " + f"Your GPU may not be supported. " + f"Skipping CUDA check on {backend} backend\n" + ) + return + + dynamo.reset() + + @dynamo.optimize(backend, nopython=True) + def fn(x): + return x + x + + class Module(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x + x + + mod = Module() + opt_mod = dynamo.optimize(backend, nopython=True)(mod) + + for f in (fn, opt_mod): + x = torch.randn(10, 10).to(device) + x.requires_grad = True + y = f(x) + torch.testing.assert_close(y, x + x) + z = y.sum() + z.backward() + torch.testing.assert_close(x.grad, 2 * torch.ones_like(x)) + except Exception: + sys.stderr.write(traceback.format_exc() + "\n" + err_msg + "\n\n") + sys.exit(1) + + +_SANITY_CHECK_ARGS = ( + ("eager", "cpu", "CPU eager sanity check failed"), + ("eager", "cuda", "CUDA eager sanity check failed"), + ("aot_eager", "cpu", "CPU aot_eager sanity check failed"), + ("aot_eager", "cuda", "CUDA aot_eager sanity check failed"), + ("inductor", "cpu", "CPU inductor sanity check failed"), + ( + "inductor", + "cuda", + "CUDA inductor sanity check failed\n" + + "NOTE: Please check that you installed the correct hash/version of `triton`", + ), +) + + +def main(): + python_ver = check_python() + torch_ver = check_torch() + cuda_ver = check_cuda() + print( + f"Python version: {python_ver.major}.{python_ver.minor}.{python_ver.micro}\n" + f"`torch` version: {torch_ver}\n" + f"CUDA version: {cuda_ver}\n" + ) + for args in _SANITY_CHECK_ARGS: + check_dynamo(*args) + print("All required checks passed") + + +if __name__ == "__main__": + main() diff --git a/tools/gen_vulkan_glsl.py b/tools/gen_vulkan_glsl.py new file mode 100644 index 0000000000000..6d89da0c743cb --- /dev/null +++ b/tools/gen_vulkan_glsl.py @@ -0,0 +1,111 @@ +import copy +import os + +from collections import OrderedDict + +import yaml +from torchgen.code_template import CodeTemplate +from yaml.constructor import ConstructorError +from yaml.nodes import MappingNode + +try: + from yaml import CLoader as Loader +except ImportError: + from yaml import Loader # type: ignore[misc] + +# https://gist.github.com/pypt/94d747fe5180851196eb +class UniqueKeyLoader(Loader): + def construct_mapping(self, node, deep=False): # type: ignore[no-untyped-def] + if not isinstance(node, MappingNode): + raise ConstructorError( + None, + None, + "expected a mapping node, but found %s" % node.id, + node.start_mark, + ) + mapping = {} + for key_node, value_node in node.value: + key = self.construct_object(key_node, deep=deep) # type: ignore[no-untyped-call] + try: + hash(key) + except TypeError as e: + raise ConstructorError( + "while constructing a mapping", + node.start_mark, + "found unacceptable key ", + key_node.start_mark, + ) from e + # check for duplicate keys + if key in mapping: + raise ConstructorError( + "while constructing a mapping", + node.start_mark, + "found duplicate key", + key_node.start_mark, + ) + value = self.construct_object(value_node, deep=deep) # type: ignore[no-untyped-call] + mapping[key] = value + return mapping + + +class GLSLGenerator(object): + standard_header = """ +#version 450 core +#define PRECISION $precision +#define FORMAT $format + +""" + + def __init__(self): # type: ignore[no-untyped-def] + self.ops_template_params = {} + + def add_params_yaml(self, parameters_yaml_file): # type: ignore[no-untyped-def] + all_template_params = OrderedDict() + with open(parameters_yaml_file, "r") as f: + contents = yaml.load(f, Loader=UniqueKeyLoader) + for key in contents: + all_template_params[key] = contents[key] + self.validate_and_construct_op_params(all_template_params) # type: ignore[no-untyped-call] + + def validate_and_construct_op_params(self, all_template_params): # type: ignore[no-untyped-def] + for op in all_template_params: + if op in self.ops_template_params: + raise KeyError(f"{op} params file has already been parsed") + op_params_default_vals = all_template_params[op][ + "parameter_names_with_default_values" + ] + template_params_set = set(op_params_default_vals.keys()) + self.ops_template_params[op] = [] + self.ops_template_params[op].append(op_params_default_vals) + op_template_params_values = all_template_params[op]["parameter_values"] + for param_vals in op_template_params_values: + param_vals_set = set(param_vals.keys()) + invalid_keys = param_vals_set - template_params_set + if (len(invalid_keys)) > 0: + raise KeyError(f"Invalid keys {invalid_keys} are found") + param_vals_copy = copy.deepcopy(op_params_default_vals) + for key in param_vals: + param_vals_copy[key] = param_vals[key] + self.ops_template_params[op].append(param_vals_copy) + + def generate(self, glsl_template_in, out_dir): # type: ignore[no-untyped-def] + glsl_template_name = os.path.basename(glsl_template_in) + op_name, extension_name = glsl_template_name.split(".") + if extension_name != "glslt": + raise TypeError(f"invalid file type for glsl template {extension_name}") + if op_name not in self.ops_template_params: + raise KeyError(f"{op_name} params have not been populated") + code_template = CodeTemplate.from_file(glsl_template_in) + for template_params in self.ops_template_params[op_name]: + content = GLSLGenerator.standard_header + param_vals_string = "x".join([str(i) for i in template_params.values()]) + output_file_name = op_name + "_" + param_vals_string + ".glsl" + content += code_template.substitute(template_params) + output_file = os.path.join(out_dir, output_file_name) + with open(output_file, "w") as f: + f.write(content) + + +# Remove this +if __name__ == "__main__": + pass diff --git a/tools/gen_vulkan_spv.py b/tools/gen_vulkan_spv.py index 74b1212bdbe26..cc317eba7d4a7 100644 --- a/tools/gen_vulkan_spv.py +++ b/tools/gen_vulkan_spv.py @@ -8,11 +8,23 @@ import sys import subprocess from torchgen.code_template import CodeTemplate +from dataclasses import dataclass +from typing import List + +from tools.gen_vulkan_glsl import GLSLGenerator H_NAME = "spv.h" CPP_NAME = "spv.cpp" DEFAULT_ENV = {"precision": "highp", "format": "rgba32f"} + +@dataclass +class ShaderInfo: + tile_size: List[int] + layouts: List[str] + weight_storage_type: str = "" + bias_storage_type: str = "" + def getName(filePath): return os.path.basename(filePath).replace("/", "_").replace(".", "_") @@ -20,11 +32,44 @@ def isDescriptorLine(lineStr): descriptorLineId = r"^layout\(set" return re.search(descriptorLineId, lineStr) +def isTileSizeLine(lineStr): + tile_size_id = r"^ \* TILE_SIZE = \(" + return re.search(tile_size_id, lineStr) + +def findTileSizes(lineStr): + tile_size_id = r"^ \* TILE_SIZE = \(([0-9]+), ([0-9]+), ([0-9]+)\)" + matches = re.search(tile_size_id, lineStr) + return [int(matches.group(1)), int(matches.group(2)), int(matches.group(3))] + +def isWeightStorageTypeLine(lineStr): + weight_storage_id = r"^ \* WEIGHT_STORAGE = " + return re.search(weight_storage_id, lineStr) + +def getWeightStorageType(lineStr): + weight_storage_id = r"^ \* WEIGHT_STORAGE = ([a-zA-Z]+_\dD)" + matches = re.search(weight_storage_id, lineStr) + return matches.group(1) + +def isBiasStorageTypeLine(lineStr): + weight_storage_id = r"^ \* BIAS_STORAGE = " + return re.search(weight_storage_id, lineStr) + +def getBiasStorageType(lineStr): + weight_storage_id = r"^ \* BIAS_STORAGE = ([a-zA-Z]+_\dD)" + matches = re.search(weight_storage_id, lineStr) + return matches.group(1) + typeIdMapping = { r"image[123]D\b": "VK_DESCRIPTOR_TYPE_STORAGE_IMAGE", r"sampler[123]D\b": "VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER", r"\bbuffer\b": "VK_DESCRIPTOR_TYPE_STORAGE_BUFFER", - r"\buniform\b.*\bBlock\b": "VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER", + r"\buniform\b": "VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER", +} + +storageTypeToEnum = { + "TEXTURE_2D" : "api::StorageType::TEXTURE_2D", + "TEXTURE_3D" : "api::StorageType::TEXTURE_3D", + "BUFFER" : "api::StorageType::BUFFER", } def determineDescriptorType(lineStr): @@ -32,16 +77,40 @@ def determineDescriptorType(lineStr): if re.search(identifier, lineStr): return typeNum - raise Exception("Could not identify descriptor type of line: {}".format(lineStr)) - -def getLayout(srcFilePath): - layout = [] +def getShaderInfo(srcFilePath): + shader_info = ShaderInfo([], [], "") with open(srcFilePath, 'r') as srcFile: for line in srcFile: if isDescriptorLine(line): - layout.append(determineDescriptorType(line)) + shader_info.layouts.append(determineDescriptorType(line)) + if isTileSizeLine(line): + shader_info.tile_size = findTileSizes(line) + if isWeightStorageTypeLine(line): + shader_info.weight_storage_type = getWeightStorageType(line) + if isBiasStorageTypeLine(line): + shader_info.bias_storage_type = getBiasStorageType(line) + + return shader_info + +def genGLSLFromGLSLT(src_dir_path, tmp_dir_path): + template_dir_path = os.path.join(src_dir_path, "templates") + vexs = glob.glob(os.path.join(template_dir_path, '**', '*.yaml'), recursive=True) + parameter_yaml_files = [] + for f in vexs: + if len(f) > 1: + parameter_yaml_files.append(f) + generator = GLSLGenerator() + for params_yaml in parameter_yaml_files: + generator.add_params_yaml(params_yaml) # type: ignore[no-untyped-call] - return layout + vexs = glob.glob(os.path.join(src_dir_path, '**', '*.glslt'), recursive=True) + templateSrcPaths = [] + for f in vexs: + if len(f) > 1: + templateSrcPaths.append(f) + templateSrcPaths.sort() + for glslt in templateSrcPaths: + generator.generate(glslt, tmp_dir_path) # type: ignore[no-untyped-call] def genCppH(hFilePath, cppFilePath, srcDirPath, glslcPath, tmpDirPath, env): print("hFilePath:{} cppFilePath:{} srcDirPath:{} glslcPath:{} tmpDirPath:{}".format( @@ -49,6 +118,14 @@ def genCppH(hFilePath, cppFilePath, srcDirPath, glslcPath, tmpDirPath, env): vexs = glob.glob(os.path.join(srcDirPath, '**', '*.glsl'), recursive=True) templateSrcPaths = [] + for f in vexs: + if len(f) > 1: + templateSrcPaths.append(f) + templateSrcPaths.sort() + + # Now add glsl files that are generated from templates + genGLSLFromGLSLT(srcDirPath, tmpDirPath) + vexs = glob.glob(os.path.join(tmpDirPath, '**', '*.glsl'), recursive=True) for f in vexs: if len(f) > 1: templateSrcPaths.append(f) @@ -74,6 +151,7 @@ def genCppH(hFilePath, cppFilePath, srcDirPath, glslcPath, tmpDirPath, env): glslcPath, "-fshader-stage=compute", srcPath, "-o", spvPath, "--target-env=vulkan1.0", + "-I", srcDirPath, "-Werror" ] @@ -85,6 +163,8 @@ def genCppH(hFilePath, cppFilePath, srcDirPath, glslcPath, tmpDirPath, env): h = "#pragma once\n" h += "#include \n" h += "#include \n" + h += "#include \n" + h += "#include \n" h += "#include " nsbegin = "\nnamespace at {\nnamespace native {\nnamespace vulkan {\n" @@ -101,7 +181,7 @@ def genCppH(hFilePath, cppFilePath, srcDirPath, glslcPath, tmpDirPath, env): h += "extern const uint32_t {}[];\n".format(name) h += "extern const uint32_t {};\n".format(name_len) - layout = getLayout(srcPath) + shader_info = getShaderInfo(srcPath) name_layout = name + "_layout" h += "extern const std::vector {};\n".format(name_layout) @@ -117,10 +197,33 @@ def genCppH(hFilePath, cppFilePath, srcDirPath, glslcPath, tmpDirPath, env): # Add layout cpp += "const std::vector {} = {{\n".format(name_layout) - for descriptor in layout: + for descriptor in shader_info.layouts: cpp += " {},\n".format(descriptor) cpp += "};\n" + # Add tile size + if (len(shader_info.tile_size) > 0): + name_tile_size = name + "_tile_size" + h += "extern const std::vector {};\n".format(name_tile_size) + cpp += "const std::vector {} = {{\n".format(name_tile_size) + for s in shader_info.tile_size: + cpp += " {},\n".format(s) + cpp += "};\n" + + # Add weight type + if (shader_info.weight_storage_type != ""): + name_weight_storage_type = name + "_weight_storage_type" + h += "extern const api::StorageType {};\n".format(name_weight_storage_type) + cpp += "const api::StorageType {} = \n".format(name_weight_storage_type) + cpp += " {};\n".format(storageTypeToEnum[shader_info.weight_storage_type]) + + # Add bias type + if (shader_info.bias_storage_type != ""): + name_bias_storage_type = name + "_bias_storage_type" + h += "extern const api::StorageType {};\n".format(name_bias_storage_type) + cpp += "const api::StorageType {} = \n".format(name_bias_storage_type) + cpp += " {};\n".format(storageTypeToEnum[shader_info.bias_storage_type]) + cpp += nsend h += nsend diff --git a/tools/generate_torch_version.py b/tools/generate_torch_version.py index 96970bd2b1c35..1586ff15fd207 100644 --- a/tools/generate_torch_version.py +++ b/tools/generate_torch_version.py @@ -25,13 +25,12 @@ def get_sha(pytorch_root: Union[str, Path]) -> str: def get_tag(pytorch_root: Union[str, Path]) -> str: try: - tag = ( - subprocess.check_output( - ["git", "describe", "--tags", "--exact"], cwd=pytorch_root - ) - .decode("ascii") - .strip() - ) + tag = subprocess.run( + ["git", "describe", "--tags", "--exact"], + cwd=pytorch_root, + encoding="ascii", + capture_output=True, + ).stdout.strip() if RELEASE_PATTERN.match(tag): return tag else: diff --git a/tools/jit/gen_unboxing.py b/tools/jit/gen_unboxing.py index ebeaa21bc7be9..79c594a9afa07 100644 --- a/tools/jit/gen_unboxing.py +++ b/tools/jit/gen_unboxing.py @@ -116,7 +116,9 @@ def __call__(self, f: NativeFunction) -> str: # from wrapping/unwrapping TensorOptios. # However, we would look to include default args for schema parsing. # Default args only show up in the nonfaithful C++ API, - arg_default = cpp.default_expr(arg.argument.default, arg.argument.type) + arg_default = cpp.default_expr( + arg.argument.default, arg.argument.type, symint=False + ) if arg_default.startswith("{"): arg_cpp = f"c10::IntArrayRef({arg_default})" else: diff --git a/tools/linter/adapters/mypy_linter.py b/tools/linter/adapters/mypy_linter.py index 65ee8850e667c..cd94879fa0f93 100644 --- a/tools/linter/adapters/mypy_linter.py +++ b/tools/linter/adapters/mypy_linter.py @@ -87,6 +87,7 @@ def check_files( filenames: List[str], config: str, retries: int, + code: str, ) -> List[LintMessage]: try: proc = run_command( @@ -100,7 +101,7 @@ def check_files( path=None, line=None, char=None, - code="MYPY", + code=code, severity=LintSeverity.ERROR, name="command-failed", original=None, @@ -118,7 +119,7 @@ def check_files( char=int(match["column"]) if match["column"] is not None and not match["column"].startswith("-") else None, - code="MYPY", + code=code, severity=severities.get(match["severity"], LintSeverity.ERROR), original=None, replacement=None, @@ -143,6 +144,11 @@ def main() -> None: required=True, help="path to an mypy .ini config file", ) + parser.add_argument( + "--code", + default="MYPY", + help="the code this lint should report as", + ) parser.add_argument( "--verbose", action="store_true", @@ -182,7 +188,7 @@ def main() -> None: else: filenames[filename] = True - lint_messages = check_files(list(filenames), args.config, args.retries) + lint_messages = check_files(list(filenames), args.config, args.retries, args.code) for lint_message in lint_messages: print(json.dumps(lint_message._asdict()), flush=True) diff --git a/tools/linter/adapters/s3_init_config.json b/tools/linter/adapters/s3_init_config.json index 0b0e87e8e26cf..d48f264f83d5d 100644 --- a/tools/linter/adapters/s3_init_config.json +++ b/tools/linter/adapters/s3_init_config.json @@ -27,12 +27,12 @@ }, "actionlint": { "Darwin": { - "download_url": "https://oss-clang-format.s3.us-east-2.amazonaws.com/actionlint/1.6.15/Darwin_amd64/actionlint", - "hash": "e9a0e0b17e54cfefe7964b6aa1da8921b1f8f2318c31c0eb1a17ea3e8ab10db2" + "download_url": "https://oss-clang-format.s3.us-east-2.amazonaws.com/actionlint/1.6.21/Darwin_amd64/actionlint", + "hash": "b354db83815384d3c3a07f68f44b30cb0a70899757a0d185d7322de9952e8813" }, "Linux": { - "download_url": "https://oss-clang-format.s3.us-east-2.amazonaws.com/actionlint/1.6.15/Linux_arm64/actionlint", - "hash": "d6b45ae67f29a2bf9ddd226071ddd8f158fdf2992e8515a06838e5fef90f3a2d" + "download_url": "https://oss-clang-format.s3.us-east-2.amazonaws.com/actionlint/1.6.21/Linux_arm64/actionlint", + "hash": "025ac157db121b33971ef24af72d73d71cda3cb1e3a94795bb2708ef4032ca76" } } } diff --git a/tools/onnx/gen_diagnostics.py b/tools/onnx/gen_diagnostics.py index ba6fd43bee292..92960024e048d 100644 --- a/tools/onnx/gen_diagnostics.py +++ b/tools/onnx/gen_diagnostics.py @@ -14,6 +14,7 @@ import argparse import os +import string import subprocess import textwrap from typing import Any, Mapping, Sequence @@ -30,19 +31,37 @@ Diagnostic rules for PyTorch ONNX export. """ -_PY_RULE_TEMPLATE = """\ -{0}: infra.Rule = dataclasses.field( - default=infra.Rule.from_sarif(**{1}), +_PY_RULE_CLASS_COMMENT = """\ +GENERATED CODE - DO NOT EDIT DIRECTLY +The purpose of generating a class for each rule is to override the `format_message` +method to provide more details in the signature about the format arguments. +""" + +_PY_RULE_CLASS_TEMPLATE = """\ +class _{pascal_case_name}(infra.Rule): + \"\"\"{short_description}\"\"\" + def format_message(self, {message_arguments}) -> str: # type: ignore[override] + \"\"\"Returns the formatted default message of this Rule. + + Message template: {message_template} + \"\"\" + return self.message_default_template.format({message_arguments_assigned}) + +""" + +_PY_RULE_COLLECTION_FIELD_TEMPLATE = """\ +{snake_case_name}: _{pascal_case_name} = dataclasses.field( + default=_{pascal_case_name}.from_sarif(**{sarif_dict}), init=False, ) -\"\"\"{2}\"\"\" +\"\"\"{short_description}\"\"\" """ _CPP_RULE_TEMPLATE = """\ /** - * @brief {1} + * @brief {short_description} */ -{0}, +{name}, """ _RuleType = Mapping[str, Any] @@ -56,24 +75,62 @@ def _kebab_case_to_pascal_case(name: str) -> str: return "".join(word.capitalize() for word in name.split("-")) -def _format_rule_for_python(rule: _RuleType) -> str: - name = _kebab_case_to_snake_case(rule["name"]) +def _format_rule_for_python_class(rule: _RuleType) -> str: + pascal_case_name = _kebab_case_to_pascal_case(rule["name"]) short_description = rule["short_description"]["text"] + message_template = rule["message_strings"]["default"]["text"] + field_names = [ + field_name + for _, field_name, _, _ in string.Formatter().parse(message_template) + if field_name is not None + ] + for field_name in field_names: + assert isinstance( + field_name, str + ), f"Unexpected field type {type(field_name)} from {field_name}. " + "Field name must be string.\nFull message template: {message_template}" + assert ( + not field_name.isnumeric() + ), f"Unexpected numeric field name {field_name}. " + "Only keyword name formatting is supported.\nFull message template: {message_template}" + message_arguments = ", ".join(field_names) + message_arguments_assigned = ", ".join( + [f"{field_name}={field_name}" for field_name in field_names] + ) + return _PY_RULE_CLASS_TEMPLATE.format( + pascal_case_name=pascal_case_name, + short_description=short_description, + message_template=repr(message_template), + message_arguments=message_arguments, + message_arguments_assigned=message_arguments_assigned, + ) + - return _PY_RULE_TEMPLATE.format(name, rule, short_description) +def _format_rule_for_python_field(rule: _RuleType) -> str: + snake_case_name = _kebab_case_to_snake_case(rule["name"]) + pascal_case_name = _kebab_case_to_pascal_case(rule["name"]) + short_description = rule["short_description"]["text"] + + return _PY_RULE_COLLECTION_FIELD_TEMPLATE.format( + snake_case_name=snake_case_name, + pascal_case_name=pascal_case_name, + sarif_dict=rule, + short_description=short_description, + ) def _format_rule_for_cpp(rule: _RuleType) -> str: name = f"k{_kebab_case_to_pascal_case(rule['name'])}" short_description = rule["short_description"]["text"] - return _CPP_RULE_TEMPLATE.format(name, short_description) + return _CPP_RULE_TEMPLATE.format(name=name, short_description=short_description) def gen_diagnostics_python( rules: Sequence[_RuleType], out_py_dir: str, template_dir: str ) -> None: - rule_lines = [_format_rule_for_python(rule) for rule in rules] + rule_class_lines = [_format_rule_for_python_class(rule) for rule in rules] + rule_field_lines = [_format_rule_for_python_field(rule) for rule in rules] fm = torchgen_utils.FileManager( install_dir=out_py_dir, template_dir=template_dir, dry_run=False @@ -83,7 +140,9 @@ def gen_diagnostics_python( "rules.py.in", lambda: { "generated_comment": _RULES_GENERATED_COMMENT, - "rules": textwrap.indent("\n".join(rule_lines), " " * 4), + "generated_rule_class_comment": _PY_RULE_CLASS_COMMENT, + "rule_classes": "\n".join(rule_class_lines), + "rules": textwrap.indent("\n".join(rule_field_lines), " " * 4), }, ) _lint_file(os.path.join(out_py_dir, "_rules.py")) diff --git a/tools/onnx/templates/rules.py.in b/tools/onnx/templates/rules.py.in index e29c202dc6a70..2137119d14c23 100644 --- a/tools/onnx/templates/rules.py.in +++ b/tools/onnx/templates/rules.py.in @@ -7,10 +7,14 @@ import dataclasses # flake8: noqa from torch.onnx._internal.diagnostics import infra +""" +${generated_rule_class_comment} +""" + +${rule_classes} @dataclasses.dataclass class _POERules(infra.RuleCollection): ${rules} - rules = _POERules() diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py index 79f97c4e9f30c..43118edb98bd5 100644 --- a/tools/pyi/gen_pyi.py +++ b/tools/pyi/gen_pyi.py @@ -443,6 +443,10 @@ def gen_pyi( "_to_functional_tensor": [ "def _to_functional_tensor(t: Tensor) -> Tensor: ..." ], + "_enable_functionalization": [ + "def _enable_functionalization(*, reapply_views: _bool = False): ..." + ], + "_disable_functionalization": ["def _disable_functionalization(): ..."], "range": [ "def range(start: Number, end: Number," " step: Number=1, *, out: Optional[Tensor]=None, {}) -> Tensor: ...".format( @@ -597,7 +601,7 @@ def gen_pyi( "def size(self, dim: _int) -> _int: ...", ], "stride": [ - "def stride(self) -> Tuple[_int]: ...", + "def stride(self) -> Tuple[_int, ...]: ...", "def stride(self, _int) -> _int: ...", ], "new_ones": [ @@ -722,7 +726,7 @@ def gen_pyi( binop += "_" out_suffix = "" unsorted_tensor_method_hints[binop].append( - "def {}(self, other: Union[Tensor, Number, torch.SymIntNode, torch.SymFloatNode]{})" + "def {}(self, other: Union[Tensor, Number, torch.SymInt, torch.SymFloat]{})" " -> Tensor: ...".format(binop, out_suffix) ) for binop in ["add", "sub"]: @@ -732,7 +736,7 @@ def gen_pyi( binop += "_" out_suffix = "" unsorted_tensor_method_hints[binop].append( - "def {}(self, other: Union[Tensor, Number, torch.SymIntNode, torch.SymFloatNode], " + "def {}(self, other: Union[Tensor, Number, torch.SymInt, torch.SymFloat], " "*, alpha: Optional[Number]=1{})" " -> Tensor: ...".format(binop, out_suffix) ) diff --git a/tools/render_junit.py b/tools/render_junit.py index 95c281d99d492..0d6effbd09063 100644 --- a/tools/render_junit.py +++ b/tools/render_junit.py @@ -12,10 +12,10 @@ TestCase, TestSuite, ) -except ImportError: +except ImportError as e: raise ImportError( "junitparser not found, please install with 'pip install junitparser'" - ) + ) from e try: import rich diff --git a/tools/stats/check_disabled_tests.py b/tools/stats/check_disabled_tests.py new file mode 100644 index 0000000000000..636af668a13d3 --- /dev/null +++ b/tools/stats/check_disabled_tests.py @@ -0,0 +1,277 @@ +import argparse +import json +import os +import xml.etree.ElementTree as ET +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import Any, Dict, Generator, Tuple + +from tools.stats.upload_stats_lib import ( + download_gha_artifacts, + download_s3_artifacts, + is_rerun_disabled_tests, + unzip, + upload_to_s3, +) +from tools.stats.upload_test_stats import process_xml_element + +TESTCASE_TAG = "testcase" +SEPARATOR = ";" + + +def process_report( + report: Path, +) -> Dict[str, Dict[str, int]]: + """ + Return a list of disabled tests that should be re-enabled and those that are still + flaky (failed or skipped) + """ + root = ET.parse(report) + + # All rerun tests from a report are grouped here: + # + # * Success test should be re-enable if it's green after rerunning in all platforms + # where it is currently disabled + # * Failures from pytest because pytest-flakefinder is used to run the same test + # multiple times, some could fails + # * Skipped tests from unittest + # + # We want to keep track of how many times the test fails (num_red) or passes (num_green) + all_tests: Dict[str, Dict[str, int]] = {} + + if not is_rerun_disabled_tests(root): + return all_tests + + for test_case in root.iter(TESTCASE_TAG): + parsed_test_case = process_xml_element(test_case) + + # Under --rerun-disabled-tests mode, a test is skipped when: + # * it's skipped explicitly inside PyToch code + # * it's skipped because it's a normal enabled test + # * or it's falky (num_red > 0 and num_green > 0) + # * or it's failing (num_red > 0 and num_green == 0) + # + # We care only about the latter two here + skipped = parsed_test_case.get("skipped", None) + if skipped and "num_red" not in skipped.get("message", ""): + continue + + name = parsed_test_case.get("name", "") + classname = parsed_test_case.get("classname", "") + filename = parsed_test_case.get("file", "") + + if not name or not classname or not filename: + continue + + # Check if the test is a failure + failure = parsed_test_case.get("failure", None) + + disabled_test_id = SEPARATOR.join([name, classname, filename]) + if disabled_test_id not in all_tests: + all_tests[disabled_test_id] = { + "num_green": 0, + "num_red": 0, + } + + # Under --rerun-disabled-tests mode, if a test is not skipped or failed, it's + # counted as a success. Otherwise, it's still flaky or failing + if skipped: + try: + stats = json.loads(skipped.get("message", "")) + except json.JSONDecodeError: + stats = {} + + all_tests[disabled_test_id]["num_green"] += stats.get("num_green", 0) + all_tests[disabled_test_id]["num_red"] += stats.get("num_red", 0) + elif failure: + # As a failure, increase the failure count + all_tests[disabled_test_id]["num_red"] += 1 + else: + all_tests[disabled_test_id]["num_green"] += 1 + + return all_tests + + +def get_test_reports( + repo: str, workflow_run_id: int, workflow_run_attempt: int +) -> Generator[Path, None, None]: + """ + Gather all the test reports from S3 and GHA. It is currently not possible to guess which + test reports are from rerun_disabled_tests workflow because the name doesn't include the + test config. So, all reports will need to be downloaded and examined + """ + with TemporaryDirectory() as temp_dir: + print("Using temporary directory:", temp_dir) + os.chdir(temp_dir) + + artifact_paths = download_s3_artifacts( + "test-reports", workflow_run_id, workflow_run_attempt + ) + for path in artifact_paths: + unzip(path) + + artifact_paths = download_gha_artifacts( + "test-report", workflow_run_id, workflow_run_attempt + ) + for path in artifact_paths: + unzip(path) + + for report in Path(".").glob("**/*.xml"): + yield report + + +def get_disabled_test_name(test_id: str) -> Tuple[str, str, str, str]: + """ + Follow flaky bot convention here, if that changes, this will also need to be updated + """ + name, classname, filename = test_id.split(SEPARATOR) + return f"{name} (__main__.{classname})", name, classname, filename + + +def prepare_record( + workflow_id: int, + workflow_run_attempt: int, + name: str, + classname: str, + filename: str, + flaky: bool, + num_red: int = 0, + num_green: int = 0, +) -> Tuple[Any, Dict[str, Any]]: + """ + Prepare the record to save onto S3 + """ + key = ( + workflow_id, + workflow_run_attempt, + name, + classname, + filename, + ) + + record = { + "workflow_id": workflow_id, + "workflow_run_attempt": workflow_run_attempt, + "name": name, + "classname": classname, + "filename": filename, + "flaky": flaky, + "num_green": num_green, + "num_red": num_red, + } + + return key, record + + +def save_results( + workflow_id: int, + workflow_run_attempt: int, + all_tests: Dict[str, Dict[str, int]], +) -> None: + """ + Save the result to S3, so it can go to Rockset + """ + should_be_enabled_tests = { + name: stats + for name, stats in all_tests.items() + if "num_green" in stats + and stats["num_green"] + and "num_red" in stats + and stats["num_red"] == 0 + } + still_flaky_tests = { + name: stats + for name, stats in all_tests.items() + if name not in should_be_enabled_tests + } + + records = {} + for test_id, stats in all_tests.items(): + num_green = stats.get("num_green", 0) + num_red = stats.get("num_red", 0) + disabled_test_name, name, classname, filename = get_disabled_test_name(test_id) + + key, record = prepare_record( + workflow_id=workflow_id, + workflow_run_attempt=workflow_run_attempt, + name=name, + classname=classname, + filename=filename, + flaky=test_id in still_flaky_tests, + num_green=num_green, + num_red=num_red, + ) + records[key] = record + + # Log the results + print(f"The following {len(should_be_enabled_tests)} tests should be re-enabled:") + for test_id, stats in should_be_enabled_tests.items(): + disabled_test_name, name, classname, filename = get_disabled_test_name(test_id) + print(f" {disabled_test_name} from {filename}") + + print(f"The following {len(still_flaky_tests)} are still flaky:") + for test_id, stats in still_flaky_tests.items(): + num_green = stats.get("num_green", 0) + num_red = stats.get("num_red", 0) + + disabled_test_name, name, classname, filename = get_disabled_test_name(test_id) + print( + f" {disabled_test_name} from {filename}, failing {num_red}/{num_red + num_green}" + ) + + upload_to_s3( + workflow_id, + workflow_run_attempt, + "rerun_disabled_tests", + list(records.values()), + ) + + +def main(repo: str, workflow_run_id: int, workflow_run_attempt: int) -> None: + """ + Find the list of all disabled tests that should be re-enabled + """ + # Aggregated across all jobs + all_tests: Dict[str, Dict[str, int]] = {} + + for report in get_test_reports( + args.repo, args.workflow_run_id, args.workflow_run_attempt + ): + tests = process_report(report) + for name, stats in tests.items(): + if name not in all_tests: + all_tests[name] = stats.copy() + else: + all_tests[name]["num_green"] += stats.get("num_green", 0) + all_tests[name]["num_red"] += stats.get("num_red", 0) + + save_results( + workflow_run_id, + workflow_run_attempt, + all_tests, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Upload test artifacts from GHA to S3") + parser.add_argument( + "--workflow-run-id", + type=int, + required=True, + help="id of the workflow to get artifacts from", + ) + parser.add_argument( + "--workflow-run-attempt", + type=int, + required=True, + help="which retry of the workflow this is", + ) + parser.add_argument( + "--repo", + type=str, + required=True, + help="which GitHub repo this workflow run belongs to", + ) + + args = parser.parse_args() + main(args.repo, args.workflow_run_id, args.workflow_run_attempt) diff --git a/tools/stats/monitor.py b/tools/stats/monitor.py index 972d0dbea038b..b45979451507a 100644 --- a/tools/stats/monitor.py +++ b/tools/stats/monitor.py @@ -30,11 +30,22 @@ def get_per_process_cpu_info() -> List[Dict[str, Any]]: "cmd": " ".join(p.cmdline()), "cpu_percent": p.cpu_percent(), "rss_memory": p.memory_info().rss, - "uss_memory": p.memory_full_info().uss, } - if "pss" in p.memory_full_info(): - # only availiable in linux - info["pss_memory"] = p.memory_full_info().pss + + # https://psutil.readthedocs.io/en/latest/index.html?highlight=memory_full_info + # requires higher user privileges and could throw AccessDenied error, i.e. mac + try: + memory_full_info = p.memory_full_info() + + info["uss_memory"] = memory_full_info.uss + if "pss" in memory_full_info: + # only availiable in linux + info["pss_memory"] = memory_full_info.pss + + except psutil.AccessDenied as e: + # It's ok to skip this + pass + per_process_info.append(info) return per_process_info diff --git a/tools/stats/upload_artifacts.py b/tools/stats/upload_artifacts.py new file mode 100644 index 0000000000000..eb0fde7f38ac2 --- /dev/null +++ b/tools/stats/upload_artifacts.py @@ -0,0 +1,61 @@ +import argparse +import os +import re +from tempfile import TemporaryDirectory + +from tools.stats.upload_stats_lib import download_gha_artifacts, upload_file_to_s3 + +ARTIFACTS = [ + "sccache-stats", + "test-jsons", + "test-reports", + "usage-log", +] +BUCKET_NAME = "gha-artifacts" +FILENAME_REGEX = r"-runattempt\d+" + + +def get_artifacts(repo: str, workflow_run_id: int, workflow_run_attempt: int) -> None: + with TemporaryDirectory() as temp_dir: + print("Using temporary directory:", temp_dir) + os.chdir(temp_dir) + + for artifact in ARTIFACTS: + artifact_paths = download_gha_artifacts( + artifact, workflow_run_id, workflow_run_attempt + ) + + for artifact_path in artifact_paths: + # GHA artifact is named as follows: NAME-runattempt${{ github.run_attempt }}-SUFFIX.zip + # and we want remove the run_attempt to conform with the naming convention on S3, i.e. + # pytorch/pytorch/WORKFLOW_ID/RUN_ATTEMPT/artifact/NAME-SUFFIX.zip + s3_filename = re.sub(FILENAME_REGEX, "", artifact_path.name) + upload_file_to_s3( + file_name=str(artifact_path.resolve()), + bucket=BUCKET_NAME, + key=f"{repo}/{workflow_run_id}/{workflow_run_attempt}/artifact/{s3_filename}", + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Upload test artifacts from GHA to S3") + parser.add_argument( + "--workflow-run-id", + type=int, + required=True, + help="id of the workflow to get artifacts from", + ) + parser.add_argument( + "--workflow-run-attempt", + type=int, + required=True, + help="which retry of the workflow this is", + ) + parser.add_argument( + "--repo", + type=str, + required=True, + help="which GitHub repo this workflow run belongs to", + ) + args = parser.parse_args() + get_artifacts(args.repo, args.workflow_run_id, args.workflow_run_attempt) diff --git a/tools/stats/upload_stats_lib.py b/tools/stats/upload_stats_lib.py index 1cba78f68da1e..c91075225a628 100644 --- a/tools/stats/upload_stats_lib.py +++ b/tools/stats/upload_stats_lib.py @@ -2,6 +2,7 @@ import io import json import os +import xml.etree.ElementTree as ET import zipfile from pathlib import Path from typing import Any, Dict, List @@ -12,6 +13,7 @@ PYTORCH_REPO = "https://api.github.com/repos/pytorch/pytorch" S3_RESOURCE = boto3.resource("s3") +TARGET_WORKFLOW = "--rerun-disabled-tests" def _get_request_headers() -> Dict[str, str]: @@ -136,6 +138,22 @@ def upload_to_s3( print("Done!") +def upload_file_to_s3( + file_name: str, + bucket: str, + key: str, +) -> None: + """ + Upload a local file to S3 + """ + print(f"Upload {file_name} to s3://{bucket}/{key}") + boto3.client("s3").upload_file( + file_name, + bucket, + key, + ) + + def unzip(p: Path) -> None: """Unzip the provided zipfile to a similarly-named directory. @@ -149,3 +167,16 @@ def unzip(p: Path) -> None: with zipfile.ZipFile(p, "r") as zip: zip.extractall(unzipped_dir) + + +def is_rerun_disabled_tests(root: ET.ElementTree) -> bool: + """ + Check if the test report is coming from rerun_disabled_tests workflow + """ + skipped = root.find(".//*skipped") + # Need to check against None here, if not skipped doesn't work as expected + if skipped is None: + return False + + message = skipped.attrib.get("message", "") + return TARGET_WORKFLOW in message or "num_red" in message diff --git a/tools/stats/upload_test_stats.py b/tools/stats/upload_test_stats.py index 01647264705bb..23695933c704b 100644 --- a/tools/stats/upload_test_stats.py +++ b/tools/stats/upload_test_stats.py @@ -9,6 +9,7 @@ from tools.stats.upload_stats_lib import ( download_gha_artifacts, download_s3_artifacts, + is_rerun_disabled_tests, unzip, upload_to_s3, ) @@ -35,9 +36,18 @@ def parse_xml_report( job_id = get_job_id(report) print(f"Found job id: {job_id}") + test_cases: List[Dict[str, Any]] = [] + root = ET.parse(report) + # TODO: unlike unittest, pytest-flakefinder used by rerun disabled tests for test_ops + # includes skipped messages multiple times (50 times by default). This slows down + # this script too much (O(n)) because it tries to gather all the stats. This should + # be fixed later in the way we use pytest-flakefinder. A zipped test report from rerun + # disabled test is only few MB, but will balloon up to a much bigger XML file after + # extracting from a dozen to few hundred MB + if is_rerun_disabled_tests(root): + return test_cases - test_cases = [] for test_case in root.iter(tag): case = process_xml_element(test_case) case["workflow_id"] = workflow_id @@ -118,10 +128,16 @@ def process_xml_element(element: ET.Element) -> Dict[str, Any]: def get_pytest_parallel_times() -> Dict[Any, Any]: - pytest_parallel_times = {} + pytest_parallel_times: Dict[Any, Any] = {} for report in Path(".").glob("**/python-pytest/**/*.xml"): invoking_file = report.parent.name + root = ET.parse(report) + # TODO: Skip test reports from rerun disabled tests, same reason as mentioned + # above + if is_rerun_disabled_tests(root): + continue + assert len(list(root.iter("testsuite"))) == 1 for test_suite in root.iter("testsuite"): pytest_parallel_times[ diff --git a/tools/test/gen_oplist_test.py b/tools/test/gen_oplist_test.py index d58e2ccc90671..33f9fb293edc4 100644 --- a/tools/test/gen_oplist_test.py +++ b/tools/test/gen_oplist_test.py @@ -4,7 +4,7 @@ import unittest from unittest.mock import MagicMock -from gen_oplist import throw_if_any_op_includes_overloads +from tools.code_analyzer.gen_oplist import throw_if_any_op_includes_overloads class GenOplistTest(unittest.TestCase): diff --git a/tools/test/test_codegen.py b/tools/test/test_codegen.py index 8bcecbb26e32e..4a9585708890c 100644 --- a/tools/test/test_codegen.py +++ b/tools/test/test_codegen.py @@ -217,6 +217,14 @@ def setUp(self) -> None: loc=torchgen.model.Location(__file__, 1), valid_tags=set(), ) + ( + self.fragment_custom_native_function, + _, + ) = torchgen.model.NativeFunction.from_yaml( + {"func": "quantized_decomposed::func() -> bool"}, + loc=torchgen.model.Location(__file__, 1), + valid_tags=set(), + ) def test_default_namespace_schema_registration_code_valid(self) -> None: native_functions = [DEFAULT_NATIVE_FUNCTION] @@ -237,6 +245,23 @@ def test_custom_namespace_schema_registration_code_valid(self) -> None: TORCH_LIBRARY(custom, m) { m.def("func() -> bool", {}); +};""", + ) + + def test_fragment_custom_namespace_schema_registration_code_valid(self) -> None: + """Sometimes we want to extend an existing namespace, for example quantized + namespace, which is already defined in native/quantized/library.cpp + """ + _, registrations = get_native_function_schema_registrations( + native_functions=[self.fragment_custom_native_function], + schema_selector=self.selector, + ) + self.assertEqual( + registrations, + """ +TORCH_LIBRARY_FRAGMENT(quantized_decomposed, m) { + m.def("func() -> bool", {}); + };""", ) diff --git a/tools/test/test_vulkan_codegen.py b/tools/test/test_vulkan_codegen.py new file mode 100644 index 0000000000000..8b0b4b3a13cde --- /dev/null +++ b/tools/test/test_vulkan_codegen.py @@ -0,0 +1,100 @@ +import os +import tempfile +import unittest + +from tools.gen_vulkan_glsl import GLSLGenerator +from yaml.constructor import ConstructorError + + +class TestGLSLCodegen(unittest.TestCase): + def test_assert_on_duplicate_key_yaml(self) -> None: + yaml_with_duplicate_keys = """ +conv2d_pw: + parameter_names_with_default_values: + TILE_SIZE_X: 1 + TILE_SIZE_Y: 1 + parameter_values: + - TILE_SIZE_X: 2 + TILE_SIZE_Y: 2 + - TILE_SIZE_X: 2 + TILE_SIZE_Y: 4 + - TILE_SIZE_X: 4 + TILE_SIZE_Y: 2 + - TILE_SIZE_X: 4 + TILE_SIZE_Y: 4 +conv2d_pw: + parameter_names_with_default_values: + - TILE_SIZE_X: 1 + - TILE_SIZE_Y: 1 + parameter_values: + - TILE_SIZE_X: 2 + TILE_SIZE_Y: 2 + - TILE_SIZE_X: 2 + TILE_SIZE_Y: 4 + - TILE_SIZE_X: 4 + TILE_SIZE_Y: 2 + - TILE_SIZE_X: 4 + TILE_SIZE_Y: 4 +""" + + generator = GLSLGenerator() # type: ignore[no-untyped-call] + with tempfile.NamedTemporaryFile(mode="w") as fp: + fp.write(yaml_with_duplicate_keys) + fp.flush() + with self.assertRaisesRegex( + ConstructorError, r"while constructing a mapping" + ): + generator.add_params_yaml(fp.name) # type: ignore[no-untyped-call] + + def test_assert_keys_mismatch(self) -> None: + yaml_with_key_mismatch = """ +conv2d_pw: + parameter_names_with_default_values: + TILE_SIZE_X: 1 + TILE_SIZE_Y: 1 + parameter_values: + - TILE_SIZE_X: 2 + TILE_SIZE_Z: 2 +""" + + generator = GLSLGenerator() # type: ignore[no-untyped-call] + with tempfile.NamedTemporaryFile(mode="w") as fp: + fp.write(yaml_with_key_mismatch) + fp.flush() + with self.assertRaisesRegex(KeyError, r"Invalid keys {'TILE_SIZE_Z'}"): + generator.add_params_yaml(fp.name) # type: ignore[no-untyped-call] + + def test_missing_key_default_val(self) -> None: + yaml_with_key_mismatch = """ +conv2d_pw: + parameter_names_with_default_values: + TILE_SIZE_X: 1 + TILE_SIZE_Y: 1 + parameter_values: + - TILE_SIZE_Y: 2 +""" + file_content = """ +x = $TILE_SIZE_X + $TILE_SIZE_Y +""" + + generator = GLSLGenerator() # type: ignore[no-untyped-call] + with tempfile.NamedTemporaryFile(mode="w") as fp: + fp.write(yaml_with_key_mismatch) + fp.flush() + generator.add_params_yaml(fp.name) # type: ignore[no-untyped-call] + with tempfile.TemporaryDirectory() as tmp_dir: + template_file_name = os.path.join(tmp_dir, "conv2d_pw.glslt") + with open(template_file_name, "w") as template_file: + template_file.write(file_content) + template_file.flush() + generator.generate(template_file.name, tmp_dir) # type: ignore[no-untyped-call] + file_name_1 = os.path.join(tmp_dir, "conv2d_pw_1x1.glsl") + file_name_2 = os.path.join(tmp_dir, "conv2d_pw_1x2.glsl") + self.assertTrue(os.path.exists(file_name_1)) + self.assertTrue(os.path.exists(file_name_2)) + with open(file_name_1, "r") as f: + contents = f.read() + self.assertTrue("1 + 1" in contents) + with open(file_name_2, "r") as f: + contents = f.read() + self.assertTrue("1 + 2" in contents) diff --git a/tools/testing/test_selections.py b/tools/testing/test_selections.py index 3b33281781894..950d686d8dacc 100644 --- a/tools/testing/test_selections.py +++ b/tools/testing/test_selections.py @@ -5,7 +5,7 @@ from tools.stats.import_test_stats import get_disabled_tests, get_slow_tests -NUM_PROCS = 2 +NUM_PROCS = 1 if os.getenv("PYTORCH_TEST_CUDA_MEM_LEAK_CHECK", "0") == "1" else 2 class ShardJob: @@ -45,14 +45,14 @@ def calculate_shards( ] for test in sorted_tests: if must_serial(test): - min_sharded_job = sorted(sharded_jobs, key=lambda j: j.get_total_time())[0] + min_sharded_job = min(sharded_jobs, key=lambda j: j.get_total_time()) min_sharded_job.serial.append(test) else: - min_sharded_job = sorted(sharded_jobs, key=lambda j: j.get_total_time())[0] + min_sharded_job = min(sharded_jobs, key=lambda j: j.get_total_time()) min_sharded_job.parallel.append(test) # Round robin the unknown jobs starting with the smallest shard - index = sorted(range(num_shards), key=lambda i: sharded_jobs[i].get_total_time())[0] + index = min(range(num_shards), key=lambda i: sharded_jobs[i].get_total_time()) for test in unknown_tests: sharded_jobs[index].serial.append(test) index = (index + 1) % num_shards diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 70248d1325274..e2c2d554cdb0e 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -169,19 +169,7 @@ class Future(object): def _jit_set_num_profiled_runs(num: _size) -> _size: ... -class SymIntNode(object): - def get_pyobj(self) -> Any: ... - - @staticmethod - def new_symint(obj) -> SymIntNode: ... - -class SymFloatNode(object): - def get_pyobj(self) -> Any: ... - - @staticmethod - def new_symfloat(obj) -> SymFloatNode: ... - -# Defined in torch/csrc/jit/passes/xnnpack_rewrite.h +# Defined in torch/csrc/jit/passes/mobile_optimizer_type.h class MobileOptimizerType: ... @@ -190,6 +178,7 @@ INSERT_FOLD_PREPACK_OPS: MobileOptimizerType REMOVE_DROPOUT: MobileOptimizerType FUSE_ADD_RELU: MobileOptimizerType HOIST_CONV_PACKED_PARAMS: MobileOptimizerType +VULKAN_AUTOMATIC_GPU_TRANSFER: MobileOptimizerType def fork(*args: Any, **kwargs: Any) -> Future: ... def wait(fut: Future) -> Any: ... @@ -227,6 +216,7 @@ def _clone_module_with_class(module: 'torch.jit.ScriptModule', ignored_methods: List[AnyStr], ignored_attributes: List[AnyStr]) -> 'torch.jit.ScriptModule': ... def _jit_pass_vulkan_optimize_for_mobile(module: 'torch.jit.ScriptModule', + optimization_blocklist: Set[MobileOptimizerType], preserved_methods: List[AnyStr]) -> 'torch.jit.ScriptModule': ... def _jit_pass_metal_optimize_for_mobile(module: 'torch.jit.ScriptModule', preserved_methods: List[AnyStr]) -> 'torch.jit.ScriptModule': ... @@ -525,8 +515,8 @@ class Value: # Defined in torch/csrc/jit/ir/ir.h class Block: - def inputs(self) -> List[Value]: ... - def outputs(self) -> List[Value]: ... + def inputs(self) -> Iterator[Value]: ... + def outputs(self) -> Iterator[Value]: ... def nodes(self) -> Iterator[Node]: ... def paramNode(self) -> Node: ... def returnNode(self) -> Node: ... @@ -540,11 +530,11 @@ class Node: def __getitem__(self, key: str) -> Any: ... def schema(self) -> str: ... def input(self) -> Value: ... - def inputs(self) -> List[Value]: ... + def inputs(self) -> Iterator[Value]: ... def inputsAt(self, idx: _int) -> Value: ... def inputsSize(self) -> _int: ... def output(self) -> Value: ... - def outputs(self) -> List[Value]: ... + def outputs(self) -> Iterator[Value]: ... def outputsAt(self, idx: _int) -> Value: ... def outputsSize(self) -> _int: ... def hasMultipleOutputs(self) -> _bool: ... @@ -620,12 +610,12 @@ class Node: # Defined in torch/torch/csrc/jit/ir/ir.h class Graph: - def inputs(self) -> List[Value]: ... - def outputs(self) -> List[Value]: ... + def inputs(self) -> Iterator[Value]: ... + def outputs(self) -> Iterator[Value]: ... def nodes(self) -> Iterator[Node]: ... def param_node(self) -> Node: ... def return_node(self) -> Node: ... - def addInput(self, name: str) -> Value: ... + def addInput(self, name: str = "") -> Value: ... def eraseInput(self, i: _int) -> None: ... def registerOutput(self, n: Value) -> _int: ... def eraseOutput(self, i: _int) -> None: ... @@ -641,6 +631,7 @@ class Graph: def insertPoint(self) -> Node: ... def insertGraph(self, callee: Graph, inputs: List[Value]) -> List[Value]: ... def makeMultiOutputIntoTuple(self) -> None: ... + def copy(self) -> Graph: ... ... @@ -831,6 +822,8 @@ def _get_cudnn_enabled() -> _bool: ... # THPModule_userEnabledCuDNN def _set_cudnn_enabled(arg: _bool) -> None: ... # THPModule_setUserEnabledCuDNN def _get_flash_sdp_enabled() -> _bool: ... # THPModule_userEnabledFusedSDP def _set_sdp_use_flash(arg: _bool) -> None: ... # THPModule_setSDPUseFlash +def _get_mem_efficient_sdp_enabled() -> _bool: ... # THPModule_userEnabledMathSDP +def _set_sdp_use_mem_efficient(arg: _bool) -> None: ... # THPModule_setSDPUseMemEfficient def _get_math_sdp_enabled() -> _bool: ... # THPModule_userEnabledMathSDP def _set_sdp_use_math(arg: _bool) -> None: ... # THPModule_setSDPUseMath def _get_mkldnn_enabled() -> _bool: ... # THPModule_userEnabledMkldnn @@ -856,6 +849,8 @@ def _set_conj(x: Tensor, conj: _bool) -> None: ... def _set_neg(x: Tensor, neg: _bool) -> None: ... def _set_meta_in_tls_dispatch_include(meta_in_tls: _bool) -> None: ... def _meta_in_tls_dispatch_include() -> _bool: ... +def _select_conv_backend(*args, **kwargs) -> ConvBackend: ... +def _conv_determine_backend_memory_format(input: Tensor, weight: Tensor, backend: ConvBackend) -> memory_format: ... def _has_storage(x: Tensor) -> _bool: ... def _should_allow_numbers_as_tensors(func_name: str) -> _bool: ... # NB: There is no Capsule type in typing, see @@ -884,12 +879,14 @@ def _disabled_torch_function_impl(func: Callable, types: Iterable[Type], args: T def _disabled_torch_dispatch_impl(func: Callable, types: Iterable[Type], args: Tuple, kwargs: Dict) -> Any: ... # THPModule_disable_dispatch_function def _get_linalg_preferred_backend() -> torch._C._LinalgBackend: ... def _set_linalg_preferred_backend(arg: torch._C._LinalgBackend): ... -def _is_mps_available() -> _bool: ... class _LinalgBackend: Default: _LinalgBackend Cusolver: _LinalgBackend Magma: _LinalgBackend +class ConvBackend(Enum): + ... + # Defined in `valgrind.h` and `callgrind.h` respecitively. def _valgrind_supported_platform() -> _bool: ... # NVALGRIND def _valgrind_toggle() -> None: ... # CALLGRIND_TOGGLE_COLLECT @@ -923,6 +920,8 @@ def autocast_increment_nesting() -> _int: ... def autocast_decrement_nesting() -> _int: ... def is_autocast_cache_enabled() -> _bool: ... def set_autocast_cache_enabled(enabled: _bool) -> None: ... +def _set_autograd_function_extension_enabled(enabled: _bool) -> None: ... +def _is_autograd_function_extension_enabled() -> _bool: ... def set_anomaly_enabled(enabled: _bool, check_nan: _bool = True) -> None: ... def is_anomaly_enabled() -> _bool: ... def is_anomaly_check_nan_enabled() -> _bool: ... @@ -975,11 +974,14 @@ class AggregationType(Enum): AVG = 1 class FileCheck(object): - # TODO (add more FileCheck signature) - def check_source_highlighted(self, highlight: str) -> 'FileCheck': ... def run(self, test_string: str) -> None: ... def check(self, test_string: str) -> 'FileCheck': ... def check_not(self, test_string: str) -> 'FileCheck': ... + def check_same(self, test_string: str) -> 'FileCheck': ... + def check_next(self, test_string: str) -> 'FileCheck': ... + def check_count(self, test_string: str, count: _int, exactly: _bool = False) -> 'FileCheck': ... + def check_dag(self, test_string: str) -> 'FileCheck': ... + def check_source_highlighted(self, test_string: str) -> 'FileCheck': ... ... # Defined in torch/csrc/jit/python/init.cpp @@ -1015,6 +1017,9 @@ def _jit_pass_lint(Graph) -> None: ... # Defined in torch/csrc/jit/python/python_custome_class.cpp def _get_custom_class_python_wrapper(name: str, attr: str) -> Any: ... +# Defined in torch/csrc/Module.cpp +def _rename_privateuse1_backend(backend: str) -> None: ... + # Defined in torch/csrc/Generator.cpp class Generator(object): device: _device @@ -1037,7 +1042,6 @@ class _DispatchModule: def def_name_t_t(self, name: str, dispatch: str, debug: str = "default_def_name_t_t") -> _DispatchModule: ... def def_schema_t_t(self, schema: str, dispatch: str, alias: str, debug: str = "default_def_schema_t_t") -> _DispatchModule: ... def impl_t_t(self, name: str, dispatch: str, debug: str = "impl_t_t") -> _DispatchModule: ... - def impl_tt_t(self, name: str, dispatch: str, debug: str = "impl_tt_t") -> _DispatchModule: ... def impl(self, name: str, dispatch: str, func: Callable) -> _DispatchModule: ... def define(self, schema: str, alias: str = "") -> _DispatchModule: ... def fallback_fallthrough(self, dispatch: str = "") -> _DispatchModule: ... @@ -1055,10 +1059,13 @@ def _dispatch_find_dangling_impls() -> List[str]: ... def _dispatch_get_all_op_names() -> List[str]: ... def _dispatch_tls_set_dispatch_key_excluded(dispatch: _dispatchkey, val: _bool) -> None: ... def _dispatch_tls_is_dispatch_key_excluded(dispatch: _dispatchkey) -> _bool: ... +def _dispatch_tls_set_dispatch_key_included(dispatch: _dispatchkey, val: _bool) -> None: ... +def _dispatch_tls_is_dispatch_key_included(dispatch: _dispatchkey) -> _bool: ... def _dispatch_isTensorSubclassLike(tensor: Tensor) -> _bool: ... def _dispatch_key_name(dispatch: _dispatchkey) -> str: ... def _dispatch_key_parse(dispatch: _dispatchkey) -> DispatchKey: ... def _dispatch_num_backends() -> _int: ... +def _functionalization_reapply_views_tls() -> _bool: ... class DispatchKey(Enum): ${dispatch_key_hints} @@ -1164,6 +1171,11 @@ class _TensorBase(metaclass=_TensorMeta): # Defined in torch/csrc/multiprocessing/init.cpp def _multiprocessing_init() -> None: ... +# Defined in torch/csrc/mps/Module.cpp +def _mps_synchronize() -> None: ... +def _mps_init() -> None: ... +def _is_mps_available() -> _bool: ... + # Defined in torch/csrc/cuda/Module.cpp def _cuda_getCurrentStream(device: _int) -> _int: ... def _cuda_getCurrentRawStream(device: _int) -> _int: ... @@ -1194,6 +1206,13 @@ def _cuda_resetPeakMemoryStats(device: _int) -> None: ... def _cuda_memorySnapshot() -> Dict[str, Any]: ... def _cuda_recordMemoryHistory(enabled: _bool, record_context: _bool, record_context_cpp: _bool, alloc_trace_max_entries: _int, alloc_trace_record_context: _bool) -> None: ... def _cuda_getAllocatorBackend() -> str: ... + +class _cuda_CUDAAllocator: + ... + +def _cuda_customAllocator(alloc_fn: _int, free_fn: _int) -> _cuda_CUDAAllocator: ... +def _cuda_changeCurrentAllocator(allocator: _cuda_CUDAAllocator) -> None: ... +def _cuda_getAllocator() -> _cuda_CUDAAllocator: ... def _cuda_lock_mutex() -> None: ... def _cuda_unlock_mutex() -> None: ... def _cuda_canDeviceAccessPeer(device: _int, peer_device: _int) -> _bool: ... @@ -1293,6 +1312,9 @@ class _CUDAGraph: def replay(self) -> None: ... def reset(self) -> None: ... def pool(self) -> Tuple[_int, _int]: ... + def enable_debug_mode(self) -> None: ... + def debug_dump(self, + debug_path: str) -> None: ... def _cuda_isCurrentStreamCapturing() -> _bool: ... @@ -1338,6 +1360,8 @@ class JitType: def with_sizes(self, sizes: List[Optional[_int]]) -> JitType: ... def kind(self) -> str: ... def scalarType(self) -> Optional[str]: ... + def getElementType(self) -> JitType: ... + def dtype(self) -> Optional[_dtype]: ... class InferredType: def __init__(self, arg: Union[JitType, str]): ... @@ -1503,3 +1527,6 @@ def _current_graph_task_id() -> _int: ... class _OutOfMemoryError: pass + +class _DistBackendError(RuntimeError): + pass diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index aad37d6a8c5ae..f16a8ec362f50 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -1,8 +1,9 @@ from datetime import timedelta from enum import Enum -from typing import Optional, List, Any, Tuple, overload, Union +from typing import Any, Dict, List, Optional, overload, Tuple, Union from torch import Tensor +from torch.futures import Future # This module is defined in torch/csrc/distributed/c10d/init.cpp @@ -32,13 +33,34 @@ class Reducer: self, params: List[Tensor], bucket_indices: List[List[int]], + per_bucket_size_limits: List[int], process_group: ProcessGroup, - expect_sparse_gradients: List[bool], - bucket_bytes_cap: int, - find_unused_parameters: bool, - gradient_as_bucket_view: bool, + expect_sparse_gradients: List[bool] = ..., + bucket_bytes_cap: int = ..., # kDefaultBucketBytesCap in reducer.hpp + find_unused_parameters: bool = ..., + gradient_as_bucket_view: bool = ..., + param_to_name_mapping: Dict[int, str] = ..., + first_bucket_types_cap: int = ..., # kDefaultFirstBucketBytes in reducer.hpp ): ... - ... + def prepare_for_forward(self) -> None: ... + def prepare_for_backward(self, output: List[Tensor]) -> None: ... + def get_backward_stats(self) -> List[int]: ... + def _install_post_backward_futures(self, futures: List[Future]) -> None: ... + def _rebuild_buckets(self) -> bool: ... + def _get_zeros_like_grad_buckets(self) -> List[GradBucket]: ... + def _push_all_rebuilt_params(self) -> None: ... + def _set_forward_pass_work_handle( + self, work: Work, use_static_world_size: bool + ): ... + def _get_local_used_map(self) -> Tensor: ... + def _set_ddp_runtime_logging_sample_rate(self, sample_rate: int) -> None: ... + def _set_static_graph(self) -> None: ... + def _run_comm_hook(self, bucket: GradBucket) -> Future: ... + def set_logger(self, logger: Logger) -> None: ... + +class DDPLoggingData: + strs_map: Dict[str, str] + ints_map: Dict[str, int] class Logger: def __init__(self, reducer: Reducer): ... @@ -49,8 +71,14 @@ class Logger: output_device: int, broadcast_buffers: bool, has_sync_bn: bool, + static_graph: bool, ): ... - ... + def set_runtime_stats_and_log(self) -> None: ... + def set_error_and_log(self, error: str) -> None: ... + def _get_ddp_logging_data(self) -> DDPLoggingData: ... + def _set_comm_hook_name(self, comm_hook: str) -> None: ... + def _set_uneven_input_join(self) -> None: ... + def _set_static_graph(self) -> None: ... def get_debug_level(): ... def set_debug_level(): ... @@ -63,7 +91,8 @@ class DebugLevel(Enum): class ReduceOp: - # note(crcrpar): These values are populated from Kind + def __init__(self, op: "RedOpType"): ... + SUM = ... PRODUCT = ... MIN = ... @@ -74,7 +103,7 @@ class ReduceOp: PREMUL_SUM = ... UNUSED = ... - class Kind(Enum): ... + class RedOpType(Enum): ... class BroadcastOptions: rootRank: int @@ -119,7 +148,9 @@ class Store: def set(self, key: str, value: str): ... def get(self, key: str) -> bytes: ... def add(self, key: str, value: int) -> int: ... - def compare_set(self, key: str, expected_value: str, desired_value: str) -> bytes: ... + def compare_set( + self, key: str, expected_value: str, desired_value: str + ) -> bytes: ... def delete_key(self, key: str) -> bool: ... def num_keys(self) -> int: ... def set_timeout(self, timeout: timedelta): ... @@ -143,7 +174,7 @@ class TCPStore(Store): is_master: bool = ..., timeout: timedelta = ..., wait_for_workers: bool = ..., - multi_tenant: bool = ... + multi_tenant: bool = ..., ): ... @property def host(self) -> str: ... @@ -168,6 +199,7 @@ class Work: class ProcessGroup: class Options: ... + def __init__(self): ... def rank(self) -> int: ... def size(self) -> int: ... @@ -236,7 +268,7 @@ class ProcessGroup: self, output: Tensor, input: Tensor, - opts = AllGatherOptions(), + opts=AllGatherOptions(), ) -> Work: ... def allgather_coalesced( self, @@ -344,6 +376,7 @@ def _round_robin_process_groups( class ProcessGroupGloo(ProcessGroup): class Device: ... class Options: ... + def __init__( self, store: Store, @@ -359,16 +392,12 @@ class ProcessGroupGloo(ProcessGroup): ... class _ProcessGroupWrapper(ProcessGroup): - def __init__( - self, - pg: ProcessGroup, - gloo_pg: ProcessGroupGloo - ): ... + def __init__(self, pg: ProcessGroup, gloo_pg: ProcessGroupGloo): ... wrapped_pg: ProcessGroup - class ProcessGroupNCCL(ProcessGroup): class Options: ... + def __init__( self, store: Store, @@ -403,9 +432,9 @@ class ProcessGroupMPI(ProcessGroup): def _compute_bucket_assignment_by_size( tensors: List[Tensor], - bucket_size: int, - expect_sparse_gradient: List[bool], - tensor_indices: List[int], + bucket_size_limits: List[int], + expect_sparse_gradient: List[bool] = ..., + tensor_indices: List[int] = ..., ) -> Tuple[List[List[int]], List[int]]: ... def _broadcast_coalesced( process_group: ProcessGroup, diff --git a/torch/_C/_dynamo/__init__.pyi b/torch/_C/_dynamo/__init__.pyi new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/torch/_C/_dynamo/eval_frame.pyi b/torch/_C/_dynamo/eval_frame.pyi new file mode 100644 index 0000000000000..3428342750cc0 --- /dev/null +++ b/torch/_C/_dynamo/eval_frame.pyi @@ -0,0 +1,10 @@ +import types +from typing import Union +from torch._dynamo.types import DynamoCallback, DynamoGuardHook + +def set_eval_frame(callback: DynamoCallback) -> DynamoCallback: ... +def reset_code(code: types.CodeType) -> None: ... +def unsupported(obj1: object, obj2: object) -> object: ... +def skip_code(code: types.CodeType) -> None: ... +def set_guard_fail_hook(hook: DynamoGuardHook) -> None: ... +def set_guard_error_hook(hook: DynamoGuardHook) -> None: ... diff --git a/torch/_C/_functorch.pyi b/torch/_C/_functorch.pyi index 6ab5f91b78f1e..d07c39d9413ac 100644 --- a/torch/_C/_functorch.pyi +++ b/torch/_C/_functorch.pyi @@ -1,4 +1,5 @@ from torch import Tensor +from enum import Enum # Defined in torch/csrc/functorch/init.cpp @@ -10,3 +11,39 @@ def is_functorch_wrapped_tensor(tensor: Tensor) -> bool: ... def is_gradtrackingtensor(tensor: Tensor) -> bool: ... def maybe_get_bdim(tensor: Tensor) -> int: ... def maybe_get_level(tensor: Tensor) -> int: ... +def unwrap_if_dead(tensor: Tensor) -> Tensor: ... +def _unwrap_for_grad(tensor: Tensor, level: int) -> Tensor: ... +def _wrap_for_grad(tensor: Tensor, level: int) -> Tensor: ... + +def set_autograd_function_allowed(allowed: bool) -> None: ... +def get_autograd_function_allowed() -> bool: ... + +# Defined in aten/src/ATen/functorch/Interpreter.h +class TransformType(Enum): + Torch: TransformType = ... + Vmap: TransformType = ... + Grad: TransformType = ... + Jvp: TransformType = ... + Functionalize: TransformType = ... + +class CInterpreter: + def key(self) -> TransformType: ... + def level(self) -> int: ... + +class CGradInterpreterPtr: + def __init__(self, interpreter: CInterpreter): ... + def lift(self, Tensor) -> Tensor: ... + def prevGradMode(self) -> bool: ... + +class CVmapInterpreterPtr: + def __init__(self, interpreter: CInterpreter): ... + def key(self) -> TransformType: ... + def level(self) -> int: ... + def batchSize(self) -> int: ... + +class DynamicLayer: + pass + +def peek_interpreter_stack() -> CInterpreter: ... +def pop_dynamic_layer_stack() -> DynamicLayer: ... +def push_dynamic_layer_stack(dl: DynamicLayer) -> int: ... diff --git a/torch/_C/_profiler.pyi b/torch/_C/_profiler.pyi index 6d6c2893f4554..4a1fe23cec614 100644 --- a/torch/_C/_profiler.pyi +++ b/torch/_C/_profiler.pyi @@ -3,6 +3,8 @@ from typing import List, Optional, Tuple, Union from torch._C import device, dtype, layout +from typing_extensions import Literal + # defined in torch/csrc/profiler/python/init.cpp class RecordScope(Enum): @@ -38,11 +40,12 @@ class ProfilerActivity(Enum): CUDA = ... class _EventType(Enum): - Allocation = ... + TorchOp = ... Backend = ... + Allocation = ... + OutOfMemory = ... PyCall = ... PyCCall = ... - TorchOp = ... Kineto = ... class _ExperimentalConfig: @@ -71,6 +74,8 @@ class _ProfilerEvent: start_tid: int start_time_ns: int children: List[_ProfilerEvent] + + # TODO(robieta): remove in favor of `self.typed` extra_fields: Union[ _ExtraFields_TorchOp, _ExtraFields_Backend, @@ -81,6 +86,18 @@ class _ProfilerEvent: _ExtraFields_Kineto, ] + @property + def typed( + self, + ) -> Union[ + Tuple[Literal[_EventType.TorchOp], _ExtraFields_TorchOp], + Tuple[Literal[_EventType.Backend], _ExtraFields_Backend], + Tuple[Literal[_EventType.Allocation], _ExtraFields_Allocation], + Tuple[Literal[_EventType.OutOfMemory], _ExtraFields_OutOfMemory], + Tuple[Literal[_EventType.PyCall], _ExtraFields_PyCall], + Tuple[Literal[_EventType.PyCCall], _ExtraFields_PyCCall], + Tuple[Literal[_EventType.Kineto], _ExtraFields_Kineto], + ]: ... @property def name(self) -> str: ... @property @@ -96,30 +113,34 @@ class _ProfilerEvent: @property def duration_time_ns(self) -> int: ... -class _Inputs: - shapes: List[List[int]] - dtypes: List[str] - strides: List[List[int]] - ivalues: List[Union[int, float, bool, complex]] - tensor_metadata: List[Optional[_TensorMetadata]] - class _TensorMetadata: impl_ptr: Optional[int] storage_data_ptr: Optional[int] id: Optional[int] + @property + def allocation_id(self) -> Optional[int]: ... @property def layout(self) -> layout: ... @property def device(self) -> device: ... @property def dtype(self) -> dtype: ... + @property + def sizes(self) -> List[int]: ... + @property + def strides(self) -> List[int]: ... + +Scalar = Union[int, float, bool, complex] +Input = Optional[Union[_TensorMetadata, List[_TensorMetadata], Scalar]] class _ExtraFields_TorchOp: - inputs: _Inputs + name: str sequence_number: int allow_tf32_cublas: bool + @property + def inputs(self) -> List[Input]: ... @property def scope(self) -> RecordScope: ... @@ -132,6 +153,8 @@ class _ExtraFields_Allocation: total_allocated: int total_reserved: int + @property + def allocation_id(self) -> Optional[int]: ... @property def device(self) -> device: ... @@ -146,17 +169,46 @@ class _PyFrameState: class _NNModuleInfo: @property - def params(self) -> List[Tuple[str, int]]: ... + def self_ptr(self) -> int: ... + @property + def cls_ptr(self) -> int: ... @property def cls_name(self) -> str: ... + @property + def parameters( + self, + ) -> List[Tuple[str, _TensorMetadata, Optional[_TensorMetadata]]]: ... + +class _OptimizerInfo: + @property + def parameters( + self, + ) -> List[ + Tuple[ + # Parameter + _TensorMetadata, + # + # Gradient (if present during optimizer.step()) + Optional[_TensorMetadata], + # + # Optimizer state for Parameter as (name, tensor) pairs + List[Tuple[str, _TensorMetadata]], + ] + ]: ... class _ExtraFields_PyCCall: - callsite: _PyFrameState - caller: _PyFrameState - module: Optional[_NNModuleInfo] + @property + def caller(self) -> _PyFrameState: ... class _ExtraFields_PyCall: - caller: _PyFrameState + @property + def callsite(self) -> _PyFrameState: ... + @property + def caller(self) -> _PyFrameState: ... + @property + def module(self) -> Optional[_NNModuleInfo]: ... + @property + def optimizer(self) -> Optional[_OptimizerInfo]: ... class _ExtraFields_Kineto: ... diff --git a/torch/__init__.py b/torch/__init__.py index 8a824642ab57d..c8543057c7474 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -2,7 +2,7 @@ r""" The torch package contains data structures for multi-dimensional tensors and defines mathematical operations over these tensors. -Additionally, it provides many utilities for efficient serializing of +Additionally, it provides many utilities for efficient serialization of Tensors and arbitrary types, and other useful utilities. It has a CUDA counterpart, that enables you to run your tensor computations @@ -29,7 +29,7 @@ from ._six import string_classes as _string_classes -from typing import Set, Type, TYPE_CHECKING, Union, Callable, Any +from typing import Any, Callable, Dict, Optional, Set, Type, TYPE_CHECKING, Union import builtins __all__ = [ @@ -47,7 +47,8 @@ 'is_deterministic_algorithms_warn_only_enabled', 'set_deterministic_debug_mode', 'get_deterministic_debug_mode', 'set_float32_matmul_precision', 'get_float32_matmul_precision', - 'set_warn_always', 'is_warn_always_enabled', + 'set_warn_always', 'is_warn_always_enabled', 'SymInt', 'SymFloat', + 'compile', ] ################################################################################ @@ -141,20 +142,46 @@ kernel32.SetErrorMode(prev_error_mode) +def _preload_cuda_deps(): + """ Preloads cudnn/cublas deps if they could not be found otherwise """ + # Should only be called on Linux if default path resolution have failed + assert platform.system() == 'Linux', 'Should only be called on Linux' + for path in sys.path: + nvidia_path = os.path.join(path, 'nvidia') + if not os.path.exists(nvidia_path): + continue + cublas_path = os.path.join(nvidia_path, 'cublas', 'lib', 'libcublas.so.11') + cudnn_path = os.path.join(nvidia_path, 'cudnn', 'lib', 'libcudnn.so.8') + if not os.path.exists(cublas_path) or not os.path.exists(cudnn_path): + continue + break + + ctypes.CDLL(cublas_path) + ctypes.CDLL(cudnn_path) + + # See Note [Global dependencies] def _load_global_deps(): - if platform.system() == 'Windows' or sys.executable == 'torch_deploy': + if sys.executable == 'torch_deploy' or platform.system() == 'Windows': return lib_name = 'libtorch_global_deps' + ('.dylib' if platform.system() == 'Darwin' else '.so') here = os.path.abspath(__file__) lib_path = os.path.join(os.path.dirname(here), 'lib', lib_name) - ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL) + try: + ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL) + except OSError as err: + # Can only happen of wheel with cublas as PYPI deps + # As PyTorch is not purelib, but nvidia-cublas-cu11 is + if 'libcublas.so.11' not in err.args[0]: + raise err + _preload_cuda_deps() + ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL) if (USE_RTLD_GLOBAL_WITH_LIBTORCH or os.getenv('TORCH_USE_RTLD_GLOBAL')) and \ - platform.system() != 'Windows': + (sys.executable == "torch_deploy" or platform.system() != 'Windows'): # Do it the hard way. You might want to load libtorch with RTLD_GLOBAL in a # few circumstances: # @@ -196,6 +223,92 @@ def _load_global_deps(): if TYPE_CHECKING: import torch._C as _C +class SymInt: + """ + Like an int (including magic methods), but redirects all operations on the + wrapped node. This is used in particular to symbolically record operations + in the symbolic shape workflow. + """ + + def __init__(self, node): + # This field MUST be named node; C++ binding code assumes that this + # class has a field named node that stores SymNode + self.node = node + + def __bool__(self): + return self.node.bool_() + + def __int__(self): + return self.node.int_() + + # Magic methods installed by torch.fx.experimental.symbolic_shapes + + def __eq__(self, other: object) -> builtins.bool: + raise AssertionError("type stub not overridden") + + def __lt__(self, other) -> builtins.bool: + raise AssertionError("type stub not overridden") + + def __gt__(self, other) -> builtins.bool: + raise AssertionError("type stub not overridden") + + def __le__(self, other) -> builtins.bool: + raise AssertionError("type stub not overridden") + + def __ge__(self, other) -> builtins.bool: + raise AssertionError("type stub not overridden") + + def __sym_float__(self): + raise AssertionError("type stub not overridden") + + def __repr__(self): + return str(self.node) + + # For BC; direct access of node is OK too + def get_pyobj(self): + return self.node + +class SymFloat: + """ + Like an float (including magic methods), but redirects all operations on the + wrapped node. This is used in particular to symbolically record operations + in the symbolic shape workflow. + """ + + def __init__(self, node): + from torch.fx.experimental.symbolic_shapes import SymNode + assert isinstance(node, SymNode) + # This field MUST be named node; C++ binding code assumes that this + # class has a field named node that stores SymNode + self.node = node + + def __bool__(self): + return self.node.bool_() + + # Magic methods installed by torch.fx.experimental.symbolic_shapes + + def __eq__(self, other: object) -> builtins.bool: + raise AssertionError("type stub not overridden") + + def __lt__(self, other) -> builtins.bool: + raise AssertionError("type stub not overridden") + + def __gt__(self, other) -> builtins.bool: + raise AssertionError("type stub not overridden") + + def __le__(self, other) -> builtins.bool: + raise AssertionError("type stub not overridden") + + def __ge__(self, other) -> builtins.bool: + raise AssertionError("type stub not overridden") + + def __repr__(self): + return self.node.str() + + # For BC; direct access of node is OK too + def get_pyobj(self): + return self.node + # Check to see if we can load C extensions, and if not provide some guidance # on what the problem might be. try: @@ -647,7 +760,7 @@ def is_warn_always_enabled(): ################################################################################ from ._tensor import Tensor -from .storage import _StorageBase, TypedStorage, _LegacyStorage, UntypedStorage +from .storage import _StorageBase, TypedStorage, _LegacyStorage, UntypedStorage, _warn_typed_storage_removal # NOTE: New Storage classes should never be added. When adding a new # dtype, use torch.storage.TypedStorage directly. @@ -655,86 +768,171 @@ def is_warn_always_enabled(): class ByteStorage(_LegacyStorage): @classproperty def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.uint8 class DoubleStorage(_LegacyStorage): @classproperty def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.double class FloatStorage(_LegacyStorage): @classproperty def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.float class HalfStorage(_LegacyStorage): @classproperty def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.half class LongStorage(_LegacyStorage): @classproperty def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.long class IntStorage(_LegacyStorage): @classproperty def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.int class ShortStorage(_LegacyStorage): @classproperty def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.short class CharStorage(_LegacyStorage): @classproperty def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.int8 class BoolStorage(_LegacyStorage): @classproperty def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.bool class BFloat16Storage(_LegacyStorage): @classproperty def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.bfloat16 class ComplexDoubleStorage(_LegacyStorage): @classproperty def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.cdouble class ComplexFloatStorage(_LegacyStorage): @classproperty def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.cfloat class QUInt8Storage(_LegacyStorage): @classproperty def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.quint8 class QInt8Storage(_LegacyStorage): @classproperty def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.qint8 class QInt32Storage(_LegacyStorage): @classproperty def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.qint32 class QUInt4x2Storage(_LegacyStorage): @classproperty def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.quint4x2 class QUInt2x4Storage(_LegacyStorage): @classproperty def dtype(self): + _warn_typed_storage_removal() + return self._dtype + + @classproperty + def _dtype(self): return torch.quint2x4 _storage_classes = { @@ -758,7 +956,7 @@ def dtype(self): ################################################################################ def manager_path(): - if platform.system() == 'Windows' or sys.executable == 'torch_deploy': + if sys.executable == 'torch_deploy' or platform.system() == 'Windows': return b"" path = get_file_path('torch', 'bin', 'torch_shm_manager') prepare_multiprocessing_environment(get_file_path('torch')) @@ -941,6 +1139,74 @@ def compiled_with_cxx11_abi(): lstsq, ) +def compile(model: Optional[Callable] = None, *, + fullgraph: builtins.bool = False, + dynamic: builtins.bool = False, + backend: Union[str, Callable] = "inductor", + mode: Union[str, None] = None, + passes: Optional[Dict[str, Union[str, builtins.int, builtins.bool]]] = None, + **kwargs) -> Callable: + """ + Optimizes given model/function using Dynamo and specified backend + + Args: + model (Callable): Module/function to optimize + fullgraph (bool): Whether it is ok to break model into several subgraphs + dynamic (bool): Use dynamic shape tracing + backend (str or Callable): backend to be used + mode (str): Can be either "default", "reduce-overhead" or "max-autotune" + passes (dict): A dictionary of passes to the backend. Passes currently recognized by inductor backend: + - static-memory + - matmul-tune + - matmul-padding + - triton-autotune + - triton-bmm + - triton-mm + - triton-convolution + - rematerialize-threshold + - rematerialize-acc-threshold + + Example:: + + @torch.compile(passes={"matmul-padding": True}, fullgraph=True) + def foo(x): + return torch.sin(x) + torch.cos(x) + + """ + _C._log_api_usage_once("torch.compile") + # Decorator mode + if model is None: + def fn(model: Callable): + if model is None: + raise RuntimeError("Model can't be None") + return compile(model, + fullgraph=fullgraph, + dynamic=dynamic, + backend=backend, + mode=mode, + passes=passes, + **kwargs) + return fn + + import torch._dynamo + from torch._dynamo.eval_frame import lookup_backend + from torch._inductor.config import InductorConfigContext + if mode is not None and passes is not None: + raise RuntimeError("Either mode or passes can be specified, but both can't be specified at the same time.") + if mode is None and passes is None: + mode = "default" + if backend == "inductor": + compile_fn = lookup_backend(backend) + cm = InductorConfigContext(mode if mode is not None else passes) + + def _compile_fn(model_, inputs_): + with cm: + return compile_fn(model_, inputs_) + + _compile_fn._torchdynamo_orig_callable = compile_fn # type: ignore[attr-defined] + backend = _compile_fn + return torch._dynamo.optimize(backend=backend, nopython=fullgraph, dynamic=dynamic, **kwargs)(model) + def _register_device_module(device_type, module): r"""Register an external runtime module of the specific :attr:`device_type` @@ -961,13 +1227,15 @@ def _register_device_module(device_type, module): # expose return_types from . import return_types -if sys.executable != 'torch_deploy' and os.environ.get('PYTORCH_DISABLE_LIBRARY', "0") == "0": - from . import library - if not TYPE_CHECKING: - from . import _meta_registrations +from . import library +if not TYPE_CHECKING: + from . import _meta_registrations # Enable CUDA Sanitizer if 'TORCH_CUDA_SANITIZER' in os.environ: import torch.cuda._sanitizer as csan csan.enable_cuda_sanitizer() + +# Populate magic methods on SymInt and SymFloat +import torch.fx.experimental.symbolic_shapes diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py index 2dcda014cea30..d50f33933da49 100644 --- a/torch/_decomp/__init__.py +++ b/torch/_decomp/__init__.py @@ -26,16 +26,34 @@ pre_autograd_decomposition_table = global_decomposition_table["pre_autograd"] meta_table = global_decomposition_table["meta"] -meta_lib = torch.library.Library("aten", "IMPL", "Meta") - -# decompositions which have been disabled as meta kernel implementations, -# usually due to mismatching strides, aliasing, or other inconsistent property -_disabled_meta_decomps = set() +def _add_op_to_registry(registry, op, fn): + """ + This is an internal API for adding an op to the decomposition table. -def register_decomposition( - aten_op, registry=None, *, type="post_autograd", disable_meta: bool = False -): + If op is OpOverload, it will be added to the registry directly. + If op is OpOverloadPacket, all the valid op_overloads in the packet will be added to the registry. + """ + overloads = [] + if isinstance(op, OpOverload): + overloads.append(op) + else: + assert isinstance(op, OpOverloadPacket) + for ol in op.overloads(): + overloads.append(getattr(op, ol)) + + for op_overload in overloads: + if op_overload in registry: + raise RuntimeError(f"duplicate registrations for {op_overload}") + + # TorchScript dumps a bunch of extra nonsense overloads + # which don't have corresponding dispatcher entries, we need + # to filter those out, e.g aten.add.float_int + if torch._C._dispatch_has_kernel(op_overload.name()): + registry[op_overload] = fn + + +def register_decomposition(aten_op, registry=None, *, type="post_autograd"): """ A decorator to register a function as a decomposition to the Python decomposition table. Use it like this:: @@ -52,9 +70,8 @@ def clamp_min(x): autograd) and not just backend tracing, where we then need to know if a decomposition can be used to simulate a transform. - By default, if the decomposition is for an operator that doesn't have - a Meta implementation, we will register it to the dispatcher. Use - `disable_meta` to disable this behavior. + By default, we also will register it to the Meta key of dispatcher, + and replace the c++ Meta implementation if there is already one. """ assert type in {"post_autograd", "pre_autograd", "meta"} @@ -106,62 +123,11 @@ def _fn(*args, **kwargs): if registry is None: registry = global_decomposition_table[type] - def add_op_to_table(aten_op): - overloads = [] - if isinstance(aten_op, OpOverload): - overloads.append(aten_op) - else: - assert isinstance(aten_op, OpOverloadPacket) - for ol in aten_op.overloads(): - overloads.append(getattr(aten_op, ol)) - for op_overload in overloads: - if op_overload in registry: - raise RuntimeError(f"duplicate registrations for {op_overload}") - registry[op_overload] = fn - op_overload.py_impl(torch._C.DispatchKey.Meta)(fn) - # TODO: factor this logic into OpOverload or Library API - name = op_overload._schema.name - if op_overload._schema.overload_name: - name += "." + op_overload._schema.overload_name - - if disable_meta: - global _disabled_meta_decomps - _disabled_meta_decomps.add(op_overload) - - if ( - not disable_meta - # TorchScript dumps a bunch of extra nonsense overloads - # which don't have corresponding dispatcher entries, we need - # to filter those out - and torch._C._dispatch_has_kernel(name) - # Don't register a python meta kernel to any operator that has - # should already work with meta tensors today. - # We can check that by seeing if the "computed table" for the operator - # has a registration to Meta; - # either through a direct registration, or an indirect one through - # an alias dispatch key (e.g. CompositeImplicitAutograd) - and not torch._C._dispatch_has_computed_kernel_for_dispatch_key( - name, "Meta" - ) - ): - if any( - a.alias_info is not None and not a.alias_info.is_write - for a in op_overload._schema.arguments - ): - raise RuntimeError( - f""" -Attempting to register a python meta kernel for a view operator: {str(op_overload)}. -We shouldn't do this, because the output will report as not having aliased storages. -All view ops have meta kernels in C++ today, so we should use those instead. - -If you're registering an operator through the `@register_decomposition` decorator, -Please set `disable_meta=True`. - """ - ) - meta_lib.impl(op_overload, fn) + def register(op): + _add_op_to_registry(registry, op, fn) # To handle allowing multiple aten_ops at once - tree_map(add_op_to_table, aten_op) + tree_map(register, aten_op) return fn return decomposition_decorator diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index 4f61dc9b26f8a..1a8335dc292a1 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -4,15 +4,22 @@ from enum import Enum from functools import partial, reduce from itertools import product -from typing import Callable, cast, Iterable, List, Optional, Tuple +from typing import Callable, cast, Iterable, List, Optional, Tuple, Union import torch +import torch._prims as prims import torch._prims_common as utils import torch.nn.functional as F from torch import Tensor from torch._decomp import register_decomposition -from torch._prims_common import NumberType, TensorLike, TensorSequenceType -from torch._prims_common.wrappers import _maybe_resize_out, _safe_copy_out, out_wrapper +from torch._prims_common import IntLike, NumberType, TensorLike, TensorSequenceType +from torch._prims_common.wrappers import ( + _maybe_convert_to_dtype, + _maybe_resize_out, + _safe_copy_out, + out_wrapper, +) +from torch.fx.experimental.symbolic_shapes import guard_int, sym_float, sym_int from torch.utils._pytree import tree_flatten, tree_map DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined] @@ -110,19 +117,6 @@ def softplus_backward(out_grad: Tensor, x: Tensor, beta: float, threshold: float return torch.where((x * beta) > threshold, out_grad, out_grad * z / (z + 1.0)) -@register_decomposition(aten.elu) -@pw_cast_for_opmath -def elu( - self: Tensor, alpha: float = 1, scale: float = 1, input_scale: float = 1 -) -> Tensor: - negcoef = alpha * scale - poscoef = scale - negiptcoef = input_scale - return torch.where( - self > 0, self * poscoef, (torch.exp(self * negiptcoef) - 1) * negcoef - ) - - @register_decomposition(aten.elu_backward) @pw_cast_for_opmath def elu_backward( @@ -696,7 +690,12 @@ def _softmax_backward_data( grad_input = new_grad_output - output * torch.sum( new_grad_output, dim=dim, keepdim=True ) - return _cast_grad_to_input_dtype(grad_output, grad_input, input_dtype) + + # CPU kernel doesn't respect input_dtype, but following check doesn't work for meta tensor + # if grad_output.device == torch.device("cpu"): + # return grad_input.contiguous() + + return _cast_grad_to_input_dtype(grad_output, grad_input, input_dtype).contiguous() @register_decomposition(aten._log_softmax_backward_data) @@ -912,9 +911,17 @@ def check_positive(param, param_name, strict=True): @register_decomposition(aten.native_dropout_backward) -@pw_cast_for_opmath def native_dropout_backward(grad_output: Tensor, mask: Tensor, scale: float): - return grad_output * (mask.type_as(grad_output) * scale) + # According to the CUDA kernel implementation we should have this test; + # but it seems to fail tests! + # utils.check(mask.dtype == torch.bool, lambda: f"Mask should be Bool Scalar Type {mask.dtype}") + + # Mimicking CUDA kernel's behavior for output stride: output follow input's memory format + # This different from TensorIterator's behavior + r = (grad_output * (mask.type_as(grad_output) * scale)).clone( + memory_format=utils.suggest_memory_format(grad_output) + ) + return r @register_decomposition(aten.unfold_backward) @@ -1025,22 +1032,19 @@ def embedding( sparse: bool = False, ) -> Tensor: assert weight.dim() == 2, "'weight' must be 2-D" - # TODO: Assert not ported over yet - # auto indices_arg = TensorArg(indices, "indices", 1); - # checkScalarTypes("embedding", indices_arg, {kLong, kInt}); - - if indices.dim() == 1: - return weight.index_select(0, indices) - - size = list(indices.shape) - for d in weight.shape[1:]: - size.append(d) - - return weight.index_select(0, indices.reshape(-1)).view(size) + # Nb. scale_grad_by_freq is not used in the forward + if indices.ndim <= 1: + # We need this one as weight[indices] calls item() in these cases + out = weight.index_select(0, indices) + if indices.ndim == 0: + out = out.squeeze(0) + return out + else: + return weight[indices] -# TODO: Correct the type promotion semantics @register_decomposition(aten.embedding_dense_backward) +@pw_cast_for_opmath def embedding_dense_backward( grad_output: Tensor, indices: Tensor, @@ -1048,22 +1052,20 @@ def embedding_dense_backward( padding_idx: int, scale_grad_by_freq: bool, ): - numel = indices.numel() - grad = grad_output.reshape(numel, grad_output.size(-1)) - grad_weight = grad_output.new_zeros((num_weights, grad_output.shape[-1])) - indices_rank1 = indices.reshape(numel) + indices = _maybe_convert_to_dtype(indices, torch.long) # type: ignore[assignment] if scale_grad_by_freq: counts = indices.new_zeros((num_weights,)) - ones = indices.new_ones((numel,)) - counts = counts.index_put([indices_rank1], ones, accumulate=True) - grad_weights_scale = counts[indices_rank1] - grad = grad / grad_weights_scale.unsqueeze(1) - skip_padding = (indices_rank1 != padding_idx).unsqueeze(1) - skip_padding = skip_padding.expand_as(grad) - zero_grad = torch.full_like(grad, 0) - return grad_weight.index_put( - [indices_rank1], torch.where(skip_padding, grad, zero_grad), accumulate=True + ones = torch.ones_like(indices) + counts = counts.index_put([indices], ones, accumulate=True) + grad_weights_scale = counts[indices] + grad_output = grad_output / grad_weights_scale.unsqueeze(1) + + mask = _unsqueeze_to_dim(indices == padding_idx, grad_output.ndim) + grad = grad_output.masked_fill(mask, 0) + grad_weight = grad_output.new_zeros( + (num_weights,) + grad_output.shape[indices.ndim :] ) + return grad_weight.index_put([indices], grad, accumulate=True) def prod(x: List[int]): @@ -1073,7 +1075,7 @@ def prod(x: List[int]): return r -@register_decomposition(aten.split_with_sizes, disable_meta=True) +@register_decomposition(aten.split_with_sizes) def split_with_sizes( self: Tensor, split_sizes: List[int], dim: int = 0 ) -> List[Tensor]: @@ -1087,7 +1089,7 @@ def split_with_sizes( return splits -@register_decomposition(aten.split.Tensor, disable_meta=True) +@register_decomposition(aten.split.Tensor) def split(self: Tensor, split_size: int, dim: int = 0) -> List[Tensor]: input_sizes = self.shape dim_size = input_sizes[dim] @@ -1095,8 +1097,9 @@ def split(self: Tensor, split_size: int, dim: int = 0) -> List[Tensor]: assert dim_size == 0 return [self] chunks = (dim_size + split_size - 1) // split_size + chunks = guard_int(chunks) split_sizes = [split_size for i in range(chunks)] - split_sizes[chunks - 1] = split_size - (split_size * chunks - dim_size) + split_sizes[-1] = split_size - (split_size * chunks - dim_size) return torch.split(self, split_sizes, dim) @@ -1111,7 +1114,14 @@ def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta: int = 1, alpha: int = out = alpha * torch.mm(mat1, mat2) if beta == 0: return out - return beta * self + out + + # The output of aten.addmm is contiguous, we need to match this behavior in the decomposition. + # The original implementation 'beta * self + out' would return a strided tensor if `self` is strided. + # We thus use `out`, the output of torch.mm, which is always contiguous, as the first argument for addition. + # This is relying on TensorIterator's behavior that it takes higher precedence on the stride of first input. + # Alternative, we can write `(beta * self + out).contiguous()`, but it introduces another copy in some cases. + # This implementation is not ideal, and we should revisit this when we have a better solution. + return out + beta * self # This computes the mean and variance along the specifized normalization dims, @@ -1131,37 +1141,6 @@ def normalize(input, norm_dims, eps): return out, mean, rstd -@register_decomposition(aten.native_group_norm.default, disable_meta=True) -def native_group_norm( - input: Tensor, - weight: Optional[Tensor], - bias: Optional[Tensor], - N: int, - C: int, - HxW: int, - group: int, - eps: float, -) -> Tuple[Tensor, Tensor, Tensor]: - orig_shape = input.shape - input = input.view(N, group, C // group, HxW) - reduction_dims = [2, 3] - out, mean, rstd = normalize(input, reduction_dims, eps) - mean = _squeeze_multiple(mean, reduction_dims) - rstd = _squeeze_multiple(rstd, reduction_dims) - out = out.view(orig_shape) - if weight is not None: - weight = _unsqueeze_to_dim(weight, out.dim() - 1) - out = out * weight - if bias is not None: - bias = _unsqueeze_to_dim(bias, out.dim() - 1) - out = out + bias - - out = out.to(dtype=input.dtype) - mean = mean.to(dtype=input.dtype) - rstd = rstd.to(dtype=input.dtype) - return (out, mean, rstd) - - @register_decomposition(aten.native_group_norm_backward) @pw_cast_for_opmath def native_group_norm_backward( @@ -1334,8 +1313,7 @@ def native_layer_norm_backward( ) -@register_decomposition(aten.native_batch_norm) -def native_batch_norm( +def native_batch_norm_helper( input: Tensor, weight: Optional[Tensor], bias: Optional[Tensor], @@ -1344,16 +1322,21 @@ def native_batch_norm( training: bool, momentum: float, eps: float, -) -> Tuple[Tensor, Tensor, Tensor]: + functional: bool, +) -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: reduction_dims = [0] + list(range(2, input.dim())) computation_dtype = utils.get_computation_dtype(input.dtype) + new_running_mean = running_mean + new_running_var = running_var if training: output, mean, rstd = normalize(input, reduction_dims, eps) save_mean = _squeeze_multiple(mean, reduction_dims) save_rstd = _squeeze_multiple(rstd, reduction_dims) if running_mean is not None: - running_mean.copy_(momentum * save_mean + (1 - momentum) * running_mean) + new_running_mean = momentum * save_mean + (1 - momentum) * running_mean + if not functional: + running_mean.copy_(new_running_mean) if running_var is not None: n = input.numel() / input.shape[1] # This doesn't strictly match eager's numerics, which accumulates var sum and then directly applies the correction @@ -1362,11 +1345,15 @@ def native_batch_norm( unbiased_var = torch.var(input, reduction_dims, unbiased=False) * ( n / (n - 1) ) - running_var.copy_(momentum * unbiased_var + (1 - momentum) * running_var) + new_running_var = momentum * unbiased_var + (1 - momentum) * running_var + if not functional: + running_var.copy_(new_running_var) else: assert running_mean is not None and running_var is not None running_mean = running_mean.to(dtype=computation_dtype, copy=True) + new_running_mean = running_mean running_var = running_var.to(dtype=computation_dtype, copy=True) + new_running_var = running_var mean = running_mean invstd = 1 / (torch.sqrt(running_var + eps)) # Very annoying inconsistency where CPU and CUDA give different shapes @@ -1392,7 +1379,127 @@ def native_batch_norm( if input.device.type == "cpu": save_mean = save_mean.to(dtype=input.dtype) save_rstd = save_rstd.to(dtype=input.dtype) - return output.to(dtype=input.dtype), save_mean, save_rstd + return ( + output.to(dtype=input.dtype), + save_mean, + save_rstd, + new_running_mean, + new_running_var, + ) + + +@register_decomposition(aten.native_batch_norm) +def native_batch_norm( + input: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + training: bool, + momentum: float, + eps: float, +) -> Tuple[Tensor, Tensor, Tensor]: + output, save_mean, save_rstd, _, _ = native_batch_norm_helper( + input, weight, bias, running_mean, running_var, training, momentum, eps, False + ) + return output, save_mean, save_rstd + + +# TODO: this decomposition is NOT here to stay. We would much prefer replacing native_batch_norm +# with our new correctly schema'd _native_batch_norm_legit and its variants, but +# we cannot do that immediately in the C++ because it would be forwards incompatible +# with some mobile use cases. +# +# Since this change is most impactful for aot autograd/functionalization, we simply +# register this decomposition on the Autograd key for the python dispatcher (which is +# currently only used by aot autograd/functionalization and no one else, really). +# In two weeks or so, we should remove this decomposition and phase out the current native_batch_norm +# to be _native_batch_norm_legit and have the right schema (stating that there are input mutations). +@torch.ops.aten.native_batch_norm.default.py_impl(DispatchKey.Autograd) +def native_batch_norm_decomposition( + input: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + training: bool, + momentum: float, + eps: float, +) -> Tuple[Tensor, Tensor, Tensor]: + if running_mean is None and running_var is None: + return aten._native_batch_norm_legit( + input, weight, bias, training, momentum, eps + ) + if running_mean is None: + raise RuntimeError( + "running_mean is None, but running_var is provided. " + "They should both be None or both be provided." + ) + if running_var is None: + raise RuntimeError( + "running_var is None, but running_mean is provided. " + "They should both be None or both be provided." + ) + return aten._native_batch_norm_legit( + input, weight, bias, running_mean, running_var, training, momentum, eps + ) + + +@register_decomposition(aten._native_batch_norm_legit.default) +def _native_batch_norm_legit( + input: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + running_mean: Tensor, + running_var: Tensor, + training: bool, + momentum: float, + eps: float, +) -> Tuple[Tensor, Tensor, Tensor]: + output, save_mean, save_rstd, _, _ = native_batch_norm_helper( + input, weight, bias, running_mean, running_var, training, momentum, eps, False + ) + return output, save_mean, save_rstd + + +@register_decomposition(aten._native_batch_norm_legit.no_stats) +def _native_batch_norm_legit_no_stats( + input: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + training: bool, + momentum: float, + eps: float, +) -> Tuple[Tensor, Tensor, Tensor]: + output, save_mean, save_rstd, _, _ = native_batch_norm_helper( + input, weight, bias, None, None, training, momentum, eps, False + ) + return output, save_mean, save_rstd + + +@register_decomposition(aten._native_batch_norm_legit_functional.default) +def _native_batch_norm_legit_functional( + input: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + running_mean: Tensor, + running_var: Tensor, + training: bool, + momentum: float, + eps: float, +) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + ( + output, + save_mean, + save_rstd, + new_running_mean, + new_running_var, + ) = native_batch_norm_helper( + input, weight, bias, running_mean, running_var, training, momentum, eps, True + ) + assert new_running_mean is not None, "new_running_mean should not be None" + assert new_running_var is not None, "new_running_var should not be None" + return output, save_mean, save_rstd, new_running_mean, new_running_var @register_decomposition(aten._fused_dropout) @@ -1433,7 +1540,6 @@ def _to_copy( return x -@register_decomposition(aten.xlogy.Tensor) @pw_cast_for_int_to_real def xlogy(self: Tensor, other: Tensor) -> Tensor: return aten.where( @@ -1447,51 +1553,11 @@ def xlogy(self: Tensor, other: Tensor) -> Tensor: ) -@register_decomposition(aten.var.correction) -@reduction_complex_to_real -def var_correction( - x: Tensor, - dim: Optional[List[int]], - correction: Optional[int] = None, - keepdim: bool = False, -): - dims: List[int] = [] if dim is None else dim - - if x.is_complex(): - # For complex, calculate variance of real and imaginary components - # separately then add to get overall variance. - real_in = x.real - var_real = torch.var(real_in, dims, correction=correction, keepdim=keepdim) - imag_in = x.imag - var_imag = torch.var(imag_in, dims, correction=correction, keepdim=keepdim) - return var_real + var_imag - - if correction is None: - correction = 1 - - if len(dims) == 0: - n = prod(x.shape) # type: ignore[arg-type] - else: - n = 1 - for d in dims: - n *= x.shape[d] - - mean = torch.mean(x, dims, True) - sub = x - mean - sq = sub * sub - sum = torch.sum(sq, dims, keepdim) - - if correction: - n = n - correction - - return sum / n - - @register_decomposition(aten.std.correction) @reduction_complex_to_real def std_decomposition( x: Tensor, - dim: Optional[List[int]], + dim: Optional[List[int]] = None, correction: Optional[int] = None, keepdim: bool = False, ): @@ -1501,11 +1567,14 @@ def std_decomposition( # Questionable decompositions # This is only valid if we're running the graph without autograd, such as if the backward pass has been traced. # Note that this decomposition causes issues with in-place ops -@register_decomposition([aten.detach, aten.lift, aten.lift_fresh], disable_meta=True) +@register_decomposition([aten.detach, aten.lift, aten.lift_fresh]) def nop_decomposition(x): return aten.alias(x) +# Also register to the Autograd dispatch key, so this decomp can run above autograd. +# native_batch_norm needs to decompose into other ops before autograd. +@torch.ops.aten.cudnn_batch_norm.default.py_impl(DispatchKey.Autograd) @register_decomposition(aten.cudnn_batch_norm) def cudnn_batch_norm( input: Tensor, @@ -1559,6 +1628,10 @@ def native_batch_norm_backward( output_mask: List[bool], ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: input_dtype = input.dtype + if weight is not None: + weight_dtype = weight.dtype + else: + weight_dtype = input_dtype computation_dtype = utils.get_computation_dtype(input.dtype) ( grad_out_cast, @@ -1636,8 +1709,8 @@ def native_batch_norm_backward( return ( grad_input.to(input_dtype), - _maybe_cast(grad_weight, input_dtype), - _maybe_cast(grad_bias, input_dtype), + _maybe_cast(grad_weight, weight_dtype), + _maybe_cast(grad_bias, weight_dtype), ) @@ -1667,7 +1740,7 @@ def cudnn_batch_norm_backward( ) -@register_decomposition(aten._adaptive_avg_pool2d, disable_meta=True) +@register_decomposition(aten._adaptive_avg_pool2d) @pw_cast_for_opmath def adaptive_avg_pool2d(input: Tensor, output_size: Tuple[int, int]): # Preconditions @@ -1740,7 +1813,7 @@ def compute_idx(in_size, out_size): return torch.mean(vals, dim=(-3, -1)) def maybe_mask(vals, length, range_max, adaptive, dim): - if isinstance(length, int): + if isinstance(length, IntLike): return vals, length else: # zero-out the things we didn't really want to select @@ -1838,7 +1911,6 @@ def log_sigmoid_forward(self: Tensor) -> Tuple[Tensor, Tensor]: @register_decomposition(aten.norm) @out_wrapper() -@reduction_complex_to_real def norm( self: Tensor, p: Optional[float] = None, @@ -1846,34 +1918,101 @@ def norm( keepdim: bool = False, dtype: Optional[torch.dtype] = None, ): - if p is None: - p = 2.0 - return torch.linalg.vector_norm(self, p, dim, keepdim, dtype=dtype) + p = p if p is not None else 2.0 + if dtype: + return torch.linalg.vector_norm(self.to(dtype), p, dim, keepdim, dtype=dtype) + + computation_dtype, result_dtype = utils.elementwise_dtypes( + self, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT + ) + return torch.linalg.vector_norm( + self.to(computation_dtype), p, dim, keepdim, dtype=dtype + ).to(result_dtype) + + +@register_decomposition(aten.uniform) +def uniform( + x: Tensor, + low: Union[bool, int, float] = 0.0, + high: Union[bool, int, float] = 1.0, +): + return prims._uniform_helper( + x.shape, + low=sym_float(low), + high=sym_float(high), + dtype=x.dtype, + device=x.device, + ) + + +# aten/src/ATen/native/UpSample.cpp compute_output_size +def upsample_compute_output_size(input_size, output_size, scale_factors): + spatial_dimensions = len(input_size) - 2 + if output_size is not None: + utils.check( + scale_factors is None, + lambda: "Must specify exactly one of output_size and scale_factors", + ) + utils.check(len(output_size) == spatial_dimensions, lambda: "") + return output_size + if scale_factors is not None: + # NB: this isn't necessary lol + utils.check( + output_size is None, + lambda: "Must specify exactly one of output_size and scale_factors", + ) + utils.check(len(scale_factors) == spatial_dimensions, lambda: "") + return [ + # Returning output_size as float. We cannot convert it to int directly, + # as latter computation of scale_factor is relying output size being float + sym_float(input_size[i + 2] * scale_factors[i]) + for i in range(spatial_dimensions) + ] + utils.check( + False, lambda: "Must specify exactly one of output_size and scale_factors" + ) + + +def get_scale_value(scales, idx): + if scales is None: + return None + return scales[idx] @register_decomposition(torch.ops.aten.upsample_bilinear2d.vec) -@register_decomposition(torch.ops.aten.upsample_bilinear2d.vec, type="pre_autograd") +@torch.ops.aten.upsample_bilinear2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) +@torch.ops.aten.upsample_bilinear2d.vec.py_impl(DispatchKey.Autograd) +def upsample_bilinear2d_vec(input, output_size, align_corners, scale_factors): + osize = upsample_compute_output_size(input.size(), output_size, scale_factors) + scale_h = get_scale_value(scale_factors, 0) + scale_w = get_scale_value(scale_factors, 1) + + # NB: osize could be a list of float when scale_factors is float + # so we cannot redispatch to aten.upsample_bilinear2d.default here + return upsample_bilinear2d(input, osize, align_corners, scale_h, scale_w) + + +@register_decomposition(torch.ops.aten.upsample_bilinear2d.default) +@torch.ops.aten.upsample_bilinear2d.default.py_impl(DispatchKey.Autograd) @pw_cast_for_opmath -def upsample_bilinear2d_vec( +def upsample_bilinear2d( input: Tensor, - output_size: Optional[List[int]], + output_size: List[Union[int, float]], align_corners: bool, - scale_factors: Optional[List[float]], + scales_h: Optional[float] = None, + scales_w: Optional[float] = None, ) -> Tensor: # get dimensions of original image n_batch, n_channels, in_h, in_w = input.shape - if output_size is not None: - out_h = float(output_size[0]) - out_w = float(output_size[1]) - elif scale_factors is not None: - out_h = in_h * scale_factors[0] - out_w = in_w * scale_factors[1] + out_h = sym_float(output_size[0]) + out_w = sym_float(output_size[1]) # Calculate horizontal and vertical scaling factor + # TODO: Figure out if scales_h/scales_w matters here if out_h > 1: if align_corners: - h_scale_factor = (in_h - 1) / (int(out_h) - 1) + h_scale_factor = (in_h - 1) / (sym_int(out_h) - 1) else: h_scale_factor = in_h / out_h else: @@ -1881,14 +2020,14 @@ def upsample_bilinear2d_vec( if out_w > 1: if align_corners: - w_scale_factor = (in_w - 1) / (int(out_w) - 1) + w_scale_factor = (in_w - 1) / (sym_int(out_w) - 1) else: w_scale_factor = in_w / out_w else: w_scale_factor = 0.0 - i = torch.arange(int(out_h), dtype=input.dtype, device=input.device) - j = torch.arange(int(out_w), dtype=input.dtype, device=input.device) + i = torch.arange(sym_int(out_h), dtype=input.dtype, device=input.device) + j = torch.arange(sym_int(out_w), dtype=input.dtype, device=input.device) if align_corners: x = h_scale_factor * i @@ -1920,6 +2059,16 @@ def upsample_bilinear2d_vec( q1 = torch.mul(v1, xscale1) + torch.mul(v2, xscale2) q2 = torch.mul(v3, xscale1) + torch.mul(v4, xscale2) result = torch.mul(q1, yscale1) + torch.mul(q2, yscale2) + + # convert output to correct memory format, if necessary + memory_format = utils.suggest_memory_format(input) + + # following "heuristic: only use channels_last path when it's faster than the contiguous path" + if input.device.type == "cuda" and n_channels < 16: + memory_format = torch.contiguous_format + + result = result.contiguous(memory_format=memory_format) + return result @@ -1929,7 +2078,7 @@ def is_same_size(a: Tensor, b: Tensor) -> bool: return a.shape == b.shape -@register_decomposition([aten._reshape_alias, aten._unsafe_view], disable_meta=True) +@register_decomposition([aten._reshape_alias, aten._unsafe_view]) def _reshape_alias(x, shape, *args): return aten.view(x, shape) @@ -2195,7 +2344,7 @@ def mv(self, vec): return (self * vec).sum(dim=1) -@register_decomposition(aten.dot, disable_meta=True) +@register_decomposition(aten.dot) @out_wrapper() @pw_cast_for_opmath def dot(self, other): @@ -2321,9 +2470,7 @@ def matmul(tensor1, tensor2): t2_is_matrix = t2.dim() == 2 if t2_is_matrix: output_shape.append(t2.shape[1]) - # HACK: We need reshape with symint support - t1 = t1.contiguous() - t1_folded = t1.view(folded_dim1, sizes_1[-1]) + t1_folded = t1.reshape(folded_dim1, sizes_1[-1]) if t2_is_matrix: # FIXME This path always does an unnecessary copy when transpose == True as the returned # result from BLAS is already C-transposed @@ -2356,15 +2503,11 @@ def matmul(tensor1, tensor2): expand_batch_product = prod(expand_batch_portion) # HACK: We need reshape with symint support - tensor1_expanded = ( - tensor1.expand(tensor1_expand_size) - .contiguous() - .view(expand_batch_product, n, m1) + tensor1_expanded = tensor1.expand(tensor1_expand_size).reshape( + expand_batch_product, n, m1 ) - tensor2_expanded = ( - tensor2.expand(tensor2_expand_size) - .contiguous() - .view(expand_batch_product, m2, p) + tensor2_expanded = tensor2.expand(tensor2_expand_size).reshape( + expand_batch_product, m2, p ) output_shape = expand_batch_portion diff --git a/torch/_deploy.py b/torch/_deploy.py index 53769538b6c11..30c022eac8793 100644 --- a/torch/_deploy.py +++ b/torch/_deploy.py @@ -23,7 +23,7 @@ def persistent_id(obj): if isinstance(obj, torch.storage.TypedStorage): # TODO: Once we decide to break serialization FC, we can # remove this case - storage = obj._storage + storage = obj._untyped_storage dtype = obj.dtype else: storage = obj diff --git a/torch/_dispatch/python.py b/torch/_dispatch/python.py index 95b7fa05bfe2c..f0814889ba2d2 100644 --- a/torch/_dispatch/python.py +++ b/torch/_dispatch/python.py @@ -1,5 +1,9 @@ import torch._C from contextlib import contextmanager +import unittest.mock +import torch +import torch.utils._pytree as pytree +import itertools __all__ = ['enable_python_dispatcher', 'no_python_dispatcher'] @@ -18,3 +22,121 @@ def enable_python_dispatcher(): yield finally: del g + +CROSSREF_FUNCTIONALIZE = False + +def all_known_overloads(): + for ns in torch.ops: + packets = getattr(torch.ops, ns) + for op_name in packets: + packet = getattr(packets, op_name) + for overload in packet: + yield getattr(packet, overload) + +@contextmanager +def suspend_functionalization(): + f_tls = torch._C._dispatch_tls_is_dispatch_key_included(torch._C.DispatchKey.Functionalize) + f_rv = torch._C._functionalization_reapply_views_tls() + if f_tls: + torch._disable_functionalization() + try: + yield + finally: + if f_tls: + torch._enable_functionalization(reapply_views=f_rv) + +def check_tensor_metadata_matches(nv, rv, desc): + assert callable(desc) + assert nv.size() == rv.size(), f"{desc()}: sizes {nv.size()} != {rv.size()}" + assert nv.dtype == rv.dtype, f"{desc()}: dtype {nv.dtype} != {rv.dtype}" + same_strides, idx = torch._prims_common.check_significant_strides(nv, rv, only_cuda=False) + assert same_strides, f"{desc()}: strides {nv.stride()} != {rv.stride()} (mismatch at index {idx})" + +def check_metadata_matches(n, r, desc): + assert callable(desc) + n_vals, n_spec = pytree.tree_flatten(n) + r_vals, r_spec = pytree.tree_flatten(r) + # TODO: test the specs match; empirically sometimes we have a tuple + # on one side and a list on the other + assert len(n_vals) == len(r_vals), f"{len(n_vals)} != {len(r_vals)}" + for i, nv, rv in zip(range(len(n_vals)), n_vals, r_vals): + if not isinstance(rv, torch.Tensor): + continue + check_tensor_metadata_matches(nv, rv, lambda: f"{desc()} output {i}") + +class Lit: + def __init__(self, s): + self.s = s + + def __repr__(self): + return self.s + +def _fmt(a: object) -> object: + if isinstance(a, torch.Tensor): + return Lit(f"torch.empty_strided({tuple(a.size())}, {a.stride()}, dtype={a.dtype})") + else: + return a + +def make_crossref_functionalize(op, final_key): + from torch._subclasses.fake_tensor import FakeTensorMode + # This case is pretty weird, suppress it for now + if op == torch.ops.aten.lift_fresh.default: + return final_key + + def handler(*args, **kwargs): + fake_mode = FakeTensorMode() + + def fakeify_defun(t): + if isinstance(t, torch.Tensor): + if torch._is_functional_tensor(t): + r = torch._from_functional_tensor(t) + # NB: This assumes that the inner tensor sizes/strides match + # the outer tensor sizes/strides. This doesn't necessarily have to + # be the case, see discussion at + # https://github.com/pytorch/pytorch/pull/87610/files/401ddeda1d769bedc88a12de332c7357b60e51a4#r1007264456 + assert t.size() == r.size() + assert t.stride() == r.stride() + else: + r = t + # TODO: suppress guards + return fake_mode.from_tensor(r) + return t + + def maybe_detach(t): + if isinstance(t, torch.Tensor): + return t.detach() + else: + return t + + with suspend_functionalization(): + f_args, f_kwargs = pytree.tree_map(fakeify_defun, (args, kwargs)) + orig_f_args, orig_f_kwargs = pytree.tree_map(maybe_detach, (f_args, f_kwargs)) + with fake_mode: + f_r = op(*f_args, **f_kwargs) + r = op._op_dk(final_key, *args, **kwargs) + + def desc(): + fmt_args = ", ".join( + itertools.chain( + (repr(pytree.tree_map(_fmt, a)) for a in orig_f_args), + (f"{k}={pytree.tree_map(_fmt, v)}" for k, v in orig_f_kwargs.items()), + ) + ) + return f"{op}({fmt_args})" + check_metadata_matches(f_r, r, desc) + return r + return handler + +# NB: enabling this is slow, don't do it in a hot loop. This is purely +# for debugging purposes. +@contextmanager +def enable_crossref_functionalize(): + for op in all_known_overloads(): + op._uncache_dispatch(torch._C.DispatchKey.Functionalize) + try: + with enable_python_dispatcher(), unittest.mock.patch( + 'torch._dispatch.python.CROSSREF_FUNCTIONALIZE', True): + yield + finally: + for op in all_known_overloads(): + op._uncache_dispatch(torch._C.DispatchKey.Functionalize) diff --git a/torch/_dynamo/__init__.py b/torch/_dynamo/__init__.py index 6b49ce5104ca4..57df92d75f4fa 100644 --- a/torch/_dynamo/__init__.py +++ b/torch/_dynamo/__init__.py @@ -7,6 +7,7 @@ export, optimize, optimize_assert, + OptimizedModule, reset_code, run, skip, @@ -14,7 +15,10 @@ from .utils import compilation_metrics, guard_failures, orig_code_map __all__ = [ + "allow_in_graph", "assume_constant_result", + "disallow_in_graph", + "graph_break", "optimize", "optimize_assert", "export", @@ -25,6 +29,7 @@ "reset", "list_backends", "skip", + "OptimizedModule", ] @@ -45,8 +50,9 @@ def reset(): def list_backends(): """ - Return valid strings that can be passed to: - @torchdynamo.optimize() + Return valid strings that can be passed to:: + + @torch._dynamo.optimize() def foo(...): .... """ @@ -58,11 +64,12 @@ def foo(...): def allow_in_graph(fn): """ Customize which functions TorchDynamo will include in the generated - graph. Similar to torch.fx.wrap(). + graph. Similar to `torch.fx.wrap()`. + :: - torchdynamo.allow_in_graph(my_custom_function) + torch._dynamo.allow_in_graph(my_custom_function) - @torchdynamo.optimize(...) + @torch._dynamo.optimize(...) def fn(a): x = torch.add(x, 1) x = my_custom_function(x) @@ -71,7 +78,7 @@ def fn(a): fn(...) - Will capture a single graph containing my_custom_function(). + Will capture a single graph containing `my_custom_function()`. """ if isinstance(fn, (list, tuple)): return [allow_in_graph(x) for x in fn] @@ -85,10 +92,11 @@ def disallow_in_graph(fn): """ Customize which functions TorchDynamo will exclude in the generated graph and force a graph break on. + :: - torchdynamo.disallow_in_graph(torch.sub) + torch._dynamo.disallow_in_graph(torch.sub) - @torchdynamo.optimize(...) + @torch._dynamo.optimize(...) def fn(a): x = torch.add(x, 1) x = torch.sub(x, 1) @@ -97,8 +105,8 @@ def fn(a): fn(...) - Will break the graph on torch.sub, and give two graphs each with a - single torch.add() op. + Will break the graph on `torch.sub`, and give two graphs each with a + single `torch.add()` op. """ if isinstance(fn, (list, tuple)): return [disallow_in_graph(x) for x in fn] diff --git a/torch/_dynamo/allowed_functions.py b/torch/_dynamo/allowed_functions.py index 42a6580ac1c86..67daafc5adac7 100644 --- a/torch/_dynamo/allowed_functions.py +++ b/torch/_dynamo/allowed_functions.py @@ -18,6 +18,24 @@ from . import config from .utils import is_safe_constant +""" +A note on allowed functions: + +Dynamo consults this file to determine if a particular function/module +is allowed to appear as a node in its fx output. + +If a function is disallowed, it may either be traced-through, or skipped. + +Trace-through means dynamo will continue to trace the interior code for +the function/module rather than stopping at its boundary and recording it +as a node in the fx graph. Whether tracing through or allowing, the functionality +of the function/module is part of the dynamo graph. Caveat: if tracing through, +any interior operation could trigger its own graph-break. + +Skips are determined by (torch/_dynamo/skipfiles.py) - see "a note on +skipfiles" there. +""" + def make_function_id_set(lazy_initializer): """ @@ -130,6 +148,7 @@ def _is_allowed_module_prefix(obj): "torch._inductor.", "torch._C.inductor.", "torch.fx.", + "torch.distributed.fsdp.", ) allowed_modules_dot = tuple([x + "." for x in allowed_modules]) module = inspect.getmodule(obj) diff --git a/torch/_dynamo/codegen.py b/torch/_dynamo/codegen.py index 2ba29981c3668..e469ce02ebd64 100644 --- a/torch/_dynamo/codegen.py +++ b/torch/_dynamo/codegen.py @@ -14,6 +14,7 @@ from .variables.base import VariableTracker from .variables.nn_module import NNModuleVariable from .variables.tensor import ( + DynamicShapeVariable, TensorVariable, TensorWithTFOverrideVariable, UnspecializedNumpyVariable, @@ -95,6 +96,7 @@ def __call__(self, value, allow_cache=True): value, ( TensorVariable, + DynamicShapeVariable, TensorWithTFOverrideVariable, UnspecializedNumpyVariable, UnspecializedPythonVariable, diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index 18d1af0a743b4..56e5f24b2642e 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -17,19 +17,24 @@ # log level (levels print what it says + all levels listed below it) # logging.DEBUG print full traces <-- lowest level + print tracing of every instruction -# torchdynamo.logging.CODE print compiled functions + graphs -# logging.INFO print the steps that dynamo is running +# logging.INFO print the steps that dynamo is running and optionally, compiled functions + graphs # logging.WARN print warnings (including graph breaks) # logging.ERROR print exceptions (and what user code was being processed when it occurred) # NOTE: changing log_level will automatically update the levels of all torchdynamo loggers log_level = logging.WARNING +output_code = False + # the name of a file to write the logs to log_file_name = None # Verbose will print full stack traces on warnings and errors verbose = False +# If true, traced graph outputs will be outputted as Python GraphModule code. +# If false, traced graph outputs will be outputted in tabular form. +output_graph_code = False + # verify the correctness of optimized backend verify_correctness = False @@ -64,36 +69,53 @@ # Run the FX graph as it is created to get better type information dynamic_propagation = True -# Run the FX graph with FakeTensors -fake_tensor_propagation = True - # run FX normalization passes in optimizer normalize_ir = False -# If a tensor subclass type is in this set, torchdynamo will inline the -# __torch_function__ logic of the subclass. +# This feature doesn't really work. We offer this flag for experimental +# purposes / if you want to help us build out support. +# +# torchdynamo has very limited support for tensor subclasses that implement +# __torch_function__. Our current support is limited to tensor subclasses +# that DO NOT store metadata on the tensor (in general, dynamo does not +# support Python code that stores extra attributes on tensors at present). +# If your tensor subclass purely changes function call behavior via +# __torch_function__, you can allow torchdynamo to trace into it by +# adding it to traceable_tensor_subclasses. We don't do any safety checks, +# so it is up to you to ensure that your subclass is well behaved. See also +# https://github.com/pytorch/torchdynamo/issues/1948 +# +# We do NOT currently support __torch_dispatch__. The implementation is +# currently buggy, the main show stopper for nontrivial use is +# https://github.com/pytorch/torchdynamo/issues/1952 traceable_tensor_subclasses = set() -# Raise torchdynamo internal assertions -raise_on_assertion_error = False - -# Propagate backend exceptions up to torchdynamo.optimize -raise_on_backend_error = True +# Suppress errors in torch._dynamo.optimize, instead forcing a fallback to eager. +# This is a good way to get your model to work one way or another, but you may +# lose optimization opportunities this way. Devs, if your benchmark model is failing +# this way, you should figure out why instead of suppressing it. +suppress_errors = bool(os.environ.get("TORCHDYNAMO_SUPPRESS_ERRORS", False)) # Record and write an execution record of the current frame to a file # if an exception is encountered replay_record_enabled = False -replay_record_dir_name = "./torchdynamo_error_records" + +# Rewrite assert statement in python with torch._assert +rewrite_assert_with_torch_assert = True # Show a warning on every graph break print_graph_breaks = False +# Disable dynamo +disable = os.environ.get("TORCH_COMPILE_DISABLE", False) + # If a PyTorch module is in this allowlist, torchdynamo will be allowed # to inline objects from it or its children. skipfiles_inline_module_allowlist = { torch.nn, torch.distributions, torch.testing, + torch.ao.nn, } if HAS_REFS_PRIMS: skipfiles_inline_module_allowlist |= { @@ -126,9 +148,6 @@ # 4: Dumps a minifier_launcher.py if the accuracy fails. repro_level = int(os.environ.get("TORCHDYNAMO_REPRO_LEVEL", 2)) -# Specify the directory where to save the repro artifacts -repro_dir = os.environ.get("TORCHDYNAMO_REPRO_DIR", None) - # Not all backends support scalars. Some calls on torch.Tensor (like .item()) return a scalar type. # When this flag is set to False, we introduce a graph break instead of capturing. capture_scalar_outputs = False @@ -138,8 +157,11 @@ enforce_cond_guards_match = True # Automatically split model graph into pieces to match DDP bucket sizes -# to allow DDP comm/compute overlap -optimize_ddp = False +# to allow DDP comm/compute overlap. Disable to allow DDP models to +# run without graph-breaks, but also without comm/compute overlap. +# set torch._dynamo.config.log_level to INFO or DEBUG for more info +# about optimize_ddp behavior. +optimize_ddp = True # If True, raises exception if TorchDynamo is called with a context manager raise_on_ctx_manager_usage = True @@ -147,18 +169,28 @@ # If True, raise when aot autograd is unsafe to use raise_on_unsafe_aot_autograd = False -# How to import torchdynamo, either torchdynamo or torch.dynamo +# How to import torchdynamo, either torchdynamo or torch._dynamo dynamo_import = __name__.replace(".config", "") # How to import torchinductor, either torchinductor or torch.inductor inductor_import = dynamo_import.replace("dynamo", "inductor") +# If true, error with a better message if we symbolically trace over a +# dynamo-optimized function. If false, silently suppress dynamo. +error_on_nested_fx_trace = True + # root folder of the project if "torch." in dynamo_import: base_dir = dirname(dirname(dirname(abspath(__file__)))) else: base_dir = dirname(dirname(abspath(__file__))) +debug_dir_root = os.path.join(os.getcwd(), "torchdynamo_debug") + +# this is to resolve a import problem in fbcode, we will be deleting +# this very shortly +DO_NOT_USE_legacy_non_fake_example_inputs = False + class _AccessLimitingConfig(ModuleType): def __setattr__(self, name, value): diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index d4afed9f63e37..a2105cd10743a 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -4,23 +4,18 @@ import os import traceback import types -import typing import weakref -from typing import Callable +from traceback import FrameSummary +from typing import cast, Dict, List, Optional, Set import torch from torch.fx.graph_module import _forward_from_src as original_forward_from_src -from . import config, exc, logging as torchdynamo_logging +from . import config, exc from .allowed_functions import is_allowed from .bytecode_analysis import remove_dead_code, remove_pointless_jumps from .bytecode_transformation import is_generator, transform_code_object -from .eval_frame import ( - always_optimize_code_objects, - skip_code, - TorchPatcher, - WrapperBackend, -) +from .eval_frame import always_optimize_code_objects, skip_code, TorchPatcher from .exc import ( BackendCompilerFailed, InternalTorchDynamoError, @@ -29,6 +24,8 @@ Unsupported, ) from .guards import CheckFunctionManager, GuardedCode +from .hooks import Hooks +from .output_graph import CompilerFn, OutputGraph from .replay_record import ExecutionRecord from .symbolic_convert import InstructionTranslator from .utils import ( @@ -86,18 +83,6 @@ def fx_forward_from_src_skip_result(*args, **kwargs): return result -def wrap_compiler_fn(compiler_fn): - """WrapperBackend if config.verify_correctness is True""" - if config.verify_correctness: - # wrap backend if verify_correctness is True - wrapper_backend_compiler_fn = WrapperBackend(compiler_fn) - - wrapper_backend_compiler_fn._torchdynamo_orig_callable = compiler_fn - return wrapper_backend_compiler_fn - - return compiler_fn - - def wrap_convert_context(fn): """ Context manager to: @@ -123,14 +108,14 @@ def _fn(*args, **kwargs): torch.cuda.set_rng_state(cuda_rng_state) torch.fx.graph_module._forward_from_src = prior_fwd_from_src - _fn._torchdynamo_orig_callable = fn + _fn._torchdynamo_orig_callable = fn # type: ignore[attr-defined] return _fn @TorchPatcher.suppress_torch_distributed_warnings def has_tensor_in_frame(frame): """Check if the frame has torch.* related bits""" - # Check if the function was decorated using torchdynamo.optimize + # Check if the function was decorated using torch._dynamo.optimize if frame.f_code in always_optimize_code_objects: return True @@ -140,7 +125,7 @@ def has_tensor_in_frame(frame): if is_allowed(frame.f_globals[co_name]): return True - seen_ids = dict() + seen_ids: Dict[int, bool] = dict() def has_tensor(obj): """Recursively check if the obj has a tensor""" @@ -156,7 +141,11 @@ def has_tensor(obj): seen_ids[obj_id] = any([has_tensor(v) for v in obj]) return seen_ids[obj_id] elif istype(obj, dict): - seen_ids[obj_id] = any([has_tensor(v) for v in obj.values()]) + # Some packages like pytest can be updated during runtime. So, make a + # copy of values to avoid issues like "RuntimeError: dictionary + # changed size during iteration" + values = list(obj.values()) + seen_ids[obj_id] = any([has_tensor(v) for v in values]) return seen_ids[obj_id] elif istype(obj, (str, int, float, type(None), bool)): seen_ids[obj_id] = False @@ -164,9 +153,6 @@ def has_tensor(obj): elif is_namedtuple(obj): seen_ids[obj_id] = any([has_tensor(getattr(obj, v)) for v in obj._fields]) return seen_ids[obj_id] - elif not is_allowed(obj) and hasattr(obj, "__dict__") and len(obj.__dict__): - seen_ids[obj_id] = any([has_tensor(v) for v in obj.__dict__.values()]) - return seen_ids[obj_id] else: # if config.debug: # print( @@ -190,17 +176,6 @@ def has_tensor(obj): def format_error_msg(exc, code, record_filename=None, frame=None): msg = os.linesep * 2 - def replay_record_msg(): - if ( - config.replay_record_enabled - and hasattr(exc, "exec_record") - and record_filename is not None - ): - return f"\nLast frame execution written to {record_filename}. To run only this frame while debugging, run\ - {config.dynamo_import}.replay('{record_filename}').\n" - else: - return "" - if config.verbose: msg = format_bytecode( "WON'T CONVERT", code.co_name, code.co_filename, code.co_firstlineno, code @@ -221,26 +196,59 @@ def replay_record_msg(): msg += "".join( traceback.format_list( - stack_above_dynamo + list(reversed(exc.real_stack)) + stack_above_dynamo + list(reversed(get_real_stack(exc))) ) ) - - msg += replay_record_msg() + msg += "\n" + msg += "=" * 10 else: msg = f"WON'T CONVERT {code.co_name} {code.co_filename}\ line {code.co_firstlineno} \ndue to: \n{traceback.format_exc(limit=-1)}" - if hasattr(exc, "real_stack"): - msg += f"\nfrom user code:\n {''.join(traceback.format_list([exc.real_stack[-1]]))}" + return msg + + +def get_real_stack(exc) -> List[FrameSummary]: + assert hasattr(exc, "real_stack") + return cast(List[FrameSummary], exc.real_stack) - msg += replay_record_msg() +def augment_exc_message(exc, msg="\n"): + if ( + hasattr(exc, "real_stack") + and len(exc.real_stack) > 0 + and not (config.verbose and config.suppress_errors) + ): + msg += f"\nfrom user code:\n {''.join(traceback.format_list(list(reversed(get_real_stack(exc)[0:2]))))}" + + if config.replay_record_enabled and hasattr(exc, "record_filename"): + msg += f"\nLast frame execution written to {exc.record_filename}. To run only this frame while debugging, run\ + {config.dynamo_import}.replay('{exc.record_filename}').\n" + + if not config.verbose: msg += ( f"\nSet {config.dynamo_import}.config.verbose=True for more information\n" ) - msg += "=" * 10 - return msg + + if hasattr(exc, "inner_exception") and hasattr( + exc.inner_exception, "minifier_path" + ): + msg += ( + f"\nMinifier script written to {exc.inner_exception.minifier_path}. Run " + "this script to find the smallest traced graph which reproduces this error.\n" + ) + + if not config.suppress_errors: + msg += ( + "\n\n" + "You can suppress this exception and fall back to eager by setting:\n" + " torch._dynamo.config.suppress_errors = True\n" + ) + + old_msg = "" if len(exc.args) == 0 else exc.args[0] + new_msg = old_msg + msg + exc.args = (new_msg,) + exc.args[1:] def exception_handler(e, code, frame=None): @@ -248,20 +256,25 @@ def exception_handler(e, code, frame=None): if hasattr(e, "exec_record"): record_filename = gen_record_file_name(e, code) write_record_to_file(record_filename, e.exec_record) + e.record_filename = record_filename - log.error(format_error_msg(e, code, record_filename, frame)) + augment_exc_message(e) + # Only log the exception if we are going to suppress it + # if aren't suppressing it, a higher level except block will handle it + if config.suppress_errors: + log.error(format_error_msg(e, code, record_filename, frame)) def convert_frame_assert( - compiler_fn: Callable, guard_export_fn=None, one_graph=True, export=False + compiler_fn: CompilerFn, + one_graph: bool = True, + export: bool = False, ): """Fully convert a frame into an FX graph""" init_logging() - compiler_fn = wrap_compiler_fn(compiler_fn) - @dynamo_timed - def _convert_frame_assert(frame: types.FrameType, cache_size: int): + def _convert_frame_assert(frame: types.FrameType, cache_size: int, hooks: Hooks): code = frame.f_code input_codes.add(code) if code in output_codes: @@ -280,6 +293,7 @@ def _convert_frame_assert(frame: types.FrameType, cache_size: int): # setattr could be tricky to handle generally, # but also not likely useful to compile- skip the whole frame return None + # Check if the frame is generated by an exec builtin call # TODO - Running exec generated frame seems propagates f_globals to the # next frames. @@ -330,7 +344,7 @@ def format_guard_failures(code): compiler_fn, one_graph, export, - guard_export_fn, + hooks, frame, ) @@ -339,17 +353,19 @@ def format_guard_failures(code): def _compile( - code, - globals, - locals, - builtins, - compiler_fn, - one_graph, - export, - guard_export_fn=None, - frame=None, -): - output = None + code: types.CodeType, + globals: Dict[str, object], + locals: Dict[str, object], + builtins: Dict[str, object], + compiler_fn: CompilerFn, + one_graph: bool, + export: bool, + hooks: Hooks, + frame: Optional[types.FrameType] = None, +) -> Optional[GuardedCode]: + output: Optional[OutputGraph] = None + # This is shared across restarts + mutated_closure_cell_contents: Set[str] = set() # from .utils import print_once; print_once(code.co_filename) def transform(instructions, code_options): @@ -364,9 +380,11 @@ def transform(instructions, code_options): compiler_fn, one_graph, export, + mutated_closure_cell_contents, ) tracer.run() output = tracer.output + assert output is not None assert output.output_instructions instructions[:] = output.output_instructions code_options.update(output.code_options) @@ -394,39 +412,48 @@ def transform(instructions, code_options): return None output_codes.add(out_code) - log.log( - torchdynamo_logging.CODE, - format_bytecode( - "ORIGINAL BYTECODE", - code.co_name, - code.co_filename, - code.co_firstlineno, - code, - ), - ) - log.log( - torchdynamo_logging.CODE, - format_bytecode( - "MODIFIED BYTECODE", - code.co_name, - code.co_filename, - code.co_firstlineno, - out_code, - ), - ) + if config.output_code: + log.info( + format_bytecode( + "ORIGINAL BYTECODE", + code.co_name, + code.co_filename, + code.co_firstlineno, + code, + ), + ) + log.info( + format_bytecode( + "MODIFIED BYTECODE", + code.co_name, + code.co_filename, + code.co_firstlineno, + out_code, + ), + ) + assert output is not None assert output.guards is not None CleanupManager.instance[out_code] = output.cleanups - check_fn = CheckFunctionManager(output.guards, locals, globals) + check_fn = CheckFunctionManager( + output, + output.guards, + locals, + globals, + hooks.guard_fail_fn if hooks else None, + ) guarded_code = GuardedCode(out_code, check_fn.check_fn) - guard_str = "GUARDS:\n" - guard_str += "\n".join([f" - {str(guard)}" for guard in sorted(output.guards)]) - log.log(torchdynamo_logging.CODE, guard_str) + if config.output_code: + guard_str = "GUARDS:\n" + guard_str += "\n".join( + [f" - {str(guard)}" for guard in sorted(output.guards)] + ) + log.info(guard_str) - if guard_export_fn is not None: - guard_export_fn(output.guards) + if hooks.guard_export_fn is not None: + hooks.guard_export_fn(output.guards) return guarded_code except ( @@ -439,29 +466,27 @@ def transform(instructions, code_options): raise except Exception as e: exception_handler(e, code, frame) - raise InternalTorchDynamoError() + raise InternalTorchDynamoError() from e -def convert_frame(compiler_fn: typing.Callable, guard_export_fn=None): +def convert_frame(compiler_fn: CompilerFn, hooks: Hooks): """Try to convert a frame into an FX graph, if error leave frame unmodified""" - inner_convert = convert_frame_assert(compiler_fn, guard_export_fn, one_graph=False) + inner_convert = convert_frame_assert(compiler_fn, one_graph=False) - def _convert_frame(frame: types.FrameType, cache_size: int): + def _convert_frame(frame: types.FrameType, cache_size: int, hooks: Hooks): counters["frames"]["total"] += 1 try: - result = inner_convert(frame, cache_size) + result = inner_convert(frame, cache_size, hooks) counters["frames"]["ok"] += 1 return result - except AssertionError: - if config.raise_on_assertion_error: - raise - except BackendCompilerFailed: - raise - except Exception: + except (NotImplementedError, Unsupported): pass + except Exception: + if not config.suppress_errors: + raise return None - _convert_frame._torchdynamo_orig_callable = compiler_fn + _convert_frame._torchdynamo_orig_callable = compiler_fn # type: ignore[attr-defined] return _convert_frame @@ -484,11 +509,11 @@ def replay(filename): record.globals, record.locals, record.builtins, - eager, - False, # one_graph - None, # export_fn - None, # frame - False, # Export + compiler_fn=eager, + one_graph=False, + export=False, + hooks=Hooks(), + frame=None, ) except Exception: pass diff --git a/torch/_dynamo/debug_utils.py b/torch/_dynamo/debug_utils.py index 7a2466637b767..1db28caee6b8b 100644 --- a/torch/_dynamo/debug_utils.py +++ b/torch/_dynamo/debug_utils.py @@ -16,13 +16,13 @@ from . import config from .optimizations.backends import register_backend -from .utils import clone_inputs +from .utils import clone_inputs, get_debug_dir log = logging.getLogger(__name__) def minifier_dir(): - path = config.repro_dir + path = os.path.join(get_debug_dir(), "minifier") if path is None: path = f"/tmp/minifier_{getpass.getuser()}" if not os.path.exists(path): @@ -84,6 +84,11 @@ def __init__(self): for module_name, module in gm.named_children(): module_str = f"{module.__repr__()}" + # module should be a core torch.nn.Module, so all parameters + # should be on the same device. + example_param = next(module.parameters(), None) + if example_param is not None and example_param.is_cuda: + module_str = f"{module_str}.cuda()" model_str += f"{tab*2}self.{module_name} = {module_str}\n" for buffer_name, buffer in gm._buffers.items(): @@ -95,12 +100,16 @@ def __init__(self): tensor_str = ( f"torch.randint(1, size={list(buffer.shape)}, dtype={buffer.dtype})" ) + if buffer.is_cuda: + tensor_str = f"{tensor_str}.cuda()" model_str += f"{tab*2}self.register_buffer('{buffer_name}', {tensor_str})\n" for param_name, param in gm._parameters.items(): if param is None: continue tensor_str = f"torch.nn.Parameter(torch.randn({list(param.shape)}, dtype={param.dtype}))" + if param.is_cuda: + tensor_str = f"{tensor_str}.cuda()" model_str += f"{tab*2}self.{param_name} = {tensor_str}\n" # TODO - Keep this code for now. But, I don't think we will need this. @@ -128,7 +137,7 @@ def _cuda_system_info_comment(): ) model_str += f"{cuda_version_out}\n" except FileNotFoundError: - model_str += "nvcc not found\n" + model_str += "# nvcc not found\n" gpu_names = subprocess.run( ["nvidia-smi", "--query-gpu=gpu_name", "--format=csv"], @@ -145,6 +154,9 @@ def _cuda_system_info_comment(): return model_str +TEST_REPLACEABLE_COMMENT = "# REPLACEABLE COMMENT FOR TESTING PURPOSES" + + def generate_compiler_repro_string(gm, args): model_str = textwrap.dedent( f""" @@ -155,6 +167,8 @@ def generate_compiler_repro_string(gm, args): from math import inf from torch.fx.experimental.proxy_tensor import make_fx + {TEST_REPLACEABLE_COMMENT} + """ ) model_str += f"# torch version: {torch.version.__version__}\n" @@ -170,7 +184,7 @@ def generate_compiler_repro_string(gm, args): model_str += ( "args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args]\n" ) - model_str += 'mod = make_fx(Repro().to(device="cuda"))(*args)\n' + model_str += "mod = make_fx(Repro())(*args)\n" return model_str @@ -179,11 +193,6 @@ def generate_compiler_repro_string(gm, args): from {config.dynamo_import}.debug_utils import same_two_models """ -NVFUSER_IMPORT = """ -from torch.fx.passes.backends.nvfuser import NvFuserBackend -nvfuser = NvFuserBackend() -""" - COMPILER_REPRO_OPTIONS = { "inductor": (INDUCTOR_IMPORT, "compile_fx_inner", "inductor_fails"), "inductor_accuracy": ( @@ -191,7 +200,6 @@ def generate_compiler_repro_string(gm, args): "compile_fx_inner", "inductor_accuracy_fails", ), - "nvfuser": (NVFUSER_IMPORT, "nvfuser", "nvfuser_fails"), } @@ -203,7 +211,8 @@ def dump_compiler_graph_state(gm, args, compiler_name): log.warning(f"Writing checkpoint with {len(gm.graph.nodes)} nodes to {file_name}") with open(file_name, "w") as fd: save_graph_repro(fd, gm, args, compiler_name) - repro_path = os.path.join(config.base_dir, "repro.py") + curdir = os.getcwd() + repro_path = os.path.join(curdir, "repro.py") try: shutil.copyfile(file_name, repro_path) log.warning(f"Copying repro file for convenience to {repro_path}") @@ -213,6 +222,12 @@ def dump_compiler_graph_state(gm, args, compiler_name): def save_graph_repro(fd, gm, args, compiler_name): + sync_line = "" + for arg in args: + if arg.is_cuda: + sync_line = "torch.cuda.synchronize() # Ensures that segfaults are surfaced" + break + if "inductor" in compiler_name: fd.write(f"import {config.inductor_import}.overrides\n") fd.write(generate_compiler_repro_string(gm, args)) @@ -222,7 +237,10 @@ def save_graph_repro(fd, gm, args, compiler_name): textwrap.dedent( f""" compiled = {COMPILER_REPRO_OPTIONS[compiler_name][1]}(mod, args) - assert same_two_models(mod, compiled, args, only_fwd=True), "Accuracy failed" + class AccuracyError(Exception): + pass + if not same_two_models(mod, compiled, args, only_fwd=True): + raise AccuracyError("Bad accuracy detected") """ ) ) @@ -231,21 +249,25 @@ def save_graph_repro(fd, gm, args, compiler_name): textwrap.dedent( f""" compiled = {COMPILER_REPRO_OPTIONS[compiler_name][1]}(mod, args) - compiled(args) + ref = compiled(args) + {sync_line} """ ) ) -def isolate_fails(fx_g, args, compiler_name: str, env=None): +def isolate_fails(fx_g, args, compiler_name: str, env=None, patch_code=None): if env is None: env = {} - subdir = f"{minifier_dir()}/isolate" + subdir = os.path.join(os.getcwd(), "isolate") if not os.path.exists(subdir): os.makedirs(subdir, exist_ok=True) file_name = os.path.join(subdir, f"{str(uuid.uuid4())[:5]}.py") with open(file_name, "w") as fd: - fd.write(generate_compiler_repro_string(fx_g, args)) + repro_code = generate_compiler_repro_string(fx_g, args) + if patch_code is not None: + repro_code = repro_code.replace(TEST_REPLACEABLE_COMMENT, patch_code) + fd.write(repro_code) fail_fn = COMPILER_REPRO_OPTIONS[compiler_name][2] fd.write( textwrap.dedent( @@ -269,6 +291,7 @@ def isolate_fails(fx_g, args, compiler_name: str, env=None): stdout, stderr = TemporaryFile(), TemporaryFile() p = subprocess.Popen( ["python", file_name], + cwd=subdir, stdout=stdout, stderr=stderr, env=new_env, @@ -280,43 +303,41 @@ def isolate_fails(fx_g, args, compiler_name: str, env=None): stderr.seek(0) print(textwrap.indent(stdout.read().decode("utf-8"), prefix=">> ")) print(textwrap.indent(stderr.read().decode("utf-8"), prefix=">> ")) + # print(f"Isolated test failed - {file_name}") return True return False def inductor_fails(fx_g, args, check_str=None): + has_cuda = False + for arg in args: + if arg.is_cuda: + has_cuda = True + break + + def sync(): + if has_cuda: + # Ensures that segfaults are surfaced + torch.cuda.synchronize() + compile_fx_inner = import_module( f"{config.inductor_import}.compile_fx" ).compile_fx_inner - import_module(f"{config.inductor_import}.config").triton.autotune = False - try: result = fx_g(*args) assert isinstance(result, (tuple, list)) assert not any([isinstance(x, (tuple, list)) for x in result]) except Exception: return False + result = None + + sync() try: compile_mod = compile_fx_inner(fx_g, args) compile_mod(args) - except Exception as e: - if check_str is not None and check_str not in repr(e): - return False - print(repr(e)) - return True - return False - - -def nvfuser_fails(fx_g, args, check_str=None): - from torch.fx.passes.backends.nvfuser import NvFuserBackend - - nvfuser = NvFuserBackend() - - try: - compile_mod = nvfuser(fx_g, args) - compile_mod = compile_mod(*args) + sync() except Exception as e: if check_str is not None and check_str not in repr(e): return False @@ -326,29 +347,24 @@ def nvfuser_fails(fx_g, args, check_str=None): def inductor_accuracy_fails(fx_g, args, check_str=None): - from torchinductor.compile_fx import compile_fx_inner + from torch._inductor.compile_fx import compile_fx_inner return backend_aot_accuracy_fails(fx_g, args, compile_fx_inner) +def get_minifier_repro_path(): + return os.path.join(minifier_dir(), "minifier_launcher.py") + + def helper_for_dump_minify(contents): - minified_repro_path = os.path.join(minifier_dir(), "minifier_launcher.py") + minified_repro_path = get_minifier_repro_path() log.warning(f"Writing minified repro to {minified_repro_path}") try: with open(minified_repro_path, "w") as fd: fd.write(contents) except OSError as e: log.exception(e) - raise NotImplementedError("Could not write to {minified_repro_path}") - - local_path = os.path.join(config.base_dir, "minifier_launcher.py") - try: - shutil.copyfile(minified_repro_path, local_path) - log.warning( - f"Copying minified repro from {minified_repro_path} to {local_path} for convenience" - ) - except OSError: - log.warning(f"Don't have write permissions for {local_path}") + raise NotImplementedError("Could not write to {minified_repro_path}") from e def dump_to_minify(gm, args, compiler_name: str): @@ -356,6 +372,8 @@ def dump_to_minify(gm, args, compiler_name: str): contents = textwrap.dedent( f""" +isolate_fails_code_str = None + {generate_compiler_repro_string(gm, args)} from functools import partial @@ -370,7 +388,7 @@ def dump_to_minify(gm, args, compiler_name: str): minifier( mod, args, - module_fails=partial(isolate_fails, env=env_variables, compiler_name="{compiler_name}"), + module_fails=partial(isolate_fails, env=env_variables, compiler_name="{compiler_name}", patch_code=isolate_fails_code_str), dump_state=partial(dump_compiler_graph_state, compiler_name="{compiler_name}"), ) """ @@ -378,6 +396,10 @@ def dump_to_minify(gm, args, compiler_name: str): return helper_for_dump_minify(contents) +class AccuracyError(Exception): + pass + + def wrap_compiler_debug(compiler_fn, compiler_name: str): """ Minifier for Fx Graph modules after Aot Autograd has finished. We wrap both @@ -437,7 +459,7 @@ def deferred_for_real_inputs(real_inputs): copy_tensor_attrs, f"{compiler_name}_accuracy", ) - raise ValueError("Bad accuracy detected") + raise AccuracyError("Bad accuracy detected") else: # Call the compiled function with real inputs return inner_compiled_fn(real_inputs) @@ -462,7 +484,8 @@ def deferred_for_real_inputs(real_inputs): copy_tensor_attrs, compiler_name, ) - raise e + log.error("CompilerError") + raise if config.repro_after == "aot": compiled_fn = deferred_for_real_inputs @@ -479,7 +502,7 @@ def run_fwd_maybe_bwd(gm, args, only_fwd=False): """ Runs a forward and possibly backward iteration for a given mod and args. """ - from functorch._src.aot_autograd import make_boxed_func + from torch._functorch.aot_autograd import make_boxed_func from .testing import collect_results, reduce_to_scalar_loss, requires_bwd_pass @@ -506,15 +529,23 @@ def run_fwd_maybe_bwd(gm, args, only_fwd=False): if requires_bwd_pass(out): loss = reduce_to_scalar_loss(out) loss.backward() - return collect_results(gm, out, None, []) + return collect_results(gm, out, None, args) def same_two_models(gm, opt_gm, example_inputs, only_fwd=False): """ Check two models have same accuracy. """ + from .eval_frame import OptimizedModule + from .testing import named_parameters_for_optimized_module from .utils import same + if isinstance(gm, OptimizedModule): + gm.named_parameters = named_parameters_for_optimized_module(gm) + + if isinstance(opt_gm, OptimizedModule): + opt_gm.named_parameters = named_parameters_for_optimized_module(opt_gm) + ref = run_fwd_maybe_bwd(gm, example_inputs, only_fwd) try: @@ -526,7 +557,19 @@ def same_two_models(gm, opt_gm, example_inputs, only_fwd=False): log.warning("Could not generate fp64 outputs") fp64_ref = None - res = run_fwd_maybe_bwd(opt_gm, example_inputs, only_fwd) + try: + res = run_fwd_maybe_bwd(opt_gm, example_inputs, only_fwd) + except Exception as e: + # This means that the the minified graph is bad/exposes a different problem. + # As we are checking accuracy here, lets log the exception and return True. + log.warning( + ( + "While minifying the program in accuracy minification mode," + "ran into a runtime exception which is likely an unrelated issue." + " Skipping this graph." + ) + ) + return True passing = same(ref, res, fp64_ref, tol=0.001, equal_nan=True) return passing @@ -571,9 +614,14 @@ def generate_dynamo_fx_repro_string( f""" mod.eval() opt_mod.eval() + +class AccuracyError(Exception): + pass + with torch.cuda.amp.autocast(enabled={torch.is_autocast_enabled()}): assert same_two_models(mod, mod, args), "Eager itself failed" - assert same_two_models(mod, opt_mod, args), "Dynamo failed" + if not same_two_models(mod, opt_mod, args): + raise AccuracyError("Dynamo failed") """ ) @@ -588,12 +636,14 @@ def generate_dynamo_fx_repro_string( from {config.dynamo_import}.debug_utils import run_fwd_maybe_bwd from {config.dynamo_import}.debug_utils import same_two_models +{TEST_REPLACEABLE_COMMENT} + args = {[(tuple(a.shape), tuple(a.stride()), a.dtype, a.device.type, a.requires_grad) for a in args]} args = [rand_strided(sh, st, dt, dev).requires_grad_(rg) for (sh, st, dt, dev, rg) in args] {model_str} -mod = Repro().cuda() +mod = Repro() opt_mod = {config.dynamo_import}.optimize("{compiler_name}")(mod) {run_code} @@ -605,10 +655,11 @@ def dump_backend_repro_as_file(gm, args, compiler_name, check_accuracy=False): """ Saves the repro to a repro.py file """ - subdir = os.path.join(minifier_dir()) + curdir = os.getcwd() + subdir = os.path.join(os.getcwd(), "checkpoints") if not os.path.exists(subdir): os.makedirs(subdir, exist_ok=True) - file_name = os.path.join(subdir, f"{len(gm.graph.nodes)}.py") + file_name = os.path.join(subdir, f"minified_{len(gm.graph.nodes)}_nodes.py") log.warning(f"Writing checkpoint with {len(gm.graph.nodes)} nodes to {file_name}") model_str = NNModuleToString.convert(gm) @@ -618,19 +669,10 @@ def dump_backend_repro_as_file(gm, args, compiler_name, check_accuracy=False): model_str, args, compiler_name, check_accuracy ) ) - latest_repro = os.path.join(subdir, "repro.py") + latest_repro = os.path.join(curdir, "repro.py") log.warning(f"Copying {file_name} to {latest_repro} for convenience") shutil.copyfile(file_name, latest_repro) - local_path = os.path.join(config.base_dir, "repro.py") - try: - shutil.copyfile(file_name, local_path) - log.warning( - f"Copying minified repro from {file_name} to {local_path} for convenience" - ) - except OSError: - log.warning("No write permissions for {local_path}") - # TODO - Commented because we are assuming that nn.Modules can be safely repr'd # If that does not work, we might have to bring this code back. So, keeping it @@ -701,12 +743,28 @@ def dump_backend_state(gm, args, compiler_name, check_accuracy=False): def backend_accuracy_fails(gm, example_inputs, compiler_fn, only_fwd=False): - compiled_gm = compiler_fn(copy.deepcopy(gm), clone_inputs(example_inputs)) + try: + compiled_gm = compiler_fn(copy.deepcopy(gm), clone_inputs(example_inputs)) + except Exception as e: + # This means that the the minified graph is bad/exposes a different problem. + # As we are checking accuracy here, lets log the exception and return False. + log.warning( + ( + "While minifying the program in accuracy minification mode," + "ran into a runtime exception which is likely an unrelated issue." + " Skipping this graph" + ) + ) + return False + return not same_two_models(gm, compiled_gm, example_inputs, only_fwd) backend_aot_accuracy_fails = functools.partial(backend_accuracy_fails, only_fwd=True) +# Please see NOTE: [Real Tensors in Accuracy Evaluation] +MINIFIER_SPAWNED = False + def backend_fails(gm, example_inputs, compiler_fn, orig_failure): """ @@ -740,6 +798,21 @@ def dump_to_minify_after_dynamo(gm, args, compiler_name): if config.repro_level == 4: minifier_backend = "dynamo_accuracy_minifier_backend" + custom_compiler_error = ( + textwrap.dedent( + """\ + raise RuntimeError( + 'Compiler name is None - this likely means that a custom compiler ' + 'was called by torchdynamo. Please remove this error, import your ' + 'custom compiler function, and replace the compiler_name="None" ' + 'line below to compiler_name=' + ) + """ + ) + if compiler_name is None + else "" + ) + contents = textwrap.dedent( f""" import os @@ -753,16 +826,18 @@ def dump_to_minify_after_dynamo(gm, args, compiler_name): from {config.dynamo_import}.optimizations.backends import BACKENDS from {config.dynamo_import}.testing import rand_strided -{config.dynamo_import}.config.repro_dir = \"{minifier_dir()}\" +{TEST_REPLACEABLE_COMMENT} args = {[(tuple(a.shape), tuple(a.stride()), a.dtype, a.device.type, a.requires_grad) for a in args]} args = [rand_strided(sh, st, dt, dev).requires_grad_(rg) for (sh, st, dt, dev, rg) in args] {model_str} -mod = Repro().cuda() +mod = Repro() # Setup debug minifier compiler +torch._dynamo.debug_utils.MINIFIER_SPAWNED = True compiler_fn = BACKENDS["{minifier_backend}"] +{custom_compiler_error} dynamo_minifier_backend = functools.partial( compiler_fn, compiler_name="{compiler_name}", @@ -790,32 +865,37 @@ def wrap_backend_debug(compiler_fn, compiler_name: str): def debug_wrapper(gm, example_inputs, **kwargs): assert config.repro_after in ("dynamo", "aot", None) if config.repro_after == "dynamo": - # Ensure that we fail when backend fails - config.raise_on_backend_error = True if config.repro_level == 3: dump_to_minify_after_dynamo(gm, example_inputs, compiler_name) # Check for either accuracy (level 4) or other type of failures. if config.repro_level == 4: # Check Accuracy - compiled_gm = compiler_fn(gm, example_inputs, **kwargs) + compiled_gm = compiler_fn(copy.deepcopy(gm), example_inputs, **kwargs) if backend_accuracy_fails(gm, example_inputs, compiler_fn): - log.warning("Accuracy failed for the TorchDyanmo produced graph") + log.warning( + "Accuracy failed for the TorchDyanmo produced graph. Creating script to minify the error." + ) dump_to_minify_after_dynamo( fx.GraphModule(gm, copy.deepcopy(gm.graph)), example_inputs, compiler_name, ) - raise ValueError("Bad accuracy detected") + exc = AccuracyError("Bad accuracy detected.") + exc.minifier_path = os.path.join( + minifier_dir(), "minifier_launcher.py" + ) + raise exc else: try: - compiled_gm = compiler_fn(gm, example_inputs, **kwargs) + compiled_gm = compiler_fn( + copy.deepcopy(gm), example_inputs, **kwargs + ) run_fwd_maybe_bwd(compiled_gm, example_inputs) except Exception as exc: log.warning( - "Compiled Fx GraphModule failed with following error. Setting up minifier." + "Compiled Fx GraphModule failed. Creating script to minify the error." ) - log.exception(exc) if config.repro_level == 1: dump_state_fn = functools.partial( dump_backend_state, compiler_name=compiler_name @@ -829,7 +909,10 @@ def debug_wrapper(gm, example_inputs, **kwargs): example_inputs, compiler_name, ) - raise ValueError("Issue deteced. Repro at minifier_launcher.py.") + exc.minifier_path = os.path.join( + minifier_dir(), "minifier_launcher.py" + ) + raise else: compiled_gm = compiler_fn(gm, example_inputs, **kwargs) @@ -855,9 +938,8 @@ def dynamo_minifier_backend(gm, example_inputs, compiler_name): except Exception as exc: orig_failure = str(exc) log.warning( - "Compiled Fx GraphModule failed with following error. Starting minifier." + "Compiled Fx GraphModule failed. Creating script to minify the error." ) - log.exception(exc) dump_state_fn = functools.partial( dump_backend_state, compiler_name=compiler_name ) @@ -880,10 +962,10 @@ def dynamo_minifier_backend(gm, example_inputs, compiler_name): def dynamo_accuracy_minifier_backend(gm, example_inputs, compiler_name): from functorch.compile import minifier - from torchdynamo.optimizations.backends import BACKENDS + from torch._dynamo.optimizations.backends import BACKENDS if compiler_name == "inductor": - from torchinductor.compile_fx import compile_fx + from torch._inductor.compile_fx import compile_fx compiler_fn = compile_fx else: diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index bf9a230a420b8..773e62e3f9e3e 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -1,15 +1,17 @@ import contextlib -import copy import functools import inspect import logging import os import sys +import textwrap import threading import traceback import types import warnings +from enum import Enum from importlib import import_module +from typing import Optional, Tuple, TYPE_CHECKING, Union from unittest.mock import patch import torch @@ -17,31 +19,67 @@ from torch.fx.experimental.proxy_tensor import make_fx from torch.nn.parallel.distributed import DistributedDataParallel +from .hooks import Hooks + +if TYPE_CHECKING: + from torch._C._dynamo.eval_frame import ( # noqa: F401 + reset_code, + set_eval_frame, + set_guard_error_hook, + set_guard_fail_hook, + skip_code, + unsupported, + ) +else: + for name in dir(torch._C._dynamo.eval_frame): + if name.startswith("__"): + continue + globals()[name] = getattr(torch._C._dynamo.eval_frame, name) + from . import config, convert_frame, skipfiles, utils from .exc import ResetRequired from .mutation_guard import install_generation_tagging_init -from .optimizations.distributed import DDPOptimizer -from .utils import checkpoint_params, clone_inputs, compile_times, same +from .output_graph import CompilerFn +from .types import DynamoCallback +from .utils import compile_times log = logging.getLogger(__name__) -try: - from torch.fx.experimental import proxy_tensor -except ImportError: - proxy_tensor = None - -_eval_frame = torch._C._dynamo.eval_frame -set_eval_frame = _eval_frame.set_eval_frame -reset_code = _eval_frame.reset_code -unsupported = _eval_frame.unsupported -skip_code = _eval_frame.skip_code -set_guard_fail_hook = _eval_frame.set_guard_fail_hook -set_guard_error_hook = _eval_frame.set_guard_error_hook +from torch.fx.experimental import proxy_tensor + always_optimize_code_objects = utils.ExactWeakKeyDictionary() null_context = contextlib.nullcontext -unset = object() + +# See https://github.com/python/typing/pull/240 +class Unset(Enum): + token = 0 + + +unset = Unset.token + compile_lock = threading.RLock() -most_recent_backend = None +most_recent_backend: Optional[CompilerFn] = None + + +class OptimizedModule(torch.nn.Module): + """ + Wraps the original nn.Module object and later patches its + forward method to optimized self.forward method. + """ + + def __init__(self, mod, dynamo_ctx): + super().__init__() + # Installs the params/buffer + self._orig_mod = mod + self.dynamo_ctx = dynamo_ctx + + def __getattr__(self, name): + if name == "_orig_mod": + return self._modules["_orig_mod"] + return getattr(self._orig_mod, name) + + def forward(self, *args, **kwargs): + return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs) def remove_from_cache(f): @@ -78,39 +116,58 @@ def innermost_fn(fn): return unaltered_fn +@contextlib.contextmanager +def enable_dynamic(enable: bool = True): + if not enable: + yield + return + with patch("torch._dynamo.config.dynamic_shapes", True), patch( + "torch._functorch.config.use_dynamic_shapes", True + ): + yield + + class _TorchDynamoContext: def __init__( self, - callback, + callback: DynamoCallback, on_enter=nothing, backend_ctx_ctor=null_context, patch_fn=nothing, first_ctx=False, + *, + dynamic=False, ): super().__init__() assert callable(callback) or callback is False or callback is None - self.callback = callback - self.prior = unset + self.callback: DynamoCallback = callback + self.prior: Union[Unset, DynamoCallback] = unset self.on_enter = on_enter self.extra_ctx_ctor = backend_ctx_ctor self.first_ctx = first_ctx + self.dynamic = dynamic patch_fn() def __enter__(self): if config.raise_on_ctx_manager_usage: raise RuntimeError( - "torchdynamo.optimize(...) is used with a context manager. " + "torch._dynamo.optimize(...) is used with a context manager. " "Please refer to https://github.com/pytorch/torchdynamo#usage-example " - "to use torchdynamo.optimize(...) as an annotation/decorator. " + "to use torch._dynamo.optimize(...) as an annotation/decorator. " ) self.on_enter() self.prior = set_eval_frame(self.callback) self.backend_ctx = self.extra_ctx_ctor() self.backend_ctx.__enter__() + self.dynamic_ctx = enable_dynamic(self.dynamic) + self.dynamic_ctx.__enter__() def __exit__(self, exc_type, exc_val, exc_tb): + assert self.prior is not unset set_eval_frame(self.prior) self.prior = unset + # TODO: This is totally not the right way to chain contexts manually + self.dynamic_ctx.__exit__(exc_type, exc_val, exc_tb) self.backend_ctx.__exit__(exc_type, exc_val, exc_tb) def __call__(self, fn): @@ -118,67 +175,99 @@ def __call__(self, fn): # Optimize the forward method of torch.nn.Module object if isinstance(fn, torch.nn.Module): mod = fn - optimized_forward = self(mod.forward) - - class TorchDynamoNNModuleWrapper: - """ - A wrapper that redirects the forward call to the optimized - forward, while for rest it redirects the calls to the original - module. - """ - - def __getattr__(self, name): - return getattr(mod, name) - - def forward(self, *args, **kwargs): - return optimized_forward(*args, **kwargs) - - def __call__(self, *args, **kwargs): - return self.forward(*args, **kwargs) - - new_mod = TorchDynamoNNModuleWrapper() + new_mod = OptimizedModule(mod, self) # Save the function pointer to find the original callable while nesting # of decorators. - new_mod._torchdynamo_orig_callable = mod + new_mod._torchdynamo_orig_callable = mod.forward return new_mod assert callable(fn) + callback = self.callback on_enter = self.on_enter backend_ctx_ctor = self.extra_ctx_ctor @functools.wraps(fn) def _fn(*args, **kwargs): + if ( + not isinstance(self, DisableContext) + and torch.fx._symbolic_trace.is_fx_tracing() + ): + if config.error_on_nested_fx_trace: + raise RuntimeError( + "Detected that you are using FX to symbolically trace " + "a dynamo-optimized function. This is not supported at the moment." + ) + else: + return fn(*args, **kwargs) + on_enter() prior = set_eval_frame(callback) backend_ctx = backend_ctx_ctor() backend_ctx.__enter__() + dynamic_ctx = enable_dynamic(self.dynamic) + dynamic_ctx.__enter__() try: return fn(*args, **kwargs) finally: set_eval_frame(prior) + dynamic_ctx.__exit__(None, None, None) backend_ctx.__exit__(None, None, None) # hooks to properly handle inlining if isinstance(self, DisableContext): - _fn._torchdynamo_disable = True + _fn._torchdynamo_disable = True # type: ignore[attr-defined] else: - _fn._torchdynamo_inline = fn + _fn._torchdynamo_inline = fn # type: ignore[attr-defined] # Save the function pointer to find the original callable while nesting # of decorators. - _fn._torchdynamo_orig_callable = fn + _fn._torchdynamo_orig_callable = fn # type: ignore[attr-defined] - # If the function is called using torchdynamo.optimize decorator, we + # If the function is called using torch._dynamo.optimize decorator, we # should prevent any type of skipping. if callback not in (None, False): + if not hasattr(fn, "__code__"): + raise RuntimeError( + textwrap.dedent( + """ + + torch._dynamo.optimize is called on a non function object. + If this is a callable class, please wrap the relevant code into a function and optimize the + wrapper function. + + >> class CallableClass: + >> def __init__(self): + >> super().__init__() + >> self.relu = torch.nn.ReLU() + >> + >> def __call__(self, x): + >> return self.relu(torch.sin(x)) + >> + >> def print_hello(self): + >> print("Hello world") + >> + >> mod = CallableClass() + + If you want to optimize the __call__ function and other code, wrap that up in a function + + >> def wrapper_fn(x): + >> y = mod(x) + >> return y.sum() + + and then optimize the wrapper_fn + + >> opt_wrapper_fn = torch._dynamo.optimize(wrapper_fn) + """ + ) + ) always_optimize_code_objects[fn.__code__] = True return _fn class OptimizeContext(_TorchDynamoContext): - def __init__(self, callback, backend_ctx_ctor, first_ctx=False): + def __init__(self, callback, backend_ctx_ctor, first_ctx=False, *, dynamic=False): def on_enter(): global most_recent_backend if ( @@ -196,6 +285,7 @@ def on_enter(): backend_ctx_ctor=backend_ctx_ctor, patch_fn=TorchPatcher.patch, first_ctx=first_ctx, + dynamic=dynamic, ) @@ -209,92 +299,53 @@ def __init__(self): super().__init__(callback=None) -def catch_errors_wrapper(callback): +def catch_errors_wrapper(callback, hooks: Hooks): @functools.wraps(callback) def catch_errors(frame, cache_size): - try: - if frame.f_lasti >= 0 or skipfiles.check(frame.f_code.co_filename): - log.debug(f"skipping {frame.f_code.co_name} {frame.f_code.co_filename}") - return None - if ( - frame.f_code.co_filename == "" - and frame.f_code.co_name == "__new__" - ): - # nametuple constructor - return None - if config.optimize_ddp: - ddp_module = DistributedDataParallel._get_active_ddp_module() - if ddp_module: - with compile_lock: - ddp_optimizer = DDPOptimizer( - bucket_bytes_cap=ddp_module.bucket_bytes_cap, - parameters_to_ignore=ddp_module.parameters_to_ignore, - backend_compile_fn=callback._torchdynamo_orig_callable, - ) - hijacked_callback = convert_frame.convert_frame( - ddp_optimizer.compile_fn, guard_export_fn=None - ) - return hijacked_callback(frame, cache_size) - - with compile_lock: - return callback(frame, cache_size) - except Exception: - log.exception("Error while processing frame") - raise - - catch_errors._torchdynamo_orig_callable = callback + if ( + frame.f_lasti >= 0 + or skipfiles.check(frame.f_code.co_filename) + or config.disable + ): + log.debug(f"skipping {frame.f_code.co_name} {frame.f_code.co_filename}") + return None + if frame.f_code.co_filename == "" and frame.f_code.co_name == "__new__": + # nametuple constructor + return None + if config.optimize_ddp: + ddp_module = DistributedDataParallel._get_active_ddp_module() + if ddp_module: + with compile_lock: + from .optimizations.distributed import DDPOptimizer + + ddp_optimizer = DDPOptimizer( + bucket_bytes_cap=ddp_module.bucket_bytes_cap, + backend_compile_fn=callback._torchdynamo_orig_callable, + ) + hijacked_callback = convert_frame.convert_frame( + ddp_optimizer.compile_fn, + hooks=hooks, + ) + return hijacked_callback(frame, cache_size, hooks) + + with compile_lock: + return callback(frame, cache_size, hooks) + + catch_errors._torchdynamo_orig_callable = callback # type: ignore[attr-defined] return catch_errors -def _optimize_catch_errors(compile_fn, backend_ctx_ctor=null_context): +def _optimize_catch_errors( + compile_fn, hooks: Hooks, backend_ctx_ctor=null_context, dynamic=False +): return OptimizeContext( - catch_errors_wrapper(compile_fn), + catch_errors_wrapper(compile_fn, hooks), backend_ctx_ctor=backend_ctx_ctor, first_ctx=True, + dynamic=dynamic, ) -class WrapperBackend: - def __init__(self, backend=None): - self.backend = backend - - @property - def example_inputs(self): - return clone_inputs(self.original_example_inputs) - - def __call__(self, gm: torch.fx.GraphModule, example_inputs): - - self.restore = checkpoint_params(gm) - self.original_example_inputs = clone_inputs(example_inputs) - self.gm = gm - copy_gm = copy.deepcopy(self.gm) - self.candidate = self.backend(copy_gm, self.original_example_inputs) - - if self.candidate is None or self.candidate is self.gm.forward: - return self.gm.forward - - if not config.verify_correctness: - return self.candidate - - # if verify_correctness=True - try: - correct = self.gm.forward(*self.example_inputs) - result = self.candidate(*self.example_inputs) - - # TODO: replace `same` function with the one in testing - if same(correct, result): - return self.candidate - - raise RuntimeError(f"incorrect results of backend {self}") - return self.gm.forward - - except Exception: - log.exception("error in verify_correctness") - raise - finally: - self.restore() - - def get_compiler_fn(compiler_fn): from .debug_utils import wrap_backend_debug @@ -307,6 +358,16 @@ def get_compiler_fn(compiler_fn): def lookup_backend(compiler_fn): """Expand backend strings to functions""" if compiler_fn == "inductor": + if torch.cuda.is_available(): + if ( + torch.backends.cuda.matmul.allow_tf32 is False + and torch.cuda.get_device_capability() >= (8, 0) + ): + warnings.warn( + "TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled." + "Consider setting `torch.set_float32_matmul_precision('high')`" + ) + compiler_fn = import_module(f"{config.inductor_import}.compile_fx").compile_fx elif isinstance(compiler_fn, str): from .optimizations import BACKENDS @@ -315,14 +376,20 @@ def lookup_backend(compiler_fn): return compiler_fn -class _NullDecorator(contextlib.nullcontext): +class _NullDecorator(contextlib.nullcontext): # type: ignore[type-arg] def __call__(self, fn): assert callable(fn) return fn def optimize( - backend="inductor", *, nopython=False, guard_export_fn=None, disable=False + backend="inductor", + *, + nopython=False, + guard_export_fn=None, + guard_fail_fn=None, + disable=False, + dynamic=False, ): """ The main entrypoint of TorchDynamo. Do graph capture and call @@ -336,17 +403,25 @@ def optimize( One can also provide additional context for the backend, like torch.jit.fuser("fuser2"), by setting the backend_ctx_ctor attribute. See AOTAutogradMemoryEfficientFusionWithContext for the usage. - - Or, a string backend name in `torchdynamo.list_backends()` + - Or, a string backend name in `torch._dynamo.list_backends()` nopython: If True, graph breaks will be errors and there will be a single whole-program graph. disable: If True, turn this decorator into a no-op + dynamic: If True, turn on dynamic shapes support - Example Usage: + Example Usage:: - @torchdynamo.optimize() + @torch._dynamo.optimize() def toy_example(a, b): ... """ + # Note: The hooks object could be global instead of passed around, *however* that would make + # for a confusing API usage and plumbing story wherein we nest multiple .optimize calls. + # There is some prior art around this, w/r/t nesting backend calls are enforced to be the same + # compiler, however, this feels onerous for callback and hooks, and it feels better to give our users an + # easier to understand UX at the cost of a little more plumbing on our end. + hooks = Hooks(guard_export_fn=guard_export_fn, guard_fail_fn=guard_fail_fn) + torch._C._log_api_usage_once("torch._dynamo.optimize") if disable or os.environ.get("TORCHDYNAMO_DISABLE", "") == "1": return _NullDecorator() if sys.platform == "win32": @@ -368,14 +443,21 @@ def toy_example(a, b): backend_ctx_ctor = getattr(backend, "backend_ctx_ctor", null_context) if nopython: - return optimize_assert(backend, guard_export_fn=guard_export_fn) + return optimize_assert( + backend, + dynamic=dynamic, + hooks=hooks, + ) return _optimize_catch_errors( - convert_frame.convert_frame(backend, guard_export_fn=guard_export_fn), + convert_frame.convert_frame(backend, hooks=hooks), + hooks, backend_ctx_ctor, + dynamic=dynamic, ) -@patch("torchdynamo.symbolic_convert.explain", True) +# TODO(voz): Consider making "explain" output alongside a run / part of a run +@patch("torch._dynamo.symbolic_convert.explain", True) def explain(f, *args, **kwargs): # TODO(voz): Do we want a decorator for this? from . import reset @@ -433,20 +515,29 @@ def guard_export_print(guards): msg = f"{break_reason.reason}\n{formatted_stack}" formatted_list += f"{idx + 1}. {msg} \n" - explanation = f"Dynamo produced {graph_count} graphs" + explanation = f"Dynamo produced {graph_count} graphs " explanation += f"with {graph_count - 1} graph break and {op_count} ops" - explanation += f"\n Break reasons: \n\n{formatted_list}" + explanation_verbose = explanation + explanation_verbose += f"\n Break reasons: \n\n{formatted_list}" - explanation += compile_times() + explanation_verbose += compile_times() # TODO(voz): Do we want a decorator for this? reset() - return explanation, out_guards, graphs, ops_per_graph, break_reasons + return ( + explanation, + out_guards, + graphs, + ops_per_graph, + break_reasons, + explanation_verbose, + ) def export( f, *args, aten_graph=False, decomposition_table=None, tracing_mode="real", **kwargs ): + torch._C._log_api_usage_once("torch._dynamo.export") if decomposition_table is not None or tracing_mode != "real": assert ( aten_graph @@ -456,7 +547,7 @@ def export( graph = None out_guards = None graph_captured_input = None - graph_captured_result = None + graph_captured_result: Optional[Tuple[torch.Tensor, ...]] = None def produce_matching(source_args, candidate_args): matched_elements_positions = [] @@ -505,6 +596,7 @@ def result_capturing_wrapper(*graph_inputs): nonlocal graph_captured_input graph_captured_input = graph_inputs + assert graph is not None graph_captured_result = graph(*graph_inputs) return graph_captured_result @@ -517,7 +609,7 @@ def result_capturing_wrapper(*graph_inputs): with patch(f"{__name__}.most_recent_backend", None): opt_f = optimize_assert( dynamo_normalization_capturing_compiler, - guard_export_fn=guard_export_print, + hooks=Hooks(guard_export_fn=guard_export_print, guard_fail_fn=None), export=True, )(f) # TODO(voz): We may have instances of `f` that mutate inputs, we should track sideffects and reject. @@ -531,6 +623,7 @@ def result_capturing_wrapper(*graph_inputs): flat_results_traced, out_spec_traced = pytree.tree_flatten(result_traced) + assert graph_captured_result is not None flat_both = list(graph_captured_result) + flat_args matched_output_elements_positions = produce_matching(flat_both, flat_results_traced) @@ -550,7 +643,10 @@ def __init__( ) def placeholder(self, target, args, kwargs): - return next(self.old_args_gen) + arg = next(self.old_args_gen) + if "val" in self.current_node.meta: + arg.node.meta["val"] = self.current_node.meta["val"] + return arg def output(self, target, args, kwargs): dynamo_result_flat = args[0] @@ -560,6 +656,10 @@ def output(self, target, args, kwargs): return super().output(target, (new_result,), {}) + def run_node(self, n): + self.current_node = n + return super().run_node(n) + if aten_graph: # Running graph with interpreter is needed for propagating the stack_trace def graph_with_interpreter(*args): @@ -581,15 +681,12 @@ def graph_with_interpreter(*args): def assume_constant_result(fn): fn._dynamo_marked_constant = True - assert ( - not config.fake_tensor_propagation - ), "Constant result capture is not supported with fake tensors." return fn -def optimize_assert(backend, *, guard_export_fn=None, export=False): +def optimize_assert(backend, *, hooks=Hooks(None, None), export=False, dynamic=False): """ - The same as `torchdynamo.optimize(backend, nopython=True)` + The same as `torch._dynamo.optimize(backend, nopython=True)` """ backend = get_compiler_fn(backend) @@ -597,8 +694,10 @@ def optimize_assert(backend, *, guard_export_fn=None, export=False): backend_ctx_ctor = getattr(backend, "backend_ctx_ctor", null_context) return _optimize_catch_errors( - convert_frame.convert_frame_assert(backend, guard_export_fn, export=export), + convert_frame.convert_frame_assert(backend, export=export), + hooks, backend_ctx_ctor, + dynamic=dynamic, ) @@ -651,8 +750,7 @@ def patch(): torch.onnx.export_to_pretty_string = disable(torch.onnx.export_to_pretty_string) torch.distributions.Distribution.set_default_validate_args(False) - if proxy_tensor is not None: - proxy_tensor.dispatch_trace = disable(proxy_tensor.dispatch_trace) + proxy_tensor.dispatch_trace = disable(proxy_tensor.dispatch_trace) optimizers = [ opt @@ -671,6 +769,11 @@ def patch(): opt._cuda_graph_capture_health_check = disable( opt._cuda_graph_capture_health_check ) + opt.zero_grad = disable(opt.zero_grad) + + if hasattr(opt, "_init_group"): + opt._init_group = disable(opt._init_group) + # disable any currently set hooks # Note: we only want to disable the profiling hook # which is the *last* hook applied, we want to keep the no_grad hook diff --git a/torch/_dynamo/exc.py b/torch/_dynamo/exc.py index 3001c8c823924..41a9f68351aa9 100644 --- a/torch/_dynamo/exc.py +++ b/torch/_dynamo/exc.py @@ -29,8 +29,8 @@ def __init__(self): super(ResetRequired, self).__init__( textwrap.dedent( """ - Must call `torchdynamo.reset()` before changing backends. Detected two calls to - `torchdynamo.optimize(...)` with a different backend compiler arguments. + Must call `torch._dynamo.reset()` before changing backends. Detected two calls to + `torch._dynamo.optimize(...)` with a different backend compiler arguments. """ ) ) @@ -40,12 +40,8 @@ class BackendCompilerFailed(TorchDynamoException): def __init__(self, backend_fn, inner_exception): self.backend_name = getattr(backend_fn, "__name__", "?") self.inner_exception = inner_exception - super().__init__( - f"{self.backend_name} raised {type(inner_exception).__name__}: {inner_exception}" - "\n\n" - "You can suppress this exception and fall back to eager by setting:\n" - " torchdynamo.config.raise_on_backend_error = False" - ) + msg = f"{self.backend_name} raised {type(inner_exception).__name__}: {inner_exception}" + super().__init__(msg) class Unsupported(TorchDynamoException): diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 8f94714784d73..36d628a7ece5c 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -5,19 +5,23 @@ import math import os import re -import textwrap import types import weakref from inspect import currentframe, getframeinfo -from typing import Any, Callable, Dict, List, Optional, Set +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union +from weakref import ReferenceType import numpy as np +import sympy + import torch +from torch.fx.experimental.symbolic_shapes import FloorDiv from . import config, convert_frame, mutation_guard from .eval_frame import set_guard_error_hook, set_guard_fail_hook from .exc import unimplemented +from .types import GuardedCode, GuardFail, GuardFn # noqa: F401 from .utils import ( dict_const_keys, dict_param_key_ids, @@ -57,6 +61,8 @@ class GuardSource(enum.Enum): LOCAL_NN_MODULE = 2 GLOBAL_NN_MODULE = 3 CONSTANT = 4 + RANDOM_VALUE = 5 + SHAPE_ENV = 6 def select(self, locals_, globals_): if self in (GuardSource.LOCAL, GuardSource.LOCAL_NN_MODULE): @@ -65,7 +71,7 @@ def select(self, locals_, globals_): return globals_ raise NotImplementedError() - def is_nn_module(self): + def is_nn_module(self) -> bool: return self in (GuardSource.GLOBAL_NN_MODULE, GuardSource.LOCAL_NN_MODULE) def is_local(self): @@ -74,15 +80,30 @@ def is_local(self): @dataclasses.dataclass class Guard: + # The name of a Guard specifies what exactly it is the guard is guarding + # on. The meaning of the name is dependent on the create_fn; you must + # look at the use-site inside create_fn to know what name means. + # + # That being said, although you might think this is just a "name", name is + # usually an arbitrary Python expression that will be evaluated with all + # globals (and locals, if you create a LOCAL guard) to extract the Python + # object that we want to perform guard tests on. This evaluation + # typically happens in GuardBuilder.eval. In these cases, name is + # typically produced by Source.name() (not to be confused with + # GuardSource)--morally, we could have stored a Source here. + # + # Occasionally, name is not a valid Python expression; sometimes + # it is meaningless. Example create_fns that are like this include + # GRAD_MODE and SYMBOL_MATCH. name: str source: GuardSource - create_fn: Callable + create_fn: Callable[["GuardBuilder", "Guard"], None] is_volatile: bool = False # Export only. These values are written to at time of guard check_fn creation. guard_types: Optional[List[str]] = None code_list: Optional[List[str]] = None - obj_weakref: Optional[Any] = None + obj_weakref: Optional[object] = None guarded_class_weakref: Optional[type] = None def __hash__(self): @@ -90,7 +111,7 @@ def __hash__(self): def sort_key(self): return ( - self.source.value, + self.source.value if self.source else -1, len(self.name), self.name, self.create_fn.__code__.co_firstlineno, @@ -99,13 +120,38 @@ def sort_key(self): def __lt__(self, other): return self.sort_key() < other.sort_key() + @staticmethod + def weakref_to_str(obj_weakref): + """ + This is a workaround of a Python weakref bug. + + `obj_weakref` is instance returned by `weakref.ref`, + `str(obj_weakref)` is buggy if the original obj overrides __getattr__, e.g: + + class MyConfig(dict): + def __getattr__(self, x): + return self[x] + + obj = MyConfig(offset=5) + obj_weakref = weakref.ref(obj) + str(obj_weakref) # raise error: KeyError: '__name__' + """ + if isinstance(obj_weakref, weakref.ReferenceType): + obj = obj_weakref() + if obj is not None: + return f"" + else: + return f"" + else: + return str(obj_weakref) + def __str__(self): s = f""" - {self.source.name.lower()} {repr(self.name)} {self.create_fn.__name__} + {self.source.name.lower() if self.source else ""} {repr(self.name)} {self.create_fn.__name__} {{ 'guard_types': {self.guard_types}, 'code': {self.code_list}, - 'obj_weakref': {self.obj_weakref} + 'obj_weakref': {self.weakref_to_str(self.obj_weakref)} 'guarded_class': {self.guarded_class_weakref} }} """ @@ -164,7 +210,11 @@ def strip_getattr_getitem(name): class GuardBuilder: def __init__( - self, id_ref: Callable, scope: Dict[str, Any], guarded_code, renames=True + self, + id_ref: Callable[[Type[object]], str], + scope: Optional[Dict[str, object]], + guarded_code: "CheckFunctionManager", + renames=True, ): self.id_ref = id_ref if scope: @@ -172,18 +222,34 @@ def __init__( scope = {rename_implicit(k): v for k, v in scope.items()} else: scope = dict() - self.scope = scope + self.scope: Dict[str, object] = scope self.argnames: List[str] = [] # Code is python expression strings generated for each guard self.code: List[str] = [] - self.tensor_check_names = [] - self.tensor_check_examples = [] - self.guarded_code = guarded_code - def get(self, name: str): + # Most of the time, we generate Python code in a guard to directly + # check various properties. However, tensors are a bit special; + # it is too slow to check their properties one-by-one in Python. + # Instead, there is a C++ function TensorGuards.check which takes + # all of the tensor arguments and checks them all against compile-time + # examples entirely in C++. Thus, every time we process a + # TENSOR_MATCH guard, we just add another entry to + # tensor_check_names/tensor_check_examples, saying "for this local, + # check it against this example", and it all ends up getting + # swept up into a single call to ___check_tensors. Invariant: + # len(tensor_check_names) == len(tensor_check_examples). + self.tensor_check_names: List[str] = [] + self.tensor_check_examples: List[torch.Tensor] = [] + + self.tensor_check_ids: Dict[str, int] = {} + # TODO: tf is this naming + self.guarded_code: CheckFunctionManager = guarded_code + + def get(self, name: str) -> Any: return eval(name, self.scope, CLOSURE_VARS) - def arg_ref(self, guard: Guard): + def arg_ref(self, guard: Union[str, Guard]) -> str: + name: str if isinstance(guard, str): name = guard else: @@ -208,7 +274,9 @@ def ID_MATCH(self, guard: Guard): m = re.match(r"^type\((.+)\)$", guard.name) if m: # optional optimization to produce cleaner/faster guard code - return self.TYPE_MATCH(Guard(m.group(1), guard.source, None)) + return self.TYPE_MATCH( + Guard(m.group(1), guard.source, GuardBuilder.TYPE_MATCH) + ) code = f"___check_obj_id({self.arg_ref(guard)}, {self.id_ref(self.get(guard.name))})" self._produce_guard_code(guard, [code]) @@ -266,8 +334,8 @@ def EQUALS_MATCH(self, guard: Guard): ), t.__name__ if istype(val, (torch.device, torch.dtype)): # TODO(jansel): is this slow? perhaps optimize it - code = f"str({ref}) == {str(val)!r}" - self._produce_guard_code(guard, [code]) + code = [f"str({ref}) == {str(val)!r}"] + self._produce_guard_code(guard, code) return # Special case for nan because float("nan") == float("nan") evaluates to False @@ -410,14 +478,25 @@ def GRAD_MODE(self, guard: Guard): code = "not ___is_grad_enabled()" self._produce_guard_code(guard, [code]) + # This is a bit of a crutch for export case for symbolic shape guards. + # SYMBOL_MATCH is only ever, and must only ever, be used for setting this value on + # the create_fn field for tracking guards in export. + def SYMBOL_MATCH(self, guard: Guard): + raise AssertionError("this should not actually be called") + def TENSOR_MATCH(self, guard: Guard): if guard.is_nn_module(): self.ID_MATCH(guard) else: value = self.get(guard.name) - self.tensor_check_names.append(self.arg_ref(guard)) + assert isinstance(value, torch.Tensor) + tensor_name = self.arg_ref(guard) + self.tensor_check_names.append(tensor_name) self.tensor_check_examples.append(value) + # STOP - DO NOT USE id_ref FOR TENSORS - TENSOR INVALIDATION RULES DIFFER + self.tensor_check_ids[tensor_name] = id(value) + # Note: Guard code produced for tensor_match is a little different. # We accumulate tensor names, then do a single install of `___check_tensors`. # See _guards.cpp and TensorGuard for more information. @@ -432,8 +511,16 @@ def TENSOR_MATCH(self, guard: Guard): # A util that appends guarded code, or, in the case of export, adds data onto guards def _produce_guard_code(self, guard, code_list, provided_guarded_object=None): - caller = currentframe().f_back + # WARNING: It is important that cur_frame/caller do NOT stay in + # the current frame, because they will keep things live longer + # than they should. See TestMisc.test_release_module_memory + cur_frame = currentframe() + assert cur_frame is not None + caller = cur_frame.f_back + del cur_frame + assert caller is not None func_name = getframeinfo(caller)[2] + del caller # We use func_name for export, so might as well get a nice defensive check out of it assert func_name in dir( self.__class__ @@ -464,10 +551,68 @@ def _produce_guard_code(self, guard, code_list, provided_guarded_object=None): ) +from sympy.printing.str import StrPrinter + + @dataclasses.dataclass -class GuardedCode: - code: types.CodeType - check_fn: Callable +class TensorReference(object): + """ + TensorReference objects are entirely optional. They are created to give us hints + into where the symbolic shape came from. + + ref_id: The id of the tensor + kind: A string tracking where in the tensor this value came from ("size","stride", etc) + idx: An index in the structure + + NOTE - A symbolic shape coming from tensor at id 12345's shape dim 2, would be + TensorReference(ref_id=12345, kind="size", idx=2) + """ + + ref_id: Optional[int] = None + kind: Optional[str] = None + idx: Optional[int] = None + # Note - this is untyped because of TypeError: '_SpecialForm' object does not support item assignment + # But it is a Optional[Union["sympy.Expr", int]] + expr: Optional[object] = None # Populated after association + + def __hash__(self): + return hash((self.ref_id, self.kind, self.idx)) + + +class DynamoGuardPrinter(StrPrinter): + @staticmethod + def tensor_ref_as_str(tensor_ref, id_to_name_map): + if tensor_ref.kind in ("size", "stride"): + return f"{id_to_name_map[tensor_ref.ref_id]}.{tensor_ref.kind}()[{tensor_ref.idx}]" + return f"{id_to_name_map[tensor_ref.ref_id]}.{tensor_ref.kind}()" + + def __init__( + self, + expr_to_tensor_ref: Dict[sympy.Symbol, Dict[TensorReference, None]], + id_to_name_map, + shape_env, + intermediary_symbols, + ): + super().__init__() + self.expr_to_tensor_ref = expr_to_tensor_ref + self.id_to_name_map = id_to_name_map + self.shape_env = shape_env + self.intermediary_symbols = intermediary_symbols + + def _print_Symbol(self, expr) -> str: + assert isinstance(expr, sympy.Symbol) + if expr == 0: + return "0" + if expr == 1: + return "1" + assert expr in (self.expr_to_tensor_ref) or (expr in self.intermediary_symbols) + refs = self.expr_to_tensor_ref[expr] + if len(refs) == 0: + return super()._print_Symbol(expr) + tensor_ref = next( + iter(refs) + ) # Any is fine here, because we install equality guards later + return DynamoGuardPrinter.tensor_ref_as_str(tensor_ref, self.id_to_name_map) # NB: Naively, you'd expect this to only be a function that produces @@ -483,13 +628,16 @@ class GuardedCode: class CheckFunctionManager: def __init__( self, + output_graph=None, guards: Optional[Set[Guard]] = None, - f_locals: Optional[Dict] = None, - f_globals: Optional[Dict] = None, + f_locals: Optional[Dict[str, object]] = None, + f_globals: Optional[Dict[str, object]] = None, + guard_fail_fn: Optional[Callable[[Tuple[str, str]], None]] = None, ): self.valid = True - self._weakrefs = [] - self._seen_ids = set() + self._weakrefs: List["ReferenceType[object]"] = [] + self._seen_ids: Set[int] = set() + self.output_graph = output_graph # Note: right overrides left def combine_scopes(left, right): @@ -509,16 +657,99 @@ def combine_scopes(left, right): if not config.guard_nn_modules and guard.is_nn_module(): continue guard.create(local_builder, global_builder) - self.check_fn = self.compile_check_fn(local_builder, global_builder) + self.check_fn = self.compile_check_fn( + local_builder, global_builder, guards, guard_fail_fn + ) self._seen_ids.clear() - def compile_check_fn(self, local_builder, global_builder): + """ + This is a complex bit of logic. The outline here is brief. For a line by line breakdown, see + the code comments below. + + The role of this function is to take the current state of symbolic shape guards, tensor ids in the + CURRENT dynamo frame, and tensor names (dynamo's frame agnostic tensor reference mechanism, see TensorCheck and + guards.cpp for more info) - and produce executable python expressions for addition to our guarded code components + that make their way into check_fn. + + We DO NOT create guards based on ids. The IDs act as a lookup for the following mapping: + + dynamo: tensor_name <> tensor_id + shape_env: tensor_id <> shape_expr + + This allows us to then create a tensor_name <> shape_expr association for the current frames guards. + """ + + def _parse_symbolic_shape_expressions(self, tensor_check_names, tensor_check_ids): + # Pre join output + finished_expressions: List[str] = [] + + # A mapping of tensor_ids to tensor names + id_to_name_map: Dict[int, str] = {} + + # We should not have a shape env, or guards if we are not in config.dynamic shapes + # But check it anyway. + if not config.dynamic_shapes: + return None + + expr_to_tensor_ref: Dict[sympy.Symbol, Dict[TensorReference, None]] = {} + guard_printer = DynamoGuardPrinter( + expr_to_tensor_ref, + id_to_name_map, + self.output_graph.shape_env, + self.output_graph.intermediary_symbols, + ) + + # tensor_check_names is the primary tensor association mechanism in dynamo. + # All other guards installations are driven off of it, so these ones will too. + for name in tensor_check_names: + tensor_id = tensor_check_ids[name] + id_to_name_map[tensor_id] = name + + if tensor_id in self.output_graph.tensor_id_to_sym_shape_ref: + # If we made it here, this tensor_id is relevant to dynamo guard installation + # AND was found in the shape_env + tensor_ref_set = self.output_graph.tensor_id_to_sym_shape_ref[tensor_id] + for tensor_ref in tensor_ref_set: + obj_expr = tensor_ref.expr + if obj_expr not in expr_to_tensor_ref: + expr_to_tensor_ref[obj_expr] = {} + expr_to_tensor_ref[obj_expr][tensor_ref] = None + + guard_expression = self.output_graph.shape_env.get_guard_expr() + expr_as_str = guard_printer.doprint(guard_expression) + # We may get into a state where symbolic shape keys (all should be found in replacements) + # Have not been removed from the expression. This is a serious enough error state that we need to assert. + for key in self.output_graph.shape_env.var_to_val.keys(): + assert str(key) not in expr_as_str, f"Unknown shape symbol {key}. " + finished_expressions.append(expr_as_str) + + for expr in expr_to_tensor_ref.keys(): + tensor_refs = expr_to_tensor_ref[expr].keys() + equality_candidates = [ + DynamoGuardPrinter.tensor_ref_as_str(x, id_to_name_map) + for x in tensor_refs + ] + + if len(equality_candidates) > 1: + equality_expr = " == ".join(equality_candidates) + finished_expressions.append(equality_expr) + + # Redundant with code_parts, but allows us to wrap it with parens nicely. + if len(finished_expressions) == 0: + return None + + expression = " and ".join(finished_expressions) + return f"({expression})" + + def compile_check_fn( + self, local_builder, global_builder, guards_out, guard_fail_fn + ): assert not (set(local_builder.argnames) & set(global_builder.argnames)) # see parallel handling of ".0" / "___implicit0" in _eval_frame.c - args = [a for a in local_builder.scope.keys() if a == "___implicit0"] - args += [a for a in local_builder.argnames if a != "___implicit0"] - args += ["**___kwargs_ignored"] - args = ",".join(args) + largs = [a for a in local_builder.scope.keys() if a == "___implicit0"] + largs += [a for a in local_builder.argnames if a != "___implicit0"] + largs += ["**___kwargs_ignored"] + args = ",".join(largs) code_parts = ( ["___guarded_code.valid"] + local_builder.code + global_builder.code @@ -531,9 +762,16 @@ def compile_check_fn(self, local_builder, global_builder): tensor_check_names = ( local_builder.tensor_check_names + global_builder.tensor_check_names ) + + tensor_check_ids = local_builder.tensor_check_ids.copy() + tensor_check_ids.update(global_builder.tensor_check_ids) + check_tensors_fn = None check_tensors_verbose_fn = None if tensor_check_names: + symbolic_shape_expression = self._parse_symbolic_shape_expressions( + tensor_check_names, tensor_check_ids + ) tensor_check_examples = ( local_builder.tensor_check_examples + global_builder.tensor_check_examples @@ -548,28 +786,49 @@ def compile_check_fn(self, local_builder, global_builder): tensor_check_names + ["tensor_check_names=tensor_check_names"] ) verbose_code_parts.append(f"___check_tensors_verbose({verbose_args})") + if symbolic_shape_expression: + code_parts.append(symbolic_shape_expression) + verbose_code_parts.append(symbolic_shape_expression) + guards_out.add( + Guard( + name="symbolic_shape_expression", + source=GuardSource.SHAPE_ENV, + create_fn=GuardBuilder.SYMBOL_MATCH, + code_list=symbolic_shape_expression, + ) + ) - code = " and ".join(unique(code_parts)) + def direct_equality(a, b): + return a == b + + def direct_negation(a, b): + return not direct_equality(a, b) + code = " and ".join(unique(code_parts)) closure_vars = collections.OrderedDict( [ ("___guarded_code", self), ("___check_tensors", check_tensors_fn), ("___check_tensors_verbose", check_tensors_verbose_fn), ("tensor_check_names", tensor_check_names), + ("floor", math.floor), + ("ceiling", math.ceil), + ("Eq", direct_equality), + ("Ne", direct_negation), + ("Mod", sympy.Mod), + ("FloorDiv", FloorDiv), ] ) closure_vars.update(CLOSURE_VARS) - py_code = textwrap.dedent( - f""" - def ___make_guard_fn({','.join(closure_vars.keys())}): - return lambda {args}: {code} - """ - ) + py_code = f"""\ +def ___make_guard_fn({','.join(closure_vars.keys())}): + return lambda {args}: {code} +""" if os.environ.get("TORCHDYNAMO_PRINT_GUARDS", None) == "1": print("GUARDS", code) set_guard_fail_hook(guard_fail_hook) - out = dict() + out: Dict[str, Any] = dict() + # print("RUNNING PY CODE", py_code) exec(py_code, global_builder.scope, out) guard_fn = out["___make_guard_fn"](*closure_vars.values()) guard_fn.closure_vars = closure_vars @@ -577,6 +836,7 @@ def ___make_guard_fn({','.join(closure_vars.keys())}): guard_fn.code_parts = code_parts guard_fn.verbose_code_parts = verbose_code_parts guard_fn.global_scope = global_builder.scope + guard_fn.guard_fail_fn = guard_fail_fn return guard_fn def invalidate(self, ref): @@ -595,31 +855,43 @@ def id_ref(self, obj): def guard_fail_hook( - guard_fn: Callable, code: types.CodeType, f_locals: Dict[str, Any], last: bool -): + guard_fn: GuardFn, code: types.CodeType, f_locals: Dict[str, object], last: bool +) -> None: """ called whenever a guard fails. """ - if not last: + if not guard_fn.guard_fail_fn and not last: return scope = {rename_implicit(k): v for k, v in f_locals.items()} scope.update(guard_fn.closure_vars) - reasons = [] + reason = None for part in guard_fn.verbose_code_parts: fail_reason = eval(part, guard_fn.global_scope, scope) # TODO(whc) hacky for now as not every 'part' in guard_fn.verbose_code_parts # is updated to return a string explaining the failure. if isinstance(fail_reason, str): - reasons.append(fail_reason) + reason = fail_reason break elif isinstance(fail_reason, bool) and not fail_reason: - reasons.append(part) + reason = part break - guard_failures[orig_code_map[code]].append(reasons) + try: + if guard_fn.guard_fail_fn is not None: + guard_fn.guard_fail_fn( + GuardFail(reason or "unknown reason", orig_code_map[code]) + ) + except Exception as e: + log.error( + "Failure in guard_fail_fn callback - raising here will cause a NULL Error on guard eval", + exc_info=True, + ) + + if last: + guard_failures[orig_code_map[code]].append(reason) def guard_error_hook( - guard_fn: Callable, code: types.CodeType, f_locals: Dict[str, Any], last: bool + guard_fn: GuardFn, code: types.CodeType, f_locals: Dict[str, object], last: bool ): print( f"ERROR RUNNING GUARDS {code.co_name} {code.co_filename}:{code.co_firstlineno}" diff --git a/torch/_dynamo/hooks.py b/torch/_dynamo/hooks.py new file mode 100644 index 0000000000000..6a3f64c9ccaf0 --- /dev/null +++ b/torch/_dynamo/hooks.py @@ -0,0 +1,9 @@ +import dataclasses + +from typing import Callable, Optional, Set, Tuple + + +@dataclasses.dataclass +class Hooks: + guard_export_fn: Optional[Callable[[Set["Guard"]], None]] = None + guard_fail_fn: Optional[Callable[[Tuple["GuardFail"]], None]] = None diff --git a/torch/_dynamo/logging.py b/torch/_dynamo/logging.py index 0705e77a7c7d5..61000481580f1 100644 --- a/torch/_dynamo/logging.py +++ b/torch/_dynamo/logging.py @@ -2,10 +2,14 @@ import logging import os +from torch.hub import Faketqdm, tqdm + # logging level for dynamo generated graphs/bytecode/guards -CODE = 15 -logging.addLevelName(CODE, "CODE") +logging.CODE = 15 +logging.addLevelName(logging.CODE, "CODE") +# Disable progress bar by default, not in dynamo config because otherwise get a circular import +disable_progress = True # Return all loggers that torchdynamo/torchinductor is responsible for def get_loggers(): @@ -78,8 +82,26 @@ def init_logging(log_level, log_file_name=None): _step_counter = itertools.count(1) +# Update num_steps if more phases are added: Dynamo, AOT, Backend +# This is very inductor centric +# _inductor.utils.has_triton() gives a circular import error here + +if not disable_progress: + try: + import triton # noqa: F401 + + num_steps = 3 + except ImportError: + num_steps = 2 + pbar = tqdm(total=num_steps, desc="torch.compile()", delay=0) + def get_step_logger(logger): + if not disable_progress: + pbar.update(1) + if not isinstance(pbar, Faketqdm): + pbar.set_postfix_str(f"{logger.name}") + step = next(_step_counter) def log(level, msg): diff --git a/torch/_dynamo/optimizations/analysis.py b/torch/_dynamo/optimizations/analysis.py index b7557a82d744a..f732fb322438f 100644 --- a/torch/_dynamo/optimizations/analysis.py +++ b/torch/_dynamo/optimizations/analysis.py @@ -1,21 +1,18 @@ -import copy import functools import itertools import operator import torch + +from torch._subclasses import FakeTensorMode # noqa: F401 from torch.fx.node import map_aggregate from torch.fx.passes.shape_prop import _extract_tensor_metadata, ShapeProp from torch.multiprocessing.reductions import StorageWeakRef from torch.utils._pytree import tree_map from .. import config -from ..utils import clone_inputs, fake_tensors_available - -if fake_tensors_available: - from torch._subclasses import FakeTensorMode # noqa: F401 - from ..utils import deepcopy_to_fake_tensor, wrap_to_fake_tensor +from ..utils import deepcopy_to_fake_tensor class ShapeAliasingAndMutationProp(ShapeProp): @@ -24,10 +21,11 @@ def __init__(self, *args, **kwargs): self.input_alias_groups = set() self.storage_to_alias_group = dict() self.make_alias_group = itertools.count(1) + self.name = "ShapeAliasingAndMutation" def tensor_alias_group(self, value: torch.Tensor): """Assign a unique identifier to the storage of a given tensor""" - storage = StorageWeakRef(value.storage()) + storage = StorageWeakRef(value._typed_storage()) alias_group = self.storage_to_alias_group.get(storage) if alias_group is None: alias_group = next(self.make_alias_group) @@ -121,21 +119,27 @@ def has_mutation(gm, example_inputs, inputs_only=False): true, we only check for mutation of inputs""" # TODO - moco gives bad accuracy with Aliasing. gm is getting mutated in a bad way. - # Clone the inputs such that intermediate tensors (not leaf tensors) with - # requires_grad to True are now converted to False to avoid Runtime Error - # like "leaf variable that requires grad is inplace modified" - example_inputs = clone_inputs(example_inputs) - if fake_tensors_available and config.fake_tensor_propagation: - with FakeTensorMode() as fake_mode: - pass - fake_wrapper = functools.partial(wrap_to_fake_tensor, fake_mode=fake_mode) - example_inputs = tree_map(fake_wrapper, example_inputs) - new_gm = deepcopy_to_fake_tensor(gm, fake_mode) - with fake_mode.restore() if hasattr(fake_mode, "restore") else fake_mode: - ShapeAliasingAndMutationProp(new_gm).run(*example_inputs) - else: - new_gm = copy.deepcopy(gm) - example_inputs = copy.deepcopy(example_inputs) + def _wrap_to_fake_tensor(t, *, f_mode): + if isinstance(t, torch.Tensor): + # TODO: it probably doesn't matter if we're dynamic shapes or not + static_shapes_ = config.dynamic_shapes is False + return fake_mode.from_tensor( + t, static_shapes=config.dynamic_shapes is not False + ) + else: + return t + + # Our analysis pass should use dynamic shape tensor inputs + # when dynamic shapes are enabled. + # We don't actually care about the guards that are created + # on those shapes though, so just create a fresh ShapeEnv here. + from torch.fx.experimental.symbolic_shapes import ShapeEnv + + fake_mode = FakeTensorMode(shape_env=ShapeEnv() if config.dynamic_shapes else None) + fake_wrapper = functools.partial(_wrap_to_fake_tensor, f_mode=fake_mode) + example_inputs = tree_map(fake_wrapper, example_inputs) + new_gm = deepcopy_to_fake_tensor(gm, fake_mode) + with fake_mode.restore() if hasattr(fake_mode, "restore") else fake_mode: ShapeAliasingAndMutationProp(new_gm).run(*example_inputs) for node in new_gm.graph.nodes: diff --git a/torch/_dynamo/optimizations/backends.py b/torch/_dynamo/optimizations/backends.py index abcb4290e7826..90620ccd6cddf 100644 --- a/torch/_dynamo/optimizations/backends.py +++ b/torch/_dynamo/optimizations/backends.py @@ -6,15 +6,18 @@ import subprocess import tempfile +from typing import Dict + import numpy as np import torch +from ..output_graph import CompilerFn from ..utils import identity from .subgraph import SubGraph log = logging.getLogger(__name__) -BACKENDS = dict() +BACKENDS: Dict[str, CompilerFn] = dict() _NP_DTYPE = { torch.float16: np.float16, torch.float32: np.float32, @@ -53,9 +56,6 @@ def inner(model, example_inputs=None, **kwargs): return fn(model, **kwargs) except KeyboardInterrupt: raise - except Exception: - log.exception(f"{fn.__name__} error") - return None BACKENDS[fn.__name__] = inner return inner @@ -100,13 +100,13 @@ def nnc_ofi(subgraph): @create_backend -def nvfuser(subgraph): +def ts_nvfuser(subgraph): with torch.jit.fuser("fuser2"): return reload_jit_model(subgraph) @create_backend -def nvfuser_ofi(subgraph): +def ts_nvfuser_ofi(subgraph): with torch.jit.fuser("fuser2"): return reload_jit_model_ofi(subgraph) @@ -133,7 +133,7 @@ def static_runtime(subgraph): def onnxrt_common(subgraph, provider, onnx_filename=None): - import onnxruntime + import onnxruntime # type: ignore[import] assert provider in onnxruntime.get_available_providers() session = onnxruntime.InferenceSession( @@ -144,9 +144,9 @@ def onnxrt_common(subgraph, provider, onnx_filename=None): create_outputs = subgraph.empty_outputs_factory() is_cpu = subgraph.is_cpu - def _call(*args): + def _call(*initial_args): binding = session.io_binding() - args = [a.contiguous() for a in args] + args = [a.contiguous() for a in initial_args] for name, value in zip(input_names, args): dev = value.device binding.bind_input( @@ -231,7 +231,7 @@ def onnxrt(subgraph): @functools.lru_cache(None) def _init_tensorflow(): - import tensorflow as tf + import tensorflow as tf # type: ignore[import] # prevent tensorflow from eating all the GPU memory gpus = tf.config.list_physical_devices("GPU") @@ -242,8 +242,8 @@ def _init_tensorflow(): @create_backend def onnx2tf(subgraph): - import onnx - from onnx_tf.backend import prepare + import onnx # type: ignore[import] + from onnx_tf.backend import prepare # type: ignore[import] tf = _init_tensorflow() filename = subgraph.filename("tensorflow") @@ -256,8 +256,8 @@ def onnx2tf(subgraph): tf_module = tf.saved_model.load(filename) tf_module = tf.function(tf_module, jit_compile=True) - def run(*args): - args = [a.contiguous() for a in args] + def run(*i_args): + args = [a.contiguous() for a in i_args] with tf.device(device): outs = tf_module( **{ @@ -295,7 +295,7 @@ def taso(subgraph): @create_backend def ipex(subgraph, **kwargs): - import intel_extension_for_pytorch as ipex + import intel_extension_for_pytorch as ipex # type: ignore[import] inputs = subgraph.example_inputs model = subgraph.model @@ -324,12 +324,20 @@ def fx2trt(subgraph, **kwargs): # TensorRT fails violently with an abort() on this return None - from torch_tensorrt.fx.fx2trt import InputTensorSpec, TRTInterpreter - from torch_tensorrt.fx.passes.lower_basic_pass import transform_setitem - from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter, TRTSplitterSetting - from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer - from torch_tensorrt.fx.trt_module import TRTModule - from torch_tensorrt.fx.utils import LowerPrecision + from torch_tensorrt.fx.fx2trt import ( # type: ignore[import] + InputTensorSpec, + TRTInterpreter, + ) + from torch_tensorrt.fx.passes.lower_basic_pass import ( # type: ignore[import] + transform_setitem, + ) + from torch_tensorrt.fx.tools.trt_splitter import ( # type: ignore[import] + TRTSplitter, + TRTSplitterSetting, + ) + from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer # type: ignore[import] + from torch_tensorrt.fx.trt_module import TRTModule # type: ignore[import] + from torch_tensorrt.fx.utils import LowerPrecision # type: ignore[import] from .normalize import normalize_ir @@ -417,7 +425,7 @@ def torch2trt(subgraph): # TensorRT fails violently with an abort() on this return None - from torch2trt import torch2trt + from torch2trt import torch2trt # type: ignore[import] inputs = subgraph.example_inputs trt_mod = torch2trt( @@ -441,45 +449,6 @@ def tensorrt(subgraph): return model -@create_backend -def onnx2tensorrt_alt(subgraph): - if subgraph.will_tensorrt_barf(): - # TensorRT fails violently with an abort() on this - return None - - import tensorrt as trt - - from torch.fx.experimental.fx2trt.trt_module import TRTModule - - inputs = subgraph.example_inputs - - logger = trt.Logger(trt.Logger.ERROR) - builder = trt.Builder(logger) - config = builder.create_builder_config() - assert isinstance(inputs, (list, tuple)) - inputs = tuple(inputs) - input_names = subgraph.input_names - output_names = subgraph.output_names - network = builder.create_network( - 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) - ) - parser = trt.OnnxParser(network, logger) - success = parser.parse(open(subgraph.onnx_filename, "rb").read()) - for idx in range(parser.num_errors): - print(parser.get_error(idx)) - assert success - - config.max_workspace_size = 1 << 25 - config.set_flag(trt.BuilderFlag.STRICT_TYPES) - builder.max_batch_size = len(inputs[0]) - - engine = builder.build_engine(network, config) - assert engine - - trt_mod = TRTModule(engine, input_names, output_names) - return subgraph.wrap_returns(trt_mod) - - @create_backend def cudagraphs(subgraph): model = subgraph.model @@ -548,22 +517,6 @@ def run(*new_inputs): return run -@create_backend -def aot_autograd(subgraph, **kwargs): - def _wrapped_bw_compiler(*args, **kwargs): - # stop TorchDynamo from trying to compile our generated backwards pass - return disable(bw_compiler(*args, **kwargs)) - - bw_compiler = kwargs.get("bw_compiler") or kwargs["fw_compiler"] - kwargs["bw_compiler"] = _wrapped_bw_compiler - - from functorch.compile import aot_module_simplified - - from .. import disable - - return aot_module_simplified(subgraph.model, **kwargs) - - def tvm_compile(jit_mod, example_inputs, log_file=None, **kwargs): if jit_mod is None: return None @@ -631,9 +584,9 @@ def tvm_compile_inner( jit_mod, example_inputs, tuning_option=None, log_file=None, trials=20000, cuda=False ): try: - import tvm - from tvm import relay - from tvm.contrib import graph_executor + import tvm # type: ignore[import] + from tvm import relay # type: ignore[import] + from tvm.contrib import graph_executor # type: ignore[import] shape_list = [(f"inp_{idx}", i.shape) for idx, i in enumerate(example_inputs)] mod, params = relay.frontend.from_pytorch(jit_mod, shape_list) @@ -681,7 +634,7 @@ def tvm_compile_inner( elif tuning_option == "meta_schedule": from os import path as osp - from tvm.contrib.torch import optimize_torch + from tvm import meta_schedule as ms with tempfile.TemporaryDirectory() as work_dir: if log_file is not None: @@ -689,14 +642,28 @@ def tvm_compile_inner( log_file ), "TVM's meta_schedule requires a directory for storing log files." work_dir = log_file - - lib = optimize_torch( - jit_mod, - example_inputs, - max_trials_global=20000, + if not cuda: + # meta_schedule needs num-cores to be specified + # here we use the maximum core count + target = tvm.target.Target( + f"{llvm_target()} --num-cores {ms.utils.cpu_count(logical=False)}" + ) + # TODO(shingjan): This could be replaced by tvm.contrib.torch.optimize_torch + # once USE_PT_TVMDSOOP is updated and turned on by default in TVM. + database = ms.relay_integration.tune_relay( + mod=mod, + target=target, work_dir=work_dir, + max_trials_global=20000, + num_trials_per_iter=64, + params=params, + strategy="evolutionary", + ) + lib = ms.relay_integration.compile_relay( + database=database, + mod=mod, target=target, - max_trials_per_task=64, + params=params, ) elif tuning_option is None: @@ -708,41 +675,41 @@ def tvm_compile_inner( "This tuning option is invalid/not implemented for torchdynamo's TVM-related backend. " "There are three available options including None, auto_scheduler and meta_schedule." ) - if tune_option != "meta_schedule": - m = graph_executor.GraphModule(lib["default"](dev)) - - def to_torch_tensor(nd_tensor): - """A helper function to transfer a NDArray to torch.tensor.""" - if nd_tensor.dtype == "bool": - # DLPack does not support boolean so it can't be handled by - # torch.utils.dlpack.from_pack. Workaround by going through - # numpy, although this brings additional data copy overhead. - return torch.from_numpy(nd_tensor.numpy()) - return torch.utils.dlpack.from_dlpack(nd_tensor.to_dlpack()) - - def exec_tvm(*args): - args = [a.contiguous() for a in args] - for idx, arg in enumerate(args, 0): - if arg.dim() != 0: - if arg.requires_grad: - arg = arg.detach() - m.set_input( - f"inp_{idx}", - tvm.nd.array(arg.numpy(), dev), - ) - m.run() - return [ - to_torch_tensor(m.get_output(i)) for i in range(m.get_num_outputs()) - ] - - else: - - def exec_tvm(*args): - args = [a.contiguous() for a in args] - return lib(*args) + m = graph_executor.GraphModule(lib["default"](dev)) + + def to_torch_tensor(nd_tensor): + """A helper function to transfer a NDArray to torch.tensor.""" + if nd_tensor.dtype == "bool": + # DLPack does not support boolean so it can't be handled by + # torch.utils.dlpack.from_pack. Workaround by going through + # numpy, although this brings additional data copy overhead. + return torch.from_numpy(nd_tensor.numpy()) + return torch.utils.dlpack.from_dlpack(nd_tensor.to_dlpack()) + + def to_tvm_tensor(torch_tensor): + """A helper function to transfer a torch.tensor to NDArray.""" + if torch_tensor.dtype == torch.bool: + # same reason as above, fallback to numpy conversion which + # could introduce data copy overhead + return tvm.nd.array(torch_tensor.cpu().numpy()) + return tvm.nd.from_dlpack(torch_tensor) + + def exec_tvm(*i_args): + args = [a.contiguous() for a in i_args] + for idx, arg in enumerate(args, 0): + if arg.dim() != 0: + if arg.requires_grad: + arg = arg.detach() + m.set_input( + f"inp_{idx}", + to_tvm_tensor(arg), + ) + m.run() + return [ + to_torch_tensor(m.get_output(i)) for i in range(m.get_num_outputs()) + ] return exec_tvm - except Exception: log.exception("tvm error") return jit_mod # explicit fall back to eager @@ -785,6 +752,29 @@ def ltc_model(*inputs): return ltc_model +@create_backend +def torchxla_trivial(subgraph): + return subgraph.model + + +@create_backend +def torchxla_trace_once(subgraph): + import torch._dynamo.optimizations.torchxla_integration as integration + + compiled_graph = None + model = subgraph.model + + def fwd(*args): + nonlocal subgraph + nonlocal compiled_graph + if compiled_graph is None: + compiled_graph = integration.extract_compiled_graph(model, args) + del subgraph + return compiled_graph(*args) + + return fwd + + def ipex_fp32(gm: torch.fx.GraphModule, example_inputs): kwargs_ipex = {"datatype": "fp32"} return BACKENDS["ipex"](gm, example_inputs, **kwargs_ipex) diff --git a/torch/_dynamo/optimizations/distributed.py b/torch/_dynamo/optimizations/distributed.py index f65c16483aec6..f48ba500be59f 100644 --- a/torch/_dynamo/optimizations/distributed.py +++ b/torch/_dynamo/optimizations/distributed.py @@ -1,9 +1,14 @@ -from typing import Any, List +import logging +from dataclasses import dataclass, field +from typing import Any, List, Optional import torch import torch.fx.traceback as fx_traceback from torch import fx from torch.fx.node import Node +from ..utils import deepcopy_to_fake_tensor, fake_mode_from_tensors + +log = logging.getLogger(__name__) def args_str(args): @@ -18,102 +23,199 @@ def args_str(args): return str(args) +@dataclass +class Bucket: + size: int = 0 + params: List[str] = field(default_factory=list) + nodes: List[fx.Node] = field(default_factory=list) + + # param_ids is just used for unit testing + param_ids: List = field(default_factory=list) + + +def pretty_print_buckets(buckets: List[Bucket]): + headers = ("Index", "Size (b)", "Param Names") + rows = [] + for idx, bucket in enumerate(reversed(buckets)): + if len(bucket.params) > 0: + rows.append((idx, bucket.size, bucket.params[0])) + for param in bucket.params[1:]: + rows.append((None, None, param)) + try: + from tabulate import tabulate + + log.info( + "\nDDPOptimizer bucket assignments\n" + + tabulate(rows, headers=headers, tablefmt="simple_grid") + ) + except ImportError: + log.info( + "Please `pip install tabulate` in order to pretty-print ddp bucket sizes" + ) + + class DDPOptimizer: + """ + DDPOptimizer applies when dynamo compiles models wrapped in DistributedDataParallel (DDP), + breaking the dynamo graph into chunks to compile separately, with the breaks aligning to + the boundaries of gradient-allreduce buckets chosen by DDP. + + Background/Motivation + - DDP uses allreduce collectives to synchronize partial gradients computed on different workers + - DDP groups gradient allreduces into 'buckets' to optimize communication efficiency of all-reduce + - Parameters grouped into buckets are assumed to be adjacent in time, so they become ready + at around the same time during backward and thus can share the same allreduce efficently + - Allreduces must overlap with backward compute for optimal training performance + - DDP schedules allreduces using 'hooks' fired from the c++ autograd engine in pytorch, which + operates when individual grads become 'ready' + - Dynamo+AOTAutograd produces a single fused graph that runs 'atomically' from the perspective of the + autograd engine, such that all gradients become 'ready' at the same time. Hooks fire after the whole + fused backward function executes, preventing any overlap of compute and communication + + Algorithm + - DDPOptimizer starts off with an FX graph traced by dynamo which represents forward. It can traverse + this graph in reverse order to determine the true order that gradients will become ready during backward. + - Parameter sizes are counted in reverse order, up to a bucket size limit, at which point a new bucket is started + and a graph break introduced + - Each of the subgraphs is compiled by the compiler provided to dynamo by the user, and then fused back together + into an outer module that is returned to the user + + Notes + - It would be better to enforce (by adding an API to DDP) that the bucket splits chosen here are used by DDP, + and that DDP does not need to detect or optimize bucket order by observing execution at runtime, as it does + in eager. + - If Dynamo can't capture a whole graph for the portion of the model wrapped by DDP, this algorithm will currently + produce splits that do not necessarily align with the buckets used by DDP. This should result in performance + degradation approaching the baseline case where graph-splits are not used, but not worse. + - If the backend compiler fails to compile a single subgraph, it will execute eagerly despite the rest of the + subgraphs being compiled + - DDP has a 'parameters_and_buffers_to_ignore' field, which DDPOptimizer attempts to honor by reading markers + left by DDP on individual parameters. In cases where other transformations, such as reparameterization, are + also used, the ignore markers could be lost. If DDPOptimizer fails to ignore a parameter ignored by DDP, + it is not catastrophic but could impact performance by choosing sub-optimal bucket splits. + - DDPOptimizer always ignores all buffers, regardless of their ignore flag, since buffers do not require gradients, + and therefore aren't allreduced by DDP. (They are broadcast during forward, but this is not covered by + DDPOptimizer) + + Args: + bucket_bytes_cap (int): Controls the size of buckets, in bytes, used to determine graphbreaks. Should be + set to match the equivalent parameter on the original DDP module. + + backend_compile_fn (callable): A dynamo compiler function, to be invoked to compile each subgraph. + + first_bucket_cap (int): Controls the size of the first bucket. Should match DDP's first bucket cap. DDP + special-cases the first bucket size since it is sometimes optimal to start a small allreduce early. + + """ + def __init__( self, bucket_bytes_cap: int, - parameters_to_ignore: List[str], backend_compile_fn, - debug=False, + first_bucket_cap: Optional[int] = None, ): + if first_bucket_cap is not None: + self.first_bucket_cap = first_bucket_cap + elif torch.distributed.is_available(): + # this constant comes from C10D lib which is not always built + self.first_bucket_cap = torch.distributed._DEFAULT_FIRST_BUCKET_BYTES + else: + self.first_bucket_cap = bucket_bytes_cap + self.bucket_bytes_cap = bucket_bytes_cap - self.parameters_to_ignore = parameters_to_ignore + assert ( + self.first_bucket_cap <= self.bucket_bytes_cap + ), "First bucket should be smaller/equal to other buckets to get comms warmed up ASAP" + self.backend_compile_fn = backend_compile_fn - self.debug = debug + + def _ignore_parameter(self, parameter): + return hasattr(parameter, "_ddp_ignored") and parameter._ddp_ignored def compile_fn(self, gm: fx.GraphModule, example_inputs: List[torch.Tensor]): """ - TODO: - - handle params_and_buffers_to_ignore - - handle kwargs + Implements graph splitting, first determining a set of of buckets by counting + parameter sizes in reverse graph order, then invoking the user/backend compiler + to compile each subgraph. Finally, stiches compiled graphs into one graphmodule + and returns its callable. """ + fake_mode = fake_mode_from_tensors(example_inputs) + if fake_mode is None: + fake_mode = torch._subclasses.fake_tensor.FakeTensorMode() # 1: compute the partition map according to DDP bucket logic - bucket_bytes = 0 - bucket_actual_sizes = [] - node_splits = [[]] + buckets = [Bucket()] # (size, param_names) for node in reversed(gm.graph.nodes): - if node.op == "output" or node.op == "placeholder": + if node.op in ("output", "placeholder"): continue - if bucket_bytes >= self.bucket_bytes_cap: - bucket_actual_sizes.insert(0, bucket_bytes) - bucket_bytes = 0 - node_splits.insert(0, []) + if ( + buckets[0].size >= self.bucket_bytes_cap + or len(buckets) == 1 + and buckets[0].size >= self.first_bucket_cap + ): + buckets.insert(0, Bucket()) - elif node.op == "call_module": + if node.op == "call_module": target = gm.get_submodule(node.target) - params_size_b = sum( - [ - p.storage().nbytes() - for p in target.parameters() - if p.requires_grad - ] - ) - bucket_bytes += params_size_b - # print(f"accumulated {params_size_b} b from {node}") + for name, p in target.named_parameters(): + param = target.get_parameter(name) + if p.requires_grad and not self._ignore_parameter(param): + buckets[0].size += p._storage().nbytes() + buckets[0].params.append(f"{node.target}_{name}") + buckets[0].param_ids.append(id(param)) elif node.op == "get_attr": maybe_param = getattr(gm, node.target) - if maybe_param.requires_grad: - bucket_bytes += maybe_param.storage().nbytes() - else: - # TODO(whc) confirm this: - # (e.g. call_method, call_function aren't expected to 'have' parameters) - pass - - node_splits[0].append(node) - - if len(node_splits) == 1: - if self.debug: - print( - "DDPOptimizer did not split graphs." - f" Accumulated {bucket_bytes} bytes, and bucket cap is {self.bucket_bytes_cap}" - ) - return self.backend_compile_fn(gm, example_inputs) + if maybe_param.requires_grad and not self._ignore_parameter( + maybe_param + ): + buckets[0].size += maybe_param._storage().nbytes() + buckets[0].params.append(node.target) + buckets[0].param_ids.append(id(maybe_param)) + + # All nodes have to be mapped to a bucket, even if they don't have their own params + # Ignored params still end up in buckets, we just don't count them towards the capacity + buckets[0].nodes.append(node) - if len(bucket_actual_sizes) < len(node_splits): - bucket_actual_sizes.insert(0, bucket_bytes) + # stash buckets for testing/debugging purposes + self.buckets = buckets + log.info( + f"DDPOptimizer used bucket cap {self.bucket_bytes_cap} and produced the following buckets:" + ) + pretty_print_buckets(buckets) - if self.debug: - print( - f"DDPOptimizer used bucket cap {self.bucket_bytes_cap}" - f" and split graphs into parameter sizes {', '.join([str(b) for b in bucket_actual_sizes])}" - ) + if len(buckets) == 1: + # bypass split/fuse logic if there is only one bucket + return self.backend_compile_fn(gm, example_inputs) # 2: partition the graphmodule according to bucket capacity partition_map = {} - for p, nodes in enumerate(node_splits): - for node in nodes: - partition_map[node] = p + for idx, b in enumerate(buckets): + for node in b.nodes: + partition_map[node] = idx split_gm = fx.passes.split_module.split_module( gm, None, lambda node: partition_map[node] ) - if self.debug: - with open("debug_ddp_optimizer.log", "w") as dump_file: - dump_file.write("---orig graph---") - dump_file.write(str(gm.graph)) - dump_file.write("\n---split graph---") - dump_file.write(str(split_gm.graph)) + + debug_str = ( + f"\n---orig graph---\n{gm.graph}\n" + + f"\n---split graph---\n{split_gm.graph}\n" + ) + for name, module in split_gm.named_modules(): + if "." not in name and len(name): + # only print the submod graphs, not their children + debug_str += f"\n---{name} graph---\n{module.graph}\n" + debug_str += "\n---------------\n" + log.debug(debug_str) # 3: compile each of the partitioned submodules using the user-provided compiler class SubmodCompiler(torch.fx.interpreter.Interpreter): - def __init__(self, module, compiler, debug=False): + def __init__(self, module, compiler): super().__init__(module) self.compiler = compiler - self.debug = debug - def compile_submod(self, submod, args, kwargs): + def compile_submod(self, input_mod, args, kwargs): """ Compile the submodule, using a wrapper to make sure its output is always a tuple, @@ -122,13 +224,13 @@ def compile_submod(self, submod, args, kwargs): assert len(kwargs) == 0, "We assume only args for these modules" class WrapperModule(torch.nn.Module): - def __init__(self, compiled_submod, unwrap_singleton_tuple): + def __init__(self, submod, unwrap_singleton_tuple): super().__init__() - self.compiled_submod = compiled_submod + self.submod = submod self.unwrap_singleton_tuple = unwrap_singleton_tuple def forward(self, *args): - x = self.compiled_submod(*args) + x = self.submod(*args) # TODO(whc) # for some reason the isinstance check is necessary if I split one node per submod # - even though I supposedly wrapped the output in a tuple in those cases, the real @@ -138,50 +240,81 @@ def forward(self, *args): return x unwrap_singleton_tuple = False - for sn in submod.graph.nodes: + for sn in input_mod.graph.nodes: if sn.op == "output": if not isinstance(sn.args[0], tuple): unwrap_singleton_tuple = True sn.args = (sn.args,) - submod.recompile() + input_mod.recompile() wrapper = WrapperModule( - self.compiler(submod, args), + self.compiler(input_mod, args), unwrap_singleton_tuple, ) return wrapper + # Note: + # + # The way distributed works today around fake tensors can be somehwat confusing. + # Some of these codepaths are shared in both runtime, and compile time. The presence + # of a fake_mode, read off of fake tensor inputs, dictates how we will operate. + # + # A few things to keep in mind: + # + # 1) We invoke `compile_submod` with a real module. The output of that gets stored + # on the graph via `self.module.add_submodule(n.target, compiled_submod_real)`. + # + # 2) When running a call_module targeted node, if we have a fake_mode, we fakify the + # module we got from self.fetch_attr(n.target). Regardless of fake_mode, we then execute it. + # + # 3) Fake tensors should always be around during compile time. + # + # 4) Fake tensors should never be around at runtime. + # + # 5) We end up with a compilation mode that takes a real submodule and fake tensors, + # to match what aot_autograd exepcts. See Note: [Fake Modules and AOTAutograd] def run_node(self, n: Node) -> Any: with fx_traceback.append_stack_trace(n.stack_trace): args, kwargs = self.fetch_args_kwargs_from_env(n) - if self.debug: - print(f"run_node {n.op}, {n.target} got args {args_str(args)}") + new_args = [] + assert fake_mode + for arg in args: + if isinstance(arg, torch.Tensor) and not isinstance( + arg, torch._subclasses.FakeTensor + ): + new_args.append(fake_mode.from_tensor(arg)) + else: + new_args.append(arg) + + log.debug(f"run_node {n.op}, {n.target} got args {args_str(args)}") assert isinstance(args, tuple) assert isinstance(kwargs, dict) # modify the currently running FX graph # maybe this isn't sound in general, but only changing the target of a node might be ok? if n.op == "call_module": - submod = self.fetch_attr(n.target) - if self.debug: - with open("debug_ddp_optimizer.log", "a") as dump_file: - dump_file.write(f"\n---{n.target} graph---") - dump_file.write(str(submod.graph)) - compiled_submod = self.compile_submod(submod, args, kwargs) + real_mod = self.fetch_attr(n.target) + if fake_mode: + curr_submod = deepcopy_to_fake_tensor(real_mod, fake_mode) + else: + curr_submod = real_mod + + log.debug( + f"\n---{n.target} graph---\n" + str(curr_submod.graph) + ) + compiled_submod_real = self.compile_submod( + real_mod, new_args, kwargs + ) self.module.delete_submodule(n.target) n.target = "compiled_" + n.target - self.module.add_submodule(n.target, compiled_submod) - + self.module.add_submodule(n.target, compiled_submod_real) + return curr_submod(*new_args, **kwargs) # then we execute the modified node using the usual logic - return getattr(self, n.op)(n.target, args, kwargs) + return getattr(self, n.op)(n.target, new_args, kwargs) - submod_compiler = SubmodCompiler(split_gm, self.backend_compile_fn, self.debug) + submod_compiler = SubmodCompiler(split_gm, self.backend_compile_fn) submod_compiler.run(*example_inputs) split_gm.recompile() - if self.debug: - with open("debug_ddp_optimizer.log", "a") as dump_file: - dump_file.write("\n---final graph---") - dump_file.write(str(split_gm.graph)) - + log.debug("\n---final graph---\n" + str(split_gm.graph) + "\n---------------\n") return split_gm diff --git a/torch/_dynamo/optimizations/log_args.py b/torch/_dynamo/optimizations/log_args.py index caa0a9a83ce66..111da69d4a8fe 100644 --- a/torch/_dynamo/optimizations/log_args.py +++ b/torch/_dynamo/optimizations/log_args.py @@ -34,7 +34,6 @@ def run(self, *args): def run_node(self, n: torch.fx.Node): result = super().run_node(n) - if n.op == "call_function": if n.target == aten.convolution.default: args, kwargs = self.fetch_args_kwargs_from_env(n) @@ -67,8 +66,8 @@ def run_node(self, n: torch.fx.Node): def conv_args_analysis(gm: torch.fx.GraphModule, example_inputs): - # lowering graph - gm = make_fx(gm)(*example_inputs) - # use Interpreter to logs the args of conv - ConvArgsAnalysis(gm).run(*example_inputs) - return gm + def conv_arg_inner(*args): + fx_g = make_fx(gm)(*args) + return ConvArgsAnalysis(fx_g).run(*args) + + return conv_arg_inner diff --git a/torch/_dynamo/optimizations/torchxla_integration.py b/torch/_dynamo/optimizations/torchxla_integration.py new file mode 100644 index 0000000000000..f93e4d385ad82 --- /dev/null +++ b/torch/_dynamo/optimizations/torchxla_integration.py @@ -0,0 +1,189 @@ +import dataclasses + +import functools +import itertools +import os +import time +from typing import Any, Dict, List + +import torch + +debug = os.environ.get("debug_extract_compiled_graph") == "1" + + +@dataclasses.dataclass +class GraphInputMatcher: + """ + The GraphInputMatcher class setup the graph inputs for future calls after lazy tracing. + Specifically, those graph inputs corresponding to method parameters should be replaced with the + arguments for the current call. + + tensor_id_to_arg_idx maps the tensor id to the parameter index. + graph_input_tensor_ids, graph_input_xla_values list the tensor_id and ivalue for each of the + TS/XLA graph inputs. + """ + + tensor_id_to_arg_idx: Dict[int, int] + graph_input_tensor_ids: List[int] + # there are 2 categories of graph_input_tensors. + # Category 1: those whose id are not found in tensor_id_to_arg_idx. These are + # most likely const tensors and we can get its content from graph_input_tensors + # Category 2: those whose id are found in tensor_id_to_arg_idx. We should get + # the tensor from method arguments + graph_input_xla_values: List[Any] + + # get the real graph input tensors + def __call__(self, args): + real_input = [] + for tensor_id, traced_xla_value in zip( + self.graph_input_tensor_ids, self.graph_input_xla_values + ): + arg_idx = self.tensor_id_to_arg_idx.get(tensor_id, None) + if arg_idx is None: + inp = traced_xla_value + else: + inp = args[arg_idx] + real_input.append(inp) + return real_input + + +def get_fallback_ops(): + fallback_ops = [] + for opname in metrics.counter_names(): + if "aten::" not in opname: + continue + val = int(metrics.counter_value(opname)) + if val > 0: + fallback_ops.append(f"{opname}={val}") + + return fallback_ops + + +@functools.lru_cache(None) +def import_torchxla(): + """ + CI will run test_circular_dependencies in test/test_testing.py + which tries to import all modules found. + Enclosing the imports in a function so CI that does not have torch_xla + installed will not break. + """ + global torch_xla, xm, metrics + import torch_xla + import torch_xla.core.xla_model as xm + import torch_xla.debug.metrics as metrics + + +def is_xla_tensor(tensor: torch.Tensor) -> bool: + return tensor.device.type == "xla" + + +def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args): + import_torchxla() + + assert all( + map( + is_xla_tensor, + filter( + lambda x: isinstance(x, torch.Tensor), + itertools.chain(xla_model.parameters(), xla_args), + ), + ) + ), "All tensors should be on xla" + + # This call is critical to make sure xla_args' tensor id show up in graph_input_tensor_ids + xm.mark_step() + args_tensor_ids = [ + torch_xla._XLAC._xla_get_tensor_id(xla_arg) for xla_arg in xla_args + ] + + if debug: + print(f"args_tensor_ids {args_tensor_ids}") + + tensor_id_to_arg_idx = {tensor_id: i for i, tensor_id in enumerate(args_tensor_ids)} + xla_out = xla_model(*xla_args) + + fallback_ops = get_fallback_ops() + if len(fallback_ops) > 0: + raise RuntimeError( + f"Fail to extact the compiled graph because of fallback: {','.join(fallback_ops)}" + ) + + if not isinstance(xla_out, (tuple, list)): + xla_out = (xla_out,) + + # If a arg is being in place updated by model, we need to include arg as part of the graph result. + xla_args_need_update_bool = torch_xla._XLAC._check_tensor_need_materialization( + xla_args + ) + xla_args_need_update = [] + arg_index_to_need_update_index = {} + for i, need_update in enumerate(xla_args_need_update_bool): + if need_update: + arg_index_to_need_update_index[i] = len(xla_args_need_update) + xla_args_need_update.append(xla_args[i]) + + args_and_out = tuple(xla_args_need_update) + tuple(xla_out) + + if debug: + print(f"XLA IR Text: {torch_xla._XLAC._get_xla_tensors_text(args_and_out)}") + print(f"XLA IR HLO: {torch_xla._XLAC._get_xla_tensors_hlo(args_and_out)}") + + # calculate graph hash + graph_hash = torch_xla._XLAC._get_graph_hash(args_and_out) + if debug: + print("graph_hash", graph_hash) + + ( + graph_input_tensor_ids, + graph_input_xla_values, + ) = torch_xla._XLAC._get_tensors_xla_device_data_node(args_and_out) + if debug: + print(f"graph_input_tensor_ids {graph_input_tensor_ids}") + assert len(graph_input_tensor_ids) == len( + graph_input_xla_values + ), f"{len(graph_input_tensor_ids)} v.s. {len(graph_input_xla_values)}" + graph_input_matcher = GraphInputMatcher( + tensor_id_to_arg_idx, graph_input_tensor_ids, graph_input_xla_values + ) + + # compiles+runs graph rooted at tensors in 'args_and_out' + torch_xla._XLAC._xla_sync_multi(args_and_out, []) + torch_xla._XLAC._clear_pending_irs(str(xm.xla_device())) + + # input all cpu tensors + def optimized_mod(*args): + torch_xla._XLAC._xla_sync_multi(args, []) + enter_ts = time.time() + if len(args_and_out) == 0: + return () + + assert len(args) > 0 # can not handle no args case for now + graph_input = graph_input_matcher(args) + start_ts = time.time() + res = torch_xla._XLAC._run_cached_graph(graph_hash, graph_input) + if debug: + print( + f"torchxla reuse compiled graph run_cached_graph takes {time.time() - start_ts} seconds" + ) + + args_inplace_update_ts = time.time() + assert len(res) == len(args_and_out) + ncopy = 0 + + for arg_index, res_index in arg_index_to_need_update_index.items(): + args[arg_index].copy_(res[res_index]) + + if debug: + print( + f"Copy {ncopy} args takes {time.time() - args_inplace_update_ts} seconds" + ) + + # First few elements might be xla_args that needs to be in place updated + result = res[len(xla_args_need_update) :] + if debug: + print(f"optimized_mod takes {time.time() - enter_ts} seconds overall") + + xm.mark_step() + return result + + return optimized_mod diff --git a/torch/_dynamo/optimizations/training.py b/torch/_dynamo/optimizations/training.py index bec450bd37430..7be7c1723b1b4 100644 --- a/torch/_dynamo/optimizations/training.py +++ b/torch/_dynamo/optimizations/training.py @@ -6,177 +6,132 @@ from importlib import import_module from typing import Set +from functorch.compile import ( + aot_module_simplified, + min_cut_rematerialization_partition, + nop, + ts_compile, +) + import torch + +from torch._functorch.compilers import debug_nop from torch.fx import GraphModule from torch.fx.passes.backends.cudagraphs import partition_cudagraphs from torch.multiprocessing.reductions import StorageWeakRef from torch.nn import Module from torch.utils._pytree import tree_map -from .. import config -from ..debug_utils import wrap_compiler_debug -from ..utils import clone_inputs, count_calls, counters -from .analysis import has_mutation +from .. import config, eval_frame +from ..utils import clone_inputs, counters from .backends import BACKENDS from .normalize import normalize_ir log = logging.getLogger(__name__) -def is_aot_autograd_safe_to_run(gm, example_inputs): - """ - There are some known issues with Aot Autograd. This is a workaround to catch - such cases, and fallback to eager. We should fix these quickly. - - Issues - 1) LSTM - https://github.com/pytorch/torchdynamo/issues/1147 - 2) LSTM - https://github.com/pytorch/functorch/issues/586 - 3) Input mutation - https://github.com/pytorch/torchdynamo/issues/1301 - """ - - def raise_or_warn(reason): - msg = f"Unable to use Aot Autograd because of presence of {reason}" - if config.raise_on_unsafe_aot_autograd: - raise NotImplementedError(msg) - else: - log.warning(msg) - return False - - import functorch.compile - - # 1) LSTM module (tts_angular) - https://github.com/pytorch/functorch/issues/586 - for submod in gm.modules(): - if submod.__class__.__name__ == "LSTM": - return raise_or_warn("LSTM") - - # 2) Mutation in the graph - mutated = False - try: - if functorch.compile.config.use_functionalize: - # There are two problematic classes we still exclude for now with - # functionalization: - # - data mutation of inputs (fixed when we stop recording the - # copy_ directly into the graph) - # - metadata mutation of inputs (fixed if we do an extra partition - # to avoid AotAutograd on the mutated inputs, or if we some how - # get custom autograd function to reflect metadata changes to the - # original tensor) - mutated = has_mutation(gm, example_inputs, inputs_only=True) - else: - mutated = has_mutation(gm, example_inputs) - except NotImplementedError as e: - if "SparseTensorImpl" not in str(e): - # TODO - TorchDynamo mutation analysis cannot handle sparse tensors. - # So, there is a chance that we could call Aot Autograd when it is - # unsafe. - # The exception is fairly guarded with string check, so any other - # mutation analysis bugs will raise exceptions and will be caught. - raise e - pass - - if mutated: - return raise_or_warn("mutation") - - return True - - -class AotAutogradStrategy(object): - """Base class for backend strategies that use AOT Autograd""" - - @classmethod - def compile_fn(cls, gm: torch.fx.GraphModule, example_inputs): - if count_calls(gm.graph) < 2: - return gm # no point for tiny graphs - return cls(gm, example_inputs).verified_candidate() - - def __init__(self, gm: torch.fx.GraphModule, example_inputs): +def aot_autograd(**kwargs): + def compiler_fn(gm: torch.fx.GraphModule, example_inputs): import functorch.compile + # Hack to get around circular import problems with aot_inductor_debug + if callable(kwargs.get("decompositions")): + kwargs["decompositions"] = kwargs["decompositions"]() + + # TODO: stop monkeypatching here (without even cleaning up, UGH!) functorch.compile.config.use_functionalize = True functorch.compile.config.use_fake_tensor = True - super(AotAutogradStrategy, self).__init__() counters["aot_autograd"]["total"] += 1 - self.use_fallback = False - self.original_example_inputs = example_inputs - self.gm = gm + use_fallback = False if not functorch.compile.config.use_functionalize and config.normalize_ir: try: - self.gm = normalize_ir(gm, self.example_inputs) + gm = normalize_ir(gm, clone_inputs(example_inputs)) except Exception: log.debug("TorchDynamo unable to remove mutation") - self.use_fallback = True - pass + use_fallback = True + # NB: no clone here on example inputs if not is_aot_autograd_safe_to_run(gm, example_inputs): - self.use_fallback = True + use_fallback = True - @property - def example_inputs(self): - return clone_inputs(self.original_example_inputs) - - def verified_candidate(self): - if self.use_fallback: + if use_fallback: log.debug("Unable to use AOT Autograd because graph has mutation") counters["aot_autograd"]["not_ok"] += 1 - return self.gm - cg = self.candidate() - if cg is None: - counters["aot_autograd"]["not_ok"] += 1 - raise RuntimeError("AOT Autograd failed to compile") - counters["aot_autograd"]["ok"] += 1 - return cg - - def candidate(self): - raise NotImplementedError() - + return gm -class AotNop(AotAutogradStrategy): - """Useful for debugging purpose""" + # OK attempt to compile - def candidate(self): - from functorch.compile import nop + def _wrapped_bw_compiler(*args, **kwargs): + # stop TorchDynamo from trying to compile our generated backwards pass + return eval_frame.disable(eval_frame.disable(bw_compiler)(*args, **kwargs)) - return BACKENDS["aot_autograd"](self.gm, self.example_inputs, fw_compiler=nop) + bw_compiler = kwargs.get("bw_compiler") or kwargs["fw_compiler"] + kwargs["bw_compiler"] = _wrapped_bw_compiler + try: + # NB: NOT cloned! + cg = aot_module_simplified(gm, example_inputs, **kwargs) + counters["aot_autograd"]["ok"] += 1 + return eval_frame.disable(cg) + except Exception: + counters["aot_autograd"]["not_ok"] += 1 + raise -aot_eager = AotNop.compile_fn + return compiler_fn -class AotTorchscript(AotAutogradStrategy): - """ - AOT Autograd with torchscript backend. Default partitioner. +def is_aot_autograd_safe_to_run(gm, example_inputs): """ + There are some known issues with Aot Autograd. This is a workaround to catch + such cases, and fallback to eager. We should fix these quickly. - def candidate(self): - from functorch.compile import ts_compile - - return BACKENDS["aot_autograd"]( - self.gm, self.example_inputs, fw_compiler=ts_compile - ) + Issues + 1) LSTM - https://github.com/pytorch/torchdynamo/issues/1147 + 2) LSTM - https://github.com/pytorch/functorch/issues/586 + 3) Input mutation - https://github.com/pytorch/torchdynamo/issues/1301 + """ + def raise_or_warn(reason): + msg = f"Unable to use Aot Autograd because of presence of {reason}" + if config.raise_on_unsafe_aot_autograd: + raise NotImplementedError(msg) + else: + log.warning(msg) + return False -aot_ts = AotTorchscript.compile_fn + # 1) LSTM module (tts_angular) - https://github.com/pytorch/functorch/issues/586 + for submod in gm.modules(): + if submod.__class__.__name__ == "LSTM": + return raise_or_warn("LSTM") -# Global counter to differentiate between different graphs. -graph_idx = 0 + # 2) Mutation in the graphs are now always handled by AOT Autograd. + return True -class AotPrint(AotNop): - """Saves all the gm models so that we can run them separately""" +DEBUG = False - def candidate(self): - global graph_idx - module_idx = "module_" + str(graph_idx) - self.gm.to_folder(module_idx, "Bar") - for idx, x in enumerate(self.example_inputs): - torch.save(x, module_idx + "_tensor" + str(idx) + ".pt") - graph_idx += 1 - return super(AotPrint, self).candidate() +# Useful for debugging purpose +aot_eager = aot_autograd(fw_compiler=debug_nop if DEBUG else nop) +# AOT Autograd with torchscript backend. Default partitioner. +aot_ts = aot_autograd(fw_compiler=ts_compile) -aot_print = AotPrint.compile_fn +# Uses TorchInductor AOT Autograd decomps and partitioner to isolate aot vs +# inductor problems. +aot_inductor_debug = aot_autograd( + # these are taken from memory_efficient_fusion() + fw_compiler=nop, + bw_compiler=nop, + # NB: lambda here is to delay import of inductor + decompositions=lambda: import_module( + f"{config.inductor_import}.compile_fx" + ).select_decomp_table(), + partition_fn=functools.partial( + min_cut_rematerialization_partition, compiler="inductor" + ), +) def mem_efficient_fusion_kwargs(use_decomps): @@ -199,164 +154,87 @@ def mem_efficient_fusion_kwargs(use_decomps): return kwargs -class AotMemEfficientFusion(AotAutogradStrategy): - """Use Min cut rematerilization and NVFuser with AOT Autograd""" - - def candidate(self): - kwargs = mem_efficient_fusion_kwargs(use_decomps=True) - return BACKENDS["aot_autograd"](self.gm, self.example_inputs, **kwargs) - - -class AotMemEfficientFusionNoDecomps(AotAutogradStrategy): - """Use Min cut rematerilization and NVFuser with AOT Autograd""" - - def candidate(self): - kwargs = mem_efficient_fusion_kwargs(use_decomps=False) - return BACKENDS["aot_autograd"](self.gm, self.example_inputs, **kwargs) - - -class AotInductorDebug(AotAutogradStrategy): - """ - Uses TorchInductor Aot Autograd decopms and partitioner to isolate aot vs - inductor problems. - """ - - def candidate(self): - from functorch.compile import min_cut_rematerialization_partition, nop - - decompositions = import_module( - f"{config.inductor_import}.compile_fx" - ).select_decomp_table() - - kwargs = { - # these are taken from memory_efficient_fusion() - "fw_compiler": nop, - "bw_compiler": nop, - "decompositions": decompositions, - "partition_fn": functools.partial( - min_cut_rematerialization_partition, compiler="inductor" - ), - } - return BACKENDS["aot_autograd"](self.gm, self.example_inputs, **kwargs) - - -aot_inductor_debug = AotInductorDebug.compile_fn - - -class AOTMemEfficientFusionWithContext: - """Pass nvfuser context to TorchDynamo""" +# Use min cut rematerialization and TorchScript+nvFuser with AOT Autograd +aot_mem_efficient_fusion = aot_autograd(**mem_efficient_fusion_kwargs(use_decomps=True)) +aot_mem_efficient_fusion_no_decomp = aot_autograd( + **mem_efficient_fusion_kwargs(use_decomps=False) +) - def __init__(self, use_decomps=True): - self.backend_ctx_ctor = lambda: torch.jit.fuser("fuser2") - self.use_decomps = use_decomps +# Pass TorchScript+nvFuser context to TorchDynamo +aot_mem_efficient_fusion.backend_ctx_ctor = lambda: torch.jit.fuser("fuser2") +aot_mem_efficient_fusion_no_decomp.backend_ctx_ctor = lambda: torch.jit.fuser("fuser2") - def __call__(self, gm: torch.fx.GraphModule, example_inputs): - if self.use_decomps: - return AotMemEfficientFusion.compile_fn(gm, example_inputs) - else: - return AotMemEfficientFusionNoDecomps.compile_fn(gm, example_inputs) +def prims_executor(gm, inputs, *, executor): + from functorch.compile import make_boxed_func -aot_mem_efficient_fusion = AOTMemEfficientFusionWithContext(True) -aot_mem_efficient_fusion_no_decomp = AOTMemEfficientFusionWithContext(False) + # This function is called once per forward/backward pass of a graph in AOT + # Autograd. We use it to set up the nvFuser-specific FX graph and return + # execute function. + from torch._prims.context import TorchRefsNvfuserCapabilityMode + from torch._prims.executor import execute + from torch.fx.experimental.proxy_tensor import make_fx + # AOT Autograd might not use the partitioner, so we need to make sure that + # the graph is transformed to use nvFuser-compatible nodes. + if not getattr(gm, "_nvprim_transformed", False): + with TorchRefsNvfuserCapabilityMode(): + gm = make_fx(gm)(*inputs) -class AotPrimsNvfuser(AotAutogradStrategy): - """ - Use FX graph partitioner + Aten2Prims ref + trace executor + nvFuser - """ + # Then we return a callable that executes the "gm" graph + return make_boxed_func(partial(execute, gm, executor=executor)) - def __init__(self, gm: torch.fx.GraphModule, example_inputs): - super(AotPrimsNvfuser, self).__init__(gm, example_inputs) - - from functorch.compile import min_cut_rematerialization_partition - - from torch.fx.passes.backends.nvfuser import NvFuserBackend - - self.nvfuser = NvFuserBackend() - self.min_cut_rematerialization_partition = min_cut_rematerialization_partition - self.populate_aten2aten_decomps() - - def populate_aten2aten_decomps(self): - from torch._decomp import get_decompositions - - aten = torch.ops.aten - default_decompositions = { - aten.detach, - aten.gelu_backward, - aten.leaky_relu_backward, - aten.sigmoid_backward, - aten.threshold_backward, - aten.hardtanh_backward, - aten.hardsigmoid_backward, - aten.hardswish_backward, - aten.tanh_backward, - aten.silu_backward, - aten.elu_backward, - aten.cudnn_batch_norm, - aten.cudnn_batch_norm_backward, - aten.masked_fill.Scalar, - aten.masked_fill.Tensor, - aten.elu, - aten.leaky_relu, - aten.hardtanh, - aten.hardswish, - aten.hardsigmoid, - aten.rsub, - aten.native_batch_norm_backward, - } - - self.aten2aten_decompositions = get_decompositions(default_decompositions) - - def candidate(self): - return BACKENDS["aot_autograd"]( - self.gm, - self.example_inputs, - fw_compiler=wrap_compiler_debug(self.nvfuser, "nvfuser"), - partition_fn=self.min_cut_rematerialization_partition, - decompositions=self.aten2aten_decompositions, - ) - - -aot_prims_nvfuser = AotPrimsNvfuser.compile_fn +def nvprims_fw_bw_partition_fn(joint_module, joint_inputs, *, num_fwd_outputs): + # This function is called once per forward+backward pass of a graph in AOT + # Autograd. We use it to set up the nvFuser-specific FX graph that is later + # passed to the executor. + from functorch.compile import min_cut_rematerialization_partition -def prims_executor(gm, inputs, *, executor): - # This function is called once per forward/backward pass of a graph in AOT - # Autograd. We use it to set up the nvFuser-specific FX graph and return - # execute function. from torch._prims.context import TorchRefsNvfuserCapabilityMode - from torch._prims.executor import execute from torch.fx.experimental.proxy_tensor import make_fx + # AOT Autograd expects arguments of the traced function to be named exactly + # "primals, tangents" + def func(primals, tangents): + return joint_module(primals, tangents) + # First we trace the graph conditionally decomposing nodes # that can be sent to the nvfuser executor with TorchRefsNvfuserCapabilityMode(): - prim_gm = make_fx(gm)(*inputs) + prim_gm = make_fx(func)(*joint_inputs) + + # all nvprims for now + recomputable_ops = { + getattr(torch.ops.nvprims, prim) + for prim in dir(torch.ops.nvprims) + if isinstance(getattr(torch.ops.nvprims, prim), torch._ops.OpOverloadPacket) + and getattr(torch.ops.nvprims, prim).is_recomputable + } - # Then we return a callable that executes the "prim_gm" graph - return partial(execute, prim_gm, executor=executor) + fw_gm, bw_gm = min_cut_rematerialization_partition( + prim_gm, + joint_inputs, + recomputable_ops=recomputable_ops, + num_fwd_outputs=num_fwd_outputs, + ) + # AOT Autograd might not use the partitioner, so we need to make sure that + # the graph is marked as already transformed to use nvFuser-compatible nodes + fw_gm._nvprim_transformed = True + bw_gm._nvprim_transformed = True + return fw_gm, bw_gm def create_nvprims_backend(*, executor): - class NvPrims(AotAutogradStrategy): - def __init__(self, gm: torch.fx.GraphModule, example_inputs): - super(NvPrims, self).__init__(gm, example_inputs) - self.executor = executor - - def candidate(self): - return BACKENDS["aot_autograd"]( - self.gm, - self.example_inputs, - fw_compiler=partial(prims_executor, executor=self.executor), - bw_compiler=partial(prims_executor, executor=self.executor), - ) - - return NvPrims + return aot_autograd( + fw_compiler=partial(prims_executor, executor=executor), + bw_compiler=partial(prims_executor, executor=executor), + partition_fn=nvprims_fw_bw_partition_fn, + ) -aot_nvprims_nvfuser = create_nvprims_backend(executor="nvfuser").compile_fn -aot_nvprims_aten = create_nvprims_backend(executor="aten").compile_fn +aot_nvprims_nvfuser = create_nvprims_backend(executor="nvfuser") +aot_nvprims_aten = create_nvprims_backend(executor="aten") def cloner(t): @@ -435,7 +313,7 @@ def meta_fk(meta): mutated_inputs = set() for n in g.nodes: if n.op == "placeholder": - inputs[StorageWeakRef(meta_fk(n.meta).storage())].add(input_idx) + inputs[StorageWeakRef(meta_fk(n.meta)._typed_storage())].add(input_idx) input_idx += 1 elif n.op == "call_function": if n.target is operator.getitem: @@ -456,7 +334,7 @@ def meta_fk(meta): # TODO: not correct for args that contain tensors in a struct # like list mutated_inputs |= inputs[ - StorageWeakRef(meta_fk(argument.meta).storage()) + StorageWeakRef(meta_fk(argument.meta)._typed_storage()) ] # TODO: error on unrecognized nodes return mutated_inputs @@ -480,33 +358,7 @@ def cudagraphs(model, inputs): return model -def raw_aot_autograd_cudagraphs(model, inputs): - kwargs = { - # these are taken from memory_efficient_fusion() - "fw_compiler": cudagraphs, - "bw_compiler": cudagraphs, - } - - def _wrapped_bw_compiler(*args, **kwargs): - # stop TorchDynamo from trying to compile our generated backwards pass - return disable(bw_compiler(*args, **kwargs)) # type: ignore[operator] - - bw_compiler = kwargs.get("bw_compiler") or kwargs["fw_compiler"] - kwargs["bw_compiler"] = _wrapped_bw_compiler - - from functorch.compile import aot_module_simplified # type: ignore[import] - - from .. import disable - - return aot_module_simplified(model, **kwargs) - - -class AotAutogradCudaGraphs(AotAutogradStrategy): - def candidate(self): - return raw_aot_autograd_cudagraphs(self.gm, self.example_inputs) - - -aot_cudagraphs = AotAutogradCudaGraphs.compile_fn +aot_cudagraphs = aot_autograd(fw_compiler=cudagraphs, bw_compiler=cudagraphs) def create_aot_backends(): @@ -516,36 +368,26 @@ def create_aot_backends(): # aot_eager uses AOT Autograd backend with nop compiler. It is helpful in debugging. BACKENDS["aot_eager"] = aot_eager - # aot_eager uses AOT Autograd backend with print compiler. It prints the - # graphs and also saves the graph modules that are sent to AOT Autograd. - # This is helpful for debugging. - BACKENDS["aot_print"] = aot_print - # aot_ts uses torchscript backend. We can use this with both nnc and nvfuser # by using the relevant fuser with torch.jit.fuser(...) BACKENDS["aot_ts"] = aot_ts - # prims_nvfuser uses the prims and AOT-Autograd to get FX-aten IR. And then - # directly lowers to NVFuser without relying no Torchscript. - BACKENDS["prims_nvfuser"] = aot_prims_nvfuser - # "nvprims" is a subset of PrimTorch primitives that are guaranteed to be # supported by nvFuser. This is the preferred backend for nvFuser+PrimTorch. BACKENDS["nvprims_nvfuser"] = aot_nvprims_nvfuser # This is useful for debugging. Can be removed later. BACKENDS["nvprims_aten"] = aot_nvprims_aten - # aot_nvfuser uses the memory efficient fusion algorithm from AOT Autograd. - # It uses min cut rematerialization algorithm, and uses nvfuser as the - # compiler backend. This is the most optimized setting with nvfuser for - # training. - BACKENDS["aot_nvfuser"] = aot_mem_efficient_fusion + # aot_ts_nvfuser uses the memory efficient fusion algorithm from AOT Autograd. + # It uses min cut rematerialization algorithm, uses nvFuser as the + # compiler backend, and TorchScript as the frontend. + BACKENDS["aot_ts_nvfuser"] = aot_mem_efficient_fusion - # Similar to aot_nvfuser, but disables the decompositions. Decompositions + # Similar to aot_ts_nvfuser, but disables the decompositions. Decompositions # can cause accuracy deviations. This setting allows us to compare accuracy # without worrying about the impact of decomposisitons. More details at # https://github.com/pytorch/torchdynamo/issues/611 - BACKENDS["aot_nvfuser_nodecomps"] = aot_mem_efficient_fusion_no_decomp + BACKENDS["aot_ts_nvfuser_nodecomps"] = aot_mem_efficient_fusion_no_decomp # aot_cudagraphs only applies CUDA graphs to the graph. It is also helpful # for debugging and can serve as a perf baseline. diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index f9b75b782aa00..7c3a1782b0f86 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -1,4 +1,5 @@ import collections +import copy import functools import itertools import logging @@ -6,29 +7,49 @@ import re import traceback from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional +from typing import ( + Any, + Callable, + Dict, + List, + NamedTuple, + Optional, + OrderedDict, + Set, + Tuple, + Union, +) + +import sympy +from typing_extensions import Protocol import torch.nn from torch import fx +from torch.fx.experimental.symbolic_shapes import ShapeEnv from . import config, logging as torchdynamo_logging, variables from .bytecode_transformation import create_instruction, Instruction, unique_id from .codegen import PyCodegen from .exc import BackendCompilerFailed, unimplemented -from .guards import GuardBuilder +from .guards import Guard, GuardBuilder, TensorReference from .mutation_guard import is_dynamic_nn_module from .side_effects import SideEffects from .source import ConstantSource, LocalSource, Source from .utils import ( + assert_no_fake_params_or_buffers, + checkpoint_params, CleanupHook, + clone_inputs, count_calls, counters, - fake_tensors_available, format_graph_tabular, + same, ) -from .variables.builder import VariableBuilder +from .variables.base import VariableTracker +from .variables.builder import GraphArg, VariableBuilder, wrap_fx_proxy from .variables.nn_module import NNModuleVariable from .variables.tensor import ( + DynamicShapeVariable, TensorVariable, UnspecializedNumpyVariable, UnspecializedPythonVariable, @@ -37,6 +58,38 @@ log = logging.getLogger(__name__) +# TODO: I think this accepts int arguments too +class CompiledFn(Protocol): + def __call__(self, *args: torch.Tensor) -> Tuple[torch.Tensor, ...]: + ... + + +CompilerFn = Callable[[fx.GraphModule, List[torch.Tensor]], CompiledFn] + + +class OutputGraphState(NamedTuple): + graphargs: List[GraphArg] + guards: Set[Guard] + nn_modules: Optional[Dict[str, torch.nn.Module]] + side_effects: SideEffects + timestamp: int + name_to_input: OrderedDict[str, Optional[fx.Proxy]] + + def diff(self, other: "OutputGraphState", *, prefix: str = "") -> Optional[str]: + for k in self._fields: + if k == "side_effects": + r = self.side_effects.diff(other.side_effects) + if r is not None: + return r + continue + + sv = getattr(self, k) + ov = getattr(other, k) + if sv != ov: + return f"{prefix}{k} mismatch: {sv} != {ov}" + return None + + @functools.lru_cache(None) def _step_logger(): return torchdynamo_logging.get_step_logger(log) @@ -60,7 +113,7 @@ def _gen_rand_values(): class FakeRootModule(torch.nn.Module): """Trick the constructor of fx.GraphModule""" - def __init__(self, nn_modules: dict): + def __init__(self, nn_modules: Dict[str, torch.nn.Module]): super(FakeRootModule, self).__init__() for k, v in nn_modules.items(): setattr(self, k, v) @@ -69,6 +122,47 @@ def __repr__(self): return "FakeRootModule(...)" +class WrapperBackend: + def __init__(self, backend: CompilerFn, original_example_inputs): + self.backend: CompilerFn = backend + self.original_example_inputs = original_example_inputs + + @property + def example_inputs(self): + return clone_inputs(self.original_example_inputs) + + def __call__(self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): + + self.restore = checkpoint_params(gm) + self.gm = gm + copy_gm = copy.deepcopy(self.gm) + self.candidate = self.backend(copy_gm, self.original_example_inputs) + + if self.candidate is None or self.candidate is self.gm.forward: + return self.gm.forward + + if not config.verify_correctness: + return self.candidate + + # if verify_correctness=True + try: + correct = self.gm.forward(*self.example_inputs) + result = self.candidate(*self.example_inputs) + + # TODO: replace `same` function with the one in testing + if same(correct, result): + return self.candidate + + raise RuntimeError(f"incorrect results of backend {self}") + return self.gm.forward + + except Exception: + log.exception("error in verify_correctness") + raise + finally: + self.restore() + + class OutputGraph(fx.Tracer): """ Wrapper class to hold outputs of InstructionTranslator. Mainly the @@ -79,31 +173,47 @@ def __init__( self, f_globals: Dict[str, Any], code_options: Dict[str, Any], - compiler_fn: Callable, + compiler_fn: CompilerFn, root_tx, ): super(OutputGraph, self).__init__() - # Mutable state checkpointed by copy_graphstate() self.graph = torch.fx.Graph() - self.graphargs = [] - self.guards = set() - self.nn_modules = dict() + self.graphargs: List[GraphArg] = [] + self.guards: Set[Guard] = set() + self.nn_modules: Optional[Dict[str, torch.nn.Module]] = dict() self.side_effects = SideEffects() self.code_options = dict(code_options) - self.output_instructions = [] - # Node => computed real value (see TensorVariable.get_real_value) - self.real_value_cache = {} + self.output_instructions: List[Instruction] = [] + # used to track nodes that are added between calls of copy_graphstate + # and restore_graphstate + self.timestamp = 0 + # Node => computed real value (see utils.get_real_value) + self.real_value_cache: Dict[fx.Node, torch.Tensor] = {} # Not checkpointed - self.compiler_fn = compiler_fn + self.compiler_fn: CompilerFn = compiler_fn self.root_globals = f_globals self.root_tx = root_tx - self.cleanups = [] + from torch._dynamo.symbolic_convert import InstructionTranslatorBase + + self._current_tx: List[InstructionTranslatorBase] = [] + self.cleanups: List[CleanupHook] = [] self.should_exit = False self.random_values_var = None self.initial_random_state = () - self.unspec_variable_map = {} + self.unspec_variable_map: Dict[ + str, Union[UnspecializedNumpyVariable, UnspecializedPythonVariable] + ] = {} + self.shape_env = ShapeEnv() if config.dynamic_shapes else None + self.tensor_id_to_sym_shape_ref: Dict[int, Set[TensorReference]] = {} + self.intermediary_symbols: Dict[sympy.Expr, None] = {} + + # Enables creating unique node names by tracking + # all current placeholder node names + self.name_to_input: OrderedDict[ + str, Optional[fx.Proxy] + ] = collections.OrderedDict() @property def output(self): @@ -113,34 +223,48 @@ def output(self): def fake_mode(self): return self.root_tx.fake_mode - def copy_graphstate(self): + def push_tx(self, tx): + self._current_tx.append(tx) + + def pop_tx(self): + return self._current_tx.pop() + + @property + def current_tx(self): + return self.root_tx if not self._current_tx else self._current_tx[-1] + + def copy_graphstate(self) -> OutputGraphState: """Create a checkpoint of the current state by copying everything""" - graph_nodes = set(self.graph.nodes) - return ( - graph_nodes, + assert self.nn_modules is not None + state = OutputGraphState( list(self.graphargs), set(self.guards), dict(self.nn_modules), self.side_effects.clone(), + self.timestamp, + self.name_to_input.copy(), ) + self.timestamp += 1 + return state - def restore_graphstate(self, state): + def restore_graphstate(self, state: OutputGraphState): """Restore a checkpoint created by self.copy_graphstate()""" ( - graph_nodes, self.graphargs, self.guards, self.nn_modules, self.side_effects, + self.timestamp, + self.name_to_input, ) = state # FX deepcopy doesn't work for a partially created graph, so just remove new nodes for node in reversed(list(self.graph.nodes)): - if node not in graph_nodes: + if node.meta["creation_timestamp"] > self.timestamp: # Erasing node alone does not remove the meta information # So, remove the help tensor explicitly if "example_value" in node.meta: del node.meta["example_value"] - self.graph.erase_node(node) + self.remove_node(node) self.real_value_cache.pop(node, None) def count_calls(self): @@ -157,22 +281,22 @@ def get_submodule(self, keys): return obj def create_graph_input(self, name, type_expr=None): - placeholders = [n for n in self.graph.nodes if n.op == "placeholder"] - # unique - used_names = {n.target for n in placeholders} - if name in used_names: + if name in self.name_to_input: for i in itertools.count(): - if f"{name}_{i}" not in used_names: + if f"{name}_{i}" not in self.name_to_input: name = f"{name}_{i}" break - if placeholders: - ctx = self.graph.inserting_after(placeholders[-1]) + if self.name_to_input: + prev_name = next(reversed(self.name_to_input)) + ctx = self.graph.inserting_after(self.name_to_input[prev_name]) else: ctx = self.graph.inserting_before(None) with ctx: - return self.create_proxy("placeholder", name, (), {}, type_expr=type_expr) + proxy = self.create_proxy("placeholder", name, (), {}, type_expr=type_expr) + self.name_to_input[name] = proxy.node + return proxy def new_var(self, name="tmp"): existing = set(self.code_options["co_varnames"]) @@ -191,43 +315,64 @@ def update_co_names(self, name): name, ) - def register_attr_or_module(self, mod: torch.nn.Module, *names, **options): - if is_dynamic_nn_module(mod): - return variables.UnspecializedNNModuleVariable(mod, **options) + def register_attr_or_module( + self, target: Union[torch.nn.Module, torch.Tensor, Any], *names, **options + ): + if is_dynamic_nn_module(target): + return variables.UnspecializedNNModuleVariable(target, **options) options = dict(options) options["guards"] = set(options.get("guards", [])) source: Source = options.get("source", None) - if isinstance(mod, torch.Tensor): + if isinstance(target, torch.Tensor): if source: options["guards"].add(source.make_guard(GuardBuilder.TENSOR_MATCH)) def wrap_name(module_key): - return TensorVariable.create( + return wrap_fx_proxy( self, self.create_proxy("get_attr", module_key, tuple(), {}), - example_value=mod, + example_value=target, **options, ) - elif isinstance(mod, torch.nn.Module): - assert isinstance(mod, torch.nn.Module) + elif isinstance(target, torch.nn.Module): + assert isinstance(target, torch.nn.Module) options["guards"].add(source.make_guard(GuardBuilder.NN_MODULE)) def wrap_name(module_key): - return NNModuleVariable(type(mod), module_key, **options) + return NNModuleVariable(type(target), module_key, **options) + + elif isinstance(target, (torch.SymInt, torch.SymFloat)): + # HACKY CODE REGION BEGIN + # WE ARE PIGGYBACKING ON EXISTING INFRA TO REGISTER ATTRS + # This ultimately gets written to self.nn_modules, which is unfortunate + # Attrs that are tenors and symints and such need to be migrated to have their + # own storage + # alas, this is like this for now + self.intermediary_symbols.update({target.get_pyobj().expr: None}) + def wrap_name(module_key): + return DynamicShapeVariable.create( + self, + self.create_proxy("get_attr", module_key, tuple(), {}), + dyn_shape=target, + **options, + ) + + # HACKY CODE REGION END else: def wrap_name(module_key): self.output.update_co_names(module_key) - self.root_globals[module_key] = mod + self.root_globals[module_key] = target return VariableBuilder(self, ConstantSource(source_name=module_key))( - mod + target ) + assert self.nn_modules is not None for k, v in self.nn_modules.items(): - if v is mod: + if v is target: # it already exists return wrap_name(k) @@ -243,7 +388,7 @@ def wrap_name(module_key): base = name for i in itertools.count(): if name not in self.nn_modules: - self.nn_modules[name] = mod + self.nn_modules[name] = target return wrap_name(name) name = f"{base}_{i}" @@ -269,11 +414,14 @@ def compile_subgraph( tx.prune_dead_locals() stack_values = list(tx.stack) + assert self.nn_modules is not None root = FakeRootModule(self.nn_modules) # Add all the local vars to the "stack" so restore at the end restore_vars = [] - val_to_names = collections.OrderedDict() + val_to_names: OrderedDict[ + VariableTracker, List[str] + ] = collections.OrderedDict() if stack_values: val_to_names[stack_values[-1]] = list() for k, v in tx.symbolic_locals.items(): @@ -323,6 +471,7 @@ def compile_subgraph( and len(set(stack_values)) == len(stack_values) and self.side_effects.is_empty() ): + # optimization to generate better code in a common case self.add_output_instructions( self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root) @@ -394,17 +543,25 @@ def compile_and_call_fx_graph(self, tx, rv, root): gm.recompile() gm.compile_subgraph_reason = self.compile_subgraph_reason name = unique_id("__compiled_fn") + + assert_no_fake_params_or_buffers(gm) compiled_fn = self.call_user_compiler(gm) compiled_fn = disable(compiled_fn) + counters["stats"]["unique_graphs"] += 1 self.install_global(name, compiled_fn) try: # the call to tabulate can cause a lot of memory to be allocated - if config.log_level <= logging.INFO: + if config.log_level <= logging.INFO and config.output_code: + graph_str = ( + gm.print_readable() + if config.output_graph_code + else format_graph_tabular(gm.graph) + ) log.log( - torchdynamo_logging.CODE, - f"TRACED GRAPH\n {name} {gm.forward.__code__.co_filename} {format_graph_tabular(gm.graph)}\n", + logging.CODE, # type: ignore[attr-defined] + f"TRACED GRAPH\n {name} {gm.forward.__code__.co_filename} {graph_str}\n", ) except ImportError: log.warning( @@ -417,7 +574,7 @@ def compile_and_call_fx_graph(self, tx, rv, root): cg.make_call_generated_code(name) return cg.get_instructions() - def call_user_compiler(self, gm): + def call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn: try: name = ( self.compiler_fn.__name__ @@ -425,32 +582,78 @@ def call_user_compiler(self, gm): else "" ) _step_logger()(logging.INFO, f"calling compiler function {name}") - compiled_fn = self.compiler_fn(gm, self.example_inputs()) + compiler_fn = self.compiler_fn + # WrapperBackend needs real inputs, for now, to verify correctness + if config.verify_correctness: + compiler_fn = WrapperBackend(compiler_fn, self.example_inputs()) + + # NOTE: [Real Tensors in Accuracy Evaluation] + # + # Today, tensors are passed to backends as fake at compile time. See the .fake_example_inputs() + # call to compiler_fn below. At runtime, backends use real tensors. + # + # This should be a strong invariant we hold across all backends, + # and generally, it is. However, for accuracy evaluation, we need real tensors at compile time, + # for now, due to the unfortunate setup described below. + # + # Due to the nature of how we invoke comparison as a backend in two different ways: + # + # (1) Less bad, but still worth rewriting, WrapperBackend above, which takes + # real inputs for its ctor. see the config.verify_correctnes above. + # + # (2) More bad, and very worth rewriting, the minifier installs accuracy comparison as + # a true backend, and therefore needs to be compiled with real inputs. This is made trickier + # by the fact that the minifier will spawn new processes during minification. As such, we have + # created a global flag, MINIFIER_SPAWNED, that should be set IF AND ONLY IF this run was spawned + # as part of accuracy minification. This flag is not a contract, and ideally will not be here long. + # + # The longer term PoR is to: + # (A) Rewrite the minifier accuracy evaluation and verify_correctness code to share the same + # correctness and accuracy logic, so as not to have two different ways of doing the same thing. + # + # (B) Refactor minifier accuracy backend to do its comparison fully at runtime, so as not to need to + # pass real tensors to it at compile time. + is_top_level_minifying = ( + config.repro_after is not None and config.repro_level == 4 + ) + if torch._dynamo.debug_utils.MINIFIER_SPAWNED or is_top_level_minifying: + compiled_fn = compiler_fn(gm, self.example_inputs()) + elif config.DO_NOT_USE_legacy_non_fake_example_inputs: + compiled_fn = compiler_fn(gm, self.example_inputs()) + else: + compiled_fn = compiler_fn(gm, self.fake_example_inputs()) _step_logger()(logging.INFO, f"done compiler function {name}") assert callable(compiled_fn), "compiler_fn did not return callable" except Exception as e: - log.warning("-" * 40 + "\n") - log.warning("TORCHDYNAMO: backend compiler failed\n") - log.warning(e, exc_info=True) - log.warning("-" * 40 + "\n") compiled_fn = gm.forward - if config.raise_on_backend_error: - raise BackendCompilerFailed(self.compiler_fn, e) from e + raise BackendCompilerFailed(self.compiler_fn, e) from e return compiled_fn - def example_inputs(self): + def fake_example_inputs(self) -> List[torch.Tensor]: + result = [] + for arg in self.graphargs: + example = arg.get_fake_examples() + if example is not None: + result.extend(example) + else: + # Fallback, in case fake_tensor was not set + # Particularly for graph args that are not tensors + result.extend(arg.get_examples()) + return result + + def example_inputs(self) -> List[torch.Tensor]: result = [] for arg in self.graphargs: result.extend(arg.get_examples()) return result - def remove_unused_graphargs(self): + def remove_unused_graphargs(self) -> None: for node in reversed(list(self.graph.nodes)): if len(list(node.users)) == 0: if node.op == "get_attr": - self.graph.erase_node(node) + self.remove_node(node) elif node.op == "call_function" and node.target is operator.getitem: - self.graph.erase_node(node) + self.remove_node(node) expanded_graphargs = [] for arg in self.graphargs: @@ -465,12 +668,12 @@ def remove_unused_graphargs(self): if arg.uses == 0: if "example_value" in node.meta: del node.meta["example_value"] - self.graph.erase_node(node) + self.remove_node(node) self.real_value_cache.pop(node, None) self.graphargs = [arg for arg in self.graphargs if arg.uses > 0] - def add_output_instructions(self, prefix: List[Instruction]): + def add_output_instructions(self, prefix: List[Instruction]) -> None: """ We call this on the creation of a new compiled subgraph that is inserted before user code. @@ -478,16 +681,15 @@ def add_output_instructions(self, prefix: List[Instruction]): self.output_instructions.extend(prefix) self.should_exit = True - def install_global(self, name, value): + def install_global(self, name, value) -> None: self.cleanups.append(CleanupHook.create(self.root_globals, name, value)) - def cleanup(self): + def cleanup(self) -> None: # There is a reference cycle between tracer and OutputGraph, causing # some of the tensor objects to be held alive for longer than necessary. # Clear cache for conversion of real -> fake tensors - if fake_tensors_available: - self.root_tx.fake_mode.fake_tensor_converter = None + self.root_tx.fake_mode.fake_tensor_converter = None self.root_tx = None # Note: generated fx graph will hold a reference to the nn_module, @@ -502,6 +704,7 @@ def cleanup(self): if "example_value" in node.meta: del node.meta["example_value"] self.real_value_cache.clear() + self.name_to_input.clear() def create_proxy( self, @@ -512,14 +715,13 @@ def create_proxy( name=None, type_expr=None, proxy_factory_fn=None, - current_tx=None, ): rv = super().create_proxy( kind, target, args, kwargs, name, type_expr, proxy_factory_fn ) # append stack trace to fx node - tx = current_tx if current_tx else self.root_tx + tx = self.current_tx nn_module_stack = tx.nn_module_stack if nn_module_stack: @@ -530,10 +732,22 @@ def create_proxy( frame_summaries.append(tx.frame_summary()) tx = getattr(tx, "parent", None) - msgs = traceback.StackSummary.from_list(frame_summaries).format() + # official from_list stub doesn't have new-style type + msgs = traceback.StackSummary.from_list(frame_summaries).format() # type: ignore[arg-type] # Carry module_stack along with node.stack_trace for reusing stacktrace propagation infra nn_module_stack_str = f"Module stack: {nn_module_stack}\n" rv.node.stack_trace = nn_module_stack_str + " | ".join(msgs) return rv + + def create_node(self, *args, **kwargs): + node = super().create_node(*args, **kwargs) + node.meta["creation_timestamp"] = self.timestamp + return node + + # Note: we did not override erase_node since + # we call self.graph.erase_node elsewhere + def remove_node(self, node): + self.graph.erase_node(node) + self.name_to_input.pop(node.name, None) diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index 1f8675ae1c9e3..46c5cd115e052 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -1,7 +1,7 @@ import collections import dataclasses import inspect -from typing import Any +from typing import Any, Dict, List, Optional import torch.nn @@ -59,18 +59,48 @@ def __eq__(self, other): return self is other -class SideEffects(object): +class SideEffects: """ Track side effects (list mutation, setattr, etc) that need to be applied after an FX graph is run. """ + id_to_variable: Dict[int, VariableTracker] + store_attr_mutations: Dict[AttributeMutation, Dict[str, VariableTracker]] + keepalive: List[Any] + def __init__(self, id_to_variable=None, store_attr_mutations=None, keepalive=None): - super(SideEffects, self).__init__() + super().__init__() self.id_to_variable = id_to_variable or collections.OrderedDict() self.store_attr_mutations = store_attr_mutations or collections.OrderedDict() self.keepalive = keepalive or [] + def __eq__(self, other: object) -> bool: + assert isinstance(other, SideEffects) + # NB: do NOT test keepalive + return ( + self.id_to_variable == other.id_to_variable + and self.store_attr_mutations == other.store_attr_mutations + ) + + def diff(self, other: "SideEffects") -> Optional[str]: + if self.id_to_variable != other.id_to_variable: + sk_itv = self.id_to_variable.keys() + ok_itv = other.id_to_variable.keys() + if sk_itv != ok_itv: + return f"id_to_variable keys: {sk_itv} != {ok_itv}" + # Feel free to augment this with more fancy diffing logic + # if needed for debugging + return "id_to_variable: unknown diff" + elif self.store_attr_mutations != other.store_attr_mutations: + sk_sam = self.store_attr_mutations.keys() + ok_sam = other.store_attr_mutations.keys() + if sk_sam != ok_sam: + return f"store_attr_mutations keys: {sk_sam} != {ok_sam}" + return "store_attr_mutations: unknown diff" + else: + return None + def clone(self): """Create a shallow copy""" return self.__class__( @@ -82,16 +112,16 @@ def clone(self): keepalive=list(self.keepalive), ) - def apply(self, fn, cache=None): + def apply(self, fn, cache=None, skip_fn=lambda _: False): if cache is None: cache = dict() self.id_to_variable = collections.OrderedDict( - (k, VariableTracker.apply(fn, v, cache)) + (k, VariableTracker.apply(fn, v, cache, skip_fn)) for k, v in self.id_to_variable.items() ) self.store_attr_mutations = collections.OrderedDict( - (k, VariableTracker.apply(fn, v, cache)) + (k, VariableTracker.apply(fn, v, cache, skip_fn)) for k, v in self.store_attr_mutations.items() ) diff --git a/torch/_dynamo/skipfiles.py b/torch/_dynamo/skipfiles.py index 2b6fbb3959c8d..41a04626756d2 100644 --- a/torch/_dynamo/skipfiles.py +++ b/torch/_dynamo/skipfiles.py @@ -49,6 +49,22 @@ from . import config +""" +A note on skipfiles: + +Dynamo consults this file to determine whether code should be compiled or skipped. + +A skip applies at the frame boundary, meaning dynamo either triggers a graph break +at the beginning of the frame or attempts to trace the whole frame. When skipping +a frame, recursively called frames are still traced by dynamo unless also skipped. + +Skipfiles (skipped at the file level instead of function level) still apply on a +frame-by-frame boundary as dynamo traces, but apply to all functions in that file. + +@skip is a helper decorator that can be applied to your function to cause it to be +included here. +""" + def _strip_init_py(s): return re.sub(r"__init__.py$", "", s) @@ -106,16 +122,6 @@ def _module_dir(m: types.ModuleType): torch.set_rng_state.__code__.co_filename, } -# Include optimizer code for tracing -FILENAME_ALLOWLIST |= set( - [ - inspect.getfile(obj) - for obj in torch.optim.__dict__.values() - if inspect.isclass(obj) - ] -) - -FILENAME_ALLOWLIST |= {torch.optim._functional.__file__} if HAS_PRIMS_REFS: FILENAME_ALLOWLIST |= { @@ -127,7 +133,6 @@ def _module_dir(m: types.ModuleType): torch._refs.nn.functional.__file__, } -FILENAME_ALLOWLIST |= {torch.optim._functional.__file__} SKIP_DIRS_RE = None diff --git a/torch/_dynamo/source.py b/torch/_dynamo/source.py index 6b5d63ab850e1..626bdb4b7826c 100644 --- a/torch/_dynamo/source.py +++ b/torch/_dynamo/source.py @@ -75,6 +75,9 @@ def name(self): class RandomValueSource(Source): random_call_index: int + def guard_source(self): + return GuardSource.RANDOM_VALUE + def reconstruct(self, codegen): return [ codegen.create_load(codegen.tx.output.random_values_var), diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 0b5cfae69363c..2064e12497a38 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -12,7 +12,8 @@ import types import typing import weakref -from typing import Any, Dict, Iterable, List +from collections.abc import Sized +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple from unittest.mock import patch import torch @@ -36,9 +37,9 @@ unique_id, ) from .codegen import PyCodegen -from .exc import unimplemented, Unsupported +from .exc import BackendCompilerFailed, unimplemented, Unsupported from .guards import GuardBuilder -from .output_graph import GraphCompileReason, OutputGraph +from .output_graph import GraphCompileReason, OutputGraph, OutputGraphState from .replay_record import DummyModule, ExecutionRecorder from .resume_execution import ContinueExecutionCache, ReenterWith from .source import ( @@ -48,14 +49,9 @@ GlobalWeakRefSource, LocalSource, ) -from .utils import ( - counters, - fake_tensors_available, - graph_break_dup_warning_checker, - istype, -) +from .utils import counters, graph_break_dup_warning_checker, istype, proxy_args_kwargs from .variables.base import MutableLocal, typestr, VariableTracker -from .variables.builder import VariableBuilder +from .variables.builder import VariableBuilder, wrap_fx_proxy from .variables.builtin import BuiltinVariable from .variables.constant import ConstantVariable from .variables.dicts import ConstDictVariable @@ -81,7 +77,7 @@ WithExitFunctionVariable, ) from .variables.nn_module import NNModuleVariable -from .variables.tensor import TensorVariable +from .variables.tensor import DynamicShapeVariable, TensorVariable from .variables.torch import TorchVariable from .variables.user_defined import UserDefinedVariable @@ -96,7 +92,7 @@ def _step_logger(): @dataclasses.dataclass class BlockStackEntry: target: Instruction - stack_index: int = None + stack_index: Optional[int] = None with_context: ContextWrappingVariable = None def can_restore(self): @@ -110,7 +106,28 @@ def exit(self, tx): return self.with_context.exit(tx) -def stack_op(fn: typing.Callable): +class InstructionTranslatorGraphState(NamedTuple): + output: OutputGraphState + symbolic_locals: Dict[str, VariableTracker] + stack: List[VariableTracker] + block_stack: List[BlockStackEntry] + instruction_pointer: Optional[int] + current_instruction: Instruction + next_instruction: Optional[Instruction] + lineno: int + + def diff(self, other: "InstructionTranslatorGraphState") -> Optional[str]: + for k in self._fields: + if k == "output": + return self.output.diff(other.output, prefix=f"{k}.") + sv = getattr(self, k) + ov = getattr(other, k) + if sv != ov: + return f"{k} mismatch: {sv} != {ov}" + return None + + +def stack_op(fn: typing.Callable[..., object]): nargs = len(inspect.signature(fn).parameters) fn_var = BuiltinVariable(fn) @@ -121,16 +138,120 @@ def impl(self: "InstructionTranslatorBase", inst: Instruction): return impl -def generic_jump(truth_fn: typing.Callable, push: bool): +def _detect_and_normalize_assert_statement( + self: "InstructionTranslatorBase", + truth_fn: typing.Callable[[object], bool], + push: bool, +): + # Detect if this jump instruction is assert and normalize the assert + # by pushing dummy error message when nothing is given. + # + # Python 3.9 assertion is in following format: + # 18 POP_JUMP_IF_TRUE 28 + # 20 LOAD_ASSERTION_ERROR + # 22 LOAD_CONST 3 ('Assert message') -> optional instruction + # 24 CALL_FUNCTION 1 -> optional instruction + # 26 RAISE_VARARGS + # + # Python 3.8 assertion is in following format: + # 18 POP_JUMP_IF_TRUE 28 + # 20 LOAD_GLOBAL 0 (Assertion type) + # 22 LOAD_CONST 3 ('Assert message') -> optional instruction + # 24 CALL_FUNCTION 1 -> optional instruction + # 26 RAISE_VARARGS 1 + + if (truth_fn is not operator.truth) or push: + return False + + assert isinstance(self.instruction_pointer, int) + current_instruction_pointer = self.instruction_pointer + inst = self.instructions[current_instruction_pointer] + # Detect LOAD_ASSERTION_ERROR or LOAD_GLOBAL 0 + if sys.version_info < (3, 9): + if inst.opname != "LOAD_GLOBAL" or inst.argval != "AssertionError": + return False + else: + if inst.opname != "LOAD_ASSERTION_ERROR": + return False + + current_instruction_pointer += 1 + + if current_instruction_pointer >= len(self.instructions): + return False + + inst = self.instructions[current_instruction_pointer] + has_error_msg = False + # DETECT RAISE_VARARGS or LOAD CONST + if inst.opname == "LOAD_CONST": + if not isinstance(inst.argval, str): + return False + self.LOAD_CONST(inst) + has_error_msg = True + + # if it is LOAD_CONSTANT, it must be followed by CALL_FUNCTION + current_instruction_pointer += 1 + if current_instruction_pointer >= len(self.instructions): + return False + inst = self.instructions[current_instruction_pointer] + if inst.opname != "CALL_FUNCTION": + return False + + # CALL_FUNCTION should be followed by RAISE_VARARGS + current_instruction_pointer += 1 + if current_instruction_pointer >= len(self.instructions): + return False + inst = self.instructions[current_instruction_pointer] + + if inst.opname != "RAISE_VARARGS": + return False + + if not has_error_msg: + # Push dummy value instead of error message + self.push(ConstantVariable("assertion error")) + + return True + + +def generic_jump(truth_fn: typing.Callable[[object], bool], push: bool): def inner(self: "InstructionTranslatorBase", inst: Instruction): value: VariableTracker = self.pop() self.output.guards.update(value.guards) + if ( + config.rewrite_assert_with_torch_assert + and _detect_and_normalize_assert_statement(self, truth_fn, push) + ): + error_msg: VariableTracker = self.pop() + self.output.guards.update(error_msg.guards) + # Skip over things like `assert True` + if value.is_python_constant() and bool(value.as_python_constant()): + self.jump(inst) + return + + # Manually insert torch._assert instead of python assert and jump over + # assert related instructions as we don't need them anymore. + self.output.create_proxy( + "call_function", + torch._assert, + *proxy_args_kwargs((value, error_msg), {}), + ) + self.jump(inst) + return + if value.is_python_constant(): if truth_fn(value.as_python_constant()): push and self.push(value) self.jump(inst) - elif isinstance(value, TensorVariable) and self.should_compile_partial_graph(): + elif ( + isinstance(value, (TensorVariable)) and self.should_compile_partial_graph() + ): # compile a partial subgraph prefix then jump into user code + if self.has_backedge(): + msg = ( + "Skipping frame because there is a graph break in a for/while loop" + ) + log.debug(msg) + raise exc.SkipFrame(msg) + self.push(value) self.output.compile_subgraph( self, @@ -149,12 +270,22 @@ def inner(self: "InstructionTranslatorBase", inst: Instruction): + if_next + if_jump ) + elif isinstance(value, NNModuleVariable): + # Equivant of "self.nn_module is not None" + if truth_fn(value): + push and self.push(value) + self.jump(inst) elif not isinstance(value, TensorVariable) and value.has_unpack_var_sequence( self ): if truth_fn(len(value.unpack_var_sequence(self))): push and self.push(value) self.jump(inst) + elif isinstance(value, DynamicShapeVariable): + eval_result = value.evaluate_expr(self.output) + if truth_fn(eval_result): + push and self.push(value) + self.jump(inst) else: unimplemented(f"generic_jump {typestr(value)}") @@ -172,13 +303,18 @@ def wrapper(self: "InstructionTranslatorBase", inst: Instruction): reason = None try: return inner_fn(self, inst) - except Unsupported as exc: + except Unsupported as excp: + if self.has_backedge(): + msg = "Skipping frame because there is a graph break in a for/while loop" + log.debug(msg) + raise exc.SkipFrame(msg) from excp + if not self.should_compile_partial_graph(): raise - user_stack = [self.frame_summary()] + list(reversed(exc.real_stack)) + user_stack = [self.frame_summary()] + list(reversed(excp.real_stack)) user_stack_formatted = "".join(traceback.format_list(user_stack)) frame_loc = (user_stack[-1].filename, user_stack[-1].lineno) - # torchdynamo.explain() formats this a little nicer, and presents a slightly + # torch._dynamo.explain() formats this a little nicer, and presents a slightly # more actionable user code pointer if ( config.print_graph_breaks @@ -186,12 +322,12 @@ def wrapper(self: "InstructionTranslatorBase", inst: Instruction): and graph_break_dup_warning_checker.add(frame_loc) ): log.warning( - f"Graph break: {exc} from user code at {user_stack_formatted}" + f"Graph break: {excp} from user code at {user_stack_formatted}" ) - exc.remove_from_stats() - exc.add_to_stats("graph_break") - reason = GraphCompileReason(exc.msg, user_stack) + excp.remove_from_stats() + excp.add_to_stats("graph_break") + reason = GraphCompileReason(excp.msg, user_stack) self.restore_graphstate(state) self.output.compile_subgraph(self, reason=reason) self.popn(push - dis.stack_effect(inst.opcode, inst.arg)) @@ -230,6 +366,36 @@ def wrapper(self: "InstructionTranslatorBase", inst: Instruction): class InstructionTranslatorBase(object): + output: OutputGraph + symbolic_locals: Dict[str, VariableTracker] + symbolic_globals: Dict[str, VariableTracker] + stack: List[VariableTracker] + instruction_pointer: Optional[int] + current_instruction: Instruction + next_instruction: Optional[Instruction] + block_stack: List[BlockStackEntry] + lineno: int + mutated_closure_cell_contents: Set[str] + + checkpoint: Optional[Tuple[Instruction, InstructionTranslatorGraphState]] + random_calls: List[ + Tuple[Callable[..., object], Tuple[object, ...], Dict[str, object]] + ] + + def has_backedge(self): + cur_offset = self.current_instruction.offset + assert self.instruction_pointer is not None + for inst in self.instructions[self.instruction_pointer :]: + if inst.opname in ( + "JUMP_ABSOLUTE", + "POP_JUMP_IF_TRUE", + "POP_JUMP_IF_FALSE", + ): + jump_offset = inst.argval + if jump_offset < cur_offset: + return True + return False + def cell_and_freevars(self): if not hasattr(self, "_cell_and_freevars"): self._cell_and_freevars = tuple( @@ -269,11 +435,18 @@ def repl(v: VariableTracker): return newvar return v - cache = dict() - self.output.side_effects.apply(repl, cache) - self.stack = [VariableTracker.apply(repl, x, cache) for x in self.stack] + def skip(v: VariableTracker): + return oldvar.mutable_local not in v.recursively_contains + + cache: Dict[int, Tuple[object, object]] = dict() + self.output.side_effects.apply(repl, cache, skip_fn=skip) + self.stack = [ + VariableTracker.apply(repl, x, cache, skip_fn=skip) for x in self.stack + ] for k, x in self.symbolic_locals.items(): - self.symbolic_locals[k] = VariableTracker.apply(repl, x, cache) + self.symbolic_locals[k] = VariableTracker.apply( + repl, x, cache, skip_fn=skip + ) def replace_all(self, oldvar: VariableTracker, newvar: VariableTracker): if isinstance(oldvar.mutable_local, side_effects.MutableSideEffects): @@ -299,6 +472,7 @@ def inline_user_function_return(self, fn, args, kwargs): def step(self): """Process exactly one instruction, return False we should exit""" + assert isinstance(self.instruction_pointer, int) inst = self.instructions[self.instruction_pointer] self.current_instruction = inst self.instruction_pointer += 1 @@ -320,7 +494,10 @@ def step(self): if not hasattr(self, inst.opname): unimplemented(f"missing: {inst.opname}") getattr(self, inst.opname)(inst) + return inst.opname != "RETURN_VALUE" + except BackendCompilerFailed: + raise except Unsupported as exc: exc.real_stack.append(self.frame_summary()) if self.empty_checkpoint(): @@ -328,11 +505,12 @@ def step(self): except Exception as exc: real_stack = getattr(exc, "real_stack", []) real_stack.append(self.frame_summary()) - exc.real_stack = real_stack + exc.real_stack = real_stack # type: ignore[attr-defined] raise # generate code from checkpoint assert not self.output.output_instructions + assert self.checkpoint is not None continue_inst, state = self.checkpoint self.restore_graphstate(state) self.output.compile_subgraph(self, partial_convert=True) @@ -343,18 +521,21 @@ def step(self): def run(self): try: + self.output.push_tx(self) while ( self.instruction_pointer is not None and not self.output.should_exit and self.step() ): pass + except BackendCompilerFailed: + raise except Exception as e: if config.replay_record_enabled: - e.exec_record = self.exec_recorder.get_record() - + e.exec_record = self.exec_recorder.get_record() # type: ignore[attr-defined] raise finally: + self.output.pop_tx() # Cleanup the outputGraph to delete the held tensors. We perform the # cleanup only for InstructionTranslator and not # InliningInstructionTranslator. The InliningInstructionTranslator @@ -363,20 +544,20 @@ def run(self): if isinstance(self, InstructionTranslator): self.output.cleanup() - def push(self, val): + def push(self, val: Optional[VariableTracker]): assert val is None or isinstance( val, VariableTracker ), f"push expects VariableTracker, got {typestr(val)}" self.stack.append(val) - def push_many(self, vals: List[TensorVariable]): + def push_many(self, vals: List[VariableTracker]): for val in vals: self.push(val) - def pop(self) -> TensorVariable: + def pop(self) -> VariableTracker: return self.stack.pop() - def popn(self, n: int) -> List[TensorVariable]: + def popn(self, n: int) -> List[VariableTracker]: assert n >= 0 return list(reversed([self.pop() for _ in range(n)])) @@ -505,7 +686,7 @@ def calc_package(self): f"({package!r} != {spec.parent!r})", ImportWarning, stacklevel=3, - ) + ) # type: ignore[call-arg] return package elif spec is not None: return spec.parent @@ -515,7 +696,7 @@ def calc_package(self): "falling back on __name__ and __path__", ImportWarning, stacklevel=3, - ) + ) # type: ignore[call-arg] package = self.f_globals["__name__"] if "__path__" not in self.f_globals: package = package.rpartition(".")[0] @@ -696,6 +877,7 @@ def COMPARE_OP(self, inst): left, ( TensorVariable, + DynamicShapeVariable, NNModuleVariable, BaseListVariable, UserDefinedVariable, @@ -713,16 +895,6 @@ def COMPARE_OP(self, inst): supported_is_const[op](object(), right.value), **options ) ) - elif ( - isinstance(left, TensorVariable) or isinstance(right, TensorVariable) - ) and op in supported_tensors: - self.push( - TensorVariable.create( - self, - supported_tensors[op](left.as_proxy(), right.as_proxy()), - **options, - ) - ) elif ( left.is_python_constant() and right.is_python_constant() @@ -737,10 +909,40 @@ def COMPARE_OP(self, inst): **options, ) ) + elif ( + isinstance(left, TensorVariable) or isinstance(right, TensorVariable) + ) and op in supported_tensors: + self.push( + wrap_fx_proxy( + self, + supported_tensors[op](left.as_proxy(), right.as_proxy()), + **options, + ) + ) + elif ( + isinstance(left, DynamicShapeVariable) + or isinstance(right, DynamicShapeVariable) + ) and op in supported_tensors: + self.push( + DynamicShapeVariable.create( + self, + supported_tensors[op](left.as_proxy(), right.as_proxy()), + dyn_shape=None, + **options, + ) + ) elif op in ("in", "not in"): self.push(right.call_method(self, "__contains__", [left], {})) if op == "not in": self.UNARY_NOT(inst) + elif ( + isinstance(left, UserFunctionVariable) + and isinstance(right, UserFunctionVariable) + and op in supported_is_const + ): + self.push( + ConstantVariable(supported_is_const[op](left.fn, right.fn), **options) + ) else: unimplemented(f"COMPARE_OP {typestr(left)} {op} {typestr(right)}") @@ -797,8 +999,8 @@ def CALL_FUNCTION_KW(self, inst): fn = self.pop() assert isinstance(argnames, ConstantVariable) argnames = argnames.value - args, kwargs = args[: -len(argnames)], args[-len(argnames) :] - kwargs = dict(zip(argnames, kwargs)) + args, kwargs_list = args[: -len(argnames)], args[-len(argnames) :] + kwargs = dict(zip(argnames, kwargs_list)) assert len(kwargs) == len(argnames) self.call_function(fn, args, kwargs) @@ -824,6 +1026,14 @@ def LOAD_ATTR(self, inst): def STORE_ATTR(self, inst): prior = self.copy_graphstate() val, obj = self.popn(2) + + if isinstance(obj, NNModuleVariable): + # We don't allow side effects during export + # https://github.com/pytorch/torchdynamo/issues/1475 + assert ( + not self.export + ), f"Mutating module attribute {inst.argval} during export." + try: self.output.guards.update( BuiltinVariable(setattr) @@ -848,6 +1058,16 @@ def STORE_ATTR(self, inst): self.create_call_resume_at(self.next_instruction) ) + def create_call_resume_at(self, offset): + raise AssertionError( + f"create_call_resume_at not overridden by subclass {type(self)}" + ) + + def should_compile_partial_graph(self) -> bool: + raise AssertionError( + f"should_compile_partial_graph not overridden by subclass {type(self)}" + ) + @break_graph_if_unsupported(push=0) def STORE_SUBSCR(self, inst): val, obj, key = self.popn(3) @@ -950,10 +1170,20 @@ def LIST_APPEND(self, inst): obj = self.stack[-inst.arg] assert isinstance(obj, ListVariable) assert obj.mutable_local + # only copy if the new obj contains other mutables + new_rec_contains = obj.recursively_contains + if v.recursively_contains or v.mutable_local: + new_rec_contains = obj.recursively_contains.union(v.recursively_contains) + + if v.mutable_local: + new_rec_contains.add(v.mutable_local) + self.replace_all( obj, ListVariable( obj.items + [v], + recursively_contains=new_rec_contains, + regen_guards=False, **VariableTracker.propagate([obj, v]), ), ) @@ -993,30 +1223,24 @@ def MAKE_FUNCTION(self, inst): ) def UNPACK_SEQUENCE(self, inst): - # TODO(jansel): rewrite this using unpack_var_sequence seq = self.pop() - options = VariableTracker.propagate([seq]) if isinstance(seq, BaseListVariable): - assert len(seq.items) == inst.argval self.output.guards.update(seq.guards) - for i in reversed(seq.items): - self.push(i) + val = seq.unpack_var_sequence(self) elif seq.is_python_constant() and isinstance(seq, ConstantVariable): - val = seq.as_python_constant() - assert len(val) == inst.argval - for i in reversed(val): - self.push(ConstantVariable(i, **options)) + val = seq.unpack_var_sequence(self) elif isinstance(seq, TensorVariable): - proxy = seq.as_proxy() - for i in reversed(range(inst.argval)): - self.push(TensorVariable.create(self, proxy[i], **options)) + val = seq.unpack_var_sequence(self, idxes=range(inst.argval)) elif isinstance(seq, GetAttrVariable) and isinstance(seq.obj, TensorVariable): # x, y = a.shape proxy = getattr(seq.obj.as_proxy(), seq.name) - for i in reversed(range(inst.argval)): - self.push(TensorVariable.create(self, proxy[i], **options)) + options = VariableTracker.propagate(self) + val = [wrap_fx_proxy(self, proxy[i], **options) for i in range(inst.argval)] else: unimplemented(f"UNPACK_SEQUENCE {seq}") + assert len(val) == inst.argval + for i in reversed(val): + self.push(i) def UNPACK_EX(self, inst): assert 0 <= inst.argval <= 0xFFFF @@ -1089,7 +1313,8 @@ def FORMAT_VALUE(self, inst): fmt_spec = ConstantVariable("") value = self.pop() - + if isinstance(value, DynamicShapeVariable): + value = ConstantVariable(str(value.dyn_shape)) if (flags & 0x03) == 0x01: value = BuiltinVariable(str).call_function(self, [value], {}) elif (flags & 0x03) == 0x02: @@ -1224,9 +1449,9 @@ def MATCH_KEYS(self, inst): INPLACE_XOR = stack_op(operator.ixor) INPLACE_OR = stack_op(operator.ior) - def copy_graphstate(self): + def copy_graphstate(self) -> InstructionTranslatorGraphState: """Create a checkpoint of the current state by copying everything""" - return ( + return InstructionTranslatorGraphState( self.output.copy_graphstate(), collections.OrderedDict(self.symbolic_locals), list(self.stack), @@ -1237,7 +1462,7 @@ def copy_graphstate(self): self.lineno, ) - def restore_graphstate(self, state): + def restore_graphstate(self, state: InstructionTranslatorGraphState): """Restore a checkpoint created by self.copy_graphstate()""" ( output_state, @@ -1258,7 +1483,7 @@ def empty_checkpoint(self): graphstate = self.checkpoint[1][1:] state = (*output_graphstate, *graphstate) for obj in state: - if isinstance(obj, Iterable): + if isinstance(obj, Sized): if len(obj) != 0: return False return True @@ -1308,19 +1533,20 @@ def __init__( symbolic_locals: Dict[str, VariableTracker], symbolic_globals: Dict[str, VariableTracker], f_code: types.CodeType, + export: bool, ): - super(InstructionTranslatorBase, self).__init__() + super().__init__() # Mutable state checkpointed by copy_graphstate() - self.output: OutputGraph = output - self.symbolic_locals: Dict[str, VariableTracker] = symbolic_locals - self.symbolic_globals: Dict[str, VariableTracker] = symbolic_globals - self.stack: List[VariableTracker] = [] - self.instruction_pointer: int = 0 - self.current_instruction: Instruction = create_instruction("NOP") - self.next_instruction: typing.Optional[Instruction] = None - self.block_stack: List[BlockStackEntry] = [] - self.lineno: int = code_options.get("co_firstlineno") + self.output = output + self.symbolic_locals = symbolic_locals + self.symbolic_globals = symbolic_globals + self.stack = [] + self.instruction_pointer = 0 + self.current_instruction = create_instruction("NOP") + self.next_instruction = None + self.block_stack = [] + self.lineno = code_options["co_firstlineno"] # Properties of the input/output code self.instructions: List[Instruction] = instructions @@ -1337,16 +1563,16 @@ def __init__( self.exec_recorder = ExecutionRecorder(code=f_code, code_options=code_options) # Stack of module being parsed, current nn.module is at the end of ordered dict self.nn_module_stack: Dict[str, str] = {} + # Flag to indicate whether tracing is used for export. + self.export = export - if fake_tensors_available: - with torch._subclasses.FakeTensorMode( - throw_on_data_dependent_ops=True - ) as fake_mode: - pass - self._fake_mode = fake_mode + self._fake_mode = torch._subclasses.FakeTensorMode( + throw_on_data_dependent_ops=True, + shape_env=output.shape_env, + ) self.checkpoint = None - self.random_calls: List[tuple] = [] + self.random_calls = [] if sys.version_info >= (3, 10): from .resume_execution import ( @@ -1374,6 +1600,7 @@ def __init__( compiler_fn, one_graph, export, + mutated_closure_cell_contents: Set[str], ): super(InstructionTranslator, self).__init__( output=OutputGraph(f_globals, code_options, compiler_fn, self), @@ -1386,9 +1613,11 @@ def __init__( # A global var is inserted only after a STORE_GLOBAL happens to it symbolic_globals=collections.OrderedDict(), f_code=f_code, + export=export, ) self.one_graph: bool = one_graph self.export = export + self.mutated_closure_cell_contents = mutated_closure_cell_contents if self.export: assert ( self.one_graph @@ -1513,6 +1742,8 @@ def RETURN_VALUE(self, inst): class InliningInstructionTranslator(InstructionTranslatorBase): """Trace and inline a called method""" + symbolic_result: Optional[TensorVariable] + @classmethod def inline_call(cls, parent, func, args, kwargs): with patch.dict(counters, {"unimplemented": counters["inline_call"]}): @@ -1561,6 +1792,7 @@ def inline_call_(parent, func, args, kwargs): log.debug(f"INLINING {code} \n {dis.Bytecode(code).dis()} \n") + tracer: InliningInstructionTranslator if is_generator(code): tracer = InliningGeneratorInstructionTranslator( parent, code, sub_locals, parent.symbolic_globals, closure_cells, func @@ -1581,6 +1813,7 @@ def inline_call_(parent, func, args, kwargs): log.debug(f"DONE INLINING {code}") if is_generator(code): + assert isinstance(tracer, InliningGeneratorInstructionTranslator) assert tracer.symbolic_result.as_python_constant() is None return ListIteratorVariable( tracer.generated_items, @@ -1613,6 +1846,7 @@ def __init__( instructions=cleaned_instructions(code), code_options={k: getattr(code, k) for k in dir(code)}, f_code=code, + export=parent.export, ) self.parent = parent self.symbolic_result = None @@ -1632,14 +1866,30 @@ def STORE_DEREF(self, inst): else: self.output.side_effects.store_cell(cell, val) else: + maybe_cell = self.symbolic_locals.get(inst.argval) if isinstance( - self.symbolic_locals.get(inst.argval), + maybe_cell, variables.NewCellVariable, ): self.output.side_effects.store_cell( self.symbolic_locals[inst.argval], self.pop() ) else: + if ( + maybe_cell is not None + and maybe_cell.source.name() + not in self.parent.mutated_closure_cell_contents + ): + # Why is the source name here unique? + # mutated_closure_cell_contents is a per-frame + # concept, and sources identify, e.g., particular + # locals from the frame. If you had two locals, + # they'll get different source names, and therefore + # differ here. + self.parent.mutated_closure_cell_contents.add( + maybe_cell.source.name() + ) + raise exc.RestartAnalysis() unimplemented("write to __closure__ while inlining") def LOAD_DEREF(self, inst): @@ -1663,9 +1913,9 @@ def LOAD_CLOSURE(self, inst): def replace_all(self, oldvar: VariableTracker, newvar: VariableTracker): newvar = super().replace_all(oldvar, newvar) # recursively check and update parent's locals and stack in case oldvar is from parent - translator = self + translator: InstructionTranslatorBase = self while hasattr(translator, "parent"): - translator = translator.parent + translator = translator.parent # type: ignore[attr-defined] translator.update_locals_and_stack(oldvar, newvar) return newvar @@ -1681,6 +1931,8 @@ def RETURN_VALUE(self, inst): class InliningGeneratorInstructionTranslator(InliningInstructionTranslator): + generated_items: List[VariableTracker] + def __init__(self, *args, **kwargs): super(InliningGeneratorInstructionTranslator, self).__init__(*args, **kwargs) self.generated_items = [] diff --git a/torch/_dynamo/test_case.py b/torch/_dynamo/test_case.py index 089e5053d0625..39eda31646d2a 100644 --- a/torch/_dynamo/test_case.py +++ b/torch/_dynamo/test_case.py @@ -51,9 +51,6 @@ def tearDownClass(cls): def setUpClass(cls): super().setUpClass() cls._exit_stack = contextlib.ExitStack() - cls._exit_stack.enter_context( - patch.object(config, "raise_on_backend_error", True) - ) cls._exit_stack.enter_context( patch.object(config, "raise_on_ctx_manager_usage", True) ) diff --git a/torch/_dynamo/test_minifier_common.py b/torch/_dynamo/test_minifier_common.py new file mode 100644 index 0000000000000..947a45f2fcdf1 --- /dev/null +++ b/torch/_dynamo/test_minifier_common.py @@ -0,0 +1,129 @@ +import os +import re +import subprocess +import tempfile +import unittest + +import torch +import torch._dynamo +import torch._dynamo.test_case +from torch._dynamo.debug_utils import TEST_REPLACEABLE_COMMENT + + +class MinifierTestBase(torch._dynamo.test_case.TestCase): + _debug_dir_obj = tempfile.TemporaryDirectory() + DEBUG_DIR = _debug_dir_obj.name + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls._exit_stack.enter_context( + unittest.mock.patch.object( + torch._dynamo.config, + "debug_dir_root", + cls.DEBUG_DIR, + ) + ) + os.makedirs(cls.DEBUG_DIR, exist_ok=True) + + @classmethod + def tearDownClass(cls): + cls._debug_dir_obj.cleanup() + cls._exit_stack.close() + + def setUp(self): + super().setUp() + + def tearDown(self): + super().tearDown() + + # Search for the name of the first function defined in a code string. + def _get_fn_name(self, code): + fn_name_match = re.search(r"def (\w+)\(", code) + if fn_name_match is not None: + return fn_name_match.group(1) + return None + + # Run `code` in a separate python process. + # Returns the completed process state and the directory containing the + # minifier launcher script, if `code` outputted it. + def _run_test_code(self, code): + proc = subprocess.run( + ["python3", "-c", code], capture_output=True, cwd=self.DEBUG_DIR + ) + repro_dir_match = re.search( + r"(\S+)minifier_launcher.py", proc.stderr.decode("utf-8") + ) + if repro_dir_match is not None: + # Print repro directory for debugging generated code. + # Make sure to comment out `shutil.rmtree...` above as well. + print("repro dir:", repro_dir_match.group(1)) + return proc, repro_dir_match.group(1) + return proc, None + + # Patch generated files with testing patches + def _inject_code(self, patch_code, filename): + patch_code = f"""\ +{patch_code} +torch._dynamo.config.debug_dir_root = "{self.DEBUG_DIR}" +""" + with open(filename, "r") as f: + code = f.read() + code = code.replace(TEST_REPLACEABLE_COMMENT, patch_code) + with open(filename, "w") as f: + f.write(code) + return code + + # Runs the minifier launcher script in `repro_dir`, patched with `patch_code`. + def _run_minifier_launcher(self, patch_code, repro_dir): + self.assertIsNotNone(repro_dir) + launch_file = os.path.join(repro_dir, "minifier_launcher.py") + self.assertTrue(os.path.exists(launch_file)) + launch_code = self._inject_code(patch_code, launch_file) + + launch_proc = subprocess.run( + ["python3", launch_file], + capture_output=True, + cwd=repro_dir, + ) + + return launch_proc, launch_code + + # Runs the repro script in `repro_dir`, patched with `patch_code` + def _run_repro(self, patch_code, repro_dir): + self.assertIsNotNone(repro_dir) + repro_file = os.path.join(repro_dir, "repro.py") + self.assertTrue(os.path.exists(repro_file)) + repro_code = self._inject_code(patch_code, repro_file) + + repro_proc = subprocess.run( + ["python3", repro_file], capture_output=True, cwd=repro_dir + ) + return repro_proc, repro_code + + # Template for testing code. + # `run_code` is the code to run for the test case. + # `patch_code` is the code to be patched in every generated file. + def _gen_test_code(self, run_code, repro_after, repro_level, patch_code): + return f"""\ +import torch +import torch._dynamo +{patch_code} +torch._dynamo.config.repro_after = "{repro_after}" +torch._dynamo.config.repro_level = {repro_level} +torch._dynamo.config.debug_dir_root = "{self.DEBUG_DIR}" +{run_code} +""" + + # Runs a full minifier test. + # Minifier tests generally consist of 3 stages: + # 1. Run the problematic code (in a separate process since it could segfault) + # 2. Run the generated minifier launcher script + # 3. Run the generated repro script + def _run_full_test(self, run_code, repro_after, repro_level, patch_code): + test_code = self._gen_test_code(run_code, repro_after, repro_level, patch_code) + test_proc, repro_dir = self._run_test_code(test_code) + self.assertIsNotNone(repro_dir) + launch_proc, launch_code = self._run_minifier_launcher(patch_code, repro_dir) + repro_proc, repro_code = self._run_repro(patch_code, repro_dir) + return ((test_proc, launch_proc, repro_proc), (launch_code, repro_code)) diff --git a/torch/_dynamo/testing.py b/torch/_dynamo/testing.py index d6082ce48acf8..53ea6251bd4cf 100644 --- a/torch/_dynamo/testing.py +++ b/torch/_dynamo/testing.py @@ -32,18 +32,32 @@ def clone_me(x): return x.detach().clone().requires_grad_(x.requires_grad) +def named_parameters_for_optimized_module(mod): + assert isinstance(mod, eval_frame.OptimizedModule) + return mod._orig_mod.named_parameters + + +def remove_optimized_module_prefix(name): + prefix = "_orig_mod." + assert name.startswith(prefix) + name = name[len(prefix) :] + return torch.distributed.fsdp._common_utils.clean_tensor_name(name) + + def collect_results(model, prediction, loss, example_inputs): results = [] results.append(prediction) results.append(loss) - if isinstance(loss, torch.Tensor) and loss.item() > 1: - log.warning( - f"High loss value alert - {loss:.2f}. Can result in unstable gradients." - ) + # if isinstance(loss, torch.Tensor) and loss.item() > 1: + # log.warning( + # f"High loss value alert - {loss:.2f}. Can result in unstable gradients." + # ) grads = dict() params = dict() for name, param in model.named_parameters(): + if isinstance(model, eval_frame.OptimizedModule): + name = remove_optimized_module_prefix(name) param_copy = param grad = param.grad # Treat None and zero grad as same @@ -110,7 +124,7 @@ def debug_dump(name, code: types.CodeType, extra=""): ) -def debug_insert_nops(frame, cache_size): +def debug_insert_nops(frame, cache_size, hooks): """used to debug jump updates""" def insert_nops(instructions, code_options): @@ -222,7 +236,7 @@ def rand_strided(size, stride, dtype=torch.float32, device="cpu"): if dtype.is_floating_point: buffer = torch.randn(needed_size, dtype=dtype, device=device) else: - buffer = torch.ones(size=[needed_size], dtype=dtype, device=device) + buffer = torch.zeros(size=[needed_size], dtype=dtype, device=device) return torch.as_strided(buffer, size, stride) diff --git a/torch/_dynamo/types.py b/torch/_dynamo/types.py new file mode 100644 index 0000000000000..67a81a765bca8 --- /dev/null +++ b/torch/_dynamo/types.py @@ -0,0 +1,51 @@ +import dataclasses +import types +from typing import Callable, Dict, List, NamedTuple, Optional, OrderedDict, Union + +from typing_extensions import Protocol + + +class GuardFail(NamedTuple): + # A string repr of the piece of failed guard code we eval-ed + reason: str + # A code object where we failed a guard + orig_code: types.CodeType + + +class GuardFn(Protocol): + closure_vars: OrderedDict[str, object] + code_parts: List[str] + verbose_code_parts: List[str] + global_scope: Dict[str, object] + guard_fail_fn: Optional[Callable[[GuardFail], None]] + + # maps locals of user function to bool + def __call__(self, *maybe_dotzero: object, **f_locals: object) -> bool: + ... + + +@dataclasses.dataclass +class GuardedCode: + code: types.CodeType + check_fn: GuardFn + + +class DynamoCallbackFn(Protocol): + def __call__( + self, frame: types.FrameType, cache_size: int + ) -> Optional[GuardedCode]: + ... + + +DynamoCallback = Union[DynamoCallbackFn, None, bool] + + +class DynamoGuardHook(Protocol): + def __call__( + self, + guard_fn: GuardFn, + code: types.CodeType, + f_locals: Dict[str, object], + last: bool, + ) -> None: + ... diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index b66c240e0f04d..3d0f1bf34e363 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -3,6 +3,7 @@ import copy import cProfile import dataclasses +import datetime import dis import functools import gc @@ -18,16 +19,20 @@ import sys import time import types +import typing import weakref from contextlib import contextmanager from functools import lru_cache -from typing import Any, Dict +from typing import Any, Dict, List import numpy as np +import sympy import torch from torch import fx +from torch._dispatch.python import enable_python_dispatcher from torch.nn.modules.lazy import LazyModuleMixin +from torch.utils._pytree import tree_flatten, tree_map from . import config, logging as torchdynamo_logging @@ -83,7 +88,9 @@ def time_wrapper(*args, **kwargs): compilation_metrics[key] = [] t0 = time.time() r = func(*args, **kwargs) - compilation_metrics[key].append(time.time() - t0) + latency = time.time() - t0 + # print(f"Dynamo timer: key={key}, latency={latency:.2f} sec") + compilation_metrics[key].append(latency) return r return time_wrapper @@ -197,7 +204,7 @@ def format_bytecode(prefix, name, filename, line_no, code): def gen_record_file_name(exc, code): - return f"{config.replay_record_dir_name}/\ + return f"{get_debug_dir()}/error_recordings/\ {code.co_name}_{type(exc).__name__}_{code.co_firstlineno}.rec" @@ -271,6 +278,13 @@ def istype(obj, allowed_types): return type(obj) is allowed_types +def is_typing(value): + if sys.version_info < (3, 9): + return isinstance(value, typing._GenericAlias) + else: + return isinstance(value, typing._SpecialGenericAlias) + + def is_numpy_int_type(value): return istype( value, @@ -305,8 +319,7 @@ def istensor(obj): torch.nn.Parameter, *config.traceable_tensor_subclasses, ) - if fake_tensors_available: - tensor_list = tensor_list + (torch._subclasses.FakeTensor,) + tensor_list = tensor_list + (torch._subclasses.FakeTensor,) return istype(obj, tensor_list) @@ -335,13 +348,13 @@ def proxy_args_kwargs(args, kwargs): proxy_args = tuple(arg.as_proxy() for arg in args) proxy_kwargs = {key: arg.as_proxy() for key, arg in kwargs.items()} return proxy_args, proxy_kwargs - except NotImplementedError: + except NotImplementedError as e: from .exc import unimplemented from .variables.base import typestr raise unimplemented( f"call_function args: {typestr(*args)} {typestr(*list(kwargs.values()))}" - ) + ) from e @dataclasses.dataclass @@ -385,7 +398,24 @@ def clone_tensor(x): def clone_input(x): """copy while preserving strides""" + # TODO: this is questionable + if isinstance(x, torch._subclasses.FakeTensor): + # this func fails on fake tensors in __torch_dispatch__ + return x + + def torch_clone(x): + y = torch.clone(x) + if x.is_leaf: + y.requires_grad_(x.requires_grad) + if x.is_leaf and x.grad is not None: + y.grad = clone_input(x.grad) + return y + with torch.no_grad(): + if x.device.type == "xla": + # Access data_ptr() for a xla tensor will cause crash + return torch_clone(x) + needed_size = sum( (shape - 1) * stride for shape, stride in zip(x.size(), x.stride()) ) @@ -407,12 +437,7 @@ def clone_input(x): # RuntimeError: unsupported operation: more than one element of the written-to # tensor refers to a single memory location. Please clone() the tensor before # performing the operation. - y = torch.clone(x) - if x.is_leaf: - y.requires_grad_(x.requires_grad) - if x.is_leaf and x.grad is not None: - y.grad = clone_input(x.grad) - return y + return torch_clone(x) return result @@ -581,7 +606,19 @@ def is_safe_constant(v): if istype(v, (tuple, frozenset)): return all(map(is_safe_constant, v)) return istype( - v, (types.CodeType, int, float, bool, str, bytes, type(None), slice, type(type)) + v, + ( + types.CodeType, + int, + float, + bool, + str, + bytes, + type(None), + slice, + type(type), + torch.device, + ), ) @@ -656,37 +693,95 @@ def rename_implicit(v): return v -# FakeTensors were introduced after pytorch 1.12, so gate their use -# to allow pytorch 1.12 to work -fake_tensors_available = True -try: - from torch._subclasses import ( # noqa: F401 - FakeTensorMode, - UnsupportedFakeTensorException, +from torch._subclasses import ( # noqa: F401 + FakeTensorMode, + UnsupportedFakeTensorException, +) + + +def make_fake_tensor(e, fake_mode, static_shapes=False, tx=None, ignore_subclass=False): + fake_tensor = fake_mode.from_tensor( + e, static_shapes=static_shapes, ignore_subclass=ignore_subclass ) + if tx is not None: + from torch._dynamo.guards import TensorReference - def wrap_fake_exception(fn): - try: - return fn() - except UnsupportedFakeTensorException as e: - from .exc import unimplemented + def _record(tensor_ref): + if tensor_ref.ref_id not in tx.output.tensor_id_to_sym_shape_ref: + tx.output.tensor_id_to_sym_shape_ref[tensor_ref.ref_id] = set() + tx.output.tensor_id_to_sym_shape_ref[tensor_ref.ref_id].add(tensor_ref) - msg = f"Unsupported: {e.reason} with fake tensor propagation. Run with config.fake_tensor_propagation=False" - log.warning(msg) - raise unimplemented(msg) + def _extract(symbol): + if isinstance(symbol, int): + return None + sym_expr = symbol.get_pyobj().expr + if not isinstance(sym_expr, sympy.Symbol): + return None + return sym_expr - def wrap_to_fake_tensor(e, fake_mode): - if type(e) in (torch.Tensor, torch.nn.Parameter): - return wrap_fake_exception(lambda: fake_mode.from_tensor(e)) - else: - return e + def _record_ref(e, index, symbol, kind): + sym_expr = _extract(symbol) + if sym_expr: + tensor_ref = TensorReference(id(e), kind, index, sym_expr) + _record(tensor_ref) + + for index, symbol in enumerate(fake_tensor.size()): + _record_ref(e, index, symbol, "size") + + for index, symbol in enumerate(fake_tensor.stride()): + _record_ref(e, index, symbol, "stride") + + offset = fake_tensor.storage_offset() + _record_ref(e, None, offset, "storage_offset") + + return fake_tensor - def deepcopy_to_fake_tensor(obj, fake_mode): - with torch._subclasses.fake_tensor.FakeCopyMode(fake_mode): - return wrap_fake_exception(lambda: copy.deepcopy(obj)) -except ImportError: - fake_tensors_available = False +def wrap_fake_exception(fn): + try: + return fn() + except UnsupportedFakeTensorException as e: + from .exc import unimplemented + + msg = f"Unsupported: {e.reason} with fake tensor propagation." + log.warning(msg) + raise unimplemented(msg) from e + + +def wrap_to_fake_tensor(e, fake_mode): + if type(e) in (torch.Tensor, torch.nn.Parameter): + return wrap_fake_exception( + lambda: make_fake_tensor( + e, fake_mode, static_shapes=config.dynamic_shapes is False + ) + ) + else: + return e + + +def wrap_to_fake_tensor_and_record(e, tx, ignore_subclass=False): + # The not fake tensor check here is annoying - ideally, fake tensors never call this during wrapping. + # However, get_fake_value takes args and passes them through this, which may include fake tensors. + # see tree_map(fake_wrapper, args) in get_fake_value. + # TODO: Check if we should remove FakeTensor isinstance check when + # ignore_subclass + if isinstance(e, torch.Tensor) and not isinstance(e, torch._subclasses.FakeTensor): + static_shapes = config.dynamic_shapes is False + if type(e) is torch.nn.Parameter: + # Always static for params + static_shapes = True + return wrap_fake_exception( + lambda: make_fake_tensor( + e, tx.fake_mode, static_shapes, tx, ignore_subclass=ignore_subclass + ) + ) + else: + return e + + +def deepcopy_to_fake_tensor(obj, fake_mode): + with torch._subclasses.fake_tensor.FakeCopyMode(fake_mode): + return wrap_fake_exception(lambda: copy.deepcopy(obj)) def rmse(ref, res): @@ -735,13 +830,18 @@ def same( return False return True elif isinstance(ref, torch.Tensor): + assert not isinstance(ref, torch._subclasses.FakeTensor) + assert not isinstance(res, torch._subclasses.FakeTensor) + if ref.is_sparse: assert res.is_sparse ref = ref.to_dense() res = res.to_dense() assert isinstance(res, torch.Tensor), f"type mismatch {type(ref)} {type(res)}" if exact_dtype: - assert ref.dtype == res.dtype, f"dtype mismatch {ref.dtype}, {res.dtype}" + if ref.dtype != res.dtype: + log.error(f"dtype mismatch {ref.dtype}, {res.dtype}") + return False if ref.dtype == torch.bool: # triton stores bool as int8, so add this for more accurate checking return torch.allclose( @@ -758,10 +858,11 @@ def same( # early exit that handles zero/nan better # cosine_similarity(zeros(10), zeros(10), dim=0) is 0 return True - res = torch.nn.functional.cosine_similarity(ref, res, dim=0, eps=1e-6) - if res < 0.99: - log.warning(f"Similarity score={res.cpu().detach().item()}") - return res >= 0.99 + score = torch.nn.functional.cosine_similarity(ref, res, dim=0, eps=1e-6) + if score < 0.99: + breakpoint() + log.warning(f"Similarity score={score.cpu().detach().item()}") + return score >= 0.99 else: if not exact_dtype: ref = ref.to(res.dtype) @@ -776,8 +877,11 @@ def same( res_error = rmse(fp64_ref, res).item() multiplier = 2.0 - if fp64_ref.numel() < 1000 or ( - ref.ndim == 4 and ref.shape[-1] == ref.shape[-2] == 1 + if ( + fp64_ref.numel() < 1000 + or (ref.ndim == 4 and ref.shape[-1] == ref.shape[-2] == 1) + # large tol means a benchmark has been specified as REQUIRE_HIGHER_TOLERANCE + or tol >= 2 * 1e-2 ): # In the presence of noise, noise might dominate our error # metric for smaller tensors. @@ -928,3 +1032,173 @@ def recompile_reasons(code): rpt += "No cache-limited recompilations detected.\n" return rpt + + +# return same dir unless user changes config between calls +@functools.lru_cache(None) +def _get_debug_dir(root_dir): + dir_name = "run_" + datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f") + return os.path.join(root_dir, dir_name) + + +def get_debug_dir(): + debug_root = config.debug_dir_root + return _get_debug_dir(debug_root) + + +def get_fake_value(node, tx): + """ + Run the computation represented by `node` using fake tensors and return the result. + """ + from .exc import TorchRuntimeError, unimplemented, Unsupported + + op = node.op + fake_wrapper = functools.partial(wrap_to_fake_tensor_and_record, tx=tx) + + def visit(n: torch.fx.Node): + return n.meta["example_value"] + + args, kwargs = torch.fx.node.map_arg((node.args, node.kwargs), visit) + args = tree_map(fake_wrapper, args) + kwargs = tree_map(fake_wrapper, kwargs) + + nnmodule = None + if op == "call_module": + nnmodule = tx.output.nn_modules[node.target] + + if not is_lazy_module(nnmodule): + nnmodule = deepcopy_to_fake_tensor(nnmodule, tx.fake_mode) + + if op == "call_module" and is_lazy_module(nnmodule): + assert nnmodule is not None + # In the case of a lazy module, we want to run + # the pre-hooks which initialize it + nnmodule(*args, **kwargs) + try: + with tx.fake_mode, enable_python_dispatcher(): + return wrap_fake_exception( + lambda: run_node(tx.output, node, args, kwargs, nnmodule) + ) + except Unsupported: + raise + except RuntimeError as e: + cause = e + if e.__cause__ is not None: + cause = e.__cause__ + if isinstance( + cause, torch._subclasses.fake_tensor.DataDependentOutputException + ): + if config.capture_scalar_outputs and node.target == "item": + return torch.zeros(size=(), dtype=args[0].dtype).item() + else: + unimplemented(f"data dependent operator: {cause.func}") + elif isinstance( + cause, torch._subclasses.fake_tensor.DynamicOutputShapeException + ): + unimplemented(f"dynamic shape operator: {cause.func}") + raise TorchRuntimeError() from e + + +def run_node(output_graph, node, args, kwargs, nnmodule): + """ + Runs a given node, with the given args and kwargs. + + Behavior is dicatated by a node's op. + + run_node is useful for extracting real values out of nodes. + See get_real_value for more info on common usage. + + Note: The output_graph arg is only used for 'get_attr' ops + Note: The nnmodule arg is only used for 'call_module' ops + + Nodes that are not call_function, call_method, call_module, or get_attr will + raise an AssertionError. + """ + op = node.op + try: + if op == "call_function": + return node.target(*args, **kwargs) + elif op == "call_method": + return getattr(args[0], node.target)(*args[1:], **kwargs) + elif op == "call_module": + assert nnmodule is not None + return nnmodule(*args, **kwargs) + elif op == "get_attr": + return output_graph.get_submodule(node.target) + except Exception as e: + raise RuntimeError( + f"Failed running {op} {node.target}(*{args}, **{kwargs}):\n{e}\n(scroll up for backtrace)" + ) from e + raise AssertionError(op) + + +def get_real_value(node, output_graph): + """ + Run the actual computation represented by `node` and return the result. + This will execute any dependent nodes in the graph as well. + """ + cache = output_graph.real_value_cache + if node in cache: + return cache[node] + + op = node.op + args, kwargs = torch.fx.node.map_arg( + (node.args, node.kwargs), + lambda n: get_real_value(n, output_graph), + ) + + if op == "call_module": + nn_module = output_graph.nn_modules[node.target] + if not is_lazy_module(nn_module): + nn_module = copy.deepcopy(nn_module) + else: + # In the case of a lazy module, we want to run + # the pre-hooks which initialize it + nn_module(*args, **kwargs) + else: + nn_module = None + + try: + real_value = run_node(output_graph, node, args, kwargs, nn_module) + cache[node] = real_value + except RuntimeError as e: + raise TorchRuntimeError() from e + return real_value + + +def assert_no_fake_params_or_buffers(gm): + from torch._subclasses.fake_tensor import FakeTensorConfig + + def stack_or_hint(t): + if FakeTensorConfig.debug: + import traceback + + return f"FAKE TENSOR CREATION TRACEBACK: \n {traceback.format_list(t._debug_trace)}" + else: + return "Enable TORCH_FAKE_TENSOR_DEBUG=1 to get creation stack traces on fake tensors." + + for name, buffer in gm.named_buffers(): + assert not isinstance( + buffer, torch._subclasses.FakeTensor + ), f"Unexpected fake buffer {name} {stack_or_hint(buffer)}" + for name, param in gm.named_parameters(): + assert not isinstance( + param, torch._subclasses.FakeTensor + ), f"Unexpected fake param {name} {stack_or_hint(param)}" + + +def fake_mode_from_tensors(inputs: List[Any]): + """ + Takes a list of anything, unflattened is fine, returns a fake_mode + if any are fake. All fake modes on all fake tensors must be identical. + Returns None if no fake_mode is fine + """ + flat_inputs, _ = tree_flatten(inputs) + fake_mode = None + for flat_input in flat_inputs: + if isinstance(flat_input, torch._subclasses.FakeTensor): + if fake_mode is None: + fake_mode = flat_input.fake_mode + else: + assert fake_mode is flat_input.fake_mode + return fake_mode diff --git a/torch/_dynamo/variables/__init__.py b/torch/_dynamo/variables/__init__.py index 8c80557e3fd01..2305afc226ac2 100644 --- a/torch/_dynamo/variables/__init__.py +++ b/torch/_dynamo/variables/__init__.py @@ -35,6 +35,7 @@ ) from .nn_module import NNModuleVariable, UnspecializedNNModuleVariable from .tensor import ( + DynamicShapeVariable, FakeItemVariable, TensorVariable, UnspecializedNumpyVariable, diff --git a/torch/_dynamo/variables/base.py b/torch/_dynamo/variables/base.py index 62cddfff0cb29..52161a8dbdcb6 100644 --- a/torch/_dynamo/variables/base.py +++ b/torch/_dynamo/variables/base.py @@ -21,7 +21,15 @@ def __eq__(self, other): return self is other -class VariableTracker: +# metaclass to call post_init +class HasPostInit(type): + def __call__(cls, *args, **kwargs): + obj = type.__call__(cls, *args, **kwargs) + obj.__post_init__(*args, **kwargs) + return obj + + +class VariableTracker(object, metaclass=HasPostInit): """ Base class for tracked locals and stack values @@ -41,13 +49,6 @@ def visit(var): if type(var) in (list, tuple, dict_values, odict_values): for i in var: visit(i) - elif isinstance(var, variables.BaseListVariable): - guards.update(var.guards) - for i in var.items: - visit(i) - elif isinstance(var, variables.ConstDictVariable): - guards.update(var.guards) - visit(var.items.values()) else: assert isinstance(var, VariableTracker), typestr(var) guards.update(var.guards) @@ -70,7 +71,11 @@ def copy(cls, value): @classmethod def apply( - cls, fn: Callable[["VariableTracker"], "VariableTracker"], value, cache=None + cls, + fn: Callable[["VariableTracker"], "VariableTracker"], + value, + cache=None, + skip_fn=lambda _: False, # Whether we should skip applying to this var ): """ Walk this object and call fn on all the VariableTracker @@ -84,21 +89,29 @@ def apply( return cache[idx][0] if isinstance(value, VariableTracker): - updated_dict = dict(value.__dict__) - for key in updated_dict.keys(): - if key not in value._nonvar_fields: - updated_dict[key] = cls.apply(fn, updated_dict[key], cache) - result = fn(value.clone(**updated_dict)) + if not skip_fn(value): + updated_dict = dict(value.__dict__) + for key in updated_dict.keys(): + if key not in value._nonvar_fields: + updated_dict[key] = cls.apply( + fn, updated_dict[key], cache, skip_fn + ) + result = fn(value.clone(**updated_dict)) + else: + result = fn(value) + elif istype(value, list): - result = [cls.apply(fn, v, cache) for v in value] + result = [cls.apply(fn, v, cache, skip_fn) for v in value] elif istype(value, tuple): - result = tuple(cls.apply(fn, v, cache) for v in value) + result = tuple(cls.apply(fn, v, cache, skip_fn) for v in value) elif istype(value, collections.OrderedDict): result = collections.OrderedDict( - cls.apply(fn, v, cache) for v in value.items() + cls.apply(fn, v, cache, skip_fn) for v in value.items() ) elif istype(value, dict): - result = {k: cls.apply(fn, v, cache) for k, v in list(value.items())} + result = { + k: cls.apply(fn, v, cache, skip_fn) for k, v in list(value.items()) + } else: result = value @@ -244,11 +257,32 @@ def __init__( guards: Optional[Set] = None, source: Source = None, mutable_local: MutableLocal = None, + recursively_contains: Optional[Set] = None, ): super(VariableTracker, self).__init__() self.guards = guards or set() self.source = source self.mutable_local = mutable_local + self.recursively_contains = ( + recursively_contains # provides hint to replace_all when replacing vars + ) + + def __post_init__(self, *args, **kwargs): + if self.recursively_contains is None: + self.recursively_contains = set() + + def aggregate_mutables(var): + self.recursively_contains.update(var.recursively_contains) + if var.mutable_local is not None: + self.recursively_contains.add(var.mutable_local) + + return var + + VariableTracker.apply( + aggregate_mutables, self, skip_fn=lambda var: var is not self + ) + + assert None not in self.recursively_contains def typestr(*objs): diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index d3c5140fa4a97..90fdaa143a66c 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -3,15 +3,19 @@ import enum import functools import inspect +import math +import numbers +import operator import re import types from abc import ABCMeta -from typing import Any, List +from typing import Any, Optional, Union import numpy as np from functorch.experimental.ops import PyOperator import torch +from torch.fx.immutable_collections import immutable_list from .. import config, mutation_guard, replay_record, skipfiles from ..allowed_functions import is_allowed, is_builtin_callable, is_numpy @@ -31,18 +35,24 @@ TupleIteratorGetItemSource, ) from ..utils import ( + clone_input, + get_fake_value, getfile, global_key_name, is_namedtuple, is_numpy_int_type, + is_typing, istensor, istype, odict_values, + preserve_rng_state, tuple_iterator, tuple_iterator_getitem, tuple_iterator_len, + wrap_to_fake_tensor_and_record, ) -from .base import MutableLocal + +from .base import MutableLocal, typestr from .builtin import BuiltinVariable from .constant import ConstantVariable, EnumVariable from .dicts import ( @@ -57,6 +67,7 @@ ListVariable, NamedTupleVariable, RangeVariable, + SizeVariable, SliceVariable, TupleVariable, ) @@ -72,6 +83,8 @@ ) from .nn_module import UnspecializedNNModuleVariable from .tensor import ( + DynamicShapeVariable, + FakeItemVariable, TensorVariable, TensorWithTFOverrideVariable, UnspecializedNumpyVariable, @@ -86,13 +99,22 @@ from .user_defined import UserDefinedClassVariable, UserDefinedObjectVariable +class _missing: + pass + + @dataclasses.dataclass class GraphArg: source: Source example: Any is_unspecialized: bool + fake_tensor: Optional[torch._subclasses.fake_tensor.FakeTensor] def __post_init__(self): + if isinstance(self.example, torch.Tensor): + assert isinstance( + self.fake_tensor, torch._subclasses.fake_tensor.FakeTensor + ) if isinstance(self.example, torch._subclasses.fake_tensor.FakeTensor): raise AssertionError("Fake Tensor observed in TorchDynamo Fx graph inputs") @@ -102,6 +124,13 @@ def load(self, tx): def get_examples(self): return [self.example] + def get_fake_examples(self): + if self.fake_tensor is not None: + assert isinstance( + self.fake_tensor, torch._subclasses.fake_tensor.FakeTensor + ) + return [self.fake_tensor] + def __len__(self): return 1 @@ -187,6 +216,8 @@ def make_guards(self, *guards): def _wrap(self, value): make_guards = self.make_guards + if istype(value, (torch.SymInt, torch.SymFloat)): + return self.wrap_sym(value) if istensor(value): return self.wrap_tensor(value) elif istype(value, (tuple, list, odict_values)) or is_namedtuple(value): @@ -221,9 +252,19 @@ def _wrap(self, value): return ListIteratorVariable( output, mutable_local=MutableLocal(), guards=guards ) - elif istype(value, range): - guards = self.make_guards(GuardBuilder.EQUALS_MATCH) - return RangeVariable(value=value, guards=guards) + elif istype(value, (slice, range)): + items = [ + VariableBuilder(self.tx, AttrSource(self.get_source(), k))( + getattr(value, k) + ) + for k in ("start", "stop", "step") + ] + if isinstance(value, slice): + return SliceVariable(items, guards=make_guards(GuardBuilder.TYPE_MATCH)) + else: + return RangeVariable( + items, guards=make_guards(GuardBuilder.EQUALS_MATCH) + ) elif istype( value, (dict, collections.defaultdict, collections.OrderedDict) ) and all( @@ -278,9 +319,18 @@ def index_source(key): return self.tx.output.side_effects.track_object_existing( self.source, value, result ) - elif issubclass( + elif getattr(value, "_is_fsdp_managed_module", False) or issubclass( value.__class__, torch.nn.parallel.distributed.DistributedDataParallel ): + if getattr(value, "_is_fsdp_managed_module", False): + # Note: we can't do this assert inside FSDP constructor, + # since we don't know yet whether dynamo will be used + assert getattr( + value, "_fsdp_use_orig_params", False + ), "Dynamo only supports FSDP with use_orig_params=True" + + # See note [Dynamo treats FSDP wrapped modules as UnspecializedNNModule] + # in fully_sharded_data_parallel.py for more information return UnspecializedNNModuleVariable( value, guards=make_guards(GuardBuilder.TYPE_MATCH) ) @@ -341,7 +391,8 @@ def index_source(key): value, guards=make_guards(GuardBuilder.FUNCTION_MATCH), ) - elif value is List: + elif is_typing(value): + # typing.List, typing.Mapping, etc. return TypingVariable( value, guards=make_guards(GuardBuilder.ID_MATCH), @@ -404,7 +455,7 @@ def index_source(key): value, guards=make_guards(GuardBuilder.FUNCTION_MATCH) ) elif ( - isinstance(value, types.BuiltinFunctionType) + isinstance(value, types.MethodType) and type(getattr(value, "__self__", None)) is torch.autograd.function.FunctionMeta and getattr(value, "__name__", "") == "apply" @@ -427,14 +478,6 @@ def index_source(key): return HFPretrainedConfigVariable( value, guards=make_guards(GuardBuilder.TYPE_MATCH) ) - elif isinstance(value, slice): - items = [ - VariableBuilder(self.tx, AttrSource(self.get_source(), k))( - getattr(value, k) - ) - for k in ("start", "stop", "step") - ] - return SliceVariable(items, guards=make_guards(GuardBuilder.TYPE_MATCH)) elif isinstance(value, PyOperator): return TorchPyOperator( value, @@ -490,6 +533,28 @@ def tensor_should_specialize(self): ) ) + def wrap_sym(self, value: Union[torch.SymInt, torch.SymFloat]): + if not is_constant_source(self.get_source()): + self.tx.output.graphargs.append( + GraphArg(self.get_source(), value, False, None) + ) + elif is_constant_source(self.get_source()): + return self.tx.output.register_attr_or_module( + value, + re.sub(r"[^a-zA-Z0-9]+", "_", self.name), + source=None, + dyn_shape=value + # shape Guards live their own rich life via shape_env + ) + return DynamicShapeVariable.create( + tx=self.tx, + proxy=self.tx.output.create_graph_input( + re.sub(r"[^a-zA-Z0-9]+", "_", self.name), type(value) + ), + dyn_shape=value + # shape Guards live their own rich life via shape_env + ) + def wrap_tensor(self, value: torch.Tensor): if self.get_source().guard_source().is_nn_module(): return self.tx.output.register_attr_or_module( @@ -499,77 +564,126 @@ def wrap_tensor(self, value: torch.Tensor): # Guards are done inside register_attr_or_module # guards=self.make_guards(GuardBuilder.TENSOR_MATCH), ) + + if is_constant_source(self.get_source()): + return self.tx.output.register_attr_or_module( + value, + re.sub(r"[^a-zA-Z0-9]+", "_", self.name), + source=None, + # Guards are added inside register_attr_or_module + ) + + if type(value) in config.traceable_tensor_subclasses: + # Ordinarily, we would fakeify a tensor so that it can get dynamic + # shapes and be computed on without triggering actual operations. + # However, how can we fakeify a tensor subclass? Ordinary + # inheritance (nor multiple inheritance) won't work work. + # + # Instead, our plan is to *manually simulate* the tensor subclass + # inheriting from a fake tensor with dynamo. This means our + # data representation for a tensor subclass will be a fake tensor + # + tensor subclass type + any extra data the subclass may have + # been storing on the tensor. Because all Python accesses are + # mediated through TensorWithTFOverrideVariable, we can ensure + # that we dispatch differently, e.g., according to + # __torch_function__ + # + # To simplify things for now, the __dict__ tracking bits haven't + # been implemented yet, but they can be added into this design at + # a later point in time. + ignore_subclass = True else: - if not is_constant_source(self.get_source()): - self.tx.output.graphargs.append( - GraphArg(self.get_source(), value, False) - ) - # Disable __torch_function__ to prevent cloning of `value` to hit - # us - with torch._C.DisableTorchFunction(): - if is_constant_source(self.get_source()): - return self.tx.output.register_attr_or_module( - value, - re.sub(r"[^a-zA-Z0-9]+", "_", self.name), - source=None, - # Guards are added inside register_attr_or_module - ) - tensor_variable = TensorVariable.create( - tx=self.tx, - proxy=self.tx.output.create_graph_input( - re.sub(r"[^a-zA-Z0-9]+", "_", self.name), type(value) - ), - example_value=value, - guards=self.make_guards(GuardBuilder.TENSOR_MATCH), - should_specialize=self.tensor_should_specialize(), - ) - if torch.overrides.has_torch_function_unary(value): - subclass_torch_function__func = value.__torch_function__.__func__ - subclass_type = type(value) - return TensorWithTFOverrideVariable( - tensor_variable, - self.get_source(), - subclass_torch_function__func, - subclass_type, - ) - return tensor_variable + assert type(value) in (torch.Tensor, torch.nn.Parameter) + ignore_subclass = False + + tensor_variable = wrap_fx_proxy( + tx=self.tx, + proxy=self.tx.output.create_graph_input( + re.sub(r"[^a-zA-Z0-9]+", "_", self.name), type(value) + ), + example_value=value, + guards=self.make_guards(GuardBuilder.TENSOR_MATCH), + should_specialize=self.tensor_should_specialize(), + ignore_subclass=ignore_subclass, + ) + + # TODO: I think the result is guaranteed to be fake with + # ignore_subclass changes + fake_tensor_value = None + example_value = tensor_variable.proxy.node.meta["example_value"] + if isinstance(example_value, torch._subclasses.fake_tensor.FakeTensor): + fake_tensor_value = example_value + + self.tx.output.graphargs.append( + GraphArg(self.get_source(), value, False, fake_tensor_value) + ) + + if type(value) in config.traceable_tensor_subclasses: + subclass_torch_function__func = value.__torch_function__.__func__ + subclass_type = type(value) + # NB: This is slightly misnamed, a tensor subclass might not have + # any explicit __torch_function__ implementation and is relying + # on the default inherited from torch.Tensor + return TensorWithTFOverrideVariable( + tensor_variable, + self.get_source(), + subclass_torch_function__func, + subclass_type, + ) + + return tensor_variable def wrap_unspecialized_primitive(self, value): if self.name in self.tx.output.unspec_variable_map: return self.tx.output.unspec_variable_map[self.name] else: - wrapped_value = torch.tensor(value) - if not is_constant_source(self.get_source()): - self.tx.output.graphargs.append( - GraphArg(self.get_source(), wrapped_value, True) + if config.dynamic_shapes and isinstance(value, int): + shape_env = self.tx.output.shape_env + wrapped_value = shape_env.create_symintnode( + shape_env.create_symbol(value) ) + # TODO: Do float + else: + # TODO: Eliminate this case entirely + wrapped_value = torch.tensor(value) if not isinstance(self.get_source(), RandomValueSource): guards = {self.get_source().make_guard(GuardBuilder.TYPE_MATCH, True)} options = {"guards": guards} else: options = {} options.update({"source": self.get_source()}) - options.update({"raw_value": value}) + if isinstance(wrapped_value, torch.Tensor): + options.update({"raw_value": value}) proxy = self.tx.output.create_graph_input( re.sub(r"[^a-zA-Z0-9]+", "_", self.name), type(wrapped_value) ) if isinstance(value, np.number): - unspec_var = UnspecializedNumpyVariable.create( + unspec_var = wrap_fx_proxy_cls( + UnspecializedNumpyVariable, tx=self.tx, proxy=proxy, example_value=wrapped_value, **options, ) else: - unspec_var = UnspecializedPythonVariable.create( + unspec_var = wrap_fx_proxy_cls( + UnspecializedPythonVariable, tx=self.tx, proxy=proxy, example_value=wrapped_value, **options, ) self.tx.output.unspec_variable_map[self.name] = unspec_var + if not is_constant_source(self.get_source()): + fake_tensor_value = None + example_value = unspec_var.proxy.node.meta["example_value"] + if isinstance(example_value, torch._subclasses.fake_tensor.FakeTensor): + fake_tensor_value = example_value + self.tx.output.graphargs.append( + GraphArg(self.get_source(), wrapped_value, True, fake_tensor_value) + ) return unspec_var @@ -589,3 +703,193 @@ def _dataclasses_fields_lambda(obj): ) items.append(UserDefinedObjectVariable(field, source=source).add_options(obj)) return TupleVariable(items).add_options(obj) + + +def wrap_fx_proxy(tx, proxy, example_value=None, **options): + return wrap_fx_proxy_cls( + target_cls=TensorVariable, + tx=tx, + proxy=proxy, + example_value=example_value, + **options, + ) + + +# Note: Unfortunate split due to some gross classes existing that subclass TensorVariable +# Should be compositional instead +def wrap_fx_proxy_cls( + target_cls, tx, proxy, example_value=None, ignore_subclass=False, **options +): + if "guards" in options and options["guards"] is not None: + tx.output.guards.update(options["guards"]) + + assert "example_value" not in proxy.node.meta + if not config.dynamic_propagation: + # TODO: This probably doesn't handle subclass correctly + if isinstance(example_value, torch.Tensor): + options.update(target_cls.specialize(example_value)) + return target_cls(proxy, **options) + + initial_example_value = example_value + + def _clone_input(value): + if isinstance(value, torch.Tensor): + # tensor subclasses will not be converted to FakeTensors and need to be cloned + if not isinstance(value, torch._subclasses.fake_tensor.FakeTensor): + # NB: ensure strides are preserved + value = clone_input(value) + + return value + + with preserve_rng_state(): + if example_value is None: + example_value = get_fake_value(proxy.node, tx) + else: + # Note: Unfortunately, this can happen during tracing, and is valid enough for now to allow. + # TODO(voz): Find all the callsites and burn this down. + # Flipping it to an assert fails dozens of tests. + # TODO(ezyang): should attempt this burndown again + if not isinstance(example_value, torch._subclasses.FakeTensor): + # We shouldn't be doing this at all, see + # https://github.com/pytorch/torchdynamo/issues/1950 + # But assuming we're doing it, the legacy behavior for + # subclasses was to perform a clone WITHOUT preserving + # the subclass. It's not clear to me that's what you actually + # want, but whatever, I wouldn't have this cache at all. + with torch._C.DisableTorchFunction(): + proxy.tracer.real_value_cache[proxy.node] = _clone_input( + example_value + ) + # NB: If we're ignoring subclass, then the expectation is you will + # take the returned TensorVariable and wrap it into a more + # accurate TensorVariable that is able to track subclass-ness; + # otherwise this is wrong! + example_value = wrap_to_fake_tensor_and_record( + example_value, tx=tx, ignore_subclass=ignore_subclass + ) + + if isinstance(example_value, torch.Tensor): + is_parameter = isinstance(example_value, torch.nn.Parameter) + should_specialize = options.pop("should_specialize", False) + if is_parameter or should_specialize: + specialized_value = initial_example_value + else: + specialized_value = None + + # NB: In most (all?) cases, this does not actually do a clone. + # (WARNING: this means that if we mutate metadata on the fake + # tensor, the stored example value will update too!) + example_value = _clone_input(example_value) + proxy.node.meta["example_value"] = example_value + specialized_props = target_cls.specialize(example_value) + if isinstance(example_value, torch._subclasses.fake_tensor.FakeTensor): + # NB: This will be wrong for ignore_subclass; fix it up later! + specialized_props["class_type"] = ( + torch.nn.Parameter if is_parameter else torch.Tensor + ) + + specialized_props["specialized_value"] = specialized_value + + options.update(specialized_props) + return target_cls(proxy, **options) + elif ( + hasattr(proxy.node.target, "__name__") + and proxy.node.target.__name__ == "set_state" + and isinstance(proxy.node.target.__self__, torch._C.Generator) + or proxy.node.target == torch.random.set_rng_state + ): + from . import TorchVariable + + return TorchVariable(proxy.node.target) + elif ( + proxy.node.target == torch._C._DisableFuncTorch + or proxy.node.target == torch.cuda._is_in_bad_fork + ): + from . import UserDefinedObjectVariable + + return UserDefinedObjectVariable(example_value) + elif istype(example_value, (int, bool, float)) and config.dynamic_shapes: + proxy.node.meta["example_value"] = example_value + return DynamicShapeVariable.create(tx, proxy, example_value, **options) + elif istype(example_value, torch.Size) and config.dynamic_shapes: + proxy.node.meta["example_value"] = example_value + sizes = [] + for i, v in enumerate(example_value): + proxy_i = proxy[i] + sizes.append(DynamicShapeVariable.create(tx, proxy_i, v, **options)) + return SizeVariable(sizes, proxy, **options) + elif istype(example_value, int) and proxy.node.target in ( + torch.seed, + operator.mod, + # some mac builds are missing torch.distributed.get_rank() + getattr(torch.distributed, "get_rank", _missing), + getattr(torch.distributed, "get_world_size", _missing), + ): + if config.dynamic_shapes: + proxy.node.meta["example_value"] = example_value + return DynamicShapeVariable.create(tx, proxy, example_value, **options) + else: + return ConstantVariable(example_value, **options) + elif istype(example_value, torch.Size) and all( + [isinstance(x, int) for x in example_value] + ): + sizes = [ConstantVariable(x) for x in example_value] + return SizeVariable(sizes, **options) + elif isinstance(example_value, (tuple, list)): + unpacked = [] + for i, val in enumerate(example_value): + if val is None: + # nn.MultiheadAttention() can return None, see issue #175 + unpacked.append( + ConstantVariable(None, **options), + ) + else: + unpacked.append( + wrap_fx_proxy( + tx, + proxy.tracer.create_proxy( + "call_function", operator.getitem, (proxy, i), {} + ), + example_value=val, + **options, + ) + ) + if istype(example_value, tuple): + return TupleVariable(unpacked, **options) + elif istype(example_value, (list, immutable_list)): + return ListVariable(unpacked, mutable_local=MutableLocal(), **options) + else: + assert ( + example_value.__class__.__module__ == "torch.return_types" + or hasattr(example_value, "_fields") + ), ("namedtuple?") + return NamedTupleVariable(unpacked, example_value.__class__, **options) + elif example_value is None or proxy.node.target is torch.manual_seed: + return ConstantVariable(None, **options) + elif ( + isinstance(example_value, int) + and proxy.node.target is torch._utils._element_size + ): + proxy.node.meta["example_value"] = example_value + return ConstantVariable(example_value, **options) + elif ( + isinstance(example_value, numbers.Number) + and (proxy.node.target == "item" or proxy.node.target in {math.sqrt, math.pow}) + and config.capture_scalar_outputs + ): + # item raw value should not be accessed + return wrap_fx_proxy_cls( + FakeItemVariable, + tx=tx, + proxy=proxy, + example_value=torch.tensor(example_value), + **options, + ) + elif isinstance(example_value, (torch.SymInt, torch.SymFloat)): + proxy.node.meta["example_value"] = example_value + return DynamicShapeVariable(proxy, example_value, **options) + else: + raise AssertionError( + "torch.* op returned non-Tensor " + + f"{typestr(example_value)} {proxy.node.op} {proxy.node.target}" + ) diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 53fdb95aca8bb..083da86a1e19b 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -10,6 +10,7 @@ import numpy as np import torch +from torch.fx.experimental.symbolic_shapes import sym_float, sym_int from .. import config, variables from ..allowed_functions import is_allowed @@ -26,7 +27,7 @@ ) from .base import MutableLocal, VariableTracker from .dicts import ConstDictVariable -from .tensor import DynamicShapeVariable, FakeItemVariable +from .tensor import DynamicShapeVariable, FakeItemVariable, UnspecializedPythonVariable log = logging.getLogger(__name__) @@ -226,6 +227,7 @@ def unwrap_unspec_args_kwargs(args, kwargs): def call_function( self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" ) -> "VariableTracker": + from .builder import wrap_fx_proxy, wrap_fx_proxy_cls constant_args = check_constant_args(args, kwargs) tensor_args = self.tensor_args(*args, **kwargs) @@ -234,7 +236,7 @@ def call_function( has_constant_handler = self.can_constant_fold_through() and ( constant_args or unspec_python_args ) - assert isinstance(args, list) + assert isinstance(args, (list, tuple)) assert isinstance(kwargs, dict) if ( @@ -271,10 +273,13 @@ def call_function( fn, args = operator.add, [args[1], args[0]] proxy = tx.output.create_proxy( - "call_function", fn, *proxy_args_kwargs(args, kwargs), current_tx=tx + "call_function", + fn, + *proxy_args_kwargs(args, kwargs), ) if any([isinstance(arg, FakeItemVariable) for arg in args]): - return variables.FakeItemVariable.create( + return wrap_fx_proxy_cls( + FakeItemVariable, tx, proxy, **options, @@ -282,7 +287,8 @@ def call_function( elif self.unspec_numpy_args(*args, **kwargs): _args, _kwargs = self.unwrap_unspec_args_kwargs(args, kwargs) raw_value = self.fn(*_args, **_kwargs) - return variables.UnspecializedNumpyVariable.create( + return wrap_fx_proxy_cls( + variables.UnspecializedNumpyVariable, tx, proxy, raw_value=raw_value, @@ -298,7 +304,8 @@ def call_function( if isinstance(x, variables.UnspecializedPythonVariable) ) - return variables.UnspecializedPythonVariable.create( + return wrap_fx_proxy_cls( + UnspecializedPythonVariable, tx, proxy, raw_value=raw_value, @@ -312,21 +319,36 @@ def call_function( args[0], variables.UnspecializedPythonVariable ): args[0] = args[0].convert_to_constant(tx) - return variables.TensorVariable.create(tx, proxy, **options) + return wrap_fx_proxy(tx, proxy, **options) except NotImplementedError: unimplemented(f"partial tensor op: {self} {args} {kwargs}") # Handle cases like int(torch.seed()) - if self.fn is int and isinstance(args[0], DynamicShapeVariable): - return args[0] + # Also handle sym_float to sym_int cases + if self.fn in (int, float) and isinstance(args[0], DynamicShapeVariable): + fn_ = sym_int if self.fn is int else sym_float + out = wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + fn_, + (args[0].as_proxy(),), + {}, + ), + **options, + ) + return out handler = getattr(self, f"call_{self.fn.__name__}", None) if handler: try: inspect.signature(handler).bind(tx, *args, **kwargs) except TypeError as exc: - log.warning(f"incorrect arg count {handler} {exc}") + if not has_constant_handler: + log.warning( + f"incorrect arg count {handler} {exc} and no constant handler" + ) handler = None if handler: @@ -350,7 +372,6 @@ def call_function( ), **options, ) - return super().call_function(tx, args, kwargs) def _call_min_max(self, tx, a, b): @@ -359,11 +380,24 @@ def _call_min_max(self, tx, a, b): a, b = b, a assert isinstance(a, variables.TensorVariable) - # 1. result of an item call is a scalar convert to a tensor - # 2. dynamic shape should be resolved to tensor - if isinstance(a, (FakeItemVariable, DynamicShapeVariable)): + # result of an item call is a scalar convert to a tensor + if isinstance(a, FakeItemVariable): a = variables.TorchVariable(torch.tensor).call_function(tx, [a], {}) + # Dynamic input does not get resolved, rather, gets stored as call_function + if isinstance(a, DynamicShapeVariable): + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + self.fn, + *proxy_args_kwargs([a, b], {}), + ), + **VariableTracker.propagate(self, [a, b]), + ) + # convert min/max to torch ops if b.is_python_constant(): kwargs = {"min": b} if (self.fn is max) else {"max": b} @@ -422,28 +456,62 @@ def _call_min_max(self, tx, a, b): return variables.ConstantVariable(max(a.value, b.value)) else: return variables.ConstantVariable(min(a.value, b.value)) + elif isinstance(a, DynamicShapeVariable) or isinstance(b, DynamicShapeVariable): + proxy = tx.output.create_proxy( + "call_function", self.fn, *proxy_args_kwargs([a, b], {}) + ) + return DynamicShapeVariable.create(tx, proxy, None) else: + unimplemented(f"unsupported min / max over args {str(a)}, {str(b)}") call_min = _call_min_max call_max = _call_min_max - def call_range(self, tx, *args, **kwargs): - if self.unspec_python_args(*args, **kwargs) or self.constant_args( - *args, **kwargs - ): - args, kwargs = specialize_args_kwargs(tx, args, kwargs) - return variables.RangeVariable( - value=range( - *[x.value for x in args], - **{k: v.value for k, v in kwargs.items()}, - ), - ) + def call_range(self, tx, *args): + if self.unspec_python_args(*args) or self.constant_args(*args): + args, _ = specialize_args_kwargs(tx, args, {}) + return variables.RangeVariable(args) + elif self._dynamic_args(*args): + + def guard_if_dyn(arg): + if isinstance(arg, DynamicShapeVariable): + return arg.evaluate_expr(tx.output) + return arg + + args = [variables.ConstantVariable(guard_if_dyn(arg)) for arg in args] + return variables.RangeVariable(args) + # None no-ops this handler and lets the driving function proceed + return None + + def _dynamic_args(self, *args, **kwargs): + return any([isinstance(x, DynamicShapeVariable) for x in args]) or any( + [isinstance(x, DynamicShapeVariable) for x in kwargs.values()] + ) def call_slice(self, tx, *args): return variables.SliceVariable(args) - def _call_iter_tuple_list(self, tx, obj=None): + def _dyn_proxy(self, tx, *args, **kwargs): + assert self._dynamic_args(*args, **kwargs) + from .builder import wrap_fx_proxy + + options = VariableTracker.propagate(self, args, kwargs.values()) + return wrap_fx_proxy( + tx, + tx.output.create_proxy( + "call_function", self.fn, *proxy_args_kwargs(args, kwargs) + ), + **options, + ) + + def call_mod(self, tx, *args, **kwargs): + if self._dynamic_args(*args, **kwargs): + return self._dyn_proxy(tx, *args, **kwargs) + + def _call_iter_tuple_list(self, tx, obj=None, *args, **kwargs): + if self._dynamic_args(*args, **kwargs): + return self._dyn_proxy(tx, *args, **kwargs) cls = variables.BaseListVariable.cls_for(self.fn) if obj is None: return cls( @@ -508,6 +576,11 @@ def call_mul(self, tx, a, b): return b.__class__( items=b.items * a.as_python_constant(), mutable_local=MutableLocal() ).add_options(self, a, b) + # TODO this doesn't generalize in other builtin operators. + elif isinstance(a, variables.ConstantVariable) and isinstance( + b, DynamicShapeVariable + ): + return b.call_method(tx, "__rmul__", [a], {}) else: return a.call_method(tx, "__mul__", [b], {}) @@ -536,6 +609,7 @@ def call_getitem(self, tx, *args, **kwargs): def call_isinstance(self, tx, arg, isinstance_type): arg_type = arg.python_type() + isinstance_type = isinstance_type.as_python_constant() if isinstance(arg, variables.TensorVariable) and arg.dtype is not None: diff --git a/torch/_dynamo/variables/constant.py b/torch/_dynamo/variables/constant.py index d3366448e3799..d42760fc26864 100644 --- a/torch/_dynamo/variables/constant.py +++ b/torch/_dynamo/variables/constant.py @@ -13,6 +13,8 @@ class ConstantVariable(VariableTracker): def __init__(self, value, **kwargs): super(ConstantVariable, self).__init__(**kwargs) assert not isinstance(value, torch.Tensor) + assert not isinstance(value, torch.SymInt) + assert not isinstance(value, torch.SymFloat) self.value = value def as_proxy(self): @@ -54,8 +56,8 @@ def unpack_var_sequence(self, tx): try: options = VariableTracker.propagate([self]) return [ConstantVariable(x, **options) for x in self.as_python_constant()] - except TypeError: - raise NotImplementedError() + except TypeError as e: + raise NotImplementedError from e def const_getattr(self, tx, name): member = getattr(self.value, name) @@ -70,6 +72,8 @@ def call_method( args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]", ) -> "VariableTracker": + from .tensor import DynamicShapeVariable + options = VariableTracker.propagate(self, args, kwargs.values()) if istype(self.value, tuple): @@ -78,6 +82,20 @@ def call_method( items=self.unpack_var_sequence(tx), source=self.source, **options ).call_method(tx, name, args, kwargs) + if any([isinstance(x, DynamicShapeVariable) for x in args]): + # NOTE! DANGER! THIS ONLY WORKS FOR COMMUTATIVE OPS + # we are relying on add to have arg[0] be a DynamicShapeVariable + # because we are in ConstantVariable land + # This transforms + # constant + dynamic + # into + # dynamic + constant + # Which already has infra built for writing to the graph + if name == "__add__": + assert len(args) == 1 + return args[0].call_method(tx, name, [self], {}) + # Unfortunate constant + return super(ConstantVariable, self).call_method(tx, name, args, kwargs) try: const_args = [a.as_python_constant() for a in args] const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()} @@ -98,7 +116,19 @@ def has_arith_binop(num_ty): return ConstantVariable(method(*const_args, **const_kwargs), **options) elif has_arith_binop(int) or has_arith_binop(float): op = getattr(operator, name) - return ConstantVariable(op(self.value, const_args[0]), **options) + add_target = const_args[0] + if isinstance(add_target, (torch.SymInt, torch.SymFloat)): + from .tensor import DynamicShapeVariable + + # Addition between a non sym and sym makes a sym + # dyn_shape = tx.output.register_attr_or_module( + # add_target, f"sym_shape_{add_target}", source=None + # ) + proxy = tx.output.create_proxy( + "call_function", op, (self.value, add_target), {} + ) + return DynamicShapeVariable.create(tx, proxy, add_target, **options) + return ConstantVariable(op(self.value, add_target), **options) elif name == "__len__" and not (args or kwargs): return ConstantVariable(len(self.value), **options) elif name == "__contains__" and len(args) == 1 and args[0].is_python_constant(): diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index 30df18f6d6e92..e05eecffc7e61 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -16,8 +16,12 @@ class ConstDictVariable(VariableTracker): - def __init__(self, items, user_cls, **kwargs): - super(ConstDictVariable, self).__init__(**kwargs) + def __init__(self, items, user_cls, recursively_contains=None, **kwargs): + super(ConstDictVariable, self).__init__( + recursively_contains=recursively_contains, **kwargs + ) + + self.guards.update(VariableTracker.propagate(items.values())["guards"]) self.items = items self.user_cls = user_cls @@ -112,7 +116,17 @@ def call_method( tx.store_dict_key(global_key_name(k), k) newval = collections.OrderedDict(val) newval[k] = args[1] - return tx.replace_all(self, self.modifed(newval, **options)) + + new_rec_contains = self.recursively_contains.union( + args[1].recursively_contains + ) + if args[1].mutable_local is not None: + new_rec_contains.add(args[1].mutable_local) + + return tx.replace_all( + self, + self.modifed(newval, new_rec_contains, **options), + ) elif ( name in ("pop", "get") and args @@ -130,7 +144,7 @@ def call_method( ): newval = collections.OrderedDict(val) result = newval.pop(ConstDictVariable.get_key(args[0])) - tx.replace_all(self, self.modifed(newval, **options)) + tx.replace_all(self, self.modifed(newval, None, **options)) return result.add_options(options) elif ( name == "update" @@ -140,7 +154,12 @@ def call_method( ): newval = collections.OrderedDict(val) newval.update(args[0].items) - result = self.modifed(newval, **options) + new_rec_contains = self.recursively_contains.union( + args[0].recursively_contains + ) + result = self.modifed( + newval, recursively_contains=new_rec_contains, **options + ) return tx.replace_all(self, result) elif ( name in ("get", "__getattr__") @@ -159,9 +178,11 @@ def call_method( else: return super().call_method(tx, name, args, kwargs) - def modifed(self, items, **options): + def modifed(self, items, recursively_contains, **options): """a copy of self with different items""" - return self.clone(items=items, **options) + return self.clone( + items=items, recursively_contains=recursively_contains, **options + ) def unpack_var_sequence(self, tx): options = VariableTracker.propagate([self]) @@ -237,7 +258,14 @@ def call_method( f"defaultdict with default_factory = {self.default_factory}" ) new_val[k] = default_var - tx.replace_all(self, self.modifed(new_val, **options)) + new_rec_contains = self.recursively_contains.union( + default_var.recursively_contains + ) + if default_var.mutable_local is not None: + new_rec_contains.add(default_var.mutable_local) + tx.replace_all( + self, self.modifed(new_val, new_rec_contains, **options) + ) return default_var else: return super().call_method(tx, name, args, kwargs) diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 8f1e29bc7e55a..b0259731772e4 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -1,3 +1,4 @@ +import abc import enum import functools import inspect @@ -23,8 +24,12 @@ def wrap_bound_arg(val, options): return cls([wrap_bound_arg(x, options) for x in val], **options) elif variables.ConstantVariable.is_literal(val): return variables.ConstantVariable(val, **options) + elif isinstance(val, types.FunctionType): + return variables.UserFunctionVariable(val, **options) elif isinstance(val, enum.Enum): return variables.EnumVariable(val, **options) + elif isinstance(val, (type, abc.ABCMeta)): + return variables.UserDefinedClassVariable(val, **options) else: assert isinstance(val, VariableTracker), typestr(val) return val @@ -84,7 +89,7 @@ def __init__(self, fn, is_constant=False, **kwargs): assert isinstance( fn, types.FunctionType ), f"expected FunctionType found {typestr(fn)} {fn}" - # unpack @torchdynamo.optimize()(fn) wrapped function + # unpack @torch._dynamo.optimize()(fn) wrapped function fn = inspect.getattr_static(fn, "_torchdynamo_inline", fn) # unpack torch.jit.script_if_tracing if inspect.getattr_static(fn, "__script_if_tracing_wrapper", False): @@ -114,6 +119,8 @@ def bind_args(self, parent, args, kwargs): options = VariableTracker.propagate([self]) wrap = functools.partial(wrap_bound_arg, options=options) + tx = parent.output.root_tx + fn: types.FunctionType = self.fn fake_func = types.FunctionType( fn.__code__, @@ -141,7 +148,7 @@ def bind_args(self, parent, args, kwargs): if name == "__class__": result[name] = variables.UserDefinedClassVariable(cell.cell_contents) else: - var = parent.output.root_tx.match_nested_cell(name, cell) + var = tx.match_nested_cell(name, cell) if var is not None: # optimization for cleaner codegen result[name] = var @@ -158,15 +165,31 @@ def bind_args(self, parent, args, kwargs): closure_cell_contents = AttrSource( closure_cell, "cell_contents" ) + contents_var = VariableBuilder(parent, closure_cell_contents)( + cell.cell_contents + ) + + if ( + closure_cell_contents.name() + not in tx.mutated_closure_cell_contents + ): + # Optimistically don't allocate the cell, to + # reduce the number of side effects. This is + # important for cond, as without it, any accesses + # to closures create side effects and cond doesn't + # support side effects. If we're wrong and this + # closure cell gets written to, we will restart + # the analysis with this cell's name in the + # mutated list here + result[name] = contents_var + continue # cells are written to with "cell_contents", # so the source should just be the closure_cell, not its contents out = side_effects.track_cell_existing(closure_cell, cell) side_effects.store_cell( out, - VariableBuilder(parent, closure_cell_contents)( - cell.cell_contents - ), + contents_var, ) result[name] = out diff --git a/torch/_dynamo/variables/lists.py b/torch/_dynamo/variables/lists.py index e1c0d584073e4..82a7d79a1c367 100644 --- a/torch/_dynamo/variables/lists.py +++ b/torch/_dynamo/variables/lists.py @@ -7,7 +7,7 @@ from ..bytecode_transformation import create_instruction from ..exc import unimplemented from ..source import GetItemSource -from ..utils import namedtuple_fields +from ..utils import namedtuple_fields, proxy_args_kwargs from .base import MutableLocal, VariableTracker from .constant import ConstantVariable @@ -23,10 +23,23 @@ def cls_for(obj): tuple: TupleVariable, }[obj] - def __init__(self, items: List[VariableTracker], **kwargs): - super(BaseListVariable, self).__init__(**kwargs) + def __init__( + self, + items: List[VariableTracker], + recursively_contains=None, + regen_guards=True, + **kwargs, + ): + super(BaseListVariable, self).__init__( + recursively_contains=recursively_contains, **kwargs + ) assert isinstance(items, list) assert all(isinstance(x, VariableTracker) for x in items) + + # Sometimes, we know that we have passed in the guards from the items in the list + if regen_guards: + self.guards.update(VariableTracker.propagate(items)["guards"]) + self.items: List[VariableTracker] = items def _as_proxy(self): @@ -89,42 +102,51 @@ def call_method( class RangeVariable(BaseListVariable): - def __init__(self, value, items=None, guards=None, **kwargs): - if items is None: - items = [variables.ConstantVariable(x, guards=guards) for x in value] - super().__init__(items, guards=guards, **kwargs) - self.value = value + def __init__(self, items, **kwargs): + items_to_map = items + start = variables.ConstantVariable(0) + stop = None + step = variables.ConstantVariable(1) + + if len(items_to_map) == 1: + (stop,) = items_to_map + elif len(items_to_map) == 2: + start, stop = items_to_map + elif len(items_to_map) == 3: + start, stop, step = items_to_map + else: + raise AssertionError() + + assert stop is not None + super().__init__([start, stop, step], **kwargs) def python_type(self): return range def as_python_constant(self): - return self.value + return range(*[x.as_python_constant() for x in self.items]) - def reconstruct(self, codegen): - assert "range" not in codegen.tx.f_globals - range_fn = codegen.create_load_global("range", add=True) - if self.value.step == 1: - if self.value.start == 0: - return [ - range_fn, - codegen.create_load_const(self.value.stop), - create_instruction("CALL_FUNCTION", 1), - ] - return [ - range_fn, - codegen.create_load_const(self.value.start), - codegen.create_load_const(self.value.stop), - create_instruction("CALL_FUNCTION", 2), - ] + def as_proxy(self): + return self.python_type()(*self._as_proxy()) + + def unpack_var_sequence(self, tx): return [ - range_fn, - codegen.create_load_const(self.value.start), - codegen.create_load_const(self.value.stop), - codegen.create_load_const(self.value.step), - create_instruction("CALL_FUNCTION", 3), + variables.ConstantVariable(x).add_options(self) + for x in self.as_python_constant() ] + def reconstruct(self, codegen): + assert "range" not in codegen.tx.f_globals + codegen.append_output(codegen.create_load_python_module(range)) + codegen.foreach(self.items) + return [create_instruction("CALL_FUNCTION", 3)] + + def var_getattr(self, tx, name): + fields = ["start", "stop", "step"] + if name not in fields: + unimplemented(f"range.{name}") + return self.items[fields.index(name)].add_options(self) + class ListVariable(BaseListVariable): def python_type(self): @@ -145,9 +167,17 @@ def call_method( if name == "append" and self.mutable_local: assert not kwargs (arg,) = args + new_rec_contains = self.recursively_contains.union(arg.recursively_contains) + if arg.mutable_local is not None: + new_rec_contains.add(arg.mutable_local) tx.replace_all( self, - ListVariable(self.items + [arg], **options), + ListVariable( + self.items + [arg], + recursively_contains=new_rec_contains, + regen_guards=False, + **options, + ), ) return ConstantVariable(None) elif ( @@ -162,6 +192,7 @@ def call_method( self, ListVariable( list(self.items) + list(arg.unpack_var_sequence(tx)), + regen_guards=False, **options, ), ) @@ -172,7 +203,7 @@ def call_method( items.insert(idx.as_python_constant(), value) return tx.replace_all( self, - ListVariable(items, **options), + ListVariable(items, regen_guards=False, **options), ) elif name == "pop" and self.mutable_local: assert not kwargs @@ -180,14 +211,14 @@ def call_method( result = items.pop(*[a.as_python_constant() for a in args]) tx.replace_all( self, - ListVariable(items, **options), + ListVariable(items, regen_guards=False, **options), ) return result elif name == "clear" and self.mutable_local: assert not kwargs and not args return tx.replace_all( self, - ListVariable([], **options), + ListVariable([], regen_guards=False, **options), ) elif ( name == "__setitem__" @@ -202,7 +233,7 @@ def call_method( items[key.as_python_constant()] = list(value.items) else: items[key.as_python_constant()] = value - result = ListVariable(items, **options) + result = ListVariable(items, regen_guards=False, **options) return tx.replace_all(self, result) else: return super().call_method(tx, name, args, kwargs) @@ -308,6 +339,57 @@ def reconstruct(self, codegen): ] return build_torch_size + def unpack_var_sequence(self, tx): + return [x.add_options(self) for x in self.items] + + def call_method( + self, + tx, + name, + args: "List[VariableTracker]", + kwargs: "Dict[str, VariableTracker]", + ) -> "VariableTracker": + options = VariableTracker.propagate(self, args, kwargs.values()) + if name == "__getitem__": + assert not kwargs and len(args) == 1 + if config.dynamic_shapes: + out = self.get_item_dyn(tx, args[0]) + else: + out = self.getitem_const(args[0]) + return out + return super(SizeVariable, self).call_method(tx, name, args, kwargs) + + def get_item_dyn(self, tx, arg: VariableTracker): + from .tensor import DynamicShapeVariable + + index = arg.as_python_constant() + if isinstance(index, slice): + + def _dynamo_get_item_lambda(target, index): + return torch.Size.__getitem__(target, index) + + parent_proxy = self.as_proxy() + proxy = tx.output.create_proxy( + "call_function", + _dynamo_get_item_lambda, + *proxy_args_kwargs([self, arg], {}), + ) + items = self.items[index] + + def _unpack_into_example(item): + if isinstance(item, DynamicShapeVariable): + return item.dyn_shape + return item.as_python_constant() + + # Mirror the indexing into example_value for downstream correctness + proxy.node.meta["example_value"] = parent_proxy.node.meta["example_value"][ + index + ] + return SizeVariable(items, proxy=proxy).add_options(arg, self) + else: + assert isinstance(index, int) + return self.items[index].add_options(arg, self) + class ShapeVariable(TupleVariable): """ @@ -326,6 +408,9 @@ def __init__(self, items, tuple_cls, **kwargs): def python_type(self): return self.tuple_cls + def as_python_constant(self): + return self.python_type()(*[x.as_python_constant() for x in self.items]) + def reconstruct(self, codegen): create_fn = getattr(self.tuple_cls, "_make", self.tuple_cls) codegen.append_output(codegen._create_load_const(create_fn)) @@ -349,13 +434,20 @@ def call_hasattr(self, tx, name: str) -> "VariableTracker": class SliceVariable(BaseListVariable): def __init__(self, items, **kwargs): + from .tensor import DynamicShapeVariable + + if any([isinstance(x, DynamicShapeVariable) for x in items]): + unimplemented("Dynamic slicing not supported") + + items_to_map = items start, stop, step = [variables.ConstantVariable(None)] * 3 - if len(items) == 1: - (stop,) = items - elif len(items) == 2: - start, stop = items - elif len(items) == 3: - start, stop, step = items + + if len(items_to_map) == 1: + (stop,) = items_to_map + elif len(items_to_map) == 2: + start, stop = items_to_map + elif len(items_to_map) == 3: + start, stop, step = items_to_map else: raise AssertionError() @@ -366,7 +458,7 @@ def __init__(self, items, **kwargs): # more complete support for breaking on data dependent operators. if not config.capture_scalar_outputs: for limit in (start, stop, step): - if isinstance(limit, variables.TensorVariable): + if isinstance(limit, (variables.TensorVariable, DynamicShapeVariable)): unimplemented("Dynamic slicing not supported") super().__init__([start, stop, step], **kwargs) @@ -392,10 +484,15 @@ def var_getattr(self, tx, name): class ListIteratorVariable(VariableTracker): - def __init__(self, items, index: int = 0, **kwargs): - super(ListIteratorVariable, self).__init__(**kwargs) + def __init__(self, items, index: int = 0, recursively_contains=None, **kwargs): + super(ListIteratorVariable, self).__init__( + recursively_contains=recursively_contains, **kwargs + ) assert isinstance(items, list) - assert all(isinstance(x, VariableTracker) for x in items) + # Removing this check as it slows things down too much + # https://github.com/pytorch/pytorch/pull/87533#issuecomment-1287574492 + + # assert all(isinstance(x, VariableTracker) for x in items) self.items = items self.index = index @@ -407,6 +504,7 @@ def next_variables(self): self.items, self.index + 1, mutable_local=MutableLocal(), + recursively_contains=self.recursively_contains, **VariableTracker.propagate([self]), ) diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 8dd3478114396..58db779178f53 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -107,6 +107,10 @@ def __init__(self, target_values, initial_values=None, **kwargs): super(ContextWrappingVariable, self).__init__(**kwargs) self.target_values = target_values self.initial_values = initial_values + self.recursively_contains = ( + set() + ) # This var doesn't contain any child vars and doesn't support clone() properly, + # so don't populate this automatically def enter(self, tx): self._call_func(tx, self.target_values) @@ -116,6 +120,9 @@ def exit(self, tx, *args): self._call_func(tx, self.initial_values) return variables.ConstantVariable(None, **VariableTracker.propagate(self)) + def module_name(self): + return "torch" + def reconstruct(self, codegen, target_inst=None): """ Generate following Python Bytecode, with a `torch._C._set_grad_enable` call @@ -274,7 +281,7 @@ def enter(self, tx): def _call_func(self, tx, values): assert len(values) == 1 value = values[0] - tx.output.graph.create_node( + tx.output.create_node( "call_function", torch._C._set_grad_enabled, (value,), {} ), torch._C._set_grad_enabled(value) @@ -283,7 +290,7 @@ def _func_name(self): return "_C._set_grad_enabled" def fn_name(self): - if self.target_values: + if self.target_values[0]: return "enable_grad" else: return "no_grad" @@ -291,7 +298,7 @@ def fn_name(self): class AutocastModeVariable(ContextWrappingVariable): @staticmethod - def create(tx, target_values, kwargs): + def create(target_values, kwargs): values = target_values # device_type : str, # dtype : Optional[_dtype] = None, @@ -319,10 +326,10 @@ def create(tx, target_values, kwargs): else: values.append(variables.ConstantVariable(None)) - var = AutocastModeVariable(tx, target_values, initial_values=None, **kwargs) + var = AutocastModeVariable(target_values, initial_values=None, **kwargs) return var - def __init__(self, tx, target_values, initial_values=None, **kwargs): + def __init__(self, target_values, initial_values=None, **kwargs): super(AutocastModeVariable, self).__init__( target_values=target_values, initial_values=initial_values, **kwargs ) @@ -330,12 +337,12 @@ def __init__(self, tx, target_values, initial_values=None, **kwargs): self.mode = None def exit(self, tx, *args): - tx.output.graph.create_node( + tx.output.create_node( "call_function", exit_functional_autocast, (self.mode,), {} ) def enter(self, tx): - self.mode = tx.output.graph.create_node( + self.mode = tx.output.create_node( "call_function", enter_functional_autocast, (*self.target_values,), {} ) @@ -356,11 +363,15 @@ def exit_functional_autocast(mode): mode.__exit__(None, None, None) -class ProfilerContextWrapperVariable(ContextWrappingVariable): +class NullContextVariable(ContextWrappingVariable): + """ + This class represents Python contextlib.nullcontext. + It's used as a placeholder for other context managers that Dynamo doesn't + support yet, e.g, torch.autograd.profiler.record_function. + """ + def __init__(self, target_values=None, **kwargs): - super(ProfilerContextWrapperVariable, self).__init__( - target_values=target_values, **kwargs - ) + super(NullContextVariable, self).__init__(target_values=target_values, **kwargs) def enter(self, tx): return variables.ConstantVariable(None, **VariableTracker.propagate(self)) @@ -368,8 +379,11 @@ def enter(self, tx): def exit(self, tx, *args): return variables.ConstantVariable(None, **VariableTracker.propagate(self)) + def module_name(self): + return "contextlib" + def fn_name(self): - return "autograd.profiler.profile" + return "nullcontext" class WithExitFunctionVariable(VariableTracker): @@ -389,7 +403,7 @@ def reconstruct(self, codegen): # exit function. The handler generated by BlockStackEntry # will re-enter the context in the resume function. output = AttrSource( - codegen.tx.import_source("torch"), self.ctx.fn_name() + codegen.tx.import_source(self.ctx.module_name()), self.ctx.fn_name() ).reconstruct(codegen) if codegen.tx.output.partial_convert: @@ -445,9 +459,19 @@ def visit(node): args = [BlackHoleVariable()] + list(args) options = VariableTracker.propagate(self, args, kwargs.values()) - return variables.UserFunctionVariable( - self.fn_cls.forward, **options - ).call_function(tx, args, kwargs) + fn = self.fn_cls.forward + if isinstance(fn, types.FunctionType): + return variables.UserFunctionVariable(fn, **options).call_function( + tx, args, kwargs + ) + elif isinstance(fn, types.MethodType): + return variables.UserMethodVariable( + fn.__func__, variables.UserDefinedClassVariable(self.fn_cls), **options + ).call_function(tx, args, kwargs) + else: + unimplemented( + f"non-function or method in subclass of torch.autograd.Function: {fn}" + ) def call_function(self, tx, args, kwargs): options = VariableTracker.propagate(self, args, kwargs.values()) @@ -513,6 +537,7 @@ def reconstruct(self, codegen): def call_function( self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" ) -> "VariableTracker": + from .builder import wrap_fx_proxy # This variable is True when it corresponds to user code such as # @@ -530,7 +555,7 @@ def call_function( if is_original_tensor_torch_function: # Instead of tracing inside torch.Tensor.__torch_function__, # record the `call_function` or `call_method` call into the graph. - from . import TensorVariable, TorchVariable + from . import TorchVariable original_torch_or_getattr_variable = args[0] new_args = args[2].items @@ -540,24 +565,22 @@ def call_function( # example tensor from going into the override. with torch._C.DisableTorchFunction(): if isinstance(args[0], TorchVariable): - return TensorVariable.create( + return wrap_fx_proxy( tx=tx, proxy=tx.output.create_proxy( "call_function", original_torch_or_getattr_variable.value, *proxy_args_kwargs(new_args, new_kwargs), - current_tx=tx, ), **options, ) elif isinstance(args[0], GetAttrVariable): - return TensorVariable.create( + return wrap_fx_proxy( tx=tx, proxy=tx.output.create_proxy( "call_method", original_torch_or_getattr_variable.name, *proxy_args_kwargs(new_args, new_kwargs), - current_tx=tx, ), **options, ) @@ -643,6 +666,12 @@ def call_method( ) unimplemented("typing") + def python_type(self): + return type(self.value) + + def as_python_constant(self): + return self.value + class NumpyVariable(VariableTracker): """ diff --git a/torch/_dynamo/variables/nn_module.py b/torch/_dynamo/variables/nn_module.py index 4da389bbd8c47..c9f0f7792ec91 100644 --- a/torch/_dynamo/variables/nn_module.py +++ b/torch/_dynamo/variables/nn_module.py @@ -1,7 +1,6 @@ import functools import inspect import itertools -import re import types from contextlib import contextmanager from typing import Dict, List @@ -14,7 +13,13 @@ from ..guards import GuardBuilder from ..mutation_guard import GenerationTracker from ..source import AttrSource, GetItemSource, NNModuleSource, NotNNModuleSource -from ..utils import is_lazy_module, istype, proxy_args_kwargs +from ..utils import ( + is_lazy_module, + is_safe_constant, + istensor, + istype, + proxy_args_kwargs, +) from .base import MutableLocal, typestr, VariableTracker from .functions import invoke_and_store_as_constant from .lists import SliceVariable @@ -139,6 +144,9 @@ def var_getattr(self, tx, name): return variables.UserFunctionVariable(subobj.__get__(base), **options) elif istype(subobj, types.FunctionType): return variables.UserMethodVariable(subobj, self, **options) + elif is_safe_constant(subobj) or istensor(subobj): + # Support possibly common cases of class members + return VariableBuilder(tx, NNModuleSource(source))(subobj) else: unimplemented(f"class property {typestr(base)} {typestr(subobj)}") @@ -156,7 +164,7 @@ def call_function( @contextmanager def record_nn_module_stack(): try: - tx.nn_module_stack[self.module_key] = type(mod) + tx.nn_module_stack[self.module_key] = str(type(mod)) yield finally: del tx.nn_module_stack[self.module_key] @@ -188,14 +196,14 @@ def record_nn_module_stack(): # The module type will change after it is called if is_lazy: self.module_type = mod.cls_to_become + from .builder import wrap_fx_proxy - return variables.TensorVariable.create( + return wrap_fx_proxy( tx=tx, proxy=tx.output.create_proxy( "call_module", self.module_key, *proxy_args_kwargs(args, kwargs), - current_tx=tx, ), **options, ) @@ -273,18 +281,15 @@ def get_kwargs(*names): bound_args = bound_args.arguments return {k: bound_args[k] for k in names} - def wrap_values(items, getsource=AttrSource): + def wrap_values(items): result = [] for name, submod in items: - # layer.0.foo => layer[0].foo - name = re.sub(r"[.]([0-9]+)([.]|$)", r"[\1]\2", name) - src = NNModuleSource(getsource(self.source, name)) result.append( tx.output.register_attr_or_module( submod, key, name, - source=src, + source=NNModuleSource(gen_source(self.source, name)), **options, ) ) @@ -298,12 +303,21 @@ def named_embed(name, obj): obj, key, name, - source=NNModuleSource(GetItemSource(self.source, name)), + source=NNModuleSource(gen_source(self.source, name)), **options, ), ] ) + def gen_source(source, name): + name_split = name.split(".") + if name_split[0] == "": + return source + while len(name_split) > 0: + x = name_split.pop(0) + source = AttrSource(source, x) + return source + if name == "children": assert not (args or kwargs) return wrap_values(module.named_children()) @@ -314,6 +328,13 @@ def named_embed(name, obj): ): result.append(named_embed(name, param)) return ListIteratorVariable(result, mutable_local=MutableLocal(), **options) + elif name == "named_buffers": + result = [] + for name, buffer in module.named_buffers( + **get_kwargs("prefix", "recurse", "remove_duplicate") + ): + result.append(named_embed(name, buffer)) + return ListIteratorVariable(result, mutable_local=MutableLocal(), **options) elif name == "named_modules": result = [] for name, submod in module.named_modules( @@ -321,11 +342,13 @@ def named_embed(name, obj): ): result.append(named_embed(name, submod)) return ListIteratorVariable(result, mutable_local=MutableLocal(), **options) + elif name == "modules": + return wrap_values(module.named_modules()) elif name == "parameters": return wrap_values(module.named_parameters(**get_kwargs("recurse"))) elif name == "values": assert not (args or kwargs) - return wrap_values(module.items(), GetItemSource) + return wrap_values(module.items()) elif name == "items": assert not (args or kwargs) result = [] @@ -436,14 +459,15 @@ def make_attr(name): proxy_args, proxy_kwargs = proxy_args_kwargs(args, kwargs) - return variables.TensorVariable.create( + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( tx=tx, proxy=tx.output.create_proxy( "call_method", name, args=(proxy_for_mod, *proxy_args), kwargs=proxy_kwargs, - current_tx=tx, ), **options, ) @@ -480,8 +504,8 @@ def unpack_var_sequence(self, tx): try: fn = inspect.getattr_static(self.value_type, "__iter__") - except AttributeError: - raise NotImplementedError() + except AttributeError as e: + raise NotImplementedError from e if fn in ( torch.nn.ModuleList.__iter__, @@ -546,7 +570,12 @@ def call_method( return variables.ListIteratorVariable( items, mutable_local=MutableLocal(), **options ) - + elif isinstance(method, staticmethod): + return tx.inline_user_function_return( + variables.UserFunctionVariable(method.__func__, **options), + args, + kwargs, + ) if id(method.__code__) in self._nn_module_method_ids(): unimplemented(f"UnspecializedNNModuleVariable missing {name}") diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index a8db819cb272d..f1b30f9212423 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -1,161 +1,27 @@ -import copy -import functools import itertools -import math -import numbers import operator from typing import Dict, List import torch.fx import torch.random -from ..utils import fake_tensors_available - -if fake_tensors_available: - from torch._subclasses import FakeTensor - from torch._subclasses.fake_tensor import ( - DataDependentOutputException, - DynamicOutputShapeException, - ) - from ..utils import deepcopy_to_fake_tensor, wrap_to_fake_tensor - -import torch.utils._python_dispatch as py_dispatch -from torch.fx.immutable_collections import immutable_list -from torch.utils._pytree import tree_map - from .. import config, variables -from ..exc import TorchRuntimeError, unimplemented, Unsupported +from ..exc import unimplemented from ..guards import GuardBuilder from ..source import AttrSource + from ..utils import ( - clone_input, - is_lazy_module, - istype, - preserve_rng_state, + get_fake_value, + get_real_value, product, proxy_args_kwargs, tensortype_to_dtype, ) -from .base import MutableLocal, typestr, VariableTracker +from .base import VariableTracker from .constant import ConstantVariable from .lists import ShapeVariable, SizeVariable -class _missing: - pass - - -def _run_node(output_graph, node, args, kwargs, nnmodule): - op = node.op - if op == "call_function": - return node.target(*args, **kwargs) - elif op == "call_method": - return getattr(args[0], node.target)(*args[1:], **kwargs) - elif op == "call_module": - assert nnmodule is not None - return nnmodule(*args, **kwargs) - elif op == "get_attr": - return output_graph.get_submodule(node.target) - raise AssertionError(op) - - -def _get_real_value(node, output_graph): - """ - Run the actual computation represented by `node` and return the result. - This will execute any dependent nodes in the graph as well. - """ - cache = output_graph.real_value_cache - if node in cache: - return cache[node] - - op = node.op - args, kwargs = torch.fx.node.map_arg( - (node.args, node.kwargs), - lambda n: _get_real_value(n, output_graph), - ) - - if op == "call_module": - nn_module = output_graph.nn_modules[node.target] - if not is_lazy_module(nn_module): - nn_module = copy.deepcopy(nn_module) - else: - # In the case of a lazy module, we want to run - # the pre-hooks which initialize it - nn_module(*args, **kwargs) - else: - nn_module = None - - try: - real_value = _run_node(output_graph, node, args, kwargs, nn_module) - cache[node] = real_value - except RuntimeError as e: - raise TorchRuntimeError() from e - return real_value - - -def _get_fake_value(node, tx): - """ - Run the computation represented by `node` using fake tensors and return the result. - """ - op = node.op - fake_wrapper = functools.partial(wrap_to_fake_tensor, fake_mode=tx.fake_mode) - from ..utils import wrap_fake_exception - - def visit(n: torch.fx.Node): - return n.meta["example_value"] - - args, kwargs = torch.fx.node.map_arg((node.args, node.kwargs), visit) - args = tree_map(fake_wrapper, args) - kwargs = tree_map(fake_wrapper, kwargs) - - nnmodule = None - if op == "call_module": - nnmodule = tx.output.nn_modules[node.target] - - if not is_lazy_module(nnmodule): - nnmodule = deepcopy_to_fake_tensor(nnmodule, tx.fake_mode) - - def context(): - if hasattr(py_dispatch, "enable_torch_dispatch_mode"): - return py_dispatch.enable_torch_dispatch_mode(tx.fake_mode) - else: - return tx.fake_mode - - if op == "call_module" and is_lazy_module(nnmodule): - assert nnmodule is not None - # In the case of a lazy module, we want to run - # the pre-hooks which initialize it - nnmodule(*args, **kwargs) - try: - with context(): - return wrap_fake_exception( - lambda: _run_node(tx.output, node, args, kwargs, nnmodule) - ) - except Unsupported: - raise - except RuntimeError as e: - if isinstance(e, DataDependentOutputException): - if config.capture_scalar_outputs and node.target == "item": - return torch.zeros(size=(), dtype=args[0].dtype).item() - else: - unimplemented(f"data dependent operator: {e.func}") - elif isinstance(e, DynamicOutputShapeException): - unimplemented(f"dynamic shape operator: {e.func}") - else: - raise TorchRuntimeError() from e - - -def _clone_input(value): - if isinstance(value, torch.Tensor): - use_fake_tensors = fake_tensors_available and config.fake_tensor_propagation - # tensor subclasses will not be converted to FakeTensors and need to be cloned - if not use_fake_tensors or not isinstance(value, FakeTensor): - # NB: ensure strides are preserved - value = clone_input(value) - - return value - - class TensorVariable(VariableTracker): """A torch.Tensor input or an intermediate value in the FX graph""" @@ -163,6 +29,7 @@ class TensorVariable(VariableTracker): "proxy", "dtype", "device", + "layout", "ndim", "size", "stride", @@ -178,176 +45,14 @@ def get_real_value(self): NOTE: this runs actual tensor computation and may be slow and memory-intensive. """ - return _get_real_value(self.proxy.node, self.proxy.tracer) - - @classmethod - def create(cls, tx, proxy, example_value=None, **options): - if "guards" in options and options["guards"] is not None: - tx.output.guards.update(options["guards"]) - - assert "example_value" not in proxy.node.meta - if not config.dynamic_propagation: - if isinstance(example_value, torch.Tensor): - options.update(cls.specialize(example_value)) - return cls(proxy, **options) - - use_fake_tensors = fake_tensors_available and config.fake_tensor_propagation - - initial_example_value = example_value - - with preserve_rng_state(): - if example_value is None: - if use_fake_tensors: - example_value = _get_fake_value(proxy.node, tx) - else: - example_value = _get_real_value(proxy.node, tx.output) - - else: - proxy.tracer.real_value_cache[proxy.node] = _clone_input(example_value) - if use_fake_tensors: - fake_wrapper = functools.partial( - wrap_to_fake_tensor, fake_mode=tx.fake_mode - ) - example_value = fake_wrapper(example_value) - - if isinstance(example_value, torch.Tensor): - is_parameter = isinstance(example_value, torch.nn.Parameter) - should_specialize = options.pop("should_specialize", False) - if is_parameter or should_specialize: - specialized_value = initial_example_value - else: - specialized_value = None - - example_value = _clone_input(example_value) - proxy.node.meta["example_value"] = example_value - specialized_props = cls.specialize(example_value) - if use_fake_tensors and isinstance(example_value, FakeTensor): - specialized_props["class_type"] = ( - torch.nn.Parameter if is_parameter else torch.Tensor - ) - - specialized_props["specialized_value"] = specialized_value - - options.update(specialized_props) - return cls(proxy, **options) - elif ( - hasattr(proxy.node.target, "__name__") - and proxy.node.target.__name__ == "set_state" - and isinstance(proxy.node.target.__self__, torch._C.Generator) - or proxy.node.target == torch.random.set_rng_state - ): - from . import TorchVariable - - return TorchVariable(proxy.node.target) - elif istype(example_value, (int, bool, float)) and config.dynamic_shapes: - proxy.node.meta["example_value"] = example_value - return DynamicShapeVariable(proxy, type(example_value), **options) - elif istype(example_value, torch.Size) and config.dynamic_shapes: - proxy.node.meta["example_value"] = example_value - sizes = [] - for i, v in enumerate(example_value): - proxy_i = proxy[i] - proxy_i.node.meta["example_value"] = v - sizes.append(DynamicShapeVariable(proxy_i, int)) - return SizeVariable(sizes, proxy, **options) - elif istype(example_value, int) and proxy.node.target in ( - torch.seed, - operator.mod, - # some mac builds are missing torch.distributed.get_rank() - getattr(torch.distributed, "get_rank", _missing), - getattr(torch.distributed, "get_world_size", _missing), - ): - proxy.node.meta["example_value"] = example_value - return DynamicShapeVariable(proxy, type(example_value), **options) - elif istype(example_value, torch.Size) and all( - [isinstance(x, int) for x in example_value] - ): - sizes = [variables.ConstantVariable(x) for x in example_value] - return SizeVariable(sizes, **options) - elif isinstance(example_value, (tuple, list)): - unpacked = [] - for i, val in enumerate(example_value): - if val is None: - # nn.MultiheadAttention() can return None, see issue #175 - unpacked.append( - variables.ConstantVariable(None, **options), - ) - else: - unpacked.append( - cls.create( - tx, - proxy.tracer.create_proxy( - "call_function", operator.getitem, (proxy, i), {} - ), - example_value=val, - **options, - ) - ) - if istype(example_value, tuple): - return variables.TupleVariable(unpacked, **options) - elif istype(example_value, (list, immutable_list)): - return variables.ListVariable( - unpacked, mutable_local=MutableLocal(), **options - ) - else: - assert ( - example_value.__class__.__module__ == "torch.return_types" - or hasattr(example_value, "_fields") - ), "namedtuple?" - return variables.NamedTupleVariable( - unpacked, example_value.__class__, **options - ) - elif example_value is None or proxy.node.target is torch.manual_seed: - return variables.ConstantVariable(None, **options) - elif ( - isinstance(example_value, int) - and proxy.node.target is torch._utils._element_size - ): - proxy.node.meta["example_value"] = example_value - return variables.ConstantVariable(example_value, **options) - elif ( - isinstance(example_value, numbers.Number) - and ( - proxy.node.target == "item" - or proxy.node.target in {math.sqrt, math.pow} - ) - and config.capture_scalar_outputs - ): - if use_fake_tensors: - # item raw value should not be accessed - return FakeItemVariable.create( - tx=tx, - proxy=proxy, - example_value=torch.tensor(example_value), - **options, - ) - else: - return UnspecializedPythonVariable.create( - tx=tx, - proxy=proxy, - example_value=torch.tensor(example_value), - raw_value=None if use_fake_tensors else example_value, - need_unwrap=False, - **options, - ) - elif ( - proxy.node.target == torch._C._DisableFuncTorch - or proxy.node.target == torch.cuda._is_in_bad_fork - ): - from . import UserDefinedObjectVariable - - return UserDefinedObjectVariable(example_value) - else: - raise AssertionError( - "torch.* op returned non-Tensor " - + f"{typestr(example_value)} {proxy.node.op} {proxy.node.target}" - ) + return get_real_value(self.proxy.node, self.proxy.tracer) def __init__( self, proxy: torch.fx.Proxy, dtype=None, device=None, + layout=None, ndim=None, size=None, stride=None, @@ -363,6 +68,7 @@ def __init__( self.proxy = proxy self.dtype = dtype self.device = device + self.layout = layout self.ndim = ndim self.size = size self.stride = stride @@ -397,6 +103,7 @@ def specialize(value: torch.Tensor): props = { "dtype": value.dtype, "device": value.device, + "layout": value.layout, "ndim": int(value.ndim), "requires_grad": value.requires_grad, "is_quantized": value.is_quantized, @@ -406,7 +113,13 @@ def specialize(value: torch.Tensor): if not config.dynamic_shapes: props["size"] = tuple(value.size()) props["stride"] = tuple(value.stride()) - props["is_contiguous"] = value.is_contiguous() + props["is_contiguous"] = tuple( + [ + x + for x in torch._prims_common._memory_formats + if value.is_contiguous(memory_format=x) + ] + ) return props def var_getattr(self, tx, name): @@ -420,6 +133,8 @@ def var_getattr(self, tx, name): result = TorchVariable(self.dtype, **options) elif name == "device" and self.device is not None: result = TorchVariable(self.device, **options) + elif name == "layout" and self.layout is not None: + result = TorchVariable(self.layout, **options) elif name == "is_cuda" and self.device is not None: result = ConstantVariable(self.device.type == "cuda", **options) elif name == "shape" and self.size is not None: @@ -435,6 +150,11 @@ def var_getattr(self, tx, name): result = self.call_method(tx, "size", [], {}) elif name == "ndim" and self.ndim is None: result = self.call_method(tx, "dim", [], {}) + elif name == "data": + result = self.call_method(tx, "detach", [], {}) + elif name == "T": + args = [variables.ConstantVariable(i) for i in range(self.ndim - 1, -1, -1)] + result = self.call_method(tx, "permute", args, {}) if name == "__class__": return TorchVariable(self.python_type(), **options) @@ -450,17 +170,16 @@ def var_getattr(self, tx, name): return result - def unpack_var_sequence(self, tx): - options = VariableTracker.propagate(self) - if self.size: - return [ - variables.BuiltinVariable(operator.getitem, **options).call_function( - tx, [self, variables.ConstantVariable(i)], {} - ) - for i in range(self.size[0]) - ] + def unpack_var_sequence(self, tx, idxes=None): + from .builder import wrap_fx_proxy - return super(TensorVariable, self).unpack_var_sequence(tx) + if idxes is None: + if self.size: + idxes = range(self.size[0]) + else: + return super(TensorVariable, self).unpack_var_sequence(tx) + options = VariableTracker.propagate(self) + return [wrap_fx_proxy(tx, self.as_proxy()[i], **options) for i in idxes] def call_method( self, @@ -469,31 +188,60 @@ def call_method( args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]", ) -> "VariableTracker": - from . import ConstantVariable, TupleVariable + from . import ConstantVariable, TorchVariable, TupleVariable + from .builder import wrap_fx_proxy kwargs = dict(kwargs) - options = VariableTracker.propagate(self, args, kwargs.values()) - if name == "stride" and self.stride is not None: constant_result = ConstantVariable(self.stride, **options) elif name == "size" and self.size is not None: sizes = [variables.ConstantVariable(x) for x in self.size] constant_result = SizeVariable(sizes, **options) - elif name == "numel" and self.size is not None: + elif name == "size" and self.size is None and config.dynamic_shapes: + return wrap_fx_proxy( + tx, + tx.output.create_proxy( + "call_method", + name, + *proxy_args_kwargs([self] + list(args), kwargs), + ), + **options, + ) + elif name in ("numel", "nelement") and self.size is not None: constant_result = ConstantVariable(product(self.size), **options) elif name in ("ndimension", "dim") and self.ndim is not None: constant_result = ConstantVariable(self.ndim, **options) elif name == "is_floating_point" and self.dtype is not None: constant_result = ConstantVariable(self.dtype.is_floating_point, **options) elif name == "is_contiguous" and self.is_contiguous is not None: - if ( - "memory_format" in kwargs - and kwargs["memory_format"].as_python_constant() - == torch.contiguous_format - ): - kwargs.pop("memory_format") - constant_result = ConstantVariable(self.is_contiguous, **options) + if "memory_format" in kwargs: + memory_format = kwargs.pop("memory_format").as_python_constant() + else: + memory_format = torch.contiguous_format + constant_result = ConstantVariable( + memory_format in self.is_contiguous, **options + ) + elif ( + name == "type" + and self.dtype is not None + and len(args) == 0 + and isinstance(self.device, torch.device) + ): + tensortype = [k for k, v in tensortype_to_dtype.items() if self.dtype in v][ + 0 + ] + if self.device.type == "cuda": + constant_result = ConstantVariable( + f"torch.cuda.{tensortype.__name__}", **options + ) + else: + constant_result = ConstantVariable( + f"torch.{tensortype.__name__}", **options + ) + elif name == "get_device" and isinstance(self.device, torch.device): + index = self.device.index if self.device.type != "cpu" else -1 + constant_result = ConstantVariable(index, **options) else: constant_result = None @@ -514,17 +262,22 @@ def call_method( and not config.dynamic_shapes ): unimplemented("dynamic Tensor.repeat") - elif name in ("tolist", "numpy", "backward"): + elif name in ("tolist", "numpy", "backward", "data_ptr"): unimplemented(f"Tensor.{name}") elif name == "nonzero" and not config.dynamic_shapes: unimplemented(f"Tensor.{name}") elif name == "item": if config.capture_scalar_outputs: - return self.__class__.create( + example_value = get_fake_value(self.proxy.node, tx) + return wrap_fx_proxy( tx, tx.output.create_proxy( - "call_method", "item", (self.as_proxy(),), {}, current_tx=tx + "call_method", + "item", + (self.as_proxy(),), + {}, ), + example_value=example_value, **options, ) else: @@ -534,10 +287,13 @@ def call_method( assert not config.dynamic_shapes return ConstantVariable(self.size[0], **options) else: - return self.__class__.create( + return wrap_fx_proxy( tx, tx.output.create_proxy( - "call_function", len, (self.as_proxy(),), {}, current_tx=tx + "call_function", + len, + (self.as_proxy(),), + {}, ), **options, ) @@ -546,10 +302,59 @@ def call_method( tx.output.create_proxy( "call_function", operator.setitem, - *proxy_args_kwargs([self] + args, kwargs), - current_tx=tx, + *proxy_args_kwargs([self] + list(args), kwargs), ) return ConstantVariable(None, **options) + elif name in ("resize_", "resize_as_"): + if "memory_format" in kwargs: + memory_format = kwargs["memory_format"].as_python_constant() + else: + memory_format = torch.contiguous_format + + if name == "resize_": + self.size = args[0].as_python_constant() + self.is_contiguous = (memory_format,) + else: + assert isinstance(args[0], TensorVariable) + if self.size and args[0].size: + if ( + self.size == args[0].size + or memory_format is torch.preserve_format + ): + self.is_contiguous = args[0].is_contiguous + else: + self.size = args[0].size + self.stride = args[0].stride + self.ndim = args[0].ndim + self.is_contiguous = (memory_format,) + + return wrap_fx_proxy( + tx, + tx.output.create_proxy( + "call_method", + name, + *proxy_args_kwargs([self] + list(args), kwargs), + ), + **options, + ) + elif ( + name == "add_" and len(args) == 1 and len(kwargs) == 1 and "alpha" in kwargs + ): + result = TorchVariable(torch.mul, **options).call_function( + tx, args + [kwargs["alpha"]], {} + ) + return self.call_method(tx, "add_", [result], {}) + elif ( + name == "addcdiv_" + and len(args) == 2 + and len(kwargs) == 1 + and "value" in kwargs + ): + result = TorchVariable(torch.div, **options).call_function(tx, args, {}) + result = TorchVariable(torch.mul, **options).call_function( + tx, [result, kwargs["value"]], {} + ) + return self.call_method(tx, "add_", [result], {}) else: # Convert x.new(torch.Size) into x.new_empty(torch.Size), # as Tensor.new acts differently with a Size input versus a tuple input. @@ -560,34 +365,71 @@ def call_method( and not config.dynamic_shapes ): name = "new_empty" - - return self.__class__.create( + return wrap_fx_proxy( tx, tx.output.create_proxy( "call_method", name, - *proxy_args_kwargs([self] + args, kwargs), - current_tx=tx, + *proxy_args_kwargs([self] + list(args), kwargs), ), **options, ) -class DynamicShapeVariable(TensorVariable): +class DynamicShapeVariable(VariableTracker): """ Represents a symbolic size, e.g., as returned by tensor.size(0) """ - def __init__(self, proxy, dyn_shape_cls, **kwargs): - super(DynamicShapeVariable, self).__init__(proxy, **kwargs) - self.dyn_shape_cls = dyn_shape_cls + @classmethod + def create(cls, tx, proxy, dyn_shape, **options): + if "example_value" in proxy.node.meta: + assert proxy.node.meta["example_value"] == dyn_shape + if dyn_shape is None: + dyn_shape = get_fake_value(proxy.node, tx) + proxy.node.meta["example_value"] = dyn_shape + return DynamicShapeVariable(proxy, dyn_shape, **options) + + def __init__(self, proxy, dyn_shape, **kwargs): + super(DynamicShapeVariable, self).__init__(**kwargs) + self.proxy = proxy + self.dyn_shape = dyn_shape def python_type(self): - return self.dyn_shape_cls + return type(self.dyn_shape) def unpack_var_sequence(self, tx): super(DynamicShapeVariable, self).unpack_var_sequence(tx) + def as_proxy(self): + return self.proxy + + def evaluate_expr(self, output_graph): + if not isinstance(self.dyn_shape, torch.SymInt): + return self.dyn_shape + return output_graph.shape_env.evaluate_expr(self.dyn_shape.get_pyobj().expr) + + def call_method( + self, + tx, + name, + args: "List[VariableTracker]", + kwargs: "Dict[str, VariableTracker]", + ) -> "VariableTracker": + from .builder import wrap_fx_proxy + + options = VariableTracker.propagate(self, args, kwargs.values()) + + return wrap_fx_proxy( + tx, + tx.output.create_proxy( + "call_method", + name, + *proxy_args_kwargs([self] + list(args), kwargs), + ), + **options, + ) + class TensorWithTFOverrideVariable(VariableTracker): """ @@ -621,6 +463,11 @@ def call_method( options = VariableTracker.propagate(self, args, kwargs.values()) # insert unwrapped version of self as the first argument + # TODO: This is wrong! When you call the internal __torch_function__, + # you still get the wrapped version of self, and if you call functions + # inside __torch_function__, they should come back here. If we unwrap + # the tensor immediately, that will not happen. + # See https://github.com/pytorch/torchdynamo/issues/1951 args = list(args) args.insert(0, self.tensor_variable) func_var = GetAttrVariable(self.tensor_variable, name) diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 1ecfbe1a70b2c..31ad83cb648a3 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -1,11 +1,15 @@ import logging + +import math import re import types +from collections import OrderedDict from typing import Dict, List import numpy import torch._C +import torch.fx import torch.nn import torch.onnx.operators @@ -24,8 +28,7 @@ ) from .base import VariableTracker from .lists import ListVariable, TupleVariable -from .misc import AutocastModeVariable, ProfilerContextWrapperVariable -from .nn_module import NNModuleVariable +from .misc import AutocastModeVariable, NullContextVariable from .tensor import TensorWithTFOverrideVariable log = logging.getLogger(__name__) @@ -161,8 +164,9 @@ def can_constant_fold_through(self): torch.finfo, torch.iinfo, torch.is_floating_point, - torch.is_tensor, - torch.overrides.is_tensor_like, + torch.cuda.is_available, + torch.nn.functional._Reduction.get_enum, + torch._utils._get_device_index, ): return True return getattr(self.value, "__module__", None) == "math" @@ -170,7 +174,15 @@ def can_constant_fold_through(self): def call_function( self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" ) -> "VariableTracker": - from . import ConstantVariable, GradModeVariable, TensorVariable + from . import ( + ConstantVariable, + DynamicShapeVariable, + GradModeVariable, + TensorVariable, + UserDefinedObjectVariable, + ) + + from .builder import wrap_fx_proxy constant_args = check_constant_args(args, kwargs) unspec_python_args = check_unspec_python_args(args, kwargs) @@ -196,21 +208,26 @@ def call_function( return self._call_cross_entropy_loss(tx, args, kwargs, options) else: unimplemented(f"construct nn.Module: {self.value.__name__}") + elif self.value in (torch.is_tensor, torch.overrides.is_tensor_like): + assert len(args) == 1 + if isinstance(args[0], TensorVariable) or ( + self.value is torch.overrides.is_tensor_like + and isinstance(args[0], UserDefinedObjectVariable) + and hasattr(args[0].value, "__torch_function__") + ): + return ConstantVariable(True, **options) + else: + return ConstantVariable(False, **options) elif ( self.value in ( - torch.is_tensor, torch.is_floating_point, torch.is_complex, - torch.overrides.is_tensor_like, - torch.is_complex, ) and isinstance(args[0], TensorVariable) and args[0].dtype is not None ): - if self.value in (torch.is_tensor, torch.overrides.is_tensor_like): - return ConstantVariable(True, **options) - elif self.value is torch.is_floating_point: + if self.value is torch.is_floating_point: return ConstantVariable(args[0].dtype.is_floating_point, **options) elif self.value is torch.is_complex: return ConstantVariable(args[0].dtype.is_complex, **options) @@ -279,7 +296,7 @@ def call_function( tensor_with_tf_override.subclass_type, ) elif self.value is torch.amp.autocast_mode.autocast: - return AutocastModeVariable.create(tx, target_values=args, kwargs=kwargs) + return AutocastModeVariable.create(target_values=args, kwargs=kwargs) elif self.value in ( torch.profiler.profile, torch.profiler.record_function, @@ -287,7 +304,9 @@ def call_function( torch.autograd.profiler.record_function, ): log.warning("Profiler will be ignored") - return ProfilerContextWrapperVariable(**options) + return NullContextVariable(**options) + elif self.value is torch.autograd._profiler_enabled: + unimplemented("torch.autograd._profiler_enabled not supported yet") elif self.value is torch.jit.annotate: assert len(args) == 2 return args[1] @@ -300,13 +319,12 @@ def call_function( def get_state_from_generator(): return self.value() - return TensorVariable.create( + return wrap_fx_proxy( tx=tx, proxy=tx.output.create_proxy( "call_function", get_state_from_generator, *proxy_args_kwargs(args, kwargs), - current_tx=tx, ), example_value=self.value(), **options, @@ -319,32 +337,88 @@ def get_state_from_generator(): assert len(args) == 1 assert isinstance(args[0], TensorVariable) - if config.fake_tensor_propagation: - # In fake tensor case, this state doesn't matter, but - # it needs to be valid to not segfault. Pull a real tensor out. - # The value won't matter since we are running with fake tensors anyway, so rng doesn't matter. - # However, it is imperative to record the call_function in the graph with the true args - # (Not the fake example_value) - for the sake of graph correctness. - if self.value == torch.random.set_rng_state: - example_value = torch.random.get_rng_state() - else: - example_value = self.value.__self__.get_state() + unimplemented( + "TODO: make torch.random.set_rng_state work with FakeTensor/aot_autograd" + ) + # In fake tensor case, this state doesn't matter, but + # it needs to be valid to not segfault. Pull a real tensor out. + # The value won't matter since we are running with fake tensors anyway, so rng doesn't matter. + # However, it is imperative to record the call_function in the graph with the true args + # (Not the fake example_value) - for the sake of graph correctness. + if self.value == torch.random.set_rng_state: + example_value = torch.random.get_rng_state() else: - example_value = args[0].proxy.node.meta["example_value"] + example_value = self.value.__self__.get_state() self.value.__module__ = self.__module__ - return TensorVariable.create( + return wrap_fx_proxy( tx=tx, proxy=tx.output.create_proxy( "call_function", self.value, *proxy_args_kwargs(args, kwargs), - current_tx=tx, ), example_value=example_value, **options, ) + elif ( + self.value == torch.numel + and len(args) == 1 + and isinstance(args[0], TensorVariable) + and len(kwargs) == 0 + ): + # TODO(voz): This is rewritten as a call_method because + # torch.numel(x) w/ sym shapes raises a RuntimeError and x.numel() does not + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_method", + "numel", + *proxy_args_kwargs(args, kwargs), + ), + **options, + ) + elif ( + self.value == torch.addcdiv + and len(args) == 3 + and "value" in kwargs + and len(kwargs) == 1 + ): + # decompose addcdiv into constituent ops, prevents a graph break due to converting + # value to a scalar + result = TorchVariable(torch.div, **options).call_function(tx, args[1:], {}) + result = TorchVariable(torch.mul, **options).call_function( + tx, [result, kwargs["value"]], {} + ) + return TorchVariable(torch.add, **options).call_function( + tx, [args[0], result], {} + ) else: + any_symints_or_symfloats = any( + [isinstance(x, DynamicShapeVariable) for x in args] + ) + all_ints_or_floats = all( + [ + isinstance( + x, (variables.ConstantVariable, variables.DynamicShapeVariable) + ) + for x in args + ] + ) + bin_ops = set(["add", "sub", "mul", "div", "sqrt"]) + if ( + self.value.__module__ == "torch" + and self.value.__name__ in bin_ops + and any_symints_or_symfloats + and all_ints_or_floats + ): + msg = f"""\ +Calling {str(self.value)} on only torch.SymInt arguments is not yet supported. +To support this behavior, we need to allow const-propping tensors that store symint data. +For now, dynamo will explicitly graph break when it encounters user code with this behavior. +""" + log.warning(msg) + raise unimplemented(msg) # Handle sth like torch.LongTensor(list(np.int64, np.int64, ...)), # as FX symbolic trace doesn't support numpy int/float as base types. if ( @@ -357,13 +431,22 @@ def get_state_from_generator(): if isinstance(x.value, numpy.generic): x.value = x.value.item() - tensor_variable = TensorVariable.create( + # TODO(voz): Replace w/ dynamic shape rewrite table. + # Ideally, we would be able to do this at ctor time, but alas we need a combination + # of value + args to determine this. + fn_ = self.value + if any([isinstance(x, DynamicShapeVariable) for x in args]): + if self.value == math.sqrt: + from torch.fx.experimental.symbolic_shapes import sym_sqrt + + fn_ = sym_sqrt + + tensor_variable = wrap_fx_proxy( tx=tx, proxy=tx.output.create_proxy( "call_function", - self.value, + fn_, *proxy_args_kwargs(args, kwargs), - current_tx=tx, ), **options, ) @@ -427,13 +510,14 @@ def _call_softmax(self, tx, args, kwargs, options): dim = args[0] if args else kwargs.get("dim", variables.ConstantVariable(None)) def fake_softmax(input): - return variables.TensorVariable.create( + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( tx=tx, proxy=tx.output.create_proxy( "call_function", torch.nn.functional.softmax, *proxy_args_kwargs([input, dim], {}), - current_tx=tx, ), **VariableTracker.propagate([self, dim, input]), ) @@ -479,7 +563,9 @@ def normalize_args( ) = normalize_args(*args, **kwargs) def fake_cross_entropy_loss(input, target): - return variables.TensorVariable.create( + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( tx=tx, proxy=tx.output.create_proxy( "call_function", @@ -497,7 +583,6 @@ def fake_cross_entropy_loss(input, target): ], {}, ), - current_tx=tx, ), **VariableTracker.propagate( [ @@ -553,71 +638,30 @@ def __init__(self, value, **kwargs): def call_function( self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" ) -> "VariableTracker": - from . import ListVariable, TensorVariable, UserFunctionVariable + from . import ( + ListVariable, + NestedUserFunctionVariable, + TensorVariable, + UserFunctionVariable, + ) + from .builder import wrap_fx_proxy assert kwargs is None or len(kwargs) == 0, "kwargs are not supported, yet" - def unwrap_real(arg): - if isinstance(arg, TensorVariable): - return arg.get_real_value() - if isinstance(arg, UserFunctionVariable): - return arg.fn - if isinstance(arg, NNModuleVariable): - return tx.output.get_submodule(arg.module_key) - if arg.has_unpack_var_sequence(tx): - return [ - unwrap_real(arg_inner) for arg_inner in arg.unpack_var_sequence(tx) - ] - return arg - - def make_attr(name, proxy_args=None): + def make_attr(name): node = tx.output.create_proxy( "get_attr", name, - tuple(proxy_args) if proxy_args else tuple(), + (), {}, ) return node - # Get values - u_args = [unwrap_real(arg) for arg in args] - - def unwrap_proxy(arg): - try: - if isinstance(arg, TensorVariable): - return arg.as_proxy() - if isinstance(arg, NNModuleVariable): - name = arg.module_key - mod = unwrap_real(arg) - options = VariableTracker.propagate(self, args, kwargs.values()) - tx.output.register_attr_or_module( - mod, - name, - name, - source=NNModuleSource( - GetItemSource(self.source, arg.module_key) - ), - **options, - ) - return make_attr(name) - if arg.has_unpack_var_sequence(tx): - return [ - unwrap_proxy(arg_inner) - for arg_inner in arg.unpack_var_sequence(tx) - ] - return arg.as_proxy() - except NotImplementedError: - return arg - - def register_as_subgraph(fn, name, args): - from .. import export - - gm, guards = export(fn, *args) - + def add_subgraph(name, gm): next_name = None i = 0 while not next_name: - candidate = f"name_{i}" + candidate = f"cond_{name}_{i}" if candidate in tx.output.nn_modules: i += 1 else: @@ -627,52 +671,152 @@ def register_as_subgraph(fn, name, args): src = NNModuleSource(GetItemSource(self.source, next_name)) gm.torchdynamo_force_dynamic = False tx.output.register_attr_or_module(gm, next_name, source=src) - return next_name, gm, guards + return next_name - # Get args as proxies - p_args = [unwrap_proxy(arg) for arg in args] if self.value.__name__ == "cond": # TODO(voz): Support fake tensor dispatch for recursive # ops - see torch/dispatch/_dispatcher.py - from .. import config - if config.fake_tensor_propagation: - unimplemented("Fake tensor mode not yet supported for cond") + assert len(args) == 4 + assert type(args[0]) is TensorVariable, str(type(args[0])) # predicate + assert isinstance( + args[1], (UserFunctionVariable, NestedUserFunctionVariable) + ), str( + type(args[1]) + ) # true_fn + assert isinstance( + args[2], (UserFunctionVariable, NestedUserFunctionVariable) + ), str( + type(args[2]) + ) # false_fn + assert type(args[3]) is ListVariable, str(type(args[3])) # args + + # Our strategy for tracing the true/false branches of cond + # are to checkpoint our graphstate, run the true branch, + # roll it back to the checkpoint, and run the false + # branch, and then merge the graphstates. Well, perhaps + # "merge" is too strong a word: we mostly assert that + # the resulting graphstates have to be the same. + # + # We only permit guards to diverge (we union the guards from + # both branches). In particular, this means that side + # effects are NOT permitted inside true/false branches; this + # would be difficult to implement, because of the path + # explosion problem. + + graph_checkpoint, checkpoint = tx.output.graph, tx.copy_graphstate() + + sub_args = args[3].unpack_var_sequence(tx) + + def speculate_branch(branch): + # Setup the subgraph we're going to capture into + tx.output.graph = torch.fx.Graph() + tx.output.graphargs = [] + tx.output.name_to_input.clear() + + # One argument to graph per sub_args + for a in sub_args: + assert isinstance(a, TensorVariable) + tx.output.create_graph_input(a.as_proxy().node.name) + # NB: we don't bother populating graphargs, as + # they won't actually get used by anything + + # NB: 0 is predicate + ix = 1 if branch else 2 + + output = args[ix].call_function(tx, sub_args, {}) + + # Register output to graph + # Modeled off of compile_and_call_fx_graph + # TODO: support non single Tensor output + assert isinstance(output, TensorVariable) + tx.output.guards.update(output.guards) + tx.output.create_node( + "output", "output", (tx.output.create_arg((output.as_proxy(),))), {} + ) - assert len(p_args) == 4 - assert type(args[0]) is TensorVariable # predicate - assert type(p_args[1]) is UserFunctionVariable # true_fn - assert type(p_args[2]) is UserFunctionVariable # false_fn - assert type(args[3]) is ListVariable # args + tx.output.side_effects.prune_dead_object_new(tx) + state = tx.copy_graphstate() + + guards = state.output.guards + nn_modules = state.output.nn_modules + + # Nub out bits of state that we don't require to be + # equal + comparable_state = state._replace( + output=state.output._replace( + guards=set(), + nn_modules=None, + # Timestamp is monotonically increasing so we don't + # care about divergence + timestamp=0, + # Meh (problem is the nodes don't compare equal; + # maybe nub out outputs only) + name_to_input=OrderedDict(), + ) + ) - node_args = [unwrap_real(x) for x in args[3].unpack_var_sequence(tx)] - proxy_args = [unwrap_proxy(x) for x in args[3].unpack_var_sequence(tx)] - true_name, true_graph, true_guards = register_as_subgraph( - p_args[1].get_function(), "true", node_args + graph = tx.output.graph + tx.output.graph = graph_checkpoint + tx.restore_graphstate(checkpoint) + + return output, graph, guards, nn_modules, comparable_state + + ( + true_r, + true_graph, + true_guards, + true_nn_modules, + true_cmp, + ) = speculate_branch(True) + ( + false_r, + false_graph, + false_guards, + false_nn_modules, + false_cmp, + ) = speculate_branch(False) + + if true_cmp != false_cmp: + unimplemented(true_cmp.diff(false_cmp)) + + # Add guards + tx.output.guards |= false_guards + tx.output.guards |= true_guards + + true_name = add_subgraph( + "true", torch.fx.GraphModule(true_nn_modules, true_graph) ) - false_name, false_graph, false_guards = register_as_subgraph( - p_args[2].get_function(), "false", node_args + false_name = add_subgraph( + "false", torch.fx.GraphModule(false_nn_modules, false_graph) ) - if config.enforce_cond_guards_match: - assert ( - true_guards == false_guards - ), "Guards for true and false path must be equal." + # Apply side effects (guaranteed to be equal) + tx.output.side_effects = true_cmp.output.side_effects - true_node = make_attr(true_name, proxy_args) - false_node = make_attr(false_name, proxy_args) - p_args[1] = true_node - p_args[2] = false_node + true_node = make_attr(true_name) + false_node = make_attr(false_name) + + p_args = ( + args[0].as_proxy(), + true_node, + false_node, + tuple(a.as_proxy() for a in sub_args), + ) + # TODO: assert that the true/false return values are + # consistent + example_value = true_r.as_proxy().node.meta["example_value"] + else: + unimplemented(f"PyOperator {self.value.__name__}") # Store the invocation as a call - return variables.TensorVariable.create( + return wrap_fx_proxy( tx=tx, proxy=tx.output.create_proxy( "call_function", self.value, args=tuple(p_args), kwargs={}, - current_tx=tx, ), - example_value=self.value(*u_args), + example_value=example_value, ) diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 2d33c8328268a..d86969d83774d 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -1,4 +1,5 @@ import collections +import contextlib import dataclasses import functools import importlib @@ -11,11 +12,11 @@ from .. import variables from ..exc import unimplemented -from ..guards import Guard, GuardBuilder +from ..guards import GuardBuilder from ..source import AttrSource, ODictGetItemSource, RandomValueSource from ..utils import is_namedtuple_cls, namedtuple_fields from .base import MutableLocal, VariableTracker -from .misc import ProfilerContextWrapperVariable +from .misc import NullContextVariable class UserDefinedVariable(VariableTracker): @@ -68,7 +69,7 @@ def call_method( return variables.ListVariable(subs_as_vars, **options) - return super().call_method(tx, args, kwargs) + return super().call_method(tx, name, args, kwargs) def call_function( self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" @@ -77,8 +78,11 @@ def call_function( options = VariableTracker.propagate(self, args, kwargs.values()) - if self.value is torch.autograd.profiler.profile: - return ProfilerContextWrapperVariable() + if self.value in ( + contextlib.nullcontext, + torch.autograd.profiler.profile, + ): + return NullContextVariable(**options) elif is_namedtuple_cls(self.value): fields = namedtuple_fields(self.value) items = list(args) @@ -174,13 +178,7 @@ def call_method( assert all(map(ConstantVariable.is_literal, keys)) return TupleVariable( [ConstantVariable(k, **options) for k in keys], **options - ).add_guard( - Guard( - self.source.name(), - self.source.guard_source(), - GuardBuilder.ODICT_KEYS, - ) - ) + ).add_guard(self.source.make_guard(GuardBuilder.ODICT_KEYS)) if ( method is collections.OrderedDict.items diff --git a/torch/_functorch/__init__.py b/torch/_functorch/__init__.py new file mode 100644 index 0000000000000..10a55772ab58b --- /dev/null +++ b/torch/_functorch/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py new file mode 100644 index 0000000000000..2cd12e4f883fc --- /dev/null +++ b/torch/_functorch/aot_autograd.py @@ -0,0 +1,2117 @@ +import collections +import dataclasses +import warnings +import itertools +from contextlib import contextmanager, nullcontext +from dataclasses import dataclass +from enum import Enum +from functools import wraps, partial +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from torch.fx.experimental.proxy_tensor import is_sym_node +import logging + +import torch +import torch.fx.traceback as fx_traceback +import torch.nn as nn +import torch.utils._pytree as pytree +import torch.utils.dlpack +from torch import Tensor +from torch._dynamo.utils import dynamo_timed +from torch._subclasses import FakeTensorMode, CrossRefFakeMode, FakeTensor +from torch.fx import immutable_collections, Interpreter +from torch.fx.experimental.symbolic_shapes import ShapeEnv +from torch.multiprocessing.reductions import StorageWeakRef +from torch.nn.utils import stateless + +from functorch import make_fx +from torch._dispatch.python import enable_python_dispatcher +from . import config +from .named_members_polyfill import _named_buffers, _named_parameters +from .partitioners import default_partition + +MutationType = Enum("MutationType", ("none", "metadata_only", "data")) +OutputType = Enum("OutputType", ("non_alias", "alias_of_input", "alias_of_intermediate")) + +pytree._register_pytree_node( + immutable_collections.immutable_list, + lambda x: (list(x), None), + lambda x, c: immutable_collections.immutable_list(x), +) +pytree._register_pytree_node( + immutable_collections.immutable_dict, + lambda x: (list(x.values()), list(x.keys())), + lambda x, c: immutable_collections.immutable_dict( + {key: value for key, value in zip(c, x)} + ), +) + +aten = torch.ops.aten + +# This global counter increments every time we compile a graph with +# AOTAutograd. You can use this to correlate runtime error messages +# with compile time (e.g., if you get an error at runtime saying +# compiled graph 3 failed, you can set a breakpoint at compile time +# for this graph number to investigate further at compile time.) +# +# NB: this is different from get_aot_compilation_context, which tracks +# each underlying graph that is compiled. In contrast, AOT_COUNTER +# corresponds to top-level invocations of aot_module/aot_function; +# one counter is allocated per entire compiled block (but this block +# may involve compiling multiple subgraphs; e.g., for forwards/backwards) +AOT_COUNTER = itertools.count() + +KNOWN_TYPES = [torch.Tensor, int, str, float, bool, torch.SymInt, torch.SymFloat] + +@contextmanager +def preserve_rng_state(): + rng_state = torch.clone(torch.random.get_rng_state()) + if torch.cuda.is_available(): + cuda_rng_state = torch.clone(torch.cuda.get_rng_state()) + try: + yield + finally: + torch.random.set_rng_state(rng_state) + if torch.cuda.is_available(): + torch.cuda.set_rng_state(cuda_rng_state) + + +# Set up hooks so that during backward the fx's stack_trace is properly set +callback_set = False + + +def setup_stacktrace_preservation_hooks(roots: List): + def iter_graph(roots): + if not roots: + return + seen = set() + q = collections.deque() + for node in roots: + if node is not None: + seen.add(node) + q.append(node) + + while q: + node = q.popleft() + for fn, _idx in node.next_functions: + if fn in seen or fn is None: + continue + seen.add(fn) + q.append(fn) + + yield node + + def get_callback(saved_stack_): + def callback(): + global callback_set + fx_traceback.set_stack_trace(saved_stack_) + callback_set = False + + return callback + + def get_prehook(stack_): + def prehook(grad_output): + global callback_set + + if not callback_set: + torch.autograd.variable.Variable._execution_engine.queue_callback( + get_callback(fx_traceback.format_stack()) + ) + callback_set = True + + fx_traceback.set_stack_trace(stack_) + + return prehook + + def get_posthook(special_stack_): + def posthook(grad_input, grad_output): + fx_traceback.set_stack_trace(special_stack_) + + return posthook + + for node in iter_graph(roots): + forward_node_stack = node.metadata.get("traceback_", []) + node.register_prehook(get_prehook(forward_node_stack)) + + special_stack = forward_node_stack.copy() + special_stack.append( + "Gradient addition node due to multiple use of tensor around:" + ) + node.register_hook(get_posthook(special_stack)) + +# This class tells us about a user's forward output that is an alias. +# It can be an alias of either a user forward input, of of a graph intermediate. +@dataclass(frozen=True) +class OutputAliasInfo: + # Tells us if this output is: + # (1) a regular (non-aliased) output + # (2) an alias of a forward input + # (2) an alias of an intermediate (aka an alias of an output of the inner traced forward) + output_type: OutputType + # If (1) above, then + # - Tells us that the base of this alias is user_fwd_input[base_idx] + # (This is an index into the inputs *before* we make synthetic bases) + # If (2) above, then + # - Tells us that the base of this alias is traced_fwd_outputs[base_idx] + # here, this refers to the index of the *direct* traced + base_idx: int + # sizes, strides and storage offset of the aliased output are all returned as actual (sym)ints + # in the compiled forward. These indices tell us where in the forward outputs to grab them. + sizes_idx: Optional[int] + strides_idx: Optional[int] + storage_offset_idx: Optional[int] + # We store the actual output alias that we traced in the forward (should be a fake tensor) + # to grab any other non-symbolic properties on the output alias, like requires_grad. + # It's optional here, for cases where the user directly returns an input as an output. + # If output_type == non_alias, then these fields are also always None. + tensor_meta: Optional[Tensor] + +# This class tells us about how to perform a metadata mutation on forward inputs. +# it only applies to forward inputs that experience metadata-only mutations +@dataclass(frozen=True) +class InputAliasInfo: + # This object gives us information about how to perform a metadata-mutation + # on original_fwd_inputs[base_idx] + # (This is an index into the inputs *before* we make synthetic bases) + base_idx: int + # sizes, strides and storage offset of the aliased output are all returned as actual (sym)ints + # in the compiled forward. These indices tell us where in the forward outputs to grab them. + sizes_idx: int + strides_idx: int + storage_offset_idx: int + # We store the actual output alias that we traced in the forward (should be a fake tensor) + # to grab any other non-symbolic properties on the output alias, like requires_grad. + tensor_meta: Tensor + +# This class encapsulates all aliasing + mutation info we need about the forward graph +# See a more detailed overview of the edge case handling at +# https://docs.google.com/document/d/19UoIh_SVrMy_b2Sx5ZaeOJttm6P0Qmyss2rdBuyfoic/edit +@dataclass(frozen=True) +class ViewAndMutationMeta: + # length: # user forward inputs + # For every input, tells us whether the input: + # (a) is not mutated + # (b) only metadata is mutated + # (c) data (and maybe metadta) is mutated + mutated_input_info: List[MutationType] + # length: (# inputs of the user forward) + # metadata_mutation_input_info[i] is not None <====> mutated_input_info[i] == MutationType.metadata_only + # We stash the updated FakeTensor that we traced with in the forward in here, + # that way we can use it to replay the metadata mutation + metadata_mutation_input_info: List[Optional[InputAliasInfo]] + # length: # outputs in the compiled forward (not including output alias symints). Equal to: + # length: (# inputs w data mutations) + (# outputs that don't alias inputs) + # For every output *and* mutated input returned from the forward, + # tells us whether or not the output should require gradients or not + requires_grad_out_info: List[bool] + # length: # fw outputs + aliased_output_info: List[OutputAliasInfo] + +def gen_alias_from_base(aliased_base_tensor, size, stride, storage_offset, target_meta_tensor): + # handle R2C and C2R + if aliased_base_tensor.is_complex() and not target_meta_tensor.is_complex(): + aliased_out = torch.view_as_real(aliased_base_tensor).as_strided(size, stride, storage_offset) + elif not aliased_base_tensor.is_complex() and target_meta_tensor.is_complex(): + aliased_out = torch.view_as_complex(aliased_base_tensor).as_strided(size, stride, storage_offset) + else: + aliased_out = aliased_base_tensor.as_strided(size, stride, storage_offset) + # For outputs aliasing inputs, we need to check if the requires-gradness has changed. + if aliased_base_tensor.requires_grad and not target_meta_tensor.requires_grad: + aliased_out = aliased_out.detach() + elif not aliased_base_tensor.requires_grad and target_meta_tensor.requires_grad: + aliased_out.requires_grad_(True) + return aliased_out + +# This is a version of functionalization that is specifically designed +# for the AOTAutograd use case. +# +# Unlike functorch's variant, this doesn't use the functorch level system, +# instead it directly uses PyTorch's conventional dispatcher to hit the +# functionalization key. In particular, this means that FunctionalTensorWrapper +# can have autograd data stored directly on it. +# +# In typical AOTAutograd usage, the dispatch key order will look like: +# +# Autograd - Functionalization ~~~~> Proxy Mode - Fake Tensor +# outer tensor inner tensor +# +# TODO: Provide a faster version of this that assumes flat arguments +# (so no pytree necessary) +def run_functionalized_fw_and_collect_metadata(f): + memo = {} + + def to_fun(t): + if isinstance(t, Tensor): + if t in memo: + return memo[t] + r = torch._to_functional_tensor(t, mirror_autograd_meta=True) + memo[t] = r + return r + else: + return t + + def from_fun(t): + if not isinstance(t, Tensor) or not torch._is_functional_tensor(t): + return t + torch._sync(t) + return torch._from_functional_tensor(t) + + @wraps(f) + def inner(*args): + # This function is meant to be run with the forward, which expects a flat list of tensor/symint/other args. + assert all(isinstance(a, torch.Tensor) or type(a) in KNOWN_TYPES for a in args) + + collect_mutated_input_info: List[MutationType] = [] + collect_requires_grad_out_info: List[bool] = [] + collect_aliased_output_info: List[OutputAliasInfo] = [] + collect_metadata_mutation_input_info: List[Optional[InputAliasInfo]] = [] + + f_args = pytree.tree_map(to_fun, args) + + torch._enable_functionalization(reapply_views=True) + try: + outs = f(*f_args) + finally: + torch._disable_functionalization() + + flat_args, _ = pytree.tree_flatten(args) + flat_f_args, _ = pytree.tree_flatten(f_args) + flat_outs, _ = pytree.tree_flatten(outs) + + # Inspect the state of the input tensor functional wrapper to detect input mutation info + inputs_with_mutated_data = [] + # If inp[i] has a metadata-only mutation, then maybe_inputs_with_mutated_metadata[i] contains the updated version + maybe_inputs_with_mutated_metadata: List[Optional[torch.Tensor]] = [] + for (i, (arg, f_arg)) in enumerate(zip(flat_args, flat_f_args)): + if not isinstance(arg, Tensor): + continue + torch._sync(f_arg) + new_arg = torch._from_functional_tensor(f_arg) + if arg is not new_arg: + # Note [Input mutation handling in aot autograd] + # We use functionalization to detect two types in input mutations: + # (1) metadata-only input mutations, like input.t_() + # (2) data input mutations, like input.add_(1) + # inputs that have both data and metadata mutated get lumped into (2). + # + # Why do we distinguish these two cases? aot autograd needs to handle them very differently. + # For data mutations, we return the updated inputs *directly* in the compiled forward graph. + # e.g. + # def f(x): + # x.mul_(2) + # out = x.mul(3) + # return out + # + # // This function gets compiled and dumped inside of an autograd.Function.forward() + # def traced_forward(x): + # x_updated = x.mul(2) + # out = x_updated.mul(3) + # return x_updated, out + # + # // The returned function will call the compiled forward, and apply input mutations afterwards + # def compiled_fn(x): + # x_updated, out = traced_forward(x) + # x.copy_(x_updated) + # return out + # + # For input metadata mutations, though, we cannot return the "updated input" in the forward graph, + # Because it is an alias of an input. autograd.Function.forward can't handle arbitrary outputs that alias inputs. + # Instead, we stash the "updated input metadata" during tracing + # e.g. + # def f(x): + # x.t_() + # out = x.mul(3) + # return out + # + # // This function gets compiled and dumped inside of an autograd.Function.forward() + # // (We don't return x_updated. Just return the original fw out) + # def traced_forward(x): + # x_updated = x.t() + # out = x_updated.mul(3) + # return out + # + # // The returned function will call the compiled forward, and apply input mutations afterwards + # def compiled_fn(x): + # out = traced_forward(x) + # _x_updated_metadata = CompiledFunction.fw_metadata.metadata_mutation_input_info[0] + # x.as_strided_(_x_updated_metadata.size(), _x_updated_metadata.stride(), _x_updated_metadata.storage_offset()) + # return out + if StorageWeakRef(arg._storage()) == StorageWeakRef(new_arg._storage()): + # We can use the storage aliasing of the inputs and updated inputs + # to detect when an input was actually updated, or just inplace-viewed. + collect_mutated_input_info.append(MutationType.metadata_only) + else: + collect_mutated_input_info.append(MutationType.data) + # Only return mutated inputs that mutate *data*, not metadata + # Note [Input mutation handling in aot autograd] + inputs_with_mutated_data.append(new_arg) + # For every mutated input, we ALSO need to return info on + # whether than mutated input requires gradients. Why? + # Our custom autograd.Function.forward returns updated inputs as outputs, + collect_requires_grad_out_info.append(f_arg.requires_grad) + else: + collect_mutated_input_info.append(MutationType.none) + + maybe_inputs_with_mutated_metadata.append( + new_arg if collect_mutated_input_info[-1] == MutationType.metadata_only else None) + + def collect_grad_info(t): + # Collect info on which output tensors require gradients, + # so we can mark them properly in the returned autograd.Function. + # We only collect requires_grad info on real forward outputs, and not on inputs. + collect_requires_grad_out_info.append(isinstance(t, torch.Tensor) and t.requires_grad) + + # Note [output alias handling in aot autograd] + # Given a function to compile where one of its outputs aliases an input, + # we need to remove that output from the compiled graph and generate it off to the side. + # e.g. + # def f(x): + # return x.view(-1) + # + # Why? Two reasons: + # (1) If your autograd.Function returns a view on an input in the forward, autograd.Function + # will not allow you to mutate it (This original came from arbitrary user code where the user might want to mutate) + # (2) There's no reason to compile views anyway. We can just regenerate the view of the input off to the side, + # + # Another interesting case is when you have both mutation and aliasing: + # def f(x): + # x.mul_(2) + # return x.view(-1) + # + # You could imagine that this output is now *safe* to compile and return in the autograd.Function, + # because after functionalization runs, it will technically not alias an input: + # def f_functionalized(x): + # x_updated = x.mul(2) + # return x_updated, x_updated.view(-1) + # + # However, this is still wrong: we can't return x_updated.view(-1) to the user. We are on the hook to return: + # def traced_forward(x): + # x_updated = x.mul(2) + # return x_updated + # + # def compiled_fn(x) + # x_updated = traced_forward(x) + # x.copy_(x_updated) + # return x.view(-1) + # + # Why can't we return x_updated.view(-1) to the user? + # It can have different metadata from x.view(-1)! Specifically, the input x could be a non-memory-dense tensor, + # But the intermediate created by our graph, x_updated, will always be memory-dense. + def filter_and_record_aliased_outs(outputs): + # NOTE: this dict will clobber keys if we have multiple inputs that alias. + # Let's say inpA and inpB alias, and the user generated an output using out = inpA.view(...) + # For now, since we're not handling the case with multiple _base's sharing a storage, + # it is actually fine to arbitrarily pick which input to regenerate the aliased output from. + # e.g. out_new = inpB.as_strided(out.size(), out.stride(), out.storage_offset()) + # + # This will be more complicated when you have multiple _base tensors aliasing the same + # underlying storage, when we eventually handle that. + # We'll need to ensure that we generate the view off of the right base. + inp_storage_refs = {StorageWeakRef(inpt._storage()): idx for idx, inpt in enumerate(flat_f_args)} + inp_tensor_ids = {id(inpt) for inpt in flat_f_args if isinstance(inpt, torch.Tensor)} + inp_storage_refs_set = set(inp_storage_refs) + + non_aliased_input_outs = [] + # For a given output tensor that alias an input, tells us: + # (1) the index of the input that we alias + # (2) Whether or not the output is a view of the input, or if `output is input` + # (so we don't need to generate a view, and can return the input directly) + # Note: if the function returns an output that *is* an input, we still cannot return it in the graph. + # e.g. + # def f(x): + # x.add_(1) + # return x + # Our compiled fw will return an "x_updated", but it is *not* ok to return that to the user. + # We need to manually do x.copy_(x_updated), and return the original x to the user. + # Why? for example, the metadata between x and x_updated might be different (e.g. _is_leaf()) + aliased_out_idx: Dict[torch.Tensor, Tuple[int, bool]] = {} + + for o in outputs: + # Note: When detecting input/output aliasing, we NEED to do it using the outer FunctionalTensorWrapper objects. + # In the case where we mutate an input *and* return a view of it, the outer wrappers will still alias, + # but the inner tensors no longer alias. + if isinstance(o, torch.Tensor) and StorageWeakRef(o._storage()) in inp_storage_refs: + aliased_inp_idx = inp_storage_refs[StorageWeakRef(o._storage())] + is_exact_input = id(o) in inp_tensor_ids + aliases_intermediate_and_not_input = False + aliased_out_idx[o] = (aliased_inp_idx, aliases_intermediate_and_not_input, is_exact_input) + else: + # Only return outputs that are not aliases of inputs. + non_aliased_input_outs.append(o) + # If a function involves creating a tensor, and returning a view of it, such that its _base is the intermediiate, + # We need to make sure our graph returns the _base as a graph output, and we manually recreate the view + # to return to the user. Why? The backend compiler is free to (incorrectly) not set requires_grad + # on the base tensor, but we are obligated to properly set requires-gradness on the real output. + non_aliased_outs = [] + for i, o in enumerate(non_aliased_input_outs): + non_aliased_outs.append(o) + + return non_aliased_outs, aliased_out_idx + + non_aliased_outs, aliased_out_to_inp_idx = filter_and_record_aliased_outs(outs) + + pytree.tree_map(collect_grad_info, non_aliased_outs) + + # Calling convention: the output is (mutated_input_values, original_outs) + # We return all mutated inputs + outputs here, **except** for any mutated inputs or outputs + # that alias original inputs. + # See Note [Input mutation handling in aot autograd] + mutated_inps_and_outs = inputs_with_mutated_data + list(non_aliased_outs) + + # Our compiled forward function will return: + # (1) non-aliased updated inputs + # (2) non-aliased fw outputs + # (3) size/stride/storage_offset metadata for updated aliased inputs + # (4) size/stride/storage_offset metadata for aliased outputs + + start_idx_for_aliased_output_metadata = 0 + + # First, gather the metadata info on mutated inputs (this only applies to inputs with metadata-only mutations)) + for i, maybe_aliased_updated_inp in enumerate(maybe_inputs_with_mutated_metadata): + if maybe_aliased_updated_inp is None: + collect_metadata_mutation_input_info.append(None) + continue + # Figure out where the sizes/strides/storage_offset are in the compiled fw output. + sizes_idx = start_idx_for_aliased_output_metadata + strides_idx = sizes_idx + len(maybe_aliased_updated_inp.size()) + storage_offset_idx = strides_idx + len(maybe_aliased_updated_inp.stride()) + # update our offset for the next tensor + start_idx_for_aliased_output_metadata = storage_offset_idx + 1 + inp_info = InputAliasInfo( + base_idx=i, + sizes_idx=sizes_idx, + strides_idx=strides_idx, + storage_offset_idx=storage_offset_idx, + tensor_meta=maybe_aliased_updated_inp, + ) + collect_metadata_mutation_input_info.append(inp_info) + + # Next, gather the metadata info on the user's outputs that alias (either inputs or graph outputs) + num_non_input_aliased_outputs = 0 + for o in outs: + maybe_alias_info = aliased_out_to_inp_idx.get(o, None) if isinstance(o, torch.Tensor) else None + if maybe_alias_info is None: + output_type = OutputType.non_alias + # Here, alias_idx will tell us which output from the inner forward this corresponds to. + alias_idx = num_non_input_aliased_outputs + sizes_idx = None + strides_idx = None + storage_offset_idx = None + tensor_meta = None + else: + input_alias_idx, is_alias_of_intermediate_not_input, is_exact_input = maybe_alias_info + if is_exact_input: + assert not is_alias_of_intermediate_not_input + output_type = OutputType.alias_of_input + alias_idx = input_alias_idx + sizes_idx = None + strides_idx = None + storage_offset_idx = None + tensor_meta = None + else: + if is_alias_of_intermediate_not_input: + output_type = OutputType.alias_of_intermediate + alias_idx = num_non_input_aliased_outputs + else: + output_type = OutputType.alias_of_input + alias_idx = input_alias_idx + tensor_meta = o + # Figure out where the sizes/strides/storage_offset are in the compiled fw output. + sizes_idx = start_idx_for_aliased_output_metadata + strides_idx = sizes_idx + len(tensor_meta.size()) + storage_offset_idx = strides_idx + len(tensor_meta.stride()) + # update our offset for the next tensor + start_idx_for_aliased_output_metadata = storage_offset_idx + 1 + + if output_type != OutputType.alias_of_input: + num_non_input_aliased_outputs += 1 + + inp_info = OutputAliasInfo( + output_type=output_type, + base_idx=alias_idx, + sizes_idx=sizes_idx, + strides_idx=strides_idx, + storage_offset_idx=storage_offset_idx, + tensor_meta=tensor_meta + ) + collect_aliased_output_info.append(inp_info) + + # This is the total number of size/stride/storage_offset metadata outputs that we return in the forward, + # used for regenerating aliases later. + num_aliasing_metadata_outs = start_idx_for_aliased_output_metadata + + assert len(collect_metadata_mutation_input_info) == len(collect_mutated_input_info) + + assert len([x for x in collect_metadata_mutation_input_info if x is not None]) == len([ + x for x in collect_mutated_input_info if x == MutationType.metadata_only + ]) + assert len(collect_aliased_output_info) == len(outs) + assert len([x for x in collect_aliased_output_info if x.output_type != OutputType.alias_of_input]) == len(non_aliased_outs) + + + # Our autograd.Function.forward returns both mutated inputs and outputs, + # so we need grad info on all of them. + assert len(collect_requires_grad_out_info) == len(mutated_inps_and_outs) + + metadata = ViewAndMutationMeta( + mutated_input_info=collect_mutated_input_info, + metadata_mutation_input_info=collect_metadata_mutation_input_info, + requires_grad_out_info=collect_requires_grad_out_info, + aliased_output_info=collect_aliased_output_info, + ) + return metadata, pytree.tree_map(from_fun, mutated_inps_and_outs), num_aliasing_metadata_outs + return inner + + +# This creates a functionalized joint forwards-backwards function given both +# the primals (to run forwards) and tangents (to run backwards). +# +# It uses the metadata that was created earlier to figure out what all of the outputs to the autograd.Function.forward are: +# (1) Which inputs received data mutations (and need to be passed as outputs into autograd.grad()) +# (2) Which outputs are aliases of inputs (and should *not* be passed as outputs into autograd.grad()) +def create_joint_forward_backward_functionalized( + fn, + *, + meta: ViewAndMutationMeta, + synthetic_base_info: Optional[List[Union[int, Tuple[int, List[Any]]]]], +): + # NOTE: when we have synthetic base inputs, we need to clone them *before* creating views off of them. + # This means that "idx" here represents the index of the (potentially) synthetic base. + # What we need to do is: + # (1) map the current (post-synthetic-base calling convention) input argument index + # to int index pre-synthetic-base-calling-convention. + # (2) There could be multiple, if this index corresponds to a synthetic base + # that has multiple input aliases. + # (3) If any of those corresponding inputs get metadata mutations, then we clone the base. + def maybe_to_fresh_input(idx, t): + if not isinstance(t, Tensor): + return t + + if synthetic_base_info is None: + outer_aliased_indices_of_current_base_arg = [idx] + else: + outer_aliased_indices_of_current_base_arg = [ + # For every argument index in the outer calling convention (before synthetic bases) + # find its index in the inner calling convention. + # if it matches the index of our current arg (idx), track the outer argument's index (i) + i for i, outer_idx_or_lambda in enumerate(synthetic_base_info) + if (isinstance(outer_idx_or_lambda, int) and outer_idx_or_lambda == idx) + or (isinstance(outer_idx_or_lambda, tuple) and outer_idx_or_lambda[0] == idx) + ] + if any(meta.mutated_input_info[i] == MutationType.data for i in outer_aliased_indices_of_current_base_arg): + # Make sure the primal we pass to autograd.grad() + # seees the tensor before the mutation + out = t.clone() + elif any(meta.mutated_input_info[i] == MutationType.metadata_only for i in outer_aliased_indices_of_current_base_arg): + # Make sure the primal we pass to autograd.grad() + # seees the tensor before the metadata mutation + out = t.view(t.shape) + else: + out = t + return out + + def unpack_synthetic_bases(primals: List[Any]) -> List[Any]: + # This is only not None if our graph mutates a graph input that aliases another graph input. + if synthetic_base_info is None: + return primals + + f_args_inner = [] + for outer_idx_or_lambda in synthetic_base_info: + if isinstance(outer_idx_or_lambda, int): + f_args_inner.append(primals[outer_idx_or_lambda]) + else: + outer_base_idx, strided_args = outer_idx_or_lambda + outer_base = primals[outer_base_idx] + # TODO: we could consider storing and executing view replay logic here, + # instead of a general as_strided() call. + # This could also improve perf, since today this will cause + # more as_strided_scatter() ops in the graph. + view_arg = outer_base.as_strided(*strided_args) + f_args_inner.append(view_arg) + return f_args_inner + + def joint_forward_backward( + primals: List[Any], tangents: List[Any] + ) -> Tuple[List[Any], List[Any]]: + # Call the forward pass, making sure to clone any inputs that are mutated first. + # We need to ensure that the inputs we pass to autograd.grad() are the *original* + # inputs, and not their mutated values. + primals_no_input_mutations = [maybe_to_fresh_input(i, t) for i, t in enumerate(primals)] + # This is also where we handle the calling convention around synthetic bases. + # We need to make sure that we convert any synthetic base arguments into views + # *after* we do the cloning above, to preserve the view relationship. + primals_ = unpack_synthetic_bases(primals_no_input_mutations) + assert len(meta.mutated_input_info) == len(primals_) + all_outs = fn(*primals_) + assert len(meta.aliased_output_info) == len(all_outs) + + # Pass any (non-aliased) outputs in as tangents, since they'll be returned as outputs in the fw + # For outputs that are aliases of intermediates, we will have returned the output's _base as an output in the graph instead, + # which we *should* send to grad() + outputs_for_grad = [ + x + # TODO: support ._base + # x._base if meta.aliased_output_info[i].output_type == OutputType.alias_of_intermediate else x + for (i, x) in enumerate(all_outs) if meta.aliased_output_info[i].output_type != OutputType.alias_of_input + ] + # Pass any (non-aliased) mutated inputs in as tangents, since they'll be returned as outputs in the fw + # Important: the traced joint fw/bw will return updated inputs with data mutations, + # but *not* with metadata mutations. + # Instead, we shunt the updated metadata around externally + # and update the input's metadata outside of the autograd.Function + mutated_inputs_for_grad = [x for (i, x) in enumerate(primals_) if meta.mutated_input_info[i] == MutationType.data] + mutated_inputs_and_outs_to_grad = mutated_inputs_for_grad + outputs_for_grad + + metadata_mutated_inps = [x for (i, x) in enumerate(primals_) if meta.mutated_input_info[i] == MutationType.metadata_only] + # for user outputs that are aliases (either of inputs, or of graph intermediates) + # figure out what metadata to return in the forward, which is needed to regenerate the output aliases + aliased_outs = [x for (i, x) in enumerate(all_outs) if meta.aliased_output_info[i].output_type != OutputType.non_alias + and meta.aliased_output_info[i].tensor_meta is not None] + output_metadata_for_fw = [] + for curr_alias in metadata_mutated_inps + aliased_outs: + size_ = curr_alias.size() + stride_ = curr_alias.stride() + storage_offset_ = curr_alias.storage_offset() + # FX IR doesn't know about tuples, so we flatten the metadata into individual ints/symints, + # and index into the final output list later. + output_metadata_for_fw += (size_ + stride_ + (storage_offset_,)) + + # Take care to grab and sync the updated inputs from primals_ (the inputs we actually mutate!) + # and not primals (the preserved inputs, pre-mutation, that we pass to grad()) + for i, arg in enumerate(primals_): + if not isinstance(arg, Tensor): + continue + torch._sync(arg) + + # Get the inputs that need gradients + grad_primals = [] + inputs_needs_grads = [] + # Note that we're not using primals_ here, being carefully not to pass any mutated inputs into autograd.grad() + for p in primals: + is_grad_tensor = isinstance(p, Tensor) and p.requires_grad + inputs_needs_grads.append(is_grad_tensor) + if is_grad_tensor: + grad_primals.append(p) + + # Get the outputs that need gradients + assert len(tangents) == len(mutated_inputs_and_outs_to_grad) + needed_outs = [] + needed_tangents = [] + for out, tangent in zip(mutated_inputs_and_outs_to_grad, tangents): + if isinstance(out, Tensor) and out.requires_grad: + # A bit sketchy, but fixes e.g. test_aot_autograd_exhaustive_matmul_cpu_float32 + # The issue is that we are sensitive to decomps that don't accurately maintain + # their output's _base.shape compared to eager mode, and this helps mitigate a bit. + needed_outs.append(out if out.shape == tangent.shape else out.view(tangent.shape)) + needed_tangents.append(tangent.requires_grad_(True)) + + setup_stacktrace_preservation_hooks([out.grad_fn for out in needed_outs]) + + backward_out = [] + # Call the backwards pass + if grad_primals: + with fx_traceback.override_stack_trace(): + backward_out = torch.autograd.grad( + needed_outs, + grad_primals, + grad_outputs=needed_tangents, + allow_unused=True, + ) + backward_out_iter = iter(backward_out) + all_fw_outs = mutated_inputs_and_outs_to_grad + output_metadata_for_fw + return all_fw_outs, [ + next(backward_out_iter) if i else None for i in inputs_needs_grads + ] + + def to_fun(t): + if isinstance(t, Tensor): + return torch._to_functional_tensor(t, mirror_autograd_meta=True) + else: + return t + + def from_fun(t): + if not isinstance(t, Tensor) or not torch._is_functional_tensor(t): + return t + torch._sync(t) + return torch._from_functional_tensor(t) + + def functionalized_joint( + primals: List[Any], tangents: List[Any] + ) -> Tuple[List[Any], List[Any]]: + + # Wrap inputs into functional wrappers + f_primals, f_tangents = pytree.tree_map(to_fun, (primals, tangents)) + torch._enable_functionalization(reapply_views=True) + try: + # Run the joint + outs = joint_forward_backward(f_primals, f_tangents) + finally: + torch._disable_functionalization() + + # Syncing of inputs/outputs was already done directly in the joint call + return pytree.tree_map(from_fun, outs) + + return functionalized_joint + + +def normalize_as_list(x): + if isinstance(x, tuple): + return list(x) + elif isinstance(x, list): + return x + return [x] + + +aot_autograd_decompositions = {} + + +# This is a list since looking forward, we can have this arbitrarily nested. +graph_being_compiled: List[str] = [] +# TODO: It would be nice to reset the numbering every time aot_id goes +# up, but this is annoying to do right now (because we don't know if +# an aot_id will come back from the dead), so right now this also happens +# to be a globally unique number too (at the cost of wobbling if you change +# how the graphs compile) +nth_graph: int = 0 +model_name: str = "model" + + +def set_model_name(name): + global model_name + model_name = name + + +def get_aot_compilation_context() -> Tuple[List[str], str, int]: + return list(graph_being_compiled), model_name, nth_graph + + +def get_aot_graph_name() -> str: + """ + Returns the name of the graph being compiled. + """ + global model_name, graph_being_compiled, nth_graph + return f"{model_name}__{'_'.join(graph_being_compiled)}_{nth_graph}" + + +get_graph_being_compiled = get_aot_graph_name + + +@contextmanager +def track_graph_compiling(aot_config, graph_name): + global graph_being_compiled + # TODO: Don't shove the aot_id in here; set it in the context + graph_being_compiled = [f"{aot_config.aot_id}_{graph_name}"] + yield + global nth_graph + nth_graph += 1 + graph_being_compiled = [] + + +def make_boxed_func(f): + def g(args): + return f(*args) + + g._boxed_call = True + return g + + +def make_boxed_compiler(compiler): + @wraps(compiler) + def f(fx_g, inps): + out_f = compiler(fx_g, inps) + fx_g = make_boxed_func(out_f) + return fx_g + + return f + + +def call_func_with_args(f, args, steal_args=False, disable_amp=False): + if not steal_args: + args = list(args) + assert isinstance(args, list) + + if disable_amp: + guard = torch._C._DisableAutocast() + try: + if hasattr(f, "_boxed_call"): + out = normalize_as_list(f(args)) + else: + # TODO: Please remove soon + # https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 + warnings.warn( + "Your compiler for AOTAutograd is returning a a function that doesn't take boxed arguments. " + "Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. " + "See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale." + ) + out = normalize_as_list(f(*args)) + finally: + if disable_amp: + del guard + return out + + +@dataclasses.dataclass +class AOTConfig: + """ + Configuration for AOTDispatcher + """ + + fw_compiler: Callable + bw_compiler: Callable + partition_fn: Callable + decompositions: Dict[Callable, Callable] + num_params_buffers: int + aot_id: int + + +def aot_dispatch_base(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig): + fw_module = make_fx(flat_fn, aot_config.decompositions)(*flat_args) + if config.debug_graphs: + print("====== Forward (only) graph {aot_config.aot_id} ======") + fw_module.print_readable() + + + disable_amp = torch._C._is_any_autocast_enabled() + context = disable_autocast_manager if disable_amp else nullcontext + + with context(), track_graph_compiling(aot_config, "inference"): + compiled_fw = aot_config.fw_compiler(fw_module, flat_args) + + @wraps(compiled_fw) + def new_fn(args): + fw_outs = call_func_with_args(compiled_fw, args, disable_amp=disable_amp) + return fw_outs + new_fn._boxed_call = True + + return new_fn + + +@contextmanager +def disable_autocast_manager(): + guard = torch._C._DisableAutocast() + try: + yield + finally: + del guard + +def are_differentiable_views(view1, view2): + if view1 is view2: + return True + if view1._base is None and view2._base is None: + return False + if view1._base is view2._base or view1._base is view2 or view1 is view2._base: + return True + return False + +def same_dtype_views(view1, view2): + if view1.dtype != view2.dtype: + return False + if view1._base is not None and view1.dtype != view1._base.dtype: + return False + if view2._base is not None and view2.dtype != view2._base.dtype: + return False + return True + +# Note [Handling mutations on an input that aliases other inputs] +# The easiest example to show-case this edge case is here: +# +# def f(a, b): +# a.mul_(2) +# out = a + b +# return out +# +# In this situation, if a and b happened to be aliased, we need to trace something different! +# Suppose we had b = a.view(-1) +# (In this case, that means that `a._base is b`) +# +# We need to ensure that the aliasing relationship between a and b is preserved. +# We do that detecting the specific situation above (mutate an input that aliases another input), +# and when we do that, we create a synthetic base argument. Then inside of the traced forward, +# we regenerate a and b off of that base. +# The complete example of the transformed function looks like this: +# +# // The traced forward takes in a synthetic base, and regenerates the aliased inputs as views +# // We could consider getting view-replay support here to minimize as_strided_scatter ops in the graph +# def traced_forward(base): +# a = base.as_strided(...) +# b = base.as_strided(...) +# a_updated = a.mul(2) +# base_updated = torch.as_strided_scatter(base, a_updated, ...) +# b_updated = base_updated.as_strided(...) +# out = a_updated + b_updated +# return a_updated, out +# +# def compiled_fn(a, b): +# // we detect that a is the "differentiable base" here +# base = a +# // In other situations, we might do either: +# // (1) a and b are both views off of some larger differentiable base +# // assert a._base is b._base and a._base is not None +# // base = a._base +# // (2) a and b both don't require gradients. Create a base from the storage +# // assert a._base is None and b._base is None +# // base = torch.Tensor(a.storage()) +# a_updated, out = traced_forward(base) +# a.copy_(a_updated) +# return out +# +# This function: +# (1) Merges input views into a synthetic base argument, when any of those input views are mutated +# (2) Returns metadata telling the autograd.Function how to modify their arguments properly, +# to respect the new calling convention. +# +# The calling convention is as follows. +# Any inputs that were originally views of one another get yanked, and replaced with a synthetic base. +# The argument list ordering goes [base1, ..., baseN], [arg1, ..., argN], +# Where the ordering of the bases is determined from the ordering of the original view args. +# baseA will come before baseB if the earliest original argument coming from baseA +# showed up earlier in the argument list than the earliest original argument coming from baseB. +# +# Example, given some tensors a, b, c, d +# call site: +# f(a, c.view(-1), b.view(-1), b, c, d) +# Modified argument list: +# c_base comes first because the first c view came earlier in arg list than the first b view +# b_base = torch.Tensor(b.storage()) +# c_base = torch.Tensor(c.storage()) +# f(c_base, b_base, a, d) +def merge_view_inputs( + fwd_inputs: List[Any], + mutated_input_info: List[MutationType] +) -> Tuple[List[Any], Optional[List[Union[int, Tuple[int, Tuple[Any]]]]]]: + assert len(fwd_inputs) == len(mutated_input_info) + storage_ref_to_idx: Dict[StorageWeakRef, List[int]] = collections.defaultdict(list) + for i, inpt in enumerate(fwd_inputs): + if isinstance(inpt, Tensor): + storage_ref = StorageWeakRef(inpt._storage()) + storage_ref_to_idx[storage_ref].append(i) + base_args = [] + other_args = [] + # This list contains metadata that tells you what the i'th argument in the inner calling convention should be. + # It's either: + # - another int (corresponding to the index in the argument list of the element from the outer calling convention) + # - idx, *args, where we can generate the new output with old_args[idx].as_strided(*args) + # idx corresponds to which synthetic base from the outer calling context to view + inner_calling_convention_meta: Dict[int, Union[int, Tuple[int, List[Any]]]] = {} + for aliased_input_indices in storage_ref_to_idx.values(): + if len(aliased_input_indices) > 1 and any( + # We only care about mutations that affect all aliases, + # so metadata mutations on an input doesn't require us to do synthetic base handling. + mutated_input_info[inpt_idx] == MutationType.data for inpt_idx in aliased_input_indices + ): + # We detected an input that was mutated, AND aliases with another input. + # we need to replace this set of aliased inputs with a single synthetic base. + # For now, I'm banning a bunch of cases. We expect dynamo to properly detect these cases + # and error out. We can fix them later. + for idx1, idx2 in zip(aliased_input_indices, aliased_input_indices[1:]): + view1 = fwd_inputs[idx1] + view2 = fwd_inputs[idx2] + # The "inputs that are aliased but have different differentiable bases" case + # is more complicated and hopefully pretty rare. Not currently handled. + assert are_differentiable_views(view1, view2), \ + "aot_autograd() does not yet handle non-differentiable view input mutations." + # Regenerating views when reinterpreting complex / real tensors seems non-trivial, + # not handling for now + assert same_dtype_views(view1, view2), \ + "aot_autograd() does not yet handle input mutations on views with different dtypes." + non_none_bases = [fwd_inputs[i]._base for i in aliased_input_indices if fwd_inputs[i]._base is not None] + aliases_with_none_bases = [fwd_inputs[i] for i in aliased_input_indices if fwd_inputs[i]._base is None] + if len(non_none_bases) == 0: + # Case where none of the aliases require gradients + example_idx = aliased_input_indices[0] + synthetic_base = torch.Tensor(fwd_inputs[example_idx]._storage()) + else: + # Case where all of the aliases require gradients, and have the same _base. + synthetic_base = non_none_bases[0] + for other_base in non_none_bases[1:]: + assert other_base is synthetic_base, \ + "aot_autograd() does not yet handle non-differentiable view input mutations." + for alias in aliases_with_none_bases: + assert alias is synthetic_base, "aot_autograd() does not yet handle non-differentiable view input mutations." + base_args.append(synthetic_base) + for curr_view_idx in aliased_input_indices: + curr_view = fwd_inputs[curr_view_idx] + base_idx = len(base_args) - 1 + size_ = curr_view.size() + stride_ = curr_view.stride() + storage_offset_ = curr_view.storage_offset() + # We store just enough info here so that we can regenerate the view later. + # Regeneration: args[base_idx].as_strided(size_, stride_, storage_offset_) + # If we want view replay instead of as_strided() calls, this will need to change. + inner_calling_convention_meta[curr_view_idx] = (base_idx, (size_, stride_, storage_offset_)) + else: + for curr_idx in aliased_input_indices: + other_args.append(fwd_inputs[curr_idx]) + if len(base_args) == 0: + assert len(other_args) == len(fwd_inputs) + # If no synthetic bases are necessary, just return the original inputs. + return fwd_inputs, None + else: + # Otherwise, return: + # (1) The new args according to the updated calling convention: (synthetic_bases, other_args) + # (2) Metadata telling functionalization how to generate the inner argument list given the outer calling convention. + # We post-process it into a list, where meta[i] tells you info about the i'th argument in the inner calling convention. + args_to_functionalization = base_args + other_args + arg_to_old_idx_map = {arg: i for (i, arg) in enumerate(fwd_inputs)} + for i, other_arg in enumerate(other_args): + new_idx = len(base_args) + i + old_idx = arg_to_old_idx_map[other_arg] + inner_calling_convention_meta[old_idx] = new_idx + # post process into a list + post_processed_calling_convention_meta: List[Union[int, Callable]] = [-1 for _ in range(len(inner_calling_convention_meta))] + for k, v in inner_calling_convention_meta.items(): + post_processed_calling_convention_meta[k] = v + # Quick assert: every argument in the inner calling convention should be accounted for. + for x in post_processed_calling_convention_meta: + assert x != -1 + return args_to_functionalization, post_processed_calling_convention_meta + + +def format_guard_bug_msg(aot_config, expected): + return ( + f"At compilation time, graph {aot_config.aot_id} was compiled under the " + f"assumption that {expected}, but at runtime this was not the case. " + "This indicates a guard bug in AOTAutograd or Dynamo, please file a bug to PyTorch." + ) + + +# MOTIVATION: +# +# When tracing functions for future execution, one must be careful not to pass +# in the same input tensor multiple times (e.g., f(x, x), as this can result +# in graphs that are ONLY valid if you later pass a new tensor in exactly the +# same way (e.g., f(y, y)). (NB: we really mean duplicate; two distinct +# tensors that alias each other is a different situation that is covered by +# aot_dispatch_deduplicated_autograd). Here are two examples: +# +# (1) Suppose you have a function: +# +# def f(x, y): +# return x + y +# +# If you make_fx(f)(x, x), you will trace out: +# +# def f(x, y): +# return y + y +# +# Oops! +# +# (2) For most tensors x and y, you can compute f's gradient with respect to +# these to inputs by saying torch.autograd.grad(f(x, y), (x, y)). However, +# if x is y, you will trace out a program that gets incorrect gradients: +# +# >>> x = torch.randn(1, requires_grad=True) +# >>> torch.autograd.grad(x + x, (x, x)) +# (tensor([2.]), tensor([2.])) +# +# In other words, the gradient is double-counted. Deduplicating the arguments +# gives you an appropriate gradient: +# +# >>> y = torch.randn(1, requires_grad=True) +# >>> torch.autograd.grad(x + y, (x, y)) +# (tensor([1.]), tensor([1.])) +# +# HOW TO DEDUPLICATE: +# +# There are a few strategies, in order of preference: +# +# 1. For every duplicate argument to the function, detach it into +# a separate leaf tensor, so that it is no longer duplicated. +# +# PRO: The resulting compiled graph works for any configuration +# of duplicated arguments. +# +# CON: It does not (naively) work if you mutate the metadata of inputs: +# +# def f(x, y): +# x.transpose_(0, 1) +# y.transpose_(0, 2) +# +# x = torch.randn(2, 3, 4) +# f(x, x) +# +# The ordering of the transposes inside f dictates whether or not +# you get [4, 2, 3] or [3, 4, 2]. This means that you cannot precompute +# what metadata mutations should get applied to each input; you need to +# assume they aren't duplicates (what we do today) or preserve +# the original metadata mutations exactly in order, so that they work +# for any duplicate configuration. +# +# CON: It does not (naively) work if you mutate the data of inputs. +# In particular, leaf tensors that require grad cannot be mutated, +# this makes it impossible to differentiate with respect to the original +# base. +# +# 2. For every duplicate argument to the function, remove it, so it is +# no longer part of the "true" signature: +# +# PRO: Implemented naively, it still works for metadata/data mutation. +# +# CON: The resulting compiled graph is duplicate-specialized: it only +# works if future calls duplicate arguments in exactly the same way. +# Horribly, Dynamo doesn't guard on this at the moment. But even if +# it did, you could still end up recompiling a bunch of each duplicate. +# +# Our strategy is to do (1) if we can, and do (2) otherwise, erroring if +# Dynamo's guards are not enough. In practice, this seems to cover +# everything. +# +def aot_wrapper_dedupe(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig, *, compiler_fn): + # Get information about whether or not flat_fn mutates its arguments + # or not + try: + with enable_python_dispatcher(): + fw_metadata, _out, _num_aliasing_metadata_outs = run_functionalized_fw_and_collect_metadata( + flat_fn + )(*flat_args) + except RuntimeError as e: + logging.warning( + "Failed to collect metadata on function, produced code may be suboptimal. " + "Known situations this can occur are inference mode only compilation involving " + "resize_ or prims (!schema.hasAnyAliasInfo() INTERNAL ASSERT FAILED); " + "if your situation looks different please file a bug to PyTorch.", + exc_info=True + ) + # Analysis failed, fall back to duplicate specialize + # TODO: Known analysis problems: + # - resize_: TestInductorOpInfoCPU.test_comprehensive_resize__cpu_bool + # - prims: test_tmp_not_defined_issue1_cpu + pass + else: + # Strategy 1: For any input that is not mutated, we can leafify it if we + # need to remove a duplicate. + leaf_flat_args = [] + args_set = set() + ok = True + + for i, a in enumerate(flat_args): + if a not in args_set: + args_set.add(a) + leaf_flat_args.append(a) + elif fw_metadata.mutated_input_info[i] == MutationType.none: + leaf_flat_args.append(a.detach().requires_grad_(a.requires_grad)) + else: + ok = False + break + + if ok: + return compiler_fn(flat_fn, leaf_flat_args, aot_config) + + # Strategy 2: Duplicate specialize. + # + # In Haskell types, suppose you have: + # + # add_dupe_args :: DedupedArgs -> Args + # remove_dupe_args :: Args -> DedupedArgs + # + # compiler_fn + # :: (DedupedArgs -> R) -> DedupedArgs -> AOTConfig -> (DedupedArgs -> R) + # deped_compiler_fn + # :: (Args -> R) -> Args -> AOTConfig -> (Args -> R) + # + # Then the code below can be written in point-free style as: + # + # deduped_compiler_fn f a c = + # compiler_fn (f . add_dupe_args) (remove_dupe_args a) c . remove_dupe_args + # + # Suppose you have: + # + # [a, b, a, c] + # + # We want: + # + # remove_dupe_args([a, b, a, c]) == [a, b, c] + # add_dupe_args([a, b, c]) == [a, b, a, c] + # + # This is done via (respectively): + # + # seen_args = {a: 0, b: 1, c: 2} + # add_dupe_map = { # how to get args from the deduped list + # 0: 0, + # 1: 1, + # 2: 0, + # 3: 2, + # } + # keep_arg_mask = [True, True, False, True] + + seen_args = {} + keep_arg_mask = [] + add_dupe_map = {} + duped_arg_len = len(flat_args) + + j = 0 # index into deduped_flat_args + for i, t in enumerate(flat_args): + if t in seen_args: + keep_arg_mask.append(False) + add_dupe_map[i] = seen_args[t] + continue + keep_arg_mask.append(True) + seen_args[t] = j + add_dupe_map[i] = j + j += 1 + + unique_args = j + + # NB: Hot path, avoid set lookups here + # TODO: Can avoid the zip here too, probably + def remove_dupe_args(args): + return [t for t, keep in zip(args, keep_arg_mask) if keep] + + def add_dupe_args(args): + return [args[add_dupe_map[i]] for i in range(duped_arg_len)] + + deduped_flat_args = remove_dupe_args(flat_args) + + @wraps(flat_fn) + def wrapped_flat_fn(*args): + return flat_fn(*add_dupe_args(args)) + + compiled_fn = compiler_fn(wrapped_flat_fn, deduped_flat_args, aot_config) + + if not hasattr(compiled_fn, "_boxed_call"): + compiled_fn = make_boxed_func(compiled_fn) + + @wraps(compiled_fn) + def wrapped_compiled_fn(args): + deduped_args = remove_dupe_args(args) + args.clear() + return compiled_fn(deduped_args) + wrapped_compiled_fn._boxed_call = True + + # This can be uncommented when we properly guard for duplicates, + # but right now we must not do it. + # if not config.debug_assert: + # return wrapped_compiled_fn + + @wraps(wrapped_compiled_fn) + def debugged_compiled_fn(args): + # Test that the computed remove/add arg functions are an inverse + new_args = add_dupe_args(remove_dupe_args(args)) + seen = {} + for i, (x, y) in enumerate(zip(new_args, args)): + seen[y] = None + assert x is y, format_guard_bug_msg( + aot_config, + f"{describe_input(i, aot_config)} would be a duplicate of " + f"{describe_input(add_dupe_map[i], aot_config)}" + ) + # This is only an error if there is metadata mutation on both of + # the duped arguments; in this case, we need to know what order + # the metadata mutation applies in. You'll get the correct result + # otherwise, because a graph that assumes distinct inputs works if + # you dupe the inputs (the gradient contributions from each input + # will get summed up appropriately.) + # + # TODO: work out how to setup this assert correctly + """ + assert len(seen) == unique_args, format_guard_bug_msg(aot_config, + f"there would be {unique_args} distinct arguments" + ) + """ + return wrapped_compiled_fn(args) + debugged_compiled_fn._boxed_call = True + + return debugged_compiled_fn + + +def describe_input(i, aot_config): + if i < aot_config.num_params_buffers: + return f"parameter/buffer {i}" + else: + return f"input {i - aot_config.num_params_buffers}" + + +# Has the precondition that there +# are no duplicate arguments in flat_args (e.g., the same Tensor +# object never shows up twice. However, two tensor inputs MAY alias +# the same storage, so long as they have separate TensorImpls.) +def aot_dispatch_autograd(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig): + + with enable_python_dispatcher(): + _fw_metadata, out, _num_aliasing_metadata_outs = run_functionalized_fw_and_collect_metadata( + flat_fn + )(*flat_args) + + # pre-compute, so we can bail out quickly in the hotpath + _num_outputs_aliased_to_inputs = len([ + x for x in _fw_metadata.aliased_output_info if x.output_type == OutputType.alias_of_input]) + _num_outputs_aliased_to_intermediates = len([ + x for x in _fw_metadata.aliased_output_info if x.output_type == OutputType.alias_of_intermediate]) + _num_mutated_data_inputs = len([x for x in _fw_metadata.mutated_input_info if x == MutationType.data]) + _num_mutated_metadata_only_inputs = len([x for x in _fw_metadata.metadata_mutation_input_info if x is not None]) + _num_mutated_inputs = _num_mutated_data_inputs + _num_mutated_metadata_only_inputs + + if isinstance(out, (list, tuple)): + _num_non_aliased_outs = len(out[_num_mutated_data_inputs:]) + else: + _num_non_aliased_outs = 1 + assert len(_fw_metadata.requires_grad_out_info) == _num_mutated_data_inputs + _num_non_aliased_outs + + # out here corresponds to the set of outputs that should be returned by the traced forward call. + # It includes outputs of the original forward, *and* any updated inputs due to input mutations. + # However, it does *not* include any outputs that are aliases of inputs, or any metadata-only input mutations. + out = pytree.tree_map( + lambda x: x.detach().contiguous() if isinstance(x, Tensor) else x, + out, + ) + + # This code only executes if we have graph inputs that alias each other, and one of those inputs + # gets its data mutated. + # When that happens, we replace the aliased inputs with a synthetic base, and in the traced forward + # we later generate the input views + flat_args_with_views_handled, _synthetic_base_info = merge_view_inputs( + flat_args, _fw_metadata.mutated_input_info) + + joint_forward_backward = create_joint_forward_backward_functionalized( + flat_fn, + meta=_fw_metadata, + synthetic_base_info=_synthetic_base_info, + ) + + joint_inputs = (flat_args_with_views_handled, out) + + disable_amp = torch._C._is_any_autocast_enabled() + + if config.use_functionalize: + with enable_python_dispatcher(): + flattened_joints, _ = pytree.tree_flatten(joint_inputs) + fx_g = make_fx( + joint_forward_backward, aot_config.decompositions + )(*joint_inputs) + + # Redudant with the check above, but worth having in case tracing introduced + # a fake tensor. Unlikely. + # See Note: [Fake Modules and AOTAutograd] + torch._dynamo.utils.assert_no_fake_params_or_buffers(fx_g) + fx_g.graph.eliminate_dead_code() + fx_g.recompile() + else: + # joint_forward_backward() now always runs with functionalization, and factoring it out + # to make that toggleable is a bit painful. + # aot autograd without functionalization is wrong anyway, so we error. + raise AssertionError("Graph partitioning without functionalization is not sound, we may introduce errors") + + if config.debug_joint: + print(f"====== Joint graph {aot_config.aot_id} ======") + fx_g.print_readable() + + with torch.no_grad(): + with track_graph_compiling(aot_config, "joint"): + num_inner_fwd_outputs = _num_mutated_data_inputs + _num_non_aliased_outs + _num_aliasing_metadata_outs + fw_module, bw_module = aot_config.partition_fn( + fx_g, joint_inputs, num_fwd_outputs=num_inner_fwd_outputs) + fw_outs = [n for n in fw_module.graph.nodes if n.op == "output"][0].args[0] + # we only need to bookkeep the symints that are saved for bw, not any symints + # the user forward might have returned in its own output + fw_outs_saved_for_bw = fw_outs[num_inner_fwd_outputs:] + symint_outs_saved_for_bw = [n for n in fw_outs_saved_for_bw if is_sym_node(n)] + _num_symints_saved_for_bw = len(symint_outs_saved_for_bw) + + if config.debug_graphs: + print("====== Forward graph {aot_config.aot_id} ======") + fw_module.print_readable() + print("====== Backward graph {aot_config.aot_id} ======") + bw_module.print_readable() + + with track_graph_compiling(aot_config, "forward"): + compiled_fw_func = aot_config.fw_compiler(fw_module, flat_args_with_views_handled) + + class CompiledFunction(torch.autograd.Function): + compiled_fw = compiled_fw_func + compiled_bw = None + # Corresponds to number of outs (not including updated inputs returns as outs), + # *and* not including outs that are aliases of inputs + num_non_aliased_outs = _num_non_aliased_outs + num_symints_saved_for_bw = _num_symints_saved_for_bw + # Corresponds to number of inputs that are mutated (both metadata only, and data) + num_mutated_inputs = _num_mutated_inputs + # Corresponds to number of inputs that only have their metadata mutated + num_mutated_data_inputs = _num_mutated_data_inputs + # Corresponds to number of inputs that get their metadata (but not data) mutated + # We don't return these in the compiled fw, and instead we stash enough info + # to replay the metadata mutations later. + num_mutated_metadata_only_inputs = _num_mutated_metadata_only_inputs + # Corresponds to number of outputs in the original fw that are aliases of inputs + # (These are all not returned by the compiled forward, and instead they are manually + # created in the epilogue) + num_outputs_aliased_to_inputs = _num_outputs_aliased_to_inputs + # Corresponds to the number of user outputs that alias intermediates (aka graph outputs). + num_outputs_aliased_to_intermediates = _num_outputs_aliased_to_intermediates + # For every output that aliases and input, and every input that gets only its metadata mutated, + # we return that tensor's size/stride/storage_offset directly at the end of the compiled forward, + # as a big list of ints. + # The number is tracked here. + num_aliasing_metadata_outs = _num_aliasing_metadata_outs + synthetic_base_info = _synthetic_base_info + fw_metadata = _fw_metadata + + @staticmethod + def forward(ctx, *deduped_flat_tensor_args): + + # There is a pretty complicated calling convention around what the compiled fw returns. + # The full list of outputs and their relative order is: + # (*mutated_data_inputs, *non_aliased_fw_outs, *saved_tensors, *saved_symints) + # - Note that in the synthetic bases case, mutated_inputs will correspond to an updated version + # of the original view, and not the synthetic base + fw_outs = call_func_with_args( + CompiledFunction.compiled_fw, deduped_flat_tensor_args, disable_amp=disable_amp + ) + + num_non_aliased_outs = CompiledFunction.num_non_aliased_outs + num_aliasing_metadata_outs = CompiledFunction.num_aliasing_metadata_outs + num_symints_saved_for_bw = CompiledFunction.num_symints_saved_for_bw + num_mutated_data_inputs = CompiledFunction.num_mutated_data_inputs + # Our forward() returns both (mutated_inputs, outputs, output_alias_meta, saved_tensors, saved_symints) + num_forward_returns = num_mutated_data_inputs + num_non_aliased_outs + num_aliasing_metadata_outs + num_forward_returns_not_including_alias_meta = num_mutated_data_inputs + num_non_aliased_outs + + # Partitioners must put symint arguments at the end separate from tensor arguments + if num_symints_saved_for_bw > 0: + tensors_saved_for_backwards = fw_outs[num_forward_returns:-num_symints_saved_for_bw] + assert all([isinstance(x, torch.Tensor) for x in tensors_saved_for_backwards]) + ctx.save_for_backward(*tensors_saved_for_backwards) + symint_outs = fw_outs[-num_symints_saved_for_bw:] + assert all([isinstance(x, (int, float, torch.SymInt, torch.SymFloat)) for x in symint_outs]) + ctx.symints = symint_outs + else: + ctx.save_for_backward(*fw_outs[num_forward_returns:]) + ctx.symints = [] + + fw_outs_not_requiring_grad = [ + x for (i, x) in enumerate(fw_outs[:num_forward_returns_not_including_alias_meta]) + if isinstance(x, torch.Tensor) and not CompiledFunction.fw_metadata.requires_grad_out_info[i] + ] + fw_out_ids_requiring_grad = [ + id(x) for (i, x) in enumerate(fw_outs[:num_forward_returns_not_including_alias_meta]) + if isinstance(x, torch.Tensor) and CompiledFunction.fw_metadata.requires_grad_out_info[i] + ] + + ctx.mark_non_differentiable(*fw_outs_not_requiring_grad) + + return tuple(fw_outs[0:num_forward_returns]) + + @staticmethod + def backward(ctx, *all_flat_args): + # Calling convention: we expect a grad_out passed to the backward: + # - for every output of the fw that does *not* alias an input + # - for every updated_input generated by the fw that does *not* alias an input + # - for every size/stride metadata value for aliased outputs. + # These are returned by the forward, but we just drop them in the backward. + # We need to return them in the forward, but unfortunately there's no way to specify + # in autograd.Function that certain non-tensor forward outputs shouldn't show up in the backward. + expected_grad_outs = CompiledFunction.num_non_aliased_outs + CompiledFunction.num_mutated_data_inputs + if CompiledFunction.num_aliasing_metadata_outs > 0: + flat_args = all_flat_args[:-CompiledFunction.num_aliasing_metadata_outs] + metadata_args = all_flat_args[-CompiledFunction.num_aliasing_metadata_outs:] + # metadata args are all ints/symints, which autograd will send Nones for as grad_outputs in the bw + assert all([x is None for x in metadata_args]) + # delete + # for out_idx, (base_sizes, base_strides, base_storage_offset) in CompiledFunctions.fw_out_base_metadata.items(): + + else: + flat_args = all_flat_args + + assert len(flat_args) == expected_grad_outs + contiguous_args = [t.contiguous() if torch.is_tensor(t) else t for t in flat_args] + all_args = list(ctx.symints) + list(ctx.saved_tensors) + list(contiguous_args) + del contiguous_args + if CompiledFunction.compiled_bw is None: + # TODO - pass in fake tensors ? + context = disable_autocast_manager if disable_amp else nullcontext + with context(), track_graph_compiling(aot_config, "backward"): + CompiledFunction.compiled_bw = aot_config.bw_compiler( + bw_module, all_args + ) + + ctx.maybe_clear_saved_tensors() + out = call_func_with_args( + CompiledFunction.compiled_bw, all_args, steal_args=True, disable_amp=disable_amp + ) + return tuple(out) + + @wraps(CompiledFunction.apply) + def compiled_function(*args): + # Step 2: remove aliased inputs that are mutated, replace with synthetic bases + # Only happens if our graph mutates an input that aliases another input. + if CompiledFunction.synthetic_base_info is not None: + # Given: the original args, including at least one pair of inputs that are aliased + # and get subsequently mutated. + # Generate: the updated args, including (potentially multiple) synthetic bases + # that replace the views. The input views are regenerated manually in the compiled function. + # TODO: think harder about what happens if (a view of) one of these mutated input views is ALSO returned + new_inputs, metadata = merge_view_inputs(args, CompiledFunction.fw_metadata.mutated_input_info) + # We're just re-running the original-args-to-synthetic-base transformation + # that we ran during compilation. + # This returns metadata that we use during tracing to recover the input views, + # which we don't actually need at runtime. + assert metadata is not None + args_with_synthetic_bases = new_inputs + else: + args_with_synthetic_bases = args + + all_outs = CompiledFunction.apply(*args_with_synthetic_bases) + if CompiledFunction.num_aliasing_metadata_outs > 0: + outs = all_outs[:-CompiledFunction.num_aliasing_metadata_outs] + aliasing_metadata_outs = all_outs[-CompiledFunction.num_aliasing_metadata_outs:] + else: + outs = all_outs + aliasing_metadata_outs = [] + + assert len(all_outs) == CompiledFunction.num_mutated_data_inputs + CompiledFunction.num_non_aliased_outs \ + + CompiledFunction.num_aliasing_metadata_outs + + # Step 3: After running the compiled fw, apply updates to mutated inputs + if CompiledFunction.num_mutated_inputs > 0: + # Calling convention: (mutated_inputs, real_outs, aliasing_metadata) + + if CompiledFunction.num_mutated_data_inputs > 0: + updated_inputs = outs[:CompiledFunction.num_mutated_data_inputs] + fw_outs = outs[CompiledFunction.num_mutated_data_inputs:] + else: + updated_inputs = [] + fw_outs = outs + + curr_mutated_inpt_idx = 0 + for inpt_idx, (mutation_type, metadata_mutation_info) in enumerate(zip( + # TODO: I should merge these two pieces of state + CompiledFunction.fw_metadata.mutated_input_info, + CompiledFunction.fw_metadata.metadata_mutation_input_info, + )): + if mutation_type == MutationType.none: + continue + original_inpt = args[inpt_idx] + if mutation_type == MutationType.metadata_only: + # We need to grab the size/stride/storage_offset from the compiled forward, + # and use that to mutate the metadata of the input + expected_meta = CompiledFunction.fw_metadata.metadata_mutation_input_info[inpt_idx] + assert expected_meta is not None + fake_meta = expected_meta.tensor_meta + size_len = len(fake_meta.size()) + stride_len = len(fake_meta.stride()) + size_ = aliasing_metadata_outs[expected_meta.sizes_idx:expected_meta.sizes_idx + size_len] + stride_ = aliasing_metadata_outs[expected_meta.strides_idx:expected_meta.strides_idx + stride_len] + storage_offset_ = aliasing_metadata_outs[expected_meta.storage_offset_idx] + original_inpt.as_strided_(size_, stride_, storage_offset_) + else: + updated_inpt = updated_inputs[curr_mutated_inpt_idx] + curr_mutated_inpt_idx += 1 + # TODO: handle resize_() on inputs to a larger size. + # This is actually non-trivial to detect, so we should probably just handle it + # (or make dynamo detect). + # We can't just check of original_inpt.storage_size != updated_inpt.storage_size, + # Because the original_inpt might be a view of some larger tensor, + # and updated_inpt is always densely packed. + if original_inpt.size() != updated_inpt.size() \ + or original_inpt.stride() != updated_inpt.stride() \ + or original_inpt.storage_offset() != updated_inpt.storage_offset(): + # Functionalization can't easily tell us if an input had BOTH its metadata actual data mutated. + # So we check if metadata needs to be mutated here manually. + original_inpt.as_strided_(updated_inpt.size(), updated_inpt.stride(), updated_inpt.storage_offset()) + original_inpt.copy_(updated_inpt) + else: + fw_outs = outs + + # Step 4: Manually regenerate any outputs that are aliased to inputs, instead of + # compiling them. + if CompiledFunction.num_outputs_aliased_to_inputs > 0 or CompiledFunction.num_outputs_aliased_to_intermediates > 0: + assert CompiledFunction.num_outputs_aliased_to_inputs + len(fw_outs) == \ + len(CompiledFunction.fw_metadata.aliased_output_info) + fw_outs_including_aliases = [] + for aliased_out_metadata in CompiledFunction.fw_metadata.aliased_output_info: + if aliased_out_metadata.output_type == OutputType.non_alias: + fw_outs_including_aliases.append(fw_outs[aliased_out_metadata.base_idx]) + else: + if aliased_out_metadata.output_type == OutputType.alias_of_input: + aliased_base_tensor = args[aliased_out_metadata.base_idx] + else: + assert aliased_out_metadata.output_type == OutputType.alias_of_intermediate + aliased_base_tensor = fw_outs[aliased_out_metadata.base_idx] + # Note: here, we manually regenerate the output, using an as_strided() call, + # OR if the aliased output came from a custom autograd.function, we replay it. + # The as_strided() in the normal case is good for perf (this is hot-path code, + # and we're consolidating potential chains of views into a single view op). + # But we might need to figure out view replaying for e.g. XLA. + # TODO: handle the custom autograd function case here. + # We need a way to check whether a tensor came from a custom autograd fn from python, + # AND a way to replay that custom view fn. + fake_meta = aliased_out_metadata.tensor_meta + if fake_meta is None: + # This handles the specific case where the user returns an output that *was* an input. Don't create a view. + fw_outs_including_aliases.append(aliased_base_tensor) + else: + # We need to grab the size/stride/storage_offset from the compiled forward, + # and use that to create a view off of the right input + fake_meta = aliased_out_metadata.tensor_meta + size_len = len(fake_meta.size()) + stride_len = len(fake_meta.stride()) + size_ = aliasing_metadata_outs[aliased_out_metadata.sizes_idx:aliased_out_metadata.sizes_idx + size_len] + stride_ = aliasing_metadata_outs[ + aliased_out_metadata.strides_idx:aliased_out_metadata.strides_idx + stride_len] + storage_offset_ = aliasing_metadata_outs[aliased_out_metadata.storage_offset_idx] + # Create the output alias + aliased_out = gen_alias_from_base(aliased_base_tensor, size_, stride_, storage_offset_, fake_meta) + fw_outs_including_aliases.append(aliased_out) + + for inner_out, user_out in zip(fw_outs, fw_outs_including_aliases): + # Sanity check assert + assert type(inner_out) == type(user_out) + return fw_outs_including_aliases + else: + return fw_outs + + if not config.debug_assert: + return compiled_function + + flat_requires_grad = [a.requires_grad if isinstance(a, Tensor) else None for a in flat_args] + + @wraps(compiled_function) + def debug_compiled_function(*args): + # TODO: Check aliasing relationships + # TODO: Check strides for metadata mutation + # (NB: ideally, this logic is factored out of this function and + # you move these debug checks there) + + # Check requires grad. Bad case is when we compiled with + # requires_grad = False, but input requires_grad = True + # (vice versa is OK; we compute a gradient and then throw + # it away when it hits the input.) + for i, a in enumerate(args): + can_require_grad = flat_requires_grad[i] + if can_require_grad is None: + assert not isinstance(a, Tensor) + elif not can_require_grad: + assert not a.requires_grad, format_guard_bug_msg( + aot_config, + f"{describe_input(i, aot_config)} would not require grad" + ) + + return compiled_function(*args) + + return debug_compiled_function + + +@dynamo_timed +def create_aot_dispatcher_function( + flat_fn, flat_args: List[Tensor], aot_config: AOTConfig +): + """ + Traces the forward and backward graphs of the attr:`flat_fn` to generate a + joint graph. The joint graph is an Fx graph with Aten ops. Please refer to + the tracing mechanism to understand the graph capturing details. + + The joint graph is then passed through attr:`partition_fn` to isolate the + forward and backward portions, which are then respectively compiled via the + provided attr:`fw_compiler` and attr:`bw_compiler`. + + The resulting compiled forward and backward graphs are then wrapped up in a + ``torch.autograd.Function`` object. + + The calling convention here is that the first aot_config.num_params_buffers + inputs in flat_args are parameters and buffers, and the rest are inputs. + + We use this to assume that parameters/buffer's shapes don't change. + """ + + # This is the main entry point. + # TODO: Chillee argues that dynamo itself should pass in fake tensors to + # the list of arguments when compiling; at the moment we do not do this + + if aot_config.decompositions is None: + aot_config.decompositions = {} + + aot_config.decompositions = { + **aot_autograd_decompositions, + **aot_config.decompositions, + } + # NB: don't bother setting allow_fallback_kernels; this should not actually + # be configurable in fake tensor, we should automatically do the right + # thing + if config.debug_fake_cross_ref: + # This is a little messy but TorchDynamo directly changes `use_fake_tensor` + # so it's not enough for user to change the config manually + # TODO: have TorchDynamo read in `use_fake_tensor` from os environ / + # coordinate flags + config.use_fake_tensor = False + + if config.use_dynamic_shapes: + assert config.use_fake_tensor, "Dynamic shapes only works with fake tensor" + + # Check flat_args to see if they're already fake. If so, use that fake + # mode instead. + + for x in flat_args: + if isinstance(x, FakeTensor): + fake_mode = x.fake_mode + break + else: + shape_env = ShapeEnv() if config.use_dynamic_shapes else None + fake_mode = FakeTensorMode(shape_env=shape_env) if config.use_fake_tensor else nullcontext() + + cross_ref = CrossRefFakeMode() if config.debug_fake_cross_ref else nullcontext() + python_dispatcher_mode = enable_python_dispatcher() if config.use_dynamic_shapes else nullcontext() + + with torch.autograd.set_multithreading_enabled(False), preserve_rng_state(), cross_ref, fake_mode, python_dispatcher_mode: + + def process_inputs(flat_args): + if config.use_fake_tensor or isinstance(fake_mode, FakeTensorMode): + def convert(idx, x): + if not isinstance(x, torch.Tensor): + return x + if isinstance(x, FakeTensor): + assert x.fake_mode is fake_mode + return x + if idx < aot_config.num_params_buffers and config.static_weight_shapes: + return fake_mode.from_tensor(x, static_shapes=True) + return fake_mode.from_tensor(x, static_shapes=False) + + return [convert(idx, x) for idx, x in enumerate(flat_args)] + else: + return flat_args + + fake_flat_tensor_args = process_inputs(flat_args) + + needs_autograd = ( + any( + [ + x.requires_grad + for x in fake_flat_tensor_args + if isinstance(x, Tensor) + ] + ) + and torch.is_grad_enabled() + ) + # crappy version of dispatcher + # TODO: Do this properly + if needs_autograd: + compiler_fn = aot_dispatch_autograd + else: + compiler_fn = aot_dispatch_base + + compiler_fn = partial(aot_wrapper_dedupe, compiler_fn=compiler_fn) + # You can put more passes here + + compiled_fn = compiler_fn(flat_fn, fake_flat_tensor_args, aot_config) + + if not hasattr(compiled_fn, '_boxed_call'): + compiled_fn = make_boxed_func(compiled_fn) + + return compiled_fn + + +# Inspired by autodidax (thanks!) +class PytreeThunk: + spec = None + # These are some kinda dumb microoptimizations that save about 3-4 us of overhead. + is_simple = ( + None # if the output spec is a tuple/list, we won't bother unflattening it. + ) + is_really_simple = None # if the output spec is a LeafSpec + + def set(self, spec): + assert self.spec is None or self.spec == spec + self.spec = spec + if type(self.spec) in [tuple, list] and all( + isinstance(i, pytree.LeafSpec) for i in spec.children_specs + ): + self.is_simple = True + if isinstance(self.spec, pytree.LeafSpec): + self.is_really_simple = True + + def unflatten(self, x): + if self.is_really_simple: + return x[0] + if self.is_simple: + return x + return pytree.tree_unflatten(x, self.spec) + + +def aot_function( + fn: Callable, + fw_compiler: Callable, + bw_compiler: Optional[Callable] = None, + partition_fn: Callable = default_partition, + decompositions: Optional[Dict] = None, + num_params_buffers: int = 0, + hasher_type=None, # deprecated + static_argnums: Optional[Tuple[int]] = None, # deprecated +) -> Callable: + """ + Traces the forward and backward graph of :attr:`fn` using torch dispatch + mechanism, and then compiles the generated forward and backward graphs + through :attr:`fw_compiler` and :attr:`bw_compiler`. + + :func:`aot_function` traces the forward and backward graph ahead of time, + and generates a joint forward and backward graph. :attr:`partition_fn` is + then used to separate out forward and backward graphs. The partitioner + function can be used to perform optimizations such as recomputation. One can + set `decompositions` dictionary to decompose the operators into a sequence + of core or simpler operators supported by the backend compilers. + + :func:`aot_function` uses a compilation cache, based on input tensor + properties, to detect when there is a need of recompilation. + + .. warning:: + This API is experimental and likely to change. + + Args: + fn (Callable): A Python function that takes one ore more arguments. Must + return one or more Tensors. + fw_compiler (Callable): A Python function that accepts an Fx graph with + Aten ops and input args, and returns a Callable that semantically is + equivalent to the input Fx graph. + bw_compiler (Optional[Callable]): A Python function that accepts an + Fx graph with Aten ops and input args, and returns a Callable that + semantically is equivalent to the input Fx graph. Default: None + (when None, it defaults to the :attr:`fw_compiler`) + partition_fn (Callable): A Python function that takes a joint forward + and backward graph, and partitions it into separate forward and + backward graphs. + decompositions (Dict): A dictionary to define the decomposition of + larger Aten ops into simpler or core Aten ops. + + Returns: + Returns a ``Callable`` that retains the eager behavior of the original + :attr:`fn`, but with forward and backward graph compiled via + :attr:`fw_compile` and :attr:`bw_compile`. + + A simple example usage of :func:`aot_function` is as follows. This example + will print the forward and backward graphs of the function ``fn`` + + >>> fn = lambda x : x.sin().cos() + >>> def print_compile_fn(fx_module, args): + >>> print(fx_module) + >>> return fx_module + >>> aot_fn = aot_function(fn, print_compile_fn) + >>> x = torch.randn(4, 5, requires_grad=True) + >>> aot_fn(x) + """ + if static_argnums is not None: + raise RuntimeError("static_argnums has been deprecated - manually wrap your function or use torchdynamo.") + + if bw_compiler is None: + bw_compiler = fw_compiler + aot_config = AOTConfig( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + partition_fn=partition_fn, + decompositions=decompositions, + num_params_buffers=num_params_buffers, + aot_id=next(AOT_COUNTER), + ) + cached_res = None + + @wraps(fn) + def returned_function(*args, **kwargs): + nonlocal cached_res + # Now flatten the tensor args + flat_args, _ = pytree.tree_flatten((args, kwargs)) + + # Compile the function and save it in the cache + if cached_res is None: + # Save the args_spec for flat_tensor_args to unflatten while tracing + _, tensor_args_spec = pytree.tree_flatten((args, kwargs)) + out_spec = PytreeThunk() + + def flat_fn(*flat_args): + # The input are flattened tensor args. Prepare the args in the + # order that original function expects. Add static args as well. + # They will appear as tensor constants in the traced graph. + nonlocal out_spec + args, kwargs = pytree.tree_unflatten( + flat_args, tensor_args_spec + ) + tree_out = fn(*args, **kwargs) + flat_out, spec = pytree.tree_flatten(tree_out) + for i in flat_out: + is_known_type = False + for j in KNOWN_TYPES: + if isinstance(i, j): + is_known_type = True + break + if not is_known_type: + raise RuntimeError( + f"Found {type(i)} in output, which is not a known type. " + "If this type holds tensors, you need to register a pytree for it. " + "See https://github.com/pytorch/functorch/issues/475 for a brief " + "explanation why. If you don't need to register a pytree, please " + "leave a comment explaining your use case and we'll make this more " + "ergonomic to deal with" + ) + out_spec.set(spec) + return flat_out + + compiled_fn = create_aot_dispatcher_function( + flat_fn, + flat_args, + aot_config, + ) + cached_res = (compiled_fn, out_spec) + + cached_fn, out_spec = cached_res + out = cached_fn(flat_args) + return out_spec.unflatten(out) + + return returned_function + + +def aot_module(mod: nn.Module, *args, **kwargs) -> nn.Module: + """ + Traces the forward and backward graph of :attr:`mod` using torch dispatch + tracing mechanism. It is wrapper function, that underneath uses + :func:`aot_function` to perform tracing and compilation. + + :func:`aot_module` lifts the parameters and buffers of ``nn.Module`` as inputs + to a new callable which is then compiled through :func:`aot_function`. + + .. warning:: + This API is experimental and likely to change. + + Args: + mod (Callable): A ``nn.Module`` module. + args : args to be passed to :func:`aot_function` + kwargs : kwargs to be passed to :func:`aot_function` + + Returns: + Returns a ``nn.Module`` that retains the eager behavior of the original + :attr:`mod`, but with forward and backward graph compiled. + + """ + # See Note: [Fake Modules and AOTAutograd] + torch._dynamo.utils.assert_no_fake_params_or_buffers(mod) + + def functional_call(named_params, named_buffers, *args, **kwargs): + params_and_buffers = {**named_params, **named_buffers} + return stateless.functional_call(mod, params_and_buffers, args, kwargs) + + named_params = dict(_named_parameters(mod, remove_duplicate=False)) + named_buffers = dict(_named_buffers(mod, remove_duplicate=False)) + num_params_buffers = len(named_params) + len(named_buffers) + compiled_f = aot_function(functional_call, num_params_buffers=num_params_buffers, *args, **kwargs) + + class AOTModule(nn.Module): + def __init__(self): + super(AOTModule, self).__init__() + self.orig_module = mod + + def forward(self, *args, **kwargs): + return compiled_f( + named_params, + named_buffers, + *args, + **kwargs, + ) + + return AOTModule() + + +def aot_module_simplified( + mod: nn.Module, + args, + fw_compiler: Callable, + bw_compiler: Optional[Callable] = None, + partition_fn: Callable = default_partition, + decompositions: Optional[Dict] = None, + hasher_type=None, + static_argnums=None +) -> nn.Module: + """ + This is the simplified or low overhead version of aot_module. For frontends + like TorchDynamo, the input functions/modules to AOT are static and have + unpacked inputs/outputs. This gives us an opportunity to remove the + (1) pytree overhead to parse inputs/outputs, + (2) AOT Autograd cache, + (3) Reading of params/buffers in every forward call + + :func:`aot_module_simplified` removes these overheads. + """ + ######################################################### + + # Redudant with dynamo, but worth having in case this gets invoked elsewhere. + + # Note [Fake Modules and AOTAutograd] + # + # A simple heuristic for when to use fake versus real tensors is that fake tensors are for compile time + # (when we don't want to actually run the compute, but we do want to know about metadata), + # and real tensors are for runtime (when we actually want to do the compute.) However, in AOTAutograd, + # modules are the exception: we always pass AOTAutograd modules with real tensors. + # This is because AOTAutograd will produce a compiled function which needs to directly access any + # parameters the compiled function may need, but these parameters will NOT be passed in by the caller (aka Dynamo). + # So at compile time, the compiled function we produce must close over any parameters, and those parameters must be + # real parameters, and we cannot do this unless at compile time we get a module with real tensors. + + # Even if Dynamo did pass all parameters explicitly at runtime, which would eliminate the need to close over + # the parameters, it would still be profitable to pass real tensor parameters to the compiler at compile time, + # because some compilation strategies like CUDA graphs want to burn in the pointer addresses where the parameter data live, + # and of course we can't do that unless we give the backend a real tensor. + torch._dynamo.utils.assert_no_fake_params_or_buffers(mod) + + params = { + **dict(_named_parameters(mod, remove_duplicate=False)), + **dict(_named_buffers(mod, remove_duplicate=False)), + } + params_flat, params_spec = pytree.tree_flatten(params) + params_flat = tuple(params_flat) + params_len = len(params_flat) + + def functional_call(*args, **kwargs): + with stateless._reparametrize_module( + mod, pytree.tree_unflatten(args[:params_len], params_spec) + ): + if isinstance(mod, torch.fx.GraphModule): + with fx_traceback.override_stack_trace(), warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", "Anomaly Detection has been enabled." + ) + with torch.autograd.detect_anomaly(check_nan=False): + out = Interpreter(mod).run(*args[params_len:], **kwargs) + else: + out = mod(*args[params_len:], **kwargs) + + if not isinstance(out, (tuple, list)): + raise RuntimeError( + "Graph output must be a tuple(). This is so that we can avoid " + "pytree processing of the ouputs. Please change the module to " + "have tuple outputs or use aot_module instead." + ) + return out + + assert static_argnums is None + if bw_compiler is None: + bw_compiler = fw_compiler + aot_config = AOTConfig( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + partition_fn=partition_fn, + decompositions=decompositions, + num_params_buffers=params_len, + aot_id=next(AOT_COUNTER), + ) + + full_args = [] + full_args.extend(params_flat) + full_args.extend(args) + + compiled_fn = create_aot_dispatcher_function( + functional_call, + full_args, + aot_config, + ) + + # TODO: There is something deeply wrong here; compiled_fn running with + # the boxed calling convention, but aot_module_simplified somehow + # historically returned a function that was not the boxed calling + # convention. This should get fixed... + def forward(*runtime_args): + full_args = [] + full_args.extend(params_flat) + full_args.extend(runtime_args) + return compiled_fn(full_args) + + # Just for convenience + forward.zero_grad = mod.zero_grad + forward.named_parameters = mod.named_parameters + + return forward + + +compiled_function = aot_function +compiled_module = aot_module diff --git a/torch/_functorch/autograd_function.py b/torch/_functorch/autograd_function.py new file mode 100644 index 0000000000000..d80a826e7ff51 --- /dev/null +++ b/torch/_functorch/autograd_function.py @@ -0,0 +1,188 @@ +import torch +from torch._ops import PyOperator +from torch._C._functorch import TransformType +from torch._functorch.utils import enable_autograd_function +from torch.autograd.function import _SingleLevelFunction +import torch.utils._pytree as pytree +from torch._C._functorch import ( + _wrap_for_grad, + _unwrap_for_grad, +) + +# autograd.Function technically runs before the regular PyTorch dispatcher. +# This is how features like autocast and torch_dispatch (e.g. PythonTLSSnapshot) +# work with it. One day we might decide to change this, but until then, +# we need to give the illusion that autograd.Function runs before those things. +# +# We do this by using creating a custom PyOperator that only functorch +# dispatches specially. +class CustomFunctionPyOperator(PyOperator): + def __init__(self): + super().__init__('custom_function_call') + + def __call__(self, *args, **kwargs): + # When custom_function_call is done dispatching through functorch, + # it should just invoke the autograd.Function. This is consistent + # with the autograd.Function behavior of being invoked before the + # PyTorch dispatcher. + # + # This will lead us into trouble later down the line, but this is + # pre-existing. There is an invariant that a function traced by + # make_fx should have the same behavior when provided the same + # Tensor. However, make_fx sees autograd.Function as a composite + # (because autograd.Function happens before the Python dispatch key) + # and only traces the forward pass. + if torch._C._are_functorch_transforms_active(): + return super().__call__(*args, **kwargs) + autograd_function = args[0] + return autograd_function.apply(*args[1:], **kwargs) + + +# "custom_function_call" +# This is the mechanism for an autograd.Function that works with functorch transforms. +# It wraps an autograd.Function; interactions with functorch transforms are defined +# via PyDispatcher and PyOperator rather than through the traditional PyTorch +# dispatcher. +custom_function_call = CustomFunctionPyOperator() + + +# The grad rule for custom_function_call is to construct a new _SingleLevelFunction +# (autograd.Function that only works with a single layer (level) of functorch) that: +# - unwraps the inputs +# - redispatches to custom_function_call +# - wraps the outputs +# and whose backward pass calls the original autograd.Function's backward. +# +# Why do we need to redispatch to custom_function_call? +# ----------------------------------------------------- +# This is consistent with how ATen operators work with functorch's grad transform: +# they always redispatch to the original operator. +# Consider torch.sin, and let's say we do grad0(grad1(torch.sin))(x) +# +# grad1 will: +# - set up the autograd graph +# - unwrap the inputs +# - redispatch to at::sin (*) +# - rewrap the outputs on the return +# +# On the redispatch in (*), grad0 will: +# - set up the autograd graph +# - unwrap the inputs +# - redispatch to at::sin +# - rewrap the outputs on the return +# +# To "set up the autograd graph", we generate a _SingleLevelFunction +# and apply it. +@custom_function_call.py_impl(TransformType.Grad) +def custom_function_call_grad(interpreter, autograd_function, *operands): + maybe_interpreter = interpreter + level = maybe_interpreter.level() + + # TODO: The name of the grad_fn is GeneratedBackward. This isn't a great UX, + # but in theory functorch users shouldn't be peeking at the grad_fn. + # We should try to generate a better name for this. + # https://github.com/pytorch/pytorch/issues/90224 + class Generated(_SingleLevelFunction): + @staticmethod + def forward(*operands): + unwrapped_operands = pytree.tree_map_only( + torch.Tensor, + lambda x: _unwrap_for_grad(x, level), + operands) + with torch.enable_grad(), maybe_interpreter.lower(): + output = custom_function_call(autograd_function, *unwrapped_operands) + + return pytree.tree_map_only( + torch.Tensor, + lambda x: _wrap_for_grad(x, level), + output) + + @staticmethod + def setup_context(ctx, outputs, *operands): + ctx.mark_dirty = mark_dirty_error + return autograd_function.setup_context(ctx, outputs, *operands) + + @staticmethod + def backward(ctx, *grads): + result = autograd_function.backward(ctx, *grads) + return result + + with enable_autograd_function(): + flat_out = Generated.apply(*operands) + return flat_out + + +# https://github.com/pytorch/pytorch/issues/90225 +# If an input was marked as dirty, and the autograd.Function returns the input +# from the forward, then the grad rule for custom_function_call must also +# return the corresponding input from the forward() of the Generated autograd.Function +# +# We haven't figured out how to do this yet. One possibility is to rely +# on if the return from the redispatched custom_function_call in Generated.forward +# has the same object id as one of the inputs, +# but https://github.com/pytorch/pytorch/issues/90209 means we cannot rely on +# that property. +def mark_dirty_error(*args, **kwargs): + raise RuntimeError( + 'NYI: we do not yet support ctx.mark_dirty with functorch transforms. ' + 'Please try to avoid modifying inputs to the autograd.Function in-place ' + 'by using out-of-place operations or by cloning the inputs. ' + 'Please see https://github.com/pytorch/pytorch/issues/90209 for more details' + ) + + +# NOTE: [functorch vjp and autograd interaction] +# There's an edge case with the functorch vjp and autograd interaction +# that will eventually be fixed by mode-only functorch. +# The TL;DR is that there's no way to unwrap a dead GradTensorWrapper, +# so we (the framework) need to do it manually. Regular PyTorch operators +# automatically do so this is consisent. +# +# class MyExp(torch.autograd.Function): +# @staticmethod +# def forward(x): +# return x.exp() +# +# @staticmethod +# def setup_context(ctx, outputs, x): +# y = outputs +# ctx.save_for_backward(y) +# +# @staticmethod +# def backward(gy): +# y, = ctx.saved_tensors() +# return MyMul.apply(gy, y) +# +# x = torch.randn([], requires_grad=True) +# gy = torch.randn([], requires_grad=True) +# _, vjp_fn = vjp(MySin.apply, x) +# result = vjp_fn(gy) +# +# MyMul is an autograd.Function that is not shown here. +# It saves a `y` for backward (since gy requires grad). +# +# in vjp_fn(gy), we get: +# > MyMul.apply(gy, GradTensorWrapper(y, level=dead)) +# Because the y that is saved for backward by MyExp is a GradTensorWrapper +# but is now dead since we are outside the vjp context. +# +# PyTorch dispatcher operations, upon seeing a dead GradTensorWrapper, +# will automatically unwrap the GradTensorWrapper when applied. +# But since autograd.Function technically sits above the regular PyTorch +# dispatcher, it doesn't get this treatment. So we manually do +# the unwrapping to be consistent with regular PyTorch dispatcher operations. + + +@custom_function_call.py_impl(TransformType.Vmap) +def custom_function_call_vmap(interpreter, autograd_function, *operands): + raise RuntimeError("NYI: vmap rule for custom_function_call") + + +@custom_function_call.py_impl(TransformType.Jvp) +def custom_function_call_jvp(interpreter, autograd_function, *operands): + raise RuntimeError("NYI: jvp rule for custom_function_call") + + +@custom_function_call.py_impl(TransformType.Functionalize) +def custom_function_call_functionalize(interpreter, autograd_function, *operands): + raise RuntimeError("NYI: Functionalize rule for custom_function_call") diff --git a/functorch/_src/benchmark_utils.py b/torch/_functorch/benchmark_utils.py similarity index 100% rename from functorch/_src/benchmark_utils.py rename to torch/_functorch/benchmark_utils.py diff --git a/functorch/_src/compile_utils.py b/torch/_functorch/compile_utils.py similarity index 100% rename from functorch/_src/compile_utils.py rename to torch/_functorch/compile_utils.py diff --git a/functorch/_src/compilers.py b/torch/_functorch/compilers.py similarity index 88% rename from functorch/_src/compilers.py rename to torch/_functorch/compilers.py index 18deafa244695..da723e5cbcb18 100644 --- a/functorch/_src/compilers.py +++ b/torch/_functorch/compilers.py @@ -19,6 +19,8 @@ draw_graph, min_cut_rematerialization_partition, ) +import torch.utils._pytree as pytree + # These canonicalizations are needed here (and not decompositions), as the ops @@ -85,7 +87,8 @@ def ts_compile(fx_g: fx.GraphModule, inps) -> Callable: f = torch.jit.freeze(f.eval()) f = torch.jit.optimize_for_inference(f) - f(*inps) + if not any(isinstance(t, torch._subclasses.FakeTensor) for t in inps): + f(*inps) return f @@ -112,6 +115,34 @@ def nop(fx_g: fx.GraphModule, _) -> Callable: """ return fx_g +class DebugInterpreter(fx.Interpreter): + def run_node(self, n): + # TODO: This will fail once we start caching in AOTAutograd + # again, because we need to remap SymInts to their new values + # in the presence of dynamism + r = super().run_node(n) + if 'val' in n.meta: + n_vals, n_spec = pytree.tree_flatten(n.meta['val']) + r_vals, r_spec = pytree.tree_flatten(r) + assert n_spec == r_spec, f"{n_spec} != {r_spec}" + assert len(n_vals) == len(r_vals), f"{len(n_vals)} != {len(r_vals)}" + for i, nv, rv in zip(range(len(n_vals)), n_vals, r_vals): + if not isinstance(rv, torch.Tensor): + continue + assert nv.size() == rv.size(), f"output {i}: {nv.size()} != {rv.size()}" + assert nv.dtype == rv.dtype, f"output {i}: {nv.dtype} != {rv.dtype}" + assert torch._prims_common.check_significant_strides(nv, rv), f"output {i}: {nv.stride()} != {rv.stride()}" + return r + + +@make_boxed_compiler +def debug_nop(fx_g: fx.GraphModule, _) -> Callable: + """ + Returns a (slow) interpreter over the FX graph module that also checks + various debugging properties (e.g., that tracing strides matched real + strides.) + """ + return DebugInterpreter(fx_g).run @make_boxed_compiler def simple_ts_compile(fx_g, _): @@ -349,6 +380,7 @@ def graph_saver_joint(gm, joint_args): return aot_module_simplified( gm, + example_inputs, fw_compiler=graph_saver_forward, bw_compiler=graph_saver_backward, partition_fn=graph_saver_joint, @@ -356,6 +388,7 @@ def graph_saver_joint(gm, joint_args): ) +# WARNING: This isn't tested anywhere!! def graph_dumper_aot(current_name, folder_name, dump_example_input=False): """ Dump the forward, backward, and joint computation graph. diff --git a/functorch/_src/config.py b/torch/_functorch/config.py similarity index 77% rename from functorch/_src/config.py rename to torch/_functorch/config.py index 2dacdd38fa37c..53fa5b28a86bb 100644 --- a/functorch/_src/config.py +++ b/torch/_functorch/config.py @@ -14,6 +14,11 @@ # TODO Benchmark use_fake_tensor = False +# Enables optional asserts in hotpath code to check for errors. If +# you are seeing weird accuracy problems, try turning this on. +# For now, to more easily identify bugs, this is turned on by default. +debug_assert = True + debug_fake_cross_ref = os.environ.get('AOT_FAKE_CROSSREF', False) debug_partitioner = os.environ.get('AOT_PARTITIONER_DEBUG', False) diff --git a/functorch/_src/eager_transforms.py b/torch/_functorch/eager_transforms.py similarity index 99% rename from functorch/_src/eager_transforms.py rename to torch/_functorch/eager_transforms.py index 209a738060bed..3f02b7fa3a1ed 100644 --- a/functorch/_src/eager_transforms.py +++ b/torch/_functorch/eager_transforms.py @@ -233,7 +233,7 @@ def vjp(func: Callable, *primals, has_aux: bool = False): >>> x = torch.randn([5]) >>> def f(x, scale=4.): - >>> return x * 4. + >>> return x * scale >>> >>> (_, vjpfunc) = functorch.vjp(f, x) >>> vjps = vjpfunc(torch.ones_like(x)) @@ -1060,7 +1060,7 @@ def hessian(func, argnums=0): >>> return x.sin().sum() >>> >>> x = torch.randn(5) - >>> hess = jacfwd(jacrev(f))(x) + >>> hess = hessian(f)(x) # equivalent to jacfwd(jacrev(f))(x) >>> assert torch.allclose(hess, torch.diag(-x.sin())) """ diff --git a/functorch/_src/fx_minifier.py b/torch/_functorch/fx_minifier.py similarity index 100% rename from functorch/_src/fx_minifier.py rename to torch/_functorch/fx_minifier.py diff --git a/functorch/_src/make_functional.py b/torch/_functorch/make_functional.py similarity index 99% rename from functorch/_src/make_functional.py rename to torch/_functorch/make_functional.py index 7b8c15196e23b..abb3f07ca597f 100644 --- a/functorch/_src/make_functional.py +++ b/torch/_functorch/make_functional.py @@ -44,7 +44,7 @@ def _get_nested_attr(obj: nn.Module, names: List[str]) -> None: if len(names) == 1: return getattr(obj, names[0]) else: - _get_nested_attr(getattr(obj, names[0]), names[1:]) + return _get_nested_attr(getattr(obj, names[0]), names[1:]) def raise_parameter_tying_error(): diff --git a/functorch/_src/named_members_polyfill.py b/torch/_functorch/named_members_polyfill.py similarity index 100% rename from functorch/_src/named_members_polyfill.py rename to torch/_functorch/named_members_polyfill.py diff --git a/functorch/_src/partitioners.py b/torch/_functorch/partitioners.py similarity index 86% rename from functorch/_src/partitioners.py rename to torch/_functorch/partitioners.py index 1077904528efe..bcbaaca7b0ef8 100644 --- a/functorch/_src/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -11,6 +11,7 @@ from typing import Tuple from .compile_utils import fx_graph_cse, get_aten_target from . import config +import functools AOT_PARTITIONER_DEBUG = config.debug_partitioner @@ -84,16 +85,15 @@ def _is_tangent(node): return node.op == "placeholder" and "tangents" in node.target -def _extract_fwd_bwd_outputs(joint_module: fx.GraphModule): - num_fwd_outputs = joint_module._out_spec.children_specs[0].num_leaves +def _extract_fwd_bwd_outputs(joint_module: fx.GraphModule, *, num_fwd_outputs): outputs = pytree.tree_flatten([node.args for node in joint_module.graph.nodes if node.op == 'output'])[0] fwd_outputs = outputs[:num_fwd_outputs] bwd_outputs = outputs[num_fwd_outputs:] return fwd_outputs, bwd_outputs -def _extract_fwd_bwd_modules(joint_module: fx.GraphModule, saved_values, saved_sym_nodes=()): - fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs(joint_module) +def _extract_fwd_bwd_modules(joint_module: fx.GraphModule, saved_values, saved_sym_nodes=(), *, num_fwd_outputs): + fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs) primal_inputs = list(filter(_is_primal, joint_module.graph.nodes)) tangent_inputs = list(filter(_is_tangent, joint_module.graph.nodes)) # Construct the forward module @@ -125,7 +125,7 @@ def _extract_fwd_bwd_modules(joint_module: fx.GraphModule, saved_values, saved_s def default_partition( - joint_module: fx.GraphModule, _joint_inputs + joint_module: fx.GraphModule, _joint_inputs, *, num_fwd_outputs ) -> Tuple[fx.GraphModule, fx.GraphModule]: """ Partitions the :attr:`joint_module` in a manner that closely resembles the @@ -151,7 +151,7 @@ def default_partition( Returns the generated forward and backward Fx graph modules. """ primal_inputs = list(filter(_is_primal, joint_module.graph.nodes)) - fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs(joint_module) + fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs) forward_only_graph = _extract_graph_with_inputs_outputs(joint_module.graph, primal_inputs, fwd_outputs) forward_node_names = {node.name for node in forward_only_graph.nodes if node.op != 'output'} saved_values = [] @@ -178,7 +178,7 @@ def default_partition( saved_values = list(set(saved_values)) saved_sym_nodes = list(set(saved_sym_nodes)) - return _extract_fwd_bwd_modules(joint_module, saved_values, saved_sym_nodes=saved_sym_nodes) + return _extract_fwd_bwd_modules(joint_module, saved_values, saved_sym_nodes=saved_sym_nodes, num_fwd_outputs=num_fwd_outputs) def _prod(x): @@ -209,7 +209,7 @@ def _tensor_nbytes(numel, dtype): def _size_of(node: fx.Node) -> int: def to_size_hint(s): - if isinstance(s, torch.SymIntNode): + if isinstance(s, torch.SymInt): py_s = s.get_pyobj() return py_s.shape_env.size_hint(py_s.expr) assert isinstance(s, int) @@ -247,8 +247,27 @@ def _count_ops(graph): print(sorted(cnt.items(), key=lambda x: x[1], reverse=True)) +@functools.lru_cache(None) +def pointwise_ops(): + ops = [] + for attr_name in dir(torch.ops.aten): + opoverloadpacket = getattr(torch.ops.aten, attr_name) + if not isinstance(opoverloadpacket, torch._ops.OpOverloadPacket): + continue + + for overload in opoverloadpacket.overloads(): + op_overload = getattr(opoverloadpacket, overload) + if torch.Tag.pointwise in op_overload.tags: + # currently aot autograd uses packet not overload + ops.append(opoverloadpacket) + break + + return ops + + def min_cut_rematerialization_partition( joint_module: fx.GraphModule, _joint_inputs, compiler="nvfuser", recomputable_ops=None, + *, num_fwd_outputs ) -> Tuple[fx.GraphModule, fx.GraphModule]: """ Partitions the joint graph such that the backward recomputes the forward. @@ -270,17 +289,20 @@ def min_cut_rematerialization_partition( recomputable_ops: This is an optional set of recomputable ops. If this is not None, then this set of ops will be used instead of the default set of ops. + num_fwd_outputs: The number of outputs from the forward graph. Returns: Returns the generated forward and backward Fx graph modules. """ try: import networkx as nx - except ImportError: - raise RuntimeError("Need networkx installed to perform smart recomputation heuristics") + except ImportError as e: + raise RuntimeError("Need networkx installed to perform smart recomputation " + "heuristics") from e joint_module.graph.eliminate_dead_code() joint_module.recompile() + fx_g = joint_module.graph # add the CSE pass @@ -302,15 +324,30 @@ def classify_nodes(joint_module): required_bw_nodes.add(user) primal_inputs = list(filter(_is_primal, joint_module.graph.nodes)) - fwd_outputs, _ = _extract_fwd_bwd_outputs(joint_module) + fwd_outputs, _ = _extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs) forward_only_graph = _extract_graph_with_inputs_outputs(joint_module.graph, primal_inputs, fwd_outputs) required_fw_nodes = {name_to_node[node.name] for node in forward_only_graph.nodes if node.op != 'output'} unclaimed_nodes = {node for node in joint_module.graph.nodes if node not in required_fw_nodes and node not in required_bw_nodes} - return required_fw_nodes, required_bw_nodes, unclaimed_nodes + return fwd_outputs, required_fw_nodes, required_bw_nodes, unclaimed_nodes + + orig_fw_outputs, required_fw_nodes, required_bw_nodes, unclaimed_nodes = classify_nodes(joint_module) + + def is_tensor_node(x): + # When dynamic shapes are not enabled, fw outputs can be raw ints and not fx nodes + if not isinstance(x, fx.Node): + return False + # It would be nice if we could guarantee that all fx nodes from make_fx get a 'val' + # key in their meta dict, but that isn't always true today (see proxy_tensor.py) + return 'tensor_meta' in x.meta or ('val' in x.meta and isinstance(x.meta['val'], torch.Tensor)) + + # networkx blows up on graphs with no tensor outputs. + # Since there's nothing to partition anyway, and the default partitioner can "handle" + # this case, send our graph over to the default partitioner. + if not any(is_tensor_node(x) for x in orig_fw_outputs): + return default_partition(joint_module, _joint_inputs, num_fwd_outputs=num_fwd_outputs) - required_fw_nodes, required_bw_nodes, unclaimed_nodes = classify_nodes(joint_module) for node in reversed(joint_module.graph.nodes): if node not in required_fw_nodes: node.dist_from_bw = 0 @@ -325,14 +362,17 @@ def classify_nodes(joint_module): # compiler == "nvfuser" is the default set of recomputable ops default_recomputable_ops = [aten.add, aten.sub, aten.div, aten.atan2, aten.mul, aten.max, aten.min, aten.pow, aten.remainder, aten.fmod, aten.__and__, aten.__or__, aten.__xor__, aten.__lshift__, aten.__rshift__, aten.eq, aten.ne, aten.ge, aten.gt, aten.le, aten.lt, aten.abs, aten.bitwise_not, aten.ceil, aten.floor, aten.frac, aten.neg, aten.relu, aten.round, aten.silu, aten.trunc, aten.log, aten.log10, aten.log1p, aten.log2, aten.lgamma, aten.exp, aten.expm1, aten.erf, aten.erfc, aten.cos, aten.acos, aten.cosh, aten.sin, aten.asin, aten.sinh, aten.tan, aten.atan, aten.tanh, aten.atanh, aten.sqrt, aten.rsqrt, aten.reciprocal, aten.sigmoid, aten.softplus, aten.threshold, aten.threshold_backward, aten.clamp, aten.where, aten.lerp, aten.addcmul, aten.gelu, aten.gelu_backward, aten.alias, aten.sum, aten.mean, aten._grad_sum_to_size, aten.sum_to_size, aten.amax, aten.to, aten.type_as, operator.getitem, aten.squeeze, aten.unsqueeze, aten.rsub, aten._to_copy] # noqa: E501 if compiler == "inductor": - default_recomputable_ops += [prims.div, prims.convert_element_type, aten.sign, aten.clone, aten._to_copy, aten.full_like, prims.var, prims.sum, aten.var, aten.std, prims.broadcast_in_dim, aten.select, aten.permute, aten._unsafe_view, aten.view, aten.expand, aten.slice, aten.reshape, aten.broadcast_tensors, aten.scalar_tensor, aten.ones, aten.new_zeros, aten.lift_fresh_copy, aten.minimum, aten.arange, aten.bitwise_and, aten.triu, aten.var_mean, aten.isinf, aten.any, aten.isnan, aten.full, aten.as_strided, aten.zeros, aten.argmax, aten.maximum, aten.bitwise_or, aten.logical_and, aten.logical_or] # noqa: E501 + default_recomputable_ops += [prims.div, prims.convert_element_type, aten.clone, aten._to_copy, aten.full_like, prims.var, prims.sum, aten.var, aten.std, prims.broadcast_in_dim, aten.select, aten.permute, aten._unsafe_view, aten.view, aten.expand, aten.slice, aten.reshape, aten.broadcast_tensors, aten.scalar_tensor, aten.ones, aten.new_zeros, aten.lift_fresh_copy, aten.arange, aten.triu, aten.var_mean, aten.isinf, aten.any, aten.full, aten.as_strided, aten.zeros, aten.argmax, aten.maximum] # noqa: E501 # Natalia said that we should allow recomputing indexing :) default_recomputable_ops += [aten.index] + # add more generally ? + default_recomputable_ops += pointwise_ops() + recomputable_ops = set(recomputable_ops) if recomputable_ops is not None else set(default_recomputable_ops) random_ops = [aten.native_dropout, aten.rand_like, aten.randn_like] - compute_intensive_ops = [aten.mm, aten.convolution, aten.convolution_backward, aten.bmm, aten.addmm, aten.upsample_bilinear2d, aten._softmax, aten._softmax_backward_data, aten.native_layer_norm, aten.native_layer_norm_backward, aten.native_batch_norm, aten.native_batch_norm_backward] # noqa: E501 + compute_intensive_ops = [aten.mm, aten.convolution, aten.convolution_backward, aten.bmm, aten.addmm, aten.upsample_bilinear2d, aten._softmax, aten._softmax_backward_data, aten.native_layer_norm, aten.native_layer_norm_backward, aten.native_batch_norm, aten.native_batch_norm_backward, aten._native_batch_norm_legit] # noqa: E501 unrecomputable_ops = random_ops + compute_intensive_ops @@ -443,7 +483,8 @@ def get_node_weight(node) -> int: # save_for_backward on tensors and stashes symints in autograd .ctx saved_sym_nodes = list(filter(lambda n: is_sym_node(n), saved_values)) saved_values = list(filter(lambda n: not is_sym_node(n), saved_values)) - fw_module, bw_module = _extract_fwd_bwd_modules(joint_module, saved_values, saved_sym_nodes=saved_sym_nodes) + fw_module, bw_module = _extract_fwd_bwd_modules( + joint_module, saved_values, saved_sym_nodes=saved_sym_nodes, num_fwd_outputs=num_fwd_outputs) if AOT_PARTITIONER_DEBUG: print("Theoretical Activations Stored: ", sum([_size_of(i) for i in saved_values]) / 1e9) fw_module_nodes = set([node.name for node in fw_module.graph.nodes if node.op == 'call_function']) diff --git a/torch/_functorch/pyfunctorch.py b/torch/_functorch/pyfunctorch.py new file mode 100644 index 0000000000000..1ada5b4e19771 --- /dev/null +++ b/torch/_functorch/pyfunctorch.py @@ -0,0 +1,142 @@ +from abc import ABC, abstractmethod +import contextlib +from typing import Any +import torch +import torch.utils._pytree as pytree +from torch._C._functorch import ( + TransformType, + CInterpreter, + CGradInterpreterPtr, + CVmapInterpreterPtr, + pop_dynamic_layer_stack, + push_dynamic_layer_stack, +) + +""" +This file contains the functorch integration with PyDispatcher. + +PyDispatcher does not understand functorch's DynamicLayerStack dispatching +logic because it is entirely implemented in C++ in the fallbacks for two +dispatch keys, FuncTorchDynamicLayer{Front, Back}Mode (PyDispatcher is unable +to directly reuse C++ boxed fallbacks). + +Instead of trying to hammer PyDispatcher into understanding those fallbacks, +we re-implement the logic of peeking the top of the stack for an interpreter, +selecting the interpreter to dispatch on, etc, in Python. This leads to a +simpler design. + +The main difference between C++ functorch and PyDispatcher's functorch logic +is that: +- C++ functorch needs to manually tweak dispatch keys to ping-pong between + DynamicLayerFrontMode and DynamicLayerBackMode. +- PyDispatcher's functorch logic pops an Interpreter from the top of the stack + and asks it to execute the rule associated with the Interpreter. + +In C++ we do the ping-pong because e.g. vmap rules are associated with the +batched DispatchKey, but in PyDispatcher we are able to avoid this by asking +the user to register a batching rule directly to a transform that an +interpreter then invokes. +""" + + +# FuncTorchInterpreter is the Python version of Interpreter (recall that +# the DynamicLayerStack is a stack of interpreters). +# It is a wrapper around the actual C++ Interpreter object. +# +# Keep the methods in sync with aten/src/ATen/functorch/Interpreter.h +class FuncTorchInterpreter(ABC): + def __init__(self, cptr: Any): + self._cptr = cptr + + # Process an operation. eg for vmap, this is invoking a batching rule. + # Conceptually this is analogous to Interpreter::process in C++ + @abstractmethod + def process(self, op, args, kwargs): + pass + + # lower an operation from this Interpreter to the next Interpreter on the stack. + # Concretely, this involves temporarily popping the current Interpreter. + # Conceptually this is analogous to Interpreter::sendToNextInterpreter in C++ + def lower(self): + return temporarily_pop_interpreter_stack() + + def level(self): + return self._cptr.level() + + def key(self): + return self._cptr.key() + + +@contextlib.contextmanager +def temporarily_pop_interpreter_stack(): + try: + saved = pop_dynamic_layer_stack() + yield + finally: + push_dynamic_layer_stack(saved) + + +class VmapInterpreter(FuncTorchInterpreter): + def __init__(self, cdata: CInterpreter): + assert cdata.key() == TransformType.Vmap + # NOTE: [Interpreter cdata vs cptr] + # cdata is a generic CInterpreter. We wrap it in a CVmapInterpreterPtr + # so that we can access methods specific to the vmap interpreter + self._cdata = cdata + self._cptr = CVmapInterpreterPtr(cdata) + + def process(self, op, args, kwargs): + kernel = op.functorch_table[TransformType.Vmap] + return kernel(self, *args, **kwargs) + + def batch_size(self): + return self._cptr.batchSize() + + +class GradInterpreter(FuncTorchInterpreter): + def __init__(self, cdata: CInterpreter): + assert cdata.key() == TransformType.Grad + # See NOTE: [Interpreter cdata vs cptr] + self._cdata = cdata + self._cptr = CGradInterpreterPtr(cdata) + + def lift(self, args, kwargs): + args, kwargs = pytree.tree_map_only(torch.Tensor, self._cptr.lift, [args, kwargs]) + return args, kwargs + + def process(self, op, args, kwargs): + kernel = op.functorch_table[TransformType.Grad] + args, kwargs = self.lift(args, kwargs) + return kernel(self, *args, **kwargs) + + # GradInterpreter has custom lower because of the no_grad interaction + # See NOTE [grad and vjp interaction with no_grad] + # This logic is mirrored from C++ GradInterpreterPtr::sendToNextInterpreter + def lower(self): + prev_grad_mode = self.prev_grad_mode() + if not self.prev_grad_mode: + return contextlib.nested(torch.no_grad(), super().lower()) + return super().lower() + + def prev_grad_mode(self): + return self._cptr.prevGradMode() + + +def coerce_cinterpreter(cinterpreter: CInterpreter) -> FuncTorchInterpreter: + key = cinterpreter.key() + if key == TransformType.Grad: + return GradInterpreter(cinterpreter) + if key == TransformType.Vmap: + return VmapInterpreter(cinterpreter) + raise RuntimeError(f"NYI: PyDispatcher has not implemented support for {key}") + + +def retrieve_current_functorch_interpreter(): + interpreter = torch._C._functorch.peek_interpreter_stack() + assert interpreter is not None + return coerce_cinterpreter(interpreter) + + +def dispatch_functorch(op, args, kwargs): + interpreter = retrieve_current_functorch_interpreter() + return interpreter.process(op, args, kwargs) diff --git a/functorch/_src/python_key.py b/torch/_functorch/python_key.py similarity index 100% rename from functorch/_src/python_key.py rename to torch/_functorch/python_key.py diff --git a/functorch/_src/pytree_hacks.py b/torch/_functorch/pytree_hacks.py similarity index 100% rename from functorch/_src/pytree_hacks.py rename to torch/_functorch/pytree_hacks.py diff --git a/functorch/_src/top_operators_github_usage.py b/torch/_functorch/top_operators_github_usage.py similarity index 100% rename from functorch/_src/top_operators_github_usage.py rename to torch/_functorch/top_operators_github_usage.py diff --git a/torch/_functorch/utils.py b/torch/_functorch/utils.py new file mode 100644 index 0000000000000..2e98c4ba8fd1d --- /dev/null +++ b/torch/_functorch/utils.py @@ -0,0 +1,24 @@ +import contextlib +import torch +from torch._C._functorch import ( + set_autograd_function_allowed, + get_autograd_function_allowed, + unwrap_if_dead, +) + +@contextlib.contextmanager +def enable_autograd_function(): + try: + prev_state = get_autograd_function_allowed() + set_autograd_function_allowed(True) + yield + finally: + set_autograd_function_allowed(prev_state) + +def unwrap_dead_wrappers(args): + # NB: doesn't use tree_map_only for performance reasons + result = tuple( + unwrap_if_dead(arg) if isinstance(arg, torch.Tensor) else arg + for arg in args + ) + return result diff --git a/functorch/_src/vmap.py b/torch/_functorch/vmap.py similarity index 100% rename from functorch/_src/vmap.py rename to torch/_functorch/vmap.py diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index c4400a35cce85..2b7481acda41c 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -1,23 +1,29 @@ import base64 +import dataclasses import functools import getpass import hashlib import logging +import multiprocessing import os import re import shutil +import signal import subprocess +import sys import sysconfig import tempfile import types from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor from ctypes import cdll -from time import time -from typing import Any, Dict +from threading import Thread +from time import sleep, time +from typing import Any, Callable, Dict, List import torch -from torch.utils import cpp_extension +from torch.hub import Faketqdm, tqdm +from torch.utils import cpp_extension from . import config, cuda_properties, exc LOCK_TIMEOUT = 600 @@ -46,8 +52,11 @@ def _compile_end(): logging.getLogger("filelock").setLevel(logging.DEBUG if config.debug else logging.INFO) +@functools.lru_cache(None) def cache_dir(): - return f"/tmp/torchinductor_{getpass.getuser()}" + return os.environ.get( + "TORCHINDUCTOR_CACHE_DIR", f"/tmp/torchinductor_{getpass.getuser()}" + ) def get_lock_dir(): @@ -66,12 +75,17 @@ def code_hash(code): ) -def write(source_code, ext, extra=""): +def get_code_path(source_code, ext, extra): basename = code_hash(source_code + extra) subdir = os.path.join(cache_dir(), basename[1:3]) + path = os.path.join(subdir, f"{basename}.{ext}") + return basename, subdir, path + + +def write(source_code, ext, extra=""): + basename, subdir, path = get_code_path(source_code, ext, extra) if not os.path.exists(subdir): os.makedirs(subdir, exist_ok=True) - path = os.path.join(subdir, f"{basename}.{ext}") if not os.path.exists(path): # use a temp file for thread safety fd, tmp_path = tempfile.mkstemp(dir=subdir) @@ -94,6 +108,13 @@ def cpp_compiler_search(search): for cxx in search: try: if cxx is None: + # gxx package is only available for Linux + # according to https://anaconda.org/conda-forge/gxx/ + if sys.platform != "linux": + continue + # Do not install GXX by default + if not os.getenv("TORCH_INDUCTOR_INSTALL_GXX"): + continue from filelock import FileLock lock_dir = get_lock_dir() @@ -139,11 +160,205 @@ def is_gcc(): return re.search(r"(gcc|g\+\+)", cpp_compiler()) -def cpp_compile_command(input, output, include_pytorch=False): - if include_pytorch: +class VecISA(object): + _bit_width: int + _macro: str + _arch_flags: str + _dtype_nelements: Dict[torch.dtype, int] + + # TorchInductor CPU vectorization reuses PyTorch vectorization utility functions + # Hence, TorchInductor would depend on Sleef* to accelerate mathematical functions + # like exp, pow, sin, cos and etc. + # But PyTorch and TorchInductor might use different compilers to build code. If + # PyTorch uses gcc-7/g++-7 to build the release package, the libtorch_cpu.so + # will not expose the Sleef* AVX512 symbols since gcc-7/g++-7 cannot pass + # avx512 check in CMake - FindAVX.cmake. But TorchInductor install the latest + # gcc/g++ compiler by default while it could support the AVX512 compilation. + # Therefore, there would be a conflict sleef version between PyTorch and + # TorchInductor. Hence, we dry-compile the following code to check whether current + # HW platform and PyTorch both could support AVX512 or AVX2. And suppose ARM + # also needs the logic + _avx_code = """ +#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) +#include +#include +#endif + +__attribute__((aligned(64))) float in_out_ptr0[16] = {0.0}; + +extern "C" void __avx_chk_kernel() { + auto tmp0 = at::vec::Vectorized(1); + auto tmp1 = tmp0.exp(); + tmp1.store(in_out_ptr0); +} +""" + + _avx_py_load = """ +import torch +from ctypes import cdll +cdll.LoadLibrary("__lib_path__") +""" + + def bit_width(self): + return self._bit_width + + def nelements(self, dtype: torch.dtype = torch.float): + return self._dtype_nelements[dtype] + + def build_macro(self): + return self._macro + + def build_arch_flags(self): + return self._arch_flags + + def __hash__(self) -> int: + return hash(str(self)) + + @functools.lru_cache(None) + def __bool__(self): + key, input_path = write(VecISA._avx_code, "cpp", extra="") + from filelock import FileLock + + lock_dir = get_lock_dir() + lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT) + with lock: + output_path = input_path[:-3] + "so" + build_cmd = cpp_compile_command( + input_path, output_path, warning_all=False, vec_isa=self + ).split(" ") + try: + # Check build result + subprocess.check_output(build_cmd, stderr=subprocess.STDOUT) + subprocess.check_call( + [ + "python", + "-c", + VecISA._avx_py_load.replace("__lib_path__", output_path), + ], + stderr=subprocess.DEVNULL, + ) + except Exception as e: + return False + + return True + + +@dataclasses.dataclass +class VecAVX512(VecISA): + _bit_width = 512 + _macro = "CPU_CAPABILITY_AVX512" + _arch_flags = "-mavx512f -mavx512dq -mavx512vl -mavx512bw -mfma" + _dtype_nelements = {torch.float: 16, torch.bfloat16: 32} + + def __str__(self) -> str: + return "avx512" + + __hash__: Callable[[VecISA], Any] = VecISA.__hash__ + + +@dataclasses.dataclass +class VecAVX2(VecISA): + _bit_width = 256 + _macro = "CPU_CAPABILITY_AVX2" + _arch_flags = "-mavx2 -mfma" + _dtype_nelements = {torch.float: 8, torch.bfloat16: 16} + + def __str__(self) -> str: + return "avx2" + + __hash__: Callable[[VecISA], Any] = VecISA.__hash__ + + +class InvalidVecISA(VecISA): + _bit_width = 0 + _macro = "" + _arch_flags = "" + _dtype_nelements = {} + + def __str__(self) -> str: + return "INVALID_VEC_ISA" + + def __bool__(self): + return False + + __hash__: Callable[[VecISA], Any] = VecISA.__hash__ + + +invalid_vec_isa = InvalidVecISA() +supported_vec_isa_list = [VecAVX512(), VecAVX2()] + + +# Cache the cpuinfo to avoid I/O overhead. Meanwhile, the cpuinfo content +# might have too much redundant content that is useless for ISA check. Hence, +# we only cache some key isa information. +@functools.lru_cache(None) +def valid_vec_isa_list(): + if sys.platform != "linux": + return [] + + isa_list = [] + with open("/proc/cpuinfo") as _cpu_info: + _cpu_info_content = _cpu_info.read() + for isa in supported_vec_isa_list: + if str(isa) in _cpu_info_content and isa: + isa_list.append(isa) + return isa_list + + +def pick_vec_isa(): + _valid_vec_isa_list: List[VecISA] = valid_vec_isa_list() + if not _valid_vec_isa_list: + return invalid_vec_isa + + # If the simdlen is None, it indicates determin the vectroization length automatically + if config.cpp.simdlen is None: + assert _valid_vec_isa_list + return _valid_vec_isa_list[0] + + for isa in _valid_vec_isa_list: + if config.cpp.simdlen == isa.bit_width(): + return isa + + return invalid_vec_isa + + +def get_shared(shared=True): + return "-shared -fPIC" if shared else "" + + +def get_warning_all_flag(warning_all=True): + return "-Wall" if warning_all else "" + + +def cpp_flags(): + return "-std=c++17 -Wno-unused-variable" + + +def optimization_flags(): + return "-march=native -O3 -ffast-math -fno-finite-math-only -fopenmp" + + +def use_custom_generated_macros(): + return "-D C10_USING_CUSTOM_GENERATED_MACROS" + + +def get_include_and_linking_paths( + include_pytorch=False, vec_isa: VecISA = invalid_vec_isa +): + if sys.platform == "linux" and ( + include_pytorch + or vec_isa != invalid_vec_isa + or config.cpp.enable_kernel_profile + ): + # Note - We include pytorch only on linux right now. There is more work + # to do to enable OMP build on darwin where PyTorch is built with IOMP + # and we need a way to link to what PyTorch links. ipaths = cpp_extension.include_paths() + [sysconfig.get_path("include")] lpaths = cpp_extension.library_paths() + [sysconfig.get_config_var("LIBDIR")] libs = ["c10", "torch", "torch_cpu", "torch_python", "gomp"] + macros = vec_isa.build_macro() + if macros: + macros = f"-D{macros}" else: # Note - this is effectively a header only inclusion. Usage of some header files may result in # symbol not found, if those header files require a library. @@ -152,17 +367,34 @@ def cpp_compile_command(input, output, include_pytorch=False): ipaths = cpp_extension.include_paths() + [sysconfig.get_path("include")] lpaths = [] libs = ["gomp"] + macros = "" ipaths = " ".join(["-I" + p for p in ipaths]) lpaths = " ".join(["-L" + p for p in lpaths]) libs = " ".join(["-l" + p for p in libs]) + return ipaths, lpaths, libs, macros + + +def cpp_compile_command( + input, + output, + warning_all=True, + shared=True, + include_pytorch=False, + vec_isa: VecISA = invalid_vec_isa, +): + ipaths, lpaths, libs, macros = get_include_and_linking_paths( + include_pytorch, vec_isa + ) + return re.sub( r"[ \n]+", " ", f""" - {cpp_compiler()} -shared -fPIC -Wall -std=c++14 -Wno-unused-variable - {ipaths} {lpaths} {libs} - -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp - -o{output} {input} + {cpp_compiler()} {input} {get_shared(shared)} {get_warning_all_flag(warning_all)} {cpp_flags()} + {ipaths} {lpaths} {libs} {macros} + {optimization_flags()} + {use_custom_generated_macros()} + -o{output} """, ).strip() @@ -171,9 +403,26 @@ class CppCodeCache: cache = dict() clear = staticmethod(cache.clear) + @staticmethod + def _load_library(path): + try: + return cdll.LoadLibrary(path) + except OSError as e: + if "gomp" in str(e) and os.path.exists("/usr/lib64/libgomp.so.1"): + # hacky workaround for fbcode/buck + global _libgomp + _libgomp = cdll.LoadLibrary("/usr/lib64/libgomp.so.1") + return cdll.LoadLibrary(path) + raise + @classmethod def load(cls, source_code): - key, input_path = write(source_code, "cpp", extra=cpp_compile_command("i", "o")) + picked_vec_isa = pick_vec_isa() + key, input_path = write( + source_code, + "cpp", + extra=cpp_compile_command("i", "o", vec_isa=picked_vec_isa), + ) if key not in cls.cache: from filelock import FileLock @@ -183,14 +432,14 @@ def load(cls, source_code): output_path = input_path[:-3] + "so" if not os.path.exists(output_path): cmd = cpp_compile_command( - input=input_path, output=output_path + input=input_path, output=output_path, vec_isa=picked_vec_isa ).split(" ") try: subprocess.check_output(cmd, stderr=subprocess.STDOUT) except subprocess.CalledProcessError as e: - raise exc.CppCompileError(cmd, e.output) + raise exc.CppCompileError(cmd, e.output) from e - cls.cache[key] = cdll.LoadLibrary(output_path) + cls.cache[key] = cls._load_library(output_path) cls.cache[key].key = key return cls.cache[key] @@ -225,7 +474,7 @@ def patch_triton_dir(): class TritonCodeCache: @staticmethod def get_name(mod): - (name,) = [n for n in dir(mod) if n.startswith("kernel")] + (name,) = [n for n in dir(mod) if n.startswith("triton_")] return name @classmethod @@ -247,17 +496,30 @@ def _load_kernel(source_code): return kernel +def _load_kernel_name(source_code): + return TritonCodeCache.get_name(PyCodeCache.load(source_code)) + + class TritonFuture: def __init__(self, source_code, future): self.source_code = source_code self.future = future + # @dynamo_utils.dynamo_timed def result(self): + t0 = time() if hasattr(self, "kernel"): return self.kernel # If the worker failed this will throw an exception. self.future.result() kernel = self.kernel = _load_kernel(self.source_code) + latency = time() - t0 + if latency > 50: + name = _load_kernel_name(self.source_code) + log.warning( + f"Detected long compilation time of {latency} seconds for kernel name {name}" + ) + log.warning(self.source_code) del self.source_code, self.future return kernel @@ -279,7 +541,37 @@ def process_pool(): # are forked cuda_properties._properties() assert config.compile_threads > 1 - return ProcessPoolExecutor(config.compile_threads) + orig_ppid = os.getpid() + + # if this process dies abnormally (e.g. segfault) + # it will not shut down the workers. Instead + # the workers will have their parent reassigned to the + # init process. This launches a separate thread to + # watch for the worker getting reassigned, + # and cleans it up in this case. + def init(): + def run(): + while True: + sleep(1) + if orig_ppid != os.getppid(): + os.kill(os.getpid(), signal.SIGKILL) + + global _watchdog_thread + _watchdog_thread = Thread(target=run, daemon=True) + _watchdog_thread.start() + + # we rely on 'fork' because we cannot control whether users + # have an `if __name__ == '__main__'` in their main process. + fork_context = multiprocessing.get_context("fork") + pool = ProcessPoolExecutor( + config.compile_threads, mp_context=fork_context, initializer=init + ) + # when this pool is created in a subprocess object, the normal exit handler + # doesn't run, and we need to register our own handler. + # exitpriority has to be high, because another one of the finalizers will + # kill the worker thread that sends the shutdown message to the workers... + multiprocessing.util.Finalize(None, pool.shutdown, exitpriority=sys.maxsize) + return pool @classmethod def warm_pool(cls): @@ -305,7 +597,7 @@ def warm_pool(cls): if hasattr(pool, "_start_queue_management_thread"): pool._start_queue_management_thread() else: - for i in range(config.compile_threads): + for _ in range(config.compile_threads): pool._adjust_process_count() pool._start_executor_manager_thread() _compile_end() @@ -346,10 +638,26 @@ def task(): return self.submit(task) def wait(self, scope: Dict[str, Any]): + num_kernels = len( + [ + value + for key, value in scope.items() + if isinstance(value, (Future, TritonFuture)) + ] + ) + pbar = tqdm( + total=num_kernels, + desc="Inductor Compilation", + disable=config.disable_progress, + delay=0, + ) if config.compile_threads > 1: - for key, result in list(scope.items()): + for key, result in scope.items(): + if config.verbose_progress and not isinstance(pbar, Faketqdm): + pbar.set_postfix_str(key) if isinstance(result, (Future, TritonFuture)): scope[key] = result.result() + pbar.update(1) _compile_end() diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index a949effb26793..7f467cff4d488 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -34,7 +34,8 @@ class ExprPrinter(Printer): @staticmethod def paren(string): if ( - re.match(r"^[a-z0-9_.]+$", string, re.I) + isinstance(string, CSEVariable) + or re.match(r"^[a-z0-9_.]+$", string, re.I) or re.match(r"^\([^)]*\)$", string, re.I) or string == "" ): @@ -47,7 +48,12 @@ def _print_Pow(self, expr): base = self._print(base) assert exp.is_integer exp = int(exp) - return "*".join([self.paren(base)] * exp) + if exp > 0: + return "*".join([self.paren(base)] * exp) + elif exp < 0: + return "1/" + self.paren("*".join([self.paren(base)] * abs(exp))) + else: # exp == 0 + return "1" def _print_Mul(self, expr): return "*".join(map(self.paren, map(self._print, expr.args))) @@ -89,7 +95,9 @@ def square(x): @staticmethod def sign(x): - return ops.where(f"{x} == 0", "0", ops.where(f"{x} < 0", "-1", "1")) + left = ops.where(ops.lt("0", x), "1", "0") + right = ops.where(ops.lt(x, "0"), "1", "0") + return ops.sub(left, right) @staticmethod def bitwise_not(x): @@ -283,6 +291,8 @@ def input(self, name): assert name not in V.graph.removed_buffers, name if name in self.output_buffers: return self.output_buffers[name] + if name in self.inplace_buffers: + return self.inplace_buffers[name].inner_name if name.startswith("seed"): return self._lookup("seed", self.input_buffers, name) return self._lookup("in_ptr", self.input_buffers, name) @@ -290,6 +300,8 @@ def input(self, name): def output(self, name): name = V.graph.scheduler.mutation_real_name.get(name, name) assert name not in V.graph.removed_buffers, name + if name in self.inplace_buffers: + return self.inplace_buffers[name].inner_name return self._lookup("out_ptr", self.output_buffers, name) def make_inplace(self, input_name, output_name): @@ -317,6 +329,12 @@ def call_names(self): self.input_buffers.keys(), self.output_buffers.keys(), self.sizevars.keys() ) + def wrap_ptr_arg(self, buf, dtype): + return f"c_void_p({buf}.data_ptr())" + + def wrap_size_arg(self, size): + return f"c_long({size})" + def cpp_argdefs(self): from .cpp import DTYPE_TO_CPP, INDEX_TYPE @@ -331,28 +349,36 @@ def cpp_argdefs(self): call_args = [] arg_defs = [] + arg_types = [] for inplaced in unique(self.inplace_buffers.values()): outer = inplaced.other_names[-1] inner = inplaced.inner_name dtype = buffer_types[outer] - arg_defs.append(f"{DTYPE_TO_CPP[dtype]}* __restrict__ {inner}") - call_args.append(f"c_void_p({outer}.data_ptr())") + cpp_dtype = DTYPE_TO_CPP[dtype] + arg_defs.append(f"{cpp_dtype}* __restrict__ {inner}") + call_args.append(self.wrap_ptr_arg(outer, dtype)) + arg_types.append(f"{cpp_dtype}*") for outer, inner in self.input_buffers.items(): if outer in self.inplace_buffers: continue dtype = buffer_types[outer] - arg_defs.append(f"const {DTYPE_TO_CPP[dtype]}* __restrict__ {inner}") - call_args.append(f"c_void_p({outer}.data_ptr())") + cpp_dtype = DTYPE_TO_CPP[dtype] + arg_defs.append(f"const {cpp_dtype}* __restrict__ {inner}") + call_args.append(self.wrap_ptr_arg(outer, dtype)) + arg_types.append(f"const {cpp_dtype}*") for outer, inner in self.output_buffers.items(): if outer in self.inplace_buffers or inner == "REMOVED": continue dtype = buffer_types[outer] - arg_defs.append(f"{DTYPE_TO_CPP[dtype]}* __restrict__ {inner}") - call_args.append(f"c_void_p({outer}.data_ptr())") + cpp_dtype = DTYPE_TO_CPP[dtype] + arg_defs.append(f"{cpp_dtype}* __restrict__ {inner}") + call_args.append(self.wrap_ptr_arg(outer, dtype)) + arg_types.append(f"{cpp_dtype}*") for outer, inner in self.sizevars.items(): arg_defs.append(f"const {INDEX_TYPE} {inner}") - call_args.append(f"c_long({outer})") - return arg_defs, call_args + call_args.append(self.wrap_size_arg(outer)) + arg_types.append(f"const {INDEX_TYPE}") + return arg_defs, call_args, arg_types def python_argdefs(self): arg_defs = [] @@ -392,6 +418,45 @@ def aliases(self): if other in self.output_buffers: yield self.output_buffers[other], inplaced.inner_name + def is_removed(self, name): + def _is_removed(name, buffers): + return name not in buffers or buffers[name] == "REMOVED" + + return _is_removed(name, self.output_buffers) and _is_removed( + name, self.inplace_buffers + ) + + +class CSEVariable: + """A CSEVariable is just a name for an expression but it is useful to be able to annotate them on a backend dependent basis. + The backends can inherit from this class and overload the "create_cse_var" Kernel to do that. + The "update_on_args" method gives you a hook for annotations, see example of TritonCSEVariable in triton.py.""" + + def __init__(self, name): + self.name = name + + def __str__(self): + return self.name + + def __hash__(self) -> int: + return hash(self.name) + + def __eq__(self, other) -> bool: + return type(other) == type(self) and other.name == self.name + + def update_on_args(self, args, kwargs): + pass + + +class CppWrapperKernelArgs(KernelArgs): + def wrap_ptr_arg(self, buf, dtype): + from .cpp import DTYPE_TO_CPP + + return f"({DTYPE_TO_CPP[dtype]}*)({buf}.data_ptr())" + + def wrap_size_arg(self, size): + return f"{size}" + class CSE: """Common subexpression elimination""" @@ -413,6 +478,7 @@ def __init__( self.reduction_cache = reduction_cache or {} self.iter_buffer_ids = iter_buffers or itertools.count() self.invalidated_stores = set() + self.varname_map = {} def invalidate(self, keep_vars: typing.Set[str]): for name, tmp in list(self.store_cache.items()): @@ -430,9 +496,11 @@ def clone(self): self.store_cache, ) - def generate(self, buffer: IndentedBuffer, expr: str, write=True): - assert isinstance(expr, str), expr - if expr.startswith(self.name_prefix) and re.match(r"^[a-z0-9]+$", expr): + def generate( + self, buffer: IndentedBuffer, expr: typing.Union[str, CSEVariable], write=True + ) -> CSEVariable: + assert isinstance(expr, (str, CSEVariable)), type(expr) + if isinstance(expr, CSEVariable): return expr if expr not in self.cache: var = self.newvar() @@ -442,8 +510,11 @@ def generate(self, buffer: IndentedBuffer, expr: str, write=True): buffer.writeline(f"{self.prefix}{var} = {expr}{self.suffix}") return self.cache[expr] - def newvar(self): - return f"{self.name_prefix}{next(self.iter_buffer_ids)}" + def newvar(self) -> CSEVariable: + var_name = f"{self.name_prefix}{next(self.iter_buffer_ids)}" + var = V.kernel.create_cse_var(var_name) + self.varname_map[var_name] = var + return var class CodeGen: @@ -524,12 +595,16 @@ def reduction(self, name, dtype, src_dtype, reduction_type, index, value): def __enter__(self): class CSEProxy: + self.name = "CSEProxy" + @staticmethod def __getattr__(name): def inner(*args, **kwargs): - return self.cse.generate( + csevar = self.cse.generate( self.compute, getattr(parent_handler, name)(*args, **kwargs) ) + csevar.update_on_args(args, kwargs) + return csevar return inner @@ -586,3 +661,6 @@ def rename_indexing(self, index) -> sympy.Expr: x: self.args.size(x) for x in sorted_symbols if x.name.startswith("s") } return sympy_subs(index, replacements) + + def create_cse_var(self, *args, **kwargs): + return CSEVariable(*args, **kwargs) diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 39dd6519d926c..18c1e3f14fadf 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -1,6 +1,9 @@ import contextlib import dataclasses import functools +import math +import sys +from copy import deepcopy from pathlib import Path from typing import Dict, List @@ -9,11 +12,13 @@ import torch from torch._prims_common import is_float_dtype -from .. import codecache, config -from ..utils import sympy_product, sympy_symbol +from .. import codecache, config, ir, metrics +from ..codegen.wrapper import WrapperCodeGen +from ..utils import sympy_product, sympy_subs, sympy_symbol from ..virtualized import ops, V from .common import ( BracesBuffer, + CppWrapperKernelArgs, DeferredIndentedBuffer, ExprPrinter, IndentedBuffer, @@ -34,6 +39,20 @@ torch.bool: "bool", torch.bfloat16: "bfloat16", } + +DTYPE_TO_ATEN = { + torch.float32: "at::ScalarType::Float", + torch.float64: "at::ScalarType::Double", + torch.float16: "at::ScalarType::Half", + torch.int64: "at::ScalarType::Long", + torch.int32: "at::ScalarType::Int", + torch.int16: "at::ScalarType::Short", + torch.int8: "at::ScalarType::Char", + torch.uint8: "at::ScalarType::Byte", + torch.bool: "at::ScalarType::Bool", + torch.bfloat16: "at::ScalarType::BFloat16", +} + INDEX_TYPE = "long" RTYPE_TO_CPP = { @@ -72,6 +91,17 @@ def reduction_combine(reduction_type, var, next_value): return f"{var} = std::{reduction_type}({var}, {next_value})" +def reduction_combine_vec(reduction_type, var, next_value): + if reduction_type == "max": + return f"{var} = at::vec::maximum({var}, {next_value})" + elif reduction_type == "min": + return f"{var} = at::vec::minimum({var}, {next_value})" + elif reduction_type == "sum": + return f"{var} += {next_value}" + else: + raise NotImplementedError() + + index_value_name_counter = 1 @@ -120,6 +150,13 @@ def float16_reduction_prefix(rtype): return prefix +def parallel_num_threads(): + threads = config.cpp.threads + if threads < 1: + threads = torch.get_num_threads() + return threads + + @functools.lru_cache() def cpp_prefix(): path = Path(__file__).parent / "cpp_prefix.h" @@ -151,6 +188,181 @@ def _print_IndexingDiv(self, expr): cexpr = CppPrinter().doprint +class CppVecOverrides(OpOverrides): + """Map element-wise ops to aten vectorization C++""" + + @staticmethod + def add(a, b): + return f"{a} + {b}" + + @staticmethod + def sub(a, b): + return f"{a} - {b}" + + @staticmethod + def mul(a, b): + return f"{a} * {b}" + + @staticmethod + def div(a, b): + return f"{a} / {b}" + + @staticmethod + def abs(x): + return f"{x}.abs()" + + @staticmethod + def sin(x): + return f"{x}.sin()" + + @staticmethod + def cos(x): + return f"{x}.cos()" + + @staticmethod + def exp(x): + return f"{x}.exp()" + + @staticmethod + def erf(x): + return f"{x}.erf()" + + @staticmethod + def sqrt(x): + return f"{x}.sqrt()" + + @staticmethod + def rsqrt(x): + return f"{x}.rsqrt()" + + @staticmethod + def pow(a, b): + return f"{a}.pow({b})" + + @staticmethod + def log(x): + return f"{x}.log()" + + @staticmethod + def round(x): + return f"{x}.round()" + + @staticmethod + def floor(x): + return f"{x}.floor()" + + @staticmethod + def ceil(x): + return f"{x}.ceil()" + + @staticmethod + def trunc(x): + return f"{x}.trunc()" + + @staticmethod + def fmod(a, b): + return f"{a}.fmod({b})" + + @staticmethod + def lgamma(x): + return f"{x}.lgamma()" + + @staticmethod + def logical_and(a, b): + return f"{a} && {b}" + + @staticmethod + def logical_or(a, b): + return f"{a} || {b}" + + @staticmethod + def tanh(a): + return f"{a}.tanh()" + + @staticmethod + def reciprocal(a): + return f"{a}.reciprocal()" + + @staticmethod + def constant(val, dtype): + if val == float("inf"): + quote = f"std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::infinity()" + elif val == float("-inf"): + quote = f"-std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::infinity()" + elif math.isnan(val): + quote = f"std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::quiet_NaN()" + elif val is True or val is False: + quote = f"static_cast<{DTYPE_TO_CPP[dtype]}>({str(val).lower()})" + else: + quote = f"static_cast<{DTYPE_TO_CPP[dtype]}>({repr(val)})" + return f"at::vec::Vectorized<{DTYPE_TO_CPP[dtype]}>({quote})" + + @staticmethod + def relu(x): + return f"at::vec::clamp_min({x}, decltype({x})(0))" + + @staticmethod + def sigmoid(x): + return f"decltype({x})(1)/(decltype({x})(1) + {x}.neg().exp())" + + @staticmethod + def neg(x): + return f"{x}.neg()" + + @staticmethod + def floordiv(a, b): + # a and b are integer type + _t = f"decltype({a})" + quot = f"{a} / {b}" + rem = f"{a} % {b}" + return f"(({a} < {_t}(0)) != ({b} < {_t}(0)) ? ({rem} != {_t}(0) ? {quot} - {_t}(1) : {quot}) : {quot})" + + @staticmethod + def truncdiv(a, b): + # a and b are integer type + return f"{a} / {b}" + + @staticmethod + def minimum(a, b): + return f"at::vec::minimum({a}, {b})" + + @staticmethod + def maximum(a, b): + return f"at::vec::maximum({a}, {b})" + + @staticmethod + def square(a): + return f"{a}.pow(2)" + + @staticmethod + def where(a, b, c): + return f"decltype({b})::blendv({c}, {b}, {a})" + + @staticmethod + def sign(x): + code = BracesBuffer() + # auto tmp5 = tmp4 < 0 ? -1 : 1; + vec_zero = f"decltype({x})(0)" + vec_one = f"decltype({x})(1)" + blendv = f"decltype({x})::blendv({vec_zero}, {vec_one}, {vec_zero} < {x})" + left = V.kernel.cse.newvar() + code.writeline(f"auto {left} = {blendv};") + + # auto tmp6 = tmp4 == 0 ? 0 : tmp5; + blendv = f"decltype({x})::blendv({vec_zero}, {vec_one}, {x} < {vec_zero})" + right = V.kernel.cse.newvar() + code.writeline(f"auto {right} = {blendv};") + result = V.kernel.cse.newvar() + code.writeline(f"auto {result} = {left} - {right};") + V.kernel.compute.splice(code) + return result + + @staticmethod + def to_dtype(x, dtype): + assert dtype in [torch.bool], f"{__name__} does not support {dtype}" + return f"({x})" + + class CppOverrides(OpOverrides): """Map element-wise ops to C++""" @@ -176,6 +388,10 @@ def exp(x): # return f"Sleef_expf_u10({x})" return f"std::exp({x})" + @staticmethod + def erf(x): + return f"std::erf({x})" + @staticmethod def sqrt(x): return f"std::sqrt({x})" @@ -184,6 +400,14 @@ def sqrt(x): def rsqrt(x): return f"1 / std::sqrt({x})" + @staticmethod + def log1p(x): + return f"std::log1p({x})" + + @staticmethod + def expm1(x): + return f"std::expm1({x})" + @staticmethod def signbit(x): return f"std::signbit({x})" @@ -246,11 +470,11 @@ def relu(x): @staticmethod def minimum(a, b): - return f"std::min({a}, {b})" + return f"({b} != {b}) ? {b} : std::min({a}, {b})" @staticmethod def maximum(a, b): - return f"std::max({a}, {b})" + return f"({b} != {b}) ? {b} : std::max({a}, {b})" @staticmethod def where(a, b, c): @@ -266,6 +490,8 @@ def constant(val, dtype): return f"std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::infinity()" elif val == float("-inf"): return f"-std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::infinity()" + elif math.isnan(val): + return f"std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::quiet_NaN()" elif val is True or val is False: return ops.to_dtype(str(val).lower(), dtype) return ops.to_dtype(repr(val), dtype) @@ -282,6 +508,8 @@ def masked(mask, body, other): code.writeline(f"float {var} = -std::numeric_limits::infinity();") elif other == float("inf"): code.writeline(f"float {var} = std::numeric_limits::infinity();") + elif isinstance(other, float): + code.writeline(f"float {var} = {other};") else: code.writeline(f"auto {var} = {other!r};") code.writeline(f"if({mask})") @@ -312,6 +540,19 @@ def sigmoid(x): x = ops.exp(f"-{x}") return f"1 / (1 + {x})" + @staticmethod + def sign(x): + code = BracesBuffer() + # auto tmp5 = tmp4 < 0 ? -1 : 1; + left = V.kernel.cse.newvar() + right = V.kernel.cse.newvar() + result = V.kernel.cse.newvar() + code.writeline(f"auto {left} = {x} > 0 ? 1 : 0;") + code.writeline(f"auto {right} = {x} < 0 ? 1 : 0;") + code.writeline(f"auto {result} = {left} - {right};") + V.kernel.compute.splice(code) + return result + class CppKernel(Kernel): overrides = CppOverrides @@ -413,9 +654,7 @@ def size_hint(self): return V.graph.sizevars.size_hint(sympy_product(self.call_ranges)) def codegen_loops(self, code, worksharing): - threads = config.cpp.threads - if threads < 1: - threads = torch.get_num_threads() + threads = parallel_num_threads() loops = [LoopLevel(var, size) for var, size in zip(self.itervars, self.ranges)] loops, reductions = LoopNest(loops[: self.reduction_depth]), LoopNest( @@ -423,11 +662,11 @@ def codegen_loops(self, code, worksharing): ) reductions.mark_reduction(self.reduction_vars) - if config.cpp.simdlen: + if codecache.pick_vec_isa(): # TODO(jansel): detect stride-1 dimension and vectorize that if reductions: reductions.loops[-1].simd = True - else: + elif loops: loops.loops[-1].simd = True par_depth = 0 @@ -509,14 +748,506 @@ def write_to_suffix(self): (self.loads, self.compute, self.stores, self.cse) = prior +class CppVecKernel(CppKernel): + overrides = CppVecOverrides + + def __init__(self, args, num_threads): + super(CppVecKernel, self).__init__(args, num_threads) + assert codecache.pick_vec_isa() + self.simd_nelements = codecache.pick_vec_isa().nelements() + self.reduction_omp_dec: Dict[str, str] = {} + metrics.generated_cpp_vec_kernel_count += 1 + + def is_single_step_var(self, var: sympy.Symbol, index: sympy.Expr): + replacement = {var: var + 1} + new_index = sympy_subs(index, replacement) + delta = sympy.simplify(new_index - index) + return delta == 1 + + def is_var_irrevelant(self, var: sympy.Symbol, index: sympy.Expr): + expanded_index = sympy.expand(index) + return not expanded_index.has(var) + + def transform_index(self, index: sympy.Expr): + expanded_index = sympy.expand(index) + assert self.simd_nelements + assert self.simd_nelements >= 1 + most_inner_var = self.itervars[-1] + replacement = {most_inner_var: most_inner_var * self.simd_nelements} + new_index = sympy_subs(expanded_index, replacement) + return new_index + + def load(self, name: str, index: sympy.Expr): + var = self.args.input(name) + index = self.rename_indexing(index) + + expanded_index = sympy.expand(index) + new_index = self.transform_index(index) + + if expanded_index == new_index: + line = f"at::vec::Vectorized({var}[{cexpr(index)}])" + else: + if V.graph.get_dtype(name) in [torch.bool, torch.uint8]: + g_tmp_buf = f"g_tmp_buffer_{var}" + nelements = codecache.pick_vec_isa().nelements() + self.loads.writeline(f"float {g_tmp_buf}[{nelements}] = {{0}};") + self.loads.writeline( + f"flag_to_float({var} + {cexpr(new_index)}, {g_tmp_buf}, {nelements});" + ) + line = f"at::vec::Vectorized::loadu({g_tmp_buf})" + else: + line = f"at::vec::Vectorized::loadu({var} + {cexpr(new_index)})" + + return self.cse.generate(self.loads, line) + + def store(self, name, index, value, mode=None): + assert "buf" in name + var = self.args.output(name) + index = self.rename_indexing(index) + assert mode is None + + expanded_index = sympy.expand(index) + new_index = self.transform_index(index) + assert new_index != expanded_index + line = f"{value}.store({var} + {cexpr(new_index)});" + self.stores.writeline(name, line) + + def reduction(self, name, dtype, src_dtype, reduction_type, index, value): + assert reduction_type in {"max", "min", "sum"} + assert dtype == torch.float + assert src_dtype == torch.float + reduce_map = {"max": "maximum", "min": "minimum"} + + vec_ns = "at::vec" + vec = f"{vec_ns}::Vectorized<{DTYPE_TO_CPP[dtype]}>" + + if reduction_type not in self.reduction_omp_dec: + vec_reduc_prefix = "#pragma omp declare reduction(" + vec_reduc_prefix += f"{RTYPE_TO_CPP[reduction_type]}:{vec}:" + if reduction_type == "sum": + vec_reduc_prefix += "omp_out += omp_in" + else: + vec_reduc_prefix += ( + f"omp_out = {vec_ns}::{reduce_map[reduction_type]}(omp_out, omp_in)" + ) + vec_reduc_prefix += ")" + vec_reduc_prefix += " initializer(" + vec_reduc_prefix += "omp_priv={{" + vec_reduc_prefix += f"{reduction_init(reduction_type, dtype)}" + vec_reduc_prefix += "}})" + self.reduction_omp_dec[reduction_type] = RTYPE_TO_CPP[reduction_type] + self.reduction_prefix.writeline(vec_reduc_prefix) + + tmpvar = self.cse.generate( + self.loads, f"reduction {name} {cexpr(index)}", write=False + ) + tmpvar_vec = f"{tmpvar}_vec" + + index = self.rename_indexing(index) + self.reduction_vars[tmpvar] = reduction_type + self.reduction_prefix.writeline( + f"{DTYPE_TO_CPP[dtype]} {tmpvar} = {reduction_init(reduction_type, dtype)};" + ) + self.reduction_prefix.writeline( + f"auto {tmpvar_vec} = at::vec::Vectorized<{DTYPE_TO_CPP[dtype]}>({tmpvar});" + ) + self.stores.writeline( + None, f"{reduction_combine_vec(reduction_type, tmpvar_vec, value)};" + ) + + reduce_all_body = "{" + if reduction_type == "sum": + reduce_all_body += "return x + y;" + else: + reduce_all_body += f"return {vec_ns}::{reduce_map[reduction_type]}(x, y);" + reduce_all_body += "}" + vec_reduce_all_func = f"{vec_ns}::vec_reduce_all<{DTYPE_TO_CPP[dtype]}>" + self.reduction_suffix.writeline( + name, + f"{tmpvar} = {vec_reduce_all_func}([]({vec}& x, {vec}&y) {reduce_all_body}, {tmpvar_vec});", + ) + self.cse.store_cache[name] = tmpvar + + +class CppVecKernelChecker(CppVecKernel): + def __init__(self, args, num_threads): + super(CppVecKernelChecker, self).__init__(args, num_threads) + + # Since this kernel is only for checker but does not genreate any + # code, so we need to decrease the kernel count. + metrics.generated_kernel_count -= 1 + metrics.generated_cpp_vec_kernel_count -= 1 + + # Used to recorde the graph wrapper code as the wrapper_code status could be + # changed during graph run. + self._orig_wrapper_code = None + + self.simd_vec = True + self.fast_vec_list = [] + for k, v in CppVecOverrides.__dict__.items(): + if isinstance(v, staticmethod): + self.fast_vec_list.append(k) + self.exit_stack = contextlib.ExitStack() + + def is_legal_data_access(self, var: sympy.Symbol, index: sympy.Expr): + return self.is_var_irrevelant(var, index) or self.is_single_step_var(var, index) + + def could_vec(self, name: str, index: sympy.Expr): + assert self.itervars is not None + # Not a loop + if len(self.itervars) == 0: + return False + + most_inner_var = self.itervars[-1] + return self.is_legal_data_access(most_inner_var, index) + + def load(self, name: str, index: sympy.Expr): + if not V.graph.get_dtype(name) in [ + torch.float, + torch.float32, + torch.bool, + torch.uint8, + ]: + self.simd_vec = False + return self.simd_vec + + index = self.rename_indexing(index) + self.simd_vec = self.simd_vec and self.could_vec(name, index) + return self.simd_vec + + def store(self, name, index, value, mode=None): + if not V.graph.get_dtype(name) in [torch.float, torch.float32]: + self.simd_vec = False + return self.simd_vec + + assert "buf" in name + index = self.rename_indexing(index) + + if mode: + self.simd_vec = False + return False + + self.simd_vec = self.simd_vec and self.could_vec(name, index) + return self.simd_vec + + def reduction(self, name, dtype, src_dtype, reduction_type, index, value): + if ( + dtype == torch.float + and src_dtype == torch.float + and reduction_type in ["max", "min", "sum"] + ): + pass + else: + self.simd_vec = False + return self.simd_vec + + def __exit__(self, exc_type, exc_val, exc_tb): + assert self._orig_wrapper_code is not None + # Restore the wrapper_code + V.graph.wrapper_code = self._orig_wrapper_code + self.exit_stack.__exit__(exc_type, exc_val, exc_tb) + + def __enter__(self): + # Recorde the graph wrapper code. The wrapper_code status could be + # changed during graph run. Regarding this checker, we also need to + # run the graph but we don't expect to change any status that would + # impact the code generation. Hence, we record the graph wapper code + # and replace it with a dummy warpper_code and then restore to the + # original one as long as the checker is finished. + self._orig_wrapper_code = V.graph.wrapper_code + V.graph.wrapper_code = WrapperCodeGen() + + class VecCheckerProxy: + @staticmethod + def __getattr__(name): + def inner(*args, **kwargs): + if not (name in self.fast_vec_list): + self.simd_vec = False + return self.simd_vec + + return inner + + @staticmethod + def load(name: str, index: sympy.Expr): + return self.load(name, index) + + @staticmethod + def store(name, index, value, mode=None): + return self.store(name, index, value, mode=mode) + + @staticmethod + def reduction(name, dtype, src_dtype, reduction_type, index, value): + return self.reduction( + name, dtype, src_dtype, reduction_type, index, value + ) + + @staticmethod + def constant(val, dtype): + supported_dtype = (torch.float32, torch.int32) + is_supported_dtype = dtype in (supported_dtype) + if not is_supported_dtype: + self.simd_vec = False + return is_supported_dtype + + @staticmethod + def index_expr(expr, dtype): + self.simd_vec = False + tmp_var = self.cse.newvar() + return tmp_var + + @staticmethod + def indirect_indexing(index_var): + self.simd_vec = False + return sympy.Symbol(str(index_var)) + + @staticmethod + def masked(mask, body, other): + tmp_var = self.cse.newvar() + return tmp_var + + @staticmethod + def to_dtype(x, dtype): + if dtype != torch.bool: + self.simd_vec = False + return x + + self.exit_stack.enter_context(V.set_ops_handler(VecCheckerProxy())) + self.exit_stack.enter_context(V.set_kernel_handler(self)) + return self + + +class CppKernelProxy(CppKernel): + def __init__(self, args=None, num_threads=None): + super(CppKernelProxy, self).__init__(args, num_threads) + self.simd_vec_kernel: CppVecKernel = None + self.simd_omp_kernel: CppKernel = None + self.picked_vec_isa: codecache.VecISA = codecache.pick_vec_isa() + + def vectorize_most_inner_loop(self, loop_nest, dtype=torch.float): + assert self.picked_vec_isa + nelements = self.picked_vec_isa.nelements(dtype) + loop_nest.split_most_inner_loop(nelements) + loop_with_tail = loop_nest.loops[-1] + assert isinstance(loop_with_tail, LoopLevelWithTail) + + loop_with_tail.main_loop.simd_vec = True + + loop_with_tail.tail_loop.simd_omp = True + # We chope the loop into two cubes by the nelements - main loop and tail loop. + # Regarding the main loop, it is straightforward that it could be vectorized with + # nelements. But for the tail loop, it still could be vectorized. For example, + # if the nelements is 8(256bits), then the tail loop still could be vectorized + # as 4(128bits). + loop_with_tail.tail_loop.simd_nelements = int(nelements / 2) + loop_with_tail.tail_loop.simd_vec = False + + loop_with_tail.main_loop_body = self.simd_vec_kernel + loop_with_tail.tail_loop_body = self.simd_omp_kernel + return loop_nest + + def codegen_loops(self, code, worksharing): + threads = parallel_num_threads() + + if self.simd_vec_kernel is None or not self.picked_vec_isa: + assert self.simd_omp_kernel + return self.simd_omp_kernel.codegen_loops(code, worksharing) + + assert self.simd_vec_kernel.itervars == self.simd_omp_kernel.itervars + assert self.simd_vec_kernel.ranges == self.simd_omp_kernel.ranges + assert ( + self.simd_vec_kernel.reduction_vars == self.simd_omp_kernel.reduction_vars + ) + + itervars = self.simd_vec_kernel.itervars + rangs = self.simd_vec_kernel.ranges + loops = [LoopLevel(var, size) for var, size in zip(itervars, rangs)] + assert ( + self.simd_vec_kernel.reduction_depth == self.simd_omp_kernel.reduction_depth + ) + reduction_depth = self.simd_vec_kernel.reduction_depth + loops_nest_non_reduce, loops_nest_reduce = LoopNest( + loops[:reduction_depth] + ), LoopNest(loops[reduction_depth:]) + loops_nest_reduce.mark_reduction(self.simd_vec_kernel.reduction_vars) + + assert self.picked_vec_isa + # Do not apply vectorization since the range of most inner is too small. Meanwhile, + # If the range of the most inner is less then the codecache.pick_vec_isa().nelements(), + # the generated code for some reduction will be as follows that leads to incrrect result. + # + # LINE01: float tmp1 = 0; + # LINE02: auto tmp1_vec = at::vec::Vectorized(tmp1); + # LINE03: for(long i1=0; i1<2; i1+=1) + # LINE04: { + # LINE05: for(long i2=0; i2<0; i2+=1) + # LINE06: { + # LINE07: auto tmp0 = at::vec::Vectorized::loadu(in_ptr0 + (8*i0) + (16*i2) + (32*i1)); + # LINE08: tmp1_vec += tmp0; + # LINE09: } + # LINE10: tmp1 = vec_reduce_all([](Vectorized& x, Vectorized&y) {return x + y;}, tmp1_vec); + # LINE11: #pragma omp simd simdlen(8) reduction(+:tmp1) + # LINE12: for(long i2=0; i2<8; i2+=1) + # LINE13: { + # LINE14: auto tmp0 = in_ptr0[i2 + (8*i0) + (32*i1)]; + # LINE15: tmp1 += tmp0; + # LINE16: } + # LINE17: } + # LINE18: out_ptr3[i0] = tmp1; + # + # tmp1_vec(LINE02) will always be zero as it is initialized with tmp1 value and the range(LINE05) + # is 0. Hence, the LINE10 will always reset tmp1 to 0. But tmp1(LINE01) is global value. So the result + # will be incorrect. We skip thie case. + most_inner_loop = ( + loops_nest_reduce.loops[-1] + if loops_nest_reduce + else loops_nest_non_reduce.loops[-1] + ) + main_loop_range = ir.IndexingDiv( + most_inner_loop.size, self.picked_vec_isa.nelements() + ) + loop_interval = sympy.simplify(main_loop_range) + # TODO(Eikan): To support dynamic shape. + if not loop_interval.is_integer or loop_interval <= 0: + metrics.generated_cpp_vec_kernel_count -= 1 + return self.simd_omp_kernel.codegen_loops(code, worksharing) + + # TODO(jansel): detect stride-1 dimension and vectorize that + if loops_nest_reduce: + loops_nest_reduce.loops[-1].simd = True + elif loops_nest_non_reduce: + loops_nest_non_reduce.loops[-1].simd = True + + par_depth = 0 + reduction_par_depth = 0 + if loops_nest_non_reduce: + par_depth = self.simd_vec_kernel.decide_parallel_depth( + self.simd_vec_kernel.call_ranges[:reduction_depth], threads + ) + else: + reduction_par_depth = self.simd_vec_kernel.decide_parallel_depth( + self.simd_vec_kernel.call_ranges[reduction_depth:], threads + ) + + # If the most inner loop of the reduction will be vectorized, the vectorization + # will add a vec variable for reduction. Take the code snippet as an example: + # float tmp1 = 0; + # for(long i1=0; i1<8; i1+=1) { + # auto tmp0 = in_ptr0[i1]; + # tmp1 += tmp0; + # } + # The vectorization will add tmp1_vec for reduction and then the loop will be transformed + # as follows. + # float tmp1 = 0; + # auto tmp1_vec = at::vec::Vectorized(tmp1); + # for(long i1=0; i1<1; i1+=1) { + # auto tmp0 = at::vec::Vectorized::loadu(in_ptr0 + (8*i1)); + # tmp1_vec += tmp0; + # } + # tmp1 = at::vec::vec_reduce_all([] + # (at::vec::Vectorized& x, at::vec::Vectorized&y) {return x + y;}, + # tmp1_vec); + # for(long i1=8; i1<8; i1+=1) { + # auto tmp0 = in_ptr0[i1]; + # tmp1 += tmp0; + # } + # It means that the vectorization introduce another reduction variable(tmp1_vec). + # If the most inner loop of the reduction is not a parallelized but its parent reduction + # loop is parallized, the new added reduction variable(tmp1_vec) could not be added + # to the parallelized loop reduction. So we skip this case and does not vectorize it. + if reduction_par_depth > 0 and reduction_par_depth != len( + loops_nest_reduce.loops + ): + metrics.generated_cpp_vec_kernel_count -= 1 + return self.simd_omp_kernel.codegen_loops(code, worksharing) + + with contextlib.ExitStack() as stack: + if par_depth: + worksharing.parallel(threads) + loops_nest_non_reduce.mark_parallel(par_depth) + elif reduction_par_depth: + # need to close the worksharing scope to define reduction vars outside it + worksharing.close() + loops_nest_reduce.mark_parallel(reduction_par_depth) + elif threads > 1: + if worksharing.single(): + stack.enter_context(code.indent()) + + non_reduce_loops = loops_nest_non_reduce.loops + reduce_loops = loops_nest_reduce.loops + loop_with_tail: LoopLevelWithTail = None + + if loops_nest_reduce: + self.vectorize_most_inner_loop(loops_nest_reduce) + loop_with_tail = loops_nest_reduce.loops[-1] + # The most inner loop will be vectorized + reduce_loops = reduce_loops[0:-1] + else: + self.vectorize_most_inner_loop(loops_nest_non_reduce) + loop_with_tail = loops_nest_non_reduce.loops[-1] + # The most inner loop will be vectorized + non_reduce_loops = non_reduce_loops[0:-1] + + # The reductions loops are always the loop body of non-reduction loops + for loop in non_reduce_loops: + code.writelines(loop.lines()) + stack.enter_context(code.indent()) + + with contextlib.ExitStack() as stack_outer: + if self.simd_vec_kernel.reduction_prefix: + stack_outer.enter_context(code.indent()) + code.splice(self.simd_vec_kernel.reduction_prefix) + + if reduction_par_depth: + worksharing.parallel(threads) + + with contextlib.ExitStack() as stack: + for loop in reduce_loops: + code.writelines(loop.lines()) + stack.enter_context(code.indent()) + + def gen_vectorized_loop(loop, kernel, write_reduction_suffix=False): + code.writelines(loop.lines()) + with contextlib.ExitStack() as stack: + stack.enter_context(code.indent()) + code.splice(kernel.loads) + code.splice(kernel.compute) + code.splice(kernel.stores) + if write_reduction_suffix: + code.splice(kernel.reduction_suffix) + + # Regarding the vectorized reduction loop, we need to call reduce_all to to reduce + # the vectorize as a single scalar. Hence, we set write_reduction_suffix to True to + # gen the code. + gen_vectorized_loop( + loop_with_tail.main_loop, loop_with_tail.main_loop_body, True + ) + + gen_vectorized_loop( + loop_with_tail.tail_loop, loop_with_tail.tail_loop_body, False + ) + + if reduction_par_depth: + worksharing.close() + + code.splice(loop_with_tail.tail_loop_body.reduction_suffix) + + class CppScheduling: def __init__(self, scheduler): self.scheduler = scheduler - self.kernel_group = KernelGroup() + self.get_kernel_group() def group_fn(self, sizes): return tuple(tuple(map(V.graph.sizevars.simplify, s)) for s in sizes) + def get_kernel_group(self): + from .wrapper import CppWrapperCodeGen + + if isinstance(V.graph.wrapper_code, CppWrapperCodeGen): + self.kernel_group = CppWrapperKernelGroup() + else: + self.kernel_group = KernelGroup() + @staticmethod def can_fuse_horizontal(node1, node2): _, (vars1, reduce1) = node1.group @@ -532,42 +1263,116 @@ def can_fuse_horizontal(node1, node2): def can_fuse_vertical(cls, node1, node2): return cls.can_fuse_horizontal(node1, node2) and not node1.is_reduction() - def codegen_nodes(self, nodes): - """ - Turn an set of pre-fused nodes into a C++ kernel. - """ - kernel_group = self.kernel_group - scheduler = self.scheduler + def can_vec(self, nodes): + if not codecache.pick_vec_isa(): + return False + _, (group, reduction_group) = max( nodes, key=lambda x: int(x.is_reduction()) ).group - in_suffix = False - - with kernel_group.new_kernel() as kernel: - vars, reduction_vars = kernel.set_ranges(group, reduction_group) + with CppVecKernelChecker( + deepcopy(self.kernel_group.args), parallel_num_threads() + ) as kernel_checker: + vars, reduction_vars = kernel_checker.set_ranges(group, reduction_group) for node in nodes: if node.group[1] in [ (group, reduction_group), (group + reduction_group, ()), ]: - assert not in_suffix node.run(vars, reduction_vars) else: - in_suffix = True assert node.group[1] == ( group, (), ), f"unexpected group: {node.group[1]} != {group}, {reduction_group}" - # we can fuse in some extra pointwise into the suffix - with kernel.write_to_suffix(): - node.run(vars, ()) + node.run(vars, ()) + + return kernel_checker.simd_vec + + def _codegen_nodes_impl(self, nodes, is_simd_vec=False): + """ + Turn an set of pre-fused nodes into a C++ kernel. + """ + kernel_group = self.kernel_group + _, (group, reduction_group) = max( + nodes, key=lambda x: int(x.is_reduction()) + ).group + + def create_kernel(_is_simd_vec): + in_suffix = False + + with kernel_group.new_kernel(_is_simd_vec) as kernel: + vars, reduction_vars = kernel.set_ranges(group, reduction_group) + + for node in nodes: + if node.group[1] in [ + (group, reduction_group), + (group + reduction_group, ()), + ]: + assert not in_suffix + node.run(vars, reduction_vars) + else: + in_suffix = True + assert node.group[1] == ( + group, + (), + ), f"unexpected group: {node.group[1]} != {group}, {reduction_group}" + # we can fuse in some extra pointwise into the suffix + with kernel.write_to_suffix(): + node.run(vars, ()) + return kernel + + org_inplace_buffers_flag = config.inplace_buffers + if is_simd_vec: + # Create vectorization kernel + cpp_vec_kernel = create_kernel(True) + + # Since a kernel is divided into two parts - vectorization and non-vectorization. + # And the two parts share the same global contexts like V.graph.wrapper_code, + # V.kernel.args. But the vectorization kernel generation has updated these global + # contexts. Hence, the non-vectorization kernel should not do this again to avoid + # conext conflict. By now, we only control the config.inplace_buffers. In the future, + # we could maintain more contexts. + config.inplace_buffers = False + + # Create non-vectorization kernel + cpp_kernel = create_kernel(False) + + # Restore the inplace_buffers flag + config.inplace_buffers = org_inplace_buffers_flag + return (cpp_vec_kernel, cpp_kernel) + else: + return (None, create_kernel(False)) + + def codegen_nodes(self, nodes): + """ + Turn an set of pre-fused nodes into a C++ kernel. + """ + kernel_group = self.kernel_group - kernel_group.finalize_kernel(kernel, scheduler) + can_be_simd_vec = self.can_vec(nodes) + simd_vec_kernel, simd_omp_kernel = self._codegen_nodes_impl( + nodes, can_be_simd_vec + ) + + assert simd_omp_kernel + metrics.generated_kernel_count -= 1 + # Maitain the metrics kernel count + if simd_vec_kernel: + metrics.generated_kernel_count -= 1 + + cpp_kernel_proxy = CppKernelProxy( + kernel_group.args, kernel_group.ws.num_threads + ) + cpp_kernel_proxy.simd_vec_kernel = simd_vec_kernel + cpp_kernel_proxy.simd_omp_kernel = simd_omp_kernel + + kernel_group.finalize_kernel(cpp_kernel_proxy, None) def flush(self): self.kernel_group.codegen_define_and_call(V.graph.wrapper_code) - self.kernel_group = KernelGroup() + self.get_kernel_group() class KernelGroup: @@ -580,8 +1385,11 @@ def __init__(self): self.stack.enter_context(self.ws) self.count = 0 - def new_kernel(self): - return CppKernel(self.args, self.ws.num_threads) + def new_kernel(self, simd_vec=False): + if simd_vec: + return CppVecKernel(self.args, parallel_num_threads()) + else: + return CppKernel(self.args, parallel_num_threads()) def finalize_kernel(self, new_kernel, scheduler): self.count += 1 @@ -594,11 +1402,27 @@ def codegen_define_and_call(self, wrapper): if self.count == 0: return - arg_defs, call_args = self.args.cpp_argdefs() + kernel_name = "kernel_cpp_" + wrapper.next_kernel_suffix() + arg_defs, call_args, arg_types = self.args.cpp_argdefs() arg_defs = ",\n".ljust(25).join(arg_defs) + arg_types = ",".join(arg_types) code = BracesBuffer() + # TODO: support kernel profile on other platforms + enable_kernel_profile = ( + config.cpp.enable_kernel_profile and sys.platform == "linux" + ) + if enable_kernel_profile: + code.writelines(["#include "]) code.writelines([cpp_prefix(), "" f'extern "C" void kernel({arg_defs})']) with code.indent(): + if enable_kernel_profile: + graph_id = V.graph.graph_id + prefix = "graph_" + str(graph_id) + "_" if graph_id is not None else "" + code.writelines( + [ + f'RECORD_FUNCTION("{prefix + kernel_name}", c10::ArrayRef({{}}));' + ] + ) for old, new in self.args.aliases(): code.writeline(f"auto {old} = {new};") code.splice(self.loops_code) @@ -608,17 +1432,20 @@ def codegen_define_and_call(self, wrapper): codecache_def.splice(code) codecache_def.writeline("''')") - kernel_name = wrapper.next_kernel_name() codecache_str = codecache_def.getvalue() # TODO(voz): Ostensibly, we should not need this. But there are cases where C++ codegen does # not use BracesBuffer, so we have no good indicator of a C++ buffer atm. codecache_str = codecache_str.replace("#pragma CMT", "//") wrapper.define_kernel(kernel_name, codecache_str) - + wrapper.load_kernel(kernel_name, code, arg_types) # generate the code to call this - wrapper.writeline( - "{}({})".format(kernel_name, ", ".join(call_args)), - ) + wrapper.generate_kernel_call(kernel_name, call_args) + + +class CppWrapperKernelGroup(KernelGroup): + def __init__(self): + super().__init__() + self.args = CppWrapperKernelArgs() class WorkSharing: @@ -660,41 +1487,65 @@ def __exit__(self, exc_type, exc_val, exc_tb): @dataclasses.dataclass class LoopLevel: - var: sympy.Expr - size: sympy.Expr + var: sympy.Expr = None + size: sympy.Expr = None + offset: sympy.Expr = sympy.Integer(0) + steps: sympy.Expr = sympy.Integer(1) parallel: int = 0 - simd: bool = False + simd_omp: bool = False + picked_vec_isa: codecache.VecISA = codecache.pick_vec_isa() + simd_nelements: int = picked_vec_isa.nelements() if picked_vec_isa else 0 + simd_vec: bool = False collapsed: bool = False reduction_vars: Dict[str, str] = None def lines(self): if self.reduction_vars: + suffix = "_vec" if self.simd_vec else "" reduction = " " + " ".join( - f"reduction({RTYPE_TO_CPP[rtype]}:{var})" + f"reduction({RTYPE_TO_CPP[rtype]}:{var}{suffix})" for var, rtype in self.reduction_vars.items() ) else: reduction = "" - simd = f"simd simdlen({config.cpp.simdlen})" + simd = ( + f"simd simdlen({self.simd_nelements}) " + if self.simd_omp and self.simd_nelements > 1 + else "" + ) if self.parallel: # TODO(jansel): look into chunk size and other schedules line1 = f"#pragma omp for{reduction} " if self.parallel > 1: line1 += f" collapse({self.parallel})" - if self.simd: + if self.simd_omp: line1 = line1.replace(" for ", f" for {simd}") - elif self.simd: + elif self.simd_vec: + line1 = "" + elif self.simd_omp: line1 = f"#pragma omp {simd}{reduction}" elif not self.reduction_vars and codecache.is_gcc(): line1 = "#pragma GCC ivdep" else: line1 = "" - line2 = f"for({INDEX_TYPE} {self.var}=0; {self.var}<{cexpr(self.size)}; ++{self.var})" + line2 = f"for({INDEX_TYPE} {self.var}={cexpr(self.offset)}; {self.var}<{cexpr(self.size)}; {self.var}+={cexpr(self.steps)})" if self.collapsed or not line1: return [line2] return [line1, line2] +class LoopLevelWithTail(LoopLevel): + def __init__(self, main_loop: LoopLevel, tail_loop: LoopLevel): + super().__init__() + self.main_loop = main_loop + self.tail_loop = tail_loop + self.main_loop_body = None + self.tail_loop_body = None + + def lines(self): + raise AssertionError("Not Implemented") + + @dataclasses.dataclass class LoopNest: loops: List[LoopLevel] @@ -711,7 +1562,38 @@ def mark_parallel(self, par_depth): loops[0].parallel = par_depth for i in range(1, par_depth): loops[i].collapsed = True - loops[0].simd = loops[par_depth - 1].simd + + def split_most_inner_loop(self, factor): + sympy_factor = sympy.Integer(factor) + + most_inner_loop = self.loops[-1] + + # If the most inner loop needs to be collapsed, we need to + # exclude it since we need to split it into two loops. Meanwhile, + # we still mark it as parallized. + if most_inner_loop.collapsed: + assert self.loops[0].parallel == len(self.loops) + self.loops[0].parallel -= 1 + + main_loop_range = ir.IndexingDiv(most_inner_loop.size, sympy_factor) + + main_loop = LoopLevel(most_inner_loop.var, main_loop_range) + main_loop.parallel = most_inner_loop.parallel + main_loop.collapsed = False + main_loop.reduction_vars = most_inner_loop.reduction_vars + + offset = main_loop_range * sympy_factor + tail_loop = LoopLevel(most_inner_loop.var, most_inner_loop.size) + tail_loop.offset = offset + tail_loop.parallel = most_inner_loop.parallel + tail_loop.collapsed = False + tail_loop.reduction_vars = most_inner_loop.reduction_vars + + loop_with_tail = LoopLevelWithTail(main_loop, tail_loop) + loop_with_tail.parallel = 0 + loop_with_tail.collapsed = False + + self.loops[-1] = loop_with_tail def codegen(self, code, stack): for loop in self.loops: diff --git a/torch/_inductor/codegen/cpp_prefix.h b/torch/_inductor/codegen/cpp_prefix.h index d9c0a99f5f42c..c1c9c3bae112d 100644 --- a/torch/_inductor/codegen/cpp_prefix.h +++ b/torch/_inductor/codegen/cpp_prefix.h @@ -5,9 +5,13 @@ #include #include -#include "ATen/core/PhiloxRNGEngine.h" -#include +#include +#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) +#include +#include +#endif #include +#include typedef at::Half half; typedef at::BFloat16 bfloat16; @@ -53,3 +57,15 @@ template void atomic_add(volatile T *addr, T offset) { } while (!atomic_addr->compare_exchange_weak(expected, desired, std::memory_order_relaxed)); } + +// This function is used to convert bool or uint8 to float mask for +// vectorization. The caller needs to make sure the src represents TRUE/FALSE +// correctly. +template +void flag_to_float(const T* src, float* dst, int64_t n) { +#pragma unroll + for (int64_t i = 0; i < n; i++) { + uint32_t* dst_u32 = (uint32_t*)dst; + dst_u32[i] = *(src + i) ? 0xFFFFFFFF : 0; + } +} diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 832a0e6c82b4c..6c5130b829566 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -12,11 +12,12 @@ import torch +from ..._dynamo import config as dynamo_config from .. import config, ir, scheduler from ..ir import ReductionHint from ..utils import ( - dynamo_logging, free_symbol_startswith, + get_fused_kernel_name, instance_descriptor, sympy_product, sympy_subs, @@ -24,6 +25,7 @@ ) from ..virtualized import ops, V from .common import ( + CSEVariable, DeferredLine, ExprPrinter, IndentedBuffer, @@ -109,6 +111,17 @@ def triton_constant(value): return repr(value) +class TritonCSEVariable(CSEVariable): + def __init__(self, name): + super().__init__(name) + self.is_scalar = False + + def update_on_args(self, args, kwargs): + self.is_scalar = all( + not (isinstance(arg, TritonCSEVariable)) or arg.is_scalar for arg in args + ) + + class TritonOverrides(OpOverrides): """Map element-wise ops to Triton""" @@ -152,26 +165,14 @@ def relu(x): @staticmethod def minimum(a, b): - return f"tl.minimum({a}, {b})" + return f"tl.where({a} != {a}, {a}, tl.where({a} < {b}, {a}, {b}))" @staticmethod def maximum(a, b): - return f"tl.maximum({a}, {b})" + return f"tl.where({a} != {a}, {a}, tl.where({a} > {b}, {a}, {b}))" @staticmethod def where(a, b, c): - if not config.triton.simple_where: - # wonkyness to work around https://github.com/openai/triton/issues/532 - # identity calls to force new triton variables (and get access to .shape/.dtype/.numel - a = ops.identity(a) - b = ops.identity(b) - c = ops.identity(c) - a = ops.identity( - f"{a} | tl.zeros({b}.shape, {a}.dtype) if {b}.numel > 1 else {a}" - ) - a = ops.identity( - f"{a} | tl.zeros({c}.shape, {a}.dtype) if {c}.numel > 1 else {a}" - ) return f"tl.where({a}, {b}, {c})" @staticmethod @@ -206,6 +207,10 @@ def masked(mask, body, other): def lgamma(x): return f"tl.libdevice.lgamma({x})" + @staticmethod + def erf(x): + return f"tl.libdevice.erf({x})" + @staticmethod def logical_and(a, b): return f"{a} & {b}" @@ -226,6 +231,14 @@ def randn(seed, offset, _): # _ here to keep the contract identical to CPU rand def rsqrt(x): return f"tl.libdevice.rsqrt({x})" + @staticmethod + def log1p(x): + return f"tl.libdevice.log1p({x})" + + @staticmethod + def expm1(x): + return f"tl.libdevice.expm1({x})" + @staticmethod def sigmoid(x): return f"tl.sigmoid({x})" @@ -237,11 +250,11 @@ def libdevice_sigmoid(x): @staticmethod def signbit(x): # XX: This is wrong for the value -0.0 in floating point - return f"tl.libdevice.signbitf({x}) if ({x}).dtype is tl.float32 else {x} < 0" + return f"tl.libdevice.signbit({x}) if ({x}).dtype is tl.float32 else {x} < 0" @staticmethod def fmod(a, b): - return f"tl.libdevice.fmod({a}, ({b}).to(tl.float32))" + return f"tl.libdevice.fmod({a}, {b})" @staticmethod def pow(a, b): @@ -257,11 +270,11 @@ def libdevice_log(x): @staticmethod def isinf(x): - return f"tl.libdevice.isinfd({x}) if ({x}).dtype is tl.float64 else tl.libdevice.isinff({x})" + return f"tl.libdevice.isinf({x})" @staticmethod def isnan(x): - return f"tl.libdevice.isnand({x}) if ({x}).dtype is tl.float64 else tl.libdevice.isnanf({x})" + return f"tl.libdevice.isnan({x})" @staticmethod def round(x): @@ -499,11 +512,18 @@ class TritonKernel(Kernel): overrides = TritonOverrides sexpr = texpr - def __init__(self, *groups, pid_cache=None, reduction_hint=ReductionHint.DEFAULT): + def __init__( + self, + *groups, + mutations=None, + pid_cache=None, + reduction_hint=ReductionHint.DEFAULT, + ): if pid_cache is None: pid_cache = {} super(TritonKernel, self).__init__() self.numels = [V.graph.sizevars.simplify(s) for s in groups] + self.mutations = mutations self.range_trees = [] self.range_tree_nodes = {} self.iter_vars_count = itertools.count() @@ -740,6 +760,12 @@ def indexing( mask = dense_mask index_str = f"{index_str} + tl.zeros({copy_shape}.shape, tl.int32)" elif indirect_indexing: + # Use dense mask for indirect_indexing + # See https://github.com/pytorch/torchdynamo/issues/1654 + # TODO - An optimization could be to hoist this load outside of + # reduction loop, if it is independent of rmask. Such example can be found in + # https://github.com/pytorch/torchdynamo/issues/1654 + index_str = f"{index_str} + tl.zeros({self.dense_size_str()}, tl.int32)" mask = dense_mask if self._load_mask: @@ -752,7 +778,13 @@ def indexing( # https://github.com/openai/triton/issues/633 mask = ["None"] - return index_str, " & ".join(mask) + if ( + index_str in self.cse.varname_map + and self.cse.varname_map[index_str].is_scalar + ): + mask = ["None"] + + return index_str, " & ".join(map(str, mask)) def var_ranges(self): return dict( @@ -982,24 +1014,35 @@ def codegen_kernel(self, name=None): import triton import triton.language as tl from {config.inductor_import}.ir import ReductionHint + from {config.inductor_import}.ir import TileHint from {config.inductor_import}.triton_ops.autotune import {heuristics} from {config.inductor_import}.utils import instance_descriptor """ ) argdefs, _, signature = self.args.python_argdefs() + + mutated_args = set() + for mutation in self.mutations: + if mutation in self.args.input_buffers: + mutated_args.add(self.args.input_buffers[mutation]) + if mutation in self.args.inplace_buffers: + mutated_args.add(self.args.inplace_buffers[mutation].inner_name) + if mutation in self.args.output_buffers: + mutated_args.add(self.args.output_buffers[mutation]) + triton_meta = { "signature": dict(enumerate(map(signature_of, signature))), "device": V.graph.scheduler.current_device.index, - "configs": [config_of(signature)], "constants": {}, + "mutated_arg_names": mutated_args, } for tree in self.range_trees: if tree.prefix != "r" or self.inside_reduction: - triton_meta["signature"][len(argdefs)] = signature_of( - SizeArg(f"{tree.prefix}numel", tree.numel) - ) + sizearg = SizeArg(f"{tree.prefix}numel", tree.numel) + signature.append(sizearg) + triton_meta["signature"][len(argdefs)] = signature_of(sizearg) argdefs.append(f"{tree.prefix}numel") # constexpr version causes issues, see # https://github.com/pytorch/torchdynamo/pull/1362 @@ -1007,6 +1050,7 @@ def codegen_kernel(self, name=None): # tree.numel # ) # argdefs.append(f"{tree.prefix}numel: tl.constexpr") + triton_meta["configs"] = [config_of(signature)] for tree in self.range_trees: if tree.prefix != "r" or self.inside_reduction: @@ -1022,8 +1066,14 @@ def codegen_kernel(self, name=None): @triton.jit """ else: + tile_hint = "" + if len(size_hints) == 2: + if len(signature) == 4: # input, output and 2 args + tile_hint = "tile_hint=TileHint.SQUARE," + else: + tile_hint = "tile_hint=TileHint.DEFAULT," heuristics_line = f""" - @{heuristics}(size_hints={size_hints!r}, filename=__file__, meta={triton_meta!r}) + @{heuristics}(size_hints={size_hints!r}, {tile_hint}filename=__file__, meta={triton_meta!r}) @triton.jit """ code.splice(heuristics_line) @@ -1099,6 +1149,9 @@ def call_kernel(self, code, name: str): f"{name}.run({call_args}, grid=grid({', '.join(grid)}), stream={stream_name})" ) + def create_cse_var(self, *args, **kwargs): + return TritonCSEVariable(*args, **kwargs) + class TritonScheduling: def __init__(self, scheduler): @@ -1226,7 +1279,8 @@ def end_current_reduction_loop(): f"unexpected group: ({numel}, {rnumel}) != {node.group[1]}" ) - log.log(dynamo_logging.CODE, "schedule: %s", node_schedule) + if dynamo_config.output_code: + log.info("schedule: %s", node_schedule) return self.codegen_node_schedule(node_schedule, numel, rnumel) @staticmethod @@ -1257,7 +1311,15 @@ def codegen_node_schedule(self, node_schedule, numel, reduction_numel): reduction_hint_val = ReductionHint.DEFAULT else: reduction_hint_val = ReductionHint.DEFAULT - with TritonKernel(*tiled_groups, reduction_hint=reduction_hint_val) as kernel: + + mutations = set() + for node in node_schedule: + if hasattr(node, "get_mutations"): + mutations.update(node.get_mutations()) + + with TritonKernel( + *tiled_groups, reduction_hint=reduction_hint_val, mutations=mutations + ) as kernel: stack = contextlib.ExitStack() for node in node_schedule: if node not in (EnableReduction, DisableReduction): @@ -1275,9 +1337,14 @@ def codegen_node_schedule(self, node_schedule, numel, reduction_numel): if src_code in wrapper.kernels: kernel_name = wrapper.kernels[src_code] else: - kernel_name = wrapper.next_kernel_name() + fused_name = ( + get_fused_kernel_name(node_schedule) + if config.triton.descriptive_kernel_names + else "" + ) + kernel_name = "_".join(["triton", fused_name, wrapper.next_kernel_suffix()]) wrapper.kernels[src_code] = kernel_name - subs_name = kernel_name if config.triton.ordered_kernel_names else "kernel" + subs_name = kernel_name if config.triton.ordered_kernel_names else "triton_" src_code = src_code.replace("KERNEL_NAME", subs_name) # TODO(voz): Ostensibly, we should not need this. But there are cases where C++ codegen does # not use BracesBuffer, so we have no good indicator of a C++ buffer atm. diff --git a/torch/_inductor/codegen/triton_template.py b/torch/_inductor/codegen/triton_template.py index 4d86feeccec86..cd1c2bed6bb7c 100644 --- a/torch/_inductor/codegen/triton_template.py +++ b/torch/_inductor/codegen/triton_template.py @@ -330,12 +330,12 @@ def template_codegen(scheduler, scheduler_node, epilogue): kernel_buf_replace_name = None if could_remove_kernel_buf: for node in epilogue: - if kernel.args.output_buffers[node.get_name()] != "REMOVED": + if not kernel.args.is_removed(node.get_name()): kernel_buf_replace_name = node.get_name() break assert kernel_buf_replace_name is not None - kernel_name = wrapper.next_kernel_name() + kernel_name = "triton_template_" + wrapper.next_kernel_suffix() # code gen kernel wrapper.header.splice( kernel.codegen_kernel( diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 7efc1cf1aa8c6..63fcb745a189a 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -1,4 +1,5 @@ import collections +import contextlib import dataclasses import functools import hashlib @@ -6,6 +7,7 @@ from typing import Any, Dict, List from .. import codecache, config, ir +from ..codecache import cpp_compile_command, get_code_path from ..utils import dynamo_utils, has_triton, sympy_dot, sympy_product from ..virtualized import V from .common import CodeGen, DeferredLine, IndentedBuffer, Kernel @@ -27,18 +29,18 @@ def buffer_reuse_key(node: ir.Buffer): ) -def make_buffer_reuse(old, new): +def make_buffer_reuse(old, new, del_func, declare, ending, as_strided): assert old.get_dtype() == new.get_dtype() del_line = "" if old.get_name() not in V.graph.get_output_names(): - del_line = f"; del {old.get_name()}" + del_line = del_func(old.get_name()) if old.get_size() == new.get_size() and old.get_stride() == new.get_stride(): - return f"{new.get_name()} = {old.get_name()}{del_line}" + return f"{declare}{new.get_name()} = {old.get_name()}{del_line}{ending}" return ( - f"{new.get_name()} = as_strided({old.get_name()}, " + f"{declare}{new.get_name()} = {as_strided}({old.get_name()}, " f"{V.graph.sizevars.codegen_shape_tuple(new.get_size())}, " - f"{V.graph.sizevars.codegen_shape_tuple(new.get_stride())}){del_line}" + f"{V.graph.sizevars.codegen_shape_tuple(new.get_stride())}){del_line}{ending}" ) @@ -55,6 +57,21 @@ def make_buffer_allocation(buffer): ) +def make_cpp_buffer_allocation(buffer): + from .cpp import DTYPE_TO_ATEN + + # TODO: map layout and device here + dtype = buffer.get_dtype() + shape = tuple(buffer.get_size()) + stride = tuple(buffer.get_stride()) + return ( + f"auto {buffer.get_name()} = at::empty_strided(" + f"{V.graph.sizevars.codegen_shape_tuple(shape)}, " + f"{V.graph.sizevars.codegen_shape_tuple(stride)}, " + f"{DTYPE_TO_ATEN[dtype]}); " + ) + + class MemoryPlanningState: def __init__(self): super().__init__() @@ -107,6 +124,27 @@ def codegen(self, code: IndentedBuffer): code.writeline(make_buffer_allocation(self.node)) +@dataclasses.dataclass +class CppAllocateLine(AllocateLine): + def plan(self, state: MemoryPlanningState): + if self.node.get_name() in V.graph.removed_buffers: + return NullLine() + + # try to reuse a recently freed buffer + key = buffer_reuse_key(self.node) + + if key in state: + free_line = state.pop(key) + free_line.is_reused = True + return CppReuseLine(free_line.node, self.node) + + return self + + def codegen(self, code: IndentedBuffer): + assert self.node.get_name() not in V.graph.removed_buffers + code.writeline(make_cpp_buffer_allocation(self.node)) + + @dataclasses.dataclass class FreeIfNotReusedLine(MemoryPlanningLine): node: ir.Buffer @@ -125,6 +163,17 @@ def codegen(self, code: IndentedBuffer): code.writeline(f"del {self.node.get_name()}") +@dataclasses.dataclass +class CppFreeIfNotReusedLine(FreeIfNotReusedLine): + node: ir.Buffer + is_reused: bool = False + + def codegen(self, code: IndentedBuffer): + assert (self.node.get_name()) not in V.graph.removed_buffers + if not self.is_reused: + code.writeline(f"{self.node.get_name()}.reset();") + + @dataclasses.dataclass class ReuseLine(MemoryPlanningLine): node: ir.Buffer @@ -138,7 +187,38 @@ def plan(self, state: MemoryPlanningState): def codegen(self, code: IndentedBuffer): assert self.node.get_name() not in V.graph.removed_buffers assert self.reused_as.get_name() not in V.graph.removed_buffers - code.writeline(make_buffer_reuse(self.node, self.reused_as) + " # reuse") + code.writeline( + make_buffer_reuse( + self.node, + self.reused_as, + del_func=lambda name: f"; del {name}", + declare="", + ending="", + as_strided="as_strided", + ) + + " # reuse" + ) + + +@dataclasses.dataclass +class CppReuseLine(ReuseLine): + node: ir.Buffer + reused_as: ir.Buffer + + def codegen(self, code: IndentedBuffer): + assert self.node.get_name() not in V.graph.removed_buffers + assert self.reused_as.get_name() not in V.graph.removed_buffers + code.writeline( + make_buffer_reuse( + self.node, + self.reused_as, + del_func=lambda name: f"; {name}.reset()", + declare="auto ", + ending=";", + as_strided="at::as_strided", + ) + + " // reuse" + ) @dataclasses.dataclass @@ -169,6 +249,7 @@ def __init__(self): self._names_iter = count() self.header = IndentedBuffer() self.prefix = IndentedBuffer() + self.wrapper_call = IndentedBuffer() self.kernels = {} self.lines = [] self.header.splice( @@ -218,6 +299,20 @@ def __init__(self): f"from {config.inductor_import}.triton_ops.batched_matmul import bmm_out as triton_bmm_out" ) + self.write_prefix() + + for name, value in V.graph.constants.items(): + # include a hash so our code cache gives different constants different files + hashed = hashlib.sha256(repr(value).encode("utf-8")).hexdigest() + self.header.writeline(f"{name} = None # {hashed}") + + self.allocated = set() + self.freed = set() + self.write_get_cuda_stream = functools.lru_cache(None)( + self.write_get_cuda_stream + ) + + def write_prefix(self): self.prefix.splice( """ @@ -227,36 +322,33 @@ def __init__(self): def call(args): """ ) - with self.prefix.indent(): + with self.wrapper_call.indent(): inp_len = len(V.graph.graph_inputs.keys()) if inp_len != 0: lhs = f"{', '.join(V.graph.graph_inputs.keys())}{'' if inp_len != 1 else ','}" - self.prefix.writeline(f"{lhs} = args") - self.prefix.writeline("args.clear()") + self.wrapper_call.writeline(f"{lhs} = args") + self.wrapper_call.writeline("args.clear()") for name in V.graph.randomness_seeds: - self.prefix.writeline( + self.wrapper_call.writeline( f"torch.randint(2**31, size=(), dtype=torch.int64, out={name})" ) - V.graph.sizevars.codegen(self.prefix, V.graph.graph_inputs) - - for name, value in V.graph.constants.items(): - # include a hash so our code cache gives different constants different files - hashed = hashlib.sha256(repr(value).encode("utf-8")).hexdigest() - self.header.writeline(f"{name} = None # {hashed}") - - self.allocated = set() - self.freed = set() - self.write_get_cuda_stream = functools.lru_cache(None)( - self.write_get_cuda_stream - ) + V.graph.sizevars.codegen(self.wrapper_call, V.graph.graph_inputs) def write_get_cuda_stream(self, index): name = f"stream{index}" self.writeline(f"{name} = get_cuda_stream({index})") return name - def next_kernel_name(self): - return f"kernel{next(self._names_iter)}" + def next_kernel_suffix(self): + return f"{next(self._names_iter)}" + + def write_allocate_line(self, buffer): + self.writeline(AllocateLine(buffer)) + + def get_deferred_line(self, name, layout): + return DeferredLine( + name, f"{name} = {layout.view.codegen_reference()} # alias" + ) def codegen_allocation(self, buffer): name = buffer.get_name() @@ -278,20 +370,24 @@ def codegen_allocation(self, buffer): if not layout.maybe_guard_aligned(): V.graph.unaligned_buffers.add(name) self.codegen_allocation(layout.view.data) - allocation = DeferredLine( - name, f"{name} = {layout.view.codegen_reference()} # alias" - ) + allocation = self.get_deferred_line(name, layout) self.writeline(allocation) return - self.writeline(AllocateLine(buffer)) + self.write_allocate_line(buffer) + + def write_del_line(self, name): + self.writeline(f"del {name}") + + def write_free_if_not_reused_line(self, buffer): + self.writeline(FreeIfNotReusedLine(buffer)) def codegen_free(self, buffer): name = buffer.get_name() # can be freed but not reused if isinstance(buffer, ir.InputBuffer): - self.writeline(f"del {name}") + self.write_del_line(name) return if not self.can_reuse(buffer): @@ -300,10 +396,10 @@ def codegen_free(self, buffer): layout = buffer.get_layout() if isinstance(layout, (ir.AliasedLayout, ir.MultiOutputLayout)): - self.writeline(f"del {name}") + self.write_del_line(name) return - self.writeline(FreeIfNotReusedLine(buffer)) + self.write_free_if_not_reused_line(buffer) def can_reuse(self, buffer): name = buffer.get_name() @@ -316,12 +412,24 @@ def can_reuse(self, buffer): return False return True + def write_reuse_line(self, input_buffer, output_buffer): + self.writeline(ReuseLine(input_buffer, output_buffer)) + def codegen_inplace_reuse(self, input_buffer, output_buffer): assert buffer_reuse_key(input_buffer) == buffer_reuse_key(output_buffer) self.codegen_allocation(input_buffer) self.freed.add(input_buffer.get_name()) self.allocated.add(output_buffer.get_name()) - self.writeline(ReuseLine(input_buffer, output_buffer)) + self.write_reuse_line(input_buffer, output_buffer) + + def generate_return(self, output_refs): + if output_refs: + self.wrapper_call.writeline("return (" + ", ".join(output_refs) + ", )") + else: + self.wrapper_call.writeline("return ()") + + def generate_end(self, result): + return @dynamo_utils.dynamo_timed def generate(self): @@ -330,7 +438,16 @@ def generate(self): result.splice(self.prefix) out_names = V.graph.get_output_names() - with result.indent(): + with contextlib.ExitStack() as stack: + stack.enter_context(self.wrapper_call.indent()) + if config.profiler_mark_wrapper_call: + self.wrapper_call.writeline( + "from torch.profiler import record_function" + ) + self.wrapper_call.writeline( + "with record_function('inductor_wrapper_call'):" + ) + stack.enter_context(self.wrapper_call.indent()) while ( self.lines and isinstance(self.lines[-1], MemoryPlanningLine) @@ -347,15 +464,17 @@ def generate(self): for line in self.lines: if isinstance(line, MemoryPlanningLine): - line.codegen(result) + line.codegen(self.wrapper_call) else: - result.writeline(line) + self.wrapper_call.writeline(line) output_refs = [x.codegen_reference() for x in V.graph.graph_outputs] - if output_refs: - result.writeline("return (" + ", ".join(output_refs) + ", )") - else: - result.writeline("return ()") + self.generate_return(output_refs) + + with result.indent(): + result.splice(self.wrapper_call) + + self.generate_end(result) self.add_benchmark_harness(result) @@ -371,8 +490,8 @@ def add_benchmark_harness(self, output): def add_fake_input(name, shape, stride, device, dtype): output.writeline( f"{name} = rand_strided(" - f"{V.graph.sizevars.codegen_shape_tuple(shape)}, " - f"{V.graph.sizevars.codegen_shape_tuple(stride)}, " + f"{V.graph.sizevars.codegen_benchmark_shape_tuple(shape)}, " + f"{V.graph.sizevars.codegen_benchmark_shape_tuple(stride)}, " f"device='{device.type}', dtype={dtype})" ) @@ -405,6 +524,17 @@ def add_fake_input(name, shape, stride, device, dtype): def define_kernel(self, name: str, kernel: str): self.header.splice(f"\n\n{name} = {kernel}") + def load_kernel(self, name: str = None, kernel: str = None, arg_types: List = None): + return + + def wrap_kernel_call(self, name, call_args): + return "{}({})".format(name, ", ".join(call_args)) + + def generate_kernel_call(self, name, call_args): + self.writeline( + self.wrap_kernel_call(name, call_args), + ) + def call_kernel(self, name: str, kernel: Kernel): tmp = IndentedBuffer() kernel.call_kernel(self, tmp, name) @@ -415,3 +545,150 @@ def call_kernel(self, name: str, kernel: Kernel): def writeline(self, line): self.lines.append(line) + + +class CppWrapperCodeGen(WrapperCodeGen): + """ + The outer wrapper that calls the kernels. + """ + + call_func_id = count() + + def __init__(self): + self._call_func_id = next(CppWrapperCodeGen.call_func_id) + super().__init__() + + def write_prefix(self): + self.prefix.splice( + """ + async_compile.wait(globals()) + del async_compile + from torch.utils.cpp_extension import load_inline + wrapper = ( + ''' + #include + #include + """ + ) + with self.wrapper_call.indent(): + inputs_len = len(V.graph.graph_inputs.keys()) + output_refs = [x.codegen_reference() for x in V.graph.graph_outputs] + if output_refs: + if len(output_refs) == 1: + output_types = "at::Tensor" + else: + output_types = "std::vector" + else: + output_types = "void" + + if inputs_len != 0: + inputs_args = ["at::Tensor&"] * len(V.graph.graph_inputs.keys()) + inputs_args = ", ".join(inputs_args) + inputs_args = f"std::tuple<{inputs_args}>" + + self.wrapper_call.writeline( + f"{output_types} call_{self._call_func_id}({inputs_args} args) {{" + ) + inputs_keys_str = ", ".join(V.graph.graph_inputs.keys()) + self.wrapper_call.writeline(f"at::Tensor {inputs_keys_str};") + self.wrapper_call.writeline(f"std::tie({inputs_keys_str}) = args;") + else: + self.wrapper_call.writeline( + f"{output_types} call_{self._call_func_id}(std::tuple<> args) {{" + ) + for name in V.graph.randomness_seeds: + self.wrapper_call.writeline(f"at::Tensor {name};") + self.wrapper_call.writeline( + f"{name} = at::randint(std::pow(2, 31), {{}}, at::ScalarType::Long);" + ) + V.graph.sizevars.codegen(self.wrapper_call, V.graph.graph_inputs) + + def write_allocate_line(self, buffer): + self.writeline(CppAllocateLine(buffer)) + + def write_del_line(self, name): + self.writeline(f"{name}.reset();") + return + + def write_free_if_not_reused_line(self, buffer): + self.writeline(CppFreeIfNotReusedLine(buffer)) + return + + def write_reuse_line(self, input_buffer, output_buffer): + self.writeline(CppReuseLine(input_buffer, output_buffer)) + + def get_deferred_line(self, name, layout): + return DeferredLine( + name, f"auto {name} = {layout.view.codegen_reference()}; // alias" + ) + + def get_kernel_path(self, code): + from ..codecache import pick_vec_isa + + picked_vec_isa = pick_vec_isa() + ext = "so" + extra = cpp_compile_command("i", "o", vec_isa=picked_vec_isa) + # \n is required to match with the CodeCache behavior + source_code = "\n" + code.getvalue() + _, _, kernel_path = get_code_path(source_code, ext, extra) + return kernel_path + + def load_kernel(self, name: str = None, kernel: str = None, arg_types: List = None): + kernel_path = self.get_kernel_path(kernel) + + self.writeline(f'auto {name}_lib = dlopen("{kernel_path}", RTLD_NOW);') + self.writeline(f"assert({name}_lib != nullptr);") + self.writeline(f"void (*{name})({arg_types});") + self.writeline(f'*(void **) (&{name}) = dlsym({name}_lib, "kernel");') + + def wrap_kernel_call(self, name, call_args): + return "{}({});".format(name, ", ".join(call_args)) + + def generate_return(self, output_refs): + if output_refs: + if len(output_refs) == 1: + self.wrapper_call.writeline("return " + output_refs[0] + "; }''' )") + else: + self.wrapper_call.writeline( + "return std::vector({" + + ", ".join(output_refs) + + "}); }''' )" + ) + else: + self.wrapper_call.writeline("return; }''' )") + + def generate_end(self, result): + shared = codecache.get_shared() + warning_all_flag = codecache.get_warning_all_flag() + cpp_flags = codecache.cpp_flags() + ipaths, lpaths, libs, macros = codecache.get_include_and_linking_paths() + optimization_flags = codecache.optimization_flags() + use_custom_generated_macros = codecache.use_custom_generated_macros() + + extra_cflags = f"{cpp_flags} {optimization_flags} {warning_all_flag} {macros} {use_custom_generated_macros}" + extra_ldflags = f"{shared} {lpaths} {libs}" + extra_include_paths = f"{ipaths}" + + # get the hash of the wrapper code to name the extension + wrapper_call_hash = codecache.code_hash(self.wrapper_call.getvalue()) + result.splice( + f""" + module = load_inline( + name='inline_extension_{wrapper_call_hash}', + cpp_sources=[wrapper], + functions=['call_{self._call_func_id}'], + extra_cflags=['{extra_cflags}'], + extra_ldflags=['{extra_ldflags}'], + extra_include_paths=['{extra_include_paths}']) + """ + ) + # Wrap the func to support setting result._boxed_call = True + result.splice( + f""" + def _wrap_func(f): + def g(args): + return f(args) + return g + call = _wrap_func(module.call_{self._call_func_id}) + """ + ) diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index e6b27420a941a..3b4ce9c202e5c 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -6,14 +6,14 @@ from typing import List import functorch -from functorch._src.aot_autograd import make_boxed_func from functorch.compile import min_cut_rematerialization_partition import torch.fx +from torch._dynamo.utils import fake_mode_from_tensors +from torch._functorch.aot_autograd import make_boxed_func from torch._subclasses.fake_tensor import FakeTensor -from torch.utils._mode_utils import no_dispatch -from . import config, overrides +from . import config, metrics, overrides from .debug import DebugContext from .decomposition import select_decomp_table from .graph import GraphLowering @@ -28,7 +28,7 @@ log = logging.getLogger(__name__) ALIGNMENT = 16 -aot_autograd = dynamo_optimizations.backends.aot_autograd +aot_autograd = dynamo_optimizations.training.aot_autograd normalize_ir = dynamo_optimizations.normalize.normalize_ir is_aot_autograd_safe_to_run = dynamo_optimizations.training.is_aot_autograd_safe_to_run count_calls = dynamo_utils.count_calls @@ -85,7 +85,20 @@ def _step_logger(): @DebugContext.wrap -@no_dispatch() +def count_bytes_inner(gm, example_inputs, num_fixed=0, **kwargs): + shape_env = _shape_env_from_inputs(example_inputs) + + graph = GraphLowering(gm, shape_env=shape_env, num_static_inputs=num_fixed) + with V.set_graph_handler(graph): + graph.run(*example_inputs) + num_bytes, nodes_num_elem = graph.count_bytes() + metrics.num_bytes_accessed += num_bytes + metrics.nodes_num_elem += nodes_num_elem + return make_boxed_func(gm.forward) + + +@DebugContext.wrap +@torch.utils._python_dispatch._disable_current_modes() def compile_fx_inner( gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor], @@ -112,12 +125,12 @@ def compile_fx_inner( if cudagraphs is None: cudagraphs = config.triton.cudagraphs - shape_env = None - for inp in example_inputs: - if isinstance(inp, FakeTensor) and inp.fake_mode.shape_env is not None: - shape_env = inp.fake_mode.shape_env - graph = GraphLowering(gm, shape_env=shape_env, num_static_inputs=num_fixed) + shape_env = _shape_env_from_inputs(example_inputs) + + graph = GraphLowering( + gm, shape_env=shape_env, num_static_inputs=num_fixed, graph_id=graph_id + ) with V.set_graph_handler(graph): graph.run(*example_inputs) compiled_fn = graph.compile_to_fn() @@ -327,7 +340,11 @@ def is_not_gradout(x): _graph_counter = itertools.count(0) -def compile_fx(model_: torch.fx.GraphModule, example_inputs_: List[torch.Tensor]): +def compile_fx( + model_: torch.fx.GraphModule, + example_inputs_: List[torch.Tensor], + inner_compile=compile_fx_inner, +): """Main entrypoint to a compile given FX graph""" if not is_aot_autograd_safe_to_run(model_, example_inputs_): @@ -340,6 +357,7 @@ def compile_fx(model_: torch.fx.GraphModule, example_inputs_: List[torch.Tensor] with overrides.patch_functions(): model_ = normalize_ir(model_, example_inputs_) model_ = overrides.replace_fx(model_) + model_ = overrides.fuse_fx(model_, example_inputs_) num_example_inputs = len(example_inputs_) cudagraphs = BoxedBool(config.triton.cudagraphs and not config.dynamic_shapes) @@ -348,7 +366,7 @@ def compile_fx(model_: torch.fx.GraphModule, example_inputs_: List[torch.Tensor] @dynamo_utils.dynamo_timed def fw_compiler(model: torch.fx.GraphModule, example_inputs): fixed = len(example_inputs) - num_example_inputs - return compile_fx_inner( + return inner_compile( model, example_inputs, num_fixed=fixed, @@ -359,7 +377,7 @@ def fw_compiler(model: torch.fx.GraphModule, example_inputs): @dynamo_utils.dynamo_timed def bw_compiler(model: torch.fx.GraphModule, example_inputs): fixed = count_tangents(model) - return compile_fx_inner( + return inner_compile( model, example_inputs, num_fixed=fixed, @@ -371,15 +389,29 @@ def bw_compiler(model: torch.fx.GraphModule, example_inputs): with overrides.patch_functions(): # TODO: can add logging before/after the call to create_aot_dispatcher_function - # in functorch/_src/aot_autograd.py::aot_module_simplified::aot_function_simplified::new_func + # in torch._functorch/aot_autograd.py::aot_module_simplified::aot_function_simplified::new_func # once torchdynamo is merged into pytorch return aot_autograd( - model_, - example_inputs_, fw_compiler=fw_compiler, bw_compiler=bw_compiler, decompositions=select_decomp_table(), partition_fn=functools.partial( min_cut_rematerialization_partition, compiler="inductor" ), - ) + )(model_, example_inputs_) + + +def _shape_env_from_inputs(inputs): + shape_env = None + fake_mode = fake_mode_from_tensors(inputs) + + # TODO(voz): It would be nice to enable this assert, but there are lots of tests that + # pass in real inputs for now. + # if len(inputs) > 0: + # assert fake_mode is not None, breakpoint() + + if fake_mode is not None: + return fake_mode.shape_env + + # TODO(voz): Should we always have one anyway? + return None diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index cabaa3e7ce0ba..19f350ee6f0cf 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -1,8 +1,18 @@ import os +import sys # add some debug printouts debug = False +# Whether to disable a progress bar for autotuning +disable_progress = True + +# Whether to enable printing the source code for each future +verbose_progress = False + +# use cpp wrapper instead of python wrapper +cpp_wrapper = False + # dead code elimination dce = False @@ -27,9 +37,14 @@ benchmark_harness = True # control store vs recompute heuristic +# For fanouts, rematearialization can lead to exponential blowup. So, have +# smaller threshold realize_reads_threshold = 4 realize_bytes_threshold = 2000 +# Threshold to prevent excessive accumulation of ops in one buffer during lowering +realize_acc_reads_threshold = 8 + # fallback to eager for random/dropout, this is slow but useful for debugging fallback_random = False @@ -53,7 +68,27 @@ comment_origin = False -compile_threads = min(32, os.cpu_count()) + +def is_fbcode(): + import torch + + return not hasattr(torch.version, "git_version") + + +compile_threads = ( + 1 + if sys.platform == "win32" or is_fbcode() + else min( + 32, + len(os.sched_getaffinity(0)) + if hasattr(os, "sched_getaffinity") + else os.cpu_count(), + ) +) + +# If kernel is fused, the name is generated from the origin node op names +# for larger kernels limit this +kernel_name_max_ops = 10 # How to import torchinductor, either torchinductor or torch.inductor inductor_import = __name__.replace(".config", "") @@ -61,6 +96,15 @@ # How to import torchdynamo, either torchdynamo or torch.dynamo dynamo_import = inductor_import.replace("inductor", "dynamo") +# Pad input tensors of matmul/bmm/addmm to leverage Tensor Cores in NVIDIA GPUs +shape_padding = os.environ.get("TORCHINDUCTOR_SHAPE_PADDING", "0") == "1" +alignment_size = 4 + +# Fx-based linear/matmul/bmm + permute/transpose vertical fusion +permute_fusion = os.environ.get("TORCHINDUCTOR_PERMUTE_FUSION", "0") == "1" + +# Mark the wrapper call in PyTorch profiler +profiler_mark_wrapper_call = False # config specific to codegen/cpp.pp class cpp: @@ -77,12 +121,15 @@ class cpp: min_chunk_size = 4096 cxx = ( None, # download gcc12 from conda-forge if conda is installed - "g++-12", - "g++-11", - "g++-10", - "clang++", + # "g++-12", + # "g++-11", + # "g++-10", + # "clang++", "g++", + # "g++.par", ) + # Allow kernel performance profiling via PyTorch profiler + enable_kernel_profile = False # config specific to codegen/triton.py @@ -117,8 +164,8 @@ class triton: tiling_prevents_reduction_fusion = True # should we give different names to kernels ordered_kernel_names = False - # should we use natural codegen for where, needs newer triton version - simple_where = True + # should we put op names in kernel names + descriptive_kernel_names = True # create a directory containing lots of debug information @@ -153,3 +200,92 @@ class trace: # Upload the .tar.gz file # Needs to be overriden based on specific environment needs upload_tar = None + + +class InductorConfigContext: + static_memory: bool + matmul_tune: str + matmul_padding: bool + triton_autotune: bool + triton_bmm: bool + triton_mm: str + triton_convolution: str + rematerialize_threshold: int + rematerialize_acc_threshold: int + + def _save(self): + self.static_memory = triton.cudagraphs + self.matmul_tune = triton.mm + self.matmul_padding = shape_padding + self.triton_autotune = triton.autotune + self.triton_bmm = triton.use_bmm + self.triton_mm = triton.mm + self.triton_convolution = triton.convolution + self.rematerialize_threshold = realize_reads_threshold + self.rematerialize_acc_threshold = realize_acc_reads_threshold + + def _apply(self): + triton.cudagraphs = self.static_memory + triton.mm = self.matmul_tune + shape_padding = self.matmul_padding + triton.autotune = self.triton_autotune + triton.use_bmm = self.triton_bmm + triton.mm = self.triton_mm + triton.convolution = self.triton_convolution + realize_reads_threshold = self.rematerialize_threshold + realize_acc_reads_threshold = self.rematerialize_acc_threshold + + def __init__(self, arg=None): + self._save() + if arg is None: + return + # Handle mode + if type(arg) is str: + + def default(): + self.static_memory = False + + def reduce_overhead(): + self.static_memory = True + + def max_autotune(): + self.static_memory = False + self.matmul_padding = True + self.triton_convolution = "autotune" + self.triton_mm = "autotune" + self.matmul_padding = True + + modes = { + x.__name__.replace("_", "-"): x + for x in [default, reduce_overhead, max_autotune] + } + if arg not in modes: + raise RuntimeError( + f"Unrecognized mode {arg}, should be one of {', '.join(modes.keys())}" + ) + modes[arg]() + return + # Handle passes + for (name, val) in arg.items(): + attr_name = name.replace("-", "_") + if not hasattr(self, attr_name): + known_passes = ", ".join( + [x.replace("_", "-") for x in dir(self) if not x.startswith("_")] + ) + raise RuntimeError( + f"Unexpected optimization pass {name}, known passes are {known_passes}" + ) + if type(val) != type(getattr(self, attr_name)): + val_type_str = type(val).__name__ + expected_type_str = type(getattr(self, attr_name)).__name__ + raise RuntimeError( + f"Unexpected type of attr {name}, got {val_type_str} should be {expected_type_str}" + ) + setattr(self, attr_name, val) + + def __enter__(self): + self._prev = InductorConfigContext() + self._apply() + + def __exit__(self, exc_type, exc_val, exc_tb): + self._prev._apply() diff --git a/torch/_inductor/cuda_properties.py b/torch/_inductor/cuda_properties.py index de5349b568971..e42b2c5b5c676 100644 --- a/torch/_inductor/cuda_properties.py +++ b/torch/_inductor/cuda_properties.py @@ -11,10 +11,15 @@ @functools.lru_cache(None) def _properties(): - r = { - i: torch.cuda.get_device_properties(i) for i in range(torch.cuda.device_count()) - } - return r + if not torch.cuda.is_available(): + return {} + try: + return { + i: torch.cuda.get_device_properties(i) + for i in range(torch.cuda.device_count()) + } + except RuntimeError: + return {} _compile_worker_current_device = None diff --git a/torch/_inductor/debug.py b/torch/_inductor/debug.py index d2bc9bcd73344..67e75d1a73294 100644 --- a/torch/_inductor/debug.py +++ b/torch/_inductor/debug.py @@ -19,7 +19,6 @@ from torch.fx.passes.tools_common import legalize_graph from . import config, ir -from .codecache import cache_dir from .scheduler import ( BaseSchedulerNode, ExternKernelSchedulerNode, @@ -182,7 +181,11 @@ def inner(*args, **kwargs): @staticmethod def create_debug_dir(): for n in DebugContext._counter: - dirname = os.path.join(cache_dir(), f"debug.{os.getpid()}.{n}") + dirname = os.path.join( + dynamo_utils.get_debug_dir(), + "torchinductor", + f"debug.{os.getpid()}.{n}", + ) if not os.path.exists(dirname): os.makedirs(dirname) return dirname @@ -303,9 +306,12 @@ def __init__(self, handler): self.handler = handler def fx_graph(self, gm: torch.fx.GraphModule, inputs: List[torch.Tensor]): - with self.fopen("fx_graph.py") as fd: + with self.fopen("fx_graph_runnable.py") as fd: dynamo_debug_utils.save_graph_repro(fd, gm, inputs, "inductor") + with self.fopen("fx_graph_readable.py") as fd: + fd.write(gm.print_readable(print_output=False)) + def ir_pre_fusion(self, nodes: SchedulerNodeList): self._write_ir("ir_pre_fusion.txt", nodes) diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index 5e67bfe6ef29e..f8fedcc786015 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -8,8 +8,9 @@ from torch import Tensor from torch._decomp import get_decompositions from torch._prims_common import is_boolean_dtype, is_integer_dtype +from torch.utils._mode_utils import no_dispatch -from . import config +from . import config, utils log = logging.getLogger(__name__) aten = torch.ops.aten @@ -65,6 +66,8 @@ aten.mv, aten.narrow, aten.native_batch_norm, + aten._native_batch_norm_legit, + aten._native_batch_norm_legit_functional, aten.native_batch_norm_backward, aten.native_dropout_backward, aten.native_group_norm, @@ -81,24 +84,29 @@ aten._reshape_alias, aten.select_backward, aten.select_scatter, + aten.sgn, aten.sigmoid_backward, + aten.silu, aten.silu_backward, aten.slice_backward, - aten.sgn, - aten.std_mean.correction, aten._softmax, aten._softmax_backward_data, + aten.softplus, + aten.softplus_backward, aten.stack, + aten.std_mean.correction, aten.t, aten.tanh_backward, aten.threshold_backward, aten.transpose.int, aten.tril.default, + aten.unfold, + aten.unfold_backward, aten.upsample_bilinear2d.vec, aten.upsample_nearest2d_backward, aten.softplus, aten.softplus_backward, - aten.silu, + aten.bucketize, ] ) @@ -107,7 +115,7 @@ def register_decomposition(ops): for op in [ops] if callable(ops) else ops: if op in decompositions: log.warning(f"duplicate decomp: {ops}") - return decomp.register_decomposition(ops, decompositions, disable_meta=True) + return decomp.register_decomposition(ops, decompositions) @register_decomposition([aten.clamp]) @@ -131,6 +139,26 @@ def floordiv(a, b): return aten.div.Tensor_mode(a, b, rounding_mode="floor") +def get_padded_length(x): + if x % config.alignment_size == 0: + return 0 + return int((x // config.alignment_size + 1) * config.alignment_size) - x + + +def pad_dim(x, padded_length, dim): + pad = x.new_zeros(*x.shape[:dim], padded_length, *x.shape[dim + 1 :]) + return torch.cat([x, pad], dim=dim) + + +def check_device_dtype(a: Tensor, b: Tensor): + return ( + a.is_cuda + and b.is_cuda + and a.dtype in (torch.float32, torch.float16, torch.bfloat16) + and b.dtype in (torch.float32, torch.float16, torch.bfloat16) + ) + + @register_decomposition([aten.addmm]) def addmm(input, mat1, mat2, *, beta=1, alpha=1): if config.triton.mm != "aten": @@ -140,13 +168,191 @@ def addmm(input, mat1, mat2, *, beta=1, alpha=1): if not isinstance(beta, numbers.Number) or beta != 1: input = input * beta return input + out - else: - return NotImplemented # go directly to lowering - -@register_decomposition([aten.rsqrt]) -def rsqrt(x): - return torch.reciprocal(torch.sqrt(x)) + if ( + config.shape_padding + and check_device_dtype(mat1, mat2) + and should_pad_bench(mat1, mat2, torch.ops.aten.addmm, input=input) + ): + m_padded_length = get_padded_length(mat1.shape[0]) + k_padded_length = get_padded_length(mat1.shape[1]) + n_padded_length = get_padded_length(mat2.shape[1]) + + if k_padded_length != 0: + mat1 = pad_dim(mat1, k_padded_length, 1) + mat2 = pad_dim(mat2, k_padded_length, 0) + elif m_padded_length != 0: + mat1 = pad_dim(mat1, m_padded_length, 0) + elif n_padded_length != 0: + mat2 = pad_dim(mat2, n_padded_length, 1) + + if input is not None and k_padded_length == 0: + if m_padded_length != 0 and input.dim() == 2: + input = pad_dim(input, m_padded_length, 0) + elif n_padded_length != 0: + if input.dim() == 2: + input = pad_dim(input, n_padded_length, 1) + elif input.dim() == 1: + input = pad_dim(input, n_padded_length, 0) + + if k_padded_length != 0: + return torch.ops.aten.addmm(input, mat1, mat2, beta=beta, alpha=alpha) + elif m_padded_length != 0: + return torch.ops.aten.addmm(input, mat1, mat2, beta=beta, alpha=alpha)[ + :-m_padded_length, : + ] + elif n_padded_length != 0: + return torch.ops.aten.addmm(input, mat1, mat2, beta=beta, alpha=alpha)[ + :, :-n_padded_length + ] + + return NotImplemented # go directly to lowering + + +def should_pad_bench(mat1, mat2, op, input=None): + assert utils.has_triton() + from triton.testing import do_bench + + with no_dispatch(): + if op is torch.ops.aten.mm or op is torch.ops.aten.addmm: + m_padded_length = get_padded_length(mat1.shape[0]) + k_padded_length = get_padded_length(mat1.shape[1]) + n_padded_length = get_padded_length(mat2.shape[1]) + elif op is torch.ops.aten.bmm: + m_padded_length = get_padded_length(mat1.shape[1]) + k_padded_length = get_padded_length(mat1.shape[2]) + n_padded_length = get_padded_length(mat2.shape[2]) + else: + return False + + if m_padded_length == k_padded_length == n_padded_length == 0: + return False + + mat1 = torch.randn_like(mat1) + mat2 = torch.randn_like(mat2) + warmup = 5 + rep = 100 + if op is torch.ops.aten.bmm or op is torch.ops.aten.mm: + ori_time = do_bench( + lambda: op(mat1, mat2), warmup=warmup, rep=rep, fast_flush=True + )[0] + else: + if input is not None: + input = torch.randn_like(input) + ori_time = do_bench( + lambda: op(input, mat1, mat2), warmup=warmup, rep=rep, fast_flush=True + )[0] + + mat1_pad = mat1.new_empty([get_padded_length(i) + i for i in mat1.shape]) + mat2_pad = mat2.new_empty([get_padded_length(i) + i for i in mat2.shape]) + if op is torch.ops.aten.addmm: + input_pad = None + if input is not None and input.is_cuda and input.dtype == torch.float32: + input_pad = input.new_empty( + [get_padded_length(i) + i for i in input.shape] + ) + pad_time = do_bench( + lambda: op(input_pad, mat1_pad, mat2_pad), + warmup=warmup, + rep=rep, + fast_flush=True, + )[0] + else: + pad_time = do_bench( + lambda: op(mat1_pad, mat2_pad), warmup=warmup, rep=rep, fast_flush=True + )[0] + + # Shape padding introduces addtional memory ops. Based on microbenchmarks, 1.3x for + # aten.mm and aten.addmm and 2x for aten.bmm represent a reasonable tradeoff between + # performance improvement from shape padding and overhead from addtional memory ops + # TODO: Build a learned model which would be better than this heuristic + if op is torch.ops.aten.mm or op is torch.ops.aten.addmm: + return ori_time > pad_time * 1.3 + else: + return ori_time > pad_time * 2 + + +@register_decomposition([aten.mm]) +def mm_decomp(mat1, mat2): + if ( + config.shape_padding + and check_device_dtype(mat1, mat2) + and should_pad_bench(mat1, mat2, torch.ops.aten.mm) + ): + m_padded_length = get_padded_length(mat1.shape[0]) + k_padded_length = get_padded_length(mat1.shape[1]) + n_padded_length = get_padded_length(mat2.shape[1]) + + if k_padded_length != 0: + mat1 = pad_dim(mat1, k_padded_length, 1) + mat2 = pad_dim(mat2, k_padded_length, 0) + return torch.ops.aten.mm(mat1, mat2) + elif m_padded_length != 0: + mat1 = pad_dim(mat1, m_padded_length, 0) + return torch.ops.aten.mm(mat1, mat2)[:-m_padded_length, :] + elif n_padded_length != 0: + mat2 = pad_dim(mat2, n_padded_length, 1) + return torch.ops.aten.mm(mat1, mat2)[:, :-n_padded_length] + + return NotImplemented # go directly to lowering + + +@register_decomposition([aten.bmm]) +def bmm_decomp(mat1, mat2): + if ( + config.shape_padding + and check_device_dtype(mat1, mat2) + and should_pad_bench(mat1, mat2, torch.ops.aten.bmm) + ): + m_padded_length = get_padded_length(mat1.shape[1]) + k_padded_length = get_padded_length(mat1.shape[2]) + n_padded_length = get_padded_length(mat2.shape[2]) + + if k_padded_length != 0: + mat1 = pad_dim(mat1, k_padded_length, 2) + mat2 = pad_dim(mat2, k_padded_length, 1) + return torch.ops.aten.bmm(mat1, mat2) + elif m_padded_length != 0: + mat1 = pad_dim(mat1, m_padded_length, 1) + return torch.ops.aten.bmm(mat1, mat2)[:, :-m_padded_length, :].contiguous() + elif n_padded_length != 0: + mat2 = pad_dim(mat2, n_padded_length, 2) + return torch.ops.aten.bmm(mat1, mat2)[:, :, :-n_padded_length].contiguous() + + return NotImplemented # go directly to lowering + + +@register_decomposition([aten.convolution_backward]) +def convolution_backward( + grad_output, + input, + weight, + bias_sizes, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + output_mask, +): + if not output_mask[2] or grad_output.device.type != "cuda": + return NotImplemented + grad_bias = aten.sum(grad_output, [0] + list(range(2, grad_output.dim()))) + grad_inp, grad_weight, _ = aten.convolution_backward( + grad_output, + input, + weight, + bias_sizes, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + [output_mask[0], output_mask[1], False], + ) + return (grad_inp, grad_weight, grad_bias) @register_decomposition([aten.log2]) @@ -160,30 +366,6 @@ def round_dec(x, decimals=0): return aten.round(x * ten_pow_decimals) * (1.0 / ten_pow_decimals) -@register_decomposition([aten.special_erf, aten.erf]) -def special_erf(x): - # TODO(jansel): this might be crazy slow. Triton doesn't have the - # cuda ::erf() builtin. I've made a feature request for this, - # so it may be coming soon. - - # from https://www.johndcook.com/blog/2009/01/19/stand-alone-error-function-erf/ - a1 = 0.254829592 - a2 = -0.284496736 - a3 = 1.421413741 - a4 = -1.453152027 - a5 = 1.061405429 - p = 0.3275911 - - sign = torch.sign(x) - x = torch.abs(x) - - # A & S 7.1.26 - t = 1.0 / (1.0 + p * x) - y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * torch.exp(-x * x) - - return sign * y - - @register_decomposition([aten.rsub.Tensor, aten.rsub.Scalar]) def rsub(a, b): if isinstance(b, numbers.Number): @@ -234,6 +416,17 @@ def all_dim(input, dim, keeepdim=False): return torch.logical_not(torch.any(torch.logical_not(input), dim, keeepdim)) +# NB: this decomposition is not stride accurate, do not put it in the main +# library +@register_decomposition(aten.copy) +def copy(self, src, non_blocking=False): + intermediate = src.to(self, non_blocking) + if self.size() != intermediate.size(): + return aten.expand_copy.default(intermediate, self.size()) + else: + return intermediate + + @register_decomposition(aten.hardswish_) def hardswish_(x): return x.copy_(aten.hardswish(x)) @@ -259,11 +452,6 @@ def masked_fill_(x, mask, value): return x.copy_(aten.masked_fill(x, mask, value)) -@register_decomposition([aten.log1p]) -def log1p(x): - return torch.log(x + 1) - - @register_decomposition([aten.baddbmm]) def baddbmm(self, batch1, batch2, beta=1, alpha=1): result = torch.bmm(batch1, batch2) @@ -302,6 +490,12 @@ def bernoulli(self, *, generator=None): return torch.rand_like(self, dtype=torch.float32) < self +@register_decomposition([aten.bernoulli.p]) +def bernoulli_p(self, p=0.5, *, generator=None): + assert generator is None + return torch.rand_like(self, dtype=torch.float32) < p + + """ Some decomps result in differences from eager related to randomness. We put these decomps in a separate table `extra_random_decomps` to allow @@ -309,13 +503,13 @@ def bernoulli(self, *, generator=None): """ extra_random_decomps = get_decompositions([aten.native_dropout]) register_extra_random_decomp = functools.partial( - decomp.register_decomposition, registry=extra_random_decomps, disable_meta=True + decomp.register_decomposition, registry=extra_random_decomps ) @register_extra_random_decomp([aten.bernoulli_]) def bernoulli_(self, p=0.5): - return self.copy_(torch.rand_like(self) < p) + return self.copy_(torch.rand_like(self, dtype=torch.float32) < p) @functools.lru_cache(None) diff --git a/torch/_inductor/dependencies.py b/torch/_inductor/dependencies.py index 6eee943b60074..5434d7addfa9a 100644 --- a/torch/_inductor/dependencies.py +++ b/torch/_inductor/dependencies.py @@ -7,8 +7,16 @@ import sympy +from . import config from .codegen.common import index_prevent_reordering -from .utils import sympy_product, sympy_str, sympy_subs, sympy_symbol, VarRanges +from .utils import ( + get_dtype_size, + sympy_product, + sympy_str, + sympy_subs, + sympy_symbol, + VarRanges, +) from .virtualized import V log = logging.getLogger(__name__) @@ -68,11 +76,18 @@ def rename(self, renames: Dict[str, str]) -> "MemoryDep": return MemoryDep(renames[self.name], self.index, self.size) return self - def numel_hint(self): + def numbytes_hint(self): vars = set(self.index.free_symbols) + size_vars_used = [] + for var in vars: + if var.name.startswith(canonicalization_prefix()): + # Sometimes with indirect indexing we have very weird symbol names + assert " " not in var.name + size_vars_used.append(int(var.name[len(canonicalization_prefix()) :])) + return V.graph.sizevars.size_hint( - sympy_product([s for s in self.size if s in vars]) - ) + sympy_product([self.size[i] for i in size_vars_used]) + ) * get_dtype_size(V.graph.get_dtype(self.name)) def is_contiguous(self) -> bool: return isinstance(self.index, (sympy.Symbol, sympy.Integer)) @@ -87,8 +102,21 @@ def rename(self, renames: Dict[str, str]) -> "StarDep": return StarDep(renames[self.name]) return self - def numel_hint(self): - return 1 + def numbytes_hint(self): + from .ir import MultiOutputLayout + + if self.name in V.graph.name_to_buffer: + buf = V.graph.name_to_buffer[self.name] + elif self.name in V.graph.graph_inputs: + buf = V.graph.graph_inputs[self.name] + else: + return 1 + if hasattr(buf, "layout") and isinstance(buf.layout, MultiOutputLayout): + # NB: Too annoying to acquire, should only be used for instrumentation + return 1 + return V.graph.sizevars.size_hint( + sympy_product(buf.get_size()) + ) * get_dtype_size(buf.get_dtype()) def is_contiguous(self) -> bool: return False @@ -146,6 +174,15 @@ def __init__(self, var_ranges: VarRanges, normalize: bool): self._var_ranges: VarRanges = var_ranges self._normalize: bool = normalize + # Truncate the expr str by a threshold to prevent it's too long + # and cause process hanging. The result is not used. + # https://github.com/pytorch/torchdynamo/issues/1352 + @staticmethod + def truncate_expr(expr): + if len(expr) > config.realize_bytes_threshold: + expr = f"{expr[:config.realize_bytes_threshold]}..." + return expr + def canonicalize( self, index: sympy.Expr ) -> Tuple[sympy.Expr, Tuple[sympy.Expr, ...]]: diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 2a1619a822451..cdc40a114840f 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -1,6 +1,8 @@ import logging import operator import os +import re +import sys import time import sympy @@ -11,17 +13,24 @@ from torch.fx.experimental.symbolic_shapes import ShapeEnv from torch.utils._mode_utils import no_dispatch +from .._dynamo import config as dynamo_config + from . import config, ir -from .codegen.wrapper import WrapperCodeGen +from .codegen.wrapper import CppWrapperCodeGen, WrapperCodeGen from .exc import ( LoweringException, MissingOperatorWithDecomp, MissingOperatorWithoutDecomp, ) -from .ir import Constant, FixedLayout, InputBuffer, TensorBox -from .lowering import lowerings, make_fallback, needs_realized_inputs -from .sizevars import SizeVarAllocator -from .utils import dynamo_logging, dynamo_utils +from .ir import Constant, FixedLayout, InputBuffer, Pointwise, Reduction, TensorBox +from .lowering import ( + layout_constraints, + lowerings, + make_fallback, + needs_realized_inputs, +) +from .sizevars import CppSizeVarAllocator, SizeVarAllocator +from .utils import dynamo_utils, gather_origins, get_dtype_size, sympy_product from .virtualized import V log = logging.getLogger(__name__) @@ -40,11 +49,9 @@ def symbolic_sizes_strides(self, ex: torch.Tensor): else: size, stride = self._shape_env.create_symbolic_sizes_strides(ex) - size = [ - i.get_pyobj().expr if isinstance(i, torch.SymIntNode) else i for i in size - ] + size = [i.get_pyobj().expr if isinstance(i, torch.SymInt) else i for i in size] stride = [ - i.get_pyobj().expr if isinstance(i, torch.SymIntNode) else i for i in stride + i.get_pyobj().expr if isinstance(i, torch.SymInt) else i for i in stride ] return size, stride @@ -57,7 +64,11 @@ def static_sizes_strides(self, ex: torch.Tensor): return size, stride def __init__( - self, gm: torch.fx.GraphModule, shape_env=None, num_static_inputs=None + self, + gm: torch.fx.GraphModule, + shape_env=None, + num_static_inputs=None, + graph_id=None, ): super().__init__(gm) if shape_env is None: @@ -84,6 +95,9 @@ def __init__( self.randomness_seeds = [] self.name_to_buffer = {} self.creation_time = time.time() + self.name = "GraphLowering" + self._can_use_cpp_wrapper = config.cpp_wrapper + self.graph_id = graph_id def get_dtype(self, buffer_name): if buffer_name in self.constants: @@ -92,6 +106,9 @@ def get_dtype(self, buffer_name): return self.name_to_buffer[buffer_name].get_dtype() if buffer_name in self.graph_inputs: return self.graph_inputs[buffer_name].get_dtype() + m = re.match(r"as_strided\(([a-zA-Z0-9_]+),", buffer_name) + if m: + return self.get_dtype(m.group(1)) raise KeyError(f"could not find {buffer_name}") def random_seed_buffer(self, device: torch.device): @@ -130,7 +147,21 @@ def increment_randomness_offset(self, numel): def run(self, *args): return super().run(*args) + def disable_cpp_wrapper(self, cond): + self._can_use_cpp_wrapper = False + log.debug("Set _can_use_cpp_wrapper to False due to %s", cond) + + def check_buffer_for_cpp_wrapper(self, buffer: ir.ComputedBuffer): + if isinstance(buffer, ir.ExternKernel): + self.disable_cpp_wrapper("ExternKernel") + if isinstance(buffer, ir.ComputedBuffer): + if buffer.data.get_reduction_type(): + self.disable_cpp_wrapper("Reduction") + def register_buffer(self, buffer: ir.ComputedBuffer): + if config.cpp_wrapper: + self.check_buffer_for_cpp_wrapper(buffer) + name = f"buf{len(self.buffers)}" self.buffers.append(buffer) self.name_to_buffer[name] = buffer @@ -214,34 +245,35 @@ def placeholder(self, target, args, kwargs): return tensor def call_function(self, target, args, kwargs): - if target is operator.getitem and isinstance(args[0], (list, tuple)): - return super().call_function(target, args, kwargs) - - if target not in lowerings: - if config.implicit_fallbacks: - error = ( - MissingOperatorWithDecomp - if get_decompositions([target]) - else MissingOperatorWithoutDecomp - ) - log.warning( - "Creating implicit fallback for:\n%s", - error.operator_str(target, args, kwargs), - ) - make_fallback(target) - elif get_decompositions([target]): - # There isn't a good way to dynamically patch this in - # since AOT Autograd already ran. The error message tells - # the user how to fix it. - raise MissingOperatorWithDecomp(target, args, kwargs) - else: - raise MissingOperatorWithoutDecomp(target, args, kwargs) + with ir.IRNode.current_origins(gather_origins(args, kwargs)): + if target is operator.getitem and isinstance(args[0], (list, tuple)): + return super().call_function(target, args, kwargs) + + if target not in lowerings: + if config.implicit_fallbacks: + error = ( + MissingOperatorWithDecomp + if get_decompositions([target]) + else MissingOperatorWithoutDecomp + ) + log.warning( + "Creating implicit fallback for:\n%s", + error.operator_str(target, args, kwargs), + ) + make_fallback(target) + elif get_decompositions([target]): + # There isn't a good way to dynamically patch this in + # since AOT Autograd already ran. The error message tells + # the user how to fix it. + raise MissingOperatorWithDecomp(target, args, kwargs) + else: + raise MissingOperatorWithoutDecomp(target, args, kwargs) - try: - out = lowerings[target](*args, **kwargs) - return out - except Exception as e: - raise LoweringException(e, target, args, kwargs) from e + try: + out = lowerings[target](*args, **kwargs) + return out + except Exception as e: + raise LoweringException(e, target, args, kwargs) from e def get_attr(self, target, args, kwargs): # this is a constant @@ -268,7 +300,15 @@ def output(self, target, args, kwargs): assert isinstance(result, (tuple, list)), type(result) assert all( isinstance( - x, (TensorBox, ir.Constant, type(None), ir.ConstantBuffer, sympy.Expr) + x, + ( + TensorBox, + ir.Constant, + type(None), + ir.ConstantBuffer, + sympy.Expr, + int, + ), ) for x in result ), result @@ -298,25 +338,150 @@ def finalize(self): def run_node(self, n: torch.fx.Node): with ir.IRNode.current_origins({n}): - result = super().run_node(n) + if n.op == "call_function" and n.target in layout_constraints: + args, kwargs = self.fetch_args_kwargs_from_env(n) + args, kwargs = layout_constraints[n.target](n, *args, **kwargs) + result = self.call_function(n.target, args, kwargs) + else: + result = super().run_node(n) + + # Realize if (1) any user need inputs realized, or (2) there is + # already too many reads and rematerializing can be bad. num_users = len(set(n.users)) if num_users > 1 and isinstance(result, TensorBox): for user in n.users: - if user.target in needs_realized_inputs or user.op == "output": + if user.target in needs_realized_inputs: result.realize_hint() + # This inclusion is somewhat controversial (from + # discussion between Horace, Natalia, and Elias). + # Currently, it's not very clear why this is helpful. + # The general idea here is that even though a node may + # have FlexibleLayout, we still often *treat* it as if + # it was contiguous. This appears to sometime result in + # suboptimal behavior. + # + # When we do a better job selecting layout, we should + # revisit this. + if user.target in ( + torch.ops.aten.convolution.default, + torch.ops.aten.convolution_backward.default, + torch.ops.aten.mm.default, + ): + result = ir.ExternKernel.require_stride_order( + result, ir.get_stride_order(n.meta["val"].stride()) + ) + if user.op == "output": + if isinstance(result.data.data, (Pointwise, Reduction)): + result.realize() # TODO(jansel): introduce a store vs inline choice result.mark_reuse(len(n.users)) + + # Realize if the IRNode already has accumulated lots of reads + if isinstance(result, TensorBox) and result.has_exceeded_max_reads(): + # Prevent excessive accumulation in a computed buffer, when + # there are multiple branches meach with small number of memory + # reads, but they converge to a user. + result.realize_hint() return result + def check_platform(self): + if sys.platform != "linux": + self.disable_cpp_wrapper("platform not linux") + + def check_profiler_mark_wrapper_call(self): + if config.profiler_mark_wrapper_call: + self.disable_cpp_wrapper("profiler not supported") + + def check_device_for_cpp_buffer(self): + if len(self.device_types) == 1: + device = self.device_types.pop() + if device == "cpu": + return + self.disable_cpp_wrapper("device not CPU") + + def check_input_for_cpp_buffer(self): + for _, value in self.graph_inputs.items(): + if value.get_dtype() != torch.float32: + self.disable_cpp_wrapper("inputs not FP32") + + def check_output_for_cpp_buffer(self): + for item in self.graph_outputs: + if isinstance(item, ir.NoneAsConstantBuffer): + self.disable_cpp_wrapper("NoneAsConstantBuffer") + + def check_constant_for_cpp_buffer(self): + if self.constants: + self.disable_cpp_wrapper("Constants") + + def check_cpp_wrapper(self): + self.check_platform() + self.check_profiler_mark_wrapper_call() + self.check_device_for_cpp_buffer() + self.check_input_for_cpp_buffer() + self.check_output_for_cpp_buffer() + self.check_constant_for_cpp_buffer() + + def init_wrapper_code(self): + if config.cpp_wrapper: + self.check_cpp_wrapper() + if self._can_use_cpp_wrapper: + self.sizevars = CppSizeVarAllocator(self._shape_env) + self.wrapper_code = CppWrapperCodeGen() + return + self.wrapper_code = WrapperCodeGen() + return + def codegen(self): from .scheduler import Scheduler - self.wrapper_code = WrapperCodeGen() + self.init_wrapper_code() + self.scheduler = Scheduler(self.buffers) self.scheduler.codegen() return self.wrapper_code.generate() + def count_bytes(self): + from .scheduler import FusedSchedulerNode, NopKernelSchedulerNode, Scheduler + + scheduler = Scheduler(self.buffers) + + def get_read_write_buffers_sizes(node): + if isinstance(node, NopKernelSchedulerNode): + return 0 + reads = set(dep.name for dep in node.read_writes.reads) + writes = set(dep.name for dep in node.read_writes.writes) + + def is_materialized(buf): + buf_uses = set( + [user.node for user in scheduler.name_to_node[buf].users] + ) + return len(buf_uses - set(node.snodes)) > 0 + + if isinstance(node, FusedSchedulerNode): + writes = set([dep for dep in writes if is_materialized(dep)]) + node_bytes = 0 + for buf in reads | writes: + if buf in self.name_to_buffer: + buf = self.name_to_buffer[buf] + elif buf in self.graph_inputs: + buf = self.graph_inputs[buf] + else: + continue + + node_bytes += V.graph.sizevars.size_hint( + sympy_product(buf.get_size()) + ) * get_dtype_size(buf.get_dtype()) + return node_bytes + + total_bytes = 0 + node_counts = [] + for node in scheduler.nodes: + num_bytes = get_read_write_buffers_sizes(node) + node_counts.append((node, num_bytes // 4)) + total_bytes += num_bytes + return total_bytes, node_counts + @dynamo_utils.dynamo_timed def compile_to_module(self): from .codecache import PyCodeCache @@ -329,7 +494,8 @@ def compile_to_module(self): for name, value in self.constants.items(): setattr(mod, name, value) - log.log(dynamo_logging.CODE, "Output code: %s", mod.__file__) + if dynamo_config.output_code: + log.info("Output code: %s", mod.__file__) V.debug.output_code(mod.__file__) V.debug.rename(os.path.splitext(mod.__file__)[0] + ".debug") return mod diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 13cf5d771a0c8..e7a50f58c0b14 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -6,8 +6,10 @@ import re import textwrap from collections import OrderedDict +from contextlib import nullcontext from enum import Enum from functools import partial +from inspect import signature from typing import Any, Callable, ClassVar, Dict, List, Optional, Set, Tuple, Union from unittest.mock import patch @@ -17,17 +19,31 @@ import torch.fx import torch.utils._pytree as pytree -from torch._prims_common import is_boolean_dtype, is_float_dtype +from torch._prims_common import ( + is_boolean_dtype, + is_float_dtype, + make_channels_last_strides_for, + make_contiguous_strides_for, +) +from torch._subclasses.fake_tensor import FakeTensorMode from . import config, dependencies from .codegen.common import index_prevent_reordering from .cuda_properties import get_device_properties from .dependencies import extract_read_writes, var_builder -from .utils import cache_on_self, sympy_dot, sympy_product, sympy_subs, sympy_symbol +from .utils import ( + argsort, + cache_on_self, + sympy_dot, + sympy_product, + sympy_subs, + sympy_symbol, +) from .virtualized import ops, V log = logging.getLogger(__name__) indent = functools.partial(textwrap.indent, prefix=" ") +aten = torch.ops.aten def inverse_reorder(order): @@ -66,6 +82,17 @@ def stride_order2fill_order(order): return fill_order +def get_stride_order(seq): + """ + Convert strides to stride order + """ + sorted_idx = argsort(seq) + out = [None for _ in range(len(seq))] + for i, elem in enumerate(sorted_idx): + out[elem] = i + return out + + def reads_from_conv(buf, var_ranges): """ return: @@ -101,6 +128,25 @@ def reads_from_conv(buf, var_ranges): return False, None +def ir_node_to_tensor(x, guard_shape=True): + shape_fn = ( + V.graph.sizevars.guard_static_shape + if guard_shape + else V.graph.sizevars.size_hint + ) + size = [shape_fn(s) for s in x.get_size()] + if is_storage_and_layout(x): + stride = [shape_fn(s) for s in x.get_layout().stride] + else: + stride = make_contiguous_strides_for(size) + dtype = x.get_dtype() + device = x.get_device() + t = torch.empty_strided( + size=size, stride=stride, dtype=dtype, device=device + ).zero_() + return t + + def layout_priority_idx(reads_bufs, memory_addrs, var_ranges): """ if reads from conv that needs to use specific layout @@ -144,10 +190,24 @@ def eval(cls, base, divisor, modulus): if isinstance(base, sympy.Add): new_terms = [] + all_positive = True for term in base.args: if sympy.gcd(term, modulus * divisor) != modulus * divisor: - new_terms.append(term) - if len(new_terms) != len(base.args): + if (isinstance(term, sympy.Integer) and term < 0) or ( + isinstance(term, sympy.Mul) + and isinstance(term.args[0], sympy.Integer) + and term.args[0] < 0 + ): + # workaround for https://github.com/openai/triton/issues/619, + # if there are negative terms, // produces wrong result + # TODO if https://github.com/openai/triton/issues/619 is fixed + # this optimization would become valid + all_positive = False + break + else: + new_terms.append(term) + + if len(new_terms) != len(base.args) and all_positive: return ModularIndexing(sum(new_terms), divisor, modulus) if isinstance(base, IndexingDiv): @@ -302,7 +362,7 @@ def inner_fn_str(self): with V.set_ops_handler(V.MockHandler()), patch.object( FlexibleLayout, "allow_indexing", True ): - return self.inner_fn(self._index(self.ranges)) + return str(self.inner_fn(self._index(self.ranges))) except Exception as e: return f"inner_fn(): {e}" @@ -379,6 +439,11 @@ class ReductionHint(Enum): DEFAULT = 3 +class TileHint(Enum): + SQUARE = 0 + DEFAULT = 1 + + @dataclasses.dataclass class Reduction(Loops): reduction_ranges: List[Expr] @@ -419,8 +484,11 @@ def inner_fn_str(self): with V.set_ops_handler(V.MockHandler()), patch.object( FlexibleLayout, "allow_indexing", True ): - return self.inner_fn( - self._index(self.ranges), self._index(self.reduction_ranges, "r") + return str( + self.inner_fn( + self._index(self.ranges), + self._index(self.reduction_ranges, "r"), + ) ) except Exception as e: return f"inner_fn(): {e}" @@ -680,12 +748,47 @@ def create( reduction_hint: ReductionHint = ReductionHint.DEFAULT, ): reduction_numel = V.graph.sizevars.simplify(sympy_product(reduction_ranges)) + + if reduction_numel == 0: + + # N.B. This is a hack to generate the literal of the given type + # Ideally, we should be fixing `def constant` in triton.py + # but it breaks due to hardcoded dtypes in other places + def py_cnst(val): + return ( + bool(val) + if dst_dtype == torch.bool + else float(val) + if dst_dtype.is_floating_point + else int(val) + ) + + rtypes_to_inits = { + "sum": py_cnst(0), + "prod": py_cnst(1), + "any": py_cnst(0), + # "all" is desugared to `!any(!val)` + } + + assert ( + reduction_type in rtypes_to_inits.keys() + ), f"{reduction_type} not supported for zero-dimension tensors!" + + def const_fn(index): + return ops.constant(rtypes_to_inits[reduction_type], dst_dtype) + + return Pointwise.create( + device=device, + dtype=src_dtype, + inner_fn=const_fn, + ranges=list(ranges), + ) + if reduction_numel == 1: # this reduction is actually a pointwise op if reduction_type in ("argmin", "argmax"): def fn(index): - assert len(index) <= 1 return 0 else: @@ -909,11 +1012,11 @@ def as_storage_and_layout(x, freeze=True, want_contiguous=False, stride_order=No x.data.decide_layout() return x, x.data.layout if isinstance(x, ReinterpretView): + # making the base of x contiguous or stride_ordered will not necessarily make + # the ReinterpretedView either, so dont pass along those arguments buffer, _ = as_storage_and_layout( x.data, freeze=freeze, - want_contiguous=want_contiguous, - stride_order=stride_order, ) return buffer, x.layout raise NotImplementedError @@ -948,6 +1051,9 @@ def get_name(self): def mark_reuse(self, users): return self.data.mark_reuse(users) + def has_exceeded_max_reads(self): + return self.data.has_exceeded_max_reads() + def realize(self): return self.data.realize() @@ -1310,6 +1416,10 @@ class ReinterpretView(BaseView): layout: "Layout" + def __post_init__(self): + if isinstance(self.data, BaseView): + self.data = self.data.unwrap_view() + def __str__(self): return self.str_helper( [ @@ -1422,6 +1532,9 @@ def get_device(self): def mark_reuse(self, users): pass + def has_exceeded_max_reads(self): + return False + def get_reads(self): return () @@ -1457,11 +1570,23 @@ def loader(index): @dataclasses.dataclass class Layout(IRNode): - device: torch.device - dtype: torch.dtype - size: List[Expr] - stride: List[Expr] - offset: Expr = Integer(0) + def __init__( + self, + device: torch.device, + dtype: torch.dtype, + size: List[Expr], + stride: List[Expr], + offset: Expr = Integer(0), + ): + self.device = device + self.dtype = dtype + self.size = size + self._stride = stride + self.offset = offset + + @property + def stride(self): + return self._stride def __str__(self): offset = "" @@ -1677,6 +1802,15 @@ def __init__(self, target: IRNode): ) self.target = target + @Layout.stride.getter + def stride(self): + return self.real_layout().stride + + def real_layout(self): + if isinstance(self.target, MutationLayout): + return self.target.real_layout() + return self.target.data.layout + @classmethod def realize_into(cls, src, dst): dst.realize() @@ -2225,6 +2359,54 @@ def copy_input(x): pw.realize() return pw + @classmethod + def process_kernel(cls, kernel, *args, **kwargs): + binded_args = signature(kernel).bind(*args, **kwargs).arguments + args_flat, args_spec = pytree.tree_flatten(binded_args) + + is_arg_tensor = [] + tensor_args = [] + non_tensor_args = [] + for arg in args_flat: + is_arg_tensor.append(isinstance(arg, IRNode)) + if is_arg_tensor[-1]: + tensor_args.append(arg) + else: + non_tensor_args.append(arg) + + def unflatten_args(new_tensor_args, new_non_tensor_args): + result = [] + it_tensors = iter(new_tensor_args) + it_non_tensors = iter(new_non_tensor_args) + for is_tensor in is_arg_tensor: + if is_tensor: + result.append(next(it_tensors)) + else: + result.append(next(it_non_tensors)) + result = pytree.tree_unflatten(result, args_spec) + return result.get("args", []), result.get("kwargs", {}) + + tensor_args = [cls.realize_input(x) for x in tensor_args] + + # freeze layout otherwise our output stride calculation might + # become incorrect + for x in tensor_args: + if is_storage_and_layout(x): + as_storage_and_layout(x, freeze=True) + + # We don't have generic shape formulas, so just burn in the + # shapes and run an example input. + # TODO(jansel): replace this with dynamic shape formulas + example_args = [] + + for x in tensor_args: + example_args.append(ir_node_to_tensor(x, guard_shape=True)) + + new_args, new_kwargs = unflatten_args(example_args, non_tensor_args) + example_output = kernel(*new_args, **new_kwargs) + + return example_output, tensor_args, non_tensor_args, unflatten_args + @classmethod def convert_to_reinterpret_view(cls, x): """ @@ -2271,7 +2453,7 @@ def convert_to_reinterpret_view(cls, x): def realize_input(cls, x): if x is None: return NoneAsConstantBuffer() - if isinstance(x, sympy.Expr): + if isinstance(x, (sympy.Expr, int)): return ShapeAsConstantBuffer(x) if isinstance(x, Constant): return V.graph.add_tensor_constant( @@ -2300,43 +2482,53 @@ def realize_input(cls, x): @classmethod def require_stride1(cls, x): - if len(x.get_stride()) == 0: - return x - for stride in x.get_stride(): - if stride == 1: + if is_storage_and_layout(x): + if len(x.get_stride()) == 0: return x + for stride in x.get_stride(): + if stride == 1: + return x return cls.copy_input(x) @classmethod - def require_contiguous(cls, x): - if is_contiguous_storage_and_layout(x): - as_contiguous_storage_and_layout(x, freeze=True) + def require_stride_order(cls, x, order): + if x.get_numel() == 0: # Layout doesn't matter return x - x = cls.copy_input(x) - assert is_contiguous_storage_and_layout(x) - as_contiguous_storage_and_layout(x, freeze=True) - return x - @classmethod - def require_stride_order(cls, x, order): # require x to have the layout as strided_ordered as order - if isinstance( - x.get_layout(), FlexibleLayout - ) and is_stride_order_storage_and_layout(x, order): - # fix flexiblelayout to be FixedLayout with stride_order - as_storage_and_layout( - x, freeze=True, want_contiguous=False, stride_order=order - ) - return x - elif isinstance(x.get_layout(), FixedLayout) and x.layout.is_stride_ordered( - order - ): + if is_storage_and_layout(x): + if isinstance(x.get_layout(), FlexibleLayout): + # fix flexiblelayout to be FixedLayout with stride_order + as_storage_and_layout( + x, freeze=True, want_contiguous=False, stride_order=order + ) + return x + elif isinstance( + x.get_layout(), FixedLayout + ) and x.get_layout().is_stride_ordered(order): + return x + elif isinstance(x.get_layout(), MutationLayout): + if isinstance(x.get_layout().real_layout(), FlexibleLayout): + raise AssertionError( + "the MutationLayout's real layout shouldn't be FlexibleLayout" + ) + elif isinstance( + x.get_layout().real_layout(), FixedLayout + ) and x.get_layout().real_layout().is_stride_ordered(order): + return x + + # TODO - Storage to InputBuffer + if isinstance(x, InputBuffer) and x.get_layout().is_stride_ordered(order): return x x = cls.copy_input(x) as_storage_and_layout(x, freeze=True, want_contiguous=False, stride_order=order) assert is_stride_order_storage_and_layout(x, order) return x + @classmethod + def require_contiguous(cls, x): + return cls.require_stride_order(x, list(reversed(range(len(x.get_size()))))) + def apply_constraint(self): pass @@ -2752,43 +2944,6 @@ def get_reads(self): return () -class AdaptiveAvgPool2d(ExternKernelAlloc): - kernel = "aten._adaptive_avg_pool2d" - - @classmethod - def create(cls, x, target_size): - # x = cls.require_stride1(cls.realize_input(x)) - x = cls.realize_input(x) - output_size = [ - *x.get_size()[: -len(target_size)], - *map(sympy.Integer, target_size), - ] - # contigouse stride order - stride_order = list(reversed(range(len(output_size)))) - return cls( - FlexibleLayout( - x.get_device(), - x.get_dtype(), - output_size, - # TODO(jansel): fix channels last case - # FlexibleLayout.contiguous_strides(output_size), - stride_order, - ), - (x,), - (tuple(target_size),), - ) - - def apply_constraint(self): - x = self.inputs[0] - if isinstance(x.get_layout(), FixedLayout): - # fix self's layout to be the same order as x - self.freeze_layout_with_same_order(x.get_layout().stride) - else: - x = self.require_stride_order(x, self.layout.preferred_stride_order) - self.inputs[0] = x - self.freeze_layout_with_stride_order(self.layout.preferred_stride_order) - - @dataclasses.dataclass class FallbackKernel(ExternKernelAlloc): def __init__( @@ -2824,104 +2979,73 @@ class Shim: def __repr__(self): return self.ref - tensor_args = [Shim(x.codegen_reference()) for x in self.inputs] - constant_args = [Shim(repr(x)) for x in self.constant_args] - def gen_kwarg(k, v): return f"{k}={repr(v)}" - kwargs = list(gen_kwarg(k, v) for k, v in self.kwargs.items()) - - return list(map(repr, self.unflatten_args(tensor_args, constant_args))) + kwargs + tensor_args = [Shim(x.codegen_reference()) for x in self.inputs] + constant_args = [Shim(repr(x)) for x in self.constant_args] + args, kwargs = self.unflatten_args(tensor_args, constant_args) + return list(map(repr, args)) + list(gen_kwarg(k, v) for k, v in kwargs.items()) @classmethod def create(cls, kernel, *args, **kwargs): - args_flat, args_spec = pytree.tree_flatten(args) - - is_arg_tensor = [] - tensor_args = [] - non_tensor_args = [] - for arg in args_flat: - is_arg_tensor.append(isinstance(arg, IRNode)) - if is_arg_tensor[-1]: - tensor_args.append(arg) - else: - non_tensor_args.append(arg) - - def unflatten_args(new_tensor_args, new_non_tensor_args): - new_args = [] - it_tensors = iter(new_tensor_args) - it_non_tensors = iter(new_non_tensor_args) - for is_tensor in is_arg_tensor: - if is_tensor: - new_args.append(next(it_tensors)) - else: - new_args.append(next(it_non_tensors)) - return pytree.tree_unflatten(new_args, args_spec) - - tensor_args = [ - cls.require_contiguous(cls.realize_input(x)) for x in tensor_args - ] - - # We don't have generic shape formulas, so just burn in the - # shapes and run an example input. - # TODO(jansel): replace this with dynamic shape formulas - example_args = [] - for x in tensor_args: - size = [V.graph.sizevars.guard_static_shape(s) for s in x.get_size()] - stride = [ - V.graph.sizevars.guard_static_shape(s) for s in x.get_layout().stride - ] - dtype = x.get_dtype() - device = x.get_device() - arg = torch.empty_strided( - size=size, stride=stride, dtype=dtype, device=device - ).zero_() - example_args.append(arg) - - example_output = kernel( - *unflatten_args(example_args, non_tensor_args), **kwargs - ) - - if isinstance(example_output, (list, tuple)): - packed = FallbackKernel( - MultiOutputLayout(tensor_args[0].get_device()), - kernel, + fake_incorrect_kernels = ( + aten._fft_r2c.default, + aten._fft_r2c.out, + aten._fft_c2r.default, + aten._fft_c2c.default, + aten._fft_c2c.out, + aten._linalg_svd.default, + aten._linalg_svd.U, + aten._fused_moving_avg_obs_fq_helper_functional, + ) + context = ( + FakeTensorMode if kernel not in fake_incorrect_kernels else nullcontext + ) + with context(): + ( + example_output, tensor_args, non_tensor_args, unflatten_args, - ) - return [ - ( - MultiOutput( - FixedLayout( - example_output[i].device, - example_output[i].dtype, - [sympy.Integer(s) for s in example_output[i].size()], - [sympy.Integer(s) for s in example_output[i].stride()], - ), - packed, - i, - ) - if example_output[i] is not None - else None + ) = cls.process_kernel(kernel, *args, **kwargs) + + assert tensor_args or isinstance( + example_output, torch.Tensor + ), "Not sure where to find device info" + packed = FallbackKernel( + MultiOutputLayout( + tensor_args[0].get_device() if tensor_args else example_output.device + ), + kernel, + tensor_args, + non_tensor_args, + unflatten_args, + kwargs, + ) + + def generate_output(output, index=""): + if isinstance(output, (list, tuple)): + return type(output)( + generate_output(output[i], f"{index}[{i}]") + for i in range(len(output)) ) - for i in range(len(example_output)) - ] - else: - return FallbackKernel( - FixedLayout( - example_output.device, - example_output.dtype, - [sympy.Integer(s) for s in example_output.size()], - [sympy.Integer(s) for s in example_output.stride()], - ), - kernel, - tensor_args, - non_tensor_args, - unflatten_args, - kwargs, - ) + elif isinstance(output, torch.Tensor): + return MultiOutput( + FixedLayout( + output.device, + output.dtype, + [sympy.Integer(s) for s in output.size()], + [sympy.Integer(s) for s in output.stride()], + ), + packed, + index, + ) + else: + assert output is None, "FallbackKernel output type is not supported" + return None + + return generate_output(example_output) def apply_constraint(self): return super().apply_constraint() @@ -2935,11 +3059,11 @@ class MultiOutputLayout(IRNode): class MultiOutput(ExternKernel): def codegen(self, wrapper): wrapper.writeline( - f"{self.get_name()} = {self.inputs[0].get_name()}[{self.index}]" + f"{self.get_name()} = {self.inputs[0].get_name()}{self.index}" ) self.codegen_size_asserts(wrapper) - def __init__(self, layout, input, index): + def __init__(self, layout, input, index: str): super().__init__(None, layout, [input], ()) self.name = V.graph.register_buffer(self) self.index = index @@ -2987,8 +3111,32 @@ def create( output_padding_: List[int], groups: int, ): - x = cls.require_stride1(cls.realize_input(x)) - weight = cls.require_stride1(cls.realize_input(weight)) + with torch._subclasses.FakeTensorMode(): + x_fake = ir_node_to_tensor(x, guard_shape=True) + weight_fake = ir_node_to_tensor(weight, guard_shape=True) + bias_fake = ( + ir_node_to_tensor(bias, guard_shape=True) if bias is not None else bias + ) + output = torch.ops.aten.convolution( + x_fake, + weight_fake, + bias_fake, + stride_, + padding_, + dilation_, + transposed, + output_padding_, + groups, + ) + req_stride_order = get_stride_order(output.stride()) + + if config.triton.convolution == "aten": + weight = cls.require_stride_order(weight, req_stride_order) + x = cls.require_stride_order(x, req_stride_order) + else: + x = cls.require_stride1(cls.realize_input(x)) + weight = cls.require_stride1(cls.realize_input(weight)) + stride = tuple(stride_) padding = tuple(padding_) dilation = tuple(dilation_) @@ -2996,70 +3144,13 @@ def create( output_padding = tuple(output_padding_) assert isinstance(groups, int) + output_size = output.shape + weight_shape = [ sympy.Integer(V.graph.sizevars.guard_static_shape(s)) for s in weight.get_size() ] - - out_channels, in_channels1, *kernel_size = weight_shape - in_channels1 = in_channels1 * groups - if transposed: - out_channels, in_channels1 = in_channels1, out_channels - - if bias is not None: - bias = cls.require_stride1(cls.realize_input(bias)) - (bias_shape,) = [ - sympy.Integer(V.graph.sizevars.guard_static_shape(s)) - for s in bias.get_size() - ] - assert bias_shape == out_channels, f"{bias_shape} == {out_channels}" - - if len(x.get_size()) == 1 + len(kernel_size): - in_channels2, *input_size = x.get_size() - in_channels_stride, *_ = x.get_stride() - output_size = [] - else: - assert len(x.get_size()) == 2 + len(kernel_size) - batch, in_channels2, *input_size = x.get_size() - _, in_channels_stride, *_ = x.get_stride() - output_size = [batch] - - V.graph.sizevars.guard_equals(in_channels1, in_channels2) - - output_size.append(out_channels) - - assert ( - len(stride) - == len(padding) - == len(dilation) - == len(output_padding) - == len(kernel_size) - == len(input_size) - ) - for i in range(len(stride)): - if transposed: - output_size.append( - (input_size[i] - 1) * stride[i] - - 2 * padding[i] - + dilation[i] * (kernel_size[i] - 1) - + output_padding[i] - + 1 - ) - else: - output_size.append( - IndexingDiv( - input_size[i] - + 2 * padding[i] - - dilation[i] * (kernel_size[i] - 1) - - 1 - + stride[i], - stride[i], - ) - + 2 * output_padding[i] - ) - output_size[-1] = sympy.Integer( - V.graph.sizevars.guard_static_shape(output_size[-1]) - ) + _, _, *kernel_size = weight_shape # choose runtime kernel config_conv = config.triton.convolution @@ -3097,6 +3188,7 @@ def create( # for conv2d or conv3d, prefer channels last format if kernel == "triton_ops.conv": output_layout_str = "torch.channels_last" + elif config.tune_layout and len(x.get_size()) == 4: from .codegen.autotuner import tuned_conv_layout @@ -3113,47 +3205,32 @@ def create( x.get_device(), x.get_dtype(), ) - else: - output_layout_str = "torch.contiguous_format" - # If x or weight have one channels_last(2d or 3d) format, it will call channels_last path, - # which align with aten.convolutuion path(cpu only support 2d case now). - # TODO: after cpu 3d convolution support channels_last path, the size check can be removed. - - # CUDA channels_last path depend on cudnn version, see - # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/ConvUtils.h. - valid_cudnn = False - if ( - torch.backends.cudnn.is_available() - and torch.backends.cudnn.version() >= 7603 - ): - valid_cudnn = True - valid_device = x.get_device().type == "cpu" or ( - x.get_device().type == "cuda" and valid_cudnn + else: + output_layout_str = ( + "torch.contiguous_format" + if output.is_contiguous() + else "torch.channels_last" ) - if ( - valid_device - and len(x.get_size()) == 4 - and ( - x.get_layout().is_channels_last_stride_ordered() - or weight.get_layout().is_channels_last_stride_ordered() - ) - ): - output_layout_str = "torch.channels_last" if output_layout_str == "torch.channels_last": stride_order = [0] + list(reversed(range(1, len(kernel_size) + 1))) if len(stride_order) < len(output_size): # add batch dim if it exists stride_order = [len(stride_order)] + stride_order + strides = make_channels_last_strides_for(output_size) else: stride_order = list(reversed(range(len(output_size)))) + strides = make_contiguous_strides_for(output_size) - output_layout = FlexibleLayout( + if config.triton.convolution != "aten": + x = cls.require_stride_order(x, stride_order) + + output_layout = FixedLayout( x.get_device(), x.get_dtype(), output_size, - stride_order, + strides, ) if bias is not None: @@ -3173,13 +3250,6 @@ def create( kernel, ) - def apply_constraint(self): - x = self.inputs[0] - # FixedLayout of input - x = self.require_stride_order(x, self.layout.preferred_stride_order) - self.inputs[0] = x - self.freeze_layout_with_stride_order(self.layout.preferred_stride_order) - def map_args(self): # x, w, bias in_args = [x.codegen_reference() for x in self.inputs] @@ -3279,6 +3349,394 @@ def get_template_tiling(self): ) +def _prepare_convolution_fusion_create( + cls, + x: "TensorBox", + weight: "TensorBox", + bias: "TensorBox", + padding_: List[int], + stride_: List[int], + dilation_: List[int], + groups: int, +): + """ + This function is a helper function to prepare inputs, layout and constant args + for convolution post-op fusion's create function, including deciding the output + layout (channels first or channels last), realizing inputs and make them etc. The + function only supports the CPU device since conv post-op fusion kernel is only + supported on CPU right now. + """ + stride = tuple(stride_) + padding = tuple(padding_) + dilation = tuple(dilation_) + assert isinstance(groups, int) + with torch._subclasses.FakeTensorMode(): + x_fake = ir_node_to_tensor(x, guard_shape=True) + weight_fake = ir_node_to_tensor(weight, guard_shape=True) + bias_fake = ( + ir_node_to_tensor(bias, guard_shape=True) if bias is not None else bias + ) + output = torch.ops.aten.convolution( + x_fake, + weight_fake, + bias_fake, + stride, + padding, + dilation, + False, + [0, 0], + groups, + ) + output_size = output.size() + req_stride_order = [0] + list(reversed(range(1, len(stride) + 1))) + req_stride_order = [len(req_stride_order)] + req_stride_order + output_stride = make_channels_last_strides_for(output_size) + + x = cls.require_stride_order(x, req_stride_order) + assert x.get_device().type == "cpu" and weight.get_device().type == "cpu" + inputs = [x, weight] + + kernel_layout = FixedLayout( + x.get_device(), + x.get_dtype(), + output.size(), + output_stride, + ) + constant_args = [padding, stride, dilation, groups] + + if bias is not None: + inputs.append(bias) + else: + constant_args.insert(0, bias) + return inputs, constant_args, kernel_layout, req_stride_order + + +class ConvolutionUnary(ExternKernelAlloc): + kernel = "torch.ops.mkldnn._convolution_pointwise" + + def __init__( + self, + layout, + inputs, + constant_args=(), + kernel="torch.ops.mkldnn._convolution_pointwise", + ): + super().__init__(layout, inputs, constant_args) + self.kernel = kernel + + def codegen(self, wrapper): + wrapper.writeline( + f"{self.get_name()} = {self.kernel}({', '.join(self.codegen_args())})" + ) + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + @classmethod + def create( + cls, + x: "TensorBox", + weight: "TensorBox", + bias: "TensorBox", + padding_: List[int], + stride_: List[int], + dilation_: List[int], + groups: int, + attr, + scalars, + algorithm, + ): + kernel = "torch.ops.mkldnn._convolution_pointwise" + (inputs, constant_args, kernel_layout, _) = _prepare_convolution_fusion_create( + cls, x, weight, bias, padding_, stride_, dilation_, groups + ) + constant_args = constant_args + [attr, scalars, algorithm] + return ConvolutionUnary( + layout=kernel_layout, + inputs=inputs, + constant_args=constant_args, + kernel=kernel, + ) + + +class ConvolutionBinary(ExternKernelAlloc): + kernel = "torch.ops.mkldnn._convolution_pointwise.binary" + + def __init__( + self, + layout, + inputs, + constant_args=(), + kernel="torch.ops.mkldnn._convolution_pointwise.binary", + ): + super().__init__(layout, inputs, constant_args) + self.kernel = kernel + + def codegen(self, wrapper): + wrapper.writeline( + f"{self.get_name()} = {self.kernel}({', '.join(self.codegen_args())})" + ) + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + @classmethod + def create( + cls, + x: "TensorBox", + other: "TensorBox", + weight: "TensorBox", + bias: "TensorBox", + padding_: List[int], + stride_: List[int], + dilation_: List[int], + groups: int, + binary_attr: str, + binary_alpha: Optional[float], + unary_attr: Optional[str], + unary_scalars: Optional[List], + unary_algorithm: Optional[str], + ): + kernel = "torch.ops.mkldnn._convolution_pointwise.binary" + ( + inputs, + constant_args, + kernel_layout, + req_stride_order, + ) = _prepare_convolution_fusion_create( + cls, x, weight, bias, padding_, stride_, dilation_, groups + ) + other = cls.require_stride_order(other, req_stride_order) + inputs.insert(1, other) + constant_args = constant_args + [ + binary_attr, + binary_alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ] + return ConvolutionBinary( + layout=kernel_layout, + inputs=inputs, + constant_args=constant_args, + kernel=kernel, + ) + + +class ConvolutionBinaryInplace(ExternKernelAlloc): + kernel = "torch.ops.mkldnn._convolution_pointwise_.binary" + + def __init__( + self, + kernel_layout, + inputs, + constant_args=(), + kernel="torch.ops.mkldnn._convolution_pointwise_.binary", + ): + super().__init__(kernel_layout, inputs, constant_args) + self.kernel = kernel + + def codegen(self, wrapper): + wrapper.writeline( + f"{self.get_name()} = {self.kernel}({', '.join(self.codegen_args())})" + ) + + def get_mutation_names(self): + assert isinstance(self.layout, MutationLayout) + return (self.layout.target.get_name(),) + + @classmethod + def create( + cls, + x: "TensorBox", + other: "TensorBox", + weight: "TensorBox", + bias: "TensorBox", + padding_: List[int], + stride_: List[int], + dilation_: List[int], + groups: int, + binary_attr: str, + binary_alpha: Optional[float], + unary_attr: Optional[str], + unary_scalars: Optional[List], + unary_algorithm: Optional[str], + ): + kernel = "torch.ops.mkldnn._convolution_pointwise_.binary" + (inputs, constant_args, _, _) = _prepare_convolution_fusion_create( + cls, x, weight, bias, padding_, stride_, dilation_, groups + ) + other = cls.realize_input(other) + V.graph.realize_users_of(other.get_name()) + inputs.insert(1, other) + constant_args = constant_args + [ + binary_attr, + binary_alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ] + return ConvolutionBinaryInplace( + kernel_layout=MutationLayout(inputs[1]), + inputs=inputs, + constant_args=constant_args, + kernel=kernel, + ) + + +class MKLPackedLinear(ExternKernelAlloc): + kernel = "torch.ops.mkl._mkl_linear" + + def __init__( + self, + layout, + inputs, + constant_args=(), + kernel="torch.ops.mkl._mkl_linear", + ): + super().__init__(layout, inputs, constant_args) + self.kernel = kernel + + def codegen(self, wrapper): + wrapper.writeline( + f"{self.get_name()} = {self.kernel}({', '.join(self.codegen_args())})" + ) + + @classmethod + def create(cls, x, packed_w, orig_w, bias, batch_size): + kernel = "torch.ops.mkl._mkl_linear" + + with torch._subclasses.FakeTensorMode(): + x_fake = ir_node_to_tensor(x, guard_shape=True) + weight_fake = ir_node_to_tensor(orig_w, guard_shape=True) + bias_fake = ( + ir_node_to_tensor(bias, guard_shape=True) if bias is not None else bias + ) + output = torch.ops.aten.linear( + x_fake, + weight_fake, + bias_fake, + ) + output_size = output.size() + req_stride_order = list(reversed(range(len(output_size)))) + output_stride = output.stride() + x = cls.require_stride_order(x, req_stride_order) + inputs = [x, packed_w, orig_w] + constant_args = [batch_size] + if bias is not None: + inputs.append(bias) + else: + constant_args.insert(0, bias) + + return MKLPackedLinear( + layout=FixedLayout( + x.get_device(), x.get_dtype(), output_size, output_stride + ), + inputs=inputs, + constant_args=constant_args, + kernel=kernel, + ) + + +class LinearUnary(ExternKernelAlloc): + kernel = "torch.ops.mkldnn._linear_pointwise" + + def __init__( + self, + layout, + inputs, + constant_args=(), + kernel="torch.ops.mkldnn._linear_pointwise", + ): + super().__init__(layout, inputs, constant_args) + self.kernel = kernel + + def codegen(self, wrapper): + wrapper.writeline( + f"{self.get_name()} = {self.kernel}({', '.join(self.codegen_args())})" + ) + + @classmethod + def create(cls, x, w, b, attr, scalars, algorithm): + kernel = "torch.ops.mkldnn._linear_pointwise" + x = cls.require_stride1(cls.realize_input(x)) + w = cls.require_stride1(cls.realize_input(w)) + + *m, ic = x.get_size() + oc, ic = w.get_size() + + inputs = [x, w] + constant_args = [attr, scalars, algorithm] + if b is not None: + b = cls.require_stride1(cls.realize_input(b)) + inputs.append(b) + else: + constant_args.insert(0, b) + + return LinearUnary( + layout=FlexibleLayout( + device=x.get_device(), + dtype=x.get_dtype(), + size=list(m) + [oc], + ), + inputs=inputs, + constant_args=constant_args, + kernel=kernel, + ) + + def apply_constraint(self): + pass + + +class LinearBinary(ExternKernelAlloc): + kernel = "torch.ops.mkldnn._linear_pointwise.binary" + + def __init__( + self, + layout, + inputs, + constant_args=(), + kernel="torch.ops.mkldnn._linear_pointwise.binary", + ): + super().__init__(layout, inputs, constant_args) + self.kernel = kernel + + def codegen(self, wrapper): + wrapper.writeline( + f"{self.get_name()} = {self.kernel}({', '.join(self.codegen_args())})" + ) + + @classmethod + def create(cls, x, y, w, b, attr): + kernel = "torch.ops.mkldnn._linear_pointwise.binary" + x = cls.require_stride1(cls.realize_input(x)) + y = cls.require_stride1(cls.realize_input(y)) + w = cls.require_stride1(cls.realize_input(w)) + + *m, ic = x.get_size() + oc, ic = w.get_size() + + inputs = [x, y, w] + constant_args = [attr] + if b is not None: + b = cls.require_stride1(cls.realize_input(b)) + inputs.append(b) + else: + constant_args.insert(0, b) + + return LinearBinary( + layout=FlexibleLayout( + device=x.get_device(), + dtype=x.get_dtype(), + size=list(m) + [oc], + ), + inputs=inputs, + constant_args=constant_args, + kernel=kernel, + ) + + def apply_constraint(self): + pass + + @dataclasses.dataclass class MutableBox(IRNode): """ @@ -3341,6 +3799,7 @@ def realize(self): data=self.data, ) self.data.name = V.graph.register_buffer(self.data) + self.data.origins = self.origins return self.data.name def realize_hint(self): @@ -3350,6 +3809,12 @@ def realize_hint(self): if isinstance(self.data, (Pointwise, Reduction)) and self.num_reads() > 1: self.realize() + def has_exceeded_max_reads(self): + return isinstance(self.data, Pointwise) and ( + self.num_reads() > config.realize_acc_reads_threshold + or len(self.inner_fn_str()) > config.realize_bytes_threshold + ) + def mark_reuse(self, users): """ A heuristic to decide if we should realize a tensor @@ -3499,6 +3964,8 @@ def add_index(expr, category, buf_name=None): ) class CaptureIndexing(V.WrapperHandler): + self.name = "CaptureIndexing" + def load(self, name: str, index: sympy.Expr): index = add_index(index, "reads", name) return self._inner.load(name, index) @@ -3581,6 +4048,7 @@ def __init__(self): self.garbage_collect_values = False self.env = {} self.fetch_attr = submodules.__getitem__ + self.name = V.get_ops_handler().name return InterpreterShim().run(V.get_ops_handler()) diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 49a136b440ed2..45f1772e197ad 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -9,6 +9,7 @@ import torch import torch.fx +import torch.utils._pytree as pytree from torch._prims_common import ( elementwise_dtypes, ELEMENTWISE_TYPE_PROMOTION_KIND, @@ -36,6 +37,7 @@ log = logging.getLogger(__name__) lowerings = {} +layout_constraints = {} fallbacks = set() aten = torch.ops.aten prims = torch.ops.prims @@ -51,6 +53,14 @@ def add_needs_realized_inputs(fn): needs_realized_inputs.add(getattr(fn, overload)) +def add_layout_constraint(fn, constraint): + if isinstance(fn, torch._ops.OpOverloadPacket): + for overload in fn.overloads(): + layout_constraints[getattr(fn, overload)] = constraint + else: + layout_constraints[fn] = constraint + + add_needs_realized_inputs( [ aten.as_strided, @@ -164,8 +174,15 @@ def wrapped(*args, **kwargs): args = args[0] # Only look at args that are Tensors indices = [i for i, x in enumerate(args) if isinstance(x, TensorBox)] - # kwargs tensors not supported yet - assert not any(isinstance(x, TensorBox) for x in kwargs.values()) + + # explicitly assert for "out=" ops for better error messages + assert not any( + x == "out" for x in kwargs.keys() + ), "out= ops aren't yet supported" + # kwargs tensors not supported yet unless it's a fallback op + assert not any(isinstance(x, TensorBox) for x in kwargs.values()) or all( + fn in fallbacks for fn in aten_fn + ) if (type_promotion_kind or convert_input_to_bool) and indices: if convert_input_to_bool: @@ -386,33 +403,6 @@ def _to_copy( return x -@register_lowering(aten.to) -def to( - x, - device_or_dtype=None, - non_blocking=False, - copy=False, - memory_format=None, - device=None, - dtype=None, - layout=None, -): - assert not memory_format, "TODO" - assert layout in (None, torch.strided) - if isinstance(device_or_dtype, torch.dtype): - return to_dtype(x, device_or_dtype) - elif isinstance(device_or_dtype, torch.device): - return to_device(x, device_or_dtype) - else: - assert device_or_dtype is None, device_or_dtype - - if device is not None: - x = to_device(x, device) - if dtype is not None: - x = to_dtype(x, dtype) - return x - - def ops_wrapper(name): assert isinstance(name, str) @@ -886,20 +876,150 @@ def bmm(a: TensorBox, b: TensorBox): return TensorBox.create(ir.BatchMatrixMultiply.create(a, b)) +def register_onednn_fusion_ops(): + if torch._C.has_mkldnn: + + @register_lowering(torch.ops.mkldnn._convolution_pointwise) + def convolution_unary( + x: TensorBox, + weight: TensorBox, + bias: TensorBox, + padding, + stride, + dilation, + groups, + attr, + scalars, + algorithm, + ): + return TensorBox.create( + ir.ConvolutionUnary.create( + x, + weight, + bias, + padding, + stride, + dilation, + groups, + attr, + scalars, + algorithm, + ) + ) + + @register_lowering(torch.ops.mkldnn._convolution_pointwise.binary) + def convolution_binary( + x: TensorBox, + other: TensorBox, + weight: TensorBox, + bias: TensorBox, + padding, + stride, + dilation, + groups, + binary_attr, + binary_alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ): + return TensorBox.create( + ir.ConvolutionBinary.create( + x, + other, + weight, + bias, + padding, + stride, + dilation, + groups, + binary_attr, + binary_alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ) + ) + + @register_lowering(torch.ops.mkldnn._convolution_pointwise_.binary) + def convolution_binary_inplace( + x: TensorBox, + other: TensorBox, + weight: TensorBox, + bias: TensorBox, + padding, + stride, + dilation, + groups, + binary_attr, + binary_alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ): + return TensorBox.create( + ir.ConvolutionBinaryInplace.create( + x, + other, + weight, + bias, + padding, + stride, + dilation, + groups, + binary_attr, + binary_alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ) + ) + + @register_lowering(torch.ops.mkldnn._linear_pointwise) + def linear_unary( + x: TensorBox, w: TensorBox, b: TensorBox, attr, scalars, algorithm + ): + return TensorBox.create( + ir.LinearUnary.create(x, w, b, attr, scalars, algorithm) + ) + + @register_lowering(torch.ops.mkldnn._linear_pointwise.binary) + def linear_binary(x: TensorBox, y: TensorBox, w: TensorBox, b: TensorBox, attr): + return TensorBox.create(ir.LinearBinary.create(x, y, w, b, attr)) + + if torch._C.has_mkl: + + @register_lowering(torch.ops.mkl._mkl_linear) + def mkl_packed_linear( + x: TensorBox, + packed_w: TensorBox, + orig_w: TensorBox, + b: TensorBox, + batch_size, + ): + return TensorBox.create( + ir.MKLPackedLinear.create(x, packed_w, orig_w, b, batch_size) + ) + + else: + pass + + +register_onednn_fusion_ops() + + def fallback_handler(kernel): fallbacks.add(kernel) def handler(*args, **kwargs): - result = ir.FallbackKernel.create(kernel, *args, **kwargs) - if isinstance(result, (list, tuple)): - return list(map(TensorBox.create, result)) - else: - return TensorBox.create(result) + return pytree.tree_map( + TensorBox.create, ir.FallbackKernel.create(kernel, *args, **kwargs) + ) return handler -def make_fallback(kernel): +def make_fallback(kernel, layout_constraint=None): assert ( kernel not in decompositions ), f"both a fallback and a decomp for same kernel: {kernel}" @@ -909,6 +1029,8 @@ def make_fallback(kernel): ) add_needs_realized_inputs(kernel) + if layout_constraint is not None: + add_layout_constraint(kernel, layout_constraint) return register_lowering(kernel, type_promotion_kind=None)(fallback_handler(kernel)) @@ -918,12 +1040,10 @@ def native_dropout(x, p, train): config.fallback_random ), "this should be handled in decomps unless config.fallback_random" if train: - return list( - map( - TensorBox.create, - ir.FallbackKernel.create(aten.native_dropout, x, p, train), - ) + return pytree.tree_map( + TensorBox.create, ir.FallbackKernel.create(aten.native_dropout, x, p, train) ) + return x, ones_like(x, dtype=torch.bool) @@ -1062,32 +1182,59 @@ def inner_fn(index): ) +def require_dense(_, *args, **kwargs): + args, kwargs = pytree.tree_map_only( + ir.IRNode, lambda t: ir.ExternKernel.require_stride1(t), (args, kwargs) + ) + return args, kwargs + + +def require_contiguous(_, *args, **kwargs): + args, kwargs = pytree.tree_map_only( + ir.IRNode, lambda t: ir.ExternKernel.require_contiguous(t), (args, kwargs) + ) + return args, kwargs + + if has_torchvision_roi_align(): make_fallback(torch.ops.torchvision.roi_align) + +def constrain_to_fx_strides(fx_node, *args, **kwargs): + def apply_constraint(arg, fx_arg): + if isinstance(arg, ir.IRNode): + stride_order = ir.get_stride_order(fx_arg.meta["val"].stride()) + return ir.ExternKernel.require_stride_order(arg, stride_order) + return arg + + args = [apply_constraint(arg, fx_arg) for arg, fx_arg in zip(args, fx_node.args)] + kwargs = {k: apply_constraint(v, fx_node.kwargs[k]) for k, v in kwargs.items()} + return args, kwargs + + # TODO(jansel): we should implement decomps or lowerings for these # https://github.com/pytorch/torchdynamo/issues/327 -make_fallback(aten._adaptive_avg_pool2d_backward) -make_fallback(aten.as_strided_scatter) -make_fallback(aten.convolution_backward) -make_fallback(aten._cudnn_rnn) -make_fallback(aten._cudnn_rnn_backward) -make_fallback(aten.cumsum) -make_fallback(aten._embedding_bag) -make_fallback(aten._embedding_bag_forward_only) +make_fallback(aten._adaptive_avg_pool2d_backward, require_dense) +make_fallback(aten.convolution_backward, constrain_to_fx_strides) +make_fallback(aten._cudnn_rnn, require_dense) +make_fallback(aten._cudnn_rnn_backward, require_contiguous) +make_fallback(aten.cumsum, require_dense) +make_fallback(aten._embedding_bag, require_contiguous) +make_fallback(aten._embedding_bag_forward_only, require_contiguous) make_fallback(aten._fused_moving_avg_obs_fq_helper) make_fallback(aten._fused_moving_avg_obs_fq_helper_functional) -make_fallback(aten.grid_sampler_2d_backward) +make_fallback(aten.grid_sampler_2d_backward, require_dense) make_fallback(aten.randperm) make_fallback(aten.sort) make_fallback(aten.sort.stable) make_fallback(aten._sparse_coo_tensor_with_dims_and_tensors) -make_fallback(aten._thnn_fused_lstm_cell) +make_fallback(aten._thnn_fused_lstm_cell, require_dense) make_fallback(aten.topk) -make_fallback(aten.unfold) -make_fallback(aten.unfold_backward) -make_fallback(aten.upsample_bicubic2d_backward) -make_fallback(aten.upsample_bilinear2d_backward) +make_fallback(aten.upsample_bicubic2d_backward, require_contiguous) +make_fallback(aten.upsample_bilinear2d_backward, require_dense) + + +add_layout_constraint(aten.convolution, constrain_to_fx_strides) @register_lowering(aten.convolution) @@ -1652,19 +1799,26 @@ def fn(idx): ) -def check_and_broadcast_indices(indices): +def check_and_broadcast_indices(indices, device): assert all( i.get_dtype() in (torch.int64, torch.int32, torch.bool, torch.uint8) for i in indices if i is not None ), f"indices must be int64, byte or bool. Got {[i.get_dtype() for i in indices if i is not None]}" - assert all( - [i.get_dtype() in (torch.int32, torch.int64) for i in indices if i is not None] - ), "bool indices are not supported yet" + if any( + i.get_dtype() in (torch.bool, torch.uint8) for i in indices if i is not None + ): + raise NotImplementedError("Fallback for bool indices") + valid_idxs = [i for i, x in enumerate(indices) if isinstance(x, TensorBox)] assert len(valid_idxs) > 0, "requires at least 1 non-None index" new_indices = [None] * len(indices) for i, x in zip(valid_idxs, broadcast_tensors(*[indices[i] for i in valid_idxs])): + # Eager allows indices to be CPU tensor when running on CUDA + # FIXME: Calling to_device(x, device) should work but + # test_advancedindex_mixed_cpu_devices still fails + if x.get_device() != device: + raise NotImplementedError("Fallback when indices is on a different device") new_indices[i] = x output_dim = len(x.get_size()) start_offset = 0 @@ -1675,9 +1829,10 @@ def check_and_broadcast_indices(indices): while tmp and tmp[0] is None: tmp.pop(0) start_offset += 1 - assert all((i is not None) for i in tmp) - end_offset = output_dim + start_offset + if any((i is None) for i in tmp): + raise NotImplementedError("Fallback when None is in the middle of indices") + end_offset = output_dim + start_offset return new_indices, start_offset, end_offset @@ -1685,10 +1840,18 @@ def check_and_broadcast_indices(indices): def index(x, indices): assert isinstance(indices, (list, tuple)) x_loader = x.make_loader() - indices, start_offset, end_offset = check_and_broadcast_indices(indices) + try: + indices, start_offset, end_offset = check_and_broadcast_indices( + indices, x.get_device() + ) + except NotImplementedError: + x.realize() + return fallback_handler(aten.index)(x, indices) + indices_sizes = [i.get_size() for i in indices if i is not None] indices_loaders = [i.make_loader() for i in indices if i is not None] # no guards on output size, all the guards are set in broadcast_tensors + output_size = list(indices_sizes[0]) x_size = x.get_size() @@ -1769,7 +1932,12 @@ def index_put_(self, indices, values, accumulate=False): return self values = to_dtype(values, self.get_dtype()) - indices, start_offset, end_offset = check_and_broadcast_indices(indices) + try: + indices, start_offset, end_offset = check_and_broadcast_indices( + indices, self.get_device() + ) + except NotImplementedError: + return index_put_fallback(self, indices, values, accumulate) indices_sizes = [i.get_size() for i in indices if i is not None] indices_loaders = [i.make_loader() for i in indices if i is not None] @@ -1820,19 +1988,50 @@ def output_indexer(index): return self +@register_lowering(aten.as_strided_scatter, type_promotion_kind=None) +def as_strided_scatter(self, src, size, stride, storage_offset=None): + output = clone(self) + output_view = as_strided(output, size, stride, storage_offset) + copy_(output_view, src) + return output + + @register_lowering(aten.scatter, type_promotion_kind=None) def scatter(x, dim: int, index, src, **kwargs): return scatter_(clone(x), dim, index, src, **kwargs) +def scatter_fallback( + fn, self, dim: int, index, src, *, reduce: str = None, include_self: bool = True +): + + if reduce not in {None, "sum"} or ( + reduce == "sum" and self.get_dtype() in {torch.bool, torch.int64} + ): + self.realize() + return fallback_handler(fn)( + self, dim, index, src, reduce=reduce, include_self=include_self + ) + + return None + + @register_lowering(aten.scatter_, type_promotion_kind=None) def scatter_(self, dim: int, index, src, *, reduce: str = None): + if reduce == "add": reduce = "sum" elif reduce == "multiply": reduce = "prod" else: assert reduce is None + + fallback_result = scatter_fallback( + aten.scatter_, self, dim, index, src, reduce=reduce + ) + + if fallback_result: + return fallback_result return scatter_reduce_(self, dim, index, src, reduce) @@ -1858,15 +2057,18 @@ def scatter_reduce(x, dim: int, index, src, reduction_type, **kwargs): def scatter_reduce_(self, dim: int, index, src, reduce, *, include_self: bool = True): assert reduce in {None, "sum", "prod", "mean", "amax", "amin"} - # TODO: Need to support more reduction type - # For reduction of "sum", tl.atomic_add doesn't support bool or int64 - if reduce not in {None, "sum"} or ( - reduce == "sum" and self.get_dtype() in {torch.bool, torch.int64} - ): - self.realize() - return fallback_scatter_reduce_( - self, dim, index, src, reduce, include_self=include_self - ) + fallback_result = scatter_fallback( + aten.scatter_reduce_, + self, + dim, + index, + src, + reduce=reduce, + include_self=include_self, + ) + + if fallback_result: + return fallback_result assert isinstance(self, TensorBox) assert "int" in str(index.get_dtype()) @@ -1948,20 +2150,20 @@ def backend_reduce_str(reduce): return self -def upsample_nearestnd(x, output_size=None, scale_factors=None, n=2): +def upsample_nearestnd(x, output_size, scales_x: Tuple[float] = None, n: int = 2): x.realize_hint() # elements are reused x_loader = x.make_loader() i_sizes = x.get_size()[-n:] batch = x.get_size()[:-n] i_sizes = [V.graph.sizevars.guard_static_shape(i) for i in i_sizes] - if scale_factors: - assert not output_size - o_sizes = [int(i * s) for i, s in zip(i_sizes, scale_factors)] - else: - o_sizes = output_size + assert len(scales_x) == n + o_sizes = output_size scales = [i / o for i, o in zip(i_sizes, o_sizes)] + for i, scale in enumerate(scales): + if scale: + scales[i] = scale def scale(x, scale): x = ops.index_expr(x, torch.float32) @@ -1982,9 +2184,27 @@ def fn(idx): ) -register_lowering(aten.upsample_nearest1d)(functools.partial(upsample_nearestnd, n=1)) -register_lowering(aten.upsample_nearest2d)(functools.partial(upsample_nearestnd, n=2)) -register_lowering(aten.upsample_nearest3d)(functools.partial(upsample_nearestnd, n=3)) +@register_lowering(aten.upsample_nearest1d.default) +def upsample_nearest1d(x, output_size, scales: Optional[float] = None): + return upsample_nearestnd(x, output_size, (scales,), n=1) + + +@register_lowering(aten.upsample_nearest2d.default) +def upsample_nearest2d( + x, output_size, scales_h: Optional[float] = None, scales_w: Optional[float] = None +): + return upsample_nearestnd(x, output_size, (scales_h, scales_w), n=2) + + +@register_lowering(aten.upsample_nearest3d.default) +def upsample_nearest3d( + x, + output_size, + scales_d: Optional[float] = None, + scales_h: Optional[float] = None, + scales_w: Optional[float] = None, +): + return upsample_nearestnd(x, output_size, (scales_d, scales_h, scales_w), n=3) @register_lowering(aten.upsample_bicubic2d.default) @@ -2102,26 +2322,6 @@ def get_x_interp(y): ) -@register_lowering(aten.upsample_bicubic2d.vec) -def upsample_bicubic2d_vec( - a, - output_size, - align_corners: bool, - scale_factors: Optional[Tuple[float, float]] = None, -): - _, _, iH, iW = a.get_size() - iH = V.graph.sizevars.guard_static_shape(iH) - iW = V.graph.sizevars.guard_static_shape(iW) - - if bool(output_size) + bool(scale_factors) != 1: - raise RuntimeError("Must specify exactly one of output_size and scale_factor.") - if output_size is None: - assert scale_factors is not None - output_size = (int(iH * scale_factors[0]), int(iW * scale_factors[1])) - scale_h, scale_w = scale_factors if scale_factors else (None, None) - return upsample_bicubic2d_default(a, output_size, align_corners, scale_h, scale_w) - - @register_lowering(aten.reflection_pad2d) def reflection_pad2d(x, padding): assert len(padding) == 4 @@ -2352,6 +2552,9 @@ def pooling_size(x, i, kernel_size, stride, padding, ceil_mode): return x_out, ceil_mode +fallback_max_pool2d_with_indices = fallback_handler(aten.max_pool2d_with_indices) + + @register_lowering(aten.max_pool2d_with_indices, type_promotion_kind=None) def max_pool2d_with_indices( x, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False @@ -2380,6 +2583,13 @@ def max_pool2d_with_indices( x_loader = x.make_loader() new_size = list(batch) + [h_out, w_out] + window_size = kernel_size[0] * kernel_size[1] + + if window_size > 25: + # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. + return fallback_max_pool2d_with_indices( + x, kernel_size, stride, padding, dilation, ceil_mode + ) def fn(idx, return_index): *prefix, bh, bw = idx @@ -2389,12 +2599,15 @@ def fn(idx, return_index): ih = bh * stride[0] + ih - padding[0] iw = bw * stride[1] + iw - padding[1] val = x_loader([*prefix, ih, iw]) - index = ops.index_expr(ih * w + iw, torch.int64) + if return_index: + index = ops.index_expr(ih * w + iw, torch.int64) + if maxindex is None: + maxindex = index + else: + maxindex = ops.where(ops.gt(val, maxval), index, maxindex) if maxval is None: - maxindex = index maxval = val else: - maxindex = ops.where(ops.gt(val, maxval), index, maxindex) maxval = ops.maximum(val, maxval) if return_index: return maxindex @@ -2417,6 +2630,11 @@ def fn(idx, return_index): return r1, r2 +fallback_max_pool2d_with_indices_backward = fallback_handler( + aten.max_pool2d_with_indices_backward +) + + @register_lowering(aten.max_pool2d_with_indices_backward, type_promotion_kind=None) def max_pool2d_with_indices_backward( grad_output, x, kernel_size, stride, padding, dilation, ceil_mode, indices @@ -2457,6 +2675,14 @@ def max_pool2d_with_indices_backward( ] ) + window_size = h_window_size * w_window_size + + if window_size > 25: + # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. + return fallback_max_pool2d_with_indices_backward( + grad_output, x, kernel_size, stride, padding, dilation, ceil_mode, indices + ) + def fn(idx): *prefix, h, w = idx index_test = ops.index_expr(h * width + w, torch.int32) @@ -2579,6 +2805,9 @@ def fn_sum(idx, loader): return fn_sum +fallback_adaptive_avg_pool2d = fallback_handler(aten._adaptive_avg_pool2d) + + @register_lowering(aten._adaptive_avg_pool2d) def _adaptive_avg_pool2d(x, output_size): assert isinstance(x, TensorBox) @@ -2618,6 +2847,11 @@ def end_index(index, out_dim, inp_dim): w_start_index = functools.partial(start_index, out_dim=w_out, inp_dim=w_in) w_end_index = functools.partial(end_index, out_dim=w_out, inp_dim=w_in) + window_size = h_kernel_max * w_kernel_max + if window_size > 25: + # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. + return fallback_adaptive_avg_pool2d(x, output_size) + fn_sum = _adaptive_pooling_idx_sum( [h_kernel_max, w_kernel_max], [h_start_index, w_start_index], @@ -2639,9 +2873,9 @@ def fn(idx): return rv -@register_lowering(aten.upsample_nearest2d_backward.vec) +@register_lowering(aten.upsample_nearest2d_backward.default) def upsample_nearest2d_backward( - x, output_size=None, input_size=None, scale_factors=None + x, output_size=None, input_size=None, scales_h=None, scales_w=None ): x.realize_hint() @@ -2688,6 +2922,9 @@ def fn(idx): return rv +fallback_avg_pool2d = fallback_handler(aten.avg_pool2d) + + @register_lowering(aten.avg_pool2d, type_promotion_kind=None) def avg_pool2d( x, @@ -2725,6 +2962,19 @@ def avg_pool2d( new_size = list(batch) + [h_out, w_out] dtype = x.get_dtype() + window_size = kernel_size[0] * kernel_size[1] + if window_size > 25: + # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. + return fallback_avg_pool2d( + x, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + ) + def fn_sum(idx, loader): *prefix, bh, bw = idx total = None @@ -2764,6 +3014,9 @@ def fn(idx): return rv +fallback_avg_pool2d_backward = fallback_handler(aten.avg_pool2d_backward) + + @register_lowering(aten.avg_pool2d_backward, type_promotion_kind=None) def avg_pool2d_backward( grad_output, @@ -2817,6 +3070,20 @@ def avg_pool2d_backward( ] ) + window_size = h_window_size * w_window_size + if window_size > 25: + # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. + return fallback_avg_pool2d_backward( + grad_output, + x, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + ) + def compute_pool_size_without_padding(ph, pw): """ This computes the scaling factor that we will divide an element @@ -2922,7 +3189,7 @@ def _validate_reduction_axis(x, axis): axis = list(axis) for i in range(len(axis)): if axis[i] < 0: - axis[i] += len(size) + axis[i] += len(size) if len(size) else 1 assert 0 <= axis[i] < len(size) or (len(size) == 0 and axis[i] == 0) assert len(set(axis)) == len(axis), "reduction axis not unique" return axis @@ -3019,7 +3286,7 @@ def mean(x, axis=None, keepdim=False, *, dtype=None): @register_lowering([aten.var, prims.var]) -def var_(x, axis, correction=1, keepdim=False): +def var_(x, axis=None, correction=1, keepdim=False): size = x.get_size() axis = _validate_reduction_axis(x, axis) diffs = square(sub(x, mean(x, axis, keepdim=True))) @@ -3034,7 +3301,7 @@ def var_(x, axis, correction=1, keepdim=False): @register_lowering(aten.var_mean) -def var_mean(x, dim, unbiased=True, keepdim=False, correction=None): +def var_mean(x, dim=None, unbiased=True, keepdim=False, correction=None): if correction is None: correction = int(unbiased) return [ @@ -3044,7 +3311,7 @@ def var_mean(x, dim, unbiased=True, keepdim=False, correction=None): @register_lowering(aten.std) -def std(x, axis, correction=1, keepdim=False): +def std(x, axis=None, correction=1, keepdim=False): return sqrt(var_(x, axis, correction, keepdim=keepdim)) @@ -3085,6 +3352,8 @@ def pow(a, b): ), "Pow input must be floating point." if isinstance(b, float) and b == int(b): return pow(a, int(b)) + elif isinstance(b, float) and b == 0.5: + return sqrt(a) elif isinstance(b, int) and b == 1: return a elif isinstance(b, int) and -32 < b < 32: @@ -3160,7 +3429,7 @@ def truncdiv(a, b): return ops.truncdiv(a, b) -@register_lowering(aten.div.Tensor_mode) +@register_lowering(aten.div, broadcast=True) def div_mode(a, b, rounding_mode=None): both_integer = is_integer_type(a) and is_integer_type(b) both_boolean = is_boolean_type(a) and is_boolean_type(b) @@ -3176,23 +3445,6 @@ def div_mode(a, b, rounding_mode=None): return div(a, b) -@register_lowering([aten.div], broadcast=True) -def div(a, b): - def fn(*args): - return ops.div(*args) - - dtype = get_promoted_dtype( - a, b, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT - ) - # truediv produces a float tensor even if both operands are integer types - if is_integer_type(a) and is_integer_type(b): - dtype = torch.get_default_dtype() - return make_pointwise(fn, override_return_dtype=dtype)( - a if isinstance(a, Number) else to_dtype(a, dtype), - b if isinstance(b, Number) else to_dtype(b, dtype), - ) - - @register_lowering([aten.mul], broadcast=True) def mul(a, b): both_bool = is_boolean_type(a) and is_boolean_type(b) @@ -3203,32 +3455,55 @@ def mul(a, b): return make_pointwise(fn)(a, b) -# TODO(lezcano) I believe the casting behaviour of prims.div is wrong -# https://github.com/pytorch/pytorch/issues/84412 -# div prim performs truncation division on integer inputs -# and true division for floating and complex inputs +# NOTE: prims.div maps to a / b in C, so performs truncation division on +# integer inputs and true division for floating and complex inputs. @register_lowering([prims.div], broadcast=True) def div_prim(a, b): is_integral = is_boolean_type(a) or is_integer_type(a) if is_integral: - return div_mode(a, b, rounding_mode="floor") + return truncdiv(a, b) + + def fn(*args): + return ops.div(*args) + + return make_pointwise(fn)(a, b) + + +div = register_lowering( + [aten.true_divide, aten.div.Tensor], + broadcast=True, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +)(div_prim) + + +@register_lowering([aten.fmod, prims.fmod], broadcast=True) +def fmod(a, b): + is_integral = is_boolean_type(a) or is_integer_type(a) + + if is_integral: + + def fn(a, b): + return ops.mod(a, b) + else: - return div(a, b) + def fn(a, b): + return ops.fmod(a, b) -# TODO - enable builtin and disable decomp to lower to ptx instruction -# Causes compilation to not complete on timm_vision_transformers inference -# @register_lowering(aten.rsqrt) -# def rsqrt(x): -# dtype = x.get_dtype() -# if is_integer_dtype(dtype) or is_boolean_dtype(dtype): -# x = to_dtype(x, torch.get_default_dtype()) -# -# def _rsqrt(x): -# return ops.rsqrt(x) -# -# return make_pointwise(_rsqrt)(x) + return make_pointwise(fn)(a, b) + + +@register_lowering(aten.rsqrt) +def rsqrt(x): + dtype = x.get_dtype() + if is_integer_dtype(dtype) or is_boolean_dtype(dtype): + x = to_dtype(x, torch.get_default_dtype()) + + def _rsqrt(x): + return ops.rsqrt(x) + + return make_pointwise(_rsqrt)(x) @register_lowering([aten.sum, prims.sum]) @@ -3294,6 +3569,23 @@ def sum_(x, axis=None, keepdims=False, *, dtype=None): register_pointwise( aten.lgamma, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT ) +erf = register_pointwise( + aten.erf, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT +) +register_lowering( + aten.special_erf, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT +)(erf) + +register_pointwise( + aten.log1p, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) + +register_pointwise( + aten.expm1, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) + register_pointwise( aten.log, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, @@ -3309,7 +3601,6 @@ def sum_(x, axis=None, keepdims=False, *, dtype=None): register_pointwise(aten.remainder) register_pointwise(aten.sign, override_fn_when_input_bool="identity") register_pointwise(aten.ceil) -register_pointwise(aten.fmod) register_pointwise(aten.signbit, override_return_dtype=torch.bool) register_pointwise(aten.le, type_promotion_kind=None, override_return_dtype=torch.bool) @@ -3347,7 +3638,8 @@ def fn(*args, **kwargs): register_inplace(aten.add_, add) register_inplace(aten.mul_, mul) -register_inplace(aten.div_, div) +register_inplace(aten.div_.Tensor, div) +register_inplace(aten.div_.Tensor_mode, div_mode) register_inplace(aten.sub_, sub) register_inplace(aten.relu_, relu) register_inplace(aten.sigmoid_, sigmoid) @@ -3381,3 +3673,9 @@ def op_floordiv(a, b): @register_lowering(aten._foobar) def foobar(self, *args, **kwargs): raise NotImplementedError("Helpful for debugging") + + +@register_lowering(aten._test_inductor_realize) +def _realize(x): + x.realize() + return clone(x) diff --git a/torch/_inductor/metrics.py b/torch/_inductor/metrics.py index b94badf93289e..fe4fe07529de5 100644 --- a/torch/_inductor/metrics.py +++ b/torch/_inductor/metrics.py @@ -1,8 +1,22 @@ # counter for tracking how many kernels have been generated generated_kernel_count = 0 +generated_cpp_vec_kernel_count = 0 +num_bytes_accessed = 0 +nodes_num_elem = [] + +# counters for tracking fusions +ir_nodes_pre_fusion = 0 # reset all counters def reset(): global generated_kernel_count + global generated_cpp_vec_kernel_count + global num_bytes_accessed, nodes_num_elem + global ir_nodes_pre_fusion + generated_kernel_count = 0 + generated_cpp_vec_kernel_count = 0 + num_bytes_accessed = 0 + nodes_num_elem.clear() + ir_nodes_pre_fusion = 0 diff --git a/torch/_inductor/overrides.py b/torch/_inductor/overrides.py index 85a0e0c1c2459..bf66e68fed624 100644 --- a/torch/_inductor/overrides.py +++ b/torch/_inductor/overrides.py @@ -1,12 +1,30 @@ +import copy +import itertools import logging +import operator import random import weakref +from typing import Optional + +import numpy import torch +import torch.nn as nn from torch import _prims +from torch._dynamo.utils import fake_mode_from_tensors +from torch.fx.experimental.optimization import ( + matches_module_pattern, + replace_node_module, +) from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode +from torch.fx.passes.shape_prop import ShapeProp +from torch.nn import functional as F +from torch.nn.modules.utils import _pair +from torch.nn.utils.fusion import fuse_conv_bn_eval, fuse_conv_bn_weights from torch.overrides import TorchFunctionMode +from . import config + log = logging.getLogger(__name__) @@ -26,6 +44,11 @@ def replace_fx(gm: torch.fx.GraphModule): # Sometimes patch_functions() misses things already in the graph for node in reversed(list(gm.graph.nodes)): if node.op == "call_function" and node.target in replacements: + if ( + config.fallback_random + and replacements[node.target] in replacements_using_triton_random + ): + continue with gm.graph.inserting_before(node): node.replace_all_uses_with( gm.graph.call_function( @@ -37,6 +60,684 @@ def replace_fx(gm: torch.fx.GraphModule): return gm +class UnaryAttr(object): + def __init__(self, op_name: str, scalars_attr=None, algorithm_attr=None): + self.op_name = op_name + self.scalars_attr = scalars_attr if scalars_attr else [] + self.algorithm_attr = algorithm_attr if algorithm_attr else "" + super(UnaryAttr, self).__init__() + + def __call__(self, unary_module: nn.Module): + if type(unary_module) is nn.ReLU6: + unary_module = nn.Hardtanh(min_val=0, max_val=6) + assert all(hasattr(unary_module, item) for item in self.scalars_attr) + scalars = [getattr(unary_module, item) for item in self.scalars_attr] + + algorithm = "" + if self.algorithm_attr: + assert hasattr(unary_module, self.algorithm_attr) + algorithm = getattr(unary_module, self.algorithm_attr) + + return self.op_name, scalars, algorithm + + +class ConvUnary2d(nn.Conv2d): + def __init__( + self, + conv: nn.Module, + unary: Optional[nn.Module], + input_size: list, + ): + super(ConvUnary2d, self).__init__( + conv.in_channels, + conv.out_channels, + conv.kernel_size, + conv.stride, + conv.padding, + conv.dilation, + conv.groups, + conv.bias is not None, + conv.padding_mode, + conv.weight.device, + conv.weight.dtype, + ) + self._update_module_params(conv, unary, input_size) + + def _update_module_params(self, conv, unary, input_size): + self.__dict__ = copy.deepcopy(conv.__dict__) + self.attr = "none" + self.scalars = [] + self.algorithm = "" + if unary is not None: + self.attr, self.scalars, self.algorithm = unary_modules_map[ + unary.__class__ + ](unary) + self.weight = torch.nn.Parameter( + torch._C._nn.mkldnn_reorder_conv2d_weight( + self.weight.to_mkldnn(), + self.padding, + self.stride, + self.dilation, + self.groups, + input_size, + ), + requires_grad=self.weight.requires_grad, + ) + + def _conv_forward(self, input, weight, bias): + if self.padding_mode != "zeros": + return torch.ops.mkldnn._convolution_pointwise( + F.pad( + input, self._reversed_padding_repeated_twice, mode=self.padding_mode + ), + weight, + bias, + _pair(0), + self.stride, + self.dilation, + self.groups, + self.attr, + self.scalars, + self.algorithm, + ) + return torch.ops.mkldnn._convolution_pointwise( + input, + weight, + bias, + self.padding, + self.stride, + self.dilation, + self.groups, + self.attr, + self.scalars, + self.algorithm, + ) + + def forward(self, input): + return self._conv_forward(input, self.weight, self.bias) + + +class ConvBinary2d(nn.Conv2d): + def __init__( + self, + conv: nn.Module, + binary_op_name: str, + input_size: list, + ): + super(ConvBinary2d, self).__init__( + conv.in_channels, + conv.out_channels, + conv.kernel_size, + conv.stride, + conv.padding, + conv.dilation, + conv.groups, + conv.bias is not None, + conv.padding_mode, + conv.weight.device, + conv.weight.dtype, + ) + self._update_module_params(conv, binary_op_name, input_size) + + def _update_module_params(self, conv, binary_op_name, input_size): + self.__dict__ = copy.deepcopy(conv.__dict__) + self.binary_attr = binary_op_name + self.binary_alpha = None + self.unary_attr = None + self.unary_scalars = [] + self.unary_algorithm = None + self.weight = torch.nn.Parameter( + torch._C._nn.mkldnn_reorder_conv2d_weight( + self.weight.to_mkldnn(), + self.padding, + self.stride, + self.dilation, + self.groups, + input_size, + ), + requires_grad=self.weight.requires_grad, + ) + + def _update_unary_params(self, unary): + self.unary_attr, self.unary_scalars, self.unary_algorithm = unary_modules_map[ + unary.__class__ + ](unary) + + def _conv_forward(self, input, other, weight, bias): + if self.padding_mode != "zeros": + return torch.ops.mkldnn._convolution_pointwise( + F.pad( + input, self._reversed_padding_repeated_twice, mode=self.padding_mode + ), + other, + weight, + bias, + _pair(0), + self.stride, + self.dilation, + self.groups, + self.binary_attr, + self.binary_alpha, + self.unary_attr, + self.unary_scalars, + self.unary_algorithm, + ) + return torch.ops.mkldnn._convolution_pointwise( + input, + other, + weight, + bias, + self.padding, + self.stride, + self.dilation, + self.groups, + self.binary_attr, + self.binary_alpha, + self.unary_attr, + self.unary_scalars, + self.unary_algorithm, + ) + + def forward(self, input, other): + return self._conv_forward(input, other, self.weight, self.bias) + + +class ConvBinaryInplace2d(nn.Conv2d): + def __init__( + self, + conv: nn.Module, + binary_op_name: str, + input_size: list, + ): + super(ConvBinaryInplace2d, self).__init__( + conv.in_channels, + conv.out_channels, + conv.kernel_size, + conv.stride, + conv.padding, + conv.dilation, + conv.groups, + conv.bias is not None, + conv.padding_mode, + conv.weight.device, + conv.weight.dtype, + ) + self._update_module_params(conv, binary_op_name, input_size) + + def _update_module_params(self, conv, binary_op_name, input_size): + self.__dict__ = copy.deepcopy(conv.__dict__) + self.binary_attr = binary_op_name + self.binary_alpha = None + self.unary_attr = None + self.unary_scalars = [] + self.unary_algorithm = None + self.weight = torch.nn.Parameter( + torch._C._nn.mkldnn_reorder_conv2d_weight( + self.weight.to_mkldnn(), + self.padding, + self.stride, + self.dilation, + self.groups, + input_size, + ), + requires_grad=self.weight.requires_grad, + ) + + def _update_unary_params(self, unary): + self.unary_attr, self.unary_scalars, self.unary_algorithm = unary_modules_map[ + unary.__class__ + ](unary) + + def _conv_forward(self, input, other, weight, bias): + if self.padding_mode != "zeros": + return torch.ops.mkldnn._convolution_pointwise_( + F.pad( + input, self._reversed_padding_repeated_twice, mode=self.padding_mode + ), + other, + weight, + bias, + _pair(0), + self.stride, + self.dilation, + self.groups, + self.binary_attr, + self.binary_alpha, + self.unary_attr, + self.unary_scalars, + self.unary_algorithm, + ) + return torch.ops.mkldnn._convolution_pointwise_( + input, + other, + weight, + bias, + self.padding, + self.stride, + self.dilation, + self.groups, + self.binary_attr, + self.binary_alpha, + self.unary_attr, + self.unary_scalars, + self.unary_algorithm, + ) + + def forward(self, input, other): + return self._conv_forward(input, other, self.weight, self.bias) + + +class PackedLinear(nn.Linear): + def __init__(self, linear: nn.Module, input_size: list): + super(PackedLinear, self).__init__( + linear.in_features, + linear.out_features, + linear.bias is not None, + linear.weight.device, + linear.weight.dtype, + ) + self._update_module_params(linear, input_size) + + def _update_module_params(self, linear, input_size): + self.__dict__ = copy.deepcopy(linear.__dict__) + self.batch_size = int(numpy.prod(input_size) / input_size[-1]) + self.packed_weight = torch.nn.Parameter( + torch.ops.mkl._mkl_reorder_linear_weight( + self.weight.to_mkldnn(), self.batch_size + ), + requires_grad=self.weight.requires_grad, + ) + + def forward(self, input): + y = torch.ops.mkl._mkl_linear( + input, self.packed_weight, self.weight, self.bias, self.batch_size + ) + return y + + +class LinearUnary(nn.Linear): + def __init__( + self, + linear: nn.Module, + unary: nn.Module, + ): + super(LinearUnary, self).__init__( + linear.in_features, + linear.out_features, + linear.bias is not None, + linear.weight.device, + linear.weight.dtype, + ) + self._update_module_params(linear, unary) + + def _update_module_params(self, linear, unary): + self.__dict__ = copy.deepcopy(linear.__dict__) + self.attr, self.scalars, self.algorithm = unary_modules_map[unary.__class__]( + unary + ) + + def forward(self, input): + y = torch.ops.mkldnn._linear_pointwise( + input, self.weight, self.bias, self.attr, self.scalars, self.algorithm + ) + return y + + +class LinearBinary(nn.Linear): + def __init__(self, linear: nn.Module, binary_op_name: str): + super(LinearBinary, self).__init__( + linear.in_features, + linear.out_features, + linear.bias is not None, + linear.weight.device, + linear.weight.dtype, + ) + self._update_module_params(linear, binary_op_name) + + def _update_module_params(self, linear, binary_op_name): + self.__dict__ = copy.deepcopy(linear.__dict__) + + self.attr = binary_op_name + + def forward(self, input, other): + y = torch.ops.mkldnn._linear_pointwise( + input, other, self.weight, self.bias, self.attr + ) + return y + + +def packed_conv_eval(conv: nn.Module, input_size: list): + assert not (conv.training), "Fusion only for eval!" + return ConvUnary2d( + conv, + None, + input_size, + ) + + +def fused_conv_unary_eval(conv: nn.Module, unary: nn.Module, input_size: list): + assert not (conv.training), "Fusion only for eval!" + return ConvUnary2d( + conv, + unary, + input_size, + ) + + +def fused_conv_binary_eval(conv: nn.Module, binary_op_name: str, input_size: list): + assert not (conv.training), "Fusion only for eval!" + return ConvBinary2d( + conv, + binary_op_name, + input_size, + ) + + +def fused_conv_binary_inplace_eval( + conv: nn.Module, binary_op_name: str, input_size: list +): + assert not (conv.training), "Fusion only for eval!" + return ConvBinaryInplace2d( + conv, + binary_op_name, + input_size, + ) + + +def fused_conv_binary_unary_eval( + conv_binary: nn.Module, unary: nn.Module, input_size: list +): + assert not (conv_binary.training), "Fusion only for eval!" + # reuse origin conv module, and just update its' unary attr. + conv_binary._update_unary_params(unary) + return conv_binary + + +def is_bfloat16_module(m): + weight_is_bf16 = m.weight.dtype == torch.bfloat16 + bias_is_bf16 = m.bias is None or m.bias.dtype == torch.bfloat16 + return weight_is_bf16 and bias_is_bf16 + + +def packed_linear_eval(linear: nn.Module, input_size: list): + assert not (linear.training), "Fusion only for eval!" + return PackedLinear(linear, input_size) + + +def fused_linear_unary_eval(linear: nn.Module, unary: nn.Module, input_size: list): + assert not (linear.training), "Fusion only for eval!" + return LinearUnary( + linear, + unary, + ) + + +def fused_linear_binary_eval(linear: nn.Module, attr: str, input_size: list): + assert not (linear.training), "Fusion only for eval!" + linear_binary = LinearBinary( + linear, + attr, + ) + return linear_binary + + +def check_node_kind(current_node, modules, node_kind): + if not isinstance(current_node, torch.fx.Node): + return False + if current_node.op != "call_module": + return False + if not isinstance(current_node.target, str): + return False + if current_node.target not in modules: + return False + if type(modules[current_node.target]) is not node_kind: + return False + return True + + +def check_node_is_binary(node): + return ( + (node.op == "call_function" and node.target in [torch.add, torch.sub]) + or ( + node.op == "call_function" + and node.target + in [operator.add, operator.iadd, operator.sub, operator.isub] + ) + or (node.op == "call_method" and node.target in ["add", "add_", "sub", "sub_"]) + ) + + +def check_binary_op_kwargs_is_default(node): + # For binary op, we hope the kwargs values are the default value: + # torch.sub(add)(input, other, *, alpha=1, out=None). + if len(node.args) > 2: + return False + if len(node.kwargs) > 0: + if "out" in node.kwargs and node.kwargs["out"] is not None: + return False + if "alpha" in node.kwargs and node.kwargs["alpha"] != 1.0: + return False + return True + + +def check_node_is_add_inplace(node): + return (node.op == "call_function" and node.target in [operator.iadd]) or ( + node.op == "call_method" and node.target in ["add_"] + ) + + +def fuse_fx(gm: torch.fx.GraphModule, example_inputs): + is_cpu = all( + example_input.device == torch.device("cpu") for example_input in example_inputs + ) + + fake_mode = fake_mode_from_tensors(example_inputs) + + if config.permute_fusion and not is_cpu: + # For linear permute fusion, we need to check input info to identify + # and perform proper permutation/transpose + ShapeProp(gm, fake_mode=fake_mode).propagate(*example_inputs) + gm = linear_permute_fusion(gm) + gm = permute_linear_fusion(gm) + gm = permute_matmul_fusion(gm) + + # make sure the autograd is disabled. + if torch.is_grad_enabled(): + return gm + if not (torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available()): + return gm + if not is_cpu: + return gm + gm = remove_identity(gm) + gm = fuse_conv_bn(gm) + # For binary fusion, we need to check inputs info to make sure + # the binary inputs have same tensor info(device, dtype, and layout). + + ShapeProp(gm, fake_mode=fake_mode).propagate(*example_inputs) + gm = fuse_unary(gm) + gm = fuse_binary_inplace(gm) + gm = fuse_binary(gm) + # why re-run fuse_unary? we want to enable conv+binary+unary fusion, + # such as conv+add+relu for vision model. + gm = fuse_unary(gm) + gm = pack_module(gm) + return gm + + +# check the pattern: (nn.module, F.function) matched. +def matches_module_function_pattern(pattern, node, modules): + if len(node.args) == 0: + return False + if not isinstance(node.args[0], torch.fx.Node) or not isinstance( + node, torch.fx.Node + ): + return False + # the first node is call_module + if node.args[0].op != "call_module": + return False + if not isinstance(node.args[0].target, str): + return False + if node.args[0].target not in modules: + return False + if type(modules[node.args[0].target]) is not pattern[0]: + return False + # the second node is call_function + if node.op != "call_function": + return False + if node.target != pattern[1]: + return False + # make sure node.args[0] output is only used by current node. + if len(node.args[0].users) > 1: + return False + return True + + +def fetch_attr(target: str, mod): + target_atoms = target.split(".") + attr_itr = mod + for i, atom in enumerate(target_atoms): + if not hasattr(attr_itr, atom): + raise RuntimeError( + f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}" + ) + attr_itr = getattr(attr_itr, atom) + return attr_itr + + +def remove_identity(gm: torch.fx.GraphModule): + """ + Removes all identity layers from the module. + """ + + class IdentityRemover(torch.fx.Transformer): + def call_module(self, target, args, kwargs): + if isinstance(self.submodules[target], nn.Identity): + assert len(args) == 1 + return args[0] + else: + return super().call_module(target, args, kwargs) + + return IdentityRemover(gm).transform() + + +def fuse_conv_bn(gm: torch.fx.GraphModule, inplace=False): + """ + Fuses Convolution/BN layers for inference purposes. + """ + modules_patterns = [ + (torch.nn.Conv1d, torch.nn.BatchNorm1d), + (torch.nn.Conv2d, torch.nn.BatchNorm2d), + (torch.nn.Conv3d, torch.nn.BatchNorm3d), + ] + module_function_patterns = [ + (torch.nn.Conv1d, F.batch_norm), + (torch.nn.Conv2d, F.batch_norm), + (torch.nn.Conv3d, F.batch_norm), + ] + modules = dict(gm.named_modules()) + for pattern in modules_patterns: + for node in gm.graph.nodes: + if matches_module_pattern(pattern, node, modules): + if len(node.args[0].users) > 1: # Output of conv is used by other nodes + continue + conv = modules[node.args[0].target] + bn = modules[node.target] + eval_mode = all(not n.training for n in [conv, bn]) + if not eval_mode: + continue + if not bn.track_running_stats: + continue + fused_conv = fuse_conv_bn_eval(conv, bn) + replace_node_module(node.args[0], modules, fused_conv) + node.replace_all_uses_with(node.args[0]) + gm.graph.erase_node(node) + gm.graph.lint() + for pattern in module_function_patterns: + for node in gm.graph.nodes: + if matches_module_function_pattern(pattern, node, modules): + # TODO: support kwargs. + if len(node.args) != 8: + continue + conv = modules[node.args[0].target] + bn_training = node.args[5] + bn_eps = node.args[7] + if conv.training or bn_training: + continue + if type(bn_eps) is not float: + continue + bn_args_is_constant = all( + n.op == "get_attr" and len(n.users) == 1 for n in node.args[1:5] + ) + if not bn_args_is_constant: + continue + bn_running_mean = fetch_attr(node.args[1].target, gm) + bn_running_var = fetch_attr(node.args[2].target, gm) + bn_weight = fetch_attr(node.args[3].target, gm) + bn_bias = fetch_attr(node.args[4].target, gm) + if bn_running_mean is None or bn_running_var is None: + continue + fused_conv = copy.deepcopy(conv) + fused_conv.weight, fused_conv.bias = fuse_conv_bn_weights( + fused_conv.weight, + fused_conv.bias, + bn_running_mean, + bn_running_var, + bn_eps, + bn_weight, + bn_bias, + ) + replace_node_module(node.args[0], modules, fused_conv) + node.replace_all_uses_with(node.args[0]) + gm.graph.erase_node(node) + gm.graph.lint() + gm.recompile() + + return gm + + +def fuse_unary(gm: torch.fx.GraphModule): + modules = dict(gm.named_modules()) + + for (unary_module, _), (computation_module, fuse_func,) in itertools.product( + unary_modules_map.items(), computation_op_unary_op_fusion_map.items() + ): + pattern = (computation_module, unary_module) + for node in gm.graph.nodes: + if matches_module_pattern(pattern, node, modules): + if ( + len(node.args[0].users) > 1 + ): # Output of computation_node is used by other nodes + continue + computation_node = modules[node.args[0].target] + unary_node = modules[node.target] + eval_mode = all(not n.training for n in [computation_node, unary_node]) + if not eval_mode: + continue + # TODO: support padding str input("valid", "same"). + if type(computation_node) in [nn.Conv2d] and isinstance( + computation_node.padding, str + ): + continue + # only fuse for linear when the dtype is bf16 + if type(computation_node) in [nn.Linear] and not is_bfloat16_module( + computation_node + ): + continue + computation_node_input_size = ( + node.args[0].args[0].meta.get("tensor_meta").shape + ) + fused_module = fuse_func( + computation_node, unary_node, computation_node_input_size + ) + replace_node_module(node.args[0], modules, fused_module) + + node.replace_all_uses_with(node.args[0]) + gm.graph.erase_node(node) + gm.graph.lint() + gm.recompile() + return gm + + def _philox_rand_like_meta(input, seed, offset): return _prims.TensorMeta(input) @@ -46,6 +747,345 @@ def _philox_rand_like(input, seed, offset): return torch.rand_like(input) +class NormalizedLinearNode: + def __init__(self, node: torch.fx.Node) -> None: + assert node.op == "call_function" + assert node.target in [torch.nn.functional.linear] + self.node: torch.fx.Node = node + + def get_input(self) -> torch.fx.Node: + if len(self.node.args) > 0: + return self.node.args[0] + else: + return self.node.kwargs["input"] + + def get_weight(self) -> torch.fx.Node: + if len(self.node.args) > 1: + return self.node.args[1] + else: + return self.node.kwargs["weight"] + + def get_bias(self) -> torch.fx.Node: + if len(self.node.args) > 2: + return self.node.args[2] + else: + return self.node.kwargs["bias"] + + +class NormalizedMatmulNode: + def __init__(self, node: torch.fx.Node) -> None: + assert node.op == "call_function" + assert node.target in [torch.bmm, torch.matmul] + self.node: torch.fx.Node = node + + def get_input(self) -> torch.fx.Node: + if len(self.node.args) > 0: + return self.node.args[0] + else: + return self.node.kwargs["input"] + + def get_other(self) -> torch.fx.Node: + if len(self.node.args) > 1: + return self.node.args[1] + else: + return self.node.kwargs["other"] + + +def check_permute(node: torch.fx.Node): + ranks = len(node.meta["tensor_meta"].shape) + if len(node.args) > 3: + permutation = [node.args[i] % ranks for i in range(1, ranks + 1)] + elif ( + "permutation" in node.kwargs + and node.kwargs["permutation"] is not None + and len(node.kwargs["permutation"]) > 2 + ): + permutation = [i % ranks for i in node.kwargs["permutation"]] + else: + return False + allowed_permutation = list(range(ranks)) + allowed_permutation[-1] = ranks - 2 + allowed_permutation[-2] = ranks - 1 + return permutation == allowed_permutation + + +def linear_permute_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule: + for node in module.graph.nodes: + if ( + node.op == "call_method" + and node.target == "permute" + and check_permute(node) + ): + if len(node.args) > 0: + input_node = node.args[0] + else: + input_node = node.kwargs["input"] + if ( + input_node.op == "call_function" + and input_node.target == torch.nn.functional.linear + ): + normalized = NormalizedLinearNode(input_node) + input = normalized.get_input() + weight = normalized.get_weight() + bias = normalized.get_bias() + with module.graph.inserting_before(node): + fused_node = module.graph.call_function( + linear_transpose, args=(input, weight, bias) + ) + node.replace_all_uses_with(fused_node) + + module.graph.lint() + module.graph.eliminate_dead_code() + module.recompile() + return module + + +# Y1 = X * W^T + bias +# Y2 = Y1.permute(0, 2, 1) +# ----> +# Y2 = (W * X^T + bias.unsqueeze(-1))^T +def linear_transpose( + input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor +) -> torch.Tensor: + return torch.matmul(weight, input.transpose(-1, -2)) + bias.unsqueeze(-1) + + +def permute_linear_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule: + for node in module.graph.nodes: + if node.op == "call_function" and node.target == torch.nn.functional.linear: + if len(node.args) > 0: + input_node = node.args[0] + else: + input_node = node.kwargs["input"] + if ( + input_node.op == "call_method" + and input_node.target == "permute" + and check_permute(input_node) + ): + normalized = NormalizedLinearNode(node) + if len(input_node.args) > 0: + input = input_node.args[0] + else: + input = input_node.kwargs["input"] + weight = normalized.get_weight() + bias = normalized.get_bias() + with module.graph.inserting_before(node): + fused_node = module.graph.call_function( + transpose_linear, args=(input, weight, bias) + ) + node.replace_all_uses_with(fused_node) + + module.graph.lint() + module.graph.eliminate_dead_code() + module.recompile() + return module + + +def permute_matmul_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule: + for node in module.graph.nodes: + if node.op == "call_function" and ( + node.target == torch.bmm or node.target == torch.matmul + ): + normalized = NormalizedMatmulNode(node) + A = normalized.get_input() + B = normalized.get_other() + Atrans = Btrans = False + if A.op == "call_method" and A.target == "permute" and check_permute(A): + Atrans = True + if len(A.args) > 0: + A = A.args[0] + else: + A = A.kwargs["input"] + + if B.op == "call_method" and B.target == "permute" and check_permute(B): + Btrans = True + if len(B.args) > 0: + B = B.args[0] + else: + B = B.kwargs["input"] + + if Atrans or Btrans: + with module.graph.inserting_before(node): + fused_node = module.graph.call_function( + transpose_matmul, + args=(A, B, Atrans, Btrans), + ) + node.replace_all_uses_with(fused_node) + + module.graph.lint() + module.graph.eliminate_dead_code() + module.recompile() + return module + + +# X1 = X.permute(0, 2, 1) +# Y1 = X1 * W1^T + bias1 +# ----> +# Y2 = X1.transpose(-1, -2) * W1^T + bias1 +def transpose_linear( + input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor +) -> torch.Tensor: + return torch.matmul(input.transpose(-1, -2), weight.t()) + bias + + +def transpose_matmul(A: torch.Tensor, B: torch.Tensor, Atrans: bool, Btrans: bool): + if Atrans: + A = A.transpose(-1, -2) + if Btrans: + B = B.transpose(-1, -2) + return torch.matmul(A, B) + + +def replace_and_fuse_for_binary( + computation_node, node, fuse_func, attr, modules, index_node, index_pointwise +): + computation_node_input_size = ( + node.args[index_node].args[0].meta.get("tensor_meta").shape + ) + fused_module = fuse_func(computation_node, attr, computation_node_input_size) + replace_node_module(node.args[index_node], modules, fused_module) + node.args[index_node].args = node.args[index_node].args + ( + node.args[index_pointwise], + ) + node.replace_all_uses_with(node.args[index_node]) + + +def binary_inputs_meta_is_same(binary_node): + tensor0_meta = binary_node.args[0].meta.get("tensor_meta") + tensor1_meta = binary_node.args[1].meta.get("tensor_meta") + if not tensor0_meta or not tensor1_meta: + return False + if ( + tensor0_meta.shape != tensor1_meta.shape + or tensor0_meta.stride != tensor1_meta.stride + or tensor0_meta.dtype != tensor1_meta.dtype + ): + return False + + return True + + +def fuse_binary(gm: torch.fx.GraphModule): + modules = dict(gm.named_modules()) + for node in gm.graph.nodes: + if check_node_is_binary(node) and check_binary_op_kwargs_is_default(node): + for node_kind, fuse_func in computation_op_binary_op_fusion_map.items(): + if not isinstance(node.args[0], torch.fx.Node) or not isinstance( + node.args[1], torch.fx.Node + ): + continue + if not binary_inputs_meta_is_same(node): + continue + attr = binary_attr[node.target] + index_list = supported_index_list[attr] + for index_dict in index_list: + index_node = index_dict["index_computation"] + index_pointwise = index_dict["index_pointwise"] + if check_node_kind(node.args[index_node], modules, node_kind): + if len(node.args[index_node].users) > 1: + continue + computation_node = modules[node.args[index_node].target] + # TODO: support padding str input("valid", "same"). + if type(computation_node) in [nn.Conv2d] and isinstance( + computation_node.padding, str + ): + continue + # only fuse for linear when the dtype is bf16 + if type(computation_node) in [ + nn.Linear + ] and not is_bfloat16_module(computation_node): + continue + replace_and_fuse_for_binary( + computation_node, + node, + fuse_func, + attr if attr != "iadd" else "add", + modules, + index_node, + index_pointwise, + ) + # Make sure the fused node is post node of node's inputs nodes. + node.append(node.args[index_node]) + gm.graph.erase_node(node) + gm.graph.lint() + break + + gm.recompile() + return gm + + +def fuse_binary_inplace(gm: torch.fx.GraphModule): + modules = dict(gm.named_modules()) + for node in gm.graph.nodes: + if check_node_is_add_inplace(node) and check_binary_op_kwargs_is_default(node): + for ( + node_kind, + fuse_func, + ) in computation_op_binary_op_fusion_inplace_map.items(): + if not isinstance(node.args[0], torch.fx.Node) or not isinstance( + node.args[1], torch.fx.Node + ): + continue + if not binary_inputs_meta_is_same(node): + continue + if check_node_kind(node.args[1], modules, node_kind): + if len(node.args[1].users) > 1: + continue + # make sure the output and input are not same tensor. + if node.args[1].args[0] == node.args[0]: + continue + computation_node = modules[node.args[1].target] + # TODO: support padding str input("valid", "same"). + if type(computation_node) in [nn.Conv2d] and isinstance( + computation_node.padding, str + ): + continue + replace_and_fuse_for_binary( + computation_node, + node, + fuse_func, + "add", + modules, + 1, # conv module index + 0, # binary op index + ) + # Make sure the fused node is post node of node's inputs nodes. + node.append(node.args[1]) + gm.graph.erase_node(node) + gm.graph.lint() + break + + gm.recompile() + return gm + + +def pack_module(gm: torch.fx.GraphModule): + modules = dict(gm.named_modules()) + for node in gm.graph.nodes: + if node.op == "call_module": + assert isinstance(node.target, str) + cur_module = modules[node.target] + if type(cur_module) in computation_op_packed_map: + computation_node_input_meta = node.args[0].meta.get("tensor_meta") + if computation_node_input_meta.dtype != torch.float32: + continue + if type(cur_module) in [torch.nn.Linear] and not torch._C.has_mkl: + continue + computation_node_input_size = computation_node_input_meta.shape + if type(cur_module) in [nn.Conv2d] and isinstance( + cur_module.padding, str + ): + continue + new_module = computation_op_packed_map[type(cur_module)]( + cur_module, computation_node_input_size + ) + assert isinstance(new_module, nn.Module) + replace_node_module(node, modules, new_module) + gm.graph.lint() + gm.recompile() + return gm + + philox_rand_like = _prims._make_prim( schema="philox_rand_like(Tensor input, Tensor seed, int offset) -> Tensor", return_type=_prims.RETURN_TYPE.NEW, @@ -135,7 +1175,7 @@ def backward(ctx, grad_output): @torch.fx.wrap -def lowmem_dropout(input, p, training=True, inplace=False): +def lowmem_dropout(input, p=0.5, training=True, inplace=False): if isinstance(input, torch.fx.Proxy): # double check we don't FX trace this return input.tracer.create_proxy( @@ -163,3 +1203,70 @@ def rand_like(x, **kwargs): replacements = {torch.nn.functional.dropout: lowmem_dropout, torch.rand_like: rand_like} +# Keep track of any replacement functions that use triton random, +# so they can be avoided when fallback_random is set +replacements_using_triton_random = {lowmem_dropout, rand_like} + +computation_op_unary_op_fusion_map = { + nn.Conv2d: fused_conv_unary_eval, + nn.Linear: fused_linear_unary_eval, + ConvBinary2d: fused_conv_binary_unary_eval, + ConvBinaryInplace2d: fused_conv_binary_unary_eval, +} + + +unary_modules_map = { + nn.ReLU: UnaryAttr("relu"), + nn.Sigmoid: UnaryAttr("sigmoid"), + nn.Tanh: UnaryAttr("tanh"), + nn.Hardswish: UnaryAttr("hardswish"), + nn.LeakyReLU: UnaryAttr("leaky_relu", scalars_attr=["negative_slope"]), + nn.Hardtanh: UnaryAttr("hardtanh", scalars_attr=["min_val", "max_val"]), + nn.GELU: UnaryAttr("gelu", algorithm_attr="approximate"), + nn.ReLU6: UnaryAttr("hardtanh", scalars_attr=["min_val", "max_val"]), + nn.SiLU: UnaryAttr("swish"), +} + + +binary_attr = { + torch.add: "add", # node.op == "call_function" + "add": "add", # node.op == "call_method" + "add_": "iadd", # node.op == "call_method" + operator.add: "add", # node.op == "call_function" + operator.iadd: "iadd", # node.op == "call_function" + torch.sub: "sub", # node.op == "call_function" + "sub": "sub", # node.op == "call_method" + "sub_": "sub", # node.op == "call_method" + operator.sub: "sub", # node.op == "call_function" + operator.isub: "sub", # node.op == "call_function" +} + + +computation_op_binary_op_fusion_map = { + nn.Conv2d: fused_conv_binary_eval, + nn.Linear: fused_linear_binary_eval, +} + + +computation_op_binary_op_fusion_inplace_map = { + nn.Conv2d: fused_conv_binary_inplace_eval, +} + + +computation_op_packed_map = { + nn.Linear: packed_linear_eval, + nn.Conv2d: packed_conv_eval, +} + + +# For add: we support conv/linear + other and other + conv +# For sub/add_/sub_, we only support conv/linear - other +# or conv/linear +(-)= other +supported_index_list = { + "add": [ + {"index_computation": 0, "index_pointwise": 1}, + {"index_computation": 1, "index_pointwise": 0}, + ], + "iadd": [{"index_computation": 0, "index_pointwise": 1}], + "sub": [{"index_computation": 0, "index_pointwise": 1}], +} diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 2cbf80c29566a..0bb8eb3f27f01 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -13,10 +13,10 @@ import torch -from . import config, dependencies, ir +from . import config, dependencies, ir, metrics from .dependencies import MemoryDep, StarDep from .sizevars import SimplifyIndexing -from .utils import cache_on_self, cmp, dynamo_utils +from .utils import cache_on_self, cmp, dynamo_utils, has_triton from .virtualized import V log = logging.getLogger(__name__) @@ -347,6 +347,12 @@ def allocate(self): V.kernel.args.make_inplace( input_node.get_name(), self.get_name() ) + # mutations not tracked in cpp kernels + if isinstance( + V.kernel, torch._inductor.codegen.triton.TritonKernel + ): + V.kernel.mutations.add(input_node.get_name()) + V.kernel.mutations.add(self.get_name()) return super().allocate() @@ -593,6 +599,7 @@ def __init__(self, nodes): self.compute_predecessors() self.dead_node_elimination() + metrics.ir_nodes_pre_fusion += len(self.nodes) V.debug.ir_pre_fusion(self.nodes) self.num_orig_nodes = len(self.nodes) self.name_to_fused_node = {n.get_name(): n for n in self.nodes} @@ -922,9 +929,22 @@ def can_fuse_vertical(self, node1, node2): be scheduled before the fusion of node1 and node2. """ node1_names = node1.get_names() - remaining_deps = { - dep.name for dep in node2.unmet_dependencies - node1.read_writes.writes - } + computed_deps = set() + for rd in node2.unmet_dependencies: + for cd in node1.read_writes.writes: + # StarDep doesn't match MemoryDep, different indices don't match + # However, broadcasting sometimes strips dimensions, and if that's the case + # we still can match unmet dep + if ( + rd.name == cd.name + and type(rd) == type(cd) + and rd.index == cd.index + and len(rd.size) >= len(cd.size) + and rd.size[: len(cd.size)] == cd.size + ): + computed_deps.add(rd) + + remaining_deps = {dep.name for dep in node2.unmet_dependencies - computed_deps} if remaining_deps & node1_names: # MemoryDeps didn't match and read different locations of the same buffer. # Examples here include: @@ -964,7 +984,7 @@ def score_fusion_memory(self, node1, node2): common_memory_deps = (node1.read_writes.reads | node1.read_writes.writes) & ( node2.read_writes.reads | node2.read_writes.writes ) - return sum(dep.numel_hint() for dep in common_memory_deps) + return sum(dep.numbytes_hint() for dep in common_memory_deps) def score_fusion_key(self, nodes): """ @@ -1065,6 +1085,16 @@ def create_backend(self, device: torch.device): return CppScheduling(self) else: + if not has_triton(): + device_props = torch.cuda.get_device_properties(device) + if device_props.major < 7: + raise RuntimeError( + f"Found {device_props.name} which is too old to be supported by the triton GPU compiler, which is used as the backend. Triton only supports devices of CUDA Capability >= 7.0, but your device is of CUDA capability {device_props.major}.{device_props.minor}" # noqa: B950 + ) + else: + raise RuntimeError( + "Cannot find a working triton installation. More information on installing Triton can be found at https://github.com/openai/triton" # noqa: B950 + ) from .codegen.triton import TritonScheduling return TritonScheduling(self) diff --git a/torch/_inductor/sizevars.py b/torch/_inductor/sizevars.py index 67902bb23b2d3..13a0f5b6bc2be 100644 --- a/torch/_inductor/sizevars.py +++ b/torch/_inductor/sizevars.py @@ -50,6 +50,8 @@ def __init__(self, shape_env=None): self.stride_vars = self.make_stride_vars_cache() self.simplify_with_ranges = self.make_simplify_with_ranges_cache() self._simplify_loops = self.make_simplify_loops_cache() + self.declare = "" + self.ending = "" def seed(self): """ @@ -144,6 +146,11 @@ def visit_modular_indexing(base, divisor, modulus): base_s = base.args[2] - 1 elif not base.has(ModularIndexing): # actual iteration range is to size-1 + iter_ranges_zero = {k: 0 for k, v in var_ranges.items()} + base_lowest = sympy_subs(base, iter_ranges_zero) + if self.maybe_guard_lt(base_lowest, 0): + # can't replace with indexing div if base can be negative + return ModularIndexing(base, divisor, modulus) iter_ranges = {k: v - 1 for k, v in var_ranges.items()} base_s = sympy_subs(base, iter_ranges) else: @@ -446,12 +453,14 @@ def codegen(self, code: IndentedBuffer, graph_inputs: Dict[str, ir.Buffer]): @functools.lru_cache(None) def sizeof(name): - code.writeline(f"{name}_size = {name}.size()") + code.writeline(f"{self.declare}{name}_size = {name}.size(){self.ending}") return f"{name}_size" @functools.lru_cache(None) def strideof(name): - code.writeline(f"{name}_stride = {name}.stride()") + code.writeline( + f"{self.declare}{name}_stride = {name}.stride(){self.ending}" + ) return f"{name}_stride" # Assign all symbolic shapes needed to local variables @@ -465,7 +474,9 @@ def strideof(name): if shape in needed: needed.remove(shape) added.add(shape) - code.writeline(f"{shape} = {sizeof(name)}[{dim}]") + code.writeline( + f"{self.declare}{shape} = {sizeof(name)}[{dim}]{self.ending}" + ) elif isinstance(shape, sympy.Symbol): assert shape in added, f"{shape} is needed but not added" @@ -475,7 +486,9 @@ def strideof(name): shape = self.simplify(shape) if shape in needed: needed.remove(shape) - code.writeline(f"{shape} = {strideof(name)}[{dim}]") + code.writeline( + f"{self.declare}{shape} = {strideof(name)}[{dim}]{self.ending}" + ) elif isinstance(shape, sympy.Symbol): assert shape in added, f"{shape} is needed but not added" assert not needed @@ -493,6 +506,9 @@ def codegen_shape_tuple(self, shape: Tuple[Expr, ...]) -> str: return f"({parts[0]}, )" return f"({', '.join(parts)})" + def codegen_benchmark_shape_tuple(self, shape: Tuple[Expr, ...]) -> str: + return self.codegen_shape_tuple(shape) + def join_dimensions(expr: Expr) -> Expr: from .ir import ModularIndexing @@ -559,6 +575,24 @@ def _join_dimensions_cached(expr: Expr) -> Expr: return expr +class CppSizeVarAllocator(SizeVarAllocator): + def __init__(self, shape_env=None): + super().__init__(shape_env) + self.declare = "auto " + self.ending = ";" + + def codegen_shape_tuple(self, shape: Tuple[Expr, ...]) -> str: + parts = list(map(self.codegen_sizevar, shape)) + if len(parts) == 0: + return "{}" + if len(parts) == 1: + return f"{{{parts[0]}, }}" + return f"{{{', '.join(parts)}}}" + + def codegen_benchmark_shape_tuple(self, shape: Tuple[Expr, ...]) -> str: + return super().codegen_shape_tuple(shape) + + class SimplifyIndexing(V.WrapperHandler): # type: ignore[name-defined] """ A wrapper around .virtualize.ops that uses var range information to @@ -567,6 +601,7 @@ class SimplifyIndexing(V.WrapperHandler): # type: ignore[name-defined] def __init__(self, inner, var_ranges: VarRanges): super().__init__(inner) + self.name = "SimplifyIndexing" self._simplify: Callable[ [Expr], Expr ] = lambda index: V.graph.sizevars.simplify_with_ranges(index, var_ranges) diff --git a/torch/_inductor/triton_ops/autotune.py b/torch/_inductor/triton_ops/autotune.py index 5d53b3522a25c..a61927eb01885 100644 --- a/torch/_inductor/triton_ops/autotune.py +++ b/torch/_inductor/triton_ops/autotune.py @@ -11,9 +11,9 @@ import torch from .. import config -from ..ir import ReductionHint +from ..ir import ReductionHint, TileHint from ..triton_ops.mm_perf_model import estimate_matmul_time -from ..utils import conditional_product, has_triton +from ..utils import conditional_product, dynamo_utils, has_triton from .conv_perf_model import ( early_config_prune as conv_early_config_prune, estimate_conv_time, @@ -42,11 +42,12 @@ class CachingAutotuner(KernelInterface): configs, and does not rely on the Triton JIT. """ - def __init__(self, fn, meta, configs, save_cache_hook): + def __init__(self, fn, meta, configs, save_cache_hook, mutated_arg_names): super().__init__() self.fn = fn self.meta = meta self.save_cache_hook = save_cache_hook + self.mutated_arg_names = mutated_arg_names self.configs = configs self.launchers = [] self.lock = threading.Lock() @@ -134,18 +135,24 @@ def kernel_call(): from triton.testing import do_bench - return do_bench(kernel_call) + return do_bench(kernel_call, rep=40, fast_flush=True) + @dynamo_utils.dynamo_timed def autotune_to_one_config(self, *args, **kwargs): """Do the actual autotuning""" from ..compile_fx import clone_preserve_strides - # clone the input args to avoid autotune contaminating them if - # the kernel does in-place stores - cloned_args = [ - clone_preserve_strides(arg) if isinstance(arg, torch.Tensor) else arg - for arg in args - ] + # clone inplace buffers to avoid autotune contaminating them if + # the kernel does in-place stores. avoid cloning other buffers because + # it leads to increase memory use + cloned_args = [] + for i, arg in enumerate(args): + if self.fn.arg_names[i] in self.mutated_arg_names: + assert isinstance(arg, torch.Tensor) + cloned_args.append(clone_preserve_strides(arg)) + else: + cloned_args.append(arg) + timings = { launcher: self.bench(launcher, *cloned_args, **kwargs) for launcher in self.launchers @@ -177,7 +184,7 @@ def run(self, *args, grid, stream): raise RuntimeError( """Consider updating Triton with `pip install -U "git+https://github.com/openai/triton@af76c989eb4799b015f8b288ccd8421558772e56#subdirectory=python"`""" - ) + ) from e else: raise e @@ -207,7 +214,8 @@ def load_cached_autotuning( if not os.path.exists(cache_filename): return None - best_config = json.loads(open(cache_filename).read()) + with open(cache_filename, "r") as fd: + best_config = json.loads(fd.read()) if best_config.get("configs_hash") != configs_hash: return None @@ -249,9 +257,15 @@ def save_cache_hook(cfg): else: save_cache_hook = None + mutated_arg_names = meta.pop("mutated_arg_names", ()) + def decorator(fn): return CachingAutotuner( - fn, meta=meta, configs=configs, save_cache_hook=save_cache_hook + fn, + meta=meta, + configs=configs, + save_cache_hook=save_cache_hook, + mutated_arg_names=mutated_arg_names, ) return decorator @@ -343,7 +357,7 @@ def triton_config_reduction(size_hints, x, r, num_stages=2) -> Config: r *= 2 cfg = {"XBLOCK": x, "RBLOCK": r} - num_warps = next_power_of_2(min(max(conditional_product(x, r) // 128, 1), 8)) + num_warps = next_power_of_2(min(max(conditional_product(x, r) // 128, 2), 8)) return Config(cfg, num_warps=num_warps, num_stages=num_stages) @@ -376,15 +390,15 @@ def triton_config_tiled_reduction(size_hints, x, y, r, num_stages=2): return Config(cfg, num_warps=num_warps, num_stages=num_stages) -def pointwise(size_hints, meta, filename=None): +def pointwise(size_hints, meta, tile_hint=None, filename=None): """ Construct @triton.heuristics() based on size_hints. """ if len(size_hints) == 1: return cached_autotune([triton_config(size_hints, 1024)], meta=meta) if len(size_hints) == 2: - if not config.triton.autotune: - return cached_autotune([triton_config(size_hints, 64, 64)], meta=meta) + if not config.triton.autotune or tile_hint == TileHint.SQUARE: + return cached_autotune([triton_config(size_hints, 32, 32)], meta=meta) return cached_autotune( [ triton_config(size_hints, 32, 32), diff --git a/torch/_inductor/triton_ops/conv.py b/torch/_inductor/triton_ops/conv.py index 62d7123174a5b..a2098bce1995a 100644 --- a/torch/_inductor/triton_ops/conv.py +++ b/torch/_inductor/triton_ops/conv.py @@ -465,7 +465,7 @@ def _call( shape_w = w.shape shape_bias = bias.shape if bias is not None else None - # indicies for the layeout + # indicies for the layout xn, xc, xh, xw = 0, 1, 2, 3 yn, yc, yh, yw = 0, 1, 2, 3 wn, wc, wh, ww = 0, 1, 2, 3 diff --git a/torch/_inductor/triton_ops/conv1x1.py b/torch/_inductor/triton_ops/conv1x1.py index c7b79f004a5a9..fca5dc3f1d323 100644 --- a/torch/_inductor/triton_ops/conv1x1.py +++ b/torch/_inductor/triton_ops/conv1x1.py @@ -26,7 +26,7 @@ def _call( shape_w = w.shape shape_bias = bias.shape if bias is not None else None - # indicies for the layeout + # indicies for the layout xn, xc, xh, xw = 0, 1, 2, 3 yn, yc, yh, yw = 0, 1, 2, 3 wn, wc, wh, ww = 0, 1, 2, 3 diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index e970f6acbe5d8..ff2fae775220d 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -178,6 +178,39 @@ def wrapper(self): return wrapper +def get_fused_kernel_name(node_schedule): + return "_".join( + ["fused"] + + sorted( + [ + str(origin.name) + for origin in functools.reduce( + operator.or_, + [ + node.node.origins + for node in node_schedule + if hasattr(node, "node") + ], + ) + if origin.op == "call_function" + ] + )[0 : config.kernel_name_max_ops] + ) + + +def gather_origins(args, kwargs): + import itertools + + from .ir import ComputedBuffer, IRNode + + def is_unrealized_node(n): + return isinstance(n, IRNode) and not isinstance(n, ComputedBuffer) + + kwarg_origins = [val.origins for val in kwargs.values() if is_unrealized_node(val)] + arg_origins = [arg.origins for arg in args if is_unrealized_node(arg)] + return set(itertools.chain(*arg_origins, *kwarg_origins)) + + def sympy_str(expr: sympy.Expr): """ Normal sympy str is very slow, this is a lot faster. The result are @@ -242,23 +275,38 @@ def has_incompatible_cudagraph_ops(gm): @contextlib.contextmanager -def fresh_triton_cache(cache_entries=None): +def fresh_inductor_cache(cache_entries=None): """ - Contextmanager that provides a clean tmp cachedir for triton. + Contextmanager that provides a clean tmp cachedir for inductor. Optionally, pass a dict as 'cache_entries' to get a list of filenames and sizes generated with this cache instance. """ - with tempfile.TemporaryDirectory() as tmpdirname: - with mock.patch.dict(os.environ, {"TRITON_CACHE_DIR": tmpdirname}): - yield - if isinstance(cache_entries, dict): - assert len(cache_entries) == 0, "expected empty cache_entries dict" - files = os.listdir(tmpdirname) - cache_entries.update( - { - f: os.path.getsize(os.path.join(tmpdirname, f)) - for f in files - if ".lock" not in f - } - ) + with tempfile.TemporaryDirectory() as inductor_cache_dir: + with mock.patch.dict( + os.environ, {"TORCHINDUCTOR_CACHE_DIR": inductor_cache_dir} + ): + triton_cache_dir = os.path.join(inductor_cache_dir, "triton") + with mock.patch.dict(os.environ, {"TRITON_CACHE_DIR": triton_cache_dir}): + yield + if isinstance(cache_entries, dict): + assert len(cache_entries) == 0, "expected empty cache_entries dict" + if os.path.exists(triton_cache_dir): + files = os.listdir(triton_cache_dir) + cache_entries.update( + { + f: os.path.getsize(os.path.join(triton_cache_dir, f)) + for f in files + if ".lock" not in f + } + ) + + +def argsort(seq): + # preserve original order for equal strides + return list(reversed(sorted(range(len(seq)), key=seq.__getitem__, reverse=True))) + + +@functools.lru_cache(8) +def get_dtype_size(dtype): + return torch.empty((), dtype=dtype).element_size() diff --git a/torch/_inductor/virtualized.py b/torch/_inductor/virtualized.py index 64c221895a91b..cff6770997371 100644 --- a/torch/_inductor/virtualized.py +++ b/torch/_inductor/virtualized.py @@ -57,20 +57,27 @@ def _arg_str(a): class MockHandler: def __getattr__(self, name): + if name == "name": + return "MockHandler" + def inner(*args, **kwargs): fargs = [_arg_str(a) for a in args] fargs.extend(f"{k}={v}" for k, v in kwargs.items()) - return f"{name}({', '.join(fargs)})" + return self.truncate_expr(f"{name}({', '.join(fargs)})") return inner @staticmethod - def masked(mask, body, other): - return f"masked({mask}, {body()}, {other})" + def truncate_expr(expr): + return expr + + @classmethod + def masked(cls, mask, body, other): + return cls.truncate_expr(f"masked({mask}, {body()}, {other})") @staticmethod def indirect_indexing(index_var): - return sympy_symbol(str(index_var)) + return sympy_symbol(f"({str(index_var)})") @classmethod def _init_cls(cls): diff --git a/torch/_linalg_utils.py b/torch/_linalg_utils.py index 76b8ab532fcda..bdd22f395d2da 100644 --- a/torch/_linalg_utils.py +++ b/torch/_linalg_utils.py @@ -76,12 +76,7 @@ def qform(A: Optional[Tensor], S: Tensor): def basis(A): """Return orthogonal basis of A columns.""" - if A.is_cuda: - # torch.orgqr is not available in CUDA - Q = torch.linalg.qr(A).Q - else: - Q = torch.orgqr(*torch.geqrf(A)) - return Q + return torch.linalg.qr(A).Q def symeig(A: Tensor, largest: Optional[bool] = False) -> Tuple[Tensor, Tensor]: diff --git a/torch/_lobpcg.py b/torch/_lobpcg.py index 273c93d038158..032783c2d24e4 100644 --- a/torch/_lobpcg.py +++ b/torch/_lobpcg.py @@ -399,7 +399,7 @@ def lobpcg( A (Tensor): the input tensor of size :math:`(*, m, m)` B (Tensor, optional): the input tensor of size :math:`(*, m, - m)`. When not specified, `B` is interpereted as + m)`. When not specified, `B` is interpreted as identity matrix. X (tensor, optional): the input tensor of size :math:`(*, m, n)` diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 2e1c728c582dc..b6d0214f0df47 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -4,13 +4,18 @@ import torch import torch._prims_common as utils from torch import Tensor -from torch._decomp import meta_table as meta_table +from torch._decomp import _add_op_to_registry, global_decomposition_table, meta_table +from torch._ops import OpOverload +from torch._prims import _elementwise_meta, ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND from torch._prims_common import ( check, corresponding_complex_dtype, corresponding_real_dtype, elementwise_dtypes, ELEMENTWISE_TYPE_PROMOTION_KIND, + FloatLike, + IntLike, + make_contiguous_strides_for, ) from torch._prims_common.wrappers import out_wrapper @@ -19,27 +24,19 @@ from torch._subclasses.fake_tensor import check_no_bool_index_tensors from torch.utils._pytree import tree_map + aten = torch.ops.aten _meta_lib_dont_use_me_use_register_meta = torch.library.Library("aten", "IMPL", "Meta") -def register_meta(op, register_dispatcher=True): - def wrapper(f): - def add_func(op): - meta_table[op] = f - if register_dispatcher: - name = ( - op.__name__ - if op._overloadname != "default" - else op.overloadpacket.__name__ - ) - _meta_lib_dont_use_me_use_register_meta.impl(name, f) - - op.py_impl(torch._C.DispatchKey.Meta)(f) +def register_meta(op): + def wrapper(fn): + def register(op): + _add_op_to_registry(meta_table, op, fn) - tree_map(add_func, op) - return f + tree_map(register, op) + return fn return wrapper @@ -81,13 +78,28 @@ def meta_randperm(n, *, generator=None, out): @register_meta(aten.randint.default) -def meta_randint(high, size, *, dtype=torch.long, **kwargs): - return torch.empty(size, dtype=dtype, **kwargs) +def meta_randint( + high, size, *, dtype=torch.long, layout=None, device=None, pin_memory=None +): + return torch.empty( + size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory + ) @register_meta(aten.randint.low) -def meta_randint_low(low, high, size, *, dtype=torch.long, **kwargs): - return torch.empty(size, dtype=dtype, **kwargs) +def meta_randint_low( + low, high, size, *, dtype=torch.long, layout=None, device=None, pin_memory=None +): + return torch.empty( + size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory + ) + + +@register_meta(aten.rand.default) +def meta_rand_default(size, *, dtype=None, layout=None, device=None, pin_memory=None): + return torch.empty( + size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory + ) @register_meta([aten._fft_c2r.default, aten._fft_c2r.out]) @@ -99,11 +111,28 @@ def meta_fft_c2r(self, dim, normalization, lastdim): return self.new_empty(output_sizes, dtype=toRealValueType(self.dtype)) -@register_meta(aten.copy_.default, register_dispatcher=False) +@register_meta(aten.copy_.default) def meta_copy_(self, src, non_blocking=False): return self +def inferUnsqueezeGeometry(tensor, dim): + result_sizes = list(tensor.size()) + result_strides = list(tensor.stride()) + new_stride = 1 if dim >= tensor.dim() else result_sizes[dim] * result_strides[dim] + result_sizes.insert(dim, 1) + result_strides.insert(dim, new_stride) + return result_sizes, result_strides + + +@register_meta(aten.unsqueeze_.default) +def meta_unsqueeze_(self, dim): + dim = maybe_wrap_dim(dim, self.dim() + 1) + g_sizes, g_strides = inferUnsqueezeGeometry(self, dim) + self.as_strided_(g_sizes, g_strides) + return self + + # Implementations below are taken from https://github.com/albanD/subclass_zoo/blob/main/python_meta_tensor.py @register_meta(aten.index_select.default) def meta_index_select(self, dim, index): @@ -125,6 +154,16 @@ def meta_max(self): return self.new_empty(()) +@register_meta(aten.max.dim) +def meta_max_dim(self, dim, keepdim=False): + dim = utils.reduction_dims(self.shape, (dim,)) + output_shape = _compute_reduction_shape(self, dim, keepdim) + return ( + self.new_empty(output_shape), + self.new_empty(output_shape, dtype=torch.long), + ) + + @register_meta([aten.min.default]) def meta_min(self): return self.new_empty(()) @@ -138,7 +177,7 @@ def meta_angle(self): _, result_dtype = elementwise_dtypes( self, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT ) - return self.new_empty(self.size(), dtype=result_dtype) + return torch.empty_like(self, dtype=result_dtype) @register_meta(aten.angle.out) @@ -147,7 +186,8 @@ def meta_angle_out(self, out): return out.copy_(torch.angle(self)) -def squareCheckInputs(self, f_name): +# From aten/src/ATen/native/LinearAlgebraUtils.h +def squareCheckInputs(self: Tensor, f_name: str): assert ( self.dim() >= 2 ), f"{f_name}: The input tensor must have at least 2 dimensions." @@ -156,6 +196,22 @@ def squareCheckInputs(self, f_name): ), f"{f_name}: A must be batches of square matrices, but they are {self.size(-2)} by {self.size(-1)} matrices" +# From aten/src/ATen/native/LinearAlgebraUtils.h +def checkFloatingOrComplex( + t: Tensor, f_name: str, allow_low_precision_dtypes: bool = True +): + dtype = t.dtype + check( + t.is_floating_point() or t.is_complex(), + lambda: f"{f_name}, : Expected a floating point or complex tensor as input. Got , {dtype}", + ) + if allow_low_precision_dtypes: + check( + dtype in (torch.float, torch.double, torch.cfloat, torch.cdouble), + lambda: f"{f_name} : Low precision dtypes not supported. Got {dtype}", + ) + + def checkUplo(uplo: str): uplo_uppercase = uplo.upper() assert ( @@ -175,6 +231,25 @@ def meta_linalg_eigh(self, uplo="L"): return (values, vectors) +# From aten/src/ATen/native/BatchLinearAlgebra.cpp +@register_meta(aten.linalg_cholesky_ex.default) +def linalg_cholesky_ex(A: Tensor, upper: bool = False, check_errors: bool = False): + squareCheckInputs(A, "linalg.cholesky") + checkFloatingOrComplex(A, "linalg.cholesky") + + A_shape = A.shape + ndim = len(A_shape) + + # L + L_strides = make_contiguous_strides_for(A_shape, False) + L = A.new_empty(A_shape) + L.as_strided_(A_shape, L_strides) + + # infos + infos = A.new_empty(A_shape[0 : ndim - 2], dtype=torch.int32) + return L, infos + + # From aten/src/ATen/native/ReflectionPad.cpp @register_meta( [aten.reflection_pad2d_backward.default, aten.replication_pad2d_backward.default] @@ -239,11 +314,24 @@ def meta_pad2d(self, padding): return self.new_empty((nbatch, nplane, output_h, output_w)) -@register_meta(aten.bernoulli_.float, register_dispatcher=False) +@register_meta([aten.bernoulli.default, aten.bernoulli.out]) +@out_wrapper() +def meta_bernoulli(self, *, generator=None): + # https://github.com/pytorch/pytorch/issues/88612 + return torch.empty_like(self).contiguous() + + +@register_meta(aten.bernoulli_.float) def meta_bernoulli_(self, p=0.5, generator=None): return self +@register_meta(aten.bernoulli.p) +def meta_bernoulli_p(self, p=0.5, generator=None): + # https://github.com/pytorch/pytorch/issues/88612 + return torch.empty_like(self).contiguous() + + @register_meta(aten._fused_moving_avg_obs_fq_helper.default) def meta__fused_moving_avg_obs_fq_helper( self, @@ -257,15 +345,15 @@ def meta__fused_moving_avg_obs_fq_helper( quant_min, quant_max, ch_axis, - per_row_fake_quant, - symmetric_quant, + per_row_fake_quant=False, + symmetric_quant=False, ): check( ch_axis < self.dim(), lambda: "Error in fused_moving_avg_obs_fake_quant_cpu: ch_axis must be < self.dim()", ) - mask = self.empty_like(dtype=torch.bool) - return (self.empty_like(), mask) + mask = torch.empty_like(self, dtype=torch.bool) + return (torch.empty_like(self), mask) def dot_check(self, other): @@ -281,7 +369,7 @@ def meta_dot(self, tensor): return self.new_empty(()) -@register_meta([aten.mm.default], register_dispatcher=False) +@register_meta([aten.mm.default]) def meta_mm(a, b): check(a.dim() == 2, lambda: "a must be 2D") check(b.dim() == 2, lambda: "b must be 2D") @@ -298,10 +386,15 @@ def _compute_reduction_shape(self, dims, keepdim): return utils.compute_reduction_output_shape(self.shape, dims) -@register_meta(aten.bernoulli.out) -def meta_bernoulli(self, *, generator=None, out): - torch._resize_output_(out, self.size(), self.device) - return out +# FakeTensors (meta tensors with a device) will report device as meta +# when running meta kernels. Here, access the "fake device" of FakeTensor if it +# exists so meta kernels which have diverge per device will be more +# accurate when run with FakeTensors +def device_hint(tensor) -> "str": + if isinstance(tensor, torch._subclasses.FakeTensor): + return tensor.fake_device.type + else: + return "cuda" # default to cuda @register_meta(aten.convolution.default) @@ -361,24 +454,24 @@ def calc_conv_nd_return_shape( output_padding: Optional[Union[List[int], int]] = None, ): ret_shape = [] - if isinstance(stride, int): + if isinstance(stride, IntLike): stride = [stride] * len(dims) elif len(stride) == 1: stride = [stride[0]] * len(dims) - if isinstance(padding, int): + if isinstance(padding, IntLike): padding = [padding] * len(dims) elif len(padding) == 1: padding = [padding[0]] * len(dims) - if isinstance(dilation, int): + if isinstance(dilation, IntLike): dilation = [dilation] * len(dims) elif len(dilation) == 1: dilation = [dilation[0]] * len(dims) output_padding_list: Optional[List[int]] = None if output_padding: - if isinstance(output_padding, int): + if isinstance(output_padding, IntLike): output_padding_list = [output_padding] * len(dims) elif len(output_padding) == 1: output_padding_list = [output_padding[0]] * len(dims) @@ -409,8 +502,8 @@ def calc_conv_nd_return_shape( def is_channels_last(ten): return torch._prims_common.suggest_memory_format(ten) == torch.channels_last - def pick_memory_format(device_hint): - if device_hint == "cuda": + def pick_memory_format(): + if device_hint(input_tensor) == "cuda": if is_channels_last(input_tensor) or is_channels_last(weight): return torch.channels_last else: @@ -444,15 +537,7 @@ def pick_memory_format(device_hint): ) out = input_tensor.new_empty((input_tensor.shape[0], out_channels, *shape_out)) - from torch._subclasses.fake_tensor import FakeTensor - - if isinstance(input_tensor, FakeTensor): - device_hint = input_tensor.fake_device.type - else: - device_hint = "cuda" # default to cuda - - mem_fmt = pick_memory_format(device_hint) - out = out.to(memory_format=mem_fmt) # type: ignore[call-overload] + out = out.to(memory_format=pick_memory_format()) # type: ignore[call-overload] return out @@ -465,7 +550,7 @@ def check_dim_size(tensor, dim, dim_size, size): ) -@register_meta(aten.avg_pool2d.default, register_dispatcher=False) +@register_meta(aten.avg_pool2d.default) def meta_avg_pool2d( input, kernel_size, @@ -584,7 +669,7 @@ def avg_pool2d_backward_shape_check( # Don't override the C++ registration. -@register_meta(aten.avg_pool2d_backward.default, register_dispatcher=False) +@register_meta(aten.avg_pool2d_backward.default) def meta_avg_pool2d_backward( gradOutput_, input, @@ -660,7 +745,13 @@ def meta_adaptive_avg_pool2d(self, output_size): self.ndim == 3 or self.ndim == 4, lambda: f"Expected 3D or 4D tensor, but got {self.shape}", ) - return self.new_empty(self.shape[:-2] + tuple(output_size)) + output_shape = self.shape[:-2] + tuple(output_size) + memory_format = utils.suggest_memory_format(self) + # need to set memory_format to preserve the memory format of the input + # channel last input should have channel last output + return torch.empty( + output_shape, dtype=self.dtype, device=self.device, memory_format=memory_format + ) @register_meta(aten._adaptive_avg_pool3d.default) @@ -729,7 +820,7 @@ def vdot(self, other): # of indexing shape inference is useful, # but not registering it to the dispatcher because we already # get shape inference through structured kernels -@register_meta(aten.index.Tensor, register_dispatcher=False) +@register_meta(aten.index.Tensor) def meta_index_Tensor(self, indices): check_no_bool_index_tensors(aten.index.Tensor, self, indices) check(indices, lambda: "at least one index must be provided") @@ -912,8 +1003,8 @@ def meta_cdist_forward(x1, x2, p, compute_mode): ) check(p >= 0, lambda: "cdist only supports non-negative p values") check( - compute_mode >= 0 and compute_mode <= 2, - lambda: f"possible modes: 0, 1, 2, but was: {compute_mode}", + compute_mode in (None, 1, 2), + lambda: f"possible modes: None, 1, 2, but was: {compute_mode}", ) r1 = x1.size(-2) r2 = x2.size(-2) @@ -999,7 +1090,7 @@ def is_fast_path(src, scale, output, padding_idx): else: return is_fast_path_index_select(src, output, padding_idx) - if offsets.device.type != "cpu": + if device_hint(offsets) != "cpu": offset2bag = indices.new_empty(indices.size(0)) bag_size = indices.new_empty(offsets.size()) if mode == MODE_MAX: @@ -1017,28 +1108,12 @@ def is_fast_path(src, scale, output, padding_idx): return output, offset2bag, bag_size, max_indices -@register_meta([aten.diag.default, aten.diag.out]) -@out_wrapper() -def meta_diag(self, dim=0): - check(self.dim() in (1, 2), lambda: "matrix or a vector expected") - if self.dim() == 1: - sz = self.size(0) + abs(dim) - return self.new_empty((sz, sz)) - - # case: dim is 2 - if dim >= 0: - sz = min(self.size(0), self.size(1) - dim) - else: - sz = min(self.size(0) + dim, self.size(1)) - return self.new_empty((sz,)) - - @register_meta(aten._embedding_bag_forward_only.default) def meta_embedding_bag_forward_only(weight, indices, offsets, *args): output, offset2bag, bag_size, max_indices = meta_embedding_bag( weight, indices, offsets, *args ) - if offsets.device.type == "cpu": + if device_hint(offsets) == "cpu": bag_size = offsets.new_empty(offsets.size()) return output, offset2bag, bag_size, max_indices @@ -1104,40 +1179,81 @@ def meta_repeat(self, repeats): return self.new_empty(target_size) -@register_meta(aten.zero_.default, register_dispatcher=False) +@register_meta(aten.zero_.default) def meta_zero_(self): return self @register_meta( - [aten.fill.Tensor, aten.fill.Scalar, aten.fill_.Tensor, aten.fill_.Scalar], - register_dispatcher=False, + [ + aten.mul_.Scalar, + aten.div_.Scalar, + aten.mul_.Tensor, + aten.div_.Tensor, + aten.logical_and_.default, + aten.logical_or_.default, + aten.logical_xor_.default, + ], +) +def meta_binop_inplace(self, other): + return self + + +@register_meta( + [ + aten.add_.Scalar, + aten.sub_.Scalar, + aten.add_.Tensor, + aten.sub_.Tensor, + ], ) +def meta_binop_inplace_alpha(self, other, alpha=1): + return self + + +@register_meta([aten.round.default, aten.round.decimals]) +def meta_round(self, **kwargs): + return _elementwise_meta( + self, type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT + ) + + +@register_meta(aten.zero.default) +def meta_zero(self): + return self.new_empty(self.shape) + + +@register_meta([aten.fill_.Tensor, aten.fill_.Scalar]) def meta_fill_(self, val): return self -@register_meta(aten.relu_.default, register_dispatcher=False) +@register_meta([aten.fill.Tensor, aten.fill.Scalar]) +def meta_fill(self, val): + return torch.empty_like(self) + + +@register_meta(aten.relu_.default) def meta_relu_(self): return self -@register_meta(aten.index_put.default, register_dispatcher=False) +@register_meta(aten.index_put.default) def meta_index_put(self, indices, values, accumulate=False): - return self.new_empty(self.size()) + return torch.empty_like(self) -@register_meta(aten.masked_fill_.Scalar, register_dispatcher=False) +@register_meta(aten.masked_fill_.Scalar) def meta_masked_fill_(self, mask, value): return self -@register_meta(aten.index_put_.default, register_dispatcher=False) +@register_meta(aten.index_put_.default) def meta_index_put_(self, indices, values, accumulate=False): return self -@register_meta(aten.alias.default, register_dispatcher=False) +@register_meta(aten.alias.default) def meta_alias(self): return self.view(self.shape) @@ -1175,7 +1291,7 @@ def common_meta_baddbmm_bmm(batch1, batch2, is_bmm, self_baddbmm=None): return output -@register_meta(aten.bmm.default, register_dispatcher=False) +@register_meta(aten.bmm.default) def meta_bmm(self, mat2): return common_meta_baddbmm_bmm(self, mat2, True) @@ -1285,9 +1401,8 @@ def pool2d_shape_check( ) -@register_meta(aten.max_pool2d_with_indices.default, register_dispatcher=False) -def meta_max_pool2d_with_indices( - input, kernel_size, stride=(), padding=(0,), dilation=(1,), ceil_mode=False +def max_pool2d_checks_and_compute_shape( + input, kernel_size, stride, padding, dilation, ceil_mode ): # Reference: aten/src/ATen/native/DilatedMaxPool2d.cpp def unpack(name, val): @@ -1312,6 +1427,9 @@ def unpack(name, val): padH, padW = unpack("padding", padding) dilationH, dilationW = unpack("dilation", dilation) + nInputPlane = input.size(-3) + inputHeight = input.size(-2) + inputWidth = input.size(-1) memory_format = utils.suggest_memory_format(input) if memory_format == torch.channels_last: @@ -1330,11 +1448,6 @@ def unpack(name, val): lambda: "Unsupport memory format. Supports only ChannelsLast, Contiguous", ) - nbatch = input.size(-4) if input.dim() == 4 else 1 - nInputPlane = input.size(-3) - inputHeight = input.size(-2) - inputWidth = input.size(-1) - outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, dilationH, ceil_mode) outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, dilationW, ceil_mode) @@ -1356,6 +1469,49 @@ def unpack(name, val): memory_format, ) + return nInputPlane, outputHeight, outputWidth + + +@register_meta(aten.max_pool2d_with_indices_backward.default) +def meta_max_pool2d_with_indices_backward( + grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices +): + nInputPlane, outputHeight, outputWidth = max_pool2d_checks_and_compute_shape( + self, kernel_size, stride, padding, dilation, ceil_mode + ) + + check( + self.dtype == grad_output.dtype, + lambda: "expected dtype {self.dtype} for `gradOutput` but got dtype {grad_output.dtype}", + ) + + nOutputPlane = nInputPlane + ndim = self.ndim + + def _check_dim_size(t): + check_dim_size(t, ndim, ndim - 3, nOutputPlane) + check_dim_size(t, ndim, ndim - 2, outputHeight) + check_dim_size(t, ndim, ndim - 1, outputWidth) + + _check_dim_size(grad_output) + _check_dim_size(indices) + + memory_format = utils.suggest_memory_format(self) + return torch.empty( + self.shape, dtype=self.dtype, device=self.device, memory_format=memory_format + ) + + +@register_meta(aten.max_pool2d_with_indices.default) +def meta_max_pool2d_with_indices( + input, kernel_size, stride=(), padding=(0,), dilation=(1,), ceil_mode=False +): + nInputPlane, outputHeight, outputWidth = max_pool2d_checks_and_compute_shape( + input, kernel_size, stride, padding, dilation, ceil_mode + ) + + nbatch = input.size(-4) if input.dim() == 4 else 1 + memory_format = utils.suggest_memory_format(input) if input.dim() == 3: size = [nInputPlane, outputHeight, outputWidth] else: @@ -1370,6 +1526,25 @@ def unpack(name, val): ) +@register_meta(aten.grid_sampler_2d_backward.default) +def grid_sampler_2d_backward_meta( + grad_output, + input, + grid, + interpolation_mode, + padding_mode, + align_corners, + output_mask, +): + input_requires_grad = output_mask[0] + if input_requires_grad: + grad_input = torch.zeros_like(input, memory_format=torch.contiguous_format) + else: + grad_input = None + grad_grid = torch.empty_like(grid, memory_format=torch.contiguous_format) + return (grad_input, grad_grid) + + @register_meta([aten.full.default]) def full(size, fill_value, *args, **kwargs): return torch.empty(size, *args, **kwargs) @@ -1382,7 +1557,6 @@ def full(size, fill_value, *args, **kwargs): aten.randn_like.default, aten.rand_like.default, aten.full_like.default, - aten.zeros_like.default, aten.ones_like.default, ] ) @@ -1390,14 +1564,52 @@ def meta_like(self, *args, **kwargs): return aten.empty_like.default(self, **kwargs) +# zeros_like is special cased to work for sparse +@register_meta(aten.zeros_like.default) +def zeros_like( + self, dtype=None, layout=None, device=None, pin_memory=None, memory_format=None +): + if layout == torch.sparse_coo: + check( + memory_format is None, + lambda: "memory format option is only supported by strided tensors", + ) + + res = torch.empty( + 0, + dtype=self.dtype if dtype is None else dtype, + layout=layout, + device=self.device if device is None else device, + pin_memory=pin_memory, + ) + + if self.is_sparse: + res.sparse_resize_and_clear_( + self.size(), self.sparse_dim(), self.dense_dim() + ) + else: + res.sparse_resize_and_clear_(self.size(), self.dim(), 0) + + res._coalesced_(True) + return res + return aten.empty_like.default( + self, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + memory_format=memory_format, + ) + + # hacky: Please remove after math.ceil works with arange @register_meta(aten.arange.default) def arange(end, **kwargs): - if isinstance(end, float): - end = math.ceil(end) + if isinstance(end, FloatLike): + end = math.ceil(end) # type: ignore[arg-type] def is_integral(x): - return isinstance(x, int) or isinstance(x, bool) + return isinstance(x, IntLike) or isinstance(x, bool) set_to_integral_dtype = kwargs.get("dtype", None) is None and is_integral(end) if set_to_integral_dtype: @@ -1450,13 +1662,14 @@ def meta_slice_scatter(self, src, dim=0, start=None, end=None, step=1): return torch.empty_like(self) +# TODO: Deduplicate this with canonicalize_dim def maybe_wrap_dim(dim: int, dim_post_expr: int, wrap_scalar: bool = True): if dim_post_expr <= 0: assert wrap_scalar dim_post_expr = 1 min = -dim_post_expr max = dim_post_expr - 1 - assert not (dim < min or dim > max) + assert not (dim < min or dim > max), f"dim {dim} out of bounds ({min}, {max})" if dim < 0: dim += dim_post_expr return dim @@ -1483,14 +1696,14 @@ def gather_shape_check(self, dim, index): ) -@register_meta(aten.gather.default, register_dispatcher=False) +@register_meta(aten.gather.default) def meta_gather(self, dim, index, sparse_grad=False): wrapped_dim = maybe_wrap_dim(dim, self.dim()) is_index_empty = index.numel() == 0 if not is_index_empty: check( index.dtype == torch.long, - lambda: "gather(): Expected dtype int64 for index", + lambda: f"gather(): Expected dtype int64 for index, but got {index.dtype}", ) gather_shape_check(self, wrapped_dim, index) return self.new_empty(index.shape) @@ -1599,40 +1812,180 @@ def scatter_meta_impl(self, dim, index, src=None, reduce_=None, use_new_options= get_operator_enum(reduce_, use_new_options) -@register_meta(aten.scatter_add.default, register_dispatcher=False) +@register_meta(aten.scatter_add.default) def meta_scatter_add(self, dim, index, src): scatter_meta_impl(self, dim, index, src, "add") return self.new_empty(self.shape) -@register_meta(aten.upsample_nearest2d.vec) -def upsample_nearest2d_vec(input, output_size, scale_factors): - mem_format = utils.suggest_memory_format(input) - spatial_dimensions = input.dim() - 2 +@register_meta(aten.scatter_add_) +def meta_scatter_add_(self, dim, index, src): + scatter_meta_impl(self, dim, index, src, "add") + return self + + +@register_meta( + [ + aten.scatter.src, + aten.scatter.value, + aten.scatter.reduce, + aten.scatter.value_reduce, + ] +) +@out_wrapper() +def meta_scatter(self, dim, index, src_or_value, reduce=None): + src = src_or_value if isinstance(src_or_value, torch.Tensor) else None + scatter_meta_impl(self, dim, index, src, reduce) + return self.new_empty(self.shape) + + +@register_meta( + [ + aten.scatter_.src, + aten.scatter_.value, + aten.scatter_.reduce, + aten.scatter_.value_reduce, + ] +) +def meta_scatter_(self, dim, index, src_or_value, reduce=None): + src = src_or_value if isinstance(src_or_value, torch.Tensor) else None + scatter_meta_impl(self, dim, index, src, reduce) + return self + + +@register_meta([aten.scatter_reduce.two, aten.scatter_reduce.two_out]) +@out_wrapper() +def meta_scatter_reduce_two(self, dim, index, src, reduce, include_self=True): + scatter_meta_impl(self, dim, index, src, reduce, use_new_options=True) + return self.new_empty(self.shape) + + +@register_meta(aten.scatter_reduce_.two) +def meta_scatter_reduce__two(self, dim, index, src, reduce, include_self=True): + scatter_meta_impl(self, dim, index, src, reduce, use_new_options=True) + return self + - input_shape = input.shape - if output_size is not None: - assert scale_factors is None - out_size = output_size - elif scale_factors is not None: - assert output_size is None - out_size = [] - for i in range(spatial_dimensions): - sym_float = (input_shape[i + 2] / 1) * scale_factors[i] - assert sym_float >= 0 - out_size.append(math.floor(sym_float)) +@register_meta([aten.sort.default, aten.sort.stable]) +def meta_sort(self, stable=None, dim=-1, descending=False): + return torch.empty_like(self), torch.empty_like(self, dtype=torch.int64) - output_height = out_size[0] - output_width = out_size[1] - nbatch = input_shape[0] - channels = input_shape[1] - return input.new_empty((nbatch, channels, output_height, output_width)).to( - memory_format=mem_format + +def zero_numel_check_dims(self, dim, fn_name): + if self.ndim == 0: + check( + dim == 0 or dim == -1, + lambda: f"{fn_name}: Expected reduction dim -1 or 0 for scalar but got {dim}", + IndexError, + ) + else: + check( + self.size(dim) != 0, + lambda: f"{fn_name}: Expected reduction dim {dim} to have non-zero size.", + IndexError, + ) + + +# From aten/src/ATen/native/ReduceOps.cpp +def check_argmax_argmin(name, self, dim): + if dim is not None: + dim = maybe_wrap_dim(dim, self.dim()) + zero_numel_check_dims(self, dim, name) + else: + check( + self.numel() != 0, + lambda: f"{name}: Expected reduction dim to be specified for input.numel() == 0.", + ) + + +@register_meta([aten.argmax.default, aten.argmin.default]) +def argmax_argmin_meta(self, dim=None, keepdim=False): + check_argmax_argmin("argmax", self, dim) + dims = utils.reduction_dims(self.shape, (dim,) if dim is not None else None) + shape = _compute_reduction_shape(self, dims, keepdim) + return self.new_empty(shape, dtype=torch.int64) + + +@register_meta(aten.scalar_tensor.default) +def scalar_tensor(s, dtype=None, layout=None, device=None, pin_memory=None): + return torch.empty( + (), dtype=dtype, layout=layout, device=device, pin_memory=pin_memory ) +@register_meta(aten.topk.default) +def topk_meta(self, k, dim=-1, largest=True, sorted=True): + # From aten/src/ATen/native/Sorting.cpp + dim = maybe_wrap_dim(dim, self.dim(), wrap_scalar=True) + check( + k >= 0 and k <= (self.size(dim) if self.dim() > 0 else 1), + lambda: "selected index k out of range", + ) + sliceSize = 1 if self.dim() == 0 else self.size(dim) + check(k >= 0 and k <= sliceSize, lambda: "k not in range for dimension") + + topKSize = list(self.shape) + if len(topKSize) > 0: + topKSize[dim] = k + return self.new_empty(topKSize), self.new_empty(topKSize, dtype=torch.int64) + + # We must also trigger meta registrations from PrimTorch ref # decompositions import torch._refs import torch._refs.nn.functional import torch._refs.special + + +def activate_meta(): + + activate_meta_table = {} + + # For a given op, we pick the most specific decomp function from + # global_decomp_table in the precedence order of meta > post_autograd > pre_autograd + for type in ["meta", "post_autograd", "pre_autograd"]: + registry = global_decomposition_table[type] + + for opo in registry: + if opo not in activate_meta_table: + activate_meta_table[opo] = registry[opo] + + for op_overload, fn in activate_meta_table.items(): + assert isinstance(op_overload, OpOverload) + + op_overload.py_impl(torch._C.DispatchKey.Meta)(fn) + + if torch._C._dispatch_has_kernel_for_dispatch_key( + op_overload.name(), "CompositeImplicitAutograd" + ): + # Internally, we shouldn't be registering meta kernels for any operators that + # have CompositeImplicitAutograd kernels. + # Instead, we should be letting those decompositions run, and writing meta kernels + # only for the base operators. + if op_overload in global_decomposition_table["meta"]: + raise RuntimeError( + f"{op_overload} is a CompositeImplicitAutograd op, we shouldn't " + "register meta function for it. Instead, we should let the decomposition run and write " + "meta kernels for the base operators." + ) + pass + elif op_overload.is_view: + # Attempting to register a python meta kernel for a view operator. + # We shouldn't do this, because the output will report as not having aliased storages. + # All view ops have meta kernels in C++ today, so we should use those instead. + pass + elif op_overload.name() in { + "aten::empty_strided", # causing infinite recursion, test_meta.py + "aten::clone", # causing infinite recursion + "aten::_to_copy", # causing infinite recursion, test_serialization.py -k test_tensor_subclass_getstate_overwrite # noqa: B950 + "aten::copy_", # Exception not raised, test_torch.py -k test_storage_meta_errors_cpu_int64 # noqa: B950 + "aten::constant_pad_nd", # requires_grad mismatch, test_ops.py -k test_fake_crossref_backward_amp_istft_cuda_float32 # noqa: B950 + "aten::rot90", # requires_grad mismatch! test_ops.py -k test_fake_crossref_backward_amp_rot90_cuda_float32 # noqa: B950 + "aten::as_strided_scatter", # requires_grad mismatch, test_ops.py -k test_fake_crossref_backward_no_amp_as_strided_scatter_cuda_float32 # noqa: B950 + }: + pass + else: + _meta_lib_dont_use_me_use_register_meta.impl(op_overload, fn) + + +activate_meta() diff --git a/torch/_ops.py b/torch/_ops.py index b3ebd401ab8a2..033d8f361eed7 100644 --- a/torch/_ops.py +++ b/torch/_ops.py @@ -10,6 +10,7 @@ import torch.jit from torch import _utils_internal +from torch._functorch.pyfunctorch import dispatch_functorch # Query `hasattr` only once. @@ -103,7 +104,7 @@ def resolve_key(op: PyOperatorABC, k: DispatchKey): # type: ignore[valid-type] # The dispatch key itself will implicitly route to backend fallback. # This is probably not great for the pure Python implementation. return k - raise RuntimeError("could not find kernel") + raise NotImplementedError(f"could not find kernel for {op} at dispatch key {k}") pyop_namespace = {} @@ -114,6 +115,7 @@ def __init__(self, name): self._name = name self.table = {} self.python_key_mode_table = {} + self.functorch_table = {} # Make _OPNamespace not scream, this whole name based association needs a good hard look self.__name__ = name @@ -122,18 +124,26 @@ def __init__(self, name): def fallthrough(self, dispatch_key): self.table[dispatch_key] = self._fallthrough_fn(self, dispatch_key) - def py_impl(self, dispatch_key_or_mode): + def py_impl(self, dispatch_key_or_mode_or_transform): def inner(fn): - if inspect.isclass(dispatch_key_or_mode) and issubclass( - dispatch_key_or_mode, torch.utils._python_dispatch.TorchDispatchMode + if inspect.isclass(dispatch_key_or_mode_or_transform) and issubclass( + dispatch_key_or_mode_or_transform, + torch.utils._python_dispatch.TorchDispatchMode, ): - mode = dispatch_key_or_mode + mode = dispatch_key_or_mode_or_transform assert mode not in self.python_key_mode_table # TODO(voz): Should we replace setting torch._C.DispatchKey.Python entirely with setting mode keys? self.python_key_mode_table[mode] = fn return fn - dispatch_key = dispatch_key_or_mode + if isinstance( + dispatch_key_or_mode_or_transform, torch._C._functorch.TransformType + ): + transform = dispatch_key_or_mode_or_transform + self.functorch_table[transform] = fn + return fn + + dispatch_key = dispatch_key_or_mode_or_transform assert ( dispatch_key != torch._C.DispatchKey.Python ), "Please register a mode for the torch._C.DispatchKey.Python key instead." @@ -147,6 +157,9 @@ def inner(fn): def dispatch(self, dispatch_key, *args, **kwargs): from torch.utils._python_dispatch import _get_current_dispatch_mode + if dispatch_key == torch._C.DispatchKey.FuncTorchDynamicLayerFrontMode: + return dispatch_functorch(self, args, kwargs) + if dispatch_key == torch._C.DispatchKey.Python: # TODO(voz): We should walk all the nodes here / turn it into a list, topmode is ok for now. curr_mode = type(_get_current_dispatch_mode()) @@ -159,7 +172,7 @@ def dispatch(self, dispatch_key, *args, **kwargs): # TODO(voz): The idea behind this is that we do not yet support dispatch by key + mode, only key. return self.python_key_mode_table[curr_mode](*args, **kwargs) - assert dispatch_key in self.table + assert dispatch_key in self.table, dispatch_key return self.table[dispatch_key](*args, **kwargs) def __call__(self, *args, **kwargs): @@ -243,6 +256,21 @@ def __init__(self, overloadpacket, op, op_dk, schema, tags): op.__module__ = overloadpacket.__module__ self.__qualname__ = self._name self.__annotations__ = {} + # NB: This name is hard-coded in torch/csrc/autograd/python_variable.cpp + self._dispatch_cache = {} + + # Logic replicated from aten/src/ATen/native/MathBitsFallback.h + is_write = None + for a in self._schema.arguments: + if a.alias_info is None: + continue + if is_write is None: + is_write = a.alias_info.is_write + else: + # We will conservatively call mixed mutable/non-mutable + # aliased inputs as NOT a view + is_write = a.alias_info.is_write or is_write + self.is_view = is_write is not None and not is_write # it's a no-op since OpOverload object is immutable and must be unique for a given op overload. def __deepcopy__(self, memo=None): @@ -289,6 +317,7 @@ def inner(fn): assert mode not in self.python_key_mode_table # TODO(voz): Should we replace setting torch._C.DispatchKey.Python entirely with setting mode keys? self.python_key_mode_table[mode] = fn + self._dispatch_cache.clear() return fn assert isinstance(dispatch_key_or_mode, torch._C.DispatchKey) @@ -296,24 +325,35 @@ def inner(fn): dispatch_key_or_mode != torch._C.DispatchKey.Python ), "Please register a mode for the torch._C.DispatchKey.Python key instead." + if dispatch_key_or_mode in self.py_kernels: + raise RuntimeError( + f"Trying to override a python impl for {dispatch_key_or_mode} on operator {self._name}" + ) self.py_kernels[dispatch_key_or_mode] = fn + self._dispatch_cache.clear() return fn return inner - # This implements the pre-computation logic for the Python dispatcher. - def __getattr__(self, attr): - if len(attr) == 0 or not attr[0].isupper(): - raise AttributeError() + # Remove a dispatch key from the dispatch cache. This will force it to get + # recomputed the next time. Does nothing + # WARNING: if you register a dispatch key to py_kernels of an OpOverload, + # calling _del_dispatch on that key is NOT sufficient to apply your change, + # because a single registration may affect MULTIPLE dispatch keys (e.g., + # registering Autograd affects AutogradCPU). del_dispatch is to be used + # only if you are specifically modifying how get_dispatch handles a + # particular input 'key'. + def _uncache_dispatch(self, key): + self._dispatch_cache.pop(key, None) - try: - key = torch._C._dispatch_key_parse(attr) - except Exception as e: - raise AttributeError() + # This implements the pre-computation logic for the Python dispatcher. + def _get_dispatch(self, key): + # This is only called upon a cache miss + assert key not in self._dispatch_cache, f"{self} {key}" if key == torch._C.DispatchKey.Python: if not self.python_key_mode_table: - setattr(self, attr, key) + self._dispatch_cache[key] = key return key def handler(*args, **kwargs): @@ -332,12 +372,25 @@ def handler(*args, **kwargs): # TODO(voz): The idea behind this is that we do not yet support dispatch by key + mode, only key. return self.python_key_mode_table[curr_mode](*args, **kwargs) - setattr(self, attr, handler) + self._dispatch_cache[key] = handler return handler - key = resolve_key(self, key) - r = self.py_kernels.get(key, key) - setattr(self, attr, r) + final_key = resolve_key(self, key) + + # TODO: We could potentially have lots of debugging wrappers against + # dispatch keys; design some general registration mechanism instead of + # having if statement for each of them + if key == torch._C.DispatchKey.Functionalize: + import torch._dispatch.python as pydispatch + + if pydispatch.CROSSREF_FUNCTIONALIZE: + handler = pydispatch.make_crossref_functionalize(self, final_key) + self._dispatch_cache[key] = handler + return handler + + # print(self, key, final_key) + r = self.py_kernels.get(final_key, final_key) + self._dispatch_cache[key] = r return r def name(self): @@ -368,6 +421,7 @@ def __init__(self, qualified_op_name, op_name, op, overload_names): self.__name__ = op_name self._op = op self._overload_names = overload_names + self._dir = [] # it's a no-op since OpOverloadPacket object is immutable and must be unique for a given op. def __deepcopy__(self, memo=None): @@ -426,6 +480,7 @@ def __getattr__(self, key): overload = OpOverload(self, op_, op_dk_, schema, tags) # cache the overload object setattr(self, key, overload) + self._dir.append(key) return overload except RuntimeError: raise AttributeError( @@ -434,6 +489,9 @@ def __getattr__(self, key): ) ) from None + def __iter__(self): + return iter(self._dir) + def __call__(self, *args, **kwargs): # overloading __call__ to ensure torch.ops.foo.bar() # is still callable from JIT @@ -485,6 +543,10 @@ class _OpNamespace(types.ModuleType): def __init__(self, name): super(_OpNamespace, self).__init__("torch.ops." + name) self.name = name + self._dir = [] + + def __iter__(self): + return iter(self._dir) def __getattr__(self, op_name): # It is not a valid op_name when __file__ is passed in @@ -517,6 +579,7 @@ def __getattr__(self, op_name): # cache the opoverloadpacket to ensure that each op corresponds to # a unique OpOverloadPacket object setattr(self, op_name, opoverloadpacket) + self._dir.append(op_name) return opoverloadpacket @@ -533,6 +596,7 @@ def __init__(self): super(_Ops, self).__init__("torch.ops") self.loaded_libraries = set() self.pyops = _PyOpNamespace() + self._dir = [] def __getattr__(self, name): # Check if the name is a pyop @@ -542,8 +606,12 @@ def __getattr__(self, name): # Here we are creating `torch.ops.my_namespace` namespace = _OpNamespace(name) setattr(self, name, namespace) + self._dir.append(name) return namespace + def __iter__(self): + return iter(self._dir) + def load_library(self, path): """ Loads a shared library from the given path into the current process. diff --git a/torch/_prims/__init__.py b/torch/_prims/__init__.py index 8ea992894cf5e..a7cc65ee23131 100644 --- a/torch/_prims/__init__.py +++ b/torch/_prims/__init__.py @@ -16,8 +16,10 @@ from torch._prims.nvfuser_prims import register_nvprims from torch._prims_common import ( check, + Dim, DimsSequenceType, DimsType, + IntLike, Number, NumberType, RETURN_TYPE, @@ -29,6 +31,7 @@ ) from torch._prims_common.wrappers import backwards_not_supported from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode +from torch.fx.experimental.symbolic_shapes import sym_float from torch.overrides import handle_torch_function, has_torch_function from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten @@ -148,6 +151,10 @@ "transpose", "view_of", # + # Functionalized view mutations + # + "as_strided_scatter", + # # Shape prims # "collapse", @@ -196,7 +203,7 @@ # Randomness Prims # "normal", - "uniform", + "_uniform_helper", # # FFT prims # @@ -304,6 +311,7 @@ def _backend_select_impl(*args, **kwargs): p.schema = schema p.prim_impl = _prim_impl p.prim_meta_impl = meta + p.impl_aten = impl_aten return _prim @@ -333,7 +341,7 @@ def _elementwise_meta( args_ = list(args) if args_with_fixed_dtypes is not None: - args_.extend(args_with_fixed_dtypes) + args_ = list(args_with_fixed_dtypes) + args_ utils.check_same_device(*args_, allow_cpu_scalar_tensors=True) utils.check_same_shape(*args_, allow_cpu_scalar_tensors=True) @@ -387,11 +395,18 @@ def _elementwise_meta( return TensorMeta(device=device, shape=shape, strides=strides, dtype=dtype) # Number case - # NOTE: this case is not currently exercised # TODO: fix number type promotion (bool, complex->float) - assert not isinstance(number, torch.SymIntNode), "NYI" - assert not isinstance(number, torch.SymFloatNode), "NYI" - return TensorMeta(number) + + # For now for symint/float, just implementing the common / simple cases of (int,float,symint,symfloat) + seen_float = False + if isinstance(number, (torch.SymInt, torch.SymFloat)): + for a in args: + assert isinstance(a, (int, float, torch.SymInt, torch.SymFloat)), "NYI" + seen_float = seen_float or isinstance(a, (float, torch.SymFloat)) + if seen_float: + number = sym_float(number) + + return TensorMeta(number) # type: ignore[arg-type] def _complex_only_elementwise_meta(*args, **kwargs): @@ -929,7 +944,7 @@ def _fill_aten(a: Tensor, value: NumberType) -> Tensor: # div prim performs truncation division on integer inputs # and true division for floating and complex inputs def _div_aten(a, b): - is_integral = isinstance(a, (bool, int)) or ( + is_integral = isinstance(a, (bool, int, torch.SymInt)) or ( isinstance(a, torch.Tensor) and utils.is_integer_dtype(a.dtype) ) @@ -1139,9 +1154,6 @@ def _minimum_aten( # # View operations -# -# TODO: model view relationships -# TODO: model storage def _as_strided_meta( a: TensorLikeType, size: ShapeType, stride: StrideType, storage_offset: int ) -> TensorLikeType: @@ -1155,9 +1167,11 @@ def _as_strided_meta( # as_strided to shapes with no elements are trivially valid, so it's OK pass elif isinstance(a, torch.Tensor): - utils.check_in_bounds_for_storage(a.storage(), size, stride, storage_offset) + utils.check_in_bounds_for_storage( + a._typed_storage(), size, stride, storage_offset + ) - return TensorMeta(a, shape=size, strides=stride) + return torch.as_strided(a, size, stride, storage_offset) def _as_strided_aten( @@ -1198,7 +1212,7 @@ def _broadcast_in_dim_meta( # (no relative reordering of dims) of integers and # each dimension must be within the new shape def _greater_than_reduce(acc, x): - assert isinstance(x, int) + assert isinstance(x, Dim) assert x > acc assert x < len(shape) @@ -1222,7 +1236,12 @@ def _greater_than_reduce(acc, x): new_strides.append(a.stride()[original_idx]) original_idx = original_idx + 1 else: - new_strides.append(0) + if shape[idx] != 1: + new_strides.append(0) + elif original_idx == a.ndim: + new_strides.append(1) + else: + new_strides.append(a.stride()[original_idx] * a.size()[original_idx]) return a.as_strided(shape, new_strides, a.storage_offset()) @@ -1271,7 +1290,7 @@ def _collapse_view_helper( strides = (1,) else: shape = a.shape # type: ignore[assignment] - strides = a.stride() + strides = a.stride() # type: ignore[assignment] utils.validate_idx(len(shape), start) utils.validate_exclusive_idx(len(shape), end) @@ -1779,6 +1798,53 @@ def _view_of_aten(a: Tensor) -> Tensor: doc=_view_of_doc, ) +# +# Functionalized view mutations +# + + +def _as_strided_scatter_meta( + input: TensorLikeType, + src: TensorLikeType, + size: ShapeType, + stride: StrideType, + storage_offset: int, +) -> TensorLikeType: + utils.validate_shape(size) + utils.validate_strides(stride) + + required_size = utils.compute_required_storage_length(size, stride, storage_offset) + utils.check( + input.numel() >= required_size, + lambda: ( + f"as_strided_scatter: sizes {size}, strides {stride}, storage offset {storage_offset} " + f" and itemsize {input.element_size()} requiring a storage size of " + f"{required_size * input.element_size()} are out of bounds " + f"for storage of size {input.numel() * input.element_size()}" + ), + ) + utils.check( + utils.is_same_shape(src.shape, size), + lambda: f"expected src to have a size equal to the slice of self. src size = {src.shape}, slice size = {size}", + ) + + return _clone_meta(input) + + +_as_strided_scatter_doc = """ + Creates a new tensor equivalent to ``out = input.clone()`` after mutation by + ``out.as_strided(size, stride, storage_offset).copy_(src)``. +""" + +as_strided_scatter = _make_prim( + schema="as_strided_scatter(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt storage_offset) -> Tensor", + meta=_as_strided_scatter_meta, + impl_aten=torch.as_strided_scatter, + return_type=RETURN_TYPE.NEW, + doc=_as_strided_scatter_doc, +) + + # # Shape operations # @@ -1876,7 +1942,8 @@ def _reshape_aten(a: Tensor, shape: ShapeType) -> Tensor: def _rev_meta(a: TensorLikeType, dims: DimsSequenceType) -> TensorLikeType: utils.validate_dimension_indices(a.ndim, dims) - return TensorMeta(a) + out = torch.empty_like(a, memory_format=torch.preserve_format) + return TensorMeta(out) _rev_doc = """ @@ -1931,7 +1998,11 @@ def _convert_element_type_meta(a: TensorLikeType, dtype: torch.dtype) -> TensorL assert isinstance(a, TensorLike) assert isinstance(dtype, torch.dtype) - strides = utils.compute_elementwise_output_strides(a) + # dtype conversion preserves dense strides + if torch._prims_common.is_non_overlapping_and_dense(a): + strides = a.stride() + else: + strides = utils.compute_elementwise_output_strides(a) return TensorMeta(a, strides=strides, dtype=dtype) @@ -2309,17 +2380,19 @@ def _arange_meta( step != 0, lambda: "step must be nonzero", ) - utils.check( - math.isfinite(start) and math.isfinite(end), - lambda: f"unsupported range: {start} -> {end}", - ) + # SymInts can't represent inf + if not isinstance(start, torch.SymInt) and not isinstance(end, torch.SymInt): + utils.check( + math.isfinite(start) and math.isfinite(end), + lambda: f"unsupported range: {start} -> {end}", + ) utils.check( (step > 0 and end >= start) or (step < 0 and end <= start), lambda: "upper bound and lower bound inconsistent with step sign", ) if dtype is not None: pass - elif all(isinstance(arg, int) for arg in (start, end, step)): + elif all(isinstance(arg, IntLike) for arg in (start, end, step)): dtype = torch.int64 else: dtype = torch.get_default_dtype() @@ -2690,7 +2763,7 @@ def _uniform_aten( """ # TODO: we should more seriously review randomness modeling and prims -uniform = _make_prim( +_uniform_helper = _make_prim( schema=( "uniform(SymInt[] shape, *, Scalar low, Scalar high, ScalarType dtype, Device device) -> Tensor" ), diff --git a/torch/_prims/context.py b/torch/_prims/context.py index fea3f17a5009b..22452e4daefcf 100644 --- a/torch/_prims/context.py +++ b/torch/_prims/context.py @@ -68,7 +68,8 @@ def torch_to_refs_map(): # Support conversions for s in torch._refs._conversions.__all__: - r[getattr(torch.Tensor, s)] = torch._refs._conversions.__dict__.get(s) + tensor_attr = getattr(torch.Tensor, s, None) or getattr(torch, s) + r[tensor_attr] = torch._refs._conversions.__dict__.get(s) return r @@ -254,10 +255,6 @@ def _is_func_unsupported_nvfuser( class TorchRefsNvfuserCapabilityMode(TorchRefsMode): def __init__(self, *, skip_ops=()): aten_ops_to_skip = ( - "aten.transpose.int", - "aten.t.default", - "aten.unsqueeze.default", - "aten.permute.default", "aten._log_softmax.default", "aten._log_softmax_backward_data.default", "aten.expand.default", @@ -367,6 +364,16 @@ def _is_rand_like(self, func): ) return result + def _is_full(self, func): + result = "torch.full" == torch.overrides.resolve_name(func) or ( + func + in [ + torch.ops.aten.full, + torch.ops.aten.full.names, + ] + ) + return result + def __torch_function__( self, orig_func: Callable, @@ -405,6 +412,12 @@ def __torch_function__( warn("view has ignored kwargs!") return torch.ops.nvprims.view(a, shape) + if orig_func == torch.ops.aten._reshape_alias.default: + a, shape, stride = args + if len(kwargs) > 0: + warn("view has ignored kwargs!") + return torch.ops.nvprims.view(a, shape) + if self._is_native_batch_norm(orig_func): return torch.ops.nvprims.native_batch_norm(*args, **kwargs) @@ -413,5 +426,8 @@ def __torch_function__( warn("rand_like has ignored kwargs!") return torch.ops.nvprims.rand_like(*args) + if self._is_full(orig_func): + return torch.ops.nvprims.full(*args, **kwargs) + # Then we use TorchRefsMode to interpret the rest return super().__torch_function__(orig_func, types, args, kwargs) diff --git a/torch/_prims/nvfuser_executor.py b/torch/_prims/nvfuser_executor.py index e7d3df238bb50..b44f7653ee81d 100644 --- a/torch/_prims/nvfuser_executor.py +++ b/torch/_prims/nvfuser_executor.py @@ -1,3 +1,4 @@ +import operator from copy import deepcopy from dataclasses import dataclass from functools import lru_cache @@ -22,14 +23,23 @@ DataType, Fusion, FusionDefinition, + Tensor, ) else: DataType = None +import os + + +@lru_cache(None) +def get_nvprim_dump_nvtx(): + return os.getenv("PYTORCH_NVFUSER_DUMP_NVTX") + + DEFAULT_NVFUSER_PYTHON_CONFIG = MappingProxyType( { "use_python_fusion_cache": True, - "allow_single_op_fusion": True, + "allow_single_op_fusion": False, } ) @@ -39,8 +49,8 @@ # https://github.com/pytorch/pytorch/issues/80551 @dataclass(frozen=True) class nvFuserTensorTemplate: - size: tuple - stride: tuple + symbolic_shape: tuple + contiguity: tuple dtype: DataType is_cpu: bool @@ -50,12 +60,29 @@ class nvFuserScalarTemplate: dtype: DataType +@lru_cache(maxsize=2048) +def compute_symbolic_shape(shape): + """Computes the symbolic shape of a tensor. + nvFuser specializes on size-1 dimensions as broadcasted dimensions. + -1 is used to represent any size.""" + return tuple(1 if s == 1 else -1 for s in shape) + + +@lru_cache(maxsize=2048) +def compute_contiguity(shape, strides): + """Computes the contiguity information to simplify internal indexing. + Contiguous dimensions are represented by True, strided dimensions + are represented by False. + """ + return torch._C._nvfuser.compute_contiguity(shape, strides) + + def to_nvfuser_template_args(args): def to_nvfuser(arg): if isinstance(arg, torch.Tensor): return nvFuserTensorTemplate( - arg.size(), - arg.stride(), + compute_symbolic_shape(arg.size()), + compute_contiguity(arg.size(), arg.stride()), getnvFuserDtype(arg.dtype), arg.is_cpu, # type: ignore[attr-defined] ) @@ -89,7 +116,7 @@ def make_nvfuser_fusion(gm: GraphModule, *nv_args_templates): # Everything in the graph must support nvfuser for node in gm.graph.nodes: - if node.op == "call_function" and "getitem" in node.name: + if node.op == "call_function" and node.target == operator.getitem: continue if ( node.op == "call_function" @@ -115,6 +142,10 @@ def make_nvfuser_fusion(gm: GraphModule, *nv_args_templates): call_function_nodes ), "Constant tensors that are saved in the graph and used as arguments are not supported yet" + # Checking output dtypes + output_node = next(filter(lambda n: n.op == "output", gm.graph.nodes)) + orig_flat_out, _ = tree_flatten(output_node.args[0]) + fusion = Fusion() with FusionDefinition(fusion) as fd: @@ -152,7 +183,7 @@ def run_node(self, node): def call_function(self, target, args, kwargs): # This handles tuple unpacking - if "getitem" in str(target): + if target == operator.getitem: assert isinstance(args[0], tuple) return target(*args, **kwargs) args = tuple(map(_to_nvfuser_constant, args)) @@ -160,9 +191,23 @@ def call_function(self, target, args, kwargs): args = (fd,) + args return target(*args, **kwargs) + def output(self, target, args, kwargs): + flat_out, unflatten_spec = tree_flatten(args[0]) + for o, orig_o in zip(flat_out, orig_flat_out): + # casting outputs to the original data type + # ensures outputs produced by fusion would always agree with original GraphModule + out_dtype = _torch_dtype_to_nvfuser_dtype_map.get(orig_o.meta["tensor_meta"].dtype) # type: ignore[union-attr] + assert isinstance( + o, Tensor + ), "output from codegen has to be tensor type" + fd.add_output(fd.ops.cast(o, dtype=out_dtype)) + return args[0] + def templates_to_nvfuser_inputs(arg): if isinstance(arg, nvFuserTensorTemplate): - x = fd.define_tensor(arg.size, arg.stride, arg.dtype, arg.is_cpu) + x = fd.define_tensor( + arg.symbolic_shape, arg.contiguity, arg.dtype, arg.is_cpu + ) return x elif isinstance(arg, nvFuserScalarTemplate): x = fd.define_scalar(arg.dtype) @@ -174,8 +219,6 @@ def templates_to_nvfuser_inputs(arg): nv_args = tuple(map(templates_to_nvfuser_inputs, nv_args_templates)) out = FusionInterpreter(gm).run(*nv_args) flat_out, unflatten_spec = tree_flatten(out) - for o in flat_out: - fd.add_output(o) return fusion, unflatten_spec @@ -212,10 +255,30 @@ def nvfuser_execute(gm: GraphModule, *args, executor_parameters=None): arg for arg in flat_args if isinstance(arg, (torch.Tensor, Number)) ) - return tree_unflatten( + if get_nvprim_dump_nvtx(): + torch.cuda.nvtx.range_push( + "fusion: {0}, graph: {1}".format( + fusion.id(), + str( + [ + { + "op": n.op, + "name": n.name, + "args": n.args, + "kwargs": n.kwargs, + } + for n in gm.graph.nodes + ] + ), + ) + ) + result = tree_unflatten( fusion.execute(concrete_fusion_inputs), # type: ignore[has-type] unflatten_spec, # type: ignore[has-type] ) + if get_nvprim_dump_nvtx(): + torch.cuda.nvtx.range_pop() + return result else: warn( "nvfuser_executor is executed with non-cuda args, fallback to aten executor" @@ -237,10 +300,9 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: ) is not None ) - return ( - node.op == "call_function" - and getattr(node.target, "impl_nvfuser", None) is not None - or "getitem" in node.name # getitem is a special case + return node.op == "call_function" and ( + getattr(node.target, "impl_nvfuser", None) is not None + or node.target == operator.getitem ) @@ -268,30 +330,45 @@ def __call__(self, *args): ) +# A set of operators that are supported by nvFuser +# but should not form a fusion group solely on their own +_non_compute_ops = [ + "torch.ops." + str(getattr(torch.ops.nvprims, prim).default) + for prim in dir(torch.ops.nvprims) + if isinstance(getattr(torch.ops.nvprims, prim), torch._ops.OpOverloadPacket) + and getattr(torch.ops.nvprims, prim).return_type + == torch._prims_common.RETURN_TYPE.VIEW +] + +_allowed_single_node_partition_ops = [ + "torch.ops.nvprims.native_batch_norm.default", + "torch.ops.nvprims.var_mean.default", + "torch.ops.nvprims.var_mean.main", +] + + def _remove_empty_like_fill(gm: GraphModule): # Remove empty_like + fill nodes that prevent lowering to nvprims # This is a workaround for nonoptimal traces of C++ code `(1 - tensor)` # https://github.com/pytorch/pytorch/issues/86612 - # Here when we see a `sub` node, we check if the first input is a result of - # filling a tensor with a scalar - # If so, we replace the first argument of the `sub` node with a scalar - for node in gm.graph.nodes: - if node.op == "call_function": - if node.target == torch.ops.nvprims.sub.default: - # check if the first argument is a fill - if ( - isinstance(node.args[0], torch.fx.Node) - and node.args[0].op == "call_function" - and node.args[0].target == torch.ops.aten.fill.Scalar - ): - # Replace the first argument with the second argument of fill - # aten.fill.Scalar(tensor, scalar) - fill_node = node.args[0] - scalar = fill_node.args[1] - node.args = (scalar, *node.args[1:]) - gm.graph.eliminate_dead_code() - gm.recompile() + def pattern(scalar, tensor): + # pattern for C++ trace of `scalar - tensor`. We are looking for the + # pattern of aten and nvprims.sub specifically because we want to remove + # the empty_like + fill nodes after lowering of AOT Autograd trace to + # nvprims In the future, nvFuser might support fill, and empty_like and + # this workaround can be removed. + empty_like = torch.ops.aten.empty_like.default( + tensor, memory_format=torch.preserve_format + ) + fill = torch.ops.aten.fill.Scalar(empty_like, scalar) + sub = torch.ops.nvprims.sub.default(fill, tensor) + return sub + + def replacement(scalar, tensor): + return torch.ops.nvprims.sub.default(scalar, tensor) + + torch.fx.replace_pattern(gm, pattern, replacement) return gm @@ -325,9 +402,14 @@ def maybe_partition_graph( # CapabilityBasedPartitioner modifies the graph in-place so we need to make a copy of the graph gm = deepcopy(gm) partitioner = CapabilityBasedPartitioner( - gm, supported_ops, allows_single_node_partition=allow_single_op_fusion + gm, + supported_ops, + allows_single_node_partition=allow_single_op_fusion, + non_compute_ops=_non_compute_ops, + allowed_single_node_partition_ops=_allowed_single_node_partition_ops, ) partitions = partitioner.propose_partitions() + partitioner.remove_bookend_non_compute_ops(partitions) if len(partitions) == 0: warn( "No partition found for the graph. " @@ -350,11 +432,33 @@ def maybe_partition_graph( NvfuserGraphModule(nvfuser_submodule, use_python_fusion_cache), ) + # Go through the graph and replace all the nodes that were converted to + # nvprims but won't be sent to nvFuser with a call to PyTorch's eager + # mode. This is necessary because torch.ops.* have higher overhead than + # calling the eager mode directly. + for node in partitioned_graph.graph.nodes: + if node.op == "call_function" and str(node.target).startswith("nvprims."): + if getattr(node.target, "impl_aten", None) is not None: + node.target = node.target.impl_aten + partitioned_graph.graph.eliminate_dead_code() + partitioned_graph.recompile() return partitioned_graph, any_unsupported else: return gm, any_unsupported +class NVTXInterpreter(torch.fx.Interpreter): + def run_node(self, n): + torch.cuda.nvtx.range_push( + "name: {0}, args: {1}, op: {2}, kwargs: {3}".format( + n.name, n.args, n.op, n.kwargs + ) + ) + result = super().run_node(n) + torch.cuda.nvtx.range_pop() + return result + + def nvfuser_execute_partitioned(gm: GraphModule, *args, executor_parameters=None): executor_parameters = executor_parameters or DEFAULT_NVFUSER_PYTHON_CONFIG # maybe_partition_graph function is cached so we can't use non-hashable arguments @@ -374,6 +478,9 @@ def nvfuser_execute_partitioned(gm: GraphModule, *args, executor_parameters=None use_python_fusion_cache=use_python_fusion_cache, ) if is_partitioned: - return gm(*args) + if get_nvprim_dump_nvtx(): + return NVTXInterpreter(gm).run(*args) + else: + return gm(*args) else: return nvfuser_execute(gm, *args, executor_parameters=executor_parameters) diff --git a/torch/_prims/nvfuser_prims.py b/torch/_prims/nvfuser_prims.py index d4132b356473a..fc70bdbc0a124 100644 --- a/torch/_prims/nvfuser_prims.py +++ b/torch/_prims/nvfuser_prims.py @@ -5,20 +5,24 @@ # can be added in the future for the corresponding higher-level torch/aten # functions. -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Tuple import torch +import torch._prims_common as utils from torch._prims_common import ( DimsSequenceType, + elementwise_dtypes, ELEMENTWISE_TYPE_PROMOTION_KIND, getnvFuserDtype, make_contiguous_strides_for, + NumberType, ShapeType, TensorLikeType, ) from torch._prims_common.wrappers import ( + _maybe_convert_to_dtype, backwards_not_supported, elementwise_type_promotion_wrapper, ) @@ -214,6 +218,8 @@ def _{fname}_nvfuser(fd, a, b, c): def _native_batch_norm_nvfuser( fd, input, weight, bias, running_mean, running_var, training, momentum, eps ): + + """ if weight is None: weight = fd.define_null_tensor() if bias is None: @@ -222,15 +228,16 @@ def _native_batch_norm_nvfuser( running_mean = fd.define_null_tensor() if running_var is None: running_var = fd.define_null_tensor() + """ return fd.ops.batch_norm( input, weight, bias, running_mean, running_var, - training, momentum, eps, + training, ) @@ -248,8 +255,8 @@ def _convert_element_type_nvfuser(fd: Any, a: TensorLikeType, dtype: torch.dtype return fd.ops.cast(a, nvfuser_dtype) # type: ignore[attr-defined] -def _transpose_nvfuser(fd, a, permutation): - return fd.ops.permute(a, permutation) # type: ignore[attr-defined] +def _transpose_nvfuser(fd, a, dims): + return fd.ops.permute(a, dims) # type: ignore[attr-defined] def _squeeze_nvfuser(fd, a, a_shape, dimensions): @@ -336,6 +343,26 @@ def _clone_nvfuser(fd: Any, input: TensorLikeType, *, memory_format=None): return fd.ops.set(input) +def _full_nvfuser( + fd: Any, + shape: ShapeType, + fill_value: NumberType, + *, + dtype: Optional[torch.dtype] = None, + layout: Optional[torch.layout] = None, + device: Optional[torch.device] = None, + pin_memory: bool = False, + requires_grad: bool = False, +): + assert device != torch.device("cpu") + assert layout is None or layout is torch.strided + assert pin_memory is False + assert requires_grad is False + dtype = dtype if dtype is not None else utils.type_to_dtype(type(fill_value)) + nvfuser_dtype = getnvFuserDtype(dtype) + return fd.ops.full(shape, fill_value, nvfuser_dtype) + + _nvfuser_impls["native_batch_norm"] = _native_batch_norm_nvfuser _nvfuser_impls["broadcast_in_dim"] = _broadcast_in_dim_nvfuser _nvfuser_impls["convert_element_type"] = _convert_element_type_nvfuser @@ -350,6 +377,147 @@ def _clone_nvfuser(fd: Any, input: TensorLikeType, *, memory_format=None): _nvfuser_impls["var_mean"] = _var_mean_nvfuser _nvfuser_impls["amax"] = _amax_nvfuser _nvfuser_impls["amin"] = _amin_nvfuser +_nvfuser_impls["full"] = _full_nvfuser + + +def register_full(): + name = "full" + + nvprim.define( + "full(SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, " + + "bool? pin_memory=None, bool? requires_grad=None) -> Tensor" + ) + + def _meta_impl( + size, + fill_value, + *, + out=None, + dtype=None, + layout=None, + device=None, + requires_grad=False, + ): + strides = make_contiguous_strides_for(size) + return torch._prims.TensorMeta( + None, + shape=size, + strides=strides, + dtype=dtype, + device=device, + ) + + def _prim_impl( + size, + fill_value, + *, + out=None, + dtype=None, + layout=None, + device=None, + pin_memory=False, + requires_grad=False, + ): + return torch.full( + size, + fill_value, + out=out, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + requires_grad=requires_grad, + ) + + nvprim_impl.impl(name, _prim_impl) + nvprim_meta_impl.impl(name, _meta_impl) + + prim_packet = getattr(torch.ops.nvprims, name) + prim = prim_packet.default + nvprim_autograd_impl.impl(name, backwards_not_supported(prim)) + for p in (prim_packet, prim): + p.__doc__ = "Create a tensor with given size and filled with value" + p.impl_nvfuser = _nvfuser_impls["full"] + p.is_recomputable = _nvfuser_is_recomputable["full"] + p.return_type = torch._prims_common.RETURN_TYPE.NEW # type: ignore[attr-defined] + + +# functorch.compile.min_cut_rematerialization_partition accepts a list of +# operators that can be recomputed in the backward pass. This list is used to +# determine which operators can be recomputed. If an operator is not in this +# list, it will not be recomputed. +_nvfuser_is_recomputable: Dict[str, bool] = { + # Reductions are not allowed to be recomputed + "amax": False, + "amin": False, + "sum": False, + "var": False, + "var_mean": False, + # Normalizations are not allowed to be recomputed + "native_batch_norm": False, + # Random ops are not allowed to be recomputed + "rand_like": False, + # Everything else is allowed to be recomputed + "abs": True, + "acos": True, + "add": True, + "asin": True, + "atan": True, + "atan2": True, + "atanh": True, + "bitwise_and": True, + "bitwise_not": True, + "bitwise_or": True, + "bitwise_xor": True, + "broadcast_in_dim": True, + "ceil": True, + "clone": True, + "convert_element_type": True, + "cos": True, + "cosh": True, + "div": True, + "eq": True, + "erf": True, + "erfc": True, + "exp": True, + "expm1": True, + "floor": True, + "fmod": True, + "full": True, + "ge": True, + "gt": True, + "imag": True, + "isfinite": True, + "le": True, + "lgamma": True, + "log": True, + "log10": True, + "log1p": True, + "log2": True, + "lt": True, + "mul": True, + "ne": True, + "neg": True, + "pow": True, + "real": True, + "reciprocal": True, + "remainder": True, + "round": True, + "rsqrt": True, + "sign": True, + "sin": True, + "sinh": True, + "sqrt": True, + "squeeze": True, + "sub": True, + "tan": True, + "tanh": True, + "transpose": True, + "trunc": True, + "view": True, + "view_of": True, + "where": True, +} def register_native_batch_norm(): @@ -370,15 +538,64 @@ def _prim_impl( ) nvprim_impl.impl(name, _prim_impl) - nvprim_autograd_impl.impl( - name, backwards_not_supported(torch.ops.nvprims.native_batch_norm.default) - ) - prim_packet = torch.ops.nvprims.native_batch_norm prim = prim_packet.default + + def _native_batch_norm_ref( + input: torch.Tensor, + weight: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + running_mean: Optional[torch.Tensor], + running_var: Optional[torch.Tensor], + training: bool, + momentum: float, + eps: float, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + if torch._prims_common.is_complex_dtype(input.dtype): + raise NotImplementedError("Complex tensors are not supported") + + # note: BN only promotes input to dtype of weight/bias, but keeps the same output dtype + result_dtype = input.dtype + computation_dtype, _ = elementwise_dtypes( + input, + weight, + bias, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH, + ) + + input_ = _maybe_convert_to_dtype(input, computation_dtype) + output, mean, rstd = prim( + input_, weight, bias, running_mean, running_var, training, momentum, eps + ) + output_ = _maybe_convert_to_dtype(output, result_dtype) # type: ignore[arg-type] + return (output_, mean, rstd) # type: ignore[return-value] + + def _native_batch_norm_autograd( + input: torch.Tensor, + weight: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + running_mean: Optional[torch.Tensor], + running_var: Optional[torch.Tensor], + training: bool, + momentum: float, + eps: float, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # This wrapper is needed to convert prims calls inside + # _native_batch_norm_ref to nvprims calls + from torch._prims.context import NvfuserPrimsMode + + with NvfuserPrimsMode(): + return backwards_not_supported(_native_batch_norm_ref)( + input, weight, bias, running_mean, running_var, training, momentum, eps + ) + + nvprim_autograd_impl.impl(name, _native_batch_norm_autograd) + for p in (prim_packet, prim): p.__doc__ = "Computes batch normalization." p.impl_nvfuser = _nvfuser_impls["native_batch_norm"] + p.is_recomputable = _nvfuser_is_recomputable["native_batch_norm"] p.return_type = torch._prims_common.RETURN_TYPE.NEW # type: ignore[attr-defined] @@ -437,6 +654,7 @@ def _prim_impl( for p in (prim_packet, prim): p.__doc__ = "Computes rand_like" p.impl_nvfuser = _nvfuser_impls["rand_like"] + p.is_recomputable = _nvfuser_is_recomputable["rand_like"] p.return_type = torch._prims_common.RETURN_TYPE.NEW # type: ignore[attr-defined] @@ -535,9 +753,14 @@ def _var_mean_autograd( for p in (prim_packet, prim): p.__doc__ = "Computes the variance and mean of x over the list of dimensions specified in the dim argument" p.impl_nvfuser = _nvfuser_impls["var_mean"] + p.is_recomputable = _nvfuser_is_recomputable["var_mean"] p.return_type = torch._prims_common.RETURN_TYPE.NEW # type: ignore[attr-defined] +def _nvprims_view_impl_aten(a, original_shape, new_shape): + return a.reshape(new_shape) + + def register_view(): """This function is used to register the view function in torch.ops.view module.""" # View is implemented as a decomposition into prims.split_dim, @@ -568,7 +791,9 @@ def _view_no_original_shape_overload_impl(a, shape): for p in (prim_packet, prim): p.__doc__ = "Creates a tensor with the specified shape containing a copy of the data in a." p.impl_nvfuser = _nvfuser_impls["view"] - p.return_type = torch._prims_common.RETURN_TYPE.NEW # type: ignore[attr-defined] + p.is_recomputable = _nvfuser_is_recomputable["view"] + p.return_type = torch._prims_common.RETURN_TYPE.VIEW # type: ignore[attr-defined] + p.impl_aten = _nvprims_view_impl_aten def register_nvprims(): @@ -577,6 +802,7 @@ def register_nvprims(): register_view() register_native_batch_norm() register_rand_like() + register_full() for name in nvprim_names: main_prim = getattr(torch.ops.prims, name) @@ -593,4 +819,6 @@ def register_nvprims(): for p in (prim_packet, prim): p.__doc__ = main_prim.__doc__ p.impl_nvfuser = _nvfuser_impls[name] + p.is_recomputable = _nvfuser_is_recomputable.get(name, False) p.return_type = main_prim.return_type # type: ignore[attr-defined] + p.impl_aten = main_prim.impl_aten diff --git a/torch/_prims_common/__init__.py b/torch/_prims_common/__init__.py index 72a01a85359c8..b34f109c3a2fb 100644 --- a/torch/_prims_common/__init__.py +++ b/torch/_prims_common/__init__.py @@ -20,6 +20,7 @@ torch.bfloat16: DataType.BFloat16, torch.long: DataType.Int, torch.int: DataType.Int32, + torch.uint8: DataType.Int32, torch.bool: DataType.Bool, # Python scalars complex: DataType.ComplexDouble, @@ -42,12 +43,20 @@ def getnvFuserDtype(dtype: Union[torch.dtype, NumberTypeType]): StrideType = Union[List[int], Tuple[int, ...]] DimsType = Union[int, List[int], Tuple[int, ...]] DimsSequenceType = Union[List[int], Tuple[int, ...]] -# TODO: Type[torch.SymIntNode], Type[torch.SymFloatNode] +# TODO: Type[torch.SymInt], Type[torch.SymFloat] NumberTypeType = Union[Type[bool], Type[int], Type[float], Type[complex]] # TODO: This needs a lot more type annotations -# NumberType = Union[bool, int, float, complex, torch.SymIntNode, torch.SymFloatNode] +# NumberType = Union[bool, int, float, complex, torch.SymInt, torch.SymFloat] NumberType = Union[bool, int, float, complex] -Number = (bool, int, float, complex, torch.SymIntNode, torch.SymFloatNode) + +Number = (bool, int, float, complex, torch.SymInt, torch.SymFloat) +# I don't call it Integral because numbers.Integral includes bool, but IntLike +# does not +Dim = int +IntLike = (int, torch.SymInt) +FloatLike = (float, torch.SymFloat) +IntWithoutSymInt = int +FloatWithoutSymFloat = float DeviceLikeType = Union[str, torch.device] Tensor = torch.Tensor @@ -141,20 +150,31 @@ def compare_tensor_meta(a: TensorLikeType, b: TensorLikeType, check_strides=Fals raise RuntimeError(msg) -def check_significant_strides( - a: TensorLikeType, b: TensorLikeType +def _check_strides_helper( + a: TensorLikeType, b: TensorLikeType, *, only_cuda=True, significant_only=True ) -> Tuple[bool, Optional[int]]: # NOTE: only on CUDA because CPU elementwise strides are incorrect in PyTorch # See https://github.com/pytorch/pytorch/issues/77553 # Only compares strides that are "meaningful" -- strides for dimensions with length > 1 # and for tensors with more than one element - if (a.device.type == "cuda" or b.device.type == "cuda") and a.numel() > 0: + if (not only_cuda or a.device.type == "cuda" or b.device.type == "cuda") and a.numel() > 0: for idx in range(a.ndim): - if a.stride()[idx] != b.stride()[idx] and a.shape[idx] > 1: + check = not significant_only or a.shape[idx] > 1 + if a.stride()[idx] != b.stride()[idx] and check: return False, idx return True, None +def check_significant_strides( + a: TensorLikeType, b: TensorLikeType, *, only_cuda=True +) -> Tuple[bool, Optional[int]]: + return _check_strides_helper(a, b, only_cuda=only_cuda, significant_only=True) + +def check_all_strides( + a: TensorLikeType, b: TensorLikeType, *, only_cuda=True +) -> Tuple[bool, Optional[int]]: + return _check_strides_helper(a, b, only_cuda=only_cuda, significant_only=False) + # This function is equivalent to compute_contiguous() from TensorImpl.cpp def is_contiguous(a: TensorLikeType) -> bool: @@ -283,6 +303,9 @@ def is_non_overlapping_and_dense(a: Tensor) -> bool: its dimensions that is contiguous. """ + if a.is_sparse: + return False + # Short-circuits if the tensor is already contiguous or channels-last contiguous if is_contiguous(a) or is_channels_last_contiguous(a): return True @@ -352,7 +375,7 @@ def compute_elementwise_output_strides(*tensors) -> Tuple[int, ...]: shape = tensors[0].shape - def _cmp(idx_a, idx_b): + def should_swap(idx_a, idx_b): for tensor in tensors: stride_a = tensor.stride()[idx_a] stride_b = tensor.stride()[idx_b] @@ -370,24 +393,30 @@ def _cmp(idx_a, idx_b): if shape[idx_a] > shape[idx_b]: return 1 - # NOTE: this case is missing in the C++ impl - if shape[idx_a] < shape[idx_b]: - return -1 - # Note: this case is hit if all strides are zero, # or all strides are equal and all dimensions have the same length return 0 - perm = tuple(range(ndim)) - perm = tuple(sorted(perm, key=cmp_to_key(_cmp), reverse=True)) + perm = list(reversed(range(ndim))) + + # insertion sort with support for ambiguous comparisons + for i in range(1, ndim): + dim1 = i + for dim0 in reversed(range(i)): + comparison = should_swap(perm[dim0], perm[dim1]) + if comparison > 0: + perm[dim0], perm[dim1] = perm[dim1], perm[dim0] + dim1 = dim0 + elif comparison < 0: + break permuted_shape = [-1] * ndim - for idx, x in enumerate(perm): + for idx, x in enumerate(reversed(perm)): permuted_shape[idx] = shape[x] new_strides = make_contiguous_strides_for(permuted_shape) permuted_strides = [-1] * ndim - for idx, x in enumerate(perm): + for idx, x in enumerate(reversed(perm)): permuted_strides[x] = new_strides[idx] return tuple(permuted_strides) @@ -433,8 +462,8 @@ def validate_idx(rank: int, idx: int): Assumes the index is already canonicalized. """ - assert isinstance(idx, int) - assert isinstance(rank, int) + assert isinstance(idx, Dim) + assert isinstance(rank, Dim) assert idx >= 0 and idx < rank or idx == 0 @@ -450,14 +479,15 @@ def validate_exclusive_idx(rank: int, ex_idx: int): for the given shape. """ - assert isinstance(ex_idx, int) - assert isinstance(rank, int) + assert isinstance(ex_idx, Dim) + assert isinstance(rank, Dim) assert ex_idx > 0 and ex_idx <= rank # "Wraps" a dim (up to one time) for the given rank, allowing dims to be -# specified using negative indices. For scalar tensors with rank 0, then idx -# must be in the range [-1, 0]. Otherwise, idx should be in the range [-rank, rank-1]. +# specified using negative indices. If `wrap_scalar` is true then scalar +# tensors of rank 0 will allow dimensions in the range [-1, 0]. Otherwise, +# idx should be in the range [-rank, rank-1]. def canonicalize_dim(rank: int, idx: int, wrap_scalar: bool = True) -> int: if rank < 0: msg = f"Rank cannot be negative but got {rank}" @@ -490,20 +520,20 @@ def canonicalize_dim(rank: int, idx: int, wrap_scalar: bool = True) -> int: # Takes a dimension or sequence of dimensions and "wraps" them, # mapping negative offsets to positive ones @overload -def canonicalize_dims(rank: int, indices: Sequence[int]) -> Tuple[int, ...]: +def canonicalize_dims(rank: int, indices: Sequence[int], wrap_scalar: bool = True) -> Tuple[int, ...]: pass @overload -def canonicalize_dims(rank: int, indices: int) -> int: +def canonicalize_dims(rank: int, indices: int, wrap_scalar: bool = True) -> int: pass -def canonicalize_dims(rank, indices): - if isinstance(indices, int): - return canonicalize_dim(rank, indices) +def canonicalize_dims(rank, indices, wrap_scalar=True): + if isinstance(indices, Dim): + return canonicalize_dim(rank, indices, wrap_scalar) - return tuple(canonicalize_dim(rank, x) for x in indices) + return tuple(canonicalize_dim(rank, x, wrap_scalar) for x in indices) def is_valid_permutation(rank: int, perm: DimsSequenceType) -> bool: @@ -703,12 +733,14 @@ def infer_size(shape: ShapeType, numel: int) -> Tuple[int, ...]: lambda: f"shape '{list(shape)}' is invalid for input of size {numel}", ) if dim is not None: + # Convert to list to produce a compatible error message with core + # PyTorch, which prints sequences in square brackets. + shape = list(shape) check( newsize != 0, - lambda: f"cannot reshape tensor fo 0 elements into shape {shape} because the " - f"unspecified dimension size -1 can be any value and is ambiguous", + lambda: (f"cannot reshape tensor of 0 elements into shape {shape} because the " + f"unspecified dimension size -1 can be any value and is ambiguous"), ) - shape = list(shape) shape[dim] = numel // newsize return tuple(shape) @@ -823,10 +855,11 @@ def type_to_dtype(typ: type) -> torch.dtype: if typ is bool: return torch.bool - if typ is int: + if typ in [int, torch.SymInt]: return torch.long - if typ is float: + if typ in [float, torch.SymFloat]: return torch.get_default_dtype() + # TODO: sym_complex_float? if typ is complex: return corresponding_complex_dtype(torch.get_default_dtype()) @@ -1105,10 +1138,10 @@ class RETURN_TYPE(Enum): # TODO: when NumberType contains the sym types, can simplify this -def number_type(x: Union[NumberType, torch.SymIntNode, torch.SymFloatNode]) -> Type: - if isinstance(x, torch.SymIntNode): +def number_type(x: Union[NumberType, torch.SymInt, torch.SymFloat]) -> Type: + if isinstance(x, torch.SymInt): return int - elif isinstance(x, torch.SymFloatNode): + elif isinstance(x, torch.SymFloat): return float else: return type(x) @@ -1323,15 +1356,17 @@ def reduction_dtypes( result_dtype = torch.bool return computation_dtype, result_dtype - +# This function's logic is borrowed from the following functions defined in C++: +# batched_matrix_contiguous_strides and contiguous_strides def make_contiguous_strides_for( shape: ShapeType, row_major: bool = True ) -> Tuple[int, ...]: """ - Returns the strides of a contriguous tensor if row_major + Returns the strides of a contiguous tensor if row_major If row_major=True, it returns the strides of a contiguous batch of Fortran-contiguous matrices This is often used when calling external libraries like BLAS/LAPACK/cuSolver... """ + # contiguous_strides from c10/util/strides.h validate_shape(shape) if not shape: return () @@ -1345,6 +1380,7 @@ def make_contiguous_strides_for( result = tuple(reversed(strides)) + # batched_matrix_contiguous_strides from aten/src/ATen/native/LinearAlgebraUtils.h if row_major: return result else: @@ -1439,13 +1475,52 @@ def set_correction( correction = 1 elif correction is None and unbiased is not None: correction = 0 if unbiased is False else 1 - if not isinstance(correction, int): + # NB: we don't actually support symint here, but it's harmless to accept + if not isinstance(correction, IntLike): raise ValueError("correction argument should be integer") if correction < 0: raise ValueError("correction argument should be non-negative") return correction +def compute_required_storage_length( + shape: ShapeType, strides: StrideType, storage_offset: int +) -> int: + """Computes the minimum storage size to hold the given tensor geometry. + + Example + ======= + + This is the size of a newly allocated tensor's storage, in units of elements + + >>> t = torch.empty((10, 20)) + >>> compute_required_storage_length(t.shape, t.stride(), t.storage_offset()) + 200 + + >>> t2 = torch.empty_strided((1, 2, 3), (5, 7, 11)) + >>> size = compute_required_storage_length(t2.shape, t2.stride(), t2.storage_offset()) + >>> size == t.storage().size() + True + + A valid tensor may have a larger storage size, but never smaller + + >>> slice = torch.empty(100)[20:40] + >>> slice.storage().size() + 100 + + >>> compute_required_storage_length(slice.shape, slice.stride(), slice.storage_offset()) + 40 + + """ + # Short-circuits if the shape has no elements + if reduce(operator.mul, shape, 1) == 0: + return 0 + + max_offset = sum((x - 1) * y for x, y in zip(shape, strides)) + # +1 to account for the first element which offsets are taken from + return 1 + storage_offset + max_offset + + def check_in_bounds_for_storage( a: torch.TypedStorage, shape: ShapeType, strides: StrideType, storage_offset: int ): @@ -1453,17 +1528,8 @@ def check_in_bounds_for_storage( Determines if the given shape, strides, and offset are valid for the given storage. """ - # Short-circuits if the shape has no elements - if reduce(operator.mul, shape) == 0: - return - - length = a.size() - storage_offset - max_offset = 0 - for x, y in zip(shape, strides): - max_offset = max_offset + (x - 1) * y - - if max_offset >= length: - required_length = max_offset + storage_offset + required_length = compute_required_storage_length(shape, strides, storage_offset) + if a.size() < required_length: msg = ( "Can't view a storage of size {0} with an offset of {1}, shape of {2}, and strides of {3}, " "which requires a storage of size {4}".format( @@ -1561,6 +1627,27 @@ def mask_tensor(mask: TensorLikeType, t: TensorLikeType): return torch.where(mask, t, 0) +def get_aten_op(fn: Callable, name: str): + """ + Given the __module__ of reference and its name, it returns + (our best guess of) the ATen name of the associated operation + + Note: In ATen, the __name__ of a function within a module often + starts by the module name. E.g. linalg_eigh, or special_zeta + """ + module = fn.__module__ + prefix = "torch._refs" + assert(module.startswith(prefix)) + module = module[len(prefix):] + # We want to go from .special / .nn.functional + # to special and special_ / nn_functional_ + if module: + module = module[1:] + module = module.replace(".", "_") + module = module + "_" + return getattr(torch.ops.aten, f"{module}{name}") + + def dtype_or_default(dtype: Optional[torch.dtype]) -> torch.dtype: return dtype if dtype is not None else torch.get_default_dtype() diff --git a/torch/_prims_common/wrappers.py b/torch/_prims_common/wrappers.py index 76886f886a726..349e450cf3723 100644 --- a/torch/_prims_common/wrappers.py +++ b/torch/_prims_common/wrappers.py @@ -4,6 +4,7 @@ NumberType, TensorLike, TensorLikeType, + ShapeType, ELEMENTWISE_TYPE_PROMOTION_KIND, ) import torch._prims_common as utils @@ -11,8 +12,7 @@ from typing import Callable, Sequence, Union, Tuple, NamedTuple import inspect -from functools import wraps, reduce -import operator +from functools import wraps import warnings from itertools import chain @@ -129,25 +129,22 @@ def _fn(*args, **kwargs): # TODO: handle tuples of tensors -def _maybe_resize_out(out: TensorLikeType, shape): - if out.numel() == 0: - return out.resize_(shape) - - if out.numel() != reduce(operator.mul, shape, 1): - msg = ( - "An output with one or more elements was resized since it had shape {0} " - "which does not match the required output shape {1}. " - "This behavior is deprecated, and in a future PyTorch release outputs will not " - "be resized unless they have zero elements. " - "You can explicitly reuse an out tensor t by resizing it, inplace, to zero elements with t.resize_(0).".format( - str(out.shape), str(shape) +def _maybe_resize_out(out: TensorLikeType, shape: ShapeType): + # If the shapes are correct there's nothing to do + if utils.same_shape(out.shape, shape): + return out + else: + if out.numel() != 0: + msg = ( + f"An output with one or more elements was resized since it had shape {str(out.shape)} " + "which does not match the required output shape {str(shape)}. " + "This behavior is deprecated, and in a future PyTorch release outputs will not " + "be resized unless they have zero elements. " + "You can explicitly reuse an out tensor t by resizing it, inplace, to zero elements with t.resize_(0)." ) - ) - warnings.warn(msg) + warnings.warn(msg) return out.resize_(shape) - return out - def _safe_copy_out( *, copy_from: TensorLikeType, copy_to: TensorLikeType, exact_dtype: bool = False diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index a37673afb72af..f06f5ba34b5a9 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -16,10 +16,13 @@ from torch._prims_common import ( check, DeviceLikeType, + Dim, DimsSequenceType, DimsType, dtype_to_type, ELEMENTWISE_TYPE_PROMOTION_KIND, + FloatLike, + IntLike, is_weakly_lesser_type, Number, NumberType, @@ -39,6 +42,7 @@ elementwise_unary_scalar_wrapper, out_wrapper, ) +from torch.fx.experimental.symbolic_shapes import sym_float, sym_int # Experimental module containing prototype Python references for existing # PyTorch operations. @@ -81,6 +85,7 @@ "isnan", "isreal", "i0", + "lerp", "lgamma", "log", "log1p", @@ -117,7 +122,7 @@ "bitwise_right_shift", "bitwise_xor", "clamp_min", - # "complex", + "clamp_max", "copysign", "div", "eq", @@ -158,12 +163,10 @@ "rsub", "rtruediv", "rfloordiv", - # # special.xlog1py - # # special.zeta "sub", "true_divide", "trunc_divide", - # 'xlogy', # where?, log, mul + "xlogy", # # Elementwise Ternary References # @@ -217,7 +220,10 @@ "constant_pad_nd", "contiguous", "diag_embed", + "diag", "diagonal", + "diagonal_copy", + "diagonal_scatter", "dsplit", "dstack", "expand", @@ -232,6 +238,7 @@ "movedim", "narrow", "narrow_copy", + "native_group_norm", "native_layer_norm", "permute", "ravel", @@ -278,10 +285,6 @@ "zeros", "zeros_like", # - # Randomness References - # - "uniform", # TODO: add OpInfo -- and testing for randomness? - # # Test-related functions # "allclose", @@ -298,7 +301,7 @@ def _broadcast_shapes(*_shapes): shapes = tuple( - (x,) if isinstance(x, int) else x + (x,) if isinstance(x, IntLike) else x for x in filter(lambda x: x is not None, _shapes) ) @@ -315,7 +318,7 @@ def _broadcast_shapes(*_shapes): common_shape = [ 1, ] * reduce(max, (len(shape) for shape in shapes)) - for shape in shapes: + for arg_idx, shape in enumerate(shapes): for idx in range(-1, -1 - len(shape), -1): if common_shape[idx] == 1: if shape[idx] < 0: @@ -326,9 +329,9 @@ def _broadcast_shapes(*_shapes): elif shape[idx] != 1: if common_shape[idx] != shape[idx]: raise RuntimeError( - "Attempting to broadcast a dimension of length ", - str(shape[idx]), - "!", + f"Attempting to broadcast a dimension of length {shape[idx]} at {idx}! " + f"Mismatching argument at index {arg_idx} had {shape}; but expected shape " + f"should be broadcastable to {common_shape}" ) return common_shape @@ -375,7 +378,6 @@ def _make_elementwise_unary_reference( type_promotion_kind, *, aten_op=infer_aten_op, - disable_meta=False, extra_meta=None, ) -> Callable: def inner(prim: Callable): @@ -389,26 +391,59 @@ def inner(prim: Callable): type_promotion_kind=type_promotion_kind, ) def _ref(a: TensorLikeType) -> TensorLikeType: - if not isinstance(a, TensorLike): - raise RuntimeError( - "Expected a tensor input for an elementwise unary operation!" - ) - if extra_meta is not None: extra_meta(a) return prim(a) if aten_op is infer_aten_op: - aten_op = getattr(torch.ops.aten, prim.__name__) + aten_op = utils.get_aten_op(prim, prim.__name__) if aten_op is not None: - register_decomposition(aten_op, disable_meta=disable_meta)(_ref) + register_decomposition(aten_op)(_ref) return _ref return inner +def _make_alias(fn, name): + """ + This function defines an alias of another function and sets its __name__argument + Note that when naïvely doing `alias = fn`, we have that `alias.__name__ == "fn"`. + """ + + def _fn(*args, **kwargs): + return fn(*args, **kwargs) + + _fn.__name__ = name + return _fn + + +def _make_inplace(fn): + """ + Given a function with out variant (i.e. using `out_wrapper()), it returns its in-place variant + See https://github.com/pytorch/pytorch/wiki/Developer-FAQ#how-do-in-place-operations-work-in-pytorch + """ + + # nb. We use the name of the first argument used in the unary references + @wraps(fn) + def _fn(a, *args, **kwargs): + return fn(a, *args, out=a, **kwargs) + + inplace_name = f"{fn.__name__}_" + _fn.__name__ = inplace_name + _fn = register_decomposition(getattr(torch.ops.aten, inplace_name))(_fn) + + # We access the __all__ attribute of the module where fn is defined + # There may be a cleaner way of doing this... + from inspect import getmodule + + _all = getmodule(fn).__all__ # type: ignore[union-attr] + if inplace_name not in _all: + _all.append(inplace_name) + return _fn + + @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT) def abs(a): return prims.abs(a) @@ -606,6 +641,10 @@ def isnan(a: TensorLikeType) -> TensorLikeType: return prims.ne(a, a) +# alias +mvlgamma = _make_alias(torch.special.multigammaln, "mvlgamma") # type: ignore[has-type] + + @_make_elementwise_unary_reference( ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, aten_op=None, # CompositeImplicitAutograd @@ -629,10 +668,6 @@ def lgamma(a): return prims.lgamma(a) -# alias -mvlgamma = torch.special.multigammaln # type: ignore[has-type] - - @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) def log(a): return prims.log(a) @@ -705,10 +740,10 @@ def nan_to_num( nan = 0.0 if posinf is None: - posinf = prims.maximum_value(a.dtype) + posinf = torch.finfo(a.dtype).max if neginf is None: - neginf = prims.minimum_value(a.dtype) + neginf = torch.finfo(a.dtype).min result = where(isnan(a), nan, a) @@ -722,9 +757,14 @@ def nan_to_num( def _neg_meta(a: TensorLikeType): - if a.dtype is torch.bool: - msg = "neg is not supported on bool tensors." - raise RuntimeError(msg) + check( + a.dtype is not torch.bool, + lambda: ( + "Negation, the `-` operator, on a bool tensor is not supported. " + "If you are trying to invert a mask, use the `~` or `logical_not()` " + "operator instead." + ), + ) @_make_elementwise_unary_reference( @@ -842,51 +882,59 @@ def trunc(a): def _make_elementwise_binary_reference( - prim: Callable, - *, type_promotion_kind, aten_op=infer_aten_op, + name=None, has_out=True, supports_lhs_python_scalar=True, supports_rhs_python_scalar=True, - disable_meta=False, + supports_two_python_scalars=False, ) -> Callable: - @elementwise_type_promotion_wrapper( - type_promoting_args=("a", "b"), - type_promotion_kind=type_promotion_kind, - ) - def _ref( - a: Union[Tensor, NumberType], - b: Union[Tensor, NumberType], - ) -> Tensor: - if not supports_lhs_python_scalar and isinstance(a, Number): - raise ValueError( - "Received a lhs Python scalar to an elementwise binary operation that does not accept lhs scalars!" - ) + def inner(prim: Callable): + nonlocal aten_op, name + if name is None: + name = prim.__name__ - if not supports_rhs_python_scalar and isinstance(b, Number): - raise ValueError( - "Received a rhs Python scalar to an elementwise binary operation that does not accept rhs scalars!" + @wraps(prim) + @elementwise_type_promotion_wrapper( + type_promoting_args=("a", "b"), + type_promotion_kind=type_promotion_kind, + ) + def _ref( + a: Union[Tensor, NumberType], + b: Union[Tensor, NumberType], + ) -> Tensor: + check( + supports_lhs_python_scalar or not isinstance(a, Number), + lambda: "{name}: Received a lhs Python scalar to an elementwise binary operation that does not accept lhs scalars!", + ValueError, ) - - # TODO: enable this for operations that support it, like add - if isinstance(a, Number) and isinstance(b, Number): - raise ValueError( - "Receive two Number inputs to an elementwise binary operation!" + check( + supports_rhs_python_scalar or not isinstance(b, Number), + lambda: "{name}: Received a rhs Python scalar to an elementwise binary operation that does not accept rhs scalars!", + ValueError, + ) + check( + supports_two_python_scalars + or not (isinstance(a, Number) and isinstance(b, Number)), + lambda: f"{name}: Receive two Number inputs to an elementwise binary operation!", + ValueError, ) + a, b = _maybe_broadcast(a, b) + return prim(a, b) - a, b = _maybe_broadcast(a, b) - return prim(a, b) + if has_out: + _ref = out_wrapper()(_ref) - if has_out: - _ref = out_wrapper()(_ref) + _ref.__name__ = name + if aten_op is infer_aten_op: + aten_op = utils.get_aten_op(prim, name) + if aten_op is not None: + register_decomposition(aten_op)(_ref) - if aten_op is infer_aten_op: - aten_op = getattr(torch.ops.aten, prim.__name__.split(".")[0]) - if aten_op is not None: - register_decomposition(aten_op, disable_meta=disable_meta)(_ref) + return _ref - return _ref + return inner # Add has its own implementation because it has an alpha argument @@ -906,11 +954,6 @@ def add( Reference implementation of torch.add """ - if isinstance(a, Number) and isinstance(b, Number): - raise ValueError( - "Receive two Number inputs to an elementwise binary operation!" - ) - a, b = _maybe_broadcast(a, b) if alpha is not None: @@ -931,47 +974,61 @@ def add( # TODO: add docstring -atan2 = _make_elementwise_binary_reference( - prims.atan2, # type: ignore[has-type] +@_make_elementwise_binary_reference( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, supports_lhs_python_scalar=False, supports_rhs_python_scalar=False, ) +def atan2(a, b): + return prims.atan2(a, b) + # TODO: add docstring -bitwise_and = _make_elementwise_binary_reference( - prims.bitwise_and, # type: ignore[has-type] +@_make_elementwise_binary_reference( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, ) +def bitwise_and(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.bitwise_and(a, b) + # TODO: add docstring -bitwise_left_shift = _make_elementwise_binary_reference( - prims.shift_left, # type: ignore[has-type] +@_make_elementwise_binary_reference( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, - aten_op=torch.ops.aten.bitwise_left_shift, # prim/aten name mismatch ) +def bitwise_left_shift(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.shift_left(a, b) + # TODO: add docstring -bitwise_or = _make_elementwise_binary_reference( - prims.bitwise_or, # type: ignore[has-type] +@_make_elementwise_binary_reference( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, ) +def bitwise_or(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.bitwise_or(a, b) + # TODO: add docstring -bitwise_right_shift = _make_elementwise_binary_reference( - prims.shift_right_arithmetic, # type: ignore[has-type] +@_make_elementwise_binary_reference( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, - aten_op=torch.ops.aten.bitwise_right_shift, # prim/aten name mismatch ) +def bitwise_right_shift(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.shift_right_arithmetic(a, b) + # TODO: add docstring -bitwise_xor = _make_elementwise_binary_reference( - prims.bitwise_xor, # type: ignore[has-type] +@_make_elementwise_binary_reference( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, ) +def bitwise_xor(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.bitwise_xor(a, b) -def _copysign( +# TODO: add docstring +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + supports_lhs_python_scalar=False, +) +def copysign( a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType] ): if isinstance(b, Number) and isinstance(a, Tensor): @@ -984,14 +1041,6 @@ def _copysign( return where(signbit(b), neg(abs(a)), abs(a)) -# TODO: add docstring -copysign = _make_elementwise_binary_reference( - _copysign, - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, - supports_lhs_python_scalar=False, - aten_op=torch.ops.aten.copysign, -) - # TODO: add docstring # complex = _make_elementwise_binary_reference(prims.complex, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT) @@ -1022,14 +1071,19 @@ def div( # TODO: add docstring -eq = _make_elementwise_binary_reference( - prims.eq, # type: ignore[has-type] +@_make_elementwise_binary_reference( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, supports_lhs_python_scalar=False, ) +def eq(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.eq(a, b) -def _pow( +# TODO: add docstring +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG, +) +def pow( a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType], ) -> TensorLikeType: @@ -1045,13 +1099,6 @@ def _pow( return prims.pow(a, b) -# TODO: add docstring -pow = _make_elementwise_binary_reference( - _pow, - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG, - aten_op=torch.ops.aten.pow, -) - # TODO: add docstring # Float power has its own implementation because it has unique type promotion. # NB: aten_op not registered because CompositeExplicitAutograd @@ -1111,7 +1158,13 @@ def float_power( # # For reference, see CPython's implementation: # https://github.com/python/cpython/blob/ace008c531dd685a30c1dd68f9b5ba35f20171cf/Objects/floatobject.c#L636 -def _floor_divide( + +# TODO: add docstring +@_make_elementwise_binary_reference( + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + supports_two_python_scalars=True, +) +def floor_divide( a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType] ): # Wrap scalars because some references only accept tensor arguments. @@ -1149,7 +1202,7 @@ def _floor_divide_integer(a: Tensor, b: Tensor) -> Tensor: # Convert truncation to flooring: offset = (torch.signbit(a) != torch.signbit(b)).logical_and(torch.fmod(a, b) != 0) - return prims.div(a, b) - prims.convert_element_type(offset, a.dtype) + return prims.div(a, b) - _maybe_convert_to_dtype(offset, a.dtype) def _floor_divide_float(a: Tensor, b: Tensor) -> Tensor: @@ -1178,65 +1231,69 @@ def _floor_divide_float(a: Tensor, b: Tensor) -> Tensor: # TODO: add docstring -floor_divide = _make_elementwise_binary_reference( - _floor_divide, - type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, - aten_op=torch.ops.aten.floor_divide, -) - - -# TODO: add docstring -fmax = _make_elementwise_binary_reference( - prims.fmax, # type: ignore[has-type] +@_make_elementwise_binary_reference( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, - aten_op=torch.ops.aten.fmax, supports_lhs_python_scalar=False, supports_rhs_python_scalar=False, ) +def fmax(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.fmax(a, b) + # TODO: add docstring -fmin = _make_elementwise_binary_reference( - prims.fmin, # type: ignore[has-type] +@_make_elementwise_binary_reference( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, - aten_op=torch.ops.aten.fmin, supports_lhs_python_scalar=False, supports_rhs_python_scalar=False, ) +def fmin(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.fmin(a, b) + # TODO: add docstring -fmod = _make_elementwise_binary_reference( - prims.fmod, # type: ignore[has-type] +@_make_elementwise_binary_reference( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, - aten_op=torch.ops.aten.fmod, supports_lhs_python_scalar=False, supports_rhs_python_scalar=True, ) +def fmod(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.fmod(a, b) + # TODO: add docstring -gcd = _make_elementwise_binary_reference( - prims.gcd, # type: ignore[has-type] +@_make_elementwise_binary_reference( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, - aten_op=torch.ops.aten.gcd, supports_lhs_python_scalar=False, supports_rhs_python_scalar=False, ) +def gcd(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.gcd(a, b) + # TODO: add docstring -ge = _make_elementwise_binary_reference( - prims.ge, # type: ignore[has-type] +@_make_elementwise_binary_reference( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, supports_lhs_python_scalar=False, ) +def ge(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.ge(a, b) + # TODO: add docstring -gt = _make_elementwise_binary_reference( - prims.gt, # type: ignore[has-type] +@_make_elementwise_binary_reference( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, supports_lhs_python_scalar=False, ) +def gt(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.gt(a, b) -def _heaviside(input: TensorLikeType, values: TensorLikeType) -> TensorLikeType: +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=False, +) +def heaviside(input: TensorLikeType, values: TensorLikeType) -> TensorLikeType: input_eq_zero = eq(input, 0) input_lt_zero = logical_or(lt(input, 0), isnan(input)) zeros_and_ones = where(input_lt_zero, 0, 1) @@ -1244,34 +1301,31 @@ def _heaviside(input: TensorLikeType, values: TensorLikeType) -> TensorLikeType: return output -heaviside = _make_elementwise_binary_reference( - _heaviside, - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH, - supports_lhs_python_scalar=False, - supports_rhs_python_scalar=False, - aten_op=torch.ops.aten.heaviside, -) - -hypot = _make_elementwise_binary_reference( - prims.hypot, # type: ignore[has-type] +@_make_elementwise_binary_reference( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, supports_lhs_python_scalar=False, supports_rhs_python_scalar=False, ) +def hypot(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.hypot(a, b) + -igamma = _make_elementwise_binary_reference( - prims.igamma, # type: ignore[has-type] +@_make_elementwise_binary_reference( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, supports_lhs_python_scalar=False, supports_rhs_python_scalar=False, ) +def igamma(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.igamma(a, b) -igammac = _make_elementwise_binary_reference( - prims.igammac, # type: ignore[has-type] + +@_make_elementwise_binary_reference( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, supports_lhs_python_scalar=False, supports_rhs_python_scalar=False, ) +def igammac(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.igammac(a, b) def _check_close_args( @@ -1346,8 +1400,16 @@ def isclose( return result -def _lcm(a: TensorLikeType, b: TensorLikeType): +# TODO: add docstring +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=False, +) +def lcm(a: TensorLikeType, b: TensorLikeType): dtype = a.dtype + # promoting to int32 to maintain 100% consistency with C++ and to + # prevent overflow in case of int8 and int16 promote_to_int = dtype in (torch.int8, torch.int16) if promote_to_int: a = prims.convert_element_type(a, torch.int32) @@ -1361,24 +1423,19 @@ def _lcm(a: TensorLikeType, b: TensorLikeType): # TODO: add docstring -lcm = _make_elementwise_binary_reference( - _lcm, - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, - aten_op=torch.ops.aten.lcm, +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, supports_lhs_python_scalar=False, - supports_rhs_python_scalar=False, ) +def le(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.le(a, b) # TODO: add docstring -le = _make_elementwise_binary_reference( - prims.le, # type: ignore[has-type] +@_make_elementwise_binary_reference( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, - supports_lhs_python_scalar=False, ) - - -def _logical_and(a: TensorLikeType, b: TensorLikeType): +def logical_and(a: TensorLikeType, b: TensorLikeType): if not utils.is_boolean_dtype(a.dtype): a = a != 0 if not utils.is_boolean_dtype(b.dtype): @@ -1386,23 +1443,19 @@ def _logical_and(a: TensorLikeType, b: TensorLikeType): return a & b -logical_and = _make_elementwise_binary_reference( - _logical_and, - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, - aten_op=torch.ops.aten.logical_and, -) - - -@_make_elementwise_unary_reference( - ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, aten_op=torch.ops.aten.logical_not -) +# TODO: add docstring +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL) def logical_not(a: TensorLikeType): if not utils.is_boolean_dtype(a.dtype): return a == 0 return ~a -def _logical_or(a: TensorLikeType, b: TensorLikeType): +# TODO: add docstring +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, +) +def logical_or(a: TensorLikeType, b: TensorLikeType): if not utils.is_boolean_dtype(a.dtype): a = a != 0 if not utils.is_boolean_dtype(b.dtype): @@ -1410,14 +1463,12 @@ def _logical_or(a: TensorLikeType, b: TensorLikeType): return bitwise_or(a, b) -logical_or = _make_elementwise_binary_reference( - _logical_or, +# TODO: add docstring +# TODO: skip unnecessary conversion of long to float +@_make_elementwise_binary_reference( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, - aten_op=torch.ops.aten.logical_or, ) - - -def _logical_xor(a: TensorLikeType, b: TensorLikeType): +def logical_xor(a: TensorLikeType, b: TensorLikeType): if not utils.is_boolean_dtype(a.dtype): a = a != 0 if not utils.is_boolean_dtype(b.dtype): @@ -1425,60 +1476,66 @@ def _logical_xor(a: TensorLikeType, b: TensorLikeType): return a ^ b -# TODO: skip unnecessary conversion of long to float -logical_xor = _make_elementwise_binary_reference( - _logical_xor, - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, - aten_op=torch.ops.aten.logical_xor, -) - - # TODO: add docstring -lt = _make_elementwise_binary_reference( - prims.lt, # type: ignore[has-type] +@_make_elementwise_binary_reference( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, supports_lhs_python_scalar=False, ) +def lt(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.lt(a, b) + # TODO: add docstring -maximum = _make_elementwise_binary_reference( - prims.maximum, # type: ignore[has-type] +@_make_elementwise_binary_reference( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, ) +def maximum(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.maximum(a, b) + # TODO: add docstring -minimum = _make_elementwise_binary_reference( - prims.minimum, # type: ignore[has-type] +@_make_elementwise_binary_reference( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, ) +def minimum(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.minimum(a, b) + # TODO: add docstring -mul = _make_elementwise_binary_reference( - prims.mul, # type: ignore[has-type] +@_make_elementwise_binary_reference( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + supports_two_python_scalars=True, ) +def mul(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.mul(a, b) + # TODO: add docstring -ne = _make_elementwise_binary_reference( - prims.ne, # type: ignore[has-type] +@_make_elementwise_binary_reference( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, supports_lhs_python_scalar=False, ) +def ne(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.ne(a, b) + # TODO: add docstring -nextafter = _make_elementwise_binary_reference( - prims.nextafter, # type: ignore[has-type] +@_make_elementwise_binary_reference( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH, supports_lhs_python_scalar=False, supports_rhs_python_scalar=False, ) +def nextafter(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.nextafter(a, b) + # TODO: add docstring -remainder = _make_elementwise_binary_reference( - prims.remainder, # type: ignore[has-type] +@_make_elementwise_binary_reference( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, - aten_op=torch.ops.aten.remainder, ) +def remainder(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.remainder(a, b) + # reverse sub def rsub( @@ -1512,11 +1569,6 @@ def sub( Reference implementation of torch.sub """ - if isinstance(a, Number) and isinstance(b, Number): - raise ValueError( - "Receive two Number inputs to an elementwise binary operation!" - ) - a, b = _maybe_broadcast(a, b) if alpha is not None: @@ -1535,14 +1587,48 @@ def sub( # TODO: add docstring -true_divide = _make_elementwise_binary_reference( - prims.div, # type: ignore[has-type] +@_make_elementwise_binary_reference( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + name="true_divide", aten_op=None, # CompositeImplicitAutograd + supports_two_python_scalars=True, ) +def true_divide(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.div(a, b) -def _trunc_divide( +@register_decomposition(torch.ops.aten.xlogy) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a", "b"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def xlogy(a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]): + utils.check( + isinstance(a, TensorLike) or isinstance(b, TensorLike), + lambda: 'Expected either argument a or b to be a Tensor"', + ) + + # Operations like eq and log do not handle scalar values, so we convert them to scalar_tensors. + if isinstance(b, TensorLike) and isinstance(a, Number): + a = scalar_tensor(a, dtype=b.dtype, device=b.device) + elif isinstance(a, TensorLike) and isinstance(b, Number): + b = scalar_tensor(b, dtype=a.dtype, device=a.device) + + # mypy: expected "Tensor" + assert isinstance(a, TensorLike) + assert isinstance(b, TensorLike) + rhs = torch.where(torch.eq(a, 0), 0, torch.mul(a, torch.log(b))) + return torch.where(torch.isnan(b), float("nan"), rhs) + + +# TODO: add docstring +@_make_elementwise_binary_reference( + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + aten_op=None, # CompositeImplicitAutograd + supports_two_python_scalars=True, +) +def trunc_divide( a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType] ): dtype = utils.get_dtype(a) @@ -1552,13 +1638,6 @@ def _trunc_divide( return trunc(prims.div(a, b)) -# TODO: add docstring -trunc_divide = _make_elementwise_binary_reference( - _trunc_divide, - type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, - aten_op=None, # CompositeImplicitAutograd -) - # # Elementwise Ternary References # @@ -1939,8 +2018,8 @@ def _reduction( "dtype argument and out dtype must match in reduction" ) if not accepts_dim_tuple: - assert dims is None or isinstance(dims, int) - if isinstance(dims, int): + assert dims is None or isinstance(dims, Dim) + if isinstance(dims, Dim): dims = (dims,) # type: ignore[assignment] dims = utils.reduction_dims(a.shape, dims) if not has_identity: @@ -1974,6 +2053,25 @@ def _reduction( return result +def _make_copy_from_view(fn): + """ + Given a view function (e.g. torch.diagonal) generates its copy variant (e.g. torch.diagonal_copy) + """ + name = fn.__name__ + fn = out_wrapper()(fn) + + def _fn(*args, out=None, **kwargs): + result = fn(*args, out=out, **kwargs) + if out is None: + return result.clone(memory_format=torch.contiguous_format) + return result + + copy_name = f"{name}_copy" + _fn.__name__ = copy_name + _fn = register_decomposition(getattr(torch.ops.aten, copy_name))(_fn) + return _fn + + # Saves Python all py_all = all @@ -1986,7 +2084,7 @@ def all( keepdim: bool = False, ) -> TensorLikeType: # Computes nelem - if isinstance(dim, int): + if isinstance(dim, Dim): dim = (dim,) # type: ignore[assignment] a_ = _maybe_convert_to_dtype(a, torch.bool) @@ -2157,6 +2255,7 @@ def _dim_var_dispatch(dim=None, unbiased=None): return dim, unbiased +@register_decomposition(torch.ops.aten.var) @out_wrapper() def var( a: TensorLikeType, @@ -2230,11 +2329,14 @@ def mean( # reduces over all dimensions if dim=() is passed if dim == () or dim == []: dim = None + orig_dtype = dtype if dtype is None: dtype = a.dtype # can't use out wrapper because of this argument - if out is not None and out.dtype != dtype: - raise RuntimeError("expected out dtype and dtype to match") + check( + out is None or out.dtype == dtype, + lambda: f"Expected out tensor to have dtype {dtype}, but got {out.dtype} instead", + ) result = _reduction( a, prims.sum, @@ -2244,9 +2346,15 @@ def mean( out=None, output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.KEEP_PROMOTED_TYPE, ) - if utils.is_integer_dtype(dtype): - raise RuntimeError("result type should be floating point or complex") - if isinstance(dim, int): + check( + utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype), + lambda: ( + f"mean(): could not infer output dtype. " + f"{'Input' if orig_dtype is None else 'Optional'} dtype must be either " + f"a floating point or complex dtype. Got: {dtype}" + ), + ) + if isinstance(dim, Dim): dim = (dim,) # type: ignore[assignment] dims = utils.reduction_dims(a.shape, dim) # type: ignore[arg-type] nelem = 1 if a.ndim == 0 else reduce(operator.mul, (a.shape[i] for i in dims), 1) @@ -2401,9 +2509,27 @@ def atleast_3d( def as_strided( - a: TensorLikeType, size: ShapeType, stride: StrideType, storage_offset: int = 0 + a: TensorLikeType, + size: ShapeType, + stride: StrideType, + storage_offset: Optional[int] = None, ) -> TensorLikeType: - return prims.as_strided(a, size, stride, storage_offset) + storage_offset_int = ( + storage_offset if storage_offset is not None else a.storage_offset() + ) + return prims.as_strided(a, size, stride, storage_offset_int) + + +@register_decomposition(torch.ops.aten.as_strided_scatter) +def as_strided_scatter( + input: TensorLikeType, + src: TensorLikeType, + size: ShapeType, + stride: StrideType, + storage_offset: Optional[int] = None, +) -> TensorLikeType: + storage_offset_int = 0 if storage_offset is None else storage_offset + return prims.as_strided_scatter(input, src, size, stride, storage_offset_int) def broadcast_shapes(*shapes) -> ShapeType: @@ -2432,6 +2558,18 @@ def broadcast_to(a: TensorLikeType, size: ShapeType) -> TensorLikeType: type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH, ) def cat(tensors: TensorSequenceType, dim: int = 0) -> TensorLikeType: + def cat_compute_output_memory_format(inputs): + format = None + for t in inputs: + f = utils.suggest_memory_format(t) + if f == torch.contiguous_format: + return f + if format is not None and format != f: + return torch.contiguous_format + format = f + assert format is not None + return format + if len(tensors) == 0: msg = "cat expects at least one tensor, but received zero!" raise ValueError(msg) @@ -2449,6 +2587,8 @@ def cat(tensors: TensorSequenceType, dim: int = 0) -> TensorLikeType: utils.validate_idx(t.ndim, dim) break + memory_format = cat_compute_output_memory_format(tensors) + # Filters tensors with one dimension of length zero filtered = tuple(x for x in tensors if not (x.ndim == 1 and x.numel() == 0)) if len(filtered) == 0: @@ -2460,9 +2600,15 @@ def cat(tensors: TensorSequenceType, dim: int = 0) -> TensorLikeType: except Exception: requires_grad = False - return empty((0,), dtype=t.dtype, device=t.device, requires_grad=requires_grad) + return empty( + (0,), + dtype=t.dtype, + device=t.device, + requires_grad=requires_grad, + memory_format=memory_format, + ) - return prims.cat(filtered, dim) + return prims.cat(filtered, dim).clone(memory_format=memory_format) # CompositeImplicitAutograd - don't register decomp @@ -2580,7 +2726,7 @@ def dstack(tensors: TensorSequenceType) -> TensorLikeType: return cat(aligned_tensors, 2) -@register_decomposition(torch.ops.aten.expand, disable_meta=True) +@register_decomposition(torch.ops.aten.expand) def expand(a: Tensor, *shape) -> Tensor: # NOTE: cannot use utils.extract_shape_from_varargs here # because that also validates the shape, but the shape @@ -2687,19 +2833,39 @@ def flipud(a: TensorLikeType) -> TensorLikeType: # CompositeImplicitAutograd - don't register decomp -def narrow(a: TensorLikeType, dim: int, start: int, length: int) -> TensorLikeType: +def narrow( + a: TensorLikeType, dim: int, start: Union[int, TensorLikeType], length: int +) -> TensorLikeType: + # Supports Tensor overload that was added for XLA: + # https://github.com/pytorch/pytorch/issues/31558 + if isinstance(start, TensorLike): + check( + start.dim() == 0 and utils.is_integer_dtype(start.dtype), + lambda: "start must be an 0-dim integral Tensor.", + ) + start = start.item() # type: ignore[assignment] + check(a.dim() > 0, lambda: "narrow() cannot be applied to a 0-dim tensor.") + check(length >= 0, lambda: "narrow(): length must be non-negative.") dim = utils.canonicalize_dim(a.ndim, dim) + dim_length = a.size(dim) + # Start being the end is usually invalid since it's out of bounds. So it's + # not allowed by canonicalize_dim. But for narrow it's valid as long as + # the length is 0, which is handled by the check below. + if start != dim_length: + # Negative start means indexing from the end of dim. + # Note: a dimension isn't being canonicalized here, this reuses + # canonicalize_dim because the semantics are similar. + start = utils.canonicalize_dim(dim_length, start) # type: ignore[arg-type] + check( + start <= dim_length - length, # type: ignore[arg-type] + lambda: f"start ({start}) + length ({length}) exceeds dimension size ({dim_length}).", + ) return prims.slice_in_dim(a, start, start + length, axis=dim) -@register_decomposition(torch.ops.aten.narrow_copy) -@out_wrapper() -def narrow_copy(a: TensorLikeType, dim: int, start: int, length: int) -> TensorLikeType: - # TODO: This must return a sparse tensor if the input is sparse, but refs - # have no sparse support. See narrow_copy_sparse in core. - if a.is_sparse: - raise NotImplementedError("narrow_copy ref doesn't support sparse tensors") - return torch.clone(torch.narrow(a=a, dim=dim, start=start, length=length)) # type: ignore[call-overload] +# TODO: This must return a sparse tensor if the input is sparse, but refs have +# no sparse support. See narrow_copy_sparse in core. +narrow_copy = _make_copy_from_view(narrow) def _normalize( @@ -2719,6 +2885,7 @@ def _normalize( mean (Tensor): mean of the tensor along norm_dims. rstd (Tensor): 1/std of the tensor along norm_dims. """ + norm_dims = utils.canonicalize_dims(a.ndim, norm_dims) computation_dtype = utils.get_computation_dtype(a.dtype) a_acc = _maybe_convert_to_dtype(a, computation_dtype) assert isinstance(a_acc, TensorLike) # to avoid mypy error for var_mean @@ -2730,6 +2897,72 @@ def _normalize( return out, mean, rstd +# add all specified dimensions +def _unsqueeze_multiple(x: TensorLikeType, dimensions: List[int]) -> TensorLikeType: + for dim in sorted(dimensions): + x = torch.unsqueeze(x, dim) + return x + + +def _squeeze_multiple(x: TensorLikeType, dimensions: List[int]) -> TensorLikeType: + for dim in reversed(sorted(dimensions)): + x = torch.squeeze(x, dim) + return x + + +@register_decomposition(torch.ops.aten.native_group_norm.default) +def native_group_norm( + input: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + batch_size: int, + num_channels: int, + flattened_inner_size: int, + num_groups: int, + eps: float, +) -> Tuple[Tensor, Tensor, Tensor]: + utils.check( + input.ndim >= 2, + lambda: f"Expected at least 2 dimensions for input tensor but received {input.ndim}", + ) + utils.check( + num_channels % num_groups == 0, + lambda: "Expected number of channels in input to be divisible by num_groups, " + + f"but got input of shape {input.shape} and num_groups = {num_groups}", + ) + + # num_channels / num_groups and flattened inner dimension are the reduction axes + reduction_dims = [2, 3] + input_reshaped = torch.reshape( + input, + [batch_size, num_groups, num_channels // num_groups, flattened_inner_size], + ) + out, mean, rstd = _normalize(input_reshaped, reduction_dims, eps) + out = out.view(input.shape) + + broadcast_dims = [0] + list(dim for dim in range(2, input.ndim)) + unsqueeze_bias = None + if bias is not None: + unsqueeze_bias = _unsqueeze_multiple(bias, broadcast_dims) + unsqueeze_weight = None + if weight is not None: + unsqueeze_weight = _unsqueeze_multiple(weight, broadcast_dims) + + if unsqueeze_weight is not None: + out = out * unsqueeze_weight + if unsqueeze_bias is not None: + out = out + unsqueeze_bias + + out = _maybe_convert_to_dtype(out, input.dtype) # type: ignore[assignment] + mean = _maybe_convert_to_dtype(mean, input.dtype) # type: ignore[assignment] + rstd = _maybe_convert_to_dtype(rstd, input.dtype) # type: ignore[assignment] + + # remove broadcast dimensions from mean and rstd + mean = _squeeze_multiple(mean, reduction_dims) + rstd = _squeeze_multiple(rstd, reduction_dims) + return (out, mean, rstd) + + @register_decomposition(torch.ops.aten.native_layer_norm) def native_layer_norm( input: Tensor, @@ -2792,16 +3025,16 @@ def native_layer_norm( elif weight is not None and bias is not None: out = out * weight + bias - out = prims.convert_element_type(out, input.dtype) + out = _maybe_convert_to_dtype(out, input.dtype) # type: ignore[assignment] if input.device.type == "cpu": - mean = prims.convert_element_type(mean, input.dtype) - rstd = prims.convert_element_type(rstd, input.dtype) + mean = _maybe_convert_to_dtype(mean, input.dtype) # type: ignore[assignment] + rstd = _maybe_convert_to_dtype(rstd, input.dtype) # type: ignore[assignment] return (out, mean, rstd) # TODO: Adding this as a meta function causes functorch tests to fail when compiled with debug mode. # test/test_eager_transforms.py::TestFunctionalizeCPU::test_functionalize_fx_transpose_simple_cpu -@register_decomposition(torch.ops.aten.permute, disable_meta=True) +@register_decomposition(torch.ops.aten.permute) def permute(a: TensorLikeType, *dims) -> TensorLikeType: _permutation = utils.canonicalize_dims( a.ndim, utils.extract_dims_from_varargs(dims) @@ -3090,7 +3323,7 @@ def rot90( elif k == 3: return torch.transpose(torch.flip(a, (dims[0],)), dims[0], dims[1]) else: - return clone(a) + return clone(a, memory_format=torch.contiguous_format) def _check_stack_inputs(tensors: TensorSequenceType) -> None: @@ -3166,7 +3399,7 @@ def unbind(t: TensorLikeType, dim: int = 0) -> TensorSequenceType: dim = utils.canonicalize_dim(t.ndim, dim) check( len(t.shape) > 0, - lambda: "dimension specified as 0 but tensor has no dimensions", + lambda: "Dimension specified as 0 but tensor has no dimensions", IndexError, ) return tuple( @@ -3177,7 +3410,9 @@ def unbind(t: TensorLikeType, dim: int = 0) -> TensorSequenceType: @register_decomposition(torch.ops.aten.index_copy) @out_wrapper() def index_copy(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike): - return x.clone().index_copy_(dim, index, tensor) + return x.clone(memory_format=torch.contiguous_format).index_copy_( + dim, index, tensor + ) @register_decomposition(torch.ops.aten.index_copy_) @@ -3234,10 +3469,13 @@ def index_add( *, alpha: NumberType = 1, ): - return x.clone().index_add_(dim, index, tensor, alpha=alpha) # type: ignore[arg-type] + # index_add always returns a new contiguous tensor + return x.clone(memory_format=torch.contiguous_format).index_add_( + dim, index, tensor, alpha=alpha # type: ignore[arg-type] + ) -@register_decomposition(torch.ops.aten.index_select, disable_meta=True) +@register_decomposition(torch.ops.aten.index_select) @out_wrapper() def index_select(x: TensorLike, dim: int, index: TensorLike): dim = utils.canonicalize_dims(x.ndim, dim) @@ -3255,8 +3493,7 @@ def index_select(x: TensorLike, dim: int, index: TensorLike): return x[idx] -# Note: although squeeze is documented as having the out= kwarg it doesn't -@register_decomposition(torch.ops.aten.squeeze, disable_meta=True) +@register_decomposition(torch.ops.aten.squeeze) def squeeze(a: TensorLikeType, dim: Optional[int] = None) -> TensorLikeType: if dim is not None: dim = utils.canonicalize_dim(a.ndim, dim) @@ -3299,7 +3536,7 @@ def tensor_split( raise ValueError(msg) # Case 0 -- indices_or_sections is an integer or a scalar tensor n and a is split along dim into n parts of equal-ish length - if isinstance(indices_or_sections, int) or ( + if isinstance(indices_or_sections, IntLike) or ( isinstance(indices_or_sections, TensorLike) and indices_or_sections.ndim == 0 ): sections: int = ( @@ -3365,7 +3602,7 @@ def hsplit( ), ) dim = 0 if a.ndim == 1 else 1 - if isinstance(indices_or_sections, int): + if isinstance(indices_or_sections, IntLike): split_size = indices_or_sections check( (split_size != 0 and a.shape[dim] % split_size == 0), @@ -3407,17 +3644,17 @@ def vsplit( + " dimensions!" ), ) - if isinstance(indices_or_sections, int): + if isinstance(indices_or_sections, IntLike): split_size = indices_or_sections check( (split_size != 0 and a.shape[0] % split_size == 0), lambda: ( - "torch.vsplit attempted to split along dimension 0 " - + ", but the size of the dimension " - + str(a.shape[0]) - + " is not divisible by the split_size " - + str(split_size) - + "!" + f"torch.vsplit attempted to split along dimension 0" + f", but the size of the dimension " + f"{a.shape[0]}" + f" is not divisible by the split_size " + f"{split_size}" + f"!" ), ) return tensor_split(a, split_size, 0) @@ -3436,7 +3673,43 @@ def vsplit( return tensor_split(a, split_sizes, 0) -@register_decomposition(torch.ops.aten.diagonal, disable_meta=True) +@register_decomposition(torch.ops.aten.diag.out) +@out_wrapper() +def diag( + self: TensorLikeType, + offset: int = 0, +) -> TensorLikeType: + ndim = self.dim() + utils.check( + ndim in (1, 2), lambda: f"diag(): Supports 1D or 2D tensors. Got {ndim}D" + ) + if ndim == 1: + return torch.diag_embed(self, offset) + else: + return torch.diagonal_copy(self, offset) + + +@register_decomposition(torch.ops.aten.diagonal_scatter) +@out_wrapper() +def diagonal_scatter( + input: TensorLikeType, + src: TensorLikeType, + offset: int = 0, + dim1: int = 0, + dim2: int = 1, +) -> TensorLikeType: + out = input.clone() + diag = out.diagonal(offset, dim1, dim2) + check( + diag.shape == src.shape, + lambda: "expected src to have a size equal to the diagonal of the input." + f"Got {src.shape} for a diagonal of shape {diag.shape}", + ) + copy_to(diag, src) + return out + + +@register_decomposition(torch.ops.aten.diagonal) def diagonal( self: TensorLikeType, offset: int = 0, @@ -3478,7 +3751,11 @@ def diagonal( return result +diagonal_copy = _make_copy_from_view(diagonal) + + @register_decomposition(torch.ops.aten.diag_embed) +@out_wrapper() def diag_embed( t: TensorLikeType, offset: int = 0, @@ -3529,7 +3806,10 @@ def diag_embed( cond = a_range == b_range.unsqueeze(-1) cond_shape = [last_dim if i in (dim1, dim2) else 1 for i in range(len(t.shape))] cond = cond.reshape(cond_shape) - return utils.mask_tensor(cond, t) + + # aten.diag_embed always returns a new contiguous tensor + # contiguous() is needed to correctly model the output stride + return utils.mask_tensor(cond, t).contiguous() # CompositeImplicitAutograd - don't register decomp @@ -3538,15 +3818,15 @@ def dsplit(a: TensorLikeType, sections: DimsType) -> TensorSequenceType: raise RuntimeError( f"torch.dsplit requires a tensor with at least 3 dimension, but got a tensor with {a.ndim} dimensions!" ) - if isinstance(sections, int) and (sections == 0 or a.shape[2] % sections != 0): + if isinstance(sections, IntLike) and (sections == 0 or a.shape[2] % sections != 0): raise RuntimeError( - "torch._refs.dsplit attempted to split along dimension 2, " + "torch.dsplit attempted to split along dimension 2, " + f"but the size of the dimension {a.shape[2]} is not divisible by the split_size {sections}!" ) return tensor_split(a, sections, 2) -@register_decomposition(torch.ops.aten.t.default, disable_meta=True) +@register_decomposition(torch.ops.aten.t.default) def t(a: TensorLikeType): # TODO: Add sparse support # if a.is_sparse: @@ -3577,7 +3857,7 @@ def T(a: TensorLikeType) -> TensorLikeType: return a.t() -@register_decomposition(torch.ops.aten.transpose, disable_meta=True) +@register_decomposition(torch.ops.aten.transpose) def transpose(a: TensorLikeType, dim0: int, dim1: int) -> TensorLikeType: _dim0, _dim1 = utils.canonicalize_dims(a.ndim, (dim0, dim1)) # type: ignore[misc] @@ -3607,7 +3887,9 @@ def unfold( @register_decomposition(torch.ops.aten.unfold_copy) @out_wrapper() def unfold_copy(self: TensorLikeType, dimension: int, size: int, step: int): - return self.unfold(dimension, size, step).clone() + return self.unfold(dimension, size, step).clone( + memory_format=torch.contiguous_format + ) @register_decomposition(torch.ops.aten.cumsum) @@ -3634,7 +3916,8 @@ def cumsum( return sum(masked_a, dim=dim, keepdim=keepdim, dtype=dtype, out=out) -@register_decomposition(torch.ops.aten.unsqueeze, disable_meta=True) +# Note: although squeeze is documented as having the out= kwarg it doesn't +@register_decomposition(torch.ops.aten.unsqueeze) def unsqueeze(a: TensorLikeType, dim: int) -> TensorLikeType: # Note that unsqueeze canonicalizes with rank + 1 because it allows # a new innermost dimension to be specified @@ -3647,7 +3930,7 @@ def unsqueeze(a: TensorLikeType, dim: int) -> TensorLikeType: # Tensor.view(a, b, c) or Tensor.view((a, b, c)) Function call torch.view # doesn't support unpacked shapes # TODO: Turn this into a decomposition (currently fails on reshape meta tests) -@register_decomposition(torch.ops.aten.view, disable_meta=True) +@register_decomposition(torch.ops.aten.view) def view(a: TensorLikeType, *shape: ShapeType) -> TensorLikeType: return _reshape_view_helper(a, *shape, allow_copy=False) @@ -3662,7 +3945,7 @@ def ravel(a: TensorLikeType) -> TensorLikeType: return reshape(a, (-1,)) -@register_decomposition(torch.ops.aten.empty) +@register_decomposition(torch.ops.aten.empty.memory_format) @out_wrapper() def empty( *shape, @@ -3755,7 +4038,7 @@ def new_empty_strided( ) -@register_decomposition(torch.ops.aten.zeros) +@register_decomposition(torch.ops.aten.zeros.default) @out_wrapper() def zeros( *size, @@ -3807,7 +4090,7 @@ def new_zeros( ) -@register_decomposition(torch.ops.aten.ones) +@register_decomposition(torch.ops.aten.ones.default) @out_wrapper() def ones( *size, @@ -3963,6 +4246,37 @@ def arange( ) +@register_decomposition(torch.ops.aten.lerp) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("start", "end", "weight"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def lerp(start: Tensor, end: Tensor, weight: Union[Tensor, NumberType]): + check( + start.dtype == end.dtype, + lambda: f"expected dtype {start.dtype} for `end` but got dtype {end.dtype}", + ) + if isinstance(weight, Number): + weight = start.new_full((), weight) # type: ignore[arg-type] + else: + check( + start.dtype == weight.dtype, + lambda: f"expected dtype {start.dtype} for `weight` but got dtype {weight.dtype}", # type: ignore[union-attr] + ) + assert isinstance(weight, Tensor) # mypy + # We implement it this way for numerical stability. We assume (in the stability optimisation) + # that 0 <= weight <= 1. We take the abs to deal with complex numbers + # We want to perform operations near zero, which is where floating points are most precise + # thus, we perform the following optimisation: + # If weight.abs() >= 0.5: + # return (1 - weight) * (start - end) + end + mask = weight.abs() >= 0.5 + coeff = torch.where(mask, weight - 1, weight) + base = torch.where(mask, end, start) + return coeff * (end - start) + base + + @register_decomposition(torch.ops.aten.linspace) @out_wrapper() def linspace( @@ -3976,28 +4290,28 @@ def linspace( pin_memory: bool = False, requires_grad: bool = False, ) -> TensorLikeType: - if dtype is None: - dtype = torch.get_default_dtype() - - # NB: NumPy actually doesn't do this cast, but for this ref, I'd rather have this - # cast than not, because it allows us to always go into the precise path - # if dtype is integral and not worry about whether start/end are float - if prims.utils.is_integer_dtype(dtype): - if isinstance(start, float): - start = int(start) - if isinstance(end, float): - end = int(end) - if py_any(isinstance(arg, complex) for arg in (start, end, steps)): - raise NotImplementedError - assert not isinstance(start, complex) and not isinstance(end, complex) # for mypy + default_complex_dtype = utils.corresponding_complex_dtype( + torch.get_default_dtype() + ) + if dtype is None: + dtype = default_complex_dtype + else: + check( + utils.is_complex_dtype(dtype), + lambda: f"linspace(): inferred dtype {default_complex_dtype} can't be safely cast to passed dtype {dtype}", + ) + else: + dtype = dtype or torch.get_default_dtype() + assert isinstance(dtype, torch.dtype) + # steps does not participate in the computation of the dtype check( - isinstance(steps, int), + isinstance(steps, IntLike), lambda: "steps must be int, not float", exc_type=TypeError, ) - assert isinstance(steps, int) # for mypy + assert isinstance(steps, IntLike) # for mypy check(steps >= 0, lambda: "number of steps must be non-negative") factory_kwargs = { @@ -4007,41 +4321,27 @@ def linspace( "requires_grad": requires_grad, } if steps == 0: - ret = torch.full((0,), 0, dtype=dtype, **factory_kwargs) # type: ignore[call-overload] - elif steps == 1: - ret = torch.full((1,), start, dtype=dtype, **factory_kwargs) # type: ignore[call-overload] - elif start == end: - ret = torch.full((steps,), start, dtype=dtype, **factory_kwargs) # type: ignore[call-overload] - else: - if prims.utils.is_integer_dtype(dtype): - # We need to cast to int, so to avoid off-by-one issues - # do the entire computation with ints when we can - assert isinstance(start, int) and isinstance(end, int) - step_size_x_denom = end - start - eps = 1 if end > start else -1 - denom = steps - 1 - ret = prims.to_dtype( - torch.arange( - start * denom, - end * denom + eps, - step_size_x_denom, - dtype=torch.int64, - **factory_kwargs, # type: ignore[arg-type] - ) - / denom, - dtype, - ) - else: - step_size = (end - start) / (steps - 1) - eps = step_size / 2 - ret = prims.to_dtype( - torch.arange( # type: ignore[call-overload] - start, end + eps, step_size, dtype=torch.float64, **factory_kwargs - ), - dtype, - ) - - return ret + return torch.full((0,), 0, dtype=dtype, **factory_kwargs) # type: ignore[arg-type] + if steps == 1: + return torch.full((1,), start, dtype=dtype, **factory_kwargs) # type: ignore[arg-type] + if start == end: + return torch.full((steps,), start, dtype=dtype, **factory_kwargs) # type: ignore[arg-type] + + # arange returns values in the interval [start, end) so we add an an eps to make it [start, end] + # The eps is small enough as to always add just the element end + step_size = 1 / (steps - 1) + eps = step_size / 2 + # arange returns a tensor of size divup(end - start, step) and thus, for the arguemnts below + # ceil(div(1 + step_size/2, 1/(steps - 1)) = steps - 1 + ceil(1 / 2) = steps + # torch.arange is an scan algorithm, so we need a high-precision dtype + rg = torch.arange( + 0, 1 + eps, step_size, dtype=torch.float64, **factory_kwargs # type: ignore[arg-type] + ) + double_dtype = torch.complex128 if utils.is_complex_dtype(dtype) else torch.float64 + rg = _maybe_convert_to_dtype(rg, double_dtype) # type: ignore[assignment] + cast = partial(torch.full, (1,), dtype=double_dtype, **factory_kwargs) + out = torch.lerp(cast(start), cast(end), rg) + return _maybe_convert_to_dtype(out, dtype) # type: ignore[return-value] @register_decomposition(torch.ops.aten.logspace) @@ -4063,10 +4363,10 @@ def logspace( # NB: NumPy doesn't have this cast if prims.utils.is_integer_dtype(dtype): - if isinstance(start, float): - start = int(start) - if isinstance(end, float): - end = int(end) + if isinstance(start, FloatLike): + start = sym_int(start) + if isinstance(end, FloatLike): + end = sym_int(end) assert not isinstance(base, complex) # for mypy if base < 0: @@ -4174,12 +4474,14 @@ def movedim( if type(destination) is int: destination = (destination,) + # Converts to list to produce a compatible error message with core PyTorch, + # which prints sequences in square brackets. utils.check( len(source) == len(destination), # type: ignore[arg-type] lambda: ( - "movedim: Invalid source or destination dims: source " - f"({source} dims) should contain the same number of dims as " - f"destination ({destination} dims)" + "movedim: Invalid source or destination dims: source " # type: ignore[arg-type] + f"({list(source)} dims) should contain the same number of dims as " + f"destination ({list(destination)} dims)" ), ) @@ -4190,13 +4492,14 @@ def movedim( sss = set(ss) dss = set(ds) + # See above on why this converts to list in error messages. utils.check( len(ss) == len(sss), - lambda: f"movedim: repeated dim in `source` {source}", + lambda: f"movedim: repeated dim in `source` ({list(source)})", # type: ignore[arg-type] ) utils.check( len(ds) == len(dss), - lambda: f"movedim: repeated dim in `destination` {destination}", + lambda: f"movedim: repeated dim in `destination` ({list(destination)})", # type: ignore[arg-type] ) m = dict(zip(ds, ss)) @@ -4291,6 +4594,7 @@ def eye( # result.requires_grad_(requires_grad) +@register_decomposition(torch.ops.aten.full) @out_wrapper() def full( shape: ShapeType, @@ -4302,6 +4606,12 @@ def full( pin_memory: bool = False, requires_grad: bool = False, ) -> TensorLikeType: + utils.check_layout(layout) + utils.check_pin_memory(pin_memory) + + dtype = dtype if dtype is not None else utils.type_to_dtype(type(fill_value)) + device = device if device is not None else torch.device("cpu") + e = empty( shape, dtype=dtype, @@ -4310,7 +4620,7 @@ def full( pin_memory=pin_memory, requires_grad=requires_grad, ) - return fill(e, fill_value) + return torch.fill(e, fill_value) # type: ignore[arg-type] def full_like( @@ -4341,8 +4651,8 @@ def full_like( ones_like = partial(full_like, fill_value=True) -# TODO: add pin_memory support -@register_decomposition(torch.ops.aten.randn) + +@register_decomposition(torch.ops.aten.randn.default) @out_wrapper() def randn( *shape, @@ -4350,16 +4660,14 @@ def randn( device: Optional[torch.device] = None, layout: Optional[torch.layout] = None, requires_grad: bool = False, - pin_memory: Optional[bool] = None, + pin_memory: bool = False, ) -> TensorLikeType: - - check(pin_memory is None, lambda: "pin_memory parameter is not supported!") + utils.check_pin_memory(pin_memory) shape_ = utils.extract_shape_from_varargs(shape) dtype = utils.dtype_or_default(dtype) device = utils.device_or_default(device) - layout = utils.layout_or_default(layout) return prims.normal( shape_, @@ -4391,8 +4699,7 @@ def scalar_tensor( # -@register_decomposition(torch.ops.aten.uniform) -def uniform( +def _uniform_helper( shape: ShapeType, low: Union[bool, int, float] = 0.0, high: Union[bool, int, float] = 1.0, @@ -4402,15 +4709,15 @@ def uniform( ) -> TensorLikeType: utils.validate_shape(shape) - assert isinstance(low, (bool, int, float)) - assert isinstance(high, (bool, int, float)) - low = float(low) - high = float(high) + assert isinstance(low, Number) + assert isinstance(high, Number) + low = sym_float(low) + high = sym_float(high) assert isinstance(dtype, torch.dtype) device = utils.canonicalize_device(device) - return prims.uniform(shape, low=low, high=high, dtype=dtype, device=device) + return prims._uniform_helper(shape, low=low, high=high, dtype=dtype, device=device) @register_decomposition( @@ -4450,10 +4757,14 @@ def masked_fill(a: TensorLikeType, mask: TensorLikeType, value: TensorOrNumberLi # Since `where` allows type-promotion, # cast value to correct type before passing to `where` if isinstance(value, Number): - return torch.where(mask, python_type(value), a) + r = torch.where(mask, python_type(value), a) + else: + assert isinstance(value, TensorLike) + r = torch.where(mask, prims.to_dtype(value, a.dtype), a) - assert isinstance(value, TensorLike) - return torch.where(mask, prims.to_dtype(value, a.dtype), a) + # aten.mask_fill always return a new contiguous tensor + # contiguous() is needed to correctly model the output stride + return r.contiguous() # CompositeImplicitAutograd - don't register decomp @@ -4505,10 +4816,10 @@ def norm( ) -> TensorLikeType: # In these cases we compute the "Frobenius norm" if ( - p == "fro" and (dim is None or isinstance(dim, int) or len(dim) <= 2) + p == "fro" and (dim is None or isinstance(dim, Dim) or len(dim) <= 2) ) or p is None: p = 2 - if isinstance(dim, int): + if isinstance(dim, Dim): dim = [dim] if isinstance(p, str): # Here we either call the nuclear norm, or we call matrix_norm with some arguments @@ -4555,7 +4866,9 @@ def triu(a: TensorLikeType, diagonal: int = 0) -> TensorLikeType: - torch.arange(h, device=a.device).unsqueeze(-1) ) >= diagonal - return utils.mask_tensor(mask, a) + # aten.triu always returns a new contiguous tensor + # contiguous() is needed to correctly model the output stride + return utils.mask_tensor(mask, a).contiguous() @register_decomposition(torch.ops.aten.tril) @@ -4570,7 +4883,9 @@ def tril(a: TensorLikeType, diagonal: int = 0) -> TensorLikeType: - torch.arange(h, device=a.device).unsqueeze(-1) ) <= diagonal - return utils.mask_tensor(mask, a) + # aten.tril always returns a new contiguous tensor + # contiguous() is needed to correctly model the output stride + return utils.mask_tensor(mask, a).contiguous() # This is based on get_tril_size in aten/src/ATen/native/TensorFactories.h @@ -4773,6 +5088,92 @@ def bucketize( return start.to(dtype=out_dtype) +# inplace +abs_ = _make_inplace(abs) +acos_ = _make_inplace(acos) +acosh_ = _make_inplace(acosh) +addcmul_ = _make_inplace(addcmul) +addcdiv_ = _make_inplace(addcdiv) +asin_ = _make_inplace(asin) +asinh_ = _make_inplace(asinh) +atan_ = _make_inplace(atan) +atanh_ = _make_inplace(atanh) +atan2_ = _make_inplace(atan2) +ceil_ = _make_inplace(ceil) +clamp_ = _make_inplace(clamp) +clamp_min_ = _make_inplace(clamp_min) +clamp_max_ = _make_inplace(clamp_max) +conj_physical_ = _make_inplace(conj_physical) +copysign_ = _make_inplace(copysign) +cos_ = _make_inplace(cos) +cosh_ = _make_inplace(cosh) +cumsum_ = _make_inplace(cumsum) +digamma_ = _make_inplace(digamma) +div_ = _make_inplace(div) +eq_ = _make_inplace(eq) +erf_ = _make_inplace(erf) +erfc_ = _make_inplace(erfc) +erfinv_ = _make_inplace(erfinv) +exp_ = _make_inplace(exp) +exp2_ = _make_inplace(exp2) +expm1_ = _make_inplace(expm1) +float_power_ = _make_inplace(float_power) +floor_ = _make_inplace(floor) +floor_divide_ = _make_inplace(floor_divide) +fmod_ = _make_inplace(fmod) +frac_ = _make_inplace(frac) +ge_ = _make_inplace(ge) +gt_ = _make_inplace(gt) +heaviside_ = _make_inplace(heaviside) +hypot_ = _make_inplace(hypot) +igamma_ = _make_inplace(igamma) +igammac_ = _make_inplace(igammac) +le_ = _make_inplace(le) +lerp_ = _make_inplace(lerp) +lgamma_ = _make_inplace(lgamma) +log10_ = _make_inplace(log10) +log1p_ = _make_inplace(log1p) +log2_ = _make_inplace(log2) +log_ = _make_inplace(log) +logical_and_ = _make_inplace(logical_and) +logical_or_ = _make_inplace(logical_or) +logical_xor_ = _make_inplace(logical_xor) +lt_ = _make_inplace(lt) +mvlgamma_ = _make_inplace(mvlgamma) +nan_to_num_ = _make_inplace(nan_to_num) +ne_ = _make_inplace(ne) +neg_ = _make_inplace(neg) +nextafter_ = _make_inplace(nextafter) +pow_ = _make_inplace(pow) +reciprocal_ = _make_inplace(reciprocal) +remainder_ = _make_inplace(remainder) +rsqrt_ = _make_inplace(rsqrt) +sgn_ = _make_inplace(sgn) +sigmoid_ = _make_inplace(sigmoid) +sign_ = _make_inplace(sign) +sin_ = _make_inplace(sin) +sinc_ = _make_inplace(sinc) +sinh_ = _make_inplace(sinh) +sqrt_ = _make_inplace(sqrt) +square_ = _make_inplace(square) +tan_ = _make_inplace(tan) +tanh_ = _make_inplace(tanh) +tril_ = _make_inplace(tril) +triu_ = _make_inplace(triu) +true_divide_ = _make_inplace(true_divide) +trunc_ = _make_inplace(trunc) +xlogy_ = _make_inplace(xlogy) + +# Views +# We can't model these as above, as the pattern of doing `op(a, out=a)` does not work for a view function +# given that it does not reshape the input (it just copies the result into it) + +# squeeze_ = _make_inplace(squeeze) +# t_ = _make_inplace(t) +# transpose_ = _make_inplace(transpose) +# unsqueeze_ = _make_inplace(unsqueeze) + + import torch._refs._conversions import torch._refs.fft import torch._refs.linalg diff --git a/torch/_refs/_conversions.py b/torch/_refs/_conversions.py index 11657f7058bd7..abcd5729818d7 100644 --- a/torch/_refs/_conversions.py +++ b/torch/_refs/_conversions.py @@ -1,6 +1,12 @@ import torch +import torch._prims_common as utils -from torch._prims_common import TensorLikeType +# Utilities should come BEFORE this import +from torch._decomp import register_decomposition + +from torch._prims_common import check, TensorLikeType +from torch._prims_common.wrappers import out_wrapper +from torch._refs import _broadcast_shapes # Data conversion references. # @@ -10,6 +16,7 @@ # (like int). __all__ = [ + # dtypes "bfloat16", "bool", "byte", @@ -23,6 +30,8 @@ "int", "long", "short", + # misc + "complex", ] @@ -61,3 +70,37 @@ def fn( long = _make_conversion_method("long", torch.long) short = _make_conversion_method("short", torch.short) + + +@register_decomposition(torch.ops.aten.complex) +# Note: complex has type promotion tests disabled due to different semantics. +# exact_dtype is for compat with complex_check_dtype from core. +@out_wrapper(exact_dtype=True) +def complex(real: TensorLikeType, imag: TensorLikeType) -> TensorLikeType: + allowed_dtypes = (torch.float32, torch.float64, torch.float16) + check( + real.dtype in allowed_dtypes and imag.dtype in allowed_dtypes, + lambda: ( + f"Expected both inputs to be Half, Float or Double tensors but got " + f"{real.dtype} and {imag.dtype}" + ), + ) + check( + real.dtype == imag.dtype, + lambda: ( + f"Expected object of scalar type {real.dtype} but got " + f"scalar type {imag.dtype} for second argument" + ), + ) + result_dtype = utils.corresponding_complex_dtype(real.dtype) # type: ignore[arg-type] + common_shape = _broadcast_shapes(real.shape, imag.shape) + result = real.new_empty( + common_shape, + dtype=result_dtype, + layout=real.layout, + device=real.device, + # pin_memory=real.is_pinned(), # NYI + ) + result.real = real + result.imag = imag + return result diff --git a/torch/_refs/fft.py b/torch/_refs/fft.py index d92ef6914c2d1..738a33fde038b 100644 --- a/torch/_refs/fft.py +++ b/torch/_refs/fft.py @@ -9,7 +9,7 @@ import torch._prims_common as utils from torch._decomp import register_decomposition from torch._prims_common import check, DimsType, ShapeType, TensorLikeType -from torch._prims_common.wrappers import out_wrapper +from torch._prims_common.wrappers import _maybe_convert_to_dtype, out_wrapper __all__ = [ # Transforms @@ -76,9 +76,7 @@ def _maybe_promote_tensor_fft( """Helper to promote a tensor to a dtype supported by the FFT primitives""" cur_type = t.dtype new_type = _promote_type_fft(cur_type, require_complex) - if cur_type == new_type: - return t - return prims.convert_element_type(t, new_type) + return _maybe_convert_to_dtype(t, new_type) # type: ignore[return-value] def _resize_fft_input( @@ -117,7 +115,7 @@ def _fft_c2r( ) -> TensorLikeType: """Common code for performing any complex to real FFT (irfft or hfft)""" input = _maybe_promote_tensor_fft(input, require_complex=True) - dims = (utils.canonicalize_dim(input.ndim, dim),) + dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),) last_dim_size = n if n is not None else 2 * (input.shape[dim] - 1) check(last_dim_size >= 1, lambda: f"Invalid number of data points ({n}) specified") @@ -146,7 +144,7 @@ def _fft_r2c( lambda: f"{func_name} expects a floating point input tensor, but got {input.dtype}", ) input = _maybe_promote_tensor_fft(input) - dims = (utils.canonicalize_dim(input.ndim, dim),) + dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),) if n is not None: input = _resize_fft_input(input, dims, (n,)) @@ -169,7 +167,7 @@ def _fft_c2c( input.dtype.is_complex, lambda: f"{func_name} expects a complex input tensor, but got {input.dtype}", ) - dims = (utils.canonicalize_dim(input.ndim, dim),) + dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),) if n is not None: input = _resize_fft_input(input, dims, (n,)) @@ -265,7 +263,7 @@ def _canonicalize_fft_shape_and_dim_args( if dim is not None: if not isinstance(dim, Sequence): dim = (dim,) - ret_dims = utils.canonicalize_dims(input_dim, dim) + ret_dims = utils.canonicalize_dims(input_dim, dim, wrap_scalar=False) # Check dims are unique check(len(set(dim)) == len(dim), lambda: "FFT dims must be unique") diff --git a/torch/_refs/linalg/__init__.py b/torch/_refs/linalg/__init__.py index c3b8a3c603524..e6c15ec01889f 100644 --- a/torch/_refs/linalg/__init__.py +++ b/torch/_refs/linalg/__init__.py @@ -14,11 +14,12 @@ check, check_fp_or_complex, check_is_matrix, + Dim, DimsType, NumberType, TensorLikeType, ) -from torch._prims_common.wrappers import out_wrapper +from torch._prims_common.wrappers import _maybe_convert_to_dtype, out_wrapper __all__ = [ "svd", @@ -69,7 +70,7 @@ def vector_norm( # Checks check_fp_or_complex(x.dtype, "linalg.vector_norm") - if isinstance(dim, int): + if isinstance(dim, Dim): dim = [dim] # type: ignore[assignment] elif not isinstance(dim, List) and dim is not None: # refs.amin just accepts List rather than DimType (Tuple) @@ -96,23 +97,23 @@ def vector_norm( x, utils.REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT, dtype ) - to_result_dtype = partial(prims.convert_element_type, dtype=result_dtype) + to_result_dtype = partial(_maybe_convert_to_dtype, dtype=result_dtype) # Implementation if ord == 0.0: return refs.sum(refs.ne(x, 0.0), dim=dim, keepdim=keepdim, dtype=result_dtype) elif ord == float("inf"): - return to_result_dtype(refs.amax(torch.abs(x), dim=dim, keepdim=keepdim)) + return to_result_dtype(refs.amax(torch.abs(x), dim=dim, keepdim=keepdim)) # type: ignore[return-value] elif ord == float("-inf"): - return to_result_dtype(refs.amin(torch.abs(x), dim=dim, keepdim=keepdim)) + return to_result_dtype(refs.amin(torch.abs(x), dim=dim, keepdim=keepdim)) # type: ignore[return-value] else: # From here on the computation dtype is important as the reduction is non-trivial - x = prims.convert_element_type(x, computation_dtype) + x = _maybe_convert_to_dtype(x, computation_dtype) # type: ignore[assignment] reduce_sum = partial(refs.sum, dim=dim, keepdim=keepdim) if not (ord % 2.0 == 0.0 and utils.is_float_dtype(x.dtype)): x = torch.abs(x) - return to_result_dtype(torch.pow(reduce_sum(torch.pow(x, ord)), 1.0 / ord)) + return to_result_dtype(torch.pow(reduce_sum(torch.pow(x, ord)), 1.0 / ord)) # type: ignore[return-value] def backshift_permutation(dim0, dim1, ndim): @@ -142,7 +143,7 @@ def matrix_norm( check_is_matrix(A, "linalg.matrix_norm") # dim dim = utils.canonicalize_dims(A.ndim, dim) - if isinstance(dim, int): + if isinstance(dim, Dim): dim = (dim,) # type: ignore[assignment] check(len(dim) == 2, lambda: "linalg.matrix_norm: dim must be a 2-tuple. Got {dim}") check( @@ -167,7 +168,7 @@ def matrix_norm( return vector_norm(A, 2, dim, keepdim, dtype=dtype) else: # ord == "nuc" if dtype is not None: - A = prims.convert_element_type(A, dtype) + A = _maybe_convert_to_dtype(A, dtype) # type: ignore[assignment] perm = backshift_permutation(dim[0], dim[1], A.ndim) result = torch.sum(svdvals(prims.transpose(A, perm)), -1, keepdim) if keepdim: @@ -190,7 +191,7 @@ def matrix_norm( if abs_ord == 2.0: if dtype is not None: - A = prims.convert_element_type(A, dtype) + A = _maybe_convert_to_dtype(A, dtype) # type: ignore[assignment] perm = backshift_permutation(dim[0], dim[1], A.ndim) result = max_min(svdvals(prims.transpose(A, perm)), dim=-1) if keepdim: @@ -219,7 +220,7 @@ def norm( dtype: Optional[torch.dtype] = None, ) -> TensorLikeType: if dim is not None: - if isinstance(dim, int): + if isinstance(dim, Dim): dim = (dim,) # type: ignore[assignment] check( len(dim) in (1, 2), diff --git a/torch/_refs/nn/functional/__init__.py b/torch/_refs/nn/functional/__init__.py index 3cde678449476..e979dfd9e03df 100644 --- a/torch/_refs/nn/functional/__init__.py +++ b/torch/_refs/nn/functional/__init__.py @@ -1,3 +1,5 @@ +import math +from functools import wraps from typing import Callable, Optional, Union import torch @@ -19,14 +21,12 @@ elementwise_unary_scalar_wrapper, out_wrapper, ) -from torch._refs import ( - _make_elementwise_binary_reference, - _make_elementwise_unary_reference, -) +from torch._refs import _make_inplace from torch._subclasses.fake_tensor import FakeTensor __all__ = [ + "alpha_dropout", "celu", "dropout", "elu", @@ -59,9 +59,90 @@ Tensor = torch.Tensor + +def _dropout_helper( + self: TensorLikeType, + val: float, +) -> TensorLikeType: + """ + Helper function for all dropout-type operators. During training, + some of the elements of the input tensor are randomly masked. + + Returns the masked tensor of the boolean values. + + """ + + return ( + refs._uniform_helper( + self.shape, low=0.0, high=1.0, dtype=torch.float32, device=self.device + ) + < val + ) + + +@register_decomposition(torch.ops.aten.alpha_dropout) +def alpha_dropout( + self: TensorLikeType, p: float = 0.5, training: bool = False, inplace: bool = False +) -> TensorLikeType: + + if inplace: + raise NotImplementedError + + if not training: + return self + + utils.check( + p <= 1 and p >= 0, + lambda: f"dropout probability has to be between 0 and 1, but got, {p}", + ) + + if p == 1: + return torch.zeros_like(self) + + if p == 0: + return self + + dropout_mask = _dropout_helper(self, 1 - p) + + # From paper: Self-Normalizing Neural Networks (https://arxiv.org/pdf/1706.02515.pdf) + # alpha = - SELU.alpha * SELU.scale, here + # SELU.alpha = 1.6732632423543772848170429916717 and + # SELU.scale = 1.0507009873554804934193349852946 + alpha = -1.7580993408473766 + + a = 1.0 / math.sqrt((alpha * alpha * p + 1) * (1 - p)) + b = torch.logical_not(dropout_mask) + b = b * (alpha * a) + alpha * a * p + dropout_mask = a * dropout_mask + + return self * dropout_mask + b + + +def inplace_wrapper(fn): + """ + Given a nn.functional non-linearity, implements its `inplace: bool` argument + """ + + # nb. We use the name of the first argument used in the unary references + @wraps(fn) + def _fn(a, *args, inplace=False, **kwargs): + if inplace: + check( + "out" not in kwargs, + lambda: "Cannot set inplace=True and pass out= at the same time", + ) + return fn(a, *args, inplace=False, out=a, **kwargs) + else: + return fn(a, *args, inplace=False, **kwargs) + + return _fn + + # celu is implemented specially because it has an alpha argument # celu is very similar to elu @register_decomposition(torch.ops.aten.celu) +@inplace_wrapper +@out_wrapper() @elementwise_type_promotion_wrapper( type_promoting_args=("a",), type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, @@ -93,8 +174,9 @@ def celu( return torch.where(a > 0, a, rhs) -# TODO: should we allow the user to set a different dtype for the mask generation? @register_decomposition(torch.ops.aten.dropout) +@inplace_wrapper +@out_wrapper() def dropout( a: TensorLikeType, p: float = 0.5, training: bool = True, inplace: bool = False ) -> TensorLikeType: @@ -105,32 +187,36 @@ def dropout( if not training: return a - assert p <= 1 - assert p >= 0 + utils.check( + p <= 1 and p >= 0, + lambda: f"dropout probability has to be between 0 and 1, but got, {p}", + ) if p == 1: - return refs.zeros_like(a) + return torch.zeros_like(a) if p == 0: return a - p1m = 1 - p - scale = 1 / p1m - mask = refs.lt( - refs.uniform(a.shape, low=0.0, high=1.0, dtype=torch.float32, device=a.device), - p1m, - ) - return refs.mul(refs.mul(a, mask), scale) + scale = 1 / (1 - p) + dropout_mask = _dropout_helper(a, 1 - p) + + return a * dropout_mask * scale -# elu is implemented specially because it has an alpha argument -# This cannot be used as a decomposition because the aten op takes in 2 extra kwargs +@register_decomposition(torch.ops.aten.elu) +@inplace_wrapper +@out_wrapper() @elementwise_type_promotion_wrapper( type_promoting_args=("a",), type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, ) def elu( - a: TensorLikeType, alpha: Optional[NumberType] = None, inplace: bool = False + a: TensorLikeType, + alpha: NumberType = 1.0, + scale: NumberType = 1.0, + input_scale: NumberType = 1.0, + inplace: bool = False, ) -> TensorLikeType: """ Reference implementation of torch.nn.functional.elu @@ -138,24 +224,27 @@ def elu( if inplace: raise NotImplementedError - rhs: TensorLikeType - if alpha is not None: - python_type = utils.dtype_to_type(a.dtype) - if not utils.is_weakly_lesser_type(type(alpha), python_type): - msg = ( - "alpha argument of type {0} cannot be safely cast to type {1}!".format( - type(alpha), python_type - ) - ) - raise ValueError(msg) - rhs = alpha * torch.expm1(a) - else: - rhs = torch.expm1(a) + # nb. This should be factored out into a can_cast aux function + python_type = utils.dtype_to_type(a.dtype) + check( + utils.is_weakly_lesser_type(type(input_scale), python_type), + lambda: f"input_scale argument of type {type(input_scale)} cannot be safely cast to type {python_type}!", + ) + check( + utils.is_weakly_lesser_type(type(scale), python_type), + lambda: f"scale argument of type {type(scale)} cannot be safely cast to type {python_type}!", + ) + check( + utils.is_weakly_lesser_type(type(alpha), python_type), + lambda: f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!", + ) - return torch.where(a > 0, a, rhs) + return torch.where(a > 0, scale * a, (alpha * scale) * torch.expm1(a * input_scale)) @register_decomposition(torch.ops.aten.relu) +@inplace_wrapper +@out_wrapper() @elementwise_type_promotion_wrapper( type_promoting_args=("a",), type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, @@ -171,6 +260,46 @@ def relu(a: TensorLikeType, inplace: bool = False) -> TensorLikeType: return torch.where(torch.le(a, 0), 0, a) +def group_norm( + input: Tensor, + num_groups: int, + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + eps: float = 1e-5, +) -> Tensor: + """ + Reference implementation of :func:`torch.nn.functional.group_norm`. + """ + utils.check( + input.ndim >= 2, + lambda: f"Expected at least 2 dimensions for input tensor but received {input.ndim}", + ) + + batch_size = input.shape[0] + num_channels = input.shape[1] + utils.check( + num_channels % num_groups == 0, + lambda: "Expected number of channels in input to be divisible by num_groups, " + + f"but got input of shape {input.shape} and num_groups = {num_groups}", + ) + + # input shape is (N, C, *), so we flatten all inner dimensions except (N, C) + flattened_inner_size = 1 + for dim_length in input.shape[2:]: + flattened_inner_size *= dim_length + + return torch.native_group_norm( + input, + weight, + bias, + batch_size, + num_channels, + flattened_inner_size, + num_groups, + eps, + )[0] + + def layer_norm( input: Tensor, normalized_shape: ShapeType, @@ -185,6 +314,8 @@ def layer_norm( @register_decomposition(torch.ops.aten.leaky_relu) +@inplace_wrapper +@out_wrapper() @elementwise_type_promotion_wrapper( type_promoting_args=("a",), type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, @@ -207,6 +338,8 @@ def leaky_relu( @register_decomposition(torch.ops.aten.mish) +@inplace_wrapper +@out_wrapper() @elementwise_type_promotion_wrapper( type_promoting_args=("a",), type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, @@ -222,6 +355,8 @@ def mish(a: TensorLikeType, inplace: bool = False) -> TensorLikeType: @register_decomposition(torch.ops.aten.selu) +@inplace_wrapper +@out_wrapper() @elementwise_type_promotion_wrapper( type_promoting_args=("a",), type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, @@ -274,6 +409,7 @@ def softmin( # softplus is implemented specially because it has beta and threshold arguments @register_decomposition(torch.ops.aten.softplus) +@inplace_wrapper @out_wrapper() @elementwise_type_promotion_wrapper( type_promoting_args=("a",), @@ -496,7 +632,7 @@ def _nll_loss_nd( ) -> TensorLikeType: utils.check( input.ndim > 0 and input.ndim <= 3, - lambda: f"Expected input dimension to be either [1, 2, 3] but recieved {input.ndim}.", + lambda: f"Expected input dimension to be either [1, 2, 3] but received {input.ndim}.", ) utils.check( @@ -566,11 +702,11 @@ def _nll_loss_nd( @register_decomposition(torch.ops.aten.nll_loss) +@out_wrapper() @elementwise_type_promotion_wrapper( type_promoting_args=("input",), type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, ) -@out_wrapper() def nll_loss( input: TensorLikeType, target: TensorLikeType, @@ -689,6 +825,8 @@ def tanhshrink(a: TensorLikeType) -> TensorLikeType: @register_decomposition(torch.ops.aten.threshold) +@inplace_wrapper +@out_wrapper() @elementwise_type_promotion_wrapper( type_promoting_args=("a",), type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, @@ -788,6 +926,8 @@ def _triplet_margin_with_distance_loss( @register_decomposition(torch.ops.aten.hardtanh) +@inplace_wrapper +@out_wrapper() @elementwise_unary_scalar_wrapper @elementwise_type_promotion_wrapper( type_promoting_args=("a"), @@ -927,6 +1067,8 @@ def prelu(a: TensorLikeType, weight: TensorLikeType) -> TensorLikeType: @register_decomposition(torch.ops.aten.relu6) +@inplace_wrapper +@out_wrapper() def relu6(a: TensorLikeType, inplace: bool = False) -> TensorLikeType: """ Reference implementation of torch.nn.functional.relu6 @@ -941,11 +1083,11 @@ def relu6(a: TensorLikeType, inplace: bool = False) -> TensorLikeType: @register_decomposition(torch.ops.aten.glu) +@out_wrapper() @elementwise_type_promotion_wrapper( type_promoting_args=("a",), type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, ) -@out_wrapper() def glu(a: TensorLikeType, dim: int = -1) -> TensorLikeType: dim = utils.canonicalize_dims(a.ndim, dim) check( @@ -970,11 +1112,11 @@ def pairwise_distance( @register_decomposition(torch.ops.aten.pdist) +@out_wrapper() @elementwise_type_promotion_wrapper( type_promoting_args=("a",), type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, ) -@out_wrapper() def pdist(a: TensorLikeType, p: float = 2) -> TensorLikeType: check(a.ndim == 2, lambda: f"pdist only supports 2D tensors, got: {a.ndim}D") check(p >= 0, lambda: "pdist only supports non-negative p values") @@ -988,3 +1130,11 @@ def pdist(a: TensorLikeType, p: float = 2) -> TensorLikeType: t = torch.linalg.vector_norm(a.unsqueeze(1) - a, ord=p, dim=2) i = torch.triu_indices(t.shape[0], t.shape[1], offset=1, device=a.device) return t.flatten().index_select(0, i[0] * t.shape[0] + i[1]) + + +# Needed as aten.{celu_,elu_...} exist (even if they don't have the in-place kwarg) +celu_ = _make_inplace(celu) +elu_ = _make_inplace(elu) +mish_ = _make_inplace(mish) +selu_ = _make_inplace(selu) +threshold_ = _make_inplace(threshold) diff --git a/torch/_refs/special/__init__.py b/torch/_refs/special/__init__.py index fae9f9d12dbe6..4983823242653 100644 --- a/torch/_refs/special/__init__.py +++ b/torch/_refs/special/__init__.py @@ -1,5 +1,5 @@ import math -from typing import Optional +from typing import Optional, Union import torch import torch._prims as prims @@ -8,7 +8,13 @@ from torch import Tensor from torch._decomp import register_decomposition -from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND, TensorLikeType +from torch._prims_common import ( + ELEMENTWISE_TYPE_PROMOTION_KIND, + Number, + NumberType, + TensorLike, + TensorLikeType, +) from torch._prims_common.wrappers import elementwise_type_promotion_wrapper, out_wrapper from torch._refs import ( _make_elementwise_binary_reference, @@ -33,13 +39,13 @@ "ndtri", "softmax", "spherical_bessel_j0", + "xlog1py", "zeta", ] @_make_elementwise_unary_reference( ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, - aten_op=torch.ops.aten.special_bessel_j0, ) def bessel_j0(a: TensorLikeType) -> TensorLikeType: return prims.bessel_j0(a) @@ -47,7 +53,6 @@ def bessel_j0(a: TensorLikeType) -> TensorLikeType: @_make_elementwise_unary_reference( ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, - aten_op=torch.ops.aten.special_bessel_j1, ) def bessel_j1(a: TensorLikeType) -> TensorLikeType: return prims.bessel_j1(a) @@ -82,21 +87,21 @@ def erfcx(a: TensorLikeType) -> TensorLikeType: @_make_elementwise_unary_reference( - ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, aten_op=torch.ops.aten.special_i0e + ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, ) def i0e(a: TensorLikeType) -> TensorLikeType: return prims.bessel_i0e(a) @_make_elementwise_unary_reference( - ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, aten_op=torch.ops.aten.special_i1 + ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, ) def i1(a: TensorLikeType) -> TensorLikeType: return prims.bessel_i1(a) @_make_elementwise_unary_reference( - ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, aten_op=torch.ops.aten.special_i1e + ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, ) def i1e(a: TensorLikeType) -> TensorLikeType: return prims.bessel_i1e(a) @@ -134,6 +139,31 @@ def logit(self: TensorLikeType, eps: Optional[float] = None) -> TensorLikeType: return torch.log(torch.true_divide(self, torch.sub(1, self))) +@register_decomposition(torch.ops.aten.special_xlog1py) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a", "b"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def xlog1py(a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]): + utils.check( + isinstance(a, TensorLike) or isinstance(b, TensorLike), + lambda: 'Expected either argument a or b to be a Tensor"', + ) + + # Operations like eq and log do not handle scalar values, so we convert them to scalar_tensors. + if isinstance(a, TensorLike) and isinstance(b, Number): + b = refs.scalar_tensor(b, dtype=a.dtype, device=a.device) + elif isinstance(b, TensorLike) and isinstance(a, Number): + a = refs.scalar_tensor(a, dtype=b.dtype, device=b.device) + + # mypy: expected "Tensor" + assert isinstance(a, TensorLike) + assert isinstance(b, TensorLike) + rhs = torch.where(torch.eq(a, 0), 0, torch.mul(a, refs.log1p(b))) + return torch.where(torch.isnan(b), float("nan"), rhs) + + @register_decomposition(torch.ops.aten.mvlgamma) @out_wrapper() @elementwise_type_promotion_wrapper( @@ -191,14 +221,14 @@ def softmax( @_make_elementwise_unary_reference( ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, - aten_op=torch.ops.aten.special_spherical_bessel_j0, ) def spherical_bessel_j0(a: TensorLikeType) -> TensorLikeType: return prims.spherical_bessel_j0(a) -zeta = _make_elementwise_binary_reference( - prims.zeta, # type: ignore[has-type] +# TODO: add docstring +@_make_elementwise_binary_reference( type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, - aten_op=torch.ops.aten.special_zeta, ) +def zeta(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.zeta(a, b) diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index bb6970303facd..a137bdba1aa17 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -1,15 +1,13 @@ import contextlib import functools import itertools -import sys -import warnings +import os import weakref from dataclasses import dataclass from functools import partial -from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union import torch -from torch._decomp import meta_table as meta_table from torch._ops import OpOverload from torch._subclasses.meta_utils import MetaConverter, WeakTensorRefKey from torch.fx.operator_schemas import normalize_function @@ -140,15 +138,14 @@ def tree_flatten_only(ty: Type[T], pytree: PyTree): # structure. Like `MetaConverter`, it uses `WeakTensorRefKey` to # hold a weak reference for all memoized tensors. class FakeTensorConverter(object): - tensor_memo: weakref.WeakValueDictionary + @property + def tensor_memo(self): + return self.meta_converter.tensor_memo + meta_converter: MetaConverter constant_storage_mapping: Dict[StorageWeakRef, List[TensorWeakRef]] def __init__(self): - # FakeTensors store the FakeTensorMode which in turn stores a - # FakeTensor, so we need to hold a weak reference to the FakeTensor - # otherwise we would induce a circular reference - self.tensor_memo = weakref.WeakValueDictionary() self.meta_converter = MetaConverter() # map from to storage to corresponding constant tensors @@ -159,7 +156,7 @@ def add_constant_storage_mapping(self, fake_tensor): # const_tensor.add_(torch.rand([1])) # all aliases of it must become no longer const assert isinstance(fake_tensor, FakeTensor) and fake_tensor.constant is not None - weak_st = StorageWeakRef(fake_tensor.constant.storage()) + weak_st = StorageWeakRef(fake_tensor.constant._typed_storage()) # we need a map from a weak storage to all of its corresponding # constant tensors. python doesn't have the weak value equivalent @@ -171,7 +168,7 @@ def add_constant_storage_mapping(self, fake_tensor): def invalidate_constant_aliases(self, tensor): assert not isinstance(tensor, FakeTensor) - weak_st = StorageWeakRef(tensor.storage()) + weak_st = StorageWeakRef(tensor._typed_storage()) if weak_st not in self.constant_storage_mapping: return @@ -207,7 +204,9 @@ def del_ten(): weakref.finalize(t, del_ten) self.tensor_memo[th] = v - def from_real_tensor(self, fake_mode, t, make_constant=False, shape_env=None): + def from_real_tensor( + self, fake_mode, t, make_constant=False, shape_env=None, ignore_subclass=False + ): maybe_memo = self._get_memo(t) if maybe_memo is not None: return maybe_memo @@ -215,31 +214,43 @@ def from_real_tensor(self, fake_mode, t, make_constant=False, shape_env=None): # not yet supported in metatensors if t.is_quantized: raise UnsupportedFakeTensorException("quantized nyi in meta tensors") - with no_dispatch(): - meta_t = self.meta_converter(t, shape_env=shape_env) - if meta_t.device.type != "meta": - raise UnsupportedFakeTensorException("meta converter nyi") - out = FakeTensor( - fake_mode, - meta_t, - existing_device, - constant=t if make_constant else None, - ) - out.requires_grad_(t.requires_grad) - if make_constant: - self.add_constant_storage_mapping(out) if type(t) is torch.nn.Parameter: assert not make_constant - out = torch.nn.Parameter(out, requires_grad=out.requires_grad) # type: ignore[assignment] - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", "The .grad attribute of a Tensor") - grad_not_none = t.grad is not None - if grad_not_none: - out.grad = self.from_real_tensor(fake_mode, t.grad) - self.set_tensor_memo(t, out) + + def mk_fake_tensor(make_meta_t): + # NB: don't use in_kernel_invocation_manager. to + # ensure FakeTensor can internally do constant computation + # as necessary. Invocation manager is "more correct" as + # it works for more operators in make_meta_t, but + # invariant is that make_meta_t only calls factories + # for which it is not strictly necessary to use the + # invocation manager (I think!) + with no_dispatch(): + return FakeTensor( + fake_mode, + make_meta_t(), + existing_device, + constant=t if make_constant else None, + ) + + out = self.meta_converter( + t, + shape_env=shape_env, + callback=mk_fake_tensor, + ignore_subclass=ignore_subclass, + ) + if out is NotImplemented: + raise UnsupportedFakeTensorException("meta converter nyi") + if make_constant: + self.add_constant_storage_mapping(out) + # NB: meta_converter set the memo return out + # If you specify the device, it MUST be a meta tensor. def from_meta_and_device(self, fake_mode, t, device): + assert ( + t.device.type == "meta" + ), f"tensor's device must be `meta`, got {t.device.type} instead" maybe_memo = self._get_memo(t) if maybe_memo is not None: return maybe_memo @@ -247,27 +258,29 @@ def from_meta_and_device(self, fake_mode, t, device): self.set_tensor_memo(t, out) return out - # There are two ways to call this. First, you can have manually constructed - # a meta tensor and you need to turn it into a fake tensor. In that case, - # pass a meta tensor and a device argument. Alternately, you can have a - # real tensor that you need to convert into a fake tensor; in that case, - # omit the device. + # You can have a real tensor that you need to convert into a fake tensor. + # If you have a meta tensor already, call from_meta_and_device. # - # The disallowed case: if you specify the device, it MUST be a meta tensor. - # However, you're allowed to pass a meta tensor to be turned into a fake + # You're allowed to pass a meta tensor to be turned into a fake # tensor; although an odd thing to do, this can occur if you're doing - # cross ref testing and the inner test is already operating on meta tensors + # cross ref testing and the inner test is already operating on meta tensors. + # You must have created the FakeTensorMode with allow_meta == True def __call__( - self, fake_mode, t, device=None, *, make_constant=False, shape_env=None + self, + fake_mode, + t, + *, + make_constant=False, + shape_env=None, + ignore_subclass=False, ): - if device is None: - return self.from_real_tensor( - fake_mode, t, make_constant, shape_env=shape_env - ) - else: - assert make_constant is False - assert t.device.type == "meta" - return self.from_meta_and_device(fake_mode, t, device) + return self.from_real_tensor( + fake_mode, + t, + make_constant, + shape_env=shape_env, + ignore_subclass=ignore_subclass, + ) op_implementations = [] @@ -305,7 +318,10 @@ def constructors(fake_mode, func, *args, **kwargs): out_device = new_kwargs.pop("device", None) out_device = out_device if out_device is not None else default_device new_kwargs["device"] = torch.device("meta") - r = func(*args, **new_kwargs) + # _like constructors have fake tensor inputs (maybe this causes the non-like + # to fail? hmmm) + with in_kernel_invocation_manager(fake_mode): + r = func(*args, **new_kwargs) return FakeTensor(fake_mode, r, out_device) @@ -318,15 +334,20 @@ def non_kwarg_to(fake_mode, func, *args, **kwargs): out_device = input_device if input_device else new_kwargs["input"].device new_kwargs["device"] = torch.device("meta") inp = new_kwargs.pop("input") - r = func(inp, **new_kwargs) - return fake_mode.fake_tensor_converter(fake_mode, r, out_device) + with in_kernel_invocation_manager(fake_mode): + r = func(inp, **new_kwargs) + # TODO: I think this does the wrong thing if r is inp + return fake_mode.fake_tensor_converter.from_meta_and_device( + fake_mode, r, out_device + ) # Dont default to default device handling, # since the device of `the_template` is ignored @register_op_impl(aten.resize_as_.default) def resize_as_(fake_mode, func, *args, **kwargs): - return func(*args, **kwargs) + with in_kernel_invocation_manager(fake_mode): + return func(*args, **kwargs) @register_op_impl(aten._sparse_coo_tensor_with_dims_and_tensors.default) @@ -335,21 +356,6 @@ def _sparse_coo_tensor_with_dims_and_tensors(fake_mode, func, *args, **kwargs): return constructors(fake_mode, func, *args, **kwargs) -# _to_copy fails when run with FakeTensors to cuda device -# TODO: debug -@register_op_impl(aten._to_copy.default) -def to_copy(fake_mode, func, *args, **kwargs): - _, new_kwargs = normalize_function( - func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True - ) - - input_device = new_kwargs.pop("device", None) - out_device = input_device if input_device else new_kwargs["input"].device - with in_kernel_invocation_manager(fake_mode): - input = new_kwargs.pop("input").to("meta") - return FakeTensor(fake_mode, aten._to_copy(input, **new_kwargs), out_device) - - # index.Tensor data-dependent in only some conditions @register_op_impl( lambda func: torch.Tag.dynamic_output_shape in func.tags # type: ignore[attr-defined] @@ -423,11 +429,58 @@ def nyi(fake_mode, func, *args, **kwargs): assert func not in _device_not_kwarg_ops, f"NYI: {func}" -# Meta tensors give you the ability to run PyTorch code without having to -# actually do computation through tensors allocated on a `meta` device. -# Because the device is `meta`, meta tensors do not model device propagation. -# FakeTensor extends MetaTensors to also carry an additional `fake_device` -# which tracks devices that would have been used. +@register_op_impl( + lambda func: func in (aten.convolution.default, aten.convolution_backward.default) +) +def conv(fake_mode, func, *args, **kwargs): + _, kwargs = normalize_function( + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + device = kwargs["input"].fake_device + # need to re-enable mode so the tensors report fake device + with fake_mode: + # if the input is unsqueezed is done in Convolution.cpp we get segfault + k = kwargs["weight"].ndim + if k == 3 and not kwargs["input"].is_mkldnn and not kwargs["input"].is_xpu: + mem_fmt = None + else: + if func is aten.convolution.default: + conv_backend = torch._C._select_conv_backend(**kwargs) + else: + conv_backend = torch._C._select_conv_backend( + kwargs["input"], + kwargs["weight"], + bias=None, + stride=kwargs["stride"], + padding=kwargs["padding"], + dilation=kwargs["dilation"], + transposed=kwargs["transposed"], + output_padding=kwargs["output_padding"], + groups=kwargs["groups"], + bias_sizes=kwargs["bias_sizes"], + ) + mem_fmt = torch._C._conv_determine_backend_memory_format( + kwargs["input"], kwargs["weight"], conv_backend + ) + + def convert(t, mem_fmt): + if t is None: + return t + if mem_fmt is not None: + t = t.to(memory_format=mem_fmt) + return FakeTensor(fake_mode, t, device) + + with in_kernel_invocation_manager(fake_mode): + out = func(**kwargs) + + if func is aten.convolution.default: + return convert(out, mem_fmt) + else: + return ( + convert(out[0], mem_fmt), + convert(out[1], mem_fmt), + convert(out[2], None), + ) @contextlib.contextmanager @@ -448,7 +501,19 @@ def in_kernel_invocation_manager(fake_mode): del guard +class FakeTensorConfig: + debug = os.environ.get("TORCH_FAKE_TENSOR_DEBUG", False) + + class FakeTensor(torch.Tensor): + """ + Meta tensors give you the ability to run PyTorch code without having to + actually do computation through tensors allocated on a `meta` device. + Because the device is `meta`, meta tensors do not model device propagation. + FakeTensor extends MetaTensors to also carry an additional `fake_device` + which tracks devices that would have been used. + """ + fake_device: torch.device fake_mode: "FakeTensorMode" constant: Optional[torch.Tensor] @@ -499,6 +564,10 @@ def __init__( self.fake_device = device self.fake_mode = fake_mode self.constant = constant + if FakeTensorConfig.debug: + import traceback + + self._debug_trace = traceback.extract_stack() @staticmethod def from_tensor(t, fake_mode): @@ -542,11 +611,14 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): return func(*args, **kwargs) @staticmethod - def _find_common_device(func, args, kwargs): + def _find_common_device(func, args, kwargs) -> Tuple[torch.device, bool]: + # Returns: (common_device, has_scalar_only_inputs) + # cpu - zero-dim tensors can be called in cuda kernels, # so overwrite the common_device if it the only existing # device comes from a cpu zero-dim tensor common_device = None + has_scalar_only_inputs = False is_cpu_zero_dim = None def cpu_zero_dim(t): @@ -598,11 +670,13 @@ def merge_devices(t): ) and common_device is None ): + # ops with scalar only inputs always have result on cpu + has_scalar_only_inputs = True common_device = torch.device("cpu") assert common_device is not None, f"Could not find common device for {func}" - return common_device + return common_device, has_scalar_only_inputs __torch_function__ = torch._C._disabled_torch_function_impl @@ -656,8 +730,19 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): else: return args[0].fake_device + # Some attribute queries that can be serviced directly + # See Note [is_coalesced is dispatched] + if func in { + torch.ops.aten.is_coalesced.default, + torch.ops.aten.dense_dim.default, + torch.ops.aten.sparse_dim.default, + }: + # NB: no_dispatch is ok here too, this func is very simple + with in_kernel_invocation_manager(self): + return func(*args, **kwargs) + flat_arg_fake_tensors = tree_flatten_only(FakeTensor, (args, kwargs)) - flat_symints = tree_flatten_only(torch.SymIntNode, (args, kwargs)) + flat_symints = tree_flatten_only(torch.SymInt, (args, kwargs)) has_symbolic_sizes = ( any([i._has_symbolic_sizes_strides for i in flat_arg_fake_tensors]) or len(flat_symints) > 0 @@ -671,38 +756,38 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): if func in self.lift_fns: out = func(*args, **kwargs) if self.may_turn_const(out): + # NB: not in_kernel_invocation_manager because we're doing real + # compute here with no_dispatch(): - return converter(self, out.clone(), make_constant=True) - - with no_dispatch(): - flat_arg_tensors = tree_flatten_only(torch.Tensor, (args, kwargs)) - # See [subclass inputs] below - # NB: If you're seeing a mysterious infinite loop involving fake - # tensor, it might be related to this line. Though I'm not sure - # how you'll know to read this comment, as this line won't show up - # in the stack trace. - if self.check_for_subclass(flat_arg_tensors): - return NotImplemented - - # if we are in the dispatch mode, we will enter this function even if the inputs - # are not FakeTensors. For now, throw if any non-Fake Tensor inputs - # and just support constructors. - - # this is generated from torch.tensor(), which does not use the - # dispatcher, to allow wrapper subclasses to wrap the new tensor - if func in self.lift_fns: - assert ( - len(kwargs) == 0 - and len(args) == 1 - and type(args[0]) is torch.Tensor - ), f"{args} {kwargs}" - return converter(self, args[0]) - - if self.check_for_non_fake(flat_arg_tensors): - raise Exception( - "Invoking operators with non-Fake Tensor inputs in FakeTensorMode is not yet supported. " - f"Please convert all Tensors to FakeTensors first. Found in {func}(*{args}, **{kwargs})" - ) + out = out.clone() + return converter(self, out, make_constant=True) + + flat_arg_tensors = tree_flatten_only(torch.Tensor, (args, kwargs)) + # See [subclass inputs] below + # NB: If you're seeing a mysterious infinite loop involving fake + # tensor, it might be related to this line. Though I'm not sure + # how you'll know to read this comment, as this line won't show up + # in the stack trace. + if self.check_for_subclass(flat_arg_tensors): + return NotImplemented + + # if we are in the dispatch mode, we will enter this function even if the inputs + # are not FakeTensors. For now, throw if any non-Fake Tensor inputs + # and just support constructors. + + # this is generated from torch.tensor(), which does not use the + # dispatcher, to allow wrapper subclasses to wrap the new tensor + if func in self.lift_fns: + assert ( + len(kwargs) == 0 and len(args) == 1 and type(args[0]) is torch.Tensor + ), f"{args} {kwargs}" + return converter(self, args[0]) + + if self.check_for_non_fake(flat_arg_tensors): + raise Exception( + "Invoking operators with non-Fake Tensor inputs in FakeTensorMode is not yet supported. " + f"Please convert all Tensors to FakeTensors first. Found in {func}(*{args}, **{kwargs})" + ) # The current constant handling only support tracing systems # (aot autograd, torchdynamo) where each operation is run consecutively. @@ -722,115 +807,90 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): and len(flat_arg_fake_tensors) != 0 and not has_symbolic_sizes ): + const_args, const_kwargs = pytree.tree_map_only( + FakeTensor, lambda t: t.constant, (args, kwargs) + ) + + # NB: not in_kernel_invocation_manager(self) as we want to do REAL + # compute with no_dispatch(): - const_args, const_kwargs = pytree.tree_map_only( - FakeTensor, lambda t: t.constant, (args, kwargs) - ) out = func(*const_args, **const_kwargs) - all_constant = pytree.tree_all_only( - torch.Tensor, lambda t: self.may_turn_const(t), out - ) + all_constant = pytree.tree_all_only( + torch.Tensor, lambda t: self.may_turn_const(t), out + ) - if all_constant: - return pytree.tree_map_only( - torch.Tensor, - lambda t: converter(self, t, make_constant=True), - out, - ) + if all_constant: + return pytree.tree_map_only( + torch.Tensor, + lambda t: converter(self, t, make_constant=True), + out, + ) - # we weren't able to turn outputs to constants, - # so invalidate all constants that might be aliases of the outputs - for ten in tree_flatten_only(torch.Tensor, out): - converter.invalidate_constant_aliases(ten) + # we weren't able to turn outputs to constants, + # so invalidate all constants that might be aliases of the outputs + for ten in tree_flatten_only(torch.Tensor, out): + converter.invalidate_constant_aliases(ten) # we are falling through to running non constant tensors, any input constant that # is written to must be invalidated self.invalidate_written_to_constants(func, flat_arg_fake_tensors, args, kwargs) - from torch._decomp import _disabled_meta_decomps, decomposition_table - - with self: - # Decomposes CompositeImplicitAutograd ops - r = func.decompose(*args, **kwargs) - if r is not NotImplemented: - return r + # If there's a Python meta, prefer that over the decomposition + from torch._decomp import meta_table as meta_table - # IDK: feels bad man, sym_numel on as_strided infinite loops otherwise - if ( - has_symbolic_sizes - and func not in self.functions_with_cpp_meta_impl_that_support_symint - ): - with no_dispatch(): - if func == aten.size.default: - sys.stderr.write( - "Trying to call aten.size on a tensor with symbolic shapes. " - "It's likely that this is from calling tensor.shape in C++" - ) - # We do this to allow for better error localization with `TORCH_SHOW_CPP_STACKTRACES=1` - return None + if func not in meta_table and not self.cpp_meta_supports_symint(func): + from torch._decomp import decomposition_table - with self: - if func in meta_table: - r = meta_table[func](*args, **kwargs) - return r - if func in decomposition_table: + # Prefer Python decompositions over C++ ones + if func in decomposition_table and ( + has_symbolic_sizes + or ( + # TODO: Remove these exclusions, so that we can remove + # this leg entirely + torch_decomp_decompositions(func) + and all(not e.is_sparse for e in flat_arg_fake_tensors) + ) + ): + with self: return decomposition_table[func](*args, **kwargs) - if ( - func in decomposition_table - and torch_decomp_decompositions(func) - and func not in _disabled_meta_decomps - and all(not e.is_sparse for e in flat_arg_fake_tensors) - ): with self: - return decomposition_table[func](*args, **kwargs) + # Decomposes CompositeImplicitAutograd ops + r = func.decompose(*args, **kwargs) + if r is not NotImplemented: + return r # prims already wrap FakeTensor inputs to FakeTensor outputs # and do device logic, we dont need do anything but run them # and ensure that Meta kernels are dispatched to (see) # Fake Tensor Dispatch Keys # TODO - we should be use the prim aten impl - if ( - "prims::" in func._schema.name - and len(flat_arg_fake_tensors) != 0 - and hasattr(func, "prim_meta_impl") - ): + if "prims::" in func._schema.name and hasattr(func, "prim_meta_impl"): with self: return func.prim_meta_impl(*args, **kwargs) - if has_symbolic_sizes: - if func not in self.functions_with_cpp_meta_impl_that_support_symint: - raise RuntimeError( - f"{func} - couldn't find symbolic meta function/decomposition" - ) - - with no_dispatch(): - # special handling for funcs registered through `register_op_impl`, - # e.g., manipulating args on constructor calls to construct meta tensors - # and then afterwards wrapping them to a FakeTensor - for run_impl_check, op_impl in op_implementations: - if run_impl_check(func): - op_impl_out = op_impl(self, func, *args, **kwargs) - if op_impl_out != NotImplemented: - return op_impl_out - - # run kernel registered to meta for func, which include - # python meta registrations, prims, decomps, and c++ meta fns (structured kernels) - try: - with in_kernel_invocation_manager(self): - r = func(*args, **kwargs) - except NotImplementedError as not_implemented_error: - # no meta kernel registered, fallback to kernel for the device - if not self.allow_fallback_kernels: - raise not_implemented_error - return run_fallback_kernel( - self, func, args, kwargs, not_implemented_error - ) - - return self.wrap_meta_outputs_with_default_device_logic( - r, func, args, kwargs - ) + # special handling for funcs registered through `register_op_impl`, + # e.g., manipulating args on constructor calls to construct meta tensors + # and then afterwards wrapping them to a FakeTensor + for run_impl_check, op_impl in op_implementations: + if run_impl_check(func): + op_impl_out = op_impl(self, func, *args, **kwargs) + if op_impl_out != NotImplemented: + return op_impl_out + + # run kernel registered to meta for func, which include + # python meta registrations, prims, decomps, and c++ meta fns (structured kernels) + try: + with in_kernel_invocation_manager(self): + r = func(*args, **kwargs) + except NotImplementedError as not_implemented_error: + # no meta kernel registered, fallback to kernel for the device + if not self.allow_fallback_kernels: + raise not_implemented_error + return run_fallback_kernel(self, func, args, kwargs, not_implemented_error) + + return self.wrap_meta_outputs_with_default_device_logic(r, func, args, kwargs) # [subclass inputs] # Suppose we enable fake tensor mode. This means that fake tensor @@ -870,26 +930,46 @@ def gen_wrap_fn(self, func, args, kwargs): # Lazily initialized, in case there are no tensor returns common_device = None + has_scalar_only_inputs = False def wrap(e, device=None): nonlocal common_device + nonlocal has_scalar_only_inputs if isinstance(e, torch.Tensor) and not isinstance(e, FakeTensor): if common_device is None: - common_device = FakeTensor._find_common_device(func, args, kwargs) - return converter(self, e, device or common_device) + ( + common_device, + has_scalar_only_inputs, + ) = FakeTensor._find_common_device(func, args, kwargs) + + if has_scalar_only_inputs: + # Under FakeTensorMode, op accepts scalar only inputs, such as aten.add/sub/mul/div, + # returns a real scalar tensor on CPU. See TensorMeta() in _prims/__init__.py for details. + # We thus directly convert real tensor to fake tensor. + return converter(self, e) + else: + return converter.from_meta_and_device( + self, e, device or common_device + ) else: return e return wrap - @property - def functions_with_cpp_meta_impl_that_support_symint(self): - return [ + def cpp_meta_supports_symint(self, func): + if torch.Tag.view_copy in func.tags: # type: ignore[attr-defined] + return True + return func in [ aten.empty_strided.default, aten.as_strided_scatter.default, aten.as_strided.default, + aten.as_strided_.default, aten.zeros.default, aten.detach.default, + aten.view_as_real.default, + aten.view_as_complex.default, + aten.set_.source_Storage_storage_offset, + aten._sparse_coo_tensor_with_dims_and_tensors.default, ] @property @@ -921,10 +1001,14 @@ def invalidate_written_to_constants( ): self.fake_tensor_converter.invalidate_constant_aliases(v.constant) - def from_tensor(self, tensor, static_shapes=False): + def from_tensor(self, tensor, static_shapes=False, ignore_subclass=False): if static_shapes: - return self.fake_tensor_converter(self, tensor) - return self.fake_tensor_converter(self, tensor, shape_env=self.shape_env) + return self.fake_tensor_converter( + self, tensor, ignore_subclass=ignore_subclass + ) + return self.fake_tensor_converter( + self, tensor, shape_env=self.shape_env, ignore_subclass=ignore_subclass + ) # NB: returns fake tensors @@ -935,8 +1019,11 @@ def run_fallback_kernel(fake_mode, func, args, kwargs, orig_not_implemented_exce if torch.Tag.inplace_view in func.tags: # type: ignore[attr-defined] raise orig_not_implemented_exception + inp_impls = {} + + # Don't use in_kernel_invocation_manager(fake_mode) as we want to do + # REAL compute (not with meta device) with no_dispatch(): - inp_impls = {} def to_real_tensor(e): if isinstance(e, FakeTensor): @@ -952,25 +1039,25 @@ def to_real_tensor(e): r = func(*args, **kwargs) - tensor_impls = set() - storages = set() - - for e in tree_flatten((args, kwargs))[0]: - if isinstance(e, torch.Tensor): - if not e.is_sparse: - storages.add(e.storage()._cdata) - - # TODO: also check metadata change on inputs - # proper aliasing/metadata relationship between outputs and inputs will - # not be set up, bc of conversion to device, unless we can reuse an - # input impl - for e in tree_flatten(r)[0]: - if id(e) not in inp_impls and ( - isinstance(e, torch.Tensor) - and not e.is_sparse - and e.storage()._cdata in storages - ): - raise orig_not_implemented_exception + tensor_impls = set() + storages = set() + + for e in tree_flatten((args, kwargs))[0]: + if isinstance(e, torch.Tensor): + if not e.is_sparse: + storages.add(e._typed_storage()._cdata) + + # TODO: also check metadata change on inputs + # proper aliasing/metadata relationship between outputs and inputs will + # not be set up, bc of conversion to device, unless we can reuse an + # input impl + for e in tree_flatten(r)[0]: + if id(e) not in inp_impls and ( + isinstance(e, torch.Tensor) + and not e.is_sparse + and e._typed_storage()._cdata in storages + ): + raise orig_not_implemented_exception def map_out(e): if isinstance(e, torch.Tensor): @@ -995,7 +1082,9 @@ def __torch_function__(self, func, types, args=(), kwargs=None): # clone will get called in Parameter deepcopy if func == torch._C._TensorBase.clone: - return func(self.fake_mode.from_tensor(args[0]), **kwargs) + return func( + self.fake_mode.from_tensor(args[0], static_shapes=True), **kwargs + ) elif func == torch.Tensor.__deepcopy__: assert len(args) == 2 and len(kwargs) == 0 tensor, memo = args @@ -1003,7 +1092,7 @@ def __torch_function__(self, func, types, args=(), kwargs=None): if id(tensor) in memo: return memo[id(tensor)] - out = self.fake_mode.from_tensor(tensor) + out = self.fake_mode.from_tensor(tensor, static_shapes=True) memo[id(tensor)] = out return out else: diff --git a/torch/_subclasses/fake_utils.py b/torch/_subclasses/fake_utils.py index 37ff260c9bd30..6cd3789ae5a08 100644 --- a/torch/_subclasses/fake_utils.py +++ b/torch/_subclasses/fake_utils.py @@ -18,12 +18,12 @@ def outputs_alias_inputs(outputs, inputs): input_storages = { - inp.storage()._cdata + inp._typed_storage()._cdata for inp in tree_flatten_only(torch.Tensor, inputs) if torch._C._has_storage(inp) } return any( - torch._C._has_storage(out) and out.storage()._cdata in input_storages + torch._C._has_storage(out) and out._typed_storage()._cdata in input_storages for out in tree_flatten_only(torch.Tensor, outputs) ) @@ -38,7 +38,7 @@ def output_alias_each_other(outputs): for out in tree_flatten_only(torch.Tensor, outputs): if not torch._C._has_storage(out): continue - stor = out.storage()._cdata + stor = out._typed_storage()._cdata if stor in storages: return True storages.add(stor) @@ -136,5 +136,5 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): r_out, fake_out, check_strides=self.check_strides ) except Exception as e: - raise RuntimeError(f"Mismatch on {func}: {e}") + raise RuntimeError(f"Mismatch on {func}: {e}") from e return r diff --git a/torch/_subclasses/meta_utils.py b/torch/_subclasses/meta_utils.py index 80723f1246339..577eb813dde9c 100644 --- a/torch/_subclasses/meta_utils.py +++ b/torch/_subclasses/meta_utils.py @@ -1,8 +1,10 @@ +import contextlib +import warnings import weakref +from typing import ContextManager import torch from torch.multiprocessing.reductions import StorageWeakRef -from torch.utils._mode_utils import no_dispatch def safe_is_leaf(t): @@ -13,6 +15,49 @@ def safe_is_leaf(t): return False +def safe_grad(t): + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "The .grad attribute of a Tensor") + return t.grad + + +def assert_eq(a, b): + assert a == b, f"{a} != {b}" + + +def assert_metadata_eq(assert_eq, m1, m2, *, skip_symbolic=False): + def go(m1, m2): + assert_eq(m1.dtype, m2.dtype) + if not skip_symbolic: + assert_eq(m1.shape, m2.shape) + assert_eq(m1.requires_grad, m2.requires_grad) + assert_eq(m1.is_leaf, m2.is_leaf) + assert_eq(m1.grad_fn is None, m2.grad_fn is None) + assert_eq(m1.is_sparse, m2.is_sparse) + assert_eq(m1.is_inference(), m2.is_inference()) + assert_eq(m1.is_conj(), m2.is_conj()) + assert_eq(m1.is_neg(), m2.is_neg()) + assert_eq(safe_grad(m1) is not None, safe_grad(m2) is not None) + if safe_grad(m1) is not None: + go(safe_grad(m1), safe_grad(m2)) + if m1.is_sparse: + assert_eq(m1.dense_dim(), m2.dense_dim()) + assert_eq(m1.sparse_dim(), m2.sparse_dim()) + assert_eq(m1.is_coalesced(), m2.is_coalesced()) + else: + if not skip_symbolic: + assert_eq(m1.stride(), m2.stride()) + assert_eq(m1.storage_offset(), m2.storage_offset()) + assert_eq(m1._is_view(), m2._is_view()) + if m1._is_view(): + go(m1._base, m2._base) + # TODO: test if is resizable (no direct query for this atm) + # TODO: audit AutogradMeta to see if it matches + # TODO: test forward AD + + return go(m1, m2) + + # torch.Tensors cannot be used as a key in a dictionary # because they define a custom __eq__ function which when used # to resolve hash collisions will throw when comparing tensors: @@ -56,7 +101,7 @@ def __eq__(self, other): class MetaConverter: def __init__(self): self.storage_memo = {} - self.tensor_memo = {} + self.tensor_memo: weakref.WeakValueDictionary = weakref.WeakValueDictionary() self.maybe_storages_to_delete = [] self.check_expired_frequency = 128 self.check_expired_count = 0 @@ -95,10 +140,10 @@ def set_tensor_memo(self, t, v): # hold a weak ref to self, otherwise it will be kept alive # by the del_ten closure self_weak_ref = weakref.ref(self) - if t.is_sparse: + if t.is_sparse or t.is_mkldnn: weak_st = None else: - weak_st = StorageWeakRef(t.storage()) + weak_st = StorageWeakRef(t._typed_storage()) tensor_ref_key = WeakTensorRefKey(t) def del_ten(): @@ -127,33 +172,64 @@ def del_ten(): # NB: doesn't actually return a storage, because meta storage is # not supported - def meta_storage(self, s): + def meta_storage(self, s, callback): # NB: TypedStorage is freshly allocated and cannot be used as hash # key index. # Use a Weak Ref to s in order to not leak memory swr = StorageWeakRef(s) if swr not in self.storage_memo: - self.storage_memo[swr] = torch.empty(s.size(), dtype=s.dtype, device="meta") + self.storage_memo[swr] = callback( + lambda: torch.empty(s.size(), dtype=torch.uint8, device="meta") + )._storage() return self.storage_memo[swr] # This function assumes that it's possible to do the conversion - def meta_tensor(self, t, shape_env=None): + def meta_tensor(self, t, shape_env=None, callback=lambda t: t()): + # This indicates you set no_dispatch() before calling into this + # function. This is an error: we may be creating fake tensors and + # will perform operations on them which need fake tensor mode to + # be active. You will segfault if you are in a no_dispatch() block. + assert not torch._C._dispatch_tls_local_exclude_set().has( + torch._C.DispatchKey.Python + ) arg_cnt = self.arg_cnt self.arg_cnt += 1 - make_symbolic = shape_env is not None + # When we make as_strided calls, we end up generating a guard + # that the new as_strided tensor is in bounds for the old storage + # for the base (since as_strided calls can "bust" out of their + # bounding box.) This guard is unnecessary: if a user is able + # to provide us a tensor with the view base setup this way, we + # don't need to produce a guard, because the fact that they + # were able to produce the view base means its in bounds. + # + # Now, ordinarily, this guard would be harmless. However, the + # generated guard refers to variables bound on the base variable. + # At the moment, Dynamo doesn't actually guard on x._base, because + # according to Voz this results in a lot of spurious invalidations, + # and also if the user doesn't directly make use of _base, its + # pointless anyway (because programs should be parametric over + # whether or not the input tensor is a view or not--unless you're + # mutating the input, but that's a whole 'nother ballgame). So + # for expediency, we suppress these guards so we don't have to + # deal with this (yet, anyway.) + # + # NB: An old version of this code suppressed guards for ALL operations + # happening during meta conversion, not just as_strided calls. + # This is too aggressive: we do duck sizing and 0/1 simplification + # as we allocate variables, and we do need to register guards for + # these cases. + maybe_suppress = contextlib.nullcontext + if shape_env is not None: + maybe_suppress = shape_env.suppress_guards - def sym(x): - if make_symbolic: - return shape_env.create_symbol(x) - else: - return x + make_symbolic = shape_env is not None - def sym_sizes_strides(t): + def sym_sizes_strides_storage_offset(t): if make_symbolic: - return shape_env.create_symbolic_sizes_strides(t) - return (t.size(), t.stride()) + return shape_env.create_symbolic_sizes_strides_storage_offset(t) + return (t.size(), t.stride(), t.storage_offset()) # see expired-storages self.check_expired_count += 1 @@ -166,14 +242,22 @@ def sym_sizes_strides(t): if t.is_sparse: assert shape_env is None, "symbolic on sparse NYI" is_leaf = safe_is_leaf(t) - r = torch.ops.aten._sparse_coo_tensor_with_dims( - t.sparse_dim(), - t.dense_dim(), - t.shape, - dtype=t.dtype, - layout=torch.sparse_coo, - device="meta", + r = callback( + lambda: torch.ops.aten._sparse_coo_tensor_with_dims( + t.sparse_dim(), + t.dense_dim(), + t.shape, + dtype=t.dtype, + layout=torch.sparse_coo, + device="meta", + ) ) + assert safe_is_leaf(r), "the callback you passed in doesn't detach" + # Note [is_coalesced is dispatched] + # Strangely enough, is_coalesced() is a dispatched operator, + # which means that it will get caught by fake tensor mode. + # Ordinarily this would error, but there's some logic in + # fake tensor ensure this doesn't happen. r._coalesced_(t.is_coalesced()) if t.requires_grad: r.requires_grad = True @@ -181,14 +265,30 @@ def sym_sizes_strides(t): with torch.enable_grad(): r = r.clone() r._coalesced_(t.is_coalesced()) - + elif t.is_mkldnn: + is_leaf = safe_is_leaf(t) + sizes, strides, _storage_offset = sym_sizes_strides_storage_offset( + t + ) + r = callback( + lambda: torch.empty_strided( + sizes, strides, dtype=t.dtype, device="meta" + ) + ) + assert safe_is_leaf(r), "the callback you passed in doesn't detach" + if t.requires_grad: + r.requires_grad = True + if t.requires_grad and not is_leaf: + with torch.enable_grad(): + r = r.clone() elif t._is_view(): # Construct views in two steps: recursively meta-fy their - # base, and then create the view off that. NB: doing it + # base, and then create view(s) off that. NB: doing it # directly from storage is WRONG because this won't cause # version counters to get shared. assert t._is_view() - base = self.meta_tensor(t._base) + + base = self.meta_tensor(t._base, shape_env, callback) def is_c_of_r(complex_dtype, real_dtype): return ( @@ -197,55 +297,163 @@ def is_c_of_r(complex_dtype, real_dtype): == real_dtype ) - if base.dtype == t.dtype: - pass - elif is_c_of_r(base.dtype, t.dtype): - base = torch.view_as_real(base) - elif is_c_of_r(t.dtype, base.dtype): - base = torch.view_as_complex(base) - else: - # This is not guaranteed to succeed. If it fails, it - # means there is another dtype-converting view function - # that hasn't been handled here - base = base.view(t.dtype) - - with torch.enable_grad(): - sizes, strides = sym_sizes_strides(t) - r = base.as_strided(sizes, strides, sym(t.storage_offset())) + # In some situations, MetaConverter may be called in a + # context where autograd is disabled. For the _is_view + # assert to pass, we have to setup the autograd view + # metadata anyway. Do this by reenabling the + # ADInplaceOrView key. This is kind of a hack. + old_exclude = torch._C._dispatch_tls_is_dispatch_key_excluded( + torch._C.DispatchKey.ADInplaceOrView + ) + torch._C._dispatch_tls_set_dispatch_key_excluded( + torch._C.DispatchKey.ADInplaceOrView, False + ) + try: + + if base.dtype == t.dtype: + pass + elif is_c_of_r(base.dtype, t.dtype): + base = torch.view_as_real(base) + elif is_c_of_r(t.dtype, base.dtype): + base = torch.view_as_complex(base) + else: + # This is not guaranteed to succeed. If it fails, it + # means there is another dtype-converting view function + # that hasn't been handled here + base = base.view(t.dtype) + + # This is very tricky. Naively, you might expect this + # to hold: + # + # if t.requires_grad and not safe_is_leaf(t) + # assert t._base.requires_grad + # + # But it's not true! As you can see in the following + # program: + # + # x = torch.zeros(4) + # y = x.view(1, 4) + # y.requires_grad = True + # z = y.view(1, 1, 4) + # assert z._base is x + # + # So we may have to do *two* views out of the base to + # recreate this situation. + + ( + sizes, + strides, + storage_offset, + ) = sym_sizes_strides_storage_offset(t) + + if safe_is_leaf(t): + # Leaf views that track view metadata are created by + # creating a view inside a no_grad block + with torch.no_grad(), maybe_suppress(): + r = base.as_strided(sizes, strides, storage_offset) + # As it's a leaf, we can directly assign requires_grad + r.requires_grad = t.requires_grad + else: + if t._base.requires_grad == t.requires_grad: + # Easy case, just run the view op + with torch.enable_grad(), maybe_suppress(): + r = base.as_strided(sizes, strides, storage_offset) + else: + # Obscure case. Create a leaf view and give it the + # correct requires_grad, then do the final view. + # NB: Can't have a non-leaf without requiring grad! + assert t.requires_grad + with torch.no_grad(): + mid = base.view(base.shape) + mid.requires_grad = t.requires_grad + with torch.enable_grad(), maybe_suppress(): + r = mid.as_strided(sizes, strides, storage_offset) + finally: + torch._C._dispatch_tls_set_dispatch_key_excluded( + torch._C.DispatchKey.ADInplaceOrView, old_exclude + ) + else: is_leaf = safe_is_leaf(t) - # Fake up some autograd history. - if t.requires_grad: - r = torch.empty( - (0,), dtype=t.dtype, device="meta", requires_grad=True + sizes, strides, storage_offset = sym_sizes_strides_storage_offset(t) + r = callback( + lambda: torch.empty_strided( + sizes, strides, dtype=t.dtype, device="meta" ) + ) + assert safe_is_leaf(r), "the callback you passed in doesn't detach" + if t.requires_grad: + r.requires_grad = t.requires_grad if not is_leaf: + # Fake up some autograd history. with torch.enable_grad(): - # The backward function here will be wrong, but - # that's OK; our goal is just to get the metadata - # looking as close as possible; we're not going to - # actually try to backward() on these produced - # metas. TODO: would be safer to install some - # sort of unsupported grad_fn here - r = r.clone() + # preserve_format is the default, but we want to + # emphasize how important it is to preserve + # format here + r = r.clone(memory_format=torch.preserve_format) + + s = t._storage() + swr = StorageWeakRef(s) + if ( + swr not in self.storage_memo + and r.stride() == strides + and r.storage_offset() == storage_offset + ): + # You're normal and happy, install the fresh storage into the memo + self.storage_memo[swr] = r._storage() else: - r = torch.empty((0,), dtype=t.dtype, device="meta") - # As long as meta storage is not supported, need to prevent - # redispatching on set_(Storage, ...) which will choke with - # meta storage - s = self.meta_storage(t.storage()) - with no_dispatch(): - sizes, strides = sym_sizes_strides(t) - with torch.no_grad(): - r.set_(s, sym(t.storage_offset()), sizes, strides) + # You're in crazy town; somehow you gave us a tensor + # that wasn't a view, but had nonzero storage offset, + # nontrivial strides (such that clone() couldn't + # preserve them), or already aliases with another + # tensor's storage. The most typical way to end + # up here is with set_. So use set_ to bludgeon this + # in. + r_s = self.meta_storage(s, callback=callback) + # NB: In principle, this should always work, but there + # is some subtle difference in the autograd metadata + # that means we will backprop the set_ call, even if + # r is declared as an input to grad. + # See https://github.com/pytorch/pytorch/issues/87956 + # for the reproducer. + # NB: The in_kernel_invocation_manager here is necessary + # for fake tensor. If we run the set_ call with fake + # tensor on, r will improperly report that it is NOT a + # meta tensor but a cpu tensor, and then the set_ call + # will fail due to device mismatch. no_dispatch() is + # not enough, because the fake tensor will still claim + # to be a CPU tensor and you'll end up in the CPU + # kernel. Arguably this is a hack; a cleaner way to + # solve this is to have a FakeStorage concept which + # would report it's CPU device--no problem now! But + # this is difficult to do because we don't have storage + # subclasses. Relevant test is + # DynamicShapesFunctionTests::test_add_dynamic_shapes in + # test/dynamo/test_dynamic_shapes.py + maybe_fake_mgr: ContextManager[None] = contextlib.nullcontext() + from torch._subclasses.fake_tensor import ( + FakeTensor, + in_kernel_invocation_manager, + ) + if isinstance(r, FakeTensor): + maybe_fake_mgr = in_kernel_invocation_manager(r.fake_mode) + with maybe_fake_mgr, torch.no_grad(): + r.set_(r_s, storage_offset, sizes, strides) + + if safe_grad(t) is not None: + r.grad = self.meta_tensor(safe_grad(t), shape_env, callback) torch._C._set_conj(r, t.is_conj()) torch._C._set_neg(r, t.is_neg()) + # This can be skipped if necessary for performance reasons + assert_metadata_eq(assert_eq, t, r, skip_symbolic=True) self.set_tensor_memo(t, r) return self.get_tensor_memo(t) - def __call__(self, t, shape_env=None): + def __call__( + self, t, shape_env=None, *, callback=lambda t: t(), ignore_subclass=False + ): # TODO: zero tensors? We appear to have eliminated them by # excluding complex for now from torch._subclasses.fake_tensor import FakeTensor @@ -253,13 +461,13 @@ def __call__(self, t, shape_env=None): if ( type(t) is torch.Tensor or type(t) is torch.nn.Parameter + or (ignore_subclass and isinstance(t, torch.Tensor)) or isinstance(t, FakeTensor) ): if any( [ t.is_sparse_csr, t.layout in [torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc], - t.is_mkldnn, t.is_quantized, t.is_nested, t._is_view() and t._base is not None and t._base.is_sparse, @@ -280,10 +488,20 @@ def __call__(self, t, shape_env=None): # tests all break so we just exclude this. In any case # the to conversion isn't really right anyhow. self.miss += 1 - return t + return NotImplemented else: self.hit += 1 - r = self.meta_tensor(t, shape_env=shape_env) + # When ignoring subclasses, we treat the input tensor "as if" it + # were a normal tensor and create a non-subclassed fake tensor + # that, modulo type and attributes, resembles the original tensor. + # This can be helpful if you're planning to simulate the subclassness + # by hand, e.g., as is done in Dynamo + ctx = contextlib.nullcontext() + if ignore_subclass: + ctx = torch._C.DisableTorchFunction() + with ctx: + r = self.meta_tensor(t, shape_env=shape_env, callback=callback) + # TODO: this is suspicious, now that we have callback argument if type(t) is torch.nn.Parameter: r = torch.nn.Parameter(r, requires_grad=r.requires_grad) return r @@ -294,7 +512,7 @@ def __call__(self, t, shape_env=None): # support meta. Trying to YOLO this is more trouble than it's # worth. self.miss += 1 - return t + return NotImplemented else: # non-Tensor types don't count as hit or miss return t diff --git a/torch/_tensor.py b/torch/_tensor.py index d0af241c8a221..5e31f565a023a 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -55,9 +55,6 @@ def _rebuild_from_type(func, type, args, dict): def _rebuild_from_type_v2(func, new_type, args, state): - if new_type is Tensor: - return func(*args) - ret = func(*args) if type(ret) is not new_type: ret = ret.as_subclass(new_type) @@ -70,21 +67,7 @@ def _rebuild_from_type_v2(func, new_type, args, state): ): ret.__setstate__(state) else: - if isinstance(state, tuple): - if not len(state) == 2: - raise RuntimeError(f"Invalid serialized state: {state}") - dict_state = state[0] - slots_state = state[1] - else: - dict_state = state - slots_state = None - - for k, v in dict_state.items(): - setattr(ret, k, v) - - if slots_state: - for k, v in slots_state.items(): - setattr(ret, k, v) + ret = torch._utils._set_obj_state(ret, state) return ret @@ -114,7 +97,8 @@ def __deepcopy__(self, memo): # Update the test in test_serialization if you remove 'meta' from here if ( self.is_sparse - or self.device.type in ["lazy", "xla", "mps", "ort", "meta", "hpu"] + or self.device.type + in ["lazy", "xla", "mps", "ort", "meta", "hpu", "ipu"] or ( not torch._C._has_storage(self) and self.device.type == "privateuseone" @@ -132,7 +116,7 @@ def __deepcopy__(self, memo): "different type." ) else: - new_storage = self.storage().__deepcopy__(memo) + new_storage = self._typed_storage()._deepcopy(memo) if self.is_quantized: # quantizer_params can be different type based on torch attribute quantizer_params: Union[ @@ -163,7 +147,9 @@ def __deepcopy__(self, memo): # need to wrap with TypedStorage new_tensor = torch._utils._rebuild_qtensor( torch.storage.TypedStorage( - wrap_storage=new_storage.untyped(), dtype=self.dtype + wrap_storage=new_storage._untyped_storage, + dtype=self.dtype, + _internal=True, ), self.storage_offset(), self.size(), @@ -221,31 +207,13 @@ def __deepcopy__(self, memo): return new_tensor def __reduce_ex__(self, proto): - if type(self) is Tensor: + state = torch._utils._get_obj_state(self) + if type(self) is Tensor and not state: + # Fast path for regular tensor without Python state. return self._reduce_ex_internal(proto) if has_torch_function_unary(self): return handle_torch_function(Tensor.__reduce_ex__, (self,), self, proto) func, args = self._reduce_ex_internal(proto) - # Get the state of the python subclass - # This loosely mimicks the function on the object class but since Tensor do not inherit - # from it, we cannot call that function directly - # https://github.com/python/cpython/blob/c83919bd635f4433f1c6ae8504996a9fe3c215e5/Objects/typeobject.c#L4891 - getstate_fn = getattr(self, "__getstate__", None) - if getstate_fn: - state = getstate_fn() - else: - slots_to_save = copyreg._slotnames(self.__class__) # type: ignore[attr-defined] - if slots_to_save: - state = ( - self.__dict__, - { - name: getattr(self, name) - for name in slots_to_save - if hasattr(self, name) - }, - ) - else: - state = self.__dict__ return (_rebuild_from_type_v2, (func, type(self), args, state)) def storage(self): @@ -257,7 +225,17 @@ def storage(self): if has_torch_function_unary(self): return handle_torch_function(Tensor.storage, (self,), self) - return torch.TypedStorage(wrap_storage=self._storage(), dtype=self.dtype) + torch.storage._warn_typed_storage_removal(stacklevel=2) + return self._typed_storage() + + # For internal use only, to avoid raising deprecation warning + def _typed_storage(self): + _storage = self._storage() + if isinstance(_storage, torch.TypedStorage): + _storage = _storage._untyped_storage + return torch.TypedStorage( + wrap_storage=_storage, dtype=self.dtype, _internal=True + ) def _reduce_ex_internal(self, proto): check_serializing_named_tensor(self) @@ -331,7 +309,9 @@ def _reduce_ex_internal(self, proto): # need to wrap with TypedStorage args_qtensor = ( torch.storage.TypedStorage( - wrap_storage=self.storage().untyped(), dtype=self.dtype + wrap_storage=self._typed_storage()._untyped_storage, + dtype=self.dtype, + _internal=True, ), self.storage_offset(), tuple(self.size()), @@ -352,22 +332,32 @@ def _reduce_ex_internal(self, proto): "sparse tensor __reduce_ex__ for layout `%s`" % (self.layout) ) return (torch._utils._rebuild_sparse_tensor, args_sparse) - elif self.is_sparse_csr: - if self.layout == torch.sparse_csr: - args_sparse_csr = ( - self.layout, - ( - self.crow_indices(), - self.col_indices(), - self.values(), - self.size(), - ), + elif self.layout in { + torch.sparse_csr, + torch.sparse_csc, + torch.sparse_bsr, + torch.sparse_bsc, + }: + if self.layout in {torch.sparse_csr, torch.sparse_bsr}: + compressed_indices, plain_indices = ( + self.crow_indices(), + self.col_indices(), ) else: - raise NotImplementedError( - "sparse csr tensor __reduce_ex__ for layout `%s`" % (self.layout) + compressed_indices, plain_indices = ( + self.ccol_indices(), + self.row_indices(), ) - return (torch._utils._rebuild_sparse_csr_tensor, args_sparse_csr) + args_sparse_compressed = ( + self.layout, + ( + compressed_indices, + plain_indices, + self.values(), + self.size(), + ), + ) + return (torch._utils._rebuild_sparse_tensor, args_sparse_compressed) elif ( self.data_ptr() == 0 and type(self) is not torch.Tensor @@ -389,7 +379,9 @@ def _reduce_ex_internal(self, proto): # need to wrap with TypedStorage args = ( torch.storage.TypedStorage( - wrap_storage=self.storage().untyped(), dtype=self.dtype + wrap_storage=self._typed_storage()._untyped_storage, + dtype=self.dtype, + _internal=True, ), self.storage_offset(), tuple(self.size()), @@ -397,6 +389,10 @@ def _reduce_ex_internal(self, proto): self.requires_grad, backward_hooks, ) # previously was self._backward_hooks + + metadata = torch._utils.get_tensor_metadata(self) + if metadata: + args = args + (metadata,) # type: ignore[assignment] return (torch._utils._rebuild_tensor_v2, args) def __setstate__(self, state): @@ -607,7 +603,7 @@ def is_shared(self): """ if has_torch_function_unary(self): return handle_torch_function(Tensor.is_shared, (self,), self) - return self.storage().is_shared() + return self._typed_storage()._is_shared() def share_memory_(self): r"""Moves the underlying storage to shared memory. @@ -617,7 +613,7 @@ def share_memory_(self): """ if has_torch_function_unary(self): return handle_torch_function(Tensor.share_memory_, (self,), self) - self.storage().share_memory_() + self._typed_storage()._share_memory_() return self def __reversed__(self): @@ -629,7 +625,13 @@ def __reversed__(self): else: return self.flip(0) - def norm(self, p="fro", dim=None, keepdim=False, dtype=None): + def norm( + self, + p: Optional[Union[float, str]] = "fro", + dim=None, + keepdim=False, + dtype=None, + ): r"""See :func:`torch.norm`""" if has_torch_function_unary(self): return handle_torch_function( @@ -1059,7 +1061,9 @@ def storage_type(self): if has_torch_function_unary(self): return handle_torch_function(Tensor.storage_type, (self,), self) - return self.storage()._get_legacy_storage_class() + torch.storage._warn_typed_storage_removal() + + return self._typed_storage()._get_legacy_storage_class() def refine_names(self, *names): r"""Refines the dimension names of :attr:`self` according to :attr:`names`. diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index 9f8ce0c0f8520..cc4c9d3a92f63 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -1584,7 +1584,7 @@ def add_docstr_all(method, docstr): add_docstr_all( "as_strided_scatter", r""" -as_strided_scatter(src, size, stride, storage_offset=0) -> Tensor +as_strided_scatter(src, size, stride, storage_offset=None) -> Tensor See :func:`torch.as_strided_scatter` """, @@ -2229,14 +2229,15 @@ def add_docstr_all(method, docstr): get_device() -> Device ordinal (Integer) For CUDA tensors, this function returns the device ordinal of the GPU on which the tensor resides. -For CPU tensors, an error is thrown. +For CPU tensors, this function returns `-1`. Example:: >>> x = torch.randn(3, 4, 5, device='cuda:0') >>> x.get_device() 0 - >>> x.cpu().get_device() # RuntimeError: get_device is not implemented for type torch.FloatTensor + >>> x.cpu().get_device() + -1 """, ) @@ -3435,18 +3436,7 @@ def callable(a, b) -> number r""" narrow(dimension, start, length) -> Tensor -See :func:`torch.narrow` - -Example:: - - >>> x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) - >>> x.narrow(0, 0, 2) - tensor([[ 1, 2, 3], - [ 4, 5, 6]]) - >>> x.narrow(1, 1, 2) - tensor([[ 2, 3], - [ 5, 6], - [ 8, 9]]) +See :func:`torch.narrow`. """, ) @@ -4262,7 +4252,7 @@ def callable(a, b) -> number Additionally accepts an optional :attr:`reduce` argument that allows specification of an optional reduction operation, which is applied to all -values in the tensor :attr:`src` into :attr:`self` at the indicies +values in the tensor :attr:`src` into :attr:`self` at the indices specified in the :attr:`index`. For each value in :attr:`src`, the reduction operation is applied to an index in :attr:`self` which is specified by its index in :attr:`src` for ``dimension != dim`` and by the corresponding @@ -4795,12 +4785,7 @@ def callable(a, b) -> number add_docstr_all( "std", r""" -std(dim, unbiased=True, keepdim=False) -> Tensor - -See :func:`torch.std` - -.. function:: std(unbiased=True) -> Tensor - :noindex: +std(dim=None, *, correction=1, keepdim=False) -> Tensor See :func:`torch.std` """, @@ -5388,6 +5373,48 @@ def callable(a, b) -> number tensor(indices=tensor([[1]]), values=tensor([[ 9, 0, 10]]), size=(3, 3), nnz=1, layout=torch.sparse_coo) + +.. method:: to_sparse(*, layout=None, blocksize=None) -> Tensor + :noindex: + +Returns a sparse tensor with the specified layout and blocksize. + +.. note:: If the :attr:`self` layout and blocksize parameters match + with the specified layout and blocksize, return + :attr:`self`. Otherwise, return a sparse tensor copy of + :attr:`self`. + +Args: + + layout (:class:`torch.layout`, optional): The desired sparse + layout. One of ``torch.sparse_coo``, ``torch.sparse_csr``, + ``torch.sparse_csc``, ``torch.sparse_bsr``, or + ``torch.sparse_bsc``. Default: if ``None``, + ``torch.sparse_coo``. + + blocksize (list, tuple, :class:`torch.Size`, optional): Block size + of the resulting BSR or BSC tensor. For other layouts, + specifying the block size that is not ``None`` will result in a + RuntimeError exception. A block size must be a tuple of length + two such that its items evenly divide the two sparse dimensions. + +Example:: + + >>> x = torch.tensor([[1, 0], [0, 0], [2, 3]]) + >>> x.to_sparse(layout=torch.sparse_coo) + tensor(indices=tensor([[0, 2, 2], + [0, 0, 1]]), + values=tensor([1, 2, 3]), + size=(3, 2), nnz=3, layout=torch.sparse_coo) + >>> x.to_sparse(layout=torch.sparse_bsr, blocksize=(1, 2)) + tensor(crow_indices=tensor([0, 1, 1, 2]), + col_indices=tensor([0, 0]), + values=tensor([[[1, 0]], + [[2, 3]]]), size=(3, 2), nnz=2, layout=torch.sparse_bsr) + >>> x.to_sparse(layout=torch.sparse_bsr, blocksize=(2, 1)) + RuntimeError: Tensor size(-2) 3 needs to be divisible by blocksize[0] 2 + >>> x.to_sparse(layout=torch.sparse_csr, blocksize=(3, 1)) + RuntimeError: to_sparse for Strided to SparseCsr conversion does not use specified blocksize """, ) @@ -5706,12 +5733,7 @@ def callable(a, b) -> number add_docstr_all( "var", r""" -var(dim, unbiased=True, keepdim=False) -> Tensor - -See :func:`torch.var` - -.. function:: var(unbiased=True) -> Tensor - :noindex: +var(dim=None, *, correction=1, keepdim=False) -> Tensor See :func:`torch.var` """, diff --git a/torch/_tensor_str.py b/torch/_tensor_str.py index 986be67a52f68..ad5429c61e56d 100644 --- a/torch/_tensor_str.py +++ b/torch/_tensor_str.py @@ -632,6 +632,6 @@ def _functorch_wrapper_str_intern(tensor, *, tensor_contents=None): def _str(self, *, tensor_contents=None): - with torch.no_grad(): + with torch.no_grad(), torch.utils._python_dispatch._disable_current_modes(): guard = torch._C._DisableFuncTorch() return _str_intern(self, tensor_contents=tensor_contents) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 00e7129cfb10e..baf683901ef29 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -108,6 +108,16 @@ def merge_dicts(*dicts): returned Tensor. Default: ``torch.contiguous_format``. """ ), + { + "sparse_factory_device_note": """\ +.. note:: + + If the ``device`` argument is not specified the device of the given + :attr:`values` and indices tensor(s) must match. If, however, the + argument is specified the input Tensors will be converted to the + given device and in turn determine the device of the constructed + sparse tensor.""" + }, ) factory_like_common_args = parse_kwargs( @@ -936,7 +946,7 @@ def merge_dicts(*dicts): size (tuple or ints): the shape of the output tensor stride (tuple or ints): the stride of the output tensor storage_offset (int, optional): the offset in the underlying storage of the output tensor. - If ``None``, the storage_offset of the output tensor will match the input tensor. + If ``None``, the storage_offset of the output tensor will match the input tensor. Example:: @@ -3740,7 +3750,7 @@ def merge_dicts(*dicts): add_docstr( torch.as_strided_scatter, r""" -as_strided_scatter(input, src, size, stride, storage_offset=0) -> Tensor +as_strided_scatter(input, src, size, stride, storage_offset=None) -> Tensor Embeds the values of the :attr:`src` tensor into :attr:`input` along the elements corresponding to the result of calling @@ -7980,8 +7990,10 @@ def merge_dicts(*dicts): Args: input (Tensor): the tensor to narrow dim (int): the dimension along which to narrow - start (Tensor or int): the starting dimension - length (int): the distance to the ending dimension + start (int or Tensor): index of the element to start the narrowed dimension + from. Can be negative, which means indexing from the end of `dim`. If + `Tensor`, it must be an 0-dim integral `Tensor` (bools not allowed) + length (int): length of the narrowed dimension, must be weakly positive Example:: @@ -7993,6 +8005,10 @@ def merge_dicts(*dicts): tensor([[ 2, 3], [ 5, 6], [ 8, 9]]) + >>> torch.narrow(x, -1, torch.tensor(-1), 1) + tensor([[3], + [6], + [9]]) """, ) @@ -8008,8 +8024,9 @@ def merge_dicts(*dicts): Args: input (Tensor): the tensor to narrow dim (int): the dimension along which to narrow - start (int): the starting offset - length (int): the distance to the ending dimension + start (int): index of the element to start the narrowed dimension from. Can + be negative, which means indexing from the end of `dim` + length (int): length of the narrowed dimension, must be weakly positive Keyword args: {out} @@ -8027,13 +8044,13 @@ def merge_dicts(*dicts): >>> s = torch.arange(16).reshape(2, 2, 2, 2).to_sparse(2) >>> torch.narrow_copy(s, 0, 0, 1) tensor(indices=tensor([[0, 0], - [0, 1]]), - values=tensor([[[0, 1], - [2, 3]], + [0, 1]]), + values=tensor([[[0, 1], + [2, 3]], - [[4, 5], - [6, 7]]]), - size=(1, 2, 2, 2), nnz=2, layout=torch.sparse_coo) + [[4, 5], + [6, 7]]]), + size=(1, 2, 2, 2), nnz=2, layout=torch.sparse_coo) .. seealso:: @@ -8834,7 +8851,7 @@ def merge_dicts(*dicts): If you plan to backpropagate through QR, note that the current backward implementation is only well-defined when the first :math:`\min(input.size(-1), input.size(-2))` columns of :attr:`input` are linearly independent. - This behavior will propably change once QR supports pivoting. + This behavior will probably change once QR supports pivoting. .. note:: This function uses LAPACK for CPU inputs and MAGMA for CUDA inputs, and may produce different (valid) decompositions on different device types @@ -8979,8 +8996,8 @@ def merge_dicts(*dicts): add_docstr( torch.rand, """ -rand(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, \ -pin_memory=False) -> Tensor +rand(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, \ +requires_grad=False, pin_memory=False) -> Tensor """ + r""" Returns a tensor filled with random numbers from a uniform distribution @@ -9660,6 +9677,11 @@ def merge_dicts(*dicts): Slices the :attr:`input` tensor along the selected dimension at the given index. This function returns a view of the original tensor with the given dimension removed. +.. note:: If :attr:`input` is a sparse tensor and returning a view of + the tensor is not possible, a RuntimeError exception is + raised. In this is the case, consider using + :func:`torch.select_copy` function. + Args: {input} dim (int): the dimension to slice @@ -10149,6 +10171,8 @@ def merge_dicts(*dicts): have a look at :ref:`the note on the data type of the indices `. +{sparse_factory_device_note} + Args: compressed_indices (array_like): (B+1)-dimensional array of size ``(*batchsize, compressed_dim_size + 1)``. The last element of @@ -10161,10 +10185,12 @@ def merge_dicts(*dicts): plain_indices (array_like): Plain dimension (column or row) co-ordinates of each element or block in values. (B+1)-dimensional tensor with the same length as values. + values (array_list): Initial values for the tensor. Can be a list, tuple, NumPy ``ndarray``, scalar, and other types. that - represents a (1+K)-dimensional or (1+2+K)-dimensional tensor - where ``K`` is the number of dense dimensions. + represents a (1+K)-dimensional (for CSR and CSC layouts) or + (1+2+K)-dimensional tensor (for BSR and BSC layouts) where + ``K`` is the number of dense dimensions. size (list, tuple, :class:`torch.Size`, optional): Size of the sparse tensor: ``(*batchsize, nrows * blocksize[0], ncols * blocksize[1], *densesize)`` where ``blocksize[0] == @@ -10214,6 +10240,8 @@ def merge_dicts(*dicts): in CSR format are typically faster than that for sparse tensors in COO format. Make you have a look at :ref:`the note on the data type of the indices `. +{sparse_factory_device_note} + Args: crow_indices (array_like): (B+1)-dimensional array of size ``(*batchsize, nrows + 1)``. The last element of each batch @@ -10227,7 +10255,7 @@ def merge_dicts(*dicts): as values. values (array_list): Initial values for the tensor. Can be a list, tuple, NumPy ``ndarray``, scalar, and other types that - represents a (1+K)-dimensonal tensor where ``K`` is the number + represents a (1+K)-dimensional tensor where ``K`` is the number of dense dimensions. size (list, tuple, :class:`torch.Size`, optional): Size of the sparse tensor: ``(*batchsize, nrows, ncols, *densesize)``. If @@ -10274,6 +10302,8 @@ def merge_dicts(*dicts): for sparse tensors in COO format. Make you have a look at :ref:`the note on the data type of the indices `. +{sparse_factory_device_note} + Args: ccol_indices (array_like): (B+1)-dimensional array of size ``(*batchsize, ncols + 1)``. The last element of each batch @@ -10287,7 +10317,7 @@ def merge_dicts(*dicts): values. values (array_list): Initial values for the tensor. Can be a list, tuple, NumPy ``ndarray``, scalar, and other types that - represents a (1+K)-dimensonal tensor where ``K`` is the number + represents a (1+K)-dimensional tensor where ``K`` is the number of dense dimensions. size (list, tuple, :class:`torch.Size`, optional): Size of the sparse tensor: ``(*batchsize, nrows, ncols, *densesize)``. If @@ -10334,6 +10364,8 @@ def merge_dicts(*dicts): for sparse tensors in COO format. Make you have a look at :ref:`the note on the data type of the indices `. +{sparse_factory_device_note} + Args: crow_indices (array_like): (B+1)-dimensional array of size ``(*batchsize, nrowblocks + 1)``. The last element of each @@ -10347,7 +10379,7 @@ def merge_dicts(*dicts): values. values (array_list): Initial values for the tensor. Can be a list, tuple, NumPy ``ndarray``, scalar, and other types that - represents a (1 + 2 + K)-dimensonal tensor where ``K`` is the + represents a (1 + 2 + K)-dimensional tensor where ``K`` is the number of dense dimensions. size (list, tuple, :class:`torch.Size`, optional): Size of the sparse tensor: ``(*batchsize, nrows * blocksize[0], ncols * @@ -10399,6 +10431,8 @@ def merge_dicts(*dicts): for sparse tensors in COO format. Make you have a look at :ref:`the note on the data type of the indices `. +{sparse_factory_device_note} + Args: ccol_indices (array_like): (B+1)-dimensional array of size ``(*batchsize, ncolblocks + 1)``. The last element of each @@ -10412,7 +10446,7 @@ def merge_dicts(*dicts): as values. values (array_list): Initial blocks for the tensor. Can be a list, tuple, NumPy ``ndarray``, and other types that - represents a (1 + 2 + K)-dimensonal tensor where ``K`` is the + represents a (1 + 2 + K)-dimensional tensor where ``K`` is the number of dense dimensions. size (list, tuple, :class:`torch.Size`, optional): Size of the sparse tensor: ``(*batchsize, nrows * blocksize[0], ncols * @@ -10464,6 +10498,8 @@ def merge_dicts(*dicts): This function returns an :ref:`uncoalesced tensor `. +{sparse_factory_device_note} + Args: indices (array_like): Initial data for the tensor. Can be a list, tuple, NumPy ``ndarray``, scalar, and other types. Will be cast to a :class:`torch.LongTensor` @@ -10641,38 +10677,53 @@ def merge_dicts(*dicts): add_docstr( torch.std, r""" -std(input, dim, unbiased, keepdim=False, *, out=None) -> Tensor +std(input, dim=None, *, correction=1, keepdim=False, out=None) -> Tensor + +Calculates the standard deviation over the dimensions specified by :attr:`dim`. +:attr:`dim` can be a single dimension, list of dimensions, or ``None`` to +reduce over all dimensions. + +The standard deviation (:math:`\sigma`) is calculated as -If :attr:`unbiased` is ``True``, Bessel's correction will be used. -Otherwise, the sample deviation is calculated, without any correction. +.. math:: \sigma = \sqrt{\frac{1}{N - \delta N}\sum_{i=0}^{N-1}(x_i-\bar{x})^2} + +where :math:`x` is the sample set of elements, :math:`\bar{x}` is the +sample mean, :math:`N` is the number of samples and :math:`\delta N` is +the :attr:`correction`. +""" + + r""" + +{keepdim_details} Args: {input} {dim} Keyword args: - unbiased (bool): whether to use Bessel's correction (:math:`\delta N = 1`). + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + .. versionchanged:: 1.14 + Previously this argument was called ``unbiased`` and was a boolean with + ``True`` corresponding to ``correction=1`` and ``False`` being ``correction=0``. + {keepdim} {out} +Example: -.. function:: std(input, unbiased) -> Tensor - :noindex: - -Calculates the standard deviation of all elements in the :attr:`input` tensor. - -If :attr:`unbiased` is ``True``, Bessel's correction will be used. -Otherwise, the sample deviation is calculated, without any correction. - -Args: - {input} - unbiased (bool): whether to use Bessel's correction (:math:`\delta N = 1`). + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]]) + >>> torch.std(a, dim=1, keepdim=True) + tensor([[1.0311], + [0.7477], + [1.2204], + [0.9087]]) -Example:: +.. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction - >>> a = torch.tensor([[-0.8166, -1.3802, -0.3560]]) - >>> torch.std(a, unbiased=False) - tensor(0.4188) """.format( **multi_dim_common ), @@ -10681,45 +10732,54 @@ def merge_dicts(*dicts): add_docstr( torch.std_mean, r""" -std_mean(input, dim, unbiased, keepdim=False, *, out=None) -> (Tensor, Tensor) +std_mean(input, dim=None, *, correction=1, keepdim=False, out=None) -> (Tensor, Tensor) + +Calculates the standard deviation and mean over the dimensions specified by +:attr:`dim`. :attr:`dim` can be a single dimension, list of dimensions, or +``None`` to reduce over all dimensions. + +The standard deviation (:math:`\sigma`) is calculated as + +.. math:: \sigma = \sqrt{\frac{1}{N - \delta N}\sum_{i=0}^{N-1}(x_i-\bar{x})^2} + +where :math:`x` is the sample set of elements, :math:`\bar{x}` is the +sample mean, :math:`N` is the number of samples and :math:`\delta N` is +the :attr:`correction`. + +""" + + r""" -If :attr:`unbiased` is ``True``, Bessel's correction will be used to calculate -the standard deviation. Otherwise, the sample deviation is calculated, without -any correction. +{keepdim_details} Args: {input} {opt_dim} Keyword args: - unbiased (bool): whether to use Bessel's correction (:math:`\delta N = 1`). + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + .. versionchanged:: 1.14 + Previously this argument was called ``unbiased`` and was a boolean with + ``True`` corresponding to ``correction=1`` and ``False`` being ``correction=0``. {keepdim} {out} Returns: A tuple (std, mean) containing the standard deviation and mean. -.. function:: std_mean(input, unbiased) -> (Tensor, Tensor) - :noindex: - -Calculates the standard deviation and mean of all elements in the :attr:`input` -tensor. - -If :attr:`unbiased` is ``True``, Bessel's correction will be used. -Otherwise, the sample deviation is calculated, without any correction. - -Args: - {input} - unbiased (bool): whether to use Bessel's correction (:math:`\delta N = 1`). +Example: -Returns: - A tuple (std, mean) containing the standard deviation and mean. + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]]) + >>> torch.std_mean(a, dim=0, keepdim=True) + (tensor([[1.2620, 1.0028, 1.0957, 0.6038]]), + tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]])) -Example:: +.. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction - >>> a = torch.tensor([[-0.8166, -1.3802, -0.3560]]) - >>> torch.std_mean(a, unbiased=False) - (tensor(0.4188), tensor(-0.8509)) """.format( **multi_dim_common ), @@ -11155,7 +11215,7 @@ def merge_dicts(*dicts): r""" flip(input, dims) -> Tensor -Reverse the order of a n-D tensor along given axis in dims. +Reverse the order of an n-D tensor along given axis in dims. .. note:: `torch.flip` makes a copy of :attr:`input`'s data. This is different from NumPy's `np.flip`, @@ -11312,7 +11372,7 @@ def merge_dicts(*dicts): r""" rot90(input, k=1, dims=[0,1]) -> Tensor -Rotate a n-D tensor by 90 degrees in the plane specified by dims axis. +Rotate an n-D tensor by 90 degrees in the plane specified by dims axis. Rotation direction is from the first towards the second axis if k > 0, and from the second towards the first for k < 0. Args: @@ -12081,37 +12141,52 @@ def merge_dicts(*dicts): add_docstr( torch.var, r""" -var(input, dim, unbiased, keepdim=False, *, out=None) -> Tensor +var(input, dim=None, *, correction=1, keepdim=False, out=None) -> Tensor + +Calculates the variance over the dimensions specified by :attr:`dim`. :attr:`dim` +can be a single dimension, list of dimensions, or ``None`` to reduce over all +dimensions. + +The variance (:math:`\sigma^2`) is calculated as + +.. math:: \sigma^2 = \frac{1}{N - \delta N}\sum_{i=0}^{N-1}(x_i-\bar{x})^2 + +where :math:`x` is the sample set of elements, :math:`\bar{x}` is the +sample mean, :math:`N` is the number of samples and :math:`\delta N` is +the :attr:`correction`. +""" + + r""" -If :attr:`unbiased` is ``True``, Bessel's correction will be used. -Otherwise, the sample variance is calculated, without any correction. +{keepdim_details} Args: {input} {opt_dim} Keyword args: - unbiased (bool): whether to use Bessel's correction (:math:`\delta N = 1`). + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + .. versionchanged:: 1.14 + Previously this argument was called ``unbiased`` and was a boolean with + ``True`` corresponding to ``correction=1`` and ``False`` being ``correction=0``. {keepdim} {out} -.. function:: var(input, unbiased) -> Tensor - :noindex: - -Calculates the variance of all elements in the :attr:`input` tensor. - -If :attr:`unbiased` is ``True``, Bessel's correction will be used. -Otherwise, the sample deviation is calculated, without any correction. +Example: -Args: - {input} - unbiased (bool): whether to use Bessel's correction (:math:`\delta N = 1`). + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]]) + >>> torch.var(a, dim=1, keepdim=True) + tensor([[1.0631], + [0.5590], + [1.4893], + [0.8258]]) -Example:: +.. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction - >>> a = torch.tensor([[-0.8166, -1.3802, -0.3560]]) - >>> torch.var(a, unbiased=False) - tensor(0.1754) """.format( **multi_dim_common ), @@ -12120,45 +12195,53 @@ def merge_dicts(*dicts): add_docstr( torch.var_mean, r""" -var_mean(input, dim, unbiased, keepdim=False, *, out=None) -> (Tensor, Tensor) +var_mean(input, dim=None, *, correction=1, keepdim=False, out=None) -> (Tensor, Tensor) + +Calculates the variance and mean over the dimensions specified by :attr:`dim`. +:attr:`dim` can be a single dimension, list of dimensions, or ``None`` to +reduce over all dimensions. + +The variance (:math:`\sigma^2`) is calculated as -If :attr:`unbiased` is ``True``, Bessel's correction will be used to calculate -the variance. Otherwise, the sample variance is calculated, without any -correction. +.. math:: \sigma^2 = \frac{1}{N - \delta N}\sum_{i=0}^{N-1}(x_i-\bar{x})^2 + +where :math:`x` is the sample set of elements, :math:`\bar{x}` is the +sample mean, :math:`N` is the number of samples and :math:`\delta N` is +the :attr:`correction`. +""" + + r""" + +{keepdim_details} Args: {input} {opt_dim} Keyword args: - unbiased (bool): whether to use Bessel's correction (:math:`\delta N = 1`). + correction (int): difference between the sample size and sample degrees of freedom. + Defaults to `Bessel's correction`_, ``correction=1``. + .. versionchanged:: 1.14 + Previously this argument was called ``unbiased`` and was a boolean with + ``True`` corresponding to ``correction=1`` and ``False`` being ``correction=0``. {keepdim} {out} Returns: A tuple (var, mean) containing the variance and mean. -.. function:: var_mean(input, unbiased) -> (Tensor, Tensor) - :noindex: - -Calculates the variance and mean of all elements in the :attr:`input` -tensor. - -If :attr:`unbiased` is ``True``, Bessel's correction will be used. -Otherwise, the sample deviation is calculated, without any correction. - -Args: - {input} - unbiased (bool): whether to use Bessel's correction (:math:`\delta N = 1`). +Example: -Returns: - A tuple (var, mean) containing the variance and mean. + >>> a = torch.tensor( + ... [[ 0.2035, 1.2959, 1.8101, -0.4644], + ... [ 1.5027, -0.3270, 0.5905, 0.6538], + ... [-1.5745, 1.3330, -0.5596, -0.6548], + ... [ 0.1264, -0.5080, 1.6420, 0.1992]]) + >>> torch.var_mean(a, dim=0, keepdim=True) + (tensor([[1.5926, 1.0056, 1.2005, 0.3646]]), + tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]])) -Example:: +.. _Bessel's correction: https://en.wikipedia.org/wiki/Bessel%27s_correction - >>> a = torch.tensor([[-0.8166, -1.3802, -0.3560]]) - >>> torch.var_mean(a, unbiased=False) - (tensor(0.1754), tensor(-0.8509)) """.format( **multi_dim_common ), @@ -12397,7 +12480,7 @@ def merge_dicts(*dicts): add_docstr( torch.where, r""" -where(condition, x, y) -> Tensor +where(condition, x, y, *, out=None) -> Tensor Return a tensor of elements selected from either :attr:`x` or :attr:`y`, depending on :attr:`condition`. @@ -12408,7 +12491,8 @@ def merge_dicts(*dicts): \text{x}_i & \text{if } \text{condition}_i \\ \text{y}_i & \text{otherwise} \\ \end{cases} - +""" + + r""" .. note:: The tensors :attr:`condition`, :attr:`x`, :attr:`y` must be :ref:`broadcastable `. @@ -12419,6 +12503,9 @@ def merge_dicts(*dicts): y (Tensor or Scalar): value (if :attr:`y` is a scalar) or values selected at indices where :attr:`condition` is ``False`` +Keyword args: + {out} + Returns: Tensor: A tensor of shape equal to the broadcasted shape of :attr:`condition`, :attr:`x`, :attr:`y` @@ -12450,7 +12537,9 @@ def merge_dicts(*dicts): .. note:: See also :func:`torch.nonzero`. -""", +""".format( + **common_args + ), ) add_docstr( @@ -12613,7 +12702,7 @@ def merge_dicts(*dicts): {requires_grad} Returns: - Tensor: A 1-D tensor of size :math:`(\text{{window\_length}},)` containing the window + Tensor: A 1-D tensor of size :math:`(\text{{window\_length}},)` containing the window. """.format( **factory_common_args @@ -13113,7 +13202,7 @@ def merge_dicts(*dicts): Keyword args: output_size (int, optional): Total output size for the given axis - ( e.g. sum of repeats). If given, it will avoid stream syncronization + ( e.g. sum of repeats). If given, it will avoid stream synchronization needed to calculate output shape of the tensor. Returns: @@ -13349,7 +13438,7 @@ def merge_dicts(*dicts): input (Tensor): quantized tensor kernel_size (list of int): the size of the sliding window stride (``list of int``, optional): the stride of the sliding window - padding (``list of int``, opttional): padding to be added on both sides, must be >= 0 and <= kernel_size / 2 + padding (``list of int``, optional): padding to be added on both sides, must be >= 0 and <= kernel_size / 2 dilation (``list of int``, optional): The stride between elements within a sliding window, must be > 0. Default 1 ceil_mode (bool, optional): If True, will use ceil instead of floor to compute the output shape. Defaults to False. diff --git a/torch/_utils.py b/torch/_utils.py index 8a539d75f5657..1bf3cf96ad1ce 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -1,3 +1,4 @@ +import copyreg import sys import traceback import warnings @@ -143,15 +144,33 @@ def _get_async_or_non_blocking(function_name, non_blocking, kwargs): # be a TypedStorage def _rebuild_tensor(storage, storage_offset, size, stride): # first construct a tensor with the correct dtype/device - t = torch.tensor([], dtype=storage.dtype, device=storage.untyped().device) - return t.set_(storage.untyped(), storage_offset, size, stride) + t = torch.tensor([], dtype=storage.dtype, device=storage._untyped_storage.device) + return t.set_(storage._untyped_storage, storage_offset, size, stride) + + +def get_tensor_metadata(tensor): + # Tensor's Metadata for serializing. + # Currently, this only returns a dict[string, bool] specifing whether + # `conj` or `neg` bit is set. + assert isinstance(tensor, torch.Tensor) + return torch._C._get_tensor_metadata(tensor) # type: ignore[attr-defined] + + +def set_tensor_metadata(tensor, metadata): + # See `get_tensor_metadata` above + assert isinstance(metadata, dict) + assert isinstance(tensor, torch.Tensor) + torch._C._set_tensor_metadata(tensor, metadata) # type: ignore[attr-defined] def _rebuild_tensor_v2( - storage, storage_offset, size, stride, requires_grad, backward_hooks + storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None ): tensor = _rebuild_tensor(storage, storage_offset, size, stride) tensor.requires_grad = requires_grad + if metadata: + set_tensor_metadata(tensor, metadata) + # NB: This line exists only for backwards compatibility; the # general expectation is that backward_hooks is an empty # OrderedDict. See Note [Don't serialize hooks] @@ -174,15 +193,30 @@ def _rebuild_tensor_v2( def _validate_loaded_sparse_tensors(): try: for t in _sparse_tensors_to_validate: - if t.is_sparse: + if t.layout is torch.sparse_coo: torch._validate_sparse_coo_tensor_args( t._indices(), t._values(), t.size() ) - elif t.is_sparse_csr: + elif t.layout in { + torch.sparse_csr, + torch.sparse_csc, + torch.sparse_bsr, + torch.sparse_bsc, + }: # TODO: Validation currently involves an expensive traversal # on CPU, which may include a device transfer. - torch._validate_sparse_csr_tensor_args( - t.crow_indices(), t.col_indices(), t.values(), t.size() + if t.layout in {torch.sparse_csr, torch.sparse_bsr}: + compressed_indices, plain_indices = ( + t.crow_indices(), + t.col_indices(), + ) + else: + compressed_indices, plain_indices = ( + t.ccol_indices(), + t.row_indices(), + ) + torch._validate_sparse_compressed_tensor_args( + compressed_indices, plain_indices, t.values(), t.size(), t.layout ) else: raise NotImplementedError( @@ -207,14 +241,15 @@ def _rebuild_sparse_tensor(layout, data): _sparse_tensors_to_validate.append(result) return result - raise NotImplementedError("rebuilding sparse tensor for layout %s" % (layout)) - - -def _rebuild_sparse_csr_tensor(layout, data): - if layout == torch.sparse_csr: - crow_indices, col_indices, values, size = data - result = torch._sparse_csr_tensor_unsafe( - crow_indices, col_indices, values, size + elif layout in { + torch.sparse_csr, + torch.sparse_csc, + torch.sparse_bsr, + torch.sparse_bsc, + }: + compressed_indices, plain_indices, values, size = data + result = torch._sparse_compressed_tensor_unsafe( + compressed_indices, plain_indices, values, size, layout=layout ) _sparse_tensors_to_validate.append(result) return result @@ -317,6 +352,64 @@ def _rebuild_parameter(data, requires_grad, backward_hooks): return param +# TODO(kshitij12345): Support serializing nn.Parameter with Python Attributes. +# NOTE: We are just defining it here now for future use. +def _rebuild_parameter_with_state(data, requires_grad, backward_hooks, state): + param = torch.nn.Parameter(data, requires_grad) + # NB: This line exists only for backwards compatibility; the + # general expectation is that backward_hooks is an empty + # OrderedDict. See Note [Don't serialize hooks] + param._backward_hooks = backward_hooks + + # Restore state on Parameter like python attr. + param = _set_obj_state(param, state) + return param + + +def _get_obj_state(obj): + # Get the state of the python subclass + # This loosely mimicks the function on the object class but since Tensor do not inherit + # from it, we cannot call that function directly + # https://github.com/python/cpython/blob/c83919bd635f4433f1c6ae8504996a9fe3c215e5/Objects/typeobject.c#L4891 + getstate_fn = getattr(obj, "__getstate__", None) + if getstate_fn: + state = getstate_fn() + else: + slots_to_save = copyreg._slotnames(obj.__class__) # type: ignore[attr-defined] + if slots_to_save: + state = ( + obj.__dict__, + { + name: getattr(obj, name) + for name in slots_to_save + if hasattr(obj, name) + }, + ) + else: + state = obj.__dict__ + + return state + + +def _set_obj_state(obj, state): + if isinstance(state, tuple): + if not len(state) == 2: + raise RuntimeError(f"Invalid serialized state: {state}") + dict_state = state[0] + slots_state = state[1] + else: + dict_state = state + slots_state = None + + for k, v in dict_state.items(): + setattr(obj, k, v) + + if slots_state: + for k, v in slots_state.items(): + setattr(obj, k, v) + return obj + + def _import_dotted_name(name): components = name.split(".") obj = __import__(components[0]) diff --git a/torch/_weights_only_unpickler.py b/torch/_weights_only_unpickler.py index ee00db937fc3d..30e10409184f7 100644 --- a/torch/_weights_only_unpickler.py +++ b/torch/_weights_only_unpickler.py @@ -100,9 +100,12 @@ def _get_allowed_globals(): torch._utils._rebuild_tensor_v2, torch._utils._rebuild_sparse_tensor, torch._utils._rebuild_meta_tensor_no_storage, - torch._utils._rebuild_sparse_csr_tensor, ]: rc[f"torch._utils.{f.__name__}"] = f + + # Handles Tensor Subclasses, Tensor's with attributes. + # NOTE: It calls into above rebuild functions for regular Tensor types. + rc["torch._tensor._rebuild_from_type_v2"] = torch._tensor._rebuild_from_type_v2 return rc diff --git a/torch/ao/nn/intrinsic/qat/modules/linear_fused.py b/torch/ao/nn/intrinsic/qat/modules/linear_fused.py index f19dbd9a0f587..7c92c470ba5b9 100644 --- a/torch/ao/nn/intrinsic/qat/modules/linear_fused.py +++ b/torch/ao/nn/intrinsic/qat/modules/linear_fused.py @@ -35,7 +35,7 @@ def __init__(self, freeze_bn=False, qconfig=None): nn.modules.linear.Linear.__init__(self, in_features, out_features, bias) - assert qconfig, 'qconfig must be provded for QAT module' + assert qconfig, 'qconfig must be provided for QAT module' self.qconfig = qconfig self.freeze_bn = freeze_bn if self.training else True self.bn = nn.BatchNorm1d(out_features, eps, momentum, True, True) diff --git a/torch/ao/nn/qat/dynamic/modules/linear.py b/torch/ao/nn/qat/dynamic/modules/linear.py index a6642b5d2df54..89c5567315956 100644 --- a/torch/ao/nn/qat/dynamic/modules/linear.py +++ b/torch/ao/nn/qat/dynamic/modules/linear.py @@ -17,7 +17,7 @@ class Linear(torch.ao.nn.qat.Linear): def __init__(self, in_features, out_features, bias=True, qconfig=None, device=None, dtype=None) -> None: super().__init__(in_features, out_features, bias, qconfig, device, dtype) - if not torch.ao.quantization.activation_is_memoryless(qconfig): + if not torch.ao.quantization.qconfig._activation_is_memoryless(qconfig): raise ValueError( "Dynamic QAT requires a memoryless observer." + "This means a MovingAverage observer with averaging constant equal to 1" diff --git a/torch/ao/nn/quantizable/modules/rnn.py b/torch/ao/nn/quantizable/modules/rnn.py index 59f23137097ce..72156a7ba5fe1 100644 --- a/torch/ao/nn/quantizable/modules/rnn.py +++ b/torch/ao/nn/quantizable/modules/rnn.py @@ -41,12 +41,22 @@ def __init__(self, input_dim: int, hidden_dim: int, bias: bool = True, self.hgates = torch.nn.Linear(hidden_dim, 4 * hidden_dim, bias=bias, **factory_kwargs) self.gates = torch.ao.nn.quantized.FloatFunctional() + self.input_gate = torch.nn.Sigmoid() + self.forget_gate = torch.nn.Sigmoid() + self.cell_gate = torch.nn.Tanh() + self.output_gate = torch.nn.Sigmoid() + self.fgate_cx = torch.ao.nn.quantized.FloatFunctional() self.igate_cgate = torch.ao.nn.quantized.FloatFunctional() self.fgate_cx_igate_cgate = torch.ao.nn.quantized.FloatFunctional() self.ogate_cy = torch.ao.nn.quantized.FloatFunctional() + self.initial_hidden_state_qparams: Tuple[float, int] = (1.0, 0) + self.initial_cell_state_qparams: Tuple[float, int] = (1.0, 0) + self.hidden_state_dtype: torch.dtype = torch.quint8 + self.cell_state_dtype: torch.dtype = torch.quint8 + def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]: if hidden is None or hidden[0] is None or hidden[1] is None: hidden = self.initialize_hidden(x.shape[0], x.is_quantized) @@ -58,10 +68,10 @@ def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None) -> input_gate, forget_gate, cell_gate, out_gate = gates.chunk(4, 1) - input_gate = torch.sigmoid(input_gate) - forget_gate = torch.sigmoid(forget_gate) - cell_gate = torch.tanh(cell_gate) - out_gate = torch.sigmoid(out_gate) + input_gate = self.input_gate(input_gate) + forget_gate = self.forget_gate(forget_gate) + cell_gate = self.cell_gate(cell_gate) + out_gate = self.output_gate(out_gate) fgate_cx = self.fgate_cx.mul(forget_gate, cx) igate_cgate = self.igate_cgate.mul(input_gate, cell_gate) @@ -75,8 +85,10 @@ def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None) -> def initialize_hidden(self, batch_size: int, is_quantized: bool = False) -> Tuple[Tensor, Tensor]: h, c = torch.zeros((batch_size, self.hidden_size)), torch.zeros((batch_size, self.hidden_size)) if is_quantized: - h = torch.quantize_per_tensor(h, scale=1.0, zero_point=0, dtype=torch.quint8) - c = torch.quantize_per_tensor(c, scale=1.0, zero_point=0, dtype=torch.quint8) + (h_scale, h_zp) = self.initial_hidden_state_qparams + (c_scale, c_zp) = self.initial_cell_state_qparams + h = torch.quantize_per_tensor(h, scale=h_scale, zero_point=h_zp, dtype=self.hidden_state_dtype) + c = torch.quantize_per_tensor(c, scale=c_scale, zero_point=c_zp, dtype=self.cell_state_dtype) return h, c def _get_name(self): diff --git a/torch/ao/nn/quantized/functional.py b/torch/ao/nn/quantized/functional.py index d0b100bd30567..b3169279082ae 100644 --- a/torch/ao/nn/quantized/functional.py +++ b/torch/ao/nn/quantized/functional.py @@ -422,7 +422,7 @@ def leaky_relu(input: Tensor, negative_slope: float = 0.01, inplace: bool = Fals :math:`\text{LeakyReLU}(x) = \max(0, x) + \text{negative\_slope} * \min(0, x)` Args: - input: Quaintized input + input: Quantized input negative_slope: The slope of the negative input inplace: Inplace modification of the input tensor scale, zero_point: Scale and zero point of the output tensor. diff --git a/torch/ao/ns/_numeric_suite_fx.py b/torch/ao/ns/_numeric_suite_fx.py index f586de58531a7..92298c3d29b6a 100644 --- a/torch/ao/ns/_numeric_suite_fx.py +++ b/torch/ao/ns/_numeric_suite_fx.py @@ -119,15 +119,11 @@ NSResultsType, NSNodeTargetType, ) - -from torch.ao.quantization import ( - QConfigMapping, -) from torch.ao.quantization.backend_config.utils import get_fusion_pattern_to_root_node_getter from torch.ao.quantization.backend_config import BackendConfig -from torch.ao.quantization.fx.backend_config_utils import get_pattern_to_quantize_handlers from torch.ao.quantization.fx.match_utils import find_matches from torch.ao.quantization.fx.qconfig_mapping_utils import generate_node_name_to_qconfig +from torch.ao.quantization.fx.quantize_handler import _get_pattern_to_quantize_handlers from torch.ao.quantization.qconfig import QConfigAny from torch.ao.ns.fx.n_shadows_utils import ( OutputProp, @@ -138,6 +134,7 @@ print_n_shadows_summary, handle_subgraph, ) +from torch.ao.ns.fx.qconfig_multi_mapping import QConfigMultiMapping from typing import Dict, Tuple, Callable, List, Optional, Set, Any, Type @@ -753,8 +750,11 @@ def extend_logger_results_with_comparison( def prepare_n_shadows_model( model: torch.nn.Module, example_inputs: Any, - qconfig_mappings: List[QConfigMapping], + qconfig_multi_mapping: QConfigMultiMapping, backend_config: BackendConfig, + custom_prepare_fn: Optional[Callable] = None, + custom_prepare_kwargs: Optional[Dict[str, Any]] = None, + custom_tracer: Any = None, ) -> torch.nn.Module: """ Given a model with a graph with M ops such as @@ -770,9 +770,9 @@ def prepare_n_shadows_model( args_kwargs_m -> op_m -> output_m | | - |---------------------------> mod_with_op_m_transformed_with_qconfig_i + |---------------------------> mod_with_op_m_transformed_with_qconfig_n - Where mod_with_op_m_transformed_with_qconfig_i is a submodule, and its + Where mod_with_op_m_transformed_with_qconfig_n is a submodule, and its inner graph looks like .. code:: @@ -790,11 +790,13 @@ def prepare_n_shadows_model( 1. add deduplication for qconfigs per subgraph 2. figure out a better way to name the output structure 3. return a results data structure instead of printing it out - 4. make specifying sets of QConfigMapping more user friendly - 5. add examples to docblocks + 4. add examples to docblocks """ - tracer = quantize_fx.QuantizationTracer([], []) + if custom_tracer is None: + tracer = quantize_fx.QuantizationTracer([], []) + else: + tracer = custom_tracer mt = torch.fx.GraphModule(model, tracer.trace(model)) # this is necessary to ensure logger FQNs get populated mt._node_name_to_scope = tracer.node_name_to_scope @@ -807,7 +809,7 @@ def prepare_n_shadows_model( # Find the set of subgraphs in the original graph which we need to # consider. modules = dict(mt.named_modules(remove_duplicate=False)) - patterns = get_pattern_to_quantize_handlers(backend_config) + patterns = _get_pattern_to_quantize_handlers(backend_config) root_node_getter_mapping = \ get_fusion_pattern_to_root_node_getter(backend_config) standalone_module_names: List[str] = [] @@ -822,7 +824,7 @@ def prepare_n_shadows_model( # generate node to qconfig for each subgraph # TODO(future PR): deduplicate repeating entries list_of_node_name_to_qconfig: List[Dict[str, QConfigAny]] = [] - for qconfig_mapping in qconfig_mappings: + for qconfig_mapping in qconfig_multi_mapping.qconfig_mappings_list: node_name_to_qconfig = generate_node_name_to_qconfig( mt, modules, mt.graph, qconfig_mapping, tracer.node_name_to_scope) list_of_node_name_to_qconfig.append(node_name_to_qconfig) @@ -838,7 +840,9 @@ def prepare_n_shadows_model( enumerate(subgraphs_dedup.items()): handle_subgraph( mt, subgraph_idx, match_name, nodes_in_this_subgraph, - qconfig_mappings, list_of_node_name_to_qconfig) + qconfig_multi_mapping.qconfig_mappings_list, list_of_node_name_to_qconfig, + custom_prepare_fn, custom_prepare_kwargs + ) mt.recompile() return mt @@ -866,7 +870,11 @@ def loggers_set_save_activations( if isinstance(child, OutputLogger): child.save_activations = save_activations -def convert_n_shadows_model(model: GraphModule) -> GraphModule: +def convert_n_shadows_model( + model: GraphModule, + custom_convert_fn: Optional[Callable] = None, + custom_convert_kwargs: Optional[Dict[str, Any]] = None +) -> GraphModule: """ Given a model from `prepare_n_shadows_model`, runs `convert_fx` on each shadow submodule. @@ -876,8 +884,13 @@ def convert_n_shadows_model(model: GraphModule) -> GraphModule: # node name string match if node.name.startswith(SHADOW_WRAPPER_NODE_NAME_PREFIX): orig_mod = getattr(model, node.name) - converted_mod = torch.ao.quantization.quantize_fx.convert_fx( - orig_mod) + if custom_convert_fn is None: + converted_mod = torch.ao.quantization.quantize_fx.convert_fx( + orig_mod) + else: + if custom_convert_kwargs is None: + custom_convert_kwargs = {} + converted_mod = custom_convert_fn(orig_mod, **custom_convert_kwargs) setattr(model, node.name, converted_mod) return model diff --git a/torch/ao/ns/fx/n_shadows_utils.py b/torch/ao/ns/fx/n_shadows_utils.py index 85e9be6135f1c..c504a59c995d6 100644 --- a/torch/ao/ns/fx/n_shadows_utils.py +++ b/torch/ao/ns/fx/n_shadows_utils.py @@ -500,6 +500,8 @@ def handle_subgraph_candidate( fqn: Optional[str], list_of_node_name_to_qconfig: List[Dict[str, QConfigAny]], example_inputs: Any, + custom_prepare_fn: Optional[Callable] = None, + custom_prepare_kwargs: Dict[str, Any] = None, ) -> None: """ Given a subgraph in `mt` and a subgraph candidate idx, inserts the @@ -566,9 +568,24 @@ def handle_subgraph_candidate( .set_non_traceable_module_classes([OutputLogger, OutputComparisonLogger]) # add a call to prepare_fx on the wrapper module - orig_mod_copy_wrapped = torch.ao.quantization.quantize_fx.prepare_fx( - orig_mod_copy_wrapped, qconfig_mapping, example_inputs=example_inputs, - prepare_custom_config=prepare_custom_config) + if custom_prepare_fn is None: + orig_mod_copy_wrapped = torch.ao.quantization.quantize_fx.prepare_fx( + orig_mod_copy_wrapped, qconfig_mapping, example_inputs=example_inputs, + prepare_custom_config=prepare_custom_config) + else: + if custom_prepare_kwargs is None: + custom_prepare_kwargs = {} + for kwarg_name in ["example_inputs", "prepare_custom_config", "qconfig_mapping"]: + assert kwarg_name not in custom_prepare_kwargs, f"cannot specify {kwarg_name} in custom_prepare_kwargs" + prepare_kwargs: Dict[str, Any] = { + "example_inputs": example_inputs, + "prepare_custom_config": prepare_custom_config, + "qconfig_mapping": qconfig_mapping + } + prepare_kwargs.update(custom_prepare_kwargs) + orig_mod_copy_wrapped = custom_prepare_fn( + orig_mod_copy_wrapped, + **prepare_kwargs) # attach the wrapper to the model attr_name = _get_attr_wrapper_name(subgraph_idx, subgraph_candidate_idx) @@ -615,6 +632,8 @@ def handle_subgraph( nodes_in_this_subgraph: List[Any], qconfig_mappings: List[QConfigMapping], list_of_node_name_to_qconfig: List[Dict[str, QConfigAny]], + custom_prepare_fn: Optional[Callable] = None, + custom_prepare_kwargs: Dict[str, Any] = None, ) -> None: """ Given a model `mt` and a subgraph_idx, creates the needed copies @@ -690,7 +709,7 @@ def handle_subgraph( handle_subgraph_candidate( mt, subgraph_idx, subgraph_candidate_idx, first_node, last_node, fqn, list_of_node_name_to_qconfig, - example_inputs) + example_inputs, custom_prepare_fn, custom_prepare_kwargs) # TODO(future PR): redesign this to make it easier to consume outputs def group_results_by_subgraph(results: NSResultsType) -> Any: diff --git a/torch/ao/ns/fx/pattern_utils.py b/torch/ao/ns/fx/pattern_utils.py index b8e6a0ee4dc11..b91024bc76c09 100644 --- a/torch/ao/ns/fx/pattern_utils.py +++ b/torch/ao/ns/fx/pattern_utils.py @@ -6,9 +6,10 @@ from torch.fx import GraphModule from torch.fx.graph import Node +from torch.ao.quantization.backend_config import get_native_backend_config +from torch.ao.quantization.fx.quantize_handler import _get_pattern_to_quantize_handlers from torch.ao.quantization.utils import getattr_from_fqn from .ns_types import NSNodeTargetType -from torch.ao.quantization.fx.backend_config_utils import get_native_quant_patterns from torch.ao.quantization import ( ObserverBase, FakeQuantizeBase, @@ -66,7 +67,7 @@ def get_reversed_fusions() -> List[Tuple[NSFusionType, int]]: # * multiple ops: (torch.nn.ReLU, torch.nn.Conv2d) # For fusions, we only care about patterns composed of multiple ops. # TODO(future PR): allow customizations from default patterns. - all_quant_patterns = get_native_quant_patterns() + all_quant_patterns = _get_pattern_to_quantize_handlers(get_native_backend_config()) default_base_op_idx = 0 for quant_pattern, _quant_handler in all_quant_patterns.items(): diff --git a/torch/ao/ns/fx/qconfig_multi_mapping.py b/torch/ao/ns/fx/qconfig_multi_mapping.py new file mode 100644 index 0000000000000..20a005d0c8bf9 --- /dev/null +++ b/torch/ao/ns/fx/qconfig_multi_mapping.py @@ -0,0 +1,243 @@ +from __future__ import annotations + +import copy +from typing import Any, Callable, Dict, List, Union + +import torch +from torch.ao.quantization import QConfigMapping +from torch.ao.quantization.qconfig_mapping import _QCONFIG_STYLE_ORDER +from torch.ao.quantization.qconfig import QConfigAny + +__all__ = ["QConfigMultiMapping"] + +_QCONFIG_STYLE_TO_METHOD: Dict[str, str] = { + "global_qconfig": "set_global", + "object_type_qconfigs": "set_object_type", + "module_name_regex_qconfigs": "set_module_name_regex", + "module_name_qconfigs": "set_module_name", + "module_name_object_type_order_qconfigs": "set_module_name_object_type_order", +} + +def _remove_duplicates_and_none(qconfig_list: List[QConfigAny]) -> None: + to_remove = [] + for index, cur_qconfig in enumerate(qconfig_list): + if cur_qconfig is None: + to_remove.append(index) + break + for checked_qconfig in qconfig_list[:index]: + if torch.ao.quantization.qconfig_equals(cur_qconfig, checked_qconfig): + to_remove.append(index) + break + for index in to_remove[::-1]: + qconfig_list.pop(index) + +class QConfigMultiMapping: + """ + This class, used with the prepare_n_shadows_model API, stores a list of :class:`torch.ao.quantization.QConfigMapping`s + so that multiple QConfigs can be specified for each QConfig matching style. + + The user can specify QConfigs using the following methods (in increasing match priority): + + ``set_global`` : sets the global (default) QConfigs + + ``set_object_type`` : sets the QConfigs for a given module type, function, or method name + + ``set_module_name_regex`` : sets the QConfigs for modules matching the given regex string + + ``set_module_name`` : sets the QConfigs for modules matching the given module name + + ``set_module_name_object_type_order`` : sets the QConfigs for modules matching a combination + of the given module name, object type, and the index at which the module appears + + Note: Usage of set methods is the same as in QConfigMapping except with a passed in list of QConfigs rather than a + single QConfig. + + Example usage:: + + qconfig_mapping = QConfigMultiMapping() + .set_global([qconfig1, qconfig2]) + .set_object_type(torch.nn.Linear, [qconfig2, qconfig3]) + .set_object_type(torch.nn.ReLU, [qconfig1]) + .set_module_name_regex("foo.*bar.*conv[0-9]+", [qconfig2]) + .set_module_name_regex("foo.*", [qconfig1, qconfig2, qconfig3]) + .set_module_name("module1", [None]) + .set_module_name("module2", [qconfig2]) + .set_module_name_object_type_order("foo.bar", torch.nn.functional.linear, 0, [qconfig3]) + + """ + + def __init__(self): + # initialize this with 1 QConfigMapping to avoid corner cases + self.qconfig_mappings_list: List[QConfigMapping] = [QConfigMapping()] + + def _handle_list_size_mismatch( + self, qconfig_list: List[QConfigAny], style: str + ) -> None: + # this method handles cases where the size of qconfig_list does not match + # the size of qconfig_mappings_list. + # Issue: Consider a user inserting global_qconfig A and B first, then inserting + # qconfig C as an object_type_qconfig for conv ops. If we internally store + # 1 QConfigMapping with A and C and another with just B, then the + # second QConfigMapping will match B to conv ops (which is not wanted), since B is global. + + # we avoid this by maintaining the invariant that if any QConfigMapping + # has a qconfig style+key with a qconfig in it, all QConfigMappings must + # have either a qconfig or None for that same style+key. In the above + # example, a None qconfig would prevent the unwanted match in the + # second QConfigMapping + + if len(qconfig_list) > len(self.qconfig_mappings_list): + # Case: we have more qconfigs (in qconfig_list) than QConfigMappings + + # Add new QConfigMappings (initialized so we maintain the `invariant`) + + new_qconfig_mapping = QConfigMapping() + # searches other QConfigMappings for qconfig style+keys + # that need to be inserted as `None` into the new QConfigMapping + for qconfig_mapping in self.qconfig_mappings_list: + + # global_qconfig has None by default + for check_style in _QCONFIG_STYLE_ORDER[1:]: + qconfigs_dict = getattr(qconfig_mapping, check_style) + target_qconfigs_dict = getattr(new_qconfig_mapping, check_style) + for key in qconfigs_dict: + target_qconfigs_dict[key] = None + break + + # insert copies of this new QConfigMapping until all entires + # in qconfig_list can fit among the QConfigMappings + while len(qconfig_list) > len(self.qconfig_mappings_list): + self.qconfig_mappings_list.append(copy.deepcopy(new_qconfig_mapping)) + else: + # Case: we have fewer qconfigs in qconfig_list than QConfigMappings + + # pad qconfig_list with `None` until length is same + while len(qconfig_list) < len(self.qconfig_mappings_list): + qconfig_list.append(None) + + # this function applies the insertion method across each QConfigMapping + def _insert_qconfig_list( + self, + style: str, + args: List[Union[str, int, Callable]], + qconfig_list: List[QConfigAny], + ) -> None: + + # we remove duplicates and None to make the ordering of qconfigs + # deterministic upon insertion. + _remove_duplicates_and_none(qconfig_list) + + self._handle_list_size_mismatch(qconfig_list, style) + method_name = _QCONFIG_STYLE_TO_METHOD[style] + for qconfig_mapping, qconfig in zip(self.qconfig_mappings_list, qconfig_list): + # uses QConfigMapping set method to insert qconfig + set_method = getattr(qconfig_mapping, method_name) + set_method(*args, qconfig) + + def set_global(self, global_qconfig_list: List[QConfigAny]) -> QConfigMultiMapping: + """ + Set global QConfigs + see :func:`~torch.ao.quantization.QConfigMapping.set_global()` for more info + """ + self._insert_qconfig_list("global_qconfig", [], global_qconfig_list) + return self + + def set_object_type( + self, object_type: Union[Callable, str], qconfig_list: List[QConfigAny] + ) -> QConfigMultiMapping: + """ + Set object type QConfigs + see :func:`~torch.ao.quantization.QConfigMapping.set_object_type()` for more info + """ + self._insert_qconfig_list("object_type_qconfigs", [object_type], qconfig_list) + return self + + def set_module_name_regex( + self, module_name_regex: str, qconfig_list: List[QConfigAny] + ) -> QConfigMultiMapping: + """ + Set module_name_regex QConfigs + see :func:`~torch.ao.quantization.QConfigMapping.set_module_name_regex()` for more info + """ + self._insert_qconfig_list( + "module_name_regex_qconfigs", [module_name_regex], qconfig_list + ) + return self + + def set_module_name( + self, module_name: str, qconfig_list: List[QConfigAny] + ) -> QConfigMultiMapping: + """ + Set module_name QConfigs + see :func:`~torch.ao.quantization.QConfigMapping.set_module_name()` for more info + """ + self._insert_qconfig_list("module_name_qconfigs", [module_name], qconfig_list) + return self + + def set_module_name_object_type_order( + self, + module_name: str, + object_type: Callable, + index: int, + qconfig_list: List[QConfigAny], + ) -> QConfigMultiMapping: + """ + Set module_name QConfigs + see :func:`~torch.ao.quantization.QConfigMapping.set_module_name_object_type_order()` for more info + """ + self._insert_qconfig_list( + "module_name_object_type_order_qconfigs", + [module_name, object_type, index], + qconfig_list, + ) + return self + + def __repr__(self): + return ( + self.__class__.__name__ + + " [" + + "".join(f"\n{qconfig_mapping.__repr__()}," for qconfig_mapping in self.qconfig_mappings_list) + + "\n]" + ) + + @classmethod + def from_list_qconfig_mapping( + cls, qconfig_mapping_list: List[QConfigMapping] + ) -> QConfigMultiMapping: + """ + Creates a QConfigMultiMapping from a list of QConfigMappings + """ + new_qconfig_multi_mapping = cls() + + new_qconfig_multi_mapping.qconfig_mappings_list = copy.deepcopy( + qconfig_mapping_list + ) + + # we need to avoid the issue described in _handle_list_size_mismatch, + # so we reinsert all the qconfigs using the QConfigMultiMapping + # set methods + + # go through all qconfig styles + # note: global can be ignored since it is None by default + for style in _QCONFIG_STYLE_ORDER[1:]: + + # gather all key+qconfigs for current style + # into qconfig_dict_list + qconfig_dict_list: Dict[Any, List[QConfigAny]] = {} + for qconfig_mapping in qconfig_mapping_list: + qconfig_dict = getattr(qconfig_mapping, style) + for key, qconfig in qconfig_dict.items(): + if key not in qconfig_dict_list: + qconfig_dict_list[key] = [] + qconfig_dict_list[key].append(qconfig) + + # reinsert all gathered key+qconfigs + set_method_name = _QCONFIG_STYLE_TO_METHOD[style] + set_method = getattr(new_qconfig_multi_mapping, set_method_name) + for key, qconfig_list in qconfig_dict_list.items(): + if isinstance(key, tuple): + set_method(*key, qconfig_list) + else: + set_method(key, qconfig_list) + + return new_qconfig_multi_mapping diff --git a/torch/ao/pruning/_experimental/activation_sparsifier/README.md b/torch/ao/pruning/_experimental/activation_sparsifier/README.md index 3c2514c2f116b..810b053d92221 100644 --- a/torch/ao/pruning/_experimental/activation_sparsifier/README.md +++ b/torch/ao/pruning/_experimental/activation_sparsifier/README.md @@ -60,7 +60,7 @@ def mask_fn(tensor, threshold): # threshold is the sparse config here ``` ## API Design -`ActivationSparsifier`: Attaches itself to a model layer and sparsifies the activation flowing through that layer. The user can pass in the default `aggregate_fn`, `reduce_fn` and `mask_fn`. Additionaly, `features` and `feature_dim` are also accepted. +`ActivationSparsifier`: Attaches itself to a model layer and sparsifies the activation flowing through that layer. The user can pass in the default `aggregate_fn`, `reduce_fn` and `mask_fn`. Additionally, `features` and `feature_dim` are also accepted. `register_layer`: Registers a layer for sparsification. Specifically, registers `forward_pre_hook()` that performs aggregation. diff --git a/torch/ao/pruning/_experimental/data_sparsifier/README.md b/torch/ao/pruning/_experimental/data_sparsifier/README.md index c6fc99b36c8c4..faea74355360a 100644 --- a/torch/ao/pruning/_experimental/data_sparsifier/README.md +++ b/torch/ao/pruning/_experimental/data_sparsifier/README.md @@ -3,7 +3,7 @@ The data sparsifier inherits from the `BaseSparsifier` class. It attempts to sparsify data tensors in general (trainable and non-trainable). ## Implementation Details -The data sparsifier does not receive a model or a layer to sparsify. Hence, the mask needs to be owned by the data sparsifier. This is acheived by introducing a private container model that registers the data as a parametrized buffer. +The data sparsifier does not receive a model or a layer to sparsify. Hence, the mask needs to be owned by the data sparsifier. This is achieved by introducing a private container model that registers the data as a parametrized buffer. The BaseDataSparsifier handles all the housekeeping while allowing the user to just implement the `update_mask` logic in their implementation. diff --git a/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/README.md b/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/README.md index b39e951efec5d..f7f83d7d6f3bb 100644 --- a/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/README.md +++ b/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/README.md @@ -5,7 +5,7 @@ The objective of this exercise is to use the data sparsifier to prune the embedd 1. **Disk usage savings**: Savings in model size after pruning. 2. **Model Quality**: How and by how much does performance deteriorate after pruning the embedding bags? -3. **Model forward time**: Can we speed up the model forward time by utilizing the sparsity? Specificially, can we introduce torch.sparse interim to reduce number of computations. +3. **Model forward time**: Can we speed up the model forward time by utilizing the sparsity? Specifically, can we introduce torch.sparse interim to reduce number of computations. ## Scope The [DataNormSparsifier](https://github.com/pytorch/pytorch/blob/master/torch/ao/sparsity/_experimental/data_sparsifier/data_norm_sparsifier.py) is used to sparsify the embeddings of the DLRM model. The model is sparsified for all the combinations of - diff --git a/torch/ao/pruning/_experimental/pruner/__init__.py b/torch/ao/pruning/_experimental/pruner/__init__.py index c496e555930a2..d762873277493 100644 --- a/torch/ao/pruning/_experimental/pruner/__init__.py +++ b/torch/ao/pruning/_experimental/pruner/__init__.py @@ -1,15 +1,5 @@ -from .base_pruner import BasePruner +from .base_structured_sparsifier import BaseStructuredSparsifier from .parametrization import ( - ActivationReconstruction, + FakeStructuredSparsity, BiasHook, - PruningParametrization, - ZeroesParametrization, ) - -__all__ = [ - "ActivationReconstruction", - "BasePruner", - "BiasHook", - "PruningParametrization", - "ZeroesParametrization", -] diff --git a/torch/ao/pruning/_experimental/pruner/base_pruner.py b/torch/ao/pruning/_experimental/pruner/base_pruner.py deleted file mode 100644 index fbeed5084abb5..0000000000000 --- a/torch/ao/pruning/_experimental/pruner/base_pruner.py +++ /dev/null @@ -1,247 +0,0 @@ - -import copy -import warnings -import abc - -import torch -from torch import nn -from torch.nn.utils import parametrize - -from torch.nn.modules.container import ModuleDict, ModuleList - -from .parametrization import PruningParametrization, ZeroesParametrization, ActivationReconstruction, BiasHook - -from torch.ao.pruning import BaseSparsifier, module_to_fqn, fqn_to_module -from torch.ao.pruning.sparsifier.utils import get_arg_info_from_tensor_fqn - -__all__ = ["BasePruner"] - -SUPPORTED_MODULES = { # added to config if None given - nn.Linear, - nn.Conv2d, - nn.BatchNorm2d, # will need manual update to match conv2d -} - -NEEDS_ZEROS = { # these layers should have pruned indices zero-ed, not removed - nn.BatchNorm2d -} - -class BasePruner(BaseSparsifier): - r"""Base class for all pruners. - - Abstract methods that need to be implemented: - - - update_mask: Function to compute a new mask for all keys in the - `groups` attribute. - - Args: - - defaults [dict]: default configurations will be attached to the - configuration. Only the keys that don't exist in the `config` will - be updated. - - also_prune_bias [bool]: whether to prune bias in addition to weights (to prune full output channel) - or not; default=True. - - """ - def __init__(self, defaults, also_prune_bias=True): - super().__init__(defaults) - self.prune_bias = also_prune_bias - - def _get_modules_and_tensor_names(self, config, use_path): - modules = [] - tensor_names = [] - if use_path: - if type(config['module']) is tuple: # (Conv2d, BN) - for module_fqn, tensor_name in zip(config['module_fqn'], config['tensor_name']): - module = fqn_to_module(self.model, module_fqn) - modules.append(module) - tensor_names.append(tensor_name) - else: - module = fqn_to_module(self.model, config['module_fqn']) - modules.append(module) - tensor_name = config['tensor_name'] - tensor_names.append(tensor_name) - - else: - if type(config['module']) is tuple: - for module, tensor_name in zip(config['module'], config['tensor_name']): - modules.append(module) - tensor_names.append(tensor_name) - else: - module = config['module'] - modules.append(module) - tensor_name = config['tensor_name'] - tensor_names.append(tensor_name) - return modules, tensor_names - - def _prepare(self, use_path=False, *args, **kwargs): - r"""Adds mask parametrization to the layer weight - """ - self.activation_handles = [] # store removable hook handles - self.bias_handles = [] - - for config in self.groups: - modules, tensor_names = self._get_modules_and_tensor_names(config, use_path) - - for module, tensor_name in zip(modules, tensor_names): - if not isinstance(module, tuple(NEEDS_ZEROS)): - # add pruning parametrization and forward hooks - if getattr(module, 'mask', None) is None: - module.register_buffer('mask', torch.tensor(getattr(module, tensor_name).shape[0])) - param = config.get('parametrization', PruningParametrization) - parametrize.register_parametrization(module, tensor_name, param(module.mask), unsafe=True) - - assert isinstance(module.parametrizations, ModuleDict) # make mypy happy - assert isinstance(module.parametrizations.weight, ModuleList) - if isinstance(module, tuple(SUPPORTED_MODULES)): - self.activation_handles.append(module.register_forward_hook( - ActivationReconstruction(getattr(module.parametrizations, tensor_name)[0]) - )) - else: - raise NotImplementedError("This module type is not supported yet.") - - else: # needs zeros - if getattr(module, 'mask', None) is None: - module.register_buffer('mask', torch.tensor(getattr(module, tensor_name).shape[0])) - param = config.get('parametrization', ZeroesParametrization) - parametrize.register_parametrization(module, tensor_name, param(module.mask), unsafe=True) - - if module.bias is not None: - module.register_parameter('_bias', nn.Parameter(module.bias.detach())) - module.bias = None - self.bias_handles.append(module.register_forward_hook(BiasHook(module.parametrizations.weight[0], self.prune_bias))) - - if len(modules) == 2: # (Conv2d, BN) - # should have the same set of pruned outputs - modules[1].parametrizations.weight[0].pruned_outputs = modules[0].parametrizations.weight[0].pruned_outputs - - def make_config_from_model(self, model, SUPPORTED_MODULES=SUPPORTED_MODULES, NEEDS_ZEROS=NEEDS_ZEROS): - self.config = [] - stack = [model] - while stack: - module = stack.pop() - for name, child in module.named_children(): - if type(child) in SUPPORTED_MODULES: - child_fqn = module_to_fqn(model, child) - assert isinstance(child_fqn, str) # for mypy - self.config.append({'tensor_fqn': child_fqn + '.weight'}) - else: - if NEEDS_ZEROS is not None and type(child) in NEEDS_ZEROS and hasattr(self, "prune_bias") and self.prune_bias: - # only useful for Pruner - warnings.warn(f"Models with {type(child)} layers have config provided by user.") - stack.append(child) - - def prepare(self, model, config): - r"""Prepares a model, by adding the parametrizations and forward post-hooks. - Note:: - The model is modified inplace. If you need to preserve the original - model, use copy.deepcopy. - - Args: - - model [nn.Module]: model to configure. The model itself is not saved - but used for the state_dict saving / loading. - - config [list]: configuration elements could either be instances of - tuples of dict maps or dict maps. The dicts must have a key 'tensor_fqn' with the - value being the fqn of the tensor to be pruned. - """ - self.model = model # TODO: Need to figure out how to load without this. - self.config = config - - # If no config -- try getting all the supported layers - if self.config is None: - # Add all models to the config - self.make_config_from_model(self.model) - - for module_config in self.config: - if type(module_config) is tuple: - first_layer, next_layer = module_config - assert isinstance(first_layer, nn.Conv2d) and isinstance(next_layer, nn.BatchNorm2d) - assert isinstance(module_config, tuple) # for mypy - module_config = {'module': module_config} - local_args = copy.deepcopy(self.defaults) - local_args.update(module_config) - module_fqn_list = [] - tensor_fqn_list = [] - tensor_name_list = [] - for module in local_args['module']: - module_fqn = module_to_fqn(model, module) - if module_fqn is None: - module_fqn = '' - if module_fqn and module_fqn[0] == '.': - module_fqn = module_fqn[1:] - module_fqn_list.append(module_fqn) - tensor_fqn_list.append(module_fqn + '.weight') - tensor_name_list.append('weight') - - local_args['module_fqn'] = module_fqn_list - local_args['tensor_fqn'] = tensor_fqn_list - local_args['tensor_name'] = tensor_name_list - else: - if isinstance(module_config, nn.Module): - module_config = {'module': module_config} # type: ignore[dict-item] - - local_args = copy.deepcopy(self.defaults) - local_args.update(module_config) - - # now that we're working with a dict, does it have the new format? - if local_args.get('tensor_fqn', None) is not None: - tensor_fqn = local_args.get('tensor_fqn') - assert isinstance(tensor_fqn, str) # for mypy - info_from_tensor_fqn = get_arg_info_from_tensor_fqn(model, tensor_fqn) - - for key in info_from_tensor_fqn.keys(): - if key in local_args: - # info_from_tensor_fqn will chop leading '.' from tensor_fqn so ignore that - assert key == 'tensor_fqn' or info_from_tensor_fqn[key] == local_args[key], ( - "Given both `{}` and `tensor_fqn`, it is expected them to " - "agree!".format(key) - ) - local_args.update(info_from_tensor_fqn) - else: - module = local_args['module'] - module_fqn = module_to_fqn(model, module) - if module_fqn and module_fqn[0] == '.': - module_fqn = module_fqn[1:] - local_args['module_fqn'] = module_fqn - local_args['tensor_name'] = "weight" - assert isinstance(module_fqn, str) # for mypy - local_args['tensor_fqn'] = module_fqn + ".weight" - self.groups.append(local_args) - - self._prepare() - - def squash_mask(self, use_path=False, *args, **kwargs): - for config in self.groups: - modules, tensor_names = self._get_modules_and_tensor_names(config, use_path) - - for module, tensor_name in zip(modules, tensor_names): - parametrize.remove_parametrizations(module, tensor_name, - leave_parametrized=True) - if getattr(module._parameters, 'mask', None): - del module._parameters['mask'] - elif getattr(module._buffers, 'mask', None): - del module._buffers['mask'] - delattr(module, 'mask') - - def get_module_pruned_outputs(self, module, tensor_name='weight'): - r"""Returns the set of pruned indices of module""" - assert parametrize.is_parametrized(module) # can only get pruned indices of pruned module - return getattr(module.parametrizations, tensor_name)[0].pruned_outputs # assume only one parametrization attached - - def step(self, use_path=False): - if not self.enable_mask_update: - return - with torch.no_grad(): - for config in self.groups: - modules, tensor_names = self._get_modules_and_tensor_names(config, use_path) - - untupled_args: dict = {} - untupled_args.update() - # only need to update the first module in modules if len(modules) > 1 - # since they should share the same set of pruned outputs - untupled_args['module'] = modules[0] - untupled_args['tensor_name'] = tensor_names[0] - self.update_mask(**config) - - @abc.abstractmethod - def update_mask(self, module, tensor_name, **kwargs): - pass diff --git a/torch/ao/pruning/_experimental/pruner/base_structured_sparsifier.py b/torch/ao/pruning/_experimental/pruner/base_structured_sparsifier.py new file mode 100644 index 0000000000000..3b568f1557d07 --- /dev/null +++ b/torch/ao/pruning/_experimental/pruner/base_structured_sparsifier.py @@ -0,0 +1,294 @@ +from itertools import chain +import torch +import torch.nn.functional as F +from torch import nn +from torch.fx import symbolic_trace +from torch.nn.utils import parametrize +from typing import Type, Set, Dict, Callable, Tuple, Optional, Union + +from torch.ao.pruning import BaseSparsifier +from .parametrization import FakeStructuredSparsity, BiasHook +from .match_utils import apply_match +from .prune_functions import ( + prune_linear, + prune_linear_linear, + prune_linear_activation_linear, + prune_conv2d, + prune_conv2d_conv2d, + prune_conv2d_activation_conv2d, + prune_conv2d_activation_pool_conv2d, + prune_conv2d_pool_activation_conv2d, + prune_conv2d_pool_flatten_linear, +) + + +def _get_supported_structured_pruning_modules(): + SUPPORTED_STRUCTURED_PRUNING_MODULES = { # added to config if None given + nn.Linear, + nn.Conv2d, + } + return SUPPORTED_STRUCTURED_PRUNING_MODULES + + +def _get_supported_activation_functions(): + SUPPORTED_ACTIVATION_FUNCTIONS = { + F.relu, + F.rrelu, + F.hardtanh, + F.relu6, + F.sigmoid, + F.hardsigmoid, + F.tanh, + F.silu, + F.mish, + F.hardswish, + F.elu, + F.celu, + F.selu, + F.hardshrink, + F.leaky_relu, + F.logsigmoid, + F.softplus, + F.prelu, + F.softsign, + F.tanhshrink, + } + return SUPPORTED_ACTIVATION_FUNCTIONS + + +def _get_supported_activation_modules(): + SUPPORTED_ACTIVATION_MODULES = { + nn.ReLU, + nn.RReLU, + nn.Hardtanh, + nn.ReLU6, + nn.Sigmoid, + nn.Hardsigmoid, + nn.Tanh, + nn.SiLU, + nn.Mish, + nn.Hardswish, + nn.ELU, + nn.CELU, + nn.SELU, + nn.Hardshrink, + nn.LeakyReLU, + nn.LogSigmoid, + nn.Softplus, + nn.PReLU, + nn.Softsign, + nn.Tanhshrink, + } + return SUPPORTED_ACTIVATION_MODULES + + +def _get_default_structured_pruning_patterns() -> Dict[ + Tuple[Union[Type[nn.Module], Callable[[torch.Tensor], torch.Tensor], str], ...], + Callable[..., None], +]: + """ + Returns the patterns for conv2d / linear conversion for each element in the activation functions/modules defined above. + """ + patterns: Dict[ + Tuple[Union[Type[nn.Module], Callable[[torch.Tensor], torch.Tensor], str], ...], + Callable[..., None], + ] = { + # linear -> linear + (nn.Linear, "output"): prune_linear, + (nn.Linear, nn.Linear): prune_linear_linear, + # conv2d -> conv2d + (nn.Conv2d, "output"): prune_conv2d, + (nn.Conv2d, nn.Conv2d): prune_conv2d_conv2d, + } + + for activation in chain( + _get_supported_activation_functions(), _get_supported_activation_modules() + ): + patterns.update( + { + # linear -> activation -> linear + (nn.Linear, activation, nn.Linear): prune_linear_activation_linear, + # conv2d -> activation -> conv2d + (nn.Conv2d, activation, nn.Conv2d): prune_conv2d_activation_conv2d, + # conv2d -> activation -> pool -> conv2d + ( + nn.Conv2d, + activation, + nn.AvgPool2d, + nn.Conv2d, + ): prune_conv2d_activation_pool_conv2d, + ( + nn.Conv2d, + activation, + F.avg_pool2d, + nn.Conv2d, + ): prune_conv2d_activation_pool_conv2d, + ( + nn.Conv2d, + activation, + nn.MaxPool2d, + nn.Conv2d, + ): prune_conv2d_activation_pool_conv2d, + ( + nn.Conv2d, + activation, + F.max_pool2d, + nn.Conv2d, + ): prune_conv2d_activation_pool_conv2d, + # conv2d -> pool -> activation -> conv2d + ( + nn.Conv2d, + nn.AvgPool2d, + activation, + nn.Conv2d, + ): prune_conv2d_pool_activation_conv2d, + ( + nn.Conv2d, + F.avg_pool2d, + activation, + nn.Conv2d, + ): prune_conv2d_pool_activation_conv2d, + ( + nn.Conv2d, + nn.MaxPool2d, + activation, + nn.Conv2d, + ): prune_conv2d_pool_activation_conv2d, + ( + nn.Conv2d, + F.max_pool2d, + activation, + nn.Conv2d, + ): prune_conv2d_pool_activation_conv2d, + # conv2d -> adaptive pool -> flatten -> linear + ( + nn.Conv2d, + nn.AdaptiveAvgPool2d, + nn.Flatten, + nn.Linear, + ): prune_conv2d_pool_flatten_linear, + ( + nn.Conv2d, + nn.AdaptiveAvgPool2d, + torch.flatten, + nn.Linear, + ): prune_conv2d_pool_flatten_linear, + ( + nn.Conv2d, + nn.AdaptiveMaxPool2d, + nn.Flatten, + nn.Linear, + ): prune_conv2d_pool_flatten_linear, + ( + nn.Conv2d, + nn.AdaptiveMaxPool2d, + torch.flatten, + nn.Linear, + ): prune_conv2d_pool_flatten_linear, + } + ) + return patterns + + +class BaseStructuredSparsifier(BaseSparsifier): + r"""Base class for structured pruning. + + Abstract methods that need to be implemented: + - update_mask: Function to compute a new mask for all keys in the + `groups` attribute. + + Args: + - defaults [dict]: default configurations will be attached to the + configuration. Only the keys that don't exist in the `config` will + be updated. + """ + + def __init__(self, defaults, patterns=None): + super().__init__(defaults) + if patterns is None: + patterns = _get_default_structured_pruning_patterns() + self.patterns = patterns + + def make_config_from_model( + self, + model: nn.Module, + SUPPORTED_MODULES: Optional[Set[Type]] = None, + ) -> None: + if SUPPORTED_MODULES is None: + SUPPORTED_MODULES = _get_supported_structured_pruning_modules() + super().make_config_from_model(model, SUPPORTED_MODULES=SUPPORTED_MODULES) + + def _prepare(self, *args, **kwargs) -> None: + r"""This function will attach the FakeStructuredSparsity parameterizations + and BiasHooks at the appropriate points in the model. + """ + self.bias_handles = [] + + for config in self.groups: + module = config["module"] + tensor_name = config["tensor_name"] + parametrization = config.get("parametrization", FakeStructuredSparsity) + tensor = getattr(module, tensor_name) + + mask = config.get( + "mask", + torch.ones(tensor.shape[0], dtype=torch.bool, device=tensor.device), + ) + self.state[config["tensor_fqn"]]["mask"] = mask + parametrize.register_parametrization( + module, tensor_name, parametrization(mask) + ) + prune_bias = config.get("prune_bias", True) + if module.bias is not None: + module.register_parameter("_bias", nn.Parameter(module.bias.detach())) + module.bias = None + module.prune_bias = prune_bias + + self.bias_handles.append( + module.register_forward_hook( + BiasHook(module.parametrizations.weight[0], prune_bias) + ) + ) + + def prune(self) -> None: + r""" + This function will FX symbolically trace the model and then find instances of the patterns + defined in self.patterns (by default SUPPORTED_STRUCTURED_PRUNING_PATTERNS ). + + For each pattern, it will apply to corresponding conversion function, which will modify the output + and input size expected by the modules within the pattern + """ + + self.traced = symbolic_trace(self.model) + modules = dict(self.traced.named_modules()) + + # Right now we check for matches simply by iterating across all the patterns + # if this is slow we can store patterns in a trie-structure and modify this code for faster lookup + + for node in self.traced.graph.nodes: + for pattern, convert_fn in self.patterns.items(): + matched = apply_match(modules, pattern, node, []) + if matched is None: + continue + + first_module = modules.get(node.target) + # check if first module exists and has apropriate parameterization, otherwise skip + if ( + first_module is not None + and parametrize.is_parametrized(first_module) + and isinstance( + first_module.parametrizations["weight"][0], + FakeStructuredSparsity, + ) + ): + convert_block = [] + for node in matched: + if node.op == "call_module": + convert_block.append(modules.get(node.target)) + elif node.op == "call_function": + convert_block.append(node.target) + convert_fn(*convert_block) + + self.traced.graph.lint() + self.traced.recompile() + return self.traced diff --git a/torch/ao/pruning/_experimental/pruner/match_utils.py b/torch/ao/pruning/_experimental/pruner/match_utils.py new file mode 100644 index 0000000000000..d0f7a9f6293d9 --- /dev/null +++ b/torch/ao/pruning/_experimental/pruner/match_utils.py @@ -0,0 +1,59 @@ +""" +Contains utility functions to check if a pattern is in the graph and return the matching nodes +""" +import torch +from torch import nn +from torch.ao.quantization.utils import ( + MatchAllNode, +) +from torch.fx import Node +from torch.nn.utils import parametrize +from typing import Any, Dict, List, Optional, Tuple, Union + +def _match(modules: Dict[str, nn.ModuleDict], node: Node, current: Union[nn.Module, Any]) -> bool: + r""" + checks to see if a single node of a pattern matches + """ + if isinstance(current, type) and issubclass(current, MatchAllNode): + return True + if not isinstance(node, Node): + return False + if isinstance(current, type) and issubclass(current, torch.nn.Module): + return ( + node.op == "call_module" + and parametrize.type_before_parametrizations(modules[node.target]) + == current + ) + elif callable(current): + return node.op == "call_function" and node.target is current + elif isinstance(current, str): + return node.target == current + return False + +def apply_match( + modules: Dict[str, nn.ModuleDict], + pattern: Union[Tuple[Any], Any], + node: Node, + matched_node_pattern: List[Node], +) -> Optional[List[Node]]: + r""" + This function will return the matched nodes if the pattern matches the node given + If there is no match, it will return None + """ + if isinstance(pattern, tuple): + if len(pattern) == 1: + if _match(modules, node, pattern[0]): + return matched_node_pattern + [node] + + first, *rest = pattern + if _match(modules, node, first): + if rest is None: + return matched_node_pattern + [node] + + for user in node.users: + return apply_match( + modules, tuple(rest), user, matched_node_pattern + [node] + ) + elif _match(modules, node, pattern): + return [node] + return None diff --git a/torch/ao/pruning/_experimental/pruner/parametrization.py b/torch/ao/pruning/_experimental/pruner/parametrization.py index 77c86a22e175a..aeddd0a841525 100644 --- a/torch/ao/pruning/_experimental/pruner/parametrization.py +++ b/torch/ao/pruning/_experimental/pruner/parametrization.py @@ -1,72 +1,44 @@ import torch from torch import nn -from typing import Any, List -__all__ = ['PruningParametrization', 'ZeroesParametrization', 'ActivationReconstruction', 'BiasHook'] -class PruningParametrization(nn.Module): - def __init__(self, original_outputs): - super().__init__() - self.original_outputs = set(range(original_outputs.item())) - self.pruned_outputs = set() # Will contain indicies of outputs to prune - def forward(self, x): - valid_outputs = self.original_outputs - self.pruned_outputs - return x[list(valid_outputs)] +# Structured Pruning Parameterizations +class FakeStructuredSparsity(nn.Module): + r""" + Parametrization for Structured Pruning. Like FakeSparsity, this should be attached to + the 'weight' or any other parameter that requires a mask. + Instead of an element-wise bool mask, this parameterization uses a row-wise bool mask. + """ -class ZeroesParametrization(nn.Module): - r"""Zero out pruned channels instead of removing. - E.g. used for Batch Norm pruning, which should match previous Conv2d layer.""" - def __init__(self, original_outputs): + def __init__(self, mask): super().__init__() - self.original_outputs = set(range(original_outputs.item())) - self.pruned_outputs = set() # Will contain indicies of outputs to prune + self.register_buffer("mask", mask) def forward(self, x): - x.data[list(self.pruned_outputs)] = 0 - return x - - -class ActivationReconstruction: - def __init__(self, parametrization): - self.param = parametrization - - def __call__(self, module, input, output): - max_outputs = self.param.original_outputs - pruned_outputs = self.param.pruned_outputs - valid_columns = list(max_outputs - pruned_outputs) - - # get size of reconstructed output - sizes = list(output.shape) - sizes[1] = len(max_outputs) - - # get valid indices of reconstructed output - indices: List[Any] = [] - for size in output.shape: - indices.append(slice(0, size, 1)) - indices[1] = valid_columns - - reconstructed_tensor = torch.zeros(sizes, - dtype=output.dtype, - device=output.device, - layout=output.layout) - reconstructed_tensor[indices] = output - return reconstructed_tensor + assert isinstance(self.mask, torch.Tensor) + assert self.mask.shape[0] == x.shape[0] + shape = [1] * len(x.shape) + shape[0] = -1 + return self.mask.reshape(shape) * x + def state_dict(self, *args, **kwargs): + # avoid double saving masks + return {} class BiasHook: + def __init__(self, parametrization, prune_bias): self.param = parametrization self.prune_bias = prune_bias def __call__(self, module, input, output): - pruned_outputs = self.param.pruned_outputs if getattr(module, '_bias', None) is not None: bias = module._bias.data if self.prune_bias: - bias[list(pruned_outputs)] = 0 + bias[~self.param.mask] = 0 # reshape bias to broadcast over output dimensions idx = [1] * len(output.shape) diff --git a/torch/ao/pruning/_experimental/pruner/prune_functions.py b/torch/ao/pruning/_experimental/pruner/prune_functions.py new file mode 100644 index 0000000000000..ee8bffb7f9f3e --- /dev/null +++ b/torch/ao/pruning/_experimental/pruner/prune_functions.py @@ -0,0 +1,359 @@ +""" +Collection of conversion functions for linear / conv2d structured pruning +Also contains utilities for bias propogation +""" +from typing import cast, Optional, Callable, Tuple + +import torch +from torch import nn, Tensor +from torch.nn.utils import parametrize +from torch.nn.utils.parametrize import ParametrizationList +from .parametrization import FakeStructuredSparsity, BiasHook + + +# BIAS PROPOGATION +def _remove_bias_handles(module: nn.Module) -> None: + if hasattr(module, "_forward_hooks"): + bias_hooks = [] + for key, hook in module._forward_hooks.items(): + if isinstance(hook, BiasHook): + bias_hooks.append(key) + + for key in bias_hooks: + del module._forward_hooks[key] + + +def _get_adjusted_next_layer_bias( + next_layer: nn.Module, pruned_biases: Tensor, mask: Tensor +) -> nn.Parameter: + r"""Returns new adjusted bias for the second supported module""" + if parametrize.is_parametrized(next_layer): + # need to access original weight + parametrization_dict = cast(nn.ModuleDict, next_layer.parametrizations) + weight_parameterizations = cast( + ParametrizationList, parametrization_dict.weight + ) + next_weight = weight_parameterizations.original + else: + next_weight = cast(Tensor, next_layer.weight) + + scaling_weight = next_weight[:, ~mask] + if isinstance(next_layer, nn.Conv2d): # checking for Conv2d + # Propagating first layer pruned biases and calculating the new second layer bias + # involves more steps since the Conv2d scaling weight has extra dimensions, + # so adding bias involves broadcasting, logically: + # for each channel k in range(oC): + # scaled_biases = sum(first_bias[pruned_idx] @ next_weight[k, pruned_idx, :, :].T) + # new_next_bias[k] = old_next_bias[k] + scaled_biases + scaling_product = torch.matmul( + pruned_biases.reshape(1, -1), torch.transpose(scaling_weight, 1, 2) + ) + sum_range = list(range(len(scaling_product.shape)))[ + 1: + ] # all but the first dimension + scaled_biases = torch.sum(scaling_product, sum_range) + elif isinstance(next_layer, nn.Linear): # Linear + scaled_biases = torch.matmul( + pruned_biases, torch.transpose(scaling_weight, 0, 1) + ) # recall b2_new = b1 @ w2.T + b2 + else: + raise NotImplementedError(f"Type {type(next_layer)} not supported yet.") + + if ( + parametrize.is_parametrized(next_layer) + and getattr(next_layer, "_bias", None) is not None + ): # next_layer is parametrized & has original bias ._bias + adjusted_bias = nn.Parameter(scaled_biases + next_layer._bias) + elif ( + not parametrize.is_parametrized(next_layer) and next_layer.bias is not None + ): # next_layer not parametrized & has .bias + adjusted_bias = nn.Parameter(scaled_biases + next_layer.bias) + else: # next_layer has no bias + adjusted_bias = nn.Parameter(scaled_biases) + return adjusted_bias + + +def _prune_module_bias(module: nn.Module, mask: Tensor) -> None: + r"""Applies mask to given modules bias""" + # prune bias along with weights, discard pruned indices of bias + original_bias = cast(Tensor, getattr(module, "_bias", module.bias)) + if original_bias is not None: + module.bias = nn.Parameter(original_bias[mask]) + + # remove _bias parameter + if hasattr(module, "_bias"): + delattr(module, "_bias") + + +def _propogate_module_bias(module: nn.Module, mask: Tensor) -> Optional[Tensor]: + r""" + In the case that we need to propogate biases, this function will return the biases we need + """ + # set current module bias + if module.bias is not None: + module.bias = nn.Parameter(cast(Tensor, module.bias)[mask]) + elif getattr(module, "_bias", None) is not None: + module.bias = nn.Parameter(cast(Tensor, module._bias)[mask]) + + # get pruned biases to propogate to subsequent layer + if getattr(module, "_bias", None) is not None: + pruned_biases = cast(Tensor, module._bias)[~mask] + else: + pruned_biases = None + + if hasattr(module, "_bias"): + delattr(module, "_bias") + + return pruned_biases + + +# LINEAR +def _prune_linear_helper(linear: nn.Linear) -> Tensor: + # expects linear to be a parameterized linear module + parametrization_dict = cast(nn.ModuleDict, linear.parametrizations) + weight_parameterizations = cast(ParametrizationList, parametrization_dict.weight) + for p in weight_parameterizations: + if isinstance(p, FakeStructuredSparsity): + mask = cast(Tensor, p.mask) + + with torch.no_grad(): + parametrize.remove_parametrizations(linear, "weight", leave_parametrized=True) + linear.weight = nn.Parameter(linear.weight[mask]) + linear.out_features = linear.weight.shape[0] + _remove_bias_handles(linear) + + return mask + + +def prune_linear(linear: nn.Linear) -> None: + mask = _prune_linear_helper(linear) + if getattr(linear, "prune_bias", False): + _prune_module_bias(linear, mask) + + +def prune_linear_linear(linear1: nn.Linear, linear2: nn.Linear) -> None: + prune_linear_activation_linear(linear1, None, linear2) + + +def prune_linear_activation_linear( + linear1: nn.Linear, + activation: Optional[Callable[[Tensor], Tensor]], + linear2: nn.Linear, +): + mask = _prune_linear_helper(linear1) + if getattr(linear1, "prune_bias", False): + _prune_module_bias(linear1, mask) + else: + pruned_biases = _propogate_module_bias(linear1, mask) + if pruned_biases is not None: + if activation: + pruned_biases = activation(pruned_biases) + linear2.bias = _get_adjusted_next_layer_bias(linear2, pruned_biases, mask) + + with torch.no_grad(): + if parametrize.is_parametrized(linear2): + parametrization_dict = cast(nn.ModuleDict, linear2.parametrizations) + weight_parameterizations = cast( + ParametrizationList, parametrization_dict.weight + ) + + weight_parameterizations.original = nn.Parameter( + weight_parameterizations.original[:, mask] + ) + linear2.in_features = weight_parameterizations.original.shape[1] + else: + linear2.weight = nn.Parameter(linear2.weight[:, mask]) + linear2.in_features = linear2.weight.shape[1] + + +# CONV2D +def _prune_conv2d_helper(conv2d: nn.Conv2d) -> Tensor: + parametrization_dict = cast(nn.ModuleDict, conv2d.parametrizations) + weight_parameterizations = cast(ParametrizationList, parametrization_dict.weight) + for p in weight_parameterizations: + if isinstance(p, FakeStructuredSparsity): + mask = cast(Tensor, p.mask) + + with torch.no_grad(): + parametrize.remove_parametrizations(conv2d, "weight", leave_parametrized=True) + conv2d.weight = nn.Parameter(conv2d.weight[mask]) + conv2d.out_channels = conv2d.weight.shape[0] + + _remove_bias_handles(conv2d) + return mask + + +def prune_conv2d_padded(conv2d_1: nn.Conv2d) -> None: + parametrization_dict = cast(nn.ModuleDict, conv2d_1.parametrizations) + weight_parameterizations = cast(ParametrizationList, parametrization_dict.weight) + for p in weight_parameterizations: + if isinstance(p, FakeStructuredSparsity): + mask = cast(Tensor, p.mask) + + with torch.no_grad(): + parametrize.remove_parametrizations(conv2d_1, "weight", leave_parametrized=True) + + if getattr(conv2d_1, "_bias", None) is not None: + if ( + conv2d_1.bias is not None + ): # conv2d_1 has original bias and bias propagated from previous layer + new_bias = torch.zeros(conv2d_1.bias.shape) + new_bias[mask] = conv2d_1.bias[mask] + # adjusted bias that to keep in conv2d_1 + new_bias[~mask] = cast(Tensor, conv2d_1._bias)[~mask] + # pruned biases that are kept instead of propagated + conv2d_1.bias = nn.Parameter(new_bias) + else: # conv2d_1 has only original bias + conv2d_1.bias = nn.Parameter(cast(Tensor, conv2d_1._bias)) + else: + # no original bias, only propagated bias + if ( + conv2d_1.bias is not None + ): # conv2d_1 has bias propagated from previous layer + conv2d_1.bias.data[~mask] = 0 + + if hasattr(conv2d_1, "_bias"): + delattr(conv2d_1, "_bias") + + +def prune_conv2d(conv2d: nn.Conv2d) -> None: + mask = _prune_conv2d_helper(conv2d) + if getattr(conv2d, "prune_bias", False): + _prune_module_bias(conv2d, mask) + + +def prune_conv2d_conv2d(conv2d_1: nn.Conv2d, conv2d_2: nn.Conv2d) -> None: + prune_conv2d_activation_conv2d(conv2d_1, None, conv2d_2) + + +def prune_conv2d_activation_conv2d( + conv2d_1: nn.Conv2d, + activation: Optional[Callable[[Tensor], Tensor]], + conv2d_2: nn.Conv2d, +): + r""" + Fusion Pattern for conv2d -> some activation module / function -> conv2d layers + """ + parametrization_dict = cast(nn.ModuleDict, conv2d_1.parametrizations) + weight_parameterizations = cast(ParametrizationList, parametrization_dict.weight) + for p in weight_parameterizations: + if isinstance(p, FakeStructuredSparsity): + mask = cast(Tensor, p.mask) + + prune_bias = getattr(conv2d_1, "prune_bias", False) + if ( + hasattr(conv2d_2, "padding") + and cast(Tuple[int], conv2d_2.padding) > (0, 0) + and (conv2d_1.bias is not None or getattr(conv2d_1, "_bias", None) is not None) + ): + prune_conv2d_padded(conv2d_1) + else: + mask = _prune_conv2d_helper(conv2d_1) + if prune_bias: + _prune_module_bias(conv2d_1, mask) + else: + pruned_biases = _propogate_module_bias(conv2d_1, mask) + if pruned_biases is not None: + if activation: + pruned_biases = activation(pruned_biases) + conv2d_2.bias = _get_adjusted_next_layer_bias( + conv2d_2, pruned_biases, mask + ) + + if ( + not ( + hasattr(conv2d_2, "padding") + and cast(Tuple[int], conv2d_2.padding) > (0, 0) + ) + or conv2d_1.bias is None + ): + with torch.no_grad(): + if parametrize.is_parametrized(conv2d_2): + parametrization_dict = cast( + nn.ModuleDict, conv2d_2.parametrizations + ) + weight_parameterizations = cast( + ParametrizationList, parametrization_dict.weight + ) + weight_parameterizations.original = nn.Parameter( + weight_parameterizations.original[:, mask] + ) + conv2d_2.in_channels = weight_parameterizations.original.shape[1] + else: + conv2d_2.weight = nn.Parameter(conv2d_2.weight[:, mask]) + conv2d_2.in_channels = conv2d_2.weight.shape[1] + + +def prune_conv2d_pool_activation_conv2d( + c1: nn.Conv2d, + pool: nn.Module, + activation: Optional[Callable[[Tensor], Tensor]], + c2: nn.Conv2d, +) -> None: + prune_conv2d_activation_conv2d(c1, activation, c2) + + +def prune_conv2d_activation_pool_conv2d( + c1: nn.Conv2d, + activation: Optional[Callable[[Tensor], Tensor]], + pool: nn.Module, + c2: nn.Conv2d, +) -> None: + prune_conv2d_activation_conv2d(c1, activation, c2) + + +def prune_conv2d_pool_flatten_linear( + conv2d: nn.Conv2d, + pool: nn.Module, + flatten: Optional[Callable[[Tensor], Tensor]], + linear: nn.Linear, +) -> None: + mask = _prune_conv2d_helper(conv2d) + + # We map the pruned indices of the Conv2d output to the flattened indices of the Linear following the Flatten layer. + # we determine the flattening scale (h * w), and readjust `first_pruned_indices` + # (each idx maps to range idx * h * w to (idx+1) * h * w), `first_valid_indices`, + # and `pruned_biases` (repeat each bias by h * w). + if parametrize.is_parametrized(linear): + parametrization_dict = cast(nn.ModuleDict, linear.parametrizations) + weight_parameterizations = cast( + ParametrizationList, parametrization_dict.weight + ) + linear_ic = weight_parameterizations.original.shape[1] + else: + linear_ic = linear.weight.shape[1] + + conv2d_oc = len(mask) + assert ( + linear_ic % conv2d_oc == 0 + ), f"Flattening from dimensions {conv2d_oc} to {linear_ic} not supported" + + flatten_scale = linear_ic // conv2d_oc + flattened_mask = torch.tensor( + [[val] * flatten_scale for val in mask], dtype=torch.bool, device=mask.device + ).flatten() + + if getattr(conv2d, "prune_bias", False): + _prune_module_bias(conv2d, mask) + else: + pruned_biases = cast(Tensor, _propogate_module_bias(conv2d, mask)) + flattened_pruned_biases = torch.tensor( + [[bias] * flatten_scale for bias in pruned_biases], device=mask.device + ).flatten() + linear.bias = _get_adjusted_next_layer_bias( + linear, flattened_pruned_biases, flattened_mask + ) + + with torch.no_grad(): + if parametrize.is_parametrized(linear): + parametrization_dict = cast(nn.ModuleDict, linear.parametrizations) + weight_parameterizations = cast( + ParametrizationList, parametrization_dict.weight + ) + weight_parameterizations.original = nn.Parameter( + weight_parameterizations.original[:, flattened_mask] + ) + linear.in_features = weight_parameterizations.original.shape[1] + else: + linear.weight = nn.Parameter(linear.weight[:, flattened_mask]) + linear.in_features = linear.weight.shape[1] diff --git a/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py b/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py index 8a66280cc852d..2ba2584616e21 100644 --- a/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py +++ b/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py @@ -99,7 +99,7 @@ def _make_tensor_mask(self, data, input_shape, sparsity_level, sparse_block_shap dw = (block_w - w % block_w) % block_w if mask is None: - mask = torch.ones(h, w, device=data.device) + mask = torch.ones(h + dh, w + dw, device=data.device) if sparsity_level >= 1.0: mask.data = torch.zeros_like(mask) @@ -141,14 +141,15 @@ def _make_block_mask(self, data, sparse_block_shape, zeros_per_block, mask=None) In this context the `zeros_per_block` describes the number of zeroed-out elements within a patch. """ - if mask is None: - mask = torch.ones(data.shape, device=data.device) h, w = data.shape[-2:] block_h, block_w = sparse_block_shape dh = (block_h - h % block_h) % block_h dw = (block_w - w % block_w) % block_w values_per_block = reduce((lambda x, y: x * y), sparse_block_shape) + if mask is None: + mask = torch.ones((h + dh, w + dw), device=data.device) + if values_per_block == zeros_per_block: # Everything should be sparsified mask.data = torch.zeros_like(mask) @@ -168,7 +169,7 @@ def _make_block_mask(self, data, sparse_block_shape, zeros_per_block, mask=None) dim=1, indices=sorted_idx, output_shape=padded_data.shape, block_shape=sparse_block_shape, mask=mask_reshape ) - mask.data = mask_reshape.squeeze().reshape(mask.shape)[:h, :w].contiguous() + mask.data = mask_reshape.squeeze().reshape(mask.shape).contiguous() return mask def update_mask(self, module, tensor_name, sparsity_level, sparse_block_shape, diff --git a/torch/ao/quantization/__init__.py b/torch/ao/quantization/__init__.py index dc0e0a07381f5..1ba2a60ed3d12 100644 --- a/torch/ao/quantization/__init__.py +++ b/torch/ao/quantization/__init__.py @@ -40,11 +40,8 @@ "RecordingObserver", "ReuseInputObserver", "UniformQuantizationObserverBase", - "activation_is_memoryless", - "add_module_to_qconfig_obs_ctr", "add_observer_", "add_quant_dequant", - "assert_valid_qconfig", "convert", "convert_dynamic_jit", "convert_jit", @@ -117,12 +114,9 @@ "get_quantized_operator", "get_static_quant_module_class", "get_unique_devices_", - "get_valid_patterns", "is_activation_post_process", - "is_reuse_input_qconfig", "load_observer_state_dict", "no_observer_set", - "obs_or_fq_ctr_equals", "per_channel_weight_observer_range_neg_127_to_127", "prepare", "prepare_dynamic_jit", @@ -130,19 +124,14 @@ "prepare_qat", "propagate_qconfig_", "qconfig_equals", - "quant_type_to_str", "quantize", "quantize_dynamic", "quantize_dynamic_jit", "quantize_jit", "quantize_qat", "register_activation_post_process_hook", - "reverse2", - "reverse3", - "reverse_sequential_wrapper2", "script_qconfig", "script_qconfig_dict", - "sequential_wrapper2", "swap_module", "weight_observer_range_neg_127_to_127", ] diff --git a/torch/ao/quantization/_correct_bias.py b/torch/ao/quantization/_correct_bias.py index 0d9017533166a..7dfc58dfe52ad 100644 --- a/torch/ao/quantization/_correct_bias.py +++ b/torch/ao/quantization/_correct_bias.py @@ -5,6 +5,14 @@ import torch.ao.quantization import torch.ao.ns._numeric_suite as ns +__all__ = [ + "get_module", + "parent_child_names", + "get_param", + "MeanShadowLogger", + "bias_correction", +] + _supported_modules = {nn.Linear, nn.Conv2d} _supported_modules_quantized = {nnq.Linear, nnq.Conv2d} diff --git a/torch/ao/quantization/_equalize.py b/torch/ao/quantization/_equalize.py index 1da025ca7a0d2..b15ffc65b7ad1 100644 --- a/torch/ao/quantization/_equalize.py +++ b/torch/ao/quantization/_equalize.py @@ -2,6 +2,19 @@ import copy from typing import Dict, Any +__all__ = [ + "set_module_weight", + "set_module_bias", + "get_module_weight", + "get_module_bias", + "max_over_ndim", + "min_over_ndim", + "channel_range", + "cross_layer_equalization", + "equalize", + "converged", +] + _supported_types = {torch.nn.Conv2d, torch.nn.Linear} _supported_intrinsic_types = {torch.nn.intrinsic.ConvReLU2d, torch.nn.intrinsic.LinearReLU} _all_supported_types = _supported_types.union(_supported_intrinsic_types) diff --git a/torch/ao/quantization/_learnable_fake_quantize.py b/torch/ao/quantization/_learnable_fake_quantize.py index 9be2a4c5900ad..10600363d3564 100644 --- a/torch/ao/quantization/_learnable_fake_quantize.py +++ b/torch/ao/quantization/_learnable_fake_quantize.py @@ -1,6 +1,8 @@ import torch from torch.nn.parameter import Parameter +from typing import List +__all__: List[str] = [] class _LearnableFakeQuantize(torch.ao.quantization.FakeQuantizeBase): r""" This is an extension of the FakeQuantize module in fake_quantize.py, which diff --git a/torch/ao/quantization/backend_config/README.md b/torch/ao/quantization/backend_config/README.md index a170581d5638b..5d37fce9ec502 100644 --- a/torch/ao/quantization/backend_config/README.md +++ b/torch/ao/quantization/backend_config/README.md @@ -1,10 +1,34 @@ -The patterns are we matching against is float modules types, functional operators and pytorch operators in reverse order: +## BackendConfig Overview + +BackendConfig allows PyTorch quantization to work with different backend or kernel libraries. These backends may have different sets of supported quantized operator patterns, and the same operator patterns may require different handling across different backends. To make quantization work with different backends and allow maximum flexibility, we strived to make all the parts of the quantization flow configurable with BackendConfig. Currently, it is only used by FX graph mode quantization. For more details on how it integrates with the FX graph mode quantization flow, refer to this [README](/torch/ao/quantization/fx/README.md). + +BackendConfig configures quantization behavior in terms of operator patterns. For each operator pattern, we need to specify what the supported data types are for the input and output activations, weights, and biases, and also specify the QAT modules, the reference quantized modules etc., which will be used in module swapping during the quantization passes. + +Quantized backends can have different support in terms of the following aspects: +* Quantization scheme (symmetric vs asymmetric, per-channel vs per-tensor) +* Data type (float32, float16, int8, uint8, bfloat16, etc.) for input/output/weight/bias +* Quantized (and fused) mapping: Some quantized operators may have different numerics compared to a naive (dequant - float_op - quant) reference implementation. For weighted operators, such as conv and linear, we need to be able to specify custom reference modules and a mapping from the float modules +* QAT mapping: For weighted operators, we need to swap them with the Quantization Aware Training (QAT) versions that add fake quantization to the weights + +As an example, here is what fbgemm looks like: +| | fbgemm | +|-------------------------------------------|-----------------------------------------------------------------------| +| Quantization Scheme | activation: per tensor, weight: per tensor or per channel | +| Data Type | activation: quint8 (with qmin/qmax range restrictions), weight: qint8 | +| Quantized and Fused Operators and Mapping | e.g. torch.nn.Conv2d -> torch.ao.nn.quantized.reference.Conv2d | +| QAT Module Mapping | e.g. torch.nn.Conv2d -> torch.ao.nn.qat.Conv2d | + +Instead of hardcoding the fusion mappings, float to reference quantized module mappings, fusion patterns etc., we will derive everything from the BackendConfig throughout the code base. This allows PyTorch Quantization to work with all first-party (fbgemm and qnnpack) and third-party backends (TensorRT, executorch etc.) that may differ from native backends in different aspects. With the recent addition of xnnpack, integrated as part of the qnnpack backend in PyTorch, the BackendConfig is needed to define the new constraints required for xnnpack quantized operators. + +## Pattern Specification + +The operator patterns used in BackendConfig are float modules, functional operators and pytorch operators specified in reverse order: ``` operator = module_type | functional | torch op | native op | MatchAllNode Pattern = (operator, Pattern, Pattern, ...) | operator ``` -where the first item for Pattern is the operator we want to match, and the rest are the patterns for the arguments of the operator. -For example, pattern (nn.ReLU, (operator.add, MatchAllNode, (nn.BatchNorm2d, nn.Conv2d))) would match the following graph: +where the first item for each Pattern is the operator, and the rest are the patterns for the arguments of the operator. +For example, the pattern (nn.ReLU, (operator.add, MatchAllNode, (nn.BatchNorm2d, nn.Conv2d))) would match the following graph: ``` tensor_1 tensor_2 | | @@ -17,4 +41,137 @@ tensor_1 tensor_2 nn.ReLU ``` -we’ll match the last node as the anchor point of the match, and we can retrieve the whole graph by tracing back from the node, e.g. in the example above, we matched nn.ReLU node, then node.args[0] is the operator.add node. +During prepare and convert, we’ll match the last node, which will be the anchor point of the match, and we can retrieve the whole graph by tracing back from the node. E.g. in the example above, we matched the `nn.ReLU` node, and `node.args[0]` is the `operator.add` node. + +## BackendConfig Implementation + +The BackendConfig is comprised of a list of BackendPatternConfigs, each of which define the specifications and the requirements for an operator pattern. Here is an example usage: + +``` +import torch +from torch.ao.quantization.backend_config import ( + BackendConfig, + BackendPatternConfig, + DTypeConfig, + ObservationType, +) + +weighted_int8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.qint8, + bias_dtype=torch.float) + +def fuse_conv2d_relu(is_qat, relu, conv): + """Return a fused ConvReLU2d from individual conv and relu modules.""" + return torch.ao.nn.intrinsic.ConvReLU2d(conv, relu) + +# For quantizing Linear +linear_config = BackendPatternConfig(torch.nn.Linear) \ + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \ + .add_dtype_config(weighted_int8_dtype_config) \ + .set_root_module(torch.nn.Linear) \ + .set_qat_module(torch.ao.nn.qat.Linear) \ + .set_reference_quantized_module(torch.ao.nn.quantized.reference.Linear) + +# For fusing Conv2d + ReLU into ConvReLU2d +conv_relu_config = BackendPatternConfig((torch.nn.ReLU, torch.nn.Conv2d)) \ + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \ + .add_dtype_config(weighted_int8_dtype_config) \ + .set_fused_module(torch.ao.nn.intrinsic.ConvReLU2d) \ + .set_fuser_method(fuse_conv2d_relu) + +# For quantizing ConvReLU2d +fused_conv_relu_config = BackendPatternConfig(torch.ao.nn.intrinsic.ConvReLU2d) \ + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \ + .add_dtype_config(weighted_int8_dtype_config) \ + .set_root_module(torch.nn.Conv2d) \ + .set_qat_module(torch.ao.nn.intrinsic.qat.ConvReLU2d) \ + .set_reference_quantized_module(torch.ao.nn.quantized.reference.Conv2d) + +backend_config = BackendConfig("my_backend") \ + .set_backend_pattern_config(linear_config) \ + .set_backend_pattern_config(conv_relu_config) \ + .set_backend_pattern_config(fused_conv_relu_config) +``` + +### Observer Insertion + +Relevant APIs: +* `set_observation_type` + +During the prepare phase, we insert observers (or QuantDeQuantStubs in the future) into the graph for this operator pattern based on the observation type, which specifies whether to use different observers for the inputs and the outputs of the pattern. For more detail, see `torch.ao.quantization.backend_config.ObservationType`. + +### Reference Quantized Patterns + +Relevant APIs: +* `set_root_module` +* `set_reference_quantized_module` + +During the convert phase, when we construct the reference quantized model, the root modules (e.g. `torch.nn.Linear` for `nni.LinearReLU` or `nniqat.LinearReLU`) will be swapped to the corresponding reference quantized modules (e.g. `torch.ao.nn.reference.Linear`). This allows custom backends to specify custom reference quantized module implementations to match the numerics of their lowered operators. Since this is a one-to-one mapping, both the root module and the reference quantized module must be specified in the same BackendPatternConfig in order for the conversion to take place. + +### Fusion + +Relevant APIs: +* `set_fuser_method` +* `set_fused_module` +* `_set_root_node_getter` +* `_set_extra_inputs_getter` + +As an optimization, operator patterns such as (`torch.nn.ReLU`, `torch.nn.Linear`) may be fused into `nni.LinearReLU`. This is performed during the prepare phase according to the function specified in `set_fuser_method`, which replaces the pattern with the fused module. During the convert phase, these fused modules (identified by `set_fused_module`) will then be converted to the reference quantized versions of the modules. + +In FX graph mode quantization, we replace the corresponding nodes in the graph using two helper functions set by the user: `root_node_getter`, which returns the root node (typically the weighted module in the pattern like `torch.nn.Linear`) to replace the matched pattern in the graph, and `extra_inputs_getter`, which returns a list of extra input arguments that will be appended to the existing arguments of the fused module (copied over from the root node). See [this snippet](https://gist.github.com/jerryzh168/8bea7180a8ba3c279f2c9b050f2a69a6) for an example usage. + +### Data Type Restrictions + +Relevant APIs: +* `add_dtype_config` +* `set_dtype_configs` + +DTypeConfig specifies a set of supported data types for input/output/weight/bias along with the associated constraints, if any. There are two ways of specifying `input_dtype`, `output_dtype`, and `weight_dtype`, as simple `torch.dtype`s or as `DTypeWithConstraints`, e.g.: + +``` +import torch +from torch.ao.quantization.backend import DTypeConfig, DTypeWithConstraints + +dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.qint8, + bias_dtype=torch.float) + +dtype_config_with_constraints = DTypeConfig( + input_dtype=DTypeWithConstraints( + dtype=torch.quint8, + quant_min_lower_bound=0, + quant_max_upper_bound=255, + scale_min_lower_bound=2 ** -12, + ), + output_dtype=DTypeWithConstraints( + dtype=torch.quint8, + quant_min_lower_bound=0, + quant_max_upper_bound=255, + scale_min_lower_bound=2 ** -12, + ), + weight_dtype=DTypeWithConstraints( + dtype=torch.qint8, + quant_min_lower_bound=-128, + quant_max_upper_bound=127, + scale_min_lower_bound=2 ** -12, + ), + bias_dtype=torch.float) +``` + +During the prepare phase of quantization, we will compare the data types specified in these DTypeConfigs to the ones specified in the matching QConfig for a given operator pattern. If the data types do not match (or the constraints are not satisfied) for all the DTypeConfigs specified for the operator pattern, then we will simply ignore the QConfig and skip quantizing this pattern. + +#### Quantization range + +The user's QConfig may specify `quant_min` and `quant_max`, which are min and max restrictions on the quantization values. Here we set the lower bound for the `quant_min` and then upper bound for the `quant_max` to represent the limits of the backend. If a QConfig exceeds these limits in either direction, it will be treated as violating this constraint. + +#### Scale range + +Similarly, the user's QConfig may specify a minimum value for the quantization scale (currently exposed as `eps` but will change in the future to better reflect the semantics). Here we set the lower bound for the `scale_min` to represent the limits of the backend. If a QConfig's min scale value falls below this limit, the QConfig will be treated as violating this constraint. Note that `scale_max_upper_bound` is currently not used, because there is no corresponding mechanism to enforce this on the observer yet. + +#### Fixed quantization parameters + +For ops with fixed quantization parameters such as `torch.nn.Sigmoid` or `torch.nn.Tanh`, the BackendConfig can specify the specific scale and zero point values as constraints on the input and output activations. The user's QConfigs for these ops must use `FixedQParamsObserver` or `FixedQParamsFakeQuantize` for their activations with matching scale and zero point values, otherwise these QConfigs will be ignored. diff --git a/torch/ao/quantization/backend_config/__init__.py b/torch/ao/quantization/backend_config/__init__.py index 6443b756f716c..9aba6d2e9853f 100644 --- a/torch/ao/quantization/backend_config/__init__.py +++ b/torch/ao/quantization/backend_config/__init__.py @@ -16,5 +16,6 @@ "BackendConfig", "BackendPatternConfig", "DTypeConfig", + "DTypeWithConstraints", "ObservationType", ] diff --git a/torch/ao/quantization/backend_config/_common_operator_config_utils.py b/torch/ao/quantization/backend_config/_common_operator_config_utils.py index bc6f678485fb6..3d95b8b38a38b 100644 --- a/torch/ao/quantization/backend_config/_common_operator_config_utils.py +++ b/torch/ao/quantization/backend_config/_common_operator_config_utils.py @@ -1,3 +1,4 @@ +import copy import operator import torch import torch.nn.functional as F @@ -7,23 +8,24 @@ import torch.nn.qat as nnqat import torch.nn.quantized._reference as nnqr from collections import namedtuple -from typing import List +from typing import Callable, Dict, List, Union from .backend_config import ( BackendPatternConfig, DTypeConfig, + DTypeWithConstraints, ObservationType, ) -from ..fake_quantize import FixedQParamsFakeQuantize from ..fuser_method_mappings import ( - reverse_sequential_wrapper2, - reverse2, - reverse3, + _reverse_sequential_wrapper2, + _reverse2, + _reverse3, fuse_conv_bn, fuse_conv_bn_relu, fuse_linear_bn, fuse_convtranspose_bn, ) -from ..qconfig_mapping import _FIXED_QPARAMS_OP_TO_OBSERVER + +__all__: List[str] = [] # TODO: rename to be more explict, e.g. qat_conv_relu _ConvMetadata = namedtuple( @@ -48,6 +50,38 @@ nnqat.Conv3d, nniqat.ConvReLU3d, nniqat.ConvBn3d, nniqat.ConvBnReLU3d, F.conv3d) +# Add constraints for fixed qparams ops like sigmoid and tanh to ensure values +# fall within the proper ranges, e.g. [0, 1] for sigmoid, [-1, 1] for tanh +_FIXED_QPARAM_OP_0TO1_CONSTRAINTS = DTypeWithConstraints( + dtype=torch.quint8, + quant_min_lower_bound=0, + quant_max_upper_bound=255, + scale_exact_match=1.0 / 256.0, + zero_point_exact_match=0, +) +_FIXED_QPARAM_OP_NEG1TO1_CONSTRAINTS = DTypeWithConstraints( + dtype=torch.quint8, + quant_min_lower_bound=0, + quant_max_upper_bound=255, + scale_exact_match=2.0 / 256.0, + zero_point_exact_match=128, +) +_FIXED_QPARAMS_OP_TO_CONSTRAINTS: Dict[Union[Callable, str], DTypeWithConstraints] = { + torch.nn.Hardsigmoid: _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, + torch.nn.functional.hardsigmoid: _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, + "hardsigmoid": _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, + "hardsigmoid_": _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, + torch.nn.Sigmoid: _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, + torch.sigmoid: _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, + "sigmoid": _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, + "sigmoid_": _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, + torch.nn.Softmax: _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, + torch.nn.Tanh: _FIXED_QPARAM_OP_NEG1TO1_CONSTRAINTS, + torch.tanh: _FIXED_QPARAM_OP_NEG1TO1_CONSTRAINTS, + "tanh": _FIXED_QPARAM_OP_NEG1TO1_CONSTRAINTS, + "tanh_": _FIXED_QPARAM_OP_NEG1TO1_CONSTRAINTS, +} + def _get_binary_op_configs(dtype_configs: List[DTypeConfig]) -> List[BackendPatternConfig]: binary_op_configs: List[BackendPatternConfig] = [] num_tensor_args_to_observation_type_mapping = { @@ -115,13 +149,13 @@ def _get_linear_configs(dtype_configs: List[DTypeConfig]) -> List[BackendPattern linear_configs.append( BackendPatternConfig((torch.nn.ReLU, torch.nn.Linear)) .set_dtype_configs(dtype_configs) # noqa: E131 - .set_fuser_method(reverse_sequential_wrapper2(nni.LinearReLU)) + .set_fuser_method(_reverse_sequential_wrapper2(nni.LinearReLU)) .set_fused_module(nni.LinearReLU)) # linear relu, linear module + functional relu linear_configs.append( BackendPatternConfig((torch.nn.functional.relu, torch.nn.Linear)) .set_dtype_configs(dtype_configs) # noqa: E131 - .set_fuser_method(reverse_sequential_wrapper2(nni.LinearReLU)) + .set_fuser_method(_reverse_sequential_wrapper2(nni.LinearReLU)) .set_fused_module(nni.LinearReLU)) # 2.2 linear module + relu, fused module configs @@ -158,7 +192,7 @@ def _get_linear_configs(dtype_configs: List[DTypeConfig]) -> List[BackendPattern linear_configs.append( BackendPatternConfig((nn.BatchNorm1d, nn.Linear)) .set_dtype_configs(dtype_configs) # noqa: E131 - .set_fuser_method(reverse2(fuse_linear_bn)) + .set_fuser_method(_reverse2(fuse_linear_bn)) .set_fused_module(nni.LinearBn1d)) # 3.2 linear bn fused @@ -218,13 +252,13 @@ def _get_conv_configs(dtype_configs): conv_configs.append( BackendPatternConfig((torch.nn.ReLU, convs.root)) .set_dtype_configs(dtype_configs) # noqa: E131 - .set_fuser_method(reverse_sequential_wrapper2(convs.fused_conv_relu)) + .set_fuser_method(_reverse_sequential_wrapper2(convs.fused_conv_relu)) .set_fused_module(convs.fused_conv_relu)) # conv relu fusion, conv module + functional relu conv_configs.append( BackendPatternConfig((F.relu, convs.root)) .set_dtype_configs(dtype_configs) # noqa: E131 - .set_fuser_method(reverse_sequential_wrapper2(convs.fused_conv_relu)) + .set_fuser_method(_reverse_sequential_wrapper2(convs.fused_conv_relu)) .set_fused_module(convs.fused_conv_relu)) # 2.2 conv module + relu fused module configs # conv relu, fused module @@ -273,20 +307,20 @@ def _get_conv_configs(dtype_configs): conv_configs.append( BackendPatternConfig((convs.bn, convs.root)) .set_dtype_configs(dtype_configs) # noqa: E131 - .set_fuser_method(reverse2(fuse_conv_bn)) + .set_fuser_method(_reverse2(fuse_conv_bn)) .set_fused_module(convs.fused_conv_bn)) # conv + bn + relu module fusion conv_configs.append( BackendPatternConfig((nn.ReLU, (convs.bn, convs.root))) .set_dtype_configs(dtype_configs) # noqa: E131 - .set_fuser_method(reverse3(fuse_conv_bn_relu)) + .set_fuser_method(_reverse3(fuse_conv_bn_relu)) .set_fused_module(convs.fused_conv_bn_relu)) # conv + bn + relu functional fusion conv_configs.append( BackendPatternConfig((F.relu, (convs.bn, convs.root))) .set_dtype_configs(dtype_configs) # noqa: E131 .set_root_module(convs.root) - .set_fuser_method(reverse3(fuse_conv_bn_relu)) + .set_fuser_method(_reverse3(fuse_conv_bn_relu)) .set_fused_module(convs.fused_conv_bn_relu)) # TODO: we can add fusion for torch.relu as well @@ -330,7 +364,7 @@ def _get_conv_configs(dtype_configs): conv_configs.append( BackendPatternConfig((convs.bn, convs.transpose)) .set_dtype_configs(dtype_configs) # noqa: E131 - .set_fuser_method(reverse2(fuse_convtranspose_bn)) + .set_fuser_method(_reverse2(fuse_convtranspose_bn)) .set_root_module(convs.transpose) .set_reference_quantized_module(convs.transpose_reference)) @@ -393,21 +427,45 @@ def _get_default_op_configs(dtype_configs: List[DTypeConfig]) -> List[BackendPat ) return configs +def _add_fixed_qparams_to_dtype_configs( + dtype_configs: List[DTypeConfig], + constraints: DTypeWithConstraints, +) -> List[DTypeConfig]: + """ + Return a copy of the list of DTypeConfigs where activations are subject to the specified + constraints required for fixed qparams ops. + + If the data type doesn't match the one in the constraints, simply leave the corresponding + DTypeConfig unchanged. + + If `scale_min_lower_bound` or `scale_max_upper_bound` is specified in the activations, + throw an exception since these settings are incompatible with fixed qparams ops. + """ + new_dtype_configs = [] + for dtype_config in dtype_configs: + dc = copy.deepcopy(dtype_config) + for orig_constraints in [dc.input_dtype_with_constraints, dc.output_dtype_with_constraints]: + if orig_constraints.dtype != constraints.dtype: + continue + if orig_constraints.scale_min_lower_bound is not None: + raise ValueError("scale_min_lower_bound is invalid for fixed qparams ops: %s" % dtype_config) + if orig_constraints.scale_max_upper_bound is not None: + raise ValueError("scale_max_upper_bound is invalid for fixed qparams ops: %s" % dtype_config) + orig_constraints.quant_min_lower_bound = constraints.quant_min_lower_bound + orig_constraints.quant_max_upper_bound = constraints.quant_max_upper_bound + orig_constraints.scale_exact_match = constraints.scale_exact_match + orig_constraints.zero_point_exact_match = constraints.zero_point_exact_match + new_dtype_configs.append(dc) + return new_dtype_configs + def _get_fixed_qparams_op_configs(dtype_configs: List[DTypeConfig]) -> List[BackendPatternConfig]: fixed_qparams_op_configs = [] - for fixed_qparam_op, output_observer in _FIXED_QPARAMS_OP_TO_OBSERVER.items(): + for fixed_qparam_op, constraints in _FIXED_QPARAMS_OP_TO_CONSTRAINTS.items(): + new_dtype_configs = _add_fixed_qparams_to_dtype_configs(dtype_configs, constraints) fixed_qparams_op_configs.append( - # TODO: The _overwrite_output keys are temporary, since we don't want to put observer - # in the configs we expect that it's provided by user - # What we want to put here is the requirement on observers, in this case dtype, - # quant_min, quant_max etc., but we need to first move all configs to - # backend_config_dict to do that, we'll remove these keys after we fully migrated - # everything to use backend_config_dict BackendPatternConfig(fixed_qparam_op) .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) # noqa: E131 - .set_dtype_configs(dtype_configs) - ._set_overwrite_output_fake_quantize(FixedQParamsFakeQuantize.with_args(observer=output_observer)) - ._set_overwrite_output_observer(output_observer)) + .set_dtype_configs(new_dtype_configs)) return fixed_qparams_op_configs def _get_share_qparams_op_configs(dtype_configs): @@ -497,13 +555,13 @@ def _get_bn_configs(dtype_configs: List[DTypeConfig]) -> List[BackendPatternConf bn_configs.append( BackendPatternConfig((torch.nn.ReLU, bn)) .set_dtype_configs(dtype_configs) # noqa: E131 - .set_fuser_method(reverse_sequential_wrapper2(fused_bn)) + .set_fuser_method(_reverse_sequential_wrapper2(fused_bn)) .set_fused_module(fused_bn)) # bn module + F.relu fusion config bn_configs.append( BackendPatternConfig((torch.nn.functional.relu, bn)) .set_dtype_configs(dtype_configs) # noqa: E131 - .set_fuser_method(reverse_sequential_wrapper2(bn_to_fused_bn[bn])) + .set_fuser_method(_reverse_sequential_wrapper2(bn_to_fused_bn[bn])) .set_fused_module(fused_bn)) bn_configs.append( BackendPatternConfig(bn) @@ -557,10 +615,3 @@ def _get_embedding_op_configs(dtype_configs: List[DTypeConfig]) -> List[BackendP .set_reference_quantized_module(ref_embedding_op) ._set_input_output_observed(False)) # This is temporary, and will be removed soon return embedding_op_configs - -__all__ = [ - "_get_binary_op_configs", - "_get_linear_configs", - "_get_conv_configs", - "_get_share_qparams_op_configs", -] diff --git a/torch/ao/quantization/backend_config/backend_config.py b/torch/ao/quantization/backend_config/backend_config.py index e0d7e0b9d7428..4b3d4d3aa8130 100644 --- a/torch/ao/quantization/backend_config/backend_config.py +++ b/torch/ao/quantization/backend_config/backend_config.py @@ -3,7 +3,6 @@ from typing import Any, Callable, Dict, List, Optional, Type, Union import torch -from torch.ao.quantization.observer import _PartialWrapper from torch.ao.quantization.utils import Pattern from enum import Enum @@ -42,8 +41,6 @@ NUM_TENSOR_ARGS_TO_OBSERVATION_TYPE_DICT_KEY = "num_tensor_args_to_observation_type" INPUT_TYPE_TO_INDEX_DICT_KEY = "input_type_to_index" INPUT_OUTPUT_OBSERVED_DICT_KEY = "input_output_observed" -OVERWRITE_OUTPUT_FAKE_QUANTIZE_DICT_KEY = "overwrite_output_fake_quantize" -OVERWRITE_OUTPUT_OBSERVER_DICT_KEY = "overwrite_output_observer" # TODO: maybe rename this to something that's not related to observer @@ -69,14 +66,17 @@ class ObservationType(Enum): @dataclass class DTypeWithConstraints: """ - Config for specifying additional constraints for a given dtype, such as quantization value - ranges and scale value ranges, to be used in :class:`~torch.ao.quantization.backend_config.DTypeConfig`. + Config for specifying additional constraints for a given dtype, such as quantization + value ranges, scale value ranges, and fixed quantization params, to be used in + :class:`~torch.ao.quantization.backend_config.DTypeConfig`. """ dtype: Optional[torch.dtype] = None quant_min_lower_bound: Union[int, float, None] = None quant_max_upper_bound: Union[int, float, None] = None scale_min_lower_bound: Union[int, float, None] = None scale_max_upper_bound: Union[int, float, None] = None + scale_exact_match: Optional[float] = None + zero_point_exact_match: Optional[int] = None @dataclass @@ -228,31 +228,49 @@ class BackendConfig: Example usage:: import torch - from torch.ao.quantization.backend_config import BackendConfig, BackendPatternConfig, DTypeConfig, ObservationType - from torch.ao.quantization.fuser_method_mappings import reverse_sequential_wrapper2 + from torch.ao.quantization.backend_config import ( + BackendConfig, + BackendPatternConfig, + DTypeConfig, + ObservationType, + ) weighted_int8_dtype_config = DTypeConfig( input_dtype=torch.quint8, output_dtype=torch.quint8, weight_dtype=torch.qint8, - bias_type=torch.float) + bias_dtype=torch.float) + def fuse_conv2d_relu(is_qat, relu, conv): + return torch.ao.nn.intrinsic.ConvReLU2d(conv, relu) + + # For quantizing Linear linear_config = BackendPatternConfig(torch.nn.Linear) \ .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \ .add_dtype_config(weighted_int8_dtype_config) \ .set_root_module(torch.nn.Linear) \ - .set_qat_module(torch.nn.qat.Linear) \ - .set_reference_quantized_module(torch.nn.quantized._reference.Linear) + .set_qat_module(torch.ao.nn.qat.Linear) \ + .set_reference_quantized_module(torch.ao.nn.quantized.reference.Linear) + # For fusing Conv2d + ReLU into ConvReLU2d conv_relu_config = BackendPatternConfig((torch.nn.ReLU, torch.nn.Conv2d)) \ .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \ .add_dtype_config(weighted_int8_dtype_config) \ - .set_fused_module(torch.nn.intrinsic.ConvReLU2d) \ - .set_fuser_method(reverse_sequential_wrapper2(torch.nn.intrinsic.ConvReLU2d)) + .set_fused_module(torch.ao.nn.intrinsic.ConvReLU2d) \ + .set_fuser_method(fuse_conv2d_relu) + + # For quantizing ConvReLU2d + fused_conv_relu_config = BackendPatternConfig(torch.ao.nn.intrinsic.ConvReLU2d) \ + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \ + .add_dtype_config(weighted_int8_dtype_config) \ + .set_root_module(torch.nn.Conv2d) \ + .set_qat_module(torch.ao.nn.intrinsic.qat.ConvReLU2d) \ + .set_reference_quantized_module(torch.ao.nn.quantized.reference.Conv2d) backend_config = BackendConfig("my_backend") \ .set_backend_pattern_config(linear_config) \ - .set_backend_pattern_config(conv_relu_config) + .set_backend_pattern_config(conv_relu_config) \ + .set_backend_pattern_config(fused_conv_relu_config) """ def __init__(self, name: str = ""): @@ -336,14 +354,20 @@ def __init__(self, pattern: Pattern): self._num_tensor_args_to_observation_type: Dict[int, ObservationType] = {} self._input_type_to_index: Dict[str, int] = {} self._input_output_observed: Optional[bool] = None - self._overwrite_output_fake_quantize: Optional[_PartialWrapper] = None - self._overwrite_output_observer: Optional[_PartialWrapper] = None def set_observation_type(self, observation_type: ObservationType) -> BackendPatternConfig: """ - Set how observers should be inserted for this pattern. - See :class:`~torch.ao.quantization.backend_config.ObservationType` for details + Set how observers should be inserted in the graph for this pattern. + There are two observation types: + + `OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT` (default): the output observer instance will be + different from the input. This is the most common observation type. + + `OUTPUT_SHARE_OBSERVER_WITH_INPUT`: the output observer instance will be the same as the input. + This is useful for operators like `cat`. + Note: This will be renamed in the near future, since we will soon insert QuantDeQuantStubs with + observers (and fake quantizes) attached instead of observers themselves. """ self.observation_type = observation_type return self @@ -395,6 +419,11 @@ def set_fused_module(self, fused_module: Type[torch.nn.Module]) -> BackendPatter def set_fuser_method(self, fuser_method: Callable) -> BackendPatternConfig: """ Set the function that specifies how to fuse the pattern for this pattern. + + The first argument of this function should be `is_qat`, and the rest of the arguments + should be the items in the tuple pattern, e.g. (`torch.nn.ReLU`, `torch.nn.Linear`) + will have a function with three arguments, `is_qat`, `relu`, and `linear`. + The return value of this function should be the resulting fused module. """ self.fuser_method = fuser_method return self @@ -420,14 +449,6 @@ def _set_input_output_observed(self, input_output_observed: bool) -> BackendPatt self._input_output_observed = input_output_observed return self - def _set_overwrite_output_fake_quantize(self, overwrite_output_fake_quantize: _PartialWrapper) -> BackendPatternConfig: - self._overwrite_output_fake_quantize = overwrite_output_fake_quantize - return self - - def _set_overwrite_output_observer(self, overwrite_output_observer: _PartialWrapper) -> BackendPatternConfig: - self._overwrite_output_observer = overwrite_output_observer - return self - @classmethod def from_dict(cls, backend_pattern_config_dict: Dict[str, Any]) -> BackendPatternConfig: """ @@ -474,8 +495,6 @@ def _get_dtype_config(obj: Any) -> DTypeConfig: backend_pattern_config_dict.get(NUM_TENSOR_ARGS_TO_OBSERVATION_TYPE_DICT_KEY, {})) conf._set_input_type_to_index(backend_pattern_config_dict.get(INPUT_TYPE_TO_INDEX_DICT_KEY, {})) conf._set_input_output_observed(backend_pattern_config_dict.get(INPUT_OUTPUT_OBSERVED_DICT_KEY, None)) - conf._set_overwrite_output_fake_quantize(backend_pattern_config_dict.get(OVERWRITE_OUTPUT_FAKE_QUANTIZE_DICT_KEY, None)) - conf._set_overwrite_output_observer(backend_pattern_config_dict.get(OVERWRITE_OUTPUT_OBSERVER_DICT_KEY, None)) return conf def to_dict(self) -> Dict[str, Any]: @@ -508,8 +527,4 @@ def to_dict(self) -> Dict[str, Any]: backend_pattern_config_dict[INPUT_TYPE_TO_INDEX_DICT_KEY] = self._input_type_to_index if self._input_output_observed is not None: backend_pattern_config_dict[INPUT_OUTPUT_OBSERVED_DICT_KEY] = self._input_output_observed - if self._overwrite_output_fake_quantize is not None: - backend_pattern_config_dict[OVERWRITE_OUTPUT_FAKE_QUANTIZE_DICT_KEY] = self._overwrite_output_fake_quantize - if self._overwrite_output_observer is not None: - backend_pattern_config_dict[OVERWRITE_OUTPUT_OBSERVER_DICT_KEY] = self._overwrite_output_observer return backend_pattern_config_dict diff --git a/torch/ao/quantization/backend_config/executorch.py b/torch/ao/quantization/backend_config/executorch.py index 4c0f2a48b552e..fcccec6c2225f 100644 --- a/torch/ao/quantization/backend_config/executorch.py +++ b/torch/ao/quantization/backend_config/executorch.py @@ -5,9 +5,18 @@ import torch.nn as nn import torch.nn.qat as nnqat import torch.nn.quantized._reference as nnqr -from .backend_config import BackendConfig, BackendPatternConfig, DTypeConfig, ObservationType +from .backend_config import ( + BackendConfig, + BackendPatternConfig, + DTypeConfig, + ObservationType, +) +from .qnnpack import ( + qnnpack_weighted_op_qint8_symmetric_dtype_config, + qnnpack_default_op_qint8_symmetric_dtype_config +) from ._common_operator_config_utils import _Conv2dMetadata -from ..fuser_method_mappings import reverse_sequential_wrapper2 +from ..fuser_method_mappings import _reverse_sequential_wrapper2 __all__ = [ @@ -47,7 +56,6 @@ is_dynamic=True, ) - # ============================= # | BACKEND PATTERN CONFIGS | # ============================= @@ -58,6 +66,7 @@ def _get_linear_configs() -> List[BackendPatternConfig]: """ observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT dtype_configs = [ + qnnpack_weighted_op_qint8_symmetric_dtype_config, executorch_weighted_op_int8_dtype_config, executorch_default_dynamic_int8_dtype_config, executorch_default_dynamic_float16_dtype_config, @@ -84,7 +93,10 @@ def _get_conv_configs() -> List[BackendPatternConfig]: Return all configs related to conv modules and ops. """ observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT - dtype_configs = [executorch_weighted_op_int8_dtype_config] + dtype_configs = [ + qnnpack_weighted_op_qint8_symmetric_dtype_config, + executorch_weighted_op_int8_dtype_config + ] conv_configs = [] for convs in [_Conv2dMetadata]: # conv module @@ -105,13 +117,13 @@ def _get_conv_configs() -> List[BackendPatternConfig]: conv_configs.append( BackendPatternConfig((torch.nn.ReLU, convs.root)) .set_dtype_configs(dtype_configs) # noqa: E131 - .set_fuser_method(reverse_sequential_wrapper2(convs.fused_conv_relu)) + .set_fuser_method(_reverse_sequential_wrapper2(convs.fused_conv_relu)) .set_fused_module(convs.fused_conv_relu)) # conv module + functional relu conv_configs.append( BackendPatternConfig((F.relu, convs.root)) .set_dtype_configs(dtype_configs) # noqa: E131 - .set_fuser_method(reverse_sequential_wrapper2(convs.fused_conv_relu)) + .set_fuser_method(_reverse_sequential_wrapper2(convs.fused_conv_relu)) .set_fused_module(convs.fused_conv_relu)) # fused conv relu module conv_configs.append( @@ -137,7 +149,10 @@ def _get_binary_ops_configs() -> List[BackendPatternConfig]: """ Return all configs related to binary ops. """ - dtype_configs = [executorch_weighted_op_int8_dtype_config] + dtype_configs = [ + qnnpack_default_op_qint8_symmetric_dtype_config, + executorch_weighted_op_int8_dtype_config + ] num_tensor_args_to_observation_type_mapping = { # TODO: this is not used right now since we have extra check in prepare # will need to change this to NO_OBSERVER later after we implemented @@ -165,7 +180,10 @@ def _get_share_qparams_ops_configs() -> List[BackendPatternConfig]: observer_0 - avgpool2d - observer_0 (same observer instance as input) """ observation_type = ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT - dtype_configs = [executorch_default_op_quint8_dtype_config] + dtype_configs = [ + qnnpack_default_op_qint8_symmetric_dtype_config, + executorch_default_op_quint8_dtype_config + ] share_qparams_ops = [ F.adaptive_avg_pool2d, F.relu, @@ -192,7 +210,10 @@ def _get_bn_configs() -> List[BackendPatternConfig]: Return all configs related to batchnorm. """ observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT - dtype_configs = [executorch_default_op_quint8_dtype_config] + dtype_configs = [ + qnnpack_default_op_qint8_symmetric_dtype_config, + executorch_default_op_quint8_dtype_config + ] bn_configs = [] bn_configs.append( BackendPatternConfig(nn.BatchNorm2d) @@ -200,6 +221,17 @@ def _get_bn_configs() -> List[BackendPatternConfig]: .set_dtype_configs(dtype_configs)) return bn_configs +def _get_cat_configs() -> List[BackendPatternConfig]: + dtype_configs = [ + qnnpack_default_op_qint8_symmetric_dtype_config, + executorch_default_op_quint8_dtype_config + ] + cat_configs = [] + cat_configs.append( + BackendPatternConfig(torch.cat) + .set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT) + .set_dtype_configs(dtype_configs)) + return cat_configs # ===================== # | BACKEND CONFIGS | @@ -214,4 +246,5 @@ def get_executorch_backend_config() -> BackendConfig: .set_backend_pattern_configs(_get_conv_configs()) \ .set_backend_pattern_configs(_get_binary_ops_configs()) \ .set_backend_pattern_configs(_get_share_qparams_ops_configs()) \ - .set_backend_pattern_configs(_get_bn_configs()) + .set_backend_pattern_configs(_get_bn_configs()) \ + .set_backend_pattern_configs(_get_cat_configs()) diff --git a/torch/ao/quantization/backend_config/fbgemm.py b/torch/ao/quantization/backend_config/fbgemm.py index de38272b00e9f..d2bc87879c44f 100644 --- a/torch/ao/quantization/backend_config/fbgemm.py +++ b/torch/ao/quantization/backend_config/fbgemm.py @@ -13,6 +13,9 @@ ) from .backend_config import BackendConfig, DTypeConfig +__all__ = [ + "get_fbgemm_backend_config", +] # =================== # | DTYPE CONFIGS | @@ -108,7 +111,3 @@ def get_fbgemm_backend_config() -> BackendConfig: .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs)) \ .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs)) \ .set_backend_pattern_configs(_get_embedding_op_configs(embedding_op_dtype_configs)) - -__all__ = [ - "get_fbgemm_backend_config", -] diff --git a/torch/ao/quantization/backend_config/native.py b/torch/ao/quantization/backend_config/native.py index f584aff82a12b..ad5a12e6053b1 100644 --- a/torch/ao/quantization/backend_config/native.py +++ b/torch/ao/quantization/backend_config/native.py @@ -14,6 +14,19 @@ ) from .backend_config import BackendConfig, DTypeConfig +__all__ = [ + "get_test_only_legacy_native_backend_config", + "default_op_quint8_dtype_config", + "default_op_fp16_dtype_config", + "default_dynamic_int8_dtype_config", + "default_dynamic_float16_dtype_config", + "input_output_only_quint8_dtype_config", + "weight_only_quint8_dtype_config", + "weight_only_quint4x2_dtype_config", + "get_native_backend_config", + "get_native_backend_config_dict", + "get_test_only_legacy_native_backend_config_dict", +] # =================== # | DTYPE CONFIGS | @@ -182,17 +195,3 @@ def get_test_only_legacy_native_backend_config_dict(): fp16 ops in dictionary form. """ return get_test_only_legacy_native_backend_config().to_dict() - -__all__ = [ - "get_test_only_legacy_native_backend_config", - "default_op_quint8_dtype_config", - "default_op_fp16_dtype_config", - "default_dynamic_int8_dtype_config", - "default_dynamic_float16_dtype_config", - "input_output_only_quint8_dtype_config", - "weight_only_quint8_dtype_config", - "weight_only_quint4x2_dtype_config", - "get_native_backend_config", - "get_native_backend_config_dict", - "get_test_only_legacy_native_backend_config_dict", -] diff --git a/torch/ao/quantization/backend_config/qnnpack.py b/torch/ao/quantization/backend_config/qnnpack.py index 391acf55614af..772a25c655744 100644 --- a/torch/ao/quantization/backend_config/qnnpack.py +++ b/torch/ao/quantization/backend_config/qnnpack.py @@ -13,6 +13,9 @@ ) from .backend_config import BackendConfig, DTypeConfig, DTypeWithConstraints +__all__ = [ + "get_qnnpack_backend_config", +] # =================== # | DTYPE CONFIGS | @@ -155,7 +158,3 @@ def get_qnnpack_backend_config() -> BackendConfig: .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs)) \ .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs)) \ .set_backend_pattern_configs(_get_embedding_op_configs(embedding_op_dtype_configs)) - -__all__ = [ - "get_qnnpack_backend_config", -] diff --git a/torch/ao/quantization/backend_config/tensorrt.py b/torch/ao/quantization/backend_config/tensorrt.py index 9b6fb39e06160..a617f765adf77 100644 --- a/torch/ao/quantization/backend_config/tensorrt.py +++ b/torch/ao/quantization/backend_config/tensorrt.py @@ -12,6 +12,11 @@ _get_share_qparams_op_configs, ) +__all__ = [ + "get_tensorrt_backend_config", + "get_tensorrt_backend_config_dict", +] + def get_tensorrt_backend_config() -> BackendConfig: """ Return the `BackendConfig` for the TensorRT backend. @@ -69,8 +74,3 @@ def get_tensorrt_backend_config_dict(): Return the `BackendConfig` for the TensorRT backend in dictionary form. """ return get_tensorrt_backend_config().to_dict() - -__all__ = [ - "get_tensorrt_backend_config", - "get_tensorrt_backend_config_dict", -] diff --git a/torch/ao/quantization/backend_config/x86.py b/torch/ao/quantization/backend_config/x86.py index ce92ed9bc42b2..78a3f76187821 100644 --- a/torch/ao/quantization/backend_config/x86.py +++ b/torch/ao/quantization/backend_config/x86.py @@ -13,6 +13,9 @@ ) from .backend_config import BackendConfig, DTypeConfig +__all__ = [ + "get_x86_backend_config", +] # =================== # | DTYPE CONFIGS | @@ -105,7 +108,3 @@ def get_x86_backend_config() -> BackendConfig: .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs)) \ .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs)) \ .set_backend_pattern_configs(_get_embedding_op_configs(embedding_op_dtype_configs)) - -__all__ = [ - "get_x86_backend_config", -] diff --git a/torch/ao/quantization/fuse_modules.py b/torch/ao/quantization/fuse_modules.py index eb7296e38f60f..6cf37af0cf934 100644 --- a/torch/ao/quantization/fuse_modules.py +++ b/torch/ao/quantization/fuse_modules.py @@ -160,7 +160,7 @@ def fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_mo modules_to_fuse, is_qat=False, inplace=inplace, - fuser_func=fuse_known_modules, + fuser_func=fuser_func, fuse_custom_config_dict=None) def fuse_modules_qat(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_modules, fuse_custom_config_dict=None): @@ -171,5 +171,5 @@ def fuse_modules_qat(model, modules_to_fuse, inplace=False, fuser_func=fuse_know modules_to_fuse, is_qat=True, inplace=inplace, - fuser_func=fuse_known_modules, + fuser_func=fuser_func, fuse_custom_config_dict=None) diff --git a/torch/ao/quantization/fuser_method_mappings.py b/torch/ao/quantization/fuser_method_mappings.py index 2e39f87321d41..db4cc9a04d767 100644 --- a/torch/ao/quantization/fuser_method_mappings.py +++ b/torch/ao/quantization/fuser_method_mappings.py @@ -10,13 +10,7 @@ "fuse_conv_bn_relu", "fuse_linear_bn", "fuse_convtranspose_bn", - "sequential_wrapper2", "get_fuser_method", - "reverse_sequential_wrapper2", - "reverse2", - "reverse3", - "DEFAULT_PATTERN_TO_FUSER_METHOD", - "get_valid_patterns", "get_fuser_method_new", ] @@ -156,7 +150,7 @@ def fuse_convtranspose_bn(is_qat, convt, bn): else: return nn.utils.fusion.fuse_conv_bn_eval(convt, bn, transpose=True) -def sequential_wrapper2(sequential): +def _sequential_wrapper2(sequential): """ Given a sequential class for two modules, return a function that takes is_qat, and then two modules as argument, that ignores the is_qat flag and always returns the sequential that combines the two input modules @@ -165,20 +159,20 @@ def fuser_method(is_qat, m1, m2): return sequential(m1, m2) return fuser_method -DEFAULT_OP_LIST_TO_FUSER_METHOD: Dict[Tuple, Union[nn.Sequential, Callable]] = { +_DEFAULT_OP_LIST_TO_FUSER_METHOD: Dict[Tuple, Union[nn.Sequential, Callable]] = { (nn.Conv1d, nn.BatchNorm1d): fuse_conv_bn, (nn.Conv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu, (nn.Conv2d, nn.BatchNorm2d): fuse_conv_bn, (nn.Conv2d, nn.BatchNorm2d, nn.ReLU): fuse_conv_bn_relu, (nn.Conv3d, nn.BatchNorm3d): fuse_conv_bn, (nn.Conv3d, nn.BatchNorm3d, nn.ReLU): fuse_conv_bn_relu, - (nn.Conv1d, nn.ReLU): sequential_wrapper2(nni.ConvReLU1d), - (nn.Conv2d, nn.ReLU): sequential_wrapper2(nni.ConvReLU2d), - (nn.Conv3d, nn.ReLU): sequential_wrapper2(nni.ConvReLU3d), + (nn.Conv1d, nn.ReLU): _sequential_wrapper2(nni.ConvReLU1d), + (nn.Conv2d, nn.ReLU): _sequential_wrapper2(nni.ConvReLU2d), + (nn.Conv3d, nn.ReLU): _sequential_wrapper2(nni.ConvReLU3d), (nn.Linear, nn.BatchNorm1d): fuse_linear_bn, - (nn.Linear, nn.ReLU): sequential_wrapper2(nni.LinearReLU), - (nn.BatchNorm2d, nn.ReLU): sequential_wrapper2(nni.BNReLU2d), - (nn.BatchNorm3d, nn.ReLU): sequential_wrapper2(nni.BNReLU3d), + (nn.Linear, nn.ReLU): _sequential_wrapper2(nni.LinearReLU), + (nn.BatchNorm2d, nn.ReLU): _sequential_wrapper2(nni.BNReLU2d), + (nn.BatchNorm3d, nn.ReLU): _sequential_wrapper2(nni.BNReLU3d), (nn.ConvTranspose1d, nn.BatchNorm1d): fuse_convtranspose_bn, (nn.ConvTranspose2d, nn.BatchNorm2d): fuse_convtranspose_bn, (nn.ConvTranspose3d, nn.BatchNorm3d): fuse_convtranspose_bn, @@ -190,13 +184,13 @@ def get_fuser_method(op_list, additional_fuser_method_mapping=None): ''' if additional_fuser_method_mapping is None: additional_fuser_method_mapping = {} - all_mappings = get_combined_dict(DEFAULT_OP_LIST_TO_FUSER_METHOD, + all_mappings = get_combined_dict(_DEFAULT_OP_LIST_TO_FUSER_METHOD, additional_fuser_method_mapping) fuser_method = all_mappings.get(op_list, None) assert fuser_method is not None, "did not find fuser method for: {} ".format(op_list) return fuser_method -def reverse_sequential_wrapper2(sequential): +def _reverse_sequential_wrapper2(sequential): """ Given a sequential class for two modules, return a function that takes is_qat, and then two modules as argument, that ignores the is_qat flag and always returns the sequential that combines the two input modules, with @@ -206,37 +200,37 @@ def fuser_method(is_qat, m1, m2): return sequential(m2, m1) return fuser_method -def reverse2(f): +def _reverse2(f): def reversed(is_qat, x, y): return f(is_qat, y, x) return reversed -def reverse3(f): +def _reverse3(f): def reversed(is_qat, x, w): y, z = w return f(is_qat, z, y, x) return reversed -DEFAULT_PATTERN_TO_FUSER_METHOD: Dict[Pattern, Union[nn.Sequential, Callable]] = { - (nn.BatchNorm1d, nn.Conv1d): reverse2(fuse_conv_bn), - (nn.ReLU, (nn.BatchNorm1d, nn.Conv1d)): reverse3(fuse_conv_bn_relu), - (nn.BatchNorm2d, nn.Conv2d): reverse2(fuse_conv_bn), - (nn.ReLU, (nn.BatchNorm2d, nn.Conv2d)): reverse3(fuse_conv_bn_relu), - (nn.BatchNorm3d, nn.Conv3d): reverse2(fuse_conv_bn), - (nn.ReLU, (nn.BatchNorm3d, nn.Conv3d)): reverse3(fuse_conv_bn_relu), - (nn.ReLU, nn.Conv1d): reverse_sequential_wrapper2(nni.ConvReLU1d), - (nn.ReLU, nn.Conv2d): reverse_sequential_wrapper2(nni.ConvReLU2d), - (nn.ReLU, nn.Conv3d): reverse_sequential_wrapper2(nni.ConvReLU3d), - (nn.BatchNorm1d, nn.Linear): reverse2(fuse_linear_bn), - (nn.ReLU, nn.Linear): reverse_sequential_wrapper2(nni.LinearReLU), - (nn.ReLU, nn.BatchNorm2d): reverse_sequential_wrapper2(nni.BNReLU2d), - (nn.ReLU, nn.BatchNorm3d): reverse_sequential_wrapper2(nni.BNReLU3d), - (nn.BatchNorm1d, nn.ConvTranspose1d): reverse2(fuse_convtranspose_bn), - (nn.BatchNorm2d, nn.ConvTranspose2d): reverse2(fuse_convtranspose_bn), - (nn.BatchNorm3d, nn.ConvTranspose3d): reverse2(fuse_convtranspose_bn), +_DEFAULT_PATTERN_TO_FUSER_METHOD: Dict[Pattern, Union[nn.Sequential, Callable]] = { + (nn.BatchNorm1d, nn.Conv1d): _reverse2(fuse_conv_bn), + (nn.ReLU, (nn.BatchNorm1d, nn.Conv1d)): _reverse3(fuse_conv_bn_relu), + (nn.BatchNorm2d, nn.Conv2d): _reverse2(fuse_conv_bn), + (nn.ReLU, (nn.BatchNorm2d, nn.Conv2d)): _reverse3(fuse_conv_bn_relu), + (nn.BatchNorm3d, nn.Conv3d): _reverse2(fuse_conv_bn), + (nn.ReLU, (nn.BatchNorm3d, nn.Conv3d)): _reverse3(fuse_conv_bn_relu), + (nn.ReLU, nn.Conv1d): _reverse_sequential_wrapper2(nni.ConvReLU1d), + (nn.ReLU, nn.Conv2d): _reverse_sequential_wrapper2(nni.ConvReLU2d), + (nn.ReLU, nn.Conv3d): _reverse_sequential_wrapper2(nni.ConvReLU3d), + (nn.BatchNorm1d, nn.Linear): _reverse2(fuse_linear_bn), + (nn.ReLU, nn.Linear): _reverse_sequential_wrapper2(nni.LinearReLU), + (nn.ReLU, nn.BatchNorm2d): _reverse_sequential_wrapper2(nni.BNReLU2d), + (nn.ReLU, nn.BatchNorm3d): _reverse_sequential_wrapper2(nni.BNReLU3d), + (nn.BatchNorm1d, nn.ConvTranspose1d): _reverse2(fuse_convtranspose_bn), + (nn.BatchNorm2d, nn.ConvTranspose2d): _reverse2(fuse_convtranspose_bn), + (nn.BatchNorm3d, nn.ConvTranspose3d): _reverse2(fuse_convtranspose_bn), } -def get_valid_patterns(op_pattern): +def _get_valid_patterns(op_pattern): """ Returns a list of valid patterns generated from the op_pattern, since MatchAllNode can match all types of nodes, @@ -261,7 +255,7 @@ def get_valid_patterns(op_pattern): if isinstance(op_pattern, (tuple, list)): sub_combs = [] for sub_pattern in op_pattern: - sub_combs.append(get_valid_patterns(sub_pattern)) + sub_combs.append(_get_valid_patterns(sub_pattern)) result = list(itertools.product(*sub_combs)) else: result = [op_pattern, MatchAllNode] @@ -274,9 +268,9 @@ def get_fuser_method_new( Would like to implement this first and have a separate PR for deprecation """ if fuser_method_mapping is None: - fuser_method_mapping = DEFAULT_PATTERN_TO_FUSER_METHOD + fuser_method_mapping = _DEFAULT_PATTERN_TO_FUSER_METHOD - op_patterns = get_valid_patterns(op_pattern) + op_patterns = _get_valid_patterns(op_pattern) fuser_method = None for op_pattern in op_patterns: fuser_method = fuser_method_mapping.get(op_pattern, None) diff --git a/torch/ao/quantization/fx/README.md b/torch/ao/quantization/fx/README.md index 389a5e428627d..7816247dc3291 100644 --- a/torch/ao/quantization/fx/README.md +++ b/torch/ao/quantization/fx/README.md @@ -81,7 +81,7 @@ What we did in this example are: ``` BackendPatternConfig((torch.nn.ReLU, torch.nn.Linear)) - .set_fuser_method(reverse_sequential_wrapper2(nni.LinearReLU)) + .set_fuser_method(_reverse_sequential_wrapper2(nni.LinearReLU)) ._set_root_node_getter(my_root_node_getter) ._set_extra_inputs_getter(my_extra_inputs_getter) ``` @@ -169,15 +169,20 @@ input - qat_linear_relu - output 'pattern': nnqat.LinearReLU, 'dtype_configs': [{input: torch.quint8, output: torch.quint8, weight: torch.qint8}], } +``` + +step 1: assign qconfig to each op (please see [TODO: link] for details) -# step 1: assign qconfig to each op (please see [TODO: link] for details) -# step 2: determine which qconfigs are valid according to the backend configuration (please see [TODO: link] for details) +step 2: determine which qconfigs are valid according to the backend configuration (please see [TODO: link] for details) (we should add a warning here) -# step 3: for subgraphs with validated qconfigs, insert qstub/dqstub/qdqstub needed -# To talk about what happens in this step, let’s first define some terms. Let’s view the computation graph we showed about as a Graph consists of nodes and edges, each node here will be an FX Node that represents some computation, for example linear, and each edge will be a connection between two nodes, and each edge can both be viewed as the output of the previous Node or the input of the next Node. -# The end goal for this step is to insert QDQStubs at edges so that we produce a graph of quantized reference model when each QDQStub represents a quantize operator followed by a dequantize operator. +step 3: for subgraphs with validated qconfigs, insert qstub/dqstub/qdqstub needed + +To talk about what happens in this step, let’s first define some terms. Let’s view the computation graph we showed above as a Graph consists of nodes and edges, each node here will be an FX Node that represents some computation, for example linear, and each edge will be a connection between two nodes, and each edge can both be viewed as the output of the previous Node or the input of the next Node. +The end goal for this step is to insert QDQStubs at edges so that we produce a graph of quantized reference model when each QDQStub represents a quantize operator followed by a dequantize operator. + +``` # graph 2: input - QDQStub1 (FakeQuantize) - qat_linear_relu - QDQStub2 (FakeQuantize) - output | @@ -185,11 +190,13 @@ input - QDQStub1 (FakeQuantize) - qat_linear_relu - QDQStub2 (FakeQuantize) - ou (need to be updated with QDQStub + FakeQuantize) | weight +``` Note: weight + FakeQuantize is a part of qat_linear_relu -# The overall logic to insert QDQStub1 and QDQStub2 inplace is the following: -# 0. For each node in the original graph, we compute the target_dtype for input and output for it based on qconfig, for graph1, configured with qconfig_mapping, we have: -# node_name_to_target_dtype = +The overall logic to insert QDQStub1 and QDQStub2 inplace is the following: +0. For each node in the original graph, we compute the target_dtype for input and output for it based on qconfig, for graph1, configured with qconfig_mapping, we have: +``` +# node_name_to_target_dtype_info = # { # # this is placeholder node in FX Graph # “input” : {“input_activation”: torch.float32, “output_activation”: torch.float32}, @@ -197,35 +204,44 @@ Note: weight + FakeQuantize is a part of qat_linear_relu # # this is the return node in FX Graph # “output”: {“input_activation”: torch.float32, “output_activation”: torch.float32} # } -# Note: this map is generated before we insert qdqstub to graph1, and will not change in the process. -# -# 1. Inserting QDQStub1 (for input of qat_linear_relu) -# We need to look at the edge between `input` Node and `qat_linear_relu` Node here, we need to decide if we need to insert a -# QDQStub at this edge, which could serve as an input argument for `qat_linear_relu` Node (and also output for `input` Node) -# The way we decide if we want to insert QDQStub here is to figure out -# (1). The target dtype for output of `input` Node, which is torch.float32 -# (2). The target dtype for input of `qat_linear_relu` Node, which is torch.quint8 -# There is a mismatch here and (2) is a quantized dtype, so we need to insert QDQStub at the edge. -# We also need to attach observer/fakequant module to the QDQStub we inserted here. -# 2. Insert QDQStub2 (for output of qat_linear_relu) -# The logic for inserting QDQStub for output is much easier, since we assume all modules/functions in the graph produce fp32 output -# by default (we can have additional checks and extend this to work for other dtypes after we have type inference ready), -# we just need to look at the target output dtype for qat_linear_relu Node, and if it is a quantized dtype (quint8, qint8, float16), -# we would insert a QDQStub here. -# -# Questions: How to avoid inserting duplicate QDQStubs? -# e.g. when we have a single input being used by multiple ops: -# input — linear1 —- -# \--- linear2 — -# how do we make sure we only insert one QDQStub for input of both linear1 and linear2? -# input - QDQStub — linear1 - -# \ —- linear2 - -# -# The way we do it right now is before we insert QDQStub, we look at all users of `input` Node here and make sure there is no QDQStubs -# with the same target_dtype, that is, if we already inserted a QDQStub with dtype quint8 for linear1, and linear2 is also connected to it, if we request another QDQStub with dtype quint8 when processing linear2 Node, we’ll detect that the desired QDQStub already exists and do nothing - -# Question: What is the logic for keeping output to be float32? -# Let’s say the output of `qat_linear_relu` Node is configured as float32, both in qconfig_mapping and backend_config: +``` +Note: this map is generated before we insert qdqstub to graph1, and will not change in the process. + +1. Inserting QDQStub1 (for input of qat_linear_relu) + We need to look at the edge between `input` Node and `qat_linear_relu` Node here, we need to decide if we need to insert a + QDQStub at this edge, which could serve as an input argument for `qat_linear_relu` Node (and also output for `input` Node) + The way we decide if we want to insert QDQStub here is to figure out + + (1). The target dtype for output of `input` Node, which is torch.float32 + + (2). The target dtype for input of `qat_linear_relu` Node, which is torch.quint8 + There is a mismatch here and (2) is a quantized dtype, so we need to insert QDQStub at the edge. + + We also need to attach observer/fakequant module to the QDQStub we inserted here. +2. Insert QDQStub2 (for output of qat_linear_relu) + The logic for inserting QDQStub for output is much easier, since we assume all modules/functions in the graph produce fp32 output + by default (we can have additional checks and extend this to work for other dtypes after we have type inference ready), + we just need to look at the target output dtype for qat_linear_relu Node, and if it is a quantized dtype (quint8, qint8, float16), + we would insert a QDQStub here. + +Questions: How to avoid inserting duplicate QDQStubs? +e.g. when we have a single input being used by multiple ops: +``` +input — linear1 —- + \--- linear2 — +``` +how do we make sure we only insert one QDQStub for input of both linear1 and linear2? +``` +input - QDQStub — linear1 - + \ —- linear2 - +``` + +The way we do it right now is before we insert QDQStub, we look at all users of `input` Node here and make sure there is no QDQStubs +with the same target_dtype, that is, if we already inserted a QDQStub with dtype quint8 for linear1, and linear2 is also connected to it, if we request another QDQStub with dtype quint8 when processing linear2 Node, we’ll detect that the desired QDQStub already exists and do nothing + +Question: What is the logic for keeping output to be float32? +Let’s say the output of `qat_linear_relu` Node is configured as float32, both in qconfig_mapping and backend_config: +``` # qconfig_mapping (simplified, shown as dict) {'qat_linear_relu': QConfig( weight=MinMaxObserver.with_args(dtype=torch.qint8), @@ -238,33 +254,44 @@ Note: weight + FakeQuantize is a part of qat_linear_relu 'pattern': nnqat.LinearReLU, 'dtype_configs': [{input: torch.quint8, output: torch.float32, weight: torch.qint8}], } - -# What we’ll do here is when we are trying to insert output QDQStub for `qat_linear_relu`, we look at the target output dtype for this node (node_name_to_target_dtype[“qat_linear_relu”][“output_activation”], and find that it is float, which is not a quantized dtype, so -# will do nothing here. -# Note that this does not prevent other operators following `qat_linear_relu` to insert a QDQStub at the output of `qat_linear_relu`, since we are dealing with an `edge` of the graph here, and an `edge` is connected to two nodes, which means -# the output of `qat_linear_relu` will also be the input of a node following `qat_linear_relu`. ``` +What we’ll do here is when we are trying to insert output QDQStub for `qat_linear_relu`, we look at the target output dtype for this node (node_name_to_target_dtype_info[“qat_linear_relu”][“output_activation”], and find that it is float, which is not a quantized dtype, so +will do nothing here. +Note that this does not prevent other operators following `qat_linear_relu` to insert a QDQStub at the output of `qat_linear_relu`, since we are dealing with an `edge` of the graph here, and an `edge` is connected to two nodes, which means +the output of `qat_linear_relu` will also be the input of a node following `qat_linear_relu`. + `backend_config` configurations used in this step: ``` BackendConfig(nniqat.LinearReLU) - .set_observation_type(ObservationType.OUTPUT_USE_DIFFFERENT_OBSERVER_AS_INPUT) + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) .set_dtype_configs([ DTypeConfig(input_dtype=torch.quint8, output_dtype = torch.quint8, weight_dtype = torch.qint8, bias_dtype = torch.float32)] ) ``` Pattern in this case is the same as before, it defines the pattern for the subgraph we are dealing with + `set_observation_type`: sets the observation type for the patter, currently only two types: + +`OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT` means the output observer instance will be different from the input, which is the most common type of observer placement. + +`OUTPUT_SHARE_OBSERVER_WITH_INPUT` means the output observer is shared with input, they will be the same instance. This is useful for operators like cat. + +`set_dtype_configs`: sets a list of supported (activation, weight, bias, etc.) dtype combinations for qconfigs for the pattern. Note that we represent different modes of quantization (static/dynamic/`weight_only`) purely through this combination, for example, fbgemm static quantization can be represented as: ``` -OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT means the output observer instance will be different from the input, which is the most common type of observer placement. -OUTPUT_SHARE_OBSERVER_WITH_INPUT means the output observer is shared with input, they will be the same instance. This is useful for operators like cat. +{ + "input_activation": torch.quint8, + "weight": torch.qint8, + "output_activation": torch.quint8 +} ``` -`set_dtype_configs`: sets a list of supported (activation, weight, bias, etc.) dtype combinations for qconfigs for the pattern. Note that we represent different modes of quantization (static/dynamic/`weight_only`) purely through this combination, for example, fbgemm static quantization can be represented as: {"`input_activation`": torch.quint8, "weight": torch.qint8, "`output_activation`": torch.quint8} Note: the dtype config will be used to configure the support for dynamic quantization as well + Note: we may extend this to support more fine grained configurations of args, kwargs, attributes and outputs in the future -Note: we are referring to observer here, which is an implementation detail, we can change this to talk about quantization parameters instead, e.g. `QParamsType.OUTPUT_USE_DIFFERENT_QPARAMS_AS_INPUT and QParamsType.OUTPUT_USE_SAME_QPARAMS_AS_INPUT` + +Note: we are referring to observer here, which is an implementation detail, we can change this to talk about quantization parameters instead, e.g. `QParamsType.OUTPUT_USE_DIFFERENT_QPARAMS_AS_INPUT` and `QParamsType.OUTPUT_USE_SAME_QPARAMS_AS_INPUT` ### 2. Calibration/Training After we insert observers, we run the model to calibrate observers or to fine tune. This step is identical to eager mode quantization. After that the observer/fakequantize modules contain sufficient information to determine quantization parameters according to the observed data. @@ -292,7 +319,9 @@ def forward(self, x): ``` After we insert observers, we’ll need to convert the model to a reference quantized model. Reference quantized model is a model that uses reference patterns to represent quantized operators, this serves as the standard interface for quantized operators between PyTorch quantization and backend lowering passes. For more details, please take a look at this [RFC](https://github.com/pytorch/rfcs/blob/master/RFC-0019-Extending-PyTorch-Quantization-to-Custom-Backends.md). This pass is pretty straightforward, what we do is: + (1). for each QDQStub (attached with Observer for FakeQuantize modules) in the graph, we'll convert it to calls to quantize and dequantize functions based on the attributes of attached Observer and FakeQuantize modules (e.g. qscheme, dtype etc.) + (2). for weighted modules like linear/conv, we convert them to corresponding reference quantized module. Example: @@ -319,9 +348,13 @@ input - quantize - dequantize - reference_linear_relu - quantize - dequantize - Note: weight + quantize + dequantize is a part of reference_linear_relu module To decide which quantize node we want to use, we’ll look at: + (1). dtype of attached Observer/FakeQuantize module + (2). qscheme of attached Observer/FakeQuantize module + (3). (optionally) other attributes of attached Observer/FakeQuantize module + The quantize operator we can choose from right now are: (quantize_per_tensor, quantize_per_channel, to, quantize_per_tensor_dynamic) ``` @@ -329,7 +362,7 @@ backend_config configurations used in this step: BackendConfig(nniqat.LinearReLU) .set_root_module(nn.Linear) .set_reference_quantized_module_for_root(nnqr.Linear) - .set_fused_module(nni.Linear) + .set_fused_module(nni.LinearReLU) ``` Pattern in this case is the same as before, it defines the pattern for the subgraph we are dealing with @@ -338,8 +371,9 @@ Pattern in this case is the same as before, it defines the pattern for the subgr `set_reference_quantized_module_for_root`: Sets the corresponding reference quantized module class for root module class, e.g. when root_module is nn.Linear, this will be nn.quantized.reference.Linear, used to swap the root module to be a reference quantized module. -Note: we are only swapping `root_module` here, for example, in the current example, the original module is nniqat.LinearReLU, when we are converting weight modules(step (2)), we first convert nniqat.LinearReLU to a float module, in this case, the fused LinearReLU module: nni.LinearReLU, and then swap the root_module (nn.Linear) with reference quantized module (nnqr.Linear), so we end up with a nni.LinearReLU module, which is a sequential module of a nnqr.Linear and nn.ReLU. -Basically, the corresponding reference quantized module for both nniqat.LinearReLU and nni.LinearReLU would be a nni.LinearReLU sequential module (originally nn.Linear + nn.ReLU) with nn.Linear being replaced by nnqr.Linear: nni.LinearReLU(nnqr.Linear, nn.ReLU). +Note: we are only swapping `root_module` here, for example, in the current example, the original module is `nniqat.LinearReLU`, when we are converting weight modules(step (2)), we first convert `nniqat.LinearReLU` to a float module, in this case, the fused LinearReLU module: `nni.LinearReLU`, and then swap the root_module (`nn.Linear`) with reference quantized module (`nnqr.Linear`), so we end up with a `nni.LinearReLU` module, which is a sequential module of a `nnqr.Linear` and `nn.ReLU`. + +Basically, the corresponding reference quantized module for both `nniqat.LinearReLU` and `nni.LinearReLU` would be a `nni.LinearReLU` Sequential module (originally `nn.Linear` + `nn.ReLU`) with `nn.Linear` being replaced by `nnqr.Linear`: `nni.LinearReLU(nnqr.Linear, nn.ReLU)`. `set_fused_module`: This is the corresponding fused module class for the pattern, used to identify fused modules that needs to be converted to reference quantized module @@ -359,43 +393,33 @@ def forward(self, x): ``` Currently, PyTorch has native quantized backends: fbgemm and qnnpack, so we need a lowering pass to lower the reference quantized model to a model that is using native quantized operators in PyTorch. What this pass did is -* Recognize the reference patterns like: "dequantize - `float_op` - quantize" in the graph and replace them with the quantized modules (under torch.nn.quantized namespace) or operators (under torch.ops.quantized namespace, or torch namespace) + +1. Recognize the reference patterns like: "dequantize - `float_op` - quantize" in the graph and replace them with the quantized modules (under torch.nn.quantized namespace) or operators (under torch.ops.quantized namespace, or torch namespace) In general there are three types of patterns: -** Static quantization: "dequantize - `float_op` - `quantize_per_tensor`" -** Dynamic quantization: "`quantize_per_tensor_dynamic` - dequantize - `float_op`" -** Weight only quantization: + +* Static quantization: ``` - Input - float_op - output - weight - quantize_per_tensor - dequantize / +dequantize -> float_op -> quantize_per_tensor ``` -* Prepack and fold the weights for quantized linear and quantized conv operator -* The lowering pass is also going to keep some patterns for quantized operators unfused, since user may explicitly request some operators to stay in float by configuring the qconfig to be None -There are no configurations related to lowering in `backend_config` since it is backend developer’s responsibility to implement lowering pass and each of the backend developers may have their own configurations. So from end to end, `backend_config` and together with qconfig_mapping controls what Reference Quantized Model is produced by FX Graph Mode Quantization, not lowered model. - -However, for some operator based backends, like the current pytorch native backends including fbgemm and qnnpack. We could interpret `backend_config` in terms of configurations for operators as well. e.g. configuring `input_dtype`=quint8, `weight_dtype`=qint8, `output_dtype`=torch.quint8 for nn.Linear is saying that the quantized linear will take a quint8 activation and qint8 weight as input and outputs a quint8 activation. But there is no guarantee that this interpretation will always work in the future, especially when we add new flavors of quantized operators. +* Dynamic quantization: +``` +quantize_per_tensor_dynamic -> dequantize -> float_op +``` -## Extensibility -Different backend or kernel libraries may have different support for quantization. They may have different quantized operators, and the quantized operators might work for Tensors with different dtypes, the observers may need to be placed in different places. To make quantization work for different backends, and allow maximum flexibility, we also strived to make all the parts of the flow configurable with backend_config. +* Weight only quantization: +``` + input - float_op - output + weight - quantize_per_tensor - dequantize / +``` -backend_config configures quantization behavior in terms of operator patterns. We need to define a operator pattern and specify what are the supported dtypes for input/output/weight/bias for the pattern, and also specify the qat modules, reference modules etc. for the pattern, which will be used in module swapping during the quantization passes. +2. Prepack and fold the weights for quantized linear and quantized conv operator +3. The lowering pass is also going to keep some patterns for quantized operators unfused, since user may explicitly request some operators to stay in float by configuring the qconfig to be None -Quantized Backends can have different support in the following aspects: -* Quantization Scheme (symmetric vs asymmetric, per-channel vs per-tensor) -* Data Type (float32, float16, int8, uint8, bfloat16, etc) for input/output/weight/bias -* Quantized (and Fused) Operators and Mapping The quantized operators supported by the backend. For example: quantized conv2d, quantized linear etc. Some quantized operators may have different numerics compared to a naive (dequant - float_op - quant) implementation For weighted operators (conv and linear) we need to define a reference module and a mapping -* QAT Module Mapping For modules with weights, e.g. Conv2d and Linear, we need to swap them with qat (quantization aware training) module that adds fake quantization to the weights +There are no configurations related to lowering in `backend_config` since it is backend developer’s responsibility to implement lowering pass and each of the backend developers may have their own configurations. So from end to end, `backend_config` and together with qconfig_mapping controls what Reference Quantized Model is produced by FX Graph Mode Quantization, not lowered model. -As an example, here is what fbgemm looks like: -+-------------------------------------------+-----------------------------------------------------------------------+ -| | fbgemm | -|-------------------------------------------|-----------------------------------------------------------------------| -| Quantization Scheme | activation: per tensor, weight: per tensor or per channel | -| Data Type | activation: quint8 (with qmin/qmax range restrictions), weight: qint8 | -| Quantized and Fused Operators and Mapping | e.g. nn.Conv2d -> torch.ao.nn.quantized.reference.Conv2d | -| QAT Module Mapping | nn.Conv -> torch.ao.nn.qat.Conv2d | -+-------------------------------------------+-----------------------------------------------------------------------+ +However, for some operator based backends, like the current pytorch native backends including fbgemm and qnnpack. We could interpret `backend_config` in terms of configurations for operators as well. e.g. configuring `input_dtype=quint8`, `weight_dtype=qint8`, `output_dtype=torch.quint8` for nn.Linear is saying that the quantized linear will take a `quint8` activation and `qint8` weight as input and outputs a `quint8` activation. But there is no guarantee that this interpretation will always work in the future, especially when we add new flavors of quantized operators. -So instead of hardcoding the fusion mappings, float to quantized module mappings, fusion patterns etc. we will derive everything through `backend_config` throughout the code base. This allows PyTorch Quantization to work for all first-party or third-party backends that may differ from native backends in different aspects. +## Extensibility -For use cases, we will use TensorRT as an example use case and have a tutorial talking about `backend_config`, pytorch native backends fbgemm and qnnpack will be using this to define their behaviors as well, especially with the recent addition of xnnpack (integrated as a part of qnnpack backend in pytorch), the `backend_config` api is needed to define the new constraints from xnnpack. +FX graph mode quantization can be extended to work with different backends, which may have different sets of supported quantized operator patterns and different requirements for each pattern. For more detail, please refer to the [BackendConfig README](/torch/ao/quantization/backend_config/README.md). diff --git a/torch/ao/quantization/fx/_decomposed.py b/torch/ao/quantization/fx/_decomposed.py new file mode 100644 index 0000000000000..a6f5ad7a3d0b9 --- /dev/null +++ b/torch/ao/quantization/fx/_decomposed.py @@ -0,0 +1,326 @@ +import torch +from torch.library import Library, impl +from torch.ao.quantization import MinMaxObserver +from typing import Tuple + +# Note: decomposed means decomposed quantized tensor, using decomposed so that the +# name is not too long +quantized_decomposed_lib = Library("quantized_decomposed", "DEF") + +_DTYPE_TO_QVALUE_BOUNDS = { + torch.uint8: (0, 255), + torch.int8: (-128, 127), + torch.int32: (-(2**31), 2**31 - 1) +} + +# Helper to check the passed in quant min and max are valid for the dtype +def _quant_min_max_bounds_check(quant_min, quant_max, dtype): + if dtype not in _DTYPE_TO_QVALUE_BOUNDS: + raise ValueError(f"Unsupported dtype: {dtype}") + quant_min_lower_bound, quant_max_upper_bound = _DTYPE_TO_QVALUE_BOUNDS[dtype] + + assert quant_min >= quant_min_lower_bound, \ + "quant_min out of bound for dtype, " \ + f"quant_min_lower_bound: {quant_min_lower_bound} quant_min: {quant_min}" + + assert quant_max <= quant_max_upper_bound, \ + "quant_max out of bound for dtype, " \ + f"quant_max_upper_bound: {quant_max_upper_bound} quant_max: {quant_max}" + +quantized_decomposed_lib.define( + "quantize_per_tensor(Tensor input, float scale, int zero_point, " + "int quant_min, int quant_max, ScalarType dtype) -> Tensor") + +@impl(quantized_decomposed_lib, "quantize_per_tensor", "CompositeExplicitAutograd") +def quantize_per_tensor( + input: torch.Tensor, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype +) -> torch.Tensor: + """ Affine quantization for the Tensor using the same quantization parameters to map + from floating point to quantized values + + Args: + input (torch.Tensor): original float32 Tensor + scale (float): quantization parameter for affine quantization + zero_point (int): quantization parameter for affine quantization + quant_min (int): minimum quantized value for output Tensor + quant_max (int): maximum quantized value for output Tensor + dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor + + Returns: + Tensor with requested dtype (e.g. torch.uint8), note the quantization parameters + are not stored in the Tensor, we are storing them in function arguments instead + """ + assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" + _quant_min_max_bounds_check(quant_min, quant_max, dtype) + + inv_scale = 1.0 / scale + return torch.clamp(torch.round(input * inv_scale) + zero_point, quant_min, quant_max).to(dtype) + +quantized_decomposed_lib.define( + "quantize_per_tensor.tensor(Tensor input, Tensor scale, Tensor zero_point, " + "int quant_min, int quant_max, ScalarType dtype) -> Tensor") + +@impl(quantized_decomposed_lib, "quantize_per_tensor.tensor", "CompositeExplicitAutograd") +def quantize_per_tensor_tensor( + input: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + quant_min: int, + quant_max: int, + dtype: torch.dtype +) -> torch.Tensor: + """ Affine quantization for the Tensor using the same quantization parameters to map + from floating point to quantized values + Same as `quantize_per_tensor` but scale and zero_point are Scalar Tensor instead of + scalar values + """ + assert zero_point.numel() == 1, f"Exepecting zero_point tensor to be one element, but received : {zero_point.numel()}" + assert scale.numel() == 1, f"Exepecting scale tensor to be one element, but received : {scale.numel()}" + return quantize_per_tensor(input, scale.item(), zero_point.item(), quant_min, quant_max, dtype) + +@impl(quantized_decomposed_lib, "quantize_per_tensor.tensor", "Meta") +def quantize_per_tensor_tensor_meta(input, scale, zero_point, quant_min, quant_max, dtype): + assert zero_point.numel() == 1, f"Exepecting zero_point tensor to be one element, but received : {zero_point.numel()}" + assert scale.numel() == 1, f"Exepecting scale tensor to be one element, but received : {scale.numel()}" + assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" + _quant_min_max_bounds_check(quant_min, quant_max, dtype) + return torch.empty_like(input, dtype=dtype) + +# Note: quant_min/quant_max/dtype are not used in the operator, but for now it's kept in +# the signature as metadata for the input Tensor, this might be useful for pattern +# matching in the future +# We will revisit this later if we found there are no use cases for it +quantized_decomposed_lib.define( + "dequantize_per_tensor(Tensor input, float scale, int zero_point, " + "int quant_min, int quant_max, ScalarType dtype) -> Tensor") + +@impl(quantized_decomposed_lib, "dequantize_per_tensor", "CompositeExplicitAutograd") +def dequantize_per_tensor( + input: torch.Tensor, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype +) -> torch.Tensor: + """ Affine dequantization for the Tensor using the same quantization parameters to map + from quantized values to floating point values + + Args: + input (torch.Tensor): Tensor with dtype matching `dtype` argument, + e.g. (`torch.uint8`), it is a per tensor quantized Tensor if combined with + quantization parameters in the argument of this function (scale/zero_point) + + scale (float): quantization parameter for affine quantization + + zero_point (int): quantization parameter for affine quantization + + quant_min (int): minimum quantized value for input Tensor (not used in computation, + reserved for pattern matching) + + quant_max (int): maximum quantized value for input Tensor (not used in computation, + reserved for pattern matching) + + dtype (torch.dtype): dtype for input Tensor (not used in computation, + reserved for pattern matching) + + Returns: + dequantized float32 Tensor + """ + assert input.dtype == dtype, f"Expecting input to have dtype: {dtype}" + if dtype in [torch.uint8, torch.int8, torch.int32]: + # TODO: investigate why + # (input - zero_point).to(torch.float32) * scale + # failed the test + return (input.to(torch.float32) - zero_point) * scale + else: + raise ValueError(f"Unsupported dtype in dequantize_per_tensor: {dtype}") + + +quantized_decomposed_lib.define( + "dequantize_per_tensor.tensor(Tensor input, Tensor scale, Tensor zero_point, " + "int quant_min, int quant_max, ScalarType dtype) -> Tensor") + +@impl(quantized_decomposed_lib, "dequantize_per_tensor.tensor", "CompositeExplicitAutograd") +def dequantize_per_tensor_tensor( + input: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + quant_min: int, + quant_max: int, + dtype: torch.dtype +) -> torch.Tensor: + """ Affine dequantization for the Tensor using the same quantization parameters to map + from quantized values to floating point values + Same as `dequantize_per_tensor` but scale and zero_point are Scalar Tensor instead of + scalar values + """ + assert zero_point.numel() == 1, f"Exepecting zero_point tensor to be one element, but received : {zero_point.numel()}" + assert scale.numel() == 1, f"Exepecting scale tensor to be one element, but received : {scale.numel()}" + return dequantize_per_tensor(input, scale.item(), zero_point.item(), quant_min, quant_max, dtype) + +@impl(quantized_decomposed_lib, "dequantize_per_tensor.tensor", "Meta") +def dequantize_per_tensor_tensor_meta(input, scale, zero_point, quant_min, quant_max, dtype): + assert zero_point.numel() == 1, f"Exepecting zero_point tensor to be one element, but received : {zero_point.numel()}" + assert scale.numel() == 1, f"Exepecting scale tensor to be one element, but received : {scale.numel()}" + assert input.dtype == dtype, f"Expecting input to have dtype: {dtype}" + if dtype in [torch.uint8, torch.int8, torch.int32]: + return torch.empty_like(input, dtype=torch.float32) + else: + raise ValueError(f"Unsupported dtype in dequantize_per_tensor: {dtype}") + + +quantized_decomposed_lib.define( + "choose_qparams.tensor(Tensor input, int quant_min, int quant_max, " + "ScalarType dtype) -> (Tensor, Tensor)") + +@impl(quantized_decomposed_lib, "choose_qparams.tensor", "CompositeExplicitAutograd") +def choose_qparams_tensor( + input: torch.Tensor, + quant_min: int, + quant_max: int, + dtype: torch.dtype +) -> Tuple[float, int]: + """ Given an input Tensor, derive the per tensor affine quantization parameter + (scale and zero_point) for target quantized Tensor from the Tensor + + Args: + input (torch.Tensor): floating point input Tensor + quant_min (int): minimum quantized value for target quantized Tensor + quant_max (int): maximum quantized value for target quantized Tensor + dtype (torch.dtype): dtype for target quantized Tensor + + Returns: + scale (float): quantization parameter for the target quantized Tensor + zero_point (int): quantization parameter for the target quantized Tensor + """ + assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" + assert quant_min < quant_max, f"Expecting quant_min to be smaller than quant_max but received min: {quant_min} max: {quant_max}" + + # Its weird to create an observer manually just to calculate qparams. I tried refactoring this functionality out of observer + # into a util and then use that util directly, but I kept running into jit typing errors related to torch.qscheme not + # being recognized as a type. TODO: properly refactor this out to avoid observer overhead + tensor_dtype_to_observer_dtype = {torch.uint8: torch.quint8, torch.int8: torch.qint8} + observer = MinMaxObserver(quant_min=quant_min, quant_max=quant_max, dtype=tensor_dtype_to_observer_dtype[dtype]) + observer(input) + scale, zero_point = observer.calculate_qparams() + return (scale, zero_point) + +# Helper function used to implement per-channel quantization against any axis +def _permute_to_axis_zero(x, axis): + new_axis_list = list(range(x.dim())) + new_axis_list[axis] = 0 + new_axis_list[0] = axis + y = x.permute(tuple(new_axis_list)) + return y, new_axis_list + +quantized_decomposed_lib.define( + "quantize_per_channel(Tensor input, Tensor scales, Tensor zero_points, int axis, " + "int quant_min, int quant_max, ScalarType dtype) -> Tensor") + +@impl(quantized_decomposed_lib, "quantize_per_channel", "CompositeExplicitAutograd") +def quantize_per_channel( + input: torch.Tensor, + scales: torch.Tensor, + zero_points: torch.Tensor, + axis: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype +) -> torch.Tensor: + """ Affine per channel quantization for the Tensor using the same quantization + parameters for each channel/axis to map from floating point to quantized values + + Args: + input (torch.Tensor): original float32 Tensor + scales (torch.Tensor): a list of scale quantization parameter for + affine quantization, one per channel + zero_point (torch.Tensor): a list of zero_point quantization parameter for + affine quantization, one per channel + quant_min (int): minimum quantized value for output Tensor + quant_max (int): maximum quantized value for output Tensor + dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor + + Returns: + Tensor with requested dtype (e.g. torch.uint8), note the quantization parameters + are not stored in the Tensor, we are storing them in function arguments instead + """ + assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" + assert axis < input.dim(), f"Expecting axis to be < {input.dim()}" + _quant_min_max_bounds_check(quant_min, quant_max, dtype) + input, permute_axis_list = _permute_to_axis_zero(input, axis) + res = torch.zeros_like(input) + + for i in range(input.size(0)): + res[i] = torch.clamp( + torch.round(input[i] * (1.0 / scales[i])) + zero_points[i], + quant_min, + quant_max + ) + + out = res.permute(tuple(permute_axis_list)) + return out.to(dtype) + +# Note: quant_min/quant_max/dtype are not used in the operator, but for now it's kept in +# the signature as metadata for the input Tensor, this might be useful for pattern +# matching in the future +# We will revisit this later if we found there are no use cases for it +quantized_decomposed_lib.define( + "dequantize_per_channel(Tensor input, Tensor scales, Tensor zero_points, int axis, " + "int quant_min, int quant_max, ScalarType dtype) -> Tensor") + +@impl(quantized_decomposed_lib, "dequantize_per_channel", "CompositeExplicitAutograd") +def dequantize_per_channel( + input: torch.Tensor, + scales: torch.Tensor, + zero_points: torch.Tensor, + axis: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype +) -> torch.Tensor: + """ Affine per channel dequantization for the Tensor using the same quantization + parameters for each channel/axis to map from quantized values to floating point values + + Args: + input (torch.Tensor): Tensor with dtype matching `dtype` argument, + e.g. (`torch.uint8`), it is a per channel quantized Tensor if combined with + quantization parameter in the argument of this function (scales/zero_points/axis) + + scales (torch.Tensor): a list of scale quantization parameter for + affine quantization, one per channel + + zero_points (torch.Tensor): a list of zero_point quantization parameter for + affine quantization, one per channel + + quant_min (int): minimum quantized value for output Tensor (not used in computation, + reserved for pattern matching) + + quant_max (int): maximum quantized value for output Tensor (not used in computation, + reserved for pattern matching) + + dtype (torch.dtype): requested dtype for output Tensor (not used in computation, + reserved for pattern matching) + + Returns: + dquantized float32 Tensor + """ + assert input.dtype == dtype, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" + assert axis < input.dim(), f"Expecting axis to be < {input.dim()}" + _quant_min_max_bounds_check(quant_min, quant_max, dtype) + input, permute_axis_list = _permute_to_axis_zero(input, axis) + res = torch.zeros_like(input, dtype=torch.float32) + + for i in range(input.size(0)): + # TODO: investigate why + # (input[i] - zero_points[i]).to(torch.float32) * scales[i] + # failed the test + res[i] = (input[i].to(torch.float32) - zero_points[i]) * scales[i] + + out = res.permute(tuple(permute_axis_list)) + return out diff --git a/torch/ao/quantization/fx/_lower_to_native_backend.py b/torch/ao/quantization/fx/_lower_to_native_backend.py index aa71fafbf00e9..c60385537fe41 100644 --- a/torch/ao/quantization/fx/_lower_to_native_backend.py +++ b/torch/ao/quantization/fx/_lower_to_native_backend.py @@ -111,6 +111,9 @@ def is_copy_node(node, modules): torch.flatten, torch.mean, operator.floordiv, + # F.channel_shuffle and torch.channel_shuffle are essentially the same thing + # so we only need to put one of them here + torch.channel_shuffle, ] method_list = [ "clamp", @@ -131,6 +134,7 @@ def is_copy_node(node, modules): torch.nn.MaxPool3d, torch.nn.ReLU, torch.nn.ReLU6, + torch.nn.ChannelShuffle, ] return _is_node_in_list(node, modules, func_list, method_list, module_type_list) @@ -280,7 +284,7 @@ def should_skip_lowering(op: torch.fx.node.Node, qconfig_map: Dict[str, QConfigA } # Mapping from a functional to a dictionary, where the key is a 2-tuple of -# (activation_compute_dtype, weight_dtype) and the value is a 2-tuple of +# (input_activation_dtype, weight_dtype) and the value is a 2-tuple of # 1) The dynamically quantized version of the op # 2) The dynamically quantized version of the op fused with relu, if it exists, else None DYNAMIC_LOWER_FUNCTIONAL_MAP: Dict[Callable, Dict[Tuple[torch.dtype, torch.dtype], Tuple[Callable, Optional[Callable]]]] = { @@ -533,9 +537,9 @@ def _lower_dynamic_weighted_ref_module(model: QuantizedGraphModule): input_dynamic_q_node.target != torch.quantize_per_tensor_dynamic: continue - activation_compute_dtype = input_dynamic_q_node.args[1] - is_fp16 = activation_compute_dtype == torch.float16 - is_int8 = activation_compute_dtype in [torch.quint8, torch.qint8] + activation_dtype = input_dynamic_q_node.args[1] + is_fp16 = activation_dtype == torch.float16 + is_int8 = activation_dtype in [torch.quint8, torch.qint8] if not is_int8 and not is_fp16: continue @@ -688,9 +692,9 @@ def _lower_dynamic_weighted_ref_functional( continue reduce_range_node = None - (pattern_input, activation_compute_dtype, reduce_range_node) = input_dynamic_q_node.args - is_fp16 = activation_compute_dtype == torch.float16 - is_int8 = activation_compute_dtype in [torch.quint8, torch.qint8] + (pattern_input, activation_dtype, reduce_range_node) = input_dynamic_q_node.args + is_fp16 = activation_dtype == torch.float16 + is_int8 = activation_dtype in [torch.quint8, torch.qint8] if not is_int8 and not is_fp16: continue @@ -698,7 +702,7 @@ def _lower_dynamic_weighted_ref_functional( weight_dtype = quantized_weight.args[-1] # Step 1: Try to select reference pattern with the corresponding quantized op - dynamic_quant_dtype_key = (activation_compute_dtype, weight_dtype) + dynamic_quant_dtype_key = (activation_dtype, weight_dtype) if dynamic_quant_dtype_key not in DYNAMIC_LOWER_FUNCTIONAL_MAP[func_node.target]: print(f"Didn't find dtype combination {dynamic_quant_dtype_key} during " f"dynamic quantized op lowering for {func_node.target}") @@ -825,7 +829,8 @@ def special_pattern_replacement(model: QuantizedGraphModule): is_call_function, is_call_method, is_call_module = is_special_pattern_node(ref_node, modules) if not (is_call_module or is_call_function or is_call_method): continue - dq_node_or_nodes = ref_node.args[0] + assert len(ref_node.args) > 0 or len(ref_node.kwargs) > 0 + dq_node_or_nodes = ref_node.args[0] if len(ref_node.args) > 0 else list(ref_node.kwargs.values())[0] assert isinstance(dq_node_or_nodes, Node) or isinstance(dq_node_or_nodes, (tuple, list)) is_dequantize = False if isinstance(dq_node_or_nodes, Node): diff --git a/torch/ao/quantization/fx/_model_report/README.md b/torch/ao/quantization/fx/_model_report/README.md index 0c4943ad6a755..6275b49b54e2b 100644 --- a/torch/ao/quantization/fx/_model_report/README.md +++ b/torch/ao/quantization/fx/_model_report/README.md @@ -5,7 +5,7 @@ ModelReport > ⚠️ *While the example below uses the Fx Workflow, the use of the ModelReport class **does not depend** on the Fx Workflow to work*. The requirements are detector dependent. - Most detectors require a **traceable GraphModule**, but some (ex. `PerChannelDetector`) require just a `nn.Module`. + Most detectors require a **traceable GraphModule**, but some (ex. `PerChannelDetector`) require just an `nn.Module`. #### Typical Fx Workflow - Initialize model → Prepare model → Callibrate model → Convert model → ... @@ -32,7 +32,7 @@ model_report = ModelReport(model, detector_set) ready_for_callibrate = model_report.prepare_detailed_callibration() # callibrate model and generate report -ready_for_callibrate(example_input) # TODO run callibration of model with relavent data +ready_for_callibrate(example_input) # TODO run callibration of model with relevant data reports = model_report.generate_model_report(remove_inserted_observers=True) for report_name in report.keys(): text_report, report_dict = reports[report_name] @@ -61,8 +61,8 @@ This is so that we can keep track of where we want to insert observers on a dete - `prepare_detailed_calibration(self)` → `GraphModule` inserts observers into the locations specified by each detector in the model. It then returns the GraphModule with the detectors inserted into both the regular module structure as well as the node structure. - `generate_model_report(self, remove_inserted_observers: bool)` → `Dict[str, Tuple[str, Dict]]` uses callibrated GraphModule to optionally removes inserted observers, and generate, for each detector the ModelReport instance was initialized with: - - A string-based report that is easily digestable and actionable explaining the data collected by relavent observers for that detector - - A dictionary containing statistics collected by the relavent observers and values calculated by the detector for futher analysis or plotting + - A string-based report that is easily digestable and actionable explaining the data collected by relevant observers for that detector + - A dictionary containing statistics collected by the relevant observers and values calculated by the detector for further analysis or plotting ## ModelReportVisualizer Overview @@ -127,21 +127,21 @@ return_dict = { "[unique_observer_fqn_of_insert_location]" : { "target_node" -> the node we are trying to observe with this observer (torch.fx.node.Node), - "insert_observer" -> the intialized observer we wish to insert (ObserverBase), + "insert_observer" -> the initialized observer we wish to insert (ObserverBase), "insert_post" -> True if this is meant to be a post-observer for target_node, False if pre-observer, "observer_args" -> The arguments that are meant to be passed into the observer, } } ``` - `get_detector_name(self)` -> `str`: returns the name of the detector. -You should give your detector a unique name different from exisiting detectors. +You should give your detector a unique name different from existing detectors. - `generate_detector_report(self, model)` -> `Tuple[str, Dict[str, Any]]`: generates a report based on the information the detector is trying to collect. This report consists of both a text-based report as well as a dictionary of collected and calculated statistics. This report is returned to the `ModelReport` instance, which will then compile all the reports of all the Detectors requested by the user. ## ModelReportObserver Overview -As seen in the [requirments to implement a detector section](#requirements-to-implement-a-detector), one of the key parts of implementing a detector is to specify what `Observer` we are trying to insert. +As seen in the [requirements to implement a detector section](#requirements-to-implement-a-detector), one of the key parts of implementing a detector is to specify what `Observer` we are trying to insert. All the detectors in the ModelReport API use the [`ModelReportObserver`](https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/fx/_model_report/model_report_observer.py). While the core purpose of many observers in PyTorch's Quantization API is to collect min / max information to help determine quantization parameters, the `ModelReportObserver` collects additional statistics. @@ -152,7 +152,7 @@ The statistics collected by the `ModelReportObserver` include: - Ratio of 100th percentile to some *n*th percentile - Number of constant value batches to pass through each channel -After the `ModelReportObserver` collects the statistics above during the callibration process, the detectors then extract the information they need to generate their reports from the relavent observers. +After the `ModelReportObserver` collects the statistics above during the callibration process, the detectors then extract the information they need to generate their reports from the relevant observers. ### Using Your Own Observer @@ -187,7 +187,7 @@ Since you are also implementing your own detector in this case, it is up to you - A line plot (for both per-tensor and per-channel statistics) - A histogram (for both per-tensor and per-channel statistics) - `model_report.py`: File containing the `ModelReport` class - - Main class users are interacting with to go through the ModelReport worflow + - Main class users are interacting with to go through the ModelReport workflow - API described in detail in [Overview section](#modelreport-overview) # Tests @@ -200,7 +200,7 @@ These tests include: - Test class for the `ModelReportVisualizer` class - Test class for **each** of the implemented Detectors -If you wish to add a Detector, make sure to create a test class modeled after one of the exisiting classes and test your detector. +If you wish to add a Detector, make sure to create a test class modeled after one of the existing classes and test your detector. Because users will be interacting with the Detectors through the `ModelReport` class and not directly, ensure that the tests follow this as well. # Future Tasks and Improvements diff --git a/torch/ao/quantization/fx/_model_report/detector.py b/torch/ao/quantization/fx/_model_report/detector.py index 239137aaaabba..c92733bbc1c32 100644 --- a/torch/ao/quantization/fx/_model_report/detector.py +++ b/torch/ao/quantization/fx/_model_report/detector.py @@ -10,7 +10,7 @@ from torch.ao.quantization.qconfig import ( QConfig, default_qconfig, - assert_valid_qconfig, + _assert_valid_qconfig, ) from torch.ao.quantization.observer import ( ObserverBase, @@ -84,7 +84,7 @@ def generate_quantization_qconfig(self, module: torch.nn.Module) -> QConfig: weight = default_per_channel_weight_observer if rec[1] else default_weight_observer test_config = QConfig(activation, weight) try: - assert_valid_qconfig(test_config, module) + _assert_valid_qconfig(test_config, module) module_qconfig = test_config break except AssertionError: diff --git a/torch/ao/quantization/fx/_model_report/model_report.py b/torch/ao/quantization/fx/_model_report/model_report.py index dfe777a540585..ee96dd4bf5a9c 100644 --- a/torch/ao/quantization/fx/_model_report/model_report.py +++ b/torch/ao/quantization/fx/_model_report/model_report.py @@ -385,7 +385,7 @@ def _reformat_reports_for_visualizer(self) -> OrderedDict: module_fqns_to_features[module_fqn] = {**new_info, **present_info} else: error_str = "You have the same key with different values across detectors. " - error_str += "Someone incorrectly implemented a detector with conflicting keys to exisiting detectors." + error_str += "Someone incorrectly implemented a detector with conflicting keys to existing detectors." raise ValueError(error_str) else: # we just set it diff --git a/torch/ao/quantization/fx/backend_config_utils.py b/torch/ao/quantization/fx/backend_config_utils.py deleted file mode 100644 index eef4979a0a064..0000000000000 --- a/torch/ao/quantization/fx/backend_config_utils.py +++ /dev/null @@ -1,144 +0,0 @@ -import torch -from torch.ao.quantization.fx.pattern_utils import get_default_quant_patterns, sorted_patterns_dict -from torch.ao.quantization.backend_config import ( - get_native_backend_config, - ObservationType, -) -from torch.ao.quantization.utils import ( - activation_dtype, - get_combined_dict, - Pattern, - NodePattern, - QuantizerCls, -) - -from ..backend_config import BackendConfig -from .quantization_patterns import QuantizeHandler -from .fusion_patterns import DefaultFuseHandler - -from typing import Dict, Any, Callable, Optional - -def get_quantize_handler_cls( - observation_type, - dtype_configs, - num_tensor_args_to_observation_type, - overwrite_output_fake_quantizer, - overwrite_output_observer, - input_output_observed): - - class ConfigurableQuantizeHandler(QuantizeHandler): - def __init__( - self, - node_pattern: NodePattern, - modules: Dict[str, torch.nn.Module], - root_node_getter: Callable = None): - super().__init__(node_pattern, modules, root_node_getter) - if num_tensor_args_to_observation_type: - assert self.num_tensor_args in num_tensor_args_to_observation_type, \ - f"Must provide observation_type config for tensor number {self.num_tensor_args}" \ - f" in num_tensor_args_to_observation_type for {node_pattern}" - self.observation_type = num_tensor_args_to_observation_type[self.num_tensor_args] - else: - self.observation_type = observation_type - self.dtype_configs = dtype_configs - self.overwrite_output_fake_quantizer = overwrite_output_fake_quantizer - self.overwrite_output_observer = overwrite_output_observer - self.input_output_observed_ = input_output_observed - - def is_general_tensor_value_op(self) -> bool: - return self.observation_type == ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT - - # TODO: change this to output activation - def get_activation_ctr( - self, - qconfig: Any, - pattern: Pattern, - is_training: bool, - ) -> Optional[Callable]: - """ - Returns the constructor for the activation observer which should be - used for the pattern matched to this handler. Some handlers override - this to a different value than what is specified in the qconfig. - """ - act_dtype = activation_dtype(qconfig) - # TODO: change to is_qat - if is_training: - if act_dtype == torch.quint8 and self.overwrite_output_fake_quantizer is not None: - return self.overwrite_output_fake_quantizer - else: - if act_dtype == torch.quint8 and self.overwrite_output_observer is not None: - return self.overwrite_output_observer - return qconfig.activation - - # This is temporary, and will be removed soon - def input_output_observed(self): - return self.input_output_observed_ - - - return ConfigurableQuantizeHandler - -def get_pattern_to_quantize_handlers(backend_config: BackendConfig) -> Dict[Pattern, QuantizerCls]: - """ - Note: Quantize handler is just a holder for some check methods like - (should_insert_observer_for_output), maybe this can be a enum as well, - we can refactor this after we convert the path for fbgemm/qnnpack fully to the - new path, this is not exposed to backend developers - """ - pattern_to_quantize_handlers = {} - for pattern, config in backend_config.configs.items(): - observation_type = config.observation_type - dtype_configs = config.dtype_configs - num_tensor_args_to_observation_type = config._num_tensor_args_to_observation_type - overwrite_fake_quantizer = config._overwrite_output_fake_quantize - overwrite_observer = config._overwrite_output_observer - input_output_observed = config._input_output_observed - if input_output_observed is None: - input_output_observed = True - pattern_to_quantize_handlers[pattern] = \ - get_quantize_handler_cls( - observation_type, - dtype_configs, - num_tensor_args_to_observation_type, - overwrite_fake_quantizer, - overwrite_observer, - input_output_observed) - - return pattern_to_quantize_handlers - -# TODO: move this to torch/ao/quantization/backend_config/utils.py -def get_fusion_pattern_to_fuse_handler_cls( - backend_config: BackendConfig) -> Dict[Pattern, Callable]: - fusion_pattern_to_fuse_handlers: Dict[Pattern, Callable] = {} - for pattern, config in backend_config.configs.items(): - if config.fuser_method is not None: - # TODO: is this logic right? - fusion_pattern_to_fuse_handlers[pattern] = DefaultFuseHandler - - return fusion_pattern_to_fuse_handlers - -# TODO: remove when all uses are changed to backend_config -def get_native_quant_patterns(additional_quant_patterns: Dict[Pattern, QuantizerCls] = None) -> Dict[Pattern, QuantizerCls]: - """ - Return a map from pattern to quantize handlers based on the default patterns and the native backend_config. - The returned map is sorted such that longer patterns will be encountered first when iterating through it. - """ - patterns = get_default_quant_patterns() - if additional_quant_patterns is not None: - patterns = get_combined_dict(patterns, additional_quant_patterns) - # TODO: currently we just extend the quantize handlers generated from - # `get_native_backend_config` - # in the future we can just assign backend_config when everything is defined - for pattern, quantize_handler in get_pattern_to_quantize_handlers(get_native_backend_config()).items(): - patterns[pattern] = quantize_handler - return sorted_patterns_dict(patterns) - -get_fusion_pattern_to_fuse_handler_cls.__module__ = "torch.ao.quantization.fx.backend_config_utils" -get_native_quant_patterns.__module__ = "torch.ao.quantization.fx.backend_config_utils" -get_pattern_to_quantize_handlers.__module__ = "torch.ao.quantization.fx.backend_config_utils" - -__all__ = [ - "get_quantize_handler_cls", - "get_fusion_pattern_to_fuse_handler_cls", - "get_native_quant_patterns", - "get_pattern_to_quantize_handlers", -] diff --git a/torch/ao/quantization/fx/convert.py b/torch/ao/quantization/fx/convert.py index aa402e882abc8..e795b3bca8584 100644 --- a/torch/ao/quantization/fx/convert.py +++ b/torch/ao/quantization/fx/convert.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Set, Tuple, Union, Type +from typing import Any, Dict, List, Optional, Set, Tuple, Union, Type, Callable from torch.ao.quantization.quant_type import QuantType import torch import copy @@ -23,14 +23,12 @@ qconfig_equals ) from ..qconfig_mapping import QConfigMapping -from ..qconfig_mapping_utils import ( - update_qconfig_for_qat, -) from .qconfig_mapping_utils import ( generate_node_name_to_qconfig, compare_prepare_convert_qconfig_mappings, update_qconfig_for_fusion, is_qconfig_supported_by_dtype_configs, + _update_qconfig_for_qat, ) from torch.ao.quantization.backend_config.utils import ( get_root_module_to_quantized_reference_module, @@ -53,12 +51,15 @@ _get_module, _is_custom_module_lstm, get_custom_module_class_keys, - get_quantize_node_info, create_getattr_from_value, collect_producer_nodes, graph_module_from_producer_nodes, node_arg_is_weight, ) +from torch.ao.quantization.utils import ( + is_per_channel, + to_underlying_dtype, +) from torch.ao.quantization.quantize import ( _remove_qconfig, is_activation_post_process, @@ -69,7 +70,9 @@ PrepareCustomConfig, ) from .lower_to_fbgemm import lower_to_fbgemm - +# importing the lib so that the quantized_decomposed ops are registered +from ._decomposed import quantized_decomposed_lib # noqa: F401 +import operator # TODO: revisit this list. Many helper methods shouldn't be public __all__ = [ @@ -86,6 +89,369 @@ "run_weight_observers", ] +def _replace_observer_with_quantize_dequantize_node_decomposed( + model: torch.nn.Module, + graph: Graph, + node: Node, + modules: Dict[str, torch.nn.Module], + node_name_to_scope: Dict[str, Tuple[str, type]], + node_name_to_qconfig: Dict[str, QConfigAny]) -> None: + """ Replace activation_post_process module call node with quantize and + dequantize node working with decomposed Tensor + + Before: + ... -> observer_0(x) -> ... + After: + ... -> torch.ops.quantized_decomposed.quantize_per_tensor(x, ...) -> + torch.ops.quantized_decomposed.dequantize_per_tensor() -> ... + + or quantize_per_channel and dequantize_per_channel + """ + assert modules is not None + assert isinstance(node.target, str) + module_path, prefix = get_module_path_and_prefix(node, node_name_to_scope, node_name_to_qconfig) + activation_post_process = modules[node.target] + # skip replacing observers to quant/dequant nodes if the qconfigs of all + # consumers and producers of this observer are None + skip_replacement = all([ + has_none_qconfig(n, node_name_to_qconfig) for n in + list(node.args) + list(node.users.keys())]) + if skip_replacement or not _is_conversion_supported(activation_post_process): + # didn't find correponding quantize op and info for the activation_post_process + # so we just remove the observer + with graph.inserting_before(node): + node.replace_all_uses_with(node.args[0]) + graph.erase_node(node) + return + + # otherwise, we can convert the activation_post_process module call to quantize/dequantize node + + # 1. extract the information from activation_post_process module for generating + # the quantize and dequantize operator + dtype = activation_post_process.dtype # type: ignore[attr-defined] + + is_dynamic = False + if hasattr(activation_post_process, "is_dynamic"): + is_dynamic = activation_post_process.is_dynamic # type: ignore[assignment] + + if dtype in [torch.quint8, torch.qint8, torch.qint32] and \ + (not is_dynamic): + # TODO: probably should cleanup this condition check, it's hard + # to reason about this if and the following elif + + # uint8/int8/int32 static quantization branch + + # 1. extract information for inserting q/dq node from activation_post_process + node_type = "call_function" + quantize_op : Optional[Callable] = None + scale, zero_point = activation_post_process.calculate_qparams() # type: ignore[attr-defined, operator] + if is_per_channel(activation_post_process.qscheme): # type: ignore[attr-defined] + ch_axis = int(activation_post_process.ch_axis) # type: ignore[attr-defined, arg-type] + quantize_op = torch.ops.quantized_decomposed.quantize_per_channel + dequantize_op = torch.ops.quantized_decomposed.dequantize_per_channel + quant_min = activation_post_process.quant_min + quant_max = activation_post_process.quant_max + dtype_ = to_underlying_dtype(dtype) + qparams = { + "_scale_": scale, + "_zero_point_": zero_point, + "_axis_": ch_axis, + "_quant_min_": quant_min, + "_quant_max_": quant_max, + "_dtype_": dtype_ + } + else: + quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor + dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor + scale = float(scale) + zero_point = int(zero_point) + quant_min = activation_post_process.quant_min # type: ignore[attr-defined] + quant_max = activation_post_process.quant_max # type: ignore[attr-defined] + dtype_ = to_underlying_dtype(dtype) + qparams = { + "_scale_": scale, + "_zero_point_": zero_point, + "_quant_min_": quant_min, + "_quant_max_": quant_max, + "_dtype_": dtype_ + } + + # 2. replace activation_post_process node with quantize and dequantize + with graph.inserting_before(node): + input_node = node.args[0] + quantize_op_inputs = [input_node] + for key, value_or_node in qparams.items(): + # TODO: we can add the information of whether a value needs to + # be registered as an attribute in qparams dict itself + if key in ['_scale_', '_zero_point_']: + # For scale and zero_point values we register them as buffers in the root module. + # TODO: maybe need more complex attr name here + qparam_node = create_getattr_from_value( + model, graph, module_path + prefix + key, value_or_node) + quantize_op_inputs.append(qparam_node) + else: + # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph. + quantize_op_inputs.append(value_or_node) + + quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {}) + # use the same qparams from quantize op + dq_inputs = [quantized_node] + quantize_op_inputs[1:] + dequantized_node = graph.call_function( + dequantize_op, + tuple(dq_inputs), + {} + ) + node.replace_all_uses_with(dequantized_node) + graph.erase_node(node) + elif is_dynamic: + + # uint8/int8/fp16 dynamic quantization + + # 1. extract information for inserting q/dq node from activation_post_process + node_type = "call_function" + quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor.tensor + # we only use choose_qparams for is_decomposed now, + # but we should probably align the non-decomposed path with this as well, + # and that can be done after we remove reduce_range flag + # 1. extract qparams from activation_post_process module + dtype_ = to_underlying_dtype(dtype) + assert dtype_ in [torch.uint8, torch.int8], \ + "only uint8 and int8 are supported in reference flow for " \ + "dynamic quantization right now" + quant_min = activation_post_process.quant_min # type: ignore[attr-defined] + quant_max = activation_post_process.quant_max # type: ignore[attr-defined] + # note: scale and zero_point are missing for quantize_per_tensor op + # we'll need to get this from choose_qparams op, which we'll add after + # this step + qparams = { + "_quant_min_": quant_min, + "_quant_max_": quant_max, + "_dtype_": dtype_ + } + + # 2. insert choose_qparams op and update the qparams list + with graph.inserting_before(node): + input_node = node.args[0] + choose_qparams_op_inputs = [node.args[0]] + for key, value in qparams.items(): + # we have quant_min, quant_max and dtype, all should be stored + # as literals + choose_qparams_op_inputs.append(value) + choose_qparams_node = graph.create_node( + "call_function", + torch.ops.quantized_decomposed.choose_qparams.tensor, + tuple(choose_qparams_op_inputs), + {} + ) + # choose_qparms returns (scale, zero_point) + scale_node = graph.create_node( + "call_function", + operator.getitem, + (choose_qparams_node, 0), + {} + ) + zero_point_node = graph.create_node( + "call_function", + operator.getitem, + (choose_qparams_node, 1), + {} + ) + quant_min = qparams["_quant_min_"] + quant_max = qparams["_quant_max_"] + dtype = qparams["_dtype_"] + qparams = { + "_scale_": scale_node, + "_zero_point_": zero_point_node, + "_quant_min_": quant_min, + "_quant_max_": quant_max, + "_dtype_": dtype + } + + # 3. replace activation_post_process node to quantize and dequantize node + with graph.inserting_before(node): + input_node = node.args[0] + quantize_op_inputs = [input_node] + for key, value_or_node in qparams.items(): + # TODO: we can add the information of whether a value needs to + # be registered as an attribute in qparams dict itself + if key in ['_scale_', '_zero_point_']: + # in this case we have a node in the graph since it's dynamically + # computed from the input, with choose_qparams op + qparam_node = value_or_node + quantize_op_inputs.append(qparam_node) + else: + # for qparams that are not scale/zero_point (like axis, dtype) we + # store them as literals in the graph. + quantize_op_inputs.append(value_or_node) + + quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {}) + # use the same qparams from quantize op + dq_inputs = [quantized_node] + quantize_op_inputs[1:] + # need to use the tensor variant of this op, since scale and zero_point + # from choose_qparam are Tensors, instead of float/int, this is to + # prevent these nodes being traced away by downstream systems + dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor.tensor + dequantized_node = graph.call_function( + dequantize_op, + tuple(dq_inputs), + {} + ) + node.replace_all_uses_with(dequantized_node) + graph.erase_node(node) + elif dtype == torch.float16: + raise NotImplementedError("decomposed to float16 op not implemented yet") + + # should not reach since we have checks in the begining to make sure the + # activation_post_process is supported + +def _replace_observer_with_quantize_dequantize_node( + model: torch.nn.Module, + graph: Graph, + node: Node, + modules: Dict[str, torch.nn.Module], + node_name_to_scope: Dict[str, Tuple[str, type]], + node_name_to_qconfig: Dict[str, QConfigAny]) -> None: + """ Replace activation_post_process module call node with quantize and + dequantize node + + Before: + ... -> observer_0(x) -> ... + After: + ... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ... + """ + assert modules is not None + assert isinstance(node.target, str) + module_path, prefix = get_module_path_and_prefix(node, node_name_to_scope, node_name_to_qconfig) + activation_post_process = modules[node.target] + # skip replacing observers to quant/dequant nodes if the qconfigs of all + # consumers and producers of this observer are None + skip_replacement = all([ + has_none_qconfig(n, node_name_to_qconfig) for n in + list(node.args) + list(node.users.keys())]) + if skip_replacement or not _is_conversion_supported(activation_post_process): + # didn't find correponding quantize op and info for the activation_post_process + # so we just remove the observer + with graph.inserting_before(node): + node.replace_all_uses_with(node.args[0]) + graph.erase_node(node) + return + + # otherwise, we can convert the activation_post_process module call to quantize/dequantize node + dtype = activation_post_process.dtype # type: ignore[attr-defined] + + is_dynamic = False + if hasattr(activation_post_process, "is_dynamic"): + is_dynamic = activation_post_process.is_dynamic # type: ignore[attr-defined, assignment] + + if dtype in [torch.quint8, torch.qint8, torch.qint32] and \ + (not is_dynamic): + # TODO: probably should cleanup this condition check, it's hard + # to reason about this if and the following elif + + # uint8/int8/int32 static quantization branch + + # 1. extract the information from activation_post_process module for generating + # the quantize and dequantize operator + node_type = "call_function" + quantize_op : Optional[Callable] = None + scale, zero_point = activation_post_process.calculate_qparams() # type: ignore[attr-defined, operator] + if is_per_channel(activation_post_process.qscheme): # type: ignore[attr-defined] + ch_axis = int(activation_post_process.ch_axis) # type: ignore[attr-defined, arg-type] + qparams = {"_scale_": scale, "_zero_point_": zero_point, "_axis_": ch_axis, "_dtype_": dtype} + quantize_op = torch.quantize_per_channel + else: + scale = float(scale) + zero_point = int(zero_point) + qparams = {"_scale_": scale, "_zero_point_": zero_point, "_dtype_": dtype} + quantize_op = torch.quantize_per_tensor + + # 2. replace activation_post_process node with quantize and dequantize + with graph.inserting_before(node): + input_node = node.args[0] + quantize_op_inputs = [input_node] + for key, value_or_node in qparams.items(): + # TODO: we can add the information of whether a value needs to + # be registered as an attribute in qparams dict itself + if key in ['_scale_', '_zero_point_']: + # For scale and zero_point values we register them as buffers in the root module. + # TODO: maybe need more complex attr name here + qparam_node = create_getattr_from_value( + model, graph, module_path + prefix + key, value_or_node) + quantize_op_inputs.append(qparam_node) + else: + # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph. + quantize_op_inputs.append(value_or_node) + + quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {}) + dequantized_node = graph.call_method("dequantize", args=(quantized_node,)) + node.replace_all_uses_with(dequantized_node) + graph.erase_node(node) + elif is_dynamic: + + # uint8/int8/fp16 dynamic quantization branch + + node_type = "call_function" + quantize_op = torch.quantize_per_tensor_dynamic + # TODO: get reduce range from observer + # reduce_range = activation_post_process.reduce_range + reduce_range = torch.backends.quantized.engine in ("fbgemm", "x86") + qparams = {"_dtype_": dtype, "_reduce_range_": reduce_range} + + with graph.inserting_before(node): + input_node = node.args[0] + quantize_op_inputs = [input_node] + for key, value in qparams.items(): + quantize_op_inputs.append(value) + + quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {}) + dequantized_node = graph.call_method("dequantize", args=(quantized_node,)) + node.replace_all_uses_with(dequantized_node) + graph.erase_node(node) + elif dtype == torch.float16: + node_type = "call_method" + quantize_op = "to" # type: ignore[assignment] + qparams = {"_dtype_": dtype} + with graph.inserting_before(node): + input_node = node.args[0] + quantize_op_inputs = [input_node] + for key, value in qparams.items(): + # TODO: we can add the information of whether a value needs to + # be registered as an attribute in qparams dict itself + quantize_op_inputs.append(value) + + quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {}) + dequantized_node = graph.call_method("dequantize", args=(quantized_node,)) + node.replace_all_uses_with(dequantized_node) + graph.erase_node(node) + + # should not reach since we have checks in the begining to make sure the + # activation_post_process is supported + +# this is a temporary hack for custom module, we may want to implement +# this properly after the custom module class design is finalized +# TODO: DeQuantStubs are currently inserted only after custom module LSTM, while observers are inserted +# after all other custom modules. In the future, we should simply insert QuantStubs before and DeQuantStubs +# after custom modules in general, and replace these with "quantize" and "dequantize" nodes respectively. +def _replace_observer_or_dequant_stub_with_dequantize_node(node: Node, graph: Graph): + call_custom_module_node = node.args[0] + assert isinstance(call_custom_module_node, Node), \ + f"Expecting the for call custom module node to be a Node, but got {call_custom_module_node}" + node.replace_all_uses_with(call_custom_module_node) + graph.erase_node(node) + insert_dequantize_node(call_custom_module_node, graph) + +def _is_conversion_supported(activation_post_process: torch.nn.Module) -> bool: + dtype = activation_post_process.dtype # type: ignore[attr-defined] + + is_dynamic = False + if hasattr(activation_post_process, "is_dynamic"): + is_dynamic = activation_post_process.is_dynamic # type: ignore[attr-defined, assignment] + + return ( + (dtype in [torch.quint8, torch.qint8, torch.qint32] and (not is_dynamic)) or # type: ignore[return-value] + is_dynamic or + dtype == torch.float16 + ) def restore_state( observed: torch.nn.Module @@ -485,7 +851,8 @@ def convert( is_standalone_module: bool = False, _remove_qconfig_flag: bool = True, qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None, - backend_config: Union[BackendConfig, Dict[str, Any], None] = None) -> torch.nn.Module: + backend_config: Union[BackendConfig, Dict[str, Any], None] = None, + is_decomposed: bool = False) -> torch.nn.Module: """ We will convert an observed model (a module with observer calls) to a reference quantized model, the rule is simple: @@ -497,13 +864,21 @@ def convert( is stored in observed_node_names, we can decide whether we need to swap the module based on this set - standalone_module means it a submodule that is not inlined in - parent module, and will be quantized separately as one unit. + Args: + * `is_standalone_module`: when this flag is True, it means we are quantizing + a submodule that is not inlined in parent module, and will be quantized + separately as one unit. + + * `is_decomposed`: a boolean flag to indicate whether we want to use the + quantize operator for decomposed quantized tensor + (torch.ops.quantized_decomposed.quantize_per_tensor) or default/standalone + quantized tensor (torch.quantize_per_tensor) - Returns a quantized standalone module, whether input/output is quantized is - specified by prepare_custom_config, with - input_quantized_idxs, output_quantized_idxs, please - see docs for prepare_fx for details + Returns: + a quantized standalone module, whether input/output is quantized is + specified by prepare_custom_config, with + input_quantized_idxs, output_quantized_idxs, please + see docs for :func:`~torch.ao.quantization.prepare_fx` for details """ if convert_custom_config is None: convert_custom_config = ConvertCustomConfig() @@ -552,7 +927,7 @@ def convert( modules_copy = copy.deepcopy(modules) if model._is_qat: - update_qconfig_for_qat(qconfig_mapping, {}) + _update_qconfig_for_qat(qconfig_mapping, {}) update_qconfig_for_fusion(model, qconfig_mapping) compare_prepare_convert_qconfig_mappings(prepare_qconfig_mapping, qconfig_mapping) # type: ignore[arg-type] @@ -588,75 +963,6 @@ def convert( if node.op == 'placeholder': graph_inputs.append(node.name) - # TODO: move this outside of this function - def replace_observer_with_quantize_dequantize_node( - model: torch.nn.Module, - graph: Graph, - node: Node, - modules: Dict[str, torch.nn.Module], - node_name_to_scope: Dict[str, Tuple[str, type]], - node_name_to_qconfig: Dict[str, QConfigAny]) -> None: - """ Replace activation_post_process module call node with quantize and - dequantize node - - Before: - ... -> observer_0(x) -> ... - After: - ... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ... - """ - assert modules is not None - assert isinstance(node.target, str) - module_path, prefix = get_module_path_and_prefix(node, node_name_to_scope, node_name_to_qconfig) - observer_module = modules[node.target] - maybe_quantize_node_info = get_quantize_node_info(observer_module) - # Skip replacing observers to quant/dequant nodes if the qconfigs of all - # consumers and producers of this observer are None - skip_replacement = all([ - has_none_qconfig(n, node_name_to_qconfig) for n in - list(node.args) + list(node.users.keys())]) - if skip_replacement or maybe_quantize_node_info is None: - # didn't find correponding quantize op and info for the observer_module - # so we just remove the observer - with graph.inserting_before(node): - node.replace_all_uses_with(node.args[0]) - graph.erase_node(node) - else: - # otherwise, we can convert the observer moduel call to quantize/dequantize node - node_type, quantize_op, qparams = maybe_quantize_node_info - # replace observer node with quant - dequant node - with graph.inserting_before(node): - input_node = node.args[0] - inputs = [input_node] - for key, value in qparams.items(): - # TODO: we can add the information of whether a value needs to - # be registered as an attribute in qparams dict itself - if key in ['_scale_', '_zero_point_']: - # For scale and zero_point values we register them as buffers in the root module. - # TODO: maybe need more complex attr name here - qparam_node = create_getattr_from_value(model, graph, module_path + prefix + key, value) - inputs.append(qparam_node) - else: - # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph. - inputs.append(value) - - quantized_node = graph.create_node(node_type, quantize_op, tuple(inputs), {}) - dequantized_node = graph.call_method("dequantize", args=(quantized_node,)) - node.replace_all_uses_with(dequantized_node) - graph.erase_node(node) - - # this is a temporary hack for custom module, we may want to implement - # this properly after the custom module class design is finalized - # TODO: DeQuantStubs are currently inserted only after custom module LSTM, while observers are inserted - # after all other custom modules. In the future, we should simply insert QuantStubs before and DeQuantStubs - # after custom modules in general, and replace these with "quantize" and "dequantize" nodes respectively. - def replace_observer_or_dequant_stub_with_dequantize_node(node: Node, graph: Graph): - call_custom_module_node = node.args[0] - assert isinstance(call_custom_module_node, Node), \ - f"Expecting the for call custom module node to be a Node, but got {call_custom_module_node}" - node.replace_all_uses_with(call_custom_module_node) - graph.erase_node(node) - insert_dequantize_node(call_custom_module_node, graph) - # additional state to override inputs to be quantized, if specified # by the user placeholder_node_seen_cnt = 0 @@ -707,13 +1013,18 @@ def replace_observer_or_dequant_stub_with_dequantize_node(node: Node, graph: Gra if is_activation_post_process(mod): observed_node = node.args[0] if observed_node in statically_quantized_custom_module_nodes: - replace_observer_or_dequant_stub_with_dequantize_node(node, model.graph) + _replace_observer_or_dequant_stub_with_dequantize_node(node, model.graph) else: - replace_observer_with_quantize_dequantize_node( - model, model.graph, node, modules, node_name_to_scope, - node_name_to_qconfig) + if is_decomposed: + _replace_observer_with_quantize_dequantize_node_decomposed( + model, model.graph, node, modules, node_name_to_scope, + node_name_to_qconfig) + else: + _replace_observer_with_quantize_dequantize_node( + model, model.graph, node, modules, node_name_to_scope, + node_name_to_qconfig) elif isinstance(mod, DeQuantStub): - replace_observer_or_dequant_stub_with_dequantize_node(node, model.graph) + _replace_observer_or_dequant_stub_with_dequantize_node(node, model.graph) elif is_observed_standalone_module(mod): convert_standalone_module( node, modules, model, is_reference, backend_config) diff --git a/torch/ao/quantization/fx/custom_config.py b/torch/ao/quantization/fx/custom_config.py index 0f5f5bfe8d158..9d08853a41260 100644 --- a/torch/ao/quantization/fx/custom_config.py +++ b/torch/ao/quantization/fx/custom_config.py @@ -4,7 +4,7 @@ from torch.ao.quantization import QConfigMapping from torch.ao.quantization.backend_config import BackendConfig -from torch.ao.quantization.quant_type import QuantType, _quant_type_from_str, quant_type_to_str +from torch.ao.quantization.quant_type import QuantType, _quant_type_from_str, _get_quant_type_to_str __all__ = [ @@ -263,7 +263,7 @@ def _make_tuple(key: Any, e: StandaloneModuleConfigEntry): for quant_type, float_to_observed_mapping in self.float_to_observed_mapping.items(): if FLOAT_TO_OBSERVED_DICT_KEY not in d: d[FLOAT_TO_OBSERVED_DICT_KEY] = {} - d[FLOAT_TO_OBSERVED_DICT_KEY][quant_type_to_str(quant_type)] = float_to_observed_mapping + d[FLOAT_TO_OBSERVED_DICT_KEY][_get_quant_type_to_str(quant_type)] = float_to_observed_mapping if len(self.non_traceable_module_names) > 0: d[NON_TRACEABLE_MODULE_NAME_DICT_KEY] = self.non_traceable_module_names if len(self.non_traceable_module_classes) > 0: @@ -350,7 +350,7 @@ def to_dict(self) -> Dict[str, Any]: for quant_type, observed_to_quantized_mapping in self.observed_to_quantized_mapping.items(): if OBSERVED_TO_QUANTIZED_DICT_KEY not in d: d[OBSERVED_TO_QUANTIZED_DICT_KEY] = {} - d[OBSERVED_TO_QUANTIZED_DICT_KEY][quant_type_to_str(quant_type)] = observed_to_quantized_mapping + d[OBSERVED_TO_QUANTIZED_DICT_KEY][_get_quant_type_to_str(quant_type)] = observed_to_quantized_mapping if len(self.preserved_attributes) > 0: d[PRESERVED_ATTRIBUTES_DICT_KEY] = self.preserved_attributes return d diff --git a/torch/ao/quantization/fx/fuse.py b/torch/ao/quantization/fx/fuse.py index 8a4fcb6d11251..930be1bdcc4fc 100644 --- a/torch/ao/quantization/fx/fuse.py +++ b/torch/ao/quantization/fx/fuse.py @@ -24,11 +24,13 @@ get_fusion_pattern_to_root_node_getter, get_fusion_pattern_to_extra_inputs_getter, ) -from .backend_config_utils import get_fusion_pattern_to_fuse_handler_cls from .custom_config import FuseCustomConfig -from .fusion_patterns import * # noqa: F401,F403 +from .fuse_handler import ( + _get_fusion_pattern_to_fuse_handler_cls, + FuseHandler, +) from typing import Any, Callable, Dict, List, Tuple, Union import warnings @@ -38,6 +40,9 @@ __all__ = [ "fuse", + # TODO: We should make this private in the future + # This is currently needed for test_public_bindings for some reason + "FuseHandler", ] @@ -69,7 +74,7 @@ def fuse( if backend_config is None: backend_config = get_native_backend_config() - fusion_pattern_to_fuse_handler_cls = sorted_patterns_dict(get_fusion_pattern_to_fuse_handler_cls(backend_config)) + fusion_pattern_to_fuse_handler_cls = sorted_patterns_dict(_get_fusion_pattern_to_fuse_handler_cls(backend_config)) fuser_method_mapping = get_fuser_method_mapping(backend_config) fusion_pattern_to_root_node_getter = get_fusion_pattern_to_root_node_getter(backend_config) fusion_pattern_to_extra_inputs_getter = get_fusion_pattern_to_extra_inputs_getter(backend_config) diff --git a/torch/ao/quantization/fx/fusion_patterns.py b/torch/ao/quantization/fx/fuse_handler.py similarity index 89% rename from torch/ao/quantization/fx/fusion_patterns.py rename to torch/ao/quantization/fx/fuse_handler.py index 075a0cfa03315..2106dc4e33143 100644 --- a/torch/ao/quantization/fx/fusion_patterns.py +++ b/torch/ao/quantization/fx/fuse_handler.py @@ -1,4 +1,5 @@ import torch +from torch.ao.quantization.backend_config import BackendConfig from torch.fx.graph import Node, Graph from ..utils import _parent_name, NodePattern, Pattern from ..fuser_method_mappings import get_fuser_method_new @@ -38,7 +39,6 @@ def fuse(self, is_qat: bool) -> Node: pass -# TODO: move this to backend_config_utils class DefaultFuseHandler(FuseHandler): def __init__( self, @@ -108,3 +108,12 @@ def get_matched_types(m): args.extend(extra_args) node.args = tuple(args) return node + +def _get_fusion_pattern_to_fuse_handler_cls( + backend_config: BackendConfig) -> Dict[Pattern, Callable]: + fusion_pattern_to_fuse_handlers: Dict[Pattern, Callable] = {} + for pattern, config in backend_config.configs.items(): + if config.fuser_method is not None: + # TODO: is this logic right? + fusion_pattern_to_fuse_handlers[pattern] = DefaultFuseHandler + return fusion_pattern_to_fuse_handlers diff --git a/torch/ao/quantization/fx/match_utils.py b/torch/ao/quantization/fx/match_utils.py index e50b89f9ce408..f9b6c442476a0 100644 --- a/torch/ao/quantization/fx/match_utils.py +++ b/torch/ao/quantization/fx/match_utils.py @@ -5,7 +5,7 @@ Node, ) from torch.ao.quantization.utils import Pattern -from .quantization_patterns import ( +from .quantize_handler import ( QuantizeHandler, ) from ..qconfig import ( @@ -18,7 +18,7 @@ is_observed_standalone_module, ) from torch.nn.utils.parametrize import type_before_parametrizations -from typing import Any, Dict, List, Callable, Optional, Tuple, Type, Set +from typing import Any, Dict, List, Callable, Optional, Tuple, Type, Set, Iterable # TODO: revisit this list. Many helper methods shouldn't be public @@ -53,6 +53,9 @@ def is_match(modules, node, pattern, max_uses=sys.maxsize): if isinstance(self_match, type) and issubclass(self_match, MatchAllNode): return True + if node == pattern: + return True + if not isinstance(node, Node) or len(node.users) > max_uses: return False @@ -133,6 +136,8 @@ def _recursive_record_node_in_match_map( if isinstance(node_pattern, Node): match_map[node_pattern.name] = ( last_node, matched_node_pattern, pattern, match_value) + elif not isinstance(node_pattern, Iterable): + return else: for n in node_pattern: _recursive_record_node_in_match_map(last_node, match_map, n, matched_node_pattern, pattern, match_value) @@ -146,6 +151,7 @@ def record_match( match_map): if isinstance(pattern, tuple): s, *args = pattern + is_single_arg = len(args) == 1 current_node_pattern: List[Node] = [] record_match( s, @@ -162,7 +168,17 @@ def record_match( current_node_pattern, match_map) if len(current_node_pattern) > 1: - matched_node_pattern.append(tuple(current_node_pattern)) + # current_node_pattern is the node pattern we get from matching + # the subpattern with arguments of the node + # we use is_single_arg to recover the original structure of the pattern + # if the original pattern has a single argument, we will have + # (original_op, (original_arg, ...)) + # otherwise, we'll have a list of arguments + # (original_op, arg0, arg1, arg2, ...) + if is_single_arg: + matched_node_pattern.append(tuple(current_node_pattern)) + else: + matched_node_pattern.extend(list(current_node_pattern)) else: matched_node_pattern.append(current_node_pattern[0]) else: diff --git a/torch/ao/quantization/fx/pattern_utils.py b/torch/ao/quantization/fx/pattern_utils.py index c4971b542627a..10b67d075b216 100644 --- a/torch/ao/quantization/fx/pattern_utils.py +++ b/torch/ao/quantization/fx/pattern_utils.py @@ -2,7 +2,6 @@ from typing import Dict, Any from torch.ao.quantization.utils import Pattern from ..fake_quantize import FixedQParamsFakeQuantize -# from .quantization_patterns import BinaryOpQuantizeHandler from ..observer import ObserverBase import copy diff --git a/torch/ao/quantization/fx/prepare.py b/torch/ao/quantization/fx/prepare.py index cc6fc65bd8906..9985a5c049720 100644 --- a/torch/ao/quantization/fx/prepare.py +++ b/torch/ao/quantization/fx/prepare.py @@ -1,6 +1,5 @@ import copy import torch -import operator import warnings from torch.fx import ( GraphModule, @@ -18,26 +17,22 @@ ObserverBase, ) from ..qconfig import ( - obs_or_fq_ctr_equals, - float16_dynamic_qconfig, - float16_static_qconfig, - is_reuse_input_qconfig, + _is_reuse_input_qconfig, QConfigAny, ) from ..qconfig_mapping import ( - _FIXED_QPARAMS_OP_TO_OBSERVER, QConfigMapping, ) -from ..qconfig_mapping_utils import ( - get_flattened_qconfig_dict, - update_qconfig_for_qat, -) from .qconfig_mapping_utils import ( generate_node_name_to_qconfig, update_qconfig_for_fusion, + _get_flattened_qconfig_dict, + _update_qconfig_for_qat, ) -from .quantization_patterns import ( +from .quantize_handler import ( + _default_root_node_getter, + _get_pattern_to_quantize_handlers, QuantizeHandler, ) @@ -46,8 +41,6 @@ NodePattern, ) -from torch.ao.quantization import FixedQParamsFakeQuantize - from ._equalize import ( is_equalization_observer, node_supports_equalization, @@ -92,7 +85,6 @@ get_qconfig_dtypes, get_swapped_custom_module_class, activation_is_statically_quantized, - activation_is_int8_quantized, ) from ..backend_config.utils import ( @@ -105,10 +97,6 @@ DTypeConfig, get_native_backend_config, ) -from .backend_config_utils import ( - get_pattern_to_quantize_handlers, -) - from .custom_config import ( PrepareCustomConfig, StandaloneModuleConfigEntry, @@ -251,10 +239,10 @@ def _is_pattern_dtype_config_and_qconfig_supported_by_backend( assert matched_node_pattern is not None and len(matched_node_pattern) >= 1 pattern_to_dtype_configs = get_pattern_to_dtype_configs(backend_config) dtype_configs: List[DTypeConfig] = pattern_to_dtype_configs.get(pattern, []) + pattern_to_root_node_getter = get_fusion_pattern_to_root_node_getter(backend_config) - # TODO: this only works for one input and one output patterns, need to generalize to multiple - # inputs/output - root_node = _default_root_node_getter(matched_node_pattern) + root_node_getter = pattern_to_root_node_getter.get(pattern, _default_root_node_getter) + root_node = root_node_getter(matched_node_pattern) input_node = root_node output_node = matched_node_pattern[0] for dtype_config in dtype_configs: @@ -307,12 +295,6 @@ def add_matched_node_name_to_set(matched_node_pattern: NodePattern, s: Set[str]) for maybe_node in matched_node_pattern: add_matched_node_name_to_set(maybe_node, s) -# this is temporary, will be removed soon -def _default_root_node_getter(node_pattern): - while not isinstance(node_pattern, Node): - node_pattern = node_pattern[-1] - return node_pattern - def insert_observer( node: Node, observer: ObserverBase, @@ -402,21 +384,11 @@ def get_target_activation_dtype_for_node( "output_activation_dtype": None, } - # TODO(future PR): consider stopping matching getitem - is_getitem = node.op == 'call_function' and \ - node.target == operator.getitem - if is_getitem: - return { - "input_activation_dtype": (torch.float, False), - "output_activation_dtype": (torch.float, False), - } - # get qconfig to determine the eventual dtype of this node if qconfig is not None: if qhandler is not None and qhandler.input_output_observed(): - act_dtype, weight_dtype, act_compute_dtype = \ + act_dtype, weight_dtype, input_act_is_dynamic = \ get_qconfig_dtypes(qconfig) - input_act_is_dynamic = act_compute_dtype is not None # Currently `QConfig` only has one `activation` field. # For static quantization, it is reused for both input @@ -426,13 +398,13 @@ def get_target_activation_dtype_for_node( # In the future this may change as we add more fields # to the `QConfig` object. output_act_dtype = act_dtype \ - if input_act_is_dynamic is not True else torch.float + if (not input_act_is_dynamic) else torch.float bias_dtype = torch.float16 \ if ( act_dtype == torch.float16 and weight_dtype == torch.float16 - and act_compute_dtype is None + and (not input_act_is_dynamic) ) else torch.float return { "input_activation_dtype": (act_dtype, input_act_is_dynamic), @@ -585,7 +557,7 @@ def maybe_insert_input_observer_for_arg_or_kwarg( # regular flow for most nodes, except standalone modules is_weight = node_arg_is_weight(node, arg, backend_config) - is_reuse_input_qconfig_ = is_reuse_input_qconfig(qconfig) + _is_reuse_input_qconfig_ = _is_reuse_input_qconfig(qconfig) act_post_process_ctr = qconfig.weight if is_weight else \ qconfig.activation @@ -613,7 +585,7 @@ def maybe_insert_input_observer_for_arg_or_kwarg( # if arg output dtype is in DO_NOT_OBS_DTYPE_LIST do not insert observer (arg_as_output_target_dtype not in DO_NOT_OBS_DTYPE_LIST) and # if qconfig is reuse_input qconfig, we won't insert extra observer for input - not is_reuse_input_qconfig_ + not _is_reuse_input_qconfig_ ) or ( # need to add input observer for dynamic quantization # only add observer for first input for now, we may need to extend @@ -826,13 +798,7 @@ def maybe_insert_output_observer_for_node( (not is_standalone_module) if should_insert_observer: - act_post_process_ctr = qconfig.activation - if activation_is_int8_quantized(qconfig): - act_post_process_ctr = qhandler.get_activation_ctr( - qconfig, - matched_pattern, - is_qat) - observer = act_post_process_ctr() + observer = qconfig.activation() return insert_observer(node, observer, model, modules, graph) else: return None @@ -946,9 +912,7 @@ def maybe_propagate_dtype_for_node( ) -> None: """ Assigns `target_dtype` to `node`, setting `is_dynamic` to False. If `node` - is a general tensor shape op - (see GeneralTensorShapeOpQuantizeHandler in quantization_patterns.py for more details) - also call this function recursively on + is a general tensor shape op, also call this function recursively on the first argument, to propagate the dtype to the caller. """ node_name_to_target_dtype_info[node.name]["input_activation_dtype"] = (target_dtype, False) @@ -1119,14 +1083,10 @@ def swap_custom_module_to_observed( def insert_observers_for_model( model: GraphModule, - modules: Dict[str, torch.nn.Module], matches: Dict[str, _MatchResultWithQConfig], node_name_to_qconfig: Dict[str, QConfigAny], - graph: Graph, prepare_custom_config: PrepareCustomConfig, equalization_config_map: Dict[str, Any], - input_quantized_idxs: List[int], - output_quantized_idxs: List[int], backend_config: BackendConfig, observed_node_names: Set[str], is_qat: bool, @@ -1198,6 +1158,8 @@ def insert_observers_for_model( for node in model.graph.nodes: root_node, _, pattern, qhandler, qconfig = matches.get( node.name, (None, None, None, None, None)) + input_quantized_idxs: List[int] = prepare_custom_config.input_quantized_indexes + output_quantized_idxs: List[int] = prepare_custom_config.output_quantized_indexes node_name_to_target_dtype_info[node.name] = get_target_activation_dtype_for_node( node, qconfig, inputs_seen_counter, outputs_seen_counter, input_quantized_idxs, output_quantized_idxs, qhandler, @@ -1243,14 +1205,10 @@ def insert_observers_for_model( this_node_dtype_info = node_name_to_target_dtype_info[node.name] output_not_a_tensor = this_node_dtype_info is None - # TODO(future PR): consider stopping matching getitem - is_getitem = node.op == 'call_function' and \ - node.target == operator.getitem skip_inserting_observers = ( (qconfig is None) or - output_not_a_tensor or - is_getitem + output_not_a_tensor ) and ( not node.op == 'output' ) @@ -1290,15 +1248,14 @@ def insert_observers_for_model( if user != node and is_user_quantized: is_quantized_branch = True - # TODO: this only works for sequential fusion right now, extend it - # it to automatically detect all input nodes based on the pattern - # need to change find_matches function to return this information - root_node = _default_root_node_getter(matched_node_pattern) + pattern_to_root_node_getter = get_fusion_pattern_to_root_node_getter(backend_config) + root_node_getter = pattern_to_root_node_getter.get(pattern, _default_root_node_getter) + root_node = root_node_getter(matched_node_pattern) is_input_node_of_the_pattern = node is root_node if is_input_node_of_the_pattern: # this modifies node inplace maybe_insert_input_observers_for_node( - node, qconfig, model, modules, graph, + node, qconfig, model, modules, model.graph, node_name_to_target_dtype_info, qhandler, prepare_custom_config, @@ -1306,13 +1263,13 @@ def insert_observers_for_model( # Insert equalization input observers if needed maybe_insert_input_equalization_observers_for_node( - node, equalization_qconfig, model, modules, graph, + node, equalization_qconfig, model, modules, model.graph, node_name_to_target_dtype_info, is_quantized_branch, backend_config) is_last_node_of_pattern = node is last_node is_general_tensor_value_op = \ (qhandler is not None and qhandler.is_general_tensor_value_op()) - is_reuse_input_qconfig_ = is_reuse_input_qconfig(qconfig) + _is_reuse_input_qconfig_ = _is_reuse_input_qconfig(qconfig) if is_last_node_of_pattern: if _is_custom_module_lstm(node, modules, qconfig, qhandler): @@ -1327,12 +1284,12 @@ def insert_observers_for_model( # these output observers are the same as DeQuantStubs. In the future, we # should resolve this inconsistency by inserting DeQuantStubs for all custom # modules, not just for LSTM. - _insert_dequant_stubs_for_custom_module_lstm_output(node, model, modules, graph) + _insert_dequant_stubs_for_custom_module_lstm_output(node, model, modules, model.graph) swap_custom_module_to_observed(node, qconfig, modules, prepare_custom_config) else: # this returns the new observer node if it was needed maybe_output_obs_node = maybe_insert_output_observer_for_node( - node, model, modules, graph, matches, + node, model, modules, model.graph, matches, node_name_to_target_dtype_info, pattern, qhandler, is_qat) if maybe_output_obs_node is not None: @@ -1364,7 +1321,7 @@ def insert_observers_for_model( # to make all inputs and outputs use the first input's # observer if (is_general_tensor_value_op and is_observer_in_same_graph_) or \ - is_reuse_input_qconfig_: + _is_reuse_input_qconfig_: if not maybe_make_input_output_share_observers(node, model, modules): remove_output_observer(node, model, modules) @@ -1375,7 +1332,7 @@ def insert_observers_for_model( maybe_insert_observers_before_graph_output( node, output_quantized_idxs, node_name_to_target_dtype_info, node_name_to_qconfig, - model, modules, graph) + model, modules, model.graph) # # After this point, the current node has input and output observers @@ -1392,45 +1349,6 @@ def insert_observers_for_model( return results_node -def _validate_fixed_qparams_qconfigs(model: GraphModule, node_name_to_qconfig: Dict[str, QConfigAny]): - """ - Validate whether the correct observers are configured for fixed qparams ops in the model, if any. - """ - # TODO: handle fp16 qconfigs properly - allowed_observer_ctrs = [ - float16_dynamic_qconfig.activation, - float16_static_qconfig.activation, - ] - named_modules = dict(model.named_modules(remove_duplicate=False)) - for node in model.graph.nodes: - if node.op == "call_function": - module_type_or_function_or_method = node.target - elif node.op == "call_module": - module_type_or_function_or_method = type(named_modules[node.target]) - else: - module_type_or_function_or_method = None - - if module_type_or_function_or_method in _FIXED_QPARAMS_OP_TO_OBSERVER: - bad_observer = True - qconfig = node_name_to_qconfig.get(node.name, None) - if qconfig is None: - bad_observer = False - else: - for observer_ctr in allowed_observer_ctrs + [_FIXED_QPARAMS_OP_TO_OBSERVER[module_type_or_function_or_method]]: - if obs_or_fq_ctr_equals( - qconfig.activation, - FixedQParamsFakeQuantize.with_args(observer=observer_ctr)) or \ - obs_or_fq_ctr_equals(qconfig.activation, observer_ctr): - bad_observer = False - if bad_observer: - raise ValueError("QConfigMapping must specify fixed qparams observer for fixed qparams op " - "'%s' type: '%s'. Please use torch.ao.quantization.get_default_qconfig_mapping or " - "torch.ao.quantization.get_default_qat_qconfig_mapping" - " instead. Example: \n" - " qconfig_mapping = get_default_qconfig_mapping(\"fbgemm\") \n" - " model = prepare_fx(model, qconfig_mapping, example_inputs)" - "" % (node.format_node(), module_type_or_function_or_method)) - def run_prepare_fx_on_standalone_modules( model: torch.nn.Module, is_qat: bool, @@ -1573,7 +1491,7 @@ def prepare( pattern_to_quantize_handler: Dict[Pattern, QuantizeHandler] = {} if backend_config is None: backend_config = get_native_backend_config() - pattern_to_quantize_handler = get_pattern_to_quantize_handlers(backend_config) + pattern_to_quantize_handler = _get_pattern_to_quantize_handlers(backend_config) pattern_to_quantize_handler = sorted_patterns_dict(pattern_to_quantize_handler) root_node_getter_mapping = \ @@ -1581,14 +1499,14 @@ def prepare( update_qconfig_for_fusion(model, qconfig_mapping) update_qconfig_for_fusion(model, _equalization_config) - flattened_qconfig_dict = get_flattened_qconfig_dict(qconfig_mapping) + flattened_qconfig_dict = _get_flattened_qconfig_dict(qconfig_mapping) # TODO: support regex as well propagate_qconfig_(model, flattened_qconfig_dict, prepare_custom_config.to_dict()) if is_qat: module_to_qat_module = get_module_to_qat_module(backend_config) qat_swap_modules(model, module_to_qat_module) - update_qconfig_for_qat(qconfig_mapping, {}) + _update_qconfig_for_qat(qconfig_mapping, {}) # mapping from fully qualified module name to module instance # for example, @@ -1603,7 +1521,6 @@ def prepare( equalization_node_name_to_qconfig = generate_node_name_to_qconfig( model, modules, model.graph, _equalization_config, node_name_to_scope) node_name_to_qconfig = generate_node_name_to_qconfig(model, modules, model.graph, qconfig_mapping, node_name_to_scope) - _validate_fixed_qparams_qconfigs(model, node_name_to_qconfig) # match the patterns that will get quantized standalone_module_names = list(prepare_custom_config.standalone_module_names.keys()) @@ -1620,9 +1537,6 @@ def prepare( match_with_qconfig = (*match_without_qconfig, node_name_to_qconfig[node_name]) matches[node_name] = match_with_qconfig - input_quantized_idxs: List[int] = prepare_custom_config.input_quantized_indexes - output_quantized_idxs: List[int] = prepare_custom_config.output_quantized_indexes - run_prepare_fx_on_standalone_modules( model, is_qat, modules, matches, prepare_custom_config, backend_config) @@ -1632,14 +1546,15 @@ def prepare( observed_node_names: Set[str] = set() result_node = insert_observers_for_model( - model, modules, matches, node_name_to_qconfig, - model.graph, prepare_custom_config, + model, + matches, + node_name_to_qconfig, + prepare_custom_config, equalization_node_name_to_qconfig, - input_quantized_idxs, - output_quantized_idxs, backend_config, observed_node_names, - is_qat) + is_qat + ) save_state(model, node_name_to_qconfig, node_name_to_scope, prepare_custom_config, equalization_node_name_to_qconfig, qconfig_mapping, is_qat, observed_node_names) @@ -1654,6 +1569,8 @@ def prepare( # these inputs are observed in parent # converting List[int] to Tensor since module attribute is # Union[Tensor, Module] + input_quantized_idxs: List[int] = prepare_custom_config.input_quantized_indexes + output_quantized_idxs: List[int] = prepare_custom_config.output_quantized_indexes model._standalone_module_input_quantized_idxs = \ torch.tensor(input_quantized_idxs) model._standalone_module_output_quantized_idxs = torch.tensor(output_quantized_idxs) diff --git a/torch/ao/quantization/fx/qconfig_mapping_utils.py b/torch/ao/quantization/fx/qconfig_mapping_utils.py index 16f61f78daffe..9248890dbc158 100644 --- a/torch/ao/quantization/fx/qconfig_mapping_utils.py +++ b/torch/ao/quantization/fx/qconfig_mapping_utils.py @@ -1,8 +1,9 @@ import torch +import re from collections import defaultdict, OrderedDict -from typing import Callable, Any, Dict, Tuple, Set, List +from typing import Callable, Any, Dict, Tuple, Set, List, Union from torch.ao.quantization import QConfig -from torch.ao.quantization.qconfig import add_module_to_qconfig_obs_ctr, QConfigAny, qconfig_equals +from torch.ao.quantization.qconfig import _add_module_to_qconfig_obs_ctr, QConfigAny, qconfig_equals from torch.ao.quantization.quantize import ( is_activation_post_process, ) @@ -21,19 +22,18 @@ from ..utils import ( _parent_name, get_qconfig_dtypes, + get_combined_dict ) from ..qconfig_mapping import ( - OBJECT_TYPE_DICT_KEY, - MODULE_NAME_DICT_KEY, - MODULE_NAME_REGEX_DICT_KEY, + _OBJECT_TYPE_DICT_KEY, + _MODULE_NAME_DICT_KEY, + _MODULE_NAME_REGEX_DICT_KEY, QConfigMapping, ) -from ..qconfig_mapping_utils import ( - get_object_type_qconfig, - maybe_adjust_qconfig_for_module_type_or_name, +from ..quantization_mappings import ( + get_default_qat_module_mappings, ) - # TODO: revisit this list. Many helper methods shouldn't be public __all__ = [ "check_is_valid_config_dict", @@ -121,17 +121,17 @@ def generate_node_name_to_qconfig( qconfig = None if node.op == "get_attr": module_name, _ = _parent_name(node.target) - qconfig = maybe_adjust_qconfig_for_module_type_or_name( + qconfig = _maybe_adjust_qconfig_for_module_type_or_name( qconfig_mapping, type(modules[module_name]), module_name, global_qconfig) - qconfig_with_device_check = add_module_to_qconfig_obs_ctr(qconfig, modules.get(node.target, None)) + qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(qconfig, modules.get(node.target, None)) elif node.op == "call_function": # precedence: module_name_qconfig # > function_qconfig > global_qconfig # module_name takes precedence over function qconfig - function_qconfig = get_object_type_qconfig( + function_qconfig = _get_object_type_qconfig( qconfig_mapping, node.target, global_qconfig) module_path, module_type = node_name_to_scope[node.name] - qconfig = maybe_adjust_qconfig_for_module_type_or_name( + qconfig = _maybe_adjust_qconfig_for_module_type_or_name( qconfig_mapping, module_type, module_path, function_qconfig) cur_object_type_idx = \ @@ -139,28 +139,28 @@ def generate_node_name_to_qconfig( submodule_to_object_type_to_cur_idx[module_path][node.target] += 1 qconfig = maybe_adjust_qconfig_for_module_name_object_type_order( qconfig_mapping, module_path, node.target, cur_object_type_idx, qconfig) - qconfig_with_device_check = add_module_to_qconfig_obs_ctr(qconfig, modules.get(node.target, None)) + qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(qconfig, modules.get(node.target, None)) elif node.op == "call_method": module_path, module_type = node_name_to_scope[node.name] # first use node.target (string) to get the qconfig # this is to support configs like # "object_type": [("reshpe", qconfig)] - qconfig = maybe_adjust_qconfig_for_module_type_or_name( + qconfig = _maybe_adjust_qconfig_for_module_type_or_name( qconfig_mapping, node.target, module_path, global_qconfig) # if there is no special config for the method, we'll fall back to the # config for the module that contains the call_method node - qconfig = maybe_adjust_qconfig_for_module_type_or_name( + qconfig = _maybe_adjust_qconfig_for_module_type_or_name( qconfig_mapping, module_type, module_path, qconfig) # currently call_method does not support modifying qconfig # by order, we can add this later if it is needed. - qconfig_with_device_check = add_module_to_qconfig_obs_ctr(qconfig, modules.get(node.target, None)) + qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(qconfig, modules.get(node.target, None)) elif node.op == 'call_module': # if the node is an observer, just continue - don't add it to the qconfig_map if is_activation_post_process(modules[node.target]): continue - qconfig = maybe_adjust_qconfig_for_module_type_or_name( + qconfig = _maybe_adjust_qconfig_for_module_type_or_name( qconfig_mapping, type(modules[node.target]), node.target, global_qconfig) module_path, module_type = node_name_to_scope[node.name] @@ -174,7 +174,7 @@ def generate_node_name_to_qconfig( qconfig = maybe_adjust_qconfig_for_module_name_object_type_order( qconfig_mapping, parent_name, module_type, cur_object_type_idx, qconfig) - qconfig_with_device_check = add_module_to_qconfig_obs_ctr(qconfig, modules.get(node.target, None)) + qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(qconfig, modules.get(node.target, None)) # regex is not supported eager mode propagate_qconfig_, we'll # need to set the qconfig explicitly here in case regex @@ -223,7 +223,7 @@ def compare_prepare_convert_qconfig_mappings( convert_qconfig_mapping.module_name_qconfigs, convert_qconfig_mapping.module_name_regex_qconfigs, ] - dict_names = [OBJECT_TYPE_DICT_KEY, MODULE_NAME_DICT_KEY, MODULE_NAME_REGEX_DICT_KEY] + dict_names = [_OBJECT_TYPE_DICT_KEY, _MODULE_NAME_DICT_KEY, _MODULE_NAME_REGEX_DICT_KEY] for i in range(len(prepare_dicts)): for name, qconfig in prepare_dicts[i].items(): assert name in convert_dicts[i], "Missing key {} {} in convert QConfigMapping \ @@ -242,7 +242,7 @@ def is_qconfig_supported_by_dtype_configs(qconfig: QConfig, dtype_configs: List[ weight_dtype = dtype_config.weight_dtype or torch.float bias_dtype = dtype_config.bias_dtype or torch.float output_dtype = dtype_config.output_dtype or torch.float - qconfig_activation_dtype, qconfig_weight_dtype, qconfig_compute_dtype = \ + qconfig_activation_dtype, qconfig_weight_dtype, qconfig_input_act_is_dynamic = \ get_qconfig_dtypes(qconfig) qconfig_bias_dtype = torch.float16 \ if ( @@ -252,7 +252,8 @@ def is_qconfig_supported_by_dtype_configs(qconfig: QConfig, dtype_configs: List[ ) else torch.float if is_dynamic: - is_match = input_dtype == qconfig_compute_dtype and \ + is_match = qconfig_input_act_is_dynamic and \ + input_dtype == qconfig_activation_dtype and \ output_dtype == torch.float and \ weight_dtype == qconfig_weight_dtype else: @@ -263,3 +264,89 @@ def is_qconfig_supported_by_dtype_configs(qconfig: QConfig, dtype_configs: List[ if is_match: return True return False + +def _get_object_type_qconfig( + qconfig_mapping: QConfigMapping, + object_type: Union[Callable, str], + fallback_qconfig: QConfigAny) -> QConfigAny: + return qconfig_mapping.object_type_qconfigs.get(object_type, fallback_qconfig) + + +def _get_module_name_regex_qconfig(qconfig_mapping, module_name, fallback_qconfig): + for regex_pattern, qconfig in qconfig_mapping.module_name_regex_qconfigs.items(): + if re.match(regex_pattern, module_name): + # first match wins + return qconfig + return fallback_qconfig + + +def _get_module_name_qconfig(qconfig_mapping, module_name, fallback_qconfig): + if module_name == '': + # module name qconfig not found + return fallback_qconfig + if module_name in qconfig_mapping.module_name_qconfigs: + return qconfig_mapping.module_name_qconfigs[module_name] + else: + parent, _ = _parent_name(module_name) + return _get_module_name_qconfig(qconfig_mapping, parent, fallback_qconfig) + + +def _maybe_adjust_qconfig_for_module_type_or_name(qconfig_mapping, module_type, module_name, global_qconfig): + # get qconfig for module_name, + # fallback to module_name_regex_qconfig, module_type_qconfig, + # global_qconfig if necessary + module_type_qconfig = _get_object_type_qconfig( + qconfig_mapping, module_type, global_qconfig) + module_name_regex_qconfig = _get_module_name_regex_qconfig( + qconfig_mapping, module_name, module_type_qconfig) + module_name_qconfig = _get_module_name_qconfig( + qconfig_mapping, module_name, module_name_regex_qconfig) + return module_name_qconfig + + +def _get_flattened_qconfig_dict(qconfig_mapping: QConfigMapping) -> Dict[Union[Callable, str], QConfigAny]: + """ flatten the global, object_type and module_name qconfig + to the same qconfig_dict so that it can be used by + propagate_qconfig_ function. + "module_name_regex" is ignored for now since it's not supported + in propagate_qconfig_, but it can be fixed later. + + For example: + Input: { + "": qconfig, + "object_type": [ + (torch.add, qconfig) + ], + "module_name": [ + ("conv", qconfig) + ] + } + + Output: { + "": qconfig, + torch.add: qconfig, + "conv": qconfig + } + """ + flattened: Dict[Union[Callable, str], QConfigAny] = {"": qconfig_mapping.global_qconfig} + for obj, qconfig in qconfig_mapping.object_type_qconfigs.items(): + flattened[obj] = qconfig + for obj, qconfig in qconfig_mapping.module_name_qconfigs.items(): + flattened[obj] = qconfig + return flattened + + +def _update_qconfig_for_qat( + qconfig_mapping: QConfigMapping, + additional_qat_module_mapping: Dict[Callable, Callable]): + """ + Update the qconfig_dict to account for module swaps during QAT. + During QAT we perform a module swap on the nn.Module types to the corresponding nn.qat.modules types. + """ + all_qat_mappings = get_combined_dict( + get_default_qat_module_mappings(), additional_qat_module_mapping) + object_type_dict = qconfig_mapping.object_type_qconfigs + new_object_type_dict = object_type_dict.copy() + for k, v in new_object_type_dict.items(): + if k in all_qat_mappings: + object_type_dict[all_qat_mappings[k]] = v diff --git a/torch/ao/quantization/fx/quantization_patterns.py b/torch/ao/quantization/fx/quantize_handler.py similarity index 63% rename from torch/ao/quantization/fx/quantization_patterns.py rename to torch/ao/quantization/fx/quantize_handler.py index c24adb9e11e90..8670eee3ed776 100644 --- a/torch/ao/quantization/fx/quantization_patterns.py +++ b/torch/ao/quantization/fx/quantize_handler.py @@ -6,13 +6,19 @@ from .utils import ( all_node_args_have_no_tensors, ) +from torch.ao.quantization.backend_config import ( + BackendConfig, + DTypeConfig, + ObservationType, +) from torch.ao.quantization.utils import ( - Pattern, NodePattern, + Pattern, + QuantizerCls, ) from abc import ABC -from typing import Any, Callable, Dict, Optional +from typing import Callable, Dict, List, Type __all__ = [ "QuantizeHandler", @@ -38,7 +44,6 @@ def _default_root_node_getter(node_pattern): node_pattern = node_pattern[-1] return node_pattern -# TODO: move to backend_config_utils.py # Base Pattern Handler class QuantizeHandler(ABC): """ Base handler class for the quantizer patterns @@ -98,25 +103,70 @@ def is_general_tensor_value_op(self) -> bool: """ return False - def get_activation_ctr( - self, - qconfig: Any, - pattern: Pattern, - is_training: bool, - ) -> Optional[Callable]: - """ - Returns the constructor for the activation observer which should be - used for the pattern matched to this handler. Some handlers override - this to a different value than what is specified in the qconfig. - """ - return qconfig.activation - def is_custom_module(self): return self.is_custom_module_ def is_standalone_module(self): return self.is_standalone_module_ +def _get_quantize_handler_cls( + observation_type: ObservationType, + dtype_configs: List[DTypeConfig], + num_tensor_args_to_observation_type: Dict[int, ObservationType], + input_output_observed: bool) -> Type[QuantizeHandler]: + """ + Return a configurable QuantizeHandler that matches the given specifications from the backend. + """ + + class ConfigurableQuantizeHandler(QuantizeHandler): + def __init__( + self, + node_pattern: NodePattern, + modules: Dict[str, torch.nn.Module], + root_node_getter: Callable = None): + super().__init__(node_pattern, modules, root_node_getter) + if num_tensor_args_to_observation_type: + assert self.num_tensor_args in num_tensor_args_to_observation_type, \ + f"Must provide observation_type config for tensor number {self.num_tensor_args}" \ + f" in num_tensor_args_to_observation_type for {node_pattern}" + self.observation_type = num_tensor_args_to_observation_type[self.num_tensor_args] + else: + self.observation_type = observation_type + self.dtype_configs = dtype_configs + self.input_output_observed_ = input_output_observed + + def is_general_tensor_value_op(self) -> bool: + return self.observation_type == ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT + + # This is temporary, and will be removed soon + def input_output_observed(self): + return self.input_output_observed_ + + return ConfigurableQuantizeHandler + +def _get_pattern_to_quantize_handlers(backend_config: BackendConfig) -> Dict[Pattern, QuantizerCls]: + """ + Note: Quantize handler is just a holder for some check methods like + (should_insert_observer_for_output), maybe this can be a enum as well, + we can refactor this after we convert the path for fbgemm/qnnpack fully to the + new path, this is not exposed to backend developers + """ + pattern_to_quantize_handlers = {} + for pattern, config in backend_config.configs.items(): + observation_type = config.observation_type + dtype_configs = config.dtype_configs + num_tensor_args_to_observation_type = config._num_tensor_args_to_observation_type + input_output_observed = config._input_output_observed + if input_output_observed is None: + input_output_observed = True + pattern_to_quantize_handlers[pattern] = \ + _get_quantize_handler_cls( + observation_type, + dtype_configs, + num_tensor_args_to_observation_type, + input_output_observed) + return pattern_to_quantize_handlers + # TODO: remove this class, this is still exposed in torch.quantization # but we should be able to break bc class BinaryOpQuantizeHandler(QuantizeHandler): diff --git a/torch/ao/quantization/fx/tracer.py b/torch/ao/quantization/fx/tracer.py index 3a959447cfd6b..1ac98a13c548e 100644 --- a/torch/ao/quantization/fx/tracer.py +++ b/torch/ao/quantization/fx/tracer.py @@ -1,67 +1,13 @@ import torch from torch.fx._symbolic_trace import Tracer -from torch.fx.node import Target, Node, Argument +from torch.fx.proxy import Scope from torch.nn.intrinsic import _FusedModule -from typing import List, Callable, Tuple, Any, Dict, Optional +from typing import List, Callable __all__ = [ "QuantizationTracer", ] -class Scope(object): - """ Scope object that records the module path and the module type - of a module. Scope is used to track the information of the module - that contains a Node in a Graph of GraphModule. For example:: - - class Sub(torch.nn.Module): - def forward(self, x): - # This will be a call_method Node in GraphModule, - # scope for this would be (module_path="sub", module_type=Sub) - return x.transpose(1, 2) - - class M(torch.nn.Module): - def __init__(self): - self.sub = Sub() - - def forward(self, x): - # This will be a call_method Node as well, - # scope for this would be (module_path="", None) - x = x.transpose(1, 2) - x = self.sub(x) - return x - - """ - - def __init__(self, module_path: str, module_type: Any): - super().__init__() - self.module_path = module_path - self.module_type = module_type - - -class ScopeContextManager(object): - """ A context manager to track the Scope of Node during symbolic tracing. - When entering a forward function of a Module, we'll update the scope information of - the current module, and when we exit, we'll restore the previous scope information. - """ - - def __init__( - self, scope: Scope, current_module: torch.nn.Module, current_module_path: str - ): - super().__init__() - self.prev_module_type = scope.module_type - self.prev_module_path = scope.module_path - self.scope = scope - self.scope.module_path = current_module_path - self.scope.module_type = type(current_module) - - def __enter__(self): - return - - def __exit__(self, *args): - self.scope.module_path = self.prev_module_path - self.scope.module_type = self.prev_module_type - return - class QuantizationTracer(Tracer): def __init__( self, skipped_module_names: List[str], skipped_module_classes: List[Callable] @@ -75,7 +21,6 @@ def __init__( # We can change this if there is a use case that configures # qconfig using top level module type self.scope = Scope("", None) - self.node_name_to_scope: Dict[str, Tuple[str, type]] = {} self.record_stack_traces = True def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: @@ -88,32 +33,3 @@ def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool or type(m) in self.skipped_module_classes or isinstance(m, _FusedModule) ) - - def call_module( - self, - m: torch.nn.Module, - forward: Callable[..., Any], - args: Tuple[Any, ...], - kwargs: Dict[str, Any], - ) -> Any: - module_qualified_name = self.path_of_module(m) - # Creating scope with information of current module - # scope will be restored automatically upon exit - with ScopeContextManager(self.scope, m, module_qualified_name): - return super().call_module(m, forward, args, kwargs) - - def create_node( - self, - kind: str, - target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - name: Optional[str] = None, - type_expr: Optional[Any] = None, - ) -> Node: - node = super().create_node(kind, target, args, kwargs, name, type_expr) - self.node_name_to_scope[node.name] = ( - self.scope.module_path, - self.scope.module_type, - ) - return node diff --git a/torch/ao/quantization/fx/utils.py b/torch/ao/quantization/fx/utils.py index f359bd90f9e61..242e18935aa2a 100644 --- a/torch/ao/quantization/fx/utils.py +++ b/torch/ao/quantization/fx/utils.py @@ -1,5 +1,4 @@ import copy -import re import torch import torch.nn as nn from torch.ao.quantization import ( @@ -10,14 +9,21 @@ BackendConfig, DTypeWithConstraints, ) -from torch.ao.quantization.fake_quantize import FakeQuantizeBase -from torch.ao.quantization.observer import ObserverBase -from torch.ao.quantization.stubs import DeQuantStub -from torch.ao.quantization.utils import ( - activation_is_statically_quantized, - is_per_tensor, - is_per_channel, +from torch.ao.quantization.fake_quantize import ( + FakeQuantizeBase, + FixedQParamsFakeQuantize, +) +from torch.ao.quantization.observer import ( + FixedQParamsObserver, + ObserverBase, ) +from torch.ao.quantization.qconfig import ( + float16_static_qconfig, + float16_dynamic_qconfig, + qconfig_equals, +) +from torch.ao.quantization.stubs import DeQuantStub +from torch.ao.quantization.utils import activation_is_statically_quantized from torch.ao.quantization.quantize import is_activation_post_process from torch.fx import GraphModule, map_arg @@ -27,6 +33,8 @@ Node, ) from .custom_config import PrepareCustomConfig +# importing the lib so that the quantized_decomposed ops are registered +from ._decomposed import quantized_decomposed_lib # noqa: F401 from typing import Callable, Optional, List, Dict, Any, Set, Tuple, Union, Type from collections import namedtuple @@ -41,28 +49,20 @@ "collect_producer_nodes", "create_getattr_from_value", "create_node_from_old_node_preserve_meta", - "create_qparam_nodes", "EMPTY_ARG_DICT", "get_custom_module_class_keys", "get_linear_prepack_op_for_dtype", "get_new_attr_name_with_prefix", "get_non_observable_arg_indexes_and_types", - "get_per_tensor_qparams", - "get_qconv_op", "get_qconv_prepack_op", - "get_quantize_node_info", "get_skipped_module_name_and_classes", "graph_module_from_producer_nodes", - "graph_pretty_str", - "is_get_tensor_info_node", "maybe_get_next_module", "NodeInfo", - "node_return_type_is_int", "node_arg_is_bias", "node_arg_is_weight", "NON_OBSERVABLE_ARG_DICT", "NON_QUANTIZABLE_WEIGHT_OPS", - "quantize_node", "return_arg_list", ] @@ -86,183 +86,6 @@ def node_arg_is_bias(node: Node, arg: Any, backend_config: BackendConfig) -> boo return node.kwargs.get("bias") is arg return False -def graph_pretty_str(g, shorten=True) -> str: - """Returns a printable representation of the ops in the graph of g. - If shorten is True, tries to abbreviate fields. - """ - built_in_func_re = re.compile('') - built_in_meth_re = re.compile('') - op_dict = { - 'placeholder': 'plchdr', - 'get_attr': 'gt_prm', - 'call_function': 'cl_fun', - 'call_module': 'cl_mod', - 'call_method': 'cl_meth', - } - - max_lens = {} - col_names = ("name", "op", "target", "args", "kwargs") - for s in col_names: - max_lens[s] = len(s) - - results = [] - for n in g.nodes: - - # activation_post_process_0 -> obs_0 - name = str(n.name) - if shorten: - name = name.replace("activation_post_process", "obs") - - op = str(n.op) - # placeholder -> plchdr, and so on - if shorten and op in op_dict: - op = op_dict[op] - - target = str(n.target) - # -> , and so on - if shorten: - built_in_func = built_in_func_re.search(target) - if built_in_func: - target = f"" - built_in_meth = built_in_meth_re.search(target) - if built_in_meth: - target = f"" - target = target.replace("activation_post_process", "obs") - - args = str(n.args) - if shorten: - args = args.replace("activation_post_process", "obs") - - kwargs = str(n.kwargs) - - # calculate maximum length of each column, so we can tabulate properly - for k, v in zip(col_names, (name, op, target, args, kwargs)): - max_lens[k] = max(max_lens[k], len(v)) - results.append([name, op, target, args, kwargs]) - - res_str = "" - format_str = "{:<{name}} {:<{op}} {:<{target}} {:<{args}} {:<{kwargs}}\n" - res_str += format_str.format(*col_names, **max_lens) - for result in results: - res_str += format_str.format(*result, **max_lens) - - # print an exra note on abbreviations which change attribute names, - # since users will have to un-abbreviate for further debugging - if shorten: - res_str += "*obs_{n} = activation_post_process_{n}\n" - return res_str - -def get_per_tensor_qparams(activation_post_process): - assert is_per_tensor(activation_post_process.qscheme), 'Only per tensor quantization is supported' - scale, zero_point = activation_post_process.calculate_qparams() - scale = float(scale) - zero_point = int(zero_point) - dtype = activation_post_process.dtype - return scale, zero_point, dtype - -def get_quantize_node_info(activation_post_process: Callable) -> Optional[Tuple[str, Union[Callable, str], Dict[str, Any]]]: - ''' Given an activation_post_process module, - return node_type(e.g. call_function), quantize op(e.g. quantize_per_tensor) and a dictionary - of extracted qparams from the module - ''' - dtype = activation_post_process.dtype # type: ignore[attr-defined] - compute_dtype = None - if hasattr(activation_post_process, "compute_dtype"): - compute_dtype = activation_post_process.compute_dtype # type: ignore[attr-defined] - quantize_op : Optional[Union[Callable, str]] = None - if dtype in [torch.quint8, torch.qint8] and \ - not hasattr(activation_post_process, 'compute_dtype'): - node_type = "call_function" - scale, zero_point = activation_post_process.calculate_qparams() # type: ignore[attr-defined] - if is_per_channel(activation_post_process.qscheme): # type: ignore[attr-defined] - ch_axis = int(activation_post_process.ch_axis) # type: ignore[attr-defined] - qparams = {"_scale_": scale, "_zero_point_": zero_point, "_axis_": ch_axis, "_dtype_": dtype} - quantize_op = torch.quantize_per_channel - else: - scale = float(scale) - zero_point = int(zero_point) - qparams = {"_scale_": scale, "_zero_point_": zero_point, "_dtype_": dtype} - quantize_op = torch.quantize_per_tensor - elif compute_dtype in [torch.quint8, torch.qint8, torch.float16]: - # TODO(future PR): switch compute_dtype to is_dynamic - # dynamic quantization - node_type = "call_function" - quantize_op = torch.quantize_per_tensor_dynamic - # TODO: get reduce range from observer - # reduce_range = activation_post_process.reduce_range - reduce_range = torch.backends.quantized.engine in ("fbgemm", "x86") - qparams = {"_dtype_": compute_dtype, "_reduce_range_": reduce_range} - elif dtype == torch.float16: - node_type = "call_method" - quantize_op = "to" - qparams = {"_dtype_": dtype} - else: - warnings.warn(f"Unsupported activation_post_process in get_quantize_node_info: {activation_post_process}") - return None - return node_type, quantize_op, qparams - -def quantize_node( - in_node: Node, - obs_module: torch.nn.Module, - obs_node: Node, - modules: Dict[str, torch.nn.Module], - quantized_graph: Graph, - node_name_to_scope: Dict[str, Tuple[str, type]], - is_input: bool, - output_prefix: str = "_output") -> Node: - ''' Add quantization nodes (eg. quantize_per_tensor/per_channel) for given node to graph - with the qparams calculated from activation_post_process (obs_module). - The observer node (obs_node) is used to find the FQN of the user of act_post_process. - e.g. Given input `node` in `node = self.conv(x)`, insert node: - `quantized_node = torch.quantize_per_tensor(x, self._scale_0, self._zer_point_0, self._dtype_0)` - where self._scale_0, self._zero_point_0 and self._dtype_0 are - calculated from `obs_module` - ''' - # Find the first use of the observer node, we use this to get the scope of the module. - if is_input: - # if the quantize function is at the input of op, then we find the first user of the observer_node - # to get the path. If a linear call_function is in the user list, we return the first instance - # of linear node to get the FQN. - users = list(obs_node.users) - first_linear_use_or_first_use = users[0] if users else None - linear_node = None - for n in users: - if n.op == "call_function" and n.target == torch.nn.functional.linear: - linear_node = n - break - if linear_node: - first_linear_use_or_first_use = linear_node - prefix = "_input" - else: - # if the quantize function is at the output of the op, we use the observer input node to get the path - first_linear_use_or_first_use = in_node - prefix = output_prefix - - if first_linear_use_or_first_use and first_linear_use_or_first_use.name in node_name_to_scope: - module_path, _ = node_name_to_scope[first_linear_use_or_first_use.name] - else: - # TODO: it's not used, so actually we can skip quantization - # but this requires changing return type of quantize_node - # we can fix it later if needed - module_path = "" - root_module = modules[''] - graph = quantized_graph - maybe_quantize_node_info = get_quantize_node_info(obs_module) - assert maybe_quantize_node_info is not None, \ - f"Expecting quantize node info not to be None, observer: {obs_module}" - node_type, quantize_op, qparams = maybe_quantize_node_info - inputs = [in_node] - - for key, value in qparams.items(): - if key in ['_scale_', '_zero_point_']: - # For scale and zero_point values we register them as buffers in the root module. - qparam_node = create_getattr_from_value(root_module, graph, module_path + prefix + key, value) - inputs.append(qparam_node) - else: - # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph. - inputs.append(value) - return graph.create_node(node_type, quantize_op, tuple(inputs), {}) - def get_custom_module_class_keys(custom_module_mapping: Dict[QuantType, Dict[Type, Type]]) -> List[Any]: r""" Get all the unique custom module keys in the custom config dict e.g. @@ -309,24 +132,6 @@ def get_qconv_prepack_op(conv_op: Callable) -> Callable: assert prepack_op, "Didn't find prepack op for {}".format(conv_op) return prepack_op -def get_qconv_op(conv_op: Callable, has_relu: bool) -> Callable: - qconv_op = { - # has relu - True: { - torch.nn.functional.conv1d: torch.ops.quantized.conv1d_relu, - torch.nn.functional.conv2d: torch.ops.quantized.conv2d_relu, - torch.nn.functional.conv3d: torch.ops.quantized.conv3d_relu - }, - False: { - torch.nn.functional.conv1d: torch.ops.quantized.conv1d, - torch.nn.functional.conv2d: torch.ops.quantized.conv2d, - torch.nn.functional.conv3d: torch.ops.quantized.conv3d - } - } - qconv = qconv_op[has_relu].get(conv_op) - assert qconv, "Can't find corresponding quantized conv op for {} {}".format(conv_op, has_relu) - return qconv - # Returns a function that can get a new attribute name for module with given # prefix, for example, # >> get_new_observer_name = get_new_attr_name_with_prefix('_observer') @@ -426,25 +231,6 @@ def create_getattr_from_value(module: torch.nn.Module, graph: Graph, prefix: str attr_node = graph.create_node("get_attr", attr_name) return attr_node -def create_qparam_nodes( - node_name: str, - scale: Any, - zero_point: Any, - modules: Dict[str, torch.nn.Module], - quantized_graph: Graph, - node_name_to_scope: Dict[str, Tuple[str, type]] -) -> Tuple[Node, Node]: - """ - Create getattr nodes in the quantized graph for scale and zero point values. - The nodes are registered with the root_module of the model. - """ - root_module = modules[''] - module_path, _ = node_name_to_scope[node_name] - scale_node = create_getattr_from_value(root_module, quantized_graph, (module_path + "_scale_"), scale) - zero_point_node = create_getattr_from_value(root_module, quantized_graph, (module_path + "_zero_point_"), zero_point) - return (scale_node, zero_point_node) - - def all_node_args_have_no_tensors(node: Node, modules: Dict[str, torch.nn.Module], cache: Dict[Node, bool]) -> bool: """ If we know for sure that all of this node's args have no @@ -589,22 +375,6 @@ def get_non_observable_arg_indexes_and_types(node: Node) -> Dict[Union[type, tor return NON_OBSERVABLE_ARG_DICT.get(info, EMPTY_ARG_DICT) -def node_return_type_is_int(node: Node) -> bool: - """ - Returns true if this node results in an integer, even if some of the args - are Tensors. - """ - return node.op == 'call_method' and node.target == 'size' - - -def is_get_tensor_info_node(node: Node) -> bool: - """ Returns True if this node is a node that takes a Tensor as input and output some - meta information about the Tensor, e.g. shape, size etc. - """ - result: bool = \ - node.op == "call_function" and node.target == getattr and node.args[1] == "shape" # type: ignore[assignment] - return result - def maybe_get_next_module( node: Node, modules: Dict[str, nn.Module], @@ -667,7 +437,7 @@ def _is_custom_module_lstm( """ mod = _get_module(node, named_modules) if qconfig is not None and qhandler is not None: - assert isinstance(qhandler, torch.ao.quantization.fx.quantization_patterns.QuantizeHandler) # type: ignore[attr-defined] + assert isinstance(qhandler, torch.ao.quantization.fx.quantize_handler.QuantizeHandler) # type: ignore[attr-defined] return isinstance(mod, torch.nn.LSTM) and \ activation_is_statically_quantized(qconfig) and \ qhandler.is_custom_module() @@ -977,10 +747,13 @@ def _qconfig_satisfies_dtype_config_constraints( 1. QConfig specified a quantization range that falls within the backend's, if any 2. QConfig specified a min scale value that is >= the backend's, if any + 3. QConfig specified a FixedQParamsObserver or FixedQParamsFakeQuantize that has + scale and zero point that match the backend's, if any If `is_activation` is True, we check `qconfig.activation`, else we check `qconfig.weight`. If `qconfig` or `dtype_with_constraints.dtype` is None, or the dtypes do not match, return True. """ + # TODO: log warnings only when the user enabled a debug flag def _activation_post_process_satisfies_dtype_config_constraints( activation_post_process: Union[ObserverBase, FakeQuantizeBase], dtype_with_constraints: DTypeWithConstraints, @@ -994,6 +767,8 @@ def _activation_post_process_satisfies_dtype_config_constraints( backend_quant_min = dtype_with_constraints.quant_min_lower_bound backend_quant_max = dtype_with_constraints.quant_max_upper_bound backend_scale_min = dtype_with_constraints.scale_min_lower_bound + backend_scale_exact_match = dtype_with_constraints.scale_exact_match + backend_zero_point_exact_match = dtype_with_constraints.zero_point_exact_match # check quantization ranges if backend_quant_min is not None and backend_quant_max is not None: if app_quant_min is None or app_quant_max is None: @@ -1016,6 +791,30 @@ def _activation_post_process_satisfies_dtype_config_constraints( "the backend's min scale value (%s), ignoring %s") % (debug_string, app_scale_min, backend_scale_min, qconfig)) return False + # check fixed scale and zero point + if backend_scale_exact_match is not None and backend_zero_point_exact_match is not None: + # For tests only, accept the following qconfigs for now + # TODO: handle fp16 qconfigs properly + for accepted_qconfig in [float16_static_qconfig, float16_dynamic_qconfig]: + if qconfig_equals(qconfig, accepted_qconfig): + return True + suggestion_str = ( + "Please use torch.ao.quantization.get_default_qconfig_mapping or " + "torch.ao.quantization.get_default_qat_qconfig_mapping. Example:\n" + " qconfig_mapping = get_default_qconfig_mapping(\"fbgemm\")\n" + " model = prepare_fx(model, qconfig_mapping, example_inputs)" + ) + if not isinstance(activation_post_process, FixedQParamsObserver) and \ + not isinstance(activation_post_process, FixedQParamsFakeQuantize): + warnings.warn(("QConfig must specify a FixedQParamsObserver or a FixedQParamsFakeQuantize " + "for fixed qparams ops, ignoring %s.\n%s") % (qconfig, suggestion_str)) + return False + if observer.scale != backend_scale_exact_match or observer.zero_point != backend_zero_point_exact_match: + warnings.warn(("QConfig fixed scale (%s) and zero point (%s) do not match the backend's " + "(%s and %s), ignoring %s.\n%s") % + (observer.scale, observer.zero_point, backend_scale_exact_match, + backend_zero_point_exact_match, qconfig, suggestion_str)) + return False return True if qconfig is None or dtype_with_constraints.dtype is None: diff --git a/torch/ao/quantization/observer.py b/torch/ao/quantization/observer.py index e704444d0a6dc..f8683024cee52 100644 --- a/torch/ao/quantization/observer.py +++ b/torch/ao/quantization/observer.py @@ -1019,7 +1019,7 @@ def _non_linear_param_search(self) -> Tuple[torch.Tensor, torch.Tensor]: This follows the implementation of NormMinimization::NonlinearQuantizationParamsSearch in caffe2/quantization/server/norm_minimization.cc """ - assert self.histogram.size()[0] == self.bins, "bins mistmatch" + assert self.histogram.size()[0] == self.bins, "bins mismatch" bin_width = (self.max_val - self.min_val) / self.bins # cumulative sum @@ -1316,26 +1316,38 @@ class PlaceholderObserver(ObserverBase): Args: dtype: dtype argument to the `quantize` node needed to implement the reference model spec. + quant_min: minimum value in quantized domain (TODO: align behavior with other observers) + quant_min: maximum value in quantized domain custom_op_name: (temporary) specify this observer for an operator that doesn't require any observation (Can be used in Graph Mode Passes for special case ops). - compute_dtype: if set, marks the future quantize function to use + compute_dtype (deprecated): if set, marks the future quantize function to use dynamic quantization instead of static quantization. - Note: this field will be removed in the near future and - replaced with `is_dynamic`. + This field is deprecated, use `is_dynamic=True` instead. + is_dynamic: if True, the `quantize` function in the reference model + representation taking stats from this observer instance will + use dynamic quantization. """ def __init__( - self, dtype=torch.float32, custom_op_name="", compute_dtype=None + self, dtype=torch.float32, custom_op_name="", compute_dtype=None, + quant_min=None, quant_max=None, is_dynamic=False, ) -> None: - super(PlaceholderObserver, self).__init__(dtype=dtype) + super().__init__(dtype=dtype) # dtype of input of the target operator, e.g. for dynamic quantization # ops, the dtype will be float32 self.dtype = dtype + self.quant_min = quant_min + self.quant_max = quant_max self.custom_op = custom_op_name # used for configuration of computation type for dynamic quantization - # TODO(future PR): replace this with `is_dynamic` if compute_dtype: - self.compute_dtype = compute_dtype + is_dynamic = True + warnings.warn( + "Please use `is_dynamic` instead of `compute_dtype`. \ + `compute_dtype` will be deprecated in a future release \ + of PyTorch." + ) + self.is_dynamic = is_dynamic def forward(self, x): return x @@ -1551,7 +1563,7 @@ def load_observer_state_dict(mod, obs_dict): """ default_dynamic_quant_observer = PlaceholderObserver.with_args( - dtype=torch.quint8, compute_dtype=torch.quint8 + dtype=torch.quint8, quant_min=0, quant_max=255, is_dynamic=True, ) """ Default observer for dynamic quantization. diff --git a/torch/ao/quantization/qconfig.py b/torch/ao/quantization/qconfig.py index 8e662e5745ce6..2dec48498aa58 100644 --- a/torch/ao/quantization/qconfig.py +++ b/torch/ao/quantization/qconfig.py @@ -72,13 +72,8 @@ "get_default_qat_qconfig", "get_default_qconfig_dict", "get_default_qat_qconfig_dict", - "assert_valid_qconfig", - "add_module_to_qconfig_obs_ctr", "QConfigAny", - "obs_or_fq_ctr_equals", "qconfig_equals", - "activation_is_memoryless", - "is_reuse_input_qconfig", ] class QConfig(namedtuple('QConfig', ['activation', 'weight'])): @@ -157,7 +152,7 @@ def __new__(cls, activation=torch.nn.Identity, weight=torch.nn.Identity): Default dynamic qconfig. """ -float16_dynamic_qconfig = QConfig(activation=PlaceholderObserver.with_args(dtype=torch.float16, compute_dtype=torch.float16), +float16_dynamic_qconfig = QConfig(activation=PlaceholderObserver.with_args(dtype=torch.float16, is_dynamic=True), weight=PlaceholderObserver.with_args(dtype=torch.float16)) """ Dynamic qconfig with weights quantized to `torch.float16`. @@ -223,17 +218,24 @@ def __new__(cls, activation=torch.nn.Identity, weight=torch.nn.Identity): Default qconfig for operators that reuse the observers from input Tensor, e.g. reshape """ -def get_default_qconfig(backend='fbgemm', version=0): +def get_default_qconfig(backend='x86', version=0): """ Returns the default PTQ qconfig for the specified backend. Args: - * `backend`: a string representing the target backend. Currently supports - `x86`, `fbgemm` (default), `qnnpack` and `onednn`. + * `backend` (str): a string representing the target backend. Currently supports + `x86` (default), `fbgemm`, `qnnpack` and `onednn`. Return: qconfig """ + supported_backends = ["fbgemm", "x86", "qnnpack", "onednn"] + if backend not in supported_backends: + raise AssertionError( + "backend: " + str(backend) + + " not supported. backend must be one of {}".format(supported_backends) + ) + if version == 0: if backend == 'fbgemm': qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=True), @@ -249,6 +251,7 @@ def get_default_qconfig(backend='fbgemm', version=0): qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=True), weight=default_per_channel_weight_observer) else: + # won't reach qconfig = default_qconfig else: raise AssertionError("Version number: " + str(version) + @@ -298,18 +301,25 @@ def get_default_qconfig(backend='fbgemm', version=0): default_embedding_qat_qconfig_4bit = QConfig(activation=NoopObserver.with_args(dtype=torch.float32), weight=default_embedding_fake_quant_4bit) -def get_default_qat_qconfig(backend='fbgemm', version=1): +def get_default_qat_qconfig(backend='x86', version=1): """ Returns the default QAT qconfig for the specified backend. Args: - * `backend`: a string representing the target backend. Currently supports - `x86`, `fbgemm` (default), `qnnpack` and `onednn`. + * `backend` (str): a string representing the target backend. Currently supports + `x86` (default), `fbgemm`, `qnnpack` and `onednn`. * `version`: version, for backwards compatibility. Can be `None` or `1`. Return: qconfig """ + supported_backends = ["fbgemm", "x86", "qnnpack", "onednn"] + if backend not in supported_backends: + raise AssertionError( + "backend: " + str(backend) + + " not supported. backend must be one of {}".format(supported_backends) + ) + # Histogram observer is too slow for quantization aware training if version == 0: if backend == 'fbgemm': @@ -329,7 +339,7 @@ def get_default_qat_qconfig(backend='fbgemm', version=1): quant_min=0, quant_max=255), weight=default_per_channel_weight_fake_quant) - if backend == 'x86': + elif backend == 'x86': qconfig = QConfig(activation=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255, @@ -392,20 +402,20 @@ def get_default_qat_qconfig(backend='fbgemm', version=1): eps=2 ** -12), weight=fused_per_channel_wt_fake_quant_range_neg_127_to_127) -def get_default_qconfig_dict(backend='fbgemm', version=0): +def get_default_qconfig_dict(backend='x86', version=0): warnings.warn( "torch.ao.quantization.get_default_qconfig_dict is deprecated and will be removed in " "a future version. Please use torch.ao.quantization.get_default_qconfig_mapping instead.") return torch.ao.quantization.get_default_qconfig_mapping(backend, version).to_dict() -def get_default_qat_qconfig_dict(backend='fbgemm', version=1): +def get_default_qat_qconfig_dict(backend='x86', version=1): warnings.warn( "torch.ao.quantization.get_default_qat_qconfig_dict is deprecated and will be removed in " "a future version. Please use torch.ao.quantization.get_default_qat_qconfig_mapping instead.") return torch.ao.quantization.get_default_qat_qconfig_mapping(backend, version).to_dict() -def assert_valid_qconfig(qconfig: Optional[QConfig], - mod: torch.nn.Module) -> None: +def _assert_valid_qconfig(qconfig: Optional[QConfig], + mod: torch.nn.Module) -> None: """ Verifies that this `qconfig` is valid. """ @@ -427,11 +437,10 @@ def assert_valid_qconfig(qconfig: Optional[QConfig], assert not is_per_channel, \ 'Per channel weight observer is not supported yet for ConvTranspose{n}d.' -# TODO: remove QConfigAny and replace it with Optional[QConfig] QConfigAny = Optional[QConfig] QConfigAny.__module__ = "torch.ao.quantization.qconfig" -def add_module_to_qconfig_obs_ctr( +def _add_module_to_qconfig_obs_ctr( qconfig: QConfigAny, module: Optional[nn.Module]) -> Any: r"""This is a helper function for use in quantization prepare that updates a qconfig so that @@ -475,7 +484,7 @@ def configure_constructor_to_put_obs_on_module_device(original_constructor): _ObserverOrFakeQuantizeConstructor = Union[_PartialWrapper, ObserverBase, FakeQuantizeBase] -def obs_or_fq_ctr_equals(obs_or_fq1: _ObserverOrFakeQuantizeConstructor, obs_or_fq2: _ObserverOrFakeQuantizeConstructor): +def _obs_or_fq_ctr_equals(obs_or_fq1: _ObserverOrFakeQuantizeConstructor, obs_or_fq2: _ObserverOrFakeQuantizeConstructor): if isinstance(obs_or_fq1, _PartialWrapper) and isinstance(obs_or_fq2, _PartialWrapper): return _partial_wrapper_equals(obs_or_fq1, obs_or_fq2) return obs_or_fq1 == obs_or_fq2 @@ -488,9 +497,9 @@ def _partial_wrapper_equals(obs_or_fq1: _PartialWrapper, obs_or_fq2: _PartialWra obs_or_fq1_keywords = copy.copy(obs_or_fq1.p.keywords) obs_or_fq2_keywords = copy.copy(obs_or_fq2.p.keywords) keywords_equal = True - # compare observer constructor with obs_or_fq_ctr_equals since direct compare would fail + # compare observer constructor with _obs_or_fq_ctr_equals since direct compare would fail if "observer" in obs_or_fq1_keywords and "observer" in obs_or_fq2_keywords: - keywords_equal = keywords_equal and obs_or_fq_ctr_equals(obs_or_fq1_keywords["observer"], obs_or_fq2_keywords["observer"]) + keywords_equal = keywords_equal and _obs_or_fq_ctr_equals(obs_or_fq1_keywords["observer"], obs_or_fq2_keywords["observer"]) obs_or_fq1_keywords.pop("observer") obs_or_fq2_keywords.pop("observer") keywords_equal = keywords_equal and obs_or_fq1_keywords == obs_or_fq2_keywords @@ -508,13 +517,13 @@ def qconfig_equals(q1: QConfigAny, q2: QConfigAny): # Qconfig weight and activation can be either a partial wrapper, # or an observer class. Special handling is required (above) for # comparing partial wrappers. - activation_same = obs_or_fq_ctr_equals(q1.activation, q2.activation) - weight_same = obs_or_fq_ctr_equals(q1.weight, q2.weight) + activation_same = _obs_or_fq_ctr_equals(q1.activation, q2.activation) + weight_same = _obs_or_fq_ctr_equals(q1.weight, q2.weight) return activation_same and weight_same except AttributeError: return q1 == q2 -def activation_is_memoryless(qconfig: QConfig): +def _activation_is_memoryless(qconfig: QConfig): """ Return whether the observer for activations defined in the given QConfig is memoryless. This means a MovingAverage observer with averaging constant equal to 1. @@ -527,7 +536,7 @@ def _is_memoryless(observer): else: return _is_memoryless(act) -def is_reuse_input_qconfig(qconfig: Optional[QConfig]): +def _is_reuse_input_qconfig(qconfig: Optional[QConfig]): return qconfig is not None and \ isinstance(qconfig.activation(), ReuseInputObserver) and \ isinstance(qconfig.weight(), NoopObserver) diff --git a/torch/ao/quantization/qconfig_mapping.py b/torch/ao/quantization/qconfig_mapping.py index 4dc4431aa99d1..69b86b0186181 100644 --- a/torch/ao/quantization/qconfig_mapping.py +++ b/torch/ao/quantization/qconfig_mapping.py @@ -1,6 +1,6 @@ from __future__ import annotations from collections import OrderedDict -from typing import Any, Callable, Dict, Tuple, Union +from typing import Any, Callable, Dict, Tuple, Union, List import torch @@ -33,12 +33,13 @@ # TODO: replace all usages with these constants -GLOBAL_DICT_KEY = "" -OBJECT_TYPE_DICT_KEY = "object_type" -MODULE_NAME_REGEX_DICT_KEY = "module_name_regex" -MODULE_NAME_DICT_KEY = "module_name" -MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY = "module_name_object_type_order" +_GLOBAL_DICT_KEY = "" +_OBJECT_TYPE_DICT_KEY = "object_type" +_MODULE_NAME_REGEX_DICT_KEY = "module_name_regex" +_MODULE_NAME_DICT_KEY = "module_name" +_MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY = "module_name_object_type_order" +# TODO: derive this map from the BackendConfig _FIXED_QPARAMS_OP_TO_OBSERVER: Dict[Union[Callable, str], _PartialWrapper] = { torch.nn.Hardsigmoid: default_fixed_qparams_range_0to1_observer, torch.nn.functional.hardsigmoid: default_fixed_qparams_range_0to1_observer, @@ -82,26 +83,12 @@ def _get_default_qconfig_mapping(is_qat: bool, backend: str, version: int) -> QC qconfig_mapping = QConfigMapping() \ .set_global(qconfig) \ .set_object_type("reshape", default_reuse_input_qconfig) \ - .set_object_type(torch.nn.Conv1d, qconfig) \ - .set_object_type(torch.nn.Conv2d, qconfig) \ - .set_object_type(torch.nn.Conv3d, qconfig) \ .set_object_type(torch.nn.ConvTranspose1d, qconfig_transpose) \ .set_object_type(torch.nn.ConvTranspose2d, qconfig_transpose) \ .set_object_type(torch.nn.ConvTranspose3d, qconfig_transpose) \ - .set_object_type(torch.nn.Linear, qconfig) \ - .set_object_type(torch.nn.functional.conv1d, qconfig) \ - .set_object_type(torch.nn.functional.conv2d, qconfig) \ - .set_object_type(torch.nn.functional.conv3d, qconfig) \ .set_object_type(torch.nn.functional.conv_transpose1d, qconfig_transpose) \ .set_object_type(torch.nn.functional.conv_transpose2d, qconfig_transpose) \ .set_object_type(torch.nn.functional.conv_transpose3d, qconfig_transpose) \ - .set_object_type(torch.nn.functional.linear, qconfig) \ - .set_object_type(torch.nn.ReLU, qconfig) \ - .set_object_type(torch.nn.functional.relu, qconfig) \ - .set_object_type(torch.relu, qconfig) \ - .set_object_type(torch.nn.BatchNorm1d, qconfig) \ - .set_object_type(torch.nn.BatchNorm2d, qconfig) \ - .set_object_type(torch.nn.BatchNorm3d, qconfig) \ .set_object_type(torch.nn.functional.layer_norm, qconfig_layernorm) \ .set_object_type(torch.nn.LayerNorm, qconfig_layernorm) \ @@ -121,26 +108,26 @@ def _get_default_qconfig_mapping(is_qat: bool, backend: str, version: int) -> QC return qconfig_mapping -def get_default_qconfig_mapping(backend="fbgemm", version=0) -> QConfigMapping: +def get_default_qconfig_mapping(backend="x86", version=0) -> QConfigMapping: """ Return the default QConfigMapping for post training quantization. Args: - * ``backend`` : the quantization backend for the default qconfig mapping, should be - one of ["x86", "fbgemm" (default), "qnnpack", "onednn"] - * ``version`` : the version for the default qconfig mapping + * ``backend`` (str) : the quantization backend for the default qconfig mapping, should be + one of ["x86" (default), "fbgemm", "qnnpack", "onednn"] + * ``version`` (int) : the version for the default qconfig mapping """ # TODO: add assert for backend choices return _get_default_qconfig_mapping(False, backend, version) -def get_default_qat_qconfig_mapping(backend="fbgemm", version=1) -> QConfigMapping: +def get_default_qat_qconfig_mapping(backend="x86", version=1) -> QConfigMapping: """ Return the default QConfigMapping for quantization aware training. Args: - * ``backend`` : the quantization backend for the default qconfig mapping, should be - one of ["x86", "fbgemm" (default), "qnnpack", "onednn"] - * ``version`` : the version for the default qconfig mapping + * ``backend`` (str) : the quantization backend for the default qconfig mapping, should be + one of ["x86" (default), "fbgemm", "qnnpack", "onednn"] + * ``version`` (int) : the version for the default qconfig mapping """ return _get_default_qconfig_mapping(True, backend, version) @@ -156,6 +143,14 @@ def _get_symmetric_qnnpack_qconfig_mapping(): qconfig_mapping.set_object_type(pattern, default_symmetric_qnnpack_qconfig) return qconfig_mapping +_QCONFIG_STYLE_ORDER: List[str] = [ + "global_qconfig", + "object_type_qconfigs", + "module_name_regex_qconfigs", + "module_name_qconfigs", + "module_name_object_type_order_qconfigs", +] + class QConfigMapping: """ Mapping from model ops to :class:`torch.ao.quantization.QConfig` s. @@ -256,6 +251,18 @@ def set_module_name_object_type_order( self.module_name_object_type_order_qconfigs[(module_name, object_type, index)] = qconfig return self + def __repr__(self) -> str: + output = self.__class__.__name__ + " (" + for style_name in _QCONFIG_STYLE_ORDER: + output += f"\n {style_name}" + qconfigs = getattr(self, style_name) + if isinstance(qconfigs, OrderedDict) and len(qconfigs) > 0: + for key, qconfig in qconfigs.items(): + output += f"\n {key}: {qconfig}" + else: + output += f"\n {qconfigs}" + return output + "\n)" + # TODO: remove this def to_dict(self) -> Dict[str, Any]: """ @@ -274,11 +281,11 @@ def to_dict(self) -> Dict[str, Any]: The values of this dictionary are lists of tuples. """ return { - GLOBAL_DICT_KEY: self.global_qconfig, - OBJECT_TYPE_DICT_KEY: list(self.object_type_qconfigs.items()), - MODULE_NAME_REGEX_DICT_KEY: list(self.module_name_regex_qconfigs.items()), - MODULE_NAME_DICT_KEY: list(self.module_name_qconfigs.items()), - MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY: [ + _GLOBAL_DICT_KEY: self.global_qconfig, + _OBJECT_TYPE_DICT_KEY: list(self.object_type_qconfigs.items()), + _MODULE_NAME_REGEX_DICT_KEY: list(self.module_name_regex_qconfigs.items()), + _MODULE_NAME_DICT_KEY: list(self.module_name_qconfigs.items()), + _MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY: [ (*k, v) for k, v in self.module_name_object_type_order_qconfigs.items() ], } @@ -302,14 +309,14 @@ def from_dict(cls, qconfig_dict: Dict[str, Any]) -> QConfigMapping: The values of this dictionary are expected to be lists of tuples. """ conf = cls() - if GLOBAL_DICT_KEY in qconfig_dict: - conf.set_global(qconfig_dict[GLOBAL_DICT_KEY]) - for object_type, qconfig in qconfig_dict.get(OBJECT_TYPE_DICT_KEY, []): + if _GLOBAL_DICT_KEY in qconfig_dict: + conf.set_global(qconfig_dict[_GLOBAL_DICT_KEY]) + for object_type, qconfig in qconfig_dict.get(_OBJECT_TYPE_DICT_KEY, []): conf.set_object_type(object_type, qconfig) - for module_name_regex, qconfig in qconfig_dict.get(MODULE_NAME_REGEX_DICT_KEY, []): + for module_name_regex, qconfig in qconfig_dict.get(_MODULE_NAME_REGEX_DICT_KEY, []): conf.set_module_name_regex(module_name_regex, qconfig) - for module_name, qconfig in qconfig_dict.get(MODULE_NAME_DICT_KEY, []): + for module_name, qconfig in qconfig_dict.get(_MODULE_NAME_DICT_KEY, []): conf.set_module_name(module_name, qconfig) - for module_name, object_type, index, qconfig in qconfig_dict.get(MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY, []): + for module_name, object_type, index, qconfig in qconfig_dict.get(_MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY, []): conf.set_module_name_object_type_order(module_name, object_type, index, qconfig) return conf diff --git a/torch/ao/quantization/qconfig_mapping_utils.py b/torch/ao/quantization/qconfig_mapping_utils.py deleted file mode 100644 index 09bce4fbebb09..0000000000000 --- a/torch/ao/quantization/qconfig_mapping_utils.py +++ /dev/null @@ -1,110 +0,0 @@ -import re -from typing import Dict, Callable, Union - -from .utils import ( - get_combined_dict, - _parent_name, -) -from .quantization_mappings import ( - get_default_qat_module_mappings, -) -from .qconfig import QConfigAny -from .qconfig_mapping import QConfigMapping - - -# TODO: revisit this list. Many helper methods shouldn't be public -__all__ = [ - "get_flattened_qconfig_dict", - "get_object_type_qconfig", - "get_module_name_qconfig", - "get_module_name_regex_qconfig", - "maybe_adjust_qconfig_for_module_type_or_name", - "update_qconfig_for_qat", -] - - -def get_object_type_qconfig( - qconfig_mapping: QConfigMapping, - object_type: Union[Callable, str], - fallback_qconfig: QConfigAny) -> QConfigAny: - return qconfig_mapping.object_type_qconfigs.get(object_type, fallback_qconfig) - - -def get_module_name_regex_qconfig(qconfig_mapping, module_name, fallback_qconfig): - for regex_pattern, qconfig in qconfig_mapping.module_name_regex_qconfigs.items(): - if re.match(regex_pattern, module_name): - # first match wins - return qconfig - return fallback_qconfig - - -def get_module_name_qconfig(qconfig_mapping, module_name, fallback_qconfig): - if module_name == '': - # module name qconfig not found - return fallback_qconfig - if module_name in qconfig_mapping.module_name_qconfigs: - return qconfig_mapping.module_name_qconfigs[module_name] - else: - parent, _ = _parent_name(module_name) - return get_module_name_qconfig(qconfig_mapping, parent, fallback_qconfig) - - -def maybe_adjust_qconfig_for_module_type_or_name(qconfig_mapping, module_type, module_name, global_qconfig): - # get qconfig for module_name, - # fallback to module_name_regex_qconfig, module_type_qconfig, - # global_qconfig if necessary - module_type_qconfig = get_object_type_qconfig( - qconfig_mapping, module_type, global_qconfig) - module_name_regex_qconfig = get_module_name_regex_qconfig( - qconfig_mapping, module_name, module_type_qconfig) - module_name_qconfig = get_module_name_qconfig( - qconfig_mapping, module_name, module_name_regex_qconfig) - return module_name_qconfig - - -def get_flattened_qconfig_dict(qconfig_mapping: QConfigMapping) -> Dict[Union[Callable, str], QConfigAny]: - """ flatten the global, object_type and module_name qconfig - to the same qconfig_dict so that it can be used by - propagate_qconfig_ function. - "module_name_regex" is ignored for now since it's not supported - in propagate_qconfig_, but it can be fixed later. - - For example: - Input: { - "": qconfig, - "object_type": [ - (torch.add, qconfig) - ], - "module_name": [ - ("conv", qconfig) - ] - } - - Output: { - "": qconfig, - torch.add: qconfig, - "conv": qconfig - } - """ - flattened: Dict[Union[Callable, str], QConfigAny] = {"": qconfig_mapping.global_qconfig} - for obj, qconfig in qconfig_mapping.object_type_qconfigs.items(): - flattened[obj] = qconfig - for obj, qconfig in qconfig_mapping.module_name_qconfigs.items(): - flattened[obj] = qconfig - return flattened - - -def update_qconfig_for_qat( - qconfig_mapping: QConfigMapping, - additional_qat_module_mapping: Dict[Callable, Callable]): - """ - Update the qconfig_dict to account for module swaps during QAT. - During QAT we perform a module swap on the nn.Module types to the corresponding nn.qat.modules types. - """ - all_qat_mappings = get_combined_dict( - get_default_qat_module_mappings(), additional_qat_module_mapping) - object_type_dict = qconfig_mapping.object_type_qconfigs - new_object_type_dict = object_type_dict.copy() - for k, v in new_object_type_dict.items(): - if k in all_qat_mappings: - object_type_dict[all_qat_mappings[k]] = v diff --git a/torch/ao/quantization/quant_type.py b/torch/ao/quantization/quant_type.py index 9d2a3a2bdc7b2..d3b1d034a1feb 100644 --- a/torch/ao/quantization/quant_type.py +++ b/torch/ao/quantization/quant_type.py @@ -2,7 +2,6 @@ __all__ = [ "QuantType", - "quant_type_to_str", ] # Quantization type (dynamic quantization, static quantization). @@ -21,7 +20,7 @@ class QuantType(enum.IntEnum): } # TODO: make this private -def quant_type_to_str(quant_type: QuantType) -> str: +def _get_quant_type_to_str(quant_type: QuantType) -> str: return _quant_type_to_str[quant_type] def _quant_type_from_str(name: str) -> QuantType: diff --git a/torch/ao/quantization/quantize.py b/torch/ao/quantization/quantize.py index 69711aa370605..d18f93987465c 100644 --- a/torch/ao/quantization/quantize.py +++ b/torch/ao/quantization/quantize.py @@ -20,12 +20,12 @@ from .utils import get_qparam_dict, has_no_children_ignoring_parametrizations from torch.ao.quantization.stubs import DeQuantStub, QuantWrapper from torch.ao.quantization.qconfig import ( - add_module_to_qconfig_obs_ctr, + _add_module_to_qconfig_obs_ctr, default_dynamic_qconfig, float16_dynamic_qconfig, float_qparams_weight_only_qconfig, float_qparams_weight_only_qconfig_4bit, - activation_is_memoryless) + _activation_is_memoryless) from torch.nn.utils.parametrize import type_before_parametrizations __all__ = [ @@ -91,9 +91,9 @@ def _propagate_qconfig_helper(module, qconfig_dict, module_qconfig = qconfig_dict.get(prefix, module_qconfig) module_qconfig = getattr(module, 'qconfig', module_qconfig) - torch.ao.quantization.qconfig.assert_valid_qconfig(module_qconfig, module) + torch.ao.quantization.qconfig._assert_valid_qconfig(module_qconfig, module) - qconfig_with_device_check = add_module_to_qconfig_obs_ctr(module_qconfig, module) + qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(module_qconfig, module) module.qconfig = qconfig_with_device_check for name, child in module.named_children(): @@ -143,11 +143,13 @@ def register_activation_post_process_hook(module, pre_hook=False): assert hasattr(module, 'activation_post_process'), \ 'Expect activation_post_process attribute already attached to the module' if pre_hook: - handle = module.register_forward_pre_hook(_observer_forward_pre_hook) - module._forward_pre_hooks.move_to_end(handle.id, last=False) + handle = module.register_forward_pre_hook( + _observer_forward_pre_hook, prepend=True + ) else: - handle = module.register_forward_hook(_observer_forward_hook) - module._forward_hooks.move_to_end(handle.id, last=False) + handle = module.register_forward_hook( + _observer_forward_hook, prepend=True + ) def add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=None, device=None, custom_module_class_mapping=None): @@ -201,7 +203,7 @@ def insert_activation_post_process(m, special_act_post_process=None): m.qconfig, device, special_act_post_process)) # Register observer as the first entry in the hook list # All post forward hooks are preserved and will be executed after the observer before convert - register_activation_post_process_hook(m, pre_hook=activation_is_memoryless(m.qconfig)) + register_activation_post_process_hook(m, pre_hook=_activation_is_memoryless(m.qconfig)) for name, child in module.named_children(): # TODO remove Dropout special after codebase stable @@ -214,12 +216,12 @@ def insert_activation_post_process(m, special_act_post_process=None): # activation_post_process are now added directly to nn.Sequentail/_FusedModule if needs_observation(child): insert_activation_post_process(child) - elif _has_special_act_post_process(child): - special_act_post_process = _get_special_act_post_process(child) - insert_activation_post_process(child, special_act_post_process) elif non_leaf_module_list is not None and type_before_parametrizations(child) in non_leaf_module_list: if needs_observation(child): insert_activation_post_process(child) + elif _has_special_act_post_process(child): + special_act_post_process = _get_special_act_post_process(child) + insert_activation_post_process(child, special_act_post_process) elif needs_observation(child) and type_before_parametrizations(child) in custom_module_class_mapping: observed_child = custom_module_class_mapping[type_before_parametrizations(child)].from_float(child) setattr(module, name, observed_child) diff --git a/torch/ao/quantization/quantize_fx.py b/torch/ao/quantization/quantize_fx.py index fb6f3dc1fe574..2bcd2e4ca7125 100644 --- a/torch/ao/quantization/quantize_fx.py +++ b/torch/ao/quantization/quantize_fx.py @@ -17,7 +17,6 @@ FuseCustomConfig, PrepareCustomConfig, ) -from .fx.utils import graph_pretty_str # noqa: F401 from .fx.utils import get_custom_module_class_keys # noqa: F401 from .fx.utils import get_skipped_module_name_and_classes from .qconfig_mapping import QConfigMapping @@ -64,61 +63,6 @@ def _fuse_fx( graph_module, is_qat, fuse_custom_config, backend_config) # type: ignore[operator] -class Scope(object): - """ Scope object that records the module path and the module type - of a module. Scope is used to track the information of the module - that contains a Node in a Graph of GraphModule. For example:: - - class Sub(torch.nn.Module): - def forward(self, x): - # This will be a call_method Node in GraphModule, - # scope for this would be (module_path="sub", module_type=Sub) - return x.transpose(1, 2) - - class M(torch.nn.Module): - def __init__(self): - self.sub = Sub() - - def forward(self, x): - # This will be a call_method Node as well, - # scope for this would be (module_path="", None) - x = x.transpose(1, 2) - x = self.sub(x) - return x - - """ - - def __init__(self, module_path: str, module_type: Any): - super().__init__() - self.module_path = module_path - self.module_type = module_type - - -class ScopeContextManager(object): - """ A context manager to track the Scope of Node during symbolic tracing. - When entering a forward function of a Module, we'll update the scope information of - the current module, and when we exit, we'll restore the previous scope information. - """ - - def __init__( - self, scope: Scope, current_module: torch.nn.Module, current_module_path: str - ): - super().__init__() - self.prev_module_type = scope.module_type - self.prev_module_path = scope.module_path - self.scope = scope - self.scope.module_path = current_module_path - self.scope.module_type = type(current_module) - - def __enter__(self): - return - - def __exit__(self, *args): - self.scope.module_path = self.prev_module_path - self.scope.module_type = self.prev_module_type - return - - def _prepare_fx( model: torch.nn.Module, qconfig_mapping: Union[QConfigMapping, Dict[str, Any]], @@ -291,7 +235,7 @@ def prepare_fx( * `_equalization_config`: config for specifying how to perform equalization on the model * `backend_config` (BackendConfig): config that specifies how operators are quantized - in a backend, this includes how the operaetors are observed, + in a backend, this includes how the operators are observed, supported fusion patterns, how quantize/dequantize ops are inserted, supported dtypes etc. See :class:`~torch.ao.quantization.backend_config.BackendConfig` for more details @@ -489,7 +433,7 @@ def train_loop(model, train_data): qconfig_mapping = get_default_qat_qconfig("fbgemm") # We can customize qconfig_mapping in different ways, please take a look at - # the doctring for :func:`~torch.ao.quantization.prepare_fx` for different ways + # the docstring for :func:`~torch.ao.quantization.prepare_fx` for different ways # to configure this # example_inputs is a tuple of inputs, that is used to infer the type of the @@ -530,6 +474,7 @@ def _convert_fx( _remove_qconfig: bool = True, qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None, backend_config: Union[BackendConfig, Dict[str, Any], None] = None, + is_decomposed: bool = False, ) -> torch.nn.Module: """ `is_standalone_module`: see docs in :func:`~torch.ao.quantization.prepare_standalone_module_fx` """ @@ -552,6 +497,7 @@ def _convert_fx( _remove_qconfig_flag=_remove_qconfig, qconfig_mapping=qconfig_mapping, backend_config=backend_config, + is_decomposed=is_decomposed, ) preserved_attributes = convert_custom_config.preserved_attributes @@ -676,6 +622,59 @@ def convert_to_reference_fx( backend_config=backend_config, ) +def _convert_to_reference_decomposed_fx( + graph_module: GraphModule, + convert_custom_config: Union[ConvertCustomConfig, Dict[str, Any], None] = None, + _remove_qconfig: bool = True, + qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None, + backend_config: Union[BackendConfig, Dict[str, Any], None] = None, +) -> torch.nn.Module: + r""" Convert a calibrated or trained model to a reference quantized model, with + decomposed representation for quantized Tensor + see https://github.com/pytorch/rfcs/blob/master/RFC-0019-Extending-PyTorch-Quantization-to-Custom-Backends.md for more details, + reference quantzied model is a standard representation of a quantized model provided + by FX Graph Mode Quantization, it can be further lowered to run on the target + hardware, like accelerators + + Note: this is not public API + + Args: + * `graph_module` (GraphModule): A prepared and calibrated/trained model (GraphModule) + + * `convert_custom_config` (ConvertCustomConfig): custom configurations for convert function. + See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details. + + * `_remove_qconfig` (bool): Option to remove the qconfig attributes in the model after convert. + + * `qconfig_mapping` (QConfigMapping): config for specifying how to convert a model for quantization. + See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details. + + * `backend_config` (BackendConfig): A configuration for the backend which describes how + operators should be quantized in the backend. See + :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details. + + Return: + A reference quantized model (GraphModule) with operators working with decomposed quantized Tensor + + Example:: + + # prepared_model: the model after prepare_fx/prepare_qat_fx and calibration/training + # TODO: add backend_config after we split the backend_config for fbgemm and qnnpack + # e.g. backend_config = get_default_backend_config("fbgemm") + reference_quantized_model = _convert_to_reference_decomposed_fx(prepared_model) + + """ + torch._C._log_api_usage_once("quantization_api.quantize_fx._convert_to_reference_decomposed_fx") + return _convert_fx( + graph_module, + is_reference=True, + convert_custom_config=convert_custom_config, + _remove_qconfig=_remove_qconfig, + qconfig_mapping=qconfig_mapping, + backend_config=backend_config, + is_decomposed=True, + ) + def _convert_standalone_module_fx( graph_module: GraphModule, diff --git a/torch/ao/quantization/utils.py b/torch/ao/quantization/utils.py index 47ca7e64e329a..984386d205042 100644 --- a/torch/ao/quantization/utils.py +++ b/torch/ao/quantization/utils.py @@ -140,6 +140,17 @@ def getattr_from_fqn(obj: Any, fqn: str) -> Any: """ return functools.reduce(getattr, fqn.split("."), obj) +def to_underlying_dtype(qdtype): + DTYPE_MAPPING = { + torch.quint8: torch.uint8, + torch.qint8: torch.int8, + torch.qint32: torch.int32, + torch.quint4x2: torch.uint8, + torch.quint2x4: torch.uint8, + } + assert qdtype in DTYPE_MAPPING, "Unsupported dtype: " + qdtype + return DTYPE_MAPPING[qdtype] + def get_qparam_dict(observer_or_fake_quant): qscheme = observer_or_fake_quant.qscheme if hasattr(observer_or_fake_quant, "qscheme") else None dtype = observer_or_fake_quant.dtype @@ -210,9 +221,9 @@ def activation_is_dynamically_quantized(qconfig): dynamically quantized or not, this includes dynamically quantizing to quint8, qint8 and float16 """ - activation_dtype, _, activation_compute_dtype = \ + activation_dtype, _, activation_is_dynamic = \ get_qconfig_dtypes(qconfig) - return activation_compute_dtype in [torch.quint8, torch.qint8, torch.float16] + return activation_is_dynamic def activation_is_int8_quantized(qconfig): """ Given a qconfig, decide if the activation needs to be @@ -242,25 +253,24 @@ def op_is_int8_dynamically_quantized(qconfig) -> bool: """ Given a qconfig, returns True if this op is using int8 dynamic quantization """ - activation_dtype, weight_dtype, activation_compute_dtype = \ + activation_dtype, weight_dtype, activation_is_dynamic = \ get_qconfig_dtypes(qconfig) return ( activation_dtype is torch.quint8 and # for now, the lines below assume fbgemm or qnnpack weight_dtype is torch.qint8 and - activation_compute_dtype is torch.quint8 - # TODO(future PR): add is_dynamic + activation_is_dynamic ) def get_qconfig_dtypes(qconfig): r""" returns the qconfig tuple for qconfig: - (activation_dtype, weight_dtype, activation_compute_dtype) + (activation_dtype, weight_dtype, activation_is_dynamic) """ assert qconfig is not None activation = qconfig.activation() weight = qconfig.weight() - compute_dtype = activation.compute_dtype if hasattr(activation, 'compute_dtype') else None - return (activation.dtype, weight.dtype, compute_dtype) + act_is_dynamic = activation.is_dynamic if hasattr(activation, 'is_dynamic') else False + return (activation.dtype, weight.dtype, act_is_dynamic) def get_quant_type(qconfig): assert qconfig is not None @@ -268,7 +278,7 @@ def get_quant_type(qconfig): weight = qconfig.weight() static_dtypes = [torch.quint8, torch.qint8, torch.quint4x2, torch.qint32] if weight.dtype in static_dtypes: - if hasattr(activation, 'compute_dtype') and activation.compute_dtype in static_dtypes: + if hasattr(activation, 'is_dynamic') and activation.is_dynamic: return QuantType.DYNAMIC elif activation.dtype in static_dtypes: return QuantType.STATIC @@ -276,7 +286,7 @@ def get_quant_type(qconfig): return QuantType.WEIGHT_ONLY if weight.dtype == torch.float16: - if hasattr(activation, 'compute_dtype') and activation.compute_dtype in static_dtypes: + if hasattr(activation, 'is_dynamic') and activation.is_dynamic: return QuantType.DYNAMIC elif activation.dtype == torch.float16: return QuantType.STATIC @@ -535,6 +545,93 @@ def _patched_module_call(self, *args, **kwargs): torch.nn.Module.__call__ = orig_module_call return fqn_to_example_inputs +def _get_lstm_with_individually_observed_parts( + float_lstm: torch.nn.LSTM, + # Use Callable instead of _PartialWrapper here to avoid circular dependencies + linear_output_obs_ctr: Optional[Callable] = None, + sigmoid_obs_ctr: Optional[Callable] = None, + tanh_obs_ctr: Optional[Callable] = None, + cell_state_obs_ctr: Optional[Callable] = None, + hidden_state_obs_ctr: Optional[Callable] = None, +) -> torch.ao.nn.quantizable.LSTM: + """ + Return an observed `torch.ao.nn.quantizable.LSTM` created from a `torch.nn.LSTM` + with specific observers or fake quantizes assigned to the inner ops or submodules. + + In both eager and FX graph mode quantization, `torch.ao.nn.quantizable.LSTM` is + used as an observed custom module, which is responsible for inserting its own + observers. By default, all inner ops inherit the parent custom module's QConfig. + Users who wish to override this behavior may extend `torch.ao.nn.quantizable.LSTM` + and use this helper function to customize the observer insertion logic. + + Args: + `float_lstm`: The float LSTM module + `linear_output_obs_ctr`: observer or fake quantize for linear outputs Wx + b, + where W is the weight matrix, b is the bias, and x is either the inputs + or the hidden state from the previous layer (if any) + `sigmoid_obs_ctr`: observer or fake quantize for sigmoid activations + `tanh_obs_ctr`: observer or fake quantize for tanh activations + `cell_state_obs_ctr`: observer or fake quantize for the cell state + `hidden_state_obs_ctr`: observer or fake quantize for the hidden state and + the output + + Return: + A `torch.ao.nn.quantizable.LSTM` with the specified observers or fake quantizes + attached to the inner submodules. + """ + def make_qconfig(obs_ctr: Callable) -> torch.ao.quantization.QConfig: + """ + Make a QConfig with fixed qparams observers or fake quantizes. + """ + if isinstance(obs_ctr(), torch.ao.quantization.FakeQuantizeBase): + weight = torch.ao.quantization.default_weight_fake_quant + else: + weight = torch.ao.quantization.default_weight_observer + return torch.ao.quantization.QConfig(activation=obs_ctr, weight=weight) + + observed_lstm = torch.ao.nn.quantizable.LSTM( + float_lstm.input_size, float_lstm.hidden_size, float_lstm.num_layers, float_lstm.bias, + float_lstm.batch_first, float_lstm.dropout, float_lstm.bidirectional) + + # Assign QConfigs with fixed qparams to all inner submodules + # Module hierarchy: LSTM > _LSTMLayer > _LSTMSingleLayer (forward or backward) > LSTMCell + for layer in observed_lstm.layers: + inner_layers = [layer.layer_fw] + if float_lstm.bidirectional: + inner_layers.append(layer.layer_bw) + for inner_layer in inner_layers: + cell = inner_layer.cell + if linear_output_obs_ctr is not None: + qconfig = make_qconfig(linear_output_obs_ctr) + cell.igates.qconfig = qconfig + cell.hgates.qconfig = qconfig + if sigmoid_obs_ctr is not None: + qconfig = make_qconfig(sigmoid_obs_ctr) + cell.input_gate.qconfig = qconfig + cell.forget_gate.qconfig = qconfig + cell.output_gate.qconfig = qconfig + if tanh_obs_ctr is not None: + cell.cell_gate.qconfig = make_qconfig(tanh_obs_ctr) + if cell_state_obs_ctr is not None: + cell.fgate_cx_igate_cgate.qconfig = make_qconfig(cell_state_obs_ctr) + obs = cell_state_obs_ctr() + if hasattr(obs, "scale") and hasattr(obs, "zero_point"): + cell.initial_cell_state_qparams = (obs.scale, obs.zero_point) + cell.cell_state_dtype = obs.dtype + if hidden_state_obs_ctr is not None: + cell.ogate_cy.qconfig = make_qconfig(hidden_state_obs_ctr) + obs = hidden_state_obs_ctr() + if hasattr(obs, "scale") and hasattr(obs, "zero_point"): + cell.initial_hidden_state_qparams = (obs.scale, obs.zero_point) + cell.hidden_state_dtype = obs.dtype + + # Insert the observers based on the previously attached QConfigs + # Pass in non_leaf_module_list to prevent the observers for sigmoid/tanh from being overridden + torch.ao.quantization.add_observer_( + observed_lstm, + non_leaf_module_list=[torch.nn.Sigmoid, torch.nn.Tanh] + ) + return observed_lstm __all__ = [ "NodePattern", @@ -562,4 +659,5 @@ def _patched_module_call(self, *args, **kwargs): "calculate_qmin_qmax", "has_no_children_ignoring_parametrizations", "get_fqn_to_example_inputs", + "to_underlying_dtype", ] diff --git a/torch/autograd/forward_ad.py b/torch/autograd/forward_ad.py index 0a4ff26b50641..415928f5c22d3 100644 --- a/torch/autograd/forward_ad.py +++ b/torch/autograd/forward_ad.py @@ -86,8 +86,7 @@ def make_dual(tensor, tangent, *, level=None): # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE # Currently broken for 3.11, see https://github.com/pytorch/pytorch/issues/85506 if (os.environ.get("PYTORCH_JIT", "1" if sys.version_info < (3, 11) else "0") == "1" and - __debug__ and - os.environ.get('PYTORCH_DISABLE_LIBRARY', "0") == "0"): + __debug__): from torch._decomp import decompositions_for_jvp # noqa: F401 if level is None: diff --git a/torch/autograd/function.py b/torch/autograd/function.py index f4810712cab3a..386fc235592d3 100644 --- a/torch/autograd/function.py +++ b/torch/autograd/function.py @@ -1,8 +1,10 @@ import torch import torch._C as _C from torch._C import _functions +import torch._functorch as _functorch import torch.utils.hooks as hooks from torch._six import with_metaclass +from torch.autograd.grad_mode import _DecoratorContextManager import functools import warnings from collections import OrderedDict @@ -290,54 +292,7 @@ def __init__(cls, name, bases, attrs): # mypy doesn't understand `with_metaclass` from torch._six -class Function(with_metaclass(FunctionMeta, _C._FunctionBase, FunctionCtx, _HookMixin)): # type: ignore[misc] - r"""Base class to create custom `autograd.Function` - - To create a custom `autograd.Function`, subclass this class and implement - the :meth:`forward` and :meth:`backward` static methods. Then, to use your custom - op in the forward pass, call the class method ``apply``. Do not call - :meth:`forward` directly. - - To ensure correctness and best performance, make sure you are calling the - correct methods on ``ctx`` and validating your backward function using - :func:`torch.autograd.gradcheck`. - - See :ref:`extending-autograd` for more details on how to use this class. - - Examples:: - - >>> class Exp(Function): - >>> @staticmethod - >>> def forward(ctx, i): - >>> result = i.exp() - >>> ctx.save_for_backward(result) - >>> return result - >>> - >>> @staticmethod - >>> def backward(ctx, grad_output): - >>> result, = ctx.saved_tensors - >>> return grad_output * result - >>> - >>> # Use it by calling the apply method: - >>> # xdoctest: +SKIP - >>> output = Exp.apply(input) - """ - def __init__(self, *args, **kwargs): - cls = self.__class__ - warnings.warn(f"{cls} should not be instantiated. Methods on autograd functions" - "are all static, so you should invoke them on the class itself. " - "Instantiating an autograd function will raise an " - "error in a future version of PyTorch.", DeprecationWarning) - - def __call__(self, *args, **kwargs): - raise RuntimeError( - "Legacy autograd function with non-static forward method is deprecated. " - "Please use new-style autograd function with static forward method. " - "(Example: https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function)") - - # for the tracer - is_traceable = False - +class _SingleLevelFunction(with_metaclass(FunctionMeta, _C._FunctionBase, FunctionCtx, _HookMixin)): # type: ignore[misc] @staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: r"""Performs the operation. @@ -409,6 +364,77 @@ def jvp(ctx: Any, *grad_inputs: Any) -> Any: raise NotImplementedError("You must implement the jvp function for custom " "autograd.Function to use it with forward mode AD.") + +class Function(_SingleLevelFunction): + r"""Base class to create custom `autograd.Function` + + To create a custom `autograd.Function`, subclass this class and implement + the :meth:`forward` and :meth:`backward` static methods. Then, to use your custom + op in the forward pass, call the class method ``apply``. Do not call + :meth:`forward` directly. + + To ensure correctness and best performance, make sure you are calling the + correct methods on ``ctx`` and validating your backward function using + :func:`torch.autograd.gradcheck`. + + See :ref:`extending-autograd` for more details on how to use this class. + + Examples:: + + >>> class Exp(Function): + >>> @staticmethod + >>> def forward(ctx, i): + >>> result = i.exp() + >>> ctx.save_for_backward(result) + >>> return result + >>> + >>> @staticmethod + >>> def backward(ctx, grad_output): + >>> result, = ctx.saved_tensors + >>> return grad_output * result + >>> + >>> # Use it by calling the apply method: + >>> # xdoctest: +SKIP + >>> output = Exp.apply(input) + """ + def __init__(self, *args, **kwargs): + cls = self.__class__ + warnings.warn(f"{cls} should not be instantiated. Methods on autograd functions" + "are all static, so you should invoke them on the class itself. " + "Instantiating an autograd function will raise an " + "error in a future version of PyTorch.", DeprecationWarning) + + def __call__(self, *args, **kwargs): + raise RuntimeError( + "Legacy autograd function with non-static forward method is deprecated. " + "Please use new-style autograd function with static forward method. " + "(Example: https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function)") + + # for the tracer + is_traceable = False + + @classmethod + def apply(cls, *args, **kwargs): + if not torch._C._is_autograd_function_extension_enabled(): + return super().apply(*args, **kwargs) + + # TODO: fix circular import + # https://github.com/pytorch/pytorch/issues/90224 + from torch._functorch.autograd_function import custom_function_call + if not torch._C._are_functorch_transforms_active(): + # See NOTE: [functorch vjp and autograd interaction] + args = _functorch.utils.unwrap_dead_wrappers(args) + return super().apply(*args, **kwargs) + + if not hasattr(cls, 'setup_context'): + # TODO: link documentation in error message + # https://github.com/pytorch/pytorch/issues/90224 + raise RuntimeError( + 'In order to use an autograd.Function with functorch transforms ', + '(vmap, grad, jvp, jacrev, ...), it must have a setup_context ', + 'staticmethod.') + return custom_function_call(cls, *args, **kwargs) + def once_differentiable(fn): @functools.wraps(fn) @@ -468,6 +494,19 @@ def traceable(fn_cls): return fn_cls +# Private feature flag. Not user-facing. +class _set_autograd_function_extension_enabled(_DecoratorContextManager): + def __init__(self, enabled=True): + self.enabled = enabled + + def __enter__(self): + self.prev_state = torch._C._is_autograd_function_extension_enabled() + torch._C._set_autograd_function_extension_enabled(self.enabled) + + def __exit__(self, *args, **kwargs): + torch._C._set_autograd_function_extension_enabled(self.prev_state) + + class InplaceFunction(Function): def __init__(self, inplace=False): @@ -516,13 +555,11 @@ def _iter(obj): return elif isinstance(obj, (list, tuple)): for o in obj: - for var in _iter(o): - yield var + yield from _iter(o) elif isinstance(obj, dict): # We only accept primitive key types, so we needn't inspect them for o in obj.values(): - for var in _iter(o): - yield var + yield from _iter(o) elif allow_unknown: yield obj else: diff --git a/torch/autograd/grad_mode.py b/torch/autograd/grad_mode.py index 94ca204fc9ab4..e5e410eeb42ee 100644 --- a/torch/autograd/grad_mode.py +++ b/torch/autograd/grad_mode.py @@ -2,6 +2,7 @@ import torch import functools import inspect +import warnings from typing import Any, Callable, TypeVar, cast __all__ = ['no_grad', 'enable_grad', 'set_grad_enabled', @@ -18,6 +19,12 @@ class _DecoratorContextManager: """Allow a context manager to be used as a decorator""" def __call__(self, func: F) -> F: + if inspect.isclass(func): + warnings.warn("Decorating classes is deprecated and will be disabled in " + "future versions. You should only decorate functions or methods. " + "To preserve the current behavior of class decoration, you can " + "directly decorate the `__init__` method and nothing else.") + if inspect.isgeneratorfunction(func): return self._wrap_generator(func) diff --git a/torch/autograd/gradcheck.py b/torch/autograd/gradcheck.py index 2f43423a2bd6f..9f9a80ed50931 100644 --- a/torch/autograd/gradcheck.py +++ b/torch/autograd/gradcheck.py @@ -20,6 +20,15 @@ class GradcheckError(RuntimeError): pass + +def _is_sparse_compressed_tensor(obj: torch.Tensor): + return obj.layout in {torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc} + + +def _is_sparse_any_tensor(obj: torch.Tensor): + return _is_sparse_compressed_tensor(obj) or obj.layout is torch.sparse_coo + + def _is_float_or_complex_tensor(obj): return is_tensor_like(obj) and (obj.is_floating_point() or obj.is_complex()) @@ -80,7 +89,7 @@ def _iter_tensor(x_tensor): # # where x is the t.data of the original tensor. Perturbing the entry of x # at index (1, 1) yields the 3rd column of the overall Jacobian matrix. - if x_tensor.is_sparse: + if _is_sparse_any_tensor(x_tensor): def get_stride(size): dim = len(size) tmp = 1 @@ -91,8 +100,17 @@ def get_stride(size): return stride x_nnz = x_tensor._nnz() x_size = list(x_tensor.size()) - x_indices = x_tensor._indices().t() - x_values = x_tensor._values() + if x_tensor.layout is torch.sparse_coo: + x_indices = x_tensor._indices().t() + x_values = x_tensor._values() + elif x_tensor.layout is torch.sparse_csr: + x_indices = torch._convert_indices_from_csr_to_coo(x_tensor.crow_indices(), x_tensor.col_indices()).t() + x_values = x_tensor.values() + elif x_tensor.layout is torch.sparse_csc: + x_indices = torch._convert_indices_from_csr_to_coo(x_tensor.ccol_indices(), x_tensor.row_indices(), transpose=True).t() + x_values = x_tensor.values() + else: + raise NotImplementedError(f'_iter_tensor for {x_tensor.layout} input') x_stride = get_stride(x_size) # Use .data here to get around the version check x_values = x_values.data @@ -249,7 +267,7 @@ def _prepare_input(input: torch.Tensor, maybe_perturbed_input: Optional[torch.Te return maybe_perturbed_input.to_mkldnn() else: return input - elif input.layout == torch.sparse_coo: + elif _is_sparse_any_tensor(input): if fast_mode and maybe_perturbed_input is not None: # entry is already a "cloned" version of the original tensor # thus changes to entry are not reflected in the input @@ -386,7 +404,7 @@ def _get_input_to_perturb(input): if input.layout == torch._mkldnn: # type: ignore[attr-defined] # no attr _mkldnn # Convert to dense so we can perform operations that require strided tensors input_to_perturb = input.to_dense() - elif input.layout == torch.sparse_coo: + elif _is_sparse_any_tensor(input): # Clone because input may require grad, and copy_ calls resize_, # which is not allowed for .data input_to_perturb = input.clone() @@ -414,10 +432,10 @@ def jvp_fn(delta): def _reshape_tensor_or_tuple(u, shape): # We don't need to reshape when input corresponding to u is sparse if isinstance(u, tuple): - if u[0].layout != torch.sparse_coo: + if not _is_sparse_any_tensor(u[0]): return (u[0].reshape(shape), u[1].reshape(shape)) else: - if u.layout != torch.sparse_coo: + if not _is_sparse_any_tensor(u): return u.reshape(shape) return u @@ -642,7 +660,7 @@ def _get_analytical_vjps_wrt_specific_output(vjp_fn, sample_output, v) -> List[L def _check_inputs(tupled_inputs, check_sparse_nnz) -> bool: - if not check_sparse_nnz and any(t.is_sparse or t.is_sparse_csr for t in tupled_inputs if isinstance(t, torch.Tensor)): + if not check_sparse_nnz and any(_is_sparse_any_tensor(t) for t in tupled_inputs if isinstance(t, torch.Tensor)): raise GradcheckError('gradcheck expects all tensor inputs are dense when check_sparse_nnz is set to False.') # Make sure that gradients are saved for at least one input any_input_requiring_grad = False @@ -656,7 +674,7 @@ def _check_inputs(tupled_inputs, check_sparse_nnz) -> bool: 'not of double precision floating point or complex. ') if inp.is_sparse: content = inp._values() - elif inp.is_sparse_csr: + elif _is_sparse_compressed_tensor(inp): content = inp.values() else: content = inp @@ -679,7 +697,7 @@ def _check_inputs(tupled_inputs, check_sparse_nnz) -> bool: def _check_outputs(outputs) -> None: - if any(t.layout == torch.sparse_coo for t in outputs if isinstance(t, torch.Tensor)): + if any(_is_sparse_any_tensor(t) for t in outputs if isinstance(t, torch.Tensor)): # it is easier to call to_dense() on the sparse output than # to modify analytical jacobian raise ValueError('Sparse output is not supported at gradcheck yet. ' @@ -801,7 +819,7 @@ def jvp(tangent: torch.Tensor): except RuntimeError as ex: # Rethrow to provide a better error message raise GradcheckError( - f'While computing batched gradients, got: {ex}\n\n{FAILED_BATCHED_GRAD_MSG_FWD_AD}') + f'While computing batched gradients, got: {ex}\n\n{FAILED_BATCHED_GRAD_MSG_FWD_AD}') from ex for input_idx, (res, exp) in enumerate(zip(result, expected)): if torch.allclose(res, exp): @@ -843,7 +861,7 @@ def vjp(v): # autograd.grad instead of the C++ traceback of what line in the # backward formula raise GradcheckError( - f'While computing batched gradients, got: {ex}\n\n{FAILED_BATCHED_GRAD_MSG}') + f'While computing batched gradients, got: {ex}\n\n{FAILED_BATCHED_GRAD_MSG}') from ex for input_idx, (res, exp) in enumerate(zip(result, expected)): if torch.allclose(res, exp): @@ -866,11 +884,12 @@ def _test_backward_mul_by_grad_output(outputs, inputs, check_sparse_nnz) -> bool if isinstance(gi, torch.Tensor) and gi.layout != torch.strided: if gi.layout != di.layout: raise GradcheckError('grad is incorrect layout (' + str(gi.layout) + ' is not ' + str(di.layout) + ')') - if gi.layout == torch.sparse_coo: + if _is_sparse_any_tensor(gi): + sparse_kind = str(gi.layout).replace('torch.', '').replace('_coo', '') if gi.sparse_dim() != di.sparse_dim(): - raise GradcheckError('grad is sparse tensor, but has incorrect sparse_dim') + raise GradcheckError(f'grad is {sparse_kind} tensor, but has incorrect sparse_dim') if gi.dense_dim() != di.dense_dim(): - raise GradcheckError('grad is sparse tensor, but has incorrect dense_dim') + raise GradcheckError(f'grad is {sparse_kind} tensor, but has incorrect dense_dim') gi = gi.to_dense() di = di.to_dense() @@ -958,12 +977,12 @@ def check_undefined_grad_support(output_to_check): try: grads_input = torch.autograd.grad(output_to_check, diff_input_list, grads_output, allow_unused=True) - except RuntimeError: + except RuntimeError as e: warn_bc_breaking() raise GradcheckError(( 'Expected backward function to handle undefined output grads. ' 'Please look at "Notes about undefined output gradients" in ' - '"tools/autograd/derivatives.yaml"')) + '"tools/autograd/derivatives.yaml"')) from e for gi, i in zip(grads_input, diff_input_list): if (gi is not None) and (not gi.eq(0).all()): @@ -1164,9 +1183,21 @@ def _vec_from_tensor(x, generator, downcast_complex=False): dtype = _to_real_dtype(x.dtype) if downcast_complex else x.dtype values = torch.rand(x_values.numel(), generator=generator) \ .to(dtype=dtype, device=x.device) \ - .reshape(x_values.shape) + .view(x_values.shape) values /= values.norm() vec = torch.sparse_coo_tensor(x._indices(), values, x.size()) + elif _is_sparse_compressed_tensor(x): + if x.layout in {torch.sparse_csr, torch.sparse_bsr}: + compressed_indices, plain_indices = x.crow_indices(), x.col_indices() + else: + compressed_indices, plain_indices = x.ccol_indices(), x.row_indices() + x_values = x.values() + dtype = _to_real_dtype(x.dtype) if downcast_complex else x.dtype + values = torch.rand(x_values.numel(), generator=generator) \ + .to(dtype=dtype, device=x.device) \ + .view(x_values.shape) + values /= values.norm() + vec = torch.sparse_compressed_tensor(compressed_indices, plain_indices, values, x.size(), layout=x.layout) else: dtype = _to_real_dtype(x.dtype) if downcast_complex else x.dtype vec = torch.rand(x.numel(), generator=generator).to(dtype=dtype, device=x.device) @@ -1189,6 +1220,7 @@ def _adjusted_atol(atol, u, v): # matrix): v^T M u = \sum_{i} \sum_{j} u_i * v_j = (\sum_{i} u_i)(\sum_{i} v_i) # TODO: properly handle case when u is tuple instead of only taking first element u = u[0] if isinstance(u, tuple) else u + # TODO: replace torch.sparse.sum(u) with u.sum() sum_u = torch.sparse.sum(u) if u.layout == torch.sparse_coo else u.sum() sum_v = 1. if v is None else torch.sparse.sum(v) if v.layout == torch.sparse_coo else v.sum() return atol * float(sum_u) * float(sum_v) @@ -1241,7 +1273,7 @@ def new_fn(inp): def _to_flat_dense_if_sparse(tensor): - if tensor.layout == torch.sparse_coo: + if _is_sparse_any_tensor(tensor): return tensor.to_dense().reshape(-1) else: return tensor diff --git a/torch/autograd/graph.py b/torch/autograd/graph.py index 9c333c70bcf22..fc490a9d8e31c 100644 --- a/torch/autograd/graph.py +++ b/torch/autograd/graph.py @@ -1,15 +1,17 @@ import torch import contextlib -from typing import Callable, Any, Dict, Tuple, Optional, Sequence, List +from typing import Callable, Any, Dict, Tuple, Optional, Sequence, List, Set from torch.utils.hooks import RemovableHandle - -__all__ = ["saved_tensors_hooks", "save_on_cpu"] +from torch.utils._python_dispatch import TorchDispatchMode +from collections import defaultdict +import weakref __all__ = [ "saved_tensors_hooks", "save_on_cpu", "disable_saved_tensors_hooks", "register_multi_grad_hook", + "allow_mutation_on_saved_tensors", ] class saved_tensors_hooks(): @@ -270,3 +272,158 @@ def __setstate__(self, state): handles.append(t.register_hook(get_inner_hook(i))) return Handle(tuple(handles)) + + +# NOTE [Allow mutation on tensors saved for backward] +# +# 1. Tensor gets saved for backward +# - remember the python object id and the version of the tensor +# - remember aliasing information (data_ptr of base + version) +# - save the original so we control its lifetime +# 2. Any time a tensor gets in-placed +# - for each tensor aliased to it: +# - check using its object id and version to see if it has been saved +# - if it has been saved, clone it +# - delete the reference to the original +# 3. during backward +# - if the clone exists, the tensor must've been modified in-place +_allow_mutation_on_saved_tensors_enabled = False + +def _get_tid(t) -> Tuple[int, int, int]: + return (id(t), t.data_ptr(), t._version) + +def _get_sid(t) -> Tuple[int, int]: + return (t.data_ptr(), t._version) + +class _Handle(): + pass + +class _swap_with_cloned(saved_tensors_hooks): + def __init__(self, ctx): + def pack_hook(t): + tid = _get_tid(t) + sid = _get_sid(t) + # Tensors saved for backward have an entry in _tid_to_weakhandle + handle: Optional[_Handle] = None + + # Save aliasing information + ctx.sid_to_tid[sid].add(tid) + + # NB: The same tensor (of the same version) can be saved multiple times + if tid not in ctx.tid_to_weakhandle: + handle = _Handle() + ctx.tid_to_weakhandle[tid] = handle + ctx.original[handle] = t + else: + # Store an additional strong reference to the handle + handle = ctx.tid_to_weakhandle[tid] + return handle + + def unpack_hook(tup): + handle = tup + error_msg = ( + "Trying to backward outside of the 'allow_mutation_on_saved_tensors' context" + "in which the graph was originally recorded.") + assert _allow_mutation_on_saved_tensors_enabled, error_msg + if handle in ctx.cloned: + res = ctx.cloned[handle] + else: + assert handle in ctx.original, error_msg + res = ctx.original[handle] + return res + + super().__init__(pack_hook, unpack_hook) + +class _CloneArgBeforeMutateMode(TorchDispatchMode): + def __init__(self, ctx): + self.ctx = ctx + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + kwargs = kwargs or {} + + for idx, arg in enumerate(func._schema.arguments): + if arg.alias_info is not None and arg.alias_info.is_write: + t = kwargs["out"] if arg.is_out else args[idx] + tid = _get_tid(t) + sid = _get_sid(t) + ctx = self.ctx + if sid in ctx.sid_to_tid: + for tid in ctx.sid_to_tid[sid]: + if tid not in ctx.tid_to_weakhandle: + # We know that if tid is in sid_to_tid, then it must also be in + # tid_to_weakhandle. However, it is possible for the tensor to be + # saved at one point, but cleared by backward before it is modified + # in-place. Consider the following example: + # + # >>> a = torch.randn(2, 3, requires_grad=True).clone() + # >>> out = (a**2).sum() + # >>> out.backward() + # >>> a.sin_() + continue + handle = ctx.tid_to_weakhandle[tid] + if handle in ctx.cloned: + # The same exact tensor has been cloned already + continue + ctx.cloned[handle] = ctx.original[handle].clone() + del ctx.original[handle] + + rs = func(*args, **kwargs) + return rs + +class _AllowMutationOnSavedContext(): + def __init__(self): + self.cloned: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() + self.original: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() + self.tid_to_weakhandle: weakref.WeakValueDictionary = weakref.WeakValueDictionary() + self.sid_to_tid: Dict[Tuple[int, int], Set[Tuple[int, int, int]]] = defaultdict(set) + + def clear(self): + self.cloned.clear() + self.original.clear() + self.tid_to_weakhandle.clear() + self.sid_to_tid.clear() + +@contextlib.contextmanager +def allow_mutation_on_saved_tensors(): + """Context manager under which mutating tensors saved for backward is allowed + + Under this context manager, tensors saved for backward are cloned on mutation, + so the original version can still be used during backward. Normally, mutating a tensor + saved for backward will result in an error raised when it's used during backward. + + To ensure the correct behavior, both the forward and backward should be run under + the same context manager. + + returns: + An _AllowMutationOnSavedContext object storing the state managed by this + context manager. This object can be useful for debugging purposes. The state + managed by the context manager is automatically cleared upon exiting. + + Example:: + + >>> import torch + >>> with torch.autograd.graph.allow_mutation_on_saved_tensors(): + ... # forward + ... a = torch.ones(2, 3, requires_grad=True) + ... b = a.clone() + ... out = (b**2).sum() + ... b.sin_() + ... # backward + ... out.sum().backward() + ... + tensor([[0.8415, 0.8415, 0.8415], + [0.8415, 0.8415, 0.8415]], grad_fn=) + """ + global _allow_mutation_on_saved_tensors_enabled + + ctx = _AllowMutationOnSavedContext() + + with _swap_with_cloned(ctx), _CloneArgBeforeMutateMode(ctx): + try: + if _allow_mutation_on_saved_tensors_enabled: + raise RuntimeError("allow_mutation_on_saved_tensors contexts cannot be nested") + _allow_mutation_on_saved_tensors_enabled = True + yield ctx + finally: + ctx.clear() + _allow_mutation_on_saved_tensors_enabled = False diff --git a/torch/autograd/profiler.py b/torch/autograd/profiler.py index ddd0ad6d0a289..e70ec6c4ed8ca 100644 --- a/torch/autograd/profiler.py +++ b/torch/autograd/profiler.py @@ -481,17 +481,29 @@ def __init__(self, name: str, args: Optional[str] = None): self.args: Optional[str] = args # Whether or not we should run record function's end callbacks when exiting. self.run_callbacks_on_exit: bool = True - # Stores underlying RecordFunction as a tensor. TODO: move to custom - # class (https://github.com/pytorch/pytorch/issues/35026). - self.handle: torch.Tensor = torch.zeros(1) + # TODO: TorchScript ignores standard type annotation here + # self.record: Optional["torch.classes.profiler._RecordFunction"] = None + self.record = torch.jit.annotate(Optional["torch.classes.profiler._RecordFunction"], None) def __enter__(self): - self.handle = torch.ops.profiler._record_function_enter(self.name, self.args) + self.record = torch.ops.profiler._record_function_enter_new(self.name, self.args) return self def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any): - if self.run_callbacks_on_exit: - torch.ops.profiler._record_function_exit(self.handle) + if not self.run_callbacks_on_exit: + return + + # Local variable is needed by TorchScript to refine Optional[T] to T + record = self.record + assert record is not None + + # TODO: Too slow with __torch_function__ handling enabled + # See https://github.com/pytorch/pytorch/issues/76410 + if not torch.jit.is_scripting(): + with torch._C.DisableTorchFunction(): + torch.ops.profiler._record_function_exit._RecordFunction(record) + else: + torch.ops.profiler._record_function_exit(record) def _call_end_callbacks_on_future(self, fut: Future[Any]) -> Future[Any]: """ @@ -518,7 +530,19 @@ def _call_end_callbacks_on_future(self, fut: Future[Any]) -> Future[Any]: # We are scheduling to run this RecordFunction's end callbacks when the # passed in future completes, so don't run end callbacks on exit. self.run_callbacks_on_exit = False - profiled_future = torch.ops.profiler._call_end_callbacks_on_jit_fut(self.handle, fut) + + # Local variable is needed by TorchScript to refine Optional[T] to T + record = self.record + assert record is not None + + # TODO: Too slow with __torch_function__ handling enabled + # See https://github.com/pytorch/pytorch/issues/76410 + if not torch.jit.is_scripting(): + with torch._C.DisableTorchFunction(): + profiled_future = torch.ops.profiler._call_end_callbacks_on_jit_fut._RecordFunction( + record, fut) + else: + profiled_future = torch.ops.profiler._call_end_callbacks_on_jit_fut(record, fut) return profiled_future diff --git a/torch/backends/cuda/__init__.py b/torch/backends/cuda/__init__.py index eece8f2646164..50735e125ec36 100644 --- a/torch/backends/cuda/__init__.py +++ b/torch/backends/cuda/__init__.py @@ -1,12 +1,14 @@ import sys import torch import contextlib +from enum import IntEnum from typing import Union __all__ = ["is_built", "cuFFTPlanCacheAttrContextProp", "cuFFTPlanCache", "cuFFTPlanCacheManager", - "cuBLASModule", "preferred_linalg_library", "cufft_plan_cache", "matmul", "enable_flash_sdp", - "flash_sdp_enabled", "math_sdp_enabled", "enable_math_sdp", "sdp_kernel"] + "cuBLASModule", "preferred_linalg_library", "cufft_plan_cache", "matmul", "SDPBackend", "enable_flash_sdp", + "flash_sdp_enabled", "enable_mem_efficient_sdp", "mem_efficient_sdp_enabled", + "math_sdp_enabled", "enable_math_sdp", "sdp_kernel"] def is_built(): r"""Returns whether PyTorch is built with CUDA support. Note that this @@ -163,6 +165,20 @@ def preferred_linalg_library(backend: Union[None, str, torch._C._LinalgBackend] return torch._C._get_linalg_preferred_backend() +class SDPBackend(IntEnum): + r"""Enum class for the scaled dot product attention backends. + + .. warning:: This flag is experimental and subject to change.' + + This class needs to stay inline with the enum defined in: + pytorch/aten/src/ATen/native/transformers/sdp_utils_cpp.h + """ + ERROR = -1 + MATH = 0 + FLASH_ATTENTION = 1 + EFFICIENT_ATTENTION = 2 + + def flash_sdp_enabled(): r""" .. warning:: This flag is experimental and subject to change. @@ -180,6 +196,22 @@ def enable_flash_sdp(enabled: bool): """ torch._C._set_sdp_use_flash(enabled) +def mem_efficient_sdp_enabled(): + r""" + .. warning:: This flag is experimental and subject to change. + + Returns whether memory efficient sdp is enabled or not. + """ + return torch._C._get_mem_efficient_sdp_enabled() + + +def enable_mem_efficient_sdp(enabled: bool): + r""" + .. warning:: This flag is experimental and subject to change. + + Enables or disables memory efficient sdp. + """ + torch._C._set_sdp_use_mem_efficient(enabled) def math_sdp_enabled(): r""" @@ -200,23 +232,26 @@ def enable_math_sdp(enabled: bool): @contextlib.contextmanager -def sdp_kernel(enable_flash: bool = True, enable_math: bool = True): +def sdp_kernel(enable_flash: bool = True, enable_math: bool = True, enable_mem_efficient: bool = True): r""" .. warning:: This flag is experimental and subject to change. - This context manager can be used to temporarily enable or disable flash sdp and math sdp. + This context manager can be used to temporarily enable or disable flash/memory efficient sdp and math sdp. Upon exiting the context manager, the previous state of the flags will be restored. """ previous_flash: bool = flash_sdp_enabled() + previous_mem_efficient: bool = mem_efficient_sdp_enabled() previous_math: bool = math_sdp_enabled() try: enable_flash_sdp(enable_flash) + enable_mem_efficient_sdp(enable_mem_efficient) enable_math_sdp(enable_math) yield{} except RuntimeError as err: raise err finally: enable_flash_sdp(previous_flash) + enable_mem_efficient_sdp(previous_mem_efficient) enable_math_sdp(previous_math) cufft_plan_cache = cuFFTPlanCacheManager() diff --git a/torch/backends/cudnn/__init__.py b/torch/backends/cudnn/__init__.py index e187d6d26aed8..2b63a63796650 100644 --- a/torch/backends/cudnn/__init__.py +++ b/torch/backends/cudnn/__init__.py @@ -37,6 +37,8 @@ def _init(): else: cudnn_compatible = runtime_minor >= compile_minor if not cudnn_compatible: + if os.environ.get('PYTORCH_SKIP_CUDNN_COMPATIBILITY_CHECK', '0') == '1': + return True base_error_msg = (f'cuDNN version incompatibility: ' f'PyTorch was compiled against {compile_version} ' f'but found runtime version {runtime_version}. ' diff --git a/torch/backends/mps/__init__.py b/torch/backends/mps/__init__.py index b6cec317eb54d..1664c87ee5de7 100644 --- a/torch/backends/mps/__init__.py +++ b/torch/backends/mps/__init__.py @@ -11,4 +11,6 @@ def is_built() -> bool: @_lru_cache() def is_available() -> bool: r"""Returns a bool indicating if MPS is currently available.""" + if not hasattr(torch._C, '_is_mps_available'): + return False return torch._C._is_mps_available() diff --git a/torch/backends/xeon/run_cpu.py b/torch/backends/xeon/run_cpu.py index 69632cb208628..da55a9e605e10 100644 --- a/torch/backends/xeon/run_cpu.py +++ b/torch/backends/xeon/run_cpu.py @@ -598,7 +598,7 @@ def create_args(parser=None): _add_multi_instance_params(parser) # positional parser.add_argument("program", type=str, - help="The full path to the proram/script to be launched. " + help="The full path to the program/script to be launched. " "followed by all the arguments for the script") # rest from the training program diff --git a/torch/csrc/CudaIPCTypes.cpp b/torch/csrc/CudaIPCTypes.cpp index 9a2c47a5f7a84..d18a23ebe4e68 100644 --- a/torch/csrc/CudaIPCTypes.cpp +++ b/torch/csrc/CudaIPCTypes.cpp @@ -195,7 +195,7 @@ CudaIPCSentData::~CudaIPCSentData() { try { if (event_sync_required_) { at::cuda::CUDAGuard device_guard(device_.index()); - cudaEventDestroy(event_); + C10_CUDA_CHECK(cudaEventDestroy(event_)); if (!CudaIPCGlobalEntities::alive) { return; } diff --git a/torch/csrc/DynamicTypes.cpp b/torch/csrc/DynamicTypes.cpp index b3021ffe0d8d8..00674cf81229e 100644 --- a/torch/csrc/DynamicTypes.cpp +++ b/torch/csrc/DynamicTypes.cpp @@ -13,6 +13,7 @@ #include #include +#include #include #include @@ -77,7 +78,11 @@ THPLayout* getTHPLayout(at::Layout layout) { PyObject* createPyObject(const at::Storage& storage) { if (storage.device_type() != at::DeviceType::Meta && - storage.data() == nullptr && storage.nbytes() != 0) { + storage.data() == nullptr && storage.sym_nbytes() != 0 && + // Grabbing storage() from FunctionalTensorWrapper is allowed. + // This is useful for checking aliasing info from python + dynamic_cast( + storage.unsafeGetStorageImpl()) == nullptr) { TORCH_CHECK_NOT_IMPLEMENTED( false, "python bindings to nullptr storage (e.g., from torch.Tensor._make_wrapper_subclass) are currently unsafe and thus disabled. See https://github.com/pytorch/pytorch/issues/61669 for more details"); @@ -135,7 +140,7 @@ at::Storage createStorageGetType( TORCH_INTERNAL_ASSERT(THPDtype_Check(dtype_obj)); scalar_type = reinterpret_cast(dtype_obj)->scalar_type; - untyped_storage_obj = PyObject_GetAttrString(obj, "_storage"); + untyped_storage_obj = PyObject_GetAttrString(obj, "_untyped_storage"); TORCH_INTERNAL_ASSERT(untyped_storage_obj); Py_DECREF(untyped_storage_obj); diff --git a/torch/csrc/DynamicTypes.h b/torch/csrc/DynamicTypes.h index 6765916634c53..7ca18942564df 100644 --- a/torch/csrc/DynamicTypes.h +++ b/torch/csrc/DynamicTypes.h @@ -9,6 +9,7 @@ #include #include #include +#include #include #include @@ -24,7 +25,7 @@ namespace torch { void registerDtypeObject(THPDtype* dtype, at::ScalarType scalarType); void registerLayoutObject(THPLayout* thp_layout, at::Layout layout); -PyObject* createPyObject(const at::Storage& storage); +TORCH_PYTHON_API PyObject* createPyObject(const at::Storage& storage); at::Storage createStorage(PyObject* obj); at::Storage createStorageGetType( PyObject* obj, diff --git a/torch/csrc/Exceptions.cpp b/torch/csrc/Exceptions.cpp index b9e4c0a1fca72..67ac3decd6b13 100644 --- a/torch/csrc/Exceptions.cpp +++ b/torch/csrc/Exceptions.cpp @@ -13,7 +13,7 @@ #include PyObject *THPException_FatalError, *THPException_LinAlgError, - *THPException_OutOfMemoryError; + *THPException_OutOfMemoryError, *THPException_DistBackendError; #define ASSERT_TRUE(cond) \ if (!(cond)) \ @@ -63,6 +63,16 @@ could not be completed because the input matrix is singular.", PyModule_AddObject( module, "_OutOfMemoryError", THPException_OutOfMemoryError) == 0); + ASSERT_TRUE( + THPException_DistBackendError = PyErr_NewExceptionWithDoc( + "torch.distributed.DistBackendError", + "Exception raised when a backend error occurs in distributed", + PyExc_RuntimeError, + nullptr)); + ASSERT_TRUE( + PyModule_AddObject( + module, "_DistBackendError", THPException_DistBackendError) == 0); + return true; } diff --git a/torch/csrc/Exceptions.h b/torch/csrc/Exceptions.h index 666f240764217..01caa6a702c0a 100644 --- a/torch/csrc/Exceptions.h +++ b/torch/csrc/Exceptions.h @@ -8,6 +8,7 @@ #include #include +#include #include #include #include @@ -75,6 +76,8 @@ static inline void PyErr_SetString(PyObject* type, const std::string& message) { _CATCH_GENERIC_ERROR(LinAlgError, THPException_LinAlgError, retstmnt) \ _CATCH_GENERIC_ERROR( \ OutOfMemoryError, THPException_OutOfMemoryError, retstmnt) \ + _CATCH_GENERIC_ERROR( \ + DistBackendError, THPException_DistBackendError, retstmnt) \ _CATCH_GENERIC_ERROR(Error, PyExc_RuntimeError, retstmnt) \ catch (torch::PyTorchError & e) { \ auto msg = torch::processErrorMsg(e.what()); \ @@ -146,7 +149,7 @@ static inline void PyErr_SetString(PyObject* type, const std::string& message) { #define END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS_RET(nullptr) extern PyObject *THPException_FatalError, *THPException_LinAlgError, - *THPException_OutOfMemoryError; + *THPException_OutOfMemoryError, *THPException_DistBackendError; // Throwing this exception means that the python error flags have been already // set and control should be immediately returned to the interpreter. @@ -373,17 +376,17 @@ struct PyWarningHandler { namespace detail { template -using Arg = typename function_traits::template arg::type; +using Arg = typename invoke_traits::template arg::type; template auto wrap_pybind_function_impl_(Func&& f, std::index_sequence) { - using traits = function_traits; + using result_type = typename invoke_traits::result_type; namespace py = pybind11; // f=f is needed to handle function references on older compilers - return [f = f](Arg... args) -> typename traits::result_type { + return [f = std::forward(f)](Arg... args) -> result_type { HANDLE_TH_ERRORS - return f(std::forward>(args)...); + return c10::guts::invoke(f, std::forward>(args)...); END_HANDLE_TH_ERRORS_PYBIND }; } @@ -393,7 +396,7 @@ auto wrap_pybind_function_impl_(Func&& f, std::index_sequence) { // Returns a function object suitable for registering with pybind11. template auto wrap_pybind_function(Func&& f) { - using traits = function_traits; + using traits = invoke_traits; return torch::detail::wrap_pybind_function_impl_( std::forward(f), std::make_index_sequence{}); } diff --git a/torch/csrc/Generator.cpp b/torch/csrc/Generator.cpp index 31dcfefaea8d8..241628ae8938b 100644 --- a/torch/csrc/Generator.cpp +++ b/torch/csrc/Generator.cpp @@ -17,6 +17,10 @@ #include #endif +#ifdef USE_MPS +#include +#endif + using namespace at; using namespace torch; @@ -52,25 +56,24 @@ static PyObject* THPGenerator_pynew( auto device = r.deviceWithDefault(0, at::Device(at::kCPU)); THPGeneratorPtr self((THPGenerator*)type->tp_alloc(type, 0)); -#ifdef USE_CUDA if (device.type() == at::kCPU) { self->cdata = make_generator(); - } else if (device.type() == at::kCUDA) { + } +#ifdef USE_CUDA + else if (device.type() == at::kCUDA) { self->cdata = make_generator(device.index()); - } else { + } +#elif USE_MPS + else if (device.type() == at::kMPS) { + self->cdata = make_generator(); + } +#endif + else { AT_ERROR( "Device type ", c10::DeviceTypeName(device.type()), " is not supported for torch.Generator() api."); } -#else - TORCH_CHECK( - device.type() == at::kCPU, - "Device type ", - c10::DeviceTypeName(device.type()), - " is not supported for torch.Generator() api."); - self->cdata = make_generator(); -#endif return (PyObject*)self.release(); END_HANDLE_TH_ERRORS } @@ -92,11 +95,10 @@ static PyObject* THPGenerator_setState(PyObject* _self, PyObject* _new_state) { using namespace torch::autograd; HANDLE_TH_ERRORS - if (!THPVariable_Check(_new_state)) { - throw torch::TypeError( - "expected a torch.ByteTensor, but got %s", - Py_TYPE(_new_state)->tp_name); - } + TORCH_CHECK_TYPE( + THPVariable_Check(_new_state), + "expected a torch.ByteTensor, but got ", + Py_TYPE(_new_state)->tp_name); auto self = (THPGenerator*)_self; auto& gen = self->cdata; const auto& new_state_tensor = THPVariable_Unpack(_new_state); diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 08b9b81217e93..2dd1109c9987e 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -1,3 +1,4 @@ +#include #include #include @@ -8,10 +9,10 @@ #include #include #include +#include #include #include #include -#include #include #include #include @@ -55,6 +56,7 @@ #include #include #include +#include #include #include #include @@ -183,6 +185,25 @@ static PyObject* THPModule_crashIfCsrcUBSAN(PyObject* module, PyObject* arg) { return THPUtils_packInt32((int)y); } +static PyObject* THPModule_crashIfvptrUBSAN(PyObject* module, PyObject* noarg) { + // This code shoud work perfectly fine, as vtables are idential for Foo and + // Baz unless rtti and ubsan are enabled + struct Foo { + virtual int bar() = 0; + virtual ~Foo() = default; + }; + struct Baz { + virtual int bar() { + return 17; + } + virtual ~Baz() = default; + }; + Baz x{}; + auto y = static_cast(static_cast(&x)); + auto rc = y->bar(); + return THPUtils_packInt32(rc); +} + static PyObject* THPModule_crashIfATenASAN(PyObject* module, PyObject* arg) { THPUtils_assert( THPUtils_checkLong(arg), @@ -441,6 +462,20 @@ PyObject* THModule_getCppBacktrace(PyObject* _unused, PyObject* args) { c10::get_backtrace(frames_to_skip, maximum_number_of_frames, true)); END_HANDLE_TH_ERRORS } +static PyObject* THModule_rename_privateuse1_backend( + PyObject* _unused, + PyObject* arg) { + HANDLE_TH_ERRORS + THPUtils_assert( + THPUtils_checkString(arg), + "_rename_privateuse1_backend expects a str, " + "but got %s", + THPUtils_typename(arg)); + const std::string backend_name = THPUtils_unpackString(arg); + c10::register_privateuse1_backend(backend_name); + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} PyObject* THPModule_setAllowTF32CuDNN(PyObject* _unused, PyObject* arg) { THPUtils_assert( @@ -499,6 +534,21 @@ PyObject* THPModule_userEnabledFlashSDP(PyObject* _unused, PyObject* noargs) { else Py_RETURN_FALSE; } +PyObject* THPModule_setSDPUseMemEfficient(PyObject* _unused, PyObject* arg) { + THPUtils_assert( + PyBool_Check(arg), + "set_sdp_use_math expects a bool, " + "but got %s", + THPUtils_typename(arg)); + at::globalContext().setSDPUseMemEfficient(arg == Py_True); + Py_RETURN_NONE; +} +PyObject* userEnabledMemEfficientSDP(PyObject* _unused, PyObject* noargs) { + if (at::globalContext().userEnabledMemEfficientSDP()) + Py_RETURN_TRUE; + else + Py_RETURN_FALSE; +} PyObject* THPModule_setSDPUseMath(PyObject* _unused, PyObject* arg) { THPUtils_assert( PyBool_Check(arg), @@ -799,6 +849,28 @@ PyObject* THPModule_willEngineExecuteNode(PyObject* _unused, PyObject* arg) { END_HANDLE_TH_ERRORS } +PyObject* THPModule_getCurrentGraphTaskExecutionOrder( + PyObject* _unused, + PyObject* noargs) { + HANDLE_TH_ERRORS + std::vector nodes = + torch::autograd::get_current_graph_task_execution_order(); + TORCH_CHECK( + nodes.size(), + "_current_graph_task_execution_order should only be called during the backward pass"); + auto list = THPObjectPtr(PyList_New(nodes.size())); + if (!list) + return nullptr; + for (const auto i : c10::irange(nodes.size())) { + // This node is guaranteed to be alive since the backward is still running + PyObject* pyobj_node = + torch::autograd::functionToPyObject(nodes[i]->getptr()); + PyList_SET_ITEM(list.get(), i, pyobj_node); + } + return list.release(); + END_HANDLE_TH_ERRORS +} + PyObject* THPModule_getCurrentGraphTaskId(PyObject* _unused, PyObject* noargs) { HANDLE_TH_ERRORS return THPUtils_packInt64(torch::autograd::get_current_graph_task_id()); @@ -881,6 +953,7 @@ static PyMethodDef TorchMethods[] = { {"_infer_size", THPModule_inferSize, METH_VARARGS, nullptr}, {"_crash_if_csrc_asan", THPModule_crashIfCsrcASAN, METH_O, nullptr}, {"_crash_if_csrc_ubsan", THPModule_crashIfCsrcUBSAN, METH_O, nullptr}, + {"_crash_if_vptr_ubsan", THPModule_crashIfvptrUBSAN, METH_NOARGS, nullptr}, {"_crash_if_aten_asan", THPModule_crashIfATenASAN, METH_O, nullptr}, {"_show_config", THPModule_showConfig, METH_NOARGS, nullptr}, {"_cxx_flags", THPModule_cxxFlags, METH_NOARGS, nullptr}, @@ -916,6 +989,14 @@ static PyMethodDef TorchMethods[] = { METH_NOARGS, nullptr}, {"_set_sdp_use_flash", THPModule_setSDPUseFlash, METH_O, nullptr}, + {"_get_mem_efficient_sdp_enabled", + userEnabledMemEfficientSDP, + METH_NOARGS, + nullptr}, + {"_set_sdp_use_mem_efficient", + THPModule_setSDPUseMemEfficient, + METH_O, + nullptr}, {"_get_math_sdp_enabled", THPModule_userEnabledMathSDP, METH_NOARGS, @@ -990,6 +1071,10 @@ static PyMethodDef TorchMethods[] = { {"_to_dlpack", THPModule_toDLPack, METH_O, nullptr}, {"_from_dlpack", THPModule_fromDLPack, METH_O, nullptr}, {"_get_cpp_backtrace", THModule_getCppBacktrace, METH_VARARGS, nullptr}, + {"_rename_privateuse1_backend", + THModule_rename_privateuse1_backend, + METH_O, + nullptr}, {"set_flush_denormal", THPModule_setFlushDenormal, METH_O, nullptr}, {"get_default_dtype", THPModule_getDefaultDtype, METH_NOARGS, nullptr}, {"_get_default_device", THPModule_getDefaultDevice, METH_NOARGS, nullptr}, @@ -1001,6 +1086,10 @@ static PyMethodDef TorchMethods[] = { THPModule_willEngineExecuteNode, METH_O, nullptr}, + {"_current_graph_task_execution_order", + THPModule_getCurrentGraphTaskExecutionOrder, + METH_NOARGS, + nullptr}, {"_current_graph_task_id", THPModule_getCurrentGraphTaskId, METH_NOARGS, @@ -1059,6 +1148,10 @@ void initIttBindings(PyObject* module); } // namespace torch #endif +#ifdef USE_MPS +PyMethodDef* MPSModule_methods(); +#endif + namespace torch { void initVerboseBindings(PyObject* module); } // namespace torch @@ -1114,6 +1207,9 @@ PyObject* initModule() { #ifdef USE_CUDA THPUtils_addPyMethodDefs(methods, THCPModule_methods()); #endif +#ifdef USE_MPS + THPUtils_addPyMethodDefs(methods, MPSModule_methods()); +#endif #if defined(USE_DISTRIBUTED) && defined(USE_C10D) THPUtils_addPyMethodDefs( methods, torch::distributed::c10d::python_functions()); @@ -1309,7 +1405,9 @@ Call this whenever a new thread is created in order to propagate values from .value("SlowTranspose3d", at::native::ConvBackend::SlowTranspose3d) .value( "Winograd3x3Depthwise", at::native::ConvBackend::Winograd3x3Depthwise) - .value("Xnnpack2d", at::native::ConvBackend::Xnnpack2d); + .value("Xnnpack2d", at::native::ConvBackend::Xnnpack2d) + .value("Mps", at::native::ConvBackend::Mps) + .value("MpsTranspose,", at::native::ConvBackend::MpsTranspose); py_module.def( "_select_conv_backend", @@ -1317,10 +1415,10 @@ Call this whenever a new thread is created in order to propagate values from const at::Tensor& weight, const c10::optional& bias_opt, at::IntArrayRef stride_, - at::IntArrayRef padding_, + at::SymIntArrayRef padding_, at::IntArrayRef dilation_, bool transposed_, - at::IntArrayRef output_padding_, + at::SymIntArrayRef output_padding_, int64_t groups_) { return at::native::select_conv_backend( input, @@ -1331,8 +1429,62 @@ Call this whenever a new thread is created in order to propagate values from dilation_, transposed_, output_padding_, - groups_); - }); + groups_, + c10::nullopt); + }, + py::arg("input"), + py::arg("weight"), + py::arg("bias"), + py::arg("stride"), + py::arg("padding"), + py::arg("dilation"), + py::arg("transposed"), + py::arg("output_padding"), + py::arg("groups")); + + // overload for bias_sizes_opt/backward TODO: figure out default value + py_module.def( + "_select_conv_backend", + [](const at::Tensor& input, + const at::Tensor& weight, + const c10::optional& bias, + at::IntArrayRef stride_, + at::SymIntArrayRef padding_, + at::IntArrayRef dilation_, + bool transposed_, + at::SymIntArrayRef output_padding_, + int64_t groups_, + c10::optional> bias_sizes_opt) { + c10::OptionalArrayRef ref = c10::nullopt; + if (bias_sizes_opt) { + ref = (*bias_sizes_opt); + } + return at::native::select_conv_backend( + input, + weight, + bias, + stride_, + padding_, + dilation_, + transposed_, + output_padding_, + groups_, + ref); + }, + py::arg("input"), + py::arg("weight"), + py::arg("bias"), + py::arg("stride"), + py::arg("padding"), + py::arg("dilation"), + py::arg("transposed"), + py::arg("output_padding"), + py::arg("groups"), + py::arg("bias_sizes")); + + py_module.def( + "_conv_determine_backend_memory_format", + at::native::_determine_backend_memory_format); py::enum_(py_module, "_LinalgBackend") .value("Default", at::LinalgBackend::Default) @@ -1360,8 +1512,6 @@ Call this whenever a new thread is created in order to propagate values from ASSERT_TRUE(set_module_attr("has_cuda", has_cuda)); ASSERT_TRUE(set_module_attr("has_mps", has_mps)); - py_module.def("_is_mps_available", []() { return at::hasMPS(); }); - ASSERT_TRUE( set_module_attr("has_mkldnn", at::hasMKLDNN() ? Py_True : Py_False)); @@ -1400,6 +1550,12 @@ Call this whenever a new thread is created in order to propagate values from "_set_conj", [](const at::Tensor& x, bool conj) { x._set_conj(conj); }); py_module.def( "_set_neg", [](const at::Tensor& x, bool neg) { x._set_neg(neg); }); + py_module.def("_get_tensor_metadata", &torch::jit::getTensorMetadata); + py_module.def( + "_set_tensor_metadata", + static_cast)>( + torch::jit::setTensorMetadata)); py_module.def("_dispatch_key_set", [](const at::Tensor& x) { return toString(x.key_set()); }); diff --git a/torch/csrc/Size.cpp b/torch/csrc/Size.cpp index 36419f20eccd0..ba4090bfb6845 100644 --- a/torch/csrc/Size.cpp +++ b/torch/csrc/Size.cpp @@ -59,7 +59,7 @@ PyObject* THPSize_NewFromSymSizes(const at::Tensor& self_) { TORCH_CHECK( !torch::jit::tracer::isTracing(), "JIT Tracing of SymInts isn't supported"); - auto py_symint = py::cast(si.toSymIntNodeImpl()).release().ptr(); + auto py_symint = py::cast(si).release().ptr(); if (!py_symint) throw python_error(); PyTuple_SET_ITEM(ret.get(), i, py_symint); @@ -98,7 +98,7 @@ static PyObject* THPSize_pynew( if (THPUtils_checkLong(item)) { continue; } - if (torch::is_symint_node(item)) { + if (torch::is_symint(item)) { continue; } if (torch::jit::tracer::isTracing() && isTracedZeroDimVar(item)) { @@ -135,7 +135,7 @@ static PyObject* THPSize_repr(THPSize* self) { auto item = PyTuple_GET_ITEM(self, i); auto ih = py::handle(item); - repr += torch::is_symint_node(ih) + repr += torch::is_symint(ih) ? std::string(py::str(ih)) : std::to_string(THPUtils_unpackLong(PyTuple_GET_ITEM(self, i))); } diff --git a/torch/csrc/StorageMethods.cpp b/torch/csrc/StorageMethods.cpp index 2b74c8a2fd290..29f0f67ce6ecb 100644 --- a/torch/csrc/StorageMethods.cpp +++ b/torch/csrc/StorageMethods.cpp @@ -41,7 +41,7 @@ static PyObject* THPStorage_nbytes(PyObject* _self, PyObject* noargs) { HANDLE_TH_ERRORS auto self = (THPStorage*)_self; - return THPUtils_packUInt64(self->cdata->nbytes()); + return py::cast(self->cdata->sym_nbytes()).release().ptr(); END_HANDLE_TH_ERRORS } diff --git a/torch/csrc/api/include/torch/mps.h b/torch/csrc/api/include/torch/mps.h new file mode 100644 index 0000000000000..669cecfc5de49 --- /dev/null +++ b/torch/csrc/api/include/torch/mps.h @@ -0,0 +1,21 @@ +#pragma once + +#include + +#include +#include + +namespace torch { +namespace mps { + +/// Returns true if MPS device is available. +bool TORCH_API is_available(); + +/// Sets the seed for the current GPU. +void TORCH_API manual_seed(uint64_t seed); + +/// Waits for all streams on a MPS device to complete. +void TORCH_API synchronize(); + +} // namespace mps +} // namespace torch diff --git a/torch/csrc/api/include/torch/nested.h b/torch/csrc/api/include/torch/nested.h index 1730583d5e149..d91c878348bd5 100644 --- a/torch/csrc/api/include/torch/nested.h +++ b/torch/csrc/api/include/torch/nested.h @@ -1,7 +1,9 @@ #pragma once #include -#include +#include +#include +#include namespace torch { namespace nested { @@ -12,19 +14,51 @@ namespace nested { /// https://pytorch.org/docs/master/nested.html#torch.nested.nested_tensor /// /// ``` -inline Tensor nested_tensor( - TensorList list, - c10::optional dtype = c10::nullopt, - c10::optional device = c10::nullopt, - c10::optional requires_grad = false, - c10::optional pin_memory = false) { - std::vector new_list; - for (const auto i : c10::irange(list.size())) { - new_list.push_back(list[i].clone().detach()); +// implemented on python object to allow torch.nested.nested_tensor to be +// constructed with arbitrarily nested python objects - for now, only arbitrary +// python lists and lists of Tensors +// See torch/csrc/autograd/python_nested_functions_manual.cpp for Python +// implementation +// See here for C++ implementation +inline at::Tensor nested_tensor( + at::TensorList nested_tensor_data, + const at::TensorOptions& options = {}) { + auto out = at::_nested_tensor_from_tensor_list( + nested_tensor_data, + c10::typeMetaToScalarType(options.dtype()), + c10::nullopt, + options.device(), + options.pinned_memory()); + if (options.has_requires_grad() && options.requires_grad()) { + out.requires_grad_(true); + } + return out; +} + +inline at::Tensor nested_tensor( + at::ArrayRef nested_tensor_data, + const at::TensorOptions& options = {}) { + for (const auto& tdc : nested_tensor_data) { + TORCH_CHECK( + tdc.is_init_list(), + "nested_tensor() not implemented for these parameters"); } - auto out = torch::_nested_tensor_from_tensor_list( - new_list, dtype, c10::nullopt, device, pin_memory); - if (requires_grad.has_value() && requires_grad.value()) { + // Construct a TensorList using nested_tensor_data + std::vector tensor_list(nested_tensor_data.size()); + std::transform( + nested_tensor_data.begin(), + nested_tensor_data.end(), + tensor_list.begin(), + [&](const detail::TensorDataContainer& tdc) { + return tdc.convert_to_tensor(options); + }); + auto out = at::_nested_tensor_from_tensor_list( + tensor_list, + c10::typeMetaToScalarType(options.dtype()), + c10::nullopt, + options.device(), + options.pinned_memory()); + if (options.has_requires_grad() && options.requires_grad()) { out.requires_grad_(true); } return out; @@ -36,10 +70,10 @@ inline Tensor nested_tensor( /// https://pytorch.org/docs/master/nested.html#torch.nested.as_nested_tensor /// /// ``` -inline Tensor as_nested_tensor( - TensorList list, - c10::optional dtype = c10::nullopt, - c10::optional device = c10::nullopt) { +inline at::Tensor as_nested_tensor( + at::TensorList list, + c10::optional dtype = c10::nullopt, + c10::optional device = c10::nullopt) { return at::_nested_tensor_from_tensor_list( list, dtype, c10::nullopt, device, c10::nullopt); } @@ -50,11 +84,11 @@ inline Tensor as_nested_tensor( /// https://pytorch.org/docs/master/nested.html#torch.nested.to_padded_tensor /// /// ``` -inline Tensor to_padded_tensor( - const Tensor& self, +inline at::Tensor to_padded_tensor( + const at::Tensor& self, double padding, - OptionalIntArrayRef output_size = c10::nullopt) { - return torch::nested_to_padded_tensor(self, padding, output_size); + at::OptionalIntArrayRef output_size = c10::nullopt) { + return at::nested_to_padded_tensor(self, padding, output_size); } } // namespace nested diff --git a/torch/csrc/api/src/mps.cpp b/torch/csrc/api/src/mps.cpp new file mode 100644 index 0000000000000..83bb7ef2d3215 --- /dev/null +++ b/torch/csrc/api/src/mps.cpp @@ -0,0 +1,31 @@ +#include +#include + +#include + +namespace torch { +namespace mps { + +bool is_available() { + return at::detail::getMPSHooks().hasMPS(); +} + +/// Sets the seed for the MPS's default generator. +void manual_seed(uint64_t seed) { + if (is_available()) { + auto gen = at::detail::getMPSHooks().getDefaultMPSGenerator(); + { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen.mutex()); + gen.set_current_seed(seed); + } + } +} + +void synchronize() { + TORCH_CHECK(is_available(), "No MPS devices are available"); + at::detail::getMPSHooks().deviceSynchronize() +} + +} // namespace mps +} // namespace torch diff --git a/torch/csrc/api/src/nn/modules/transformer.cpp b/torch/csrc/api/src/nn/modules/transformer.cpp index 6d643fc7354f0..df08c629da561 100644 --- a/torch/csrc/api/src/nn/modules/transformer.cpp +++ b/torch/csrc/api/src/nn/modules/transformer.cpp @@ -466,7 +466,7 @@ Tensor TransformerImpl::generate_square_subsequent_mask(int64_t sz) { // Treat 0 dim valid here TORCH_CHECK( sz >= 0, - "Input size must be non-negative to genearte a valid square subsequent mask, but got ", + "Input size must be non-negative to generate a valid square subsequent mask, but got ", sz); // check IEEE754 support here since -inf is not guaranteed to be valid on non @@ -479,7 +479,7 @@ Tensor TransformerImpl::generate_square_subsequent_mask(int64_t sz) { // platform else { TORCH_WARN_ONCE( - "IEEE754 is not supporetd on this platform, generate_square_subsequent_mask will fill " + "IEEE754 is not supported on this platform, generate_square_subsequent_mask will fill " "the mask with smallest float number on this platform instead of -inf"); return torch::triu( torch::full({sz, sz}, std::numeric_limits::lowest()), 1); diff --git a/torch/csrc/api/src/optim/optimizer.cpp b/torch/csrc/api/src/optim/optimizer.cpp index 95165d850cf6f..f73e54d2835f2 100644 --- a/torch/csrc/api/src/optim/optimizer.cpp +++ b/torch/csrc/api/src/optim/optimizer.cpp @@ -64,13 +64,13 @@ void OptimizerParamState::serialize( double OptimizerOptions::get_lr() const { TORCH_CHECK( false, - "double get_lr() has not been overidden and implemented in subclass of torch::optim::OptimizerOptions, you must override it in your subclass."); + "double get_lr() has not been overridden and implemented in subclass of torch::optim::OptimizerOptions, you must override it in your subclass."); } void OptimizerOptions::set_lr(const double lr) { TORCH_CHECK( false, - "double set_lr() has not been overidden and implemented in subclass of torch::optim::OptimizerOptions, you must override it in your subclass."); + "double set_lr() has not been overridden and implemented in subclass of torch::optim::OptimizerOptions, you must override it in your subclass."); } std::unique_ptr OptimizerOptions::clone() const { diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 86b893bb014e6..fa4b4fde96c13 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -5,9 +5,9 @@ #include #include -#include #include #include +#include #include #include #include @@ -30,6 +30,7 @@ #include #include #include + // Helper functions for autogenerated code // These used to be inlined into the codegened Functions.cpp @@ -570,13 +571,13 @@ Tensor permute_backwards(const Tensor& grad, IntArrayRef fwd_dims) { Tensor rad2deg_backward(const Tensor& grad) { constexpr double M_180_PI = 57.295779513082320876798154814105170332405472466564; - return at::mul(grad, at::native::wrapped_scalar_tensor(Scalar(M_180_PI))); + return at::mul(grad, Scalar(M_180_PI)); } Tensor deg2rad_backward(const Tensor& grad) { constexpr double M_PI_180 = 0.017453292519943295769236907684886127134428718885417; - return at::mul(grad, at::native::wrapped_scalar_tensor(Scalar(M_PI_180))); + return at::mul(grad, Scalar(M_PI_180)); } Tensor unsqueeze_multiple( @@ -1098,15 +1099,15 @@ Tensor convolution_jvp( const Tensor& bias_p, const Tensor& bias_t, IntArrayRef stride, - IntArrayRef padding, + at::SymIntArrayRef padding, IntArrayRef dilation, bool transposed, - IntArrayRef output_padding, + at::SymIntArrayRef output_padding, int64_t groups) { auto bias_t_opt = bias_t.defined() ? c10::optional(bias_t) : c10::nullopt; return ( - at::convolution( + at::convolution_symint( input_t, weight_p, c10::nullopt, @@ -1116,7 +1117,7 @@ Tensor convolution_jvp( transposed, output_padding, groups) + - at::convolution( + at::convolution_symint( input_p, weight_t, bias_t_opt, @@ -1136,10 +1137,10 @@ Tensor _convolution_jvp( const Tensor& bias_p, const Tensor& bias_t, IntArrayRef stride, - IntArrayRef padding, + at::SymIntArrayRef padding, IntArrayRef dilation, bool transposed, - IntArrayRef output_padding, + at::SymIntArrayRef output_padding, int64_t groups, bool benchmark, bool deterministic, @@ -1148,7 +1149,7 @@ Tensor _convolution_jvp( auto bias_t_opt = bias_t.defined() ? c10::optional(bias_t) : c10::nullopt; return ( - at::_convolution( + at::_convolution_symint( input_t, weight_p, c10::nullopt, @@ -1162,7 +1163,7 @@ Tensor _convolution_jvp( deterministic, cudnn_enabled, allow_tf32) + - at::_convolution( + at::_convolution_symint( input_p, weight_t, bias_t_opt, @@ -2906,8 +2907,10 @@ Tensor as_strided_scatter_backward( // take the perf hit and contiguify grad for now. auto grad_ = grad.contiguous(); auto grad_slice = grad_.as_strided_symint(sizes, strides, storage_offset); - auto result = grad_.new_empty_strided_symint( - input_geometry.sym_sizes(), input_geometry.sym_strides()); + auto result = + grad_.new_zeros_symint(input_geometry.sym_sizes()) + .as_strided_symint( + input_geometry.sym_sizes(), input_geometry.sym_strides()); auto result_slice = result.as_strided_symint(sizes, strides, storage_offset); result_slice.copy_(grad_slice); return result; @@ -4833,14 +4836,32 @@ std::tuple _trilinear_backward( } Tensor log1p_backward(const Tensor& grad, const Tensor& self) { - if (self.is_sparse()) { - AT_ERROR( - "log1p of a sparse tensor is made to be non-differentiable since ", - "local gradient of zero is 1 / (0 + 1) = 1 and it makes the tensor dense. ", - "Use a different mathematical operation which preserves sparsity of gradients, ", - "or report a bug if you think this is an error."); + // We must conditionally initalize this using to_dense if sparse, sparse + // addition is not supported without exact shape match + Tensor self_p1_conj; + if (self.layout() == c10::kSparse || self.layout() == c10::kSparseCsr || + self.layout() == c10::kSparseCsc || self.layout() == c10::kSparseBsr || + self.layout() == c10::kSparseBsc) { + // The warning only applies to the sparsity of self, dense grad is never + // materialized so if self is strided and grad is sparse nothing unepected + // happens memory wise + TORCH_WARN( + "log1p_backward: received self with sparse layout, but backward requires materialization of a dense tensor with this shape"); + self_p1_conj = (self.to_dense() + 1).conj(); + } else { + // Although calling self.to_dense() would just return self when it has + // strided layout, that would breaks functorch tests. + self_p1_conj = (self + 1).conj(); + } + if (grad.layout() == c10::kSparse || grad.layout() == c10::kSparseCsr || + grad.layout() == c10::kSparseCsc || grad.layout() == c10::kSparseBsr || + grad.layout() == c10::kSparseBsc) { + // If grad is sparse we can't divide by the n-d (self + 1).conj(), so we + // must multiply by the recipricol, layout of grad is preserved which is + // important to gradcheck + return grad * self_p1_conj.reciprocal_(); } - return grad / (self + 1).conj(); + return grad / self_p1_conj; } Tensor sinc_backward(const Tensor& grad, const Tensor& self) { @@ -4870,10 +4891,10 @@ Tensor constant_pad_nd_backward(const Tensor& grad, c10::SymIntArrayRef pad) { return at::constant_pad_nd_symint(grad, negated_pad, 0); } -Tensor embedding_dense_double_backward( +Tensor embedding_dense_double_backward_symint( const Tensor& grad, const Tensor& indices, - int64_t padding_idx) { + c10::SymInt padding_idx) { // since first backward takes care of scaling by frequency, // we don't need to worry about it here. auto gg_weight = grad.index_select(0, indices.reshape(-1)); diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index 04416c2b49e08..edc7dcd140f7b 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -517,10 +517,10 @@ at::Tensor sinc_backward(const at::Tensor& grad, const at::Tensor& self); at::Tensor sparse_constructor_values_backward( const at::Tensor& sparse_grad_out, const at::Tensor& indices); -at::Tensor embedding_dense_double_backward( +at::Tensor embedding_dense_double_backward_symint( const at::Tensor& grad, const at::Tensor& indices, - int64_t padding_idx); + c10::SymInt padding_idx); at::Tensor index_backward( at::Tensor zeros_like_self, const torch::List>& indices, @@ -937,10 +937,10 @@ Tensor convolution_jvp( const Tensor& bias_p, const Tensor& bias_t, IntArrayRef stride, - IntArrayRef padding, + at::SymIntArrayRef padding, IntArrayRef dilation, bool transposed, - IntArrayRef output_padding, + at::SymIntArrayRef output_padding, int64_t groups); Tensor _convolution_jvp( @@ -951,10 +951,10 @@ Tensor _convolution_jvp( const Tensor& bias_p, const Tensor& bias_t, IntArrayRef stride, - IntArrayRef padding, + at::SymIntArrayRef padding, IntArrayRef dilation, bool transposed, - IntArrayRef output_padding, + at::SymIntArrayRef output_padding, int64_t groups, bool benchmark, bool deterministic, diff --git a/torch/csrc/autograd/autograd_meta.cpp b/torch/csrc/autograd/autograd_meta.cpp index db00d67576d3b..d11cd68e1800a 100644 --- a/torch/csrc/autograd/autograd_meta.cpp +++ b/torch/csrc/autograd/autograd_meta.cpp @@ -82,7 +82,7 @@ using at::Tensor; // base if needed. Case 5 is handled in fw_grad by reading the forward grad from // the base if needed. -namespace { +namespace utils { // Enforcing that the metadata between the primal and tangent are same has two // goals: @@ -139,7 +139,8 @@ bool has_same_meta(const Variable& base, const Variable& other) { } return true; } -} // anonymous namespace + +} // namespace utils // This function is will ensure that the fw_grad_ is properly a view of the base // for inplace ops on Tensors that do not have forward grad originally. @@ -219,7 +220,8 @@ void AutogradMeta::set_fw_grad( // Enforce same meta here to make sure that the view op below is // always valid Tensor new_base_fw_grad; - if (has_same_meta(new_grad, base) && has_same_meta(new_grad, self)) { + if (utils::has_same_meta(new_grad, base) && + utils::has_same_meta(new_grad, self)) { // TODO extend this special case to when the underlying storage of // new_grad can be re-used. new_base_fw_grad = new_grad; @@ -248,7 +250,7 @@ void AutogradMeta::set_fw_grad( } // Enforce the basic layout constraint - if (!has_same_meta(new_grad, self)) { + if (!utils::has_same_meta(new_grad, self)) { if (is_view_) { auto this_view_meta = static_cast(this); TORCH_INTERNAL_ASSERT( diff --git a/torch/csrc/autograd/custom_function.h b/torch/csrc/autograd/custom_function.h index bc7489292c239..d7670d924b1fa 100644 --- a/torch/csrc/autograd/custom_function.h +++ b/torch/csrc/autograd/custom_function.h @@ -300,7 +300,7 @@ auto Function::apply(Args&&... args) TORCH_CHECK( false, "jvp is not implemented for the c++ API of custom Function yet.", - "Please open a feature request on Github if you need this."); + "Please open a feature request on GitHub if you need this."); }; auto wrapped_outputs = _wrap_outputs( diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp index ca9ae4e443df5..ef4856cf4796a 100644 --- a/torch/csrc/autograd/engine.cpp +++ b/torch/csrc/autograd/engine.cpp @@ -271,10 +271,7 @@ void Engine::stop() { // Under some conditions, autograd threads can hang on shutdown // Do not wait for them to shutdown indefinitely but rely on timeout auto wait_duration_str = getenv("TORCH_AUTOGRAD_SHUTDOWN_WAIT_LIMIT"); - if (!wait_duration_str) { - wait_duration_str = "10.0"; - } - auto wait_duration = std::atof(wait_duration_str); + auto wait_duration = wait_duration_str ? std::atof(wait_duration_str) : 10.0; bool noBackward = true; for (auto& queue : device_ready_queues_) { noBackward = noBackward && queue->empty(); @@ -398,6 +395,66 @@ void add_node_to_current_graph_task_exec_info(Node* fn) { current_graph_task->exec_info_[fn].needed_ = true; } +// NB: The engine itself does not use the outputs of this function. +std::vector get_current_graph_task_execution_order() { + std::shared_ptr task = current_graph_task; + if (!task) { + return {}; + } + + // We could potentially check if there is only a single device here + // but explicitly require this context doens't seem bad either + TORCH_CHECK( + !c10::AutogradState::get_tls_state().get_multithreading_enabled(), + "get_current_graph_task_execution_order expects the current backward to be " + "executed with multithreading disabled, e.g. by running:\n\n" + ">>> with torch.autograd.set_multithreading_enabled(False):\n" + "... torch.autograd.grad(...)\n"); + + const bool check_exec_info = !task->exec_info_.empty(); + std::vector out{}; + std::unordered_set seen{}; + + auto compare_seq_nr = [](Node* n1, Node* n2) { + return n1->sequence_nr() < n2->sequence_nr(); + }; + std::priority_queue, decltype(compare_seq_nr)> heap( + compare_seq_nr); + + for (Node* ptr : task->graph_roots_) { + heap.push(ptr); + } + + // Implementation notes: + // - Don't need to count dependencies because we have sequence_nr + // - Don't need to check topological_nr because we have exec_info + while (!heap.empty()) { + Node* fn = heap.top(); + heap.pop(); + + const bool was_inserted = seen.insert(fn).second; + if (!was_inserted) { + continue; + } + + out.push_back(fn); + for (const auto& edge : fn->next_edges()) { + Node* next_ptr = edge.function.get(); + if (!next_ptr) { + continue; + } + if (check_exec_info) { + auto it = task->exec_info_.find(next_ptr); + if (it == task->exec_info_.end() || !it->second.should_execute()) { + continue; + } + } + heap.push(next_ptr); + } + } + return out; +} + // NOTE: graph_tasks do not necessarily form a stack. Imagine this // case: // @@ -1050,7 +1107,7 @@ auto Engine::compute_dependencies( } auto Engine::execute( - const edge_list& roots, + const edge_list& root_edges, const variable_list& inputs, bool keep_graph, bool create_graph, @@ -1058,9 +1115,9 @@ auto Engine::execute( const edge_list& outputs) -> variable_list { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) validate_outputs( - roots, const_cast(inputs), [](const std::string& msg) { - return msg; - }); + root_edges, + const_cast(inputs), + [](const std::string& msg) { return msg; }); if (accumulate_grad && create_graph) { TORCH_WARN_ONCE( "Using backward() with create_graph=True will create a reference cycle " @@ -1083,17 +1140,25 @@ auto Engine::execute( init_local_ready_queue(); bool not_reentrant_backward_call = worker_device == NO_DEVICE; + // Store root nodes so we can traverse through the graph later + // e.g., for get_current_graph_task_execution_order + c10::SmallVector temp_roots{root_edges.size()}; + for (const auto i : c10::irange(root_edges.size())) { + temp_roots[i] = root_edges[i].function.get(); + } + auto graph_task = std::make_shared( /* keep_graph */ keep_graph, /* create_graph */ create_graph, /* depth */ not_reentrant_backward_call ? 0 : total_depth + 1, - /* cpu_ready_queue */ local_ready_queue); + /* cpu_ready_queue */ local_ready_queue, + /* graph_roots */ std::move(temp_roots)); // If we receive a single root, skip creating extra root node - bool skip_dummy_node = roots.size() == 1; + bool skip_dummy_node = root_edges.size() == 1; auto graph_root = skip_dummy_node - ? roots.at(0).function - : std::make_shared(roots, inputs); + ? root_edges.at(0).function + : std::make_shared(root_edges, inputs); auto min_topo_nr = compute_min_topological_nr(outputs); // Now compute the dependencies for all executable functions @@ -1106,14 +1171,17 @@ auto Engine::execute( // Queue the root if (skip_dummy_node) { - InputBuffer input_buffer(roots.at(0).function->num_inputs()); + InputBuffer input_buffer(root_edges.at(0).function->num_inputs()); auto input = inputs.at(0); const auto input_stream = InputMetadata(input).stream(); const auto opt_next_stream = - roots.at(0).function->stream(c10::DeviceType::CUDA); + root_edges.at(0).function->stream(c10::DeviceType::CUDA); input_buffer.add( - roots.at(0).input_nr, std::move(input), input_stream, opt_next_stream); + root_edges.at(0).input_nr, + std::move(input), + input_stream, + opt_next_stream); execute_with_graph_task(graph_task, graph_root, std::move(input_buffer)); } else { diff --git a/torch/csrc/autograd/function.cpp b/torch/csrc/autograd/function.cpp index 5ab3447ca9ef3..22c67d0771d3e 100644 --- a/torch/csrc/autograd/function.cpp +++ b/torch/csrc/autograd/function.cpp @@ -99,5 +99,17 @@ void deleteNode(Node* function) { } } +namespace { +bool kAutogradFunctionExtensionEnabled = false; +} + +bool isAutogradFunctionExtensionEnabled() { + return kAutogradFunctionExtensionEnabled; +} + +void setAutogradFunctionExtensionEnabled(bool enabled) { + kAutogradFunctionExtensionEnabled = enabled; +} + } // namespace autograd } // namespace torch diff --git a/torch/csrc/autograd/function.h b/torch/csrc/autograd/function.h index aa82e3ad2c77c..d27d473b3f805 100644 --- a/torch/csrc/autograd/function.h +++ b/torch/csrc/autograd/function.h @@ -55,6 +55,15 @@ class NodeGuard { std::shared_ptr last_evaluating_node_; }; +// Global (not thread-local) feature flag for the new autograd.Function +// extension. The extension consists of: +// - splitting autograd.Function.forward into forward() and setup_context(). +// - adding a vmap staticmethod to autograd.Function +// The feature flag is for preventing users from unknowningly stumbling upon +// the feature and will be removed once we've ironed out the details. +TORCH_API bool isAutogradFunctionExtensionEnabled(); +TORCH_API void setAutogradFunctionExtensionEnabled(bool enabled); + //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // Node //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -143,6 +152,9 @@ struct TORCH_API Node : std::enable_shared_from_this { Node& operator=(Node&& other) = delete; virtual ~Node() = default; + std::shared_ptr getptr() { + return shared_from_this(); + } /// Evaluates the function on the given inputs and returns the result of the /// function call. variable_list operator()(variable_list&& inputs) { diff --git a/torch/csrc/autograd/functions/accumulate_grad.h b/torch/csrc/autograd/functions/accumulate_grad.h index 5a9a0b914a871..9089d541f96b9 100644 --- a/torch/csrc/autograd/functions/accumulate_grad.h +++ b/torch/csrc/autograd/functions/accumulate_grad.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include #include diff --git a/torch/csrc/autograd/functions/tensor.cpp b/torch/csrc/autograd/functions/tensor.cpp index 377c40ce388e2..51afb8203186e 100644 --- a/torch/csrc/autograd/functions/tensor.cpp +++ b/torch/csrc/autograd/functions/tensor.cpp @@ -88,18 +88,49 @@ auto CopySlices::apply(variable_list&& inputs) -> variable_list { result.as_strided_symint(view.sym_sizes(), view.sym_strides(), offset); } - // Adding the missing nodes to the current graph's `exec_info`. - // This is a workaround because the current `GraphTask::init_to_execute` - // does not traverse into CopySlices node. + // See Note [View + Inplace update for view tensor] For more details on this + // block Since the gradient edge for the 0th input is different between `this` + // and `fn`, make sure that the one from `fn` has the same metadata in the + // current GraphTask's exec_info as the one on `this`. const auto exec_info = get_current_graph_task_exec_info(); if (exec_info && !exec_info->empty()) { - for (const auto& next : fn->next_edges()) { - if (next.is_valid()) { - add_node_to_current_graph_task_exec_info(next.function.get()); + const auto& fn_edge = fn->next_edge(0); + const auto& this_edge = this->next_edge(0); + TORCH_INTERNAL_ASSERT(fn_edge.is_valid() == this_edge.is_valid()); + if (fn_edge.is_valid()) { + const auto fn_next_node = fn_edge.function.get(); + auto it = exec_info->find(fn_next_node); + if (it == exec_info->end()) { + // Node is not in the exec_info already + if (task_should_compute_output(0)) { + // And we need gradient for the corresponding output + add_node_to_current_graph_task_exec_info(fn_next_node); + // There is no need to remove this after execution because we are + // guaranteed that this->next_edge(0) must be in the history of + // fn->next_edge(0) (we cannot easily assert this as it might be far + // away if there were many chained views). This means that, since + // fn->next_edge(0) was not needed (no exec_info entry for it), we + // know that nothing downstream of fn->next_edge(0) is needed either + // (otherwise the whole path from that Node to this->next_edge(0) + // would be needed as well). This means that no other Node will ever + // look at fn->next_edge(0) metadata and thus there is no need to + // clean them up. + } + } else { + TORCH_INTERNAL_ASSERT( + it->second.should_execute() == task_should_compute_output(0)); } } } + // Sanity check that the graph was never modified after the fact (it is + // read-only!) + TORCH_INTERNAL_ASSERT(num_outputs() == fn->num_outputs()); + for (const auto i : c10::irange(1, this->num_outputs())) { + TORCH_INTERNAL_ASSERT( + fn->next_edge(i).function.get() == this->next_edge(i).function.get()); + } + // TODO: We clone grad_slice because we modify it below and "fn" might save // it for the backward of res. We might be able to avoid the clone() if // double-backprop is disabled. diff --git a/torch/csrc/autograd/functions/tensor.h b/torch/csrc/autograd/functions/tensor.h index cd77c8ceb7244..06f155a754b07 100644 --- a/torch/csrc/autograd/functions/tensor.h +++ b/torch/csrc/autograd/functions/tensor.h @@ -21,7 +21,60 @@ struct TORCH_API CopyBackwards : public Node { }; // Note [View + Inplace update for base tensor] -// Performs grad_view = fn(grad_view), but out-of-place. +// +// This note covers a few important topics related to view + inplace handling. +// - It explains what is the CopySlices Node and why we need it. +// - It explains the considerations on what is saved for backward in +// CopySlices. +// - It explains why we need to sometimes change the exec_info of the current +// backward +// +// What is CopySlices? +// ~~~~~~~~~~~~~~~~~~~ +// +// We support autograd with inplace mutation; e.g., if you write x.mul_(2) +// the autograd will work as if you now had multiple Tensors under the hood and +// you did +// x = t.clone() +// x0 = x +// x1 = x0 * 2 +// x = x1 +// As you can see here, after this operation, x.grad_fn now points to x1.grad_fn +// (the MulBackward node) and this node points to x's original grad_fn (which is +// also x0.grad_fn). It is important to keep in mind that after the inplace, +// there is no Tensor object that represents the x0 state anymore. But the graph +// for it is still around in autograd (in case x was used before being modified +// inplace). See Example 1 in +// https://docs.google.com/drawings/d/1-T5DyYfChMX1ONQkY-zU-hj_ayQ2zmA5CBOKDWqvEhE +// We call this rebasing the history of the Tensor. +// +// Now, a difficult situation is what happens if x is a differentiable view +// of a base b. +// b = t.clone() +// x = b.select(0, 0) +// x *= 2 +// With the same approach as above, this will become +// b = t.clone() +// x = b.select(0, 0) +// b0 = b +// x0 = x +// x1 = x0 * 2 +// b1 = b0.select_scatter(x1, 0, 0) +// x2 = b1.select(0, 0) +// x = x2 +// b = b1 +// As you can see here, not only we need to modify x's grad_fn, we also need to +// modify the one from b. We also need to ensure that the new grad_fn on x is +// linked to b's new grad_fn. The chain the select_scatter, multiplication and +// select is what CopySlices does, all wrapped into a single Node. +// +// See Example 1 in +// https://docs.google.com/drawings/d/1-T5DyYfChMX1ONQkY-zU-hj_ayQ2zmA5CBOKDWqvEhE +// +// What do we need to save in CopySlices to run backward? +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// +// We need to perform grad_view = fn(grad_view), but out-of-place. // view_fn_ is an optional lambda function saved in DifferentiableViewMeta // from forward pass, so that we can recover we when as_strided is not // supported. It preserves the invariants: @@ -57,8 +110,6 @@ struct TORCH_API CopyBackwards : public Node { // efficient than the as_strided one so we should be careful to only use it when // necessary. // -// What do we use in CopySlices backward? -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // - For CPU/CUDA we save TensorGeometry of both base and view tensors, // That's all we need to pass into as_strided. // E.g. int[] sizes, int[] strides, and int storage_offset. @@ -66,7 +117,7 @@ struct TORCH_API CopyBackwards : public Node { // by **value**. // E.g for at::narrow, int dim, int start, in length are saved. // -// Theorectically we could also save Tensor `view` in CopySlices Node, but +// Theoretically we could also save Tensor `view` in CopySlices Node, but // it's far more expensive than what we currently save. // 1. We cannot afford keeping large tensors alive to recover views only. // 2. There are inplace checks when Tensors are loaded back to make sure @@ -76,9 +127,28 @@ struct TORCH_API CopyBackwards : public Node { // allows the user to modify the original Tensor without preventing the // backward pass from running. // -// When an in-place operation is done on a differentiable view, the base's -// grad_fn is updated to become a `CopySlice` wrapping the backward of the -// in-place operation. +// Why do we manually change exec_info in the apply? +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// +// Using the same example as before, +// b = t.clone() +// x = b.select(0, 0) +// x *= y +// +// You can see the visualization at +// https://docs.google.com/drawings/d/1Bx-Hcz-zlIv7PabQqnPhUIVIs9F8WWi48svqMsAUMFs +// which contains the wrapped MulBackward Node and show what it links to. +// Since a backward can happen between any subset of the inputs (t and y) and +// outputs (o, x, b). It is possible to get into a state where CopySlices's 0th +// next function (CloneBackward) needs gradient but MulBackward's 0th next +// function (SelectBackward) is not. This happens if you do autograd.grad +// between x and t for example. +// In such a case, we do need to mark SelectBackward as requiring gradient such +// that, during the execution of MulBackward, we will actually compute gradient +// for the 0th input. +// +// All the other next functions are always shared (this is asserted in the apply +// code) and so nothing needs to be done for them. // See Note [View + Inplace update for view tensor] for what we do to view // tensor when an in-place operation happens. diff --git a/torch/csrc/autograd/graph_task.h b/torch/csrc/autograd/graph_task.h index 8eb122313d0a0..4efbc905fed37 100644 --- a/torch/csrc/autograd/graph_task.h +++ b/torch/csrc/autograd/graph_task.h @@ -37,6 +37,7 @@ struct GraphTask : std::enable_shared_from_this { // Records the nodes that are in the graph std::unordered_set nodes_in_graph_; + c10::SmallVector graph_roots_; // Note [Exec info] // Exec info is created for each GraphTask, which allows filtering paths on // the graph that are not needed. It has a bit complicated semantics. If it's @@ -164,8 +165,10 @@ struct GraphTask : std::enable_shared_from_this { bool grad_mode, int reentrant_depth, std::shared_ptr cpu_ready_queue, + c10::SmallVector graph_roots, bool exit_on_error = false) : keep_graph_(keep_graph), + graph_roots_(std::move(graph_roots)), owner_(NO_DEVICE), reentrant_depth_(reentrant_depth), exit_on_error_(exit_on_error), @@ -198,6 +201,7 @@ get_current_graph_task_exec_info(); TORCH_API const std::unordered_set* get_current_graph_task_nodes_in_graph(); TORCH_API bool get_current_graph_task_keep_graph(); +TORCH_API std::vector get_current_graph_task_execution_order(); TORCH_API int get_current_graph_task_id(); void add_node_to_current_graph_task_exec_info(Node* fn); diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index 007150002dbb6..709cc46308f3c 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -28,6 +28,7 @@ #include #include #include +#include #include #include @@ -278,8 +279,9 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) { m.def("_supported_activities", []() { std::set activities{ActivityType::CPU}; -#if defined(USE_KINETO) && !defined(LIBKINETO_NOCUPTI) - if (at::getNumGPUs() > 0 && !at::hasHIP()) { +#if defined(USE_KINETO) && \ + (!defined(LIBKINETO_NOCUPTI) || !defined(LIBKINETO_NOROCTRACER)) + if (at::getNumGPUs() > 0) { activities.insert(ActivityType::CUDA); } #endif @@ -513,6 +515,30 @@ static PyObject* set_autocast_cache_enabled(PyObject* _unused, PyObject* arg) { END_HANDLE_TH_ERRORS } +static PyObject* is_autograd_function_extension_enabled( + PyObject* _unused, + PyObject* arg) { + HANDLE_TH_ERRORS + if (torch::autograd::isAutogradFunctionExtensionEnabled()) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } + END_HANDLE_TH_ERRORS +} + +static PyObject* set_autograd_function_extension_enabled( + PyObject* _unused, + PyObject* arg) { + HANDLE_TH_ERRORS + if (!PyBool_Check(arg)) { + throw TypeError("enabled must be a bool (got %s)", Py_TYPE(arg)->tp_name); + } + torch::autograd::setAutogradFunctionExtensionEnabled(arg == Py_True); + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + static PyObject* set_grad_enabled(PyObject* _unused, PyObject* arg) { HANDLE_TH_ERRORS if (!PyBool_Check(arg)) { @@ -606,24 +632,11 @@ static PyObject* python_exit_dual_level( END_HANDLE_TH_ERRORS } -static PyObject* set_torch_function_mode(PyObject* _unused, PyObject* arg) { - HANDLE_TH_ERRORS - if (arg == Py_None) { - at::impl::PythonTorchFunctionTLS::set_mode(nullptr); - } else { - Py_INCREF(arg); - at::impl::PythonTorchFunctionTLS::set_mode( - std::make_shared(arg, getPyInterpreter())); - } - Py_RETURN_NONE; - END_HANDLE_TH_ERRORS; -} - static PyObject* is_torch_function_mode_enabled( PyObject* _unused, PyObject* _unused2) { HANDLE_TH_ERRORS - if (at::impl::function_mode_enabled()) { + if (at::impl::torch_function_mode_enabled()) { Py_RETURN_TRUE; } else { Py_RETURN_FALSE; @@ -682,19 +695,6 @@ static PyObject* len_torch_function_stack( END_HANDLE_TH_ERRORS } -static PyObject* set_torch_dispatch_mode(PyObject* _unused, PyObject* arg) { - HANDLE_TH_ERRORS - if (arg == Py_None) { - c10::impl::TorchDispatchModeTLS::set_mode(nullptr); - } else { - Py_INCREF(arg); - c10::impl::TorchDispatchModeTLS::set_mode( - std::make_shared(arg, getPyInterpreter())); - } - Py_RETURN_NONE; - END_HANDLE_TH_ERRORS; -} - static PyObject* push_on_torch_dispatch_stack( PyObject* _unused, PyObject* arg) { @@ -777,6 +777,14 @@ static PyMethodDef methods[] = { // NOLINT METH_NOARGS, nullptr}, {"set_autocast_cache_enabled", set_autocast_cache_enabled, METH_O, nullptr}, + {"_set_autograd_function_extension_enabled", + set_autograd_function_extension_enabled, + METH_O, + nullptr}, + {"_is_autograd_function_extension_enabled", + is_autograd_function_extension_enabled, + METH_NOARGS, + nullptr}, {"set_anomaly_enabled", castPyCFunctionWithKeywords(set_anomaly_mode_enabled), METH_VARARGS | METH_KEYWORDS, @@ -795,7 +803,6 @@ static PyMethodDef methods[] = { // NOLINT is_torch_function_mode_enabled, METH_NOARGS, nullptr}, - {"_set_torch_function_mode", set_torch_function_mode, METH_O, nullptr}, {"_push_on_torch_function_stack", push_on_torch_function_stack, METH_O, @@ -812,7 +819,6 @@ static PyMethodDef methods[] = { // NOLINT len_torch_function_stack, METH_NOARGS, nullptr}, - {"_set_torch_dispatch_mode", set_torch_dispatch_mode, METH_O, nullptr}, {"_push_on_torch_dispatch_stack", push_on_torch_dispatch_stack, METH_O, diff --git a/torch/csrc/autograd/input_buffer.cpp b/torch/csrc/autograd/input_buffer.cpp index 6cc6acefc9d45..50d4c0ce0aa60 100644 --- a/torch/csrc/autograd/input_buffer.cpp +++ b/torch/csrc/autograd/input_buffer.cpp @@ -1,6 +1,6 @@ #include -#include +#include #include #include #include diff --git a/torch/csrc/autograd/input_metadata.h b/torch/csrc/autograd/input_metadata.h index 7cb9e8aedb195..8060c11ac4575 100644 --- a/torch/csrc/autograd/input_metadata.h +++ b/torch/csrc/autograd/input_metadata.h @@ -125,13 +125,13 @@ struct InputMetadata { if (grad.is_nested()) { ss << at::native::get_nested_size_tensor(grad); } else { - ss << grad.sizes(); + ss << grad.sym_sizes(); } ss << " but expected shape compatible with "; if (is_nested_tensor()) { ss << shape_as_tensor(); } else { - ss << c10::asIntArrayRefSlow(shape_as_dim_vector()); + ss << shape_as_dim_vector(); } return ss; } diff --git a/torch/csrc/autograd/profiler_kineto.cpp b/torch/csrc/autograd/profiler_kineto.cpp index c7ab982b38897..c9b7b9fa92960 100644 --- a/torch/csrc/autograd/profiler_kineto.cpp +++ b/torch/csrc/autograd/profiler_kineto.cpp @@ -1,3 +1,4 @@ +#include #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include @@ -12,8 +13,10 @@ #include #include #include +#include #include #include +#include #include #include #include @@ -61,11 +64,40 @@ using torch::profiler::impl::ActiveProfilerType; using torch::profiler::impl::dtypesToStr; using torch::profiler::impl::EventType; using torch::profiler::impl::ExtraFields; +using torch::profiler::impl::op_input_t; using torch::profiler::impl::ProfilerStateBase; using torch::profiler::impl::PyExtraFieldsBase; using torch::profiler::impl::Result; using torch::profiler::impl::shapesToStr; using torch::profiler::impl::stacksToStr; +using torch::profiler::impl::TensorMetadata; + +auto shapesAndDtypes(const std::vector& inputs) { + std::vector> shapes; + std::vector dtypes; + for (const auto& i : inputs) { + c10::visit( + c10::overloaded( + [&](const TensorMetadata& t) { + shapes.emplace_back(t.sizes_); + dtypes.emplace_back(scalarTypeToTypeMeta(t.dtype_).name()); + }, + [&](const std::vector&) { + shapes.emplace_back(); + dtypes.emplace_back("TensorList"); + }, + [&](const c10::IValue&) { + shapes.emplace_back(); + dtypes.emplace_back("Scalar"); + }, + [&](const auto&) { + shapes.emplace_back(); + dtypes.emplace_back(); + }), + i); + } + return std::make_pair(shapes, dtypes); +} struct MetadataBase { MetadataBase(const std::shared_ptr& result) @@ -134,10 +166,12 @@ struct AddTensorboardFields : public MetadataBase { }; struct AddGenericMetadata : public MetadataBase { - AddGenericMetadata(std::shared_ptr& result, const bool verbose) - : MetadataBase(result) { + AddGenericMetadata( + std::shared_ptr& result, + const torch::profiler::impl::ProfilerConfig* config) + : MetadataBase(result), config_(config) { result->visit(*this); - if (verbose) { + if (config->experimental_config.verbose) { result->visit_if_base( [&, this](const auto& i) -> void { this->addMetadata("Python thread", std::to_string(i.python_tid_)); @@ -146,14 +180,22 @@ struct AddGenericMetadata : public MetadataBase { } void operator()(ExtraFields& op_event) { - auto& shapes = op_event.inputs_.shapes_; - if (!shapes.empty()) { - addMetadata("Input Dims", shapesToStr(shapes)); + const auto shapes_and_dtypes = shapesAndDtypes(op_event.inputs_); + if (!shapes_and_dtypes.first.empty()) { + addMetadata("Input Dims", shapesToStr(shapes_and_dtypes.first)); + } + + if (!shapes_and_dtypes.second.empty()) { + addMetadata("Input type", dtypesToStr(shapes_and_dtypes.second)); } - auto& dtypes = op_event.inputs_.dtypes_; - if (!dtypes.empty()) { - addMetadata("Input type", dtypesToStr(dtypes)); + if (config_ && !config_->experimental_config.performance_events.empty()) { + auto& event_names = config_->experimental_config.performance_events; + for (auto i = 0; i < op_event.perf_event_counters_->size(); ++i) { + addMetadata( + event_names[i], + std::to_string((*op_event.perf_event_counters_)[i])); + } } // add information about an associated forward op, if a sequence number @@ -197,6 +239,10 @@ struct AddGenericMetadata : public MetadataBase { template void operator()(const T&) {} + + private: + /* To get names of the performance events */ + const torch::profiler::impl::ProfilerConfig* config_; }; // Assumption: Total threads number will not exceed 2^16-1, and total ops will @@ -315,7 +361,7 @@ struct KinetoThreadLocalState : public ProfilerStateBase { kineto_events_.emplace_back(e, config_.experimental_config.verbose); AddTensorboardFields add_tb(e, kineto_events_.back()); - AddGenericMetadata add_generic(e, config_.experimental_config.verbose); + AddGenericMetadata add_generic(e, &config_); // It is not safe to use the activity after post processing. e->kineto_activity_ = nullptr; @@ -433,6 +479,10 @@ void onFunctionExit( TORCH_INTERNAL_ASSERT(kineto_ctx_ptr != nullptr); kineto_ctx_ptr->event_->end_time_ = torch::profiler::impl::getApproximateTime(); + if (!config.experimental_config.performance_events.empty()) { + state_ptr->record_queue_.getSubqueue()->disable_perf_profiler( + *kineto_ctx_ptr->event_->counters_); + } kineto_ctx_ptr->event_->basic_fields_.end_tid_ = at::RecordFunction::currentThreadId(); if (config.state == ProfilerState::KINETO_GPU_FALLBACK) { @@ -518,6 +568,33 @@ void prepareProfiler( "Supported only in Kineto profiler"); torch::profiler::impl::kineto::prepareTrace( /*cpuOnly=*/!at::hasCUDA(), activities, config.experimental_config); + + if (config.experimental_config.performance_events.size()) { + /* For now only CPU activity is supported */ + TORCH_CHECK( + activities.count(torch::autograd::profiler::ActivityType::CPU), + "Cannot run cpu hardware profiler without CPU activities, please only use CPU activity type"); + /* + * Sending a warning and passing the non-standard event to the backend + * Backend can abort if the event is not supported. + * TODO Should we gracefully drop the invalid event if we have atleast one + * valid? + */ + auto is_standard_event = [](const std::string& event) -> bool { + for (auto e : torch::profiler::ProfilerPerfEvents) { + if (!std::strcmp(event.c_str(), e)) { + return true; + } + } + return false; + }; + + for (const auto& e : config.experimental_config.performance_events) { + if (!is_standard_event(e)) { + TORCH_WARN("Forwarding a non-standard CPU performance event : ", e); + } + } + } } void enableProfilerWithEventPostProcess( @@ -636,6 +713,10 @@ KinetoEvent::KinetoEvent( parent = parent->parent_.lock(); } } + + result->visit_if_base>([&](const auto& op) { + std::tie(shapes_, dtypes_) = shapesAndDtypes(op.inputs_); + }); } bool KinetoEvent::isPythonFunction() const { @@ -644,6 +725,22 @@ bool KinetoEvent::isPythonFunction() const { return out; } +bool KinetoEvent::hasShapes() const { + return !shapes_.empty(); +} + +const c10::ArrayRef> KinetoEvent::shapes() const { + return shapes_; +} + +bool KinetoEvent::hasTypes() const { + return !dtypes_.empty(); +} + +const c10::ArrayRef KinetoEvent::dtypes() const { + return dtypes_; +} + const c10::ArrayRef KinetoEvent::stack() const { auto get = [&](const auto& i) -> auto& { return !i.jit_stack_.empty() ? i.jit_stack_ : python_stack_; @@ -709,6 +806,21 @@ int64_t KinetoEvent::cudaElapsedUs() const { return -1; } +void KinetoEvent::getPerfEventCounters(std::vector& in) const { + return result_->visit(c10::overloaded( + [&in](const ExtraFields& e) -> void { + const size_t n = e.perf_event_counters_->size(); + // should be rare + if (in.size() < n) { + in.resize(n, 0); + } + for (size_t i = 0; i < n; ++i) { + in[i] = (*e.perf_event_counters_)[i]; + } + }, + [](const auto&) -> void { return; })); +} + #define FORWARD_FROM_RESULT(method_name, result_expr) \ decltype(std::declval().method_name()) \ KinetoEvent::method_name() const { \ @@ -746,10 +858,6 @@ FORWARD_FROM_RESULT(deviceResourceId, kineto_info_.resource) TYPED_ATTR_WITH_DEFAULT(TorchOp, sequenceNr, e.sequence_number_, -1) TYPED_ATTR(TorchOp, fwdThreadId, e.sequence_number_ >= 0 ? e.forward_tid_ : 0) -TYPED_ATTR(TorchOp, hasShapes, !e.inputs_.shapes_.empty()) -TYPED_ATTR(TorchOp, shapes, e.inputs_.shapes_) -TYPED_ATTR(TorchOp, hasTypes, !e.inputs_.dtypes_.empty()) -TYPED_ATTR(TorchOp, dtypes, e.inputs_.dtypes_) TYPED_ATTR(TorchOp, scope, static_cast(e.scope_)) TYPED_ATTR(TorchOp, hasModuleHierarchy, !e.jit_modules_.empty()) TYPED_ATTR(TorchOp, isAsync, e.is_async_) diff --git a/torch/csrc/autograd/profiler_kineto.h b/torch/csrc/autograd/profiler_kineto.h index 5e5b430aa2814..d85232f96cb58 100644 --- a/torch/csrc/autograd/profiler_kineto.h +++ b/torch/csrc/autograd/profiler_kineto.h @@ -4,6 +4,7 @@ #include #include +#include #include #include @@ -54,6 +55,7 @@ struct TORCH_API KinetoEvent { std::string backend() const; bool isPythonFunction() const; int64_t cudaElapsedUs() const; + void getPerfEventCounters(torch::profiler::perf_counters_t&) const; private: torch::profiler::impl::ProfilerEventStub fallbackStart() const; @@ -61,6 +63,10 @@ struct TORCH_API KinetoEvent { std::shared_ptr result_; std::vector python_stack_; + + // Copy fields from result so we can return ArrayRefs. + std::vector> shapes_; + std::vector dtypes_; }; // Consolidating events returned directly from Kineto diff --git a/torch/csrc/autograd/profiler_python.cpp b/torch/csrc/autograd/profiler_python.cpp index aee3702b8b105..5cf08afcbd1f0 100644 --- a/torch/csrc/autograd/profiler_python.cpp +++ b/torch/csrc/autograd/profiler_python.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -128,7 +129,7 @@ class CallTypeHelper final { std::index_sequence); template - static void map(T& t, FunctorT& f, Args... args) { + static void map(T& t, FunctorT& f, Args&&... args) { f(std::get(t), args...); c10::guts::if_constexpr( [&](auto _) { map(_(t), f, std::forward(args)...); }); @@ -138,7 +139,7 @@ class CallTypeHelper final { using tuple_type = decltype(make_tuple_impl(std::make_index_sequence{})); template - static void map(tuple_type& t, FunctorT& f, Args... args) { + static void map(tuple_type& t, FunctorT& f, Args&&... args) { map<0>(t, f, std::forward(args)...); } }; @@ -204,20 +205,40 @@ struct Config { static constexpr EventType event_type = EventType::PyCall; }; -template <> -struct Config { - using key_t = PyModuleSelf; - using cls_t = PyModuleCls; +template +struct ExtendedPyCallConfig { + using key_t = Key; + using cls_t = Cls; using ephemeral_t = PyFrameObject*; - using info_t = std::pair>; - struct cache_t { - c10::optional location_; // nn.Module.forward; - ska::flat_hash_map modules_and_params_; + + struct ClsAndParameters { + cls_t cls_; + std::vector parameters_; + }; + + struct Cache { + // `nn.Module.forward` or `optim.Optimizer._optimizer_step_code` + c10::optional location_; + ska::flat_hash_map cls_and_parameters_; ska::flat_hash_map cls_names_; }; + using cache_t = Cache; + static constexpr EventType event_type = EventType::PyCall; }; +template <> +struct Config : ExtendedPyCallConfig< + PyModuleSelf, + PyModuleCls, + NNModuleInfo::ParameterInfo> {}; + +template <> +struct Config : ExtendedPyCallConfig< + PyOptimizerSelf, + PyOptimizerCls, + OptimizerInfo::ParameterInfo> {}; + template <> struct Config { using key_t = PyMethod; @@ -226,25 +247,6 @@ struct Config { static constexpr EventType event_type = EventType::PyCCall; }; -template <> -struct Config { - using key_t = PyOptimizerSelf; - using cls_t = PyOptimizerCls; - using ephemeral_t = PyFrameObject*; - struct info_t { - cls_t cls_; - std::vector params_; - std::vector> states_; - }; - struct cache_t { - c10::optional - location_; // optim.Optimizer._optimizer_step_code; - ska::flat_hash_map optimizer_data_; - ska::flat_hash_map cls_names_; - }; - static constexpr EventType event_type = EventType::PyCall; -}; - // ============================================================================ // == Callsite & ValueCache: Storage during profiling ========================= // ============================================================================ @@ -269,52 +271,6 @@ class Callsite { Config::key_t caller_; }; -void check_and_store( - const pybind11::handle& name, - const pybind11::handle& param_handle, - std::vector& storeroom) { - auto param_ptr = param_handle.ptr(); - if (py::isinstance(name) && THPVariable_CheckExact(param_ptr)) { - const auto& param = THPVariable_Unpack(param_ptr); - auto grad_ptr = py::getattr(param_handle, "grad", py::none()).ptr(); - c10::optional grad_metadata; - - if (THPVariable_CheckExact(grad_ptr)) { - grad_metadata = c10::optional( - TensorMetadata(THPVariable_Unpack(grad_ptr))); - } else { - grad_metadata = c10::nullopt; - } - - storeroom.push_back( - {name.cast(), TensorMetadata(param), grad_metadata}); - } -} - -void check_and_store( - const pybind11::handle& name, - const pybind11::handle& param_handle, - std::vector, TensorMetadata>>& - storeroom) { - auto param_ptr = param_handle.ptr(); - if (py::isinstance(name) && THPVariable_CheckExact(param_ptr)) { - const auto& param = THPVariable_Unpack(param_ptr); - - storeroom.emplace_back(name.cast(), param); - } -} - -void check_and_store( - const pybind11::handle& param_handle, - std::vector& storeroom) { - auto param_ptr = param_handle.ptr(); - if (THPVariable_CheckExact(param_ptr)) { - const auto& param = THPVariable_Unpack(param_ptr); - - storeroom.emplace_back(param); - } -} - // ============================================================================ // == Type specific store and load implementations. =========================== // ============================================================================ @@ -325,6 +281,9 @@ using PyOptimizerCallKey = Config::key_t; class ValueCache { public: + ValueCache() = default; + ValueCache(const ValueCache&) = delete; + template void store(const typename Config::key_t&, typename Config::ephemeral_t); @@ -339,6 +298,9 @@ class ValueCache { load(callsite.value_)}; } + c10::optional recordIfTensor(py::handle p); + std::vector> unpackTensorMap( + py::dict tensor_map); void trimPrefixes(); private: @@ -374,6 +336,34 @@ typename Config::cls_t set_class( return cls; } +TensorMetadata toTensorMetadata(PyObject* self) { + TORCH_INTERNAL_ASSERT(THPVariable_CheckExact(self)); + const auto& t = THPVariable_Unpack(self); + RawTensorMetadata m{t}; + return TensorMetadata{ + m, + t.sizes().vec(), + m.layout_ == at::kStrided ? t.strides().vec() : std::vector()}; +} + +c10::optional ValueCache::recordIfTensor(py::handle p) { + return THPVariable_CheckExact(p.ptr()) + ? c10::optional{toTensorMetadata(p.ptr())} + : c10::nullopt; +} + +std::vector> ValueCache::unpackTensorMap( + py::dict tensor_map) { + std::vector> out; + for (auto& it : tensor_map) { + auto* value = it.second.ptr(); + if (py::isinstance(it.first) && THPVariable_CheckExact(value)) { + out.push_back({py::cast(it.first), toTensorMetadata(value)}); + } + } + return out; +} + template <> void ValueCache::store(const PyCallKey& key, no_ephemeral_t) { auto& locations = std::get(state_); @@ -397,16 +387,22 @@ void ValueCache::store( Config::ephemeral_t frame) { auto& cache = std::get(state_); if (C10_UNLIKELY( - cache.modules_and_params_.find(key) == - cache.modules_and_params_.end())) { + cache.cls_and_parameters_.find(key) == + cache.cls_and_parameters_.end())) { auto cls = set_class(this, cache, key, frame); py::dict params = py::handle((PyObject*)key).attr("_parameters"); - std::vector params_; + std::vector params_; for (auto& it : params) { - check_and_store(it.first, it.second, params_); + auto* p = it.second.ptr(); + if (py::isinstance(it.first) && THPVariable_CheckExact(p)) { + params_.push_back( + {it.first.cast(), + toTensorMetadata(p), + recordIfTensor(py::getattr(it.second, "grad", py::none()))}); + } } - cache.modules_and_params_[key] = make_pair(cls, params_); + cache.cls_and_parameters_[key] = {cls, params_}; } } @@ -415,45 +411,45 @@ ExtraFields::args_t ValueCache::load( const PyModuleCallKey& key) const { auto& cache = std::get(state_); TORCH_INTERNAL_ASSERT(cache.location_.has_value()); - auto cls = cache.modules_and_params_.at(key).first; - auto fwd = std::get(state_).at(*cache.location_); + const auto& cls_and_parameters = cache.cls_and_parameters_.at(key); + const auto& cls = cls_and_parameters.cls_; + NNModuleInfo info{ + key, cls, cache.cls_names_.at(cls), cls_and_parameters.parameters_}; return { - fwd, - NNModuleInfo{ - key, - cls, - cache.cls_names_.at(cls), - cache.modules_and_params_.at(key).second}}; + /*frame_state_=*/std::get(state_).at(*cache.location_), + /*module_info_=*/std::move(info), + /*optimizer_info_=*/c10::nullopt}; } + template <> void ValueCache::store( const PyOptimizerCallKey& key, Config::ephemeral_t frame) { auto& cache = std::get(state_); if (C10_UNLIKELY( - cache.optimizer_data_.find(key) == cache.optimizer_data_.end())) { + cache.cls_and_parameters_.find(key) == + cache.cls_and_parameters_.end())) { auto cls = set_class(this, cache, key, frame); - py::list param_groups_handle = - py::handle((PyObject*)key).attr("param_groups"); - std::vector params_; - // param_groups is a list of dict - for (auto& param_group : param_groups_handle) { - for (auto& param : - py::cast(param_group).attr("get")("params")) { - check_and_store(param, params_); - } - } - std::vector> states_; - py::dict state_handle = py::handle((PyObject*)key).attr("state"); - for (auto& it : state_handle) { - TORCH_INTERNAL_ASSERT( - py::isinstance(it.second), "Expects a dict type element"); - for (auto& state_elem : py::cast(it.second)) { - check_and_store(state_elem.first, state_elem.second, states_); + const py::handle self{(PyObject*)key}; + std::vector params; + + for (const auto& i : (py::list)self.attr("param_groups")) { + for (auto& param : py::cast(i).attr("get")("params")) { + if (THPVariable_CheckExact(param.ptr())) { + // While `self.state` is permitted to store data in an arbitrary way, + // all generic optimizers (SGD, Adam, etc) use param as the key since + // the state in question is tied to particular parameters. We can + // relax this assumption if the need arises. + params.push_back( + {toTensorMetadata(param.ptr()), + recordIfTensor(py::getattr(param, "grad", py::none())), + unpackTensorMap(py::cast(self.attr("state")) + .attr("get")(param, py::dict()))}); + } } } - cache.optimizer_data_[key] = {cls, params_, states_}; + cache.cls_and_parameters_[key] = {cls, params}; } } @@ -461,17 +457,14 @@ template <> ExtraFields::args_t ValueCache::load< CallType::PyOptimizerCall>(const PyOptimizerCallKey& key) const { auto& cache = std::get(state_); - auto cls = cache.optimizer_data_.at(key).cls_; - auto frame_state = std::get(state_).at(*cache.location_); + const auto& cls_and_parameters = cache.cls_and_parameters_.at(key); + auto cls = cls_and_parameters.cls_; + OptimizerInfo info{ + key, cls, cache.cls_names_.at(cls), cls_and_parameters.parameters_}; return { - frame_state, - c10::nullopt, - OptimizerInfo{ - key, - cls, - cache.cls_names_.at(cls), - cache.optimizer_data_.at(key).params_, - cache.optimizer_data_.at(key).states_}}; + /*frame_state_=*/std::get(state_).at(*cache.location_), + /*module_info_=*/c10::nullopt, + /*optimizer_info_=*/std::move(info)}; } template <> diff --git a/torch/csrc/autograd/python_engine.cpp b/torch/csrc/autograd/python_engine.cpp index 3bd12f480d409..dc365c1700088 100644 --- a/torch/csrc/autograd/python_engine.cpp +++ b/torch/csrc/autograd/python_engine.cpp @@ -1,7 +1,7 @@ #include -#include -#include +#include +#include #include #include #include diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp index 39374f7f82978..a66897f7f0095 100644 --- a/torch/csrc/autograd/python_function.cpp +++ b/torch/csrc/autograd/python_function.cpp @@ -9,6 +9,7 @@ #include #include +#include #include #include #include @@ -40,7 +41,6 @@ using namespace torch; using namespace torch::autograd; -using namespace torch::jit; using at::Tensor; PyObject* THPFunctionClass = nullptr; @@ -600,6 +600,7 @@ static void _append_subgraph( torch::jit::Graph* graph, std::vector trace_outputs, bool unpack_output) { + using Value = torch::jit::Value; node->g_( torch::jit::attr::Subgraph, std::make_shared(graph->current_scope())); @@ -692,8 +693,8 @@ static void _trace_post_record( node->addOutput(); auto old_node = node; if (!unpack_output) { - std::vector tuple_values(num_outputs, TensorType::get()); - TypePtr tuple_type = TupleType::create(std::move(tuple_values)); + std::vector tuple_values(num_outputs, at::TensorType::get()); + auto tuple_type = at::TupleType::create(std::move(tuple_values)); // Original type is tuple of tensors "without" element type and shape. // The missed parts will be added below. node->output()->setType(tuple_type); @@ -705,7 +706,7 @@ static void _trace_post_record( for (const auto i : c10::irange(num_outputs)) { PyObject* obj = PyTuple_GET_ITEM(output_objects, i); if (THPVariable_Check(obj)) { - Value* value = node->outputs()[i]; + auto value = node->outputs()[i]; const auto& tensor = THPVariable_Unpack(obj); if (tensor.defined()) { value->inferTypeFrom(tensor); @@ -723,12 +724,12 @@ static void _trace_post_record( // If TupleUnpack operator is created, we copy its output type back // to the original tuple type. if (!unpack_output) { - std::vector new_tuple_values; + std::vector new_tuple_values; for (const auto i : c10::irange(num_outputs)) { - TypePtr ptr = node->outputs()[i]->type(); + auto ptr = node->outputs()[i]->type(); new_tuple_values.push_back(ptr); } - TypePtr tuple_type = TupleType::create(std::move(new_tuple_values)); + auto tuple_type = at::TupleType::create(std::move(new_tuple_values)); // The i-th tuple element receives a new tensor type with element type and // shape. old_node->output()->setType(tuple_type); @@ -821,6 +822,43 @@ PyObject* THPFunction_maybe_clear_saved_tensors( END_HANDLE_TH_ERRORS } +namespace { + +THPObjectPtr make_ctx_input_tuple( + THPFunction* ctx, + const UnpackedInput& unpacked_input, + int64_t num_args) { + THPObjectPtr ctx_input_tuple(PyTuple_New(num_args + 1)); + if (!ctx_input_tuple) + return {}; + Py_INCREF(ctx); + PyTuple_SET_ITEM(ctx_input_tuple.get(), 0, (PyObject*)ctx); + for (const auto i : c10::irange(num_args)) { + PyObject* arg = PyTuple_GET_ITEM(unpacked_input.input_tuple.get(), i); + Py_INCREF(arg); + PyTuple_SET_ITEM(ctx_input_tuple.get(), i + 1, arg); + } + return ctx_input_tuple; +} + +THPObjectPtr make_ctx_input_output_tuple( + THPFunction* ctx, + UnpackedInput& unpacked_input, + PyObject* outputs) { + THPObjectPtr result(PyTuple_New(3)); + if (!result) + return {}; + Py_INCREF(ctx); + Py_INCREF(unpacked_input.input_tuple.get()); + Py_INCREF(outputs); + PyTuple_SET_ITEM(result.get(), 0, (PyObject*)ctx); + PyTuple_SET_ITEM(result.get(), 1, unpacked_input.input_tuple.get()); + PyTuple_SET_ITEM(result.get(), 2, outputs); + return result; +} + +} // namespace + PyObject* THPFunction_apply(PyObject* cls, PyObject* inputs) { HANDLE_TH_ERRORS @@ -865,29 +903,51 @@ PyObject* THPFunction_apply(PyObject* cls, PyObject* inputs) { ctx->needs_input_grad = input_info.needs_input_grad.release(); ctx->is_variable_input = std::move(input_info.is_variable_input); - // Prepend ctx to input_tuple, in preparation for static method call + // autograd.Function may optionally contain a setup_context staticmethod. + // In this case, autograd.Function.forward does NOT accept a ctx object. + bool has_separate_setup_context_fn = + (isAutogradFunctionExtensionEnabled() && + PyObject_HasAttrString(cls, "setup_context")); + auto num_args = PyTuple_GET_SIZE(inputs); - THPObjectPtr ctx_input_tuple(PyTuple_New(num_args + 1)); - if (!ctx_input_tuple) - return nullptr; - Py_INCREF(ctx); - PyTuple_SET_ITEM(ctx_input_tuple.get(), 0, (PyObject*)ctx); - for (const auto i : c10::irange(num_args)) { - PyObject* arg = PyTuple_GET_ITEM(unpacked_input.input_tuple.get(), i); - Py_INCREF(arg); - PyTuple_SET_ITEM(ctx_input_tuple.get(), i + 1, arg); - } // Call forward - THPObjectPtr tensor_outputs; + THPObjectPtr outputs; { AutoGradMode grad_mode(false); at::AutoFwGradMode fw_grad_mode(false); THPObjectPtr forward_fn(PyObject_GetAttrString(cls, "forward")); if (!forward_fn) return nullptr; - tensor_outputs = PyObject_CallObject(forward_fn, ctx_input_tuple); - if (!tensor_outputs) + if (has_separate_setup_context_fn) { + // call forward followed by setup_context + outputs = PyObject_CallObject(forward_fn, unpacked_input.input_tuple); + if (!outputs) { + return nullptr; + } + // signature is setup_context(ctx, inputs, outputs) + auto ctx_input_output_tuple = + make_ctx_input_output_tuple(ctx, unpacked_input, outputs); + if (!ctx_input_output_tuple) { + return nullptr; + } + THPObjectPtr setup_context_fn( + PyObject_GetAttrString(cls, "setup_context")); + auto result = + PyObject_CallObject(setup_context_fn, ctx_input_output_tuple); + if (!result) { + return nullptr; + } + } else { + // call forward + auto ctx_input_tuple = + make_ctx_input_tuple(ctx, unpacked_input, num_args); + if (!ctx_input_tuple) { + return nullptr; + } + outputs = PyObject_CallObject(forward_fn, ctx_input_tuple); + } + if (!outputs) return nullptr; } @@ -897,7 +957,7 @@ PyObject* THPFunction_apply(PyObject* cls, PyObject* inputs) { ctx, unpacked_input, inputs, - std::move(tensor_outputs), + std::move(outputs), is_executable, node); END_HANDLE_TH_ERRORS diff --git a/torch/csrc/autograd/python_nested_functions.h b/torch/csrc/autograd/python_nested_functions.h index 8b0bf9c115d17..6a86a3a7a1fe0 100644 --- a/torch/csrc/autograd/python_nested_functions.h +++ b/torch/csrc/autograd/python_nested_functions.h @@ -3,7 +3,9 @@ namespace torch { namespace autograd { +PyMethodDef* get_nested_functions_manual(); + void initNestedFunctions(PyObject* module); -} +} // namespace autograd } // namespace torch diff --git a/torch/csrc/autograd/python_nested_functions_manual.cpp b/torch/csrc/autograd/python_nested_functions_manual.cpp new file mode 100644 index 0000000000000..0e1823e192b3a --- /dev/null +++ b/torch/csrc/autograd/python_nested_functions_manual.cpp @@ -0,0 +1,44 @@ +#include +#include +#include +#include + +namespace torch { +namespace autograd { + +static PyObject* THPVariable_nested_tensor( + PyObject* /*self*/, + PyObject* args, + PyObject* kwargs) { + HANDLE_TH_ERRORS + static PythonArgParser parser({ + "nested_tensor(PyObject* data, *, ScalarType dtype=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)", + }); + + constexpr int ctor_num_args = 5; + ParsedArgs parsed_args; + auto r = parser.parse(args, kwargs, parsed_args); + + jit::tracer::warn( + "torch.nested.nested_tensor", jit::tracer::WARN_CONSTRUCTOR); + return THPVariable_Wrap(torch::utils::nested_tensor_ctor( + torch::tensors::get_default_dispatch_key(), + torch::tensors::get_default_scalar_type(), + r)); + END_HANDLE_TH_ERRORS +} + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) +static PyMethodDef nested_functions_manual[] = { + {"nested_tensor", + castPyCFunctionWithKeywords(THPVariable_nested_tensor), + METH_VARARGS | METH_KEYWORDS, + nullptr}, +}; + +PyMethodDef* get_nested_functions_manual() { + return nested_functions_manual; +} + +} // namespace autograd +} // namespace torch diff --git a/torch/csrc/autograd/python_torch_functions_manual.cpp b/torch/csrc/autograd/python_torch_functions_manual.cpp index 949bf1219f5ab..6aaaaf0eff6e9 100644 --- a/torch/csrc/autograd/python_torch_functions_manual.cpp +++ b/torch/csrc/autograd/python_torch_functions_manual.cpp @@ -1,6 +1,9 @@ #include #include #include +#include +#include +#include #include #include #include @@ -405,12 +408,30 @@ static PyObject* THPVariable__to_functional_tensor( PyObject* kwargs) { HANDLE_TH_ERRORS static PythonArgParser parser( - {"_to_functional_tensor(Tensor t)"}, /*traceable=*/true); + {"_to_functional_tensor(Tensor t, *, bool mirror_autograd_meta=False)"}, + /*traceable=*/true); - ParsedArgs<1> parsed_args; + ParsedArgs<2> parsed_args; auto r = parser.parse(args, kwargs, parsed_args); auto self_ = r.tensor(0); + auto mirror_autograd_meta = r.toBool(1); auto wrapped = at::functionalization::impl::to_functional_tensor(self_); + if (mirror_autograd_meta) { + // Here, we unsafely set the grad function on the wrapper to be the same as + // the inner. We expect this grad_fn to NEVER be used. It's needed so that + // .is_leaf metadata is accurate on the wrapper + auto inner_autograd_meta = impl::get_autograd_meta(self_); + if (inner_autograd_meta) { + wrapped.set_requires_grad(self_.requires_grad()); + if (wrapped.requires_grad()) { + auto new_grad_fn = std::shared_ptr( + new torch::autograd::Error( + "Cannot backprop through mirrored meta, file a bug in PyTorch"), + torch::autograd::deleteNode); + torch::autograd::set_history(wrapped, new_grad_fn); + } + } + } return wrap(wrapped); END_HANDLE_TH_ERRORS } @@ -431,6 +452,22 @@ static PyObject* THPVariable__from_functional_tensor( END_HANDLE_TH_ERRORS } +static PyObject* THPVariable__freeze_functional_tensor( + PyObject* self, + PyObject* args, + PyObject* kwargs) { + HANDLE_TH_ERRORS + static PythonArgParser parser( + {"_freeze_functional_tensor(Tensor t)"}, /*traceable=*/true); + + ParsedArgs<1> parsed_args; + auto r = parser.parse(args, kwargs, parsed_args); + auto self_ = r.tensor(0); + at::functionalization::impl::freeze_functional_tensor(self_); + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + static PyObject* THPVariable__is_functional_tensor( PyObject* self, PyObject* args, @@ -535,6 +572,10 @@ static PyMethodDef torch_functions_manual[] = { castPyCFunctionWithKeywords(THPVariable__from_functional_tensor), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, + {"_freeze_functional_tensor", + castPyCFunctionWithKeywords(THPVariable__freeze_functional_tensor), + METH_VARARGS | METH_KEYWORDS | METH_STATIC, + nullptr}, {"_sync", castPyCFunctionWithKeywords(THPVariable__sync), METH_VARARGS | METH_KEYWORDS | METH_STATIC, @@ -672,7 +713,7 @@ static PyObject* THPVariable_numel( } if (r.idx == 0) { - return wrap(r.tensor(0).numel()); + return py::cast(r.tensor(0).sym_numel()).release().ptr(); } Py_RETURN_NONE; END_HANDLE_TH_ERRORS diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp index 9b52f7b50943a..a08d6f7761fd2 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -1,8 +1,10 @@ #include #include +#include #include #include #include +#include #include #include #include @@ -31,6 +33,7 @@ #include #include #include +#include #include #include #include @@ -167,7 +170,7 @@ void pushPyOutToStack( if (num_returns == 0) { // Check that we got a None return from Python. Anything else is an error. TORCH_CHECK( - out.is(py::none()), + out.is_none(), "Expected ", msg, " for ", @@ -219,6 +222,14 @@ struct ConcretePyInterpreterVTable final const c10::OperatorHandle& op, c10::DispatchKeySet, torch::jit::Stack* stack) const override; + // NB: this is defined in python_dispatch.cpp + void python_op_registration_trampoline( + const c10::OperatorHandle& op, + c10::DispatchKey key, + torch::jit::Stack* stack) const override { + torch::impl::dispatch::python_op_registration_trampoline_impl( + op, key, stack); + } bool is_contiguous(const TensorImpl* self, at::MemoryFormat) const override; bool is_strides_like(const TensorImpl* self, at::MemoryFormat) const override; @@ -314,7 +325,10 @@ class PyInterpreterHolder { public: PyInterpreterHolder() : impl_(new c10::impl::PyInterpreter( - ConcretePyInterpreterVTable::instance())) {} + ConcretePyInterpreterVTable::instance())) { + is_main_interpreter_ = + at::impl::PythonOpRegistrationTrampoline::registerInterpreter(impl_); + } // NB: intentionally leaks the PyInterpreter, as there may still be // references to it that are live, living in objects that aren't being // destructed while Python is being cleaned up. @@ -324,9 +338,13 @@ class PyInterpreterHolder { c10::impl::PyInterpreter* get() const noexcept { return impl_; } + bool is_main_interpreter() const noexcept { + return is_main_interpreter_; + } private: c10::impl::PyInterpreter* impl_; + bool is_main_interpreter_; }; PyInterpreterHolder self_interpreter; @@ -352,6 +370,10 @@ c10::impl::PyInterpreter* getPyInterpreter() { return self_interpreter.get(); } +bool isMainPyInterpreter() { + return self_interpreter.is_main_interpreter(); +} + std::string ConcretePyInterpreterVTable::name() const { std::stringstream ss; ss << getPyInterpreter(); @@ -416,6 +438,13 @@ PyObject* THPVariable_Wrap(at::TensorBase var) { Py_RETURN_NONE; } + if (c10::impl::HermeticPyObjectTLS::get_state()) { + return THPVariable_NewWithVar( + (PyTypeObject*)THPVariableClass, + std::move(var), + c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED); + } + c10::optional mb_obj = var.unsafeGetTensorImpl()->check_pyobj(self_interpreter.get()); c10::impl::PyInterpreterStatus status; @@ -489,6 +518,11 @@ bool isResurrectable(THPVariable* self) { return false; } auto const& tensor = THPVariable_Unpack(self); + // Check if this is hermetic. If it is, no resurrection. + if (tensor.unsafeGetTensorImpl()->check_pyobj(self_interpreter.get()) != + c10::make_optional((PyObject*)self)) { + return false; + } if (!tensor.defined() || tensor.use_count() <= 1) { return false; } @@ -531,6 +565,7 @@ static bool THPVariable_tryResurrect(THPVariable* self) { // Flip THPVariable to be non-owning // (near use-after-free miss here: fresh MaybeOwned is created breaking // reference on Tensor in struct BEFORE we overwrite the old one) + TORCH_INTERNAL_ASSERT(!c10::impl::HermeticPyObjectTLS::get_state()); self->cdata = MaybeOwned::borrowed(tensor); // NB: At this point, tensor *could* be dead (e.g., some other C++ thread @@ -582,7 +617,9 @@ static int THPVariable_clear(THPVariable* self) { // unsafeIsBorrowed() is TRUE. We're deallocating the PyObject // because Tensor asked us to (it's already destructing). - if (!self->cdata.unsafeIsBorrowed()) { + if (!self->cdata.unsafeIsBorrowed() && + tensor.unsafeGetTensorImpl()->check_pyobj(self_interpreter.get()) == + c10::make_optional((PyObject*)self)) { // TODO: empirically, on OS X this assert appears to be untrue // In test_py_tensors_multi_async_call - ProcessGroupRpcTestWithSpawn // distributed/rpc/test_process_group_agent.py @@ -647,6 +684,36 @@ static PyObject* THPVariable_fix_weakref(PyObject* self, PyObject* noargs) { Py_RETURN_NONE; } +static PyObject* THPVariable_view_func(PyObject* self_, PyObject* arg) { + HANDLE_TH_ERRORS + const auto& self = THPVariable_Unpack(self_); + TORCH_CHECK( + THPVariable_Check(arg), + "_view_func expect a single argument that is a Tensor"); + const auto& new_base = THPVariable_Unpack(arg); + + // Ensure that self is indeed a backward differentiable view + auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(self); + TORCH_CHECK( + diff_view_meta && diff_view_meta->has_bw_view(), + "_view_func can only be called on " + "a Tensor that is a backward differentiable view."); + const auto& view_info = diff_view_meta->get_backward_view(); + // Ensure that the newly provided base is similar to the original base + TORCH_CHECK( + torch::autograd::utils::has_same_meta(new_base, view_info.base_), + "The new base passed to _view_func must have the same metadata as the Tensors's base"); + + // Do the actual view replay + if (view_info.has_view_fn()) { + return THPVariable_Wrap(view_info.view_fn()(new_base)); + } else { + return THPVariable_Wrap(new_base.as_strided( + self.sizes(), self.strides(), self.storage_offset())); + } + END_HANDLE_TH_ERRORS +} + // Instantiates a subclass of self with the same data. static PyObject* THPVariable_as_subclass( PyObject* _self, @@ -686,7 +753,9 @@ static PyObject* THPVariable_make_subclass( throw torch::TypeError( "cls must be a type (got %s)", Py_TYPE(cls)->tp_name); } - torch_dispatch_mode::StashTorchDispatchModeGuard td_g; + // guard completely turns off torch dispatch modes, doesn't just pop off the + // stack + torch_dispatch_mode::StashTorchDispatchStackGuard td_g; c10::impl::DisablePythonDispatcher dpd_g; auto data = r.tensor(1).detach(); // creates a fresh Tensor (DEFINITELY_UNINITIALIZED) @@ -818,7 +887,7 @@ static PyObject* THPVariable_make_wrapper_subclass( if (sizes_strides_policy.has_value()) { TORCH_CHECK( false, - "Setting sizes_strides_policy isn't suppored for this overload") + "Setting sizes_strides_policy isn't supported for this overload") } } @@ -1606,6 +1675,7 @@ static PyMethodDef extra_methods[] = { METH_STATIC | METH_VARARGS | METH_KEYWORDS, nullptr}, {"_fix_weakref", THPVariable_fix_weakref, METH_NOARGS, nullptr}, + {"_view_func", THPVariable_view_func, METH_O, nullptr}, {nullptr}}; /* From https://github.com/python/cpython/blob/v3.7.0/Modules/xxsubtype.c @@ -1885,11 +1955,27 @@ static PyObject* THPVariable_NewWithVar( auto v = (THPVariable*)obj; // TODO: named constructor to avoid default initialization new (&v->cdata) MaybeOwned(); - v->cdata = MaybeOwned::owned(std::move(_var)); - const auto& var = THPVariable_Unpack(v); - var.unsafeGetTensorImpl()->init_pyobj(self_interpreter.get(), obj, status); - if (check_has_torch_dispatch(obj)) { - var.unsafeGetTensorImpl()->set_python_dispatch(true); + if (c10::impl::HermeticPyObjectTLS::get_state()) { + // Do NOT initialize pyobj field on the tensor, you own the C++ + v->cdata = MaybeOwned::owned(std::move(_var)); + TORCH_INTERNAL_ASSERT( + !check_has_torch_dispatch(obj), + "While HermeticPyObject was enabled, we attempted to create a tensor " + "subclass with __torch_dispatch__. This violates the invariant that " + "operations in HermeticPyObject have equivalent C++ implementations. " + "If your operator registered from Python operator registration isn't " + "doing anything strange, there may be an internal PyTorch bug involving " + "not appropriately disabling TorchDispatchMode before executing " + "Python op registration."); + } else { + // Normal codepath + v->cdata = MaybeOwned::owned(std::move(_var)); + const auto& var = THPVariable_Unpack(v); + var.unsafeGetTensorImpl()->init_pyobj( + self_interpreter.get(), obj, status); + if (check_has_torch_dispatch(obj)) { + var.unsafeGetTensorImpl()->set_python_dispatch(true); + } } } return obj; @@ -2262,11 +2348,20 @@ void ConcretePyInterpreterVTable::python_dispatcher( torch::jit::Stack* stack) const { py::gil_scoped_acquire g; py::handle torch_api_function_overload = getTorchApiFunction(op); + // TODO: if necessary, can optimize to cache the cache lookup + // TODO: if necessary, can optimize OpOverload to have slots + auto cache = py::dict(torch_api_function_overload.attr("_dispatch_cache")); + if (cache.ptr() == nullptr) { + throw python_error(); + } c10::DispatchKey k = ks.highestPriorityTypeId(); - auto handler = torch_api_function_overload.attr(toString(k)); + // TODO: allow this to be non-owning + auto handler = py::reinterpret_borrow( + PyDict_GetItem(cache.ptr(), py::cast(k).ptr())); if (handler.ptr() == nullptr) { - throw python_error(); + // Slow path + handler = torch_api_function_overload.attr("_get_dispatch")(k); } if (py::isinstance(handler)) { // NB: not redispatch, as that will permanently remove the python @@ -2351,7 +2446,7 @@ bool ConcretePyInterpreterVTable::is_contiguous( {py::cast(memory_format)}); } - if (out.is(py::none())) { + if (out.is_none()) { return self->is_contiguous_default(memory_format); } @@ -2384,7 +2479,7 @@ bool ConcretePyInterpreterVTable::is_strides_like( "torch.ops.aten", {py::cast(memory_format)}); - if (out.is(py::none())) { + if (out.is_none()) { return self->is_strides_like_default(memory_format); } @@ -2413,7 +2508,7 @@ bool ConcretePyInterpreterVTable::is_non_overlapping_and_dense( .ptr(), "torch.ops.aten"); - if (out.is(py::none())) { + if (out.is_none()) { return self->is_non_overlapping_and_dense_default(); } @@ -2485,7 +2580,7 @@ c10::IntArrayRef ConcretePyInterpreterVTable::strides( .ptr(), "torch.ops.aten"); - if (out.is(py::none())) { + if (out.is_none()) { TORCH_CHECK( !self->has_symbolic_sizes_strides(), "Cannot call strides on a tensor with symbolic shapes/strides"); @@ -2544,7 +2639,7 @@ c10::IntArrayRef ConcretePyInterpreterVTable::sizes( .ptr(), "torch.ops.aten"); - if (out.is(py::none())) { + if (out.is_none()) { TORCH_CHECK( !self->has_symbolic_sizes_strides(), "Cannot call sizes on a tensor with symbolic shapes/strides"); @@ -2575,7 +2670,7 @@ c10::SymIntArrayRef ConcretePyInterpreterVTable::sym_sizes( .ptr(), "torch.ops.aten"); - if (out.is(py::none())) { + if (out.is_none()) { return self->sym_sizes_default(); } // We need to squeeze SymIntNodes and ints into `SymInts` @@ -2638,15 +2733,14 @@ c10::SymInt ConcretePyInterpreterVTable::sym_numel( .ptr(), "torch.ops.aten"); - if (out.is(py::none())) { + if (out.is_none()) { TORCH_CHECK( !self->has_symbolic_sizes_strides(), "Cannot call numel on a tensor with symbolic shapes/strides"); return self->sym_numel_default(); } - return torch::is_symint_node(out) - ? out.cast()->toSymInt() - : c10::SymInt{py::cast(out)}; + return torch::is_symint(out) ? out.cast() + : c10::SymInt{py::cast(out)}; } c10::SymInt ConcretePyInterpreterVTable::sym_storage_offset( @@ -2664,12 +2758,11 @@ c10::SymInt ConcretePyInterpreterVTable::sym_storage_offset( .ptr(), "torch.ops.aten"); - if (out.is(py::none())) { + if (out.is_none()) { return self->sym_storage_offset_default(); } - return torch::is_symint_node(out) - ? out.cast()->toSymInt() - : c10::SymInt{py::cast(out)}; + return torch::is_symint(out) ? out.cast() + : c10::SymInt{py::cast(out)}; } c10::SymIntArrayRef ConcretePyInterpreterVTable::sym_strides( @@ -2688,7 +2781,7 @@ c10::SymIntArrayRef ConcretePyInterpreterVTable::sym_strides( .ptr(), "torch.ops.aten"); - if (out.is(py::none())) { + if (out.is_none()) { return self->sym_strides_default(); } // We need to squeeze SymIntNodes and ints into `SymInts` @@ -2699,9 +2792,8 @@ c10::SymIntArrayRef ConcretePyInterpreterVTable::sym_strides( py::list symints; for (auto it = out.begin(); it != out.end(); it++) { auto elm = *it; - auto si = torch::is_symint_node(elm) - ? elm.cast()->toSymInt() - : c10::SymInt{py::cast(elm)}; + auto si = torch::is_symint(elm) ? elm.cast() + : c10::SymInt{py::cast(elm)}; symints.append(si.as_int_unchecked()); } diff --git a/torch/csrc/autograd/python_variable.h b/torch/csrc/autograd/python_variable.h index be0bd458197e0..8f448df06b327 100644 --- a/torch/csrc/autograd/python_variable.h +++ b/torch/csrc/autograd/python_variable.h @@ -69,6 +69,7 @@ inline const at::Tensor& THPVariable_Unpack(PyObject* obj) { } TORCH_PYTHON_API c10::impl::PyInterpreter* getPyInterpreter(); +TORCH_PYTHON_API bool isMainPyInterpreter(); std::pair parseIValuesToPyArgsKwargs( const c10::OperatorHandle& op, diff --git a/torch/csrc/autograd/saved_variable.cpp b/torch/csrc/autograd/saved_variable.cpp index a2e0f05b63943..d438205e8947f 100644 --- a/torch/csrc/autograd/saved_variable.cpp +++ b/torch/csrc/autograd/saved_variable.cpp @@ -144,7 +144,16 @@ Variable SavedVariable::unpack(std::shared_ptr saved_for) const { : grad_fn_; if (!is_leaf_ && !grad_fn) { - TORCH_INTERNAL_ASSERT(saved_for, "No grad_fn for non-leaf saved tensor"); + // This issue was introduced when we added logic to save the original + // because now we rely on data_.grad_fn(), but can be unreliable if the + // autograd_meta of that saved tensor is cleared with an in-place detach. + // As a simple fix, we choose to disallow that behavior here even though + // it makes behavior inconsistent depending on whether you are saving + // input or output. + TORCH_CHECK( + saved_for, + "Trying to use a saved tensor that has been detached in-place, i.e. with .detach_()." + "This is not supported, please use out-of-place `.detach()` instead"); grad_fn = std::move(saved_for); } diff --git a/torch/csrc/autograd/utils/grad_layout_contract.h b/torch/csrc/autograd/utils/grad_layout_contract.h index 2addde79c8ec2..37dda0f9acaac 100644 --- a/torch/csrc/autograd/utils/grad_layout_contract.h +++ b/torch/csrc/autograd/utils/grad_layout_contract.h @@ -28,9 +28,9 @@ inline bool obeys_layout_contract( return false; } else if (variable.is_non_overlapping_and_dense()) { // Only look at stride for dimensions that are not of size 1. - const auto& grad_sizes = grad.sizes(); - const auto& grad_strides = grad.strides(); - const auto& variable_strides = variable.strides(); + const auto& grad_sizes = grad.sym_sizes(); + const auto& grad_strides = grad.sym_strides(); + const auto& variable_strides = variable.sym_strides(); for (const auto idx : c10::irange(grad_sizes.size())) { if (grad_sizes[idx] != 1) { if (grad_strides[idx] != variable_strides[idx]) { diff --git a/torch/csrc/autograd/variable.cpp b/torch/csrc/autograd/variable.cpp index a2f075fcf1cf0..368a55ea8c1a7 100644 --- a/torch/csrc/autograd/variable.cpp +++ b/torch/csrc/autograd/variable.cpp @@ -761,7 +761,38 @@ void handle_view_on_rebase( } else { modified_obj = "is being"; } - if (grad_fn) { + + if (creation_meta == CreationMeta::INFERENCE_MODE || + creation_meta == CreationMeta::NO_GRAD_MODE || !grad_fn) { + std::string prefix; + if (grad_fn) { + prefix = c10::str( + "Output ", + diff_view_meta->output_nr_, + " of ", + grad_fn->name(), + " is a view of a view which was created in"); + } else { + prefix = "A view was created in"; + } + if (creation_meta == CreationMeta::INFERENCE_MODE) { + msg = c10::str( + prefix, + " inference mode and ", + modified_obj, + " modified inplace in normal mode."); + } else { + // create_meta is not necessarily CreationMeta::NO_GRAD_MODE + // e.g. CreationMeta::IN_CUSTOM_FUNCTION is possible, but we know that + // if there is no grad_fn, that means that the view was performed in + // no-grad mode + msg = c10::str( + prefix, + " no_grad mode and ", + modified_obj, + " modified inplace with grad mode enabled."); + } + } else { msg = c10::str( "Output ", diff_view_meta->output_nr_, @@ -770,16 +801,6 @@ void handle_view_on_rebase( " is a view and ", modified_obj, " modified inplace."); - } else if (creation_meta == CreationMeta::INFERENCE_MODE) { - msg = c10::str( - "A view was created in inference mode and ", - modified_obj, - " modified inplace in normal mode."); - } else { - msg = c10::str( - "A view was created in no_grad mode and ", - modified_obj, - " modified inplace with grad mode enabled."); } if (creation_meta == CreationMeta::MULTI_OUTPUT_NODE) { @@ -789,7 +810,6 @@ void handle_view_on_rebase( " allow the output views to be modified inplace. You should replace the inplace operation by an" " out-of-place one."); } else if (creation_meta == CreationMeta::NO_GRAD_MODE) { - TORCH_INTERNAL_ASSERT(!grad_fn); msg = c10::str( msg, " Given that this use case is ambiguous and error-prone, it is forbidden." @@ -797,14 +817,12 @@ void handle_view_on_rebase( " inside the no_grad block (if you don't want the inplace to be tracked) or both outside (if you want" " the inplace to be tracked)."); } else if (creation_meta == CreationMeta::INFERENCE_MODE) { - TORCH_INTERNAL_ASSERT(!grad_fn); msg = c10::str( msg, " Given that this use case is ambiguous and error-prone, it is forbidden." " You can clarify your code by moving both the view and the inplace either both" " inside the inference_mode block (if you don't want the inplace to be tracked) or both outside (if you want" " the inplace to be tracked)."); - TORCH_CHECK(false, msg); } else if (creation_meta == CreationMeta::IN_CUSTOM_FUNCTION) { msg = c10::str( msg, diff --git a/torch/csrc/autograd/variable.h b/torch/csrc/autograd/variable.h index 49905fe803f46..52ce34ec394d0 100644 --- a/torch/csrc/autograd/variable.h +++ b/torch/csrc/autograd/variable.h @@ -791,6 +791,11 @@ inline Variable make_variable( return Variable(); } +namespace utils { + +TORCH_API bool has_same_meta(const Variable& base, const Variable& other); + +} // namespace utils } // namespace autograd } // namespace torch diff --git a/torch/csrc/cuda/CUDAPluggableAllocator.cpp b/torch/csrc/cuda/CUDAPluggableAllocator.cpp new file mode 100644 index 0000000000000..56927c16a0de8 --- /dev/null +++ b/torch/csrc/cuda/CUDAPluggableAllocator.cpp @@ -0,0 +1,317 @@ +#include +#include +#include + +#include + +namespace torch { +namespace cuda { +namespace CUDAPluggableAllocator { + +int device_count = 0; + +void custom_raw_deleter(void* ptr); + +// This is a fast API to just register allocators +// based on function pointers (ie. external .so libraries) +// This avoids having to link against libtorch for C++ based custom allocators +// And also use this from python +CUDAPluggableAllocator::CUDAPluggableAllocator( + std::function alloc_fn, + std::function free_fn) + : alloc_fn_(alloc_fn), free_fn_(free_fn) {} + +CUDAPluggableAllocator::CUDAPluggableAllocator(CUDAPluggableAllocator& other) + : alloc_fn_(other.alloc_fn_), + free_fn_(other.free_fn_), + init_fn_(other.init_fn_), + reset_fn_(other.reset_fn_), + memory_fraction_fn_(other.memory_fraction_fn_), + base_alloc_fn_(other.base_alloc_fn_), + record_stream_fn_(other.record_stream_fn_), + capture_begin_fn_(other.capture_begin_fn_), + capture_about_to_end_fn_(other.capture_about_to_end_fn_), + capture_ended_fn_(other.capture_ended_fn_), + capture_destroy_fn_(other.capture_destroy_fn_) {} + +void CUDAPluggableAllocator::set_init_fn(std::function init_fn) { + init_fn_ = init_fn; +} + +void CUDAPluggableAllocator::set_reset_fn(std::function reset_fn) { + reset_fn_ = reset_fn; +} + +void CUDAPluggableAllocator::set_memory_fraction_fn( + std::function memory_fraction_fn) { + memory_fraction_fn_ = memory_fraction_fn; +} + +void CUDAPluggableAllocator::set_base_alloc_fn( + std::function base_alloc_fn) { + base_alloc_fn_ = base_alloc_fn; +} + +void CUDAPluggableAllocator::set_record_stream_fn( + std::function record_stream_fn) { + record_stream_fn_ = record_stream_fn; +} + +void CUDAPluggableAllocator::set_capture_begin_fn( + std::function + capture_begin_fn) { + capture_begin_fn_ = capture_begin_fn; +} + +void CUDAPluggableAllocator::set_capture_about_to_end_fn( + std::function capture_about_to_end_fn) { + capture_about_to_end_fn_ = capture_about_to_end_fn; +} + +void CUDAPluggableAllocator::set_capture_ended_fn( + std::function capture_ended_fn) { + capture_ended_fn_ = capture_ended_fn; +} + +void CUDAPluggableAllocator::set_capture_destroy_fn( + std::function capture_destroy_fn) { + capture_destroy_fn_ = capture_destroy_fn; +} + +void* CUDAPluggableAllocator::malloc( + size_t size, + int device, + cudaStream_t stream) { + void* r = alloc_fn_(size, device, stream); + { + const std::lock_guard lock(allocator_mutex_); + allocation_metadata_.emplace(r, std::make_pair(size, stream)); + } + return r; +} + +c10::DataPtr CUDAPluggableAllocator::allocate(size_t size) const { + int device; + C10_CUDA_CHECK(cudaGetDevice(&device)); + cudaStream_t stream = c10::cuda::getCurrentCUDAStream(device); + void* r = + const_cast(this)->malloc(size, device, stream); + c10::DataPtr data_ptr = { + r, r, raw_deleter(), c10::Device(c10::DeviceType::CUDA, device)}; + return data_ptr; +} + +c10::DeleterFnPtr CUDAPluggableAllocator::raw_deleter() const { + return &custom_raw_deleter; +} + +void* CUDAPluggableAllocator::raw_alloc(size_t nbytes) { + int device; + C10_CUDA_CHECK(cudaGetDevice(&device)); + cudaStream_t stream = c10::cuda::getCurrentCUDAStream(device); + return malloc(nbytes, device, stream); +} + +void* CUDAPluggableAllocator::raw_alloc_with_stream( + size_t nbytes, + cudaStream_t stream) { + int device; + C10_CUDA_CHECK(cudaGetDevice(&device)); + return malloc(nbytes, device, stream); +} + +void CUDAPluggableAllocator::raw_delete(void* ptr) { + cudaStream_t stream; + size_t size; + { + const std::lock_guard lock(allocator_mutex_); + TORCH_CHECK( + allocation_metadata_.count(ptr), + "Trying to free a pointer not allocated here"); + auto pair = allocation_metadata_[ptr]; + size = pair.first; + stream = pair.second; + allocation_metadata_.erase(ptr); + } + free_fn_(ptr, size, stream); +} + +void CUDAPluggableAllocator::init(int device_count) { + if (init_fn_) { + init_fn_(device_count); + } + initialized_ = true; +} + +bool CUDAPluggableAllocator::initialized() { + return initialized_; +} + +void CUDAPluggableAllocator::setMemoryFraction(double fraction, int device) { + if (memory_fraction_fn_) { + memory_fraction_fn_(fraction, device); + } +} + +void CUDAPluggableAllocator::emptyCache(void) { + if (reset_fn_) { + return reset_fn_(); + } +} + +void CUDAPluggableAllocator::cacheInfo(int dev_id, size_t* largestBlock) { + TORCH_CHECK( + false, + "CUDAPluggableAllocator does not yet support cacheInfo. " + "If you need it, please file an issue describing your use case."); +} + +void* CUDAPluggableAllocator::getBaseAllocation(void* ptr, size_t* size) { + if (base_alloc_fn_) { + return base_alloc_fn_(ptr, size); + } else { + return ptr; + } +} + +void CUDAPluggableAllocator::recordStream( + const c10::DataPtr& ptr, + streamType stream) { + if (record_stream_fn_) { + record_stream_fn_(ptr.get(), stream); + } +} + +c10::cuda::CUDACachingAllocator::DeviceStats CUDAPluggableAllocator:: + getDeviceStats(int device) { + TORCH_CHECK( + false, + "CUDAPluggableAllocator does not yet support getDeviceStats. " + "If you need it, please file an issue describing your use case."); +} + +void CUDAPluggableAllocator::resetAccumulatedStats(int device) { + TORCH_CHECK( + false, + "CUDAPluggableAllocator does not yet support resetAccumulatedStats. " + "If you need it, please file an issue describing your use case."); +} + +void CUDAPluggableAllocator::resetPeakStats(int device) { + TORCH_CHECK( + false, + "CUDAPluggableAllocator does not yet support resetPeakStats. " + "If you need it, please file an issue describing your use case."); +} + +c10::cuda::CUDACachingAllocator::SnapshotInfo CUDAPluggableAllocator:: + snapshot() { + TORCH_CHECK( + false, + "CUDAPluggableAllocator does not yet support snapshot. " + "If you need it, please file an issue describing your use case."); +} + +std::shared_ptr CUDAPluggableAllocator::getIpcDevPtr(std::string handle) { + TORCH_CHECK( + false, + "CUDAPluggableAllocator does not yet support getIpcDevPtr. " + "If you need it, please file an issue describing your use case."); +} + +// CUDAGraph interactions +void CUDAPluggableAllocator::notifyCaptureBegin( + int device, + c10::cuda::CaptureId_t graph_id, + c10::cuda::MempoolId_t mempool_id) { + if (capture_begin_fn_) { + capture_begin_fn_(device, graph_id, mempool_id); + } +} + +void CUDAPluggableAllocator::notifyCaptureAboutToEnd( + int device, + c10::cuda::CaptureId_t graph_id) { + if (capture_about_to_end_fn_) { + capture_about_to_end_fn_(device, graph_id); + } +} + +void CUDAPluggableAllocator::notifyCaptureEnded( + int device, + c10::cuda::CaptureId_t graph_id) { + if (capture_ended_fn_) { + capture_ended_fn_(device, graph_id); + } +} + +void CUDAPluggableAllocator::notifyCaptureDestroy( + int device, + c10::cuda::MempoolId_t mempool_id) { + if (capture_destroy_fn_) { + capture_destroy_fn_(device, mempool_id); + } +} + +void CUDAPluggableAllocator::recordHistory( + bool enabled, + c10::cuda::CUDACachingAllocator::CreateContextFn context_recorder, + size_t alloc_trace_max_entries, + bool alloc_trace_record_context) { + TORCH_CHECK( + false, + "CUDAPluggableAllocator does not yet support recordHistory. " + "If you need it, please file an issue describing your use case."); +} + +void CUDAPluggableAllocator::attachOutOfMemoryObserver( + c10::cuda::CUDACachingAllocator::OutOfMemoryObserver observer) { + TORCH_CHECK( + false, + "CUDAPluggableAllocator does not yet support attachOutOfMemoryObserver. " + "If you need it, please file an issue describing your use case."); +} + +bool CUDAPluggableAllocator::needsPoolSpecificPeerAccess() { + return false; +} + +std::string CUDAPluggableAllocator::name() { + return "pluggable"; +} + +std::shared_ptr + current_custom_allocator; + +std::shared_ptr +getCurrentAllocator() { + return current_custom_allocator; +} + +// TODO: add more functions in the argument +std::shared_ptr +createCustomAllocator( + std::function alloc_fn, + std::function free_fn) { + std::shared_ptr allocator( + new CUDAPluggableAllocator(alloc_fn, free_fn)); + allocator->init(device_count); + return allocator; +} + +void changeCurrentAllocator( + std::shared_ptr allocator) { + TORCH_CHECK( + !c10::cuda::CUDACachingAllocator::allocator.load()->initialized(), + "Can't swap an already initialized allocator"); + c10::cuda::CUDACachingAllocator::allocator.store(allocator.get()); + current_custom_allocator = allocator; +} + +void custom_raw_deleter(void* ptr) { + current_custom_allocator->raw_delete(ptr); +} + +} // namespace CUDAPluggableAllocator +} // namespace cuda +} // namespace torch diff --git a/torch/csrc/cuda/CUDAPluggableAllocator.h b/torch/csrc/cuda/CUDAPluggableAllocator.h new file mode 100644 index 0000000000000..a02acabe3cd85 --- /dev/null +++ b/torch/csrc/cuda/CUDAPluggableAllocator.h @@ -0,0 +1,135 @@ +#pragma once + +#include +#include +#include +#include + +#include + +#include +#include + +namespace torch { + +namespace cuda { + +namespace CUDAPluggableAllocator { + +#if defined(TORCH_HIP_VERSION) +using streamType = c10::hip::HIPStream; +#else +using streamType = c10::cuda::CUDAStream; +#endif + +std::shared_ptr +getCurrentAllocator(); +std::shared_ptr +createCustomAllocator( + std::function alloc_fn, + std::function free_fn); +void changeCurrentAllocator( + std::shared_ptr allocator); + +struct CUDAPluggableAllocator + : public c10::cuda::CUDACachingAllocator::CUDAAllocator { + CUDAPluggableAllocator( + std::function alloc_fn, + std::function free_fn); + + CUDAPluggableAllocator(CUDAPluggableAllocator& other); + + void set_init_fn(std::function init_fn); + + void set_reset_fn(std::function reset_fn); + + void set_memory_fraction_fn( + std::function memory_fraction_fn); + + void set_base_alloc_fn(std::function base_alloc_fn); + + void set_record_stream_fn( + std::function record_stream_fn); + + void set_capture_begin_fn( + std::function + capture_begin_fn); + + void set_capture_about_to_end_fn( + std::function capture_about_to_end_fn); + + void set_capture_ended_fn( + std::function capture_ended_fn); + + void set_capture_destroy_fn( + std::function capture_destroy_fn); + + void* malloc(size_t size, int device, cudaStream_t stream); + + c10::DataPtr allocate(size_t size) const; + c10::DeleterFnPtr raw_deleter() const; + + virtual void* raw_alloc(size_t nbytes) override; + virtual void* raw_alloc_with_stream(size_t nbytes, cudaStream_t stream) + override; + virtual void raw_delete(void* ptr) override; + virtual void init(int device_count) override; + virtual bool initialized() override; + virtual void setMemoryFraction(double fraction, int device) override; + virtual void emptyCache() override; + virtual void cacheInfo(int dev_id, size_t* largestBlock) override; + virtual void* getBaseAllocation(void* ptr, size_t* size) override; + + virtual void recordStream(const c10::DataPtr&, streamType stream) override; + + virtual c10::cuda::CUDACachingAllocator::DeviceStats getDeviceStats( + int device) override; + virtual void resetAccumulatedStats(int device) override; + virtual void resetPeakStats(int device) override; + virtual c10::cuda::CUDACachingAllocator::SnapshotInfo snapshot() override; + virtual void notifyCaptureBegin( + int device, + c10::cuda::CaptureId_t graph_id, + c10::cuda::MempoolId_t mempool_id) override; + virtual void notifyCaptureAboutToEnd( + int device, + c10::cuda::CaptureId_t graph_id) override; + virtual void notifyCaptureEnded(int device, c10::cuda::CaptureId_t graph_id) + override; + virtual void notifyCaptureDestroy( + int device, + c10::cuda::MempoolId_t mempool_id) override; + virtual std::shared_ptr getIpcDevPtr(std::string handle) override; + virtual void recordHistory( + bool enabled, + c10::cuda::CUDACachingAllocator::CreateContextFn context_recorder, + size_t alloc_trace_max_entries, + bool alloc_trace_record_context) override; + virtual void attachOutOfMemoryObserver( + c10::cuda::CUDACachingAllocator::OutOfMemoryObserver observer) override; + virtual bool needsPoolSpecificPeerAccess() override; + virtual std::string name() override; + + protected: + std::function alloc_fn_; + std::function free_fn_; + std::function init_fn_; + std::function reset_fn_; + std::function memory_fraction_fn_; + std::function base_alloc_fn_; + std::function record_stream_fn_; + std::function + capture_begin_fn_; + std::function capture_about_to_end_fn_; + std::function capture_ended_fn_; + std::function capture_destroy_fn_; + std::mutex allocator_mutex_; + // We do the bookeeping here in order to simplify custom allocators + std::unordered_map> + allocation_metadata_; + + bool initialized_ = false; +}; +} // namespace CUDAPluggableAllocator +} // namespace cuda +} // namespace torch diff --git a/torch/csrc/cuda/Graph.cpp b/torch/csrc/cuda/Graph.cpp index 0866b82f659dd..f43a7debb5e41 100644 --- a/torch/csrc/cuda/Graph.cpp +++ b/torch/csrc/cuda/Graph.cpp @@ -30,23 +30,37 @@ void THCPGraph_init(PyObject* module) { // docs aren't clear. But it works. .def( "capture_begin", - &::at::cuda::CUDAGraph::capture_begin, + torch::wrap_pybind_function(&at::cuda::CUDAGraph::capture_begin), py::call_guard(), py::arg("pool") = c10::cuda::MempoolId_t{0, 0}) .def( "capture_end", - &::at::cuda::CUDAGraph::capture_end, + torch::wrap_pybind_function(&at::cuda::CUDAGraph::capture_end), py::call_guard()) .def( "replay", - &::at::cuda::CUDAGraph::replay, + torch::wrap_pybind_function(&at::cuda::CUDAGraph::replay), py::call_guard()) .def( "reset", - &::at::cuda::CUDAGraph::reset, + torch::wrap_pybind_function(&at::cuda::CUDAGraph::reset), py::call_guard()) .def( "pool", - &::at::cuda::CUDAGraph::pool, - py::call_guard()); + torch::wrap_pybind_function(&at::cuda::CUDAGraph::pool), + py::call_guard()) + .def( + "debug_dump", + torch::wrap_pybind_function(&::at::cuda::CUDAGraph::debug_dump), + py::call_guard()) + .def( + "enable_debug_mode", + torch::wrap_pybind_function( + &::at::cuda::CUDAGraph::enable_debug_mode), + py::call_guard()) + .def( + "debug_dump", + torch::wrap_pybind_function(&::at::cuda::CUDAGraph::debug_dump), + py::call_guard(), + py::arg("debug_path")); } diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index 10dac0e0d0f7b..b526f87edd75d 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #ifdef USE_NCCL @@ -21,6 +22,7 @@ #include #include +#include #include #include #include @@ -851,6 +853,123 @@ static void registerCudaDeviceProperties(PyObject* module) { }); } +static void registerCudaPluggableAllocator(PyObject* module) { + auto m = py::handle(module).cast(); + + py::class_< + c10::cuda::CUDACachingAllocator::CUDAAllocator, + std::shared_ptr>( + m, "_cuda_CUDAAllocator"); + m.def("_cuda_getAllocator", []() { + return py::cast(torch::cuda::CUDAPluggableAllocator::getCurrentAllocator()); + }); + + m.def( + "_cuda_changeCurrentAllocator", + [](std::shared_ptr + allocator) { + torch::cuda::CUDAPluggableAllocator::changeCurrentAllocator(allocator); + }); + py::class_< + torch::cuda::CUDAPluggableAllocator::CUDAPluggableAllocator, + c10::cuda::CUDACachingAllocator::CUDAAllocator, + std::shared_ptr< + torch::cuda::CUDAPluggableAllocator::CUDAPluggableAllocator>>( + m, "_CUDAPluggableAllocator") + .def( + "set_init_fn", + [](torch::cuda::CUDAPluggableAllocator::CUDAPluggableAllocator& self, + uint64_t func_ptr) { + using FuncType = void(int); + std::function func = + reinterpret_cast(func_ptr); + self.set_init_fn(func); + }) + .def( + "set_reset_fn", + [](torch::cuda::CUDAPluggableAllocator::CUDAPluggableAllocator& self, + uint64_t func_ptr) { + using FuncType = void(); + std::function func = + reinterpret_cast(func_ptr); + self.set_reset_fn(func); + }) + .def( + "set_memory_fraction_fn", + [](torch::cuda::CUDAPluggableAllocator::CUDAPluggableAllocator& self, + uint64_t func_ptr) { + using FuncType = void(double, int); + std::function func = + reinterpret_cast(func_ptr); + self.set_memory_fraction_fn(func); + }) + .def( + "set_base_alloc_fn", + [](torch::cuda::CUDAPluggableAllocator::CUDAPluggableAllocator& self, + uint64_t func_ptr) { + using FuncType = void*(void*, size_t*); + std::function func = + reinterpret_cast(func_ptr); + self.set_base_alloc_fn(func); + }) + .def( + "set_record_stream_fn", + [](torch::cuda::CUDAPluggableAllocator::CUDAPluggableAllocator& self, + uint64_t func_ptr) { + using FuncType = void(void*, cudaStream_t); + std::function func = + reinterpret_cast(func_ptr); + self.set_record_stream_fn(func); + }) + .def( + "set_capture_begin_fn", + [](torch::cuda::CUDAPluggableAllocator::CUDAPluggableAllocator& self, + uint64_t func_ptr) { + using FuncType = + void(int, c10::cuda::CaptureId_t, c10::cuda::MempoolId_t); + std::function func = + reinterpret_cast(func_ptr); + self.set_capture_begin_fn(func); + }) + .def( + "set_capture_about_to_end_fn", + [](torch::cuda::CUDAPluggableAllocator::CUDAPluggableAllocator& self, + uint64_t func_ptr) { + using FuncType = void(int, c10::cuda::CaptureId_t); + std::function func = + reinterpret_cast(func_ptr); + self.set_capture_about_to_end_fn(func); + }) + .def( + "set_capture_ended_fn", + [](torch::cuda::CUDAPluggableAllocator::CUDAPluggableAllocator& self, + uint64_t func_ptr) { + using FuncType = void(int, c10::cuda::CaptureId_t); + std::function func = + reinterpret_cast(func_ptr); + self.set_capture_ended_fn(func); + }) + .def( + "set_capture_destroy_fn", + [](torch::cuda::CUDAPluggableAllocator::CUDAPluggableAllocator& self, + uint64_t func_ptr) { + using FuncType = void(int, c10::cuda::MempoolId_t); + std::function func = + reinterpret_cast(func_ptr); + self.set_capture_destroy_fn(func); + }); + m.def("_cuda_customAllocator", [](uint64_t malloc_ptr, uint64_t free_ptr) { + using MallocFuncType = void*(size_t, int, cudaStream_t); + using FreeFuncType = void(void*, size_t, cudaStream_t); + std::function malloc_fn = + reinterpret_cast(malloc_ptr); + std::function free_fn = + reinterpret_cast(free_ptr); + return torch::cuda::CUDAPluggableAllocator::createCustomAllocator( + malloc_fn, free_fn); + }); +} + static void bindGetDeviceProperties(PyObject* module) { // Add method to torch.cuda auto m = py::handle(module).cast(); @@ -1141,6 +1260,7 @@ void initModule(PyObject* module) { shared::initCudnnBindings(module); #endif registerCudaDeviceProperties(module); + registerCudaPluggableAllocator(module); } } // namespace cuda diff --git a/torch/csrc/cuda/Tensor.cpp b/torch/csrc/cuda/Tensor.cpp index beb81f187a6e2..f9486164358d4 100644 --- a/torch/csrc/cuda/Tensor.cpp +++ b/torch/csrc/cuda/Tensor.cpp @@ -1,4 +1,6 @@ +#ifndef __STDC_FORMAT_MACROS #define __STDC_FORMAT_MACROS +#endif // Order of these includes matters, which should be fixed. // clang-format off diff --git a/torch/csrc/cuda/comm.cpp b/torch/csrc/cuda/comm.cpp index 117f6b571792b..e215ce0e3ed67 100644 --- a/torch/csrc/cuda/comm.cpp +++ b/torch/csrc/cuda/comm.cpp @@ -180,12 +180,12 @@ tensor_list2d broadcast_coalesced( unique_type_checker type_checker; at::cuda::CUDAGuard device_guard(devices[0]); - for (auto& chunk : utils::take_tensors(tensors, buffer_size)) { + for (auto& chunk : torch::utils::take_tensors(tensors, buffer_size)) { auto type_id = chunk.type_id(); type_checker.show(type_id); std::vector results; if (chunk.options().is_sparse()) { - auto flat_tuple = utils::flatten_sparse_tensors(chunk.tensors); + auto flat_tuple = torch::utils::flatten_sparse_tensors(chunk.tensors); auto broadcast_indices = broadcast(flat_tuple.first, devices); auto broadcast_values = broadcast(flat_tuple.second, devices); results.reserve(devices.size()); @@ -194,20 +194,20 @@ tensor_list2d broadcast_coalesced( auto& device_outputs = outputs[i]; auto& inds = broadcast_indices[i]; auto& vals = broadcast_values[i]; - for (const auto& var : - utils::unflatten_sparse_tensors(inds, vals, chunk.tensors)) { + for (const auto& var : torch::utils::unflatten_sparse_tensors( + inds, vals, chunk.tensors)) { // See NOTE [ Version Counter in comm.*_coalesced ] device_outputs.push_back(make_variable(var.tensor_data(), false)); } } } else { - auto results = - broadcast(utils::flatten_dense_tensors(chunk.tensors), devices); + auto results = broadcast( + torch::utils::flatten_dense_tensors(chunk.tensors), devices); for (size_t i = 1, num_devices = devices.size(); i < num_devices; ++i) { device_guard.set_index(devices[i]); auto& device_outputs = outputs[i]; for (auto& var : - utils::unflatten_dense_tensors(results[i], chunk.tensors)) { + torch::utils::unflatten_dense_tensors(results[i], chunk.tensors)) { // See NOTE [ Version Counter in comm.*_coalesced ] device_outputs.push_back(make_variable(var.tensor_data(), false)); } @@ -218,7 +218,7 @@ tensor_list2d broadcast_coalesced( // If we only saw a single tensor type, then we can skip expensive reordering if (!type_checker.unique) { for (auto& o : outputs) - utils::reorder_tensors_like(o, tensors); + torch::utils::reorder_tensors_like(o, tensors); } return outputs; } diff --git a/torch/csrc/cuda/nccl.cpp b/torch/csrc/cuda/nccl.cpp index a1d96f7e5d6cd..8a3c8af797cc2 100644 --- a/torch/csrc/cuda/nccl.cpp +++ b/torch/csrc/cuda/nccl.cpp @@ -3,6 +3,7 @@ #include #include +#include #include #include #include @@ -142,7 +143,7 @@ struct NcclCommList { if (comms) { for (const auto i : c10::irange(ndevices)) { int dummy_var; - if (cudaGetDevice(&dummy_var) != cudaSuccess) { + if (C10_CUDA_ERROR_HANDLED(cudaGetDevice(&dummy_var)) != cudaSuccess) { /* there are cases when this destructor is called after the CUDA driver is already unloaded from the process. In these cases, skip ncclCommDestroy */ diff --git a/torch/csrc/cuda/shared/cudart.cpp b/torch/csrc/cuda/shared/cudart.cpp index 9e098d44808ba..f18c883a2a06a 100644 --- a/torch/csrc/cuda/shared/cudart.cpp +++ b/torch/csrc/cuda/shared/cudart.cpp @@ -71,25 +71,26 @@ void initCudartBindings(PyObject* module) { "cuda" "HostRegister", [](uintptr_t ptr, size_t size, unsigned int flags) -> cudaError_t { - return cudaHostRegister((void*)ptr, size, flags); + return C10_CUDA_ERROR_HANDLED( + cudaHostRegister((void*)ptr, size, flags)); }); cudart.def( "cuda" "HostUnregister", [](uintptr_t ptr) -> cudaError_t { - return cudaHostUnregister((void*)ptr); + return C10_CUDA_ERROR_HANDLED(cudaHostUnregister((void*)ptr)); }); cudart.def( "cuda" "StreamCreate", [](uintptr_t ptr) -> cudaError_t { - return cudaStreamCreate((cudaStream_t*)ptr); + return C10_CUDA_ERROR_HANDLED(cudaStreamCreate((cudaStream_t*)ptr)); }); cudart.def( "cuda" "StreamDestroy", [](uintptr_t ptr) -> cudaError_t { - return cudaStreamDestroy((cudaStream_t)ptr); + return C10_CUDA_ERROR_HANDLED(cudaStreamDestroy((cudaStream_t)ptr)); }); #if !defined(USE_ROCM) cudart.def( @@ -104,7 +105,7 @@ void initCudartBindings(PyObject* module) { c10::cuda::CUDAGuard guard(device); size_t device_free = 0; size_t device_total = 0; - cudaMemGetInfo(&device_free, &device_total); + C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total)); return {device_free, device_total}; }); } diff --git a/torch/csrc/distributed/autograd/engine/dist_engine.cpp b/torch/csrc/distributed/autograd/engine/dist_engine.cpp index 2da315644845c..06c6927e4c467 100644 --- a/torch/csrc/distributed/autograd/engine/dist_engine.cpp +++ b/torch/csrc/distributed/autograd/engine/dist_engine.cpp @@ -185,6 +185,13 @@ void DistEngine::computeDependencies( bool retainGraph) { TORCH_INTERNAL_ASSERT(graphRoot, "graphRoot is null!"); + // Store root nodes so we can traverse through the graph later + // e.g., for get_current_graph_task_execution_order + c10::SmallVector temp_roots{rootEdges.size()}; + for (const auto i : c10::irange(rootEdges.size())) { + temp_roots[i] = rootEdges[i].function.get(); + } + // Build the graph task and graph root. // NOTE: we don't need to build and pass a cpu_ready_queue to GraphTask // as we use execute_graph_task_until_ready_queue_empty, which will build @@ -194,6 +201,7 @@ void DistEngine::computeDependencies( /* create_graph */ false, /* depth */ 0, /* cpu_ready_queue */ global_cpu_ready_queue_, + /* graph_roots */ temp_roots, /* exit_on_error */ true); // Run BFS to traverse the graph locally. The roots of the graph are diff --git a/torch/csrc/distributed/c10d/NCCLUtils.hpp b/torch/csrc/distributed/c10d/NCCLUtils.hpp index a9bea0e67d7bb..fb5d91d2e11cf 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.hpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.hpp @@ -55,7 +55,7 @@ std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + \ std::to_string(__LINE__) + ", " + ncclGetErrorWithVersion(result) + \ "\n" + getNcclErrorDetailStr(result, failureReason); \ - TORCH_CHECK(false, err); \ + TORCH_CHECK_WITH(DistBackendError, false, err); \ } \ } while (0) diff --git a/torch/csrc/distributed/c10d/Ops.cpp b/torch/csrc/distributed/c10d/Ops.cpp index ea77bb337b4a8..52a8b17c290dc 100644 --- a/torch/csrc/distributed/c10d/Ops.cpp +++ b/torch/csrc/distributed/c10d/Ops.cpp @@ -40,6 +40,19 @@ std::tuple, c10::intrusive_ptr> allreduce_( std::move(tensor_vec), work); } +c10::intrusive_ptr allreduce_coalesced_( + at::TensorList tensors, + const c10::intrusive_ptr& process_group, + const c10::intrusive_ptr& reduce_op, + int64_t timeout) { + auto tensor_vec = tensors.vec(); + AllreduceCoalescedOptions opts = AllreduceCoalescedOptions{}; + opts.reduceOp = *reduce_op.get(); + opts.timeout = std::chrono::milliseconds(timeout); + + return process_group->allreduce_coalesced(tensor_vec, opts); +} + c10::intrusive_ptr reduce_( at::TensorList tensors, const c10::intrusive_ptr& process_group, @@ -75,6 +88,22 @@ allgather_( output_tensors, work); } +c10::intrusive_ptr _allgather_base_( + at::Tensor& output_tensor, + at::Tensor& input_tensor, + const c10::intrusive_ptr& process_group) { + return process_group->_allgather_base(output_tensor, input_tensor); +} + +c10::intrusive_ptr allgather_coalesced_( + const std::vector>& output_lists, + const std::vector& input_list, + const c10::intrusive_ptr& process_group) { + return process_group->allgather_coalesced( + const_cast>&>(output_lists), + const_cast&>(input_list)); +} + std::tuple, c10::intrusive_ptr> reduce_scatter_( const std::vector& output_tensors, const std::vector>& input_tensors, @@ -91,6 +120,19 @@ std::tuple, c10::intrusive_ptr> reduce_scatter_( output_tensors, work); } +c10::intrusive_ptr _reduce_scatter_base_( + at::Tensor& output_tensor, + at::Tensor& input_tensor, + const c10::intrusive_ptr& process_group, + const c10::intrusive_ptr& reduce_op, + int64_t timeout) { + return process_group->_reduce_scatter_base( + output_tensor, + input_tensor, + ReduceScatterOptions{ + *reduce_op.get(), std::chrono::milliseconds(timeout)}); +} + c10::intrusive_ptr gather_( const std::vector>& output_tensors, const std::vector& input_tensors, @@ -131,6 +173,21 @@ c10::intrusive_ptr alltoall_( AllToAllOptions{std::chrono::milliseconds(timeout)}); } +c10::intrusive_ptr alltoall_base_( + at::Tensor& output, + at::Tensor& input, + const c10::intrusive_ptr& process_group, + std::vector output_split_sizes, + std::vector input_split_sizes, + int64_t timeout) { + return process_group->alltoall_base( + output, + input, + output_split_sizes, + input_split_sizes, + AllToAllOptions{std::chrono::milliseconds(timeout)}); +} + c10::intrusive_ptr barrier( const c10::intrusive_ptr& process_group, const std::vector& device_ids, @@ -139,6 +196,17 @@ c10::intrusive_ptr barrier( BarrierOptions{device_ids, std::chrono::milliseconds(timeout)}); } +void monitored_barrier_( + at::Tensor /* unused */, + const c10::intrusive_ptr<::c10d::ProcessGroup>& process_group, + const std::vector& device_ids, + int64_t timeout, + bool wait_all_ranks) { + process_group->monitoredBarrier( + BarrierOptions{device_ids, std::chrono::milliseconds(timeout)}, + wait_all_ranks); +} + c10::intrusive_ptr send( at::TensorList tensors, const c10::intrusive_ptr& process_group, @@ -159,6 +227,14 @@ c10::intrusive_ptr recv_( tensor_vec, static_cast(srcRank), static_cast(tag)); } +c10::intrusive_ptr recv_any_source_( + at::TensorList tensors, + const c10::intrusive_ptr& process_group, + int64_t tag) { + auto tensor_vec = tensors.vec(); + return process_group->recvAnysource(tensor_vec, static_cast(tag)); +} + TORCH_LIBRARY(c10d, m) { // The following ProcessGroup, Work, and ReduceOp definitions are more like // declarations. They don't expose the details of the two classes into @@ -177,12 +253,27 @@ TORCH_LIBRARY(c10d, m) { m.def( "allreduce_", dispatch(c10::DispatchKey::CompositeExplicitAutograd, allreduce_)); + m.def( + "allreduce_coalesced_", + dispatch( + c10::DispatchKey::CompositeExplicitAutograd, allreduce_coalesced_)); m.def( "allgather_", dispatch(c10::DispatchKey::CompositeExplicitAutograd, allgather_)); + m.def( + "_allgather_base_", + dispatch(c10::DispatchKey::CompositeExplicitAutograd, _allgather_base_)); + m.def( + "allgather_coalesced_", + dispatch( + c10::DispatchKey::CompositeExplicitAutograd, allgather_coalesced_)); m.def( "reduce_scatter_", dispatch(c10::DispatchKey::CompositeExplicitAutograd, reduce_scatter_)); + m.def( + "_reduce_scatter_base_", + dispatch( + c10::DispatchKey::CompositeExplicitAutograd, _reduce_scatter_base_)); m.def( "reduce_", dispatch(c10::DispatchKey::CompositeExplicitAutograd, reduce_)); @@ -195,11 +286,21 @@ TORCH_LIBRARY(c10d, m) { m.def( "alltoall_", dispatch(c10::DispatchKey::CompositeExplicitAutograd, alltoall_)); + m.def( + "alltoall_base_", + dispatch(c10::DispatchKey::CompositeExplicitAutograd, alltoall_base_)); m.def( "barrier", dispatch(c10::DispatchKey::CompositeExplicitAutograd, barrier)); + m.def( + "monitored_barrier_", + dispatch( + c10::DispatchKey::CompositeExplicitAutograd, monitored_barrier_)); m.def("send", dispatch(c10::DispatchKey::CompositeExplicitAutograd, send)); m.def("recv_", dispatch(c10::DispatchKey::CompositeExplicitAutograd, recv_)); + m.def( + "recv_any_source_", + dispatch(c10::DispatchKey::CompositeExplicitAutograd, recv_any_source_)); } } // namespace @@ -249,6 +350,25 @@ c10::intrusive_ptr allreduce( opts.timeout.count())); } +c10::intrusive_ptr allreduce_coalesced( + const c10::intrusive_ptr& process_group, + at::TensorList tensors, + const AllreduceCoalescedOptions& opts) { + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("c10d::allreduce_coalesced_", "") + .typed( + at::TensorList, + const c10::intrusive_ptr<::c10d::ProcessGroup>&, + const c10::intrusive_ptr<::c10d::ReduceOp>&, + int64_t)>(); + + return op.call( + tensors, + process_group, + c10::make_intrusive(opts.reduceOp), + opts.timeout.count()); +} + c10::intrusive_ptr allgather( const c10::intrusive_ptr& process_group, const std::vector>& output_tensors, @@ -267,6 +387,36 @@ c10::intrusive_ptr allgather( output_tensors, input_tensors, process_group, opts.timeout.count())); } +c10::intrusive_ptr _allgather_base( + const c10::intrusive_ptr& process_group, + at::Tensor& output_tensor, + at::Tensor& input_tensor, + const AllgatherOptions& opts) { + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("c10d::_allgather_base_", "") + .typed( + at::Tensor&, + at::Tensor&, + const c10::intrusive_ptr<::c10d::ProcessGroup>&)>(); + + return op.call(output_tensor, input_tensor, process_group); +} + +c10::intrusive_ptr allgather_coalesced( + const c10::intrusive_ptr& process_group, + const std::vector>& output_lists, + const std::vector& input_list, + const AllgatherOptions& opts) { + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("c10d::allgather_coalesced_", "") + .typed( + const std::vector>&, + const std::vector&, + const c10::intrusive_ptr<::c10d::ProcessGroup>&)>(); + + return op.call(output_lists, input_list, process_group); +} + c10::intrusive_ptr reduce_scatter( const c10::intrusive_ptr& process_group, const std::vector& output_tensors, @@ -289,6 +439,27 @@ c10::intrusive_ptr reduce_scatter( opts.timeout.count())); } +c10::intrusive_ptr _reduce_scatter_base( + const c10::intrusive_ptr& process_group, + at::Tensor& output_tensor, + at::Tensor& input_tensor, + const ReduceScatterOptions& opts) { + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("c10d::_reduce_scatter_base_", "") + .typed( + at::Tensor&, + at::Tensor&, + const c10::intrusive_ptr<::c10d::ProcessGroup>&, + const c10::intrusive_ptr<::c10d::ReduceOp>&, + int64_t)>(); + return op.call( + output_tensor, + input_tensor, + process_group, + c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp), + opts.timeout.count()); +} + c10::intrusive_ptr reduce( const c10::intrusive_ptr& process_group, at::TensorList tensors, @@ -370,6 +541,53 @@ c10::intrusive_ptr alltoall( output_tensors, input_tensors, process_group, opts.timeout.count()); } +c10::intrusive_ptr alltoall_base( + const c10::intrusive_ptr& process_group, + at::Tensor& output, + at::Tensor& input, + std::vector output_split_sizes, + std::vector input_split_sizes, + const AllToAllOptions& opts) { + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("c10d::alltoall_base_", "") + .typed( + at::Tensor&, + at::Tensor&, + const c10::intrusive_ptr<::c10d::ProcessGroup>&, + std::vector, + std::vector, + int64_t)>(); + return op.call( + output, + input, + process_group, + output_split_sizes, + input_split_sizes, + opts.timeout.count()); +} + +void monitored_barrier( + const c10::intrusive_ptr& process_group, + const BarrierOptions& opts, + bool wait_all_ranks) { + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("c10d::monitored_barrier_", "") + .typed&, + const std::vector&, + int64_t, + bool)>(); + // Default to using cpu implementation, monitored barrier is only for GLOO + at::Tensor tensor = at::empty({0}, at::TensorOptions().device(at::kCPU)); + op.call( + tensor, + process_group, + opts.device_ids, + opts.timeout.count(), + wait_all_ranks); +} + c10::intrusive_ptr barrier( const c10::intrusive_ptr& process_group, const BarrierOptions& opts) { @@ -412,5 +630,18 @@ c10::intrusive_ptr recv( return op.call(tensors, process_group, srcRank, tag); } +c10::intrusive_ptr recv_any_source( + const c10::intrusive_ptr& process_group, + at::TensorList tensors, + int64_t tag) { + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("c10d::recv_any_source_", "") + .typed( + at::TensorList, + const c10::intrusive_ptr<::c10d::ProcessGroup>&, + int64_t)>(); + return op.call(tensors, process_group, tag); +} + } // namespace ops } // namespace c10d diff --git a/torch/csrc/distributed/c10d/Ops.hpp b/torch/csrc/distributed/c10d/Ops.hpp index adc64066a885e..db9006995c1a8 100644 --- a/torch/csrc/distributed/c10d/Ops.hpp +++ b/torch/csrc/distributed/c10d/Ops.hpp @@ -21,18 +21,41 @@ TORCH_API c10::intrusive_ptr allreduce( at::TensorList tensors, const AllreduceOptions& opts = {}); +TORCH_API c10::intrusive_ptr allreduce_coalesced( + const c10::intrusive_ptr& process_group, + at::TensorList tensors, + const AllreduceCoalescedOptions& opts = {}); + TORCH_API c10::intrusive_ptr allgather( const c10::intrusive_ptr& process_group, const std::vector>& output_tensors, const std::vector& input_tensors, const AllgatherOptions& opts = {}); +TORCH_API c10::intrusive_ptr _allgather_base( + const c10::intrusive_ptr& process_group, + at::Tensor& outputTensor, + at::Tensor& inputTensor, + const AllgatherOptions& opts = {}); + +TORCH_API c10::intrusive_ptr allgather_coalesced( + const c10::intrusive_ptr& process_group, + const std::vector>& output_lists, + const std::vector& input_list, + const AllgatherOptions& opts = {}); + TORCH_API c10::intrusive_ptr reduce_scatter( const c10::intrusive_ptr& process_group, const std::vector& output_tensors, const std::vector>& input_tensors, const ReduceScatterOptions& opts = {}); +TORCH_API c10::intrusive_ptr _reduce_scatter_base( + const c10::intrusive_ptr& process_group, + at::Tensor& output_tensor, + at::Tensor& input_tensor, + const ReduceScatterOptions& opts = {}); + TORCH_API c10::intrusive_ptr reduce( const c10::intrusive_ptr& process_group, at::TensorList tensors, @@ -50,6 +73,14 @@ TORCH_API c10::intrusive_ptr scatter( const std::vector>& input_tensors, const ScatterOptions& opts = {}); +TORCH_API c10::intrusive_ptr alltoall_base( + const c10::intrusive_ptr& process_group, + at::Tensor& output, + at::Tensor& input, + const std::vector outputSplitSizes, + const std::vector inputSplitSizes, + const AllToAllOptions& opts = {}); + TORCH_API c10::intrusive_ptr alltoall( const c10::intrusive_ptr& process_group, at::TensorList output_tensors, @@ -60,6 +91,11 @@ TORCH_API c10::intrusive_ptr barrier( const c10::intrusive_ptr& process_group, const BarrierOptions& opts = {}); +TORCH_API void monitored_barrier( + const c10::intrusive_ptr& process_group, + const BarrierOptions& opts, + bool waitAllRanks); + TORCH_API c10::intrusive_ptr send( const c10::intrusive_ptr& process_group, at::TensorList tensors, @@ -72,5 +108,10 @@ TORCH_API c10::intrusive_ptr recv( int64_t srcRank, int64_t tag); +TORCH_API c10::intrusive_ptr recv_any_source( + const c10::intrusive_ptr& process_group, + at::TensorList tensors, + int64_t tag); + } // namespace ops } // namespace c10d diff --git a/torch/csrc/distributed/c10d/OpsImpl.cpp b/torch/csrc/distributed/c10d/OpsImpl.cpp index 8254ce3126e3f..6eb69c664d16d 100644 --- a/torch/csrc/distributed/c10d/OpsImpl.cpp +++ b/torch/csrc/distributed/c10d/OpsImpl.cpp @@ -49,6 +49,22 @@ c10::intrusive_ptr recv_cuda_( tensor_vec, static_cast(srcRank), static_cast(tag)); } +c10::intrusive_ptr recv_any_source_cpu_( + at::TensorList tensors, + const c10::intrusive_ptr& process_group, + int64_t tag) { + auto tensor_vec = tensors.vec(); + return process_group->recvAnysource(tensor_vec, static_cast(tag)); +} + +c10::intrusive_ptr recv_any_source_cuda_( + at::TensorList tensors, + const c10::intrusive_ptr& process_group, + int64_t tag) { + auto tensor_vec = tensors.vec(); + return process_group->recvAnysource(tensor_vec, static_cast(tag)); +} + c10::intrusive_ptr reduce_cpu_( at::TensorList tensors, const c10::intrusive_ptr& process_group, @@ -149,6 +165,32 @@ std::tuple, c10::intrusive_ptr> allreduce_cuda_( std::move(tensor_vec), work); } +c10::intrusive_ptr allreduce_coalesced_cpu_( + at::TensorList tensors, + const c10::intrusive_ptr& process_group, + const c10::intrusive_ptr& reduce_op, + int64_t timeout) { + auto tensor_vec = tensors.vec(); + AllreduceCoalescedOptions opts = AllreduceCoalescedOptions{}; + opts.reduceOp = *reduce_op.get(); + opts.timeout = std::chrono::milliseconds(timeout); + + return process_group->allreduce_coalesced(tensor_vec, opts); +} + +c10::intrusive_ptr allreduce_coalesced_cuda_( + at::TensorList tensors, + const c10::intrusive_ptr& process_group, + const c10::intrusive_ptr& reduce_op, + int64_t timeout) { + auto tensor_vec = tensors.vec(); + AllreduceCoalescedOptions opts = AllreduceCoalescedOptions{}; + opts.reduceOp = *reduce_op.get(); + opts.timeout = std::chrono::milliseconds(timeout); + + return process_group->allreduce_coalesced(tensor_vec, opts); +} + std::tuple>, c10::intrusive_ptr> allgather_cpu_( const std::vector>& output_tensors, @@ -185,6 +227,235 @@ allgather_cuda_( output_tensors, work); } +c10::intrusive_ptr _allgather_base_cpu_( + at::Tensor& output_tensor, + at::Tensor& input_tensor, + const c10::intrusive_ptr& process_group) { + return process_group->_allgather_base(output_tensor, input_tensor); +} + +c10::intrusive_ptr _allgather_base_cuda_( + at::Tensor& output_tensor, + at::Tensor& input_tensor, + const c10::intrusive_ptr& process_group) { + return process_group->_allgather_base(output_tensor, input_tensor); +} + +c10::intrusive_ptr allgather_coalesced_cpu_( + const std::vector>& output_lists, + const std::vector& input_list, + const c10::intrusive_ptr& process_group) { + return process_group->allgather_coalesced( + const_cast>&>(output_lists), + const_cast&>(input_list)); +} + +c10::intrusive_ptr allgather_coalesced_cuda_( + const std::vector>& output_lists, + const std::vector& input_list, + const c10::intrusive_ptr& process_group) { + return process_group->allgather_coalesced( + const_cast>&>(output_lists), + const_cast&>(input_list)); +} + +std::tuple, c10::intrusive_ptr> +reduce_scatter_cpu_( + const std::vector& output_tensors, + const std::vector>& input_tensors, + const c10::intrusive_ptr& process_group, + const c10::intrusive_ptr& reduce_op, + int64_t timeout) { + auto work = process_group->reduce_scatter( + const_cast&>(output_tensors), + const_cast>&>(input_tensors), + ReduceScatterOptions{ + *reduce_op.get(), std::chrono::milliseconds(timeout)}); + + return std::tuple, c10::intrusive_ptr>( + output_tensors, work); +} + +std::tuple, c10::intrusive_ptr> +reduce_scatter_cuda_( + const std::vector& output_tensors, + const std::vector>& input_tensors, + const c10::intrusive_ptr& process_group, + const c10::intrusive_ptr& reduce_op, + int64_t timeout) { + auto work = process_group->reduce_scatter( + const_cast&>(output_tensors), + const_cast>&>(input_tensors), + ReduceScatterOptions{ + *reduce_op.get(), std::chrono::milliseconds(timeout)}); + + return std::tuple, c10::intrusive_ptr>( + output_tensors, work); +} + +c10::intrusive_ptr _reduce_scatter_base_cpu_( + at::Tensor& output_tensor, + at::Tensor& input_tensor, + const c10::intrusive_ptr& process_group, + const c10::intrusive_ptr& reduce_op, + int64_t timeout) { + return process_group->_reduce_scatter_base( + output_tensor, + input_tensor, + ReduceScatterOptions{ + *reduce_op.get(), std::chrono::milliseconds(timeout)}); +} + +c10::intrusive_ptr _reduce_scatter_base_cuda_( + at::Tensor& output_tensor, + at::Tensor& input_tensor, + const c10::intrusive_ptr& process_group, + const c10::intrusive_ptr& reduce_op, + int64_t timeout) { + return process_group->_reduce_scatter_base( + output_tensor, + input_tensor, + ReduceScatterOptions{ + *reduce_op.get(), std::chrono::milliseconds(timeout)}); +} + +c10::intrusive_ptr gather_cpu_( + const std::vector>& output_tensors, + const std::vector& input_tensors, + const c10::intrusive_ptr& process_group, + int64_t root_rank, + int64_t timeout) { + return process_group->gather( + const_cast>&>(output_tensors), + const_cast&>(input_tensors), + GatherOptions{root_rank, std::chrono::milliseconds(timeout)}); +} + +c10::intrusive_ptr gather_cuda_( + const std::vector>& output_tensors, + const std::vector& input_tensors, + const c10::intrusive_ptr& process_group, + int64_t root_rank, + int64_t timeout) { + return process_group->gather( + const_cast>&>(output_tensors), + const_cast&>(input_tensors), + GatherOptions{root_rank, std::chrono::milliseconds(timeout)}); +} + +std::tuple, c10::intrusive_ptr> scatter_cpu_( + const std::vector& output_tensors, + const std::vector>& input_tensors, + const c10::intrusive_ptr& process_group, + int64_t root_rank, + int64_t timeout) { + auto work = process_group->scatter( + const_cast&>(output_tensors), + const_cast>&>(input_tensors), + ScatterOptions{root_rank, std::chrono::milliseconds(timeout)}); + + return std::tuple, c10::intrusive_ptr>( + output_tensors, work); +} + +std::tuple, c10::intrusive_ptr> scatter_cuda_( + const std::vector& output_tensors, + const std::vector>& input_tensors, + const c10::intrusive_ptr& process_group, + int64_t root_rank, + int64_t timeout) { + auto work = process_group->scatter( + const_cast&>(output_tensors), + const_cast>&>(input_tensors), + ScatterOptions{root_rank, std::chrono::milliseconds(timeout)}); + + return std::tuple, c10::intrusive_ptr>( + output_tensors, work); +} + +c10::intrusive_ptr alltoall_cpu_( + at::TensorList output_tensors, + at::TensorList input_tensors, + const c10::intrusive_ptr& process_group, + int64_t timeout) { + auto output_tensors_vec = output_tensors.vec(); + auto input_tensors_vec = input_tensors.vec(); + return process_group->alltoall( + output_tensors_vec, + input_tensors_vec, + AllToAllOptions{std::chrono::milliseconds(timeout)}); +} + +c10::intrusive_ptr alltoall_cuda_( + at::TensorList output_tensors, + at::TensorList input_tensors, + const c10::intrusive_ptr& process_group, + int64_t timeout) { + auto output_tensors_vec = output_tensors.vec(); + auto input_tensors_vec = input_tensors.vec(); + return process_group->alltoall( + output_tensors_vec, + input_tensors_vec, + AllToAllOptions{std::chrono::milliseconds(timeout)}); +} + +c10::intrusive_ptr alltoall_base_cpu_( + at::Tensor& output, + at::Tensor& input, + const c10::intrusive_ptr& process_group, + std::vector output_split_sizes, + std::vector input_split_sizes, + int64_t timeout) { + return process_group->alltoall_base( + output, + input, + output_split_sizes, + input_split_sizes, + AllToAllOptions{std::chrono::milliseconds(timeout)}); +} + +c10::intrusive_ptr alltoall_base_cuda_( + at::Tensor& output, + at::Tensor& input, + const c10::intrusive_ptr& process_group, + std::vector output_split_sizes, + std::vector input_split_sizes, + int64_t timeout) { + return process_group->alltoall_base( + output, + input, + output_split_sizes, + input_split_sizes, + AllToAllOptions{std::chrono::milliseconds(timeout)}); +} + +c10::intrusive_ptr barrier_cpu( + const c10::intrusive_ptr& process_group, + const std::vector& device_ids, + int64_t timeout) { + return process_group->barrier( + BarrierOptions{device_ids, std::chrono::milliseconds(timeout)}); +} + +c10::intrusive_ptr barrier_cuda( + const c10::intrusive_ptr& process_group, + const std::vector& device_ids, + int64_t timeout) { + return process_group->barrier( + BarrierOptions{device_ids, std::chrono::milliseconds(timeout)}); +} + +void monitored_barrier_cpu_( + at::Tensor /* unused */, + const c10::intrusive_ptr<::c10d::ProcessGroup>& process_group, + const std::vector& device_ids, + int64_t timeout, + bool wait_all_ranks) { + process_group->monitoredBarrier( + BarrierOptions{device_ids, std::chrono::milliseconds(timeout)}, + wait_all_ranks); +} + // register functions to dispatcher namespace { TORCH_LIBRARY_IMPL(c10d, CPU, m) { @@ -203,6 +474,14 @@ TORCH_LIBRARY_IMPL(c10d, CUDA, m) { m.impl("recv_", recv_cuda_); } +TORCH_LIBRARY_IMPL(c10d, CPU, m) { + m.impl("recv_any_source_", recv_any_source_cpu_); +} + +TORCH_LIBRARY_IMPL(c10d, CUDA, m) { + m.impl("recv_any_source_", recv_any_source_cuda_); +} + TORCH_LIBRARY_IMPL(c10d, CPU, m) { m.impl("reduce_", reduce_cpu_); } @@ -237,6 +516,14 @@ TORCH_LIBRARY_IMPL(c10d, CUDA, m) { m.impl("allreduce_", allreduce_cuda_); } +TORCH_LIBRARY_IMPL(c10d, CPU, m) { + m.impl("allreduce_coalesced_", allreduce_coalesced_cpu_); +} + +TORCH_LIBRARY_IMPL(c10d, CUDA, m) { + m.impl("allreduce_coalesced_", allreduce_coalesced_cuda_); +} + TORCH_LIBRARY_IMPL(c10d, CPU, m) { m.impl("allgather_", allgather_cpu_); } @@ -244,6 +531,83 @@ TORCH_LIBRARY_IMPL(c10d, CPU, m) { TORCH_LIBRARY_IMPL(c10d, CUDA, m) { m.impl("allgather_", allgather_cuda_); } + +TORCH_LIBRARY_IMPL(c10d, CPU, m) { + m.impl("_allgather_base_", _allgather_base_cpu_); +} + +TORCH_LIBRARY_IMPL(c10d, CUDA, m) { + m.impl("_allgather_base_", _allgather_base_cuda_); +} + +TORCH_LIBRARY_IMPL(c10d, CPU, m) { + m.impl("allgather_coalesced_", allgather_coalesced_cpu_); +} + +TORCH_LIBRARY_IMPL(c10d, CUDA, m) { + m.impl("allgather_coalesced_", allgather_coalesced_cuda_); +} + +TORCH_LIBRARY_IMPL(c10d, CPU, m) { + m.impl("reduce_scatter_", reduce_scatter_cpu_); +} + +TORCH_LIBRARY_IMPL(c10d, CUDA, m) { + m.impl("reduce_scatter_", reduce_scatter_cuda_); +} + +TORCH_LIBRARY_IMPL(c10d, CPU, m) { + m.impl("_reduce_scatter_base_", _reduce_scatter_base_cpu_); +} + +TORCH_LIBRARY_IMPL(c10d, CUDA, m) { + m.impl("_reduce_scatter_base_", _reduce_scatter_base_cuda_); +} + +TORCH_LIBRARY_IMPL(c10d, CPU, m) { + m.impl("gather_", gather_cpu_); +} + +TORCH_LIBRARY_IMPL(c10d, CUDA, m) { + m.impl("gather_", gather_cuda_); +} + +TORCH_LIBRARY_IMPL(c10d, CPU, m) { + m.impl("scatter_", scatter_cpu_); +} + +TORCH_LIBRARY_IMPL(c10d, CUDA, m) { + m.impl("scatter_", scatter_cuda_); +} + +TORCH_LIBRARY_IMPL(c10d, CPU, m) { + m.impl("alltoall_", alltoall_cpu_); +} + +TORCH_LIBRARY_IMPL(c10d, CUDA, m) { + m.impl("alltoall_", alltoall_cuda_); +} + +TORCH_LIBRARY_IMPL(c10d, CPU, m) { + m.impl("alltoall_base_", alltoall_base_cpu_); +} + +TORCH_LIBRARY_IMPL(c10d, CUDA, m) { + m.impl("alltoall_base_", alltoall_base_cuda_); +} + +TORCH_LIBRARY_IMPL(c10d, CPU, m) { + m.impl("barrier", barrier_cpu); +} + +TORCH_LIBRARY_IMPL(c10d, CUDA, m) { + m.impl("barrier", barrier_cuda); +} + +TORCH_LIBRARY_IMPL(c10d, CPU, m) { + m.impl("monitored_barrier_", monitored_barrier_cpu_); +} + } // namespace } // namespace ops diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index c92d24af21c84..387fe5eb4dcc7 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -83,11 +83,10 @@ ncclRedOpRAII unpackPreMulSum( const auto* preMulSupplement = reinterpret_cast(reduceOp.supplement_.get()); ncclRedOp_t preMulSum; - bool has_tensor = !preMulSupplement->tensor_factors.empty(); + bool has_tensor = preMulSupplement->tensor_factor.defined(); auto residence = has_tensor ? ncclScalarDevice : ncclScalarHostImmediate; - T* ptr_factor = has_tensor - ? preMulSupplement->tensor_factors[dev_in_group].data_ptr() - : nullptr; + T* ptr_factor = + has_tensor ? preMulSupplement->tensor_factor.data_ptr() : nullptr; T scalar_factor = T(preMulSupplement->double_factor); ncclRedOpCreatePreMulSum( &preMulSum, @@ -1120,7 +1119,7 @@ void ProcessGroupNCCL::broadcastUniqueNCCLID( "[", rank_, "] is setting up NCCL communicator and " - "retreiving ncclUniqueId from [0] via c10d key-value store by key '", + "retrieving ncclUniqueId from [0] via c10d key-value store by key '", storeKey, "', but store->get('", storeKey, @@ -1133,7 +1132,7 @@ void ProcessGroupNCCL::broadcastUniqueNCCLID( "Unknown exception while [", rank_, "] is setting up NCCL communicator and " - "retreiving ncclUniqueId from [0] via c10d key-value store by key '", + "retrieving ncclUniqueId from [0] via c10d key-value store by key '", storeKey, "'")); } diff --git a/torch/csrc/distributed/c10d/ProcessGroupUCC.cpp b/torch/csrc/distributed/c10d/ProcessGroupUCC.cpp index 61f03abc112de..b03ca490cea9d 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupUCC.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupUCC.cpp @@ -169,6 +169,11 @@ void read_config() { for (auto op : parse_blocking_wait(blocking_wait_str)) { torch_ucc_config.blocking_wait[(std::uint8_t)op] = true; } + // barrier is always blocking + torch_ucc_config.blocking_wait[(std::uint8_t)OpType::BARRIER] = true; + + // barrier is always blocking + torch_ucc_config.blocking_wait[(std::uint8_t)OpType::BARRIER] = true; torch_ucc_config.use_future = std::stoi(torch_ucc_envs_map.at("TORCH_UCC_USE_FUTURE")); @@ -461,7 +466,8 @@ void Comm::enqueue_collective( ucc_coll_req_h request; TORCH_UCC_CHECK( ucc_collective_init(&coll, &request, team), "failed to init collective"); - TORCH_UCC_CHECK(ucc_collective_post(request), "failed to post collective"); + TORCH_UCC_CHECK_REQUEST( + request, ucc_collective_post(request), "failed to post collective"); auto entry = std::make_shared(&ucc_comm, request); @@ -490,7 +496,8 @@ void Comm::enqueue_cuda_collective( comp_ev.ev_context = nullptr; comp_ev.ev_context_size = 0; comp_ev.req = request; - TORCH_UCC_CHECK( + TORCH_UCC_CHECK_REQUEST( + request, ucc_collective_triggered_post(ee, &comp_ev), "failed to post triggered collective"); ucc_status_t st = ucc_ee_get_event(ee, &post_ev); @@ -758,9 +765,10 @@ c10::intrusive_ptr ProcessGroupUCC::collective_post( std::vector& inputTensors, std::vector& outputTensors, const char* prof_title) { + seq_++; set_timeout(coll); auto work = c10::make_intrusive( - opType, prof_title, inputTensors, logger); + opType, seq_, prof_title, inputTensors, logger); if (opType == OpType::RECV) { work->sourceRank_ = coll.root; @@ -783,7 +791,9 @@ c10::intrusive_ptr ProcessGroupUCC::collective_post( work->future_ = c10::make_intrusive( c10::ListType::create(c10::TensorType::get())); } + preproc(); comm->enqueue_collective(std::move(data), work, coll, team); + postproc(); return work; } #ifdef USE_CUDA @@ -1569,6 +1579,12 @@ c10::intrusive_ptr ProcessGroupUCC::recv( "ucc:recv"); } +void ProcessGroupUCC::setSequenceNumberForGroup() {} + +uint64_t ProcessGroupUCC::getSequenceNumberForGroup() { + return seq_; +} + c10::intrusive_ptr ProcessGroupUCC::createProcessGroupUCC( const c10::intrusive_ptr<::c10d::Store>& store, int rank, diff --git a/torch/csrc/distributed/c10d/ProcessGroupUCC.hpp b/torch/csrc/distributed/c10d/ProcessGroupUCC.hpp index 243cf301290e8..03d5d234873da 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupUCC.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupUCC.hpp @@ -117,10 +117,11 @@ class TORCH_API ProcessGroupUCC : public ProcessGroup { public: WorkUCC( OpType opType, + uint64_t seq, const char* prof_title, const c10::optional>& inputs, const c10::intrusive_ptr& logger) - : Work(-1, opType, prof_title, inputs), logger_(logger) {} + : Work(-1, opType, prof_title, inputs), logger_(logger), seq_(seq) {} ~WorkUCC(); void setException(); void setAndThrowException(); @@ -135,9 +136,11 @@ class TORCH_API ProcessGroupUCC : public ProcessGroup { event_pool_t* ep = nullptr; #endif int sourceRank_; + protected: std::shared_ptr entry_; c10::intrusive_ptr logger_; + uint64_t seq_; private: // The future returned by getFuture. @@ -251,6 +254,18 @@ class TORCH_API ProcessGroupUCC : public ProcessGroup { int srcRank, int tag) override; + // Counting for the sequential number of UCC collective_post call. + uint64_t seq_{0}; + + // Agrees on an initial sequence number for the whole group by having rank 0 + // create it and broadcast it to other ranks using the store. + void setSequenceNumberForGroup() override; + + // Retrieves the current sequence number for the whole group, which should be + // in sync. If the returned number is not consistent across the group, it + // may indicate that there is some sort of collective desynchronization. + uint64_t getSequenceNumberForGroup() override; + static c10::intrusive_ptr createProcessGroupUCC( const c10::intrusive_ptr<::c10d::Store>& store, int rank, @@ -264,6 +279,7 @@ class TORCH_API ProcessGroupUCC : public ProcessGroup { uint32_t comm_id; ucc_team_h team{nullptr}; ucc_ee_h cuda_ee{nullptr}; + #ifdef USE_CUDA std::unique_ptr stream = nullptr; event_pool_t ep; diff --git a/torch/csrc/distributed/c10d/Types.hpp b/torch/csrc/distributed/c10d/Types.hpp index 4d928976d87ee..9c163af5cb8e7 100644 --- a/torch/csrc/distributed/c10d/Types.hpp +++ b/torch/csrc/distributed/c10d/Types.hpp @@ -8,6 +8,7 @@ #include #include +#include #include namespace c10d { @@ -21,14 +22,17 @@ struct TORCH_API _SupplementBase : torch::CustomClassHolder { // The point of use in ProcessGroupNCCL knows how to unpack it. struct NCCLPreMulSumSupplement : _SupplementBase { double double_factor{0.0}; - std::vector tensor_factors; + at::Tensor tensor_factor; NCCLPreMulSumSupplement(double f) : double_factor{f} {} - NCCLPreMulSumSupplement(std::vector f) : tensor_factors{std::move(f)} {} + NCCLPreMulSumSupplement(at::Tensor t) : tensor_factor{std::move(t)} { + TORCH_CHECK_EQ(tensor_factor.numel(), 1); + } }; // Other ReduceOps that need different supplementary data can also // derive from _SupplementBase. struct TORCH_API ReduceOp : torch::CustomClassHolder { + // note(crcrpar): RedOpType could be defined outside of `ReduceOp` enum RedOpType : uint8_t { SUM = 0, AVG = 1, @@ -46,7 +50,9 @@ struct TORCH_API ReduceOp : torch::CustomClassHolder { ReduceOp(RedOpType op) : op_(op) { TORCH_INTERNAL_ASSERT( - op_ != PREMUL_SUM, "PREMUL_SUM requires a scale factor tensor or scalar argument"); + op_ != PREMUL_SUM, + "Use `torch.distributed._make_nccl_premul_sum` to create an instance of ReduceOp with PREMUL_SUM" + ); } ReduceOp(RedOpType op, c10::intrusive_ptr<_SupplementBase> optional_supplement) { @@ -57,7 +63,7 @@ struct TORCH_API ReduceOp : torch::CustomClassHolder { } } - // The heap resource supplement_, if it exists, is managed by a shared_ptr, + // The heap resource supplement_, if it exists, is managed by a c10::intrusive_ptr, // so constructors and operator= can be simple ReduceOp(const ReduceOp& other) : op_(other.op_), supplement_(other.supplement_) {} @@ -79,6 +85,7 @@ struct TORCH_API ReduceOp : torch::CustomClassHolder { return *this == static_cast(other); } + // todo(crcrpar): Handle `RedOpType::PREMUL_SUM` with its scaling factor. bool operator==(const ReduceOp& other) { return *this == other.op_; } diff --git a/torch/csrc/distributed/c10d/UCCUtils.cpp b/torch/csrc/distributed/c10d/UCCUtils.cpp index ef934d1597f9d..590a931f2f110 100644 --- a/torch/csrc/distributed/c10d/UCCUtils.cpp +++ b/torch/csrc/distributed/c10d/UCCUtils.cpp @@ -186,7 +186,7 @@ void CommUCC::free_request(ucc_coll_req_h request) { CommUCC::~CommUCC() { if (context != nullptr) { TORCH_UCC_CHECK( - ucc_context_destroy(context), "failed to destory UCC context"); + ucc_context_destroy(context), "failed to destroy UCC context"); } if (lib != nullptr) { TORCH_UCC_CHECK(ucc_finalize(lib), "failed to finalize UCC library"); diff --git a/torch/csrc/distributed/c10d/UCCUtils.hpp b/torch/csrc/distributed/c10d/UCCUtils.hpp index 50510a6ea9a03..3482a1d34ee52 100644 --- a/torch/csrc/distributed/c10d/UCCUtils.hpp +++ b/torch/csrc/distributed/c10d/UCCUtils.hpp @@ -8,27 +8,48 @@ namespace c10d { +// Macro to generate the error message on a non-successful UCC return value. +#define TORCH_UCC_GET_ERROR_MSG(_err, _error_msg, _result) \ + do { \ + _err = c10::str( \ + "[", \ + std::string(__FILE__), \ + ":", \ + std::to_string(__LINE__), \ + "] ", \ + logger->getLogPrefix(), \ + _error_msg, \ + ", error code ", \ + _result, \ + ": ", \ + ucc_status_string(_result), \ + ", system error code ", \ + errno); \ + } while (0) + // Macro to throw on a non-successful UCC return value. -#define TORCH_UCC_CHECK(_cmd, _error_msg) \ - do { \ - ucc_status_t result = _cmd; \ - if (result != UCC_OK) { \ - std::string err = c10::str( \ - "[", \ - std::string(__FILE__), \ - ":", \ - std::to_string(__LINE__), \ - "] ", \ - logger->getLogPrefix(), \ - _error_msg, \ - ", error code ", \ - result, \ - ": ", \ - ucc_status_string(result), \ - ", system error code ", \ - errno); \ - TORCH_CHECK(false, err); \ - } \ +#define TORCH_UCC_CHECK(_cmd, _error_msg) \ + do { \ + ucc_status_t result = _cmd; \ + if (result != UCC_OK) { \ + std::string err; \ + TORCH_UCC_GET_ERROR_MSG(err, _error_msg, result); \ + TORCH_CHECK(false, err); \ + } \ + } while (0) + +// Macro and throw on a non-successful UCC return value and free its request. +#define TORCH_UCC_CHECK_REQUEST(_request, _cmd, _error_msg) \ + do { \ + ucc_status_t result = _cmd; \ + if (result != UCC_OK) { \ + std::string err; \ + TORCH_UCC_GET_ERROR_MSG(err, _error_msg, result); \ + if (_request != nullptr) { \ + ucc_collective_finalize(_request); \ + } \ + TORCH_CHECK(false, err); \ + } \ } while (0) // Macros to print logs with unified format diff --git a/torch/csrc/distributed/c10d/comm.cpp b/torch/csrc/distributed/c10d/comm.cpp index c873eec5fbcf9..d011e5543a5da 100644 --- a/torch/csrc/distributed/c10d/comm.cpp +++ b/torch/csrc/distributed/c10d/comm.cpp @@ -32,7 +32,13 @@ class BroadcastWork { flat_tensor_.front(), bucket_tensors_); TORCH_INTERNAL_ASSERT(output_tensors.size() == bucket_tensors_.size()); for (const auto i : c10::irange(output_tensors.size())) { - bucket_tensors_[i].copy_(output_tensors[i], /*non_blocking=*/true); + // if output_tensor is empty, no need to copy it back, + // this can avoid error when both bucket_tensor and output_tensor + // are empty, but they have different shapes, see + // https://github.com/pytorch/pytorch/issues/87280 + if (output_tensors[i].numel() != 0) { + bucket_tensors_[i].copy_(output_tensors[i], /*non_blocking=*/true); + } } } diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 327c041357266..51ae468ea8068 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -1,6 +1,7 @@ #include #include +#include #include #include #include @@ -235,6 +236,61 @@ void _register_builtin_comm_hook( reducer.register_builtin_comm_hook(comm_hook_type); } +// Customize the metaclass of ::c10d::ReduceOp for the backward compatibility. +// https://github.com/pytorch/pytorch/pull/84243 changed ::c10d::ReduceOp to +// struct from enum, sacrificing some of the Python built-in function supports +// such as `isinstance` (see https://github.com/pytorch/pytorch/issues/87191) +// and `copy` (see +// https://github.com/pytorch/pytorch/pull/87303#discussion_r1002879700). Below, +// we define a custom `isinstance` in CPython/pybind11 +// (`reduceopmeta___instancecheck__`) and modify the default metaclass of +// pybind11 (`GetReduceOpMetaclass`) so that +// `isinstance(torch.distributed.ReduceOp.SUM, torch.distributed.ReduceOp)` +// returns :obj:`True` as if `ReduceOp` is enum. +// Ref: +// - https://docs.python.org/3/extending/newtypes_tutorial.html +// - https://docs.python.org/3/c-api/typeobj.html?highlight=tp_methods +// - https://github.com/pybind/pybind11/issues/2696 +static PyObject* reduceopmeta___instancecheck__( + PyObject* self, + PyObject* args) { + if (Py_TYPE(self) == Py_TYPE(args)) { + Py_RETURN_TRUE; + } + if (c10::string_view(args->ob_type->tp_name).find("RedOpType") != + c10::string_view::npos) { + Py_RETURN_TRUE; + } + Py_RETURN_FALSE; +} +static PyMethodDef reduceopmeta_methods[] = { + {"__instancecheck__", + (PyCFunction)reduceopmeta___instancecheck__, + METH_O, + "Custom `__instancecheck__` for ReduceOp"}, + {NULL, NULL}}; +PyTypeObject* GetReduceOpMetaclass() { + static auto* metaclass = [] { + PyTypeObject* base_metaclass = + pybind11::detail::get_internals().default_metaclass; + PyType_Slot slots[] = { + {Py_tp_base, base_metaclass}, + {Py_tp_methods, reduceopmeta_methods}, + {0}, + }; + PyType_Spec spec = {}; + spec.name = "torch._C._distributed_c10d._ReduceOpMeta"; + spec.basicsize = base_metaclass->tp_basicsize; + spec.flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; + spec.slots = slots; + PyTypeObject* metaclass = (PyTypeObject*)PyType_FromSpec(&spec); + if (!metaclass) + throw py::error_already_set(); + return metaclass; + }(); + return metaclass; +} + PyObject* c10d_init(PyObject* _unused, PyObject* noargs) { C10_LOG_API_USAGE_ONCE("c10d.python.import"); @@ -515,10 +571,15 @@ An enum-like class for built-in communication hooks: ``ALLREDUCE`` and ``FP16_CO R"(Sets the debug level of the torch.distributed package from the ``TORCH_DISTRIBUTED_DEBUG`` environment variable.)"); + // TODO(crcrpar): Hardening `ReduceOp`. + // While keeping most op types as enum value, + // making `PREMUL_SUM` callable, i.e., allowing for + // `ReduceOp.PREMUL_SUM(scale)` might be better as per @wanchaol. // https://pybind11.readthedocs.io/en/stable/classes.html#enumerations-and-internal-types - py::class_<::c10d::ReduceOp> reduce_op(module, "ReduceOp", R"( + py::class_<::c10d::ReduceOp> reduce_op( + module, "ReduceOp", py::metaclass((PyObject*)GetReduceOpMetaclass()), R"( An enum-like class for available reduction operations: ``SUM``, ``PRODUCT``, -``MIN``, ``MAX``, ``BAND``, ``BOR``, and ``BXOR``. +``MIN``, ``MAX``, ``BAND``, ``BOR``, ``BXOR``, and ``PREMUL_SUM``. ``BAND``, ``BOR``, and ``BXOR`` reductions are not available when using the ``NCCL`` backend. @@ -529,13 +590,16 @@ and only for NCCL versions 2.10 or later. ``PREMUL_SUM`` multiplies inputs by a given scalar locally before reduction. ``PREMUL_SUM`` is only available with the ``NCCL`` backend, -and only available for NCCL versions 2.11 or later. +and only available for NCCL versions 2.11 or later. Users are supposed to +use ``torch.distributed._make_nccl_premul_sum``. Additionally, ``MAX``, ``MIN`` and ``PRODUCT`` are not supported for complex tensors. The values of this class can be accessed as attributes, e.g., ``ReduceOp.SUM``. They are used in specifying strategies for reduction collectives, e.g., -:func:`reduce`, :func:`all_reduce_multigpu`, etc.)"); +:func:`reduce`, :func:`all_reduce_multigpu`, etc. + +This class does not support ``__members__`` property.)"); reduce_op.def(py::init<::c10d::ReduceOp::RedOpType>()) .def_readwrite("op", &::c10d::ReduceOp::op_); @@ -544,18 +608,70 @@ They are used in specifying strategies for reduction collectives, e.g., // take hash of `::c10d::ReduceOp`. To avoid losing these functionality, here // I define some member methods. reduce_op + // todo(crcrpar): Support `RedOpType == ReduceOp`. .def( + // This calls `operator==(const ReduceOp::RedOpType)` "__eq__", [](const ::c10d::ReduceOp& self, const ::c10d::ReduceOp::RedOpType& other) { return self == other; }) .def( + // This calls `operator==(const ReduceOp)` for the future support of + // `PREMUL_SUM` comparison "__eq__", [](const ::c10d::ReduceOp& self, const ::c10d::ReduceOp& other) { - return self == other.op_; + return self == other; }) - .def("__hash__", [](const ::c10d::ReduceOp& self) { return self.op_; }); + .def( + // With the above custom `__eq__`'s, I have to manually support the + // other types. + "__eq__", + [](const ::c10d::ReduceOp& self, py::object) { return false; }) + .def( + "__hash__", + [](const ::c10d::ReduceOp& self) { + return static_cast(self.op_); + }) + .def( + "__copy__", + [](const ::c10d::ReduceOp& self) { return ::c10d::ReduceOp(self); }) + .def( + "__deepcopy__", + [](const ::c10d::ReduceOp& self, const py::dict& memo) { + return ::c10d::ReduceOp(self); + }) + .def(py::pickle( + [](const ::c10d::ReduceOp& r) { + // __getstate__ + if (r.op_ != ::c10d::ReduceOp::RedOpType::PREMUL_SUM) { + return py::make_tuple(r.op_, py::none()); + } + TORCH_CHECK(r.supplement_.defined(), "Invalid PREMUL_SUM ReduceOp"); + const auto* preMulSupplement = + reinterpret_cast<::c10d::NCCLPreMulSumSupplement*>( + r.supplement_.get()); + if (!preMulSupplement->tensor_factor.defined()) { + return py::make_tuple(r.op_, preMulSupplement->double_factor); + } else { + return py::make_tuple(r.op_, preMulSupplement->tensor_factor); + } + }, + [](const py::tuple t) { + // __setstate__ + TORCH_CHECK(t.size() == 2, "Invalid state"); + const auto op = + static_cast<::c10d::ReduceOp::RedOpType>(t[0].cast()); + if (op != ::c10d::ReduceOp::RedOpType::PREMUL_SUM) { + return ::c10d::ReduceOp(op); + } + const auto preMulSupplement_factor = t[1]; + if (py::isinstance(preMulSupplement_factor)) { + return ::c10d::makeNCCLPreMulSum(t[1].cast()); + } else { + return ::c10d::makeNCCLPreMulSum(t[1].cast()); + } + })); py::enum_<::c10d::ReduceOp::RedOpType>(reduce_op, "RedOpType") .value("SUM", ::c10d::ReduceOp::RedOpType::SUM) @@ -569,7 +685,8 @@ They are used in specifying strategies for reduction collectives, e.g., .value("PREMUL_SUM", ::c10d::ReduceOp::RedOpType::PREMUL_SUM) .export_values(); - // Ref: [Implicit + // note(crcrpar): This could be removed because users will not pass + // `RedOpType` to reduce collective ops Ref: [Implicit // conversions](https://pybind11.readthedocs.io/en/stable/advanced/classes.html#implicit-conversions) // Let us skip the explicit construction of `c10d::ReduceOp` from // `c10d::ReduceOp::RedOpType` in Python. @@ -584,7 +701,7 @@ They are used in specifying strategies for reduction collectives, e.g., py::call_guard()) .def( "_make_nccl_premul_sum", - &::c10d::makeNCCLPreMulSum>, + &::c10d::makeNCCLPreMulSum, py::arg("factor").noconvert(), py::return_value_policy::copy, // seems safest py::call_guard()); @@ -1121,10 +1238,10 @@ that adds a prefix to each key inserted to the store. .def( "allreduce_coalesced", - [](::c10d::ProcessGroup& self, - std::vector& xs, + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, + const std::vector& xs, ::c10d::AllreduceCoalescedOptions opts) { - return self.allreduce_coalesced(xs, opts); + return ::c10d::ops::allreduce_coalesced(self, xs, opts); }, py::arg("tensors"), py::arg("opts") = ::c10d::AllreduceCoalescedOptions(), @@ -1174,7 +1291,13 @@ that adds a prefix to each key inserted to the store. .def( "_allgather_base", - &::c10d::ProcessGroup::_allgather_base, + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, + at::Tensor& output_tensor, + at::Tensor& input_tensor, + const ::c10d::AllgatherOptions& opts) { + return ::c10d::ops::_allgather_base( + self, output_tensor, input_tensor, opts); + }, py::arg("output"), py::arg("input"), py::arg("opts") = ::c10d::AllgatherOptions(), @@ -1196,7 +1319,13 @@ that adds a prefix to each key inserted to the store. .def( "allgather_coalesced", - &::c10d::ProcessGroup::allgather_coalesced, + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, + const std::vector>& output_lists, + const std::vector& input_list, + const ::c10d::AllgatherOptions& opts) { + return ::c10d::ops::allgather_coalesced( + self, output_lists, input_list, opts); + }, py::arg("output_lists"), py::arg("input_list"), py::arg("opts") = ::c10d::AllgatherOptions(), @@ -1297,7 +1426,13 @@ that adds a prefix to each key inserted to the store. .def( "_reduce_scatter_base", - &::c10d::ProcessGroup::_reduce_scatter_base, + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, + at::Tensor& output_tensor, + at::Tensor& input_tensor, + const ::c10d::ReduceScatterOptions& opts) { + return ::c10d::ops::_reduce_scatter_base( + self, output_tensor, input_tensor, opts); + }, py::arg("outputTensor"), py::arg("inputTensor"), py::arg("opts") = ::c10d::ReduceScatterOptions(), @@ -1305,34 +1440,26 @@ that adds a prefix to each key inserted to the store. .def( "alltoall_base", - &::c10d::ProcessGroup::alltoall_base, - py::arg("output_tensor"), - py::arg("input_tensor"), - py::arg("output_split_sizes"), - py::arg("input_split_sizes"), - py::arg("opts") = ::c10d::AllToAllOptions(), - py::call_guard()) - - .def( - "alltoall_base", - [](::c10d::ProcessGroup& self, + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, at::Tensor& output, at::Tensor& input, std::vector outputSplitSizes, - std::vector inputSplitSizes) { - return self.alltoall_base( + std::vector inputSplitSizes, + const ::c10d::AllToAllOptions& opts) { + return ::c10d::ops::alltoall_base( + self, output, input, outputSplitSizes, inputSplitSizes, - ::c10d::AllToAllOptions()); + opts); }, py::arg("output"), py::arg("input"), py::arg("output_split_sizes"), py::arg("input_split_sizes"), + py::arg("opts") = ::c10d::AllToAllOptions(), py::call_guard()) - .def( "alltoall", [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, @@ -1387,7 +1514,11 @@ that adds a prefix to each key inserted to the store. .def( "recv_anysource", - &::c10d::ProcessGroup::recvAnysource, + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, + const std::vector& tensors, + int64_t tag) { + return ::c10d::ops::recv_any_source(self, tensors, tag); + }, py::call_guard()) .def( @@ -1413,7 +1544,7 @@ that adds a prefix to each key inserted to the store. bool waitAllRanks) { ::c10d::BarrierOptions opts; opts.timeout = timeout; - return self->monitoredBarrier(opts, waitAllRanks); + return ::c10d::ops::monitored_barrier(self, opts, waitAllRanks); }, py::arg("timeout") = ::c10d::kUnsetTimeout, py::arg("wait_all_ranks") = false, @@ -1887,7 +2018,7 @@ Example:: Returns: A ``Work`` object which is associated with the completion of the ``torch.futures.Future``. - This is the prefered way of constructing Work objects when writing a custom ProcessGroup + This is the preferred way of constructing Work objects when writing a custom ProcessGroup in python. Example:: >>> class SingleRankProcessGroup(torch.distributed.ProcessGroup): diff --git a/torch/csrc/distributed/rpc/tensorpipe_agent.cpp b/torch/csrc/distributed/rpc/tensorpipe_agent.cpp index 2480b21d105f1..c885713637421 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_agent.cpp +++ b/torch/csrc/distributed/rpc/tensorpipe_agent.cpp @@ -1260,18 +1260,22 @@ void TensorPipeAgent::updateGroupMembership( workerNameToInfo_.erase(name); workerNameToURL_.erase(name); - for (const auto& it : reverseDeviceMaps_) { - if (reverseDeviceMaps.find(it.first) == reverseDeviceMaps.end()) { - reverseDeviceMaps_.erase(it.first); + // remove reverse device maps that are no longer used + for (auto it = reverseDeviceMaps_.begin(); + it != reverseDeviceMaps_.end();) { + if (reverseDeviceMaps.find(it->first) == reverseDeviceMaps.end()) { + it = reverseDeviceMaps_.erase(it); + } else { + it++; } } - auto iter = devices_.begin(); - while (iter != devices_.end()) { - if (std::find(devices.begin(), devices.end(), *iter) == devices.end()) { - iter = devices_.erase(iter); + // remove devices that are no longer used + for (auto it = devices_.begin(); it != devices_.end();) { + if (std::find(devices.begin(), devices.end(), *it) == devices.end()) { + it = devices_.erase(it); } else { - iter++; + it++; } } } diff --git a/torch/csrc/distributed/rpc/utils.cpp b/torch/csrc/distributed/rpc/utils.cpp index 12e3f2edf7558..c20145e82d038 100644 --- a/torch/csrc/distributed/rpc/utils.cpp +++ b/torch/csrc/distributed/rpc/utils.cpp @@ -38,7 +38,7 @@ void processRemoteProfiledEvents( TORCH_CHECK( enabled, "Profiler was expected to be enabled. This can happen in callback " - " continutations that run in different threads, and the TLS of the " + " continuations that run in different threads, and the TLS of the " " profiler was not propagated."); std::vector events = rpcWithProfilingResp.getProfiledEvents(); const auto& profilingId = rpcWithProfilingResp.getProfilingId(); diff --git a/torch/csrc/dynamo/eval_frame.c b/torch/csrc/dynamo/eval_frame.c index e81457e4a2487..bbfc1bb2897d2 100644 --- a/torch/csrc/dynamo/eval_frame.c +++ b/torch/csrc/dynamo/eval_frame.c @@ -191,6 +191,17 @@ static void destroy_cache_entry(CacheEntry* e) { free(e); } +inline static CacheEntry* get_extra(PyCodeObject* code) { + CacheEntry* extra = NULL; + _PyCode_GetExtra((PyObject*)code, extra_index, (void*)&extra); + return extra; +} + +inline static void set_extra(PyCodeObject* code, CacheEntry* extra) { + // TODO(jansel): would it be faster to bypass this? + _PyCode_SetExtra((PyObject*)code, extra_index, extra); +} + #ifdef TORCHDYNAMO_DEBUG inline static const char* name(PyFrameObject* frame) { DEBUG_CHECK(PyUnicode_Check(frame->f_code->co_name)); @@ -216,10 +227,11 @@ static void call_guard_fail_hook( Py_DECREF(args); } -static PyCodeObject* lookup(CacheEntry* e, PyObject* f_locals) { +static PyCodeObject* lookup(CacheEntry* e, PyFrameObject *frame, CacheEntry* prev) { if (e == NULL) { return NULL; } + PyObject *f_locals = frame->f_locals; PyObject* dotzero = PyDict_GetItem(f_locals, dotzerokey); PyObject* valid = NULL; if (unlikely(dotzero != NULL)) { @@ -240,12 +252,21 @@ static PyCodeObject* lookup(CacheEntry* e, PyObject* f_locals) { } Py_DECREF(valid); if (valid == Py_True) { + // Keep the head as the most recently used cache entry. + // If the hit cache entry is not the head of the linked list, + // move it to the head + if (prev != NULL) { + CacheEntry* extra = get_extra(frame->f_code); + prev->next = e->next; + e->next = extra; + set_extra(frame->f_code, e); + } return e->code; } if (unlikely(guard_fail_hook != NULL)) { call_guard_fail_hook(guard_fail_hook, e, f_locals); } - return lookup(e->next, f_locals); + return lookup(e->next, frame, e); } static long cache_size(CacheEntry* e) { @@ -255,17 +276,6 @@ static long cache_size(CacheEntry* e) { return 1 + cache_size(e->next); } -inline static CacheEntry* get_extra(PyCodeObject* code) { - CacheEntry* extra = NULL; - _PyCode_GetExtra((PyObject*)code, extra_index, (void*)&extra); - return extra; -} - -inline static void set_extra(PyCodeObject* code, CacheEntry* extra) { - // TODO(jansel): would it be faster to bypass this? - _PyCode_SetExtra((PyObject*)code, extra_index, extra); -} - inline static PyObject* eval_custom_code( PyThreadState* tstate, PyFrameObject* frame, @@ -358,7 +368,7 @@ static PyObject* _custom_eval_frame( // we never compile. if (callback == Py_False) { DEBUG_TRACE("In run only mode %s", name(frame)); - PyCodeObject* cached_code = lookup(extra, frame->f_locals); + PyCodeObject* cached_code = lookup(extra, frame, NULL); if (cached_code != NULL) { // used cached version DEBUG_TRACE("cache hit %s", name(frame)); @@ -377,7 +387,7 @@ static PyObject* _custom_eval_frame( // in the shim. eval_frame_callback_set(Py_None); - PyCodeObject* cached_code = lookup(extra, frame->f_locals); + PyCodeObject* cached_code = lookup(extra, frame, NULL); if (cached_code != NULL) { // used cached version DEBUG_TRACE("cache hit %s", name(frame)); diff --git a/torch/csrc/functorch/init.cpp b/torch/csrc/functorch/init.cpp index b1f696ee3c7d0..8064293016faf 100644 --- a/torch/csrc/functorch/init.cpp +++ b/torch/csrc/functorch/init.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -438,6 +439,14 @@ void initFuncTorchBindings(PyObject* module) { m.def( "get_inplace_requires_grad_allowed", &at::functorch::getInplaceRequiresGradAllowed); + m.def( + "set_autograd_function_allowed", + &at::functorch::setAutogradFunctionAllowed); + m.def( + "get_autograd_function_allowed", + &at::functorch::getAutogradFunctionAllowed); + m.def("unwrap_if_dead", &unwrapIfDead); + m.def("is_dead_tensor_wrapper", &isDeadTensorWrapper); m.def("dlevel", &dlevel, "dlevel"); m.def("dump_tensor", &dump_tensor, "dump_tensor"); m.def("reshape_dim_into", &at::functorch::reshape_dim_into); @@ -461,6 +470,40 @@ void initFuncTorchBindings(PyObject* module) { m.def("is_functorch_wrapped_tensor", [](const Tensor& tensor) { return maybe_get_level(tensor) != -1; }); + m.def("peek_interpreter_stack", []() -> c10::optional { + const auto& stack = getDynamicLayerStack(); + if (stack.size() == 0) { + return c10::nullopt; + } + auto result = stack.back().interpreter(); + return result; + }); + m.def("pop_dynamic_layer_stack", &popDynamicLayer); + m.def("push_dynamic_layer_stack", [](DynamicLayer layer) -> int64_t { + return pushDynamicLayer(std::move(layer)); + }); + py::class_(m, "DynamicLayer"); + + py::enum_(m, "TransformType") + .value("Torch", TransformType::Torch) + .value("Grad", TransformType::Grad) + .value("Jvp", TransformType::Jvp) + .value("Functionalize", TransformType::Functionalize) + .value("Vmap", TransformType::Vmap); + py::class_(m, "CInterpreter") + .def("key", &Interpreter::key) + .def("level", &Interpreter::level); + py::class_(m, "CGradInterpreterPtr") + .def(py::init()) + .def("key", &GradInterpreterPtr::key) + .def("level", &GradInterpreterPtr::level) + .def("lift", &GradInterpreterPtr::lift) + .def("prevGradMode", &GradInterpreterPtr::prevGradMode); + py::class_(m, "CVmapInterpreterPtr") + .def(py::init()) + .def("key", &VmapInterpreterPtr::key) + .def("level", &VmapInterpreterPtr::level) + .def("batchSize", &VmapInterpreterPtr::batchSize); } } // namespace impl diff --git a/torch/csrc/init_flatbuffer_module.cpp b/torch/csrc/init_flatbuffer_module.cpp index f739f834dc293..96e69ea754cc1 100644 --- a/torch/csrc/init_flatbuffer_module.cpp +++ b/torch/csrc/init_flatbuffer_module.cpp @@ -16,8 +16,9 @@ #include #include #include +#include #include -#include +#include namespace py = pybind11; diff --git a/torch/csrc/jit/OVERVIEW.md b/torch/csrc/jit/OVERVIEW.md index c1bcd57c73a5f..7168967897626 100644 --- a/torch/csrc/jit/OVERVIEW.md +++ b/torch/csrc/jit/OVERVIEW.md @@ -894,7 +894,7 @@ def LSTMCellS(x, hx, cx, w_ih, w_hh, b_ih, b_hh): return hy, cy ``` -After going through the the frontend, we start with this unoptimized graph: +After going through the frontend, we start with this unoptimized graph: ``` graph(%x : Tensor, @@ -1408,7 +1408,7 @@ TODO: differentiation, symbolic autograd, fusion, operators We attempt to reduce the number of `prim::Guard` nodes as these nodes may interfere with optimizations. * First, `GuardElimination::moveGuardsToDefs` tries to move `prim::Guards` to their definitions, so the guards guarding the same `Tensor` follow the definition directly or another guard on the same `Tensor`. * This ordering allows us to **coalesce** (done in `GuardElimination::coalesceGuards`) multiple guards into a single one. -* After guards are **coaslesced** , `GuardElimination::eliminateGuards` attempts to eliminate more guards as follows: it inspects each operation and its inputs. It checks if inputs to the operation are guarded and also if the operation produces the consistent shapes given the guarded inputs. For example, if two inputs to `add` are guaranteed to be of shape `(2, 3)`, the output shape will also always be `(2, 3)`. If this property holds, we are allowed to remove the guard guarding operation's output. +* After guards are **coalesced** , `GuardElimination::eliminateGuards` attempts to eliminate more guards as follows: it inspects each operation and its inputs. It checks if inputs to the operation are guarded and also if the operation produces the consistent shapes given the guarded inputs. For example, if two inputs to `add` are guaranteed to be of shape `(2, 3)`, the output shape will also always be `(2, 3)`. If this property holds, we are allowed to remove the guard guarding operation's output. Lastly, we need to be handle cases when the assumptions about `Tensor` shapes fail at runtime. To handle guard failures, we need to be able to run the original code i.e. the code that doesn't rely on assumptions about shapes. As guards can be inserted and moved (by Optimizer) at/to arbitrary points in a computational graph, we need to be able to resume execution starting from those arbitrary points onward. diff --git a/torch/csrc/jit/backends/coreml/objc/PTMCoreMLBackend.mm b/torch/csrc/jit/backends/coreml/objc/PTMCoreMLBackend.mm index 3588128cc0522..9db3509dc1d2b 100644 --- a/torch/csrc/jit/backends/coreml/objc/PTMCoreMLBackend.mm +++ b/torch/csrc/jit/backends/coreml/objc/PTMCoreMLBackend.mm @@ -88,9 +88,14 @@ GenericList pack_outputs(const std::vector& output_specs, id(), (float*)val.multiArrayValue.dataPointer, count * sizeof(float)); - outputs.push_back(tensor); + outputs.push_back(std::move(tensor)); } - return c10::impl::toList(outputs); + if(output_specs.size() > 1){ + c10::List> output_res; + output_res.push_back(std::move(outputs)); + return c10::impl::toList(std::move(output_res)); + } + return c10::impl::toList(std::move(outputs)); } class CoreMLBackend: public torch::jit::PyTorchBackendInterface { diff --git a/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.cpp b/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.cpp new file mode 100644 index 0000000000000..a64bf35431fdd --- /dev/null +++ b/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.cpp @@ -0,0 +1,118 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include +#include + +#include + +namespace torch { +namespace jit { +namespace xnnpack { +namespace delegate { + +void XNNCompiler::compileModel( + const void* buffer_pointer, + size_t num_bytes, + XNNExecutor* executor) { + auto output_min = -std::numeric_limits::infinity(); + auto output_max = std::numeric_limits::infinity(); + + auto flatbuffer_graph = fb_xnnpack::GetXNNGraph(buffer_pointer); + // initialize xnnpack + xnn_status status = xnn_initialize(/*allocator =*/nullptr); + TORCH_CHECK(xnn_status_success == status, "Failed to initialize xnnpack"); + + // create xnnpack subgraph + xnn_subgraph_t subgraph_ptr = nullptr; + status = xnn_create_subgraph( + /*external_value_ids=*/flatbuffer_graph->num_externs(), + /*flags=*/0, + &subgraph_ptr); + TORCH_CHECK(xnn_status_success == status, "Failed to create xnn subgraph"); + + // mapping from old ids to new created value ids + // The old ids that were serialied were generated AoT, since + // we are re-defining tensor values, the defined IDs could be + // different from the ones generated AoT, as a result, we need + // a new mapping from the old ids to the newly created ones + std::unordered_map remapped_ids; + + for (auto value : *flatbuffer_graph->xvalues()) { + switch (value->xvalue_type()) { + case fb_xnnpack::XValueUnion::XNNTensorValue: { + auto tensor_value = value->xvalue_as_XNNTensorValue(); + + std::vector dims_data; + for (auto dim : *tensor_value->dims()) { + dims_data.push_back(static_cast(dim)); + } + + uint32_t id = XNN_INVALID_VALUE_ID; + const auto& constant_buffer = *flatbuffer_graph->constant_buffer(); + auto buffer_idx = tensor_value->constant_buffer_idx(); + const auto buffer_ptr = buffer_idx == 0 + ? nullptr + : constant_buffer[buffer_idx]->storage()->data(); + status = xnn_define_tensor_value( + /*subgraph=*/subgraph_ptr, + /*datatype=*/xnn_datatype_fp32, + /*num_dims=*/tensor_value->num_dims(), + /*dims=*/dims_data.data(), + /*data=*/buffer_ptr, + /*external_id=*/tensor_value->external_id(), + /*flags=*/tensor_value->flags(), + /*id_out=*/&id); + TORCH_CHECK( + status == xnn_status_success, + "Failed to define tensor values in graph") + // map serialized id to newly generated id + remapped_ids.emplace(std::make_pair(tensor_value->id_out(), id)); + break; + } + default: { + TORCH_CHECK(false, "Unhandled value type found in deserialization"); + } + } + } + + for (auto node : *flatbuffer_graph->xnodes()) { + switch (node->xnode_type()) { + case fb_xnnpack::XNodeUnion::XNNAdd: { + auto graph_node = node->xnode_as_XNNAdd(); + status = xnn_define_add2( + subgraph_ptr, + output_min, + output_max, + remapped_ids.at(graph_node->input1_id()), + remapped_ids.at(graph_node->input2_id()), + remapped_ids.at(graph_node->output_id()), + graph_node->flags()); + TORCH_CHECK(status == xnn_status_success, "Failed to create add node") + break; + } + default: + TORCH_CHECK(false, "Unhandled node type found in deserialization"); + } + } + + xnn_runtime_t runtime_ptr = nullptr; + status = xnn_create_runtime_v2(subgraph_ptr, nullptr, 0, &runtime_ptr); + TORCH_CHECK(xnn_status_success == status); + + executor->runtime_ = + std::unique_ptr( + runtime_ptr, xnn_delete_runtime); + + for (auto old_id : *flatbuffer_graph->input_ids()) { + executor->input_ids_.emplace_back(remapped_ids.at(old_id)); + } + + for (auto old_id : *flatbuffer_graph->output_ids()) { + executor->output_ids_.emplace_back(remapped_ids.at(old_id)); + } +}; + +} // namespace delegate +} // namespace xnnpack +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.h b/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.h new file mode 100644 index 0000000000000..f74e784111d4f --- /dev/null +++ b/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.h @@ -0,0 +1,27 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include +#include +#include +#include + +namespace torch { +namespace jit { +namespace xnnpack { +namespace delegate { + +class XNNCompiler { + public: + // Takes Flatbuffer Serialized XNNPack Model and rebuilds the xnn-subgraph + // returns an executor object that holds the xnn runtime object which we + // can then use to set inputs and run inference using the xnn graph. + static void compileModel( + const void* buffer_pointer, + size_t num_bytes, + XNNExecutor* executor); +}; + +} // namespace delegate +} // namespace xnnpack +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/backends/xnnpack/executor/xnn_executor.h b/torch/csrc/jit/backends/xnnpack/executor/xnn_executor.h new file mode 100644 index 0000000000000..2521c0c7749d8 --- /dev/null +++ b/torch/csrc/jit/backends/xnnpack/executor/xnn_executor.h @@ -0,0 +1,70 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +#pragma once +#include +#include +#include + +namespace torch { +namespace jit { +namespace xnnpack { +namespace delegate { + +class XNNExecutor { + private: + std::unique_ptr runtime_{ + nullptr, + &xnn_delete_runtime}; + std::vector input_ids_; + std::vector output_ids_; + std::vector externals_; + + public: + XNNExecutor() = default; + + template + bool set_inputs(std::vector& inputs, std::vector& outputs) { + externals_.clear(); + + if (inputs.size() != input_ids_.size()) { + return false; + } + + for (int i = 0; i < inputs.size(); i++) { + externals_.emplace_back(xnn_external_value{input_ids_[i], inputs[i]}); + } + + if (outputs.size() != output_ids_.size()) { + return false; + } + + for (int i = 0; i < outputs.size(); i++) { + externals_.emplace_back(xnn_external_value{output_ids_[i], outputs[i]}); + } + + return true; + } + + bool forward() { + xnn_status status = + xnn_setup_runtime(runtime_.get(), externals_.size(), externals_.data()); + + if (status != xnn_status_success) { + return false; + } + + status = xnn_invoke_runtime(runtime_.get()); + + if (status != xnn_status_success) { + return false; + } + + return true; + } + + friend class XNNCompiler; +}; + +} // namespace delegate +} // namespace xnnpack +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/backends/xnnpack/serialization/schema.fbs b/torch/csrc/jit/backends/xnnpack/serialization/schema.fbs new file mode 100644 index 0000000000000..cc1290b718fac --- /dev/null +++ b/torch/csrc/jit/backends/xnnpack/serialization/schema.fbs @@ -0,0 +1,97 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. + +namespace fb_xnnpack; + +// datatype for xnn-values +enum XNNDatatype : short { + /// Invalid data type. Valid Values never have this datatype. + xnn_datatype_invalid = 0, + /// IEEE754 single-precision floating-point. + xnn_datatype_fp32 = 1, + /// IEEE754 half-precision floating-point. + xnn_datatype_fp16 = 2, + /// Quantized 8-bit signed integer with shared per-Value quantization parameters. + xnn_datatype_qint8 = 3, + /// Quantized 32-bit signed integer with shared per-Value quantization parameters. + xnn_datatype_qint32 = 4, +} + +// taken from executorch +// Data buffer abstraction. +table Buffer { + storage:[ubyte] (force_align: 16); +} + +table XNNTensorValue { + // type of the tensor elements. + datatype:XNNDatatype; + // number of dimensions in the shape. + num_dims:uint; + // pointer to an array of @a num_dims shape dimensions. If num_dims is 0, this pointer can be NULL. + // XNNPACK does not keep any pointers to this array after the function returns. + dims:[uint]; + // Index to the program's constant buffer table, value 0 is reserved to indicate non constant + constant_buffer_idx:uint; + // external ID for the Value. The ID must be within the range of reserved Value IDs specified on + // the Subgraph creation. If the external ID is XNN_INVALID_VALUE_ID, an internal ID will be + // created for the Value. + external_id:uint; + // binary features of the Value. Supported values are any combination of XNN_VALUE_FLAG_EXTERNAL_INPUT + // and XNN_VALUE_FLAG_EXTERNAL_OUTPUT. + flags:uint; + // pointer to the variable that will be initialized with the Value ID upon successful return. If a + // valid @a external_id was provided, the variable will be initialized with the @a external_id value. + id_out:uint; +} + +union XNodeUnion { + XNNAdd, +} + +union XValueUnion { + XNNTensorValue, +} + +table XNode { + xnode:XNodeUnion; + // An int which can be linked back to the node in the origin graph + debug_handle:uint; +} + +table XValue { + xvalue:XValueUnion; +} + +table XNNAdd { + input1_id:uint; + input2_id:uint; + output_id:uint; + flags:uint; +} + +table XNNGraph { + // Schema version. + version:string; + xnodes:[XNode]; + xvalues:[XValue]; + + // Number of external inputs/outputs + num_externs:uint; + + // Ids of external inputs + input_ids:[uint]; + + // Ids of external outputs + output_ids:[uint]; + + // Tables of constant data, used for constant Values (e.g. + // data field of weight tensors). Each constant is assigned an index into the table + // which are each individually aligned. 0 index is reserved to be pointed to by non-constant + // Tensors + constant_buffer:[Buffer]; + + // the list index is memory buffer id, the value is the memory buffer size. + mem_buffer_sizes: [uint]; +} + +root_type XNNGraph; diff --git a/torch/csrc/jit/backends/xnnpack/serialization/serializer.cpp b/torch/csrc/jit/backends/xnnpack/serialization/serializer.cpp new file mode 100644 index 0000000000000..637f7cdf4c521 --- /dev/null +++ b/torch/csrc/jit/backends/xnnpack/serialization/serializer.cpp @@ -0,0 +1,102 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include +#include + +#include + +namespace torch { +namespace jit { +namespace xnnpack { +namespace delegate { + +using namespace fb_xnnpack; + +void XNNSerializer::serializeAddNode( + uint32_t input1_id, + uint32_t input2_id, + uint32_t output_id, + uint32_t flags) { + const auto addNode = + CreateXNNAdd(_builder, input1_id, input2_id, output_id, flags); + const auto flatbufferNode = + CreateXNode(_builder, XNodeUnion::XNNAdd, addNode.Union()); + _nodes.push_back(flatbufferNode); +} + +size_t XNNSerializer::serializeData(const uint8_t* data_ptr, size_t num_bytes) { + size_t constant_buffer_idx = 0; + // Handling the tensor _values with data + if (data_ptr != nullptr) { + // steps: + // 1. creating flatbuffer byte-vector for tensor data + auto storage = _builder.CreateVector(data_ptr, num_bytes); + + // 2. put it in the common buffer + constant_buffer_idx = _constantBuffer.size(); + _constantBuffer.emplace_back(CreateBuffer(_builder, storage)); + + // 3. record size into bufferSizes + _bufferSizes.push_back(num_bytes); + assert(_bufferSizes.size() == _constantBuffer.size()); + } + return constant_buffer_idx; +} + +void XNNSerializer::serializeTensorValue( + uint32_t xnn_datatype, + size_t num_dims, + std::vector dims, + size_t data_buffer_idx, + uint32_t external_id, + uint32_t flags, + uint32_t id_out) { + std::vector serialized_dims; + serialized_dims.reserve(dims.size()); + for (auto dim : dims) { + serialized_dims.push_back(static_cast(dim)); + } + + const auto tensorValue = CreateXNNTensorValueDirect( + _builder, + XNNDatatype(xnn_datatype), + num_dims, + &serialized_dims, + data_buffer_idx, + external_id, + flags, + id_out); + + const auto flatbufferValue = + CreateXValue(_builder, XValueUnion::XNNTensorValue, tensorValue.Union()); + _values.push_back(flatbufferValue); +} + +std::string XNNSerializer::finishAndSerialize( + std::vector input_ids, + std::vector output_ids, + size_t num_extern_ids) { + auto xnnGraph = CreateXNNGraphDirect( + _builder, + _version_sha1, + &_nodes, + &_values, + num_extern_ids, + &input_ids, + &output_ids, + &_constantBuffer, + &_bufferSizes); + + _builder.Finish(xnnGraph); + + std::stringstream ss; + ss.write( + reinterpret_cast(_builder.GetBufferPointer()), _builder.GetSize()); + + return ss.str(); +} + +} // namespace delegate +} // namespace xnnpack +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/backends/xnnpack/serialization/serializer.h b/torch/csrc/jit/backends/xnnpack/serialization/serializer.h new file mode 100644 index 0000000000000..5a683c3dc3233 --- /dev/null +++ b/torch/csrc/jit/backends/xnnpack/serialization/serializer.h @@ -0,0 +1,86 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { +namespace xnnpack { +namespace delegate { + +using namespace fb_xnnpack; // Specified in the schema + +class XNNSerializer { + public: + // Constructors + // initial buffersize of 1024 which will grow + // automatically, constant buffer and buffer sizes initialized with dummy + // values as 0 index is reserved for non-constant tensors + XNNSerializer() : XNNSerializer(1024) {} + + explicit XNNSerializer(size_t bufferSize) + : _builder(bufferSize), + _nodes(), + _values(), + _constantBuffer({CreateBuffer( + _builder, + {})}), // index 0 is reserved for non-const data + _bufferSizes({0}) {} + + // Serializing Nodes + + // Serialize add node, we are serializing the argument needed to call + // xnn_define_add2. Serializing these values, and at run time we build + // teh graph by re running xnn_define_add2 + void serializeAddNode( + uint32_t input1_id, + uint32_t input2_id, + uint32_t output_id, + uint32_t flags); + + // Serializing Values + void serializeTensorValue( + uint32_t xnn_datatype, + size_t num_dims, + std::vector dims, + size_t buffer_data_idx, + uint32_t external_id, + uint32_t flags, + uint32_t id_out); + + // finish and serialize xnngraph returning serialized data + std::string finishAndSerialize( + std::vector input_ids, + std::vector output_ids, + size_t num_extern_ids); + + // decoupled data serialization with tensor values. This way constant tensor + // data can be referenced by multiple intermediate tensors. This call + // serializes the num_bytes of the data_ptr and returns the index it was + // placed in. + size_t serializeData(const uint8_t* data_ptr, size_t num_bytes); + + private: + // xnnpack version we are serializing + const char* _version_sha1 = "ae108ef49aa5623b896fc93d4298c49d1750d9ba"; + + // flatbuffer objects we will create and serialize together to create xnngraph + flatbuffers_fbsource::FlatBufferBuilder _builder; + + // Vector of the serialized xnnpack nodes + std::vector> _nodes; + + // Vector of the serialized xnnpack values + std::vector> _values; + + std::vector> _constantBuffer; + std::vector _bufferSizes; +}; + +} // namespace delegate +} // namespace xnnpack +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/backends/xnnpack/xnnpack_backend_lib.cpp b/torch/csrc/jit/backends/xnnpack/xnnpack_backend_lib.cpp index 4d1e934de4d97..46c7458039d47 100644 --- a/torch/csrc/jit/backends/xnnpack/xnnpack_backend_lib.cpp +++ b/torch/csrc/jit/backends/xnnpack/xnnpack_backend_lib.cpp @@ -1,15 +1,27 @@ +#include #include #include #include #include -#include +#include +#include namespace torch { namespace jit { namespace xnnpack { namespace delegate { +class XNNModelWrapper : public CustomClassHolder { + public: + XNNExecutor executor_; + XNNModelWrapper(XNNExecutor executor) : executor_(std::move(executor)){}; + + XNNModelWrapper() = delete; + + XNNModelWrapper(const XNNModelWrapper& oldObject) = delete; +}; + class XNNPackBackend : public PyTorchBackendInterface { public: // Constructor. @@ -26,9 +38,28 @@ class XNNPackBackend : public PyTorchBackendInterface { c10::IValue processed, c10::impl::GenericDict method_compile_spec) override { auto dict = processed.toGenericDict(); + + // Compiling and wrapping exeuction object + const std::string& ser_model = dict.at("ser_model").toStringRef(); + XNNExecutor executor; + XNNCompiler::compileModel(ser_model.data(), ser_model.length(), &executor); + + auto model_ptr = c10::make_intrusive(std::move(executor)); + auto runtime_handle = IValue::make_capsule(model_ptr); + auto wrapper = c10::static_intrusive_pointer_cast( + runtime_handle.toCapsule()); + + // Packing outputs into generic dict c10::Dict handles( c10::StringType::get(), c10::AnyType::get()); - handles.insert("forward", dict); + + c10::Dict ret( + c10::StringType::get(), c10::AnyType::get()); + + ret.insert("runtime", runtime_handle); + ret.insert("output_shapes", dict.at("outputs")); + + handles.insert("forward", ret); return handles; } @@ -41,9 +72,38 @@ class XNNPackBackend : public PyTorchBackendInterface { c10::impl::GenericList execute( c10::IValue handle, c10::impl::GenericList inputs) override { - c10::List output_list; - auto answer = handle.toGenericDict().at("Answer"); - output_list.emplace_back(answer.toTensor()); + auto dict = handle.toGenericDict(); + auto output_shapes = dict.at("output_shapes").toList(); + + auto capsule = dict.at("runtime").toCapsule(); + auto model_wrapper = + c10::static_intrusive_pointer_cast(capsule); + + XNNExecutor& executor = model_wrapper->executor_; + + std::vector input_pointers; + for (int i = 0; i < inputs.size(); ++i) { + at::IValue val = inputs.get(i); + TORCH_CHECK(val.isTensor(), "Non-tensor inputs not supported"); + input_pointers.push_back(val.toTensor().data_ptr()); + } + + std::vector output_tensors; + std::vector output_pointers; + output_tensors.reserve(output_shapes.size()); + for (int i = 0; i < output_shapes.size(); i++) { + auto o_shape = output_shapes.get(i).toIntVector(); + auto output = at::empty(o_shape, c10::ScalarType::Float); + output_tensors.push_back(output); + output_pointers.push_back(output.data_ptr()); + } + + TORCH_CHECK( + executor.set_inputs(input_pointers, output_pointers), + "Number of inputs/outputs does not match expected number of inputs/outputs"); + TORCH_CHECK(executor.forward(), "Failed to invoke XNNPack runtime"); + + c10::List output_list(output_tensors); return c10::impl::toList(output_list); } }; diff --git a/torch/csrc/jit/backends/xnnpack/xnnpack_backend_preprocess.cpp b/torch/csrc/jit/backends/xnnpack/xnnpack_backend_preprocess.cpp index 6d739f4097444..b4b7c912554a5 100644 --- a/torch/csrc/jit/backends/xnnpack/xnnpack_backend_preprocess.cpp +++ b/torch/csrc/jit/backends/xnnpack/xnnpack_backend_preprocess.cpp @@ -6,6 +6,7 @@ #include #include +#include namespace torch { namespace jit { @@ -18,7 +19,10 @@ namespace delegate { // } // or // { -// "forward" : {"inputs" : c10::List} +// "forward" : { +// "inputs" : c10::List, +// "outputs" : c10::List +// } // } // in which the value for "inputs" is the input shape to the module. // The module fed to the xnnpack backend must first be traced in order @@ -28,8 +32,9 @@ c10::IValue preprocess( const Module& mod, const c10::Dict& method_compile_spec, const BackendDebugHandleGenerator& generate_debug_handles) { - auto output_min = -std::numeric_limits::infinity(); - auto output_max = std::numeric_limits::infinity(); + auto eval_mod = mod.clone(); + eval_mod.eval(); + eval_mod = torch::jit::freeze(eval_mod); c10::Dict compiled(StringType::get(), TensorType::get()); @@ -58,7 +63,7 @@ c10::IValue preprocess( "method_compile_spec does not contain either a Tensor or TensorList, under it's \"outputs\" key."); // Graph preprocessing - const auto& forward_method = mod.get_method("forward"); + const auto& forward_method = eval_mod.get_method("forward"); auto graph = toGraphFunction(forward_method.function()).graph()->copy(); graph = tensorexpr::removeUnusedSelfArgument(graph); @@ -71,7 +76,6 @@ c10::IValue preprocess( example_inputs.reserve(inp_list.size()); for (const auto i : c10::irange(inp_list.size())) { - graph->inputs()[i]->setType(TensorType::create(inp_list[i])); example_inputs.emplace_back(inp_list[i]); } } else { @@ -79,11 +83,43 @@ c10::IValue preprocess( graph->inputs().size() == 1, "method_compile_spec inputs do not match expected number of forward inputs"); - graph->inputs()[0]->setType(TensorType::create(inp.toTensor())); example_inputs.emplace_back(inp.toTensor()); } - compiled.insert("Answer", at::empty({1}, c10::ScalarType::Float)); + // inp above has been confirmed to be either Tensor or TensorList + XNNGraph graph_builder; + graph_builder.buildXNNGraph(graph, example_inputs); + // at this point graph is complete, for the sake of testing preprocess at this + // point we will do runtime setup and run with some default values + + // grabbing the inputs from compile spec for testing + + // gather sample inputs from compile spec + std::vector inputs; + auto input_list = inp.toList(); + + for (int i = 0; i < input_list.size(); i++) { + inputs.push_back(input_list.get(i).toTensor()); + } + std::vector outputs; + auto output_list = out.toList(); + std::vector output_shapes; + + // gather sample outputs from compile spec + for (int i = 0; i < output_list.size(); i++) { + auto sample_output = output_list.get(i).toTensor(); + outputs.push_back(sample_output); + // also gather output shapes to forward along to device + output_shapes.push_back(sample_output.sizes()); + } + + // sample run on sample inputs + graph_builder.runGraphOnInputs(inputs, outputs); + c10::List shapes_list(output_shapes); + + compiled.insert("ser_model", graph_builder.serializedXNNGraph()); + compiled.insert("outputs", shapes_list); + compiled.insert("Answer", outputs); return compiled; } diff --git a/torch/csrc/jit/backends/xnnpack/xnnpack_graph_builder.cpp b/torch/csrc/jit/backends/xnnpack/xnnpack_graph_builder.cpp new file mode 100644 index 0000000000000..7c7bb2d02e4c2 --- /dev/null +++ b/torch/csrc/jit/backends/xnnpack/xnnpack_graph_builder.cpp @@ -0,0 +1,324 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include +#include +#include + +// graph passes +#include +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { +namespace xnnpack { +namespace delegate { + +std::shared_ptr XNNGraph::optimizeAndTraceGraph( + std::shared_ptr graph, + std::vector& example_inputs) { + OptimizeFrozenGraph(graph, true); + RemoveListMutation(graph); + RemoveTensorMutation(graph); + LowerAllTuples(graph); + ConstantPropagation(graph); + graph = TraceGraph(graph, example_inputs); + + return graph; +} + +void XNNGraph::buildXNNGraph( + std::shared_ptr& graph, + std::vector example_inputs) { + graph = optimizeAndTraceGraph(graph, example_inputs); + checkOpsToDelegate(graph); + gatherTensorValues(graph); + + // count unique input/outputs (some inputs can be outputs) + std::unordered_set externals; + for (auto inp : _inputs) { + externals.insert(inp); + } + for (auto out : _outputs) { + externals.insert(out); + } + + // create subgraph + xnn_status status = xnn_create_subgraph( + /*external_value_ids=*/externals.size(), + /*flags=*/0, + &_subgraph_ptr); + TORCH_CHECK(xnn_status_success == status, "Failed to create xnn subgraph"); + + defineAllTensorValues(); + defineAllNodes(graph); + // at this point graph is complete, for the sake of testing preprocess at + // this point we will do runtime setup and run with some default values +} + +void XNNGraph::runGraphOnInputs( + std::vector tensor_inputs, + std::vector tensor_outputs) { + TORCH_CHECK( + _subgraph_ptr != nullptr, + "run buildXNNGraph before running graph on inputs"); + xnn_runtime_t runtime = nullptr; + xnn_status status = + xnn_create_runtime_v2(_subgraph_ptr, nullptr, /*flags=*/0, &runtime); + TORCH_CHECK( + xnn_status_success == status, + "failed to create runtime for running inputs"); + + // smart pointer for runtime + std::unique_ptr auto_runtime( + runtime, xnn_delete_runtime); + + std::vector external_values; + TORCH_CHECK( + tensor_inputs.size() == _inputs.size(), + "supplied inputs does not match expected inputs"); + for (int i = 0; i < tensor_inputs.size(); i++) { + external_values.push_back( + {_val_to_ids[_inputs[i]], tensor_inputs[i].data_ptr()}); + } + + TORCH_CHECK( + tensor_outputs.size() == _outputs.size(), + "supplied outputs does not match expected outputs"); + for (int i = 0; i < tensor_outputs.size(); i++) { + external_values.push_back( + {_val_to_ids[_outputs[i]], tensor_outputs[i].data_ptr()}); + } + status = xnn_setup_runtime( + auto_runtime.get(), external_values.size(), external_values.data()); + TORCH_CHECK(xnn_status_success == status, "runtime not properly setup"); + + TORCH_CHECK(xnn_status_success == xnn_invoke_runtime(auto_runtime.get())); +} + +void XNNGraph::checkOpsToDelegate(std::shared_ptr& graph) { + std::unordered_set unsupported_ops; + DepthFirstGraphNodeIterator it(graph); + Node* node = nullptr; + while ((node = it.next()) != nullptr) { + switch (node->kind()) { + case prim::Constant: + case aten::add: { + break; + } + default: { + unsupported_ops.insert(node->kind().toDisplayString()); + } + } + } + std::stringstream error; + for (auto itr = unsupported_ops.begin(); itr != unsupported_ops.end(); + itr++) { + error << *itr << std::endl; + ; + } + TORCH_CHECK( + unsupported_ops.empty(), + "the module contains the following unsupported ops:\n" + error.str()); +} + +std::string XNNGraph::serializedXNNGraph() { + std::vector input_ids; + std::vector output_ids; + std::unordered_set num_externs; + + for (auto val : _inputs) { + input_ids.push_back(_val_to_ids[val]); + num_externs.emplace(_val_to_ids[val]); + } + + for (auto val : _outputs) { + output_ids.push_back(_val_to_ids[val]); + num_externs.emplace(_val_to_ids[val]); + } + + return _serializer.finishAndSerialize( + input_ids, output_ids, num_externs.size()); +} + +std::vector> XNNGraph::getGraphOutputShapes() { + std::vector> output_shapes; + for (auto val : _outputs) { + auto tensor_ptr = val->type()->cast(); + std::vector sizes = tensor_ptr->sizes().concrete_sizes().value(); + output_shapes.push_back(sizes); + } + + return output_shapes; +} + +void XNNGraph::defineAllNodes(std::shared_ptr& graph) { + DepthFirstGraphNodeIterator it(graph); + Node* node = nullptr; + while ((node = it.next()) != nullptr) { + switch (node->kind()) { + case prim::Constant: { + break; + } + case aten::add: { + // todo: handle alpha for aten::add + uint32_t input1_id = _val_to_ids[node->inputs()[0]]; + uint32_t input2_id = _val_to_ids[node->inputs()[1]]; + TORCH_CHECK( + node->inputs()[2]->type()->cast() == 1, + "non-1 alpha values not supported"); + uint32_t output_id = _val_to_ids[node->outputs()[0]]; + + xnn_status status = xnn_define_add2( + _subgraph_ptr, + output_min, + output_max, + input1_id, + input2_id, + output_id, + /*flags=*/0); + _serializer.serializeAddNode(input1_id, input2_id, output_id, 0); + TORCH_CHECK(status == xnn_status_success, "failed to create add node"); + break; + } + default: { + throw std::exception(); + TORCH_CHECK( + false, + "The node of ", + node->kind().toQualString(), + " is not supported yet"); + break; + } + } + } +} + +void XNNGraph::defineAllTensorValues() { + uint32_t external_id = + std::numeric_limits::min(); + for (auto val : _intermediate_tensors) { + if (_val_to_ids.find(val) == _val_to_ids.end()) { + uint32_t id = XNN_INVALID_VALUE_ID; + + // cast value to tensortype + auto tensor_ptr = val->type()->cast(); + auto num_dims = tensor_ptr->dim().value(); + + // create size_t* for tensor shape, casting must be done from long -> + // size_t + std::vector sizes = tensor_ptr->sizes().concrete_sizes().value(); + std::vector tensor_shape; + tensor_shape.reserve(sizes.size()); + for (auto dim : sizes) { + TORCH_CHECK(dim >= 0, "Input Dims should be unsigned"); + tensor_shape.push_back(static_cast(dim)); + } + + // ext_id value + uint32_t ext_id = XNN_INVALID_VALUE_ID; + + // update flag for if tensor is either graph input/output + uint32_t flags = 0; + + // Check if value was produced by prim::Constant + void* value_data = nullptr; + size_t buffer_idx = 0; + size_t num_bytes = 0; + if (val->node()->kind() == prim::Constant) { + c10::optional constant = val->node()->t(attr::value); + auto const_val = constant->toIValue().toTensor(); + // Need tensor data to be contiguous for serialization + auto cont_const_val = const_val.contiguous(); + value_data = cont_const_val.data_ptr(); + + num_bytes = const_val.storage().nbytes(); + buffer_idx = _serializer.serializeData( + static_cast(value_data), num_bytes); + } + + if (isGraphInput(val) || isGraphOutput(val)) { + if (isGraphInput(val)) { + flags |= XNN_VALUE_FLAG_EXTERNAL_INPUT; + } + if (isGraphOutput(val)) { + flags |= XNN_VALUE_FLAG_EXTERNAL_OUTPUT; + } + ext_id = external_id++; + } + xnn_status status = xnn_define_tensor_value( + /*subgraph=*/_subgraph_ptr, + /*datatype=*/xnn_datatype_fp32, + /*num_dims=*/num_dims, + /*dims=*/tensor_shape.data(), + /*data=*/value_data, + /*external_id=*/ext_id, + /*flags=*/flags, + /*id_out=*/&id); + TORCH_CHECK( + status == xnn_status_success, + "failed to define xnn_tensor_id for: " + val->debugName()); + _serializer.serializeTensorValue( + xnn_datatype_fp32, + num_dims, + tensor_shape, + buffer_idx, + ext_id, + flags, + id); + _val_to_ids.insert({val, id}); + } + } +} + +void XNNGraph::gatherTensorValues(std::shared_ptr& graph) { + for (auto input : graph->inputs()) { + if (input->isCompleteTensor()) { + _intermediate_tensors.insert(input); + _inputs.push_back(input); + } + } + + DepthFirstGraphNodeIterator it(graph); + Node* n = nullptr; + while ((n = it.next()) != nullptr) { + gatherNodeInputs(*n); + } + + for (auto output : graph->outputs()) { + if (output->isCompleteTensor()) { + _intermediate_tensors.insert(output); + _outputs.push_back(output); + } + } +} + +void XNNGraph::gatherNodeInputs(torch::jit::Node& node) { + switch (node.kind()) { + case aten::add: { + // this case will support all ops with only two inputs i.e. sub, add, + for (auto value : node.inputs()) { + if (value->isCompleteTensor()) { + _intermediate_tensors.insert(value); + } + } + } + } +} + +bool XNNGraph::isGraphInput(torch::jit::Value* val) { + return std::count(_inputs.begin(), _inputs.end(), val) > 0; +}; + +bool XNNGraph::isGraphOutput(torch::jit::Value* val) { + return std::count(_outputs.begin(), _outputs.end(), val) > 0; +}; + +} // namespace delegate +} // namespace xnnpack +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/backends/xnnpack/xnnpack_graph_builder.h b/torch/csrc/jit/backends/xnnpack/xnnpack_graph_builder.h new file mode 100644 index 0000000000000..0ef0757f23196 --- /dev/null +++ b/torch/csrc/jit/backends/xnnpack/xnnpack_graph_builder.h @@ -0,0 +1,93 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +#include +#include +#include +#include +#include +#include + +#include + +namespace torch { +namespace jit { +namespace xnnpack { +namespace delegate { + +class XNNGraph { + private: + const float output_min = -std::numeric_limits::infinity(); + const float output_max = std::numeric_limits::infinity(); + + // serializer class + XNNSerializer _serializer; + // xnn subgraph + xnn_subgraph_t _subgraph_ptr; + // Set of all the tensor values throughout the jit graph + std::unordered_set _intermediate_tensors; + // Set of all the tensor values mapped to the xnnpack ids + std::unordered_map _val_to_ids; + // Vector containing the torch valued inputs/outputs, + // must be ordered to preserve the order of input/outputs + std::vector _inputs; + std::vector _outputs; + + // Graph passes for optimizing and tracing torchscript graph + // Essentially massaging the graph into a digestiable format for + // xnnpack graph lowering. + std::shared_ptr optimizeAndTraceGraph( + std::shared_ptr graph, + std::vector& example_inputs); + + // Gather all the intermediate tensor values within a graph. This + // skips through all prim constants. The purpose of this is for defining + // the tensor values beforehand for the xnnpack subgraph. + void gatherTensorValues(std::shared_ptr& graph); + + // Gathers the tensor values in a give node + void gatherNodeInputs(torch::jit::Node& node); + + // Helper function to determine if a jit value is a graph input + bool isGraphInput(torch::jit::Value* val); + + // Helper function to determine if a jit value is a graph output + bool isGraphOutput(torch::jit::Value* val); + + // Defines all xnnpack nodes for the nodes in the graph + void defineAllNodes(std::shared_ptr& graph); + + // Defines all xnn tensor values used throughout the graph + void defineAllTensorValues(); + + // Makes a pass through the graph and throws if any ops are unsupported + void checkOpsToDelegate(std::shared_ptr& graph); + + public: + XNNGraph() : _serializer(), _subgraph_ptr(nullptr) { + xnn_status status = xnn_initialize(/*allocator =*/nullptr); + TORCH_CHECK(xnn_status_success == status, "Failed to initialize xnnpack"); + } + + ~XNNGraph() { + xnn_deinitialize(); + if (_subgraph_ptr != nullptr) { + xnn_delete_subgraph(_subgraph_ptr); + } + } + + void buildXNNGraph( + std::shared_ptr& graph, + std::vector example_inputs); + + void runGraphOnInputs( + std::vector tensor_inputs, + std::vector tensor_outputs); + + std::string serializedXNNGraph(); + + std::vector> getGraphOutputShapes(); +}; + +} // namespace delegate +} // namespace xnnpack +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/README.md b/torch/csrc/jit/codegen/cuda/README.md index be8aed6c5ce44..284fd14111962 100644 --- a/torch/csrc/jit/codegen/cuda/README.md +++ b/torch/csrc/jit/codegen/cuda/README.md @@ -197,8 +197,8 @@ First thing is to check that you have fusion kernel running properly. Try to run If turning on NVFuser produces unexpected outputs, set the `PYTORCH_NVFUSER_DISABLE` environment variable to disable some of the optional features, e.g.: - `fma`: disable using FMA instructions -- `index_hoist`: disble optimization to hoist comon index expressions -- `predicate_elimination`: disble optimization to eliminate redundant predicates +- `index_hoist`: disable optimization to hoist common index expressions +- `predicate_elimination`: disable optimization to eliminate redundant predicates - `unroll_with_rng`: disable unrolling when RNG is used For example, `export PYTORCH_NVFUSER_DISABLE=fma,index_hoist` would disable FMA and index hoisting. diff --git a/torch/csrc/jit/codegen/cuda/arith.cpp b/torch/csrc/jit/codegen/cuda/arith.cpp index de282dfc8182a..d4e1348ee6933 100644 --- a/torch/csrc/jit/codegen/cuda/arith.cpp +++ b/torch/csrc/jit/codegen/cuda/arith.cpp @@ -449,10 +449,99 @@ TensorView* rand(const std::vector& shape, DataType dtype) { .contiguity(std::vector(n, true)) .shape(shape) .build(); - IrBuilder::create(RNGOpType::Uniform, out); + IrBuilder::create(RNGOpType::Uniform, out, dtype); return out; } +// TENSOR FACTORIES +TensorView* uniform( + const std::vector& shape, + Val* low, + Val* high, + DataType dtype) { + auto n = shape.size(); + auto out = TensorViewBuilder() + .ndims(n) + .dtype(dtype) + .contiguity(std::vector(n, true)) + .shape(shape) + .build(); + IrBuilder::create( + RNGOpType::UniformRange, out, dtype, std::vector{low, high}); + return out; +} + +TensorView* rand_like(TensorView* tv) { + TORCH_CHECK( + isFloatingPointType(tv->dtype()), + "input must have floating point type, but got ", + tv->dtype()); + std::vector shape; + auto dom = TensorDomain::noReductions(tv->getMaybeRFactorDomain()); + shape.reserve(dom.size()); + for (auto id : dom) { + shape.emplace_back(id->getMaybeExpandedExtent()); + } + return rand(shape, tv->dtype()); +} + +Val* rand_like(Val* v) { + return rand_like(v->as()); +} + +TensorView* full( + const std::vector& shape, + Val* fill_value, + DataType dtype) { + auto n = shape.size(); + auto out = TensorViewBuilder() + .ndims(n) + .dtype(dtype) + .contiguity(std::vector(n, true)) + .shape(shape) + .build(); + IrBuilder::create(out, fill_value, dtype); + return out; +} + +TensorView* full_like(TensorView* tv, Val* fill_value) { + std::vector shape; + auto dom = TensorDomain::noReductions(tv->getMaybeRFactorDomain()); + shape.reserve(dom.size()); + for (auto id : dom) { + shape.emplace_back(id->getMaybeExpandedExtent()); + } + return full(shape, fill_value, tv->dtype()); +} + +Val* full_like(Val* v, Val* fill_value) { + return full_like(v->as(), fill_value); +} + +TensorView* zeros(const std::vector& shape, DataType dtype) { + return full(shape, FusionGuard::getCurFusion()->zeroVal(), dtype); +} + +TensorView* zeros_like(TensorView* tv) { + return full_like(tv, FusionGuard::getCurFusion()->zeroVal()); +} + +Val* zeros_like(Val* v) { + return zeros_like(v->as()); +} + +TensorView* ones(const std::vector& shape, DataType dtype) { + return full(shape, FusionGuard::getCurFusion()->oneVal(), dtype); +} + +TensorView* ones_like(TensorView* tv) { + return full_like(tv, FusionGuard::getCurFusion()->oneVal()); +} + +Val* ones_like(Val* v) { + return ones_like(v->as()); +} + TensorView* arange(Val* end, DataType dtype) { return arange(FusionGuard::getCurFusion()->zeroVal(), end, dtype); } @@ -471,17 +560,36 @@ TensorView* arange(Val* start, Val* end, Val* step, DataType dtype) { end = castOp(DataType::Double, end); step = castOp(DataType::Double, step); } - auto size = castOp(DataType::Int, ceilDiv(sub(end, start), step)); + // Make sure no negative value is passed to ceilDiv as the device + // implementation of ceilDiv assumes positive inputs + auto size = castOp(DataType::Int, ceilDiv(abs(sub(end, start)), abs(step))); auto out = TensorViewBuilder() .ndims(1) .dtype(dtype) .contiguity({true}) .shape({size}) .build(); - IrBuilder::create(out, start, end, step); + IrBuilder::create(out, start, end, step, dtype); + return out; +} + +TensorView* eye(Val* rows, Val* cols, DataType dtype) { + TORCH_CHECK(rows->getDataType() == DataType::Int, "rows must have type Int"); + TORCH_CHECK(cols->getDataType() == DataType::Int, "cols must have type Int"); + auto out = TensorViewBuilder() + .ndims(2) + .dtype(dtype) + .contiguity({true, true}) + .shape(std::vector{rows, cols}) + .build(); + IrBuilder::create(out, dtype); return out; } +TensorView* eye(Val* size, DataType dtype) { + return eye(size, size, dtype); +} + // UNARY OPERATIONS #define NVFUSER_DEFINE_UNARY_OP(op_name, op_type) \ @@ -504,23 +612,6 @@ NVFUSER_DEFINE_UNARY_OP(trunc, Trunc) NVFUSER_DEFINE_UNARY_OP(print, Print) #undef NVFUSER_DEFINE_UNARY_OP -TensorView* randlike(TensorView* v) { - TORCH_CHECK( - isFloatingPointType(v->dtype()), - "input must have floating point type, but got ", - v->dtype()); - std::vector shape; - shape.reserve(v->getMaybeRFactorDomain().size()); - for (auto id : v->getMaybeRFactorDomain()) { - shape.emplace_back(id->getMaybeExpandedExtent()); - } - return rand(shape, v->dtype()); -} - -Val* randlike(Val* v) { - return randlike(v->as()); -} - Val* bitwise_not(Val* v) { TORCH_CHECK( isIntegralType(v->dtype()) || v->dtype() == DataType::Bool, @@ -1003,7 +1094,7 @@ static TensorView* newForReduction( TORCH_INTERNAL_ASSERT( !axes_set.empty(), - "Asked for ouput of reduction, but no reduction axis provided."); + "Asked for output of reduction, but no reduction axis provided."); TORCH_INTERNAL_ASSERT( (*(axes_set.rbegin())) < orig_domain.size(), @@ -1092,7 +1183,7 @@ TensorView* reductionOp( TORCH_CHECK( axis >= 0 && axis < ndims, - "Reduction on invalid axis, recieved: ", + "Reduction on invalid axis, received: ", axis, " however tensor view only has ", ndims, @@ -1427,7 +1518,7 @@ WelfordResult Welford( TORCH_CHECK( axis >= 0 && axis < ndims, - "Reduction on invalid axis, recieved: ", + "Reduction on invalid axis, received: ", axis, " however tensor view only has ", ndims, @@ -2137,7 +2228,7 @@ static TensorView* newForMma( TORCH_INTERNAL_ASSERT( !axes_set.empty(), - "Asked for ouput of reduction, but no reduction axis provided."); + "Asked for output of reduction, but no reduction axis provided."); TORCH_INTERNAL_ASSERT( (*(axes_set.rbegin())) < orig_domain_a.size(), @@ -2228,7 +2319,7 @@ TensorView* fusedMultiplySum( TORCH_CHECK( axis >= 0 && axis < ndims, - "Reduction on invalid axis, recieved: ", + "Reduction on invalid axis, received: ", axis, " however tensor view only has ", ndims, diff --git a/torch/csrc/jit/codegen/cuda/arith.h b/torch/csrc/jit/codegen/cuda/arith.h index d8e6b65882146..66344c74880c0 100644 --- a/torch/csrc/jit/codegen/cuda/arith.h +++ b/torch/csrc/jit/codegen/cuda/arith.h @@ -121,10 +121,39 @@ TORCH_CUDA_CU_API WelfordResult Welford( // import IrBuilder just for this one interface. Int* init_N = nullptr); -// TENSOR FACTORIES +// RNG OPERATIONS TORCH_CUDA_CU_API TensorView* rand( const std::vector& shape, DataType dtype); +TORCH_CUDA_CU_API Val* rand_like(Val*); +TORCH_CUDA_CU_API TensorView* rand_like(TensorView*); + +TORCH_CUDA_CU_API TensorView* uniform( + const std::vector& shape, + Val* low, + Val* high, + DataType dtype); + +// TENSOR FACTORIES +TORCH_CUDA_CU_API TensorView* full( + const std::vector& shape, + Val* fill_value, + DataType dtype); +TORCH_CUDA_CU_API TensorView* full_like(TensorView* tv, Val* fill_value); +TORCH_CUDA_CU_API Val* full_like(Val* tv, Val* fill_value); +TORCH_CUDA_CU_API TensorView* zeros( + const std::vector& shape, + DataType dtype); +TORCH_CUDA_CU_API TensorView* zeros_like(TensorView*); +TORCH_CUDA_CU_API Val* zeros_like(Val*); +TORCH_CUDA_CU_API TensorView* ones( + const std::vector& shape, + DataType dtype); +TORCH_CUDA_CU_API TensorView* ones_like(TensorView*); +TORCH_CUDA_CU_API Val* ones_like(Val*); +//! WARNING: giving invalid combinations of the start, end and step +//! arguments can result in undefined behavior. Specifically, the +//! signs of `end - start` and step must be the same. TORCH_CUDA_CU_API TensorView* arange(Val* end, DataType dtype = DataType::Int); TORCH_CUDA_CU_API TensorView* arange( Val* start, @@ -135,6 +164,8 @@ TORCH_CUDA_CU_API TensorView* arange( Val* end, Val* step, DataType dtype = DataType::Int); +TORCH_CUDA_CU_API TensorView* eye(Val* size, DataType dtype); +TORCH_CUDA_CU_API TensorView* eye(Val* rows, Val* cols, DataType dtype); // UNARY OPERATIONS // abs @@ -200,9 +231,6 @@ TORCH_CUDA_CU_API TensorView* log2(TensorView*); // neg TORCH_CUDA_CU_API Val* neg(Val*); TORCH_CUDA_CU_API TensorView* neg(TensorView*); -// randlike -TORCH_CUDA_CU_API Val* randlike(Val*); -TORCH_CUDA_CU_API TensorView* randlike(TensorView*); // real TORCH_CUDA_CU_API Val* real(Val*); TORCH_CUDA_CU_API TensorView* real(TensorView*); diff --git a/torch/csrc/jit/codegen/cuda/codegen.cpp b/torch/csrc/jit/codegen/cuda/codegen.cpp index 6ebb2753ecb8a..e62528fdabc3e 100644 --- a/torch/csrc/jit/codegen/cuda/codegen.cpp +++ b/torch/csrc/jit/codegen/cuda/codegen.cpp @@ -264,9 +264,6 @@ class CudaKernelGenerator : private OptOutConstDispatch { indent() << " static_cast(*(philox_args.offset_.ptr) + philox_args.offset_intragraph_) :\n"; indent() << " philox_args.offset_.val;\n"; - indent() << "auto seed = philox_args.captured_ ?\n"; - indent() - << " static_cast(*(philox_args.seed_.ptr)) : philox_args.seed_.val;\n"; indent() << "uint4 rng_result;\n"; indent() << "nvfuser_index_t rng_subseq = -1;\n"; indent() << "nvfuser_index_t rng_offset = -1;\n"; @@ -546,9 +543,18 @@ class CudaKernelGenerator : private OptOutConstDispatch { void genCpAsync(const LoadStoreOp* ldst, int vec_size) { auto dtype = ldst->in()->getDataType().value(); - indent() << "Ampere::cpAsync(" - << genVectorPointer(ldst->out(), dtype, vec_size) << "," - << genVectorPointer(ldst->in(), dtype, vec_size) << ");\n"; + if (ldst->predicate() == nullptr) { + // Out of line predicate variant + indent() << "Ampere::cpAsync(" + << genVectorPointer(ldst->out(), dtype, vec_size) << "," + << genVectorPointer(ldst->in(), dtype, vec_size) << ");\n"; + } else { + // Inline predicate variant + indent() << "Ampere::cpAsync(" + << genVectorPointer(ldst->out(), dtype, vec_size) << "," + << genVectorPointer(ldst->in(), dtype, vec_size) << "," + << genInline(ldst->predicate()) << ");\n"; + } } void genLdMatrix(const LoadStoreOp* ldst, int vector_word_size) { @@ -563,14 +569,26 @@ class CudaKernelGenerator : private OptOutConstDispatch { << "&" << gen(ldst->in()) << ");\n"; } + void handle(const FullOp* fop) final { + indent() << gen(fop->output(0)) << " = (" << fop->dtype() << ")" + << gen(fop->getFillValue()) << ";\n"; + } + void handle(const ARangeOp* aop) final { - auto index = genTensorIndex(aop->getLinearIndex()->as()); - indent() << gen(aop->output(0)) << " = arange<" << aop->output(0)->dtype() - << ">"; + auto index = + genTensorIndex(aop->getLinearLogicalIndex()->as()); + indent() << gen(aop->output(0)) << " = arange<" << aop->dtype() << ">"; code_ << "(" << index << ", " << gen(aop->start()) << ", " << gen(aop->step()) << ");\n"; } + void handle(const EyeOp* aop) final { + auto index1 = gen(aop->getIndex1()); + auto index2 = gen(aop->getIndex2()); + indent() << gen(aop->output(0)) << " = (" << aop->dtype() << ")"; + code_ << "(" << index1 << " == " << index2 << ");\n"; + } + void handle(const UnaryOp* uop) final { bool is_vector_op = false; size_t vector_word_size = 1; @@ -762,9 +780,8 @@ class CudaKernelGenerator : private OptOutConstDispatch { void handle(const RNGOp* rop) final { // TODO: TORCH_INTERNAL_ASSERT that the scheduler correctly creates an // innermost ID of size 4 (float) or size 2 (double)? - auto out_tv = rop->output(0)->as()->view(); auto index = genTensorIndex(rop->getPhiloxIndex()->as()); - int multiple = out_tv->getDataType() == DataType::Double ? 2 : 4; + int multiple = rop->dtype() == DataType::Double ? 2 : 4; indent() << "nvfuser_index_t linear_index" << rop->name() << " = " << index << ";\n"; indent() << "nvfuser_index_t rng_subseq" << rop->name() << " = linear_index" @@ -775,6 +792,9 @@ class CudaKernelGenerator : private OptOutConstDispatch { << rop->getRNGOffset() << ";\n"; indent() << "if (rng_subseq != rng_subseq" << rop->name() << " || rng_offset != rng_offset" << rop->name() << ") {\n"; + indent() << " auto seed = philox_args.captured_ ?\n" + << " static_cast(*(philox_args.seed_.ptr)) : \n" + << " philox_args.seed_.val;\n"; indent() << " rng_result = philox(seed, rng_subseq" << rop->name() << ", philox_offset / 4 + rng_offset" << rop->name() << ");\n"; indent() << " rng_subseq = rng_subseq" << rop->name() << ";\n"; @@ -782,11 +802,20 @@ class CudaKernelGenerator : private OptOutConstDispatch { indent() << "}\n"; auto op_type = rop->getRNGOpType(); indent() << gen(rop->output(0)) << " = " << op_type; - if (needFloatSuffix(op_type) && - rop->output(0)->dtype() == DataType::Float) { + if (needFloatSuffix(op_type) && rop->dtype() == DataType::Float) { code_ << "f"; } - code_ << "(rng_result, rng_component" << rop->name() << ");\n"; + code_ << "(rng_result, rng_component" << rop->name(); + switch (op_type) { + case RNGOpType::UniformRange: { + auto parameters = rop->getParameters(); + TORCH_INTERNAL_ASSERT(parameters.size() == 2); + code_ << ", " << gen(parameters[0]) << ", " << gen(parameters[1]); + break; + } + default:; + } + code_ << ");\n"; } std::string genBinaryOp( diff --git a/torch/csrc/jit/codegen/cuda/compute_at.cpp b/torch/csrc/jit/codegen/cuda/compute_at.cpp index ae6231614b7ff..d8f950848f8fc 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at.cpp +++ b/torch/csrc/jit/codegen/cuda/compute_at.cpp @@ -213,20 +213,21 @@ void ComputeAt::runAt( auto selected = getPropagationSubgraph(producer, consumer); ComputeAtSelector selector(selected); - InlinePropagator inline_propagator( - consumer, consumer_position, mode, selector.selected()); - MaxRootDomainInfoSpanningTree path(consumer, consumer_position, &selector); if (mode == ComputeAtMode::MostInlined) { MostInlinedTransformPropagator propagator; path.traverse(&propagator); + inlineMost(selected); } else { TransformPropagator propagator(consumer, consumer_position); path.traverse(&propagator); + inlineSelectedAt( + selected, + consumer, + consumer_position, + mode == ComputeAtMode::BestEffort); } - - path.traverse(&inline_propagator); } void ComputeAt::runWith( @@ -253,19 +254,21 @@ void ComputeAt::runWith( auto selected = getPropagationSubgraph(producer, consumer); ComputeAtSelector selector(selected); - InlinePropagator inline_propagator( - producer, producer_position, mode, selector.selected()); - MaxRootDomainInfoSpanningTree path(producer, producer_position, &selector); if (mode == ComputeAtMode::MostInlined) { MostInlinedTransformPropagator propagator; path.traverse(&propagator); + inlineMost(selected); } else { TransformPropagator propagator(producer, producer_position); path.traverse(&propagator); + inlineSelectedAt( + selected, + producer, + producer_position, + mode == ComputeAtMode::BestEffort); } - path.traverse(&inline_propagator); } } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/compute_at.h b/torch/csrc/jit/codegen/cuda/compute_at.h index 98100334d72b6..d3d3fdb299dd6 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at.h +++ b/torch/csrc/jit/codegen/cuda/compute_at.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include diff --git a/torch/csrc/jit/codegen/cuda/compute_at_map.cpp b/torch/csrc/jit/codegen/cuda/compute_at_map.cpp index e223c0ce51646..1c2ac627b5756 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at_map.cpp +++ b/torch/csrc/jit/codegen/cuda/compute_at_map.cpp @@ -6,6 +6,8 @@ #include #include +#include + namespace torch { namespace jit { namespace fuser { @@ -29,8 +31,22 @@ bool idIsALeafDomain(IterDomain* id, TensorView* tv) { } // namespace -IterDomainGraph::IterDomainGraph(Fusion* fusion) { +IterDomainGraph::IterDomainGraph(Fusion* fusion, bool allow_self_mapping) { build(fusion); + + if (!allow_self_mapping) { + TORCH_INTERNAL_ASSERT( + !hasSelfMapping(), + "Unsupported domain mapping detected in ", + std::get<0>(*self_mapping_info_)->toString(), + ". ", + std::get<3>(*self_mapping_info_), + " domains, ", + std::get<1>(*self_mapping_info_)->toString(), + " and ", + std::get<2>(*self_mapping_info_)->toString(), + ", are mapped with each other."); + } } //! Map corresponding inputs and outputs of swizzle op together @@ -55,7 +71,11 @@ void mapMaybeSwizzleOp( } } -bool IterDomainGraph::exprsMap(Expr* first, Expr* second, bool forward) { +bool IterDomainGraph::exprsMap( + Expr* first, + Expr* second, + bool forward, + const DisjointSets& id_map) { if (first == nullptr || second == nullptr) { return false; } @@ -101,8 +121,7 @@ bool IterDomainGraph::exprsMap(Expr* first, Expr* second, bool forward) { zipped_ids.begin(), zipped_ids.end(), [&](std::pair id_pair) { - return !exact_nodes_.strictAreMapped( - id_pair.first, id_pair.second); + return !id_map.strictAreMapped(id_pair.first, id_pair.second); })) { return false; } @@ -151,7 +170,7 @@ void IterDomainGraph::mapThroughExpr(Expr* first, Expr* second, bool forward) { return; } - if (!exprsMap(first, second, forward)) { + if (!exprsMap(first, second, forward, exact_nodes_)) { return; } @@ -173,6 +192,78 @@ void IterDomainGraph::mapThroughExpr(Expr* first, Expr* second, bool forward) { } } +namespace { + +// Returns a pair of mapped IDs +c10::optional> detectMappablePair( + const std::vector& ids, + const IterDomainGraph& id_graph) { + for (auto id1 : ids) { + for (auto id2 : ids) { + if (id1 == id2) { + continue; + } + if (id_graph.permissiveNodes().disjointSetMap().at(id1)->has(id2)) { + return std::make_pair(id1, id2); + } + } + } + + return {}; +} + +// It is assumed that for any tensor represented by a list of domains, +// those domains should never be mapped with each other. It may be +// possible to lift this assumption, but it's unclear if it could +// matter in practice. +c10::optional> +findFirstSelfMapping(Fusion* fusion, const IterDomainGraph& id_graph) { + for (auto tv : ir_utils::allTvs(fusion)) { + // For each tensor, make sure root, rfactor and leaf domains + // should not include domains that are mapped with another domain + // in the same set of domains. This may be overly conservative, + // and it maybe enough to check the root domains. + + // Root domains + auto self_mappped_root_pair = + detectMappablePair(tv->getRootDomain(), id_graph); + if (self_mappped_root_pair.has_value()) { + return std::make_tuple( + tv, + self_mappped_root_pair->first, + self_mappped_root_pair->second, + "Root"); + } + + // Rfactor domains + if (tv->hasRFactor()) { + auto self_mappped_rf_pair = + detectMappablePair(tv->getRFactorDomain(), id_graph); + if (self_mappped_rf_pair.has_value()) { + return std::make_tuple( + tv, + self_mappped_rf_pair->first, + self_mappped_rf_pair->second, + "RFactor"); + } + } + + // Leaf domains + auto self_mappped_leaf_pair = + detectMappablePair(tv->domain()->domain(), id_graph); + if (self_mappped_leaf_pair.has_value()) { + return std::make_tuple( + tv, + self_mappped_leaf_pair->first, + self_mappped_leaf_pair->second, + "Leaf"); + } + } + return c10::nullopt; +} + +} // namespace + void IterDomainGraph::build(Fusion* fusion) { FusionGuard fg(fusion); @@ -240,7 +331,7 @@ void IterDomainGraph::build(Fusion* fusion) { c_tv->getRootDomain().size() == first_output_tv->getRootDomain().size(), "Multiple outputs with mismatched dimensions is not supported. ", - "Only supported case is welford op where all outputs tvs have idential domains."); + "Only supported case is welford op where all outputs tvs have identical domains."); // p->f, c->c std::unordered_map c2f_root_map; for (const auto i : @@ -515,6 +606,7 @@ void IterDomainGraph::build(Fusion* fusion) { } } } + self_mapping_info_ = findFirstSelfMapping(fusion, *this); } void IterDomainGraph::initializeId( @@ -587,7 +679,7 @@ void ComputeAtMap::allocateIndexVariables() { // Halo extended parallel loops currently are handled // differently and an index variable would still // be allocated in this case. - (GpuLower::current()->haloInfo().getExtent(id) == nullptr)) { + (GpuLower::current()->haloInfo()->getExtent(id) == nullptr)) { ptype = id->getParallelType(); return true; } diff --git a/torch/csrc/jit/codegen/cuda/compute_at_map.h b/torch/csrc/jit/codegen/cuda/compute_at_map.h index 31c2d8752f712..5ea92dff16447 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at_map.h +++ b/torch/csrc/jit/codegen/cuda/compute_at_map.h @@ -54,7 +54,7 @@ namespace cuda { // Do not forward through any broadcast IDs class TORCH_CUDA_CU_API IterDomainGraph { public: - IterDomainGraph(Fusion* fusion); + IterDomainGraph(Fusion* fusion, bool allow_self_mapping = false); const DisjointSets& permissiveNodes() const { return permissive_nodes_; @@ -88,15 +88,25 @@ class TORCH_CUDA_CU_API IterDomainGraph { return view_rfactor_ids_; } + // Returns if first and second are expressions through which the provided + // id_map have matching inputs (if forward), or outputs (if not forward). + // Returning true means the expressions are "the same", in terms they modify + // matching original extents, by the same amount. + static bool exprsMap( + Expr* first, + Expr* second, + bool forward, + const DisjointSets& id_map); + + bool hasSelfMapping() const { + return self_mapping_info_.has_value(); + } + private: void build(Fusion* fusion); void initializeId(IterDomain* id, bool is_view_rfactor_id, bool is_leaf_id); - // Returns if first and second are expressions with inputs match through exact - // map (if forward), or outputs match (if not forward). - bool exprsMap(Expr* first, Expr* second, bool forward); - // Checks if exprsMap then if forward will map outputs else inputs in exact // and permissive map. void mapThroughExpr(Expr* first, Expr* second, bool forward); @@ -116,6 +126,9 @@ class TORCH_CUDA_CU_API IterDomainGraph { VectorOfUniqueEntries all_ids_; std::unordered_set view_rfactor_ids_; + + c10::optional> + self_mapping_info_ = c10::nullopt; }; class TrivialReductionInfo; diff --git a/torch/csrc/jit/codegen/cuda/contiguity.cpp b/torch/csrc/jit/codegen/cuda/contiguity.cpp index 4817693bebdc3..dcb39d948c672 100644 --- a/torch/csrc/jit/codegen/cuda/contiguity.cpp +++ b/torch/csrc/jit/codegen/cuda/contiguity.cpp @@ -1,4 +1,5 @@ #include +#include #include #include @@ -8,20 +9,454 @@ namespace jit { namespace fuser { namespace cuda { +OrderedIdInformation::OrderedIdInformation( + const std::vector& ids, + const std::vector& root_domain, + std::shared_ptr concrete_info) + : active_ids_(root_domain), concrete_info_(concrete_info) { + if (ids.empty() || root_domain.empty()) { + return; + } + + // Grab root ids and initialize them. + for (const auto root_i : c10::irange(root_domain.size())) { + auto root_id = root_domain[root_i]->as(); + + // Initialize id_to_root_ids to map roots to themselves + id_to_root_ids_[root_id] = {root_id}; + + // Initialize roots as being made up of correctly ordered transforms. + consistently_ordered_ids_.emplace(root_id); + + exclusively_consumes_roots_.emplace(root_id); + } + + // Iterate from the root domain to the provided ids and fill + // consistently_ordered_ids_, id_to_root_ids_, and exclusively_consumes_roots_ + // for all the IDs + auto exprs = StmtSort::getExprsBetween( + ids[0]->fusion(), + {root_domain.begin(), root_domain.end()}, + {ids.begin(), ids.end()}); + + for (auto expr : exprs) { + OptInDispatch::handle(expr); + } +} + +bool OrderedIdInformation::checkExclusivelyConsumesRoots(IterDomain* id) { + TORCH_INTERNAL_ASSERT( + std::find(active_ids_.begin(), active_ids_.end(), id) != + active_ids_.end(), + "Error replaying transforms in contiguous ID checker, expected ", + id->toString(), + " to be in the active ID set."); + + auto root_id_it = id_to_root_ids_.find(id); + TORCH_INTERNAL_ASSERT( + root_id_it != id_to_root_ids_.end(), + "Error replaying transforms in contiguous ID checker, couldn't find mapped roots of ", + id->toString()); + + const auto& root_ids = root_id_it->second; + + // Check all the roots of all other ids, to see if any root_ids in id are also + // in them. + for (auto other_active_id : active_ids_) { + if (other_active_id == id || other_active_id == nullptr) { + continue; + } + + auto root_id_it = id_to_root_ids_.find(other_active_id); + TORCH_INTERNAL_ASSERT( + root_id_it != id_to_root_ids_.end(), + "Error replaying transforms in contiguous ID checker, couldn't find mapped roots of ", + other_active_id->toString()); + + const auto& other_root_ids = root_id_it->second; + + for (auto other_root_id : other_root_ids) { + if (root_ids.has(other_root_id)) { + return false; + } + } + } + return true; +} + +void OrderedIdInformation::handle(Merge* merge) { + // Find inputs in the active_ids_ vector + const auto inner_it = + std::find(active_ids_.begin(), active_ids_.end(), merge->inner()); + const auto outer_it = + std::find(active_ids_.begin(), active_ids_.end(), merge->outer()); + + // If either aren't in active_ids_ it means the inputs were detected to not be + // ordered correctly before hitting this expression. + if (inner_it == active_ids_.end() || outer_it == active_ids_.end()) { + return; + } + + auto inner_pos = std::distance(active_ids_.begin(), inner_it); + auto outer_pos = std::distance(active_ids_.begin(), outer_it); + + // Find inputs in the ordered transforms map + const auto inner_ordered_it = consistently_ordered_ids_.find(merge->inner()); + const auto outer_ordered_it = consistently_ordered_ids_.find(merge->outer()); + + bool inner_ordered = inner_ordered_it != consistently_ordered_ids_.end(); + bool outer_ordered = outer_ordered_it != consistently_ordered_ids_.end(); + + // Get root ids of the two inputs + const auto inner_root_ids_it = id_to_root_ids_.find(merge->inner()); + const auto outer_root_ids_it = id_to_root_ids_.find(merge->outer()); + + TORCH_INTERNAL_ASSERT( + inner_root_ids_it != id_to_root_ids_.end() && + outer_root_ids_it != id_to_root_ids_.end(), + "Error replaying transforms in contiguous ID checker."); + + const auto& inner_root_ids = inner_root_ids_it->second; + const auto& outer_root_ids = outer_root_ids_it->second; + + // TODO: Concretization may prevent contiguous indexing or vectorization. + // It prevents contiguous indexing if the concretization is within the IDs + // that are used for indexing. + // For vectorization it just means we need to make sure the extents of the + // axes to the right of the broadcast root domain in the contigous merge is + // bigger than the vectorization dimension. And that the tensor buffer + // supports the vector word size (always done). + bool outer_is_concretized_bcast = merge->outer()->isBroadcast() && + concrete_info_->isConcretized(merge->outer()); + + bool inner_is_concretized_bcast = merge->inner()->isBroadcast() && + concrete_info_->isConcretized(merge->inner()); + + // Update maps + // Find the position inner would have to have to be considered ordered + auto pos_after_outer = outer_pos + 1; + for (; pos_after_outer < active_ids_.size(); pos_after_outer++) { + if (active_ids_[pos_after_outer] == nullptr) { + // Can't be considered ordered after a nullptr + break; + } + if (active_ids_[pos_after_outer]->isReduction() || + ((active_ids_[pos_after_outer]->isBroadcast() && + !concrete_info_->isConcretized(active_ids_[pos_after_outer])))) { + // Skip reduction or broadcast axes that aren't concretized in the fusion + continue; + } + break; + } + + // The output is ordered as long as the inputs were ordered and outer position + // is directly left of the inner position. + bool out_ordered = inner_ordered && outer_ordered; + out_ordered = out_ordered && + // If inner_pos is before outer_pos it's not ordered correctly. If for + // some reason it's the same, that would be an error. + inner_pos > outer_pos && + // Inner could be a broadcast, so doesn't have to be right on + // pos_after_outer as that ID (if it exists) should not be a broadcast. + // However, merging over a broadcast should be fine. + inner_pos <= pos_after_outer && !inner_is_concretized_bcast && + !outer_is_concretized_bcast; + + if (out_ordered) { + consistently_ordered_ids_.emplace(merge->out()); + } + + // Don't just remove active_ids_, as if we have something like: + // [i0, i1, i2, i3] + // ->merge(0, 2) + // ->merge(1) + // The latter merge looks like it's ordered correctly, if we update the active + // map as: + // [i0, i1, i2, i3] -> [i0*i2, i1, i3] + // Hoever if we instead mark it as: + // [i0, i1, i2, i3] -> [i0*i2, i1, nullptr, i3] + // Or: + // [i0, i1, i2, i3] -> [nullptr, i1, i0*i2, i3] + // It's clear the second merge is not ordered correctly. Doesn't matter which + // direction we put the iter domain in, prefer putting it in outer as we often + // are looking for inner dimensions that are contiguous. We don't want to + // always do this, as it could make ordered merges look non-ordered. + // For exmaple: [i0, i1, i2, i3] + // ->merge(0) + // ->merge(1) + // ->merge(0) + // If it's updated as: + // [i0, i1, i2, i3] + // -> [i0*i1, nullptr, i2, i3] + // -> [i0*i1, nullptr, i2*i3, nullptr] + // Now the final merge looks non-ordered but it is. So only insert a nullptr + // entry if the out is not ordered. + active_ids_[outer_pos] = merge->out(); + + if (!out_ordered) { + active_ids_[inner_pos] = nullptr; + } else { + active_ids_.erase(active_ids_.begin() + inner_pos); + for (auto i = outer_pos + 1; i < inner_pos; i++) { + // If there's broadcast axes between outer and inner and the merge was + // contiguous, there may be broadcasts between outer and inner that cannot + // be ordered merged anywhere else so remove them. + active_ids_.erase(active_ids_.begin() + outer_pos + 1); + } + } + + // Update the root_id entry for the output. + VectorOfUniqueEntries root_ids = inner_root_ids; + root_ids.pushBack(outer_root_ids); + + id_to_root_ids_[merge->out()] = root_ids; + + // Need to check this after updating active_ids_ and id_to_root_ids_ + if (checkExclusivelyConsumesRoots(merge->out())) { + exclusively_consumes_roots_.emplace(merge->out()); + } +} + +void OrderedIdInformation::handle(Split* split) { + // Find the input in the active_ids_ vector + const auto in_it = + std::find(active_ids_.begin(), active_ids_.end(), split->in()); + + if (in_it == active_ids_.end()) { + return; + } + + auto in_pos = std::distance(active_ids_.begin(), in_it); + + // Find the input in the ordered transforms map + const auto in_ordered_it = consistently_ordered_ids_.find(split->in()); + + bool in_ordered = in_ordered_it != consistently_ordered_ids_.end(); + + // Get root ids of the input + const auto in_root_ids_it = id_to_root_ids_.find(split->in()); + + TORCH_INTERNAL_ASSERT( + in_root_ids_it != id_to_root_ids_.end(), + "Error replaying transforms in contiguous ID checker."); + + VectorOfUniqueEntries in_root_ids = in_root_ids_it->second; + + // Update map for outputs + // Remove inputs from the active_ids_ and insert the output ID + active_ids_[in_pos] = split->outer(); + active_ids_.insert(active_ids_.begin() + in_pos + 1, split->inner()); + + // The outputs are ordered as long as the input is ordered. + if (in_ordered) { + consistently_ordered_ids_.emplace(split->outer()); + consistently_ordered_ids_.emplace(split->inner()); + } + + // Update the root_id entry for the outputs. + id_to_root_ids_[split->outer()] = in_root_ids; + id_to_root_ids_[split->inner()] = in_root_ids; +} + +// Swizzle generally can't be contiguous because of the non-affine nature of it, +// but we can still analyze the operation in the same way as merge/split. +void OrderedIdInformation::handle(Swizzle2D* swizzle) { + // Find inputs in the active_ids_ vector + const auto in_x_it = + std::find(active_ids_.begin(), active_ids_.end(), swizzle->inX()); + const auto in_y_it = + std::find(active_ids_.begin(), active_ids_.end(), swizzle->inY()); + + if (in_x_it == active_ids_.end() || in_y_it == active_ids_.end()) { + return; + } + + auto in_x_pos = std::distance(active_ids_.begin(), in_x_it); + auto in_y_pos = std::distance(active_ids_.begin(), in_y_it); + + // Find inputs in the ordered transforms map + const auto in_x_ordered_it = consistently_ordered_ids_.find(swizzle->inX()); + const auto in_y_ordered_it = consistently_ordered_ids_.find(swizzle->inY()); + + bool in_x_ordered = in_x_ordered_it != consistently_ordered_ids_.end(); + bool in_y_ordered = in_y_ordered_it != consistently_ordered_ids_.end(); + + // Get root ids of the two inputs + const auto in_x_root_ids_it = id_to_root_ids_.find(swizzle->inX()); + const auto in_y_root_ids_it = id_to_root_ids_.find(swizzle->inY()); + + TORCH_INTERNAL_ASSERT( + in_x_root_ids_it != id_to_root_ids_.end() && + in_y_root_ids_it != id_to_root_ids_.end(), + "Error replaying transforms in contiguous ID checker."); + + const auto& in_x_root_ids = in_x_root_ids_it->second; + const auto& in_y_root_ids = in_y_root_ids_it->second; + + // Update map for outputs + // Remove inputs from the active_ids_ and insert the output ID + active_ids_[in_x_pos] = swizzle->outX(); + active_ids_[in_y_pos] = swizzle->outY(); + + // In the case of no real swizzle we can forward properties on each domain + // independently. + if (swizzle->swizzleType() == Swizzle2DType::NoSwizzle) { + if (in_x_ordered) { + consistently_ordered_ids_.emplace(swizzle->outX()); + } + + if (exclusivelyConsumesRoots(swizzle->inX())) { + exclusively_consumes_roots_.emplace(swizzle->outX()); + } + + if (in_y_ordered) { + consistently_ordered_ids_.emplace(swizzle->outY()); + } + + if (exclusivelyConsumesRoots(swizzle->inY())) { + exclusively_consumes_roots_.emplace(swizzle->outY()); + } + + id_to_root_ids_[swizzle->outX()] = in_x_root_ids; + id_to_root_ids_[swizzle->outY()] = in_y_root_ids; + } else { + VectorOfUniqueEntries root_ids = in_x_root_ids; + root_ids.pushBack(in_y_root_ids); + id_to_root_ids_[swizzle->outX()] = root_ids; + id_to_root_ids_[swizzle->outY()] = root_ids; + } +} + +NonDivisibleSplitDependencies::NonDivisibleSplitDependencies( + // TODO: Revisit reduction rfactor axes and propagation. Should probably use + // ca_map to propogate non divisibility dependencies across exact map. Still + // need to think through divisible split and non divisible dependencies to + // see if there's conflicts where a split might look non divisible but + // actually is divisible and one's overruling the other. + const std::vector& ids, + const std::vector& root_domain, + const std::unordered_set& divisible_splits) { + if (ids.empty() || root_domain.empty()) { + return; + } + auto transforms = StmtSort::getExprsBetween( + ids[0]->fusion(), + {root_domain.begin(), root_domain.end()}, + {ids.begin(), ids.end()}); + for (auto transform : transforms) { + auto inp_ids = ir_utils::filterByType(transform->inputs()); + for (auto inp_id : inp_ids) { + if (std::find(root_domain.begin(), root_domain.end(), inp_id) != + root_domain.end()) { + // This generally shouldn't happen as there shouldn't be + // transformations before the root ids, but in case for some reason + // we eventually do have cases like that, we should reset the + // root_ids if for some reason they've been placed in the non + // divisible split set. + depends_on_non_divisible_split.erase(inp_id); + } + } + + bool inputs_non_divisible = + std::any_of(inp_ids.begin(), inp_ids.end(), [this](IterDomain* inp_id) { + return depends_on_non_divisible_split.find(inp_id) != + depends_on_non_divisible_split.end(); + }); + + auto out_ids = ir_utils::filterByType(transform->outputs()); + + if (inputs_non_divisible) { + // If any inputs are known to be dependent on a divisible split + // Mark outputs as dependent on a non_divisible split + depends_on_non_divisible_split.insert(out_ids.begin(), out_ids.end()); + continue; + } + + if (!transform->isA()) { + continue; + } + + auto split = transform->as(); + // If this transform is a non-divisible split + if (divisible_splits.find(split) == divisible_splits.end()) { + // Mark outputs as dependent on a non_divisible split + auto out_ids = ir_utils::filterByType(transform->outputs()); + depends_on_non_divisible_split.insert(out_ids.begin(), out_ids.end()); + } + } +} + +ContigIDs::ContigIDs( + const std::vector& ids, + const std::vector& root_domain, + const std::vector& root_contiguity, + const std::unordered_set& final_ids, + const std::unordered_map& index_map, + const std::unordered_set& divisible_splits, + std::unordered_map p2c_id_map, + bool ignore_indexability, + bool ignore_consistent_ordering) + : root_domain_(root_domain), + root_contiguity_(root_contiguity), + final_ids_(final_ids), + index_map_(index_map), + divisible_splits_(divisible_splits), + p2c_id_map_(std::move(p2c_id_map)), + ignore_indexability_(ignore_indexability), + ignore_consistent_ordering_(ignore_consistent_ordering), + non_divisible_id_info_(ids, root_domain_, divisible_splits_) { + if (ids.size() > 0) { + // This constructor doesn't provide the following information so it needs to + // be built. + ca_map_ = std::make_shared(ids[0]->fusion()); + halo_info_ = std::make_shared(ids[0]->fusion(), ca_map_); + concrete_info_ = + std::make_shared(ids[0]->fusion()); + + consistent_transform_info_ = std::make_unique( + ids, root_domain, concrete_info_); + } + build(ids); +} + ContigIDs::ContigIDs( const std::vector& ids, const std::vector& root_domain, const std::vector& root_contiguity, - std::unordered_map concrete_to_ref, + const std::unordered_set& final_ids, + const std::unordered_map& index_map, + const std::unordered_set& divisible_splits, + std::shared_ptr ca_map, + std::shared_ptr halo_info, + std::shared_ptr concrete_info, std::unordered_map p2c_id_map, - bool ignore_halo_constraint, - bool ignore_indexability) + bool ignore_indexability, + bool ignore_consistent_ordering) : root_domain_(root_domain), root_contiguity_(root_contiguity), - concrete_to_ref_(std::move(concrete_to_ref)), + final_ids_(final_ids), + index_map_(index_map), + divisible_splits_(divisible_splits), + ca_map_(ca_map), + halo_info_(halo_info), + concrete_info_(concrete_info), p2c_id_map_(std::move(p2c_id_map)), - ignore_indexability_(ignore_indexability) { - if (ids.empty()) { + ignore_indexability_(ignore_indexability), + ignore_consistent_ordering_(ignore_consistent_ordering), + consistent_transform_info_(std::make_unique( + ids, + root_domain, + concrete_info_)), + non_divisible_id_info_(ids, root_domain, divisible_splits_) { + build(ids); +} + +ContigIDs ContigIDs::getNonContigIDs() { + return ContigIDs({}, {}, {}, {}, {}, {}); +} + +void ContigIDs::build(const std::vector& ids) { + if (ids.empty() || root_domain_.empty()) { return; } @@ -32,35 +467,29 @@ ContigIDs::ContigIDs( " != ", root_contiguity_.size()); - // GpuLower is required to honor halo constraints - if (!ignore_halo_constraint) { - TORCH_INTERNAL_ASSERT(GpuLower::hasCurrent(), "GpuLower not found"); - } - - for (const auto i : c10::irange(root_domain_.size())) { - auto root_domain_i = root_domain_[i]->as(); - root_to_indexed_id_[root_domain_i] = root_domain_i; + for (const auto root_domain_i : c10::irange(root_domain_.size())) { + auto root_domain_id = root_domain_[root_domain_i]->as(); + root_to_indexed_id_[root_domain_id] = root_domain_id; // Initialize to false - is_contig_root_[root_domain_i] = false; + is_contig_root_[root_domain_id] = false; // If a root domain has halo, can't use merged domain even if // both inputs are contiguous. HaloInfo is also initialized for // rfactor root domains, which should just return "zero" // RootAxisInfo. This should be safe as no rfactor tensor should // need halo. - if (root_contiguity_[i] && - (ignore_halo_constraint || - !GpuLower::current() - ->haloInfo() - .getRootAxisInfo(root_domain_i) - .hasHalo())) { - contig_ids_.emplace(root_domain_i); - is_contig_root_[root_domain_i] = true; - within_contig_ids_[root_domain_i] = std::unordered_set(); + if (root_contiguity_[root_domain_i] && + !halo_info_->getRootAxisInfo(root_domain_id).hasHalo()) { + contig_ids_.emplace(root_domain_id); + is_contig_root_[root_domain_id] = true; + within_contig_ids_[root_domain_id] = std::unordered_set(); } } if (!contig_ids_.empty()) { - auto exprs = StmtSort::getExprs(ids[0]->fusion(), {ids.begin(), ids.end()}); + auto exprs = StmtSort::getExprsBetween( + ids[0]->fusion(), + {root_domain_.begin(), root_domain_.end()}, + {ids.begin(), ids.end()}); for (auto expr : exprs) { handle(expr); } @@ -68,114 +497,99 @@ ContigIDs::ContigIDs( } void ContigIDs::handle(Merge* merge) { - // If either input is non-contiguous so is output. - const auto inner = merge->inner(); - const auto outer = merge->outer(); - const auto out = merge->out(); + // If output is not consistently ordered or doesn't solely consume all root + // domains in its dependencies, then it can't be a contiguously indexable + // iterdomain. + if (!(ignore_consistent_ordering_ || + consistent_transform_info_->isConsistentlyOrdered(merge->out()))) { + return; + } - if (!isContig(inner) || !isContig(outer)) { + if (!consistent_transform_info_->exclusivelyConsumesRoots(merge->out())) { return; } - // Stop contig merging if the merge output is not indexable. - if (!ignore_indexability_ && !isIndexable(out)) { + // If output is not "directly indexable" then it's definitely not contiguously + // indexable. + if (!ignore_indexability_ && !isIndexable(merge->out())) { return; } - // Grab inputs, make sure they're in root domain, check if they're - // contiguous. + // If inputs are marked as final, stop + if (final_ids_.count(merge->inner()) || final_ids_.count(merge->outer())) { + return; + } - auto lhs_inputs = - ir_utils::iterDomainInputsOfOrderedAs({outer}, root_domain_); - auto rhs_inputs = - ir_utils::iterDomainInputsOfOrderedAs({inner}, root_domain_); + // Check root domains for contiguity + auto root_ids_it = + consistent_transform_info_->idToRootIds().find(merge->out()); TORCH_INTERNAL_ASSERT( - inRoot(lhs_inputs) && inRoot(rhs_inputs), - "Found an invalid merge operation, inputs of its arguments are not in the root domain."); - - std::deque ordered_inputs(lhs_inputs.begin(), lhs_inputs.end()); - ordered_inputs.insert( - ordered_inputs.end(), rhs_inputs.begin(), rhs_inputs.end()); - - // If any root input is not contig, output is not contig - if (!(std::all_of( - ordered_inputs.begin(), ordered_inputs.end(), [this](IterDomain* id) { - // Allow reduction tensors in contiguity check since we're using - // this to check contiguous vectors of reference tensors in - // schedulers (to set vectorization sizes), those reference tensors - // may have reduction dims, don't bail on contiguity just because - // it's a reduction dimension. - return is_contig_root_.at(id); - }))) { - return; - } + root_ids_it != consistent_transform_info_->idToRootIds().end(), + "\nError in contiguous analysis, merge info doesn't exist for:\n", + merge->toString(), + "\nId: ", + merge->out()->toString()); - std::deque root_copy(root_domain_.begin(), root_domain_.end()); + VectorOfUniqueEntries root_ids = root_ids_it->second; - // Forward to first matching argument - while (!root_copy.empty() && !ordered_inputs.empty()) { - if (root_copy.front() != ordered_inputs.front()) { - root_copy.pop_front(); - } else { - break; - } - } + bool is_indexing_pass = !ignore_consistent_ordering_; - // Forward through all matching arguments - while (!root_copy.empty() && !ordered_inputs.empty()) { - if (root_copy.front() == ordered_inputs.front()) { - root_copy.pop_front(); - ordered_inputs.pop_front(); - } else if ( - root_copy.front()->isReduction() || root_copy.front()->isBroadcast()) { - // This was a cause of an error with - // ReductionSchedulerMultiDimNonFastest. The test no longer - // fails. - root_copy.pop_front(); - } else { - break; + IterDomain* last_root = nullptr; + for (auto root_id_i : c10::irange(root_domain_.size())) { + auto root_id = root_domain_[root_id_i]; + if (root_ids.has(root_id)) { + // ID found, remove it + root_ids.erase(root_id); + // If we're indexing: + // we could still potentially consider this ID linearly indexable, as we + // could multiple the index by the last root's stride. + // + // If we're computing predicates (ignore_consistent_ordering_==true), + // then we don't have this same constraint, we can just ignore + // contiguity of the roots all together. + if (!root_contiguity_[root_id_i] && is_indexing_pass) { + if (!root_ids.empty()) { + return; + } + } + last_root = root_id; } } - // If we matched all inputs, the output is contiguous. Only want to keep the - // top contig ID, lower ids should be placed in the "within_contig_ids" map - // of top id. - if (ordered_inputs.empty()) { - if (contig_ids_.find(inner) != contig_ids_.end()) { - contig_ids_.erase(inner); - } + // If there's a non_divisible split in the history of merge->out then it can't + // be contiguously indexable. + if (non_divisible_id_info_.dependsOnNonDivisibleSplit(merge->out())) { + return; + } - if (contig_ids_.find(outer) != contig_ids_.end()) { - contig_ids_.erase(outer); - } + // Now we know merge->out is a contiguously indexable ID - contig_ids_.emplace(out); + TORCH_INTERNAL_ASSERT( + last_root != nullptr, + "Issue processing root ids for ", + merge->out()->toString()); - std::unordered_set within_out; - within_out.emplace(inner); - if (within_contig_ids_.find(inner) != within_contig_ids_.end()) { - auto in_inner = within_contig_ids_.at(inner); - within_out.insert(in_inner.begin(), in_inner.end()); - within_contig_ids_.erase(inner); - } + // Reset root_ids + root_ids = root_ids_it->second; + for (auto root_id : root_ids) { + root_to_indexed_id_[root_id] = merge->out(); + } - within_out.emplace(outer); - if (within_contig_ids_.find(outer) != within_contig_ids_.end()) { - auto in_outer = within_contig_ids_.at(outer); - within_out.insert(in_outer.begin(), in_outer.end()); - within_contig_ids_.erase(outer); - } + auto all_within_vals = DependencyCheck::getAllValsBetween( + {root_domain_.begin(), root_domain_.end()}, {merge->out()}); + auto all_within_ids = ir_utils::filterByType(all_within_vals); - within_contig_ids_[out] = within_out; + std::unordered_set within_id_set( + all_within_ids.begin(), all_within_ids.end()); - for (auto root : lhs_inputs) { - root_to_indexed_id_[root] = out; - } - for (auto root : rhs_inputs) { - root_to_indexed_id_[root] = out; - } + within_id_set.erase(merge->out()); + within_contig_ids_[merge->out()] = within_id_set; + for (auto id : all_within_ids) { + contig_ids_.erase(id); } + + contig_ids_.emplace(merge->out()); } IterDomain* ContigIDs::getMappedId(IterDomain* id) const { @@ -187,24 +601,16 @@ IterDomain* ContigIDs::getMappedId(IterDomain* id) const { } } -IterDomain* ContigIDs::getCAIndexConcreteId(IterDomain* id) const { - TORCH_INTERNAL_ASSERT( - GpuLower::current() != nullptr, "GpuLower is not found"); - - auto c_id = GpuLower::current()->caMap()->getConcreteMappedID( - getMappedId(id), IdMappingMode::EXACT); - return c_id; -} - bool ContigIDs::isIndexable(IterDomain* id) const { // If ID is mapped to consumer through persmissive map but not exact map it // will not be mapped through to the exact map through the p2c map. Therefore // reject because it involves broadcast resolution. - if (!GpuLower::current()->caMap()->idExistsInMap(getMappedId(id))) { + if (!ca_map_->idExistsInMap(getMappedId(id))) { return false; } - auto c_id = getCAIndexConcreteId(id); - return concrete_to_ref_.find(c_id) != concrete_to_ref_.end(); + auto c_id = + ca_map_->getConcreteMappedID(getMappedId(id), IdMappingMode::EXACT); + return index_map_.find(c_id) != index_map_.end(); } } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/contiguity.h b/torch/csrc/jit/codegen/cuda/contiguity.h index 7293901310eb6..e3be65a5bbc08 100644 --- a/torch/csrc/jit/codegen/cuda/contiguity.h +++ b/torch/csrc/jit/codegen/cuda/contiguity.h @@ -2,13 +2,128 @@ #include +#include +#include #include +#include +#include namespace torch { namespace jit { namespace fuser { namespace cuda { +// Goes through the transformations associated with a series of ids and root +// ids. Checks the ordering of the iteration domains through these operations to +// pick out which operations are consistently ordered. For example: +// [i0, i1, i2] +// ->split(0, 4)->merge(1)->merge(1)->merge(0) +// are consistently ordered from largest to smallest extents, but +// ->split(0, 4)->merge(1)->merge(0, 2)->merge(0) is not consistently ordered +// with the roots. +// +// This property is important to understand the contiguity of dimensions through +// complex transformations. +class OrderedIdInformation : public OptInDispatch { + public: + OrderedIdInformation() = delete; + + OrderedIdInformation( + const std::vector& ids, + const std::vector& root_domain, + std::shared_ptr concrete_info); + + const std::unordered_map>& + idToRootIds() const { + return id_to_root_ids_; + } + + bool isConsistentlyOrdered(IterDomain* id) const { + return consistently_ordered_ids_.find(id) != + consistently_ordered_ids_.end(); + } + + bool exclusivelyConsumesRoots(IterDomain* id) const { + return exclusively_consumes_roots_.find(id) != + exclusively_consumes_roots_.end(); + } + + private: + // Returns if the id in active_ids should be in exclusively_consumes_roots_ + bool checkExclusivelyConsumesRoots(IterDomain* id); + + void handle(Split*) override; + + void handle(Merge* merge) override; + + void handle(Swizzle2D* swizzle) override; + + // Track which root ids were used to generate each iter domain + std::unordered_map> + id_to_root_ids_; + + // Track all IterDomains that have correct ordered transforms for contiguity. + // i.e. if we have: + // + // root = [i0, i1, i2] + // i3 = merge(i0, i2) + // would not be consistently ordered transformed + // + // root = [i0, i1, i2] + // i4, i5 = spit(merge(merge(i0, i1), i2), 4) + // would be consistently ordered transforms + // + // root = [i0, i1, i2, i3] + // i4 = merge(i1, i2) would also be consistently ordered transformed + std::unordered_set consistently_ordered_ids_; + + // Active series of IterDomains that are updated while we're processing the + // domain. Helps us identify which ids are consistently_ordered_ids_. Used + // for intermediate storage, not to return. + std::vector active_ids_; + + // IterDomains in this set exclusively consume all the uses of their roots. + // For example: + // [i0, i1] split(0, f)->merge(1) + // [ceilDiv(i0, f), f*i1] + // neither iter domains exclusively consume the roots. With another: + // merge(0) -> [ceilDiv(i0, f)*f*i1] + // The resulting iter domain does exclusively consume the roots. + // + // Also: + // [i0, i1, i2, i3] merge(1)->merge(1) + // ->[i0, i1*i2*i3] + // both resulting iter domains do exclusively consume their roots + std::unordered_set exclusively_consumes_roots_; + + // Broadcast domains that are concretized cannot be considered contiguously + // indexable. + // TODO: This constraint is more conservative than necessary as it's only if + // the domain is concretized within the local indexing, not in the entire + // fusion. + std::shared_ptr concrete_info_; +}; + +// Based on provided divisible split set, goes through expressions and marks all +// IterDomains that are dependent on a non-divisible split. +class NonDivisibleSplitDependencies : public OptInDispatch { + public: + NonDivisibleSplitDependencies() = delete; + + NonDivisibleSplitDependencies( + const std::vector& ids, + const std::vector& root_domain, + const std::unordered_set& divisible_splits); + + bool dependsOnNonDivisibleSplit(IterDomain* id) const { + return depends_on_non_divisible_split.find(id) != + depends_on_non_divisible_split.end(); + } + + private: + std::unordered_set depends_on_non_divisible_split; +}; + // A merge is contiguous if: // Inputs of outer are to the left in the root domain of the inputs of RHS. // All inputs are contiguous in the root domain: @@ -22,8 +137,6 @@ namespace cuda { class ContigIDs : public OptInDispatch { public: - ContigIDs() = delete; - //! Check through the history of ids whose inputs map to root_domain with //! contiguity root_contiguity. Return unordered_set of all merges that are //! contiguous. Ignore root order is primarily used for predicate generation. @@ -42,21 +155,55 @@ class ContigIDs : public OptInDispatch { //! If ignore_indexability and ignore_halo_constraint are true, //! ignore the constraint on indexing and halo, respectively. It is //! the caller that is responsible for its correctness. - //! - //! The function interface with many parameters looks ugly, but it - //! is also important to make ignore_indexability and - //! ignore_halo_constraint explicit to avoid any surprise. - //! //! Not really sure why but clang-tidy only complains about //! std::unordered_map if passed as a const reference. ContigIDs( const std::vector& ids, const std::vector& root_domain, const std::vector& root_contiguity, - std::unordered_map concrete_to_ref, + const std::unordered_set& final_ids, + const std::unordered_map& index_map, + const std::unordered_set& divisible_splits, + std::unordered_map p2c_id_map = {}, + bool ignore_indexability = false, + bool ignore_consistent_ordering = false); + + //! \param ids IterDomains on the leaves of the domain we're looking for + //! contiguous indexing into. + //! \param root_domain the root domain of the domain we're looking for + //! contiguous indexing into. + //! \param root_contiguity the contiguity of the root_domain. + //! \param concrete_to_ref concrete ids of the exact map that the reference + //! index is using for indexing. + //! \param divisible_splits a set of all splits in the fusion that are + //! divisible. + //! \param ca_map compute at map of the fusion. + //! \param halo_info halo information of the fusion. + //! \param concrete_info concretized broadcast information of the fusion. + //! \param p2c_id_map map from producer to consumer ids used for indexing + //! producer tensors. + //! \param ignore_consistent_ordering true for actual indexing into tensors + //! but false for predicate analysis. Ordering of merges don't matter for + //! predicate generation as they don't map to a physical address. + //! \param ignore_indexability can only be true if providing a real + //! concrete_to_ref map. As what it's checking is if the index is actually + //! indexable based on the reference. + ContigIDs( + const std::vector& ids, + const std::vector& root_domain, + const std::vector& root_contiguity, + const std::unordered_set& final_ids, + const std::unordered_map& index_map, + const std::unordered_set& divisible_splits, + std::shared_ptr ca_map, + std::shared_ptr halo_info, + std::shared_ptr concrete_info, std::unordered_map p2c_id_map = {}, bool ignore_indexability = false, - bool ignore_halo_constraint = false); + bool ignore_consistent_ordering = false); + + //! Return an empty ContigIDs with no contiguous ID + static ContigIDs getNonContigIDs(); const std::unordered_set& contigIDs() const { return contig_ids_; @@ -71,6 +218,14 @@ class ContigIDs : public OptInDispatch { return root_to_indexed_id_; } + VectorOfUniqueEntries indexedRootIDs(IterDomain* id) const { + auto root_ids_it = consistent_transform_info_->idToRootIds().find(id); + if (root_ids_it == consistent_transform_info_->idToRootIds().end()) { + return {}; + } + return root_ids_it->second; + } + private: using OptInDispatch::handle; @@ -107,17 +262,32 @@ class ContigIDs : public OptInDispatch { IterDomain* getMappedId(IterDomain* id) const; private: + void build(const std::vector& ids); + //! Root domains to analyze contiguity const std::vector& root_domain_; //! Contiguity of root_domain_ const std::vector& root_contiguity_; - //! Mapping of concrete to reference domains. If a concrete domain - //! is not mapped, it is not indexable as there's no mapped index. - const std::unordered_map concrete_to_ref_; + //! Domains where indexing/predicates cannot be done with their + //! consumers domains + const std::unordered_set& final_ids_; + //! Mapping of concrete domains to indices. Just used to check if + //! there's an index for an IterDomain. + const std::unordered_map index_map_; + // Divisible split information as we can still consider iter domains + // contiguous through divisible splits. + const std::unordered_set& divisible_splits_; + + std::shared_ptr ca_map_; + std::shared_ptr halo_info_; + std::shared_ptr concrete_info_; + //! Producer-to-consumer index map in the case of analyzing replayed //! producer tensors const std::unordered_map p2c_id_map_; + const bool ignore_indexability_ = false; + const bool ignore_consistent_ordering_ = false; //! Mapping of root domain to bool indicating contiguity std::unordered_map is_contig_root_; @@ -129,6 +299,10 @@ class ContigIDs : public OptInDispatch { //! Mapping of root domain to the actual indexed domain, which can //! be itself or a contig merged domain if found. std::unordered_map root_to_indexed_id_; + + std::unique_ptr consistent_transform_info_; + + NonDivisibleSplitDependencies non_divisible_id_info_; }; } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/disjoint_set.h b/torch/csrc/jit/codegen/cuda/disjoint_set.h index b325bedcf7b9e..8fd60dab5bd22 100644 --- a/torch/csrc/jit/codegen/cuda/disjoint_set.h +++ b/torch/csrc/jit/codegen/cuda/disjoint_set.h @@ -260,7 +260,7 @@ class DisjointSets { entry_it != disjointSetMap().end(), "Strict mapping failed on element: ", abstractToString(entry0), - " either an error occured, or non strict mapping should have been used."); + " either an error occurred, or non strict mapping should have been used."); return entry_it->second->has(entry1); } @@ -302,17 +302,14 @@ class DisjointSets { std::string toString() const { std::stringstream ss; ss << "disjoint sets{\n"; + const std::string sep(" "); for (auto s_ptr : disjoint_sets_) { auto& set = *s_ptr; - ss << " { "; + ss << sep << "{\n"; for (auto entry : set.vector()) { - ss << abstractToString(entry); - // DomainKey defines == but not != - if (!(entry == set.back())) { - ss << "; "; - } + ss << sep << sep << abstractToString(entry) << "\n"; } - ss << " }\n"; + ss << sep << "}\n"; } ss << "}"; return ss.str(); diff --git a/torch/csrc/jit/codegen/cuda/dispatch.cpp b/torch/csrc/jit/codegen/cuda/dispatch.cpp index 7f66d3c69495c..70e9ae16375e5 100644 --- a/torch/csrc/jit/codegen/cuda/dispatch.cpp +++ b/torch/csrc/jit/codegen/cuda/dispatch.cpp @@ -95,9 +95,15 @@ void Val::dispatch(T handler, Val* val) { template void Expr::dispatch(T handler, Expr* expr) { switch (*(expr->getExprType())) { + case ExprType::FullOp: + ptr(handler)->handle(expr->as()); + return; case ExprType::ARangeOp: ptr(handler)->handle(expr->as()); return; + case ExprType::EyeOp: + ptr(handler)->handle(expr->as()); + return; case ExprType::UnaryOp: ptr(handler)->handle(expr->as()); return; @@ -281,9 +287,15 @@ void Val::constDispatch(T handler, const Val* val) { template void Expr::constDispatch(T handler, const Expr* expr) { switch (*(expr->getExprType())) { + case ExprType::FullOp: + ptr(handler)->handle(expr->as()); + return; case ExprType::ARangeOp: ptr(handler)->handle(expr->as()); return; + case ExprType::EyeOp: + ptr(handler)->handle(expr->as()); + return; case ExprType::UnaryOp: ptr(handler)->handle(expr->as()); return; @@ -475,9 +487,15 @@ void Val::mutatorDispatch(T mutator, Val* val) { template void Expr::mutatorDispatch(T mutator, Expr* expr) { switch (*(expr->getExprType())) { + case ExprType::FullOp: + ptr(mutator)->mutate(expr->as()); + return; case ExprType::ARangeOp: ptr(mutator)->mutate(expr->as()); return; + case ExprType::EyeOp: + ptr(mutator)->mutate(expr->as()); + return; case ExprType::UnaryOp: ptr(mutator)->mutate(expr->as()); return; @@ -734,9 +752,15 @@ void OptOutConstDispatch::handle(const kir::IntPair* stmt) { } // Exprs +void OptOutConstDispatch::handle(const FullOp* stmt) { + unhandled(stmt); +} void OptOutConstDispatch::handle(const ARangeOp* stmt) { unhandled(stmt); } +void OptOutConstDispatch::handle(const EyeOp* stmt) { + unhandled(stmt); +} void OptOutConstDispatch::handle(const UnaryOp* stmt) { unhandled(stmt); } @@ -890,9 +914,15 @@ void OptOutDispatch::handle(kir::IntPair* stmt) { } // Exprs +void OptOutDispatch::handle(FullOp* stmt) { + unhandled(stmt); +} void OptOutDispatch::handle(ARangeOp* stmt) { unhandled(stmt); } +void OptOutDispatch::handle(EyeOp* stmt) { + unhandled(stmt); +} void OptOutDispatch::handle(UnaryOp* stmt) { unhandled(stmt); } diff --git a/torch/csrc/jit/codegen/cuda/dispatch.h b/torch/csrc/jit/codegen/cuda/dispatch.h index 6b35a9775ecf7..4fea698191ec4 100644 --- a/torch/csrc/jit/codegen/cuda/dispatch.h +++ b/torch/csrc/jit/codegen/cuda/dispatch.h @@ -68,7 +68,9 @@ class ComplexDouble; class NamedScalar; // Exprs +class FullOp; class ARangeOp; +class EyeOp; class UnaryOp; class BinaryOp; class TernaryOp; @@ -144,7 +146,9 @@ class TORCH_CUDA_CU_API OptOutConstDispatch : public PolymorphicBase { virtual void handle(const kir::IntPair*); // Exprs + virtual void handle(const FullOp* stmt); virtual void handle(const ARangeOp* stmt); + virtual void handle(const EyeOp* stmt); virtual void handle(const UnaryOp* stmt); virtual void handle(const BinaryOp* stmt); virtual void handle(const TernaryOp* stmt); @@ -211,7 +215,9 @@ class TORCH_CUDA_CU_API OptOutDispatch : public PolymorphicBase { virtual void handle(kir::IntPair*); // Exprs + virtual void handle(FullOp* stmt); virtual void handle(ARangeOp* stmt); + virtual void handle(EyeOp* stmt); virtual void handle(UnaryOp* stmt); virtual void handle(BinaryOp* stmt); virtual void handle(TernaryOp* stmt); @@ -319,7 +325,9 @@ class TORCH_CUDA_CU_API OptOutMutator : public PolymorphicBase { virtual void mutate(kir::IntPair*); // Exprs + virtual void mutate(FullOp*); virtual void mutate(ARangeOp*); + virtual void mutate(EyeOp*); virtual void mutate(UnaryOp*); virtual void mutate(BinaryOp*); virtual void mutate(TernaryOp*); diff --git a/torch/csrc/jit/codegen/cuda/dynamic_type.h b/torch/csrc/jit/codegen/cuda/dynamic_type.h index aba725e0ea60a..5cf9f0930929d 100644 --- a/torch/csrc/jit/codegen/cuda/dynamic_type.h +++ b/torch/csrc/jit/codegen/cuda/dynamic_type.h @@ -296,6 +296,14 @@ inline IntOrDouble min(const IntOrDouble& a, const IntOrDouble& b) { return (a < b ? a : b).cast(); } +inline IntOrDouble abs(const IntOrDouble& a) { + if (a.is_int()) { + return IntOrDouble(std::abs(a.as())); + } else { + return IntOrDouble(std::abs(a.as())); + } +} + } // namespace IntOrDouble_functions } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/evaluator_common.cpp b/torch/csrc/jit/codegen/cuda/evaluator_common.cpp index bab8586247bfd..ae280b4ac44c8 100644 --- a/torch/csrc/jit/codegen/cuda/evaluator_common.cpp +++ b/torch/csrc/jit/codegen/cuda/evaluator_common.cpp @@ -196,7 +196,13 @@ template void PrecomputedValuesBase::validate() { FUSER_PERF_SCOPE("PrecomputedValuess::Validate"); for (auto it : binding_log_) { - TORCH_INTERNAL_ASSERT(values_[it.first] == it.second); + TORCH_INTERNAL_ASSERT( + values_[it.first] == it.second, + "Precomputed values failed to validate.", + "\nSomething unexpected changed between the compilation and execution.\n", + values_[it.first], + " != ", + it.second); } has_valid_values_ = true; } @@ -295,6 +301,7 @@ void NaiveValueMachine::runInstruction(int index) { template void NaiveValueMachine::runUnaryOp(int index) { + using namespace IntOrDouble_functions; int src_index = src0_[index]; bool src_defined = precomputed_values_.defined_[src_index]; bool src_is_const = precomputed_values_.is_constant_[src_index]; @@ -323,6 +330,9 @@ void NaiveValueMachine::runUnaryOp(int index) { TORCH_INTERNAL_ASSERT(false, "dtype not supported in evaluator"); } break; + case UnaryOpType::Abs: + dest = abs(src); + break; default: TORCH_CHECK(!"Unexpected operator type ", uop_type_[index]); } diff --git a/torch/csrc/jit/codegen/cuda/executor.cpp b/torch/csrc/jit/codegen/cuda/executor.cpp index b93c9514fcf02..23be5f4232aad 100644 --- a/torch/csrc/jit/codegen/cuda/executor.cpp +++ b/torch/csrc/jit/codegen/cuda/executor.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include @@ -20,6 +21,7 @@ #include #include +#include #include namespace torch { @@ -29,6 +31,16 @@ namespace cuda { int FusionExecutor::fusion_id_counter_ = 0; // NOLINT +bool fill_allocation_with_nan_ = false; + +bool shouldFillAllocationWithNan() { + return fill_allocation_with_nan_; +} + +void setFillAllocationWithNan(bool value) { + fill_allocation_with_nan_ = value; +} + namespace { static const char* defineIndexMode(KernelIndexMode index_mode) { @@ -226,7 +238,7 @@ void FusionExecutor::compileFusion( #ifndef USE_ROCM device_smem_limit_ = properties->sharedMemPerBlockOptin; #else - // don't know if rocm supports opt-in shared memroy reconfiguration + // don't know if rocm supports opt-in shared memory reconfiguration device_smem_limit_ = properties->sharedMemPerBlock; #endif warp_size_ = properties->warpSize; @@ -245,6 +257,27 @@ void FusionExecutor::compileFusion( kernel->print(); } + if (isDebugDumpEnabled(DebugDumpOption::BankConflictInfo)) { + auto bank_conflict_info = getBankConflictInfo(kernel); + if (bank_conflict_info.empty()) { + std::cout << "===== No bank confliction =====" << std::endl; + } else { + std::cout << "======= Bank confliction =======" << std::endl; + for (auto info : bank_conflict_info) { + std::cout << "Expr: " << info.first->toString() << std::endl; + auto conflict = info.second; + if (conflict.first > 1) { + std::cout << "input conflict: " << conflict.first << " way, "; + } + if (conflict.second > 1) { + std::cout << "output conflict: " << conflict.second << " way"; + } + std::cout << std::endl; + } + std::cout << "================================" << std::endl; + } + } + kernel_code_ = codegen::generateCudaKernel(kernel, kernelName()); const auto structured_code = getStructuredCode(kernel_code_); @@ -314,6 +347,42 @@ void FusionExecutor::compileFusion( namespace { +void fillTensorWithNan(at::Tensor& t) { + switch (t.scalar_type()) { + case at::ScalarType::Byte: + t.fill_(0xFF); + break; + case at::ScalarType::Char: + t.fill_(0x7F); + break; + case at::ScalarType::Short: + t.fill_(0x7FFF); + break; + case at::ScalarType::Int: + t.fill_(0x7FFFFFFF); + break; + case at::ScalarType::Long: + t.fill_(0x7FFFFFFFFFFFFFFFL); + break; + case at::ScalarType::Bool: + t.fill_(true); + break; + case at::ScalarType::Half: + case at::ScalarType::Float: + case at::ScalarType::Double: + case at::ScalarType::BFloat16: + t.fill_(std::nan("")); + break; + case at::ScalarType::ComplexHalf: + case at::ScalarType::ComplexFloat: + case at::ScalarType::ComplexDouble: + t.fill_(c10::complex(std::nan(""), std::nan(""))); + break; + default: + TORCH_INTERNAL_ASSERT(false, "Unknown dtype"); + } +} + at::Tensor inferAndAlloc( const TensorView* tv, const std::vector& sizes, @@ -383,6 +452,9 @@ at::Tensor inferAndAlloc( // Non Variable type guard for empty_cuda call at::AutoDispatchBelowADInplaceOrView non_variable_type_mode; auto empty = at::empty(isizes, tensor_options); + if (shouldFillAllocationWithNan()) { + fillTensorWithNan(empty); + } if (expanded_dim) { return empty.expand(expanded_sizes); } @@ -700,29 +772,24 @@ FusionExecutor::GlobalBuffers FusionExecutor::allocGlobalVals( } std::vector FusionExecutor::allocOutputs( + const KernelArgumentHolder& args, kir::ExpressionEvaluator& expr_eval, const std::unordered_set& alias_indices) { FUSER_PERF_SCOPE("FusionExecutor::AllocOutputs"); const auto kernel = lowered_->kernel(); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) std::vector outputs; + TORCH_INTERNAL_ASSERT( + args.size() == kernel->inputs().size(), + "kernel arguments length does not match runtime arguments."); for (const auto out_i : c10::irange(kernel->outputs().size())) { - // TODO: FIX this short-cut where we trivially forward inputs to outputs if (kernel->outputs()[out_i]->isFusionInput()) { - TORCH_INTERNAL_ASSERT(false, "trivial input forwarding NOT IMPLEMENTED"); - // for (auto inp_i : c10::irange(kernel->inputs().size())) { - // if (kernel->inputs()[inp_i] == kernel->outputs()[out_i]) { - // TORCH_INTERNAL_ASSERT( - // inp_i < inputs.size(), - // "Issue with an input showing up as output, couldn't find - // input."); - // TORCH_INTERNAL_ASSERT( - // inputs[inp_i].isTensor(), - // "Cannot register a scalar as an output in a fusion."); - // outputs.push_back(inputs[inp_i].toTensor()); - // break; - // } - // } + // pushing empty tensor for trivial forwarding. Since we handle this in + // integration, see step 1 - note [trivial forwarding] + c10::Device device(c10::DeviceType::CUDA, args.getDeviceIndex()); + const auto tensor_options = + at::TensorOptions().dtype(at::kFloat).device(device); + outputs.emplace_back(at::empty({0}, tensor_options)); } else { TORCH_INTERNAL_ASSERT( kernel->outputs()[out_i]->isA(), @@ -762,7 +829,8 @@ KernelArgumentHolder FusionExecutor::evaluateOutputSizes( meta_options.device = c10::Device(DeviceType::Meta, 0); for (const auto out_i : c10::irange(kernel->outputs().size())) { - // If the output is just trivially the input, just "copy" it over. + // If the output is just trivially the input, just "copy" it over, see note + // [trivial forwarding] if (kernel->outputs()[out_i]->isFusionInput()) { for (auto inp_i : c10::irange(kernel->inputs().size())) { if (kernel->inputs()[inp_i] == kernel->outputs()[out_i]) { @@ -884,6 +952,8 @@ std::vector FusionExecutor::runFusion( !args.getCacheId().has_value() || outputs.empty(), "short cut input cache is not compatible with pre-allocated output"); + size_t num_inputs = args.size(); + if (isDebugDumpEnabled(DebugDumpOption::FusionArgs)) { std::cout << "Arguments for fusion" << fusion_id_ << ":" << std::endl << "Inputs:" << std::endl; @@ -930,6 +1000,9 @@ std::vector FusionExecutor::runFusion( c10::nullopt, options_.device, c10::nullopt)); + if (shouldFillAllocationWithNan()) { + fillTensorWithNan(allocated_outputs.back()); + } } // Note: aliased output is not returned as output. But we still need it // for kernel execution, so would need to push them to args @@ -970,6 +1043,9 @@ std::vector FusionExecutor::runFusion( c10::nullopt, options_.device, c10::nullopt)); + if (shouldFillAllocationWithNan()) { + fillTensorWithNan(global_buffers.buffers.back()); + } global_buffers.zero_init.push_back(false); } } @@ -1075,7 +1151,7 @@ std::vector FusionExecutor::runFusion( auto& output_alias_indices = output_alias_indices_entry.get(); - allocated_outputs = allocOutputs(expr_eval, output_alias_indices); + allocated_outputs = allocOutputs(args, expr_eval, output_alias_indices); for (const auto& entry : alias_indices) { auto aliased_output_index = entry.first; @@ -1175,9 +1251,9 @@ std::vector FusionExecutor::runFusion( if (measure_kernel_time_ || isDebugDumpEnabled(DebugDumpOption::EffectiveBandwidth) || isDebugDumpEnabled(DebugDumpOption::PerfDebugVerbose)) { - cudaEventCreate(&start_event); - cudaEventCreate(&finish_event); - cudaEventRecord(start_event); + C10_CUDA_CHECK(cudaEventCreate(&start_event)); + C10_CUDA_CHECK(cudaEventCreate(&finish_event)); + C10_CUDA_CHECK(cudaEventRecord(start_event)); } if (execute_kernel_) { @@ -1233,16 +1309,17 @@ std::vector FusionExecutor::runFusion( if (measure_kernel_time_ || isDebugDumpEnabled(DebugDumpOption::EffectiveBandwidth) || isDebugDumpEnabled(DebugDumpOption::PerfDebugVerbose)) { - cudaEventRecord(finish_event); - cudaEventSynchronize(start_event); - cudaEventSynchronize(finish_event); - cudaEventElapsedTime(&kernel_time_ms_, start_event, finish_event); - cudaEventDestroy(start_event); - cudaEventDestroy(finish_event); + C10_CUDA_CHECK(cudaEventRecord(finish_event)); + C10_CUDA_CHECK(cudaEventSynchronize(start_event)); + C10_CUDA_CHECK(cudaEventSynchronize(finish_event)); + C10_CUDA_CHECK( + cudaEventElapsedTime(&kernel_time_ms_, start_event, finish_event)); + C10_CUDA_CHECK(cudaEventDestroy(start_event)); + C10_CUDA_CHECK(cudaEventDestroy(finish_event)); bytes_processed_ = 0; // Figure how many bytes are inputs, outputs, and temporary buffers - for (auto i : c10::irange(args.size())) { + for (auto i : c10::irange(num_inputs)) { if (auto tensor_arg_abstract = dynamic_cast(args[i])) { bytes_processed_ += tensor_arg_abstract->numel() * diff --git a/torch/csrc/jit/codegen/cuda/executor.h b/torch/csrc/jit/codegen/cuda/executor.h index 1d6ff4487b8f6..9d4775b37ca95 100644 --- a/torch/csrc/jit/codegen/cuda/executor.h +++ b/torch/csrc/jit/codegen/cuda/executor.h @@ -16,6 +16,9 @@ namespace jit { namespace fuser { namespace cuda { +TORCH_CUDA_CU_API bool shouldFillAllocationWithNan(); +TORCH_CUDA_CU_API void setFillAllocationWithNan(bool value); + // TODO: Should this actually be in launch params? struct TORCH_CUDA_CU_API CompileOptions { c10::Device device = c10::Device(c10::DeviceType::CUDA, 0); @@ -217,6 +220,7 @@ class TORCH_CUDA_CU_API FusionExecutor : public NonCopyable { // skip allocating real storage for those, but still maintain its spot to // maintain the indexing from output aliases to inputs std::vector allocOutputs( + const KernelArgumentHolder& args, kir::ExpressionEvaluator& expr_eval, const std::unordered_set& alias_indices = {}); diff --git a/torch/csrc/jit/codegen/cuda/executor_utils.cpp b/torch/csrc/jit/codegen/cuda/executor_utils.cpp index db9764eb3059e..cc435ae4bb3b8 100644 --- a/torch/csrc/jit/codegen/cuda/executor_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/executor_utils.cpp @@ -155,7 +155,7 @@ bool validateKernelArgTensor( } if (!is_cpu_scalar(arg) && !arg.is_cuda()) { - msg << "Argumnet is a CPU tensor which is not supported in fusions.\n"; + msg << "Argument is a CPU tensor which is not supported in fusions.\n"; return false; } @@ -824,7 +824,7 @@ void bindInputForExprEvaluation( if (root_domain[dim]->hasExpandedExtent()) { TORCH_INTERNAL_ASSERT( tensor_arg_stride == 0, - "Execting an expanded dimension on dimension ", + "Expecting an expanded dimension on dimension ", dim, " but found stride ", tensor_arg_stride); @@ -838,18 +838,13 @@ void bindInputForExprEvaluation( *maybe_expanded_size == tensor_arg_size, "Expecting expanded extent of ", *maybe_expanded_size, - " but recieved value of ", + " but received value of ", tensor_arg_size); } } const auto value = root_domain[dim]->hasExpandedExtent() ? 1 : tensor_arg_size; - if (value == 0 && cg_tensor->uses().empty()) { - // If there's no uses, ignore there's a size-0 dimension. - continue; - } - TORCH_INTERNAL_ASSERT(value != 0, "Cannot handle size-0 dimensions"); bool should_bind = true; if (check_consistency) { const auto prev_value = expr_eval.evaluate(extent); @@ -941,7 +936,7 @@ void initializeCudaContext() { if (!pctx) { std::unique_lock cudaFreeMutexLock( *(c10::cuda::getFreeMutex())); - cudaFree(nullptr); + C10_CUDA_CHECK(cudaFree(nullptr)); } } @@ -1014,7 +1009,7 @@ std::pair nvrtcCompile( }); #ifdef USE_ROCM - std::vector args = {"--std=c++14"}; + std::vector args = {"--std=c++17"}; #if ROCM_VERSION >= 40200 args.push_back("-hip-pch"); #endif @@ -1023,6 +1018,12 @@ std::pair nvrtcCompile( // compile to sass is not allowed prior to CUDA 11.1 compile_to_sass = false; #endif + + if (isOptionDisabled(DisableOption::CompileToSass)) { + // Allows manually disabling compilation to sass + // so the intermediate ptx could be checked. + compile_to_sass = false; + } // CUDA 11.1 allows going directly to SASS (sm_) instead of PTX (compute_) // which gives better backwards compatibility to work on older driver, // (since older driver doesn't necessrily recognize PTX emitted by new @@ -1035,7 +1036,7 @@ std::pair nvrtcCompile( std::to_string(minor); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) std::vector args = { - "--std=c++14", compute.c_str(), "-default-device"}; + "--std=c++17", compute.c_str(), "-default-device"}; #endif const bool disable_fma = isOptionDisabled(DisableOption::Fma); diff --git a/torch/csrc/jit/codegen/cuda/expr_evaluator.cpp b/torch/csrc/jit/codegen/cuda/expr_evaluator.cpp index 7bda8682189ee..6e1c628111113 100644 --- a/torch/csrc/jit/codegen/cuda/expr_evaluator.cpp +++ b/torch/csrc/jit/codegen/cuda/expr_evaluator.cpp @@ -54,7 +54,17 @@ void ExpressionEvaluator::bind(Val* value, const IntOrDouble& concrete_value) { TORCH_CHECK( value->definition() == nullptr, "Tried to bind to a value that is computed in the fusion IR"); - known_values_[value] = concrete_value; + if (value->isA()) { + known_named_scalars_[value->as()->name()] = concrete_value; + } else { + known_values_[value] = concrete_value; + } +} + +void ExpressionEvaluator::bind( + const std::string& name, + const IntOrDouble& concrete_value) { + known_named_scalars_[name] = concrete_value; } c10::optional ExpressionEvaluator::evaluate(Val* value) { @@ -88,7 +98,7 @@ void ExpressionEvaluator::print() const { c10::optional ExpressionEvaluator::getValue(Val* value) { TORCH_INTERNAL_ASSERT( value->isAnInt() || value->isADouble(), - "Expression Evaluation does not support values other than integers at this time."); + "Expression Evaluation does not support values other than integers/doubles at this time."); if (value->getValType().value() == ValType::Scalar) { if (value->isAnInt() && value->as()->value().has_value()) { @@ -99,12 +109,20 @@ c10::optional ExpressionEvaluator::getValue(Val* value) { } } - const auto it = known_values_.find(value); - return it != known_values_.end() ? c10::optional(it->second) - : c10::nullopt; + if (value->isA()) { + const auto it = known_named_scalars_.find(value->as()->name()); + return it != known_named_scalars_.end() + ? c10::optional(it->second) + : c10::nullopt; + } else { + const auto it = known_values_.find(value); + return it != known_values_.end() ? c10::optional(it->second) + : c10::nullopt; + } } void ExpressionEvaluator::handle(UnaryOp* uop) { + using namespace IntOrDouble_functions; const auto in = evaluate(uop->in()); if (in.has_value()) { switch (uop->getUnaryOpType()) { @@ -123,6 +141,9 @@ void ExpressionEvaluator::handle(UnaryOp* uop) { TORCH_INTERNAL_ASSERT(false, "dtype not supported in evaluator"); } break; + case UnaryOpType::Abs: + known_values_[uop->out()] = abs(*in); + break; default: TORCH_CHECK( !"Unexpected operator type ", diff --git a/torch/csrc/jit/codegen/cuda/expr_evaluator.h b/torch/csrc/jit/codegen/cuda/expr_evaluator.h index 8d906ff58e43d..4329f9604304b 100644 --- a/torch/csrc/jit/codegen/cuda/expr_evaluator.h +++ b/torch/csrc/jit/codegen/cuda/expr_evaluator.h @@ -7,6 +7,7 @@ #include +#include #include namespace torch { @@ -30,6 +31,9 @@ class TORCH_CUDA_CU_API ExpressionEvaluator : private OptOutDispatch { //! Bind a concrete value to an IR variable void bind(Val* value, const IntOrDouble& concrete_value); + //! Bind a concrete value to a named scalar + void bind(const std::string& name, const IntOrDouble& concrete_value); + //! Try to evaluate a Fusion IR value c10::optional evaluate(Val* value); @@ -49,9 +53,11 @@ class TORCH_CUDA_CU_API ExpressionEvaluator : private OptOutDispatch { void handle(UnaryOp*) final; void handle(BinaryOp*) final; + // TODO: handle swizzle private: std::unordered_map known_values_; + std::unordered_map known_named_scalars_; Fusion* fusion_ = nullptr; FusionPrecomputedValues* evaluator_precomputed_values_ = nullptr; }; diff --git a/torch/csrc/jit/codegen/cuda/fusion.cpp b/torch/csrc/jit/codegen/cuda/fusion.cpp index 04c367c667275..e4f24f0473a19 100644 --- a/torch/csrc/jit/codegen/cuda/fusion.cpp +++ b/torch/csrc/jit/codegen/cuda/fusion.cpp @@ -11,6 +11,7 @@ #include #include #include +#include namespace torch { namespace jit { @@ -339,6 +340,20 @@ void Fusion::printKernel(DataType index_type) { std::cout << codegen::generateCudaKernel(GpuLower(this, index_type).kernel()); } +std::unordered_map> Fusion::bankConflictInfo( + DataType index_type) { + GpuLower lower(this, index_type); + auto kernel = lower.kernel(); + auto info = getBankConflictInfo(kernel); + // The container of exprs goes out of scope, so we return a map of string here + std::unordered_map> result; + result.reserve(info.size()); + for (auto i : info) { + result[i.first->toString()] = i.second; + } + return result; +} + void Fusion::printMath(bool from_outputs_only) { FUSER_PERF_SCOPE("Fusion::printMath"); diff --git a/torch/csrc/jit/codegen/cuda/fusion.h b/torch/csrc/jit/codegen/cuda/fusion.h index e726d793be756..2c0c59fae2b9b 100644 --- a/torch/csrc/jit/codegen/cuda/fusion.h +++ b/torch/csrc/jit/codegen/cuda/fusion.h @@ -136,6 +136,10 @@ class TORCH_CUDA_CU_API Fusion : public IrContainer { //! Lower the fusion and print a kernel void printKernel(DataType index_type = DataType::Int); + //! Lower the fusion and evaluate bank conflict info + std::unordered_map> bankConflictInfo( + DataType index_type = DataType::Int); + //! Return a list of topologically sorted expressions. This only includes //! exprs required to genereate registered outputs. std::vector exprs(); diff --git a/torch/csrc/jit/codegen/cuda/fusion_segmenter.cpp b/torch/csrc/jit/codegen/cuda/fusion_segmenter.cpp index f993705c9bdc2..c0bf81dc688bf 100644 --- a/torch/csrc/jit/codegen/cuda/fusion_segmenter.cpp +++ b/torch/csrc/jit/codegen/cuda/fusion_segmenter.cpp @@ -3190,7 +3190,7 @@ class ForceHalfAnnotation : public IterVisitor { val->getDataType().value() == DataType::BFloat16); }); - annotation.traverseFrom(fusion, fp16_outputs); + annotation.traverseTo(fusion, fp16_outputs); return annotation.force_fp16_tv_set_; } diff --git a/torch/csrc/jit/codegen/cuda/graph_fuser.cpp b/torch/csrc/jit/codegen/cuda/graph_fuser.cpp index 1b51c87075471..c2427f9386278 100644 --- a/torch/csrc/jit/codegen/cuda/graph_fuser.cpp +++ b/torch/csrc/jit/codegen/cuda/graph_fuser.cpp @@ -1554,10 +1554,6 @@ void guardFusionGroup( profiled_ivalue_indices.insert(index); } } - // we should assert on non-tensor inputs - TORCH_INTERNAL_ASSERT( - tensor_inputs_to_check.size(), - "CudaFusionGuard expects at least one tensor input"); // insert the if block first; auto versioning_if = @@ -2176,7 +2172,10 @@ void decomposeLinearOps(Block* block) { void replaceAliasOpsWithCopy(std::shared_ptr& graph, Block* block) { static std::unordered_map alias_to_copy_mapping( {{aten::expand, prim::expand_copy}, - {aten::expand_as, prim::expand_as_copy}}); + {aten::expand_as, prim::expand_as_copy}, + {aten::permute, prim::permute_copy}, + {aten::transpose, prim::transpose_copy}, + {aten::t, prim::t_copy}}); // TODO: revert disabled aten::view // ({{aten::view, prim::view_copy}, // {aten::reshape, prim::reshape_copy}, @@ -2228,7 +2227,10 @@ void replaceAliasOpsWithCopy(std::shared_ptr& graph, Block* block) { void revertAliasCopyOps(std::shared_ptr& graph, Block* block) { static std::unordered_map copy_to_alias_mapping( {{prim::expand_copy, aten::expand}, - {prim::expand_as_copy, aten::expand_as}}); + {prim::expand_as_copy, aten::expand_as}, + {prim::permute_copy, aten::permute}, + {prim::transpose_copy, aten::transpose}, + {prim::t_copy, aten::t}}); // TODO: revert disabled aten::view // ({{prim::view_copy, aten::view}, // {prim::flatten_copy, aten::flatten}, diff --git a/torch/csrc/jit/codegen/cuda/grouped_reduction.cpp b/torch/csrc/jit/codegen/cuda/grouped_reduction.cpp index 5931eb3427aa9..d907a0665e9f6 100644 --- a/torch/csrc/jit/codegen/cuda/grouped_reduction.cpp +++ b/torch/csrc/jit/codegen/cuda/grouped_reduction.cpp @@ -38,7 +38,7 @@ bool hasMatchingTransformations(TensorView* ref, TensorView* other) { } // Validate grouping of reductions and return a new max producer position -unsigned int validateReductionGrouping( +void validateReductionGrouping( const std::vector& inputs, const std::vector& outputs) { TORCH_INTERNAL_ASSERT(inputs.size() == outputs.size()); @@ -57,7 +57,6 @@ unsigned int validateReductionGrouping( const auto num_root_dims = ref_domain.size(); const auto num_dims = ref_tv->nDims(); const auto ref_ca_pos = ref_tv->getComputeAtPosition(); - auto max_producer_pos = ref_tv->getMaxProducerPosition(); for (const auto i : c10::irange(inputs.size())) { auto output_tv = outputs.at(i)->as(); const auto& output_domain = output_tv->getRootDomain(); @@ -136,9 +135,6 @@ unsigned int validateReductionGrouping( ref_tv->toString(), ". Mismatched tensor: ", output_tv->toString()); - - max_producer_pos = - std::max(max_producer_pos, output_tv->getMaxProducerPosition()); } // Must not have any data dependency from outputs to inputs @@ -152,8 +148,6 @@ unsigned int validateReductionGrouping( } TORCH_INTERNAL_ASSERT(all_dep_vals.empty(), ss.str()); } - - return max_producer_pos; } } // namespace @@ -194,14 +188,14 @@ void groupReductions(const std::vector& reduction_outputs) { inputs.at(i) = rop->in(); } - auto max_producer_pos = validateReductionGrouping(inputs, outputs); - - for (auto output : ir_utils::filterByType(outputs)) { - output->setMaxProducer(max_producer_pos); - } + validateReductionGrouping(inputs, outputs); IrBuilder::create( container, op_types, init_vals, outputs, inputs); + + for (auto output : ir_utils::filterByType(outputs)) { + output->updateMaxProducerPosition(); + } } } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/index_compute.cpp b/torch/csrc/jit/codegen/cuda/index_compute.cpp index 5ad56bda15f21..9028f93e9a20f 100644 --- a/torch/csrc/jit/codegen/cuda/index_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/index_compute.cpp @@ -51,8 +51,8 @@ int getProducerHaloOffset( IterDomain* consumer_id = it->second; const auto& halo_map = GpuLower::current()->haloInfo(); - const auto p_pad = halo_map.getRootAxisInfo(producer_id).width(0); - const auto c_pad = halo_map.getRootAxisInfo(consumer_id).width(0); + const auto p_pad = halo_map->getRootAxisInfo(producer_id).width(0); + const auto c_pad = halo_map->getRootAxisInfo(consumer_id).width(0); auto offset = p_pad - c_pad; @@ -178,7 +178,8 @@ Val* getConcreteProducerOffsetWithGather( Val* window_idx = nullptr; if (use_concrete_map) { - window_idx = index_map.at(ir_utils::caMapExactConcreteId(window_id)); + window_idx = index_map.at(GpuLower::current()->caMap()->getConcreteMappedID( + window_id, IdMappingMode::EXACT)); } else { window_idx = index_map.at(window_id); } @@ -440,9 +441,8 @@ void IndexCompute::handle(Merge* merge) { // When the reference has halo extent for inner_id, that extent needs to // be used to un-merge - if (reference_halo_extent_map_.find(inner_id) != - reference_halo_extent_map_.end()) { - inner_extent = reference_halo_extent_map_[inner_id]; + if (halo_extent_map_.find(inner_id) != halo_extent_map_.end()) { + inner_extent = halo_extent_map_[inner_id]; } const auto outer_extent = getExtent(outer_id); @@ -587,20 +587,16 @@ IndexCompute::IndexCompute( std::unordered_set zero_domains, std::unordered_set zero_merged_in, std::unordered_set preferred_paths, - std::unordered_map reference_halo_extent_map) + std::unordered_map halo_extent_map) : IndexCompute( _td, std::move(initial_index_map), std::move(extent_map), std::move(zero_domains), std::move(zero_merged_in), - ContigIDs( - _td->domain(), - _td->getMaybeRFactorDomain(), - std::vector(_td->getMaybeRFactorDomain().size(), false), - {}), + ContigIDs::getNonContigIDs(), std::move(preferred_paths), - std::move(reference_halo_extent_map)) {} + std::move(halo_extent_map)) {} IndexCompute::IndexCompute( const TensorDomain* _td, @@ -610,14 +606,14 @@ IndexCompute::IndexCompute( std::unordered_set zero_merged_in, const ContigIDs& contig_finder, std::unordered_set preferred_paths, - std::unordered_map reference_halo_extent_map) + std::unordered_map halo_extent_map) : td_(_td), index_map_(std::move(initial_index_map)), extent_map_(std::move(extent_map)), zero_domains_(std::move(zero_domains)), zero_merged_in_(std::move(zero_merged_in)), preferred_paths_(std::move(preferred_paths)), - reference_halo_extent_map_(std::move(reference_halo_extent_map)) { + halo_extent_map_(std::move(halo_extent_map)) { FUSER_PERF_SCOPE("GpuLower::Lower::IndexCompute::IndexCompute"); // Make sure we recompute any indices we can that map to a contiguous access @@ -640,11 +636,11 @@ IndexCompute::IndexCompute( std::unordered_map initial_index_map, std::unordered_set zero_domains, std::unordered_set preferred_paths, - std::unordered_map reference_halo_extent_map) + std::unordered_map halo_extent_map) : index_map_(std::move(initial_index_map)), zero_domains_(std::move(zero_domains)), preferred_paths_(std::move(preferred_paths)), - reference_halo_extent_map_(std::move(reference_halo_extent_map)) { + halo_extent_map_(std::move(halo_extent_map)) { FUSER_PERF_SCOPE("GpuLower::Lower::IndexCompute::IndexCompute"); concrete_id_pass_ = true; swizzle_mode_ = SwizzleMode::Loop; @@ -703,7 +699,9 @@ void IndexCompute::collectIndexIntoPermissiveMap( auto id_outputs = ir_utils::filterByType(expr->outputs()); if (std::all_of( id_outputs.begin(), id_outputs.end(), [this](IterDomain* id) { - return index_map_.count(ir_utils::caMapExactConcreteId(id)); + return index_map_.count( + GpuLower::current()->caMap()->getConcreteMappedID( + id, IdMappingMode::EXACT)); })) { // Visit this expression: // LoopIndexingAnalysis::traverseFromDomainVals made sure that each @@ -715,7 +713,9 @@ void IndexCompute::collectIndexIntoPermissiveMap( for (auto id : id_inputs) { // Collect backward pass results from this expression if they are // made available in by this expression. - auto idx_it = index_map_.find(ir_utils::caMapExactConcreteId(id)); + auto idx_it = + index_map_.find(GpuLower::current()->caMap()->getConcreteMappedID( + id, IdMappingMode::EXACT)); if (idx_it != index_map_.end()) { permissive_index_map_ @@ -730,7 +730,8 @@ void IndexCompute::collectIndexIntoPermissiveMap( void IndexCompute::updateIndexMapFromPermissiveMap(const Expr* id_expr) { auto id_outputs = ir_utils::filterByType(id_expr->outputs()); for (auto id : id_outputs) { - auto concrete_id = ir_utils::caMapExactConcreteId(id); + auto concrete_id = GpuLower::current()->caMap()->getConcreteMappedID( + id, IdMappingMode::EXACT); // Only try to copy index val from permissive map when // the index is missing. if (!index_map_.count(concrete_id)) { @@ -750,7 +751,7 @@ void IndexCompute::run() { const std::vector domain_vals( td_->domain().begin(), td_->domain().end()); - traverseFrom(td_->fusion(), domain_vals, false); + traverseTo(td_->fusion(), domain_vals, false); } IterDomain* IndexCompute::maybeGetExactMapConcreteID(IterDomain* id) { @@ -784,15 +785,14 @@ bool IndexCompute::isZero(IterDomain* id) const { IndexCompute IndexCompute::updateIndexCompute( const TensorDomain* new_td, const std::unordered_map& id_map, - const ContigIDs& contig_finder, - const std::unordered_map& reference_halo_extent_map) - const { + const ContigIDs& contig_finder) const { FUSER_PERF_SCOPE("GpuLower::Lower::updateIndexCompute"); std::unordered_map updated_index_map; std::unordered_map updated_extent_map; std::unordered_set updated_zero_domains; std::unordered_set updated_zero_merged_in; + std::unordered_map updated_halo_extent_map; for (auto id_entry : id_map) { IterDomain* prev_id = id_entry.first; @@ -811,6 +811,11 @@ IndexCompute IndexCompute::updateIndexCompute( if (zero_merged_in_.find(prev_id) != zero_merged_in_.end()) { updated_zero_merged_in.emplace(new_id); } + + auto halo_extent_it = halo_extent_map_.find(prev_id); + if (halo_extent_it != halo_extent_map_.end()) { + updated_halo_extent_map[new_id] = halo_extent_it->second; + } } IndexCompute updated_index_compute( @@ -821,25 +826,7 @@ IndexCompute IndexCompute::updateIndexCompute( updated_zero_merged_in, contig_finder, {}, - reference_halo_extent_map); - - if (concrete_id_pass_) { - // This should be the same behavior as with a reference tensor - // created, since originally halo was pulled through exact - // ca mapping and in the concrete_id_pass case, the id_map - // also represents exact ca mapping. - // TODO: might need to re-visit pathological cases when we may - // need to traverse and propagate halo info again in here. - for (auto id_entry : id_map) { - IterDomain* prev_id = id_entry.first; - IterDomain* new_id = id_entry.second; - auto halo_extent_it = reference_halo_extent_map_.find(prev_id); - if (halo_extent_it != reference_halo_extent_map_.end()) { - updated_index_compute.reference_halo_extent_map_[new_id] = - halo_extent_it->second; - } - } - } + updated_halo_extent_map); updated_index_compute.run(); @@ -860,7 +847,7 @@ class UpdateLeafIndices : public IterVisitor { const std::vector domain_vals( td_->domain().begin(), td_->domain().end()); - traverseFrom(td_->fusion(), domain_vals, false); + traverseTo(td_->fusion(), domain_vals, false); } const std::unordered_map& indexMap() const { @@ -985,7 +972,7 @@ Val* getHaloExtentOfRootAxis(IterDomain* id, Val* normal_extent = nullptr) { normal_extent = id->extent(); } - const auto& halo = GpuLower::current()->haloInfo().getRootAxisInfo(id); + const auto& halo = GpuLower::current()->haloInfo()->getRootAxisInfo(id); if (halo.hasHalo()) { auto halo_extent = SimplifyingIrBuilder::addExpr( normal_extent, SimplifyingIrBuilder::create(halo.width())); @@ -1506,7 +1493,8 @@ std::vector Index::getGlobalProducerStridedIndices( // effort which means some domains may be producer's original domains. std::vector> p_id_backup; for (auto entry : c2p_map) { - auto ref_id = ir_utils::caMapExactConcreteId(entry.first); + auto ref_id = GpuLower::current()->caMap()->getConcreteMappedID( + entry.first, IdMappingMode::EXACT); auto p_id = entry.second; if (ref_id->getParallelType() == ParallelType::Vectorize) { p_id_backup.emplace_back(std::make_pair(p_id, p_id->getParallelType())); @@ -1745,7 +1733,8 @@ std::vector Index::getNonGlobalProducerStridedIndices( // effort which means some domains may be the originals. std::vector> p_id_backup; for (auto entry : c2p_index_map) { - auto ref_id = ir_utils::caMapExactConcreteId(entry.first); + auto ref_id = GpuLower::current()->caMap()->getConcreteMappedID( + entry.first, IdMappingMode::EXACT); auto p_id = entry.second; if (ref_id->getParallelType() == ParallelType::Vectorize) { p_id_backup.emplace_back(std::make_pair(p_id, p_id->getParallelType())); @@ -1937,52 +1926,27 @@ std::vector Index::getNonGlobalProducerStridedIndices( return strided_inds; } -std::vector Index::getLinearIndex( +std::vector Index::getLinearLogicalIndex( TensorView* consumer_tv, const std::vector& loops) { - // Use domain guard to ignore the contiguity of - // consumer tv. - TensorDomain* consumer_tv_no_contiguity_domain = nullptr; - auto contiguity_vector = - std::vector(consumer_tv->getMaybeRFactorDomain().size(), true); - if (consumer_tv->hasRFactor()) { - consumer_tv_no_contiguity_domain = IrBuilder::create( - consumer_tv->getRootDomain(), - consumer_tv->getRFactorDomain(), - consumer_tv->domain()->domain(), - contiguity_vector); - } else { - consumer_tv_no_contiguity_domain = IrBuilder::create( - consumer_tv->getRootDomain(), - consumer_tv->domain()->domain(), - contiguity_vector); - } - - ir_utils::TVDomainGuard domain_guard( - consumer_tv, consumer_tv_no_contiguity_domain); - - // TODO: - // More optimization on the underlying tensor layout - // will be done in a follow up. + auto guard = ir_utils::overrideContiguityGuard(consumer_tv, true); return getGlobalConsumerStridedIndices(consumer_tv, loops); } -std::vector Index::getGlobalConsumerStridedIndices( - const TensorView* consumer_tv, +std::vector Index::getPerDimLogicalIndex( + TensorView* consumer_tv, const std::vector& loops) { - FUSER_PERF_SCOPE("GpuLower::Lower::getGlobalConsumerIndex"); - - auto gpu_lower = GpuLower::current(); - - auto index_from_id_graph = getTensorIndexFromIdGraph(loops, consumer_tv); - - auto consumer_indexing = index_from_id_graph.index; + auto guard = ir_utils::overrideContiguityGuard(consumer_tv, false); + IndexFromIdGraph index_from_id_graph = + getTensorIndexFromIdGraph(loops, consumer_tv); + return getRootIndices(consumer_tv, loops, index_from_id_graph); +} +std::vector Index::getStrides(const TensorView* tv) { // Indices should now be mapped onto IterDomains in consumer, so just grab // and use them. - auto root_dom = consumer_tv->getMaybeRFactorDomain(); + auto root_dom = tv->getMaybeRFactorDomain(); - // TODO: Abstract stride logic to reuse with producer indexing std::vector strides( root_dom.size(), GpuLower::current()->kernel()->oneVal()); { @@ -1993,14 +1957,13 @@ std::vector Index::getGlobalConsumerStridedIndices( continue; } std::stringstream ss; - ss << "T" << consumer_tv->name() << ".stride[" << stride_i++ << "]"; + ss << "T" << tv->name() << ".stride[" << stride_i++ << "]"; strides[i] = SimplifyingIrBuilder::create(ss.str(), DataType::Int); } } - TORCH_INTERNAL_ASSERT( - root_dom.size() == consumer_tv->domain()->contiguity().size()); + TORCH_INTERNAL_ASSERT(root_dom.size() == tv->domain()->contiguity().size()); Val* cur_contig_stride = GpuLower::current()->kernel()->oneVal(); for (const auto i : c10::irange(root_dom.size())) { auto dim = root_dom.size() - i - 1; @@ -2008,24 +1971,7 @@ std::vector Index::getGlobalConsumerStridedIndices( continue; } - Val* root_ind = nullptr; - if (consumer_indexing.indexMap().find(root_dom[dim]) != - consumer_indexing.indexMap().end()) { - root_ind = consumer_indexing.indexMap().at(root_dom[dim]); - } else if (root_dom[dim]->isBroadcast()) { - root_ind = GpuLower::current()->kernel()->zeroVal(); - } - - TORCH_INTERNAL_ASSERT( - root_ind != nullptr, - "Couldn't find root mapping for ", - consumer_tv->toString(), - " dim: ", - dim, - " id: ", - root_dom[dim]->toString()); - - if (consumer_tv->domain()->contiguity()[dim]) { + if (tv->domain()->contiguity()[dim]) { // If contig, used the stored stride which may be the previous // dimensions stride * previous dimensions size strides[dim] = cur_contig_stride; @@ -2041,12 +1987,18 @@ std::vector Index::getGlobalConsumerStridedIndices( strides[dim], getHaloExtentOfRootAxis(root_dom[dim])); } } + return strides; +} - auto vectorize_shift = - loops.empty() ? nullptr : loops.back()->vectorize_shift(); +std::vector Index::getRootIndices( + const TensorView* tv, + const std::vector& loops, + const IndexFromIdGraph& index_from_id_graph) { + auto gpu_lower = GpuLower::current(); + auto root_dom = tv->getMaybeRFactorDomain(); + auto indexing = index_from_id_graph.index; - // Global striding - std::vector strided_inds( + std::vector root_inds( root_dom.size(), GpuLower::current()->kernel()->zeroVal()); for (const auto i : c10::irange(root_dom.size())) { // See a comment in indexing to root domains in getGlobalProducerIndex. @@ -2057,22 +2009,21 @@ std::vector Index::getGlobalConsumerStridedIndices( } TORCH_INTERNAL_ASSERT( - consumer_indexing.indexMap().find(root_dom[i]) != - consumer_indexing.indexMap().end(), + indexing.indexMap().find(root_dom[i]) != indexing.indexMap().end(), "Couldn't find root mapping for ", - consumer_tv->toString(), + tv->toString(), " dim: ", i, " id: ", root_dom[i]->toString()); - auto root_ind = consumer_indexing.indexMap().at(root_dom[i]); + auto root_ind = indexing.indexMap().at(root_dom[i]); // index hoist must be done before the adjustments for halo root_ind = hoistConsumerIndex( root_dom[i], - consumer_tv, - consumer_indexing, + tv, + indexing, index_from_id_graph.resolved_loop_domains, index_from_id_graph.initial_concrete_index_map, loops, @@ -2080,12 +2031,33 @@ std::vector Index::getGlobalConsumerStridedIndices( root_ind = SimplifyingIrBuilder::addExpr( root_ind, getGlobalConsumerOffsetWithPartialSplit(root_dom[i])); + root_inds[i] = root_ind; + } + return root_inds; +} - if (root_ind->isZeroInt()) { +std::vector Index::getGlobalConsumerStridedIndices( + const TensorView* consumer_tv, + const std::vector& loops) { + FUSER_PERF_SCOPE("GpuLower::Lower::getGlobalConsumerIndex"); + + auto index_from_id_graph = getTensorIndexFromIdGraph(loops, consumer_tv); + auto consumer_indexing = index_from_id_graph.index; + auto strides = getStrides(consumer_tv); + auto root_inds = getRootIndices(consumer_tv, loops, index_from_id_graph); + + // Global striding + auto vectorize_shift = + loops.empty() ? nullptr : loops.back()->vectorize_shift(); + std::vector strided_inds( + root_inds.size(), GpuLower::current()->kernel()->zeroVal()); + for (const auto i : c10::irange(root_inds.size())) { + if (root_inds[i]->isZeroInt()) { continue; } else { - auto strided_ind = SimplifyingIrBuilder::mulExpr(root_ind, strides[i]); - if (i == root_dom.size() - 1 && vectorize_shift != nullptr) { + auto strided_ind = + SimplifyingIrBuilder::mulExpr(root_inds[i], strides[i]); + if (i == strides.size() - 1 && vectorize_shift != nullptr) { strided_inds[i] = SimplifyingIrBuilder::addExpr(strided_ind, vectorize_shift); } else { @@ -2354,103 +2326,71 @@ std::vector getPredicateContigIds( const auto& consumer_root_domain = consumer_tv->getRootDomain(); - std::vector contiguous_ids = consumer_root_domain; - - if (contiguous_ids.empty()) { + if (consumer_root_domain.empty()) { return std::vector(); } - // If root IDs are partial, i.e., start is non-zero and stop is not - // equal to extent, predication can't be done with merged domains as - // start and stop information is only available with root - // domains. Similarly, merged domains don't have enough information - // about halo to do correct predication, so they must be excluded. - std::unordered_set excluded_ids; + std::unordered_map concrete_index_map; + for (auto entry : consumer_index_map) { + auto c_id = gpu_lower->caMap()->getConcreteMappedID( + entry.first, IdMappingMode::EXACT); + concrete_index_map[c_id] = entry.second; + } - for (auto consumer_root_id : consumer_root_domain) { - if (gpu_lower->haloInfo().getRootAxisInfo(consumer_root_id).hasHalo()) { - excluded_ids.insert(consumer_root_id); - continue; - } - if (consumer_root_id->maybePartial()) { - excluded_ids.insert(consumer_root_id); - continue; - } - // When consumer_root_id is a broadcast domain, do not allow contig - // predication as the merged output is not mapped with the - // reference unless the concrete domain is also a broadcast - // domain. - if (consumer_root_id->isBroadcast() && - !GpuLower::current() - ->caMap() - ->getConcreteMappedID(consumer_root_id, IdMappingMode::PERMISSIVE) - ->isBroadcast()) { - excluded_ids.insert(consumer_root_id); + std::vector predicate_contiguity(consumer_root_domain.size(), true); + std::unordered_set final_ids; + for (auto root_i : c10::irange(predicate_contiguity.size())) { + auto root_id = consumer_root_domain[root_i]; + if (root_id->maybePartial()) { + final_ids.insert(root_id); continue; } // Shifted or gathered axes need to be predicated at the root domain auto shift_expr = dynamic_cast(consumer_tv->definition()); auto gather_expr = dynamic_cast(consumer_tv->definition()); - if (shift_expr == nullptr && gather_expr == nullptr) { - continue; - } - auto consumer_root_pos = consumer_tv->domain()->rootPosOf(consumer_root_id); - if ((shift_expr && shift_expr->offset(consumer_root_pos) != 0) || - (gather_expr && consumer_root_pos < gather_expr->windowShape().size() && - gather_expr->windowShape().at(consumer_root_pos) != 1)) { - excluded_ids.insert(consumer_root_id); + if ((shift_expr && shift_expr->offset(root_i) != 0) || + (gather_expr && root_i < gather_expr->windowShape().size() && + gather_expr->windowShape().at(root_i) != 1)) { + final_ids.insert(root_id); } } - // Run through iteration domain history - auto exprs = StmtSort::getExprs( - consumer_tv->fusion(), - {consumer_tv->domain()->domain().begin(), - consumer_tv->domain()->domain().end()}); + ContigIDs contig_finder( + consumer_tv->domain()->domain(), + consumer_root_domain, + predicate_contiguity, + final_ids, + concrete_index_map, + GpuLower::current()->divisbleSplitSet(), + GpuLower::current()->caMap(), + GpuLower::current()->haloInfo(), + GpuLower::current()->concretizedBroadcastDomains(), + {}, + false, + true); - for (auto expr : exprs) { - // If not a merge, output is not contiguous - if (expr->isA()) { - auto merge = expr->as(); - auto inner_contig_it = std::find( - contiguous_ids.begin(), contiguous_ids.end(), merge->inner()); - auto outer_contig_it = std::find( - contiguous_ids.begin(), contiguous_ids.end(), merge->outer()); + std::vector contig_id_infos; + std::unordered_set covered_roots; - if (excluded_ids.count(merge->inner()) > 0 || - excluded_ids.count(merge->outer()) > 0) { - continue; - } + // Create entries and return them + for (auto root_id : consumer_root_domain) { + if (covered_roots.count(root_id) > 0) { + continue; + } - // Do not try to predicate the merge output domain if the output - // domain has not a predicate that is mapped from the reference. - // See FusionContigPredicate_CUDA for a concrete example. - if (consumer_index_map.find(merge->out()) == consumer_index_map.end()) { - continue; - } + auto contig_id_it = contig_finder.rootToIndexedID().find(root_id); - if (inner_contig_it != contiguous_ids.end() && - outer_contig_it != contiguous_ids.end()) { - // If inner and outer are contiguous, out must be contiguous. Remove - // inner and outer, and add out. - contiguous_ids.erase(outer_contig_it); - contiguous_ids.erase(std::find( - contiguous_ids.begin(), contiguous_ids.end(), merge->inner())); - contiguous_ids.emplace_back(merge->out()); - } - } - } + TORCH_INTERNAL_ASSERT( + contig_id_it != contig_finder.rootToIndexedID().end(), + "Error in predicate contiguity analysis, missing index for root ", + root_id->toString()); - std::vector contig_id_infos; + auto contig_id = contig_id_it->second; - // Create entries and return them - for (auto contig_id : contiguous_ids) { // Pick inputs from the starting domains, i.e., // reference_predicated_root_domain. - auto contig_root_vals = IterVisitor::getInputsTo( - {contig_id}, - {consumer_root_domain.begin(), consumer_root_domain.end()}); - auto contig_root_ids = ir_utils::filterByType(contig_root_vals); + auto contig_root_ids = contig_finder.indexedRootIDs(contig_id); + covered_roots.insert(contig_root_ids.begin(), contig_root_ids.end()); PredicateDomainInfo contig_id_info; contig_id_info.id = contig_id; contig_id_info.covered_ids = std::unordered_set( @@ -2504,7 +2444,7 @@ int getUnswitchStopOffset( const auto gpu_lower = GpuLower::current(); AxisHaloInfo halo_info = - gpu_lower->haloInfo().getRootAxisInfo(consumer_root_id); + gpu_lower->haloInfo()->getRootAxisInfo(consumer_root_id); // If the consumer root domain to predicate does not have halo, no // adjustment is required. @@ -2528,7 +2468,7 @@ int getUnswitchStopOffset( unswitch_it, consumer_tv->domain()->domain().end(), [&gpu_lower, &consumer_root_id](auto leaf_id) { - return gpu_lower->haloInfo().isHaloInherited( + return gpu_lower->haloInfo()->isHaloInherited( consumer_root_id, leaf_id); })) { return halo_info.width(); @@ -2686,7 +2626,8 @@ std::pair getStartAndStopLimitOffsets( Val* stop_limit = SimplifyingIrBuilder::negExpr(consumer_id->stopOffset()); if (!non_divisible_pred) { - AxisHaloInfo halo_info = gpu_lower->haloInfo().getRootAxisInfo(consumer_id); + AxisHaloInfo halo_info = + gpu_lower->haloInfo()->getRootAxisInfo(consumer_id); // Below, "left" and "right" halo mean halo at offset zero and // axis extent, respectively. @@ -2710,8 +2651,8 @@ std::pair getStartAndStopLimitOffsets( // that it is less than the extent of the predicated ID + // halo. Note that getRootAxisInfo doesn't work since consumer_id // isn't a root domain. - if (gpu_lower->haloInfo().hasHaloWidth(consumer_id)) { - auto halo = gpu_lower->haloInfo().getHaloWidth(consumer_id); + if (gpu_lower->haloInfo()->hasHaloWidth(consumer_id)) { + auto halo = gpu_lower->haloInfo()->getHaloWidth(consumer_id); stop_limit = SimplifyingIrBuilder::addExpr(stop_limit, halo); } } @@ -2858,8 +2799,8 @@ bool canOmitStopPredicate( // to be predicated, not its merged contig id even if it exists. So, // if contig_id does not have root axis info, contig_id is // guaranteed to have no halo. - auto halo_ext = gpu_lower->haloInfo().hasRootAxisInfo(contig_id) - ? gpu_lower->haloInfo().getRootAxisInfo(contig_id).width() + auto halo_ext = gpu_lower->haloInfo()->hasRootAxisInfo(contig_id) + ? gpu_lower->haloInfo()->getRootAxisInfo(contig_id).width() : 0; if (halo_ext + stop_offset_val.value() > 0) { @@ -2977,14 +2918,6 @@ std::vector Index::getReferenceRootPredicates( auto db_axis = gpu_lower->doubleBufferInfo().getDoubleBufferAxis(consumer_tv); - // Indexing is done without considering contig merging. Actual - // predicated domains are determined by considering contiguity. - const ContigIDs contig_finder( - consumer_tv->domain()->domain(), - consumer_tv->getMaybeRFactorDomain(), - std::vector(consumer_tv->getMaybeRFactorDomain().size(), false), - {}); - // Generate start and stop indexing from idgraph. // // Both start and stop positions may need to be predicated. Indexing diff --git a/torch/csrc/jit/codegen/cuda/index_compute.h b/torch/csrc/jit/codegen/cuda/index_compute.h index 43cde710fdfc4..9a94ee94ac09c 100644 --- a/torch/csrc/jit/codegen/cuda/index_compute.h +++ b/torch/csrc/jit/codegen/cuda/index_compute.h @@ -1,7 +1,6 @@ #pragma once #include -#include #include #include @@ -17,40 +16,40 @@ * indices (based on input indices) that match the root dimension. * * For example with GLOBAL tensor: - * TV[I, J] - * TV[Io, Ii{4}, J] = TV.split(I, factor=4) + * TV[I, K] + * TV[Io, Ii{4}, K] = TV.split(I, factor=4) * ALLOC: NONE * INDEX: indexCompute {i, j, k} -> {i * 4 + j, k} - * FLATTENED_INDEX: {i * 4 + j, k} -> {i * 4 + j * J + k} + * FLATTENED_INDEX: {i * 4 + j, k} -> {(i * 4 + j) * K + k} * PREDICATE: {i * 4 + j, k} -> i * 4 + j < I * * * For example with SHARED tensor: * - * global_TV[I, J] - * global_TV[Io, Ii{4}, J] = global_TV.split(I, factor=4) + * global_TV[I, K] + * global_TV[Io, Ii{4}, K] = global_TV.split(I, factor=4) * smem_TV.compute_at(global_TV, 1) * global_TV.parallelize(1, threadIDx.x) * - * ALLOC: alloc(smem_TV, 4 x J) + * ALLOC: alloc(smem_TV, 4 x K) * INDEX: indexCompute(smem_TV, {threadIdx.x, k}) -> {threadIdx.x, k} - * FLATTENED_INDEX: {threadIdx.x * 4 + j, k} -> {threadIdx.x * 4 + j * J + k} + * FLATTENED_INDEX: {threadIdx.x * 4 + j, k} -> {(threadIdx.x * 4 + j) * K + k} * PREDICATE: {threadIdx.x * 4 + j, k} -> threadIdx.x * 4 + j < I // Same as if * global * * * For example with LOCAL tensor: - * global_TV[I, J, K] - * global_TV[Io, Ii{4}, J] = global_TV.split(I, factor=4) - * reg_TV.compute_at(global_TV, 1) + * global_TV[I, K, L] + * global_TV[Io, Ii{4}, K, L] = global_TV.split(I, factor=4) + * reg_TV.compute_at(global_TV, 2) * global_TV.parallelize(1, threadIDx.x) * global_TV{i, j, k, l} -> { i * 4 + j, k, l } - * global_TV{ i * 4 + j, k, l } -> { i * 4 + j * J * K + k * K + l} + * global_TV{ i * 4 + j, k, l } -> { (i * 4 + j) * K * L + k * L + l} * - * ALLOC: alloc(reg_TV, J x K) + * ALLOC: alloc(reg_TV, K x L) * INDEX: {k, l} -> {k, l} - * FLATTENED_INDEX: {k, l} -> {k * J + l} - * PREDICATE: i * 4 + j < I && k < J && l < K -> // Same as if global + * FLATTENED_INDEX: {k, l} -> {k * L + l} + * PREDICATE: i * 4 + j < I && k < K && l < L -> // Same as if global * * These indices can then be flattened later based on strides. */ @@ -62,6 +61,7 @@ namespace cuda { class ContigIDs; class LoopIndexing; +struct IndexFromIdGraph; class IndexCompute : public BackwardVisitor { protected: @@ -134,9 +134,8 @@ class IndexCompute : public BackwardVisitor { // if there's an option std::unordered_set preferred_paths_; - // Map from IterDomains to halo-extended extents in corresponding - // reference tensor - std::unordered_map reference_halo_extent_map_; + // Map from IterDomains to halo-extended extents + std::unordered_map halo_extent_map_; // Temporary flag which tells IndexCompute to use concrete id's from the exact // map rather than the actual IDs used in the ID expressions. @@ -188,7 +187,7 @@ class IndexCompute : public BackwardVisitor { std::unordered_set zero_domains, std::unordered_set _zero_merged_in, std::unordered_set preferred_paths = {}, - std::unordered_map reference_halo_extent_map = {}); + std::unordered_map halo_extent_map = {}); IndexCompute( const TensorDomain* _td, @@ -198,7 +197,7 @@ class IndexCompute : public BackwardVisitor { std::unordered_set _zero_merged_in, const ContigIDs& contig_finder, std::unordered_set preferred_paths = {}, - std::unordered_map reference_halo_extent_map = {}); + std::unordered_map halo_extent_map = {}); // Entry point used for using concrete id based traversal. This traversal is // assumed to start at leaf IDs provided by initial_index_map. @@ -213,9 +212,7 @@ class IndexCompute : public BackwardVisitor { IndexCompute updateIndexCompute( const TensorDomain* new_td, const std::unordered_map& id_map, - const ContigIDs& contig_finder, - const std::unordered_map& reference_halo_extent_map = - {}) const; + const ContigIDs& contig_finder) const; // Interface to run index traversal through loop indexing analysis result to // be used with the entry point for concrete id based traversal. @@ -331,6 +328,15 @@ class Index { const TensorView* consumer, const std::vector& loops); + // get the strides of a tensor used for the index lowering + static std::vector getStrides(const TensorView* tv); + + // get the root indices of a tensor used for the index lowering + static std::vector getRootIndices( + const TensorView* tv, + const std::vector& loops, + const IndexFromIdGraph& index_from_id_graph); + public: // Indexing functions // Consumer = Producer @@ -363,19 +369,28 @@ class Index { const TensorView* consumer, const std::vector& loops); - //! Returns a vector of strided indices mapped onto the (rfactor) + //! Returns the logical index linearized from a multi-dimension address into a + //! linear memory address a consumer tensor. The returned index is intended to + //! be used for the computation of some tensor factories, such as: arange and + //! rand (for Philox pseudo random sequences) + static std::vector getLinearLogicalIndex( + TensorView* consumer_tv, + const std::vector& loops); + + //! Returns a vector of logical indices mapped onto the (rfactor) //! root domain of a consumer tensor. The returned index is intended - //! to be used to index into arange or Philox pseudo random sequences - static std::vector getLinearIndex( + //! to be used for the computation of some tensor factories, such as: + //! eye + static std::vector getPerDimLogicalIndex( TensorView* consumer_tv, const std::vector& loops); //! Take a consumer tensorview and loop nest and generates predicates //! associated with the concrete roots of the loop nest. Returns a list of - //! predicates, and a list of concrete roots they're associated with. It is - //! assumed that no predicate is required if index[i] is an index directly - //! from a for loop. This will not catch all cases if we actually have static - //! size information for example: + //! predicates, and a list of concrete roots they're associated with. It + //! is assumed that no predicate is required if index[i] is an index + //! directly from a for loop. This will not catch all cases if we actually + //! have static size information for example: //! //! TV[I].split(4) //! would produce the code: @@ -384,14 +399,14 @@ class Index { //! if( i * 4 + j < TV.size(0)) //! TV[i * 4 + j]... //! - //! However if we had TV.size[0] = 16 at "compile time" then we wouldn't need - //! the predicate. This will be caught by canOmitPredicate in the predicate - //! lowering + //! However if we had TV.size[0] = 16 at "compile time" then we wouldn't + //! need the predicate. This will be caught by canOmitPredicate in the + //! predicate lowering //! - //! unswitch_or_vec_loop is the for loop to start the unswitch like predicate, - //! this is not a bool value as if we have an unswitch loop with a vectorized - //! loop inside, we only want to base the "unswitch" like predicate on the - //! vectorized loop. + //! unswitch_or_vec_loop is the for loop to start the unswitch like + //! predicate, this is not a bool value as if we have an unswitch loop + //! with a vectorized loop inside, we only want to base the "unswitch" + //! like predicate on the vectorized loop. static std::vector getReferenceRootPredicates( TensorView* consumer_tv, const std::vector& loops, diff --git a/torch/csrc/jit/codegen/cuda/inline_propagator.cpp b/torch/csrc/jit/codegen/cuda/inline_propagator.cpp deleted file mode 100644 index a5edae083a32a..0000000000000 --- a/torch/csrc/jit/codegen/cuda/inline_propagator.cpp +++ /dev/null @@ -1,385 +0,0 @@ -#include -#include -#include -#include - -#include - -namespace torch { -namespace jit { -namespace fuser { -namespace cuda { - -MaxPosCalculator::MaxPosCalculator( - ComputeAtMode mode, - std::unordered_set uninlinable_ids) - : mode_(mode), uninlinable_ids_(std::move(uninlinable_ids)) { - buildUnmappableDims(); -} - -void MaxPosCalculator::buildUnmappableDims() { - ComputeAtRootDomainMap root_map; - root_map.build(); - - auto all_tvs = ir_utils::allTvs(FusionGuard::getCurFusion()); - for (auto tv : all_tvs) { - auto consumers = ir_utils::consumerTvsOf(tv); - for (auto consumer : consumers) { - // Grab dimensions in producer and consumer that are mappable to eachother - // based on the computeAtRootDomainMap. This will tell us which dimensions - // can be inlined based on avoiding trying to inline non-trivial - // reduction structures. - auto mappable_roots = - root_map.getMappableDims(tv->domain(), consumer->domain()); - for (auto tv_root_id : tv->getMaybeRFactorDomain()) { - if (mappable_roots.find(tv_root_id) == mappable_roots.end() && - !tv_root_id->isTrivialReduction()) { - unmappable_dims_.emplace(tv_root_id); - } - } - } - } -} - -bool MaxPosCalculator::isAllowedID( - IterDomain* id, - TensorView* tv, - bool allow_reduction, - bool allow_vectorize, - bool allow_unmappable) const { - bool allowed = true; - - if (!allow_reduction) { - allowed = allowed && !id->isReduction(); - } - - if (uninlinable_ids_.count(id)) { - return false; - } - - if (!allow_vectorize) { - // Avoid inlining if marked as Vectorize or Group. In the case of - // BestEffort and MostInlined modes, avoid Unroll as well. - bool is_vectorize = isParallelTypeVectorize(id->getParallelType()) || - id->getParallelType() == ParallelType::Group || - ((mode_ == ComputeAtMode::BestEffort || - mode_ == ComputeAtMode::MostInlined) && - id->getParallelType() == ParallelType::Unroll); - allowed = allowed && !is_vectorize; - } - - if (!allow_unmappable) { - auto root_dom = tv->getMaybeRFactorDomain(); - std::unordered_set root_dom_set(root_dom.begin(), root_dom.end()); - auto all_vals = DependencyCheck::getAllValsBetween(root_dom_set, {id}); - bool is_unmappable = false; - for (auto val : all_vals) { - auto id = val->as(); - if (root_dom_set.count(val) > 0 && unmappable_dims_.count(id) > 0) { - is_unmappable = true; - break; - } - } - allowed = allowed && !is_unmappable; - } - - return allowed; -} - -size_t MaxPosCalculator::getMaxPosSelf( - TensorView* tv, - bool allow_reduction, - bool allow_vectorize, - bool allow_unmappable) const { - auto dom = tv->domain()->domain(); - auto iter = std::find_if(dom.begin(), dom.end(), [=](IterDomain* id) { - return !isAllowedID( - id, tv, allow_reduction, allow_vectorize, allow_unmappable); - }); - return std::distance(dom.begin(), iter); -} - -// Return the max position in producer that can be inlined to consumer -// Cannot inline: -// Vectorized dimensions in consumer -// Unrolled dimensions in consumer -size_t MaxPosCalculator::getMaxProducerPosFromConsumer( - TensorView* producer, - TensorView* consumer) const { - auto pairwise_root_map = PairwiseRootDomainMap(producer, consumer); - auto replay_CasP = - BestEffortReplay::replayCasP(consumer, producer, -1, pairwise_root_map); - auto p2c_replay_map = replay_CasP.getReplay(); - - for (size_t producer_pos = 0; producer_pos < producer->nDims(); - producer_pos++) { - // If the producer position is mismatching with the consumer, then we can - // not inline into this position, otherwise the max producer position of - // the consumer will become invalid and expression sort will fail. - if (TransformReplay::getMatchedLeafPosWithoutReplayCasP( - consumer, producer, producer_pos + 1) < 0) { - return producer_pos; - } - auto map_it = p2c_replay_map.find(producer->axis(producer_pos)); - if (map_it != p2c_replay_map.end()) { - auto c_id = map_it->second; - if (!isAllowedID(c_id, consumer, true, false, true)) { - return producer_pos; - } - } - } - return producer->nDims(); -} - -size_t InlinePropagator::getMaxPosAll(TensorView* tv, bool check_siblings) { - auto max_pos = max_pos_calc.getMaxPosSelf(tv, false, false, false); - for (auto consumer_tv : ir_utils::consumerTvsOf(tv)) { - max_pos = std::min( - max_pos, max_pos_calc.getMaxProducerPosFromConsumer(tv, consumer_tv)); - } - if (check_siblings) { - for (auto sibling_tv : ir_utils::siblingTvsOf(tv)) { - max_pos = std::min(max_pos, getMaxPosAll(sibling_tv, false)); - } - } - return max_pos; -} - -void InlinePropagator::setCAPos(TensorView* tv) { - bool debug = isDebugDumpEnabled(DebugDumpOption::InlinePropagator); - size_t pos = mapped_reference_pos_.at(tv); - if (debug) { - std::cout << " Setting CA pos of " << tv << ":" << std::endl; - std::cout << " mapped position: " << pos << std::endl; - } - if ((selected_.empty() || selected_.count(tv)) && !tv->isFusionInput()) { - auto max_pos = getMaxPosAll(tv); - if (debug) { - std::cout << " max inlinable position: " << max_pos << std::endl; - } - if (mode_ == ComputeAtMode::Standard) { - TORCH_INTERNAL_ASSERT( - pos <= max_pos, - "Invalid compute at position detected in InlinePropagator when trying to set the CA position of: ", - tv, - " to ", - pos, - ", max position that's allowed is ", - max_pos); - } else if (mode_ == ComputeAtMode::BestEffort) { - pos = std::min(pos, max_pos); - } else { - pos = max_pos; - } - // hoist inner most broadcast - while (pos > 0 && tv->axis(pos - 1)->isBroadcast()) { - pos--; - } - auto current_ca_pos = tv->getComputeAtPosition(); - if (debug) { - std::cout << " current CA position: " << current_ca_pos << std::endl; - } - if (pos > current_ca_pos) { - if (debug) { - std::cout << " new CA position: " << pos << std::endl; - } - tv->setComputeAt(pos); - for (auto consumer_tv : ir_utils::consumerTvsOf(tv)) { - needs_update_max_producer_.insert(consumer_tv); - } - } else if (debug) { - std::cout << " CA position not changed" << std::endl; - } - } else if (debug) { - std::cout << " tensor not selected, skip" << std::endl; - } -} - -InlinePropagator::InlinePropagator( - TensorView* reference, - int64_t reference_pos, - ComputeAtMode mode, - std::unordered_set selected, - std::unordered_set uninlinable_ids) - : max_pos_calc(mode, std::move(uninlinable_ids)), - selected_(std::move(selected)), - reference_(reference), - mode_(mode) { - if (reference_pos < 0) { - reference_pos += int64_t(reference->nDims()) + 1; - } - TORCH_INTERNAL_ASSERT( - reference_pos >= 0 && reference_pos <= reference->nDims(), - "Invalid computeAt axis, received ", - reference_pos, - " but should be > -", - reference->nDims(), - " and <= ", - reference->nDims(), - "."); - reference_pos_ = reference_pos; -} - -void InlinePropagator::setUp() { - bool debug = isDebugDumpEnabled(DebugDumpOption::InlinePropagator); - mapped_reference_pos_[reference_] = reference_pos_; - if (debug) { - std::cout << "InlinePropagator::setUp" << std::endl; - std::cout << " reference: " << reference_ << " @ " << reference_pos_ - << std::endl; - } - setCAPos(reference_); -} - -namespace { - -// Try to find the aligned position on consumer's domain corresponding to the -// compute at position of producer domain. Used in InlinePropagator pass only. -// No checking on actual producer-consumer relationship. -unsigned int getConsumerPosAlignedToProducerCA( - TensorView* consumer, - TensorView* producer) { - // Locate consumer's position that aligns with - // the producer's new compute at axis. We need broadcast axes forwarded so we - // need to replay PasC as CasP will not forward braodcast dims. For example - // if we have: - // T2[ iS22{( 3 * 1 )} ] ca_pos( 1 ) = broadcast( T1[ iS1{3} ] ca_pos( 1 ) - // produce_pos( 1) ) CasP will have the mapping iS1{3} -> iS2{3} and PasC will - // have the mapping iS22{( 3 * 1 )} <- iS1{3} We need the latter. Refer to - // NVFuserTest.FusionComplexBCast1_CUDA - - auto disjoint_sets = - BestEffortReplay::replayPasC( - producer, consumer, -1, PairwiseRootDomainMap(producer, consumer)) - .getDisjointSets(); - - // Find the innermost position of consumer that has - // been mapped within the producer ca axis. - unsigned int consumer_pos = consumer->nDims(); - while (consumer_pos > 0) { - auto consumer_id = consumer->axis((int)consumer_pos - 1); - auto p_dom = producer->domain()->domain(); - if (std::any_of( - p_dom.begin(), - p_dom.begin() + producer->getComputeAtPosition(), - [&consumer_id, &disjoint_sets](IterDomain* p_id) { - return disjoint_sets.permissiveAreMapped(consumer_id, p_id); - })) { - break; - } - consumer_pos--; - } - - return consumer_pos; -} - -} // namespace - -void InlinePropagator::tearDown() { - for (auto consumer : needs_update_max_producer_) { - unsigned int consumer_pos = 0; - for (auto producer : ir_utils::producerTvsOf(consumer)) { - consumer_pos = std::max( - consumer_pos, getConsumerPosAlignedToProducerCA(consumer, producer)); - } - consumer->setMaxProducer(consumer_pos); - } -} - -void InlinePropagator::propagateC2P(TensorView* from, TensorView* to) { - bool debug = isDebugDumpEnabled(DebugDumpOption::InlinePropagator); - if (debug) { - std::cout << "InlinePropagator::propagateC2P" << std::endl; - std::cout << " from: " << from << std::endl; - std::cout << " to: " << to << std::endl; - } - // Step 1: find mapped_reference_pos_[to] - int from_pos = mapped_reference_pos_.at(from); - auto to_pos = - TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, from_pos); - if (mode_ == ComputeAtMode::Standard) { - TORCH_CHECK( - to_pos >= 0, - "Unable to propagate CA position from consumer ", - from, - " at ", - from_pos, - " to producer ", - to, - " because this would require replay."); - } else { - // For MostInlined and BestEffort inline propagation, we allow the DAG to - // be not replayed fully consistently. For such case, we just don't inline - // into the mismatched dimension. - while (to_pos < 0) { - from_pos--; - to_pos = TransformReplay::getMatchedLeafPosWithoutReplayPasC( - to, from, from_pos); - } - } - mapped_reference_pos_[to] = to_pos; - // Step 2: set CA position of `to` - setCAPos(to); -} - -void InlinePropagator::propagateP2C(TensorView* from, TensorView* to) { - bool debug = isDebugDumpEnabled(DebugDumpOption::InlinePropagator); - if (debug) { - std::cout << "InlinePropagator::propagateP2C" << std::endl; - std::cout << " from: " << from << std::endl; - std::cout << " to: " << to << std::endl; - } - // Step 1: find mapped_reference_pos_[to] - int from_pos = mapped_reference_pos_.at(from); - auto to_pos = - TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, from_pos); - if (mode_ == ComputeAtMode::Standard) { - TORCH_CHECK( - to_pos >= 0, - "Unable to propagate CA position from producer ", - from, - " at ", - from_pos, - " to consumer ", - to, - " because this would require replay."); - } else { - // For MostInlined and BestEffort inline propagation, we allow the DAG to - // be not replayed fully consistently. For such case, we just don't inline - // into the mismatched dimension. - while (to_pos < 0) { - from_pos--; - to_pos = TransformReplay::getMatchedLeafPosWithoutReplayCasP( - to, from, from_pos); - } - } - mapped_reference_pos_[to] = to_pos; - // Step 2: set CA position of `to` - setCAPos(to); -} - -void InlinePropagator::propagateSibling(TensorView* from, TensorView* to) { - bool debug = isDebugDumpEnabled(DebugDumpOption::InlinePropagator); - if (debug) { - std::cout << "InlinePropagator::propagateSibling" << std::endl; - std::cout << " from: " << from << std::endl; - std::cout << " to: " << to << std::endl; - } - // Step 1: find mapped_reference_pos_[to] - auto from_pos = mapped_reference_pos_.at(from); - TORCH_CHECK( - TransformReplay::fullSelfMatching(to, from), - "Unable to propagate CA position from ", - from, - " to sibling ", - to, - " because this would require replay."); - mapped_reference_pos_[to] = from_pos; - // Step 2: set CA position of `to` - setCAPos(to); -} - -} // namespace cuda -} // namespace fuser -} // namespace jit -} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/inline_propagator.h b/torch/csrc/jit/codegen/cuda/inline_propagator.h deleted file mode 100644 index d1bdeebd06d63..0000000000000 --- a/torch/csrc/jit/codegen/cuda/inline_propagator.h +++ /dev/null @@ -1,118 +0,0 @@ -#pragma once - -#include -#include -#include - -#include - -namespace torch { -namespace jit { -namespace fuser { -namespace cuda { - -class TORCH_CUDA_CU_API MaxPosCalculator { - ComputeAtMode mode_ = ComputeAtMode::Standard; - - // Root domains in producer that's unmappable to any of its consumers - std::unordered_set unmappable_dims_; - - // User set IterDomains to not inline, used in schedulers to avoid inlining - // trivial reductions - std::unordered_set uninlinable_ids_; - - // Iterate through all TVs and collect the dimensions of each TV that don't - // map to all its consumer TVs. - void buildUnmappableDims(); - - // Utility function to return if an id of tv is a valid iter domain to inline - // within. This is used in getMaxPos{PasC,CasP}. Different variations of the - // bool values are used if checking max position of PasC, CasP, or checking - // for a max "self" position. - bool isAllowedID( - IterDomain* id, - TensorView* tv, - bool allow_reduction, - bool allow_vectorize, - bool allow_unmappable) const; - - public: - // Returns the position at which tv can be inlined within. - size_t getMaxPosSelf( - TensorView* tv, - bool allow_reduction, - bool allow_vectorize, - bool allow_unmappable) const; - - // Returns the maximum position producer can be inlined based on consumer - // given the set ComputeAtMode - size_t getMaxProducerPosFromConsumer( - TensorView* producer, - TensorView* consumer) const; - - MaxPosCalculator( - ComputeAtMode mode, - std::unordered_set uninlinable_ids = {}); -}; - -// Propagate inline position to the `selected` tensors in the DAG. If `selected` -// is not specified or empty, then propagate to the entire DAG. -class TORCH_CUDA_CU_API InlinePropagator - : public MaxInfoSpanningTree::Propagator { - // Checks producers and consumers to see what the maximum position in tv is - // that can be shared across both directions. - size_t getMaxPosAll(TensorView* tv, bool check_siblings = true); - - // We use mapped_reference_pos_ to keep track of the outer axes information of - // the reference tensor. That is, mapped_reference_pos_[tv] answers the - // question "What outer axes in tv are shared with the specified reference - // tensor's outer axes?". However, when we actually set the CA position of tv, - // we might not want to set it as mapped_reference_pos_[tv] because because we - // don't want to inline certain things (such as vectorized dimensions, inner - // most broadcasting, etc.). - std::unordered_map mapped_reference_pos_; - - // Actually set the computeAt position. This does not necessarily equal to - // mapped_reference_pos_[tv] because we don't want to inline certain things. - void setCAPos(TensorView* tv); - - const MaxPosCalculator max_pos_calc; - std::unordered_set selected_; - std::unordered_set needs_update_max_producer_; - TensorView* reference_; - size_t reference_pos_; - ComputeAtMode mode_ = ComputeAtMode::Standard; - - public: - InlinePropagator( - TensorView* reference, - int64_t reference_pos, - ComputeAtMode mode = ComputeAtMode::Standard, - std::unordered_set selected = {}, - std::unordered_set uninlinable_ids = {}); - - InlinePropagator( - TensorView* reference, - int64_t reference_pos, - std::unordered_set selected) - : InlinePropagator( - reference, - reference_pos, - ComputeAtMode::Standard, - selected) {} - - ~InlinePropagator() = default; - - // Actually propagate the transformations for the inlining pass. Uses the - // functions above to figure out what position to do the propagation at. - virtual void setUp() override; - virtual void propagateC2P(TensorView* from, TensorView* to) override; - virtual void propagateP2C(TensorView* from, TensorView* to) override; - virtual void propagateSibling(TensorView* from, TensorView* to) override; - virtual void tearDown() override; -}; - -} // namespace cuda -} // namespace fuser -} // namespace jit -} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/inlining.cpp b/torch/csrc/jit/codegen/cuda/inlining.cpp new file mode 100644 index 0000000000000..da6d229c68f8b --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/inlining.cpp @@ -0,0 +1,306 @@ +#include +#include +#include +#include + +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +MaxPosCalculator::MaxPosCalculator( + const std::unordered_set& uninlinable_ids) + : uninlinable_ids_(uninlinable_ids) { + buildUnmappableDims(); +} + +void MaxPosCalculator::buildUnmappableDims() { + ComputeAtRootDomainMap root_map; + root_map.build(); + auto all_tvs = ir_utils::allTvs(FusionGuard::getCurFusion()); + for (auto tv : all_tvs) { + auto consumers = ir_utils::consumerTvsOf(tv); + for (auto consumer : consumers) { + // Grab dimensions in producer and consumer that are mappable to eachother + // based on the computeAtRootDomainMap. This will tell us which dimensions + // can be inlined based on avoiding trying to inline non-trivial + // reduction structures. + auto mappable_roots = + root_map.getMappableDims(tv->domain(), consumer->domain()); + for (auto tv_root_id : tv->getMaybeRFactorDomain()) { + if (mappable_roots.find(tv_root_id) == mappable_roots.end() && + !tv_root_id->isTrivialReduction()) { + unmappable_dims_.emplace(tv_root_id); + } + } + } + } +} + +bool MaxPosCalculator::isAllowedID( + IterDomain* id, + TensorView* tv, + bool best_effort, + bool allow_reduction, + bool allow_vectorize, + bool allow_unmappable) const { + bool allowed = true; + + if (!allow_reduction) { + allowed = allowed && !id->isReduction(); + } + + if (uninlinable_ids_.count(id)) { + return false; + } + + if (!allow_vectorize) { + // Avoid inlining if marked as Vectorize or Group. In the case of + // BestEffort and MostInlined modes, avoid Unroll as well. + bool is_vectorize = isParallelTypeVectorize(id->getParallelType()) || + id->getParallelType() == ParallelType::Group || + (best_effort && id->getParallelType() == ParallelType::Unroll); + allowed = allowed && !is_vectorize; + } + + if (!allow_unmappable) { + auto root_dom = tv->getMaybeRFactorDomain(); + std::unordered_set root_dom_set(root_dom.begin(), root_dom.end()); + auto all_vals = DependencyCheck::getAllValsBetween(root_dom_set, {id}); + bool is_unmappable = false; + for (auto val : all_vals) { + auto id = val->as(); + if (root_dom_set.count(val) > 0 && unmappable_dims_.count(id) > 0) { + is_unmappable = true; + break; + } + } + allowed = allowed && !is_unmappable; + } + + return allowed; +} + +size_t MaxPosCalculator::getMaxPosSelf( + TensorView* tv, + bool best_effort, + bool allow_reduction, + bool allow_vectorize, + bool allow_unmappable) const { + auto dom = tv->domain()->domain(); + auto iter = std::find_if(dom.begin(), dom.end(), [=](IterDomain* id) { + return !isAllowedID( + id, + tv, + best_effort, + allow_reduction, + allow_vectorize, + allow_unmappable); + }); + return std::distance(dom.begin(), iter); +} + +// Return the max position in producer that can be inlined to consumer +// Cannot inline: +// Vectorized dimensions in consumer +// Unrolled dimensions in consumer +size_t MaxPosCalculator::getMaxProducerPosFromConsumer( + TensorView* producer, + TensorView* consumer, + bool best_effort) const { + auto pairwise_root_map = PairwiseRootDomainMap(producer, consumer); + auto replay_CasP = + BestEffortReplay::replayCasP(consumer, producer, -1, pairwise_root_map); + auto p2c_replay_map = replay_CasP.getReplay(); + + for (size_t producer_pos = 0; producer_pos < producer->nDims(); + producer_pos++) { + // If the producer position is mismatching with the consumer, then we can + // not inline into this position, otherwise the max producer position of + // the consumer will become invalid and expression sort will fail. + if (TransformReplay::getMatchedLeafPosWithoutReplayCasP( + consumer, producer, producer_pos + 1) < 0) { + return producer_pos; + } + auto map_it = p2c_replay_map.find(producer->axis(producer_pos)); + if (map_it != p2c_replay_map.end()) { + auto c_id = map_it->second; + if (!isAllowedID(c_id, consumer, best_effort, true, false, true)) { + return producer_pos; + } + } + } + return producer->nDims(); +} + +size_t MaxPosCalculator::getMaxPosAll( + TensorView* tv, + bool best_effort, + bool check_siblings) { + auto max_pos = getMaxPosSelf(tv, best_effort, false, false, false); + for (auto consumer_tv : ir_utils::consumerTvsOf(tv)) { + max_pos = std::min( + max_pos, getMaxProducerPosFromConsumer(tv, consumer_tv, best_effort)); + } + if (check_siblings) { + for (auto sibling_tv : ir_utils::siblingTvsOf(tv)) { + max_pos = std::min( + max_pos, getMaxPosAll(sibling_tv, best_effort, false)); + } + } + return max_pos; +} + +void inlineMost(const std::unordered_set& uninlinable_ids) { + inlineMost(ir_utils::allTvs(FusionGuard::getCurFusion()), uninlinable_ids); +} + +void inlineMost( + const std::vector& tvs, + const std::unordered_set& uninlinable_ids) { + if (tvs.empty()) { + return; + } + MaxPosCalculator calc(uninlinable_ids); + for (auto tv : tvs) { + tv->inlineAt(-1, true, &calc); + } +} + +void inlineMost( + const std::unordered_set& tvs, + const std::unordered_set& uninlinable_ids) { + if (tvs.empty()) { + return; + } + MaxPosCalculator calc(uninlinable_ids); + for (auto tv : tvs) { + tv->inlineAt(-1, true, &calc); + } +} + +namespace { + +// Find the positions of `selected` tensors that is mapped to the given position +// in the reference tensor. +class FindMappedPositions : public MaxInfoSpanningTree::Propagator { + std::unordered_map& output_; + + public: + FindMappedPositions( + std::unordered_map& output, + TensorView* reference, + int64_t reference_pos); + + ~FindMappedPositions() = default; + + virtual void propagateC2P(TensorView* from, TensorView* to) override; + virtual void propagateP2C(TensorView* from, TensorView* to) override; + virtual void propagateSibling(TensorView* from, TensorView* to) override; +}; + +FindMappedPositions::FindMappedPositions( + std::unordered_map& output, + TensorView* reference, + int64_t reference_pos) + : output_(output) { + if (reference_pos < 0) { + reference_pos += int64_t(reference->nDims()) + 1; + } + TORCH_CHECK( + reference_pos >= 0 && reference_pos <= reference->nDims(), + "Invalid axis received ", + reference_pos, + " but should be > -", + reference->nDims(), + " and <= ", + reference->nDims(), + "."); + output_[reference] = reference_pos; +} + +void FindMappedPositions::propagateC2P(TensorView* from, TensorView* to) { + int from_pos = output_.at(from); + auto to_pos = + TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, from_pos); + // If there is no matching position found, we compute the highest matched + // position as the closest approximation + while (to_pos < 0) { + from_pos--; + to_pos = + TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, from_pos); + } + output_[to] = to_pos; +} + +void FindMappedPositions::propagateP2C(TensorView* from, TensorView* to) { + int from_pos = output_.at(from); + auto to_pos = + TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, from_pos); + // If there is no matching position found, we compute the highest matched + // position as the closest approximation + while (to_pos < 0) { + from_pos--; + to_pos = + TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, from_pos); + } + output_[to] = to_pos; +} + +void FindMappedPositions::propagateSibling(TensorView* from, TensorView* to) { + auto from_pos = output_.at(from); + TORCH_CHECK( + TransformReplay::fullSelfMatching(to, from), + "Transformations in siblings ", + from, + " and ", + to, + " does not match with each other."); + output_[to] = from_pos; +} + +std::unordered_map getPositionsMappedTo( + TensorView* reference_tv, + int64_t reference_pos) { + std::unordered_map mapped_positions; + MaxRootDomainInfoSpanningTree tree(reference_tv, reference_pos); + FindMappedPositions propagator(mapped_positions, reference_tv, reference_pos); + tree.traverse(&propagator); + return mapped_positions; +} + +} // namespace + +void inlineAllAt( + TensorView* reference_tv, + int64_t reference_pos, + bool best_effort, + const std::unordered_set& uninlinable_ids) { + auto mapped_positions = getPositionsMappedTo(reference_tv, reference_pos); + MaxPosCalculator calc(uninlinable_ids); + for (auto pair : mapped_positions) { + pair.first->inlineAt(pair.second, best_effort, &calc); + } +} + +void inlineSelectedAt( + const std::unordered_set& selected, + TensorView* reference_tv, + int64_t reference_pos, + bool best_effort, + const std::unordered_set& uninlinable_ids) { + auto mapped_positions = getPositionsMappedTo(reference_tv, reference_pos); + MaxPosCalculator calc(uninlinable_ids); + for (auto pair : mapped_positions) { + if (selected.count(pair.first) > 0) { + pair.first->inlineAt(pair.second, best_effort, &calc); + } + } +} + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/inlining.h b/torch/csrc/jit/codegen/cuda/inlining.h new file mode 100644 index 0000000000000..3b15eb23f9877 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/inlining.h @@ -0,0 +1,100 @@ +#pragma once + +#include +#include +#include + +#include +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +class MaxPosCalculator { + // Root domains in producer that's unmappable to any of its consumers + std::unordered_set unmappable_dims_; + + // User set IterDomains to not inline, used in schedulers to avoid inlining + // trivial reductions + std::unordered_set uninlinable_ids_; + + // Iterate through all TVs and collect the dimensions of each TV that don't + // map to all its consumer TVs. + void buildUnmappableDims(); + + // Utility function to return if an id of tv is a valid iter domain to inline + // within. This is used in getMaxPos{PasC,CasP}. Different variations of the + // bool values are used if checking max position of PasC, CasP, or checking + // for a max "self" position. + bool isAllowedID( + IterDomain* id, + TensorView* tv, + bool best_effort, + bool allow_reduction, + bool allow_vectorize, + bool allow_unmappable) const; + + public: + // Returns the position at which tv can be inlined within. + size_t getMaxPosSelf( + TensorView* tv, + bool best_effort, + bool allow_reduction, + bool allow_vectorize, + bool allow_unmappable) const; + + // Returns the maximum position producer can be inlined based on consumer + // given the set ComputeAtMode + size_t getMaxProducerPosFromConsumer( + TensorView* producer, + TensorView* consumer, + bool best_effort) const; + + // Checks producers, consumers, and siblings to see what the maximum position + // in tv is that can be shared across both directions. + size_t getMaxPosAll( + TensorView* tv, + bool best_effort = false, + bool check_siblings = true); + + MaxPosCalculator(const std::unordered_set& uninlinable_ids = {}); +}; + +// Inline to the right most allowed position for all tensors in the current +// fusion. +TORCH_CUDA_CU_API void inlineMost( + const std::unordered_set& uninlinable_ids = {}); +// Inline to the right most allowed position for the selected tensors in the +// current fusion. +TORCH_CUDA_CU_API void inlineMost( + const std::vector& tvs, + const std::unordered_set& uninlinable_ids = {}); +// Inline to the right most allowed position for the selected tensors in the +// current fusion. +TORCH_CUDA_CU_API void inlineMost( + const std::unordered_set& tvs, + const std::unordered_set& uninlinable_ids = {}); + +// Inline to the position corresponding to the reference position in the +// reference tensor for all tensors in the current fusion. +TORCH_CUDA_CU_API void inlineAllAt( + TensorView* reference_tv, + int64_t reference_pos, + bool best_effort = false, + const std::unordered_set& uninlinable_ids = {}); + +// Inline to the position corresponding to the reference position in the +// reference tensor for selected tensors in the current fusion. +TORCH_CUDA_CU_API void inlineSelectedAt( + const std::unordered_set& selected, + TensorView* reference_tv, + int64_t reference_pos, + bool best_effort = false, + const std::unordered_set& uninlinable_ids = {}); + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/interface.cpp b/torch/csrc/jit/codegen/cuda/interface.cpp index 664f14d26c759..6b1fa7c44f9c5 100644 --- a/torch/csrc/jit/codegen/cuda/interface.cpp +++ b/torch/csrc/jit/codegen/cuda/interface.cpp @@ -655,6 +655,62 @@ RegisterOperators reg_add_optional({ aliasAnalysisFromSchema()), }); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +RegisterOperators reg_permute_copy({ + Operator( + "prim::permute_copy(Tensor(a) self, int[] dims) -> Tensor", + [](const Node* node) -> Operation { + return [node](Stack& stack) { + TORCH_CHECK( + node->s(attr::name) == "CudaFusionGroup", + "permute_copy is only used by nvfuser to identify non-mutating ", + "alias ops, should be restored after fusion pass!"); + IValue self, dims; + pop(stack, self, dims); + push(stack, at::native::view(self.toTensor(), dims.toIntVector())); + }; + }, + aliasAnalysisFromSchema()), +}); + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +RegisterOperators reg_transpose_copy({ + Operator( + "prim::transpose_copy.int(Tensor(a) self, int dim0, int dim1) -> Tensor", + [](const Node* node) -> Operation { + return [node](Stack& stack) { + TORCH_CHECK( + node->s(attr::name) == "CudaFusionGroup", + "transpose_copy is only used by nvfuser to identify non-mutating ", + "alias ops, should be restored after fusion pass!"); + IValue self, dim0, dim1; + pop(stack, self, dim0, dim1); + push( + stack, + at::transpose(self.toTensor(), dim0.toInt(), dim1.toInt())); + }; + }, + aliasAnalysisFromSchema()), +}); + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +RegisterOperators reg_t_copy({ + Operator( + "prim::t_copy(Tensor(a) self) -> Tensor", + [](const Node* node) -> Operation { + return [node](Stack& stack) { + TORCH_CHECK( + node->s(attr::name) == "CudaFusionGroup", + "t_copy is only used by nvfuser to identify non-mutating ", + "alias ops, should be restored after fusion pass!"); + IValue self; + pop(stack, self); + push(stack, at::t(self.toTensor())); + }; + }, + aliasAnalysisFromSchema()), +}); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) RegisterOperators reg_view_copy({ Operator( diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp index b29a8bc417cd0..ff00f659da637 100644 --- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp @@ -341,6 +341,12 @@ void Expr::setPredicate(kir::Predicate* predicate) { predicate_ = predicate; } +Expr* Expr::withPredicate(kir::Predicate* predicate) { + auto result = shallowCopy(); + result->setPredicate(predicate); + return result; +} + kir::Predicate* Expr::writePredicate() const { TORCH_INTERNAL_ASSERT( container()->isA(), "Function invalid for fusion."); @@ -353,6 +359,19 @@ void Expr::setWritePredicate(kir::Predicate* write_predicate) { write_predicate_ = write_predicate; } +Expr* Expr::withWritePredicate(kir::Predicate* predicate) { + auto result = shallowCopy(); + result->setWritePredicate(predicate); + return result; +} + +void Expr::copyPredicatesFrom(const Expr* expr) { + if (container()->isA()) { + predicate_ = expr->predicate_; + write_predicate_ = expr->write_predicate_; + } +} + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h index 7d5ebad25282b..dadabe167ebfc 100644 --- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h @@ -426,6 +426,10 @@ class TORCH_CUDA_CU_API Expr : public Statement { Expr(const Expr* src, IrCloner* ir_cloner); + // Creates a new instance of the expression with all its field copied. + // Note that unlike IrCloner, this function only do a shallow copy + virtual Expr* shallowCopy() const = 0; + c10::optional getExprType() const override { return etype_; } @@ -466,16 +470,27 @@ class TORCH_CUDA_CU_API Expr : public Statement { // TODO: Protect based on being in kernel container kir::Predicate* predicate() const; + // Creates a shallow copy the expression with the given predicate attached. // TODO: Protect based on being in kernel container - void setPredicate(kir::Predicate* predicate); + Expr* withPredicate(kir::Predicate* predicate); // TODO: Protect based on being in kernel container kir::Predicate* writePredicate() const; + // Creates a shallow copy the expression with the given write-predicate + // attached. // TODO: Protect based on being in kernel container - void setWritePredicate(kir::Predicate* write_predicate); + Expr* withWritePredicate(kir::Predicate* write_predicate); protected: + // TODO: Protect based on being in kernel container + void setPredicate(kir::Predicate* predicate); + + // TODO: Protect based on being in kernel container + void setWritePredicate(kir::Predicate* write_predicate); + + void copyPredicatesFrom(const Expr* expr); + // TODO: Add Fusion passkey void addInput(Val* input) { TORCH_INTERNAL_ASSERT(input != nullptr); diff --git a/torch/csrc/jit/codegen/cuda/ir_builder.cpp b/torch/csrc/jit/codegen/cuda/ir_builder.cpp index 189bd7aa666eb..f0fd438c15672 100644 --- a/torch/csrc/jit/codegen/cuda/ir_builder.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_builder.cpp @@ -60,7 +60,9 @@ IR_BUILDER_INSTANTIATE(ShiftOp) IR_BUILDER_INSTANTIATE(GatherOp) IR_BUILDER_INSTANTIATE(ViewAsScalar) IR_BUILDER_INSTANTIATE(ViewOp) +IR_BUILDER_INSTANTIATE(FullOp) IR_BUILDER_INSTANTIATE(ARangeOp) +IR_BUILDER_INSTANTIATE(EyeOp) IR_BUILDER_INSTANTIATE(UnaryOp) IR_BUILDER_INSTANTIATE(BinaryOp) IR_BUILDER_INSTANTIATE(TernaryOp) diff --git a/torch/csrc/jit/codegen/cuda/ir_cloner.cpp b/torch/csrc/jit/codegen/cuda/ir_cloner.cpp index bdd1d3b86df7c..489be49ddfc7c 100644 --- a/torch/csrc/jit/codegen/cuda/ir_cloner.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_cloner.cpp @@ -88,10 +88,18 @@ void IrCloner::handle(const TensorView* tv) { clone_ = IrBuilder::clone(tv, this); } +void IrCloner::handle(const FullOp* op) { + clone_ = IrBuilder::clone(op, this); +} + void IrCloner::handle(const ARangeOp* op) { clone_ = IrBuilder::clone(op, this); } +void IrCloner::handle(const EyeOp* op) { + clone_ = IrBuilder::clone(op, this); +} + void IrCloner::handle(const UnaryOp* op) { clone_ = IrBuilder::clone(op, this); } diff --git a/torch/csrc/jit/codegen/cuda/ir_cloner.h b/torch/csrc/jit/codegen/cuda/ir_cloner.h index 7cc118cdcff5a..06e1ec3359d95 100644 --- a/torch/csrc/jit/codegen/cuda/ir_cloner.h +++ b/torch/csrc/jit/codegen/cuda/ir_cloner.h @@ -68,7 +68,9 @@ class TORCH_CUDA_CU_API IrCloner : private OptInConstDispatch { void handle(const ComplexDouble*) override; void handle(const NamedScalar*) override; + void handle(const FullOp*) override; void handle(const ARangeOp*) override; + void handle(const EyeOp*) override; void handle(const UnaryOp*) override; void handle(const BinaryOp*) override; void handle(const TernaryOp*) override; diff --git a/torch/csrc/jit/codegen/cuda/ir_graphviz.cpp b/torch/csrc/jit/codegen/cuda/ir_graphviz.cpp index fa5173dfcaa5a..6c04e4214b07d 100644 --- a/torch/csrc/jit/codegen/cuda/ir_graphviz.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_graphviz.cpp @@ -407,15 +407,32 @@ void IrGraphGenerator::handle(const TensorView* tv) { tensor_views_.push_back(tv); } -void IrGraphGenerator::handle(const ARangeOp* uop) { +void IrGraphGenerator::handle(const FullOp* fop) { // node - printExpr(uop, "arange"); + printExpr(fop, "full"); // inputs & outputs - addArc(uop->start(), uop); - addArc(uop->end(), uop); - addArc(uop->step(), uop); - addArc(uop, uop->output(0)); + addArc(fop->getFillValue(), fop); + addArc(fop, fop->output(0)); +} + +void IrGraphGenerator::handle(const ARangeOp* aop) { + // node + printExpr(aop, "arange"); + + // inputs & outputs + addArc(aop->start(), aop); + addArc(aop->end(), aop); + addArc(aop->step(), aop); + addArc(aop, aop->output(0)); +} + +void IrGraphGenerator::handle(const EyeOp* eop) { + // node + printExpr(eop, "eye"); + + // inputs & outputs + addArc(eop, eop->output(0)); } void IrGraphGenerator::handle(const UnaryOp* uop) { diff --git a/torch/csrc/jit/codegen/cuda/ir_graphviz.h b/torch/csrc/jit/codegen/cuda/ir_graphviz.h index c68c4fccb6f6c..1f555ed31ec06 100644 --- a/torch/csrc/jit/codegen/cuda/ir_graphviz.h +++ b/torch/csrc/jit/codegen/cuda/ir_graphviz.h @@ -82,7 +82,9 @@ class TORCH_CUDA_CU_API IrGraphGenerator : private OptInConstDispatch { void handle(const ComplexDouble*) override; void handle(const NamedScalar*) override; + void handle(const FullOp*) override; void handle(const ARangeOp*) override; + void handle(const EyeOp*) override; void handle(const UnaryOp*) override; void handle(const BinaryOp*) override; void handle(const TernaryOp*) override; diff --git a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h index 126abba2ae103..dbefc4858d110 100644 --- a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h @@ -154,8 +154,6 @@ class TORCH_CUDA_CU_API ComplexDouble : public Val { //! the compute at position to maximum possible through traversal. enum class ComputeAtMode { Standard, BestEffort, MostInlined }; -class InlinePropagator; -class MaxProducerPosUpdater; class TransformPropagator; struct MostInlinedTransformPropagator; class TransformIter; @@ -163,6 +161,8 @@ class TransformReplay; class OptOutMutator; class TensorDomain; +class MaxPosCalculator; + namespace ir_utils { class TVDomainGuard; } @@ -492,21 +492,30 @@ class TORCH_CUDA_CU_API TensorView : public Val { friend TORCH_CUDA_CU_API MostInlinedTransformPropagator; friend TORCH_CUDA_CU_API TransformReplay; friend TORCH_CUDA_CU_API OptOutMutator; - friend TORCH_CUDA_CU_API InlinePropagator; - friend TORCH_CUDA_CU_API MaxProducerPosUpdater; + friend class InlineBatchingGuard; friend class ir_utils::TVDomainGuard; - friend TORCH_CUDA_CU_API void groupReductions( - const std::vector&); + + // Inline the computation of this tensor into its consumer at the given + // position. If this tensor is already inlined in a higher position, then this + // call is a no-op. If the right most dimensions before `pos` are + // broadcasting, then will not inline into these broadcastings. If + // best_effort, then will inline into the highest allowed position that is <= + // `pos`. + void inlineAt( + int64_t pos, + bool best_effort = false, + MaxPosCalculator* calc = nullptr); + + // Update the max producer position of the current tensor. This is required + // when we modify producer-consumer relationship of a scheduled tensor, for + // example, grouping multiple reductions. + void updateMaxProducerPosition(); protected: void setDomain(TensorDomain* td) { domain_ = td; } - void setComputeAt(unsigned int this_pos, bool decrease = false); - - void setMaxProducer(unsigned int this_pos, bool decrease = false); - private: int normalizeAxisPos(int pos) const { if (pos < 0) { diff --git a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h index 8077c9bc920cf..d34b3a9f89c58 100644 --- a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h @@ -30,6 +30,29 @@ struct AnalyzeViewResult; //! vals are `Int` will dispatch to v1->as()->sameAs(v2.as()) bool areEqualScalars(Val* v1, Val* v2); +class TORCH_CUDA_CU_API FullOp : public Expr { + public: + FullOp(IrBuilderPasskey, Val* out, Val* fill_value, DataType dtype); + + FullOp(const FullOp* src, IrCloner* ir_cloner); + + Expr* shallowCopy() const override; + + bool sameAs(const Statement* other) const override; + + DataType dtype() const { + return dtype_; + } + + Val* getFillValue() const { + return fill_value_; + } + + private: + const DataType dtype_; + Val* fill_value_; +}; + class TORCH_CUDA_CU_API ARangeOp : public Expr { public: ARangeOp( @@ -38,12 +61,19 @@ class TORCH_CUDA_CU_API ARangeOp : public Expr { Val* start, Val* end, Val* step, + DataType dtype, Val* linear_index = nullptr); ARangeOp(const ARangeOp* src, IrCloner* ir_cloner); + Expr* shallowCopy() const override; + bool sameAs(const Statement* other) const override; + DataType dtype() const { + return dtype_; + } + Val* start() const { return start_; } @@ -56,7 +86,7 @@ class TORCH_CUDA_CU_API ARangeOp : public Expr { return step_; } - Val* getLinearIndex() const { + Val* getLinearLogicalIndex() const { return linear_index_; } @@ -65,12 +95,72 @@ class TORCH_CUDA_CU_API ARangeOp : public Expr { } private: + const DataType dtype_; Val* start_; Val* end_; Val* step_; Val* linear_index_ = nullptr; }; +// Tensor factory for generating identity matrices like +// +// [[1, 0, 0], +// [0, 1, 0], +// [0, 0, 1]] +// +// or +// +// [[1, 0, 0], +// [0, 1, 0], +// [0, 0, 1], +// [0, 0, 0]] +// +// or +// +// [[1, 0, 0, 0], +// [0, 1, 0, 0], +// [0, 0, 1, 0]] +class TORCH_CUDA_CU_API EyeOp : public Expr { + public: + EyeOp( + IrBuilderPasskey, + Val* out, + DataType dtype, + Val* index1 = nullptr, + Val* index2 = nullptr); + + EyeOp(const EyeOp* src, IrCloner* ir_cloner); + + Expr* shallowCopy() const override; + + bool sameAs(const Statement* other) const override; + + DataType dtype() const { + return dtype_; + } + + Val* getIndex1() const { + return index1_; + } + + void setIndex1(Val* index) { + index1_ = index; + } + + Val* getIndex2() const { + return index2_; + } + + void setIndex2(Val* index) { + index2_ = index; + } + + private: + const DataType dtype_; + Val* index1_ = nullptr; + Val* index2_ = nullptr; +}; + //! A specialization for Unary operations. Unary operations take in a single //! input and produce a single output. Examples include: //! 1) Casting operation i.e. float(a_val) @@ -88,6 +178,8 @@ class TORCH_CUDA_CU_API UnaryOp : public Expr { UnaryOp(const UnaryOp* src, IrCloner* ir_cloner); + Expr* shallowCopy() const override; + Val* out() const { return out_; } @@ -117,6 +209,8 @@ class TORCH_CUDA_CU_API BinaryOp : public Expr { BinaryOp(const BinaryOp* src, IrCloner* ir_cloner); + Expr* shallowCopy() const override; + Val* out() const { return out_; } @@ -148,15 +242,23 @@ class TORCH_CUDA_CU_API RNGOp : public Expr { IrBuilderPasskey, RNGOpType type, Val* out, + DataType dtype, + std::vector parameters = {}, int rng_offset = 0, Val* philox_index = nullptr); RNGOp(const RNGOp* src, IrCloner* ir_cloner); + Expr* shallowCopy() const override; + RNGOpType getRNGOpType() const { return rng_op_type_; } + DataType dtype() const { + return dtype_; + } + int getRNGOffset() const { return rng_offset_; } @@ -165,6 +267,14 @@ class TORCH_CUDA_CU_API RNGOp : public Expr { rng_offset_ = val; } + const std::vector& getParameters() const { + return parameters_; + } + + const std::vector& getShape() const { + return shape_; + } + Val* getPhiloxIndex() const { return philox_index_; } @@ -177,6 +287,9 @@ class TORCH_CUDA_CU_API RNGOp : public Expr { private: const RNGOpType rng_op_type_; + const DataType dtype_; + std::vector parameters_; + std::vector shape_; int rng_offset_ = -1; // The index used to feed philox's subsequence and component Val* philox_index_ = nullptr; @@ -197,6 +310,8 @@ class TORCH_CUDA_CU_API BroadcastOp : public Expr { BroadcastOp(const BroadcastOp* src, IrCloner* ir_cloner); + Expr* shallowCopy() const override; + Val* out() const { return out_; } @@ -245,6 +360,8 @@ class TORCH_CUDA_CU_API ReductionOp : public Expr { ReductionOp(const ReductionOp* src, IrCloner* ir_cloner); + Expr* shallowCopy() const override; + Val* out() const { return out_; } @@ -293,6 +410,8 @@ class TORCH_CUDA_CU_API GroupedReductionOp : public Expr { GroupedReductionOp(const GroupedReductionOp* src, IrCloner* ir_cloner); + Expr* shallowCopy() const override; + //! Number of expressions grouped horizontally. It does not reflect //! iteration grouping. size_t numExprs() const { @@ -479,6 +598,8 @@ class TORCH_CUDA_CU_API WelfordOp : public Expr { WelfordOp(const WelfordOp* src, IrCloner* ir_cloner); + Expr* shallowCopy() const override; + Val* out() const { return output().avg(); } @@ -574,6 +695,8 @@ class TORCH_CUDA_CU_API GroupedWelfordOp : public Expr { GroupedWelfordOp(const GroupedWelfordOp* src, IrCloner* ir_cloner); + Expr* shallowCopy() const override; + //! Number of expressions grouped horizontally. It does not reflect //! iteration grouping. As horizontal grouping is not supported, //! this always returns 1. @@ -697,6 +820,8 @@ class TORCH_CUDA_CU_API MmaOp : public Expr { MmaOp(const MmaOp* src, IrCloner* ir_cloner); + Expr* shallowCopy() const override; + Val* out() const { return out_; } @@ -755,6 +880,8 @@ class TORCH_CUDA_CU_API TransposeOp : public Expr { TransposeOp(const TransposeOp* src, IrCloner* ir_cloner); + Expr* shallowCopy() const override; + TensorView* out() const { return out_; } @@ -785,6 +912,8 @@ class TORCH_CUDA_CU_API ExpandOp : public Expr { ExpandOp(const ExpandOp* src, IrCloner* ir_cloner); + Expr* shallowCopy() const override; + TensorView* out() const { return out_; } @@ -815,6 +944,8 @@ class TORCH_CUDA_CU_API TernaryOp : public Expr { TernaryOp(const TernaryOp* src, IrCloner* ir_cloner); + Expr* shallowCopy() const override; + Val* out() const { return out_; } @@ -858,6 +989,8 @@ class TORCH_CUDA_CU_API ShiftOp : public Expr { ShiftOp(const ShiftOp* src, IrCloner* ir_cloner); + Expr* shallowCopy() const override; + Val* out() const { return out_; } @@ -907,6 +1040,8 @@ class TORCH_CUDA_CU_API GatherOp : public Expr { GatherOp(const GatherOp* src, IrCloner* ir_cloner); + Expr* shallowCopy() const override; + Val* out() const { return out_; } @@ -953,6 +1088,8 @@ class TORCH_CUDA_CU_API ViewAsScalar : public Expr { ViewAsScalar(const ViewAsScalar* src, IrCloner* ir_cloner); + Expr* shallowCopy() const override; + Val* out() const { return out_; } @@ -986,6 +1123,8 @@ class TORCH_CUDA_CU_API ViewOp : public Expr { ViewOp(const ViewOp* src, IrCloner* ir_cloner); + Expr* shallowCopy() const override; + TensorView* out() const { return out_; } @@ -1011,6 +1150,8 @@ class TORCH_CUDA_CU_API LoadStoreOp : public Expr { LoadStoreOp(const LoadStoreOp* src, IrCloner* ir_cloner); + Expr* shallowCopy() const override; + Val* out() const { return out_; } @@ -1275,16 +1416,8 @@ class TORCH_CUDA_CU_API IterDomain : public Val { } //! Check if IterDomain is a reduction axis with size of 1, i.e. - //! a "squeeze" operator. - //! - //! NOTE: Detection of trivial reduction here is not - //! comprehensive. See detectTrivialReductionDerivedDomains for more - //! comprehensive analysis. We typically use this for root domain trivial - //! reduction checks. So we ship to the correct scheduler. It may - //! not be incredibly robust, but it makes sense to keep it for now. - bool isTrivialReduction() const { - return isReduction() && extent()->isOneInt(); - } + //! a "squeeze" operator, or solely derived from such axes. + bool isTrivialReduction() const; //! Split for stride by a given factor. It effectively does an inner //! split by the factor and sets the inner domain as a Stride @@ -1590,6 +1723,8 @@ class TORCH_CUDA_CU_API Split : public Expr { Split(const Split* src, IrCloner* ir_cloner); + Expr* shallowCopy() const override; + IterDomain* outer() const { return outer_; } @@ -1650,6 +1785,8 @@ class TORCH_CUDA_CU_API Merge : public Expr { Merge(const Merge* src, IrCloner* ir_cloner); + Expr* shallowCopy() const override; + IterDomain* out() const { return out_; } @@ -1682,6 +1819,8 @@ class TORCH_CUDA_CU_API Swizzle2D : public Expr { Swizzle2D(const Swizzle2D* src, IrCloner* ir_cloner); + Expr* shallowCopy() const override; + IterDomain* outX() const { return out_x_; } diff --git a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp index 5229647ac9d5c..e13273c8e75e9 100644 --- a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp @@ -248,6 +248,35 @@ void IrPrinter::handle(const NamedScalar* ns) { os_ << ns->name(); } +void IrPrinter::handle(const FullOp* fop) { + if (!print_inline_) { + indent(); + os_ << fop->output(0) << "\n"; + indent_size_++; + indent(); + os_ << " = "; + } else { + checkInlineable(fop); + } + + os_ << "full({"; + for (auto i : c10::irange(fop->inputs().size())) { + if (i == fop->inputs().size() - 1) { + os_ << "}"; + } + if (i > 0) { + os_ << ", "; + } + handle(fop->input(i)); + } + os_ << ", " << fop->dtype() << ")"; + + indent_size_--; + + if (!print_inline_) + os_ << ";\n"; +} + void IrPrinter::handle(const ARangeOp* aop) { if (!print_inline_) { indent() << aop->output(0); @@ -265,7 +294,28 @@ void IrPrinter::handle(const ARangeOp* aop) { handle(aop->end()); os_ << ", "; handle(aop->step()); - os_ << ")"; + os_ << ", " << aop->dtype() << ")"; + + indent_size_--; + + if (!print_inline_) + os_ << ";\n"; +} + +void IrPrinter::handle(const EyeOp* eop) { + if (!print_inline_) { + indent(); + os_ << eop->output(0) << "\n"; + indent_size_++; + indent(); + os_ << " = "; + } else { + checkInlineable(eop); + } + + os_ << "eye("; + handle(eop->input(0)); + os_ << ", " << eop->dtype() << ")"; indent_size_--; @@ -429,21 +479,27 @@ void IrPrinter::handle(const RNGOp* rop) { checkInlineable(rop); } - os_ << rop->getRNGOpType() << "("; + os_ << rop->getRNGOpType() << "({"; bool first = true; - for (auto i : rop->inputs()) { + for (auto i : rop->getShape()) { if (!first) { os_ << ", "; } handle(i); first = false; } - os_ << ")"; + os_ << "}"; + for (auto i : rop->getParameters()) { + os_ << ", "; + handle(i); + } + os_ << ", " << rop->dtype() << ")"; indent_size_--; - if (!print_inline_) + if (!print_inline_) { os_ << ";\n"; + } } void IrPrinter::handle(const ReductionOp* rop) { diff --git a/torch/csrc/jit/codegen/cuda/ir_iostream.h b/torch/csrc/jit/codegen/cuda/ir_iostream.h index fd77d91010a48..599e50286d294 100644 --- a/torch/csrc/jit/codegen/cuda/ir_iostream.h +++ b/torch/csrc/jit/codegen/cuda/ir_iostream.h @@ -82,7 +82,9 @@ class TORCH_CUDA_CU_API IrPrinter : public OptInConstDispatch { void handle(const ComplexDouble*) final; void handle(const NamedScalar*) final; + void handle(const FullOp*) final; void handle(const ARangeOp*) final; + void handle(const EyeOp*) final; void handle(const UnaryOp*) final; void handle(const BinaryOp*) final; void handle(const TernaryOp*) final; diff --git a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp index fb8e28c53de0d..c4d994f272be1 100644 --- a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp @@ -182,14 +182,54 @@ bool ComplexDouble::sameAs(const Statement* other) const { return false; } +FullOp::FullOp( + IrBuilderPasskey passkey, + Val* out, + Val* fill_value, + DataType dtype) + : Expr(passkey, ExprType::FullOp), dtype_(dtype), fill_value_(fill_value) { + if (out->isA()) { + addInput(out->as()->getRootDomain()[0]->extent()); + } + addInput(fill_value); + addOutput(out); +} + +FullOp::FullOp(const FullOp* src, IrCloner* ir_cloner) + : Expr(src, ir_cloner), + dtype_(src->dtype()), + fill_value_(ir_cloner->clone(src->fill_value_)) {} + +Expr* FullOp::shallowCopy() const { + auto result = IrBuilder::create(output(0), fill_value_, dtype_); + result->copyPredicatesFrom(this); + return result; +} + +bool FullOp::sameAs(const Statement* other) const { + if (this == other) { + return true; + } + if (!other->isA()) { + return false; + } + const auto other_op = other->as(); + if (dtype_ != other_op->dtype_) { + return false; + } + return Expr::sameAs(other); +} + ARangeOp::ARangeOp( IrBuilderPasskey passkey, Val* out, Val* start, Val* end, Val* step, + DataType dtype, Val* linear_index) : Expr(passkey, ExprType::ARangeOp), + dtype_(dtype), start_(start), end_(end), step_(step), @@ -202,11 +242,19 @@ ARangeOp::ARangeOp( ARangeOp::ARangeOp(const ARangeOp* src, IrCloner* ir_cloner) : Expr(src, ir_cloner), + dtype_(src->dtype()), start_(ir_cloner->clone(src->start_)), end_(ir_cloner->clone(src->end_)), step_(ir_cloner->clone(src->step_)), linear_index_(ir_cloner->clone(src->linear_index_)) {} +Expr* ARangeOp::shallowCopy() const { + auto result = IrBuilder::create( + output(0), start_, end_, step_, dtype_, linear_index_); + result->copyPredicatesFrom(this); + return result; +} + bool ARangeOp::sameAs(const Statement* other) const { if (this == other) { return true; @@ -215,6 +263,9 @@ bool ARangeOp::sameAs(const Statement* other) const { return false; } const auto other_op = other->as(); + if (dtype_ != other_op->dtype_) { + return false; + } if (!start_->sameAs(other_op->start_)) { return false; } @@ -234,6 +285,64 @@ bool ARangeOp::sameAs(const Statement* other) const { return Expr::sameAs(other); } +EyeOp::EyeOp( + IrBuilderPasskey passkey, + Val* out, + DataType dtype, + Val* index1, + Val* index2) + : Expr(passkey, ExprType::EyeOp), + dtype_(dtype), + index1_(index1), + index2_(index2) { + if (out->isA()) { + addInput(out->as()->getRootDomain()[0]->extent()); + if (out->as()->getRootDomain()[1] != + out->as()->getRootDomain()[0]) { + addInput(out->as()->getRootDomain()[1]->extent()); + } + } + addOutput(out); +} + +EyeOp::EyeOp(const EyeOp* src, IrCloner* ir_cloner) + : Expr(src, ir_cloner), + dtype_(src->dtype_), + index1_(ir_cloner->clone(src->index1_)), + index2_(ir_cloner->clone(src->index2_)) {} + +Expr* EyeOp::shallowCopy() const { + auto result = IrBuilder::create(output(0), dtype_, index1_, index2_); + result->copyPredicatesFrom(this); + return result; +} + +bool EyeOp::sameAs(const Statement* other) const { + if (this == other) { + return true; + } + if (!other->isA()) { + return false; + } + const auto other_op = other->as(); + if (dtype_ != other_op->dtype_) { + return false; + } + if ((index1_ == nullptr) != (other_op->index1_ == nullptr)) { + return false; + } + if ((index2_ == nullptr) != (other_op->index2_ == nullptr)) { + return false; + } + if ((index1_ != nullptr) && !index1_->sameAs(other_op->index1_)) { + return false; + } + if ((index2_ != nullptr) && !index2_->sameAs(other_op->index2_)) { + return false; + } + return Expr::sameAs(other); +} + UnaryOp::UnaryOp( IrBuilderPasskey passkey, UnaryOpType type, @@ -254,6 +363,12 @@ UnaryOp::UnaryOp(const UnaryOp* src, IrCloner* ir_cloner) out_(ir_cloner->clone(src->out_)), in_(ir_cloner->clone(src->in_)) {} +Expr* UnaryOp::shallowCopy() const { + auto result = IrBuilder::create(unary_op_type_, out_, in_); + result->copyPredicatesFrom(this); + return result; +} + bool UnaryOp::sameAs(const Statement* other) const { if (this == other) { return true; @@ -262,8 +377,9 @@ bool UnaryOp::sameAs(const Statement* other) const { return false; } const auto other_op = other->as(); - if (getUnaryOpType() != other_op->getUnaryOpType()) + if (getUnaryOpType() != other_op->getUnaryOpType()) { return false; + } return Expr::sameAs(other); } @@ -290,6 +406,12 @@ BinaryOp::BinaryOp(const BinaryOp* src, IrCloner* ir_cloner) lhs_(ir_cloner->clone(src->lhs_)), rhs_(ir_cloner->clone(src->rhs_)) {} +Expr* BinaryOp::shallowCopy() const { + auto result = IrBuilder::create(binary_op_type_, out_, lhs_, rhs_); + result->copyPredicatesFrom(this); + return result; +} + bool BinaryOp::sameAs(const Statement* other) const { if (this == other) { return true; @@ -298,8 +420,9 @@ bool BinaryOp::sameAs(const Statement* other) const { return false; } const auto other_op = other->as(); - if (getBinaryOpType() != other_op->getBinaryOpType()) + if (getBinaryOpType() != other_op->getBinaryOpType()) { return false; + } return Expr::sameAs(other); } @@ -330,6 +453,13 @@ TernaryOp::TernaryOp(const TernaryOp* src, IrCloner* ir_cloner) in2_(ir_cloner->clone(src->in2_)), in3_(ir_cloner->clone(src->in3_)) {} +Expr* TernaryOp::shallowCopy() const { + auto result = + IrBuilder::create(ternary_op_type_, out_, in1_, in2_, in3_); + result->copyPredicatesFrom(this); + return result; +} + bool TernaryOp::sameAs(const Statement* other) const { if (this == other) { return true; @@ -338,8 +468,9 @@ bool TernaryOp::sameAs(const Statement* other) const { return false; } const auto other_op = other->as(); - if (getTernaryOpType() != other_op->getTernaryOpType()) + if (getTernaryOpType() != other_op->getTernaryOpType()) { return false; + } return Expr::sameAs(other); } @@ -347,26 +478,45 @@ RNGOp::RNGOp( IrBuilderPasskey passkey, RNGOpType type, Val* out, + DataType dtype, + std::vector parameters, int rng_offset, Val* philox_index) : Expr(passkey, ExprType::RNGOp), rng_op_type_(type), + dtype_(dtype), + parameters_(std::move(parameters)), rng_offset_(rng_offset), philox_index_(philox_index) { if (out->isA()) { for (auto id : out->as()->getRootDomain()) { - addInput(id->extent()); + shape_.emplace_back(id->extent()); } } + for (auto v : shape_) { + addInput(v); + } + for (auto v : parameters_) { + addInput(v); + } addOutput(out); } RNGOp::RNGOp(const RNGOp* src, IrCloner* ir_cloner) : Expr(src, ir_cloner), rng_op_type_(src->rng_op_type_), + dtype_(src->dtype()), + parameters_(ir_cloner->clone(src->parameters_)), rng_offset_(src->rng_offset_), philox_index_(ir_cloner->clone(src->philox_index_)) {} +Expr* RNGOp::shallowCopy() const { + auto result = IrBuilder::create( + rng_op_type_, output(0), dtype_, parameters_, rng_offset_, philox_index_); + result->copyPredicatesFrom(this); + return result; +} + bool RNGOp::sameAs(const Statement* other) const { if (this == other) { return true; @@ -378,6 +528,17 @@ bool RNGOp::sameAs(const Statement* other) const { if (getRNGOpType() != other_op->getRNGOpType()) { return false; } + if (dtype_ != other_op->dtype_) { + return false; + } + if (parameters_.size() != other_op->parameters_.size()) { + return false; + } + for (auto i : c10::irange(parameters_.size())) { + if (!parameters_[i]->sameAs(other_op->parameters_[i])) { + return false; + } + } if (getRNGOffset() != other_op->getRNGOffset()) { return false; } @@ -439,7 +600,7 @@ BroadcastOp::BroadcastOp( id->isReduction() || id->isStride(), "Invalid broadcast op: ", id, - ". Non-reduction input dim does't match to output."); + ". Non-reduction input dim doesn't match to output."); } } @@ -467,6 +628,12 @@ BroadcastOp::BroadcastOp(const BroadcastOp* src, IrCloner* ir_cloner) in_(ir_cloner->clone(src->in_)), is_broadcast_dims_(src->is_broadcast_dims_) {} +Expr* BroadcastOp::shallowCopy() const { + auto result = IrBuilder::create(out_, in_, is_broadcast_dims_); + result->copyPredicatesFrom(this); + return result; +} + bool BroadcastOp::sameAs(const Statement* other) const { if (this == other) { return true; @@ -521,6 +688,36 @@ ReductionOp::ReductionOp( addInput(in); } +ReductionOp::ReductionOp(const ReductionOp* src, IrCloner* ir_cloner) + : Expr(src, ir_cloner), + reduction_op_type_(src->reduction_op_type_), + init_(ir_cloner->clone(src->init_)), + out_(ir_cloner->clone(src->out_)), + in_(ir_cloner->clone(src->in_)), + is_allreduce_(src->is_allreduce_) {} + +Expr* ReductionOp::shallowCopy() const { + auto result = IrBuilder::create( + reduction_op_type_, init_, out_, in_, is_allreduce_, etype()); + result->copyPredicatesFrom(this); + return result; +} + +bool ReductionOp::sameAs(const Statement* other) const { + if (this == other) { + return true; + } + if (!other->isA()) { + return false; + } + const auto other_op = other->as(); + // Note that init is not part of input vals, so it must be checked separately. + return ( + Expr::sameAs(other) && + getReductionOpType() == other_op->getReductionOpType() && + init()->sameAs(other_op->init())); +} + GroupedReductionOp::GroupedReductionOp( IrBuilderPasskey passkey, std::vector reduction_op_types, @@ -550,6 +747,18 @@ GroupedReductionOp::GroupedReductionOp( init_vals_(ir_cloner->clone(src->init_vals_)), is_allreduce_(src->is_allreduce_) {} +Expr* GroupedReductionOp::shallowCopy() const { + auto result = IrBuilder::create( + reduction_op_types_, + init_vals_, + outputs(), + inputs(), + is_allreduce_, + etype()); + result->copyPredicatesFrom(this); + return result; +} + int GroupedReductionOp::getExprIndexOfOutput(Val* output_val) const { auto it = std::find(outputs().begin(), outputs().end(), output_val); if (it != outputs().end()) { @@ -724,6 +933,13 @@ WelfordOp::WelfordOp(const WelfordOp* src, IrCloner* ir_cloner) init_(src->init_.clone(ir_cloner)), is_allreduce_(src->is_allreduce_) {} +Expr* WelfordOp::shallowCopy() const { + auto result = + IrBuilder::create(output_, input_, init_, is_allreduce_); + result->copyPredicatesFrom(this); + return result; +} + Val* WelfordOp::getInitValOfOutput(Val* output_val) const { auto val_name = output().getNameOf(output_val); @@ -873,6 +1089,13 @@ GroupedWelfordOp::GroupedWelfordOp( init_vals_(WelfordTriplet::clone(src->init_vals_, ir_cloner)), is_allreduce_(src->is_allreduce_) {} +Expr* GroupedWelfordOp::shallowCopy() const { + auto result = IrBuilder::create( + output_vals_, input_vals_, init_vals_, is_allreduce_, etype()); + result->copyPredicatesFrom(this); + return result; +} + bool GroupedWelfordOp::sameAs(const Statement* other) const { if (this == other) { return true; @@ -967,6 +1190,13 @@ MmaOp::MmaOp(const MmaOp* src, IrCloner* ir_cloner) init_(ir_cloner->clone(src->init_)), options_(src->options_) {} +Expr* MmaOp::shallowCopy() const { + auto result = IrBuilder::create(out_, in_a_, in_b_, init_); + result->options_ = options_; + result->copyPredicatesFrom(this); + return result; +} + bool MmaOp::sameAs(const Statement* other) const { if (this == other) { return true; @@ -979,29 +1209,6 @@ bool MmaOp::sameAs(const Statement* other) const { return false; } -ReductionOp::ReductionOp(const ReductionOp* src, IrCloner* ir_cloner) - : Expr(src, ir_cloner), - reduction_op_type_(src->reduction_op_type_), - init_(ir_cloner->clone(src->init_)), - out_(ir_cloner->clone(src->out_)), - in_(ir_cloner->clone(src->in_)), - is_allreduce_(src->is_allreduce_) {} - -bool ReductionOp::sameAs(const Statement* other) const { - if (this == other) { - return true; - } - if (!other->isA()) { - return false; - } - const auto other_op = other->as(); - // Note that init is not part of input vals, so it must be checked separately. - return ( - Expr::sameAs(other) && - getReductionOpType() == other_op->getReductionOpType() && - init()->sameAs(other_op->init())); -} - TransposeOp::TransposeOp( IrBuilderPasskey passkey, TensorView* out, @@ -1046,6 +1253,12 @@ TransposeOp::TransposeOp(const TransposeOp* src, IrCloner* ir_cloner) in_(ir_cloner->clone(src->in_)), new2old_(src->new2old_) {} +Expr* TransposeOp::shallowCopy() const { + auto result = IrBuilder::create(out_, in_, new2old_); + result->copyPredicatesFrom(this); + return result; +} + std::vector TransposeOp::old2new() const { std::vector old2new(new2old_.size()); for (auto new_axis : c10::irange(new2old_.size())) { @@ -1085,6 +1298,12 @@ ExpandOp::ExpandOp(const ExpandOp* src, IrCloner* ir_cloner) } } +Expr* ExpandOp::shallowCopy() const { + auto result = IrBuilder::create(out_, in_, expanded_extents_); + result->copyPredicatesFrom(this); + return result; +} + ShiftOp::ShiftOp( IrBuilderPasskey passkey, Val* out, @@ -1132,6 +1351,12 @@ ShiftOp::ShiftOp(const ShiftOp* src, IrCloner* ir_cloner) offsets_(src->offsets_), pad_width_(src->pad_width_) {} +Expr* ShiftOp::shallowCopy() const { + auto result = IrBuilder::create(out_, in_, offsets_, pad_width_); + result->copyPredicatesFrom(this); + return result; +} + bool ShiftOp::sameAs(const Statement* other) const { if (this == other) { return true; @@ -1194,6 +1419,13 @@ GatherOp::GatherOp(const GatherOp* src, IrCloner* ir_cloner) window_shape_(src->window_shape_), pad_width_(src->pad_width_) {} +Expr* GatherOp::shallowCopy() const { + auto result = + IrBuilder::create(out_, in_, window_shape_, pad_width_); + result->copyPredicatesFrom(this); + return result; +} + bool GatherOp::sameAs(const Statement* other) const { if (this == other) { return true; @@ -1240,6 +1472,12 @@ ViewAsScalar::ViewAsScalar(const ViewAsScalar* src, IrCloner* ir_cloner) vector_id_(ir_cloner->clone(src->vector_id_)), index_(ir_cloner->clone(src->index_)) {} +Expr* ViewAsScalar::shallowCopy() const { + auto result = IrBuilder::create(out_, in_, vector_id_, index_); + result->copyPredicatesFrom(this); + return result; +} + ViewOp::ViewOp(IrBuilderPasskey passkey, TensorView* out, TensorView* in) : Expr(passkey, ExprType::ViewOp), out_(out), in_(in) { addOutput(out); @@ -1251,6 +1489,12 @@ ViewOp::ViewOp(const ViewOp* src, IrCloner* ir_cloner) out_(ir_cloner->clone(src->out_)), in_(ir_cloner->clone(src->in_)) {} +Expr* ViewOp::shallowCopy() const { + auto result = IrBuilder::create(out_, in_); + result->copyPredicatesFrom(this); + return result; +} + LoadStoreOp::LoadStoreOp( IrBuilderPasskey passkey, LoadStoreOpType op_type, @@ -1270,6 +1514,12 @@ LoadStoreOp::LoadStoreOp(const LoadStoreOp* src, IrCloner* ir_cloner) out_(ir_cloner->clone(src->out_)), in_(ir_cloner->clone(src->in_)) {} +Expr* LoadStoreOp::shallowCopy() const { + auto result = IrBuilder::create(load_store_type_, out_, in_); + result->copyPredicatesFrom(this); + return result; +} + IterDomainBuilder::IterDomainBuilder(Val* _start, Val* _extent) : start_(_start), extent_(_extent) { TORCH_INTERNAL_ASSERT( @@ -1470,6 +1720,37 @@ IterDomain* IterDomain::cloneWithoutRFactor() const { return cloned; } +bool IterDomain::isTrivialReduction() const { + if (!isReduction()) { + return false; + } + + if (extent()->isOneInt()) { + return true; + } + + // If this domain is an output of an expression, i.e., not a root + // domain, check if all root domains are trivial reductions. This is + // almost the same as the analysis done in TrivialReductionInfo, but + // is limited within a single tensor, whereas TrivialReductionInfo + // does more expensive analysis potentially traversing through + // rfactor domains + if (definition()) { + // Note: There's no const version of IterVisitor. + auto id_inputs = InputsOf::output(fusion(), const_cast(this)); + if (std::all_of( + ir_utils::filterByType(id_inputs).begin(), + ir_utils::filterByType(id_inputs).end(), + [](IterDomain* root_id) { + return root_id->isReduction() && root_id->extent()->isOneInt(); + })) { + return true; + } + } + + return false; +} + std::vector IterDomain::clone( const std::vector& domains) { std::vector cloned_domains; @@ -1494,7 +1775,11 @@ IterDomain* IterDomain::merge(IterDomain* outer, IterDomain* inner) { outer->isReduction() == inner->isReduction() || (!outer->isReduction() && inner->isTrivialReduction()) || (outer->isTrivialReduction() && !inner->isReduction()), - "Merging IterDomains requires that their iteration types match."); + "Merging IterDomains requires that their iteration types match. ", + "Outer: ", + outer->toString(), + ", Inner: ", + inner->toString()); TORCH_CHECK( (outer->isGather() && inner->isGather()) || (!outer->isGather() && !inner->isGather()), @@ -1775,7 +2060,7 @@ TensorDomain::TensorDomain( : std::move(contiguity)) { TORCH_CHECK( contiguity_.size() == getMaybeRFactorDomain().size(), - "Invalid contiguity information provided, incorrect size. Recieved vector of size ", + "Invalid contiguity information provided, incorrect size. Received vector of size ", contiguity_.size(), " but needed one of size ", root_domain_.size()); @@ -1799,7 +2084,7 @@ TensorDomain::TensorDomain( : std::move(contiguity)) { TORCH_CHECK( contiguity_.size() == getMaybeRFactorDomain().size(), - "Invalid contiguity information provided, incorrect size. Recieved vector of size ", + "Invalid contiguity information provided, incorrect size. Received vector of size ", contiguity_.size(), " but needed one of size ", root_domain_.size()); @@ -1839,7 +2124,7 @@ TensorDomain::TensorDomain( : std::move(contiguity)) { TORCH_CHECK( contiguity_.size() == getMaybeRFactorDomain().size(), - "Invalid contiguity information provided, incorrect size. Recieved vector of size ", + "Invalid contiguity information provided, incorrect size. Received vector of size ", contiguity_.size(), " but needed one of size ", getMaybeRFactorDomain().size()); @@ -2380,6 +2665,13 @@ Split::Split(const Split* src, IrCloner* ir_cloner) start_offset_(ir_cloner->clone(src->start_offset_)), stop_offset_(ir_cloner->clone(src->stop_offset_)) {} +Expr* Split::shallowCopy() const { + auto result = IrBuilder::create( + outer_, inner_, in_, factor_, inner_split_, start_offset_, stop_offset_); + result->copyPredicatesFrom(this); + return result; +} + Val* Split::extent(Val* in_extent, Val* start_offset, Val* stop_offset) { TORCH_INTERNAL_ASSERT(in_extent != nullptr); @@ -2425,6 +2717,12 @@ Merge::Merge(const Merge* src, IrCloner* ir_cloner) outer_(ir_cloner->clone(src->outer_)), inner_(ir_cloner->clone(src->inner_)) {} +Expr* Merge::shallowCopy() const { + auto result = IrBuilder::create(out_, outer_, inner_); + result->copyPredicatesFrom(this); + return result; +} + bool Merge::sameAs(const Statement* other) const { if (this == other) { return true; @@ -2456,6 +2754,13 @@ Swizzle2D::Swizzle2D( addInput(in_y); } +Expr* Swizzle2D::shallowCopy() const { + auto result = IrBuilder::create( + out_x_, out_y_, in_x_, in_y_, swizzle_type_, swizzle_mode_); + result->copyPredicatesFrom(this); + return result; +} + bool Swizzle2D::sameAs(const Statement* other) const { if (this == other) { return true; diff --git a/torch/csrc/jit/codegen/cuda/ir_utils.cpp b/torch/csrc/jit/codegen/cuda/ir_utils.cpp index 4976518c737b7..dba5ee10adabb 100644 --- a/torch/csrc/jit/codegen/cuda/ir_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_utils.cpp @@ -180,6 +180,16 @@ struct SubstituteInExpr : public OptInDispatch { OptInDispatch::handle(expr); } + void handle(FullOp* full_expr) final { + auto out = reference_->sameAs(full_expr->output(0)) ? substitute_ + : full_expr->output(0); + expr_ = IrBuilder::create( + full_expr->container(), + out, + full_expr->getFillValue(), + full_expr->dtype()); + } + void handle(ARangeOp* arange_expr) final { auto start = reference_->sameAs(arange_expr->start()) ? substitute_ @@ -197,7 +207,19 @@ struct SubstituteInExpr : public OptInDispatch { start, end, step, - arange_expr->getLinearIndex()); + arange_expr->dtype(), + arange_expr->getLinearLogicalIndex()); + } + + void handle(EyeOp* eye_expr) final { + auto out = reference_->sameAs(eye_expr->output(0)) ? substitute_ + : eye_expr->output(0); + expr_ = IrBuilder::create( + eye_expr->container(), + out, + eye_expr->dtype(), + eye_expr->getIndex1(), + eye_expr->getIndex2()); } void handle(UnaryOp* unary_expr) final { @@ -244,12 +266,18 @@ struct SubstituteInExpr : public OptInDispatch { } void handle(RNGOp* rng_expr) final { + std::vector subsituted_params; + for (auto v : rng_expr->getParameters()) { + subsituted_params.emplace_back(reference_->sameAs(v) ? substitute_ : v); + } auto out = reference_->sameAs(rng_expr->output(0)) ? substitute_ : rng_expr->output(0); expr_ = IrBuilder::create( rng_expr->container(), rng_expr->getRNGOpType(), out, + rng_expr->dtype(), + subsituted_params, rng_expr->getRNGOffset(), rng_expr->getPhiloxIndex()); } @@ -748,7 +776,7 @@ class ValReplacementMutator : private OptOutMutator { // grab all leaves towards outputs and grab stmts from there. auto stmts = StmtSort::getStmts(fusion, allLeafOuts(fusion), true); - // Some fusions, such as standalone randlike, can have disconnected DAG, so + // Some fusions, such as standalone rand_like, can have disconnected DAG, so // we need some mechanism to make sure our replacement set is as complete as // possible // TODO: I think we need a more general mechanism to support disconnected @@ -851,6 +879,30 @@ bool isReductionTvOp(const Expr* expr) { return ir_utils::isTvOp(expr) && isReductionOp(expr); } +TORCH_CUDA_CU_API std::vector getViewOps(Fusion* fusion) { + auto all_exprs = fusion->exprs(); + + auto all_view_ops = ir_utils::filterByType(all_exprs); + + std::vector view_ops; + + std::copy_if( + all_view_ops.begin(), + all_view_ops.end(), + std::back_inserter(view_ops), + [](ViewOp* view) { + return std::any_of( + view->outputs().begin(), view->outputs().end(), [](Val* v) { + if (!v->isA()) { + return false; + } + return v->as()->hasRFactor(); + }); + }); + + return view_ops; +} + namespace { struct ReplaceValInIndexVal : public OptInDispatch { diff --git a/torch/csrc/jit/codegen/cuda/ir_utils.h b/torch/csrc/jit/codegen/cuda/ir_utils.h index ce38ebd27fa40..adfc64fc74adf 100644 --- a/torch/csrc/jit/codegen/cuda/ir_utils.h +++ b/torch/csrc/jit/codegen/cuda/ir_utils.h @@ -317,10 +317,15 @@ TORCH_CUDA_CU_API bool isReductionOp(const Expr*); // Returns if Expr is a reduction op with TensorView or TensorIndex TORCH_CUDA_CU_API bool isReductionTvOp(const Expr*); +// Returns all non-trivial view operations. We shouldn't have trivial view +// operations but this function is to simply make sure if we ever do we don't +// pull them in. +TORCH_CUDA_CU_API std::vector getViewOps(Fusion*); + template std::string toString(const T& nodes) { std::stringstream ss; - for (Statement* stmt : nodes) { + for (const Statement* stmt : nodes) { if (ss.tellp() != 0) { ss << ", "; } diff --git a/torch/csrc/jit/codegen/cuda/iter_visitor.cpp b/torch/csrc/jit/codegen/cuda/iter_visitor.cpp index 08ba663c9fa63..984a22194a20a 100644 --- a/torch/csrc/jit/codegen/cuda/iter_visitor.cpp +++ b/torch/csrc/jit/codegen/cuda/iter_visitor.cpp @@ -32,81 +32,44 @@ void remove_visited( } } -// Return all dependencies of a node including members of the node. -class RecursiveDependencies : public OptInDispatch { +class MemberStatements : public OptOutDispatch { public: + // Return all members of the stmt if it's a Val. For expressions it returns + // nothing. static std::vector next(Statement* stmt) { - RecursiveDependencies find_next(stmt); + MemberStatements find_next(stmt); return find_next.next_stmts_; } private: - RecursiveDependencies() = default; + MemberStatements() = default; - RecursiveDependencies(Statement* stmt) { + MemberStatements(Statement* stmt) { handle(stmt); } - using OptInDispatch::handle; - - void handle(Expr* expr) final { - FusionGuard::getCurFusion()->assertInContainer( - expr, - "IterVisitor.cpp::RecursiveDependencies::handle(Expr*) Cannot traverse expr, "); - next_stmts_.insert( - next_stmts_.end(), expr->inputs().begin(), expr->inputs().end()); - } + using OptOutDispatch::handle; void handle(Val* val) final { FusionGuard::getCurFusion()->assertInContainer( val, - "IterVisitor.cpp::RecursiveDependencies::handle(Val*) Cannot traverse val, "); - OptInDispatch::handle(val); - } - - void simpleVal(Val* val) { - if (val->definition() == nullptr) { - return; - } - next_stmts_.push_back(val->definition()); - } - - void handle(Bool* stmt) final { - simpleVal(stmt); - } - - void handle(Double* stmt) final { - simpleVal(stmt); - } - - void handle(Int* stmt) final { - simpleVal(stmt); - } - - void handle(ComplexDouble* stmt) final { - simpleVal(stmt); - } - - void handle(NamedScalar* stmt) final { - simpleVal(stmt); + "IterVisitor.cpp::MemberStatements::handle(Val*) Cannot traverse val, "); + OptOutDispatch::handle(val); } void handle(IterDomain* stmt) final { next_stmts_.push_back(stmt->start()); next_stmts_.push_back(stmt->extent()); next_stmts_.push_back(stmt->stopOffset()); - simpleVal(stmt); } void handle(TensorDomain* stmt) final { next_stmts_.insert( next_stmts_.end(), stmt->domain().begin(), stmt->domain().end()); - simpleVal(stmt); } void handle(TensorView* tv) final { next_stmts_.push_back(tv->domain()); - simpleVal(tv); } std::vector next_stmts_; @@ -169,17 +132,18 @@ void IterVisitor::handle(Val* v) { // To prevent traversing all paths through a DAG (unless we want to) we have a // function to remove visited nodes from being re-added to the stack // (remove_visited). -void IterVisitor::traverseFrom( +void IterVisitor::traverseBetween( Fusion* fusion, - const std::vector& from, - bool traverseAllPaths, - bool traverseIntoMembers) { + const std::unordered_set& from, + const std::vector& to, + bool traverse_all_paths, + bool traverse_into_members) { FusionGuard fg(fusion); std::unordered_set visited; stmt_stack.clear(); - stmt_stack.emplace_back(from.rbegin(), from.rend()); + stmt_stack.emplace_back(to.rbegin(), to.rend()); bool all_inputs_visited = false; @@ -201,7 +165,7 @@ void IterVisitor::traverseFrom( // If we just poped a stmt_stack level, we can finally visit it! if (all_inputs_visited) { // stmt may have be already visited. - if (traverseAllPaths || visited.find(stmt) == visited.end()) { + if (traverse_all_paths || visited.find(stmt) == visited.end()) { // Mark visited visited.insert(stmt); @@ -217,10 +181,20 @@ void IterVisitor::traverseFrom( } else { // We're not ready to process this node, so add all its inputs to be // checked Visit input nodes. - auto next_stmts = - traverseIntoMembers ? RecursiveDependencies::next(stmt) : next(stmt); + std::vector next_stmts; + + if ((stmt->isVal() && from.find(stmt->asVal()) == from.end()) || + stmt->isExpr()) { + next_stmts = next(stmt); + } + + if (traverse_into_members) { + auto members = MemberStatements::next(stmt); + next_stmts.insert(next_stmts.end(), members.begin(), members.end()); + } + // We may want to retraverse nodes, in that case revisit everything! - if (!traverseAllPaths) { + if (!traverse_all_paths) { // If we don't want to retraverse, remove nodes we already visisted. remove_visited(next_stmts, visited); } @@ -238,12 +212,20 @@ void IterVisitor::traverseFrom( } } +void IterVisitor::traverseTo( + Fusion* fusion, + const std::vector& to, + bool traverse_all_paths, + bool traverse_into_members) { + traverseBetween(fusion, {}, to, traverse_all_paths, traverse_into_members); +} + void IterVisitor::traverseHelper(Fusion* fusion, bool traverse_all_paths) { FusionGuard fg(fusion); auto term_val_outs = fusion->getTerminatingOutputs(); if (!term_val_outs.empty()) { - traverseFrom(fusion, term_val_outs, traverse_all_paths); + traverseTo(fusion, term_val_outs, traverse_all_paths); } } @@ -257,8 +239,7 @@ void IterVisitor::traverseAllPaths(Fusion* fusion) { namespace { -// Expr sort will take a fusion and return a topologically sorted list of -// expressions. +// TODO: Also have InputsOf should pick one and remove the other. class Inputs : public IterVisitor { private: //! Optional list of input vals. While traversing to inputs if a value in the @@ -299,7 +280,7 @@ class Inputs : public IterVisitor { return {}; } Inputs inps(all_inputs); - inps.traverseFrom(of[0]->fusion(), of); + inps.traverseTo(of[0]->fusion(), of); return inps.inputs_; } }; @@ -328,7 +309,7 @@ class AllVals : public IterVisitor { Fusion* fusion, const std::vector& from) { AllVals av; - av.traverseFrom(fusion, from, false); + av.traverseTo(fusion, from, false); return av.vals; } }; @@ -386,7 +367,7 @@ void BackwardVisitor::handle(Val* val) { OptOutDispatch::handle(val); } -void BackwardVisitor::traverseFrom( +void BackwardVisitor::traverseTo( Fusion* fusion, const std::vector& from, bool traverseAllPaths) { @@ -538,7 +519,7 @@ struct Dependencies : public IterVisitor { std::unordered_set _dependencies, const std::vector& of) : dependencies_(std::move(_dependencies)) { - traverseFrom(of[0]->fusion(), of, false); + traverseTo(of[0]->fusion(), of, false); }; public: @@ -585,7 +566,7 @@ struct FindOutputs : public IterVisitor { // tracing all paths like this. FindOutputs(const std::unordered_set& _of) : of_(_of) { auto fusion = (*of_.begin())->fusion(); - traverseFrom(fusion, fusion->outputs(), true); + traverseTo(fusion, fusion->outputs(), true); }; static std::unordered_set getAllOutputsOf( @@ -653,7 +634,7 @@ class DependentVals : public IterVisitor { DependentVals(const std::unordered_set& _of) : of_(_of) { createBoundary(); auto fusion = (*of_.begin())->fusion(); - traverseFrom(fusion, fusion->outputs(), false); + traverseTo(fusion, fusion->outputs(), false); }; public: @@ -689,7 +670,7 @@ class DependencyChains : public IterVisitor { DependencyChains(Val* _dependency, Val* _of, bool all_chains_ = false) : dependencies_({_dependency}) { - traverseFrom(_of->fusion(), {_of}, all_chains_); + traverseTo(_of->fusion(), {_of}, all_chains_); } DependencyChains(Val* _dependency, bool all_chains_ = false) @@ -815,12 +796,21 @@ std::vector StmtSort::getExprs(Fusion* fusion, bool traverse_members) { } std::vector StmtSort::getExprs( + Fusion* fusion, + const std::vector& to, + bool traverse_members) { + auto stmts = StmtSort::getStmts(fusion, to, traverse_members); + auto filter = ir_utils::filterByType(stmts.begin(), stmts.end()); + std::vector exprs(filter.begin(), filter.end()); + return exprs; +} + +std::vector StmtSort::getExprsBetween( Fusion* fusion, const std::vector& from, + const std::vector& to, bool traverse_members) { - StmtSort es; - es.traverseFrom(fusion, from, false, traverse_members); - auto stmts = StmtSort::getStmts(fusion, from, traverse_members); + auto stmts = StmtSort::getStmtsBetween(fusion, from, to, traverse_members); auto filter = ir_utils::filterByType(stmts.begin(), stmts.end()); std::vector exprs(filter.begin(), filter.end()); return exprs; @@ -834,11 +824,22 @@ std::vector StmtSort::getStmts( } std::vector StmtSort::getStmts( + Fusion* fusion, + const std::vector& to, + bool traverse_members) { + StmtSort es; + es.traverseTo(fusion, to, false, traverse_members); + return es.stmts; +} + +std::vector StmtSort::getStmtsBetween( Fusion* fusion, const std::vector& from, + const std::vector& to, bool traverse_members) { StmtSort es; - es.traverseFrom(fusion, from, false, traverse_members); + es.traverseBetween( + fusion, {from.begin(), from.end()}, to, false, traverse_members); return es.stmts; } @@ -858,7 +859,7 @@ std::vector InputsOf::outputs( Fusion* fusion, const std::vector& outputs_) { InputsOf io; - io.traverseFrom(fusion, outputs_, false); + io.traverseTo(fusion, outputs_, false); return io.ordered_inputs; } diff --git a/torch/csrc/jit/codegen/cuda/iter_visitor.h b/torch/csrc/jit/codegen/cuda/iter_visitor.h index 8adac390dac89..3ad485f1a17b6 100644 --- a/torch/csrc/jit/codegen/cuda/iter_visitor.h +++ b/torch/csrc/jit/codegen/cuda/iter_visitor.h @@ -75,29 +75,43 @@ class TORCH_CUDA_CU_API IterVisitor : public OptOutDispatch { // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) std::vector> stmt_stack; - // Statements to stop traversal on if they're hit (pretends they're leaf - // nodes in next) - // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - std::unordered_set termination_stmts; - void traverseHelper(Fusion* fusion, bool traverse_all_paths = false); public: - //! Starts at nodes provided in from, traverses from these nodes to inputs. - //! Calls handle on all Statement*s in topological sorted order. + //! Traverses nodes in Fusion from inputs in topological order to "to". i.e. + //! from inputs towards outputs. //! \param traverseAllPaths = false only call handle on each Statement* once - //! traverseAllPaths = true traverses all paths from nodes in from to - //! inputs. Calls handle on a Statement* for every path from "from" nodes, - //! to inputs. + //! traverseAllPaths = true traverses all paths between expressions/values. + //! Calls handle on a Statement* for every path from inputs to "to". //! \param traverseIntoMembers = When hitting nodes like TensorView, //! TensorDomain, or IterDomain where there are members of the nodes that are //! Val's a value of "true" will also traverse into those member Val's, a //! value of "false" will not traverse into the members. - void traverseFrom( + void traverseTo( Fusion* fusion, - const std::vector& from, - bool traverseAllPaths = false, - bool traverseIntoMembers = false); + const std::vector& to, + bool traverse_all_paths = false, + bool traverse_into_members = false); + + //! Traverses nodes in Fusion from inputs in topological order to "to". i.e. + //! from inputs towards outputs. + //! \param traverseAllPaths = false only call handle on each Statement* once + //! traverseAllPaths = true traverses all paths between expressions/values. + //! Calls handle on a Statement* for every path from inputs to "to". + //! \param traverseIntoMembers = When hitting nodes like TensorView, + //! TensorDomain, or IterDomain where there are members of the nodes that are + //! Val's a value of "true" will also traverse into those member Val's, a + //! value of "false" will not traverse into the members. + //! \param from: Specified values to start traversing. If a "from" Val is not + //! on path from inputs to "to" node it will not be visited. If there's a path + //! from inputs to "to" that doesn't go through "from" that input and the path + //! from it will also be traversed. + void traverseBetween( + Fusion* fusion, + const std::unordered_set& from, + const std::vector& to, + bool traverse_all_paths = false, + bool traverse_into_members = false); // Iterates from terminating outputs registered with the fusion. Terminating // means value is not used to generate any other value used in producing @@ -110,6 +124,9 @@ class TORCH_CUDA_CU_API IterVisitor : public OptOutDispatch { //! Get inputs to vals. Possible input vals can be optionally //! given. If not, vals with no producers are returned. + // + // TODO: This doesn't seem to fit with IterVisitor. Should probably be moved + // out of the class. static std::vector getInputsTo( const std::vector& vals, const std::vector& inputs = {}); @@ -197,7 +214,7 @@ class TORCH_CUDA_CU_API BackwardVisitor : public OptOutDispatch { // traverseAllPaths = false only call handle on each Statement* once // traverseAllPaths = true traverses all paths from nodes in from to inputs. // Handle on a Statement* for every path from "from" nodes, to inputs. - void traverseFrom( + void traverseTo( Fusion* fusion, const std::vector& from, bool traverseAllPaths = false); @@ -251,37 +268,65 @@ class TORCH_CUDA_CU_API DependencyCheck { // expressions. class StmtSort : public IterVisitor { protected: + StmtSort() = default; + std::vector stmts; void handle(Statement* stmt) override; public: // If traverse_members it will also extract all member nodes in the sorted - // expr list in the fusion. i.e. all expressions on IterDomains, extents, etc - static std::vector getExprs( + // statement list in the fusion. i.e. all IterDomains, extents, and associated + // expressions of them + static std::vector getStmts( Fusion* fusion, bool traverse_members = false); + // Returns ordered Statements required to produce from, including from. + static std::vector getStmts( + Fusion* fusion, + const std::vector& to, + bool traverse_members = false); + + // Returns ordered Statements required to produce from, including from. + // Stops traversal once hiting any Statements in to. Includes Statements in + // to. + // + // Warning: this doesn't necessarily prevent statements before `to` from being + // returned. e.g. + // i1 = i0 + // i2 = i1 + // i3 = i2 + // i4 = i3 + i1 + // getExprs(fusion, {i4}, {i3}) + // will return the definition and values {i0, i1, i4} + // i3 is dependent on i1, but since i4 also is then the traversal will go down + // the i4->i1->i0 path, even though the i4->i3-//>i2->i1 path is blocked. + // // If traverse_members it will also extract all member nodes in the sorted // expr list in the fusion. i.e. all expressions on IterDomains, extents, etc - static std::vector getExprs( + static std::vector getStmtsBetween( Fusion* fusion, const std::vector& from, + const std::vector& to, bool traverse_members = false); - // If traverse_members it will also extract all member nodes in the sorted - // statement list in the fusion. i.e. all IterDomains, extents, and associated - // expressions of them - static std::vector getStmts( + // Same as getStmts version but filters to only return the Expr*s + static std::vector getExprs( Fusion* fusion, bool traverse_members = false); - // If traverse_members it will also extract all member nodes in the sorted - // expr list in the fusion. i.e. all IterDomains, extents, and associated - // expressions of them - static std::vector getStmts( + // Same as getStmts version but filters to only return the Expr*s + static std::vector getExprs( + Fusion* fusion, + const std::vector& to, + bool traverse_members = false); + + // Same as getStmts version but filters to only return the Expr*s + static std::vector getExprsBetween( Fusion* fusion, const std::vector& from, + const std::vector& to, bool traverse_members = false); }; diff --git a/torch/csrc/jit/codegen/cuda/kernel_cache.cpp b/torch/csrc/jit/codegen/cuda/kernel_cache.cpp index 85448dc8ac418..c4604042bfaed 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_cache.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel_cache.cpp @@ -196,7 +196,15 @@ std::vector FusionExecutorCache::runFusionWithInputs( auto kernel_runtime = getKernelRuntimeFor(args); most_recent_runtime_ = kernel_runtime; + int seq_id = 0; + // Record kernel input and output tensors so profiler can construct + // the data flow graph + RECORD_FUNCTION( + "run_fused_kernel", + std::vector(inputs.begin(), inputs.end()), + seq_id); auto outputs = kernel_runtime->runWithInput(args); + RECORD_OUTPUTS(outputs); // permute output tensor returned by kernel execution. See Part_3 in Note [ // Permutation support in nvfuser ] @@ -485,7 +493,7 @@ void FusionKernelRuntime::startAsyncCompile(KernelArgumentHolder& args_old) { TORCH_INTERNAL_ASSERT( args.size() == segmented_fusion_->inputs().size(), - "Inputs were not set up correctly, recieved ", + "Inputs were not set up correctly, received ", args.size(), " inputs but expecting ", segmented_fusion_->inputs().size()); @@ -602,7 +610,7 @@ std::vector FusionKernelRuntime::runWithInput( TORCH_INTERNAL_ASSERT( args.size() == segmented_fusion_->inputs().size(), - "Inputs were not set up correctly, recieved ", + "Inputs were not set up correctly, received ", args.size(), " inputs but expecting ", segmented_fusion_->inputs().size()); @@ -649,11 +657,16 @@ std::vector FusionKernelRuntime::runWithInput( group_outputs.size() == group_runtime_outputs.size(), "output size does not match"); for (const size_t group_out_i : c10::irange(group_outputs.size())) { - output_holder[group_outputs[group_out_i]] = - group_runtime_outputs[group_out_i]; + // trivial forwarding outputs empty tensor to save bandwidth, skip + // tensor_map update on those, since we want all future use of inputs on + // the original tensor input. See note [trivial forwarding] + if (!group_outputs[group_out_i]->isFusionInput()) { + output_holder[group_outputs[group_out_i]] = + group_runtime_outputs[group_out_i]; - args.push(group_runtime_outputs[group_out_i]); - tensor_map.emplace(group_outputs[group_out_i], args.back()); + args.push(group_runtime_outputs[group_out_i]); + tensor_map.emplace(group_outputs[group_out_i], args.back()); + } } } @@ -668,6 +681,32 @@ std::vector FusionKernelRuntime::runWithInput( const auto iter = output_holder.find(output); if (iter != output_holder.end()) { fusion_outputs.push_back(iter->second); + } else if (output->isFusionInput()) { + // Note [ trivial forwarding ] + // + // Background: + // nvfuser codegen doesn't handle aliases at all. When we have a fusion + // that forwards an input to output without any operations on it, this is + // a no-op for codegen and the output tensor is never written to. However, + // the codegen cannot "forward" an input to output, since all outputs are + // allocated in integration. If we do not special case it, we'll ended up + // having a "fresh" tensor allocated for the forwarded-input. + // + // Approach: + // There are two aspects of the support: + // step 1. Codegen handles forwarding implicitly. Forwarded inputs doesn't + // have any producer in the IR, hence the output argument is not used in + // the code. But it does require to have an argument in the kernel as a + // place-holder so we'll map each arguments correctly. + // step 2. Integration handles the trivial forwarding of inputs. When we + // put together `fusion_outputs` for a given fusion, when outputs are just + // fusion inputs, we directly return the input tensor. + const auto iter = tensor_map.find(output); + TORCH_INTERNAL_ASSERT( + iter != tensor_map.end(), "Can not find output as aliased intput"); + auto arg = dynamic_cast(iter->second); + // See step 2 - note [ trivial forwarding ] + fusion_outputs.push_back(arg->getTensor()); } else { bool empty_type_check = output->getDataType().has_value() && output->getDataType().value() == DataType::Float; diff --git a/torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.cpp b/torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.cpp index a4a823ab55605..15a18a6bca83e 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.cpp @@ -132,6 +132,7 @@ void ExpressionEvaluator::handle(const NamedScalar* named_scalar) { } void ExpressionEvaluator::handle(const UnaryOp* unary_op) { + using namespace IntOrDouble_functions; const auto in = evaluate(unary_op->in()); if (in.has_value()) { switch (unary_op->getUnaryOpType()) { @@ -150,6 +151,9 @@ void ExpressionEvaluator::handle(const UnaryOp* unary_op) { TORCH_INTERNAL_ASSERT(false, "dtype not supported in evaluator"); } break; + case UnaryOpType::Abs: + known_values_[unary_op->out()] = abs(*in); + break; default: TORCH_CHECK( false, diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp index 132b99b31c34b..7e69f0307a7a5 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp @@ -78,6 +78,15 @@ TensorIndex::TensorIndex( } } +Val* TensorIndex::index(int i) const { + TORCH_INTERNAL_ASSERT( + nDims() > 0, "Tried to get an index of a 0-dim TensorIndex"); + if (i < 0) + i += nDims(); + TORCH_INTERNAL_ASSERT(i >= 0 && i < int(nDims())); + return indices_[i]; +} + BlockSync::BlockSync(IrBuilderPasskey passkey, bool war_sync) : Expr(passkey, ExprType::BlockSync), war_sync_(war_sync) { TORCH_INTERNAL_ASSERT( @@ -85,6 +94,12 @@ BlockSync::BlockSync(IrBuilderPasskey passkey, bool war_sync) "IR type only valid for Kernel container."); } +Expr* BlockSync::shallowCopy() const { + auto result = IrBuilder::create(war_sync_); + result->copyPredicatesFrom(this); + return result; +} + GridSync::GridSync( IrBuilderPasskey passkey, ParallelTypeBitmap sync_dims, @@ -93,6 +108,12 @@ GridSync::GridSync( sync_dims_(sync_dims), sync_buffer_(sync_buffer) {} +Expr* GridSync::shallowCopy() const { + auto result = IrBuilder::create(sync_dims_, sync_buffer_); + result->copyPredicatesFrom(this); + return result; +} + CpAsyncWait::CpAsyncWait(IrBuilderPasskey passkey, unsigned int keep_stages) : Expr(passkey, ExprType::CpAsyncWait), keep_stages_(keep_stages) { TORCH_INTERNAL_ASSERT( @@ -100,6 +121,12 @@ CpAsyncWait::CpAsyncWait(IrBuilderPasskey passkey, unsigned int keep_stages) "IR type only valid for Kernel container."); } +Expr* CpAsyncWait::shallowCopy() const { + auto result = IrBuilder::create(keep_stages_); + result->copyPredicatesFrom(this); + return result; +} + CpAsyncCommit::CpAsyncCommit(IrBuilderPasskey passkey) : Expr(passkey, ExprType::CpAsyncCommit) { TORCH_INTERNAL_ASSERT( @@ -107,6 +134,12 @@ CpAsyncCommit::CpAsyncCommit(IrBuilderPasskey passkey) "IR type only valid for Kernel container."); } +Expr* CpAsyncCommit::shallowCopy() const { + auto result = IrBuilder::create(); + result->copyPredicatesFrom(this); + return result; +} + InitMagicZero::InitMagicZero(IrBuilderPasskey passkey) : Expr(passkey, ExprType::InitMagicZero) { TORCH_INTERNAL_ASSERT( @@ -114,6 +147,12 @@ InitMagicZero::InitMagicZero(IrBuilderPasskey passkey) "IR type only valid for Kernel container."); } +Expr* InitMagicZero::shallowCopy() const { + auto result = IrBuilder::create(); + result->copyPredicatesFrom(this); + return result; +} + UpdateMagicZero::UpdateMagicZero(IrBuilderPasskey passkey) : Expr(passkey, ExprType::UpdateMagicZero) { TORCH_INTERNAL_ASSERT( @@ -121,6 +160,12 @@ UpdateMagicZero::UpdateMagicZero(IrBuilderPasskey passkey) "IR type only valid for Kernel container."); } +Expr* UpdateMagicZero::shallowCopy() const { + auto result = IrBuilder::create(); + result->copyPredicatesFrom(this); + return result; +} + namespace { bool isIntegralScalar(const Val* val) { @@ -147,6 +192,12 @@ PairSelect::PairSelect( TORCH_INTERNAL_ASSERT(isIntegralScalar(out), "Integer only for this op"); } +Expr* PairSelect::shallowCopy() const { + auto result = IrBuilder::create(out_, in_, selection_); + result->copyPredicatesFrom(this); + return result; +} + Swizzle2DInt::Swizzle2DInt( IrBuilderPasskey passkey, IntPair* out, @@ -172,6 +223,13 @@ Swizzle2DInt::Swizzle2DInt( addInput(extent_y); } +Expr* Swizzle2DInt::shallowCopy() const { + auto result = IrBuilder::create( + out_, in_x_, in_y_, extent_x_, extent_y_, swizzle_type_); + result->copyPredicatesFrom(this); + return result; +} + void Scope::insert(std::vector::const_iterator pos, Expr* expr) { exprs_.insert(pos, expr); } @@ -307,6 +365,22 @@ ForLoop::ForLoop(IrBuilderPasskey passkey, const ForLoop* other) "IR type only valid for Kernel container."); } +Expr* ForLoop::shallowCopy() const { + auto result = IrBuilder::create( + iter_domain_, + index_, + start_, + stop_, + step_, + vectorize_, + vectorize_shift_, + unroll_required_, + double_buffer_loop_stage_); + result->body_ = body_; + result->copyPredicatesFrom(this); + return result; +} + bool ForLoop::isUnrollable() const { // Start and stop must be constant, must not be a broadcast // dimension, cannot be bound to a parallel dimension, must not be @@ -426,13 +500,12 @@ IfThenElse::IfThenElse(IrBuilderPasskey passkey, Predicate* cond) addInput(cond); } -Val* TensorIndex::index(int i) const { - TORCH_INTERNAL_ASSERT( - nDims() > 0, "Tried to get an index of a 0-dim TensorIndex"); - if (i < 0) - i += nDims(); - TORCH_INTERNAL_ASSERT(i >= 0 && i < int(nDims())); - return indices_[i]; +Expr* IfThenElse::shallowCopy() const { + auto result = IrBuilder::create(predicate()); + result->then_body_ = then_body_; + result->else_body_ = else_body_; + result->setWritePredicate(writePredicate()); + return result; } Allocate::Allocate( @@ -495,6 +568,13 @@ Allocate::Allocate( "IR type only valid for Kernel container."); } +Expr* Allocate::shallowCopy() const { + auto result = + IrBuilder::create(buffer_, memory_type_, shape_, zero_init_); + result->copyPredicatesFrom(this); + return result; +} + GridReduction::GridReduction( IrBuilderPasskey passkey, BinaryOpType reduction_op_type, @@ -523,6 +603,22 @@ GridReduction::GridReduction( "IR type only valid for Kernel container."); } +Expr* GridReduction::shallowCopy() const { + auto result = IrBuilder::create( + getReductionOpType(), + init(), + out(), + in(), + reduction_buffer_, + sync_buffer_, + entrance_index_, + entrances_, + isAllreduce()); + result->copyPredicatesFrom(this); + result->thread_predicate_ = thread_predicate_; + return result; +} + GroupedGridReduction::GroupedGridReduction( IrBuilderPasskey passkey, std::vector reduction_op_types, @@ -553,6 +649,23 @@ GroupedGridReduction::GroupedGridReduction( "IR type only valid for Kernel container."); } +Expr* GroupedGridReduction::shallowCopy() const { + auto result = IrBuilder::create( + getReductionOpTypes(), + initVals(), + outputs(), + inputs(), + reduction_buffers_, + sync_buffer_, + entrance_index_, + entrances_, + buffer_stride_, + isAllreduce()); + result->copyPredicatesFrom(this); + result->thread_predicate_ = thread_predicate_; + return result; +} + GridBroadcast::GridBroadcast( IrBuilderPasskey passkey, BroadcastOp* broadcast_op, @@ -567,6 +680,13 @@ GridBroadcast::GridBroadcast( "IR type only valid for Kernel container."); } +Expr* GridBroadcast::shallowCopy() const { + auto result = IrBuilder::create( + broadcast_op_, broadcast_buffer_, sync_buffer_); + result->copyPredicatesFrom(this); + return result; +} + GridWelford::GridWelford( IrBuilderPasskey passkey, WelfordOp* welford_op, @@ -589,6 +709,20 @@ GridWelford::GridWelford( "IR type only valid for Kernel container."); } +Expr* GridWelford::shallowCopy() const { + auto result = IrBuilder::create( + welford_op_, + var_buffer_, + avg_buffer_, + n_buffer_, + sync_buffer_, + entrance_index_, + entrances_); + result->copyPredicatesFrom(this); + result->thread_predicate_ = thread_predicate_; + return result; +} + GroupedGridWelford::GroupedGridWelford( IrBuilderPasskey passkey, std::vector output_vals, @@ -617,6 +751,22 @@ GroupedGridWelford::GroupedGridWelford( "IR type only valid for Kernel container."); } +Expr* GroupedGridWelford::shallowCopy() const { + auto result = IrBuilder::create( + outputVals(), + inputVals(), + initVals(), + reduction_buffers_, + sync_buffer_, + entrance_index_, + entrances_, + buffer_stride_, + isAllreduce()); + result->copyPredicatesFrom(this); + result->thread_predicate_ = thread_predicate_; + return result; +} + AllocateFusedReduction::AllocateFusedReduction( IrBuilderPasskey passkey, GridReduction* grid_reduction) @@ -657,6 +807,36 @@ AllocateFusedReduction::AllocateFusedReduction( "IR type only valid for Kernel container."); } +Expr* AllocateFusedReduction::shallowCopy() const { + if (grid_expr_->isA()) { + auto result = IrBuilder::create( + grid_expr_->as()); + result->setPredicate(predicate()); + result->setWritePredicate(writePredicate()); + return result; + } else if (grid_expr_->isA()) { + auto result = IrBuilder::create( + grid_expr_->as()); + result->setPredicate(predicate()); + result->setWritePredicate(writePredicate()); + return result; + } else if (grid_expr_->isA()) { + auto result = IrBuilder::create( + grid_expr_->as()); + result->setPredicate(predicate()); + result->setWritePredicate(writePredicate()); + return result; + } else if (grid_expr_->isA()) { + auto result = IrBuilder::create( + grid_expr_->as()); + result->setPredicate(predicate()); + result->setWritePredicate(writePredicate()); + return result; + } + TORCH_INTERNAL_ASSERT( + false, "Unknown reduction type in AllocateFusedReduction::shallowCopy"); +} + TensorIndex* AllocateFusedReduction::out() const { TORCH_INTERNAL_ASSERT(grid_expr_ != nullptr); if (grid_expr_->isA() || diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.h b/torch/csrc/jit/codegen/cuda/kernel_ir.h index 62b245772dd03..cd44e8d8e21b7 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.h +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.h @@ -94,7 +94,7 @@ class TORCH_CUDA_CU_API Predicate final : public Val { return expr_; } - Bool* thread_pred() { + Bool* thread_pred() const { TORCH_INTERNAL_ASSERT( ptype_ == PredicateType::Inline || ptype_ == PredicateType::Misaligned || ptype_ == PredicateType::Shift || @@ -199,6 +199,8 @@ class TORCH_CUDA_CU_API Allocate final : public Expr { Val* size, bool zero_init = false); + Expr* shallowCopy() const override; + Val* buffer() const { return buffer_; } @@ -251,6 +253,8 @@ class TORCH_CUDA_CU_API BlockSync final : public Expr { public: explicit BlockSync(IrBuilderPasskey passkey, bool war_sync = false); + Expr* shallowCopy() const override; + bool isWarHazardSync() const { return war_sync_; } @@ -265,6 +269,8 @@ class TORCH_CUDA_CU_API CpAsyncWait final : public Expr { public: explicit CpAsyncWait(IrBuilderPasskey passkey, unsigned int keep_stages = 0); + Expr* shallowCopy() const override; + //! Returns the remaining number of stages that are not synchronized //! after this op. unsigned int keepStages() const { @@ -282,6 +288,8 @@ class TORCH_CUDA_CU_API CpAsyncWait final : public Expr { class TORCH_CUDA_CU_API CpAsyncCommit final : public Expr { public: explicit CpAsyncCommit(IrBuilderPasskey passkey); + + Expr* shallowCopy() const override; }; // Synchronize all blocks in device, implies cooperative group launch is @@ -293,6 +301,8 @@ class TORCH_CUDA_CU_API GridSync final : public Expr { ParallelTypeBitmap sync_dims, Val* sync_buffer); + Expr* shallowCopy() const override; + ParallelTypeBitmap syncDims() const { return sync_dims_; } @@ -311,6 +321,8 @@ class TORCH_CUDA_CU_API GridSync final : public Expr { class TORCH_CUDA_CU_API InitMagicZero final : public Expr { public: explicit InitMagicZero(IrBuilderPasskey passkey); + + Expr* shallowCopy() const override; }; // Simply prints "UPDATE_MAGIC_ZERO" in the code in accordance with magic_zero @@ -318,6 +330,8 @@ class TORCH_CUDA_CU_API InitMagicZero final : public Expr { class TORCH_CUDA_CU_API UpdateMagicZero final : public Expr { public: explicit UpdateMagicZero(IrBuilderPasskey passkey); + + Expr* shallowCopy() const override; }; // TODO(kir): promote to IR node @@ -418,6 +432,8 @@ class TORCH_CUDA_CU_API ForLoop final : public Expr { ForLoop(IrBuilderPasskey passkey, const ForLoop* other); + Expr* shallowCopy() const override; + Val* index() const { return index_; } @@ -512,6 +528,8 @@ class TORCH_CUDA_CU_API IfThenElse final : public Expr { public: explicit IfThenElse(IrBuilderPasskey passkey, Predicate* cond); + Expr* shallowCopy() const override; + Scope& thenBody() { return then_body_; } @@ -557,6 +575,8 @@ class TORCH_CUDA_CU_API GridReduction final : public ReductionOp { Val* entrances, bool is_allreduce = false); + Expr* shallowCopy() const override; + Allocate* reduction_buffer() const { return reduction_buffer_; } @@ -579,8 +599,11 @@ class TORCH_CUDA_CU_API GridReduction final : public ReductionOp { return thread_predicate_; } - void setThreadPredicate(const ParallelTypeBitmap& thread_predicate) { - thread_predicate_ = thread_predicate; + GridReduction* withThreadPredicate( + const ParallelTypeBitmap& thread_predicate) { + auto result = shallowCopy()->as(); + result->thread_predicate_ = thread_predicate; + return result; } private: @@ -609,6 +632,8 @@ class TORCH_CUDA_CU_API GroupedGridReduction final : public GroupedReductionOp { Val* buffer_stride, bool is_allreduce = false); + Expr* shallowCopy() const override; + const std::vector& reduction_buffers() const { return reduction_buffers_; } @@ -639,8 +664,11 @@ class TORCH_CUDA_CU_API GroupedGridReduction final : public GroupedReductionOp { return thread_predicate_; } - void setThreadPredicate(const ParallelTypeBitmap& thread_predicate) { - thread_predicate_ = thread_predicate; + GroupedGridReduction* withThreadPredicate( + const ParallelTypeBitmap& thread_predicate) { + auto result = shallowCopy()->as(); + result->thread_predicate_ = thread_predicate; + return result; } private: @@ -671,6 +699,8 @@ class TORCH_CUDA_CU_API GridBroadcast final : public Expr { Allocate* broadcast_buffer, Allocate* sync_buffer); + Expr* shallowCopy() const override; + BroadcastOp* broadcast_op() const { return broadcast_op_; } @@ -710,6 +740,8 @@ class TORCH_CUDA_CU_API GridWelford final : public Expr { Val* entrance_index, Val* entrances); + Expr* shallowCopy() const override; + WelfordOp* welford_op() const { return welford_op_; } @@ -744,8 +776,10 @@ class TORCH_CUDA_CU_API GridWelford final : public Expr { return thread_predicate_; } - void setThreadPredicate(const ParallelTypeBitmap& thread_predicate) { - thread_predicate_ = thread_predicate; + GridWelford* withThreadPredicate(const ParallelTypeBitmap& thread_predicate) { + auto result = shallowCopy()->as(); + result->thread_predicate_ = thread_predicate; + return result; } private: @@ -777,6 +811,8 @@ class TORCH_CUDA_CU_API GroupedGridWelford final : public GroupedWelfordOp { Val* buffer_stride, bool is_allreduce = false); + Expr* shallowCopy() const override; + const std::array, 3>& reduction_buffers() const { return reduction_buffers_; } @@ -803,8 +839,11 @@ class TORCH_CUDA_CU_API GroupedGridWelford final : public GroupedWelfordOp { return thread_predicate_; } - void setThreadPredicate(const ParallelTypeBitmap& thread_predicate) { - thread_predicate_ = thread_predicate; + GroupedGridWelford* withThreadPredicate( + const ParallelTypeBitmap& thread_predicate) { + auto result = shallowCopy()->as(); + result->thread_predicate_ = thread_predicate; + return result; } private: @@ -839,6 +878,8 @@ class TORCH_CUDA_CU_API AllocateFusedReduction final : public Expr { IrBuilderPasskey passkey, GroupedGridWelford* grouped_grid_welford); + Expr* shallowCopy() const override; + Expr* gridExpr() const { return grid_expr_; } @@ -879,6 +920,8 @@ class TORCH_CUDA_CU_API PairSelect : public Expr { PairSelect(IrBuilderPasskey, Val* out, IntPair* in, Selection selection); + Expr* shallowCopy() const override; + Val* out() const { return out_; } @@ -914,6 +957,8 @@ class TORCH_CUDA_CU_API Swizzle2DInt : public Expr { Val* extent_y, Swizzle2DType swizzle_type); + Expr* shallowCopy() const override; + IntPair* out() const { return out_; } diff --git a/torch/csrc/jit/codegen/cuda/lower2device.cpp b/torch/csrc/jit/codegen/cuda/lower2device.cpp index 53b9d172f203f..142ee1b7a02fb 100644 --- a/torch/csrc/jit/codegen/cuda/lower2device.cpp +++ b/torch/csrc/jit/codegen/cuda/lower2device.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -248,7 +249,7 @@ void GpuLower::lower(Fusion* fusion, DataType index_type) { // mappings of all iteration domains across the fusion. There are three types // of mappings Permissive, Exact, and Loop, see compute_at_map.h/cpp for more // information. - compute_at_map_ = std::make_unique(fusion_); + compute_at_map_ = std::make_shared(fusion_); if (isDebugDumpEnabled(DebugDumpOption::ComputeAtMap)) { std::cout << compute_at_map_->toString() << std::endl; @@ -256,8 +257,12 @@ void GpuLower::lower(Fusion* fusion, DataType index_type) { compute_at_map_->validateAndPropagatePType(); + // Uses compute_at_map, find all splits that are enforced to be divisible + divisible_splits_ = getAllDivisibleSplits(fusion_, compute_at_map_.get()); + // Used in parallel dimension map - concretized_broadcast_domains_.build(fusion_); + concretized_broadcast_domains_ = + std::make_shared(fusion_); parallelDimensionMap().build(fusion_); if (isDebugDumpEnabled(DebugDumpOption::ParallelDimensions)) { @@ -281,7 +286,7 @@ void GpuLower::lower(Fusion* fusion, DataType index_type) { // Scan the whole fusion and build mappings about halo extensions of // all IterDomains - haloInfo().build(fusion_); + halo_info_ = std::make_shared(fusion_, compute_at_map_); // Want to run this after parallel map and halo info map are // created. vectorized_accesses_ and vectorized_set_info_ are filled. @@ -298,6 +303,9 @@ void GpuLower::lower(Fusion* fusion, DataType index_type) { // Depends on thread_pred_map_, validates parallelization collects which // tensor views need WAR or RAW syncs sync_map_.build(fusion_); + if (isDebugDumpEnabled(DebugDumpOption::SyncMap)) { + std::cout << sync_map_.toString() << std::endl; + } partialSplitMap().build(fusion_); diff --git a/torch/csrc/jit/codegen/cuda/lower2device.h b/torch/csrc/jit/codegen/cuda/lower2device.h index d5600e0a25139..250b06a6495fb 100644 --- a/torch/csrc/jit/codegen/cuda/lower2device.h +++ b/torch/csrc/jit/codegen/cuda/lower2device.h @@ -62,7 +62,8 @@ class TORCH_CUDA_CU_API GpuLower : public NonCopyable { //! Query if lowering is in progress static bool hasCurrent(); - ConcretizedBroadcastDomains& concretizedBroadcastDomains() { + std::shared_ptr + concretizedBroadcastDomains() { return concretized_broadcast_domains_; } @@ -76,20 +77,16 @@ class TORCH_CUDA_CU_API GpuLower : public NonCopyable { return thread_pred_map_; } - const std::unique_ptr& caMap() const { - return compute_at_map_; + std::shared_ptr caMap() const { + return std::const_pointer_cast(compute_at_map_); } const TrivialReductionInfo& trivialReductionInfo() const { return trivial_reduction_info_; } - const HaloInfo& haloInfo() const { - return halo_info_; - } - - HaloInfo& haloInfo() { - return halo_info_; + std::shared_ptr haloInfo() const { + return std::const_pointer_cast(halo_info_); } const ParallelDimensionMap& parallelDimensionMap() const { @@ -132,6 +129,10 @@ class TORCH_CUDA_CU_API GpuLower : public NonCopyable { return non_divisible_split_info_; } + const auto& divisbleSplitSet() const { + return divisible_splits_; + } + DoubleBufferInfo& doubleBufferInfo() { return double_buffer_info_; } @@ -198,12 +199,13 @@ class TORCH_CUDA_CU_API GpuLower : public NonCopyable { // would be safer to wrap all of these in unique pointers and remove the build // interface and default constructor. That way they couldn't be accessed // without being initialized. - ConcretizedBroadcastDomains concretized_broadcast_domains_; + std::shared_ptr + concretized_broadcast_domains_; ThreadPredicateMap thread_pred_map_; PredicateElimination pred_elimination_; - std::unique_ptr compute_at_map_; + std::shared_ptr compute_at_map_; TrivialReductionInfo trivial_reduction_info_; - HaloInfo halo_info_; + std::shared_ptr halo_info_; LocalAllocationInfoMap local_allocation_info_map_; WarpPaddedParallelInfo warp_pad_info_; ParallelDimensionMap parallel_dimension_map_; @@ -214,6 +216,7 @@ class TORCH_CUDA_CU_API GpuLower : public NonCopyable { FusedReductionInfo fused_reduction_info_; SyncMap sync_map_; kir::KernelPerformanceProfile profile_; + std::unordered_set divisible_splits_; // Track which tensor views are inputs or outputs of a vectorized operation // and their maximum vectorized access size diff --git a/torch/csrc/jit/codegen/cuda/lower_alias_memory.cpp b/torch/csrc/jit/codegen/cuda/lower_alias_memory.cpp index 4e84579485509..ef12cce8fd46a 100644 --- a/torch/csrc/jit/codegen/cuda/lower_alias_memory.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_alias_memory.cpp @@ -18,16 +18,19 @@ namespace fuser { namespace cuda { namespace { +// Alias used for std::transform +IterDomain* exactConcreteId(IterDomain* id) { + return GpuLower::current()->caMap()->getConcreteMappedID( + id, IdMappingMode::EXACT); +} -//! Checks that the current loop nest is not realizing a serial -//! broadcast so that each index of producer buffer will only -//! be visited once, which is the only case where aggressive -//! inner sharing is valid. -//! +//! Checks that the current loop nest is realizing a serial +//! broadcast so that each index of producer buffer can be visited +//! multiple times, in which case the aggressive is not valid. bool isSerialBroadcastResolution(TensorView* producer, TensorView* consumer) { //! Note: see issue #1785: //! serial broadcast resolution doesn't only happen to - //! immediate producers of broadcast ops. We can also have + //! immediate outputs of broadcast ops. We can also have //! example: //! T1[I,B] = broadcast(T0[I]]) //! T3[I,I] = T1[I,B] + T2[I,I] @@ -83,7 +86,7 @@ bool isSerialBroadcastResolution(TensorView* producer, TensorView* consumer) { std::inserter( producer_exact_concrete_root_ids, producer_exact_concrete_root_ids.begin()), - ir_utils::caMapExactConcreteId); + exactConcreteId); // Check if serial loop roots indexes any exact root id's that // is not within the set of producer's root exact id's. These @@ -92,7 +95,8 @@ bool isSerialBroadcastResolution(TensorView* producer, TensorView* consumer) { for (auto serial_loop_root : ir_utils::filterByType(serial_loop_roots)) { if (!producer_exact_concrete_root_ids.count( - ir_utils::caMapExactConcreteId(serial_loop_root))) { + GpuLower::current()->caMap()->getConcreteMappedID( + serial_loop_root, IdMappingMode::EXACT))) { return true; } } diff --git a/torch/csrc/jit/codegen/cuda/lower_allocation.cpp b/torch/csrc/jit/codegen/cuda/lower_allocation.cpp index 466dc85c8abff..264905cfa213f 100644 --- a/torch/csrc/jit/codegen/cuda/lower_allocation.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_allocation.cpp @@ -59,7 +59,7 @@ class AllocationInserter : public kir::ExprMutator { // info.init_place_before, info.alloc_for_loop, info.alloc_place_before void fillAllocationInformation(AllocationInformation& info, Expr* expr) { auto loop_alloc_info = - loop_utils::getAllocInformation(info.buffer, for_loops_); + lower_utils::getAllocInformation(info.buffer, for_loops_); info.init_for_loop = loop_alloc_info.init_for_loop; info.alloc_for_loop = loop_alloc_info.alloc_for_loop; @@ -131,7 +131,7 @@ class AllocationInserter : public kir::ExprMutator { ++init_loop_it) { auto id = *init_loop_it; kir::ForLoop* new_loop = nullptr; - auto extent_with_halo = gpu_lower->haloInfo().getExtent(id); + auto extent_with_halo = gpu_lower->haloInfo()->getExtent(id); if (extent_with_halo) { new_loop = IrBuilder::create( id, @@ -166,7 +166,7 @@ class AllocationInserter : public kir::ExprMutator { } auto extent = id->extent(); // Use halo-extended extent if found - auto halo_extent = gpu_lower->haloInfo().getRootAxisInfo(id); + auto halo_extent = gpu_lower->haloInfo()->getRootAxisInfo(id); if (halo_extent.hasHalo()) { extent = IrBuilder::addExpr( extent, IrBuilder::create(halo_extent.width())); @@ -213,7 +213,7 @@ class AllocationInserter : public kir::ExprMutator { // Get the halo extent if found auto getExtent = [this](IterDomain* id) { - auto extent = gpu_lower->haloInfo().getExtent(id); + auto extent = gpu_lower->haloInfo()->getExtent(id); if (extent == nullptr) { extent = id->extent(); } @@ -368,7 +368,7 @@ class AllocationInserter : public kir::ExprMutator { auto extent = concrete_id->extent(); - if (gpu_lower->haloInfo().getExtent(info.buffer->axis(axis_i)) != + if (gpu_lower->haloInfo()->getExtent(info.buffer->axis(axis_i)) != nullptr) { has_halo = true; } diff --git a/torch/csrc/jit/codegen/cuda/lower_bank_conflict.cpp b/torch/csrc/jit/codegen/cuda/lower_bank_conflict.cpp new file mode 100644 index 0000000000000..0b97b973f786e --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/lower_bank_conflict.cpp @@ -0,0 +1,332 @@ +#include + +#include +#include +#include +#include +#include + +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +namespace { + +bool isSmemTensorIndex(Val* value) { + return value->isA() && + value->as()->view()->getMemoryType() == + MemoryType::Shared; +} + +int64_t getVectorizeSize(kir::TensorIndex* ti) { + for (auto id : ti->view()->domain()->domain()) { + if (!isParallelTypeVectorize(id->getParallelType())) { + continue; + } + + ExpressionEvaluator expr_eval(id->fusion()); + auto vector_size_optional = expr_eval.evaluate(id->extent()); + + TORCH_INTERNAL_ASSERT( + vector_size_optional.has_value(), + "Could not evaluate constant value bound to vectorized dim."); + + return vector_size_optional->as(); + } + return 1; +} + +inline int64_t getPhaseSize(int64_t word_size_bytes) { + if (word_size_bytes == 16) { + return 8; + } + if (word_size_bytes == 8) { + return 16; + } + return 32; +} + +bool isThreadIdx(const std::string& name) { + return name == "threadIdx.x" || name == "threadIdx.y" || + name == "threadIdx.z"; +} + +bool isBlockIdx(const std::string& name) { + return name == "blockIdx.x" || name == "blockIdx.y" || name == "blockIdx.z"; +} + +bool isBlockDim(const std::string& name) { + return name == "blockDim.x" && name == "blockDim.y" && name == "blockDim.z"; +} + +bool isGridDim(const std::string& name) { + return name == "gridDim.x" && name == "gridDim.y" && name == "gridDim.z"; +} + +ParallelType getParallelType(const std::string& name) { + if (name == "threadIdx.x") { + return ParallelType::TIDx; + } else if (name == "threadIdx.y") { + return ParallelType::TIDy; + } else if (name == "threadIdx.z") { + return ParallelType::TIDz; + } else if (name == "blockIdx.x") { + return ParallelType::BIDx; + } else if (name == "blockIdx.y") { + return ParallelType::BIDy; + } else if (name == "blockIdx.z") { + return ParallelType::BIDz; + } + TORCH_INTERNAL_ASSERT(false, "Not a parallel type"); +} + +std::vector evaluateAddressesOnFirstPhase( + kir::TensorIndex* ti, + const std::vector& for_loops, + c10::optional launch_params, + const ExpressionEvaluator& expr_eval_common) { + std::vector addresses; + const auto word_size_bytes = + dataTypeSize(*(ti->getDataType())) * getVectorizeSize(ti); + int64_t phase_size = getPhaseSize(word_size_bytes); + + if (launch_params.has_value()) { + phase_size = std::min(phase_size, launch_params->nThreads()); + } + + for (int64_t linear_tidx : c10::irange(phase_size)) { + int64_t tidx = linear_tidx; + int64_t tidy = 0; + int64_t tidz = 0; + if (launch_params.has_value()) { + tidy = tidx / launch_params->bdimx(); + tidx = tidx % launch_params->bdimx(); + tidz = tidy / launch_params->bdimy(); + tidy = tidy % launch_params->bdimy(); + } + int64_t index = 0; + // make a copy of the expression evaluator + ExpressionEvaluator expr_eval = expr_eval_common; + expr_eval.bind("threadIdx.x", tidx); + expr_eval.bind("threadIdx.y", tidy); + expr_eval.bind("threadIdx.z", tidz); + for (auto fl : for_loops) { + if (fl->index()->isA()) { + auto name = fl->index()->as()->name(); + TORCH_INTERNAL_ASSERT( + isThreadIdx(name) || isBlockIdx(name), "unknow loop index"); + } else { + auto start = expr_eval.evaluate(fl->start())->as(); + expr_eval.bind(fl->index(), start); + } + } + for (auto ind : ti->indices()) { + index += expr_eval.evaluate(ind)->as(); + } + addresses.emplace_back(index * word_size_bytes); + } + return addresses; +} + +int getConflictWays(const std::vector& addresses) { + std::unordered_set words_by_bank[32]; + for (auto addr : addresses) { + int64_t word = addr / 4; + int64_t bank = word % 32; + words_by_bank[bank].insert(word); + } + int conflict = 1; + for (const auto& words : words_by_bank) { + conflict = std::max(conflict, words.size()); + } + return conflict; +} + +class InferLaunchParams : public kir::IrVisitor { + public: + static c10::optional get( + const std::vector& exprs, + const std::unordered_map& known_values) { + if (exprs.empty()) { + return c10::nullopt; + } + return InferLaunchParams(exprs, known_values).launch_params_; + } + + private: + InferLaunchParams( + const std::vector& exprs, + const std::unordered_map& known_values) + : expr_eval_(exprs[0]->fusion()) { + for (auto pair : known_values) { + expr_eval_.bind(pair.first, pair.second); + } + handle(exprs); + } + + using kir::IrVisitor::handle; + + void handle(Expr* expr) final { + if (expr->isA() || expr->isA()) { + kir::IrVisitor::handle(expr); + return; + } + + for (auto fl : for_loops_) { + if (fl->index()->isA()) { + auto name = fl->index()->as()->name(); + if (isThreadIdx(name) || isBlockIdx(name)) { + auto ptype = getParallelType(name); + auto stop = expr_eval_.evaluate(fl->stop()); + if (stop.has_value()) { + if (!launch_params_.has_value()) { + launch_params_ = LaunchParams(); + } + if (launch_params_->getRawVal(ptype) == + LaunchParams::UNINITIALIZED_VAL) { + launch_params_->bind(stop->as(), ptype); + } else { + TORCH_INTERNAL_ASSERT( + launch_params_->getDim(ptype) == stop, + "Unable to infer launch parameters"); + } + } + } + } + } + } + + ExpressionEvaluator expr_eval_; + c10::optional launch_params_; +}; + +class BankConflictInfo : public kir::IrVisitor { + public: + static std::unordered_map> get( + const std::vector& exprs, + c10::optional launch_params, + const std::unordered_map& known_values) { + if (exprs.empty()) { + return {}; + } + return BankConflictInfo(exprs, launch_params, known_values) + .bank_conflict_info_; + } + + private: + BankConflictInfo( + const std::vector& exprs, + c10::optional launch_params, + const std::unordered_map& known_values) + : launch_params_(launch_params), expr_eval_common_(exprs[0]->fusion()) { + expr_eval_common_.bind("blockIdx.x", 0); + expr_eval_common_.bind("blockIdx.y", 0); + expr_eval_common_.bind("blockIdx.z", 0); + if (launch_params.has_value()) { + expr_eval_common_.bind("blockDim.x", launch_params->bdimx()); + expr_eval_common_.bind("blockDim.y", launch_params->bdimy()); + expr_eval_common_.bind("blockDim.z", launch_params->bdimz()); + expr_eval_common_.bind("gridDim.x", launch_params->gdimx()); + expr_eval_common_.bind("gridDim.y", launch_params->gdimy()); + expr_eval_common_.bind("gridDim.z", launch_params->gdimz()); + } + for (auto pair : known_values) { + expr_eval_common_.bind(pair.first, pair.second); + } + handle(exprs); + } + + using kir::IrVisitor::handle; + + void handle(Expr* expr) final { + if (expr->isA() || expr->isA()) { + kir::IrVisitor::handle(expr); + return; + } + + if (expr->isA()) { + auto uop = expr->as(); + if (uop->getUnaryOpType() != UnaryOpType::Set) { + return; + } + std::pair conflict_ways{0, 0}; + if (isSmemTensorIndex(uop->in())) { + conflict_ways.first = getConflictWays(evaluateAddressesOnFirstPhase( + uop->in()->as(), + for_loops_, + launch_params_, + expr_eval_common_)); + } + if (isSmemTensorIndex(uop->out())) { + conflict_ways.second = getConflictWays(evaluateAddressesOnFirstPhase( + uop->out()->as(), + for_loops_, + launch_params_, + expr_eval_common_)); + } + if (conflict_ways.first > 1 || conflict_ways.second > 1) { + bank_conflict_info_[expr] = conflict_ways; + } + } else if (expr->isA()) { + auto ldst = expr->as(); + std::pair conflict_ways{0, 0}; + if (isSmemTensorIndex(ldst->in())) { + conflict_ways.first = getConflictWays(evaluateAddressesOnFirstPhase( + ldst->in()->as(), + for_loops_, + launch_params_, + expr_eval_common_)); + } + if (isSmemTensorIndex(ldst->out())) { + conflict_ways.second = getConflictWays(evaluateAddressesOnFirstPhase( + ldst->out()->as(), + for_loops_, + launch_params_, + expr_eval_common_)); + } + if (conflict_ways.first > 1 || conflict_ways.second > 1) { + bank_conflict_info_[expr] = conflict_ways; + } + } + } + + std::unordered_map> bank_conflict_info_; + c10::optional launch_params_; + ExpressionEvaluator expr_eval_common_; +}; + +} // namespace + +std::unordered_map> getBankConflictInfo( + kir::Kernel* kernel, + c10::optional launch_params, + const std::unordered_map& known_values) { + for (auto pair : known_values) { + TORCH_CHECK( + !isThreadIdx(pair.first), + "threadIdx.{x,y,z} should be computed instead of provided"); + TORCH_CHECK( + !isBlockIdx(pair.first), + "blockIdx.{x,y,z} should not be provided (they are always zero)"); + TORCH_CHECK( + !isBlockDim(pair.first), + "blockDim.{x,y,z} should be provided by launch_params"); + TORCH_CHECK( + !isGridDim(pair.first), + "gridDim.{x,y,z} should be provided by launch_params"); + } + if (!launch_params.has_value()) { + launch_params = + InferLaunchParams::get(kernel->topLevelExprs(), known_values); + } + return BankConflictInfo::get( + kernel->topLevelExprs(), launch_params, known_values); +} + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/lower_bank_conflict.h b/torch/csrc/jit/codegen/cuda/lower_bank_conflict.h new file mode 100644 index 0000000000000..b651c4ed33e22 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/lower_bank_conflict.h @@ -0,0 +1,46 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +// for more info on shared memory access see page 54-72 of: +// https://on-demand.gputechconf.com/gtc/2018/presentation/s81006-volta-architecture-and-performance-optimization.pdf + +// Warning: The bank confliction checking utility here is not a replacement of +// nsight compute. This utility currently has the following assumptions and +// limitations: +// +// 1. This utility assumes that the data of the tensor is accessed by +// `T0[index]`, where `index` is the one stored in the `TensorIndex` +// object. +// 2. This utility only checks the first iteration. If we have something like +// `T1_s[tidx, 5]`, then different iterations should have different +// conflictions, which will not be evaluated for all of them +// 3. This utility assumes that all tensors are independent, which means: +// 3.1 All shared memory tensors are allocated starting from a multiple of +// 4*32 bytes +// 3.2 The only source of bank confliction is from within a tensor. +// There is no bank conflict between different tensors. +// +// Also note that this utility will not provide accurate estimation if the above +// assumptions are satisfied + +std::unordered_map> getBankConflictInfo( + kir::Kernel* kernel, + c10::optional launch_params = c10::nullopt, + const std::unordered_map& known_values = {}); + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/lower_divisible_split.cpp b/torch/csrc/jit/codegen/cuda/lower_divisible_split.cpp new file mode 100644 index 0000000000000..c1de1201e5d18 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/lower_divisible_split.cpp @@ -0,0 +1,121 @@ + +#include + +#include +#include + +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +std::unordered_set getAllDivisibleSplits(Fusion* fusion) { + ComputeAtMap ca_map(fusion); + return getAllDivisibleSplits(fusion, &ca_map); +} + +std::unordered_set getAllDivisibleSplits( + Fusion* fusion, + const ComputeAtMap* ca_map) { + std::unordered_set all_divisible_splits; + + auto all_tvs = ir_utils::allTvs(fusion); + // Find all tensor views with a view like rfactor. Splits used in view + // transformations must be divisible by definition. + for (auto tv : all_tvs) { + auto rfactor_dom = tv->getMaybeRFactorDomain(); + // Not view if there's no rfactor axis + if (!tv->domain()->hasViewLikeRFactor()) { + continue; + } + + // Take the view transformations and add all the splits. Those splits are + // the only divisible splits. + auto view_exprs = + StmtSort::getExprs(fusion, {rfactor_dom.begin(), rfactor_dom.end()}); + auto split_exprs = ir_utils::filterByType(view_exprs); + all_divisible_splits.insert(split_exprs.begin(), split_exprs.end()); + } + + // Vectorized dimensions are enforced to be a result of divisible splits. + // Gather vectorized splits. + for (auto tv : all_tvs) { + auto vec_id_it = std::find_if( + tv->domain()->domain().begin(), + tv->domain()->domain().end(), + [](IterDomain* id) { + return isParallelTypeVectorize(id->getParallelType()); + }); + + if (vec_id_it == tv->domain()->domain().end()) { + continue; + } + + // We could have a case technically like: + // [8, 2] where we do: + // split(0, 2) + // merge(1) + // so it ends up as [4, 4] + // split(0, 2) must be divisible, but for now we're not going to capture + // cases like this. Just look for direct split's producing a vectorize + // dimension. + auto vec_id = *vec_id_it; + if (vec_id->definition() != nullptr && vec_id->definition()->isA()) { + all_divisible_splits.emplace(vec_id->definition()->as()); + } + } + + // If there's no view like splits, there's nothing to find + if (all_divisible_splits.empty()) { + return all_divisible_splits; + } + + // Track the concrete id in the exact map of the outer output of the split + // expressions. This is how we'll check if there are matching splits. This + // also gets rid of any splits that already match (for processing). + std::unordered_map outer_concrete_id_to_expr; + + for (auto split : all_divisible_splits) { + outer_concrete_id_to_expr[ca_map->getConcreteMappedID( + split->outer(), IdMappingMode::EXACT)] = split; + } + + std::unordered_set visited( + all_divisible_splits.begin(), all_divisible_splits.end()); + + // Find splits that match what we already have: + for (auto entry : outer_concrete_id_to_expr) { + auto concrete_id = entry.first; + auto original_view_split = entry.second; + + const auto& exact_mapped_ids = + ca_map->idGraph().exactNodes().getDisjointSetOf(concrete_id).vector(); + for (auto other_id : exact_mapped_ids) { + if (other_id->definition() == nullptr) { + continue; + } + + if (!visited.emplace(other_id->definition()).second) { + // Already visited + continue; + } + + if (IterDomainGraph::exprsMap( + original_view_split, + other_id->definition(), + false, + ca_map->idGraph().exactNodes())) { + all_divisible_splits.emplace(other_id->definition()->as()); + } + } + } + + return all_divisible_splits; +} + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/lower_divisible_split.h b/torch/csrc/jit/codegen/cuda/lower_divisible_split.h new file mode 100644 index 0000000000000..f2c4a78e4895e --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/lower_divisible_split.h @@ -0,0 +1,29 @@ +#pragma once + +#include + +#include +#include +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +// Looks through all transformations assocaited with view, or enforced divisible +// vectorization splits and gathers all splits that provably don't have a +// remainder, therefore the extents of the associated IterDomains do not require +// a ceilDiv expressions. +TORCH_CUDA_CU_API std::unordered_set getAllDivisibleSplits( + Fusion* fusion); + +// Same as above but will use provided ComputeAtMap instead of building its own. +TORCH_CUDA_CU_API std::unordered_set getAllDivisibleSplits( + Fusion* fusion, + const ComputeAtMap* ca_map); + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/lower_expr_sort.cpp b/torch/csrc/jit/codegen/cuda/lower_expr_sort.cpp index c4a5beeeabee2..5b659e3e94605 100644 --- a/torch/csrc/jit/codegen/cuda/lower_expr_sort.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_expr_sort.cpp @@ -413,9 +413,10 @@ std::vector ExprGroup::getMergeCandidates( "Shouldn't still be traversing in fallback mode if a merge was found."); } - std::vector can_merge(true, neighbors.size()); + std::vector can_merge(neighbors.size(), true); - // Find neighbors with a level that is only 1 differant than this groups level + // Find neighbors with a level that is only 1 different than this group's + // level for (const auto i : c10::irange(neighbors.size())) { if (std::abs(neighbors[i]->payload()->level - payload()->level) > 1) { can_merge[i] = false; @@ -709,7 +710,7 @@ std::vector getLocalDomainOrdering( std::sort( merged_domain.begin(), merged_domain.end(), - IterDomainDependencySorter( + ir_utils::IterDomainDependencySorter( concrete_id_dependencies, GpuLower::current()->caMap())); return merged_domain; } @@ -927,8 +928,8 @@ bool ExprSegmentationSorter::interIterUpdate() { // If we didn't finish and we tried the fallback, throw. TORCH_INTERNAL_ASSERT( !fallback_mode_enabled_, - "Couldn't succcessfully sort out the fusion expressions. ", - "There are remaining connections of the heirarchical segmentation which should have been ", + "Couldn't successfully sort out the fusion expressions. ", + "There are remaining connections of the hierarchical segmentation which should have been ", "flattened to a single ordered group, or disjoint ordered groups."); // We didn't finish, but we haven't tried the fallback, try again with that. fallback_mode_enabled_ = true; @@ -1066,7 +1067,7 @@ void ExprSegmentationSorter::initializeForLoopDependencies() { } } - std::cerr << "Depdencies: " << std::endl; + std::cerr << "Dependencies: " << std::endl; for (const auto& dep_entry : concrete_id_dependencies) { std::cerr << " Deps of " << dep_entry.first->toString() << std::endl << " "; @@ -1398,6 +1399,9 @@ std::vector ExprSegmentationSorter::getExprs() const { std::vector reorderExprsForComputeAt() { auto fusion = FusionGuard::getCurFusion(); + if (fusion->exprs().empty()) { + return {}; + } TORCH_INTERNAL_ASSERT(fusion != nullptr); ExprSegmentationSorter sorter(fusion); sorter.sort(); diff --git a/torch/csrc/jit/codegen/cuda/lower_index.cpp b/torch/csrc/jit/codegen/cuda/lower_index.cpp index dc210e98cbc8d..e83a0e9fce996 100644 --- a/torch/csrc/jit/codegen/cuda/lower_index.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_index.cpp @@ -100,37 +100,73 @@ void IndexLowering::handle(const RNGOp* rop) { // TensorIndex for philox subsequence and component. auto philox_index = SimplifyingIrBuilder::create( - out_tv, Index::getLinearIndex(out_tv, for_loops_)); + out_tv, Index::getLinearLogicalIndex(out_tv, for_loops_)); - // TensorIndex for writing randlike output. + // TensorIndex for writing rand_like output. const auto out = lowerDstIndex(out_tv); auto lowered = IrBuilder::create( - rop->getRNGOpType(), out, rop->getRNGOffset(), philox_index); + rop->getRNGOpType(), + out, + rop->dtype(), + rop->getParameters(), + rop->getRNGOffset(), + philox_index); pushBack(lowered); GpuLower::current()->propagateExprInfo(rop, back()); } +void IndexLowering::handle(const FullOp* fop) { + auto out_tv = dynamic_cast(fop->output(0)); + TORCH_INTERNAL_ASSERT(out_tv != nullptr); + + // TensorIndex for writing output. + const auto out = lowerDstIndex(out_tv); + auto lowered = + IrBuilder::create(out, fop->getFillValue(), fop->dtype()); + + pushBack(lowered); + GpuLower::current()->propagateExprInfo(fop, back()); +} + void IndexLowering::handle(const ARangeOp* aop) { // Write linear tensor indices into the consumer // tensor index if the output is a tensor. auto out_tv = dynamic_cast(aop->output(0)); TORCH_INTERNAL_ASSERT(out_tv != nullptr); - // TensorIndex for philox subsequence and component. + // linear index for computing arange output auto linear_index = SimplifyingIrBuilder::create( - out_tv, Index::getLinearIndex(out_tv, for_loops_)); + out_tv, Index::getLinearLogicalIndex(out_tv, for_loops_)); - // TensorIndex for writing randlike output. + // TensorIndex for writing arange output. const auto out = lowerDstIndex(out_tv); auto lowered = IrBuilder::create( - out, aop->start(), aop->end(), aop->step(), linear_index); + out, aop->start(), aop->end(), aop->step(), aop->dtype(), linear_index); pushBack(lowered); GpuLower::current()->propagateExprInfo(aop, back()); } +void IndexLowering::handle(const EyeOp* eop) { + auto out_tv = dynamic_cast(eop->output(0)); + TORCH_INTERNAL_ASSERT(out_tv != nullptr); + + // linear index for computing eye output + auto indices = Index::getPerDimLogicalIndex(out_tv, for_loops_); + TORCH_INTERNAL_ASSERT(indices.size() == 2); + auto index1 = indices[0]; + auto index2 = indices[1]; + + // TensorIndex for writing eye output. + const auto out = lowerDstIndex(out_tv); + auto lowered = IrBuilder::create(out, eop->dtype(), index1, index2); + + pushBack(lowered); + GpuLower::current()->propagateExprInfo(eop, back()); +} + void IndexLowering::handle(const UnaryOp* uop) { const auto in = lowerSrcIndex(uop->in(), uop->out()); const auto out = lowerDstIndex(uop->out()); @@ -375,10 +411,12 @@ void IndexLowering::handleBlockReduction( ReductionOp* indexed_rop = IrBuilder::create( rop->getReductionOpType(), rop->init(), out, in, rop->isAllreduce()); if (rop->predicate()) { - indexed_rop->setPredicate(rop->predicate()); + indexed_rop = + indexed_rop->withPredicate(rop->predicate())->as(); } if (rop->writePredicate()) { - indexed_rop->setWritePredicate(rop->writePredicate()); + indexed_rop = indexed_rop->withWritePredicate(rop->writePredicate()) + ->as(); } pushBack(indexed_rop); @@ -457,13 +495,15 @@ void IndexLowering::handleGridReduction( n_entrances, rop->isAllreduce()); - grid_reduction->setThreadPredicate(thread_pred); + grid_reduction = grid_reduction->withThreadPredicate(thread_pred); if (rop->predicate()) { - grid_reduction->setPredicate(rop->predicate()); + grid_reduction = grid_reduction->withPredicate(rop->predicate()) + ->as(); } if (rop->writePredicate()) { - grid_reduction->setWritePredicate(rop->writePredicate()); + grid_reduction = grid_reduction->withWritePredicate(rop->writePredicate()) + ->as(); } pushBack(grid_reduction); @@ -520,10 +560,12 @@ void IndexLowering::handleBlockReduction( inputs, grouped_rop->isAllreduce()); if (grouped_rop->predicate()) { - indexed_rop->setPredicate(grouped_rop->predicate()); + indexed_rop = indexed_rop->withPredicate(grouped_rop->predicate()) + ->as(); } if (grouped_rop->writePredicate()) { - indexed_rop->setWritePredicate(grouped_rop->writePredicate()); + indexed_rop = indexed_rop->withWritePredicate(grouped_rop->writePredicate()) + ->as(); } pushBack(indexed_rop); @@ -602,13 +644,16 @@ void IndexLowering::handleGridReduction( work_buf_size_info.buffer_stride, grouped_rop->isAllreduce()); - grid_reduction->setThreadPredicate(thread_pred); + grid_reduction = grid_reduction->withThreadPredicate(thread_pred); if (grouped_rop->predicate()) { - grid_reduction->setPredicate(grouped_rop->predicate()); + grid_reduction = grid_reduction->withPredicate(grouped_rop->predicate()) + ->as(); } if (grouped_rop->writePredicate()) { - grid_reduction->setWritePredicate(grouped_rop->writePredicate()); + grid_reduction = + grid_reduction->withWritePredicate(grouped_rop->writePredicate()) + ->as(); } pushBack(grid_reduction); @@ -670,10 +715,11 @@ void IndexLowering::handle(const WelfordOp* wop) { wop->isAllreduce()); if (wop->predicate()) { - indexed_wop->setPredicate(wop->predicate()); + indexed_wop = indexed_wop->withPredicate(wop->predicate())->as(); } if (wop->writePredicate()) { - indexed_wop->setWritePredicate(wop->writePredicate()); + indexed_wop = + indexed_wop->withWritePredicate(wop->writePredicate())->as(); } // Serial welford @@ -749,22 +795,27 @@ void IndexLowering::handleGridWelford(WelfordOp* indexed_wop) { entrance_ind, n_entrances); - grid_welford->setThreadPredicate(thread_pred); + grid_welford = grid_welford->withThreadPredicate(thread_pred); const bool block_reduce_separated = out_domain->hasBlockReduction() && !indexed_wop->isAllreduce(); if (indexed_wop->predicate()) { if (block_reduce_separated) { - grid_welford->setPredicate(IrBuilder::create( - GpuLower::current()->kernel()->trueVal())); + grid_welford = grid_welford + ->withPredicate(IrBuilder::create( + GpuLower::current()->kernel()->trueVal())) + ->as(); } else { - grid_welford->setPredicate(indexed_wop->predicate()); + grid_welford = grid_welford->withPredicate(indexed_wop->predicate()) + ->as(); } } if (indexed_wop->writePredicate()) { - grid_welford->setWritePredicate(indexed_wop->writePredicate()); + grid_welford = + grid_welford->withWritePredicate(indexed_wop->writePredicate()) + ->as(); } if (block_reduce_separated) { @@ -909,13 +960,15 @@ void IndexLowering::handleGroupedGridWelford( work_buf_size_info.buffer_stride, op->isAllreduce()); - indexed_op->setThreadPredicate(thread_pred); + indexed_op = indexed_op->withThreadPredicate(thread_pred); if (op->predicate()) { - indexed_op->setPredicate(op->predicate()); + indexed_op = indexed_op->withPredicate(op->predicate()) + ->as(); } if (op->writePredicate()) { - indexed_op->setWritePredicate(op->writePredicate()); + indexed_op = indexed_op->withWritePredicate(op->writePredicate()) + ->as(); } pushBack(indexed_op); @@ -929,7 +982,9 @@ void IndexLowering::handleGroupedGridWelford( void IndexLowering::handle(const LoadStoreOp* ldst) { const auto in = lowerSrcIndex(ldst->in(), ldst->out()); const auto out = lowerDstIndex(ldst->out()); - pushBack(IrBuilder::create(ldst->opType(), out, in)); + auto new_ldst = IrBuilder::create(ldst->opType(), out, in) + ->withPredicate(ldst->predicate()); + pushBack(new_ldst); GpuLower::current()->propagateExprInfo(ldst, back()); } @@ -961,7 +1016,8 @@ void IndexLowering::handle(const BroadcastOp* bop) { const bool block_z = parallel_bitmap.get(ParallelType::BIDz); if (bop->predicate()) { - indexed_expr->setPredicate(bop->predicate()); + indexed_expr = + indexed_expr->withPredicate(bop->predicate())->as(); } const bool grid_broadcast_needed = block_x || block_y || block_z; @@ -988,7 +1044,8 @@ void IndexLowering::handle(const BroadcastOp* bop) { indexed_expr, work_buffer, sync_buffer); if (bop->predicate()) { - grid_broadcast->setPredicate(bop->predicate()); + grid_broadcast = grid_broadcast->withPredicate(bop->predicate()) + ->as(); } pushBack(grid_broadcast); @@ -1040,7 +1097,7 @@ kir::Allocate* IndexLowering::allocateUniqueBuffer( // No existing allocation found. Create a new one auto new_buffer = - ir_utils::allocGlobalBufferForGridComm(buffer_size, dtype, zero_init); + lower_utils::allocGlobalBufferForGridComm(buffer_size, dtype, zero_init); // Keep track of the allocation alloc_map.emplace(out_tv, new_buffer); diff --git a/torch/csrc/jit/codegen/cuda/lower_index.h b/torch/csrc/jit/codegen/cuda/lower_index.h index 75f7fd4aac335..6c08eeb195ea5 100644 --- a/torch/csrc/jit/codegen/cuda/lower_index.h +++ b/torch/csrc/jit/codegen/cuda/lower_index.h @@ -38,7 +38,9 @@ class TORCH_CUDA_CU_API IndexLowering : private OptOutConstDispatch { // Insert an expression before the current top-level expression. void insertAtTopLevel(Expr* expr); + void handle(const FullOp*) final; void handle(const ARangeOp*) final; + void handle(const EyeOp*) final; void handle(const ViewAsScalar*) final; void handle(const UnaryOp*) final; diff --git a/torch/csrc/jit/codegen/cuda/lower_index_compute.cpp b/torch/csrc/jit/codegen/cuda/lower_index_compute.cpp index 2d4444d340903..140fecc0f8af1 100644 --- a/torch/csrc/jit/codegen/cuda/lower_index_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_index_compute.cpp @@ -101,7 +101,7 @@ struct IndexingParameters { }; // Initial loop index map for global producer or consumer case. -IndexingParameters getGlobalIndexParameters( +IndexingParameters getLinearIndexParameters( const LoopIndexing& loop_indexing, bool index_producer = false) { IndexingParameters index_parameters; @@ -112,7 +112,8 @@ IndexingParameters getGlobalIndexParameters( for (auto loop_idx : c10::irange(loops.size())) { auto loop = loops[loop_idx]; - auto index_domain = ir_utils::caMapExactConcreteId(loop_domain[loop_idx]); + auto index_domain = GpuLower::current()->caMap()->getConcreteMappedID( + loop_domain[loop_idx], IdMappingMode::EXACT); if (loop->isTrivial()) { // This is useful information in the case of // MisalignedVectorize and double buffer epilog, etc. @@ -125,7 +126,8 @@ IndexingParameters getGlobalIndexParameters( // Derive the halo extents from the loop indexing result. index_parameters.concrete_id_to_halo_extent = - GpuLower::current()->haloInfo().buildConcreteHaloExtentMap(loop_indexing); + GpuLower::current()->haloInfo()->buildConcreteHaloExtentMap( + loop_indexing); protectNonPredicateIndexWithMagicZero( loops, @@ -148,7 +150,9 @@ IndexingParameters getGlobalIndexParameters( auto loop_id = loop_indexing.loopDomains()[loop_idx]; - auto concrete_loop_id = ir_utils::caMapExactConcreteId(loop_id); + auto concrete_loop_id = + GpuLower::current()->caMap()->getConcreteMappedID( + loop_id, IdMappingMode::EXACT); auto stage_depth = GpuLower::current()->doubleBufferInfo().getStageDepthFor( @@ -185,7 +189,7 @@ IndexingParameters getNonGlobalInitialIndexParameters( } auto alloc_tv = index_producer ? producer_tv : consumer_tv; - auto alloc_info = loop_utils::getAllocInformation( + auto alloc_info = lower_utils::getAllocInformation( alloc_tv, loops, alloc_id_map, index_producer); std::unordered_map loop_to_ind_map; @@ -216,7 +220,9 @@ IndexingParameters getNonGlobalInitialIndexParameters( auto loop = loops[loop_idx]; auto loop_domain = loop_domains[loop_idx]; - auto concrete_loop_domain = ir_utils::caMapExactConcreteId(loop_domain); + auto concrete_loop_domain = + GpuLower::current()->caMap()->getConcreteMappedID( + loop_domain, IdMappingMode::EXACT); index_parameters.initial_concrete_id_index[concrete_loop_domain] = loop_to_ind_map.at(loop); @@ -233,7 +239,8 @@ IndexingParameters getNonGlobalInitialIndexParameters( // Derive the halo extents from the loop indexing result. index_parameters.concrete_id_to_halo_extent = - GpuLower::current()->haloInfo().buildConcreteHaloExtentMap(loop_indexing); + GpuLower::current()->haloInfo()->buildConcreteHaloExtentMap( + loop_indexing); return index_parameters; } @@ -397,7 +404,8 @@ IndexingParameters getPredicateInitialIndexParameters( for (int loop_idx : c10::irange(loops.size())) { auto loop = loops.at(loop_idx); auto concrete_loop_domain = - ir_utils::caMapExactConcreteId(loop_domains.at(loop_idx)); + GpuLower::current()->caMap()->getConcreteMappedID( + loop_domains.at(loop_idx), IdMappingMode::EXACT); index_parameters.initial_concrete_id_index[concrete_loop_domain] = loop_to_ind_map.at(loop); } @@ -408,7 +416,8 @@ IndexingParameters getPredicateInitialIndexParameters( // Derive the halo extents from the loop indexing result. index_parameters.concrete_id_to_halo_extent = - GpuLower::current()->haloInfo().buildConcreteHaloExtentMap(loop_indexing); + GpuLower::current()->haloInfo()->buildConcreteHaloExtentMap( + loop_indexing); return index_parameters; } @@ -563,7 +572,10 @@ LoopIndexingAnalysis::LoopIndexingAnalysis( // consume each concrete id once so this map is well defined. for (auto expr : replayed_exprs_) { for (auto input_id : ir_utils::filterByType(expr->inputs())) { - concrete_id_to_consumer_[ir_utils::caMapExactConcreteId(input_id)] = expr; + auto concrete_input_id = + GpuLower::current()->caMap()->getConcreteMappedID( + input_id, IdMappingMode::EXACT); + concrete_id_to_consumer_[concrete_input_id] = expr; } } @@ -595,7 +607,8 @@ void LoopIndexingAnalysis::validateLoopStructure( for (auto it_i = loops.begin(); it_i != loops.end(); ++it_i) { // Largely duplicating original logic auto loop_id = (*it_i)->iter_domain(); - auto concrete_loop_id = ir_utils::caMapExactConcreteId(loop_id); + auto concrete_loop_id = GpuLower::current()->caMap()->getConcreteMappedID( + loop_id, IdMappingMode::EXACT); TORCH_INTERNAL_ASSERT( !concrete_to_loop.count(concrete_loop_id), @@ -659,13 +672,22 @@ void LoopIndexingAnalysis::traverseFromDomainVals() { } IterDomain* LoopIndexingAnalysis::concretizeAndVisitId(IterDomain* id) { - auto concrete_id = ir_utils::caMapExactConcreteId(id); + auto concrete_id = GpuLower::current()->caMap()->getConcreteMappedID( + id, IdMappingMode::EXACT); if (replayed_concrete_ids_.pushBack(concrete_id)) { concrete_to_original_id_[concrete_id] = id; } return concrete_id; } +namespace { +// Alias used for std::transform +IterDomain* exactConcreteId(IterDomain* id) { + return GpuLower::current()->caMap()->getConcreteMappedID( + id, IdMappingMode::EXACT); +} +} // namespace + void LoopIndexingAnalysis::visitExpr(Expr* expr) { if (auto swizzle2d = dynamic_cast(expr)) { // Swizzle outputs are already forwarded through @@ -700,14 +722,14 @@ void LoopIndexingAnalysis::visitExpr(Expr* expr) { consumed_ids.begin(), consumed_ids.end(), std::inserter(consumed_concrete_, consumed_concrete_.end()), - ir_utils::caMapExactConcreteId); + exactConcreteId); auto produced_ids = ir_utils::filterByType(expr->outputs()); std::transform( produced_ids.begin(), produced_ids.end(), std::inserter(produced_concrete_, produced_concrete_.end()), - ir_utils::caMapExactConcreteId); + exactConcreteId); } bool LoopIndexingAnalysis::visitIdsAndCheckDuplication( @@ -732,8 +754,36 @@ void LoopIndexingAnalysis::constructLoopDomains() { !concrete_id_to_consumer_.count(concrete_id) && // Use permissive map so the selected ID indeed represents the // loop. - GpuLower::current()->caMap()->areMapped( - concrete_id, loop_id, IdMappingMode::PERMISSIVE); + // Note: see PR https://github.com/csarofeen/pytorch/pull/1960 + // and issue https://github.com/csarofeen/pytorch/issues/1873 + // This mapping look up is part of a staged indexing scheme. + // When we find a replayed exact id that exactly map to the loop + // id, this means that we can resolve indexing involved in this + // loop "locally", i.e. only with and with only the iterdomains + // on the + // + // given consumer tv. + // When we cannot find an exact mapping, the permissive mapping + // would + // help defering the indexing resolution for this loop nest + // level to other iterdomain expressions from tv's that are + // further concretized and usually they are further down the + // consumer chain of the given consumer tv. + // + // Intuitively exact mapping of two iterdomains should imply + // permissive mapping + // of them as well and if that was the case, only looking up + // permissive mapping would be enough to address both of the + // cases above. + // FIXME: But currently exact mapping does not imply permissive + // mapping (See issue: + // https://github.com/csarofeen/pytorch/issues/1963) + // Which means we should check both exact and permissive mapping + // here. + (GpuLower::current()->caMap()->areMapped( + concrete_id, loop_id, IdMappingMode::EXACT) || + GpuLower::current()->caMap()->areMapped( + concrete_id, loop_id, IdMappingMode::PERMISSIVE)); }); TORCH_INTERNAL_ASSERT( @@ -769,7 +819,8 @@ void LoopIndexingAnalysis::constructLoopDomains() { // will complain for not having all outputs of the traversal. for (auto id : ir_utils::filterByType(all_ids_from_root)) { if (id->uses().empty()) { - loop_domains_.pushBack(ir_utils::caMapExactConcreteId(id)); + loop_domains_.pushBack(GpuLower::current()->caMap()->getConcreteMappedID( + id, IdMappingMode::EXACT)); } } } @@ -797,7 +848,7 @@ IndexFromIdGraph getTensorIndexFromIdGraph( } if (is_global) { - index_parameters = getGlobalIndexParameters(loop_indexing, index_producer); + index_parameters = getLinearIndexParameters(loop_indexing, index_producer); } else { index_parameters = getNonGlobalInitialIndexParameters( loop_indexing, consumer_tv, index_producer, producer_tv, p2c_map); @@ -849,7 +900,8 @@ IndexFromIdGraph getTensorIndexFromIdGraph( // Exact id will have to be pulled from consumer side as the // producer side are replayed ids. - auto exact_concrete_id = ir_utils::caMapExactConcreteId(consumer_id); + auto exact_concrete_id = GpuLower::current()->caMap()->getConcreteMappedID( + consumer_id, IdMappingMode::EXACT); index_update_map[exact_concrete_id] = target_id; @@ -864,7 +916,12 @@ IndexFromIdGraph getTensorIndexFromIdGraph( target_tv->domain()->domain(), target_tv->getMaybeRFactorDomain(), target_tv->domain()->contiguity(), - initial_indexable_map, + {}, + indexing.indexMap(), + GpuLower::current()->divisbleSplitSet(), + GpuLower::current()->caMap(), + GpuLower::current()->haloInfo(), + GpuLower::current()->concretizedBroadcastDomains(), p2c_map); auto target_indexing = indexing.updateIndexCompute( @@ -930,18 +987,16 @@ IndexFromIdGraph getPredicateIndexingFromIdGraph( ir_utils::filterByType(all_consumer_vals)) { // Track the non-concrete id we were trying to bind index // to, whether from producer or consumer. - auto exact_concrete_id = ir_utils::caMapExactConcreteId(consumer_id); + auto exact_concrete_id = GpuLower::current()->caMap()->getConcreteMappedID( + consumer_id, IdMappingMode::EXACT); index_update_map[exact_concrete_id] = consumer_id; } - // No contiguity info is used in the predicate indexing pass, - // the predicate generation logic that uses the index math - // generated here will take contiguity into account. - ContigIDs contig_finder( - consumer_tv->domain()->domain(), - consumer_tv->getMaybeRFactorDomain(), - std::vector(consumer_tv->getMaybeRFactorDomain().size(), false), - {}); + // No contiguity info is used in the predicate indexing pass, the predicate + // generation logic that uses the index math generated here will take + // contiguity into account. Send an empty ContigID class so nothing is marked + // as contiguous. + auto contig_finder = ContigIDs::getNonContigIDs(); // Run second backward traversal to map back to the consumer_tv auto target_indexing = indexing.updateIndexCompute( @@ -1009,7 +1064,8 @@ LoopIndexingTraversal::LoopIndexingTraversal( auto next_ids = ir_utils::filterByType(nextValsInTraversalOrder(expr)); for (auto id : next_ids) { - auto concrete_id = ir_utils::caMapExactConcreteId(id); + auto concrete_id = GpuLower::current()->caMap()->getConcreteMappedID( + id, IdMappingMode::EXACT); TORCH_INTERNAL_ASSERT( concrete_id_to_dependency_.insert(std::make_pair(concrete_id, expr)) .second, @@ -1077,7 +1133,8 @@ std::vector LoopIndexingTraversal::getExprList() { for (auto prev_id : ir_utils::filterByType(prevValsInTraversalOrder(top))) { auto prev_expr_it = concrete_id_to_dependency_.find( - ir_utils::caMapExactConcreteId(prev_id)); + GpuLower::current()->caMap()->getConcreteMappedID( + prev_id, IdMappingMode::EXACT)); if (prev_expr_it != concrete_id_to_dependency_.end()) { auto prev_expr = prev_expr_it->second; if (!visited.count(prev_expr)) { @@ -1114,7 +1171,7 @@ void LoopIndexingAnalysis::collectOutOfLineExprs() { consumer_tv_->getComputeAtPosition(), consumer_tv_->domain()->domain().end(), std::inserter(out_of_line_ids, out_of_line_ids.end()), - ir_utils::caMapExactConcreteId); + exactConcreteId); // Get the original selected list of index expressions // in reverse topological order. @@ -1129,7 +1186,9 @@ void LoopIndexingAnalysis::collectOutOfLineExprs() { id_outputs.begin(), id_outputs.end(), [&out_of_line_ids](IterDomain* id) { - return out_of_line_ids.count(ir_utils::caMapExactConcreteId(id)); + return out_of_line_ids.count( + GpuLower::current()->caMap()->getConcreteMappedID( + id, IdMappingMode::EXACT)); })) { // Record out of line expression out_of_line_exprs_.push_back(expr); @@ -1140,7 +1199,7 @@ void LoopIndexingAnalysis::collectOutOfLineExprs() { id_inputs.begin(), id_inputs.end(), std::inserter(out_of_line_ids, out_of_line_ids.end()), - ir_utils::caMapExactConcreteId); + exactConcreteId); } } } @@ -1161,14 +1220,14 @@ std::unordered_set LoopIndexing::getAllExactConcreteIdSet() const { out_ids.begin(), out_ids.end(), std::inserter(all_id_set, all_id_set.end()), - ir_utils::caMapExactConcreteId); + exactConcreteId); auto in_ids = ir_utils::filterByType(expr->inputs()); std::transform( in_ids.begin(), in_ids.end(), std::inserter(all_id_set, all_id_set.end()), - ir_utils::caMapExactConcreteId); + exactConcreteId); } return all_id_set; } @@ -1213,7 +1272,9 @@ class LoopIndexingPreferredPathCompute : public IterVisitor { } mapped_id = c_id_it->second; } - auto concrete_original_id = ir_utils::caMapExactConcreteId(mapped_id); + auto concrete_original_id = + GpuLower::current()->caMap()->getConcreteMappedID( + mapped_id, IdMappingMode::EXACT); if (all_concrete_ids.count(concrete_original_id)) { if (original_id->isBroadcast() || original_id->isReduction() || original_id->isStride()) { @@ -1239,8 +1300,10 @@ class LoopIndexingPreferredPathCompute : public IterVisitor { all_iter_inputs.begin(), all_iter_inputs.end(), [&](IterDomain* inp_id) { - return this->preferred_path_.find(ir_utils::caMapExactConcreteId( - inp_id)) != this->preferred_path_.end(); + return this->preferred_path_.find( + GpuLower::current()->caMap()->getConcreteMappedID( + inp_id, IdMappingMode::EXACT)) != + this->preferred_path_.end(); })) { auto all_iter_outputs = ir_utils::filterByType(e->outputs()); @@ -1248,7 +1311,7 @@ class LoopIndexingPreferredPathCompute : public IterVisitor { all_iter_outputs.begin(), all_iter_outputs.end(), std::inserter(preferred_path_, preferred_path_.end()), - ir_utils::caMapExactConcreteId); + exactConcreteId); } } diff --git a/torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp b/torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp index 12b02d0b51ce3..86ca9d8427e78 100644 --- a/torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp @@ -293,7 +293,7 @@ class WarSyncInserter : private kir::ExprMutator { auto maybe_aliased_tv = alloc_map_.getRealBuffer(tv); auto alloc_it = smem_allocations_.find(maybe_aliased_tv); auto ca_loop = - loop_utils::getAllocInformation(tv, for_loops_).init_for_loop; + lower_utils::getAllocInformation(tv, for_loops_).init_for_loop; if (alloc_it == smem_allocations_.end()) { WarMemoryInfo mem_info; mem_info.ca_loop = ca_loop; @@ -486,7 +486,7 @@ class ReadAfterWriteSyncs : public kir::ExprMutator { Expr* sync_expr = nullptr; kir::Allocate* maybe_alloc = nullptr; if (sync_bitmap.hasBID()) { - maybe_alloc = ir_utils::allocGlobalBufferForGridComm( + maybe_alloc = lower_utils::allocGlobalBufferForGridComm( getGridSyncBufferSize(sync_bitmap), DataType::Int, true); sync_expr = IrBuilder::create( sync_bitmap, maybe_alloc->buffer()); diff --git a/torch/csrc/jit/codegen/cuda/lower_loops.cpp b/torch/csrc/jit/codegen/cuda/lower_loops.cpp index 7fdb149da9359..0653296366ccc 100644 --- a/torch/csrc/jit/codegen/cuda/lower_loops.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_loops.cpp @@ -33,7 +33,7 @@ LoopNestGenerator::LoopNestGenerator(const std::vector& exprs) { namespace { kir::ForLoop* openForHelper(kir::ForLoop* scope, IterDomain* id) { - auto extent_with_halo = GpuLower::current()->haloInfo().getExtent(id); + auto extent_with_halo = GpuLower::current()->haloInfo()->getExtent(id); kir::ForLoop* new_scope = nullptr; if (extent_with_halo) { // When an axis is extended with halo, unrolling and vectorization @@ -252,7 +252,7 @@ void LoopNestGenerator::generate(const std::vector& exprs) { std::sort( loop_structure.rbegin(), loop_structure.rend(), - IterDomainDependencySorter( + ir_utils::IterDomainDependencySorter( concrete_id_dependencies, GpuLower::current()->caMap())); loop_structures_[tv] = loop_structure; } diff --git a/torch/csrc/jit/codegen/cuda/lower_misaligned_vectorization.cpp b/torch/csrc/jit/codegen/cuda/lower_misaligned_vectorization.cpp index bd3c9baf66e1f..9e713f4cf3a23 100644 --- a/torch/csrc/jit/codegen/cuda/lower_misaligned_vectorization.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_misaligned_vectorization.cpp @@ -462,7 +462,7 @@ class MisalignedVectorizationModifier : public kir::ExprMutator { TORCH_INTERNAL_ASSERT( !gpu_lower->trivialReductionInfo().isDerived(producer_root_id), - "No trivial reduciton axis should exist: ", + "No trivial reduction axis should exist: ", producer_root_id); // If the producer ID is reduction or broadcast, it should be safe diff --git a/torch/csrc/jit/codegen/cuda/lower_predicate.cpp b/torch/csrc/jit/codegen/cuda/lower_predicate.cpp index 989c00be81b78..7b0393d491572 100644 --- a/torch/csrc/jit/codegen/cuda/lower_predicate.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_predicate.cpp @@ -20,7 +20,7 @@ namespace cuda { namespace { -class ConditionalFromPredicateModifier : public kir::IrVisitor { +class ConditionalFromPredicateModifier : public kir::ExprMutator { public: ConditionalFromPredicateModifier() = delete; @@ -32,47 +32,58 @@ class ConditionalFromPredicateModifier : public kir::IrVisitor { private: ConditionalFromPredicateModifier(const std::vector& exprs) { FUSER_PERF_SCOPE( - "GpuLower::Lower::ConditionalFromPredicateModifier::process"); - kir::IrVisitor::handle(exprs); + "ConditionalFromPredicateModifier::ConditionalFromPredicateModifier"); + traverseAndInsert(exprs); } - using kir::IrVisitor::handle; + using kir::ExprMutator::handle; void handle(Expr* expr) final { if (expr != nullptr && expr->predicate() != nullptr) { // Replace expr predicate with bool conditional auto conditional = generateConditional(expr->predicate()); if (expr->predicate()->predicate_type() == PredicateType::Vectorize) { - // TODO: This logic doesn't seem to fit well here, for unswitch the - // logic is in the unroll loop to set the thread predicate to the expr. - // I didn't have a quick way to do that so placing this here for now. - TORCH_INTERNAL_ASSERT( - expr->isA(), - "Predicate handling expects ITE statement."); - auto ite = expr->as(); - - TORCH_INTERNAL_ASSERT( - ite->thenBody().size() == 1, - "Expecting predicated body to only have one vectorized expression."); - auto vec_expr = ite->thenBody()[0]; - TORCH_INTERNAL_ASSERT( - vec_expr->isA() || vec_expr->isA(), - "Vectorize predicate exprs only supported on set operations."); - TORCH_INTERNAL_ASSERT( - ir_utils::isTvOp(vec_expr), - "Vectorize predicate exprs only supported on tensor view operations."); - if (!vec_expr->inputs()[0]->isConstScalar()) { + if (expr->isA()) { + // TODO: This logic doesn't seem to fit well here, for unswitch the + // logic is in the unroll loop to set the thread predicate to the + // expr. I didn't have a quick way to do that so placing this here for + // now. + auto ite = expr->as(); + + TORCH_INTERNAL_ASSERT( + ite->thenBody().size() == 1, + "Expecting predicated body to only have one vectorized expression."); + auto vec_expr = ite->thenBody()[0]; + TORCH_INTERNAL_ASSERT( + vec_expr->isA() || vec_expr->isA(), + "Vectorize predicate exprs only supported on set operations."); + TORCH_INTERNAL_ASSERT( + ir_utils::isTvOp(vec_expr), + "Vectorize predicate exprs only supported on tensor view operations."); + if (!vec_expr->inputs()[0]->isConstScalar()) { + conditional = SimplifyingIrBuilder::andExpr( + conditional, + GpuLower::current()->threadPredMap().getPredicate( + ir_utils::getTvOutput(vec_expr))) + ->as(); + } + } else { + TORCH_INTERNAL_ASSERT(lower_utils::supportInlinePredicate(expr)); + auto thread_pred = GpuLower::current()->threadPredMap().getPredicate( + ir_utils::getTvOutput(expr)); + TORCH_INTERNAL_ASSERT( + thread_pred->isConst() && thread_pred->value().value()); conditional = SimplifyingIrBuilder::andExpr( conditional, GpuLower::current()->threadPredMap().getPredicate( - ir_utils::getTvOutput(vec_expr))) + ir_utils::getTvOutput(expr))) ->as(); } } TORCH_INTERNAL_ASSERT(conditional != nullptr); expr->predicate()->setValue(conditional); TORCH_INTERNAL_ASSERT(expr->predicate()->value() != nullptr); - setWritePredicate(expr, conditional); + setWritePredicate(expr); } // Note: [Predicate Inversion for CpAsync] @@ -101,7 +112,7 @@ class ConditionalFromPredicateModifier : public kir::IrVisitor { invertPredicateForGmemToSharedMemInitialize(expr); } - kir::IrVisitor::handle(expr); + kir::ExprMutator::handle(expr); } // Invert the predicate of given expr. @@ -123,7 +134,7 @@ class ConditionalFromPredicateModifier : public kir::IrVisitor { ir_utils::isCpAsyncInit(maybe_init.value()); } - void setWritePredicate(Expr* expr, Bool* read_cond) { + void setWritePredicate(Expr* expr) { if (expr->writePredicate() != nullptr) { auto write_cond = generateConditional(expr->writePredicate()); if (write_cond) { @@ -131,7 +142,7 @@ class ConditionalFromPredicateModifier : public kir::IrVisitor { } else { // If generateConditional returns null, it means no specific // predicate needs to be used. - expr->setWritePredicate(nullptr); + registerReplace(expr, expr->withWritePredicate(nullptr)); } } } @@ -150,7 +161,7 @@ class ConditionalFromPredicateModifier : public kir::IrVisitor { ite->predicate()->setValue(conditional); TORCH_INTERNAL_ASSERT(ite->predicate()->value() != nullptr); } - kir::IrVisitor::handle(ite); + kir::ExprMutator::handle(ite); } // Generate conditional according to PredicateType diff --git a/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp b/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp index 940de32ce9567..294a2327bbba0 100644 --- a/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp @@ -303,12 +303,12 @@ class PredicateChcker : public IterVisitor { // Shift is not supported yet. bool predicateShift(Expr* expr) const { - auto& halo_info = GpuLower::current()->haloInfo(); + auto halo_info = GpuLower::current()->haloInfo(); auto input_tvs = ir_utils::filterByType(expr->inputs()); - return halo_info.needsShiftPredicate(expr) || + return halo_info->needsShiftPredicate(expr) || std::any_of(input_tvs.begin(), input_tvs.end(), [&](auto input_tv) { return input_tv->definition() != nullptr && - halo_info.needsShiftPredicate(input_tv->definition()); + halo_info->needsShiftPredicate(input_tv->definition()); }); } @@ -925,7 +925,7 @@ bool PredicateElimination::setReductionInitValue( } else { TORCH_INTERNAL_ASSERT( false, - "Incosistent setting of initialization value for t", + "Inconsistent setting of initialization value for t", tv->name(), ". Prev: ", existing_val->toString(), @@ -991,7 +991,7 @@ Val* PredicateElimination::getInitValue(TensorView* tv) const { } void PredicateElimination::build(Fusion* fusion) { - traverseFrom(fusion, fusion->outputs()); + traverseTo(fusion, fusion->outputs()); } std::string PredicateElimination::toString() const { diff --git a/torch/csrc/jit/codegen/cuda/lower_shift.cpp b/torch/csrc/jit/codegen/cuda/lower_shift.cpp index fe1e0cc509c13..2a7c04243f4cf 100644 --- a/torch/csrc/jit/codegen/cuda/lower_shift.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_shift.cpp @@ -17,7 +17,7 @@ namespace jit { namespace fuser { namespace cuda { -void ShiftPredicateInserter::insert( +Expr* ShiftPredicateInserter::insert( Expr* expr, const std::vector& loops, Bool* thread_pred, @@ -28,9 +28,9 @@ void ShiftPredicateInserter::insert( TORCH_INTERNAL_ASSERT(out_tv != nullptr, "Missing TensorView output"); const bool needs_shift_predicate = - gpu_lower->haloInfo().needsShiftPredicate(out_tv->definition()); + gpu_lower->haloInfo()->needsShiftPredicate(out_tv->definition()); if (!needs_shift_predicate) { - return; + return expr; } // The conditional branches to create: @@ -56,9 +56,8 @@ void ShiftPredicateInserter::insert( // If the expr involves a thread-block barrier, set the predicate of // the expr with shift_pred. Since the expr is not shift, the // padding is safe to omit. - if (ir_utils::hasBlockSync(expr, gpu_lower->threadPredMap())) { - expr->setPredicate(shift_pred); - return; + if (lower_utils::hasBlockSync(expr, gpu_lower->threadPredMap())) { + return expr->withPredicate(shift_pred); } auto shift_ite = IrBuilder::create(shift_pred); @@ -76,7 +75,7 @@ void ShiftPredicateInserter::insert( // No padding condition is required if this is within unswitch. if (within_unswitch) { - return; + return expr; } // Padding by zero @@ -89,6 +88,8 @@ void ShiftPredicateInserter::insert( bounds_ite->thenBody().push_back(pad_expr); // Insert the else block shift_ite->elseBody().push_back(bounds_ite); + + return expr; } int AxisHaloInfo::width() const { @@ -145,13 +146,6 @@ const AxisHaloInfo& HaloInfo::getRootAxisInfo(IterDomain* id) const { return it->second; } -AxisHaloInfo& HaloInfo::getRootAxisInfo(IterDomain* id) { - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - return const_cast( - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - const_cast(this)->getRootAxisInfo(id)); -} - void HaloInfo::setRootAxisInfo( IterDomain* id, const AxisHaloInfo& root_axis_info) { @@ -161,7 +155,9 @@ void HaloInfo::setRootAxisInfo( return; } -void HaloInfo::build(Fusion* fusion) { +HaloInfo::HaloInfo(Fusion* fusion, std::shared_ptr ca_map) + // Make a copy of the permissive map for extent comparators + : permissive_map_(ca_map->idGraph().permissiveNodes()) { const auto vals = fusion->usedMathVals(); auto tvs = ir_utils::filterByType(vals); @@ -202,7 +198,7 @@ void HaloInfo::build(Fusion* fusion) { // Note that validation requires consumer halo info for (auto tv : tvs) { - validate(tv); + validate(tv, ca_map); } } @@ -445,8 +441,20 @@ void HaloInfo::build(TensorDomain* td) { } else { setHaloWidth(merge->out(), 0); } - } else if (expr->getExprType().value() == ExprType::Swizzle2D) { + } else if (auto swizzle = dynamic_cast(expr)) { // Assume no halo on swizzled domain for now. + TORCH_INTERNAL_ASSERT( + getExtent(swizzle->inX()) == nullptr, + "Halo is not supported with swizzle. Halo-extended ID: ", + swizzle->inX()->toString(), + " used in ", + swizzle->toString()); + TORCH_INTERNAL_ASSERT( + getExtent(swizzle->inY()) == nullptr, + "Halo is not supported with swizzle. Halo-extended ID: ", + swizzle->inY()->toString(), + " used in ", + swizzle->toString()); for (auto id : ir_utils::filterByType(expr->outputs())) { setHaloWidth(id, 0); } @@ -474,12 +482,13 @@ void HaloInfo::build(TensorDomain* td) { //! Other types of parallelization should be supported except for //! vectorization. Vectorization should be eventually supported but //! needs further work. -void HaloInfo::validate(TensorView* tv) const { +void HaloInfo::validate( + TensorView* tv, + std::shared_ptr ca_map) const { const auto mem_type = tv->getMemoryType(); for (auto axis : tv->domain()->domain()) { - auto concrete_id = GpuLower::current()->caMap()->getConcreteMappedID( - axis, IdMappingMode::LOOP); + auto concrete_id = ca_map->getConcreteMappedID(axis, IdMappingMode::LOOP); // The extent is assumed to be the same TORCH_INTERNAL_ASSERT( @@ -526,7 +535,7 @@ void HaloInfo::validate(TensorView* tv) const { consumer->domain()->domain().begin(), consumer->domain()->domain().end(), [&](IterDomain* consumer_axis) { - return GpuLower::current()->caMap()->areMapped( + return ca_map->areMapped( axis, consumer_axis, IdMappingMode::PERMISSIVE); }); if (it == consumer->domain()->domain().end()) { @@ -626,11 +635,10 @@ bool extentCompare( const HaloInfo& halo_map, IterDomain* id1, IterDomain* id2, - Cmp cmp) { - auto gpu_lower = GpuLower::current(); + Cmp cmp, + const DisjointSets& permissive_map) { TORCH_INTERNAL_ASSERT( - gpu_lower->caMap()->areMapped(id1, id2, IdMappingMode::PERMISSIVE), - "Invalid axes to compare"); + permissive_map.strictAreMapped(id1, id2), "Invalid axes to compare"); // It's invalid to compare two axes and when only either of them has // halo. @@ -652,10 +660,10 @@ bool extentCompare( auto merge2 = dynamic_cast(id2->definition()); TORCH_INTERNAL_ASSERT( merge2 != nullptr, "Invalid comparison: ", id1, " and ", id2); - auto inner_le = - extentCompare(halo_map, merge1->inner(), merge2->inner(), cmp); - auto outer_le = - extentCompare(halo_map, merge1->outer(), merge2->outer(), cmp); + auto inner_le = extentCompare( + halo_map, merge1->inner(), merge2->inner(), cmp, permissive_map); + auto outer_le = extentCompare( + halo_map, merge1->outer(), merge2->outer(), cmp, permissive_map); return inner_le && outer_le; } else { // This is not considered. Should never reach here. @@ -667,11 +675,11 @@ bool extentCompare( } // namespace bool HaloInfo::extentLessEqual(IterDomain* id1, IterDomain* id2) const { - return extentCompare(*this, id1, id2, std::less_equal<>()); + return extentCompare(*this, id1, id2, std::less_equal<>(), permissive_map_); } bool HaloInfo::extentEqual(IterDomain* id1, IterDomain* id2) const { - return extentCompare(*this, id1, id2, std::equal_to<>()); + return extentCompare(*this, id1, id2, std::equal_to<>(), permissive_map_); } std::string HaloInfo::toString() const { @@ -722,19 +730,20 @@ bool HaloInfo::needsShiftPredicate(Expr* expr) const { } std::unordered_map HaloInfo::buildConcreteHaloExtentMap( - const LoopIndexing& loop_indexing) { + const LoopIndexing& loop_indexing) const { // Use a local workspace to avoid re-defining halo info. - HaloInfo local_halo_info; + HaloInfo local_halo_info = *GpuLower::current()->haloInfo(); - auto& global_halo_info = GpuLower::current()->haloInfo(); + auto global_halo_info = GpuLower::current()->haloInfo(); // Setup root: for (auto consumer_root_id : loop_indexing.consumerTv()->getRootDomain()) { auto consumer_index_concrete_id = - ir_utils::caMapExactConcreteId(consumer_root_id); + GpuLower::current()->caMap()->getConcreteMappedID( + consumer_root_id, IdMappingMode::EXACT); local_halo_info.setRootAxisInfo( consumer_index_concrete_id, - global_halo_info.getRootAxisInfo(consumer_root_id)); + global_halo_info->getRootAxisInfo(consumer_root_id)); } // Track IDs that are generated by merging halo-extended IDs @@ -747,7 +756,8 @@ std::unordered_map HaloInfo::buildConcreteHaloExtentMap( merged_shifted_ids.find(split->in()) == merged_shifted_ids.end(), "Splitting IterDomain that is a merged domain of halo-extended domains is not allowed"); - auto in_id = ir_utils::caMapExactConcreteId(split->in()); + auto in_id = GpuLower::current()->caMap()->getConcreteMappedID( + split->in(), IdMappingMode::EXACT); // If no halo info is found, nothing needs to be done. This ID // must be an ancestor of a domain set by setRootAxisInfo. @@ -759,32 +769,43 @@ std::unordered_map HaloInfo::buildConcreteHaloExtentMap( if (halo_width == 0) { local_halo_info.setHaloWidth( - ir_utils::caMapExactConcreteId(split->outer()), 0); + GpuLower::current()->caMap()->getConcreteMappedID( + split->outer(), IdMappingMode::EXACT), + 0); local_halo_info.setHaloWidth( - ir_utils::caMapExactConcreteId(split->inner()), 0); + GpuLower::current()->caMap()->getConcreteMappedID( + split->inner(), IdMappingMode::EXACT), + 0); continue; } // propagate to inner domain - auto out_id = ir_utils::caMapExactConcreteId(split->inner()); + auto out_id = GpuLower::current()->caMap()->getConcreteMappedID( + split->inner(), IdMappingMode::EXACT); auto expanded_extent = SimplifyingIrBuilder::addExpr(out_id->extent(), halo_width); local_halo_info.extent_map_.insert({out_id, expanded_extent}); local_halo_info.setHaloWidth( - ir_utils::caMapExactConcreteId(split->outer()), 0); + GpuLower::current()->caMap()->getConcreteMappedID( + split->outer(), IdMappingMode::EXACT), + 0); local_halo_info.setHaloWidth( - ir_utils::caMapExactConcreteId(split->inner()), halo_width); + GpuLower::current()->caMap()->getConcreteMappedID( + split->inner(), IdMappingMode::EXACT), + halo_width); // TODO: add support for inheritance map } else if (auto merge = dynamic_cast(expr)) { // If either of the two inputs has halo extension, propagate it // to the merged output ID auto inner_extent = local_halo_info.getExtent( - ir_utils::caMapExactConcreteId(merge->inner())); + GpuLower::current()->caMap()->getConcreteMappedID( + merge->inner(), IdMappingMode::EXACT)); auto outer_extent = local_halo_info.getExtent( - ir_utils::caMapExactConcreteId(merge->outer())); + GpuLower::current()->caMap()->getConcreteMappedID( + merge->outer(), IdMappingMode::EXACT)); if (inner_extent != nullptr || outer_extent != nullptr) { if (inner_extent == nullptr) { inner_extent = merge->inner()->extent(); @@ -795,28 +816,41 @@ std::unordered_map HaloInfo::buildConcreteHaloExtentMap( auto expanded_extent = SimplifyingIrBuilder::mulExpr(outer_extent, inner_extent); local_halo_info.extent_map_.insert( - {ir_utils::caMapExactConcreteId(merge->out()), expanded_extent}); + {GpuLower::current()->caMap()->getConcreteMappedID( + merge->out(), IdMappingMode::EXACT), + expanded_extent}); // Splitting the output of this merge is not allowed, so // remember it - merged_shifted_ids.insert(ir_utils::caMapExactConcreteId(merge->out())); + merged_shifted_ids.insert( + GpuLower::current()->caMap()->getConcreteMappedID( + merge->out(), IdMappingMode::EXACT)); // Note that halo_width_map_ is not updated } else { - setHaloWidth(ir_utils::caMapExactConcreteId(merge->out()), 0); + local_halo_info.setHaloWidth( + GpuLower::current()->caMap()->getConcreteMappedID( + merge->out(), IdMappingMode::EXACT), + 0); } } else if (auto swizzle_2d = dynamic_cast(expr)) { // Swizzle with halo not yet supported, just set the width // to zero at the moment. TORCH_INTERNAL_ASSERT( local_halo_info.getHaloWidth( - ir_utils::caMapExactConcreteId(swizzle_2d->inX())) == 0 && + GpuLower::current()->caMap()->getConcreteMappedID( + swizzle_2d->inX(), IdMappingMode::EXACT)) == 0 && local_halo_info.getHaloWidth( - ir_utils::caMapExactConcreteId(swizzle_2d->inY())) == 0, + GpuLower::current()->caMap()->getConcreteMappedID( + swizzle_2d->inY(), IdMappingMode::EXACT)) == 0, "Swizzle on ID with halo not yet supported."); TORCH_INTERNAL_ASSERT("Swizzle on ID with halo not yet supported."); local_halo_info.setHaloWidth( - ir_utils::caMapExactConcreteId(swizzle_2d->outX()), 0); + GpuLower::current()->caMap()->getConcreteMappedID( + swizzle_2d->outX(), IdMappingMode::EXACT), + 0); local_halo_info.setHaloWidth( - ir_utils::caMapExactConcreteId(swizzle_2d->outY()), 0); + GpuLower::current()->caMap()->getConcreteMappedID( + swizzle_2d->outY(), IdMappingMode::EXACT), + 0); } else { TORCH_INTERNAL_ASSERT(false, "Unsupported expr: ", expr); } diff --git a/torch/csrc/jit/codegen/cuda/lower_shift.h b/torch/csrc/jit/codegen/cuda/lower_shift.h index d1500c5f9f203..f12410703d99d 100644 --- a/torch/csrc/jit/codegen/cuda/lower_shift.h +++ b/torch/csrc/jit/codegen/cuda/lower_shift.h @@ -61,23 +61,12 @@ class AxisHaloInfo { class TORCH_CUDA_CU_API HaloInfo { public: //! Scan a fusion and collect all information for lowering - void build(Fusion* fusion); - - //! Build mappings of extent information of a TensorDomain - void build(TensorDomain* td); + HaloInfo(Fusion* fusion, std::shared_ptr ca_map); //! Almost exact duplicate of build(TensorDomain* td), except that //! the traversal was done on loop indexing expressions. std::unordered_map buildConcreteHaloExtentMap( - const LoopIndexing& loop_indexing); - - //! Set initial AxisHaloInfo of a root axis - //! - //! The axis does not need to be a root domain in the case of - //! reference tensors. Reference tensors get halo information from - //! consumer root domains, which may correspond to rfactor domains - //! of tensors from which reference tensors are derived. - void setRootAxisInfo(IterDomain* id, const AxisHaloInfo& root_axis_info); + const LoopIndexing& loop_indexing) const; //! Returns true if id has the root halo information set by //! setRootAxisInfo. @@ -88,7 +77,6 @@ class TORCH_CUDA_CU_API HaloInfo { //! This is only for root axes. It is an error to query with //! non-root axes. const AxisHaloInfo& getRootAxisInfo(IterDomain* id) const; - AxisHaloInfo& getRootAxisInfo(IterDomain* id); //! Query if an axis has a halo width. //! @@ -139,10 +127,21 @@ class TORCH_CUDA_CU_API HaloInfo { std::string toString() const; private: + //! Build mappings of extent information of a TensorDomain + void build(TensorDomain* td); + //! Propagate root axis information from outputs to inputs of an //! expression void propagateRootAxisInfo(Expr* expr); + //! Set initial AxisHaloInfo of a root axis + //! + //! The axis does not need to be a root domain in the case of + //! reference tensors. Reference tensors get halo information from + //! consumer root domains, which may correspond to rfactor domains + //! of tensors from which reference tensors are derived. + void setRootAxisInfo(IterDomain* id, const AxisHaloInfo& root_axis_info); + //! Adds a domain to the halo inheritance map. //! //! A domain, child, is added to the same set as domain parent. Both @@ -163,11 +162,15 @@ class TORCH_CUDA_CU_API HaloInfo { void initializeFromRootAxisInfo(IterDomain* id); //! Validate shift usage - void validate(TensorView* td) const; + void validate(TensorView* td, std::shared_ptr ca_map) + const; void setHaloWidth(IterDomain* id, int halo_width); private: + // Copy the permissive map from the passed in compute at map + const DisjointSets permissive_map_; + //! Halo information of root axes std::unordered_map root_axis_map_; @@ -222,7 +225,7 @@ class ShiftPredicateInserter { //! the generated predicate. The branch structure is different from //! the usual predicated expression, so the insertion is also done //! here. - static void insert( + static Expr* insert( Expr* expr, const std::vector& loops, Bool* thread_pred, diff --git a/torch/csrc/jit/codegen/cuda/lower_sync_information.cpp b/torch/csrc/jit/codegen/cuda/lower_sync_information.cpp index 497256b5f850e..9b8ccd4a77ae4 100644 --- a/torch/csrc/jit/codegen/cuda/lower_sync_information.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_sync_information.cpp @@ -26,7 +26,7 @@ void validateParallelizationOfTensor(TensorView* tv) { // It doesn't matter if this axis is a non-concretized broadcast // TODO: merging broadcast and non-broadcast if (axis->isBroadcast() && - !GpuLower::current()->concretizedBroadcastDomains().isConcretized( + !GpuLower::current()->concretizedBroadcastDomains()->isConcretized( axis)) { continue; } @@ -195,7 +195,7 @@ void SyncMap::build(Fusion* fusion) { (!parallel_bcast_doms.get(consumer_ptype) || !GpuLower::current() ->concretizedBroadcastDomains() - .isConcretized(consumer_axis))) { + ->isConcretized(consumer_axis))) { continue; } @@ -240,12 +240,12 @@ void SyncMap::build(Fusion* fusion) { p_id, c_id, IdMappingMode::PERMISSIVE)) { const auto halo_info = GpuLower::current()->haloInfo(); - if (halo_info.hasHaloWidth(p_id) != - halo_info.hasHaloWidth(c_id) || - (halo_info.hasHaloWidth(p_id) && - halo_info.hasHaloWidth(c_id) && - halo_info.getHaloWidth(p_id) != - halo_info.getHaloWidth(c_id))) { + if (halo_info->hasHaloWidth(p_id) != + halo_info->hasHaloWidth(c_id) || + (halo_info->hasHaloWidth(p_id) && + halo_info->hasHaloWidth(c_id) && + halo_info->getHaloWidth(p_id) != + halo_info->getHaloWidth(c_id))) { raw_dims.set(parallel_type); continue; } @@ -410,33 +410,13 @@ void SyncMap::build(Fusion* fusion) { } } - // If same parallel type and mapped, no need for syncs unless - // producer is in smem, producer parallel type is a thread - // dimension, and consumer concretizes the dimension. This sync is - // due to the redundant predicate omission in lower thread - // predicate. - auto redundant_preds = GpuLower::current() - ->threadPredMap() - .getPredicateInfo(producer) - .redundant_types; - - if (p_id->isBroadcast() && - GpuLower::current()->concretizedBroadcastDomains().isConcretized( - p_id) && - producer->getMemoryType() == MemoryType::Shared && - redundant_preds.hasTID()) { - redundant_preds.clearAllBID(); - raw_dims |= redundant_preds; - continue; - } - // When the producer axis is a broadcast, it is not really // parallelized unless thread-predicated and concretized if (isParallelTypeThread(producer_ptype) && p_id->isBroadcast() && (!parallel_bcast_doms.get(producer_ptype) || !GpuLower::current() ->concretizedBroadcastDomains() - .isConcretized(p_id))) { + ->isConcretized(p_id))) { continue; } @@ -483,7 +463,7 @@ void SyncMap::build(Fusion* fusion) { } // end for consumers if (raw_dims.any()) { - needs_raw_sync_[producer] = raw_dims; + needs_raw_sync_[producer] |= raw_dims; } } // end producer @@ -492,10 +472,14 @@ void SyncMap::build(Fusion* fusion) { std::string SyncMap::toString() const { std::stringstream ss; - ss << "TVs requiring RAW:" << std::endl; + ss << "SyncMap:"; + bool is_first = true; for (auto entry : needs_raw_sync_) { - ss << " " << entry.first->toString() << " :: " << entry.second.toString() - << std::endl; + if (!is_first) { + ss << ","; + } + ss << " " << entry.first->toString() << " -> " << entry.second.toString(); + is_first = false; } return ss.str(); } diff --git a/torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp b/torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp index 18a4426cb7c05..dc10224a165c0 100644 --- a/torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp @@ -237,7 +237,7 @@ void ThreadPredicateMap::updateBitSet(const Expr* expr) { id_reductions.set(id->getParallelType()); } if (id->isBroadcast() && - GpuLower::current()->concretizedBroadcastDomains().isConcretized( + GpuLower::current()->concretizedBroadcastDomains()->isConcretized( id)) { id_bcasts.set(id->getParallelType()); } @@ -316,7 +316,7 @@ class RedundantUseAnalysis : BackwardVisitor { public: RedundantUseAnalysis(Fusion* fusion, const ThreadPredicateMap& pred_map) : fusion_(fusion), pred_map_(pred_map) { - traverseFrom(fusion, fusion->terminatingMathVals()); + traverseTo(fusion, fusion->terminatingMathVals()); } //! Returns a bit map signifying the parallel dimensions @@ -575,7 +575,8 @@ ParallelTypeBitmap ThreadPredicateMap::getParallelBroadcastDomains( for (auto id : iter_domains) { if (!id->isBroadcast() || - !GpuLower::current()->concretizedBroadcastDomains().isConcretized(id)) { + !GpuLower::current()->concretizedBroadcastDomains()->isConcretized( + id)) { continue; } if (id->isBlockDim() || (!output_smem && id->isThreadDim())) { diff --git a/torch/csrc/jit/codegen/cuda/lower_trivial_broadcast.cpp b/torch/csrc/jit/codegen/cuda/lower_trivial_broadcast.cpp index 324bab279b37e..88a84aa3c5877 100644 --- a/torch/csrc/jit/codegen/cuda/lower_trivial_broadcast.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_trivial_broadcast.cpp @@ -1,6 +1,5 @@ #include #include -#include #include #include @@ -10,7 +9,7 @@ namespace jit { namespace fuser { namespace cuda { -void ConcretizedBroadcastDomains::build(Fusion* fusion) { +ConcretizedBroadcastDomains::ConcretizedBroadcastDomains(Fusion* fusion) { exact_map_ = std::make_unique(fusion); // Initialize the origin map with input broadcast domains diff --git a/torch/csrc/jit/codegen/cuda/lower_trivial_broadcast.h b/torch/csrc/jit/codegen/cuda/lower_trivial_broadcast.h index 24658f3cfe7c3..c30fa9951404a 100644 --- a/torch/csrc/jit/codegen/cuda/lower_trivial_broadcast.h +++ b/torch/csrc/jit/codegen/cuda/lower_trivial_broadcast.h @@ -23,7 +23,8 @@ namespace cuda { //! domains are marked as concretized. class TORCH_CUDA_CU_API ConcretizedBroadcastDomains : private IterVisitor { public: - void build(Fusion* fusion); + ConcretizedBroadcastDomains() = delete; + ConcretizedBroadcastDomains(Fusion* fusion); //! Is a domain concretized? bool isConcretized(IterDomain* id) const; diff --git a/torch/csrc/jit/codegen/cuda/lower_unroll.cpp b/torch/csrc/jit/codegen/cuda/lower_unroll.cpp index 434d1711d9c83..63dbbf83d775d 100644 --- a/torch/csrc/jit/codegen/cuda/lower_unroll.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_unroll.cpp @@ -54,6 +54,14 @@ bool isReductionInitExpr(const Expr* expr) { } // namespace +void UnrollPass::registerReplace( + Expr* reference, + Expr* new_expr, + kir::Scope* scope) { + kir::ExprMutator::registerReplace(reference, new_expr, scope); + GpuLower::current()->propagateExprInfo(reference, new_expr); +} + void UnrollPass::handle(Expr* expr) { if (ir_utils::isTvOp(expr)) { // If tv op, predicate it @@ -79,11 +87,16 @@ void UnrollPass::handle(Expr* expr) { non_trivial_pred_found_ = true; + Expr* expr_with_predicate = expr; + // When a predicate needs to account for ShiftOp, it is currently // taken care by its own function. - if (GpuLower::current()->haloInfo().needsShiftPredicate(expr)) { - ShiftPredicateInserter::insert( + if (GpuLower::current()->haloInfo()->needsShiftPredicate(expr)) { + expr_with_predicate = ShiftPredicateInserter::insert( expr, for_loops_, thread_pred, unswitched_loop_); + if (expr_with_predicate != expr) { + registerReplace(expr, expr_with_predicate, &for_loops_.back()->body()); + } return; } @@ -93,17 +106,18 @@ void UnrollPass::handle(Expr* expr) { ? thread_pred_expr : IrBuilder::create( PredicateType::ReductionWrite, expr, thread_pred); - expr->setWritePredicate(write_pred); + expr_with_predicate = expr_with_predicate->withWritePredicate(write_pred); } // For expr calling a device func with block sync, don't create // if-then-else but pass the predicate to the device func - if (ir_utils::hasBlockSync(expr, GpuLower::current()->threadPredMap())) { + if (lower_utils::hasBlockSync(expr, GpuLower::current()->threadPredMap())) { const auto pred = unswitched_loop_ ? thread_pred_expr : IrBuilder::create( PredicateType::Inline, expr, thread_pred); - expr->setPredicate(pred); + expr_with_predicate = expr_with_predicate->withPredicate(pred); + registerReplace(expr, expr_with_predicate, &for_loops_.back()->body()); return; } @@ -124,6 +138,12 @@ void UnrollPass::handle(Expr* expr) { PredicateType::Inline, expr, thread_pred); } + if (lower_utils::supportInlinePredicate(expr)) { + expr_with_predicate = expr_with_predicate->withPredicate(pred); + registerReplace(expr, expr_with_predicate, &for_loops_.back()->body()); + return; + } + // If we need a predicate, put expr inside an if then else kir::IfThenElse* inline_ite = IrBuilder::create(pred); if (for_loops_.empty()) { @@ -135,7 +155,10 @@ void UnrollPass::handle(Expr* expr) { kir::ExprMutator::registerReplace( expr, inline_ite, &for_loops_.back()->body()); } - inline_ite->thenBody().push_back(expr); + if (expr != expr_with_predicate) { + GpuLower::current()->propagateExprInfo(expr, expr_with_predicate); + } + inline_ite->thenBody().push_back(expr_with_predicate); } else if (auto for_loop = dynamic_cast(expr)) { handle(for_loop); } @@ -222,7 +245,7 @@ bool UnrollPass::canOmitElseClause(kir::ForLoop* fl) { // If there's any expression that requires barrier // synchronization, the else part can't be omitted for (auto expr : loop->body().exprs()) { - if (ir_utils::hasBlockSync(expr, pred_map)) { + if (lower_utils::hasBlockSync(expr, pred_map)) { return false; } } @@ -264,9 +287,7 @@ bool UnrollPass::canOmitElseClause(kir::ForLoop* fl) { return true; } -// Generate the loop nest structure and place it in lowered_exprs UnrollPass::UnrollPass(const std::vector& exprs) { - FUSER_PERF_SCOPE("GpuLower::Lower::UnrollPass::computeMap"); kir::ExprMutator::traverseAndInsert(exprs); } diff --git a/torch/csrc/jit/codegen/cuda/lower_unroll.h b/torch/csrc/jit/codegen/cuda/lower_unroll.h index 14725c405b770..786e45115ba65 100644 --- a/torch/csrc/jit/codegen/cuda/lower_unroll.h +++ b/torch/csrc/jit/codegen/cuda/lower_unroll.h @@ -62,6 +62,8 @@ class TORCH_CUDA_CU_API UnrollPass : kir::ExprMutator { static bool canOmitElseClause(kir::ForLoop* fl); private: + void registerReplace(Expr* reference, Expr* new_expr, kir::Scope* scope); + // Generate the for Expr replacement map UnrollPass(const std::vector& exprs); diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.cpp b/torch/csrc/jit/codegen/cuda/lower_utils.cpp index 5802b2b99b4b8..3e92269f278a7 100644 --- a/torch/csrc/jit/codegen/cuda/lower_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_utils.cpp @@ -36,13 +36,42 @@ kir::IfThenElse* cloneIfThenElse(kir::IfThenElse* ite) { namespace ir_utils { -TVDomainGuard::TVDomainGuard(TensorView* _tv, TensorDomain* td) - : tv_(_tv), prev_domain(tv_->domain()) { +TVDomainGuard::TVDomainGuard(TensorView* tv, TensorDomain* td) + : tv_(tv), prev_domain_(tv_->domain()) { tv_->setDomain(td); } +TVDomainGuard::TVDomainGuard(TVDomainGuard&& guard) + : tv_(nullptr), prev_domain_(guard.prev_domain_) { + std::swap(tv_, guard.tv_); +} + TVDomainGuard::~TVDomainGuard() { - tv_->setDomain(prev_domain); + if (tv_ != nullptr) { + tv_->setDomain(prev_domain_); + } +} + +ir_utils::TVDomainGuard overrideContiguityGuard( + TensorView* tv, + bool contiguity) { + // Use domain guard to ignore the contiguity of + // consumer tv. + TensorDomain* domain_with_specified_contiguity = nullptr; + std::vector contiguity_vector( + tv->getMaybeRFactorDomain().size(), contiguity); + if (tv->hasRFactor()) { + domain_with_specified_contiguity = IrBuilder::create( + tv->getRootDomain(), + tv->getRFactorDomain(), + tv->domain()->domain(), + contiguity_vector); + } else { + domain_with_specified_contiguity = IrBuilder::create( + tv->getRootDomain(), tv->domain()->domain(), contiguity_vector); + } + + return ir_utils::TVDomainGuard(tv, domain_with_specified_contiguity); } std::vector iterDomainInputsOf( @@ -92,7 +121,9 @@ bool isTvOp(const Expr* expr) { expr->getExprType().value() == ExprType::BinaryOp || expr->getExprType().value() == ExprType::TernaryOp || expr->getExprType().value() == ExprType::RNGOp || + expr->getExprType().value() == ExprType::FullOp || expr->getExprType().value() == ExprType::ARangeOp || + expr->getExprType().value() == ExprType::EyeOp || expr->getExprType().value() == ExprType::ReductionOp || expr->getExprType().value() == ExprType::GroupedReductionOp || expr->getExprType().value() == ExprType::WelfordOp || @@ -204,35 +235,6 @@ bool isScalarOp(const Expr* expr) { return true; } -bool hasBlockSync(const Expr* expr, const ThreadPredicateMap& pred_map) { - if (expr->isA()) { - return true; - } - - if (!isTvOp(expr)) { - return false; - } - - if (!(isReductionOp(expr) || expr->isA() || - expr->isA())) { - return false; - } - - // GroupedReductionOp can have multiple output TVs, but they must be - // parallelized in the same way, so just checking one of them is enough. - auto tv = getTvOutput(expr); - - if (tv->hasBlockReduction() || tv->hasGridReduction()) { - return true; - } else if (expr->isA()) { - const ParallelTypeBitmap pt_map = - GpuLower::current()->threadPredMap().getParallelBroadcastDomains(tv); - return pt_map.any(); - } - - return false; -} - c10::optional getMaybeWarpReductionDim( const Val* output, const Val* input) { @@ -369,20 +371,6 @@ bool isGlobalLoadInit(const Expr* expr) { return false; } -kir::Allocate* allocGlobalBufferForGridComm( - Val* buffer_size, - DataType dtype, - bool zero_init) { - const std::vector new_buffer_ids = { - IrBuilder::create(IterDomainBuilder( - GpuLower::current()->kernel()->zeroVal(), buffer_size))}; - const auto buffer_domain = IrBuilder::create(new_buffer_ids); - const auto buffer_tv = - IrBuilder::create(buffer_domain, dtype, MemoryType::Global); - return IrBuilder::create( - buffer_tv, buffer_tv->getMemoryType(), nullptr, zero_init); -} - namespace { class ExprFlattener : private kir::IrVisitor { @@ -417,112 +405,6 @@ std::vector flattenScopedExprs(const std::vector& loop_nests) { return ExprFlattener::flatten(loop_nests); } -IterDomain* caMapExactConcreteId(IterDomain* id) { - return GpuLower::current()->caMap()->getConcreteMappedID( - id, IdMappingMode::EXACT); -} - -std::vector getAllSwizzlesBetween( - std::vector from, - std::vector to) { - auto all_expr = DependencyCheck::getAllExprsBetween( - {from.begin(), from.end()}, {to.begin(), to.end()}); - - std::vector all_swizzles; - - std::copy_if( - all_expr.begin(), - all_expr.end(), - std::back_inserter(all_swizzles), - [](Expr* expr) { - return expr->getExprType().has_value() && - (expr->etype() == ExprType::Swizzle2D); - }); - - return all_swizzles; -} - -} // namespace ir_utils - -namespace loop_utils { - -BasicAllocInfo getAllocInformation( - const TensorView* tv, - const std::vector& for_loops, - const std::unordered_map& id_map, - bool use_id_map) { - BasicAllocInfo info; - auto gpu_lower = GpuLower::current(); - - bool outer_alloc_found = false; - - for (auto fl : for_loops) { - if (info.alloc_pos == tv->getComputeAtPosition()) { - break; - } - - if (tv->axis(info.alloc_pos)->isReduction()) { - const auto outputs = FusionGuard::getCurFusion()->getTerminatingOutputs(); - TORCH_INTERNAL_ASSERT( - std::find(outputs.begin(), outputs.end(), tv) != outputs.end(), - "Invalid computeAt of T", - tv->name(), - ". A reducation axis is detected outside computeAt point even though it is not an output tensor."); - break; - } - - auto fl_id = fl->iter_domain(); - - if (fl_id->getParallelType() == ParallelType::Unroll) { - break; - } - - // Shared memory must be allocated outside of unswitched - // domains. See issue #1133. - if (fl_id->getParallelType() == ParallelType::Unswitch && - tv->getMemoryType() == MemoryType::Shared) { - outer_alloc_found = true; - } - - // Assume global memory is allocated at outer most scope. - if (tv->getMemoryType() == MemoryType::Global) { - outer_alloc_found = true; - } - - // Allocation of a double buffered tensor is placed outside its - // double buffer axis. - if ((tv->isDoubleBuffered() || tv->isCircularBuffered()) && - tv->axis(info.alloc_pos) == - gpu_lower->doubleBufferInfo().getDoubleBufferAxis(tv)) { - outer_alloc_found = true; - } - - auto local_id = tv->axis(info.alloc_pos); - - if (use_id_map) { - auto id_it = id_map.find(local_id); - if (id_it != id_map.end()) { - local_id = id_it->second; - } - } - - if (GpuLower::current()->caMap()->areMapped( - local_id, fl_id, IdMappingMode::PERMISSIVE)) { - info.alloc_pos++; - } - - info.init_for_loop = fl; - - if (!outer_alloc_found) { - info.alloc_for_loop = fl; - } - } - - return info; -} - -} // namespace loop_utils - namespace { class ReplaceExprInput : private kir::ExprMutator { @@ -564,8 +446,8 @@ class ReplaceExprInput : private kir::ExprMutator { // Copy predicates and register expression replacement void registerReplaceWithPredicate(Expr* old_expr, Expr* new_expr) { - new_expr->setPredicate(old_expr->predicate()); - new_expr->setWritePredicate(old_expr->writePredicate()); + new_expr = new_expr->withPredicate(old_expr->predicate()) + ->withWritePredicate(old_expr->writePredicate()); registerReplace(old_expr, new_expr); } @@ -703,15 +585,161 @@ std::vector replaceInputsInExpr( return ReplaceExprInput::replace(exprs, replacement_map); } -bool isTrivialIterDomain(IterDomain* id) { - auto pt = id->getParallelType(); - return id->isReduction() || id->isBroadcast() || id->isStride() || - (id->extent()->isOneInt() && id->start()->isZeroInt()) || - pt == ParallelType::Vectorize || - (isParallelTypeThread(pt) && - !GpuLower::current()->haloInfo().hasHaloWidth(id)); +std::vector getAllSwizzlesBetween( + std::vector from, + std::vector to) { + auto all_expr = DependencyCheck::getAllExprsBetween( + {from.begin(), from.end()}, {to.begin(), to.end()}); + + std::vector all_swizzles; + + std::copy_if( + all_expr.begin(), + all_expr.end(), + std::back_inserter(all_swizzles), + [](Expr* expr) { + return expr->getExprType().has_value() && + (expr->etype() == ExprType::Swizzle2D); + }); + + return all_swizzles; +} + +} // namespace ir_utils + +namespace lower_utils { + +bool hasBlockSync(const Expr* expr, const ThreadPredicateMap& pred_map) { + if (expr->isA()) { + return true; + } + + if (!ir_utils::isTvOp(expr)) { + return false; + } + + if (!(ir_utils::isReductionOp(expr) || expr->isA() || + expr->isA())) { + return false; + } + + // GroupedReductionOp can have multiple output TVs, but they must be + // parallelized in the same way, so just checking one of them is enough. + auto tv = ir_utils::getTvOutput(expr); + + if (tv->hasBlockReduction() || tv->hasGridReduction()) { + return true; + } else if (expr->isA()) { + const ParallelTypeBitmap pt_map = + GpuLower::current()->threadPredMap().getParallelBroadcastDomains(tv); + return pt_map.any(); + } + + return false; +} + +kir::Allocate* allocGlobalBufferForGridComm( + Val* buffer_size, + DataType dtype, + bool zero_init) { + const std::vector new_buffer_ids = { + IrBuilder::create(IterDomainBuilder( + GpuLower::current()->kernel()->zeroVal(), buffer_size))}; + const auto buffer_domain = IrBuilder::create(new_buffer_ids); + const auto buffer_tv = + IrBuilder::create(buffer_domain, dtype, MemoryType::Global); + return IrBuilder::create( + buffer_tv, buffer_tv->getMemoryType(), nullptr, zero_init); +} + +BasicAllocInfo getAllocInformation( + const TensorView* tv, + const std::vector& for_loops, + const std::unordered_map& id_map, + bool use_id_map) { + BasicAllocInfo info; + auto gpu_lower = GpuLower::current(); + + bool outer_alloc_found = false; + + for (auto fl : for_loops) { + if (info.alloc_pos == tv->getComputeAtPosition()) { + break; + } + + if (tv->axis(info.alloc_pos)->isReduction()) { + const auto outputs = FusionGuard::getCurFusion()->getTerminatingOutputs(); + TORCH_INTERNAL_ASSERT( + std::find(outputs.begin(), outputs.end(), tv) != outputs.end(), + "Invalid computeAt of T", + tv->name(), + ". A reducation axis is detected outside computeAt point even though it is not an output tensor."); + break; + } + + auto fl_id = fl->iter_domain(); + + if (fl_id->getParallelType() == ParallelType::Unroll) { + break; + } + + // Shared memory must be allocated outside of unswitched + // domains. See issue #1133. + if (fl_id->getParallelType() == ParallelType::Unswitch && + tv->getMemoryType() == MemoryType::Shared) { + outer_alloc_found = true; + } + + // Assume global memory is allocated at outer most scope. + if (tv->getMemoryType() == MemoryType::Global) { + outer_alloc_found = true; + } + + // Allocation of a double buffered tensor is placed outside its + // double buffer axis. + if ((tv->isDoubleBuffered() || tv->isCircularBuffered()) && + tv->axis(info.alloc_pos) == + gpu_lower->doubleBufferInfo().getDoubleBufferAxis(tv)) { + outer_alloc_found = true; + } + + auto local_id = tv->axis(info.alloc_pos); + + if (use_id_map) { + auto id_it = id_map.find(local_id); + if (id_it != id_map.end()) { + local_id = id_it->second; + } + } + + if (GpuLower::current()->caMap()->areMapped( + local_id, fl_id, IdMappingMode::PERMISSIVE)) { + info.alloc_pos++; + } + + info.init_for_loop = fl; + + if (!outer_alloc_found) { + info.alloc_for_loop = fl; + } + } + + return info; +} + +//! Implementing this in here to avoid including too many headers +//! in type.cpp. Conceptually this should be a generic definition +//! rather than a util. +bool supportInlinePredicate(Expr* expr) { + if (ir_utils::isCpAsyncOp(expr)) { + return true; + } + // TODO: build out support. + return false; } +} // namespace lower_utils + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.h b/torch/csrc/jit/codegen/cuda/lower_utils.h index d8821fd0d4ebe..4807c1e5520ea 100644 --- a/torch/csrc/jit/codegen/cuda/lower_utils.h +++ b/torch/csrc/jit/codegen/cuda/lower_utils.h @@ -39,24 +39,32 @@ namespace ir_utils { // producers with a consumer set of indices, so we need to view the producer // transformed like consumer while we index. This will set the tv with td for // the life of this context guard. -class TVDomainGuard { +class TORCH_CUDA_CU_API TVDomainGuard { private: TensorView* tv_; - TensorDomain* prev_domain; + TensorDomain* prev_domain_; public: - explicit TVDomainGuard(TensorView* _tv, TensorDomain* td); + explicit TVDomainGuard(TensorView* tv, TensorDomain* td); + TVDomainGuard(const TVDomainGuard&) = delete; + TVDomainGuard(TVDomainGuard&&); //! An utility to access the tensordomain before the temporary //! view. This is used to retrieve information, like swizzle //! information that can only be reliably kept at the original domain. const TensorDomain* prevDomain() const { - return prev_domain; + return prev_domain_; } ~TVDomainGuard(); }; +// Create a TVDomainGuard that temporarily view a tensorview with specified +// all-true or all-false contiguity. +TORCH_CUDA_CU_API ir_utils::TVDomainGuard overrideContiguityGuard( + TensorView* tv, + bool contiguity); + //! Return inputs of provided IterDomains that are IterDomains. A list //! of input IterDomain can be optionally given. Otherwise, //! IterDomains with no defining expression are returned. @@ -82,8 +90,6 @@ TORCH_CUDA_CU_API TensorView* getTvOutput(const Expr*); // Returns the first input of Expr that is a TensorView TORCH_CUDA_CU_API TensorView* getTvInput(const Expr*); -bool hasBlockSync(const Expr* expr, const ThreadPredicateMap& pred_map); - //! Returns the iterdomain that maps to the thread dimension grouped //! to warps. Returns nullopt if the reduction is not to be lowered to //! a warp reduction. @@ -108,13 +114,6 @@ bool derivedFromRootCAAxes(const TensorView* tv, IterDomain* axis); std::unordered_map getParallelDomains( const Val* val); -// Allocate global buffer for a grid communication calls, i.e. grid reduce, grid -// welford reduce, grid broadcast. -kir::Allocate* allocGlobalBufferForGridComm( - Val* buffer_size, - DataType dtype, - bool zero_init); - //! Returns true if the expression will be lowered to //! a ldmatrix intrinsic. bool isLdMatrixOp(const Expr* expr); @@ -150,49 +149,12 @@ bool isTensorScalarFillOp(const Expr* expr); TORCH_CUDA_CU_API std::vector flattenScopedExprs( const std::vector& loop_nests); -//! Returns the concretized iterdomain according to -//! the exact compute at map. -IterDomain* caMapExactConcreteId(IterDomain* id); - //! Returns all swizzle ops between the set of iterdomains //! in `from` and `to`. std::vector getAllSwizzlesBetween( std::vector from, std::vector to); -} // namespace ir_utils - -namespace loop_utils { - -struct BasicAllocInfo { - // The for loop that the initialization of this allocation must be - // placed in, nullptr if not within a loop - kir::ForLoop* init_for_loop = nullptr; - - // Keep track of the actual allocation loop. This can be different - // from init_for_loop only with unswitched shared memory allocations, - // which are moved outer loops to avoid duplicated allocations. This means - // that the alloc position may be outside what's expected. Most applications - // outside lower_allocation is likely looking for init_for_loop which is - // more directly related to how large an allocation is and how it's used. - // (see issue #1133). - kir::ForLoop* alloc_for_loop = nullptr; - - // The allocation position relative to buffer IDs, it could be outside the - // compute at position if it's shared memory with a compute at inside an - // unswitch - size_t alloc_pos = 0; -}; - -// Fill the above allocation struct based on provided information. id_map is -// used if we're looking at a producer tensor but loops on a consumer tensor. -BasicAllocInfo getAllocInformation( - const TensorView* tv, - const std::vector& loops, - const std::unordered_map& id_map = {}, - bool use_id_map = false); -} // namespace loop_utils - // Replace value pass on Kernel IR. // Replace each use of any Val* that apears in the given `replacement_map` // Keeps the predicate carried by each expr @@ -203,9 +165,6 @@ std::vector replaceInputsInExpr( const std::vector& exprs, const std::unordered_map& replacement_map); -// True if an IterDomain does not materialize a loop -bool isTrivialIterDomain(IterDomain* id); - // Go through all expressions and compute a local ordering of loops. operator< // is implemented based on the concrete_id_dependencies analysis done. If // there's no dependency between two IDs then order doesn't mater, otherwise we @@ -235,7 +194,7 @@ struct TORCH_CUDA_CU_API IterDomainDependencySorter { IterDomainDependencySorter( const std::unordered_map>& concrete_id_dependencies, - const std::unique_ptr& compute_at_map) + std::shared_ptr compute_at_map) : concrete_id_dependencies_(concrete_id_dependencies), compute_at_map_(compute_at_map) {} @@ -261,9 +220,56 @@ struct TORCH_CUDA_CU_API IterDomainDependencySorter { const std::unordered_map>& concrete_id_dependencies_; - const std::unique_ptr& compute_at_map_; + const std::shared_ptr compute_at_map_; }; +} // namespace ir_utils + +namespace lower_utils { + +bool hasBlockSync(const Expr* expr, const ThreadPredicateMap& pred_map); + +// Allocate global buffer for a grid communication calls, i.e. grid reduce, grid +// welford reduce, grid broadcast. +kir::Allocate* allocGlobalBufferForGridComm( + Val* buffer_size, + DataType dtype, + bool zero_init); + +struct BasicAllocInfo { + // The for loop that the initialization of this allocation must be + // placed in, nullptr if not within a loop + kir::ForLoop* init_for_loop = nullptr; + + // Keep track of the actual allocation loop. This can be different + // from init_for_loop only with unswitched shared memory allocations, + // which are moved outer loops to avoid duplicated allocations. This means + // that the alloc position may be outside what's expected. Most applications + // outside lower_allocation is likely looking for init_for_loop which is + // more directly related to how large an allocation is and how it's used. + // (see issue #1133). + kir::ForLoop* alloc_for_loop = nullptr; + + // The allocation position relative to buffer IDs, it could be outside the + // compute at position if it's shared memory with a compute at inside an + // unswitch + size_t alloc_pos = 0; +}; + +// Fill the above allocation struct based on provided information. id_map is +// used if we're looking at a producer tensor but loops on a consumer tensor. +BasicAllocInfo getAllocInformation( + const TensorView* tv, + const std::vector& loops, + const std::unordered_map& id_map = {}, + bool use_id_map = false); + +//! Returns true if the expression has a variant that takes a predicate +//! as an inline argument. +bool supportInlinePredicate(Expr* expr); + +} // namespace lower_utils + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/lower_validation.cpp b/torch/csrc/jit/codegen/cuda/lower_validation.cpp index de2c1135ad202..f6f71c2ec123a 100644 --- a/torch/csrc/jit/codegen/cuda/lower_validation.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_validation.cpp @@ -86,7 +86,7 @@ class ValidateSiblings : public IterVisitor { auto sibling_id = it->second; TORCH_INTERNAL_ASSERT( sibling->axis(i) == sibling_id, - "Invalid matching sinbling ID detected. Expr: ", + "Invalid matching sibling ID detected. Expr: ", expr->toString(), "Sibling ID: ", sibling_id->toString()); @@ -1183,7 +1183,7 @@ void validateAndConvertIterDomainGrouping(Fusion* fusion) { // Halo is not allowed TORCH_CHECK( - GpuLower::current()->haloInfo().getExtent(id) == nullptr, + GpuLower::current()->haloInfo()->getExtent(id) == nullptr, "Invalid use of ParallelType::Group.", " Grouping of halo-extended IterDomain, ", id->toString(), diff --git a/torch/csrc/jit/codegen/cuda/lower_warp_reduce.cpp b/torch/csrc/jit/codegen/cuda/lower_warp_reduce.cpp index 1d87790c014fb..ff603c1d18f64 100644 --- a/torch/csrc/jit/codegen/cuda/lower_warp_reduce.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_warp_reduce.cpp @@ -136,7 +136,7 @@ class EliminateDeadBroadcastAndAllocate { //! be removed, and generates a replacement map from the broadcast //! output to reduction output. //! -//! 2. kir_utils::replaceInputsInExpr replaces applicable uses of +//! 2. ir_utils::replaceInputsInExpr replaces applicable uses of //! the broadcast output with the corresponding reduction output. //! //! 3. EliminateDeadBroadcastAndAllocate removes the broadcast ops @@ -145,8 +145,8 @@ class FuseBroadcastWithWarpReduce : private kir::IrVisitor { public: static std::vector fuse(const std::vector& exprs) { FuseBroadcastWithWarpReduce fuse_broadcast_map(exprs); - const auto replaced_inputs = - replaceInputsInExpr(exprs, fuse_broadcast_map.val_replacement_map_); + const auto replaced_inputs = ir_utils::replaceInputsInExpr( + exprs, fuse_broadcast_map.val_replacement_map_); return EliminateDeadBroadcastAndAllocate::run(replaced_inputs); } diff --git a/torch/csrc/jit/codegen/cuda/manager.cpp b/torch/csrc/jit/codegen/cuda/manager.cpp index 22f914de407ee..4eb61c78b749f 100644 --- a/torch/csrc/jit/codegen/cuda/manager.cpp +++ b/torch/csrc/jit/codegen/cuda/manager.cpp @@ -62,12 +62,16 @@ namespace { // in the fallback path. void enableAliasCopyNodes(const std::shared_ptr& graph, Block* block) { static std::unordered_set alias_copy_op( - {prim::view_copy, - prim::reshape_copy, - prim::expand_copy, + {prim::expand_copy, prim::expand_as_copy, + prim::flatten_copy, + prim::permute_copy, + prim::reshape_copy, prim::squeeze_copy, - prim::unsqueeze_copy}); + prim::t_copy, + prim::transpose_copy, + prim::unsqueeze_copy, + prim::view_copy}); for (Node* n : block->nodes()) { for (Block* b : n->blocks()) { diff --git a/torch/csrc/jit/codegen/cuda/mutator.cpp b/torch/csrc/jit/codegen/cuda/mutator.cpp index 96bc40c20c90c..12a3de15f4a7f 100644 --- a/torch/csrc/jit/codegen/cuda/mutator.cpp +++ b/torch/csrc/jit/codegen/cuda/mutator.cpp @@ -125,7 +125,18 @@ void OptOutMutator::mutate(kir::TensorIndex*) { TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); } -// MUTATE FUNCTIONS FOR EXPRESSIONS. +void OptOutMutator::mutate(FullOp* fop) { + Val* out = maybeMutated(fop->output(0)); + Val* fill_value = maybeMutated(fop->getFillValue()); + + if (out->sameAs(fop->output(0))) { + return; + } + auto container = fop->container(); + container->removeExpr(fop); + IrBuilder::create(container, out, fill_value, fop->dtype()); +} + void OptOutMutator::mutate(ARangeOp* aop) { Val* out = maybeMutated(aop->output(0)); @@ -140,7 +151,20 @@ void OptOutMutator::mutate(ARangeOp* aop) { aop->start(), aop->end(), aop->step(), - aop->getLinearIndex()); + aop->dtype(), + aop->getLinearLogicalIndex()); +} + +void OptOutMutator::mutate(EyeOp* eop) { + Val* out = maybeMutated(eop->output(0)); + + if (out->sameAs(eop->output(0))) { + return; + } + auto container = eop->container(); + container->removeExpr(eop); + IrBuilder::create( + container, out, eop->dtype(), eop->getIndex1(), eop->getIndex2()); } void OptOutMutator::mutate(UnaryOp* uop) { @@ -190,8 +214,13 @@ void OptOutMutator::mutate(TernaryOp* top) { void OptOutMutator::mutate(RNGOp* rop) { Val* out = maybeMutated(rop->output(0)); + auto& parameters = rop->getParameters(); + std::vector mutated_parameters; + for (auto v : parameters) { + mutated_parameters.emplace_back(maybeMutated(v)); + } - if (out == rop->output(0)) { + if (out == rop->output(0) && mutated_parameters == parameters) { return; } @@ -199,7 +228,13 @@ void OptOutMutator::mutate(RNGOp* rop) { auto rop_type = rop->getRNGOpType(); container->removeExpr(rop); IrBuilder::create( - container, rop_type, out, rop->getRNGOffset(), rop->getPhiloxIndex()); + container, + rop_type, + out, + rop->dtype(), + mutated_parameters, + rop->getRNGOffset(), + rop->getPhiloxIndex()); } void OptOutMutator::mutate(ReductionOp* rop) { diff --git a/torch/csrc/jit/codegen/cuda/non_divisible_split.cpp b/torch/csrc/jit/codegen/cuda/non_divisible_split.cpp index 3a2ab5f5eb5be..eaff9274892dd 100644 --- a/torch/csrc/jit/codegen/cuda/non_divisible_split.cpp +++ b/torch/csrc/jit/codegen/cuda/non_divisible_split.cpp @@ -23,7 +23,7 @@ void NonDivisibleSplitInfo::build(Fusion* fusion) { tv->domain()->domain().begin(), tv->domain()->domain().end()); current_tv_ = tv; clearReachability(); - traverseFrom(fusion, domain_vals); + traverseTo(fusion, domain_vals); current_tv_ = nullptr; } @@ -53,7 +53,16 @@ void NonDivisibleSplitInfo::handle(Split* split) { splits_to_validate_.insert(split); } else { // Not proven to be a divisible split - splits_to_predicate_[current_tv_].push_back(split); + auto gpu_lower = GpuLower::current(); + TORCH_INTERNAL_ASSERT(gpu_lower != nullptr); + + // If we know this split must be divisible, it's either validated as + // above, exact matches to a case matching the above, or exact matches + // to a transformation from view which must be divisible. + if (gpu_lower->divisbleSplitSet().find(split) == + gpu_lower->divisbleSplitSet().end()) { + splits_to_predicate_[current_tv_].push_back(split); + } } is_protected = true; diff --git a/torch/csrc/jit/codegen/cuda/nvfuser.cmake b/torch/csrc/jit/codegen/cuda/nvfuser.cmake index 526a674e4fb4c..147003054766b 100644 --- a/torch/csrc/jit/codegen/cuda/nvfuser.cmake +++ b/torch/csrc/jit/codegen/cuda/nvfuser.cmake @@ -1,6 +1,4 @@ -if(BUILD_SPLIT_CUDA) - set(TORCHLIB_FLAVOR torch_cuda_cu) # chose torch_cuda_cu here since JIT is in torch_cuda_cpp -elseif(USE_CUDA) +if(USE_CUDA) set(TORCHLIB_FLAVOR torch_cuda) elseif(USE_ROCM) set(TORCHLIB_FLAVOR torch_hip) diff --git a/torch/csrc/jit/codegen/cuda/ops/alias.cpp b/torch/csrc/jit/codegen/cuda/ops/alias.cpp index b51c64a0bab0e..20c6ee533063d 100644 --- a/torch/csrc/jit/codegen/cuda/ops/alias.cpp +++ b/torch/csrc/jit/codegen/cuda/ops/alias.cpp @@ -36,6 +36,8 @@ TensorView* applyViewTransforms( TensorView* orig_tv, TensorView* post_reduce_tv, const AnalyzeViewResult& view_analysis) { + TORCH_INTERNAL_ASSERT(orig_tv != nullptr, "Input is invalid."); + TORCH_INTERNAL_ASSERT(post_reduce_tv != nullptr, "Input is invalid."); TORCH_INTERNAL_ASSERT( !post_reduce_tv->hasComputeAt(), "Cannot modify rfactor domain after compute at has been set."); @@ -43,10 +45,6 @@ TensorView* applyViewTransforms( TORCH_INTERNAL_ASSERT( post_reduce_tv->nDims() > 0, "Tried to view a 0-dim TensorView"); - TORCH_CHECK( - !post_reduce_tv->domain()->hasRFactor(), - "Cannot call view on the same TensorView twice."); - TORCH_INTERNAL_ASSERT(!view_analysis.transforms.empty()); TensorView* consumer = IrBuilder::create( @@ -62,6 +60,7 @@ TensorView* applyViewTransforms( } // namespace TensorView* view(TensorView* x, DataType dtype) { + TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid."); if (x->getDataType() == dtype) { return x; } @@ -81,6 +80,7 @@ TensorView* view( TensorView* x, const std::vector& original_sizes, const std::vector& new_sizes) { + TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid."); TORCH_INTERNAL_ASSERT( TensorDomain::noReductions(x->getMaybeRFactorDomain()).size() == original_sizes.size()); @@ -111,6 +111,7 @@ TensorView* view( } TensorView* flatten(TensorView* x, int64_t start_dim, int64_t end_dim) { + TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid."); auto inp_domain = TensorDomain::noReductions(x->getMaybeRFactorDomain()); if (start_dim < 0) { start_dim += inp_domain.size(); @@ -140,6 +141,7 @@ TensorView* flatten(TensorView* x, int64_t start_dim, int64_t end_dim) { } TensorView* squeeze(TensorView* x, const std::vector& sizes) { + TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid."); const auto ndims = static_cast(x->domain()->noReductions().size()); TORCH_INTERNAL_ASSERT( @@ -163,6 +165,7 @@ TensorView* squeeze(TensorView* x, const std::vector& sizes) { } TensorView* squeeze(TensorView* x, const std::vector& sizes, int dim) { + TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid."); const auto ndims = static_cast(x->domain()->noReductions().size()); TORCH_INTERNAL_ASSERT( @@ -191,6 +194,7 @@ TensorView* squeeze(TensorView* x, const std::vector& sizes, int dim) { } TensorView* unsqueeze(TensorView* x, int dim) { + TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid."); const auto ndims = static_cast(x->domain()->noReductions().size()); if (dim < 0) { @@ -210,17 +214,31 @@ TensorView* unsqueeze(TensorView* x, int dim) { } TensorView* permute(TensorView* x, const std::vector& new2old) { + TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid."); if (new2old.size() == 0) { return set(x); } auto inp_domain = TensorDomain::noReductions(x->getMaybeRFactorDomain()); std::vector out_domain(inp_domain.size()); + TORCH_CHECK( + inp_domain.size() == new2old.size(), + "The number of dimensions in the tensor input does not match the length", + " of the desired ordering of dimensions i.e. input.dim() = ", + inp_domain.size(), + " is not equal to len(dims) = ", + new2old.size()); + + // Return scalar tensors immediately + if (inp_domain.size() == 0) { + return set(x); + } + auto normalized_new2old = ir_utils::normalizeNew2Old(new2old, inp_domain.size()); for (const auto i : c10::irange(out_domain.size())) { - auto in_id = inp_domain[new2old[i]]; + auto in_id = inp_domain[normalized_new2old[i]]; out_domain[i] = in_id->cloneWithoutRFactor(); } @@ -233,6 +251,7 @@ TensorView* permute(TensorView* x, const std::vector& new2old) { } TensorView* transpose(TensorView* x, int64_t dim0, int64_t dim1) { + TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid."); const auto ndims = static_cast(x->domain()->noReductions().size()); if (dim0 < 0) { @@ -263,6 +282,7 @@ TensorView* transpose(TensorView* x, int64_t dim0, int64_t dim1) { } TensorView* transpose(TensorView* x) { + TORCH_INTERNAL_ASSERT(x != nullptr, "Input is invalid."); const auto ndims = static_cast(x->domain()->noReductions().size()); TORCH_CHECK( diff --git a/torch/csrc/jit/codegen/cuda/ops/composite.cpp b/torch/csrc/jit/codegen/cuda/ops/composite.cpp index 5aa1d64c5cf1a..a7905c4894c15 100644 --- a/torch/csrc/jit/codegen/cuda/ops/composite.cpp +++ b/torch/csrc/jit/codegen/cuda/ops/composite.cpp @@ -27,7 +27,7 @@ ForwardDropoutResult dropout(TensorView* x, Val* prob, Val* scale) { scale->getDataType().value() == DataType::Double, "Scale is not a valid Double."); - auto rand_vals = randlike(x); + auto rand_vals = rand_like(x); auto mask = lt(rand_vals, prob); auto apply_mask = mul(x, mask); auto y = mul(apply_mask, scale); diff --git a/torch/csrc/jit/codegen/cuda/ops/normalization.cpp b/torch/csrc/jit/codegen/cuda/ops/normalization.cpp index 4484539467cd4..f1739c665f035 100644 --- a/torch/csrc/jit/codegen/cuda/ops/normalization.cpp +++ b/torch/csrc/jit/codegen/cuda/ops/normalization.cpp @@ -589,8 +589,7 @@ ForwardNormResult batch_norm( // During inference, mean/invstd output are empty tensors // on CPU, but not on CUDA. We need to make sure we have the same // behavior as with eager mode on CUDA. - mean = set(running_mean); // use set to avoid "trivial input forwarding NOT - // IMPLEMENTED" error + mean = running_mean; invstd = unbiased_invstd; y = mul(x_sub_mean, invstd_bcast); } @@ -843,8 +842,10 @@ ForwardNormResult instance_norm( broadcast(unbiased_invstd, channels_only_broadcast_mask); // During inference, mean/invstd output are empty tensors - mean = TensorViewBuilder().shape(std::vector{0}).build(); - invstd = TensorViewBuilder().shape(std::vector{0}).build(); + // on CPU, but not on CUDA. We need to make sure we have the same + // behavior as with eager mode on CUDA. + mean = running_mean; + invstd = unbiased_invstd; y = mul(x_sub_mean, invstd_bcast); } diff --git a/torch/csrc/jit/codegen/cuda/parser.cpp b/torch/csrc/jit/codegen/cuda/parser.cpp index 251e1e6f11a2d..e78d5effbee3e 100644 --- a/torch/csrc/jit/codegen/cuda/parser.cpp +++ b/torch/csrc/jit/codegen/cuda/parser.cpp @@ -1321,7 +1321,7 @@ class IrParser { } } - auto out = randlike(operand); + auto out = rand_like(operand); value_map.emplace( node->output()->unique(), ValueHolder(out, format)); }, @@ -3378,6 +3378,115 @@ class IrParser { }, nullptr); } + + { + auto ptr_op = getOperatorForLiteral( + "prim::permute_copy.int(Tensor(a) self, int[] dims) -> Tensor"); + REGISTER_PARSE_RULE( + ptr_op, + { + MemoryFormat format; + std::list list_val; + std::tie(format, list_val) = getConsistentValues( + c10::nullopt, value_map[node->inputs()[0]->unique()]); + auto self_t = list_val.front(); + list_val.pop_front(); + auto self = self_t->as(); + + auto dims = constant_as>(node->input(1)); + TORCH_INTERNAL_ASSERT( + dims.has_value(), "The dims parameter is required."); + TORCH_INTERNAL_ASSERT( + dims.value().size() == self->getMaybeRFactorDomain().size()); + + auto output = permute(self, dims->vec()); + value_map.emplace( + node->output()->unique(), ValueHolder(output, format)); + }, + [](const Node* node) -> bool { + if (!isInputNonSizeZeroTensor(node)) { + return false; + } + auto dims = constant_as>(node->input(1)); + if (!dims.has_value()) { + return false; + } + + return true; + }, + nullptr); + } + + { + auto ptr_op = getOperatorForLiteral( + "prim::transpose_copy.int(Tensor(a) self, int dim0, int dim1) -> Tensor"); + REGISTER_PARSE_RULE( + ptr_op, + { + MemoryFormat format; + std::list list_val; + std::tie(format, list_val) = getConsistentValues( + c10::nullopt, value_map[node->inputs()[0]->unique()]); + auto self_t = list_val.front(); + list_val.pop_front(); + auto self = self_t->as(); + + auto dim0 = constant_as(node->input(1)); + TORCH_INTERNAL_ASSERT( + dim0.has_value(), "dim0 in transpose is not valid."); + + auto dim1 = constant_as(node->input(2)); + TORCH_INTERNAL_ASSERT( + dim1.has_value(), "dim1 in transpose is not valid."); + + auto output = transpose(self, dim0.value(), dim1.value()); + value_map.emplace( + node->output()->unique(), ValueHolder(output, format)); + }, + [](const Node* node) -> bool { + if (!isInputNonSizeZeroTensor(node)) { + return false; + } + if (node->input(1)->node()->kind() != prim::Constant) { + return false; + } + if (node->input(2)->node()->kind() != prim::Constant) { + return false; + } + return true; + }, + nullptr); + } + + { + auto ptr_op = + getOperatorForLiteral("prim::t_copy(Tensor(a) self) -> Tensor"); + REGISTER_PARSE_RULE( + ptr_op, + { + MemoryFormat format; + std::list list_val; + std::tie(format, list_val) = getConsistentValues( + c10::nullopt, value_map[node->inputs()[0]->unique()]); + auto self_t = list_val.front(); + list_val.pop_front(); + auto self = self_t->as(); + + TORCH_INTERNAL_ASSERT(self->getMaybeRFactorDomain().size() <= 2); + + auto output = transpose(self); + value_map.emplace( + node->output()->unique(), ValueHolder(output, format)); + }, + [](const Node* node) -> bool { + if (!isInputNonSizeZeroTensor(node)) { + return false; + } + + return true; + }, + nullptr); + } } void processJitNode(const JitOp* node) { @@ -4141,6 +4250,49 @@ bool insertProfileIValue(ProfilingRecord* pr, Node* node, size_t offset) { return true; } + static auto permute_schema = + getOperatorForLiteral( + "aten::permute(Tensor(a) self, int[] dims) -> Tensor(a)") + ->schema(); + static auto permute_copy_schema = + getOperatorForLiteral( + "prim::permute_copy(Tensor(a) self, int[] dims) -> Tensor") + ->schema(); + if (node->matches(permute_schema) || node->matches(permute_copy_schema)) { + switch (offset) { + // argument 1: dims; + case 1: + profileIntList(pr, node, offset); + break; + default: + return false; + } + return true; + } + + static auto transpose_int_copy_schema = + getOperatorForLiteral( + "aten::transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)") + ->schema(); + static auto transpose_int_schema = + getOperatorForLiteral( + "prim::transpose_copy.int(Tensor(a) self, int dim0, int dim1) -> Tensor") + ->schema(); + if (node->matches(transpose_int_copy_schema) || + node->matches(transpose_int_schema)) { + switch (offset) { + // argument 1: dim0; + // argument 2: dim1; + case 1: + case 2: + profileInt(pr, node, offset); + break; + default: + return false; + } + return true; + } + static auto batch_norm_impl_index_schema = getOperatorForLiteral( "aten::_batch_norm_impl_index(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor, Tensor, Tensor, Tensor, int)") @@ -4352,6 +4504,30 @@ bool insertProfileIValue(ProfilingRecord* pr, Node* node, size_t offset) { } } + static auto var_dim_schema = + getOperatorForLiteral( + "aten::var.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor") + ->schema(); + static auto std_dim_schema = + getOperatorForLiteral( + "aten::std.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor") + ->schema(); + if (node->matches(var_dim_schema) || node->matches(std_dim_schema)) { + switch (offset) { + case 1: + profileIntList(pr, node, offset); + return true; + case 2: + profileBool(pr, node, offset); + return true; + case 3: + profileBool(pr, node, offset); + return true; + default: + return false; + } + } + return false; } diff --git a/torch/csrc/jit/codegen/cuda/python_frontend/README.md b/torch/csrc/jit/codegen/cuda/python_frontend/README.md index 7f3364e05c69b..d519e69bcda3c 100644 --- a/torch/csrc/jit/codegen/cuda/python_frontend/README.md +++ b/torch/csrc/jit/codegen/cuda/python_frontend/README.md @@ -51,7 +51,7 @@ nvf_out = fs.execute([input1, input2])[0] * `id()`: Returns the fusion id for a given `Fusion`. * `print()`: Prints the low level IR for the currently defined fusion. -### `FusionDefiniton` Context Manager - Interface for Defining Fusions +### `FusionDefinition` Context Manager - Interface for Defining Fusions #### Defining Input Tensors _All intermediate tensors are created by operations. Constant tensors do not exist._ @@ -108,7 +108,7 @@ python -c "from torch._C._nvfuser import FusionDefinition; help(FusionDefinition ``` #### Notating Outputs -The `FusionDefintion` `add_output` method is used to indicate an intermediate is an output to the fusion. +The `FusionDefinition` `add_output` method is used to indicate an intermediate is an output to the fusion. ```python add_output(output: Tensor) diff --git a/torch/csrc/jit/codegen/cuda/python_frontend/fusion_record.h b/torch/csrc/jit/codegen/cuda/python_frontend/fusion_record.h index f124cf36e0092..1974fc66f6fa9 100644 --- a/torch/csrc/jit/codegen/cuda/python_frontend/fusion_record.h +++ b/torch/csrc/jit/codegen/cuda/python_frontend/fusion_record.h @@ -19,11 +19,11 @@ enum class RecordType { Op, BatchNormOp, BroadcastOp, + BroadcastInDimOp, CastOp, Constant, End, Tensor, - NullTensor, Output, ReductionOp, Scalar, @@ -33,6 +33,7 @@ enum class RecordType { VarianceMeanOp, ViewOp, PermuteOp, + FullOp }; //! RecordFunctor is the base class record for operations recorded by @@ -144,13 +145,14 @@ struct RecordFunctor { os << ", "; } if (arg.stype == StateType::Scalar) { - os << "S"; + os << "S" << arg.index; } else if (arg.stype == StateType::Tensor) { - os << "T"; + os << "T" << arg.index; + } else if (arg.stype == StateType::None) { + os << "None"; } else { TORCH_INTERNAL_ASSERT(false, "Unsupported StateType"); } - os << arg.index; } if (close_function) { os << ")"; @@ -377,13 +379,13 @@ struct PermuteOpRecord : RecordFunctor { PermuteOpRecord( std::vector _args, std::vector _outputs, - std::vector& permutation) + std::vector& dims) : RecordFunctor( std::move(_args), std::move(_outputs), - "permute", + "ops.permute", RecordType::PermuteOp), - permutation_(std::move(permutation)) {} + dims_(std::move(dims)) {} virtual ~PermuteOpRecord() = default; virtual RecordFunctor* clone() final { return new PermuteOpRecord(*this); @@ -391,11 +393,11 @@ struct PermuteOpRecord : RecordFunctor { virtual size_t hash() const final { auto result = RecordFunctor::hash(); - size_t permutation_hash = 0; - for (auto p : permutation_) { - permutation_hash ^= static_cast(p); + size_t dims_hash = 0; + for (auto dim : dims_) { + dims_hash ^= static_cast(dim); } - return result | (permutation_hash & 0xffff); + return result | (dims_hash & 0xffff); } virtual bool operator==(const RecordFunctor& other) const final { @@ -403,10 +405,10 @@ struct PermuteOpRecord : RecordFunctor { if (auto child_ptr = dynamic_cast(&other)) { result = RecordFunctor::operator==(other); if (result) { - result = (permutation_.size() == child_ptr->permutation_.size()); + result = (dims_.size() == child_ptr->dims_.size()); if (result) { - for (size_t i = 0; i < permutation_.size(); ++i) { - if (permutation_[i] != child_ptr->permutation_[i]) { + for (size_t i = 0; i < dims_.size(); ++i) { + if (dims_[i] != child_ptr->dims_[i]) { result = false; break; } @@ -420,13 +422,31 @@ struct PermuteOpRecord : RecordFunctor { void operator()(FusionDefinition& fd) final { auto arg = fd.getFusionState(args_.at(0).index)->template as(); - auto output = torch::jit::fuser::cuda::permute(arg, permutation_); + auto output = Nvf::permute(arg, dims_); fd.setFusionState(outputs_.at(0).index, output); } + virtual void print(std::ostream& os, bool close_function = true) const { + RecordFunctor::print(os, false); + os << ", dims=["; + bool first_arg = true; + for (auto dim : dims_) { + if (first_arg) { + first_arg = false; + } else { + os << ", "; + } + os << dim; + } + os << "]"; + if (close_function) { + os << ")"; + } + } + private: //! Represents the mapping from the original shape to the new shape - std::vector permutation_; + std::vector dims_; }; struct SqueezeOpRecord : RecordFunctor { @@ -438,7 +458,7 @@ struct SqueezeOpRecord : RecordFunctor { : RecordFunctor( std::move(_args), std::move(_outputs), - "squeeze", + "ops.squeeze", RecordType::SqueezeOp), original_shape_(std::move(original_shape)), dim_(dim) {} @@ -518,8 +538,8 @@ struct SqueezeOpRecord : RecordFunctor { //! Specialized Record Functor for the FusionDefinition's broadcast_in_dim op. -struct BroadcastOpRecord : RecordFunctor { - BroadcastOpRecord( +struct BroadcastInDimOpRecord : RecordFunctor { + BroadcastInDimOpRecord( std::vector _args, std::vector _outputs, std::string _name, @@ -529,12 +549,12 @@ struct BroadcastOpRecord : RecordFunctor { std::move(_args), std::move(_outputs), _name, - RecordType::BroadcastOp), + RecordType::BroadcastInDimOp), output_shape_(std::move(output_shape)), broadcast_dims_(std::move(broadcast_dims)) {} - virtual ~BroadcastOpRecord() = default; + virtual ~BroadcastInDimOpRecord() = default; virtual RecordFunctor* clone() final { - return new BroadcastOpRecord(*this); + return new BroadcastInDimOpRecord(*this); } //! Child specific hash function in lower 32 bits. @@ -556,7 +576,7 @@ struct BroadcastOpRecord : RecordFunctor { virtual bool operator==(const RecordFunctor& other) const final { auto result = false; - if (auto child_ptr = dynamic_cast(&other)) { + if (auto child_ptr = dynamic_cast(&other)) { result = RecordFunctor::operator==(other); if (result) { result = @@ -680,6 +700,77 @@ struct BroadcastOpRecord : RecordFunctor { std::vector broadcast_dims_; }; +//! Specialized Record Functor for the FusionDefinition's broadcast op. + +struct BroadcastOpRecord : RecordFunctor { + BroadcastOpRecord( + std::vector _args, + std::vector _outputs, + std::string _name, + std::vector& is_broadcast_dim) + : RecordFunctor( + std::move(_args), + std::move(_outputs), + _name, + RecordType::BroadcastOp), + is_broadcast_dim_(std::move(is_broadcast_dim)) {} + virtual ~BroadcastOpRecord() = default; + virtual RecordFunctor* clone() final { + return new BroadcastOpRecord(*this); + } + + virtual size_t hash() const final { + auto result = RecordFunctor::hash(); + size_t is_broadcast_dim_hash = 0; + for (size_t i = 0; i < is_broadcast_dim_.size(); ++i) { + is_broadcast_dim_hash |= + (is_broadcast_dim_[i] << (is_broadcast_dim_.size() - 1 - i)); + } + return result | (is_broadcast_dim_hash & 0xfff); + } + + virtual bool operator==(const RecordFunctor& other) const final { + auto result = false; + if (auto child_ptr = dynamic_cast(&other)) { + result = RecordFunctor::operator==(other); + result &= std::equal( + is_broadcast_dim_.begin(), + is_broadcast_dim_.end(), + child_ptr->is_broadcast_dim_.begin()); + } + return result; + } + + virtual void operator()(FusionDefinition& fd) final { + auto arg = + fd.getFusionState(args_.at(0).index)->template as(); + auto output = Nvf::broadcast(arg, is_broadcast_dim_); + fd.setFusionState(outputs_.at(0).index, output); + } + + virtual void print(std::ostream& os, bool close_function = true) const { + RecordFunctor::print(os, false); + os << ", is_broadcast_dim=["; + bool first_arg = true; + for (auto dim : is_broadcast_dim_) { + if (first_arg) { + first_arg = false; + } else { + os << ", "; + } + os << (dim ? "True" : "False"); + } + os << "]"; + if (close_function) { + os << ")"; + } + } + + private: + //! Communicates which dimensions in the output are broadcasted. + std::vector is_broadcast_dim_; +}; + template struct CastOpRecord : RecordFunctor { CastOpRecord( @@ -974,6 +1065,7 @@ struct TensorRecord : RecordFunctor { } } os << "], dtype=" << dtypeToPyString(dtype_); + os << ", is_cpu=" << (is_cpu_ ? "True" : "False"); if (close_function) { os << ")"; } @@ -993,41 +1085,6 @@ struct TensorRecord : RecordFunctor { bool is_cpu_; }; -struct NullTensorRecord : RecordFunctor { - NullTensorRecord(std::vector _outputs) - : RecordFunctor( - {}, - std::move(_outputs), - "null_tensor", - RecordType::NullTensor) {} - virtual ~NullTensorRecord() = default; - virtual RecordFunctor* clone() final { - return new NullTensorRecord(*this); - } - - //! Nothing extra necessary in hash - //! Child specific hash function in lower 32 bits. - //! | 31 --------------------------------------- 0 | - //! | None | - virtual size_t hash() const final { - auto result = RecordFunctor::hash(); - return result; - } - - virtual bool operator==(const RecordFunctor& other) const final { - auto result = false; - if (dynamic_cast(&other)) { - result = RecordFunctor::operator==(other); - } - return result; - } - - virtual void operator()(FusionDefinition& fd) final { - Nvf::TensorView* tv = nullptr; - fd.setFusionState(outputs_.at(0).index, tv); - } -}; - //! Specialized Record Functor for recording FusionDefinition outputs. template @@ -1482,12 +1539,18 @@ struct BatchNormOpRecord : RecordFunctor { void operator()(FusionDefinition& fd) final { auto x = fd.getFusionState(args_.at(0).index)->as(); - auto weight = fd.getFusionState(args_.at(1).index)->as(); - auto bias = fd.getFusionState(args_.at(2).index)->as(); - auto running_mean = - fd.getFusionState(args_.at(3).index)->as(); - auto running_var = - fd.getFusionState(args_.at(4).index)->as(); + auto weight = (args_.at(1).stype == StateType::Tensor) + ? fd.getFusionState(args_.at(1).index)->as() + : nullptr; + auto bias = (args_.at(2).stype == StateType::Tensor) + ? fd.getFusionState(args_.at(2).index)->as() + : nullptr; + auto running_mean = (args_.at(3).stype == StateType::Tensor) + ? fd.getFusionState(args_.at(3).index)->as() + : nullptr; + auto running_var = (args_.at(4).stype == StateType::Tensor) + ? fd.getFusionState(args_.at(4).index)->as() + : nullptr; auto momentum = fd.getFusionState(args_.at(5).index)->as(); auto eps = fd.getFusionState(args_.at(6).index)->as(); auto output = Nvf::batch_norm( @@ -1505,11 +1568,109 @@ struct BatchNormOpRecord : RecordFunctor { fd.setFusionState(outputs_.at(2).index, output.invstd); } + virtual void print(std::ostream& os, bool close_function = true) const final { + RecordFunctor::print(os, false); + os << ", training=" << (training_ ? "True" : "False"); + os << ", channels_last=" << (channels_last_ ? "True" : "False"); + if (close_function) { + os << ")"; + } + } + private: bool training_; bool channels_last_; }; +struct FullOpRecord : RecordFunctor { + FullOpRecord( + std::vector _args, + std::vector _outputs, + std::vector& shape, + Nvf::DataType dtype) + : RecordFunctor( + std::move(_args), + std::move(_outputs), + "ops.full", + RecordType::FullOp), + shape_(std::move(shape)), + dtype_(dtype) {} + virtual ~FullOpRecord() = default; + virtual RecordFunctor* clone() final { + return new FullOpRecord(*this); + } + + //! Child specific hash function in lower 32 bits. + //! | 31 --- 24 | 23 -------------------------- 0 | + //! | Dtype | Shape hash code | + virtual size_t hash() const final { + auto result = RecordFunctor::hash(); + size_t shape_hash = 0; + for (auto p : shape_) { + shape_hash ^= static_cast(p); + } + result |= ((static_cast(dtype_) & 0xff) << 24); + result |= (shape_hash & 0xffff); + return result; + } + + virtual bool operator==(const RecordFunctor& other) const final { + auto result = false; + if (auto child_ptr = dynamic_cast(&other)) { + result = RecordFunctor::operator==(other); + if (result) { + result = (shape_.size() == child_ptr->shape_.size()); + if (result) { + for (size_t i = 0; i < shape_.size(); ++i) { + if (shape_[i] != child_ptr->shape_[i]) { + result = false; + break; + } + } + } + } + } + return result; + } + + void operator()(FusionDefinition& fd) final { + auto arg = fd.getFusionState(args_.at(0).index)->template as(); + + std::vector nvf_shape( + shape_.size(), nullptr); + for (const auto idx : c10::irange(shape_.size())) { + nvf_shape[idx] = Nvf::IrBuilder::create(shape_.at(idx)); + } + auto output = torch::jit::fuser::cuda::full(nvf_shape, arg, dtype_); + fd.setFusionState(outputs_.at(0).index, output); + } + + virtual void print(std::ostream& os, bool close_function = true) const { + RecordFunctor::print(os, false); + os << ", shape=["; + bool first_arg = true; + for (auto p : shape_) { + if (first_arg) { + first_arg = false; + } else { + os << ", "; + } + os << p; + } + os << "]"; + os << ", dtype=" << dtypeToPyString(dtype_); + if (close_function) { + os << ")"; + } + } + + private: + //! Represents shape of new tensor + std::vector shape_; + //! Type of output + Nvf::DataType dtype_; +}; + } // namespace nvfuser //! Creating the template specialized hash and equal_to functions for a diff --git a/torch/csrc/jit/codegen/cuda/python_frontend/python_bindings.cpp b/torch/csrc/jit/codegen/cuda/python_frontend/python_bindings.cpp index b8c799d00b90a..fc9d105100b9c 100644 --- a/torch/csrc/jit/codegen/cuda/python_frontend/python_bindings.cpp +++ b/torch/csrc/jit/codegen/cuda/python_frontend/python_bindings.cpp @@ -2,6 +2,7 @@ #ifdef USE_CUDA #include +#include #include #include #include @@ -39,6 +40,24 @@ void initNvFuserPythonBindings(PyObject* module) { .value("ComplexDouble", Nvf::DataType::ComplexDouble) .value("Null", Nvf::DataType::Null); + nvfuser.def( + "compute_contiguity", + [](const std::vector& sizes, + const std::vector& strides) { + py::tuple contiguity(sizes.size()); + TORCH_CHECK( + sizes.size() == strides.size(), + "compute_contiguity: Sizes and strides must have the same number of dimensions"); + if (sizes.size() == 0) { + return contiguity; + } + contiguity[sizes.size() - 1] = strides.back() == 1; + for (int64_t i = static_cast(sizes.size()) - 2; i >= 0; --i) { + contiguity[i] = strides[i] == strides[i + 1] * sizes[i + 1]; + } + return contiguity; + }); + //! Binding the FusionCache that holds a cache of Fusions //! This is only bound to provide an interface to get the number of fusions //! that are cached. @@ -126,16 +145,6 @@ void initNvFuserPythonBindings(PyObject* module) { self.defineRecord(new nvfuser::OutputRecord( {self.recordingState(output())})); }) - .def( - "define_null_tensor", - [](nvfuser::FusionDefinition& self) -> nvfuser::Tensor { - FUSER_PERF_SCOPE("FusionDefinition.define_null_tensor"); - nvfuser::Tensor out = self.defineTensor(); - self.defineRecord( - new nvfuser::NullTensorRecord({self.recordingState(out())})); - return out; - }, - py::return_value_policy::reference) .def( "define_tensor", [](nvfuser::FusionDefinition& self, @@ -381,7 +390,7 @@ void initNvFuserPythonBindings(PyObject* module) { NVFUSER_PYTHON_BINDING_UNARY_OP("neg", neg) NVFUSER_PYTHON_BINDING_UNARY_OP("bitwise_not", bitwise_not) NVFUSER_PYTHON_BINDING_UNARY_OP("relu", relu) - NVFUSER_PYTHON_BINDING_UNARY_OP("rand_like", randlike) + NVFUSER_PYTHON_BINDING_UNARY_OP("rand_like", rand_like) NVFUSER_PYTHON_BINDING_UNARY_OP("reciprocal", reciprocal) NVFUSER_PYTHON_BINDING_UNARY_OP("round", round) NVFUSER_PYTHON_BINDING_UNARY_OP("rsqrt", rsqrt) @@ -1191,17 +1200,16 @@ void initNvFuserPythonBindings(PyObject* module) { "permute", [](nvfuser::FusionDefinition::Operators& self, nvfuser::Tensor arg, - std::vector& permutation) -> nvfuser::Tensor { + std::vector& dims) -> nvfuser::Tensor { nvfuser::FusionDefinition* fd = self.fusion_definition; nvfuser::Tensor output = fd->defineTensor(); self.fusion_definition->defineRecord(new nvfuser::PermuteOpRecord( - {fd->recordingState(arg())}, - {fd->recordingState(output())}, - permutation)); + {fd->recordingState(arg())}, {fd->recordingState(output())}, dims)); return output; }, + py::arg("arg"), + py::arg("dims"), py::return_value_policy::reference); - nvf_ops.def( "squeeze", [](nvfuser::FusionDefinition::Operators& self, @@ -1241,7 +1249,25 @@ void initNvFuserPythonBindings(PyObject* module) { py::arg("original_shape"), py::arg("new_shape"), py::return_value_policy::reference); - + nvf_ops.def( + "full", + [](nvfuser::FusionDefinition::Operators& self, + std::vector& size, + nvfuser::Scalar arg, + Nvf::DataType dtype) -> nvfuser::Tensor { + nvfuser::FusionDefinition* fd = self.fusion_definition; + nvfuser::Tensor output = fd->defineTensor(); + fd->defineRecord(new nvfuser::FullOpRecord( + {fd->recordingState(arg())}, + {fd->recordingState(output())}, + size, + dtype)); + return output; + }, + py::arg("size"), + py::arg("arg"), + py::arg("dtype"), + py::return_value_policy::reference); nvf_ops.def( "var", [](nvfuser::FusionDefinition::Operators& self, @@ -1292,26 +1318,38 @@ void initNvFuserPythonBindings(PyObject* module) { nvf_ops.def( "batch_norm", [](nvfuser::FusionDefinition::Operators& self, - nvfuser::Tensor x, - nvfuser::Tensor weight, - nvfuser::Tensor bias, - nvfuser::Tensor running_mean, - nvfuser::Tensor running_var, - bool training, + nvfuser::Tensor arg, + c10::optional weight, + c10::optional bias, + c10::optional running_mean, + c10::optional running_var, nvfuser::Scalar momentum, nvfuser::Scalar eps, + bool training, bool channels_last) -> decltype(auto) { FUSER_PERF_SCOPE("Operators.batch_norm"); nvfuser::FusionDefinition* fd = self.fusion_definition; nvfuser::Tensor output = fd->defineTensor(); nvfuser::Tensor mean = fd->defineTensor(); nvfuser::Tensor invstd = fd->defineTensor(); + auto weight_state = weight.has_value() + ? fd->recordingState(weight.value()()) + : nvfuser::State(0, nvfuser::StateType::None); + auto bias_state = bias.has_value() + ? fd->recordingState(bias.value()()) + : nvfuser::State(0, nvfuser::StateType::None); + auto running_mean_state = running_mean.has_value() + ? fd->recordingState(running_mean.value()()) + : nvfuser::State(0, nvfuser::StateType::None); + auto running_var_state = running_var.has_value() + ? fd->recordingState(running_var.value()()) + : nvfuser::State(0, nvfuser::StateType::None); fd->defineRecord(new nvfuser::BatchNormOpRecord( - {fd->recordingState(x()), - fd->recordingState(weight()), - fd->recordingState(bias()), - fd->recordingState(running_mean()), - fd->recordingState(running_var()), + {fd->recordingState(arg()), + weight_state, + bias_state, + running_mean_state, + running_var_state, fd->recordingState(momentum()), fd->recordingState(eps())}, {fd->recordingState(output()), @@ -1321,14 +1359,14 @@ void initNvFuserPythonBindings(PyObject* module) { channels_last)); return std::make_tuple(output, mean, invstd); }, - py::arg("x"), + py::arg("arg"), py::arg("weight").none(true), py::arg("bias").none(true), py::arg("running_mean").none(true), py::arg("running_var").none(true), - py::arg("training"), py::arg("momentum"), py::arg("eps"), + py::arg("training"), py::arg("channels_last") = false, py::return_value_policy::reference); nvf_ops.def( @@ -1343,7 +1381,7 @@ void initNvFuserPythonBindings(PyObject* module) { output_shape.size() >= broadcast_dims.size(), "broadcast_dims vector size is too big for output shape!"); nvfuser::Tensor output = fd->defineTensor(); - fd->defineRecord(new nvfuser::BroadcastOpRecord( + fd->defineRecord(new nvfuser::BroadcastInDimOpRecord( {fd->recordingState(arg())}, {fd->recordingState(output())}, "ops.broadcast_in_dim", @@ -1355,6 +1393,24 @@ void initNvFuserPythonBindings(PyObject* module) { py::arg("output_shape"), py::arg("broadcast_dims"), py::return_value_policy::reference); + nvf_ops.def( + "broadcast", + [](nvfuser::FusionDefinition::Operators& self, + nvfuser::Tensor arg, + std::vector& is_broadcast_dim) -> nvfuser::Tensor { + FUSER_PERF_SCOPE("Operators.broadcast"); + nvfuser::FusionDefinition* fd = self.fusion_definition; + nvfuser::Tensor output = fd->defineTensor(); + fd->defineRecord(new nvfuser::BroadcastOpRecord( + {fd->recordingState(arg())}, + {fd->recordingState(output())}, + "ops.broadcast", + is_broadcast_dim)); + return output; + }, + py::arg("arg"), + py::arg("is_broadcast_dim"), + py::return_value_policy::reference); } } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/python_frontend/test/test_nvfuser_fusion_cache.cpp b/torch/csrc/jit/codegen/cuda/python_frontend/test/test_nvfuser_fusion_cache.cpp index d1f9d8102a500..607c560dab74d 100644 --- a/torch/csrc/jit/codegen/cuda/python_frontend/test/test_nvfuser_fusion_cache.cpp +++ b/torch/csrc/jit/codegen/cuda/python_frontend/test/test_nvfuser_fusion_cache.cpp @@ -6,6 +6,7 @@ #include #include +#include // Tests go in torch::jit namespace torch { diff --git a/torch/csrc/jit/codegen/cuda/python_frontend/test/test_nvfuser_fusion_definition.cpp b/torch/csrc/jit/codegen/cuda/python_frontend/test/test_nvfuser_fusion_definition.cpp index 84aa4da5909ae..bae9cf6def810 100644 --- a/torch/csrc/jit/codegen/cuda/python_frontend/test/test_nvfuser_fusion_definition.cpp +++ b/torch/csrc/jit/codegen/cuda/python_frontend/test/test_nvfuser_fusion_definition.cpp @@ -8,6 +8,7 @@ #include #include #include +#include // Tests go in torch::jit namespace torch { diff --git a/torch/csrc/jit/codegen/cuda/python_frontend/test/test_nvfuser_fusion_record.cpp b/torch/csrc/jit/codegen/cuda/python_frontend/test/test_nvfuser_fusion_record.cpp index 47785156ef788..5ae2db7db8805 100644 --- a/torch/csrc/jit/codegen/cuda/python_frontend/test/test_nvfuser_fusion_record.cpp +++ b/torch/csrc/jit/codegen/cuda/python_frontend/test/test_nvfuser_fusion_record.cpp @@ -6,6 +6,7 @@ #include #include +#include // Tests go in torch::jit namespace torch { diff --git a/torch/csrc/jit/codegen/cuda/reference_tensor.h b/torch/csrc/jit/codegen/cuda/reference_tensor.h deleted file mode 100644 index 07c83bb6ed74c..0000000000000 --- a/torch/csrc/jit/codegen/cuda/reference_tensor.h +++ /dev/null @@ -1,27 +0,0 @@ -#pragma once - -#include - -#include - -#include - -namespace torch { -namespace jit { -namespace fuser { -namespace cuda { - -struct ReferenceTensor { - TensorDomain* domain = nullptr; - - // Map from concrete iteration domains in ComputeAtMaps to iter domains - // including those used to construct domain. - std::unordered_map concrete_to_id; - // Map from reference iteration domains to concrete iteration domains. - std::unordered_map id_to_concrete; -}; - -} // namespace cuda -} // namespace fuser -} // namespace jit -} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/root_domain_map.cpp b/torch/csrc/jit/codegen/cuda/root_domain_map.cpp index ff3ed11ae1902..235d257e2351d 100644 --- a/torch/csrc/jit/codegen/cuda/root_domain_map.cpp +++ b/torch/csrc/jit/codegen/cuda/root_domain_map.cpp @@ -186,27 +186,39 @@ auto ensureMapping( return it; } +TensorView* lookUpTv(const TensorDomain* td) { + Fusion* fusion = FusionGuard::getCurFusion(); + for (auto tv : ir_utils::filterByType(fusion->vals())) { + if (tv->domain() == td) { + return tv; + } + } + return nullptr; +} + } // namespace std::string DomainKey::toString() const { std::stringstream ss; - ss << "{"; - if (td()) { - ss << td() << " (root: " << td()->getRootDomain() - << ", maybe rfactor: " << td()->getMaybeRFactorDomain() << ")"; - } else { - ss << "null"; - } - ss << ", "; if (id()) { ss << id(); } else { ss << "null"; } if (concreteId()) { - ss << " (" << concreteId() << ")"; + ss << " (concrete: " << concreteId() << ")"; + } + ss << " in "; + if (td()) { + auto tv = lookUpTv(td()); + TORCH_INTERNAL_ASSERT(tv != nullptr, "No TV found for ", td()->toString()); + ss << "T" << tv->name() << "[ " << td()->getRootDomain() << " ]"; + if (td()->hasRFactor()) { + ss << " (Rfactor: [ " << td()->getMaybeRFactorDomain() << " ])"; + } + } else { + ss << "null"; } - ss << "}"; return ss.str(); } @@ -226,7 +238,7 @@ class FindInputDomains : BackwardVisitor { } DomainKeySet find() { - traverseFrom(tv_->fusion(), {tv_}); + traverseTo(tv_->fusion(), {tv_}); return input_keys_; } @@ -474,7 +486,7 @@ bool ComputeAtRootDomainMap::canMap( const IterDomain* id_b) const { TORCH_INTERNAL_ASSERT( id_b->definition() == nullptr || id_b->isRFactorProduct(), - "Non-root domain is not supproted: ", + "Non-root domain is not supported: ", id_b); if (!id_b->isBroadcast()) { @@ -685,7 +697,7 @@ ComputeAtRootDomainMapBuilder::ComputeAtRootDomainMapBuilder( map_through_reduction_(map_through_reduction) { Fusion* fusion = FusionGuard::getCurFusion(); TORCH_INTERNAL_ASSERT(fusion != nullptr); - traverseFrom(fusion, fusion->outputs(), false); + traverseTo(fusion, fusion->outputs(), false); if (!pending_map_.empty()) { std::stringstream ss; ss << "pending map:\n"; @@ -823,10 +835,6 @@ void ComputeAtRootDomainMapBuilder::setMaybeMapped( addToPendingList(producer_bcast_key, consumer_bcast_key); } } else { - TORCH_INTERNAL_ASSERT( - !consumer_id->isBroadcast(), - "No concrete domain found for a broadcast domain: ", - consumer_key.toString()); auto producer_concrete_key = producer_key; if (producer_id->isBroadcast()) { const auto concrete_id = consumer_id; @@ -862,7 +870,7 @@ void ComputeAtRootDomainMapBuilder::mapPointwiseOrReductionOp(Expr* e) { const auto& out_root = out_td->getRootDomain(); // Record equalities from output to all the inputs - // ignores un-concretizable broadcasts + // ignores non-concretizable broadcasts for (auto* in_tv : ir_utils::filterByType(e->inputs())) { const TensorDomain* in_td = in_tv->domain(); std::vector in_root = @@ -878,15 +886,16 @@ void ComputeAtRootDomainMapBuilder::mapPointwiseOrReductionOp(Expr* e) { for (const auto it : c10::irange(in_root.size())) { if (e->outputs().size() > 1) { TORCH_INTERNAL_ASSERT( - e->isA() || e->isA(), - "Multi-output mapping assumes WelforddOp or GroupedReductionOp but, ", + e->isA() || e->isA() || + e->isA(), + "Unknown multi-output Expr type ", e->getExprType().value(), " is found"); - for (auto o : e->outputs()) { - auto o_tv = o->as(); - auto o_td = o_tv->domain(); - auto o_root = o_td->getRootDomain(); - setMaybeMapped(in_td, in_root[it], o_td, o_root[it]); + for (auto out : e->outputs()) { + auto out_tv = out->as(); + auto out_td = out_tv->domain(); + auto out_root = out_td->getRootDomain(); + setMaybeMapped(in_td, in_root[it], out_td, out_root[it]); } } else { setMaybeMapped(in_td, in_root[it], out_td, out_root[it]); @@ -1056,7 +1065,7 @@ void ComputeAtRootDomainMapBuilder::handle(TensorView* tv) { mapAllPendingMappings(td, id); } - // When tv has a rfactor domain, propagate the domain mappings from + // When tv has an rfactor domain, propagate the domain mappings from // each of the rfactor axes to the dependent root axes. if (td->hasViewLikeRFactor()) { std::unordered_set root_set( @@ -1114,7 +1123,7 @@ class ExactRootDomainMapBuilder : private IterVisitor { Fusion* fusion, DisjointSets& eq_sets) : eq_sets_(eq_sets) { - traverseFrom(fusion, fusion->outputs()); + traverseTo(fusion, fusion->outputs()); } private: diff --git a/torch/csrc/jit/codegen/cuda/root_domain_map.h b/torch/csrc/jit/codegen/cuda/root_domain_map.h index cf2becbd1c718..fa3d323ba6d21 100644 --- a/torch/csrc/jit/codegen/cuda/root_domain_map.h +++ b/torch/csrc/jit/codegen/cuda/root_domain_map.h @@ -289,6 +289,8 @@ class TORCH_CUDA_CU_API ComputeAtRootDomainMap : public RootDomainMap { const TensorDomain* producer, const TensorDomain* consumer) const; + std::string toString() const; + private: //! Returns if key_a and key(td_b, id_b) are mapped to eachother (equivalent), //! or are the same key. @@ -331,8 +333,6 @@ class TORCH_CUDA_CU_API ComputeAtRootDomainMap : public RootDomainMap { const std::unordered_set& root_dims_to_map, bool producer_to_consumer) const override; - std::string toString() const; - private: //! Disjoint set of all mapped keys to determine axes equivalency DisjointSets eq_set_; diff --git a/torch/csrc/jit/codegen/cuda/runtime/memory.cu b/torch/csrc/jit/codegen/cuda/runtime/memory.cu index bc275ec1cc40a..e064a43090fd7 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/memory.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/memory.cu @@ -152,6 +152,31 @@ DEVICE_INLINE void cpAsync( "n"(byte_size)); } +// Global to SMEM load that is asynchronous, +// not guaranteed to be completed until cpAsyncBarrier() is called. +template +DEVICE_INLINE void cpAsync( + Array* smem_ptr, + void const* gmem_ptr, + bool predicate) { + unsigned smem_addr = util::toSmem(&(smem_ptr->array[0])); + constexpr int byte_size = sizeof(dtype) * len; + + static_assert( + byte_size == 4 || byte_size == 8 || byte_size == 16, + "cp_async : unsupported byte size"); + + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %3, 0;\n" + "@p cp.async.ca.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem_addr), + "l"(gmem_ptr), + "n"(byte_size), + "r"((int)predicate)); +} + // TODO: Might have a different category of sync if we want to build out this: DEVICE_INLINE void cpAsyncBarrier() { asm volatile("cp.async.wait_all;"); diff --git a/torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu b/torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu index 96cec63f8d9ee..75d39e7c0c4b6 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu @@ -67,3 +67,23 @@ __device__ double rng_uniform(const uint4& rng_result, int rng_component) { __device__ float rng_uniformf(const uint4& rng_result, int rng_component) { return uniformf((&rng_result.x)[rng_component]); } + +__device__ double rng_uniform_range( + const uint4& rng_result, + int rng_component, + double from, + double to) { + auto range = to - from; + auto uniform01 = rng_uniform(rng_result, rng_component); + return from + range * uniform01; +} + +__device__ float rng_uniform_rangef( + const uint4& rng_result, + int rng_component, + float from, + float to) { + auto range = to - from; + auto uniform01 = rng_uniformf(rng_result, rng_component); + return from + range * uniform01; +} diff --git a/torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h b/torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h index 90e64a284086c..d01d226efe42b 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h +++ b/torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h @@ -11,6 +11,7 @@ namespace cuda { enum class TORCH_CUDA_CU_API ScheduleHeuristic { None, + NoOp, PointWise, Reduction, Persistent, diff --git a/torch/csrc/jit/codegen/cuda/scheduler/compile_time_info.h b/torch/csrc/jit/codegen/cuda/scheduler/compile_time_info.h index c43ef64eac0a3..6453962bfec8a 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/compile_time_info.h +++ b/torch/csrc/jit/codegen/cuda/scheduler/compile_time_info.h @@ -26,14 +26,18 @@ namespace HeuristicCompileTime { //! Enum for all possible types of cached entries of compile-time info. enum class CompileTimeEntryType { DOMAIN_MAP, + TRANSPOSE_DOMAIN_MAP, REFERENCE_TENSORS, + REFERENCE_TENSORS_FOR_GROUPS, VECTORIZABLE_INPUTS_AND_OUTPUTS, INPUTS_AND_OUTPUTS_INNER_DIM_GROUPS, UNROLLABLE_INPUTS_AND_OUTPUTS, REDUCTION_TVS, PERSISTENT_BUFFER_INFO, SCOPE_PERSISTENT_FACTOR_INFO, - BROADCAST_BYTE_MULTIPLES + BROADCAST_BYTE_MULTIPLES, + INNER_MOST_DIMS_INFO, + CAN_SCHEDULE_TRANSPOSE, }; //! Entry type definition class for `DOMAIN_MAP`, @@ -45,6 +49,15 @@ class DomainMap { CompileTimeEntryType::DOMAIN_MAP; }; +//! Entry type definition class for `DOMAIN_MAP`, +//! stores the domain map of a fusion, used by transpose scheduler. +class TransposeDomainMap { + public: + using DataType = pointwise_utils::DomainMap; + static const CompileTimeEntryType EntryType = + CompileTimeEntryType::TRANSPOSE_DOMAIN_MAP; +}; + //! Entry type definition class for `REFERENCE_TENSORS`, //! stores the the reference TensorViews used to schedule a fusion. class ReferenceTensors { @@ -54,6 +67,16 @@ class ReferenceTensors { CompileTimeEntryType::REFERENCE_TENSORS; }; +//! Entry type definition class for `REFERENCE_TENSORS`, +//! stores the the reference TensorViews used to schedule a fusion, used by +//! transpose scheduler. +class ReferenceTensorsForGroups { + public: + using DataType = std::vector; + static const CompileTimeEntryType EntryType = + CompileTimeEntryType::REFERENCE_TENSORS_FOR_GROUPS; +}; + //! Entry type definition class for `VECTORIZABLE_INPUTS_AND_OUTPUTS`, //! stores the vectorizable TensorViews on a fusion's inputs and outputs. class VectorizableInputsAndOutputs { @@ -99,6 +122,16 @@ class PersistentBufferInfo { CompileTimeEntryType::PERSISTENT_BUFFER_INFO; }; +//! Entry type definition class for `INNER_MOST_DIMS_INFO`, +//! Used in the transpose scheduler to store inner most IterDomains and their +//! position in reference1 of group 1 and group 2 +class InnerMostDimInfo { + public: + using DataType = std::vector; + static const CompileTimeEntryType EntryType = + CompileTimeEntryType::INNER_MOST_DIMS_INFO; +}; + //! Auxiliary data types for `SCOPE_PERSISTENT_FACTOR_INFO` entry type. using ScopedPersistenceBufferMap = std::unordered_map>; @@ -121,11 +154,20 @@ class ScopePersistentFactorInfo { //! information. class BroadcastMultiples { public: - using DataType = std::vector; + using DataType = scheduler_utils::BroadcastMultipleInformation; static const CompileTimeEntryType EntryType = CompileTimeEntryType::BROADCAST_BYTE_MULTIPLES; }; +//! Entry type definition class for `CAN_SCHEDULE_TRANSPOSE`, +//! stores if the transpose scheduler can scheduler this fusion +class CanScheduleTranspose { + public: + using DataType = bool; + static const CompileTimeEntryType EntryType = + CompileTimeEntryType::CAN_SCHEDULE_TRANSPOSE; +}; + //! Base abstract class for unified storage in `HeuristicSummary`, //! each entry in `HeuristicSummary` will be a subclass. class CompileTimeInfoBase : public PolymorphicBase { diff --git a/torch/csrc/jit/codegen/cuda/scheduler/heuristic.h b/torch/csrc/jit/codegen/cuda/scheduler/heuristic.h index 058c72e592ad1..a828d66fdf039 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/heuristic.h +++ b/torch/csrc/jit/codegen/cuda/scheduler/heuristic.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include @@ -9,7 +10,7 @@ namespace jit { namespace fuser { namespace cuda { -class HeuristicParams { +class HeuristicParams : public PolymorphicBase { public: std::string tag = ""; diff --git a/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp b/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp index d0adc2aef6261..ddf1061591ed0 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp @@ -208,7 +208,7 @@ std::vector getMmaDomains(MmaOp* mma, MmaDimension dimension) { TORCH_CHECK( a_domain.size() == b_domain.size() && a_domain.size() == accumulator_domain.size(), - "Inconsisitent dimensions in mma op", + "Inconsistent dimensions in mma op", a_domain.size(), " ", b_domain.size(), @@ -274,10 +274,10 @@ std::unordered_set getMmaDomainSet( // optimizations. // // A concrete example: -// T0 [I0, I1, I2, I3, I4, I5] = mma(T1[I01, B11, B21, I31, I41, B51], T2[B02, +// T0 [I0, I1, I2, R3, I4, I5] = mma(T1[I01, B11, B21, I31, I41, B51], T2[B02, // I12, B22, I32, I42, I52], {3}; // In this case some example querries: -// K dimension of T0 = {I3} +// K dimension of T0 = {R3} // M dimension of T1 = {I01} // N dimension of T2 = {I52} // etc. diff --git a/torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp b/torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp index bf6768536dc24..459974b8d2884 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp @@ -909,7 +909,7 @@ TORCH_CUDA_CU_API std::shared_ptr getPersistentHeuristics( } // Try expanding vectorization to contig merged domains - vectorize_factor = scheduler_utils::expandVectorizationToContigMergedDomains( + vectorize_factor = vectorize_helper::expandVectorizationToContigMergedDomains( fusion, runtime_info, vectorizable_inputs_outputs, @@ -992,6 +992,8 @@ TORCH_CUDA_CU_API void schedulePersistentKernel( scheduler_utils::getReductionTvs(fusion /*, ignore_trivial = true */); TORCH_INTERNAL_ASSERT(reduction_tvs.size()); + // Registry assumes the reference tv is the first reduction_tv, if this + // changes registry needs to change. auto reduction_tv = reduction_tvs[0]; auto dim_analysis = scheduler_utils::canonicalDimReduction( diff --git a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp index d404ab622a5c7..b40e6fbf7cf7a 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp @@ -1,7 +1,7 @@ #include #include -#include +#include #include #include #include @@ -52,12 +52,6 @@ class DomainMap : public pointwise_utils::DomainMap { return result; } - static bool hasReferenceTensorView(Fusion* fusion) { - FusionGuard fg(fusion); - DomainMap domain_map(fusion); - return domain_map.findReferenceTensorView() != nullptr; - } - private: bool hasMinimumSize(TensorView* tv, int num_axes) const { TORCH_INTERNAL_ASSERT(tv != nullptr); @@ -141,13 +135,11 @@ std::shared_ptr getPointwiseHeuristics( }); vectorizable_inputs_outputs_entry.get(); - auto broadcast_byte_multiples_entry = - HeuristicSummaryEntry( - data_cache, []() { - return std::make_unique< - std::vector>(); - }); - broadcast_byte_multiples_entry.get(); + auto broadcast_info = HeuristicSummaryEntry< + HeuristicCompileTime::BroadcastMultiples>(data_cache, []() { + return std::make_unique(); + }); + broadcast_info.get(); return std::make_shared("Pointwise heuristics"); } @@ -183,25 +175,7 @@ std::shared_ptr getPointwiseHeuristics( auto params = std::make_shared("Pointwise heuristics"); - /* - * 2D pointwise scheduling logic. What is expected is there's some - * broadcasting pattern which would make scheduling as a 2D problem more - * efficient than scheduling simply as a 1D problem. - * - * Mapping count holds how many bytes are in each dimension for both inputs - * and outputs relative to the reference tensor. What we're looking for is a - * break point in reference_tvs dimensions which separates the outer dimension - * and inner dimension of the problem mapped to 2D. - * - * break_point is computed assuming no reuse, ignoring parallelization - * limitations, and simply figures out which point best separates broadcasted - * dimensions. In other words, where's the point where we isolate the most - * broadcasted elements to one side. - * - * Once a break point is found, simply schedule the pointwise op as 2D - * balancing parallelization as best as possible. - */ - + // See pointwise.h to understand what we're doing for this 2D analysis. // Ideal break point location int break_point = 0; @@ -230,16 +204,15 @@ std::shared_ptr getPointwiseHeuristics( // break point. int64_t gdim_right = 1; - auto broadcast_byte_multiples_entry = - HeuristicSummaryEntry( - data_cache, [&largest_out, &index_type]() { - return std::make_unique< - std::vector>( - scheduler_utils::getBroadcastMultiples( - largest_out, index_type)); - }); + auto broadcast_info = HeuristicSummaryEntry< + HeuristicCompileTime::BroadcastMultiples>( + data_cache, [&largest_out, &index_type]() { + return std::make_unique( + scheduler_utils::getBroadcastMultiples(largest_out, index_type)); + }); - auto& broadcast_byte_multiples = broadcast_byte_multiples_entry.get(); + auto& view_disjoint_sets = broadcast_info.get().view_disjoint_set_ids; + auto& broadcast_byte_multiples = broadcast_info.get().broadcast_multiples; TORCH_INTERNAL_ASSERT(broadcast_byte_multiples.size() == ref_root.size()); @@ -266,6 +239,12 @@ std::shared_ptr getPointwiseHeuristics( int64_t min_total_transfer = std::numeric_limits::max(); for (const auto break_point_i : c10::irange(ref_root.size())) { + // If break point is incoherent with view, don't consider breaking here. + if (!scheduler_utils::breakIsDisjoint( + view_disjoint_sets, break_point_i)) { + continue; + } + // Number of elements in the right side of reference tv with // break_point_i int64_t cur_right_elem_count = 1; @@ -362,8 +341,10 @@ std::shared_ptr getPointwiseHeuristics( } // Try expanding vectorization to contig merged domains + // TODO: This is an expensive function that shouldn't be in heuristics without + // caching. auto expanded_vector_word_size = - scheduler_utils::expandVectorizationToContigMergedDomains( + vectorize_helper::expandVectorizationToContigMergedDomains( fusion, runtime_info, vectorizable_inputs_outputs, @@ -435,8 +416,15 @@ LaunchParams schedulePointwise( return params->lparams; } +TensorView* getReferenceTensorView(Fusion* fusion) { + FusionGuard fg(fusion); + DomainMap domain_map(fusion); + auto reference_tv = domain_map.findReferenceTensorView(); + return reference_tv; +} + bool hasReferenceTensorView(Fusion* fusion) { - return DomainMap::hasReferenceTensorView(fusion); + return getReferenceTensorView(fusion) != nullptr; } // TODO: Inline intermediate operations (avoid inlining unrolled/vectorized @@ -487,41 +475,142 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { return; } - DomainMap domain_map(fusion); - TensorView* reference_tv = - domain_map.findReferenceTensorView(params.break_point); + TensorView* reference_tv = getReferenceTensorView(fusion); TORCH_INTERNAL_ASSERT( reference_tv != nullptr, "Could not find a fully broadcasted output to reference schedule on."); - auto all_tvs = ir_utils::allTvs(fusion); - - // Merge right side of break point + // Positions of rhs and lhs after merging all dimensions. int rhs_i = -1; - for (int i = (int)reference_tv->nDims(); i > (int)params.break_point; i--) { - auto axis_i = i - 1; - if (rhs_i == -1) { - rhs_i = axis_i; - } else { - reference_tv->merge(axis_i, rhs_i); - rhs_i = axis_i; + int lhs_i = -1; + + auto view_ops = ir_utils::getViewOps(fusion); + + /* + * If there's no path from reference through producer paths only to a view, + * e.g.: input + * / \ + * view reference + * / + * output + * + * we need to propagate the view transformations to the reference tv before + * scheduling the reference tv. Since view ops have to be identical, if any + * path from reference tv through producers goes through a view, all paths + * from reference tv's to views should be through producers. + */ + bool needs_view_prop = + view_ops.size() > 0 && + !std::any_of( + view_ops.begin(), view_ops.end(), [&reference_tv](ViewOp* view) { + return DependencyCheck::isDependencyOf(view->out(), reference_tv) || + view->out()->sameAs(reference_tv); + }); + + if (needs_view_prop) { + auto first_view_op = *view_ops.begin(); + + // Propagate the view transformations + TransformPropagator propagator(first_view_op->out()); + MaxRootDomainInfoSpanningTree spanning_tree(first_view_op->out()); + spanning_tree.traverse(&propagator); + + // Reorder reference_tv after propagating the view operation. This will + // reorder for better merging. + reference_tv->reorder( + scheduler_utils::domainReorderAsRfactorMap(reference_tv)); + + // Break point is relative to rfactor domain, find the leaf domain ID's in + // the left/right side, we really need the values in domain, but easiest way + // to do this is with Dependency check which will grab all intermediate + // values too. + auto lhs_all_vals = DependencyCheck::getAllValsBetween( + {reference_tv->getMaybeRFactorDomain().begin(), + reference_tv->getMaybeRFactorDomain().begin() + params.break_point}, + {reference_tv->domain()->domain().begin(), + reference_tv->domain()->domain().end()}); + + std::unordered_set lhs_all_vals_set( + lhs_all_vals.begin(), lhs_all_vals.end()); + + auto rhs_all_vals = DependencyCheck::getAllValsBetween( + {reference_tv->getMaybeRFactorDomain().begin() + params.break_point, + reference_tv->getMaybeRFactorDomain().end()}, + {reference_tv->domain()->domain().begin(), + reference_tv->domain()->domain().end()}); + + std::unordered_set rhs_all_vals_set( + rhs_all_vals.begin(), rhs_all_vals.end()); + + // Make sure lhs and rhs groups are disjoint. + for (auto lhs_val : lhs_all_vals) { + TORCH_INTERNAL_ASSERT( + rhs_all_vals_set.count(lhs_val) == 0, + "Error in pointwise scheduler. LHS and RHS of the 2D scheduler are not disjoint."); } - } - if (rhs_i >= 0) { - // If there's an rhs - reference_tv->reorder({{rhs_i, -1}}); - } - // Merge left side of break point - int lhs_i = -1; - for (int i = (int)params.break_point; i > 0; i--) { - auto axis_i = i - 1; - if (lhs_i == -1) { - lhs_i = axis_i; - } else { - reference_tv->merge(axis_i, lhs_i); - lhs_i = axis_i; + // Merge rhs, then lhs. + IterDomain* rhs_id = nullptr; + IterDomain* lhs_id = nullptr; + auto ndims = reference_tv->nDims(); + for (auto i : c10::irange(ndims)) { + // Merge from right to left + auto pos = ndims - 1 - i; + auto id = reference_tv->axis(pos); + if (lhs_all_vals_set.count(id) > 0) { + if (lhs_id == nullptr) { + lhs_id = id; + lhs_i = pos; + } else { + reference_tv->merge(pos, lhs_i); + lhs_i = pos; + if (rhs_i > lhs_i) { + rhs_i--; + } + } + } else if (rhs_all_vals_set.count(id) > 0) { + if (rhs_id == nullptr) { + rhs_id = id; + rhs_i = pos; + } else { + reference_tv->merge(pos, rhs_i); + rhs_i = pos; + if (lhs_i > rhs_i) { + lhs_i--; + } + } + } + } + // Find the iter domains that should be in the lhs, and rhs. + } else { + // Don't need to worry about view transformations, just merge reference tv + // as we normally would. + + // Merge right side of break point + for (int i = (int)reference_tv->nDims(); i > (int)params.break_point; i--) { + auto axis_i = i - 1; + if (rhs_i == -1) { + rhs_i = axis_i; + } else { + reference_tv->merge(axis_i, rhs_i); + rhs_i = axis_i; + } + } + if (rhs_i >= 0) { + // If there's an rhs + reference_tv->reorder({{rhs_i, -1}}); + } + + // Merge left side of break point + for (int i = (int)params.break_point; i > 0; i--) { + auto axis_i = i - 1; + if (lhs_i == -1) { + lhs_i = axis_i; + } else { + reference_tv->merge(axis_i, lhs_i); + lhs_i = axis_i; + } } } @@ -716,9 +805,9 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { // get a higher position in later inline propagation. We need this separate // step because we were not using ParallelType::Unroll, so we have to do // unrolling manually. - InlinePropagator inline_unswitch( - reference_tv, unswitch_pos, ComputeAtMode::BestEffort); - spanning_tree.traverse(&inline_unswitch); + inlineAllAt(reference_tv, unswitch_pos, true); + + auto all_tvs = ir_utils::allTvs(fusion); // Inline at the inner most position. The CA position of all tensors except // inputs, cached inputs and outputs will be updated. @@ -731,9 +820,7 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { auto output = entry.second; inner_most_tensors.erase(output); } - InlinePropagator inline_inner_most( - reference_tv, -1, ComputeAtMode::BestEffort, inner_most_tensors); - spanning_tree.traverse(&inline_inner_most); + inlineMost(inner_most_tensors); } } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.h b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.h index 6cba29cd6b4b9..f3a1da7bcff5f 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.h +++ b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.h @@ -10,6 +10,141 @@ namespace jit { namespace fuser { namespace cuda { +/* + * The 2D pointwise scheduling logic is a bit interesting. We'll start by giving + * motivation for what the scheduling is attempting to do. What we're going to + * do with the scheduling is attempt to make it two dimensional in a way that + * minimizes the refetching of broadcasted dimensions. If we think of the + * trivial case: + * T0[i0, b1] + * T1[b0, i1] + * T2[i0, i1] = T0 + T1 + * If we scheduled T2 as 1-dimensional we would do something along the lines of + * merging i0 and i1 then splitting out a block and thread dimension. If i1 is + * greater than the thread dimension, then all threads would pull the same value + * from T0. However, they would all be pulling different values from T1. In this + * case we have perfect reuse of the broadcast dimension T0 but potentially no + * reuse of the broadcast dimension of T1. "Potentially" because if i1 isn't too + * big it should be efficiently cached in L2. If i1 is big, then by the time we + * increment the i0 dimension the i1 dimension will be pushed out of cache. + * + * Instead what we do is we map this to a two dimensional problem. Instead of + * having the schedule that merges the two dimensions, we'll actually leave the + * dimensions separate and we'll take i0, split it to BIDy, TIDy, and take i1 + * and split it to BIDx and TIDx. Therefore we'll have a parallelization on T2 + * like [BIDy, TIDy | BIDx, TIDx], where | denotes the separation of the + * original i0 and i1. This helps because all threads in the TIDx dimension will + * reuse the same value in the i0 dimension (holding BIDy and TIDy constant), + * all the threads in the TIDy dimension (holding BIDx, and TIDx constant) will + * reuse the same value in the i1 dimension. This reuse of values reduces the + * number of redundant values pulled from T0 and T1. The same thing can be said + * for when incrementing BIDy, but since BIDy is strided on BIDx there's no + * effective increment of BIDy without incrementing BIDx. Since all threads are + * executed within a block we can effectively consider the block incrementing + * TIDx BDIMx times while holding TIDy constant and incrementing TIDy BDIMy + * times while holding TIDx constant. Since multiple BIDx's are running at the + * same time on the device we can consider a wave on the GPU of incrementing + * BIDx (wave number of times), while holding TIDy constant BDIMy * wave number + * of times. + * + * If instead we have a situation like: + * T0[i0, i1, b2] + * T1[i0, b1, i2] + * T2[i0, i1, i2] = T0 + T1 + * It makes sense that the break point would be in position 2, between i1 and + * i2. This is because when we map [i0, i1 | i2] to [BIDy, TIDy| BIDx, TIDx] + * BIDx, and TIDx will access the same elements of T0 on b2, and TIDy will + * likely access the same elements of T1 (as long as i1 > BDIMy). Even if i1 on + * the order of BDIMy we'll only access ~two unique elements per increment of + * BIDx or TIDx. This means we'll still reuse many of the same values and limit + * the amount we need to read duplicate values in T0 and T1. + * + * If instead we have: + * T0[i0, b1, i2] + * T1[b0, i1, i2] + * T2[i0, i1, i2] = T0 + T1 + * The analysis gets a bit more complicated. First if i2 is very large and i0 + * and i1 are relatively small it would make sense to have [i0, i1 | i2]. If b0 + * is very small it's unlikely beneficial to have [i0 | i1, i2] as there would + * be small reuse on b0, and potentially no reuse on b1. If i2 is very small it + * may be worthwhile to have [i0 | i1, i2]. If i1 and i2 are not small, and + * their product is relatively large (i.e. you can't fit T2[i, :, :] in L2) then + * it's unlikely we'll get any significant reuse across i0. + * + * What we should (but don't due to complexity) assume then, is that we will get + * strong reuse across TIDx and TIDy for dimensions that are on the inner + * portion of the 2D tile. + * + * For example if we have: + * T0[i0, b1, i2] + * T1[b0, b1, i2] + * T2[b0, i1, i2] + * T3[i0, i1, i2] = T0 + T1 + T2 + * We may want to break point at position 1 or position 2 (i.e. [i0 | i1, i2] or + * [i0, i1 | i2]). We can't immediately tell from the structure. + * + * If we choose [i0, i1 | i2] then we'll get: + * Strong reuse of T0 on TIDy (b1 dim) + * Perfect reuse across T1 on TIDy (b0 and b1) + * If BIDx is bound to the LHS of the tile we'll get: + * Maybe strong reuse of T0 on BIDx (b1 dim if it's large) + * Perfect reuse across T1 on BIDx + * Potentially no reuse on T2 if i1 is very large + * + * If we pick [i0 | i1, i2], then we'll get: + * We'll perfect reuse across TIDy on T1 and T2 on b0 + * Some reuse on T0 and T1 on b1 across BIDx if i2 is relatively small and BIDx + * is bound to the RHS of the 2D schedule Perfect reuse on T1 and T2 on b0 + * across BIDx if BIDx is bound to the LHS of the 2D schedule + * + * Materializing these benefits is dependent on the decisions the scheduler + * makes when parallelizing the problem. The heuristics logic at the moment is + * fairly simplistic where it assumes that there's only reuse across the break + * points for tensors that have no iteration domain on the entire side of the + * breakpoint. This is not optimal but for the time being it seems sufficient. + * We would ideally take into consideration the parallelization scheme and + * partial broadcasting on the lhs or rhs. + * + * An example of how this analysis is done is given the DAG: + * T0[i0, i1, b2] float + * T1[i0, b1, i2] half + * T2[i0, b1, i2] = cast(T1, float) + * T4[i0, i1, i2] float = T0 + T2 + * With values of 10, 100, 1000 as [i0, i1, i2] + * Our break point analysis for positions 0, 1, 2, 3 will be: + * + * 0: 10*10 * 100*10 * 1000*10 = 1e9 + * 1: 10*10 * 100*10 * 1000*10 = 1e9 + * 2: 10*10 * 100*10 * 1000*6 = 6e8 + * 3: 10*10 * 100*10 * 1000*10 = 1e9 + * + * Where for each computation the LHS of the * pairs is the number of elements + * in that dimension on the reference and the RHS of the * pairs is the + * broadcast multiple where any tensor that has all broadcasts on the rhs or lhs + * of the break point doesn't contribute to the broadcast multiple of the rhs or + * lhs. + * + * So we'll pick position 2 since we're confident we can get broadcast reuse on + * the rhs of tensor 0. As already mentioned this is a pretty big + * simplification/assumption and in reality it may be harder/easier to take + * advantage of broadcast on the inner or outer dimension. This is a reasonable + * way to make relative decisions on break points, however, this computation is + * ont doing an effective estimate of actual DRAM transfers which it should be + * modified to do so. + * + * For view schedules there can be some incoherent break points for example: + * T1[i0, i1*i2] = view(T0[i0, i1, i2]) + * would make the position 2 "incoherent". In otherwords we cannot replay + * through the view a schedule that tries to merge i0 and i1, without i2. So for + * positions that are incoherent we won't consider break point positions there. + * + * See FusionBroadcastViewMultiples_CUDA for what we expect with view handling. + * Shortly any dimensions that are inputs or outputs of view transformations are + * considered together, since it's hard to account for partial dimensions that + * are being broadcasted. So for view it's primarily an all or nothing situation + * when it comes to the 2D pointwise scheduler. + */ + class SchedulerRuntimeInfo; class HeuristicSummary; @@ -36,6 +171,9 @@ TORCH_CUDA_CU_API LaunchParams schedulePointwise( //! the pointwise scheduler. bool hasReferenceTensorView(Fusion* fusion); +// Return reference tensor view. +TensorView* getReferenceTensorView(Fusion* fusion); + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/scheduler/pointwise_utils.h b/torch/csrc/jit/codegen/cuda/scheduler/pointwise_utils.h index 7947a27f48360..6cc4b1b8b93bd 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/pointwise_utils.h +++ b/torch/csrc/jit/codegen/cuda/scheduler/pointwise_utils.h @@ -20,10 +20,6 @@ class DomainMap { } virtual ~DomainMap() = default; - bool areExactMapped(IterDomain* id1, IterDomain* id2) const { - return ca_map_.areMapped(id1, id2, IdMappingMode::EXACT); - } - const ComputeAtMap& getComputeAtMap() const { return ca_map_; } diff --git a/torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp b/torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp index b5940b1d4e1cb..3037f8469dad4 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp @@ -954,7 +954,7 @@ TORCH_CUDA_CU_API std::shared_ptr getReductionHeuristics( } // Try expanding vectorization to contig merged domains - vectorize_factor = scheduler_utils::expandVectorizationToContigMergedDomains( + vectorize_factor = vectorize_helper::expandVectorizationToContigMergedDomains( fusion, runtime_info, vectorizable_inputs_outputs, @@ -1010,6 +1010,8 @@ void scheduleReduction(Fusion* fusion, const ReductionParams& rparams) { TORCH_INTERNAL_ASSERT(reduction_tvs.size()); + // Registry assumes the reference tv is the first reduction_tv, if this + // changes registry needs to change. auto reduction_tv = reduction_tvs[0]; auto dim_analysis = scheduler_utils::canonicalDimReduction( diff --git a/torch/csrc/jit/codegen/cuda/scheduler/reduction_utils.cpp b/torch/csrc/jit/codegen/cuda/scheduler/reduction_utils.cpp index 6bd4d4efba376..ae9ecd88bbdc3 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/reduction_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/reduction_utils.cpp @@ -1,7 +1,7 @@ #include #include -#include +#include #include #include #include @@ -336,14 +336,7 @@ void multiReductionInliner( scheduler_utils::getTrivialReductionMap(fusion); // Inline the schedule - InlinePropagator inline_propagator( - reference_tv, - -1, - ComputeAtMode::MostInlined, - {}, - mapped_to_trivial_reduction); - - MaxRootDomainInfoSpanningTree(reference_tv).traverse(&inline_propagator); + inlineMost(mapped_to_trivial_reduction); } namespace { diff --git a/torch/csrc/jit/codegen/cuda/scheduler/registry.cpp b/torch/csrc/jit/codegen/cuda/scheduler/registry.cpp index 8550bfc6bf0fa..5d5bc84ef3b4d 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/registry.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/registry.cpp @@ -358,6 +358,45 @@ class SchedulerTopologyChecker { return true; } + + /* Returns if any non-trivial views are not before the reference. For example: + * t0 + * / \ + * view ref + * | + * t1 + * This could be important as transform propagation from a reference backwards + * through a view should always work, but transform propagation form a + * reference forward through a view could interfere with the view transforms. + */ + static bool hasViewNotBeforeRef( + Fusion* fusion, + std::vector reference_tvs) { + std::vector view_tvs; + auto view_ops = ir_utils::getViewOps(fusion); + for (auto view_op : view_ops) { + auto tv_outs = ir_utils::filterByType(view_op->outputs()); + for (auto entry : tv_outs) { + view_tvs.push_back(entry); + } + } + + if (view_tvs.empty()) { + return false; + } + + // Terrible complexity, may be worth improving, but is a compile time + // check. + for (auto ref_tv : reference_tvs) { + for (auto view_tv : view_tvs) { + if (!DependencyCheck::isDependencyOf(view_tv, ref_tv)) { + return true; + } + } + } + + return false; + } }; bool isConnectedFusionGraph(Fusion* fusion) { @@ -369,6 +408,11 @@ bool isConnectedFusionGraph(Fusion* fusion) { // A set of connected components on the fusion graph DisjointSets component_sets; + TORCH_INTERNAL_ASSERT( + !fusion->outputs().empty(), "Fusion without output is not supported"); + auto output0 = fusion->outputs()[0]; + component_sets.initializeSet(output0); + // Iterate through all used exprs for (auto expr : fusion->exprs()) { TORCH_INTERNAL_ASSERT( @@ -394,7 +438,6 @@ bool isConnectedFusionGraph(Fusion* fusion) { // If there is no independent compute flow // on this fusion graph, all outputs will be // equivalent/connected to the first output. - auto output0 = fusion->outputs()[0]; for (auto output : fusion->outputs()) { if (!component_sets.strictAreMapped(output0, output)) { return false; @@ -420,6 +463,24 @@ void SchedulerRuntimeInfo::initialize( auto fusion_inp = complete_fusion_->inputs()[inp_i]; auto data_ptr = tensor_arg_abstract->getPointer(); input_ptrs_[fusion_inp] = (size_t)data_ptr; + + // find and push discontiguous stride + auto dtype_size = dataTypeSize(tensor_arg_abstract->getDataType()); + input_discontig_strides_[fusion_inp] = {}; + auto dims = tensor_arg_abstract->getRank(); + auto expected_stride = 1; + for (auto dim = dims - 1; dim >= 0; dim--) { + auto size = tensor_arg_abstract->getSize(dim); + if (size <= 1) { + continue; + } + auto stride = tensor_arg_abstract->getStride(dim); + if (stride != expected_stride) { + input_discontig_strides_[fusion_inp].push_back(stride * dtype_size); + expected_stride = stride; + } + expected_stride *= size; + } } } @@ -486,6 +547,13 @@ size_t SchedulerRuntimeInfo::getAlignmentSize(TensorView* tv) { } auto alignment_size = SchedulerRuntimeInfo::computeAlignmentSize(ptrOf(tv)); + auto strides_it = input_discontig_strides_.find(tv); + if (strides_it != input_discontig_strides_.end()) { + for (auto stride : strides_it->second) { + alignment_size = std::min( + alignment_size, SchedulerRuntimeInfo::computeAlignmentSize(stride)); + } + } alignment_map_[tv] = alignment_size; return alignment_size; } @@ -746,8 +814,7 @@ static bool checkPatternEquivalence( // being broadcasted to one size multiple times or different sizes. This is a // hard to optimize problem and likely indicates we shouldn't be fusing. bool hasNonUniqueBcast(Fusion* fusion) { - ConcretizedBroadcastDomains concretize_info; - concretize_info.build(fusion); + ConcretizedBroadcastDomains concretize_info(fusion); for (auto tv : ir_utils::allTvs(fusion)) { for (auto id : tv->getRootDomain()) { @@ -788,6 +855,119 @@ bool hasNonUniqueBcast(Fusion* fusion) { //! This function will be called when compiling a kernel. It should apply //! scheduling to the given fusion +//! NoOp scheduler represents the case where scheduler will +//! not do any scheduling operations and forward the un-scheduled +//! fusion directly to code generation and kernel compilation. +//! +//! Typical use case of this scheduler is to handle edge cases +//! such as where all tensors are size-1 or size-0. +class NoOpScheduler : public SchedulerEntry { + //! Provides a dummy heuristic type to ensure + //! unified interface on NoOp scheduler. + class NoOpHeuristic : public HeuristicParams { + public: + size_t hash() const override { + return 0; + } + std::shared_ptr clone() const override { + return std::make_shared(); + } + bool sameAs(const std::shared_ptr& other) const override { + auto other_casted = std::dynamic_pointer_cast(other); + return other_casted != nullptr; + }; + }; + + public: + explicit NoOpScheduler( + Fusion* fusion, + SchedulerRuntimeInfo& runtime_info, + HeuristicSummary* data_cache = nullptr) + : SchedulerEntry(ScheduleHeuristic::NoOp) { + params_ = std::make_shared(); + } + + //! Check if the no-op heuristics apply in given fusion + static bool canScheduleCompileTime(Fusion* fusion) { + // Check there're no non-trivial reduction ops. + for (auto reduction : + ir_utils::getReductionOps(fusion, true /* ignore_trivial */)) { + for (auto input : + ir_utils::filterByType(reduction->inputs())) { + auto root_dom = input->getRootDomain(); + auto all_nonzero = + std::none_of(root_dom.begin(), root_dom.end(), [](IterDomain* id) { + return id->extent()->isZeroInt(); + }); + if (all_nonzero) { + scheduler_debug_utils::canScheduleRejectReason( + ScheduleHeuristic::NoOp, + "reduction of non-zero elements is not supported"); + return false; + } + } + } + + // Check that all outputs are either broadcast or ignored reduction. + for (auto out_tv : ir_utils::filterByType(fusion->outputs())) { + auto non_zero_candidate_dimension = TensorDomain::noReductions( + TensorDomain::noBroadcasts(out_tv->domain()->domain())); + + // non_zero_candidate_dimension is empty would mean this out tv has only + // broadcast and trivial reduction axes, and this out tv would not + // require scheduling ops. + // If any of the dimensions in non_zero_candidate_dimension is compile + // time + // constant zero, this out tv also does not require any scheduling + // operation as it is essentially a scalar. + // TODO: + // There seems to be a runtime component to it + // too, i.e. if the runtime sizes are zero, then we should + // handle it through null scheduler. + if (!non_zero_candidate_dimension.empty() && + std::none_of( + non_zero_candidate_dimension.begin(), + non_zero_candidate_dimension.end(), + [](IterDomain* id) { return id->extent()->isZeroInt(); })) { + // We have found a out_tv with a dimension that NoOp scheduler couldn't + // handle and therefore reject this fusion. + scheduler_debug_utils::canScheduleRejectReason( + ScheduleHeuristic::NoOp, "output has a concrete dimension"); + return false; + } + } + + // We have verified that all iterdomains on all output tv's are trivial + // reductions, + // broadcasts or zero-sized. Therefore accepting this fusion for NoOp + // scheduling. + return true; + } + + static bool canScheduleRunTime( + Fusion* fusion, + SchedulerRuntimeInfo& runtime_info, + HeuristicSummary* data_cache = nullptr) { + // TODO: + // Pipe through dynamic zero checks. + return true; + } + + void schedule(Fusion* fusion) override { + // Schedule is no-op. + return; + } + + private: + void computeHeuristics( + Fusion* fusion, + SchedulerRuntimeInfo& runtime_info, + HeuristicSummary* data_cache = nullptr) { + // Heuristics is no-op. + return; + } +}; + class ReductionScheduler : public SchedulerEntry { public: explicit ReductionScheduler( @@ -838,6 +1018,17 @@ class ReductionScheduler : public SchedulerEntry { return false; } + // Persistent scheduler simply uses reduction_tvs[0] as the reference, if + // that changes, this needs to be changed. Second check here may be overly + // conservative. + if (SchedulerTopologyChecker::hasViewNotBeforeRef( + fusion, {reduction_tvs[0]}) || + !scheduler_utils::allMatchingViews(fusion)) { + scheduler_debug_utils::canScheduleRejectReason( + ScheduleHeuristic::Reduction, "Unsupported view fusion."); + return false; + } + // Make sure reduction axes are consistent through the fusion auto reduction_ops = ir_utils::getReductionOps(fusion, false /* ignore_trivial */); @@ -937,6 +1128,84 @@ class ReductionScheduler : public SchedulerEntry { } }; +class TransposeScheduler : public SchedulerEntry { + public: + explicit TransposeScheduler( + Fusion* fusion, + SchedulerRuntimeInfo& runtime_info, + HeuristicSummary* data_cache = nullptr) + : SchedulerEntry(ScheduleHeuristic::Transpose) { + computeHeuristics(fusion, runtime_info, data_cache); + } + + static bool canScheduleCompileTime(Fusion* fusion) { + // Temporarily disallow view in transpose scheduler + // TODO Add more testing before enabling + auto view_tvs = scheduler_utils::getViewTVs(fusion); + if (view_tvs.size() > 0) { + scheduler_debug_utils::canScheduleRejectReason( + ScheduleHeuristic::Transpose, "No support for view op"); + return false; + } + + if (!hasAtLeastTwoValidGroups(fusion)) { + scheduler_debug_utils::canScheduleRejectReason( + ScheduleHeuristic::Transpose, + "cannot find two mismatching inner most dimensions"); + return false; + } + + // TODO: add support for trivial reduction + auto reduction_ops = + ir_utils::getReductionOps(fusion, false /* ignore_trivial */); + + if (!reduction_ops.empty()) { + scheduler_debug_utils::canScheduleRejectReason( + ScheduleHeuristic::Transpose, "no support for reduction ops"); + return false; + } + + if (hasNonUniqueBcast(fusion)) { + scheduler_debug_utils::canScheduleRejectReason( + ScheduleHeuristic::Transpose, + "Broadcasting dimension might be broadcasting to multiple sizes."); + return false; + } + + return true; + } + + static bool canScheduleRunTime( + Fusion* fusion, + SchedulerRuntimeInfo& runtime_info, + HeuristicSummary* data_cache = nullptr) { + FUSER_PERF_SCOPE("TransposeScheduler::canScheduleRunTime"); + + auto reason = + getTransposeRuntimeRejectReason(fusion, data_cache, runtime_info); + if (!reason.empty()) { + scheduler_debug_utils::canScheduleRejectReason( + ScheduleHeuristic::Transpose, reason); + return false; + } + return true; + } + + void schedule(Fusion* fusion) override { + FUSER_PERF_SCOPE("Schedule Transpose Fusion"); + scheduleTranspose(fusion, transposeParams()); + } + + private: + void computeHeuristics( + Fusion* fusion, + SchedulerRuntimeInfo& runtime_info, + HeuristicSummary* data_cache = nullptr) { + params_ = getTransposeHeuristics(fusion, runtime_info, data_cache); + TORCH_INTERNAL_ASSERT(params_ != nullptr); + } +}; + class PointWiseScheduler : public SchedulerEntry { public: explicit PointWiseScheduler( @@ -957,6 +1226,14 @@ class PointWiseScheduler : public SchedulerEntry { return false; } + if (!scheduler_utils::allMatchingViews(fusion) && + SchedulerTopologyChecker::hasViewNotBeforeRef( + fusion, {getReferenceTensorView(fusion)})) { + scheduler_debug_utils::canScheduleRejectReason( + ScheduleHeuristic::PointWise, "Unsupported view fusion."); + return false; + } + auto reduction_ops = ir_utils::getReductionOps(fusion, true /* ignore_trivial */); @@ -980,6 +1257,18 @@ class PointWiseScheduler : public SchedulerEntry { Fusion* fusion, SchedulerRuntimeInfo& runtime_info, HeuristicSummary* data_cache = nullptr) { + auto can_schedule_transpose_entry = + HeuristicSummaryEntry( + data_cache, [fusion]() { + return std::make_unique( + TransposeScheduler::canScheduleCompileTime(fusion)); + }); + if (can_schedule_transpose_entry.get()) { + auto reason = + getTransposeRuntimeRejectReason(fusion, data_cache, runtime_info); + return !reason.empty(); + } + return true; } @@ -1047,6 +1336,16 @@ class PersistentKernelScheduler : public SchedulerEntry { return false; } + // Persistent scheduler simply uses reduction_tvs[0] as the reference, if + // that changes, this needs to be changed. Second check here may be overly + // conservative. + if (SchedulerTopologyChecker::hasViewNotBeforeRef( + fusion, {reduction_tvs[0]}) || + !scheduler_utils::allMatchingViews(fusion)) { + scheduler_debug_utils::canScheduleRejectReason( + ScheduleHeuristic::Persistent, "Unsupported view fusion."); + } + if (findTransposeOps(fusion).size() > 0) { // Use pointwise logic scheduler_debug_utils::canScheduleRejectReason( @@ -1216,84 +1515,10 @@ class PersistentKernelScheduler : public SchedulerEntry { } }; -class TransposeScheduler : public SchedulerEntry { - public: - explicit TransposeScheduler( - Fusion* fusion, - SchedulerRuntimeInfo& runtime_info, - HeuristicSummary* data_cache = nullptr) - : SchedulerEntry(ScheduleHeuristic::Transpose) { - computeHeuristics(fusion, runtime_info, data_cache); - } - - static bool canScheduleCompileTime(Fusion* fusion) { - if (!isOptionEnabled(EnableOption::TransposeScheduler)) { - scheduler_debug_utils::canScheduleRejectReason( - ScheduleHeuristic::Transpose, "not enabled"); - return false; - } - - // Temporarily disallow view in transpose scheduler - // TODO Add more testing before enabling - auto view_tvs = scheduler_utils::getViewTVs(fusion); - if (view_tvs.size() > 0) { - scheduler_debug_utils::canScheduleRejectReason( - ScheduleHeuristic::Transpose, "No support for view op"); - return false; - } - - if (!hasAtLeastTwoValidGroups(fusion)) { - scheduler_debug_utils::canScheduleRejectReason( - ScheduleHeuristic::Transpose, - "cannot find two mismatching inner most dimensions"); - return false; - } - - // TODO: add support for trivial reduction - auto reduction_ops = - ir_utils::getReductionOps(fusion, false /* ignore_trivial */); - - if (!reduction_ops.empty()) { - scheduler_debug_utils::canScheduleRejectReason( - ScheduleHeuristic::Transpose, "no support for reduction ops"); - return false; - } - - if (hasNonUniqueBcast(fusion)) { - scheduler_debug_utils::canScheduleRejectReason( - ScheduleHeuristic::Transpose, - "Broadcasting dimension might be broadcasting to multiple sizes."); - return false; - } - - return true; - } - - static bool canScheduleRunTime( - Fusion* fusion, - SchedulerRuntimeInfo& runtime_info, - HeuristicSummary* data_cache = nullptr) { - return true; - } - - void schedule(Fusion* fusion) override { - FUSER_PERF_SCOPE("Schedule Transpose Fusion"); - scheduleTranspose(fusion, transposeParams()); - } - - private: - void computeHeuristics( - Fusion* fusion, - SchedulerRuntimeInfo& runtime_info, - HeuristicSummary* data_cache = nullptr) { - params_ = getTransposeHeuristics(fusion, runtime_info, data_cache); - TORCH_INTERNAL_ASSERT(params_ != nullptr); - } -}; - // Schedule Table const std::vector& all_heuristics() { static const std::vector hlist = { + ScheduleHeuristic::NoOp, ScheduleHeuristic::Reduction, ScheduleHeuristic::Transpose, ScheduleHeuristic::PointWise, @@ -1316,6 +1541,9 @@ bool checkCanSchedule( if (!isConnectedFusionGraph(fusion)) { return false; } + if (IterDomainGraph(fusion, /*allow_self_mapping=*/true).hasSelfMapping()) { + return false; + } if (!SchedulerType::canScheduleCompileTime(fusion)) { return false; } @@ -1333,6 +1561,8 @@ bool SchedulerEntry::canSchedule( SchedulerRuntimeInfo& runtime_info, HeuristicSummary* data_cache) { switch (sh) { + case ScheduleHeuristic::NoOp: + return checkCanSchedule(fusion, runtime_info, data_cache); case ScheduleHeuristic::PointWise: return checkCanSchedule( fusion, runtime_info, data_cache); @@ -1359,6 +1589,10 @@ std::unique_ptr SchedulerEntry::makeEntry( HeuristicSummary* data_cache) { std::unique_ptr scheduler_entry = nullptr; switch (sh) { + case ScheduleHeuristic::NoOp: + scheduler_entry = + std::make_unique(fusion, runtime_info, data_cache); + break; case ScheduleHeuristic::PointWise: scheduler_entry = std::make_unique( fusion, runtime_info, data_cache); @@ -1402,6 +1636,8 @@ size_t SchedulerEntryHash::operator()(const SchedulerEntry& se) const { std::string toString(ScheduleHeuristic sh) { switch (sh) { + case ScheduleHeuristic::NoOp: + return "no-op"; case ScheduleHeuristic::PointWise: return "pointwise"; case ScheduleHeuristic::Reduction: @@ -1450,6 +1686,9 @@ HeuristicSummary::HeuristicSummary( : heuristic_(heuristic) { recording_ = true; switch (heuristic) { + case ScheduleHeuristic::NoOp: + NoOpScheduler::canScheduleRunTime(fusion, runtime_info, this); + break; case ScheduleHeuristic::PointWise: getPointwiseHeuristics(fusion, runtime_info, this); PointWiseScheduler::canScheduleRunTime(fusion, runtime_info, this); @@ -1475,14 +1714,39 @@ HeuristicSummary::HeuristicSummary( void HeuristicSummary::validate() const { switch (heuristic_) { + case ScheduleHeuristic::NoOp: { + // TODO: need to cache the dynamically zero inputs? + break; + } + case ScheduleHeuristic::Transpose: case ScheduleHeuristic::PointWise: { - TORCH_INTERNAL_ASSERT(entry_type_map_.count(EntryType::DOMAIN_MAP)); + if (heuristic_ == ScheduleHeuristic::PointWise) { + TORCH_INTERNAL_ASSERT(entry_type_map_.count(EntryType::DOMAIN_MAP)); + TORCH_INTERNAL_ASSERT( + entry_type_map_.count(EntryType::REFERENCE_TENSORS)); + TORCH_INTERNAL_ASSERT( + entry_type_map_.count(EntryType::VECTORIZABLE_INPUTS_AND_OUTPUTS)); + TORCH_INTERNAL_ASSERT( + entry_type_map_.count(EntryType::BROADCAST_BYTE_MULTIPLES)); + TORCH_INTERNAL_ASSERT( + entry_type_map_.count(EntryType::CAN_SCHEDULE_TRANSPOSE)); + auto can_schedule_transpose = + entry_type_map_.at(EntryType::CAN_SCHEDULE_TRANSPOSE) + ->as>() + ->get(); + if (!*can_schedule_transpose) { + break; + } + } TORCH_INTERNAL_ASSERT( - entry_type_map_.count(EntryType::REFERENCE_TENSORS)); + entry_type_map_.count(EntryType::TRANSPOSE_DOMAIN_MAP)); + TORCH_INTERNAL_ASSERT(entry_type_map_.count( + EntryType::INPUTS_AND_OUTPUTS_INNER_DIM_GROUPS)); TORCH_INTERNAL_ASSERT( - entry_type_map_.count(EntryType::VECTORIZABLE_INPUTS_AND_OUTPUTS)); + entry_type_map_.count(EntryType::REFERENCE_TENSORS_FOR_GROUPS)); TORCH_INTERNAL_ASSERT( - entry_type_map_.count(EntryType::BROADCAST_BYTE_MULTIPLES)); + entry_type_map_.count(EntryType::INNER_MOST_DIMS_INFO)); break; } case ScheduleHeuristic::Reduction: { @@ -1512,11 +1776,6 @@ void HeuristicSummary::validate() const { entry_type_map_.count(EntryType::SCOPE_PERSISTENT_FACTOR_INFO)); break; } - case ScheduleHeuristic::Transpose: { - TORCH_INTERNAL_ASSERT(entry_type_map_.count( - EntryType::INPUTS_AND_OUTPUTS_INNER_DIM_GROUPS)); - break; - } default: TORCH_INTERNAL_ASSERT(false, "unknown heuristic"); } @@ -1553,7 +1812,10 @@ HeuristicSummaryEntry::HeuristicSummaryEntry( // Template instantiation for pre-defined cache entries template class HeuristicSummaryEntry; +template class HeuristicSummaryEntry; template class HeuristicSummaryEntry; +template class HeuristicSummaryEntry< + HeuristicCompileTime::ReferenceTensorsForGroups>; template class HeuristicSummaryEntry< HeuristicCompileTime::VectorizableInputsAndOutputs>; template class HeuristicSummaryEntry< @@ -1566,6 +1828,9 @@ template class HeuristicSummaryEntry< template class HeuristicSummaryEntry< HeuristicCompileTime::ScopePersistentFactorInfo>; template class HeuristicSummaryEntry; +template class HeuristicSummaryEntry; +template class HeuristicSummaryEntry< + HeuristicCompileTime::CanScheduleTranspose>; } // namespace cuda } // namespace fuser diff --git a/torch/csrc/jit/codegen/cuda/scheduler/registry.h b/torch/csrc/jit/codegen/cuda/scheduler/registry.h index 7ed8474935c01..8b34094476349 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/registry.h +++ b/torch/csrc/jit/codegen/cuda/scheduler/registry.h @@ -27,6 +27,7 @@ class ExpressionEvaluator; //! segmenter and schedulers. //! It is important that input id encoding should be up to date with any change //! of this class to avoid launching compiled kernels with illegal inputs. + class TORCH_CUDA_CU_API SchedulerRuntimeInfo : public NonCopyable { public: // Max vector size we will consider, in bytes, @@ -112,6 +113,9 @@ class TORCH_CUDA_CU_API SchedulerRuntimeInfo : public NonCopyable { // TODO: Support output tensor pointers std::unordered_map input_ptrs_; + // Copy of aten input tensor strides (in bytes) + std::unordered_map> input_discontig_strides_; + // Cache for getAlignmentSize std::unordered_map alignment_map_; // Cache for getMaxVectorizableWidth diff --git a/torch/csrc/jit/codegen/cuda/scheduler/transpose.cpp b/torch/csrc/jit/codegen/cuda/scheduler/transpose.cpp index 1bdd1d34a0a9a..b7e85cbc1c5e7 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/transpose.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/transpose.cpp @@ -1,7 +1,7 @@ #include #include -#include +#include #include #include #include @@ -24,8 +24,6 @@ namespace cuda { namespace { -constexpr int64_t kMaxTileSize = 32; - // DomainMap uses the ComputeAtMap to find a reference TensorView // that maps to all iterDomains in the fusion. class DomainMap : public pointwise_utils::DomainMap { @@ -47,6 +45,20 @@ class DomainMap : public pointwise_utils::DomainMap { return result; } + IterDomain* getMappedRootDimIn(TensorView* tv, IterDomain* root_dim) const { + // Find the root id mapped to `root_dim` + const auto& root_dom = tv->getRootDomain(); + IterDomain* mapped_id = nullptr; + for (auto i : c10::irange(root_dom.size())) { + if (ca_map_.idGraph().permissiveNodes().permissiveAreMapped( + root_dom[i], root_dim)) { + mapped_id = root_dom[i]; + break; + } + } + return mapped_id; + } + static bool hasAtLeastTwoValidGroups(Fusion* fusion) { FusionGuard fg(fusion); DomainMap domain_map(fusion); @@ -54,19 +66,51 @@ class DomainMap : public pointwise_utils::DomainMap { if (grouped_inputs_outputs.size() < 2) { return false; } - return domain_map.findReferenceFor(grouped_inputs_outputs[0]) != nullptr && - domain_map.findReferenceFor(grouped_inputs_outputs[1]) != nullptr; + auto ref1 = domain_map.findReferenceFor(grouped_inputs_outputs[0]); + auto ref2 = domain_map.findReferenceFor(grouped_inputs_outputs[1]); + if (ref1 == nullptr || ref2 == nullptr) { + return false; + } + // reference 1 is the global reference, so it must have dim mapped the + // innermost dim of both groups + auto innermost2 = scheduler_utils::innerMostRootDim(ref2); + return domain_map.getMappedRootDimIn(ref1, innermost2) != nullptr; } - int getPosMappedTo(TensorView* tv, IterDomain* id) const { + int getInnerLeafDim(TensorView* tv, IterDomain* root_dim) const { + auto mapped_id = getMappedRootDimIn(tv, root_dim); + TORCH_INTERNAL_ASSERT( + mapped_id != nullptr, + "Can not find ID mapped to ", + root_dim, + " in tensor ", + tv); + // Project the root id to leaf id + while (!mapped_id->uses().empty()) { + TORCH_INTERNAL_ASSERT(mapped_id->uses().size() == 1); + auto expr = mapped_id->uses()[0]; + if (expr->isA()) { + mapped_id = expr->as()->inner(); + } else { + auto merge = expr->as(); + TORCH_INTERNAL_ASSERT( + mapped_id == merge->inner(), + "Can not find ID mapped to ", + root_dim, + " in tensor ", + tv); + mapped_id = merge->out(); + } + } + // Find the position of the leaf id const auto& dom = tv->domain()->domain(); for (auto i : c10::irange(dom.size())) { - if (areExactMapped(id, tv->axis(i))) { + if (dom[i] == mapped_id) { return i; } } TORCH_INTERNAL_ASSERT( - false, "Can not find ID mapped to ", id, " in tensor ", tv); + false, "Can not find ID mapped to ", root_dim, " in tensor ", tv); } // Group inputs and outputs of a fusion by its inner most domain. For example @@ -128,6 +172,12 @@ class DomainMap : public pointwise_utils::DomainMap { // Then we still want to T1 and T2 to be grouped together. auto group = scheduler_utils::getInputsOutputsWithInnerDim(tv, true, false); + if (group.empty()) { + // In case that the inner most dim of tv is not found (for example, tv + // is a fusion input with only reductions), we just return a null + // result which will tell the scheduler to reject the fusion + return {}; + } for (auto member_tv : group) { if (grouped.count(member_tv) == 0) { grouped.emplace(member_tv); @@ -178,12 +228,26 @@ class DomainMap : public pointwise_utils::DomainMap { // T0[I0*I1o*I5*I6{1024*1024/4*8}, I1i*I2*I3*I4{32}] void maybeBuildVirtualInnerDims( TransposeParams& params, + int64_t device_multiprocessor_count, + int64_t n_elems, const std::vector& shape_in_ref1, int64_t inner_most1, int64_t inner_most2) { int64_t merged_size1 = shape_in_ref1[inner_most1]; int64_t merged_size2 = shape_in_ref1[inner_most2]; + int64_t actual_tile_size1 = + std::min(merged_size1, params.tile_size1); + int64_t actual_tile_size2 = + std::min(merged_size2, params.tile_size2); + int64_t wave_elements = + device_multiprocessor_count * actual_tile_size1 * actual_tile_size2; + + if (wave_elements >= n_elems) { + // if one full wave can handle all elements, don't create virtual inner dims + return; + } + // merge inner_most1 and inner_most2 left until we are done or we can no // longer do so int64_t dim = inner_most1 - 1; @@ -240,22 +304,49 @@ void maybeBuildVirtualInnerDims( // both virtual innermost dim. // 2. The satisfied one did not merge in anything. For example, // T0[I0{1024*1024}, I1{2}] + // If this is the case, this means that we need to split the large + // inner-most dimension to satisfy the small innermost dimension int64_t large_dim; int64_t split_factor; + bool split_inner_most; if (merged_size1 < params.tile_size1) { if (params.dims_merged_with_2.empty()) { +#if SUPPORT_SPLITTING_INNERMOST_DIM + // https://github.com/csarofeen/pytorch/issues/1964 // case 2 + split_inner_most = true; + large_dim = inner_most2; + split_factor = params.tile_size2; +#else + // disabled due to indexing error return; +#endif + } else { + // case 1 + split_inner_most = false; + large_dim = params.dims_merged_with_2.back(); + auto prev_merged_size2 = merged_size2 / shape_in_ref1[large_dim]; + split_factor = ceilDiv(params.tile_size2, prev_merged_size2); } - large_dim = params.dims_merged_with_2.back(); - split_factor = ceilDiv(params.tile_size1, merged_size1); } else { if (params.dims_merged_with_1.empty()) { +#if SUPPORT_SPLITTING_INNERMOST_DIM + // https://github.com/csarofeen/pytorch/issues/1964 // case 2 + split_inner_most = true; + large_dim = inner_most1; + split_factor = params.tile_size1; +#else + // disabled due to indexing error return; +#endif + } else { + // case 1 + split_inner_most = false; + large_dim = params.dims_merged_with_1.back(); + auto prev_merged_size1 = merged_size1 / shape_in_ref1[large_dim]; + split_factor = ceilDiv(params.tile_size1, prev_merged_size1); } - large_dim = params.dims_merged_with_1.back(); - split_factor = ceilDiv(params.tile_size2, merged_size2); } params.split_before_tiling.push_back({large_dim, split_factor}); // adjust all dims to after-split @@ -271,61 +362,54 @@ void maybeBuildVirtualInnerDims( } // Give the split-out dim to the unsatisfied one, so that both are satisfied. if (merged_size1 < params.tile_size1) { - params.dims_merged_with_2.pop_back(); - params.dims_merged_with_2.push_back(large_dim + 1); + if (!split_inner_most) { + params.dims_merged_with_2.pop_back(); + params.dims_merged_with_2.push_back(large_dim + 1); + } params.dims_merged_with_1.push_back(large_dim); } else { - params.dims_merged_with_1.pop_back(); - params.dims_merged_with_1.push_back(large_dim + 1); + if (!split_inner_most) { + params.dims_merged_with_1.pop_back(); + params.dims_merged_with_1.push_back(large_dim + 1); + } params.dims_merged_with_2.push_back(large_dim); } } -} // namespace - -bool hasAtLeastTwoValidGroups(Fusion* fusion) { - return DomainMap::hasAtLeastTwoValidGroups(fusion); -} - -std::shared_ptr getTransposeHeuristics( - Fusion* fusion, - const at::ArrayRef& runtime_inputs, - HeuristicSummary* data_cache) { - SchedulerRuntimeInfo runtime_info(fusion, runtime_inputs, true); - return getTransposeHeuristics(fusion, runtime_info, data_cache); -} - -std::shared_ptr getTransposeHeuristics( - Fusion* fusion, - SchedulerRuntimeInfo& runtime_info, - HeuristicSummary* data_cache) { - FUSER_PERF_SCOPE("getTransposeHeuristics"); - - FusionGuard fg(fusion); - - // Incase any buffer is of type DataType::Index - DataType index_type = indexModeToDtype(runtime_info.getIndexMode()); - +HeuristicSummaryEntry getDomainMap( + HeuristicSummary* data_cache, + Fusion* fusion) { auto domain_map_entry = - HeuristicSummaryEntry( + HeuristicSummaryEntry( data_cache, [fusion]() { return std::make_unique(fusion); }); - const auto& domain_map = dynamic_cast(domain_map_entry.get()); + return domain_map_entry; +} +HeuristicSummaryEntry +getInputsOutputsGroups(HeuristicSummary* data_cache, DomainMap& domain_map) { auto grouped_inputs_outputs_entry = HeuristicSummaryEntry( data_cache, [&domain_map]() { return std::make_unique>>( domain_map.groupInputsOutputsByInnerDim()); }); - auto grouped_inputs_outputs = grouped_inputs_outputs_entry.get(); + auto& grouped_inputs_outputs = grouped_inputs_outputs_entry.get(); TORCH_INTERNAL_ASSERT( grouped_inputs_outputs.size() >= 2, "Can not find mismatched inner most dim, should use pointwise scheduler."); + return grouped_inputs_outputs_entry; +} + +HeuristicSummaryEntry +getReferenceTensors( + HeuristicSummary* data_cache, + DomainMap& domain_map, + std::vector>& grouped_inputs_outputs) { auto reference_tensors_entry = - HeuristicSummaryEntry( + HeuristicSummaryEntry( data_cache, [&domain_map, &grouped_inputs_outputs]() { std::vector data{ domain_map.findReferenceFor(grouped_inputs_outputs[0]), @@ -340,13 +424,17 @@ std::shared_ptr getTransposeHeuristics( reference1 != nullptr, "Unable to find reference tensor for group 1"); TORCH_INTERNAL_ASSERT( reference2 != nullptr, "Unable to find reference tensor for group 2"); + return reference_tensors_entry; +} - const int64_t device_multiprocessor_count = - (int64_t)at::cuda::getCurrentDeviceProperties()->multiProcessorCount; - - auto ref_root = reference1->getMaybeRFactorDomain(); - std::vector shape_in_ref1; - shape_in_ref1.reserve(reference1->nDims()); +std::pair, int64_t> getShapeInReference( + HeuristicSummary* data_cache, + SchedulerRuntimeInfo& runtime_info, + TensorView* reference, + DomainMap& domain_map) { + auto ref_root = reference->getMaybeRFactorDomain(); + std::vector shape_in_ref; + shape_in_ref.reserve(reference->nDims()); int64_t n_elems = 1; for (size_t ref_i = 0; ref_i < ref_root.size(); ref_i++) { auto id = ref_root[ref_i]; @@ -360,36 +448,175 @@ std::shared_ptr getTransposeHeuristics( ref_root[ref_i]->extent()->toInlineString()); int64_t size = inferred_val->as(); n_elems *= size; - shape_in_ref1.push_back(size); + shape_in_ref.push_back(size); } + return {shape_in_ref, n_elems}; +} - auto params = std::make_shared("Transpose heuristics"); +HeuristicSummaryEntry +getInnerMostDimInfoInReference( + HeuristicSummary* data_cache, + const std::vector& group_references, + TensorView* global_reference, + DomainMap& domain_map) { + auto innermost_info_entry = + HeuristicSummaryEntry( + data_cache, [&]() { + std::vector data; + data.reserve(group_references.size()); + for (auto ref_tv : group_references) { + auto inner_most_id = scheduler_utils::innerMostRootDim(ref_tv); + auto inner_most_pos_in_global_ref = + domain_map.getInnerLeafDim(global_reference, inner_most_id); + data.emplace_back(inner_most_pos_in_global_ref); + } + return std::make_unique>(std::move(data)); + }); + return innermost_info_entry; +} - // If the problem size is small use small tile sizes. - if (n_elems < device_multiprocessor_count * kMaxTileSize * kMaxTileSize) { - params->tile_size1 = 8; - params->tile_size2 = 8; - // TODO: I was trying the following but I got silent wrong result - // params->tile_size1 = 8; - // params->tile_size2 = 4; - // This should not happen, because the correctness should be irrevalent to - // schedulers. We don't have to use tile size (8, 4), but we need to fix our - // bug in codegen. +} // namespace + +std::string getTransposeRuntimeRejectReason( + Fusion* fusion, + HeuristicSummary* data_cache, + SchedulerRuntimeInfo& runtime_info) { + auto domain_map_entry = getDomainMap(data_cache, fusion); + auto& domain_map = dynamic_cast(domain_map_entry.get()); + auto grouped_inputs_outputs_entry = + getInputsOutputsGroups(data_cache, domain_map); + auto grouped_inputs_outputs = grouped_inputs_outputs_entry.get(); + auto reference_tensors_entry = + getReferenceTensors(data_cache, domain_map, grouped_inputs_outputs); + auto reference_tensors = reference_tensors_entry.get(); + TensorView* reference1 = reference_tensors[0]; + + auto pair = + getShapeInReference(data_cache, runtime_info, reference1, domain_map); + auto& shape_in_ref1 = pair.first; + auto& n_elems = pair.second; + + auto innermost_info_entry = getInnerMostDimInfoInReference( + data_cache, reference_tensors, reference1, domain_map); + auto innermost_info = innermost_info_entry.get(); + + constexpr size_t default_tile_elements = + TransposeParams::getDefaultTileSize() * + TransposeParams::getDefaultTileSize(); + + // don't schedule with transpose scheduler if less than a full wave + const int64_t device_multiprocessor_count = + (int64_t)at::cuda::getCurrentDeviceProperties()->multiProcessorCount; + auto elements_per_wave = device_multiprocessor_count * default_tile_elements; + if (elements_per_wave > n_elems) { + return "Transpose scheduler does not perform well on small problem sizes."; } - // Expand inner-most dims to virtual inner-most dims so that the inner-most - // dims has at least tile_size elements - auto inner_most_id1 = scheduler_utils::innerMostRootDim(reference1); - auto inner_most_id2 = scheduler_utils::innerMostRootDim(reference2); + auto inner_most_pos1_in_ref1 = innermost_info[0]; + auto inner_most_pos2_in_ref1 = innermost_info[1]; + + auto inner_size1 = shape_in_ref1[inner_most_pos1_in_ref1]; + auto inner_size2 = shape_in_ref1[inner_most_pos2_in_ref1]; + + // For cases like + // transpose(T0[1000000000, 2, 2], 1, 2) + // the pointwise scheduler should provide better performance, because it + // provides coalesced memory access + if (inner_size1 * inner_size2 < default_tile_elements) { + auto inner_elements = inner_size1 * inner_size2; + for (int64_t i = inner_most_pos2_in_ref1 + 1; i < inner_most_pos1_in_ref1; + i++) { + inner_elements *= shape_in_ref1[i]; + } + // note that the algorithm here is only an approximation because it only + // checks reference1. In principle, we need to check all inputs and outputs + // to get an accurate result, but that is too much work. I think checking + // only reference 1 is fine for now. Below is an example where the + // approximation here will not work: + // T0[10000000, 2, 3] (reference 1) + // T1[2, 10000000, 3] input/output + // T2[2, 10000000, 3] input/output + // T3[2, 10000000, 3] input/output + // T4[3, 10000000, 2] input/output + // T5[3, 10000000, 2] input/output + if (inner_elements < default_tile_elements) { + return "Inner transpose of small dimensions should be scheduled by the " + "pointwise scheduler because it provides better memory coalescing"; + } + } + +#if !SUPPORT_SPLITTING_INNERMOST_DIM + if (n_elems / inner_size1 < TransposeParams::getDefaultTileSize() || + n_elems / inner_size2 < TransposeParams::getDefaultTileSize()) { + return "Splitting of inner most dim for the creation of virtual inner most dim " + "is disabled due to indexing bug, skipping this case at runtime for now" + "See: https://github.com/csarofeen/pytorch/issues/1964"; + } +#endif + + return ""; +} + +bool hasAtLeastTwoValidGroups(Fusion* fusion) { + return DomainMap::hasAtLeastTwoValidGroups(fusion); +} + +std::shared_ptr getTransposeHeuristics( + Fusion* fusion, + const at::ArrayRef& runtime_inputs, + HeuristicSummary* data_cache) { + SchedulerRuntimeInfo runtime_info(fusion, runtime_inputs, true); + return getTransposeHeuristics(fusion, runtime_info, data_cache); +} + +std::shared_ptr getTransposeHeuristics( + Fusion* fusion, + SchedulerRuntimeInfo& runtime_info, + HeuristicSummary* data_cache) { + FUSER_PERF_SCOPE("getTransposeHeuristics"); + + FusionGuard fg(fusion); + + // Incase any buffer is of type DataType::Index + DataType index_type = indexModeToDtype(runtime_info.getIndexMode()); + + auto domain_map_entry = getDomainMap(data_cache, fusion); + auto& domain_map = dynamic_cast(domain_map_entry.get()); + auto grouped_inputs_outputs_entry = + getInputsOutputsGroups(data_cache, domain_map); + auto grouped_inputs_outputs = grouped_inputs_outputs_entry.get(); + auto reference_tensors_entry = + getReferenceTensors(data_cache, domain_map, grouped_inputs_outputs); + auto reference_tensors = reference_tensors_entry.get(); + TensorView* reference1 = reference_tensors[0]; + TensorView* reference2 = reference_tensors[1]; + auto pair = + getShapeInReference(data_cache, runtime_info, reference1, domain_map); + auto& shape_in_ref1 = pair.first; + auto& n_elems = pair.second; + + const int64_t device_multiprocessor_count = + (int64_t)at::cuda::getCurrentDeviceProperties()->multiProcessorCount; - auto inner_most_pos1_in_ref1 = - domain_map.getPosMappedTo(reference1, inner_most_id1); - auto inner_most_pos2_in_ref1 = - domain_map.getPosMappedTo(reference1, inner_most_id2); + auto innermost_info_entry = getInnerMostDimInfoInReference( + data_cache, reference_tensors, reference1, domain_map); + auto innermost_info = innermost_info_entry.get(); + auto inner_most_pos1_in_ref1 = innermost_info[0]; + auto inner_most_pos2_in_ref1 = innermost_info[1]; + + auto params = std::make_shared("Transpose heuristics"); + + // Expand inner-most dims to virtual inner-most dims so that the inner-most + // dims has at least tile_size elements // See note [Supporting small transpose dimensions] maybeBuildVirtualInnerDims( - *params, shape_in_ref1, inner_most_pos1_in_ref1, inner_most_pos2_in_ref1); + *params, + device_multiprocessor_count, + n_elems, + shape_in_ref1, + inner_most_pos1_in_ref1, + inner_most_pos2_in_ref1); // Note [vectorization and unroll of input and output] // @@ -482,13 +709,20 @@ std::shared_ptr getTransposeHeuristics( std::cerr << "\n===== Transpose Stats ========\n" << "inputs: " << ir_utils::toString(fusion->inputs()) << "\n" << "outputs: " << ir_utils::toString(fusion->outputs()) << "\n" + << "shape: " << shape_in_ref1 << "\n" << "num_elems: " << n_elems << "\n" << "n_input_tensors: " << n_input_tensors << "\n" << "max_input_dtype_size: " << max_input_dtype_size << "\n" << "group 1: " << ir_utils::toString(grouped_inputs_outputs[0]) << "\n" + << "reference1: " << reference1 << "\n" + << "inner_most_id1 position: " << inner_most_pos1_in_ref1 + << " (in reference 1)\n" << "group 2: " << ir_utils::toString(grouped_inputs_outputs[1]) - << std::endl; + << "\n" + << "reference2: " << reference2 << "\n" + << "inner_most_id2 position: " << inner_most_pos2_in_ref1 + << " (in reference 1)" << std::endl; if (!params->split_before_tiling.empty() || !params->dims_merged_with_1.empty() || !params->dims_merged_with_2.empty()) { @@ -565,17 +799,19 @@ void scheduleTranspose(Fusion* fusion, TransposeParams params) { auto grouped_inputs_outputs = domain_map.groupInputsOutputsByInnerDim(); TORCH_INTERNAL_ASSERT(grouped_inputs_outputs.size() >= 2); - // We need something similar to `cacheFork` for input tensors in group 2. We - // need this because we will want to propagate to the entire DAG except group - // 2 and its cached inputs, so we need to make sure the DAG is still connected - // if we remove group and its cached inputs. For example - // t0 - // | - // cache - // | | - // t1 t2 - // if groups = {{t1, t2}, {t0}}, then removing {t0, cache} from the DAG will - // make it disconnected. + /* + * We need something similar to `cacheFork` for input tensors in group 2. We + * need this because we will want to propagate to the entire DAG except group + * 2 and its cached inputs, so we need to make sure the DAG is still connected + * if we remove group and its cached inputs. For example + * t0 + * | + * cache + * / \ + * t1 t2 + * if groups = {{t1, t2}, {t0}}, then removing {t0, cache} from the DAG will + * make it disconnected. + */ std::unordered_set group2_and_cached_inputs( grouped_inputs_outputs[1].begin(), grouped_inputs_outputs[1].end()); for (auto tv : grouped_inputs_outputs[1]) { @@ -643,9 +879,9 @@ void scheduleTranspose(Fusion* fusion, TransposeParams params) { // merge with inner most dims to get virtual inner most dims size_t inner_most_pos1_in_ref1 = - domain_map.getPosMappedTo(reference1, inner_most_id1); + domain_map.getInnerLeafDim(reference1, inner_most_id1); size_t inner_most_pos2_in_ref1 = - domain_map.getPosMappedTo(reference1, inner_most_id2); + domain_map.getInnerLeafDim(reference1, inner_most_id2); if (merged1.has_value()) { if (inner_most_pos1_in_ref1 < *merged1) { reference1->reorder( @@ -895,9 +1131,7 @@ void scheduleTranspose(Fusion* fusion, TransposeParams params) { } // Inline - InlinePropagator inline_propagator( - reference1, -1, ComputeAtMode::MostInlined); - entire_dag.traverse(&inline_propagator); + inlineMost(); } } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/scheduler/transpose.h b/torch/csrc/jit/codegen/cuda/scheduler/transpose.h index 0cf6920ea058b..c1a4ab6efb6ae 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/transpose.h +++ b/torch/csrc/jit/codegen/cuda/scheduler/transpose.h @@ -5,6 +5,8 @@ #include #include +#define SUPPORT_SPLITTING_INNERMOST_DIM 0 + namespace torch { namespace jit { namespace fuser { @@ -100,6 +102,13 @@ TORCH_CUDA_CU_API LaunchParams scheduleTranspose( //! groups, each with a fully broadcasted reference tensor. TORCH_CUDA_CU_API bool hasAtLeastTwoValidGroups(Fusion* fusion); +// If can schedule at runtime, returns empty string, otherwise returns the +// reason why we should not schedule at runtime. +TORCH_CUDA_CU_API std::string getTransposeRuntimeRejectReason( + Fusion* fusion, + HeuristicSummary* data_cache, + SchedulerRuntimeInfo& runtime_info); + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/scheduler/transpose_heuristic.h b/torch/csrc/jit/codegen/cuda/scheduler/transpose_heuristic.h index d672b6dc965bd..5e56278a7f16b 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/transpose_heuristic.h +++ b/torch/csrc/jit/codegen/cuda/scheduler/transpose_heuristic.h @@ -21,6 +21,10 @@ class TransposeParams : public HeuristicParams { return 128; } + static constexpr size_t getDefaultTileSize() { + return 32; + } + // See note [Supporting small transpose dimensions], all dims are positions in // reference1 std::vector> split_before_tiling = {}; @@ -37,10 +41,10 @@ class TransposeParams : public HeuristicParams { // https://github.com/csarofeen/pytorch/pull/1854#discussion_r928143729 // Tile size for the inner most dim of tensors in the first group - size_t tile_size1 = 32; + size_t tile_size1 = getDefaultTileSize(); // Tile size for the inner most dim of tensors in the second group - size_t tile_size2 = 32; + size_t tile_size2 = getDefaultTileSize(); using HeuristicParams::HeuristicParams; @@ -65,8 +69,7 @@ class TransposeParams : public HeuristicParams { std::stringstream ss; ss << "\n===== Transpose Parameters ========\n" << (tag == "" ? "" : "Tag: ") << tag << " Transpose Characteristics:\n" - << " Gridx: " << lparams.gdimx() << " BlckX: " << lparams.bdimx() - << "\n"; + << " BlckX: " << lparams.bdimx() << "\n"; ss << " input tile size: " << tile_size1 << "\n"; ss << " output tile size: " << tile_size2 << "\n"; int elements_per_tile = tile_size1 * tile_size2; diff --git a/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp b/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp index 6c0c8087270e9..4ba6b241e455c 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp @@ -11,6 +11,8 @@ #include #include +#include + namespace torch { namespace jit { namespace fuser { @@ -298,21 +300,6 @@ void parallelizeAllLike( } } -void computeAtInputs(TensorView* consumer, int pos, ComputeAtMode mode) { - for (auto inp_tv : ir_utils::inputTvsOf(consumer)) { - inp_tv->computeAt(consumer, pos, mode); - } -} - -void computeWithOutputs(TensorView* producer, int pos, ComputeAtMode mode) { - for (auto out_tv : ir_utils::outputTvsOf(producer)) { - if (out_tv == producer) { - continue; - } - producer->computeWith(out_tv, pos, mode); - } -} - namespace { // Find the resolution points of the persistent buffers in the provided @@ -1409,7 +1396,109 @@ std::vector getInputsOutputsWithInnerDim( return vectorizable_tensors; } -std::vector getBroadcastMultiples( +namespace { +// Holder return struct for the below function. +struct DisjointViewSetInfo { + // const* to the disjoint set in disjoint_view_set passed in to + // getDisjointViewSetsOf each iterdomain in the rfactor of ref is mapped to. + // + // WARNING: these pointers are relative to the disjoint_view_set reference + // passed into getDisjointViewSetsOf it's the user's responsibillity to + // maintain the lifetime of that reference to match this vector. + std::vector*> disjoint_sets_of_ref; + + // Unique ID associated to the disjoint view group the rfactor id belongs to + // in disjoint_sets_of_ref. It's straight forward to map from + // disjoint_sets_of_ref to the vector, but not the other way around. + std::vector disjoint_set_ids; + + // TensorView reference the above vectors are relative to. + TensorView* ref; +}; + +// Returns disjoint view sets mapped onto the given reference. Returns a pair +// of vectors of size rfactorDomain of reference. Vector of +// VectorOfUniqueEntries returns a const* to the disjoint set in +// disjoint_view_set the iterdomain is mapped to. Integer vector represents +// which disjoint view group the rfactor id belongs to. It's straight forward +// to map from the former to the latter, but not the latter to former. +// +// Since we return a const* to entries in disjoint_view_set, it must be passed +// in as a reference. Algorithm is N^2 based on number of dims in reference, +// but generating the disjoint view set is likely the limiter on perf of this +// function. +DisjointViewSetInfo getDisjointViewSetsOf( + Fusion* fusion, + TensorView* of, + DisjointSets& disjoint_view_set) { + auto rfactor_dom = of->getMaybeRFactorDomain(); + if (rfactor_dom.size() == 0) { + return {}; + } + + // Start naming id's based on 0 so the inner most dimension will always be + // 0, then as groups are discovered marching to the left their id will + // increase. i.e. we could have something like [0, 3, 1, 2, 1, 0] as a + // result. + std::vector disjoint_group_ids(rfactor_dom.size(), -1); + std::vector*> disjoint_set_of_id( + rfactor_dom.size(), nullptr); + int current_group_id = 0; + int ref_dim_i = rfactor_dom.size() - 1; + + while (ref_dim_i >= 0) { + if (disjoint_group_ids[ref_dim_i] != -1) { + // Already put in a group, continue + ref_dim_i--; + continue; + } + + const auto& ref_group = + disjoint_view_set.getDisjointSetOf(rfactor_dom[ref_dim_i]); + + int other_dim_i = ref_dim_i; + while (other_dim_i >= 0) { + const auto& other_group = + disjoint_view_set.getDisjointSetOf(rfactor_dom[other_dim_i]); + if (&ref_group == &other_group) { + disjoint_group_ids[other_dim_i] = current_group_id; + disjoint_set_of_id[other_dim_i] = &ref_group; + } + other_dim_i--; + } + + ref_dim_i--; + current_group_id++; + } + + TORCH_INTERNAL_ASSERT( + std::none_of( + disjoint_group_ids.begin(), + disjoint_group_ids.end(), + [](int i) { return i == -1; }), + "Failed to generate the view disjoint groups of the reference ", + of->toString()); + + TORCH_INTERNAL_ASSERT( + std::none_of( + disjoint_set_of_id.begin(), + disjoint_set_of_id.end(), + [](const VectorOfUniqueEntries* ptr) { + return ptr == nullptr; + }), + "Failed to generate the view disjoint groups of the reference ", + of->toString()); + + DisjointViewSetInfo info; + info.disjoint_sets_of_ref = disjoint_set_of_id; + info.disjoint_set_ids = disjoint_group_ids; + info.ref = of; + + return info; +} +} // namespace + +BroadcastMultipleInformation getBroadcastMultiples( TensorView* reference_tv, DataType index_type) { auto fusion = reference_tv->fusion(); @@ -1418,6 +1507,13 @@ std::vector getBroadcastMultiples( std::vector multiples( reference_tv->getMaybeRFactorDomain().size()); + auto disjoint_view_sets = disjointViewSets(fusion); + auto disjoint_set_information = scheduler_utils::getDisjointViewSetsOf( + fusion, reference_tv, disjoint_view_sets); + + auto ref_disjoint_sets = disjoint_set_information.disjoint_sets_of_ref; + auto ref_disjoint_set_ids = disjoint_set_information.disjoint_set_ids; + // All input or output tensor views std::vector in_out_tvs; { @@ -1427,8 +1523,8 @@ std::vector getBroadcastMultiples( in_out_tvs.insert(in_out_tvs.end(), out_tvs.begin(), out_tvs.end()); } - // Shouldn't matter if we use EXACT or PERMISSIVE mapping mode for compute at - // map as we're just looking at the root mappings. + // Shouldn't matter if we use EXACT or PERMISSIVE mapping mode for compute + // at map as we're just looking at the root mappings. auto ca_map = ComputeAtMap(fusion); auto ref_root_domain = reference_tv->getMaybeRFactorDomain(); @@ -1448,35 +1544,60 @@ std::vector getBroadcastMultiples( if (ref_id->isBroadcast() || ref_id->isReduction()) { continue; } - auto map_it = std::find_if( - in_out_tv_domain_list.begin(), - in_out_tv_domain_list.end(), - [&ref_id, &ca_map](IterDomain* in_out_tv_id) { - return ca_map.areMapped(in_out_tv_id, ref_id, IdMappingMode::EXACT); - }); - if (map_it == in_out_tv_domain_list.end()) { + bool ref_id_has_view_transforms = std::count( + ref_disjoint_set_ids.begin(), + ref_disjoint_set_ids.end(), + ref_disjoint_set_ids[ref_i]) > 1; + + // Could have multiple mappings if there's view transforms + std::vector mapped_ids; + if (!ref_id_has_view_transforms) { + auto mapped_it = std::find_if( + in_out_tv_domain_list.begin(), + in_out_tv_domain_list.end(), + [&ref_id, &ca_map](IterDomain* in_out_tv_id) { + return ca_map.areMapped( + in_out_tv_id, ref_id, IdMappingMode::EXACT); + }); + if (mapped_it != in_out_tv_domain_list.end()) { + mapped_ids.push_back(*mapped_it); + } + } else { + for (auto in_out_id : in_out_tv_domain) { + if (ref_disjoint_sets[ref_i]->has(in_out_id)) { + mapped_ids.push_back(in_out_id); + } + } + } + + // Nothing maps to reference, no contribution to multiples for this dim + if (mapped_ids.empty()) { continue; } - // If input/output id is broadcast or reduction - if ((*map_it)->isBroadcast() || (*map_it)->isReduction()) { + if (std::all_of(mapped_ids.begin(), mapped_ids.end(), [](IterDomain* id) { + return id->isReduction() || id->isBroadcast(); + })) { continue; } + // If any iteration domain in the input or output that's mapped through + // the view disjoint set is not a reduction or broadcast, assume it's a + // full dimension for the sake of the pointwise scheduler. mapped_axes[ref_i] = true; - in_out_tv_domain_list.erase(map_it); } // For each break point position if there an lhs or rhs multiple based on - // this tensor add it to the global multiplier + // this tensor add it to the global multiplier. The only time we consider + // we can benefit from broadcast is if the entire left or right side the + // break point is all broadcasts. { bool rhs = false; bool lhs = false; auto dtype_size = dataTypeSize(in_out_tv->getDataType().value(), index_type); - for (size_t mapped_axes_i = 0; mapped_axes_i < mapped_axes.size(); - mapped_axes_i++) { + for (auto mapped_axes_i : c10::irange(mapped_axes.size())) { auto lhs_i = mapped_axes_i; auto rhs_i = mapped_axes.size() - 1 - mapped_axes_i; @@ -1493,91 +1614,10 @@ std::vector getBroadcastMultiples( } } } - - return multiples; -} - -size_t collectMaxVectorizeSizeWithContigMerge( - TensorView* tv, - IterDomain* leaf_merged_domain, - size_t max_vector_size_in_byte, - ExpressionEvaluator& expression_evaluator, - DataType index_type) { - // Maybe too conservative, but only handles fully contiguous tensors - // TODO: Relax the contiguity constraint to be similar to that in index - // computing. Just looking for all merged root domains in the right order, all - // merged root dimensions are contiguous, all merged root dimensions are next - // to eachother (exlcuding broadcast). - if (std::any_of( - tv->domain()->contiguity().begin(), - tv->domain()->contiguity().end(), - [](const auto contig) { return !contig; })) { - return 1; - } - - auto dtype_size = dataTypeSize(tv->dtype(), index_type); - const size_t max_vector_size = max_vector_size_in_byte / dtype_size; - - // Assume no halo-related expression appears in the fusion. No - // broadcast is merged, so indexability can be assumed to be true. - ContigIDs contigIds( - {leaf_merged_domain}, - tv->getMaybeRFactorDomain(), - tv->domain()->contiguity(), - {}, - {}, - true, - true); - - auto innermost_root_id = tv->getMaybeRFactorDomain().back(); - auto indexed_id = contigIds.rootToIndexedID().at(innermost_root_id); - - size_t merged_size = 1; - // If the indexed ID is a contig merged domain, i.e., it is - // different from innermost_root_id, we accumulate the extents of - // all the root domains covered by the contig indexed ID. Otherwise, - // just look at the extent of the innermost root ID. - if (indexed_id != innermost_root_id) { - const auto& within_root = contigIds.withinContigIDs().at(indexed_id); - for (auto root_id : tv->getMaybeRFactorDomain()) { - if (within_root.find(root_id) == within_root.end()) { - continue; - } - auto maybe_dimension_size = - expression_evaluator.evaluate(root_id->extent()); - TORCH_INTERNAL_ASSERT( - maybe_dimension_size.has_value(), - "Unknown extent of tv: ", - tv->toString(), - ", id: ", - root_id->toString()); - merged_size *= maybe_dimension_size->as(); - } - } else { - auto maybe_dimension_size = - expression_evaluator.evaluate(innermost_root_id->extent()); - TORCH_INTERNAL_ASSERT( - maybe_dimension_size.has_value(), - "Unknown extent of tv: ", - tv->toString(), - ", id: ", - innermost_root_id->toString()); - merged_size = maybe_dimension_size->as(); - } - - size_t vector_size = 1; - size_t next_vector_size = vector_size * 2; - - // Try until vector size exceeds the max allowed size - while (next_vector_size <= max_vector_size) { - if (merged_size % next_vector_size != 0) { - break; - } - vector_size = next_vector_size; - next_vector_size *= 2; - } - - return vector_size; + BroadcastMultipleInformation bcast_info; + bcast_info.view_disjoint_set_ids = ref_disjoint_set_ids; + bcast_info.broadcast_multiples = multiples; + return bcast_info; } namespace matmul_utils { @@ -1811,7 +1851,7 @@ c10::optional getMaybeRootIfInnermostTiled( } // namespace -TORCH_CUDA_CU_API void orderTiledConcreteIdAsRoot(TensorView* tv) { +void orderTiledConcreteIdAsRoot(TensorView* tv) { auto ndims = tv->nDims(); // Keep track of the left most position where we will @@ -1911,9 +1951,7 @@ TORCH_CUDA_CU_API void orderTiledConcreteIdAsRoot(TensorView* tv) { } // namespace matmul_utils //! Propagate current transformations on from_tv to all graphs -TORCH_CUDA_CU_API void transformPropagateToAllFrom( - TensorView* from_tv, - int pos) { +void transformPropagateToAllFrom(TensorView* from_tv, int pos) { TransformPropagator propagator(from_tv, pos); MaxRootDomainInfoSpanningTree(from_tv, nullptr).traverse(&propagator); } @@ -2139,181 +2177,218 @@ void BoundedDirectionalTransformPropagator::bothWays( propagate(from, pos, included_tvs, *options); } -// Grab all values and expressions used to make the merged_domain and remove -// them from the fusion -void cleanUpInnermostMergedDomains( - const std::vector& root_domain, - IterDomain* merged_domain) { - TORCH_INTERNAL_ASSERT(merged_domain != nullptr); - TORCH_INTERNAL_ASSERT(!root_domain.empty()); - - std::unordered_set root_set({root_domain.begin(), root_domain.end()}); +DisjointSets disjointViewSets(Fusion* fusion) { + // Start from the exact iter domain graph of the fusion + IterDomainGraph id_graph(fusion); + auto disjoint_view_ids = id_graph.exactNodes(); - auto vals = DependencyCheck::getAllValsBetween(root_set, {merged_domain}); - - for (auto it = vals.rbegin(); it != vals.rend(); ++it) { - TORCH_INTERNAL_ASSERT((*it)->isA()); - auto id = (*it)->as(); - if (root_set.find(id) != root_set.end()) { - continue; + // If iter domains are involved in any transformation from root domains to + // rfactor domains they should be considered "contaminated". + for (auto tv : ir_utils::allTvs(fusion)) { + for (auto expr : StmtSort::getExprs( + fusion, + {tv->getMaybeRFactorDomain().begin(), + tv->getMaybeRFactorDomain().end()})) { + if (expr->isA()) { + auto merge = expr->as(); + disjoint_view_ids.mapEntries(merge->inner(), merge->out()); + disjoint_view_ids.mapEntries(merge->outer(), merge->out()); + } else if (expr->isA()) { + auto split = expr->as(); + disjoint_view_ids.mapEntries(split->in(), split->inner()); + disjoint_view_ids.mapEntries(split->in(), split->outer()); + } else { + TORCH_INTERNAL_ASSERT( + false, "Expression type: ", expr->toString(), " not supported."); + } } - Fusion* fusion = id->container()->as(); - auto id_def = id->definition(); - TORCH_INTERNAL_ASSERT( - id_def->isA(), - "Invalid ID: ", - id->toString(), - ". Expected definition of a Merge expression: ", - (id_def != nullptr ? id_def->toString() : "nullptr")); - fusion->removeExpr(id_def); - fusion->removeVal(id); } + return disjoint_view_ids; } -// Merge innermost domains for finding the widest vectorizable -// size. Return the merged domain or nullptr if no merge is done. -IterDomain* mergeInnermostDomains( - const std::vector& domain, - int num_merged_domains) { - const auto ndims = domain.size(); - IterDomain* merged_id = nullptr; - bool is_merge_done = false; - for (const auto i : c10::irange(num_merged_domains)) { - auto id = domain.at(ndims - 1 - i); - // broadcast and trivial reductions are ignored - if (id->isBroadcast() || id->isTrivialReduction()) { - continue; - } - if (merged_id == nullptr) { - merged_id = id; - } else { - auto id_inner = merged_id; - auto id_outer = id; - merged_id = IterDomain::merge(id_outer, id_inner); - is_merge_done = true; - } - } - return is_merge_done ? merged_id : nullptr; -} +bool allMatchingViews(Fusion* fusion) { + // Start from the exact iter domain graph of the fusion + IterDomainGraph id_graph(fusion); + auto exact_disjoint_set = id_graph.exactNodes(); -//! Attempt to expand vectorized domains to contig merged domains. Break point -//! identifies the point in which you can't propagate contiguous merges. For -//! example in pointwise this is the point where we want to split the -//! parallelization to take advantage of broadcast, and for reduction schedulers -//! it's the point where we switch from a reduction domain to an iter domain (or -//! vice versa). -size_t expandVectorizationToContigMergedDomains( - Fusion* fusion, - SchedulerRuntimeInfo& runtime_info, - const std::vector vectorizable_inputs_outputs, - TensorView* reference_tv, - int break_point, - size_t default_word_size) { - size_t max_expand_size = SchedulerRuntimeInfo::max_alignment_size_in_byte; - size_t common_alignment_size = - SchedulerRuntimeInfo::max_alignment_size_in_byte; + auto view_exprs = ir_utils::getViewOps(fusion); + if (view_exprs.empty()) { + return true; + } - for (auto inp_out : vectorizable_inputs_outputs) { - auto dtype_size = dataTypeSize( - inp_out->dtype(), indexModeToDtype(runtime_info.getIndexMode())); + std::vector all_view_outs; - max_expand_size = std::min( - max_expand_size, - SchedulerRuntimeInfo::max_alignment_size_in_byte / dtype_size); - max_expand_size = std::min( - max_expand_size, runtime_info.getMaxVectorizableWidth(inp_out)); - common_alignment_size = - std::min(common_alignment_size, runtime_info.getAlignmentSize(inp_out)); + for (auto view_expr : view_exprs) { + auto outs = ir_utils::filterByType(view_expr->outputs()); + all_view_outs.insert(all_view_outs.end(), outs.begin(), outs.end()); } - // If there's no possibility to increase vector size of provided tensors, then - // don't bother doing a more complex analysis to try and do so, just return - // early. - if (max_expand_size == default_word_size) { - return default_word_size; - } + TORCH_INTERNAL_ASSERT( + all_view_outs.size() > 0, + "Found view operations but can't find any output tensor views."); - auto ca_map = ComputeAtMap(fusion); + auto first_out_tv = *all_view_outs.begin(); + auto first_root_dom = + TensorDomain::noReductions(first_out_tv->getRootDomain()); + auto first_rfactor_dom = + TensorDomain::noReductions(first_out_tv->getRFactorDomain()); - // Merge the domains right of the break point - const auto& ref_root = reference_tv->getMaybeRFactorDomain(); - const int num_merged_domains = - static_cast(ref_root.size()) - static_cast(break_point); + for (auto other_out_tv : all_view_outs) { + if (other_out_tv == first_out_tv) { + continue; + } - // No expansion with no merged domain - if (num_merged_domains == 0) { - return default_word_size; - } + auto other_root_dom = + TensorDomain::noReductions(other_out_tv->getRootDomain()); + auto other_rfactor_dom = + TensorDomain::noReductions(other_out_tv->getRFactorDomain()); - // Merge the domains but don't modify TensorDomain - auto merged_domain = mergeInnermostDomains(ref_root, num_merged_domains); + if (first_root_dom.size() != other_root_dom.size() || + first_rfactor_dom.size() != other_rfactor_dom.size()) { + return false; + } + { + std::vector> zipped_ids; + + std::transform( + first_root_dom.begin(), + first_root_dom.end(), + other_root_dom.begin(), + std::back_inserter(zipped_ids), + [](IterDomain* first, IterDomain* second) { + return std::make_pair(first, second); + }); - // No expansion is done if no merge is done. - if (merged_domain == nullptr) { - return default_word_size; - } + if (std::any_of( + zipped_ids.begin(), + zipped_ids.end(), + [&exact_disjoint_set]( + std::pair id_pair) { + return !exact_disjoint_set.strictAreMapped( + id_pair.first, id_pair.second); + })) { + return false; + } + } + { + std::vector> zipped_ids; + + std::transform( + first_rfactor_dom.begin(), + first_rfactor_dom.end(), + other_rfactor_dom.begin(), + std::back_inserter(zipped_ids), + [](IterDomain* first, IterDomain* second) { + return std::make_pair(first, second); + }); - // Find the vectorizable word size with the merged domains - size_t word_size = scheduler_utils::collectMaxVectorizeSizeWithContigMerge( - reference_tv, - merged_domain, - common_alignment_size, - runtime_info.expressionEvaluator(), - indexModeToDtype(runtime_info.getIndexMode())); + if (std::any_of( + zipped_ids.begin(), + zipped_ids.end(), + [&exact_disjoint_set]( + std::pair id_pair) { + return !exact_disjoint_set.strictAreMapped( + id_pair.first, id_pair.second); + })) { + return false; + } + } + } + return true; +} - cleanUpInnermostMergedDomains(ref_root, merged_domain); +bool breakIsDisjoint(std::vector group_ids, int pos) { + if (pos < 0) { + pos += group_ids.size(); + } + TORCH_INTERNAL_ASSERT( + pos >= 0 && pos <= group_ids.size(), + "Invalid position, size of vec is ", + group_ids.size(), + " but position is ", + pos); - // Stop if the reference doesn't get a larger word size. - if (word_size <= default_word_size) { - return default_word_size; + if (pos == 0 || pos == group_ids.size()) { + return true; } - // Check the other TVs and take the minimum of the valid word sizes - for (const auto tv : vectorizable_inputs_outputs) { - if (tv == reference_tv) { - continue; - } + std::unordered_set left_ints(group_ids.begin(), group_ids.begin() + pos); - const auto& tv_root = tv->getMaybeRFactorDomain(); + for (auto i = pos; i < group_ids.size(); i++) { + if (left_ints.count(group_ids[i]) > 0) { + return false; + } + } + return true; +} - int tv_num_merged_domains = 0; - for (const auto i : c10::irange(num_merged_domains)) { - if (i == tv_root.size()) { - break; +std::unordered_map domainReorderAsRfactorMap(TensorView* tv) { + FusionGuard fg(tv->fusion()); + auto transform_exprs = StmtSort::getExprs( + tv->fusion(), + {tv->domain()->domain().begin(), tv->domain()->domain().end()}); + // simply update this vector of id's as progressing through the transformation + // expressions. We'll always insert the result of split in the location of the + // input, and insert the merge result in the position of the inner dimension. + + auto reordered_ids = tv->getMaybeRFactorDomain(); + for (const auto* expr : transform_exprs) { + if (const Split* split = dynamic_cast(expr)) { + auto find_it = + std::find(reordered_ids.begin(), reordered_ids.end(), split->in()); + if (find_it == reordered_ids.end()) { + // Transformations before rfactor, ignore those. + continue; } - auto ref_id = ref_root.at(ref_root.size() - 1 - i); - IterDomain* tv_id = tv_root.at(tv_root.size() - 1 - i); - // If not mapped, stop expanding. - if (!ca_map.areMapped(ref_id, tv_id, IdMappingMode::EXACT)) { - break; - } else { - ++tv_num_merged_domains; + auto pos = std::distance(reordered_ids.begin(), find_it); + reordered_ids[pos] = split->inner(); + reordered_ids.insert(reordered_ids.begin() + pos, split->outer()); + } else if (const Merge* merge = dynamic_cast(expr)) { + auto find_it_0 = + std::find(reordered_ids.begin(), reordered_ids.end(), merge->outer()); + auto find_it_1 = + std::find(reordered_ids.begin(), reordered_ids.end(), merge->inner()); + if (find_it_0 == reordered_ids.end() && + find_it_1 == reordered_ids.end()) { + // Transformations before rfactor, ignore those. + continue; } - } - - size_t tv_word_size = 1; - if (tv_num_merged_domains > 1) { - auto tv_merged_domain = - mergeInnermostDomains(tv_root, tv_num_merged_domains); - if (tv_merged_domain == nullptr) { - tv_word_size = runtime_info.getInnerDimVectorizableWidth(tv); - } else { - tv_word_size = scheduler_utils::collectMaxVectorizeSizeWithContigMerge( - tv, - tv_merged_domain, - common_alignment_size, - runtime_info.expressionEvaluator(), - indexModeToDtype(runtime_info.getIndexMode())); - cleanUpInnermostMergedDomains(tv_root, tv_merged_domain); + TORCH_INTERNAL_ASSERT( + find_it_0 != reordered_ids.end() && find_it_1 != reordered_ids.end(), + "Error in transformations of ", + tv->toString(), + "\nTransformations before rfactor should not mix with transformations after rfactor."); + auto pos0 = std::distance(reordered_ids.begin(), find_it_0); + auto pos1 = std::distance(reordered_ids.begin(), find_it_1); + if (pos0 > pos1) { + std::swap(pos0, pos1); } - } else { - tv_word_size = runtime_info.getInnerDimVectorizableWidth(tv); - } + // Should be impossible. + TORCH_INTERNAL_ASSERT( + pos0 != pos1, + "Didn't expect merge inputs to be the same iteration domain:\n", + merge->toString()); - word_size = std::min(word_size, tv_word_size); + reordered_ids.erase(reordered_ids.begin() + pos0); + pos1--; + reordered_ids[pos1] = merge->out(); + } } - return word_size; + std::unordered_map old2new; + for (auto id_i : c10::irange(tv->domain()->domain().size())) { + auto leaf_id = tv->axis(id_i); + auto find_it = + std::find(reordered_ids.begin(), reordered_ids.end(), leaf_id); + TORCH_INTERNAL_ASSERT( + find_it != reordered_ids.end(), + "Reordering map creation failed, uninitialized iterdomain,", + " likely something is wrong with the transformations between the rfactor domain and the leaves."); + int new_pos = (int)std::distance(reordered_ids.begin(), find_it); + int old_pos = (int)id_i; + old2new[old_pos] = new_pos; + } + return old2new; } } // namespace scheduler_utils diff --git a/torch/csrc/jit/codegen/cuda/scheduler/utils.h b/torch/csrc/jit/codegen/cuda/scheduler/utils.h index 0eb08fb03ba15..373a879f740d5 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/utils.h +++ b/torch/csrc/jit/codegen/cuda/scheduler/utils.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include @@ -115,16 +116,6 @@ TORCH_CUDA_CU_API inline void parallelizeAllLike( propagate_padding); } -TORCH_CUDA_CU_API void computeAtInputs( - TensorView* consumer, - int pos, - ComputeAtMode mode = ComputeAtMode::Standard); - -TORCH_CUDA_CU_API void computeWithOutputs( - TensorView* producer, - int pos, - ComputeAtMode mode = ComputeAtMode::Standard); - struct PersistentBufferInfo { std::vector persistent_buffers; std::unordered_set unmappable_dims; @@ -312,14 +303,26 @@ struct BroadcastMultiple { int64_t lhs_multiple = 0; }; -// Returns a vector of counts, size = reference_tv->getRootDomain().size(), each -// entry [i] is the number of inputs/outputs that have a non-broadcast dimension -// mapped to the corresponding dimension in reference_tv. Count includes -// reference_tv if reference_tv is an input or output. Count is multiplied by -// data type size. -std::vector getBroadcastMultiples( - TensorView* reference_tv, - DataType index_type); +struct BroadcastMultipleInformation { + std::vector view_disjoint_set_ids; + std::vector broadcast_multiples; +}; + +// Returns a vector of size reference_tv->getMaybeRFactorDomain().size() which +// is a view disjoint set id of each of those iter domains. If entries share the +// same value, they undergo view transformations in the fusion together. +// Broadcast multiples are also of size +// reference_tv->getMaybeRFactorDomain().size(), each entry [i] is the number of +// inputs/outputs that have a non-broadcast dimension mapped to the +// corresponding dimension in reference_tv. Broadcast multiples includes +// reference_tv if reference_tv is an input or output. Broadcast multiples is +// multiplied by data type size. In the case of view operations the broadcast +// multiple is the full multiple size if any domain in the group maps to a +// non-broadcast dimension in the given input/output. Otherwise if all +// dimensions are broadcast that input/output will not contribute to the +// multiple. +TORCH_CUDA_CU_API BroadcastMultipleInformation +getBroadcastMultiples(TensorView* reference_tv, DataType index_type); //! Collect maximum vectorization word size of a tensor whose //! innermost domain is leaf_merged_domain. Contig merging is taken @@ -492,6 +495,47 @@ struct TORCH_CUDA_CU_API BoundedDirectionalTransformPropagator { Options options); }; +// Schedulers typically start by merging some axes together then splitting, +// and propagating those transformations through the dag. What we want to +// understand is if these merges can be supported through view operations. +// For example it could be problematic to support a reduction fusion: +// +// tv0[2, 3, 4] +// tv1 = sum(tv0, {1, 2}) +// tv2 = view(tv0, {6, 4}) +// +// Since the first step of the reduction scheduler would be tv1->merge(1, 2). +// If we tried to propagate this transformation through the view it would make +// the view invalid. If we tried to propagate the view through the reduction, +// it would attempt to merge a reduction and non-reduction dimension. So for +// these types of fusions we would like to understand that the view considers +// axis 1 and 2 of tv1 as "non-separable" axes. +// +// If IterDomains are disjoint in the returned set, then they are considered +// "separable". +// Warning: This pass generates the IdGraphs, not intended for use at runtime. +TORCH_CUDA_CU_API DisjointSets disjointViewSets(Fusion* fusion); + +// Return if all trasnformations in all views match. +// TODO: Should this be moved to registry.cpp/.h? +// Warning: This pass generates the IdGraphs, not intended for use at runtime. +TORCH_CUDA_CU_API bool allMatchingViews(Fusion* fusion); + +// Makes sure that there are no group id's left of pos that match right of pos. +// e.g. +// [1, 0, 0] pos 2 would return false +// [1, 0, 0] pos 1 would return true +TORCH_CUDA_CU_API bool breakIsDisjoint(std::vector group_ids, int pos); + +// Generates an old to new map to reorder tv's domain as the rfactor order. +// Priority is given to inner most dimensions for example: +// rfactor [i0, i1, i2] +// domain [i0*i2, i1] +// will produce the map {{0, 1}, {1, 0}} +// This is somewhat similar to orderTiledConcreteIdAsRoot +TORCH_CUDA_CU_API std::unordered_map domainReorderAsRfactorMap( + TensorView* tv); + } // namespace scheduler_utils } // namespace cuda } // namespace fuser diff --git a/torch/csrc/jit/codegen/cuda/scheduler/vectorize_helper.cpp b/torch/csrc/jit/codegen/cuda/scheduler/vectorize_helper.cpp new file mode 100644 index 0000000000000..2c3c848c7f5c9 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/scheduler/vectorize_helper.cpp @@ -0,0 +1,286 @@ +#include + +#include +#include +#include +#include +#include +#include + +#include + +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { +namespace vectorize_helper { + +// Grab all values and expressions used to make the merged_domain and remove +// them from the fusion +void cleanUpInnermostMergedDomains( + const std::vector& root_domain, + IterDomain* merged_domain) { + TORCH_INTERNAL_ASSERT(merged_domain != nullptr); + TORCH_INTERNAL_ASSERT(!root_domain.empty()); + + std::unordered_set root_set({root_domain.begin(), root_domain.end()}); + + auto vals = DependencyCheck::getAllValsBetween(root_set, {merged_domain}); + + for (auto it = vals.rbegin(); it != vals.rend(); ++it) { + TORCH_INTERNAL_ASSERT((*it)->isA()); + auto id = (*it)->as(); + if (root_set.find(id) != root_set.end()) { + continue; + } + Fusion* fusion = id->container()->as(); + auto id_def = id->definition(); + TORCH_INTERNAL_ASSERT( + id_def->isA(), + "Invalid ID: ", + id->toString(), + ". Expected definition of a Merge expression: ", + (id_def != nullptr ? id_def->toString() : "nullptr")); + fusion->removeExpr(id_def); + fusion->removeVal(id); + } +} + +// Merge innermost domains for finding the widest vectorizable +// size. Return the merged domain or nullptr if no merge is done. +IterDomain* mergeInnermostDomains( + const std::vector& domain, + int num_merged_domains) { + const auto ndims = domain.size(); + IterDomain* merged_id = nullptr; + bool is_merge_done = false; + for (const auto i : c10::irange(num_merged_domains)) { + auto id = domain.at(ndims - 1 - i); + // broadcast and trivial reductions are ignored + if (id->isBroadcast() || id->isTrivialReduction()) { + continue; + } + if (merged_id == nullptr) { + merged_id = id; + } else { + auto id_inner = merged_id; + auto id_outer = id; + merged_id = IterDomain::merge(id_outer, id_inner); + is_merge_done = true; + } + } + return is_merge_done ? merged_id : nullptr; +} + +size_t collectMaxVectorizeSizeWithContigMerge( + TensorView* tv, + IterDomain* leaf_merged_domain, + size_t max_vector_size_in_byte, + ExpressionEvaluator& expression_evaluator, + DataType index_type) { + auto dtype_size = dataTypeSize(tv->dtype(), index_type); + const size_t max_vector_size = max_vector_size_in_byte / dtype_size; + + // Assume no halo-related expression appears in the fusion. No + // broadcast is merged, so indexability can be assumed to be true. + // This is expensive, as ContigIDs builds other things like CAMap, + // HaloInfo, and ConcreteBroadcast info. We should explicitly build and reuse + // these as they're compile time information. + ContigIDs contigIds( + {leaf_merged_domain}, + tv->getMaybeRFactorDomain(), + tv->domain()->contiguity(), + {}, + {}, + getAllDivisibleSplits(tv->fusion()), + {}, + true); + + auto innermost_root_id = tv->getMaybeRFactorDomain().back(); + auto indexed_id = contigIds.rootToIndexedID().at(innermost_root_id); + + size_t merged_size = 1; + // If the indexed ID is a contig merged domain, i.e., it is + // different from innermost_root_id, we accumulate the extents of + // all the root domains covered by the contig indexed ID. Otherwise, + // just look at the extent of the innermost root ID. + if (indexed_id != innermost_root_id) { + const auto& within_root = contigIds.withinContigIDs().at(indexed_id); + for (auto root_id : tv->getMaybeRFactorDomain()) { + if (within_root.find(root_id) == within_root.end()) { + continue; + } + auto maybe_dimension_size = + expression_evaluator.evaluate(root_id->extent()); + TORCH_INTERNAL_ASSERT( + maybe_dimension_size.has_value(), + "Unknown extent of tv: ", + tv->toString(), + ", id: ", + root_id->toString()); + merged_size *= maybe_dimension_size->as(); + } + } else { + auto maybe_dimension_size = + expression_evaluator.evaluate(innermost_root_id->extent()); + TORCH_INTERNAL_ASSERT( + maybe_dimension_size.has_value(), + "Unknown extent of tv: ", + tv->toString(), + ", id: ", + innermost_root_id->toString()); + merged_size = maybe_dimension_size->as(); + } + + size_t vector_size = 1; + size_t next_vector_size = vector_size * 2; + + // Try until vector size exceeds the max allowed size + while (next_vector_size <= max_vector_size) { + if (merged_size % next_vector_size != 0) { + break; + } + vector_size = next_vector_size; + next_vector_size *= 2; + } + + return vector_size; +} + +//! Attempt to expand vectorized domains to contig merged domains. Break point +//! identifies the point in which you can't propagate contiguous merges. For +//! example in pointwise this is the point where we want to split the +//! parallelization to take advantage of broadcast, and for reduction +//! schedulers it's the point where we switch from a reduction domain to an +//! iter domain (or vice versa). +size_t expandVectorizationToContigMergedDomains( + Fusion* fusion, + SchedulerRuntimeInfo& runtime_info, + const std::vector vectorizable_inputs_outputs, + TensorView* reference_tv, + int break_point, + size_t default_word_size) { + size_t max_expand_size = SchedulerRuntimeInfo::max_alignment_size_in_byte; + size_t common_alignment_size = + SchedulerRuntimeInfo::max_alignment_size_in_byte; + + for (auto inp_out : vectorizable_inputs_outputs) { + auto dtype_size = dataTypeSize( + inp_out->dtype(), indexModeToDtype(runtime_info.getIndexMode())); + + max_expand_size = std::min( + max_expand_size, + SchedulerRuntimeInfo::max_alignment_size_in_byte / dtype_size); + max_expand_size = std::min( + max_expand_size, runtime_info.getMaxVectorizableWidth(inp_out)); + common_alignment_size = + std::min(common_alignment_size, runtime_info.getAlignmentSize(inp_out)); + } + + // If there's no possibility to increase vector size of provided tensors, + // then don't bother doing a more complex analysis to try and do so, just + // return early. + if (max_expand_size == default_word_size) { + return default_word_size; + } + + auto ca_map = ComputeAtMap(fusion); + + // Merge the domains right of the break point + const auto& ref_root = reference_tv->getMaybeRFactorDomain(); + const int max_num_merged_domains = + static_cast(ref_root.size()) - static_cast(break_point); + int64_t num_merged_domains = 0; + while (num_merged_domains < max_num_merged_domains) { + auto pos = (int64_t)ref_root.size() - 1 - num_merged_domains; + if (!reference_tv->domain()->contiguity()[pos]) { + break; + } + num_merged_domains++; + } + + // No expansion with no merged domain + if (num_merged_domains == 0) { + return default_word_size; + } + + // Merge the domains but don't modify TensorDomain + auto merged_domain = mergeInnermostDomains(ref_root, num_merged_domains); + + // No expansion is done if no merge is done. + if (merged_domain == nullptr) { + return default_word_size; + } + + // Find the vectorizable word size with the merged domains + size_t word_size = collectMaxVectorizeSizeWithContigMerge( + reference_tv, + merged_domain, + common_alignment_size, + runtime_info.expressionEvaluator(), + indexModeToDtype(runtime_info.getIndexMode())); + + cleanUpInnermostMergedDomains(ref_root, merged_domain); + + // Stop if the reference doesn't get a larger word size. + if (word_size <= default_word_size) { + return default_word_size; + } + + // Check the other TVs and take the minimum of the valid word sizes + for (const auto tv : vectorizable_inputs_outputs) { + if (tv == reference_tv) { + continue; + } + + const auto& tv_root = tv->getMaybeRFactorDomain(); + + int tv_num_merged_domains = 0; + for (const auto i : c10::irange(max_num_merged_domains)) { + if (i == tv_root.size()) { + break; + } + auto ref_id = ref_root.at(ref_root.size() - 1 - i); + auto pos = tv_root.size() - 1 - i; + IterDomain* tv_id = tv_root.at(pos); + // If not mapped, stop expanding. + if (!ca_map.areMapped(ref_id, tv_id, IdMappingMode::EXACT) || + !tv->domain()->contiguity()[pos]) { + break; + } else { + ++tv_num_merged_domains; + } + } + + size_t tv_word_size = 1; + if (tv_num_merged_domains > 1) { + auto tv_merged_domain = + mergeInnermostDomains(tv_root, tv_num_merged_domains); + if (tv_merged_domain == nullptr) { + tv_word_size = runtime_info.getInnerDimVectorizableWidth(tv); + } else { + tv_word_size = collectMaxVectorizeSizeWithContigMerge( + tv, + tv_merged_domain, + common_alignment_size, + runtime_info.expressionEvaluator(), + indexModeToDtype(runtime_info.getIndexMode())); + cleanUpInnermostMergedDomains(tv_root, tv_merged_domain); + } + } else { + tv_word_size = runtime_info.getInnerDimVectorizableWidth(tv); + } + + word_size = std::min(word_size, tv_word_size); + } + + return word_size; +} + +} // namespace vectorize_helper +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/scheduler/vectorize_helper.h b/torch/csrc/jit/codegen/cuda/scheduler/vectorize_helper.h index 0a67d00618e23..a9b959b495d60 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/vectorize_helper.h +++ b/torch/csrc/jit/codegen/cuda/scheduler/vectorize_helper.h @@ -2,21 +2,15 @@ #include #include -#include #include -#include + +#include namespace torch { namespace jit { namespace fuser { namespace cuda { - -// TODO: Put implementations in a vectorize_helper.cpp -namespace scheduler_utils { - -// Moved the definition of these to -// torch/csrc/jit/codegen/cuda/scheduler/utils.cpp as making new CPP files is -// painful for multiple reasons. +namespace vectorize_helper { // Grab all values and expressions used to make the merged_domain and remove // them from the fusion @@ -44,7 +38,7 @@ size_t expandVectorizationToContigMergedDomains( int break_point, size_t default_word_size); -} // namespace scheduler_utils +} // namespace vectorize_helper } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/tensor_view.cpp b/torch/csrc/jit/codegen/cuda/tensor_view.cpp index ba95d8fabdce9..85f320fef2e43 100644 --- a/torch/csrc/jit/codegen/cuda/tensor_view.cpp +++ b/torch/csrc/jit/codegen/cuda/tensor_view.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include @@ -290,40 +291,115 @@ IterDomain* TensorView::axis(int pos) const { return domain()->axis(pos); } -void TensorView::setComputeAt(unsigned int pos, bool decrease) { +void TensorView::inlineAt( + int64_t pos, + bool best_effort, + MaxPosCalculator* calc) { TORCH_INTERNAL_ASSERT( !container()->isA(), "Function invalid for kernel container."); - if (pos <= compute_at_pos_ && !decrease) { - return; + + std::unique_ptr calc_owner; + if (calc == nullptr) { + calc_owner = std::make_unique(); + calc = calc_owner.get(); + } + + if (pos < 0) { + pos += int64_t(nDims()) + 1; } TORCH_INTERNAL_ASSERT( - (unsigned)pos <= nDims(), - "Invalid this computeAt position for T", + pos >= 0 && pos <= nDims(), + "Invalid inline position for T", name(), ": ", pos); - compute_at_pos_ = pos; -} + auto max_inline_pos = calc->getMaxPosAll(this, best_effort); -void TensorView::setMaxProducer(unsigned int pos, bool decrease) { - TORCH_INTERNAL_ASSERT( - !container()->isA(), - "Function invalid for kernel container."); - if (pos <= max_producer_pos_ && !decrease) { - return; + if (best_effort) { + pos = std::min(max_inline_pos, pos); + } + + // hoist inner most broadcast + while (pos > 0 && axis(pos - 1)->isBroadcast()) { + pos--; } TORCH_INTERNAL_ASSERT( - (unsigned)pos <= nDims(), - "Invalid max producer position for T", + pos <= max_inline_pos, + "Invalid inline position for T", name(), ": ", - pos); + pos, + ". Maximum allowed value:", + max_inline_pos); + + if (isFusionInput()) { + return; + } + + if (pos > compute_at_pos_) { + compute_at_pos_ = pos; + for (auto consumer : ir_utils::consumerTvsOf(this)) { + consumer->updateMaxProducerPosition(); + } + } +} + +namespace { + +// Try to find the aligned position on consumer's domain corresponding to the +// compute at position of producer domain. No checking on actual +// producer-consumer relationship. +unsigned int getConsumerPosAlignedToProducerCA( + TensorView* consumer, + TensorView* producer) { + // Locate consumer's position that aligns with + // the producer's new compute at axis. We need broadcast axes forwarded so we + // need to replay PasC as CasP will not forward braodcast dims. For example + // if we have: + // T2[ iS22{( 3 * 1 )} ] ca_pos( 1 ) = broadcast( T1[ iS1{3} ] ca_pos( 1 ) + // produce_pos( 1) ) CasP will have the mapping iS1{3} -> iS2{3} and PasC will + // have the mapping iS22{( 3 * 1 )} <- iS1{3} We need the latter. Refer to + // NVFuserTest.FusionComplexBCast1_CUDA + + auto disjoint_sets = + BestEffortReplay::replayPasC( + producer, consumer, -1, PairwiseRootDomainMap(producer, consumer)) + .getDisjointSets(); + + // Find the innermost position of consumer that has + // been mapped within the producer ca axis. + unsigned int consumer_pos = consumer->nDims(); + while (consumer_pos > 0) { + auto consumer_id = consumer->axis((int)consumer_pos - 1); + auto p_dom = producer->domain()->domain(); + if (std::any_of( + p_dom.begin(), + p_dom.begin() + producer->getComputeAtPosition(), + [&consumer_id, &disjoint_sets](IterDomain* p_id) { + return disjoint_sets.permissiveAreMapped(consumer_id, p_id); + })) { + break; + } + consumer_pos--; + } + + return consumer_pos; +} + +} // namespace - max_producer_pos_ = pos; +void TensorView::updateMaxProducerPosition() { + TORCH_INTERNAL_ASSERT( + !container()->isA(), + "Function invalid for kernel container."); + for (auto producer : ir_utils::producerTvsOf(this)) { + max_producer_pos_ = std::max( + max_producer_pos_, getConsumerPosAlignedToProducerCA(this, producer)); + } } TensorView* TensorView::computeAt( @@ -681,7 +757,7 @@ TensorView* TensorView::rFactor(const std::vector& axes) { TORCH_CHECK( !definition()->isA(), - "For GroupedReducitonOp, use TensorView::rFactor(const std::vector& axes, const std::vector& tvs)"); + "For GroupedReductionOp, use TensorView::rFactor(const std::vector& axes, const std::vector& tvs)"); // Split tensor view into 2 parts auto domain_pair = domain()->rFactor(axes); diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp deleted file mode 100644 index db38bbfd3a92a..0000000000000 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp +++ /dev/null @@ -1,25813 +0,0 @@ -#if defined(USE_CUDA) -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include - -#include -#include -#include - -#include -#include -#include -#include - -// Tests go in torch::jit -namespace torch { -namespace jit { - -using namespace torch::jit::fuser::cuda; -using namespace at::indexing; - -namespace { - -TensorView* loweredTv(TensorView* tv, GpuLower& gpulw) { - auto used_tvs = ir_utils::allTvs(gpulw.kernel()->as()); - TensorView* matching_tv = nullptr; - for (auto lowered_tv : used_tvs) { - if (lowered_tv->name() == tv->name()) { - matching_tv = lowered_tv; - } - } - TORCH_INTERNAL_ASSERT(matching_tv != nullptr); - return matching_tv; -} - -class PredicatedChecker : public kir::IrVisitor { - public: - // Checks if the provided tv is written to within a non-trivial conditional - static bool isPredicated(TensorView* tv, GpuLower& gpulw) { - PredicatedChecker checker( - loweredTv(tv, gpulw), gpulw.kernel()->topLevelExprs()); - return checker.is_predicated_; - } - - private: - PredicatedChecker() = delete; - - PredicatedChecker(TensorView* tv, std::vector exprs) : tv_(tv) { - kir::IrVisitor::handle(exprs); - } - - using kir::IrVisitor::handle; - bool is_predicated_ = false; - bool predicated_ite_ = false; - TensorView* tv_ = nullptr; - - void handle(kir::IfThenElse* ite) final { - auto prev_ite = predicated_ite_; - predicated_ite_ = !ite->predicate()->value()->isConstScalar(); - kir::IrVisitor::handle(ite); - predicated_ite_ = prev_ite; - } - - void handle(Expr* expr) final { - if (expr->outputs().size() && expr->outputs()[0]->isA()) { - auto ti = expr->outputs()[0]->as(); - if (ti->view() == tv_) { - is_predicated_ = is_predicated_ | predicated_ite_; - if (expr->predicate() != nullptr && - !expr->predicate()->value()->isConst()) { - is_predicated_ = true; - } - } - } - kir::IrVisitor::handle(expr); - } -}; - -class UnswitchInElseChecker : public kir::IrVisitor { - public: - // Checks if there are any unswitched for loops within an else clause - static bool check(GpuLower& gpulw) { - UnswitchInElseChecker checker(gpulw.kernel()->topLevelExprs()); - return checker.found_in_else_; - } - - private: - UnswitchInElseChecker() = delete; - UnswitchInElseChecker(std::vector exprs) { - kir::IrVisitor::handle(exprs); - } - - using kir::IrVisitor::handle; - bool within_else_ = false; - bool found_in_else_ = false; - - void handle(kir::IfThenElse* ite) final { - auto prev_within_else = within_else_; - within_else_ = true; - kir::IrVisitor::handle(ite->elseBody().exprs()); - within_else_ = prev_within_else; - } - - void handle(kir::ForLoop* for_loop) final { - if (for_loop->iter_domain()->getParallelType() == ParallelType::Unswitch) { - found_in_else_ = found_in_else_ || within_else_; - } - kir::IrVisitor::handle(for_loop); - } -}; - -class PredicateMagicZeroChecker : public kir::IrVisitor { - public: - // Checks if all predicated domains of the provided tv are protected with - // magic zero - static bool isProtected(TensorView* tv, GpuLower& gpulw) { - PredicateMagicZeroChecker checker( - loweredTv(tv, gpulw), gpulw.kernel()->topLevelExprs()); - return checker.is_protected_; - } - - private: - using kir::IrVisitor::handle; - - PredicateMagicZeroChecker(TensorView* tv, std::vector exprs) - : tv_(tv) { - handle(exprs); - } - - void handle(kir::IfThenElse* ite) final { - auto prev_predicate = predicate_; - predicate_ = ite->predicate()->value(); - kir::IrVisitor::handle(ite); - predicate_ = prev_predicate; - } - - void handle(Expr* expr) final { - if (expr->outputs().size() && expr->outputs()[0]->isA()) { - auto ti = expr->outputs()[0]->as(); - if (ti->view() == tv_) { - is_protected_ = checkPredicateOfTensor(predicate_); - return; - } - } - - if (expr->isA()) { - handle(expr->as()); - } else if (expr->isA()) { - handle(expr->as()); - } else { - for (auto input : expr->inputs()) { - handle(input); - } - } - } - - // Return true If all predicated domains are protected - bool checkPredicateOfTensor(Val* predicate) { - auto id_predicates = decomposeCompoundPredicate(predicate); - for (auto id_predicate : id_predicates) { - // Just check if nvfuser_zero is used. Not perfect but probably - // good enough. - is_magic_zero_found_ = false; - handle(id_predicate); - if (!is_magic_zero_found_) { - return false; - } - } - return true; - } - - // Decompose "X && Y" to a vector of {X, Y}. - std::vector decomposeCompoundPredicate(Val* predicate) { - if (auto binary_op = dynamic_cast(predicate->definition())) { - if (binary_op->getBinaryOpType() == BinaryOpType::And) { - auto pred = decomposeCompoundPredicate(binary_op->lhs()); - auto rhs_pred = decomposeCompoundPredicate(binary_op->rhs()); - pred.insert(pred.end(), rhs_pred.begin(), rhs_pred.end()); - return pred; - } - } - - return {predicate}; - } - - void handle(Val* val) final { - if (isMagicZero(val)) { - is_magic_zero_found_ = true; - return; - } - - auto def = val->definition(); - if (def != nullptr) { - handle(def); - } - } - - private: - bool is_protected_ = false; - Val* predicate_ = nullptr; - TensorView* tv_ = nullptr; - bool is_magic_zero_found_ = false; -}; - -// Basically just TransformPropagator, except that it checks the consistency -// replayPasC with getMatchedLeafPosWithoutReplayPasC, replayCasP with -// getMatchedLeafPosWithoutReplayCasP, and fullSelfReplay with fullSelfMatching: -// - After replayPasC, getMatchedLeafPosWithoutReplayPasC should return the same -// replayed position -// - After replayCasP, getMatchedLeafPosWithoutReplayCasP should return the same -// replayed position -// - After fullSelfReplay, fullSelfMatching should return true -struct TransformPropagatorWithCheck : public TransformPropagator { - public: - virtual void propagateC2P(TensorView* from, TensorView* to) override { - TransformPropagator::propagateC2P(from, to); - auto from_pos = replayed_pos_.at(from); - auto to_pos = replayed_pos_.at(to); - TORCH_CHECK( - TransformReplay::getMatchedLeafPosWithoutReplayPasC( - to, from, from_pos) == to_pos); - } - virtual void propagateP2C(TensorView* from, TensorView* to) override { - TransformPropagator::propagateP2C(from, to); - auto from_pos = replayed_pos_.at(from); - auto to_pos = replayed_pos_.at(to); - TORCH_CHECK( - TransformReplay::getMatchedLeafPosWithoutReplayCasP( - to, from, from_pos) == to_pos); - } - virtual void propagateSibling(TensorView* from, TensorView* to) override { - TransformPropagator::propagateSibling(from, to); - auto from_pos = replayed_pos_.at(from); - auto to_pos = replayed_pos_.at(to); - TORCH_CHECK(from_pos == to_pos); - TORCH_CHECK(TransformReplay::fullSelfMatching(from, to)); - } - using TransformPropagator::TransformPropagator; -}; - -} // namespace - -// 1. Test cases are void() functions. -// 2. They start with the prefix `test` - -// A few smoke tests for IrGraphGenerator -// (These tests exercise IrGraphGenerator through a non-trivial IR, -// to make sure that it runs w/o crashing. The actual output is not -// validated) -TEST_F(NVFuserTest, FusionIrGraphGenerator_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Make sure we can handle empty IRs - TORCH_CHECK(!IrGraphGenerator::toGraphviz( - &fusion, IrGraphGenerator::DetailLevel::Basic) - .empty()); - - // Construct an interesting IR - TensorView* tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - - TensorView* tv2 = add(tv0, IrBuilder::create(3.141)); - TensorView* tv3 = broadcast(tv0, {false, true, false, true}); - TensorView* tv4 = - reductionOp(BinaryOpType::Add, {2}, IrBuilder::create(0), tv3); - TensorView* tv5 = clamp( - tv4, IrBuilder::create(0.f), IrBuilder::create(1.f)); - TensorView* tv6 = add(tv2, tv2); - - // Another checkpoint before adding outputs - TORCH_CHECK(!IrGraphGenerator::toGraphviz( - &fusion, IrGraphGenerator::DetailLevel::Explicit) - .empty()); - - fusion.addOutput(tv6); - - tv4->axis(2)->parallelize(ParallelType::BIDy); - tv6->merge(0); - tv6->split(0, 4); - tv6->axis(0)->parallelize(ParallelType::BIDx); - tv5->reorder({{-1, 0}}); - tv2->computeAt(tv6, 1); - - // Another checkpoint with more node types - TORCH_CHECK(!IrGraphGenerator::toGraphviz( - &fusion, IrGraphGenerator::DetailLevel::ComputeOnly) - .empty()); - - for (Val* val : fusion.vals()) { - if (!val->isFusionInput() && - val->getValType().value() == ValType::TensorView) { - TensorView* tv = static_cast(val); - tv->axis(-1)->parallelize(ParallelType::TIDx); - } - } - - // Final IR graph - TORCH_CHECK(!IrGraphGenerator::toGraphviz( - &fusion, IrGraphGenerator::DetailLevel::Verbose) - .empty()); -} - -TEST_F(NVFuserTest, FusionDispatch_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - Double* f = IrBuilder::create(2.f); - std::stringstream ss1, ss2, ss3; - ss1 << f; - ss2 << static_cast(f); - ss3 << static_cast(f); - TORCH_CHECK( - ss1.str().compare(ss2.str()) == 0 && ss1.str().compare(ss3.str()) == 0, - "Error with dispatch system where results differ by passing Double* vs Val* vs Statement*."); -} - -// Evaluate basic scalar operations with constant values -TEST_F(NVFuserTest, FusionExprEvalConstants_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - ExpressionEvaluator evaluator(&fusion); - - auto* a = IrBuilder::create(7); - auto* b = IrBuilder::create(3); - - // Avoid div operation because it casts int operands to float - checkIntValue(evaluator, neg(a), -7); - checkIntValue(evaluator, add(a, b), 10); - checkIntValue(evaluator, neg(mul(sub(a, b), add(a, b))), -40); - checkIntValue(evaluator, mod(a, b), 1); - checkIntValue(evaluator, ceilDiv(a, b), 3); -} - -TEST_F(NVFuserTest, FusionExprEvalDouble_CUDA) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - auto ten = IrBuilder::create(10); - auto two = IrBuilder::create(2); - auto three = IrBuilder::create(3); - auto val = castOp(DataType::Int, ceilDiv(sub(ten, two), three)); - auto reference = static_cast(std::ceil((10.0 - 2.0) / 3.0)); - TORCH_CHECK(reference == val->evaluateInt()); -} - -// Evaluate basic scalar operations with bound values -TEST_F(NVFuserTest, FusionExprEvalBindings_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - ExpressionEvaluator evaluator(&fusion); - - auto* a = IrBuilder::create(); - auto* b = IrBuilder::create(); - auto* c = add(a, b); - auto* d = neg(ceilDiv(c, b)); - auto* e = IrBuilder::create(0); - - // trying to evaluate before binding should give empty results - TORCH_CHECK(!evaluator.evaluate(a).has_value()); - TORCH_CHECK(!evaluator.evaluate(d).has_value()); - - evaluator.bind(a, 7); - evaluator.bind(b, 3); - - // can't bind to the results of expressions - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_ANY_THROW(evaluator.bind(c, 100)); - - // can't bind to concrete values - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_ANY_THROW(evaluator.bind(e, 100)); - - checkIntValue(evaluator, c, 10); - checkIntValue(evaluator, sub(a, b), 4); - checkIntValue(evaluator, mod(a, b), 1); - checkIntValue(evaluator, ceilDiv(a, b), 3); - checkIntValue(evaluator, d, -4); - - // Reset evaluation context - evaluator = ExpressionEvaluator(&fusion); - - evaluator.bind(a, 2); - evaluator.bind(b, 5); - - checkIntValue(evaluator, c, 7); - checkIntValue(evaluator, sub(a, b), -3); - checkIntValue(evaluator, mod(a, b), 2); - checkIntValue(evaluator, ceilDiv(a, b), 1); - checkIntValue(evaluator, d, -2); -} - -// Evaluate expressions in a simple IR -TEST_F(NVFuserTest, FusionExprEvalBasic_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Create a non-trivial IR - TensorView* tv0 = makeSymbolicTensor(2); - TensorView* tv1 = makeSymbolicTensor(2); - - fusion.addInput(tv0); - fusion.addInput(tv1); - - TensorView* tv2 = add(tv1, IrBuilder::create(2.0)); - TensorView* tv3 = add(tv0, tv2); - - fusion.addOutput(tv3); - - tv3->split(0, 4); - - tv0->computeAt(tv3, 1); - tv1->computeAt(tv3, 1); - - tv3->axis(0)->parallelize(ParallelType::BIDx); - tv2->axis(1)->parallelize(ParallelType::Unroll); - tv3->axis(1)->parallelize(ParallelType::Unroll); - tv2->axis(-1)->parallelize(ParallelType::TIDx); - tv3->axis(-1)->parallelize(ParallelType::TIDx); - - // 1. Create an evaluator - ExpressionEvaluator evaluator(&fusion); - - // 2. Bind values - // - // IMPORTANT: - // a. The bindings are only as stable as the Vals are in the fusion graph - // b. You must use the original (rootDomain) extents - // (ex. `tv0->getRootDomain()[0]->extent()` - // instead of `tv0->axis(0)->extent()`) - // - evaluator.bind(tv0->getRootDomain()[0]->extent(), 6); - evaluator.bind(tv0->getRootDomain()[1]->extent(), 128); - evaluator.bind(tv1->getRootDomain()[0]->extent(), 6); - evaluator.bind(tv1->getRootDomain()[1]->extent(), 128); - - // 3. Evaluate and check result values - TORCH_CHECK(tv2->domain()->nDims() == 3); - checkIntValue(evaluator, tv2->axis(0)->extent(), 2); - checkIntValue(evaluator, tv2->axis(1)->extent(), 4); - checkIntValue(evaluator, tv2->axis(2)->extent(), 128); - - TORCH_CHECK(tv3->domain()->nDims() == 3); - checkIntValue(evaluator, tv3->axis(0)->extent(), 2); - checkIntValue(evaluator, tv3->axis(1)->extent(), 4); - checkIntValue(evaluator, tv3->axis(2)->extent(), 128); -} - -// Evaluate expressions in a more complex IR -TEST_F(NVFuserTest, FusionExprEvalComplex_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - - TensorView* tv1 = mul(tv0, IrBuilder::create(-1.0)); - TensorView* tv2 = add(tv0, IrBuilder::create(3.0)); - TensorView* tv3 = mul(tv0, IrBuilder::create(2.0)); - TensorView* tv4 = add(tv2, tv1); - TensorView* tv5 = add(tv4, tv3); - TensorView* tv6 = add(tv0, tv3); - - fusion.addOutput(tv5); - fusion.addOutput(tv6); - - tv5->reorder({{-1, 0}}); - - tv6->split(0, 5); - tv5->merge(0); - - // 1. Create an evaluator - ExpressionEvaluator evaluator(&fusion); - - // 2. Bind values - evaluator.bind(tv0->getRootDomain()[0]->extent(), 129); - evaluator.bind(tv0->getRootDomain()[1]->extent(), 127); - - // Evaluate and check extent values - TORCH_CHECK(tv0->domain()->nDims() == 2); - checkIntValue(evaluator, tv0->axis(0)->extent(), 129); - checkIntValue(evaluator, tv0->axis(1)->extent(), 127); - - TORCH_CHECK(tv3->domain()->nDims() == 2); - checkIntValue(evaluator, tv3->axis(0)->extent(), 129); - checkIntValue(evaluator, tv3->axis(1)->extent(), 127); - - TORCH_CHECK(tv4->domain()->nDims() == 2); - checkIntValue(evaluator, tv4->axis(0)->extent(), 129); - checkIntValue(evaluator, tv4->axis(1)->extent(), 127); - - TORCH_CHECK(tv5->domain()->nDims() == 1); - checkIntValue(evaluator, tv5->axis(0)->extent(), 16383); - - TORCH_CHECK(tv6->domain()->nDims() == 3); - checkIntValue(evaluator, tv6->axis(0)->extent(), 26); - checkIntValue(evaluator, tv6->axis(1)->extent(), 5); - checkIntValue(evaluator, tv6->axis(2)->extent(), 127); -} - -// Evaluate expressions post lowering -TEST_F(NVFuserTest, FusionExprEvalPostLower_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Create a non-trivial IR - TensorView* tv0 = makeSymbolicTensor(2); - TensorView* tv1 = makeSymbolicTensor(2); - - fusion.addInput(tv0); - fusion.addInput(tv1); - - TensorView* tv2 = add(tv1, IrBuilder::create(2.0)); - TensorView* tv3 = add(tv0, tv2); - - fusion.addOutput(tv3); - - tv3->split(0, 4); - - tv0->computeAt(tv3, 1); - tv1->computeAt(tv3, 1); - - tv3->axis(0)->parallelize(ParallelType::BIDx); - tv2->axis(1)->parallelize(ParallelType::Unroll); - tv3->axis(1)->parallelize(ParallelType::Unroll); - tv2->axis(-1)->parallelize(ParallelType::TIDx); - tv3->axis(-1)->parallelize(ParallelType::TIDx); - - auto* bid_x = add(tv3->axis(0)->extent(), IrBuilder::create(0)); - auto* tid_x = add(tv3->axis(-1)->extent(), IrBuilder::create(0)); - - // Lower - GpuLower gpulw(&fusion); - - // 1. Create an evaluation context - ExpressionEvaluator evaluator(&fusion); - - // 2. Bind values - evaluator.bind(tv0->getRootDomain()[0]->extent(), 6); - evaluator.bind(tv0->getRootDomain()[1]->extent(), 128); - evaluator.bind(tv1->getRootDomain()[0]->extent(), 6); - evaluator.bind(tv1->getRootDomain()[1]->extent(), 128); - - // 3. Evaluate and check result values - TORCH_CHECK(tv2->domain()->nDims() == 3); - checkIntValue(evaluator, tv2->axis(0)->extent(), 2); - checkIntValue(evaluator, tv2->axis(1)->extent(), 4); - checkIntValue(evaluator, tv2->axis(2)->extent(), 128); - - TORCH_CHECK(tv3->domain()->nDims() == 3); - checkIntValue(evaluator, tv3->axis(0)->extent(), 2); - checkIntValue(evaluator, tv3->axis(1)->extent(), 4); - checkIntValue(evaluator, tv3->axis(2)->extent(), 128); - - checkIntValue(evaluator, bid_x, 2); - checkIntValue(evaluator, tid_x, 128); -} - -// Kernel IR: Evaluate basic scalar operations with constant values -TEST_F(NVFuserTest, FusionKernelExprEvalConstants_CUDA) { - Fusion fusion; - kir::Kernel kernel(&fusion); - FusionGuard fg((&kernel)->as()); - - auto a = IrBuilder::create(7); - auto b = IrBuilder::create(3); - auto c = IrBuilder::subExpr(a, b); - auto d = IrBuilder::divExpr(a, b); - auto e = IrBuilder::mulExpr(c, d); - - kir::ExpressionEvaluator evaluator; - - checkIntValue(evaluator, IrBuilder::negExpr(a), -7); - checkIntValue(evaluator, IrBuilder::addExpr(a, b), 10); - checkIntValue(evaluator, IrBuilder::negExpr(e), -8); - checkIntValue(evaluator, IrBuilder::modExpr(a, b), 1); - checkIntValue(evaluator, IrBuilder::ceilDivExpr(a, b), 3); -} - -// Kernel IR: Evaluate basic scalar operations with bound values -TEST_F(NVFuserTest, FusionKernelExprEvalBindings_CUDA) { - Fusion fusion; - kir::Kernel kernel(&fusion); - FusionGuard fg((&kernel)->as()); - - kir::ExpressionEvaluator evaluator; - - auto a = IrBuilder::create(c10::nullopt); - auto b = IrBuilder::create(c10::nullopt); - auto c = IrBuilder::addExpr(a, b); - auto d = IrBuilder::negExpr(IrBuilder::ceilDivExpr(c, b)); - auto e = IrBuilder::create(0); - - // trying to evaluate before binding should give empty results - TORCH_CHECK(!evaluator.evaluate(a).has_value()); - TORCH_CHECK(!evaluator.evaluate(d).has_value()); - - evaluator.bind(a, 7); - evaluator.bind(b, 3); - - // can't bind to the results of expressions - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_ANY_THROW(evaluator.bind(c, 100)); - - // can't bind to concrete values - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_ANY_THROW(evaluator.bind(e, 100)); - - checkIntValue(evaluator, c, 10); - checkIntValue(evaluator, IrBuilder::subExpr(a, b), 4); - checkIntValue(evaluator, IrBuilder::modExpr(a, b), 1); - checkIntValue(evaluator, IrBuilder::ceilDivExpr(a, b), 3); - checkIntValue(evaluator, d, -4); - - // Reset the evaluation context - evaluator = kir::ExpressionEvaluator(); - - evaluator.bind(a, 2); - evaluator.bind(b, 5); - - checkIntValue(evaluator, c, 7); - checkIntValue(evaluator, IrBuilder::subExpr(a, b), -3); - checkIntValue(evaluator, IrBuilder::modExpr(a, b), 2); - checkIntValue(evaluator, IrBuilder::ceilDivExpr(a, b), 1); - checkIntValue(evaluator, d, -2); -} - -TEST_F(NVFuserTest, FusionClear_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // 1. Create a dummy IR - - { - TensorView* tv0 = makeSymbolicTensor(2); - TensorView* tv1 = makeSymbolicTensor(2); - - fusion.addInput(tv0); - fusion.addInput(tv1); - - TensorView* tv2 = add(tv1, IrBuilder::create(2.0)); - TensorView* tv3 = add(tv0, tv2); - - fusion.addOutput(tv3); - - tv3->split(0, 4); - tv0->computeAt(tv3, 1); - tv1->computeAt(tv3, 1); - - tv3->axis(0)->parallelize(ParallelType::BIDx); - tv2->axis(1)->parallelize(ParallelType::Unroll); - tv3->axis(-1)->parallelize(ParallelType::TIDx); - } - - // 2. Clear the IR - - fusion.clear(); - - TORCH_CHECK(fusion.unordered_exprs().empty()); - TORCH_CHECK(fusion.vals().empty()); - - TORCH_CHECK(fusion.inputs().empty()); - TORCH_CHECK(fusion.outputs().empty()); - - TORCH_CHECK(ir_utils::getReductionOps(&fusion).empty()); - - // 3. Rebuild the IR - - { - TensorView* tv0 = makeSymbolicTensor(3); - TensorView* tv1 = makeSymbolicTensor(3); - TensorView* tv2 = add(tv1, IrBuilder::create(2.0)); - TensorView* tv3 = add(tv0, tv2); - - fusion.addInput(tv0); - fusion.addInput(tv1); - fusion.addOutput(tv3); - - // tv3 [i0, i1, i2] - tv3->reorder({{0, 2}, {2, 0}}); - // tv3 [i2, i1, i0] - tv3->split(-1, 4); - // tv3 [i2, i1, i0outer, i0inner{4}] - tv3->reorder({{2, 0}, {3, 1}, {0, 3}}); - // tv3 [i0outer, i0inner{4}, i1, i2] - tv0->computeAt(tv3, -1); - tv1->computeAt(tv3, -1); - tv3->axis(1)->parallelize(ParallelType::BIDx); - } - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::Tensor input1 = at::randn({16, 8, 8}, options); - at::Tensor input2 = at::randn_like(input1); - - FusionExecutor fe; - fe.compileFusion(&fusion, {input1, input2}); - auto outputs = fe.runFusion({input1, input2}); - - at::Tensor tv2_ref = input2 + 2.0; - at::Tensor output_ref = input1 + tv2_ref; - - TORCH_CHECK(output_ref.equal(outputs[0])); -} - -TEST_F(NVFuserTest, FusionCopy_CUDA) { - Fusion original_fusion; - - // Create the test IR - { - FusionGuard fg(&original_fusion); - - auto tv0 = makeSymbolicTensor(3); - auto tv1 = makeSymbolicTensor(3); - auto tv2 = add(tv1, IrBuilder::create(2.0)); - auto tv3 = sub(add(tv0, mul(tv2, tv2)), tv2); - - original_fusion.addInput(tv0); - original_fusion.addInput(tv1); - original_fusion.addOutput(tv3); - - tv3->reorder({{0, 2}, {2, 0}}); - tv3->split(-1, 4); - tv3->reorder({{2, 0}, {3, 1}, {0, 3}}); - - tv0->computeAt(tv3, -1); - tv1->computeAt(tv3, -1); - - tv3->axis(0)->parallelize(ParallelType::BIDx); - tv3->axis(-1)->parallelize(ParallelType::TIDx); - } - - // Test copy before lowering - Fusion clone = original_fusion; - - // Compare IR dumps - std::stringstream original_ir; - std::stringstream clone_ir; - original_ir << original_fusion; - clone_ir << clone; - ASSERT_EQ(original_ir.str(), clone_ir.str()); - - // Lower original fusion - std::string original_kernel; - { - // TODO(kir): remove this guard once we implement the cuda codegen visitor - FusionGuard fg(&original_fusion); - original_kernel = - codegen::generateCudaKernel(GpuLower(&original_fusion).kernel()); - } - - // Make sure the "before lowering" clone was not mutated - // while lowering the original fusion IR - std::stringstream before_lowering_ir; - before_lowering_ir << clone; - ASSERT_EQ(original_ir.str(), before_lowering_ir.str()); - - // Test copy after lowering (including assignment operator) - Fusion before_lowering = clone; - clone = original_fusion; - - // Compare IR dumps - std::stringstream original_lowered_ir; - std::stringstream clone_lowered_ir; - original_lowered_ir << original_fusion; - clone_lowered_ir << clone; - ASSERT_EQ(original_lowered_ir.str(), clone_lowered_ir.str()); - - // Lower the "before lowering" and compare kernels - std::string clone_kernel; - { - // TODO(kir): remove this guard once we implement the cuda codegen visitor - FusionGuard fg(&before_lowering); - clone_kernel = - codegen::generateCudaKernel(GpuLower(&before_lowering).kernel()); - } - ASSERT_EQ(original_kernel, clone_kernel); -} - -TEST_F(NVFuserTest, FusionMove_CUDA) { - Fusion fusion; - - // Create the test IR - { - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(3); - auto tv1 = makeSymbolicTensor(3); - auto tv2 = add(tv1, IrBuilder::create(2.0)); - auto tv3 = sub(add(tv0, mul(tv2, tv2)), tv2); - - fusion.addInput(tv0); - fusion.addInput(tv1); - fusion.addOutput(tv3); - - tv3->reorder({{0, 2}, {2, 0}}); - tv3->split(-1, 4); - tv3->reorder({{2, 0}, {3, 1}, {0, 3}}); - - tv0->computeAt(tv3, -1); - tv1->computeAt(tv3, -1); - - tv3->axis(0)->parallelize(ParallelType::BIDx); - tv3->axis(-1)->parallelize(ParallelType::TIDx); - } - - std::stringstream original_ir; - original_ir << fusion; - - // Test move before lowering - Fusion another_fusion = std::move(fusion); - - // Check that the original fusion is "empty" - // - // IMPORTANT: these checks assume knowledge of the internal - // implementation of the move operations. General uses - // should only assume that the moved-from object is in - // a valid, but unspecified state. This is similar to the - // standard library containers: - // https://en.cppreference.com/w/cpp/utility/move - // - TORCH_CHECK(fusion.unordered_exprs().empty()); - TORCH_CHECK(fusion.vals().empty()); - TORCH_CHECK(fusion.inputs().empty()); - TORCH_CHECK(fusion.outputs().empty()); - - // clear() has no pre-conditions so it's valid to call on a moved-from object - fusion.clear(); - - // Compare IR dumps - std::stringstream another_ir; - another_ir << another_fusion; - ASSERT_EQ(original_ir.str(), another_ir.str()); - - // Lower the fusion IR - GpuLower lower(&another_fusion); - - std::stringstream lowered_ir; - lowered_ir << another_fusion; - - // Test move assignment after lowering - fusion = std::move(another_fusion); - - // Compare IR dumps - std::stringstream moved_lowered_ir; - moved_lowered_ir << fusion; - ASSERT_EQ(lowered_ir.str(), moved_lowered_ir.str()); -} - -TEST_F(NVFuserTest, FusionSimpleArith_CUDA) { - std::stringstream ss1, ss2; - - Fusion fusion; - FusionGuard fg(&fusion); - - Double* d1 = IrBuilder::create(1.f); - Double* d2 = IrBuilder::create(2.f); - Double* d3 = IrBuilder::create(); - - // Disrupt the fusion to make sure guard works well - { - Fusion fusion2; - FusionGuard fg(&fusion2); - - Double* d1 = IrBuilder::create(1.f); - Double* d2 = IrBuilder::create(2.f); - add(d1, d2); - ss2 << fusion2; - } - - IrBuilder::create(BinaryOpType::Add, d3, d1, d2); - ss1 << fusion; - - TORCH_CHECK( - ss1.str().compare(ss2.str()) == 0, - "Error where explicit add nodes don't match implicit add nodes."); -} - -TEST_F(NVFuserTest, FusionScalarTypePromote_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - Bool* b = IrBuilder::create(true); - Double* d = IrBuilder::create(4.f); - Int* i = IrBuilder::create(3); - ComplexDouble* c = - IrBuilder::create(c10::complex(1, 2)); - - TORCH_CHECK(add(b, b)->getDataType() == DataType::Bool); - TORCH_CHECK(add(b, d)->getDataType() == DataType::Double); - TORCH_CHECK(add(b, i)->getDataType() == DataType::Int); - TORCH_CHECK(add(b, c)->getDataType() == DataType::ComplexDouble); - - TORCH_CHECK(add(d, b)->getDataType() == DataType::Double); - TORCH_CHECK(add(d, d)->getDataType() == DataType::Double); - TORCH_CHECK(add(d, i)->getDataType() == DataType::Double); - TORCH_CHECK(add(d, c)->getDataType() == DataType::ComplexDouble); - - TORCH_CHECK(add(i, b)->getDataType() == DataType::Int); - TORCH_CHECK(add(i, d)->getDataType() == DataType::Double); - TORCH_CHECK(add(i, i)->getDataType() == DataType::Int); - TORCH_CHECK(add(i, c)->getDataType() == DataType::ComplexDouble); - - TORCH_CHECK(add(c, b)->getDataType() == DataType::ComplexDouble); - TORCH_CHECK(add(c, d)->getDataType() == DataType::ComplexDouble); - TORCH_CHECK(add(c, i)->getDataType() == DataType::ComplexDouble); - TORCH_CHECK(add(c, c)->getDataType() == DataType::ComplexDouble); -} - -TEST_F(NVFuserTest, FusionComplexAbsTypes_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto options = at::TensorOptions().device(at::kCUDA, 0); - auto tensor_cf = at::randn({4, 4, 4}, options.dtype(at::kComplexFloat)); - auto tensor_cd = at::randn({4, 4, 4}, options.dtype(at::kComplexDouble)); - - auto type_cf = TensorType::create(tensor_cf); - auto tv_cf = IrBuilder::create(type_cf); - auto type_cd = TensorType::create(tensor_cd); - auto tv_cd = IrBuilder::create(type_cd); - - TORCH_CHECK( - tensor_cf.abs().scalar_type() == - data_type_to_aten(abs(tv_cf)->getDataType().value())); - TORCH_CHECK( - tensor_cd.abs().scalar_type() == - data_type_to_aten(abs(tv_cd)->getDataType().value())); -} - -TEST_F(NVFuserTest, FusionRegister_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - Double* v1 = IrBuilder::create(1.f); - Double* v2 = IrBuilder::create(2.f); - Val* v3 = binaryOp(BinaryOpType::Add, v1, v2); - Val* v4 = binaryOp(BinaryOpType::Add, v1, v2); - TORCH_CHECK(v1->name() + 1 == v2->name()); - TORCH_CHECK(v2->name() + 1 == v3->name()); - TORCH_CHECK(v3->name() + 1 == v4->name()); - TORCH_CHECK(v3->definition()->name() + 1 == v4->definition()->name()); -} - -// dummy expr with 2 outputs only for toposort test. -struct DummyExpr : public Expr { - ~DummyExpr() = default; - DummyExpr( - IrBuilderPasskey passkey, - Val* _outlhs, - Val* _outrhs, - Val* _lhs, - Val* _rhs) - : Expr(passkey, ExprType::UnaryOp) // Not terribly safe... - { - addOutput(_outlhs); - addOutput(_outrhs); - addInput(_lhs); - addInput(_rhs); - } - DummyExpr(const DummyExpr& other) = delete; - DummyExpr& operator=(const DummyExpr& other) = delete; - DummyExpr(DummyExpr&& other) = delete; - DummyExpr& operator=(DummyExpr&& other) = delete; -}; - -TEST_F(NVFuserTest, FusionTopoSort_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // e0: v3, v2 = dummy(v1, v0) - // e1: v4 = add(v3, v2) - // e2: v5 = add(v2, v4) - // e3: v6 = add(v5, v5) - Double* v0 = IrBuilder::create(); - Double* v1 = IrBuilder::create(); - Double* v2 = IrBuilder::create(); - Double* v3 = IrBuilder::create(); - Double* v4 = IrBuilder::create(); - Double* v5 = IrBuilder::create(); - Double* v6 = IrBuilder::create(); - - std::vector inputs = {v0, v1}; - for (auto val : inputs) { - fusion.addInput(val); - } - - Expr* e0 = IrBuilder::create(v3, v2, v1, v0); - Expr* e1 = IrBuilder::create(BinaryOpType::Add, v4, v3, v2); - Expr* e2 = IrBuilder::create(BinaryOpType::Add, v5, v2, v4); - Expr* e3 = IrBuilder::create(BinaryOpType::Add, v6, v5, v5); - - fusion.addOutput(v2); - fusion.addOutput(v3); - auto exprs = fusion.exprs(); - TORCH_CHECK(exprs.size() == 1, "Found ", exprs.size(), " but expecting 1"); - TORCH_CHECK(exprs[0] == e0); - - fusion.addOutput(v5); - exprs = fusion.exprs(); - TORCH_CHECK(exprs.size() == 3, "Found ", exprs.size(), " but expecting 3"); - TORCH_CHECK(exprs[0] == e0); - TORCH_CHECK(exprs[1] == e1); - TORCH_CHECK(exprs[2] == e2); - - fusion.addOutput(v4); - exprs = fusion.exprs(); - TORCH_CHECK(exprs.size() == 3, "Found ", exprs.size(), " but expecting 3"); - TORCH_CHECK(exprs[0] == e0); - TORCH_CHECK(exprs[1] == e1); - TORCH_CHECK(exprs[2] == e2); - - fusion.addOutput(v6); - exprs = fusion.exprs(); - TORCH_CHECK(exprs.size() == 4, "Found ", exprs.size(), " but expecting 4"); - TORCH_CHECK(exprs[0] == e0); - TORCH_CHECK(exprs[1] == e1); - TORCH_CHECK(exprs[2] == e2); - TORCH_CHECK(exprs[3] == e3); - - TORCH_CHECK(v2->definition()->name() == 0); - TORCH_CHECK(v3->definition()->name() == 0); - TORCH_CHECK(v4->definition()->name() == 1); - TORCH_CHECK(v5->definition()->name() == 2); - TORCH_CHECK(v6->definition()->name() == 3); -} - -TEST_F(NVFuserTest, FusionTensor_CUDA) { - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - Fusion fusion; - FusionGuard fg(&fusion); - - { - auto tensor = at::randn({2, 3, 4, 5}, options); - auto tensor_type = TensorType::create(tensor); - auto fuser_tensor = IrBuilder::create(tensor_type); - TORCH_CHECK((int64_t)fuser_tensor->nDims() == tensor.dim()); - TORCH_CHECK(fuser_tensor->getDataType().value() == DataType::Float); - TORCH_CHECK(fuser_tensor->domain() != nullptr); - for (const auto i : c10::irange(fuser_tensor->nDims())) { - // size 1 dimension are makred as broadcast - TORCH_CHECK( - fuser_tensor->axis(i)->isBroadcast() == (tensor.sizes()[i] == 1)); - // check contiguity information; - TORCH_CHECK(fuser_tensor->domain()->contiguity()[i]); - } - } - - // TensorType::create fills stride_properties, which helps us to mark - // IterDomain properly - // Note: implementation could change, depending on how much we want to invest - // in our home-brew contiguity coalescing. For now let's make sure that we - // properly test what we are using. - { - auto tensor = at::randn({4, 4, 4}, options); - auto sliced_tensor = tensor.slice(1, 0, -1, 2); - - auto tensor_type = TensorType::create(sliced_tensor); - auto fuser_tensor = IrBuilder::create(tensor_type); - TORCH_CHECK((int64_t)fuser_tensor->nDims() == tensor.dim()); - TORCH_CHECK(fuser_tensor->getDataType().value() == DataType::Float); - TORCH_CHECK(fuser_tensor->domain() != nullptr); - for (const auto i : c10::irange(fuser_tensor->nDims())) { - // size 1 dimension are makred as broadcast - TORCH_CHECK(fuser_tensor->axis(i)->isBroadcast() == false); - } - TORCH_CHECK(fuser_tensor->domain()->contiguity()[0]); - TORCH_CHECK(!fuser_tensor->domain()->contiguity()[1]); - TORCH_CHECK(fuser_tensor->domain()->contiguity()[2]); - } - - { - auto tensor = at::randn({2, 3, 4, 5}, options); - auto permuted_tensor = tensor.permute({0, 3, 1, 2}); - auto tensor_type = TensorType::create(permuted_tensor); - auto fuser_tensor = IrBuilder::create(tensor_type); - TORCH_CHECK((int64_t)fuser_tensor->nDims() == tensor.dim()); - TORCH_CHECK(fuser_tensor->getDataType().value() == DataType::Float); - TORCH_CHECK(fuser_tensor->domain() != nullptr); - for (const auto i : c10::irange(fuser_tensor->nDims())) { - // size 1 dimension are makred as broadcast - TORCH_CHECK(fuser_tensor->axis(i)->isBroadcast() == false); - } - TORCH_CHECK(!fuser_tensor->domain()->contiguity()[0]); - TORCH_CHECK(!fuser_tensor->domain()->contiguity()[1]); - TORCH_CHECK(fuser_tensor->domain()->contiguity()[2]); - TORCH_CHECK(!fuser_tensor->domain()->contiguity()[3]); - } -} - -TEST_F(NVFuserTest, FusionFilterVals_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - auto tv1 = makeSymbolicTensor(1); - auto scalar0 = IrBuilder::create(0); - auto scalar1 = IrBuilder::create(0); - auto scalar2 = IrBuilder::create(1); - - const std::vector vals = {tv0, scalar0, tv1, scalar1, scalar2}; - - std::vector tvs( - ir_utils::filterByType(vals).begin(), - ir_utils::filterByType(vals).end()); - TORCH_CHECK(tvs.size() == 2); - TORCH_CHECK(tvs[0] == tv0); - TORCH_CHECK(tvs[1] == tv1); - - std::vector floats( - ir_utils::filterByType(vals).begin(), - ir_utils::filterByType(vals).end()); - TORCH_CHECK(floats.size() == 1); - TORCH_CHECK(floats[0] == scalar0); - - std::vector ints( - ir_utils::filterByType(vals).begin(), - ir_utils::filterByType(vals).end()); - TORCH_CHECK(ints.size() == 2); - TORCH_CHECK(ints[0] == scalar1); - TORCH_CHECK(ints[1] == scalar2); - - TORCH_CHECK( - ir_utils::filterByType(vals).begin() == - ir_utils::filterByType(vals).end(), - "Not expecting any results"); -} - -TEST_F(NVFuserTest, FusionTVSplit_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv = makeSymbolicTensor(3); - - tv = tv->split(2, 2); - TORCH_CHECK(tv->nDims() == 4); - Expr* outer = tv->axis(2)->extent()->definition(); - - TORCH_CHECK( - outer->getExprType().value() == ExprType::BinaryOp && - static_cast(outer)->getBinaryOpType() == - BinaryOpType::CeilDiv && - static_cast(outer)->lhs()->sameAs( - tv->getRootDomain()[2]->extent()) && - static_cast(static_cast(outer)->rhs()) - ->sameAs(IrBuilder::create(2))); - - IterDomain* inner = static_cast(tv->axis(3)); - TORCH_CHECK( - inner->extent()->isScalar() && - static_cast(inner->extent())->isConst() && - static_cast(inner->extent())->value().value() == 2); -} - -TEST_F(NVFuserTest, FusionTVMerge_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv = makeSymbolicTensor(3); - - tv = tv->merge(1); - Expr* axisOp = tv->axis(1)->extent()->definition(); - - TORCH_CHECK( - tv->nDims() == 2 && axisOp->getExprType() == ExprType::BinaryOp && - static_cast(axisOp)->getBinaryOpType() == BinaryOpType::Mul && - static_cast(axisOp)->lhs() == - tv->getRootDomain()[1]->extent() && - static_cast(axisOp)->rhs() == - tv->getRootDomain()[2]->extent()); -} - -TEST_F(NVFuserTest, FusionTVReorder_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - std::unordered_map shift_right{{-1, 0}}; - - std::unordered_map shift_left{{0, -1}}; - - std::unordered_map shift_left_2{{0, -1}, {1, 0}, {2, 1}}; - - std::unordered_map swap{{0, 2}, {2, 0}}; - - auto tv = makeSymbolicTensor(3); - std::vector ref; - ref = std::vector( - tv->domain()->domain().begin(), tv->domain()->domain().end()); - - tv->reorder(shift_left); - for (const auto i : c10::irange(tv->nDims())) { - TORCH_CHECK(ref[i]->sameAs(tv->axis(i - 1))); - } - - tv = makeSymbolicTensor(3); - ref = std::vector( - tv->domain()->domain().begin(), tv->domain()->domain().end()); - - tv->reorder(shift_left); - for (const auto i : c10::irange(tv->nDims())) { - TORCH_CHECK(ref[i]->sameAs(tv->axis(i - 1))); - } - - tv = makeSymbolicTensor(3); - ref = std::vector( - tv->domain()->domain().begin(), tv->domain()->domain().end()); - - tv->reorder(shift_right); - TORCH_CHECK(ref[ref.size() - 1]->sameAs(tv->axis(0))); - for (const auto i : c10::irange(1, tv->nDims())) { - TORCH_CHECK(ref[i - 1]->sameAs(tv->axis(i))); - } - - tv = makeSymbolicTensor(3); - ref = std::vector( - tv->domain()->domain().begin(), tv->domain()->domain().end()); - tv->reorder(swap); - TORCH_CHECK(ref[0]->sameAs(tv->axis(2))); - TORCH_CHECK(ref[2]->sameAs(tv->axis(0))); - TORCH_CHECK(ref[1]->sameAs(tv->axis(1))); -} - -TEST_F(NVFuserTest, FusionEquality_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - Double* fval1 = IrBuilder::create(); - Double* fval1_copy = fval1; - Double* fval2 = IrBuilder::create(); - Double* fone = IrBuilder::create(1.0); - - TORCH_CHECK(fval1->sameAs(fval1_copy)); - TORCH_CHECK(!fval1->sameAs(fval2)); - TORCH_CHECK(!fone->sameAs(fval1)); - TORCH_CHECK(fone->sameAs(IrBuilder::create(1.0))); - - Int* ival1 = IrBuilder::create(); - Int* ival1_copy = ival1; - Int* ival2 = IrBuilder::create(); - Int* ione = IrBuilder::create(1); - - TORCH_CHECK(ival1->sameAs(ival1_copy)); - TORCH_CHECK(!ival1->sameAs(ival2)); - TORCH_CHECK(!ione->sameAs(ival1)); - TORCH_CHECK(ione->sameAs(IrBuilder::create(1))); - - BinaryOp* add1 = IrBuilder::create( - BinaryOpType::Add, IrBuilder::create(), fval1, ival1); - BinaryOp* add1_copy = IrBuilder::create( - BinaryOpType::Add, IrBuilder::create(), fval1, ival1); - BinaryOp* sub1 = IrBuilder::create( - BinaryOpType::Sub, IrBuilder::create(), fval1, ival1); - - UnaryOp* neg1 = IrBuilder::create( - UnaryOpType::Neg, IrBuilder::create(), fval1); - UnaryOp* neg2 = IrBuilder::create( - UnaryOpType::Neg, IrBuilder::create(), fval2); - UnaryOp* neg1_copy = IrBuilder::create( - UnaryOpType::Neg, IrBuilder::create(), fval1); - - TORCH_CHECK(add1->sameAs(add1_copy)); - TORCH_CHECK(!add1->sameAs(sub1)); - - TORCH_CHECK(neg1->sameAs(neg1_copy)); - TORCH_CHECK(!static_cast(neg1)->sameAs(add1)); - TORCH_CHECK(!neg1->sameAs(neg2)); -} - -TEST_F(NVFuserTest, FusionDependency_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - Double* d0 = IrBuilder::create(0.f); - Double* d1 = IrBuilder::create(1.f); - auto d2 = add(d0, d1); - - auto d3 = add(d2, d2); - - Double* d4 = IrBuilder::create(4.f); - Double* d5 = IrBuilder::create(5.f); - auto d6 = add(d4, d5); - - Double* d7 = IrBuilder::create(7.f); - Double* d8 = IrBuilder::create(8.f); - auto d9 = add(d7, d8); - - auto d10 = add(d6, d9); - - auto d11 = add(d3, d10); - - TORCH_CHECK(DependencyCheck::isDependencyOf(d0, d11)); - TORCH_CHECK(DependencyCheck::isDependencyOf(d1, d11)); - TORCH_CHECK(DependencyCheck::isDependencyOf(d2, d11)); - TORCH_CHECK(DependencyCheck::isDependencyOf(d3, d11)); - TORCH_CHECK(DependencyCheck::isDependencyOf(d6, d11)); - TORCH_CHECK(DependencyCheck::isDependencyOf(d9, d11)); - TORCH_CHECK(DependencyCheck::isDependencyOf(d0, d2)); - TORCH_CHECK(DependencyCheck::isDependencyOf(d2, d3)); - TORCH_CHECK(DependencyCheck::isDependencyOf(d4, d6)); - TORCH_CHECK(DependencyCheck::isDependencyOf(d8, d10)); - - TORCH_CHECK(!DependencyCheck::isDependencyOf(d11, d0)); - TORCH_CHECK(!DependencyCheck::isDependencyOf(d11, d1)); - TORCH_CHECK(!DependencyCheck::isDependencyOf(d11, d2)); - TORCH_CHECK(!DependencyCheck::isDependencyOf(d11, d3)); - TORCH_CHECK(!DependencyCheck::isDependencyOf(d11, d4)); - TORCH_CHECK(!DependencyCheck::isDependencyOf(d11, d5)); - TORCH_CHECK(!DependencyCheck::isDependencyOf(d2, d0)); - TORCH_CHECK(!DependencyCheck::isDependencyOf(d3, d2)); - TORCH_CHECK(!DependencyCheck::isDependencyOf(d6, d4)); - TORCH_CHECK(!DependencyCheck::isDependencyOf(d10, d8)); - - auto dep_chain = DependencyCheck::getSingleDependencyChain(d0, d11); - TORCH_CHECK(dep_chain.back() == d11); - dep_chain.pop_back(); - TORCH_CHECK(dep_chain.back() == d3); - dep_chain.pop_back(); - TORCH_CHECK(dep_chain.back() == d2); - dep_chain.pop_back(); - - dep_chain = DependencyCheck::getSingleDependencyChain(d6, d11); - TORCH_CHECK(dep_chain.back() == d11); - dep_chain.pop_back(); - TORCH_CHECK(dep_chain.back() == d10); - dep_chain.pop_back(); - - dep_chain = DependencyCheck::getSingleDependencyChain(d4, d11); - TORCH_CHECK(dep_chain.back() == d11); - dep_chain.pop_back(); - TORCH_CHECK(dep_chain.back() == d10); - dep_chain.pop_back(); - TORCH_CHECK(dep_chain.back() == d6); - dep_chain.pop_back(); - - dep_chain = DependencyCheck::getSingleDependencyChain(d11, d2); - TORCH_CHECK(dep_chain.empty()); -} - -TEST_F(NVFuserTest, FusionParser_CUDA) { - // This test may not pass if using a custom block sync as there may - // be additional calls. Skip the test as it's not specifically - // relevant with block synchronizatin. - if (std::getenv("PYTORCH_NVFUSER_USE_BLOCK_SYNC_ATOMIC")) { - return; - } - auto g = std::make_shared(); - const auto graph0_string = R"IR( - graph(%0 : Float(2, strides=[1]), - %1 : Float(2, strides=[1])): - %c0 : Float(2, strides=[1]) = aten::mul(%0, %1) - %d0 : Float(2, strides=[1]) = aten::mul(%c0, %0) - return (%d0))IR"; - parseIR(graph0_string, g.get()); - - // strides are not yet supported in the irparser. - for (auto val : g->block()->inputs()) { - if (val->isCompleteTensor()) - val->setType(val->type()->castRaw()->contiguous()); - } - for (auto node : g->block()->nodes()) { - for (auto val : node->outputs()) { - if (val->isCompleteTensor()) - val->setType(val->type()->castRaw()->contiguous()); - } - } - - auto fusion = parseJitIR(g); - FusionGuard fg(fusion.get()); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - // Avoid vectorization here as those kernels can't be lowered twice at the - // moment - at::Tensor input1 = at::randn({16}, options); - at::Tensor input2 = at::randn({16}, options); - auto lparams = schedulePointwise(fusion.get(), {input1, input2}); - - // CONSIDER: - // 1. this can be moved to a dedicated "golden" file - // 2. use a fuzzy compare (ignore non-significant whitespaces for example) - const std::string expected_kernel = R"( -__global__ void CUDAGeneratedKernel(Tensor T0, Tensor T1, Tensor T3) { - int64_t i50; - i50 = (((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x); - if ((i50 < T0.size[0])) { - float T5[1]; - T5[0] = 0; - T5[0] - = T1[i50]; - float T4[1]; - T4[0] = 0; - T4[0] - = T0[i50]; - float T2[1]; - T2[0] - = T4[0] - * T5[0]; - float T6[1]; - T6[0] - = T2[0] - * T4[0]; - T3[i50] - = T6[0]; - } -} -)"; - - const std::string actual_kernel = - "\n" + codegen::generateCudaKernel(GpuLower(fusion.get()).kernel()); - if (expected_kernel.size() != actual_kernel.size() || - expected_kernel.compare(actual_kernel) != 0) { - std::cerr - << " Codegen mismatch, codegen possibly changed, or is incorrect. " - << " \n ========= EXPECTED ========= \n" - << expected_kernel << "\n========= ACTUAL ========== \n" - << actual_kernel << "\n=================" << std::endl; - auto it = std::mismatch( - expected_kernel.begin(), - expected_kernel.end(), - actual_kernel.begin(), - actual_kernel.end()); - std::string actual_mismatched_snippet(it.second, actual_kernel.end()); - actual_mismatched_snippet = actual_mismatched_snippet.substr(0, 10); - std::string expected_mismatched_snippet(it.first, expected_kernel.end()); - expected_mismatched_snippet = expected_mismatched_snippet.substr(0, 10); - std::cerr << "First mismatch found at: " << actual_mismatched_snippet - << ", expected: " << expected_mismatched_snippet << std::endl; - TORCH_CHECK(false); - } - - FusionExecutor fe; - fe.compileFusion(fusion.get(), {input1, input2}, lparams); - auto outputs = fe.runFusion({input1, input2}, lparams); - at::Tensor output_ref = input1 * input2 * input1; - TORCH_CHECK(output_ref.equal(outputs[0])); -} - -TEST_F(NVFuserTest, FusionOuterSplit_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeSymbolicTensor(3); - - IrBuilder::create( - BinaryOpType::Add, - tv0, - IrBuilder::create(0.0), - IrBuilder::create(1.0)); - TensorView* tv1 = add(tv0, IrBuilder::create(2.0)); - TensorView* tv2 = add(tv1, IrBuilder::create(3.0)); - fusion.addOutput(tv2); - - //[I0, I1, I2] - tv2->split(-1, 4, false); - //[I0, I1, I2o{4}, I2i] - tv2->merge(0); - tv2->merge(0); - //[I0*I1*I2o{4}, I2i] - tv2->split(0, 2); - //[I0*I1*I2o{4}o, I0*I1*I2o{4}i{2}, I2i] - tv2->reorder({{0, 1}, {1, 0}}); - // I0*I1*I2o{4}i{2}, [I0*I1*I2o{4}o, I2i] - - tv0->computeAt(tv2, -1); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::Tensor output = at::empty({2, 6, 32}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion); - fe.runFusion({}, {output}); - - at::Tensor output_ref = at::zeros_like(output, options); - output_ref = output_ref + 0.0 + 1.0 + 2.0 + 3.0; - - TORCH_CHECK(output_ref.equal(output)); -} - -TEST_F(NVFuserTest, FusionCodeGen_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeSymbolicTensor(3); - - IrBuilder::create( - BinaryOpType::Add, - tv0, - IrBuilder::create(0.0), - IrBuilder::create(1.0)); - TensorView* tv1 = add(tv0, IrBuilder::create(2.0)); - TensorView* tv2 = add(tv1, IrBuilder::create(3.0)); - fusion.addOutput(tv2); - - //[I0, I1, I2] - tv2 = tv2->split(0, 4); - //[I0o, I0i{4}, I1, I2] - tv2 = tv2->merge(1); - //[I0o, I0i{4}*I1, I2] - tv2 = tv2->split(-1, 2); - //[I0o, I0i{4}*I1, I2o, I2i{2}] - tv2 = tv2->reorder({{0, 1}, {1, 0}, {3, 2}}); - //[I0i{4}*I1, I0o, I2i{2}, I2o] - - tv0->computeAt(tv2, -1); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::Tensor output = at::empty({16, 8, 8}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion); - fe.runFusion({}, {output}); - - at::Tensor output_ref = at::zeros_like(output, options); - output_ref = output_ref + 0.0 + 1.0 + 2.0 + 3.0; - - TORCH_CHECK(output_ref.equal(output)); -} - -TEST_F(NVFuserTest, FusionCodeGen2_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeSymbolicTensor(3); - TensorView* tv1 = makeSymbolicTensor(3); - TensorView* tv2 = add(tv1, IrBuilder::create(2.0)); - TensorView* tv3 = add(tv0, tv2); - - fusion.addInput(tv0); - fusion.addInput(tv1); - fusion.addOutput(tv3); - - //[I0, I1, I2] - tv3->reorder({{0, 2}, {2, 0}}); - //[I2, I1, I0] - tv3->split(-1, 4); - //[I2, I1, I0o, I0i{4}] - tv3->reorder({{2, 0}, {3, 1}, {0, 3}}); - // I0o, I0i{4}, I1, I2] - - tv0->computeAt(tv3, -1); - tv1->computeAt(tv3, -1); - - tv3->axis(0)->parallelize(ParallelType::BIDx); - tv3->axis(-1)->parallelize(ParallelType::TIDx); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::Tensor input1 = at::randn({16, 8, 8}, options); - at::Tensor input2 = at::randn_like(input1); - - FusionExecutor fe; - fe.compileFusion(&fusion, {input1, input2}); - auto outputs = fe.runFusion({input1, input2}); - - at::Tensor tv2_ref = input2 + 2.0; - at::Tensor output_ref = input1 + tv2_ref; - - TORCH_CHECK(output_ref.equal(outputs[0])); -} - -TEST_F(NVFuserTest, FusionSimplePWise_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - // dimensionality of the problem - int nDims = 3; - - // Set up your input tensor views - TensorView* tv0 = makeContigTensor(nDims); - TensorView* tv1 = makeContigTensor(nDims); - - // Register your inputs - fusion.addInput(tv0); - fusion.addInput(tv1); - - // Do math with it, it returns a `Val*` but can be static_casted back to - // TensorView - TensorView* tv2 = add(tv1, IrBuilder::create(2.0)); - TensorView* tv3 = add(tv0, tv2); - - // Register your outputs - fusion.addOutput(tv3); - - // Do transformations, remember, transformations are outputs to inputs - // This doesn't have to be in this order - tv3->merge(1); - tv3->merge(0); - - // Split by n_threads - tv3->split(0, 128); - tv3->split(0, 4); - - // For all inputs, computeAt the output inline, temporaries should be squeezed - // between them - tv0->computeAt(tv3, -1); - tv1->computeAt(tv3, -1); - - // Parallelize TV3 - tv3->axis(0)->parallelize(ParallelType::BIDx); - tv3->axis(-2)->parallelize(ParallelType::Unroll); - tv3->axis(-1)->parallelize(ParallelType::TIDx); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::Tensor input1 = at::randn({64, 2, 128}, options); - at::Tensor input2 = at::rand_like(input1); - at::Tensor output = at::empty_like(input1); - - FusionExecutor fe; - fe.compileFusion(&fusion, {input1, input2}); - fe.runFusion({input1, input2}, {output}); - - at::Tensor tv2_ref = input2 + 2.0; - at::Tensor output_ref = input1 + tv2_ref; - - TORCH_CHECK(output_ref.equal(output)); -} - -TEST_F(NVFuserTest, FusionSimpleAmperePipeline_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // requires ampere+ GPU - if (!deviceMajorMinorCheck(8)) { - GTEST_SKIP() << "skipping tests on pre-AMPERE GPUs"; - return; - } - - auto tv0 = makeContigTensor(1); - - fusion.addInput(tv0); - - auto tv1 = set(tv0); - - fusion.addOutput(tv1); - - auto tv_cache = tv0->cacheAfter(LoadStoreOpType::CpAsync); - tv_cache->setMemoryType(MemoryType::Shared); - - tv1->split(0, 16); - tv0->computeAt(tv1, 1); - - tv_cache->circularBuffer(10); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input1 = at::randn({255}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {input1}); - auto cg_outputs = fe.runFusion({input1}); - - testValidate(&fusion, cg_outputs, {input1}, {input1}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionSimplePWiseDtypeComplex_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - // dimensionality of the problem - int nDims = 3; - - // Set up your input tensor views - TensorView* tv0 = makeContigTensor(nDims, DataType::ComplexFloat); - TensorView* tv1 = makeContigTensor(nDims, DataType::ComplexFloat); - - // Register your inputs - fusion.addInput(tv0); - fusion.addInput(tv1); - - // Do math with it, it returns a `Val*` but can be static_casted back to - // TensorView - c10::complex scalar1(2.0, 3.0); - TensorView* tv2 = add(tv1, IrBuilder::create(scalar1)); - TensorView* tv3 = add(tv0, tv2); - - // Register your outputs - fusion.addOutput(tv3); - - // Do transformations, remember, transformations are outputs to inputs - // This doesn't have to be in this order - tv3->merge(1); - tv3->merge(0); - - // Split by n_threads - tv3->split(0, 128); - tv3->split(0, 4); - - // For all inputs, computeAt the output inline, temporaries should be squeezed - // between them - tv0->computeAt(tv3, -1); - tv1->computeAt(tv3, -1); - - // Parallelize TV3 - tv3->axis(0)->parallelize(ParallelType::BIDx); - tv3->axis(-2)->parallelize(ParallelType::Unroll); - tv3->axis(-1)->parallelize(ParallelType::TIDx); - - auto options = - at::TensorOptions().dtype(at::kComplexFloat).device(at::kCUDA, 0); - - at::Tensor input1 = at::randn({64, 2, 128}, options); - at::Tensor input2 = at::rand_like(input1); - at::Tensor output = at::empty_like(input1); - - FusionExecutor fe; - fe.compileFusion(&fusion, {input1, input2}); - fe.runFusion({input1, input2}, {output}); - - at::Tensor tv2_ref = input2 + scalar1; - at::Tensor output_ref = input1 + tv2_ref; - - TORCH_CHECK(output_ref.equal(output)); -} - -TEST_F(NVFuserTest, FusionExecKernel_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Set up your input tensor views - TensorView* tv0 = makeSymbolicTensor(2); - TensorView* tv1 = makeSymbolicTensor(2); - - // Register your inputs - fusion.addInput(tv0); - fusion.addInput(tv1); - - // Do math with it, it returns a `Val*` but can be static_casted back to - // TensorView - TensorView* tv2 = add(tv1, IrBuilder::create(2.0)); - TensorView* tv3 = add(tv0, tv2); - - // Register your outputs - fusion.addOutput(tv3); - - tv3->merge(0); - tv3->split(0, 128); - tv3->split(0, 4); - - // For all inputs, computeAt the output inline, temporaries should be squeezed - // between them - tv0->computeAt(tv3, 1); - tv1->computeAt(tv3, 1); - - // Parallelize TV3 - tv3->axis(0)->parallelize(ParallelType::BIDx); - tv2->axis(1)->parallelize(ParallelType::Unroll); - tv3->axis(1)->parallelize(ParallelType::Unroll); - tv2->axis(-1)->parallelize(ParallelType::TIDx); - tv3->axis(-1)->parallelize(ParallelType::TIDx); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::Tensor input1 = at::ones({1, 128}, options); - at::Tensor input2 = at::ones_like(input1); - - FusionExecutor fe; - fe.compileFusion(&fusion, {input1, input2}); - auto outputs = fe.runFusion({input1, input2}); - - at::Tensor check = at::full({1, 128}, 4, options); - ; - TORCH_CHECK(outputs[0].equal(check)); -} - -int ceilDiv_(int a, int b) { - return (a + b - 1) / b; -} - -TEST_F(NVFuserTest, FusionAdvancedComputeAt1_CUDA) { - // Case 1 - // tv1 = tv0 * 0.5 - // tv2 = tv1 * -1 - // tv3 = tv1 + 3 - // tv4 = tv1 * 2 - // tv5 = tv3 + tv2 - // tv6 = tv5 + tv4 - // tv7 = tv1 + tv4 - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - - TensorView* tv1 = mul(tv0, IrBuilder::create(0.5)); - TensorView* tv2 = mul(tv1, IrBuilder::create(-1.0)); - TensorView* tv3 = add(tv1, IrBuilder::create(3.0)); - TensorView* tv4 = mul(tv1, IrBuilder::create(2.0)); - TensorView* tv5 = add(tv3, tv2); - - TensorView* tv6 = add(tv5, tv4); - TensorView* tv7 = add(tv1, tv4); - - fusion.addOutput(tv6); - fusion.addOutput(tv7); - - // Lets setup to actually run - tv7->merge(0); - tv7->split(0, 128); - tv7->split(0, 4); - - tv7->axis(0)->parallelize(ParallelType::BIDx); - - tv0->computeAt(tv7, 1); - - ComputeAtMap ca_map(&fusion); - - // The this-position of the last tensor should be zero. - TORCH_CHECK( - tv7->nDims() == 3 && tv7->getComputeAtPosition() == 0 && - tv7->getMaxProducerPosition() == 1); - TORCH_CHECK( - tv7->nDims() == 3 && tv6->getComputeAtPosition() == 0 && - tv6->getMaxProducerPosition() == 1); - // The position of every other tensor should be 1. - for (auto tv : {tv1, tv2, tv3, tv4, tv5}) { - TORCH_CHECK(tv->nDims() == 3 && tv->getComputeAtPosition() == 1); - - TORCH_CHECK( - ca_map.areMapped(tv7->axis(0), tv->axis(0), IdMappingMode::PERMISSIVE)); - } - - for (Val* val : fusion.vals()) { - if (!val->isFusionInput() && - val->getValType().value() == ValType::TensorView) { - TensorView* tv = static_cast(val); - tv->axis(1)->parallelize(ParallelType::Unroll); - tv->axis(-1)->parallelize(ParallelType::TIDx); - } - } - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::Tensor aten_input = at::randn({129, 127}, options); - - auto t1 = aten_input.mul({0.5}); - auto t2 = t1.mul({-1.0}); - auto t3 = t1.add({3.0}); - auto t4 = t1.mul({2.0}); - auto t5 = t3.add(t2); - auto t6 = t5.add(t4); - auto t7 = t1.add(t4); - - std::vector aten_outputs = {t6, t7}; - std::vector cg_outputs = { - at::empty_like(aten_input, options), at::empty_like(aten_input, options)}; - - FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input}); - fe.runFusion({aten_input}, cg_outputs); - - testValidate( - &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionAdvancedComputeAt2_CUDA) { - // Case 2 - // tv1 = tv0 * -1 - // tv2 = tv0 + 3 - // tv3 = tv0 * 2 - // tv4 = tv2 + tv1 - // tv5 = tv4 + tv3 - // tv6 = tv5 + tv3 - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - - TensorView* tv1 = mul(tv0, IrBuilder::create(-1.0)); - TensorView* tv2 = add(tv0, IrBuilder::create(3.0)); - TensorView* tv3 = mul(tv0, IrBuilder::create(2.0)); - TensorView* tv4 = add(tv2, tv1); - - TensorView* tv5 = add(tv4, tv3); - TensorView* tv6 = add(tv5, tv3); - - fusion.addOutput(tv5); - fusion.addOutput(tv6); - - // Lets setup to actually run - tv6->merge(0); - tv6->split(0, 128); - tv6->split(0, 4); - - tv6->axis(0)->parallelize(ParallelType::BIDx); - - tv0->computeAt(tv6, 1); - - for (Val* val : fusion.vals()) { - if (!val->isFusionInput() && - val->getValType().value() == ValType::TensorView) { - TensorView* tv = static_cast(val); - - tv->axis(1)->parallelize(ParallelType::Unroll); - tv->axis(-1)->parallelize(ParallelType::TIDx); - } - } - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input = at::randn({129, 127}, options); - - auto t1 = input.mul({-1.0}); - auto t2 = input.add({3.0}); - auto t3 = input.mul({2.0}); - auto t4 = t2.add(t1); - auto t5 = t4.add(t3); - auto t6 = t5.add(t3); - - std::vector aten_outputs = {t5, t6}; - - FusionExecutor fe; - fe.compileFusion(&fusion, {input}); - auto cg_outputs = fe.runFusion({input}); - - testValidate(&fusion, cg_outputs, {input}, aten_outputs, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionAdvancedComputeAt3_CUDA) { - // Case 3 - // T2 = T1 * 0.979361 - // T3 = T2 * T0 - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeSymbolicTensor(4); - fusion.addInput(tv0); - - TensorView* tv1 = makeSymbolicTensor(4); - fusion.addInput(tv1); - - TensorView* tv2 = mul(tv1, IrBuilder::create(.979361)); - TensorView* tv3 = mul(tv2, tv0); - - fusion.addOutput(tv3); - - // Lets setup to actually run - while (tv3->nDims() > 1) - tv3->merge(0); - tv3->split(0, 128); - tv3->split(0, 4); - - tv0->computeAt(tv3, 1); - tv1->computeAt(tv3, 1); - - tv3->axis(0)->parallelize(ParallelType::BIDx); - - for (Val* val : fusion.vals()) { - if (!val->isFusionInput() && - val->getValType().value() == ValType::TensorView) { - TensorView* tv = static_cast(val); - - tv->axis(1)->parallelize(ParallelType::Unroll); - tv->axis(-1)->parallelize(ParallelType::TIDx); - } - } - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({129, 127, 63, 65}, options); - at::Tensor t1 = at::rand_like(t0, options); - - auto t2 = t1.mul({0.979361}); - auto aten_output = t2.mul(t0); - - std::vector aten_inputs = {t0, t1}; - - at::Tensor cg_output = at::empty_like(t0, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - fe.runFusion(aten_inputs, {cg_output}); - - testValidate( - &fusion, {cg_output}, aten_inputs, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionAdvancedComputeAt4_CUDA) { - // Case 4 - // T4 = T2 - T3 - // T5 = T1 + T4 - // T6 = T5 - T0 - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeSymbolicTensor(4); - fusion.addInput(tv0); - - TensorView* tv1 = makeSymbolicTensor(4); - fusion.addInput(tv1); - - TensorView* tv2 = makeSymbolicTensor(4); - fusion.addInput(tv2); - - TensorView* tv3 = makeSymbolicTensor(4); - fusion.addInput(tv3); - - TensorView* tv4 = sub(tv2, tv3); - TensorView* tv5 = add(tv1, tv4); - TensorView* tv6 = sub(tv5, tv0); - - fusion.addOutput(tv6); - - // Lets setup to actually run - while (tv6->nDims() > 1) - tv6->merge(0); - tv6->split(0, 128); - tv6->split(0, 4); - - tv0->computeAt(tv6, 1); - tv1->computeAt(tv6, 1); - tv2->computeAt(tv6, 1); - tv3->computeAt(tv6, 1); - - tv6->axis(0)->parallelize(ParallelType::BIDx); - - for (Val* val : fusion.vals()) { - if (!val->isFusionInput() && - val->getValType().value() == ValType::TensorView) { - TensorView* tv = static_cast(val); - - tv->axis(1)->parallelize(ParallelType::Unroll); - tv->axis(-1)->parallelize(ParallelType::TIDx); - } - } - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({129, 127, 63, 65}, options); - at::Tensor t1 = at::rand_like(t0, options); - at::Tensor t2 = at::rand_like(t0, options); - at::Tensor t3 = at::rand_like(t0, options); - - auto t4 = t2.sub(t3); - auto t5 = t1.add(t4); - auto aten_output = t5.sub(t0); - - std::vector aten_inputs = {t0, t1, t2, t3}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto cg_outputs = fe.runFusion(aten_inputs); - - testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionAdvancedComputeAt5_CUDA) { - // Case 5 - // tv2 = tv0 + 2.0 - // tv3 = tv1 * tv2 - Fusion fusion; - FusionGuard fg(&fusion); - - // Set up your input tensor views - TensorView* tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - TensorView* tv1 = makeSymbolicTensor(2); - fusion.addInput(tv1); - TensorView* tv2 = add(tv0, IrBuilder::create(2.0)); - TensorView* tv3 = mul(tv1, tv2); - fusion.addOutput(tv3); - - tv3->merge(0); - tv3->split(-1, 8); - tv3->split(-1, 4); - - tv2->computeAt(tv3, 1); - tv3->axis(0)->parallelize(ParallelType::BIDx); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({63, 65}, options); - at::Tensor t1 = at::rand_like(t0, options); - - auto t2 = t0.add(2.0); - auto aten_output = t1.mul(t2); - - std::vector aten_inputs = {t0, t1}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto cg_outputs = fe.runFusion(aten_inputs); - - testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionAdvancedComputeAt6_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - TensorView* tv1 = makeSymbolicTensor(2); - fusion.addInput(tv1); - TensorView* tv2 = add(tv0, IrBuilder::create(2.0)); - TensorView* tv3 = mul(tv1, tv2); - fusion.addOutput(tv3); - - tv2->merge(0); - tv2->split(-1, 8); - tv2->split(-1, 4); - tv3->merge(0); - tv3->split(-1, 8); - - tv2->computeAt(tv3, 1); - - tv3->axis(0)->parallelize(ParallelType::BIDx); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({63, 65}, options); - at::Tensor t1 = at::rand_like(t0, options); - - auto t2 = t0.add(2.0); - auto aten_output = t1.mul(t2); - - std::vector aten_inputs = {t0, t1}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto cg_outputs = fe.runFusion(aten_inputs); - - testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionAdvancedComputeAt7_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - - auto tv1 = add(tv0, IrBuilder::create(1.0)); - - auto tv2 = makeSymbolicTensor(1); - fusion.addInput(tv2); - - auto tv3 = add(tv2, IrBuilder::create(3.0)); - - auto tv4 = add(tv1, tv3); - fusion.addOutput(tv4); - - auto tv5 = broadcast(tv1, {false, true}); - - auto tv6 = makeSymbolicTensor(2); - fusion.addInput(tv6); - - auto tv7 = mul(tv5, tv6); - - fusion.addOutput(tv7); - - tv7->split(1, 2); - tv7->merge(0); - tv7->split(0, 4); - tv7->split(0, 128); - - tv7->axis(0)->parallelize(ParallelType::BIDx); - tv7->axis(1)->parallelize(ParallelType::TIDx); - - tv0->computeAt(tv7, 1); - auto tv5_domain = tv5->domain()->domain(); - - // These computeAt transformations should not affect the TV5 domain - tv0->computeAt(tv4, -1); - tv2->computeAt(tv4, -1); - - auto tv5_domain_current = tv5->domain()->domain(); - TORCH_CHECK(tv5_domain == tv5_domain_current, "Invalid TV5 domain"); - - const int numel_x = 100; - const int numel_y = 200; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto t0 = at::randn({numel_x}, options); - auto t2 = at::randn({numel_x}, options); - auto t6 = at::randn({numel_x, numel_y}, options); - - auto t1 = t0.add(1.0); - auto t3 = t2.add(3.0); - auto t4 = t1.add(t3); - auto t5 = t1.unsqueeze(1); - auto t7 = t5.mul(t6); - - std::vector aten_inputs = {t0, t2, t6}; - std::vector aten_outputs = {t4, t7}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto cg_outputs = fe.runFusion(aten_inputs); - - testValidate( - &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionAdvancedComputeAt8_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - - auto tv1 = add(tv0, IrBuilder::create(1.0)); - - auto tv2 = makeSymbolicTensor(1); - fusion.addInput(tv2); - - auto tv3 = add(tv2, IrBuilder::create(3.0)); - - auto tv4 = add(tv1, tv3); - fusion.addOutput(tv4); - - auto tv5 = broadcast(tv1, {false, true}); - - auto tv6 = makeSymbolicTensor(2); - fusion.addInput(tv6); - - auto tv7 = mul(tv5, tv6); - - fusion.addOutput(tv7); - - tv7->split(1, 2); - tv7->merge(0); - tv7->split(0, 128, false); - tv7->split(0, 4, false); - - tv7->axis(0)->parallelize(ParallelType::BIDx); - tv7->axis(1)->parallelize(ParallelType::TIDx); - - // Reverse computeAt structure from previous test - tv0->computeAt(tv4, -1); - tv2->computeAt(tv4, -1); - tv0->computeAt(tv7, -1); - - const int numel_x = 100; - const int numel_y = 200; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto t0 = at::randn({numel_x}, options); - auto t2 = at::randn({numel_x}, options); - auto t6 = at::randn({numel_x, numel_y}, options); - - auto t1 = t0.add(1.0); - auto t3 = t2.add(3.0); - auto t4 = t1.add(t3); - auto t5 = t1.unsqueeze(1); - auto t7 = t5.mul(t6); - - std::vector aten_inputs = {t0, t2, t6}; - std::vector aten_outputs = {t4, t7}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto cg_outputs = fe.runFusion(aten_inputs); - - testValidate( - &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionAdvancedComputeWith1_CUDA) { - // Case 1 - // tv1 = tv0 * 0.5 - // tv2 = tv1 * -1 - // tv3 = tv1 + 3 - // tv4 = tv1 * 2 - // tv5 = tv3 + tv2 - // tv6 = tv5 + tv4 - // tv7 = tv1 + tv4 - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - - TensorView* tv1 = mul(tv0, IrBuilder::create(0.5)); - TensorView* tv2 = mul(tv1, IrBuilder::create(-1.0)); - TensorView* tv3 = add(tv1, IrBuilder::create(3.0)); - TensorView* tv4 = mul(tv1, IrBuilder::create(2.0)); - TensorView* tv5 = add(tv3, tv2); - - TensorView* tv6 = add(tv5, tv4); - TensorView* tv7 = add(tv1, tv4); - - fusion.addOutput(tv6); - fusion.addOutput(tv7); - - // Lets setup to actually run - tv0->merge(0); - tv0->split(0, 128); - tv0->split(0, 4); - - tv0->axis(0)->parallelize(ParallelType::BIDx); - - tv0->computeWith(tv7, 1); - - GpuLower gpulw(&fusion); - - // The this-position of the last tensor should be zero. - TORCH_CHECK( - tv7->nDims() == 3 && tv7->getComputeAtPosition() == 0 && - tv7->getMaxProducerPosition() == 1); - TORCH_CHECK( - tv7->nDims() == 3 && tv6->getComputeAtPosition() == 0 && - tv6->getMaxProducerPosition() == 1); - - ComputeAtMap ca_map(&fusion); - - // The position of every other tensor should be 1. - for (auto tv : {tv1, tv2, tv3, tv4, tv5}) { - TORCH_CHECK(tv->nDims() == 3 && tv->getComputeAtPosition() == 1); - TORCH_CHECK( - ca_map.areMapped(tv7->axis(0), tv->axis(0), IdMappingMode::PERMISSIVE)); - } - - for (Val* val : fusion.vals()) { - if (!val->isFusionInput() && - val->getValType().value() == ValType::TensorView) { - TensorView* tv = static_cast(val); - tv->axis(1)->parallelize(ParallelType::Unroll); - tv->axis(-1)->parallelize(ParallelType::TIDx); - } - } - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::Tensor aten_input = at::randn({129, 127}, options); - - auto t1 = aten_input.mul({0.5}); - auto t2 = t1.mul({-1.0}); - auto t3 = t1.add({3.0}); - auto t4 = t1.mul({2.0}); - auto t5 = t3.add(t2); - auto t6 = t5.add(t4); - auto t7 = t1.add(t4); - - std::vector aten_outputs = {t6, t7}; - std::vector cg_outputs = { - at::empty_like(aten_input, options), at::empty_like(aten_input, options)}; - - FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input}); - fe.runFusion({aten_input}, cg_outputs); - - testValidate( - &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionAdvancedComputeWith2_CUDA) { - // Case 2 - // tv1 = tv0 * -1 - // tv2 = tv0 + 3 - // tv3 = tv0 * 2 - // tv4 = tv2 + tv1 - // tv5 = tv4 + tv3 - // tv6 = tv5 + tv3 - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - - TensorView* tv1 = mul(tv0, IrBuilder::create(-1.0)); - TensorView* tv2 = add(tv0, IrBuilder::create(3.0)); - TensorView* tv3 = mul(tv0, IrBuilder::create(2.0)); - TensorView* tv4 = add(tv2, tv1); - - TensorView* tv5 = add(tv4, tv3); - TensorView* tv6 = add(tv5, tv3); - - fusion.addOutput(tv5); - fusion.addOutput(tv6); - - // Lets setup to actually run - tv0->merge(0); - tv0->split(0, 128); - tv0->split(0, 4); - - tv0->axis(0)->parallelize(ParallelType::BIDx); - - tv0->computeWith(tv6, 1); - - for (Val* val : fusion.vals()) { - if (!val->isFusionInput() && - val->getValType().value() == ValType::TensorView) { - TensorView* tv = static_cast(val); - - tv->axis(1)->parallelize(ParallelType::Unroll); - tv->axis(-1)->parallelize(ParallelType::TIDx); - } - } - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input = at::randn({129, 127}, options); - - auto t1 = input.mul({-1.0}); - auto t2 = input.add({3.0}); - auto t3 = input.mul({2.0}); - auto t4 = t2.add(t1); - auto t5 = t4.add(t3); - auto t6 = t5.add(t3); - - std::vector aten_outputs = {t5, t6}; - - FusionExecutor fe; - fe.compileFusion(&fusion, {input}); - auto cg_outputs = fe.runFusion({input}); - - testValidate(&fusion, cg_outputs, {input}, aten_outputs, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionAdvancedComputeWith3_CUDA) { - // Case 3 - // T2 = T1 * 0.979361 - // T3 = T2 * T0 - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeSymbolicTensor(4); - fusion.addInput(tv0); - - TensorView* tv1 = makeSymbolicTensor(4); - fusion.addInput(tv1); - - TensorView* tv2 = mul(tv1, IrBuilder::create(.979361)); - TensorView* tv3 = mul(tv2, tv0); - - fusion.addOutput(tv3); - - // Lets setup to actually run - while (tv0->nDims() > 1) - tv0->merge(0); - tv0->split(0, 128); - tv0->split(0, 4); - - while (tv1->nDims() > 1) - tv1->merge(0); - tv1->split(0, 128); - tv1->split(0, 4); - - tv0->computeWith(tv3, 1); - tv1->computeWith(tv3, 1); - - tv3->axis(0)->parallelize(ParallelType::BIDx); - - for (Val* val : fusion.vals()) { - if (!val->isFusionInput() && - val->getValType().value() == ValType::TensorView) { - TensorView* tv = static_cast(val); - - tv->axis(1)->parallelize(ParallelType::Unroll); - tv->axis(-1)->parallelize(ParallelType::TIDx); - } - } - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({129, 127, 63, 65}, options); - at::Tensor t1 = at::rand_like(t0, options); - - auto t2 = t1.mul({0.979361}); - auto aten_output = t2.mul(t0); - - std::vector aten_inputs = {t0, t1}; - - at::Tensor cg_output = at::empty_like(t0, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - fe.runFusion(aten_inputs, {cg_output}); - - testValidate( - &fusion, {cg_output}, aten_inputs, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionAdvancedComputeWith4_CUDA) { - // Case 4 - // T4 = T2 - T3 - // T5 = T1 + T4 - // T6 = T5 - T0 - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeSymbolicTensor(4); - fusion.addInput(tv0); - - TensorView* tv1 = makeSymbolicTensor(4); - fusion.addInput(tv1); - - TensorView* tv2 = makeSymbolicTensor(4); - fusion.addInput(tv2); - - TensorView* tv3 = makeSymbolicTensor(4); - fusion.addInput(tv3); - - TensorView* tv4 = sub(tv2, tv3); - TensorView* tv5 = add(tv1, tv4); - TensorView* tv6 = sub(tv5, tv0); - - fusion.addOutput(tv6); - std::vector tvs = {tv0, tv1, tv2}; - for (auto tv : tvs) { - // Lets setup to actually run - while (tv->nDims() > 1) { - tv->merge(0); - } - tv->split(0, 128); - tv->split(0, 4); - tv->computeWith(tv6, 1); - } - - tv6->axis(0)->parallelize(ParallelType::BIDx); - - for (Val* val : fusion.vals()) { - if (!val->isFusionInput() && - val->getValType().value() == ValType::TensorView) { - TensorView* tv = static_cast(val); - - tv->axis(1)->parallelize(ParallelType::Unroll); - tv->axis(-1)->parallelize(ParallelType::TIDx); - } - } - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({129, 127, 63, 65}, options); - at::Tensor t1 = at::rand_like(t0, options); - at::Tensor t2 = at::rand_like(t0, options); - at::Tensor t3 = at::rand_like(t0, options); - - auto t4 = t2.sub(t3); - auto t5 = t1.add(t4); - auto aten_output = t5.sub(t0); - - std::vector aten_inputs = {t0, t1, t2, t3}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto cg_outputs = fe.runFusion(aten_inputs); - - testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionAdvancedComputeWith5_CUDA) { - // Case 5 - // tv2 = tv0 + 2.0 - // tv3 = tv1 * tv2 - Fusion fusion; - FusionGuard fg(&fusion); - - // Set up your input tensor views - TensorView* tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - TensorView* tv1 = makeSymbolicTensor(2); - fusion.addInput(tv1); - TensorView* tv2 = add(tv0, IrBuilder::create(2.0)); - TensorView* tv3 = mul(tv1, tv2); - fusion.addOutput(tv3); - - tv2->merge(0); - tv2->split(-1, 8); - tv2->split(-1, 4); - - tv2->computeWith(tv3, 1); - tv3->axis(0)->parallelize(ParallelType::BIDx); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({63, 65}, options); - at::Tensor t1 = at::rand_like(t0, options); - - auto t2 = t0.add(2.0); - auto aten_output = t1.mul(t2); - - std::vector aten_inputs = {t0, t1}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto cg_outputs = fe.runFusion(aten_inputs); - - testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionAdvancedComputeWith6_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - TensorView* tv1 = makeSymbolicTensor(2); - fusion.addInput(tv1); - TensorView* tv2 = add(tv0, IrBuilder::create(2.0)); - TensorView* tv3 = mul(tv1, tv2); - fusion.addOutput(tv3); - - tv2->merge(0); - tv2->split(-1, 8); - tv2->split(-1, 4); - tv3->merge(0); - tv3->split(-1, 8); - - tv2->computeWith(tv3, 1); - - tv3->axis(0)->parallelize(ParallelType::BIDx); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({63, 65}, options); - at::Tensor t1 = at::rand_like(t0, options); - - auto t2 = t0.add(2.0); - auto aten_output = t1.mul(t2); - - std::vector aten_inputs = {t0, t1}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto cg_outputs = fe.runFusion(aten_inputs); - - testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionComputeAtMultiConsumers_CUDA) { - // tv1 = tv0 * 0.5 - // tv2 = tv1 * -1 - // tv3 = tv2 * -2 - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - - TensorView* tv1 = mul(tv0, IrBuilder::create(0.5)); - TensorView* tv2 = mul(tv1, IrBuilder::create(-1.0)); - TensorView* tv3 = mul(tv1, IrBuilder::create(-2.0)); - fusion.addOutput(tv2); - fusion.addOutput(tv3); - - // This computeAt will affect tv2 as well, even though tv2 is not in - // the data-flow path between tv1 and tv3. The reason is that tv1 is - // now computed at tv3, so tv2 must also be computed at the same - // location. Overall, what will happen is basically we merge - // expressions of all tensors and compute them in a single loop - // nest. - TensorView* computeAtTarget = tv3; - computeAtTarget->split(0, 128); - tv1->computeAt(computeAtTarget, 1); - - TensorView* affected_tensors[] = {tv1, tv2, tv3}; - for (auto tv : affected_tensors) { - TORCH_CHECK(tv->nDims() == computeAtTarget->nDims()); - } - - GpuLower gpulw(&fusion); - - TORCH_CHECK(tv1->getComputeAtPosition() == 1); - TORCH_CHECK( - tv2->getComputeAtPosition() == 0 && tv2->getMaxProducerPosition() == 1); - TORCH_CHECK( - tv3->getComputeAtPosition() == 0 && tv3->getMaxProducerPosition() == 1); - - ComputeAtMap ca_map(&fusion); - - // Note that tv2 is also computed at tv3. - for (auto tv : {tv1, tv2}) { - TORCH_CHECK(ca_map.areMapped( - tv->axis(0), computeAtTarget->axis(0), IdMappingMode::PERMISSIVE)); - } - - TORCH_CHECK(tv3->getComputeAtPosition() == 0); - - computeAtTarget->axis(0)->parallelize(ParallelType::BIDx); - for (auto tv : affected_tensors) { - tv->axis(-1)->parallelize(ParallelType::TIDx); - } - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::Tensor aten_input = at::randn({1000}, options); - - auto t1 = aten_input * 0.5; - auto t2 = t1 * -1.0; - auto t3 = t1 * -2.0; - - std::vector aten_outputs = {t2, t3}; - - std::vector cg_outputs = { - at::empty_like(aten_input, options), at::empty_like(aten_input, options)}; - - FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input}); - fe.runFusion({aten_input}, cg_outputs); - - testValidate( - &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); -} - -// Similar to ComputeAtMultiConsumers, but with a common consumer. -TEST_F(NVFuserTest, FusionComputeAtCommonConsumer1_CUDA) { - // tv1 = tv0 * 0.5 - // tv2 = tv1 * -1 - // tv3 = tv2 * -2 - // tv4 = tv2 + tv3 - // tv5 = tv4 * 5 - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - - TensorView* tv1 = mul(tv0, IrBuilder::create(0.5)); - TensorView* tv2 = mul(tv1, IrBuilder::create(-1.0)); - TensorView* tv3 = mul(tv1, IrBuilder::create(-2.0)); - TensorView* tv4 = add(tv2, tv3); - TensorView* tv5 = mul(tv4, IrBuilder::create(5.0)); - fusion.addOutput(tv3); - fusion.addOutput(tv4); - fusion.addOutput(tv5); - - // Computing tv1 at tv3. This will affect tv2 as discussed in - // ComplexComputeAt1. Additionally, in this case, notice that tv4 is - // the common consumer of tv2 and tv3, so they are computed at - // tv4. The indirect propagation of the computeAt should stop at the - // common consumer, and no further change should occur. More - // specifically, the computeAT position of tv4 and tv5 should be zero. - TensorView* computeAtTarget = tv3; - computeAtTarget->split(0, 128); - tv1->computeAt(computeAtTarget, 1); - - TensorView* affected_tensors[] = {tv1, tv2, tv3, tv4}; - for (auto tv : affected_tensors) { - TORCH_CHECK(tv->nDims() == computeAtTarget->nDims()); - } - - TORCH_CHECK(tv1->getComputeAtPosition() == 1); - TORCH_CHECK(tv2->getComputeAtPosition() == 1); - TORCH_CHECK(tv3->getComputeAtPosition() == 1); - TORCH_CHECK(tv4->getComputeAtPosition() == 0); - TORCH_CHECK(tv5->getComputeAtPosition() == 0); - - computeAtTarget->axis(0)->parallelize(ParallelType::BIDx); - - for (auto tv : affected_tensors) { - tv->axis(-1)->parallelize(ParallelType::TIDx); - } - - // Transform tv5 to make it look like the rest - tv5->split(0, 128); - tv5->axis(1)->parallelize(ParallelType::TIDx); - tv5->axis(0)->parallelize(ParallelType::BIDx); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::Tensor aten_input = at::randn({1000}, options); - - auto t1 = aten_input * 0.5; - auto t2 = t1 * -1.0; - auto t3 = t1 * -2.0; - auto t4 = t2 + t3; - auto t5 = t4 * 5.0; - - std::vector aten_outputs = {t3, t4, t5}; - std::vector cg_outputs = { - at::empty_like(aten_input, options), - at::empty_like(aten_input, options), - at::empty_like(aten_input, options)}; - - FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input}); - fe.runFusion({aten_input}, cg_outputs); - - testValidate( - &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionComputeAtCommonConsumer2_CUDA) { - // tv1 = tv0 * 0.5 - // tv2 = tv1 * -1 - // tv3 = tv2 * -1 - // tv4 = tv1 + 4 - // tv5 = tv3 + tv4 - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - - TensorView* tv1 = mul(tv0, IrBuilder::create(0.5)); - TensorView* tv2 = mul(tv1, IrBuilder::create(-1.0)); - TensorView* tv3 = mul(tv2, IrBuilder::create(-1.0)); - TensorView* tv4 = add(tv1, IrBuilder::create(4.0)); - TensorView* tv5 = add(tv3, tv4); - - fusion.addOutput(tv5); - - TensorView* computeAtTarget = tv3; - - computeAtTarget->merge(0); - computeAtTarget->split(0, 128); - computeAtTarget->split(0, 4); - - computeAtTarget->axis(0)->parallelize(ParallelType::BIDx); - - // This computeAt will affect all tensors including tv3, tv4 and - // tv5, even though it appears to impact only tv1 and tv2. The - // reason is that tv1 is now computed at tv3, so tv4 must also be - // computed at the same location. Similarly, the consumer of tv4, - // tv5, must also be computed at the same location. Overall, what - // will happen is basically we merge expressions of all tensors and - // compute them in a single loop nest. Internally, this will be - // realized by making all tensors, except for those in the path - // between tv1 and tv3, computed at tv5, which we call the common - // consumer. - tv1->computeAt(computeAtTarget, 1); - - // All tensors should have the same dimenionality as the target - for (Val* val : fusion.vals()) { - if (val->isFusionInput() || - val->getValType().value() != ValType::TensorView) { - continue; - } - TensorView* tv = val->as(); - TORCH_CHECK(tv->nDims() == computeAtTarget->nDims()); - if (tv == tv5) { - TORCH_CHECK(tv->getComputeAtPosition() == 0); - } else { - TORCH_CHECK(tv->getComputeAtPosition() == 1); - } - } - - for (auto tv : ir_utils::filterByType(fusion.vals())) { - if (!tv->isFusionInput()) { - tv->axis(1)->parallelize(ParallelType::Unroll); - tv->axis(-1)->parallelize(ParallelType::TIDx); - } - } - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::Tensor aten_input = at::randn({129, 127}, options); - - auto t1 = aten_input.mul({0.5}); - auto t2 = t1.mul({-1.0}); - auto t3 = t2.mul({-1.0}); - auto t4 = t1.add({4.0}); - auto aten_output = t3 + t4; - - at::Tensor cg_output = at::empty_like(aten_input, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input}); - fe.runFusion({aten_input}, {cg_output}); - - testValidate( - &fusion, {cg_output}, {aten_input}, {aten_output}, __LINE__, __FILE__); -} - -// Similar to the above common consumer test but adds an additional -// tensor that has no common consumer with the other tensors. -TEST_F(NVFuserTest, FusionComputeAtCommonConsumer3_CUDA) { - // tv1 = tv0 * 0.5 - // tv2 = tv1 * -1 - // tv3 = tv2 * -1 - // tv4 = tv1 + 4 - // tv5 = tv2 + tv3 - // tv6 = tv1 + 6 - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - - TensorView* tv1 = mul(tv0, IrBuilder::create(0.5)); - TensorView* tv2 = mul(tv1, IrBuilder::create(-1.0)); - TensorView* tv3 = mul(tv2, IrBuilder::create(-1.0)); - TensorView* tv4 = add(tv1, IrBuilder::create(4.0)); - TensorView* tv5 = add(tv3, tv4); - TensorView* tv6 = add(tv1, IrBuilder::create(6.0)); - - fusion.addOutput(tv5); - fusion.addOutput(tv6); - - TensorView* computeAtTarget = tv3; - - computeAtTarget->merge(0); - computeAtTarget->split(0, 128); - computeAtTarget->split(0, 4); - - computeAtTarget->axis(0)->parallelize(ParallelType::BIDx); - - // This will have the same impact on the tensors except for tv5 and - // tv6. tv6 does not have any common consumer with the computeAt - // target, but since it uses tv1, it must be also computed at the - // same location as the other impacted tensors. We can either make - // tv5 computed at tv6 or tv6 computed at tv5. In this case, tv5 - // should be computed at tv6 just because the current implementation - // orders the computeAt relationship based on the order in which - // tensors are specified as outputs. - - tv1->computeAt(computeAtTarget, 1); - - // All tensors should have the same dimenionality as the target - for (auto tv : ir_utils::filterByType(fusion.vals())) { - if (tv->isFusionInput()) { - continue; - } - TORCH_CHECK(tv->nDims() == computeAtTarget->nDims()); - if (tv == tv5 || tv == tv6) { - TORCH_CHECK(tv->getComputeAtPosition() == 0); - TORCH_CHECK(tv->getMaxProducerPosition() == 1); - } else { - TORCH_CHECK(tv->getComputeAtPosition() == 1); - } - } - - for (Val* val : fusion.vals()) { - if (!val->isFusionInput() && - val->getValType().value() == ValType::TensorView) { - TensorView* tv = val->as(); - tv->axis(1)->parallelize(ParallelType::Unroll); - tv->axis(-1)->parallelize(ParallelType::TIDx); - } - } - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::Tensor aten_input = at::randn({129, 127}, options); - - auto t1 = aten_input.mul({0.5}); - auto t2 = t1.mul({-1.0}); - auto t3 = t2.mul({-1.0}); - auto t4 = t1.add({4.0}); - auto t5 = t3 + t4; - auto t6 = t1.add({6.0}); - - std::vector aten_outputs = {t5, t6}; - std::vector cg_outputs = { - at::empty_like(aten_input, options), at::empty_like(aten_input, options)}; - - FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input}); - fe.runFusion({aten_input}, cg_outputs); - - testValidate( - &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); -} - -// Similar to ComputeAtCommonConsumer1 but with an addtiona ltensor -// that does not have data dependency with the consumer. -TEST_F(NVFuserTest, FusionComputeAtNoCommonConsumer_CUDA) { - // tv1 = tv0 * 0.5 - // tv2 = tv1 * -1 - // tv3 = tv1 * -2 - // tv4 = tv2 + tv3 - // tv5 = tv4 * 5 - // tv6 = tv1 * 6 - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - - TensorView* tv1 = mul(tv0, IrBuilder::create(0.5)); - TensorView* tv2 = mul(tv1, IrBuilder::create(-1.0)); - TensorView* tv3 = mul(tv1, IrBuilder::create(-2.0)); - TensorView* tv4 = add(tv2, tv3); - TensorView* tv5 = mul(tv4, IrBuilder::create(5.0)); - // Notice that tv6 is not a consumer of tv4. - TensorView* tv6 = mul(tv1, IrBuilder::create(6.0)); - fusion.addOutput(tv3); - fusion.addOutput(tv4); - fusion.addOutput(tv5); - fusion.addOutput(tv6); - - TensorView* computeAtTarget = tv3; - computeAtTarget->split(0, 128); - tv1->computeAt(computeAtTarget, 1); - - TensorView* affected_tensors[] = {tv1, tv2, tv3, tv4, tv5, tv6}; - for (auto tv : affected_tensors) { - TORCH_CHECK(tv->nDims() == computeAtTarget->nDims()); - if (tv == tv6 || tv == tv5) { - TORCH_CHECK(tv->getComputeAtPosition() == 0); - } else { - TORCH_CHECK(tv->getComputeAtPosition() == 1); - } - } - - computeAtTarget->axis(0)->parallelize(ParallelType::BIDx); - - for (auto tv : affected_tensors) { - tv->axis(-1)->parallelize(ParallelType::TIDx); - } - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::Tensor aten_input = at::randn({1000}, options); - - auto t1 = aten_input * 0.5; - auto t2 = t1 * -1.0; - auto t3 = t1 * -2.0; - auto t4 = t2 + t3; - auto t5 = t4 * 5.0; - auto t6 = t1 * 6.0; - - std::vector aten_outputs = {t3, t4, t5, t6}; - std::vector cg_outputs = { - at::empty_like(aten_input, options), - at::empty_like(aten_input, options), - at::empty_like(aten_input, options), - at::empty_like(aten_input, options)}; - - FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input}); - fe.runFusion({aten_input}, cg_outputs); - - testValidate( - &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); -} - -namespace { - -void checkIdMapped( - ComputeAtRootDomainMap& root_map, - TensorView* v0, - IterDomain* id0, - TensorView* v1, - IterDomain* id1, - bool should_map) { - if (should_map) { - TORCH_CHECK( - root_map.canMap(v0->domain(), id0, v1->domain(), id1), - "Should be mappable: ", - id0, - " of ", - v0, - " and ", - id1, - " of ", - v1); - } else { - TORCH_CHECK( - !root_map.canMap(v0->domain(), id0, v1->domain(), id1), - "Should not be mappable: ", - id0, - " of ", - v0, - " and ", - id1, - " of ", - v1); - } -} - -void checkIdMapped( - TensorView* v0, - const std::vector& root0, - const std::vector should_map0, - TensorView* v1, - const std::vector& root1, - const std::vector should_map1) { - ComputeAtRootDomainMap map; - map.build(); - TORCH_INTERNAL_ASSERT(root0.size() == should_map0.size()); - TORCH_INTERNAL_ASSERT(root1.size() == should_map1.size()); - size_t idx0 = 0; - for (const auto i : c10::irange(root0.size())) { - size_t idx1 = 0; - for (const auto j : c10::irange(root1.size())) { - if (should_map0[i] && should_map1[j] && idx0 == idx1) { - checkIdMapped(map, v0, root0[i], v1, root1[j], true); - } else { - checkIdMapped(map, v0, root0[i], v1, root1[j], false); - } - if (should_map1[j]) - ++idx1; - } - if (should_map0[i]) - ++idx0; - } -} - -void checkIdMapped( - TensorView* v0, - const std::vector& root0, - TensorView* v1, - const std::vector& root1) { - checkIdMapped( - v0, - root0, - std::vector(root0.size(), true), - v1, - root1, - std::vector(root1.size(), true)); -} - -} // namespace - -TEST_F(NVFuserTest, FusionRootMappingBasic_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeSymbolicTensor(2); - TensorView* tv1 = makeSymbolicTensor(2); - - fusion.addInput(tv0); - fusion.addInput(tv1); - auto tv3 = broadcast(tv0, {true, false, false}); - auto tv4 = broadcast(tv1, {false, true, false}); - auto tv5 = add(tv3, tv4); - fusion.addOutput(tv5); - - checkIdMapped( - tv0, - tv0->getRootDomain(), - {true, true}, - tv4, - tv4->getRootDomain(), - {false, true, true}); - checkIdMapped( - tv1, - tv1->getRootDomain(), - {true, true}, - tv4, - tv4->getRootDomain(), - {true, false, true}); - checkIdMapped( - tv0, - tv0->getRootDomain(), - {false, true}, - tv1, - tv1->getRootDomain(), - {false, true}); - checkIdMapped( - tv0, - tv0->getRootDomain(), - {true, true}, - tv5, - tv5->getRootDomain(), - {false, true, true}); - checkIdMapped( - tv1, - tv1->getRootDomain(), - {true, true}, - tv5, - tv5->getRootDomain(), - {true, false, true}); - checkIdMapped(tv3, tv3->getRootDomain(), tv4, tv4->getRootDomain()); - checkIdMapped(tv3, tv3->getRootDomain(), tv5, tv5->getRootDomain()); - checkIdMapped(tv4, tv4->getRootDomain(), tv5, tv5->getRootDomain()); -} - -TEST_F(NVFuserTest, FusionRootMappingRfactor_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // [I,I] - TensorView* tv0 = makeSymbolicTensor(2); - // [I,I,I] - TensorView* tv1 = makeSymbolicTensor(3); - - //[I,I,R] - auto tv2 = sum(tv1, {2}); - auto tv3 = add(tv2, tv0); - - fusion.addInput(tv0); - fusion.addInput(tv1); - fusion.addOutput(tv3); - - // scheduling: - //[B,I,R0,R1=128], root = [B,I,R] - tv2->split(2, 128); - - // root=[B,I,Irf], rfactor=[B,I,Irf,Rrf] - auto tv4 = tv2->rFactor({3}); - - checkIdMapped(tv1, tv1->getRootDomain(), tv4, tv4->getRootDomain()); - checkIdMapped( - tv4, - tv4->getRFactorDomain(), - {true, true, true, false}, - tv2, - tv2->getRootDomain(), - {true, true, true}); - checkIdMapped( - tv1, - tv1->getRootDomain(), - {true, true, false}, - tv2, - tv2->getRootDomain(), - {true, true, false}); - checkIdMapped( - tv1, - tv1->getRootDomain(), - {true, true, false}, - tv3, - tv3->getRootDomain(), - {true, true}); - checkIdMapped( - tv2, - tv2->getRootDomain(), - {true, true, false}, - tv3, - tv3->getRootDomain(), - {true, true}); - checkIdMapped(tv0, tv0->getRootDomain(), tv3, tv3->getRootDomain()); - checkIdMapped( - tv0, - tv0->getRootDomain(), - {true, true}, - tv1, - tv1->getRootDomain(), - {true, true, false}); - checkIdMapped( - tv0, - tv0->getRootDomain(), - {true, true}, - tv2, - tv2->getRootDomain(), - {true, true, false}); - checkIdMapped( - tv0, - tv0->getRootDomain(), - {true, true}, - tv4, - tv4->getRFactorDomain(), - {true, true, false, false}); - checkIdMapped( - tv0, - tv0->getRootDomain(), - {true, true}, - tv4, - tv4->getRootDomain(), - {true, true, false}); -} - -TEST_F(NVFuserTest, FusionRootMappingReductionDependency1_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeSymbolicTensor(2); - auto tv1 = sum(tv0, {1}); - auto tv2 = broadcast(tv1, {false, true}); - fusion.addOutput(tv2); - - // The second dimension cannot be mapped as it would require recomputation. - checkIdMapped(tv0, tv0->getRootDomain(), tv1, tv1->getRootDomain()); - checkIdMapped( - tv1, - tv1->getRootDomain(), - {true, false}, - tv2, - tv2->getRootDomain(), - {true, false}); - checkIdMapped( - tv0, - tv0->getRootDomain(), - {true, false}, - tv2, - tv2->getRootDomain(), - {true, false}); -} - -TEST_F(NVFuserTest, FusionRootMappingReductionDependency2_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeSymbolicTensor(2); - auto tv1 = sum(tv0, {1}); - auto tv2 = broadcast(tv1, {false, true}); - auto tv3 = add(tv0, tv2); - fusion.addOutput(tv3); - - checkIdMapped( - tv0, - tv0->getRootDomain(), - {true, false}, - tv1, - tv1->getRootDomain(), - {true, false}); - checkIdMapped( - tv1, - tv1->getRootDomain(), - {true, false}, - tv2, - tv2->getRootDomain(), - {true, false}); - checkIdMapped( - tv0, - tv0->getRootDomain(), - {true, false}, - tv3, - tv3->getRootDomain(), - {true, false}); - checkIdMapped(tv2, tv2->getRootDomain(), tv3, tv3->getRootDomain()); -} - -TEST_F(NVFuserTest, FusionRootMappingReductionDependency3_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeSymbolicTensor(2); - auto tv1 = sum(tv0, {1}); - auto tv2 = broadcast(tv1, {false, true}); - fusion.addOutput(tv2); - - tv1->split(-1, 4); - auto tv3 = tv1->rFactor({-2}); - - checkIdMapped(tv0, tv0->getRootDomain(), tv3, tv3->getRootDomain()); - checkIdMapped( - tv3, - tv3->getMaybeRFactorDomain(), - {true, false, true}, - tv1, - tv1->getRootDomain(), - {true, true}); - checkIdMapped( - tv1, - tv1->getRootDomain(), - {true, false}, - tv2, - tv2->getRootDomain(), - {true, false}); -} - -TEST_F(NVFuserTest, FusionRootMappingReductionDependency4_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeSymbolicTensor(2); - auto tv1 = sum(tv0, {1}); - auto tv2 = broadcast(tv1, {false, true}); - auto tv3 = add(tv0, tv2); - fusion.addOutput(tv3); - - tv1->split(-1, 4); - auto tv4 = tv1->rFactor({-2}); - - checkIdMapped( - tv0, - tv0->getRootDomain(), - {true, false}, - tv4, - tv4->getRootDomain(), - {true, false}); - checkIdMapped( - tv4, - tv4->getMaybeRFactorDomain(), - {true, false, true}, - tv1, - tv1->getRootDomain(), - {true, true}); - checkIdMapped( - tv1, - tv1->getRootDomain(), - {true, false}, - tv2, - tv2->getRootDomain(), - {true, false}); - checkIdMapped(tv2, tv2->getRootDomain(), tv3, tv3->getRootDomain()); - checkIdMapped( - tv0, - tv0->getRootDomain(), - {true, false}, - tv2, - tv2->getRootDomain(), - {true, false}); -} - -// Reproducer of issue #749 -TEST_F(NVFuserTest, FusionRootMappingReductionDependency5_CUDA_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - auto tv1 = add(tv0, IrBuilder::create(1)); - auto tv2 = sum(tv1, {1}); - auto tv3 = broadcast(tv2, {false, true}); - auto tv4 = add(tv0, tv3); - auto tv5 = add(tv4, tv1); - fusion.addOutput(tv5); - - checkIdMapped( - tv0, - tv0->getRootDomain(), - {true, false}, - tv1, - tv1->getRootDomain(), - {true, false}); - checkIdMapped( - tv1, - tv1->getRootDomain(), - {true, false}, - tv2, - tv2->getRootDomain(), - {true, false}); - checkIdMapped( - tv2, - tv2->getRootDomain(), - {true, false}, - tv3, - tv3->getRootDomain(), - {true, false}); - checkIdMapped( - tv3, - tv3->getRootDomain(), - {true, true}, - tv4, - tv4->getRootDomain(), - {true, true}); - checkIdMapped( - tv0, - tv0->getRootDomain(), - {true, false}, - tv4, - tv4->getRootDomain(), - {true, false}); - checkIdMapped( - tv4, - tv4->getRootDomain(), - {true, true}, - tv5, - tv5->getRootDomain(), - {true, true}); -} - -// Similar to RootMappingReductionDependency5 but with rFactor -TEST_F(NVFuserTest, FusionRootMappingReductionDependency6_CUDA_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - auto tv1 = add(tv0, IrBuilder::create(1)); - auto tv2 = sum(tv1, {1}); - auto tv3 = broadcast(tv2, {false, true}); - auto tv4 = add(tv0, tv3); - auto tv5 = add(tv4, tv1); - fusion.addOutput(tv5); - - tv2->split(1, 4); - auto tv6 = tv2->rFactor({-1}); - - checkIdMapped( - tv0, - tv0->getRootDomain(), - {true, false}, - tv1, - tv1->getRootDomain(), - {true, false}); - checkIdMapped( - tv1, - tv1->getRootDomain(), - {true, false}, - tv6, - tv6->getRootDomain(), - {true, false}); - checkIdMapped( - tv6, - tv6->getMaybeRFactorDomain(), - {true, true, false}, - tv2, - tv2->getRootDomain(), - {true, true}); - checkIdMapped( - tv1, - tv1->getRootDomain(), - {true, false}, - tv2, - tv2->getRootDomain(), - {true, false}); - checkIdMapped( - tv2, - tv2->getRootDomain(), - {true, false}, - tv3, - tv3->getRootDomain(), - {true, false}); - checkIdMapped( - tv3, - tv3->getRootDomain(), - {true, true}, - tv4, - tv4->getRootDomain(), - {true, true}); - checkIdMapped( - tv0, - tv0->getRootDomain(), - {true, false}, - tv4, - tv4->getRootDomain(), - {true, false}); - checkIdMapped( - tv4, - tv4->getRootDomain(), - {true, true}, - tv5, - tv5->getRootDomain(), - {true, true}); -} - -TEST_F(NVFuserTest, FusionRootMappingMultipleBroadcast_CUDA) { - if (at::cuda::getCurrentDeviceProperties()->major >= 8) { - GTEST_SKIP() << "Somehow it fails on sm_80+ GPUs" - << " See https://github.com/pytorch/pytorch/issues/86717"; - } - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeSymbolicTensor(1); - auto tv1 = broadcast(tv0, {false, true}); - auto tv2 = broadcast(tv0, {true, false}); - auto tv3 = add(tv1, tv2); - fusion.addOutput(tv3); - - // tv0 cannot be mapped with the consumers as it would mean its only - // domain would be mapped to both the first and second domains of - // the two consumers, thus computing tv0 at both corresponding loops. - checkIdMapped( - tv0, - tv0->getRootDomain(), - {false}, - tv1, - tv1->getRootDomain(), - {false, false}); - checkIdMapped( - tv0, - tv0->getRootDomain(), - {false}, - tv2, - tv2->getRootDomain(), - {false, false}); - checkIdMapped(tv1, tv1->getRootDomain(), tv3, tv3->getRootDomain()); - checkIdMapped(tv2, tv2->getRootDomain(), tv3, tv3->getRootDomain()); - checkIdMapped( - tv0, - tv0->getRootDomain(), - {false}, - tv3, - tv3->getRootDomain(), - {false, false}); -} - -TEST_F( - NVFuserTest, - FusionRootMappingMultipleBroadcastWithNoCommonConsumer_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeSymbolicTensor(1); - auto tv1 = broadcast(tv0, {false, true}); - auto tv2 = broadcast(tv0, {true, false}); - fusion.addOutput(tv1); - fusion.addOutput(tv2); - - // If there is no common consumer, there is no recomputation constraint. - checkIdMapped( - tv0, - tv0->getRootDomain(), - {true}, - tv1, - tv1->getRootDomain(), - {true, false}); - checkIdMapped( - tv0, - tv0->getRootDomain(), - {true}, - tv2, - tv2->getRootDomain(), - {false, true}); - checkIdMapped( - tv1, - tv1->getRootDomain(), - {true, false}, - tv2, - tv2->getRootDomain(), - {false, true}); -} - -TEST_F(NVFuserTest, FusionRootMappingBroadcastNonUniqueSize_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - auto tv1 = makeSymbolicTensor(2); - fusion.addInput(tv1); - auto tv2 = makeSymbolicTensor(2); - fusion.addInput(tv2); - auto tv3 = broadcast(tv0, {false, true}); - auto tv4 = add(tv1, tv3); - fusion.addOutput(tv4); - auto tv5 = add(tv2, tv3); - fusion.addOutput(tv5); - - // Broadcast domains can be used with multiple domains with - // different sizes. In this test, the broadcast domain of tv3 has - // two consumers, tv4 and tv5, which may have different sizes. Each - // of the consumers is used with the broadcast domain of tv3, but - // the two consumers may not have the same size, it is not possible - // to map those domains. - checkIdMapped( - tv0, - tv0->getRootDomain(), - {true}, - tv3, - tv3->getRootDomain(), - {true, false}); - checkIdMapped( - tv0, - tv0->getRootDomain(), - {true}, - tv1, - tv1->getRootDomain(), - {true, false}); - checkIdMapped( - tv0, - tv0->getRootDomain(), - {true}, - tv2, - tv2->getRootDomain(), - {true, false}); - checkIdMapped( - tv1, - tv1->getRootDomain(), - {true, false}, - tv2, - tv2->getRootDomain(), - {true, false}); - checkIdMapped( - tv1, - tv1->getRootDomain(), - {true, false}, - tv3, - tv3->getRootDomain(), - {true, false}); - checkIdMapped( - tv2, - tv2->getRootDomain(), - {true, false}, - tv3, - tv3->getRootDomain(), - {true, false}); - checkIdMapped( - tv3, - tv3->getRootDomain(), - {true, false}, - tv4, - tv4->getRootDomain(), - {true, false}); - checkIdMapped( - tv3, - tv3->getRootDomain(), - {true, false}, - tv5, - tv5->getRootDomain(), - {true, false}); - checkIdMapped( - tv4, - tv4->getRootDomain(), - {true, false}, - tv5, - tv5->getRootDomain(), - {true, false}); -} - -TEST_F(NVFuserTest, FusionRootMappingBroadcast_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - // tv0[I0] - fusion.addInput(tv0); - auto tv1 = broadcast(tv0, {true, false}); - // tv1[B1, I0] - auto tv2 = broadcast(tv1, {true, false, false}); - // tv2[B2, B1, I0] - fusion.addOutput(tv2); - - // In this case, tv1 and tv2 has one and two broadcast domains, - // respectively. It is the second broadcast domain that is mapped to - // the broadcast of tv1. - checkIdMapped( - tv0, - tv0->getRootDomain(), - {true}, - tv1, - tv1->getRootDomain(), - {false, true}); - checkIdMapped( - tv1, - tv1->getRootDomain(), - {true, true}, - tv2, - tv2->getRootDomain(), - {false, true, true}); // Not {true, false, true} - checkIdMapped( - tv0, - tv0->getRootDomain(), - {true}, - tv2, - tv2->getRootDomain(), - {false, false, true}); -} - -// Reproducer of issue #723 -TEST_F(NVFuserTest, FusionRootMappingTrivialReduction_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - auto tv1 = makeSymbolicTensor(2); - - fusion.addInput(tv0); - fusion.addInput(tv1); - - auto tv2 = broadcast(tv0, {true, false}); - auto tv3 = sum(tv2, {0}); - auto tv4 = add(tv2, tv1); - - fusion.addOutput(tv3); - fusion.addOutput(tv4); - - ComputeAtRootDomainMap map; - map.build(); - - checkIdMapped( - map, tv2, tv2->getRootDomain()[0], tv4, tv4->getRootDomain()[0], true); - checkIdMapped( - map, tv2, tv2->getRootDomain()[0], tv3, tv3->getRootDomain()[0], true); - - tv2->computeAt(tv4, -1); - - const int x = 11; - const int y = 12; - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({x}, options); - at::Tensor t1 = at::randn({y, x}, options); - std::vector aten_inputs = {t0, t1}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto outputs = fe.runFusion(aten_inputs); - - auto t3 = t0; - auto t4 = t0.unsqueeze(0).expand({y, x}) + t1; - - testValidate(&fusion, outputs, aten_inputs, {t3, t4}, __LINE__, __FILE__); -} - -// Repro of issue #1950 -TEST_F(NVFuserTest, FusionRootMappingRepro1950_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - auto tv0 = makeSymbolicTensor(3); - auto tv1 = makeSymbolicTensor(3); - auto tv2 = makeSymbolicTensor(3); - - fusion.addInput(tv0); - fusion.addInput(tv1); - fusion.addInput(tv2); - - auto tv3 = set(tv0); - auto tv4 = mul(tv1, tv3); - auto tv5 = mul(tv1, tv2); - auto tv6 = mul(tv5, tv3); - auto tv7 = sum(tv6, {2}); - auto tv8 = broadcast(tv7, {false, false, true}); - auto tv9 = mul(tv3, tv8); - - // Issue #1950 was caused by a particular traversal ordering based - // on the output tensor ordering as below - fusion.addOutput(tv9); - fusion.addOutput(tv5); - fusion.addOutput(tv4); - - ComputeAtRootDomainMap root_map; - root_map.build(); - - checkIdMapped(root_map, tv4, tv4->axis(-1), tv9, tv9->axis(-1), false); -} - -TEST_F(NVFuserTest, FusionDetectSelfMappedDomains_CUDA) { - if (at::cuda::getCurrentDeviceProperties()->major >= 8) { - GTEST_SKIP() << "Somehow it does not throw on sm_80+" - << " See https://github.com/pytorch/pytorch/issues/86714"; - } - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - auto tv1 = add(tv0, IrBuilder::create(1)); - auto tv2 = broadcast(tv1, {true, false}); - auto tv3 = broadcast(tv1, {false, true}); - auto tv4 = add(tv2, tv3); - fusion.addOutput(tv4); - - // computeAt should fail as there is no valid root mapping. - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_ANY_THROW(tv1->computeAt(tv4, 1)); -} - -TEST_F(NVFuserTest, FusionScalarInputs_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - TensorView* tv1 = makeSymbolicTensor(2); - fusion.addInput(tv1); - - Double* d0 = IrBuilder::create(); - fusion.addInput(d0); - Double* d1 = IrBuilder::create(); - fusion.addInput(d1); - Double* d2 = IrBuilder::create(); - fusion.addInput(d2); - Double* d3 = IrBuilder::create(); - fusion.addInput(d3); - Val* d4 = mul(d0, d1); - Val* d5 = sub(d2, d3); - - TensorView* tv2 = sub(tv1, d4); - TensorView* tv3 = add(tv0, d5); - TensorView* tv4 = mul(tv3, tv2); - - fusion.addOutput(tv4); - - // Lets setup to actually run - while (tv4->nDims() > 1) - tv4->merge(0); - tv4->split(0, 128); - tv4->split(0, 4); - - tv0->computeAt(tv4, 1); - tv1->computeAt(tv4, 1); - - tv4->axis(0)->parallelize(ParallelType::BIDx); - - for (Val* val : fusion.vals()) { - if (!val->isFusionInput() && - val->getValType().value() == ValType::TensorView) { - TensorView* tv = static_cast(val); - - tv->axis(1)->parallelize(ParallelType::Unroll); - tv->axis(-1)->parallelize(ParallelType::TIDx); - } - } - - // d4 = d0 * d1 - // d5 = d2 - d3 - // t2 = t1 - d4 - // t3 = t0 + d5 - // t4 = t3 * t2 - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - float fl0 = 0.1; - float fl1 = -0.2; - float fl2 = 0.3; - float fl3 = -0.4; - float fl4 = fl0 * fl1; - float fl5 = fl2 - fl3; - - at::Tensor t0 = at::randn({129, 127}, options); - at::Tensor t1 = at::rand_like(t0, options); - - auto t2 = t1.sub(fl4); - auto t3 = t0.add(fl5); - auto aten_output = t3.mul(t2); - - at::Tensor cg_output = at::empty_like(t0, options); - - at::Scalar test(fl0); - - std::vector aten_inputs = { - t0, - t1, - at::Scalar(fl0), - at::Scalar(fl1), - at::Scalar(fl2), - at::Scalar(fl3)}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - fe.runFusion(aten_inputs, {cg_output}); - - testValidate( - &fusion, {cg_output}, aten_inputs, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionLoopUnroll_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Set up your input tensor views - TensorView* tv0 = makeSymbolicTensor(3); - TensorView* tv1 = makeSymbolicTensor(3); - - // Register your inputs - fusion.addInput(tv0); - fusion.addInput(tv1); - - // Do math with it, it returns a `Val*` but can be static_casted back to - // TensorView - TensorView* tv2 = add(tv1, IrBuilder::create(2.0)); - TensorView* tv3 = add(tv0, tv2); - - // Register your outputs - fusion.addOutput(tv3); - - int block_size = 16; - - tv3->merge(0, 1); - tv3->merge(0, 1); - - tv3->split(0, block_size); - tv3->split(0, 4); - - // For all inputs, computeAt the output inline, temporaries should be squeezed - // between them - tv0->computeAt(tv3, 1); - tv1->computeAt(tv3, 1); - - // Parallelize - tv2->axis(1)->parallelize(ParallelType::Unroll); - tv3->axis(1)->parallelize(ParallelType::Unroll); - tv2->axis(-1)->parallelize(ParallelType::TIDx); - tv3->axis(-1)->parallelize(ParallelType::TIDx); - tv3->axis(0)->parallelize(ParallelType::BIDx); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::Tensor input0 = at::randn({129, 13, 3}, options); - at::Tensor input1 = at::randn({129, 13, 3}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {input0, input1}); - auto outputs = fe.runFusion({input0, input1}); - - TORCH_CHECK(outputs[0].equal(input0.add(input1.add(2.0)))); -} - -/* - * Helper function for single op testing that generates a codegen operand - */ - -Val* gen_jit_operand(std::pair desc) { - if (desc.first == ValType::TensorView) { - return makeSymbolicTensor(2, desc.second); - } else if (desc.first == ValType::Scalar) { - if (desc.second == DataType::Float) { - return IrBuilder::create(); - } else if (desc.second == DataType::Double) { - return IrBuilder::create(); - } else if (desc.second == DataType::ComplexFloat) { - return IrBuilder::create(); - } else if (desc.second == DataType::ComplexDouble) { - return IrBuilder::create(); - } else if (desc.second == DataType::Int) { - return IrBuilder::create(); - } else { - TORCH_CHECK(false, "Not currently supported type: ", desc.first); - } - } else { - TORCH_CHECK(false, "Not currently supported type: ", desc.first); - } - return nullptr; -} - -/* - * Helper function for single op testing that generates an ATen operand - */ - -IValue gen_aten_operand( - std::pair desc, - int blocks, - int threads, - bool rand) { - if (desc.first == ValType::TensorView) { - if (desc.second == DataType::Double || desc.second == DataType::Float || - desc.second == DataType::ComplexDouble || - desc.second == DataType::ComplexFloat || - desc.second == DataType::Half || desc.second == DataType::BFloat16) { - auto options = at::TensorOptions() - .dtype(data_type_to_aten(desc.second)) - .device(at::kCUDA, 0); - if (rand) { - return IValue(at::rand({blocks, threads}, options)); - } else { - return IValue(at::empty({blocks, threads}, options)); - } - } else if (desc.second == DataType::Int || desc.second == DataType::Int32) { - auto dtype = desc.second == DataType::Int32 ? at::kInt : at::kLong; - if (rand) { - auto options = - at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - return IValue(at::randn({blocks, threads}, options).mul(5).to(dtype)); - } else { - auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA, 0); - return IValue(at::empty({blocks, threads}, options)); - } - } else if (desc.second == DataType::Bool) { - if (rand) { - auto options = - at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - return IValue( - at::rand({blocks, threads}, options).round().to(at::kBool)); - } else { - auto options = - at::TensorOptions().dtype(at::kBool).device(at::kCUDA, 0); - return IValue(at::empty({blocks, threads}, options)); - } - } else { - TORCH_CHECK(false, "Not currently supported type: ", desc.second) - } - } else if (desc.first == ValType::Scalar) { - // IValue scalars can only be double int64 or bool - if (desc.second == DataType::ComplexDouble || - desc.second == DataType::ComplexFloat) { - return IValue(at::Scalar(c10::complex(1.0, 0.0))); - } else if ( - desc.second == DataType::Double || desc.second == DataType::Float || - desc.second == DataType::Half || desc.second == DataType::BFloat16) { - return IValue(at::Scalar(1.0)); - } else if (desc.second == DataType::Int) { - return IValue(at::Scalar(1)); - } else { - TORCH_CHECK(false, "Not currently supported type: ", desc.first); - } - } else { - TORCH_CHECK(false, "Not currently supported type: ", desc.first); - } - return nullptr; -} - -/* - * Templatized Helper Function To generate single Op comparison between the - * JIT codegen for Cuda and the ATen Library. - */ - -using OutputPair = std::pair; -template < - typename AtenFunc, - typename JitFunc, - typename InputTuple, - size_t... NumInputs> -void test_op( - int blocks, - int threads, - std::string op_str, - AtenFunc af, - JitFunc jf, - OutputPair op, - InputTuple it, - std::index_sequence) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Generate Input JIT function Inputs and add them as Inputs to the Fusion - // Graph - std::array jit_inputs = { - gen_jit_operand(std::get(it))...}; - std::for_each(jit_inputs.begin(), jit_inputs.end(), [&fusion](Val* v) { - fusion.addInput(v); - }); - TensorView* out = - static_cast(jf(std::get(jit_inputs)...)); - fusion.addOutput(out); - - std::for_each(jit_inputs.begin(), jit_inputs.end(), [out](Val* v) { - if (v->getValType() == ValType::TensorView) - static_cast(v)->computeAt(out, -1); - }); - out->axis(0)->parallelize(ParallelType::BIDx); - out->axis(-1)->parallelize(ParallelType::TIDx); - - std::array aten_inputs = {gen_aten_operand( - std::get(it), blocks, threads, /*rand*/ true)...}; - const at::ArrayRef aten_inputs_ivalues(aten_inputs); - - at::Tensor cg_output = - gen_aten_operand(op, blocks, threads, /*rand*/ false).toTensor(); - std::vector output_vect = {cg_output}; - cudaDeviceSynchronize(); - if (fusion.isStochastic()) - at::manual_seed(0); - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs_ivalues); - fe.runFusion(aten_inputs_ivalues, output_vect); - cudaDeviceSynchronize(); - - if (fusion.isStochastic()) - at::manual_seed(0); - at::Tensor aten_output = af(aten_inputs); - cudaDeviceSynchronize(); // This sync shouldn't be necessary; - - std::string op_msg = "Operation " + op_str; - - testValidate( - &fusion, - {cg_output}, - aten_inputs, - {aten_output}, - __LINE__, - __FILE__, - op_msg); -} - -/* - * Templatized Helper Function that uses variadic templates to - * process a variable length Input Tuple of different Operand Type. - */ -template -void test_op( - int blocks, - int threads, - std::string op_str, - AtenFunc af, - JitFunc jf, - OutputPair op, - InputTuple it) { - static constexpr auto size = std::tuple_size::value; - test_op( - blocks, - threads, - op_str, - af, - jf, - op, - it, - std::make_index_sequence{}); -} - -TEST_F(NVFuserTest, FusionUnaryOps_CUDA) { - using OpTuple = - std::tuple; - - // [Note: explicit tuple type for uniform initialization list] - // Tuple type must be explicitly specified for each uniform initialization - // list within the vector to make this code compatible with some old env - // which we still need to support. eg. gcc 5.4 + cuda 9.2. - std::vector ops{ - OpTuple{at::acos, UnaryOpType::Acos, "acos"}, - OpTuple{at::asin, UnaryOpType::Asin, "asin"}, - OpTuple{at::atan, UnaryOpType::Atan, "atan"}, - // There does not appear to be an appropriate ATen function for atanh - // OpTuple{at::atanh, UnaryOpType::Atanh, "atanh" }, - OpTuple{at::cos, UnaryOpType::Cos, "cos"}, - OpTuple{at::cosh, UnaryOpType::Cosh, "cosh"}, - OpTuple{at::exp, UnaryOpType::Exp, "exp"}, - // OpTuple{at::gelu, UnaryOpType::Gelu, "gelu"}, - OpTuple{at::log, UnaryOpType::Log, "log"}, - OpTuple{at::log10, UnaryOpType::Log10, "log10"}, - OpTuple{at::neg, UnaryOpType::Neg, "neg"}, - OpTuple{at::reciprocal, UnaryOpType::Reciprocal, "reciprocal"}, - OpTuple{at::sigmoid, UnaryOpType::Sigmoid, "sigmoid"}, - OpTuple{at::sin, UnaryOpType::Sin, "sin"}, - OpTuple{at::sinh, UnaryOpType::Sinh, "sinh"}, - OpTuple{at::sqrt, UnaryOpType::Sqrt, "sqrt"}, - OpTuple{at::tan, UnaryOpType::Tan, "tan"}, - OpTuple{at::tanh, UnaryOpType::Tanh, "tanh"}, - OpTuple{at::isfinite, UnaryOpType::IsFinite, "isfinite"}, - OpTuple{at::isinf, UnaryOpType::IsInf, "isinf"}, - OpTuple{at::isnan, UnaryOpType::IsNan, "isnan"}, - OpTuple{at::isreal, UnaryOpType::IsReal, "isreal"}, - }; - - // The following ops has no complex support in eager mode - std::vector ops_without_complex{ - OpTuple{at::ceil, UnaryOpType::Ceil, "ceil"}, - OpTuple{at::floor, UnaryOpType::Floor, "floor"}, - OpTuple{at::frac, UnaryOpType::Frac, "frac"}, - OpTuple{at::trunc, UnaryOpType::Trunc, "trunc"}, - OpTuple{at::round, UnaryOpType::Round, "round"}, - OpTuple{at::relu, UnaryOpType::Relu, "relu"}, - OpTuple{at::expm1, UnaryOpType::Expm1, "expm1"}, - OpTuple{at::log1p, UnaryOpType::Log1p, "log1p"}, - OpTuple{at::lgamma, UnaryOpType::Lgamma, "lgamma"}, - OpTuple{at::erf, UnaryOpType::Erf, "erf"}, - OpTuple{at::erfc, UnaryOpType::Erfc, "erfc"}, - OpTuple{at::isneginf, UnaryOpType::IsNegInf, "isneginf"}, - OpTuple{at::isposinf, UnaryOpType::IsPosInf, "isposinf"}, - }; - - // The following ops only supports complex - std::vector ops_complex_only{ - // real is supported via UnaryOpType::Set for non-complex types, and - // UnaryOpType::Real requires input to be complex - OpTuple{at::real, UnaryOpType::Real, "real"}, - OpTuple{at::imag, UnaryOpType::Imag, "imag"}, - }; - - // Complex support for the following op is not working in nvFuser yet - std::vector ops_skip_complex{ - // TODO: abs is actually supported in nvFuser, but it has bug!!! - // In eager mode, abs(complex_tensor) returns floating point tensor - // but in nvFuser, it wrongly returns complex tensor! - // We need to: - // 1. change our type promotion logic to make a special case for abs - // 2. why this bug is not detected here? we should bump up test coverage - OpTuple{at::abs, UnaryOpType::Abs, "abs"}, - // TODO: the following two ops fails with compilation error like - // "undefined function rsqrt(complex)", we could implement them in - // helpers.cu, but I think it is better to check with Jiterator first, - // because Jiterator uses the same string for complex support. - OpTuple{at::rsqrt, UnaryOpType::Rsqrt, "rsqrt"}, - OpTuple{at::log2, UnaryOpType::Log2, "log2"}}; - - std::vector dtypes = { - DataType::Float, - DataType::Double, - DataType::ComplexFloat, - DataType::ComplexDouble}; - - for (auto dtype : dtypes) { - auto ops_to_test = ops; - if (dtype != DataType::ComplexFloat && dtype != DataType::ComplexDouble) { - ops_to_test.insert( - ops_to_test.end(), - ops_without_complex.begin(), - ops_without_complex.end()); - ops_to_test.insert( - ops_to_test.end(), ops_skip_complex.begin(), ops_skip_complex.end()); - } else { - ops_to_test.insert( - ops_to_test.end(), ops_complex_only.begin(), ops_complex_only.end()); - } - std::for_each(ops.begin(), ops.end(), [&](OpTuple& op) { - test_op( - /*blocks*/ 640, - /*threads*/ 64, - /*name*/ std::get<2>(op), - /*Aten Func */ - [&op](std::array& vals) { - return std::get<0>(op)(vals[0].toTensor()); - }, - /*JIT Func */ - [&op](Val* in1) -> Val* { return unaryOp(std::get<1>(op), in1); }, - /*Output */ std::make_pair(ValType::TensorView, dtype), - /*Inputs Tuple*/ - std::make_tuple(std::make_pair(ValType::TensorView, dtype))); - }); - } - - dtypes = {DataType::Int, DataType::Int32, DataType::Bool}; - for (auto dtype : dtypes) { - test_op( - /*blocks*/ 128, - /*threads*/ 64, - /*name*/ "bitwise_not", - /*Aten Func */ - [](std::array& vals) { - return at::bitwise_not(vals[0].toTensor()); - }, - /*JIT Func */ - [](Val* in1) -> Val* { return unaryOp(UnaryOpType::Not, in1); }, - /*Output */ std::make_pair(ValType::TensorView, dtype), - /*Inputs Tuple*/ - std::make_tuple(std::make_pair(ValType::TensorView, dtype))); - } -} - -TEST_F(NVFuserTest, FusionBinaryOps_CUDA) { - using AtenFuncSig = at::Tensor (*)(const at::Tensor&, const at::Tensor&); - using OpTuple = std::tuple; - - std::vector dtypes = { - DataType::Double, - DataType::Float, - DataType::ComplexFloat, - DataType::ComplexDouble}; - - // see [Note: explicit tuple type for uniform initialization list] - std::vector equal_ops{ - OpTuple{at::eq, BinaryOpType::Eq, "eq"}, - OpTuple{at::ne, BinaryOpType::NE, "ne"}}; - - // Complex numbers are not ordered - std::vector order_ops{ - OpTuple{at::ge, BinaryOpType::GE, "ge"}, - OpTuple{at::gt, BinaryOpType::GT, "gt"}, - OpTuple{at::le, BinaryOpType::LE, "le"}, - OpTuple{at::lt, BinaryOpType::LT, "lt"}}; - - // see [Note: explicit tuple type for uniform initialization list] - std::vector math_ops{ - OpTuple{at::div, BinaryOpType::Div, "div"}, - OpTuple{at::mul, BinaryOpType::Mul, "mul"}, - OpTuple{at::pow, BinaryOpType::Pow, "pow"}}; - - // The following ops has no complex support in eager mode - std::vector math_ops_without_complex{ - OpTuple{at::atan2, BinaryOpType::Atan2, "atan2"}, - OpTuple{at::max, BinaryOpType::Max, "max"}, - OpTuple{at::min, BinaryOpType::Min, "min"}, - OpTuple{at::fmod, BinaryOpType::Fmod, "fmod"}, - // NOTE: Remainder does not match the Aten impl exactly - // despite using an identical function. - OpTuple{at::remainder, BinaryOpType::Remainder, "remainder"}}; - - for (auto dtype : dtypes) { - auto logic_ops = equal_ops; - if (dtype != DataType::ComplexFloat && dtype != DataType::ComplexDouble) { - logic_ops.insert(logic_ops.end(), order_ops.begin(), order_ops.end()); - } - std::for_each(logic_ops.begin(), logic_ops.end(), [&](OpTuple& op) { - test_op( - /*blocks*/ 640, - /*threads*/ 64, - /*name*/ std::get<2>(op), - /*Aten Func */ - [&op](std::array& vals) { - return std::get<0>(op)(vals[0].toTensor(), vals[1].toTensor()); - }, - /*JIT Func */ - [&op](Val* in1, Val* in2) -> Val* { - return binaryOp(std::get<1>(op), in1, in2); - }, - /*Output */ std::make_pair(ValType::TensorView, DataType::Bool), - /*Inputs Tuple*/ - std::make_tuple( - std::make_pair(ValType::TensorView, dtype), - std::make_pair(ValType::TensorView, dtype))); - }); - - auto enabled_math_ops = math_ops; - if (dtype != DataType::ComplexFloat && dtype != DataType::ComplexDouble) { - enabled_math_ops.insert( - enabled_math_ops.end(), - math_ops_without_complex.begin(), - math_ops_without_complex.end()); - } - std::for_each( - enabled_math_ops.begin(), enabled_math_ops.end(), [&](OpTuple& op) { - test_op( - /*blocks*/ 640, - /*threads*/ 64, - /*name*/ std::get<2>(op), - /*Aten Func */ - [&op](std::array& vals) { - return std::get<0>(op)(vals[0].toTensor(), vals[1].toTensor()); - }, - /*JIT Func */ - [&op](Val* in1, Val* in2) -> Val* { - return binaryOp(std::get<1>(op), in1, in2); - }, - /*Output */ std::make_pair(ValType::TensorView, dtype), - /*Inputs Tuple*/ - std::make_tuple( - std::make_pair(ValType::TensorView, dtype), - std::make_pair(ValType::TensorView, dtype))); - }); - - test_op( - /*blocks*/ 640, - /*threads*/ 64, - /*name*/ "add_alpha", - /*Aten Func */ - [](std::array& vals) { - return at::add( - vals[0].toTensor(), vals[1].toTensor(), vals[2].toScalar()); - }, - /*JIT Func */ static_cast(&add_alpha), - /*Output */ std::make_pair(ValType::TensorView, dtype), - /*Inputs Tuple*/ - std::make_tuple( - std::make_pair(ValType::TensorView, dtype), - std::make_pair(ValType::TensorView, dtype), - std::make_pair(ValType::Scalar, dtype))); - - test_op( - /*blocks*/ 640, - /*threads*/ 64, - /*name*/ "sub_alpha", - /*Aten Func */ - [](std::array& vals) { - return at::sub( - vals[0].toTensor(), vals[1].toTensor(), vals[2].toScalar()); - }, - /*JIT Func */ static_cast(&sub_alpha), - /*Output */ std::make_pair(ValType::TensorView, dtype), - /*Inputs Tuple*/ - std::make_tuple( - std::make_pair(ValType::TensorView, dtype), - std::make_pair(ValType::TensorView, dtype), - std::make_pair(ValType::Scalar, dtype))); - } -} - -TEST_F(NVFuserTest, FusionTernaryOps_CUDA) { - std::vector dtypes = { - DataType::Double, - DataType::Float, - DataType::ComplexFloat, - DataType::ComplexDouble}; - - for (auto dtype : dtypes) { - // clamp and threshold are not supported for complex on eager mode - if (dtype != DataType::ComplexFloat && dtype != DataType::ComplexDouble) { - test_op( - /*blocks*/ 640, - /*threads*/ 64, - /*name*/ "clamp", - /*Aten Func */ - [](std::array& vals) { - return at::clamp(vals[0].toTensor(), 0.f, 1.f); - }, - /*JIT Func */ - [&](Val* in1) -> Val* { - if (dtype == DataType::Float) { - return clamp( - in1, - IrBuilder::create(0.f), - IrBuilder::create(1.f)); - } else { - return clamp( - in1, - IrBuilder::create(0.f), - IrBuilder::create(1.f)); - } - }, - /*Output */ std::make_pair(ValType::TensorView, dtype), - /*Inputs Tuple*/ - std::make_tuple(std::make_pair(ValType::TensorView, dtype))); - test_op( - /*blocks*/ 640, - /*threads*/ 64, - /*name*/ "threshold", - /*Aten Func */ - [](std::array& vals) { - return at::threshold(vals[0].toTensor(), 0.f, 1.f); - }, - /*JIT Func */ - [&](Val* in1) -> Val* { - if (dtype == DataType::Float) { - return threshold( - in1, - IrBuilder::create(0.f), - IrBuilder::create(1.f)); - } else { - return threshold( - in1, - IrBuilder::create(0.f), - IrBuilder::create(1.f)); - } - }, - /*Output */ std::make_pair(ValType::TensorView, dtype), - /*Inputs Tuple*/ - std::make_tuple(std::make_pair(ValType::TensorView, dtype))); - } - test_op( - /*blocks*/ 640, - /*threads*/ 64, - /*name*/ "where", - /*Aten Func */ - [](std::array& vals) { - return at::where( - vals[0].toTensor(), vals[1].toTensor(), vals[2].toTensor()); - }, - /*JIT Func */ static_cast(&where), - /*Output */ std::make_pair(ValType::TensorView, dtype), - /*Inputs Tuple*/ - std::make_tuple( - std::make_pair(ValType::TensorView, DataType::Bool), - std::make_pair(ValType::TensorView, dtype), - std::make_pair(ValType::TensorView, dtype))); - } -} - -TEST_F(NVFuserTest, FusionCompoundOps_CUDA) { - std::vector dtypes = { - DataType::Double, - DataType::Float, - DataType::ComplexFloat, - DataType::ComplexDouble}; - - for (auto dtype : dtypes) { - test_op( - /*blocks*/ 640, - /*threads*/ 64, - /*name*/ "lerp", - /*Aten Func */ - [](std::array& vals) { - return at::lerp( - vals[0].toTensor(), vals[1].toTensor(), vals[2].toTensor()); - }, - /*JIT Func */ static_cast(&lerp), - /*Output */ std::make_pair(ValType::TensorView, dtype), - /*Inputs Tuple*/ - std::make_tuple( - std::make_pair(ValType::TensorView, dtype), - std::make_pair(ValType::TensorView, dtype), - std::make_pair(ValType::TensorView, dtype))); - test_op( - /*blocks*/ 640, - /*threads*/ 64, - /*name*/ "addcmul", - /*Aten Func */ - [](std::array& vals) { - return at::addcmul( - vals[0].toTensor(), - vals[1].toTensor(), - vals[2].toTensor(), - vals[3].toScalar()); - }, - /*JIT Func */ - static_cast(&addcmul), - /*Output */ std::make_pair(ValType::TensorView, dtype), - /*Inputs Tuple*/ - std::make_tuple( - std::make_pair(ValType::TensorView, dtype), - std::make_pair(ValType::TensorView, dtype), - std::make_pair(ValType::TensorView, dtype), - std::make_pair(ValType::Scalar, dtype))); - } -} - -TEST_F(NVFuserTest, FusionCastOps_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeSymbolicTensor(2, DataType::Half); - - TensorView* intrm1 = castOp(DataType::Float, tv0); - TensorView* out = castOp(DataType::Half, intrm1); - - fusion.addInput(tv0); - fusion.addOutput(out); - tv0->computeAt(out, -1); - - out->axis(0)->parallelize(ParallelType::BIDx); - out->axis(-1)->parallelize(ParallelType::TIDx); - - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - - at::Tensor input1 = at::randn({1, 4}, options); - at::Tensor ref_output = at::empty_like(input1); - - std::array inputs = {input1}; - const at::ArrayRef input_ivalues(inputs); - - FusionExecutor fe; - fe.compileFusion(&fusion, input_ivalues); - auto outputs = fe.runFusion(input_ivalues); - - ref_output = at::_cast_Half(at::_cast_Double(input1)); - - TORCH_CHECK( - outputs[0].equal(ref_output), - "\nOp Type: -- ", - "cast FP16->FP32->FP16", - " -- had a mismatch.\n", - "\nABS MAX DIFF: ", - outputs[0].sub(ref_output).abs().max(), - "\n"); -} - -// Start off simple, block on the outer dim -// block stride + thread all reduce + unrolling on inner dim -TEST_F(NVFuserTest, FusionReduction1_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Set up your input tensor views - TensorView* tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - - // tv1[I0, R1] = tv0[I0, I1] - TensorView* tv1 = - reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv0); - fusion.addOutput(tv1); - - TORCH_CHECK( - ir_utils::getReductionOps(&fusion).size(), - "Could not detect reduction in fusion."); - - tv1->split(1, 128); - // tv1[I0, R1o, R1i{128}] = tv0[I0, I1] - tv1->split(1, 4); - // tv1[I0, R1oo, R1oi{4}, R1i{128}] = tv0[I0, I1] - - TensorView* tv2 = tv1->rFactor({1}); - // tv2[I0, R1oo, Ir1oi{4}, Ir1i{128}] = tv0[I0, I1] - // tv1[I0, R1oi{4}, R1i{128}] = tv2[I0, R1oo, Ir1oi{4}, Ir1i{128}] - - TensorView* tv3 = tv1->rFactor({1}); - // tv2[I0, R1oo, Ir1oi{4}, Ir1i{128}] = tv0[I0, I1] - // tv3[I0, R1oi{4}, Ir1i{128}] = tv2[I0, R1oo, Ir1oi{4}, Ir1i{128}] - // tv1[I0, R1i{128}] = tv3[I0, R1oi{4}, Ir1i{128}] - - // Incrementally, can print in between for debugging - tv0->computeAt(tv2, 1); - tv2->computeAt(tv3, 1); - tv3->computeAt(tv1, 1); - - // Re do it all at once, because why not. - tv0->computeAt(tv1, 1); - - tv2->axis(2)->parallelize(ParallelType::Unroll); - tv1->axis(0)->parallelize(ParallelType::BIDx); - - tv1->axis(-1)->parallelize(ParallelType::TIDx); - tv2->axis(-1)->parallelize(ParallelType::TIDx); - tv3->axis(-1)->parallelize(ParallelType::TIDx); - - int numel_x = 65000; - int numel_y = 1025; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input = at::randn({numel_x, numel_y}, options); - at::Tensor cg_output = at::empty({numel_x}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {input}); - fe.runFusion({input}, {cg_output}); - - auto aten_output = input.to(at::kDouble).sum({1}); - - testValidate( - &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionReduction2_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Set up your input tensor views - TensorView* tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - - // tv1[I0, R1] = tv0[I0, I1] - TensorView* tv1 = - reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv0); - - fusion.addOutput(tv1); - - // switches to try some different scenarios. maybe we should iterate on all - // permutations. - bool bind_bidx = true; - bool bind_tidx = true; - bool bind_tidy = true; - bool bind_unroll = true; - - int numel_x = 1025; // Cannot exceed block dim max size / tidy - int numel_y = 129; - int tidx = 16; - int tidy = 8; - int unroll_factor = 4; - - tv1->split(1, tidx); - // tv1[I0, R1o, R1i{tidx}] = tv0[I0, I1] - - tv1->split(1, unroll_factor); - // tv1[I0, R1oo, R1oi{unroll}, R1i{tidx}] = tv0[I0, I1] - - tv1->split(0, tidy); - - TensorView* tv2 = tv1->rFactor({-3}); - // tv2[I0, >R1oo<, Ir1oi{unroll}, Ir1i{tidx}] - // tv1[I0o, I0i{tidy}, R1oi{unroll}, R1i{tidx}] - - TensorView* tv3 = tv1->rFactor({-2}); - // tv2[I0, >R1oo<, Ir1oi{unroll}, Ir1i{tidx}] - // tv3[I0, R1oi{unroll}, Ir1i{tidx}] - // tv1[I0o, I0i{tidy}, R1i{tidx}] - - tv0->computeAt(tv1, -2); - - if (bind_unroll) - tv2->axis(-2)->parallelize(ParallelType::Unroll); - if (bind_bidx) - tv1->axis(0)->parallelize(ParallelType::BIDx); - if (bind_tidy) - tv1->axis(1)->parallelize(ParallelType::TIDy); - - if (bind_tidx) { - tv2->axis(-1)->parallelize(ParallelType::TIDx); - tv3->axis(-1)->parallelize(ParallelType::TIDx); - tv1->axis(-1)->parallelize(ParallelType::TIDx); - } - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input = at::randn({numel_x, numel_y}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {input}); - auto cg_outputs = fe.runFusion({input}); - - auto aten_output = input.to(at::kDouble).sum({1}); - testValidate(&fusion, cg_outputs, {input}, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionReduction3_CUDA) { - // What if Z participates in the reduction with X? - Fusion fusion; - FusionGuard fg(&fusion); - - // Set up your input tensor views - TensorView* tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - - // tv1[I0, R1] = tv0[I0, I1] - TensorView* tv1 = - reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv0); - - fusion.addOutput(tv1); - - int numel_x = 1025; // Cannot exceed block dim max size / tidy - int numel_y = 129; - int tidx = 16; - int tidz = 8; - - tv1->split(1, tidz); - // tv1[I0, R1o, R1i{tidz}] = tv0[I0, I1] - - tv1->split(1, tidx); - // tv1[I0, R1oo, R1oi{tidx}, R1i{tidz}] = tv0[I0, I1] - - TensorView* tv2 = tv1->rFactor({-3}); - // tv2[I0, >R1oo<, Ir1oi{tidx}, Ir1i{tidz}] - // tv1[I0o, R1oi{tidx}, R1i{tidz}] - - tv0->computeAt(tv1, -3); - - tv1->axis(0)->parallelize(ParallelType::BIDx); - tv1->axis(-2)->parallelize(ParallelType::TIDx); - tv1->axis(-1)->parallelize(ParallelType::TIDz); - - tv2->axis(-2)->parallelize(ParallelType::TIDx); - tv2->axis(-1)->parallelize(ParallelType::TIDz); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor aten_input = at::randn({numel_x, numel_y}, options); - at::Tensor cg_output = at::empty({numel_x}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input}); - fe.runFusion({aten_input}, {cg_output}); - - auto aten_output = aten_input.to(at::kDouble).sum({1}); - - testValidate( - &fusion, {cg_output}, {aten_input}, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionReduction4_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Set up your input tensor views - TensorView* tv0 = makeSymbolicTensor(2); - TensorView* tv1 = makeSymbolicTensor(2); - - TensorView* tv2 = add(tv0, tv1); - // tv2[I0, I1] = tv0[I0, I1] + tv1[I0, I1] - - fusion.addInput(tv0); - fusion.addInput(tv1); - - TensorView* tv3 = - reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv2); - // tv3[I0, R1] = tv2[I0, I1] - - TensorView* tv4 = makeSymbolicTensor(1); - fusion.addInput(tv4); - - // tv5[I0] = tv3[I0, R1] * tv4[I0] - TensorView* tv5 = mul(tv3, tv4); - fusion.addOutput(tv5); - - int tidx = 16; - - // RFactor the reduction - tv3->split(1, tidx); - // tv3[I0, R1o, R1i{tidx}] = tv2[I0, I1] - - TensorView* tv6 = tv3->rFactor({-2}); - // tv6[I0, R1o, iR1i{tidx}] = tv2[I0, I1] - // tv3[I0, R1i{tidx}] = tv3[I0, I1] - tv2->computeAt(tv6, 2); - - // Compute at inline with tv5 (only 1D) - tv6->computeAt(tv3, 1); - tv3->computeAt(tv5, 1); - - tv5->axis(0)->parallelize(ParallelType::BIDx); - - // Intermediate tensors only need this, but doesn't hurt to do on inputs - // tv0, 1, 4 - tv2->axis(-1)->parallelize(ParallelType::TIDx); - tv3->axis(-1)->parallelize(ParallelType::TIDx); - tv6->axis(-1)->parallelize(ParallelType::TIDx); - - int numel_x = 1025; - int numel_y = 129; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({numel_x, numel_y}, options); - at::Tensor t1 = at::randn({numel_x, numel_y}, options); - at::Tensor t4 = at::randn({numel_x}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0, t1, t4}); - auto cg_outputs = fe.runFusion({t0, t1, t4}); - - auto t2 = t0.add(t1); - auto t3 = t2.to(at::kDouble).sum({1}); - auto aten_output = t3.mul(t4); - - testValidate( - &fusion, cg_outputs, {t0, t1, t4}, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionReduction5_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Set up your input tensor views - TensorView* tv0 = makeSymbolicTensor(3); - - fusion.addInput(tv0); - - TensorView* tv1 = - reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv0); - - fusion.addOutput(tv1); - - int bidy = 2; - int tidy = 4; - int tidx = 5; - - int dim1 = 11; - - tv1->split(-2, tidy); - - TensorView* tv2 = tv1->rFactor({-3}); - - tv0->computeAt(tv1, 1); - tv1->axis(0)->parallelize(ParallelType::BIDy); - - for (auto* val : fusion.vals()) { - if (!val->isFusionInput() && - val->getValType().value() == ValType::TensorView) { - val->as()->axis(-1)->parallelize(ParallelType::TIDx); - } - } - - tv2->axis(-2)->parallelize(ParallelType::TIDy); - tv1->axis(-2)->parallelize(ParallelType::TIDy); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input = at::randn({bidy, dim1, tidx}, options); - - at::Tensor cg_output = at::empty({bidy, tidx}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {input}); - fe.runFusion({input}, {cg_output}); - - auto aten_output = input.to(at::kDouble).sum({1}); - testValidate( - &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionReduction6_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - const int bdimx = 64; - const int bdimy = 8; - - // Set up your input tensor views - TensorView* tv0 = makeSymbolicTensor(3); - fusion.addInput(tv0); - - // tv1[I0, R1, R2] = tv0[I0, I1, I2] - TensorView* tv1 = - reductionOp(BinaryOpType::Add, {1, 2}, IrBuilder::create(0), tv0); - fusion.addOutput(tv1); - - TORCH_CHECK( - ir_utils::getReductionOps(&fusion).size(), - "Could not detect reduction in fusion."); - - tv1->split(2, bdimx); - // tv1[I0, R1, R2o, R2i{128}] = tv0[I0, I1, I2] - tv1->split(1, bdimy); - // tv1[I0, R1o, R1i{8}, R2o, R2i{128}] = tv0[I0, I1, I2] - - TensorView* tv2 = tv1->rFactor({3}); - // tv2[I0, I1o, I1i{8}, R2o, I2i{128}] = tv0[I0, I1, I2] - // tv1[I0, R1o, R1i{8}, R2i{128}] = tv2[I0, I1o, I1i{8}, R2o, I2i{128}] - - TensorView* tv3 = tv1->rFactor({1}); - // tv2[I0, I1o, I1i{8}, R2o, I2i{128}] = tv0[I0, I1, I2] - // tv3[I0, R1o, I1i{8}, I2i{128}] = tv2[I0, I1o, I1i{8}, R2o, I2i{128}] - // tv1[I0, R1i{8}, R2i{128}] = tv3[I0, R1o, I1i{8}, I2i{128}] - - tv3->computeAt(tv1, 1); - tv2->computeAt(tv3, 2); - - tv1->axis(0)->parallelize(ParallelType::BIDx); - tv2->axis(0)->parallelize(ParallelType::BIDx); - tv3->axis(0)->parallelize(ParallelType::BIDx); - - tv1->axis(-1)->parallelize(ParallelType::TIDx); - tv2->axis(-1)->parallelize(ParallelType::TIDx); - tv3->axis(-1)->parallelize(ParallelType::TIDx); - - tv1->axis(-2)->parallelize(ParallelType::TIDy); - tv3->axis(-2)->parallelize(ParallelType::TIDy); - tv2->axis(-3)->parallelize(ParallelType::TIDy); - - int numel_x = 650; - int numel_y = 1000; - int numel_z = 4; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input = at::randn({numel_x, numel_y, numel_z}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {input}); - auto cg_outputs = fe.runFusion({input}); - - auto aten_output = input.to(at::kDouble).sum({1, 2}); - testValidate(&fusion, cg_outputs, {input}, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionMultiGridReduction_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - TensorView* tv1 = max(tv0, {0}); - TensorView* tv2 = sum(tv0, {0}); - - fusion.addOutput(tv1); - fusion.addOutput(tv2); - - int numel_x = 4; - int numel_y = 2; - - tv1->axis(0)->parallelize(ParallelType::BIDx); - tv1->axis(1)->parallelize(ParallelType::TIDx); - - tv2->axis(0)->parallelize(ParallelType::BIDx); - tv2->axis(1)->parallelize(ParallelType::TIDx); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input = at::randn({numel_x, numel_y}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {input}); - auto cg_outputs = fe.runFusion({input}); - - std::vector aten_outputs = { - std::get<0>(input.to(at::kDouble).max(0)), input.to(at::kDouble).sum(0)}; - testValidate(&fusion, cg_outputs, {input}, aten_outputs, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionMultiGridReduction2_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - auto tv1 = sum(tv0, {0}); - auto tv2 = sum(tv1, {0}); - fusion.addOutput(tv2); - - tv1->axis(0)->parallelize(ParallelType::BIDx); - tv1->axis(1)->parallelize(ParallelType::BIDy); - tv2->axis(0)->parallelize(ParallelType::BIDy); - - FusionExecutor fe; - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_ANY_THROW(fe.compileFusion(&fusion)); -} - -TEST_F(NVFuserTest, FusionReductionTFT_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Set up your input tensor views - TensorView* tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - - // tv1[I0, R1] = tv0[I0, I1] - TensorView* tv1 = - reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv0); - - fusion.addOutput(tv1); - - int numel_x = 1025; - int numel_y = 129; - int tidx = 16; - int tidy = 8; - int tidz = 8; - - tv1->split(1, tidx); - // tv1[I0, R1o, R1i{tidx}] - - tv1->split(1, tidz); - // tv1[I0, R1oo, R1Oi{tidz}, R1R1i{tidx}] - - tv1->split(0, tidy); - // tv1[I0o, I0i, R1oo, R1Oi{tidz}, R1R1i{tidx}] - - TensorView* tv2 = tv1->rFactor({2}); - // tv2[I0o, I0i, R1oo, I1Oi{tidz}, I11i{tidx}] - // tv1[I0o, I0i, R1Oi{tidz}, R1R1i{tidx}] - - tv2->computeAt(tv1, 2); - - tv1->axis(1)->parallelize(ParallelType::TIDy); - - tv2->axis(-1)->parallelize(ParallelType::TIDx); - tv1->axis(-1)->parallelize(ParallelType::TIDx); - - tv1->axis(-2)->parallelize(ParallelType::TIDz); - tv2->axis(-2)->parallelize(ParallelType::TIDz); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input = at::randn({numel_x, numel_y}, options); - at::Tensor cg_output = at::empty({numel_x}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {input}); - fe.runFusion({input}, {cg_output}); - - auto aten_output = input.to(at::kDouble).sum({1}); - testValidate( - &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionReductionOuterSplit_CUDA) { - // based off FusionReduction4 - Fusion fusion; - FusionGuard fg(&fusion); - - // Set up your input tensor views - TensorView* tv0 = makeSymbolicTensor(2); - TensorView* tv1 = makeSymbolicTensor(2); - - TensorView* tv2 = add(tv0, tv1); - // tv2[I0, I1] = tv0[I0, I1] + tv1[I0, I1] - - fusion.addInput(tv0); - fusion.addInput(tv1); - - TensorView* tv3 = - reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv2); - // tv3[I0, R1] = tv2[I0, I1] - - TensorView* tv4 = makeSymbolicTensor(1); - fusion.addInput(tv4); - - // tv5[I0] = tv3[I0, R1] * tv4[I0] - TensorView* tv5 = mul(tv3, tv4); - fusion.addOutput(tv5); - - // RFactor the reduction - tv3->split(1, 16, false); - // tv3[I0, R1o{16}, R1i{tidx}] = tv2[I0, I1] - - TensorView* tv6 = tv3->rFactor({-2}); - // tv6[I0, R1o{16}, iR1i{tidx}] = tv2[I0, I1] - // tv3[I0, R1i{tidx}] = tv3[I0, I1] - tv2->computeAt(tv6, 2); - - // Compute at inline with tv5 (only 1D) - tv6->computeAt(tv3, 1); - tv3->computeAt(tv5, 1); - - tv5->axis(0)->parallelize(ParallelType::BIDx); - - // Intermediate tensors only need this, but doesn't hurt to do on inputs - // tv0, 1, 4 - tv2->axis(-1)->parallelize(ParallelType::TIDx); - tv3->axis(-1)->parallelize(ParallelType::TIDx); - tv6->axis(-1)->parallelize(ParallelType::TIDx); - - int numel_x = 1025; - int numel_y = 129; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({numel_x, numel_y}, options); - at::Tensor t1 = at::randn({numel_x, numel_y}, options); - at::Tensor t4 = at::randn({numel_x}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0, t1, t4}); - auto cg_outputs = fe.runFusion({t0, t1, t4}); - - auto t2 = t0.add(t1); - auto t3 = t2.to(at::kDouble).sum({1}); - auto aten_output = t3.mul(t4); - - testValidate( - &fusion, cg_outputs, {t0, t1, t4}, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionBranches_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Set up your input tensor views - TensorView* tv0 = makeSymbolicTensor(2); - TensorView* tv1 = makeSymbolicTensor(2); - TensorView* tv2 = makeSymbolicTensor(2); - fusion.addInput(tv0); - fusion.addInput(tv1); - fusion.addInput(tv2); - - auto tv3 = add(tv0, IrBuilder::create(1.0)); - auto tv4 = add(tv3, tv1); - auto tv5 = add(tv3, tv2); - auto tv6 = add(tv4, tv5); - - fusion.addOutput(tv6); - - constexpr int x = 63, y = 33; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::Tensor t0 = at::randn({x, y}, options); - at::Tensor t1 = at::randn({x, y}, options); - at::Tensor t2 = at::randn({x, y}, options); - - FusionExecutor fe; - tv6->merge(0); - tv6->split(0, 128); - tv6->split(0, 4); - - tv6->axis(0)->parallelize(ParallelType::BIDx); - - tv0->computeAt(tv6, 1); - tv1->computeAt(tv6, 1); - tv2->computeAt(tv6, 1); - - tv3->axis(-2)->parallelize(ParallelType::Unroll); - tv3->axis(-1)->parallelize(ParallelType::TIDx); - tv4->axis(-2)->parallelize(ParallelType::Unroll); - tv4->axis(-1)->parallelize(ParallelType::TIDx); - tv5->axis(-2)->parallelize(ParallelType::Unroll); - tv5->axis(-1)->parallelize(ParallelType::TIDx); - tv6->axis(-1)->parallelize(ParallelType::TIDx); - - std::vector aten_inputs = {t0, t1, t2}; - - fe.compileFusion(&fusion, aten_inputs); - auto cg_outputs = fe.runFusion(aten_inputs); - - auto t3 = t0.add(1.0); - auto t4 = t3.add(t1); - auto t5 = t3.add(t2); - auto aten_output = t4.add(t5); - - testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionSimpleBCast1_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Set up your input tensor views - TensorView* tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - TensorView* tv1 = add(tv0, IrBuilder::create(1.5)); - - TensorView* tv2 = makeSymbolicTensor(2); - fusion.addInput(tv2); - TensorView* tv3 = makeSymbolicTensor(2); - fusion.addInput(tv3); - TensorView* tv4 = sub(tv2, tv3); - - TensorView* tv5 = broadcast(tv1, {false, false, true}); - TensorView* tv6 = broadcast(tv4, {true, false, false}); - - TensorView* tv7 = add(tv5, tv6); - fusion.addOutput(tv7); - - tv7->split(-1, 4); - tv7->split(0, 8); - - tv0->computeAt(tv7, -1); - tv2->computeAt(tv7, -1); - - tv7->axis(0)->parallelize(ParallelType::BIDx); - tv7->axis(-1)->parallelize(ParallelType::TIDx); - - constexpr int x = 63, y = 33, z = 15; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::Tensor t0 = at::randn({x, y}, options); - at::Tensor t1 = t0.add(1.5); - - at::Tensor t2 = at::randn({y, z}, options); - at::Tensor t3 = at::randn({y, z}, options); - - at::Tensor t4 = t2.sub(t3); - at::Tensor t5 = t1.unsqueeze(-1).expand({x, y, z}); - - at::Tensor t6 = t4.expand({x, y, z}); - - at::Tensor aten_output = t5.add(t6); - - std::vector aten_inputs = {t0, t2, t3}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto cg_outputs = fe.runFusion(aten_inputs); - - testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionSimpleBCast2_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Set up your input tensor views - TensorView* tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - TensorView* tv1 = makeSymbolicTensor(2); - fusion.addInput(tv1); - - TensorView* tv2 = add(tv0, tv1); - - TensorView* tv3 = broadcast(tv2, {false, false, true}); - - TensorView* tv4 = makeSymbolicTensor(2); - fusion.addInput(tv4); - - TensorView* tv5 = sub(tv4, IrBuilder::create(0.1)); - - TensorView* tv6 = broadcast(tv5, {true, false, false}); - - TensorView* tv7 = add(tv3, tv6); - - fusion.addOutput(tv7); - - tv7->merge(0, 1); - - tv0->computeAt(tv7, -1); - tv4->computeAt(tv7, -1); - - tv7->axis(0)->parallelize(ParallelType::BIDx); - tv7->axis(-1)->parallelize(ParallelType::TIDx); - - constexpr int x = 63, y = 33, z = 15; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::Tensor t0 = at::randn({x, y}, options); - at::Tensor t1 = at::randn({x, y}, options); - at::Tensor t2 = t0.add(t1); - at::Tensor t3 = t2.unsqueeze(-1).expand({x, y, z}); - - at::Tensor t4 = at::randn({y, z}, options); - at::Tensor t5 = t4.sub(0.1); - at::Tensor t6 = t5.expand({x, y, z}); - at::Tensor aten_output = t3.add(t6); - - at::Tensor cg_output = at::empty({x, y, z}, options); - - std::vector aten_inputs = {t0, t1, t4}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - fe.runFusion(aten_inputs, {cg_output}); - - testValidate( - &fusion, {cg_output}, aten_inputs, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionSimpleBCast3_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Set up input tensor views - // tv0[I1, B{1}] - TensorView* tv0 = makeConcreteTensor({-1, 1}); - fusion.addInput(tv0); - - // tv1[I0, I1, I2] - TensorView* tv2 = makeSymbolicTensor(3); - fusion.addInput(tv2); - - TensorView* tv3 = add(tv0, tv2); - - fusion.addOutput(tv3); - - tv3->merge(0); - tv3->merge(0); - - tv0->computeAt(tv3, -1); - tv2->computeAt(tv3, -1); - - tv3->axis(0)->parallelize(ParallelType::BIDx); - - constexpr int x = 2, y = 3, z = 4; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::Tensor t0 = at::randn({y, 1}, options); - at::Tensor t2 = at::randn({x, y, z}, options); - auto aten_output = t0.add(t2); - - std::vector aten_inputs = {t0, t2}; - at::Tensor cg_output = at::empty({x, y, z}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - fe.runFusion(aten_inputs, {cg_output}); - - testValidate( - &fusion, {cg_output}, aten_inputs, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionSimpleBCast4_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Set up your input tensor views - TensorView* tv0 = makeConcreteTensor({1, -1}); - - TensorView* tv1 = makeSymbolicTensor(3); - fusion.addInput(tv0); - fusion.addInput(tv1); - - TensorView* tv3 = add(tv0, tv1); - - tv3->merge(0); - tv3->merge(0); - tv3->split(0, 128); - tv3->split(0, 4); - - fusion.addOutput(tv3); - - tv0->computeAt(tv3, -1); - tv1->computeAt(tv3, -1); - - tv3->axis(0)->parallelize(ParallelType::BIDx); - tv3->axis(-1)->parallelize(ParallelType::TIDx); - tv3->axis(-2)->parallelize(ParallelType::Unroll); - - constexpr int x = 63, y = 33, z = 15; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::Tensor t0 = at::randn({1, z}, options); - at::Tensor t1 = at::randn({x, y, z}, options); - - auto aten_output = t0.add(t1); - - at::Tensor cg_output = at::empty({x, y, z}, options); - - std::vector aten_inputs = {t0, t1}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - fe.runFusion(aten_inputs, {cg_output}); - - testValidate( - &fusion, {cg_output}, aten_inputs, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionSimpleBCast5_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - constexpr int m = 2, k = 3, n = 4; - auto tv0 = makeConcreteTensor({m, k}); - auto tv1 = makeConcreteTensor({k, n}); - - fusion.addInput(tv0); - fusion.addInput(tv1); - - TensorView* tv2 = broadcast(tv0, {false, false, true}); - TensorView* tv3 = broadcast(tv1, {true, false, false}); - - TensorView* tv4 = add(tv2, tv3); - - fusion.addOutput(tv4); - - tv4->merge(0); - tv4->merge(0); - - tv0->computeAt(tv4, -1); - tv1->computeAt(tv4, -1); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::Tensor t0 = at::randn({m, k}, options); - at::Tensor t1 = at::randn({k, n}, options); - - auto t2 = t0.unsqueeze(-1).expand({m, k, n}); - auto t3 = t1.expand({m, k, n}); - auto aten_output = t2.add(t3); - - at::Tensor cg_output = at::empty({m, k, n}, options); - - std::vector aten_inputs = {t0, t1}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - fe.runFusion(aten_inputs, {cg_output}); - - testValidate( - &fusion, {cg_output}, aten_inputs, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionComplexBCast1_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - int x = 2, y = 3, z = 4; - - auto tv0 = makeConcreteTensor({y}); - auto tv1 = div(tv0, IrBuilder::create(2.0)); - auto tv2 = broadcast(tv1, {false, true}); - auto tv3 = makeConcreteTensor({y, z}); - auto tv4 = mul(tv2, tv3); - auto tv5 = broadcast(tv4, {true, false, false}); - auto tv6 = makeConcreteTensor({x, y, z}); - auto tv7 = add(tv5, tv6); - - // tv0[ i1 ] = input - // tv1[ i1 ] = tv0/2.0 - // tv2[ i1, b2] = bcast(tv1) - // tv3[ i1, i2] = input - // tv4[ i1, i2] = tv2 * tv3 - // tv5[b0, i1, i2] = bcast(tv4) - // tv6[i0, i1, i2] = input - // tv7[i0, i1, i2] = tv5 + tv6 - - // tv4 = bcast(tv1) * tv3 - // tv7 = bcast(tv4) + tv6 - - fusion.addInput(tv0); - fusion.addInput(tv3); - fusion.addInput(tv6); - - fusion.addOutput(tv7); - - tv7->merge(0); - tv7->merge(0); - tv0->computeAt(tv7, -1); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::Tensor t0 = at::randn({y}, options); - at::Tensor t3 = at::randn({y, z}, options); - at::Tensor t6 = at::randn({x, y, z}, options); - - auto t4 = t0.div(2.0).unsqueeze(-1).expand({y, z}) * t3; - auto aten_output = t4.unsqueeze(0).expand({x, y, z}) + t6; - - std::vector aten_inputs = {t0, t3, t6}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto cg_outputs = fe.runFusion(aten_inputs); - - testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionComplexBCast2_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - int x = 2, y = 3, z = 4; - - auto tv0 = makeConcreteTensor({y, z}); - auto tv1 = div(tv0, IrBuilder::create(2.0)); - auto tv2 = sum(tv1, {1}); - auto tv3 = broadcast(tv2, {true, false}); - auto tv4 = makeConcreteTensor({x, y}); - auto tv5 = add(tv3, tv4); - - // tv0[ i1, i2] = input - // tv1[ i1, i2] = tv0/2.0 - // tv2[ i1 ] = sum(tv1, 1) - // tv3[b0, i1 ] = bcast(tv2) - // tv4[i0, i1 ] = input - // tv5[i0, i1 ] = tv3 + tv4 - - // tv2 = sum(tv0/2.0, 1) - // tv5 = bcast(tv2) + tv4 - - fusion.addInput(tv0); - fusion.addInput(tv4); - - fusion.addOutput(tv5); - - tv5->merge(0); - tv0->computeAt(tv5, -1); - tv1->computeAt(tv2, -1); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::Tensor t0 = at::randn({y, z}, options); - at::Tensor t4 = at::randn({x, y}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0, t4}); - auto cg_outputs = fe.runFusion({t0, t4}); - - auto t1 = t0.div(2.0); - auto t2 = t1.to(at::kDouble).sum(1); - auto t3 = t2.unsqueeze(0).expand({x, y}); - auto aten_output = t3.add(t4); - - testValidate( - &fusion, {cg_outputs}, {t0, t4}, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionAdvancedIndexing1_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - int w = 3, x = 4, y = 7, z = 8; - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - auto tv0 = makeSymbolicTensor(3); - auto tv1 = makeSymbolicTensor(4); - fusion.addInput(tv0); - fusion.addInput(tv1); - - auto tv2 = add(tv0, IrBuilder::create(1.0)); - auto tv3 = broadcast(tv2, {true, false, false, false}); - auto tv4 = add(tv3, tv1); - - fusion.addOutput(tv4); - - tv4->merge(0); - tv4->merge(0); - tv4->merge(0); - - tv4->split(0, 128); - tv4->split(0, 4); - - tv2->computeAt(tv4, 1); - - tv4->axis(0)->parallelize(ParallelType::BIDx); - tv4->axis(1)->parallelize(ParallelType::Unroll); - tv4->axis(2)->parallelize(ParallelType::TIDx); - - tv3->axis(1)->parallelize(ParallelType::Unroll); - tv3->axis(2)->parallelize(ParallelType::TIDx); - - tv2->axis(1)->parallelize(ParallelType::Unroll); - tv2->axis(2)->parallelize(ParallelType::TIDx); - - FusionExecutor fe; - - at::Tensor t0 = at::randn({x, y, z}, options); - at::Tensor t1 = at::randn({w, x, y, z}, options); - - auto t3 = t0.add(1.0); - auto aten_output = t3.add(t1); - - std::vector aten_inputs = {t0, t1}; - - fe.compileFusion(&fusion, aten_inputs); - auto cg_outputs = fe.runFusion(aten_inputs); - - testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionAdvancedIndexing2_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - int w = 3, x = 4, y = 7, z = 8; - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - auto tv0 = makeSymbolicTensor(3); - auto tv1 = makeSymbolicTensor(4); - fusion.addInput(tv0); - fusion.addInput(tv1); - - auto tv2 = add(tv0, IrBuilder::create(1.0)); - auto tv3 = broadcast(tv2, {true, false, false, false}); - auto tv4 = add(tv3, tv1); - - fusion.addOutput(tv4); - - tv4->merge(-2); - tv4->merge(-2); - tv4->merge(-2); - - tv4->split(0, 128); - tv4->split(0, 4); - - tv2->computeAt(tv4, 1); - - tv4->axis(0)->parallelize(ParallelType::BIDx); - tv4->axis(1)->parallelize(ParallelType::Unroll); - tv4->axis(2)->parallelize(ParallelType::TIDx); - - tv3->axis(1)->parallelize(ParallelType::Unroll); - tv3->axis(2)->parallelize(ParallelType::TIDx); - - tv2->axis(1)->parallelize(ParallelType::Unroll); - tv2->axis(2)->parallelize(ParallelType::TIDx); - - FusionExecutor fe; - - at::Tensor t0 = at::randn({x, y, z}, options); - at::Tensor t1 = at::randn({w, x, y, z}, options); - - auto t3 = t0.add(1.0); - auto aten_output = t3.add(t1); - - std::vector aten_inputs = {t0, t1}; - - fe.compileFusion(&fusion, aten_inputs); - auto cg_outputs = fe.runFusion(aten_inputs); - - testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionAdvancedIndexing3_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - int w = 3, x = 4, y = 7, z = 8; - - auto tv0 = makeSymbolicTensor(3); - auto tv1 = makeSymbolicTensor(4); - fusion.addInput(tv0); - fusion.addInput(tv1); - - auto tv2 = add(tv0, IrBuilder::create(1.0)); - auto tv3 = add(tv2, tv1); - fusion.addOutput(tv3); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({x, y, z}, options); - at::Tensor t1 = at::randn({w, x, y, z}, options); - - auto t2 = t0.add(1.0); - auto aten_output = t2.add(t1); - - std::vector aten_inputs = {t0, t1}; - - auto lparams = schedulePointwise(&fusion, aten_inputs); - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs, lparams); - auto cg_outputs = fe.runFusion(aten_inputs, lparams); - - testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionAdvancedIndexing4_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Set up your input tensor views - TensorView* tv0 = makeConcreteTensor({4, 8}); - fusion.addInput(tv0); - TensorView* tv1 = makeConcreteTensor({4, 4, 8}); - fusion.addInput(tv1); - - TensorView* tv2 = add(tv0, IrBuilder::create(1)); - TensorView* tv3 = broadcast(tv2, {true, false, false}); - TensorView* tv4 = add(tv3, tv1); - fusion.addOutput(tv4); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({4, 8}, options); - at::Tensor t1 = at::randn({4, 4, 8}, options); - - auto t2 = t0.add(1.0); - auto aten_output = t2.add(t1); - - std::vector aten_inputs = {t0, t1}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto cg_outputs = fe.runFusion(aten_inputs); - - testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionAdvancedIndexing5_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Set up your input tensor views - TensorView* tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - TensorView* tv1 = makeSymbolicTensor(3); - fusion.addInput(tv1); - - TensorView* tv2 = add(tv0, IrBuilder::create(1)); - TensorView* tv3 = broadcast(tv2, {true, false, true}); - TensorView* tv4 = add(tv3, tv1); - fusion.addOutput(tv4); - - tv3->merge(0)->merge(0)->split(0, 2)->split(0, 3); - tv4->merge(0)->merge(0)->split(0, 2)->split(0, 3); - - tv0->computeAt(tv4, 1); - tv1->computeAt(tv4, 1); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({7}, options); - at::Tensor t1 = at::randn({5, 7, 11}, options); - - auto t2 = t0.add(1.0); - auto aten_output = t2.unsqueeze(-1).add(t1); - - std::vector aten_inputs = {t0, t1}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto cg_outputs = fe.runFusion(aten_inputs); - - testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionAdvancedIndexing6_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - std::vector tensor0_shape{7, 4, 7}; - std::vector tensor1_shape{4, 7}; - - TensorView* tv0 = makeSymbolicTensor(tensor0_shape.size()); - fusion.addInput(tv0); - TensorView* tv1 = makeSymbolicTensor(tensor1_shape.size()); - fusion.addInput(tv1); - - TensorView* tv2 = add(tv0, tv1); - TensorView* tv3 = sum(tv2, {0, 1}); - fusion.addOutput(tv3); - - const auto options = - at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::Tensor input0 = at::randn(tensor0_shape, options); - at::Tensor input1 = at::randn(tensor1_shape, options); - - std::vector reduction_axes{0, 1}; - auto reduction_params = getReductionHeuristics(&fusion, {input0, input1}); - TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); - scheduleReduction(&fusion, *reduction_params); - - FusionExecutor fe; - fe.compileFusion(&fusion, {input0, input1}, reduction_params->lparams); - auto cg_outputs = fe.runFusion({input0, input1}, reduction_params->lparams); - - auto aten_output = input0.add(input1).to(at::kDouble).sum(reduction_axes); - - testValidate( - &fusion, - cg_outputs, - {input0, input1}, - {aten_output}, - __LINE__, - __FILE__, - "", - reduction_params->lparams); -} - -TEST_F(NVFuserTest, FusionAdvancedIndexing7_CUDA) { - // Might be able to use this one without 6 as the heuristics in 6 may change - // and this test is to cover the same issue. - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - - auto tv1 = broadcast(tv0, {false, true}); - - auto tv2 = makeSymbolicTensor(2); - fusion.addInput(tv2); - - auto tv3 = add(tv1, tv2); - auto tv4 = sum(tv3, {0, 1}); - fusion.addOutput(tv4); - - tv4->merge(0, 1); - tv4->split(0, 128); - tv4->split(0, 4); - - auto tv5 = tv4->rFactor({0, 1}); - - tv5->computeAt(tv4, -1); - tv0->computeAt(tv5, -1); - - tv4->axis(0)->parallelize(ParallelType::TIDx); - - const int numel_x = 100; - const int numel_y = 200; - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto at_t0 = at::randn({numel_x}, options); - auto at_t1 = at::randn({numel_x, numel_y}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {at_t0, at_t1}); - auto cg_outputs = fe.runFusion({at_t0, at_t1}); - - auto aten_output = (at_t0.unsqueeze(-1).expand({numel_x, numel_y}) + at_t1) - .to(at::kDouble) - .sum(); - - testValidate( - &fusion, cg_outputs, {at_t0, at_t1}, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionAdvancedIndexing8_CUDA) { - // Same as 7 but with outer splits instead of inner - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - - auto tv1 = broadcast(tv0, {false, true}); - - auto tv2 = makeSymbolicTensor(2); - fusion.addInput(tv2); - - auto tv3 = add(tv1, tv2); - auto tv4 = sum(tv3, {0, 1}); - fusion.addOutput(tv4); - - tv4->merge(0, 1); - tv4->split(0, 128, false); - tv4->split(0, 4, false); - - auto tv5 = tv4->rFactor({0, 1}); - - tv5->computeAt(tv4, -1); - tv0->computeAt(tv5, -1); - - tv4->axis(0)->parallelize(ParallelType::TIDx); - - const int numel_x = 100; - const int numel_y = 200; - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto at_t0 = at::randn({numel_x}, options); - auto at_t1 = at::randn({numel_x, numel_y}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {at_t0, at_t1}); - auto cg_outputs = fe.runFusion({at_t0, at_t1}); - - auto aten_output = (at_t0.unsqueeze(-1).expand({numel_x, numel_y}) + at_t1) - .to(at::kDouble) - .sum(); - - testValidate( - &fusion, cg_outputs, {at_t0, at_t1}, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionAdvancedIndexing9_CUDA) { - // Same as 7 but with outer splits instead of inner - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - - auto tv1 = broadcast(tv0, {false, true}); - - auto tv2 = mul(tv1, IrBuilder::create(2)); - fusion.addOutput(tv2); - - auto tv3 = makeSymbolicTensor(3); - fusion.addInput(tv3); - - auto tv4 = add(tv3, tv2); - fusion.addOutput(tv4); - - const int numel_x = 200; - const int numel_y = 300; - const int numel_z = 400; - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto at_t0 = at::randn({numel_y}, options); - auto at_t3 = at::randn({numel_x, numel_y, numel_z}, options); - std::vector aten_inputs = {at_t0, at_t3}; - - auto lparams = schedulePointwise(&fusion, aten_inputs); - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs, lparams); - auto cg_outputs = fe.runFusion(aten_inputs, lparams); - - auto at_t1 = at_t0.unsqueeze(-1); - auto at_t2 = at_t1.mul(2.0); - - auto at_t4 = at_t3.add(at_t2); - - testValidate( - &fusion, cg_outputs, aten_inputs, {at_t2, at_t4}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionAdvancedIndexing10_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Set up your input tensor views - TensorView* tv0 = makeContigTensor(2); - TensorView* tv1 = makeContigTensor(2); - - // Register your inputs - fusion.addInput(tv0); - fusion.addInput(tv1); - - // Do math with it, it returns a `Val*` but can be static_casted back to - // TensorView - TensorView* tv2 = add(tv1, IrBuilder::create(2.0)); - TensorView* tv3 = add(tv0, tv2); - - // Register your outputs - fusion.addOutput(tv3); - - auto tv0_cache = tv0->cacheAfter(); - auto tv1_cache = tv1->cacheAfter(); - - std::vector tvs = {tv0_cache, tv1_cache, tv2, tv3}; - - for (auto tv : tvs) { - tv->split(1, 2, false); - tv->split(1, 1); - tv->split(-1, 4); - // [I0, 2, 1, I1/2/4, 4] - tv->reorder({{1, 2}, {2, 3}, {3, 1}}); - tv->axis(0)->parallelize(ParallelType::BIDx); - tv->axis(1)->parallelize(ParallelType::TIDx); - } - - // For all inputs, computeAt the output inline, temporaries should be squeezed - // between them - tv0->computeAt(tv3, 1); - tv1->computeAt(tv3, 1); - - tv0_cache->axis(-1)->parallelize(ParallelType::Vectorize); - tv1_cache->axis(-1)->parallelize(ParallelType::Vectorize); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::Tensor input1 = at::randn({64, 128}, options); - at::Tensor input2 = at::rand_like(input1); - at::Tensor output = at::empty_like(input1); - - FusionExecutor fe; - fe.compileFusion(&fusion, {input1, input2}); - fe.runFusion({input1, input2}, {output}); - - at::Tensor tv2_ref = input2 + 2.0; - at::Tensor output_ref = input1 + tv2_ref; - - TORCH_CHECK(output_ref.equal(output)); -} - -TEST_F(NVFuserTest, FusionAdvancedIndexing11_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - int w = 3, x = 4, y = 7, z = 8; - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - auto tv0 = makeSymbolicTensor(4); - auto tv1 = makeSymbolicTensor(1); - fusion.addInput(tv0); - fusion.addInput(tv1); - - auto tv2 = add(tv1, IrBuilder::create(1.0)); - auto tv3 = broadcast(tv2, {true, false, true, true}); - auto tv4 = add(tv3, tv0); - - fusion.addOutput(tv4); - - tv4->merge(0); - tv4->merge(1); - - tv4->split(1, 32); - tv4->split(0, 1); - - tv4->reorder({{2, 1}}); - - tv2->computeAt(tv4, 3); - - tv2->setMemoryType(MemoryType::Global); - - tv4->axis(0)->parallelize(ParallelType::BIDx); - tv4->axis(1)->parallelize(ParallelType::BIDy); - tv4->axis(2)->parallelize(ParallelType::Unswitch); - tv4->axis(-1)->parallelize(ParallelType::TIDx); - - tv3->axis(-1)->parallelize(ParallelType::TIDx); - - FusionExecutor fe; - - at::Tensor t0 = at::randn({w, x, y, z}, options); - at::Tensor t1 = at::randn({x}, options); - - auto t3 = t1.add(1.0).unsqueeze(-1).unsqueeze(-1); - auto aten_output = t3.add(t0); - - std::vector aten_inputs = {t0, t1}; - - fe.compileFusion(&fusion, aten_inputs); - auto cg_outputs = fe.runFusion(aten_inputs); - - testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); -} - -// Intended to stress the lowering of our code generator -TEST_F(NVFuserTest, FusionAdvancedLowering1_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeConcreteTensor({9, 5}); - fusion.addInput(tv0); - - TensorView* tv1 = add(tv0, IrBuilder::create(1)); - TensorView* tv2 = add(tv1, IrBuilder::create(2)); - TensorView* tv3 = add(tv1, IrBuilder::create(3)); - TensorView* tv4 = sum(tv3, {1}); - - fusion.addOutput(tv2); - fusion.addOutput(tv4); - - tv4->split(1, 4); - auto tv5 = tv4->rFactor({2}); - - tv1->computeAt(tv5, 2); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::manual_seed(1); - at::Tensor aten_input = at::randn({9, 5}, options); - - auto t1 = aten_input.add(1.0); - auto t2 = t1.add(2.0); - auto t3 = t1.add(3.0); - auto t4 = t3.sum(1); - - std::vector aten_outputs = {t2, t4}; - - FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input}); - auto cg_outputs = fe.runFusion({aten_input}); - - testValidate( - &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionAdvancedLowering2_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Progressively broadcast tensors - TensorView* tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - TensorView* tv1 = makeSymbolicTensor(2); - fusion.addInput(tv1); - TensorView* tv2 = makeSymbolicTensor(3); - fusion.addInput(tv2); - - TensorView* tv3 = add(tv0, IrBuilder::create(1)); - TensorView* tv4 = broadcast(tv3, {false, true}); - TensorView* tv5 = add(tv4, tv1); - TensorView* tv6 = add(tv5, tv2); - - fusion.addOutput(tv6); - - // Split inner dimension - tv6->split(1, 4); - // Merge middle dims with outer dimensions - tv6->merge(2); - tv6->merge(0); - - // tv6[I0*I1o, I1i*I2] - - // Compute everything inline - tv0->computeAt(tv6, -1); - - tv6->axis(0)->parallelize(ParallelType::BIDx); - tv6->axis(1)->parallelize(ParallelType::TIDx); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - int x = 13, y = 9, z = 5; - at::Tensor t0 = at::randn({y}, options); - at::Tensor t1 = at::randn({y, z}, options); - at::Tensor t2 = at::randn({x, y, z}, options); - - auto t3 = t0.add(1.0); - auto t4 = t3.unsqueeze(-1); - auto t5 = t4.add(t1); - auto t6 = t5.add(t2); - - std::vector aten_inputs = {t0, t1, t2}; - std::vector aten_outputs = {t6}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto cg_outputs = fe.runFusion(aten_inputs); - - testValidate( - &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__); -} - -// TODO: Complete test -TEST_F(NVFuserTest, FusionAdvancedLowering3_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeConcreteTensor({1, -1}); - auto tv1 = makeSymbolicTensor(2); - fusion.addInput(tv0); - fusion.addInput(tv1); - - // [b0, i1] - auto tv2 = add(tv0, IrBuilder::create(2.0)); - - // [i0, i1] - auto tv3 = add(tv1, IrBuilder::create(3.0)); - - // [b0, i1] - auto tv4 = add(tv2, IrBuilder::create(4.0)); - - // [io, i1] - auto tv5 = add(tv2, tv3); - - fusion.addOutput(tv4); - fusion.addOutput(tv5); - - tv0->computeAt(tv4, -1); - - tv3->setMemoryType(MemoryType::Global); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - int x = 13, y = 9; - at::Tensor t0 = at::randn({1, y}, options); - at::Tensor t1 = at::randn({x, y}, options); - - auto t4 = t0 + 2 + 4; - auto t5 = t0 + 2 + t1 + 3; - - std::vector aten_inputs = {t0, t1}; - std::vector aten_outputs = {t4, t5}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto cg_outputs = fe.runFusion(aten_inputs); - - testValidate( - &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__); -} - -// This excercises indexing with broadcast root axes. Non-broadcast -// axes need to be preferred when propagating index exprs to root -// axes. See, e.g., Index::getConsumerIndex_impl. -TEST_F(NVFuserTest, FusionAdvancedLowering4_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - auto tv1 = broadcast(tv0, {false, true}); - auto tv2 = broadcast(tv1, {false, false, true}); - auto tv3 = makeSymbolicTensor(3); - fusion.addInput(tv3); - auto tv4 = add(tv2, tv3); - fusion.addOutput(tv4); - - tv4->merge(1)->merge(0); - tv4->split(0, 8); - tv0->computeAt(tv4, 1); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - const int bx = 10; - const int by = 20; - const int bz = 30; - at::Tensor t0 = at::randn({bx}, options); - at::Tensor t3 = at::randn({bx, by, bz}, options); - std::vector aten_inputs = {t0, t3}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto cg_outputs = fe.runFusion(aten_inputs); - - auto aten_output = - t0.unsqueeze(-1).expand({bx, by}).unsqueeze(-1).expand({bx, by, bz}) + t3; - - testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionAdvancedLowering5_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeConcreteTensor({5, 4, 3}); - fusion.addInput(tv0); - - TensorView* tv1 = makeConcreteTensor({5, 3}); - fusion.addInput(tv1); - - auto tv2 = broadcast(tv1, {false, true, false}); - - auto tv3 = add(tv0, tv2); - - fusion.addOutput(tv3); - - tv2->merge(0); - tv1->computeAt(tv2, 1); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::manual_seed(1); - at::Tensor t0 = at::randn({5, 4, 3}, options); - at::Tensor t1 = at::randn({5, 3}, options); - auto t2 = t1.unsqueeze(1); - auto t3 = t0 + t2; - - std::vector aten_inputs = {t0, t1}; - std::vector aten_outputs = {t3}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto cg_outputs = fe.runFusion(aten_inputs); - - testValidate( - &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionAdvancedLowering6_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeConcreteTensor({5, 4, 3}); - fusion.addInput(tv0); - auto tv1 = makeConcreteTensor({4}); - fusion.addInput(tv1); - auto tv2 = unaryOp(UnaryOpType::Set, tv0); - auto tv3 = unaryOp(UnaryOpType::Set, tv1); - - auto tv4 = sum(tv2, {0, 2}); - auto tv5 = add(tv4, tv3); - fusion.addOutput(tv5); - - auto tv6 = broadcast(tv3, {true, false, true}); - auto tv7 = add(tv2, tv6); - fusion.addOutput(tv7); - - tv2->computeAt(tv4, -1, ComputeAtMode::BestEffort); - tv3->computeAt(tv7, -1, ComputeAtMode::BestEffort); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::manual_seed(1); - at::Tensor t0 = at::randn({5, 4, 3}, options); - at::Tensor t1 = at::randn({4}, options); - - auto t2 = t0; - auto t3 = t1; - - std::vector reduction_axes{0, 2}; - auto t4 = t2.sum(reduction_axes); - auto t5 = add(t4, t3); - auto t6 = t3.unsqueeze(0).unsqueeze(-1); - auto t7 = t2.add(t6); - - std::vector aten_inputs = {t0, t1}; - std::vector aten_outputs = {t5, t7}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto cg_outputs = fe.runFusion(aten_inputs); - - testValidate( - &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__); -} - -// Test a simple Gemm but also play around with fusion executor features -TEST_F(NVFuserTest, FusionSimpleGemm_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Set up your input tensor views - TensorView* tv0 = makeSymbolicTensor(2); // M, K - TensorView* tv1 = makeSymbolicTensor(2); // K, N - fusion.addInput(tv0); - fusion.addInput(tv1); - - TensorView* tv2 = broadcast(tv0, {false, false, true}); - // tv2[I0, I1, B] = tv0[I0, I1] - - TensorView* tv3 = broadcast(tv1, {true, false, false}); - // tv3[B, I1, I2] = tv1[I1, I2] - - // tv4[I0, I1, I2] = tv2[I0, I1, B] * tv3[B, I1, I2] - TensorView* tv4 = mul(tv2, tv3); - // tv5[I0, R1, I2] = tv4[I0, I1, I2] - TensorView* tv5 = sum(tv4, {1}); - fusion.addOutput(tv5); - - tv5->split(1, 32); - // tv5[I0, R1o, R1i{32}, I2] - - auto tv6 = tv5->rFactor({1}); - // tv6[I0, R1o, I1i{32}, I2] = tv4[I0, I1, I2] - // tv5[I0, , R1i{32}, I2] = tv6[I0, R1o, I1i{32}, I2] - - tv5->split(0, 4); - tv5->split(-1, 4); - // tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}] - // tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}] - - tv0->computeAt(tv5, -1); - tv1->computeAt(tv5, -1); - - // tv6[I0o, I0i{4}, R1o, I1i{32}, I2o, I2i{4}] - // tv5[I0o, I0i{4}, , R1i{32}, I2o, I2i{4}] - //--> (line symbolizes compute at location) - // tv4[I0o, I0i{4}, I1i{32}, I2o, I2i{4}|, I1o] - // tv6[I0o, I0i{4}, I1i{32}, I2o, I2i{4}|, R1o] - // tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}|] - - tv0->computeAt(tv6, -1); - tv1->computeAt(tv6, -1); - // tv4[I0o, I0i{4}, I1i{32}, I2o, I2i{4}, I1o |] - // tv6[I0o, I0i{4}, I1i{32}, I2o, I2i{4}, R1o |] - // tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}|] - - tv5->axis(0)->parallelize(ParallelType::BIDz); - tv5->axis(1)->parallelize(ParallelType::TIDz); - - tv5->axis(-2)->parallelize(ParallelType::BIDy); - tv5->axis(-1)->parallelize(ParallelType::TIDy); - - tv5->axis(2)->parallelize(ParallelType::TIDx); - tv6->axis(2)->parallelize(ParallelType::TIDx); - - constexpr int M = 65, K = 33, N = 17; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::Tensor t0 = at::randn({M, K}, options); - at::Tensor t1 = at::randn({K, N}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0, t1}, LaunchParams(1, -1, -1, 32, 4, 4)); - // Lets specify a few bounds in launch params to make sure it works - fe.runFusion({t0, t1}, LaunchParams(1, -1, -1, 32, 4, 4)); - - // Make sure bad launch params throws - // TODO: Re-enable once we have parallelization validation in. - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - // ASSERT_ANY_THROW(fe.runFusion({t0, t1}, LaunchParams(1, 2, 3, 4, 5, 6))); - - // Don't specify any launch params - auto cg_outputs = fe.runFusion({t0, t1}); - - auto aten_output = t0.to(at::kDouble).matmul(t1.to(at::kDouble)); - - testValidate( - &fusion, cg_outputs, {t0, t1}, {aten_output}, __LINE__, __FILE__); -} - -// Softmax with a 1D tensor. Parallelized only with a single thread block. -TEST_F(NVFuserTest, FusionSoftmax1D_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - const int tidx = 128; - const int dimx = 1000; - - // Set up your input tensor views - TensorView* input_tv0 = makeSymbolicTensor(1); - fusion.addInput(input_tv0); - - TensorView* exp_tv1 = unaryOp(UnaryOpType::Exp, input_tv0); - TensorView* sum_exp_tv2 = sum(exp_tv1, {-1}); - TensorView* bcast_sum_tv3 = broadcast(sum_exp_tv2, {true}); - - // Replicate exp_tv4 as exp_tv4_copy because exp_tv4 is going to be - // computed at sum_exp_rf_tv8. - TensorView* exp_tv1_copy = unaryOp(UnaryOpType::Exp, input_tv0); - - TensorView* output_tv4 = div(exp_tv1_copy, bcast_sum_tv3); - - fusion.addOutput(output_tv4); - - bcast_sum_tv3->split(0, tidx); - - sum_exp_tv2->split(-1, tidx); - TensorView* sum_exp_rf_tv5 = sum_exp_tv2->rFactor({-2}); - - output_tv4->split(-1, tidx); - - exp_tv1->computeAt(sum_exp_rf_tv5, -1); - exp_tv1_copy->computeAt(output_tv4, -1); - - TensorView* tensors_to_parallelize[] = { - sum_exp_tv2, bcast_sum_tv3, output_tv4, sum_exp_rf_tv5}; - - for (auto tv : tensors_to_parallelize) { - tv->axis(-1)->parallelize(ParallelType::TIDx); - } - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({dimx}, options); - at::Tensor cg_output = at::empty({dimx}, options); - at::Tensor t3_output = at::empty_like(cg_output, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0}); - fe.runFusion({t0}, {cg_output}); - - auto aten_output = at::_softmax(t0.to(at::kDouble), -1, false); - - testValidate(&fusion, {cg_output}, {t0}, {aten_output}, __LINE__, __FILE__); -} - -// Softmax with a 1D tensor with input normalization. -TEST_F(NVFuserTest, FusionSoftmax1DNormalized_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - const int tidx = 128; - const int dimx = 1000; - - // Set up your input tensor views - TensorView* input_tv0 = makeSymbolicTensor(1); - fusion.addInput(input_tv0); - - // Normalize with the max value before computing exp. - TensorView* max_val_tv1 = reductionOp( - BinaryOpType::Max, {-1}, IrBuilder::create(0), input_tv0); - TensorView* bcast_max_tv2 = broadcast(max_val_tv1, {true}); - TensorView* sub_tv3 = sub(input_tv0, bcast_max_tv2); - TensorView* exp_tv4 = unaryOp(UnaryOpType::Exp, sub_tv3); - TensorView* sum_exp_tv5 = sum(exp_tv4, {-1}); - TensorView* bcast_sum_tv6 = broadcast(sum_exp_tv5, {true}); - - // Replicate exp_tv4 as exp_tv4_copy because exp_tv4 is going to be - // computed at sum_exp_rf_tv8. - TensorView* sub_tv3_copy = sub(input_tv0, bcast_max_tv2); - TensorView* exp_tv4_copy = unaryOp(UnaryOpType::Exp, sub_tv3_copy); - - TensorView* output_tv7 = div(exp_tv4_copy, bcast_sum_tv6); - - fusion.addOutput(output_tv7); - bcast_max_tv2->split(0, tidx); - bcast_sum_tv6->split(0, tidx); - - max_val_tv1->split(-1, tidx); - TensorView* max_val_rf_tv8 = max_val_tv1->rFactor({-2}); - - sum_exp_tv5->split(-1, tidx); - TensorView* sum_exp_rf_tv9 = sum_exp_tv5->rFactor({-2}); - - output_tv7->split(-1, tidx); - - sub_tv3->computeAt(sum_exp_rf_tv9, -1); - sub_tv3_copy->computeAt(output_tv7, -1); - - TensorView* tensors_to_parallelize[] = { - max_val_tv1, - bcast_max_tv2, - sum_exp_tv5, - bcast_sum_tv6, - output_tv7, - max_val_rf_tv8, - sum_exp_rf_tv9}; - - for (auto tv : tensors_to_parallelize) { - tv->axis(-1)->parallelize(ParallelType::TIDx); - } - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input = at::randn({dimx}, options); - at::Tensor t3_output = at::empty({dimx}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {input}); - auto cg_outputs = fe.runFusion({input}); - - auto aten_output = at::_softmax(input.to(at::kDouble), -1, false); - - testValidate(&fusion, cg_outputs, {input}, {aten_output}, __LINE__, __FILE__); -} - -// Softmax with a 3D tensor, where the inner-most 3rd dimension is -// normalized. Pallelized with multiple thread blocks. -TEST_F(NVFuserTest, FusionSoftmax3D_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - const int tidx = 32; - const int dimx = 32; - const int dimy = 16; - const int dimz = 130; - - // Set up your input tensor views - TensorView* input_tv0 = makeSymbolicTensor(3); - fusion.addInput(input_tv0); - - TensorView* exp_tv1 = unaryOp(UnaryOpType::Exp, input_tv0); - TensorView* sum_exp_tv2 = sum(exp_tv1, {-1}); - TensorView* bcast_sum_tv3 = broadcast(sum_exp_tv2, {false, false, true}); - - // Replicate exp_tv4 as exp_tv4_copy because exp_tv4 is going to be - // computed at sum_exp_rf_tv8. - TensorView* exp_tv1_copy = unaryOp(UnaryOpType::Exp, input_tv0); - - TensorView* output_tv4 = div(exp_tv1_copy, bcast_sum_tv3); - - fusion.addOutput(output_tv4); - - bcast_sum_tv3->split(-1, tidx); - - sum_exp_tv2->split(-1, tidx); - TensorView* sum_exp_rf_tv5 = sum_exp_tv2->rFactor({-2}); - - output_tv4->split(-1, tidx); - - exp_tv1->computeAt(sum_exp_rf_tv5, -1); - exp_tv1_copy->computeAt(output_tv4, -1); - - TensorView* tensors_to_parallelize[] = { - sum_exp_tv2, bcast_sum_tv3, output_tv4, sum_exp_rf_tv5}; - - for (auto tv : tensors_to_parallelize) { - tv->axis(0)->parallelize(ParallelType::BIDx); - tv->axis(1)->parallelize(ParallelType::BIDy); - tv->axis(-1)->parallelize(ParallelType::TIDx); - } - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input = at::randn({dimx, dimy, dimz}, options); - - at::Tensor cg_output = at::empty({dimx, dimy, dimz}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {input}); - fe.runFusion({input}, {cg_output}); - - auto aten_output = at::_softmax(input.to(at::kDouble), -1, false); - - testValidate( - &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__); -} - -// Softmax with a 3D tensor with input normalization. -TEST_F(NVFuserTest, FusionSoftmax3DNormalized_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - const int tidx = 32; - const int dimx = 32; - const int dimy = 16; - const int dimz = 130; - - // Set up your input tensor views - TensorView* input_tv0 = makeSymbolicTensor(3); - fusion.addInput(input_tv0); - - // Normalize with the max value before computing exp. - TensorView* max_val_tv1 = reductionOp( - BinaryOpType::Max, {-1}, IrBuilder::create(0), input_tv0); - TensorView* bcast_max_tv2 = broadcast(max_val_tv1, {false, false, true}); - TensorView* sub_tv3 = sub(input_tv0, bcast_max_tv2); - TensorView* exp_tv4 = unaryOp(UnaryOpType::Exp, sub_tv3); - TensorView* sum_exp_tv5 = sum(exp_tv4, {-1}); - TensorView* bcast_sum_tv6 = broadcast(sum_exp_tv5, {false, false, true}); - - // Replicate exp_tv4 as exp_tv4_copy because exp_tv4 is going to be - // computed at sum_exp_rf_tv8. - TensorView* sub_tv3_copy = sub(input_tv0, bcast_max_tv2); - TensorView* exp_tv4_copy = unaryOp(UnaryOpType::Exp, sub_tv3_copy); - - TensorView* output_tv7 = div(exp_tv4_copy, bcast_sum_tv6); - - fusion.addOutput(output_tv7); - - bcast_max_tv2->split(-1, tidx); - bcast_sum_tv6->split(-1, tidx); - - max_val_tv1->split(-1, tidx); - TensorView* max_val_rf_tv8 = max_val_tv1->rFactor({-2}); - - sum_exp_tv5->split(-1, tidx); - TensorView* sum_exp_rf_tv9 = sum_exp_tv5->rFactor({-2}); - - output_tv7->split(-1, tidx); - - sub_tv3->computeAt(sum_exp_rf_tv9, -1); - sub_tv3_copy->computeAt(output_tv7, -1); - - TensorView* tensors_to_parallelize[] = { - max_val_tv1, - bcast_max_tv2, - sum_exp_tv5, - bcast_sum_tv6, - output_tv7, - max_val_rf_tv8, - sum_exp_rf_tv9}; - - for (auto tv : tensors_to_parallelize) { - tv->axis(0)->parallelize(ParallelType::BIDx); - tv->axis(1)->parallelize(ParallelType::BIDy); - tv->axis(-1)->parallelize(ParallelType::TIDx); - } - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input = at::randn({dimx, dimy, dimz}, options); - at::Tensor t3_output = at::empty({dimx, dimy, dimz}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {input}); - auto cg_outputs = fe.runFusion({input}); - - auto aten_output = at::_softmax(input.to(at::kDouble), -1, false); - - testValidate(&fusion, cg_outputs, {input}, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionSoftmaxComputeAt_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Set up your input tensor views - TensorView* tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - - auto tv1 = sum(tv0, {1}); - auto tv2 = broadcast(tv1, {false, true}); - - auto tv3 = add(tv0, IrBuilder::create(1.0)); - - auto tv4 = mul(tv2, tv3); - - auto tv5 = sum(tv4, {1}); - auto tv6 = broadcast(tv5, {false, true}); - - auto tv7 = sub(tv6, tv4); - fusion.addOutput(tv7); - - tv1->computeAt(tv7, 1); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_ANY_THROW(tv1->computeAt(tv7, -1)); -} - -// Similar to FusionReduction but uses grid reduction -TEST_F(NVFuserTest, FusionGridReduction1_CUDA) { - const int gdimx = 32; - const int bdimx = 128; - - Fusion fusion; - FusionGuard fg(&fusion); - - // Set up your input tensor views - TensorView* tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - - // tv1[I0, R1] = tv0[I0, I1] - TensorView* tv1 = - reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv0); - fusion.addOutput(tv1); - - TORCH_CHECK( - ir_utils::getReductionOps(&fusion).size(), - "Could not detect reduction in fusion."); - - tv1->split(1, bdimx); - // tv1[I0, R1o, R1i{128}] = tv0[I0, I1] - tv1->split(1, gdimx); - // tv1[I0, R1oo, R1oi{32}, R1i{128}] = tv0[I0, I1] - - TensorView* tv2 = tv1->rFactor({1}); - // tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}] = tv0[I0, I1] - // tv1[I0, R1oi{32}, R1i{128}] = tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}] - - // Incrementally, can print in between for debugging - tv0->computeAt(tv2, 1); - tv2->computeAt(tv1, 1); - - // Re do it all at once, because why not. - tv0->computeAt(tv1, 1); - - tv1->axis(0)->parallelize(ParallelType::BIDy); - tv1->axis(1)->parallelize(ParallelType::BIDx); - tv2->axis(2)->parallelize(ParallelType::BIDx); - - tv1->axis(-1)->parallelize(ParallelType::TIDx); - tv2->axis(-1)->parallelize(ParallelType::TIDx); - - int numel_x = 10000; - int numel_y = 65000; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input = at::randn({numel_x, numel_y}, options); - at::Tensor cg_output = at::empty({numel_x}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {input}); - fe.runFusion({input}, {cg_output}); - - auto aten_output = input.to(at::kDouble).sum({1}); - - testValidate( - &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__); -} - -// Same test as the above but uses BIDy and TIDx for reduction -TEST_F(NVFuserTest, FusionGridReduction2_CUDA) { - const int gdimy = 32; - const int bdimx = 128; - - Fusion fusion; - FusionGuard fg(&fusion); - - // Set up your input tensor views - TensorView* tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - - // tv1[I0, R1] = tv0[I0, I1] - TensorView* tv1 = - reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv0); - fusion.addOutput(tv1); - - TORCH_CHECK( - ir_utils::getReductionOps(&fusion).size(), - "Could not detect reduction in fusion."); - - tv1->split(1, bdimx); - // tv1[I0, R1o, R1i{128}] = tv0[I0, I1] - tv1->split(1, gdimy); - // tv1[I0, R1oo, R1oi{32}, R1i{128}] = tv0[I0, I1] - - TensorView* tv2 = tv1->rFactor({1}); - // tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}] = tv0[I0, I1] - // tv1[I0, R1oi{32}, R1i{128}] = tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}] - - // Incrementally, can print in between for debugging - tv0->computeAt(tv2, 1); - tv2->computeAt(tv1, 1); - - // Re do it all at once, because why not. - tv0->computeAt(tv1, 1); - - tv1->axis(0)->parallelize(ParallelType::BIDx); - tv1->axis(1)->parallelize(ParallelType::BIDy); - tv2->axis(2)->parallelize(ParallelType::BIDy); - - tv1->axis(-1)->parallelize(ParallelType::TIDx); - tv2->axis(-1)->parallelize(ParallelType::TIDx); - - int numel_x = 10000; - int numel_y = 65000; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input = at::randn({numel_x, numel_y}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {input}); - auto cg_outputs = fe.runFusion({input}); - - auto aten_output = input.to(at::kDouble).sum({1}); - - testValidate(&fusion, cg_outputs, {input}, {aten_output}, __LINE__, __FILE__); -} - -// Same test but uses BIDy and BIDz for reduction. No TID used. -TEST_F(NVFuserTest, FusionGridReduction3dim1_CUDA) { - // Grid reductions when there aren't any threads are serial reductions - // keep these numbers low so our error isn't too high compared to normal cuda - // reductions - const int gdimz = 15; - const int gdimy = 9; - - Fusion fusion; - FusionGuard fg(&fusion); - - // Set up your input tensor views - TensorView* tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - - // tv1[I0, R1] = tv0[I0, I1] - TensorView* tv1 = - reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv0); - fusion.addOutput(tv1); - - TORCH_CHECK( - ir_utils::getReductionOps(&fusion).size(), - "Could not detect reduction in fusion."); - - tv1->split(1, gdimy); - // tv1[I0, R1o, R1i{128}] = tv0[I0, I1] - tv1->split(1, gdimz); - // tv1[I0, R1oo, R1oi{32}, R1i{128}] = tv0[I0, I1] - - TensorView* tv2 = tv1->rFactor({1}); - // tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}] = tv0[I0, I1] - // tv1[I0, R1oi{32}, R1i{128}] = tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}] - - // Incrementally, can print in between for debugging - tv0->computeAt(tv2, 1); - tv2->computeAt(tv1, 1); - - // Re do it all at once, because why not. - tv0->computeAt(tv1, 1); - - tv1->axis(0)->parallelize(ParallelType::BIDx); - tv1->axis(1)->parallelize(ParallelType::BIDz); - tv2->axis(2)->parallelize(ParallelType::BIDz); - tv1->axis(-1)->parallelize(ParallelType::BIDy); - tv2->axis(-1)->parallelize(ParallelType::BIDy); - - int numel_x = 100; - int numel_y = 6500; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input = at::randn({numel_x, numel_y}, options); - at::Tensor cg_output = at::empty({numel_x}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {input}); - fe.runFusion({input}, {cg_output}); - - auto aten_output = input.to(at::kDouble).sum({1}); - testValidate( - &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__); -} - -// Same as testGPU_FusionGridReduction3dim1 but reduces dimension 0 -TEST_F(NVFuserTest, FusionGridReduction3dim0_CUDA) { - // Grid reductions when there aren't any threads are serial reductions - // keep these numbers low so our error isn't too high compared to normal cuda - // reductions - const int gdimz = 15; - const int gdimy = 9; - - Fusion fusion; - FusionGuard fg(&fusion); - - // Set up your input tensor views - TensorView* tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - - // tv1[R0, I1] = tv0[I0, I1] - TensorView* tv1 = - reductionOp(BinaryOpType::Add, {0}, IrBuilder::create(0), tv0); - fusion.addOutput(tv1); - - TORCH_CHECK( - ir_utils::getReductionOps(&fusion).size(), - "Could not detect reduction in fusion."); - - tv1->split(0, gdimy); - // tv1[R0o, R0i{128}, I1] = tv0[I0, I1] - tv1->split(0, gdimz); - // tv1[R0oo, R0oi{32}, R0i{128}, I1] = tv0[I0, I1] - - TensorView* tv2 = tv1->rFactor({0}); - // tv2[R0oo, I0oi{32}, I0i{128}, I1] = tv0[I0, I1] - // tv1[ R0oi{32}, R0i{128}, I1] = tv2[R0oo, I0oi{32}, I0i{128}, I1] - - // Note that computeAt isn't going to make anything better as there - // is no dynamically sized dimension. - - // Map parallelism as [Serial, BIDz, BIDy, BIDx] - tv1->axis(-1)->parallelize(ParallelType::BIDx); - tv2->axis(-1)->parallelize(ParallelType::BIDx); - tv1->axis(-2)->parallelize(ParallelType::BIDy); - tv2->axis(-2)->parallelize(ParallelType::BIDy); - tv1->axis(-3)->parallelize(ParallelType::BIDz); - tv2->axis(-3)->parallelize(ParallelType::BIDz); - - int numel_x = 6500; - int numel_y = 100; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::manual_seed(0); - at::Tensor input = at::randn({numel_x, numel_y}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {input}); - auto cg_outputs = fe.runFusion({input}); - - auto aten_output = input.to(at::kDouble).sum({0}); - - testValidate(&fusion, cg_outputs, {input}, {aten_output}, __LINE__, __FILE__); -} - -// This is similar to the FusionReduction, but swaps BIDx and TIDx -TEST_F(NVFuserTest, FusionGridReduction4_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - const int bdimx = 128; - const int gdimx = 1024; - - // Set up your input tensor views - TensorView* tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - - // tv1[I0, R1] = tv0[I0, I1] - TensorView* tv1 = - reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv0); - fusion.addOutput(tv1); - - TORCH_CHECK( - ir_utils::getReductionOps(&fusion).size(), - "Could not detect reduction in fusion."); - - tv1->split(1, gdimx); - // tv1[I0, R1o, R1i{1024}] = tv0[I0, I1] - tv1->split(1, 4); - // tv1[I0, R1oo, R1oi{4}, R1i{128}] = tv0[I0, I1] - - TensorView* tv2 = tv1->rFactor({1}); - // tv2[I0, R1oo, Ir1oi{4}, Ir1i{1024}] = tv0[I0, I1] - // tv1[I0, R1oi{4}, R1i{1024}] = tv2[I0, R1oo, Ir1oi{4}, Ir1i{1024}] - - TensorView* tv3 = tv1->rFactor({1}); - // tv2[I0, R1oo, Ir1oi{4}, Ir1i{1024}] = tv0[I0, I1] - // tv3[I0, R1oi{4}, Ir1i{1024}] = tv2[I0, R1oo, Ir1oi{4}, Ir1i{1024}] - // tv1[I0, R1i{1024}] = tv3[I0, R1oi{4}, Ir1i{1024}] - - // Incrementally, can print in between for debugging - tv0->computeAt(tv2, 1); - tv2->computeAt(tv3, 1); - tv3->computeAt(tv1, 1); - - // Re do it all at once, because why not. - tv0->computeAt(tv1, 1); - - tv2->axis(2)->parallelize(ParallelType::Unroll); - tv1->axis(0)->parallelize(ParallelType::TIDx); - - tv1->axis(-1)->parallelize(ParallelType::BIDx); - tv2->axis(-1)->parallelize(ParallelType::BIDx); - tv3->axis(-1)->parallelize(ParallelType::BIDx); - - int numel_x = bdimx; - int numel_y = 65000; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input = at::randn({numel_x, numel_y}, options); - at::Tensor cg_output = at::empty({numel_x}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {input}); - fe.runFusion({input}, {cg_output}); - - auto aten_output = input.to(at::kDouble).sum({1}); - testValidate( - &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__); -} - -// Grid reduction with 2D thread blocks but only TIDx and BIDx are -// mapped to a reduction dim -TEST_F(NVFuserTest, FusionGridReduction5_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - const int bdimx = 64; - const int bdimy = 16; - const int gdimx = 4; - - // Set up your input tensor views - TensorView* tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - - // tv1[I0, R1] = tv0[I0, I1] - TensorView* tv1 = - reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv0); - fusion.addOutput(tv1); - - TORCH_CHECK( - ir_utils::getReductionOps(&fusion).size(), - "Could not detect reduction in fusion."); - - tv1->split(1, bdimx); - // tv1[I0, R1o, R1i{64}] = tv0[I0, I1] - tv1->split(1, gdimx); - // tv1[I0, R1oo, R1oi{4}, R1i{64}] = tv0[I0, I1] - - TensorView* tv2 = tv1->rFactor({1}); - // tv2[I0, R1oo, Ir1oi{4}, Ir1i{64}] = tv0[I0, I1] - // tv1[I0, R1oi{4}, R1i{64}] = tv2[I0, R1oo, Ir1oi{4}, Ir1i{64}] - - tv0->computeAt(tv1, 1); - - tv1->axis(-1)->parallelize(ParallelType::TIDx); - tv2->axis(-1)->parallelize(ParallelType::TIDx); - - tv1->axis(-2)->parallelize(ParallelType::BIDx); - tv2->axis(-2)->parallelize(ParallelType::BIDx); - - tv1->axis(0)->parallelize(ParallelType::TIDy); - - int numel_x = bdimy; - int numel_y = 6500; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input = at::randn({numel_x, numel_y}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {input}); - auto cg_outputs = fe.runFusion({input}); - - auto aten_output = input.to(at::kDouble).sum({1}); - testValidate(&fusion, cg_outputs, {input}, {aten_output}, __LINE__, __FILE__); -} - -// Similar to FusionGridReduction1 but with 3D tensors -TEST_F(NVFuserTest, FusionGridReduction6_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Set up your input tensor views - TensorView* tv0 = makeSymbolicTensor(3); - fusion.addInput(tv0); - - // tv1[I0, R1, R2] = tv0[I0, I1, I2] - TensorView* tv1 = - reductionOp(BinaryOpType::Add, {1, 2}, IrBuilder::create(0), tv0); - fusion.addOutput(tv1); - - TORCH_CHECK( - ir_utils::getReductionOps(&fusion).size(), - "Could not detect reduction in fusion."); - - // Splitting for TID - tv1->split(2, 128); - // tv1[I0, R1, R2o, R2i{128}] = tv0[I0, I1, I2] - - // Splitting for BID - tv1->split(1, 128); - - // tv1[I0, R1o, R1i{128}, R2o, R2i{128}] = tv0[I0, I1, I2] - - TensorView* tv2 = tv1->rFactor({3}); - // tv2[I0, I1o, I1i{128}, R2o, I2i{128}] - // tv1[I0, R1o, R1i{128}, R2i{128}] - - TensorView* tv3 = tv1->rFactor({1}); - // tv2[I0, I1o, I1i{128}, R2o, I2i{128}] - // tv3[I0, R1o, I1i{128}, I2i{128}] - // tv1[I0, R1i{128}, R2i{128}] - - tv3->computeAt(tv1, 1); - tv2->computeAt(tv3, 3); - - tv1->axis(0)->parallelize(ParallelType::BIDy); - - tv1->axis(-1)->parallelize(ParallelType::TIDx); - tv2->axis(-1)->parallelize(ParallelType::TIDx); - tv3->axis(-1)->parallelize(ParallelType::TIDx); - - tv1->axis(-2)->parallelize(ParallelType::BIDx); - tv2->axis(-3)->parallelize(ParallelType::BIDx); - tv3->axis(-2)->parallelize(ParallelType::BIDx); - - int numel_x = 6500; - int numel_y = 200; - int numel_z = numel_y; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input = at::randn({numel_x, numel_y, numel_z}, options); - at::Tensor cg_output = at::empty({numel_x}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {input}); - fe.runFusion({input}, {cg_output}); - - auto aten_output = input.to(at::kDouble).sum({1, 2}); - - testValidate( - &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__); -} - -// See issue #1049 -TEST_F(NVFuserTest, FusionGridReduction7_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - - auto tv1 = sum(tv0, {0}); - fusion.addOutput(tv1); - - tv1->split(0, 1000); - - tv1->axis(0)->parallelize(ParallelType::BIDx); - tv1->axis(1)->parallelize(ParallelType::BIDy); - - const int numel_x = 1; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input = at::randn({numel_x}, options); - at::Tensor cg_output = at::empty({numel_x}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {input}); - auto out = fe.runFusion({input}); - - auto aten_output = input.sum({0}); - - testValidate(&fusion, out, {input}, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionGridReduction8_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - - auto tv1 = sum(tv0, {0}); - fusion.addOutput(tv1); - - tv1->axis(0)->parallelize(ParallelType::BIDx); - tv1->axis(1)->parallelize(ParallelType::TIDx); - - const int numel_x = 2; - const int numel_y = 4; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input = at::randn({numel_x, numel_y}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {input}); - auto out = fe.runFusion({input}); - - auto aten_output = input.sum({0}); - - testValidate(&fusion, out, {input}, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionGridReduction9_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - auto tv1 = sum(tv0, {1}); - - auto tv2 = makeSymbolicTensor(1); - fusion.addInput(tv2); - - auto tv3 = add(tv2, tv1); - fusion.addOutput(tv3); - - tv1->split(1, 2); - - tv1->axis(1)->parallelize(ParallelType::BIDx); - tv1->axis(2)->parallelize(ParallelType::BIDy); - - tv1->computeAt(tv3, 1); - - const int numel_x = 4; - const int numel_y = 10; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({numel_x, numel_y}, options); - at::Tensor t2 = at::randn({numel_x}, options); - - std::vector aten_inputs = {t0, t2}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto cg_output = fe.runFusion(aten_inputs); - - auto aten_output = t0.sum({1}).add(t2); - - testValidate(&fusion, cg_output, {t0, t2}, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionGridReduction10_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(4); - fusion.addInput(tv0); - - auto tv1 = sum(tv0, {-1}); - auto tv2 = sum(tv1, {-1}); - auto tv3 = sum(tv2, {-1}); - - fusion.addOutput(tv3); - tv1->axis(0)->parallelize(ParallelType::TIDx); - tv1->axis(1)->parallelize(ParallelType::BIDx); - tv1->axis(2)->parallelize(ParallelType::TIDy); - tv1->axis(3)->parallelize(ParallelType::TIDz); - - tv2->axis(0)->parallelize(ParallelType::TIDx); - tv2->axis(1)->parallelize(ParallelType::BIDx); - tv2->axis(2)->parallelize(ParallelType::TIDy); - - tv3->axis(0)->parallelize(ParallelType::TIDx); - tv3->axis(1)->parallelize(ParallelType::BIDx); - - tv0->computeAt(tv3, 1); - - const int numel_w = 2; - const int numel_x = 3; - const int numel_y = 4; - const int numel_z = 5; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({numel_w, numel_x, numel_y, numel_z}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0}); - auto cg_output = fe.runFusion({t0}); - - auto aten_output = t0.sum({1, 2, 3}); - - testValidate(&fusion, cg_output, {t0}, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionNonRedAxisBind_CUDA) { - int bid_x = 3; - int tid_x = 2; - int red_dim = 0; - - Fusion fusion; - FusionGuard fg(&fusion); - - // Set up your input tensor views - TensorView* tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - - TensorView* tv1 = reductionOp( - BinaryOpType::Add, {red_dim}, IrBuilder::create(0), tv0); - fusion.addOutput(tv1); - - tv1->split(-1, tid_x); - tv1->axis(-2)->parallelize(ParallelType::BIDx); - tv1->axis(-1)->parallelize(ParallelType::TIDx); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input = at::randn({16, bid_x * tid_x}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {input}); - auto cg_outputs = fe.runFusion({input}); - - auto aten_output = input.to(at::kDouble).sum({red_dim}); - - testValidate(&fusion, cg_outputs, {input}, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionSplitBCast_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Set up your input tensor views - TensorView* input_tv0 = makeSymbolicTensor(3); - TensorView* input_tv1 = makeSymbolicTensor(3); - fusion.addInput(input_tv0); - fusion.addInput(input_tv1); - - TensorView* sum_tv2 = reductionOp( - BinaryOpType::Add, {2}, IrBuilder::create(0), input_tv0); - TensorView* bcast_tv3 = broadcast(sum_tv2, {false, false, true}); - TensorView* output_tv4 = div(input_tv1, bcast_tv3); - - sum_tv2->split(-1, 32); - TensorView* sum_rf_tv5 = sum_tv2->rFactor({-2}); - - bcast_tv3->split(-1, 32); - output_tv4->split(-1, 32); - - sum_rf_tv5->axis(0)->parallelize(ParallelType::BIDx); - sum_tv2->axis(0)->parallelize(ParallelType::BIDx); - bcast_tv3->axis(0)->parallelize(ParallelType::BIDx); - output_tv4->axis(0)->parallelize(ParallelType::BIDx); - - sum_rf_tv5->axis(1)->parallelize(ParallelType::BIDy); - sum_tv2->axis(1)->parallelize(ParallelType::BIDy); - bcast_tv3->axis(1)->parallelize(ParallelType::BIDy); - output_tv4->axis(1)->parallelize(ParallelType::BIDy); - - sum_rf_tv5->axis(-1)->parallelize(ParallelType::TIDx); - sum_tv2->axis(-1)->parallelize(ParallelType::TIDx); - bcast_tv3->axis(-1)->parallelize(ParallelType::TIDx); - output_tv4->axis(-1)->parallelize(ParallelType::TIDx); - - fusion.addOutput(output_tv4); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({32, 32, 128}, options); - at::Tensor t1 = at::randn({32, 32, 128}, options); - at::Tensor cg_output = at::empty({32, 32, 128}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0, t1}); - fe.runFusion({t0, t1}, {cg_output}); -} - -TEST_F(NVFuserTest, FusionBCastInnerDim_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - - // reduce then broadcast - auto tv1 = sum(tv0, {0}); - auto tv2 = broadcast(tv1, {false, true}); - - TORCH_CHECK(!tv2->axis(0)->isReduction() && tv2->axis(1)->isBroadcast()); -} - -TEST_F(NVFuserTest, FusionBCastReduce_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Set up your input tensor views - TensorView* tv0 = makeSymbolicTensor(2); - - auto tv1 = broadcast(tv0, {true, false, false}); - auto tv2 = sum(tv1, {1}); - TORCH_CHECK( - tv2->axis(0)->isBroadcast() && tv2->axis(1)->isReduction() && - !tv2->axis(2)->isBroadcast() && !tv2->axis(2)->isReduction()); -} - -// Multiple consumer reduction with computeAt -// https://github.com/csarofeen/pytorch/issues/110 -TEST_F(NVFuserTest, FusionReductionMultiConsumer_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - TensorView* tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - auto tv1 = unaryOp(UnaryOpType::Exp, tv0); - auto tv2 = - reductionOp(BinaryOpType::Max, {-1}, IrBuilder::create(0), tv1); - auto tv3 = - reductionOp(BinaryOpType::Min, {-1}, IrBuilder::create(0), tv1); - auto tv4 = add(tv2, tv3); - fusion.addOutput(tv4); - tv1->computeAt(tv2, -1, ComputeAtMode::BestEffort); - - TORCH_CHECK(tv1->getComputeAtPosition() == 2); -} - -TEST_F(NVFuserTest, FusionComputeAtExprOrder1_CUDA) { - for (const auto i : c10::irange(2)) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Set up your input tensor views - TensorView* tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - - auto tv1 = add(tv0, IrBuilder::create(1)); - auto tv2 = add(tv0, IrBuilder::create(1)); - TensorView* tv3 = add(tv1, tv2); - // Set outputs tv2 or tv1 and then tv3 - if (i == 0) { - fusion.addOutput(tv2); - } else { - fusion.addOutput(tv1); - } - fusion.addOutput(tv3); - - if (i == 0) { - tv1->computeAt(tv3, -1); - } else { - tv2->computeAt(tv3, -1); - } - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor aten_input = at::randn({100}, options); - std::vector aten_outputs = { - aten_input + 1, (aten_input + 1) * 2}; - - FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input}); - auto cg_outputs = fe.runFusion({aten_input}); - - testValidate( - &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); - } -} - -TEST_F(NVFuserTest, FusionComputeAtExprOrder2_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Set up your input tensor views - TensorView* tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - - auto tv1 = add(tv0, IrBuilder::create(1)); - auto tv2 = add(tv0, IrBuilder::create(1)); - TensorView* tv3 = add(tv1, tv2); - fusion.addOutput(tv3); - - tv3->split(-1, 32); - - tv1->computeAt(tv3, -1); - tv2->computeAt(tv3, -2); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor aten_input = at::randn({100, 100}, options); - auto aten_output = (aten_input + 1) * 2; - - at::Tensor cg_output = at::empty_like(aten_input, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input}); - fe.runFusion({aten_input}, {cg_output}); - - testValidate( - &fusion, {cg_output}, {aten_input}, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionComputeAtExprOrder3_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - const int64_t dimx = 13; - const int64_t dimy = 15; - - TensorView* tv0 = makeConcreteTensor({dimx, dimy}); - fusion.addInput(tv0); - TensorView* tv1 = add(tv0, IrBuilder::create(1)); - TensorView* tv2 = add(tv1, IrBuilder::create(2)); - TensorView* tv3 = add(tv2, IrBuilder::create(3)); - TensorView* tv4 = add(tv3, IrBuilder::create(4)); - TensorView* tv5 = mul(tv2, tv4); - fusion.addOutput(tv5); - - tv1->computeAt(tv2, 2); - tv3->computeAt(tv4, 1); - tv4->computeAt(tv5, 2); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor aten_input = at::randn({dimx, dimy}, options); - auto t1 = aten_input.add(1.); - auto t2 = t1.add(2.); - auto t3 = t2.add(3.); - auto t4 = t3.add(4.); - auto aten_output = t2.mul(t4); - - torch::jit::fuser::cuda::FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input}); - auto cg_outputs = fe.runFusion({aten_input}); - - testValidate( - &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionZeroDimComputeAt_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - - auto tv1 = sum(tv0, {0}); - auto tv2 = add(tv1, IrBuilder::create(1)); - fusion.addOutput(tv2); - TORCH_CHECK(tv2->nDims() == 0); - tv1->computeAt(tv2, 0); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor aten_input = at::randn({100}, options); - auto aten_output = aten_input.to(at::kDouble).sum() + 1; - - FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input}); - auto cg_outputs = fe.runFusion({aten_input}); - - testValidate( - &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionZeroDimBroadcast_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeSymbolicTensor(0); - fusion.addInput(tv0); - - auto tv1 = broadcast(tv0, {true, true}); - TORCH_CHECK(tv1->nDims() == 2); - - TensorView* tv2 = makeSymbolicTensor(2); - fusion.addInput(tv2); - - auto tv3 = add(tv1, tv2); - auto tv4 = sum(tv3, {0, 1}); - fusion.addOutput(tv4); - - tv3->computeAt(tv4, -1); - tv3->axis(-2)->parallelize(ParallelType::TIDx); - tv3->axis(-1)->parallelize(ParallelType::TIDy); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({}, options); - at::Tensor t1 = at::randn({10, 10}, options); - - auto aten_output = (t0.unsqueeze(-1).unsqueeze(-1).expand({10, 10}) + t1) - .to(at::kDouble) - .sum(); - - std::vector aten_inputs = {t0, t1}; - at::Tensor cg_output = at::empty({}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - fe.runFusion(aten_inputs, {cg_output}); - - testValidate( - &fusion, {cg_output}, aten_inputs, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionZeroDimReduction_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - const int bdimx = 32; - const int gdimx = 32; - - TensorView* tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - - auto tv1 = sum(tv0, {0}); - fusion.addOutput(tv1); - - tv1->split(0, bdimx); - tv1->split(0, gdimx); - auto tv2 = tv1->rFactor({0}); - - tv1->axis(-1)->parallelize(ParallelType::TIDx); - tv2->axis(-1)->parallelize(ParallelType::TIDx); - tv1->axis(-2)->parallelize(ParallelType::BIDx); - tv2->axis(-2)->parallelize(ParallelType::BIDx); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor aten_input = at::randn({1000}, options); - auto aten_output = aten_input.to(at::kDouble).sum(); - - at::Tensor cg_output = at::empty({}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input}); - fe.runFusion({aten_input}, {cg_output}); - - testValidate( - &fusion, {cg_output}, {aten_input}, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionBCastAfterReduce_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - const int tidx = 128; - - // Set up your input tensor views - TensorView* tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - - auto tv1 = sum(tv0, {1}); - auto tv2 = broadcast(tv1, {false, true}); - - tv1->split(1, tidx); - auto tv3 = tv1->rFactor({-2}); - - TensorView* tv4 = makeSymbolicTensor(2); - fusion.addInput(tv4); - - auto tv5 = add(tv2, tv4); - fusion.addOutput(tv5); - tv5->split(1, tidx); - - tv3->computeAt(tv5, 1); - - tv2->split(1, tidx); - - tv1->axis(-1)->parallelize(ParallelType::TIDx); - tv2->axis(-1)->parallelize(ParallelType::TIDx); - tv3->axis(-1)->parallelize(ParallelType::TIDx); - tv5->axis(-1)->parallelize(ParallelType::TIDx); - - tv5->axis(0)->parallelize(ParallelType::BIDx); - - int x = 63, y = 200; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::Tensor t0 = at::randn({x, y}, options); - at::Tensor t4 = at::randn({x, y}, options); - - auto t3 = t0.to(at::kDouble).sum({1}).unsqueeze(-1).expand({x, y}); - auto aten_output = t3.add(t4); - - std::vector aten_inputs = {t0, t4}; - FusionExecutor fe; - fe.compileFusion(&fusion, {t0, t4}); - auto cg_outputs = fe.runFusion({t0, t4}); - - testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionOutputBroadcast_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeConcreteTensor({2, 3}); - fusion.addInput(tv0); - - TensorView* tv1 = broadcast(tv0, {true, false, true, false, true}); - - fusion.addOutput(tv1); - - const auto options = - at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::Tensor aten_input = at::randn({2, 3}, options); - auto aten_output = aten_input.unsqueeze(2).unsqueeze(1).unsqueeze(0); - - FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input}); - auto cg_outputs = fe.runFusion({aten_input}); - - testValidate( - &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionReductionKeepDimBasic_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeConcreteTensor({2, 3, 4, 5, 6}); - fusion.addInput(tv0); - - TensorView* tv1 = sum(tv0, {0, 2, -1}, /*keep_dim=*/true); - - fusion.addOutput(tv1); - - const auto options = - at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::Tensor aten_input = at::randn({2, 3, 4, 5, 6}, options); - auto aten_output = - aten_input.to(at::kDouble).sum({0, 2, -1}, /*keepdim=*/true); - - FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input}); - auto cg_outputs = fe.runFusion({aten_input}); - - testValidate( - &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionReductionKeepDimScheduler_CUDA) { - constexpr int bid_x = 80; - constexpr int tid_x = 4096; - constexpr int red_dim = 1; - - Fusion fusion; - FusionGuard fg(&fusion); - - // Set up your input tensor views - TensorView* tv0 = makeConcreteTensor({bid_x, tid_x}); - fusion.addInput(tv0); - - TensorView* tv1 = reductionOp( - BinaryOpType::Add, - {red_dim}, - IrBuilder::create(0), - tv0, - /*keep_dim=*/true); - - fusion.addOutput(tv1); - - const auto options = - at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::Tensor aten_input = at::randn({bid_x, tid_x}, options); - auto aten_output = - aten_input.to(at::kDouble).sum({red_dim}, /*keepdim=*/true); - - // Apply reduction heuristic - auto reduction_params = getReductionHeuristics(&fusion, {aten_input}); - TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); - scheduleReduction(&fusion, *reduction_params); - - auto lparams = reduction_params->lparams; - - FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input}, lparams); - auto cg_outputs = fe.runFusion({aten_input}, lparams); - - testValidate( - &fusion, - cg_outputs, - {aten_input}, - {aten_output}, - __LINE__, - __FILE__, - "", - lparams); -} - -TEST_F(NVFuserTest, FusionSumTo_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - std::vector tensor_shape{2, 3, 4, 5, 6}; - std::vector sum_to_shape{1, 5, 6}; - - std::vector tensor_shape_ref{2, 3, 4, 5, 6}; - std::vector sum_to_shape_ref{1, 5, 6}; - - std::vector sum_to_symb; - std::transform( - sum_to_shape.begin(), - sum_to_shape.end(), - std::back_inserter(sum_to_symb), - [](int s) -> Int* { return IrBuilder::create(s); }); - - TensorView* tv0 = makeConcreteTensor(tensor_shape); - fusion.addInput(tv0); - - TensorView* tv1 = sum_to(tv0, sum_to_symb); - fusion.addOutput(tv1); - - const auto options = - at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::Tensor aten_input = at::randn(tensor_shape_ref, options); - auto aten_output = at::sum_to(aten_input.to(at::kDouble), sum_to_shape_ref); - - FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input}); - auto cg_outputs = fe.runFusion({aten_input}); - - TORCH_CHECK( - cg_outputs[0].dim() == static_cast(sum_to_shape.size()), - "sum_to not keeping the final dimension"); - - testValidate( - &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionSumToNoop_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - std::vector tensor_shape{4, 5, 6}; - std::vector sum_to_shape{4, 5, 6}; - - std::vector tensor_shape_ref{4, 5, 6}; - std::vector sum_to_shape_ref{4, 5, 6}; - - std::vector sum_to_symb; - std::transform( - sum_to_shape.begin(), - sum_to_shape.end(), - std::back_inserter(sum_to_symb), - [](int s) -> Int* { return IrBuilder::create(s); }); - - TensorView* tv0 = makeConcreteTensor(tensor_shape); - fusion.addInput(tv0); - - TensorView* tv1 = sum_to(tv0, sum_to_symb); - - // Dummy operator to avoid tv0 both input and output - TensorView* tv2 = add(tv1, IrBuilder::create(0)); - fusion.addOutput(tv2); - - const auto options = - at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::Tensor aten_input = at::randn(tensor_shape_ref, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input}); - auto cg_outputs = fe.runFusion({aten_input}); - auto aten_output = at::sum_to(aten_input.to(at::kDouble), sum_to_shape_ref); - - TORCH_CHECK( - cg_outputs[0].dim() == static_cast(sum_to_shape.size()), - "sum_to not keeping the final dimension"); - - testValidate( - &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionReductionScheduler_CUDA) { - constexpr int bid_x = 80; - constexpr int tid_x = 4096; - constexpr int red_dim = 1; - - Fusion fusion; - FusionGuard fg(&fusion); - - // Set up your input tensor views - TensorView* tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - - TensorView* tv1 = reductionOp( - BinaryOpType::Add, {red_dim}, IrBuilder::create(0), tv0); - fusion.addOutput(tv1); - - const auto options = - at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::Tensor aten_input = at::randn({bid_x, tid_x}, options); - auto aten_output = aten_input.to(at::kDouble).sum({red_dim}); - - // Apply reduction heuristic - auto reduction_params = getReductionHeuristics(&fusion, {aten_input}); - TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); - scheduleReduction(&fusion, *reduction_params); - - auto lparams = reduction_params->lparams; - - FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input}, lparams); - // no broadcasting needed, omitting the last optional argument; - auto cg_outputs = fe.runFusion({aten_input}, lparams); - - testValidate( - &fusion, - cg_outputs, - {aten_input}, - {aten_output}, - __LINE__, - __FILE__, - "", - lparams); -} - -// Simple reduction parallelized on a symbolic size. -TEST_F(NVFuserTest, FusionSymbolicReduction_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Set up your input tensor views - TensorView* tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - - // tv1[I0, R1] = tv0[I0, I1] - TensorView* tv1 = - reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv0); - fusion.addOutput(tv1); - - // Interface should just be a direct split with a Parallel type. We can - // include the parallelize call if we do this. - tv1->split(1, NamedScalar::getParallelDim(ParallelType::TIDx)); - // tv1[I0, R1o, R1i{BIDx}] = tv0[I0, I1] - - TensorView* tv2 = tv1->rFactor({1}); - // tv2[I0, R1oo, Ir1oi{4}, Ir1i{BIDx}] = tv0[I0, I1] - // tv1[I0, R1oi{4}, R1i{BIDx}] = tv2[I0, R1oo, Ir1oi{4}, Ir1i{BIDx}] - - // Incrementally, can print in between for debugging - tv0->computeAt(tv2, 1); - tv2->computeAt(tv1, 1); - - tv2->axis(-1)->parallelize(ParallelType::TIDx); - - tv1->axis(0)->parallelize(ParallelType::BIDx); - tv1->axis(-1)->parallelize(ParallelType::TIDx); - - int numel_x = 65000; - int numel_y = 1025; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor aten_input = at::randn({numel_x, numel_y}, options); - auto aten_output = aten_input.to(at::kDouble).sum({1}); - - // How many threads to use for the block reduction - int runtime_threadIdx_dim = 128; - - LaunchParams lparams(-1, -1, -1, runtime_threadIdx_dim, -1, -1); - - FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input}, lparams); - auto cg_outputs = fe.runFusion({aten_input}, lparams); - - testValidate( - &fusion, - cg_outputs, - {aten_input}, - {aten_output}, - __LINE__, - __FILE__, - "", - lparams); -} - -TEST_F(NVFuserTest, FusionReductionSchedulerMultiDimNonFastest_CUDA) { - const std::vector red_dims = {0, 2}; - // Copy is because CodeGen requires int and Pytorch requires int64_t - // for a vector of reduction dimensions - const std::vector red_dims64 = {0, 2}; - const std::vector tensor_dims_in = {5, 10, 15, 20}; - const std::vector tensor_dims_out = {10, 20}; - - Fusion fusion; - FusionGuard fg(&fusion); - - // Set up your input tensor views - TensorView* tv0 = makeSymbolicTensor(tensor_dims_in.size()); - fusion.addInput(tv0); - - TensorView* tv1 = reductionOp( - BinaryOpType::Add, red_dims, IrBuilder::create(0), tv0); - fusion.addOutput(tv1); - - const auto options = - at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor aten_input = at::randn(tensor_dims_in, options); - auto aten_output = aten_input.to(at::kDouble).sum(red_dims64); - at::Tensor cg_output = at::empty(tensor_dims_out, options); - - // Apply reduction heuristic - auto reduction_params = getReductionHeuristics(&fusion, {aten_input}); - TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); - scheduleReduction(&fusion, *reduction_params); - auto lparams = reduction_params->lparams; - - FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input}, lparams); - fe.runFusion({aten_input}, {cg_output}, lparams); - - testValidate( - &fusion, - {cg_output}, - {aten_input}, - {aten_output}, - __LINE__, - __FILE__, - "", - lparams); -} - -TEST_F(NVFuserTest, FusionReductionSchedulerMultiDimFastest_CUDA) { - const std::vector red_dims = {1, 3}; - // Copy is because CodeGen requires int and Pytorch requires int64_t - // for a vector of reduction dimensions - const std::vector red_dims64 = {1, 3}; - const std::vector tensor_dims_in = {5, 10, 15, 20}; - - Fusion fusion; - FusionGuard fg(&fusion); - - // Set up your input tensor views - TensorView* tv0 = makeSymbolicTensor(tensor_dims_in.size()); - fusion.addInput(tv0); - - TensorView* tv1 = reductionOp( - BinaryOpType::Add, red_dims, IrBuilder::create(0), tv0); - fusion.addOutput(tv1); - - const auto options = - at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor aten_input = at::randn(tensor_dims_in, options); - auto aten_output = aten_input.to(at::kDouble).sum(red_dims64); - - auto reduction_params = getReductionHeuristics(&fusion, {aten_input}); - TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); - scheduleReduction(&fusion, *reduction_params); - auto lparams = reduction_params->lparams; - - FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input}, lparams); - auto cg_outputs = fe.runFusion({aten_input}, lparams); - - testValidate( - &fusion, - cg_outputs, - {aten_input}, - {aten_output}, - __LINE__, - __FILE__, - "", - lparams); -} - -TEST_F(NVFuserTest, FusionReductionSchedulerNoODimShmoo_CUDA) { - std::vector dtypes = { - DataType::Double, DataType::Float, DataType::Half}; - // TODO: add test for complex. Currently complex fails with the following - // NVRTC compilation error message: - // error: no suitable user-defined conversion from - // "CudaCodeGen::std::complex" to "CudaCodeGen::std::complex" - // exists -#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 - if (at::cuda::getDeviceProperties(0)->major >= 8) { - dtypes.insert(dtypes.end(), DataType::BFloat16); - } -#endif - - std::vector red_dims; - - // Tried to cut down the number iterations with just - // doing every other power of 2. - for (int i = 1; i <= 1024 * 1024; i <<= 2) { - red_dims.push_back(i); - } - - for (auto dtype : dtypes) { - at::ScalarType aten_dtype = data_type_to_aten(dtype); - for (auto& rdim : red_dims) { - Fusion fusion; - FusionGuard fg(&fusion); - - bool is_fp16 = dtype == DataType::Half; - bool is_bf16 = dtype == DataType::BFloat16; - - TensorView* tv0 = makeSymbolicTensor(1, dtype); - fusion.addInput(tv0); - - TensorView* tv0_cast = tv0; - if (is_fp16 || is_bf16) { - tv0_cast = castOp(DataType::Float, tv0); - } - - TensorView* tv1 = sum(tv0_cast, {0}); - - TensorView* tv1_cast = tv1; - if (is_fp16) { - tv1_cast = castOp(DataType::Half, tv1); - } - if (is_bf16) { - tv1_cast = castOp(DataType::BFloat16, tv1); - } - - fusion.addOutput(tv1_cast); - - auto options = at::TensorOptions().dtype(aten_dtype).device(at::kCUDA, 0); - - at::Tensor aten_input = at::randn({rdim}, options); - auto aten_output = aten_input.to(at::kDouble).sum({0}); - - auto reduction_params = getReductionHeuristics(&fusion, {aten_input}); - TORCH_CHECK(reduction_params != nullptr, "Reduction is not found!"); - scheduleReduction(&fusion, *reduction_params); - auto lparams = reduction_params->lparams; - - FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input}, lparams); - auto cg_outputs = fe.runFusion({aten_input}, lparams); - - testValidate( - &fusion, - cg_outputs, - {aten_input}, - {aten_output}, - __LINE__, - __FILE__, - "", - lparams); - } - } -} - -TEST_F(NVFuserTest, FusionReductionSchedulerDimShmoo_CUDA) { - std::vector dtypes = { - DataType::Double, DataType::Float, DataType::Half}; - // TODO: add complex support. Currently, complex fails with the following - // NVRTC compilation error: - // error: no instance of overloaded function "__shfl_xor_sync" matches the - // argument list -#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 - if (at::cuda::getDeviceProperties(0)->major >= 8) { - dtypes.insert(dtypes.end(), DataType::BFloat16); - } -#endif - - std::vector red_axis = {1, 0}; - std::vector output_dims = {160, 320}; - std::vector red_dims; - - // Tried to cut down the number iterations with just - // doing every other power of 2. - for (int i = 1; i <= 1024 * 1024; i <<= 2) { - red_dims.push_back(i); - } - - for (auto dtype : dtypes) { - at::ScalarType aten_dtype = data_type_to_aten(dtype); - for (auto& axis : red_axis) { - for (auto& odim : output_dims) { - for (auto& rdim : red_dims) { - Fusion fusion; - FusionGuard fg(&fusion); - - bool is_fp16 = dtype == DataType::Half; - bool is_bf16 = dtype == DataType::BFloat16; - - TensorView* tv0 = makeSymbolicTensor(2, dtype); - fusion.addInput(tv0); - - TensorView* tv0_cast = tv0; - if (is_fp16 || is_bf16) { - tv0_cast = castOp(DataType::Float, tv0); - } - - TensorView* tv1 = sum(tv0_cast, {axis}); - - TensorView* tv1_cast = tv1; - if (is_fp16) { - tv1_cast = castOp(DataType::Half, tv1); - } - if (is_bf16) { - tv1_cast = castOp(DataType::BFloat16, tv1); - } - fusion.addOutput(tv1_cast); - - auto options = - at::TensorOptions().dtype(aten_dtype).device(at::kCUDA, 0); - - at::Tensor aten_input = - (axis ? at::randn({odim, rdim}, options) - : at::randn({rdim, odim}, options)); - - auto reduction_params = getReductionHeuristics(&fusion, {aten_input}); - TORCH_CHECK(reduction_params != nullptr, "Reduction is not found!"); - scheduleReduction(&fusion, *reduction_params); - auto lparams = reduction_params->lparams; - - FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input}, lparams); - auto cg_outputs = fe.runFusion({aten_input}, lparams); - auto aten_output = aten_input.to(at::kDouble).sum({axis}); - testValidate( - &fusion, - cg_outputs, - {aten_input}, - {aten_output}, - __LINE__, - __FILE__, - "", - lparams); - } - } - } - } -} - -TEST_F(NVFuserTest, FusionCacheBefore_CUDA) { - // TVM Cache Write - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeSymbolicTensor(2); - TensorView* tv1 = add(tv0, IrBuilder::create(1.0)); - TensorView* tv2 = mul(tv1, IrBuilder::create(3.0)); - fusion.addInput(tv0); - fusion.addOutput(tv2); - - // Before: TV2 = TV1 * 3 - // After: TV3 = TV1 * 3; - // TV2 = TV3; - TensorView* tv3 = tv2->cacheBefore(); - - constexpr int BSX = 32; - tv2->split(-1, BSX); - tv0->computeAt(tv2, -1); - - // Thread and Block binding - tv2->axis(0)->parallelize(ParallelType::BIDx); - tv2->axis(-1)->parallelize(ParallelType::TIDx); - - constexpr int M = 32, N = 750; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor aten_input = at::randn({M, N}, options); - at::Tensor aten_output = (aten_input + 1.0) * 3.0; - - FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input}); - auto cg_outputs = fe.runFusion({aten_input}); - - testValidate( - &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionCacheAfter_CUDA) { - // TVM Cache Read - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeSymbolicTensor(2); - TensorView* tv1 = add(tv0, IrBuilder::create(1.0)); - TensorView* tv2 = mul(tv1, IrBuilder::create(3.0)); - fusion.addInput(tv0); - fusion.addOutput(tv2); - - // Before: TV1 = TV0 + 1 - // After: TV3 = TV0; - // TV1 = TV3 + 1 - TensorView* tv3 = tv0->cacheAfter(); - - constexpr int BSX = 32; - tv2->split(-1, BSX); - tv0->computeAt(tv2, -1); - - // Thread and Block binding - tv2->axis(0)->parallelize(ParallelType::BIDx); - tv2->axis(-1)->parallelize(ParallelType::TIDx); - - constexpr int M = 32, N = 457; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor aten_input = at::randn({M, N}, options); - at::Tensor aten_output = (aten_input + 1.0) * 3.0; - - FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input}); - auto cg_outputs = fe.runFusion({aten_input}); - - testValidate( - &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionCacheFork_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeSymbolicTensor(2); - TensorView* tv1 = add(tv0, IrBuilder::create(1.0)); - TensorView* tv2 = mul(tv1, IrBuilder::create(3.0)); - fusion.addInput(tv0); - fusion.addOutput(tv1); - fusion.addOutput(tv2); - // Before: TV1 = TV0 + 1 - // TV2 = TV1 * 1 - // Output: TV1, TV2 - - // After: TV1 = TV0 + 1 - // TV3 = TV1 - // TV2 = TV1 * 1 - // Output: TV3, TV2 - - // cacheFork !!does not!! automatically apply ComputeAt to the cache - auto tv3 = tv1->cacheFork(); - - constexpr int BSX = 32; - tv2->split(-1, BSX); - tv0->computeAt(tv2, -1); - - // Thread and Block binding - tv2->axis(0)->parallelize(ParallelType::BIDx); - tv2->axis(-1)->parallelize(ParallelType::TIDx); - - constexpr int M = 32, N = 457; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor aten_input = at::randn({M, N}, options); - at::Tensor aten_output1 = aten_input + 1.0; - at::Tensor aten_output2 = aten_output1 * 3.0; - - FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input}); - auto cg_outputs = fe.runFusion({aten_input}); - - testValidate( - &fusion, - cg_outputs, - {aten_input}, - {aten_output1, aten_output2}, - __LINE__, - __FILE__); -} - -TEST_F(NVFuserTest, FusionCacheIndirect_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeSymbolicTensor(2); - TensorView* tv1 = makeSymbolicTensor(2); - TensorView* tv2 = makeSymbolicTensor(2); - TensorView* tv3 = makeSymbolicTensor(2); - TensorView* tv4 = sub(tv2, tv3); - TensorView* tv5 = add(tv1, tv4); - TensorView* tv6 = sub(tv5, tv0); - fusion.addInput(tv0); - fusion.addInput(tv1); - fusion.addInput(tv2); - fusion.addInput(tv3); - fusion.addOutput(tv6); - // t6 = ((t1 + (t2 - t3)) - t0) - - tv5->cacheAfter(); - tv5->cacheBefore(); - - // cacheAfter on inputs placed before schedule - constexpr int BSX = 32; - tv6->split(-1, BSX); - tv2->computeAt(tv6, -1); - - // Thread and Block binding - tv6->axis(0)->parallelize(ParallelType::BIDx); - tv6->axis(-1)->parallelize(ParallelType::TIDx); - - constexpr int M = 32, N = 810; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({M, N}, options); - at::Tensor t1 = at::randn({M, N}, options); - at::Tensor t2 = at::randn({M, N}, options); - at::Tensor t3 = at::randn({M, N}, options); - - std::vector aten_inputs = {t0, t1, t2, t3}; - at::Tensor aten_output = (t1 + (t2 - t3)) - t0; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto cg_outputs = fe.runFusion(aten_inputs); - - testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionCacheBcast_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Algorithm - TensorView* tv0 = makeSymbolicTensor(1); // (M, 1) - TensorView* tv1 = broadcast(tv0, {false, true}); - TensorView* tv2 = makeSymbolicTensor(1); // (1, N) - TensorView* tv3 = broadcast(tv2, {true, false}); - TensorView* tv4 = mul(tv1, tv3); - fusion.addInput(tv0); - fusion.addInput(tv2); - fusion.addOutput(tv4); - - // Case 1 - tv0->cacheAfter(); - - // Case 2 - tv1->cacheBefore(); - - // Case 3 - tv1->cacheAfter(); - - // Case 4 - TensorView* tv8 = tv4->cacheBefore(); - - constexpr int BSX = 128; - tv4->split(0, BSX); - tv4->split(-1, BSX); - tv4->reorder({{0, 0}, {1, 2}, {2, 1}, {3, 3}}); - // M/BSX, N/BSY, BSX, BSY - tv0->computeAt(tv4, 2); - tv2->computeAt(tv4, 2); - // 0, 1 | 2, 3, 4 - - tv4->axis(0)->parallelize(ParallelType::BIDx); - tv4->axis(1)->parallelize(ParallelType::BIDy); - tv4->axis(-1)->parallelize(ParallelType::TIDx); - // Manual Replay on TV3 - tv3->axis(-1)->parallelize(ParallelType::TIDx); - tv8->axis(-1)->parallelize(ParallelType::TIDx); - - constexpr int M = 92, N = 500; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({M}, options); - at::Tensor t1 = at::randn({N}, options); - std::vector aten_inputs = {t0, t1}; - at::Tensor aten_output = - t0.to(at::kDouble).unsqueeze(1).matmul(t1.to(at::kDouble).unsqueeze(0)); - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto cg_outputs = fe.runFusion(aten_inputs); - - testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionCacheMultiConsumer_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeSymbolicTensor(1); - TensorView* tv1 = add(tv0, IrBuilder::create(1)); - TensorView* tv2 = add(tv1, IrBuilder::create(2)); - TensorView* tv3 = add(tv0, IrBuilder::create(1)); - TensorView* tv4 = add(tv3, IrBuilder::create(2)); - - fusion.addInput(tv0); - fusion.addOutput(tv2); - fusion.addOutput(tv4); - - auto tv5 = tv1->cacheBefore(); - auto tv6 = tv3->cacheBefore(); - tv5->setMemoryType(MemoryType::Shared); - tv6->setMemoryType(MemoryType::Shared); - - tv1->computeAt(tv2, -1); - tv3->computeAt(tv4, -1); - - // Fails because tensor must be recomputed twice - // auto tv7 = tv0->cacheAfter(); - - constexpr int N = 800; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor aten_input = at::randn({N}, options); - auto aten_output = (aten_input + 1) + 2; - - FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input}); - auto cg_outputs = fe.runFusion({aten_input}); - - testValidate( - &fusion, - cg_outputs, - {aten_input}, - {aten_output, aten_output}, - __LINE__, - __FILE__); -} - -TEST_F(NVFuserTest, FusionSmem_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Algorithm - TensorView* tv0 = makeSymbolicTensor(2); // (M, N) - TensorView* tv1 = makeSymbolicTensor(2); // (M, N) - TensorView* tv2 = mul(tv0, tv1); - fusion.addInput(tv0); - fusion.addInput(tv1); - fusion.addOutput(tv2); - - // Schedule - TensorView* tv3 = tv0->cacheAfter(); - TensorView* tv4 = tv1->cacheAfter(); - tv3->setMemoryType(MemoryType::Shared); - tv4->setMemoryType(MemoryType::Shared); - - constexpr int BSY = 32; - constexpr int BSX = 128; - tv2->split(0, BSY); - tv2->split(2, BSX); - // M/BSX, BSX, N/BSX, BSX - tv2->reorder({{0, 0}, {1, 2}, {2, 1}, {3, 3}}); - // M/BSX, N/BSX, BSX, BSX - - tv0->computeAt(tv2, 2); - tv1->computeAt(tv2, 2); - - // Thread and Block binding - tv2->axis(0)->parallelize(ParallelType::BIDx); - tv2->axis(1)->parallelize(ParallelType::BIDy); - tv2->axis(-1)->parallelize(ParallelType::TIDx); - // Manual Binding - tv3->axis(-1)->parallelize(ParallelType::TIDx); - tv4->axis(-1)->parallelize(ParallelType::TIDx); - - constexpr int M = 128, N = 10240; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({M, N}, options); - at::Tensor t1 = at::randn({M, N}, options); - at::Tensor aten_output = mul(t0, t1); - - std::vector aten_inputs = {t0, t1}; - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0, t1}); - auto cg_outputs = fe.runFusion({t0, t1}); - - testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); - - TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 0); -} - -TEST_F(NVFuserTest, FusionSmemReduce_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Algorithm - TensorView* tv0 = makeSymbolicTensor(3); // M, K, N - TensorView* tv1 = sum(tv0, {1}); // M, R, N - fusion.addInput(tv0); - fusion.addOutput(tv1); - - TensorView* tv2 = tv0->cacheAfter(); - tv2->setMemoryType(MemoryType::Shared); - - // Schedule - constexpr int BSX = 32; - tv1->split(2, BSX); - tv1->split(1, 128); - tv1->split(0, BSX); - // M/BSX, BSX, K/BSX, BSX, N/BSX, BSX - tv1->reorder({{0, 0}, {1, 2}, {2, 4}, {3, 5}, {4, 1}, {5, 3}}); - TensorView* tv3 = tv1->rFactor({-2}); - - tv0->computeAt(tv1, -2); - tv0->computeAt(tv3, -2); - - // Thread and Block binding - tv1->axis(0)->parallelize(ParallelType::BIDx); - tv1->axis(1)->parallelize(ParallelType::BIDy); - tv1->axis(-1)->parallelize(ParallelType::TIDx); - // Manual Binding - tv2->axis(-1)->parallelize(ParallelType::TIDx); - tv3->axis(-1)->parallelize(ParallelType::TIDx); - - constexpr int M = 154, K = 45, N = 1524; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor aten_input = at::randn({M, K, N}, options); - at::Tensor aten_output = sum(aten_input.to(at::kDouble), {1}); - - FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input}); - auto cg_outputs = fe.runFusion({aten_input}); - - testValidate( - &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); - TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 0); -} - -TEST_F(NVFuserTest, FusionSmemBlockGemm_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Algorithm - TensorView* tv0 = makeSymbolicTensor(2); // (M, K) - TensorView* tv1 = makeSymbolicTensor(2); // (K, N) - TensorView* tv2 = broadcast(tv0, {false, false, true}); // (M, K, B) - TensorView* tv3 = broadcast(tv1, {true, false, false}); // (B, K, N) - TensorView* tv4 = mul(tv2, tv3); // M, K, N - TensorView* tv5 = sum(tv4, {1}); // M, R, N - fusion.addInput(tv0); - fusion.addInput(tv1); - fusion.addOutput(tv5); - - // Schedule - constexpr int BSX = 16; - tv5->split(2, BSX - 1); - tv5->split(1, BSX); - tv5->split(0, BSX + 1); - // M/BSX, BSX, K/BSX, BSX, N/BSX, BSX - tv5->reorder({{0, 0}, {1, 3}, {2, 2}, {3, 5}, {4, 1}, {5, 4}}); - // M/BSX, N/BSX, K/BSX, MSX, NSX, KSX - TensorView* tv6 = tv5->rFactor({-1}); - - tv2->setMemoryType(MemoryType::Shared); - tv3->setMemoryType(MemoryType::Shared); - tv4->setMemoryType(MemoryType::Shared); - tv6->setMemoryType(MemoryType::Shared); - - tv0->computeAt(tv5, 3); - tv1->computeAt(tv5, 3); - - // Thread and Block binding - tv5->axis(0)->parallelize(ParallelType::BIDx); - tv5->axis(1)->parallelize(ParallelType::BIDy); - tv5->axis(-2)->parallelize(ParallelType::TIDy); - tv5->axis(-1)->parallelize(ParallelType::TIDx); - // Manual Binding - tv2->axis(-3)->parallelize(ParallelType::TIDy); - tv2->axis(-1)->parallelize(ParallelType::TIDx); - tv3->axis(-1)->parallelize(ParallelType::TIDx); - tv4->axis(-3)->parallelize(ParallelType::TIDy); - tv4->axis(-1)->parallelize(ParallelType::TIDx); - tv6->axis(-3)->parallelize(ParallelType::TIDy); - tv6->axis(-2)->parallelize(ParallelType::TIDx); - - // Make sure BIDx is makred as exact (see issue #1119) - GpuLower gpulw(&fusion); - TORCH_CHECK(gpulw.parallelDimensionMap().isExact(ParallelType::BIDx)); - - constexpr int M = 154, K = 45, N = 1524; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({M, K}, options); - at::Tensor t1 = at::randn({K, N}, options); - - std::vector aten_inputs = {t0, t1}; - at::Tensor aten_output = matmul(t0.to(at::kDouble), t1.to(at::kDouble)); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0, t1}); - auto cg_outputs = fe.runFusion({t0, t1}); - - testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); - - TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 0); -} - -TEST_F(NVFuserTest, FusionSmemBlockGemmCache_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Algorithm - TensorView* tv0 = makeSymbolicTensor(2); // (M, K) - TensorView* tv1 = makeSymbolicTensor(2); // (K, N) - TensorView* tv2 = broadcast(tv0, {false, false, true}); // (M, K, B) - TensorView* tv3 = broadcast(tv1, {true, false, false}); // (B, K, N) - TensorView* tv4 = mul(tv2, tv3); // M, K, N - TensorView* tv5 = sum(tv4, {1}); // M, R, N - fusion.addInput(tv0); - fusion.addInput(tv1); - fusion.addOutput(tv5); - - // Schedule - // Remove reduction axis from tv5 - // tv6 = (M, R, N) - // tv5 = (M, N) - TensorView* tv6 = tv5->cacheBefore(); - - constexpr int BSX = 16; - tv5->split(1, BSX); - tv5->split(0, BSX); - // M/BSX, BSX, N/BSX, BSX - tv5->reorder({{0, 0}, {1, 2}, {2, 1}, {3, 3}}); - // tv5 = M/BSX, N/BSX, MSX, NSX - - tv6->computeAt(tv5, 2); - tv6->computeAt(tv5, 2); - - tv6->split(-1, BSX); - // M/BSX, BSX, K/BSX, BSX, N/BSX, BSX - tv6->reorder({{0, 0}, {1, 1}, {2, 3}, {3, 4}, {4, 2}, {5, 5}}); - // M/BSX, N/BSX, K/BSX, MSX, NSX, KSX - TensorView* tv7 = tv6->rFactor({-1}); - // tv7 = M/BSX, N/BSX, K/BSXrf, MSX, NSX, KSXr - // tv6 = M/BSX, N/BSX, K/BSXr, MSX, NSX - - tv0->computeAt(tv6, 3); - tv1->computeAt(tv6, 3); - - tv0->computeAt(tv7, 3); - tv1->computeAt(tv7, 3); - - tv2->setMemoryType(MemoryType::Shared); - tv3->setMemoryType(MemoryType::Shared); - tv4->setMemoryType(MemoryType::Shared); - tv6->setMemoryType(MemoryType::Shared); - tv7->setMemoryType(MemoryType::Shared); - // Memory Type - - // Thread and Block binding - tv5->axis(0)->parallelize(ParallelType::BIDx); - tv5->axis(1)->parallelize(ParallelType::BIDy); - tv5->axis(-2)->parallelize(ParallelType::TIDy); - tv5->axis(-1)->parallelize(ParallelType::TIDx); - // Manual Binding - tv2->axis(-3)->parallelize(ParallelType::TIDy); - tv2->axis(-1)->parallelize(ParallelType::TIDx); - tv3->axis(-1)->parallelize(ParallelType::TIDx); - tv4->axis(-3)->parallelize(ParallelType::TIDy); - tv4->axis(-1)->parallelize(ParallelType::TIDx); - - tv7->axis(-3)->parallelize(ParallelType::TIDy); - tv7->axis(-2)->parallelize(ParallelType::TIDx); - - tv6->axis(-2)->parallelize(ParallelType::TIDy); - tv6->axis(-1)->parallelize(ParallelType::TIDx); - - constexpr int M = 154, K = 45, N = 1524; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({M, K}, options); - at::Tensor t1 = at::randn({K, N}, options); - at::Tensor aten_output = matmul(t0.to(at::kDouble), t1.to(at::kDouble)); - - std::vector aten_inputs = {t0, t1}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto cg_outputs = fe.runFusion(aten_inputs); - - testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); - - TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 0); -} - -TEST_F(NVFuserTest, FusionSmemDynamicPersistentSoftmax2D_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* x = makeSymbolicTensor(2); - fusion.addInput(x); - TensorView* max_val = reductionOp( - BinaryOpType::Max, - {-1}, - IrBuilder::create(std::numeric_limits::lowest()), - x); // (M) - TensorView* bcast_max = broadcast(max_val, {false, true}); // (M, B) - TensorView* x_max_sub = sub(x, bcast_max); // (M, N) - TensorView* exp = unaryOp(UnaryOpType::Exp, x_max_sub); // (M, N) - TensorView* sum_exp = sum(exp, {-1}); // (M, R) - TensorView* bcast_sum = broadcast(sum_exp, {false, true}); // (M, B) - TensorView* softmax = div(exp, bcast_sum); // (M, N) - fusion.addOutput(softmax); - - // Read Input into Shared Memory - // Load Input + Pwise into shared memory - auto cache_x = x->cacheAfter(); - cache_x->setMemoryType(MemoryType::Shared); - exp->setMemoryType(MemoryType::Shared); - - std::vector all_tensors( - {x, - cache_x, - max_val, - bcast_max, - x_max_sub, - exp, - sum_exp, - bcast_sum, - softmax}); - - auto tidx = IrBuilder::create(); - fusion.addInput(tidx); - - for (auto tensor : all_tensors) { - tensor->split(-1, tidx); - } - - auto sum_exp_rf = sum_exp->rFactor({1}); - all_tensors.push_back(sum_exp_rf); - - // computeAt - x->computeAt(x_max_sub, 1); - exp->computeAt(softmax, 1); - x_max_sub->computeAt(exp, 2); - - softmax->axis(0)->parallelize(ParallelType::BIDx); - for (auto tensor : all_tensors) { - tensor->axis(-1)->parallelize(ParallelType::TIDx); - } - - const int64_t dimx = 1024; - const int64_t dimy = 4096; - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor aten_input = at::randn({dimx, dimy}, options); - auto aten_output = at::_softmax(aten_input.to(at::kDouble), -1, false); - - torch::jit::fuser::cuda::FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input, 128}); - auto cg_outputs = fe.runFusion({aten_input, 128}); - - testValidate( - &fusion, - cg_outputs, - {aten_input, 128}, - {aten_output}, - __LINE__, - __FILE__); -} - -TEST_F(NVFuserTest, FusionMagicSchedulerSoftmax_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - const int kReductionAxis = 3; - std::vector input_shape{10, 10, 10, 67}; - TensorView* input = makeSymbolicTensor(input_shape.size()); - fusion.addInput(input); - - auto output = softmax(input, kReductionAxis); - - fusion.addOutput(output); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor aten_input = at::randn(input_shape, options); - auto aten_output = - at::_softmax(aten_input.to(at::kDouble), kReductionAxis, false); - - auto reduction_params = getPersistentHeuristics(&fusion, {aten_input}); - TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); - - schedulePersistentKernel(&fusion, *reduction_params); - - auto lparams = reduction_params->lparams; - - torch::jit::fuser::cuda::FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input}, lparams); - auto cg_outputs = fe.runFusion({aten_input}, lparams); - - testValidate( - &fusion, - cg_outputs, - {aten_input}, - {aten_output}, - __LINE__, - __FILE__, - "", - lparams); -} - -TEST_F(NVFuserTest, FusionTestMaskSoftmax_CUDA) { - // This test is testing the usage of all padding tokens - // with softmax like Bert might might use in a full padding - // sequence. - Fusion fusion; - FusionGuard fg(&fusion); - - const int kReductionAxis = 3; - std::vector input_shape{256, 16, 128, 128}; - TensorView* input = makeSymbolicTensor(input_shape.size()); - TensorView* mask = makeSymbolicTensor(input_shape.size()); - fusion.addInput(input); - fusion.addInput(mask); - - auto out1 = add(input, mask); - auto output = softmax(out1, kReductionAxis); - - fusion.addOutput(output); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor aten_input = at::randn(input_shape, options); - at::Tensor aten_mask = at::ones(input_shape, options); - // -10,000 is used here as a magic number because the padding - // tokens need to be a value that gives a value close to zero - // as to not influence softmax. Bert, in particular, does - // not use -Infinity because sometimes it will have a - // softmax of all padding tokkens that can result a divide by - // zero that creates NaN result. - aten_mask = aten_mask * -10000.0; - auto aten_out1 = aten_input + aten_mask; - auto aten_output = at::_softmax(aten_out1, kReductionAxis, false); - - auto reduction_params = - getPersistentHeuristics(&fusion, {aten_input, aten_mask}); - TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); - - schedulePersistentKernel(&fusion, *reduction_params); - - auto lparams = reduction_params->lparams; - - torch::jit::fuser::cuda::FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input, aten_mask}, lparams); - auto cg_outputs = fe.runFusion({aten_input, aten_mask}, lparams); - - testValidate( - &fusion, - cg_outputs, - {aten_input, aten_mask}, - {aten_output}, - __LINE__, - __FILE__, - "", - lparams); -} - -TEST_F(NVFuserTest, FusionMagicSchedulerLayerNormBackward_CUDA) { - std::unique_ptr fusion_ptr = std::make_unique(); - Fusion& fusion = *fusion_ptr.get(); - FusionGuard fg(&fusion); - - std::vector shape{20, 100, 35, 67}; - std::vector norm_shape{67}; - - const size_t kM = shape.size(); - const size_t kN = norm_shape.size(); - const size_t kOuterNumDims = kM - kN; - - std::vector outer_shape; - for (const auto idx : c10::irange(kOuterNumDims)) { - outer_shape.push_back(shape[idx]); - } - for (const auto idx : c10::irange(kOuterNumDims, kM)) { - outer_shape.push_back(1); - } - - auto grad_out = makeSymbolicTensor(shape.size()); - auto input = makeSymbolicTensor(shape.size()); - auto mean = makeConcreteTensor(outer_shape); - auto rstd = makeConcreteTensor(outer_shape); - auto weight = makeSymbolicTensor(norm_shape.size()); - auto bias = makeSymbolicTensor(norm_shape.size()); - fusion.addInput(grad_out); - fusion.addInput(input); - fusion.addInput(mean); - fusion.addInput(rstd); - fusion.addInput(weight); - fusion.addInput(bias); - - auto grads = layer_norm_backward( - grad_out, - input, - norm_shape, - mean, - rstd, - weight, - bias, - {true, true, true}); - - fusion.addOutput(grads.grad_input); - fusion.addOutput(grads.grad_weight); - fusion.addOutput(grads.grad_bias); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor aten_grad_out = at::randn(shape, options); - at::Tensor aten_input = at::randn(shape, options); - at::Tensor aten_weight = at::randn(norm_shape, options); - at::Tensor aten_bias = at::randn(norm_shape, options); - auto at_weight = c10::optional(aten_weight); - auto at_bias = c10::optional(aten_bias); - - const float kEps = 1e-5; - auto aten_results = - at::native_layer_norm(aten_input, norm_shape, at_weight, at_bias, kEps); - auto aten_output = std::get<0>(aten_results); - auto aten_mean = std::get<1>(aten_results); - auto aten_rstd = std::get<2>(aten_results); - - FusionExecutorCache fec(std::move(fusion_ptr)); - std::vector aten_inputs = { - aten_grad_out, aten_input, aten_mean, aten_rstd, aten_weight, aten_bias}; - auto cg_outputs = fec.runFusionWithInputs(aten_inputs); - - auto aten_gradients = at::native_layer_norm_backward( - aten_grad_out.to(at::kDouble), - aten_input.to(at::kDouble), - norm_shape, - aten_mean.to(at::kDouble), - aten_rstd.to(at::kDouble), - c10::optional(aten_weight.to(at::kDouble)), - c10::optional(aten_bias.to(at::kDouble)), - {true, true, true}); - - testValidate( - &fusion, - cg_outputs, - aten_inputs, - {std::get<0>(aten_gradients), - std::get<1>(aten_gradients), - std::get<2>(aten_gradients)}, - __LINE__, - __FILE__); -} - -TEST_F(NVFuserTest, FusionMagicSchedulerRMSNormBackward_CUDA) { - std::unique_ptr fusion_ptr = std::make_unique(); - Fusion& fusion = *fusion_ptr.get(); - FusionGuard fg(&fusion); - const int64_t NORM_SIZE = 1024; - std::vector shape{8, 56, NORM_SIZE}; - std::vector norm_shape{NORM_SIZE}; - - const size_t kM = shape.size(); - const size_t kN = norm_shape.size(); - const size_t kOuterNumDims = kM - kN; - - std::vector outer_shape; - for (const auto idx : c10::irange(kOuterNumDims)) { - outer_shape.push_back(shape[idx]); - } - for (const auto idx : c10::irange(kOuterNumDims, kM)) { - outer_shape.push_back(1); - } - - auto grad_out = makeContigTensor(shape.size()); - auto input = makeContigTensor(shape.size()); - auto rstd = makeConcreteTensor(outer_shape); - auto weight = makeContigTensor(norm_shape.size()); - fusion.addInput(grad_out); - fusion.addInput(input); - fusion.addInput(rstd); - fusion.addInput(weight); - - auto grads = rms_norm_backward( - grad_out, input, norm_shape, rstd, weight, {true, true}); - - fusion.addOutput(grads.grad_input); - fusion.addOutput(grads.grad_weight); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor aten_grad_out = at::randn(shape, options); - at::Tensor aten_input = at::randn(shape, options); - at::Tensor aten_weight = at::randn(norm_shape, options); - auto at_weight = c10::optional(aten_weight); - - const float kEps = 1e-6; - auto pow2 = at::pow(aten_input, 2); - auto sum = at::sum(pow2, -1, true); - auto var = at::mul(sum, 1.0 / NORM_SIZE); - auto aten_rstd = at::pow(at::add(var, kEps), -0.5); - - FusionExecutorCache fec(std::move(fusion_ptr)); - std::vector aten_inputs = { - aten_grad_out, aten_input, aten_rstd, aten_weight}; - auto cg_outputs = fec.runFusionWithInputs(aten_inputs); - - auto in_mul_rstd = at::mul(aten_input, aten_rstd); - auto grad_out_mul = at::mul(aten_grad_out, in_mul_rstd); - auto aten_grad_weight = at::sum(grad_out_mul, c10::IntArrayRef{0, 1}); - auto sum_loss1 = at::sum(at::mul(aten_grad_out, aten_weight), -1, true); - auto sum_loss2 = at::sum( - at::mul( - at::mul(at::mul(aten_grad_out, aten_weight), aten_input), aten_rstd), - -1, - true); - - const float fH = NORM_SIZE; - auto term1 = at::mul(aten_rstd, 1.0 / fH); - auto aten_grad_input = at::mul(at::mul(aten_grad_out, fH), aten_weight); - aten_grad_input = at::sub(aten_grad_input, sum_loss1); - aten_grad_input = at::sub( - aten_grad_input, at::mul(at::mul(aten_input, aten_rstd), sum_loss2)); - aten_grad_input = at::mul(aten_grad_input, term1); - testValidate( - &fusion, - cg_outputs, - aten_inputs, - {aten_grad_input, aten_grad_weight}, - __LINE__, - __FILE__); -} - -TEST_F(NVFuserTest, FusionMagicSchedulerLayerNormalization_CUDA) { - std::unique_ptr fusion_ptr = std::make_unique(); - Fusion& fusion = *fusion_ptr.get(); - FusionGuard fg(&fusion); - - const float kEps = 1e-5; - Double* eps_ptr = IrBuilder::create(kEps); - - std::vector input_shape{20, 100, 35, 67}; - std::vector norm_shape{67}; - - auto input = makeSymbolicTensor(input_shape.size()); - fusion.addInput(input); - - auto result = layer_norm(input, norm_shape, nullptr, nullptr, eps_ptr); - - fusion.addOutput(result.output); - fusion.addOutput(result.mean); - fusion.addOutput(result.invstd); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor aten_input = at::randn(input_shape, options); - c10::optional aten_weight = c10::nullopt; - c10::optional aten_bias = c10::nullopt; - auto aten_outputs = at::native_layer_norm( - aten_input, norm_shape, aten_weight, aten_bias, kEps); - - // Check reduction axis is same for all reductions - // Generate Launch Parameters - auto reduction_params = getPersistentHeuristics(&fusion, {aten_input}); - TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); - - FusionExecutorCache fec(std::move(fusion_ptr)); - auto cg_outputs = fec.runFusionWithInputs({aten_input}); - - testValidate( - &fusion, - cg_outputs, - {aten_input}, - {std::get<0>(aten_outputs), - std::get<1>(aten_outputs), - std::get<2>(aten_outputs)}, - __LINE__, - __FILE__, - ""); -} - -TEST_F(NVFuserTest, FusionMagicSchedulerRMSNormalization_CUDA) { - std::unique_ptr fusion_ptr = std::make_unique(); - Fusion& fusion = *fusion_ptr.get(); - FusionGuard fg(&fusion); - - int64_t NORM_SIZE = 1024; - const float kEps = 1e-6; - Double* eps_ptr = IrBuilder::create(kEps); - - std::vector input_shape{8, 56, NORM_SIZE}; - std::vector norm_shape{NORM_SIZE}; - - auto input = makeContigTensor(input_shape.size()); - fusion.addInput(input); - auto result = rms_norm(input, norm_shape, nullptr, eps_ptr); - - fusion.addOutput(result.output); - fusion.addOutput(result.invstd); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor aten_input = at::randn(input_shape, options); - c10::optional aten_weight = c10::nullopt; - - auto pow2 = at::pow(aten_input, 2); - - auto sum = at::sum(pow2, -1, true); - auto var = at::mul(sum, 1.0 / NORM_SIZE); - auto invstd = at::pow(at::add(var, kEps), -0.5); - auto output = at::mul(aten_input, invstd); - //// Check reduction axis is same for all reductions - //// Generate Launch Parameters - auto reduction_params = getPersistentHeuristics(&fusion, {aten_input}); - TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); - - FusionExecutorCache fec(std::move(fusion_ptr)); - auto cg_outputs = fec.runFusionWithInputs({aten_input}); - - testValidate( - &fusion, - cg_outputs, - {aten_input}, - {output, invstd}, - __LINE__, - __FILE__, - ""); -} - -TEST_F(NVFuserTest, FusionMagicSchedulerBatchNormalization_CUDA) { - if (!deviceMajorMinorCheck(7)) { - GTEST_SKIP() << "skipping tests on pre-Volta GPUs"; - return; - } - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - const float kMomentum = 0.1; - const float kEps = 1e-5; - const bool kTraining = true; - std::vector input_shape{20, 100, 35, 45}; - - auto input = makeSymbolicTensor(input_shape.size()); - auto weight = makeSymbolicTensor(1); - auto bias = makeSymbolicTensor(1); - auto running_mean = makeSymbolicTensor(1); - auto running_var = makeSymbolicTensor(1); - fusion->addInput(input); - fusion->addInput(weight); - fusion->addInput(bias); - fusion->addInput(running_mean); - fusion->addInput(running_var); - - Double* momentum = IrBuilder::create(kMomentum); - Double* eps = IrBuilder::create(kEps); - - auto result = batch_norm( - input, weight, bias, running_mean, running_var, kTraining, momentum, eps); - - fusion->addOutput(result.output); - fusion->addOutput(result.mean); - fusion->addOutput(result.invstd); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto at_input = at::randn(input_shape, options); - auto at_weight = at::ones({input_shape[1]}, options); - auto at_bias = at::zeros({input_shape[1]}, options); - auto at_run_mean = at::zeros({input_shape[1]}, options); - auto at_run_var = at::ones({input_shape[1]}, options); - - std::vector aten_inputs = { - at_input, at_weight, at_bias, at_run_mean, at_run_var}; - - FusionExecutorCache executor_cache(std::move(fusion)); - - auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); - - auto aten_outputs = at::native_batch_norm( - at_input, - c10::optional(at_weight), - c10::optional(at_bias), - c10::optional(at_run_mean), - c10::optional(at_run_var), - kTraining, - kMomentum, - kEps); - - testValidate( - executor_cache.fusion(), - cg_outputs, - aten_inputs, - {std::get<0>(aten_outputs), - std::get<1>(aten_outputs), - std::get<2>(aten_outputs)}, - __LINE__, - __FILE__, - ""); -} - -TEST_F(NVFuserTest, FusionMagicSchedulerInstanceNormalization_CUDA) { - if (!deviceMajorMinorCheck(7)) { - GTEST_SKIP() << "skipping tests on pre-Volta GPUs"; - return; - } - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - const float kMomentum = 0.1; - const float kEps = 1e-5; - const bool kUseInputStats = true; - std::vector input_shape{20, 100, 35, 45}; - - auto input = makeSymbolicTensor(input_shape.size()); - auto weight = makeSymbolicTensor(1); - auto bias = makeSymbolicTensor(1); - auto running_mean = makeSymbolicTensor(1); - auto running_var = makeSymbolicTensor(1); - fusion->addInput(input); - fusion->addInput(weight); - fusion->addInput(bias); - fusion->addInput(running_mean); - fusion->addInput(running_var); - - Double* momentum = IrBuilder::create(kMomentum); - Double* eps = IrBuilder::create(kEps); - - auto result = instance_norm( - input, - weight, - bias, - running_mean, - running_var, - kUseInputStats, - momentum, - eps); - - fusion->addOutput(result.output); - // fusion->addOutput(result.mean); - // fusion->addOutput(result.invstd); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto at_input = at::randn(input_shape, options); - auto at_weight = at::ones({input_shape[1]}, options); - auto at_bias = at::zeros({input_shape[1]}, options); - auto at_run_mean = at::zeros({input_shape[1]}, options); - auto at_run_var = at::ones({input_shape[1]}, options); - - std::vector aten_inputs = { - at_input, at_weight, at_bias, at_run_mean, at_run_var}; - - FusionExecutorCache executor_cache(std::move(fusion)); - - auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); - auto cg_outputs_full = {at_run_mean, at_run_var, cg_outputs[0]}; - - auto aten_outputs = at::instance_norm( - at_input, - c10::optional(at_weight), - c10::optional(at_bias), - c10::optional(at_run_mean), - c10::optional(at_run_var), - kUseInputStats, - kMomentum, - kEps, - false); - - testValidate( - executor_cache.fusion(), - cg_outputs, - aten_inputs, - // TODO: can run_mean/run_var be checked here? - // fusion_outputs.size() == aten_outputs.size() && aten_outputs.size() == - // fusion->outputs().size() - output_alias_indices.size() - {aten_outputs}, - __LINE__, - __FILE__, - ""); -} - -TEST_F(NVFuserTest, FusionMagicSchedulerInstanceNormalizationBackward_CUDA) { - if (!deviceMajorMinorCheck(7)) { - GTEST_SKIP() << "skipping tests on pre-Volta GPUs"; - return; - } - auto fusion_forward = std::make_unique(); - FusionGuard fg_forward(fusion_forward.get()); - - const float kMomentum = 0.1; - const float kEps = 1e-5; - const bool kUseInputStats = true; - const bool channels_last = true; - const int B = 2; - const int C = 5; - const int S = 3; - std::vector input_shape{B, C, S, S, S}; - // explicit channels-last for NVFuser - std::vector nvfuser_input_shape{B, S, S, S, C}; - - auto input = makeContigTensor(input_shape.size()); - auto weight = makeContigTensor(1); - auto bias = makeContigTensor(1); - fusion_forward->addInput(input); - fusion_forward->addInput(weight); - fusion_forward->addInput(bias); - - Double* momentum = IrBuilder::create(kMomentum); - Double* eps = IrBuilder::create(kEps); - auto result_forward = instance_norm( - input, - weight, - bias, - nullptr, - nullptr, - kUseInputStats, - momentum, - eps, - channels_last); - fusion_forward->addOutput(result_forward.output); - fusion_forward->addOutput(result_forward.mean); - fusion_forward->addOutput(result_forward.invstd); - - FusionExecutorCache executor_cache_forward(std::move(fusion_forward)); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto at_input = at::randn(input_shape, options) - .to(at::MemoryFormat::ChannelsLast3d) - .set_requires_grad(true); - auto at_input_nvfuser = at_input.clone().detach().permute({0, 2, 3, 4, 1}); - auto at_weight = at::ones({input_shape[1]}, options).set_requires_grad(true); - auto at_weight_nvfuser = at_weight.clone().detach(); - auto at_bias = at::zeros({input_shape[1]}, options).set_requires_grad(true); - auto at_bias_nvfuser = at_bias.clone().detach(); - std::vector aten_inputs_forward = { - at_input_nvfuser, at_weight_nvfuser, at_bias_nvfuser}; - // out, mean, invstd - auto outputs_forward = - executor_cache_forward.runFusionWithInputs(aten_inputs_forward); - auto at_out = at::instance_norm( - at_input, - c10::optional(at_weight), - c10::optional(at_bias), - c10::optional(c10::nullopt), - c10::optional(c10::nullopt), - kUseInputStats, - kMomentum, - kEps, - false); - auto at_grad = - at::randn(input_shape, options).to(at::MemoryFormat::ChannelsLast3d); - auto at_grad_nvfuser = at_grad.clone().detach().permute({0, 2, 3, 4, 1}); - at_out.backward(at_grad); - auto fusion_backward = std::make_unique(); - FusionGuard fg_backward(fusion_backward.get()); - - input = makeContigTensor(input_shape.size()); - auto grad_output = makeContigTensor(input_shape.size()); - weight = makeContigTensor(1); - auto save_mean = makeContigTensor(2); - auto save_invstd = makeContigTensor(2); - auto dummy = makeContigTensor(0); - - fusion_backward->addInput(input); - fusion_backward->addInput(grad_output); - fusion_backward->addInput(weight); - fusion_backward->addInput(dummy); // dummy for run_mean - fusion_backward->addInput(dummy); // dummy for run_var - fusion_backward->addInput(save_mean); - fusion_backward->addInput(save_invstd); - - auto result_backward = instance_norm_backward( - input, - grad_output, - weight, - nullptr, - nullptr, - save_mean, - save_invstd, - kUseInputStats, - eps, - {true, true, true}, - channels_last); - - fusion_backward->addOutput(result_backward.grad_input); - fusion_backward->addOutput(result_backward.grad_weight); - fusion_backward->addOutput(result_backward.grad_bias); - - FusionExecutorCache executor_cache_backward(std::move(fusion_backward)); - std::vector aten_inputs_backward = { - at_input_nvfuser, - at_grad_nvfuser, - at_weight_nvfuser, - at::empty({}), - at::empty({}), - outputs_forward[1], - outputs_forward[2]}; - auto outputs_backward = - executor_cache_backward.runFusionWithInputs(aten_inputs_backward); - outputs_backward[0] = outputs_backward[0].permute({0, 4, 1, 2, 3}); - testValidate( - executor_cache_backward.fusion(), - outputs_backward, - aten_inputs_backward, - {at_input.grad(), at_weight.grad(), at_bias.grad()}, - __LINE__, - __FILE__, - ""); -} - -TEST_F(NVFuserTest, FusionPersistentSoftmaxLocalSmem_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - const int pixels_per_thread = 64; - const int TIDX = 128; - const int static_size = pixels_per_thread * TIDX; - - TensorView* sx = makeConcreteTensor({-1, static_size}); - TensorView* dx = makeSymbolicTensor(2); - fusion.addInput(sx); - fusion.addInput(dx); - - TensorView* max_sx = reductionOp( - BinaryOpType::Max, - {-1}, - IrBuilder::create(std::numeric_limits::lowest()), - sx); // (M) - TensorView* max_dx = reductionOp( - BinaryOpType::Max, - {-1}, - IrBuilder::create(std::numeric_limits::lowest()), - dx); // (M) - - // Reduction => merge local and shared memory TensorViews - TensorView* max_val = binaryOp(BinaryOpType::Max, max_sx, max_dx); - TensorView* bcast_max = broadcast(max_val, {false, true}); // (M, B) - - TensorView* sx_max_sub = sub(sx, bcast_max); // (M, N) - TensorView* dx_max_sub = sub(dx, bcast_max); // (M, N) - - TensorView* sx_exp = unaryOp(UnaryOpType::Exp, sx_max_sub); // (M, N) - TensorView* dx_exp = unaryOp(UnaryOpType::Exp, dx_max_sub); // (M, N) - - TensorView* sx_sum_exp = sum(sx_exp, {-1}); // (M, R) - TensorView* dx_sum_exp = sum(dx_exp, {-1}); // (M, R) - - // Reduction => merge local and shared memory TensorViews - TensorView* sum_exp = binaryOp(BinaryOpType::Add, sx_sum_exp, dx_sum_exp); - TensorView* bcast_sum = broadcast(sum_exp, {false, true}); // (M, B) - - TensorView* sx_softmax = div(sx_exp, bcast_sum); // (M, N) - TensorView* dx_softmax = div(dx_exp, bcast_sum); // (M, N) - fusion.addOutput(sx_softmax); - fusion.addOutput(dx_softmax); - - auto sx_cache = sx->cacheAfter(); - auto dx_cache = dx->cacheAfter(); - dx_cache->setMemoryType(MemoryType::Shared); - dx_exp->setMemoryType(MemoryType::Shared); - - // Reduction and Broadcast Tensors common to both memory TVs - std::vector common_tensors( - {max_val, sum_exp, bcast_max, bcast_sum}); - - // Static Local Memory TVs - std::vector static_tensors( - {sx, sx_cache, max_sx, sx_max_sub, sx_exp, sx_sum_exp, sx_softmax}); - - // Dynamic Local Memory TVs - std::vector dynamic_tensors( - {dx, dx_cache, max_dx, dx_max_sub, dx_exp, dx_sum_exp, dx_softmax}); - - std::vector all_tensors; - all_tensors.insert( - all_tensors.end(), common_tensors.begin(), common_tensors.end()); - all_tensors.insert( - all_tensors.end(), static_tensors.begin(), static_tensors.end()); - all_tensors.insert( - all_tensors.end(), dynamic_tensors.begin(), dynamic_tensors.end()); - - // M => M - // M, N => M, N/128, 128 - for (auto tensor : all_tensors) { - if (tensor->nDims() > 1) { - tensor->split(-1, TIDX); - } - } - - auto sx_sum_exp_rf = sx_sum_exp->rFactor({1}); - auto dx_sum_exp_rf = dx_sum_exp->rFactor({1}); - all_tensors.push_back(sx_sum_exp_rf); - all_tensors.push_back(dx_sum_exp_rf); - - // computeAt - sx->computeAt(sx_max_sub, 1); - dx->computeAt(dx_max_sub, 1); - - sx_exp->computeAt(sx_softmax, 1); - dx_exp->computeAt(dx_softmax, 1); - - sx_max_sub->computeAt(sx_exp, 2); - dx_max_sub->computeAt(dx_exp, 2); - - sx_softmax->axis(0)->parallelize(ParallelType::BIDx); - dx_softmax->axis(0)->parallelize(ParallelType::BIDx); - for (auto tensor : all_tensors) { - if (tensor->nDims() > 1) { - tensor->axis(-1)->parallelize(ParallelType::TIDx); - } - } - - const int64_t dimx = 1024; - const int64_t dimy = 16384; - - auto properties = at::cuda::getDeviceProperties(0); - // Require 70KB of smem to run test - const size_t required_smem_size = 70 << 10; - if (properties->sharedMemPerBlockOptin < required_smem_size) { - GTEST_SKIP() << "not enough shared memory space on device to run test"; - } - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor aten_input = at::randn({dimx, dimy}, options); - at::Tensor aten_static_in = aten_input.narrow(1, 0, static_size); - at::Tensor aten_dynamic_in = - aten_input.narrow(1, static_size, dimy - static_size); - - at::Tensor out = at::zeros({dimx, dimy}, options); - at::Tensor cg_static_out = out.narrow(1, 0, static_size); - at::Tensor cg_dynamic_out = out.narrow(1, static_size, dimy - static_size); - - std::vector aten_outputs; - - auto aten_output = at::_softmax(aten_input.to(at::kDouble), -1, false); - at::Tensor aten_static_out = aten_output.narrow(1, 0, static_size); - at::Tensor aten_dynamic_out = - aten_output.narrow(1, static_size, dimy - static_size); - - torch::jit::fuser::cuda::FusionExecutor fe; - fe.compileFusion(&fusion, {aten_static_in, aten_dynamic_in}); - fe.runFusion( - {aten_static_in, aten_dynamic_in}, {cg_static_out, cg_dynamic_out}); - - testValidate( - &fusion, - {cg_static_out, cg_dynamic_out}, - {aten_static_in, aten_dynamic_in}, - {cg_static_out, cg_dynamic_out}, - __LINE__, - __FILE__); -} - -TEST_F(NVFuserTest, FusionPersistentNormLocalShared_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - const int pixels_per_thread = 64; - const int TIDX = 128; - const int static_size = pixels_per_thread * TIDX; - - TensorView* sx = makeConcreteTensor({-1, static_size}); - TensorView* dx = makeSymbolicTensor(2); - fusion.addInput(sx); - fusion.addInput(dx); - - Double* gamma = IrBuilder::create(); - Double* beta = IrBuilder::create(); - Double* eps = IrBuilder::create(); - Int* N = IrBuilder::create(); - fusion.addInput(gamma); - fusion.addInput(beta); - fusion.addInput(eps); - fusion.addInput(N); - - // Reduction - auto sx_sum = sum(sx, {-1}); // (M, R) - auto dx_sum = sum(dx, {-1}); // (M, R) - // Reduction => merge local and shared memory TensorViews - auto x_sum = binaryOp(BinaryOpType::Add, sx_sum, dx_sum); - - // Broadcast - auto x_sum_bcast = broadcast(x_sum, {false, true}); // (M, B) - // Pwise - auto x_mean = div(x_sum_bcast, N); // (M, B) - - auto sx_mean_sub = sub(sx, x_mean); // (M, N) - auto dx_mean_sub = sub(dx, x_mean); // (M, N) - - auto sx_mean_sub_pow = mul(sx_mean_sub, sx_mean_sub); // (M, N) - auto dx_mean_sub_pow = mul(dx_mean_sub, dx_mean_sub); // (M, N) - - // Reduction - auto sx_var_sum = sum(sx_mean_sub_pow, {-1}); // (M, R) - auto dx_var_sum = sum(dx_mean_sub_pow, {-1}); // (M, R) - // Reduction => merge local and shared memory TensorViews - auto var_sum = binaryOp(BinaryOpType::Add, sx_var_sum, dx_var_sum); - - // Broadcast - auto var_sum_bcast = broadcast(var_sum, {false, true}); // (M, B) - // Pwise - auto var = div(var_sum_bcast, N); // (M, B) - auto var_eps = add(var, eps); // (M, B) - auto rvar = unaryOp(UnaryOpType::Rsqrt, var_eps); // (M, B) - - auto sx_norm = mul(sx_mean_sub, rvar); - auto dx_norm = mul(dx_mean_sub, rvar); - - auto sx_norm_gamma = mul(sx_norm, gamma); - auto dx_norm_gamma = mul(dx_norm, gamma); - - auto sx_norm_gamma_beta = add(sx_norm_gamma, beta); - auto dx_norm_gamma_beta = add(dx_norm_gamma, beta); - - fusion.addOutput(sx_norm_gamma_beta); - fusion.addOutput(dx_norm_gamma_beta); - - sx_norm_gamma_beta->setContiguity(false); - dx_norm_gamma_beta->setContiguity(false); - - // Read Input into Shared Memory - // Read Input minus Input_Mean into Shared Memory - auto sx_cache = sx->cacheAfter(); - auto dx_cache = dx->cacheAfter(); - dx_cache->setMemoryType(MemoryType::Shared); - dx_mean_sub->setMemoryType(MemoryType::Shared); - - std::vector common_tensors( - {x_sum, x_sum_bcast, x_mean, var_sum, var_sum_bcast, var, var_eps, rvar}); - - std::vector static_tensors( - {sx, - sx_cache, - sx_sum, - sx_mean_sub, - sx_mean_sub_pow, - sx_var_sum, - sx_norm, - sx_norm_gamma, - sx_norm_gamma_beta}); - - std::vector dynamic_tensors( - {dx, - dx_cache, - dx_sum, - dx_mean_sub, - dx_mean_sub_pow, - dx_var_sum, - dx_norm, - dx_norm_gamma, - dx_norm_gamma_beta}); - - std::vector all_tensors; - all_tensors.insert( - all_tensors.end(), common_tensors.begin(), common_tensors.end()); - all_tensors.insert( - all_tensors.end(), static_tensors.begin(), static_tensors.end()); - all_tensors.insert( - all_tensors.end(), dynamic_tensors.begin(), dynamic_tensors.end()); - - // M => M - // M, N => M, N/128, 128 - for (auto tensor : all_tensors) { - if (tensor->nDims() > 1) { - tensor->split(-1, TIDX); - } - } - - // Local Sum => Block Broadcast - TensorView* sx_sum_rf = sx_sum->rFactor({1}); - TensorView* sx_var_sum_rf = sx_var_sum->rFactor({1}); - TensorView* dx_sum_rf = dx_sum->rFactor({1}); - TensorView* dx_var_sum_rf = dx_var_sum->rFactor({1}); - all_tensors.push_back(sx_sum_rf); - all_tensors.push_back(sx_var_sum_rf); - all_tensors.push_back(dx_sum_rf); - all_tensors.push_back(dx_var_sum_rf); - - // ComputeAt - sx->computeAt(sx_mean_sub_pow, 1); - dx->computeAt(dx_mean_sub_pow, 1); - - var_sum->computeAt(rvar, 1); - - sx_mean_sub_pow->computeAt(sx_var_sum_rf, 2); - dx_mean_sub_pow->computeAt(dx_var_sum_rf, 2); - - sx_norm->computeAt(sx_norm_gamma_beta, 2); - dx_norm->computeAt(dx_norm_gamma_beta, 2); - - sx_norm_gamma_beta->axis(0)->parallelize(ParallelType::BIDx); - dx_norm_gamma_beta->axis(0)->parallelize(ParallelType::BIDx); - for (auto tensor : all_tensors) { - if (tensor->nDims() > 1) { - tensor->axis(-1)->parallelize(ParallelType::TIDx); - } - } - - const int dimx = 1024; - const int dimy = 16384; - const float kGamma = 1.0f; - const float kBeta = 0.0f; - const float kEps = 1e-5; - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::Tensor aten_input = at::randn({dimx, dimy}, options); - at::Tensor aten_static_in = aten_input.narrow(1, 0, static_size); - at::Tensor aten_dynamic_in = - aten_input.narrow(1, static_size, dimy - static_size); - - at::Tensor out = at::zeros({dimx, dimy}, options); - at::Tensor cg_static_out = out.narrow(1, 0, static_size); - at::Tensor cg_dynamic_out = out.narrow(1, static_size, dimy - static_size); - - std::vector aten_inputs = { - aten_static_in, aten_dynamic_in, kGamma, kBeta, kEps, dimy}; - - torch::jit::fuser::cuda::FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - - auto properties = at::cuda::getDeviceProperties(0); - // Require 70KB of smem to run test - const size_t required_smem_size = 70 << 10; - if (properties->sharedMemPerBlockOptin < required_smem_size) { - GTEST_SKIP() << "not enough shared memory space on device to run test"; - } - - fe.runFusion(aten_inputs, {cg_static_out, cg_dynamic_out}); - - auto at_mu = at::mean(aten_input.to(at::kDouble), -1).unsqueeze(1); - auto at_var = at::var(aten_input.to(at::kDouble), -1, false).unsqueeze(1); - auto at_rvar = at::rsqrt(at::add(at_var, kEps)); - auto at_norm = at::mul(at::sub(aten_input, at_mu), at_rvar); - auto aten_output = at::add(at::mul(at_norm, kGamma), kBeta); - at::Tensor aten_static_out = aten_output.narrow(1, 0, static_size); - at::Tensor aten_dynamic_out = - aten_output.narrow(1, static_size, dimy - static_size); - - testValidate( - &fusion, - {cg_static_out, cg_dynamic_out}, - aten_inputs, - {aten_static_out, aten_dynamic_out}, - __LINE__, - __FILE__); -} - -TEST_F(NVFuserTest, FusionSmemDynamicPersistentNorm_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Set up your input tensor views - auto x = makeSymbolicTensor(2); - Double* gamma = IrBuilder::create(); - Double* beta = IrBuilder::create(); - Double* eps = IrBuilder::create(); - Int* N = IrBuilder::create(); - fusion.addInput(x); - fusion.addInput(gamma); - fusion.addInput(beta); - fusion.addInput(eps); - fusion.addInput(N); - - // Reduction - auto x_sum = sum(x, {-1}); // (M, R) - // Broadcast - auto x_sum_bcast = broadcast(x_sum, {false, true}); // (M, B) - // Pwise - auto x_mean = div(x_sum_bcast, N); // (M, B) - auto x_mean_sub = sub(x, x_mean); // (M, N) - auto x_mean_sub_pow = mul(x_mean_sub, x_mean_sub); // (M, N) - // Reduction - auto var_sum = sum(x_mean_sub_pow, {-1}); // (M, R) - // Broadcast - auto var_sum_bcast = broadcast(var_sum, {false, true}); // (M, B) - // Pwise - auto var = div(var_sum_bcast, N); // (M, B) - auto var_eps = add(var, eps); // (M, B) - auto rvar = unaryOp(UnaryOpType::Rsqrt, var_eps); // (M, B) - auto norm = mul(x_mean_sub, rvar); - auto norm_gamma = mul(norm, gamma); - auto norm_gamma_beta = add(norm_gamma, beta); - fusion.addOutput(norm_gamma_beta); - - // Read Input into Shared Memory - // Read Input minus Input_Mean into Shared Memory - auto cache_x = x->cacheAfter(); - cache_x->setMemoryType(MemoryType::Shared); - x_mean_sub->setMemoryType(MemoryType::Shared); - - std::vector all_tensors( - {x_sum, - x_mean, - cache_x, - x_sum_bcast, - x_mean_sub, - x_mean_sub_pow, - var_sum, - var_sum_bcast, - var, - var_eps, - rvar, - norm, - norm_gamma, - norm_gamma_beta}); - - auto tidx = IrBuilder::create(); - fusion.addInput(tidx); - - for (auto tensor : all_tensors) { - tensor->split(-1, tidx); - } - - // Local Sum => Block Broadcast - TensorView* x_sum_rf = x_sum->rFactor({1}); - TensorView* var_sum_rf = var_sum->rFactor({1}); - all_tensors.push_back(x_sum_rf); - all_tensors.push_back(var_sum_rf); - - // ComputeAt - x->computeAt(x_mean_sub_pow, 1); - var_sum->computeAt(rvar, 1); - x_mean_sub_pow->computeAt(var_sum_rf, 2); - norm->computeAt(norm_gamma_beta, 2); - - for (auto tv : all_tensors) { - tv->axis(0)->parallelize(ParallelType::BIDx); - tv->axis(-1)->parallelize(ParallelType::TIDx); - } - - const int dimx = 128; - const int dimy = 2048; - const float kGamma = 1.0f; - const float kBeta = 0.0f; - const float kEps = 1e-5; - const int TIDX = 128; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor aten_input = at::randn({dimx, dimy}, options); - auto at_mu = at::mean(aten_input.to(at::kDouble), -1).unsqueeze(1); - auto at_var = at::var(aten_input.to(at::kDouble), -1).unsqueeze(1); - auto at_rvar = at::rsqrt(at::add(at_var, kEps)); - auto at_norm = at::mul(at::sub(aten_input, at_mu), at_rvar); - auto aten_output = at::add(at::mul(at_norm, kGamma), kBeta); - - std::vector aten_inputs = { - aten_input, kGamma, kBeta, kEps, dimy, TIDX}; - - torch::jit::fuser::cuda::FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto cg_outputs = fe.runFusion(aten_inputs); - - testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionSmemDynamicReductionSymbolic_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Set up your input tensor views - TensorView* tv0 = makeSymbolicTensor(2); - TensorView* tv1 = - reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv0); - fusion.addInput(tv0); - fusion.addOutput(tv1); - // tv1[I0, R1] = tv0[I0, I1] - - // Interface should just be a direct split with a Parallel type. We can - // include the parallelize call if we do this. - tv1->split(1, NamedScalar::getParallelDim(ParallelType::TIDx)); - // tv1[I0, R1o, R1i{BIDx}] = tv0[I0, I1] - - TensorView* tv2 = tv1->rFactor({2}); - tv2->setMemoryType(MemoryType::Shared); - // tv2[I0, R1oo, Ir1i{BIDx}] = tv0[I0, I1] - // tv1[I0, R1i{BIDx}] = tv2[I0, R1oo, Ir1i{BIDx}] - - tv0->computeAt(tv1, 1); - - tv2->axis(-1)->parallelize(ParallelType::TIDx); - tv1->axis(0)->parallelize(ParallelType::BIDx); - - constexpr int numel_x = 65000, numel_y = 1024; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor aten_input = at::randn({numel_x, numel_y}, options); - auto aten_output = aten_input.to(at::kDouble).sum({1}); - - // How many threads to use for the block reduction - constexpr int runtime_threadIdx_dim = 128; - - LaunchParams lparams(-1, -1, -1, runtime_threadIdx_dim, -1, -1); - - FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input}, lparams); - auto cg_outputs = fe.runFusion({aten_input}, lparams); - - testValidate( - &fusion, - cg_outputs, - {aten_input}, - {aten_output}, - __LINE__, - __FILE__, - "", - lparams); - TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 0); -} - -TEST_F(NVFuserTest, FusionSmemDynamicReductionSymbolicArg_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Algorithm - Int* sym_bsx = IrBuilder::create(); - TensorView* tv0 = makeSymbolicTensor(3); // M, K, N - fusion.addInput(tv0); - fusion.addInput(sym_bsx); - - TensorView* tv1 = sum(tv0, {1}); // M, R, N - fusion.addOutput(tv1); - - TensorView* tv2 = tv0->cacheAfter(); - tv2->setMemoryType(MemoryType::Shared); - - // Schedule - constexpr int BSX = 32; - tv1->split(2, BSX); - tv1->split(1, sym_bsx); - tv1->split(0, BSX); - // M/BSX, BSX, K/BSX, BSX, N/BSX, BSX - tv1->reorder({{0, 0}, {1, 2}, {2, 4}, {3, 5}, {4, 1}, {5, 3}}); - TensorView* tv3 = tv1->rFactor({-2}); - - tv0->computeAt(tv1, -2); - tv0->computeAt(tv3, -2); - - // Thread and Block binding - tv1->axis(0)->parallelize(ParallelType::BIDx); - tv1->axis(1)->parallelize(ParallelType::BIDy); - tv1->axis(-1)->parallelize(ParallelType::TIDx); - // Manual Binding - tv2->axis(-1)->parallelize(ParallelType::TIDx); - tv3->axis(-1)->parallelize(ParallelType::TIDx); - - constexpr int M = 154, K = 45, N = 1524; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor aten_input = at::randn({M, K, N}, options); - at::Tensor aten_output = aten_input.to(at::kDouble).sum({1}); - - // How many threads to use for the block reduction - constexpr int runtime_threadIdx_dim = 128; - - auto lparams = LaunchParams(-1, -1, -1, runtime_threadIdx_dim, -1, -1); - - FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input, runtime_threadIdx_dim}, lparams); - auto cg_outputs = fe.runFusion({aten_input, runtime_threadIdx_dim}, lparams); - - testValidate( - &fusion, - cg_outputs, - {aten_input, runtime_threadIdx_dim}, - {aten_output}, - __LINE__, - __FILE__, - "", - lparams); - - TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 0); -} - -TEST_F(NVFuserTest, FusionSmemDynamicPwiseMulSymbolicArgWAR_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - Int* sym_bsx = IrBuilder::create(); - TensorView* tv0 = makeSymbolicTensor(2); // (M, K) - TensorView* tv1 = makeSymbolicTensor(2); // (K, N) - TensorView* tv2 = broadcast(tv0, {false, false, true}); // (M, K, B) - TensorView* tv3 = broadcast(tv1, {true, false, false}); // (B, K, N) - TensorView* tv4 = mul(tv2, tv3); // M, K, N - fusion.addInput(tv0); - fusion.addInput(tv1); - fusion.addInput(sym_bsx); - fusion.addOutput(tv4); - // Algorithm - - tv2->setMemoryType(MemoryType::Shared); - tv3->setMemoryType(MemoryType::Shared); - - constexpr int BSX = 32; - tv4->split(2, BSX); - tv4->split(1, sym_bsx); - tv4->split(0, BSX); - // M/BSX, BSX, K/BSX, BSX, N/BSX, BSX - tv4->reorder({{0, 0}, {1, 3}, {2, 1}, {3, 4}, {4, 2}, {5, 5}}); - // M/BSX, K/BSX, N/BSX, MSX, KSX, NSX - - tv0->computeAt(tv4, 3); - tv1->computeAt(tv4, 3); - // Schedule - - tv4->axis(0)->parallelize(ParallelType::BIDx); - tv4->axis(2)->parallelize(ParallelType::BIDy); - // Manual Binding - tv2->axis(-2)->parallelize(ParallelType::TIDx); - tv3->axis(-1)->parallelize(ParallelType::TIDx); - // Thread and Block binding - - constexpr int M = 128, K = 457, N = 1024; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({M, K}, options); - at::Tensor t1 = at::randn({K, N}, options); - at::Tensor aten_output = mul(t0.unsqueeze(2), t1.unsqueeze(0)); - std::vector aten_inputs = {t0, t1, BSX}; - - LaunchParams lparams(-1, -1, -1, BSX, -1, -1); - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs, lparams); - auto cg_outputs = fe.runFusion(aten_inputs, lparams); - - testValidate( - &fusion, - cg_outputs, - aten_inputs, - {aten_output}, - __LINE__, - __FILE__, - "", - lparams); - - TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 1); -} - -TEST_F(NVFuserTest, FusionSmemDynamicTiledGemm_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Symbolic integers we will use for runtime tiling - Int* symbolic_m_tile_dim = IrBuilder::create(); // bound to threadIdx.z - Int* symbolic_split_k_tile_dim = - IrBuilder::create(); // bound to blockIdx.x - Int* symbolic_block_k_tile_dim = - IrBuilder::create(); // bound to threadIdx.x - // Compile-time integer for tiling - int n_smem_tile = 8; // bound to threadIdx.y - - // Symbolic 2D tensors TV0[M, K], TV1[K, N] - TensorView* tv0 = makeSymbolicTensor(2); - TensorView* tv1 = makeSymbolicTensor(2); - - // Broadcast tv0 to [M, K, *] - TensorView* tv2 = broadcast(tv0, {false, false, true}); - // Broadcast tv1 to [*, K, N] - TensorView* tv3 = broadcast(tv1, {true, false, false}); - - // Pointwise multiplication resulting in tv3[M, K, N] - TensorView* tv4 = mul(tv2, tv3); - - // Turn the K-dimension of tv4 into a reduction dimension - TensorView* tv5 = sum(tv4, {1}); - - // Register inputs and outputs - fusion.addInput(tv0); - fusion.addInput(tv1); - fusion.addOutput(tv5); - - // Register runtime tile dims as inputs - fusion.addInput(symbolic_m_tile_dim); - fusion.addInput(symbolic_split_k_tile_dim); - fusion.addInput(symbolic_block_k_tile_dim); - - // Make a 3D tile, mix of symbolic and constant, do in reverse order because - // dims are inserted - // [M, K, N] - tv5->split(2, n_smem_tile); - tv5->split(1, symbolic_block_k_tile_dim); - tv5->split(1, symbolic_split_k_tile_dim); - tv5->split(0, symbolic_m_tile_dim); - // [Mo, Mi, Koo, Koi, Ki, No, Ni] - - // Reorder so all outer tiles are in the leftmost 3 positions - tv5->reorder({{1, 5}, {5, 1}}); - // [Mo, No, Koo, Koi, Ki, Mi, Ni] - - // Factor out the outer reduction IterDomain, then run the inter-cta - // reduction, and intra-cta reduction - auto tv6 = tv5->rFactor({2}); - // [Mo, No, rKoo, rKoi, rKi, Mi, Ni] - // [Mo, No, rKoi, rKi, Mi, Ni] - - // Scope computations - tv6->computeAt(tv5, 2); - // [Mo, No, rKoo, Koi, Ki, Mi, Ni] - // [Mo, No, rKoi, rKi, Mi, Ni] - - // Setup compute at schedule - tv0->computeAt(tv6, 3); - tv1->computeAt(tv6, 3); - tv4->computeAt(tv6, -1); - // - // T2[Mo, bNo, Koo, Koi, Kii, Mi, bNi] CA(4, 3) - // T3[bMo, No, Koo, Koi, Kii, bMi, Ni] CA(4, 3) - // T4[ Mo, No, Koo, Koi, Kii, Mi, Ni] - // T6[ Mo, No, rKoo, Koi, Kii, Mi, Ni] - // T5[ Mo, No, rKoi, rKii, Mi, Ni] - - // Cache smem tiles - tv2->setMemoryType(MemoryType::Shared); - tv3->setMemoryType(MemoryType::Shared); - tv4->setMemoryType(MemoryType::Local); - tv6->setMemoryType(MemoryType::Local); - - tv5->axis(0)->parallelize(ParallelType::BIDz); - tv5->axis(1)->parallelize(ParallelType::BIDy); - - std::vector tv_list = {tv2, tv3, tv4, tv5, tv6}; - for (auto tv : tv_list) { - tv->axis(-2)->parallelize(ParallelType::TIDz); - tv->axis(-1)->parallelize(ParallelType::TIDy); - } - tv2->axis(3)->parallelize(ParallelType::TIDx); - tv3->axis(3)->parallelize(ParallelType::TIDx); - tv4->axis(3)->parallelize(ParallelType::TIDx); - tv6->axis(3)->parallelize(ParallelType::TIDx); - tv5->axis(2)->parallelize(ParallelType::TIDx); - - tv2->axis(4)->parallelize(ParallelType::BIDx); - tv3->axis(4)->parallelize(ParallelType::BIDx); - tv4->axis(4)->parallelize(ParallelType::BIDx); - tv6->axis(4)->parallelize(ParallelType::BIDx); - tv5->axis(3)->parallelize(ParallelType::BIDx); - - constexpr int M = 31, K = 65, N = 33; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({M, K}, options); - at::Tensor t1 = at::randn({K, N}, options); - - // Runtime tiling - int m_tile = 4; // bound to threadIdx.z - int split_k = 7; // bound to blockIdx.x - int intra_cta = 8; // bound to threadIdx.x - - std::vector aten_inputs = {t0, t1, m_tile, split_k, intra_cta}; - at::Tensor aten_output = - mul(t0.unsqueeze(2), t1.unsqueeze(0)).to(at::kDouble).sum(1); - - FusionExecutor fe; - // Generate CUDA and compile with nvRTC - fe.compileFusion(&fusion, aten_inputs); - auto cg_outputs = fe.runFusion(aten_inputs); - - testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); - - TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 1); -} - -TEST_F(NVFuserTest, FusionGlobalIntermediate_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Set up your input tensor views - TensorView* tv0 = makeSymbolicTensor(2); - TensorView* tv1 = - reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv0); - fusion.addInput(tv0); - fusion.addOutput(tv1); - // tv1[I0, R1] = tv0[I0, I1] - - // Interface should just be a direct split with a Parallel type. We can - // include the parallelize call if we do this. - tv1->split(1, NamedScalar::getParallelDim(ParallelType::TIDx)); - // tv1[I0, R1o, R1i{BIDx}] = tv0[I0, I1] - - TensorView* tv2 = tv1->rFactor({2}); - tv2->setMemoryType(MemoryType::Global); - // tv2[I0, R1oo, Ir1i{BIDx}] = tv0[I0, I1] - // tv1[I0, R1i{BIDx}] = tv2[I0, R1oo, Ir1i{BIDx}] - - tv0->computeAt(tv1, 1); - - tv2->axis(-1)->parallelize(ParallelType::TIDx); - tv1->axis(0)->parallelize(ParallelType::BIDx); - - constexpr int numel_x = 65000, numel_y = 1024; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input = at::randn({numel_x, numel_y}, options); - - // How many threads to use for the block reduction - constexpr int runtime_threadIdx_dim = 128; - - auto lparams = LaunchParams(-1, -1, -1, runtime_threadIdx_dim, -1, -1); - - FusionExecutor fe; - fe.compileFusion(&fusion, {input}, lparams); - auto cg_outputs = fe.runFusion({input}, lparams); - - auto aten_output = input.to(at::kDouble).sum({1}); - testValidate( - &fusion, - cg_outputs, - {input}, - {aten_output}, - __LINE__, - __FILE__, - "", - lparams); -} - -TEST_F(NVFuserTest, FusionGlobalIntermediateDefaultSchedule_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeSymbolicTensor(2); - TensorView* tv1 = makeSymbolicTensor(2); - TensorView* tv2 = makeSymbolicTensor(2); - TensorView* tv3 = makeSymbolicTensor(2); - TensorView* tv4 = sub(tv2, tv3); - TensorView* tv5 = add(tv1, tv4); - TensorView* tv6 = sub(tv5, tv0); - fusion.addInput(tv0); - fusion.addInput(tv1); - fusion.addInput(tv2); - fusion.addInput(tv3); - fusion.addOutput(tv6); - // t6 = ((t1 + (t2 - t3)) - t0) - - tv4->setMemoryType(MemoryType::Global); - tv5->setMemoryType(MemoryType::Global); - tv6->setMemoryType(MemoryType::Global); - - constexpr int M = 32, N = 810; - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({M, N}, options); - at::Tensor t1 = at::randn({M, N}, options); - at::Tensor t2 = at::randn({M, N}, options); - at::Tensor t3 = at::randn({M, N}, options); - - at::Tensor aten_output = (t1 + (t2 - t3)) - t0; - - std::vector aten_inputs = {t0, t1, t2, t3}; - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0, t1, t2, t3}); - auto cg_outputs = fe.runFusion({t0, t1, t2, t3}); - - testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionConstCheck_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto one = IrBuilder::create(1); - TORCH_CHECK(one->isConstScalar()); - - auto one_x2 = mul(one, one); - TORCH_CHECK(one_x2->isConstScalar()); - - auto one_x3 = mul(one_x2, one); - TORCH_CHECK(one_x3->isConstScalar()); - - auto one_x4 = mul(one_x3, one); - TORCH_CHECK(one_x4->isConstScalar()); -} - -TEST_F(NVFuserTest, FusionUnrollWithAlloc_CUDA) { - const std::vector tensor_dims_in = {128, 128}; - Fusion fusion; - FusionGuard fg(&fusion); - - // Set up your input tensor views - TensorView* tv0 = makeSymbolicTensor(tensor_dims_in.size()); - fusion.addInput(tv0); - - TensorView* tv1 = add(tv0, IrBuilder::create(0)); - TensorView* tv2 = - reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv1); - fusion.addOutput(tv2); - - const auto options = - at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input = at::randn(tensor_dims_in, options); - at::Tensor cg_output = at::empty({tensor_dims_in[0]}, options); - - // Schedule - tv2->split(1, 32); - tv2->split(1, 4); // unroll - - auto tv2_rf = tv2->rFactor({-3, -2}); - - tv2->axis(0)->parallelize(ParallelType::BIDx); - tv2->axis(-1)->parallelize(ParallelType::TIDx); - - tv2_rf->axis(0)->parallelize(ParallelType::BIDx); - tv2_rf->axis(-1)->parallelize(ParallelType::TIDx); - tv2_rf->axis(-2)->parallelize(ParallelType::Unroll); - - tv1->computeAt(tv2_rf, -1); - - FusionExecutor fe; - fe.compileFusion(&fusion, {input}); - auto cg_outputs = fe.runFusion({input}); - - auto aten_output = (input + 0).to(at::kDouble).sum(1); - - testValidate(&fusion, cg_outputs, {input}, {aten_output}, __LINE__, __FILE__); -} - -// Test isZeroInt -TEST_F(NVFuserTest, FusionIsZeroInt_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - Int* x = IrBuilder::create(0); - Int* y = IrBuilder::create(1); - Val* z = mul(x, y); - TORCH_CHECK(x->isZeroInt()); - TORCH_CHECK(!y->isZeroInt()); - TORCH_CHECK(!z->isZeroInt()); -} - -// Test isOneInt -TEST_F(NVFuserTest, FusionIsOneInt_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - Int* x = IrBuilder::create(1); - Int* y = IrBuilder::create(1); - Val* z = mul(x, y); - TORCH_CHECK(x->isOneInt()); - TORCH_CHECK(y->isOneInt()); - TORCH_CHECK(!z->isOneInt()); -} - -// This is to verify no cycle of computeAt is created. A more complex -// variation of this pattern appears in one of the Python tests -// (test_random_topo). -TEST_F(NVFuserTest, FusionComputeAtNonterminatingOutput_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - - // Common intermediate tensor - auto tv1 = add(tv0, IrBuilder::create(1)); - // tv1 -> tv2 - auto tv2 = add(tv1, IrBuilder::create(2)); - // tv1 -> tv3 -> tv4 - auto tv3 = add(tv1, IrBuilder::create(3)); - auto tv4 = add(tv3, IrBuilder::create(4)); - - // NOTE: This should no longer occur as of PR #201. - // The order of adding outputs matters. If tv3 is added before tv4, - // it should be fine. However, if tv4 is added before tv3, there - // will be a cycle of tv3->tv4 and tv4->tv3. tv3->tv4 is created - // first, and then tv4->tv3 is created at the final phase of - // computeAt (ComputeAt::setupOutputs). - fusion.addOutput(tv2); - fusion.addOutput(tv4); - fusion.addOutput(tv3); - - tv0->computeAt(tv2, -1); - - TORCH_CHECK(tv3->hasComputeAt()); - TORCH_CHECK(!tv4->hasComputeAt()); - - const auto options = - at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor aten_input = at::randn(100, options); - - auto t1 = aten_input + 1; - auto t2 = t1 + 2; - auto t3 = t1 + 3; - auto t4 = t3 + 4; - - FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input}); - auto cg_outputs = fe.runFusion({aten_input}); - - std::vector aten_outputs = {t2, t4, t3}; - testValidate( - &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionTraversalOrder1_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Set up your input tensor views - TensorView* tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - - TensorView* tv1 = add(tv0, IrBuilder::create(1)); - TensorView* tv2 = add(tv0, IrBuilder::create(2)); - TensorView* tv3 = add(tv1, IrBuilder::create(3)); - TensorView* tv4 = add(tv1, IrBuilder::create(4)); - - fusion.addOutput(tv2); - fusion.addOutput(tv3); - fusion.addOutput(tv4); - - tv1->computeAt(tv3, -1); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor aten_input = at::randn({10, 10}, options); - - auto t1 = aten_input + 1; - auto t2 = aten_input + 2; - auto t3 = t1 + 3; - auto t4 = t1 + 4; - - std::vector aten_outputs = {t2, t3, t4}; - - std::vector cg_outputs = { - at::empty_like(aten_input, options), - at::empty_like(aten_input, options), - at::empty_like(aten_input, options)}; - - FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input}); - fe.runFusion({aten_input}, cg_outputs); - testValidate( - &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionTraversalOrder2_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Set up your input tensor views - TensorView* tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - - TensorView* tv1 = add(tv0, IrBuilder::create(1)); - TensorView* tv2 = add(tv1, IrBuilder::create(2)); - - TensorView* tv3 = add(tv0, IrBuilder::create(3)); - TensorView* tv4 = add(tv3, IrBuilder::create(4)); - - TensorView* tv5 = add(tv1, tv3); - - fusion.addOutput(tv2); - fusion.addOutput(tv4); - fusion.addOutput(tv5); - - tv1->computeAt(tv5, -1); - tv3->computeAt(tv5, -1); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor aten_input = at::randn({10, 10}, options); - - auto t1 = aten_input + 1; - auto t2 = t1 + 2; - auto t3 = aten_input + 3; - auto t4 = t3 + 4; - auto t5 = t1 + t3; - - std::vector aten_outputs = {t2, t4, t5}; - - std::vector cg_outputs = { - at::empty_like(aten_input, options), - at::empty_like(aten_input, options), - at::empty_like(aten_input, options)}; - - FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input}); - fe.runFusion({aten_input}, cg_outputs); - - testValidate( - &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionTraversalOrder3_CUDA) { - for (const auto i : c10::irange(2)) { - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - - TensorView* tv1 = add(tv0, IrBuilder::create(1)); - TensorView* tv2 = add(tv1, IrBuilder::create(2)); - - TensorView* tv3 = add(tv0, IrBuilder::create(3)); - TensorView* tv4 = add(tv3, IrBuilder::create(4)); - - TensorView* tv5 = add(tv1, tv3); - - fusion.addOutput(tv2); - fusion.addOutput(tv4); - fusion.addOutput(tv5); - - const int tile = 32; - - tv1->split(-1, tile); - tv2->split(-1, tile); - tv3->split(-1, tile); - tv4->split(-1, tile); - tv5->split(-1, tile); - - auto compute_at_outer = tv1; - auto compute_at_inner = tv3; - if (i == 1) { - std::swap(compute_at_inner, compute_at_outer); - } - - compute_at_outer->computeAt(tv5, -2); - compute_at_inner->computeAt(tv5, -1); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor aten_input = at::randn({100}, options); - auto t1 = aten_input + 1; - auto t2 = t1 + 2; - auto t3 = aten_input + 3; - auto t4 = t3 + 4; - auto t5 = t1 + t3; - - std::vector aten_outputs = {t2, t4, t5}; - - std::vector cg_outputs = { - at::empty_like(aten_input, options), - at::empty_like(aten_input, options), - at::empty_like(aten_input, options)}; - - FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input}); - fe.runFusion({aten_input}, cg_outputs); - - testValidate( - &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); - } -} - -TEST_F(NVFuserTest, FusionTraversalOrder4_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // First tree - TensorView* tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - TensorView* tv1 = add(tv0, IrBuilder::create(1)); - TensorView* tv2 = add(tv1, IrBuilder::create(2)); - TensorView* tv3 = add(tv1, IrBuilder::create(3)); - fusion.addOutput(tv2); - fusion.addOutput(tv3); - - // Second tree - TensorView* tv4 = makeSymbolicTensor(1); - fusion.addInput(tv4); - TensorView* tv5 = add(tv4, IrBuilder::create(5)); - TensorView* tv6 = add(tv5, IrBuilder::create(6)); - TensorView* tv7 = add(tv5, IrBuilder::create(7)); - fusion.addOutput(tv6); - fusion.addOutput(tv7); - - tv1->computeAt(tv2, -1); - tv5->computeAt(tv6, -1); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({100}, options); - at::Tensor t4 = at::rand_like(t0, options); - - auto t1 = t0 + 1; - auto t2 = t1 + 2; - auto t3 = t1 + 3; - auto t5 = t4 + 5; - auto t6 = t5 + 6; - auto t7 = t5 + 7; - - std::vector aten_outputs = {t2, t3, t6, t7}; - std::vector aten_inputs = {t0, t4}; - std::vector cg_outputs = { - at::empty_like(t0, options), - at::empty_like(t0, options), - at::empty_like(t0, options), - at::empty_like(t0, options)}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - fe.runFusion(aten_inputs, cg_outputs); - - testValidate( - &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionTraversalOrder5_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - TensorView* tv1 = add(tv0, IrBuilder::create(1)); - TensorView* tv2 = add(tv1, IrBuilder::create(2)); - TensorView* tv3 = add(tv0, IrBuilder::create(3)); - TensorView* tv4 = add(tv3, IrBuilder::create(4)); - TensorView* tv5 = add(tv2, tv4); - - fusion.addOutput(tv1); - fusion.addOutput(tv3); - fusion.addOutput(tv5); - - tv2->computeAt(tv5, -1); - tv4->computeAt(tv5, -1); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor aten_input = at::randn({100}, options); - std::vector cg_outputs = { - at::empty_like(aten_input, options), - at::empty_like(aten_input, options), - at::empty_like(aten_input, options)}; - - FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input}); - fe.runFusion({aten_input}, cg_outputs); - - auto t1 = aten_input + 1; - auto t2 = t1 + 2; - auto t3 = aten_input + 3; - auto t4 = t3 + 4; - auto t5 = t2 + t4; - - std::vector aten_outputs = {t1, t3, t5}; - - testValidate( - &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionTraversalOrder6_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - TensorView* tv1 = add(tv0, IrBuilder::create(1)); - TensorView* tv2 = add(tv0, IrBuilder::create(2)); - TensorView* tv3 = add(tv1, tv2); - TensorView* tv4 = add(tv3, IrBuilder::create(4)); - - fusion.addOutput(tv4); - - tv1->split(0, 32); - tv2->split(0, 32); - tv3->split(0, 32); - tv4->split(0, 32); - - tv3->computeAt(tv4, -2); - tv1->computeAt(tv3, -1); - tv2->computeAt(tv3, -2); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor aten_input = at::randn({100}, options); - - auto t1 = aten_input + 1; - auto t2 = aten_input + 2; - auto t3 = t1 + t2; - auto aten_output = t3 + 4; - - at::Tensor cg_output = at::empty_like(aten_input, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input}); - fe.runFusion({aten_input}, {cg_output}); - - testValidate( - &fusion, {cg_output}, {aten_input}, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionTraversalOrder7_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - TensorView* tv1 = add(tv0, IrBuilder::create(1)); - TensorView* tv2 = add(tv1, IrBuilder::create(2)); - TensorView* tv3 = add(tv0, IrBuilder::create(3)); - TensorView* tv4 = add(tv3, IrBuilder::create(4)); - TensorView* tv5 = add(tv2, tv4); - - fusion.addOutput(tv5); - - TensorView* tvs[] = {tv1, tv2, tv3, tv4, tv5}; - for (auto tv : tvs) { - tv->split(0, 2); - tv->split(0, 4); - tv->split(0, 8); - } - - // computeAt into inner loop nests - tv1->computeAt(tv2, -1); - tv3->computeAt(tv4, -2); - - tv2->computeAt(tv5, -4); - tv4->computeAt(tv5, -3); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor aten_input = at::randn({100}, options); - - auto t1 = aten_input + 1; - auto t2 = t1 + 2; - auto t3 = aten_input + 3; - auto t4 = t3 + 4; - auto aten_output = t2 + t4; - - at::Tensor cg_output = at::empty_like(aten_input, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input}); - fe.runFusion({aten_input}, {cg_output}); - - testValidate( - &fusion, {cg_output}, {aten_input}, {aten_output}, __LINE__, __FILE__); -} - -// Test predication of grid reduction -TEST_F(NVFuserTest, FusionThreadPredicate_CUDA) { - const int gdimx = 4; - const int bdimx = 128; - - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - - TensorView* tv1 = - reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv0); - TensorView* tv2 = unaryOp(UnaryOpType::Neg, tv1); - TensorView* tv3 = add(tv0, IrBuilder::create(2)); - - fusion.addOutput(tv3); - fusion.addOutput(tv2); - - tv1->split(1, bdimx); - tv1->split(1, gdimx); - tv3->split(1, bdimx); - tv3->split(1, gdimx); - - TensorView* tv1_rf = tv1->rFactor({1}); - - tv1->computeAt(tv2, -1); - - tv1->axis(0)->parallelize(ParallelType::BIDy); - tv1_rf->axis(0)->parallelize(ParallelType::BIDy); - tv2->axis(0)->parallelize(ParallelType::BIDy); - tv1->axis(-2)->parallelize(ParallelType::BIDx); - tv1_rf->axis(-2)->parallelize(ParallelType::BIDx); - tv1->axis(-1)->parallelize(ParallelType::TIDx); - tv1_rf->axis(-1)->parallelize(ParallelType::TIDx); - - tv3->axis(3)->parallelize(ParallelType::TIDx); - tv3->axis(2)->parallelize(ParallelType::BIDx); - tv3->axis(0)->parallelize(ParallelType::BIDy); - - int numel_x = 100; - int numel_y = 1000; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor aten_input = at::randn({numel_x, numel_y}, options); - - auto t2 = -aten_input.to(at::kDouble).sum({1}); - auto t3 = aten_input + 2.0; - - std::vector aten_outputs = {t3, t2}; - - std::vector cg_outputs = { - at::empty_like(aten_input, options), at::empty({numel_x}, options)}; - - FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input}); - fe.runFusion({aten_input}, cg_outputs); - - testValidate( - &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionLSTMCell_CUDA) { - const int hidden_features = 512; - const int batch_size = 64; - - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tvs[16]; - for (const auto i : c10::irange(16)) { - tvs[i] = makeSymbolicTensor(2); - fusion.addInput(tvs[i]); - } - - auto ingate = unaryOp( - UnaryOpType::Sigmoid, add(add(add(tvs[0], tvs[1]), tvs[2]), tvs[3])); - - auto forgetgate = unaryOp( - UnaryOpType::Sigmoid, add(add(add(tvs[4], tvs[5]), tvs[6]), tvs[7])); - - auto cellgate = unaryOp( - UnaryOpType::Tanh, add(add(add(tvs[8], tvs[9]), tvs[10]), tvs[11])); - - auto outgate = unaryOp( - UnaryOpType::Sigmoid, add(add(add(tvs[12], tvs[13]), tvs[14]), tvs[15])); - - auto cx = makeContigTensor(2); - fusion.addInput(cx); - - auto cy = add(mul(forgetgate, cx), mul(ingate, cellgate)); - - auto hy = mul(outgate, unaryOp(UnaryOpType::Tanh, cy)); - - fusion.addOutput(cy); - fusion.addOutput(hy); - - std::vector aten_inputs; - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor large_tensor0 = - at::randn({batch_size, hidden_features * 4}, options); - at::Tensor large_tensor1 = - at::randn({batch_size, hidden_features * 4}, options); - at::Tensor large_tensor2 = - at::randn({batch_size, hidden_features * 4}, options); - at::Tensor large_tensor3 = - at::randn({batch_size, hidden_features * 4}, options); - - auto chunked0 = large_tensor0.chunk(4, 1); - auto chunked1 = large_tensor1.chunk(4, 1); - auto chunked2 = large_tensor2.chunk(4, 1); - auto chunked3 = large_tensor3.chunk(4, 1); - - aten_inputs.insert(aten_inputs.end(), chunked0.begin(), chunked0.end()); - aten_inputs.insert(aten_inputs.end(), chunked1.begin(), chunked1.end()); - aten_inputs.insert(aten_inputs.end(), chunked2.begin(), chunked2.end()); - aten_inputs.insert(aten_inputs.end(), chunked3.begin(), chunked3.end()); - - auto at_ingate = - chunked0[0].add(chunked0[1]).add(chunked0[2]).add(chunked0[3]).sigmoid(); - auto at_forgetgate = - chunked1[0].add(chunked1[1]).add(chunked1[2]).add(chunked1[3]).sigmoid(); - auto at_cellgate = - chunked2[0].add(chunked2[1]).add(chunked2[2]).add(chunked2[3]).tanh(); - auto at_outgate = - chunked3[0].add(chunked3[1]).add(chunked3[2]).add(chunked3[3]).sigmoid(); - - auto at_cx = at::randn({batch_size, hidden_features}, options); - aten_inputs.push_back(at_cx); - auto at_cy = at_forgetgate.mul(at_cx).add(at_ingate.mul(at_cellgate)); - auto at_hy = at_outgate.mul(at_cy.tanh()); - - auto lparams = schedulePointwise(&fusion, aten_inputs); - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs, lparams); - auto cg_outputs = fe.runFusion(aten_inputs, lparams); - - testValidate( - &fusion, cg_outputs, aten_inputs, {at_cy, at_hy}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionComputeAtMultiBCast_CUDA) { - if (at::cuda::getCurrentDeviceProperties()->major >= 8) { - GTEST_SKIP() << "Somehow it fails on sm_80+ GPUs" - << " See https://github.com/pytorch/pytorch/issues/86717"; - } - Fusion fusion; - FusionGuard fg(&fusion); - - // Set up your input tensor views - TensorView* tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - - TensorView* tv1 = mul(tv0, IrBuilder::create(0.5)); - TensorView* tv2 = broadcast(tv1, {true, false}); - TensorView* tv3 = broadcast(tv1, {false, true}); - TensorView* tv4 = add(tv2, tv3); - fusion.addOutput(tv4); - - // Not possible to do computeAt at position -1 as recomputation - // would be required. An exception should be thrown. - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_ANY_THROW(tv1->computeAt(tv3, -1)); -} - -TEST_F(NVFuserTest, FusionReductionHalf_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Set up your input tensor views - TensorView* tv0 = makeSymbolicTensor(3, DataType::Half); - fusion.addInput(tv0); - - auto tv1 = castOp(DataType::Float, tv0); - auto tv2 = add(tv1, IrBuilder::create(1.0)); - auto tv3 = sum(tv2, {2}); - auto tv4 = castOp(DataType::Half, tv3); - - fusion.addOutput(tv4); - - const auto options = - at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - at::Tensor aten_input = at::randn({8, 8, 16}, options); - - auto reduction_tv = tv3; - - auto reduction_params = getReductionHeuristics(&fusion, {aten_input}); - TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); - scheduleReduction(&fusion, *reduction_params); - - TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); - - auto lparams = reduction_params->lparams; - - FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input}, lparams); - // no broadcasting needed, omitting the last optional argument; - auto cg_outputs = fe.runFusion({aten_input}, lparams); - - auto aten_output = aten_input.add(1.0).to(at::kDouble).sum({2}); - - testValidate( - &fusion, - cg_outputs, - {aten_input}, - {aten_output}, - __LINE__, - __FILE__, - "", - lparams); -} - -TEST_F(NVFuserTest, FusionReduceSingle_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Set up your input tensor views - TensorView* tv0 = makeConcreteTensor({100, 1}); - fusion.addInput(tv0); - auto tv1 = sum(tv0, {1}); - fusion.addOutput(tv1); - - const auto options = - at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor aten_input = at::randn({100, 1}, options); - - // Grab only tensor views, though there shouldn't be any other type - FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input}); - // no broadcasting needed, omitting the last optional argument; - auto cg_outputs = fe.runFusion({aten_input}); - - auto aten_output = aten_input.to(at::kDouble).sum({1}); - testValidate( - &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionReduceImplicitBroadcast_CUDA) { - constexpr int bid_x = 80; - constexpr int tid_x = 4096; - constexpr int red_dim = 1; - - Fusion fusion; - FusionGuard fg(&fusion); - - // Set up your input tensor views - TensorView* tv0 = makeConcreteTensor({bid_x, tid_x, 1}); - fusion.addInput(tv0); - - TensorView* tv1 = reductionOp( - BinaryOpType::Add, {red_dim, 2}, IrBuilder::create(0), tv0); - fusion.addOutput(tv1); - - const auto options = - at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor aten_input = at::randn({bid_x, tid_x, 1}, options); - - // Apply reduction heuristic - auto reduction_params = getReductionHeuristics(&fusion, {aten_input}); - TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); - scheduleReduction(&fusion, *reduction_params); - auto lparams = reduction_params->lparams; - - FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input}, lparams); - // no broadcasting needed, omitting the last optional argument; - auto cg_outputs = fe.runFusion({aten_input}, lparams); - auto aten_output = aten_input.to(at::kDouble).sum({red_dim, 2}); - - testValidate( - &fusion, - cg_outputs, - {aten_input}, - {aten_output}, - __LINE__, - __FILE__, - "", - lparams); -} - -TEST_F(NVFuserTest, FusionReduceImplicitBroadcast2_CUDA) { - constexpr int bid_x = 80; - constexpr int tid_x = 4096; - constexpr int red_dim = 1; - - Fusion fusion; - FusionGuard fg(&fusion); - - // Set up your input tensor views - TensorView* tv0 = makeConcreteTensor({bid_x, tid_x, 1}); - fusion.addInput(tv0); - - TensorView* tv1 = - reductionOp(BinaryOpType::Add, {2}, IrBuilder::create(0), tv0); - - TensorView* tv2 = reductionOp( - BinaryOpType::Add, {red_dim}, IrBuilder::create(0), tv1); - fusion.addOutput(tv2); - - const auto options = - at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor aten_input = at::randn({bid_x, tid_x, 1}, options); - - // Apply reduction heuristic - auto reduction_params = getReductionHeuristics(&fusion, {aten_input}); - TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); - - scheduleReduction(&fusion, *reduction_params); - auto lparams = reduction_params->lparams; - - FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input}, lparams); - // no broadcasting needed, omitting the last optional argument; - auto cg_outputs = fe.runFusion({aten_input}, lparams); - auto aten_output = aten_input.to(at::kDouble).sum({1, 2}); - - testValidate( - &fusion, - cg_outputs, - {aten_input}, - {aten_output}, - __LINE__, - __FILE__, - "", - lparams); -} - -TEST_F(NVFuserTest, FusionReduceImplicitBroadcast3_CUDA) { - constexpr int bid_x = 80; - constexpr int tid_x = 4096; - constexpr int red_dim = 1; - - Fusion fusion; - FusionGuard fg(&fusion); - - // Set up your input tensor views - TensorView* tv0 = makeConcreteTensor({bid_x, tid_x, 1}); - fusion.addInput(tv0); - - TensorView* tv1 = reductionOp( - BinaryOpType::Add, {red_dim}, IrBuilder::create(0), tv0); - - TensorView* tv2 = - reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv1); - fusion.addOutput(tv2); - - const auto options = - at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor aten_input = at::randn({bid_x, tid_x, 1}, options); - - // Apply reduction heuristic - auto reduction_params = getReductionHeuristics(&fusion, {aten_input}); - TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); - scheduleReduction(&fusion, *reduction_params); - auto lparams = reduction_params->lparams; - - FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input}, lparams); - // no broadcasting needed, omitting the last optional argument; - auto cg_outputs = fe.runFusion({aten_input}, lparams); - auto aten_output = aten_input.to(at::kDouble).sum({2, 1}); - - testValidate( - &fusion, - cg_outputs, - {aten_input}, - {aten_output}, - __LINE__, - __FILE__, - "", - lparams); -} - -TEST_F(NVFuserTest, FusionTrivialReduction_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Set up your input tensor views - TensorView* tv0 = makeConcreteTensor({10, 20, 1}); - fusion.addInput(tv0); - TensorView* tv1 = - reductionOp(BinaryOpType::Add, {2}, IrBuilder::create(0), tv0); - fusion.addOutput(tv1); - - TORCH_CHECK( - ir_utils::getReductionOps(&fusion, true /* ignore_trivial */).empty(), - "Trivial reduction picked up by fusion"); - - const auto options = - at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor aten_input = at::randn({10, 20, 1}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input}); - auto cg_outputs = fe.runFusion({aten_input}); - auto aten_output = aten_input.to(at::kDouble).sum({2}); - - testValidate( - &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionTrivialReduction2_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - int w = 1, x = 1, y = 7, z = 8; - - auto tv0 = makeSymbolicTensor(2); - auto tv1 = makeConcreteTensor({w, x, y, z}); - fusion.addInput(tv0); - fusion.addInput(tv1); - - auto tv2 = sum(tv1, {0}); - auto tv3 = sum(tv2, {0}); - auto tv4 = add(tv3, tv0); - - fusion.addOutput(tv4); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({y, z}, options); - at::Tensor t1 = at::randn({w, x, y, z}, options); - auto aten_output = t1.to(at::kDouble).sum({0}).sum({0}).add(t0); - - std::vector aten_inputs = {t0, t1}; - - auto lparams = schedulePointwise(&fusion, aten_inputs); - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs, lparams); - auto cg_outputs = fe.runFusion(aten_inputs, lparams); - - testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionTrivialReduction3_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - int v = 1, w = 1, x = 1, y = 7, z = 8; - - auto tv0 = makeSymbolicTensor(2); - auto tv1 = makeConcreteTensor({v, w, x, y, z}); - fusion.addInput(tv0); - fusion.addInput(tv1); - - auto tv2 = sum(tv1, {0, 1, 2}); - auto tv3 = add(tv2, tv0); - - fusion.addOutput(tv3); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({y, z}, options); - at::Tensor t1 = at::randn({v, w, x, y, z}, options); - auto aten_output = t1.sum({0, 1, 2}).add(t0); - - std::vector aten_inputs = {t0, t1}; - - auto lparams = schedulePointwise(&fusion, aten_inputs); - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs, lparams); - auto cg_outputs = fe.runFusion(aten_inputs, lparams); - - testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); -} - -// Make sure trivial reductions are correctly detected even with -// scheduling applied. -TEST_F(NVFuserTest, FusionDetectTrivialReduction1_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - - auto tv1 = broadcast(tv0, {false, true}); - auto tv2 = sum(tv1, {1}); - fusion.addOutput(tv2); - - tv2->split(1, 4); - tv2->split(1, 8); - auto tv3 = tv2->rFactor({-1}); - auto tv4 = tv2->rFactor({-1}); - - auto tv5 = broadcast(tv0, {true, false}); - auto tv6 = add(tv5, IrBuilder::create(1)); - auto tv7 = sub(tv6, IrBuilder::create(1)); - auto tv8 = sum(tv7, {0}); - fusion.addOutput(tv8); - - auto tv9 = broadcast(tv0, {false, true, true}); - auto tv10 = sum(tv9, {1}); - auto tv11 = sum(tv10, {1}); - fusion.addOutput(tv11); - - tv8->split(0, 3); - tv10->split(1, 4); - tv11->split(1, 5); - - tv0->computeAt(tv2, -1); - tv0->computeAt(tv8, -1); - tv0->computeAt(tv11, 1); - - // Test indexing to gmem-backed tensors - tv3->setMemoryType(MemoryType::Global); - tv8->setMemoryType(MemoryType::Global); - - GpuLower gpulw(&fusion); - - // No ReductionOp should be generated as all the reduction - // exprs should be replaced with a unary set op. - for (const auto expr : gpulw.kernel()->as()->exprs()) { - TORCH_CHECK(!expr->isA()); - } - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({100}, options); - std::vector aten_inputs = {t0}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto cg_outputs = fe.runFusion(aten_inputs); - - testValidate( - &fusion, cg_outputs, aten_inputs, {t0, t0, t0}, __LINE__, __FILE__); -} - -// Test detection of partially trivial reduction -TEST_F(NVFuserTest, FusionDetectTrivialReduction2_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - auto tv1 = sum(tv0, {1}); - auto tv2 = add(tv1, IrBuilder::create(1)); - fusion.addOutput(tv2); - - tv1->split(1, 1); - // tv1->axis(1): non-trivial - // tv1->axis(2): trivial - - auto tv3 = tv1->rFactor({-1}); - - // Just to suppress register-allocation warning - tv0->computeAt(tv2, 1); - tv3->computeAt(tv1, -1); - - GpuLower gpulw(&fusion); - - // tv3's reduction axis is a trivial reduction. The only - // ReductionOp should be for tv1. - for (const auto expr : gpulw.kernel()->as()->exprs()) { - if (expr->isA()) { - auto reduction_out = - expr->as()->outputs()[0]->as(); - TORCH_CHECK(reduction_out->name() == 1); - } - } -} - -TEST_F(NVFuserTest, FusionInputsIdLookup_CUDA) { - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({16, 8, 8}, options); - at::Tensor t1 = at::randn({8, 8}, options); - at::Tensor t2 = at::randn({6, 4}, options); - - // create a cache with max size 2; - torch::jit::fuser::cuda::InputsIdLookup inputs_id_lookup(2); - - // testing basic function, same encoding for identical inputs - auto id_0 = inputs_id_lookup.lookupId({t0, t1, 5.0}); - auto id_0_lookup = inputs_id_lookup.lookupId({t0, t1, 2.5}); - TORCH_CHECK(id_0.id == id_0_lookup.id); - TORCH_CHECK(inputs_id_lookup.size() == 1); - TORCH_CHECK(id_0.eviction == false); - - // new input (even tho same shape, but we have different signature because of - // missing scalar input - auto id_1 = inputs_id_lookup.lookupId({t0, t1}); - auto id_1_lookup = inputs_id_lookup.lookupId({t0, t1}); - TORCH_CHECK(id_1.id == id_1_lookup.id); - TORCH_CHECK(inputs_id_lookup.size() == 2); - TORCH_CHECK(id_1.eviction == false); - - // eviction should happen at this point - auto id_2 = inputs_id_lookup.lookupId({t2, t1}); - TORCH_CHECK(id_2.id != id_0.id); - TORCH_CHECK(id_2.id != id_1.id); - TORCH_CHECK(inputs_id_lookup.size() == 2); - TORCH_CHECK(id_2.eviction == true); - TORCH_CHECK(id_2.evict_id == id_0.id); - - // look at input 1 again - auto id_1_relook = inputs_id_lookup.lookupId({t0, t1}); - TORCH_CHECK(id_1_relook.id == id_1.id); - TORCH_CHECK(id_1_relook.eviction == false); -} - -TEST_F(NVFuserTest, FusionGroupGuardSimpleTensor_CUDA) { - std::vector sizes_vec({16, 8, 8}); - std::vector strides_vec({64, 8, 1}); - auto tensor_type = TensorType::create( - at::kFloat, c10::nullopt, sizes_vec, strides_vec, c10::nullopt); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - // pass with identical shape - auto t0 = at::randn({16, 8, 8}, options); - TORCH_CHECK(complyWith(t0, tensor_type)); - - // pass with dynamic shape - auto t1 = at::randn({16, 16, 8}, options); - TORCH_CHECK(complyWith(t1, tensor_type)); - - // broadcasting semantic change failure - auto t2 = at::randn({16, 1, 8}, options); - TORCH_CHECK(!complyWith(t2, tensor_type)); - - // contiguity failure via slicing - auto t3 = t0.slice(1, 0, 8, 2); - TORCH_CHECK(!complyWith(t3, tensor_type)); - - // contiguity failure via slicing - auto t4 = t0.slice(2, 0, 8, 2); - TORCH_CHECK(!complyWith(t4, tensor_type)); - - // rank failure - auto t5 = at::randn({16, 8, 8, 8}, options); - TORCH_CHECK(!complyWith(t5, tensor_type)); - - // contiguity on stride 1 dimension with implicit broadcasting - auto t = at::randn({4}, options); - auto t6 = t.unsqueeze(1).expand({4, 8}); - TORCH_CHECK(complyWith(t6, TensorType::create(t6))); -} - -TEST_F(NVFuserTest, FusionGroupGuardBroadcastTensor_CUDA) { - std::vector sizes_vec({16, 1, 8}); - std::vector strides_vec({8, 8, 1}); - auto tensor_type = TensorType::create( - at::kFloat, c10::nullopt, sizes_vec, strides_vec, c10::nullopt); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - // broadcasting semantic change - auto t0 = at::randn({16, 8, 8}, options); - TORCH_CHECK(!complyWith(t0, tensor_type)); - - // dtype failure - auto t1 = at::randn({16, 1, 8}, options.dtype(at::kHalf)); - TORCH_CHECK(!complyWith(t1, tensor_type)); - - // dtype failure - auto t2 = at::randn({16, 1, 8}, options); - TORCH_CHECK(complyWith(t2, tensor_type)); - - // device inconsistency shouldn't fail - auto t3 = at::randn({16, 1, 8}, options.device(at::kCPU, 0)); - TORCH_CHECK(complyWith(t3, tensor_type)); -} - -TEST_F(NVFuserTest, FusionGroupGuardPermutedTensor_CUDA) { - std::vector sizes_vec({16, 8, 8}); - std::vector strides_vec({64, 1, 8}); - auto tensor_type = TensorType::create( - at::kFloat, c10::nullopt, sizes_vec, strides_vec, c10::nullopt); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - // failing permutation - auto t0 = at::randn({16, 8, 8}, options); - TORCH_CHECK(!complyWith(t0, tensor_type)); - - // passing with dynamic shape - auto t1 = t0.permute({0, 2, 1}); - TORCH_CHECK(complyWith(t1, tensor_type)); -} - -TEST_F(NVFuserTest, FusionGroupGuardRelaxedCheck_CUDA) { - std::vector sizes_vec({16, 8, 8}); - std::vector strides_vec({128, 16, 1}); - auto tensor_type = TensorType::create( - at::kFloat, c10::nullopt, sizes_vec, strides_vec, c10::nullopt); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - // contiguity check passes although it differs - auto t0 = at::randn({16, 16, 8}, options); - TORCH_CHECK(complyWith(t0, tensor_type)); - - // passing with dynamic shape - auto t1 = t0.slice(1, 0, 16, 2); - TORCH_CHECK(complyWith(t1, tensor_type)); -} - -TEST_F(NVFuserTest, FusionDisjointSet_CUDA) { - DisjointSets set; - - const std::set group_x({0, 1, 2}); - const std::set group_y({3, 4, 5}); - const std::set group_z({6, 7, 8}); - const std::vector> groups({group_x, group_y, group_z}); - std::set group_all; - std::for_each(groups.begin(), groups.end(), [&](const auto& g) { - group_all.insert(g.begin(), g.end()); - }); - - // Initially, nothing should be considered equivalent - for (auto i : group_all) { - for (auto j : group_all) { - TORCH_CHECK(!set.permissiveAreMapped(i, j)); - } - } - - // Sets values in group_x are equivalent - for (auto i : group_x) { - for (auto j : group_x) { - set.mapEntries(i, j); - TORCH_CHECK(set.mappingExists(i)); - TORCH_CHECK(set.mappingExists(j)); - } - } - - // All values in group_x shoudl be equivalent with each other - for (auto i : group_x) { - for (auto j : group_x) { - TORCH_CHECK(set.permissiveAreMapped(i, j)); - } - } - // But nothing else should be equivalent - for (auto i : group_all) { - for (auto j : group_y) { - TORCH_CHECK(!set.permissiveAreMapped(i, j)); - } - for (auto j : group_z) { - TORCH_CHECK(!set.permissiveAreMapped(i, j)); - } - } - - // Sets values in group_y are equivalent - for (auto i : group_y) { - for (auto j : group_y) { - set.mapEntries(i, j); - TORCH_CHECK(set.mappingExists(i)); - TORCH_CHECK(set.mappingExists(j)); - } - } - - // group_x should be still equivalent - for (auto i : group_x) { - for (auto j : group_x) { - TORCH_CHECK(set.permissiveAreMapped(i, j)); - } - } - // group_y should be now equivalent - for (auto i : group_y) { - for (auto j : group_y) { - TORCH_CHECK(set.permissiveAreMapped(i, j)); - } - } - // But group_z should not be equivalent with anything yet - for (auto i : group_all) { - for (auto j : group_z) { - TORCH_CHECK(!set.permissiveAreMapped(i, j)); - } - } - - // Sets values in group_z are equivalent - for (auto i : group_z) { - for (auto j : group_z) { - set.mapEntries(i, j); - TORCH_CHECK(set.mappingExists(i)); - TORCH_CHECK(set.mappingExists(j)); - } - } - - // Now each of the three groups should be equivalent within each - // group - for (const auto gi : c10::irange(groups.size())) { - for (const auto gj : c10::irange(groups.size())) { - for (auto i : groups[gi]) { - for (auto j : groups[gj]) { - TORCH_CHECK( - (gi == gj && set.permissiveAreMapped(i, j)) || - (gi != gj && !set.permissiveAreMapped(i, j))); - } - } - } - } - - std::vector all_elements = set.getAllElements().vector(); - std::sort(all_elements.begin(), all_elements.end()); - std::vector group_all_vec(group_all.begin(), group_all.end()); - std::sort(group_all_vec.begin(), group_all_vec.end()); - TORCH_CHECK(all_elements == group_all_vec); - - set.clear(); - TORCH_CHECK(set.getAllElements().vector().size() == 0); - - // All cleared. Nothing should be considered equivalent. - for (auto i : group_all) { - for (auto j : group_all) { - TORCH_CHECK(!set.permissiveAreMapped(i, j)); - } - } -} - -TEST_F(NVFuserTest, FusionNonUniqueBroadcastSize_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - auto tv1 = makeSymbolicTensor(2); - auto tv2 = makeSymbolicTensor(2); - fusion.addInput(tv0); - fusion.addInput(tv1); - fusion.addInput(tv2); - - auto tv3 = broadcast(tv0, {false, true}); - auto tv4 = add(tv3, tv1); - auto tv5 = add(tv3, tv2); - - fusion.addOutput(tv4); - fusion.addOutput(tv5); - - // In order to do this, tv1->axis(1) and tv2->axis(1) must have the - // same size, but we can't prove it, so this should throw an error. - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_ANY_THROW(tv3->computeAt(tv4, -1)); -} - -TEST_F(NVFuserTest, FusionBiasGeluFwd_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - const float k_079 = 0.79788456; - const float k_004 = 0.044715; - - // bias vector - auto t0 = makeSymbolicTensor(1, DataType::Half); - fusion.addInput(t0); - auto t1 = castOp(DataType::Float, t0); - // input tensor - auto t2 = makeSymbolicTensor(3, DataType::Half); - fusion.addInput(t2); - auto t3 = castOp(DataType::Float, t2); - auto t4 = broadcast(t1, {true, true, false}); - auto t5 = add(t4, t3); - auto t6 = mul(t5, IrBuilder::create(0.5)); - auto t7 = mul(t5, IrBuilder::create(k_079)); - auto t8 = mul(t5, IrBuilder::create(k_004)); - auto t9 = mul(t8, t5); - auto t10 = add(t9, IrBuilder::create(1)); - auto t11 = mul(t7, t10); - auto t12 = unaryOp(UnaryOpType::Tanh, t11); - auto t13 = add(t12, IrBuilder::create(1)); - auto t14 = mul(t6, t13); - auto t15 = castOp(DataType::Half, t14); - fusion.addOutput(t15); - - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - at::manual_seed(0); - std::vector input_shape{6, 512, 4096}; - std::vector bias_shape{4096}; - - auto at_input = at::randn(input_shape, options); - auto at_bias = at::randn(bias_shape, options); - - auto at_x = - at_bias.to(c10::ScalarType::Float) + at_input.to(c10::ScalarType::Float); - auto aten_output_float = - at_x * 0.5 * (1.0 + (k_079 * at_x * (1 + k_004 * at_x * at_x)).tanh()); - auto aten_output = aten_output_float.to(c10::ScalarType::Half); - - std::vector aten_inputs = {at_bias, at_input}; - auto lparams = schedulePointwise(&fusion, aten_inputs); - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs, lparams); - auto cg_outputs = fe.runFusion(aten_inputs, lparams); - - testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionBiasGeluBwd_CUDA) { - if (at::cuda::getDeviceProperties(0)->major < 6) { - return; - } - Fusion fusion; - FusionGuard fg(&fusion); - - const float k_079 = 0.79788456; - const float k_004 = 0.044715; - const float k_010 = 0.1070322243; - - // gradient tensor - auto t0 = makeSymbolicTensor(3, DataType::Half); - fusion.addInput(t0); - auto t1 = castOp(DataType::Float, t0); - // bias tensor - auto t2 = makeSymbolicTensor(1, DataType::Half); - fusion.addInput(t2); - auto t3 = castOp(DataType::Float, t2); - // input tensor - auto t4 = makeSymbolicTensor(3, DataType::Half); - fusion.addInput(t4); - auto t5 = castOp(DataType::Float, t4); - auto t6 = broadcast(t3, {true, true, false}); - auto t7 = add(t6, t5); - auto t8 = mul(t7, IrBuilder::create(k_079)); - auto t9 = mul(t7, IrBuilder::create(k_004)); - auto t10 = mul(t9, t7); - auto t11 = add(t10, IrBuilder::create(1)); - auto t12 = mul(t8, t11); - auto t13 = unaryOp(UnaryOpType::Tanh, t12); - auto t14 = mul(t7, IrBuilder::create(0.5)); - auto t15 = mul(t13, t13); - auto t16 = unaryOp(UnaryOpType::Neg, t15); - auto t17 = add(t16, IrBuilder::create(1)); - auto t18 = mul(t7, IrBuilder::create(k_010)); - auto t19 = mul(t18, t7); - auto t20 = add(t19, IrBuilder::create(k_079)); - auto t21 = mul(t17, t20); - auto t22 = mul(t14, t21); - auto t23 = add(t13, IrBuilder::create(1)); - auto t24 = mul(t23, IrBuilder::create(0.5)); - auto t25 = add(t22, t24); - auto t26 = mul(t25, t1); - // Save float output for validation - fusion.addOutput(t26); - auto t27 = castOp(DataType::Half, t26); - fusion.addOutput(t27); - - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - at::manual_seed(1); - std::vector input_shape{6, 512, 4096}; - std::vector bias_shape{4096}; - auto at_input = at::randn(input_shape, options); - auto at_bias = at::randn(bias_shape, options); - auto at_grad = at::randn(input_shape, options); - - auto at_x = - at_bias.to(c10::ScalarType::Float) + at_input.to(c10::ScalarType::Float); - auto at_tanh_out = (k_079 * at_x * (1 + k_004 * at_x * at_x)).tanh(); - auto at_ff = 0.5 * at_x * - ((1 - at_tanh_out * at_tanh_out) * (k_079 + k_010 * at_x * at_x)) + - 0.5 * (1 + at_tanh_out); - auto at_out = at_ff * at_grad; - auto at_out_half = at_out.to(c10::ScalarType::Half); - - std::vector aten_inputs = {at_grad, at_bias, at_input}; - std::vector aten_outputs = {at_out, at_out_half}; - - auto lparams = schedulePointwise(&fusion, aten_inputs); - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs, lparams); - auto cg_outputs = fe.runFusion(aten_inputs, lparams); - - testValidate( - &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__); -} - -// Reproducer of issue #459 -TEST_F(NVFuserTest, FusionIssue459_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - auto tv1 = makeSymbolicTensor(2); - fusion.addInput(tv1); - - auto tv2 = add(tv0, IrBuilder::create(1)); - auto tv3 = broadcast(tv2, {true, false}); - auto tv4 = add(tv1, tv3); - - // Create two outputs from the final arithmetic result - auto tv5 = add(tv4, IrBuilder::create(1)); - fusion.addOutput(tv5); - auto tv6 = add(tv4, IrBuilder::create(1)); - fusion.addOutput(tv6); - - // Scheduling - for (auto output : ir_utils::filterByType(fusion.outputs())) { - output->merge(-2, -1); - } - for (auto output : ir_utils::filterByType(fusion.outputs())) { - output->split(0, 128); - } - - tv0->computeAt(tv5, -1); - - tv6->axis(0)->parallelize(ParallelType::BIDx); - tv6->axis(1)->parallelize(ParallelType::TIDx); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::manual_seed(0); - const int numel_x = 10; - const int numel_y = 20; - auto t0 = at::randn({numel_x}, options); - auto t1 = at::randn({numel_y, numel_x}, options); - auto aten_output = (t0 + 1).unsqueeze(0) + t1 + 1; - - std::vector aten_inputs = {t0, t1}; - - torch::jit::fuser::cuda::FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto cg_outputs = fe.runFusion(aten_inputs); - - testValidate( - &fusion, - cg_outputs, - aten_inputs, - {aten_output, aten_output}, - __LINE__, - __FILE__); -} - -TEST_F(NVFuserTest, FusionSmemIndexingSimple_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - auto tv1 = add(tv0, IrBuilder::create(1)); - auto tv2 = add(tv1, IrBuilder::create(1)); - auto tv3 = add(tv2, IrBuilder::create(1)); - fusion.addOutput(tv3); - - tv3->axis(0)->parallelize(ParallelType::BIDx); - tv3->axis(1)->parallelize(ParallelType::TIDx); - - tv0->computeAt(tv3, -1); - - tv1->setMemoryType(MemoryType::Shared); - tv2->setMemoryType(MemoryType::Global); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - auto aten_input = at::randn({12, 34}, options); - at::Tensor aten_output = aten_input + 1.0 + 1.0 + 1.0; - - FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input}); - auto cg_outputs = fe.runFusion({aten_input}); - - testValidate( - &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionSmemIndexing_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Symbolic integers we will use for runtime tiling - Int* symbolic_m_tile_dim = IrBuilder::create(); - Int* symbolic_split_k_tile_dim = IrBuilder::create(); - Int* symbolic_block_k_tile_dim = IrBuilder::create(); - // Compile-time integer for tiling - int n_smem_tile = 32; - - // Symbolic 2D tensors TV0[M, K], TV1[K, N] - TensorView* tv0 = makeSymbolicTensor(2); - TensorView* tv1 = makeSymbolicTensor(2); - - // Broadcast tv0 to [M, K, *] - TensorView* tv2 = broadcast(tv0, {false, false, true}); - // Broadcast tv1 to [*, K, N] - TensorView* tv3 = broadcast(tv1, {true, false, false}); - - // Pointwise multiplication resulting in tv3[M, K, N] - TensorView* tv4 = mul(tv2, tv3); - - // Sum the K-dim - TensorView* tv5 = sum(tv4, {1}); - - // Register inputs and outputs - fusion.addInput(tv0); - fusion.addInput(tv1); - fusion.addOutput(tv5); - - // Register runtime tile dims as inputs - fusion.addInput(symbolic_m_tile_dim); - fusion.addInput(symbolic_split_k_tile_dim); - fusion.addInput(symbolic_block_k_tile_dim); - - // Make a 3D tile, mix of symbolic and constant, do in reverse order because - // dims are inserted - // [M, rK, N] - tv5->split(2, n_smem_tile); - // [M, rK, No, Ni{32}] - tv5->split(1, symbolic_block_k_tile_dim); - // [M, rKo, rKi{i2}, No, Ni{32}] - tv5->split(1, symbolic_split_k_tile_dim); - // [M, rKoo, rKoi{i1}, rKi{i2}, No, Ni{32}] - tv5->split(0, symbolic_m_tile_dim); - // [Mo, Mi{i0}, rKoo, rKoi{i1}, rKi{i2}, No, Ni{32}] - - // Reorder so all outer tiles are in the leftmost 3 positions - // [Mo, Mi{i0}, rKoo, rKoi{i1}, rKi{i2}, No, Ni{32}] - // [Mo, No, rKoo, rKoi{i1}, rKi{i2}, Mi{i0}, Ni{32}] - tv5->reorder({{1, 5}, {5, 1}}); - - // Factor out the outer reduction IterDomain, then run the inter-cta - // reduction, and intra-cta reduction - // [Mo, No, rKoo, Koi{i1}, Ki{i2}, Mi{i0}, Ni{32}] - // [Mo, No, rKoi{i1}, rKi{i2}, Mi{i0}, Ni{32}] - auto tv6 = tv5->rFactor({2}); - - // Scope computations - tv6->computeAt(tv5, 2); - - // [Mo, No, rKoo, Koi{i1}, Ki{i2}, Mi{i0}, Ni{32}] - // [Mo, No, Ki{i2}, Mi{i0}, Ni{32}, rKoo, Koi{i1}] - tv6->reorder({ - {5, -2}, - {6, -1}, - {2, 2}, - {3, 3}, - {4, 4}, - }); - - // Setup compute at schedule - tv0->computeAt(tv6, 3); - tv1->computeAt(tv6, 3); - tv4->computeAt(tv6, -1); - - // Cache smem tiles - tv2->setMemoryType(MemoryType::Shared); - tv3->setMemoryType(MemoryType::Shared); - tv4->setMemoryType(MemoryType::Shared); - tv6->setMemoryType(MemoryType::Shared); - - tv5->axis(0)->parallelize(ParallelType::BIDz); - tv5->axis(1)->parallelize(ParallelType::BIDy); - - std::vector tv_list = {tv2, tv3, tv4, tv5, tv6}; - for (auto tv : tv_list) { - tv->axis(-2)->parallelize(ParallelType::TIDz); - tv->axis(-1)->parallelize(ParallelType::TIDy); - } - - constexpr int M = 31, K = 65, N = 32; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({M, K}, options); - at::Tensor t1 = at::randn({K, N}, options); - - at::Tensor aten_output = - mul(t0.unsqueeze(2), t1.unsqueeze(0)).to(at::kDouble).sum(1); - - // A, B, m_tile_dim, split_k, intra_cta_tile - std::vector aten_inputs = {t0, t1, 3, 4, 5}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto cg_outputs = fe.runFusion(aten_inputs); - - testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); -} - -// Reproducer of issue 408 -TEST_F(NVFuserTest, FusionCacheBeforeReduction_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - auto tv1 = add(tv0, IrBuilder::create(1)); - auto tv2 = sum(tv1, {1}); - fusion.addOutput(tv2); - - tv2->split(0, 4); - - auto tv3 = tv2->cacheBefore(); - - tv0->computeAt(tv3, -1); - tv3->computeAt(tv2, -1); - - tv3->axis(-1)->parallelize(ParallelType::TIDx); - - const int numel_x = 100; - const int numel_y = 200; - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::Tensor aten_input = at::randn({numel_x, numel_y}, options); - at::Tensor cg_output = at::empty({numel_x}, options); - - auto aten_output = (aten_input + 1).to(at::kDouble).sum({1}); - - FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input}); - fe.runFusion({aten_input}, {cg_output}); - - testValidate( - &fusion, {cg_output}, {aten_input}, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionCacheBeforeReduction2_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(3); - fusion.addInput(tv0); - auto tv1 = add(tv0, IrBuilder::create(1)); - auto tv2 = sum(tv1, {1}); - auto tv3 = add(tv2, IrBuilder::create(1)); - fusion.addOutput(tv2); - fusion.addOutput(tv3); - - auto tv4 = tv2->cacheBefore(); - - tv4->computeAt(tv3, 1); - tv0->computeAt(tv4, -1); - - tv3->axis(0)->parallelize(ParallelType::BIDx); - tv1->axis(-1)->parallelize(ParallelType::TIDx); - tv2->axis(-1)->parallelize(ParallelType::TIDx); - tv3->axis(-1)->parallelize(ParallelType::TIDx); - tv4->axis(-1)->parallelize(ParallelType::TIDx); - - const int numel_x = 10; - const int numel_y = 20; - const int numel_z = 30; - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::Tensor aten_input = at::randn({numel_x, numel_y, numel_z}, options); - auto t2 = (aten_input + 1).to(at::kDouble).sum({1}); - auto t3 = t2 + 1; - std::vector aten_outputs = {t2, t3}; - - FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input}); - auto cg_outputs = fe.runFusion({aten_input}); - - testValidate( - &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionIssue367_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Symbolic integers we will use for runtime tiling - Int* symbolic_m_tile_dim = IrBuilder::create(); - Int* symbolic_split_k_tile_dim = IrBuilder::create(); - Int* symbolic_block_k_tile_dim = IrBuilder::create(); - // Compile-time integer for tiling - int n_smem_tile = 32; - - // Symbolic 2D tensors TV0[M, K], TV1[K, N] - TensorView* tv0 = makeSymbolicTensor(2); - TensorView* tv1 = makeSymbolicTensor(2); - - // Broadcast tv0 to [M, K, *] - TensorView* tv2 = broadcast(tv0, {false, false, true}); - // Broadcast tv1 to [*, K, N] - TensorView* tv3 = broadcast(tv1, {true, false, false}); - - // Pointwise multiplication resulting in tv3[M, K, N] - TensorView* tv4 = mul(tv2, tv3); - - // Sum the K-dim - TensorView* tv5 = sum(tv4, {1}); - - // Register inputs and outputs - fusion.addInput(tv0); - fusion.addInput(tv1); - fusion.addOutput(tv5); - - // Register runtime tile dims as inputs - fusion.addInput(symbolic_m_tile_dim); - fusion.addInput(symbolic_split_k_tile_dim); - fusion.addInput(symbolic_block_k_tile_dim); - - // Make a 3D tile, mix of symbolic and constant, do in reverse order because - // dims are inserted - // [M, K, N] - tv5->split(2, n_smem_tile); - tv5->split(1, symbolic_block_k_tile_dim); - tv5->split(1, symbolic_split_k_tile_dim); - tv5->split(0, symbolic_m_tile_dim); - // [Mo, Mi, Koo, Koi, Ki, No, Ni] - tv5->reorder({{1, 5}, {5, 1}}); - // [Mo, No, Koo, Koi, Ki, Mi, Ni] - - auto tv6 = tv5->rFactor({2}); - auto tv7 = tv5->rFactor({2}); - // [Mo, No, rKoo, Koi, Ki, Mi, Ni] - // [Mo, No, rKoi, rKi, Mi, Ni] - - // Scope computations - tv6->computeAt(tv5, 2); - - tv0->computeAt(tv6, 3); - tv1->computeAt(tv6, 3); - tv4->computeAt(tv6, -1); - - // Cache smem tiles - tv2->setMemoryType(MemoryType::Shared); - tv3->setMemoryType(MemoryType::Shared); - tv4->setMemoryType(MemoryType::Local); - tv6->setMemoryType(MemoryType::Local); - tv7->setMemoryType(MemoryType::Local); - - tv5->axis(0)->parallelize(ParallelType::BIDz); - tv5->axis(1)->parallelize(ParallelType::BIDy); - - std::vector tv_list = {tv2, tv3, tv4, tv5, tv6, tv7}; - for (auto tv : tv_list) { - tv->axis(-2)->parallelize(ParallelType::TIDz); - tv->axis(-1)->parallelize(ParallelType::TIDy); - } - tv2->axis(3)->parallelize(ParallelType::TIDx); - tv3->axis(3)->parallelize(ParallelType::TIDx); - tv4->axis(3)->parallelize(ParallelType::TIDx); - tv6->axis(3)->parallelize(ParallelType::TIDx); - tv7->axis(2)->parallelize(ParallelType::TIDx); - - tv2->axis(4)->parallelize(ParallelType::BIDx); - tv3->axis(4)->parallelize(ParallelType::BIDx); - tv4->axis(4)->parallelize(ParallelType::BIDx); - tv6->axis(4)->parallelize(ParallelType::BIDx); - tv7->axis(3)->parallelize(ParallelType::BIDx); - tv5->axis(2)->parallelize(ParallelType::BIDx); - - constexpr int M = 3, K = 6, N = 16; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::Tensor t0 = at::randn({M, K}, options); - at::Tensor t1 = at::randn({K, N}, options); - - // A, B, m, split_k, block_k - std::vector aten_inputs = {t0, t1, 2, 2, 3}; - at::Tensor aten_output = - mul(t0.unsqueeze(2), t1.unsqueeze(0)).to(at::kDouble).sum(1); - - torch::jit::fuser::cuda::FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto cg_outputs = fe.runFusion(aten_inputs); - - testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionIssue468_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - auto tv1 = sum(tv0, {1}); - auto tv2 = sum(tv1, {0}); - fusion.addOutput(tv2); - - tv1->axis(0)->parallelize(ParallelType::TIDy); - tv1->axis(1)->parallelize(ParallelType::TIDx); - - tv2->axis(0)->parallelize(ParallelType::TIDy); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor aten_input = at::randn({10, 100}, options); - at::Tensor aten_output = aten_input.to(at::kDouble).sum({1}).sum({0}); - - FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input}); - auto cg_outputs = fe.runFusion({aten_input}); - - testValidate( - &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionIssue363_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Symbolic 2D tensors TV0[M, K], TV1[K, N] - TensorView* tv0 = makeSymbolicTensor(2); - TensorView* tv1 = makeSymbolicTensor(2); - - // Broadcast tv0 to [M, K, *] - TensorView* tv2 = broadcast(tv0, {false, false, true}); - // Broadcast tv1 to [*, K, N] - TensorView* tv3 = broadcast(tv1, {true, false, false}); - - // Pointwise multiplication resulting in tv3[M, K, N] - TensorView* tv4 = mul(tv2, tv3); - - // Sum the K-dim - TensorView* tv5 = sum(tv4, {1}); - - // Register inputs and outputs - fusion.addInput(tv0); - fusion.addInput(tv1); - fusion.addOutput(tv5); - - tv2->setMemoryType(MemoryType::Global); - tv3->setMemoryType(MemoryType::Global); - tv4->setMemoryType(MemoryType::Global); - - tv0->computeAt(tv5, -1); - tv1->computeAt(tv5, -1); - - tv5->axis(0)->parallelize(ParallelType::BIDz); - tv5->axis(1)->parallelize(ParallelType::BIDy); - - tv5->axis(2)->parallelize(ParallelType::BIDx); - - constexpr int M = 3, K = 6, N = 16; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::Tensor t0 = at::randn({M, K}, options); - at::Tensor t1 = at::randn({K, N}, options); - at::Tensor aten_output = - mul(t0.unsqueeze(2), t1.unsqueeze(0)).to(at::kDouble).sum(1); - - std::vector aten_inputs = {t0, t1}; - - torch::jit::fuser::cuda::FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto cg_outputs = fe.runFusion(aten_inputs); - - testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionIssue484_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - auto tv1 = sum(tv0, {1}); - auto tv2 = add(tv1, IrBuilder::create(0)); - fusion.addOutput(tv2); - - tv1->setMemoryType(MemoryType::Global); - tv1->axis(1)->parallelize(ParallelType::TIDx); - - constexpr int M = 100; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::Tensor aten_input = at::randn({M, M}, options); - at::Tensor aten_output = aten_input.to(at::kDouble).sum({1}); - - torch::jit::fuser::cuda::FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input}); - auto cg_outputs = fe.runFusion({aten_input}); - - testValidate( - &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionIssue329_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - auto tv1 = add(tv0, IrBuilder::create(1)); - auto tv2 = sum(tv1, {1}); - fusion.addOutput(tv2); - auto tv3 = sum(tv1, {1}); - fusion.addOutput(tv3); - - tv1->computeAt(tv2, -1); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - std::vector t0_shape{17, 19}; - auto aten_input = at::randn(t0_shape, options); - auto t2 = (aten_input + 1).to(at::kDouble).sum({1}); - auto t3 = (aten_input + 1).to(at::kDouble).sum({1}); - std::vector aten_outputs = {t2, t3}; - - FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input}); - auto cg_outputs = fe.runFusion({aten_input}); - - testValidate( - &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionIssue382_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - - auto tv1 = add(tv0, IrBuilder::create(1)); - auto tv2 = broadcast(tv1, {false, false, true}); - auto tv3 = makeSymbolicTensor(3); - fusion.addInput(tv3); - auto tv4 = add(tv2, tv3); - fusion.addOutput(tv4); - - tv2->merge(1); - tv4->merge(1); - - tv1->computeAt(tv4, 1); - - tv4->axis(0)->parallelize(ParallelType::BIDx); - - tv1->setMemoryType(MemoryType::Global); - tv2->setMemoryType(MemoryType::Global); - - const int numel_x = 12; - const int numel_y = 34; - const int numel_z = 56; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::manual_seed(0); - auto t0 = at::randn({numel_x, numel_y}, options); - auto t3 = at::randn({numel_x, numel_y, numel_z}, options); - - std::vector aten_inputs = {t0, t3}; - auto aten_output = (t0 + 1).unsqueeze(-1) + t3; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto cg_outputs = fe.runFusion(aten_inputs); - - testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionIssue507_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - auto tv1 = add(tv0, IrBuilder::create(1)); - auto tv2 = add(tv1, IrBuilder::create(1)); - fusion.addOutput(tv2); - - tv1->setMemoryType(MemoryType::Shared); - - tv1->axis(1)->parallelize(ParallelType::TIDx); - tv2->axis(1)->parallelize(ParallelType::TIDx); - tv1->axis(0)->parallelize(ParallelType::BIDx); - tv2->axis(0)->parallelize(ParallelType::BIDx); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - std::vector t0_shape{17, 19}; - auto aten_input = at::randn(t0_shape, options); - auto t1 = (aten_input + 1); - auto aten_output = (t1 + 1); - - FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input}); - auto cg_outputs = fe.runFusion({aten_input}); - - testValidate( - &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionIssue532_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Algorithm - TensorView* tv0 = makeSymbolicTensor(1); - TensorView* tv1 = add(tv0, IrBuilder::create(1)); - TensorView* tv2 = add(tv1, IrBuilder::create(1)); - fusion.addInput(tv0); - fusion.addOutput(tv2); - - const int M_BLOCK = 64; - const int M_THREAD = 4; - - tv2->split(0, M_BLOCK); - // tv2: [M/M_BLOCK, M_BLOCK] - tv1->computeAt(tv2, 1); - // tv1: [M/M_BLOCK, M_BLOCK] - - tv1->split(-1, M_BLOCK / M_THREAD); - // tv1: [M/M_BLOCK, M_THREAD, M_BLOCK / M_THREAD] - - tv2->split(-1, M_THREAD); - // tv2: [M/M_BLOCK, M_BLOCK / M_THREAD, M_THREAD] - - constexpr int M = 1000; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::manual_seed(0); - at::Tensor t0 = at::randn({M}, options); - std::vector aten_inputs = {t0}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto outputs = fe.runFusion(aten_inputs); - - at::Tensor aten_output = t0 + 1 + 1; - - testValidate( - &fusion, outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionLoopUnswitch_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Algorithm - TensorView* tv0 = makeSymbolicTensor(1); - TensorView* tv1 = add(tv0, IrBuilder::create(1)); - TensorView* tv2 = add(tv1, IrBuilder::create(1)); - fusion.addInput(tv0); - fusion.addOutput(tv2); - - tv2->split(0, 32); - tv1->computeAt(tv2, -1); - - tv2->axis(1)->parallelize(ParallelType::Unswitch); - - constexpr int M = 1000; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::manual_seed(0); - at::Tensor t0 = at::randn({M}, options); - std::vector aten_inputs = {t0}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto outputs = fe.runFusion(aten_inputs); - - at::Tensor aten_output = t0 + 1 + 1; - - testValidate( - &fusion, outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionIssue549_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Set up your input tensor views - TensorView* tv0 = makeSymbolicTensor(2); // M, K - TensorView* tv1 = makeSymbolicTensor(2); // K, N - fusion.addInput(tv0); - fusion.addInput(tv1); - - auto tv2 = add(tv0, IrBuilder::create(1)); - - TensorView* tv3 = broadcast(tv2, {false, false, true}); - // tv3[I0, I1, B] = tv0[I0, I1] - - TensorView* tv4 = broadcast(tv1, {true, false, false}); - // tv4[B, I1, I2] = tv1[I1, I2] - - // tv5[I0, I1, I2] = tv3[I0, I1, B] * tv4[B, I1, I2] - TensorView* tv5 = mul(tv3, tv4); - // tv6[I0, R1, I2] = tv5[I0, I1, I2] - TensorView* tv6 = sum(tv5, {1}); - fusion.addOutput(tv6); - - tv6->split(1, 32); - // tv6[I0, R1o, R1i{32}, I2] - - auto tv7 = tv6->rFactor({1}); - // tv7[I0, R1o, I1i{32}, I2] = tv5[I0, I1, I2] - // tv6[I0, , R1i{32}, I2] = tv7[I0, R1o, I1i{32}, I2] - - tv6->split(0, 4); - tv6->split(-1, 4); - // tv6[I0o, I0i{4}, R1i{32}, I2o, I2i{4}] - // tv6[I0o, I0i{4}, R1i{32}, I2o, I2i{4}] - - tv0->computeAt(tv6, -1); - tv1->computeAt(tv6, -1); - - // tv7[I0o, I0i{4}, R1o, I1i{32}, I2o, I2i{4}] - // tv6[I0o, I0i{4}, , R1i{32}, I2o, I2i{4}] - //--> (line symbolizes compute at location) - // tv5[I0o, I0i{4}, I1i{32}, I2o, I2i{4}|, I1o] - // tv7[I0o, I0i{4}, I1i{32}, I2o, I2i{4}|, R1o] - // tv6[I0o, I0i{4}, R1i{32}, I2o, I2i{4}|] - - tv0->computeAt(tv7, -1); - tv1->computeAt(tv7, -1); - // tv5[I0o, I0i{4}, I1i{32}, I2o, I2i{4}, I1o |] - // tv7[I0o, I0i{4}, I1i{32}, I2o, I2i{4}, R1o |] - // tv6[I0o, I0i{4}, R1i{32}, I2o, I2i{4}|] - - tv6->axis(0)->parallelize(ParallelType::BIDz); - tv6->axis(1)->parallelize(ParallelType::TIDz); - - tv6->axis(-2)->parallelize(ParallelType::BIDy); - tv6->axis(-1)->parallelize(ParallelType::TIDy); - - tv6->axis(2)->parallelize(ParallelType::TIDx); - tv7->axis(2)->parallelize(ParallelType::TIDx); - - constexpr int M = 65, K = 33, N = 17; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::Tensor t0 = at::randn({M, K}, options); - at::Tensor t1 = at::randn({K, N}, options); - - // Lets specify a few bounds in launch params to make sure it works - LaunchParams lparams(1, -1, -1, 32, 4, 4); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0, t1}, lparams); - fe.runFusion({t0, t1}, lparams); - - // Make sure bad launch params throws - // TODO: Re-enable once we have parallelization validation in. - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - // ASSERT_ANY_THROW(fe.runFusion({t0, t1}, LaunchParams(1, 2, 3, 4, 5, 6))); - - // Don't specify any launch params - auto cg_outputs = fe.runFusion({t0, t1}); - - auto aten_output = (t0 + 1).to(at::kDouble).matmul(t1.to(at::kDouble)); - - testValidate( - &fusion, cg_outputs, {t0, t1}, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionSimpleCompileRtc_CUDA) { - FusionExecutor fe; - std::string kernel = R"( -__global__ void kernel1(Tensor T0, Tensor T1) { - if(threadIdx.x==0){ - for(size_t ki28 = 0; ki28 < T0.size[0]; ++ki28) { - T1[ki28*T1.stride[0]] = T0[ki28*T0.stride[0]]*2; - } - } -} - )"; - fe.compileRtc(kernel, "CudaCodeGen::kernel1"); - LaunchParams lp( - 256, // gdimx - 1, // gdimy - 1, // gdimz - 1, // bdimx - 1, // bdimy - 1 // bdimz - ); - lp.setSmem(0); - const auto options = - at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - const std::vector tensor_dims = {8}; - auto in0 = at::randn(tensor_dims, options); - auto out0 = at::empty_like(in0); - fe.runRtc(lp, {in0, out0}); - - auto out_ref = in0 * 2; - TORCH_CHECK(out_ref.allclose(out0)); -} - -TEST_F(NVFuserTest, FusionSerialWelford_CUDA) { - FusionExecutor fe; - int x = 128, y = 64, z = 64; - - std::string kernel = R"( -__global__ void kernel1( - Tensor inp, - Tensor out_var, - Tensor out_avg -){ - for(int i0=0;i0 tensor_dims = {x, y, z}; - auto in0 = at::randn(tensor_dims, options); - auto out_var = at::empty({x}, options); - auto out_avg = at::empty({x}, options); - fe.runRtc(lp, {in0, out_var, out_avg}); - - TORCH_CHECK(in0.var({1, 2}, false).allclose(out_var)); - TORCH_CHECK(in0.mean({1, 2}).allclose(out_avg, /*rtol*/ 1e-5, /*atol*/ 1e-6)); -} - -TEST_F(NVFuserTest, FusionBlockWelford_CUDA) { - FusionExecutor fe; - int x = 7, y = 8, z = 9; - - std::string kernel = R"( -__global__ void kernel1( - Tensor inp, - Tensor out_avg, - Tensor out_var, - Tensor init_avg, - Tensor init_var, - Tensor init_N -){ - //actual generated kernel will use dynamic shared mem, - // here is just for prototype - __shared__ float mem_avg[512]; - __shared__ float mem_M2[512]; - __shared__ long mem_N[512]; - float in=inp[threadIdx.x*inp.stride[0]+ - threadIdx.y*inp.stride[1]]; - float tmp_avg=0; - float tmp_M2=0; - long tmp_N=0; - blockWelford( - tmp_avg, - tmp_M2, - tmp_N, - in, - 0.f, - (long)1, - threadIdx, - blockDim, - (float*)mem_avg, - (float*)mem_M2, - (long*)mem_N, - (bool)(threadIdx.x tensor_dims = {x, y}; - const std::vector init_dims = {x, z}; - - // generate initial values - auto init_in = at::randn(init_dims, options); - auto init_var = init_in.var({1}, false); - auto init_avg = init_in.mean({1}); - auto init_N = - at::tensor(z, at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0)); - - auto in0 = at::randn(tensor_dims, options); - - // run kernel - auto out_var = at::zeros({x}, options); - auto out_avg = at::zeros({x}, options); - fe.runRtc(lp, {in0, out_avg, out_var, init_avg, init_var, init_N}); - - // compare with reference output - auto cat_tensor = at::cat({init_in, in0}, 1); - TORCH_CHECK(cat_tensor.var({1}, false).allclose(out_var)); - TORCH_CHECK( - cat_tensor.mean({1}).allclose(out_avg, /*rtol*/ 1e-5, /*atol*/ 1e-6)); -} - -TEST_F(NVFuserTest, FusionBlockWelfordNoInit_CUDA) { - FusionExecutor fe; - int x = 7, y = 8, z = 9; - - // need support IValue for integer input as initial count - std::string kernel = R"( -__global__ void kernel1( - Tensor inp, - Tensor out_avg, - Tensor out_var -){ - //actual generated kernel will use dynamic shared mem, - // here is just for prototype - __shared__ float mem_avg[512]; - __shared__ float mem_M2[512]; - __shared__ long mem_N[512]; - float in=inp[threadIdx.x*inp.stride[0]+ - threadIdx.y*inp.stride[1]+ - threadIdx.z*inp.stride[2]]; - float tmp_avg=0; - float tmp_M2=0; - long tmp_N=0; - block_sync::init(); - blockWelford( - tmp_avg, - tmp_M2, - tmp_N, - in, - 0.f, - (long) 1, - threadIdx, - blockDim, - (float*)mem_avg, - (float*)mem_M2, - (long*)mem_N, - (bool)(threadIdx.x tensor_dims = {x, y, z}; - auto in0 = at::randn(tensor_dims, options); - auto out_var = at::empty({x}, options); - auto out_avg = at::empty({x}, options); - fe.runRtc(lp, {in0, out_avg, out_var}); - - TORCH_CHECK(in0.var({1, 2}, false).allclose(out_var)); - TORCH_CHECK(in0.mean({1, 2}).allclose(out_avg, /*rtol*/ 1e-5, /*atol*/ 1e-6)); -} - -TEST_F(NVFuserTest, FusionGridWelfordNoInit_CUDA) { - FusionExecutor fe; - int x = 128, y = 64, z = 128; - - std::string kernel = R"( -__global__ void kernel1( - Tensor inp, - Tensor out_avg, - Tensor out_var, - Tensor work_buf_avg, - Tensor work_buf_M2, - Tensor work_buf_N, - Tensor sync_flag -){ - __shared__ float shared_buf_avg[512]; - __shared__ float shared_buf_M2[512]; - __shared__ long shared_buf_N[512]; - float tmp_avg=0; - float tmp_M2=0; - long tmp_N=0; - float in = inp[ blockIdx.x * inp.stride[0]+ - blockIdx.y * inp.stride[1]+ - threadIdx.x * inp.stride[2]]; - block_sync::init(); - welford::gridWelford< - true,true,false, - true,false,false, - false - >( - tmp_avg, - tmp_M2, - tmp_N, - in, - 0.f, - (long) 1, - &work_buf_avg[0], - &work_buf_M2[0], - &work_buf_N[0], - sync_flag, - (float*)shared_buf_avg, - (float*)shared_buf_M2, - (long*)shared_buf_N, - threadIdx.x tensor_dims = {x, y, z}; - auto in0 = at::randn(tensor_dims, options); - - auto out_avg = at::empty({z}, options); - auto out_var = at::empty({z}, options); - auto work_buf_avg = at::empty({x * y * z}, options); - auto work_buf_var = at::empty({x * y * z}, options); - auto work_buf_N = at::empty({x * y * z}, options_int); - auto sync_flag = at::zeros({1}, options_int); - fe.runRtc( - lp, - {in0, - out_avg, - out_var, - work_buf_avg, - work_buf_var, - work_buf_N, - sync_flag}); - std::vector dims{0, 1}; - - TORCH_CHECK(in0.mean(dims).allclose(out_avg, /*rtol*/ 1e-5, /*atol*/ 1e-6)); - TORCH_CHECK(in0.var(dims, false).allclose(out_var)); -} - -TEST_F(NVFuserTest, FusionWelfordOp_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - int M = 64, N = 128; - - auto tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - auto tv1 = mul(tv0, IrBuilder::create(1)); - auto tvs = Welford(tv1, {1}); - auto tv_avg = tvs.avg; - auto tv_M2 = tvs.var_sum; - auto tv_N = tvs.n; - fusion.addOutput(tv_avg); - fusion.addOutput(tv_M2); - fusion.addOutput(tv_N); - - tv_avg->split(1, 32); - tv_avg->split(0, 32); - tv_avg->split(0, 4); - tv_avg->reorder({{-1, -3}, {-3, -1}}); - tv1->computeAt(tv_avg, -1); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto options_int = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0); - at::manual_seed(0); - at::Tensor t0 = at::randn({M, N}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0}); - auto outputs = fe.runFusion({t0}); - - // by default Welford outputs sum of square diff so need to divide to get var - outputs[1] /= N; - - testValidate( - fe.kernel(), - outputs, - {t0}, - {t0.mean({1}), t0.var({1}, false), at::ones({M}, options_int) * N}, - __LINE__, - __FILE__); -} - -TEST_F(NVFuserTest, FusionBlockWelfordOp_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - int M = 64, N = 128; - - auto tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - auto tv1 = mul(tv0, IrBuilder::create(1)); - auto tvs = Welford(tv1, {1}); - auto tv_avg = tvs.avg; - auto tv_M2 = tvs.var_sum; - auto tv_N = tvs.n; - fusion.addOutput(tv_avg); - fusion.addOutput(tv_M2); - fusion.addOutput(tv_N); - - tv_avg->axis(-1)->parallelize(ParallelType::TIDx); - - tv1->computeAt(tv_avg, -1); - - // - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto options_int = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0); - at::manual_seed(0); - at::Tensor t0 = at::randn({M, N}, options); - at::Tensor t_var = at::empty({M}, options); - at::Tensor t_avg = at::empty({M}, options); - at::Tensor t_N = at::empty({M}, options_int); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0}); - auto outputs = fe.runFusion({t0}); - - // by default Welford outputs sum of square diff so need to divide to get var - outputs[1] /= N; - - testValidate( - fe.kernel(), - outputs, - {t0}, - {t0.mean({1}), t0.var({1}, false), at::ones({M}, options_int) * N}, - __LINE__, - __FILE__); -} - -TEST_F(NVFuserTest, FusionGridWelfordOp_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - int M = 64, N = 128; - - auto tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - auto tv1 = mul(tv0, IrBuilder::create(1)); - auto tvs = Welford(tv1, {1}); - auto tv_avg = tvs.avg; - auto tv_M2 = tvs.var_sum; - auto tv_N = tvs.n; - fusion.addOutput(tv_avg); - fusion.addOutput(tv_M2); - fusion.addOutput(tv_N); - - tv_avg->axis(0)->parallelize(ParallelType::TIDx); - tv_avg->axis(-1)->parallelize(ParallelType::BIDx); - - tv1->computeAt(tv_avg, -1); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto options_int = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0); - at::manual_seed(0); - at::Tensor t0 = at::randn({M, N}, options); - at::Tensor t_avg = at::empty({M}, options); - at::Tensor t_var = at::empty({M}, options); - at::Tensor t_N = at::empty({M}, options_int); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0}); - auto outputs = fe.runFusion({t0}); - - // by default Welford outputs sum of square diff so need to divide to get var - outputs[1] /= N; - - testValidate( - fe.kernel(), - outputs, - {t0}, - {t0.mean({1}), t0.var({1}, false), at::ones({M}, options_int) * N}, - __LINE__, - __FILE__); -} - -TEST_F(NVFuserTest, FusionRfactorWelfordOp_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - int M = 64, N = 128; - - auto tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - auto tv1 = mul(tv0, IrBuilder::create(1)); - auto tvs = Welford(tv1, {1}); - auto tv_avg = tvs.avg; - auto tv_M2 = tvs.var_sum; - auto tv_N = tvs.n; - fusion.addOutput(tv_avg); - fusion.addOutput(tv_M2); - fusion.addOutput(tv_N); - - tv_avg->split(1, 4); - ir_utils::rfactorHelper(tvs.avg, {2}); - tv1->computeAt(tv_avg, -1); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto options_int = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0); - at::manual_seed(0); - at::Tensor t0 = at::randn({M, N}, options); - at::Tensor t_avg = at::empty({M}, options); - at::Tensor t_var = at::empty({M}, options); - at::Tensor t_N = at::empty({M}, options_int); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0}); - auto outputs = fe.runFusion({t0}); - - // by default Welford outputs sum of square diff so need to divide to get var - outputs[1] /= N; - - testValidate( - fe.kernel(), - outputs, - {t0}, - {t0.mean({1}), t0.var({1}, false), at::ones({M}, options_int) * N}, - __LINE__, - __FILE__); -} - -TEST_F(NVFuserTest, FusionWelfordSchedule_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - int M = 64, N = 128; - - auto tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - auto tv1 = mul(tv0, IrBuilder::create(1)); - auto tvs = Welford(tv1, {1}); - auto tv_avg = tvs.avg; - auto tv_M2 = tvs.var_sum; - auto tv_N = tvs.n; - fusion.addOutput(tv_avg); - fusion.addOutput(tv_M2); - fusion.addOutput(tv_N); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto options_int = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0); - at::manual_seed(0); - at::Tensor t0 = at::randn({M, N}, options); - // TODO: Why do we use launch params from here, but not scheduling??? - auto reduction_params = getReductionHeuristics(&fusion, {t0}); - scheduleReduction(&fusion, *reduction_params); - - auto lparams = reduction_params->lparams; - FusionExecutor fe; - fe.compileFusion(&fusion, {t0}, lparams); - auto outputs = fe.runFusion({t0}, lparams); - - // by default Welford outputs sum of square diff so need to divide to get var - outputs[1] /= N; - - auto at_avg = t0.mean({1}); - auto at_var = t0.var({1}, false); - auto at_n = at::ones({M}, options_int) * N; - - testValidate( - fe.kernel(), - outputs, - {t0}, - {at_avg, at_var, at_n}, - __LINE__, - __FILE__, - "validate welford", - reduction_params->lparams); -} - -namespace { -void testWelford(DataType dtype, int red_axis, int odim, int rdim) { - const int axis = red_axis; - at::ScalarType aten_dtype = data_type_to_aten(dtype); - - Fusion fusion; - FusionGuard fg(&fusion); - TensorView* tv0 = makeSymbolicTensor(2, dtype); - bool is_fp16 = dtype == DataType::Half; - bool is_bf16 = dtype == DataType::BFloat16; - TensorView* tv0_cast = tv0; - if (is_fp16 || is_bf16) { - tv0_cast = castOp(DataType::Float, tv0); - } - fusion.addInput(tv0); - auto tv1 = mul(tv0_cast, IrBuilder::create(1)); - auto tvs = Welford(tv1, {axis}); - auto tv_avg = tvs.avg; - auto tv_M2 = tvs.var_sum; - auto tv_N = tvs.n; - - TensorView* avg_cast = tv_avg; - TensorView* M2_cast = tv_M2; - - if (is_fp16) { - avg_cast = castOp(DataType::Half, tv_avg); - M2_cast = castOp(DataType::Half, tv_M2); - } - if (is_bf16) { - avg_cast = castOp(DataType::BFloat16, tv_avg); - M2_cast = castOp(DataType::BFloat16, tv_M2); - } - - fusion.addOutput(avg_cast); - fusion.addOutput(M2_cast); - fusion.addOutput(tv_N); - - auto options = at::TensorOptions().dtype(aten_dtype).device(at::kCUDA, 0); - auto options_int = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0); - at::manual_seed(0); - std::vector outputs_of_red; - at::Tensor aten_input = - (axis ? at::randn({odim, rdim}, options) - : at::randn({rdim, odim}, options)); - - if (is_fp16 || is_bf16) { - outputs_of_red.push_back(avg_cast); - outputs_of_red.push_back(M2_cast); - } - - auto reduction_params = getReductionHeuristics(&fusion, {aten_input}); - scheduleReduction(&fusion, *reduction_params); - - auto lparams = reduction_params->lparams; - - FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input}, lparams); - auto outputs = fe.runFusion({aten_input}, lparams); - - // by default Welford outputs sum of square diff so need to divide to - // get var - - outputs[1] /= rdim; - - auto at_avg = aten_input.mean({axis}); - auto at_var = aten_input.var({axis}, false); - auto at_n = - (axis ? at::ones({odim, rdim}, options) - : at::ones({rdim, odim}, options)); - at_n = at_n.sum({axis}); - - testValidate( - fe.kernel(), - outputs, - {aten_input}, - {at_avg, at_var, at_n}, - __LINE__, - __FILE__, - "validate welford", - reduction_params->lparams); -} -} // namespace - -TEST_F(NVFuserTest, FusionWelfordShmoo_CUDA) { - std::vector dtypes = { - DataType::Double, DataType::Float, DataType::Half}; - // TODO: enable this for complex. Currently, complex yields - // silent wrong results: - // Detected abs error of: 3.8062 - // absolute tolerance was set to 2.23704e-06 - // and relative tolerance set to 2.23704e-08 -#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 - if (at::cuda::getDeviceProperties(0)->major >= 8) { - dtypes.insert(dtypes.end(), DataType::BFloat16); - } -#endif - - std::vector red_axis = {1, 0}; - std::vector output_dims = {160, 320}; - std::vector red_dims; - - // Tried to cut down the number iterations with just - // doing every other power of 2. - for (int i = 1; i <= 1024 * 1024; i <<= 2) { - red_dims.push_back(i); - } - - for (auto dtype : dtypes) { - for (auto& axis : red_axis) { - for (auto& odim : output_dims) { - for (auto& rdim : red_dims) { - // TODO: original welford algorithm actually keeps a running sum of - // squares, i.e. M_{2n} in the - // cf: - // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance - // algorithm notation, and it can reach inf for large numbers - // with half precision. skipping too large volumes for half for - // nwo might need further numerical experiments to re-design - // this. - if (rdim > 32768 && - (dtype == DataType::Half || dtype == DataType::BFloat16)) { - continue; - } - testWelford(dtype, axis, odim, rdim); - } - } - } - } -} - -namespace { -void testVarMean(at::ScalarType dtype, int correction, bool keepdim) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - int M = 64, N = 128; - - auto tv0 = makeSymbolicTensor(2, aten_to_data_type(dtype)); - fusion->addInput(tv0); - auto tvs = variance_mean(tv0, {1}, correction, keepdim); - auto tv_mean = tvs.mean; - auto tv_var = tvs.var; - fusion->addOutput(tv_var); - fusion->addOutput(tv_mean); - - auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA, 0); - at::manual_seed(0); - at::Tensor t0 = at::randn({M, N}, options); - - FusionExecutorCache executor_cache(std::move(fusion)); - auto outputs = executor_cache.runFusionWithInputs({t0}); - - auto at_var_mean = at::var_mean(t0, {1}, correction, keepdim); - std::vector aten_outputs = { - std::get<0>(at_var_mean), std::get<1>(at_var_mean)}; - - testValidate( - executor_cache.fusion(), outputs, {t0}, aten_outputs, __LINE__, __FILE__); -} -} // namespace - -TEST_F(NVFuserTest, FusionVarMean_CUDA) { - std::vector dtypes = {at::kFloat, at::kDouble}; - std::vector corrections = {0, 1}; - std::vector keepdims = {false, true}; - for (auto correction : corrections) { - for (auto keepdim : keepdims) { - for (auto dtype : dtypes) { - testVarMean(dtype, correction, keepdim); - } - } - } -} - -TEST_F(NVFuserTest, FusionSimpleGemmTransposed_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Set up your input tensor views - - TensorView* tv0 = makeSymbolicTensor(2); // K, M - TensorView* tv1 = makeSymbolicTensor(2); // N, K - fusion.addInput(tv0); - fusion.addInput(tv1); - - TensorView* tv0_t = transpose(tv0); - TensorView* tv1_t = transpose(tv1); - - TensorView* tv2 = broadcast(tv0_t, {false, false, true}); - // tv2[I0, I1, B] = tv0[I0, I1] - - TensorView* tv3 = broadcast(tv1_t, {true, false, false}); - // tv3[B, I1, I2] = tv1[I1, I2] - - // tv4[I0, I1, I2] = tv2[I0, I1, B] * tv3[B, I1, I2] - TensorView* tv4 = mul(tv2, tv3); - // tv5[I0, R1, I2] = tv4[I0, I1, I2] - TensorView* tv5 = sum(tv4, {1}); - fusion.addOutput(tv5); - - tv5->split(1, 32); - // tv5[I0, R1o, R1i{32}, I2] - - auto tv6 = tv5->rFactor({1}); - // tv6[I0, R1o, I1i{32}, I2] = tv4[I0, I1, I2] - // tv5[I0, , R1i{32}, I2] = tv6[I0, R1o, I1i{32}, I2] - - tv5->split(0, 4); - tv5->split(-1, 4); - // tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}] - // tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}] - - tv0_t->computeAt(tv5, -1); - tv1_t->computeAt(tv5, -1); - - // tv6[I0o, I0i{4}, R1o, I1i{32}, I2o, I2i{4}] - // tv5[I0o, I0i{4}, , R1i{32}, I2o, I2i{4}] - //--> (line symbolizes compute at location) - // tv4[I0o, I0i{4}, I1i{32}, I2o, I2i{4}|, I1o] - // tv6[I0o, I0i{4}, I1i{32}, I2o, I2i{4}|, R1o] - // tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}|] - - tv0_t->computeAt(tv6, -1); - tv1_t->computeAt(tv6, -1); - // tv4[I0o, I0i{4}, I1i{32}, I2o, I2i{4}, I1o |] - // tv6[I0o, I0i{4}, I1i{32}, I2o, I2i{4}, R1o |] - // tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}|] - - tv5->axis(0)->parallelize(ParallelType::BIDz); - tv5->axis(1)->parallelize(ParallelType::TIDz); - - tv5->axis(-2)->parallelize(ParallelType::BIDy); - tv5->axis(-1)->parallelize(ParallelType::TIDy); - - tv5->axis(2)->parallelize(ParallelType::TIDx); - tv6->axis(2)->parallelize(ParallelType::TIDx); - - constexpr int M = 65, K = 33, N = 17; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::Tensor t0 = at::randn({K, M}, options); - at::Tensor t1 = at::randn({N, K}, options); - - // Lets specify a few bounds in launch params to make sure it works - LaunchParams lparams(1, -1, -1, 32, 4, 4); - FusionExecutor fe; - fe.compileFusion(&fusion, {t0, t1}, lparams); - fe.runFusion({t0, t1}, lparams); - - // Don't specify any launch params - auto cg_outputs = fe.runFusion({t0, t1}); - - auto aten_output = t0.t().to(at::kDouble).matmul(t1.t().to(at::kDouble)); - - testValidate( - &fusion, cg_outputs, {t0, t1}, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionSoftmax3DTransposed_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - const int tidx = 32; - const int dimx = 32; - const int dimy = 16; - const int dimz = 130; - - // Set up your input tensor views - TensorView* input_tv0 = makeSymbolicTensor(3); - fusion.addInput(input_tv0); - - TensorView* input_t = transpose(input_tv0, 1, 2); - - TensorView* exp_tv1 = unaryOp(UnaryOpType::Exp, input_t); - TensorView* sum_exp_tv2 = sum(exp_tv1, {-1}); - TensorView* bcast_sum_tv3 = broadcast(sum_exp_tv2, {false, false, true}); - - // Replicate exp_tv4 as exp_tv4_copy because exp_tv4 is going to be - // computed at sum_exp_rf_tv8. - TensorView* input_t_copy = transpose(input_tv0, 1, 2); - TensorView* exp_tv1_copy = unaryOp(UnaryOpType::Exp, input_t_copy); - - TensorView* output_tv4 = div(exp_tv1_copy, bcast_sum_tv3); - - fusion.addOutput(output_tv4); - - bcast_sum_tv3->split(-1, tidx); - - sum_exp_tv2->split(-1, tidx); - TensorView* sum_exp_rf_tv5 = sum_exp_tv2->rFactor({-2}); - - output_tv4->split(-1, tidx); - - input_t->computeAt(sum_exp_rf_tv5, -1); - input_t_copy->computeAt(output_tv4, -1); - - TensorView* tensors_to_parallelize[] = { - sum_exp_tv2, bcast_sum_tv3, output_tv4, sum_exp_rf_tv5}; - - for (auto tv : tensors_to_parallelize) { - tv->axis(0)->parallelize(ParallelType::BIDx); - tv->axis(1)->parallelize(ParallelType::BIDy); - tv->axis(-1)->parallelize(ParallelType::TIDx); - } - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input = at::randn({dimx, dimz, dimy}, options); - - at::Tensor cg_output = at::empty({dimx, dimy, dimz}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {input}); - fe.runFusion({input}, {cg_output}); - - auto aten_input_t = at::transpose(input, 1, 2); - auto aten_output = at::_softmax(aten_input_t.to(at::kDouble), -1, false); - - testValidate( - &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionAdvancedComputeAtTransposed1_CUDA) { - // Case 1 - // tv1 = tv0 * 0.5 - // tv2 = tv1 * -1 - // tv3 = tv1 + 3 - // tv4 = tv1 * 2 - // tv5 = tv3 + tv2 - // tv6 = tv5 + tv4 - // tv7 = tv1 + tv4 - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - - tv0 = transpose(tv0); - - TensorView* tv1 = mul(tv0, IrBuilder::create(0.5)); - TensorView* tv2 = mul(tv1, IrBuilder::create(-1.0)); - TensorView* tv3 = add(tv1, IrBuilder::create(3.0)); - TensorView* tv4 = mul(tv1, IrBuilder::create(2.0)); - TensorView* tv5 = add(tv3, tv2); - - TensorView* tv6 = add(tv5, tv4); - TensorView* tv7 = add(tv1, tv4); - - fusion.addOutput(tv6); - fusion.addOutput(tv7); - - // Lets setup to actually run - tv7->merge(0); - tv7->split(0, 128); - tv7->split(0, 4); - - tv7->axis(0)->parallelize(ParallelType::BIDx); - - tv0->computeAt(tv7, 1); - - // The this-position of the last tensor should be zero. - TORCH_CHECK( - tv7->nDims() == 3 && tv7->getComputeAtPosition() == 0 && - tv7->getMaxProducerPosition() == 1); - TORCH_CHECK( - tv6->nDims() == 3 && tv6->getComputeAtPosition() == 0 && - tv6->getMaxProducerPosition() == 1); - // The position of every other tensor should be 1. - for (auto tv : {tv1, tv2, tv3, tv4, tv5}) { - TORCH_CHECK(tv->nDims() == 3 && tv->getComputeAtPosition() == 1); - } - - for (Val* val : fusion.vals()) { - if (!val->isFusionInput() && - val->getValType().value() == ValType::TensorView) { - TensorView* tv = static_cast(val); - tv->axis(1)->parallelize(ParallelType::Unroll); - tv->axis(-1)->parallelize(ParallelType::TIDx); - } - } - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::Tensor aten_input = at::randn({129, 127}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input}); - auto cg_outputs = fe.runFusion({aten_input}); - - at::Tensor aten_input_t = aten_input.t(); - - auto t1 = aten_input_t.mul({0.5}); - auto t2 = t1.mul({-1.0}); - auto t3 = t1.add({3.0}); - auto t4 = t1.mul({2.0}); - auto t5 = t3.add(t2); - auto t6 = t5.add(t4); - auto t7 = t1.add(t4); - - std::vector aten_outputs = {t6, t7}; - - testValidate( - &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionAdvancedComputeAtTransposed2_CUDA) { - // Case 2 - // tv1 = tv0 * -1 - // tv2 = tv0 + 3 - // tv3 = tv0 * 2 - // tv4 = tv2 + tv1 - // tv5 = tv4 + tv3 - // tv6 = tv5 + tv3 - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - - tv0 = transpose(tv0); - - TensorView* tv1 = mul(tv0, IrBuilder::create(-1.0)); - TensorView* tv2 = add(tv0, IrBuilder::create(3.0)); - TensorView* tv3 = mul(tv0, IrBuilder::create(2.0)); - TensorView* tv4 = add(tv2, tv1); - - TensorView* tv5 = add(tv4, tv3); - TensorView* tv6 = add(tv5, tv3); - - fusion.addOutput(tv5); - fusion.addOutput(tv6); - - // Lets setup to actually run - tv6->merge(0); - tv6->split(0, 128); - tv6->split(0, 4); - - tv6->axis(0)->parallelize(ParallelType::BIDx); - - tv0->computeAt(tv6, 1); - - for (Val* val : fusion.vals()) { - if (!val->isFusionInput() && - val->getValType().value() == ValType::TensorView) { - TensorView* tv = static_cast(val); - - tv->axis(1)->parallelize(ParallelType::Unroll); - tv->axis(-1)->parallelize(ParallelType::TIDx); - } - } - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input = at::randn({129, 127}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {input}); - auto cg_outputs = fe.runFusion({input}); - - auto input_t = input.t(); - auto t1 = input_t.mul({-1.0}); - auto t2 = input_t.add({3.0}); - auto t3 = input_t.mul({2.0}); - auto t4 = t2.add(t1); - auto t5 = t4.add(t3); - auto t6 = t5.add(t3); - - std::vector aten_outputs = {t5, t6}; - - testValidate(&fusion, cg_outputs, {input}, aten_outputs, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionAdvancedComputeAtTransposed3_CUDA) { - // Case 3 - // T2 = T1 * 0.979361 - // T3 = T2 * T0 - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeSymbolicTensor(4); - fusion.addInput(tv0); - - tv0 = permute(tv0, {3, 0, 1, 2}); - - TensorView* tv1 = makeSymbolicTensor(4); - fusion.addInput(tv1); - - tv1 = permute(tv1, {3, 0, 1, 2}); - - TensorView* tv2 = mul(tv1, IrBuilder::create(.979361)); - TensorView* tv3 = mul(tv2, tv0); - - fusion.addOutput(tv3); - - // Lets setup to actually run - while (tv3->nDims() > 1) - tv3->merge(0); - tv3->split(0, 128); - tv3->split(0, 4); - - tv0->computeAt(tv3, 1); - tv1->computeAt(tv3, 1); - - tv3->axis(0)->parallelize(ParallelType::BIDx); - - for (Val* val : fusion.vals()) { - if (!val->isFusionInput() && - val->getValType().value() == ValType::TensorView) { - TensorView* tv = static_cast(val); - - tv->axis(1)->parallelize(ParallelType::Unroll); - tv->axis(-1)->parallelize(ParallelType::TIDx); - } - } - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({129, 127, 63, 65}, options); - at::Tensor t1 = at::rand_like(t0, options); - - std::vector aten_inputs = {t0, t1}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto cg_outputs = fe.runFusion(aten_inputs); - - auto t0_t = t0.permute({3, 0, 1, 2}); - auto t1_t = t1.permute({3, 0, 1, 2}); - auto t2 = t1_t.mul({0.979361}); - auto aten_output = t2.mul(t0_t); - - testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionAdvancedComputeAtTransposed4_CUDA) { - // Case 4 - // T4 = T2 - T3 - // T5 = T1 + T4 - // T6 = T5 - T0 - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeSymbolicTensor(4); - fusion.addInput(tv0); - - tv0 = permute(tv0, {3, 0, 1, 2}); - - TensorView* tv1 = makeSymbolicTensor(4); - fusion.addInput(tv1); - - tv1 = permute(tv1, {3, 0, 1, 2}); - - TensorView* tv2 = makeSymbolicTensor(4); - fusion.addInput(tv2); - - tv2 = permute(tv2, {3, 0, 1, 2}); - - TensorView* tv3 = makeSymbolicTensor(4); - fusion.addInput(tv3); - - tv3 = permute(tv3, {3, 0, 1, 2}); - - TensorView* tv4 = sub(tv2, tv3); - TensorView* tv5 = add(tv1, tv4); - TensorView* tv6 = sub(tv5, tv0); - - fusion.addOutput(tv6); - - // Lets setup to actually run - while (tv6->nDims() > 1) - tv6->merge(0); - tv6->split(0, 128); - tv6->split(0, 4); - - tv0->computeAt(tv6, 1); - tv1->computeAt(tv6, 1); - tv2->computeAt(tv6, 1); - tv3->computeAt(tv6, 1); - - tv6->axis(0)->parallelize(ParallelType::BIDx); - - for (Val* val : fusion.vals()) { - if (!val->isFusionInput() && - val->getValType().value() == ValType::TensorView) { - TensorView* tv = static_cast(val); - - tv->axis(1)->parallelize(ParallelType::Unroll); - tv->axis(-1)->parallelize(ParallelType::TIDx); - } - } - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({129, 127, 63, 65}, options); - at::Tensor t1 = at::rand_like(t0, options); - at::Tensor t2 = at::rand_like(t0, options); - at::Tensor t3 = at::rand_like(t0, options); - - std::vector aten_inputs = {t0, t1, t2, t3}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto cg_outputs = fe.runFusion(aten_inputs); - - auto t0_t = t0.permute({3, 0, 1, 2}); - auto t1_t = t1.permute({3, 0, 1, 2}); - auto t2_t = t2.permute({3, 0, 1, 2}); - auto t3_t = t3.permute({3, 0, 1, 2}); - auto t4 = t2_t.sub(t3_t); - auto t5 = t1_t.add(t4); - auto aten_output = t5.sub(t0_t); - - testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionAdvancedComputeAtTransposed5_CUDA) { - // Case 5 - // tv2 = tv0 + 2.0 - // tv3 = tv1 * tv2 - Fusion fusion; - FusionGuard fg(&fusion); - - // Set up your input tensor views - TensorView* tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - tv0 = transpose(tv0); - TensorView* tv1 = makeSymbolicTensor(2); - fusion.addInput(tv1); - tv1 = transpose(tv1); - TensorView* tv2 = add(tv0, IrBuilder::create(2.0)); - TensorView* tv3 = mul(tv1, tv2); - fusion.addOutput(tv3); - - tv3->merge(0); - tv3->split(-1, 8); - tv3->split(-1, 4); - - tv0->computeAt(tv3, 1); - tv1->computeAt(tv3, 1); - tv3->axis(0)->parallelize(ParallelType::BIDx); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({63, 65}, options); - at::Tensor t1 = at::rand_like(t0, options); - - std::vector aten_inputs = {t0, t1}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto cg_outputs = fe.runFusion(aten_inputs); - - auto t2 = t0.t().add(2.0); - auto aten_output = t1.t().mul(t2); - - testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionAdvancedComputeAtTransposed6_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - tv0 = transpose(tv0); - TensorView* tv1 = makeSymbolicTensor(2); - fusion.addInput(tv1); - tv1 = transpose(tv1); - TensorView* tv2 = add(tv0, IrBuilder::create(2.0)); - TensorView* tv3 = mul(tv1, tv2); - fusion.addOutput(tv3); - - tv2->merge(0); - tv2->split(-1, 8); - tv2->split(-1, 4); - tv3->merge(0); - tv3->split(-1, 8); - - tv0->computeAt(tv3, 1); - tv1->computeAt(tv3, 1); - - tv3->axis(0)->parallelize(ParallelType::BIDx); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({63, 65}, options); - at::Tensor t1 = at::rand_like(t0, options); - - std::vector aten_inputs = {t0, t1}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto cg_outputs = fe.runFusion(aten_inputs); - - auto t2 = t0.t().add(2.0); - auto aten_output = t1.t().mul(t2); - - testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionSegmentReducePointwise_CUDA) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - TensorView* tv0 = makeSymbolicTensor(2); - TensorView* tv1 = makeSymbolicTensor(1); - TensorView* tv2 = makeSymbolicTensor(2); - - fusion->addInput(tv0); - fusion->addInput(tv1); - fusion->addInput(tv2); - - TensorView* tv3 = add(tv0, IrBuilder::create(1)); // Group 0 - TensorView* tv4 = - max(tv3, {0}); // Group 0 (use max instead to avoid numerical issues) - TensorView* tv5 = add(tv4, tv1); // Group 0 (Non Broadcast after reduce, - // keeps normalization scheduler away) - TensorView* tv6 = add(tv5, tv2); // Group 1 (Broadcast after reduce) - - fusion->addOutput(tv6); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({128, 65}, options); - at::Tensor t1 = at::randn({65}, options); - at::Tensor t2 = at::randn({128, 65}, options); - - auto t3 = t0.add(1.0); - auto t4 = std::get<0>(at::max(t3, 0)); - auto t5 = t4.add(t1); - auto t6 = t5.add(t2); - - FusionExecutorCache executor_cache(std::move(fusion)); - - auto outputs = executor_cache.runFusionWithInputs({t0, t1, t2}); - - TORCH_CHECK( - executor_cache.getMostRecentKernelRuntime()->isSegmented(), - "segmentation didn't happen"); - TORCH_CHECK( - executor_cache.getMostRecentKernelRuntime() - ->fusionSegments() - ->groups() - .size() == 2, - "segmentation didn't happen as expected"); - - testValidate( - executor_cache.fusion(), outputs, {t0, t1, t2}, {t6}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionMultipleVectorize_CUDA) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - TensorView* tv0 = makeContigTensor(1); - TensorView* tv1 = makeContigTensor(1); - - fusion->addInput(tv0); - fusion->addInput(tv1); - - TensorView* tv3 = add(tv0, tv1); - fusion->addOutput(tv3); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({40960}, options); - at::Tensor t1 = at::randn({40960}, options); - auto t2 = t0 + t1; - - FusionExecutorCache executor_cache(std::move(fusion)); - executor_cache.profile(true); - - auto outputs = executor_cache.runFusionWithInputs({t0, t1}); - auto runtime1 = executor_cache.getMostRecentKernelRuntime(); - auto log1 = std::dynamic_pointer_cast( - executor_cache.getMostRecentExecutorInfo().params); - TORCH_CHECK(log1 != nullptr); - TORCH_CHECK(log1->vectorize); - - testValidate( - executor_cache.fusion(), outputs, {t0, t1}, {t2}, __LINE__, __FILE__); - - t0 = at::randn({40964}, options); - t1 = at::randn({40964}, options); - t2 = t0 + t1; - - outputs = executor_cache.runFusionWithInputs({t0, t1}); - auto runtime2 = executor_cache.getMostRecentKernelRuntime(); - auto log2 = std::dynamic_pointer_cast( - executor_cache.getMostRecentExecutorInfo().params); - TORCH_CHECK(log2 != nullptr); - TORCH_CHECK(log2->vectorize); - - testValidate( - executor_cache.fusion(), outputs, {t0, t1}, {t2}, __LINE__, __FILE__); - - t0 = at::randn({40962}, options); - t1 = at::randn({40962}, options); - t2 = t0 + t1; - - outputs = executor_cache.runFusionWithInputs({t0, t1}); - auto runtime3 = executor_cache.getMostRecentKernelRuntime(); - auto log3 = std::dynamic_pointer_cast( - executor_cache.getMostRecentExecutorInfo().params); - TORCH_CHECK(log3 != nullptr); - TORCH_CHECK(log3->vectorize); - - testValidate( - executor_cache.fusion(), outputs, {t0, t1}, {t2}, __LINE__, __FILE__); - - TORCH_CHECK(runtime1 == runtime2); - TORCH_CHECK(runtime1 != runtime3); -} - -TEST_F(NVFuserTest, FusionVectorizeSimple_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeContigTensor(3); - - fusion.addInput(tv0); - - auto tv1 = unaryOp(UnaryOpType::Sin, tv0); - - fusion.addOutput(tv1); - - auto tv0_cache = tv0->cacheAfter(); - - auto tv1_cache = tv1->cacheBefore(); - - tv1->merge(0); - tv1->merge(0); - tv1->split(0, 4); - tv1->split(0, 128); - - tv1->axis(0)->parallelize(ParallelType::BIDx); - tv1->axis(1)->parallelize(ParallelType::TIDx); - - tv0->computeAt(tv1, 2); - - tv0_cache->axis(2)->parallelize(ParallelType::Vectorize); - tv1->axis(2)->parallelize(ParallelType::Vectorize); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::Tensor aten_input = at::empty({2, 6, 32}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {aten_input}); - auto cg_outputs = fe.runFusion({aten_input}); - - at::Tensor aten_output = aten_input.sin(); - - testValidate( - &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionSimpleVectorizeUnroll_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - // dimensionality of the problem - int nDims = 3; - - // Set up your input tensor views - TensorView* tv0 = makeContigTensor(nDims); - TensorView* tv1 = makeContigTensor(nDims); - - // Register your inputs - fusion.addInput(tv0); - fusion.addInput(tv1); - - // Do math with it, it returns a `Val*` but can be static_casted back to - // TensorView - TensorView* tv2 = add(tv1, IrBuilder::create(2.0)); - TensorView* tv3 = add(tv0, tv2); - - // Register your outputs - fusion.addOutput(tv3); - - auto tv0_cache = tv0->cacheAfter(); - auto tv1_cache = tv1->cacheAfter(); - auto tv3_cache = tv3->cacheBefore(); - - // Do transformations, remember, transformations are outputs to inputs - // This doesn't have to be in this order - tv3->merge(1); - - // Split by n_threads - tv3->split(1, 2); - tv3->split(0, 3); - tv3->split(0, 1); - - // [bidx, unswitch, unroll{2}, tidx, vectorize{2}] - - // Parallelize TV3 - tv3->axis(0)->parallelize(ParallelType::BIDx); - tv3->axis(1)->parallelize(ParallelType::Unswitch); - tv3->axis(2)->parallelize(ParallelType::Unroll); - tv3->axis(3)->parallelize(ParallelType::TIDx); - - tv3->reorder({{4, 2}}); - // [bidx, unswitch, vectorize{2}, unroll{2}, tidx] - - TransformPropagatorWithCheck propagator(tv3); - MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); - scheduler_utils::parallelizeAllLike(tv3); - - tv0_cache->axis(2)->parallelize(ParallelType::Vectorize); - tv1_cache->axis(2)->parallelize(ParallelType::Vectorize); - tv3->axis(2)->parallelize(ParallelType::Vectorize); - - // For all inputs, computeAt the output inline, temporaries should be squeezed - // between them - tv0->computeAt(tv3, -1, ComputeAtMode::MostInlined); - tv1->computeAt(tv3, -1, ComputeAtMode::MostInlined); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::Tensor input1 = at::randn({64, 2, 128}, options); - at::Tensor input2 = at::rand_like(input1); - at::Tensor output = at::empty_like(input1); - - FusionExecutor fe; - fe.compileFusion(&fusion, {input1, input2}); - fe.runFusion({input1, input2}, {output}); - - at::Tensor tv2_ref = input2 + 2.0; - at::Tensor output_ref = input1 + tv2_ref; - - TORCH_CHECK(output_ref.equal(output)); -} - -TEST_F(NVFuserTest, FusionSegmentReduceSoftmax_CUDA) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - std::vector input_shape{32, 64, 8}; - const int kReductionAxis = 1; - - auto tv0 = TensorViewBuilder() - .ndims(input_shape.size()) - .dtype(DataType::Double) - .build(); - - fusion->addInput(tv0); - - auto tv1 = add(tv0, IrBuilder::create(1.0)); - auto tv2 = sum(tv1, {2}); // Group 0 - - auto output = softmax(tv2, kReductionAxis); // Group 1 - fusion->addOutput(output); - - auto options = at::TensorOptions().dtype(at::kDouble).device(at::kCUDA, 0); - at::Tensor at_x = at::randn(input_shape, options); - - FusionExecutorCache executor_cache(std::move(fusion)); - - auto outputs = executor_cache.runFusionWithInputs({at_x}); - - auto t1 = at_x.add(1.0); - auto t2 = t1.sum({2}); - auto t3 = at::_softmax(t2.to(at::kDouble), -1, false); - - auto optimized_fusion = executor_cache.getMostRecentKernelRuntime(); - TORCH_CHECK(optimized_fusion->isSegmented(), "segmentation didn't happen"); - TORCH_CHECK( - optimized_fusion->fusionSegments()->groups().size() == 2, - "segmentation didn't happen as expected"); - - testValidate( - executor_cache.fusion(), outputs, {at_x}, {t3}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionSwizzle1_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - auto tv1 = add(tv0, IrBuilder::create(1)); - auto tv2 = mul(tv1, IrBuilder::create(2)); - fusion.addOutput(tv2); - - tv2->split(0, 7); - tv2->split(0, 9); - - tv0->computeAt(tv2, 1); - - tv2->axis(0)->parallelize(ParallelType::BIDx); - - tv1->setMemoryType(MemoryType::Shared); - tv1->swizzle(SwizzleType::Transpose, {1, 2}); - - tv1->axis(1)->parallelize(ParallelType::TIDx); - tv1->axis(2)->parallelize(ParallelType::TIDy); - - tv2->axis(1)->parallelize(ParallelType::TIDx); - tv2->axis(2)->parallelize(ParallelType::TIDy); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({100}, options); - - std::vector aten_inputs = {t0}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto cg_outputs = fe.runFusion(aten_inputs); - - auto aten_output = (t0 + 1) * 2; - - testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionSwizzle2_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - auto tv1 = add(tv0, IrBuilder::create(1)); - auto tv2 = mul(tv1, IrBuilder::create(2)); - fusion.addOutput(tv2); - - tv1->split(-1, 4); - tv1->split(-2, 4); - - tv2->split(-1, 4); - tv2->split(-2, 4); - - tv0->computeAt(tv2, 1); - - tv2->reorder({{-1, -2}}); - - tv1->setMemoryType(MemoryType::Shared); - tv1->swizzle(SwizzleType::Transpose, {-2, -1}); - - tv2->axis(0)->parallelize(ParallelType::BIDx); - tv2->axis(-1)->parallelize(ParallelType::TIDx); - tv2->axis(-2)->parallelize(ParallelType::TIDy); - tv1->axis(-1)->parallelize(ParallelType::TIDx); - tv1->axis(-2)->parallelize(ParallelType::TIDy); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({123}, options); - - std::vector aten_inputs = {t0}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto cg_outputs = fe.runFusion(aten_inputs); - - auto aten_output = (t0 + 1) * 2; - - testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionGridPersistence_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - - auto tv1 = sum(tv0, {0}); - auto tv2 = broadcast(tv1, {true}); - auto tv3 = add(tv0, tv2); - fusion.addOutput(tv3); - - std::vector tvs = {tv1, tv2, tv3}; - for (auto tv : tvs) { - tv->split(0, 2); - tv->axis(0)->parallelize(ParallelType::BIDx); - tv->axis(1)->parallelize(ParallelType::BIDy); - } - - const int numel_x = 10; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input = at::randn({numel_x}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {input}); - auto out = fe.runFusion({input}); - - auto aten_output = input.sum({0}).unsqueeze(-1).add(input); - - testValidate(&fusion, out, {input}, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionGridPersistence2_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - - auto tv1 = sum(tv0, {0}); - auto tv2 = broadcast(tv1, {true, false}); - auto tv3 = add(tv0, tv2); - fusion.addOutput(tv3); - - std::vector tvs = {tv1, tv2, tv3}; - for (auto tv : tvs) { - tv->split(0, 2); - tv->axis(0)->parallelize(ParallelType::BIDx); - tv->axis(1)->parallelize(ParallelType::TIDy); - tv->axis(2)->parallelize(ParallelType::TIDx); - } - - const int numel_x = 10; - const int numel_y = 3; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input = at::randn({numel_x, numel_y}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {input}); - auto out = fe.runFusion({input}); - - auto aten_output = input.sum({0}).unsqueeze(0).add(input); - - testValidate(&fusion, out, {input}, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionWelfordPersistence_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - - auto tvs = Welford(tv0, {0}); - auto tv4 = add(tvs.avg, tvs.var_sum); - auto tv5 = broadcast(tv4, {true}); - auto tv6 = add(tv0, tv5); - fusion.addOutput(tv6); - - std::vector schedule_tvs = { - tvs.avg, tvs.var_sum, tvs.n, tv5, tv6}; - - for (auto tv : schedule_tvs) { - tv->split(0, 2); - tv->axis(0)->parallelize(ParallelType::BIDx); - tv->axis(1)->parallelize(ParallelType::BIDy); - } - - const int numel_x = 10; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input = at::randn({numel_x}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {input}); - auto out = fe.runFusion({input}); - - auto aten_output = (input.mean({0}) + (input.var({0}, false) * numel_x)) - .unsqueeze(-1) - .add(input); - - testValidate(&fusion, out, {input}, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionWelfordPersistence2_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - - auto tvs = Welford(tv0, {0}); - auto tv4 = add(tvs.avg, tvs.var_sum); - auto tv5 = broadcast(tv4, {true, false}); - auto tv6 = add(tv0, tv5); - fusion.addOutput(tv6); - - std::vector schedule_tvs = { - tvs.avg, tvs.var_sum, tvs.n, tv5, tv6}; - for (auto tv : schedule_tvs) { - tv->split(0, 2); - tv->axis(0)->parallelize(ParallelType::BIDx); - tv->axis(1)->parallelize(ParallelType::TIDy); - tv->axis(2)->parallelize(ParallelType::TIDx); - } - tv4->axis(0)->parallelize(ParallelType::TIDx); - - const int numel_x = 10; - const int numel_y = 3; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input = at::randn({numel_x, numel_y}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {input}); - auto out = fe.runFusion({input}); - - auto aten_output = (input.mean({0}) + (input.var({0}, false) * numel_x)) - .unsqueeze(0) - .add(input); - - testValidate(&fusion, out, {input}, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionIssue633_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - const int dx = 10; - const int dy = 11; - const int dz = 12; - - auto tv0 = makeConcreteTensor({dx, dy, dz}); - fusion.addInput(tv0); - auto tv1 = makeConcreteTensor({dx, dy, 1}); - fusion.addInput(tv1); - auto tv2 = add(tv0, tv1); - fusion.addOutput(tv2); - - tv2->merge(1); - tv2->merge(0); - tv2->split(-1, 128); - - tv2->axis(0)->parallelize(ParallelType::BIDx); - tv2->axis(1)->parallelize(ParallelType::TIDx); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({dx, dy, dz}, options); - at::Tensor t1 = at::randn({dx, dy, 1}, options); - std::vector aten_inputs = {t0, t1}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto cg_outputs = fe.runFusion(aten_inputs); - - auto aten_output = t0 + t1; - - testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionBroadcastAcrossComputeAt_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - std::vector shape{17, 19}; - - auto tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - auto tv1 = makeSymbolicTensor(2); - fusion.addInput(tv1); - auto tv2 = broadcast(tv0, {false, true}); - auto tv3 = add(tv1, tv2); - fusion.addOutput(tv3); - - tv3->split(1, 128); - tv0->computeAt(tv3, 2); - - for (auto tv : {tv2, tv3}) { - tv->axis(-1)->parallelize(ParallelType::TIDx); - } - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({shape[0]}, options); - at::Tensor t1 = at::randn(shape, options); - std::vector aten_inputs = {t0, t1}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto cg_outputs = fe.runFusion(aten_inputs); - - auto t3 = t0.unsqueeze(-1).expand(shape) + t1; - - testValidate(&fusion, cg_outputs, aten_inputs, {t3}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionVectorizeMisalignedPointwise_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeContigTensor(2); - auto tv1 = makeContigTensor(2); - fusion.addInput(tv0); - fusion.addInput(tv1); - - auto tv2 = add(tv0, tv1); - fusion.addOutput(tv2); - - const int kTDX = 64; - const int kVecSize = 4; - const int kNumElems = kTDX * kVecSize; - - tv2->split(1, kNumElems); - - auto c0 = tv0->cacheAfter(); - auto c1 = tv1->cacheAfter(); - auto c2 = tv2->cacheBefore(); - - tv2->split(-1, kVecSize); - - c0->computeAt(tv2, -2); - c1->computeAt(tv2, -2); - - c0->axis(-1)->parallelize(ParallelType::MisalignedVectorize); - c1->axis(-1)->parallelize(ParallelType::MisalignedVectorize); - - tv2->axis(0)->parallelize(ParallelType::BIDx); - tv2->axis(-2)->parallelize(ParallelType::TIDx); - tv2->axis(-1)->parallelize(ParallelType::MisalignedVectorize); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - const int bx = 128; - const int by = 457; - at::Tensor t0 = at::randn({bx, by}, options); - at::Tensor t1 = at::randn({bx, by}, options); - - std::vector aten_inputs = {t0, t1}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto cg_outputs = fe.runFusion(aten_inputs); - - auto aten_output = t0 + t1; - testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionVectorizeMisalignedPointwiseMergeContig_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeContigTensor(4); - auto tv1 = makeContigTensor(4); - fusion.addInput(tv0); - fusion.addInput(tv1); - - auto tv2 = add(tv0, tv1); - fusion.addOutput(tv2); - - tv2->reorder({{0, 1}, {1, 0}}); - tv2->merge(-2); - - const int kTDX = 64; - const int kVecSize = 2; - const int kNumElems = kTDX * kVecSize; - - tv2->split(-1, kNumElems); - - auto c0 = tv0->cacheAfter(); - auto c1 = tv1->cacheAfter(); - auto c2 = tv2->cacheBefore(); - - tv2->split(0, 128); - tv2->split(-1, kVecSize); - - c0->computeAt(tv2, -2); - c1->computeAt(tv2, -2); - - c0->axis(-1)->parallelize(ParallelType::MisalignedVectorize); - c1->axis(-1)->parallelize(ParallelType::MisalignedVectorize); - - tv2->axis(0)->parallelize(ParallelType::BIDx); - tv2->axis(1)->parallelize(ParallelType::BIDy); - tv2->axis(-2)->parallelize(ParallelType::TIDx); - tv2->axis(-1)->parallelize(ParallelType::MisalignedVectorize); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - const int n = 32; - const int c = 127; - const int h = 51; - const int w = 23; - at::Tensor t0 = at::randn({n, c, h, w}, options); - at::Tensor t1 = at::randn({n, c, h, w}, options); - - std::vector aten_inputs = {t0, t1}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto cg_outputs = fe.runFusion(aten_inputs); - - auto aten_output = t0 + t1; - testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionVectorizeMisalignedPointwiseMergeSymbolicPass_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - constexpr int kNumDims = 4; - constexpr int kTDX = 64; - constexpr int kVecSize = 2; - constexpr int kNumElems = kTDX * kVecSize; - - auto tv0 = makeSymbolicTensor(kNumDims); - auto tv1 = makeSymbolicTensor(kNumDims); - fusion.addInput(tv0); - fusion.addInput(tv1); - - auto tv2 = add(tv0, tv1); - fusion.addOutput(tv2); - - // Create caches for vectorization - auto c0 = tv0->cacheAfter(); - auto c1 = tv1->cacheAfter(); - auto c2 = tv2->cacheBefore(); - - // Merge all dimensions together except inner-most dim - for (const auto idx : c10::irange(kNumDims - 2)) { - tv2->merge(0); - } - // Split inner-most dim - tv2->split(-1, kNumElems); - tv2->split(-1, kVecSize); - TransformPropagatorWithCheck propagator(tv2); - MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator); - - c0->computeAt(tv2, -2); - c1->computeAt(tv2, -2); - - // Parallelization Strategy - c0->axis(-1)->parallelize(ParallelType::MisalignedVectorize); - c1->axis(-1)->parallelize(ParallelType::MisalignedVectorize); - - tv2->axis(0)->parallelize(ParallelType::BIDx); - tv2->axis(2)->parallelize(ParallelType::TIDx); - tv2->axis(-1)->parallelize(ParallelType::MisalignedVectorize); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - const int n = 5; - const int c = 3; - const int h = 51; - const int w = 257; - at::Tensor t0 = at::randn({n, c, h, w}, options); - at::Tensor t1 = at::randn({n, c, h, w}, options); - - std::vector aten_inputs = {t0, t1}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto cg_outputs = fe.runFusion(aten_inputs); - - auto aten_output = t0 + t1; - testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionVectorizeMisalignedPointwiseMergeSymbolicFail_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - constexpr int kNumDims = 4; - constexpr int kTDX = 64; - constexpr int kVecSize = 2; - constexpr int kNumElems = kTDX * kVecSize; - std::vector bcast_shape{1, 1, 1, -1}; - - auto tv0 = makeContigTensor(kNumDims); - auto tv1 = TensorViewBuilder().shape(bcast_shape).build(); - fusion.addInput(tv0); - fusion.addInput(tv1); - - auto tv2 = add(tv0, tv1); - fusion.addOutput(tv2); - - // Create caches for vectorization - auto c0 = tv0->cacheAfter(); - auto c1 = tv1->cacheAfter(); - auto c2 = tv2->cacheBefore(); - - // Merge all dimensions together - // Backward merge order is necessary for vectorize validation - for (int idx = kNumDims - 1; idx > 0; --idx) { - tv2->merge(idx - 1); - } - tv2->split(-1, kNumElems); - tv2->split(-1, kVecSize); - TransformPropagatorWithCheck propagator(tv2); - MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator); - - c0->computeAt(tv2, -2); - c1->computeAt(tv2, -2); - - // Parallelization Strategy - c0->axis(-1)->parallelize(ParallelType::MisalignedVectorize); - c1->axis(-1)->parallelize(ParallelType::MisalignedVectorize); - - tv2->axis(0)->parallelize(ParallelType::BIDx); - tv2->axis(1)->parallelize(ParallelType::TIDx); - tv2->axis(-1)->parallelize(ParallelType::MisalignedVectorize); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - const int n = 32; - const int c = 128; - const int h = 51; - const int w = 23; - at::Tensor t0 = at::randn({n, c, h, w}, options); - at::Tensor t1 = at::randn({1, 1, 1, w}, options); - - std::vector aten_inputs = {t0, t1}; - - FusionExecutor fe; - // TODO: throw assertion - cannot merge non-contiguous vectorization axes - // Make sure compilation fails - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_ANY_THROW(fe.compileFusion(&fusion)); -} - -TEST_F(NVFuserTest, FusionVectorizeMisalignedRFactor_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeContigTensor(2); - auto tv1 = makeContigTensor(2); - - fusion.addInput(tv0); - fusion.addInput(tv1); - - auto tv2 = add(tv0, tv1); - - auto tv3 = sum(tv2, {-1}); - - fusion.addOutput(tv3); - - auto c0 = tv0->cacheAfter(); - auto c1 = tv1->cacheAfter(); - - tv3->split(-1, 128 * 4); - tv3->split(-1, 4); - // Reduce outer dim first - auto tv4 = tv3->rFactor({-3, -1}); - // Tv3 will reduce threads - - tv0->computeAt(tv3, 1); - tv1->computeAt(tv3, 1); - - tv3->axis(0)->parallelize(ParallelType::BIDx); - - tv0->computeAt(tv4, -2); - tv1->computeAt(tv4, -2); - - c0->axis(-1)->parallelize(ParallelType::MisalignedVectorize); - c1->axis(-1)->parallelize(ParallelType::MisalignedVectorize); - - tv4->axis(-2)->parallelize(ParallelType::TIDx); - tv3->axis(1)->parallelize(ParallelType::TIDx); - - tv2->computeAt(tv4, -1); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - const int bx = 128; - const int by = 2050; - at::Tensor t0 = at::randn({bx, by}, options); - at::Tensor t1 = at::randn({bx, by}, options); - - std::vector aten_inputs = {t0, t1}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto cg_outputs = fe.runFusion(aten_inputs); - - auto aten_output = t0.add(t1).sum(1); - testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionVectorizeMisalignedWrongDimFail_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeContigTensor(2); - auto tv1 = makeContigTensor(2); - - fusion.addInput(tv0); - fusion.addInput(tv1); - - auto tv2 = add(tv0, tv1); - fusion.addOutput(tv2); - - tv2->split(1, 16); - tv2->split(1, 64); - - tv2->axis(0)->parallelize(ParallelType::BIDx); - tv2->axis(2)->parallelize(ParallelType::TIDx); - - auto c0 = tv0->cacheAfter(); - auto c1 = tv1->cacheAfter(); - auto c2 = tv2->cacheBefore(); - - c0->computeAt(tv2, -2); - c1->computeAt(tv2, -2); - - std::vector vectorized_tvs = {c0, c1, tv2}; - for (auto tv : vectorized_tvs) { - tv->split(-1, 4); - // Vectorize the wrong dimension - tv->axis(-2)->parallelize(ParallelType::MisalignedVectorize); - } - - FusionExecutor fe; - // Make sure compilation fails - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_ANY_THROW(fe.compileFusion(&fusion)); -} - -TEST_F(NVFuserTest, FusionVectorizeMisalignedStride_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2); - auto tv1 = makeSymbolicTensor(2); - - fusion.addInput(tv0); - fusion.addInput(tv1); - - auto tv2 = add(tv0, tv1); - fusion.addOutput(tv2); - - const int kTDX = 64; - const int kVecSize = 4; - const int kNumElems = kTDX * kVecSize; - - tv2->split(1, kNumElems); - - auto c0 = tv0->cacheAfter(); - auto c1 = tv1->cacheAfter(); - - tv2->split(-1, kVecSize); - - c0->computeAt(tv2, -2); - c1->computeAt(tv2, -2); - - c0->axis(-1)->parallelize(ParallelType::MisalignedVectorize); - c1->axis(-1)->parallelize(ParallelType::MisalignedVectorize); - - tv2->axis(0)->parallelize(ParallelType::BIDx); - tv2->axis(-2)->parallelize(ParallelType::TIDx); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - const int bx = 128; - const int by = 2049; - at::Tensor t0 = at::randn({bx, by}, options).index({"...", Slice(3)}); - at::Tensor t1 = at::randn({bx, by}, options).index({"...", Slice(3)}); - std::vector aten_inputs = {t0, t1}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto cg_outputs = fe.runFusion(aten_inputs); - - auto aten_output = t0 + t1; - testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionVectorizeMisalignedStrideFail_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2); - auto tv1 = makeSymbolicTensor(2); - - fusion.addInput(tv0); - fusion.addInput(tv1); - - auto tv2 = add(tv0, tv1); - fusion.addOutput(tv2); - - const int kTDX = 64; - const int kVecSize = 4; - const int kNumElems = kTDX * kVecSize; - - tv2->split(1, kNumElems); - - auto c0 = tv0->cacheAfter(); - auto c1 = tv1->cacheAfter(); - auto c2 = tv2->cacheBefore(); - - tv2->split(-1, kVecSize); - - c0->computeAt(tv2, -2); - c1->computeAt(tv2, -2); - - c0->axis(-1)->parallelize(ParallelType::MisalignedVectorize); - c1->axis(-1)->parallelize(ParallelType::MisalignedVectorize); - - tv2->axis(0)->parallelize(ParallelType::BIDx); - tv2->axis(-2)->parallelize(ParallelType::TIDx); - tv2->axis(-1)->parallelize(ParallelType::MisalignedVectorize); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - const int bx = 128; - const int by = 2049; - at::Tensor t0 = at::randn({bx, by}, options).index({"...", Slice(3)}); - at::Tensor t1 = at::randn({bx, by}, options).index({"...", Slice(3)}); - std::vector aten_inputs = {t0, t1}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - - // Failure because the input + output tensors do not have the same stride - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_ANY_THROW(fe.runFusion(aten_inputs)); -} - -TEST_F(NVFuserTest, FusionVectorization1_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2); - - auto tv1 = makeSymbolicTensor(2); - fusion.addInput(tv0); - fusion.addInput(tv1); - - auto tv2 = add(tv0, tv1); - fusion.addOutput(tv2); - - tv2->split(1, 16); - tv2->split(1, 64); - - tv2->axis(0)->parallelize(ParallelType::BIDx); - tv2->axis(2)->parallelize(ParallelType::TIDx); - - auto c0 = tv0->cacheAfter(); - auto c1 = tv1->cacheAfter(); - auto c2 = tv2->cacheBefore(); - - c0->computeAt(tv2, -2); - c1->computeAt(tv2, -2); - - std::vector vectorized_tvs = {c0, c1, tv2}; - for (auto tv : vectorized_tvs) { - tv->split(-1, 4); - tv->axis(-1)->parallelize(ParallelType::Vectorize); - } - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - const int bx = 128; - const int by = 2048; - at::Tensor t0 = at::randn({bx, by}, options); - at::Tensor t1 = at::randn({bx, by}, options); - - std::vector aten_inputs = {t0, t1}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto cg_outputs = fe.runFusion(aten_inputs); - - auto aten_output = t0 + t1; - testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionVectorization2_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2); - - auto tv1 = makeSymbolicTensor(2); - fusion.addInput(tv0); - fusion.addInput(tv1); - - auto tv2 = add(tv0, tv1); - fusion.addOutput(tv2); - - tv2->split(1, 16); - tv2->split(1, 64); - - tv2->axis(0)->parallelize(ParallelType::BIDx); - tv2->axis(2)->parallelize(ParallelType::TIDx); - - auto c0 = tv0->cacheAfter(); - auto c1 = tv1->cacheAfter(); - auto c2 = tv2->cacheBefore(); - - c0->computeAt(tv2, -2); - c1->computeAt(tv2, -2); - - std::vector vectorized_tvs = {c0, c1, tv2}; - for (auto tv : vectorized_tvs) { - tv->split(-1, 4); - // Vectorize the wrong dimension - tv->axis(-2)->parallelize(ParallelType::Vectorize); - } - - FusionExecutor fe; - // Make sure compilation fails - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_ANY_THROW(fe.compileFusion(&fusion)); -} - -TEST_F(NVFuserTest, FusionVectorization3_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2); - - auto tv1 = makeSymbolicTensor(2); - fusion.addInput(tv0); - fusion.addInput(tv1); - - auto tv2 = add(tv0, tv1); - fusion.addOutput(tv2); - - tv2->split(1, 16); - tv2->split(1, 64); - - tv2->axis(0)->parallelize(ParallelType::BIDx); - tv2->axis(2)->parallelize(ParallelType::TIDx); - - auto c0 = tv0->cacheAfter(); - auto c1 = tv1->cacheAfter(); - auto c2 = tv2->cacheBefore(); - - c0->computeAt(tv2, -2); - c1->computeAt(tv2, -2); - - std::vector vectorized_tvs = {c0, c1, tv2}; - for (auto tv : vectorized_tvs) { - tv->split(-1, 4); - tv->axis(-1)->parallelize(ParallelType::Vectorize); - } - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - const int bx = 128; - const int by = 2049; - at::Tensor t0 = at::randn({bx, by}, options); - at::Tensor t1 = at::randn({bx, by}, options); - std::vector aten_inputs = {t0, t1}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_ANY_THROW(fe.runFusion(aten_inputs)); - - aten_inputs[0] = t0.index({"...", Slice(1)}); - aten_inputs[1] = t1.index({"...", Slice(1)}); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_ANY_THROW(fe.runFusion(aten_inputs)); - - t0 = at::randn({bx, 2048}, options).index({"...", Slice(4)}); - t1 = at::randn({bx, 2048}, options).index({"...", Slice(4)}); - aten_inputs = {t0, t1}; - auto cg_outputs = fe.runFusion(aten_inputs); - - auto aten_output = t0 + t1; - testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionVectorizationRFactor_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2); - - auto tv1 = makeSymbolicTensor(2); - fusion.addInput(tv0); - fusion.addInput(tv1); - - auto tv2 = add(tv0, tv1); - - auto tv3 = sum(tv2, {-1}); - - fusion.addOutput(tv3); - - tv3->split(-1, 128 * 4); - tv3->split(-1, 4); - // Reduce outer dim first - auto tv4 = tv3->rFactor({-3, -1}); - // Tv3 will reduce threads - - auto tv6 = tv0->cacheAfter(); - auto tv7 = tv1->cacheAfter(); - - tv0->computeAt(tv3, 1); - tv1->computeAt(tv3, 1); - - tv3->axis(0)->parallelize(ParallelType::BIDx); - - tv0->computeAt(tv4, -2); - tv1->computeAt(tv4, -2); - - tv6->axis(-1)->parallelize(ParallelType::Vectorize); - tv7->axis(-1)->parallelize(ParallelType::Vectorize); - - tv4->axis(-2)->parallelize(ParallelType::TIDx); - tv3->axis(1)->parallelize(ParallelType::TIDx); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - const int bx = 128; - const int by = 2048; - at::Tensor t0 = at::randn({bx, by}, options); - at::Tensor t1 = at::randn({bx, by}, options); - - std::vector aten_inputs = {t0, t1}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto cg_outputs = fe.runFusion(aten_inputs); - - auto aten_output = t0.add(t1).sum(1); - testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); - - auto t3 = t0.add(t1).sum(1); - - testValidate(&fusion, cg_outputs, aten_inputs, {t3}, __LINE__, __FILE__); -} - -// Unswitched loops with extent one may omit else clause. -TEST_F(NVFuserTest, FusionSizeOneLoop1_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Progressively broadcast tensors - TensorView* tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - TensorView* tv1 = makeSymbolicTensor(2); - fusion.addInput(tv1); - TensorView* tv2 = makeSymbolicTensor(3); - fusion.addInput(tv2); - - TensorView* tv3 = broadcast(tv0, {false, true}); - TensorView* tv4 = add(tv3, tv1); - TensorView* tv5 = add(tv4, tv2); - - fusion.addOutput(tv5); - - // Split inner dimension - tv5->split(1, 8); - // Merge middle dims with outer dimensions - tv5->merge(2); - tv5->merge(0); - - // tv5[I0*I1o, I1i*I2] - // Get a dim of size 1 to unswitch - tv5->split(0, 1, false); - - // Compute everything inline - tv0->computeAt(tv5, -1); - - tv5->axis(0)->parallelize(ParallelType::Unswitch); - tv5->axis(1)->parallelize(ParallelType::BIDx); - tv5->axis(2)->parallelize(ParallelType::TIDx); - - // Make sure the unswitched loop does not have an else clause. - GpuLower gpulw(&fusion); - TORCH_CHECK(!UnswitchInElseChecker::check(gpulw)); - - const int x = 11; - const int y = 12; - const int z = 13; - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({x}, options); - at::Tensor t1 = at::randn({x, y}, options); - at::Tensor t2 = at::randn({z, x, y}, options); - std::vector aten_inputs = {t0, t1, t2}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto cg_outputs = fe.runFusion(aten_inputs); - auto t6 = (t0.unsqueeze(-1) + t1).unsqueeze(0) + t2; - - testValidate(&fusion, cg_outputs, aten_inputs, {t6}, __LINE__, __FILE__); -} - -// The unswitched loop has extent one but inner loops don't. The else -// part should not be omitted. -TEST_F(NVFuserTest, FusionSizeOneLoop2_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - const int x = 15; - auto tv0 = makeConcreteTensor({x}); - fusion.addInput(tv0); - - auto tv1 = add(tv0, IrBuilder::create(1)); - fusion.addOutput(tv1); - - tv1->split(-1, 4); - tv1->split(-2, 1); - - tv1->axis(-2)->parallelize(ParallelType::Unswitch); - - // Make sure the size-one unswitched loop does not omit the else clause. - GpuLower gpulw(&fusion); - TORCH_CHECK(UnswitchInElseChecker::check(gpulw)); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({x}, options); - std::vector aten_inputs = {t0}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto cg_outputs = fe.runFusion(aten_inputs); - auto t1 = t0 + 1; - - testValidate(&fusion, cg_outputs, aten_inputs, {t1}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionValidateParallelize1_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - - auto tv1 = add(tv0, IrBuilder::create(1)); - auto tv2 = add(tv1, IrBuilder::create(1)); - fusion.addOutput(tv2); - - tv1->axis(-1)->parallelize(ParallelType::TIDx); - tv2->axis(-1)->parallelize(ParallelType::TIDy); - - // Invalid as tv1 and tv2 do have the same ParallelType - FusionExecutor fe; - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_ANY_THROW(fe.compileFusion(&fusion)); -} - -TEST_F(NVFuserTest, FusionValidateParallelize2_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - - auto tv1 = add(tv0, IrBuilder::create(1)); - auto tv2 = add(tv1, IrBuilder::create(1)); - fusion.addOutput(tv2); - - tv1->axis(-1)->parallelize(ParallelType::TIDx); - tv2->axis(-1)->parallelize(ParallelType::TIDy); - tv1->setMemoryType(MemoryType::Shared); - - // tv1 and tv2 do have the same ParallelType, but tv1 is on shared - // memory, so it is valid - FusionExecutor fe; - fe.compileFusion(&fusion); -} - -TEST_F(NVFuserTest, FusionValidateParallelize3_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - - auto tv1 = add(tv0, IrBuilder::create(1)); - auto tv2 = add(tv1, IrBuilder::create(1)); - fusion.addOutput(tv2); - - tv1->split(-1, 4); - tv1->axis(-1)->parallelize(ParallelType::TIDx); - tv2->split(-1, 4); - tv2->axis(-1)->parallelize(ParallelType::TIDx); - - tv1->setMemoryType(MemoryType::Global); - - // tv1 and tv2 have the same shape and ParallelType - FusionExecutor fe; - fe.compileFusion(&fusion); -} - -TEST_F(NVFuserTest, FusionValidateParallelize4_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - - auto tv1 = add(tv0, IrBuilder::create(1)); - auto tv2 = add(tv1, IrBuilder::create(1)); - fusion.addOutput(tv2); - - tv1->split(-1, 4); - tv1->axis(-1)->parallelize(ParallelType::TIDx); - tv2->split(-1, 8); - tv2->axis(-1)->parallelize(ParallelType::TIDx); - - tv1->setMemoryType(MemoryType::Global); - - // tv1 and tv2 do not have the same shape but global memory comm is supported. - FusionExecutor fe; - fe.compileFusion(&fusion); -} - -TEST_F(NVFuserTest, FusionValidateParallelize5_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - - auto tv1 = add(tv0, IrBuilder::create(1)); - auto tv2 = add(tv1, IrBuilder::create(1)); - fusion.addOutput(tv2); - - tv1->split(-1, 4); - tv1->axis(-1)->parallelize(ParallelType::TIDx); - tv1->setMemoryType(MemoryType::Shared); - - tv2->split(-1, 8); - tv2->axis(-1)->parallelize(ParallelType::TIDx); - - // tv1 and tv2 do not have the same shape, but tv1 is on shared - // memory, so it is valid - FusionExecutor fe; - fe.compileFusion(&fusion); -} - -// See issue #995 -TEST_F(NVFuserTest, FusionValidateParallelize6_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - int64_t W = 5, X = 6, Y = 7, Z = 8; - - auto tv0 = makeConcreteTensor({X, Y, Z}); - auto tv1 = makeConcreteTensor({W, X, Y, Z}); - fusion.addInput(tv0); - fusion.addInput(tv1); - - auto tv2 = add(tv0, IrBuilder::create(1)); - auto tv3 = broadcast(tv2, {true, false, false, false}); - auto tv4 = add(tv3, tv1); - fusion.addOutput(tv4); - - tv4->merge(0); - tv4->merge(0); - tv4->merge(0); - tv4->split(0, 4); - tv4->split(0, 3); - tv4->split(0, 2); - - TransformPropagatorWithCheck propagator(tv4); - MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); - - tv0->computeAt(tv2, 2); - tv3->computeAt(tv4, 2); - - tv4->axis(0)->parallelize(ParallelType::BIDx); - tv4->axis(-1)->parallelize(ParallelType::TIDx); - tv2->axis(0)->parallelize(ParallelType::BIDx); - tv2->axis(-1)->parallelize(ParallelType::TIDx); - tv3->axis(-1)->parallelize(ParallelType::TIDx); - - // Validation should throw an exception saying the first axes of tv2 - // and tv3 have incompatible parallelization. See also issue #995. - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_ANY_THROW(fusion.printKernel()); -} - -TEST_F(NVFuserTest, FusionDAGMerging_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(5); - auto tv1 = makeSymbolicTensor(1); - fusion.addInput(tv0); - fusion.addInput(tv1); - - // Branch 0 - auto tv2 = sum(tv0, {0}); // 0 - auto tv3 = sum(tv2, {0}); // 1 - auto tv4 = sum(tv3, {0}); // 2 - auto tv5 = sum(tv4, {0}); // 3 - - // Branch 1 - auto tv6 = add(tv1, IrBuilder::create(1)); // 4 - - // Merge - auto tv7 = add(tv6, tv5); // 5 - - // Maximum expected output groups (can improve overtime): - // {0}, {1}, {2}, {3,4,5} - // without final merge would have been {0}, {1}, {2}, {3,4}, {5} - - fusion.addOutput(tv7); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({2, 2, 2, 2, 2}, options); - at::Tensor t1 = at::randn({2}, options); - - std::vector aten_inputs = {t0, t1}; - - KernelArgumentHolder args(KernelIndexMode::INT32); - args.setDeviceIndex(0); - args.push(aten_inputs); - - auto fusion_segments = fusion.segment(args); - TORCH_CHECK(fusion_segments->groups().size() <= 4); -} - -TEST_F(NVFuserTest, FusionDAGScalarMerging_CUDA) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - auto tv0 = makeSymbolicTensor(3); - auto i0 = IrBuilder::create(); - - fusion->addInput(tv0); - fusion->addInput(i0); - - auto i1 = add(i0, IrBuilder::create(1.0)); - auto i2 = mul(i1, i1); - auto i3 = add(i2, i1); - - // Branch 0 - auto tv1 = sum(tv0, {0}); // 0 - auto tv2 = add(tv1, i2); - // Branch 1 - auto tv3 = sum(tv2, {0}); // 1 - auto tv4 = add(tv3, i3); - - auto tv5 = add(tv4, i0); - - fusion->addOutput(tv5); - - FusionExecutorCache executor_cache(std::move(fusion)); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({16, 16, 16}, options); - double s0 = 0.5; - - auto s1 = s0 + 1.0; - auto s2 = s1 * s1; - auto s3 = s2 + s1; - auto t1 = t0.sum({0}); - auto t2 = t1 + s2; - auto t3 = sum(t2, {0}); - auto t4 = t3 + s3; - auto t5 = t4 + s0; - - auto outputs = executor_cache.runFusionWithInputs({t0, s0}); - - TORCH_CHECK( - executor_cache.getMostRecentKernelRuntime()->isSegmented(), - "segmentation didn't happen"); - TORCH_CHECK( - executor_cache.getMostRecentKernelRuntime() - ->fusionSegments() - ->groups() - .size() == 2, - "segmentation didn't happen as expected"); - - testValidate( - executor_cache.fusion(), outputs, {t0, s0}, {t5}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionBlockReduceInSerialLoop_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - constexpr int M = 10; - constexpr int N = 20; - constexpr int K = 20; - - auto tv0 = makeSymbolicTensor(3); - auto tv1 = sum(tv0, {{1, 2}}); - fusion.addInput(tv0); - fusion.addOutput(tv1); - - tv1->axis(-1)->parallelize(ParallelType::TIDx); - tv1->axis(0)->parallelize(ParallelType::BIDx); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::manual_seed(0); - at::Tensor t0 = at::randn({M, N, K}, options); - std::vector aten_inputs = {t0}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto outputs = fe.runFusion(aten_inputs); - at::Tensor aten_output = t0.sum({1, 2}); - testValidate( - &fusion, outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionBlockWelfordInSerialLoop_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - constexpr int M = 10; - constexpr int N = 20; - constexpr int K = 20; - - auto tv0 = makeSymbolicTensor(3); - auto tvs = Welford(tv0, {{1, 2}}); - fusion.addInput(tv0); - auto tv_avg = tvs.avg; - auto tv_M2 = tvs.var_sum; - auto tv_N = tvs.n; - fusion.addOutput(tv_avg); - fusion.addOutput(tv_M2); - - tv_avg->axis(-1)->parallelize(ParallelType::TIDx); - tv_avg->axis(0)->parallelize(ParallelType::BIDx); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::manual_seed(0); - at::Tensor t0 = at::randn({M, N, K}, options); - std::vector aten_inputs = {t0}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto outputs = fe.runFusion(aten_inputs); - at::Tensor aten_avg = t0.mean({1, 2}); - at::Tensor aten_M2 = t0.var({1, 2}, false) * N * K; - testValidate( - &fusion, outputs, aten_inputs, {aten_avg, aten_M2}, __LINE__, __FILE__); -} - -// See Issue #716 -TEST_F(NVFuserTest, FusionIOTensorTrivialReductionRepro_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - constexpr int M = 10; - constexpr int N = 11; - - auto tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - - std::vector reduction_axes = {1}; - std::vector broadcast_mask = {false, true}; - - auto tv0_bcast = broadcast(tv0, broadcast_mask); - auto path1_bcast = add(tv0_bcast, IrBuilder::create(1.0)); - auto path1 = sum(path1_bcast, reduction_axes); - fusion.addOutput(path1); - - auto p = path1->split(1, 1); - path1->rFactor({1}); - path1->axis(0)->parallelize(ParallelType::BIDx); - tv0->computeAt(path1, 1); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::manual_seed(0); - at::Tensor t0 = at::randn({M}, options); - at::Tensor t0_ref = t0.clone(); - std::vector aten_inputs = {t0}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - - // inplace op, we are adding t0 to itself - auto outputs = fe.runFusion(aten_inputs, {t0}); - - TORCH_CHECK(outputs[0].allclose(t0_ref.add(1))); -} - -TEST_F(NVFuserTest, FusionReductionPredicate_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - auto tv1 = sum(tv0, {0}); - fusion.addOutput(tv1); - - auto tv2 = tv0->cacheAfter(); - - const int bdimx = 128; - tv1->split(1, bdimx); - tv1->split(1, 4); - tv1->split(1, 1); - - tv1->axis(-1)->parallelize(ParallelType::TIDx); - tv1->axis(2)->parallelize(ParallelType::Unroll); - tv1->split(0, 10); - tv0->computeAt(tv1, 4); - - tv2->axis(-1)->parallelize(ParallelType::TIDx); - - int numel_x = 650; - int numel_y = 102; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input = at::randn({numel_x, numel_y}, options); - at::Tensor cg_output = at::empty({numel_y}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {input}); - fe.runFusion({input}, {cg_output}); - - auto aten_output = input.to(at::kDouble).sum({0}); - - testValidate( - &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionIssue728_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - fusion.addOutput(tv0); - auto tv1 = makeSymbolicTensor(1); - fusion.addOutput(tv1); - auto tv2 = makeSymbolicTensor(1); - fusion.addOutput(tv2); - - auto tv3 = add(tv0, IrBuilder::create(1)); - auto tv4 = add(tv3, tv1); - auto tv5 = add(tv4, IrBuilder::create(1)); - auto tv6 = add(tv2, IrBuilder::create(1)); - fusion.addOutput(tv5); - fusion.addOutput(tv6); - - // tv0 -> tv3 -+ - // tv1 --------+-> tv4 -> tv5 - // - // tv2 -> tv6 - - auto all_vals_under_tv3 = - DependencyCheck::getAllValsBetween({tv3}, fusion.outputs()); - std::unordered_set included_tensors({tv3, tv4, tv5}); - for (auto tv : included_tensors) { - TORCH_CHECK( - std::find(all_vals_under_tv3.begin(), all_vals_under_tv3.end(), tv) != - all_vals_under_tv3.end(), - "TV", - tv->name(), - " not found"); - } - for (auto tv : ir_utils::filterByType(fusion.vals())) { - if (included_tensors.find(tv) == included_tensors.end()) { - TORCH_CHECK( - std::find(all_vals_under_tv3.begin(), all_vals_under_tv3.end(), tv) == - all_vals_under_tv3.end(), - "TV", - tv->name(), - " should not be found"); - } - } - - auto no_dependency = DependencyCheck::getAllValsBetween({}, fusion.outputs()); - TORCH_CHECK(no_dependency.empty(), "No val should be returned"); - - auto no_dep_path = DependencyCheck::getAllValsBetween({tv0, tv1}, {tv6}); - TORCH_CHECK(no_dep_path.empty(), "No val should be returned"); - - auto no_dep_path2 = DependencyCheck::getAllValsBetween({tv2}, {tv5}); - TORCH_CHECK(no_dep_path2.empty(), "No val should be returned"); - - auto just_tv3 = DependencyCheck::getAllValsBetween({tv3}, {tv3}); - TORCH_CHECK( - just_tv3.size() == 1 && *(just_tv3.begin()) == tv3, - "Only tv3 should be included"); -} - -TEST_F(NVFuserTest, FusionIssue757_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - auto tv1 = sum(tv0, {1}); - auto tv2 = broadcast(tv1, {false, true}); - auto tv3 = makeSymbolicTensor(2); - fusion.addInput(tv3); - auto tv4 = add(tv2, tv3); - fusion.addOutput(tv4); - - tv1->computeAt(tv4, -1); - - tv2->axis(-1)->parallelize(ParallelType::TIDx); - tv4->axis(-1)->parallelize(ParallelType::TIDx); - tv1->axis(-1)->parallelize(ParallelType::TIDx); - - int numel_x = 650; - int numel_y = 102; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({numel_x, numel_y}, options); - at::Tensor t3 = at::randn({numel_x, numel_y}, options); - std::vector inputs = {t0, t3}; - - FusionExecutor fe; - fe.compileFusion(&fusion, inputs); - auto outputs = fe.runFusion(inputs); - - auto t1 = t0.sum({1}); - auto t2 = t1.unsqueeze(-1).expand({numel_x, numel_y}); - auto t4 = t2 + t3; - - testValidate(&fusion, outputs, inputs, {t4}, __LINE__, __FILE__); -} - -// See issue #759 -TEST_F(NVFuserTest, FusionPredicatedBlockBroadcast_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - auto tv1 = sum(tv0, {1}); - auto tv2 = broadcast(tv1, {false, true}); - auto tv3 = makeSymbolicTensor(2); - fusion.addInput(tv3); - auto tv4 = add(tv2, tv3); - fusion.addOutput(tv4); - - tv4->split(0, 4); - tv1->computeAt(tv4, -1); - - tv2->axis(-1)->parallelize(ParallelType::TIDx); - tv2->axis(1)->parallelize(ParallelType::TIDy); - tv4->axis(-1)->parallelize(ParallelType::TIDx); - tv4->axis(1)->parallelize(ParallelType::TIDy); - tv1->axis(-1)->parallelize(ParallelType::TIDx); - - int numel_x = 100; - int numel_y = 101; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({numel_x, numel_y}, options); - at::Tensor t3 = at::randn({numel_x, numel_y}, options); - std::vector inputs = {t0, t3}; - - FusionExecutor fe; - fe.compileFusion(&fusion, inputs); - auto outputs = fe.runFusion(inputs); - - auto t1 = t0.sum({1}); - auto t2 = t1.unsqueeze(-1).expand({numel_x, numel_y}); - auto t4 = t2 + t3; - - testValidate(&fusion, outputs, inputs, {t4}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionSegmentVerticalMerge_CUDA) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - auto tv0 = makeSymbolicTensor(3); - - fusion->addInput(tv0); - // {first kernel} - auto tv1 = sum(tv0, {0}); - auto tv2 = add(tv1, tv0); - auto tv3 = sum(tv2, {0}); - auto tv4 = add(tv3, tv0); - auto tv5 = sum(tv4, {0}); - auto tv6 = sum(tv5, {0}); - // {second kernel} - auto tv7 = add(tv6, tv5); - auto tv8 = add(tv7, tv5); - auto tv9 = sum(tv8, {0}); - - fusion->addOutput(tv9); - - SegmentCandidateFinderOptions segment_options; - segment_options.run_herrmann_merge = false; - segment_options.run_final_merge = false; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({2, 2, 2}, options); - - KernelArgumentHolder args(KernelIndexMode::INT32); - args.setDeviceIndex(0); - args.push(t0); - - auto segmented_fusion = - SegmentCandidateFinder::segment(fusion.get(), args, segment_options); - - TORCH_CHECK(segmented_fusion->groups().size() == 2); -} - -TEST_F(NVFuserTest, FusionSegmentHorizontalMerge_CUDA) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - auto tv0 = makeSymbolicTensor(3); - auto i0 = IrBuilder::create(); - - fusion->addInput(tv0); - fusion->addInput(i0); - - // Branch 0 {first kernel} - auto tv1 = sum(tv0, {0}); - auto tv2 = add(tv0, i0); - auto tv3 = unaryOp(UnaryOpType::Rsqrt, tv2); - auto tv4 = sum(tv3, {0}); - - // Branch 1 {first kernel} - auto tv5 = unaryOp(UnaryOpType::Rsqrt, tv3); - auto tv6 = sum(tv5, {0}); - - // Incompatible {second kernel} - auto tv7 = sum(tv6, {0}); - - fusion->addOutput(tv1); - fusion->addOutput(tv4); - fusion->addOutput(tv7); - - SegmentCandidateFinderOptions segment_options; - segment_options.run_herrmann_merge = false; - segment_options.run_final_merge = false; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({2, 2, 2}, options); - - KernelArgumentHolder args(KernelIndexMode::INT32); - args.setDeviceIndex(0); - args.push(t0); - c10::IValue scalar = 1.0; - args.push(scalar); - - auto segmented_fusion = - SegmentCandidateFinder::segment(fusion.get(), args, segment_options); - - TORCH_CHECK(segmented_fusion->groups().size() == 2); -} - -TEST_F(NVFuserTest, FusionSegmentMixReduction_CUDA) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - auto tv0 = makeSymbolicTensor(3); - - fusion->addInput(tv0); - - // def of tv1 in kernel 1 through horizontal - auto tv1 = sum(tv0, {0, 1}); - // kernel 2 - auto tv2 = sum(tv0, {2}); - auto tv3 = broadcast(tv2, {false, false, true}); - auto tv4 = add(tv0, tv3); - auto tv5 = sum(tv4, {2}); - // end of kernel 2 - // kernel 1 - auto tv6 = unaryOp(UnaryOpType::Rsqrt, tv0); - auto tv7 = sum(tv6, {0, 1}); - auto tv8 = sum(tv6, {0, 1}); - - fusion->addOutput(tv1); - fusion->addOutput(tv5); - fusion->addOutput(tv7); - fusion->addOutput(tv8); - - SegmentCandidateFinderOptions segment_options; - segment_options.run_herrmann_merge = false; - segment_options.run_final_merge = false; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({2, 2, 2}, options); - - KernelArgumentHolder args(KernelIndexMode::INT32); - args.setDeviceIndex(0); - args.push(t0); - - auto segmented_fusion = - SegmentCandidateFinder::segment(fusion.get(), args, segment_options); - - TORCH_CHECK(segmented_fusion->groups().size() <= 2); -} - -TEST_F(NVFuserTest, FusionSBAR_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // N, H, W, C format - std::vector input_shape{656, 7, 7, 64}; - - auto x = makeContigTensor(4); - auto y = makeContigTensor(4); - auto weight = makeContigTensor(1); - auto bias = makeContigTensor(1); - - fusion.addInput(x); - fusion.addInput(y); - fusion.addInput(weight); - fusion.addInput(bias); - - const size_t kNumberOfDims = x->nDims(); - std::vector broadcast_mask(kNumberOfDims, false); - for (const auto axis : c10::irange(kNumberOfDims - 1)) { - broadcast_mask[axis] = true; - } - - auto weight_bcast = broadcast(weight, broadcast_mask); - auto scale = mul(x, weight_bcast); - auto bias_bcast = broadcast(bias, broadcast_mask); - auto scale_bias = add(scale, bias_bcast); - auto scale_bias_add = add(scale_bias, y); - auto scale_bias_add_relu = unaryOp(UnaryOpType::Relu, scale_bias_add); - - fusion.addOutput(scale_bias_add_relu); - - // inputs - at::manual_seed(0); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor at_x = at::randn(input_shape, options); - at::Tensor at_y = at::randn(input_shape, options); - at::Tensor at_weight = at::ones({input_shape[3]}, options); - at::Tensor at_bias = at::zeros({input_shape[3]}, options); - - // inputs - std::vector inputs = {at_x, at_y, at_weight, at_bias}; - - // outputs - std::vector outputs; - - auto lparams = schedulePointwise(&fusion, inputs); - - FusionExecutor executor; - executor.compileFusion(&fusion, inputs, lparams); - outputs = executor.runFusion(inputs, lparams); - - auto at_scale = at::mul(at_x, at_weight); - auto at_scale_bias = at::add(at_scale, at_bias); - auto pwise_add = at::add(at_scale_bias, at_y); - auto output = at::relu(pwise_add); - - testValidate(&fusion, outputs, inputs, {output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionSingleElement_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(0); - fusion.addInput(tv0); - - auto tv1 = add(tv0, IrBuilder::create(2.5)); - - auto tv2 = add(tv1, IrBuilder::create(3.5)); - fusion.addOutput(tv2); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input = at::randn({}, options); - - at::Tensor cg_output = at::empty({}, options); - - auto lparams = schedulePointwise(&fusion, {input}); - - FusionExecutor fe; - fe.compileFusion(&fusion, {input}, lparams); - fe.runFusion({input}, {cg_output}, lparams); - - auto aten_output = input.add(2.5).add(3.5); - - testValidate( - &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionBNBackwardRepro_CUDA) { - std::unique_ptr fusion_ptr = std::make_unique(); - Fusion& fusion = *fusion_ptr.get(); - FusionGuard fg(&fusion); - - int batch = 4; - int c = 4; - int h = 4; - int w = 4; - int numDims = 4; - - auto input = makeSymbolicTensor(numDims); - fusion.addInput(input); - auto weight = makeSymbolicTensor(1); - fusion.addInput(weight); - auto running_mean = makeSymbolicTensor(1); - fusion.addInput(running_mean); - auto running_var = makeSymbolicTensor(1); - fusion.addInput(running_var); - auto save_mean = makeSymbolicTensor(1); - fusion.addInput(save_mean); - auto save_invstd = makeSymbolicTensor(1); - fusion.addInput(save_invstd); - - auto grad_out_prev = makeSymbolicTensor(numDims); - fusion.addInput(grad_out_prev); - auto gt_0 = - makeSymbolicTensor(numDims); // single tensor broadcasted is dangerous. - fusion.addInput(gt_0); - - auto gt_bool = binaryOp(BinaryOpType::GT, gt_0, IrBuilder::create(1)); - auto gt_float = castOp(DataType::Float, gt_bool); - - auto grad_out = mul(grad_out_prev, gt_float); - - Val* eps_ptr = IrBuilder::create(1e-5); - - auto grads = batch_norm_backward( - input, - grad_out, - weight, - running_mean, - running_var, - save_mean, - save_invstd, - true, - eps_ptr, - {true, true, true}); - - fusion.addOutput(grads.grad_input); - fusion.addOutput(grads.grad_weight); - fusion.addOutput(grads.grad_bias); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input0 = at::randn({batch, c, h, w}, options); - at::Tensor input1 = at::randn({c}, options); - at::Tensor input2 = at::randn_like(input1); - at::Tensor input3 = at::randn_like(input1); - at::Tensor input4 = at::randn_like(input1); - at::Tensor input5 = at::randn_like(input1); - at::Tensor input6 = at::randn_like(input0); - at::Tensor input7 = at::randn_like(input0); - - FusionExecutorCache fec(std::move(fusion_ptr)); - std::vector inputs = { - input0, input1, input2, input3, input4, input5, input6, input7}; - auto outputs = fec.runFusionWithInputs(inputs); -} - -// TODO: We only changed inputs, merge this with the test above. -TEST_F(NVFuserTest, FusionBNBackwardRepro2_CUDA) { - std::unique_ptr fusion_ptr = std::make_unique(); - Fusion& fusion = *fusion_ptr.get(); - FusionGuard fg(&fusion); - - int batch = 2; - int c = 81; - int h = 1; - int w = 1; - int numDims = 4; - - // auto input = makeSymbolicTensor(numDims); - auto input = makeConcreteTensor({-1, -1, 1, 1}); - fusion.addInput(input); - auto weight = makeSymbolicTensor(1); - fusion.addInput(weight); - auto running_mean = makeSymbolicTensor(1); - fusion.addInput(running_mean); - auto running_var = makeSymbolicTensor(1); - fusion.addInput(running_var); - auto save_mean = makeSymbolicTensor(1); - fusion.addInput(save_mean); - auto save_invstd = makeSymbolicTensor(1); - fusion.addInput(save_invstd); - - // auto grad_out_prev = makeSymbolicTensor(numDims); - auto grad_out_prev = makeConcreteTensor({-1, -1, 1, 1}); - fusion.addInput(grad_out_prev); - // auto gt_0 = - // makeSymbolicTensor(numDims); // single tensor broadcasted is dangerous. - auto gt_0 = makeConcreteTensor({-1, -1, 1, 1}); - fusion.addInput(gt_0); - - auto gt_bool = binaryOp(BinaryOpType::GT, gt_0, IrBuilder::create(1)); - auto gt_float = castOp(DataType::Float, gt_bool); - - auto grad_out = mul(grad_out_prev, gt_float); - - Val* eps_ptr = IrBuilder::create(1e-5); - - auto grads = batch_norm_backward( - input, - grad_out, - weight, - running_mean, - running_var, - save_mean, - save_invstd, - true, - eps_ptr, - {true, true, true}); - - fusion.addOutput(grads.grad_input); - fusion.addOutput(grads.grad_weight); - fusion.addOutput(grads.grad_bias); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input0 = at::randn({batch, c, h, w}, options); - at::Tensor input1 = at::randn({c}, options); - at::Tensor input2 = at::randn_like(input1); - at::Tensor input3 = at::randn_like(input1); - at::Tensor input4 = at::randn_like(input1); - at::Tensor input5 = at::randn_like(input1); - at::Tensor input6 = at::randn_like(input0); - at::Tensor input7 = at::randn_like(input0); - - FusionExecutorCache fec(std::move(fusion_ptr)); - std::vector inputs = { - input0, input1, input2, input3, input4, input5, input6, input7}; - auto outputs = fec.runFusionWithInputs(inputs); -} - -TEST_F(NVFuserTest, FusionBNRepro_CUDA) { - std::unique_ptr fusion_ptr = std::make_unique(); - Fusion& fusion = *fusion_ptr.get(); - FusionGuard fg(&fusion); - - const bool kTraining = true; - const float kMomentum = 0.1; - const float kEps = 1e-5; - - int batch = 14; - int c = 65; - int h = 7; - int w = 7; - int numDims = 4; - - auto input = makeSymbolicTensor(numDims); - fusion.addInput(input); - auto weight = makeSymbolicTensor(1); - fusion.addInput(weight); - auto bias = makeSymbolicTensor(1); - fusion.addInput(bias); - auto running_mean = makeSymbolicTensor(1); - fusion.addInput(running_mean); - auto running_var = makeSymbolicTensor(1); - fusion.addInput(running_var); - - auto momentum_ptr = IrBuilder::create(kMomentum); - auto eps_ptr = IrBuilder::create(kEps); - - auto result = batch_norm( - input, - weight, - bias, - running_mean, - running_var, - kTraining, - momentum_ptr, - eps_ptr); - - fusion.addOutput(result.output); - fusion.addOutput(result.mean); - fusion.addOutput(result.invstd); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input1 = at::randn({batch, c, h, w}, options); - at::Tensor input2 = at::randn({c}, options); - at::Tensor input3 = at::randn_like(input2); - at::Tensor input4 = at::randn_like(input2); - at::Tensor input5 = at::randn_like(input2); - - auto input1_ref = input1.clone(); - auto input2_ref = input2.clone(); - auto input3_ref = input3.clone(); - auto input4_ref = input4.clone(); - auto input5_ref = input5.clone(); - - FusionExecutorCache fec(std::move(fusion_ptr)); - std::vector aten_inputs = {input1, input2, input3, input4, input5}; - auto cg_outputs = fec.runFusionWithInputs(aten_inputs); - - auto at_results = at::native_batch_norm( - input1_ref, - input2_ref, - input3_ref, - input4_ref, - input5_ref, - kTraining, - kMomentum, - kEps); - - auto at_output = std::get<0>(at_results); - auto at_mean = std::get<1>(at_results); - auto at_invstd = std::get<2>(at_results); - - std::vector aten_outputs = {at_output, at_mean, at_invstd}; - - testValidate( - &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionBNRepro2_CUDA) { - std::unique_ptr fusion_ptr = std::make_unique(); - Fusion& fusion = *fusion_ptr.get(); - FusionGuard fg(&fusion); - - const bool kTraining = true; - const float kMomentum = 0.1; - const float kEps = 1e-5; - - int batch = 2; - int c = 4; - int h = 17; - int w = 17; - int numDims = 4; - - auto input = makeSymbolicTensor(numDims); - fusion.addInput(input); - - Val* momentum_ptr = IrBuilder::create(kMomentum); - Val* eps_ptr = IrBuilder::create(kEps); - - auto result = batch_norm( - input, - nullptr, - nullptr, - nullptr, - nullptr, - kTraining, - momentum_ptr, - eps_ptr); - - fusion.addOutput(result.output); - fusion.addOutput(result.mean); - fusion.addOutput(result.invstd); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input1 = at::randn({batch, c, h, w}, options); - - auto input1_ref = input1.clone(); - at::Tensor r_m; - at::Tensor r_v; - at::Tensor weight; - at::Tensor bias; - - FusionExecutorCache fec(std::move(fusion_ptr)); - std::vector aten_inputs = {input1}; - auto cg_outputs = fec.runFusionWithInputs(aten_inputs); - - auto at_results = at::native_batch_norm( - input1_ref, r_m, r_v, weight, bias, kTraining, kMomentum, kEps); - - auto at_output = std::get<0>(at_results); - auto at_mean = std::get<1>(at_results); - auto at_invstd = std::get<2>(at_results); - - std::vector aten_outputs = {at_output, at_mean, at_invstd}; - - testValidate( - &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionZeroSizeTensorPW_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - - auto tv1 = makeConcreteTensor({0}); - fusion.addInput(tv1); - - auto tv2 = add(tv0, IrBuilder::create(2.5)); - fusion.addOutput(tv2); - - // This test used to just have: - // auto tv3 = makeConcreteTensor({0}); - // and somehow that was running through our system fine, but size-0 tensors - // are not supported, so making sure this fails. - auto tv3 = set(tv1); - fusion.addOutput(tv3); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::Tensor input0 = at::randn({2}, options); - at::Tensor input1 = at::randn({0}, options); - at::Tensor cg_output2 = at::empty({2}, options); - at::Tensor cg_output3 = at::empty({0}, options); - - // Fails at schedule pointwise because our (maybe only) size-0 check is in - // binding input sizes which the scheduler ends up calling. - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_ANY_THROW(schedulePointwise(&fusion, {input0, input1})); -} - -TEST_F(NVFuserTest, FusionZeroSizeTensorReduction_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - - auto tv1 = makeConcreteTensor({0}); - fusion.addInput(tv1); - - auto tv2 = sum(tv0, {1}); - fusion.addOutput(tv2); - - auto tv3 = makeConcreteTensor({0}); - fusion.addOutput(tv3); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::Tensor input0 = at::randn({2, 4}, options); - at::Tensor input1 = at::randn({0}, options); - at::Tensor cg_output2 = at::empty({2}, options); - at::Tensor cg_output3 = at::empty({0}, options); - - auto reduction_params = getReductionHeuristics(&fusion, {input0, input1}); - TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); - scheduleReduction(&fusion, *reduction_params); - TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); - - auto lparams = reduction_params->lparams; - FusionExecutor fe; - fe.compileFusion(&fusion, {input0, input1}, lparams); - auto cg_outputs = fe.runFusion({input0, input1}, lparams); - auto aten_output2 = input0.sum({1}); - at::Tensor aten_output3 = at::empty({0}, options); - - testValidate( - &fusion, - cg_outputs, - {input0, input1}, - {aten_output2, aten_output3}, - __LINE__, - __FILE__, - "", - lparams); -} - -TEST_F(NVFuserTest, FusionZeroSizeTensorNormalization_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - - auto tv1 = makeConcreteTensor({0}); - fusion.addInput(tv1); - - auto tv2 = sum(tv0, {0}); - auto tv3 = broadcast(tv2, {true, false}); - auto tv4 = add(tv0, tv3); - fusion.addOutput(tv4); - - auto tv5 = makeConcreteTensor({0}); - fusion.addOutput(tv5); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::Tensor input0 = at::randn({2, 4}, options); - at::Tensor input1 = at::randn({0}, options); - at::Tensor cg_output2 = at::empty({2, 4}, options); - at::Tensor cg_output3 = at::empty({0}, options); - - auto reduction_params = getPersistentHeuristics(&fusion, {input0, input1}); - TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); - schedulePersistentKernel(&fusion, *reduction_params); - - auto lparams = reduction_params->lparams; - FusionExecutor fe; - fe.compileFusion(&fusion, {input0, input1}, lparams); - auto cg_outputs = fe.runFusion({input0, input1}, lparams); - auto aten_output2 = input0.sum({0}).add(input0); - at::Tensor aten_output3 = at::empty({0}, options); - - testValidate( - &fusion, - cg_outputs, - {input0, input1}, - {aten_output2, aten_output3}, - __LINE__, - __FILE__, - "", - lparams); -} - -TEST_F(NVFuserTest, FusionSegmentIoAlias_CUDA) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - TensorView* tv0 = makeSymbolicTensor(2); - TensorView* tv1 = makeSymbolicTensor(1); - TensorView* tv2 = makeSymbolicTensor(2); - - fusion->addInput(tv0); - fusion->addInput(tv1); - fusion->addInput(tv2); - - TensorView* tv3 = add(tv0, IrBuilder::create(1)); // Group 0 - TensorView* tv4 = - max(tv3, {0}); // Group 0 (use max instead to avoid numerical issues) - TensorView* tv5 = add(tv4, tv1); // Group 0 (Non Broadcast after reduce, - // keeps normalization scheduler away) - TensorView* tv6 = add(tv5, tv2); // Group 1 (Broadcast after reduce) - - // Note: test alias; - fusion->aliasOutputToInput(tv6, tv0); - // TODO: support output on aliased fusion #1488 - // remove tv7 after #1488 - // fusion->addOutput(tv6); - TensorView* tv7 = add(tv6, IrBuilder::create(1)); // Group 0 - fusion->addOutput(tv7); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({128, 65}, options); - at::Tensor t1 = at::randn({65}, options); - at::Tensor t2 = at::randn({128, 65}, options); - - auto t3 = t0.add(1.0); - auto t4 = std::get<0>(at::max(t3, 0)); - auto t5 = t4.add(t1); - auto t6 = t5.add(t2); - auto t7 = t6.add(1.0); - - FusionExecutorCache executor_cache(std::move(fusion)); - - auto outputs = executor_cache.runFusionWithInputs({t0, t1, t2}); - - // TODO: support output on aliased fusion #1488 - // validating aliasing - // TORCH_INTERNAL_ASSERT(outputs[0].data_ptr() == t0.data_ptr()); - - TORCH_CHECK( - executor_cache.getMostRecentKernelRuntime()->isSegmented(), - "segmentation didn't happen"); - TORCH_CHECK( - executor_cache.getMostRecentKernelRuntime() - ->fusionSegments() - ->groups() - .size() == 2, - "segmentation didn't happen as expected"); - - testValidate( - executor_cache.fusion(), outputs, {t0, t1, t2}, {t7}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionWelford1Output_CUDA) { - auto fusion_ptr = std::make_unique(); - auto fusion = fusion_ptr.get(); - FusionGuard fg(fusion); - - auto tv0 = makeSymbolicTensor(2); - fusion->addInput(tv0); - - auto tvs = Welford(tv0, {1}); - fusion->addOutput(tvs.var_sum); - FusionExecutorCache executor_cache(std::move(fusion_ptr)); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({128, 65}, options); - auto outputs = executor_cache.runFusionWithInputs({t0}); - - auto t1 = t0.var({1}, false) * 65; - testValidate(fusion, outputs, {t0}, {t1}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionTranslate1Welford_CUDA) { - auto fusion_ptr = std::make_unique(); - auto fusion = fusion_ptr.get(); - FusionGuard fg(fusion); - - auto tv0 = makeSymbolicTensor(2); - fusion->addInput(tv0); - - auto tvs = Welford(tv0, {1}); - auto tv_out = add(tv0, broadcast(tvs.avg, {false, true})); - fusion->addOutput(tv_out); - FusionExecutorCache executor_cache(std::move(fusion_ptr)); - - auto run_test = [&executor_cache, - fusion](auto inner_size) -> FusionKernelRuntime* { - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({128, inner_size}, options); - auto outputs = executor_cache.runFusionWithInputs({t0}); - // Square sums does not fit well in the testValidate assumptions, - // so we just compare the divided output here. - testValidate( - fusion, - outputs, - {t0}, - {t0.add(t0.mean({1}).unsqueeze(1))}, - __LINE__, - __FILE__); - - return executor_cache.getMostRecentKernelRuntime(); - }; - - // Run a translated welford - auto runtime1 = run_test(64); - // Check it was translated - TORCH_CHECK( - runtime1->fusionSegments()->groups().size() == 1 && - runtime1->fusionSegments()->groups()[0]->exprs().size() > 2); - - // Run an un-translated welford - auto runtime2 = run_test(65536); - - bool found_welford = false; - for (auto group : runtime2->fusionSegments()->groups()) { - for (auto expr : group->exprs()) { - if (expr->isA()) { - found_welford = true; - } - } - } - TORCH_CHECK(found_welford); -} - -TEST_F(NVFuserTest, FusionTranslate2Welford_CUDA) { - auto fusion_ptr = std::make_unique(); - auto fusion = fusion_ptr.get(); - FusionGuard fg(fusion); - - auto tv0 = makeSymbolicTensor(2); - fusion->addInput(tv0); - - auto tvs1 = Welford(tv0, {1}); - auto tv_out1 = add(tv0, broadcast(tvs1.avg, {false, true})); - fusion->addOutput(tv_out1); - - auto tvs2 = Welford(tv0, {1}); - auto tv_out2 = add(tv0, broadcast(tvs2.avg, {false, true})); - fusion->addOutput(tv_out2); - - FusionExecutorCache executor_cache(std::move(fusion_ptr)); - - auto run_test = [&executor_cache, - fusion](auto inner_size) -> FusionKernelRuntime* { - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({128, inner_size}, options); - auto outputs = executor_cache.runFusionWithInputs({t0}); - - // Square sums does not fit well in the testValidate assumptions, - // so we just compare the divided output here. - auto out = t0.add(t0.mean({1}).unsqueeze(1)); - testValidate(fusion, outputs, {t0}, {out, out}, __LINE__, __FILE__); - - return executor_cache.getMostRecentKernelRuntime(); - }; - - // Run a translated welford - auto runtime1 = run_test(64); - // Check it was translated - TORCH_CHECK( - runtime1->fusionSegments()->groups().size() == 1 && - runtime1->fusionSegments()->groups()[0]->exprs().size() > 4); - - // Run an un-translated welford - auto runtime2 = run_test(65536); - // // Check it was not translated - bool found_welford = false; - for (auto group : runtime2->fusionSegments()->groups()) { - for (auto expr : group->exprs()) { - if (expr->isA()) { - found_welford = true; - } - } - } - TORCH_CHECK(found_welford); -} - -TEST_F(NVFuserTest, FusionLargeWelfordNormalization_CUDA) { - auto fusion_ptr = std::make_unique(); - auto fusion = fusion_ptr.get(); - FusionGuard fg(fusion); - - auto tv0 = makeSymbolicTensor(2); - fusion->addInput(tv0); - - auto tvs1 = Welford(tv0, {1}); - auto sum_of_tv0 = sum(tv0, {1}); - - fusion->addOutput(tvs1.var_sum); - fusion->addOutput(sum_of_tv0); - - FusionExecutorCache executor_cache(std::move(fusion_ptr)); - - auto run_test = [&executor_cache, - fusion](auto inner_size) -> FusionKernelRuntime* { - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({128, inner_size}, options); - auto outputs = executor_cache.runFusionWithInputs({t0}); - - auto t1 = t0.var({1}, false) * inner_size; - auto t2 = t0.sum({1}); - testValidate(fusion, outputs, {t0}, {t1, t2}, __LINE__, __FILE__); - - return executor_cache.getMostRecentKernelRuntime(); - }; - - auto runtime = run_test(65536); - TORCH_CHECK(!runtime->isSegmented()); -} - -TEST_F(NVFuserTest, FusionWelfordOuterPersistence_CUDA) { - auto fusion_ptr = std::make_unique(); - auto fusion = fusion_ptr.get(); - FusionGuard fg(fusion); - - auto tv0 = makeSymbolicTensor(2); - fusion->addInput(tv0); - - auto tvs1 = Welford(tv0, {1}); - auto sum_of_tv0 = sum(tv0, {1}); - auto sum_bcasted = broadcast(sum_of_tv0, {false, true}); - auto avg_bcasted = broadcast(tvs1.avg, {false, true}); - auto tv0_plus_sum = add(tv0, sum_bcasted); - auto tv0_plus_avg = add(tv0, avg_bcasted); - - fusion->addOutput(tv0_plus_sum); - fusion->addOutput(tv0_plus_avg); - - FusionExecutorCache executor_cache(std::move(fusion_ptr)); - - auto run_test = [&executor_cache, - fusion](auto inner_size) -> FusionKernelRuntime* { - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({128, inner_size}, options); - auto outputs = executor_cache.runFusionWithInputs({t0}); - - auto t1 = t0.to(c10::kDouble).mean({1}).unsqueeze(1) + t0; - auto t2 = t0.to(c10::kDouble).sum({1}).unsqueeze(1) + t0; - testValidate(fusion, outputs, {t0}, {t2, t1}, __LINE__, __FILE__); - - return executor_cache.getMostRecentKernelRuntime(); - }; - - for (auto inner_size : {4096, 8192, 32768}) { - auto runtime = run_test(inner_size); - TORCH_CHECK(!runtime->isSegmented()); - } -} - -TEST_F(NVFuserTest, FusionSegmentIslands_CUDA) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - auto tv0 = makeSymbolicTensor(2); - auto tv1 = makeSymbolicTensor(2); - fusion->addInput(tv0); - fusion->addInput(tv1); - - auto tv2 = sum(tv0, {0}); - auto tv3 = sum(tv1, {1}); - fusion->addOutput(tv2); - fusion->addOutput(tv3); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({16, 16}, options); - at::Tensor t1 = at::randn({16, 16}, options); - - FusionExecutorCache fusion_executor_cache(std::move(fusion)); - fusion_executor_cache.runFusionWithInputs({t0, t1}); -} - -TEST_F(NVFuserTest, FusionBackOffInnerBroadcast_CUDA) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - auto tv0 = makeSymbolicTensor(1); - auto tv1 = makeSymbolicTensor(2); - auto tv2 = makeSymbolicTensor(4); - fusion->addInput(tv0); - fusion->addInput(tv1); - - auto tv3 = broadcast(tv0, {false, true, true, true}); - auto tv4 = broadcast(tv1, {false, false, true, true}); - auto tv5 = unaryOp(UnaryOpType::Rsqrt, tv2); - - auto tv6 = add(tv3, tv5); - auto tv7 = add(tv4, tv5); - auto tv8 = add(tv3, tv4); - - auto tv9 = add(tv6, tv7); - auto tv10 = add(tv9, tv8); - - fusion->addOutput(tv10); - - tv0->computeAt(tv10, -2); - tv1->computeAt(tv10, -2); - tv2->computeAt(tv10, -2); - - TORCH_CHECK(tv3->getComputeAtPosition() == 1); - TORCH_CHECK(tv4->getComputeAtPosition() == 2); - TORCH_CHECK(tv5->getComputeAtPosition() == 3); - - TORCH_CHECK(tv6->getMaxProducerPosition() == 3); - TORCH_CHECK(tv7->getMaxProducerPosition() == 3); - TORCH_CHECK(tv8->getMaxProducerPosition() == 2); -} - -TEST_F(NVFuserTest, FusionBackOffInnerBroadcast2_CUDA) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - auto tv0 = makeSymbolicTensor(2); - auto tv1 = makeSymbolicTensor(3); - fusion->addInput(tv0); - fusion->addInput(tv1); - auto tv2 = broadcast(tv0, {false, false, true}); - auto tv3 = add(tv2, tv1); - - fusion->addOutput(tv3); - tv3->split(-2, 4); - tv3->reorder({{-1, -2}}); - tv0->computeAt(tv3, -2); - tv1->computeAt(tv3, -2); - TORCH_CHECK(tv2->getComputeAtPosition() == 2); - TORCH_CHECK(tv3->getMaxProducerPosition() == 2); -} - -TEST_F(NVFuserTest, FusionBackOffInnerBroadcast3_CUDA) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - auto tv0 = makeSymbolicTensor(2); - auto tv1 = makeSymbolicTensor(4); - fusion->addInput(tv0); - fusion->addInput(tv1); - auto tv2 = broadcast(tv0, {false, false, true}); - auto tv3 = broadcast(tv2, {false, true, false, false}); - auto tv4 = add(tv3, tv1); - - fusion->addOutput(tv4); - tv0->computeAt(tv4, -1); - tv1->computeAt(tv4, -1); - TORCH_CHECK(tv2->getComputeAtPosition() == 2); - TORCH_CHECK(tv3->getMaxProducerPosition() == 3); -} - -TEST_F(NVFuserTest, FusionSimpleWarp_CUDA) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - auto tv0 = makeSymbolicTensor(2); - fusion->addInput(tv0); - - auto tv1 = sum(tv0, {1}); - auto tv2 = broadcast(tv1, {false, true}); - auto tv3 = add(tv2, tv0); - - fusion->addOutput(tv3); - - tv1->split(1, 32); - auto tv1_rf = tv1->rFactor({1}); - TransformPropagatorWithCheck propagator(tv1_rf); - MaxRootDomainInfoSpanningTree(tv1_rf).traverse(&propagator); - tv1_rf->axis(-1)->parallelize(ParallelType::TIDx); - tv1->axis(0)->parallelize(ParallelType::BIDx); - tv1->axis(-1)->parallelize(ParallelType::TIDx); - tv2->axis(0)->parallelize(ParallelType::BIDx); - tv2->axis(-1)->parallelize(ParallelType::TIDx); - tv3->axis(0)->parallelize(ParallelType::BIDx); - tv3->axis(-1)->parallelize(ParallelType::TIDx); - tv0->computeAt(tv3, -1, ComputeAtMode::MostInlined); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input1 = at::randn({16, 128}, options); - - auto at_output = input1.sum({1}, true).add(input1); - - FusionExecutor fe; - fe.compileFusion(fusion.get(), {input1}); - auto outputs = fe.runFusion({input1}); - - testValidate( - fusion.get(), outputs, {input1}, {at_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionSimpleWarpPad_CUDA) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - auto tv0 = makeSymbolicTensor(2); - - fusion->addInput(tv0); - - auto tv1 = sum(tv0, {1}); - auto tv2 = broadcast(tv1, {false, true}); - auto tv3 = add(tv2, tv0); - - fusion->addOutput(tv3); - - // Schedule a persistent kernel - auto tv0_cache = tv0->cacheAfter(); - tv1->split(1, 8, false); - auto tv1_rf = tv1->rFactor({1}); - tv1_rf->axis(0)->parallelize(ParallelType::BIDx); - tv1_rf->axis(-1)->parallelize(ParallelType::TIDx); - tv1_rf->axis(-1)->padToMultipleOfWarp(32); - tv1->axis(-1)->parallelize(ParallelType::TIDx); - tv1->axis(-1)->padToMultipleOfWarp(32); - TransformPropagatorWithCheck propagator(tv1_rf); - MaxRootDomainInfoSpanningTree(tv1_rf).traverse(&propagator); - tv0->axis(-1)->parallelize(ParallelType::TIDx); - tv0->axis(-1)->padToMultipleOfWarp(32); - tv0_cache->axis(-1)->parallelize(ParallelType::TIDx); - tv0_cache->axis(-1)->padToMultipleOfWarp(32); - tv2->axis(-1)->parallelize(ParallelType::TIDx); - tv2->axis(-1)->padToMultipleOfWarp(32); - tv3->axis(-1)->parallelize(ParallelType::TIDx); - tv3->axis(-1)->padToMultipleOfWarp(32); - - tv0->computeAt(tv3, -1, ComputeAtMode::MostInlined); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input1 = at::randn({16, 127}, options); - - auto at_output = input1.sum({1}, true).add(input1); - - FusionExecutor fe; - fe.compileFusion(fusion.get(), {input1}); - auto outputs = fe.runFusion({input1}); - testValidate( - fusion.get(), outputs, {input1}, {at_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionWarpPadMergeSplit_CUDA) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - auto tv0 = makeSymbolicTensor(3); - - fusion->addInput(tv0); - - auto tv1 = sum(tv0, {1, 2}); - auto tv2 = broadcast(tv1, {false, true, true}); - auto tv3 = add(tv2, tv0); - - fusion->addOutput(tv3); - - // Schedule a persistent kernel - auto tv0_cache = tv0->cacheAfter(); - tv1->merge(1); - tv1->split(1, 8, false); - - auto tv1_rf = tv1->rFactor({1}); - tv1_rf->axis(0)->parallelize(ParallelType::BIDx); - tv1_rf->axis(-1)->parallelize(ParallelType::TIDx); - tv1->axis(-1)->parallelize(ParallelType::TIDx); - tv1->axis(-1)->padToMultipleOfWarp(); - TransformPropagatorWithCheck propagator(tv1_rf); - MaxRootDomainInfoSpanningTree(tv1_rf).traverse(&propagator); - tv0->axis(-1)->parallelize(ParallelType::TIDx); - tv0_cache->axis(-1)->parallelize(ParallelType::TIDx); - tv2->axis(-1)->parallelize(ParallelType::TIDx); - tv3->axis(-1)->parallelize(ParallelType::TIDx); - - tv0->computeAt(tv3, -1, ComputeAtMode::MostInlined); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input1 = at::randn({16, 17, 128}, options); - - auto at_output = input1.sum({1, 2}, true).add(input1); - - FusionExecutor fe; - fe.compileFusion(fusion.get(), {input1}); - auto outputs = fe.runFusion({input1}); - testValidate( - fusion.get(), outputs, {input1}, {at_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionSerialWarpReduction_CUDA) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - auto tv0 = makeSymbolicTensor(3); - - fusion->addInput(tv0); - - auto tv1 = sum(tv0, {1, 2}); - auto tv2 = broadcast(tv1, {false, true, true}); - auto tv3 = add(tv2, tv0); - - fusion->addOutput(tv3); - - // Schedule a persistent kernel - auto tv0_cache = tv0->cacheAfter(); - tv1->merge(1); - tv1->split(1, 8, false); - - tv1->axis(-1)->parallelize(ParallelType::TIDx); - tv1->axis(-1)->padToMultipleOfWarp(); - TransformPropagatorWithCheck propagator(tv1); - MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator); - tv0->axis(-1)->parallelize(ParallelType::TIDx); - tv0_cache->axis(-1)->parallelize(ParallelType::TIDx); - tv2->axis(-1)->parallelize(ParallelType::TIDx); - tv3->axis(-1)->parallelize(ParallelType::TIDx); - - tv0->computeAt(tv3, -1, ComputeAtMode::MostInlined); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input1 = at::randn({16, 17, 128}, options); - - auto at_output = input1.sum({1, 2}, true).add(input1); - - FusionExecutor fe; - fe.compileFusion(fusion.get(), {input1}); - auto outputs = fe.runFusion({input1}); - testValidate( - fusion.get(), outputs, {input1}, {at_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionTrivialWarpReduction_CUDA) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - auto tv0 = makeConcreteTensor({17, 18, 128, 1}); - - fusion->addInput(tv0); - - auto tv1 = sum(tv0, {1, 2, 3}); - auto tv2 = broadcast(tv1, {false, true, true, true}); - auto tv3 = add(tv2, tv0); - - fusion->addOutput(tv3); - - // Schedule a persistent kernel - auto tv0_cache = tv0->cacheAfter(); - tv1->merge(1); - tv1->split(1, 8, false); - - auto tv1_rf = tv1->rFactor({1}); - tv1_rf->axis(0)->parallelize(ParallelType::BIDx); - tv1_rf->axis(-2)->parallelize(ParallelType::TIDx); - tv1->axis(-2)->parallelize(ParallelType::TIDx); - tv1->axis(-2)->padToMultipleOfWarp(); - TransformPropagatorWithCheck propagator(tv1_rf); - MaxRootDomainInfoSpanningTree(tv1_rf).traverse(&propagator); - tv0->axis(-2)->parallelize(ParallelType::TIDx); - tv0_cache->axis(-2)->parallelize(ParallelType::TIDx); - tv2->axis(-2)->parallelize(ParallelType::TIDx); - tv3->axis(-2)->parallelize(ParallelType::TIDx); - - tv0->computeAt(tv3, -1, ComputeAtMode::MostInlined); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input1 = at::randn({17, 18, 128, 1}, options); - - auto at_output = input1.sum({1, 2, 3}, true).add(input1); - - FusionExecutor fe; - fe.compileFusion(fusion.get(), {input1}); - auto outputs = fe.runFusion({input1}); - testValidate( - fusion.get(), outputs, {input1}, {at_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionMultipleDimBinding_CUDA) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - auto tv0 = makeSymbolicTensor(2); - auto tv_add = makeSymbolicTensor(2); - - fusion->addInput(tv0); - fusion->addInput(tv_add); - - auto tv1 = sum(tv0, {1}); - auto tv2 = broadcast(tv1, {false, true}); - auto tv3 = add(tv2, tv0); - auto tv4 = add(tv0, tv_add); - - fusion->addOutput(tv3); - fusion->addOutput(tv4); - - // Schedule a persistent kernel - auto tv0_cache = tv0->cacheAfter(); - tv1->split(1, 8, false); - auto tv1_rf = tv1->rFactor({1}); - tv1_rf->axis(0)->parallelize(ParallelType::BIDx); - tv1_rf->axis(-1)->parallelize(ParallelType::TIDx); - tv1_rf->axis(-1)->padToMultipleOfWarp(32); - tv1->axis(-1)->parallelize(ParallelType::TIDx); - tv1->axis(-1)->padToMultipleOfWarp(32); - TransformPropagatorWithCheck propagator(tv1_rf); - MaxRootDomainInfoSpanningTree(tv1_rf).traverse(&propagator); - tv0->axis(-1)->parallelize(ParallelType::TIDx); - tv0->axis(-1)->padToMultipleOfWarp(32); - tv0_cache->axis(-1)->parallelize(ParallelType::TIDx); - tv0_cache->axis(-1)->padToMultipleOfWarp(32); - tv2->axis(-1)->parallelize(ParallelType::TIDx); - tv2->axis(-1)->padToMultipleOfWarp(32); - tv3->axis(-1)->parallelize(ParallelType::TIDx); - tv3->axis(-1)->padToMultipleOfWarp(32); - tv4->axis(-1)->parallelize(ParallelType::TIDx); - tv4->axis(-1)->padToMultipleOfWarp(64); - - tv0->computeAt(tv3, -1, ComputeAtMode::MostInlined); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input1 = at::randn({16, 128}, options); - at::Tensor input2 = at::randn({16, 128}, options); - - auto at_output = input1.sum({1}, true).add(input1); - - FusionExecutor fe; - fe.compileFusion(fusion.get(), {input1, input2}); - auto outputs = fe.runFusion({input1, input2}); - testValidate( - fusion.get(), - outputs, - {input1, input2}, - {at_output, input1 + input2}, - __LINE__, - __FILE__); -} - -TEST_F(NVFuserTest, FusionPadNoWarpReduce_CUDA) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - auto tv0 = makeSymbolicTensor(2); - - fusion->addInput(tv0); - - auto tv1 = sum(tv0, {1}); - auto tv2 = broadcast(tv1, {false, true}); - auto tv3 = add(tv2, tv0); - - fusion->addOutput(tv3); - - tv1->axis(-1)->parallelize(ParallelType::TIDx); - tv1->axis(-1)->padToMultipleOfWarp(); - tv2->axis(-1)->parallelize(ParallelType::TIDx); - tv3->axis(-1)->parallelize(ParallelType::TIDx); - - tv1->axis(0)->parallelize(ParallelType::TIDy); - tv2->axis(0)->parallelize(ParallelType::TIDy); - tv3->axis(0)->parallelize(ParallelType::TIDy); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input1 = at::randn({16, 31}, options); - - auto at_output = input1.sum({1}, true).add(input1); - - FusionExecutor fe; - fe.compileFusion(fusion.get(), {input1}); - auto outputs = fe.runFusion({input1}); - testValidate( - fusion.get(), outputs, {input1}, {at_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionWarpMutipleThreadDim_CUDA) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - auto tv0 = makeSymbolicTensor(2); - fusion->addInput(tv0); - auto tv1 = add(tv0, IrBuilder::create(1)); - auto tv2 = sum(tv1, {1}); - fusion->addOutput(tv2); - - tv2->split(1, 8); - auto tv2_rf = tv2->rFactor({-1}); - tv2_rf->axis(-1)->parallelize(ParallelType::TIDx); - tv2_rf->axis(-1)->padToMultipleOfWarp(); - - TransformPropagatorWithCheck propagator(tv2_rf); - MaxRootDomainInfoSpanningTree(tv2_rf).traverse(&propagator); - - tv0->axis(-1)->parallelize(ParallelType::TIDx); - tv1->axis(-1)->parallelize(ParallelType::TIDx); - tv2->axis(0)->parallelize(ParallelType::BIDx); - tv2->axis(1)->parallelize(ParallelType::TIDy); - tv0->computeAt(tv2, 2); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input1 = at::randn({16, 31}, options); - - auto at_output = (input1 + 1).sum({1}); - - FusionExecutor fe; - fe.compileFusion(fusion.get(), {input1}); - auto outputs = fe.runFusion({input1}); - testValidate( - fusion.get(), outputs, {input1}, {at_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionWarpReduceUnrollOuterLoop_CUDA) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - auto tv0 = makeSymbolicTensor(2); - - fusion->addInput(tv0); - - auto tv1 = sum(tv0, {1}); - auto tv2 = broadcast(tv1, {false, true}); - auto tv3 = add(tv2, tv0); - - fusion->addOutput(tv3); - - // Schedule a persistent kernel - auto tv0_cache = tv0->cacheAfter(); - tv1->split(1, 8, false); - tv1->split(0, 4); - auto tv1_rf = tv1->rFactor({2}); - - tv1_rf->axis(0)->parallelize(ParallelType::BIDx); - tv1_rf->axis(1)->parallelize(ParallelType::Unroll); - tv1_rf->axis(-1)->parallelize(ParallelType::TIDx); - tv1->axis(-1)->parallelize(ParallelType::TIDx); - tv1->axis(-1)->padToMultipleOfWarp(); - tv1->axis(1)->parallelize(ParallelType::Unroll); - TransformPropagatorWithCheck propagator(tv1_rf); - MaxRootDomainInfoSpanningTree(tv1_rf).traverse(&propagator); - tv0->axis(-1)->parallelize(ParallelType::TIDx); - tv0->axis(1)->parallelize(ParallelType::Unroll); - tv0_cache->axis(-1)->parallelize(ParallelType::TIDx); - tv0_cache->axis(1)->parallelize(ParallelType::Unroll); - tv2->axis(-1)->parallelize(ParallelType::TIDx); - tv2->axis(1)->parallelize(ParallelType::Unroll); - tv3->axis(-1)->parallelize(ParallelType::TIDx); - tv3->axis(1)->parallelize(ParallelType::Unroll); - - tv0->computeAt(tv3, -1, ComputeAtMode::MostInlined); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input1 = at::randn({16, 128}, options); - - auto at_output = input1.sum({1}, true).add(input1); - - FusionExecutor fe; - fe.compileFusion(fusion.get(), {input1}); - auto outputs = fe.runFusion({input1}); - testValidate( - fusion.get(), outputs, {input1}, {at_output}, __LINE__, __FILE__); -} - -// Repro of issue #1579 -TEST_F(NVFuserTest, FusionWarpReducePredication_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - std::vector shape1 = {1024}; - std::vector shape2 = {50}; - - auto tv0 = makeConcreteTensor(shape1); - fusion.addInput(tv0); - auto tv1 = sum(tv0, {0}); - fusion.addOutput(tv1); - - auto tv2 = makeConcreteTensor(shape2); - fusion.addInput(tv2); - auto tv3 = add(tv2, IrBuilder::create(1)); - auto tv4 = sum(tv3, {0}); - auto tv5 = add(tv4, IrBuilder::create(1)); - fusion.addOutput(tv5); - - // Just to fill the smem buffer by a thread block of 1024 threads - // with some values - tv1->axis(-1)->parallelize(ParallelType::TIDx); - - // Make the tv4_rf reduction a warp reduction to trigger the - // bug. Since the smem buffer is filled with some values due to the - // reduction of tv1, those values would be used by predicated-out - // threads. - tv4->split(-1, 10); - auto tv4_rf = tv4->rFactor({-1}); - tv4_rf->axis(-1)->parallelize(ParallelType::TIDx); - tv4_rf->axis(-1)->padToMultipleOfWarp(); - - tv4_rf->computeAt(tv4, 1); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto t0 = at::randn(shape1, options); - auto t2 = at::randn(shape2, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0, t2}); - auto cg_outputs = fe.runFusion({t0, t2}); - - auto t1 = t0.sum({0}); - auto t4 = (t2 + 1).sum({0}) + 1; - - testValidate(&fusion, cg_outputs, {t0, t2}, {t1, t4}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionSegfaultReduction_CUDA) { - std::unique_ptr fusion_ptr = std::make_unique(); - Fusion& fusion = *fusion_ptr.get(); - FusionGuard fg(&fusion); - - int batch = 2; - int c = 1; - int h = 1; - int w = 1; - int numDims = 4; - - auto input = makeConcreteTensor({-1, 1, 1, 1}); - fusion.addInput(input); - auto bcast_bias = makeConcreteTensor({-1, 1, 1, 1}); - fusion.addInput(bcast_bias); - - std::vector at_sum_axes; - std::vector outer_reduction_axes; - std::vector outer_broadcast_mask(numDims, false); - Val* N = IrBuilder::create(1); - for (const auto axis : c10::irange(numDims)) { - if (axis != 1) { - outer_reduction_axes.push_back(axis); - at_sum_axes.push_back(axis); - outer_broadcast_mask[axis] = true; - N = mul(N, input->domain()->domain()[axis]->extent()); - } - } - - auto output0 = mul(input, bcast_bias); - fusion.addOutput(output0); - auto output1 = sum(output0, outer_reduction_axes); - fusion.addOutput(output1); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input0 = at::randn({batch, c, h, w}, options); - at::Tensor input1 = at::randn({batch, c, h, w}, options); - - auto at_output0 = input0.mul(input1); - auto at_output1 = at_output0.sum(at_sum_axes); - - FusionExecutorCache fec(std::move(fusion_ptr)); - std::vector inputs = {input0, input1}; - auto outputs = fec.runFusionWithInputs(inputs); - - testValidate( - &fusion, outputs, inputs, {at_output0, at_output1}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionPredicateElimination1_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - - auto tv1 = add(tv0, IrBuilder::create(1)); - auto tv2 = add(tv1, IrBuilder::create(2)); - auto tv3 = add(tv2, IrBuilder::create(3)); - - fusion.addOutput(tv3); - - tv3->split(0, 32); - tv0->computeAt(tv3, 1); - - tv2->axis(1)->parallelize(ParallelType::Unswitch); - - { - GpuLower gpulw(&fusion); - TORCH_CHECK(!PredicatedChecker::isPredicated(tv2, gpulw)); - } - - tv2->axis(1)->parallelize(ParallelType::Serial); - tv2->split(1, 5); - - { - GpuLower gpulw(&fusion); - TORCH_CHECK(PredicatedChecker::isPredicated(tv2, gpulw)); - } -} - -// Repro of issue #1571 -TEST_F(NVFuserTest, FusionPredicateElimination2_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - std::vector shape({10, 11}); - - auto tv0 = makeConcreteTensor(shape); - fusion.addInput(tv0); - - auto tv1 = add(tv0, IrBuilder::create(1)); - auto tv2 = sum(tv1, {1}); - auto tv3 = add(tv2, IrBuilder::create(1)); - - fusion.addOutput(tv3); - - tv1->split(1, 4); - tv1->split(0, 4); - tv2->split(1, 4); - tv2->split(0, 4); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto t0 = at::randn(shape, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0}); - auto cg_outputs = fe.runFusion({t0}); - - auto ref = (t0 + 1).sum({1}) + 1; - - testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionPredicateElimination3_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - - auto tv1 = sum(tv0, {0}); - auto tv2 = add(tv1, IrBuilder::create(1)); - fusion.addOutput(tv2); - - auto tv3 = tv0->cacheAfter(); - - tv1->split(0, 10); - tv1->split(0, 33); - TransformPropagatorWithCheck propagator(tv1); - MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator); - - auto tv4 = tv1->rFactor({-1}); - auto tv5 = tv1->rFactor({-1}); - - tv4->axis(0)->parallelize(ParallelType::BIDx); - tv4->axis(1)->parallelize(ParallelType::TIDx); - scheduler_utils::parallelizeAllLike(tv4); - - GpuLower gpulw(&fusion); - - // The fusion has three reductions: one within each thread, one - // within each block, and another with the whole grid. All of them - // should not need to be predicated as they use the same init value - // and same reduction op. - TORCH_CHECK(!PredicatedChecker::isPredicated(tv4, gpulw)); - TORCH_CHECK(!PredicatedChecker::isPredicated(tv5, gpulw)); - TORCH_CHECK(!PredicatedChecker::isPredicated(tv1, gpulw)); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - for (auto size : {1, 2, 999, 1001, 1234, 10000}) { - auto t0 = at::randn({size}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0}); - auto cg_outputs = fe.runFusion({t0}); - - auto ref = sum(t0) + 1; - testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); - } -} - -TEST_F(NVFuserTest, FusionPredicateElimination4_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - - auto tv1 = sum(tv0, {1}); - - auto tv2 = sum(tv1, {0}); - auto tv3 = add(tv2, IrBuilder::create(1)); - fusion.addOutput(tv3); - - auto tv4 = max(tv1, {0}); - auto tv5 = add(tv4, IrBuilder::create(1)); - fusion.addOutput(tv5); - - tv1->split(1, 7); - tv1->split(0, 11); - tv1->reorder({{1, 2}, {2, 1}}); - TransformPropagatorWithCheck propagator(tv1); - MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator); - - tv1->axis(0)->parallelize(ParallelType::TIDy); - tv1->axis(1)->parallelize(ParallelType::TIDx); - scheduler_utils::parallelizeAllLike(tv1); - - GpuLower gpulw(&fusion); - - // tv2 uses the same op and init with tv1, so tv2 should be fine - // without a predicate. However, tv4, while it uses the tv1 as its - // input, the reduction op and init value is different from those of - // tv1, so tv4 needs to be predicated. - TORCH_CHECK(!PredicatedChecker::isPredicated(tv2, gpulw)); - TORCH_CHECK(PredicatedChecker::isPredicated(tv4, gpulw)); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - std::vector sizes = {1, 2, 33, 34, 64, 99}; - for (auto s0 : sizes) { - for (auto s1 : sizes) { - auto t0 = at::randn({s0, s1}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0}); - auto cg_outputs = fe.runFusion({t0}); - - auto t1 = t0.sum({1}); - auto t3 = t1.sum({0}) + 1; - auto t5 = std::get<0>(t1.max(0)) + 1; - - testValidate(&fusion, cg_outputs, {t0}, {t3, t5}, __LINE__, __FILE__); - } - } -} - -TEST_F(NVFuserTest, FusionPredicateElimination5_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - - auto tv1 = set(tv0); - auto tvs2 = Welford(tv1, {0}); - auto tv3 = set(tvs2.avg); - fusion.addOutput(tv3); - - tvs2.avg->split(0, 4); - TransformPropagatorWithCheck propagator(tvs2.avg); - MaxRootDomainInfoSpanningTree(tvs2.avg).traverse(&propagator); - auto avg_rf = ir_utils::rfactorHelper(tvs2.avg, {1}); - - avg_rf->axis(0)->parallelize(ParallelType::TIDx); - scheduler_utils::parallelizeAllLike(avg_rf); - - GpuLower gpulw(&fusion); - - // The first per-thread welford needs to be predicated as the N - // input is different from its init value. The second welford op - // does not need a predicate. - TORCH_CHECK(PredicatedChecker::isPredicated(avg_rf, gpulw)); - TORCH_CHECK(!PredicatedChecker::isPredicated(tvs2.avg, gpulw)); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - std::vector sizes = {1, 2, 33, 34, 64, 99}; - for (auto s0 : sizes) { - auto t0 = at::randn({s0}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0}); - auto cg_outputs = fe.runFusion({t0}); - - auto ref = t0.mean({0}); - - testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); - } -} - -TEST_F(NVFuserTest, FusionPredicateElimination6_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeConcreteTensor({2, 3}); - fusion.addInput(tv0); - - auto tv1 = add(tv0, IrBuilder::create(1)); - auto tv2 = add(tv1, IrBuilder::create(1)); - auto tv3 = add(tv2, IrBuilder::create(1)); - auto tv4 = add(tv3, IrBuilder::create(1)); - fusion.addOutput(tv4); - - tv4->split(1, 5); - TransformPropagatorWithCheck propagator(tv4); - MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); - - tv4->reorder({{0, 1}, {1, 0}}); - tv3->computeAt(tv4, 1); - - GpuLower gpulw(&fusion); - - // The expression for tv2 is a local-to-local expression. It - // satisfies all the requirements of predicate elimination, except - // for the on on split root domains. As the second root axis of tv2 - // is split, its index exceeds its extent (i.e., 3 in this case) - // without its predicate. - TORCH_CHECK(PredicatedChecker::isPredicated(tv2, gpulw)); - - // Unlike tv2, tv3 is computed at tv4, so the second root axis does - // have a zero domain. Its index should look like "i * 5 + j", where - // i comes from the first root domain and j comes from the split - // inner domain. - TORCH_CHECK(!PredicatedChecker::isPredicated(tv3, gpulw)); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto t0 = at::randn({2, 3}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0}); - auto cg_outputs = fe.runFusion({t0}); - - auto ref = t0 + 4; - - testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionPredicateElimination7_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - - auto tv1 = add(tv0, IrBuilder::create(1)); - auto tv2 = add(tv1, IrBuilder::create(1)); - auto tv3 = add(tv2, IrBuilder::create(1)); - fusion.addOutput(tv3); - - tv3->split(-1, 5); - tv3->split(-1, 4); - tv3->split(-1, 3); - TransformPropagatorWithCheck propagator(tv3); - MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); - - tv0->computeAt(tv3, 1); - - // The last split of tv2 is a non-divisible split, and omitting it - // is invalid. - GpuLower gpulw(&fusion); - TORCH_CHECK(PredicatedChecker::isPredicated(tv2, gpulw)); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto t0 = at::randn({123}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0}); - auto cg_outputs = fe.runFusion({t0}); - - auto ref = t0 + 3; - - testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionForceFp16Simple_CUDA) { - std::unique_ptr fusion_ptr = std::make_unique(); - auto fusion = fusion_ptr.get(); - FusionGuard fg(fusion); - - auto tv0 = makeSymbolicTensor(2); - auto tv1 = makeSymbolicTensor(2); - - fusion->addInput(tv0); - fusion->addInput(tv1); - - // Group 1 - auto tv2 = sum(tv0, {1}); - auto tv3 = broadcast(tv2, {false, true}); - - // Group 2 - auto tv4 = add(tv3, tv1); // Edge: tv3: expect cast - auto tv5 = castOp(DataType::Half, tv4); - - fusion->addOutput(tv5); - - FusionExecutorCache fec(std::move(fusion_ptr)); - - std::vector shape{15, 16}; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto in0 = at::randn(shape, options); - auto in1 = at::randn(shape, options); - fec.runFusionWithInputs({in0, in1}); - - // Check the segmented edge is fp16 - auto segmented_fusion = fec.getMostRecentKernelRuntime()->fusionSegments(); - for (auto edge : segmented_fusion->edges()) { - auto edge_tv = edge->val->as(); - TORCH_CHECK(edge_tv->getDataType() == DataType::Half); - } -} - -TEST_F(NVFuserTest, FusionForceBf16Simple_CUDA) { -#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 - // requires ampere+ GPU - if (!deviceMajorMinorCheck(8)) { - GTEST_SKIP() << "skipping tests on pre-AMPERE GPUs"; - return; - } - - std::unique_ptr fusion_ptr = std::make_unique(); - auto fusion = fusion_ptr.get(); - FusionGuard fg(fusion); - - auto tv0 = makeSymbolicTensor(2); - auto tv1 = makeSymbolicTensor(2); - - fusion->addInput(tv0); - fusion->addInput(tv1); - - // Group 1 - auto tv2 = sum(tv0, {1}); - auto tv3 = broadcast(tv2, {false, true}); - - // Group 2 - auto tv4 = add(tv3, tv1); // Edge: tv3: expect cast - auto tv5 = castOp(DataType::BFloat16, tv4); - - fusion->addOutput(tv5); - - FusionExecutorCache fec(std::move(fusion_ptr)); - - std::vector shape{15, 16}; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto in0 = at::randn(shape, options); - auto in1 = at::randn(shape, options); - fec.runFusionWithInputs({in0, in1}); - - // Check the segmented edge is bf16 - auto segmented_fusion = fec.getMostRecentKernelRuntime()->fusionSegments(); - for (auto edge : segmented_fusion->edges()) { - auto edge_tv = edge->val->as(); - TORCH_CHECK(edge_tv->getDataType() == DataType::BFloat16); - } -#else - GTEST_SKIP() << "requires cuda 11.0 or newer toolkit"; -#endif -} - -TEST_F(NVFuserTest, FusionForceFp16NotAllCast_CUDA) { - std::unique_ptr fusion_ptr = std::make_unique(); - auto fusion = fusion_ptr.get(); - FusionGuard fg(fusion); - - auto tv0 = makeSymbolicTensor(3); - auto tv1 = makeSymbolicTensor(3); - - fusion->addInput(tv0); - fusion->addInput(tv1); - - // Group 1 - auto tv3 = sum(tv0, {1}); - auto tv4 = broadcast(tv3, {false, true, false}); - auto tv5 = sum(tv0, {1}); - - // Group 2 - auto tv6 = add(tv4, tv1); // edge tv4, expect cast - auto tv7 = castOp(DataType::Half, tv6); - - // Group 3 - auto tv8 = sum(tv5, {1}); // edge tv5, don't expect cast - - fusion->addOutput(tv7); - fusion->addOutput(tv8); - - FusionExecutorCache fec(std::move(fusion_ptr)); - - std::vector shape{16, 16, 16}; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto in0 = at::randn(shape, options); - auto in1 = at::randn(shape, options); - fec.runFusionWithInputs({in0, in1}); - - auto segmented_fusion = fec.getMostRecentKernelRuntime()->fusionSegments(); - auto complete_fusion = segmented_fusion->completeFusion(); - - // Check that the edge that wasn't fp16 is the producer of the - // reduction op, i.e. tv8 = sum(tv5,{1});. - for (auto edge : segmented_fusion->edges()) { - auto edge_tv = edge->val->as(); - if (edge_tv->getDataType() == DataType::Float) { - auto consumer = *(complete_fusion->unordered_uses(edge_tv).begin()); - TORCH_CHECK(consumer->isA()); - } - } -} - -TEST_F(NVFuserTest, FusionForceBf16NotAllCast_CUDA) { -#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 - // requires ampere+ GPU - if (!deviceMajorMinorCheck(8)) { - GTEST_SKIP() << "skipping tests on pre-AMPERE GPUs"; - return; - } - - std::unique_ptr fusion_ptr = std::make_unique(); - auto fusion = fusion_ptr.get(); - FusionGuard fg(fusion); - - auto tv0 = makeSymbolicTensor(3); - auto tv1 = makeSymbolicTensor(3); - - fusion->addInput(tv0); - fusion->addInput(tv1); - - // Group 1 - auto tv3 = sum(tv0, {1}); - auto tv4 = broadcast(tv3, {false, true, false}); - auto tv5 = sum(tv0, {1}); - - // Group 2 - auto tv6 = add(tv4, tv1); // edge tv4, expect cast - auto tv7 = castOp(DataType::BFloat16, tv6); - - // Group 3 - auto tv8 = sum(tv5, {1}); // edge tv5, don't expect cast - - fusion->addOutput(tv7); - fusion->addOutput(tv8); - - FusionExecutorCache fec(std::move(fusion_ptr)); - - std::vector shape{16, 16, 16}; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto in0 = at::randn(shape, options); - auto in1 = at::randn(shape, options); - fec.runFusionWithInputs({in0, in1}); - - auto segmented_fusion = fec.getMostRecentKernelRuntime()->fusionSegments(); - auto complete_fusion = segmented_fusion->completeFusion(); - - // Check that the edge that wasn't fp16 is the producer of the - // reduction op, i.e. tv8 = sum(tv5,{1});. - for (auto edge : segmented_fusion->edges()) { - auto edge_tv = edge->val->as(); - if (edge_tv->getDataType() == DataType::Float) { - auto consumer = *(complete_fusion->unordered_uses(edge_tv).begin()); - TORCH_CHECK(consumer->isA()); - } - } -#else - GTEST_SKIP() << "requires cuda 11.0 or newer toolkit"; -#endif -} - -TEST_F(NVFuserTest, FusionBufferReuseBroadCastMultiVisit_CUDA) { - std::unique_ptr fusion_ptr = std::make_unique(); - auto fusion = fusion_ptr.get(); - FusionGuard fg(fusion); - - auto tv0 = makeConcreteTensor({2, 2}); - auto tv1 = makeConcreteTensor({2, 2, 2}); - - fusion->addInput(tv0); - fusion->addInput(tv1); - - auto tv2 = mul(tv0, IrBuilder::create(2)); - auto tv3 = broadcast(tv2, {false, false, true}); - auto tv4 = add(tv3, tv1); - auto tv5 = mul(tv4, IrBuilder::create(3)); - fusion->addOutput(tv5); - - // t4 cannot inner re-use t2, because there's a broadcast - // between them. - tv0->computeAt(tv5, 1, ComputeAtMode::BestEffort); - tv3->computeAt(tv5, 2, ComputeAtMode::BestEffort); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto in0 = at::randn({2, 2}, options); - auto in1 = at::randn({2, 2, 2}, options); - - auto at_output = ((in0 * 2).unsqueeze(2) + in1) * 3; - FusionExecutor fe; - fe.compileFusion(fusion, {in0, in1}); - auto outputs = fe.runFusion({in0, in1}); - - testValidate(fusion, outputs, {in0, in1}, {at_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionBufferReuseStressTest_CUDA) { - std::unique_ptr fusion_ptr = std::make_unique(); - auto fusion = fusion_ptr.get(); - FusionGuard fg(fusion); - - auto tv0 = makeConcreteTensor({2, 2}); - auto tv1 = makeConcreteTensor({2, 2, 2}); - - fusion->addInput(tv0); - fusion->addInput(tv1); - - auto tv2 = mul(tv0, IrBuilder::create(2)); - auto tv3 = mul(tv0, IrBuilder::create(3)); - auto tv4 = mul(tv2, tv3); - // Broadcast buffer can be reused through outer sharing - auto tv5 = broadcast(tv4, {true, false, false}); - auto tv6 = mul(tv5, IrBuilder::create(5)); - auto tv7 = mul(tv6, tv1); - auto tv8 = mul(tv7, IrBuilder::create(7)); - // tv9 shouldn't alias to avoid buffer over-subscription - auto tv9 = broadcast(tv4, {true, false, false}); - auto tv10 = mul(tv9, IrBuilder::create(9)); - auto tv11 = add(tv5, tv9); - fusion->addOutput(tv7); - fusion->addOutput(tv11); - - tv0->computeAt(tv5, 1, ComputeAtMode::BestEffort); - tv0->computeAt(tv9, 1, ComputeAtMode::BestEffort); - - tv5->computeAt(tv7, 1, ComputeAtMode::BestEffort); - tv5->computeAt(tv11, 1, ComputeAtMode::BestEffort); - tv9->computeAt(tv11, 1, ComputeAtMode::BestEffort); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto in0 = at::randn({2, 2}, options); - auto in1 = at::randn({2, 2, 2}, options); - auto t2 = in0 * 2; - auto t3 = in0 * 3; - auto t4 = t2 * t3; - auto t5 = t4.unsqueeze(0); - auto t6 = t5 * 5; - auto t7 = t6 * in1; - auto t8 = t7 * 7; - auto t9 = t4.unsqueeze(0); - auto t10 = t9 * 9; - auto t11 = t5 + t9; - FusionExecutor fe; - fe.compileFusion(fusion, {in0, in1}); - - auto at_output = ((in0 * 2).unsqueeze(2) + in1) * 3; - auto outputs = fe.runFusion({in0, in1}); - - testValidate(fusion, outputs, {in0, in1}, {t7, t11}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionBufferReuseLargeBuffer_CUDA) { - std::unique_ptr fusion_ptr = std::make_unique(); - auto fusion = fusion_ptr.get(); - FusionGuard fg(fusion); - - auto tv0 = makeConcreteTensor({256, 512}); - - fusion->addInput(tv0); - - auto tv1 = mul(tv0, IrBuilder::create(2)); - auto tv2 = mul(tv1, IrBuilder::create(2)); - auto tv3 = mul(tv2, IrBuilder::create(2)); - auto tv4 = mul(tv3, IrBuilder::create(2)); - auto tv5 = mul(tv4, IrBuilder::create(2)); - auto tv6 = mul(tv5, IrBuilder::create(2)); - - fusion->addOutput(tv6); - - tv0->computeAt(tv6, 1, ComputeAtMode::BestEffort); - tv6->axis(0)->parallelize(ParallelType::TIDx); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto in0 = at::randn({256, 512}, options); - - FusionExecutor fe; - fe.compileFusion(fusion, {in0}); - auto outputs = fe.runFusion({in0}); - - auto at_out = in0.mul(2).mul(2).mul(2).mul(2).mul(2).mul(2); - - testValidate(fusion, outputs, {in0}, {at_out}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionBufferReuseNo2hop_CUDA) { - std::unique_ptr fusion_ptr = std::make_unique(); - auto fusion = fusion_ptr.get(); - FusionGuard fg(fusion); - - auto tv0 = makeConcreteTensor({2, 2}); - auto tv1 = makeConcreteTensor({2, 2, 2}); - - fusion->addInput(tv0); - fusion->addInput(tv1); - - auto tv2 = mul(tv0, IrBuilder::create(2)); - auto tv3 = broadcast(tv2, {false, false, true}); - auto tv4 = add(tv3, tv1); // T4 to be inner aliased first, and - // shouldn't outer alias on top - auto tv5 = mul(tv4, IrBuilder::create(3)); - auto tv6 = mul(tv5, IrBuilder::create(3)); - fusion->addOutput(tv6); - - tv0->computeAt(tv6, 1, ComputeAtMode::BestEffort); - tv4->computeAt(tv6, 2, ComputeAtMode::BestEffort); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto in0 = at::randn({2, 2}, options); - auto in1 = at::randn({2, 2, 2}, options); - FusionExecutor fe; - fe.compileFusion(fusion, {in0, in1}); - auto outputs = fe.runFusion({in0, in1}); - - auto at_out = (in0.mul(2.0).unsqueeze(2) + in1).mul(3.0).mul(3.0); - - testValidate(fusion, outputs, {in0, in1}, {at_out}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionBufferReuseAllocationOrder_CUDA) { - std::unique_ptr fusion_ptr = std::make_unique(); - auto fusion = fusion_ptr.get(); - FusionGuard fg(fusion); - - auto tv0 = makeConcreteTensor({3, 3, 3}); - - fusion->addInput(tv0); - - auto tv1 = sum(tv0, {1}); - auto tv2 = mul(tv1, IrBuilder::create(2)); - auto tv3 = mul(tv2, IrBuilder::create(2)); - - fusion->addOutput(tv3); - - // In this case tv1 "reuses" allocation of tv2 - // due to the switched allocation order - tv1->computeAt(tv2, 1, ComputeAtMode::BestEffort); - - tv0->axis(0)->parallelize(ParallelType::TIDx); - tv1->axis(0)->parallelize(ParallelType::TIDx); - tv2->axis(0)->parallelize(ParallelType::TIDx); - tv3->axis(0)->parallelize(ParallelType::TIDx); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto in0 = at::randn({3, 3, 3}, options); - - FusionExecutor fe; - fe.compileFusion(fusion, {in0}); - auto outputs = fe.runFusion({in0}); - - auto at_out = in0.sum(1).mul(2).mul(2); - - testValidate(fusion, outputs, {in0}, {at_out}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionBufferReuseLiveInterval_CUDA) { - std::unique_ptr fusion_ptr = std::make_unique(); - auto fusion = fusion_ptr.get(); - FusionGuard fg(fusion); - - auto tv0 = makeConcreteTensor({16, 16}); - - fusion->addInput(tv0); - - auto tv1 = mul(tv0, IrBuilder::create(3)); - auto tv2 = mul(tv1, IrBuilder::create(2)); - auto tv3 = mul(tv2, IrBuilder::create(2)); - // tv1 used till here, cannot be reused by tv2 or tv3 - auto tv4 = mul(tv3, tv1); - - fusion->addOutput(tv4); - - tv0->computeAt(tv4, 1); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto in0 = at::randn({16, 16}, options); - - FusionExecutor fe; - fe.compileFusion(fusion, {in0}); - auto cg_outputs = fe.runFusion({in0}); - - auto at_t0 = in0 * 3.0; - auto at_out = at_t0 * 2.0 * 2.0 * at_t0; - - testValidate(fusion, cg_outputs, {in0}, {at_out}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionBufferReuseNoAcrossBroadcast_CUDA) { - std::unique_ptr fusion_ptr = std::make_unique(); - auto fusion = fusion_ptr.get(); - FusionGuard fg(fusion); - - auto tv0 = makeConcreteTensor({2, 2}); - auto tv1 = makeConcreteTensor({2, 2, 2}); - - fusion->addInput(tv0); - fusion->addInput(tv1); - - auto tv2 = mul(tv0, IrBuilder::create(2)); - auto tv3 = mul(tv0, IrBuilder::create(3)); - auto tv4 = mul(tv2, tv3); - auto tv5 = broadcast(tv4, {false, false, true}); - auto tv6 = mul(tv5, tv1); - auto tv7 = mul(tv6, IrBuilder::create(7)); - fusion->addOutput(tv7); - - // tv6 shouldn't re-use t2 or t3 because of - // the broadcast in between - tv0->computeAt(tv4, 1, ComputeAtMode::BestEffort); - tv4->computeAt(tv7, 2, ComputeAtMode::BestEffort); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto in0 = at::randn({2, 2}, options); - auto in1 = at::randn({2, 2, 2}, options); - FusionExecutor fe; - fe.compileFusion(fusion, {in0, in1}); - auto outputs = fe.runFusion({in0, in1}); - - auto t2 = in0 * 2; - auto t3 = in0 * 3; - auto t4 = t2 * t3; - auto t5 = t4.unsqueeze(2); - auto t6 = t5 * in1; - auto t7 = t6 * 7; - testValidate(fusion, outputs, {in0, in1}, {t7}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionIssue970_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - const int nelm = 10; - - // tv3 = tv0 + sum(tv0) - auto tv0 = makeConcreteTensor({nelm, nelm}); - fusion.addInput(tv0); - auto tv1 = sum(tv0, {1}); - auto tv2 = broadcast(tv1, {false, true}); - auto tv3 = add(tv2, tv0); - fusion.addOutput(tv3); - - tv1->split(1, 4); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto options_int = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0); - at::manual_seed(0); - at::Tensor t0 = at::randn({nelm, nelm}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0}); - auto outputs = fe.runFusion({t0}); - - auto ref = sum(t0, {1}).unsqueeze(-1).expand({nelm, nelm}) + t0; - - testValidate(&fusion, outputs, {t0}, {ref}, __LINE__, __FILE__); -} - -// Reproducer of #1016 -TEST_F(NVFuserTest, FusionIssue1016_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - - auto tv1 = add(tv0, IrBuilder::create(1)); - auto tv2 = add(tv1, IrBuilder::create(2)); - - fusion.addOutput(tv2); - - tv1->setMemoryType(MemoryType::Shared); - - tv2->split(-1, 8); - - int numel_x = 10; - int numel_y = 11; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({numel_x, numel_y}, options); - std::vector inputs = {t0}; - - FusionExecutor fe; - fe.compileFusion(&fusion, inputs); - auto outputs = fe.runFusion(inputs); - - auto ref = t0 + 1 + 2; - - testValidate(&fusion, outputs, {t0}, {ref}, __LINE__, __FILE__); -} - -// Reproducer of #1021 -TEST_F(NVFuserTest, FusionIssue1021_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - auto tv1 = add(tv0, IrBuilder::create(1)); - auto tv2 = broadcast(tv1, {false, true}); - fusion.addOutput(tv2); - - auto tv3 = tv2->cacheBefore(); - - tv2->split(0, 2); - - tv1->computeAt(tv2, 1); - - tv2->axis(0)->parallelize(ParallelType::TIDx); - tv2->axis(1)->parallelize(ParallelType::Vectorize); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({10}, options); - std::vector inputs = {t0}; - - FusionExecutor fe; - fe.compileFusion(&fusion, inputs); - auto outputs = fe.runFusion(inputs); - - auto ref = (t0 + 1).unsqueeze(-1); - - testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); -} - -// Reproducer of issue #1053 -TEST_F(NVFuserTest, FusionNonUniqueThreadDim_CUDA) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - auto tv0 = makeSymbolicTensor(1); - fusion->addInput(tv0); - auto tv1 = sum(tv0, {0}); - fusion->addOutput(tv1); - - auto tv2 = add(tv0, IrBuilder::create(1)); - fusion->addOutput(tv2); - - tv1->split(0, 8); - auto tv1_rf = tv1->rFactor({-1}); - - tv1_rf->computeAt(tv1, 1); - - tv1_rf->axis(-1)->parallelize(ParallelType::TIDx); - - tv2->axis(0)->parallelize(ParallelType::TIDx); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input1 = at::randn({32}, options); - - auto at_tv1 = (input1).sum({0}); - auto at_tv2 = input1 + 1; - - FusionExecutor fe; - fe.compileFusion(fusion.get(), {input1}); - auto outputs = fe.runFusion({input1}); - testValidate( - fusion.get(), outputs, {input1}, {at_tv1, at_tv2}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionParallelDimensionMap1_CUDA) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - auto tv0 = makeSymbolicTensor(1); - fusion->addInput(tv0); - auto tv1 = add(tv0, IrBuilder::create(1)); - auto tv2 = add(tv0, IrBuilder::create(1)); - fusion->addOutput(tv1); - fusion->addOutput(tv2); - - tv1->split(0, 8, false); - tv1->axis(1)->parallelize(ParallelType::TIDx); - tv2->split(0, 8, false); - tv2->axis(1)->parallelize(ParallelType::TIDx); - - // The extents of tv1 and tv2 axes are equal even though their - // actual values are not statically known - GpuLower gpulw(fusion.get()); - const auto& pdmap = gpulw.parallelDimensionMap(); - for (const auto i : c10::irange(tv1->domain()->domain().size())) { - auto dom1 = tv1->domain()->domain()[i]; - auto dom2 = tv2->domain()->domain()[i]; - TORCH_INTERNAL_ASSERT(pdmap.equalDim(dom1->extent(), dom2->extent())); - } - - TORCH_CHECK(pdmap.isExact(ParallelType::TIDx)); - TORCH_CHECK( - pdmap.get(ParallelType::TIDx)->isA() && - pdmap.get(ParallelType::TIDx)->as()->name() == "blockDim.x"); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input1 = at::randn({32}, options); - - FusionExecutor fe; - fe.compileFusion(fusion.get(), {input1}); - auto outputs = fe.runFusion({input1}); - - testValidate( - fusion.get(), - outputs, - {input1}, - {input1 + 1, input1 + 1}, - __LINE__, - __FILE__); -} - -TEST_F(NVFuserTest, FusionParallelDimensionMap2_CUDA) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - auto tv0 = makeSymbolicTensor(1); - fusion->addInput(tv0); - auto tv1 = makeSymbolicTensor(2); - fusion->addInput(tv1); - auto tv2 = broadcast(tv0, {false, true}); - auto tv3 = add(tv1, tv2); - fusion->addOutput(tv3); - - tv3->split(-1, 8, false); - tv2->computeAt(tv3, -1); - - tv3->axis(-1)->parallelize(ParallelType::TIDx); - tv2->axis(-1)->parallelize(ParallelType::TIDx); - - GpuLower gpulw(fusion.get()); - const auto& pdmap = gpulw.parallelDimensionMap(); - TORCH_CHECK(pdmap.isExact(ParallelType::TIDx)); - TORCH_CHECK( - pdmap.get(ParallelType::TIDx)->isA() && - pdmap.get(ParallelType::TIDx)->as()->name() == "blockDim.x"); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input1 = at::randn({11}, options); - at::Tensor input2 = at::randn({11, 13}, options); - - FusionExecutor fe; - fe.compileFusion(fusion.get(), {input1, input2}); - auto outputs = fe.runFusion({input1, input2}); - - auto ref = input1.unsqueeze(-1) + input2; - - testValidate( - fusion.get(), outputs, {input1, input2}, {ref}, __LINE__, __FILE__); -} - -// Mix symbolic and concrete tensors -TEST_F(NVFuserTest, FusionParallelDimensionMap3_CUDA) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - auto tv0 = makeSymbolicTensor(1); - fusion->addInput(tv0); - - auto tv2 = add(tv0, IrBuilder::create(1)); - fusion->addOutput(tv2); - auto tv3 = add(tv0, IrBuilder::create(1)); - fusion->addOutput(tv3); - - tv2->split(0, 10); - tv3->split(0, 20); - - auto tv4 = add(tv0, IrBuilder::create(1)); - fusion->addOutput(tv4); - auto tv5 = add(tv0, IrBuilder::create(1)); - fusion->addOutput(tv5); - - // Not mapped but equal extent - tv4->split(0, 10); - tv5->split(0, 10); - - tv2->axis(-1)->parallelize(ParallelType::TIDx); - tv3->axis(-1)->parallelize(ParallelType::TIDx); - - tv4->axis(-1)->parallelize(ParallelType::TIDy); - tv5->axis(-1)->parallelize(ParallelType::TIDy); - - GpuLower gpulw(fusion.get()); - const auto& pdmap = gpulw.parallelDimensionMap(); - TORCH_CHECK(!pdmap.isExact(ParallelType::TIDx)); - TORCH_CHECK( - pdmap.get(ParallelType::TIDx)->isA() && - pdmap.get(ParallelType::TIDx)->as()->name() == "blockDim.x"); - TORCH_CHECK(pdmap.isExact(ParallelType::TIDy)); - TORCH_CHECK( - pdmap.get(ParallelType::TIDy)->isConst() && - pdmap.get(ParallelType::TIDy)->as()->value().value() == 10); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input1 = at::randn({13}, options); - - FusionExecutor fe; - fe.compileFusion(fusion.get(), {input1}); - auto outputs = fe.runFusion({input1}); - - testValidate( - fusion.get(), - outputs, - {input1}, - {input1 + 1, input1 + 1, input1 + 1, input1 + 1}, - __LINE__, - __FILE__); -} - -// Parallelizing merged broadcast domains -TEST_F(NVFuserTest, FusionParallelDimensionMap4_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - auto tv1 = makeSymbolicTensor(2); - fusion.addInput(tv1); - auto tv2 = add(tv0, IrBuilder::create(1)); - auto tv3 = broadcast(tv2, {true, false}); - auto tv4 = add(tv3, tv1); - fusion.addOutput(tv4); - - tv4->split(1, 4); - tv4->reorder({{1, 2}, {2, 1}}); - tv4->merge(0); - tv0->computeAt(tv4, 1); - tv1->computeAt(tv4, 1); - - // TIDx is mapped to tv4.axis(0) as well as tv2.axis(0), so it's not - // exact. - tv4->axis(0)->parallelize(ParallelType::TIDx); - - tv2->setMemoryType(MemoryType::Shared); - tv3->setMemoryType(MemoryType::Shared); - - GpuLower gpulw(&fusion); - const auto& pdmap = gpulw.parallelDimensionMap(); - TORCH_CHECK(!pdmap.isExact(ParallelType::TIDx)); - TORCH_CHECK( - pdmap.get(ParallelType::TIDx)->isA() && - pdmap.get(ParallelType::TIDx)->as()->name() == "blockDim.x"); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input1 = at::randn({13}, options); - at::Tensor input2 = at::randn({15, 13}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {input1, input2}); - auto outputs = fe.runFusion({input1, input2}); - - auto ref = (input1 + 1).unsqueeze(0) + input2; - - testValidate(&fusion, outputs, {input1, input2}, {ref}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionParallelDimensionMap5_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - auto tv1 = makeSymbolicTensor(2); - fusion.addInput(tv1); - auto tv3 = broadcast(tv0, {false, true}); - auto tv4 = add(tv3, tv1); - fusion.addOutput(tv4); - - tv4->split(1, 4); - tv0->computeAt(tv4, -1); - tv1->computeAt(tv4, -1); - - tv4->axis(-1)->parallelize(ParallelType::TIDx); - tv3->axis(-1)->parallelize(ParallelType::TIDx); - tv4->axis(-2)->parallelize(ParallelType::TIDy); - tv3->axis(-2)->parallelize(ParallelType::TIDy); - - GpuLower gpulw(&fusion); - const auto& pdmap = gpulw.parallelDimensionMap(); - TORCH_CHECK(pdmap.isExact(ParallelType::TIDx)); - TORCH_CHECK(pdmap.isExact(ParallelType::TIDy)); - TORCH_CHECK( - pdmap.get(ParallelType::TIDx)->isConst() && - pdmap.get(ParallelType::TIDx)->as()->value().value() == 4); - TORCH_CHECK( - pdmap.get(ParallelType::TIDy)->isA() && - pdmap.get(ParallelType::TIDy)->as()->name() == "blockDim.y"); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input1 = at::randn({13}, options); - at::Tensor input2 = at::randn({13, 15}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {input1, input2}); - auto outputs = fe.runFusion({input1, input2}); - - auto ref = (input1).unsqueeze(-1) + input2; - - testValidate(&fusion, outputs, {input1, input2}, {ref}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionSegmenterCombineReductionsCycleRepro_CUDA) { - auto fusion_ptr = std::make_unique(); - auto& fusion = *fusion_ptr.get(); - FusionGuard fg(&fusion); - - auto t0 = makeSymbolicTensor(3, DataType::Float); - auto t1 = makeSymbolicTensor(3, DataType::Half); - auto t3 = makeSymbolicTensor(3, DataType::Half); - auto t5 = makeSymbolicTensor(3, DataType::Half); - auto t7 = makeSymbolicTensor(1, DataType::Half); - auto t11 = makeSymbolicTensor(3, DataType::Half); - auto t13 = makeSymbolicTensor(3, DataType::Half); - auto t15 = makeSymbolicTensor(3, DataType::Half); - auto t17 = makeSymbolicTensor(3, DataType::Half); - auto d56 = IrBuilder::create(); - - fusion.addInput(t0); - fusion.addInput(t1); - fusion.addInput(t3); - fusion.addInput(t5); - fusion.addInput(t7); - fusion.addInput(t11); - fusion.addInput(t13); - fusion.addInput(t15); - fusion.addInput(t17); - fusion.addInput(d56); - - auto t2 = castOp(DataType::Float, t1); - auto t4 = castOp(DataType::Float, t3); - auto t22 = sub(t2, t4); - auto t6 = castOp(DataType::Float, t5); - auto t23 = mul(t22, t6); - auto t16 = castOp(DataType::Float, t15); - auto t18 = castOp(DataType::Float, t17); - auto t19 = add(t16, t18); - auto t14 = castOp(DataType::Float, t13); - auto t20 = add(t19, t14); - auto t12 = castOp(DataType::Float, t11); - auto t21 = add(t20, t12); - auto t8 = castOp(DataType::Float, t7); - auto t24 = broadcast(t8, {true, true, false}); - auto t25 = mul(t21, t24); - auto t27 = sum(t25, {2}); - auto t28 = broadcast(t27, {false, false, true}); - auto t29 = mul(t25, t23); - auto t30 = sum(t29, {2}); - auto t31 = broadcast(t30, {false, false, true}); - auto d59 = - mul(t1->getRootDomain()[2]->extent(), IrBuilder::create(1)); - auto t26 = mul(d59, t25); - auto txx = mul(t26, IrBuilder::create(1)); - auto t33 = sub(txx, t28); - auto d70 = unaryOp(UnaryOpType::Reciprocal, d59); - auto t35 = mul(d70, t6); - auto t39 = sum(t21, {0, 1}); - auto t47 = castOp(DataType::Half, t39); - auto t37 = mul(t21, t23); - auto t38 = sum(t37, {0, 1}); - auto t46 = castOp(DataType::Half, t38); - auto t32 = mul(t23, t31); - auto t34 = sub(t33, t32); - auto t36 = mul(t35, t34); - auto t45 = castOp(DataType::Half, t36); - auto t40 = mul(t36, t0); - auto t41 = mul(t40, d56); - auto t44 = castOp(DataType::Half, t41); - auto t42 = sum(t41, {0, 1}); - auto t43 = castOp(DataType::Half, t42); - - fusion.addOutput(t43); - fusion.addOutput(t44); - fusion.addOutput(t45); - fusion.addOutput(t46); - fusion.addOutput(t47); - - auto options_half = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - auto options_float = - at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor at_t0 = at::randn({128, 64, 1024}, options_float); - at::Tensor at_t1 = at::randn({128, 64, 1024}, options_half); - at::Tensor at_t3 = at::randn({128, 64, 1024}, options_half); - at::Tensor at_t5 = at::randn({128, 64, 1024}, options_half); - at::Tensor at_t7 = at::randn({1024}, options_half); - at::Tensor at_t11 = at::randn({128, 64, 1024}, options_half); - at::Tensor at_t13 = at::randn({128, 64, 1024}, options_half); - at::Tensor at_t15 = at::randn({128, 64, 1024}, options_half); - at::Tensor at_t17 = at::randn({128, 64, 1024}, options_half); - double at_d56 = 1.1111; - - std::vector aten_inputs = { - at_t0, at_t1, at_t3, at_t5, at_t7, at_t11, at_t13, at_t15, at_t17}; - - c10::IValue val = at_d56; - - KernelArgumentHolder args(KernelIndexMode::INT32); - args.setDeviceIndex(0); - args.push(aten_inputs); - args.push(val); - - for (auto _ : c10::irange(5)) { - auto segmented_fusion = - SegmentCandidateFinder::segment(fusion_ptr.get(), args); - } -} - -TEST_F(NVFuserTest, FusionSerialAndParallelIndexing_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - - auto tv1 = add(tv0, IrBuilder::create(1)); - auto tv2 = add(tv1, IrBuilder::create(1)); - fusion.addOutput(tv2); - - auto tv3 = add(tv0, IrBuilder::create(1)); - auto tv4 = add(tv3, IrBuilder::create(1)); - fusion.addOutput(tv4); - - auto tv5 = add(tv0, IrBuilder::create(1)); - auto tv6 = add(tv5, IrBuilder::create(1)); - fusion.addOutput(tv6); - - // Case 1: local memory tensor computed serially and used by - // parallel threads - tv2->split(-1, 4); - tv1->computeAt(tv2, -2); - tv2->axis(-1)->parallelize(ParallelType::TIDx); - - // Case 2: shared memory tensor computed serially and used by BID - tv4->split(-1, 4); - tv3->computeAt(tv4, -2); - tv4->axis(-1)->parallelize(ParallelType::BIDx); - tv3->setMemoryType(MemoryType::Shared); - - // Case 3: shared memory tensor computed by TID and used by BID - tv6->split(-1, 4); - tv5->computeAt(tv6, -2); - tv6->axis(-1)->parallelize(ParallelType::BIDx); - tv5->axis(-1)->parallelize(ParallelType::TIDx); - tv5->setMemoryType(MemoryType::Shared); - - const int nx = 11; - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({nx}, options); - std::vector aten_inputs = {t0}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto outputs = fe.runFusion(aten_inputs); - - auto ref = t0 + 2; - - testValidate( - &fusion, outputs, aten_inputs, {ref, ref, ref}, __LINE__, __FILE__); -} - -// Repro of issue #1105 -TEST_F(NVFuserTest, FusionWARSyncAliasedSmem_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - - auto tv1 = add(tv0, IrBuilder::create(1)); - auto tv2 = add(tv1, IrBuilder::create(1)); - auto tv3 = add(tv2, IrBuilder::create(1)); - - fusion.addOutput(tv3); - - tv1->setMemoryType(MemoryType::Shared); - tv2->setMemoryType(MemoryType::Shared); - - tv3->split(0, 4); - tv0->computeAt(tv3, 1); - - tv1->axis(-1)->parallelize(ParallelType::TIDx); - tv2->axis(-1)->parallelize(ParallelType::TIDy); - tv3->axis(-1)->parallelize(ParallelType::TIDz); - - // Make sure a WAR sync is inserted at the end of the outer loop - GpuLower gpulw(&fusion); - for (const auto& kir_node : gpulw.kernel()->topLevelExprs()) { - if (auto loop = dynamic_cast(kir_node)) { - const auto& body = loop->body().exprs(); - TORCH_CHECK(!body.empty()); - auto last_expr = dynamic_cast(body.back()); - TORCH_CHECK(last_expr != nullptr, "Invalid expr found"); - TORCH_CHECK(last_expr->isWarHazardSync(), "Not a sync for WAR hazard"); - } - } - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({17}, options); - std::vector aten_inputs = {t0}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto outputs = fe.runFusion(aten_inputs); - - auto ref1 = t0 + 3; - - testValidate(&fusion, outputs, aten_inputs, {ref1}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionIssue1099_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - - auto tv1 = add(tv0, IrBuilder::create(1)); - auto tv2 = add(tv1, IrBuilder::create(1)); - fusion.addOutput(tv2); - - auto tv3 = makeSymbolicTensor(1); - fusion.addInput(tv3); - - // Just to make TIDx/y/z non-exact - auto tv4 = add(tv3, IrBuilder::create(1)); - auto tv5 = add(tv4, IrBuilder::create(1)); - auto tv6 = add(tv5, IrBuilder::create(1)); - fusion.addOutput(tv6); - - tv2->split(0, 4); - tv0->computeAt(tv2, 1); - - tv0->axis(-1)->parallelize(ParallelType::TIDx); - tv1->axis(-1)->parallelize(ParallelType::TIDy); - tv2->axis(-1)->parallelize(ParallelType::TIDz); - tv2->axis(0)->parallelize(ParallelType::BIDx); - - tv1->setMemoryType(MemoryType::Shared); - - tv4->split(0, 5); - tv4->axis(-1)->parallelize(ParallelType::TIDx); - tv4->setMemoryType(MemoryType::Shared); - tv5->split(0, 6); - tv5->axis(-1)->parallelize(ParallelType::TIDy); - tv5->setMemoryType(MemoryType::Shared); - tv6->split(0, 7); - tv6->axis(-1)->parallelize(ParallelType::TIDz); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({17}, options); - at::Tensor t3 = at::randn({19}, options); - std::vector aten_inputs = {t0, t3}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto outputs = fe.runFusion(aten_inputs); - - auto ref_t2 = t0 + 2; - auto ref_t3 = t3 + 3; - - testValidate( - &fusion, outputs, aten_inputs, {ref_t2, ref_t3}, __LINE__, __FILE__); -} - -// Repro of issue #1080 -TEST_F(NVFuserTest, FusionUnswitchPredicate_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - - auto tv1 = add(tv0, IrBuilder::create(1)); - auto tv2 = add(tv1, IrBuilder::create(1)); - fusion.addOutput(tv2); - - tv2->split(0, 4); - tv0->computeAt(tv2, 2); - - tv2->split(-1, 8); - tv1->split(-1, 8); - - tv2->axis(1)->parallelize(ParallelType::Unswitch); - - tv2->axis(-1)->parallelize(ParallelType::TIDx); - tv2->axis(-2)->parallelize(ParallelType::TIDy); - - // swap TIDx and TIDy - tv1->axis(-1)->parallelize(ParallelType::TIDy); - tv1->axis(-2)->parallelize(ParallelType::TIDx); - - tv1->setMemoryType(MemoryType::Shared); - - const int nx = 4; - const int ny = 10; - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({nx, ny}, options); - std::vector aten_inputs = {t0}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto outputs = fe.runFusion(aten_inputs); - - auto ref = t0 + 2; - - testValidate(&fusion, outputs, aten_inputs, {ref}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionIssue1189_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeConcreteTensor({16, 16}); - auto tv1 = makeConcreteTensor({16, 16}); - - auto tv0b = broadcast(tv0, {false, false, true}); - auto tv1b = broadcast(tv1, {false, false, true}); - - fusion.addInput(tv0b); - fusion.addInput(tv1b); - - auto tv2 = add(tv0b, tv1b); - auto tv3 = sum(tv2, {1}); - fusion.addOutput(tv3); - - auto parallelize = [](auto tv) { - tv->axis(0)->parallelize(ParallelType::TIDx); - tv->axis(1)->parallelize(ParallelType::BIDx); - tv->axis(2)->parallelize(ParallelType::BIDy); - }; - - parallelize(tv0b); - parallelize(tv1b); - parallelize(tv2); - parallelize(tv3); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({16, 16, 1}, options); - at::Tensor t1 = at::randn({16, 16, 1}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0, t1}); - auto outputs = fe.runFusion({t0, t1}); - - auto ref = (t0 + t1).sum({1}); - - testValidate(&fusion, outputs, {t0, t1}, {ref}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionIssue1052_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - auto tv1 = makeSymbolicTensor(1); - fusion.addInput(tv1); - - auto tv2 = add(tv0, IrBuilder::create(1)); - fusion.addOutput(tv2); - - auto tv3 = add(tv1, IrBuilder::create(1)); - fusion.addOutput(tv3); - - tv2->axis(-1)->parallelize(ParallelType::TIDx); - tv3->axis(-1)->parallelize(ParallelType::TIDx); - - scheduler_utils::parallelizeAllLike(tv2, {tv0}); - scheduler_utils::parallelizeAllLike(tv3, {tv1}); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({10}, options); - at::Tensor t1 = at::randn({100}, options); - std::vector aten_inputs = {t0, t1}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto outputs = fe.runFusion(aten_inputs); - - auto ref_t2 = t0 + 1; - auto ref_t3 = t1 + 1; - - testValidate( - &fusion, outputs, aten_inputs, {ref_t2, ref_t3}, __LINE__, __FILE__); -} - -// Repro of issue #1115 -TEST_F(NVFuserTest, FusionPointwiseBroadcast_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - std::vector input_shape{3, 17, 80}; - std::vector output_shape{3, 17, 1, 80}; - - TensorView* x = makeSymbolicTensor(input_shape.size()); - TensorView* bias = makeSymbolicTensor(input_shape.size()); - fusion.addInput(x); - fusion.addInput(bias); - - auto x_add_bias = add(x, bias); - auto x_bcast = broadcast(x_add_bias, {false, false, true, false}); - auto y = gelu(x_bcast); - fusion.addOutput(y); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor at_x = at::randn(input_shape, options); - at::Tensor at_bias = at::randn(input_shape, options); - std::vector aten_inputs = {at_x, at_bias}; - - schedulePointwise(&fusion, aten_inputs); - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto outputs = fe.runFusion(aten_inputs); - - auto at_x_add_bias = at_x + at_bias; - auto at_x_view = at::native::view(at_x_add_bias, output_shape); - auto aten_y = at::gelu(at_x_view); - - testValidate(&fusion, outputs, aten_inputs, {aten_y}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionPointwiseVectorize_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - const int size = 1024 * 64; - - TensorView* x = makeContigTensor(1); - fusion.addInput(x); - auto y = sin(x); - fusion.addOutput(y); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - // PyTorch's CUDA caching allocator should always return aligned pointer for - // freshly allocated tensor - at::Tensor at_x = at::randn({size}, options); - - schedulePointwise(&fusion, {at_x}); - - for (auto x_consumer : ir_utils::consumerTvsOf(x)) { - bool found_vec_in_input = false; - for (auto id : x_consumer->domain()->domain()) { - if (isParallelTypeVectorize(id->getParallelType())) { - found_vec_in_input = true; - break; - } - } - TORCH_CHECK(found_vec_in_input, "Expect input to be vectorized"); - } - - for (auto id : y->domain()->domain()) { - if (isParallelTypeVectorize(id->getParallelType())) { - return; - } - } - TORCH_CHECK(false, "Expect output to be vectorized"); -} - -TEST_F(NVFuserTest, FusionSmemAliasSerial_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - - auto tv1 = add(tv0, IrBuilder::create(1)); - auto tv2 = add(tv1, IrBuilder::create(1)); - auto tv3 = add(tv2, IrBuilder::create(1)); - - fusion.addOutput(tv3); - - // Just set the dimension of TIDx - auto tv4 = makeSymbolicTensor(1); - fusion.addInput(tv4); - auto tv5 = add(tv4, IrBuilder::create(1)); - fusion.addOutput(tv5); - - tv1->setMemoryType(MemoryType::Shared); - tv2->setMemoryType(MemoryType::Shared); - - tv5->axis(0)->parallelize(ParallelType::TIDx); - - // tv1 and tv2 are on shared memory and are not parallelized with - // TIDx. They should be predicated as they are redundant and can - // interfere with smem aliasing (issue #1100). - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({10}, options); - at::Tensor t4 = at::randn({1024}, options); - std::vector aten_inputs = {t0, t4}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto outputs = fe.runFusion(aten_inputs); - - auto ref1 = t0 + 3; - auto ref2 = t4 + 1; - - testValidate(&fusion, outputs, aten_inputs, {ref1, ref2}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionGridReductionWithNonExactParallelDimensions_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - - auto tv1 = add(tv0, IrBuilder::create(1)); - fusion.addOutput(tv1); - - auto tv2 = makeSymbolicTensor(1); - fusion.addInput(tv2); - auto tv3 = sum(tv2, {0}); - fusion.addOutput(tv3); - - tv1->axis(0)->parallelize(ParallelType::TIDx); - tv3->axis(0)->parallelize(ParallelType::BIDx); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({17}, options); - at::Tensor t2 = at::randn({19}, options); - std::vector aten_inputs = {t0, t2}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto outputs = fe.runFusion(aten_inputs); - - auto ref1 = t0 + 1; - auto ref2 = sum(t2); - - testValidate(&fusion, outputs, aten_inputs, {ref1, ref2}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionGridWelfordWithNonExactParallelDimensions_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - - auto tv1 = add(tv0, IrBuilder::create(1)); - fusion.addOutput(tv1); - - auto tv2 = makeSymbolicTensor(1); - fusion.addInput(tv2); - auto tv3 = Welford(tv2, {0}).avg; - fusion.addOutput(tv3); - - tv1->axis(0)->parallelize(ParallelType::TIDx); - tv3->axis(0)->parallelize(ParallelType::BIDx); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({17}, options); - at::Tensor t2 = at::randn({19}, options); - std::vector aten_inputs = {t0, t2}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto outputs = fe.runFusion(aten_inputs); - - auto ref1 = t0 + 1; - auto ref2 = mean(t2, {0}); - - testValidate(&fusion, outputs, aten_inputs, {ref1, ref2}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionGridReductionWithNonExactParallelDimensions2_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - - auto tv1 = sum(tv0, {0, 1}); - fusion.addOutput(tv1); - - auto tv2 = makeSymbolicTensor(3); - fusion.addInput(tv2); - auto tv3 = add(tv2, IrBuilder::create(1)); - fusion.addOutput(tv3); - - auto tv4 = makeSymbolicTensor(3); - fusion.addInput(tv4); - auto tv5 = add(tv4, IrBuilder::create(1)); - fusion.addOutput(tv5); - - tv1->axis(0)->parallelize(ParallelType::BIDx); - tv1->axis(1)->parallelize(ParallelType::TIDx); - - tv3->axis(0)->parallelize(ParallelType::TIDx); - tv3->axis(1)->parallelize(ParallelType::TIDy); - tv3->axis(2)->parallelize(ParallelType::TIDz); - - tv5->axis(0)->parallelize(ParallelType::BIDx); - tv5->axis(1)->parallelize(ParallelType::BIDy); - tv5->axis(2)->parallelize(ParallelType::BIDz); - - // TODO: This needs a fix for issue #1102. - // Also, need to allow predicated grid reductions. -#if 0 - FusionExecutor fe; - fe.compileFusion(&fusion); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({2, 3}, options); - at::Tensor t2 = at::randn({5, 6, 7}, options); - at::Tensor t4 = at::randn({8, 9, 10}, options); - std::vector aten_inputs = {t0, t2, t4}; - auto outputs = fe.runFusion(aten_inputs); - - auto ref1 = t0.sum(at::IntArrayRef{0, 1}); - auto ref2 = t2 + 1; - auto ref3 = t4 + 1; - - testValidate( - &fusion, outputs, aten_inputs, {ref1, ref2, ref3}, __LINE__, __FILE__); -#endif -} - -TEST_F(NVFuserTest, FusionGridWelfordWithNonExactParallelDimensions2_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - - auto tvs = Welford(tv0, {0, 1}); - fusion.addOutput(tvs.avg); - - auto tv2 = makeSymbolicTensor(3); - fusion.addInput(tv2); - auto tv3 = add(tv2, IrBuilder::create(1)); - fusion.addOutput(tv3); - - auto tv4 = makeSymbolicTensor(3); - fusion.addInput(tv4); - auto tv5 = add(tv4, IrBuilder::create(1)); - fusion.addOutput(tv5); - - tvs.avg->axis(0)->parallelize(ParallelType::BIDx); - tvs.avg->axis(1)->parallelize(ParallelType::TIDx); - - tv3->axis(0)->parallelize(ParallelType::TIDx); - tv3->axis(1)->parallelize(ParallelType::TIDy); - tv3->axis(2)->parallelize(ParallelType::TIDz); - - tv5->axis(0)->parallelize(ParallelType::BIDx); - tv5->axis(1)->parallelize(ParallelType::BIDy); - tv5->axis(2)->parallelize(ParallelType::BIDz); - - // TODO: needs a fix for issue #1102 - // Also, need to allow predicated grid reductions. -#if 0 - FusionExecutor fe; - fe.compileFusion(&fusion); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({2, 3}, options); - at::Tensor t2 = at::randn({5, 6, 7}, options); - at::Tensor t4 = at::randn({8, 9, 10}, options); - std::vector aten_inputs = {t0, t2, t4}; - auto outputs = fe.runFusion(aten_inputs); - - auto ref1 = t0.mean(at::IntArrayRef{0, 1}); - auto ref2 = t2 + 1; - auto ref3 = t4 + 1; - - testValidate( - &fusion, outputs, aten_inputs, {ref1, ref2, ref3}, __LINE__, __FILE__); -#endif -} - -// Repro of issue #1102 -TEST_F(NVFuserTest, FusionPredicateParallelizedDomains_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - - // Just to make TIDx/y/z non-exact - auto tv1 = add(tv0, IrBuilder::create(1)); - auto tv2 = add(tv1, IrBuilder::create(1)); - auto tv3 = add(tv2, IrBuilder::create(1)); - fusion.addOutput(tv3); - - auto tv4 = makeSymbolicTensor(1); - fusion.addInput(tv4); - - auto tv5 = add(tv4, IrBuilder::create(1)); - auto tv6 = add(tv5, IrBuilder::create(1)); - auto tv7 = add(tv6, IrBuilder::create(1)); - auto tv8 = add(tv7, IrBuilder::create(1)); - auto tv9 = sum(tv8, {0}); - fusion.addOutput(tv9); - - tv1->split(0, 5); - tv1->axis(-1)->parallelize(ParallelType::TIDx); - tv1->setMemoryType(MemoryType::Shared); - tv2->split(0, 6); - tv2->axis(-1)->parallelize(ParallelType::TIDy); - tv2->setMemoryType(MemoryType::Shared); - tv3->split(0, 7); - tv3->axis(-1)->parallelize(ParallelType::TIDz); - - tv9->split(0, 4); - tv4->computeAt(tv9, 1); - - tv4->axis(-1)->parallelize(ParallelType::TIDx); - tv5->axis(-1)->parallelize(ParallelType::TIDy); - tv6->axis(-1)->parallelize(ParallelType::TIDz); - tv7->axis(-1)->parallelize(ParallelType::TIDz); - tv8->axis(-1)->parallelize(ParallelType::TIDz); - tv9->axis(-1)->parallelize(ParallelType::TIDz); - tv9->axis(0)->parallelize(ParallelType::BIDx); - - tv5->setMemoryType(MemoryType::Shared); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({17}, options); - at::Tensor t4 = at::randn({19}, options); - std::vector aten_inputs = {t0, t4}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto outputs = fe.runFusion(aten_inputs); - - auto ref1 = t0 + 3; - auto ref2 = sum(t4 + 4); - - testValidate(&fusion, outputs, aten_inputs, {ref1, ref2}, __LINE__, __FILE__); -} - -// Repro of #1102 and #1129 -TEST_F(NVFuserTest, FusionSmemPredicateUnswitch_CUDA) { - if (!deviceMajorMinorCheck(7)) { - GTEST_SKIP() << "skipping tests on pre-Volta GPUs"; - return; - } - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - auto tv1 = makeSymbolicTensor(1); - fusion.addInput(tv1); - - auto tv2 = add(tv0, IrBuilder::create(1)); - auto tv3 = add(tv2, IrBuilder::create(1)); - auto tv4 = add(tv3, IrBuilder::create(1)); - auto tv5 = add(tv4, IrBuilder::create(1)); - fusion.addOutput(tv5); - - // Just to make TIDx/y/z non-exact - auto tvx = add(tv1, IrBuilder::create(1)); - auto tvy = add(tvx, IrBuilder::create(1)); - auto tvz = add(tvy, IrBuilder::create(1)); - fusion.addOutput(tvz); - - tv5->split(0, 4); - tv0->computeAt(tv5, 1); - - tv0->axis(-1)->parallelize(ParallelType::TIDx); - tv2->axis(-1)->parallelize(ParallelType::TIDy); - tv3->axis(-1)->parallelize(ParallelType::TIDz); - tv4->axis(-1)->parallelize(ParallelType::TIDx); - tv5->axis(-1)->parallelize(ParallelType::TIDy); - tv5->axis(0)->parallelize(ParallelType::Unswitch); - - tvx->split(0, 5); - tvx->axis(-1)->parallelize(ParallelType::TIDx); - tvy->split(0, 6); - tvy->axis(-1)->parallelize(ParallelType::TIDy); - tvz->split(0, 7); - tvz->axis(-1)->parallelize(ParallelType::TIDz); - - for (auto tv : {tv2, tv3, tv4, tvx, tvy}) { - tv->setMemoryType(MemoryType::Shared); - } - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({17}, options); - at::Tensor t1 = at::randn({19}, options); - std::vector aten_inputs = {t0, t1}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto outputs = fe.runFusion(aten_inputs); - - auto ref1 = t0 + 4; - auto ref2 = t1 + 3; - - testValidate(&fusion, outputs, aten_inputs, {ref1, ref2}, __LINE__, __FILE__); -} - -// Repro of issue #1136 -TEST_F(NVFuserTest, FusionFloatPow_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - - auto tv1 = binaryOp(BinaryOpType::Pow, tv0, IrBuilder::create(4)); - // To check if pow(tv0, 2) is replaced with tv0 * tv0 - auto tv2 = binaryOp(BinaryOpType::Pow, tv0, IrBuilder::create(2)); - // To check if pow(tv0, 2.0) is replaced with tv0 * tv0 - auto tv3 = binaryOp(BinaryOpType::Pow, tv0, IrBuilder::create(2)); - auto tv4 = binaryOp(BinaryOpType::Pow, tv0, IrBuilder::create(3)); - auto tv5 = binaryOp(BinaryOpType::Pow, tv0, IrBuilder::create(3)); - auto s = binaryOp( - BinaryOpType::Pow, - IrBuilder::create(3), - IrBuilder::create(3)); - auto tv6 = add(tv0, s); - - fusion.addOutput(tv1); - fusion.addOutput(tv2); - fusion.addOutput(tv3); - fusion.addOutput(tv4); - fusion.addOutput(tv5); - fusion.addOutput(tv6); - - tv1->split(0, 32); - tv1->axis(0)->parallelize(ParallelType::BIDx); - tv1->axis(1)->parallelize(ParallelType::TIDx); - - TransformPropagatorWithCheck propagator(tv1); - MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator); - scheduler_utils::parallelizeAllLike(tv1, {tv2, tv3, tv4, tv5, tv6}); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({1000}, options); - // Negative inputs cause nan in Fuesr as use_fast_math is enabled - t0 = abs(t0); - std::vector aten_inputs = {t0}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto outputs = fe.runFusion(aten_inputs); - - auto p4 = at::pow(t0, 4); - auto p2 = at::pow(t0, 2); - auto p3 = at::pow(t0, 3); - auto t6 = t0 + std::pow(3, 3); - - testValidate( - &fusion, - outputs, - aten_inputs, - {p4, p2, p2, p3, p3, t6}, - __LINE__, - __FILE__); -} - -TEST_F(NVFuserTest, FusionIssue1127_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - const int numel = 4; - - auto tv0 = makeConcreteTensor({numel}); - fusion.addInput(tv0); - - auto tv1 = sum(tv0, {0}); - auto tv2 = broadcast(tv1, {true}); - - auto tv3 = makeConcreteTensor({numel, numel}); - fusion.addInput(tv3); - - auto tv4 = sum(tv3, {1}); - - auto tv5 = add(tv2, tv4); - fusion.addOutput(tv5); - - tv1->axis(0)->parallelize(ParallelType::TIDx); - tv2->axis(0)->parallelize(ParallelType::TIDx); - tv4->axis(1)->parallelize(ParallelType::TIDx); - tv5->axis(0)->parallelize(ParallelType::TIDx); - - // Lowering should fail since tv5 is predicated and paralellized with TIDx. - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_ANY_THROW(fusion.printKernel()); -} - -TEST_F(NVFuserTest, FusionChannelsLastParser_CUDA) { - // This test may not pass if using a custom block sync as there may - // be additional calls. Skip the test as it's not specifically - // relevant with block synchronizatin. - if (std::getenv("PYTORCH_NVFUSER_USE_BLOCK_SYNC_ATOMIC")) { - return; - } - auto g = std::make_shared(); - const auto graph0_string = R"IR( - graph(%0 : Half(8, 4, 10, 16, strides=[640, 1, 64, 4]), - %1 : Half(8, 4, 10, 16, strides=[640, 160, 16, 1])): - %o.1 : Half(8, 4, 10, 16, strides=[640, 1, 64, 4]) = aten::mul(%0, %1) # sum_dyn.py:5:6 - %3 : Half(8, 4, 10, 16, strides=[640, 1, 64, 4]) = aten::relu(%o.1) # sum_dyn.py:6:9 - return (%3))IR"; - parseIR(graph0_string, g.get()); - - // strides are not yet supported in the irparser. - { - auto val = g->block()->inputs()[0]; - val->setType(val->type()->castRaw()->withSizesStrides( - {8, 4, 10, 16}, {640, 1, 64, 4})); - } - - { - auto val = g->block()->inputs()[1]; - val->setType(val->type()->castRaw()->withSizesStrides( - {8, 4, 10, 16}, {640, 160, 16, 1})); - } - - for (auto node : g->block()->nodes()) { - for (auto val : node->outputs()) { - if (val->isCompleteTensor()) - val->setType(val->type()->castRaw()->withSizesStrides( - {8, 4, 10, 16}, {640, 1, 64, 4})); - } - } - - auto fusion = parseJitIR(g); - FusionGuard fg(fusion.get()); - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - at::Tensor input0 = - at::randn({2, 2, 2, 16}, options).clone(c10::MemoryFormat::ChannelsLast); - at::Tensor input1 = at::randn({2, 2, 2, 16}, options); - auto lparams = schedulePointwise(fusion.get(), {input0, input1}); - - // CONSIDER: - // 1. this can be moved to a dedicated "golden" file - // 2. use a fuzzy compare (ignore non-significant whitespaces for example) - const std::string expected_kernel = R"( -__global__ void CUDAGeneratedKernel(Tensor<__half, 4> T0, Tensor<__half, 4> T2, Tensor<__half, 4> T7) { - int64_t i171; - i171 = (((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x); - if ((i171 < (T0.size[0] * (T0.size[1] * (T0.size[2] * T0.size[3]))))) { - __half T9[1]; - T9[0] = 0; - T9[0] - = T2[((((((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x)) / (T0.size[1] * (T0.size[2] * T0.size[3]))) * ((T0.size[2] * T0.size[1]) * T0.size[3])) + ((((((((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x)) % (T0.size[1] * (T0.size[2] * T0.size[3]))) % (T0.size[2] * T0.size[3])) % T0.size[3]) * (T0.size[2] * T0.size[1])) + (((((((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x)) % (T0.size[1] * (T0.size[2] * T0.size[3]))) / (T0.size[2] * T0.size[3])) * T0.size[2]) + (((((((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x)) % (T0.size[1] * (T0.size[2] * T0.size[3]))) % (T0.size[2] * T0.size[3])) / T0.size[3])]; - __half T8[1]; - T8[0] = 0; - T8[0] - = T0[i171]; - float T3[1]; - T3[0] - = __half2float(T9[0]); - float T4[1]; - T4[0] - = T3[0]; - float T1[1]; - T1[0] - = __half2float(T8[0]); - float T5[1]; - T5[0] - = T1[0] - * T4[0]; - float T6[1]; - T6[0] - = relu(T5[0]); - __half T10[1]; - T10[0] - = __float2half(T6[0]); - T7[i171] - = T10[0]; - } -} -)"; - - const std::string actual_kernel = - "\n" + codegen::generateCudaKernel(GpuLower(fusion.get()).kernel()); - - if (expected_kernel.size() != actual_kernel.size() || - expected_kernel.compare(actual_kernel) != 0) { - std::cerr - << " Codegen mismatch, codegen possibly changed, or is incorrect. " - << " \n ========= EXPECTED ========= \n" - << expected_kernel << "\n========= ACTUAL ========== \n" - << actual_kernel << "\n=================" << std::endl; - auto it = std::mismatch( - expected_kernel.begin(), - expected_kernel.end(), - actual_kernel.begin(), - actual_kernel.end()); - std::string actual_mismatched_snippet(it.second, actual_kernel.end()); - actual_mismatched_snippet = actual_mismatched_snippet.substr(0, 10); - std::string expected_mismatched_snippet(it.first, expected_kernel.end()); - expected_mismatched_snippet = expected_mismatched_snippet.substr(0, 10); - std::cerr << "First mismatch found at: " << actual_mismatched_snippet - << ", expected: " << expected_mismatched_snippet << std::endl; - TORCH_CHECK(false); - } - - // TODO: runFusion hits assertion. I'm probably doing something wrong here. - // FusionExecutor fe; - // fe.compileFusion(fusion.get()); - // auto outputs = fe.runFusion({input0, input1}, lparams); - // at::Tensor output_ref = (input0 * input1).relu(); - // TORCH_CHECK(output_ref.equal(outputs[0])); -} - -TEST_F(NVFuserTest, FusionThreadPredicateUnswitch_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeConcreteTensor({10, 1024}); - fusion.addInput(tv0); - - auto tv1 = sum(tv0, {1}); - auto tv2 = add(tv1, IrBuilder::create(1)); - auto tv3 = add(tv2, IrBuilder::create(1)); - - fusion.addOutput(tv3); - - tv1->axis(-1)->parallelize(ParallelType::TIDx); - tv2->computeAt(tv3, -1); - tv3->axis(0)->parallelize(ParallelType::Unswitch); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({10, 1024}, options); - std::vector aten_inputs = {t0}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto outputs = fe.runFusion(aten_inputs); - - auto ref = sum(t0, {1}) + 2; - - testValidate(&fusion, outputs, aten_inputs, {ref}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionNonContigOutputs_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - - auto tv1 = add(tv0, IrBuilder::create(1)); - fusion.addOutput(tv1); - - tv1->setContiguity(false); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor at_input = at::randn({10}, options); - at::Tensor at_output = at::empty_strided({10}, {2}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {at_input}); - auto returned_outputs = fe.runFusion({at_input}, {at_output}); - - // Returned outputs should only contain one tensor that is the same - // as the output tensor given to runFusion - TORCH_CHECK(returned_outputs.size() == 1); - TORCH_CHECK(returned_outputs[0].is_same(at_output)); - TORCH_CHECK(!returned_outputs[0].is_contiguous()); - - auto at_ref = at_input + 1; - - testValidate(&fusion, {at_output}, {at_input}, {at_ref}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionTestWarpSoftMax_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Setup softmax fusion - auto input = makeContigTensor(2); - fusion.addInput(input); - auto output = softmax(input, 1); - fusion.addOutput(output); - - // Setup runtime input - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor aten_input = at::randn({8, 16 * 197}, options); - std::vector aten_inputs({aten_input}); - - // Schedule through magic scheduler - SchedulerRuntimeInfo runtime_info(&fusion, aten_inputs, true); - TORCH_CHECK(SchedulerEntry::canSchedule( - ScheduleHeuristic::Persistent, &fusion, runtime_info)); - auto scheduler = SchedulerEntry::makeEntry( - ScheduleHeuristic::Persistent, &fusion, runtime_info); - scheduler->schedule(&fusion); - - // Modify the schedule to use warp reduction - auto used_vals = fusion.usedMathVals(); - for (auto tv : ir_utils::filterByType(used_vals)) { - for (IterDomain* id : tv->domain()->domain()) { - if (id->getParallelType() == ParallelType::TIDx) { - id->padToMultipleOfWarp(); - } - } - } - - // Test result - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto outputs = fe.runFusion(aten_inputs); - auto ref_output = at::_softmax(aten_input, 1, false); - testValidate(&fusion, outputs, aten_inputs, {ref_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionIssue1133_CUDA) { - if (!deviceMajorMinorCheck(7)) { - GTEST_SKIP() << "skipping tests on pre-Volta GPUs"; - return; - } - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - - auto tv1 = add(tv0, IrBuilder::create(1)); - auto tv2 = sum(tv1, {1}); - auto tv3 = add(tv2, IrBuilder::create(1)); - - fusion.addOutput(tv3); - - tv0->computeAt(tv3, 1); - - const int split_factor = 32; - - tv2->split(-1, split_factor); - tv1->computeAt(tv2, -2); - - tv1->axis(-1)->parallelize(ParallelType::TIDx); - tv2->axis(-1)->parallelize(ParallelType::TIDx); - - tv3->axis(0)->parallelize(ParallelType::Unswitch); - - tv1->setMemoryType(MemoryType::Shared); - tv2->setMemoryType(MemoryType::Shared); - - // Both tv1 and tv2 should be allocated at the top-level scope - GpuLower gpulw(&fusion); - bool tv1_validated = false; - bool tv2_validated = false; - for (const auto& kir_node : gpulw.kernel()->topLevelExprs()) { - if (auto alloc = dynamic_cast(kir_node)) { - auto size = alloc->size(); - if (!(alloc->buffer()->name() == 1 || alloc->buffer()->name() == 2)) { - // There should be no allocation other than those for tv1 and tv2 - TORCH_CHECK(false, "Invalid allocation detected"); - } - TORCH_CHECK(size->isA(), "Invalid allocation size"); - TORCH_CHECK(size->as()->isConst(), "Allocation not constant"); - auto size_int = size->as()->value().value(); - if (alloc->buffer()->name() == 1) { - TORCH_CHECK( - size_int == split_factor, - "Invalid allocation size: ", - size->as()->value().value()); - tv1_validated = true; - } else { - TORCH_CHECK( - size_int == 1, - "Invalid allocation size: ", - size->as()->value().value()); - tv2_validated = true; - } - } - } - - TORCH_CHECK(tv1_validated, "Failed to validate tv1 allocation"); - TORCH_CHECK(tv2_validated, "Failed to validate tv2 allocation"); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({99, 101}, options); - std::vector aten_inputs = {t0}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto outputs = fe.runFusion(aten_inputs); - - auto ref = (t0 + 1).sum({1}) + 1; - - testValidate(&fusion, outputs, aten_inputs, {ref}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionRfactorContigIDs_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - - auto tv1 = sum(tv0, {1}); - fusion.addOutput(tv1); - - tv1->split(1, 32); - - auto tv2 = tv1->rFactor({1}); - - // This merged domain is not contiguous. - tv2->merge(0, 2); - - tv2->setMemoryType(MemoryType::Shared); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({99, 101}, options); - std::vector aten_inputs = {t0}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto outputs = fe.runFusion(aten_inputs); - - auto ref = t0.sum({1}); - - testValidate(&fusion, outputs, aten_inputs, {ref}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionPersistentBufferCalculation1_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - - auto tv1 = set(tv0); - auto tv2 = sum(tv1, {1}); - auto tv3 = broadcast(tv2, {false, true}); - auto tv4 = set(tv1); - auto tv5 = add(tv3, tv4); - fusion.addOutput(tv5); - - auto persistent_buffer_info = scheduler_utils::persistentBuffers(&fusion); - - auto isTvWithinVec = [](std::vector& vec, TensorView* tv) { - return std::find(vec.begin(), vec.end(), tv) != vec.end(); - }; - - auto tvEntryInVecVec = [](std::vector>& vec_o_vec, - std::vector& buffer_vec, - TensorView* tv) { - auto buffer_it = std::find(buffer_vec.begin(), buffer_vec.end(), tv); - return vec_o_vec.begin() + std::distance(buffer_vec.begin(), buffer_it); - }; - - auto& buffers = persistent_buffer_info.persistent_buffers; - auto& resolution = persistent_buffer_info.persistent_buffer_resolution_points; - auto& projectable = persistent_buffer_info.projectable_persistent_buffers; - auto& projectable_inputs = persistent_buffer_info.projectable_buffer_inputs; - - TORCH_INTERNAL_ASSERT(buffers.size() == 1); - TORCH_INTERNAL_ASSERT(resolution.size() == 1 && resolution[0].size() == 1); - TORCH_INTERNAL_ASSERT(projectable.size() == 1); - TORCH_INTERNAL_ASSERT(projectable_inputs.size() == 1); - - TORCH_INTERNAL_ASSERT(isTvWithinVec(buffers, tv1)); - TORCH_INTERNAL_ASSERT(isTvWithinVec(projectable, tv1)); - TORCH_INTERNAL_ASSERT(isTvWithinVec(projectable_inputs, tv0)); - - auto tv1_resolution_it = tvEntryInVecVec(resolution, buffers, tv1); - TORCH_INTERNAL_ASSERT(tv1_resolution_it != resolution.end()) - - TORCH_INTERNAL_ASSERT(isTvWithinVec(*tv1_resolution_it, tv5)); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor aten_t0 = at::randn({99, 101}, options); - - // Schedule through magic scheduler - SchedulerRuntimeInfo runtime_info(&fusion, {aten_t0}, true); - auto persistent_buffer_size = - persistentBufferSize(&fusion, runtime_info, persistent_buffer_info); - - TORCH_INTERNAL_ASSERT( - persistent_buffer_size.persistent_buffer_size == - static_cast(aten_t0.size(1) * dataTypeSize(DataType::Float))); - TORCH_INTERNAL_ASSERT( - persistent_buffer_size.projected_persistent_buffer_size == - static_cast(aten_t0.size(1) * dataTypeSize(DataType::Float))); -} - -TEST_F(NVFuserTest, FusionPersistentBufferCalculation2_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2, DataType::Half); - fusion.addInput(tv0); - - auto tv1 = castOp(DataType::Float, tv0); - auto tv2 = sum(tv1, {1}); - auto tv3 = broadcast(tv2, {false, true}); - auto tv4 = set(tv1); - auto tv5 = add(tv3, tv4); - auto tv6 = castOp(DataType::Half, tv5); - fusion.addOutput(tv6); - - auto persistent_buffer_info = scheduler_utils::persistentBuffers(&fusion); - - auto isTvWithinVec = [](std::vector& vec, TensorView* tv) { - return std::find(vec.begin(), vec.end(), tv) != vec.end(); - }; - - auto tvEntryInVecVec = [](std::vector>& vec_o_vec, - std::vector& buffer_vec, - TensorView* tv) { - auto buffer_it = std::find(buffer_vec.begin(), buffer_vec.end(), tv); - return vec_o_vec.begin() + std::distance(buffer_vec.begin(), buffer_it); - }; - - auto& buffers = persistent_buffer_info.persistent_buffers; - auto& resolution = persistent_buffer_info.persistent_buffer_resolution_points; - auto& projectable = persistent_buffer_info.projectable_persistent_buffers; - auto& projectable_inputs = persistent_buffer_info.projectable_buffer_inputs; - - TORCH_INTERNAL_ASSERT(buffers.size() == 1); - TORCH_INTERNAL_ASSERT(resolution.size() == 1 && resolution[0].size() == 1); - TORCH_INTERNAL_ASSERT(projectable.size() == 1); - TORCH_INTERNAL_ASSERT(projectable_inputs.size() == 1); - - TORCH_INTERNAL_ASSERT(isTvWithinVec(buffers, tv1)); - TORCH_INTERNAL_ASSERT(isTvWithinVec(projectable, tv1)); - TORCH_INTERNAL_ASSERT(isTvWithinVec(projectable_inputs, tv0)); - - auto tv1_resolution_it = tvEntryInVecVec(resolution, buffers, tv1); - TORCH_INTERNAL_ASSERT(tv1_resolution_it != resolution.end()) - - TORCH_INTERNAL_ASSERT(isTvWithinVec(*tv1_resolution_it, tv5)); - - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - at::Tensor aten_t0 = at::randn({99, 101}, options); - - // Schedule through magic scheduler - SchedulerRuntimeInfo runtime_info(&fusion, {aten_t0}, true); - auto persistent_buffer_size = - persistentBufferSize(&fusion, runtime_info, persistent_buffer_info); - - TORCH_INTERNAL_ASSERT( - persistent_buffer_size.persistent_buffer_size == - static_cast(aten_t0.size(1) * dataTypeSize(DataType::Float))); - TORCH_INTERNAL_ASSERT( - persistent_buffer_size.projected_persistent_buffer_size == - static_cast(aten_t0.size(1) * dataTypeSize(DataType::Half))); -} - -TEST_F(NVFuserTest, FusionPersistentBufferCalculation3_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2, DataType::Half); - fusion.addInput(tv0); - - auto tv1 = castOp(DataType::Float, tv0); - auto tv2 = set(tv1); - auto tv3 = sum(tv2, {1}); - auto tv4 = broadcast(tv3, {false, true}); - - auto tv5 = makeSymbolicTensor(2, DataType::Half); - fusion.addInput(tv5); - - auto tv6 = castOp(DataType::Float, tv5); - - auto tv7 = add(tv6, tv4); - auto tv8 = set(tv1); - auto tv9 = add(tv7, tv8); - auto tv10 = sum(tv9, {1}); - auto tv11 = broadcast(tv10, {false, true}); - auto tv12 = set(tv7); - auto tv13 = add(tv12, tv11); - - fusion.addOutput(tv13); - - auto persistent_buffer_info = scheduler_utils::persistentBuffers(&fusion); - - auto isTvWithinVec = [](std::vector& vec, TensorView* tv) { - return std::find(vec.begin(), vec.end(), tv) != vec.end(); - }; - - auto tvEntryInVecVec = [](std::vector>& vec_o_vec, - std::vector& buffer_vec, - TensorView* tv) { - auto buffer_it = std::find(buffer_vec.begin(), buffer_vec.end(), tv); - return vec_o_vec.begin() + std::distance(buffer_vec.begin(), buffer_it); - }; - - auto& buffers = persistent_buffer_info.persistent_buffers; - auto& resolution = persistent_buffer_info.persistent_buffer_resolution_points; - auto& projectable = persistent_buffer_info.projectable_persistent_buffers; - auto& projectable_inputs = persistent_buffer_info.projectable_buffer_inputs; - - TORCH_INTERNAL_ASSERT(buffers.size() == 2); - TORCH_INTERNAL_ASSERT( - resolution.size() == 2 && resolution[0].size() == 1 && - resolution[1].size() == 1); - TORCH_INTERNAL_ASSERT(projectable.size() == 1); - TORCH_INTERNAL_ASSERT(projectable_inputs.size() == 1); - - TORCH_INTERNAL_ASSERT( - isTvWithinVec(buffers, tv1) && isTvWithinVec(buffers, tv7)); - TORCH_INTERNAL_ASSERT( - isTvWithinVec(projectable, tv1) && !isTvWithinVec(projectable, tv7)); - - TORCH_INTERNAL_ASSERT(isTvWithinVec(projectable_inputs, tv0)); - - auto tv1_resolution_it = tvEntryInVecVec(resolution, buffers, tv1); - TORCH_INTERNAL_ASSERT(tv1_resolution_it != resolution.end()) - TORCH_INTERNAL_ASSERT(isTvWithinVec(*tv1_resolution_it, tv9)); - - auto tv7_resolution_it = tvEntryInVecVec(resolution, buffers, tv7); - TORCH_INTERNAL_ASSERT(tv7_resolution_it != resolution.end()) - TORCH_INTERNAL_ASSERT(isTvWithinVec(*tv7_resolution_it, tv13)); - - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - at::Tensor aten_t0 = at::randn({99, 101}, options); - at::Tensor aten_t5 = at::randn({99, 101}, options); - - // Schedule through magic scheduler - SchedulerRuntimeInfo runtime_info(&fusion, {aten_t0, aten_t5}, true); - auto persistent_buffer_size = - persistentBufferSize(&fusion, runtime_info, persistent_buffer_info); - - TORCH_INTERNAL_ASSERT( - persistent_buffer_size.persistent_buffer_size == - static_cast( - aten_t0.size(1) * dataTypeSize(DataType::Float) * 2)); - TORCH_INTERNAL_ASSERT( - persistent_buffer_size.projected_persistent_buffer_size == - static_cast( - aten_t0.size(1) * - (dataTypeSize(DataType::Half) + dataTypeSize(DataType::Float)))); -} - -TEST_F(NVFuserTest, FusionPersistentBufferCalculation4_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2, DataType::Half); - fusion.addInput(tv0); - - auto tv1 = castOp(DataType::Float, tv0); - auto tv2 = set(tv1); - auto tv3 = sum(tv2, {1}); - auto tv4 = broadcast(tv3, {false, true}); - auto tv5 = set(tv1); - auto tv6 = add(tv4, tv5); - auto tv7 = set(tv2); - auto tv8 = add(tv7, tv6); - auto tv9 = castOp(DataType::Half, tv8); - - fusion.addOutput(tv9); - - auto persistent_buffer_info = scheduler_utils::persistentBuffers(&fusion); - - auto isTvWithinVec = [](std::vector& vec, TensorView* tv) { - return std::find(vec.begin(), vec.end(), tv) != vec.end(); - }; - - auto tvEntryInVecVec = [](std::vector>& vec_o_vec, - std::vector& buffer_vec, - TensorView* tv) { - auto buffer_it = std::find(buffer_vec.begin(), buffer_vec.end(), tv); - return vec_o_vec.begin() + std::distance(buffer_vec.begin(), buffer_it); - }; - - auto& buffers = persistent_buffer_info.persistent_buffers; - auto& resolution = persistent_buffer_info.persistent_buffer_resolution_points; - auto& projectable = persistent_buffer_info.projectable_persistent_buffers; - auto& projectable_inputs = persistent_buffer_info.projectable_buffer_inputs; - - TORCH_INTERNAL_ASSERT(buffers.size() == 2); - TORCH_INTERNAL_ASSERT( - resolution.size() == 2 && resolution[0].size() == 1 && - resolution[1].size() == 1); - - TORCH_INTERNAL_ASSERT(projectable.size() == 2); - TORCH_INTERNAL_ASSERT(projectable_inputs.size() == 1); - - TORCH_INTERNAL_ASSERT( - isTvWithinVec(buffers, tv1) && isTvWithinVec(buffers, tv2)); - TORCH_INTERNAL_ASSERT( - isTvWithinVec(projectable, tv1) && isTvWithinVec(projectable, tv2)); - - TORCH_INTERNAL_ASSERT(isTvWithinVec(projectable_inputs, tv0)); - - auto tv1_resolution_it = tvEntryInVecVec(resolution, buffers, tv1); - TORCH_INTERNAL_ASSERT(tv1_resolution_it != resolution.end()) - TORCH_INTERNAL_ASSERT(isTvWithinVec(*tv1_resolution_it, tv6)); - - auto tv2_resolution_it = tvEntryInVecVec(resolution, buffers, tv2); - TORCH_INTERNAL_ASSERT(tv2_resolution_it != resolution.end()) - TORCH_INTERNAL_ASSERT(isTvWithinVec(*tv2_resolution_it, tv8)); - - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - at::Tensor aten_t0 = at::randn({99, 101}, options); - - // Schedule through magic scheduler - SchedulerRuntimeInfo runtime_info(&fusion, {aten_t0}, true); - auto persistent_buffer_size = - persistentBufferSize(&fusion, runtime_info, persistent_buffer_info); - - TORCH_INTERNAL_ASSERT( - persistent_buffer_size.persistent_buffer_size == - static_cast( - aten_t0.size(1) * dataTypeSize(DataType::Float) * 2)); - - TORCH_INTERNAL_ASSERT( - persistent_buffer_size.projected_persistent_buffer_size == - static_cast(aten_t0.size(1) * dataTypeSize(DataType::Half))); -} - -TEST_F(NVFuserTest, FusionPersistentBufferProjection_CUDA) { - std::unique_ptr fusion_ptr = std::make_unique(); - Fusion& fusion = *fusion_ptr.get(); - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2, DataType::Half); - fusion.addInput(tv0); - - auto tv1 = castOp(DataType::Float, tv0); - auto tv2 = set(tv1); - auto tv3 = sum(tv2, {1}); - auto tv4 = broadcast(tv3, {false, true}); - auto tv5 = set(tv1); - auto tv6 = add(tv4, tv5); - auto tv7 = set(tv2); - auto tv8 = add(tv7, tv6); - auto tv9 = castOp(DataType::Half, tv8); - - fusion.addOutput(tv9); - - reduction_scheduler_utils::projectPersistentBuffers(&fusion); - - auto tv5_producers = ir_utils::producerTvsOf(tv5); - auto tv7_producers = ir_utils::producerTvsOf(tv7); - - // Projection should have broken these dependencies - - TORCH_INTERNAL_ASSERT( - std::find(tv5_producers.begin(), tv5_producers.end(), tv1) == - tv5_producers.end()); - TORCH_INTERNAL_ASSERT( - std::find(tv7_producers.begin(), tv7_producers.end(), tv2) == - tv7_producers.end()); - - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - at::Tensor aten_t0 = at::randn({99, 101}, options); - - FusionExecutorCache fec(std::move(fusion_ptr)); - auto cg_outputs = fec.runFusionWithInputs({aten_t0}); - - auto aten_t1 = aten_t0.to(c10::kDouble); - auto aten_t3 = aten_t1.sum({1}); - auto aten_t4 = aten_t3.unsqueeze(1); - auto aten_t7 = aten_t4.add(aten_t1).add(aten_t1); - - testValidate(&fusion, cg_outputs, {aten_t0}, {aten_t7}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionIssue1223_CUDA) { - if (!deviceMajorMinorCheck(7)) { - GTEST_SKIP() << "skipping tests on pre-Volta GPUs"; - return; - } - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeContigTensor(2); - fusion.addInput(tv0); - - auto tv1 = add(tv0, IrBuilder::create(1)); - auto tv2 = sum(tv1, {0, 1}); - fusion.addOutput(tv2); - - auto tv3 = add(tv0, IrBuilder::create(0)); - fusion.addOutput(tv3); - - tv2->split(0, 4); - tv2->split(1, 1, false); - tv2->split(-1, 4); - - tv2->axis(1)->parallelize(ParallelType::Unswitch); - tv2->axis(-3)->parallelize(ParallelType::TIDx); - tv2->axis(-1)->parallelize(ParallelType::TIDy); - - tv1->computeAt(tv2, -1); - - // Make TIDx and TIDy non-exact - tv3->split(0, 32); - tv3->split(-1, 32); - tv3->axis(1)->parallelize(ParallelType::TIDx); - tv3->axis(3)->parallelize(ParallelType::TIDy); - - // The second axis of both tv1 and tv2 are fully unswitched, so they - // don't need to predicate the parallel type usage of TIDy, whereas - // the first axis is only partially unswitched, i.e., part of its - // split output domains is outside the unswitched axis, so the first - // axis, which uses TIDx, needs to predicate the parallel - // dimension. Previously, as reported in issue #1223, unswitched - // expressions didn't predicate parallel dimensions. It should be - // fixed by PR #1222. - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor at_t0 = at::ones({11, 10}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {at_t0}); - auto cg_outputs = fe.runFusion({at_t0}); - - auto at_t1 = (at_t0 + 1).sum(); - - testValidate( - &fusion, cg_outputs, {at_t0}, {at_t1, at_t0}, __LINE__, __FILE__); -} - -// See #1247 and #1250 -TEST_F(NVFuserTest, FusionRfactorPredication1_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeContigTensor(1); - fusion.addInput(tv0); - - auto tv1 = add(tv0, IrBuilder::create(1)); - auto tv2 = min(tv1, {0}); - - fusion.addOutput(tv2); - - // Make TIDx non-exact - auto tv3 = makeContigTensor(1); - fusion.addInput(tv3); - - auto tv4 = add(tv3, IrBuilder::create(1)); - fusion.addOutput(tv4); - - tv2->split(0, 4); - auto tv5 = tv2->rFactor({1}); - - tv0->computeAt(tv2, 1); - - tv2->axis(0)->parallelize(ParallelType::TIDx); - - tv4->axis(0)->parallelize(ParallelType::TIDx); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor at_t0 = at::randn({9}, options); - at_t0 = at::abs(at_t0); - at::Tensor at_t3 = at::randn({128}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {at_t0, at_t3}); - auto cg_outputs = fe.runFusion({at_t0, at_t3}); - - auto at_t2 = (at_t0 + 1).min(); - auto at_t4 = at_t3 + 1; - - testValidate( - &fusion, cg_outputs, {at_t0, at_t3}, {at_t2, at_t4}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionRfactorPredication2_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeContigTensor(1); - fusion.addInput(tv0); - - auto tv1 = min(tv0, {0}); - fusion.addOutput(tv1); - - // Make TIDx non-exact - auto tv2 = makeContigTensor(1); - fusion.addInput(tv2); - - auto tv3 = add(tv2, IrBuilder::create(1)); - fusion.addOutput(tv3); - - tv1->split(0, 4); - auto tv4 = tv1->rFactor({0}); - - tv1->split(0, 3); - - // tv0->computeAt(tv1, 3); - tv4->reorder({{0, 1}}); - tv4->split(0, 3); - tv4->setMemoryType(MemoryType::Shared); - - // tv0: [I] - // tv4: [4/3, 3, I/4] - // tv1: [4/3, 3] - - tv1->axis(0)->parallelize(ParallelType::TIDx); - scheduler_utils::parallelizeAllLike(tv1, {tv4}); - - tv3->axis(0)->parallelize(ParallelType::TIDx); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::manual_seed(0); - at::Tensor at_t0 = at::randn({9}, options); - at_t0 = at::abs(at_t0); - at::Tensor at_t3 = at::randn({128}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {at_t0, at_t3}); - auto cg_outputs = fe.runFusion({at_t0, at_t3}); - - auto at_t2 = std::get<0>(at_t0.min(0)); - auto at_t4 = at_t3 + 1; - - testValidate( - &fusion, cg_outputs, {at_t0, at_t3}, {at_t2, at_t4}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionRfactorIndirectRoot_CUDA) { - // https://github.com/csarofeen/pytorch/issues/1692 - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(3); - fusion.addInput(tv0); - - auto tv1 = sum(tv0, {1, 2}); - fusion.addOutput(tv1); - - tv1->split(2, 4); - tv1->split(1, 3); - tv1->merge(2, 3); - auto rf = tv1->rFactor({-1}); - - tv1->split(0, 256); - tv1->axis(0)->parallelize(ParallelType::BIDx); - tv1->axis(1)->parallelize(ParallelType::TIDx); - rf->computeAt(tv1, -1); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::manual_seed(0); - - auto at_in = at::randn({6, 6, 6}, options); - auto at_out = at_in.sum({1, 2}); - - FusionExecutor fe; - fe.compileFusion(&fusion, {at_in}); - auto cg_outputs = fe.runFusion({at_in}); - - testValidate(&fusion, cg_outputs, {at_in}, {at_out}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionNonDivisibleSplit1_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - - auto tv1 = sum(tv0, {0}); - fusion.addOutput(tv1); - - // [I] - tv1->split(0, 5); - // [ceilDiv(I, 5), 5] - - // This second split is non-divisible. The split domain must be predicated. - tv1->split(1, 3); - // [ceilDiv(I, 5), 2, 3] - - auto tv2 = sum(tv0, {0}); - fusion.addOutput(tv2); - - // tv2 shouldn't need to have another predicate - tv2->split(0, 4); - tv2->split(1, 2); - - GpuLower gpulw(&fusion); - TORCH_CHECK( - gpulw.nonDivisibleSplitInfo().splitsToValidate().empty(), - "There must be no split to validate"); - TORCH_CHECK( - gpulw.nonDivisibleSplitInfo().splitsToPredicate().size() == 1, - "Only tv1 should have a non-divisible predicate."); - for (auto tv : {loweredTv(tv1, gpulw)}) { - auto it = gpulw.nonDivisibleSplitInfo().splitsToPredicate().find(tv); - TORCH_CHECK( - it != gpulw.nonDivisibleSplitInfo().splitsToPredicate().end(), - "No info found for ", - tv); - const auto& splits_to_predicate = it->second; - TORCH_CHECK( - splits_to_predicate.size() == 1, - "There must be one split to predicate"); - } - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::manual_seed(0); - at::Tensor t0 = at::randn({24}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0}); - auto cg_outputs = fe.runFusion({t0}); - - auto ref = t0.sum(); - - testValidate(&fusion, cg_outputs, {t0}, {ref, ref}, __LINE__, __FILE__); -} - -// Repro of issue #1074 -TEST_F(NVFuserTest, FusionNonDivisibleSplit2_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - auto tv1 = add(tv0, IrBuilder::create(1)); - auto tv2 = add(tv1, IrBuilder::create(1)); - fusion.addOutput(tv2); - - tv2->split(0, 2); - tv2->split(-1, 4); - tv2->reorder({{1, 2}, {2, 1}}); - tv0->computeAt(tv2, 2); - - tv2->split(-1, 3); - - // To make the sanitizer catch the invalid accesses. Not necessary - // to expose the bug. - tv1->setMemoryType(MemoryType::Shared); - - GpuLower gpulw(&fusion); - TORCH_CHECK( - gpulw.nonDivisibleSplitInfo().splitsToValidate().empty(), - "There must be no split to validate"); - TORCH_CHECK( - gpulw.nonDivisibleSplitInfo().splitsToPredicate().size() == 1, - "Only tv2 should have a non-divisible predicate."); - for (auto tv : {loweredTv(tv2, gpulw)}) { - auto it = gpulw.nonDivisibleSplitInfo().splitsToPredicate().find(tv); - TORCH_CHECK( - it != gpulw.nonDivisibleSplitInfo().splitsToPredicate().end(), - "No info found for ", - tv); - const auto& splits_to_predicate = it->second; - TORCH_CHECK( - splits_to_predicate.size() == 1, - "There must be one split to predicate"); - } - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::manual_seed(0); - at::Tensor t0 = at::randn({13, 17}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0}); - auto cg_outputs = fe.runFusion({t0}); - - auto ref = t0 + 2; - - testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); -} - -// Similar to FusionNonDivisibleSplit1 but with unswitch -TEST_F(NVFuserTest, FusionNonDivisibleSplit3_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - - auto tv1 = add(tv0, IrBuilder::create(1)); - auto tv2 = sum(tv1, {0}); - fusion.addOutput(tv2); - - tv2->split(0, 5); - tv2->split(1, 3); - - tv0->computeAt(tv2, -1); - - tv2->axis(0)->parallelize(ParallelType::Unswitch); - - GpuLower gpulw(&fusion); - TORCH_CHECK( - gpulw.nonDivisibleSplitInfo().splitsToValidate().empty(), - "There must be no split to validate"); - TORCH_CHECK( - gpulw.nonDivisibleSplitInfo().splitsToPredicate().size() == 2, - "Both tv1 and tv2 should have a non-divisible predicate."); - for (auto tv : {loweredTv(tv1, gpulw), loweredTv(tv2, gpulw)}) { - auto it = gpulw.nonDivisibleSplitInfo().splitsToPredicate().find(tv); - TORCH_CHECK( - it != gpulw.nonDivisibleSplitInfo().splitsToPredicate().end(), - "No info found for ", - tv); - const auto& splits_to_predicate = it->second; - TORCH_CHECK( - splits_to_predicate.size() == 1, - "There must be one split to predicate"); - } - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::manual_seed(0); - at::Tensor t0 = at::randn({24}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0}); - auto cg_outputs = fe.runFusion({t0}); - - auto ref = (t0 + 1).sum(); - - testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); -} - -// Non-divisible split through merge -TEST_F(NVFuserTest, FusionNonDivisibleSplit4_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - - auto tv1 = add(tv0, IrBuilder::create(1)); - auto tv2 = sum(tv1, {0, 1}); - fusion.addOutput(tv2); - - tv2->split(0, 5); - tv2->merge(1, 2); - tv2->split(1, 3); - - tv0->computeAt(tv2, -1); - - GpuLower gpulw(&fusion); - TORCH_CHECK( - gpulw.nonDivisibleSplitInfo().splitsToValidate().empty(), - "There must be no split to validate"); - TORCH_CHECK( - gpulw.nonDivisibleSplitInfo().splitsToPredicate().size() == 2, - "Both tv1 and tv2 should have a non-divisible predicate."); - for (auto tv : {loweredTv(tv1, gpulw), loweredTv(tv2, gpulw)}) { - auto it = gpulw.nonDivisibleSplitInfo().splitsToPredicate().find(tv); - TORCH_CHECK( - it != gpulw.nonDivisibleSplitInfo().splitsToPredicate().end(), - "No info found for ", - tv); - const auto& splits_to_predicate = it->second; - TORCH_CHECK( - splits_to_predicate.size() == 1, - "There must be one split to predicate"); - } - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::manual_seed(0); - at::Tensor t0 = at::randn({24, 2}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0}); - auto cg_outputs = fe.runFusion({t0}); - - auto ref = (t0 + 1).sum(); - - testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); -} - -// Nested splits -TEST_F(NVFuserTest, FusionNonDivisibleSplit5_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - - auto tv1 = add(tv0, IrBuilder::create(1)); - auto tv2 = sum(tv1, {0}); - fusion.addOutput(tv2); - - // [I] - tv2->split(0, 8); - // [I/8, 8] - tv2->split(1, 2); - // [I/8, 4, 2] - tv2->split(1, 3); // non-divisible split of outer output - // [I/8, 2, 3, 2] - - tv0->computeAt(tv2, -1); - - GpuLower gpulw(&fusion); - TORCH_CHECK( - gpulw.nonDivisibleSplitInfo().splitsToValidate().empty(), - "There must be no split to validate"); - TORCH_CHECK( - gpulw.nonDivisibleSplitInfo().splitsToPredicate().size() == 2, - "Both tv1 and tv2 should have a non-divisible predicate."); - for (auto tv : {loweredTv(tv1, gpulw), loweredTv(tv2, gpulw)}) { - auto it = gpulw.nonDivisibleSplitInfo().splitsToPredicate().find(tv); - TORCH_CHECK( - it != gpulw.nonDivisibleSplitInfo().splitsToPredicate().end(), - "No info found for ", - tv); - const auto& splits_to_predicate = it->second; - TORCH_CHECK( - splits_to_predicate.size() == 1, - "There must be one split to predicate"); - } - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::manual_seed(0); - at::Tensor t0 = at::randn({24}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0}); - auto cg_outputs = fe.runFusion({t0}); - - auto ref = (t0 + 1).sum(); - - testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); -} - -// Vectorized non-divisible split. Must be validated at run time -TEST_F(NVFuserTest, FusionNonDivisibleSplitVectorize1_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeContigTensor(1); - fusion.addInput(tv0); - - auto tv1 = set(tv0); - fusion.addOutput(tv1); - - tv1->split(0, 8, false); - tv1->split(1, 4); - - tv1->axis(-1)->parallelize(ParallelType::Vectorize); - - GpuLower gpulw(&fusion); - TORCH_CHECK( - gpulw.nonDivisibleSplitInfo().splitsToValidate().size() == 1, - "There should be one split to validate"); - for (const auto& kv : gpulw.nonDivisibleSplitInfo().splitsToPredicate()) { - const auto& splits_to_predicate = kv.second; - TORCH_CHECK( - splits_to_predicate.empty(), - "There must be no split to predicate, but tensor t", - kv.first->name(), - " has:", - splits_to_predicate); - } - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::manual_seed(0); - auto t0 = at::randn({32}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0}); - auto cg_outputs = fe.runFusion({t0}); - - auto ref = t0; - - testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); - - auto t0_non_divisible = at::randn({8}, options); - // Since ceilDiv(8, 8) is not divisible by 4, the vectorization is - // illegal. The run-time validation of vectorization should throw an error. - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_ANY_THROW(fe.runFusion({t0_non_divisible})); -} - -// If a split is validated at run time, it's not necessary to predicate. -TEST_F(NVFuserTest, FusionNonDivisibleSplitVectorize2_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeContigTensor(1); - fusion.addInput(tv0); - - auto tv1 = set(tv0); - auto tv2 = add(tv1, IrBuilder::create(1)); - auto tv3 = sum(tv2, {0}); - fusion.addOutput(tv3); - - tv3->split(0, 8, false); - tv3->split(1, 4); - TransformPropagatorWithCheck propagator(tv3); - MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); - - tv3->axis(1)->parallelize(ParallelType::TIDx); - scheduler_utils::parallelizeAllLike(tv3, {tv1, tv2}); - - tv1->axis(2)->parallelize(ParallelType::Vectorize); - - GpuLower gpulw(&fusion); - TORCH_CHECK( - gpulw.nonDivisibleSplitInfo().splitsToValidate().size() == 1, - "There should be one split to validate"); - for (const auto& kv : gpulw.nonDivisibleSplitInfo().splitsToPredicate()) { - const auto& splits_to_predicate = kv.second; - TORCH_CHECK( - splits_to_predicate.empty(), - "There must be no split to predicate, but tensor t", - kv.first->name(), - " has:", - splits_to_predicate); - } - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::manual_seed(0); - - auto t0 = at::randn({1024}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0}); - auto cg_outputs = fe.runFusion({t0}); - - auto ref = (t0 + 1).sum(); - - testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionIssue1284Repro_CUDA) { - std::unique_ptr fusion_ptr = std::make_unique(); - Fusion& fusion = *fusion_ptr.get(); - FusionGuard fg(&fusion); - - std::vector input_shape_0 = {10, 20}; - std::vector input_shape_1 = {15}; - - TensorView* in_0 = makeSymbolicTensor(input_shape_0.size()); - TensorView* in_1 = makeSymbolicTensor(input_shape_1.size()); - fusion.addInput(in_0); - fusion.addInput(in_1); - - TensorView* out_0 = add(in_0, IrBuilder::create(0.f)); - TensorView* out_1 = add(in_1, IrBuilder::create(2.f)); - - fusion.addOutput(out_0); - fusion.addOutput(out_1); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor at_in_0 = at::randn(input_shape_0, options); - at::Tensor at_in_1 = at::randn(input_shape_1, options); - std::vector aten_inputs = {at_in_0, at_in_1}; - - FusionExecutorCache fec(std::move(fusion_ptr)); - auto outputs = fec.runFusionWithInputs(aten_inputs); - - auto t1 = at_in_1 + 2; - - auto runtime = fec.getMostRecentKernelRuntime(); - TORCH_INTERNAL_ASSERT(runtime->isSegmented()); - TORCH_INTERNAL_ASSERT(runtime->fusionSegments()->groups().size() == 2); - - testValidate( - &fusion, outputs, {at_in_0, at_in_1}, {at_in_0, t1}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionIssue1284Repro2_CUDA) { - std::unique_ptr fusion_ptr = std::make_unique(); - Fusion& fusion = *fusion_ptr.get(); - FusionGuard fg(&fusion); - - std::vector input_shape_0 = {4, 4}; - std::vector input_shape_1 = {3, 4, 4}; - std::vector input_shape_2 = {2, 8, 4, 4}; - - TensorView* in_0 = makeSymbolicTensor(input_shape_0.size()); - TensorView* in_1 = makeSymbolicTensor(input_shape_1.size()); - TensorView* in_2 = makeSymbolicTensor(input_shape_2.size()); - - fusion.addInput(in_0); - fusion.addInput(in_1); - fusion.addInput(in_2); - - TensorView* out_0 = add(in_0, in_1); - TensorView* out_1 = add(in_0, in_2); - - fusion.addOutput(out_0); - fusion.addOutput(out_1); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor at_in_0 = at::randn(input_shape_0, options); - at::Tensor at_in_1 = at::randn(input_shape_1, options); - at::Tensor at_in_2 = at::randn(input_shape_2, options); - - std::vector aten_inputs = {at_in_0, at_in_1, at_in_2}; - - FusionExecutorCache fec(std::move(fusion_ptr)); - auto outputs = fec.runFusionWithInputs(aten_inputs); - - auto t0 = at_in_0 + at_in_1; - auto t1 = at_in_0 + at_in_2; - - auto runtime = fec.getMostRecentKernelRuntime(); - TORCH_INTERNAL_ASSERT(runtime->isSegmented()); - TORCH_INTERNAL_ASSERT(runtime->fusionSegments()->groups().size() == 2); - - testValidate( - &fusion, - outputs, - {at_in_0, at_in_1, at_in_2}, - {t0, t1}, - __LINE__, - __FILE__); -} - -TEST_F(NVFuserTest, FusionIssue1305Repro_CUDA) { - std::unique_ptr fusion_ptr = std::make_unique(); - Fusion& fusion = *fusion_ptr.get(); - FusionGuard fg(&fusion); - - auto t0 = makeContigTensor(1); - auto t1 = makeContigTensor(2); - - fusion.addInput(t0); - fusion.addInput(t1); - - auto t2 = broadcast(t0, {true, false}); - auto t3 = add(t1, t2); - auto t4 = add(t3, t2); - auto t5 = sum(t4, {1}); - auto t6 = broadcast(t5, {false, true}); - auto t7 = add(t3, t6); - - fusion.addOutput(t7); - - t3->computeAt(t7, -1, ComputeAtMode::MostInlined); - - TORCH_INTERNAL_ASSERT(t3->getComputeAtPosition() == 1); -} - -TEST_F(NVFuserTest, FusionDoubleBuffering1_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeContigTensor(1); - fusion.addInput(tv0); - - auto tv1 = set(tv0); - auto tv2 = add(tv1, IrBuilder::create(1.0)); - auto tv3 = set(tv2); - fusion.addOutput(tv3); - - tv1->setMemoryType(MemoryType::Shared); - - tv3->split(-1, 128); - tv3->split(-1, 32); - TransformPropagatorWithCheck propagator(tv3); - MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); - - tv0->computeAt(tv3, 1); - - tv3->axis(-2)->parallelize(ParallelType::BIDx); - tv3->axis(-1)->parallelize(ParallelType::TIDx); - scheduler_utils::parallelizeAllLike(tv3); - - tv1->doubleBuffer(); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::manual_seed(0); - auto t0 = at::randn({1000}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0}); - auto cg_outputs = fe.runFusion({t0}); - - auto ref = t0 + 1; - - testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionDoubleBuffering2_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeContigTensor(1); - fusion.addInput(tv0); - - auto tv1 = set(tv0); - auto tv2 = add(tv1, IrBuilder::create(1.0)); - auto tv3 = set(tv2); - fusion.addOutput(tv3); - - tv3->split(-1, 128); - tv3->split(-1, 32); - TransformPropagatorWithCheck propagator(tv3); - MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); - - tv0->computeAt(tv3, -1); - - tv3->axis(-2)->parallelize(ParallelType::BIDx); - tv3->axis(-1)->parallelize(ParallelType::TIDx); - scheduler_utils::parallelizeAllLike(tv3); - - tv1->doubleBuffer(); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::manual_seed(0); - auto t0 = at::randn({1000}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0}); - auto cg_outputs = fe.runFusion({t0}); - - auto ref = t0 + 1; - - testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionDoubleBuffering3_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeContigTensor(1); - fusion.addInput(tv0); - - auto tv1 = add(tv0, IrBuilder::create(1.0)); - auto tv2 = set(tv1); - auto tv3 = add(tv2, IrBuilder::create(1.0)); - fusion.addOutput(tv3); - - tv1->setMemoryType(MemoryType::Shared); - - tv3->split(-1, 128); - tv3->split(-1, 32); - TransformPropagatorWithCheck propagator(tv3); - MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); - - tv0->computeAt(tv3, 1); - - // tv2 is invalid to double-buffer as its producer, tv1, is - // computed inside the double-buffering loop. - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_ANY_THROW(tv2->doubleBuffer()); - - // Moving tv2 inner makes tv1 large enough to double-buffer tv2 - tv2->computeAt(tv3, 2); - - tv2->doubleBuffer(); - - tv3->axis(-1)->parallelize(ParallelType::TIDx); - scheduler_utils::parallelizeAllLike(tv3); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::manual_seed(0); - auto t0 = at::randn({1000}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0}); - auto cg_outputs = fe.runFusion({t0}); - - auto ref = t0 + 2; - - testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); -} - -// Double buffering smem to local and unswitch -TEST_F(NVFuserTest, FusionDoubleBuffering4_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeContigTensor(1); - fusion.addInput(tv0); - - auto tv1 = add(tv0, IrBuilder::create(1.0)); - auto tv2 = set(tv1); - auto tv3 = add(tv2, IrBuilder::create(1.0)); - fusion.addOutput(tv3); - - tv1->setMemoryType(MemoryType::Shared); - - tv3->split(-1, 128); - tv3->split(-1, 32); - tv3->split(-1, 8); - TransformPropagatorWithCheck propagator(tv3); - MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); - - tv0->computeAt(tv3, 2); - tv2->computeAt(tv3, -1); - - tv3->axis(-1)->parallelize(ParallelType::TIDx); - tv3->axis(1)->parallelize(ParallelType::Unswitch); - scheduler_utils::parallelizeAllLike(tv3); - - tv2->doubleBuffer(); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::manual_seed(0); - auto t0 = at::randn({1000}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0}); - auto cg_outputs = fe.runFusion({t0}); - - auto ref = t0 + 2; - - testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); -} - -// Double buffering gmem to shared and unswitch -TEST_F(NVFuserTest, FusionDoubleBuffering5_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeContigTensor(1); - fusion.addInput(tv0); - - auto tv1 = set(tv0); - auto tv2 = add(tv1, IrBuilder::create(1.0)); - fusion.addOutput(tv2); - - tv1->setMemoryType(MemoryType::Shared); - - tv2->split(-1, 128); - tv2->split(-1, 32); - tv2->split(-1, 8); - TransformPropagatorWithCheck propagator(tv2); - MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator); - - tv0->computeAt(tv2, 2); - tv1->computeAt(tv2, -1); - - tv2->axis(-1)->parallelize(ParallelType::TIDx); - tv2->axis(1)->parallelize(ParallelType::Unswitch); - scheduler_utils::parallelizeAllLike(tv2); - - tv1->doubleBuffer(); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::manual_seed(0); - auto t0 = at::randn({1000}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0}); - auto cg_outputs = fe.runFusion({t0}); - - auto ref = t0 + 1; - - testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); -} - -// Double buffering smem to local and unroll -TEST_F(NVFuserTest, FusionDoubleBuffering6_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeContigTensor(1); - fusion.addInput(tv0); - - auto tv1 = add(tv0, IrBuilder::create(1.0)); - auto tv2 = set(tv1); - auto tv3 = add(tv2, IrBuilder::create(1.0)); - fusion.addOutput(tv3); - - tv1->setMemoryType(MemoryType::Shared); - - tv3->split(-1, 128); - tv3->split(-1, 16); - tv3->split(-2, 4); - tv3->split(-2, 2); - TransformPropagatorWithCheck propagator(tv3); - MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); - - tv0->computeAt(tv3, 1); - tv2->computeAt(tv3, -1); - - tv3->axis(2)->parallelize(ParallelType::Unroll); - tv3->axis(4)->parallelize(ParallelType::TIDx); - - tv2->doubleBuffer(); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::manual_seed(0); - auto t0 = at::randn({199}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0}); - auto cg_outputs = fe.runFusion({t0}); - - auto ref = t0 + 2; - - testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); -} - -// Double buffering and vectorize -TEST_F(NVFuserTest, FusionDoubleBuffering7_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeContigTensor(1); - fusion.addInput(tv0); - - auto tv1 = set(tv0); - auto tv2 = add(tv1, IrBuilder::create(1.0)); - fusion.addOutput(tv2); - - tv2->split(-1, 128); - tv2->split(-1, 4); - TransformPropagatorWithCheck propagator(tv2); - MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator); - - tv1->computeAt(tv2, 2); - - tv2->axis(-2)->parallelize(ParallelType::TIDx); - - tv1->axis(-1)->parallelize(ParallelType::Vectorize); - - tv1->doubleBuffer(); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::manual_seed(0); - auto t0 = at::randn({200}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0}); - auto cg_outputs = fe.runFusion({t0}); - - auto ref = t0 + 1; - - testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); -} - -// Multiple tensors to double-buffer -TEST_F(NVFuserTest, FusionDoubleBuffering8_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeContigTensor(1); - fusion.addInput(tv0); - auto tv1 = makeContigTensor(1); - fusion.addInput(tv1); - - auto tv2 = set(tv0); - auto tv3 = set(tv1); - auto tv4 = add(tv2, tv3); - fusion.addOutput(tv4); - - tv4->split(0, 32); - tv4->split(0, 4); - TransformPropagatorWithCheck propagator(tv4); - MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); - - tv0->computeAt(tv4, 1); - tv1->computeAt(tv4, 1); - - tv4->axis(-1)->parallelize(ParallelType::TIDx); - scheduler_utils::parallelizeAllLike(tv4); - - tv2->doubleBuffer(); - tv3->doubleBuffer(); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::manual_seed(0); - auto t0 = at::randn({100}, options); - auto t1 = at::randn({100}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0, t1}); - auto cg_outputs = fe.runFusion({t0, t1}); - - auto ref = t0 + t1; - - testValidate(&fusion, cg_outputs, {t0, t1}, {ref}, __LINE__, __FILE__); -} - -// Nested double buffering from gmem to smem and smem to register -TEST_F(NVFuserTest, FusionDoubleBuffering9_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeContigTensor(1); - fusion.addInput(tv0); - auto tv1 = add(tv0, IrBuilder::create(1)); - auto out = tv1; - fusion.addOutput(out); - - auto tv2 = tv0->cacheAfter(); - auto tv3 = tv2->cacheAfter(); - - out->split(0, 32); - out->split(0, 4); - TransformPropagatorWithCheck propagator(out); - MaxRootDomainInfoSpanningTree(out).traverse(&propagator); - - tv2->setMemoryType(MemoryType::Shared); - - tv2->computeAt(out, 1); - tv3->computeAt(out, -1); - - out->axis(-1)->parallelize(ParallelType::TIDx); - scheduler_utils::parallelizeAllLike(out); - - tv2->doubleBuffer(); - tv3->doubleBuffer(); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::manual_seed(0); - auto t0 = at::randn({1001}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0}); - auto cg_outputs = fe.runFusion({t0}); - - auto ref = t0 + 1; - - testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); -} - -// FusionSmemBlockGemmCache + double buffering at both smem and local -TEST_F(NVFuserTest, FusionSmemBlockGemmCacheDoubleBuffer_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Algorithm - TensorView* tv0 = makeSymbolicTensor(2); // (M, K) - TensorView* tv1 = makeSymbolicTensor(2); // (K, N) - TensorView* tv2 = broadcast(tv0, {false, false, true}); // (M, K, B) - TensorView* tv3 = broadcast(tv1, {true, false, false}); // (B, K, N) - TensorView* tv4 = mul(tv2, tv3); // M, K, N - TensorView* tv5 = sum(tv4, {1}); // M, R, N - fusion.addInput(tv0); - fusion.addInput(tv1); - fusion.addOutput(tv5); - - TensorView* tv6 = tv5->cacheBefore(); - - // For smem double buffering - auto tv0_cache_local = tv0->cacheAfter(); - auto tv1_cache_local = tv1->cacheAfter(); - - // For register double buffering - auto tv0_cache_smem = tv0->cacheAfter(); - auto tv1_cache_smem = tv1->cacheAfter(); - - const int BSX = 32; - const int TSX = 8; - - // [M, K, N] - tv6->split(-1, BSX); - tv6->split(-1, TSX); - tv6->split(1, BSX); - tv6->split(0, BSX); - tv6->split(1, TSX); - // [M/BSX, BSX/TSX, TSX, K/BSX, BSX, N/BSX, BSX/TSX, TSX] - tv6->reorder( - {{4, 7}, {7, 6}, {6, 5}, {2, 4}, {1, 3}, {3, 2}, {5, 1}, {0, 0}}); - // [M/BSX, N/BSX, K/BSX, BSX/TSX, BSX/TSX, TSX, TSX, BSX] - - auto tv6_rf = tv6->rFactor({-1}); - - TransformPropagatorWithCheck propagator(tv6_rf); - MaxRootDomainInfoSpanningTree(tv6_rf).traverse(&propagator); - - tv0->computeAt(tv6, 3); - tv1->computeAt(tv6, 3); - - tv6_rf->computeAt(tv6, -1); - tv0_cache_local->computeAt(tv6_rf, -1); - tv1_cache_local->computeAt(tv6_rf, -1); - - tv0_cache_smem->setMemoryType(MemoryType::Shared); - tv1_cache_smem->setMemoryType(MemoryType::Shared); - - tv5->axis(0)->parallelize(ParallelType::BIDx); - tv5->axis(1)->parallelize(ParallelType::BIDy); - tv5->axis(-3)->parallelize(ParallelType::TIDy); - tv5->axis(-1)->parallelize(ParallelType::TIDx); - - scheduler_utils::parallelizeAllLike(tv5); - - tv0_cache_local->doubleBuffer(); - tv1_cache_local->doubleBuffer(); - - tv0_cache_smem->doubleBuffer(); - tv1_cache_smem->doubleBuffer(); - - constexpr int M = 154, K = 45, N = 1524; - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({M, K}, options); - at::Tensor t1 = at::randn({K, N}, options); - at::Tensor aten_output = matmul(t0.to(at::kDouble), t1.to(at::kDouble)); - - std::vector aten_inputs = {t0, t1}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto cg_outputs = fe.runFusion(aten_inputs); - - testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); - // The smem cache write in this test case is redundant predicated, - // and also double buffered. Currently we are relying on WAR sync - // insertion to ensure ordering of double buffered tensor access. - // The check below makes sure that the sync is inserted so that the - // test isn't running on a race condition. - TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count > 0); -} - -TEST_F(NVFuserTest, FusionIntermediateTensorVectorize_CUDA) { - std::vector mem_types = {MemoryType::Shared, MemoryType::Local}; - - for (auto mem_type : mem_types) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeContigTensor(1); - fusion.addInput(tv0); - - auto tv1 = set(tv0); - auto tv2 = set(tv1); - auto tv3 = set(tv2); - fusion.addOutput(tv3); - - tv1->setMemoryType(mem_type); - - tv3->split(-1, 4); - TransformPropagatorWithCheck propagator(tv3); - MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); - - tv1->computeAt(tv3, -2); - - tv2->axis(-1)->parallelize(ParallelType::Vectorize); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::manual_seed(0); - auto t0 = at::randn({15}, options); - FusionExecutor fe; - fe.compileFusion(&fusion); - - // This should throw an exception as the extent of t0 is not - // divisible by the vector width - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_ANY_THROW(fe.runFusion({t0})); - - auto t1 = at::randn({16}, options); - auto cg_outputs = fe.runFusion({t1}); - - auto ref = t1; - - testValidate(&fusion, cg_outputs, {t1}, {ref}, __LINE__, __FILE__); - } -} - -TEST_F(NVFuserTest, FusionBroadcastConcretization1_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeConcreteTensor({10, 1}); - fusion.addInput(tv0); - auto tv1 = makeConcreteTensor({10, 20}); - fusion.addInput(tv1); - auto tv2 = makeConcreteTensor({10, 10}); - fusion.addInput(tv2); - - // Not concretized - auto tv3 = sum(tv2, {1}); - auto tv4 = broadcast(tv3, {false, true}); - auto tv5 = add(tv0, tv4); - fusion.addOutput(tv5); - - // Concretized - auto tv6 = sum(tv2, {1}); - auto tv7 = broadcast(tv6, {false, true}); - auto tv8 = add(tv1, tv7); - fusion.addOutput(tv8); - - for (auto tv : {tv3, tv4, tv5, tv6, tv7, tv8}) { - tv->axis(1)->parallelize(ParallelType::TIDx); - } - - GpuLower gpulw(&fusion); - TORCH_CHECK(!gpulw.concretizedBroadcastDomains().isConcretized( - loweredTv(tv4, gpulw)->axis(1))); - TORCH_CHECK(gpulw.concretizedBroadcastDomains().isConcretized( - loweredTv(tv7, gpulw)->axis(1))); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::manual_seed(0); - auto t0 = at::randn({10, 1}, options); - auto t1 = at::randn({10, 20}, options); - auto t2 = at::randn({10, 10}, options); - std::vector aten_inputs = {t0, t1, t2}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto outputs = fe.runFusion(aten_inputs); - - auto t5 = t0 + t2.sum({1}).unsqueeze(-1); - auto t8 = t1 + t2.sum({1}).unsqueeze(-1); - - testValidate(&fusion, outputs, aten_inputs, {t5, t8}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionBroadcastConcretization2_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - - auto tv1 = sum(tv0, {0, 1}); - auto tv2 = broadcast(tv1, {true}); - auto tv3 = broadcast(tv2, {false, true}); - fusion.addOutput(tv3); - - // tv1 is thread-predicated with TIDx and TIDy - tv1->axis(0)->parallelize(ParallelType::TIDx); - tv1->axis(1)->parallelize(ParallelType::TIDy); - // tv2 broadcasts along TIDx - tv2->axis(0)->parallelize(ParallelType::TIDx); - // tv3 broadcasts along TIDy - tv3->axis(0)->parallelize(ParallelType::TIDx); - tv3->axis(1)->parallelize(ParallelType::TIDy); - - // Both tv2 and tv3 broadcast along predicated TID dimensions, but - // since the broadcast domains are not concretized, there should be - // no actual parallel broadcast - - GpuLower gpulw(&fusion); - TORCH_CHECK( - !gpulw.kernel()->summary().has_block_broadcasts && - !gpulw.kernel()->summary().has_grid_broadcasts, - "There must be no parallel broadcast in this fusion"); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::manual_seed(0); - auto t0 = at::randn({10, 11}, options); - std::vector aten_inputs = {t0}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto outputs = fe.runFusion(aten_inputs); - - auto t3 = t0.sum().unsqueeze(-1).unsqueeze(-1); - - testValidate(&fusion, outputs, aten_inputs, {t3}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionBroadcastConcretization3_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - std::vector input_shape({10, 4, 8}); - std::vector output_shape({8, 4, 1}); - - auto tv0 = makeConcreteTensor(input_shape); - fusion.addInput(tv0); - - auto tv2 = sum(tv0, {0}); - auto tv3 = set(tv2); - auto tv4 = - view(tv3, {input_shape.begin() + 1, input_shape.end()}, output_shape); - auto tv5 = add(tv4, IrBuilder::create(1)); - fusion.addOutput(tv5); - - tv2->axis(0)->parallelize(ParallelType::TIDx); - tv4->axis(-1)->parallelize(ParallelType::TIDx); - tv5->axis(-1)->parallelize(ParallelType::TIDx); - - // The view op adds a broadcast domain in tv4, which is - // parallelized. Howver, it is never materialized, so there should - // be no parallel broadcast. - - GpuLower gpulw(&fusion); - TORCH_CHECK( - !gpulw.kernel()->summary().has_block_broadcasts && - !gpulw.kernel()->summary().has_grid_broadcasts, - "There must be no parallel broadcast in this fusion"); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::manual_seed(0); - auto t0 = at::randn(input_shape, options); - std::vector aten_inputs = {t0}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto outputs = fe.runFusion(aten_inputs); - - auto t5 = at::native::view(t0.sum(0), output_shape) + 1; - - testValidate(&fusion, outputs, aten_inputs, {t5}, __LINE__, __FILE__); -} - -// Merging non-broadcast and broadcast domains -// TODO: Fix use case see issue https://github.com/csarofeen/pytorch/issues/1418 -// validateParallelize does not pass. Even if it's skipped, -// generated code is invalid as blockBroadcast is not used. -#if 0 -TEST_F(NVFuserTest, FusionBroadcastConcretization4_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - - auto tv1 = sum(tv0, {1}); - auto tv2 = broadcast(tv1, {false, true}); - auto tv3 = add(tv2, tv0); - fusion.addOutput(tv3); - - tv1->axis(1)->parallelize(ParallelType::TIDx); - - tv2->merge(0, 1); - tv2->axis(0)->parallelize(ParallelType::TIDx); - // TODO: When set to shared memory, this kernel should be correct, but fails - // validation and when skipped produces incorrect code - tv2->setMemoryType(MemoryType::Shared); - - tv3->merge(0, 1); - tv3->axis(0)->parallelize(ParallelType::TIDx); - - fusion.printMath(); - fusion.printKernel(); -} -#endif - -TEST_F(NVFuserTest, FusionBroadcastConcretization5_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - auto tv1 = makeSymbolicTensor(1); - fusion.addInput(tv1); - auto tv2 = makeSymbolicTensor(1); - fusion.addInput(tv2); - auto tv3 = makeSymbolicTensor(1); - fusion.addInput(tv3); - - // Assert tv2 and tv3 have the same shape - auto tv4 = add(tv2, tv3); - fusion.addOutput(tv4); - - // Concretize a broadcast domain to multiple non-concrete domains - // through a multi-output expression. It should be considered to be - // non-uniquely concretized. - auto tv5 = broadcast(tv0, {false, true}); - // Reduce only the non-broadcast domain. - auto tvs = Welford(tv5, {0}); - auto tv9 = add(tvs.avg, tv1); - auto tv10 = add(tvs.var_sum, tv2); - fusion.addOutput(tv9); - fusion.addOutput(tv10); - - // Same pattern as the above, but concretize the broadcast domain - // with tv2 and tv3, which have the exactly same shape, so the - // broadcast should be considered uniquely concretized. - auto tv11 = broadcast(tv0, {false, true}); - // Reduce only the non-broadcast domain. - auto tvs2 = Welford(tv11, {0}); - auto tv15 = add(tvs2.avg, tv2); - auto tv16 = add(tvs2.var_sum, tv3); - fusion.addOutput(tv15); - fusion.addOutput(tv16); - - // Reduce only the broadcast domain. Since it's reduced, it should - // not be considered to be concretized. - auto tv17 = broadcast(tv0, {false, true}); - auto tvs3 = Welford(tv17, {1}); - fusion.addOutput(tvs3.avg); - - ConcretizedBroadcastDomains bcast_concretization_info; - bcast_concretization_info.build(&fusion); - - TORCH_CHECK( - bcast_concretization_info.maybeNonUniquelyConcretized(tv5->axis(1)), - "Failed to detect non-unique concretization of ", - tv5->toString()); - - TORCH_CHECK( - bcast_concretization_info.isUniquelyConcretized(tv11->axis(1)), - "Failed to detect unique concretization of ", - tv11->toString()); - - TORCH_CHECK( - !bcast_concretization_info.isConcretized(tv17->axis(1)), - "Failed to detect non-concretization of ", - tv17->toString()); -} - -TEST_F(NVFuserTest, FusionIssue1430_CUDA) { - // Derived from an expression sorting issue when using loop map, now expr - // sorting uses parallel map. - std::unique_ptr fusion_ptr = std::make_unique(); - Fusion& fusion = *fusion_ptr.get(); - FusionGuard fg(&fusion); - - int V = 2, W = 3, X = 4, Y = 5, Z = 6; - - // setup fusion - auto tv0 = TensorViewBuilder() - .ndims(5) - .dtype(DataType::Half) - .contiguity(std::vector(5, true)) - .shape({V, W, X, Y, Z}) - .build(); - - fusion.addInput(tv0); - auto tv1 = set(tv0); - auto tv2 = castOp(DataType::Float, tv1); - - auto tvs = Welford(tv2, {1, 2, 3, 4}); - auto tv3 = tvs.avg; - auto tv4 = tvs.var_sum; - auto tv5 = tvs.n; - - // avg - auto tv6 = broadcast(tvs.avg, {false, true, true, true, true}); - - // var - auto tv7 = mul(tv4, IrBuilder::create(1. / (W * X * Y * Z))); - auto tv8 = add(tv7, IrBuilder::create(1.e-6)); - auto tv9 = broadcast(tv8, {false, true, true, true, true}); - auto tv10 = rsqrt(tv9); - - auto tv11 = castOp(DataType::Float, tv1); - auto tv12 = sub(tv11, tv6); - auto tv13 = mul(tv12, tv10); - - auto tv14 = set(tv13); - fusion.addOutput(tv14); - - tv3->axis(0)->parallelize(ParallelType::BIDy); - tv3->axis(2)->parallelize(ParallelType::BIDx); - tv3->axis(3)->parallelize(ParallelType::TIDx); - tv3->axis(4)->parallelize(ParallelType::Vectorize); - - // tv3->reorder({{1, -2}}); - - auto rfactor = ir_utils::rfactorHelper(tv3, {1, 4}); - - scheduler_utils::parallelizeAllLike(rfactor); - - for (auto tv : ir_utils::allTvs(&fusion)) { - if (tv != tv1 || tv != tv3) { - for (auto i : c10::irange(tv->nDims())) { - if (isParallelTypeVectorize(tv->axis(i)->getParallelType())) { - tv->axis(i)->parallelize(ParallelType::Serial); - } - } - } - } - - tv0->computeAt(tv14, 1); - tv13->computeAt(tv14, -2); - tv2->computeAt(tv14, -1, ComputeAtMode::MostInlined); - tv11->computeAt(tv14, -1, ComputeAtMode::MostInlined); - - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({V, W, X, Y, Z}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion); - auto cg_outputs = fe.runFusion({t0}, LaunchParams(X, V, -1, Y, -1, -1)); - - auto t0_double = t0.to(at::kDouble); - - auto at_mu = at::mean(t0_double, {1, 2, 3, 4}) - .unsqueeze(-1) - .unsqueeze(-1) - .unsqueeze(-1) - .unsqueeze(-1); - auto at_var = at::var(t0_double, {1, 2, 3, 4}, false) - .unsqueeze(-1) - .unsqueeze(-1) - .unsqueeze(-1) - .unsqueeze(-1); - - auto at_out = t0_double.sub(at_mu).div(at_var.add(1.e-6).sqrt()); - - testValidate( - &fusion, - cg_outputs, - {t0}, - {at_out}, - __LINE__, - __FILE__, - "", - LaunchParams(X, V, -1, Y, -1, -1)); -} - -// Test code generation of allocated scalars -TEST_F(NVFuserTest, FusionCodegenAllocatedScalars_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Fusion is just a dummy container in this test, just used for - // getting a Kernel container - auto tv0 = makeSymbolicTensor(0); - fusion.addInput(tv0); - auto tv1 = set(tv0); - fusion.addOutput(tv1); - - GpuLower gpulw(&fusion); - auto kernel = gpulw.kernel(); - - // Set the kernel as the current fusion - FusionGuard kg(kernel); - - // Create alocated scalars - auto ks0 = add(kernel->zeroVal(), kernel->oneVal()); - auto ks0_alloc = IrBuilder::create( - ks0, MemoryType::Local, kernel->oneVal()); - - auto ks1 = add(ks0, kernel->oneVal()); - auto ks1_alloc = IrBuilder::create( - ks1, MemoryType::Local, kernel->oneVal()); - - auto tk0 = kernel->inputs()[0]->as(); - auto tki0 = IrBuilder::create(tk0, std::vector{ks0}); - auto tki1 = IrBuilder::create(tk0, std::vector{ks1}); - auto tk0_expr = IrBuilder::create(UnaryOpType::Set, tki0, tki1); - - // Insert the scalar expression and the allocation of the - // output directly to the kernel - auto proxy = kir::KernelInternalProxy(kernel); - - const auto indent = " "; - const auto ks0_name = "i" + std::to_string(ks0->name()); - const auto ks1_name = "i" + std::to_string(ks1->name()); - const auto tk0_name = "T" + std::to_string(tk0->name()); - - auto& exprs = proxy.topLevelExprs(); - exprs.push_back(tk0_expr); - - // Invalid code gen - const auto no_alloc_code = codegen::generateCudaKernel(kernel); - - // Without alloc, Int vals are just inlined, resulting in: - // t0[(0 + 1)] = t0[((0 + 1) + 1)] - std::stringstream no_alloc_ref; - no_alloc_ref << "\n" - << indent << tk0_name << "[(0 + 1)]\n" - << indent << indent << " = " << tk0_name << "[((0 + 1) + 1)];\n"; - - TORCH_CHECK( - no_alloc_code.find(no_alloc_ref.str()) != std::string::npos, - "Invalid code generation. Expected:", - no_alloc_ref.str(), - "Actual:\n", - no_alloc_code); - - // Insert proper allocations and definitions - exprs.insert(std::find(exprs.begin(), exprs.end(), tk0_expr), ks0_alloc); - exprs.insert( - std::find(exprs.begin(), exprs.end(), tk0_expr), ks0->definition()); - exprs.insert(std::find(exprs.begin(), exprs.end(), tk0_expr), ks1_alloc); - exprs.insert( - std::find(exprs.begin(), exprs.end(), tk0_expr), ks1->definition()); - - const auto valid_code = codegen::generateCudaKernel(kernel); - - std::stringstream valid_ref; - valid_ref << "\n" - << indent << tk0_name << "[" << ks0_name << "]\n" - << indent << indent << " = " << tk0_name << "[" << ks1_name - << "];\n"; - - TORCH_CHECK( - valid_code.find(valid_ref.str()) != std::string::npos, - "Invalid code generation. Expected:", - valid_ref.str(), - "Actual:\n", - valid_code); -} - -TEST_F(NVFuserTest, FusionIndexHoist1_CUDA) { - if (isOptionDisabled(DisableOption::IndexHoist)) { - GTEST_SKIP() << "Index hoisting disabled"; - } - - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - - auto tv1 = set(tv0); - auto tv2 = set(tv1); - auto tv3 = set(tv2); - auto tv4 = set(tv3); - auto tv5 = set(tv4); - fusion.addOutput(tv5); - - tv1->split(-1, 4); - tv2->split(-1, 4); - tv3->merge(0, 1); - tv3->split(0, 8); - tv5->merge(0, 1); - tv5->split(0, 8); - tv4->computeAt(tv5, -1); - - tv1->setMemoryType(MemoryType::Global); - tv2->setMemoryType(MemoryType::Global); - tv3->setMemoryType(MemoryType::Global); - - // Use Int32 as the index type to verify Int32 is used as the type - // of hoisted indices - GpuLower gpulw(&fusion, DataType::Int32); - auto kernel = gpulw.kernel(); - - auto is_index_times_ns = [](Val* val, Val* index, std::string name) -> bool { - auto def = dynamic_cast(val->definition()); - if (def == nullptr) { - return false; - } - return def->getBinaryOpType() == BinaryOpType::Mul && - def->rhs()->isA() && - def->rhs()->as()->name() == name && def->lhs() == index; - }; - - // Validate indices in the kernel are hoisted as - // intended. Validation could be also done by just string comparison - // as the parser test, but updating such tests would be tedious. - for (auto top_level_loop : - ir_utils::filterByType(kernel->topLevelExprs())) { - auto innermost_loop = top_level_loop; - while (auto first_expr_loop = dynamic_cast( - innermost_loop->body().exprs().at(0))) { - innermost_loop = first_expr_loop; - } - const auto& exprs = innermost_loop->body().exprs(); - TORCH_CHECK(!exprs.empty(), "No expression found"); - TORCH_CHECK( - exprs.at(0)->isA(), - "Invalid expression: ", - exprs.at(0)->toString()); - auto hoisted_index = exprs.at(0)->as()->buffer(); - TORCH_CHECK( - hoisted_index->dtype() == DataType::Int32, - "Invalid data type of hoisted indices. Should be Int32 but: ", - hoisted_index->dtype()); - kir::Predicate* pred = nullptr; - for (auto expr : exprs) { - if (expr->isA()) { - pred = expr->as()->predicate(); - auto arith_expr = expr->as()->thenBody().exprs().at(0); - auto out_ti = arith_expr->outputs()[0]->as(); - if (out_ti->view()->name() == 1) { - // Ref: T1[*, hoisted_index] = T0[*, hoisted_index * T0.stride]; - auto t1_index = out_ti->index(1); - TORCH_CHECK( - t1_index == hoisted_index, - "Invalid index: ", - t1_index->toInlineString()); - // Pred: hoisted_index < T0.size[1] - TORCH_CHECK( - pred->value()->definition()->as()->lhs() == - hoisted_index, - "Invalid predicate: ", - pred->value()->toInlineString()); - TORCH_CHECK(arith_expr->inputs().size() == 1); - auto in0 = arith_expr->inputs().front()->as(); - TORCH_CHECK(in0->view()->name() == 0); - // hoisted_index * T0.stride[1] - auto t0_index = in0->index(1); - TORCH_CHECK( - is_index_times_ns(t0_index, hoisted_index, "T0.stride[1]"), - "Invalid index: ", - t0_index->toInlineString()); - } else if (out_ti->view()->name() == 2) { - // Ref: T3[*, hoisted_index] = T2[*, hoisted_index]; - auto out_index = out_ti->index(1); - TORCH_CHECK( - out_index == hoisted_index, - "Invalid index: ", - out_index->toInlineString()); - TORCH_CHECK( - pred->value()->definition()->as()->lhs() == - hoisted_index, - "Invalid predicate: ", - pred->value()->toInlineString()); - TORCH_CHECK(arith_expr->inputs().size() == 1); - auto in0 = arith_expr->inputs().front()->as(); - TORCH_CHECK(in0->view()->name() == 1); - auto in0_index = in0->index(1); - TORCH_CHECK( - in0_index == hoisted_index, - "Invalid index: ", - in0_index->toInlineString()); - } else if (out_ti->view()->name() == 3) { - // Ref: T3[hoisted_index] = T2[hoisted_index]; - auto out_index = out_ti->index(0); - TORCH_CHECK( - out_index == hoisted_index, - "Invalid index: ", - out_index->toInlineString()); - TORCH_CHECK( - pred->value()->definition()->as()->lhs() == - hoisted_index, - "Invalid predicate: ", - pred->value()->toInlineString()); - TORCH_CHECK(arith_expr->inputs().size() == 1); - auto in0 = arith_expr->inputs().front()->as(); - TORCH_CHECK(in0->view()->name() == 2); - auto in0_index = in0->index(0); - TORCH_CHECK( - in0_index == hoisted_index, - "Invalid index: ", - in0_index->toInlineString()); - } else if (out_ti->view()->name() == 4) { - // Ref: T4[0] = T3[hoisted_index]; - TORCH_CHECK( - pred->value()->definition()->as()->lhs() == - hoisted_index, - "Invalid predicate: ", - pred->value()->toInlineString()); - TORCH_CHECK(arith_expr->inputs().size() == 1); - auto in0 = arith_expr->inputs().front()->as(); - TORCH_CHECK(in0->view()->name() == 3); - auto in0_index = in0->index(0); - TORCH_CHECK( - in0_index == hoisted_index, - "Invalid index: ", - in0_index->toInlineString()); - } else if (out_ti->view()->name() == 5) { - // Ref: T5[hoisted_index] = T4[0] - auto out_index = out_ti->index(0); - TORCH_CHECK( - out_index == hoisted_index, - "Invalid index: ", - out_index->toInlineString()); - TORCH_CHECK( - pred->value()->definition()->as()->lhs() == - hoisted_index, - "Invalid predicate: ", - pred->value()->toInlineString()); - } - } - } - } - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::manual_seed(0); - auto t0 = at::randn({15, 17}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0}); - auto cg_outputs = fe.runFusion({t0}); - - auto ref = t0; - - testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); -} - -// Hoist indices for vectorized tensors -TEST_F(NVFuserTest, FusionIndexHoist2_CUDA) { - if (isOptionDisabled(DisableOption::IndexHoist)) { - GTEST_SKIP() << "Index hoisting disabled"; - } - - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeContigTensor(1); - fusion.addInput(tv0); - auto tv1 = makeContigTensor(1); - fusion.addInput(tv1); - - auto tv2 = set(tv0); - auto tv3 = set(tv1); - auto tv4 = add(tv2, tv3); - auto tv5 = set(tv4); - fusion.addOutput(tv5); - - tv5->split(-1, 4); - TransformPropagatorWithCheck propagator(tv5); - MaxRootDomainInfoSpanningTree(tv5).traverse(&propagator); - - tv4->split(-1, 3); - - tv0->computeAt(tv5, 1); - tv1->computeAt(tv5, 1); - - tv2->axis(-1)->parallelize(ParallelType::Vectorize); - tv3->axis(-1)->parallelize(ParallelType::Vectorize); - tv5->axis(-1)->parallelize(ParallelType::Vectorize); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::manual_seed(0); - auto t0 = at::randn({16}, options); - auto t1 = at::randn({16}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0, t1}); - auto cg_outputs = fe.runFusion({t0, t1}); - - auto ref = t0 + t1; - - testValidate(&fusion, cg_outputs, {t0, t1}, {ref}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionTestGridComm_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - int X = 3, Y = 4, Z = 2; - auto tv0 = makeConcreteTensor({X, Y, Z}); - fusion.addInput(tv0); - auto tv1 = makeConcreteTensor({X, Y, Z}); - fusion.addInput(tv1); - - auto tv2 = set(tv0); - auto tv3 = add(tv2, tv1); - auto tv4 = set(tv3); - auto tv5 = set(tv4); - fusion.addOutput(tv5); - - tv2->setMemoryType(MemoryType::Global); - tv3->setMemoryType(MemoryType::Global); - tv4->setMemoryType(MemoryType::Global); - - tv2->axis(0)->parallelize(ParallelType::BIDy); - tv2->axis(1)->parallelize(ParallelType::BIDx); - tv2->axis(2)->parallelize(ParallelType::Vectorize); - - tv3->axis(0)->parallelize(ParallelType::BIDx); - tv3->axis(1)->parallelize(ParallelType::BIDy); - - tv4->axis(0)->parallelize(ParallelType::BIDy); - tv4->axis(1)->parallelize(ParallelType::BIDx); - - tv5->axis(0)->parallelize(ParallelType::BIDy); - tv5->axis(1)->parallelize(ParallelType::BIDx); - tv5->axis(2)->parallelize(ParallelType::Vectorize); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::manual_seed(0); - auto t0 = at::randn({X, Y, Z}, options); - auto t1 = at::randn({X, Y, Z}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0, t1}); - auto cg_outputs = fe.runFusion({t0, t1}); - - auto ref = t0 + t1; - - testValidate(&fusion, cg_outputs, {t0, t1}, {ref}, __LINE__, __FILE__); -} - -// See issue https://github.com/csarofeen/pytorch/issues/1497 -TEST_F(NVFuserTest, FusionTestGridComm2_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - int64_t W = 3, X = 4; - - auto tv0 = makeConcreteTensor({X}); - auto tv1 = makeConcreteTensor({W, X}); - fusion.addInput(tv0); - fusion.addInput(tv1); - - auto tv2 = add(tv0, IrBuilder::create(1)); - auto tv3 = broadcast(tv2, {true, false}); - auto tv4 = add(tv3, tv1); - fusion.addOutput(tv4); - - tv4->merge(0); - tv4->split(0, 2); - - TransformPropagatorWithCheck propagator(tv4); - MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); - - tv3->computeAt(tv4, 1); - - tv4->axis(0)->parallelize(ParallelType::BIDx); - tv4->axis(-1)->parallelize(ParallelType::TIDx); - tv2->axis(0)->parallelize(ParallelType::BIDx); - tv2->axis(-1)->parallelize(ParallelType::TIDx); - tv3->axis(-1)->parallelize(ParallelType::TIDx); - - tv2->setMemoryType(MemoryType::Global); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::manual_seed(0); - auto t0 = at::randn({X}, options); - auto t1 = at::randn({W, X}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0, t1}); - auto cg_outputs = fe.runFusion({t0, t1}); - - auto ref = t0 + t1 + 1; - - testValidate(&fusion, cg_outputs, {t0, t1}, {ref}, __LINE__, __FILE__); -} - -// Vectorized reset test for double buffered registers -TEST_F(NVFuserTest, FusionDoubleBufferVector_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeContigTensor(1); - fusion.addInput(tv0); - - auto tv1 = add(tv0, IrBuilder::create(1.0)); - auto tv2 = sum(tv1, {0}); - auto tv2c = tv2->cacheBefore(); - - fusion.addOutput(tv2); - - auto tv1cw = tv1->cacheAfter(); - auto tv1cr = tv1cw->cacheAfter(); - - tv1cw->split(-1, 32); - tv1cr->split(-1, 32); - tv1cr->split(-1, 4); - tv1cr->axis(-1)->parallelize(ParallelType::Vectorize); - - tv1cw->computeAt(tv1cr, 1); - tv0->computeAt(tv1cw, -1); - tv2c->split(-1, 32); - tv2c->split(-1, 4); - tv1cr->computeAt(tv2c, 2); - - tv1cw->setMemoryType(MemoryType::Shared); - tv1cr->doubleBuffer(); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::manual_seed(0); - auto t0 = at::randn({200}, options); - FusionExecutor fe; - fe.compileFusion(&fusion, {t0}); - auto cg_outputs = fe.runFusion({t0}); - auto ref = (t0 + 1).sum({0}); - - testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); -} - -// Request 48KB of data in shared mem, -// should be large enough not to fit in -// static allocations, but small enough -// to fit in supported devices (sm70+). -TEST_F(NVFuserTest, FusionLargeSmem_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeContigTensor(1); - fusion.addInput(tv0); - auto tv1 = add(tv0, IrBuilder::create(1.0)); - auto tv2 = add(tv1, IrBuilder::create(2.0)); - fusion.addOutput(tv2); - - tv2->split(0, 12288); - tv2->split(1, 128); - tv1->computeAt(tv2, 1); - tv1->split(1, 128); - tv0->computeAt(tv1, -1); - tv1->setMemoryType(MemoryType::Shared); - tv1->axis(-1)->parallelize(ParallelType::TIDx); - tv2->axis(-1)->parallelize(ParallelType::TIDx); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::manual_seed(0); - auto t0 = at::randn({12288 * 4}, options); - FusionExecutor fe; - fe.compileFusion(&fusion, {t0}); - auto cg_outputs = fe.runFusion({t0}); - auto ref = t0 + 1 + 2; - - testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); -} - -// Request a smem allocation that is equal to the device limit -TEST_F(NVFuserTest, FusionTooLargeSmem_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto properties = at::cuda::getDeviceProperties( - c10::Device(c10::DeviceType::CUDA, 0).index()); - int device_limit = properties->sharedMemPerBlockOptin; - - auto tv0 = makeContigTensor(1); - fusion.addInput(tv0); - auto tv1 = add(tv0, IrBuilder::create(1.0)); - auto tv2 = add(tv1, IrBuilder::create(2.0)); - fusion.addOutput(tv2); - - // 4 byte per float - tv2->split(0, device_limit / 4); - tv2->split(1, 128); - tv1->computeAt(tv2, 1); - tv1->split(1, 128); - tv0->computeAt(tv1, -1); - tv1->setMemoryType(MemoryType::Shared); - tv1->axis(-1)->parallelize(ParallelType::TIDx); - tv2->axis(-1)->parallelize(ParallelType::TIDx); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::manual_seed(0); - auto t0 = at::randn({12288 * 4}, options); - FusionExecutor fe; - - // First compile gets a compiled kernel - fe.compileFusion(&fusion, {t0}); - - // Should be throwing because the kernel - // requested absolute device limit - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_ANY_THROW(fe.runFusion({t0})); -} - -// Try to test alignment when multiple tensors are -// in shared mem. -TEST_F(NVFuserTest, FusionSmemAlignment_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeConcreteTensor({3, 4, 7, 2, 5}); - fusion.addInput(tv0); - auto tv1 = sum(tv0, {4}); - auto tv2 = sum(tv1, {3}); - auto tv3 = sum(tv2, {2}); - auto tv4 = sum(tv3, {1}); - fusion.addOutput(tv4); - - auto tv0c = tv0->cacheAfter(); - auto tv1bc = tv1->cacheBefore(); - auto tv2bc = tv2->cacheBefore(); - auto tv3bc = tv3->cacheBefore(); - auto tv4bc = tv4->cacheBefore(); - - tv0c->setMemoryType(MemoryType::Shared); - tv1bc->setMemoryType(MemoryType::Shared); - tv2bc->setMemoryType(MemoryType::Shared); - tv3bc->setMemoryType(MemoryType::Shared); - tv4bc->setMemoryType(MemoryType::Shared); - - tv1->axis(-1)->parallelize(ParallelType::Vectorize); - tv3->axis(-1)->parallelize(ParallelType::Vectorize); - tv0->computeAt(tv4, 0); - tv0->computeAt(tv2, 2); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::manual_seed(0); - auto t0 = at::randn({3, 4, 7, 2, 5}, options); - FusionExecutor fe; - - fe.compileFusion(&fusion, {t0}); - auto cg_outputs = fe.runFusion({t0}); - auto tref = t0.sum({1, 2, 3, 4}); - - testValidate(&fusion, cg_outputs, {t0}, {tref}, __LINE__, __FILE__); -} - -// Repro of #1521 -TEST_F(NVFuserTest, FusionImmediateValueAsInput_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - - auto immediate_scalr = IrBuilder::create(0.1); - // Adding an immediate scalar value as an input is not allowed - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_ANY_THROW(fusion.addInput(immediate_scalr)); - - // Instead, use a symbolic value - auto symbolic_scalar = IrBuilder::create(); - fusion.addInput(symbolic_scalar); - - auto tv1 = add(tv0, symbolic_scalar); - fusion.addOutput(tv1); - - // Make sure the kernel is compiled. - FusionExecutor fe; - fe.compileFusion(&fusion); -} - -// Repro of #1506 -TEST_F(NVFuserTest, FusionVectorizeContigIndex_CUDA) { - std::vector shape{14, 14}; - - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeContigTensor(2); - fusion.addInput(tv0); - auto tv1 = set(tv0); - auto tv2 = set(tv1); - fusion.addOutput(tv2); - - tv2->merge(0); - - // Vectorize by 4 should be allowed - tv2->split(0, 4); - - tv2->axis(0)->parallelize(ParallelType::TIDx); - tv0->computeAt(tv2, 1); - - tv1->axis(1)->parallelize(ParallelType::Vectorize); - tv2->axis(1)->parallelize(ParallelType::Vectorize); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto t0 = at::randn(shape, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0}); - auto cg_outputs = fe.runFusion({t0}); - - TORCH_CHECK(t0.equal(cg_outputs[0])); -} - -// Make sure the same fusion as FusionVectorizeContigIndex fails if -// not contig. -TEST_F(NVFuserTest, FusionVectorizeContigIndexFail_CUDA) { - std::vector shape{14, 14}; - - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - auto tv1 = set(tv0); - auto tv2 = set(tv1); - fusion.addOutput(tv2); - - tv2->merge(0); - - tv2->split(0, 4); - - tv2->axis(0)->parallelize(ParallelType::TIDx); - tv0->computeAt(tv2, 1); - - tv1->axis(1)->parallelize(ParallelType::Vectorize); - tv2->axis(1)->parallelize(ParallelType::Vectorize); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto t0 = at::randn(shape, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0}); - - // This should fail at the launch time as 14 is not divisible by the - // vector word size. The two domains are merged, but they are not - // contiguous, so contig indexing is not involved in this case. - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_ANY_THROW(fe.runFusion({t0})); -} - -TEST_F(NVFuserTest, FusionVectorizeInputToOutput_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - auto tv1 = set(tv0); - fusion.addOutput(tv1); - - tv1->split(0, 4); - - tv1->axis(-1)->parallelize(ParallelType::Vectorize); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::manual_seed(0); - - const int n = 12; - auto t0 = at::randn({n}, options); - // Shift by one to make it non-aligned - auto t0_misaligned = at::randn({n + 1}, options).index({Slice(1)}); - auto t1_misaligned = at::empty({n + 1}, options).index({Slice(1)}); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0}); - auto cg_outputs = fe.runFusion({t0}); - TORCH_CHECK(t0.equal(cg_outputs[0])); - - // Pass misaligned input. This must fail. - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_ANY_THROW(fe.runFusion({t0_misaligned})); - - // Pass misaligned output. This must fail too. - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_ANY_THROW(fe.runFusion({t0}, {t1_misaligned})); -} - -// Repro of issue #1530 -TEST_F(NVFuserTest, FusionVectorizeContigIndexValidationFail_CUDA) { - std::vector shape{1, 2, 1}; - - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeContigTensor(shape.size()); - fusion.addInput(tv0); - auto tv1 = set(tv0); - fusion.addOutput(tv1); - - tv1->merge(1); - tv1->merge(0); - - auto invalid_vec_size = shape[0] * shape[1] * shape[2]; - invalid_vec_size *= invalid_vec_size; - - tv1->split(0, invalid_vec_size); - - tv1->axis(1)->parallelize(ParallelType::Vectorize); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto t0 = at::randn(shape, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0}); - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_ANY_THROW(fe.runFusion({t0})); -} - -TEST_F(NVFuserTest, FusionContigIndexingWithBroadcast_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeConcreteTensor({4}); - fusion.addInput(tv0); - auto tv1 = makeConcreteTensor({3, 4}); - fusion.addInput(tv1); - - auto tv2 = broadcast(tv0, {true, false}); - auto tv3 = add(tv2, tv1); - fusion.addOutput(tv3); - - tv3->merge(0); - TransformPropagatorWithCheck propagator(tv3); - MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); - - tv2->setMemoryType(MemoryType::Local); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto t0 = at::randn({4}, options); - auto t1 = at::randn({3, 4}, options); - - auto t3 = t0.unsqueeze(0).add(t1); - { - FusionExecutor fe; - fe.compileFusion(&fusion, {t0, t1}); - auto cg_outputs = fe.runFusion({t0, t1}); - - testValidate(&fusion, cg_outputs, {t0, t1}, {t3}, __LINE__, __FILE__); - } - - // Make sure tv2 indexing also works when it's stored in global memory - tv2->setMemoryType(MemoryType::Global); - { - FusionExecutor fe; - fe.compileFusion(&fusion, {t0, t1}); - auto cg_outputs = fe.runFusion({t0, t1}); - - testValidate(&fusion, cg_outputs, {t0, t1}, {t3}, __LINE__, __FILE__); - } -} - -// Repro of #1534. Validation should detect invalid vectorization. -TEST_F(NVFuserTest, FusionVectorizeContigIndexValidationFail2_CUDA) { - std::vector shape1{2, 3, 2}; - std::vector shape2{2, 2}; - - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeContigConcreteTensor(shape1); - fusion.addInput(tv0); - auto tv1 = makeContigConcreteTensor(shape2); - fusion.addInput(tv1); - - auto tv2 = set(tv1); - auto tv3 = broadcast(tv2, {false, true, false}); - auto tv4 = add(tv0, tv3); - fusion.addOutput(tv4); - - tv4->merge(1, 2); - tv4->merge(0, 1); - tv4->split(0, 4); - TransformPropagatorWithCheck propagator(tv4); - MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); - - tv0->computeAt(tv4, -2); - tv1->computeAt(tv4, -2); - - tv2->axis(-1)->parallelize(ParallelType::Vectorize); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto t0 = at::randn(shape1, options); - auto t1 = at::randn(shape2, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0, t1}); - - // Vectorization of tv2 should be detected as invalid. - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_ANY_THROW(fe.runFusion({t0, t1})); -} - -TEST_F(NVFuserTest, FusionVectorizeContigIndexWithBroadcast_CUDA) { - std::vector shape1{2, 2, 2}; - std::vector shape2{1, 2, 2}; - - Fusion fusion; - FusionGuard fg(&fusion); - - // [I0, I1, I2] - auto tv0 = makeContigTensor(shape1.size()); - fusion.addInput(tv0); - - // [B3, I1, I2] - auto tv1 = makeContigConcreteTensor(shape2); - fusion.addInput(tv1); - - auto tv2 = set(tv1); - auto tv3 = add(tv0, tv2); - fusion.addOutput(tv3); - - tv3->merge(1, 2); - tv3->merge(0, 1); - tv3->split(0, 4); - - // Don't modify tv1 so that it's replayed as tv2 with actual - // transformations. It would create temporary IterDomains, and the - // validation should still be able to detect vectorization by 4 is valid. - // TransformPropagatorWithCheck propagator(tv3); - // MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); - - tv2->merge(1, 2); - tv2->merge(0, 1); - tv2->split(0, 4); - - tv2->computeAt(tv3, -2); - - tv2->axis(-1)->parallelize(ParallelType::Vectorize); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto t0 = at::randn(shape1, options); - auto t1 = at::randn(shape2, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0, t1}); - auto cg_outputs = fe.runFusion({t0, t1}); - - auto ref = t0 + t1; - - testValidate(&fusion, cg_outputs, {t0, t1}, {ref}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionVectorizeContigIndexPointwiseSchedule_CUDA) { - std::vector shape0{100, 14, 2, 14}; - std::vector shape1{100, 2, 14}; - - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeContigTensor(shape0.size()); - fusion.addInput(tv0); - auto tv1 = makeContigTensor(shape1.size()); - fusion.addInput(tv1); - - auto tv2 = broadcast(tv1, {false, true, false, false}); - auto tv3 = add(tv0, tv2); - fusion.addOutput(tv3); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto t0 = at::randn(shape0, options); - auto t1 = at::randn(shape1, options); - - auto lparams = schedulePointwise(&fusion, {t0, t1}); - - GpuLower gpulw(&fusion); - auto kernel = gpulw.kernel(); - - // The innermost two dimensions are merged and contiguous, so - // vectorization can be done against 2*14=28 rather than 14, so - // vector word size should be 4. Broadcasting of tv1 should not - // matter. - for (const auto& vec_info : kernel->summary().vectorized_set_info) { - TORCH_CHECK( - vec_info.word_size == 4, - "Invalid vector word size: ", - vec_info.word_size); - } - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0, t1}, lparams); - auto cg_outputs = fe.runFusion({t0, t1}); - - auto ref = t0 + t1.unsqueeze(-3); - - testValidate(&fusion, cg_outputs, {t0, t1}, {ref}, __LINE__, __FILE__); -} - -// Repro of issue #1539. -TEST_F(NVFuserTest, FusionTrivialReductionForwarding1_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - - auto tv1 = broadcast(tv0, {true, false}); - auto tv2 = sum(tv1, {0}); - auto tv3 = set(tv2); - fusion.addOutput(tv3); - - tv2->merge(0); - tv2->split(0, 4); - - TransformPropagatorWithCheck propagator(tv2); - MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator); - - // All tensors must be transformed to a 2D tensor with each axis - // mapped with each other in the LOOP map. - ComputeAtMap ca_map(&fusion); - for (auto tv : ir_utils::allTvs(&fusion)) { - TORCH_CHECK( - tv->nDims() == 2, "Expected to be a 2D tensor but: ", tv->toString()); - for (const auto i : c10::irange(2)) { - TORCH_CHECK(ca_map.areMapped( - tv->axis(i), tv3->axis(i), IdMappingMode::PERMISSIVE)); - } - } -} - -TEST_F(NVFuserTest, FusionTrivialReductionForwarding2_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - - auto tv1 = broadcast(tv0, {true, false}); - auto tv2 = sum(tv1, {0}); - auto tv3 = add(tv2, IrBuilder::create(1)); - - fusion.addOutput(tv3); - - // Merging a trivial reduction with a non-reduction domain - tv2->merge(0, 1); - tv2->split(0, 4); - - tv3->split(0, 4); - - // tv2 and tv3 are different as tv3 lacks the trivial reduction, but - // they are mapped with each other by BestEffortReplay as the merge - // of trivial reduciton dim is forwarded. - - PairwiseRootDomainMap root_map(tv2, tv3); - - auto p2c = BestEffortReplay::replayCasP(tv3, tv2, 2, root_map).getReplay(); - for (const auto i : c10::irange(tv2->nDims())) { - auto tv2_id = tv2->axis(i); - auto it = p2c.find(tv2_id); - TORCH_CHECK( - it != p2c.end(), - "Expected mapped consumer ID but not found: ", - tv2_id->toString()); - auto tv3_mapped_id = it->second; - TORCH_CHECK( - tv3_mapped_id == tv3->axis(i), - "Unexpected mapped consumer ID: ", - tv3_mapped_id->toString()); - } - - auto c2p = BestEffortReplay::replayPasC(tv2, tv3, 2, root_map).getReplay(); - for (const auto i : c10::irange(tv3->nDims())) { - auto tv3_id = tv3->axis(i); - auto it = c2p.find(tv3_id); - TORCH_CHECK( - it != c2p.end(), - "Expected mapped producer ID but not found: ", - tv3_id->toString()); - auto tv2_mapped_id = it->second; - TORCH_CHECK( - tv2_mapped_id == tv2->axis(i), - "Unexpected mapped consumer ID: ", - tv2_mapped_id->toString()); - } -} - -TEST_F(NVFuserTest, FusionTrivialReductionForwarding3_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - - auto tv1 = sum(tv0, {1}); - auto tv2 = add(tv1, IrBuilder::create(1)); - fusion.addOutput(tv2); - - // Similar pattern as FusionTrivialReductionForwarding2 but no - // trivial reduciton at the root domain - - // Create a trivial reduction by splitting with a factor of 1 - tv1->split(1, 1, false); - // Merging with a trivial reduction - tv1->merge(0, 1); - tv1->split(0, 5); - - tv2->split(0, 5); - - // While the merge of tv1 is done with a trivial reduciton, it's not - // a root domain, so forwarding is not enabled. BestEffortReplay - // should only map the first axis of each tensor. - - PairwiseRootDomainMap root_map(tv1, tv2); - auto p2c = BestEffortReplay::replayCasP(tv2, tv1, 2, root_map).getReplay(); - TORCH_CHECK(p2c.size() == 1, "Expected only one mapping found"); - TORCH_CHECK(p2c.begin()->first == tv1->getRootDomain().at(0)); - TORCH_CHECK(p2c.begin()->second == tv2->getRootDomain().at(0)); -} - -TEST_F(NVFuserTest, FusionTrivialReductionForwarding4_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - - auto tv1 = makeSymbolicTensor(2); - fusion.addInput(tv1); - - auto tv2 = broadcast(tv0, {true, false}); - auto tv3 = add(tv1, tv2); - fusion.addOutput(tv3); - - // tv4 has a trivial reduction axis - auto tv4 = sum(tv2, {0}); - auto tv5 = add(tv4, IrBuilder::create(1)); - fusion.addOutput(tv5); - - tv3->merge(0, 1); - tv3->split(0, 32); - - // This causes the trivial reduction of tv4 to be merged with - // another axis of tv4, and then forward computeAt is done from tv4 - // to tv5. The split of the merged id of tv4 should be done on tv5 - // by forwarding the merge of the trivial reduction. - tv0->computeAt(tv3, -1); - - tv3->axis(0)->parallelize(ParallelType::BIDx); - tv3->axis(1)->parallelize(ParallelType::TIDx); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto t0 = at::randn({111}, options); - auto t1 = at::randn({123, 111}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0, t1}); - auto cg_outputs = fe.runFusion({t0, t1}); - - auto t2 = t0.unsqueeze(0); - auto t3 = t1 + t2; - auto t5 = sum(t2, {0}) + 1; - - testValidate(&fusion, cg_outputs, {t0, t1}, {t3, t5}, __LINE__, __FILE__); -} - -// See issue #1598 -TEST_F(NVFuserTest, FusionRAWSyncInsertionPlace1_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2); - auto tv1 = makeSymbolicTensor(2); - fusion.addInput(tv0); - fusion.addInput(tv1); - - auto tv2 = set(tv0); - auto tv3 = set(tv1); - auto tv4 = add(tv2, tv3); - fusion.addOutput(tv4); - - // Place tv2 on shared memory - tv2->split(0, 2); - tv2->split(-1, 4); - tv2->setMemoryType(MemoryType::Shared); - tv2->axis(-2)->parallelize(ParallelType::TIDy); - tv2->axis(-1)->parallelize(ParallelType::TIDx); - - tv3->split(0, 2); - tv3->split(-1, 4); - // swap tidx and tidy - tv3->axis(-2)->parallelize(ParallelType::TIDx); - tv3->axis(-1)->parallelize(ParallelType::TIDy); - - tv4->split(0, 2); - tv4->split(-1, 4); - tv4->axis(-2)->parallelize(ParallelType::TIDx); - tv4->axis(-1)->parallelize(ParallelType::TIDy); - - tv0->computeAt(tv4, 1); - tv3->computeAt(tv4, -1); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto t0 = at::randn({10, 64}, options); - auto t1 = at::randn({10, 64}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0, t1}); - auto cg_outputs = fe.runFusion({t0, t1}); - - auto ref = t0 + t1; - - testValidate(&fusion, cg_outputs, {t0, t1}, {ref}, __LINE__, __FILE__); -} - -// See issue #1598 -TEST_F(NVFuserTest, FusionRAWSyncInsertionPlace2_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2); - auto tv1 = makeSymbolicTensor(2); - fusion.addInput(tv0); - fusion.addInput(tv1); - - auto tv2 = set(tv0); - auto tv3 = set(tv1); - auto tv4 = add(tv2, tv3); - fusion.addOutput(tv4); - - tv2->split(0, 2); - tv2->split(-1, 4); - tv2->setMemoryType(MemoryType::Shared); - - tv2->axis(-2)->parallelize(ParallelType::TIDy); - tv2->axis(-1)->parallelize(ParallelType::TIDx); - - tv4->split(0, 2); - tv4->split(-1, 4); - // Also do unroll for tv3 and tv4 - tv4->split(-2, 8, false); - tv4->axis(-3)->parallelize(ParallelType::Unroll); - // swap tidx and tidy - tv4->axis(-2)->parallelize(ParallelType::TIDx); - tv4->axis(-1)->parallelize(ParallelType::TIDy); - - tv0->computeAt(tv4, 1); - tv3->computeAt(tv4, -1); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto t0 = at::randn({10, 64}, options); - auto t1 = at::randn({10, 64}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0, t1}); - auto cg_outputs = fe.runFusion({t0, t1}); - - auto ref = t0 + t1; - - testValidate(&fusion, cg_outputs, {t0, t1}, {ref}, __LINE__, __FILE__); -} - -// See issue #1599 -TEST_F(NVFuserTest, FusionRAWSyncInsertionPlace3_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2); - auto tv1 = makeSymbolicTensor(2); - fusion.addInput(tv0); - fusion.addInput(tv1); - - auto tv2 = set(tv0); - auto tv3 = set(tv1); - auto tv4 = add(tv2, tv3); - fusion.addOutput(tv4); - - // Use unroll where a RAW-sync tensor is stored - - tv4->split(0, 2); - tv4->split(0, 3); - tv4->split(-1, 4); - tv4->axis(1)->parallelize(ParallelType::Unroll); - tv4->axis(-2)->parallelize(ParallelType::TIDx); - tv4->axis(-1)->parallelize(ParallelType::TIDy); - - tv0->computeAt(tv4, 3); - tv3->computeAt(tv4, -1); - - tv2->split(-1, 4); - tv2->axis(-2)->parallelize(ParallelType::TIDy); - tv2->axis(-1)->parallelize(ParallelType::TIDx); - tv2->setMemoryType(MemoryType::Shared); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto t0 = at::randn({50, 64}, options); - auto t1 = at::randn({50, 64}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0, t1}); - auto cg_outputs = fe.runFusion({t0, t1}); - - auto ref = t0 + t1; - - testValidate(&fusion, cg_outputs, {t0, t1}, {ref}, __LINE__, __FILE__); -} - -// See #1618 -TEST_F(NVFuserTest, FusionRAWSyncInsertionPlace4_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeConcreteTensor({16, 128}); - auto tv1 = makeConcreteTensor({16, 128}); - fusion.addInput(tv0); - fusion.addInput(tv1); - - auto tv2 = set(tv0); - auto tv3 = set(tv1); - auto tv4 = set(tv2); - auto tv5 = set(tv3); - auto tv6 = add(tv4, tv5); - fusion.addOutput(tv6); - - tv2->setMemoryType(MemoryType::Shared); - tv3->setMemoryType(MemoryType::Shared); - - tv2->computeAt(tv6, 0); - tv3->computeAt(tv6, 1); - tv4->computeAt(tv6, 1); - tv5->computeAt(tv6, -1); - tv2->split(1, 64); - tv3->split(1, 64); - tv2->axis(-1)->parallelize(ParallelType::TIDx); - tv3->axis(-1)->parallelize(ParallelType::TIDx); - tv6->axis(-1)->parallelize(ParallelType::TIDx); - - // Check the block sync is inserted at the correct location. - // There is exactly one block sync needed in this test case - // and the sync needs to be after the 2 expressions - // that modify shared memory. - class SyncInsertionPointChecker : public kir::IrVisitor { - public: - using kir::IrVisitor::handle; - - private: - void handle(UnaryOp* uop) final { - // Record number of unary ops that modifies shared memory. - if (uop->out()->isA() && - uop->out()->as()->view()->getMemoryType() == - MemoryType::Shared && - // Filter out initialization expressions - uop->in()->isA()) { - number_of_writes_++; - } - } - void handle(kir::BlockSync* bsync) final { - // Make sure both shared memory modifying expressions - // have been observed at the sync insertion point. - TORCH_INTERNAL_ASSERT( - number_of_writes_ == 2, - "FusionRAWSyncInsertionPlace4 test fail:", - "only 1 sync after the 2 shared mem writes is needed in this test," - "either a redundant sync has been inserted or the block sync is not inserted at the right place"); - } - - private: - int number_of_writes_ = 0; - } sync_insertion_checker; - GpuLower gpulw(&fusion); - sync_insertion_checker.handle(gpulw.kernel()->topLevelExprs()); -} - -// Test serial write and parallel read of shared mem: mapped case -TEST_F(NVFuserTest, FusionSerialSmemWriteParallelRead1_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeConcreteTensor({128, 6}); - TensorView* tv1 = makeConcreteTensor({128, 6}); - TensorView* tv2 = makeConcreteTensor({128, 6}); - fusion.addInput(tv0); - fusion.addInput(tv1); - fusion.addInput(tv2); - - TensorView* tv3 = add(tv0, tv1); - TensorView* tv4 = add(tv3, tv2); - - fusion.addOutput(tv4); - - // Use shared memory - tv3->setMemoryType(MemoryType::Shared); - - // Parallelize t4, in this case dim 0 on tv3 will - // not be parallelized but dim0 of t4 will be. - // We will need to make sure a sync is inserted - // even if these dimensions are mapped. - tv4->axis(0)->parallelize(ParallelType::TIDx); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::Tensor t0 = at::randn({128, 6}, options); - at::Tensor t1 = at::randn({128, 6}, options); - at::Tensor t2 = at::randn({128, 6}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0, t1, t2}); - auto cg_outputs = fe.runFusion({t0, t1, t2}); - - auto ref = t0 + t1 + t2; - - testValidate(&fusion, cg_outputs, {t0, t1, t2}, {ref}, __LINE__, __FILE__); -} - -// Test serial write and parallel read of shared mem: un-mapped case -TEST_F(NVFuserTest, FusionSerialSmemWriteParallelRead2_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeConcreteTensor({128, 6}); - TensorView* tv1 = makeConcreteTensor({128, 6}); - TensorView* tv2 = makeConcreteTensor({128, 6}); - fusion.addInput(tv0); - fusion.addInput(tv1); - fusion.addInput(tv2); - - TensorView* tv3 = add(tv0, tv1); - TensorView* tv4 = add(tv3, tv2); - - fusion.addOutput(tv4); - - // Use shared memory - tv3->setMemoryType(MemoryType::Shared); - - // Split and parallelize t4, - // the parallelized dimension in t4 will not - // map across to the shared mem tensor, t3. So - // there will need to be a sync before use of t3. - tv4->split(0, 2); - tv4->axis(0)->parallelize(ParallelType::TIDx); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::Tensor t0 = at::randn({128, 6}, options); - at::Tensor t1 = at::randn({128, 6}, options); - at::Tensor t2 = at::randn({128, 6}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0, t1, t2}); - auto cg_outputs = fe.runFusion({t0, t1, t2}); - - auto ref = t0 + t1 + t2; - - testValidate(&fusion, cg_outputs, {t0, t1, t2}, {ref}, __LINE__, __FILE__); -} - -// Simple test of async copy primitive -TEST_F(NVFuserTest, FusionSimpleCpAsync_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - int m = 33, n = 31; - - TensorView* tv0 = makeConcreteTensor({m, n}); - TensorView* tv1 = makeConcreteTensor({m, n}); - - fusion.addInput(tv0); - fusion.addInput(tv1); - - TensorView* tv2 = add(tv0, tv1); - - fusion.addOutput(tv2); - - auto tv0_shared = tv0->cacheAfter(LoadStoreOpType::CpAsync); - tv0_shared->setMemoryType(MemoryType::Shared); - - tv0->computeAt(tv2, 1); - tv0_shared->axis(1)->parallelize(ParallelType::TIDx); - tv2->axis(1)->parallelize(ParallelType::TIDx); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({m, n}, options); - at::Tensor t1 = at::randn({m, n}, options); - - FusionExecutor fe; - - // requires ampere+ GPU - if (!deviceMajorMinorCheck(8)) { - ASSERT_ANY_THROW(fe.compileFusion(&fusion, {t0, t1})); - GTEST_SKIP() << "skipping tests on pre-AMPERE GPUs"; - } - fe.compileFusion(&fusion, {t0, t1}); - auto cg_outputs = fe.runFusion({t0, t1}); - - auto ref = t0 + t1; - - testValidate(&fusion, cg_outputs, {t0, t1}, {ref}, __LINE__, __FILE__); -} - -// Simple test of async copy primitive: double buffered -// Double buffer case 1, both block sync and async wait -// are needed. -TEST_F(NVFuserTest, FusionDoubleBufferCpAsync1_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Using vectorization so need to keep n multiple of 4. - int m = 33, n = 48; - - TensorView* tv0 = makeConcreteTensor({m, n}); - TensorView* tv1 = makeConcreteTensor({m, n}); - - fusion.addInput(tv0); - fusion.addInput(tv1); - - TensorView* tv2 = add(tv0, tv1); - - fusion.addOutput(tv2); - - auto tv0_shared = tv0->cacheAfter(LoadStoreOpType::CpAsync); - tv0_shared->setMemoryType(MemoryType::Shared); - tv0->computeAt(tv2, 1); - - // Asynchronously load a tile in one schedule - tv0_shared->split(1, 4); - tv0_shared->axis(-1)->parallelize(ParallelType::Vectorize); - tv0_shared->axis(-2)->parallelize(ParallelType::TIDx); - - // Consume the loaded tile in another schedule, - // triggering the need for a sync. - tv2->split(1, 12); - tv2->axis(-1)->parallelize(ParallelType::TIDx); - - // Double buffer the shared mem tensor. - tv0_shared->doubleBuffer(); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({m, n}, options); - at::Tensor t1 = at::randn({m, n}, options); - - FusionExecutor fe; - // requires ampere+ GPU - if (!deviceMajorMinorCheck(8)) { - ASSERT_ANY_THROW(fe.compileFusion(&fusion, {t0, t1})); - GTEST_SKIP() << "skipping tests on pre-AMPERE GPUs"; - } - fe.compileFusion(&fusion, {t0, t1}); - auto cg_outputs = fe.runFusion({t0, t1}); - - auto ref = t0 + t1; - - testValidate(&fusion, cg_outputs, {t0, t1}, {ref}, __LINE__, __FILE__); -} - -// Simple test of async copy primitive: double buffered -// Double buffer case 2, only async wait is needed -TEST_F(NVFuserTest, FusionDoubleBufferCpAsync2_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Using vectorization so need to keep n multiple of 4. - int m = 33, n = 48; - - TensorView* tv0 = makeConcreteTensor({m, n}); - TensorView* tv1 = makeConcreteTensor({m, n}); - - fusion.addInput(tv0); - fusion.addInput(tv1); - - TensorView* tv2 = add(tv0, tv1); - - fusion.addOutput(tv2); - - auto tv0_shared = tv0->cacheAfter(LoadStoreOpType::CpAsync); - tv0_shared->setMemoryType(MemoryType::Shared); - tv0->computeAt(tv2, 1); - - // Asynchronously load a tile in one schedule - tv0_shared->split(1, 4); - tv0_shared->axis(-2)->parallelize(ParallelType::TIDx); - - // Consume the loaded tile in another schedule, - // triggering the need for a sync. - tv2->split(1, 4); - tv2->axis(-2)->parallelize(ParallelType::TIDx); - - // Double buffer the shared mem tensor. - tv0_shared->doubleBuffer(); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({m, n}, options); - at::Tensor t1 = at::randn({m, n}, options); - - FusionExecutor fe; - // requires ampere+ GPU - if (!deviceMajorMinorCheck(8)) { - ASSERT_ANY_THROW(fe.compileFusion(&fusion, {t0, t1})); - GTEST_SKIP() << "skipping tests on pre-AMPERE GPUs"; - } - fe.compileFusion(&fusion, {t0, t1}); - auto cg_outputs = fe.runFusion({t0, t1}); - - auto ref = t0 + t1; - - testValidate(&fusion, cg_outputs, {t0, t1}, {ref}, __LINE__, __FILE__); -} - -// Simple test for double buffer in shared mem, -// where we should not insert redundant syncs when -// they are not needed. -TEST_F(NVFuserTest, FusionDoubleBufferNoSync_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Using vectorization so need to keep n multiple of 4. - int m = 33, n = 48; - - TensorView* tv0 = makeConcreteTensor({m, n}); - TensorView* tv1 = makeConcreteTensor({m, n}); - - fusion.addInput(tv0); - fusion.addInput(tv1); - - TensorView* tv2 = add(tv0, tv1); - - fusion.addOutput(tv2); - - auto tv0_shared = tv0->cacheAfter(); - tv0_shared->setMemoryType(MemoryType::Shared); - tv0->computeAt(tv2, 1); - - // Asynchronously load a tile in one schedule - tv0_shared->split(1, 4); - tv0_shared->axis(-2)->parallelize(ParallelType::TIDx); - - // Consume the loaded tile in another schedule, - // triggering the need for a sync. - tv2->split(1, 4); - tv2->axis(-2)->parallelize(ParallelType::TIDx); - - // Double buffer the shared mem tensor. - tv0_shared->doubleBuffer(); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({m, n}, options); - at::Tensor t1 = at::randn({m, n}, options); - - GpuLower gpulw(&fusion); - auto flattened_exprs = - ir_utils::flattenScopedExprs(gpulw.kernel()->topLevelExprs()); - bool sync_inserted = std::any_of( - flattened_exprs.begin(), flattened_exprs.end(), [](Expr* expr) { - return expr->isA(); - }); - TORCH_INTERNAL_ASSERT(!sync_inserted, "Un-expected block sync inserted"); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0, t1}); - auto cg_outputs = fe.runFusion({t0, t1}); - - auto ref = t0 + t1; - - testValidate(&fusion, cg_outputs, {t0, t1}, {ref}, __LINE__, __FILE__); -} - -// Test predicate inversion for cp.async -TEST_F(NVFuserTest, FusionCpAsyncPredicate_CUDA) { - // requires ampere+ GPU - - Fusion fusion; - FusionGuard fg(&fusion); - - // Using vectorization so need to keep n multiple of 4. - int m = 33, n = 48; - - TensorView* tv0 = makeConcreteTensor({m, n}); - - fusion.addInput(tv0); - auto tv1 = sum(tv0, {1}); - fusion.addOutput(tv1); - - auto tv0_shared = tv0->cacheAfter(LoadStoreOpType::CpAsync); - auto tv0_reg = tv0_shared->cacheAfter(); - tv0_shared->setMemoryType(MemoryType::Shared); - tv0->computeAt(tv1, 1); - - tv0_shared->split(-1, 32); - tv0_shared->split(-1, 4); - tv0_shared->axis(-1)->parallelize(ParallelType::Vectorize); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({m, n}, options); - - FusionExecutor fe; - if (!deviceMajorMinorCheck(8)) { - ASSERT_ANY_THROW(fe.compileFusion(&fusion, {t0})); - GTEST_SKIP() << "skipping tests on pre-AMPERE GPUs"; - } - - fe.compileFusion(&fusion, {t0}); - auto cg_outputs = fe.runFusion({t0}); - - auto ref = t0.sum({1}); - - testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); -} - -// Test predicate removal on reg-to-reg expressions -TEST_F(NVFuserTest, FusionPredRemovalCheck_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeContigTensor(2); - fusion.addInput(tv0); - - TensorView* tv1 = set(tv0); - TensorView* tv2 = set(tv1); - TensorView* tv3 = set(tv2); - TensorView* tv4 = set(tv3); - - fusion.addOutput(tv4); - tv4->split(1, 4); - tv0->computeAt(tv4, -2); - tv3->axis(-1)->parallelize(ParallelType::Vectorize); - - class PredicateRemovalChecker : public kir::IrVisitor { - public: - using kir::IrVisitor::handle; - - private: - void handle(UnaryOp* uop) final { - assertOnLocalToLocal(uop); - } - - // Utility to assert any local-to-local expr is only trivially predicated. - void assertOnLocalToLocal(Expr* expr) { - bool is_local = true; - for (auto in : ir_utils::filterByType(expr->inputs())) { - if (in->view()->getMemoryType() != MemoryType::Local) { - is_local = false; - } - } - for (auto in : - ir_utils::filterByType(expr->outputs())) { - if (in->view()->getMemoryType() != MemoryType::Local) { - is_local = false; - } - } - - if (is_local) { - if (auto ite = dynamic_cast(scope_exprs_.back())) { - TORCH_INTERNAL_ASSERT( - ite->predicate()->value()->isConst(), - "redundant predicate on: ", - expr); - } - } - } - - private: - bool within_ite_ = false; - } pred_checker; - - GpuLower gpulw(&fusion); - pred_checker.handle(gpulw.kernel()->topLevelExprs()); -} - -TEST_F(NVFuserTest, FusionPropagateParallelTypesToSiblings_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - auto tvs = Welford(tv0, {0}); - auto tv_avg = tvs.avg; - fusion.addOutput(tv_avg); - - tv_avg->split(0, 128); - TransformPropagatorWithCheck propagator(tv_avg); - MaxRootDomainInfoSpanningTree(tv_avg).traverse(&propagator); - - tv_avg->axis(0)->parallelize(ParallelType::BIDx); - tv_avg->axis(1)->parallelize(ParallelType::TIDx); - - // Make sure the parallelization of tv_avg is propagated to the var - // and count tensors. - GpuLower gpulw(&fusion); - for (const auto expr : gpulw.kernel()->exprs()) { - auto wop = dynamic_cast(expr); - if (wop == nullptr) { - continue; - } - auto ref = wop->outAvg()->as(); - for (auto sibling : ir_utils::filterByType(wop->outputs())) { - if (ref == sibling) { - continue; - } - TORCH_CHECK( - ref->nDims() == sibling->nDims(), - "Invalid sibling: ", - sibling->toString()); - for (const auto i : c10::irange(ref->nDims())) { - TORCH_CHECK( - ref->axis(i)->getParallelType() == - sibling->axis(i)->getParallelType(), - "Mismatched parallel types between siblings. ", - ref->toString(), - ", ", - sibling->toString()); - } - } - } - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto options_int = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0); - at::manual_seed(0); - at::Tensor t0 = at::randn({9999}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0}); - auto outputs = fe.runFusion({t0}); - - testValidate(fe.kernel(), outputs, {t0}, {t0.mean({0})}, __LINE__, __FILE__); -} - -// Test ExactRootDomainMap -TEST_F(NVFuserTest, FusionExactRootDomainMap_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - auto tv1 = makeSymbolicTensor(2); - fusion.addInput(tv1); - - auto tv2 = broadcast(tv0, {false, true}); - auto tv3 = transpose(tv2); - auto tv4 = add(tv2, tv1); - auto tv5 = add(tv2, tv3); - auto tv6 = add(tv3, tv1); - fusion.addOutput(tv4); - fusion.addOutput(tv5); - fusion.addOutput(tv6); - - const auto exact_map = ExactRootDomainMap(&fusion); - - // In the exact mapping, the broadcast domain introduced at tv2 is - // only mapped with the another one in tv3, which is just transposed - // from tv2. Any other domain, including the second domain of tv4, - // must not be mapped. - - auto tv2_bc = tv2->axis(1); - auto tv3_bc = tv3->axis(0); - - TORCH_CHECK( - exact_map.areMapped(tv2_bc, tv3_bc), - "Invalid exact root domain map: ", - exact_map.toString()); - - // They must not be mapped with anything else. - for (auto tv : ir_utils::allTvs(&fusion)) { - for (auto root_id : tv->getRootDomain()) { - if (root_id == tv2_bc || root_id == tv3_bc) { - continue; - } - TORCH_CHECK( - !exact_map.areMapped(root_id, tv2_bc), - "Invalid exact root domain map: ", - exact_map.toString()); - TORCH_CHECK( - !exact_map.areMapped(root_id, tv3_bc), - "Invalid exact root domain map: ", - exact_map.toString()); - } - } -} - -class NVFuserMultithreadedTest : public ::testing::Test { - protected: - bool was_enabled = false; - - void SetUp() override { - was_enabled = fuser::cuda::setEnabled(true); - } - - void TearDown() override { - fuser::cuda::setEnabled(was_enabled); - } -}; - -TEST_F(NVFuserMultithreadedTest, SingleFunction_CUDA) { - std::string ir = R"IR( -graph(%x.1 : Tensor, - %y.1 : Tensor): - %12 : NoneType = prim::Constant() - %11 : bool = prim::Constant[value=0]() - %9 : int = prim::Constant[value=1]() - %3 : Tensor = aten::exp(%x.1) - %5 : Tensor = aten::relu(%y.1) - %6 : Tensor = aten::sin(%5) - %8 : Tensor = aten::add(%3, %6, %9) - %10 : int[] = prim::ListConstruct(%9) - %13 : Tensor = aten::sum(%8, %10, %11, %12) - return (%13) -)IR"; - auto g = std::make_shared(); - torch::jit::parseIR(ir, g.get()); - GraphFunction fn("nvfuser_test", g, nullptr); - - auto run_kernel = [&fn]() { - auto x = torch::rand({32, 32}, at::TensorOptions(at::kCUDA)); - auto y = torch::rand({32, 32}, at::TensorOptions(at::kCUDA)); - std::vector results; - for (const auto& _ : c10::irange(10)) { - auto stack = createStack({x.clone(), y.clone()}); - fn.run(stack); - results.push_back(stack.back()); - } - for (const auto& i : c10::irange(1, 10)) { - auto t0 = results[0].toTensor(); - auto ti = results[i].toTensor(); - ASSERT_TRUE(at::allclose(t0, ti)); - } - }; - - constexpr size_t kNumThreads = 4; - std::vector threads; - for (size_t id = 0; id < kNumThreads; ++id) { - threads.emplace_back(run_kernel); - } - for (auto& t : threads) { - t.join(); - } -} - -TEST_F(NVFuserMultithreadedTest, MultipleFunctions_CUDA) { - auto run_kernel = []() { - const std::string ir = R"IR( - graph(%x.1 : Tensor, - %y.1 : Tensor): - %12 : NoneType = prim::Constant() - %11 : bool = prim::Constant[value=0]() - %9 : int = prim::Constant[value=1]() - %3 : Tensor = aten::exp(%x.1) - %5 : Tensor = aten::relu(%y.1) - %6 : Tensor = aten::sin(%5) - %8 : Tensor = aten::add(%3, %6, %9) - %10 : int[] = prim::ListConstruct(%9) - %13 : Tensor = aten::sum(%8, %10, %11, %12) - return (%13) - )IR"; - auto g = std::make_shared(); - torch::jit::parseIR(ir, g.get()); - GraphFunction fn("nvfuser_test", g, nullptr); - - auto x = torch::rand({32, 32}, at::TensorOptions(at::kCUDA)); - auto y = torch::rand({32, 32}, at::TensorOptions(at::kCUDA)); - std::vector results; - constexpr size_t numRuns = 10; - for (const auto& _ : c10::irange(numRuns)) { - auto stack = createStack({x.clone(), y.clone()}); - fn.run(stack); - results.push_back(stack.back()); - } - for (const auto& i : c10::irange(1, numRuns)) { - auto t0 = results[0].toTensor(); - auto ti = results[i].toTensor(); - ASSERT_TRUE(at::allclose(t0, ti)); - } - }; - - constexpr size_t kNumThreads = 4; - std::vector threads; - for (size_t id = 0; id < kNumThreads; ++id) { - threads.emplace_back(run_kernel); - } - for (auto& t : threads) { - t.join(); - } -} - -// Repro of issue #1655 -TEST_F(NVFuserTest, FusionIncompleteConcreteID_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - auto tv1 = makeSymbolicTensor(2); - fusion.addInput(tv1); - auto tv2 = makeSymbolicTensor(2); - fusion.addInput(tv2); - - auto tv3 = broadcast(tv0, {true, true, false}); - auto tv4 = broadcast(tv1, {false, true, false}); - auto tv5 = broadcast(tv2, {true, false, false}); - - auto tv6 = add(tv3, tv4); - auto tv7 = add(tv3, tv5); - - fusion.addOutput(tv6); - fusion.addOutput(tv7); - - tv6->merge(0); - tv6->merge(0); - - TransformPropagatorWithCheck propagator(tv6); - MaxRootDomainInfoSpanningTree(tv6).traverse(&propagator); - - tv0->computeAt(tv6, -1, ComputeAtMode::MostInlined); - tv1->computeAt(tv6, -1, ComputeAtMode::MostInlined); - tv2->computeAt(tv7, -1, ComputeAtMode::MostInlined); - - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_ANY_THROW(fusion.printKernel()); -} - -TEST_F(NVFuserTest, FusionTestReEntrantGridWelford_CUDA) { - std::unique_ptr fusion_ptr = std::make_unique(); - Fusion& fusion = *fusion_ptr.get(); - FusionGuard fg(&fusion); - - int X = 256, Y = 7, Z = 2048; - - // setup fusion - auto tv0 = makeContigTensor(4, DataType::Half); - fusion.addInput(tv0); - auto tv1 = castOp(DataType::Float, tv0); - - auto tvs = Welford(tv1, {0, 1, 2}); - auto tv_avg = tvs.avg; - auto tv_M2 = tvs.var_sum; - auto tv_N = tvs.n; - fusion.addOutput(tv_avg); - fusion.addOutput(tv_M2); - - auto cached_input = tv0->cacheAfter(); - auto cached_avg = tv_avg->cacheBefore(); - auto cached_M2 = tv_M2->cacheBefore(); - - auto reduction_tv = scheduler_utils::getReductionTvs(&fusion)[0]; - - reduction_tv->merge(0); - reduction_tv->merge(0); - - int TIDx = 16; - int vec = 4; - - int TIDy = 16; - int outer_tidy_fact = 16; - - reduction_tv->split(-1, TIDx * vec); - reduction_tv->split(-1, vec); - reduction_tv->axis(-2)->parallelize(ParallelType::TIDx); - reduction_tv->axis(-1)->parallelize(ParallelType::Vectorize); - reduction_tv->axis(-3)->parallelize(ParallelType::BIDx); - - reduction_tv->split(0, TIDy); - reduction_tv->axis(1)->parallelize(ParallelType::TIDy); - reduction_tv->split(0, outer_tidy_fact); - reduction_tv->axis(0)->parallelize(ParallelType::BIDy); - - // T2_g[ rblockIdx.y, rS{16}, rthreadIdx.y, iblockIdx.x, ithreadIdx.x24, - // iV25{4} ] - reduction_tv->reorder({{3, 0}, {4, 1}, {0, 2}, {2, 3}, {1, 4}, {5, 5}}); - // T2_g[iblockIdx.x, ithreadIdx.x24, rblockIdx.y, rthreadIdx.y, rS{16}, - // iV25{4}] - - TransformPropagatorWithCheck propagator(reduction_tv); - MaxRootDomainInfoSpanningTree(reduction_tv).traverse(&propagator); - auto rfactor_tv = ir_utils::rfactorHelper(reduction_tv, {4}); - scheduler_utils::parallelizeAllLike(rfactor_tv); - - tv0->computeAt(tv_avg, 2); - tv0->computeAt(cached_input, -2); - - cached_input->computeAt(rfactor_tv, 4, ComputeAtMode::BestEffort); - - for (auto tv : ir_utils::allTvs(&fusion)) { - if (tv == cached_input || tv == tv_avg || tv == tv_M2) { - continue; - } - tv->axis(-1)->parallelize(ParallelType::Serial); - } - - FusionExecutor fe; - fe.compileFusion(&fusion, {}, LaunchParams()); - - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({X, Y, Y, Z}, options); - - auto cg_outputs = fe.runFusion({t0}, LaunchParams(-1, -1, -1, -1, -1, -1)); - - // by default Welford outputs sum of square diff so need to divide to get var - cg_outputs[1] = cg_outputs[1].div((float)(X * Y * Y)); - - auto at_mu = at::mean(t0.to(at::kDouble), {0, 1, 2}); - auto at_var = at::var(t0.to(at::kDouble), {0, 1, 2}, false); - - testValidate( - &fusion, - cg_outputs, - {t0}, - {at_mu, at_var}, - __LINE__, - __FILE__, - "", - LaunchParams(-1, -1, -1, -1, -1, -1)); -} - -// Test sync insertion with redundant predicates -TEST_F(NVFuserTest, FusionRedundantPredSync_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeConcreteTensor({32}); - TensorView* tv1 = makeConcreteTensor({32, 32}); - fusion.addInput(tv0); - fusion.addInput(tv1); - - auto tv2 = broadcast(tv0, {true, false}); - auto tv3 = add(tv2, tv1); - - fusion.addOutput(tv3); - - auto tv0c = tv0->cacheAfter(); - - // Make a redundant write through smem - tv0c->setMemoryType(MemoryType::Shared); - - tv0->computeAt(tv3, 0); - tv1->computeAt(tv3, 0); - - tv0c->axis(0)->parallelize(ParallelType::TIDx); - tv2->axis(0)->parallelize(ParallelType::TIDy); - tv2->axis(1)->parallelize(ParallelType::TIDx); - - tv3->axis(0)->parallelize(ParallelType::TIDy); - tv3->axis(1)->parallelize(ParallelType::TIDx); - - GpuLower gpulw(&fusion); - auto flattened_exprs = - ir_utils::flattenScopedExprs(gpulw.kernel()->topLevelExprs()); - bool sync_inserted = std::any_of( - flattened_exprs.begin(), flattened_exprs.end(), [](Expr* expr) { - return expr->isA(); - }); - TORCH_INTERNAL_ASSERT(sync_inserted, "Expected block sync not inserted"); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::Tensor t0 = at::randn({32}, options); - at::Tensor t1 = at::randn({32, 32}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0, t1}); - auto cg_outputs = fe.runFusion({t0, t1}); - - auto ref = t0 + t1; - - testValidate(&fusion, cg_outputs, {t0, t1}, {ref}, __LINE__, __FILE__); -} - -// Test case for removing syncs on chain of redundant uses. -TEST_F(NVFuserTest, FusionRedundantPredSync2_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeConcreteTensor({32}); - TensorView* tv1 = makeConcreteTensor({32, 32}); - fusion.addInput(tv0); - fusion.addInput(tv1); - - auto tv2 = broadcast(tv0, {true, false}); - auto tv3 = add(tv2, tv1); - - fusion.addOutput(tv3); - - auto tv0c = tv0->cacheAfter(); - - // Make a redundant write through smem - tv0c->setMemoryType(MemoryType::Shared); - tv2->setMemoryType(MemoryType::Shared); - - tv0->computeAt(tv3, 0); - tv1->computeAt(tv3, 0); - - tv0c->axis(0)->parallelize(ParallelType::TIDx); - tv2->axis(0)->parallelize(ParallelType::TIDy); - tv2->axis(1)->parallelize(ParallelType::TIDx); - - tv3->axis(0)->parallelize(ParallelType::TIDy); - tv3->axis(1)->parallelize(ParallelType::TIDx); - - // Utility class to make sure one block sync - // is inserted by RAW pass. - class SyncChecker : public kir::IrVisitor { - public: - using kir::IrVisitor::handle; - int result() { - return sync_seen_; - } - - private: - void handle(kir::BlockSync*) final { - sync_seen_++; - } - - private: - int sync_seen_ = 0; - } checker; - - GpuLower gpulw(&fusion); - checker.handle(gpulw.kernel()->topLevelExprs()); - TORCH_INTERNAL_ASSERT( - checker.result() < 2, "More syncs were inserted than expected"); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::Tensor t0 = at::randn({32}, options); - at::Tensor t1 = at::randn({32, 32}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0, t1}); - auto cg_outputs = fe.runFusion({t0, t1}); - - auto ref = t0 + t1; - - testValidate(&fusion, cg_outputs, {t0, t1}, {ref}, __LINE__, __FILE__); -} - -// Test case for sync insertion after redundant predicated smem write -// Check that syncs are removed only when all paths are redundant. -TEST_F(NVFuserTest, FusionRedundantPredSync3_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeConcreteTensor({32}); - TensorView* tv1 = makeConcreteTensor({32, 32}); - fusion.addInput(tv0); - fusion.addInput(tv1); - - auto tv2 = broadcast(tv0, {true, false}); - auto tv3 = set(tv2); - auto tv4 = add(tv3, tv1); - auto tv5 = add(tv2, tv1); - - fusion.addOutput(tv4); - fusion.addOutput(tv5); - - auto tv0c = tv0->cacheAfter(); - - // In this scheduling config, - // tv0c -> tv2 -> tv3 is a redundant path for tidy - // tv0c -> tv2 -> tv5 is not. - // So we need a RAW sync in tv0c->tv2 to make sure - // tv2 has the correct value to produce tv5. - tv0c->setMemoryType(MemoryType::Shared); - tv3->setMemoryType(MemoryType::Shared); - - tv0c->axis(0)->parallelize(ParallelType::TIDx); - tv2->axis(0)->parallelize(ParallelType::TIDy); - tv2->axis(1)->parallelize(ParallelType::TIDx); - - tv3->axis(0)->parallelize(ParallelType::TIDy); - tv3->axis(1)->parallelize(ParallelType::TIDx); - - tv5->axis(0)->parallelize(ParallelType::TIDy); - tv5->axis(1)->parallelize(ParallelType::TIDx); - - // Utility class to make sure one block sync - // is inserted by RAW pass. - class SyncChecker : public kir::IrVisitor { - public: - using kir::IrVisitor::handle; - int result() { - return sync_seen_; - } - - private: - void handle(kir::BlockSync* sync) final { - if (!sync->isWarHazardSync()) { - sync_seen_++; - } - } - - private: - int sync_seen_ = 0; - } checker; - - GpuLower gpulw(&fusion); - checker.handle(gpulw.kernel()->topLevelExprs()); - - // This is implicit checking. There are exactly 2 places - // where RAW hazards happen: one producing tv2 and the other - // producing tv3. This test case expect syncs in both of - // these places so we check that 2 RAW syncs are inserted. - TORCH_INTERNAL_ASSERT( - checker.result() == 2, - "Exactly 2 RAW sync expected for the two shared memory transfers"); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::Tensor t0 = at::randn({32}, options); - at::Tensor t1 = at::randn({32, 32}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0, t1}); - auto cg_outputs = fe.runFusion({t0, t1}); - - auto ref = t0 + t1; - - testValidate(&fusion, cg_outputs, {t0, t1}, {ref, ref}, __LINE__, __FILE__); -} - -// Unit test case for detecting thread redundant usage of shared tensors. -TEST_F(NVFuserTest, FusionRedundantUseCheck_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeConcreteTensor({32, 32}); - fusion.addInput(tv0); - - auto tv1 = set(tv0); - auto tv2 = set(tv1); - auto tv3 = set(tv2); - auto tv4 = set(tv3); - - auto tv5 = set(tv4); - - auto tv6 = set(tv4); - auto tv7 = set(tv6); - - fusion.addOutput(tv5); - fusion.addOutput(tv7); - - tv2->setMemoryType(MemoryType::Shared); - tv4->setMemoryType(MemoryType::Shared); - - tv7->axis(-1)->parallelize(ParallelType::TIDx); - - // Thread pred map cannot be built without an active lower - // object. So would need to lower the whole fusion for - // testing. However, lower also keeps an copy of the fusion - // so the original pointers cannot be used to querry the - // thread pred map. So have to traverse the new expr list - // to find the pointers; - GpuLower gpulw(&fusion); - - TensorView *lowered_tv2 = nullptr, *lowered_tv4 = nullptr; - auto used_vals = gpulw.kernel()->usedMathVals(); - - for (auto tv : ir_utils::filterByType(used_vals)) { - if (tv->name() == 2) { - lowered_tv2 = tv; - } - if (tv->name() == 4) { - lowered_tv4 = tv; - } - } - - TORCH_INTERNAL_ASSERT( - lowered_tv2 != nullptr && lowered_tv4 != nullptr, - "tv2 or tv4 not lowered or mangled"); - - auto tv2_info = gpulw.threadPredMap().getPredicateInfo(lowered_tv2); - auto tv4_info = gpulw.threadPredMap().getPredicateInfo(lowered_tv4); - - // tv2 -> tv3 -> tv4 (shared) is the only use chain for tv2, - // and tv4 is redundantly written in tidx so tv2 is redundantly - // consumed in tidx. - TORCH_INTERNAL_ASSERT( - tv2_info.redundant_use_types.get(ParallelType::TIDx), - "TV2 is redundantly used but not detected."); - - // tv4->tv5 (global) is a redundant use chain, but - // tv4->tv6->tv7 is not, so tv4 should not be detected as - // a redundant used tensor in tidx. - TORCH_INTERNAL_ASSERT( - !tv4_info.redundant_use_types.get(ParallelType::TIDx), - "TV4 is not redundantly used but not detected."); -} - -// Test a basic swizzle pattern -TEST_F(NVFuserTest, FusionSimpleSwizzle0_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeConcreteTensor({2, 32}); - fusion.addInput(tv0); - - auto tv1 = add(tv0, IrBuilder::create(1)); - auto tv2 = add(tv1, IrBuilder::create(1)); - - fusion.addOutput(tv2); - - // Make a 2x8 Zshape tile - tv1->split(-1, 16); - tv1->split(-1, 8); - // [O, 2, 8] - - tv2->split(-1, 16); - tv2->split(-1, 4); - //[O, 4, 4] - - tv1->computeAt(tv2, 1); - tv1->swizzle(Swizzle2DType::ZShape, -2, -1); - - FusionExecutor fe; - fe.compileFusion(&fusion); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto t0 = at::randn({2, 32}, options); - auto t2 = t0 + 2.0; - auto cg_outputs = fe.runFusion({t0}); - - testValidate(&fusion, cg_outputs, {t0}, {t2}, __LINE__, __FILE__); -} - -// Test swizzle inlining -TEST_F(NVFuserTest, FusionSimpleSwizzle1_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeConcreteTensor({2, 32}); - fusion.addInput(tv0); - - auto tv1 = add(tv0, IrBuilder::create(1)); - auto tv2 = add(tv1, IrBuilder::create(1)); - auto tv3 = add(tv2, IrBuilder::create(1)); - - fusion.addOutput(tv3); - - // Make a 2x8 Zshape tile - tv2->split(-1, 16); - tv2->split(-1, 8); - // [O, 2, 8] - - tv3->split(-1, 16); - tv3->split(-1, 4); - //[O, 4, 4] - - tv2->computeAt(tv3, 1); - tv2->swizzle(Swizzle2DType::ZShape, -2, -1); - - // Inlining a producer into a swizzled consumer is ok - tv1->computeAt(tv2, -1); - - FusionExecutor fe; - fe.compileFusion(&fusion); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto t0 = at::randn({2, 32}, options); - auto t3 = t0 + 3.0; - auto cg_outputs = fe.runFusion({t0}); - - testValidate(&fusion, cg_outputs, {t0}, {t3}, __LINE__, __FILE__); -} - -// Test sync insertion and memory check in parallelized swizzles. -// In this test, data is parallel written into smem in zcurve -// pattern and then read out and output to global mem unswizzled. -TEST_F(NVFuserTest, FusionSimpleSwizzle2_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeConcreteTensor({32, 32}); - fusion.addInput(tv0); - - auto tv1 = add(tv0, IrBuilder::create(1)); - auto tv2 = add(tv1, IrBuilder::create(1)); - - fusion.addOutput(tv2); - - tv1->swizzle(Swizzle2DType::ZShape, -2, -1); - - tv1->axis(0)->parallelize(ParallelType::TIDx); - tv1->axis(1)->parallelize(ParallelType::TIDy); - - tv2->axis(0)->parallelize(ParallelType::TIDx); - tv2->axis(1)->parallelize(ParallelType::TIDy); - - // Validation should fail since TV1 is not in shared - // memory as required by sync info pass. - ASSERT_ANY_THROW(GpuLower gpulw_throw(&fusion)); - - tv1->setMemoryType(MemoryType::Shared); - - // Make sure that a sync is inserted: - bool sync_found = false; - GpuLower gpu_lw(&fusion); - auto flattened_exps = - ir_utils::flattenScopedExprs(gpu_lw.kernel()->topLevelExprs()); - - for (auto expr : flattened_exps) { - if (expr->isA()) { - sync_found = true; - } - // Will require a sync thread before any shared memory read. - for (auto inp_tv : ir_utils::filterByType(expr->inputs())) { - if (inp_tv->getMemoryType() == MemoryType::Shared) { - TORCH_INTERNAL_ASSERT( - sync_found, "Block sync required but not inserted"); - } - } - } - - FusionExecutor fe; - fe.compileFusion(&fusion); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto t0 = at::randn({32, 32}, options); - auto t2 = t0 + 2.0; - auto cg_outputs = fe.runFusion({t0}); - - testValidate(&fusion, cg_outputs, {t0}, {t2}, __LINE__, __FILE__); -} - -// Test BestEffortReplay behavior with swizzle op -TEST_F(NVFuserTest, FusionSwizzleMapping_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeConcreteTensor({2, 32}); - fusion.addInput(tv0); - - auto tv1 = add(tv0, IrBuilder::create(1)); - auto tv2 = add(tv1, IrBuilder::create(1)); - auto tv3 = add(tv2, IrBuilder::create(1)); - - fusion.addOutput(tv3); - - // Make a 2x8 Zshape tile - tv2->split(-1, 16); - tv2->split(-1, 8); - // [O, 2, 8] - - tv3->split(-1, 16); - tv3->split(-1, 4); - //[O, 4, 4] - - tv2->computeAt(tv3, 1); - tv2->swizzle(Swizzle2DType::ZShape, -2, -1); - - // Inlining a producer into a swizzled consumer is ok - tv1->computeAt(tv2, -1); - - // Check BestEffortReplay behavior with skip swizzles option on. - PairwiseRootDomainMap root_map(tv1, tv2); - - // Check producer to consumer map, - // i.e. unswizzled tensor to swizzled tensor map - //---------------------------------------------------------- - auto p2c = BestEffortReplay::replayCasP(tv2, tv1, -1, root_map).getReplay(); - auto swizzle_x_it0 = p2c.find(tv1->axis(-2)); - auto swizzle_y_it0 = p2c.find(tv1->axis(-1)); - // P2C map should exist and both the x and y map should - // map to the output of the swizzle op. - TORCH_INTERNAL_ASSERT( - swizzle_x_it0 != p2c.end() && swizzle_y_it0 != p2c.end()); - TORCH_INTERNAL_ASSERT( - swizzle_x_it0->second == tv2->axis(-2) && - swizzle_y_it0->second == tv2->axis(-1)); - - // Check consumer to producer map, - // i.e. swizzled tensor to unswizzled tensor map - //---------------------------------------------------------- - auto c2p = BestEffortReplay::replayPasC(tv1, tv2, -1, root_map).getReplay(); - - auto swizzle_op = tv2->axis(-1)->definition()->as(); - - // Find mapping for swizzle inputs - auto swizzle_x_it1 = c2p.find(swizzle_op->inX()); - auto swizzle_y_it1 = c2p.find(swizzle_op->inY()); - - // Find mapping for swizzle outputs - auto swizzle_x_it2 = c2p.find(swizzle_op->outX()); - auto swizzle_y_it2 = c2p.find(swizzle_op->outY()); - - // Input of swizzle ops will not be mapped to any - // by BestEffortReplay, as BestEffortReplay has to be - // one to one. IdGraph will further map them together. - TORCH_INTERNAL_ASSERT( - swizzle_x_it1 == c2p.end() && swizzle_y_it1 == c2p.end()); - - // Mapping for swizzle outputs should be mapped and should - // also map to the corresponding axes on the unswizzled tensor. - TORCH_INTERNAL_ASSERT( - swizzle_x_it2 != c2p.end() && swizzle_y_it2 != c2p.end()); - TORCH_INTERNAL_ASSERT( - swizzle_x_it2->second == tv1->axis(-2) && - swizzle_y_it2->second == tv1->axis(-1)); - - // Check id graph behavior - //---------------------------------------------------------- - ComputeAtMap ca_map(&fusion); - // Corresponding inputs and outputs of swizzle ops are - // map through by exact and permissive map. - TORCH_INTERNAL_ASSERT( - ca_map.areMapped(tv1->axis(-2), swizzle_op->inX(), IdMappingMode::EXACT)); - TORCH_INTERNAL_ASSERT( - ca_map.areMapped(tv1->axis(-1), swizzle_op->inY(), IdMappingMode::EXACT)); - TORCH_INTERNAL_ASSERT(ca_map.areMapped( - tv1->axis(-2), swizzle_op->outX(), IdMappingMode::EXACT)); - TORCH_INTERNAL_ASSERT(ca_map.areMapped( - tv1->axis(-1), swizzle_op->outY(), IdMappingMode::EXACT)); - - TORCH_INTERNAL_ASSERT(ca_map.areMapped( - tv1->axis(-2), swizzle_op->inX(), IdMappingMode::PERMISSIVE)); - TORCH_INTERNAL_ASSERT(ca_map.areMapped( - tv1->axis(-1), swizzle_op->inY(), IdMappingMode::PERMISSIVE)); - TORCH_INTERNAL_ASSERT(ca_map.areMapped( - tv1->axis(-2), swizzle_op->outX(), IdMappingMode::PERMISSIVE)); - TORCH_INTERNAL_ASSERT(ca_map.areMapped( - tv1->axis(-1), swizzle_op->outY(), IdMappingMode::PERMISSIVE)); -} - -// Test a basic loop swizzle pattern -TEST_F(NVFuserTest, FusionLoopSwizzle0_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeConcreteTensor({2, 32}); - fusion.addInput(tv0); - - auto tv1 = add(tv0, IrBuilder::create(1)); - auto tv2 = add(tv1, IrBuilder::create(1)); - - fusion.addOutput(tv2); - - tv2->split(-1, 16); - tv2->split(-1, 4); - //[O, 4, 4] - - tv2->swizzle(Swizzle2DType::ZShape, -2, -1, SwizzleMode::Loop); - - tv0->computeAt(tv2, -1); - - FusionExecutor fe; - fe.compileFusion(&fusion); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto t0 = at::randn({2, 32}, options); - auto t2 = t0 + 2.0; - auto cg_outputs = fe.runFusion({t0}); - - testValidate(&fusion, cg_outputs, {t0}, {t2}, __LINE__, __FILE__); -} - -// Outer block zshape pattern -TEST_F(NVFuserTest, FusionLoopSwizzle1_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeContigTensor(2); - fusion.addInput(tv0); - - auto tv1 = add(tv0, IrBuilder::create(1)); - auto tv2 = add(tv1, IrBuilder::create(1)); - - fusion.addOutput(tv2); - - tv2->split(-2, 8); - tv2->split(-1, 4); - //[I0o, I0i, I1o, I1i] - tv2->reorder({{1, 2}, {2, 1}}); - //[I0o, I1o, I0i, I1i] - - tv2->swizzle(Swizzle2DType::ZShape, 0, 1, SwizzleMode::Loop); - tv0->computeAt(tv2, -1); - - tv2->axis(0)->parallelize(ParallelType::BIDx); - tv2->axis(1)->parallelize(ParallelType::BIDy); - - FusionExecutor fe; - fe.compileFusion(&fusion); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto t0 = at::randn({45, 77}, options); - auto t2 = t0 + 2.0; - auto cg_outputs = fe.runFusion({t0}); - - testValidate(&fusion, cg_outputs, {t0}, {t2}, __LINE__, __FILE__); -} - -// Test assertion in unsupported pattern: non-leaf loop swizzle. -TEST_F(NVFuserTest, FusionLoopSwizzleCheck0_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeConcreteTensor({2, 32}); - fusion.addInput(tv0); - - auto tv1 = add(tv0, IrBuilder::create(1)); - auto tv2 = add(tv1, IrBuilder::create(1)); - - fusion.addOutput(tv2); - - tv2->split(-1, 16); - tv2->split(-1, 4); - //[O, 4, 4] - - // Swizzle the inner tile. - tv2->swizzle(Swizzle2DType::ZShape, -2, -1, SwizzleMode::Loop); - - // Make swizzle output not a leaf domain. - tv2->merge(-2); - - tv0->computeAt(tv2, -1); - - FusionExecutor fe; - ASSERT_ANY_THROW(fe.compileFusion(&fusion)); -} - -// Test assertion in unsupported pattern: half-inlined loop swizzle. -TEST_F(NVFuserTest, FusionLoopSwizzleCheck1_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeConcreteTensor({2, 32}); - fusion.addInput(tv0); - - auto tv1 = add(tv0, IrBuilder::create(1)); - auto tv2 = add(tv1, IrBuilder::create(1)); - auto tv3 = add(tv2, IrBuilder::create(1)); - - fusion.addOutput(tv3); - - //[O, 4, 4] - tv2->split(-1, 16); - tv2->split(-1, 4); - - //[O, 4, 4] - tv3->split(-1, 16); - tv3->split(-1, 4); - - // Swizzle inner tile of tv2 - tv2->swizzle(Swizzle2DType::ZShape, -2, -1, SwizzleMode::Loop); - - // Make tv2 swizzled and half-inlined (unsupported). - tv0->computeAt(tv3, -2); - - FusionExecutor fe; - ASSERT_ANY_THROW(fe.compileFusion(&fusion)); -} - -TEST_F(NVFuserTest, FusionUnsqueeze1_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - std::vector shape({10, 11}); - - auto tv0 = makeConcreteTensor(shape); - fusion.addInput(tv0); - - // [I, R] - auto tv1 = sum(tv0, {1}); - // [I, B] - auto tv2 = unsqueeze(tv1, -1); - fusion.addOutput(tv2); - - TORCH_CHECK( - tv2->nDims() == 2, "Unpected unsqueeze result: ", tv2->toString()); - TORCH_CHECK( - tv2->axis(1)->isBroadcast(), - "Unexpected unsqueeze result: ", - tv2->toString()); - - // tv1 has only one non-reduction axis. An exception should be - // thrown. - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_ANY_THROW(unsqueeze(tv1, 2)); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({10, 11}, options); - std::vector aten_inputs = {t0}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto cg_outputs = fe.runFusion(aten_inputs); - - auto ref = t0.sum(1).unsqueeze(-1); - - testValidate(&fusion, cg_outputs, aten_inputs, {ref}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionSqueeze1_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - std::vector shape({10, 11}); - - auto tv0 = makeConcreteTensor(shape); - fusion.addInput(tv0); - - // [I, B] - auto tv1 = sum(tv0, {1}, true); - // [I] - auto tv2 = squeeze(tv1, {shape[0], 1}); - fusion.addOutput(tv2); - - TORCH_CHECK( - tv2->nDims() == 2, "Unexpected squeeze result: ", tv2->toString()); - - // [I, R] - auto tv3 = sum(tv0, {1}); - // tv3 has only one non-reduction axis. The extent of the first axis - // is not one, so squeeze should fail. - // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_ANY_THROW(squeeze(tv3, {shape[0], 1})); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({10, 11}, options); - std::vector aten_inputs = {t0}; - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs); - auto cg_outputs = fe.runFusion(aten_inputs); - - auto ref = t0.sum(1, true).squeeze(-1); - - testValidate(&fusion, cg_outputs, aten_inputs, {ref}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionContigPredicate_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - auto tv1 = set(tv0); - auto tv2 = broadcast(tv1, {false, true, false}); - fusion.addOutput(tv2); - - tv2->merge(-2, -1); - tv2->merge(-2, -1); - tv2->split(-1, 100); - tv0->computeAt(tv2, -1); - - GpuLower gpulw(&fusion); - TORCH_CHECK(PredicatedChecker::isPredicated(tv1, gpulw)); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({3, 4}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0}); - auto cg_outputs = fe.runFusion({t0}); - - auto ref = t0.unsqueeze(1); - - testValidate(fe.kernel(), cg_outputs, {t0}, {ref}, __LINE__, __FILE__); -} - -// Repro of https://github.com/csarofeen/pytorch/issues/1777 -TEST_F(NVFuserTest, FusionDivScalarLhs_CUDA) { - // tv1 = 2.0 / tv0 - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - TensorView* tv1 = div(IrBuilder::create(2.0), tv0); - fusion.addOutput(tv1); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto t0 = at::randn({3, 3}, options); - // There's no overload div(Scalar, Tensor) in ATen - auto aten_output = at::div( - at::native::wrapped_scalar_tensor(at::Scalar(2.0), options.device()), t0); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0}); - auto cg_outputs = fe.runFusion({t0}); - - testValidate(&fusion, cg_outputs, {t0}, {aten_output}, __LINE__, __FILE__); -} - -// Repro of an issue of the reduction scheduler with a broadcast -// domain concretized to multiple domains that are not proven to have -// the same extent -TEST_F(NVFuserTest, FusionRepro1713_CUDA) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - auto tv0 = makeSymbolicTensor(2); - auto tv1 = makeSymbolicTensor(2); - auto tv2 = makeSymbolicTensor(1); - fusion->addInput(tv0); - fusion->addInput(tv1); - fusion->addInput(tv2); - auto tv3 = broadcast(tv2, {false, true}); - - auto tv4 = add(tv3, tv0); - - auto tv5 = add(tv3, tv1); - auto tv6 = sum(tv5, {0}); - fusion->addOutput(tv4); - fusion->addOutput(tv6); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({1024, 204800}, options); - // Original repro had the same shape as t0, but this should work - // with a different extent at the second axis - at::Tensor t1 = at::randn({1024, 123}, options); - at::Tensor t2 = at::randn({1024}, options); - std::vector aten_inputs({t0, t1, t2}); - - FusionExecutorCache executor_cache(std::move(fusion)); - auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); - - auto t3 = t2.unsqueeze(-1); - auto t4 = t3 + t0; - auto t5 = t3 + t1; - auto t6 = sum(t5, {0}); - - testValidate( - executor_cache.fusion(), - cg_outputs, - {t0, t1, t2}, - {t4, t6}, - __LINE__, - __FILE__); -} - -TEST_F(NVFuserTest, FusionExpand_CUDA) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - auto w = 2, x = 3, y = 4, z = 5; - - // Test - // a simple expand - // Expand that's propagated - // expand_as - // symbolic expand - - // x - auto tv0 = makeSymbolicTensor(1); - fusion->addInput(tv0); - - auto tv1 = broadcast(tv0, {false, true}); - auto tv2 = expand(tv1, {tv0->axis(0)->extent(), IrBuilder::create(y)}); - - // x - auto tv3 = makeSymbolicTensor(1); - fusion->addInput(tv3); - auto tv4 = broadcast(tv3, {false, true}); - auto tv5 = add(tv4, tv2); - // [x, e_y] - - // [x, y, z] - auto tv6 = makeSymbolicTensor(3); - fusion->addInput(tv6); - - // Disjoint set op will cause a segmentation for just this op. - auto tmp_7 = set(tv6); - fusion->addOutput(tmp_7); - - auto tv7 = broadcast(tv5, {false, false, true}); - - auto tv8 = expand_as(tv7, tv6); - // [x, e_y, e_z] - - auto w_symbolic = IrBuilder::create(); - fusion->addInput(w_symbolic); - - auto tv9 = broadcast(tv8, {true, false, false, false}); - //[1, x, e_y, e_z] - - auto tv10 = expand( - tv9, - {w_symbolic, - tv9->axis(1)->extent(), - tv9->axis(2)->expandedExtent(), - tv9->axis(3)->expandedExtent()}); - - fusion->addOutput(tv10); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({x}, options); - at::Tensor t3 = at::randn({x}, options); - at::Tensor t6 = at::randn({x, y, z}, options); - - FusionExecutorCache executor_cache(std::move(fusion)); - - auto cg_outputs = executor_cache.runFusionWithInputs({t0, t3, t6, w}); - auto cg_out = cg_outputs[1]; - - TORCH_INTERNAL_ASSERT(cg_out.size(0) == w); - TORCH_INTERNAL_ASSERT(cg_out.size(1) == x); - TORCH_INTERNAL_ASSERT(cg_out.size(2) == y); - TORCH_INTERNAL_ASSERT(cg_out.size(3) == z); - TORCH_INTERNAL_ASSERT(cg_out.stride(0) == 0); - TORCH_INTERNAL_ASSERT(cg_out.stride(1) == 1); - TORCH_INTERNAL_ASSERT(cg_out.stride(2) == 0); - TORCH_INTERNAL_ASSERT(cg_out.stride(3) == 0); - - auto t10 = t0.unsqueeze(-1) - .expand({x, y}) - .add(t3.unsqueeze(-1)) - .unsqueeze(-1) - .expand_as(t6) - .unsqueeze(0) - .expand({w, x, y, z}); - - testValidate( - executor_cache.fusion(), - cg_outputs, - {t0, t3, t6, w}, - {t6, t10}, - __LINE__, - __FILE__); -} - -TEST_F(NVFuserTest, FusionExpandIssue1751_CUDA) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - auto x = 3, y = 4, z = 5; - - // y, z - auto tv0 = makeSymbolicTensor(2); - fusion->addInput(tv0); - - auto tv1 = broadcast(tv0, {true, false, false}); - - // Two ways to propagate extents as is: use -1 or explicitly pass - // the extent vals. - - auto tv2 = expand( - tv1, - {IrBuilder::create(x), - IrBuilder::create(-1), - IrBuilder::create(-1)}); - - auto tv3 = expand( - tv1, - {IrBuilder::create(x), - tv0->axis(0)->extent(), - tv0->axis(1)->extent()}); - - fusion->addOutput(tv2); - fusion->addOutput(tv3); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({y, z}, options); - - FusionExecutorCache executor_cache(std::move(fusion)); - - auto cg_outputs = executor_cache.runFusionWithInputs({t0}); - - for (const auto& cg_out : cg_outputs) { - TORCH_INTERNAL_ASSERT(cg_out.size(0) == x); - TORCH_INTERNAL_ASSERT(cg_out.size(1) == y); - TORCH_INTERNAL_ASSERT(cg_out.size(2) == z); - } - - auto t2 = t0.expand({x, y, z}); - - testValidate( - executor_cache.fusion(), cg_outputs, {t0}, {t2, t2}, __LINE__, __FILE__); -} - -// TODO: Make sure the kernel uses the expanded concrete size instead -// of the symbolic size -TEST_F(NVFuserTest, FusionExpandToConcrete_CUDA) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - auto x = 3, y = 4; - - auto tv0 = makeSymbolicTensor(1); - fusion->addInput(tv0); - - auto tv1 = broadcast(tv0, {true, false}); - - auto tv2 = - expand(tv1, {IrBuilder::create(x), IrBuilder::create(y)}); - - fusion->addOutput(tv2); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({y}, options); - - FusionExecutorCache executor_cache(std::move(fusion)); - - auto cg_outputs = executor_cache.runFusionWithInputs({t0}); - - for (const auto& cg_out : cg_outputs) { - TORCH_INTERNAL_ASSERT(cg_out.size(0) == x); - TORCH_INTERNAL_ASSERT(cg_out.size(1) == y); - } - - auto t2 = t0.expand({x, y}); - - testValidate( - executor_cache.fusion(), cg_outputs, {t0}, {t2}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionReproNoncontigBroadcast_CUDA) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({4, 32, 16, 112, 112}, options).transpose(-1, -2); - at::Tensor t1 = at::randn({32, 1, 112, 1}, options).transpose(-1, -2); - - auto tv0 = TensorViewBuilder() - .ndims(5) - .contiguity({true, true, false, false, false}) // ttfff - .shape({-1, -1, -1, -1, -1}) - .dtype(DataType::Half) - .build(); - auto tv1 = TensorViewBuilder() - .ndims(4) - .contiguity({true, false, false, true}) // tfft - .shape({-1, 1, 1, -1}) - .dtype(DataType::Half) - .build(); - - fusion->addInput(tv0); - fusion->addInput(tv1); - - auto tv2 = add(tv0, tv1); - - fusion->addOutput(tv2); - - std::vector aten_inputs({t0, t1}); - - FusionExecutorCache executor_cache(std::move(fusion)); - auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); - - auto t2 = t0 + t1; - - testValidate( - executor_cache.fusion(), cg_outputs, {t0, t1}, {t2}, __LINE__, __FILE__); -} - -namespace { - -// check that the resulting sibling are identical -void checkSiblingConsistency(TensorView* replay, TensorView* target) { - auto replay_root = replay->getRootDomain(); - auto replay_dom = replay->domain()->domain(); - auto target_root = target->getRootDomain(); - auto target_dom = target->domain()->domain(); - std::unordered_map target2replay_map; - TORCH_CHECK(replay_root.size() == target_root.size()); - target2replay_map.reserve(replay_root.size()); - std::transform( - target_root.begin(), - target_root.end(), - replay_root.begin(), - std::inserter(target2replay_map, target2replay_map.begin()), - [](auto a, auto b) { return std::make_pair(a, b); }); - BestEffortReplay replay_(replay_dom, target_dom, target2replay_map); - auto r = replay_.getReplay(); - for (int64_t i = 0; i < replay_dom.size(); i++) { - auto target_id = target_dom[i]; - auto replay_it = r.find(target_id); - TORCH_CHECK(replay_it != r.end()); - TORCH_CHECK( - replay_it->second == replay_dom[i], - "IterDomain mismatch when checking ", - replay, - " and ", - target, - " at ", - i, - ", got ", - replay_it->second, - " and ", - replay_dom[i]); - } -}; - -} // namespace - -TEST_F(NVFuserTest, FusionTransformPropagateSibling_CUDA) { - // https://github.com/csarofeen/pytorch/issues/1760 - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - - auto tvs = Welford(tv0, {1}); - fusion.addOutput(tvs.var_sum); - - tvs.avg->split(1, 1); - tvs.avg->split(1, 2); - tvs.avg->split(1, 3); - tvs.var_sum->split(1, 1); - tvs.var_sum->split(1, 2); - tvs.var_sum->split(1, 3); - tvs.n->split(1, 1); - tvs.n->split(1, 2); - tvs.n->split(1, 3); - - auto var_sum_rf = ir_utils::rfactorHelper(tvs.var_sum, {1, 4}); - - TransformPropagatorWithCheck propagator(var_sum_rf); - MaxRootDomainInfoSpanningTree(var_sum_rf).traverse(&propagator); - - auto rf_tvs = ir_utils::producerTvsOf(tvs.var_sum); - - std::vector siblings[] = {{tvs.avg, tvs.var_sum, tvs.n}, rf_tvs}; - for (auto tensors : siblings) { - for (auto t1 : tensors) { - for (auto t2 : tensors) { - TORCH_CHECK(TransformReplay::fullSelfMatching(t1, t2)); - } - } - } -} - -TEST_F(NVFuserTest, FusionTransformPropagateSelectorSibling_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - - auto tvs = Welford(tv0, {1}); - fusion.addOutput(tvs.var_sum); - - tvs.avg->split(1, 1); - tvs.avg->split(1, 2); - tvs.avg->split(1, 3); - tvs.var_sum->split(1, 1); - tvs.var_sum->split(1, 2); - tvs.var_sum->split(1, 3); - tvs.n->split(1, 1); - tvs.n->split(1, 2); - tvs.n->split(1, 3); - - auto var_sum_rf = ir_utils::rfactorHelper(tvs.var_sum, {1, 4}); - - struct DisableTv0 : public MaxInfoSpanningTree::Selector { - TensorView* tv0; - virtual bool allowC2P(TensorView* from, TensorView* to) override { - return from != tv0 && to != tv0; - }; - virtual bool allowP2C(TensorView* from, TensorView* to) override { - return from != tv0 && to != tv0; - }; - virtual bool allowSibling(TensorView* from, TensorView* to) override { - return true; - } - DisableTv0(TensorView* tv0) : tv0(tv0) {} - } selector1(tv0); - - struct DisableTv0AndSibling : public DisableTv0 { - virtual bool allowSibling(TensorView* from, TensorView* to) override { - return false; - } - using DisableTv0::DisableTv0; - } selector2(tv0); - - TransformPropagatorWithCheck propagator(var_sum_rf); - MaxRootDomainInfoSpanningTree good_path(var_sum_rf, &selector1); - MaxRootDomainInfoSpanningTree bad_path(var_sum_rf, &selector2); - - auto rf_tvs = ir_utils::producerTvsOf(tvs.var_sum); - - auto check = [&]() { - std::vector siblings[] = { - {tvs.avg, tvs.var_sum, tvs.n}, rf_tvs}; - for (auto tensors : siblings) { - for (auto t1 : tensors) { - for (auto t2 : tensors) { - TORCH_CHECK(TransformReplay::fullSelfMatching(t1, t2)); - } - } - } - }; - - bad_path.traverse(&propagator); - ASSERT_ANY_THROW(check()); - good_path.traverse(&propagator); - check(); -} - -TEST_F(NVFuserTest, FusionTransformPropagatePosition_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(4); - auto tv1 = makeSymbolicTensor(6); - fusion.addInput(tv0); - - auto tv2 = broadcast(tv0, {false, false, true, false, false, true}); - auto tv3 = add(tv1, tv2); - fusion.addOutput(tv3); - - tv0->merge(2); - tv0->merge(0); - TransformPropagatorWithCheck propagator(tv0); - MaxRootDomainInfoSpanningTree(tv0).traverse(&propagator); - - TORCH_CHECK(tv1->nDims() == 4); -} - -TEST_F(NVFuserTest, FusionIgnoreZeroDimReduction_CUDA) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - auto tv0 = makeSymbolicTensor(1); - fusion->addInput(tv0); - auto tv1 = sum(tv0, {0}); - // tv1 is effectively a zero-dim tensor as it only has a reduction - // axis. - // Reducing it further is converted to just a set op. - auto tv2 = sum(tv1, {0}); - fusion->addOutput(tv2); - - auto tv2_def = dynamic_cast(tv2->definition()); - TORCH_CHECK( - tv2_def != nullptr, - "Expected UnaryOp but found ", - tv2->definition()->toString()); - - TORCH_CHECK( - tv2_def->getUnaryOpType() == UnaryOpType::Set, - "Expected UnaryOpType::Set but found ", - tv2_def->getUnaryOpType()); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto t0 = at::randn({12345}, options); - std::vector aten_inputs({t0}); - - FusionExecutorCache executor_cache(std::move(fusion)); - auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); - - auto ref = sum(t0, {0}); - - testValidate( - executor_cache.fusion(), - cg_outputs, - aten_inputs, - {ref}, - __LINE__, - __FILE__); -} - -// Repro of issue #1770 -TEST_F(NVFuserTest, FusionIssue1770Repro_CUDA) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - auto tv0 = makeSymbolicTensor(1); - fusion->addInput(tv0); - auto tv1 = makeSymbolicTensor(1); - fusion->addInput(tv1); - - auto tv2 = ge(tv0, tv1); - auto tv3 = - where(tv2, IrBuilder::create(1), IrBuilder::create(2)); - fusion->addOutput(tv3); - - std::vector shape({999}); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn(shape, options); - at::Tensor t1 = at::randn(shape, options); - std::vector aten_inputs({t0, t1}); - - FusionExecutorCache executor_cache(std::move(fusion)); - auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); - - auto ref = where(t0 >= t1, 1.0, 2.0); - - testValidate( - executor_cache.fusion(), - cg_outputs, - aten_inputs, - {ref}, - __LINE__, - __FILE__); -} - -TEST_F(NVFuserTest, FusionTransformPropagatorSelector_CUDA) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - auto tv0 = makeSymbolicTensor(1); - fusion->addInput(tv0); - auto tv1 = makeSymbolicTensor(1); - fusion->addInput(tv1); - - auto tv2 = add(tv0, tv1); - - auto tv3 = sin(tv2); - auto tv4 = cos(tv2); - - fusion->addOutput(tv3); - fusion->addOutput(tv4); - - tv2->split(0, 10); - - struct Selector : public MaxInfoSpanningTree::Selector { - TensorView* tv0; - TensorView* tv3; - virtual bool allowC2P(TensorView* from, TensorView* to) override { - return to == tv0; - } - virtual bool allowP2C(TensorView* from, TensorView* to) override { - return to == tv3; - } - virtual bool allowSibling(TensorView* from, TensorView* to) override { - return false; - } - Selector(TensorView* tv0, TensorView* tv3) : tv0(tv0), tv3(tv3) {} - } selector(tv0, tv3); - - TransformPropagatorWithCheck propagator(tv2); - MaxRootDomainInfoSpanningTree(tv2, &selector).traverse(&propagator); - - TORCH_CHECK(tv0->nDims() == 2); - TORCH_CHECK(tv1->nDims() == 1); - TORCH_CHECK(tv2->nDims() == 2); - TORCH_CHECK(tv3->nDims() == 2); - TORCH_CHECK(tv4->nDims() == 1); -} - -TEST_F(NVFuserTest, FusionTransformPropagatorPos_CUDA) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - auto tv0 = makeConcreteTensor({22, 105}); - fusion->addInput(tv0); - - auto tv1 = sin(tv0); - fusion->addOutput(tv1); - - tv1->split(0, 2); - tv1->split(-1, 3); - tv1->split(-1, 5); - - TransformPropagatorWithCheck propagator(tv1, 2); - MaxRootDomainInfoSpanningTree(tv1, 2).traverse(&propagator); - - auto expect = makeConcreteTensor({22, 105}); - expect->split(0, 2); - TORCH_CHECK(TransformReplay::fullSelfMatching(expect, tv0)); -} - -TEST_F(NVFuserTest, FusionMaxRootDomainInfoSpanningTreePrintTwice_CUDA) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - auto tv0 = makeSymbolicTensor(3); - fusion->addInput(tv0); - - auto tv1 = sum(tv0, {0}); - auto tv2 = neg(tv1); - - fusion->addOutput(tv2); - - tv1->split(0, 10); - - struct Printer : public MaxInfoSpanningTree::Propagator { - std::stringstream ss; - virtual void propagateC2P(TensorView* from, TensorView* to) override { - ss << "propagateC2P" << std::endl; - ss << "from: " << from->name() << std::endl; - ss << "to: " << to->name() << std::endl; - } - virtual void propagateP2C(TensorView* from, TensorView* to) override { - ss << "propagateP2C" << std::endl; - ss << "from: " << from->name() << std::endl; - ss << "to: " << to->name() << std::endl; - } - virtual void propagateSibling(TensorView* from, TensorView* to) override { - ss << "propagateSibling" << std::endl; - ss << "from: " << from->name() << std::endl; - ss << "to: " << to->name() << std::endl; - } - } printer1, printer2; - printer1.ss << std::endl; - printer2.ss << std::endl; - - MaxRootDomainInfoSpanningTree path(tv1); - path.traverse(&printer1); - path.traverse(&printer2); - - auto expect = R"ESCAPE( -propagateC2P -from: 1 -to: 0 -propagateP2C -from: 1 -to: 2 -)ESCAPE"; - TORCH_CHECK(printer1.ss.str() == expect); - TORCH_CHECK(printer2.ss.str() == expect); -} - -TEST_F(NVFuserTest, FusionTransformPropagatorNoOverwrite_CUDA) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - auto tv0 = makeSymbolicTensor(1); - fusion->addInput(tv0); - auto tv1 = broadcast(tv0, {true, false, true}); - auto tv2 = sin(tv1); - fusion->addOutput(tv2); - - tv0->split(0, 2); - tv2->split(1, 2); - tv2->split(0, 4); - - MaxRootDomainInfoSpanningTree path1(tv2); - TransformPropagatorWithCheck propagator1(tv2); - path1.traverse(&propagator1); - - MaxRootDomainInfoSpanningTree path2(tv0); - TransformPropagatorWithCheck propagator2(tv0); - path2.traverse(&propagator2); - - TORCH_CHECK(tv1->axis(0)->isBroadcast()); - TORCH_CHECK(tv1->axis(1)->isBroadcast()); - TORCH_CHECK(!tv1->axis(2)->isBroadcast()); - TORCH_CHECK(!tv1->axis(3)->isBroadcast()); - TORCH_CHECK(tv1->axis(4)->isBroadcast()); - - auto expect = makeSymbolicTensor(3); - expect->split(1, 2); - expect->split(0, 4); - TORCH_CHECK(TransformReplay::fullSelfMatching(expect, tv1)); -} - -TEST_F(NVFuserTest, FusionIssue1785Repro_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - // Set up your input tensor views - TensorView* tv0 = makeContigTensor(1); - TensorView* tv1 = makeContigTensor(2); - - // Register your inputs - fusion.addInput(tv0); - fusion.addInput(tv1); - - auto tv2 = set(tv0); - // [B, I] - auto tv3 = broadcast(tv2, {true, false}); - auto tv4 = add(tv3, tv1); - auto tv5 = set(tv4); - - // Register your outputs - fusion.addOutput(tv5); - - tv5->split(0, 8); - tv5->split(-1, 8); - - // [Serial, TIDy, TIDX, Serial] - - tv4->computeAt(tv5, -2); - tv3->computeAt(tv4, -1); - tv2->computeAt(tv3, 0); - tv2->split(0, 8); - tv2->axis(0)->parallelize(ParallelType::TIDx); - tv1->computeAt(tv5, -2); - - tv5->axis(1)->parallelize(ParallelType::TIDy); - tv5->axis(2)->parallelize(ParallelType::TIDx); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::Tensor in1 = at::randn({16}, options); - at::Tensor in2 = at::randn({12, 16}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {in1, in2}); - auto cg_outputs = fe.runFusion({in1, in2}); - - auto tv_ref = in1 + in2; - - testValidate(&fusion, cg_outputs, {in1, in2}, {tv_ref}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionSkipReplay_CUDA) { - { - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeContigTensor(1); - TensorView* tv1 = makeContigTensor(2); - fusion.addInput(tv0); - fusion.addInput(tv1); - - auto tv2 = broadcast(tv0, {false, true}); - auto tv3 = add(tv2, tv1); - fusion.addOutput(tv3); - - tv3->split(1, 2, false); - - TransformPropagatorWithCheck propagator(tv3); - MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); - } - - { - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeContigTensor(3); - fusion.addInput(tv0); - - auto tv1 = sum(tv0, {0, 2}); - auto tv2 = sin(tv1); - fusion.addOutput(tv2); - - tv0->split(1, 2, false); - - TransformPropagatorWithCheck propagator(tv0); - MaxRootDomainInfoSpanningTree(tv0).traverse(&propagator); - } -} - -TEST_F(NVFuserTest, FusionInlineRepro1803_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeContigTensor(2); - - fusion.addInput(tv0); - auto tv1 = set(tv0); - auto tvs = Welford(tv1, {1}); - auto tvo = set(tvs.var_sum); - fusion.addOutput(tvo); - - tvo->split(0, 16); - tvo->axis(1)->parallelize(ParallelType::Unroll); - - tv0->computeAt(tvo, -1, ComputeAtMode::BestEffort); - - TORCH_CHECK( - tvs.var_sum->getComputeAtPosition() == tvs.avg->getComputeAtPosition()); - TORCH_CHECK( - tvs.var_sum->getComputeAtPosition() == tvs.n->getComputeAtPosition()); - TORCH_CHECK(tvs.var_sum->getComputeAtPosition() == 1); -} - -// Unit test for the transform selection logic -TEST_F(NVFuserTest, FusionBoundedDirectionSelection1_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - TensorView* tv0 = makeContigTensor(2); - - fusion.addInput(tv0); - auto tv1 = set(tv0); - auto tv2 = set(tv1); - auto tv3 = add(tv2, tv1); - fusion.addOutput(tv3); - - tv3->split(-1, 5); - tv3->split(-1, 8); - - scheduler_utils::BoundedDirectionalTransformPropagator::backward( - tv3, -1, {tv0, tv2}); - - // Check that the splits are replayed on tv1, even though tv2 - // is part of the boundary. - TORCH_INTERNAL_ASSERT( - tv2->nDims() == 4, "Propagator didn't propagate to tv2"); -} - -TEST_F(NVFuserTest, FusionIssueRepro1844_CUDA) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - std::vector shape = {2, 1, 768}; - std::vector sum_to_shape = {768}; - std::vector sum_to_axes = {0, 1}; - double kProb = 0.5; - - std::vector sum_to_symb; - std::transform( - sum_to_shape.begin(), - sum_to_shape.end(), - std::back_inserter(sum_to_symb), - [](int s) -> Int* { return IrBuilder::create(s); }); - - TensorView* tv0 = makeContigConcreteTensor(shape); - TensorView* tv1 = makeContigConcreteTensor(shape); - TensorView* tv2 = makeContigConcreteTensor(shape, DataType::Bool); - - fusion->addInput(tv0); - fusion->addInput(tv1); - fusion->addInput(tv2); - - Double* prob = IrBuilder::create(kProb); - auto grad_input = dropout_backward(tv1, tv2, prob); - auto grad_gelu = gelu_backward(grad_input, tv0); - auto grad_bias = sum_to(grad_gelu, sum_to_symb); - - fusion->addOutput(grad_gelu); - fusion->addOutput(grad_bias); - - const auto options = - at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - const auto mask_options = - at::TensorOptions().dtype(at::kBool).device(at::kCUDA, 0); - at::manual_seed(0); - - at::Tensor a = at::randn(shape, options); - at::Tensor b = at::randn(shape, options); - at::Tensor c = at::randn(shape, options); - auto mask = at::gt(c, 0.0f); - std::vector aten_inputs = {a, b, mask}; - - FusionExecutorCache executor_cache(std::move(fusion)); - auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); - - auto dinput = at::native_dropout_backward(b, mask, kProb); - auto dgelu = at::gelu_backward(dinput, a, "none"); - auto dbias = dgelu.sum(sum_to_axes); - - testValidate( - executor_cache.fusion(), - cg_outputs, - aten_inputs, - {dgelu, dbias}, - __LINE__, - __FILE__); -} - -TEST_F(NVFuserTest, FusionInsertMagicZero1_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - - auto tv1 = add(tv0, IrBuilder::create(1)); - auto tv2 = set(tv1); - fusion.addOutput(tv2); - - tv2->split(0, 32); - tv2->split(-1, 2); - tv2->reorder({{1, 2}, {2, 1}}); - tv2->merge(0); - - TransformPropagatorWithCheck propagator(tv2); - MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator); - - tv0->computeAt(tv2, 1); - - // The predicate of tv2 should be protected with magic zero - GpuLower gpulw(&fusion); - TORCH_CHECK( - PredicateMagicZeroChecker::isProtected(tv2, gpulw), - "Failed to protect the predicates of ", - tv2->toString()); -} - -TEST_F(NVFuserTest, FusionRepro1860_CUDA) { - auto fusion_ptr = std::make_unique(); - Fusion& fusion = *fusion_ptr; - FusionGuard fg(&fusion); - std::vector contiguity{true, false, false}; - - std::vector shape{1, -1, -1}; - TensorView* tv0 = makeContigConcreteTensor(shape); - fusion.addInput(tv0); - TensorView* tv1 = makeContigConcreteTensor(shape); - fusion.addInput(tv1); - TensorView* tv2 = makeContigConcreteTensor(shape); - fusion.addInput(tv2); - - std::vector domain1(3, nullptr); - for (const auto i : c10::irange(3)) { - if (i == 0) { - domain1[i] = - IterDomainBuilder( - FusionGuard::getCurFusion()->zeroVal(), IrBuilder::create(1)) - .iter_type(IterType::Broadcast) - .build(); - } else { - domain1[i] = - IterDomainBuilder( - FusionGuard::getCurFusion()->zeroVal(), IrBuilder::create(1)) - .expanded_extent(IrBuilder::create(1 + i)) - .iter_type(IterType::Broadcast) - .build(); - } - } - - TensorView* tv22 = IrBuilder::create( - IrBuilder::create(domain1, contiguity), DataType::Float); - - fusion.addInput(tv22); - - auto tv3 = add(tv0, tv1); - auto tv4 = softmax(tv3, 0); - auto tv5 = add(tv4, tv22); - fusion.addOutput(tv5); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::Tensor input1 = at::randn({1, 2, 3}, options); - at::Tensor input2 = at::randn({1, 2, 3}, options); - at::Tensor input3 = at::randn({1, 2, 3}, options); - at::Tensor input4 = at::randn({1, 1, 1}, options).expand({1, 2, 3}); - std::vector aten_inputs = {input1, input2, input3, input4}; - - FusionExecutorCache executor_cache(std::move(fusion_ptr)); - auto outputs = executor_cache.runFusionWithInputs(aten_inputs); -} - -TEST_F(NVFuserTest, FusionExpandReduce_CUDA) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - auto tv0 = makeConcreteTensor({1, 8}); - fusion->addInput(tv0); - - auto tv1 = - expand(tv0, {IrBuilder::create(12), IrBuilder::create(8)}); - - auto tv2 = sum(tv1, {0}); - fusion->addOutput(tv2); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::manual_seed(0); - auto t0 = at::randn({1, 8}, options); - - FusionExecutorCache executor_cache(std::move(fusion)); - auto cg_outputs = executor_cache.runFusionWithInputs({t0}); - - auto ref = t0.expand({12, 8}).sum({0}); - - testValidate( - executor_cache.fusion(), cg_outputs, {t0}, {ref}, __LINE__, __FILE__); -} - -// Predicate elimination issue repro: -TEST_F(NVFuserTest, FusionExpandReduce2_CUDA) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - auto tv0 = makeConcreteTensor({1, 4}); - fusion->addInput(tv0); - - auto tv1 = - expand(tv0, {IrBuilder::create(3), IrBuilder::create(4)}); - - auto tv2 = sum(tv1, {0}); - fusion->addOutput(tv2); - - // tv2[r{3}, i{4}] - tv2->split(0, NamedScalar::getParallelDim(ParallelType::TIDy)); - tv2->axis(1)->parallelize(ParallelType::TIDy); - tv2->split(0, NamedScalar::getParallelDim(ParallelType::BIDy), false); - tv2->axis(0)->parallelize(ParallelType::BIDy); - tv2->split(-1, NamedScalar::getParallelDim(ParallelType::TIDx)); - tv2->axis(-1)->parallelize(ParallelType::TIDx); - tv2->axis(-2)->parallelize(ParallelType::BIDx); - // [rBIDy, rO, rTIDy, iBIDx, iTIDx] - tv2->reorder({{-2, 0}, {-1, 1}, {2, 2}}); - // [iBIDx, iTIDx, rTIDy, rBIDy, rO] - auto tv3 = tv2->rFactor({-1}); - - TransformPropagatorWithCheck propagator(tv3); - MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); - scheduler_utils::parallelizeAllLike(tv3); - tv0->computeAt(tv3, -1, ComputeAtMode::MostInlined); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::manual_seed(0); - auto t0 = at::randn({1, 4}, options); - - FusionExecutor fe; - fe.compileFusion(fusion.get(), {t0}, LaunchParams(-1, 2, -1, 4, 2, 1)); - auto cg_outputs = fe.runFusion({t0}, LaunchParams(-1, 2, -1, 4, 2, 1)); - - auto ref = t0.expand({3, 4}).sum({0}); - - testValidate( - fusion.get(), - cg_outputs, - {t0}, - {ref}, - __LINE__, - __FILE__, - "", - LaunchParams(-1, 2, -1, 4, 2, 1)); -} - -TEST_F(NVFuserTest, FusionExpandBadShapeTest_CUDA) { - auto fusion_ptr = std::make_unique(); - Fusion& fusion = *fusion_ptr; - FusionGuard fg(&fusion); - std::vector contiguity{false, false}; - - auto tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - - std::vector domains = { - IterDomainBuilder( - FusionGuard::getCurFusion()->zeroVal(), IrBuilder::create()) - .build(), - IterDomainBuilder( - FusionGuard::getCurFusion()->zeroVal(), IrBuilder::create(1)) - .expanded_extent(IrBuilder::create(10)) - .iter_type(IterType::Broadcast) - .build()}; - - // expand to 10 - TensorView* tv22 = IrBuilder::create( - IrBuilder::create(domains, contiguity), DataType::Float); - - fusion.addInput(tv22); - - auto tv3 = add(tv0, tv22); - fusion.addOutput(tv3); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - // Incompatible shapes - at::Tensor input1 = at::randn({2, 3}, options); - // Passing expand size of 5, not 10. Should cause an error - at::Tensor input4 = at::randn({2, 1}, options).expand({2, 5}); - - std::vector aten_inputs = {input1, input4}; - - FusionExecutorCache executor_cache(std::move(fusion_ptr)); - ASSERT_ANY_THROW(executor_cache.runFusionWithInputs(aten_inputs)); -} - -TEST_F( - NVFuserTest, - FusionPointwiseScheduleWithBroadcastAndTrivialReduction_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeContigTensor(3); - auto tv1 = makeContigTensor(2); - fusion.addInput(tv0); - fusion.addInput(tv1); - auto tv2 = broadcast(tv0, {false, true, false, true, false, true}); - auto tv3 = sin(tv2); - auto tv4 = add(tv3, tv1); - auto tv5 = sum(tv4, {1}); - fusion.addOutput(tv5); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({100, 100, 10}, options); - at::Tensor t1 = at::randn({10, 20}, options); - - auto aten_output = (t0.view({100, 1, 100, 1, 10, 1}).sin() + t1).squeeze(1); - - std::vector aten_inputs = {t0, t1}; - - auto lparams = schedulePointwise(&fusion, aten_inputs); - - FusionExecutor fe; - fe.compileFusion(&fusion, aten_inputs, lparams); - auto cg_outputs = fe.runFusion(aten_inputs, lparams); - - testValidate( - &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionInlinePropagatorMismatchedDims1_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeConcreteTensor({2, 3, 4}); - fusion.addInput(tv0); - auto tv1 = sin(tv0); - auto tv2 = cos(tv1); - auto tv3 = transpose(tv2, 1, 2); - auto tv4 = exp(tv3); - auto tv5 = tan(tv4); - fusion.addOutput(tv5); - - InlinePropagator inline_propagator(tv5, -1, ComputeAtMode::MostInlined); - MaxRootDomainInfoSpanningTree(tv5).traverse(&inline_propagator); - - TORCH_CHECK(tv5->getComputeAtPosition() == 3); - TORCH_CHECK(tv4->getComputeAtPosition() == 3); - TORCH_CHECK(tv3->getComputeAtPosition() == 3); - TORCH_CHECK(tv2->getComputeAtPosition() == 1); - TORCH_CHECK(tv1->getComputeAtPosition() == 3); - - const auto options = - at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input = at::randn({2, 3, 4}, options); - auto output = input.sin().cos().transpose(1, 2).exp().tan(); - - FusionExecutor fe; - fe.compileFusion(&fusion, {input}); - auto cg_outputs = fe.runFusion({input}); - - testValidate(&fusion, cg_outputs, {input}, {output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionInlinePropagatorMismatchedDims2_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeConcreteTensor({2, 3, 4}); - fusion.addInput(tv0); - auto tv1 = sin(tv0); - auto tv2 = cos(tv1); - auto tv3 = transpose(tv2, 1, 2); - auto tv4 = exp(tv3); - auto tv5 = tan(tv4); - fusion.addOutput(tv5); - - InlinePropagator inline_propagator(tv5, -1, ComputeAtMode::BestEffort); - MaxRootDomainInfoSpanningTree(tv5).traverse(&inline_propagator); - - TORCH_CHECK(tv5->getComputeAtPosition() == 3); - TORCH_CHECK(tv4->getComputeAtPosition() == 3); - TORCH_CHECK(tv3->getComputeAtPosition() == 3); - TORCH_CHECK(tv2->getComputeAtPosition() == 1); - TORCH_CHECK(tv1->getComputeAtPosition() == 1); - - const auto options = - at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input = at::randn({2, 3, 4}, options); - auto output = input.sin().cos().transpose(1, 2).exp().tan(); - - FusionExecutor fe; - fe.compileFusion(&fusion, {input}); - auto cg_outputs = fe.runFusion({input}); - - testValidate(&fusion, cg_outputs, {input}, {output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionInlinePropagatorMismatchedDims3_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeConcreteTensor({2, 3, 4}); - fusion.addInput(tv0); - auto tv1 = sin(tv0); - // broadcasting - auto tv2 = broadcast(tv1, {false, true, false, true, false, true}); - auto tv3 = relu(tv2); - // trivial reduction - auto tv4 = sum(tv3, {1, 3, 5}); - auto tv5 = cos(tv4); - auto tv6 = transpose(tv5, 1, 2); - auto tv7 = exp(tv6); - auto tv8 = tan(tv7); - fusion.addOutput(tv8); - - for (auto tv : {tv2, tv3, tv4}) { - tv->merge(0); - tv->merge(1); - tv->merge(2); - } - - InlinePropagator inline_propagator(tv8, -1, ComputeAtMode::MostInlined); - MaxRootDomainInfoSpanningTree(tv8).traverse(&inline_propagator); - - TORCH_CHECK(tv8->getComputeAtPosition() == 3); - TORCH_CHECK(tv7->getComputeAtPosition() == 3); - TORCH_CHECK(tv6->getComputeAtPosition() == 3); - TORCH_CHECK(tv5->getComputeAtPosition() == 1); - TORCH_CHECK(tv4->getComputeAtPosition() == 3); - TORCH_CHECK(tv3->getComputeAtPosition() == 3); - TORCH_CHECK(tv2->getComputeAtPosition() == 3); - TORCH_CHECK(tv1->getComputeAtPosition() == 3); - - const auto options = - at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input = at::randn({2, 3, 4}, options); - auto output = input.sin().relu().cos().transpose(1, 2).exp().tan(); - - FusionExecutor fe; - fe.compileFusion(&fusion, {input}); - auto cg_outputs = fe.runFusion({input}); - - testValidate(&fusion, cg_outputs, {input}, {output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionInlinePropagatorMismatchedDims4_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeConcreteTensor({2, 3, 4}); - fusion.addInput(tv0); - auto tv1 = sin(tv0); - auto tv2 = exp(tv1); - auto tv3 = relu(tv2); - auto tv4 = cos(tv3); - auto tv5 = tan(tv4); - fusion.addOutput(tv5); - - tv3->merge(1); - InlinePropagator inline_propagator(tv0, -1, ComputeAtMode::MostInlined); - MaxRootDomainInfoSpanningTree(tv0).traverse(&inline_propagator); - - TORCH_CHECK(tv5->getComputeAtPosition() == 3); - TORCH_CHECK(tv4->getComputeAtPosition() == 3); - TORCH_CHECK(tv3->getComputeAtPosition() == 1); - TORCH_CHECK(tv2->getComputeAtPosition() == 1); - TORCH_CHECK(tv1->getComputeAtPosition() == 3); - - const auto options = - at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input = at::randn({2, 3, 4}, options); - auto output = input.sin().exp().relu().cos().tan(); - - FusionExecutor fe; - fe.compileFusion(&fusion, {input}); - auto cg_outputs = fe.runFusion({input}); - - testValidate(&fusion, cg_outputs, {input}, {output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionInlinePropagatorBroadcast_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeConcreteTensor({2, 3, 4}); - fusion.addInput(tv0); - auto tv1 = sin(tv0); - // broadcasting - auto tv2 = broadcast(tv1, {false, true, false, true, false, true}); - auto tv3 = cos(tv2); - auto tv4 = tan(tv3); - fusion.addOutput(tv4); - - for (auto tv : {tv2, tv3, tv4}) { - tv->merge(0); - tv->merge(1); - tv->merge(2); - } - - InlinePropagator inline_propagator(tv0, -1, ComputeAtMode::MostInlined); - MaxRootDomainInfoSpanningTree(tv0).traverse(&inline_propagator); - - TORCH_CHECK(tv4->getComputeAtPosition() == 3); - TORCH_CHECK(tv3->getComputeAtPosition() == 3); - TORCH_CHECK(tv2->getComputeAtPosition() == 3); - TORCH_CHECK(tv1->getComputeAtPosition() == 3); - - const auto options = - at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input = at::randn({2, 3, 4}, options); - auto output = input.sin().view({2, 1, 3, 1, 4, 1}).cos().tan(); - - FusionExecutor fe; - fe.compileFusion(&fusion, {input}); - auto cg_outputs = fe.runFusion({input}); - - testValidate(&fusion, cg_outputs, {input}, {output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionInlinePropagatorBroadcastTrivialReduction_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeConcreteTensor({2, 3, 4}); - fusion.addInput(tv0); - auto tv1 = sin(tv0); - // broadcasting - auto tv2 = broadcast(tv1, {false, true, false, true, false, true}); - auto tv3 = tan(tv2); - // trivial reduction - auto tv4 = sum(tv3, {1, 3, 5}); - auto tv5 = cos(tv4); - auto tv6 = exp(tv5); - fusion.addOutput(tv6); - - for (auto tv : {tv2, tv3, tv4}) { - tv->merge(0); - tv->merge(1); - tv->merge(2); - } - - InlinePropagator inline_propagator(tv6, -1, ComputeAtMode::MostInlined); - MaxRootDomainInfoSpanningTree(tv6).traverse(&inline_propagator); - - TORCH_CHECK(tv6->getComputeAtPosition() == 3); - TORCH_CHECK(tv5->getComputeAtPosition() == 3); - TORCH_CHECK(tv4->getComputeAtPosition() == 3); - TORCH_CHECK(tv3->getComputeAtPosition() == 3); - TORCH_CHECK(tv2->getComputeAtPosition() == 3); - TORCH_CHECK(tv1->getComputeAtPosition() == 3); - - const auto options = - at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input = at::randn({2, 3, 4}, options); - auto output = input.sin().tan().cos().exp(); - - FusionExecutor fe; - fe.compileFusion(&fusion, {input}); - auto cg_outputs = fe.runFusion({input}); - - testValidate(&fusion, cg_outputs, {input}, {output}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionMatchedLeafPosWithoutReplayTrivialReduction_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeConcreteTensor({2, 1, 3, 1, 4, 1}); - fusion.addInput(tv0); - auto tv1 = sum(tv0, {1, 3, 5}); - auto tv2 = sin(tv1); - fusion.addOutput(tv1); - - for (auto tv : {tv0, tv1}) { - tv->merge(0); - tv->merge(1); - tv->merge(2); - } - - TORCH_CHECK( - TransformReplay::getMatchedLeafPosWithoutReplayPasC(tv0, tv1, 3) == 3); - TORCH_CHECK( - TransformReplay::getMatchedLeafPosWithoutReplayCasP(tv1, tv0, 3) == 3); - TORCH_CHECK( - TransformReplay::getMatchedLeafPosWithoutReplayPasC(tv1, tv2, 3) == 3); - TORCH_CHECK( - TransformReplay::getMatchedLeafPosWithoutReplayCasP(tv2, tv1, 3) == 3); -} - -TEST_F(NVFuserTest, FusionMatchedLeafPosWithoutReplayBroadcast_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeConcreteTensor({2, 3, 4}); - fusion.addInput(tv0); - auto tv1 = broadcast(tv0, {false, true, false, true, false, true}); - auto tv2 = sin(tv1); - fusion.addOutput(tv2); - - for (auto tv : {tv1, tv2}) { - tv->merge(0); - tv->merge(1); - tv->merge(2); - } - - TORCH_CHECK( - TransformReplay::getMatchedLeafPosWithoutReplayPasC(tv0, tv1, 3) == 3); - TORCH_CHECK( - TransformReplay::getMatchedLeafPosWithoutReplayCasP(tv1, tv0, 3) == 3); - TORCH_CHECK( - TransformReplay::getMatchedLeafPosWithoutReplayPasC(tv1, tv2, 3) == 3); - TORCH_CHECK( - TransformReplay::getMatchedLeafPosWithoutReplayCasP(tv2, tv1, 3) == 3); -} - -TEST_F(NVFuserTest, FusionIdGraphTrivialReduction_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeConcreteTensor({2, 3, 4}); - fusion.addInput(tv0); - auto tv1 = broadcast(tv0, {false, true, false, true, false, true}); - auto tv2 = sum(tv1, {1, 3, 5}); - auto tv3 = sin(tv2); - fusion.addOutput(tv3); - - for (auto tv : {tv1, tv2}) { - tv->merge(0); - tv->merge(1); - tv->merge(2); - } - - InlinePropagator inline_propagator(tv3, -1, ComputeAtMode::MostInlined); - MaxRootDomainInfoSpanningTree(tv3).traverse(&inline_propagator); - - ComputeAtMap ca_map(&fusion); - - auto all_tvs = ir_utils::allTvs(&fusion); - for (auto tv1 : all_tvs) { - for (auto tv2 : all_tvs) { - if (tv1->isFusionInput() || tv2->isFusionInput()) { - continue; - } - for (int i : c10::irange(3)) { - auto id1 = tv1->axis(i); - auto id2 = tv2->axis(i); - TORCH_CHECK(ca_map.areMapped(id1, id2, IdMappingMode::LOOP)); - TORCH_CHECK(ca_map.areMapped(id1, id2, IdMappingMode::PERMISSIVE)); - } - } - } -} - -TEST_F(NVFuserTest, FusionPrint_CUDA) { - auto dtypes = { - at::kFloat, - at::kDouble, - at::kHalf, - at::kBFloat16, - at::kInt, - at::kLong, - at::kBool}; - for (auto dtype : dtypes) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - auto tv0 = makeSymbolicTensor(1, aten_to_data_type(dtype)); - fusion->addInput(tv0); - auto tv1 = print(tv0); - auto tv2 = sin(tv1); - fusion->addOutput(tv2); - - // There is no way to check if anything is printed to the console, but we - // can validate that when print exist, compilation and computation are not - // broken. - auto options = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0); - at::Tensor t0 = at::arange(2, options).to(dtype); - - FusionExecutorCache executor_cache(std::move(fusion)); - auto cg_outputs = executor_cache.runFusionWithInputs({t0}); - - testValidate( - executor_cache.fusion(), - cg_outputs, - {t0}, - {t0.sin()}, - __LINE__, - __FILE__); - } -} - -TEST_F(NVFuserTest, FusionCheckedSymbolicShape_CUDA) { - const auto options = - at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::Tensor a = at::randn({123, 456}, options); - at::Tensor b = at::randn({123, 456}, options); - at::Tensor c = at::randn({321, 654}, options); - - using return_t = - std::pair, std::vector>; - auto matched_add = [](at::Tensor a, at::Tensor b) -> return_t { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - Val* s1 = IrBuilder::create(); - Val* s2 = IrBuilder::create(); - auto builder = TensorViewBuilder().shape(std::vector{s1, s2}); - TensorView* tv0 = builder.build(); - TensorView* tv1 = builder.build(); - - fusion->addInput(tv0); - fusion->addInput(tv1); - - auto tv2 = add(tv0, tv1); - - fusion->addOutput(tv2); - - auto executor_cache = - std::make_unique(std::move(fusion)); - auto cg_outputs = executor_cache->runFusionWithInputs({a, b}); - return {std::move(executor_cache), std::move(cg_outputs)}; - }; - - { - auto ret1 = matched_add(a, b); - testValidate( - ret1.first->fusion(), ret1.second, {a, b}, {a + b}, __LINE__, __FILE__); - } - - { - EXPECT_THAT( - [&]() { matched_add(a, c); }, - ::testing::ThrowsMessage( - ::testing::HasSubstr("Attempting to bind"))); - } -} - -TEST_F(NVFuserTest, FusionSizeDependentData_CUDA) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - Val* s1 = IrBuilder::create(); - auto builder = TensorViewBuilder().shape(std::vector{s1}); - TensorView* tv0 = builder.build(); - - fusion->addInput(tv0); - - auto tv1 = add(tv0, s1); - - fusion->addOutput(tv1); - - const auto options = - at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::Tensor a = at::zeros({123}, options); - - FusionExecutorCache executor_cache(std::move(fusion)); - auto cg_outputs = executor_cache.runFusionWithInputs({a}); - - testValidate( - executor_cache.fusion(), cg_outputs, {a}, {a + 123}, __LINE__, __FILE__); -} - -// Repro for issue #1925 -TEST_F(NVFuserTest, FusionScheduleTransposeRepro1_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(4); - auto tv1 = makeConcreteTensor({-1, -1, -1, 1}); - fusion.addInput(tv0); - fusion.addInput(tv1); - auto tv2 = add(tv0, tv1); - fusion.addOutput(tv2); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input0 = at::randn({1, 1, 333, 1}, options); - at::Tensor input1 = at::randn({1, 1, 333, 1}, options); - - auto lparams = scheduleTranspose(&fusion, {input0, input1}); - - FusionExecutor fe; - fe.compileFusion(&fusion, {input0, input1}, lparams); - auto outputs = fe.runFusion({input0, input1}, lparams); - - auto tv_ref = input0 + input1; - - testValidate( - &fusion, outputs, {input0, input1}, {tv_ref}, __LINE__, __FILE__); -} - -// Repro for issue #1873 -TEST_F(NVFuserTest, FusionInlineBroadcastIndexing0_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeContigTensor(1); - auto tv1 = makeContigTensor(2); - fusion.addInput(tv0); - fusion.addInput(tv1); - auto tv2 = set(tv0); - auto tv3 = broadcast(tv2, {true, false}); - auto tv4 = add(tv3, tv1); - fusion.addOutput(tv4); - - tv4->merge(0); - tv4->split(0, 32); - - tv0->computeAt(tv4, 1); - - tv2->split(-1, 8); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({123}, options); - at::Tensor t1 = at::randn({3, 123}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0, t1}); - - auto outputs = fe.runFusion({t0, t1}); - - auto tv_ref = t0 + t1; - - testValidate(&fusion, outputs, {t0, t1}, {tv_ref}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionPredicateUnshare_CUDA) { - // https://github.com/csarofeen/pytorch/issues/1926 - std::unique_ptr fusion_ptr = std::make_unique(); - auto fusion = fusion_ptr.get(); - FusionGuard fg(fusion); - - TensorView* tv0 = makeSymbolicTensor(2); - fusion->addInput(tv0); - auto tv1 = set(tv0); - auto tv2 = set(tv1); - fusion->addOutput(tv2); - - tv1->setMemoryType(MemoryType::Shared); - for (auto tv : {tv1, tv2}) { - tv->split(0, 4); - tv->reorder({{1, -1}}); - tv->split(1, 8); - tv->merge(0); - tv->split(0, 1); - tv->axis(0)->parallelize(ParallelType::BIDx); - tv->axis(1)->parallelize(ParallelType::Unswitch); - } - tv1->merge(2); - tv2->reorder({{2, 3}}); - tv2->merge(2); - for (auto tv : {tv1, tv2}) { - tv->axis(-1)->parallelize(ParallelType::TIDx); - } - - InlinePropagator propagator(tv2, -1, ComputeAtMode::MostInlined); - MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator); - - auto options = at::TensorOptions().dtype(kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({5, 5}, options); - - FusionExecutor fe; - fe.compileFusion(fusion, {t0}); - auto cg_outputs = fe.runFusion({t0}); - auto out = cg_outputs[0]; - - testValidate(fusion, {out}, {t0}, {t0}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, AsyncCompilation_CUDA) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - TensorView* tv0 = makeSymbolicTensor(2); - TensorView* tv1 = makeSymbolicTensor(1); - TensorView* tv2 = makeSymbolicTensor(2); - - fusion->addInput(tv0); - fusion->addInput(tv1); - fusion->addInput(tv2); - - TensorView* tv3 = add(tv0, IrBuilder::create(1)); // Group 0 - TensorView* tv4 = - max(tv3, {0}); // Group 0 (use max instead to avoid numerical issues) - TensorView* tv5 = add(tv4, tv1); // Group 0 (Non Broadcast after reduce, - // keeps normalization scheduler away) - TensorView* tv6 = add(tv5, tv2); // Group 1 (Broadcast after reduce) - - fusion->addOutput(tv6); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - - at::Tensor t0 = at::randn({8, 5}, options); - at::Tensor t1 = at::randn({5}, options); - at::Tensor t2 = at::randn({8, 5}, options); - - auto t3 = t0.add(1.0); - auto t4 = std::get<0>(at::max(t3, 0)); - auto t5 = t4.add(t1); - auto t6 = t5.add(t2); - - FusionExecutorCache executor_cache(std::move(fusion)); - - std::vector aten_inputs = {t0, t1, t2}; - - executor_cache.compileFusionAsync(aten_inputs); - - while (!executor_cache.isCompiled(aten_inputs)) { - std::this_thread::sleep_for(std::chrono::milliseconds(20)); - printf("."); - } - - auto outputs = executor_cache.runFusionWithInputs(aten_inputs); - - TORCH_CHECK( - executor_cache.getMostRecentKernelRuntime()->isSegmented(), - "segmentation didn't happen"); - TORCH_CHECK( - executor_cache.getMostRecentKernelRuntime() - ->fusionSegments() - ->groups() - .size() == 2, - "segmentation didn't happen as expected"); - - testValidate( - executor_cache.fusion(), outputs, aten_inputs, {t6}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionMergeBroadcastingTrivialReduction1_CUDA) { - std::unique_ptr fusion_ptr = std::make_unique(); - auto fusion = fusion_ptr.get(); - FusionGuard fg(fusion); - - TensorView* tv0 = makeConcreteTensor({1, 1}); - TensorView* tv1 = makeConcreteTensor({-1}); - fusion->addInput(tv0); - fusion->addInput(tv1); - auto tv2 = sum(tv0, {1}); - auto tv3 = add(tv2, tv1); - fusion->addOutput(tv3); - - tv0->merge(0); - - MaxRootDomainInfoSpanningTree tree(tv0); - TransformPropagatorWithCheck tp(tv0); - tree.traverse(&tp); - - InlinePropagator ip(tv0, -1, ComputeAtMode::MostInlined); - tree.traverse(&ip); - - auto options = at::TensorOptions().dtype(kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({1, 1}, options); - at::Tensor t1 = at::randn({10}, options); - - FusionExecutor fe; - fe.compileFusion(fusion, {t0, t1}); - auto cg_outputs = fe.runFusion({t0, t1}); - auto out = cg_outputs[0]; - - testValidate( - fusion, {out}, {t0, t1}, {t1 + t0.flatten()}, __LINE__, __FILE__); -} - -TEST_F(NVFuserTest, FusionMergeBroadcastingTrivialReduction2_CUDA) { - std::unique_ptr fusion_ptr = std::make_unique(); - auto fusion = fusion_ptr.get(); - FusionGuard fg(fusion); - - TensorView* tv0 = makeConcreteTensor({-1, 1, 1}); - TensorView* tv1 = makeConcreteTensor({-1, -1}); - fusion->addInput(tv0); - fusion->addInput(tv1); - auto tv2 = sum(tv0, {1}); - auto tv3 = add(tv2, tv1); - fusion->addOutput(tv3); - - tv2->merge(1); - tv2->merge(0); - - MaxRootDomainInfoSpanningTree tree(tv0); - TransformPropagatorWithCheck tp(tv0); - tree.traverse(&tp); - - InlinePropagator ip(tv0, -1, ComputeAtMode::MostInlined); - tree.traverse(&ip); - - auto options = at::TensorOptions().dtype(kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({10, 1, 1}, options); - at::Tensor t1 = at::randn({10, 10}, options); - - FusionExecutor fe; - fe.compileFusion(fusion, {t0, t1}); - auto cg_outputs = fe.runFusion({t0, t1}); - auto out = cg_outputs[0]; - - testValidate( - fusion, {out}, {t0, t1}, {t1 + t0.squeeze(-1)}, __LINE__, __FILE__); -} - -} // namespace jit -} // namespace torch -#endif // #if defined(USE_CUDA) diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu1.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu1.cpp new file mode 100644 index 0000000000000..2a14695b53ff2 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu1.cpp @@ -0,0 +1,9985 @@ +#if defined(USE_CUDA) +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include + +// Tests go in torch::jit +namespace torch { +namespace jit { + +using namespace torch::jit::fuser::cuda; +using namespace at::indexing; + +// A few smoke tests for IrGraphGenerator +// (These tests exercise IrGraphGenerator through a non-trivial IR, +// to make sure that it runs w/o crashing. The actual output is not +// validated) +TEST_F(NVFuserTest, FusionIrGraphGenerator_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Make sure we can handle empty IRs + TORCH_CHECK(!IrGraphGenerator::toGraphviz( + &fusion, IrGraphGenerator::DetailLevel::Basic) + .empty()); + + // Construct an interesting IR + TensorView* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + TensorView* tv2 = add(tv0, IrBuilder::create(3.141)); + TensorView* tv3 = broadcast(tv0, {false, true, false, true}); + TensorView* tv4 = + reductionOp(BinaryOpType::Add, {2}, IrBuilder::create(0), tv3); + TensorView* tv5 = clamp( + tv4, IrBuilder::create(0.f), IrBuilder::create(1.f)); + TensorView* tv6 = add(tv2, tv2); + + // Another checkpoint before adding outputs + TORCH_CHECK(!IrGraphGenerator::toGraphviz( + &fusion, IrGraphGenerator::DetailLevel::Explicit) + .empty()); + + fusion.addOutput(tv6); + + tv4->axis(2)->parallelize(ParallelType::BIDy); + tv6->merge(0); + tv6->split(0, 4); + tv6->axis(0)->parallelize(ParallelType::BIDx); + tv5->reorder({{-1, 0}}); + tv2->computeAt(tv6, 1); + + // Another checkpoint with more node types + TORCH_CHECK(!IrGraphGenerator::toGraphviz( + &fusion, IrGraphGenerator::DetailLevel::ComputeOnly) + .empty()); + + for (Val* val : fusion.vals()) { + if (!val->isFusionInput() && + val->getValType().value() == ValType::TensorView) { + TensorView* tv = static_cast(val); + tv->axis(-1)->parallelize(ParallelType::TIDx); + } + } + + // Final IR graph + TORCH_CHECK(!IrGraphGenerator::toGraphviz( + &fusion, IrGraphGenerator::DetailLevel::Verbose) + .empty()); +} + +TEST_F(NVFuserTest, FusionDispatch_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + Double* f = IrBuilder::create(2.f); + std::stringstream ss1, ss2, ss3; + ss1 << f; + ss2 << static_cast(f); + ss3 << static_cast(f); + TORCH_CHECK( + ss1.str().compare(ss2.str()) == 0 && ss1.str().compare(ss3.str()) == 0, + "Error with dispatch system where results differ by passing Double* vs Val* vs Statement*."); +} + +// Evaluate basic scalar operations with constant values +TEST_F(NVFuserTest, FusionExprEvalConstants_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + ExpressionEvaluator evaluator(&fusion); + + auto* a = IrBuilder::create(7); + auto* b = IrBuilder::create(3); + + // Avoid div operation because it casts int operands to float + checkIntValue(evaluator, neg(a), -7); + checkIntValue(evaluator, add(a, b), 10); + checkIntValue(evaluator, neg(mul(sub(a, b), add(a, b))), -40); + checkIntValue(evaluator, mod(a, b), 1); + checkIntValue(evaluator, ceilDiv(a, b), 3); +} + +TEST_F(NVFuserTest, FusionExprEvalDouble_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + auto ten = IrBuilder::create(10); + auto two = IrBuilder::create(2); + auto three = IrBuilder::create(3); + auto val = castOp(DataType::Int, ceilDiv(sub(ten, two), three)); + auto reference = static_cast(std::ceil((10.0 - 2.0) / 3.0)); + TORCH_CHECK(reference == val->evaluateInt()); +} + +// Evaluate basic scalar operations with bound values +TEST_F(NVFuserTest, FusionExprEvalBindings_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + ExpressionEvaluator evaluator(&fusion); + + auto* a = IrBuilder::create(); + auto* b = IrBuilder::create(); + auto* c = add(a, b); + auto* d = neg(ceilDiv(c, b)); + auto* e = IrBuilder::create(0); + + // trying to evaluate before binding should give empty results + TORCH_CHECK(!evaluator.evaluate(a).has_value()); + TORCH_CHECK(!evaluator.evaluate(d).has_value()); + + evaluator.bind(a, 7); + evaluator.bind(b, 3); + + // can't bind to the results of expressions + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) + ASSERT_ANY_THROW(evaluator.bind(c, 100)); + + // can't bind to concrete values + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) + ASSERT_ANY_THROW(evaluator.bind(e, 100)); + + checkIntValue(evaluator, c, 10); + checkIntValue(evaluator, sub(a, b), 4); + checkIntValue(evaluator, mod(a, b), 1); + checkIntValue(evaluator, ceilDiv(a, b), 3); + checkIntValue(evaluator, d, -4); + + // Reset evaluation context + evaluator = ExpressionEvaluator(&fusion); + + evaluator.bind(a, 2); + evaluator.bind(b, 5); + + checkIntValue(evaluator, c, 7); + checkIntValue(evaluator, sub(a, b), -3); + checkIntValue(evaluator, mod(a, b), 2); + checkIntValue(evaluator, ceilDiv(a, b), 1); + checkIntValue(evaluator, d, -2); +} + +// Evaluate expressions in a simple IR +TEST_F(NVFuserTest, FusionExprEvalBasic_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Create a non-trivial IR + TensorView* tv0 = makeSymbolicTensor(2); + TensorView* tv1 = makeSymbolicTensor(2); + + fusion.addInput(tv0); + fusion.addInput(tv1); + + TensorView* tv2 = add(tv1, IrBuilder::create(2.0)); + TensorView* tv3 = add(tv0, tv2); + + fusion.addOutput(tv3); + + tv3->split(0, 4); + + tv0->computeAt(tv3, 1); + tv1->computeAt(tv3, 1); + + tv3->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(1)->parallelize(ParallelType::Unroll); + tv3->axis(1)->parallelize(ParallelType::Unroll); + tv2->axis(-1)->parallelize(ParallelType::TIDx); + tv3->axis(-1)->parallelize(ParallelType::TIDx); + + // 1. Create an evaluator + ExpressionEvaluator evaluator(&fusion); + + // 2. Bind values + // + // IMPORTANT: + // a. The bindings are only as stable as the Vals are in the fusion graph + // b. You must use the original (rootDomain) extents + // (ex. `tv0->getRootDomain()[0]->extent()` + // instead of `tv0->axis(0)->extent()`) + // + evaluator.bind(tv0->getRootDomain()[0]->extent(), 6); + evaluator.bind(tv0->getRootDomain()[1]->extent(), 128); + evaluator.bind(tv1->getRootDomain()[0]->extent(), 6); + evaluator.bind(tv1->getRootDomain()[1]->extent(), 128); + + // 3. Evaluate and check result values + TORCH_CHECK(tv2->domain()->nDims() == 3); + checkIntValue(evaluator, tv2->axis(0)->extent(), 2); + checkIntValue(evaluator, tv2->axis(1)->extent(), 4); + checkIntValue(evaluator, tv2->axis(2)->extent(), 128); + + TORCH_CHECK(tv3->domain()->nDims() == 3); + checkIntValue(evaluator, tv3->axis(0)->extent(), 2); + checkIntValue(evaluator, tv3->axis(1)->extent(), 4); + checkIntValue(evaluator, tv3->axis(2)->extent(), 128); +} + +// Evaluate expressions in a more complex IR +TEST_F(NVFuserTest, FusionExprEvalComplex_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + TensorView* tv1 = mul(tv0, IrBuilder::create(-1.0)); + TensorView* tv2 = add(tv0, IrBuilder::create(3.0)); + TensorView* tv3 = mul(tv0, IrBuilder::create(2.0)); + TensorView* tv4 = add(tv2, tv1); + TensorView* tv5 = add(tv4, tv3); + TensorView* tv6 = add(tv0, tv3); + + fusion.addOutput(tv5); + fusion.addOutput(tv6); + + tv5->reorder({{-1, 0}}); + + tv6->split(0, 5); + tv5->merge(0); + + // 1. Create an evaluator + ExpressionEvaluator evaluator(&fusion); + + // 2. Bind values + evaluator.bind(tv0->getRootDomain()[0]->extent(), 129); + evaluator.bind(tv0->getRootDomain()[1]->extent(), 127); + + // Evaluate and check extent values + TORCH_CHECK(tv0->domain()->nDims() == 2); + checkIntValue(evaluator, tv0->axis(0)->extent(), 129); + checkIntValue(evaluator, tv0->axis(1)->extent(), 127); + + TORCH_CHECK(tv3->domain()->nDims() == 2); + checkIntValue(evaluator, tv3->axis(0)->extent(), 129); + checkIntValue(evaluator, tv3->axis(1)->extent(), 127); + + TORCH_CHECK(tv4->domain()->nDims() == 2); + checkIntValue(evaluator, tv4->axis(0)->extent(), 129); + checkIntValue(evaluator, tv4->axis(1)->extent(), 127); + + TORCH_CHECK(tv5->domain()->nDims() == 1); + checkIntValue(evaluator, tv5->axis(0)->extent(), 16383); + + TORCH_CHECK(tv6->domain()->nDims() == 3); + checkIntValue(evaluator, tv6->axis(0)->extent(), 26); + checkIntValue(evaluator, tv6->axis(1)->extent(), 5); + checkIntValue(evaluator, tv6->axis(2)->extent(), 127); +} + +// Evaluate expressions post lowering +TEST_F(NVFuserTest, FusionExprEvalPostLower_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Create a non-trivial IR + TensorView* tv0 = makeSymbolicTensor(2); + TensorView* tv1 = makeSymbolicTensor(2); + + fusion.addInput(tv0); + fusion.addInput(tv1); + + TensorView* tv2 = add(tv1, IrBuilder::create(2.0)); + TensorView* tv3 = add(tv0, tv2); + + fusion.addOutput(tv3); + + tv3->split(0, 4); + + tv0->computeAt(tv3, 1); + tv1->computeAt(tv3, 1); + + tv3->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(1)->parallelize(ParallelType::Unroll); + tv3->axis(1)->parallelize(ParallelType::Unroll); + tv2->axis(-1)->parallelize(ParallelType::TIDx); + tv3->axis(-1)->parallelize(ParallelType::TIDx); + + auto* bid_x = add(tv3->axis(0)->extent(), IrBuilder::create(0)); + auto* tid_x = add(tv3->axis(-1)->extent(), IrBuilder::create(0)); + + // Lower + GpuLower gpulw(&fusion); + + // 1. Create an evaluation context + ExpressionEvaluator evaluator(&fusion); + + // 2. Bind values + evaluator.bind(tv0->getRootDomain()[0]->extent(), 6); + evaluator.bind(tv0->getRootDomain()[1]->extent(), 128); + evaluator.bind(tv1->getRootDomain()[0]->extent(), 6); + evaluator.bind(tv1->getRootDomain()[1]->extent(), 128); + + // 3. Evaluate and check result values + TORCH_CHECK(tv2->domain()->nDims() == 3); + checkIntValue(evaluator, tv2->axis(0)->extent(), 2); + checkIntValue(evaluator, tv2->axis(1)->extent(), 4); + checkIntValue(evaluator, tv2->axis(2)->extent(), 128); + + TORCH_CHECK(tv3->domain()->nDims() == 3); + checkIntValue(evaluator, tv3->axis(0)->extent(), 2); + checkIntValue(evaluator, tv3->axis(1)->extent(), 4); + checkIntValue(evaluator, tv3->axis(2)->extent(), 128); + + checkIntValue(evaluator, bid_x, 2); + checkIntValue(evaluator, tid_x, 128); +} + +// Kernel IR: Evaluate basic scalar operations with constant values +TEST_F(NVFuserTest, FusionKernelExprEvalConstants_CUDA) { + Fusion fusion; + kir::Kernel kernel(&fusion); + FusionGuard fg((&kernel)->as()); + + auto a = IrBuilder::create(7); + auto b = IrBuilder::create(3); + auto c = IrBuilder::subExpr(a, b); + auto d = IrBuilder::divExpr(a, b); + auto e = IrBuilder::mulExpr(c, d); + + kir::ExpressionEvaluator evaluator; + + checkIntValue(evaluator, IrBuilder::negExpr(a), -7); + checkIntValue(evaluator, IrBuilder::addExpr(a, b), 10); + checkIntValue(evaluator, IrBuilder::negExpr(e), -8); + checkIntValue(evaluator, IrBuilder::modExpr(a, b), 1); + checkIntValue(evaluator, IrBuilder::ceilDivExpr(a, b), 3); +} + +// Kernel IR: Evaluate basic scalar operations with bound values +TEST_F(NVFuserTest, FusionKernelExprEvalBindings_CUDA) { + Fusion fusion; + kir::Kernel kernel(&fusion); + FusionGuard fg((&kernel)->as()); + + kir::ExpressionEvaluator evaluator; + + auto a = IrBuilder::create(c10::nullopt); + auto b = IrBuilder::create(c10::nullopt); + auto c = IrBuilder::addExpr(a, b); + auto d = IrBuilder::negExpr(IrBuilder::ceilDivExpr(c, b)); + auto e = IrBuilder::create(0); + + // trying to evaluate before binding should give empty results + TORCH_CHECK(!evaluator.evaluate(a).has_value()); + TORCH_CHECK(!evaluator.evaluate(d).has_value()); + + evaluator.bind(a, 7); + evaluator.bind(b, 3); + + // can't bind to the results of expressions + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) + ASSERT_ANY_THROW(evaluator.bind(c, 100)); + + // can't bind to concrete values + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) + ASSERT_ANY_THROW(evaluator.bind(e, 100)); + + checkIntValue(evaluator, c, 10); + checkIntValue(evaluator, IrBuilder::subExpr(a, b), 4); + checkIntValue(evaluator, IrBuilder::modExpr(a, b), 1); + checkIntValue(evaluator, IrBuilder::ceilDivExpr(a, b), 3); + checkIntValue(evaluator, d, -4); + + // Reset the evaluation context + evaluator = kir::ExpressionEvaluator(); + + evaluator.bind(a, 2); + evaluator.bind(b, 5); + + checkIntValue(evaluator, c, 7); + checkIntValue(evaluator, IrBuilder::subExpr(a, b), -3); + checkIntValue(evaluator, IrBuilder::modExpr(a, b), 2); + checkIntValue(evaluator, IrBuilder::ceilDivExpr(a, b), 1); + checkIntValue(evaluator, d, -2); +} + +TEST_F(NVFuserTest, FusionClear_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // 1. Create a dummy IR + + { + TensorView* tv0 = makeSymbolicTensor(2); + TensorView* tv1 = makeSymbolicTensor(2); + + fusion.addInput(tv0); + fusion.addInput(tv1); + + TensorView* tv2 = add(tv1, IrBuilder::create(2.0)); + TensorView* tv3 = add(tv0, tv2); + + fusion.addOutput(tv3); + + tv3->split(0, 4); + tv0->computeAt(tv3, 1); + tv1->computeAt(tv3, 1); + + tv3->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(1)->parallelize(ParallelType::Unroll); + tv3->axis(-1)->parallelize(ParallelType::TIDx); + } + + // 2. Clear the IR + + fusion.clear(); + + TORCH_CHECK(fusion.unordered_exprs().empty()); + TORCH_CHECK(fusion.vals().empty()); + + TORCH_CHECK(fusion.inputs().empty()); + TORCH_CHECK(fusion.outputs().empty()); + + TORCH_CHECK(ir_utils::getReductionOps(&fusion).empty()); + + // 3. Rebuild the IR + + { + TensorView* tv0 = makeSymbolicTensor(3); + TensorView* tv1 = makeSymbolicTensor(3); + TensorView* tv2 = add(tv1, IrBuilder::create(2.0)); + TensorView* tv3 = add(tv0, tv2); + + fusion.addInput(tv0); + fusion.addInput(tv1); + fusion.addOutput(tv3); + + // tv3 [i0, i1, i2] + tv3->reorder({{0, 2}, {2, 0}}); + // tv3 [i2, i1, i0] + tv3->split(-1, 4); + // tv3 [i2, i1, i0outer, i0inner{4}] + tv3->reorder({{2, 0}, {3, 1}, {0, 3}}); + // tv3 [i0outer, i0inner{4}, i1, i2] + tv0->computeAt(tv3, -1); + tv1->computeAt(tv3, -1); + tv3->axis(1)->parallelize(ParallelType::BIDx); + } + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor input1 = at::randn({16, 8, 8}, options); + at::Tensor input2 = at::randn_like(input1); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input1, input2}); + auto outputs = fe.runFusion({input1, input2}); + + at::Tensor tv2_ref = input2 + 2.0; + at::Tensor output_ref = input1 + tv2_ref; + + TORCH_CHECK(output_ref.equal(outputs[0])); +} + +TEST_F(NVFuserTest, FusionCopy_CUDA) { + Fusion original_fusion; + + // Create the test IR + { + FusionGuard fg(&original_fusion); + + auto tv0 = makeSymbolicTensor(3); + auto tv1 = makeSymbolicTensor(3); + auto tv2 = add(tv1, IrBuilder::create(2.0)); + auto tv3 = sub(add(tv0, mul(tv2, tv2)), tv2); + + original_fusion.addInput(tv0); + original_fusion.addInput(tv1); + original_fusion.addOutput(tv3); + + tv3->reorder({{0, 2}, {2, 0}}); + tv3->split(-1, 4); + tv3->reorder({{2, 0}, {3, 1}, {0, 3}}); + + tv0->computeAt(tv3, -1); + tv1->computeAt(tv3, -1); + + tv3->axis(0)->parallelize(ParallelType::BIDx); + tv3->axis(-1)->parallelize(ParallelType::TIDx); + } + + // Test copy before lowering + Fusion clone = original_fusion; + + // Compare IR dumps + std::stringstream original_ir; + std::stringstream clone_ir; + original_ir << original_fusion; + clone_ir << clone; + ASSERT_EQ(original_ir.str(), clone_ir.str()); + + // Lower original fusion + std::string original_kernel; + { + // TODO(kir): remove this guard once we implement the cuda codegen visitor + FusionGuard fg(&original_fusion); + original_kernel = + codegen::generateCudaKernel(GpuLower(&original_fusion).kernel()); + } + + // Make sure the "before lowering" clone was not mutated + // while lowering the original fusion IR + std::stringstream before_lowering_ir; + before_lowering_ir << clone; + ASSERT_EQ(original_ir.str(), before_lowering_ir.str()); + + // Test copy after lowering (including assignment operator) + Fusion before_lowering = clone; + clone = original_fusion; + + // Compare IR dumps + std::stringstream original_lowered_ir; + std::stringstream clone_lowered_ir; + original_lowered_ir << original_fusion; + clone_lowered_ir << clone; + ASSERT_EQ(original_lowered_ir.str(), clone_lowered_ir.str()); + + // Lower the "before lowering" and compare kernels + std::string clone_kernel; + { + // TODO(kir): remove this guard once we implement the cuda codegen visitor + FusionGuard fg(&before_lowering); + clone_kernel = + codegen::generateCudaKernel(GpuLower(&before_lowering).kernel()); + } + ASSERT_EQ(original_kernel, clone_kernel); +} + +TEST_F(NVFuserTest, FusionMove_CUDA) { + Fusion fusion; + + // Create the test IR + { + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(3); + auto tv1 = makeSymbolicTensor(3); + auto tv2 = add(tv1, IrBuilder::create(2.0)); + auto tv3 = sub(add(tv0, mul(tv2, tv2)), tv2); + + fusion.addInput(tv0); + fusion.addInput(tv1); + fusion.addOutput(tv3); + + tv3->reorder({{0, 2}, {2, 0}}); + tv3->split(-1, 4); + tv3->reorder({{2, 0}, {3, 1}, {0, 3}}); + + tv0->computeAt(tv3, -1); + tv1->computeAt(tv3, -1); + + tv3->axis(0)->parallelize(ParallelType::BIDx); + tv3->axis(-1)->parallelize(ParallelType::TIDx); + } + + std::stringstream original_ir; + original_ir << fusion; + + // Test move before lowering + Fusion another_fusion = std::move(fusion); + + // Check that the original fusion is "empty" + // + // IMPORTANT: these checks assume knowledge of the internal + // implementation of the move operations. General uses + // should only assume that the moved-from object is in + // a valid, but unspecified state. This is similar to the + // standard library containers: + // https://en.cppreference.com/w/cpp/utility/move + // + TORCH_CHECK(fusion.unordered_exprs().empty()); + TORCH_CHECK(fusion.vals().empty()); + TORCH_CHECK(fusion.inputs().empty()); + TORCH_CHECK(fusion.outputs().empty()); + + // clear() has no pre-conditions so it's valid to call on a moved-from object + fusion.clear(); + + // Compare IR dumps + std::stringstream another_ir; + another_ir << another_fusion; + ASSERT_EQ(original_ir.str(), another_ir.str()); + + // Lower the fusion IR + GpuLower lower(&another_fusion); + + std::stringstream lowered_ir; + lowered_ir << another_fusion; + + // Test move assignment after lowering + fusion = std::move(another_fusion); + + // Compare IR dumps + std::stringstream moved_lowered_ir; + moved_lowered_ir << fusion; + ASSERT_EQ(lowered_ir.str(), moved_lowered_ir.str()); +} + +TEST_F(NVFuserTest, FusionSimpleArith_CUDA) { + std::stringstream ss1, ss2; + + Fusion fusion; + FusionGuard fg(&fusion); + + Double* d1 = IrBuilder::create(1.f); + Double* d2 = IrBuilder::create(2.f); + Double* d3 = IrBuilder::create(); + + // Disrupt the fusion to make sure guard works well + { + Fusion fusion2; + FusionGuard fg(&fusion2); + + Double* d1 = IrBuilder::create(1.f); + Double* d2 = IrBuilder::create(2.f); + add(d1, d2); + ss2 << fusion2; + } + + IrBuilder::create(BinaryOpType::Add, d3, d1, d2); + ss1 << fusion; + + TORCH_CHECK( + ss1.str().compare(ss2.str()) == 0, + "Error where explicit add nodes don't match implicit add nodes."); +} + +TEST_F(NVFuserTest, FusionScalarTypePromote_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + Bool* b = IrBuilder::create(true); + Double* d = IrBuilder::create(4.f); + Int* i = IrBuilder::create(3); + ComplexDouble* c = + IrBuilder::create(c10::complex(1, 2)); + + TORCH_CHECK(add(b, b)->getDataType() == DataType::Bool); + TORCH_CHECK(add(b, d)->getDataType() == DataType::Double); + TORCH_CHECK(add(b, i)->getDataType() == DataType::Int); + TORCH_CHECK(add(b, c)->getDataType() == DataType::ComplexDouble); + + TORCH_CHECK(add(d, b)->getDataType() == DataType::Double); + TORCH_CHECK(add(d, d)->getDataType() == DataType::Double); + TORCH_CHECK(add(d, i)->getDataType() == DataType::Double); + TORCH_CHECK(add(d, c)->getDataType() == DataType::ComplexDouble); + + TORCH_CHECK(add(i, b)->getDataType() == DataType::Int); + TORCH_CHECK(add(i, d)->getDataType() == DataType::Double); + TORCH_CHECK(add(i, i)->getDataType() == DataType::Int); + TORCH_CHECK(add(i, c)->getDataType() == DataType::ComplexDouble); + + TORCH_CHECK(add(c, b)->getDataType() == DataType::ComplexDouble); + TORCH_CHECK(add(c, d)->getDataType() == DataType::ComplexDouble); + TORCH_CHECK(add(c, i)->getDataType() == DataType::ComplexDouble); + TORCH_CHECK(add(c, c)->getDataType() == DataType::ComplexDouble); +} + +TEST_F(NVFuserTest, FusionComplexAbsTypes_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto options = at::TensorOptions().device(at::kCUDA, 0); + auto tensor_cf = at::randn({4, 4, 4}, options.dtype(at::kComplexFloat)); + auto tensor_cd = at::randn({4, 4, 4}, options.dtype(at::kComplexDouble)); + + auto type_cf = TensorType::create(tensor_cf); + auto tv_cf = IrBuilder::create(type_cf); + auto type_cd = TensorType::create(tensor_cd); + auto tv_cd = IrBuilder::create(type_cd); + + TORCH_CHECK( + tensor_cf.abs().scalar_type() == + data_type_to_aten(abs(tv_cf)->getDataType().value())); + TORCH_CHECK( + tensor_cd.abs().scalar_type() == + data_type_to_aten(abs(tv_cd)->getDataType().value())); +} + +TEST_F(NVFuserTest, FusionRegister_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + Double* v1 = IrBuilder::create(1.f); + Double* v2 = IrBuilder::create(2.f); + Val* v3 = binaryOp(BinaryOpType::Add, v1, v2); + Val* v4 = binaryOp(BinaryOpType::Add, v1, v2); + TORCH_CHECK(v1->name() + 1 == v2->name()); + TORCH_CHECK(v2->name() + 1 == v3->name()); + TORCH_CHECK(v3->name() + 1 == v4->name()); + TORCH_CHECK(v3->definition()->name() + 1 == v4->definition()->name()); +} + +// dummy expr with 2 outputs only for toposort test. +struct DummyExpr : public Expr { + ~DummyExpr() = default; + DummyExpr( + IrBuilderPasskey passkey, + Val* _outlhs, + Val* _outrhs, + Val* _lhs, + Val* _rhs) + : Expr(passkey, ExprType::UnaryOp) // Not terribly safe... + { + addOutput(_outlhs); + addOutput(_outrhs); + addInput(_lhs); + addInput(_rhs); + } + DummyExpr(const DummyExpr& other) = delete; + DummyExpr& operator=(const DummyExpr& other) = delete; + DummyExpr(DummyExpr&& other) = delete; + DummyExpr& operator=(DummyExpr&& other) = delete; + Expr* shallowCopy() const override { + return nullptr; + } +}; + +TEST_F(NVFuserTest, FusionTopoSort_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // e0: v3, v2 = dummy(v1, v0) + // e1: v4 = add(v3, v2) + // e2: v5 = add(v2, v4) + // e3: v6 = add(v5, v5) + Double* v0 = IrBuilder::create(); + Double* v1 = IrBuilder::create(); + Double* v2 = IrBuilder::create(); + Double* v3 = IrBuilder::create(); + Double* v4 = IrBuilder::create(); + Double* v5 = IrBuilder::create(); + Double* v6 = IrBuilder::create(); + + std::vector inputs = {v0, v1}; + for (auto val : inputs) { + fusion.addInput(val); + } + + Expr* e0 = IrBuilder::create(v3, v2, v1, v0); + Expr* e1 = IrBuilder::create(BinaryOpType::Add, v4, v3, v2); + Expr* e2 = IrBuilder::create(BinaryOpType::Add, v5, v2, v4); + Expr* e3 = IrBuilder::create(BinaryOpType::Add, v6, v5, v5); + + fusion.addOutput(v2); + fusion.addOutput(v3); + auto exprs = fusion.exprs(); + TORCH_CHECK(exprs.size() == 1, "Found ", exprs.size(), " but expecting 1"); + TORCH_CHECK(exprs[0] == e0); + + fusion.addOutput(v5); + exprs = fusion.exprs(); + TORCH_CHECK(exprs.size() == 3, "Found ", exprs.size(), " but expecting 3"); + TORCH_CHECK(exprs[0] == e0); + TORCH_CHECK(exprs[1] == e1); + TORCH_CHECK(exprs[2] == e2); + + fusion.addOutput(v4); + exprs = fusion.exprs(); + TORCH_CHECK(exprs.size() == 3, "Found ", exprs.size(), " but expecting 3"); + TORCH_CHECK(exprs[0] == e0); + TORCH_CHECK(exprs[1] == e1); + TORCH_CHECK(exprs[2] == e2); + + fusion.addOutput(v6); + exprs = fusion.exprs(); + TORCH_CHECK(exprs.size() == 4, "Found ", exprs.size(), " but expecting 4"); + TORCH_CHECK(exprs[0] == e0); + TORCH_CHECK(exprs[1] == e1); + TORCH_CHECK(exprs[2] == e2); + TORCH_CHECK(exprs[3] == e3); + + TORCH_CHECK(v2->definition()->name() == 0); + TORCH_CHECK(v3->definition()->name() == 0); + TORCH_CHECK(v4->definition()->name() == 1); + TORCH_CHECK(v5->definition()->name() == 2); + TORCH_CHECK(v6->definition()->name() == 3); +} + +TEST_F(NVFuserTest, FusionTensor_CUDA) { + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + Fusion fusion; + FusionGuard fg(&fusion); + + { + auto tensor = at::randn({2, 3, 4, 5}, options); + auto tensor_type = TensorType::create(tensor); + auto fuser_tensor = IrBuilder::create(tensor_type); + TORCH_CHECK((int64_t)fuser_tensor->nDims() == tensor.dim()); + TORCH_CHECK(fuser_tensor->getDataType().value() == DataType::Float); + TORCH_CHECK(fuser_tensor->domain() != nullptr); + for (const auto i : c10::irange(fuser_tensor->nDims())) { + // size 1 dimension are makred as broadcast + TORCH_CHECK( + fuser_tensor->axis(i)->isBroadcast() == (tensor.sizes()[i] == 1)); + // check contiguity information; + TORCH_CHECK(fuser_tensor->domain()->contiguity()[i]); + } + } + + // TensorType::create fills stride_properties, which helps us to mark + // IterDomain properly + // Note: implementation could change, depending on how much we want to invest + // in our home-brew contiguity coalescing. For now let's make sure that we + // properly test what we are using. + { + auto tensor = at::randn({4, 4, 4}, options); + auto sliced_tensor = tensor.slice(1, 0, -1, 2); + + auto tensor_type = TensorType::create(sliced_tensor); + auto fuser_tensor = IrBuilder::create(tensor_type); + TORCH_CHECK((int64_t)fuser_tensor->nDims() == tensor.dim()); + TORCH_CHECK(fuser_tensor->getDataType().value() == DataType::Float); + TORCH_CHECK(fuser_tensor->domain() != nullptr); + for (const auto i : c10::irange(fuser_tensor->nDims())) { + // size 1 dimension are makred as broadcast + TORCH_CHECK(fuser_tensor->axis(i)->isBroadcast() == false); + } + TORCH_CHECK(fuser_tensor->domain()->contiguity()[0]); + TORCH_CHECK(!fuser_tensor->domain()->contiguity()[1]); + TORCH_CHECK(fuser_tensor->domain()->contiguity()[2]); + } + + { + auto tensor = at::randn({2, 3, 4, 5}, options); + auto permuted_tensor = tensor.permute({0, 3, 1, 2}); + auto tensor_type = TensorType::create(permuted_tensor); + auto fuser_tensor = IrBuilder::create(tensor_type); + TORCH_CHECK((int64_t)fuser_tensor->nDims() == tensor.dim()); + TORCH_CHECK(fuser_tensor->getDataType().value() == DataType::Float); + TORCH_CHECK(fuser_tensor->domain() != nullptr); + for (const auto i : c10::irange(fuser_tensor->nDims())) { + // size 1 dimension are makred as broadcast + TORCH_CHECK(fuser_tensor->axis(i)->isBroadcast() == false); + } + TORCH_CHECK(!fuser_tensor->domain()->contiguity()[0]); + TORCH_CHECK(!fuser_tensor->domain()->contiguity()[1]); + TORCH_CHECK(fuser_tensor->domain()->contiguity()[2]); + TORCH_CHECK(!fuser_tensor->domain()->contiguity()[3]); + } +} + +TEST_F(NVFuserTest, FusionFilterVals_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + auto tv1 = makeSymbolicTensor(1); + auto scalar0 = IrBuilder::create(0); + auto scalar1 = IrBuilder::create(0); + auto scalar2 = IrBuilder::create(1); + + const std::vector vals = {tv0, scalar0, tv1, scalar1, scalar2}; + + std::vector tvs( + ir_utils::filterByType(vals).begin(), + ir_utils::filterByType(vals).end()); + TORCH_CHECK(tvs.size() == 2); + TORCH_CHECK(tvs[0] == tv0); + TORCH_CHECK(tvs[1] == tv1); + + std::vector floats( + ir_utils::filterByType(vals).begin(), + ir_utils::filterByType(vals).end()); + TORCH_CHECK(floats.size() == 1); + TORCH_CHECK(floats[0] == scalar0); + + std::vector ints( + ir_utils::filterByType(vals).begin(), + ir_utils::filterByType(vals).end()); + TORCH_CHECK(ints.size() == 2); + TORCH_CHECK(ints[0] == scalar1); + TORCH_CHECK(ints[1] == scalar2); + + TORCH_CHECK( + ir_utils::filterByType(vals).begin() == + ir_utils::filterByType(vals).end(), + "Not expecting any results"); +} + +TEST_F(NVFuserTest, FusionTVSplit_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv = makeSymbolicTensor(3); + + tv = tv->split(2, 2); + TORCH_CHECK(tv->nDims() == 4); + Expr* outer = tv->axis(2)->extent()->definition(); + + TORCH_CHECK( + outer->getExprType().value() == ExprType::BinaryOp && + static_cast(outer)->getBinaryOpType() == + BinaryOpType::CeilDiv && + static_cast(outer)->lhs()->sameAs( + tv->getRootDomain()[2]->extent()) && + static_cast(static_cast(outer)->rhs()) + ->sameAs(IrBuilder::create(2))); + + IterDomain* inner = static_cast(tv->axis(3)); + TORCH_CHECK( + inner->extent()->isScalar() && + static_cast(inner->extent())->isConst() && + static_cast(inner->extent())->value().value() == 2); +} + +TEST_F(NVFuserTest, FusionTVMerge_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv = makeSymbolicTensor(3); + + tv = tv->merge(1); + Expr* axisOp = tv->axis(1)->extent()->definition(); + + TORCH_CHECK( + tv->nDims() == 2 && axisOp->getExprType() == ExprType::BinaryOp && + static_cast(axisOp)->getBinaryOpType() == BinaryOpType::Mul && + static_cast(axisOp)->lhs() == + tv->getRootDomain()[1]->extent() && + static_cast(axisOp)->rhs() == + tv->getRootDomain()[2]->extent()); +} + +TEST_F(NVFuserTest, FusionTVReorder_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + std::unordered_map shift_right{{-1, 0}}; + + std::unordered_map shift_left{{0, -1}}; + + std::unordered_map shift_left_2{{0, -1}, {1, 0}, {2, 1}}; + + std::unordered_map swap{{0, 2}, {2, 0}}; + + auto tv = makeSymbolicTensor(3); + std::vector ref; + ref = std::vector( + tv->domain()->domain().begin(), tv->domain()->domain().end()); + + tv->reorder(shift_left); + for (const auto i : c10::irange(tv->nDims())) { + TORCH_CHECK(ref[i]->sameAs(tv->axis(i - 1))); + } + + tv = makeSymbolicTensor(3); + ref = std::vector( + tv->domain()->domain().begin(), tv->domain()->domain().end()); + + tv->reorder(shift_left); + for (const auto i : c10::irange(tv->nDims())) { + TORCH_CHECK(ref[i]->sameAs(tv->axis(i - 1))); + } + + tv = makeSymbolicTensor(3); + ref = std::vector( + tv->domain()->domain().begin(), tv->domain()->domain().end()); + + tv->reorder(shift_right); + TORCH_CHECK(ref[ref.size() - 1]->sameAs(tv->axis(0))); + for (const auto i : c10::irange(1, tv->nDims())) { + TORCH_CHECK(ref[i - 1]->sameAs(tv->axis(i))); + } + + tv = makeSymbolicTensor(3); + ref = std::vector( + tv->domain()->domain().begin(), tv->domain()->domain().end()); + tv->reorder(swap); + TORCH_CHECK(ref[0]->sameAs(tv->axis(2))); + TORCH_CHECK(ref[2]->sameAs(tv->axis(0))); + TORCH_CHECK(ref[1]->sameAs(tv->axis(1))); +} + +TEST_F(NVFuserTest, FusionEquality_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + Double* fval1 = IrBuilder::create(); + Double* fval1_copy = fval1; + Double* fval2 = IrBuilder::create(); + Double* fone = IrBuilder::create(1.0); + + TORCH_CHECK(fval1->sameAs(fval1_copy)); + TORCH_CHECK(!fval1->sameAs(fval2)); + TORCH_CHECK(!fone->sameAs(fval1)); + TORCH_CHECK(fone->sameAs(IrBuilder::create(1.0))); + + Int* ival1 = IrBuilder::create(); + Int* ival1_copy = ival1; + Int* ival2 = IrBuilder::create(); + Int* ione = IrBuilder::create(1); + + TORCH_CHECK(ival1->sameAs(ival1_copy)); + TORCH_CHECK(!ival1->sameAs(ival2)); + TORCH_CHECK(!ione->sameAs(ival1)); + TORCH_CHECK(ione->sameAs(IrBuilder::create(1))); + + BinaryOp* add1 = IrBuilder::create( + BinaryOpType::Add, IrBuilder::create(), fval1, ival1); + BinaryOp* add1_copy = IrBuilder::create( + BinaryOpType::Add, IrBuilder::create(), fval1, ival1); + BinaryOp* sub1 = IrBuilder::create( + BinaryOpType::Sub, IrBuilder::create(), fval1, ival1); + + UnaryOp* neg1 = IrBuilder::create( + UnaryOpType::Neg, IrBuilder::create(), fval1); + UnaryOp* neg2 = IrBuilder::create( + UnaryOpType::Neg, IrBuilder::create(), fval2); + UnaryOp* neg1_copy = IrBuilder::create( + UnaryOpType::Neg, IrBuilder::create(), fval1); + + TORCH_CHECK(add1->sameAs(add1_copy)); + TORCH_CHECK(!add1->sameAs(sub1)); + + TORCH_CHECK(neg1->sameAs(neg1_copy)); + TORCH_CHECK(!static_cast(neg1)->sameAs(add1)); + TORCH_CHECK(!neg1->sameAs(neg2)); +} + +TEST_F(NVFuserTest, FusionDependency_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + Double* d0 = IrBuilder::create(0.f); + Double* d1 = IrBuilder::create(1.f); + auto d2 = add(d0, d1); + + auto d3 = add(d2, d2); + + Double* d4 = IrBuilder::create(4.f); + Double* d5 = IrBuilder::create(5.f); + auto d6 = add(d4, d5); + + Double* d7 = IrBuilder::create(7.f); + Double* d8 = IrBuilder::create(8.f); + auto d9 = add(d7, d8); + + auto d10 = add(d6, d9); + + auto d11 = add(d3, d10); + + TORCH_CHECK(DependencyCheck::isDependencyOf(d0, d11)); + TORCH_CHECK(DependencyCheck::isDependencyOf(d1, d11)); + TORCH_CHECK(DependencyCheck::isDependencyOf(d2, d11)); + TORCH_CHECK(DependencyCheck::isDependencyOf(d3, d11)); + TORCH_CHECK(DependencyCheck::isDependencyOf(d6, d11)); + TORCH_CHECK(DependencyCheck::isDependencyOf(d9, d11)); + TORCH_CHECK(DependencyCheck::isDependencyOf(d0, d2)); + TORCH_CHECK(DependencyCheck::isDependencyOf(d2, d3)); + TORCH_CHECK(DependencyCheck::isDependencyOf(d4, d6)); + TORCH_CHECK(DependencyCheck::isDependencyOf(d8, d10)); + + TORCH_CHECK(!DependencyCheck::isDependencyOf(d11, d0)); + TORCH_CHECK(!DependencyCheck::isDependencyOf(d11, d1)); + TORCH_CHECK(!DependencyCheck::isDependencyOf(d11, d2)); + TORCH_CHECK(!DependencyCheck::isDependencyOf(d11, d3)); + TORCH_CHECK(!DependencyCheck::isDependencyOf(d11, d4)); + TORCH_CHECK(!DependencyCheck::isDependencyOf(d11, d5)); + TORCH_CHECK(!DependencyCheck::isDependencyOf(d2, d0)); + TORCH_CHECK(!DependencyCheck::isDependencyOf(d3, d2)); + TORCH_CHECK(!DependencyCheck::isDependencyOf(d6, d4)); + TORCH_CHECK(!DependencyCheck::isDependencyOf(d10, d8)); + + auto dep_chain = DependencyCheck::getSingleDependencyChain(d0, d11); + TORCH_CHECK(dep_chain.back() == d11); + dep_chain.pop_back(); + TORCH_CHECK(dep_chain.back() == d3); + dep_chain.pop_back(); + TORCH_CHECK(dep_chain.back() == d2); + dep_chain.pop_back(); + + dep_chain = DependencyCheck::getSingleDependencyChain(d6, d11); + TORCH_CHECK(dep_chain.back() == d11); + dep_chain.pop_back(); + TORCH_CHECK(dep_chain.back() == d10); + dep_chain.pop_back(); + + dep_chain = DependencyCheck::getSingleDependencyChain(d4, d11); + TORCH_CHECK(dep_chain.back() == d11); + dep_chain.pop_back(); + TORCH_CHECK(dep_chain.back() == d10); + dep_chain.pop_back(); + TORCH_CHECK(dep_chain.back() == d6); + dep_chain.pop_back(); + + dep_chain = DependencyCheck::getSingleDependencyChain(d11, d2); + TORCH_CHECK(dep_chain.empty()); +} + +TEST_F(NVFuserTest, FusionParser_CUDA) { + // This test may not pass if using a custom block sync as there may + // be additional calls. Skip the test as it's not specifically + // relevant with block synchronizatin. + if (std::getenv("PYTORCH_NVFUSER_USE_BLOCK_SYNC_ATOMIC")) { + return; + } + auto g = std::make_shared(); + const auto graph0_string = R"IR( + graph(%0 : Float(2, strides=[1]), + %1 : Float(2, strides=[1])): + %c0 : Float(2, strides=[1]) = aten::mul(%0, %1) + %d0 : Float(2, strides=[1]) = aten::mul(%c0, %0) + return (%d0))IR"; + parseIR(graph0_string, g.get()); + + // strides are not yet supported in the irparser. + for (auto val : g->block()->inputs()) { + if (val->isCompleteTensor()) + val->setType(val->type()->castRaw()->contiguous()); + } + for (auto node : g->block()->nodes()) { + for (auto val : node->outputs()) { + if (val->isCompleteTensor()) + val->setType(val->type()->castRaw()->contiguous()); + } + } + + auto fusion = parseJitIR(g); + FusionGuard fg(fusion.get()); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + // Avoid vectorization here as those kernels can't be lowered twice at the + // moment + at::Tensor input1 = at::randn({16}, options); + at::Tensor input2 = at::randn({16}, options); + auto lparams = schedulePointwise(fusion.get(), {input1, input2}); + + // CONSIDER: + // 1. this can be moved to a dedicated "golden" file + // 2. use a fuzzy compare (ignore non-significant whitespaces for example) + const std::string expected_kernel = R"( +__global__ void CUDAGeneratedKernel(Tensor T0, Tensor T1, Tensor T3) { + int64_t i50; + i50 = (((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x); + if ((i50 < T0.size[0])) { + float T5[1]; + T5[0] = 0; + T5[0] + = T1[i50]; + float T4[1]; + T4[0] = 0; + T4[0] + = T0[i50]; + float T2[1]; + T2[0] + = T4[0] + * T5[0]; + float T6[1]; + T6[0] + = T2[0] + * T4[0]; + T3[i50] + = T6[0]; + } +} +)"; + + const std::string actual_kernel = + "\n" + codegen::generateCudaKernel(GpuLower(fusion.get()).kernel()); + if (expected_kernel.size() != actual_kernel.size() || + expected_kernel.compare(actual_kernel) != 0) { + std::cerr + << " Codegen mismatch, codegen possibly changed, or is incorrect. " + << " \n ========= EXPECTED ========= \n" + << expected_kernel << "\n========= ACTUAL ========== \n" + << actual_kernel << "\n=================" << std::endl; + auto it = std::mismatch( + expected_kernel.begin(), + expected_kernel.end(), + actual_kernel.begin(), + actual_kernel.end()); + std::string actual_mismatched_snippet(it.second, actual_kernel.end()); + actual_mismatched_snippet = actual_mismatched_snippet.substr(0, 10); + std::string expected_mismatched_snippet(it.first, expected_kernel.end()); + expected_mismatched_snippet = expected_mismatched_snippet.substr(0, 10); + std::cerr << "First mismatch found at: " << actual_mismatched_snippet + << ", expected: " << expected_mismatched_snippet << std::endl; + TORCH_CHECK(false); + } + + FusionExecutor fe; + fe.compileFusion(fusion.get(), {input1, input2}, lparams); + auto outputs = fe.runFusion({input1, input2}, lparams); + at::Tensor output_ref = input1 * input2 * input1; + TORCH_CHECK(output_ref.equal(outputs[0])); +} + +TEST_F(NVFuserTest, FusionOuterSplit_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeSymbolicTensor(3); + + IrBuilder::create( + BinaryOpType::Add, + tv0, + IrBuilder::create(0.0), + IrBuilder::create(1.0)); + TensorView* tv1 = add(tv0, IrBuilder::create(2.0)); + TensorView* tv2 = add(tv1, IrBuilder::create(3.0)); + fusion.addOutput(tv2); + + //[I0, I1, I2] + tv2->split(-1, 4, false); + //[I0, I1, I2o{4}, I2i] + tv2->merge(0); + tv2->merge(0); + //[I0*I1*I2o{4}, I2i] + tv2->split(0, 2); + //[I0*I1*I2o{4}o, I0*I1*I2o{4}i{2}, I2i] + tv2->reorder({{0, 1}, {1, 0}}); + // I0*I1*I2o{4}i{2}, [I0*I1*I2o{4}o, I2i] + + tv0->computeAt(tv2, -1); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor output = at::empty({2, 6, 32}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion); + fe.runFusion({}, {output}); + + at::Tensor output_ref = at::zeros_like(output, options); + output_ref = output_ref + 0.0 + 1.0 + 2.0 + 3.0; + + TORCH_CHECK(output_ref.equal(output)); +} + +TEST_F(NVFuserTest, FusionCodeGen_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeSymbolicTensor(3); + + IrBuilder::create( + BinaryOpType::Add, + tv0, + IrBuilder::create(0.0), + IrBuilder::create(1.0)); + TensorView* tv1 = add(tv0, IrBuilder::create(2.0)); + TensorView* tv2 = add(tv1, IrBuilder::create(3.0)); + fusion.addOutput(tv2); + + //[I0, I1, I2] + tv2 = tv2->split(0, 4); + //[I0o, I0i{4}, I1, I2] + tv2 = tv2->merge(1); + //[I0o, I0i{4}*I1, I2] + tv2 = tv2->split(-1, 2); + //[I0o, I0i{4}*I1, I2o, I2i{2}] + tv2 = tv2->reorder({{0, 1}, {1, 0}, {3, 2}}); + //[I0i{4}*I1, I0o, I2i{2}, I2o] + + tv0->computeAt(tv2, -1); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor output = at::empty({16, 8, 8}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion); + fe.runFusion({}, {output}); + + at::Tensor output_ref = at::zeros_like(output, options); + output_ref = output_ref + 0.0 + 1.0 + 2.0 + 3.0; + + TORCH_CHECK(output_ref.equal(output)); +} + +TEST_F(NVFuserTest, FusionCodeGen2_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeSymbolicTensor(3); + TensorView* tv1 = makeSymbolicTensor(3); + TensorView* tv2 = add(tv1, IrBuilder::create(2.0)); + TensorView* tv3 = add(tv0, tv2); + + fusion.addInput(tv0); + fusion.addInput(tv1); + fusion.addOutput(tv3); + + //[I0, I1, I2] + tv3->reorder({{0, 2}, {2, 0}}); + //[I2, I1, I0] + tv3->split(-1, 4); + //[I2, I1, I0o, I0i{4}] + tv3->reorder({{2, 0}, {3, 1}, {0, 3}}); + // I0o, I0i{4}, I1, I2] + + tv0->computeAt(tv3, -1); + tv1->computeAt(tv3, -1); + + tv3->axis(0)->parallelize(ParallelType::BIDx); + tv3->axis(-1)->parallelize(ParallelType::TIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor input1 = at::randn({16, 8, 8}, options); + at::Tensor input2 = at::randn_like(input1); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input1, input2}); + auto outputs = fe.runFusion({input1, input2}); + + at::Tensor tv2_ref = input2 + 2.0; + at::Tensor output_ref = input1 + tv2_ref; + + TORCH_CHECK(output_ref.equal(outputs[0])); +} + +TEST_F(NVFuserTest, FusionSimplePWise_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + // dimensionality of the problem + int nDims = 3; + + // Set up your input tensor views + TensorView* tv0 = makeContigTensor(nDims); + TensorView* tv1 = makeContigTensor(nDims); + + // Register your inputs + fusion.addInput(tv0); + fusion.addInput(tv1); + + // Do math with it, it returns a `Val*` but can be static_casted back to + // TensorView + TensorView* tv2 = add(tv1, IrBuilder::create(2.0)); + TensorView* tv3 = add(tv0, tv2); + + // Register your outputs + fusion.addOutput(tv3); + + // Do transformations, remember, transformations are outputs to inputs + // This doesn't have to be in this order + tv3->merge(1); + tv3->merge(0); + + // Split by n_threads + tv3->split(0, 128); + tv3->split(0, 4); + + // For all inputs, computeAt the output inline, temporaries should be squeezed + // between them + tv0->computeAt(tv3, -1); + tv1->computeAt(tv3, -1); + + // Parallelize TV3 + tv3->axis(0)->parallelize(ParallelType::BIDx); + tv3->axis(-2)->parallelize(ParallelType::Unroll); + tv3->axis(-1)->parallelize(ParallelType::TIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor input1 = at::randn({64, 2, 128}, options); + at::Tensor input2 = at::rand_like(input1); + at::Tensor output = at::empty_like(input1); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input1, input2}); + fe.runFusion({input1, input2}, {output}); + + at::Tensor tv2_ref = input2 + 2.0; + at::Tensor output_ref = input1 + tv2_ref; + + TORCH_CHECK(output_ref.equal(output)); +} + +TEST_F(NVFuserTest, FusionSimplePWiseDtypeComplex_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + // dimensionality of the problem + int nDims = 3; + + // Set up your input tensor views + TensorView* tv0 = makeContigTensor(nDims, DataType::ComplexFloat); + TensorView* tv1 = makeContigTensor(nDims, DataType::ComplexFloat); + + // Register your inputs + fusion.addInput(tv0); + fusion.addInput(tv1); + + // Do math with it, it returns a `Val*` but can be static_casted back to + // TensorView + c10::complex scalar1(2.0, 3.0); + TensorView* tv2 = add(tv1, IrBuilder::create(scalar1)); + TensorView* tv3 = add(tv0, tv2); + + // Register your outputs + fusion.addOutput(tv3); + + // Do transformations, remember, transformations are outputs to inputs + // This doesn't have to be in this order + tv3->merge(1); + tv3->merge(0); + + // Split by n_threads + tv3->split(0, 128); + tv3->split(0, 4); + + // For all inputs, computeAt the output inline, temporaries should be squeezed + // between them + tv0->computeAt(tv3, -1); + tv1->computeAt(tv3, -1); + + // Parallelize TV3 + tv3->axis(0)->parallelize(ParallelType::BIDx); + tv3->axis(-2)->parallelize(ParallelType::Unroll); + tv3->axis(-1)->parallelize(ParallelType::TIDx); + + auto options = + at::TensorOptions().dtype(at::kComplexFloat).device(at::kCUDA, 0); + + at::Tensor input1 = at::randn({64, 2, 128}, options); + at::Tensor input2 = at::rand_like(input1); + at::Tensor output = at::empty_like(input1); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input1, input2}); + fe.runFusion({input1, input2}, {output}); + + at::Tensor tv2_ref = input2 + scalar1; + at::Tensor output_ref = input1 + tv2_ref; + + TORCH_CHECK(output_ref.equal(output)); +} + +TEST_F(NVFuserTest, FusionExecKernel_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + TensorView* tv0 = makeSymbolicTensor(2); + TensorView* tv1 = makeSymbolicTensor(2); + + // Register your inputs + fusion.addInput(tv0); + fusion.addInput(tv1); + + // Do math with it, it returns a `Val*` but can be static_casted back to + // TensorView + TensorView* tv2 = add(tv1, IrBuilder::create(2.0)); + TensorView* tv3 = add(tv0, tv2); + + // Register your outputs + fusion.addOutput(tv3); + + tv3->merge(0); + tv3->split(0, 128); + tv3->split(0, 4); + + // For all inputs, computeAt the output inline, temporaries should be squeezed + // between them + tv0->computeAt(tv3, 1); + tv1->computeAt(tv3, 1); + + // Parallelize TV3 + tv3->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(1)->parallelize(ParallelType::Unroll); + tv3->axis(1)->parallelize(ParallelType::Unroll); + tv2->axis(-1)->parallelize(ParallelType::TIDx); + tv3->axis(-1)->parallelize(ParallelType::TIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor input1 = at::ones({1, 128}, options); + at::Tensor input2 = at::ones_like(input1); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input1, input2}); + auto outputs = fe.runFusion({input1, input2}); + + at::Tensor check = at::full({1, 128}, 4, options); + ; + TORCH_CHECK(outputs[0].equal(check)); +} + +int ceilDiv_(int a, int b) { + return (a + b - 1) / b; +} + +TEST_F(NVFuserTest, FusionAdvancedComputeAt1_CUDA) { + // Case 1 + // tv1 = tv0 * 0.5 + // tv2 = tv1 * -1 + // tv3 = tv1 + 3 + // tv4 = tv1 * 2 + // tv5 = tv3 + tv2 + // tv6 = tv5 + tv4 + // tv7 = tv1 + tv4 + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + TensorView* tv1 = mul(tv0, IrBuilder::create(0.5)); + TensorView* tv2 = mul(tv1, IrBuilder::create(-1.0)); + TensorView* tv3 = add(tv1, IrBuilder::create(3.0)); + TensorView* tv4 = mul(tv1, IrBuilder::create(2.0)); + TensorView* tv5 = add(tv3, tv2); + + TensorView* tv6 = add(tv5, tv4); + TensorView* tv7 = add(tv1, tv4); + + fusion.addOutput(tv6); + fusion.addOutput(tv7); + + // Lets setup to actually run + tv7->merge(0); + tv7->split(0, 128); + tv7->split(0, 4); + + tv7->axis(0)->parallelize(ParallelType::BIDx); + + tv0->computeAt(tv7, 1); + + ComputeAtMap ca_map(&fusion); + + // The this-position of the last tensor should be zero. + TORCH_CHECK( + tv7->nDims() == 3 && tv7->getComputeAtPosition() == 0 && + tv7->getMaxProducerPosition() == 1); + TORCH_CHECK( + tv7->nDims() == 3 && tv6->getComputeAtPosition() == 0 && + tv6->getMaxProducerPosition() == 1); + // The position of every other tensor should be 1. + for (auto tv : {tv1, tv2, tv3, tv4, tv5}) { + TORCH_CHECK(tv->nDims() == 3 && tv->getComputeAtPosition() == 1); + + TORCH_CHECK( + ca_map.areMapped(tv7->axis(0), tv->axis(0), IdMappingMode::PERMISSIVE)); + } + + for (Val* val : fusion.vals()) { + if (!val->isFusionInput() && + val->getValType().value() == ValType::TensorView) { + TensorView* tv = static_cast(val); + tv->axis(1)->parallelize(ParallelType::Unroll); + tv->axis(-1)->parallelize(ParallelType::TIDx); + } + } + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor aten_input = at::randn({129, 127}, options); + + auto t1 = aten_input.mul({0.5}); + auto t2 = t1.mul({-1.0}); + auto t3 = t1.add({3.0}); + auto t4 = t1.mul({2.0}); + auto t5 = t3.add(t2); + auto t6 = t5.add(t4); + auto t7 = t1.add(t4); + + std::vector aten_outputs = {t6, t7}; + std::vector cg_outputs = { + at::empty_like(aten_input, options), at::empty_like(aten_input, options)}; + + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); + fe.runFusion({aten_input}, cg_outputs); + + testValidate( + &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionAdvancedComputeAt2_CUDA) { + // Case 2 + // tv1 = tv0 * -1 + // tv2 = tv0 + 3 + // tv3 = tv0 * 2 + // tv4 = tv2 + tv1 + // tv5 = tv4 + tv3 + // tv6 = tv5 + tv3 + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + TensorView* tv1 = mul(tv0, IrBuilder::create(-1.0)); + TensorView* tv2 = add(tv0, IrBuilder::create(3.0)); + TensorView* tv3 = mul(tv0, IrBuilder::create(2.0)); + TensorView* tv4 = add(tv2, tv1); + + TensorView* tv5 = add(tv4, tv3); + TensorView* tv6 = add(tv5, tv3); + + fusion.addOutput(tv5); + fusion.addOutput(tv6); + + // Lets setup to actually run + tv6->merge(0); + tv6->split(0, 128); + tv6->split(0, 4); + + tv6->axis(0)->parallelize(ParallelType::BIDx); + + tv0->computeAt(tv6, 1); + + for (Val* val : fusion.vals()) { + if (!val->isFusionInput() && + val->getValType().value() == ValType::TensorView) { + TensorView* tv = static_cast(val); + + tv->axis(1)->parallelize(ParallelType::Unroll); + tv->axis(-1)->parallelize(ParallelType::TIDx); + } + } + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input = at::randn({129, 127}, options); + + auto t1 = input.mul({-1.0}); + auto t2 = input.add({3.0}); + auto t3 = input.mul({2.0}); + auto t4 = t2.add(t1); + auto t5 = t4.add(t3); + auto t6 = t5.add(t3); + + std::vector aten_outputs = {t5, t6}; + + FusionExecutor fe; + fe.compileFusion(&fusion, {input}); + auto cg_outputs = fe.runFusion({input}); + + testValidate(&fusion, cg_outputs, {input}, aten_outputs, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionAdvancedComputeAt3_CUDA) { + // Case 3 + // T2 = T1 * 0.979361 + // T3 = T2 * T0 + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeSymbolicTensor(4); + fusion.addInput(tv0); + + TensorView* tv1 = makeSymbolicTensor(4); + fusion.addInput(tv1); + + TensorView* tv2 = mul(tv1, IrBuilder::create(.979361)); + TensorView* tv3 = mul(tv2, tv0); + + fusion.addOutput(tv3); + + // Lets setup to actually run + while (tv3->nDims() > 1) + tv3->merge(0); + tv3->split(0, 128); + tv3->split(0, 4); + + tv0->computeAt(tv3, 1); + tv1->computeAt(tv3, 1); + + tv3->axis(0)->parallelize(ParallelType::BIDx); + + for (Val* val : fusion.vals()) { + if (!val->isFusionInput() && + val->getValType().value() == ValType::TensorView) { + TensorView* tv = static_cast(val); + + tv->axis(1)->parallelize(ParallelType::Unroll); + tv->axis(-1)->parallelize(ParallelType::TIDx); + } + } + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({129, 127, 63, 65}, options); + at::Tensor t1 = at::rand_like(t0, options); + + auto t2 = t1.mul({0.979361}); + auto aten_output = t2.mul(t0); + + std::vector aten_inputs = {t0, t1}; + + at::Tensor cg_output = at::empty_like(t0, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + fe.runFusion(aten_inputs, {cg_output}); + + testValidate( + &fusion, {cg_output}, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionAdvancedComputeAt4_CUDA) { + // Case 4 + // T4 = T2 - T3 + // T5 = T1 + T4 + // T6 = T5 - T0 + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeSymbolicTensor(4); + fusion.addInput(tv0); + + TensorView* tv1 = makeSymbolicTensor(4); + fusion.addInput(tv1); + + TensorView* tv2 = makeSymbolicTensor(4); + fusion.addInput(tv2); + + TensorView* tv3 = makeSymbolicTensor(4); + fusion.addInput(tv3); + + TensorView* tv4 = sub(tv2, tv3); + TensorView* tv5 = add(tv1, tv4); + TensorView* tv6 = sub(tv5, tv0); + + fusion.addOutput(tv6); + + // Lets setup to actually run + while (tv6->nDims() > 1) + tv6->merge(0); + tv6->split(0, 128); + tv6->split(0, 4); + + tv0->computeAt(tv6, 1); + tv1->computeAt(tv6, 1); + tv2->computeAt(tv6, 1); + tv3->computeAt(tv6, 1); + + tv6->axis(0)->parallelize(ParallelType::BIDx); + + for (Val* val : fusion.vals()) { + if (!val->isFusionInput() && + val->getValType().value() == ValType::TensorView) { + TensorView* tv = static_cast(val); + + tv->axis(1)->parallelize(ParallelType::Unroll); + tv->axis(-1)->parallelize(ParallelType::TIDx); + } + } + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({129, 127, 63, 65}, options); + at::Tensor t1 = at::rand_like(t0, options); + at::Tensor t2 = at::rand_like(t0, options); + at::Tensor t3 = at::rand_like(t0, options); + + auto t4 = t2.sub(t3); + auto t5 = t1.add(t4); + auto aten_output = t5.sub(t0); + + std::vector aten_inputs = {t0, t1, t2, t3}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + testValidate( + &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionAdvancedComputeAt5_CUDA) { + // Case 5 + // tv2 = tv0 + 2.0 + // tv3 = tv1 * tv2 + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + TensorView* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + TensorView* tv1 = makeSymbolicTensor(2); + fusion.addInput(tv1); + TensorView* tv2 = add(tv0, IrBuilder::create(2.0)); + TensorView* tv3 = mul(tv1, tv2); + fusion.addOutput(tv3); + + tv3->merge(0); + tv3->split(-1, 8); + tv3->split(-1, 4); + + tv2->computeAt(tv3, 1); + tv3->axis(0)->parallelize(ParallelType::BIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({63, 65}, options); + at::Tensor t1 = at::rand_like(t0, options); + + auto t2 = t0.add(2.0); + auto aten_output = t1.mul(t2); + + std::vector aten_inputs = {t0, t1}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + testValidate( + &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionAdvancedComputeAt6_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + TensorView* tv1 = makeSymbolicTensor(2); + fusion.addInput(tv1); + TensorView* tv2 = add(tv0, IrBuilder::create(2.0)); + TensorView* tv3 = mul(tv1, tv2); + fusion.addOutput(tv3); + + tv2->merge(0); + tv2->split(-1, 8); + tv2->split(-1, 4); + tv3->merge(0); + tv3->split(-1, 8); + + tv2->computeAt(tv3, 1); + + tv3->axis(0)->parallelize(ParallelType::BIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({63, 65}, options); + at::Tensor t1 = at::rand_like(t0, options); + + auto t2 = t0.add(2.0); + auto aten_output = t1.mul(t2); + + std::vector aten_inputs = {t0, t1}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + testValidate( + &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionAdvancedComputeAt7_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1.0)); + + auto tv2 = makeSymbolicTensor(1); + fusion.addInput(tv2); + + auto tv3 = add(tv2, IrBuilder::create(3.0)); + + auto tv4 = add(tv1, tv3); + fusion.addOutput(tv4); + + auto tv5 = broadcast(tv1, {false, true}); + + auto tv6 = makeSymbolicTensor(2); + fusion.addInput(tv6); + + auto tv7 = mul(tv5, tv6); + + fusion.addOutput(tv7); + + tv7->split(1, 2); + tv7->merge(0); + tv7->split(0, 4); + tv7->split(0, 128); + + tv7->axis(0)->parallelize(ParallelType::BIDx); + tv7->axis(1)->parallelize(ParallelType::TIDx); + + tv0->computeAt(tv7, 1); + auto tv5_domain = tv5->domain()->domain(); + + // These computeAt transformations should not affect the TV5 domain + tv0->computeAt(tv4, -1); + tv2->computeAt(tv4, -1); + + auto tv5_domain_current = tv5->domain()->domain(); + TORCH_CHECK(tv5_domain == tv5_domain_current, "Invalid TV5 domain"); + + const int numel_x = 100; + const int numel_y = 200; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({numel_x}, options); + auto t2 = at::randn({numel_x}, options); + auto t6 = at::randn({numel_x, numel_y}, options); + + auto t1 = t0.add(1.0); + auto t3 = t2.add(3.0); + auto t4 = t1.add(t3); + auto t5 = t1.unsqueeze(1); + auto t7 = t5.mul(t6); + + std::vector aten_inputs = {t0, t2, t6}; + std::vector aten_outputs = {t4, t7}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + testValidate( + &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionAdvancedComputeAt8_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1.0)); + + auto tv2 = makeSymbolicTensor(1); + fusion.addInput(tv2); + + auto tv3 = add(tv2, IrBuilder::create(3.0)); + + auto tv4 = add(tv1, tv3); + fusion.addOutput(tv4); + + auto tv5 = broadcast(tv1, {false, true}); + + auto tv6 = makeSymbolicTensor(2); + fusion.addInput(tv6); + + auto tv7 = mul(tv5, tv6); + + fusion.addOutput(tv7); + + tv7->split(1, 2); + tv7->merge(0); + tv7->split(0, 128, false); + tv7->split(0, 4, false); + + tv7->axis(0)->parallelize(ParallelType::BIDx); + tv7->axis(1)->parallelize(ParallelType::TIDx); + + // Reverse computeAt structure from previous test + tv0->computeAt(tv4, -1); + tv2->computeAt(tv4, -1); + tv0->computeAt(tv7, -1); + + const int numel_x = 100; + const int numel_y = 200; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({numel_x}, options); + auto t2 = at::randn({numel_x}, options); + auto t6 = at::randn({numel_x, numel_y}, options); + + auto t1 = t0.add(1.0); + auto t3 = t2.add(3.0); + auto t4 = t1.add(t3); + auto t5 = t1.unsqueeze(1); + auto t7 = t5.mul(t6); + + std::vector aten_inputs = {t0, t2, t6}; + std::vector aten_outputs = {t4, t7}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + testValidate( + &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionAdvancedComputeWith1_CUDA) { + // Case 1 + // tv1 = tv0 * 0.5 + // tv2 = tv1 * -1 + // tv3 = tv1 + 3 + // tv4 = tv1 * 2 + // tv5 = tv3 + tv2 + // tv6 = tv5 + tv4 + // tv7 = tv1 + tv4 + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + TensorView* tv1 = mul(tv0, IrBuilder::create(0.5)); + TensorView* tv2 = mul(tv1, IrBuilder::create(-1.0)); + TensorView* tv3 = add(tv1, IrBuilder::create(3.0)); + TensorView* tv4 = mul(tv1, IrBuilder::create(2.0)); + TensorView* tv5 = add(tv3, tv2); + + TensorView* tv6 = add(tv5, tv4); + TensorView* tv7 = add(tv1, tv4); + + fusion.addOutput(tv6); + fusion.addOutput(tv7); + + // Lets setup to actually run + tv0->merge(0); + tv0->split(0, 128); + tv0->split(0, 4); + + tv0->axis(0)->parallelize(ParallelType::BIDx); + + tv0->computeWith(tv7, 1); + + GpuLower gpulw(&fusion); + + // The this-position of the last tensor should be zero. + TORCH_CHECK( + tv7->nDims() == 3 && tv7->getComputeAtPosition() == 0 && + tv7->getMaxProducerPosition() == 1); + TORCH_CHECK( + tv7->nDims() == 3 && tv6->getComputeAtPosition() == 0 && + tv6->getMaxProducerPosition() == 1); + + ComputeAtMap ca_map(&fusion); + + // The position of every other tensor should be 1. + for (auto tv : {tv1, tv2, tv3, tv4, tv5}) { + TORCH_CHECK(tv->nDims() == 3 && tv->getComputeAtPosition() == 1); + TORCH_CHECK( + ca_map.areMapped(tv7->axis(0), tv->axis(0), IdMappingMode::PERMISSIVE)); + } + + for (Val* val : fusion.vals()) { + if (!val->isFusionInput() && + val->getValType().value() == ValType::TensorView) { + TensorView* tv = static_cast(val); + tv->axis(1)->parallelize(ParallelType::Unroll); + tv->axis(-1)->parallelize(ParallelType::TIDx); + } + } + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor aten_input = at::randn({129, 127}, options); + + auto t1 = aten_input.mul({0.5}); + auto t2 = t1.mul({-1.0}); + auto t3 = t1.add({3.0}); + auto t4 = t1.mul({2.0}); + auto t5 = t3.add(t2); + auto t6 = t5.add(t4); + auto t7 = t1.add(t4); + + std::vector aten_outputs = {t6, t7}; + std::vector cg_outputs = { + at::empty_like(aten_input, options), at::empty_like(aten_input, options)}; + + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); + fe.runFusion({aten_input}, cg_outputs); + + testValidate( + &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionAdvancedComputeWith2_CUDA) { + // Case 2 + // tv1 = tv0 * -1 + // tv2 = tv0 + 3 + // tv3 = tv0 * 2 + // tv4 = tv2 + tv1 + // tv5 = tv4 + tv3 + // tv6 = tv5 + tv3 + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + TensorView* tv1 = mul(tv0, IrBuilder::create(-1.0)); + TensorView* tv2 = add(tv0, IrBuilder::create(3.0)); + TensorView* tv3 = mul(tv0, IrBuilder::create(2.0)); + TensorView* tv4 = add(tv2, tv1); + + TensorView* tv5 = add(tv4, tv3); + TensorView* tv6 = add(tv5, tv3); + + fusion.addOutput(tv5); + fusion.addOutput(tv6); + + // Lets setup to actually run + tv0->merge(0); + tv0->split(0, 128); + tv0->split(0, 4); + + tv0->axis(0)->parallelize(ParallelType::BIDx); + + tv0->computeWith(tv6, 1); + + for (Val* val : fusion.vals()) { + if (!val->isFusionInput() && + val->getValType().value() == ValType::TensorView) { + TensorView* tv = static_cast(val); + + tv->axis(1)->parallelize(ParallelType::Unroll); + tv->axis(-1)->parallelize(ParallelType::TIDx); + } + } + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input = at::randn({129, 127}, options); + + auto t1 = input.mul({-1.0}); + auto t2 = input.add({3.0}); + auto t3 = input.mul({2.0}); + auto t4 = t2.add(t1); + auto t5 = t4.add(t3); + auto t6 = t5.add(t3); + + std::vector aten_outputs = {t5, t6}; + + FusionExecutor fe; + fe.compileFusion(&fusion, {input}); + auto cg_outputs = fe.runFusion({input}); + + testValidate(&fusion, cg_outputs, {input}, aten_outputs, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionAdvancedComputeWith3_CUDA) { + // Case 3 + // T2 = T1 * 0.979361 + // T3 = T2 * T0 + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeSymbolicTensor(4); + fusion.addInput(tv0); + + TensorView* tv1 = makeSymbolicTensor(4); + fusion.addInput(tv1); + + TensorView* tv2 = mul(tv1, IrBuilder::create(.979361)); + TensorView* tv3 = mul(tv2, tv0); + + fusion.addOutput(tv3); + + // Lets setup to actually run + while (tv0->nDims() > 1) + tv0->merge(0); + tv0->split(0, 128); + tv0->split(0, 4); + + while (tv1->nDims() > 1) + tv1->merge(0); + tv1->split(0, 128); + tv1->split(0, 4); + + tv0->computeWith(tv3, 1); + tv1->computeWith(tv3, 1); + + tv3->axis(0)->parallelize(ParallelType::BIDx); + + for (Val* val : fusion.vals()) { + if (!val->isFusionInput() && + val->getValType().value() == ValType::TensorView) { + TensorView* tv = static_cast(val); + + tv->axis(1)->parallelize(ParallelType::Unroll); + tv->axis(-1)->parallelize(ParallelType::TIDx); + } + } + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({129, 127, 63, 65}, options); + at::Tensor t1 = at::rand_like(t0, options); + + auto t2 = t1.mul({0.979361}); + auto aten_output = t2.mul(t0); + + std::vector aten_inputs = {t0, t1}; + + at::Tensor cg_output = at::empty_like(t0, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + fe.runFusion(aten_inputs, {cg_output}); + + testValidate( + &fusion, {cg_output}, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionAdvancedComputeWith4_CUDA) { + // Case 4 + // T4 = T2 - T3 + // T5 = T1 + T4 + // T6 = T5 - T0 + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeSymbolicTensor(4); + fusion.addInput(tv0); + + TensorView* tv1 = makeSymbolicTensor(4); + fusion.addInput(tv1); + + TensorView* tv2 = makeSymbolicTensor(4); + fusion.addInput(tv2); + + TensorView* tv3 = makeSymbolicTensor(4); + fusion.addInput(tv3); + + TensorView* tv4 = sub(tv2, tv3); + TensorView* tv5 = add(tv1, tv4); + TensorView* tv6 = sub(tv5, tv0); + + fusion.addOutput(tv6); + std::vector tvs = {tv0, tv1, tv2}; + for (auto tv : tvs) { + // Lets setup to actually run + while (tv->nDims() > 1) { + tv->merge(0); + } + tv->split(0, 128); + tv->split(0, 4); + tv->computeWith(tv6, 1); + } + + tv6->axis(0)->parallelize(ParallelType::BIDx); + + for (Val* val : fusion.vals()) { + if (!val->isFusionInput() && + val->getValType().value() == ValType::TensorView) { + TensorView* tv = static_cast(val); + + tv->axis(1)->parallelize(ParallelType::Unroll); + tv->axis(-1)->parallelize(ParallelType::TIDx); + } + } + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({129, 127, 63, 65}, options); + at::Tensor t1 = at::rand_like(t0, options); + at::Tensor t2 = at::rand_like(t0, options); + at::Tensor t3 = at::rand_like(t0, options); + + auto t4 = t2.sub(t3); + auto t5 = t1.add(t4); + auto aten_output = t5.sub(t0); + + std::vector aten_inputs = {t0, t1, t2, t3}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + testValidate( + &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionAdvancedComputeWith5_CUDA) { + // Case 5 + // tv2 = tv0 + 2.0 + // tv3 = tv1 * tv2 + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + TensorView* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + TensorView* tv1 = makeSymbolicTensor(2); + fusion.addInput(tv1); + TensorView* tv2 = add(tv0, IrBuilder::create(2.0)); + TensorView* tv3 = mul(tv1, tv2); + fusion.addOutput(tv3); + + tv2->merge(0); + tv2->split(-1, 8); + tv2->split(-1, 4); + + tv2->computeWith(tv3, 1); + tv3->axis(0)->parallelize(ParallelType::BIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({63, 65}, options); + at::Tensor t1 = at::rand_like(t0, options); + + auto t2 = t0.add(2.0); + auto aten_output = t1.mul(t2); + + std::vector aten_inputs = {t0, t1}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + testValidate( + &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionAdvancedComputeWith6_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + TensorView* tv1 = makeSymbolicTensor(2); + fusion.addInput(tv1); + TensorView* tv2 = add(tv0, IrBuilder::create(2.0)); + TensorView* tv3 = mul(tv1, tv2); + fusion.addOutput(tv3); + + tv2->merge(0); + tv2->split(-1, 8); + tv2->split(-1, 4); + tv3->merge(0); + tv3->split(-1, 8); + + tv2->computeWith(tv3, 1); + + tv3->axis(0)->parallelize(ParallelType::BIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({63, 65}, options); + at::Tensor t1 = at::rand_like(t0, options); + + auto t2 = t0.add(2.0); + auto aten_output = t1.mul(t2); + + std::vector aten_inputs = {t0, t1}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + testValidate( + &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionComputeAtMultiConsumers_CUDA) { + // tv1 = tv0 * 0.5 + // tv2 = tv1 * -1 + // tv3 = tv2 * -2 + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + TensorView* tv1 = mul(tv0, IrBuilder::create(0.5)); + TensorView* tv2 = mul(tv1, IrBuilder::create(-1.0)); + TensorView* tv3 = mul(tv1, IrBuilder::create(-2.0)); + fusion.addOutput(tv2); + fusion.addOutput(tv3); + + // This computeAt will affect tv2 as well, even though tv2 is not in + // the data-flow path between tv1 and tv3. The reason is that tv1 is + // now computed at tv3, so tv2 must also be computed at the same + // location. Overall, what will happen is basically we merge + // expressions of all tensors and compute them in a single loop + // nest. + TensorView* computeAtTarget = tv3; + computeAtTarget->split(0, 128); + tv1->computeAt(computeAtTarget, 1); + + TensorView* affected_tensors[] = {tv1, tv2, tv3}; + for (auto tv : affected_tensors) { + TORCH_CHECK(tv->nDims() == computeAtTarget->nDims()); + } + + GpuLower gpulw(&fusion); + + TORCH_CHECK(tv1->getComputeAtPosition() == 1); + TORCH_CHECK( + tv2->getComputeAtPosition() == 0 && tv2->getMaxProducerPosition() == 1); + TORCH_CHECK( + tv3->getComputeAtPosition() == 0 && tv3->getMaxProducerPosition() == 1); + + ComputeAtMap ca_map(&fusion); + + // Note that tv2 is also computed at tv3. + for (auto tv : {tv1, tv2}) { + TORCH_CHECK(ca_map.areMapped( + tv->axis(0), computeAtTarget->axis(0), IdMappingMode::PERMISSIVE)); + } + + TORCH_CHECK(tv3->getComputeAtPosition() == 0); + + computeAtTarget->axis(0)->parallelize(ParallelType::BIDx); + for (auto tv : affected_tensors) { + tv->axis(-1)->parallelize(ParallelType::TIDx); + } + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor aten_input = at::randn({1000}, options); + + auto t1 = aten_input * 0.5; + auto t2 = t1 * -1.0; + auto t3 = t1 * -2.0; + + std::vector aten_outputs = {t2, t3}; + + std::vector cg_outputs = { + at::empty_like(aten_input, options), at::empty_like(aten_input, options)}; + + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); + fe.runFusion({aten_input}, cg_outputs); + + testValidate( + &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); +} + +// Similar to ComputeAtMultiConsumers, but with a common consumer. +TEST_F(NVFuserTest, FusionComputeAtCommonConsumer1_CUDA) { + // tv1 = tv0 * 0.5 + // tv2 = tv1 * -1 + // tv3 = tv2 * -2 + // tv4 = tv2 + tv3 + // tv5 = tv4 * 5 + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + TensorView* tv1 = mul(tv0, IrBuilder::create(0.5)); + TensorView* tv2 = mul(tv1, IrBuilder::create(-1.0)); + TensorView* tv3 = mul(tv1, IrBuilder::create(-2.0)); + TensorView* tv4 = add(tv2, tv3); + TensorView* tv5 = mul(tv4, IrBuilder::create(5.0)); + fusion.addOutput(tv3); + fusion.addOutput(tv4); + fusion.addOutput(tv5); + + // Computing tv1 at tv3. This will affect tv2 as discussed in + // ComplexComputeAt1. Additionally, in this case, notice that tv4 is + // the common consumer of tv2 and tv3, so they are computed at + // tv4. The indirect propagation of the computeAt should stop at the + // common consumer, and no further change should occur. More + // specifically, the computeAT position of tv4 and tv5 should be zero. + TensorView* computeAtTarget = tv3; + computeAtTarget->split(0, 128); + tv1->computeAt(computeAtTarget, 1); + + TensorView* affected_tensors[] = {tv1, tv2, tv3, tv4}; + for (auto tv : affected_tensors) { + TORCH_CHECK(tv->nDims() == computeAtTarget->nDims()); + } + + TORCH_CHECK(tv1->getComputeAtPosition() == 1); + TORCH_CHECK(tv2->getComputeAtPosition() == 1); + TORCH_CHECK(tv3->getComputeAtPosition() == 1); + TORCH_CHECK(tv4->getComputeAtPosition() == 0); + TORCH_CHECK(tv5->getComputeAtPosition() == 0); + + computeAtTarget->axis(0)->parallelize(ParallelType::BIDx); + + for (auto tv : affected_tensors) { + tv->axis(-1)->parallelize(ParallelType::TIDx); + } + + // Transform tv5 to make it look like the rest + tv5->split(0, 128); + tv5->axis(1)->parallelize(ParallelType::TIDx); + tv5->axis(0)->parallelize(ParallelType::BIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor aten_input = at::randn({1000}, options); + + auto t1 = aten_input * 0.5; + auto t2 = t1 * -1.0; + auto t3 = t1 * -2.0; + auto t4 = t2 + t3; + auto t5 = t4 * 5.0; + + std::vector aten_outputs = {t3, t4, t5}; + std::vector cg_outputs = { + at::empty_like(aten_input, options), + at::empty_like(aten_input, options), + at::empty_like(aten_input, options)}; + + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); + fe.runFusion({aten_input}, cg_outputs); + + testValidate( + &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionComputeAtCommonConsumer2_CUDA) { + // tv1 = tv0 * 0.5 + // tv2 = tv1 * -1 + // tv3 = tv2 * -1 + // tv4 = tv1 + 4 + // tv5 = tv3 + tv4 + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + TensorView* tv1 = mul(tv0, IrBuilder::create(0.5)); + TensorView* tv2 = mul(tv1, IrBuilder::create(-1.0)); + TensorView* tv3 = mul(tv2, IrBuilder::create(-1.0)); + TensorView* tv4 = add(tv1, IrBuilder::create(4.0)); + TensorView* tv5 = add(tv3, tv4); + + fusion.addOutput(tv5); + + TensorView* computeAtTarget = tv3; + + computeAtTarget->merge(0); + computeAtTarget->split(0, 128); + computeAtTarget->split(0, 4); + + computeAtTarget->axis(0)->parallelize(ParallelType::BIDx); + + // This computeAt will affect all tensors including tv3, tv4 and + // tv5, even though it appears to impact only tv1 and tv2. The + // reason is that tv1 is now computed at tv3, so tv4 must also be + // computed at the same location. Similarly, the consumer of tv4, + // tv5, must also be computed at the same location. Overall, what + // will happen is basically we merge expressions of all tensors and + // compute them in a single loop nest. Internally, this will be + // realized by making all tensors, except for those in the path + // between tv1 and tv3, computed at tv5, which we call the common + // consumer. + tv1->computeAt(computeAtTarget, 1); + + // All tensors should have the same dimenionality as the target + for (Val* val : fusion.vals()) { + if (val->isFusionInput() || + val->getValType().value() != ValType::TensorView) { + continue; + } + TensorView* tv = val->as(); + TORCH_CHECK(tv->nDims() == computeAtTarget->nDims()); + if (tv == tv5) { + TORCH_CHECK(tv->getComputeAtPosition() == 0); + } else { + TORCH_CHECK(tv->getComputeAtPosition() == 1); + } + } + + for (auto tv : ir_utils::filterByType(fusion.vals())) { + if (!tv->isFusionInput()) { + tv->axis(1)->parallelize(ParallelType::Unroll); + tv->axis(-1)->parallelize(ParallelType::TIDx); + } + } + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor aten_input = at::randn({129, 127}, options); + + auto t1 = aten_input.mul({0.5}); + auto t2 = t1.mul({-1.0}); + auto t3 = t2.mul({-1.0}); + auto t4 = t1.add({4.0}); + auto aten_output = t3 + t4; + + at::Tensor cg_output = at::empty_like(aten_input, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); + fe.runFusion({aten_input}, {cg_output}); + + testValidate( + &fusion, {cg_output}, {aten_input}, {aten_output}, __LINE__, __FILE__); +} + +// Similar to the above common consumer test but adds an additional +// tensor that has no common consumer with the other tensors. +TEST_F(NVFuserTest, FusionComputeAtCommonConsumer3_CUDA) { + // tv1 = tv0 * 0.5 + // tv2 = tv1 * -1 + // tv3 = tv2 * -1 + // tv4 = tv1 + 4 + // tv5 = tv2 + tv3 + // tv6 = tv1 + 6 + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + TensorView* tv1 = mul(tv0, IrBuilder::create(0.5)); + TensorView* tv2 = mul(tv1, IrBuilder::create(-1.0)); + TensorView* tv3 = mul(tv2, IrBuilder::create(-1.0)); + TensorView* tv4 = add(tv1, IrBuilder::create(4.0)); + TensorView* tv5 = add(tv3, tv4); + TensorView* tv6 = add(tv1, IrBuilder::create(6.0)); + + fusion.addOutput(tv5); + fusion.addOutput(tv6); + + TensorView* computeAtTarget = tv3; + + computeAtTarget->merge(0); + computeAtTarget->split(0, 128); + computeAtTarget->split(0, 4); + + computeAtTarget->axis(0)->parallelize(ParallelType::BIDx); + + // This will have the same impact on the tensors except for tv5 and + // tv6. tv6 does not have any common consumer with the computeAt + // target, but since it uses tv1, it must be also computed at the + // same location as the other impacted tensors. We can either make + // tv5 computed at tv6 or tv6 computed at tv5. In this case, tv5 + // should be computed at tv6 just because the current implementation + // orders the computeAt relationship based on the order in which + // tensors are specified as outputs. + + tv1->computeAt(computeAtTarget, 1); + + // All tensors should have the same dimenionality as the target + for (auto tv : ir_utils::filterByType(fusion.vals())) { + if (tv->isFusionInput()) { + continue; + } + TORCH_CHECK(tv->nDims() == computeAtTarget->nDims()); + if (tv == tv5 || tv == tv6) { + TORCH_CHECK(tv->getComputeAtPosition() == 0); + TORCH_CHECK(tv->getMaxProducerPosition() == 1); + } else { + TORCH_CHECK(tv->getComputeAtPosition() == 1); + } + } + + for (Val* val : fusion.vals()) { + if (!val->isFusionInput() && + val->getValType().value() == ValType::TensorView) { + TensorView* tv = val->as(); + tv->axis(1)->parallelize(ParallelType::Unroll); + tv->axis(-1)->parallelize(ParallelType::TIDx); + } + } + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor aten_input = at::randn({129, 127}, options); + + auto t1 = aten_input.mul({0.5}); + auto t2 = t1.mul({-1.0}); + auto t3 = t2.mul({-1.0}); + auto t4 = t1.add({4.0}); + auto t5 = t3 + t4; + auto t6 = t1.add({6.0}); + + std::vector aten_outputs = {t5, t6}; + std::vector cg_outputs = { + at::empty_like(aten_input, options), at::empty_like(aten_input, options)}; + + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); + fe.runFusion({aten_input}, cg_outputs); + + testValidate( + &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); +} + +// Similar to ComputeAtCommonConsumer1 but with an addtiona ltensor +// that does not have data dependency with the consumer. +TEST_F(NVFuserTest, FusionComputeAtNoCommonConsumer_CUDA) { + // tv1 = tv0 * 0.5 + // tv2 = tv1 * -1 + // tv3 = tv1 * -2 + // tv4 = tv2 + tv3 + // tv5 = tv4 * 5 + // tv6 = tv1 * 6 + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + TensorView* tv1 = mul(tv0, IrBuilder::create(0.5)); + TensorView* tv2 = mul(tv1, IrBuilder::create(-1.0)); + TensorView* tv3 = mul(tv1, IrBuilder::create(-2.0)); + TensorView* tv4 = add(tv2, tv3); + TensorView* tv5 = mul(tv4, IrBuilder::create(5.0)); + // Notice that tv6 is not a consumer of tv4. + TensorView* tv6 = mul(tv1, IrBuilder::create(6.0)); + fusion.addOutput(tv3); + fusion.addOutput(tv4); + fusion.addOutput(tv5); + fusion.addOutput(tv6); + + TensorView* computeAtTarget = tv3; + computeAtTarget->split(0, 128); + tv1->computeAt(computeAtTarget, 1); + + TensorView* affected_tensors[] = {tv1, tv2, tv3, tv4, tv5, tv6}; + for (auto tv : affected_tensors) { + TORCH_CHECK(tv->nDims() == computeAtTarget->nDims()); + if (tv == tv6 || tv == tv5) { + TORCH_CHECK(tv->getComputeAtPosition() == 0); + } else { + TORCH_CHECK(tv->getComputeAtPosition() == 1); + } + } + + computeAtTarget->axis(0)->parallelize(ParallelType::BIDx); + + for (auto tv : affected_tensors) { + tv->axis(-1)->parallelize(ParallelType::TIDx); + } + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor aten_input = at::randn({1000}, options); + + auto t1 = aten_input * 0.5; + auto t2 = t1 * -1.0; + auto t3 = t1 * -2.0; + auto t4 = t2 + t3; + auto t5 = t4 * 5.0; + auto t6 = t1 * 6.0; + + std::vector aten_outputs = {t3, t4, t5, t6}; + std::vector cg_outputs = { + at::empty_like(aten_input, options), + at::empty_like(aten_input, options), + at::empty_like(aten_input, options), + at::empty_like(aten_input, options)}; + + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); + fe.runFusion({aten_input}, cg_outputs); + + testValidate( + &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); +} + +namespace { + +void checkIdMapped( + ComputeAtRootDomainMap& root_map, + TensorView* v0, + IterDomain* id0, + TensorView* v1, + IterDomain* id1, + bool should_map) { + if (should_map) { + TORCH_CHECK( + root_map.canMap(v0->domain(), id0, v1->domain(), id1), + "Should be mappable: ", + id0, + " of ", + v0, + " and ", + id1, + " of ", + v1); + } else { + TORCH_CHECK( + !root_map.canMap(v0->domain(), id0, v1->domain(), id1), + "Should not be mappable: ", + id0, + " of ", + v0, + " and ", + id1, + " of ", + v1); + } +} + +void checkIdMapped( + TensorView* v0, + const std::vector& root0, + const std::vector should_map0, + TensorView* v1, + const std::vector& root1, + const std::vector should_map1) { + ComputeAtRootDomainMap map; + map.build(); + TORCH_INTERNAL_ASSERT(root0.size() == should_map0.size()); + TORCH_INTERNAL_ASSERT(root1.size() == should_map1.size()); + size_t idx0 = 0; + for (const auto i : c10::irange(root0.size())) { + size_t idx1 = 0; + for (const auto j : c10::irange(root1.size())) { + if (should_map0[i] && should_map1[j] && idx0 == idx1) { + checkIdMapped(map, v0, root0[i], v1, root1[j], true); + } else { + checkIdMapped(map, v0, root0[i], v1, root1[j], false); + } + if (should_map1[j]) + ++idx1; + } + if (should_map0[i]) + ++idx0; + } +} + +void checkIdMapped( + TensorView* v0, + const std::vector& root0, + TensorView* v1, + const std::vector& root1) { + checkIdMapped( + v0, + root0, + std::vector(root0.size(), true), + v1, + root1, + std::vector(root1.size(), true)); +} + +} // namespace + +TEST_F(NVFuserTest, FusionRootMappingBasic_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeSymbolicTensor(2); + TensorView* tv1 = makeSymbolicTensor(2); + + fusion.addInput(tv0); + fusion.addInput(tv1); + auto tv3 = broadcast(tv0, {true, false, false}); + auto tv4 = broadcast(tv1, {false, true, false}); + auto tv5 = add(tv3, tv4); + fusion.addOutput(tv5); + + checkIdMapped( + tv0, + tv0->getRootDomain(), + {true, true}, + tv4, + tv4->getRootDomain(), + {false, true, true}); + checkIdMapped( + tv1, + tv1->getRootDomain(), + {true, true}, + tv4, + tv4->getRootDomain(), + {true, false, true}); + checkIdMapped( + tv0, + tv0->getRootDomain(), + {false, true}, + tv1, + tv1->getRootDomain(), + {false, true}); + checkIdMapped( + tv0, + tv0->getRootDomain(), + {true, true}, + tv5, + tv5->getRootDomain(), + {false, true, true}); + checkIdMapped( + tv1, + tv1->getRootDomain(), + {true, true}, + tv5, + tv5->getRootDomain(), + {true, false, true}); + checkIdMapped(tv3, tv3->getRootDomain(), tv4, tv4->getRootDomain()); + checkIdMapped(tv3, tv3->getRootDomain(), tv5, tv5->getRootDomain()); + checkIdMapped(tv4, tv4->getRootDomain(), tv5, tv5->getRootDomain()); +} + +TEST_F(NVFuserTest, FusionRootMappingRfactor_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // [I,I] + TensorView* tv0 = makeSymbolicTensor(2); + // [I,I,I] + TensorView* tv1 = makeSymbolicTensor(3); + + //[I,I,R] + auto tv2 = sum(tv1, {2}); + auto tv3 = add(tv2, tv0); + + fusion.addInput(tv0); + fusion.addInput(tv1); + fusion.addOutput(tv3); + + // scheduling: + //[B,I,R0,R1=128], root = [B,I,R] + tv2->split(2, 128); + + // root=[B,I,Irf], rfactor=[B,I,Irf,Rrf] + auto tv4 = tv2->rFactor({3}); + + checkIdMapped(tv1, tv1->getRootDomain(), tv4, tv4->getRootDomain()); + checkIdMapped( + tv4, + tv4->getRFactorDomain(), + {true, true, true, false}, + tv2, + tv2->getRootDomain(), + {true, true, true}); + checkIdMapped( + tv1, + tv1->getRootDomain(), + {true, true, false}, + tv2, + tv2->getRootDomain(), + {true, true, false}); + checkIdMapped( + tv1, + tv1->getRootDomain(), + {true, true, false}, + tv3, + tv3->getRootDomain(), + {true, true}); + checkIdMapped( + tv2, + tv2->getRootDomain(), + {true, true, false}, + tv3, + tv3->getRootDomain(), + {true, true}); + checkIdMapped(tv0, tv0->getRootDomain(), tv3, tv3->getRootDomain()); + checkIdMapped( + tv0, + tv0->getRootDomain(), + {true, true}, + tv1, + tv1->getRootDomain(), + {true, true, false}); + checkIdMapped( + tv0, + tv0->getRootDomain(), + {true, true}, + tv2, + tv2->getRootDomain(), + {true, true, false}); + checkIdMapped( + tv0, + tv0->getRootDomain(), + {true, true}, + tv4, + tv4->getRFactorDomain(), + {true, true, false, false}); + checkIdMapped( + tv0, + tv0->getRootDomain(), + {true, true}, + tv4, + tv4->getRootDomain(), + {true, true, false}); +} + +TEST_F(NVFuserTest, FusionRootMappingReductionDependency1_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeSymbolicTensor(2); + auto tv1 = sum(tv0, {1}); + auto tv2 = broadcast(tv1, {false, true}); + fusion.addOutput(tv2); + + // The second dimension cannot be mapped as it would require recomputation. + checkIdMapped(tv0, tv0->getRootDomain(), tv1, tv1->getRootDomain()); + checkIdMapped( + tv1, + tv1->getRootDomain(), + {true, false}, + tv2, + tv2->getRootDomain(), + {true, false}); + checkIdMapped( + tv0, + tv0->getRootDomain(), + {true, false}, + tv2, + tv2->getRootDomain(), + {true, false}); +} + +TEST_F(NVFuserTest, FusionRootMappingReductionDependency2_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeSymbolicTensor(2); + auto tv1 = sum(tv0, {1}); + auto tv2 = broadcast(tv1, {false, true}); + auto tv3 = add(tv0, tv2); + fusion.addOutput(tv3); + + checkIdMapped( + tv0, + tv0->getRootDomain(), + {true, false}, + tv1, + tv1->getRootDomain(), + {true, false}); + checkIdMapped( + tv1, + tv1->getRootDomain(), + {true, false}, + tv2, + tv2->getRootDomain(), + {true, false}); + checkIdMapped( + tv0, + tv0->getRootDomain(), + {true, false}, + tv3, + tv3->getRootDomain(), + {true, false}); + checkIdMapped(tv2, tv2->getRootDomain(), tv3, tv3->getRootDomain()); +} + +TEST_F(NVFuserTest, FusionRootMappingReductionDependency3_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeSymbolicTensor(2); + auto tv1 = sum(tv0, {1}); + auto tv2 = broadcast(tv1, {false, true}); + fusion.addOutput(tv2); + + tv1->split(-1, 4); + auto tv3 = tv1->rFactor({-2}); + + checkIdMapped(tv0, tv0->getRootDomain(), tv3, tv3->getRootDomain()); + checkIdMapped( + tv3, + tv3->getMaybeRFactorDomain(), + {true, false, true}, + tv1, + tv1->getRootDomain(), + {true, true}); + checkIdMapped( + tv1, + tv1->getRootDomain(), + {true, false}, + tv2, + tv2->getRootDomain(), + {true, false}); +} + +TEST_F(NVFuserTest, FusionRootMappingReductionDependency4_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeSymbolicTensor(2); + auto tv1 = sum(tv0, {1}); + auto tv2 = broadcast(tv1, {false, true}); + auto tv3 = add(tv0, tv2); + fusion.addOutput(tv3); + + tv1->split(-1, 4); + auto tv4 = tv1->rFactor({-2}); + + checkIdMapped( + tv0, + tv0->getRootDomain(), + {true, false}, + tv4, + tv4->getRootDomain(), + {true, false}); + checkIdMapped( + tv4, + tv4->getMaybeRFactorDomain(), + {true, false, true}, + tv1, + tv1->getRootDomain(), + {true, true}); + checkIdMapped( + tv1, + tv1->getRootDomain(), + {true, false}, + tv2, + tv2->getRootDomain(), + {true, false}); + checkIdMapped(tv2, tv2->getRootDomain(), tv3, tv3->getRootDomain()); + checkIdMapped( + tv0, + tv0->getRootDomain(), + {true, false}, + tv2, + tv2->getRootDomain(), + {true, false}); +} + +// Reproducer of issue #749 +TEST_F(NVFuserTest, FusionRootMappingReductionDependency5_CUDA_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = sum(tv1, {1}); + auto tv3 = broadcast(tv2, {false, true}); + auto tv4 = add(tv0, tv3); + auto tv5 = add(tv4, tv1); + fusion.addOutput(tv5); + + checkIdMapped( + tv0, + tv0->getRootDomain(), + {true, false}, + tv1, + tv1->getRootDomain(), + {true, false}); + checkIdMapped( + tv1, + tv1->getRootDomain(), + {true, false}, + tv2, + tv2->getRootDomain(), + {true, false}); + checkIdMapped( + tv2, + tv2->getRootDomain(), + {true, false}, + tv3, + tv3->getRootDomain(), + {true, false}); + checkIdMapped( + tv3, + tv3->getRootDomain(), + {true, true}, + tv4, + tv4->getRootDomain(), + {true, true}); + checkIdMapped( + tv0, + tv0->getRootDomain(), + {true, false}, + tv4, + tv4->getRootDomain(), + {true, false}); + checkIdMapped( + tv4, + tv4->getRootDomain(), + {true, true}, + tv5, + tv5->getRootDomain(), + {true, true}); +} + +// Similar to RootMappingReductionDependency5 but with rFactor +TEST_F(NVFuserTest, FusionRootMappingReductionDependency6_CUDA_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = sum(tv1, {1}); + auto tv3 = broadcast(tv2, {false, true}); + auto tv4 = add(tv0, tv3); + auto tv5 = add(tv4, tv1); + fusion.addOutput(tv5); + + tv2->split(1, 4); + auto tv6 = tv2->rFactor({-1}); + + checkIdMapped( + tv0, + tv0->getRootDomain(), + {true, false}, + tv1, + tv1->getRootDomain(), + {true, false}); + checkIdMapped( + tv1, + tv1->getRootDomain(), + {true, false}, + tv6, + tv6->getRootDomain(), + {true, false}); + checkIdMapped( + tv6, + tv6->getMaybeRFactorDomain(), + {true, true, false}, + tv2, + tv2->getRootDomain(), + {true, true}); + checkIdMapped( + tv1, + tv1->getRootDomain(), + {true, false}, + tv2, + tv2->getRootDomain(), + {true, false}); + checkIdMapped( + tv2, + tv2->getRootDomain(), + {true, false}, + tv3, + tv3->getRootDomain(), + {true, false}); + checkIdMapped( + tv3, + tv3->getRootDomain(), + {true, true}, + tv4, + tv4->getRootDomain(), + {true, true}); + checkIdMapped( + tv0, + tv0->getRootDomain(), + {true, false}, + tv4, + tv4->getRootDomain(), + {true, false}); + checkIdMapped( + tv4, + tv4->getRootDomain(), + {true, true}, + tv5, + tv5->getRootDomain(), + {true, true}); +} + +TEST_F( + NVFuserTest, + FusionRootMappingMultipleBroadcastWithNoCommonConsumer_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeSymbolicTensor(1); + auto tv1 = broadcast(tv0, {false, true}); + auto tv2 = broadcast(tv0, {true, false}); + fusion.addOutput(tv1); + fusion.addOutput(tv2); + + // If there is no common consumer, there is no recomputation constraint. + checkIdMapped( + tv0, + tv0->getRootDomain(), + {true}, + tv1, + tv1->getRootDomain(), + {true, false}); + checkIdMapped( + tv0, + tv0->getRootDomain(), + {true}, + tv2, + tv2->getRootDomain(), + {false, true}); + checkIdMapped( + tv1, + tv1->getRootDomain(), + {true, false}, + tv2, + tv2->getRootDomain(), + {false, true}); +} + +TEST_F(NVFuserTest, FusionRootMappingBroadcastNonUniqueSize_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + auto tv1 = makeSymbolicTensor(2); + fusion.addInput(tv1); + auto tv2 = makeSymbolicTensor(2); + fusion.addInput(tv2); + auto tv3 = broadcast(tv0, {false, true}); + auto tv4 = add(tv1, tv3); + fusion.addOutput(tv4); + auto tv5 = add(tv2, tv3); + fusion.addOutput(tv5); + + // Broadcast domains can be used with multiple domains with + // different sizes. In this test, the broadcast domain of tv3 has + // two consumers, tv4 and tv5, which may have different sizes. Each + // of the consumers is used with the broadcast domain of tv3, but + // the two consumers may not have the same size, it is not possible + // to map those domains. + checkIdMapped( + tv0, + tv0->getRootDomain(), + {true}, + tv3, + tv3->getRootDomain(), + {true, false}); + checkIdMapped( + tv0, + tv0->getRootDomain(), + {true}, + tv1, + tv1->getRootDomain(), + {true, false}); + checkIdMapped( + tv0, + tv0->getRootDomain(), + {true}, + tv2, + tv2->getRootDomain(), + {true, false}); + checkIdMapped( + tv1, + tv1->getRootDomain(), + {true, false}, + tv2, + tv2->getRootDomain(), + {true, false}); + checkIdMapped( + tv1, + tv1->getRootDomain(), + {true, false}, + tv3, + tv3->getRootDomain(), + {true, false}); + checkIdMapped( + tv2, + tv2->getRootDomain(), + {true, false}, + tv3, + tv3->getRootDomain(), + {true, false}); + checkIdMapped( + tv3, + tv3->getRootDomain(), + {true, false}, + tv4, + tv4->getRootDomain(), + {true, false}); + checkIdMapped( + tv3, + tv3->getRootDomain(), + {true, false}, + tv5, + tv5->getRootDomain(), + {true, false}); + checkIdMapped( + tv4, + tv4->getRootDomain(), + {true, false}, + tv5, + tv5->getRootDomain(), + {true, false}); +} + +TEST_F(NVFuserTest, FusionRootMappingBroadcast_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + // tv0[I0] + fusion.addInput(tv0); + auto tv1 = broadcast(tv0, {true, false}); + // tv1[B1, I0] + auto tv2 = broadcast(tv1, {true, false, false}); + // tv2[B2, B1, I0] + fusion.addOutput(tv2); + + // In this case, tv1 and tv2 has one and two broadcast domains, + // respectively. It is the second broadcast domain that is mapped to + // the broadcast of tv1. + checkIdMapped( + tv0, + tv0->getRootDomain(), + {true}, + tv1, + tv1->getRootDomain(), + {false, true}); + checkIdMapped( + tv1, + tv1->getRootDomain(), + {true, true}, + tv2, + tv2->getRootDomain(), + {false, true, true}); // Not {true, false, true} + checkIdMapped( + tv0, + tv0->getRootDomain(), + {true}, + tv2, + tv2->getRootDomain(), + {false, false, true}); +} + +// Reproducer of issue #723 +TEST_F(NVFuserTest, FusionRootMappingTrivialReduction_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + auto tv1 = makeSymbolicTensor(2); + + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = broadcast(tv0, {true, false}); + auto tv3 = sum(tv2, {0}); + auto tv4 = add(tv2, tv1); + + fusion.addOutput(tv3); + fusion.addOutput(tv4); + + ComputeAtRootDomainMap map; + map.build(); + + checkIdMapped( + map, tv2, tv2->getRootDomain()[0], tv4, tv4->getRootDomain()[0], true); + checkIdMapped( + map, tv2, tv2->getRootDomain()[0], tv3, tv3->getRootDomain()[0], true); + + tv2->computeAt(tv4, -1); + + const int x = 11; + const int y = 12; + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({x}, options); + at::Tensor t1 = at::randn({y, x}, options); + std::vector aten_inputs = {t0, t1}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto outputs = fe.runFusion(aten_inputs); + + auto t3 = t0; + auto t4 = t0.unsqueeze(0).expand({y, x}) + t1; + + testValidate(&fusion, outputs, aten_inputs, {t3, t4}, __LINE__, __FILE__); +} + +// Repro of issue #1950 +TEST_F(NVFuserTest, FusionRootMappingRepro1950_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + auto tv0 = makeSymbolicTensor(3); + auto tv1 = makeSymbolicTensor(3); + auto tv2 = makeSymbolicTensor(3); + + fusion.addInput(tv0); + fusion.addInput(tv1); + fusion.addInput(tv2); + + auto tv3 = set(tv0); + auto tv4 = mul(tv1, tv3); + auto tv5 = mul(tv1, tv2); + auto tv6 = mul(tv5, tv3); + auto tv7 = sum(tv6, {2}); + auto tv8 = broadcast(tv7, {false, false, true}); + auto tv9 = mul(tv3, tv8); + + // Issue #1950 was caused by a particular traversal ordering based + // on the output tensor ordering as below + fusion.addOutput(tv9); + fusion.addOutput(tv5); + fusion.addOutput(tv4); + + ComputeAtRootDomainMap root_map; + root_map.build(); + + checkIdMapped(root_map, tv4, tv4->axis(-1), tv9, tv9->axis(-1), false); +} + +TEST_F(NVFuserTest, FusionDetectSelfMappedDomains_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + // [I1] + auto tv1 = add(tv0, IrBuilder::create(1)); + // [B2, I2] + auto tv2 = broadcast(tv1, {true, false}); + // [I3, B3] + auto tv3 = broadcast(tv1, {false, true}); + // [I4, I5] + auto tv4 = add(tv2, tv3); + fusion.addOutput(tv4); + + // IterDomainGraph maps B2, I3 and I4 together, and similarly I2, + // B3 and I5. The problem is I1 is mapped with both of the ID + // groups, so eventually all of the IDs are mapped + // together. IterDomainGraph should throw an exception as this + // pattern of domain mappings is not supported. + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) + ASSERT_ANY_THROW({ IterDomainGraph id_graph(&fusion); }); +} + +TEST_F(NVFuserTest, FusionScalarInputs_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + TensorView* tv1 = makeSymbolicTensor(2); + fusion.addInput(tv1); + + Double* d0 = IrBuilder::create(); + fusion.addInput(d0); + Double* d1 = IrBuilder::create(); + fusion.addInput(d1); + Double* d2 = IrBuilder::create(); + fusion.addInput(d2); + Double* d3 = IrBuilder::create(); + fusion.addInput(d3); + Val* d4 = mul(d0, d1); + Val* d5 = sub(d2, d3); + + TensorView* tv2 = sub(tv1, d4); + TensorView* tv3 = add(tv0, d5); + TensorView* tv4 = mul(tv3, tv2); + + fusion.addOutput(tv4); + + // Lets setup to actually run + while (tv4->nDims() > 1) + tv4->merge(0); + tv4->split(0, 128); + tv4->split(0, 4); + + tv0->computeAt(tv4, 1); + tv1->computeAt(tv4, 1); + + tv4->axis(0)->parallelize(ParallelType::BIDx); + + for (Val* val : fusion.vals()) { + if (!val->isFusionInput() && + val->getValType().value() == ValType::TensorView) { + TensorView* tv = static_cast(val); + + tv->axis(1)->parallelize(ParallelType::Unroll); + tv->axis(-1)->parallelize(ParallelType::TIDx); + } + } + + // d4 = d0 * d1 + // d5 = d2 - d3 + // t2 = t1 - d4 + // t3 = t0 + d5 + // t4 = t3 * t2 + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + float fl0 = 0.1; + float fl1 = -0.2; + float fl2 = 0.3; + float fl3 = -0.4; + float fl4 = fl0 * fl1; + float fl5 = fl2 - fl3; + + at::Tensor t0 = at::randn({129, 127}, options); + at::Tensor t1 = at::rand_like(t0, options); + + auto t2 = t1.sub(fl4); + auto t3 = t0.add(fl5); + auto aten_output = t3.mul(t2); + + at::Tensor cg_output = at::empty_like(t0, options); + + at::Scalar test(fl0); + + std::vector aten_inputs = { + t0, + t1, + at::Scalar(fl0), + at::Scalar(fl1), + at::Scalar(fl2), + at::Scalar(fl3)}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + fe.runFusion(aten_inputs, {cg_output}); + + testValidate( + &fusion, {cg_output}, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionLoopUnroll_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + TensorView* tv0 = makeSymbolicTensor(3); + TensorView* tv1 = makeSymbolicTensor(3); + + // Register your inputs + fusion.addInput(tv0); + fusion.addInput(tv1); + + // Do math with it, it returns a `Val*` but can be static_casted back to + // TensorView + TensorView* tv2 = add(tv1, IrBuilder::create(2.0)); + TensorView* tv3 = add(tv0, tv2); + + // Register your outputs + fusion.addOutput(tv3); + + int block_size = 16; + + tv3->merge(0, 1); + tv3->merge(0, 1); + + tv3->split(0, block_size); + tv3->split(0, 4); + + // For all inputs, computeAt the output inline, temporaries should be squeezed + // between them + tv0->computeAt(tv3, 1); + tv1->computeAt(tv3, 1); + + // Parallelize + tv2->axis(1)->parallelize(ParallelType::Unroll); + tv3->axis(1)->parallelize(ParallelType::Unroll); + tv2->axis(-1)->parallelize(ParallelType::TIDx); + tv3->axis(-1)->parallelize(ParallelType::TIDx); + tv3->axis(0)->parallelize(ParallelType::BIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor input0 = at::randn({129, 13, 3}, options); + at::Tensor input1 = at::randn({129, 13, 3}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input0, input1}); + auto outputs = fe.runFusion({input0, input1}); + + TORCH_CHECK(outputs[0].equal(input0.add(input1.add(2.0)))); +} + +/* + * Helper function for single op testing that generates a codegen operand + */ + +Val* gen_jit_operand(std::pair desc) { + if (desc.first == ValType::TensorView) { + return makeSymbolicTensor(2, desc.second); + } else if (desc.first == ValType::Scalar) { + if (desc.second == DataType::Float) { + return IrBuilder::create(); + } else if (desc.second == DataType::Double) { + return IrBuilder::create(); + } else if (desc.second == DataType::ComplexFloat) { + return IrBuilder::create(); + } else if (desc.second == DataType::ComplexDouble) { + return IrBuilder::create(); + } else if (desc.second == DataType::Int) { + return IrBuilder::create(); + } else { + TORCH_CHECK(false, "Not currently supported type: ", desc.first); + } + } else { + TORCH_CHECK(false, "Not currently supported type: ", desc.first); + } + return nullptr; +} + +/* + * Helper function for single op testing that generates an ATen operand + */ + +IValue gen_aten_operand( + std::pair desc, + int blocks, + int threads, + bool rand) { + if (desc.first == ValType::TensorView) { + if (desc.second == DataType::Double || desc.second == DataType::Float || + desc.second == DataType::ComplexDouble || + desc.second == DataType::ComplexFloat || + desc.second == DataType::Half || desc.second == DataType::BFloat16) { + auto options = at::TensorOptions() + .dtype(data_type_to_aten(desc.second)) + .device(at::kCUDA, 0); + if (rand) { + return IValue(at::rand({blocks, threads}, options)); + } else { + return IValue(at::empty({blocks, threads}, options)); + } + } else if (desc.second == DataType::Int || desc.second == DataType::Int32) { + auto dtype = desc.second == DataType::Int32 ? at::kInt : at::kLong; + if (rand) { + auto options = + at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + return IValue(at::randn({blocks, threads}, options).mul(5).to(dtype)); + } else { + auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA, 0); + return IValue(at::empty({blocks, threads}, options)); + } + } else if (desc.second == DataType::Bool) { + if (rand) { + auto options = + at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + return IValue( + at::rand({blocks, threads}, options).round().to(at::kBool)); + } else { + auto options = + at::TensorOptions().dtype(at::kBool).device(at::kCUDA, 0); + return IValue(at::empty({blocks, threads}, options)); + } + } else { + TORCH_CHECK(false, "Not currently supported type: ", desc.second) + } + } else if (desc.first == ValType::Scalar) { + // IValue scalars can only be double int64 or bool + if (desc.second == DataType::ComplexDouble || + desc.second == DataType::ComplexFloat) { + return IValue(at::Scalar(c10::complex(1.0, 0.0))); + } else if ( + desc.second == DataType::Double || desc.second == DataType::Float || + desc.second == DataType::Half || desc.second == DataType::BFloat16) { + return IValue(at::Scalar(1.0)); + } else if (desc.second == DataType::Int) { + return IValue(at::Scalar(1)); + } else { + TORCH_CHECK(false, "Not currently supported type: ", desc.first); + } + } else { + TORCH_CHECK(false, "Not currently supported type: ", desc.first); + } + return nullptr; +} + +/* + * Templatized Helper Function To generate single Op comparison between the + * JIT codegen for Cuda and the ATen Library. + */ + +using OutputPair = std::pair; +template < + typename AtenFunc, + typename JitFunc, + typename InputTuple, + size_t... NumInputs> +void test_op( + int blocks, + int threads, + std::string op_str, + AtenFunc af, + JitFunc jf, + OutputPair op, + InputTuple it, + std::index_sequence) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Generate Input JIT function Inputs and add them as Inputs to the Fusion + // Graph + std::array jit_inputs = { + gen_jit_operand(std::get(it))...}; + std::for_each(jit_inputs.begin(), jit_inputs.end(), [&fusion](Val* v) { + fusion.addInput(v); + }); + TensorView* out = + static_cast(jf(std::get(jit_inputs)...)); + fusion.addOutput(out); + + std::for_each(jit_inputs.begin(), jit_inputs.end(), [out](Val* v) { + if (v->getValType() == ValType::TensorView) + static_cast(v)->computeAt(out, -1); + }); + out->axis(0)->parallelize(ParallelType::BIDx); + out->axis(-1)->parallelize(ParallelType::TIDx); + + std::array aten_inputs = {gen_aten_operand( + std::get(it), blocks, threads, /*rand*/ true)...}; + const at::ArrayRef aten_inputs_ivalues(aten_inputs); + + at::Tensor cg_output = + gen_aten_operand(op, blocks, threads, /*rand*/ false).toTensor(); + std::vector output_vect = {cg_output}; + cudaDeviceSynchronize(); + if (fusion.isStochastic()) + at::manual_seed(0); + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs_ivalues); + fe.runFusion(aten_inputs_ivalues, output_vect); + cudaDeviceSynchronize(); + + if (fusion.isStochastic()) + at::manual_seed(0); + at::Tensor aten_output = af(aten_inputs); + cudaDeviceSynchronize(); // This sync shouldn't be necessary; + + std::string op_msg = "Operation " + op_str; + + testValidate( + &fusion, + {cg_output}, + aten_inputs, + {aten_output}, + __LINE__, + __FILE__, + op_msg); +} + +/* + * Templatized Helper Function that uses variadic templates to + * process a variable length Input Tuple of different Operand Type. + */ +template +void test_op( + int blocks, + int threads, + std::string op_str, + AtenFunc af, + JitFunc jf, + OutputPair op, + InputTuple it) { + static constexpr auto size = std::tuple_size::value; + test_op( + blocks, + threads, + op_str, + af, + jf, + op, + it, + std::make_index_sequence{}); +} + +TEST_F(NVFuserTest, FusionUnaryOps_CUDA) { + using OpTuple = + std::tuple; + + // [Note: explicit tuple type for uniform initialization list] + // Tuple type must be explicitly specified for each uniform initialization + // list within the vector to make this code compatible with some old env + // which we still need to support. eg. gcc 5.4 + cuda 9.2. + std::vector ops{ + OpTuple{at::acos, UnaryOpType::Acos, "acos"}, + OpTuple{at::asin, UnaryOpType::Asin, "asin"}, + OpTuple{at::atan, UnaryOpType::Atan, "atan"}, + // There does not appear to be an appropriate ATen function for atanh + // OpTuple{at::atanh, UnaryOpType::Atanh, "atanh" }, + OpTuple{at::cos, UnaryOpType::Cos, "cos"}, + OpTuple{at::cosh, UnaryOpType::Cosh, "cosh"}, + OpTuple{at::exp, UnaryOpType::Exp, "exp"}, + // OpTuple{at::gelu, UnaryOpType::Gelu, "gelu"}, + OpTuple{at::log, UnaryOpType::Log, "log"}, + OpTuple{at::log10, UnaryOpType::Log10, "log10"}, + OpTuple{at::neg, UnaryOpType::Neg, "neg"}, + OpTuple{at::reciprocal, UnaryOpType::Reciprocal, "reciprocal"}, + OpTuple{at::sigmoid, UnaryOpType::Sigmoid, "sigmoid"}, + OpTuple{at::sin, UnaryOpType::Sin, "sin"}, + OpTuple{at::sinh, UnaryOpType::Sinh, "sinh"}, + OpTuple{at::sqrt, UnaryOpType::Sqrt, "sqrt"}, + OpTuple{at::tan, UnaryOpType::Tan, "tan"}, + OpTuple{at::tanh, UnaryOpType::Tanh, "tanh"}, + OpTuple{at::isfinite, UnaryOpType::IsFinite, "isfinite"}, + OpTuple{at::isinf, UnaryOpType::IsInf, "isinf"}, + OpTuple{at::isnan, UnaryOpType::IsNan, "isnan"}, + OpTuple{at::isreal, UnaryOpType::IsReal, "isreal"}, + }; + + // The following ops has no complex support in eager mode + std::vector ops_without_complex{ + OpTuple{at::ceil, UnaryOpType::Ceil, "ceil"}, + OpTuple{at::floor, UnaryOpType::Floor, "floor"}, + OpTuple{at::frac, UnaryOpType::Frac, "frac"}, + OpTuple{at::trunc, UnaryOpType::Trunc, "trunc"}, + OpTuple{at::round, UnaryOpType::Round, "round"}, + OpTuple{at::relu, UnaryOpType::Relu, "relu"}, + OpTuple{at::expm1, UnaryOpType::Expm1, "expm1"}, + OpTuple{at::log1p, UnaryOpType::Log1p, "log1p"}, + OpTuple{at::lgamma, UnaryOpType::Lgamma, "lgamma"}, + OpTuple{at::erf, UnaryOpType::Erf, "erf"}, + OpTuple{at::erfc, UnaryOpType::Erfc, "erfc"}, + OpTuple{at::isneginf, UnaryOpType::IsNegInf, "isneginf"}, + OpTuple{at::isposinf, UnaryOpType::IsPosInf, "isposinf"}, + }; + + // The following ops only supports complex + std::vector ops_complex_only{ + // real is supported via UnaryOpType::Set for non-complex types, and + // UnaryOpType::Real requires input to be complex + OpTuple{at::real, UnaryOpType::Real, "real"}, + OpTuple{at::imag, UnaryOpType::Imag, "imag"}, + }; + + // Complex support for the following op is not working in nvFuser yet + std::vector ops_skip_complex{ + // TODO: abs is actually supported in nvFuser, but it has bug!!! + // In eager mode, abs(complex_tensor) returns floating point tensor + // but in nvFuser, it wrongly returns complex tensor! + // We need to: + // 1. change our type promotion logic to make a special case for abs + // 2. why this bug is not detected here? we should bump up test coverage + OpTuple{at::abs, UnaryOpType::Abs, "abs"}, + // TODO: the following two ops fails with compilation error like + // "undefined function rsqrt(complex)", we could implement them in + // helpers.cu, but I think it is better to check with Jiterator first, + // because Jiterator uses the same string for complex support. + OpTuple{at::rsqrt, UnaryOpType::Rsqrt, "rsqrt"}, + OpTuple{at::log2, UnaryOpType::Log2, "log2"}}; + + std::vector dtypes = { + DataType::Float, + DataType::Double, + DataType::ComplexFloat, + DataType::ComplexDouble}; + + for (auto dtype : dtypes) { + auto ops_to_test = ops; + if (dtype != DataType::ComplexFloat && dtype != DataType::ComplexDouble) { + ops_to_test.insert( + ops_to_test.end(), + ops_without_complex.begin(), + ops_without_complex.end()); + ops_to_test.insert( + ops_to_test.end(), ops_skip_complex.begin(), ops_skip_complex.end()); + } else { + ops_to_test.insert( + ops_to_test.end(), ops_complex_only.begin(), ops_complex_only.end()); + } + std::for_each(ops.begin(), ops.end(), [&](OpTuple& op) { + test_op( + /*blocks*/ 640, + /*threads*/ 64, + /*name*/ std::get<2>(op), + /*Aten Func */ + [&op](std::array& vals) { + return std::get<0>(op)(vals[0].toTensor()); + }, + /*JIT Func */ + [&op](Val* in1) -> Val* { return unaryOp(std::get<1>(op), in1); }, + /*Output */ std::make_pair(ValType::TensorView, dtype), + /*Inputs Tuple*/ + std::make_tuple(std::make_pair(ValType::TensorView, dtype))); + }); + } + + dtypes = {DataType::Int, DataType::Int32, DataType::Bool}; + for (auto dtype : dtypes) { + test_op( + /*blocks*/ 128, + /*threads*/ 64, + /*name*/ "bitwise_not", + /*Aten Func */ + [](std::array& vals) { + return at::bitwise_not(vals[0].toTensor()); + }, + /*JIT Func */ + [](Val* in1) -> Val* { return unaryOp(UnaryOpType::Not, in1); }, + /*Output */ std::make_pair(ValType::TensorView, dtype), + /*Inputs Tuple*/ + std::make_tuple(std::make_pair(ValType::TensorView, dtype))); + } +} + +TEST_F(NVFuserTest, FusionBinaryOps_CUDA) { + using AtenFuncSig = at::Tensor (*)(const at::Tensor&, const at::Tensor&); + using OpTuple = std::tuple; + + std::vector dtypes = { + DataType::Double, + DataType::Float, + DataType::ComplexFloat, + DataType::ComplexDouble}; + + // see [Note: explicit tuple type for uniform initialization list] + std::vector equal_ops{ + OpTuple{at::eq, BinaryOpType::Eq, "eq"}, + OpTuple{at::ne, BinaryOpType::NE, "ne"}}; + + // Complex numbers are not ordered + std::vector order_ops{ + OpTuple{at::ge, BinaryOpType::GE, "ge"}, + OpTuple{at::gt, BinaryOpType::GT, "gt"}, + OpTuple{at::le, BinaryOpType::LE, "le"}, + OpTuple{at::lt, BinaryOpType::LT, "lt"}}; + + // see [Note: explicit tuple type for uniform initialization list] + std::vector math_ops{ + OpTuple{at::div, BinaryOpType::Div, "div"}, + OpTuple{at::mul, BinaryOpType::Mul, "mul"}, + OpTuple{at::pow, BinaryOpType::Pow, "pow"}}; + + // The following ops has no complex support in eager mode + std::vector math_ops_without_complex{ + OpTuple{at::atan2, BinaryOpType::Atan2, "atan2"}, + OpTuple{at::max, BinaryOpType::Max, "max"}, + OpTuple{at::min, BinaryOpType::Min, "min"}, + OpTuple{at::fmod, BinaryOpType::Fmod, "fmod"}, + // NOTE: Remainder does not match the Aten impl exactly + // despite using an identical function. + OpTuple{at::remainder, BinaryOpType::Remainder, "remainder"}}; + + for (auto dtype : dtypes) { + auto logic_ops = equal_ops; + if (dtype != DataType::ComplexFloat && dtype != DataType::ComplexDouble) { + logic_ops.insert(logic_ops.end(), order_ops.begin(), order_ops.end()); + } + std::for_each(logic_ops.begin(), logic_ops.end(), [&](OpTuple& op) { + test_op( + /*blocks*/ 640, + /*threads*/ 64, + /*name*/ std::get<2>(op), + /*Aten Func */ + [&op](std::array& vals) { + return std::get<0>(op)(vals[0].toTensor(), vals[1].toTensor()); + }, + /*JIT Func */ + [&op](Val* in1, Val* in2) -> Val* { + return binaryOp(std::get<1>(op), in1, in2); + }, + /*Output */ std::make_pair(ValType::TensorView, DataType::Bool), + /*Inputs Tuple*/ + std::make_tuple( + std::make_pair(ValType::TensorView, dtype), + std::make_pair(ValType::TensorView, dtype))); + }); + + auto enabled_math_ops = math_ops; + if (dtype != DataType::ComplexFloat && dtype != DataType::ComplexDouble) { + enabled_math_ops.insert( + enabled_math_ops.end(), + math_ops_without_complex.begin(), + math_ops_without_complex.end()); + } + std::for_each( + enabled_math_ops.begin(), enabled_math_ops.end(), [&](OpTuple& op) { + test_op( + /*blocks*/ 640, + /*threads*/ 64, + /*name*/ std::get<2>(op), + /*Aten Func */ + [&op](std::array& vals) { + return std::get<0>(op)(vals[0].toTensor(), vals[1].toTensor()); + }, + /*JIT Func */ + [&op](Val* in1, Val* in2) -> Val* { + return binaryOp(std::get<1>(op), in1, in2); + }, + /*Output */ std::make_pair(ValType::TensorView, dtype), + /*Inputs Tuple*/ + std::make_tuple( + std::make_pair(ValType::TensorView, dtype), + std::make_pair(ValType::TensorView, dtype))); + }); + + test_op( + /*blocks*/ 640, + /*threads*/ 64, + /*name*/ "add_alpha", + /*Aten Func */ + [](std::array& vals) { + return at::add( + vals[0].toTensor(), vals[1].toTensor(), vals[2].toScalar()); + }, + /*JIT Func */ static_cast(&add_alpha), + /*Output */ std::make_pair(ValType::TensorView, dtype), + /*Inputs Tuple*/ + std::make_tuple( + std::make_pair(ValType::TensorView, dtype), + std::make_pair(ValType::TensorView, dtype), + std::make_pair(ValType::Scalar, dtype))); + + test_op( + /*blocks*/ 640, + /*threads*/ 64, + /*name*/ "sub_alpha", + /*Aten Func */ + [](std::array& vals) { + return at::sub( + vals[0].toTensor(), vals[1].toTensor(), vals[2].toScalar()); + }, + /*JIT Func */ static_cast(&sub_alpha), + /*Output */ std::make_pair(ValType::TensorView, dtype), + /*Inputs Tuple*/ + std::make_tuple( + std::make_pair(ValType::TensorView, dtype), + std::make_pair(ValType::TensorView, dtype), + std::make_pair(ValType::Scalar, dtype))); + } +} + +TEST_F(NVFuserTest, FusionTernaryOps_CUDA) { + std::vector dtypes = { + DataType::Double, + DataType::Float, + DataType::ComplexFloat, + DataType::ComplexDouble}; + + for (auto dtype : dtypes) { + // clamp and threshold are not supported for complex on eager mode + if (dtype != DataType::ComplexFloat && dtype != DataType::ComplexDouble) { + test_op( + /*blocks*/ 640, + /*threads*/ 64, + /*name*/ "clamp", + /*Aten Func */ + [](std::array& vals) { + return at::clamp(vals[0].toTensor(), 0.f, 1.f); + }, + /*JIT Func */ + [&](Val* in1) -> Val* { + if (dtype == DataType::Float) { + return clamp( + in1, + IrBuilder::create(0.f), + IrBuilder::create(1.f)); + } else { + return clamp( + in1, + IrBuilder::create(0.f), + IrBuilder::create(1.f)); + } + }, + /*Output */ std::make_pair(ValType::TensorView, dtype), + /*Inputs Tuple*/ + std::make_tuple(std::make_pair(ValType::TensorView, dtype))); + test_op( + /*blocks*/ 640, + /*threads*/ 64, + /*name*/ "threshold", + /*Aten Func */ + [](std::array& vals) { + return at::threshold(vals[0].toTensor(), 0.f, 1.f); + }, + /*JIT Func */ + [&](Val* in1) -> Val* { + if (dtype == DataType::Float) { + return threshold( + in1, + IrBuilder::create(0.f), + IrBuilder::create(1.f)); + } else { + return threshold( + in1, + IrBuilder::create(0.f), + IrBuilder::create(1.f)); + } + }, + /*Output */ std::make_pair(ValType::TensorView, dtype), + /*Inputs Tuple*/ + std::make_tuple(std::make_pair(ValType::TensorView, dtype))); + } + test_op( + /*blocks*/ 640, + /*threads*/ 64, + /*name*/ "where", + /*Aten Func */ + [](std::array& vals) { + return at::where( + vals[0].toTensor(), vals[1].toTensor(), vals[2].toTensor()); + }, + /*JIT Func */ static_cast(&where), + /*Output */ std::make_pair(ValType::TensorView, dtype), + /*Inputs Tuple*/ + std::make_tuple( + std::make_pair(ValType::TensorView, DataType::Bool), + std::make_pair(ValType::TensorView, dtype), + std::make_pair(ValType::TensorView, dtype))); + } +} + +TEST_F(NVFuserTest, FusionCompoundOps_CUDA) { + std::vector dtypes = { + DataType::Double, + DataType::Float, + DataType::ComplexFloat, + DataType::ComplexDouble}; + + for (auto dtype : dtypes) { + test_op( + /*blocks*/ 640, + /*threads*/ 64, + /*name*/ "lerp", + /*Aten Func */ + [](std::array& vals) { + return at::lerp( + vals[0].toTensor(), vals[1].toTensor(), vals[2].toTensor()); + }, + /*JIT Func */ static_cast(&lerp), + /*Output */ std::make_pair(ValType::TensorView, dtype), + /*Inputs Tuple*/ + std::make_tuple( + std::make_pair(ValType::TensorView, dtype), + std::make_pair(ValType::TensorView, dtype), + std::make_pair(ValType::TensorView, dtype))); + test_op( + /*blocks*/ 640, + /*threads*/ 64, + /*name*/ "addcmul", + /*Aten Func */ + [](std::array& vals) { + return at::addcmul( + vals[0].toTensor(), + vals[1].toTensor(), + vals[2].toTensor(), + vals[3].toScalar()); + }, + /*JIT Func */ + static_cast(&addcmul), + /*Output */ std::make_pair(ValType::TensorView, dtype), + /*Inputs Tuple*/ + std::make_tuple( + std::make_pair(ValType::TensorView, dtype), + std::make_pair(ValType::TensorView, dtype), + std::make_pair(ValType::TensorView, dtype), + std::make_pair(ValType::Scalar, dtype))); + } +} + +TEST_F(NVFuserTest, FusionCastOps_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeSymbolicTensor(2, DataType::Half); + + TensorView* intrm1 = castOp(DataType::Float, tv0); + TensorView* out = castOp(DataType::Half, intrm1); + + fusion.addInput(tv0); + fusion.addOutput(out); + tv0->computeAt(out, -1); + + out->axis(0)->parallelize(ParallelType::BIDx); + out->axis(-1)->parallelize(ParallelType::TIDx); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + + at::Tensor input1 = at::randn({1, 4}, options); + at::Tensor ref_output = at::empty_like(input1); + + std::array inputs = {input1}; + const at::ArrayRef input_ivalues(inputs); + + FusionExecutor fe; + fe.compileFusion(&fusion, input_ivalues); + auto outputs = fe.runFusion(input_ivalues); + + ref_output = at::_cast_Half(at::_cast_Double(input1)); + + TORCH_CHECK( + outputs[0].equal(ref_output), + "\nOp Type: -- ", + "cast FP16->FP32->FP16", + " -- had a mismatch.\n", + "\nABS MAX DIFF: ", + outputs[0].sub(ref_output).abs().max(), + "\n"); +} + +// Start off simple, block on the outer dim +// block stride + thread all reduce + unrolling on inner dim +TEST_F(NVFuserTest, FusionReduction1_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + TensorView* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + // tv1[I0, R1] = tv0[I0, I1] + TensorView* tv1 = + reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv0); + fusion.addOutput(tv1); + + TORCH_CHECK( + ir_utils::getReductionOps(&fusion).size(), + "Could not detect reduction in fusion."); + + tv1->split(1, 128); + // tv1[I0, R1o, R1i{128}] = tv0[I0, I1] + tv1->split(1, 4); + // tv1[I0, R1oo, R1oi{4}, R1i{128}] = tv0[I0, I1] + + TensorView* tv2 = tv1->rFactor({1}); + // tv2[I0, R1oo, Ir1oi{4}, Ir1i{128}] = tv0[I0, I1] + // tv1[I0, R1oi{4}, R1i{128}] = tv2[I0, R1oo, Ir1oi{4}, Ir1i{128}] + + TensorView* tv3 = tv1->rFactor({1}); + // tv2[I0, R1oo, Ir1oi{4}, Ir1i{128}] = tv0[I0, I1] + // tv3[I0, R1oi{4}, Ir1i{128}] = tv2[I0, R1oo, Ir1oi{4}, Ir1i{128}] + // tv1[I0, R1i{128}] = tv3[I0, R1oi{4}, Ir1i{128}] + + // Incrementally, can print in between for debugging + tv0->computeAt(tv2, 1); + tv2->computeAt(tv3, 1); + tv3->computeAt(tv1, 1); + + // Re do it all at once, because why not. + tv0->computeAt(tv1, 1); + + tv2->axis(2)->parallelize(ParallelType::Unroll); + tv1->axis(0)->parallelize(ParallelType::BIDx); + + tv1->axis(-1)->parallelize(ParallelType::TIDx); + tv2->axis(-1)->parallelize(ParallelType::TIDx); + tv3->axis(-1)->parallelize(ParallelType::TIDx); + + int numel_x = 65000; + int numel_y = 1025; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input = at::randn({numel_x, numel_y}, options); + at::Tensor cg_output = at::empty({numel_x}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input}); + fe.runFusion({input}, {cg_output}); + + auto aten_output = input.to(at::kDouble).sum({1}); + + testValidate( + &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionReduction2_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + TensorView* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + // tv1[I0, R1] = tv0[I0, I1] + TensorView* tv1 = + reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv0); + + fusion.addOutput(tv1); + + // switches to try some different scenarios. maybe we should iterate on all + // permutations. + bool bind_bidx = true; + bool bind_tidx = true; + bool bind_tidy = true; + bool bind_unroll = true; + + int numel_x = 1025; // Cannot exceed block dim max size / tidy + int numel_y = 129; + int tidx = 16; + int tidy = 8; + int unroll_factor = 4; + + tv1->split(1, tidx); + // tv1[I0, R1o, R1i{tidx}] = tv0[I0, I1] + + tv1->split(1, unroll_factor); + // tv1[I0, R1oo, R1oi{unroll}, R1i{tidx}] = tv0[I0, I1] + + tv1->split(0, tidy); + + TensorView* tv2 = tv1->rFactor({-3}); + // tv2[I0, >R1oo<, Ir1oi{unroll}, Ir1i{tidx}] + // tv1[I0o, I0i{tidy}, R1oi{unroll}, R1i{tidx}] + + TensorView* tv3 = tv1->rFactor({-2}); + // tv2[I0, >R1oo<, Ir1oi{unroll}, Ir1i{tidx}] + // tv3[I0, R1oi{unroll}, Ir1i{tidx}] + // tv1[I0o, I0i{tidy}, R1i{tidx}] + + tv0->computeAt(tv1, -2); + + if (bind_unroll) + tv2->axis(-2)->parallelize(ParallelType::Unroll); + if (bind_bidx) + tv1->axis(0)->parallelize(ParallelType::BIDx); + if (bind_tidy) + tv1->axis(1)->parallelize(ParallelType::TIDy); + + if (bind_tidx) { + tv2->axis(-1)->parallelize(ParallelType::TIDx); + tv3->axis(-1)->parallelize(ParallelType::TIDx); + tv1->axis(-1)->parallelize(ParallelType::TIDx); + } + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input = at::randn({numel_x, numel_y}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input}); + auto cg_outputs = fe.runFusion({input}); + + auto aten_output = input.to(at::kDouble).sum({1}); + testValidate(&fusion, cg_outputs, {input}, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionReduction3_CUDA) { + // What if Z participates in the reduction with X? + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + TensorView* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + // tv1[I0, R1] = tv0[I0, I1] + TensorView* tv1 = + reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv0); + + fusion.addOutput(tv1); + + int numel_x = 1025; // Cannot exceed block dim max size / tidy + int numel_y = 129; + int tidx = 16; + int tidz = 8; + + tv1->split(1, tidz); + // tv1[I0, R1o, R1i{tidz}] = tv0[I0, I1] + + tv1->split(1, tidx); + // tv1[I0, R1oo, R1oi{tidx}, R1i{tidz}] = tv0[I0, I1] + + TensorView* tv2 = tv1->rFactor({-3}); + // tv2[I0, >R1oo<, Ir1oi{tidx}, Ir1i{tidz}] + // tv1[I0o, R1oi{tidx}, R1i{tidz}] + + tv0->computeAt(tv1, -3); + + tv1->axis(0)->parallelize(ParallelType::BIDx); + tv1->axis(-2)->parallelize(ParallelType::TIDx); + tv1->axis(-1)->parallelize(ParallelType::TIDz); + + tv2->axis(-2)->parallelize(ParallelType::TIDx); + tv2->axis(-1)->parallelize(ParallelType::TIDz); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor aten_input = at::randn({numel_x, numel_y}, options); + at::Tensor cg_output = at::empty({numel_x}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); + fe.runFusion({aten_input}, {cg_output}); + + auto aten_output = aten_input.to(at::kDouble).sum({1}); + + testValidate( + &fusion, {cg_output}, {aten_input}, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionReduction4_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + TensorView* tv0 = makeSymbolicTensor(2); + TensorView* tv1 = makeSymbolicTensor(2); + + TensorView* tv2 = add(tv0, tv1); + // tv2[I0, I1] = tv0[I0, I1] + tv1[I0, I1] + + fusion.addInput(tv0); + fusion.addInput(tv1); + + TensorView* tv3 = + reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv2); + // tv3[I0, R1] = tv2[I0, I1] + + TensorView* tv4 = makeSymbolicTensor(1); + fusion.addInput(tv4); + + // tv5[I0] = tv3[I0, R1] * tv4[I0] + TensorView* tv5 = mul(tv3, tv4); + fusion.addOutput(tv5); + + int tidx = 16; + + // RFactor the reduction + tv3->split(1, tidx); + // tv3[I0, R1o, R1i{tidx}] = tv2[I0, I1] + + TensorView* tv6 = tv3->rFactor({-2}); + // tv6[I0, R1o, iR1i{tidx}] = tv2[I0, I1] + // tv3[I0, R1i{tidx}] = tv3[I0, I1] + tv2->computeAt(tv6, 2); + + // Compute at inline with tv5 (only 1D) + tv6->computeAt(tv3, 1); + tv3->computeAt(tv5, 1); + + tv5->axis(0)->parallelize(ParallelType::BIDx); + + // Intermediate tensors only need this, but doesn't hurt to do on inputs + // tv0, 1, 4 + tv2->axis(-1)->parallelize(ParallelType::TIDx); + tv3->axis(-1)->parallelize(ParallelType::TIDx); + tv6->axis(-1)->parallelize(ParallelType::TIDx); + + int numel_x = 1025; + int numel_y = 129; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({numel_x, numel_y}, options); + at::Tensor t1 = at::randn({numel_x, numel_y}, options); + at::Tensor t4 = at::randn({numel_x}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t1, t4}); + auto cg_outputs = fe.runFusion({t0, t1, t4}); + + auto t2 = t0.add(t1); + auto t3 = t2.to(at::kDouble).sum({1}); + auto aten_output = t3.mul(t4); + + testValidate( + &fusion, cg_outputs, {t0, t1, t4}, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionReduction5_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + TensorView* tv0 = makeSymbolicTensor(3); + + fusion.addInput(tv0); + + TensorView* tv1 = + reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv0); + + fusion.addOutput(tv1); + + int bidy = 2; + int tidy = 4; + int tidx = 5; + + int dim1 = 11; + + tv1->split(-2, tidy); + + TensorView* tv2 = tv1->rFactor({-3}); + + tv0->computeAt(tv1, 1); + tv1->axis(0)->parallelize(ParallelType::BIDy); + + for (auto* val : fusion.vals()) { + if (!val->isFusionInput() && + val->getValType().value() == ValType::TensorView) { + val->as()->axis(-1)->parallelize(ParallelType::TIDx); + } + } + + tv2->axis(-2)->parallelize(ParallelType::TIDy); + tv1->axis(-2)->parallelize(ParallelType::TIDy); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input = at::randn({bidy, dim1, tidx}, options); + + at::Tensor cg_output = at::empty({bidy, tidx}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input}); + fe.runFusion({input}, {cg_output}); + + auto aten_output = input.to(at::kDouble).sum({1}); + testValidate( + &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionReduction6_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + const int bdimx = 64; + const int bdimy = 8; + + // Set up your input tensor views + TensorView* tv0 = makeSymbolicTensor(3); + fusion.addInput(tv0); + + // tv1[I0, R1, R2] = tv0[I0, I1, I2] + TensorView* tv1 = + reductionOp(BinaryOpType::Add, {1, 2}, IrBuilder::create(0), tv0); + fusion.addOutput(tv1); + + TORCH_CHECK( + ir_utils::getReductionOps(&fusion).size(), + "Could not detect reduction in fusion."); + + tv1->split(2, bdimx); + // tv1[I0, R1, R2o, R2i{128}] = tv0[I0, I1, I2] + tv1->split(1, bdimy); + // tv1[I0, R1o, R1i{8}, R2o, R2i{128}] = tv0[I0, I1, I2] + + TensorView* tv2 = tv1->rFactor({3}); + // tv2[I0, I1o, I1i{8}, R2o, I2i{128}] = tv0[I0, I1, I2] + // tv1[I0, R1o, R1i{8}, R2i{128}] = tv2[I0, I1o, I1i{8}, R2o, I2i{128}] + + TensorView* tv3 = tv1->rFactor({1}); + // tv2[I0, I1o, I1i{8}, R2o, I2i{128}] = tv0[I0, I1, I2] + // tv3[I0, R1o, I1i{8}, I2i{128}] = tv2[I0, I1o, I1i{8}, R2o, I2i{128}] + // tv1[I0, R1i{8}, R2i{128}] = tv3[I0, R1o, I1i{8}, I2i{128}] + + tv3->computeAt(tv1, 1); + tv2->computeAt(tv3, 2); + + tv1->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv3->axis(0)->parallelize(ParallelType::BIDx); + + tv1->axis(-1)->parallelize(ParallelType::TIDx); + tv2->axis(-1)->parallelize(ParallelType::TIDx); + tv3->axis(-1)->parallelize(ParallelType::TIDx); + + tv1->axis(-2)->parallelize(ParallelType::TIDy); + tv3->axis(-2)->parallelize(ParallelType::TIDy); + tv2->axis(-3)->parallelize(ParallelType::TIDy); + + int numel_x = 650; + int numel_y = 1000; + int numel_z = 4; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input = at::randn({numel_x, numel_y, numel_z}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input}); + auto cg_outputs = fe.runFusion({input}); + + auto aten_output = input.to(at::kDouble).sum({1, 2}); + testValidate(&fusion, cg_outputs, {input}, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionMultiGridReduction_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + TensorView* tv1 = max(tv0, {0}); + TensorView* tv2 = sum(tv0, {0}); + + fusion.addOutput(tv1); + fusion.addOutput(tv2); + + int numel_x = 4; + int numel_y = 2; + + tv1->axis(0)->parallelize(ParallelType::BIDx); + tv1->axis(1)->parallelize(ParallelType::TIDx); + + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(1)->parallelize(ParallelType::TIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input = at::randn({numel_x, numel_y}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input}); + auto cg_outputs = fe.runFusion({input}); + + std::vector aten_outputs = { + std::get<0>(input.to(at::kDouble).max(0)), input.to(at::kDouble).sum(0)}; + testValidate(&fusion, cg_outputs, {input}, aten_outputs, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionMultiGridReduction2_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto tv1 = sum(tv0, {0}); + auto tv2 = sum(tv1, {0}); + fusion.addOutput(tv2); + + tv1->axis(0)->parallelize(ParallelType::BIDx); + tv1->axis(1)->parallelize(ParallelType::BIDy); + tv2->axis(0)->parallelize(ParallelType::BIDy); + + FusionExecutor fe; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) + ASSERT_ANY_THROW(fe.compileFusion(&fusion)); +} + +TEST_F(NVFuserTest, FusionReductionTFT_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + TensorView* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + // tv1[I0, R1] = tv0[I0, I1] + TensorView* tv1 = + reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv0); + + fusion.addOutput(tv1); + + int numel_x = 1025; + int numel_y = 129; + int tidx = 16; + int tidy = 8; + int tidz = 8; + + tv1->split(1, tidx); + // tv1[I0, R1o, R1i{tidx}] + + tv1->split(1, tidz); + // tv1[I0, R1oo, R1Oi{tidz}, R1R1i{tidx}] + + tv1->split(0, tidy); + // tv1[I0o, I0i, R1oo, R1Oi{tidz}, R1R1i{tidx}] + + TensorView* tv2 = tv1->rFactor({2}); + // tv2[I0o, I0i, R1oo, I1Oi{tidz}, I11i{tidx}] + // tv1[I0o, I0i, R1Oi{tidz}, R1R1i{tidx}] + + tv2->computeAt(tv1, 2); + + tv1->axis(1)->parallelize(ParallelType::TIDy); + + tv2->axis(-1)->parallelize(ParallelType::TIDx); + tv1->axis(-1)->parallelize(ParallelType::TIDx); + + tv1->axis(-2)->parallelize(ParallelType::TIDz); + tv2->axis(-2)->parallelize(ParallelType::TIDz); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input = at::randn({numel_x, numel_y}, options); + at::Tensor cg_output = at::empty({numel_x}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input}); + fe.runFusion({input}, {cg_output}); + + auto aten_output = input.to(at::kDouble).sum({1}); + testValidate( + &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionReductionOuterSplit_CUDA) { + // based off FusionReduction4 + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + TensorView* tv0 = makeSymbolicTensor(2); + TensorView* tv1 = makeSymbolicTensor(2); + + TensorView* tv2 = add(tv0, tv1); + // tv2[I0, I1] = tv0[I0, I1] + tv1[I0, I1] + + fusion.addInput(tv0); + fusion.addInput(tv1); + + TensorView* tv3 = + reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv2); + // tv3[I0, R1] = tv2[I0, I1] + + TensorView* tv4 = makeSymbolicTensor(1); + fusion.addInput(tv4); + + // tv5[I0] = tv3[I0, R1] * tv4[I0] + TensorView* tv5 = mul(tv3, tv4); + fusion.addOutput(tv5); + + // RFactor the reduction + tv3->split(1, 16, false); + // tv3[I0, R1o{16}, R1i{tidx}] = tv2[I0, I1] + + TensorView* tv6 = tv3->rFactor({-2}); + // tv6[I0, R1o{16}, iR1i{tidx}] = tv2[I0, I1] + // tv3[I0, R1i{tidx}] = tv3[I0, I1] + tv2->computeAt(tv6, 2); + + // Compute at inline with tv5 (only 1D) + tv6->computeAt(tv3, 1); + tv3->computeAt(tv5, 1); + + tv5->axis(0)->parallelize(ParallelType::BIDx); + + // Intermediate tensors only need this, but doesn't hurt to do on inputs + // tv0, 1, 4 + tv2->axis(-1)->parallelize(ParallelType::TIDx); + tv3->axis(-1)->parallelize(ParallelType::TIDx); + tv6->axis(-1)->parallelize(ParallelType::TIDx); + + int numel_x = 1025; + int numel_y = 129; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({numel_x, numel_y}, options); + at::Tensor t1 = at::randn({numel_x, numel_y}, options); + at::Tensor t4 = at::randn({numel_x}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t1, t4}); + auto cg_outputs = fe.runFusion({t0, t1, t4}); + + auto t2 = t0.add(t1); + auto t3 = t2.to(at::kDouble).sum({1}); + auto aten_output = t3.mul(t4); + + testValidate( + &fusion, cg_outputs, {t0, t1, t4}, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionBranches_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + TensorView* tv0 = makeSymbolicTensor(2); + TensorView* tv1 = makeSymbolicTensor(2); + TensorView* tv2 = makeSymbolicTensor(2); + fusion.addInput(tv0); + fusion.addInput(tv1); + fusion.addInput(tv2); + + auto tv3 = add(tv0, IrBuilder::create(1.0)); + auto tv4 = add(tv3, tv1); + auto tv5 = add(tv3, tv2); + auto tv6 = add(tv4, tv5); + + fusion.addOutput(tv6); + + constexpr int x = 63, y = 33; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor t0 = at::randn({x, y}, options); + at::Tensor t1 = at::randn({x, y}, options); + at::Tensor t2 = at::randn({x, y}, options); + + FusionExecutor fe; + tv6->merge(0); + tv6->split(0, 128); + tv6->split(0, 4); + + tv6->axis(0)->parallelize(ParallelType::BIDx); + + tv0->computeAt(tv6, 1); + tv1->computeAt(tv6, 1); + tv2->computeAt(tv6, 1); + + tv3->axis(-2)->parallelize(ParallelType::Unroll); + tv3->axis(-1)->parallelize(ParallelType::TIDx); + tv4->axis(-2)->parallelize(ParallelType::Unroll); + tv4->axis(-1)->parallelize(ParallelType::TIDx); + tv5->axis(-2)->parallelize(ParallelType::Unroll); + tv5->axis(-1)->parallelize(ParallelType::TIDx); + tv6->axis(-1)->parallelize(ParallelType::TIDx); + + std::vector aten_inputs = {t0, t1, t2}; + + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + auto t3 = t0.add(1.0); + auto t4 = t3.add(t1); + auto t5 = t3.add(t2); + auto aten_output = t4.add(t5); + + testValidate( + &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionSimpleBCast1_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + TensorView* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + TensorView* tv1 = add(tv0, IrBuilder::create(1.5)); + + TensorView* tv2 = makeSymbolicTensor(2); + fusion.addInput(tv2); + TensorView* tv3 = makeSymbolicTensor(2); + fusion.addInput(tv3); + TensorView* tv4 = sub(tv2, tv3); + + TensorView* tv5 = broadcast(tv1, {false, false, true}); + TensorView* tv6 = broadcast(tv4, {true, false, false}); + + TensorView* tv7 = add(tv5, tv6); + fusion.addOutput(tv7); + + tv7->split(-1, 4); + tv7->split(0, 8); + + tv0->computeAt(tv7, -1); + tv2->computeAt(tv7, -1); + + tv7->axis(0)->parallelize(ParallelType::BIDx); + tv7->axis(-1)->parallelize(ParallelType::TIDx); + + constexpr int x = 63, y = 33, z = 15; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor t0 = at::randn({x, y}, options); + at::Tensor t1 = t0.add(1.5); + + at::Tensor t2 = at::randn({y, z}, options); + at::Tensor t3 = at::randn({y, z}, options); + + at::Tensor t4 = t2.sub(t3); + at::Tensor t5 = t1.unsqueeze(-1).expand({x, y, z}); + + at::Tensor t6 = t4.expand({x, y, z}); + + at::Tensor aten_output = t5.add(t6); + + std::vector aten_inputs = {t0, t2, t3}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + testValidate( + &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionSimpleBCast2_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + TensorView* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + TensorView* tv1 = makeSymbolicTensor(2); + fusion.addInput(tv1); + + TensorView* tv2 = add(tv0, tv1); + + TensorView* tv3 = broadcast(tv2, {false, false, true}); + + TensorView* tv4 = makeSymbolicTensor(2); + fusion.addInput(tv4); + + TensorView* tv5 = sub(tv4, IrBuilder::create(0.1)); + + TensorView* tv6 = broadcast(tv5, {true, false, false}); + + TensorView* tv7 = add(tv3, tv6); + + fusion.addOutput(tv7); + + tv7->merge(0, 1); + + tv0->computeAt(tv7, -1); + tv4->computeAt(tv7, -1); + + tv7->axis(0)->parallelize(ParallelType::BIDx); + tv7->axis(-1)->parallelize(ParallelType::TIDx); + + constexpr int x = 63, y = 33, z = 15; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor t0 = at::randn({x, y}, options); + at::Tensor t1 = at::randn({x, y}, options); + at::Tensor t2 = t0.add(t1); + at::Tensor t3 = t2.unsqueeze(-1).expand({x, y, z}); + + at::Tensor t4 = at::randn({y, z}, options); + at::Tensor t5 = t4.sub(0.1); + at::Tensor t6 = t5.expand({x, y, z}); + at::Tensor aten_output = t3.add(t6); + + at::Tensor cg_output = at::empty({x, y, z}, options); + + std::vector aten_inputs = {t0, t1, t4}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + fe.runFusion(aten_inputs, {cg_output}); + + testValidate( + &fusion, {cg_output}, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionSimpleBCast3_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up input tensor views + // tv0[I1, B{1}] + TensorView* tv0 = makeConcreteTensor({-1, 1}); + fusion.addInput(tv0); + + // tv1[I0, I1, I2] + TensorView* tv2 = makeSymbolicTensor(3); + fusion.addInput(tv2); + + TensorView* tv3 = add(tv0, tv2); + + fusion.addOutput(tv3); + + tv3->merge(0); + tv3->merge(0); + + tv0->computeAt(tv3, -1); + tv2->computeAt(tv3, -1); + + tv3->axis(0)->parallelize(ParallelType::BIDx); + + constexpr int x = 2, y = 3, z = 4; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor t0 = at::randn({y, 1}, options); + at::Tensor t2 = at::randn({x, y, z}, options); + auto aten_output = t0.add(t2); + + std::vector aten_inputs = {t0, t2}; + at::Tensor cg_output = at::empty({x, y, z}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + fe.runFusion(aten_inputs, {cg_output}); + + testValidate( + &fusion, {cg_output}, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionSimpleBCast4_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + TensorView* tv0 = makeConcreteTensor({1, -1}); + + TensorView* tv1 = makeSymbolicTensor(3); + fusion.addInput(tv0); + fusion.addInput(tv1); + + TensorView* tv3 = add(tv0, tv1); + + tv3->merge(0); + tv3->merge(0); + tv3->split(0, 128); + tv3->split(0, 4); + + fusion.addOutput(tv3); + + tv0->computeAt(tv3, -1); + tv1->computeAt(tv3, -1); + + tv3->axis(0)->parallelize(ParallelType::BIDx); + tv3->axis(-1)->parallelize(ParallelType::TIDx); + tv3->axis(-2)->parallelize(ParallelType::Unroll); + + constexpr int x = 63, y = 33, z = 15; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor t0 = at::randn({1, z}, options); + at::Tensor t1 = at::randn({x, y, z}, options); + + auto aten_output = t0.add(t1); + + at::Tensor cg_output = at::empty({x, y, z}, options); + + std::vector aten_inputs = {t0, t1}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + fe.runFusion(aten_inputs, {cg_output}); + + testValidate( + &fusion, {cg_output}, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionSimpleBCast5_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + constexpr int m = 2, k = 3, n = 4; + auto tv0 = makeConcreteTensor({m, k}); + auto tv1 = makeConcreteTensor({k, n}); + + fusion.addInput(tv0); + fusion.addInput(tv1); + + TensorView* tv2 = broadcast(tv0, {false, false, true}); + TensorView* tv3 = broadcast(tv1, {true, false, false}); + + TensorView* tv4 = add(tv2, tv3); + + fusion.addOutput(tv4); + + tv4->merge(0); + tv4->merge(0); + + tv0->computeAt(tv4, -1); + tv1->computeAt(tv4, -1); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor t0 = at::randn({m, k}, options); + at::Tensor t1 = at::randn({k, n}, options); + + auto t2 = t0.unsqueeze(-1).expand({m, k, n}); + auto t3 = t1.expand({m, k, n}); + auto aten_output = t2.add(t3); + + at::Tensor cg_output = at::empty({m, k, n}, options); + + std::vector aten_inputs = {t0, t1}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + fe.runFusion(aten_inputs, {cg_output}); + + testValidate( + &fusion, {cg_output}, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionComplexBCast1_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + int x = 2, y = 3, z = 4; + + auto tv0 = makeConcreteTensor({y}); + auto tv1 = div(tv0, IrBuilder::create(2.0)); + auto tv2 = broadcast(tv1, {false, true}); + auto tv3 = makeConcreteTensor({y, z}); + auto tv4 = mul(tv2, tv3); + auto tv5 = broadcast(tv4, {true, false, false}); + auto tv6 = makeConcreteTensor({x, y, z}); + auto tv7 = add(tv5, tv6); + + // tv0[ i1 ] = input + // tv1[ i1 ] = tv0/2.0 + // tv2[ i1, b2] = bcast(tv1) + // tv3[ i1, i2] = input + // tv4[ i1, i2] = tv2 * tv3 + // tv5[b0, i1, i2] = bcast(tv4) + // tv6[i0, i1, i2] = input + // tv7[i0, i1, i2] = tv5 + tv6 + + // tv4 = bcast(tv1) * tv3 + // tv7 = bcast(tv4) + tv6 + + fusion.addInput(tv0); + fusion.addInput(tv3); + fusion.addInput(tv6); + + fusion.addOutput(tv7); + + tv7->merge(0); + tv7->merge(0); + tv0->computeAt(tv7, -1); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor t0 = at::randn({y}, options); + at::Tensor t3 = at::randn({y, z}, options); + at::Tensor t6 = at::randn({x, y, z}, options); + + auto t4 = t0.div(2.0).unsqueeze(-1).expand({y, z}) * t3; + auto aten_output = t4.unsqueeze(0).expand({x, y, z}) + t6; + + std::vector aten_inputs = {t0, t3, t6}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + testValidate( + &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionComplexBCast2_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + int x = 2, y = 3, z = 4; + + auto tv0 = makeConcreteTensor({y, z}); + auto tv1 = div(tv0, IrBuilder::create(2.0)); + auto tv2 = sum(tv1, {1}); + auto tv3 = broadcast(tv2, {true, false}); + auto tv4 = makeConcreteTensor({x, y}); + auto tv5 = add(tv3, tv4); + + // tv0[ i1, i2] = input + // tv1[ i1, i2] = tv0/2.0 + // tv2[ i1 ] = sum(tv1, 1) + // tv3[b0, i1 ] = bcast(tv2) + // tv4[i0, i1 ] = input + // tv5[i0, i1 ] = tv3 + tv4 + + // tv2 = sum(tv0/2.0, 1) + // tv5 = bcast(tv2) + tv4 + + fusion.addInput(tv0); + fusion.addInput(tv4); + + fusion.addOutput(tv5); + + tv5->merge(0); + tv0->computeAt(tv5, -1); + tv1->computeAt(tv2, -1); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor t0 = at::randn({y, z}, options); + at::Tensor t4 = at::randn({x, y}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t4}); + auto cg_outputs = fe.runFusion({t0, t4}); + + auto t1 = t0.div(2.0); + auto t2 = t1.to(at::kDouble).sum(1); + auto t3 = t2.unsqueeze(0).expand({x, y}); + auto aten_output = t3.add(t4); + + testValidate( + &fusion, {cg_outputs}, {t0, t4}, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionAdvancedIndexing1_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + int w = 3, x = 4, y = 7, z = 8; + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + auto tv0 = makeSymbolicTensor(3); + auto tv1 = makeSymbolicTensor(4); + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = add(tv0, IrBuilder::create(1.0)); + auto tv3 = broadcast(tv2, {true, false, false, false}); + auto tv4 = add(tv3, tv1); + + fusion.addOutput(tv4); + + tv4->merge(0); + tv4->merge(0); + tv4->merge(0); + + tv4->split(0, 128); + tv4->split(0, 4); + + tv2->computeAt(tv4, 1); + + tv4->axis(0)->parallelize(ParallelType::BIDx); + tv4->axis(1)->parallelize(ParallelType::Unroll); + tv4->axis(2)->parallelize(ParallelType::TIDx); + + tv3->axis(1)->parallelize(ParallelType::Unroll); + tv3->axis(2)->parallelize(ParallelType::TIDx); + + tv2->axis(1)->parallelize(ParallelType::Unroll); + tv2->axis(2)->parallelize(ParallelType::TIDx); + + FusionExecutor fe; + + at::Tensor t0 = at::randn({x, y, z}, options); + at::Tensor t1 = at::randn({w, x, y, z}, options); + + auto t3 = t0.add(1.0); + auto aten_output = t3.add(t1); + + std::vector aten_inputs = {t0, t1}; + + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + testValidate( + &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionAdvancedIndexing2_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + int w = 3, x = 4, y = 7, z = 8; + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + auto tv0 = makeSymbolicTensor(3); + auto tv1 = makeSymbolicTensor(4); + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = add(tv0, IrBuilder::create(1.0)); + auto tv3 = broadcast(tv2, {true, false, false, false}); + auto tv4 = add(tv3, tv1); + + fusion.addOutput(tv4); + + tv4->merge(-2); + tv4->merge(-2); + tv4->merge(-2); + + tv4->split(0, 128); + tv4->split(0, 4); + + tv2->computeAt(tv4, 1); + + tv4->axis(0)->parallelize(ParallelType::BIDx); + tv4->axis(1)->parallelize(ParallelType::Unroll); + tv4->axis(2)->parallelize(ParallelType::TIDx); + + tv3->axis(1)->parallelize(ParallelType::Unroll); + tv3->axis(2)->parallelize(ParallelType::TIDx); + + tv2->axis(1)->parallelize(ParallelType::Unroll); + tv2->axis(2)->parallelize(ParallelType::TIDx); + + FusionExecutor fe; + + at::Tensor t0 = at::randn({x, y, z}, options); + at::Tensor t1 = at::randn({w, x, y, z}, options); + + auto t3 = t0.add(1.0); + auto aten_output = t3.add(t1); + + std::vector aten_inputs = {t0, t1}; + + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + testValidate( + &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionAdvancedIndexing3_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + int w = 3, x = 4, y = 7, z = 8; + + auto tv0 = makeSymbolicTensor(3); + auto tv1 = makeSymbolicTensor(4); + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = add(tv0, IrBuilder::create(1.0)); + auto tv3 = add(tv2, tv1); + fusion.addOutput(tv3); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({x, y, z}, options); + at::Tensor t1 = at::randn({w, x, y, z}, options); + + auto t2 = t0.add(1.0); + auto aten_output = t2.add(t1); + + std::vector aten_inputs = {t0, t1}; + + auto lparams = schedulePointwise(&fusion, aten_inputs); + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs, lparams); + auto cg_outputs = fe.runFusion(aten_inputs, lparams); + + testValidate( + &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionAdvancedIndexing4_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + TensorView* tv0 = makeConcreteTensor({4, 8}); + fusion.addInput(tv0); + TensorView* tv1 = makeConcreteTensor({4, 4, 8}); + fusion.addInput(tv1); + + TensorView* tv2 = add(tv0, IrBuilder::create(1)); + TensorView* tv3 = broadcast(tv2, {true, false, false}); + TensorView* tv4 = add(tv3, tv1); + fusion.addOutput(tv4); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({4, 8}, options); + at::Tensor t1 = at::randn({4, 4, 8}, options); + + auto t2 = t0.add(1.0); + auto aten_output = t2.add(t1); + + std::vector aten_inputs = {t0, t1}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + testValidate( + &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionAdvancedIndexing5_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + TensorView* tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + TensorView* tv1 = makeSymbolicTensor(3); + fusion.addInput(tv1); + + TensorView* tv2 = add(tv0, IrBuilder::create(1)); + TensorView* tv3 = broadcast(tv2, {true, false, true}); + TensorView* tv4 = add(tv3, tv1); + fusion.addOutput(tv4); + + tv3->merge(0)->merge(0)->split(0, 2)->split(0, 3); + tv4->merge(0)->merge(0)->split(0, 2)->split(0, 3); + + tv0->computeAt(tv4, 1); + tv1->computeAt(tv4, 1); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({7}, options); + at::Tensor t1 = at::randn({5, 7, 11}, options); + + auto t2 = t0.add(1.0); + auto aten_output = t2.unsqueeze(-1).add(t1); + + std::vector aten_inputs = {t0, t1}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + testValidate( + &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionAdvancedIndexing6_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + std::vector tensor0_shape{7, 4, 7}; + std::vector tensor1_shape{4, 7}; + + TensorView* tv0 = makeSymbolicTensor(tensor0_shape.size()); + fusion.addInput(tv0); + TensorView* tv1 = makeSymbolicTensor(tensor1_shape.size()); + fusion.addInput(tv1); + + TensorView* tv2 = add(tv0, tv1); + TensorView* tv3 = sum(tv2, {0, 1}); + fusion.addOutput(tv3); + + const auto options = + at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor input0 = at::randn(tensor0_shape, options); + at::Tensor input1 = at::randn(tensor1_shape, options); + + std::vector reduction_axes{0, 1}; + auto reduction_params = getReductionHeuristics(&fusion, {input0, input1}); + TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); + scheduleReduction(&fusion, *reduction_params); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input0, input1}, reduction_params->lparams); + auto cg_outputs = fe.runFusion({input0, input1}, reduction_params->lparams); + + auto aten_output = input0.add(input1).to(at::kDouble).sum(reduction_axes); + + testValidate( + &fusion, + cg_outputs, + {input0, input1}, + {aten_output}, + __LINE__, + __FILE__, + "", + reduction_params->lparams); +} + +TEST_F(NVFuserTest, FusionAdvancedIndexing7_CUDA) { + // Might be able to use this one without 6 as the heuristics in 6 may change + // and this test is to cover the same issue. + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + auto tv1 = broadcast(tv0, {false, true}); + + auto tv2 = makeSymbolicTensor(2); + fusion.addInput(tv2); + + auto tv3 = add(tv1, tv2); + auto tv4 = sum(tv3, {0, 1}); + fusion.addOutput(tv4); + + tv4->merge(0, 1); + tv4->split(0, 128); + tv4->split(0, 4); + + auto tv5 = tv4->rFactor({0, 1}); + + tv5->computeAt(tv4, -1); + tv0->computeAt(tv5, -1); + + tv4->axis(0)->parallelize(ParallelType::TIDx); + + const int numel_x = 100; + const int numel_y = 200; + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto at_t0 = at::randn({numel_x}, options); + auto at_t1 = at::randn({numel_x, numel_y}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {at_t0, at_t1}); + auto cg_outputs = fe.runFusion({at_t0, at_t1}); + + auto aten_output = (at_t0.unsqueeze(-1).expand({numel_x, numel_y}) + at_t1) + .to(at::kDouble) + .sum(); + + testValidate( + &fusion, cg_outputs, {at_t0, at_t1}, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionAdvancedIndexing8_CUDA) { + // Same as 7 but with outer splits instead of inner + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + auto tv1 = broadcast(tv0, {false, true}); + + auto tv2 = makeSymbolicTensor(2); + fusion.addInput(tv2); + + auto tv3 = add(tv1, tv2); + auto tv4 = sum(tv3, {0, 1}); + fusion.addOutput(tv4); + + tv4->merge(0, 1); + tv4->split(0, 128, false); + tv4->split(0, 4, false); + + auto tv5 = tv4->rFactor({0, 1}); + + tv5->computeAt(tv4, -1); + tv0->computeAt(tv5, -1); + + tv4->axis(0)->parallelize(ParallelType::TIDx); + + const int numel_x = 100; + const int numel_y = 200; + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto at_t0 = at::randn({numel_x}, options); + auto at_t1 = at::randn({numel_x, numel_y}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {at_t0, at_t1}); + auto cg_outputs = fe.runFusion({at_t0, at_t1}); + + auto aten_output = (at_t0.unsqueeze(-1).expand({numel_x, numel_y}) + at_t1) + .to(at::kDouble) + .sum(); + + testValidate( + &fusion, cg_outputs, {at_t0, at_t1}, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionAdvancedIndexing9_CUDA) { + // Same as 7 but with outer splits instead of inner + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + auto tv1 = broadcast(tv0, {false, true}); + + auto tv2 = mul(tv1, IrBuilder::create(2)); + fusion.addOutput(tv2); + + auto tv3 = makeSymbolicTensor(3); + fusion.addInput(tv3); + + auto tv4 = add(tv3, tv2); + fusion.addOutput(tv4); + + const int numel_x = 200; + const int numel_y = 300; + const int numel_z = 400; + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto at_t0 = at::randn({numel_y}, options); + auto at_t3 = at::randn({numel_x, numel_y, numel_z}, options); + std::vector aten_inputs = {at_t0, at_t3}; + + auto lparams = schedulePointwise(&fusion, aten_inputs); + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs, lparams); + auto cg_outputs = fe.runFusion(aten_inputs, lparams); + + auto at_t1 = at_t0.unsqueeze(-1); + auto at_t2 = at_t1.mul(2.0); + + auto at_t4 = at_t3.add(at_t2); + + testValidate( + &fusion, cg_outputs, aten_inputs, {at_t2, at_t4}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionAdvancedIndexing10_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = makeContigTensor(2); + + // Register your inputs + fusion.addInput(tv0); + fusion.addInput(tv1); + + // Do math with it, it returns a `Val*` but can be static_casted back to + // TensorView + TensorView* tv2 = add(tv1, IrBuilder::create(2.0)); + TensorView* tv3 = add(tv0, tv2); + + // Register your outputs + fusion.addOutput(tv3); + + auto tv0_cache = tv0->cacheAfter(); + auto tv1_cache = tv1->cacheAfter(); + + std::vector tvs = {tv0_cache, tv1_cache, tv2, tv3}; + + for (auto tv : tvs) { + tv->split(1, 2, false); + tv->split(1, 1); + tv->split(-1, 4); + // [I0, 2, 1, I1/2/4, 4] + tv->reorder({{1, 2}, {2, 3}, {3, 1}}); + tv->axis(0)->parallelize(ParallelType::BIDx); + tv->axis(1)->parallelize(ParallelType::TIDx); + } + + // For all inputs, computeAt the output inline, temporaries should be squeezed + // between them + tv0->computeAt(tv3, 1); + tv1->computeAt(tv3, 1); + + tv0_cache->axis(-1)->parallelize(ParallelType::Vectorize); + tv1_cache->axis(-1)->parallelize(ParallelType::Vectorize); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor input1 = at::randn({64, 128}, options); + at::Tensor input2 = at::rand_like(input1); + at::Tensor output = at::empty_like(input1); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input1, input2}); + fe.runFusion({input1, input2}, {output}); + + at::Tensor tv2_ref = input2 + 2.0; + at::Tensor output_ref = input1 + tv2_ref; + + TORCH_CHECK(output_ref.equal(output)); +} + +TEST_F(NVFuserTest, FusionAdvancedIndexing11_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + int w = 3, x = 4, y = 7, z = 8; + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + auto tv0 = makeSymbolicTensor(4); + auto tv1 = makeSymbolicTensor(1); + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = add(tv1, IrBuilder::create(1.0)); + auto tv3 = broadcast(tv2, {true, false, true, true}); + auto tv4 = add(tv3, tv0); + + fusion.addOutput(tv4); + + tv4->merge(0); + tv4->merge(1); + + tv4->split(1, 32); + tv4->split(0, 1); + + tv4->reorder({{2, 1}}); + + tv2->computeAt(tv4, 3); + + tv2->setMemoryType(MemoryType::Global); + + tv4->axis(0)->parallelize(ParallelType::BIDx); + tv4->axis(1)->parallelize(ParallelType::BIDy); + tv4->axis(2)->parallelize(ParallelType::Unswitch); + tv4->axis(-1)->parallelize(ParallelType::TIDx); + + tv3->axis(-1)->parallelize(ParallelType::TIDx); + + FusionExecutor fe; + + at::Tensor t0 = at::randn({w, x, y, z}, options); + at::Tensor t1 = at::randn({x}, options); + + auto t3 = t1.add(1.0).unsqueeze(-1).unsqueeze(-1); + auto aten_output = t3.add(t0); + + std::vector aten_inputs = {t0, t1}; + + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + testValidate( + &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +// Intended to stress the lowering of our code generator +TEST_F(NVFuserTest, FusionAdvancedLowering1_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeConcreteTensor({9, 5}); + fusion.addInput(tv0); + + TensorView* tv1 = add(tv0, IrBuilder::create(1)); + TensorView* tv2 = add(tv1, IrBuilder::create(2)); + TensorView* tv3 = add(tv1, IrBuilder::create(3)); + TensorView* tv4 = sum(tv3, {1}); + + fusion.addOutput(tv2); + fusion.addOutput(tv4); + + tv4->split(1, 4); + auto tv5 = tv4->rFactor({2}); + + tv1->computeAt(tv5, 2); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(1); + at::Tensor aten_input = at::randn({9, 5}, options); + + auto t1 = aten_input.add(1.0); + auto t2 = t1.add(2.0); + auto t3 = t1.add(3.0); + auto t4 = t3.sum(1); + + std::vector aten_outputs = {t2, t4}; + + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); + auto cg_outputs = fe.runFusion({aten_input}); + + testValidate( + &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionAdvancedLowering2_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Progressively broadcast tensors + TensorView* tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + TensorView* tv1 = makeSymbolicTensor(2); + fusion.addInput(tv1); + TensorView* tv2 = makeSymbolicTensor(3); + fusion.addInput(tv2); + + TensorView* tv3 = add(tv0, IrBuilder::create(1)); + TensorView* tv4 = broadcast(tv3, {false, true}); + TensorView* tv5 = add(tv4, tv1); + TensorView* tv6 = add(tv5, tv2); + + fusion.addOutput(tv6); + + // Split inner dimension + tv6->split(1, 4); + // Merge middle dims with outer dimensions + tv6->merge(2); + tv6->merge(0); + + // tv6[I0*I1o, I1i*I2] + + // Compute everything inline + tv0->computeAt(tv6, -1); + + tv6->axis(0)->parallelize(ParallelType::BIDx); + tv6->axis(1)->parallelize(ParallelType::TIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + int x = 13, y = 9, z = 5; + at::Tensor t0 = at::randn({y}, options); + at::Tensor t1 = at::randn({y, z}, options); + at::Tensor t2 = at::randn({x, y, z}, options); + + auto t3 = t0.add(1.0); + auto t4 = t3.unsqueeze(-1); + auto t5 = t4.add(t1); + auto t6 = t5.add(t2); + + std::vector aten_inputs = {t0, t1, t2}; + std::vector aten_outputs = {t6}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + testValidate( + &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__); +} + +// TODO: Complete test +TEST_F(NVFuserTest, FusionAdvancedLowering3_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({1, -1}); + auto tv1 = makeSymbolicTensor(2); + fusion.addInput(tv0); + fusion.addInput(tv1); + + // [b0, i1] + auto tv2 = add(tv0, IrBuilder::create(2.0)); + + // [i0, i1] + auto tv3 = add(tv1, IrBuilder::create(3.0)); + + // [b0, i1] + auto tv4 = add(tv2, IrBuilder::create(4.0)); + + // [io, i1] + auto tv5 = add(tv2, tv3); + + fusion.addOutput(tv4); + fusion.addOutput(tv5); + + tv0->computeAt(tv4, -1); + + tv3->setMemoryType(MemoryType::Global); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + int x = 13, y = 9; + at::Tensor t0 = at::randn({1, y}, options); + at::Tensor t1 = at::randn({x, y}, options); + + auto t4 = t0 + 2 + 4; + auto t5 = t0 + 2 + t1 + 3; + + std::vector aten_inputs = {t0, t1}; + std::vector aten_outputs = {t4, t5}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + testValidate( + &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__); +} + +// This excercises indexing with broadcast root axes. Non-broadcast +// axes need to be preferred when propagating index exprs to root +// axes. See, e.g., Index::getConsumerIndex_impl. +TEST_F(NVFuserTest, FusionAdvancedLowering4_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + auto tv1 = broadcast(tv0, {false, true}); + auto tv2 = broadcast(tv1, {false, false, true}); + auto tv3 = makeSymbolicTensor(3); + fusion.addInput(tv3); + auto tv4 = add(tv2, tv3); + fusion.addOutput(tv4); + + tv4->merge(1)->merge(0); + tv4->split(0, 8); + tv0->computeAt(tv4, 1); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + const int bx = 10; + const int by = 20; + const int bz = 30; + at::Tensor t0 = at::randn({bx}, options); + at::Tensor t3 = at::randn({bx, by, bz}, options); + std::vector aten_inputs = {t0, t3}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + auto aten_output = + t0.unsqueeze(-1).expand({bx, by}).unsqueeze(-1).expand({bx, by, bz}) + t3; + + testValidate( + &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionAdvancedLowering5_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeConcreteTensor({5, 4, 3}); + fusion.addInput(tv0); + + TensorView* tv1 = makeConcreteTensor({5, 3}); + fusion.addInput(tv1); + + auto tv2 = broadcast(tv1, {false, true, false}); + + auto tv3 = add(tv0, tv2); + + fusion.addOutput(tv3); + + tv2->merge(0); + tv1->computeAt(tv2, 1); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(1); + at::Tensor t0 = at::randn({5, 4, 3}, options); + at::Tensor t1 = at::randn({5, 3}, options); + auto t2 = t1.unsqueeze(1); + auto t3 = t0 + t2; + + std::vector aten_inputs = {t0, t1}; + std::vector aten_outputs = {t3}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + testValidate( + &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionAdvancedLowering6_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeConcreteTensor({5, 4, 3}); + fusion.addInput(tv0); + auto tv1 = makeConcreteTensor({4}); + fusion.addInput(tv1); + auto tv2 = unaryOp(UnaryOpType::Set, tv0); + auto tv3 = unaryOp(UnaryOpType::Set, tv1); + + auto tv4 = sum(tv2, {0, 2}); + auto tv5 = add(tv4, tv3); + fusion.addOutput(tv5); + + auto tv6 = broadcast(tv3, {true, false, true}); + auto tv7 = add(tv2, tv6); + fusion.addOutput(tv7); + + tv2->computeAt(tv4, -1, ComputeAtMode::BestEffort); + tv3->computeAt(tv7, -1, ComputeAtMode::BestEffort); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(1); + at::Tensor t0 = at::randn({5, 4, 3}, options); + at::Tensor t1 = at::randn({4}, options); + + auto t2 = t0; + auto t3 = t1; + + std::vector reduction_axes{0, 2}; + auto t4 = t2.sum(reduction_axes); + auto t5 = add(t4, t3); + auto t6 = t3.unsqueeze(0).unsqueeze(-1); + auto t7 = t2.add(t6); + + std::vector aten_inputs = {t0, t1}; + std::vector aten_outputs = {t5, t7}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + testValidate( + &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__); +} + +// Test a simple Gemm but also play around with fusion executor features +TEST_F(NVFuserTest, FusionSimpleGemm_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + TensorView* tv0 = makeSymbolicTensor(2); // M, K + TensorView* tv1 = makeSymbolicTensor(2); // K, N + fusion.addInput(tv0); + fusion.addInput(tv1); + + TensorView* tv2 = broadcast(tv0, {false, false, true}); + // tv2[I0, I1, B] = tv0[I0, I1] + + TensorView* tv3 = broadcast(tv1, {true, false, false}); + // tv3[B, I1, I2] = tv1[I1, I2] + + // tv4[I0, I1, I2] = tv2[I0, I1, B] * tv3[B, I1, I2] + TensorView* tv4 = mul(tv2, tv3); + // tv5[I0, R1, I2] = tv4[I0, I1, I2] + TensorView* tv5 = sum(tv4, {1}); + fusion.addOutput(tv5); + + tv5->split(1, 32); + // tv5[I0, R1o, R1i{32}, I2] + + auto tv6 = tv5->rFactor({1}); + // tv6[I0, R1o, I1i{32}, I2] = tv4[I0, I1, I2] + // tv5[I0, , R1i{32}, I2] = tv6[I0, R1o, I1i{32}, I2] + + tv5->split(0, 4); + tv5->split(-1, 4); + // tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}] + // tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}] + + tv0->computeAt(tv5, -1); + tv1->computeAt(tv5, -1); + + // tv6[I0o, I0i{4}, R1o, I1i{32}, I2o, I2i{4}] + // tv5[I0o, I0i{4}, , R1i{32}, I2o, I2i{4}] + //--> (line symbolizes compute at location) + // tv4[I0o, I0i{4}, I1i{32}, I2o, I2i{4}|, I1o] + // tv6[I0o, I0i{4}, I1i{32}, I2o, I2i{4}|, R1o] + // tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}|] + + tv0->computeAt(tv6, -1); + tv1->computeAt(tv6, -1); + // tv4[I0o, I0i{4}, I1i{32}, I2o, I2i{4}, I1o |] + // tv6[I0o, I0i{4}, I1i{32}, I2o, I2i{4}, R1o |] + // tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}|] + + tv5->axis(0)->parallelize(ParallelType::BIDz); + tv5->axis(1)->parallelize(ParallelType::TIDz); + + tv5->axis(-2)->parallelize(ParallelType::BIDy); + tv5->axis(-1)->parallelize(ParallelType::TIDy); + + tv5->axis(2)->parallelize(ParallelType::TIDx); + tv6->axis(2)->parallelize(ParallelType::TIDx); + + constexpr int M = 65, K = 33, N = 17; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor t0 = at::randn({M, K}, options); + at::Tensor t1 = at::randn({K, N}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t1}, LaunchParams(1, -1, -1, 32, 4, 4)); + // Lets specify a few bounds in launch params to make sure it works + fe.runFusion({t0, t1}, LaunchParams(1, -1, -1, 32, 4, 4)); + + // Make sure bad launch params throws + // TODO: Re-enable once we have parallelization validation in. + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) + // ASSERT_ANY_THROW(fe.runFusion({t0, t1}, LaunchParams(1, 2, 3, 4, 5, 6))); + + // Don't specify any launch params + auto cg_outputs = fe.runFusion({t0, t1}); + + auto aten_output = t0.to(at::kDouble).matmul(t1.to(at::kDouble)); + + testValidate( + &fusion, cg_outputs, {t0, t1}, {aten_output}, __LINE__, __FILE__); +} + +// Softmax with a 1D tensor. Parallelized only with a single thread block. +TEST_F(NVFuserTest, FusionSoftmax1D_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + const int tidx = 128; + const int dimx = 1000; + + // Set up your input tensor views + TensorView* input_tv0 = makeSymbolicTensor(1); + fusion.addInput(input_tv0); + + TensorView* exp_tv1 = unaryOp(UnaryOpType::Exp, input_tv0); + TensorView* sum_exp_tv2 = sum(exp_tv1, {-1}); + TensorView* bcast_sum_tv3 = broadcast(sum_exp_tv2, {true}); + + // Replicate exp_tv4 as exp_tv4_copy because exp_tv4 is going to be + // computed at sum_exp_rf_tv8. + TensorView* exp_tv1_copy = unaryOp(UnaryOpType::Exp, input_tv0); + + TensorView* output_tv4 = div(exp_tv1_copy, bcast_sum_tv3); + + fusion.addOutput(output_tv4); + + bcast_sum_tv3->split(0, tidx); + + sum_exp_tv2->split(-1, tidx); + TensorView* sum_exp_rf_tv5 = sum_exp_tv2->rFactor({-2}); + + output_tv4->split(-1, tidx); + + exp_tv1->computeAt(sum_exp_rf_tv5, -1); + exp_tv1_copy->computeAt(output_tv4, -1); + + TensorView* tensors_to_parallelize[] = { + sum_exp_tv2, bcast_sum_tv3, output_tv4, sum_exp_rf_tv5}; + + for (auto tv : tensors_to_parallelize) { + tv->axis(-1)->parallelize(ParallelType::TIDx); + } + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({dimx}, options); + at::Tensor cg_output = at::empty({dimx}, options); + at::Tensor t3_output = at::empty_like(cg_output, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + fe.runFusion({t0}, {cg_output}); + + auto aten_output = at::_softmax(t0.to(at::kDouble), -1, false); + + testValidate(&fusion, {cg_output}, {t0}, {aten_output}, __LINE__, __FILE__); +} + +// Softmax with a 1D tensor with input normalization. +TEST_F(NVFuserTest, FusionSoftmax1DNormalized_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + const int tidx = 128; + const int dimx = 1000; + + // Set up your input tensor views + TensorView* input_tv0 = makeSymbolicTensor(1); + fusion.addInput(input_tv0); + + // Normalize with the max value before computing exp. + TensorView* max_val_tv1 = reductionOp( + BinaryOpType::Max, {-1}, IrBuilder::create(0), input_tv0); + TensorView* bcast_max_tv2 = broadcast(max_val_tv1, {true}); + TensorView* sub_tv3 = sub(input_tv0, bcast_max_tv2); + TensorView* exp_tv4 = unaryOp(UnaryOpType::Exp, sub_tv3); + TensorView* sum_exp_tv5 = sum(exp_tv4, {-1}); + TensorView* bcast_sum_tv6 = broadcast(sum_exp_tv5, {true}); + + // Replicate exp_tv4 as exp_tv4_copy because exp_tv4 is going to be + // computed at sum_exp_rf_tv8. + TensorView* sub_tv3_copy = sub(input_tv0, bcast_max_tv2); + TensorView* exp_tv4_copy = unaryOp(UnaryOpType::Exp, sub_tv3_copy); + + TensorView* output_tv7 = div(exp_tv4_copy, bcast_sum_tv6); + + fusion.addOutput(output_tv7); + bcast_max_tv2->split(0, tidx); + bcast_sum_tv6->split(0, tidx); + + max_val_tv1->split(-1, tidx); + TensorView* max_val_rf_tv8 = max_val_tv1->rFactor({-2}); + + sum_exp_tv5->split(-1, tidx); + TensorView* sum_exp_rf_tv9 = sum_exp_tv5->rFactor({-2}); + + output_tv7->split(-1, tidx); + + sub_tv3->computeAt(sum_exp_rf_tv9, -1); + sub_tv3_copy->computeAt(output_tv7, -1); + + TensorView* tensors_to_parallelize[] = { + max_val_tv1, + bcast_max_tv2, + sum_exp_tv5, + bcast_sum_tv6, + output_tv7, + max_val_rf_tv8, + sum_exp_rf_tv9}; + + for (auto tv : tensors_to_parallelize) { + tv->axis(-1)->parallelize(ParallelType::TIDx); + } + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input = at::randn({dimx}, options); + at::Tensor t3_output = at::empty({dimx}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input}); + auto cg_outputs = fe.runFusion({input}); + + auto aten_output = at::_softmax(input.to(at::kDouble), -1, false); + + testValidate(&fusion, cg_outputs, {input}, {aten_output}, __LINE__, __FILE__); +} + +// Softmax with a 3D tensor, where the inner-most 3rd dimension is +// normalized. Pallelized with multiple thread blocks. +TEST_F(NVFuserTest, FusionSoftmax3D_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + const int tidx = 32; + const int dimx = 32; + const int dimy = 16; + const int dimz = 130; + + // Set up your input tensor views + TensorView* input_tv0 = makeSymbolicTensor(3); + fusion.addInput(input_tv0); + + TensorView* exp_tv1 = unaryOp(UnaryOpType::Exp, input_tv0); + TensorView* sum_exp_tv2 = sum(exp_tv1, {-1}); + TensorView* bcast_sum_tv3 = broadcast(sum_exp_tv2, {false, false, true}); + + // Replicate exp_tv4 as exp_tv4_copy because exp_tv4 is going to be + // computed at sum_exp_rf_tv8. + TensorView* exp_tv1_copy = unaryOp(UnaryOpType::Exp, input_tv0); + + TensorView* output_tv4 = div(exp_tv1_copy, bcast_sum_tv3); + + fusion.addOutput(output_tv4); + + bcast_sum_tv3->split(-1, tidx); + + sum_exp_tv2->split(-1, tidx); + TensorView* sum_exp_rf_tv5 = sum_exp_tv2->rFactor({-2}); + + output_tv4->split(-1, tidx); + + exp_tv1->computeAt(sum_exp_rf_tv5, -1); + exp_tv1_copy->computeAt(output_tv4, -1); + + TensorView* tensors_to_parallelize[] = { + sum_exp_tv2, bcast_sum_tv3, output_tv4, sum_exp_rf_tv5}; + + for (auto tv : tensors_to_parallelize) { + tv->axis(0)->parallelize(ParallelType::BIDx); + tv->axis(1)->parallelize(ParallelType::BIDy); + tv->axis(-1)->parallelize(ParallelType::TIDx); + } + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input = at::randn({dimx, dimy, dimz}, options); + + at::Tensor cg_output = at::empty({dimx, dimy, dimz}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input}); + fe.runFusion({input}, {cg_output}); + + auto aten_output = at::_softmax(input.to(at::kDouble), -1, false); + + testValidate( + &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__); +} + +// Softmax with a 3D tensor with input normalization. +TEST_F(NVFuserTest, FusionSoftmax3DNormalized_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + const int tidx = 32; + const int dimx = 32; + const int dimy = 16; + const int dimz = 130; + + // Set up your input tensor views + TensorView* input_tv0 = makeSymbolicTensor(3); + fusion.addInput(input_tv0); + + // Normalize with the max value before computing exp. + TensorView* max_val_tv1 = reductionOp( + BinaryOpType::Max, {-1}, IrBuilder::create(0), input_tv0); + TensorView* bcast_max_tv2 = broadcast(max_val_tv1, {false, false, true}); + TensorView* sub_tv3 = sub(input_tv0, bcast_max_tv2); + TensorView* exp_tv4 = unaryOp(UnaryOpType::Exp, sub_tv3); + TensorView* sum_exp_tv5 = sum(exp_tv4, {-1}); + TensorView* bcast_sum_tv6 = broadcast(sum_exp_tv5, {false, false, true}); + + // Replicate exp_tv4 as exp_tv4_copy because exp_tv4 is going to be + // computed at sum_exp_rf_tv8. + TensorView* sub_tv3_copy = sub(input_tv0, bcast_max_tv2); + TensorView* exp_tv4_copy = unaryOp(UnaryOpType::Exp, sub_tv3_copy); + + TensorView* output_tv7 = div(exp_tv4_copy, bcast_sum_tv6); + + fusion.addOutput(output_tv7); + + bcast_max_tv2->split(-1, tidx); + bcast_sum_tv6->split(-1, tidx); + + max_val_tv1->split(-1, tidx); + TensorView* max_val_rf_tv8 = max_val_tv1->rFactor({-2}); + + sum_exp_tv5->split(-1, tidx); + TensorView* sum_exp_rf_tv9 = sum_exp_tv5->rFactor({-2}); + + output_tv7->split(-1, tidx); + + sub_tv3->computeAt(sum_exp_rf_tv9, -1); + sub_tv3_copy->computeAt(output_tv7, -1); + + TensorView* tensors_to_parallelize[] = { + max_val_tv1, + bcast_max_tv2, + sum_exp_tv5, + bcast_sum_tv6, + output_tv7, + max_val_rf_tv8, + sum_exp_rf_tv9}; + + for (auto tv : tensors_to_parallelize) { + tv->axis(0)->parallelize(ParallelType::BIDx); + tv->axis(1)->parallelize(ParallelType::BIDy); + tv->axis(-1)->parallelize(ParallelType::TIDx); + } + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input = at::randn({dimx, dimy, dimz}, options); + at::Tensor t3_output = at::empty({dimx, dimy, dimz}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input}); + auto cg_outputs = fe.runFusion({input}); + + auto aten_output = at::_softmax(input.to(at::kDouble), -1, false); + + testValidate(&fusion, cg_outputs, {input}, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionSoftmaxComputeAt_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + TensorView* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto tv1 = sum(tv0, {1}); + auto tv2 = broadcast(tv1, {false, true}); + + auto tv3 = add(tv0, IrBuilder::create(1.0)); + + auto tv4 = mul(tv2, tv3); + + auto tv5 = sum(tv4, {1}); + auto tv6 = broadcast(tv5, {false, true}); + + auto tv7 = sub(tv6, tv4); + fusion.addOutput(tv7); + + tv1->computeAt(tv7, 1); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) + ASSERT_ANY_THROW(tv1->computeAt(tv7, -1)); +} + +// Similar to FusionReduction but uses grid reduction +TEST_F(NVFuserTest, FusionGridReduction1_CUDA) { + const int gdimx = 32; + const int bdimx = 128; + + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + TensorView* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + // tv1[I0, R1] = tv0[I0, I1] + TensorView* tv1 = + reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv0); + fusion.addOutput(tv1); + + TORCH_CHECK( + ir_utils::getReductionOps(&fusion).size(), + "Could not detect reduction in fusion."); + + tv1->split(1, bdimx); + // tv1[I0, R1o, R1i{128}] = tv0[I0, I1] + tv1->split(1, gdimx); + // tv1[I0, R1oo, R1oi{32}, R1i{128}] = tv0[I0, I1] + + TensorView* tv2 = tv1->rFactor({1}); + // tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}] = tv0[I0, I1] + // tv1[I0, R1oi{32}, R1i{128}] = tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}] + + // Incrementally, can print in between for debugging + tv0->computeAt(tv2, 1); + tv2->computeAt(tv1, 1); + + // Re do it all at once, because why not. + tv0->computeAt(tv1, 1); + + tv1->axis(0)->parallelize(ParallelType::BIDy); + tv1->axis(1)->parallelize(ParallelType::BIDx); + tv2->axis(2)->parallelize(ParallelType::BIDx); + + tv1->axis(-1)->parallelize(ParallelType::TIDx); + tv2->axis(-1)->parallelize(ParallelType::TIDx); + + int numel_x = 10000; + int numel_y = 65000; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input = at::randn({numel_x, numel_y}, options); + at::Tensor cg_output = at::empty({numel_x}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input}); + fe.runFusion({input}, {cg_output}); + + auto aten_output = input.to(at::kDouble).sum({1}); + + testValidate( + &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__); +} + +// Same test as the above but uses BIDy and TIDx for reduction +TEST_F(NVFuserTest, FusionGridReduction2_CUDA) { + const int gdimy = 32; + const int bdimx = 128; + + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + TensorView* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + // tv1[I0, R1] = tv0[I0, I1] + TensorView* tv1 = + reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv0); + fusion.addOutput(tv1); + + TORCH_CHECK( + ir_utils::getReductionOps(&fusion).size(), + "Could not detect reduction in fusion."); + + tv1->split(1, bdimx); + // tv1[I0, R1o, R1i{128}] = tv0[I0, I1] + tv1->split(1, gdimy); + // tv1[I0, R1oo, R1oi{32}, R1i{128}] = tv0[I0, I1] + + TensorView* tv2 = tv1->rFactor({1}); + // tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}] = tv0[I0, I1] + // tv1[I0, R1oi{32}, R1i{128}] = tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}] + + // Incrementally, can print in between for debugging + tv0->computeAt(tv2, 1); + tv2->computeAt(tv1, 1); + + // Re do it all at once, because why not. + tv0->computeAt(tv1, 1); + + tv1->axis(0)->parallelize(ParallelType::BIDx); + tv1->axis(1)->parallelize(ParallelType::BIDy); + tv2->axis(2)->parallelize(ParallelType::BIDy); + + tv1->axis(-1)->parallelize(ParallelType::TIDx); + tv2->axis(-1)->parallelize(ParallelType::TIDx); + + int numel_x = 10000; + int numel_y = 65000; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input = at::randn({numel_x, numel_y}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input}); + auto cg_outputs = fe.runFusion({input}); + + auto aten_output = input.to(at::kDouble).sum({1}); + + testValidate(&fusion, cg_outputs, {input}, {aten_output}, __LINE__, __FILE__); +} + +// Same test but uses BIDy and BIDz for reduction. No TID used. +TEST_F(NVFuserTest, FusionGridReduction3dim1_CUDA) { + // Grid reductions when there aren't any threads are serial reductions + // keep these numbers low so our error isn't too high compared to normal cuda + // reductions + const int gdimz = 15; + const int gdimy = 9; + + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + TensorView* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + // tv1[I0, R1] = tv0[I0, I1] + TensorView* tv1 = + reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv0); + fusion.addOutput(tv1); + + TORCH_CHECK( + ir_utils::getReductionOps(&fusion).size(), + "Could not detect reduction in fusion."); + + tv1->split(1, gdimy); + // tv1[I0, R1o, R1i{128}] = tv0[I0, I1] + tv1->split(1, gdimz); + // tv1[I0, R1oo, R1oi{32}, R1i{128}] = tv0[I0, I1] + + TensorView* tv2 = tv1->rFactor({1}); + // tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}] = tv0[I0, I1] + // tv1[I0, R1oi{32}, R1i{128}] = tv2[I0, R1oo, Ir1oi{32}, Ir1i{128}] + + // Incrementally, can print in between for debugging + tv0->computeAt(tv2, 1); + tv2->computeAt(tv1, 1); + + // Re do it all at once, because why not. + tv0->computeAt(tv1, 1); + + tv1->axis(0)->parallelize(ParallelType::BIDx); + tv1->axis(1)->parallelize(ParallelType::BIDz); + tv2->axis(2)->parallelize(ParallelType::BIDz); + tv1->axis(-1)->parallelize(ParallelType::BIDy); + tv2->axis(-1)->parallelize(ParallelType::BIDy); + + int numel_x = 100; + int numel_y = 6500; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input = at::randn({numel_x, numel_y}, options); + at::Tensor cg_output = at::empty({numel_x}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input}); + fe.runFusion({input}, {cg_output}); + + auto aten_output = input.to(at::kDouble).sum({1}); + testValidate( + &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__); +} + +// Same as testGPU_FusionGridReduction3dim1 but reduces dimension 0 +TEST_F(NVFuserTest, FusionGridReduction3dim0_CUDA) { + // Grid reductions when there aren't any threads are serial reductions + // keep these numbers low so our error isn't too high compared to normal cuda + // reductions + const int gdimz = 15; + const int gdimy = 9; + + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + TensorView* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + // tv1[R0, I1] = tv0[I0, I1] + TensorView* tv1 = + reductionOp(BinaryOpType::Add, {0}, IrBuilder::create(0), tv0); + fusion.addOutput(tv1); + + TORCH_CHECK( + ir_utils::getReductionOps(&fusion).size(), + "Could not detect reduction in fusion."); + + tv1->split(0, gdimy); + // tv1[R0o, R0i{128}, I1] = tv0[I0, I1] + tv1->split(0, gdimz); + // tv1[R0oo, R0oi{32}, R0i{128}, I1] = tv0[I0, I1] + + TensorView* tv2 = tv1->rFactor({0}); + // tv2[R0oo, I0oi{32}, I0i{128}, I1] = tv0[I0, I1] + // tv1[ R0oi{32}, R0i{128}, I1] = tv2[R0oo, I0oi{32}, I0i{128}, I1] + + // Note that computeAt isn't going to make anything better as there + // is no dynamically sized dimension. + + // Map parallelism as [Serial, BIDz, BIDy, BIDx] + tv1->axis(-1)->parallelize(ParallelType::BIDx); + tv2->axis(-1)->parallelize(ParallelType::BIDx); + tv1->axis(-2)->parallelize(ParallelType::BIDy); + tv2->axis(-2)->parallelize(ParallelType::BIDy); + tv1->axis(-3)->parallelize(ParallelType::BIDz); + tv2->axis(-3)->parallelize(ParallelType::BIDz); + + int numel_x = 6500; + int numel_y = 100; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + at::Tensor input = at::randn({numel_x, numel_y}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input}); + auto cg_outputs = fe.runFusion({input}); + + auto aten_output = input.to(at::kDouble).sum({0}); + + testValidate(&fusion, cg_outputs, {input}, {aten_output}, __LINE__, __FILE__); +} + +// This is similar to the FusionReduction, but swaps BIDx and TIDx +TEST_F(NVFuserTest, FusionGridReduction4_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + const int bdimx = 128; + const int gdimx = 1024; + + // Set up your input tensor views + TensorView* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + // tv1[I0, R1] = tv0[I0, I1] + TensorView* tv1 = + reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv0); + fusion.addOutput(tv1); + + TORCH_CHECK( + ir_utils::getReductionOps(&fusion).size(), + "Could not detect reduction in fusion."); + + tv1->split(1, gdimx); + // tv1[I0, R1o, R1i{1024}] = tv0[I0, I1] + tv1->split(1, 4); + // tv1[I0, R1oo, R1oi{4}, R1i{128}] = tv0[I0, I1] + + TensorView* tv2 = tv1->rFactor({1}); + // tv2[I0, R1oo, Ir1oi{4}, Ir1i{1024}] = tv0[I0, I1] + // tv1[I0, R1oi{4}, R1i{1024}] = tv2[I0, R1oo, Ir1oi{4}, Ir1i{1024}] + + TensorView* tv3 = tv1->rFactor({1}); + // tv2[I0, R1oo, Ir1oi{4}, Ir1i{1024}] = tv0[I0, I1] + // tv3[I0, R1oi{4}, Ir1i{1024}] = tv2[I0, R1oo, Ir1oi{4}, Ir1i{1024}] + // tv1[I0, R1i{1024}] = tv3[I0, R1oi{4}, Ir1i{1024}] + + // Incrementally, can print in between for debugging + tv0->computeAt(tv2, 1); + tv2->computeAt(tv3, 1); + tv3->computeAt(tv1, 1); + + // Re do it all at once, because why not. + tv0->computeAt(tv1, 1); + + tv2->axis(2)->parallelize(ParallelType::Unroll); + tv1->axis(0)->parallelize(ParallelType::TIDx); + + tv1->axis(-1)->parallelize(ParallelType::BIDx); + tv2->axis(-1)->parallelize(ParallelType::BIDx); + tv3->axis(-1)->parallelize(ParallelType::BIDx); + + int numel_x = bdimx; + int numel_y = 65000; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input = at::randn({numel_x, numel_y}, options); + at::Tensor cg_output = at::empty({numel_x}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input}); + fe.runFusion({input}, {cg_output}); + + auto aten_output = input.to(at::kDouble).sum({1}); + testValidate( + &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__); +} + +// Grid reduction with 2D thread blocks but only TIDx and BIDx are +// mapped to a reduction dim +TEST_F(NVFuserTest, FusionGridReduction5_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + const int bdimx = 64; + const int bdimy = 16; + const int gdimx = 4; + + // Set up your input tensor views + TensorView* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + // tv1[I0, R1] = tv0[I0, I1] + TensorView* tv1 = + reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv0); + fusion.addOutput(tv1); + + TORCH_CHECK( + ir_utils::getReductionOps(&fusion).size(), + "Could not detect reduction in fusion."); + + tv1->split(1, bdimx); + // tv1[I0, R1o, R1i{64}] = tv0[I0, I1] + tv1->split(1, gdimx); + // tv1[I0, R1oo, R1oi{4}, R1i{64}] = tv0[I0, I1] + + TensorView* tv2 = tv1->rFactor({1}); + // tv2[I0, R1oo, Ir1oi{4}, Ir1i{64}] = tv0[I0, I1] + // tv1[I0, R1oi{4}, R1i{64}] = tv2[I0, R1oo, Ir1oi{4}, Ir1i{64}] + + tv0->computeAt(tv1, 1); + + tv1->axis(-1)->parallelize(ParallelType::TIDx); + tv2->axis(-1)->parallelize(ParallelType::TIDx); + + tv1->axis(-2)->parallelize(ParallelType::BIDx); + tv2->axis(-2)->parallelize(ParallelType::BIDx); + + tv1->axis(0)->parallelize(ParallelType::TIDy); + + int numel_x = bdimy; + int numel_y = 6500; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input = at::randn({numel_x, numel_y}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input}); + auto cg_outputs = fe.runFusion({input}); + + auto aten_output = input.to(at::kDouble).sum({1}); + testValidate(&fusion, cg_outputs, {input}, {aten_output}, __LINE__, __FILE__); +} + +// Similar to FusionGridReduction1 but with 3D tensors +TEST_F(NVFuserTest, FusionGridReduction6_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + TensorView* tv0 = makeSymbolicTensor(3); + fusion.addInput(tv0); + + // tv1[I0, R1, R2] = tv0[I0, I1, I2] + TensorView* tv1 = + reductionOp(BinaryOpType::Add, {1, 2}, IrBuilder::create(0), tv0); + fusion.addOutput(tv1); + + TORCH_CHECK( + ir_utils::getReductionOps(&fusion).size(), + "Could not detect reduction in fusion."); + + // Splitting for TID + tv1->split(2, 128); + // tv1[I0, R1, R2o, R2i{128}] = tv0[I0, I1, I2] + + // Splitting for BID + tv1->split(1, 128); + + // tv1[I0, R1o, R1i{128}, R2o, R2i{128}] = tv0[I0, I1, I2] + + TensorView* tv2 = tv1->rFactor({3}); + // tv2[I0, I1o, I1i{128}, R2o, I2i{128}] + // tv1[I0, R1o, R1i{128}, R2i{128}] + + TensorView* tv3 = tv1->rFactor({1}); + // tv2[I0, I1o, I1i{128}, R2o, I2i{128}] + // tv3[I0, R1o, I1i{128}, I2i{128}] + // tv1[I0, R1i{128}, R2i{128}] + + tv3->computeAt(tv1, 1); + tv2->computeAt(tv3, 3); + + tv1->axis(0)->parallelize(ParallelType::BIDy); + + tv1->axis(-1)->parallelize(ParallelType::TIDx); + tv2->axis(-1)->parallelize(ParallelType::TIDx); + tv3->axis(-1)->parallelize(ParallelType::TIDx); + + tv1->axis(-2)->parallelize(ParallelType::BIDx); + tv2->axis(-3)->parallelize(ParallelType::BIDx); + tv3->axis(-2)->parallelize(ParallelType::BIDx); + + int numel_x = 6500; + int numel_y = 200; + int numel_z = numel_y; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input = at::randn({numel_x, numel_y, numel_z}, options); + at::Tensor cg_output = at::empty({numel_x}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input}); + fe.runFusion({input}, {cg_output}); + + auto aten_output = input.to(at::kDouble).sum({1, 2}); + + testValidate( + &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__); +} + +// See issue #1049 +TEST_F(NVFuserTest, FusionGridReduction7_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + auto tv1 = sum(tv0, {0}); + fusion.addOutput(tv1); + + tv1->split(0, 1000); + + tv1->axis(0)->parallelize(ParallelType::BIDx); + tv1->axis(1)->parallelize(ParallelType::BIDy); + + const int numel_x = 1; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input = at::randn({numel_x}, options); + at::Tensor cg_output = at::empty({numel_x}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input}); + auto out = fe.runFusion({input}); + + auto aten_output = input.sum({0}); + + testValidate(&fusion, out, {input}, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionGridReduction8_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto tv1 = sum(tv0, {0}); + fusion.addOutput(tv1); + + tv1->axis(0)->parallelize(ParallelType::BIDx); + tv1->axis(1)->parallelize(ParallelType::TIDx); + + const int numel_x = 2; + const int numel_y = 4; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input = at::randn({numel_x, numel_y}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input}); + auto out = fe.runFusion({input}); + + auto aten_output = input.sum({0}); + + testValidate(&fusion, out, {input}, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionGridReduction9_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto tv1 = sum(tv0, {1}); + + auto tv2 = makeSymbolicTensor(1); + fusion.addInput(tv2); + + auto tv3 = add(tv2, tv1); + fusion.addOutput(tv3); + + tv1->split(1, 2); + + tv1->axis(1)->parallelize(ParallelType::BIDx); + tv1->axis(2)->parallelize(ParallelType::BIDy); + + tv1->computeAt(tv3, 1); + + const int numel_x = 4; + const int numel_y = 10; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({numel_x, numel_y}, options); + at::Tensor t2 = at::randn({numel_x}, options); + + std::vector aten_inputs = {t0, t2}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_output = fe.runFusion(aten_inputs); + + auto aten_output = t0.sum({1}).add(t2); + + testValidate(&fusion, cg_output, {t0, t2}, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionGridReduction10_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(4); + fusion.addInput(tv0); + + auto tv1 = sum(tv0, {-1}); + auto tv2 = sum(tv1, {-1}); + auto tv3 = sum(tv2, {-1}); + + fusion.addOutput(tv3); + tv1->axis(0)->parallelize(ParallelType::TIDx); + tv1->axis(1)->parallelize(ParallelType::BIDx); + tv1->axis(2)->parallelize(ParallelType::TIDy); + tv1->axis(3)->parallelize(ParallelType::TIDz); + + tv2->axis(0)->parallelize(ParallelType::TIDx); + tv2->axis(1)->parallelize(ParallelType::BIDx); + tv2->axis(2)->parallelize(ParallelType::TIDy); + + tv3->axis(0)->parallelize(ParallelType::TIDx); + tv3->axis(1)->parallelize(ParallelType::BIDx); + + tv0->computeAt(tv3, 1); + + const int numel_w = 2; + const int numel_x = 3; + const int numel_y = 4; + const int numel_z = 5; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({numel_w, numel_x, numel_y, numel_z}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto cg_output = fe.runFusion({t0}); + + auto aten_output = t0.sum({1, 2, 3}); + + testValidate(&fusion, cg_output, {t0}, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionNonRedAxisBind_CUDA) { + int bid_x = 3; + int tid_x = 2; + int red_dim = 0; + + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + TensorView* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + TensorView* tv1 = reductionOp( + BinaryOpType::Add, {red_dim}, IrBuilder::create(0), tv0); + fusion.addOutput(tv1); + + tv1->split(-1, tid_x); + tv1->axis(-2)->parallelize(ParallelType::BIDx); + tv1->axis(-1)->parallelize(ParallelType::TIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input = at::randn({16, bid_x * tid_x}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input}); + auto cg_outputs = fe.runFusion({input}); + + auto aten_output = input.to(at::kDouble).sum({red_dim}); + + testValidate(&fusion, cg_outputs, {input}, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionSplitBCast_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + TensorView* input_tv0 = makeSymbolicTensor(3); + TensorView* input_tv1 = makeSymbolicTensor(3); + fusion.addInput(input_tv0); + fusion.addInput(input_tv1); + + TensorView* sum_tv2 = reductionOp( + BinaryOpType::Add, {2}, IrBuilder::create(0), input_tv0); + TensorView* bcast_tv3 = broadcast(sum_tv2, {false, false, true}); + TensorView* output_tv4 = div(input_tv1, bcast_tv3); + + sum_tv2->split(-1, 32); + TensorView* sum_rf_tv5 = sum_tv2->rFactor({-2}); + + bcast_tv3->split(-1, 32); + output_tv4->split(-1, 32); + + sum_rf_tv5->axis(0)->parallelize(ParallelType::BIDx); + sum_tv2->axis(0)->parallelize(ParallelType::BIDx); + bcast_tv3->axis(0)->parallelize(ParallelType::BIDx); + output_tv4->axis(0)->parallelize(ParallelType::BIDx); + + sum_rf_tv5->axis(1)->parallelize(ParallelType::BIDy); + sum_tv2->axis(1)->parallelize(ParallelType::BIDy); + bcast_tv3->axis(1)->parallelize(ParallelType::BIDy); + output_tv4->axis(1)->parallelize(ParallelType::BIDy); + + sum_rf_tv5->axis(-1)->parallelize(ParallelType::TIDx); + sum_tv2->axis(-1)->parallelize(ParallelType::TIDx); + bcast_tv3->axis(-1)->parallelize(ParallelType::TIDx); + output_tv4->axis(-1)->parallelize(ParallelType::TIDx); + + fusion.addOutput(output_tv4); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({32, 32, 128}, options); + at::Tensor t1 = at::randn({32, 32, 128}, options); + at::Tensor cg_output = at::empty({32, 32, 128}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t1}); + fe.runFusion({t0, t1}, {cg_output}); +} + +TEST_F(NVFuserTest, FusionBCastInnerDim_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + // reduce then broadcast + auto tv1 = sum(tv0, {0}); + auto tv2 = broadcast(tv1, {false, true}); + + TORCH_CHECK(!tv2->axis(0)->isReduction() && tv2->axis(1)->isBroadcast()); +} + +TEST_F(NVFuserTest, FusionBCastReduce_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + TensorView* tv0 = makeSymbolicTensor(2); + + auto tv1 = broadcast(tv0, {true, false, false}); + auto tv2 = sum(tv1, {1}); + TORCH_CHECK( + tv2->axis(0)->isBroadcast() && tv2->axis(1)->isReduction() && + !tv2->axis(2)->isBroadcast() && !tv2->axis(2)->isReduction()); +} + +// Multiple consumer reduction with computeAt +// https://github.com/csarofeen/pytorch/issues/110 +TEST_F(NVFuserTest, FusionReductionMultiConsumer_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + TensorView* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto tv1 = unaryOp(UnaryOpType::Exp, tv0); + auto tv2 = + reductionOp(BinaryOpType::Max, {-1}, IrBuilder::create(0), tv1); + auto tv3 = + reductionOp(BinaryOpType::Min, {-1}, IrBuilder::create(0), tv1); + auto tv4 = add(tv2, tv3); + fusion.addOutput(tv4); + tv1->computeAt(tv2, -1, ComputeAtMode::BestEffort); + + TORCH_CHECK(tv1->getComputeAtPosition() == 2); +} + +TEST_F(NVFuserTest, FusionComputeAtExprOrder1_CUDA) { + for (const auto i : c10::irange(2)) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + TensorView* tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv0, IrBuilder::create(1)); + TensorView* tv3 = add(tv1, tv2); + // Set outputs tv2 or tv1 and then tv3 + if (i == 0) { + fusion.addOutput(tv2); + } else { + fusion.addOutput(tv1); + } + fusion.addOutput(tv3); + + if (i == 0) { + tv1->computeAt(tv3, -1); + } else { + tv2->computeAt(tv3, -1); + } + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor aten_input = at::randn({100}, options); + std::vector aten_outputs = { + aten_input + 1, (aten_input + 1) * 2}; + + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); + auto cg_outputs = fe.runFusion({aten_input}); + + testValidate( + &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); + } +} + +TEST_F(NVFuserTest, FusionComputeAtExprOrder2_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + TensorView* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv0, IrBuilder::create(1)); + TensorView* tv3 = add(tv1, tv2); + fusion.addOutput(tv3); + + tv3->split(-1, 32); + + tv1->computeAt(tv3, -1); + tv2->computeAt(tv3, -2); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor aten_input = at::randn({100, 100}, options); + auto aten_output = (aten_input + 1) * 2; + + at::Tensor cg_output = at::empty_like(aten_input, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); + fe.runFusion({aten_input}, {cg_output}); + + testValidate( + &fusion, {cg_output}, {aten_input}, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionComputeAtExprOrder3_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + const int64_t dimx = 13; + const int64_t dimy = 15; + + TensorView* tv0 = makeConcreteTensor({dimx, dimy}); + fusion.addInput(tv0); + TensorView* tv1 = add(tv0, IrBuilder::create(1)); + TensorView* tv2 = add(tv1, IrBuilder::create(2)); + TensorView* tv3 = add(tv2, IrBuilder::create(3)); + TensorView* tv4 = add(tv3, IrBuilder::create(4)); + TensorView* tv5 = mul(tv2, tv4); + fusion.addOutput(tv5); + + tv1->computeAt(tv2, 2); + tv3->computeAt(tv4, 1); + tv4->computeAt(tv5, 2); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor aten_input = at::randn({dimx, dimy}, options); + auto t1 = aten_input.add(1.); + auto t2 = t1.add(2.); + auto t3 = t2.add(3.); + auto t4 = t3.add(4.); + auto aten_output = t2.mul(t4); + + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); + auto cg_outputs = fe.runFusion({aten_input}); + + testValidate( + &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionZeroDimComputeAt_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + auto tv1 = sum(tv0, {0}); + auto tv2 = add(tv1, IrBuilder::create(1)); + fusion.addOutput(tv2); + TORCH_CHECK(tv2->nDims() == 0); + tv1->computeAt(tv2, 0); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor aten_input = at::randn({100}, options); + auto aten_output = aten_input.to(at::kDouble).sum() + 1; + + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); + auto cg_outputs = fe.runFusion({aten_input}); + + testValidate( + &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionZeroDimBroadcast_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeSymbolicTensor(0); + fusion.addInput(tv0); + + auto tv1 = broadcast(tv0, {true, true}); + TORCH_CHECK(tv1->nDims() == 2); + + TensorView* tv2 = makeSymbolicTensor(2); + fusion.addInput(tv2); + + auto tv3 = add(tv1, tv2); + auto tv4 = sum(tv3, {0, 1}); + fusion.addOutput(tv4); + + tv3->computeAt(tv4, -1); + tv3->axis(-2)->parallelize(ParallelType::TIDx); + tv3->axis(-1)->parallelize(ParallelType::TIDy); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({}, options); + at::Tensor t1 = at::randn({10, 10}, options); + + auto aten_output = (t0.unsqueeze(-1).unsqueeze(-1).expand({10, 10}) + t1) + .to(at::kDouble) + .sum(); + + std::vector aten_inputs = {t0, t1}; + at::Tensor cg_output = at::empty({}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + fe.runFusion(aten_inputs, {cg_output}); + + testValidate( + &fusion, {cg_output}, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionZeroDimReduction_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + const int bdimx = 32; + const int gdimx = 32; + + TensorView* tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + auto tv1 = sum(tv0, {0}); + fusion.addOutput(tv1); + + tv1->split(0, bdimx); + tv1->split(0, gdimx); + auto tv2 = tv1->rFactor({0}); + + tv1->axis(-1)->parallelize(ParallelType::TIDx); + tv2->axis(-1)->parallelize(ParallelType::TIDx); + tv1->axis(-2)->parallelize(ParallelType::BIDx); + tv2->axis(-2)->parallelize(ParallelType::BIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor aten_input = at::randn({1000}, options); + auto aten_output = aten_input.to(at::kDouble).sum(); + + at::Tensor cg_output = at::empty({}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); + fe.runFusion({aten_input}, {cg_output}); + + testValidate( + &fusion, {cg_output}, {aten_input}, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionBCastAfterReduce_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + const int tidx = 128; + + // Set up your input tensor views + TensorView* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto tv1 = sum(tv0, {1}); + auto tv2 = broadcast(tv1, {false, true}); + + tv1->split(1, tidx); + auto tv3 = tv1->rFactor({-2}); + + TensorView* tv4 = makeSymbolicTensor(2); + fusion.addInput(tv4); + + auto tv5 = add(tv2, tv4); + fusion.addOutput(tv5); + tv5->split(1, tidx); + + tv3->computeAt(tv5, 1); + + tv2->split(1, tidx); + + tv1->axis(-1)->parallelize(ParallelType::TIDx); + tv2->axis(-1)->parallelize(ParallelType::TIDx); + tv3->axis(-1)->parallelize(ParallelType::TIDx); + tv5->axis(-1)->parallelize(ParallelType::TIDx); + + tv5->axis(0)->parallelize(ParallelType::BIDx); + + int x = 63, y = 200; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor t0 = at::randn({x, y}, options); + at::Tensor t4 = at::randn({x, y}, options); + + auto t3 = t0.to(at::kDouble).sum({1}).unsqueeze(-1).expand({x, y}); + auto aten_output = t3.add(t4); + + std::vector aten_inputs = {t0, t4}; + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t4}); + auto cg_outputs = fe.runFusion({t0, t4}); + + testValidate( + &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionOutputBroadcast_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeConcreteTensor({2, 3}); + fusion.addInput(tv0); + + TensorView* tv1 = broadcast(tv0, {true, false, true, false, true}); + + fusion.addOutput(tv1); + + const auto options = + at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor aten_input = at::randn({2, 3}, options); + auto aten_output = aten_input.unsqueeze(2).unsqueeze(1).unsqueeze(0); + + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); + auto cg_outputs = fe.runFusion({aten_input}); + + testValidate( + &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionReductionKeepDimBasic_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeConcreteTensor({2, 3, 4, 5, 6}); + fusion.addInput(tv0); + + TensorView* tv1 = sum(tv0, {0, 2, -1}, /*keep_dim=*/true); + + fusion.addOutput(tv1); + + const auto options = + at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor aten_input = at::randn({2, 3, 4, 5, 6}, options); + auto aten_output = + aten_input.to(at::kDouble).sum({0, 2, -1}, /*keepdim=*/true); + + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); + auto cg_outputs = fe.runFusion({aten_input}); + + testValidate( + &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionReductionKeepDimScheduler_CUDA) { + constexpr int bid_x = 80; + constexpr int tid_x = 4096; + constexpr int red_dim = 1; + + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + TensorView* tv0 = makeConcreteTensor({bid_x, tid_x}); + fusion.addInput(tv0); + + TensorView* tv1 = reductionOp( + BinaryOpType::Add, + {red_dim}, + IrBuilder::create(0), + tv0, + /*keep_dim=*/true); + + fusion.addOutput(tv1); + + const auto options = + at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor aten_input = at::randn({bid_x, tid_x}, options); + auto aten_output = + aten_input.to(at::kDouble).sum({red_dim}, /*keepdim=*/true); + + // Apply reduction heuristic + auto reduction_params = getReductionHeuristics(&fusion, {aten_input}); + TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); + scheduleReduction(&fusion, *reduction_params); + + auto lparams = reduction_params->lparams; + + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}, lparams); + auto cg_outputs = fe.runFusion({aten_input}, lparams); + + testValidate( + &fusion, + cg_outputs, + {aten_input}, + {aten_output}, + __LINE__, + __FILE__, + "", + lparams); +} + +TEST_F(NVFuserTest, FusionSumTo_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + std::vector tensor_shape{2, 3, 4, 5, 6}; + std::vector sum_to_shape{1, 5, 6}; + + std::vector tensor_shape_ref{2, 3, 4, 5, 6}; + std::vector sum_to_shape_ref{1, 5, 6}; + + std::vector sum_to_symb; + std::transform( + sum_to_shape.begin(), + sum_to_shape.end(), + std::back_inserter(sum_to_symb), + [](int s) -> Int* { return IrBuilder::create(s); }); + + TensorView* tv0 = makeConcreteTensor(tensor_shape); + fusion.addInput(tv0); + + TensorView* tv1 = sum_to(tv0, sum_to_symb); + fusion.addOutput(tv1); + + const auto options = + at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor aten_input = at::randn(tensor_shape_ref, options); + auto aten_output = at::sum_to(aten_input.to(at::kDouble), sum_to_shape_ref); + + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); + auto cg_outputs = fe.runFusion({aten_input}); + + TORCH_CHECK( + cg_outputs[0].dim() == static_cast(sum_to_shape.size()), + "sum_to not keeping the final dimension"); + + testValidate( + &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionSumToNoop_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + std::vector tensor_shape{4, 5, 6}; + std::vector sum_to_shape{4, 5, 6}; + + std::vector tensor_shape_ref{4, 5, 6}; + std::vector sum_to_shape_ref{4, 5, 6}; + + std::vector sum_to_symb; + std::transform( + sum_to_shape.begin(), + sum_to_shape.end(), + std::back_inserter(sum_to_symb), + [](int s) -> Int* { return IrBuilder::create(s); }); + + TensorView* tv0 = makeConcreteTensor(tensor_shape); + fusion.addInput(tv0); + + TensorView* tv1 = sum_to(tv0, sum_to_symb); + + // Dummy operator to avoid tv0 both input and output + TensorView* tv2 = add(tv1, IrBuilder::create(0)); + fusion.addOutput(tv2); + + const auto options = + at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor aten_input = at::randn(tensor_shape_ref, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); + auto cg_outputs = fe.runFusion({aten_input}); + auto aten_output = at::sum_to(aten_input.to(at::kDouble), sum_to_shape_ref); + + TORCH_CHECK( + cg_outputs[0].dim() == static_cast(sum_to_shape.size()), + "sum_to not keeping the final dimension"); + + testValidate( + &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionReductionScheduler_CUDA) { + constexpr int bid_x = 80; + constexpr int tid_x = 4096; + constexpr int red_dim = 1; + + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + TensorView* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + TensorView* tv1 = reductionOp( + BinaryOpType::Add, {red_dim}, IrBuilder::create(0), tv0); + fusion.addOutput(tv1); + + const auto options = + at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor aten_input = at::randn({bid_x, tid_x}, options); + auto aten_output = aten_input.to(at::kDouble).sum({red_dim}); + + // Apply reduction heuristic + auto reduction_params = getReductionHeuristics(&fusion, {aten_input}); + TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); + scheduleReduction(&fusion, *reduction_params); + + auto lparams = reduction_params->lparams; + + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}, lparams); + // no broadcasting needed, omitting the last optional argument; + auto cg_outputs = fe.runFusion({aten_input}, lparams); + + testValidate( + &fusion, + cg_outputs, + {aten_input}, + {aten_output}, + __LINE__, + __FILE__, + "", + lparams); +} + +// This test checks if our system could correctly handles the case where both +// reduction and trivial reduction exist in the fusion. Trivial reduction +// deserve testing because trivial reduction is handled more like a broadcasting +// rather than a reduction. +TEST_F(NVFuserTest, FusionReductionWithTrivialReduction_CUDA) { + constexpr int bid_x = 80; + constexpr int tid_x = 4096; + + std::vector> shapes = { + {-1, -1, 1}, {-1, 1, -1}, {1, -1, -1}}; + + for (auto shape : shapes) { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr; + FusionGuard fg(&fusion); + + std::vector> reduction_dims = { + {0}, + {1}, + {2}, + {0, 1}, + {0, 2}, + {1, 2}, + {0, 1, 2}, + }; + + // Set up your input tensor views + TensorView* tv0 = makeConcreteTensor(shape); + fusion.addInput(tv0); + + for (auto rdims : reduction_dims) { + std::vector rdims_(rdims.begin(), rdims.end()); + auto tv = sum(tv0, rdims_); + fusion.addOutput(tv); + } + + const auto options = + at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + auto concrete_shape = shape; + std::deque concrete_values = {bid_x, tid_x}; + for (auto& s : concrete_shape) { + if (s == -1) { + s = concrete_values.front(); + concrete_values.pop_front(); + } + } + + at::Tensor aten_input = at::randn(concrete_shape, options); + std::vector aten_outputs; + for (auto rdims : reduction_dims) { + aten_outputs.push_back(aten_input.sum(rdims)); + } + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto cg_outputs = executor_cache.runFusionWithInputs({aten_input}); + + testValidate( + &fusion, + cg_outputs, + {aten_input}, + aten_outputs, + __LINE__, + __FILE__, + ""); + } +} + +// Simple reduction parallelized on a symbolic size. +TEST_F(NVFuserTest, FusionSymbolicReduction_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + TensorView* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + // tv1[I0, R1] = tv0[I0, I1] + TensorView* tv1 = + reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv0); + fusion.addOutput(tv1); + + // Interface should just be a direct split with a Parallel type. We can + // include the parallelize call if we do this. + tv1->split(1, NamedScalar::getParallelDim(ParallelType::TIDx)); + // tv1[I0, R1o, R1i{BIDx}] = tv0[I0, I1] + + TensorView* tv2 = tv1->rFactor({1}); + // tv2[I0, R1oo, Ir1oi{4}, Ir1i{BIDx}] = tv0[I0, I1] + // tv1[I0, R1oi{4}, R1i{BIDx}] = tv2[I0, R1oo, Ir1oi{4}, Ir1i{BIDx}] + + // Incrementally, can print in between for debugging + tv0->computeAt(tv2, 1); + tv2->computeAt(tv1, 1); + + tv2->axis(-1)->parallelize(ParallelType::TIDx); + + tv1->axis(0)->parallelize(ParallelType::BIDx); + tv1->axis(-1)->parallelize(ParallelType::TIDx); + + int numel_x = 65000; + int numel_y = 1025; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor aten_input = at::randn({numel_x, numel_y}, options); + auto aten_output = aten_input.to(at::kDouble).sum({1}); + + // How many threads to use for the block reduction + int runtime_threadIdx_dim = 128; + + LaunchParams lparams(-1, -1, -1, runtime_threadIdx_dim, -1, -1); + + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}, lparams); + auto cg_outputs = fe.runFusion({aten_input}, lparams); + + testValidate( + &fusion, + cg_outputs, + {aten_input}, + {aten_output}, + __LINE__, + __FILE__, + "", + lparams); +} + +TEST_F(NVFuserTest, FusionReductionSchedulerMultiDimNonFastest_CUDA) { + const std::vector red_dims = {0, 2}; + // Copy is because CodeGen requires int and Pytorch requires int64_t + // for a vector of reduction dimensions + const std::vector red_dims64 = {0, 2}; + const std::vector tensor_dims_in = {5, 10, 15, 20}; + const std::vector tensor_dims_out = {10, 20}; + + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + TensorView* tv0 = makeSymbolicTensor(tensor_dims_in.size()); + fusion.addInput(tv0); + + TensorView* tv1 = reductionOp( + BinaryOpType::Add, red_dims, IrBuilder::create(0), tv0); + fusion.addOutput(tv1); + + const auto options = + at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor aten_input = at::randn(tensor_dims_in, options); + auto aten_output = aten_input.to(at::kDouble).sum(red_dims64); + at::Tensor cg_output = at::empty(tensor_dims_out, options); + + // Apply reduction heuristic + auto reduction_params = getReductionHeuristics(&fusion, {aten_input}); + TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); + scheduleReduction(&fusion, *reduction_params); + auto lparams = reduction_params->lparams; + + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}, lparams); + fe.runFusion({aten_input}, {cg_output}, lparams); + + testValidate( + &fusion, + {cg_output}, + {aten_input}, + {aten_output}, + __LINE__, + __FILE__, + "", + lparams); +} + +TEST_F(NVFuserTest, FusionReductionSchedulerMultiDimFastest_CUDA) { + const std::vector red_dims = {1, 3}; + // Copy is because CodeGen requires int and Pytorch requires int64_t + // for a vector of reduction dimensions + const std::vector red_dims64 = {1, 3}; + const std::vector tensor_dims_in = {5, 10, 15, 20}; + + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + TensorView* tv0 = makeSymbolicTensor(tensor_dims_in.size()); + fusion.addInput(tv0); + + TensorView* tv1 = reductionOp( + BinaryOpType::Add, red_dims, IrBuilder::create(0), tv0); + fusion.addOutput(tv1); + + const auto options = + at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor aten_input = at::randn(tensor_dims_in, options); + auto aten_output = aten_input.to(at::kDouble).sum(red_dims64); + + auto reduction_params = getReductionHeuristics(&fusion, {aten_input}); + TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); + scheduleReduction(&fusion, *reduction_params); + auto lparams = reduction_params->lparams; + + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}, lparams); + auto cg_outputs = fe.runFusion({aten_input}, lparams); + + testValidate( + &fusion, + cg_outputs, + {aten_input}, + {aten_output}, + __LINE__, + __FILE__, + "", + lparams); +} + +TEST_F(NVFuserTest, FusionReductionSchedulerNoODimShmoo_CUDA) { + std::vector dtypes = { + DataType::Double, DataType::Float, DataType::Half}; + // TODO: add test for complex. Currently complex fails with the following + // NVRTC compilation error message: + // error: no suitable user-defined conversion from + // "CudaCodeGen::std::complex" to "CudaCodeGen::std::complex" + // exists +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 + if (at::cuda::getDeviceProperties(0)->major >= 8) { + dtypes.insert(dtypes.end(), DataType::BFloat16); + } +#endif + + std::vector red_dims; + + // Tried to cut down the number iterations with just + // doing every other power of 2. + for (int i = 1; i <= 1024 * 1024; i <<= 2) { + red_dims.push_back(i); + } + + for (auto dtype : dtypes) { + at::ScalarType aten_dtype = data_type_to_aten(dtype); + for (auto& rdim : red_dims) { + Fusion fusion; + FusionGuard fg(&fusion); + + bool is_fp16 = dtype == DataType::Half; + bool is_bf16 = dtype == DataType::BFloat16; + + TensorView* tv0 = makeSymbolicTensor(1, dtype); + fusion.addInput(tv0); + + TensorView* tv0_cast = tv0; + if (is_fp16 || is_bf16) { + tv0_cast = castOp(DataType::Float, tv0); + } + + TensorView* tv1 = sum(tv0_cast, {0}); + + TensorView* tv1_cast = tv1; + if (is_fp16) { + tv1_cast = castOp(DataType::Half, tv1); + } + if (is_bf16) { + tv1_cast = castOp(DataType::BFloat16, tv1); + } + + fusion.addOutput(tv1_cast); + + auto options = at::TensorOptions().dtype(aten_dtype).device(at::kCUDA, 0); + + at::Tensor aten_input = at::randn({rdim}, options); + auto aten_output = aten_input.to(at::kDouble).sum({0}); + + auto reduction_params = getReductionHeuristics(&fusion, {aten_input}); + TORCH_CHECK(reduction_params != nullptr, "Reduction is not found!"); + scheduleReduction(&fusion, *reduction_params); + auto lparams = reduction_params->lparams; + + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}, lparams); + auto cg_outputs = fe.runFusion({aten_input}, lparams); + + testValidate( + &fusion, + cg_outputs, + {aten_input}, + {aten_output}, + __LINE__, + __FILE__, + "", + lparams); + } + } +} + +TEST_F(NVFuserTest, FusionReductionSchedulerDimShmoo_CUDA) { + std::vector dtypes = { + DataType::Double, DataType::Float, DataType::Half}; + // TODO: add complex support. Currently, complex fails with the following + // NVRTC compilation error: + // error: no instance of overloaded function "__shfl_xor_sync" matches the + // argument list +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 + if (at::cuda::getDeviceProperties(0)->major >= 8) { + dtypes.insert(dtypes.end(), DataType::BFloat16); + } +#endif + + std::vector red_axis = {1, 0}; + std::vector output_dims = {160, 320}; + std::vector red_dims; + + // Tried to cut down the number iterations with just + // doing every other power of 2. + for (int i = 1; i <= 1024 * 1024; i <<= 2) { + red_dims.push_back(i); + } + + for (auto dtype : dtypes) { + at::ScalarType aten_dtype = data_type_to_aten(dtype); + for (auto& axis : red_axis) { + for (auto& odim : output_dims) { + for (auto& rdim : red_dims) { + Fusion fusion; + FusionGuard fg(&fusion); + + bool is_fp16 = dtype == DataType::Half; + bool is_bf16 = dtype == DataType::BFloat16; + + TensorView* tv0 = makeSymbolicTensor(2, dtype); + fusion.addInput(tv0); + + TensorView* tv0_cast = tv0; + if (is_fp16 || is_bf16) { + tv0_cast = castOp(DataType::Float, tv0); + } + + TensorView* tv1 = sum(tv0_cast, {axis}); + + TensorView* tv1_cast = tv1; + if (is_fp16) { + tv1_cast = castOp(DataType::Half, tv1); + } + if (is_bf16) { + tv1_cast = castOp(DataType::BFloat16, tv1); + } + fusion.addOutput(tv1_cast); + + auto options = + at::TensorOptions().dtype(aten_dtype).device(at::kCUDA, 0); + + at::Tensor aten_input = + (axis ? at::randn({odim, rdim}, options) + : at::randn({rdim, odim}, options)); + + auto reduction_params = getReductionHeuristics(&fusion, {aten_input}); + TORCH_CHECK(reduction_params != nullptr, "Reduction is not found!"); + scheduleReduction(&fusion, *reduction_params); + auto lparams = reduction_params->lparams; + + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}, lparams); + auto cg_outputs = fe.runFusion({aten_input}, lparams); + auto aten_output = aten_input.to(at::kDouble).sum({axis}); + testValidate( + &fusion, + cg_outputs, + {aten_input}, + {aten_output}, + __LINE__, + __FILE__, + "", + lparams); + } + } + } + } +} + +TEST_F(NVFuserTest, FusionCacheBefore_CUDA) { + // TVM Cache Write + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeSymbolicTensor(2); + TensorView* tv1 = add(tv0, IrBuilder::create(1.0)); + TensorView* tv2 = mul(tv1, IrBuilder::create(3.0)); + fusion.addInput(tv0); + fusion.addOutput(tv2); + + // Before: TV2 = TV1 * 3 + // After: TV3 = TV1 * 3; + // TV2 = TV3; + TensorView* tv3 = tv2->cacheBefore(); + + constexpr int BSX = 32; + tv2->split(-1, BSX); + tv0->computeAt(tv2, -1); + + // Thread and Block binding + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(-1)->parallelize(ParallelType::TIDx); + + constexpr int M = 32, N = 750; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor aten_input = at::randn({M, N}, options); + at::Tensor aten_output = (aten_input + 1.0) * 3.0; + + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); + auto cg_outputs = fe.runFusion({aten_input}); + + testValidate( + &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionCacheAfter_CUDA) { + // TVM Cache Read + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeSymbolicTensor(2); + TensorView* tv1 = add(tv0, IrBuilder::create(1.0)); + TensorView* tv2 = mul(tv1, IrBuilder::create(3.0)); + fusion.addInput(tv0); + fusion.addOutput(tv2); + + // Before: TV1 = TV0 + 1 + // After: TV3 = TV0; + // TV1 = TV3 + 1 + TensorView* tv3 = tv0->cacheAfter(); + + constexpr int BSX = 32; + tv2->split(-1, BSX); + tv0->computeAt(tv2, -1); + + // Thread and Block binding + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(-1)->parallelize(ParallelType::TIDx); + + constexpr int M = 32, N = 457; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor aten_input = at::randn({M, N}, options); + at::Tensor aten_output = (aten_input + 1.0) * 3.0; + + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); + auto cg_outputs = fe.runFusion({aten_input}); + + testValidate( + &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionCacheFork_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeSymbolicTensor(2); + TensorView* tv1 = add(tv0, IrBuilder::create(1.0)); + TensorView* tv2 = mul(tv1, IrBuilder::create(3.0)); + fusion.addInput(tv0); + fusion.addOutput(tv1); + fusion.addOutput(tv2); + // Before: TV1 = TV0 + 1 + // TV2 = TV1 * 1 + // Output: TV1, TV2 + + // After: TV1 = TV0 + 1 + // TV3 = TV1 + // TV2 = TV1 * 1 + // Output: TV3, TV2 + + // cacheFork !!does not!! automatically apply ComputeAt to the cache + auto tv3 = tv1->cacheFork(); + + constexpr int BSX = 32; + tv2->split(-1, BSX); + tv0->computeAt(tv2, -1); + + // Thread and Block binding + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(-1)->parallelize(ParallelType::TIDx); + + constexpr int M = 32, N = 457; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor aten_input = at::randn({M, N}, options); + at::Tensor aten_output1 = aten_input + 1.0; + at::Tensor aten_output2 = aten_output1 * 3.0; + + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); + auto cg_outputs = fe.runFusion({aten_input}); + + testValidate( + &fusion, + cg_outputs, + {aten_input}, + {aten_output1, aten_output2}, + __LINE__, + __FILE__); +} + +TEST_F(NVFuserTest, FusionCacheIndirect_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeSymbolicTensor(2); + TensorView* tv1 = makeSymbolicTensor(2); + TensorView* tv2 = makeSymbolicTensor(2); + TensorView* tv3 = makeSymbolicTensor(2); + TensorView* tv4 = sub(tv2, tv3); + TensorView* tv5 = add(tv1, tv4); + TensorView* tv6 = sub(tv5, tv0); + fusion.addInput(tv0); + fusion.addInput(tv1); + fusion.addInput(tv2); + fusion.addInput(tv3); + fusion.addOutput(tv6); + // t6 = ((t1 + (t2 - t3)) - t0) + + tv5->cacheAfter(); + tv5->cacheBefore(); + + // cacheAfter on inputs placed before schedule + constexpr int BSX = 32; + tv6->split(-1, BSX); + tv2->computeAt(tv6, -1); + + // Thread and Block binding + tv6->axis(0)->parallelize(ParallelType::BIDx); + tv6->axis(-1)->parallelize(ParallelType::TIDx); + + constexpr int M = 32, N = 810; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({M, N}, options); + at::Tensor t1 = at::randn({M, N}, options); + at::Tensor t2 = at::randn({M, N}, options); + at::Tensor t3 = at::randn({M, N}, options); + + std::vector aten_inputs = {t0, t1, t2, t3}; + at::Tensor aten_output = (t1 + (t2 - t3)) - t0; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + testValidate( + &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionCacheBcast_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Algorithm + TensorView* tv0 = makeSymbolicTensor(1); // (M, 1) + TensorView* tv1 = broadcast(tv0, {false, true}); + TensorView* tv2 = makeSymbolicTensor(1); // (1, N) + TensorView* tv3 = broadcast(tv2, {true, false}); + TensorView* tv4 = mul(tv1, tv3); + fusion.addInput(tv0); + fusion.addInput(tv2); + fusion.addOutput(tv4); + + // Case 1 + tv0->cacheAfter(); + + // Case 2 + tv1->cacheBefore(); + + // Case 3 + tv1->cacheAfter(); + + // Case 4 + TensorView* tv8 = tv4->cacheBefore(); + + constexpr int BSX = 128; + tv4->split(0, BSX); + tv4->split(-1, BSX); + tv4->reorder({{0, 0}, {1, 2}, {2, 1}, {3, 3}}); + // M/BSX, N/BSY, BSX, BSY + tv0->computeAt(tv4, 2); + tv2->computeAt(tv4, 2); + // 0, 1 | 2, 3, 4 + + tv4->axis(0)->parallelize(ParallelType::BIDx); + tv4->axis(1)->parallelize(ParallelType::BIDy); + tv4->axis(-1)->parallelize(ParallelType::TIDx); + // Manual Replay on TV3 + tv3->axis(-1)->parallelize(ParallelType::TIDx); + tv8->axis(-1)->parallelize(ParallelType::TIDx); + + constexpr int M = 92, N = 500; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({M}, options); + at::Tensor t1 = at::randn({N}, options); + std::vector aten_inputs = {t0, t1}; + at::Tensor aten_output = + t0.to(at::kDouble).unsqueeze(1).matmul(t1.to(at::kDouble).unsqueeze(0)); + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + testValidate( + &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionCacheMultiConsumer_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeSymbolicTensor(1); + TensorView* tv1 = add(tv0, IrBuilder::create(1)); + TensorView* tv2 = add(tv1, IrBuilder::create(2)); + TensorView* tv3 = add(tv0, IrBuilder::create(1)); + TensorView* tv4 = add(tv3, IrBuilder::create(2)); + + fusion.addInput(tv0); + fusion.addOutput(tv2); + fusion.addOutput(tv4); + + auto tv5 = tv1->cacheBefore(); + auto tv6 = tv3->cacheBefore(); + tv5->setMemoryType(MemoryType::Shared); + tv6->setMemoryType(MemoryType::Shared); + + tv1->computeAt(tv2, -1); + tv3->computeAt(tv4, -1); + + // Fails because tensor must be recomputed twice + // auto tv7 = tv0->cacheAfter(); + + constexpr int N = 800; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor aten_input = at::randn({N}, options); + auto aten_output = (aten_input + 1) + 2; + + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); + auto cg_outputs = fe.runFusion({aten_input}); + + testValidate( + &fusion, + cg_outputs, + {aten_input}, + {aten_output, aten_output}, + __LINE__, + __FILE__); +} + +TEST_F(NVFuserTest, FusionSmem_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Algorithm + TensorView* tv0 = makeSymbolicTensor(2); // (M, N) + TensorView* tv1 = makeSymbolicTensor(2); // (M, N) + TensorView* tv2 = mul(tv0, tv1); + fusion.addInput(tv0); + fusion.addInput(tv1); + fusion.addOutput(tv2); + + // Schedule + TensorView* tv3 = tv0->cacheAfter(); + TensorView* tv4 = tv1->cacheAfter(); + tv3->setMemoryType(MemoryType::Shared); + tv4->setMemoryType(MemoryType::Shared); + + constexpr int BSY = 32; + constexpr int BSX = 128; + tv2->split(0, BSY); + tv2->split(2, BSX); + // M/BSX, BSX, N/BSX, BSX + tv2->reorder({{0, 0}, {1, 2}, {2, 1}, {3, 3}}); + // M/BSX, N/BSX, BSX, BSX + + tv0->computeAt(tv2, 2); + tv1->computeAt(tv2, 2); + + // Thread and Block binding + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(1)->parallelize(ParallelType::BIDy); + tv2->axis(-1)->parallelize(ParallelType::TIDx); + // Manual Binding + tv3->axis(-1)->parallelize(ParallelType::TIDx); + tv4->axis(-1)->parallelize(ParallelType::TIDx); + + constexpr int M = 128, N = 10240; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({M, N}, options); + at::Tensor t1 = at::randn({M, N}, options); + at::Tensor aten_output = mul(t0, t1); + + std::vector aten_inputs = {t0, t1}; + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t1}); + auto cg_outputs = fe.runFusion({t0, t1}); + + testValidate( + &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); + + TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 0); +} + +TEST_F(NVFuserTest, FusionSmemReduce_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Algorithm + TensorView* tv0 = makeSymbolicTensor(3); // M, K, N + TensorView* tv1 = sum(tv0, {1}); // M, R, N + fusion.addInput(tv0); + fusion.addOutput(tv1); + + TensorView* tv2 = tv0->cacheAfter(); + tv2->setMemoryType(MemoryType::Shared); + + // Schedule + constexpr int BSX = 32; + tv1->split(2, BSX); + tv1->split(1, 128); + tv1->split(0, BSX); + // M/BSX, BSX, K/BSX, BSX, N/BSX, BSX + tv1->reorder({{0, 0}, {1, 2}, {2, 4}, {3, 5}, {4, 1}, {5, 3}}); + TensorView* tv3 = tv1->rFactor({-2}); + + tv0->computeAt(tv1, -2); + tv0->computeAt(tv3, -2); + + // Thread and Block binding + tv1->axis(0)->parallelize(ParallelType::BIDx); + tv1->axis(1)->parallelize(ParallelType::BIDy); + tv1->axis(-1)->parallelize(ParallelType::TIDx); + // Manual Binding + tv2->axis(-1)->parallelize(ParallelType::TIDx); + tv3->axis(-1)->parallelize(ParallelType::TIDx); + + constexpr int M = 154, K = 45, N = 1524; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor aten_input = at::randn({M, K, N}, options); + at::Tensor aten_output = sum(aten_input.to(at::kDouble), {1}); + + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); + auto cg_outputs = fe.runFusion({aten_input}); + + testValidate( + &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); + TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 0); +} + +TEST_F(NVFuserTest, FusionSmemBlockGemm_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Algorithm + TensorView* tv0 = makeSymbolicTensor(2); // (M, K) + TensorView* tv1 = makeSymbolicTensor(2); // (K, N) + TensorView* tv2 = broadcast(tv0, {false, false, true}); // (M, K, B) + TensorView* tv3 = broadcast(tv1, {true, false, false}); // (B, K, N) + TensorView* tv4 = mul(tv2, tv3); // M, K, N + TensorView* tv5 = sum(tv4, {1}); // M, R, N + fusion.addInput(tv0); + fusion.addInput(tv1); + fusion.addOutput(tv5); + + // Schedule + constexpr int BSX = 16; + tv5->split(2, BSX - 1); + tv5->split(1, BSX); + tv5->split(0, BSX + 1); + // M/BSX, BSX, K/BSX, BSX, N/BSX, BSX + tv5->reorder({{0, 0}, {1, 3}, {2, 2}, {3, 5}, {4, 1}, {5, 4}}); + // M/BSX, N/BSX, K/BSX, MSX, NSX, KSX + TensorView* tv6 = tv5->rFactor({-1}); + + tv2->setMemoryType(MemoryType::Shared); + tv3->setMemoryType(MemoryType::Shared); + tv4->setMemoryType(MemoryType::Shared); + tv6->setMemoryType(MemoryType::Shared); + + tv0->computeAt(tv5, 3); + tv1->computeAt(tv5, 3); + + // Thread and Block binding + tv5->axis(0)->parallelize(ParallelType::BIDx); + tv5->axis(1)->parallelize(ParallelType::BIDy); + tv5->axis(-2)->parallelize(ParallelType::TIDy); + tv5->axis(-1)->parallelize(ParallelType::TIDx); + // Manual Binding + tv2->axis(-3)->parallelize(ParallelType::TIDy); + tv2->axis(-1)->parallelize(ParallelType::TIDx); + tv3->axis(-1)->parallelize(ParallelType::TIDx); + tv4->axis(-3)->parallelize(ParallelType::TIDy); + tv4->axis(-1)->parallelize(ParallelType::TIDx); + tv6->axis(-3)->parallelize(ParallelType::TIDy); + tv6->axis(-2)->parallelize(ParallelType::TIDx); + + // Make sure BIDx is makred as exact (see issue #1119) + GpuLower gpulw(&fusion); + TORCH_CHECK(gpulw.parallelDimensionMap().isExact(ParallelType::BIDx)); + + constexpr int M = 154, K = 45, N = 1524; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({M, K}, options); + at::Tensor t1 = at::randn({K, N}, options); + + std::vector aten_inputs = {t0, t1}; + at::Tensor aten_output = matmul(t0.to(at::kDouble), t1.to(at::kDouble)); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t1}); + auto cg_outputs = fe.runFusion({t0, t1}); + + testValidate( + &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); + + TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 0); +} + +TEST_F(NVFuserTest, FusionSmemBlockGemmCache_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Algorithm + TensorView* tv0 = makeSymbolicTensor(2); // (M, K) + TensorView* tv1 = makeSymbolicTensor(2); // (K, N) + TensorView* tv2 = broadcast(tv0, {false, false, true}); // (M, K, B) + TensorView* tv3 = broadcast(tv1, {true, false, false}); // (B, K, N) + TensorView* tv4 = mul(tv2, tv3); // M, K, N + TensorView* tv5 = sum(tv4, {1}); // M, R, N + fusion.addInput(tv0); + fusion.addInput(tv1); + fusion.addOutput(tv5); + + // Schedule + // Remove reduction axis from tv5 + // tv6 = (M, R, N) + // tv5 = (M, N) + TensorView* tv6 = tv5->cacheBefore(); + + constexpr int BSX = 16; + tv5->split(1, BSX); + tv5->split(0, BSX); + // M/BSX, BSX, N/BSX, BSX + tv5->reorder({{0, 0}, {1, 2}, {2, 1}, {3, 3}}); + // tv5 = M/BSX, N/BSX, MSX, NSX + + tv6->computeAt(tv5, 2); + tv6->computeAt(tv5, 2); + + tv6->split(-1, BSX); + // M/BSX, BSX, K/BSX, BSX, N/BSX, BSX + tv6->reorder({{0, 0}, {1, 1}, {2, 3}, {3, 4}, {4, 2}, {5, 5}}); + // M/BSX, N/BSX, K/BSX, MSX, NSX, KSX + TensorView* tv7 = tv6->rFactor({-1}); + // tv7 = M/BSX, N/BSX, K/BSXrf, MSX, NSX, KSXr + // tv6 = M/BSX, N/BSX, K/BSXr, MSX, NSX + + tv0->computeAt(tv6, 3); + tv1->computeAt(tv6, 3); + + tv0->computeAt(tv7, 3); + tv1->computeAt(tv7, 3); + + tv2->setMemoryType(MemoryType::Shared); + tv3->setMemoryType(MemoryType::Shared); + tv4->setMemoryType(MemoryType::Shared); + tv6->setMemoryType(MemoryType::Shared); + tv7->setMemoryType(MemoryType::Shared); + // Memory Type + + // Thread and Block binding + tv5->axis(0)->parallelize(ParallelType::BIDx); + tv5->axis(1)->parallelize(ParallelType::BIDy); + tv5->axis(-2)->parallelize(ParallelType::TIDy); + tv5->axis(-1)->parallelize(ParallelType::TIDx); + // Manual Binding + tv2->axis(-3)->parallelize(ParallelType::TIDy); + tv2->axis(-1)->parallelize(ParallelType::TIDx); + tv3->axis(-1)->parallelize(ParallelType::TIDx); + tv4->axis(-3)->parallelize(ParallelType::TIDy); + tv4->axis(-1)->parallelize(ParallelType::TIDx); + + tv7->axis(-3)->parallelize(ParallelType::TIDy); + tv7->axis(-2)->parallelize(ParallelType::TIDx); + + tv6->axis(-2)->parallelize(ParallelType::TIDy); + tv6->axis(-1)->parallelize(ParallelType::TIDx); + + constexpr int M = 154, K = 45, N = 1524; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({M, K}, options); + at::Tensor t1 = at::randn({K, N}, options); + at::Tensor aten_output = matmul(t0.to(at::kDouble), t1.to(at::kDouble)); + + std::vector aten_inputs = {t0, t1}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + testValidate( + &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); + + TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 0); +} + +TEST_F(NVFuserTest, FusionSmemDynamicPersistentSoftmax2D_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* x = makeSymbolicTensor(2); + fusion.addInput(x); + TensorView* max_val = reductionOp( + BinaryOpType::Max, + {-1}, + IrBuilder::create(std::numeric_limits::lowest()), + x); // (M) + TensorView* bcast_max = broadcast(max_val, {false, true}); // (M, B) + TensorView* x_max_sub = sub(x, bcast_max); // (M, N) + TensorView* exp = unaryOp(UnaryOpType::Exp, x_max_sub); // (M, N) + TensorView* sum_exp = sum(exp, {-1}); // (M, R) + TensorView* bcast_sum = broadcast(sum_exp, {false, true}); // (M, B) + TensorView* softmax = div(exp, bcast_sum); // (M, N) + fusion.addOutput(softmax); + + // Read Input into Shared Memory + // Load Input + Pwise into shared memory + auto cache_x = x->cacheAfter(); + cache_x->setMemoryType(MemoryType::Shared); + exp->setMemoryType(MemoryType::Shared); + + std::vector all_tensors( + {x, + cache_x, + max_val, + bcast_max, + x_max_sub, + exp, + sum_exp, + bcast_sum, + softmax}); + + auto tidx = IrBuilder::create(); + fusion.addInput(tidx); + + for (auto tensor : all_tensors) { + tensor->split(-1, tidx); + } + + auto sum_exp_rf = sum_exp->rFactor({1}); + all_tensors.push_back(sum_exp_rf); + + // computeAt + x->computeAt(x_max_sub, 1); + exp->computeAt(softmax, 1); + x_max_sub->computeAt(exp, 2); + + softmax->axis(0)->parallelize(ParallelType::BIDx); + for (auto tensor : all_tensors) { + tensor->axis(-1)->parallelize(ParallelType::TIDx); + } + + const int64_t dimx = 1024; + const int64_t dimy = 4096; + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor aten_input = at::randn({dimx, dimy}, options); + auto aten_output = at::_softmax(aten_input.to(at::kDouble), -1, false); + + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input, 128}); + auto cg_outputs = fe.runFusion({aten_input, 128}); + + testValidate( + &fusion, + cg_outputs, + {aten_input, 128}, + {aten_output}, + __LINE__, + __FILE__); +} + +TEST_F(NVFuserTest, FusionMagicSchedulerSoftmax_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + const int kReductionAxis = 3; + std::vector input_shape{10, 10, 10, 67}; + TensorView* input = makeSymbolicTensor(input_shape.size()); + fusion.addInput(input); + + auto output = softmax(input, kReductionAxis); + + fusion.addOutput(output); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor aten_input = at::randn(input_shape, options); + auto aten_output = + at::_softmax(aten_input.to(at::kDouble), kReductionAxis, false); + + auto reduction_params = getPersistentHeuristics(&fusion, {aten_input}); + TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); + + schedulePersistentKernel(&fusion, *reduction_params); + + auto lparams = reduction_params->lparams; + + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}, lparams); + auto cg_outputs = fe.runFusion({aten_input}, lparams); + + testValidate( + &fusion, + cg_outputs, + {aten_input}, + {aten_output}, + __LINE__, + __FILE__, + "", + lparams); +} + +TEST_F(NVFuserTest, FusionTestMaskSoftmax_CUDA) { + // This test is testing the usage of all padding tokens + // with softmax like Bert might might use in a full padding + // sequence. + Fusion fusion; + FusionGuard fg(&fusion); + + const int kReductionAxis = 3; + std::vector input_shape{256, 16, 128, 128}; + TensorView* input = makeSymbolicTensor(input_shape.size()); + TensorView* mask = makeSymbolicTensor(input_shape.size()); + fusion.addInput(input); + fusion.addInput(mask); + + auto out1 = add(input, mask); + auto output = softmax(out1, kReductionAxis); + + fusion.addOutput(output); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor aten_input = at::randn(input_shape, options); + at::Tensor aten_mask = at::ones(input_shape, options); + // -10,000 is used here as a magic number because the padding + // tokens need to be a value that gives a value close to zero + // as to not influence softmax. Bert, in particular, does + // not use -Infinity because sometimes it will have a + // softmax of all padding tokkens that can result a divide by + // zero that creates NaN result. + aten_mask = aten_mask * -10000.0; + auto aten_out1 = aten_input + aten_mask; + auto aten_output = at::_softmax(aten_out1, kReductionAxis, false); + + auto reduction_params = + getPersistentHeuristics(&fusion, {aten_input, aten_mask}); + TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); + + schedulePersistentKernel(&fusion, *reduction_params); + + auto lparams = reduction_params->lparams; + + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input, aten_mask}, lparams); + auto cg_outputs = fe.runFusion({aten_input, aten_mask}, lparams); + + testValidate( + &fusion, + cg_outputs, + {aten_input, aten_mask}, + {aten_output}, + __LINE__, + __FILE__, + "", + lparams); +} + +TEST_F(NVFuserTest, FusionMagicSchedulerLayerNormBackward_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + std::vector shape{20, 100, 35, 67}; + std::vector norm_shape{67}; + + const size_t kM = shape.size(); + const size_t kN = norm_shape.size(); + const size_t kOuterNumDims = kM - kN; + + std::vector outer_shape; + for (const auto idx : c10::irange(kOuterNumDims)) { + outer_shape.push_back(shape[idx]); + } + for (const auto idx : c10::irange(kOuterNumDims, kM)) { + outer_shape.push_back(1); + } + + auto grad_out = makeSymbolicTensor(shape.size()); + auto input = makeSymbolicTensor(shape.size()); + auto mean = makeConcreteTensor(outer_shape); + auto rstd = makeConcreteTensor(outer_shape); + auto weight = makeSymbolicTensor(norm_shape.size()); + auto bias = makeSymbolicTensor(norm_shape.size()); + fusion.addInput(grad_out); + fusion.addInput(input); + fusion.addInput(mean); + fusion.addInput(rstd); + fusion.addInput(weight); + fusion.addInput(bias); + + auto grads = layer_norm_backward( + grad_out, + input, + norm_shape, + mean, + rstd, + weight, + bias, + {true, true, true}); + + fusion.addOutput(grads.grad_input); + fusion.addOutput(grads.grad_weight); + fusion.addOutput(grads.grad_bias); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor aten_grad_out = at::randn(shape, options); + at::Tensor aten_input = at::randn(shape, options); + at::Tensor aten_weight = at::randn(norm_shape, options); + at::Tensor aten_bias = at::randn(norm_shape, options); + auto at_weight = c10::optional(aten_weight); + auto at_bias = c10::optional(aten_bias); + + const float kEps = 1e-5; + auto aten_results = + at::native_layer_norm(aten_input, norm_shape, at_weight, at_bias, kEps); + auto aten_output = std::get<0>(aten_results); + auto aten_mean = std::get<1>(aten_results); + auto aten_rstd = std::get<2>(aten_results); + + FusionExecutorCache fec(std::move(fusion_ptr)); + std::vector aten_inputs = { + aten_grad_out, aten_input, aten_mean, aten_rstd, aten_weight, aten_bias}; + auto cg_outputs = fec.runFusionWithInputs(aten_inputs); + + auto aten_gradients = at::native_layer_norm_backward( + aten_grad_out.to(at::kDouble), + aten_input.to(at::kDouble), + norm_shape, + aten_mean.to(at::kDouble), + aten_rstd.to(at::kDouble), + c10::optional(aten_weight.to(at::kDouble)), + c10::optional(aten_bias.to(at::kDouble)), + {true, true, true}); + + testValidate( + &fusion, + cg_outputs, + aten_inputs, + {std::get<0>(aten_gradients), + std::get<1>(aten_gradients), + std::get<2>(aten_gradients)}, + __LINE__, + __FILE__); +} + +TEST_F(NVFuserTest, FusionMagicSchedulerRMSNormBackward_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + const int64_t NORM_SIZE = 1024; + std::vector shape{8, 56, NORM_SIZE}; + std::vector norm_shape{NORM_SIZE}; + + const size_t kM = shape.size(); + const size_t kN = norm_shape.size(); + const size_t kOuterNumDims = kM - kN; + + std::vector outer_shape; + for (const auto idx : c10::irange(kOuterNumDims)) { + outer_shape.push_back(shape[idx]); + } + for (const auto idx : c10::irange(kOuterNumDims, kM)) { + outer_shape.push_back(1); + } + + auto grad_out = makeContigTensor(shape.size()); + auto input = makeContigTensor(shape.size()); + auto rstd = makeConcreteTensor(outer_shape); + auto weight = makeContigTensor(norm_shape.size()); + fusion.addInput(grad_out); + fusion.addInput(input); + fusion.addInput(rstd); + fusion.addInput(weight); + + auto grads = rms_norm_backward( + grad_out, input, norm_shape, rstd, weight, {true, true}); + + fusion.addOutput(grads.grad_input); + fusion.addOutput(grads.grad_weight); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor aten_grad_out = at::randn(shape, options); + at::Tensor aten_input = at::randn(shape, options); + at::Tensor aten_weight = at::randn(norm_shape, options); + auto at_weight = c10::optional(aten_weight); + + const float kEps = 1e-6; + auto pow2 = at::pow(aten_input, 2); + auto sum = at::sum(pow2, -1, true); + auto var = at::mul(sum, 1.0 / NORM_SIZE); + auto aten_rstd = at::pow(at::add(var, kEps), -0.5); + + FusionExecutorCache fec(std::move(fusion_ptr)); + std::vector aten_inputs = { + aten_grad_out, aten_input, aten_rstd, aten_weight}; + auto cg_outputs = fec.runFusionWithInputs(aten_inputs); + + auto in_mul_rstd = at::mul(aten_input, aten_rstd); + auto grad_out_mul = at::mul(aten_grad_out, in_mul_rstd); + auto aten_grad_weight = at::sum(grad_out_mul, c10::IntArrayRef{0, 1}); + auto sum_loss1 = at::sum(at::mul(aten_grad_out, aten_weight), -1, true); + auto sum_loss2 = at::sum( + at::mul( + at::mul(at::mul(aten_grad_out, aten_weight), aten_input), aten_rstd), + -1, + true); + + const float fH = NORM_SIZE; + auto term1 = at::mul(aten_rstd, 1.0 / fH); + auto aten_grad_input = at::mul(at::mul(aten_grad_out, fH), aten_weight); + aten_grad_input = at::sub(aten_grad_input, sum_loss1); + aten_grad_input = at::sub( + aten_grad_input, at::mul(at::mul(aten_input, aten_rstd), sum_loss2)); + aten_grad_input = at::mul(aten_grad_input, term1); + testValidate( + &fusion, + cg_outputs, + aten_inputs, + {aten_grad_input, aten_grad_weight}, + __LINE__, + __FILE__); +} + +TEST_F(NVFuserTest, FusionMagicSchedulerLayerNormalization_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + const float kEps = 1e-5; + Double* eps_ptr = IrBuilder::create(kEps); + + std::vector input_shape{20, 100, 35, 67}; + std::vector norm_shape{67}; + + auto input = makeSymbolicTensor(input_shape.size()); + fusion.addInput(input); + + auto result = layer_norm(input, norm_shape, nullptr, nullptr, eps_ptr); + + fusion.addOutput(result.output); + fusion.addOutput(result.mean); + fusion.addOutput(result.invstd); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor aten_input = at::randn(input_shape, options); + c10::optional aten_weight = c10::nullopt; + c10::optional aten_bias = c10::nullopt; + auto aten_outputs = at::native_layer_norm( + aten_input, norm_shape, aten_weight, aten_bias, kEps); + + // Check reduction axis is same for all reductions + // Generate Launch Parameters + auto reduction_params = getPersistentHeuristics(&fusion, {aten_input}); + TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); + + FusionExecutorCache fec(std::move(fusion_ptr)); + auto cg_outputs = fec.runFusionWithInputs({aten_input}); + + testValidate( + &fusion, + cg_outputs, + {aten_input}, + {std::get<0>(aten_outputs), + std::get<1>(aten_outputs), + std::get<2>(aten_outputs)}, + __LINE__, + __FILE__, + ""); +} + +TEST_F(NVFuserTest, FusionMagicSchedulerRMSNormalization_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + int64_t NORM_SIZE = 1024; + const float kEps = 1e-6; + Double* eps_ptr = IrBuilder::create(kEps); + + std::vector input_shape{8, 56, NORM_SIZE}; + std::vector norm_shape{NORM_SIZE}; + + auto input = makeContigTensor(input_shape.size()); + fusion.addInput(input); + auto result = rms_norm(input, norm_shape, nullptr, eps_ptr); + + fusion.addOutput(result.output); + fusion.addOutput(result.invstd); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor aten_input = at::randn(input_shape, options); + c10::optional aten_weight = c10::nullopt; + + auto pow2 = at::pow(aten_input, 2); + + auto sum = at::sum(pow2, -1, true); + auto var = at::mul(sum, 1.0 / NORM_SIZE); + auto invstd = at::pow(at::add(var, kEps), -0.5); + auto output = at::mul(aten_input, invstd); + //// Check reduction axis is same for all reductions + //// Generate Launch Parameters + auto reduction_params = getPersistentHeuristics(&fusion, {aten_input}); + TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); + + FusionExecutorCache fec(std::move(fusion_ptr)); + auto cg_outputs = fec.runFusionWithInputs({aten_input}); + + testValidate( + &fusion, + cg_outputs, + {aten_input}, + {output, invstd}, + __LINE__, + __FILE__, + ""); +} + +TEST_F(NVFuserTest, FusionMagicSchedulerBatchNormalization_CUDA) { + if (!deviceMajorMinorCheck(7)) { + GTEST_SKIP() << "skipping tests on pre-Volta GPUs"; + return; + } + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + const float kMomentum = 0.1; + const float kEps = 1e-5; + const bool kTraining = true; + std::vector input_shape{20, 100, 35, 45}; + + auto input = makeSymbolicTensor(input_shape.size()); + auto weight = makeSymbolicTensor(1); + auto bias = makeSymbolicTensor(1); + auto running_mean = makeSymbolicTensor(1); + auto running_var = makeSymbolicTensor(1); + fusion->addInput(input); + fusion->addInput(weight); + fusion->addInput(bias); + fusion->addInput(running_mean); + fusion->addInput(running_var); + + Double* momentum = IrBuilder::create(kMomentum); + Double* eps = IrBuilder::create(kEps); + + auto result = batch_norm( + input, weight, bias, running_mean, running_var, kTraining, momentum, eps); + + fusion->addOutput(result.output); + fusion->addOutput(result.mean); + fusion->addOutput(result.invstd); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto at_input = at::randn(input_shape, options); + auto at_weight = at::ones({input_shape[1]}, options); + auto at_bias = at::zeros({input_shape[1]}, options); + auto at_run_mean = at::zeros({input_shape[1]}, options); + auto at_run_var = at::ones({input_shape[1]}, options); + + std::vector aten_inputs = { + at_input, at_weight, at_bias, at_run_mean, at_run_var}; + + FusionExecutorCache executor_cache(std::move(fusion)); + + auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); + + auto aten_outputs = at::native_batch_norm( + at_input, + c10::optional(at_weight), + c10::optional(at_bias), + c10::optional(at_run_mean), + c10::optional(at_run_var), + kTraining, + kMomentum, + kEps); + + testValidate( + executor_cache.fusion(), + cg_outputs, + aten_inputs, + {std::get<0>(aten_outputs), + std::get<1>(aten_outputs), + std::get<2>(aten_outputs)}, + __LINE__, + __FILE__, + ""); +} + +TEST_F(NVFuserTest, FusionMagicSchedulerInstanceNormalization_CUDA) { + if (!deviceMajorMinorCheck(7)) { + GTEST_SKIP() << "skipping tests on pre-Volta GPUs"; + return; + } + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + const float kMomentum = 0.1; + const float kEps = 1e-5; + const bool kUseInputStats = true; + std::vector input_shape{20, 100, 35, 45}; + + auto input = makeSymbolicTensor(input_shape.size()); + auto weight = makeSymbolicTensor(1); + auto bias = makeSymbolicTensor(1); + auto running_mean = makeSymbolicTensor(1); + auto running_var = makeSymbolicTensor(1); + fusion->addInput(input); + fusion->addInput(weight); + fusion->addInput(bias); + fusion->addInput(running_mean); + fusion->addInput(running_var); + + Double* momentum = IrBuilder::create(kMomentum); + Double* eps = IrBuilder::create(kEps); + + auto result = instance_norm( + input, + weight, + bias, + running_mean, + running_var, + kUseInputStats, + momentum, + eps); + + fusion->addOutput(result.output); + // fusion->addOutput(result.mean); + // fusion->addOutput(result.invstd); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto at_input = at::randn(input_shape, options); + auto at_weight = at::ones({input_shape[1]}, options); + auto at_bias = at::zeros({input_shape[1]}, options); + auto at_run_mean = at::zeros({input_shape[1]}, options); + auto at_run_var = at::ones({input_shape[1]}, options); + + std::vector aten_inputs = { + at_input, at_weight, at_bias, at_run_mean, at_run_var}; + + FusionExecutorCache executor_cache(std::move(fusion)); + + auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); + auto cg_outputs_full = {at_run_mean, at_run_var, cg_outputs[0]}; + + auto aten_outputs = at::instance_norm( + at_input, + c10::optional(at_weight), + c10::optional(at_bias), + c10::optional(at_run_mean), + c10::optional(at_run_var), + kUseInputStats, + kMomentum, + kEps, + false); + + testValidate( + executor_cache.fusion(), + cg_outputs, + aten_inputs, + // TODO: can run_mean/run_var be checked here? + // fusion_outputs.size() == aten_outputs.size() && aten_outputs.size() == + // fusion->outputs().size() - output_alias_indices.size() + {aten_outputs}, + __LINE__, + __FILE__, + ""); +} + +TEST_F(NVFuserTest, FusionMagicSchedulerInstanceNormalizationBackward_CUDA) { + if (!deviceMajorMinorCheck(7)) { + GTEST_SKIP() << "skipping tests on pre-Volta GPUs"; + return; + } + auto fusion_forward = std::make_unique(); + FusionGuard fg_forward(fusion_forward.get()); + + const float kMomentum = 0.1; + const float kEps = 1e-5; + const bool kUseInputStats = true; + const bool channels_last = true; + const int B = 2; + const int C = 5; + const int S = 3; + std::vector input_shape{B, C, S, S, S}; + // explicit channels-last for NVFuser + std::vector nvfuser_input_shape{B, S, S, S, C}; + + auto input = makeContigTensor(input_shape.size()); + auto weight = makeContigTensor(1); + auto bias = makeContigTensor(1); + fusion_forward->addInput(input); + fusion_forward->addInput(weight); + fusion_forward->addInput(bias); + + Double* momentum = IrBuilder::create(kMomentum); + Double* eps = IrBuilder::create(kEps); + auto result_forward = instance_norm( + input, + weight, + bias, + nullptr, + nullptr, + kUseInputStats, + momentum, + eps, + channels_last); + fusion_forward->addOutput(result_forward.output); + fusion_forward->addOutput(result_forward.mean); + fusion_forward->addOutput(result_forward.invstd); + + FusionExecutorCache executor_cache_forward(std::move(fusion_forward)); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto at_input = at::randn(input_shape, options) + .to(at::MemoryFormat::ChannelsLast3d) + .set_requires_grad(true); + auto at_input_nvfuser = at_input.clone().detach().permute({0, 2, 3, 4, 1}); + auto at_weight = at::ones({input_shape[1]}, options).set_requires_grad(true); + auto at_weight_nvfuser = at_weight.clone().detach(); + auto at_bias = at::zeros({input_shape[1]}, options).set_requires_grad(true); + auto at_bias_nvfuser = at_bias.clone().detach(); + std::vector aten_inputs_forward = { + at_input_nvfuser, at_weight_nvfuser, at_bias_nvfuser}; + // out, mean, invstd + auto outputs_forward = + executor_cache_forward.runFusionWithInputs(aten_inputs_forward); + auto at_out = at::instance_norm( + at_input, + c10::optional(at_weight), + c10::optional(at_bias), + c10::optional(c10::nullopt), + c10::optional(c10::nullopt), + kUseInputStats, + kMomentum, + kEps, + false); + auto at_grad = + at::randn(input_shape, options).to(at::MemoryFormat::ChannelsLast3d); + auto at_grad_nvfuser = at_grad.clone().detach().permute({0, 2, 3, 4, 1}); + at_out.backward(at_grad); + auto fusion_backward = std::make_unique(); + FusionGuard fg_backward(fusion_backward.get()); + + input = makeContigTensor(input_shape.size()); + auto grad_output = makeContigTensor(input_shape.size()); + weight = makeContigTensor(1); + auto save_mean = makeContigTensor(2); + auto save_invstd = makeContigTensor(2); + auto dummy = makeContigTensor(0); + + fusion_backward->addInput(input); + fusion_backward->addInput(grad_output); + fusion_backward->addInput(weight); + fusion_backward->addInput(dummy); // dummy for run_mean + fusion_backward->addInput(dummy); // dummy for run_var + fusion_backward->addInput(save_mean); + fusion_backward->addInput(save_invstd); + + auto result_backward = instance_norm_backward( + input, + grad_output, + weight, + nullptr, + nullptr, + save_mean, + save_invstd, + kUseInputStats, + eps, + {true, true, true}, + channels_last); + + fusion_backward->addOutput(result_backward.grad_input); + fusion_backward->addOutput(result_backward.grad_weight); + fusion_backward->addOutput(result_backward.grad_bias); + + FusionExecutorCache executor_cache_backward(std::move(fusion_backward)); + std::vector aten_inputs_backward = { + at_input_nvfuser, + at_grad_nvfuser, + at_weight_nvfuser, + at::empty({}), + at::empty({}), + outputs_forward[1], + outputs_forward[2]}; + auto outputs_backward = + executor_cache_backward.runFusionWithInputs(aten_inputs_backward); + outputs_backward[0] = outputs_backward[0].permute({0, 4, 1, 2, 3}); + testValidate( + executor_cache_backward.fusion(), + outputs_backward, + aten_inputs_backward, + {at_input.grad(), at_weight.grad(), at_bias.grad()}, + __LINE__, + __FILE__, + ""); +} + +TEST_F(NVFuserTest, FusionPersistentSoftmaxLocalShared_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + const int pixels_per_thread = 64; + const int TIDX = 128; + const int static_size = pixels_per_thread * TIDX; + + TensorView* sx = makeConcreteTensor({-1, static_size}); + TensorView* dx = makeSymbolicTensor(2); + fusion.addInput(sx); + fusion.addInput(dx); + + TensorView* max_sx = reductionOp( + BinaryOpType::Max, + {-1}, + IrBuilder::create(std::numeric_limits::lowest()), + sx); // (M) + TensorView* max_dx = reductionOp( + BinaryOpType::Max, + {-1}, + IrBuilder::create(std::numeric_limits::lowest()), + dx); // (M) + + // Reduction => merge local and shared memory TensorViews + TensorView* max_val = binaryOp(BinaryOpType::Max, max_sx, max_dx); + TensorView* bcast_max = broadcast(max_val, {false, true}); // (M, B) + + TensorView* sx_max_sub = sub(sx, bcast_max); // (M, N) + TensorView* dx_max_sub = sub(dx, bcast_max); // (M, N) + + TensorView* sx_exp = unaryOp(UnaryOpType::Exp, sx_max_sub); // (M, N) + TensorView* dx_exp = unaryOp(UnaryOpType::Exp, dx_max_sub); // (M, N) + + TensorView* sx_sum_exp = sum(sx_exp, {-1}); // (M, R) + TensorView* dx_sum_exp = sum(dx_exp, {-1}); // (M, R) + + // Reduction => merge local and shared memory TensorViews + TensorView* sum_exp = binaryOp(BinaryOpType::Add, sx_sum_exp, dx_sum_exp); + TensorView* bcast_sum = broadcast(sum_exp, {false, true}); // (M, B) + + TensorView* sx_softmax = div(sx_exp, bcast_sum); // (M, N) + TensorView* dx_softmax = div(dx_exp, bcast_sum); // (M, N) + fusion.addOutput(sx_softmax); + fusion.addOutput(dx_softmax); + + auto sx_cache = sx->cacheAfter(); + auto dx_cache = dx->cacheAfter(); + dx_cache->setMemoryType(MemoryType::Shared); + dx_exp->setMemoryType(MemoryType::Shared); + + // Reduction and Broadcast Tensors common to both memory TVs + std::vector common_tensors( + {max_val, sum_exp, bcast_max, bcast_sum}); + + // Static Local Memory TVs + std::vector static_tensors( + {sx, sx_cache, max_sx, sx_max_sub, sx_exp, sx_sum_exp, sx_softmax}); + + // Dynamic Local Memory TVs + std::vector dynamic_tensors( + {dx, dx_cache, max_dx, dx_max_sub, dx_exp, dx_sum_exp, dx_softmax}); + + std::vector all_tensors; + all_tensors.insert( + all_tensors.end(), common_tensors.begin(), common_tensors.end()); + all_tensors.insert( + all_tensors.end(), static_tensors.begin(), static_tensors.end()); + all_tensors.insert( + all_tensors.end(), dynamic_tensors.begin(), dynamic_tensors.end()); + + // M => M + // M, N => M, N/128, 128 + for (auto tensor : all_tensors) { + if (tensor->nDims() > 1) { + tensor->split(-1, TIDX); + } + } + + auto sx_sum_exp_rf = sx_sum_exp->rFactor({1}); + auto dx_sum_exp_rf = dx_sum_exp->rFactor({1}); + all_tensors.push_back(sx_sum_exp_rf); + all_tensors.push_back(dx_sum_exp_rf); + + // computeAt + sx->computeAt(sx_max_sub, 1); + dx->computeAt(dx_max_sub, 1); + + sx_exp->computeAt(sx_softmax, 1); + dx_exp->computeAt(dx_softmax, 1); + + sx_max_sub->computeAt(sx_exp, 2); + dx_max_sub->computeAt(dx_exp, 2); + + sx_softmax->axis(0)->parallelize(ParallelType::BIDx); + dx_softmax->axis(0)->parallelize(ParallelType::BIDx); + for (auto tensor : all_tensors) { + if (tensor->nDims() > 1) { + tensor->axis(-1)->parallelize(ParallelType::TIDx); + } + } + + const int64_t dimx = 1024; + const int64_t dimy = 16384; + + auto properties = at::cuda::getDeviceProperties(0); + const size_t required_smem_size = + (dimy - static_size) * sizeof(float) + TIDX * sizeof(float); + if (properties->sharedMemPerBlockOptin < required_smem_size) { + GTEST_SKIP() << "not enough shared memory space on device to run test: " + << properties->sharedMemPerBlock; + } + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor aten_input = at::randn({dimx, dimy}, options); + at::Tensor aten_static_in = aten_input.narrow(1, 0, static_size); + at::Tensor aten_dynamic_in = + aten_input.narrow(1, static_size, dimy - static_size); + + at::Tensor out = at::zeros({dimx, dimy}, options); + at::Tensor cg_static_out = out.narrow(1, 0, static_size); + at::Tensor cg_dynamic_out = out.narrow(1, static_size, dimy - static_size); + + std::vector aten_outputs; + + auto aten_output = at::_softmax(aten_input.to(at::kDouble), -1, false); + at::Tensor aten_static_out = aten_output.narrow(1, 0, static_size); + at::Tensor aten_dynamic_out = + aten_output.narrow(1, static_size, dimy - static_size); + + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion, {aten_static_in, aten_dynamic_in}); + fe.runFusion( + {aten_static_in, aten_dynamic_in}, {cg_static_out, cg_dynamic_out}); + + testValidate( + &fusion, + {cg_static_out, cg_dynamic_out}, + {aten_static_in, aten_dynamic_in}, + {cg_static_out, cg_dynamic_out}, + __LINE__, + __FILE__); +} + +TEST_F(NVFuserTest, FusionPersistentNormLocalShared_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + const int pixels_per_thread = 64; + const int TIDX = 128; + const int static_size = pixels_per_thread * TIDX; + + TensorView* sx = makeConcreteTensor({-1, static_size}); + TensorView* dx = makeSymbolicTensor(2); + fusion.addInput(sx); + fusion.addInput(dx); + + Double* gamma = IrBuilder::create(); + Double* beta = IrBuilder::create(); + Double* eps = IrBuilder::create(); + Int* N = IrBuilder::create(); + fusion.addInput(gamma); + fusion.addInput(beta); + fusion.addInput(eps); + fusion.addInput(N); + + // Reduction + auto sx_sum = sum(sx, {-1}); // (M, R) + auto dx_sum = sum(dx, {-1}); // (M, R) + // Reduction => merge local and shared memory TensorViews + auto x_sum = binaryOp(BinaryOpType::Add, sx_sum, dx_sum); + + // Broadcast + auto x_sum_bcast = broadcast(x_sum, {false, true}); // (M, B) + // Pwise + auto x_mean = div(x_sum_bcast, N); // (M, B) + + auto sx_mean_sub = sub(sx, x_mean); // (M, N) + auto dx_mean_sub = sub(dx, x_mean); // (M, N) + + auto sx_mean_sub_pow = mul(sx_mean_sub, sx_mean_sub); // (M, N) + auto dx_mean_sub_pow = mul(dx_mean_sub, dx_mean_sub); // (M, N) + + // Reduction + auto sx_var_sum = sum(sx_mean_sub_pow, {-1}); // (M, R) + auto dx_var_sum = sum(dx_mean_sub_pow, {-1}); // (M, R) + // Reduction => merge local and shared memory TensorViews + auto var_sum = binaryOp(BinaryOpType::Add, sx_var_sum, dx_var_sum); + + // Broadcast + auto var_sum_bcast = broadcast(var_sum, {false, true}); // (M, B) + // Pwise + auto var = div(var_sum_bcast, N); // (M, B) + auto var_eps = add(var, eps); // (M, B) + auto rvar = unaryOp(UnaryOpType::Rsqrt, var_eps); // (M, B) + + auto sx_norm = mul(sx_mean_sub, rvar); + auto dx_norm = mul(dx_mean_sub, rvar); + + auto sx_norm_gamma = mul(sx_norm, gamma); + auto dx_norm_gamma = mul(dx_norm, gamma); + + auto sx_norm_gamma_beta = add(sx_norm_gamma, beta); + auto dx_norm_gamma_beta = add(dx_norm_gamma, beta); + + fusion.addOutput(sx_norm_gamma_beta); + fusion.addOutput(dx_norm_gamma_beta); + + sx_norm_gamma_beta->setContiguity(false); + dx_norm_gamma_beta->setContiguity(false); + + // Read Input into Shared Memory + // Read Input minus Input_Mean into Shared Memory + auto sx_cache = sx->cacheAfter(); + auto dx_cache = dx->cacheAfter(); + dx_cache->setMemoryType(MemoryType::Shared); + dx_mean_sub->setMemoryType(MemoryType::Shared); + + std::vector common_tensors( + {x_sum, x_sum_bcast, x_mean, var_sum, var_sum_bcast, var, var_eps, rvar}); + + std::vector static_tensors( + {sx, + sx_cache, + sx_sum, + sx_mean_sub, + sx_mean_sub_pow, + sx_var_sum, + sx_norm, + sx_norm_gamma, + sx_norm_gamma_beta}); + + std::vector dynamic_tensors( + {dx, + dx_cache, + dx_sum, + dx_mean_sub, + dx_mean_sub_pow, + dx_var_sum, + dx_norm, + dx_norm_gamma, + dx_norm_gamma_beta}); + + std::vector all_tensors; + all_tensors.insert( + all_tensors.end(), common_tensors.begin(), common_tensors.end()); + all_tensors.insert( + all_tensors.end(), static_tensors.begin(), static_tensors.end()); + all_tensors.insert( + all_tensors.end(), dynamic_tensors.begin(), dynamic_tensors.end()); + + // M => M + // M, N => M, N/128, 128 + for (auto tensor : all_tensors) { + if (tensor->nDims() > 1) { + tensor->split(-1, TIDX); + } + } + + // Local Sum => Block Broadcast + TensorView* sx_sum_rf = sx_sum->rFactor({1}); + TensorView* sx_var_sum_rf = sx_var_sum->rFactor({1}); + TensorView* dx_sum_rf = dx_sum->rFactor({1}); + TensorView* dx_var_sum_rf = dx_var_sum->rFactor({1}); + all_tensors.push_back(sx_sum_rf); + all_tensors.push_back(sx_var_sum_rf); + all_tensors.push_back(dx_sum_rf); + all_tensors.push_back(dx_var_sum_rf); + + // ComputeAt + sx->computeAt(sx_mean_sub_pow, 1); + dx->computeAt(dx_mean_sub_pow, 1); + + var_sum->computeAt(rvar, 1); + + sx_mean_sub_pow->computeAt(sx_var_sum_rf, 2); + dx_mean_sub_pow->computeAt(dx_var_sum_rf, 2); + + sx_norm->computeAt(sx_norm_gamma_beta, 2); + dx_norm->computeAt(dx_norm_gamma_beta, 2); + + sx_norm_gamma_beta->axis(0)->parallelize(ParallelType::BIDx); + dx_norm_gamma_beta->axis(0)->parallelize(ParallelType::BIDx); + for (auto tensor : all_tensors) { + if (tensor->nDims() > 1) { + tensor->axis(-1)->parallelize(ParallelType::TIDx); + } + } + + const int dimx = 1024; + const int dimy = 16384; + const float kGamma = 1.0f; + const float kBeta = 0.0f; + const float kEps = 1e-5; + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + auto properties = at::cuda::getDeviceProperties(0); + const size_t required_smem_size = + (dimy - static_size) * sizeof(float) + TIDX * sizeof(float); + if (properties->sharedMemPerBlockOptin < required_smem_size) { + GTEST_SKIP() << "not enough shared memory space on device to run test: " + << properties->sharedMemPerBlock; + } + + at::Tensor aten_input = at::randn({dimx, dimy}, options); + at::Tensor aten_static_in = aten_input.narrow(1, 0, static_size); + at::Tensor aten_dynamic_in = + aten_input.narrow(1, static_size, dimy - static_size); + + at::Tensor out = at::zeros({dimx, dimy}, options); + at::Tensor cg_static_out = out.narrow(1, 0, static_size); + at::Tensor cg_dynamic_out = out.narrow(1, static_size, dimy - static_size); + + std::vector aten_inputs = { + aten_static_in, aten_dynamic_in, kGamma, kBeta, kEps, dimy}; + + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + + fe.runFusion(aten_inputs, {cg_static_out, cg_dynamic_out}); + + auto at_mu = at::mean(aten_input.to(at::kDouble), -1).unsqueeze(1); + auto at_var = at::var(aten_input.to(at::kDouble), -1, false).unsqueeze(1); + auto at_rvar = at::rsqrt(at::add(at_var, kEps)); + auto at_norm = at::mul(at::sub(aten_input, at_mu), at_rvar); + auto aten_output = at::add(at::mul(at_norm, kGamma), kBeta); + at::Tensor aten_static_out = aten_output.narrow(1, 0, static_size); + at::Tensor aten_dynamic_out = + aten_output.narrow(1, static_size, dimy - static_size); + + testValidate( + &fusion, + {cg_static_out, cg_dynamic_out}, + aten_inputs, + {aten_static_out, aten_dynamic_out}, + __LINE__, + __FILE__); +} + +TEST_F(NVFuserTest, FusionSmemDynamicPersistentNorm_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + auto x = makeSymbolicTensor(2); + Double* gamma = IrBuilder::create(); + Double* beta = IrBuilder::create(); + Double* eps = IrBuilder::create(); + Int* N = IrBuilder::create(); + fusion.addInput(x); + fusion.addInput(gamma); + fusion.addInput(beta); + fusion.addInput(eps); + fusion.addInput(N); + + // Reduction + auto x_sum = sum(x, {-1}); // (M, R) + // Broadcast + auto x_sum_bcast = broadcast(x_sum, {false, true}); // (M, B) + // Pwise + auto x_mean = div(x_sum_bcast, N); // (M, B) + auto x_mean_sub = sub(x, x_mean); // (M, N) + auto x_mean_sub_pow = mul(x_mean_sub, x_mean_sub); // (M, N) + // Reduction + auto var_sum = sum(x_mean_sub_pow, {-1}); // (M, R) + // Broadcast + auto var_sum_bcast = broadcast(var_sum, {false, true}); // (M, B) + // Pwise + auto var = div(var_sum_bcast, N); // (M, B) + auto var_eps = add(var, eps); // (M, B) + auto rvar = unaryOp(UnaryOpType::Rsqrt, var_eps); // (M, B) + auto norm = mul(x_mean_sub, rvar); + auto norm_gamma = mul(norm, gamma); + auto norm_gamma_beta = add(norm_gamma, beta); + fusion.addOutput(norm_gamma_beta); + + // Read Input into Shared Memory + // Read Input minus Input_Mean into Shared Memory + auto cache_x = x->cacheAfter(); + cache_x->setMemoryType(MemoryType::Shared); + x_mean_sub->setMemoryType(MemoryType::Shared); + + std::vector all_tensors( + {x_sum, + x_mean, + cache_x, + x_sum_bcast, + x_mean_sub, + x_mean_sub_pow, + var_sum, + var_sum_bcast, + var, + var_eps, + rvar, + norm, + norm_gamma, + norm_gamma_beta}); + + auto tidx = IrBuilder::create(); + fusion.addInput(tidx); + + for (auto tensor : all_tensors) { + tensor->split(-1, tidx); + } + + // Local Sum => Block Broadcast + TensorView* x_sum_rf = x_sum->rFactor({1}); + TensorView* var_sum_rf = var_sum->rFactor({1}); + all_tensors.push_back(x_sum_rf); + all_tensors.push_back(var_sum_rf); + + // ComputeAt + x->computeAt(x_mean_sub_pow, 1); + var_sum->computeAt(rvar, 1); + x_mean_sub_pow->computeAt(var_sum_rf, 2); + norm->computeAt(norm_gamma_beta, 2); + + for (auto tv : all_tensors) { + tv->axis(0)->parallelize(ParallelType::BIDx); + tv->axis(-1)->parallelize(ParallelType::TIDx); + } + + const int dimx = 128; + const int dimy = 2048; + const float kGamma = 1.0f; + const float kBeta = 0.0f; + const float kEps = 1e-5; + const int TIDX = 128; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor aten_input = at::randn({dimx, dimy}, options); + auto at_mu = at::mean(aten_input.to(at::kDouble), -1).unsqueeze(1); + auto at_var = at::var(aten_input.to(at::kDouble), -1).unsqueeze(1); + auto at_rvar = at::rsqrt(at::add(at_var, kEps)); + auto at_norm = at::mul(at::sub(aten_input, at_mu), at_rvar); + auto aten_output = at::add(at::mul(at_norm, kGamma), kBeta); + + std::vector aten_inputs = { + aten_input, kGamma, kBeta, kEps, dimy, TIDX}; + + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + testValidate( + &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionSmemDynamicReductionSymbolic_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + TensorView* tv0 = makeSymbolicTensor(2); + TensorView* tv1 = + reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv0); + fusion.addInput(tv0); + fusion.addOutput(tv1); + // tv1[I0, R1] = tv0[I0, I1] + + // Interface should just be a direct split with a Parallel type. We can + // include the parallelize call if we do this. + tv1->split(1, NamedScalar::getParallelDim(ParallelType::TIDx)); + // tv1[I0, R1o, R1i{BIDx}] = tv0[I0, I1] + + TensorView* tv2 = tv1->rFactor({2}); + tv2->setMemoryType(MemoryType::Shared); + // tv2[I0, R1oo, Ir1i{BIDx}] = tv0[I0, I1] + // tv1[I0, R1i{BIDx}] = tv2[I0, R1oo, Ir1i{BIDx}] + + tv0->computeAt(tv1, 1); + + tv2->axis(-1)->parallelize(ParallelType::TIDx); + tv1->axis(0)->parallelize(ParallelType::BIDx); + + constexpr int numel_x = 65000, numel_y = 1024; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor aten_input = at::randn({numel_x, numel_y}, options); + auto aten_output = aten_input.to(at::kDouble).sum({1}); + + // How many threads to use for the block reduction + constexpr int runtime_threadIdx_dim = 128; + + LaunchParams lparams(-1, -1, -1, runtime_threadIdx_dim, -1, -1); + + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}, lparams); + auto cg_outputs = fe.runFusion({aten_input}, lparams); + + testValidate( + &fusion, + cg_outputs, + {aten_input}, + {aten_output}, + __LINE__, + __FILE__, + "", + lparams); + TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 0); +} + +TEST_F(NVFuserTest, FusionSmemDynamicReductionSymbolicArg_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Algorithm + Int* sym_bsx = IrBuilder::create(); + TensorView* tv0 = makeSymbolicTensor(3); // M, K, N + fusion.addInput(tv0); + fusion.addInput(sym_bsx); + + TensorView* tv1 = sum(tv0, {1}); // M, R, N + fusion.addOutput(tv1); + + TensorView* tv2 = tv0->cacheAfter(); + tv2->setMemoryType(MemoryType::Shared); + + // Schedule + constexpr int BSX = 32; + tv1->split(2, BSX); + tv1->split(1, sym_bsx); + tv1->split(0, BSX); + // M/BSX, BSX, K/BSX, BSX, N/BSX, BSX + tv1->reorder({{0, 0}, {1, 2}, {2, 4}, {3, 5}, {4, 1}, {5, 3}}); + TensorView* tv3 = tv1->rFactor({-2}); + + tv0->computeAt(tv1, -2); + tv0->computeAt(tv3, -2); + + // Thread and Block binding + tv1->axis(0)->parallelize(ParallelType::BIDx); + tv1->axis(1)->parallelize(ParallelType::BIDy); + tv1->axis(-1)->parallelize(ParallelType::TIDx); + // Manual Binding + tv2->axis(-1)->parallelize(ParallelType::TIDx); + tv3->axis(-1)->parallelize(ParallelType::TIDx); + + constexpr int M = 154, K = 45, N = 1524; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor aten_input = at::randn({M, K, N}, options); + at::Tensor aten_output = aten_input.to(at::kDouble).sum({1}); + + // How many threads to use for the block reduction + constexpr int runtime_threadIdx_dim = 128; + + auto lparams = LaunchParams(-1, -1, -1, runtime_threadIdx_dim, -1, -1); + + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input, runtime_threadIdx_dim}, lparams); + auto cg_outputs = fe.runFusion({aten_input, runtime_threadIdx_dim}, lparams); + + testValidate( + &fusion, + cg_outputs, + {aten_input, runtime_threadIdx_dim}, + {aten_output}, + __LINE__, + __FILE__, + "", + lparams); + + TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 0); +} + +TEST_F(NVFuserTest, FusionSmemDynamicPwiseMulSymbolicArgWAR_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + Int* sym_bsx = IrBuilder::create(); + TensorView* tv0 = makeSymbolicTensor(2); // (M, K) + TensorView* tv1 = makeSymbolicTensor(2); // (K, N) + TensorView* tv2 = broadcast(tv0, {false, false, true}); // (M, K, B) + TensorView* tv3 = broadcast(tv1, {true, false, false}); // (B, K, N) + TensorView* tv4 = mul(tv2, tv3); // M, K, N + fusion.addInput(tv0); + fusion.addInput(tv1); + fusion.addInput(sym_bsx); + fusion.addOutput(tv4); + // Algorithm + + tv2->setMemoryType(MemoryType::Shared); + tv3->setMemoryType(MemoryType::Shared); + + constexpr int BSX = 32; + tv4->split(2, BSX); + tv4->split(1, sym_bsx); + tv4->split(0, BSX); + // M/BSX, BSX, K/BSX, BSX, N/BSX, BSX + tv4->reorder({{0, 0}, {1, 3}, {2, 1}, {3, 4}, {4, 2}, {5, 5}}); + // M/BSX, K/BSX, N/BSX, MSX, KSX, NSX + + tv0->computeAt(tv4, 3); + tv1->computeAt(tv4, 3); + // Schedule + + tv4->axis(0)->parallelize(ParallelType::BIDx); + tv4->axis(2)->parallelize(ParallelType::BIDy); + // Manual Binding + tv2->axis(-2)->parallelize(ParallelType::TIDx); + tv3->axis(-1)->parallelize(ParallelType::TIDx); + // Thread and Block binding + + constexpr int M = 128, K = 457, N = 1024; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({M, K}, options); + at::Tensor t1 = at::randn({K, N}, options); + at::Tensor aten_output = mul(t0.unsqueeze(2), t1.unsqueeze(0)); + std::vector aten_inputs = {t0, t1, BSX}; + + LaunchParams lparams(-1, -1, -1, BSX, -1, -1); + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs, lparams); + auto cg_outputs = fe.runFusion(aten_inputs, lparams); + + testValidate( + &fusion, + cg_outputs, + aten_inputs, + {aten_output}, + __LINE__, + __FILE__, + "", + lparams); + + TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 1); +} + +TEST_F(NVFuserTest, FusionSmemDynamicTiledGemm_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Symbolic integers we will use for runtime tiling + Int* symbolic_m_tile_dim = IrBuilder::create(); // bound to threadIdx.z + Int* symbolic_split_k_tile_dim = + IrBuilder::create(); // bound to blockIdx.x + Int* symbolic_block_k_tile_dim = + IrBuilder::create(); // bound to threadIdx.x + // Compile-time integer for tiling + int n_smem_tile = 8; // bound to threadIdx.y + + // Symbolic 2D tensors TV0[M, K], TV1[K, N] + TensorView* tv0 = makeSymbolicTensor(2); + TensorView* tv1 = makeSymbolicTensor(2); + + // Broadcast tv0 to [M, K, *] + TensorView* tv2 = broadcast(tv0, {false, false, true}); + // Broadcast tv1 to [*, K, N] + TensorView* tv3 = broadcast(tv1, {true, false, false}); + + // Pointwise multiplication resulting in tv3[M, K, N] + TensorView* tv4 = mul(tv2, tv3); + + // Turn the K-dimension of tv4 into a reduction dimension + TensorView* tv5 = sum(tv4, {1}); + + // Register inputs and outputs + fusion.addInput(tv0); + fusion.addInput(tv1); + fusion.addOutput(tv5); + + // Register runtime tile dims as inputs + fusion.addInput(symbolic_m_tile_dim); + fusion.addInput(symbolic_split_k_tile_dim); + fusion.addInput(symbolic_block_k_tile_dim); + + // Make a 3D tile, mix of symbolic and constant, do in reverse order because + // dims are inserted + // [M, K, N] + tv5->split(2, n_smem_tile); + tv5->split(1, symbolic_block_k_tile_dim); + tv5->split(1, symbolic_split_k_tile_dim); + tv5->split(0, symbolic_m_tile_dim); + // [Mo, Mi, Koo, Koi, Ki, No, Ni] + + // Reorder so all outer tiles are in the leftmost 3 positions + tv5->reorder({{1, 5}, {5, 1}}); + // [Mo, No, Koo, Koi, Ki, Mi, Ni] + + // Factor out the outer reduction IterDomain, then run the inter-cta + // reduction, and intra-cta reduction + auto tv6 = tv5->rFactor({2}); + // [Mo, No, rKoo, rKoi, rKi, Mi, Ni] + // [Mo, No, rKoi, rKi, Mi, Ni] + + // Scope computations + tv6->computeAt(tv5, 2); + // [Mo, No, rKoo, Koi, Ki, Mi, Ni] + // [Mo, No, rKoi, rKi, Mi, Ni] + + // Setup compute at schedule + tv0->computeAt(tv6, 3); + tv1->computeAt(tv6, 3); + tv4->computeAt(tv6, -1); + // + // T2[Mo, bNo, Koo, Koi, Kii, Mi, bNi] CA(4, 3) + // T3[bMo, No, Koo, Koi, Kii, bMi, Ni] CA(4, 3) + // T4[ Mo, No, Koo, Koi, Kii, Mi, Ni] + // T6[ Mo, No, rKoo, Koi, Kii, Mi, Ni] + // T5[ Mo, No, rKoi, rKii, Mi, Ni] + + // Cache smem tiles + tv2->setMemoryType(MemoryType::Shared); + tv3->setMemoryType(MemoryType::Shared); + tv4->setMemoryType(MemoryType::Local); + tv6->setMemoryType(MemoryType::Local); + + tv5->axis(0)->parallelize(ParallelType::BIDz); + tv5->axis(1)->parallelize(ParallelType::BIDy); + + std::vector tv_list = {tv2, tv3, tv4, tv5, tv6}; + for (auto tv : tv_list) { + tv->axis(-2)->parallelize(ParallelType::TIDz); + tv->axis(-1)->parallelize(ParallelType::TIDy); + } + tv2->axis(3)->parallelize(ParallelType::TIDx); + tv3->axis(3)->parallelize(ParallelType::TIDx); + tv4->axis(3)->parallelize(ParallelType::TIDx); + tv6->axis(3)->parallelize(ParallelType::TIDx); + tv5->axis(2)->parallelize(ParallelType::TIDx); + + tv2->axis(4)->parallelize(ParallelType::BIDx); + tv3->axis(4)->parallelize(ParallelType::BIDx); + tv4->axis(4)->parallelize(ParallelType::BIDx); + tv6->axis(4)->parallelize(ParallelType::BIDx); + tv5->axis(3)->parallelize(ParallelType::BIDx); + + constexpr int M = 31, K = 65, N = 33; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({M, K}, options); + at::Tensor t1 = at::randn({K, N}, options); + + // Runtime tiling + int m_tile = 4; // bound to threadIdx.z + int split_k = 7; // bound to blockIdx.x + int intra_cta = 8; // bound to threadIdx.x + + std::vector aten_inputs = {t0, t1, m_tile, split_k, intra_cta}; + at::Tensor aten_output = + mul(t0.unsqueeze(2), t1.unsqueeze(0)).to(at::kDouble).sum(1); + + FusionExecutor fe; + // Generate CUDA and compile with nvRTC + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + testValidate( + &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); + + TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count == 1); +} + +} // namespace jit +} // namespace torch +#endif // #if defined(USE_CUDA) diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu2.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu2.cpp new file mode 100644 index 0000000000000..d154b454281e1 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu2.cpp @@ -0,0 +1,9801 @@ +#if defined(USE_CUDA) +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include + +// Tests go in torch::jit +namespace torch { +namespace jit { + +using namespace torch::jit::fuser::cuda; +using namespace at::indexing; + +TEST_F(NVFuserTest, FusionGlobalIntermediate_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + TensorView* tv0 = makeSymbolicTensor(2); + TensorView* tv1 = + reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv0); + fusion.addInput(tv0); + fusion.addOutput(tv1); + // tv1[I0, R1] = tv0[I0, I1] + + // Interface should just be a direct split with a Parallel type. We can + // include the parallelize call if we do this. + tv1->split(1, NamedScalar::getParallelDim(ParallelType::TIDx)); + // tv1[I0, R1o, R1i{BIDx}] = tv0[I0, I1] + + TensorView* tv2 = tv1->rFactor({2}); + tv2->setMemoryType(MemoryType::Global); + // tv2[I0, R1oo, Ir1i{BIDx}] = tv0[I0, I1] + // tv1[I0, R1i{BIDx}] = tv2[I0, R1oo, Ir1i{BIDx}] + + tv0->computeAt(tv1, 1); + + tv2->axis(-1)->parallelize(ParallelType::TIDx); + tv1->axis(0)->parallelize(ParallelType::BIDx); + + constexpr int numel_x = 65000, numel_y = 1024; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input = at::randn({numel_x, numel_y}, options); + + // How many threads to use for the block reduction + constexpr int runtime_threadIdx_dim = 128; + + auto lparams = LaunchParams(-1, -1, -1, runtime_threadIdx_dim, -1, -1); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input}, lparams); + auto cg_outputs = fe.runFusion({input}, lparams); + + auto aten_output = input.to(at::kDouble).sum({1}); + testValidate( + &fusion, + cg_outputs, + {input}, + {aten_output}, + __LINE__, + __FILE__, + "", + lparams); +} + +TEST_F(NVFuserTest, FusionGlobalIntermediateDefaultSchedule_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeSymbolicTensor(2); + TensorView* tv1 = makeSymbolicTensor(2); + TensorView* tv2 = makeSymbolicTensor(2); + TensorView* tv3 = makeSymbolicTensor(2); + TensorView* tv4 = sub(tv2, tv3); + TensorView* tv5 = add(tv1, tv4); + TensorView* tv6 = sub(tv5, tv0); + fusion.addInput(tv0); + fusion.addInput(tv1); + fusion.addInput(tv2); + fusion.addInput(tv3); + fusion.addOutput(tv6); + // t6 = ((t1 + (t2 - t3)) - t0) + + tv4->setMemoryType(MemoryType::Global); + tv5->setMemoryType(MemoryType::Global); + tv6->setMemoryType(MemoryType::Global); + + constexpr int M = 32, N = 810; + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({M, N}, options); + at::Tensor t1 = at::randn({M, N}, options); + at::Tensor t2 = at::randn({M, N}, options); + at::Tensor t3 = at::randn({M, N}, options); + + at::Tensor aten_output = (t1 + (t2 - t3)) - t0; + + std::vector aten_inputs = {t0, t1, t2, t3}; + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t1, t2, t3}); + auto cg_outputs = fe.runFusion({t0, t1, t2, t3}); + + testValidate( + &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionConstCheck_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto one = IrBuilder::create(1); + TORCH_CHECK(one->isConstScalar()); + + auto one_x2 = mul(one, one); + TORCH_CHECK(one_x2->isConstScalar()); + + auto one_x3 = mul(one_x2, one); + TORCH_CHECK(one_x3->isConstScalar()); + + auto one_x4 = mul(one_x3, one); + TORCH_CHECK(one_x4->isConstScalar()); +} + +TEST_F(NVFuserTest, FusionUnrollWithAlloc_CUDA) { + const std::vector tensor_dims_in = {128, 128}; + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + TensorView* tv0 = makeSymbolicTensor(tensor_dims_in.size()); + fusion.addInput(tv0); + + TensorView* tv1 = add(tv0, IrBuilder::create(0)); + TensorView* tv2 = + reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv1); + fusion.addOutput(tv2); + + const auto options = + at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input = at::randn(tensor_dims_in, options); + at::Tensor cg_output = at::empty({tensor_dims_in[0]}, options); + + // Schedule + tv2->split(1, 32); + tv2->split(1, 4); // unroll + + auto tv2_rf = tv2->rFactor({-3, -2}); + + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(-1)->parallelize(ParallelType::TIDx); + + tv2_rf->axis(0)->parallelize(ParallelType::BIDx); + tv2_rf->axis(-1)->parallelize(ParallelType::TIDx); + tv2_rf->axis(-2)->parallelize(ParallelType::Unroll); + + tv1->computeAt(tv2_rf, -1); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input}); + auto cg_outputs = fe.runFusion({input}); + + auto aten_output = (input + 0).to(at::kDouble).sum(1); + + testValidate(&fusion, cg_outputs, {input}, {aten_output}, __LINE__, __FILE__); +} + +// Test isZeroInt +TEST_F(NVFuserTest, FusionIsZeroInt_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + Int* x = IrBuilder::create(0); + Int* y = IrBuilder::create(1); + Val* z = mul(x, y); + TORCH_CHECK(x->isZeroInt()); + TORCH_CHECK(!y->isZeroInt()); + TORCH_CHECK(!z->isZeroInt()); +} + +// Test isOneInt +TEST_F(NVFuserTest, FusionIsOneInt_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + Int* x = IrBuilder::create(1); + Int* y = IrBuilder::create(1); + Val* z = mul(x, y); + TORCH_CHECK(x->isOneInt()); + TORCH_CHECK(y->isOneInt()); + TORCH_CHECK(!z->isOneInt()); +} + +// This is to verify no cycle of computeAt is created. A more complex +// variation of this pattern appears in one of the Python tests +// (test_random_topo). +TEST_F(NVFuserTest, FusionComputeAtNonterminatingOutput_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + // Common intermediate tensor + auto tv1 = add(tv0, IrBuilder::create(1)); + // tv1 -> tv2 + auto tv2 = add(tv1, IrBuilder::create(2)); + // tv1 -> tv3 -> tv4 + auto tv3 = add(tv1, IrBuilder::create(3)); + auto tv4 = add(tv3, IrBuilder::create(4)); + + // NOTE: This should no longer occur as of PR #201. + // The order of adding outputs matters. If tv3 is added before tv4, + // it should be fine. However, if tv4 is added before tv3, there + // will be a cycle of tv3->tv4 and tv4->tv3. tv3->tv4 is created + // first, and then tv4->tv3 is created at the final phase of + // computeAt (ComputeAt::setupOutputs). + fusion.addOutput(tv2); + fusion.addOutput(tv4); + fusion.addOutput(tv3); + + tv0->computeAt(tv2, -1); + + TORCH_CHECK(tv3->hasComputeAt()); + TORCH_CHECK(!tv4->hasComputeAt()); + + const auto options = + at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor aten_input = at::randn(100, options); + + auto t1 = aten_input + 1; + auto t2 = t1 + 2; + auto t3 = t1 + 3; + auto t4 = t3 + 4; + + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); + auto cg_outputs = fe.runFusion({aten_input}); + + std::vector aten_outputs = {t2, t4, t3}; + testValidate( + &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionTraversalOrder1_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + TensorView* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + TensorView* tv1 = add(tv0, IrBuilder::create(1)); + TensorView* tv2 = add(tv0, IrBuilder::create(2)); + TensorView* tv3 = add(tv1, IrBuilder::create(3)); + TensorView* tv4 = add(tv1, IrBuilder::create(4)); + + fusion.addOutput(tv2); + fusion.addOutput(tv3); + fusion.addOutput(tv4); + + tv1->computeAt(tv3, -1); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor aten_input = at::randn({10, 10}, options); + + auto t1 = aten_input + 1; + auto t2 = aten_input + 2; + auto t3 = t1 + 3; + auto t4 = t1 + 4; + + std::vector aten_outputs = {t2, t3, t4}; + + std::vector cg_outputs = { + at::empty_like(aten_input, options), + at::empty_like(aten_input, options), + at::empty_like(aten_input, options)}; + + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); + fe.runFusion({aten_input}, cg_outputs); + testValidate( + &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionTraversalOrder2_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + TensorView* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + TensorView* tv1 = add(tv0, IrBuilder::create(1)); + TensorView* tv2 = add(tv1, IrBuilder::create(2)); + + TensorView* tv3 = add(tv0, IrBuilder::create(3)); + TensorView* tv4 = add(tv3, IrBuilder::create(4)); + + TensorView* tv5 = add(tv1, tv3); + + fusion.addOutput(tv2); + fusion.addOutput(tv4); + fusion.addOutput(tv5); + + tv1->computeAt(tv5, -1); + tv3->computeAt(tv5, -1); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor aten_input = at::randn({10, 10}, options); + + auto t1 = aten_input + 1; + auto t2 = t1 + 2; + auto t3 = aten_input + 3; + auto t4 = t3 + 4; + auto t5 = t1 + t3; + + std::vector aten_outputs = {t2, t4, t5}; + + std::vector cg_outputs = { + at::empty_like(aten_input, options), + at::empty_like(aten_input, options), + at::empty_like(aten_input, options)}; + + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); + fe.runFusion({aten_input}, cg_outputs); + + testValidate( + &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionTraversalOrder3_CUDA) { + for (const auto i : c10::irange(2)) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + TensorView* tv1 = add(tv0, IrBuilder::create(1)); + TensorView* tv2 = add(tv1, IrBuilder::create(2)); + + TensorView* tv3 = add(tv0, IrBuilder::create(3)); + TensorView* tv4 = add(tv3, IrBuilder::create(4)); + + TensorView* tv5 = add(tv1, tv3); + + fusion.addOutput(tv2); + fusion.addOutput(tv4); + fusion.addOutput(tv5); + + const int tile = 32; + + tv1->split(-1, tile); + tv2->split(-1, tile); + tv3->split(-1, tile); + tv4->split(-1, tile); + tv5->split(-1, tile); + + auto compute_at_outer = tv1; + auto compute_at_inner = tv3; + if (i == 1) { + std::swap(compute_at_inner, compute_at_outer); + } + + compute_at_outer->computeAt(tv5, -2); + compute_at_inner->computeAt(tv5, -1); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor aten_input = at::randn({100}, options); + auto t1 = aten_input + 1; + auto t2 = t1 + 2; + auto t3 = aten_input + 3; + auto t4 = t3 + 4; + auto t5 = t1 + t3; + + std::vector aten_outputs = {t2, t4, t5}; + + std::vector cg_outputs = { + at::empty_like(aten_input, options), + at::empty_like(aten_input, options), + at::empty_like(aten_input, options)}; + + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); + fe.runFusion({aten_input}, cg_outputs); + + testValidate( + &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); + } +} + +TEST_F(NVFuserTest, FusionTraversalOrder4_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // First tree + TensorView* tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + TensorView* tv1 = add(tv0, IrBuilder::create(1)); + TensorView* tv2 = add(tv1, IrBuilder::create(2)); + TensorView* tv3 = add(tv1, IrBuilder::create(3)); + fusion.addOutput(tv2); + fusion.addOutput(tv3); + + // Second tree + TensorView* tv4 = makeSymbolicTensor(1); + fusion.addInput(tv4); + TensorView* tv5 = add(tv4, IrBuilder::create(5)); + TensorView* tv6 = add(tv5, IrBuilder::create(6)); + TensorView* tv7 = add(tv5, IrBuilder::create(7)); + fusion.addOutput(tv6); + fusion.addOutput(tv7); + + tv1->computeAt(tv2, -1); + tv5->computeAt(tv6, -1); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({100}, options); + at::Tensor t4 = at::rand_like(t0, options); + + auto t1 = t0 + 1; + auto t2 = t1 + 2; + auto t3 = t1 + 3; + auto t5 = t4 + 5; + auto t6 = t5 + 6; + auto t7 = t5 + 7; + + std::vector aten_outputs = {t2, t3, t6, t7}; + std::vector aten_inputs = {t0, t4}; + std::vector cg_outputs = { + at::empty_like(t0, options), + at::empty_like(t0, options), + at::empty_like(t0, options), + at::empty_like(t0, options)}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + fe.runFusion(aten_inputs, cg_outputs); + + testValidate( + &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionTraversalOrder5_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + TensorView* tv1 = add(tv0, IrBuilder::create(1)); + TensorView* tv2 = add(tv1, IrBuilder::create(2)); + TensorView* tv3 = add(tv0, IrBuilder::create(3)); + TensorView* tv4 = add(tv3, IrBuilder::create(4)); + TensorView* tv5 = add(tv2, tv4); + + fusion.addOutput(tv1); + fusion.addOutput(tv3); + fusion.addOutput(tv5); + + tv2->computeAt(tv5, -1); + tv4->computeAt(tv5, -1); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor aten_input = at::randn({100}, options); + std::vector cg_outputs = { + at::empty_like(aten_input, options), + at::empty_like(aten_input, options), + at::empty_like(aten_input, options)}; + + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); + fe.runFusion({aten_input}, cg_outputs); + + auto t1 = aten_input + 1; + auto t2 = t1 + 2; + auto t3 = aten_input + 3; + auto t4 = t3 + 4; + auto t5 = t2 + t4; + + std::vector aten_outputs = {t1, t3, t5}; + + testValidate( + &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionTraversalOrder6_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + TensorView* tv1 = add(tv0, IrBuilder::create(1)); + TensorView* tv2 = add(tv0, IrBuilder::create(2)); + TensorView* tv3 = add(tv1, tv2); + TensorView* tv4 = add(tv3, IrBuilder::create(4)); + + fusion.addOutput(tv4); + + tv1->split(0, 32); + tv2->split(0, 32); + tv3->split(0, 32); + tv4->split(0, 32); + + tv3->computeAt(tv4, -2); + tv1->computeAt(tv3, -1); + tv2->computeAt(tv3, -2); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor aten_input = at::randn({100}, options); + + auto t1 = aten_input + 1; + auto t2 = aten_input + 2; + auto t3 = t1 + t2; + auto aten_output = t3 + 4; + + at::Tensor cg_output = at::empty_like(aten_input, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); + fe.runFusion({aten_input}, {cg_output}); + + testValidate( + &fusion, {cg_output}, {aten_input}, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionTraversalOrder7_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + TensorView* tv1 = add(tv0, IrBuilder::create(1)); + TensorView* tv2 = add(tv1, IrBuilder::create(2)); + TensorView* tv3 = add(tv0, IrBuilder::create(3)); + TensorView* tv4 = add(tv3, IrBuilder::create(4)); + TensorView* tv5 = add(tv2, tv4); + + fusion.addOutput(tv5); + + TensorView* tvs[] = {tv1, tv2, tv3, tv4, tv5}; + for (auto tv : tvs) { + tv->split(0, 2); + tv->split(0, 4); + tv->split(0, 8); + } + + // computeAt into inner loop nests + tv1->computeAt(tv2, -1); + tv3->computeAt(tv4, -2); + + tv2->computeAt(tv5, -4); + tv4->computeAt(tv5, -3); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor aten_input = at::randn({100}, options); + + auto t1 = aten_input + 1; + auto t2 = t1 + 2; + auto t3 = aten_input + 3; + auto t4 = t3 + 4; + auto aten_output = t2 + t4; + + at::Tensor cg_output = at::empty_like(aten_input, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); + fe.runFusion({aten_input}, {cg_output}); + + testValidate( + &fusion, {cg_output}, {aten_input}, {aten_output}, __LINE__, __FILE__); +} + +// Test predication of grid reduction +TEST_F(NVFuserTest, FusionThreadPredicate_CUDA) { + const int gdimx = 4; + const int bdimx = 128; + + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + TensorView* tv1 = + reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv0); + TensorView* tv2 = unaryOp(UnaryOpType::Neg, tv1); + TensorView* tv3 = add(tv0, IrBuilder::create(2)); + + fusion.addOutput(tv3); + fusion.addOutput(tv2); + + tv1->split(1, bdimx); + tv1->split(1, gdimx); + tv3->split(1, bdimx); + tv3->split(1, gdimx); + + TensorView* tv1_rf = tv1->rFactor({1}); + + tv1->computeAt(tv2, -1); + + tv1->axis(0)->parallelize(ParallelType::BIDy); + tv1_rf->axis(0)->parallelize(ParallelType::BIDy); + tv2->axis(0)->parallelize(ParallelType::BIDy); + tv1->axis(-2)->parallelize(ParallelType::BIDx); + tv1_rf->axis(-2)->parallelize(ParallelType::BIDx); + tv1->axis(-1)->parallelize(ParallelType::TIDx); + tv1_rf->axis(-1)->parallelize(ParallelType::TIDx); + + tv3->axis(3)->parallelize(ParallelType::TIDx); + tv3->axis(2)->parallelize(ParallelType::BIDx); + tv3->axis(0)->parallelize(ParallelType::BIDy); + + int numel_x = 100; + int numel_y = 1000; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor aten_input = at::randn({numel_x, numel_y}, options); + + auto t2 = -aten_input.to(at::kDouble).sum({1}); + auto t3 = aten_input + 2.0; + + std::vector aten_outputs = {t3, t2}; + + std::vector cg_outputs = { + at::empty_like(aten_input, options), at::empty({numel_x}, options)}; + + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); + fe.runFusion({aten_input}, cg_outputs); + + testValidate( + &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionLSTMCell_CUDA) { + const int hidden_features = 512; + const int batch_size = 64; + + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tvs[16]; + for (const auto i : c10::irange(16)) { + tvs[i] = makeSymbolicTensor(2); + fusion.addInput(tvs[i]); + } + + auto ingate = unaryOp( + UnaryOpType::Sigmoid, add(add(add(tvs[0], tvs[1]), tvs[2]), tvs[3])); + + auto forgetgate = unaryOp( + UnaryOpType::Sigmoid, add(add(add(tvs[4], tvs[5]), tvs[6]), tvs[7])); + + auto cellgate = unaryOp( + UnaryOpType::Tanh, add(add(add(tvs[8], tvs[9]), tvs[10]), tvs[11])); + + auto outgate = unaryOp( + UnaryOpType::Sigmoid, add(add(add(tvs[12], tvs[13]), tvs[14]), tvs[15])); + + auto cx = makeContigTensor(2); + fusion.addInput(cx); + + auto cy = add(mul(forgetgate, cx), mul(ingate, cellgate)); + + auto hy = mul(outgate, unaryOp(UnaryOpType::Tanh, cy)); + + fusion.addOutput(cy); + fusion.addOutput(hy); + + std::vector aten_inputs; + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor large_tensor0 = + at::randn({batch_size, hidden_features * 4}, options); + at::Tensor large_tensor1 = + at::randn({batch_size, hidden_features * 4}, options); + at::Tensor large_tensor2 = + at::randn({batch_size, hidden_features * 4}, options); + at::Tensor large_tensor3 = + at::randn({batch_size, hidden_features * 4}, options); + + auto chunked0 = large_tensor0.chunk(4, 1); + auto chunked1 = large_tensor1.chunk(4, 1); + auto chunked2 = large_tensor2.chunk(4, 1); + auto chunked3 = large_tensor3.chunk(4, 1); + + aten_inputs.insert(aten_inputs.end(), chunked0.begin(), chunked0.end()); + aten_inputs.insert(aten_inputs.end(), chunked1.begin(), chunked1.end()); + aten_inputs.insert(aten_inputs.end(), chunked2.begin(), chunked2.end()); + aten_inputs.insert(aten_inputs.end(), chunked3.begin(), chunked3.end()); + + auto at_ingate = + chunked0[0].add(chunked0[1]).add(chunked0[2]).add(chunked0[3]).sigmoid(); + auto at_forgetgate = + chunked1[0].add(chunked1[1]).add(chunked1[2]).add(chunked1[3]).sigmoid(); + auto at_cellgate = + chunked2[0].add(chunked2[1]).add(chunked2[2]).add(chunked2[3]).tanh(); + auto at_outgate = + chunked3[0].add(chunked3[1]).add(chunked3[2]).add(chunked3[3]).sigmoid(); + + auto at_cx = at::randn({batch_size, hidden_features}, options); + aten_inputs.push_back(at_cx); + auto at_cy = at_forgetgate.mul(at_cx).add(at_ingate.mul(at_cellgate)); + auto at_hy = at_outgate.mul(at_cy.tanh()); + + auto lparams = schedulePointwise(&fusion, aten_inputs); + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs, lparams); + auto cg_outputs = fe.runFusion(aten_inputs, lparams); + + testValidate( + &fusion, cg_outputs, aten_inputs, {at_cy, at_hy}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionReductionHalf_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + TensorView* tv0 = makeSymbolicTensor(3, DataType::Half); + fusion.addInput(tv0); + + auto tv1 = castOp(DataType::Float, tv0); + auto tv2 = add(tv1, IrBuilder::create(1.0)); + auto tv3 = sum(tv2, {2}); + auto tv4 = castOp(DataType::Half, tv3); + + fusion.addOutput(tv4); + + const auto options = + at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + at::Tensor aten_input = at::randn({8, 8, 16}, options); + + auto reduction_tv = tv3; + + auto reduction_params = getReductionHeuristics(&fusion, {aten_input}); + TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); + scheduleReduction(&fusion, *reduction_params); + + TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); + + auto lparams = reduction_params->lparams; + + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}, lparams); + // no broadcasting needed, omitting the last optional argument; + auto cg_outputs = fe.runFusion({aten_input}, lparams); + + auto aten_output = aten_input.add(1.0).to(at::kDouble).sum({2}); + + testValidate( + &fusion, + cg_outputs, + {aten_input}, + {aten_output}, + __LINE__, + __FILE__, + "", + lparams); +} + +TEST_F(NVFuserTest, FusionReduceSingle_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + TensorView* tv0 = makeConcreteTensor({100, 1}); + fusion.addInput(tv0); + auto tv1 = sum(tv0, {1}); + fusion.addOutput(tv1); + + const auto options = + at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor aten_input = at::randn({100, 1}, options); + + // Grab only tensor views, though there shouldn't be any other type + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); + // no broadcasting needed, omitting the last optional argument; + auto cg_outputs = fe.runFusion({aten_input}); + + auto aten_output = aten_input.to(at::kDouble).sum({1}); + testValidate( + &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionReduceImplicitBroadcast_CUDA) { + constexpr int bid_x = 80; + constexpr int tid_x = 4096; + constexpr int red_dim = 1; + + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + TensorView* tv0 = makeConcreteTensor({bid_x, tid_x, 1}); + fusion.addInput(tv0); + + TensorView* tv1 = reductionOp( + BinaryOpType::Add, {red_dim, 2}, IrBuilder::create(0), tv0); + fusion.addOutput(tv1); + + const auto options = + at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor aten_input = at::randn({bid_x, tid_x, 1}, options); + + // Apply reduction heuristic + auto reduction_params = getReductionHeuristics(&fusion, {aten_input}); + TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); + scheduleReduction(&fusion, *reduction_params); + auto lparams = reduction_params->lparams; + + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}, lparams); + // no broadcasting needed, omitting the last optional argument; + auto cg_outputs = fe.runFusion({aten_input}, lparams); + auto aten_output = aten_input.to(at::kDouble).sum({red_dim, 2}); + + testValidate( + &fusion, + cg_outputs, + {aten_input}, + {aten_output}, + __LINE__, + __FILE__, + "", + lparams); +} + +TEST_F(NVFuserTest, FusionReduceImplicitBroadcast2_CUDA) { + constexpr int bid_x = 80; + constexpr int tid_x = 4096; + constexpr int red_dim = 1; + + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + TensorView* tv0 = makeConcreteTensor({bid_x, tid_x, 1}); + fusion.addInput(tv0); + + TensorView* tv1 = + reductionOp(BinaryOpType::Add, {2}, IrBuilder::create(0), tv0); + + TensorView* tv2 = reductionOp( + BinaryOpType::Add, {red_dim}, IrBuilder::create(0), tv1); + fusion.addOutput(tv2); + + const auto options = + at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor aten_input = at::randn({bid_x, tid_x, 1}, options); + + // Apply reduction heuristic + auto reduction_params = getReductionHeuristics(&fusion, {aten_input}); + TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); + + scheduleReduction(&fusion, *reduction_params); + auto lparams = reduction_params->lparams; + + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}, lparams); + // no broadcasting needed, omitting the last optional argument; + auto cg_outputs = fe.runFusion({aten_input}, lparams); + auto aten_output = aten_input.to(at::kDouble).sum({1, 2}); + + testValidate( + &fusion, + cg_outputs, + {aten_input}, + {aten_output}, + __LINE__, + __FILE__, + "", + lparams); +} + +TEST_F(NVFuserTest, FusionReduceImplicitBroadcast3_CUDA) { + constexpr int bid_x = 80; + constexpr int tid_x = 4096; + constexpr int red_dim = 1; + + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + TensorView* tv0 = makeConcreteTensor({bid_x, tid_x, 1}); + fusion.addInput(tv0); + + TensorView* tv1 = reductionOp( + BinaryOpType::Add, {red_dim}, IrBuilder::create(0), tv0); + + TensorView* tv2 = + reductionOp(BinaryOpType::Add, {1}, IrBuilder::create(0), tv1); + fusion.addOutput(tv2); + + const auto options = + at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor aten_input = at::randn({bid_x, tid_x, 1}, options); + + // Apply reduction heuristic + auto reduction_params = getReductionHeuristics(&fusion, {aten_input}); + TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); + scheduleReduction(&fusion, *reduction_params); + auto lparams = reduction_params->lparams; + + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}, lparams); + // no broadcasting needed, omitting the last optional argument; + auto cg_outputs = fe.runFusion({aten_input}, lparams); + auto aten_output = aten_input.to(at::kDouble).sum({2, 1}); + + testValidate( + &fusion, + cg_outputs, + {aten_input}, + {aten_output}, + __LINE__, + __FILE__, + "", + lparams); +} + +TEST_F(NVFuserTest, FusionTrivialReduction_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + TensorView* tv0 = makeConcreteTensor({10, 20, 1}); + fusion.addInput(tv0); + TensorView* tv1 = + reductionOp(BinaryOpType::Add, {2}, IrBuilder::create(0), tv0); + fusion.addOutput(tv1); + + TORCH_CHECK( + ir_utils::getReductionOps(&fusion, true /* ignore_trivial */).empty(), + "Trivial reduction picked up by fusion"); + + const auto options = + at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor aten_input = at::randn({10, 20, 1}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); + auto cg_outputs = fe.runFusion({aten_input}); + auto aten_output = aten_input.to(at::kDouble).sum({2}); + + testValidate( + &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionTrivialReduction2_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + int w = 1, x = 1, y = 7, z = 8; + + auto tv0 = makeSymbolicTensor(2); + auto tv1 = makeConcreteTensor({w, x, y, z}); + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = sum(tv1, {0}); + auto tv3 = sum(tv2, {0}); + auto tv4 = add(tv3, tv0); + + fusion.addOutput(tv4); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({y, z}, options); + at::Tensor t1 = at::randn({w, x, y, z}, options); + auto aten_output = t1.to(at::kDouble).sum({0}).sum({0}).add(t0); + + std::vector aten_inputs = {t0, t1}; + + auto lparams = schedulePointwise(&fusion, aten_inputs); + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs, lparams); + auto cg_outputs = fe.runFusion(aten_inputs, lparams); + + testValidate( + &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionTrivialReduction3_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + int v = 1, w = 1, x = 1, y = 7, z = 8; + + auto tv0 = makeSymbolicTensor(2); + auto tv1 = makeConcreteTensor({v, w, x, y, z}); + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = sum(tv1, {0, 1, 2}); + auto tv3 = add(tv2, tv0); + + fusion.addOutput(tv3); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({y, z}, options); + at::Tensor t1 = at::randn({v, w, x, y, z}, options); + auto aten_output = t1.sum({0, 1, 2}).add(t0); + + std::vector aten_inputs = {t0, t1}; + + auto lparams = schedulePointwise(&fusion, aten_inputs); + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs, lparams); + auto cg_outputs = fe.runFusion(aten_inputs, lparams); + + testValidate( + &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +// Make sure trivial reductions are correctly detected even with +// scheduling applied. +TEST_F(NVFuserTest, FusionDetectTrivialReduction1_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + auto tv1 = broadcast(tv0, {false, true}); + auto tv2 = sum(tv1, {1}); + fusion.addOutput(tv2); + + tv2->split(1, 4); + tv2->split(1, 8); + auto tv3 = tv2->rFactor({-1}); + auto tv4 = tv2->rFactor({-1}); + + auto tv5 = broadcast(tv0, {true, false}); + auto tv6 = add(tv5, IrBuilder::create(1)); + auto tv7 = sub(tv6, IrBuilder::create(1)); + auto tv8 = sum(tv7, {0}); + fusion.addOutput(tv8); + + auto tv9 = broadcast(tv0, {false, true, true}); + auto tv10 = sum(tv9, {1}); + auto tv11 = sum(tv10, {1}); + fusion.addOutput(tv11); + + tv8->split(0, 3); + tv10->split(1, 4); + tv11->split(1, 5); + + tv0->computeAt(tv2, -1); + tv0->computeAt(tv8, -1); + tv0->computeAt(tv11, 1); + + // Test indexing to gmem-backed tensors + tv3->setMemoryType(MemoryType::Global); + tv8->setMemoryType(MemoryType::Global); + + GpuLower gpulw(&fusion); + + // No ReductionOp should be generated as all the reduction + // exprs should be replaced with a unary set op. + for (const auto expr : gpulw.kernel()->as()->exprs()) { + TORCH_CHECK(!expr->isA()); + } + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({100}, options); + std::vector aten_inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + testValidate( + &fusion, cg_outputs, aten_inputs, {t0, t0, t0}, __LINE__, __FILE__); +} + +// Test detection of partially trivial reduction +TEST_F(NVFuserTest, FusionDetectTrivialReduction2_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto tv1 = sum(tv0, {1}); + auto tv2 = add(tv1, IrBuilder::create(1)); + fusion.addOutput(tv2); + + tv1->split(1, 1); + // tv1->axis(1): non-trivial + // tv1->axis(2): trivial + + auto tv3 = tv1->rFactor({-1}); + + // Just to suppress register-allocation warning + tv0->computeAt(tv2, 1); + tv3->computeAt(tv1, -1); + + GpuLower gpulw(&fusion); + + // tv3's reduction axis is a trivial reduction. The only + // ReductionOp should be for tv1. + for (const auto expr : gpulw.kernel()->as()->exprs()) { + if (expr->isA()) { + auto reduction_out = + expr->as()->outputs()[0]->as(); + TORCH_CHECK(reduction_out->name() == 1); + } + } +} + +TEST_F(NVFuserTest, FusionInputsIdLookup_CUDA) { + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({16, 8, 8}, options); + at::Tensor t1 = at::randn({8, 8}, options); + at::Tensor t2 = at::randn({6, 4}, options); + + // create a cache with max size 2; + torch::jit::fuser::cuda::InputsIdLookup inputs_id_lookup(2); + + // testing basic function, same encoding for identical inputs + auto id_0 = inputs_id_lookup.lookupId({t0, t1, 5.0}); + auto id_0_lookup = inputs_id_lookup.lookupId({t0, t1, 2.5}); + TORCH_CHECK(id_0.id == id_0_lookup.id); + TORCH_CHECK(inputs_id_lookup.size() == 1); + TORCH_CHECK(id_0.eviction == false); + + // new input (even tho same shape, but we have different signature because of + // missing scalar input + auto id_1 = inputs_id_lookup.lookupId({t0, t1}); + auto id_1_lookup = inputs_id_lookup.lookupId({t0, t1}); + TORCH_CHECK(id_1.id == id_1_lookup.id); + TORCH_CHECK(inputs_id_lookup.size() == 2); + TORCH_CHECK(id_1.eviction == false); + + // eviction should happen at this point + auto id_2 = inputs_id_lookup.lookupId({t2, t1}); + TORCH_CHECK(id_2.id != id_0.id); + TORCH_CHECK(id_2.id != id_1.id); + TORCH_CHECK(inputs_id_lookup.size() == 2); + TORCH_CHECK(id_2.eviction == true); + TORCH_CHECK(id_2.evict_id == id_0.id); + + // look at input 1 again + auto id_1_relook = inputs_id_lookup.lookupId({t0, t1}); + TORCH_CHECK(id_1_relook.id == id_1.id); + TORCH_CHECK(id_1_relook.eviction == false); +} + +TEST_F(NVFuserTest, FusionGroupGuardSimpleTensor_CUDA) { + std::vector sizes_vec({16, 8, 8}); + std::vector strides_vec({64, 8, 1}); + auto tensor_type = TensorType::create( + at::kFloat, c10::nullopt, sizes_vec, strides_vec, c10::nullopt); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + // pass with identical shape + auto t0 = at::randn({16, 8, 8}, options); + TORCH_CHECK(complyWith(t0, tensor_type)); + + // pass with dynamic shape + auto t1 = at::randn({16, 16, 8}, options); + TORCH_CHECK(complyWith(t1, tensor_type)); + + // broadcasting semantic change failure + auto t2 = at::randn({16, 1, 8}, options); + TORCH_CHECK(!complyWith(t2, tensor_type)); + + // contiguity failure via slicing + auto t3 = t0.slice(1, 0, 8, 2); + TORCH_CHECK(!complyWith(t3, tensor_type)); + + // contiguity failure via slicing + auto t4 = t0.slice(2, 0, 8, 2); + TORCH_CHECK(!complyWith(t4, tensor_type)); + + // rank failure + auto t5 = at::randn({16, 8, 8, 8}, options); + TORCH_CHECK(!complyWith(t5, tensor_type)); + + // contiguity on stride 1 dimension with implicit broadcasting + auto t = at::randn({4}, options); + auto t6 = t.unsqueeze(1).expand({4, 8}); + TORCH_CHECK(complyWith(t6, TensorType::create(t6))); +} + +TEST_F(NVFuserTest, FusionGroupGuardBroadcastTensor_CUDA) { + std::vector sizes_vec({16, 1, 8}); + std::vector strides_vec({8, 8, 1}); + auto tensor_type = TensorType::create( + at::kFloat, c10::nullopt, sizes_vec, strides_vec, c10::nullopt); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + // broadcasting semantic change + auto t0 = at::randn({16, 8, 8}, options); + TORCH_CHECK(!complyWith(t0, tensor_type)); + + // dtype failure + auto t1 = at::randn({16, 1, 8}, options.dtype(at::kHalf)); + TORCH_CHECK(!complyWith(t1, tensor_type)); + + // dtype failure + auto t2 = at::randn({16, 1, 8}, options); + TORCH_CHECK(complyWith(t2, tensor_type)); + + // device inconsistency shouldn't fail + auto t3 = at::randn({16, 1, 8}, options.device(at::kCPU, 0)); + TORCH_CHECK(complyWith(t3, tensor_type)); +} + +TEST_F(NVFuserTest, FusionGroupGuardPermutedTensor_CUDA) { + std::vector sizes_vec({16, 8, 8}); + std::vector strides_vec({64, 1, 8}); + auto tensor_type = TensorType::create( + at::kFloat, c10::nullopt, sizes_vec, strides_vec, c10::nullopt); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + // failing permutation + auto t0 = at::randn({16, 8, 8}, options); + TORCH_CHECK(!complyWith(t0, tensor_type)); + + // passing with dynamic shape + auto t1 = t0.permute({0, 2, 1}); + TORCH_CHECK(complyWith(t1, tensor_type)); +} + +TEST_F(NVFuserTest, FusionGroupGuardRelaxedCheck_CUDA) { + std::vector sizes_vec({16, 8, 8}); + std::vector strides_vec({128, 16, 1}); + auto tensor_type = TensorType::create( + at::kFloat, c10::nullopt, sizes_vec, strides_vec, c10::nullopt); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + // contiguity check passes although it differs + auto t0 = at::randn({16, 16, 8}, options); + TORCH_CHECK(complyWith(t0, tensor_type)); + + // passing with dynamic shape + auto t1 = t0.slice(1, 0, 16, 2); + TORCH_CHECK(complyWith(t1, tensor_type)); +} + +TEST_F(NVFuserTest, FusionDisjointSet_CUDA) { + DisjointSets set; + + const std::set group_x({0, 1, 2}); + const std::set group_y({3, 4, 5}); + const std::set group_z({6, 7, 8}); + const std::vector> groups({group_x, group_y, group_z}); + std::set group_all; + std::for_each(groups.begin(), groups.end(), [&](const auto& g) { + group_all.insert(g.begin(), g.end()); + }); + + // Initially, nothing should be considered equivalent + for (auto i : group_all) { + for (auto j : group_all) { + TORCH_CHECK(!set.permissiveAreMapped(i, j)); + } + } + + // Sets values in group_x are equivalent + for (auto i : group_x) { + for (auto j : group_x) { + set.mapEntries(i, j); + TORCH_CHECK(set.mappingExists(i)); + TORCH_CHECK(set.mappingExists(j)); + } + } + + // All values in group_x shoudl be equivalent with each other + for (auto i : group_x) { + for (auto j : group_x) { + TORCH_CHECK(set.permissiveAreMapped(i, j)); + } + } + // But nothing else should be equivalent + for (auto i : group_all) { + for (auto j : group_y) { + TORCH_CHECK(!set.permissiveAreMapped(i, j)); + } + for (auto j : group_z) { + TORCH_CHECK(!set.permissiveAreMapped(i, j)); + } + } + + // Sets values in group_y are equivalent + for (auto i : group_y) { + for (auto j : group_y) { + set.mapEntries(i, j); + TORCH_CHECK(set.mappingExists(i)); + TORCH_CHECK(set.mappingExists(j)); + } + } + + // group_x should be still equivalent + for (auto i : group_x) { + for (auto j : group_x) { + TORCH_CHECK(set.permissiveAreMapped(i, j)); + } + } + // group_y should be now equivalent + for (auto i : group_y) { + for (auto j : group_y) { + TORCH_CHECK(set.permissiveAreMapped(i, j)); + } + } + // But group_z should not be equivalent with anything yet + for (auto i : group_all) { + for (auto j : group_z) { + TORCH_CHECK(!set.permissiveAreMapped(i, j)); + } + } + + // Sets values in group_z are equivalent + for (auto i : group_z) { + for (auto j : group_z) { + set.mapEntries(i, j); + TORCH_CHECK(set.mappingExists(i)); + TORCH_CHECK(set.mappingExists(j)); + } + } + + // Now each of the three groups should be equivalent within each + // group + for (const auto gi : c10::irange(groups.size())) { + for (const auto gj : c10::irange(groups.size())) { + for (auto i : groups[gi]) { + for (auto j : groups[gj]) { + TORCH_CHECK( + (gi == gj && set.permissiveAreMapped(i, j)) || + (gi != gj && !set.permissiveAreMapped(i, j))); + } + } + } + } + + std::vector all_elements = set.getAllElements().vector(); + std::sort(all_elements.begin(), all_elements.end()); + std::vector group_all_vec(group_all.begin(), group_all.end()); + std::sort(group_all_vec.begin(), group_all_vec.end()); + TORCH_CHECK(all_elements == group_all_vec); + + set.clear(); + TORCH_CHECK(set.getAllElements().vector().size() == 0); + + // All cleared. Nothing should be considered equivalent. + for (auto i : group_all) { + for (auto j : group_all) { + TORCH_CHECK(!set.permissiveAreMapped(i, j)); + } + } +} + +TEST_F(NVFuserTest, FusionNonUniqueBroadcastSize_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + auto tv1 = makeSymbolicTensor(2); + auto tv2 = makeSymbolicTensor(2); + fusion.addInput(tv0); + fusion.addInput(tv1); + fusion.addInput(tv2); + + auto tv3 = broadcast(tv0, {true, false}); + auto tv4 = add(tv3, tv1); + auto tv5 = add(tv3, tv2); + + fusion.addOutput(tv4); + fusion.addOutput(tv5); + + // In order to do this, tv1->axis(1) and tv2->axis(1) must have the + // same size, but we can't prove it, so this should throw an error. + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) + ASSERT_ANY_THROW(tv3->computeAt(tv4, -1)); +} + +TEST_F(NVFuserTest, FusionBiasGeluFwd_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + const float k_079 = 0.79788456; + const float k_004 = 0.044715; + + // bias vector + auto t0 = makeSymbolicTensor(1, DataType::Half); + fusion.addInput(t0); + auto t1 = castOp(DataType::Float, t0); + // input tensor + auto t2 = makeSymbolicTensor(3, DataType::Half); + fusion.addInput(t2); + auto t3 = castOp(DataType::Float, t2); + auto t4 = broadcast(t1, {true, true, false}); + auto t5 = add(t4, t3); + auto t6 = mul(t5, IrBuilder::create(0.5)); + auto t7 = mul(t5, IrBuilder::create(k_079)); + auto t8 = mul(t5, IrBuilder::create(k_004)); + auto t9 = mul(t8, t5); + auto t10 = add(t9, IrBuilder::create(1)); + auto t11 = mul(t7, t10); + auto t12 = unaryOp(UnaryOpType::Tanh, t11); + auto t13 = add(t12, IrBuilder::create(1)); + auto t14 = mul(t6, t13); + auto t15 = castOp(DataType::Half, t14); + fusion.addOutput(t15); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + at::manual_seed(0); + std::vector input_shape{6, 512, 4096}; + std::vector bias_shape{4096}; + + auto at_input = at::randn(input_shape, options); + auto at_bias = at::randn(bias_shape, options); + + auto at_x = + at_bias.to(c10::ScalarType::Float) + at_input.to(c10::ScalarType::Float); + auto aten_output_float = + at_x * 0.5 * (1.0 + (k_079 * at_x * (1 + k_004 * at_x * at_x)).tanh()); + auto aten_output = aten_output_float.to(c10::ScalarType::Half); + + std::vector aten_inputs = {at_bias, at_input}; + auto lparams = schedulePointwise(&fusion, aten_inputs); + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs, lparams); + auto cg_outputs = fe.runFusion(aten_inputs, lparams); + + testValidate( + &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionBiasGeluBwd_CUDA) { + if (at::cuda::getDeviceProperties(0)->major < 6) { + return; + } + Fusion fusion; + FusionGuard fg(&fusion); + + const float k_079 = 0.79788456; + const float k_004 = 0.044715; + const float k_010 = 0.1070322243; + + // gradient tensor + auto t0 = makeSymbolicTensor(3, DataType::Half); + fusion.addInput(t0); + auto t1 = castOp(DataType::Float, t0); + // bias tensor + auto t2 = makeSymbolicTensor(1, DataType::Half); + fusion.addInput(t2); + auto t3 = castOp(DataType::Float, t2); + // input tensor + auto t4 = makeSymbolicTensor(3, DataType::Half); + fusion.addInput(t4); + auto t5 = castOp(DataType::Float, t4); + auto t6 = broadcast(t3, {true, true, false}); + auto t7 = add(t6, t5); + auto t8 = mul(t7, IrBuilder::create(k_079)); + auto t9 = mul(t7, IrBuilder::create(k_004)); + auto t10 = mul(t9, t7); + auto t11 = add(t10, IrBuilder::create(1)); + auto t12 = mul(t8, t11); + auto t13 = unaryOp(UnaryOpType::Tanh, t12); + auto t14 = mul(t7, IrBuilder::create(0.5)); + auto t15 = mul(t13, t13); + auto t16 = unaryOp(UnaryOpType::Neg, t15); + auto t17 = add(t16, IrBuilder::create(1)); + auto t18 = mul(t7, IrBuilder::create(k_010)); + auto t19 = mul(t18, t7); + auto t20 = add(t19, IrBuilder::create(k_079)); + auto t21 = mul(t17, t20); + auto t22 = mul(t14, t21); + auto t23 = add(t13, IrBuilder::create(1)); + auto t24 = mul(t23, IrBuilder::create(0.5)); + auto t25 = add(t22, t24); + auto t26 = mul(t25, t1); + // Save float output for validation + fusion.addOutput(t26); + auto t27 = castOp(DataType::Half, t26); + fusion.addOutput(t27); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + at::manual_seed(1); + std::vector input_shape{6, 512, 4096}; + std::vector bias_shape{4096}; + auto at_input = at::randn(input_shape, options); + auto at_bias = at::randn(bias_shape, options); + auto at_grad = at::randn(input_shape, options); + + auto at_x = + at_bias.to(c10::ScalarType::Float) + at_input.to(c10::ScalarType::Float); + auto at_tanh_out = (k_079 * at_x * (1 + k_004 * at_x * at_x)).tanh(); + auto at_ff = 0.5 * at_x * + ((1 - at_tanh_out * at_tanh_out) * (k_079 + k_010 * at_x * at_x)) + + 0.5 * (1 + at_tanh_out); + auto at_out = at_ff * at_grad; + auto at_out_half = at_out.to(c10::ScalarType::Half); + + std::vector aten_inputs = {at_grad, at_bias, at_input}; + std::vector aten_outputs = {at_out, at_out_half}; + + auto lparams = schedulePointwise(&fusion, aten_inputs); + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs, lparams); + auto cg_outputs = fe.runFusion(aten_inputs, lparams); + + testValidate( + &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__); +} + +// Reproducer of issue #459 +TEST_F(NVFuserTest, FusionIssue459_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + auto tv1 = makeSymbolicTensor(2); + fusion.addInput(tv1); + + auto tv2 = add(tv0, IrBuilder::create(1)); + auto tv3 = broadcast(tv2, {true, false}); + auto tv4 = add(tv1, tv3); + + // Create two outputs from the final arithmetic result + auto tv5 = add(tv4, IrBuilder::create(1)); + fusion.addOutput(tv5); + auto tv6 = add(tv4, IrBuilder::create(1)); + fusion.addOutput(tv6); + + // Scheduling + for (auto output : ir_utils::filterByType(fusion.outputs())) { + output->merge(-2, -1); + } + for (auto output : ir_utils::filterByType(fusion.outputs())) { + output->split(0, 128); + } + + tv0->computeAt(tv5, -1); + + tv6->axis(0)->parallelize(ParallelType::BIDx); + tv6->axis(1)->parallelize(ParallelType::TIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + const int numel_x = 10; + const int numel_y = 20; + auto t0 = at::randn({numel_x}, options); + auto t1 = at::randn({numel_y, numel_x}, options); + auto aten_output = (t0 + 1).unsqueeze(0) + t1 + 1; + + std::vector aten_inputs = {t0, t1}; + + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + testValidate( + &fusion, + cg_outputs, + aten_inputs, + {aten_output, aten_output}, + __LINE__, + __FILE__); +} + +TEST_F(NVFuserTest, FusionSmemIndexingSimple_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); + auto tv3 = add(tv2, IrBuilder::create(1)); + fusion.addOutput(tv3); + + tv3->axis(0)->parallelize(ParallelType::BIDx); + tv3->axis(1)->parallelize(ParallelType::TIDx); + + tv0->computeAt(tv3, -1); + + tv1->setMemoryType(MemoryType::Shared); + tv2->setMemoryType(MemoryType::Global); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + auto aten_input = at::randn({12, 34}, options); + at::Tensor aten_output = aten_input + 1.0 + 1.0 + 1.0; + + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); + auto cg_outputs = fe.runFusion({aten_input}); + + testValidate( + &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionSmemIndexing_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Symbolic integers we will use for runtime tiling + Int* symbolic_m_tile_dim = IrBuilder::create(); + Int* symbolic_split_k_tile_dim = IrBuilder::create(); + Int* symbolic_block_k_tile_dim = IrBuilder::create(); + // Compile-time integer for tiling + int n_smem_tile = 32; + + // Symbolic 2D tensors TV0[M, K], TV1[K, N] + TensorView* tv0 = makeSymbolicTensor(2); + TensorView* tv1 = makeSymbolicTensor(2); + + // Broadcast tv0 to [M, K, *] + TensorView* tv2 = broadcast(tv0, {false, false, true}); + // Broadcast tv1 to [*, K, N] + TensorView* tv3 = broadcast(tv1, {true, false, false}); + + // Pointwise multiplication resulting in tv3[M, K, N] + TensorView* tv4 = mul(tv2, tv3); + + // Sum the K-dim + TensorView* tv5 = sum(tv4, {1}); + + // Register inputs and outputs + fusion.addInput(tv0); + fusion.addInput(tv1); + fusion.addOutput(tv5); + + // Register runtime tile dims as inputs + fusion.addInput(symbolic_m_tile_dim); + fusion.addInput(symbolic_split_k_tile_dim); + fusion.addInput(symbolic_block_k_tile_dim); + + // Make a 3D tile, mix of symbolic and constant, do in reverse order because + // dims are inserted + // [M, rK, N] + tv5->split(2, n_smem_tile); + // [M, rK, No, Ni{32}] + tv5->split(1, symbolic_block_k_tile_dim); + // [M, rKo, rKi{i2}, No, Ni{32}] + tv5->split(1, symbolic_split_k_tile_dim); + // [M, rKoo, rKoi{i1}, rKi{i2}, No, Ni{32}] + tv5->split(0, symbolic_m_tile_dim); + // [Mo, Mi{i0}, rKoo, rKoi{i1}, rKi{i2}, No, Ni{32}] + + // Reorder so all outer tiles are in the leftmost 3 positions + // [Mo, Mi{i0}, rKoo, rKoi{i1}, rKi{i2}, No, Ni{32}] + // [Mo, No, rKoo, rKoi{i1}, rKi{i2}, Mi{i0}, Ni{32}] + tv5->reorder({{1, 5}, {5, 1}}); + + // Factor out the outer reduction IterDomain, then run the inter-cta + // reduction, and intra-cta reduction + // [Mo, No, rKoo, Koi{i1}, Ki{i2}, Mi{i0}, Ni{32}] + // [Mo, No, rKoi{i1}, rKi{i2}, Mi{i0}, Ni{32}] + auto tv6 = tv5->rFactor({2}); + + // Scope computations + tv6->computeAt(tv5, 2); + + // [Mo, No, rKoo, Koi{i1}, Ki{i2}, Mi{i0}, Ni{32}] + // [Mo, No, Ki{i2}, Mi{i0}, Ni{32}, rKoo, Koi{i1}] + tv6->reorder({ + {5, -2}, + {6, -1}, + {2, 2}, + {3, 3}, + {4, 4}, + }); + + // Setup compute at schedule + tv0->computeAt(tv6, 3); + tv1->computeAt(tv6, 3); + tv4->computeAt(tv6, -1); + + // Cache smem tiles + tv2->setMemoryType(MemoryType::Shared); + tv3->setMemoryType(MemoryType::Shared); + tv4->setMemoryType(MemoryType::Shared); + tv6->setMemoryType(MemoryType::Shared); + + tv5->axis(0)->parallelize(ParallelType::BIDz); + tv5->axis(1)->parallelize(ParallelType::BIDy); + + std::vector tv_list = {tv2, tv3, tv4, tv5, tv6}; + for (auto tv : tv_list) { + tv->axis(-2)->parallelize(ParallelType::TIDz); + tv->axis(-1)->parallelize(ParallelType::TIDy); + } + + constexpr int M = 31, K = 65, N = 32; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({M, K}, options); + at::Tensor t1 = at::randn({K, N}, options); + + at::Tensor aten_output = + mul(t0.unsqueeze(2), t1.unsqueeze(0)).to(at::kDouble).sum(1); + + // A, B, m_tile_dim, split_k, intra_cta_tile + std::vector aten_inputs = {t0, t1, 3, 4, 5}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + testValidate( + &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +// Reproducer of issue 408 +TEST_F(NVFuserTest, FusionCacheBeforeReduction_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = sum(tv1, {1}); + fusion.addOutput(tv2); + + tv2->split(0, 4); + + auto tv3 = tv2->cacheBefore(); + + tv0->computeAt(tv3, -1); + tv3->computeAt(tv2, -1); + + tv3->axis(-1)->parallelize(ParallelType::TIDx); + + const int numel_x = 100; + const int numel_y = 200; + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor aten_input = at::randn({numel_x, numel_y}, options); + at::Tensor cg_output = at::empty({numel_x}, options); + + auto aten_output = (aten_input + 1).to(at::kDouble).sum({1}); + + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); + fe.runFusion({aten_input}, {cg_output}); + + testValidate( + &fusion, {cg_output}, {aten_input}, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionCacheBeforeReduction2_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(3); + fusion.addInput(tv0); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = sum(tv1, {1}); + auto tv3 = add(tv2, IrBuilder::create(1)); + fusion.addOutput(tv2); + fusion.addOutput(tv3); + + auto tv4 = tv2->cacheBefore(); + + tv4->computeAt(tv3, 1); + tv0->computeAt(tv4, -1); + + tv3->axis(0)->parallelize(ParallelType::BIDx); + tv1->axis(-1)->parallelize(ParallelType::TIDx); + tv2->axis(-1)->parallelize(ParallelType::TIDx); + tv3->axis(-1)->parallelize(ParallelType::TIDx); + tv4->axis(-1)->parallelize(ParallelType::TIDx); + + const int numel_x = 10; + const int numel_y = 20; + const int numel_z = 30; + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor aten_input = at::randn({numel_x, numel_y, numel_z}, options); + auto t2 = (aten_input + 1).to(at::kDouble).sum({1}); + auto t3 = t2 + 1; + std::vector aten_outputs = {t2, t3}; + + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); + auto cg_outputs = fe.runFusion({aten_input}); + + testValidate( + &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionIssue367_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Symbolic integers we will use for runtime tiling + Int* symbolic_m_tile_dim = IrBuilder::create(); + Int* symbolic_split_k_tile_dim = IrBuilder::create(); + Int* symbolic_block_k_tile_dim = IrBuilder::create(); + // Compile-time integer for tiling + int n_smem_tile = 32; + + // Symbolic 2D tensors TV0[M, K], TV1[K, N] + TensorView* tv0 = makeSymbolicTensor(2); + TensorView* tv1 = makeSymbolicTensor(2); + + // Broadcast tv0 to [M, K, *] + TensorView* tv2 = broadcast(tv0, {false, false, true}); + // Broadcast tv1 to [*, K, N] + TensorView* tv3 = broadcast(tv1, {true, false, false}); + + // Pointwise multiplication resulting in tv3[M, K, N] + TensorView* tv4 = mul(tv2, tv3); + + // Sum the K-dim + TensorView* tv5 = sum(tv4, {1}); + + // Register inputs and outputs + fusion.addInput(tv0); + fusion.addInput(tv1); + fusion.addOutput(tv5); + + // Register runtime tile dims as inputs + fusion.addInput(symbolic_m_tile_dim); + fusion.addInput(symbolic_split_k_tile_dim); + fusion.addInput(symbolic_block_k_tile_dim); + + // Make a 3D tile, mix of symbolic and constant, do in reverse order because + // dims are inserted + // [M, K, N] + tv5->split(2, n_smem_tile); + tv5->split(1, symbolic_block_k_tile_dim); + tv5->split(1, symbolic_split_k_tile_dim); + tv5->split(0, symbolic_m_tile_dim); + // [Mo, Mi, Koo, Koi, Ki, No, Ni] + tv5->reorder({{1, 5}, {5, 1}}); + // [Mo, No, Koo, Koi, Ki, Mi, Ni] + + auto tv6 = tv5->rFactor({2}); + auto tv7 = tv5->rFactor({2}); + // [Mo, No, rKoo, Koi, Ki, Mi, Ni] + // [Mo, No, rKoi, rKi, Mi, Ni] + + // Scope computations + tv6->computeAt(tv5, 2); + + tv0->computeAt(tv6, 3); + tv1->computeAt(tv6, 3); + tv4->computeAt(tv6, -1); + + // Cache smem tiles + tv2->setMemoryType(MemoryType::Shared); + tv3->setMemoryType(MemoryType::Shared); + tv4->setMemoryType(MemoryType::Local); + tv6->setMemoryType(MemoryType::Local); + tv7->setMemoryType(MemoryType::Local); + + tv5->axis(0)->parallelize(ParallelType::BIDz); + tv5->axis(1)->parallelize(ParallelType::BIDy); + + std::vector tv_list = {tv2, tv3, tv4, tv5, tv6, tv7}; + for (auto tv : tv_list) { + tv->axis(-2)->parallelize(ParallelType::TIDz); + tv->axis(-1)->parallelize(ParallelType::TIDy); + } + tv2->axis(3)->parallelize(ParallelType::TIDx); + tv3->axis(3)->parallelize(ParallelType::TIDx); + tv4->axis(3)->parallelize(ParallelType::TIDx); + tv6->axis(3)->parallelize(ParallelType::TIDx); + tv7->axis(2)->parallelize(ParallelType::TIDx); + + tv2->axis(4)->parallelize(ParallelType::BIDx); + tv3->axis(4)->parallelize(ParallelType::BIDx); + tv4->axis(4)->parallelize(ParallelType::BIDx); + tv6->axis(4)->parallelize(ParallelType::BIDx); + tv7->axis(3)->parallelize(ParallelType::BIDx); + tv5->axis(2)->parallelize(ParallelType::BIDx); + + constexpr int M = 3, K = 6, N = 16; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor t0 = at::randn({M, K}, options); + at::Tensor t1 = at::randn({K, N}, options); + + // A, B, m, split_k, block_k + std::vector aten_inputs = {t0, t1, 2, 2, 3}; + at::Tensor aten_output = + mul(t0.unsqueeze(2), t1.unsqueeze(0)).to(at::kDouble).sum(1); + + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + testValidate( + &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionIssue468_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto tv1 = sum(tv0, {1}); + auto tv2 = sum(tv1, {0}); + fusion.addOutput(tv2); + + tv1->axis(0)->parallelize(ParallelType::TIDy); + tv1->axis(1)->parallelize(ParallelType::TIDx); + + tv2->axis(0)->parallelize(ParallelType::TIDy); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor aten_input = at::randn({10, 100}, options); + at::Tensor aten_output = aten_input.to(at::kDouble).sum({1}).sum({0}); + + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); + auto cg_outputs = fe.runFusion({aten_input}); + + testValidate( + &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionIssue363_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Symbolic 2D tensors TV0[M, K], TV1[K, N] + TensorView* tv0 = makeSymbolicTensor(2); + TensorView* tv1 = makeSymbolicTensor(2); + + // Broadcast tv0 to [M, K, *] + TensorView* tv2 = broadcast(tv0, {false, false, true}); + // Broadcast tv1 to [*, K, N] + TensorView* tv3 = broadcast(tv1, {true, false, false}); + + // Pointwise multiplication resulting in tv3[M, K, N] + TensorView* tv4 = mul(tv2, tv3); + + // Sum the K-dim + TensorView* tv5 = sum(tv4, {1}); + + // Register inputs and outputs + fusion.addInput(tv0); + fusion.addInput(tv1); + fusion.addOutput(tv5); + + tv2->setMemoryType(MemoryType::Global); + tv3->setMemoryType(MemoryType::Global); + tv4->setMemoryType(MemoryType::Global); + + tv0->computeAt(tv5, -1); + tv1->computeAt(tv5, -1); + + tv5->axis(0)->parallelize(ParallelType::BIDz); + tv5->axis(1)->parallelize(ParallelType::BIDy); + + tv5->axis(2)->parallelize(ParallelType::BIDx); + + constexpr int M = 3, K = 6, N = 16; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor t0 = at::randn({M, K}, options); + at::Tensor t1 = at::randn({K, N}, options); + at::Tensor aten_output = + mul(t0.unsqueeze(2), t1.unsqueeze(0)).to(at::kDouble).sum(1); + + std::vector aten_inputs = {t0, t1}; + + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + testValidate( + &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionIssue484_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto tv1 = sum(tv0, {1}); + auto tv2 = add(tv1, IrBuilder::create(0)); + fusion.addOutput(tv2); + + tv1->setMemoryType(MemoryType::Global); + tv1->axis(1)->parallelize(ParallelType::TIDx); + + constexpr int M = 100; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor aten_input = at::randn({M, M}, options); + at::Tensor aten_output = aten_input.to(at::kDouble).sum({1}); + + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); + auto cg_outputs = fe.runFusion({aten_input}); + + testValidate( + &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionIssue329_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = sum(tv1, {1}); + fusion.addOutput(tv2); + auto tv3 = sum(tv1, {1}); + fusion.addOutput(tv3); + + tv1->computeAt(tv2, -1); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + std::vector t0_shape{17, 19}; + auto aten_input = at::randn(t0_shape, options); + auto t2 = (aten_input + 1).to(at::kDouble).sum({1}); + auto t3 = (aten_input + 1).to(at::kDouble).sum({1}); + std::vector aten_outputs = {t2, t3}; + + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); + auto cg_outputs = fe.runFusion({aten_input}); + + testValidate( + &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionIssue382_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = broadcast(tv1, {false, false, true}); + auto tv3 = makeSymbolicTensor(3); + fusion.addInput(tv3); + auto tv4 = add(tv2, tv3); + fusion.addOutput(tv4); + + tv2->merge(1); + tv4->merge(1); + + tv1->computeAt(tv4, 1); + + tv4->axis(0)->parallelize(ParallelType::BIDx); + + tv1->setMemoryType(MemoryType::Global); + tv2->setMemoryType(MemoryType::Global); + + const int numel_x = 12; + const int numel_y = 34; + const int numel_z = 56; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + auto t0 = at::randn({numel_x, numel_y}, options); + auto t3 = at::randn({numel_x, numel_y, numel_z}, options); + + std::vector aten_inputs = {t0, t3}; + auto aten_output = (t0 + 1).unsqueeze(-1) + t3; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + testValidate( + &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionIssue507_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); + fusion.addOutput(tv2); + + tv1->setMemoryType(MemoryType::Shared); + + tv1->axis(1)->parallelize(ParallelType::TIDx); + tv2->axis(1)->parallelize(ParallelType::TIDx); + tv1->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(0)->parallelize(ParallelType::BIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + std::vector t0_shape{17, 19}; + auto aten_input = at::randn(t0_shape, options); + auto t1 = (aten_input + 1); + auto aten_output = (t1 + 1); + + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); + auto cg_outputs = fe.runFusion({aten_input}); + + testValidate( + &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionIssue532_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Algorithm + TensorView* tv0 = makeSymbolicTensor(1); + TensorView* tv1 = add(tv0, IrBuilder::create(1)); + TensorView* tv2 = add(tv1, IrBuilder::create(1)); + fusion.addInput(tv0); + fusion.addOutput(tv2); + + const int M_BLOCK = 64; + const int M_THREAD = 4; + + tv2->split(0, M_BLOCK); + // tv2: [M/M_BLOCK, M_BLOCK] + tv1->computeAt(tv2, 1); + // tv1: [M/M_BLOCK, M_BLOCK] + + tv1->split(-1, M_BLOCK / M_THREAD); + // tv1: [M/M_BLOCK, M_THREAD, M_BLOCK / M_THREAD] + + tv2->split(-1, M_THREAD); + // tv2: [M/M_BLOCK, M_BLOCK / M_THREAD, M_THREAD] + + constexpr int M = 1000; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + at::Tensor t0 = at::randn({M}, options); + std::vector aten_inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto outputs = fe.runFusion(aten_inputs); + + at::Tensor aten_output = t0 + 1 + 1; + + testValidate( + &fusion, outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionLoopUnswitch_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Algorithm + TensorView* tv0 = makeSymbolicTensor(1); + TensorView* tv1 = add(tv0, IrBuilder::create(1)); + TensorView* tv2 = add(tv1, IrBuilder::create(1)); + fusion.addInput(tv0); + fusion.addOutput(tv2); + + tv2->split(0, 32); + tv1->computeAt(tv2, -1); + + tv2->axis(1)->parallelize(ParallelType::Unswitch); + + constexpr int M = 1000; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + at::Tensor t0 = at::randn({M}, options); + std::vector aten_inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto outputs = fe.runFusion(aten_inputs); + + at::Tensor aten_output = t0 + 1 + 1; + + testValidate( + &fusion, outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionIssue549_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + TensorView* tv0 = makeSymbolicTensor(2); // M, K + TensorView* tv1 = makeSymbolicTensor(2); // K, N + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = add(tv0, IrBuilder::create(1)); + + TensorView* tv3 = broadcast(tv2, {false, false, true}); + // tv3[I0, I1, B] = tv0[I0, I1] + + TensorView* tv4 = broadcast(tv1, {true, false, false}); + // tv4[B, I1, I2] = tv1[I1, I2] + + // tv5[I0, I1, I2] = tv3[I0, I1, B] * tv4[B, I1, I2] + TensorView* tv5 = mul(tv3, tv4); + // tv6[I0, R1, I2] = tv5[I0, I1, I2] + TensorView* tv6 = sum(tv5, {1}); + fusion.addOutput(tv6); + + tv6->split(1, 32); + // tv6[I0, R1o, R1i{32}, I2] + + auto tv7 = tv6->rFactor({1}); + // tv7[I0, R1o, I1i{32}, I2] = tv5[I0, I1, I2] + // tv6[I0, , R1i{32}, I2] = tv7[I0, R1o, I1i{32}, I2] + + tv6->split(0, 4); + tv6->split(-1, 4); + // tv6[I0o, I0i{4}, R1i{32}, I2o, I2i{4}] + // tv6[I0o, I0i{4}, R1i{32}, I2o, I2i{4}] + + tv0->computeAt(tv6, -1); + tv1->computeAt(tv6, -1); + + // tv7[I0o, I0i{4}, R1o, I1i{32}, I2o, I2i{4}] + // tv6[I0o, I0i{4}, , R1i{32}, I2o, I2i{4}] + //--> (line symbolizes compute at location) + // tv5[I0o, I0i{4}, I1i{32}, I2o, I2i{4}|, I1o] + // tv7[I0o, I0i{4}, I1i{32}, I2o, I2i{4}|, R1o] + // tv6[I0o, I0i{4}, R1i{32}, I2o, I2i{4}|] + + tv0->computeAt(tv7, -1); + tv1->computeAt(tv7, -1); + // tv5[I0o, I0i{4}, I1i{32}, I2o, I2i{4}, I1o |] + // tv7[I0o, I0i{4}, I1i{32}, I2o, I2i{4}, R1o |] + // tv6[I0o, I0i{4}, R1i{32}, I2o, I2i{4}|] + + tv6->axis(0)->parallelize(ParallelType::BIDz); + tv6->axis(1)->parallelize(ParallelType::TIDz); + + tv6->axis(-2)->parallelize(ParallelType::BIDy); + tv6->axis(-1)->parallelize(ParallelType::TIDy); + + tv6->axis(2)->parallelize(ParallelType::TIDx); + tv7->axis(2)->parallelize(ParallelType::TIDx); + + constexpr int M = 65, K = 33, N = 17; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor t0 = at::randn({M, K}, options); + at::Tensor t1 = at::randn({K, N}, options); + + // Lets specify a few bounds in launch params to make sure it works + LaunchParams lparams(1, -1, -1, 32, 4, 4); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t1}, lparams); + fe.runFusion({t0, t1}, lparams); + + // Make sure bad launch params throws + // TODO: Re-enable once we have parallelization validation in. + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) + // ASSERT_ANY_THROW(fe.runFusion({t0, t1}, LaunchParams(1, 2, 3, 4, 5, 6))); + + // Don't specify any launch params + auto cg_outputs = fe.runFusion({t0, t1}); + + auto aten_output = (t0 + 1).to(at::kDouble).matmul(t1.to(at::kDouble)); + + testValidate( + &fusion, cg_outputs, {t0, t1}, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionSimpleCompileRtc_CUDA) { + FusionExecutor fe; + std::string kernel = R"( +__global__ void kernel1(Tensor T0, Tensor T1) { + if(threadIdx.x==0){ + for(size_t ki28 = 0; ki28 < T0.size[0]; ++ki28) { + T1[ki28*T1.stride[0]] = T0[ki28*T0.stride[0]]*2; + } + } +} + )"; + fe.compileRtc(kernel, "CudaCodeGen::kernel1"); + LaunchParams lp( + 256, // gdimx + 1, // gdimy + 1, // gdimz + 1, // bdimx + 1, // bdimy + 1 // bdimz + ); + lp.setSmem(0); + const auto options = + at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + const std::vector tensor_dims = {8}; + auto in0 = at::randn(tensor_dims, options); + auto out0 = at::empty_like(in0); + fe.runRtc(lp, {in0, out0}); + + auto out_ref = in0 * 2; + TORCH_CHECK(out_ref.allclose(out0)); +} + +TEST_F(NVFuserTest, FusionSerialWelford_CUDA) { + FusionExecutor fe; + int x = 128, y = 64, z = 64; + + std::string kernel = R"( +__global__ void kernel1( + Tensor inp, + Tensor out_var, + Tensor out_avg +){ + for(int i0=0;i0 tensor_dims = {x, y, z}; + auto in0 = at::randn(tensor_dims, options); + auto out_var = at::empty({x}, options); + auto out_avg = at::empty({x}, options); + fe.runRtc(lp, {in0, out_var, out_avg}); + + TORCH_CHECK(in0.var({1, 2}, false).allclose(out_var)); + TORCH_CHECK(in0.mean({1, 2}).allclose(out_avg, /*rtol*/ 1e-5, /*atol*/ 1e-6)); +} + +TEST_F(NVFuserTest, FusionBlockWelford_CUDA) { + FusionExecutor fe; + int x = 7, y = 8, z = 9; + + std::string kernel = R"( +__global__ void kernel1( + Tensor inp, + Tensor out_avg, + Tensor out_var, + Tensor init_avg, + Tensor init_var, + Tensor init_N +){ + //actual generated kernel will use dynamic shared mem, + // here is just for prototype + __shared__ float mem_avg[512]; + __shared__ float mem_M2[512]; + __shared__ long mem_N[512]; + float in=inp[threadIdx.x*inp.stride[0]+ + threadIdx.y*inp.stride[1]]; + float tmp_avg=0; + float tmp_M2=0; + long tmp_N=0; + blockWelford( + tmp_avg, + tmp_M2, + tmp_N, + in, + 0.f, + (long)1, + threadIdx, + blockDim, + (float*)mem_avg, + (float*)mem_M2, + (long*)mem_N, + (bool)(threadIdx.x tensor_dims = {x, y}; + const std::vector init_dims = {x, z}; + + // generate initial values + auto init_in = at::randn(init_dims, options); + auto init_var = init_in.var({1}, false); + auto init_avg = init_in.mean({1}); + auto init_N = + at::tensor(z, at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0)); + + auto in0 = at::randn(tensor_dims, options); + + // run kernel + auto out_var = at::zeros({x}, options); + auto out_avg = at::zeros({x}, options); + fe.runRtc(lp, {in0, out_avg, out_var, init_avg, init_var, init_N}); + + // compare with reference output + auto cat_tensor = at::cat({init_in, in0}, 1); + TORCH_CHECK(cat_tensor.var({1}, false).allclose(out_var)); + TORCH_CHECK( + cat_tensor.mean({1}).allclose(out_avg, /*rtol*/ 1e-5, /*atol*/ 1e-6)); +} + +TEST_F(NVFuserTest, FusionBlockWelfordNoInit_CUDA) { + FusionExecutor fe; + int x = 7, y = 8, z = 9; + + // need support IValue for integer input as initial count + std::string kernel = R"( +__global__ void kernel1( + Tensor inp, + Tensor out_avg, + Tensor out_var +){ + //actual generated kernel will use dynamic shared mem, + // here is just for prototype + __shared__ float mem_avg[512]; + __shared__ float mem_M2[512]; + __shared__ long mem_N[512]; + float in=inp[threadIdx.x*inp.stride[0]+ + threadIdx.y*inp.stride[1]+ + threadIdx.z*inp.stride[2]]; + float tmp_avg=0; + float tmp_M2=0; + long tmp_N=0; + block_sync::init(); + blockWelford( + tmp_avg, + tmp_M2, + tmp_N, + in, + 0.f, + (long) 1, + threadIdx, + blockDim, + (float*)mem_avg, + (float*)mem_M2, + (long*)mem_N, + (bool)(threadIdx.x tensor_dims = {x, y, z}; + auto in0 = at::randn(tensor_dims, options); + auto out_var = at::empty({x}, options); + auto out_avg = at::empty({x}, options); + fe.runRtc(lp, {in0, out_avg, out_var}); + + TORCH_CHECK(in0.var({1, 2}, false).allclose(out_var)); + TORCH_CHECK(in0.mean({1, 2}).allclose(out_avg, /*rtol*/ 1e-5, /*atol*/ 1e-6)); +} + +TEST_F(NVFuserTest, FusionGridWelfordNoInit_CUDA) { + FusionExecutor fe; + int x = 128, y = 64, z = 128; + + std::string kernel = R"( +__global__ void kernel1( + Tensor inp, + Tensor out_avg, + Tensor out_var, + Tensor work_buf_avg, + Tensor work_buf_M2, + Tensor work_buf_N, + Tensor sync_flag +){ + __shared__ float shared_buf_avg[512]; + __shared__ float shared_buf_M2[512]; + __shared__ long shared_buf_N[512]; + float tmp_avg=0; + float tmp_M2=0; + long tmp_N=0; + float in = inp[ blockIdx.x * inp.stride[0]+ + blockIdx.y * inp.stride[1]+ + threadIdx.x * inp.stride[2]]; + block_sync::init(); + welford::gridWelford< + true,true,false, + true,false,false, + false + >( + tmp_avg, + tmp_M2, + tmp_N, + in, + 0.f, + (long) 1, + &work_buf_avg[0], + &work_buf_M2[0], + &work_buf_N[0], + sync_flag, + (float*)shared_buf_avg, + (float*)shared_buf_M2, + (long*)shared_buf_N, + threadIdx.x tensor_dims = {x, y, z}; + auto in0 = at::randn(tensor_dims, options); + + auto out_avg = at::empty({z}, options); + auto out_var = at::empty({z}, options); + auto work_buf_avg = at::empty({x * y * z}, options); + auto work_buf_var = at::empty({x * y * z}, options); + auto work_buf_N = at::empty({x * y * z}, options_int); + auto sync_flag = at::zeros({1}, options_int); + fe.runRtc( + lp, + {in0, + out_avg, + out_var, + work_buf_avg, + work_buf_var, + work_buf_N, + sync_flag}); + std::vector dims{0, 1}; + + TORCH_CHECK(in0.mean(dims).allclose(out_avg, /*rtol*/ 1e-5, /*atol*/ 1e-6)); + TORCH_CHECK(in0.var(dims, false).allclose(out_var)); +} + +TEST_F(NVFuserTest, FusionWelfordOp_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + int M = 64, N = 128; + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto tv1 = mul(tv0, IrBuilder::create(1)); + auto tvs = Welford(tv1, {1}); + auto tv_avg = tvs.avg; + auto tv_M2 = tvs.var_sum; + auto tv_N = tvs.n; + fusion.addOutput(tv_avg); + fusion.addOutput(tv_M2); + fusion.addOutput(tv_N); + + tv_avg->split(1, 32); + tv_avg->split(0, 32); + tv_avg->split(0, 4); + tv_avg->reorder({{-1, -3}, {-3, -1}}); + tv1->computeAt(tv_avg, -1); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto options_int = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0); + at::manual_seed(0); + at::Tensor t0 = at::randn({M, N}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto outputs = fe.runFusion({t0}); + + // by default Welford outputs sum of square diff so need to divide to get var + outputs[1] /= N; + + testValidate( + fe.kernel(), + outputs, + {t0}, + {t0.mean({1}), t0.var({1}, false), at::ones({M}, options_int) * N}, + __LINE__, + __FILE__); +} + +TEST_F(NVFuserTest, FusionBlockWelfordOp_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + int M = 64, N = 128; + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto tv1 = mul(tv0, IrBuilder::create(1)); + auto tvs = Welford(tv1, {1}); + auto tv_avg = tvs.avg; + auto tv_M2 = tvs.var_sum; + auto tv_N = tvs.n; + fusion.addOutput(tv_avg); + fusion.addOutput(tv_M2); + fusion.addOutput(tv_N); + + tv_avg->axis(-1)->parallelize(ParallelType::TIDx); + + tv1->computeAt(tv_avg, -1); + + // + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto options_int = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0); + at::manual_seed(0); + at::Tensor t0 = at::randn({M, N}, options); + at::Tensor t_var = at::empty({M}, options); + at::Tensor t_avg = at::empty({M}, options); + at::Tensor t_N = at::empty({M}, options_int); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto outputs = fe.runFusion({t0}); + + // by default Welford outputs sum of square diff so need to divide to get var + outputs[1] /= N; + + testValidate( + fe.kernel(), + outputs, + {t0}, + {t0.mean({1}), t0.var({1}, false), at::ones({M}, options_int) * N}, + __LINE__, + __FILE__); +} + +TEST_F(NVFuserTest, FusionGridWelfordOp_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + int M = 64, N = 128; + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto tv1 = mul(tv0, IrBuilder::create(1)); + auto tvs = Welford(tv1, {1}); + auto tv_avg = tvs.avg; + auto tv_M2 = tvs.var_sum; + auto tv_N = tvs.n; + fusion.addOutput(tv_avg); + fusion.addOutput(tv_M2); + fusion.addOutput(tv_N); + + tv_avg->axis(0)->parallelize(ParallelType::TIDx); + tv_avg->axis(-1)->parallelize(ParallelType::BIDx); + + tv1->computeAt(tv_avg, -1); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto options_int = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0); + at::manual_seed(0); + at::Tensor t0 = at::randn({M, N}, options); + at::Tensor t_avg = at::empty({M}, options); + at::Tensor t_var = at::empty({M}, options); + at::Tensor t_N = at::empty({M}, options_int); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto outputs = fe.runFusion({t0}); + + // by default Welford outputs sum of square diff so need to divide to get var + outputs[1] /= N; + + testValidate( + fe.kernel(), + outputs, + {t0}, + {t0.mean({1}), t0.var({1}, false), at::ones({M}, options_int) * N}, + __LINE__, + __FILE__); +} + +TEST_F(NVFuserTest, FusionRfactorWelfordOp_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + int M = 64, N = 128; + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto tv1 = mul(tv0, IrBuilder::create(1)); + auto tvs = Welford(tv1, {1}); + auto tv_avg = tvs.avg; + auto tv_M2 = tvs.var_sum; + auto tv_N = tvs.n; + fusion.addOutput(tv_avg); + fusion.addOutput(tv_M2); + fusion.addOutput(tv_N); + + tv_avg->split(1, 4); + ir_utils::rfactorHelper(tvs.avg, {2}); + tv1->computeAt(tv_avg, -1); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto options_int = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0); + at::manual_seed(0); + at::Tensor t0 = at::randn({M, N}, options); + at::Tensor t_avg = at::empty({M}, options); + at::Tensor t_var = at::empty({M}, options); + at::Tensor t_N = at::empty({M}, options_int); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto outputs = fe.runFusion({t0}); + + // by default Welford outputs sum of square diff so need to divide to get var + outputs[1] /= N; + + testValidate( + fe.kernel(), + outputs, + {t0}, + {t0.mean({1}), t0.var({1}, false), at::ones({M}, options_int) * N}, + __LINE__, + __FILE__); +} + +TEST_F(NVFuserTest, FusionWelfordSchedule_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + int M = 64, N = 128; + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto tv1 = mul(tv0, IrBuilder::create(1)); + auto tvs = Welford(tv1, {1}); + auto tv_avg = tvs.avg; + auto tv_M2 = tvs.var_sum; + auto tv_N = tvs.n; + fusion.addOutput(tv_avg); + fusion.addOutput(tv_M2); + fusion.addOutput(tv_N); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto options_int = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0); + at::manual_seed(0); + at::Tensor t0 = at::randn({M, N}, options); + // TODO: Why do we use launch params from here, but not scheduling??? + auto reduction_params = getReductionHeuristics(&fusion, {t0}); + scheduleReduction(&fusion, *reduction_params); + + auto lparams = reduction_params->lparams; + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}, lparams); + auto outputs = fe.runFusion({t0}, lparams); + + // by default Welford outputs sum of square diff so need to divide to get var + outputs[1] /= N; + + auto at_avg = t0.mean({1}); + auto at_var = t0.var({1}, false); + auto at_n = at::ones({M}, options_int) * N; + + testValidate( + fe.kernel(), + outputs, + {t0}, + {at_avg, at_var, at_n}, + __LINE__, + __FILE__, + "validate welford", + reduction_params->lparams); +} + +namespace { +void testWelford(DataType dtype, int red_axis, int odim, int rdim) { + const int axis = red_axis; + at::ScalarType aten_dtype = data_type_to_aten(dtype); + + Fusion fusion; + FusionGuard fg(&fusion); + TensorView* tv0 = makeSymbolicTensor(2, dtype); + bool is_fp16 = dtype == DataType::Half; + bool is_bf16 = dtype == DataType::BFloat16; + TensorView* tv0_cast = tv0; + if (is_fp16 || is_bf16) { + tv0_cast = castOp(DataType::Float, tv0); + } + fusion.addInput(tv0); + auto tv1 = mul(tv0_cast, IrBuilder::create(1)); + auto tvs = Welford(tv1, {axis}); + auto tv_avg = tvs.avg; + auto tv_M2 = tvs.var_sum; + auto tv_N = tvs.n; + + TensorView* avg_cast = tv_avg; + TensorView* M2_cast = tv_M2; + + if (is_fp16) { + avg_cast = castOp(DataType::Half, tv_avg); + M2_cast = castOp(DataType::Half, tv_M2); + } + if (is_bf16) { + avg_cast = castOp(DataType::BFloat16, tv_avg); + M2_cast = castOp(DataType::BFloat16, tv_M2); + } + + fusion.addOutput(avg_cast); + fusion.addOutput(M2_cast); + fusion.addOutput(tv_N); + + auto options = at::TensorOptions().dtype(aten_dtype).device(at::kCUDA, 0); + auto options_int = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0); + at::manual_seed(0); + std::vector outputs_of_red; + at::Tensor aten_input = + (axis ? at::randn({odim, rdim}, options) + : at::randn({rdim, odim}, options)); + + if (is_fp16 || is_bf16) { + outputs_of_red.push_back(avg_cast); + outputs_of_red.push_back(M2_cast); + } + + auto reduction_params = getReductionHeuristics(&fusion, {aten_input}); + scheduleReduction(&fusion, *reduction_params); + + auto lparams = reduction_params->lparams; + + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}, lparams); + auto outputs = fe.runFusion({aten_input}, lparams); + + // by default Welford outputs sum of square diff so need to divide to + // get var + + outputs[1] /= rdim; + + auto at_avg = aten_input.mean({axis}); + auto at_var = aten_input.var({axis}, false); + auto at_n = + (axis ? at::ones({odim, rdim}, options) + : at::ones({rdim, odim}, options)); + at_n = at_n.sum({axis}); + + testValidate( + fe.kernel(), + outputs, + {aten_input}, + {at_avg, at_var, at_n}, + __LINE__, + __FILE__, + "validate welford", + reduction_params->lparams); +} +} // namespace + +TEST_F(NVFuserTest, FusionWelfordShmoo_CUDA) { + std::vector dtypes = { + DataType::Double, DataType::Float, DataType::Half}; + // TODO: enable this for complex. Currently, complex yields + // silent wrong results: + // Detected abs error of: 3.8062 + // absolute tolerance was set to 2.23704e-06 + // and relative tolerance set to 2.23704e-08 +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 + if (at::cuda::getDeviceProperties(0)->major >= 8) { + dtypes.insert(dtypes.end(), DataType::BFloat16); + } +#endif + + std::vector red_axis = {1, 0}; + std::vector output_dims = {160, 320}; + std::vector red_dims; + + // Tried to cut down the number iterations with just + // doing every other power of 2. + for (int i = 1; i <= 1024 * 1024; i <<= 2) { + red_dims.push_back(i); + } + + for (auto dtype : dtypes) { + for (auto& axis : red_axis) { + for (auto& odim : output_dims) { + for (auto& rdim : red_dims) { + // TODO: original welford algorithm actually keeps a running sum of + // squares, i.e. M_{2n} in the + // cf: + // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance + // algorithm notation, and it can reach inf for large numbers + // with half precision. skipping too large volumes for half for + // nwo might need further numerical experiments to re-design + // this. + if (rdim > 32768 && + (dtype == DataType::Half || dtype == DataType::BFloat16)) { + continue; + } + testWelford(dtype, axis, odim, rdim); + } + } + } + } +} + +namespace { +void testVarMean(at::ScalarType dtype, int correction, bool keepdim) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + int M = 64, N = 128; + + auto tv0 = makeSymbolicTensor(2, aten_to_data_type(dtype)); + fusion->addInput(tv0); + auto tvs = variance_mean(tv0, {1}, correction, keepdim); + auto tv_mean = tvs.mean; + auto tv_var = tvs.var; + fusion->addOutput(tv_var); + fusion->addOutput(tv_mean); + + auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA, 0); + at::manual_seed(0); + at::Tensor t0 = at::randn({M, N}, options); + + FusionExecutorCache executor_cache(std::move(fusion)); + auto outputs = executor_cache.runFusionWithInputs({t0}); + + auto at_var_mean = at::var_mean(t0, {1}, correction, keepdim); + std::vector aten_outputs = { + std::get<0>(at_var_mean), std::get<1>(at_var_mean)}; + + testValidate( + executor_cache.fusion(), outputs, {t0}, aten_outputs, __LINE__, __FILE__); +} +} // namespace + +TEST_F(NVFuserTest, FusionVarMean_CUDA) { + std::vector dtypes = {at::kFloat, at::kDouble}; + std::vector corrections = {0, 1}; + std::vector keepdims = {false, true}; + for (auto correction : corrections) { + for (auto keepdim : keepdims) { + for (auto dtype : dtypes) { + testVarMean(dtype, correction, keepdim); + } + } + } +} + +TEST_F(NVFuserTest, FusionSimpleGemmTransposed_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + + TensorView* tv0 = makeSymbolicTensor(2); // K, M + TensorView* tv1 = makeSymbolicTensor(2); // N, K + fusion.addInput(tv0); + fusion.addInput(tv1); + + TensorView* tv0_t = transpose(tv0); + TensorView* tv1_t = transpose(tv1); + + TensorView* tv2 = broadcast(tv0_t, {false, false, true}); + // tv2[I0, I1, B] = tv0[I0, I1] + + TensorView* tv3 = broadcast(tv1_t, {true, false, false}); + // tv3[B, I1, I2] = tv1[I1, I2] + + // tv4[I0, I1, I2] = tv2[I0, I1, B] * tv3[B, I1, I2] + TensorView* tv4 = mul(tv2, tv3); + // tv5[I0, R1, I2] = tv4[I0, I1, I2] + TensorView* tv5 = sum(tv4, {1}); + fusion.addOutput(tv5); + + tv5->split(1, 32); + // tv5[I0, R1o, R1i{32}, I2] + + auto tv6 = tv5->rFactor({1}); + // tv6[I0, R1o, I1i{32}, I2] = tv4[I0, I1, I2] + // tv5[I0, , R1i{32}, I2] = tv6[I0, R1o, I1i{32}, I2] + + tv5->split(0, 4); + tv5->split(-1, 4); + // tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}] + // tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}] + + tv0_t->computeAt(tv5, -1); + tv1_t->computeAt(tv5, -1); + + // tv6[I0o, I0i{4}, R1o, I1i{32}, I2o, I2i{4}] + // tv5[I0o, I0i{4}, , R1i{32}, I2o, I2i{4}] + //--> (line symbolizes compute at location) + // tv4[I0o, I0i{4}, I1i{32}, I2o, I2i{4}|, I1o] + // tv6[I0o, I0i{4}, I1i{32}, I2o, I2i{4}|, R1o] + // tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}|] + + tv0_t->computeAt(tv6, -1); + tv1_t->computeAt(tv6, -1); + // tv4[I0o, I0i{4}, I1i{32}, I2o, I2i{4}, I1o |] + // tv6[I0o, I0i{4}, I1i{32}, I2o, I2i{4}, R1o |] + // tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}|] + + tv5->axis(0)->parallelize(ParallelType::BIDz); + tv5->axis(1)->parallelize(ParallelType::TIDz); + + tv5->axis(-2)->parallelize(ParallelType::BIDy); + tv5->axis(-1)->parallelize(ParallelType::TIDy); + + tv5->axis(2)->parallelize(ParallelType::TIDx); + tv6->axis(2)->parallelize(ParallelType::TIDx); + + constexpr int M = 65, K = 33, N = 17; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor t0 = at::randn({K, M}, options); + at::Tensor t1 = at::randn({N, K}, options); + + // Lets specify a few bounds in launch params to make sure it works + LaunchParams lparams(1, -1, -1, 32, 4, 4); + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t1}, lparams); + fe.runFusion({t0, t1}, lparams); + + // Don't specify any launch params + auto cg_outputs = fe.runFusion({t0, t1}); + + auto aten_output = t0.t().to(at::kDouble).matmul(t1.t().to(at::kDouble)); + + testValidate( + &fusion, cg_outputs, {t0, t1}, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionSoftmax3DTransposed_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + const int tidx = 32; + const int dimx = 32; + const int dimy = 16; + const int dimz = 130; + + // Set up your input tensor views + TensorView* input_tv0 = makeSymbolicTensor(3); + fusion.addInput(input_tv0); + + TensorView* input_t = transpose(input_tv0, 1, 2); + + TensorView* exp_tv1 = unaryOp(UnaryOpType::Exp, input_t); + TensorView* sum_exp_tv2 = sum(exp_tv1, {-1}); + TensorView* bcast_sum_tv3 = broadcast(sum_exp_tv2, {false, false, true}); + + // Replicate exp_tv4 as exp_tv4_copy because exp_tv4 is going to be + // computed at sum_exp_rf_tv8. + TensorView* input_t_copy = transpose(input_tv0, 1, 2); + TensorView* exp_tv1_copy = unaryOp(UnaryOpType::Exp, input_t_copy); + + TensorView* output_tv4 = div(exp_tv1_copy, bcast_sum_tv3); + + fusion.addOutput(output_tv4); + + bcast_sum_tv3->split(-1, tidx); + + sum_exp_tv2->split(-1, tidx); + TensorView* sum_exp_rf_tv5 = sum_exp_tv2->rFactor({-2}); + + output_tv4->split(-1, tidx); + + input_t->computeAt(sum_exp_rf_tv5, -1); + input_t_copy->computeAt(output_tv4, -1); + + TensorView* tensors_to_parallelize[] = { + sum_exp_tv2, bcast_sum_tv3, output_tv4, sum_exp_rf_tv5}; + + for (auto tv : tensors_to_parallelize) { + tv->axis(0)->parallelize(ParallelType::BIDx); + tv->axis(1)->parallelize(ParallelType::BIDy); + tv->axis(-1)->parallelize(ParallelType::TIDx); + } + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input = at::randn({dimx, dimz, dimy}, options); + + at::Tensor cg_output = at::empty({dimx, dimy, dimz}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input}); + fe.runFusion({input}, {cg_output}); + + auto aten_input_t = at::transpose(input, 1, 2); + auto aten_output = at::_softmax(aten_input_t.to(at::kDouble), -1, false); + + testValidate( + &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionAdvancedComputeAtTransposed1_CUDA) { + // Case 1 + // tv1 = tv0 * 0.5 + // tv2 = tv1 * -1 + // tv3 = tv1 + 3 + // tv4 = tv1 * 2 + // tv5 = tv3 + tv2 + // tv6 = tv5 + tv4 + // tv7 = tv1 + tv4 + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + tv0 = transpose(tv0); + + TensorView* tv1 = mul(tv0, IrBuilder::create(0.5)); + TensorView* tv2 = mul(tv1, IrBuilder::create(-1.0)); + TensorView* tv3 = add(tv1, IrBuilder::create(3.0)); + TensorView* tv4 = mul(tv1, IrBuilder::create(2.0)); + TensorView* tv5 = add(tv3, tv2); + + TensorView* tv6 = add(tv5, tv4); + TensorView* tv7 = add(tv1, tv4); + + fusion.addOutput(tv6); + fusion.addOutput(tv7); + + // Lets setup to actually run + tv7->merge(0); + tv7->split(0, 128); + tv7->split(0, 4); + + tv7->axis(0)->parallelize(ParallelType::BIDx); + + tv0->computeAt(tv7, 1); + + // The this-position of the last tensor should be zero. + TORCH_CHECK( + tv7->nDims() == 3 && tv7->getComputeAtPosition() == 0 && + tv7->getMaxProducerPosition() == 1); + TORCH_CHECK( + tv6->nDims() == 3 && tv6->getComputeAtPosition() == 0 && + tv6->getMaxProducerPosition() == 1); + // The position of every other tensor should be 1. + for (auto tv : {tv1, tv2, tv3, tv4, tv5}) { + TORCH_CHECK(tv->nDims() == 3 && tv->getComputeAtPosition() == 1); + } + + for (Val* val : fusion.vals()) { + if (!val->isFusionInput() && + val->getValType().value() == ValType::TensorView) { + TensorView* tv = static_cast(val); + tv->axis(1)->parallelize(ParallelType::Unroll); + tv->axis(-1)->parallelize(ParallelType::TIDx); + } + } + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor aten_input = at::randn({129, 127}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); + auto cg_outputs = fe.runFusion({aten_input}); + + at::Tensor aten_input_t = aten_input.t(); + + auto t1 = aten_input_t.mul({0.5}); + auto t2 = t1.mul({-1.0}); + auto t3 = t1.add({3.0}); + auto t4 = t1.mul({2.0}); + auto t5 = t3.add(t2); + auto t6 = t5.add(t4); + auto t7 = t1.add(t4); + + std::vector aten_outputs = {t6, t7}; + + testValidate( + &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionAdvancedComputeAtTransposed2_CUDA) { + // Case 2 + // tv1 = tv0 * -1 + // tv2 = tv0 + 3 + // tv3 = tv0 * 2 + // tv4 = tv2 + tv1 + // tv5 = tv4 + tv3 + // tv6 = tv5 + tv3 + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + tv0 = transpose(tv0); + + TensorView* tv1 = mul(tv0, IrBuilder::create(-1.0)); + TensorView* tv2 = add(tv0, IrBuilder::create(3.0)); + TensorView* tv3 = mul(tv0, IrBuilder::create(2.0)); + TensorView* tv4 = add(tv2, tv1); + + TensorView* tv5 = add(tv4, tv3); + TensorView* tv6 = add(tv5, tv3); + + fusion.addOutput(tv5); + fusion.addOutput(tv6); + + // Lets setup to actually run + tv6->merge(0); + tv6->split(0, 128); + tv6->split(0, 4); + + tv6->axis(0)->parallelize(ParallelType::BIDx); + + tv0->computeAt(tv6, 1); + + for (Val* val : fusion.vals()) { + if (!val->isFusionInput() && + val->getValType().value() == ValType::TensorView) { + TensorView* tv = static_cast(val); + + tv->axis(1)->parallelize(ParallelType::Unroll); + tv->axis(-1)->parallelize(ParallelType::TIDx); + } + } + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input = at::randn({129, 127}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input}); + auto cg_outputs = fe.runFusion({input}); + + auto input_t = input.t(); + auto t1 = input_t.mul({-1.0}); + auto t2 = input_t.add({3.0}); + auto t3 = input_t.mul({2.0}); + auto t4 = t2.add(t1); + auto t5 = t4.add(t3); + auto t6 = t5.add(t3); + + std::vector aten_outputs = {t5, t6}; + + testValidate(&fusion, cg_outputs, {input}, aten_outputs, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionAdvancedComputeAtTransposed3_CUDA) { + // Case 3 + // T2 = T1 * 0.979361 + // T3 = T2 * T0 + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeSymbolicTensor(4); + fusion.addInput(tv0); + + tv0 = permute(tv0, {3, 0, 1, 2}); + + TensorView* tv1 = makeSymbolicTensor(4); + fusion.addInput(tv1); + + tv1 = permute(tv1, {3, 0, 1, 2}); + + TensorView* tv2 = mul(tv1, IrBuilder::create(.979361)); + TensorView* tv3 = mul(tv2, tv0); + + fusion.addOutput(tv3); + + // Lets setup to actually run + while (tv3->nDims() > 1) + tv3->merge(0); + tv3->split(0, 128); + tv3->split(0, 4); + + tv0->computeAt(tv3, 1); + tv1->computeAt(tv3, 1); + + tv3->axis(0)->parallelize(ParallelType::BIDx); + + for (Val* val : fusion.vals()) { + if (!val->isFusionInput() && + val->getValType().value() == ValType::TensorView) { + TensorView* tv = static_cast(val); + + tv->axis(1)->parallelize(ParallelType::Unroll); + tv->axis(-1)->parallelize(ParallelType::TIDx); + } + } + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({129, 127, 63, 65}, options); + at::Tensor t1 = at::rand_like(t0, options); + + std::vector aten_inputs = {t0, t1}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + auto t0_t = t0.permute({3, 0, 1, 2}); + auto t1_t = t1.permute({3, 0, 1, 2}); + auto t2 = t1_t.mul({0.979361}); + auto aten_output = t2.mul(t0_t); + + testValidate( + &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionAdvancedComputeAtTransposed4_CUDA) { + // Case 4 + // T4 = T2 - T3 + // T5 = T1 + T4 + // T6 = T5 - T0 + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeSymbolicTensor(4); + fusion.addInput(tv0); + + tv0 = permute(tv0, {3, 0, 1, 2}); + + TensorView* tv1 = makeSymbolicTensor(4); + fusion.addInput(tv1); + + tv1 = permute(tv1, {3, 0, 1, 2}); + + TensorView* tv2 = makeSymbolicTensor(4); + fusion.addInput(tv2); + + tv2 = permute(tv2, {3, 0, 1, 2}); + + TensorView* tv3 = makeSymbolicTensor(4); + fusion.addInput(tv3); + + tv3 = permute(tv3, {3, 0, 1, 2}); + + TensorView* tv4 = sub(tv2, tv3); + TensorView* tv5 = add(tv1, tv4); + TensorView* tv6 = sub(tv5, tv0); + + fusion.addOutput(tv6); + + // Lets setup to actually run + while (tv6->nDims() > 1) + tv6->merge(0); + tv6->split(0, 128); + tv6->split(0, 4); + + tv0->computeAt(tv6, 1); + tv1->computeAt(tv6, 1); + tv2->computeAt(tv6, 1); + tv3->computeAt(tv6, 1); + + tv6->axis(0)->parallelize(ParallelType::BIDx); + + for (Val* val : fusion.vals()) { + if (!val->isFusionInput() && + val->getValType().value() == ValType::TensorView) { + TensorView* tv = static_cast(val); + + tv->axis(1)->parallelize(ParallelType::Unroll); + tv->axis(-1)->parallelize(ParallelType::TIDx); + } + } + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({129, 127, 63, 65}, options); + at::Tensor t1 = at::rand_like(t0, options); + at::Tensor t2 = at::rand_like(t0, options); + at::Tensor t3 = at::rand_like(t0, options); + + std::vector aten_inputs = {t0, t1, t2, t3}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + auto t0_t = t0.permute({3, 0, 1, 2}); + auto t1_t = t1.permute({3, 0, 1, 2}); + auto t2_t = t2.permute({3, 0, 1, 2}); + auto t3_t = t3.permute({3, 0, 1, 2}); + auto t4 = t2_t.sub(t3_t); + auto t5 = t1_t.add(t4); + auto aten_output = t5.sub(t0_t); + + testValidate( + &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionAdvancedComputeAtTransposed5_CUDA) { + // Case 5 + // tv2 = tv0 + 2.0 + // tv3 = tv1 * tv2 + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + TensorView* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + tv0 = transpose(tv0); + TensorView* tv1 = makeSymbolicTensor(2); + fusion.addInput(tv1); + tv1 = transpose(tv1); + TensorView* tv2 = add(tv0, IrBuilder::create(2.0)); + TensorView* tv3 = mul(tv1, tv2); + fusion.addOutput(tv3); + + tv3->merge(0); + tv3->split(-1, 8); + tv3->split(-1, 4); + + tv0->computeAt(tv3, 1); + tv1->computeAt(tv3, 1); + tv3->axis(0)->parallelize(ParallelType::BIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({63, 65}, options); + at::Tensor t1 = at::rand_like(t0, options); + + std::vector aten_inputs = {t0, t1}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + auto t2 = t0.t().add(2.0); + auto aten_output = t1.t().mul(t2); + + testValidate( + &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionAdvancedComputeAtTransposed6_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + tv0 = transpose(tv0); + TensorView* tv1 = makeSymbolicTensor(2); + fusion.addInput(tv1); + tv1 = transpose(tv1); + TensorView* tv2 = add(tv0, IrBuilder::create(2.0)); + TensorView* tv3 = mul(tv1, tv2); + fusion.addOutput(tv3); + + tv2->merge(0); + tv2->split(-1, 8); + tv2->split(-1, 4); + tv3->merge(0); + tv3->split(-1, 8); + + tv0->computeAt(tv3, 1); + tv1->computeAt(tv3, 1); + + tv3->axis(0)->parallelize(ParallelType::BIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({63, 65}, options); + at::Tensor t1 = at::rand_like(t0, options); + + std::vector aten_inputs = {t0, t1}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + auto t2 = t0.t().add(2.0); + auto aten_output = t1.t().mul(t2); + + testValidate( + &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionSegmentReducePointwise_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + TensorView* tv0 = makeSymbolicTensor(2); + TensorView* tv1 = makeSymbolicTensor(1); + TensorView* tv2 = makeSymbolicTensor(2); + + fusion->addInput(tv0); + fusion->addInput(tv1); + fusion->addInput(tv2); + + TensorView* tv3 = add(tv0, IrBuilder::create(1)); // Group 0 + TensorView* tv4 = + max(tv3, {0}); // Group 0 (use max instead to avoid numerical issues) + TensorView* tv5 = add(tv4, tv1); // Group 0 (Non Broadcast after reduce, + // keeps normalization scheduler away) + TensorView* tv6 = add(tv5, tv2); // Group 1 (Broadcast after reduce) + + fusion->addOutput(tv6); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({128, 65}, options); + at::Tensor t1 = at::randn({65}, options); + at::Tensor t2 = at::randn({128, 65}, options); + + auto t3 = t0.add(1.0); + auto t4 = std::get<0>(at::max(t3, 0)); + auto t5 = t4.add(t1); + auto t6 = t5.add(t2); + + FusionExecutorCache executor_cache(std::move(fusion)); + + auto outputs = executor_cache.runFusionWithInputs({t0, t1, t2}); + + TORCH_CHECK( + executor_cache.getMostRecentKernelRuntime()->isSegmented(), + "segmentation didn't happen"); + TORCH_CHECK( + executor_cache.getMostRecentKernelRuntime() + ->fusionSegments() + ->groups() + .size() == 2, + "segmentation didn't happen as expected"); + + testValidate( + executor_cache.fusion(), outputs, {t0, t1, t2}, {t6}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionMultipleVectorize_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + TensorView* tv0 = makeContigTensor(1); + TensorView* tv1 = makeContigTensor(1); + + fusion->addInput(tv0); + fusion->addInput(tv1); + + TensorView* tv3 = add(tv0, tv1); + fusion->addOutput(tv3); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({40960}, options); + at::Tensor t1 = at::randn({40960}, options); + auto t2 = t0 + t1; + + FusionExecutorCache executor_cache(std::move(fusion)); + executor_cache.profile(true); + + auto outputs = executor_cache.runFusionWithInputs({t0, t1}); + auto runtime1 = executor_cache.getMostRecentKernelRuntime(); + auto log1 = + executor_cache.getMostRecentExecutorInfo().params->as(); + TORCH_CHECK(log1 != nullptr); + TORCH_CHECK(log1->vectorize); + + testValidate( + executor_cache.fusion(), outputs, {t0, t1}, {t2}, __LINE__, __FILE__); + + t0 = at::randn({40964}, options); + t1 = at::randn({40964}, options); + t2 = t0 + t1; + + outputs = executor_cache.runFusionWithInputs({t0, t1}); + auto runtime2 = executor_cache.getMostRecentKernelRuntime(); + auto log2 = + executor_cache.getMostRecentExecutorInfo().params->as(); + TORCH_CHECK(log2 != nullptr); + TORCH_CHECK(log2->vectorize); + + testValidate( + executor_cache.fusion(), outputs, {t0, t1}, {t2}, __LINE__, __FILE__); + + t0 = at::randn({40962}, options); + t1 = at::randn({40962}, options); + t2 = t0 + t1; + + outputs = executor_cache.runFusionWithInputs({t0, t1}); + auto runtime3 = executor_cache.getMostRecentKernelRuntime(); + auto log3 = + executor_cache.getMostRecentExecutorInfo().params->as(); + TORCH_CHECK(log3 != nullptr); + TORCH_CHECK(log3->vectorize); + + testValidate( + executor_cache.fusion(), outputs, {t0, t1}, {t2}, __LINE__, __FILE__); + + TORCH_CHECK(runtime1 == runtime2); + TORCH_CHECK(runtime1 != runtime3); +} + +TEST_F(NVFuserTest, FusionVectorizeSimple_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeContigTensor(3); + + fusion.addInput(tv0); + + auto tv1 = unaryOp(UnaryOpType::Sin, tv0); + + fusion.addOutput(tv1); + + auto tv0_cache = tv0->cacheAfter(); + + auto tv1_cache = tv1->cacheBefore(); + + tv1->merge(0); + tv1->merge(0); + tv1->split(0, 4); + tv1->split(0, 128); + + tv1->axis(0)->parallelize(ParallelType::BIDx); + tv1->axis(1)->parallelize(ParallelType::TIDx); + + tv0->computeAt(tv1, 2); + + tv0_cache->axis(2)->parallelize(ParallelType::Vectorize); + tv1->axis(2)->parallelize(ParallelType::Vectorize); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor aten_input = at::empty({2, 6, 32}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {aten_input}); + auto cg_outputs = fe.runFusion({aten_input}); + + at::Tensor aten_output = aten_input.sin(); + + testValidate( + &fusion, cg_outputs, {aten_input}, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionSimpleVectorizeUnroll_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + // dimensionality of the problem + int nDims = 3; + + // Set up your input tensor views + TensorView* tv0 = makeContigTensor(nDims); + TensorView* tv1 = makeContigTensor(nDims); + + // Register your inputs + fusion.addInput(tv0); + fusion.addInput(tv1); + + // Do math with it, it returns a `Val*` but can be static_casted back to + // TensorView + TensorView* tv2 = add(tv1, IrBuilder::create(2.0)); + TensorView* tv3 = add(tv0, tv2); + + // Register your outputs + fusion.addOutput(tv3); + + auto tv0_cache = tv0->cacheAfter(); + auto tv1_cache = tv1->cacheAfter(); + auto tv3_cache = tv3->cacheBefore(); + + // Do transformations, remember, transformations are outputs to inputs + // This doesn't have to be in this order + tv3->merge(1); + + // Split by n_threads + tv3->split(1, 2); + tv3->split(0, 3); + tv3->split(0, 1); + + // [bidx, unswitch, unroll{2}, tidx, vectorize{2}] + + // Parallelize TV3 + tv3->axis(0)->parallelize(ParallelType::BIDx); + tv3->axis(1)->parallelize(ParallelType::Unswitch); + tv3->axis(2)->parallelize(ParallelType::Unroll); + tv3->axis(3)->parallelize(ParallelType::TIDx); + + tv3->reorder({{4, 2}}); + // [bidx, unswitch, vectorize{2}, unroll{2}, tidx] + + TransformPropagatorWithCheck propagator(tv3); + MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + scheduler_utils::parallelizeAllLike(tv3); + + tv0_cache->axis(2)->parallelize(ParallelType::Vectorize); + tv1_cache->axis(2)->parallelize(ParallelType::Vectorize); + tv3->axis(2)->parallelize(ParallelType::Vectorize); + + // For all inputs, computeAt the output inline, temporaries should be squeezed + // between them + tv0->computeAt(tv3, -1, ComputeAtMode::MostInlined); + tv1->computeAt(tv3, -1, ComputeAtMode::MostInlined); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor input1 = at::randn({64, 2, 128}, options); + at::Tensor input2 = at::rand_like(input1); + at::Tensor output = at::empty_like(input1); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input1, input2}); + fe.runFusion({input1, input2}, {output}); + + at::Tensor tv2_ref = input2 + 2.0; + at::Tensor output_ref = input1 + tv2_ref; + + TORCH_CHECK(output_ref.equal(output)); +} + +TEST_F(NVFuserTest, FusionSegmentReduceSoftmax_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + std::vector input_shape{32, 64, 8}; + const int kReductionAxis = 1; + + auto tv0 = TensorViewBuilder() + .ndims(input_shape.size()) + .dtype(DataType::Double) + .build(); + + fusion->addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1.0)); + auto tv2 = sum(tv1, {2}); // Group 0 + + auto output = softmax(tv2, kReductionAxis); // Group 1 + fusion->addOutput(output); + + auto options = at::TensorOptions().dtype(at::kDouble).device(at::kCUDA, 0); + at::Tensor at_x = at::randn(input_shape, options); + + FusionExecutorCache executor_cache(std::move(fusion)); + + auto outputs = executor_cache.runFusionWithInputs({at_x}); + + auto t1 = at_x.add(1.0); + auto t2 = t1.sum({2}); + auto t3 = at::_softmax(t2.to(at::kDouble), -1, false); + + auto optimized_fusion = executor_cache.getMostRecentKernelRuntime(); + TORCH_CHECK(optimized_fusion->isSegmented(), "segmentation didn't happen"); + TORCH_CHECK( + optimized_fusion->fusionSegments()->groups().size() == 2, + "segmentation didn't happen as expected"); + + testValidate( + executor_cache.fusion(), outputs, {at_x}, {t3}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionSwizzle1_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = mul(tv1, IrBuilder::create(2)); + fusion.addOutput(tv2); + + tv2->split(0, 7); + tv2->split(0, 9); + + tv0->computeAt(tv2, 1); + + tv2->axis(0)->parallelize(ParallelType::BIDx); + + tv1->setMemoryType(MemoryType::Shared); + tv1->swizzle(SwizzleType::Transpose, {1, 2}); + + tv1->axis(1)->parallelize(ParallelType::TIDx); + tv1->axis(2)->parallelize(ParallelType::TIDy); + + tv2->axis(1)->parallelize(ParallelType::TIDx); + tv2->axis(2)->parallelize(ParallelType::TIDy); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({100}, options); + + std::vector aten_inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + auto aten_output = (t0 + 1) * 2; + + testValidate( + &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionSwizzle2_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = mul(tv1, IrBuilder::create(2)); + fusion.addOutput(tv2); + + tv1->split(-1, 4); + tv1->split(-2, 4); + + tv2->split(-1, 4); + tv2->split(-2, 4); + + tv0->computeAt(tv2, 1); + + tv2->reorder({{-1, -2}}); + + tv1->setMemoryType(MemoryType::Shared); + tv1->swizzle(SwizzleType::Transpose, {-2, -1}); + + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(-1)->parallelize(ParallelType::TIDx); + tv2->axis(-2)->parallelize(ParallelType::TIDy); + tv1->axis(-1)->parallelize(ParallelType::TIDx); + tv1->axis(-2)->parallelize(ParallelType::TIDy); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({123}, options); + + std::vector aten_inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + auto aten_output = (t0 + 1) * 2; + + testValidate( + &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionGridPersistence_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + auto tv1 = sum(tv0, {0}); + auto tv2 = broadcast(tv1, {true}); + auto tv3 = add(tv0, tv2); + fusion.addOutput(tv3); + + std::vector tvs = {tv1, tv2, tv3}; + for (auto tv : tvs) { + tv->split(0, 2); + tv->axis(0)->parallelize(ParallelType::BIDx); + tv->axis(1)->parallelize(ParallelType::BIDy); + } + + const int numel_x = 10; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input = at::randn({numel_x}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input}); + auto out = fe.runFusion({input}); + + auto aten_output = input.sum({0}).unsqueeze(-1).add(input); + + testValidate(&fusion, out, {input}, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionGridPersistence2_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto tv1 = sum(tv0, {0}); + auto tv2 = broadcast(tv1, {true, false}); + auto tv3 = add(tv0, tv2); + fusion.addOutput(tv3); + + std::vector tvs = {tv1, tv2, tv3}; + for (auto tv : tvs) { + tv->split(0, 2); + tv->axis(0)->parallelize(ParallelType::BIDx); + tv->axis(1)->parallelize(ParallelType::TIDy); + tv->axis(2)->parallelize(ParallelType::TIDx); + } + + const int numel_x = 10; + const int numel_y = 3; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input = at::randn({numel_x, numel_y}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input}); + auto out = fe.runFusion({input}); + + auto aten_output = input.sum({0}).unsqueeze(0).add(input); + + testValidate(&fusion, out, {input}, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionWelfordPersistence_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + auto tvs = Welford(tv0, {0}); + auto tv4 = add(tvs.avg, tvs.var_sum); + auto tv5 = broadcast(tv4, {true}); + auto tv6 = add(tv0, tv5); + fusion.addOutput(tv6); + + std::vector schedule_tvs = { + tvs.avg, tvs.var_sum, tvs.n, tv5, tv6}; + + for (auto tv : schedule_tvs) { + tv->split(0, 2); + tv->axis(0)->parallelize(ParallelType::BIDx); + tv->axis(1)->parallelize(ParallelType::BIDy); + } + + const int numel_x = 10; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input = at::randn({numel_x}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input}); + auto out = fe.runFusion({input}); + + auto aten_output = (input.mean({0}) + (input.var({0}, false) * numel_x)) + .unsqueeze(-1) + .add(input); + + testValidate(&fusion, out, {input}, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionWelfordPersistence2_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto tvs = Welford(tv0, {0}); + auto tv4 = add(tvs.avg, tvs.var_sum); + auto tv5 = broadcast(tv4, {true, false}); + auto tv6 = add(tv0, tv5); + fusion.addOutput(tv6); + + std::vector schedule_tvs = { + tvs.avg, tvs.var_sum, tvs.n, tv5, tv6}; + for (auto tv : schedule_tvs) { + tv->split(0, 2); + tv->axis(0)->parallelize(ParallelType::BIDx); + tv->axis(1)->parallelize(ParallelType::TIDy); + tv->axis(2)->parallelize(ParallelType::TIDx); + } + tv4->axis(0)->parallelize(ParallelType::TIDx); + + const int numel_x = 10; + const int numel_y = 3; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input = at::randn({numel_x, numel_y}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input}); + auto out = fe.runFusion({input}); + + auto aten_output = (input.mean({0}) + (input.var({0}, false) * numel_x)) + .unsqueeze(0) + .add(input); + + testValidate(&fusion, out, {input}, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionIssue633_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + const int dx = 10; + const int dy = 11; + const int dz = 12; + + auto tv0 = makeConcreteTensor({dx, dy, dz}); + fusion.addInput(tv0); + auto tv1 = makeConcreteTensor({dx, dy, 1}); + fusion.addInput(tv1); + auto tv2 = add(tv0, tv1); + fusion.addOutput(tv2); + + tv2->merge(1); + tv2->merge(0); + tv2->split(-1, 128); + + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(1)->parallelize(ParallelType::TIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({dx, dy, dz}, options); + at::Tensor t1 = at::randn({dx, dy, 1}, options); + std::vector aten_inputs = {t0, t1}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + auto aten_output = t0 + t1; + + testValidate( + &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionBroadcastAcrossComputeAt_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + std::vector shape{17, 19}; + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + auto tv1 = makeSymbolicTensor(2); + fusion.addInput(tv1); + auto tv2 = broadcast(tv0, {false, true}); + auto tv3 = add(tv1, tv2); + fusion.addOutput(tv3); + + tv3->split(1, 128); + tv0->computeAt(tv3, 2); + + for (auto tv : {tv2, tv3}) { + tv->axis(-1)->parallelize(ParallelType::TIDx); + } + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({shape[0]}, options); + at::Tensor t1 = at::randn(shape, options); + std::vector aten_inputs = {t0, t1}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + auto t3 = t0.unsqueeze(-1).expand(shape) + t1; + + testValidate(&fusion, cg_outputs, aten_inputs, {t3}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionVectorizeMisalignedPointwise_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(2); + auto tv1 = makeContigTensor(2); + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = add(tv0, tv1); + fusion.addOutput(tv2); + + const int kTDX = 64; + const int kVecSize = 4; + const int kNumElems = kTDX * kVecSize; + + tv2->split(1, kNumElems); + + auto c0 = tv0->cacheAfter(); + auto c1 = tv1->cacheAfter(); + auto c2 = tv2->cacheBefore(); + + tv2->split(-1, kVecSize); + + c0->computeAt(tv2, -2); + c1->computeAt(tv2, -2); + + c0->axis(-1)->parallelize(ParallelType::MisalignedVectorize); + c1->axis(-1)->parallelize(ParallelType::MisalignedVectorize); + + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(-2)->parallelize(ParallelType::TIDx); + tv2->axis(-1)->parallelize(ParallelType::MisalignedVectorize); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + const int bx = 128; + const int by = 457; + at::Tensor t0 = at::randn({bx, by}, options); + at::Tensor t1 = at::randn({bx, by}, options); + + std::vector aten_inputs = {t0, t1}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + auto aten_output = t0 + t1; + testValidate( + &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionVectorizeMisalignedPointwiseMergeContig_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(4); + auto tv1 = makeContigTensor(4); + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = add(tv0, tv1); + fusion.addOutput(tv2); + + tv2->reorder({{0, 1}, {1, 0}}); + tv2->merge(-2); + + const int kTDX = 64; + const int kVecSize = 2; + const int kNumElems = kTDX * kVecSize; + + tv2->split(-1, kNumElems); + + auto c0 = tv0->cacheAfter(); + auto c1 = tv1->cacheAfter(); + auto c2 = tv2->cacheBefore(); + + tv2->split(0, 128); + tv2->split(-1, kVecSize); + + c0->computeAt(tv2, -2); + c1->computeAt(tv2, -2); + + c0->axis(-1)->parallelize(ParallelType::MisalignedVectorize); + c1->axis(-1)->parallelize(ParallelType::MisalignedVectorize); + + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(1)->parallelize(ParallelType::BIDy); + tv2->axis(-2)->parallelize(ParallelType::TIDx); + tv2->axis(-1)->parallelize(ParallelType::MisalignedVectorize); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + const int n = 32; + const int c = 127; + const int h = 51; + const int w = 23; + at::Tensor t0 = at::randn({n, c, h, w}, options); + at::Tensor t1 = at::randn({n, c, h, w}, options); + + std::vector aten_inputs = {t0, t1}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + auto aten_output = t0 + t1; + testValidate( + &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionVectorizeMisalignedPointwiseMergeSymbolicPass_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + constexpr int kNumDims = 4; + constexpr int kTDX = 64; + constexpr int kVecSize = 2; + constexpr int kNumElems = kTDX * kVecSize; + + auto tv0 = makeSymbolicTensor(kNumDims); + auto tv1 = makeSymbolicTensor(kNumDims); + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = add(tv0, tv1); + fusion.addOutput(tv2); + + // Create caches for vectorization + auto c0 = tv0->cacheAfter(); + auto c1 = tv1->cacheAfter(); + auto c2 = tv2->cacheBefore(); + + // Merge all dimensions together except inner-most dim + for (const auto idx : c10::irange(kNumDims - 2)) { + tv2->merge(0); + } + // Split inner-most dim + tv2->split(-1, kNumElems); + tv2->split(-1, kVecSize); + TransformPropagatorWithCheck propagator(tv2); + MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator); + + c0->computeAt(tv2, -2); + c1->computeAt(tv2, -2); + + // Parallelization Strategy + c0->axis(-1)->parallelize(ParallelType::MisalignedVectorize); + c1->axis(-1)->parallelize(ParallelType::MisalignedVectorize); + + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(2)->parallelize(ParallelType::TIDx); + tv2->axis(-1)->parallelize(ParallelType::MisalignedVectorize); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + const int n = 5; + const int c = 3; + const int h = 51; + const int w = 257; + at::Tensor t0 = at::randn({n, c, h, w}, options); + at::Tensor t1 = at::randn({n, c, h, w}, options); + + std::vector aten_inputs = {t0, t1}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + auto aten_output = t0 + t1; + testValidate( + &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionVectorizeMisalignedPointwiseMergeSymbolicFail_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + constexpr int kNumDims = 4; + constexpr int kTDX = 64; + constexpr int kVecSize = 2; + constexpr int kNumElems = kTDX * kVecSize; + std::vector bcast_shape{1, 1, 1, -1}; + + auto tv0 = makeContigTensor(kNumDims); + auto tv1 = TensorViewBuilder().shape(bcast_shape).build(); + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = add(tv0, tv1); + fusion.addOutput(tv2); + + // Create caches for vectorization + auto c0 = tv0->cacheAfter(); + auto c1 = tv1->cacheAfter(); + auto c2 = tv2->cacheBefore(); + + // Merge all dimensions together + // Backward merge order is necessary for vectorize validation + for (int idx = kNumDims - 1; idx > 0; --idx) { + tv2->merge(idx - 1); + } + tv2->split(-1, kNumElems); + tv2->split(-1, kVecSize); + TransformPropagatorWithCheck propagator(tv2); + MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator); + + c0->computeAt(tv2, -2); + c1->computeAt(tv2, -2); + + // Parallelization Strategy + c0->axis(-1)->parallelize(ParallelType::MisalignedVectorize); + c1->axis(-1)->parallelize(ParallelType::MisalignedVectorize); + + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(1)->parallelize(ParallelType::TIDx); + tv2->axis(-1)->parallelize(ParallelType::MisalignedVectorize); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + const int n = 32; + const int c = 128; + const int h = 51; + const int w = 23; + at::Tensor t0 = at::randn({n, c, h, w}, options); + at::Tensor t1 = at::randn({1, 1, 1, w}, options); + + std::vector aten_inputs = {t0, t1}; + + FusionExecutor fe; + // TODO: throw assertion - cannot merge non-contiguous vectorization axes + // Make sure compilation fails + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) + ASSERT_ANY_THROW(fe.compileFusion(&fusion)); +} + +TEST_F(NVFuserTest, FusionVectorizeMisalignedRFactor_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(2); + auto tv1 = makeContigTensor(2); + + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = add(tv0, tv1); + + auto tv3 = sum(tv2, {-1}); + + fusion.addOutput(tv3); + + auto c0 = tv0->cacheAfter(); + auto c1 = tv1->cacheAfter(); + + tv3->split(-1, 128 * 4); + tv3->split(-1, 4); + // Reduce outer dim first + auto tv4 = tv3->rFactor({-3, -1}); + // Tv3 will reduce threads + + tv0->computeAt(tv3, 1); + tv1->computeAt(tv3, 1); + + tv3->axis(0)->parallelize(ParallelType::BIDx); + + tv0->computeAt(tv4, -2); + tv1->computeAt(tv4, -2); + + c0->axis(-1)->parallelize(ParallelType::MisalignedVectorize); + c1->axis(-1)->parallelize(ParallelType::MisalignedVectorize); + + tv4->axis(-2)->parallelize(ParallelType::TIDx); + tv3->axis(1)->parallelize(ParallelType::TIDx); + + tv2->computeAt(tv4, -1); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + const int bx = 128; + const int by = 2050; + at::Tensor t0 = at::randn({bx, by}, options); + at::Tensor t1 = at::randn({bx, by}, options); + + std::vector aten_inputs = {t0, t1}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + auto aten_output = t0.add(t1).sum(1); + testValidate( + &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionVectorizeMisalignedWrongDimFail_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(2); + auto tv1 = makeContigTensor(2); + + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = add(tv0, tv1); + fusion.addOutput(tv2); + + tv2->split(1, 16); + tv2->split(1, 64); + + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(2)->parallelize(ParallelType::TIDx); + + auto c0 = tv0->cacheAfter(); + auto c1 = tv1->cacheAfter(); + auto c2 = tv2->cacheBefore(); + + c0->computeAt(tv2, -2); + c1->computeAt(tv2, -2); + + std::vector vectorized_tvs = {c0, c1, tv2}; + for (auto tv : vectorized_tvs) { + tv->split(-1, 4); + // Vectorize the wrong dimension + tv->axis(-2)->parallelize(ParallelType::MisalignedVectorize); + } + + FusionExecutor fe; + // Make sure compilation fails + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) + ASSERT_ANY_THROW(fe.compileFusion(&fusion)); +} + +TEST_F(NVFuserTest, FusionVectorizeMisalignedStride_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + auto tv1 = makeSymbolicTensor(2); + + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = add(tv0, tv1); + fusion.addOutput(tv2); + + const int kTDX = 64; + const int kVecSize = 4; + const int kNumElems = kTDX * kVecSize; + + tv2->split(1, kNumElems); + + auto c0 = tv0->cacheAfter(); + auto c1 = tv1->cacheAfter(); + + tv2->split(-1, kVecSize); + + c0->computeAt(tv2, -2); + c1->computeAt(tv2, -2); + + c0->axis(-1)->parallelize(ParallelType::MisalignedVectorize); + c1->axis(-1)->parallelize(ParallelType::MisalignedVectorize); + + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(-2)->parallelize(ParallelType::TIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + const int bx = 128; + const int by = 2049; + at::Tensor t0 = at::randn({bx, by}, options).index({"...", Slice(3)}); + at::Tensor t1 = at::randn({bx, by}, options).index({"...", Slice(3)}); + std::vector aten_inputs = {t0, t1}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + auto aten_output = t0 + t1; + testValidate( + &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionVectorizeMisalignedStrideFail_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + auto tv1 = makeSymbolicTensor(2); + + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = add(tv0, tv1); + fusion.addOutput(tv2); + + const int kTDX = 64; + const int kVecSize = 4; + const int kNumElems = kTDX * kVecSize; + + tv2->split(1, kNumElems); + + auto c0 = tv0->cacheAfter(); + auto c1 = tv1->cacheAfter(); + auto c2 = tv2->cacheBefore(); + + tv2->split(-1, kVecSize); + + c0->computeAt(tv2, -2); + c1->computeAt(tv2, -2); + + c0->axis(-1)->parallelize(ParallelType::MisalignedVectorize); + c1->axis(-1)->parallelize(ParallelType::MisalignedVectorize); + + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(-2)->parallelize(ParallelType::TIDx); + tv2->axis(-1)->parallelize(ParallelType::MisalignedVectorize); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + const int bx = 128; + const int by = 2049; + at::Tensor t0 = at::randn({bx, by}, options).index({"...", Slice(3)}); + at::Tensor t1 = at::randn({bx, by}, options).index({"...", Slice(3)}); + std::vector aten_inputs = {t0, t1}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + + // Failure because the input + output tensors do not have the same stride + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) + ASSERT_ANY_THROW(fe.runFusion(aten_inputs)); +} + +TEST_F(NVFuserTest, FusionVectorization1_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + + auto tv1 = makeSymbolicTensor(2); + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = add(tv0, tv1); + fusion.addOutput(tv2); + + tv2->split(1, 16); + tv2->split(1, 64); + + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(2)->parallelize(ParallelType::TIDx); + + auto c0 = tv0->cacheAfter(); + auto c1 = tv1->cacheAfter(); + auto c2 = tv2->cacheBefore(); + + c0->computeAt(tv2, -2); + c1->computeAt(tv2, -2); + + std::vector vectorized_tvs = {c0, c1, tv2}; + for (auto tv : vectorized_tvs) { + tv->split(-1, 4); + tv->axis(-1)->parallelize(ParallelType::Vectorize); + } + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + const int bx = 128; + const int by = 2048; + at::Tensor t0 = at::randn({bx, by}, options); + at::Tensor t1 = at::randn({bx, by}, options); + + std::vector aten_inputs = {t0, t1}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + auto aten_output = t0 + t1; + testValidate( + &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionVectorization2_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + + auto tv1 = makeSymbolicTensor(2); + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = add(tv0, tv1); + fusion.addOutput(tv2); + + tv2->split(1, 16); + tv2->split(1, 64); + + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(2)->parallelize(ParallelType::TIDx); + + auto c0 = tv0->cacheAfter(); + auto c1 = tv1->cacheAfter(); + auto c2 = tv2->cacheBefore(); + + c0->computeAt(tv2, -2); + c1->computeAt(tv2, -2); + + std::vector vectorized_tvs = {c0, c1, tv2}; + for (auto tv : vectorized_tvs) { + tv->split(-1, 4); + // Vectorize the wrong dimension + tv->axis(-2)->parallelize(ParallelType::Vectorize); + } + + FusionExecutor fe; + // Make sure compilation fails + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) + ASSERT_ANY_THROW(fe.compileFusion(&fusion)); +} + +TEST_F(NVFuserTest, FusionVectorization3_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + + auto tv1 = makeSymbolicTensor(2); + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = add(tv0, tv1); + fusion.addOutput(tv2); + + tv2->split(1, 16); + tv2->split(1, 64); + + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(2)->parallelize(ParallelType::TIDx); + + auto c0 = tv0->cacheAfter(); + auto c1 = tv1->cacheAfter(); + auto c2 = tv2->cacheBefore(); + + c0->computeAt(tv2, -2); + c1->computeAt(tv2, -2); + + std::vector vectorized_tvs = {c0, c1, tv2}; + for (auto tv : vectorized_tvs) { + tv->split(-1, 4); + tv->axis(-1)->parallelize(ParallelType::Vectorize); + } + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + const int bx = 128; + const int by = 2049; + at::Tensor t0 = at::randn({bx, by}, options); + at::Tensor t1 = at::randn({bx, by}, options); + std::vector aten_inputs = {t0, t1}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) + ASSERT_ANY_THROW(fe.runFusion(aten_inputs)); + + aten_inputs[0] = t0.index({"...", Slice(1)}); + aten_inputs[1] = t1.index({"...", Slice(1)}); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) + ASSERT_ANY_THROW(fe.runFusion(aten_inputs)); + + t0 = at::randn({bx, 2048}, options).index({"...", Slice(4)}); + t1 = at::randn({bx, 2048}, options).index({"...", Slice(4)}); + aten_inputs = {t0, t1}; + auto cg_outputs = fe.runFusion(aten_inputs); + + auto aten_output = t0 + t1; + testValidate( + &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionVectorizationRFactor_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + + auto tv1 = makeSymbolicTensor(2); + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = add(tv0, tv1); + + auto tv3 = sum(tv2, {-1}); + + fusion.addOutput(tv3); + + tv3->split(-1, 128 * 4); + tv3->split(-1, 4); + // Reduce outer dim first + auto tv4 = tv3->rFactor({-3, -1}); + // Tv3 will reduce threads + + auto tv6 = tv0->cacheAfter(); + auto tv7 = tv1->cacheAfter(); + + tv0->computeAt(tv3, 1); + tv1->computeAt(tv3, 1); + + tv3->axis(0)->parallelize(ParallelType::BIDx); + + tv0->computeAt(tv4, -2); + tv1->computeAt(tv4, -2); + + tv6->axis(-1)->parallelize(ParallelType::Vectorize); + tv7->axis(-1)->parallelize(ParallelType::Vectorize); + + tv4->axis(-2)->parallelize(ParallelType::TIDx); + tv3->axis(1)->parallelize(ParallelType::TIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + const int bx = 128; + const int by = 2048; + at::Tensor t0 = at::randn({bx, by}, options); + at::Tensor t1 = at::randn({bx, by}, options); + + std::vector aten_inputs = {t0, t1}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + auto aten_output = t0.add(t1).sum(1); + testValidate( + &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); + + auto t3 = t0.add(t1).sum(1); + + testValidate(&fusion, cg_outputs, aten_inputs, {t3}, __LINE__, __FILE__); +} + +// Unswitched loops with extent one may omit else clause. +TEST_F(NVFuserTest, FusionSizeOneLoop1_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Progressively broadcast tensors + TensorView* tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + TensorView* tv1 = makeSymbolicTensor(2); + fusion.addInput(tv1); + TensorView* tv2 = makeSymbolicTensor(3); + fusion.addInput(tv2); + + TensorView* tv3 = broadcast(tv0, {false, true}); + TensorView* tv4 = add(tv3, tv1); + TensorView* tv5 = add(tv4, tv2); + + fusion.addOutput(tv5); + + // Split inner dimension + tv5->split(1, 8); + // Merge middle dims with outer dimensions + tv5->merge(2); + tv5->merge(0); + + // tv5[I0*I1o, I1i*I2] + // Get a dim of size 1 to unswitch + tv5->split(0, 1, false); + + // Compute everything inline + tv0->computeAt(tv5, -1); + + tv5->axis(0)->parallelize(ParallelType::Unswitch); + tv5->axis(1)->parallelize(ParallelType::BIDx); + tv5->axis(2)->parallelize(ParallelType::TIDx); + + // Make sure the unswitched loop does not have an else clause. + GpuLower gpulw(&fusion); + TORCH_CHECK(!UnswitchInElseChecker::check(gpulw)); + + const int x = 11; + const int y = 12; + const int z = 13; + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({x}, options); + at::Tensor t1 = at::randn({x, y}, options); + at::Tensor t2 = at::randn({z, x, y}, options); + std::vector aten_inputs = {t0, t1, t2}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + auto t6 = (t0.unsqueeze(-1) + t1).unsqueeze(0) + t2; + + testValidate(&fusion, cg_outputs, aten_inputs, {t6}, __LINE__, __FILE__); +} + +// The unswitched loop has extent one but inner loops don't. The else +// part should not be omitted. +TEST_F(NVFuserTest, FusionSizeOneLoop2_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + const int x = 15; + auto tv0 = makeConcreteTensor({x}); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1)); + fusion.addOutput(tv1); + + tv1->split(-1, 4); + tv1->split(-2, 1); + + tv1->axis(-2)->parallelize(ParallelType::Unswitch); + + // Make sure the size-one unswitched loop does not omit the else clause. + GpuLower gpulw(&fusion); + TORCH_CHECK(UnswitchInElseChecker::check(gpulw)); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({x}, options); + std::vector aten_inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + auto t1 = t0 + 1; + + testValidate(&fusion, cg_outputs, aten_inputs, {t1}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionValidateParallelize1_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); + fusion.addOutput(tv2); + + tv1->axis(-1)->parallelize(ParallelType::TIDx); + tv2->axis(-1)->parallelize(ParallelType::TIDy); + + // Invalid as tv1 and tv2 do have the same ParallelType + FusionExecutor fe; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) + ASSERT_ANY_THROW(fe.compileFusion(&fusion)); +} + +TEST_F(NVFuserTest, FusionValidateParallelize2_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); + fusion.addOutput(tv2); + + tv1->axis(-1)->parallelize(ParallelType::TIDx); + tv2->axis(-1)->parallelize(ParallelType::TIDy); + tv1->setMemoryType(MemoryType::Shared); + + // tv1 and tv2 do have the same ParallelType, but tv1 is on shared + // memory, so it is valid + FusionExecutor fe; + fe.compileFusion(&fusion); +} + +TEST_F(NVFuserTest, FusionValidateParallelize3_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); + fusion.addOutput(tv2); + + tv1->split(-1, 4); + tv1->axis(-1)->parallelize(ParallelType::TIDx); + tv2->split(-1, 4); + tv2->axis(-1)->parallelize(ParallelType::TIDx); + + tv1->setMemoryType(MemoryType::Global); + + // tv1 and tv2 have the same shape and ParallelType + FusionExecutor fe; + fe.compileFusion(&fusion); +} + +TEST_F(NVFuserTest, FusionValidateParallelize4_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); + fusion.addOutput(tv2); + + tv1->split(-1, 4); + tv1->axis(-1)->parallelize(ParallelType::TIDx); + tv2->split(-1, 8); + tv2->axis(-1)->parallelize(ParallelType::TIDx); + + tv1->setMemoryType(MemoryType::Global); + + // tv1 and tv2 do not have the same shape but global memory comm is supported. + FusionExecutor fe; + fe.compileFusion(&fusion); +} + +TEST_F(NVFuserTest, FusionValidateParallelize5_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); + fusion.addOutput(tv2); + + tv1->split(-1, 4); + tv1->axis(-1)->parallelize(ParallelType::TIDx); + tv1->setMemoryType(MemoryType::Shared); + + tv2->split(-1, 8); + tv2->axis(-1)->parallelize(ParallelType::TIDx); + + // tv1 and tv2 do not have the same shape, but tv1 is on shared + // memory, so it is valid + FusionExecutor fe; + fe.compileFusion(&fusion); +} + +// See issue #995 +TEST_F(NVFuserTest, FusionValidateParallelize6_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + int64_t W = 5, X = 6, Y = 7, Z = 8; + + auto tv0 = makeConcreteTensor({X, Y, Z}); + auto tv1 = makeConcreteTensor({W, X, Y, Z}); + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = add(tv0, IrBuilder::create(1)); + auto tv3 = broadcast(tv2, {true, false, false, false}); + auto tv4 = add(tv3, tv1); + fusion.addOutput(tv4); + + tv4->merge(0); + tv4->merge(0); + tv4->merge(0); + tv4->split(0, 4); + tv4->split(0, 3); + tv4->split(0, 2); + + TransformPropagatorWithCheck propagator(tv4); + MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); + + tv0->computeAt(tv2, 2); + tv3->computeAt(tv4, 2); + + tv4->axis(0)->parallelize(ParallelType::BIDx); + tv4->axis(-1)->parallelize(ParallelType::TIDx); + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(-1)->parallelize(ParallelType::TIDx); + tv3->axis(-1)->parallelize(ParallelType::TIDx); + + // Validation should throw an exception saying the first axes of tv2 + // and tv3 have incompatible parallelization. See also issue #995. + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) + ASSERT_ANY_THROW(fusion.printKernel()); +} + +// Repro of #2046 +TEST_F(NVFuserTest, FusionValidateParallelize7_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + auto tv2 = set(tv1); + auto tv3 = set(tv1); + fusion.addOutput(tv2); + fusion.addOutput(tv3); + + tv1->setMemoryType(MemoryType::Global); + + tv1->axis(0)->parallelize(ParallelType::BIDx); + tv1->axis(1)->parallelize(ParallelType::TIDx); + + tv2->axis(1)->parallelize(ParallelType::TIDy); + tv3->axis(0)->parallelize(ParallelType::BIDx); + + // tv2 uses tv1 but is not parallelized with BIDx, so a grid sync is + // required. It should be placed as a top-level expression. + + GpuLower gpulw(&fusion); + TORCH_CHECK( + std::any_of( + gpulw.kernel()->topLevelExprs().begin(), + gpulw.kernel()->topLevelExprs().end(), + [](Expr* expr) { return expr->isA(); }), + "Grid sync not found"); +} + +TEST_F(NVFuserTest, FusionDAGMerging_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(5); + auto tv1 = makeSymbolicTensor(1); + fusion.addInput(tv0); + fusion.addInput(tv1); + + // Branch 0 + auto tv2 = sum(tv0, {0}); // 0 + auto tv3 = sum(tv2, {0}); // 1 + auto tv4 = sum(tv3, {0}); // 2 + auto tv5 = sum(tv4, {0}); // 3 + + // Branch 1 + auto tv6 = add(tv1, IrBuilder::create(1)); // 4 + + // Merge + auto tv7 = add(tv6, tv5); // 5 + + // Maximum expected output groups (can improve overtime): + // {0}, {1}, {2}, {3,4,5} + // without final merge would have been {0}, {1}, {2}, {3,4}, {5} + + fusion.addOutput(tv7); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({2, 2, 2, 2, 2}, options); + at::Tensor t1 = at::randn({2}, options); + + std::vector aten_inputs = {t0, t1}; + + KernelArgumentHolder args(KernelIndexMode::INT32); + args.setDeviceIndex(0); + args.push(aten_inputs); + + auto fusion_segments = fusion.segment(args); + TORCH_CHECK(fusion_segments->groups().size() <= 4); +} + +TEST_F(NVFuserTest, FusionDAGScalarMerging_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeSymbolicTensor(3); + auto i0 = IrBuilder::create(); + + fusion->addInput(tv0); + fusion->addInput(i0); + + auto i1 = add(i0, IrBuilder::create(1.0)); + auto i2 = mul(i1, i1); + auto i3 = add(i2, i1); + + // Branch 0 + auto tv1 = sum(tv0, {0}); // 0 + auto tv2 = add(tv1, i2); + // Branch 1 + auto tv3 = sum(tv2, {0}); // 1 + auto tv4 = add(tv3, i3); + + auto tv5 = add(tv4, i0); + + fusion->addOutput(tv5); + + FusionExecutorCache executor_cache(std::move(fusion)); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({16, 16, 16}, options); + double s0 = 0.5; + + auto s1 = s0 + 1.0; + auto s2 = s1 * s1; + auto s3 = s2 + s1; + auto t1 = t0.sum({0}); + auto t2 = t1 + s2; + auto t3 = sum(t2, {0}); + auto t4 = t3 + s3; + auto t5 = t4 + s0; + + auto outputs = executor_cache.runFusionWithInputs({t0, s0}); + + TORCH_CHECK( + executor_cache.getMostRecentKernelRuntime()->isSegmented(), + "segmentation didn't happen"); + TORCH_CHECK( + executor_cache.getMostRecentKernelRuntime() + ->fusionSegments() + ->groups() + .size() == 2, + "segmentation didn't happen as expected"); + + testValidate( + executor_cache.fusion(), outputs, {t0, s0}, {t5}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionBlockReduceInSerialLoop_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + constexpr int M = 10; + constexpr int N = 20; + constexpr int K = 20; + + auto tv0 = makeSymbolicTensor(3); + auto tv1 = sum(tv0, {{1, 2}}); + fusion.addInput(tv0); + fusion.addOutput(tv1); + + tv1->axis(-1)->parallelize(ParallelType::TIDx); + tv1->axis(0)->parallelize(ParallelType::BIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + at::Tensor t0 = at::randn({M, N, K}, options); + std::vector aten_inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto outputs = fe.runFusion(aten_inputs); + at::Tensor aten_output = t0.sum({1, 2}); + testValidate( + &fusion, outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionBlockWelfordInSerialLoop_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + constexpr int M = 10; + constexpr int N = 20; + constexpr int K = 20; + + auto tv0 = makeSymbolicTensor(3); + auto tvs = Welford(tv0, {{1, 2}}); + fusion.addInput(tv0); + auto tv_avg = tvs.avg; + auto tv_M2 = tvs.var_sum; + auto tv_N = tvs.n; + fusion.addOutput(tv_avg); + fusion.addOutput(tv_M2); + + tv_avg->axis(-1)->parallelize(ParallelType::TIDx); + tv_avg->axis(0)->parallelize(ParallelType::BIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + at::Tensor t0 = at::randn({M, N, K}, options); + std::vector aten_inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto outputs = fe.runFusion(aten_inputs); + at::Tensor aten_avg = t0.mean({1, 2}); + at::Tensor aten_M2 = t0.var({1, 2}, false) * N * K; + testValidate( + &fusion, outputs, aten_inputs, {aten_avg, aten_M2}, __LINE__, __FILE__); +} + +// See Issue #716 +TEST_F(NVFuserTest, FusionIOTensorTrivialReductionRepro_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + constexpr int M = 10; + constexpr int N = 11; + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + std::vector reduction_axes = {1}; + std::vector broadcast_mask = {false, true}; + + auto tv0_bcast = broadcast(tv0, broadcast_mask); + auto path1_bcast = add(tv0_bcast, IrBuilder::create(1.0)); + auto path1 = sum(path1_bcast, reduction_axes); + fusion.addOutput(path1); + + auto p = path1->split(1, 1); + path1->rFactor({1}); + path1->axis(0)->parallelize(ParallelType::BIDx); + tv0->computeAt(path1, 1); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + at::Tensor t0 = at::randn({M}, options); + at::Tensor t0_ref = t0.clone(); + std::vector aten_inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + + // inplace op, we are adding t0 to itself + auto outputs = fe.runFusion(aten_inputs, {t0}); + + TORCH_CHECK(outputs[0].allclose(t0_ref.add(1))); +} + +TEST_F(NVFuserTest, FusionReductionPredicate_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto tv1 = sum(tv0, {0}); + fusion.addOutput(tv1); + + auto tv2 = tv0->cacheAfter(); + + const int bdimx = 128; + tv1->split(1, bdimx); + tv1->split(1, 4); + tv1->split(1, 1); + + tv1->axis(-1)->parallelize(ParallelType::TIDx); + tv1->axis(2)->parallelize(ParallelType::Unroll); + tv1->split(0, 10); + tv0->computeAt(tv1, 4); + + tv2->axis(-1)->parallelize(ParallelType::TIDx); + + int numel_x = 650; + int numel_y = 102; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input = at::randn({numel_x, numel_y}, options); + at::Tensor cg_output = at::empty({numel_y}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input}); + fe.runFusion({input}, {cg_output}); + + auto aten_output = input.to(at::kDouble).sum({0}); + + testValidate( + &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionIssue728_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addOutput(tv0); + auto tv1 = makeSymbolicTensor(1); + fusion.addOutput(tv1); + auto tv2 = makeSymbolicTensor(1); + fusion.addOutput(tv2); + + auto tv3 = add(tv0, IrBuilder::create(1)); + auto tv4 = add(tv3, tv1); + auto tv5 = add(tv4, IrBuilder::create(1)); + auto tv6 = add(tv2, IrBuilder::create(1)); + fusion.addOutput(tv5); + fusion.addOutput(tv6); + + // tv0 -> tv3 -+ + // tv1 --------+-> tv4 -> tv5 + // + // tv2 -> tv6 + + auto all_vals_under_tv3 = + DependencyCheck::getAllValsBetween({tv3}, fusion.outputs()); + std::unordered_set included_tensors({tv3, tv4, tv5}); + for (auto tv : included_tensors) { + TORCH_CHECK( + std::find(all_vals_under_tv3.begin(), all_vals_under_tv3.end(), tv) != + all_vals_under_tv3.end(), + "TV", + tv->name(), + " not found"); + } + for (auto tv : ir_utils::filterByType(fusion.vals())) { + if (included_tensors.find(tv) == included_tensors.end()) { + TORCH_CHECK( + std::find(all_vals_under_tv3.begin(), all_vals_under_tv3.end(), tv) == + all_vals_under_tv3.end(), + "TV", + tv->name(), + " should not be found"); + } + } + + auto no_dependency = DependencyCheck::getAllValsBetween({}, fusion.outputs()); + TORCH_CHECK(no_dependency.empty(), "No val should be returned"); + + auto no_dep_path = DependencyCheck::getAllValsBetween({tv0, tv1}, {tv6}); + TORCH_CHECK(no_dep_path.empty(), "No val should be returned"); + + auto no_dep_path2 = DependencyCheck::getAllValsBetween({tv2}, {tv5}); + TORCH_CHECK(no_dep_path2.empty(), "No val should be returned"); + + auto just_tv3 = DependencyCheck::getAllValsBetween({tv3}, {tv3}); + TORCH_CHECK( + just_tv3.size() == 1 && *(just_tv3.begin()) == tv3, + "Only tv3 should be included"); +} + +TEST_F(NVFuserTest, FusionIssue757_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto tv1 = sum(tv0, {1}); + auto tv2 = broadcast(tv1, {false, true}); + auto tv3 = makeSymbolicTensor(2); + fusion.addInput(tv3); + auto tv4 = add(tv2, tv3); + fusion.addOutput(tv4); + + tv1->computeAt(tv4, -1); + + tv2->axis(-1)->parallelize(ParallelType::TIDx); + tv4->axis(-1)->parallelize(ParallelType::TIDx); + tv1->axis(-1)->parallelize(ParallelType::TIDx); + + int numel_x = 650; + int numel_y = 102; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({numel_x, numel_y}, options); + at::Tensor t3 = at::randn({numel_x, numel_y}, options); + std::vector inputs = {t0, t3}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); + auto outputs = fe.runFusion(inputs); + + auto t1 = t0.sum({1}); + auto t2 = t1.unsqueeze(-1).expand({numel_x, numel_y}); + auto t4 = t2 + t3; + + testValidate(&fusion, outputs, inputs, {t4}, __LINE__, __FILE__); +} + +// See issue #759 +TEST_F(NVFuserTest, FusionPredicatedBlockBroadcast_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto tv1 = sum(tv0, {1}); + auto tv2 = broadcast(tv1, {false, true}); + auto tv3 = makeSymbolicTensor(2); + fusion.addInput(tv3); + auto tv4 = add(tv2, tv3); + fusion.addOutput(tv4); + + tv4->split(0, 4); + tv1->computeAt(tv4, -1); + + tv2->axis(-1)->parallelize(ParallelType::TIDx); + tv2->axis(1)->parallelize(ParallelType::TIDy); + tv4->axis(-1)->parallelize(ParallelType::TIDx); + tv4->axis(1)->parallelize(ParallelType::TIDy); + tv1->axis(-1)->parallelize(ParallelType::TIDx); + + int numel_x = 100; + int numel_y = 101; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({numel_x, numel_y}, options); + at::Tensor t3 = at::randn({numel_x, numel_y}, options); + std::vector inputs = {t0, t3}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); + auto outputs = fe.runFusion(inputs); + + auto t1 = t0.sum({1}); + auto t2 = t1.unsqueeze(-1).expand({numel_x, numel_y}); + auto t4 = t2 + t3; + + testValidate(&fusion, outputs, inputs, {t4}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionSegmentVerticalMerge_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeSymbolicTensor(3); + + fusion->addInput(tv0); + // {first kernel} + auto tv1 = sum(tv0, {0}); + auto tv2 = add(tv1, tv0); + auto tv3 = sum(tv2, {0}); + auto tv4 = add(tv3, tv0); + auto tv5 = sum(tv4, {0}); + auto tv6 = sum(tv5, {0}); + // {second kernel} + auto tv7 = add(tv6, tv5); + auto tv8 = add(tv7, tv5); + auto tv9 = sum(tv8, {0}); + + fusion->addOutput(tv9); + + SegmentCandidateFinderOptions segment_options; + segment_options.run_herrmann_merge = false; + segment_options.run_final_merge = false; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({2, 2, 2}, options); + + KernelArgumentHolder args(KernelIndexMode::INT32); + args.setDeviceIndex(0); + args.push(t0); + + auto segmented_fusion = + SegmentCandidateFinder::segment(fusion.get(), args, segment_options); + + TORCH_CHECK(segmented_fusion->groups().size() == 2); +} + +TEST_F(NVFuserTest, FusionSegmentHorizontalMerge_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeSymbolicTensor(3); + auto i0 = IrBuilder::create(); + + fusion->addInput(tv0); + fusion->addInput(i0); + + // Branch 0 {first kernel} + auto tv1 = sum(tv0, {0}); + auto tv2 = add(tv0, i0); + auto tv3 = unaryOp(UnaryOpType::Rsqrt, tv2); + auto tv4 = sum(tv3, {0}); + + // Branch 1 {first kernel} + auto tv5 = unaryOp(UnaryOpType::Rsqrt, tv3); + auto tv6 = sum(tv5, {0}); + + // Incompatible {second kernel} + auto tv7 = sum(tv6, {0}); + + fusion->addOutput(tv1); + fusion->addOutput(tv4); + fusion->addOutput(tv7); + + SegmentCandidateFinderOptions segment_options; + segment_options.run_herrmann_merge = false; + segment_options.run_final_merge = false; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({2, 2, 2}, options); + + KernelArgumentHolder args(KernelIndexMode::INT32); + args.setDeviceIndex(0); + args.push(t0); + c10::IValue scalar = 1.0; + args.push(scalar); + + auto segmented_fusion = + SegmentCandidateFinder::segment(fusion.get(), args, segment_options); + + TORCH_CHECK(segmented_fusion->groups().size() == 2); +} + +TEST_F(NVFuserTest, FusionSegmentMixReduction_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeSymbolicTensor(3); + + fusion->addInput(tv0); + + // def of tv1 in kernel 1 through horizontal + auto tv1 = sum(tv0, {0, 1}); + // kernel 2 + auto tv2 = sum(tv0, {2}); + auto tv3 = broadcast(tv2, {false, false, true}); + auto tv4 = add(tv0, tv3); + auto tv5 = sum(tv4, {2}); + // end of kernel 2 + // kernel 1 + auto tv6 = unaryOp(UnaryOpType::Rsqrt, tv0); + auto tv7 = sum(tv6, {0, 1}); + auto tv8 = sum(tv6, {0, 1}); + + fusion->addOutput(tv1); + fusion->addOutput(tv5); + fusion->addOutput(tv7); + fusion->addOutput(tv8); + + SegmentCandidateFinderOptions segment_options; + segment_options.run_herrmann_merge = false; + segment_options.run_final_merge = false; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({2, 2, 2}, options); + + KernelArgumentHolder args(KernelIndexMode::INT32); + args.setDeviceIndex(0); + args.push(t0); + + auto segmented_fusion = + SegmentCandidateFinder::segment(fusion.get(), args, segment_options); + + TORCH_CHECK(segmented_fusion->groups().size() <= 2); +} + +TEST_F(NVFuserTest, FusionSBAR_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // N, H, W, C format + std::vector input_shape{656, 7, 7, 64}; + + auto x = makeContigTensor(4); + auto y = makeContigTensor(4); + auto weight = makeContigTensor(1); + auto bias = makeContigTensor(1); + + fusion.addInput(x); + fusion.addInput(y); + fusion.addInput(weight); + fusion.addInput(bias); + + const size_t kNumberOfDims = x->nDims(); + std::vector broadcast_mask(kNumberOfDims, false); + for (const auto axis : c10::irange(kNumberOfDims - 1)) { + broadcast_mask[axis] = true; + } + + auto weight_bcast = broadcast(weight, broadcast_mask); + auto scale = mul(x, weight_bcast); + auto bias_bcast = broadcast(bias, broadcast_mask); + auto scale_bias = add(scale, bias_bcast); + auto scale_bias_add = add(scale_bias, y); + auto scale_bias_add_relu = unaryOp(UnaryOpType::Relu, scale_bias_add); + + fusion.addOutput(scale_bias_add_relu); + + // inputs + at::manual_seed(0); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor at_x = at::randn(input_shape, options); + at::Tensor at_y = at::randn(input_shape, options); + at::Tensor at_weight = at::ones({input_shape[3]}, options); + at::Tensor at_bias = at::zeros({input_shape[3]}, options); + + // inputs + std::vector inputs = {at_x, at_y, at_weight, at_bias}; + + // outputs + std::vector outputs; + + auto lparams = schedulePointwise(&fusion, inputs); + + FusionExecutor executor; + executor.compileFusion(&fusion, inputs, lparams); + outputs = executor.runFusion(inputs, lparams); + + auto at_scale = at::mul(at_x, at_weight); + auto at_scale_bias = at::add(at_scale, at_bias); + auto pwise_add = at::add(at_scale_bias, at_y); + auto output = at::relu(pwise_add); + + testValidate(&fusion, outputs, inputs, {output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionSingleElement_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(0); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(2.5)); + + auto tv2 = add(tv1, IrBuilder::create(3.5)); + fusion.addOutput(tv2); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input = at::randn({}, options); + + at::Tensor cg_output = at::empty({}, options); + + auto lparams = schedulePointwise(&fusion, {input}); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input}, lparams); + fe.runFusion({input}, {cg_output}, lparams); + + auto aten_output = input.add(2.5).add(3.5); + + testValidate( + &fusion, {cg_output}, {input}, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionBNBackwardRepro_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + int batch = 4; + int c = 4; + int h = 4; + int w = 4; + int numDims = 4; + + auto input = makeSymbolicTensor(numDims); + fusion.addInput(input); + auto weight = makeSymbolicTensor(1); + fusion.addInput(weight); + auto running_mean = makeSymbolicTensor(1); + fusion.addInput(running_mean); + auto running_var = makeSymbolicTensor(1); + fusion.addInput(running_var); + auto save_mean = makeSymbolicTensor(1); + fusion.addInput(save_mean); + auto save_invstd = makeSymbolicTensor(1); + fusion.addInput(save_invstd); + + auto grad_out_prev = makeSymbolicTensor(numDims); + fusion.addInput(grad_out_prev); + auto gt_0 = + makeSymbolicTensor(numDims); // single tensor broadcasted is dangerous. + fusion.addInput(gt_0); + + auto gt_bool = binaryOp(BinaryOpType::GT, gt_0, IrBuilder::create(1)); + auto gt_float = castOp(DataType::Float, gt_bool); + + auto grad_out = mul(grad_out_prev, gt_float); + + Val* eps_ptr = IrBuilder::create(1e-5); + + auto grads = batch_norm_backward( + input, + grad_out, + weight, + running_mean, + running_var, + save_mean, + save_invstd, + true, + eps_ptr, + {true, true, true}); + + fusion.addOutput(grads.grad_input); + fusion.addOutput(grads.grad_weight); + fusion.addOutput(grads.grad_bias); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input0 = at::randn({batch, c, h, w}, options); + at::Tensor input1 = at::randn({c}, options); + at::Tensor input2 = at::randn_like(input1); + at::Tensor input3 = at::randn_like(input1); + at::Tensor input4 = at::randn_like(input1); + at::Tensor input5 = at::randn_like(input1); + at::Tensor input6 = at::randn_like(input0); + at::Tensor input7 = at::randn_like(input0); + + FusionExecutorCache fec(std::move(fusion_ptr)); + std::vector inputs = { + input0, input1, input2, input3, input4, input5, input6, input7}; + auto outputs = fec.runFusionWithInputs(inputs); +} + +// TODO: We only changed inputs, merge this with the test above. +TEST_F(NVFuserTest, FusionBNBackwardRepro2_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + int batch = 2; + int c = 81; + int h = 1; + int w = 1; + int numDims = 4; + + // auto input = makeSymbolicTensor(numDims); + auto input = makeConcreteTensor({-1, -1, 1, 1}); + fusion.addInput(input); + auto weight = makeSymbolicTensor(1); + fusion.addInput(weight); + auto running_mean = makeSymbolicTensor(1); + fusion.addInput(running_mean); + auto running_var = makeSymbolicTensor(1); + fusion.addInput(running_var); + auto save_mean = makeSymbolicTensor(1); + fusion.addInput(save_mean); + auto save_invstd = makeSymbolicTensor(1); + fusion.addInput(save_invstd); + + // auto grad_out_prev = makeSymbolicTensor(numDims); + auto grad_out_prev = makeConcreteTensor({-1, -1, 1, 1}); + fusion.addInput(grad_out_prev); + // auto gt_0 = + // makeSymbolicTensor(numDims); // single tensor broadcasted is dangerous. + auto gt_0 = makeConcreteTensor({-1, -1, 1, 1}); + fusion.addInput(gt_0); + + auto gt_bool = binaryOp(BinaryOpType::GT, gt_0, IrBuilder::create(1)); + auto gt_float = castOp(DataType::Float, gt_bool); + + auto grad_out = mul(grad_out_prev, gt_float); + + Val* eps_ptr = IrBuilder::create(1e-5); + + auto grads = batch_norm_backward( + input, + grad_out, + weight, + running_mean, + running_var, + save_mean, + save_invstd, + true, + eps_ptr, + {true, true, true}); + + fusion.addOutput(grads.grad_input); + fusion.addOutput(grads.grad_weight); + fusion.addOutput(grads.grad_bias); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input0 = at::randn({batch, c, h, w}, options); + at::Tensor input1 = at::randn({c}, options); + at::Tensor input2 = at::randn_like(input1); + at::Tensor input3 = at::randn_like(input1); + at::Tensor input4 = at::randn_like(input1); + at::Tensor input5 = at::randn_like(input1); + at::Tensor input6 = at::randn_like(input0); + at::Tensor input7 = at::randn_like(input0); + + FusionExecutorCache fec(std::move(fusion_ptr)); + std::vector inputs = { + input0, input1, input2, input3, input4, input5, input6, input7}; + auto outputs = fec.runFusionWithInputs(inputs); +} + +TEST_F(NVFuserTest, FusionBNRepro_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + const bool kTraining = true; + const float kMomentum = 0.1; + const float kEps = 1e-5; + + int batch = 14; + int c = 65; + int h = 7; + int w = 7; + int numDims = 4; + + auto input = makeSymbolicTensor(numDims); + fusion.addInput(input); + auto weight = makeSymbolicTensor(1); + fusion.addInput(weight); + auto bias = makeSymbolicTensor(1); + fusion.addInput(bias); + auto running_mean = makeSymbolicTensor(1); + fusion.addInput(running_mean); + auto running_var = makeSymbolicTensor(1); + fusion.addInput(running_var); + + auto momentum_ptr = IrBuilder::create(kMomentum); + auto eps_ptr = IrBuilder::create(kEps); + + auto result = batch_norm( + input, + weight, + bias, + running_mean, + running_var, + kTraining, + momentum_ptr, + eps_ptr); + + fusion.addOutput(result.output); + fusion.addOutput(result.mean); + fusion.addOutput(result.invstd); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input1 = at::randn({batch, c, h, w}, options); + at::Tensor input2 = at::randn({c}, options); + at::Tensor input3 = at::randn_like(input2); + at::Tensor input4 = at::randn_like(input2); + at::Tensor input5 = at::randn_like(input2); + + auto input1_ref = input1.clone(); + auto input2_ref = input2.clone(); + auto input3_ref = input3.clone(); + auto input4_ref = input4.clone(); + auto input5_ref = input5.clone(); + + FusionExecutorCache fec(std::move(fusion_ptr)); + std::vector aten_inputs = {input1, input2, input3, input4, input5}; + auto cg_outputs = fec.runFusionWithInputs(aten_inputs); + + auto at_results = at::native_batch_norm( + input1_ref, + input2_ref, + input3_ref, + input4_ref, + input5_ref, + kTraining, + kMomentum, + kEps); + + auto at_output = std::get<0>(at_results); + auto at_mean = std::get<1>(at_results); + auto at_invstd = std::get<2>(at_results); + + std::vector aten_outputs = {at_output, at_mean, at_invstd}; + + testValidate( + &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionBNRepro2_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + const bool kTraining = true; + const float kMomentum = 0.1; + const float kEps = 1e-5; + + int batch = 2; + int c = 4; + int h = 17; + int w = 17; + int numDims = 4; + + auto input = makeSymbolicTensor(numDims); + fusion.addInput(input); + + Val* momentum_ptr = IrBuilder::create(kMomentum); + Val* eps_ptr = IrBuilder::create(kEps); + + auto result = batch_norm( + input, + nullptr, + nullptr, + nullptr, + nullptr, + kTraining, + momentum_ptr, + eps_ptr); + + fusion.addOutput(result.output); + fusion.addOutput(result.mean); + fusion.addOutput(result.invstd); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input1 = at::randn({batch, c, h, w}, options); + + auto input1_ref = input1.clone(); + at::Tensor r_m; + at::Tensor r_v; + at::Tensor weight; + at::Tensor bias; + + FusionExecutorCache fec(std::move(fusion_ptr)); + std::vector aten_inputs = {input1}; + auto cg_outputs = fec.runFusionWithInputs(aten_inputs); + + auto at_results = at::native_batch_norm( + input1_ref, r_m, r_v, weight, bias, kTraining, kMomentum, kEps); + + auto at_output = std::get<0>(at_results); + auto at_mean = std::get<1>(at_results); + auto at_invstd = std::get<2>(at_results); + + std::vector aten_outputs = {at_output, at_mean, at_invstd}; + + testValidate( + &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionZeroSizeTensorPW_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + auto tv1 = makeConcreteTensor({0}); + fusion.addInput(tv1); + + auto tv2 = add(tv0, IrBuilder::create(2.5)); + fusion.addOutput(tv2); + + // This test used to just have: + // auto tv3 = makeConcreteTensor({0}); + // and somehow that was running through our system fine, but size-0 tensors + // are not supported, so making sure this fails. + auto tv3 = set(tv1); + fusion.addOutput(tv3); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor input0 = at::randn({2}, options); + at::Tensor input1 = at::randn({0}, options); + at::Tensor cg_output2 = at::empty({2}, options); + at::Tensor cg_output3 = at::empty({0}, options); + + // Fails at schedule pointwise because our (maybe only) size-0 check is in + // binding input sizes which the scheduler ends up calling. + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) + ASSERT_ANY_THROW(schedulePointwise(&fusion, {input0, input1})); +} + +TEST_F(NVFuserTest, FusionZeroSizeTensorReduction_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto tv1 = makeConcreteTensor({0}); + fusion.addInput(tv1); + + auto tv2 = sum(tv0, {1}); + fusion.addOutput(tv2); + + auto tv3 = makeConcreteTensor({0}); + fusion.addOutput(tv3); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor input0 = at::randn({2, 4}, options); + at::Tensor input1 = at::randn({0}, options); + at::Tensor cg_output2 = at::empty({2}, options); + at::Tensor cg_output3 = at::empty({0}, options); + + auto reduction_params = getReductionHeuristics(&fusion, {input0, input1}); + TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); + scheduleReduction(&fusion, *reduction_params); + TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); + + auto lparams = reduction_params->lparams; + FusionExecutor fe; + fe.compileFusion(&fusion, {input0, input1}, lparams); + auto cg_outputs = fe.runFusion({input0, input1}, lparams); + auto aten_output2 = input0.sum({1}); + at::Tensor aten_output3 = at::empty({0}, options); + + testValidate( + &fusion, + cg_outputs, + {input0, input1}, + {aten_output2, aten_output3}, + __LINE__, + __FILE__, + "", + lparams); +} + +TEST_F(NVFuserTest, FusionZeroSizeTensorNormalization_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto tv1 = makeConcreteTensor({0}); + fusion.addInput(tv1); + + auto tv2 = sum(tv0, {0}); + auto tv3 = broadcast(tv2, {true, false}); + auto tv4 = add(tv0, tv3); + fusion.addOutput(tv4); + + auto tv5 = makeConcreteTensor({0}); + fusion.addOutput(tv5); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor input0 = at::randn({2, 4}, options); + at::Tensor input1 = at::randn({0}, options); + at::Tensor cg_output2 = at::empty({2, 4}, options); + at::Tensor cg_output3 = at::empty({0}, options); + + auto reduction_params = getPersistentHeuristics(&fusion, {input0, input1}); + TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); + schedulePersistentKernel(&fusion, *reduction_params); + + auto lparams = reduction_params->lparams; + FusionExecutor fe; + fe.compileFusion(&fusion, {input0, input1}, lparams); + auto cg_outputs = fe.runFusion({input0, input1}, lparams); + auto aten_output2 = input0.sum({0}).add(input0); + at::Tensor aten_output3 = at::empty({0}, options); + + testValidate( + &fusion, + cg_outputs, + {input0, input1}, + {aten_output2, aten_output3}, + __LINE__, + __FILE__, + "", + lparams); +} + +TEST_F(NVFuserTest, FusionSegmentIoAlias_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + TensorView* tv0 = makeSymbolicTensor(2); + TensorView* tv1 = makeSymbolicTensor(1); + TensorView* tv2 = makeSymbolicTensor(2); + + fusion->addInput(tv0); + fusion->addInput(tv1); + fusion->addInput(tv2); + + TensorView* tv3 = add(tv0, IrBuilder::create(1)); // Group 0 + TensorView* tv4 = + max(tv3, {0}); // Group 0 (use max instead to avoid numerical issues) + TensorView* tv5 = add(tv4, tv1); // Group 0 (Non Broadcast after reduce, + // keeps normalization scheduler away) + TensorView* tv6 = add(tv5, tv2); // Group 1 (Broadcast after reduce) + + // Note: test alias; + fusion->aliasOutputToInput(tv6, tv0); + // TODO: support output on aliased fusion #1488 + // remove tv7 after #1488 + // fusion->addOutput(tv6); + TensorView* tv7 = add(tv6, IrBuilder::create(1)); // Group 0 + fusion->addOutput(tv7); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({128, 65}, options); + at::Tensor t1 = at::randn({65}, options); + at::Tensor t2 = at::randn({128, 65}, options); + + auto t3 = t0.add(1.0); + auto t4 = std::get<0>(at::max(t3, 0)); + auto t5 = t4.add(t1); + auto t6 = t5.add(t2); + auto t7 = t6.add(1.0); + + FusionExecutorCache executor_cache(std::move(fusion)); + + auto outputs = executor_cache.runFusionWithInputs({t0, t1, t2}); + + // TODO: support output on aliased fusion #1488 + // validating aliasing + // TORCH_INTERNAL_ASSERT(outputs[0].data_ptr() == t0.data_ptr()); + + TORCH_CHECK( + executor_cache.getMostRecentKernelRuntime()->isSegmented(), + "segmentation didn't happen"); + TORCH_CHECK( + executor_cache.getMostRecentKernelRuntime() + ->fusionSegments() + ->groups() + .size() == 2, + "segmentation didn't happen as expected"); + + testValidate( + executor_cache.fusion(), outputs, {t0, t1, t2}, {t7}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionWelford1Output_CUDA) { + auto fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion->addInput(tv0); + + auto tvs = Welford(tv0, {1}); + fusion->addOutput(tvs.var_sum); + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({128, 65}, options); + auto outputs = executor_cache.runFusionWithInputs({t0}); + + auto t1 = t0.var({1}, false) * 65; + testValidate(fusion, outputs, {t0}, {t1}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionTranslate1Welford_CUDA) { + auto fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion->addInput(tv0); + + auto tvs = Welford(tv0, {1}); + auto tv_out = add(tv0, broadcast(tvs.avg, {false, true})); + fusion->addOutput(tv_out); + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + + auto run_test = [&executor_cache, + fusion](auto inner_size) -> FusionKernelRuntime* { + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({128, inner_size}, options); + auto outputs = executor_cache.runFusionWithInputs({t0}); + // Square sums does not fit well in the testValidate assumptions, + // so we just compare the divided output here. + testValidate( + fusion, + outputs, + {t0}, + {t0.add(t0.mean({1}).unsqueeze(1))}, + __LINE__, + __FILE__); + + return executor_cache.getMostRecentKernelRuntime(); + }; + + // Run a translated welford + auto runtime1 = run_test(64); + // Check it was translated + TORCH_CHECK( + runtime1->fusionSegments()->groups().size() == 1 && + runtime1->fusionSegments()->groups()[0]->exprs().size() > 2); + + // Run an un-translated welford + auto runtime2 = run_test(65536); + + bool found_welford = false; + for (auto group : runtime2->fusionSegments()->groups()) { + for (auto expr : group->exprs()) { + if (expr->isA()) { + found_welford = true; + } + } + } + TORCH_CHECK(found_welford); +} + +TEST_F(NVFuserTest, FusionTranslate2Welford_CUDA) { + auto fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion->addInput(tv0); + + auto tvs1 = Welford(tv0, {1}); + auto tv_out1 = add(tv0, broadcast(tvs1.avg, {false, true})); + fusion->addOutput(tv_out1); + + auto tvs2 = Welford(tv0, {1}); + auto tv_out2 = add(tv0, broadcast(tvs2.avg, {false, true})); + fusion->addOutput(tv_out2); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + + auto run_test = [&executor_cache, + fusion](auto inner_size) -> FusionKernelRuntime* { + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({128, inner_size}, options); + auto outputs = executor_cache.runFusionWithInputs({t0}); + + // Square sums does not fit well in the testValidate assumptions, + // so we just compare the divided output here. + auto out = t0.add(t0.mean({1}).unsqueeze(1)); + testValidate(fusion, outputs, {t0}, {out, out}, __LINE__, __FILE__); + + return executor_cache.getMostRecentKernelRuntime(); + }; + + // Run a translated welford + auto runtime1 = run_test(64); + // Check it was translated + TORCH_CHECK( + runtime1->fusionSegments()->groups().size() == 1 && + runtime1->fusionSegments()->groups()[0]->exprs().size() > 4); + + // Run an un-translated welford + auto runtime2 = run_test(65536); + // // Check it was not translated + bool found_welford = false; + for (auto group : runtime2->fusionSegments()->groups()) { + for (auto expr : group->exprs()) { + if (expr->isA()) { + found_welford = true; + } + } + } + TORCH_CHECK(found_welford); +} + +TEST_F(NVFuserTest, FusionLargeWelfordNormalization_CUDA) { + auto fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion->addInput(tv0); + + auto tvs1 = Welford(tv0, {1}); + auto sum_of_tv0 = sum(tv0, {1}); + + fusion->addOutput(tvs1.var_sum); + fusion->addOutput(sum_of_tv0); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + + auto run_test = [&executor_cache, + fusion](auto inner_size) -> FusionKernelRuntime* { + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({128, inner_size}, options); + auto outputs = executor_cache.runFusionWithInputs({t0}); + + auto t1 = t0.var({1}, false) * inner_size; + auto t2 = t0.sum({1}); + testValidate(fusion, outputs, {t0}, {t1, t2}, __LINE__, __FILE__); + + return executor_cache.getMostRecentKernelRuntime(); + }; + + auto runtime = run_test(65536); + TORCH_CHECK(!runtime->isSegmented()); +} + +TEST_F(NVFuserTest, FusionWelfordOuterPersistence_CUDA) { + auto fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion->addInput(tv0); + + auto tvs1 = Welford(tv0, {1}); + auto sum_of_tv0 = sum(tv0, {1}); + auto sum_bcasted = broadcast(sum_of_tv0, {false, true}); + auto avg_bcasted = broadcast(tvs1.avg, {false, true}); + auto tv0_plus_sum = add(tv0, sum_bcasted); + auto tv0_plus_avg = add(tv0, avg_bcasted); + + fusion->addOutput(tv0_plus_sum); + fusion->addOutput(tv0_plus_avg); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + + auto run_test = [&executor_cache, + fusion](auto inner_size) -> FusionKernelRuntime* { + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({128, inner_size}, options); + auto outputs = executor_cache.runFusionWithInputs({t0}); + + auto t1 = t0.to(c10::kDouble).mean({1}).unsqueeze(1) + t0; + auto t2 = t0.to(c10::kDouble).sum({1}).unsqueeze(1) + t0; + testValidate(fusion, outputs, {t0}, {t2, t1}, __LINE__, __FILE__); + + return executor_cache.getMostRecentKernelRuntime(); + }; + + for (auto inner_size : {4096, 8192, 32768}) { + auto runtime = run_test(inner_size); + TORCH_CHECK(!runtime->isSegmented()); + } +} + +TEST_F(NVFuserTest, FusionSegmentIslands_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeSymbolicTensor(2); + auto tv1 = makeSymbolicTensor(2); + fusion->addInput(tv0); + fusion->addInput(tv1); + + auto tv2 = sum(tv0, {0}); + auto tv3 = sum(tv1, {1}); + fusion->addOutput(tv2); + fusion->addOutput(tv3); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({16, 16}, options); + at::Tensor t1 = at::randn({16, 16}, options); + + FusionExecutorCache fusion_executor_cache(std::move(fusion)); + fusion_executor_cache.runFusionWithInputs({t0, t1}); +} + +TEST_F(NVFuserTest, FusionBackOffInnerBroadcast_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeSymbolicTensor(1); + auto tv1 = makeSymbolicTensor(2); + auto tv2 = makeSymbolicTensor(4); + fusion->addInput(tv0); + fusion->addInput(tv1); + + auto tv3 = broadcast(tv0, {false, true, true, true}); + auto tv4 = broadcast(tv1, {false, false, true, true}); + auto tv5 = unaryOp(UnaryOpType::Rsqrt, tv2); + + auto tv6 = add(tv3, tv5); + auto tv7 = add(tv4, tv5); + auto tv8 = add(tv3, tv4); + + auto tv9 = add(tv6, tv7); + auto tv10 = add(tv9, tv8); + + fusion->addOutput(tv10); + + tv0->computeAt(tv10, -2); + tv1->computeAt(tv10, -2); + tv2->computeAt(tv10, -2); + + TORCH_CHECK(tv3->getComputeAtPosition() == 1); + TORCH_CHECK(tv4->getComputeAtPosition() == 2); + TORCH_CHECK(tv5->getComputeAtPosition() == 3); + + TORCH_CHECK(tv6->getMaxProducerPosition() == 3); + TORCH_CHECK(tv7->getMaxProducerPosition() == 3); + TORCH_CHECK(tv8->getMaxProducerPosition() == 2); +} + +TEST_F(NVFuserTest, FusionBackOffInnerBroadcast2_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeSymbolicTensor(2); + auto tv1 = makeSymbolicTensor(3); + fusion->addInput(tv0); + fusion->addInput(tv1); + auto tv2 = broadcast(tv0, {false, false, true}); + auto tv3 = add(tv2, tv1); + + fusion->addOutput(tv3); + tv3->split(-2, 4); + tv3->reorder({{-1, -2}}); + tv0->computeAt(tv3, -2); + tv1->computeAt(tv3, -2); + TORCH_CHECK(tv2->getComputeAtPosition() == 2); + TORCH_CHECK(tv3->getMaxProducerPosition() == 2); +} + +TEST_F(NVFuserTest, FusionBackOffInnerBroadcast3_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeSymbolicTensor(2); + auto tv1 = makeSymbolicTensor(4); + fusion->addInput(tv0); + fusion->addInput(tv1); + auto tv2 = broadcast(tv0, {false, false, true}); + auto tv3 = broadcast(tv2, {false, true, false, false}); + auto tv4 = add(tv3, tv1); + + fusion->addOutput(tv4); + tv0->computeAt(tv4, -1); + tv1->computeAt(tv4, -1); + TORCH_CHECK(tv2->getComputeAtPosition() == 2); + TORCH_CHECK(tv3->getMaxProducerPosition() == 3); +} + +TEST_F(NVFuserTest, FusionSimpleWarp_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeSymbolicTensor(2); + fusion->addInput(tv0); + + auto tv1 = sum(tv0, {1}); + auto tv2 = broadcast(tv1, {false, true}); + auto tv3 = add(tv2, tv0); + + fusion->addOutput(tv3); + + tv1->split(1, 32); + auto tv1_rf = tv1->rFactor({1}); + TransformPropagatorWithCheck propagator(tv1_rf); + MaxRootDomainInfoSpanningTree(tv1_rf).traverse(&propagator); + tv1_rf->axis(-1)->parallelize(ParallelType::TIDx); + tv1->axis(0)->parallelize(ParallelType::BIDx); + tv1->axis(-1)->parallelize(ParallelType::TIDx); + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(-1)->parallelize(ParallelType::TIDx); + tv3->axis(0)->parallelize(ParallelType::BIDx); + tv3->axis(-1)->parallelize(ParallelType::TIDx); + tv0->computeAt(tv3, -1, ComputeAtMode::MostInlined); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input1 = at::randn({16, 128}, options); + + auto at_output = input1.sum({1}, true).add(input1); + + FusionExecutor fe; + fe.compileFusion(fusion.get(), {input1}); + auto outputs = fe.runFusion({input1}); + + testValidate( + fusion.get(), outputs, {input1}, {at_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionSimpleWarpPad_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeSymbolicTensor(2); + + fusion->addInput(tv0); + + auto tv1 = sum(tv0, {1}); + auto tv2 = broadcast(tv1, {false, true}); + auto tv3 = add(tv2, tv0); + + fusion->addOutput(tv3); + + // Schedule a persistent kernel + auto tv0_cache = tv0->cacheAfter(); + tv1->split(1, 8, false); + auto tv1_rf = tv1->rFactor({1}); + tv1_rf->axis(0)->parallelize(ParallelType::BIDx); + tv1_rf->axis(-1)->parallelize(ParallelType::TIDx); + tv1_rf->axis(-1)->padToMultipleOfWarp(32); + tv1->axis(-1)->parallelize(ParallelType::TIDx); + tv1->axis(-1)->padToMultipleOfWarp(32); + TransformPropagatorWithCheck propagator(tv1_rf); + MaxRootDomainInfoSpanningTree(tv1_rf).traverse(&propagator); + tv0->axis(-1)->parallelize(ParallelType::TIDx); + tv0->axis(-1)->padToMultipleOfWarp(32); + tv0_cache->axis(-1)->parallelize(ParallelType::TIDx); + tv0_cache->axis(-1)->padToMultipleOfWarp(32); + tv2->axis(-1)->parallelize(ParallelType::TIDx); + tv2->axis(-1)->padToMultipleOfWarp(32); + tv3->axis(-1)->parallelize(ParallelType::TIDx); + tv3->axis(-1)->padToMultipleOfWarp(32); + + tv0->computeAt(tv3, -1, ComputeAtMode::MostInlined); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input1 = at::randn({16, 127}, options); + + auto at_output = input1.sum({1}, true).add(input1); + + FusionExecutor fe; + fe.compileFusion(fusion.get(), {input1}); + auto outputs = fe.runFusion({input1}); + testValidate( + fusion.get(), outputs, {input1}, {at_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionWarpPadMergeSplit_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeSymbolicTensor(3); + + fusion->addInput(tv0); + + auto tv1 = sum(tv0, {1, 2}); + auto tv2 = broadcast(tv1, {false, true, true}); + auto tv3 = add(tv2, tv0); + + fusion->addOutput(tv3); + + // Schedule a persistent kernel + auto tv0_cache = tv0->cacheAfter(); + tv1->merge(1); + tv1->split(1, 8, false); + + auto tv1_rf = tv1->rFactor({1}); + tv1_rf->axis(0)->parallelize(ParallelType::BIDx); + tv1_rf->axis(-1)->parallelize(ParallelType::TIDx); + tv1->axis(-1)->parallelize(ParallelType::TIDx); + tv1->axis(-1)->padToMultipleOfWarp(); + TransformPropagatorWithCheck propagator(tv1_rf); + MaxRootDomainInfoSpanningTree(tv1_rf).traverse(&propagator); + tv0->axis(-1)->parallelize(ParallelType::TIDx); + tv0_cache->axis(-1)->parallelize(ParallelType::TIDx); + tv2->axis(-1)->parallelize(ParallelType::TIDx); + tv3->axis(-1)->parallelize(ParallelType::TIDx); + + tv0->computeAt(tv3, -1, ComputeAtMode::MostInlined); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input1 = at::randn({16, 17, 128}, options); + + auto at_output = input1.sum({1, 2}, true).add(input1); + + FusionExecutor fe; + fe.compileFusion(fusion.get(), {input1}); + auto outputs = fe.runFusion({input1}); + testValidate( + fusion.get(), outputs, {input1}, {at_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionSerialWarpReduction_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeSymbolicTensor(3); + + fusion->addInput(tv0); + + auto tv1 = sum(tv0, {1, 2}); + auto tv2 = broadcast(tv1, {false, true, true}); + auto tv3 = add(tv2, tv0); + + fusion->addOutput(tv3); + + // Schedule a persistent kernel + auto tv0_cache = tv0->cacheAfter(); + tv1->merge(1); + tv1->split(1, 8, false); + + tv1->axis(-1)->parallelize(ParallelType::TIDx); + tv1->axis(-1)->padToMultipleOfWarp(); + TransformPropagatorWithCheck propagator(tv1); + MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator); + tv0->axis(-1)->parallelize(ParallelType::TIDx); + tv0_cache->axis(-1)->parallelize(ParallelType::TIDx); + tv2->axis(-1)->parallelize(ParallelType::TIDx); + tv3->axis(-1)->parallelize(ParallelType::TIDx); + + tv0->computeAt(tv3, -1, ComputeAtMode::MostInlined); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input1 = at::randn({16, 17, 128}, options); + + auto at_output = input1.sum({1, 2}, true).add(input1); + + FusionExecutor fe; + fe.compileFusion(fusion.get(), {input1}); + auto outputs = fe.runFusion({input1}); + testValidate( + fusion.get(), outputs, {input1}, {at_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionTrivialWarpReduction_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeConcreteTensor({17, 18, 128, 1}); + + fusion->addInput(tv0); + + auto tv1 = sum(tv0, {1, 2, 3}); + auto tv2 = broadcast(tv1, {false, true, true, true}); + auto tv3 = add(tv2, tv0); + + fusion->addOutput(tv3); + + // Schedule a persistent kernel + auto tv0_cache = tv0->cacheAfter(); + tv1->merge(1); + tv1->split(1, 8, false); + + auto tv1_rf = tv1->rFactor({1}); + tv1_rf->axis(0)->parallelize(ParallelType::BIDx); + tv1_rf->axis(-2)->parallelize(ParallelType::TIDx); + tv1->axis(-2)->parallelize(ParallelType::TIDx); + tv1->axis(-2)->padToMultipleOfWarp(); + TransformPropagatorWithCheck propagator(tv1_rf); + MaxRootDomainInfoSpanningTree(tv1_rf).traverse(&propagator); + tv0->axis(-2)->parallelize(ParallelType::TIDx); + tv0_cache->axis(-2)->parallelize(ParallelType::TIDx); + tv2->axis(-2)->parallelize(ParallelType::TIDx); + tv3->axis(-2)->parallelize(ParallelType::TIDx); + + tv0->computeAt(tv3, -1, ComputeAtMode::MostInlined); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input1 = at::randn({17, 18, 128, 1}, options); + + auto at_output = input1.sum({1, 2, 3}, true).add(input1); + + FusionExecutor fe; + fe.compileFusion(fusion.get(), {input1}); + auto outputs = fe.runFusion({input1}); + testValidate( + fusion.get(), outputs, {input1}, {at_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionMultipleDimBinding_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeSymbolicTensor(2); + auto tv_add = makeSymbolicTensor(2); + + fusion->addInput(tv0); + fusion->addInput(tv_add); + + auto tv1 = sum(tv0, {1}); + auto tv2 = broadcast(tv1, {false, true}); + auto tv3 = add(tv2, tv0); + auto tv4 = add(tv0, tv_add); + + fusion->addOutput(tv3); + fusion->addOutput(tv4); + + // Schedule a persistent kernel + auto tv0_cache = tv0->cacheAfter(); + tv1->split(1, 8, false); + auto tv1_rf = tv1->rFactor({1}); + tv1_rf->axis(0)->parallelize(ParallelType::BIDx); + tv1_rf->axis(-1)->parallelize(ParallelType::TIDx); + tv1_rf->axis(-1)->padToMultipleOfWarp(32); + tv1->axis(-1)->parallelize(ParallelType::TIDx); + tv1->axis(-1)->padToMultipleOfWarp(32); + TransformPropagatorWithCheck propagator(tv1_rf); + MaxRootDomainInfoSpanningTree(tv1_rf).traverse(&propagator); + tv0->axis(-1)->parallelize(ParallelType::TIDx); + tv0->axis(-1)->padToMultipleOfWarp(32); + tv0_cache->axis(-1)->parallelize(ParallelType::TIDx); + tv0_cache->axis(-1)->padToMultipleOfWarp(32); + tv2->axis(-1)->parallelize(ParallelType::TIDx); + tv2->axis(-1)->padToMultipleOfWarp(32); + tv3->axis(-1)->parallelize(ParallelType::TIDx); + tv3->axis(-1)->padToMultipleOfWarp(32); + tv4->axis(-1)->parallelize(ParallelType::TIDx); + tv4->axis(-1)->padToMultipleOfWarp(64); + + tv0->computeAt(tv3, -1, ComputeAtMode::MostInlined); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input1 = at::randn({16, 128}, options); + at::Tensor input2 = at::randn({16, 128}, options); + + auto at_output = input1.sum({1}, true).add(input1); + + FusionExecutor fe; + fe.compileFusion(fusion.get(), {input1, input2}); + auto outputs = fe.runFusion({input1, input2}); + testValidate( + fusion.get(), + outputs, + {input1, input2}, + {at_output, input1 + input2}, + __LINE__, + __FILE__); +} + +TEST_F(NVFuserTest, FusionPadNoWarpReduce_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeSymbolicTensor(2); + + fusion->addInput(tv0); + + auto tv1 = sum(tv0, {1}); + auto tv2 = broadcast(tv1, {false, true}); + auto tv3 = add(tv2, tv0); + + fusion->addOutput(tv3); + + tv1->axis(-1)->parallelize(ParallelType::TIDx); + tv1->axis(-1)->padToMultipleOfWarp(); + tv2->axis(-1)->parallelize(ParallelType::TIDx); + tv3->axis(-1)->parallelize(ParallelType::TIDx); + + tv1->axis(0)->parallelize(ParallelType::TIDy); + tv2->axis(0)->parallelize(ParallelType::TIDy); + tv3->axis(0)->parallelize(ParallelType::TIDy); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input1 = at::randn({16, 31}, options); + + auto at_output = input1.sum({1}, true).add(input1); + + FusionExecutor fe; + fe.compileFusion(fusion.get(), {input1}); + auto outputs = fe.runFusion({input1}); + testValidate( + fusion.get(), outputs, {input1}, {at_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionWarpMutipleThreadDim_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeSymbolicTensor(2); + fusion->addInput(tv0); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = sum(tv1, {1}); + fusion->addOutput(tv2); + + tv2->split(1, 8); + auto tv2_rf = tv2->rFactor({-1}); + tv2_rf->axis(-1)->parallelize(ParallelType::TIDx); + tv2_rf->axis(-1)->padToMultipleOfWarp(); + + TransformPropagatorWithCheck propagator(tv2_rf); + MaxRootDomainInfoSpanningTree(tv2_rf).traverse(&propagator); + + tv0->axis(-1)->parallelize(ParallelType::TIDx); + tv1->axis(-1)->parallelize(ParallelType::TIDx); + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(1)->parallelize(ParallelType::TIDy); + tv0->computeAt(tv2, 2); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input1 = at::randn({16, 31}, options); + + auto at_output = (input1 + 1).sum({1}); + + FusionExecutor fe; + fe.compileFusion(fusion.get(), {input1}); + auto outputs = fe.runFusion({input1}); + testValidate( + fusion.get(), outputs, {input1}, {at_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionWarpReduceUnrollOuterLoop_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeSymbolicTensor(2); + + fusion->addInput(tv0); + + auto tv1 = sum(tv0, {1}); + auto tv2 = broadcast(tv1, {false, true}); + auto tv3 = add(tv2, tv0); + + fusion->addOutput(tv3); + + // Schedule a persistent kernel + auto tv0_cache = tv0->cacheAfter(); + tv1->split(1, 8, false); + tv1->split(0, 4); + auto tv1_rf = tv1->rFactor({2}); + + tv1_rf->axis(0)->parallelize(ParallelType::BIDx); + tv1_rf->axis(1)->parallelize(ParallelType::Unroll); + tv1_rf->axis(-1)->parallelize(ParallelType::TIDx); + tv1->axis(-1)->parallelize(ParallelType::TIDx); + tv1->axis(-1)->padToMultipleOfWarp(); + tv1->axis(1)->parallelize(ParallelType::Unroll); + TransformPropagatorWithCheck propagator(tv1_rf); + MaxRootDomainInfoSpanningTree(tv1_rf).traverse(&propagator); + tv0->axis(-1)->parallelize(ParallelType::TIDx); + tv0->axis(1)->parallelize(ParallelType::Unroll); + tv0_cache->axis(-1)->parallelize(ParallelType::TIDx); + tv0_cache->axis(1)->parallelize(ParallelType::Unroll); + tv2->axis(-1)->parallelize(ParallelType::TIDx); + tv2->axis(1)->parallelize(ParallelType::Unroll); + tv3->axis(-1)->parallelize(ParallelType::TIDx); + tv3->axis(1)->parallelize(ParallelType::Unroll); + + tv0->computeAt(tv3, -1, ComputeAtMode::MostInlined); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input1 = at::randn({16, 128}, options); + + auto at_output = input1.sum({1}, true).add(input1); + + FusionExecutor fe; + fe.compileFusion(fusion.get(), {input1}); + auto outputs = fe.runFusion({input1}); + testValidate( + fusion.get(), outputs, {input1}, {at_output}, __LINE__, __FILE__); +} + +// Repro of issue #1579 +TEST_F(NVFuserTest, FusionWarpReducePredication_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + std::vector shape1 = {1024}; + std::vector shape2 = {50}; + + auto tv0 = makeConcreteTensor(shape1); + fusion.addInput(tv0); + auto tv1 = sum(tv0, {0}); + fusion.addOutput(tv1); + + auto tv2 = makeConcreteTensor(shape2); + fusion.addInput(tv2); + auto tv3 = add(tv2, IrBuilder::create(1)); + auto tv4 = sum(tv3, {0}); + auto tv5 = add(tv4, IrBuilder::create(1)); + fusion.addOutput(tv5); + + // Just to fill the smem buffer by a thread block of 1024 threads + // with some values + tv1->axis(-1)->parallelize(ParallelType::TIDx); + + // Make the tv4_rf reduction a warp reduction to trigger the + // bug. Since the smem buffer is filled with some values due to the + // reduction of tv1, those values would be used by predicated-out + // threads. + tv4->split(-1, 10); + auto tv4_rf = tv4->rFactor({-1}); + tv4_rf->axis(-1)->parallelize(ParallelType::TIDx); + tv4_rf->axis(-1)->padToMultipleOfWarp(); + + tv4_rf->computeAt(tv4, 1); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn(shape1, options); + auto t2 = at::randn(shape2, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t2}); + auto cg_outputs = fe.runFusion({t0, t2}); + + auto t1 = t0.sum({0}); + auto t4 = (t2 + 1).sum({0}) + 1; + + testValidate(&fusion, cg_outputs, {t0, t2}, {t1, t4}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionSegfaultReduction_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + int batch = 2; + int c = 1; + int h = 1; + int w = 1; + int numDims = 4; + + auto input = makeConcreteTensor({-1, 1, 1, 1}); + fusion.addInput(input); + auto bcast_bias = makeConcreteTensor({-1, 1, 1, 1}); + fusion.addInput(bcast_bias); + + std::vector at_sum_axes; + std::vector outer_reduction_axes; + std::vector outer_broadcast_mask(numDims, false); + Val* N = IrBuilder::create(1); + for (const auto axis : c10::irange(numDims)) { + if (axis != 1) { + outer_reduction_axes.push_back(axis); + at_sum_axes.push_back(axis); + outer_broadcast_mask[axis] = true; + N = mul(N, input->domain()->domain()[axis]->extent()); + } + } + + auto output0 = mul(input, bcast_bias); + fusion.addOutput(output0); + auto output1 = sum(output0, outer_reduction_axes); + fusion.addOutput(output1); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input0 = at::randn({batch, c, h, w}, options); + at::Tensor input1 = at::randn({batch, c, h, w}, options); + + auto at_output0 = input0.mul(input1); + auto at_output1 = at_output0.sum(at_sum_axes); + + FusionExecutorCache fec(std::move(fusion_ptr)); + std::vector inputs = {input0, input1}; + auto outputs = fec.runFusionWithInputs(inputs); + + testValidate( + &fusion, outputs, inputs, {at_output0, at_output1}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionPredicateElimination1_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(2)); + auto tv3 = add(tv2, IrBuilder::create(3)); + + fusion.addOutput(tv3); + + tv3->split(0, 32); + tv0->computeAt(tv3, 1); + + tv2->axis(1)->parallelize(ParallelType::Unswitch); + + { + GpuLower gpulw(&fusion); + TORCH_CHECK(!PredicatedChecker::isPredicated(tv2, gpulw)); + } + + tv2->axis(1)->parallelize(ParallelType::Serial); + tv2->split(1, 5); + + { + GpuLower gpulw(&fusion); + TORCH_CHECK(PredicatedChecker::isPredicated(tv2, gpulw)); + } +} + +// Repro of issue #1571 +TEST_F(NVFuserTest, FusionPredicateElimination2_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + std::vector shape({10, 11}); + + auto tv0 = makeConcreteTensor(shape); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = sum(tv1, {1}); + auto tv3 = add(tv2, IrBuilder::create(1)); + + fusion.addOutput(tv3); + + tv1->split(1, 4); + tv1->split(0, 4); + tv2->split(1, 4); + tv2->split(0, 4); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn(shape, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto cg_outputs = fe.runFusion({t0}); + + auto ref = (t0 + 1).sum({1}) + 1; + + testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionPredicateElimination3_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + auto tv1 = sum(tv0, {0}); + auto tv2 = add(tv1, IrBuilder::create(1)); + fusion.addOutput(tv2); + + auto tv3 = tv0->cacheAfter(); + + tv1->split(0, 10); + tv1->split(0, 33); + TransformPropagatorWithCheck propagator(tv1); + MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator); + + auto tv4 = tv1->rFactor({-1}); + auto tv5 = tv1->rFactor({-1}); + + tv4->axis(0)->parallelize(ParallelType::BIDx); + tv4->axis(1)->parallelize(ParallelType::TIDx); + scheduler_utils::parallelizeAllLike(tv4); + + GpuLower gpulw(&fusion); + + // The fusion has three reductions: one within each thread, one + // within each block, and another with the whole grid. All of them + // should not need to be predicated as they use the same init value + // and same reduction op. + TORCH_CHECK(!PredicatedChecker::isPredicated(tv4, gpulw)); + TORCH_CHECK(!PredicatedChecker::isPredicated(tv5, gpulw)); + TORCH_CHECK(!PredicatedChecker::isPredicated(tv1, gpulw)); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + for (auto size : {1, 2, 999, 1001, 1234, 10000}) { + auto t0 = at::randn({size}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto cg_outputs = fe.runFusion({t0}); + + auto ref = sum(t0) + 1; + testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); + } +} + +TEST_F(NVFuserTest, FusionPredicateElimination4_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto tv1 = sum(tv0, {1}); + + auto tv2 = sum(tv1, {0}); + auto tv3 = add(tv2, IrBuilder::create(1)); + fusion.addOutput(tv3); + + auto tv4 = max(tv1, {0}); + auto tv5 = add(tv4, IrBuilder::create(1)); + fusion.addOutput(tv5); + + tv1->split(1, 7); + tv1->split(0, 11); + tv1->reorder({{1, 2}, {2, 1}}); + TransformPropagatorWithCheck propagator(tv1); + MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator); + + tv1->axis(0)->parallelize(ParallelType::TIDy); + tv1->axis(1)->parallelize(ParallelType::TIDx); + scheduler_utils::parallelizeAllLike(tv1); + + GpuLower gpulw(&fusion); + + // tv2 uses the same op and init with tv1, so tv2 should be fine + // without a predicate. However, tv4, while it uses the tv1 as its + // input, the reduction op and init value is different from those of + // tv1, so tv4 needs to be predicated. + TORCH_CHECK(!PredicatedChecker::isPredicated(tv2, gpulw)); + TORCH_CHECK(PredicatedChecker::isPredicated(tv4, gpulw)); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + std::vector sizes = {1, 2, 33, 34, 64, 99}; + for (auto s0 : sizes) { + for (auto s1 : sizes) { + auto t0 = at::randn({s0, s1}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto cg_outputs = fe.runFusion({t0}); + + auto t1 = t0.sum({1}); + auto t3 = t1.sum({0}) + 1; + auto t5 = std::get<0>(t1.max(0)) + 1; + + testValidate(&fusion, cg_outputs, {t0}, {t3, t5}, __LINE__, __FILE__); + } + } +} + +TEST_F(NVFuserTest, FusionPredicateElimination5_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + auto tvs2 = Welford(tv1, {0}); + auto tv3 = set(tvs2.avg); + fusion.addOutput(tv3); + + tvs2.avg->split(0, 4); + TransformPropagatorWithCheck propagator(tvs2.avg); + MaxRootDomainInfoSpanningTree(tvs2.avg).traverse(&propagator); + auto avg_rf = ir_utils::rfactorHelper(tvs2.avg, {1}); + + avg_rf->axis(0)->parallelize(ParallelType::TIDx); + scheduler_utils::parallelizeAllLike(avg_rf); + + GpuLower gpulw(&fusion); + + // The first per-thread welford needs to be predicated as the N + // input is different from its init value. The second welford op + // does not need a predicate. + TORCH_CHECK(PredicatedChecker::isPredicated(avg_rf, gpulw)); + TORCH_CHECK(!PredicatedChecker::isPredicated(tvs2.avg, gpulw)); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + std::vector sizes = {1, 2, 33, 34, 64, 99}; + for (auto s0 : sizes) { + auto t0 = at::randn({s0}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto cg_outputs = fe.runFusion({t0}); + + auto ref = t0.mean({0}); + + testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); + } +} + +TEST_F(NVFuserTest, FusionPredicateElimination6_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({2, 3}); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); + auto tv3 = add(tv2, IrBuilder::create(1)); + auto tv4 = add(tv3, IrBuilder::create(1)); + fusion.addOutput(tv4); + + tv4->split(1, 5); + TransformPropagatorWithCheck propagator(tv4); + MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); + + tv4->reorder({{0, 1}, {1, 0}}); + tv3->computeAt(tv4, 1); + + GpuLower gpulw(&fusion); + + // The expression for tv2 is a local-to-local expression. It + // satisfies all the requirements of predicate elimination, except + // for the on on split root domains. As the second root axis of tv2 + // is split, its index exceeds its extent (i.e., 3 in this case) + // without its predicate. + TORCH_CHECK(PredicatedChecker::isPredicated(tv2, gpulw)); + + // Unlike tv2, tv3 is computed at tv4, so the second root axis does + // have a zero domain. Its index should look like "i * 5 + j", where + // i comes from the first root domain and j comes from the split + // inner domain. + TORCH_CHECK(!PredicatedChecker::isPredicated(tv3, gpulw)); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({2, 3}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto cg_outputs = fe.runFusion({t0}); + + auto ref = t0 + 4; + + testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionPredicateElimination7_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); + auto tv3 = add(tv2, IrBuilder::create(1)); + fusion.addOutput(tv3); + + tv3->split(-1, 5); + tv3->split(-1, 4); + tv3->split(-1, 3); + TransformPropagatorWithCheck propagator(tv3); + MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + + tv0->computeAt(tv3, 1); + + // The last split of tv2 is a non-divisible split, and omitting it + // is invalid. + GpuLower gpulw(&fusion); + TORCH_CHECK(PredicatedChecker::isPredicated(tv2, gpulw)); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({123}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto cg_outputs = fe.runFusion({t0}); + + auto ref = t0 + 3; + + testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionForceFp16Simple_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + auto tv0 = makeSymbolicTensor(2); + auto tv1 = makeSymbolicTensor(2); + + fusion->addInput(tv0); + fusion->addInput(tv1); + + // Group 1 + auto tv2 = sum(tv0, {1}); + auto tv3 = broadcast(tv2, {false, true}); + + // Group 2 + auto tv4 = add(tv3, tv1); // Edge: tv3: expect cast + auto tv5 = castOp(DataType::Half, tv4); + + fusion->addOutput(tv5); + + FusionExecutorCache fec(std::move(fusion_ptr)); + + std::vector shape{15, 16}; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto in0 = at::randn(shape, options); + auto in1 = at::randn(shape, options); + fec.runFusionWithInputs({in0, in1}); + + // Check the segmented edge is fp16 + auto segmented_fusion = fec.getMostRecentKernelRuntime()->fusionSegments(); + for (auto edge : segmented_fusion->edges()) { + auto edge_tv = edge->val->as(); + TORCH_CHECK(edge_tv->getDataType() == DataType::Half); + } +} + +TEST_F(NVFuserTest, FusionForceBf16Simple_CUDA) { +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 + // requires ampere+ GPU + if (!deviceMajorMinorCheck(8)) { + GTEST_SKIP() << "skipping tests on pre-AMPERE GPUs"; + return; + } + + std::unique_ptr fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + auto tv0 = makeSymbolicTensor(2); + auto tv1 = makeSymbolicTensor(2); + + fusion->addInput(tv0); + fusion->addInput(tv1); + + // Group 1 + auto tv2 = sum(tv0, {1}); + auto tv3 = broadcast(tv2, {false, true}); + + // Group 2 + auto tv4 = add(tv3, tv1); // Edge: tv3: expect cast + auto tv5 = castOp(DataType::BFloat16, tv4); + + fusion->addOutput(tv5); + + FusionExecutorCache fec(std::move(fusion_ptr)); + + std::vector shape{15, 16}; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto in0 = at::randn(shape, options); + auto in1 = at::randn(shape, options); + fec.runFusionWithInputs({in0, in1}); + + // Check the segmented edge is bf16 + auto segmented_fusion = fec.getMostRecentKernelRuntime()->fusionSegments(); + for (auto edge : segmented_fusion->edges()) { + auto edge_tv = edge->val->as(); + TORCH_CHECK(edge_tv->getDataType() == DataType::BFloat16); + } +#else + GTEST_SKIP() << "requires cuda 11.0 or newer toolkit"; +#endif +} + +TEST_F(NVFuserTest, FusionForceFp16NotAllCast_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + auto tv0 = makeSymbolicTensor(3); + auto tv1 = makeSymbolicTensor(3); + + fusion->addInput(tv0); + fusion->addInput(tv1); + + // Group 1 + auto tv3 = sum(tv0, {1}); + auto tv4 = broadcast(tv3, {false, true, false}); + auto tv5 = sum(tv0, {1}); + + // Group 2 + auto tv6 = add(tv4, tv1); // edge tv4, expect cast + auto tv7 = castOp(DataType::Half, tv6); + + // Group 3 + auto tv8 = sum(tv5, {1}); // edge tv5, don't expect cast + + fusion->addOutput(tv7); + fusion->addOutput(tv8); + + FusionExecutorCache fec(std::move(fusion_ptr)); + + std::vector shape{16, 16, 16}; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto in0 = at::randn(shape, options); + auto in1 = at::randn(shape, options); + fec.runFusionWithInputs({in0, in1}); + + auto segmented_fusion = fec.getMostRecentKernelRuntime()->fusionSegments(); + auto complete_fusion = segmented_fusion->completeFusion(); + + // Check that the edge that wasn't fp16 is the producer of the + // reduction op, i.e. tv8 = sum(tv5,{1});. + for (auto edge : segmented_fusion->edges()) { + auto edge_tv = edge->val->as(); + if (edge_tv->getDataType() == DataType::Float) { + auto consumer = *(complete_fusion->unordered_uses(edge_tv).begin()); + TORCH_CHECK(consumer->isA()); + } + } +} + +TEST_F(NVFuserTest, FusionForceBf16NotAllCast_CUDA) { +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 + // requires ampere+ GPU + if (!deviceMajorMinorCheck(8)) { + GTEST_SKIP() << "skipping tests on pre-AMPERE GPUs"; + return; + } + + std::unique_ptr fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + auto tv0 = makeSymbolicTensor(3); + auto tv1 = makeSymbolicTensor(3); + + fusion->addInput(tv0); + fusion->addInput(tv1); + + // Group 1 + auto tv3 = sum(tv0, {1}); + auto tv4 = broadcast(tv3, {false, true, false}); + auto tv5 = sum(tv0, {1}); + + // Group 2 + auto tv6 = add(tv4, tv1); // edge tv4, expect cast + auto tv7 = castOp(DataType::BFloat16, tv6); + + // Group 3 + auto tv8 = sum(tv5, {1}); // edge tv5, don't expect cast + + fusion->addOutput(tv7); + fusion->addOutput(tv8); + + FusionExecutorCache fec(std::move(fusion_ptr)); + + std::vector shape{16, 16, 16}; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto in0 = at::randn(shape, options); + auto in1 = at::randn(shape, options); + fec.runFusionWithInputs({in0, in1}); + + auto segmented_fusion = fec.getMostRecentKernelRuntime()->fusionSegments(); + auto complete_fusion = segmented_fusion->completeFusion(); + + // Check that the edge that wasn't fp16 is the producer of the + // reduction op, i.e. tv8 = sum(tv5,{1});. + for (auto edge : segmented_fusion->edges()) { + auto edge_tv = edge->val->as(); + if (edge_tv->getDataType() == DataType::Float) { + auto consumer = *(complete_fusion->unordered_uses(edge_tv).begin()); + TORCH_CHECK(consumer->isA()); + } + } +#else + GTEST_SKIP() << "requires cuda 11.0 or newer toolkit"; +#endif +} + +TEST_F(NVFuserTest, FusionBufferReuseBroadCastMultiVisit_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + auto tv0 = makeConcreteTensor({2, 2}); + auto tv1 = makeConcreteTensor({2, 2, 2}); + + fusion->addInput(tv0); + fusion->addInput(tv1); + + auto tv2 = mul(tv0, IrBuilder::create(2)); + auto tv3 = broadcast(tv2, {false, false, true}); + auto tv4 = add(tv3, tv1); + auto tv5 = mul(tv4, IrBuilder::create(3)); + fusion->addOutput(tv5); + + // t4 cannot inner re-use t2, because there's a broadcast + // between them. + tv0->computeAt(tv5, 1, ComputeAtMode::BestEffort); + tv3->computeAt(tv5, 2, ComputeAtMode::BestEffort); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto in0 = at::randn({2, 2}, options); + auto in1 = at::randn({2, 2, 2}, options); + + auto at_output = ((in0 * 2).unsqueeze(2) + in1) * 3; + FusionExecutor fe; + fe.compileFusion(fusion, {in0, in1}); + auto outputs = fe.runFusion({in0, in1}); + + testValidate(fusion, outputs, {in0, in1}, {at_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionBufferReuseStressTest_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + auto tv0 = makeConcreteTensor({2, 2}); + auto tv1 = makeConcreteTensor({2, 2, 2}); + + fusion->addInput(tv0); + fusion->addInput(tv1); + + auto tv2 = mul(tv0, IrBuilder::create(2)); + auto tv3 = mul(tv0, IrBuilder::create(3)); + auto tv4 = mul(tv2, tv3); + // Broadcast buffer can be reused through outer sharing + auto tv5 = broadcast(tv4, {true, false, false}); + auto tv6 = mul(tv5, IrBuilder::create(5)); + auto tv7 = mul(tv6, tv1); + auto tv8 = mul(tv7, IrBuilder::create(7)); + // tv9 shouldn't alias to avoid buffer over-subscription + auto tv9 = broadcast(tv4, {true, false, false}); + auto tv10 = mul(tv9, IrBuilder::create(9)); + auto tv11 = add(tv5, tv9); + fusion->addOutput(tv7); + fusion->addOutput(tv11); + + tv0->computeAt(tv5, 1, ComputeAtMode::BestEffort); + tv0->computeAt(tv9, 1, ComputeAtMode::BestEffort); + + tv5->computeAt(tv7, 1, ComputeAtMode::BestEffort); + tv5->computeAt(tv11, 1, ComputeAtMode::BestEffort); + tv9->computeAt(tv11, 1, ComputeAtMode::BestEffort); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto in0 = at::randn({2, 2}, options); + auto in1 = at::randn({2, 2, 2}, options); + auto t2 = in0 * 2; + auto t3 = in0 * 3; + auto t4 = t2 * t3; + auto t5 = t4.unsqueeze(0); + auto t6 = t5 * 5; + auto t7 = t6 * in1; + auto t8 = t7 * 7; + auto t9 = t4.unsqueeze(0); + auto t10 = t9 * 9; + auto t11 = t5 + t9; + FusionExecutor fe; + fe.compileFusion(fusion, {in0, in1}); + + auto at_output = ((in0 * 2).unsqueeze(2) + in1) * 3; + auto outputs = fe.runFusion({in0, in1}); + + testValidate(fusion, outputs, {in0, in1}, {t7, t11}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionBufferReuseLargeBuffer_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + auto tv0 = makeConcreteTensor({256, 512}); + + fusion->addInput(tv0); + + auto tv1 = mul(tv0, IrBuilder::create(2)); + auto tv2 = mul(tv1, IrBuilder::create(2)); + auto tv3 = mul(tv2, IrBuilder::create(2)); + auto tv4 = mul(tv3, IrBuilder::create(2)); + auto tv5 = mul(tv4, IrBuilder::create(2)); + auto tv6 = mul(tv5, IrBuilder::create(2)); + + fusion->addOutput(tv6); + + tv0->computeAt(tv6, 1, ComputeAtMode::BestEffort); + tv6->axis(0)->parallelize(ParallelType::TIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto in0 = at::randn({256, 512}, options); + + FusionExecutor fe; + fe.compileFusion(fusion, {in0}); + auto outputs = fe.runFusion({in0}); + + auto at_out = in0.mul(2).mul(2).mul(2).mul(2).mul(2).mul(2); + + testValidate(fusion, outputs, {in0}, {at_out}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionBufferReuseNo2hop_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + auto tv0 = makeConcreteTensor({2, 2}); + auto tv1 = makeConcreteTensor({2, 2, 2}); + + fusion->addInput(tv0); + fusion->addInput(tv1); + + auto tv2 = mul(tv0, IrBuilder::create(2)); + auto tv3 = broadcast(tv2, {false, false, true}); + auto tv4 = add(tv3, tv1); // T4 to be inner aliased first, and + // shouldn't outer alias on top + auto tv5 = mul(tv4, IrBuilder::create(3)); + auto tv6 = mul(tv5, IrBuilder::create(3)); + fusion->addOutput(tv6); + + tv0->computeAt(tv6, 1, ComputeAtMode::BestEffort); + tv4->computeAt(tv6, 2, ComputeAtMode::BestEffort); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto in0 = at::randn({2, 2}, options); + auto in1 = at::randn({2, 2, 2}, options); + FusionExecutor fe; + fe.compileFusion(fusion, {in0, in1}); + auto outputs = fe.runFusion({in0, in1}); + + auto at_out = (in0.mul(2.0).unsqueeze(2) + in1).mul(3.0).mul(3.0); + + testValidate(fusion, outputs, {in0, in1}, {at_out}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionBufferReuseAllocationOrder_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + auto tv0 = makeConcreteTensor({3, 3, 3}); + + fusion->addInput(tv0); + + auto tv1 = sum(tv0, {1}); + auto tv2 = mul(tv1, IrBuilder::create(2)); + auto tv3 = mul(tv2, IrBuilder::create(2)); + + fusion->addOutput(tv3); + + // In this case tv1 "reuses" allocation of tv2 + // due to the switched allocation order + tv1->computeAt(tv2, 1, ComputeAtMode::BestEffort); + + tv0->axis(0)->parallelize(ParallelType::TIDx); + tv1->axis(0)->parallelize(ParallelType::TIDx); + tv2->axis(0)->parallelize(ParallelType::TIDx); + tv3->axis(0)->parallelize(ParallelType::TIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto in0 = at::randn({3, 3, 3}, options); + + FusionExecutor fe; + fe.compileFusion(fusion, {in0}); + auto outputs = fe.runFusion({in0}); + + auto at_out = in0.sum(1).mul(2).mul(2); + + testValidate(fusion, outputs, {in0}, {at_out}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionBufferReuseLiveInterval_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + auto tv0 = makeConcreteTensor({16, 16}); + + fusion->addInput(tv0); + + auto tv1 = mul(tv0, IrBuilder::create(3)); + auto tv2 = mul(tv1, IrBuilder::create(2)); + auto tv3 = mul(tv2, IrBuilder::create(2)); + // tv1 used till here, cannot be reused by tv2 or tv3 + auto tv4 = mul(tv3, tv1); + + fusion->addOutput(tv4); + + tv0->computeAt(tv4, 1); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto in0 = at::randn({16, 16}, options); + + FusionExecutor fe; + fe.compileFusion(fusion, {in0}); + auto cg_outputs = fe.runFusion({in0}); + + auto at_t0 = in0 * 3.0; + auto at_out = at_t0 * 2.0 * 2.0 * at_t0; + + testValidate(fusion, cg_outputs, {in0}, {at_out}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionBufferReuseNoAcrossBroadcast_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + auto tv0 = makeConcreteTensor({2, 2}); + auto tv1 = makeConcreteTensor({2, 2, 2}); + + fusion->addInput(tv0); + fusion->addInput(tv1); + + auto tv2 = mul(tv0, IrBuilder::create(2)); + auto tv3 = mul(tv0, IrBuilder::create(3)); + auto tv4 = mul(tv2, tv3); + auto tv5 = broadcast(tv4, {false, false, true}); + auto tv6 = mul(tv5, tv1); + auto tv7 = mul(tv6, IrBuilder::create(7)); + fusion->addOutput(tv7); + + // tv6 shouldn't re-use t2 or t3 because of + // the broadcast in between + tv0->computeAt(tv4, 1, ComputeAtMode::BestEffort); + tv4->computeAt(tv7, 2, ComputeAtMode::BestEffort); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto in0 = at::randn({2, 2}, options); + auto in1 = at::randn({2, 2, 2}, options); + FusionExecutor fe; + fe.compileFusion(fusion, {in0, in1}); + auto outputs = fe.runFusion({in0, in1}); + + auto t2 = in0 * 2; + auto t3 = in0 * 3; + auto t4 = t2 * t3; + auto t5 = t4.unsqueeze(2); + auto t6 = t5 * in1; + auto t7 = t6 * 7; + testValidate(fusion, outputs, {in0, in1}, {t7}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionIssue970_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + const int nelm = 10; + + // tv3 = tv0 + sum(tv0) + auto tv0 = makeConcreteTensor({nelm, nelm}); + fusion.addInput(tv0); + auto tv1 = sum(tv0, {1}); + auto tv2 = broadcast(tv1, {false, true}); + auto tv3 = add(tv2, tv0); + fusion.addOutput(tv3); + + tv1->split(1, 4); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto options_int = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0); + at::manual_seed(0); + at::Tensor t0 = at::randn({nelm, nelm}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto outputs = fe.runFusion({t0}); + + auto ref = sum(t0, {1}).unsqueeze(-1).expand({nelm, nelm}) + t0; + + testValidate(&fusion, outputs, {t0}, {ref}, __LINE__, __FILE__); +} + +// Reproducer of #1016 +TEST_F(NVFuserTest, FusionIssue1016_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(2)); + + fusion.addOutput(tv2); + + tv1->setMemoryType(MemoryType::Shared); + + tv2->split(-1, 8); + + int numel_x = 10; + int numel_y = 11; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({numel_x, numel_y}, options); + std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); + auto outputs = fe.runFusion(inputs); + + auto ref = t0 + 1 + 2; + + testValidate(&fusion, outputs, {t0}, {ref}, __LINE__, __FILE__); +} + +// Reproducer of #1021 +TEST_F(NVFuserTest, FusionIssue1021_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = broadcast(tv1, {false, true}); + fusion.addOutput(tv2); + + auto tv3 = tv2->cacheBefore(); + + tv2->split(0, 2); + + tv1->computeAt(tv2, 1); + + tv2->axis(0)->parallelize(ParallelType::TIDx); + tv2->axis(1)->parallelize(ParallelType::Vectorize); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({10}, options); + std::vector inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); + auto outputs = fe.runFusion(inputs); + + auto ref = (t0 + 1).unsqueeze(-1); + + testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); +} + +// Reproducer of issue #1053 +TEST_F(NVFuserTest, FusionNonUniqueThreadDim_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeSymbolicTensor(1); + fusion->addInput(tv0); + auto tv1 = sum(tv0, {0}); + fusion->addOutput(tv1); + + auto tv2 = add(tv0, IrBuilder::create(1)); + fusion->addOutput(tv2); + + tv1->split(0, 8); + auto tv1_rf = tv1->rFactor({-1}); + + tv1_rf->computeAt(tv1, 1); + + tv1_rf->axis(-1)->parallelize(ParallelType::TIDx); + + tv2->axis(0)->parallelize(ParallelType::TIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input1 = at::randn({32}, options); + + auto at_tv1 = (input1).sum({0}); + auto at_tv2 = input1 + 1; + + FusionExecutor fe; + fe.compileFusion(fusion.get(), {input1}); + auto outputs = fe.runFusion({input1}); + testValidate( + fusion.get(), outputs, {input1}, {at_tv1, at_tv2}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionParallelDimensionMap1_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeSymbolicTensor(1); + fusion->addInput(tv0); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv0, IrBuilder::create(1)); + fusion->addOutput(tv1); + fusion->addOutput(tv2); + + tv1->split(0, 8, false); + tv1->axis(1)->parallelize(ParallelType::TIDx); + tv2->split(0, 8, false); + tv2->axis(1)->parallelize(ParallelType::TIDx); + + // The extents of tv1 and tv2 axes are equal even though their + // actual values are not statically known + GpuLower gpulw(fusion.get()); + const auto& pdmap = gpulw.parallelDimensionMap(); + for (const auto i : c10::irange(tv1->domain()->domain().size())) { + auto dom1 = tv1->domain()->domain()[i]; + auto dom2 = tv2->domain()->domain()[i]; + TORCH_INTERNAL_ASSERT(pdmap.equalDim(dom1->extent(), dom2->extent())); + } + + TORCH_CHECK(pdmap.isExact(ParallelType::TIDx)); + TORCH_CHECK( + pdmap.get(ParallelType::TIDx)->isA() && + pdmap.get(ParallelType::TIDx)->as()->name() == "blockDim.x"); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input1 = at::randn({32}, options); + + FusionExecutor fe; + fe.compileFusion(fusion.get(), {input1}); + auto outputs = fe.runFusion({input1}); + + testValidate( + fusion.get(), + outputs, + {input1}, + {input1 + 1, input1 + 1}, + __LINE__, + __FILE__); +} + +TEST_F(NVFuserTest, FusionParallelDimensionMap2_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeSymbolicTensor(1); + fusion->addInput(tv0); + auto tv1 = makeSymbolicTensor(2); + fusion->addInput(tv1); + auto tv2 = broadcast(tv0, {false, true}); + auto tv3 = add(tv1, tv2); + fusion->addOutput(tv3); + + tv3->split(-1, 8, false); + tv2->computeAt(tv3, -1); + + tv3->axis(-1)->parallelize(ParallelType::TIDx); + tv2->axis(-1)->parallelize(ParallelType::TIDx); + + GpuLower gpulw(fusion.get()); + const auto& pdmap = gpulw.parallelDimensionMap(); + TORCH_CHECK(pdmap.isExact(ParallelType::TIDx)); + TORCH_CHECK( + pdmap.get(ParallelType::TIDx)->isA() && + pdmap.get(ParallelType::TIDx)->as()->name() == "blockDim.x"); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input1 = at::randn({11}, options); + at::Tensor input2 = at::randn({11, 13}, options); + + FusionExecutor fe; + fe.compileFusion(fusion.get(), {input1, input2}); + auto outputs = fe.runFusion({input1, input2}); + + auto ref = input1.unsqueeze(-1) + input2; + + testValidate( + fusion.get(), outputs, {input1, input2}, {ref}, __LINE__, __FILE__); +} + +// Mix symbolic and concrete tensors +TEST_F(NVFuserTest, FusionParallelDimensionMap3_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeSymbolicTensor(1); + fusion->addInput(tv0); + + auto tv2 = add(tv0, IrBuilder::create(1)); + fusion->addOutput(tv2); + auto tv3 = add(tv0, IrBuilder::create(1)); + fusion->addOutput(tv3); + + tv2->split(0, 10); + tv3->split(0, 20); + + auto tv4 = add(tv0, IrBuilder::create(1)); + fusion->addOutput(tv4); + auto tv5 = add(tv0, IrBuilder::create(1)); + fusion->addOutput(tv5); + + // Not mapped but equal extent + tv4->split(0, 10); + tv5->split(0, 10); + + tv2->axis(-1)->parallelize(ParallelType::TIDx); + tv3->axis(-1)->parallelize(ParallelType::TIDx); + + tv4->axis(-1)->parallelize(ParallelType::TIDy); + tv5->axis(-1)->parallelize(ParallelType::TIDy); + + GpuLower gpulw(fusion.get()); + const auto& pdmap = gpulw.parallelDimensionMap(); + TORCH_CHECK(!pdmap.isExact(ParallelType::TIDx)); + TORCH_CHECK( + pdmap.get(ParallelType::TIDx)->isA() && + pdmap.get(ParallelType::TIDx)->as()->name() == "blockDim.x"); + TORCH_CHECK(pdmap.isExact(ParallelType::TIDy)); + TORCH_CHECK( + pdmap.get(ParallelType::TIDy)->isConst() && + pdmap.get(ParallelType::TIDy)->as()->value().value() == 10); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input1 = at::randn({13}, options); + + FusionExecutor fe; + fe.compileFusion(fusion.get(), {input1}); + auto outputs = fe.runFusion({input1}); + + testValidate( + fusion.get(), + outputs, + {input1}, + {input1 + 1, input1 + 1, input1 + 1, input1 + 1}, + __LINE__, + __FILE__); +} + +// Parallelizing merged broadcast domains +TEST_F(NVFuserTest, FusionParallelDimensionMap4_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + auto tv1 = makeSymbolicTensor(2); + fusion.addInput(tv1); + auto tv2 = add(tv0, IrBuilder::create(1)); + auto tv3 = broadcast(tv2, {true, false}); + auto tv4 = add(tv3, tv1); + fusion.addOutput(tv4); + + tv4->split(1, 4); + tv4->reorder({{1, 2}, {2, 1}}); + tv4->merge(0); + tv0->computeAt(tv4, 1); + tv1->computeAt(tv4, 1); + + // TIDx is mapped to tv4.axis(0) as well as tv2.axis(0), so it's not + // exact. + tv4->axis(0)->parallelize(ParallelType::TIDx); + + tv2->setMemoryType(MemoryType::Shared); + tv3->setMemoryType(MemoryType::Shared); + + GpuLower gpulw(&fusion); + const auto& pdmap = gpulw.parallelDimensionMap(); + TORCH_CHECK(!pdmap.isExact(ParallelType::TIDx)); + TORCH_CHECK( + pdmap.get(ParallelType::TIDx)->isA() && + pdmap.get(ParallelType::TIDx)->as()->name() == "blockDim.x"); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input1 = at::randn({13}, options); + at::Tensor input2 = at::randn({15, 13}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input1, input2}); + auto outputs = fe.runFusion({input1, input2}); + + auto ref = (input1 + 1).unsqueeze(0) + input2; + + testValidate(&fusion, outputs, {input1, input2}, {ref}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionParallelDimensionMap5_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + auto tv1 = makeSymbolicTensor(2); + fusion.addInput(tv1); + auto tv3 = broadcast(tv0, {false, true}); + auto tv4 = add(tv3, tv1); + fusion.addOutput(tv4); + + tv4->split(1, 4); + tv0->computeAt(tv4, -1); + tv1->computeAt(tv4, -1); + + tv4->axis(-1)->parallelize(ParallelType::TIDx); + tv3->axis(-1)->parallelize(ParallelType::TIDx); + tv4->axis(-2)->parallelize(ParallelType::TIDy); + tv3->axis(-2)->parallelize(ParallelType::TIDy); + + GpuLower gpulw(&fusion); + const auto& pdmap = gpulw.parallelDimensionMap(); + TORCH_CHECK(pdmap.isExact(ParallelType::TIDx)); + TORCH_CHECK(pdmap.isExact(ParallelType::TIDy)); + TORCH_CHECK( + pdmap.get(ParallelType::TIDx)->isConst() && + pdmap.get(ParallelType::TIDx)->as()->value().value() == 4); + TORCH_CHECK( + pdmap.get(ParallelType::TIDy)->isA() && + pdmap.get(ParallelType::TIDy)->as()->name() == "blockDim.y"); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input1 = at::randn({13}, options); + at::Tensor input2 = at::randn({13, 15}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input1, input2}); + auto outputs = fe.runFusion({input1, input2}); + + auto ref = (input1).unsqueeze(-1) + input2; + + testValidate(&fusion, outputs, {input1, input2}, {ref}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionSegmenterCombineReductionsCycleRepro_CUDA) { + auto fusion_ptr = std::make_unique(); + auto& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + auto t0 = makeSymbolicTensor(3, DataType::Float); + auto t1 = makeSymbolicTensor(3, DataType::Half); + auto t3 = makeSymbolicTensor(3, DataType::Half); + auto t5 = makeSymbolicTensor(3, DataType::Half); + auto t7 = makeSymbolicTensor(1, DataType::Half); + auto t11 = makeSymbolicTensor(3, DataType::Half); + auto t13 = makeSymbolicTensor(3, DataType::Half); + auto t15 = makeSymbolicTensor(3, DataType::Half); + auto t17 = makeSymbolicTensor(3, DataType::Half); + auto d56 = IrBuilder::create(); + + fusion.addInput(t0); + fusion.addInput(t1); + fusion.addInput(t3); + fusion.addInput(t5); + fusion.addInput(t7); + fusion.addInput(t11); + fusion.addInput(t13); + fusion.addInput(t15); + fusion.addInput(t17); + fusion.addInput(d56); + + auto t2 = castOp(DataType::Float, t1); + auto t4 = castOp(DataType::Float, t3); + auto t22 = sub(t2, t4); + auto t6 = castOp(DataType::Float, t5); + auto t23 = mul(t22, t6); + auto t16 = castOp(DataType::Float, t15); + auto t18 = castOp(DataType::Float, t17); + auto t19 = add(t16, t18); + auto t14 = castOp(DataType::Float, t13); + auto t20 = add(t19, t14); + auto t12 = castOp(DataType::Float, t11); + auto t21 = add(t20, t12); + auto t8 = castOp(DataType::Float, t7); + auto t24 = broadcast(t8, {true, true, false}); + auto t25 = mul(t21, t24); + auto t27 = sum(t25, {2}); + auto t28 = broadcast(t27, {false, false, true}); + auto t29 = mul(t25, t23); + auto t30 = sum(t29, {2}); + auto t31 = broadcast(t30, {false, false, true}); + auto d59 = + mul(t1->getRootDomain()[2]->extent(), IrBuilder::create(1)); + auto t26 = mul(d59, t25); + auto txx = mul(t26, IrBuilder::create(1)); + auto t33 = sub(txx, t28); + auto d70 = unaryOp(UnaryOpType::Reciprocal, d59); + auto t35 = mul(d70, t6); + auto t39 = sum(t21, {0, 1}); + auto t47 = castOp(DataType::Half, t39); + auto t37 = mul(t21, t23); + auto t38 = sum(t37, {0, 1}); + auto t46 = castOp(DataType::Half, t38); + auto t32 = mul(t23, t31); + auto t34 = sub(t33, t32); + auto t36 = mul(t35, t34); + auto t45 = castOp(DataType::Half, t36); + auto t40 = mul(t36, t0); + auto t41 = mul(t40, d56); + auto t44 = castOp(DataType::Half, t41); + auto t42 = sum(t41, {0, 1}); + auto t43 = castOp(DataType::Half, t42); + + fusion.addOutput(t43); + fusion.addOutput(t44); + fusion.addOutput(t45); + fusion.addOutput(t46); + fusion.addOutput(t47); + + auto options_half = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + auto options_float = + at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor at_t0 = at::randn({128, 64, 1024}, options_float); + at::Tensor at_t1 = at::randn({128, 64, 1024}, options_half); + at::Tensor at_t3 = at::randn({128, 64, 1024}, options_half); + at::Tensor at_t5 = at::randn({128, 64, 1024}, options_half); + at::Tensor at_t7 = at::randn({1024}, options_half); + at::Tensor at_t11 = at::randn({128, 64, 1024}, options_half); + at::Tensor at_t13 = at::randn({128, 64, 1024}, options_half); + at::Tensor at_t15 = at::randn({128, 64, 1024}, options_half); + at::Tensor at_t17 = at::randn({128, 64, 1024}, options_half); + double at_d56 = 1.1111; + + std::vector aten_inputs = { + at_t0, at_t1, at_t3, at_t5, at_t7, at_t11, at_t13, at_t15, at_t17}; + + c10::IValue val = at_d56; + + KernelArgumentHolder args(KernelIndexMode::INT32); + args.setDeviceIndex(0); + args.push(aten_inputs); + args.push(val); + + for (auto _ : c10::irange(5)) { + auto segmented_fusion = + SegmentCandidateFinder::segment(fusion_ptr.get(), args); + } +} + +TEST_F(NVFuserTest, FusionSerialAndParallelIndexing_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); + fusion.addOutput(tv2); + + auto tv3 = add(tv0, IrBuilder::create(1)); + auto tv4 = add(tv3, IrBuilder::create(1)); + fusion.addOutput(tv4); + + auto tv5 = add(tv0, IrBuilder::create(1)); + auto tv6 = add(tv5, IrBuilder::create(1)); + fusion.addOutput(tv6); + + // Case 1: local memory tensor computed serially and used by + // parallel threads + tv2->split(-1, 4); + tv1->computeAt(tv2, -2); + tv2->axis(-1)->parallelize(ParallelType::TIDx); + + // Case 2: shared memory tensor computed serially and used by BID + tv4->split(-1, 4); + tv3->computeAt(tv4, -2); + tv4->axis(-1)->parallelize(ParallelType::BIDx); + tv3->setMemoryType(MemoryType::Shared); + + // Case 3: shared memory tensor computed by TID and used by BID + tv6->split(-1, 4); + tv5->computeAt(tv6, -2); + tv6->axis(-1)->parallelize(ParallelType::BIDx); + tv5->axis(-1)->parallelize(ParallelType::TIDx); + tv5->setMemoryType(MemoryType::Shared); + + const int nx = 11; + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({nx}, options); + std::vector aten_inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto outputs = fe.runFusion(aten_inputs); + + auto ref = t0 + 2; + + testValidate( + &fusion, outputs, aten_inputs, {ref, ref, ref}, __LINE__, __FILE__); +} + +// Repro of issue #1105 +TEST_F(NVFuserTest, FusionWARSyncAliasedSmem_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); + auto tv3 = add(tv2, IrBuilder::create(1)); + + fusion.addOutput(tv3); + + tv1->setMemoryType(MemoryType::Shared); + tv2->setMemoryType(MemoryType::Shared); + + tv3->split(0, 4); + tv0->computeAt(tv3, 1); + + tv1->axis(-1)->parallelize(ParallelType::TIDx); + tv2->axis(-1)->parallelize(ParallelType::TIDy); + tv3->axis(-1)->parallelize(ParallelType::TIDz); + + // Make sure a WAR sync is inserted at the end of the outer loop + GpuLower gpulw(&fusion); + for (const auto& kir_node : gpulw.kernel()->topLevelExprs()) { + if (auto loop = dynamic_cast(kir_node)) { + const auto& body = loop->body().exprs(); + TORCH_CHECK(!body.empty()); + auto last_expr = dynamic_cast(body.back()); + TORCH_CHECK(last_expr != nullptr, "Invalid expr found"); + TORCH_CHECK(last_expr->isWarHazardSync(), "Not a sync for WAR hazard"); + } + } + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({17}, options); + std::vector aten_inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto outputs = fe.runFusion(aten_inputs); + + auto ref1 = t0 + 3; + + testValidate(&fusion, outputs, aten_inputs, {ref1}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionIssue1099_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); + fusion.addOutput(tv2); + + auto tv3 = makeSymbolicTensor(1); + fusion.addInput(tv3); + + // Just to make TIDx/y/z non-exact + auto tv4 = add(tv3, IrBuilder::create(1)); + auto tv5 = add(tv4, IrBuilder::create(1)); + auto tv6 = add(tv5, IrBuilder::create(1)); + fusion.addOutput(tv6); + + tv2->split(0, 4); + tv0->computeAt(tv2, 1); + + tv0->axis(-1)->parallelize(ParallelType::TIDx); + tv1->axis(-1)->parallelize(ParallelType::TIDy); + tv2->axis(-1)->parallelize(ParallelType::TIDz); + tv2->axis(0)->parallelize(ParallelType::BIDx); + + tv1->setMemoryType(MemoryType::Shared); + + tv4->split(0, 5); + tv4->axis(-1)->parallelize(ParallelType::TIDx); + tv4->setMemoryType(MemoryType::Shared); + tv5->split(0, 6); + tv5->axis(-1)->parallelize(ParallelType::TIDy); + tv5->setMemoryType(MemoryType::Shared); + tv6->split(0, 7); + tv6->axis(-1)->parallelize(ParallelType::TIDz); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({17}, options); + at::Tensor t3 = at::randn({19}, options); + std::vector aten_inputs = {t0, t3}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto outputs = fe.runFusion(aten_inputs); + + auto ref_t2 = t0 + 2; + auto ref_t3 = t3 + 3; + + testValidate( + &fusion, outputs, aten_inputs, {ref_t2, ref_t3}, __LINE__, __FILE__); +} + +// Repro of issue #1080 +TEST_F(NVFuserTest, FusionUnswitchPredicate_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); + fusion.addOutput(tv2); + + tv2->split(0, 4); + tv0->computeAt(tv2, 2); + + tv2->split(-1, 8); + tv1->split(-1, 8); + + tv2->axis(1)->parallelize(ParallelType::Unswitch); + + tv2->axis(-1)->parallelize(ParallelType::TIDx); + tv2->axis(-2)->parallelize(ParallelType::TIDy); + + // swap TIDx and TIDy + tv1->axis(-1)->parallelize(ParallelType::TIDy); + tv1->axis(-2)->parallelize(ParallelType::TIDx); + + tv1->setMemoryType(MemoryType::Shared); + + const int nx = 4; + const int ny = 10; + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({nx, ny}, options); + std::vector aten_inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto outputs = fe.runFusion(aten_inputs); + + auto ref = t0 + 2; + + testValidate(&fusion, outputs, aten_inputs, {ref}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionIssue1189_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({16, 16}); + auto tv1 = makeConcreteTensor({16, 16}); + + auto tv0b = broadcast(tv0, {false, false, true}); + auto tv1b = broadcast(tv1, {false, false, true}); + + fusion.addInput(tv0b); + fusion.addInput(tv1b); + + auto tv2 = add(tv0b, tv1b); + auto tv3 = sum(tv2, {1}); + fusion.addOutput(tv3); + + auto parallelize = [](auto tv) { + tv->axis(0)->parallelize(ParallelType::TIDx); + tv->axis(1)->parallelize(ParallelType::BIDx); + tv->axis(2)->parallelize(ParallelType::BIDy); + }; + + parallelize(tv0b); + parallelize(tv1b); + parallelize(tv2); + parallelize(tv3); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({16, 16, 1}, options); + at::Tensor t1 = at::randn({16, 16, 1}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t1}); + auto outputs = fe.runFusion({t0, t1}); + + auto ref = (t0 + t1).sum({1}); + + testValidate(&fusion, outputs, {t0, t1}, {ref}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionIssue1052_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + auto tv1 = makeSymbolicTensor(1); + fusion.addInput(tv1); + + auto tv2 = add(tv0, IrBuilder::create(1)); + fusion.addOutput(tv2); + + auto tv3 = add(tv1, IrBuilder::create(1)); + fusion.addOutput(tv3); + + tv2->axis(-1)->parallelize(ParallelType::TIDx); + tv3->axis(-1)->parallelize(ParallelType::TIDx); + + scheduler_utils::parallelizeAllLike(tv2, {tv0}); + scheduler_utils::parallelizeAllLike(tv3, {tv1}); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({10}, options); + at::Tensor t1 = at::randn({100}, options); + std::vector aten_inputs = {t0, t1}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto outputs = fe.runFusion(aten_inputs); + + auto ref_t2 = t0 + 1; + auto ref_t3 = t1 + 1; + + testValidate( + &fusion, outputs, aten_inputs, {ref_t2, ref_t3}, __LINE__, __FILE__); +} + +// Repro of issue #1115 +TEST_F(NVFuserTest, FusionPointwiseBroadcast_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + std::vector input_shape{3, 17, 80}; + std::vector output_shape{3, 17, 1, 80}; + + TensorView* x = makeSymbolicTensor(input_shape.size()); + TensorView* bias = makeSymbolicTensor(input_shape.size()); + fusion.addInput(x); + fusion.addInput(bias); + + auto x_add_bias = add(x, bias); + auto x_bcast = broadcast(x_add_bias, {false, false, true, false}); + auto y = gelu(x_bcast); + fusion.addOutput(y); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor at_x = at::randn(input_shape, options); + at::Tensor at_bias = at::randn(input_shape, options); + std::vector aten_inputs = {at_x, at_bias}; + + schedulePointwise(&fusion, aten_inputs); + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto outputs = fe.runFusion(aten_inputs); + + auto at_x_add_bias = at_x + at_bias; + auto at_x_view = at::native::view(at_x_add_bias, output_shape); + auto aten_y = at::gelu(at_x_view); + + testValidate(&fusion, outputs, aten_inputs, {aten_y}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionPointwiseVectorize_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + const int size = 1024 * 64; + + TensorView* x = makeContigTensor(1); + fusion.addInput(x); + auto y = sin(x); + fusion.addOutput(y); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + // PyTorch's CUDA caching allocator should always return aligned pointer for + // freshly allocated tensor + at::Tensor at_x = at::randn({size}, options); + + schedulePointwise(&fusion, {at_x}); + + for (auto x_consumer : ir_utils::consumerTvsOf(x)) { + bool found_vec_in_input = false; + for (auto id : x_consumer->domain()->domain()) { + if (isParallelTypeVectorize(id->getParallelType())) { + found_vec_in_input = true; + break; + } + } + TORCH_CHECK(found_vec_in_input, "Expect input to be vectorized"); + } + + for (auto id : y->domain()->domain()) { + if (isParallelTypeVectorize(id->getParallelType())) { + return; + } + } + TORCH_CHECK(false, "Expect output to be vectorized"); +} + +TEST_F(NVFuserTest, FusionSmemAliasSerial_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); + auto tv3 = add(tv2, IrBuilder::create(1)); + + fusion.addOutput(tv3); + + // Just set the dimension of TIDx + auto tv4 = makeSymbolicTensor(1); + fusion.addInput(tv4); + auto tv5 = add(tv4, IrBuilder::create(1)); + fusion.addOutput(tv5); + + tv1->setMemoryType(MemoryType::Shared); + tv2->setMemoryType(MemoryType::Shared); + + tv5->axis(0)->parallelize(ParallelType::TIDx); + + // tv1 and tv2 are on shared memory and are not parallelized with + // TIDx. They should be predicated as they are redundant and can + // interfere with smem aliasing (issue #1100). + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({10}, options); + at::Tensor t4 = at::randn({1024}, options); + std::vector aten_inputs = {t0, t4}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto outputs = fe.runFusion(aten_inputs); + + auto ref1 = t0 + 3; + auto ref2 = t4 + 1; + + testValidate(&fusion, outputs, aten_inputs, {ref1, ref2}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionGridReductionWithNonExactParallelDimensions_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1)); + fusion.addOutput(tv1); + + auto tv2 = makeSymbolicTensor(1); + fusion.addInput(tv2); + auto tv3 = sum(tv2, {0}); + fusion.addOutput(tv3); + + tv1->axis(0)->parallelize(ParallelType::TIDx); + tv3->axis(0)->parallelize(ParallelType::BIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({17}, options); + at::Tensor t2 = at::randn({19}, options); + std::vector aten_inputs = {t0, t2}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto outputs = fe.runFusion(aten_inputs); + + auto ref1 = t0 + 1; + auto ref2 = sum(t2); + + testValidate(&fusion, outputs, aten_inputs, {ref1, ref2}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionGridWelfordWithNonExactParallelDimensions_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1)); + fusion.addOutput(tv1); + + auto tv2 = makeSymbolicTensor(1); + fusion.addInput(tv2); + auto tv3 = Welford(tv2, {0}).avg; + fusion.addOutput(tv3); + + tv1->axis(0)->parallelize(ParallelType::TIDx); + tv3->axis(0)->parallelize(ParallelType::BIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({17}, options); + at::Tensor t2 = at::randn({19}, options); + std::vector aten_inputs = {t0, t2}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto outputs = fe.runFusion(aten_inputs); + + auto ref1 = t0 + 1; + auto ref2 = mean(t2, {0}); + + testValidate(&fusion, outputs, aten_inputs, {ref1, ref2}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionGridReductionWithNonExactParallelDimensions2_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto tv1 = sum(tv0, {0, 1}); + fusion.addOutput(tv1); + + auto tv2 = makeSymbolicTensor(3); + fusion.addInput(tv2); + auto tv3 = add(tv2, IrBuilder::create(1)); + fusion.addOutput(tv3); + + auto tv4 = makeSymbolicTensor(3); + fusion.addInput(tv4); + auto tv5 = add(tv4, IrBuilder::create(1)); + fusion.addOutput(tv5); + + tv1->axis(0)->parallelize(ParallelType::BIDx); + tv1->axis(1)->parallelize(ParallelType::TIDx); + + tv3->axis(0)->parallelize(ParallelType::TIDx); + tv3->axis(1)->parallelize(ParallelType::TIDy); + tv3->axis(2)->parallelize(ParallelType::TIDz); + + tv5->axis(0)->parallelize(ParallelType::BIDx); + tv5->axis(1)->parallelize(ParallelType::BIDy); + tv5->axis(2)->parallelize(ParallelType::BIDz); + + // TODO: This needs a fix for issue #1102. + // Also, need to allow predicated grid reductions. +#if 0 + FusionExecutor fe; + fe.compileFusion(&fusion); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({2, 3}, options); + at::Tensor t2 = at::randn({5, 6, 7}, options); + at::Tensor t4 = at::randn({8, 9, 10}, options); + std::vector aten_inputs = {t0, t2, t4}; + auto outputs = fe.runFusion(aten_inputs); + + auto ref1 = t0.sum(at::IntArrayRef{0, 1}); + auto ref2 = t2 + 1; + auto ref3 = t4 + 1; + + testValidate( + &fusion, outputs, aten_inputs, {ref1, ref2, ref3}, __LINE__, __FILE__); +#endif +} + +TEST_F(NVFuserTest, FusionGridWelfordWithNonExactParallelDimensions2_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto tvs = Welford(tv0, {0, 1}); + fusion.addOutput(tvs.avg); + + auto tv2 = makeSymbolicTensor(3); + fusion.addInput(tv2); + auto tv3 = add(tv2, IrBuilder::create(1)); + fusion.addOutput(tv3); + + auto tv4 = makeSymbolicTensor(3); + fusion.addInput(tv4); + auto tv5 = add(tv4, IrBuilder::create(1)); + fusion.addOutput(tv5); + + tvs.avg->axis(0)->parallelize(ParallelType::BIDx); + tvs.avg->axis(1)->parallelize(ParallelType::TIDx); + + tv3->axis(0)->parallelize(ParallelType::TIDx); + tv3->axis(1)->parallelize(ParallelType::TIDy); + tv3->axis(2)->parallelize(ParallelType::TIDz); + + tv5->axis(0)->parallelize(ParallelType::BIDx); + tv5->axis(1)->parallelize(ParallelType::BIDy); + tv5->axis(2)->parallelize(ParallelType::BIDz); + + // TODO: needs a fix for issue #1102 + // Also, need to allow predicated grid reductions. +#if 0 + FusionExecutor fe; + fe.compileFusion(&fusion); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({2, 3}, options); + at::Tensor t2 = at::randn({5, 6, 7}, options); + at::Tensor t4 = at::randn({8, 9, 10}, options); + std::vector aten_inputs = {t0, t2, t4}; + auto outputs = fe.runFusion(aten_inputs); + + auto ref1 = t0.mean(at::IntArrayRef{0, 1}); + auto ref2 = t2 + 1; + auto ref3 = t4 + 1; + + testValidate( + &fusion, outputs, aten_inputs, {ref1, ref2, ref3}, __LINE__, __FILE__); +#endif +} + +// Repro of issue #1102 +TEST_F(NVFuserTest, FusionPredicateParallelizedDomains_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + // Just to make TIDx/y/z non-exact + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); + auto tv3 = add(tv2, IrBuilder::create(1)); + fusion.addOutput(tv3); + + auto tv4 = makeSymbolicTensor(1); + fusion.addInput(tv4); + + auto tv5 = add(tv4, IrBuilder::create(1)); + auto tv6 = add(tv5, IrBuilder::create(1)); + auto tv7 = add(tv6, IrBuilder::create(1)); + auto tv8 = add(tv7, IrBuilder::create(1)); + auto tv9 = sum(tv8, {0}); + fusion.addOutput(tv9); + + tv1->split(0, 5); + tv1->axis(-1)->parallelize(ParallelType::TIDx); + tv1->setMemoryType(MemoryType::Shared); + tv2->split(0, 6); + tv2->axis(-1)->parallelize(ParallelType::TIDy); + tv2->setMemoryType(MemoryType::Shared); + tv3->split(0, 7); + tv3->axis(-1)->parallelize(ParallelType::TIDz); + + tv9->split(0, 4); + tv4->computeAt(tv9, 1); + + tv4->axis(-1)->parallelize(ParallelType::TIDx); + tv5->axis(-1)->parallelize(ParallelType::TIDy); + tv6->axis(-1)->parallelize(ParallelType::TIDz); + tv7->axis(-1)->parallelize(ParallelType::TIDz); + tv8->axis(-1)->parallelize(ParallelType::TIDz); + tv9->axis(-1)->parallelize(ParallelType::TIDz); + tv9->axis(0)->parallelize(ParallelType::BIDx); + + tv5->setMemoryType(MemoryType::Shared); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({17}, options); + at::Tensor t4 = at::randn({19}, options); + std::vector aten_inputs = {t0, t4}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto outputs = fe.runFusion(aten_inputs); + + auto ref1 = t0 + 3; + auto ref2 = sum(t4 + 4); + + testValidate(&fusion, outputs, aten_inputs, {ref1, ref2}, __LINE__, __FILE__); +} + +// Repro of #1102 and #1129 +TEST_F(NVFuserTest, FusionSmemPredicateUnswitch_CUDA) { + if (!deviceMajorMinorCheck(7)) { + GTEST_SKIP() << "skipping tests on pre-Volta GPUs"; + return; + } + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + auto tv1 = makeSymbolicTensor(1); + fusion.addInput(tv1); + + auto tv2 = add(tv0, IrBuilder::create(1)); + auto tv3 = add(tv2, IrBuilder::create(1)); + auto tv4 = add(tv3, IrBuilder::create(1)); + auto tv5 = add(tv4, IrBuilder::create(1)); + fusion.addOutput(tv5); + + // Just to make TIDx/y/z non-exact + auto tvx = add(tv1, IrBuilder::create(1)); + auto tvy = add(tvx, IrBuilder::create(1)); + auto tvz = add(tvy, IrBuilder::create(1)); + fusion.addOutput(tvz); + + tv5->split(0, 4); + tv0->computeAt(tv5, 1); + + tv0->axis(-1)->parallelize(ParallelType::TIDx); + tv2->axis(-1)->parallelize(ParallelType::TIDy); + tv3->axis(-1)->parallelize(ParallelType::TIDz); + tv4->axis(-1)->parallelize(ParallelType::TIDx); + tv5->axis(-1)->parallelize(ParallelType::TIDy); + tv5->axis(0)->parallelize(ParallelType::Unswitch); + + tvx->split(0, 5); + tvx->axis(-1)->parallelize(ParallelType::TIDx); + tvy->split(0, 6); + tvy->axis(-1)->parallelize(ParallelType::TIDy); + tvz->split(0, 7); + tvz->axis(-1)->parallelize(ParallelType::TIDz); + + for (auto tv : {tv2, tv3, tv4, tvx, tvy}) { + tv->setMemoryType(MemoryType::Shared); + } + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({17}, options); + at::Tensor t1 = at::randn({19}, options); + std::vector aten_inputs = {t0, t1}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto outputs = fe.runFusion(aten_inputs); + + auto ref1 = t0 + 4; + auto ref2 = t1 + 3; + + testValidate(&fusion, outputs, aten_inputs, {ref1, ref2}, __LINE__, __FILE__); +} + +// Repro of issue #1136 +TEST_F(NVFuserTest, FusionFloatPow_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + auto tv1 = binaryOp(BinaryOpType::Pow, tv0, IrBuilder::create(4)); + // To check if pow(tv0, 2) is replaced with tv0 * tv0 + auto tv2 = binaryOp(BinaryOpType::Pow, tv0, IrBuilder::create(2)); + // To check if pow(tv0, 2.0) is replaced with tv0 * tv0 + auto tv3 = binaryOp(BinaryOpType::Pow, tv0, IrBuilder::create(2)); + auto tv4 = binaryOp(BinaryOpType::Pow, tv0, IrBuilder::create(3)); + auto tv5 = binaryOp(BinaryOpType::Pow, tv0, IrBuilder::create(3)); + auto s = binaryOp( + BinaryOpType::Pow, + IrBuilder::create(3), + IrBuilder::create(3)); + auto tv6 = add(tv0, s); + + fusion.addOutput(tv1); + fusion.addOutput(tv2); + fusion.addOutput(tv3); + fusion.addOutput(tv4); + fusion.addOutput(tv5); + fusion.addOutput(tv6); + + tv1->split(0, 32); + tv1->axis(0)->parallelize(ParallelType::BIDx); + tv1->axis(1)->parallelize(ParallelType::TIDx); + + TransformPropagatorWithCheck propagator(tv1); + MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator); + scheduler_utils::parallelizeAllLike(tv1, {tv2, tv3, tv4, tv5, tv6}); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({1000}, options); + // Negative inputs cause nan in Fuesr as use_fast_math is enabled + t0 = abs(t0); + std::vector aten_inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto outputs = fe.runFusion(aten_inputs); + + auto p4 = at::pow(t0, 4); + auto p2 = at::pow(t0, 2); + auto p3 = at::pow(t0, 3); + auto t6 = t0 + std::pow(3, 3); + + testValidate( + &fusion, + outputs, + aten_inputs, + {p4, p2, p2, p3, p3, t6}, + __LINE__, + __FILE__); +} + +TEST_F(NVFuserTest, FusionIssue1127_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + const int numel = 4; + + auto tv0 = makeConcreteTensor({numel}); + fusion.addInput(tv0); + + auto tv1 = sum(tv0, {0}); + auto tv2 = broadcast(tv1, {true}); + + auto tv3 = makeConcreteTensor({numel, numel}); + fusion.addInput(tv3); + + auto tv4 = sum(tv3, {1}); + + auto tv5 = add(tv2, tv4); + fusion.addOutput(tv5); + + tv1->axis(0)->parallelize(ParallelType::TIDx); + tv2->axis(0)->parallelize(ParallelType::TIDx); + tv4->axis(1)->parallelize(ParallelType::TIDx); + tv5->axis(0)->parallelize(ParallelType::TIDx); + + // Lowering should fail since tv5 is predicated and paralellized with TIDx. + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) + ASSERT_ANY_THROW(fusion.printKernel()); +} + +TEST_F(NVFuserTest, FusionChannelsLastParser_CUDA) { + // This test may not pass if using a custom block sync as there may + // be additional calls. Skip the test as it's not specifically + // relevant with block synchronizatin. + if (std::getenv("PYTORCH_NVFUSER_USE_BLOCK_SYNC_ATOMIC")) { + return; + } + auto g = std::make_shared(); + const auto graph0_string = R"IR( + graph(%0 : Half(8, 4, 10, 16, strides=[640, 1, 64, 4]), + %1 : Half(8, 4, 10, 16, strides=[640, 160, 16, 1])): + %o.1 : Half(8, 4, 10, 16, strides=[640, 1, 64, 4]) = aten::mul(%0, %1) # sum_dyn.py:5:6 + %3 : Half(8, 4, 10, 16, strides=[640, 1, 64, 4]) = aten::relu(%o.1) # sum_dyn.py:6:9 + return (%3))IR"; + parseIR(graph0_string, g.get()); + + // strides are not yet supported in the irparser. + { + auto val = g->block()->inputs()[0]; + val->setType(val->type()->castRaw()->withSizesStrides( + {8, 4, 10, 16}, {640, 1, 64, 4})); + } + + { + auto val = g->block()->inputs()[1]; + val->setType(val->type()->castRaw()->withSizesStrides( + {8, 4, 10, 16}, {640, 160, 16, 1})); + } + + for (auto node : g->block()->nodes()) { + for (auto val : node->outputs()) { + if (val->isCompleteTensor()) + val->setType(val->type()->castRaw()->withSizesStrides( + {8, 4, 10, 16}, {640, 1, 64, 4})); + } + } + + auto fusion = parseJitIR(g); + FusionGuard fg(fusion.get()); + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + at::Tensor input0 = + at::randn({2, 2, 2, 16}, options).clone(c10::MemoryFormat::ChannelsLast); + at::Tensor input1 = at::randn({2, 2, 2, 16}, options); + auto lparams = schedulePointwise(fusion.get(), {input0, input1}); + + // CONSIDER: + // 1. this can be moved to a dedicated "golden" file + // 2. use a fuzzy compare (ignore non-significant whitespaces for example) + const std::string expected_kernel = R"( +__global__ void CUDAGeneratedKernel(Tensor<__half, 4> T0, Tensor<__half, 4> T2, Tensor<__half, 4> T7) { + int64_t i165; + i165 = (((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x); + if ((i165 < (T0.size[0] * (T0.size[1] * (T0.size[2] * T0.size[3]))))) { + __half T9[1]; + T9[0] = 0; + T9[0] + = T2[((((((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x)) / (T0.size[1] * (T0.size[2] * T0.size[3]))) * ((T0.size[2] * T0.size[1]) * T0.size[3])) + ((((((((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x)) % (T0.size[1] * (T0.size[2] * T0.size[3]))) % (T0.size[2] * T0.size[3])) % T0.size[3]) * (T0.size[2] * T0.size[1])) + (((((((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x)) % (T0.size[1] * (T0.size[2] * T0.size[3]))) / (T0.size[2] * T0.size[3])) * T0.size[2]) + (((((((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x)) % (T0.size[1] * (T0.size[2] * T0.size[3]))) % (T0.size[2] * T0.size[3])) / T0.size[3])]; + __half T8[1]; + T8[0] = 0; + T8[0] + = T0[i165]; + float T3[1]; + T3[0] + = __half2float(T9[0]); + float T4[1]; + T4[0] + = T3[0]; + float T1[1]; + T1[0] + = __half2float(T8[0]); + float T5[1]; + T5[0] + = T1[0] + * T4[0]; + float T6[1]; + T6[0] + = relu(T5[0]); + __half T10[1]; + T10[0] + = __float2half(T6[0]); + T7[i165] + = T10[0]; + } +} +)"; + + const std::string actual_kernel = + "\n" + codegen::generateCudaKernel(GpuLower(fusion.get()).kernel()); + + if (expected_kernel.size() != actual_kernel.size() || + expected_kernel.compare(actual_kernel) != 0) { + std::cerr + << " Codegen mismatch, codegen possibly changed, or is incorrect. " + << " \n ========= EXPECTED ========= \n" + << expected_kernel << "\n========= ACTUAL ========== \n" + << actual_kernel << "\n=================" << std::endl; + auto it = std::mismatch( + expected_kernel.begin(), + expected_kernel.end(), + actual_kernel.begin(), + actual_kernel.end()); + std::string actual_mismatched_snippet(it.second, actual_kernel.end()); + actual_mismatched_snippet = actual_mismatched_snippet.substr(0, 10); + std::string expected_mismatched_snippet(it.first, expected_kernel.end()); + expected_mismatched_snippet = expected_mismatched_snippet.substr(0, 10); + std::cerr << "First mismatch found at: " << actual_mismatched_snippet + << ", expected: " << expected_mismatched_snippet << std::endl; + TORCH_CHECK(false); + } + + // TODO: runFusion hits assertion. I'm probably doing something wrong here. + // FusionExecutor fe; + // fe.compileFusion(fusion.get()); + // auto outputs = fe.runFusion({input0, input1}, lparams); + // at::Tensor output_ref = (input0 * input1).relu(); + // TORCH_CHECK(output_ref.equal(outputs[0])); +} + +TEST_F(NVFuserTest, FusionThreadPredicateUnswitch_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({10, 1024}); + fusion.addInput(tv0); + + auto tv1 = sum(tv0, {1}); + auto tv2 = add(tv1, IrBuilder::create(1)); + auto tv3 = add(tv2, IrBuilder::create(1)); + + fusion.addOutput(tv3); + + tv1->axis(-1)->parallelize(ParallelType::TIDx); + tv2->computeAt(tv3, -1); + tv3->axis(0)->parallelize(ParallelType::Unswitch); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({10, 1024}, options); + std::vector aten_inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto outputs = fe.runFusion(aten_inputs); + + auto ref = sum(t0, {1}) + 2; + + testValidate(&fusion, outputs, aten_inputs, {ref}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionNonContigOutputs_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1)); + fusion.addOutput(tv1); + + tv1->setContiguity(false); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor at_input = at::randn({10}, options); + at::Tensor at_output = at::empty_strided({10}, {2}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {at_input}); + auto returned_outputs = fe.runFusion({at_input}, {at_output}); + + // Returned outputs should only contain one tensor that is the same + // as the output tensor given to runFusion + TORCH_CHECK(returned_outputs.size() == 1); + TORCH_CHECK(returned_outputs[0].is_same(at_output)); + TORCH_CHECK(!returned_outputs[0].is_contiguous()); + + auto at_ref = at_input + 1; + + testValidate(&fusion, {at_output}, {at_input}, {at_ref}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionTestWarpSoftMax_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Setup softmax fusion + auto input = makeContigTensor(2); + fusion.addInput(input); + auto output = softmax(input, 1); + fusion.addOutput(output); + + // Setup runtime input + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor aten_input = at::randn({8, 16 * 197}, options); + std::vector aten_inputs({aten_input}); + + // Schedule through magic scheduler + SchedulerRuntimeInfo runtime_info(&fusion, aten_inputs, true); + TORCH_CHECK(SchedulerEntry::canSchedule( + ScheduleHeuristic::Persistent, &fusion, runtime_info)); + auto scheduler = SchedulerEntry::makeEntry( + ScheduleHeuristic::Persistent, &fusion, runtime_info); + scheduler->schedule(&fusion); + + // Modify the schedule to use warp reduction + auto used_vals = fusion.usedMathVals(); + for (auto tv : ir_utils::filterByType(used_vals)) { + for (IterDomain* id : tv->domain()->domain()) { + if (id->getParallelType() == ParallelType::TIDx) { + id->padToMultipleOfWarp(); + } + } + } + + // Test result + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto outputs = fe.runFusion(aten_inputs); + auto ref_output = at::_softmax(aten_input, 1, false); + testValidate(&fusion, outputs, aten_inputs, {ref_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionIssue1133_CUDA) { + if (!deviceMajorMinorCheck(7)) { + GTEST_SKIP() << "skipping tests on pre-Volta GPUs"; + return; + } + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = sum(tv1, {1}); + auto tv3 = add(tv2, IrBuilder::create(1)); + + fusion.addOutput(tv3); + + tv0->computeAt(tv3, 1); + + const int split_factor = 32; + + tv2->split(-1, split_factor); + tv1->computeAt(tv2, -2); + + tv1->axis(-1)->parallelize(ParallelType::TIDx); + tv2->axis(-1)->parallelize(ParallelType::TIDx); + + tv3->axis(0)->parallelize(ParallelType::Unswitch); + + tv1->setMemoryType(MemoryType::Shared); + tv2->setMemoryType(MemoryType::Shared); + + // Both tv1 and tv2 should be allocated at the top-level scope + GpuLower gpulw(&fusion); + bool tv1_validated = false; + bool tv2_validated = false; + for (const auto& kir_node : gpulw.kernel()->topLevelExprs()) { + if (auto alloc = dynamic_cast(kir_node)) { + auto size = alloc->size(); + if (!(alloc->buffer()->name() == 1 || alloc->buffer()->name() == 2)) { + // There should be no allocation other than those for tv1 and tv2 + TORCH_CHECK(false, "Invalid allocation detected"); + } + TORCH_CHECK(size->isA(), "Invalid allocation size"); + TORCH_CHECK(size->as()->isConst(), "Allocation not constant"); + auto size_int = size->as()->value().value(); + if (alloc->buffer()->name() == 1) { + TORCH_CHECK( + size_int == split_factor, + "Invalid allocation size: ", + size->as()->value().value()); + tv1_validated = true; + } else { + TORCH_CHECK( + size_int == 1, + "Invalid allocation size: ", + size->as()->value().value()); + tv2_validated = true; + } + } + } + + TORCH_CHECK(tv1_validated, "Failed to validate tv1 allocation"); + TORCH_CHECK(tv2_validated, "Failed to validate tv2 allocation"); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({99, 101}, options); + std::vector aten_inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto outputs = fe.runFusion(aten_inputs); + + auto ref = (t0 + 1).sum({1}) + 1; + + testValidate(&fusion, outputs, aten_inputs, {ref}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionRfactorContigIDs_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto tv1 = sum(tv0, {1}); + fusion.addOutput(tv1); + + tv1->split(1, 32); + + auto tv2 = tv1->rFactor({1}); + + // This merged domain is not contiguous. + tv2->merge(0, 2); + + tv2->setMemoryType(MemoryType::Shared); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({99, 101}, options); + std::vector aten_inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto outputs = fe.runFusion(aten_inputs); + + auto ref = t0.sum({1}); + + testValidate(&fusion, outputs, aten_inputs, {ref}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionPersistentBufferCalculation1_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + auto tv2 = sum(tv1, {1}); + auto tv3 = broadcast(tv2, {false, true}); + auto tv4 = set(tv1); + auto tv5 = add(tv3, tv4); + fusion.addOutput(tv5); + + auto persistent_buffer_info = scheduler_utils::persistentBuffers(&fusion); + + auto isTvWithinVec = [](std::vector& vec, TensorView* tv) { + return std::find(vec.begin(), vec.end(), tv) != vec.end(); + }; + + auto tvEntryInVecVec = [](std::vector>& vec_o_vec, + std::vector& buffer_vec, + TensorView* tv) { + auto buffer_it = std::find(buffer_vec.begin(), buffer_vec.end(), tv); + return vec_o_vec.begin() + std::distance(buffer_vec.begin(), buffer_it); + }; + + auto& buffers = persistent_buffer_info.persistent_buffers; + auto& resolution = persistent_buffer_info.persistent_buffer_resolution_points; + auto& projectable = persistent_buffer_info.projectable_persistent_buffers; + auto& projectable_inputs = persistent_buffer_info.projectable_buffer_inputs; + + TORCH_INTERNAL_ASSERT(buffers.size() == 1); + TORCH_INTERNAL_ASSERT(resolution.size() == 1 && resolution[0].size() == 1); + TORCH_INTERNAL_ASSERT(projectable.size() == 1); + TORCH_INTERNAL_ASSERT(projectable_inputs.size() == 1); + + TORCH_INTERNAL_ASSERT(isTvWithinVec(buffers, tv1)); + TORCH_INTERNAL_ASSERT(isTvWithinVec(projectable, tv1)); + TORCH_INTERNAL_ASSERT(isTvWithinVec(projectable_inputs, tv0)); + + auto tv1_resolution_it = tvEntryInVecVec(resolution, buffers, tv1); + TORCH_INTERNAL_ASSERT(tv1_resolution_it != resolution.end()) + + TORCH_INTERNAL_ASSERT(isTvWithinVec(*tv1_resolution_it, tv5)); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor aten_t0 = at::randn({99, 101}, options); + + // Schedule through magic scheduler + SchedulerRuntimeInfo runtime_info(&fusion, {aten_t0}, true); + auto persistent_buffer_size = + persistentBufferSize(&fusion, runtime_info, persistent_buffer_info); + + TORCH_INTERNAL_ASSERT( + persistent_buffer_size.persistent_buffer_size == + static_cast(aten_t0.size(1) * dataTypeSize(DataType::Float))); + TORCH_INTERNAL_ASSERT( + persistent_buffer_size.projected_persistent_buffer_size == + static_cast(aten_t0.size(1) * dataTypeSize(DataType::Float))); +} + +TEST_F(NVFuserTest, FusionPersistentBufferCalculation2_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2, DataType::Half); + fusion.addInput(tv0); + + auto tv1 = castOp(DataType::Float, tv0); + auto tv2 = sum(tv1, {1}); + auto tv3 = broadcast(tv2, {false, true}); + auto tv4 = set(tv1); + auto tv5 = add(tv3, tv4); + auto tv6 = castOp(DataType::Half, tv5); + fusion.addOutput(tv6); + + auto persistent_buffer_info = scheduler_utils::persistentBuffers(&fusion); + + auto isTvWithinVec = [](std::vector& vec, TensorView* tv) { + return std::find(vec.begin(), vec.end(), tv) != vec.end(); + }; + + auto tvEntryInVecVec = [](std::vector>& vec_o_vec, + std::vector& buffer_vec, + TensorView* tv) { + auto buffer_it = std::find(buffer_vec.begin(), buffer_vec.end(), tv); + return vec_o_vec.begin() + std::distance(buffer_vec.begin(), buffer_it); + }; + + auto& buffers = persistent_buffer_info.persistent_buffers; + auto& resolution = persistent_buffer_info.persistent_buffer_resolution_points; + auto& projectable = persistent_buffer_info.projectable_persistent_buffers; + auto& projectable_inputs = persistent_buffer_info.projectable_buffer_inputs; + + TORCH_INTERNAL_ASSERT(buffers.size() == 1); + TORCH_INTERNAL_ASSERT(resolution.size() == 1 && resolution[0].size() == 1); + TORCH_INTERNAL_ASSERT(projectable.size() == 1); + TORCH_INTERNAL_ASSERT(projectable_inputs.size() == 1); + + TORCH_INTERNAL_ASSERT(isTvWithinVec(buffers, tv1)); + TORCH_INTERNAL_ASSERT(isTvWithinVec(projectable, tv1)); + TORCH_INTERNAL_ASSERT(isTvWithinVec(projectable_inputs, tv0)); + + auto tv1_resolution_it = tvEntryInVecVec(resolution, buffers, tv1); + TORCH_INTERNAL_ASSERT(tv1_resolution_it != resolution.end()) + + TORCH_INTERNAL_ASSERT(isTvWithinVec(*tv1_resolution_it, tv5)); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + at::Tensor aten_t0 = at::randn({99, 101}, options); + + // Schedule through magic scheduler + SchedulerRuntimeInfo runtime_info(&fusion, {aten_t0}, true); + auto persistent_buffer_size = + persistentBufferSize(&fusion, runtime_info, persistent_buffer_info); + + TORCH_INTERNAL_ASSERT( + persistent_buffer_size.persistent_buffer_size == + static_cast(aten_t0.size(1) * dataTypeSize(DataType::Float))); + TORCH_INTERNAL_ASSERT( + persistent_buffer_size.projected_persistent_buffer_size == + static_cast(aten_t0.size(1) * dataTypeSize(DataType::Half))); +} + +TEST_F(NVFuserTest, FusionPersistentBufferCalculation3_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2, DataType::Half); + fusion.addInput(tv0); + + auto tv1 = castOp(DataType::Float, tv0); + auto tv2 = set(tv1); + auto tv3 = sum(tv2, {1}); + auto tv4 = broadcast(tv3, {false, true}); + + auto tv5 = makeSymbolicTensor(2, DataType::Half); + fusion.addInput(tv5); + + auto tv6 = castOp(DataType::Float, tv5); + + auto tv7 = add(tv6, tv4); + auto tv8 = set(tv1); + auto tv9 = add(tv7, tv8); + auto tv10 = sum(tv9, {1}); + auto tv11 = broadcast(tv10, {false, true}); + auto tv12 = set(tv7); + auto tv13 = add(tv12, tv11); + + fusion.addOutput(tv13); + + auto persistent_buffer_info = scheduler_utils::persistentBuffers(&fusion); + + auto isTvWithinVec = [](std::vector& vec, TensorView* tv) { + return std::find(vec.begin(), vec.end(), tv) != vec.end(); + }; + + auto tvEntryInVecVec = [](std::vector>& vec_o_vec, + std::vector& buffer_vec, + TensorView* tv) { + auto buffer_it = std::find(buffer_vec.begin(), buffer_vec.end(), tv); + return vec_o_vec.begin() + std::distance(buffer_vec.begin(), buffer_it); + }; + + auto& buffers = persistent_buffer_info.persistent_buffers; + auto& resolution = persistent_buffer_info.persistent_buffer_resolution_points; + auto& projectable = persistent_buffer_info.projectable_persistent_buffers; + auto& projectable_inputs = persistent_buffer_info.projectable_buffer_inputs; + + TORCH_INTERNAL_ASSERT(buffers.size() == 2); + TORCH_INTERNAL_ASSERT( + resolution.size() == 2 && resolution[0].size() == 1 && + resolution[1].size() == 1); + TORCH_INTERNAL_ASSERT(projectable.size() == 1); + TORCH_INTERNAL_ASSERT(projectable_inputs.size() == 1); + + TORCH_INTERNAL_ASSERT( + isTvWithinVec(buffers, tv1) && isTvWithinVec(buffers, tv7)); + TORCH_INTERNAL_ASSERT( + isTvWithinVec(projectable, tv1) && !isTvWithinVec(projectable, tv7)); + + TORCH_INTERNAL_ASSERT(isTvWithinVec(projectable_inputs, tv0)); + + auto tv1_resolution_it = tvEntryInVecVec(resolution, buffers, tv1); + TORCH_INTERNAL_ASSERT(tv1_resolution_it != resolution.end()) + TORCH_INTERNAL_ASSERT(isTvWithinVec(*tv1_resolution_it, tv9)); + + auto tv7_resolution_it = tvEntryInVecVec(resolution, buffers, tv7); + TORCH_INTERNAL_ASSERT(tv7_resolution_it != resolution.end()) + TORCH_INTERNAL_ASSERT(isTvWithinVec(*tv7_resolution_it, tv13)); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + at::Tensor aten_t0 = at::randn({99, 101}, options); + at::Tensor aten_t5 = at::randn({99, 101}, options); + + // Schedule through magic scheduler + SchedulerRuntimeInfo runtime_info(&fusion, {aten_t0, aten_t5}, true); + auto persistent_buffer_size = + persistentBufferSize(&fusion, runtime_info, persistent_buffer_info); + + TORCH_INTERNAL_ASSERT( + persistent_buffer_size.persistent_buffer_size == + static_cast( + aten_t0.size(1) * dataTypeSize(DataType::Float) * 2)); + TORCH_INTERNAL_ASSERT( + persistent_buffer_size.projected_persistent_buffer_size == + static_cast( + aten_t0.size(1) * + (dataTypeSize(DataType::Half) + dataTypeSize(DataType::Float)))); +} + +TEST_F(NVFuserTest, FusionPersistentBufferCalculation4_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2, DataType::Half); + fusion.addInput(tv0); + + auto tv1 = castOp(DataType::Float, tv0); + auto tv2 = set(tv1); + auto tv3 = sum(tv2, {1}); + auto tv4 = broadcast(tv3, {false, true}); + auto tv5 = set(tv1); + auto tv6 = add(tv4, tv5); + auto tv7 = set(tv2); + auto tv8 = add(tv7, tv6); + auto tv9 = castOp(DataType::Half, tv8); + + fusion.addOutput(tv9); + + auto persistent_buffer_info = scheduler_utils::persistentBuffers(&fusion); + + auto isTvWithinVec = [](std::vector& vec, TensorView* tv) { + return std::find(vec.begin(), vec.end(), tv) != vec.end(); + }; + + auto tvEntryInVecVec = [](std::vector>& vec_o_vec, + std::vector& buffer_vec, + TensorView* tv) { + auto buffer_it = std::find(buffer_vec.begin(), buffer_vec.end(), tv); + return vec_o_vec.begin() + std::distance(buffer_vec.begin(), buffer_it); + }; + + auto& buffers = persistent_buffer_info.persistent_buffers; + auto& resolution = persistent_buffer_info.persistent_buffer_resolution_points; + auto& projectable = persistent_buffer_info.projectable_persistent_buffers; + auto& projectable_inputs = persistent_buffer_info.projectable_buffer_inputs; + + TORCH_INTERNAL_ASSERT(buffers.size() == 2); + TORCH_INTERNAL_ASSERT( + resolution.size() == 2 && resolution[0].size() == 1 && + resolution[1].size() == 1); + + TORCH_INTERNAL_ASSERT(projectable.size() == 2); + TORCH_INTERNAL_ASSERT(projectable_inputs.size() == 1); + + TORCH_INTERNAL_ASSERT( + isTvWithinVec(buffers, tv1) && isTvWithinVec(buffers, tv2)); + TORCH_INTERNAL_ASSERT( + isTvWithinVec(projectable, tv1) && isTvWithinVec(projectable, tv2)); + + TORCH_INTERNAL_ASSERT(isTvWithinVec(projectable_inputs, tv0)); + + auto tv1_resolution_it = tvEntryInVecVec(resolution, buffers, tv1); + TORCH_INTERNAL_ASSERT(tv1_resolution_it != resolution.end()) + TORCH_INTERNAL_ASSERT(isTvWithinVec(*tv1_resolution_it, tv6)); + + auto tv2_resolution_it = tvEntryInVecVec(resolution, buffers, tv2); + TORCH_INTERNAL_ASSERT(tv2_resolution_it != resolution.end()) + TORCH_INTERNAL_ASSERT(isTvWithinVec(*tv2_resolution_it, tv8)); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + at::Tensor aten_t0 = at::randn({99, 101}, options); + + // Schedule through magic scheduler + SchedulerRuntimeInfo runtime_info(&fusion, {aten_t0}, true); + auto persistent_buffer_size = + persistentBufferSize(&fusion, runtime_info, persistent_buffer_info); + + TORCH_INTERNAL_ASSERT( + persistent_buffer_size.persistent_buffer_size == + static_cast( + aten_t0.size(1) * dataTypeSize(DataType::Float) * 2)); + + TORCH_INTERNAL_ASSERT( + persistent_buffer_size.projected_persistent_buffer_size == + static_cast(aten_t0.size(1) * dataTypeSize(DataType::Half))); +} + +TEST_F(NVFuserTest, FusionPersistentBufferProjection_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2, DataType::Half); + fusion.addInput(tv0); + + auto tv1 = castOp(DataType::Float, tv0); + auto tv2 = set(tv1); + auto tv3 = sum(tv2, {1}); + auto tv4 = broadcast(tv3, {false, true}); + auto tv5 = set(tv1); + auto tv6 = add(tv4, tv5); + auto tv7 = set(tv2); + auto tv8 = add(tv7, tv6); + auto tv9 = castOp(DataType::Half, tv8); + + fusion.addOutput(tv9); + + reduction_scheduler_utils::projectPersistentBuffers(&fusion); + + auto tv5_producers = ir_utils::producerTvsOf(tv5); + auto tv7_producers = ir_utils::producerTvsOf(tv7); + + // Projection should have broken these dependencies + + TORCH_INTERNAL_ASSERT( + std::find(tv5_producers.begin(), tv5_producers.end(), tv1) == + tv5_producers.end()); + TORCH_INTERNAL_ASSERT( + std::find(tv7_producers.begin(), tv7_producers.end(), tv2) == + tv7_producers.end()); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + at::Tensor aten_t0 = at::randn({99, 101}, options); + + FusionExecutorCache fec(std::move(fusion_ptr)); + auto cg_outputs = fec.runFusionWithInputs({aten_t0}); + + auto aten_t1 = aten_t0.to(c10::kDouble); + auto aten_t3 = aten_t1.sum({1}); + auto aten_t4 = aten_t3.unsqueeze(1); + auto aten_t7 = aten_t4.add(aten_t1).add(aten_t1); + + testValidate(&fusion, cg_outputs, {aten_t0}, {aten_t7}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionIssue1223_CUDA) { + if (!deviceMajorMinorCheck(7)) { + GTEST_SKIP() << "skipping tests on pre-Volta GPUs"; + return; + } + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(2); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = sum(tv1, {0, 1}); + fusion.addOutput(tv2); + + auto tv3 = add(tv0, IrBuilder::create(0)); + fusion.addOutput(tv3); + + tv2->split(0, 4); + tv2->split(1, 1, false); + tv2->split(-1, 4); + + tv2->axis(1)->parallelize(ParallelType::Unswitch); + tv2->axis(-3)->parallelize(ParallelType::TIDx); + tv2->axis(-1)->parallelize(ParallelType::TIDy); + + tv1->computeAt(tv2, -1); + + // Make TIDx and TIDy non-exact + tv3->split(0, 32); + tv3->split(-1, 32); + tv3->axis(1)->parallelize(ParallelType::TIDx); + tv3->axis(3)->parallelize(ParallelType::TIDy); + + // The second axis of both tv1 and tv2 are fully unswitched, so they + // don't need to predicate the parallel type usage of TIDy, whereas + // the first axis is only partially unswitched, i.e., part of its + // split output domains is outside the unswitched axis, so the first + // axis, which uses TIDx, needs to predicate the parallel + // dimension. Previously, as reported in issue #1223, unswitched + // expressions didn't predicate parallel dimensions. It should be + // fixed by PR #1222. + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor at_t0 = at::ones({11, 10}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {at_t0}); + auto cg_outputs = fe.runFusion({at_t0}); + + auto at_t1 = (at_t0 + 1).sum(); + + testValidate( + &fusion, cg_outputs, {at_t0}, {at_t1, at_t0}, __LINE__, __FILE__); +} + +// See #1247 and #1250 +TEST_F(NVFuserTest, FusionRfactorPredication1_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(1); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = min(tv1, {0}); + + fusion.addOutput(tv2); + + // Make TIDx non-exact + auto tv3 = makeContigTensor(1); + fusion.addInput(tv3); + + auto tv4 = add(tv3, IrBuilder::create(1)); + fusion.addOutput(tv4); + + tv2->split(0, 4); + auto tv5 = tv2->rFactor({1}); + + tv0->computeAt(tv2, 1); + + tv2->axis(0)->parallelize(ParallelType::TIDx); + + tv4->axis(0)->parallelize(ParallelType::TIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor at_t0 = at::randn({9}, options); + at_t0 = at::abs(at_t0); + at::Tensor at_t3 = at::randn({128}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {at_t0, at_t3}); + auto cg_outputs = fe.runFusion({at_t0, at_t3}); + + auto at_t2 = (at_t0 + 1).min(); + auto at_t4 = at_t3 + 1; + + testValidate( + &fusion, cg_outputs, {at_t0, at_t3}, {at_t2, at_t4}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionRfactorPredication2_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(1); + fusion.addInput(tv0); + + auto tv1 = min(tv0, {0}); + fusion.addOutput(tv1); + + // Make TIDx non-exact + auto tv2 = makeContigTensor(1); + fusion.addInput(tv2); + + auto tv3 = add(tv2, IrBuilder::create(1)); + fusion.addOutput(tv3); + + tv1->split(0, 4); + auto tv4 = tv1->rFactor({0}); + + tv1->split(0, 3); + + // tv0->computeAt(tv1, 3); + tv4->reorder({{0, 1}}); + tv4->split(0, 3); + tv4->setMemoryType(MemoryType::Shared); + + // tv0: [I] + // tv4: [4/3, 3, I/4] + // tv1: [4/3, 3] + + tv1->axis(0)->parallelize(ParallelType::TIDx); + scheduler_utils::parallelizeAllLike(tv1, {tv4}); + + tv3->axis(0)->parallelize(ParallelType::TIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + at::Tensor at_t0 = at::randn({9}, options); + at_t0 = at::abs(at_t0); + at::Tensor at_t3 = at::randn({128}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {at_t0, at_t3}); + auto cg_outputs = fe.runFusion({at_t0, at_t3}); + + auto at_t2 = std::get<0>(at_t0.min(0)); + auto at_t4 = at_t3 + 1; + + testValidate( + &fusion, cg_outputs, {at_t0, at_t3}, {at_t2, at_t4}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionRfactorIndirectRoot_CUDA) { + // https://github.com/csarofeen/pytorch/issues/1692 + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(3); + fusion.addInput(tv0); + + auto tv1 = sum(tv0, {1, 2}); + fusion.addOutput(tv1); + + tv1->split(2, 4); + tv1->split(1, 3); + tv1->merge(2, 3); + auto rf = tv1->rFactor({-1}); + + tv1->split(0, 256); + tv1->axis(0)->parallelize(ParallelType::BIDx); + tv1->axis(1)->parallelize(ParallelType::TIDx); + rf->computeAt(tv1, -1); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + + auto at_in = at::randn({6, 6, 6}, options); + auto at_out = at_in.sum({1, 2}); + + FusionExecutor fe; + fe.compileFusion(&fusion, {at_in}); + auto cg_outputs = fe.runFusion({at_in}); + + testValidate(&fusion, cg_outputs, {at_in}, {at_out}, __LINE__, __FILE__); +} + +} // namespace jit +} // namespace torch +#endif // #if defined(USE_CUDA) diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu3.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu3.cpp new file mode 100644 index 0000000000000..8d24cc3803747 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu3.cpp @@ -0,0 +1,6538 @@ +#if defined(USE_CUDA) +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include + +// Tests go in torch::jit +namespace torch { +namespace jit { + +using namespace torch::jit::fuser::cuda; +using namespace at::indexing; + +TEST_F(NVFuserTest, FusionNonDivisibleSplit1_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + auto tv1 = sum(tv0, {0}); + fusion.addOutput(tv1); + + // [I] + tv1->split(0, 5); + // [ceilDiv(I, 5), 5] + + // This second split is non-divisible. The split domain must be predicated. + tv1->split(1, 3); + // [ceilDiv(I, 5), 2, 3] + + auto tv2 = sum(tv0, {0}); + fusion.addOutput(tv2); + + // tv2 shouldn't need to have another predicate + tv2->split(0, 4); + tv2->split(1, 2); + + GpuLower gpulw(&fusion); + TORCH_CHECK( + gpulw.nonDivisibleSplitInfo().splitsToValidate().empty(), + "There must be no split to validate"); + TORCH_CHECK( + gpulw.nonDivisibleSplitInfo().splitsToPredicate().size() == 1, + "Only tv1 should have a non-divisible predicate."); + for (auto tv : {loweredTv(tv1, gpulw)}) { + auto it = gpulw.nonDivisibleSplitInfo().splitsToPredicate().find(tv); + TORCH_CHECK( + it != gpulw.nonDivisibleSplitInfo().splitsToPredicate().end(), + "No info found for ", + tv); + const auto& splits_to_predicate = it->second; + TORCH_CHECK( + splits_to_predicate.size() == 1, + "There must be one split to predicate"); + } + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + at::Tensor t0 = at::randn({24}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto cg_outputs = fe.runFusion({t0}); + + auto ref = t0.sum(); + + testValidate(&fusion, cg_outputs, {t0}, {ref, ref}, __LINE__, __FILE__); +} + +// Repro of issue #1074 +TEST_F(NVFuserTest, FusionNonDivisibleSplit2_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); + fusion.addOutput(tv2); + + tv2->split(0, 2); + tv2->split(-1, 4); + tv2->reorder({{1, 2}, {2, 1}}); + tv0->computeAt(tv2, 2); + + tv2->split(-1, 3); + + // To make the sanitizer catch the invalid accesses. Not necessary + // to expose the bug. + tv1->setMemoryType(MemoryType::Shared); + + GpuLower gpulw(&fusion); + TORCH_CHECK( + gpulw.nonDivisibleSplitInfo().splitsToValidate().empty(), + "There must be no split to validate"); + TORCH_CHECK( + gpulw.nonDivisibleSplitInfo().splitsToPredicate().size() == 1, + "Only tv2 should have a non-divisible predicate."); + for (auto tv : {loweredTv(tv2, gpulw)}) { + auto it = gpulw.nonDivisibleSplitInfo().splitsToPredicate().find(tv); + TORCH_CHECK( + it != gpulw.nonDivisibleSplitInfo().splitsToPredicate().end(), + "No info found for ", + tv); + const auto& splits_to_predicate = it->second; + TORCH_CHECK( + splits_to_predicate.size() == 1, + "There must be one split to predicate"); + } + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + at::Tensor t0 = at::randn({13, 17}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto cg_outputs = fe.runFusion({t0}); + + auto ref = t0 + 2; + + testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); +} + +// Similar to FusionNonDivisibleSplit1 but with unswitch +TEST_F(NVFuserTest, FusionNonDivisibleSplit3_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = sum(tv1, {0}); + fusion.addOutput(tv2); + + tv2->split(0, 5); + tv2->split(1, 3); + + tv0->computeAt(tv2, -1); + + tv2->axis(0)->parallelize(ParallelType::Unswitch); + + GpuLower gpulw(&fusion); + TORCH_CHECK( + gpulw.nonDivisibleSplitInfo().splitsToValidate().empty(), + "There must be no split to validate"); + TORCH_CHECK( + gpulw.nonDivisibleSplitInfo().splitsToPredicate().size() == 2, + "Both tv1 and tv2 should have a non-divisible predicate."); + for (auto tv : {loweredTv(tv1, gpulw), loweredTv(tv2, gpulw)}) { + auto it = gpulw.nonDivisibleSplitInfo().splitsToPredicate().find(tv); + TORCH_CHECK( + it != gpulw.nonDivisibleSplitInfo().splitsToPredicate().end(), + "No info found for ", + tv); + const auto& splits_to_predicate = it->second; + TORCH_CHECK( + splits_to_predicate.size() == 1, + "There must be one split to predicate"); + } + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + at::Tensor t0 = at::randn({24}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto cg_outputs = fe.runFusion({t0}); + + auto ref = (t0 + 1).sum(); + + testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); +} + +// Non-divisible split through merge +TEST_F(NVFuserTest, FusionNonDivisibleSplit4_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = sum(tv1, {0, 1}); + fusion.addOutput(tv2); + + tv2->split(0, 5); + tv2->merge(1, 2); + tv2->split(1, 3); + + tv0->computeAt(tv2, -1); + + GpuLower gpulw(&fusion); + TORCH_CHECK( + gpulw.nonDivisibleSplitInfo().splitsToValidate().empty(), + "There must be no split to validate"); + TORCH_CHECK( + gpulw.nonDivisibleSplitInfo().splitsToPredicate().size() == 2, + "Both tv1 and tv2 should have a non-divisible predicate."); + for (auto tv : {loweredTv(tv1, gpulw), loweredTv(tv2, gpulw)}) { + auto it = gpulw.nonDivisibleSplitInfo().splitsToPredicate().find(tv); + TORCH_CHECK( + it != gpulw.nonDivisibleSplitInfo().splitsToPredicate().end(), + "No info found for ", + tv); + const auto& splits_to_predicate = it->second; + TORCH_CHECK( + splits_to_predicate.size() == 1, + "There must be one split to predicate"); + } + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + at::Tensor t0 = at::randn({24, 2}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto cg_outputs = fe.runFusion({t0}); + + auto ref = (t0 + 1).sum(); + + testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); +} + +// Nested splits +TEST_F(NVFuserTest, FusionNonDivisibleSplit5_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = sum(tv1, {0}); + fusion.addOutput(tv2); + + // [I] + tv2->split(0, 8); + // [I/8, 8] + tv2->split(1, 2); + // [I/8, 4, 2] + tv2->split(1, 3); // non-divisible split of outer output + // [I/8, 2, 3, 2] + + tv0->computeAt(tv2, -1); + + GpuLower gpulw(&fusion); + TORCH_CHECK( + gpulw.nonDivisibleSplitInfo().splitsToValidate().empty(), + "There must be no split to validate"); + TORCH_CHECK( + gpulw.nonDivisibleSplitInfo().splitsToPredicate().size() == 2, + "Both tv1 and tv2 should have a non-divisible predicate."); + for (auto tv : {loweredTv(tv1, gpulw), loweredTv(tv2, gpulw)}) { + auto it = gpulw.nonDivisibleSplitInfo().splitsToPredicate().find(tv); + TORCH_CHECK( + it != gpulw.nonDivisibleSplitInfo().splitsToPredicate().end(), + "No info found for ", + tv); + const auto& splits_to_predicate = it->second; + TORCH_CHECK( + splits_to_predicate.size() == 1, + "There must be one split to predicate"); + } + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + at::Tensor t0 = at::randn({24}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto cg_outputs = fe.runFusion({t0}); + + auto ref = (t0 + 1).sum(); + + testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); +} + +// Vectorized non-divisible split. Must be validated at run time +TEST_F(NVFuserTest, FusionNonDivisibleSplitVectorize1_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(1); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + fusion.addOutput(tv1); + + tv1->split(0, 8, false); + tv1->split(1, 4); + + tv1->axis(-1)->parallelize(ParallelType::Vectorize); + + GpuLower gpulw(&fusion); + TORCH_CHECK( + gpulw.nonDivisibleSplitInfo().splitsToValidate().size() == 1, + "There should be one split to validate"); + for (const auto& kv : gpulw.nonDivisibleSplitInfo().splitsToPredicate()) { + const auto& splits_to_predicate = kv.second; + TORCH_CHECK( + splits_to_predicate.empty(), + "There must be no split to predicate, but tensor t", + kv.first->name(), + " has:", + splits_to_predicate); + } + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + auto t0 = at::randn({32}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto cg_outputs = fe.runFusion({t0}); + + auto ref = t0; + + testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); + + auto t0_non_divisible = at::randn({8}, options); + // Since ceilDiv(8, 8) is not divisible by 4, the vectorization is + // illegal. The run-time validation of vectorization should throw an error. + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) + ASSERT_ANY_THROW(fe.runFusion({t0_non_divisible})); +} + +// If a split is validated at run time, it's not necessary to predicate. +TEST_F(NVFuserTest, FusionNonDivisibleSplitVectorize2_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(1); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + auto tv2 = add(tv1, IrBuilder::create(1)); + auto tv3 = sum(tv2, {0}); + fusion.addOutput(tv3); + + tv3->split(0, 8, false); + tv3->split(1, 4); + TransformPropagatorWithCheck propagator(tv3); + MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + + tv3->axis(1)->parallelize(ParallelType::TIDx); + scheduler_utils::parallelizeAllLike(tv3, {tv1, tv2}); + + tv1->axis(2)->parallelize(ParallelType::Vectorize); + + GpuLower gpulw(&fusion); + TORCH_CHECK( + gpulw.nonDivisibleSplitInfo().splitsToValidate().size() == 1, + "There should be one split to validate"); + for (const auto& kv : gpulw.nonDivisibleSplitInfo().splitsToPredicate()) { + const auto& splits_to_predicate = kv.second; + TORCH_CHECK( + splits_to_predicate.empty(), + "There must be no split to predicate, but tensor t", + kv.first->name(), + " has:", + splits_to_predicate); + } + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + + auto t0 = at::randn({1024}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto cg_outputs = fe.runFusion({t0}); + + auto ref = (t0 + 1).sum(); + + testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionIssue1284Repro_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + std::vector input_shape_0 = {10, 20}; + std::vector input_shape_1 = {15}; + + TensorView* in_0 = makeSymbolicTensor(input_shape_0.size()); + TensorView* in_1 = makeSymbolicTensor(input_shape_1.size()); + fusion.addInput(in_0); + fusion.addInput(in_1); + + TensorView* out_0 = add(in_0, IrBuilder::create(0.f)); + TensorView* out_1 = add(in_1, IrBuilder::create(2.f)); + + fusion.addOutput(out_0); + fusion.addOutput(out_1); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor at_in_0 = at::randn(input_shape_0, options); + at::Tensor at_in_1 = at::randn(input_shape_1, options); + std::vector aten_inputs = {at_in_0, at_in_1}; + + FusionExecutorCache fec(std::move(fusion_ptr)); + auto outputs = fec.runFusionWithInputs(aten_inputs); + + auto t1 = at_in_1 + 2; + + auto runtime = fec.getMostRecentKernelRuntime(); + TORCH_INTERNAL_ASSERT(runtime->isSegmented()); + TORCH_INTERNAL_ASSERT(runtime->fusionSegments()->groups().size() == 2); + + testValidate( + &fusion, outputs, {at_in_0, at_in_1}, {at_in_0, t1}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionIssue1284Repro2_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + std::vector input_shape_0 = {4, 4}; + std::vector input_shape_1 = {3, 4, 4}; + std::vector input_shape_2 = {2, 8, 4, 4}; + + TensorView* in_0 = makeSymbolicTensor(input_shape_0.size()); + TensorView* in_1 = makeSymbolicTensor(input_shape_1.size()); + TensorView* in_2 = makeSymbolicTensor(input_shape_2.size()); + + fusion.addInput(in_0); + fusion.addInput(in_1); + fusion.addInput(in_2); + + TensorView* out_0 = add(in_0, in_1); + TensorView* out_1 = add(in_0, in_2); + + fusion.addOutput(out_0); + fusion.addOutput(out_1); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor at_in_0 = at::randn(input_shape_0, options); + at::Tensor at_in_1 = at::randn(input_shape_1, options); + at::Tensor at_in_2 = at::randn(input_shape_2, options); + + std::vector aten_inputs = {at_in_0, at_in_1, at_in_2}; + + FusionExecutorCache fec(std::move(fusion_ptr)); + auto outputs = fec.runFusionWithInputs(aten_inputs); + + auto t0 = at_in_0 + at_in_1; + auto t1 = at_in_0 + at_in_2; + + auto runtime = fec.getMostRecentKernelRuntime(); + TORCH_INTERNAL_ASSERT(runtime->isSegmented()); + TORCH_INTERNAL_ASSERT(runtime->fusionSegments()->groups().size() == 2); + + testValidate( + &fusion, + outputs, + {at_in_0, at_in_1, at_in_2}, + {t0, t1}, + __LINE__, + __FILE__); +} + +TEST_F(NVFuserTest, FusionIssue1305Repro_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + auto t0 = makeContigTensor(1); + auto t1 = makeContigTensor(2); + + fusion.addInput(t0); + fusion.addInput(t1); + + auto t2 = broadcast(t0, {true, false}); + auto t3 = add(t1, t2); + auto t4 = add(t3, t2); + auto t5 = sum(t4, {1}); + auto t6 = broadcast(t5, {false, true}); + auto t7 = add(t3, t6); + + fusion.addOutput(t7); + + t3->computeAt(t7, -1, ComputeAtMode::MostInlined); + + TORCH_INTERNAL_ASSERT(t3->getComputeAtPosition() == 1); +} + +TEST_F(NVFuserTest, FusionDoubleBuffering1_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(1); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + auto tv2 = add(tv1, IrBuilder::create(1.0)); + auto tv3 = set(tv2); + fusion.addOutput(tv3); + + tv1->setMemoryType(MemoryType::Shared); + + tv3->split(-1, 128); + tv3->split(-1, 32); + TransformPropagatorWithCheck propagator(tv3); + MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + + tv0->computeAt(tv3, 1); + + tv3->axis(-2)->parallelize(ParallelType::BIDx); + tv3->axis(-1)->parallelize(ParallelType::TIDx); + scheduler_utils::parallelizeAllLike(tv3); + + tv1->doubleBuffer(); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + auto t0 = at::randn({1000}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto cg_outputs = fe.runFusion({t0}); + + auto ref = t0 + 1; + + testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionDoubleBuffering2_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(1); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + auto tv2 = add(tv1, IrBuilder::create(1.0)); + auto tv3 = set(tv2); + fusion.addOutput(tv3); + + tv3->split(-1, 128); + tv3->split(-1, 32); + TransformPropagatorWithCheck propagator(tv3); + MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + + tv0->computeAt(tv3, -1); + + tv3->axis(-2)->parallelize(ParallelType::BIDx); + tv3->axis(-1)->parallelize(ParallelType::TIDx); + scheduler_utils::parallelizeAllLike(tv3); + + tv1->doubleBuffer(); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + auto t0 = at::randn({1000}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto cg_outputs = fe.runFusion({t0}); + + auto ref = t0 + 1; + + testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionDoubleBuffering3_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(1); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1.0)); + auto tv2 = set(tv1); + auto tv3 = add(tv2, IrBuilder::create(1.0)); + fusion.addOutput(tv3); + + tv1->setMemoryType(MemoryType::Shared); + + tv3->split(-1, 128); + tv3->split(-1, 32); + TransformPropagatorWithCheck propagator(tv3); + MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + + tv0->computeAt(tv3, 1); + + // tv2 is invalid to double-buffer as its producer, tv1, is + // computed inside the double-buffering loop. + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) + ASSERT_ANY_THROW(tv2->doubleBuffer()); + + // Moving tv2 inner makes tv1 large enough to double-buffer tv2 + tv2->computeAt(tv3, 2); + + tv2->doubleBuffer(); + + tv3->axis(-1)->parallelize(ParallelType::TIDx); + scheduler_utils::parallelizeAllLike(tv3); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + auto t0 = at::randn({1000}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto cg_outputs = fe.runFusion({t0}); + + auto ref = t0 + 2; + + testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); +} + +// Double buffering smem to local and unswitch +TEST_F(NVFuserTest, FusionDoubleBuffering4_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(1); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1.0)); + auto tv2 = set(tv1); + auto tv3 = add(tv2, IrBuilder::create(1.0)); + fusion.addOutput(tv3); + + tv1->setMemoryType(MemoryType::Shared); + + tv3->split(-1, 128); + tv3->split(-1, 32); + tv3->split(-1, 8); + TransformPropagatorWithCheck propagator(tv3); + MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + + tv0->computeAt(tv3, 2); + tv2->computeAt(tv3, -1); + + tv3->axis(-1)->parallelize(ParallelType::TIDx); + tv3->axis(1)->parallelize(ParallelType::Unswitch); + scheduler_utils::parallelizeAllLike(tv3); + + tv2->doubleBuffer(); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + auto t0 = at::randn({1000}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto cg_outputs = fe.runFusion({t0}); + + auto ref = t0 + 2; + + testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); +} + +// Double buffering gmem to shared and unswitch +TEST_F(NVFuserTest, FusionDoubleBuffering5_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(1); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + auto tv2 = add(tv1, IrBuilder::create(1.0)); + fusion.addOutput(tv2); + + tv1->setMemoryType(MemoryType::Shared); + + tv2->split(-1, 128); + tv2->split(-1, 32); + tv2->split(-1, 8); + TransformPropagatorWithCheck propagator(tv2); + MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator); + + tv0->computeAt(tv2, 2); + tv1->computeAt(tv2, -1); + + tv2->axis(-1)->parallelize(ParallelType::TIDx); + tv2->axis(1)->parallelize(ParallelType::Unswitch); + scheduler_utils::parallelizeAllLike(tv2); + + tv1->doubleBuffer(); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + auto t0 = at::randn({1000}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto cg_outputs = fe.runFusion({t0}); + + auto ref = t0 + 1; + + testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); +} + +// Double buffering smem to local and unroll +TEST_F(NVFuserTest, FusionDoubleBuffering6_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(1); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1.0)); + auto tv2 = set(tv1); + auto tv3 = add(tv2, IrBuilder::create(1.0)); + fusion.addOutput(tv3); + + tv1->setMemoryType(MemoryType::Shared); + + tv3->split(-1, 128); + tv3->split(-1, 16); + tv3->split(-2, 4); + tv3->split(-2, 2); + TransformPropagatorWithCheck propagator(tv3); + MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + + tv0->computeAt(tv3, 1); + tv2->computeAt(tv3, -1); + + tv3->axis(2)->parallelize(ParallelType::Unroll); + tv3->axis(4)->parallelize(ParallelType::TIDx); + + tv2->doubleBuffer(); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + auto t0 = at::randn({199}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto cg_outputs = fe.runFusion({t0}); + + auto ref = t0 + 2; + + testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); +} + +// Double buffering and vectorize +TEST_F(NVFuserTest, FusionDoubleBuffering7_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(1); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + auto tv2 = add(tv1, IrBuilder::create(1.0)); + fusion.addOutput(tv2); + + tv2->split(-1, 128); + tv2->split(-1, 4); + TransformPropagatorWithCheck propagator(tv2); + MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator); + + tv1->computeAt(tv2, 2); + + tv2->axis(-2)->parallelize(ParallelType::TIDx); + + tv1->axis(-1)->parallelize(ParallelType::Vectorize); + + tv1->doubleBuffer(); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + auto t0 = at::randn({200}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto cg_outputs = fe.runFusion({t0}); + + auto ref = t0 + 1; + + testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); +} + +// Multiple tensors to double-buffer +TEST_F(NVFuserTest, FusionDoubleBuffering8_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(1); + fusion.addInput(tv0); + auto tv1 = makeContigTensor(1); + fusion.addInput(tv1); + + auto tv2 = set(tv0); + auto tv3 = set(tv1); + auto tv4 = add(tv2, tv3); + fusion.addOutput(tv4); + + tv4->split(0, 32); + tv4->split(0, 4); + TransformPropagatorWithCheck propagator(tv4); + MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); + + tv0->computeAt(tv4, 1); + tv1->computeAt(tv4, 1); + + tv4->axis(-1)->parallelize(ParallelType::TIDx); + scheduler_utils::parallelizeAllLike(tv4); + + tv2->doubleBuffer(); + tv3->doubleBuffer(); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + auto t0 = at::randn({100}, options); + auto t1 = at::randn({100}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t1}); + auto cg_outputs = fe.runFusion({t0, t1}); + + auto ref = t0 + t1; + + testValidate(&fusion, cg_outputs, {t0, t1}, {ref}, __LINE__, __FILE__); +} + +// Nested double buffering from gmem to smem and smem to register +TEST_F(NVFuserTest, FusionDoubleBuffering9_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(1); + fusion.addInput(tv0); + auto tv1 = add(tv0, IrBuilder::create(1)); + auto out = tv1; + fusion.addOutput(out); + + auto tv2 = tv0->cacheAfter(); + auto tv3 = tv2->cacheAfter(); + + out->split(0, 32); + out->split(0, 4); + TransformPropagatorWithCheck propagator(out); + MaxRootDomainInfoSpanningTree(out).traverse(&propagator); + + tv2->setMemoryType(MemoryType::Shared); + + tv2->computeAt(out, 1); + tv3->computeAt(out, -1); + + out->axis(-1)->parallelize(ParallelType::TIDx); + scheduler_utils::parallelizeAllLike(out); + + tv2->doubleBuffer(); + tv3->doubleBuffer(); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + auto t0 = at::randn({1001}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto cg_outputs = fe.runFusion({t0}); + + auto ref = t0 + 1; + + testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); +} + +// FusionSmemBlockGemmCache + double buffering at both smem and local +TEST_F(NVFuserTest, FusionSmemBlockGemmCacheDoubleBuffer_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Algorithm + TensorView* tv0 = makeSymbolicTensor(2); // (M, K) + TensorView* tv1 = makeSymbolicTensor(2); // (K, N) + TensorView* tv2 = broadcast(tv0, {false, false, true}); // (M, K, B) + TensorView* tv3 = broadcast(tv1, {true, false, false}); // (B, K, N) + TensorView* tv4 = mul(tv2, tv3); // M, K, N + TensorView* tv5 = sum(tv4, {1}); // M, R, N + fusion.addInput(tv0); + fusion.addInput(tv1); + fusion.addOutput(tv5); + + TensorView* tv6 = tv5->cacheBefore(); + + // For smem double buffering + auto tv0_cache_local = tv0->cacheAfter(); + auto tv1_cache_local = tv1->cacheAfter(); + + // For register double buffering + auto tv0_cache_smem = tv0->cacheAfter(); + auto tv1_cache_smem = tv1->cacheAfter(); + + const int BSX = 32; + const int TSX = 8; + + // [M, K, N] + tv6->split(-1, BSX); + tv6->split(-1, TSX); + tv6->split(1, BSX); + tv6->split(0, BSX); + tv6->split(1, TSX); + // [M/BSX, BSX/TSX, TSX, K/BSX, BSX, N/BSX, BSX/TSX, TSX] + tv6->reorder( + {{4, 7}, {7, 6}, {6, 5}, {2, 4}, {1, 3}, {3, 2}, {5, 1}, {0, 0}}); + // [M/BSX, N/BSX, K/BSX, BSX/TSX, BSX/TSX, TSX, TSX, BSX] + + auto tv6_rf = tv6->rFactor({-1}); + + TransformPropagatorWithCheck propagator(tv6_rf); + MaxRootDomainInfoSpanningTree(tv6_rf).traverse(&propagator); + + tv0->computeAt(tv6, 3); + tv1->computeAt(tv6, 3); + + tv6_rf->computeAt(tv6, -1); + tv0_cache_local->computeAt(tv6_rf, -1); + tv1_cache_local->computeAt(tv6_rf, -1); + + tv0_cache_smem->setMemoryType(MemoryType::Shared); + tv1_cache_smem->setMemoryType(MemoryType::Shared); + + tv5->axis(0)->parallelize(ParallelType::BIDx); + tv5->axis(1)->parallelize(ParallelType::BIDy); + tv5->axis(-3)->parallelize(ParallelType::TIDy); + tv5->axis(-1)->parallelize(ParallelType::TIDx); + + scheduler_utils::parallelizeAllLike(tv5); + + tv0_cache_local->doubleBuffer(); + tv1_cache_local->doubleBuffer(); + + tv0_cache_smem->doubleBuffer(); + tv1_cache_smem->doubleBuffer(); + + constexpr int M = 154, K = 45, N = 1524; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({M, K}, options); + at::Tensor t1 = at::randn({K, N}, options); + at::Tensor aten_output = matmul(t0.to(at::kDouble), t1.to(at::kDouble)); + + std::vector aten_inputs = {t0, t1}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + testValidate( + &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); + // The smem cache write in this test case is redundant predicated, + // and also double buffered. Currently we are relying on WAR sync + // insertion to ensure ordering of double buffered tensor access. + // The check below makes sure that the sync is inserted so that the + // test isn't running on a race condition. + TORCH_CHECK(fe.kernel()->summary().war_hazard_syncs_count > 0); +} + +TEST_F(NVFuserTest, FusionIntermediateTensorVectorize_CUDA) { + std::vector mem_types = {MemoryType::Shared, MemoryType::Local}; + + for (auto mem_type : mem_types) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(1); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + auto tv2 = set(tv1); + auto tv3 = set(tv2); + fusion.addOutput(tv3); + + tv1->setMemoryType(mem_type); + + tv3->split(-1, 4); + TransformPropagatorWithCheck propagator(tv3); + MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + + tv1->computeAt(tv3, -2); + + tv2->axis(-1)->parallelize(ParallelType::Vectorize); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + auto t0 = at::randn({15}, options); + FusionExecutor fe; + fe.compileFusion(&fusion); + + // This should throw an exception as the extent of t0 is not + // divisible by the vector width + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) + ASSERT_ANY_THROW(fe.runFusion({t0})); + + auto t1 = at::randn({16}, options); + auto cg_outputs = fe.runFusion({t1}); + + auto ref = t1; + + testValidate(&fusion, cg_outputs, {t1}, {ref}, __LINE__, __FILE__); + } +} + +TEST_F(NVFuserTest, FusionBroadcastConcretization1_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({10, 1}); + fusion.addInput(tv0); + auto tv1 = makeConcreteTensor({10, 20}); + fusion.addInput(tv1); + auto tv2 = makeConcreteTensor({10, 10}); + fusion.addInput(tv2); + + // Not concretized + auto tv3 = sum(tv2, {1}); + auto tv4 = broadcast(tv3, {false, true}); + auto tv5 = add(tv0, tv4); + fusion.addOutput(tv5); + + // Concretized + auto tv6 = sum(tv2, {1}); + auto tv7 = broadcast(tv6, {false, true}); + auto tv8 = add(tv1, tv7); + fusion.addOutput(tv8); + + for (auto tv : {tv3, tv4, tv5, tv6, tv7, tv8}) { + tv->axis(1)->parallelize(ParallelType::TIDx); + } + + GpuLower gpulw(&fusion); + TORCH_CHECK(!gpulw.concretizedBroadcastDomains()->isConcretized( + loweredTv(tv4, gpulw)->axis(1))); + TORCH_CHECK(gpulw.concretizedBroadcastDomains()->isConcretized( + loweredTv(tv7, gpulw)->axis(1))); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + auto t0 = at::randn({10, 1}, options); + auto t1 = at::randn({10, 20}, options); + auto t2 = at::randn({10, 10}, options); + std::vector aten_inputs = {t0, t1, t2}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto outputs = fe.runFusion(aten_inputs); + + auto t5 = t0 + t2.sum({1}).unsqueeze(-1); + auto t8 = t1 + t2.sum({1}).unsqueeze(-1); + + testValidate(&fusion, outputs, aten_inputs, {t5, t8}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionBroadcastConcretization2_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto tv1 = sum(tv0, {0, 1}); + auto tv2 = broadcast(tv1, {true}); + auto tv3 = broadcast(tv2, {false, true}); + fusion.addOutput(tv3); + + // tv1 is thread-predicated with TIDx and TIDy + tv1->axis(0)->parallelize(ParallelType::TIDx); + tv1->axis(1)->parallelize(ParallelType::TIDy); + // tv2 broadcasts along TIDx + tv2->axis(0)->parallelize(ParallelType::TIDx); + // tv3 broadcasts along TIDy + tv3->axis(0)->parallelize(ParallelType::TIDx); + tv3->axis(1)->parallelize(ParallelType::TIDy); + + // Both tv2 and tv3 broadcast along predicated TID dimensions, but + // since the broadcast domains are not concretized, there should be + // no actual parallel broadcast + + GpuLower gpulw(&fusion); + TORCH_CHECK( + !gpulw.kernel()->summary().has_block_broadcasts && + !gpulw.kernel()->summary().has_grid_broadcasts, + "There must be no parallel broadcast in this fusion"); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + auto t0 = at::randn({10, 11}, options); + std::vector aten_inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto outputs = fe.runFusion(aten_inputs); + + auto t3 = t0.sum().unsqueeze(-1).unsqueeze(-1); + + testValidate(&fusion, outputs, aten_inputs, {t3}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionBroadcastConcretization3_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + std::vector input_shape({10, 4, 8}); + std::vector output_shape({8, 4, 1}); + + auto tv0 = makeConcreteTensor(input_shape); + fusion.addInput(tv0); + + auto tv2 = sum(tv0, {0}); + auto tv3 = set(tv2); + auto tv4 = + view(tv3, {input_shape.begin() + 1, input_shape.end()}, output_shape); + auto tv5 = add(tv4, IrBuilder::create(1)); + fusion.addOutput(tv5); + + tv2->axis(0)->parallelize(ParallelType::TIDx); + tv4->axis(-1)->parallelize(ParallelType::TIDx); + tv5->axis(-1)->parallelize(ParallelType::TIDx); + + // The view op adds a broadcast domain in tv4, which is + // parallelized. Howver, it is never materialized, so there should + // be no parallel broadcast. + + GpuLower gpulw(&fusion); + TORCH_CHECK( + !gpulw.kernel()->summary().has_block_broadcasts && + !gpulw.kernel()->summary().has_grid_broadcasts, + "There must be no parallel broadcast in this fusion"); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + auto t0 = at::randn(input_shape, options); + std::vector aten_inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto outputs = fe.runFusion(aten_inputs); + + auto t5 = at::native::view(t0.sum(0), output_shape) + 1; + + testValidate(&fusion, outputs, aten_inputs, {t5}, __LINE__, __FILE__); +} + +// Merging non-broadcast and broadcast domains +// TODO: Fix use case see issue https://github.com/csarofeen/pytorch/issues/1418 +// validateParallelize does not pass. Even if it's skipped, +// generated code is invalid as blockBroadcast is not used. +#if 0 +TEST_F(NVFuserTest, FusionBroadcastConcretization4_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto tv1 = sum(tv0, {1}); + auto tv2 = broadcast(tv1, {false, true}); + auto tv3 = add(tv2, tv0); + fusion.addOutput(tv3); + + tv1->axis(1)->parallelize(ParallelType::TIDx); + + tv2->merge(0, 1); + tv2->axis(0)->parallelize(ParallelType::TIDx); + // TODO: When set to shared memory, this kernel should be correct, but fails + // validation and when skipped produces incorrect code + tv2->setMemoryType(MemoryType::Shared); + + tv3->merge(0, 1); + tv3->axis(0)->parallelize(ParallelType::TIDx); + + fusion.printMath(); + fusion.printKernel(); +} +#endif + +TEST_F(NVFuserTest, FusionBroadcastConcretization5_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + auto tv1 = makeSymbolicTensor(1); + fusion.addInput(tv1); + auto tv2 = makeSymbolicTensor(1); + fusion.addInput(tv2); + auto tv3 = makeSymbolicTensor(1); + fusion.addInput(tv3); + + // Assert tv2 and tv3 have the same shape + auto tv4 = add(tv2, tv3); + fusion.addOutput(tv4); + + // Concretize a broadcast domain to multiple non-concrete domains + // through a multi-output expression. It should be considered to be + // non-uniquely concretized. + auto tv5 = broadcast(tv0, {false, true}); + // Reduce only the non-broadcast domain. + auto tvs = Welford(tv5, {0}); + auto tv9 = add(tvs.avg, tv1); + auto tv10 = add(tvs.var_sum, tv2); + fusion.addOutput(tv9); + fusion.addOutput(tv10); + + // Same pattern as the above, but concretize the broadcast domain + // with tv2 and tv3, which have the exactly same shape, so the + // broadcast should be considered uniquely concretized. + auto tv11 = broadcast(tv0, {false, true}); + // Reduce only the non-broadcast domain. + auto tvs2 = Welford(tv11, {0}); + auto tv15 = add(tvs2.avg, tv2); + auto tv16 = add(tvs2.var_sum, tv3); + fusion.addOutput(tv15); + fusion.addOutput(tv16); + + // Reduce only the broadcast domain. Since it's reduced, it should + // not be considered to be concretized. + auto tv17 = broadcast(tv0, {false, true}); + auto tvs3 = Welford(tv17, {1}); + fusion.addOutput(tvs3.avg); + + ConcretizedBroadcastDomains bcast_concretization_info(&fusion); + + TORCH_CHECK( + bcast_concretization_info.maybeNonUniquelyConcretized(tv5->axis(1)), + "Failed to detect non-unique concretization of ", + tv5->toString()); + + TORCH_CHECK( + bcast_concretization_info.isUniquelyConcretized(tv11->axis(1)), + "Failed to detect unique concretization of ", + tv11->toString()); + + TORCH_CHECK( + !bcast_concretization_info.isConcretized(tv17->axis(1)), + "Failed to detect non-concretization of ", + tv17->toString()); +} + +TEST_F(NVFuserTest, FusionIssue1430_CUDA) { + // Derived from an expression sorting issue when using loop map, now expr + // sorting uses parallel map. + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + int V = 2, W = 3, X = 4, Y = 5, Z = 6; + + // setup fusion + auto tv0 = TensorViewBuilder() + .ndims(5) + .dtype(DataType::Half) + .contiguity(std::vector(5, true)) + .shape({V, W, X, Y, Z}) + .build(); + + fusion.addInput(tv0); + auto tv1 = set(tv0); + auto tv2 = castOp(DataType::Float, tv1); + + auto tvs = Welford(tv2, {1, 2, 3, 4}); + auto tv3 = tvs.avg; + auto tv4 = tvs.var_sum; + auto tv5 = tvs.n; + + // avg + auto tv6 = broadcast(tvs.avg, {false, true, true, true, true}); + + // var + auto tv7 = mul(tv4, IrBuilder::create(1. / (W * X * Y * Z))); + auto tv8 = add(tv7, IrBuilder::create(1.e-6)); + auto tv9 = broadcast(tv8, {false, true, true, true, true}); + auto tv10 = rsqrt(tv9); + + auto tv11 = castOp(DataType::Float, tv1); + auto tv12 = sub(tv11, tv6); + auto tv13 = mul(tv12, tv10); + + auto tv14 = set(tv13); + fusion.addOutput(tv14); + + tv3->axis(0)->parallelize(ParallelType::BIDy); + tv3->axis(2)->parallelize(ParallelType::BIDx); + tv3->axis(3)->parallelize(ParallelType::TIDx); + tv3->axis(4)->parallelize(ParallelType::Vectorize); + + // tv3->reorder({{1, -2}}); + + auto rfactor = ir_utils::rfactorHelper(tv3, {1, 4}); + + scheduler_utils::parallelizeAllLike(rfactor); + + for (auto tv : ir_utils::allTvs(&fusion)) { + if (tv != tv1 || tv != tv3) { + for (auto i : c10::irange(tv->nDims())) { + if (isParallelTypeVectorize(tv->axis(i)->getParallelType())) { + tv->axis(i)->parallelize(ParallelType::Serial); + } + } + } + } + + tv0->computeAt(tv14, 1); + tv13->computeAt(tv14, -2); + tv2->computeAt(tv14, -1, ComputeAtMode::MostInlined); + tv11->computeAt(tv14, -1, ComputeAtMode::MostInlined); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({V, W, X, Y, Z}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion); + auto cg_outputs = fe.runFusion({t0}, LaunchParams(X, V, -1, Y, -1, -1)); + + auto t0_double = t0.to(at::kDouble); + + auto at_mu = at::mean(t0_double, {1, 2, 3, 4}) + .unsqueeze(-1) + .unsqueeze(-1) + .unsqueeze(-1) + .unsqueeze(-1); + auto at_var = at::var(t0_double, {1, 2, 3, 4}, false) + .unsqueeze(-1) + .unsqueeze(-1) + .unsqueeze(-1) + .unsqueeze(-1); + + auto at_out = t0_double.sub(at_mu).div(at_var.add(1.e-6).sqrt()); + + testValidate( + &fusion, + cg_outputs, + {t0}, + {at_out}, + __LINE__, + __FILE__, + "", + LaunchParams(X, V, -1, Y, -1, -1)); +} + +// Test code generation of allocated scalars +TEST_F(NVFuserTest, FusionCodegenAllocatedScalars_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Fusion is just a dummy container in this test, just used for + // getting a Kernel container + auto tv0 = makeSymbolicTensor(0); + fusion.addInput(tv0); + auto tv1 = set(tv0); + fusion.addOutput(tv1); + + GpuLower gpulw(&fusion); + auto kernel = gpulw.kernel(); + + // Set the kernel as the current fusion + FusionGuard kg(kernel); + + // Create alocated scalars + auto ks0 = add(kernel->zeroVal(), kernel->oneVal()); + auto ks0_alloc = IrBuilder::create( + ks0, MemoryType::Local, kernel->oneVal()); + + auto ks1 = add(ks0, kernel->oneVal()); + auto ks1_alloc = IrBuilder::create( + ks1, MemoryType::Local, kernel->oneVal()); + + auto tk0 = kernel->inputs()[0]->as(); + auto tki0 = IrBuilder::create(tk0, std::vector{ks0}); + auto tki1 = IrBuilder::create(tk0, std::vector{ks1}); + auto tk0_expr = IrBuilder::create(UnaryOpType::Set, tki0, tki1); + + // Insert the scalar expression and the allocation of the + // output directly to the kernel + auto proxy = kir::KernelInternalProxy(kernel); + + const auto indent = " "; + const auto ks0_name = "i" + std::to_string(ks0->name()); + const auto ks1_name = "i" + std::to_string(ks1->name()); + const auto tk0_name = "T" + std::to_string(tk0->name()); + + auto& exprs = proxy.topLevelExprs(); + exprs.push_back(tk0_expr); + + // Invalid code gen + const auto no_alloc_code = codegen::generateCudaKernel(kernel); + + // Without alloc, Int vals are just inlined, resulting in: + // t0[(0 + 1)] = t0[((0 + 1) + 1)] + std::stringstream no_alloc_ref; + no_alloc_ref << "\n" + << indent << tk0_name << "[(0 + 1)]\n" + << indent << indent << " = " << tk0_name << "[((0 + 1) + 1)];\n"; + + TORCH_CHECK( + no_alloc_code.find(no_alloc_ref.str()) != std::string::npos, + "Invalid code generation. Expected:", + no_alloc_ref.str(), + "Actual:\n", + no_alloc_code); + + // Insert proper allocations and definitions + exprs.insert(std::find(exprs.begin(), exprs.end(), tk0_expr), ks0_alloc); + exprs.insert( + std::find(exprs.begin(), exprs.end(), tk0_expr), ks0->definition()); + exprs.insert(std::find(exprs.begin(), exprs.end(), tk0_expr), ks1_alloc); + exprs.insert( + std::find(exprs.begin(), exprs.end(), tk0_expr), ks1->definition()); + + const auto valid_code = codegen::generateCudaKernel(kernel); + + std::stringstream valid_ref; + valid_ref << "\n" + << indent << tk0_name << "[" << ks0_name << "]\n" + << indent << indent << " = " << tk0_name << "[" << ks1_name + << "];\n"; + + TORCH_CHECK( + valid_code.find(valid_ref.str()) != std::string::npos, + "Invalid code generation. Expected:", + valid_ref.str(), + "Actual:\n", + valid_code); +} + +TEST_F(NVFuserTest, FusionIndexHoist1_CUDA) { + if (isOptionDisabled(DisableOption::IndexHoist)) { + GTEST_SKIP() << "Index hoisting disabled"; + } + + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + auto tv2 = set(tv1); + auto tv3 = set(tv2); + auto tv4 = set(tv3); + auto tv5 = set(tv4); + fusion.addOutput(tv5); + + tv1->split(-1, 4); + tv2->split(-1, 4); + tv3->merge(0, 1); + tv3->split(0, 8); + tv5->merge(0, 1); + tv5->split(0, 8); + tv4->computeAt(tv5, -1); + + tv1->setMemoryType(MemoryType::Global); + tv2->setMemoryType(MemoryType::Global); + tv3->setMemoryType(MemoryType::Global); + + // Use Int32 as the index type to verify Int32 is used as the type + // of hoisted indices + GpuLower gpulw(&fusion, DataType::Int32); + auto kernel = gpulw.kernel(); + + auto is_index_times_ns = [](Val* val, Val* index, std::string name) -> bool { + auto def = dynamic_cast(val->definition()); + if (def == nullptr) { + return false; + } + return def->getBinaryOpType() == BinaryOpType::Mul && + def->rhs()->isA() && + def->rhs()->as()->name() == name && def->lhs() == index; + }; + + // Validate indices in the kernel are hoisted as + // intended. Validation could be also done by just string comparison + // as the parser test, but updating such tests would be tedious. + for (auto top_level_loop : + ir_utils::filterByType(kernel->topLevelExprs())) { + auto innermost_loop = top_level_loop; + while (auto first_expr_loop = dynamic_cast( + innermost_loop->body().exprs().at(0))) { + innermost_loop = first_expr_loop; + } + const auto& exprs = innermost_loop->body().exprs(); + TORCH_CHECK(!exprs.empty(), "No expression found"); + TORCH_CHECK( + exprs.at(0)->isA(), + "Invalid expression: ", + exprs.at(0)->toString()); + auto hoisted_index = exprs.at(0)->as()->buffer(); + TORCH_CHECK( + hoisted_index->dtype() == DataType::Int32, + "Invalid data type of hoisted indices. Should be Int32 but: ", + hoisted_index->dtype()); + kir::Predicate* pred = nullptr; + for (auto expr : exprs) { + if (expr->isA()) { + pred = expr->as()->predicate(); + auto arith_expr = expr->as()->thenBody().exprs().at(0); + auto out_ti = arith_expr->outputs()[0]->as(); + if (out_ti->view()->name() == 1) { + // Ref: T1[*, hoisted_index] = T0[*, hoisted_index * T0.stride]; + auto t1_index = out_ti->index(1); + TORCH_CHECK( + t1_index == hoisted_index, + "Invalid index: ", + t1_index->toInlineString()); + // Pred: hoisted_index < T0.size[1] + TORCH_CHECK( + pred->value()->definition()->as()->lhs() == + hoisted_index, + "Invalid predicate: ", + pred->value()->toInlineString(), + ", ", + expr->toString()); + TORCH_CHECK(arith_expr->inputs().size() == 1); + auto in0 = arith_expr->inputs().front()->as(); + TORCH_CHECK(in0->view()->name() == 0); + // hoisted_index * T0.stride[1] + auto t0_index = in0->index(1); + TORCH_CHECK( + is_index_times_ns(t0_index, hoisted_index, "T0.stride[1]"), + "Invalid index: ", + t0_index->toInlineString(), + ", ", + expr->toString()); + } else if (out_ti->view()->name() == 2) { + // Ref: T3[*, hoisted_index] = T2[*, hoisted_index]; + auto out_index = out_ti->index(1); + TORCH_CHECK( + out_index == hoisted_index, + "Invalid index: ", + out_index->toInlineString(), + ", ", + expr->toString()); + TORCH_CHECK( + pred->value()->definition()->as()->lhs() == + hoisted_index, + "Invalid predicate: ", + pred->value()->toInlineString(), + ", ", + expr->toString()); + TORCH_CHECK(arith_expr->inputs().size() == 1); + auto in0 = arith_expr->inputs().front()->as(); + TORCH_CHECK(in0->view()->name() == 1); + auto in0_index = in0->index(1); + TORCH_CHECK( + in0_index == hoisted_index, + "Invalid index: ", + in0_index->toInlineString(), + ", ", + expr->toString()); + } else if (out_ti->view()->name() == 3) { + // Ref: T3[hoisted_index] = T2[hoisted_index]; + auto out_index = out_ti->index(0); + TORCH_CHECK( + out_index == hoisted_index, + "Invalid index: ", + out_index->toInlineString(), + ", ", + expr->toString()); + TORCH_CHECK( + pred->value()->definition()->as()->lhs() == + hoisted_index, + "Invalid predicate: ", + pred->value()->toInlineString(), + ", ", + expr->toString()); + TORCH_CHECK(arith_expr->inputs().size() == 1); + auto in0 = arith_expr->inputs().front()->as(); + TORCH_CHECK(in0->view()->name() == 2); + auto in0_index = in0->index(0); + TORCH_CHECK( + in0_index == hoisted_index, + "Invalid index: ", + in0_index->toInlineString(), + ", ", + expr->toString()); + } else if (out_ti->view()->name() == 4) { + // Ref: T4[0] = T3[hoisted_index]; + TORCH_CHECK( + pred->value()->definition()->as()->lhs() == + hoisted_index, + "Invalid predicate: ", + pred->value()->toInlineString(), + ", ", + expr->toString()); + TORCH_CHECK(arith_expr->inputs().size() == 1); + auto in0 = arith_expr->inputs().front()->as(); + TORCH_CHECK(in0->view()->name() == 3); + auto in0_index = in0->index(0); + TORCH_CHECK( + in0_index == hoisted_index, + "Invalid index: ", + in0_index->toInlineString(), + ", ", + expr->toString()); + } else if (out_ti->view()->name() == 5) { + // Ref: T5[hoisted_index] = T4[0] + auto out_index = out_ti->index(0); + TORCH_CHECK( + out_index == hoisted_index, + "Invalid index: ", + out_index->toInlineString(), + ", ", + expr->toString()); + TORCH_CHECK( + pred->value()->definition()->as()->lhs() == + hoisted_index, + "Invalid predicate: ", + pred->value()->toInlineString(), + ", ", + expr->toString()); + } + } + } + } + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + auto t0 = at::randn({15, 17}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto cg_outputs = fe.runFusion({t0}); + + auto ref = t0; + + testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); +} + +// Hoist indices for vectorized tensors +TEST_F(NVFuserTest, FusionIndexHoist2_CUDA) { + if (isOptionDisabled(DisableOption::IndexHoist)) { + GTEST_SKIP() << "Index hoisting disabled"; + } + + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(1); + fusion.addInput(tv0); + auto tv1 = makeContigTensor(1); + fusion.addInput(tv1); + + auto tv2 = set(tv0); + auto tv3 = set(tv1); + auto tv4 = add(tv2, tv3); + auto tv5 = set(tv4); + fusion.addOutput(tv5); + + tv5->split(-1, 4); + TransformPropagatorWithCheck propagator(tv5); + MaxRootDomainInfoSpanningTree(tv5).traverse(&propagator); + + tv4->split(-1, 3); + + tv0->computeAt(tv5, 1); + tv1->computeAt(tv5, 1); + + tv2->axis(-1)->parallelize(ParallelType::Vectorize); + tv3->axis(-1)->parallelize(ParallelType::Vectorize); + tv5->axis(-1)->parallelize(ParallelType::Vectorize); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + auto t0 = at::randn({16}, options); + auto t1 = at::randn({16}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t1}); + auto cg_outputs = fe.runFusion({t0, t1}); + + auto ref = t0 + t1; + + testValidate(&fusion, cg_outputs, {t0, t1}, {ref}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionTestGridComm_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + int X = 3, Y = 4, Z = 2; + auto tv0 = makeConcreteTensor({X, Y, Z}); + fusion.addInput(tv0); + auto tv1 = makeConcreteTensor({X, Y, Z}); + fusion.addInput(tv1); + + auto tv2 = set(tv0); + auto tv3 = add(tv2, tv1); + auto tv4 = set(tv3); + auto tv5 = set(tv4); + fusion.addOutput(tv5); + + tv2->setMemoryType(MemoryType::Global); + tv3->setMemoryType(MemoryType::Global); + tv4->setMemoryType(MemoryType::Global); + + tv2->axis(0)->parallelize(ParallelType::BIDy); + tv2->axis(1)->parallelize(ParallelType::BIDx); + tv2->axis(2)->parallelize(ParallelType::Vectorize); + + tv3->axis(0)->parallelize(ParallelType::BIDx); + tv3->axis(1)->parallelize(ParallelType::BIDy); + + tv4->axis(0)->parallelize(ParallelType::BIDy); + tv4->axis(1)->parallelize(ParallelType::BIDx); + + tv5->axis(0)->parallelize(ParallelType::BIDy); + tv5->axis(1)->parallelize(ParallelType::BIDx); + tv5->axis(2)->parallelize(ParallelType::Vectorize); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + auto t0 = at::randn({X, Y, Z}, options); + auto t1 = at::randn({X, Y, Z}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t1}); + auto cg_outputs = fe.runFusion({t0, t1}); + + auto ref = t0 + t1; + + testValidate(&fusion, cg_outputs, {t0, t1}, {ref}, __LINE__, __FILE__); +} + +// See issue https://github.com/csarofeen/pytorch/issues/1497 +TEST_F(NVFuserTest, FusionTestGridComm2_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + int64_t W = 3, X = 4; + + auto tv0 = makeConcreteTensor({X}); + auto tv1 = makeConcreteTensor({W, X}); + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = add(tv0, IrBuilder::create(1)); + auto tv3 = broadcast(tv2, {true, false}); + auto tv4 = add(tv3, tv1); + fusion.addOutput(tv4); + + tv4->merge(0); + tv4->split(0, 2); + + TransformPropagatorWithCheck propagator(tv4); + MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); + + tv3->computeAt(tv4, 1); + + tv4->axis(0)->parallelize(ParallelType::BIDx); + tv4->axis(-1)->parallelize(ParallelType::TIDx); + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(-1)->parallelize(ParallelType::TIDx); + tv3->axis(-1)->parallelize(ParallelType::TIDx); + + tv2->setMemoryType(MemoryType::Global); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + auto t0 = at::randn({X}, options); + auto t1 = at::randn({W, X}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t1}); + auto cg_outputs = fe.runFusion({t0, t1}); + + auto ref = t0 + t1 + 1; + + testValidate(&fusion, cg_outputs, {t0, t1}, {ref}, __LINE__, __FILE__); +} + +// Vectorized reset test for double buffered registers +TEST_F(NVFuserTest, FusionDoubleBufferVector_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(1); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1.0)); + auto tv2 = sum(tv1, {0}); + auto tv2c = tv2->cacheBefore(); + + fusion.addOutput(tv2); + + auto tv1cw = tv1->cacheAfter(); + auto tv1cr = tv1cw->cacheAfter(); + + tv1cw->split(-1, 32); + tv1cr->split(-1, 32); + tv1cr->split(-1, 4); + tv1cr->axis(-1)->parallelize(ParallelType::Vectorize); + + tv1cw->computeAt(tv1cr, 1); + tv0->computeAt(tv1cw, -1); + tv2c->split(-1, 32); + tv2c->split(-1, 4); + tv1cr->computeAt(tv2c, 2); + + tv1cw->setMemoryType(MemoryType::Shared); + tv1cr->doubleBuffer(); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::manual_seed(0); + auto t0 = at::randn({200}, options); + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto cg_outputs = fe.runFusion({t0}); + auto ref = (t0 + 1).sum({0}); + + testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); +} + +// Request 48KB of data in shared mem, +// should be large enough not to fit in +// static allocations, but small enough +// to fit in supported devices (sm70+). +TEST_F(NVFuserTest, FusionLargeSmem_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(1); + fusion.addInput(tv0); + auto tv1 = add(tv0, IrBuilder::create(1.0)); + auto tv2 = add(tv1, IrBuilder::create(2.0)); + fusion.addOutput(tv2); + + tv2->split(0, 12288); + tv2->split(1, 128); + tv1->computeAt(tv2, 1); + tv1->split(1, 128); + tv0->computeAt(tv1, -1); + tv1->setMemoryType(MemoryType::Shared); + tv1->axis(-1)->parallelize(ParallelType::TIDx); + tv2->axis(-1)->parallelize(ParallelType::TIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::manual_seed(0); + auto t0 = at::randn({12288 * 4}, options); + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto cg_outputs = fe.runFusion({t0}); + auto ref = t0 + 1 + 2; + + testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); +} + +// Request a smem allocation that is equal to the device limit +TEST_F(NVFuserTest, FusionTooLargeSmem_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto properties = at::cuda::getDeviceProperties( + c10::Device(c10::DeviceType::CUDA, 0).index()); + int device_limit = properties->sharedMemPerBlockOptin; + + auto tv0 = makeContigTensor(1); + fusion.addInput(tv0); + auto tv1 = add(tv0, IrBuilder::create(1.0)); + auto tv2 = add(tv1, IrBuilder::create(2.0)); + fusion.addOutput(tv2); + + // 4 byte per float + tv2->split(0, device_limit / 4); + tv2->split(1, 128); + tv1->computeAt(tv2, 1); + tv1->split(1, 128); + tv0->computeAt(tv1, -1); + tv1->setMemoryType(MemoryType::Shared); + tv1->axis(-1)->parallelize(ParallelType::TIDx); + tv2->axis(-1)->parallelize(ParallelType::TIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::manual_seed(0); + auto t0 = at::randn({12288 * 4}, options); + FusionExecutor fe; + + // First compile gets a compiled kernel + fe.compileFusion(&fusion, {t0}); + + // Should be throwing because the kernel + // requested absolute device limit + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) + ASSERT_ANY_THROW(fe.runFusion({t0})); +} + +// Try to test alignment when multiple tensors are +// in shared mem. +TEST_F(NVFuserTest, FusionSmemAlignment_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({3, 4, 7, 2, 5}); + fusion.addInput(tv0); + auto tv1 = sum(tv0, {4}); + auto tv2 = sum(tv1, {3}); + auto tv3 = sum(tv2, {2}); + auto tv4 = sum(tv3, {1}); + fusion.addOutput(tv4); + + auto tv0c = tv0->cacheAfter(); + auto tv1bc = tv1->cacheBefore(); + auto tv2bc = tv2->cacheBefore(); + auto tv3bc = tv3->cacheBefore(); + auto tv4bc = tv4->cacheBefore(); + + tv0c->setMemoryType(MemoryType::Shared); + tv1bc->setMemoryType(MemoryType::Shared); + tv2bc->setMemoryType(MemoryType::Shared); + tv3bc->setMemoryType(MemoryType::Shared); + tv4bc->setMemoryType(MemoryType::Shared); + + tv1->axis(-1)->parallelize(ParallelType::Vectorize); + tv3->axis(-1)->parallelize(ParallelType::Vectorize); + tv0->computeAt(tv4, 0); + tv0->computeAt(tv2, 2); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::manual_seed(0); + auto t0 = at::randn({3, 4, 7, 2, 5}, options); + FusionExecutor fe; + + fe.compileFusion(&fusion, {t0}); + auto cg_outputs = fe.runFusion({t0}); + auto tref = t0.sum({1, 2, 3, 4}); + + testValidate(&fusion, cg_outputs, {t0}, {tref}, __LINE__, __FILE__); +} + +// Repro of #1521 +TEST_F(NVFuserTest, FusionImmediateValueAsInput_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + auto immediate_scalr = IrBuilder::create(0.1); + // Adding an immediate scalar value as an input is not allowed + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) + ASSERT_ANY_THROW(fusion.addInput(immediate_scalr)); + + // Instead, use a symbolic value + auto symbolic_scalar = IrBuilder::create(); + fusion.addInput(symbolic_scalar); + + auto tv1 = add(tv0, symbolic_scalar); + fusion.addOutput(tv1); + + // Make sure the kernel is compiled. + FusionExecutor fe; + fe.compileFusion(&fusion); +} + +// Repro of #1506 +TEST_F(NVFuserTest, FusionVectorizeContigIndex_CUDA) { + std::vector shape{14, 14}; + + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(2); + fusion.addInput(tv0); + auto tv1 = set(tv0); + auto tv2 = set(tv1); + fusion.addOutput(tv2); + + tv2->merge(0); + + // Vectorize by 4 should be allowed + tv2->split(0, 4); + + tv2->axis(0)->parallelize(ParallelType::TIDx); + tv0->computeAt(tv2, 1); + + tv1->axis(1)->parallelize(ParallelType::Vectorize); + tv2->axis(1)->parallelize(ParallelType::Vectorize); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn(shape, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto cg_outputs = fe.runFusion({t0}); + + TORCH_CHECK(t0.equal(cg_outputs[0])); +} + +// Make sure the same fusion as FusionVectorizeContigIndex fails if +// not contig. +TEST_F(NVFuserTest, FusionVectorizeContigIndexFail_CUDA) { + std::vector shape{14, 14}; + + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto tv1 = set(tv0); + auto tv2 = set(tv1); + fusion.addOutput(tv2); + + tv2->merge(0); + + tv2->split(0, 4); + + tv2->axis(0)->parallelize(ParallelType::TIDx); + tv0->computeAt(tv2, 1); + + tv1->axis(1)->parallelize(ParallelType::Vectorize); + tv2->axis(1)->parallelize(ParallelType::Vectorize); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn(shape, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + + // This should fail at the launch time as 14 is not divisible by the + // vector word size. The two domains are merged, but they are not + // contiguous, so contig indexing is not involved in this case. + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) + ASSERT_ANY_THROW(fe.runFusion({t0})); +} + +TEST_F(NVFuserTest, FusionVectorizeInputToOutput_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + auto tv1 = set(tv0); + fusion.addOutput(tv1); + + tv1->split(0, 4); + + tv1->axis(-1)->parallelize(ParallelType::Vectorize); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + + const int n = 12; + auto t0 = at::randn({n}, options); + // Shift by one to make it non-aligned + auto t0_misaligned = at::randn({n + 1}, options).index({Slice(1)}); + auto t1_misaligned = at::empty({n + 1}, options).index({Slice(1)}); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto cg_outputs = fe.runFusion({t0}); + TORCH_CHECK(t0.equal(cg_outputs[0])); + + // Pass misaligned input. This must fail. + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) + ASSERT_ANY_THROW(fe.runFusion({t0_misaligned})); + + // Pass misaligned output. This must fail too. + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) + ASSERT_ANY_THROW(fe.runFusion({t0}, {t1_misaligned})); +} + +// Repro of issue #1530 +TEST_F(NVFuserTest, FusionVectorizeContigIndexValidationFail_CUDA) { + std::vector shape{1, 2, 1}; + + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(shape.size()); + fusion.addInput(tv0); + auto tv1 = set(tv0); + fusion.addOutput(tv1); + + tv1->merge(1); + tv1->merge(0); + + auto invalid_vec_size = shape[0] * shape[1] * shape[2]; + invalid_vec_size *= invalid_vec_size; + + tv1->split(0, invalid_vec_size); + + tv1->axis(1)->parallelize(ParallelType::Vectorize); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn(shape, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) + ASSERT_ANY_THROW(fe.runFusion({t0})); +} + +TEST_F(NVFuserTest, FusionContigIndexingWithBroadcast_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({4}); + fusion.addInput(tv0); + auto tv1 = makeConcreteTensor({3, 4}); + fusion.addInput(tv1); + + auto tv2 = broadcast(tv0, {true, false}); + auto tv3 = add(tv2, tv1); + fusion.addOutput(tv3); + + tv3->merge(0); + TransformPropagatorWithCheck propagator(tv3); + MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + + tv2->setMemoryType(MemoryType::Local); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({4}, options); + auto t1 = at::randn({3, 4}, options); + + auto t3 = t0.unsqueeze(0).add(t1); + { + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t1}); + auto cg_outputs = fe.runFusion({t0, t1}); + + testValidate(&fusion, cg_outputs, {t0, t1}, {t3}, __LINE__, __FILE__); + } + + // Make sure tv2 indexing also works when it's stored in global memory + tv2->setMemoryType(MemoryType::Global); + { + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t1}); + auto cg_outputs = fe.runFusion({t0, t1}); + + testValidate(&fusion, cg_outputs, {t0, t1}, {t3}, __LINE__, __FILE__); + } +} + +// Repro of #1534. Validation should detect invalid vectorization. +TEST_F(NVFuserTest, FusionVectorizeContigIndexValidationFail2_CUDA) { + std::vector shape1{2, 3, 2}; + std::vector shape2{2, 2}; + + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigConcreteTensor(shape1); + fusion.addInput(tv0); + auto tv1 = makeContigConcreteTensor(shape2); + fusion.addInput(tv1); + + auto tv2 = set(tv1); + auto tv3 = broadcast(tv2, {false, true, false}); + auto tv4 = add(tv0, tv3); + fusion.addOutput(tv4); + + tv4->merge(1, 2); + tv4->merge(0, 1); + tv4->split(0, 4); + TransformPropagatorWithCheck propagator(tv4); + MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); + + tv0->computeAt(tv4, -2); + tv1->computeAt(tv4, -2); + + tv2->axis(-1)->parallelize(ParallelType::Vectorize); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn(shape1, options); + auto t1 = at::randn(shape2, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t1}); + + // Vectorization of tv2 should be detected as invalid. + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) + ASSERT_ANY_THROW(fe.runFusion({t0, t1})); +} + +TEST_F(NVFuserTest, FusionVectorizeContigIndexWithBroadcast_CUDA) { + std::vector shape1{2, 2, 2}; + std::vector shape2{1, 2, 2}; + + Fusion fusion; + FusionGuard fg(&fusion); + + // [I0, I1, I2] + auto tv0 = makeContigTensor(shape1.size()); + fusion.addInput(tv0); + + // [B3, I1, I2] + auto tv1 = makeContigConcreteTensor(shape2); + fusion.addInput(tv1); + + auto tv2 = set(tv1); + auto tv3 = add(tv0, tv2); + fusion.addOutput(tv3); + + tv3->merge(1, 2); + tv3->merge(0, 1); + tv3->split(0, 4); + + // Don't modify tv1 so that it's replayed as tv2 with actual + // transformations. It would create temporary IterDomains, and the + // validation should still be able to detect vectorization by 4 is valid. + // TransformPropagatorWithCheck propagator(tv3); + // MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + + tv2->merge(1, 2); + tv2->merge(0, 1); + tv2->split(0, 4); + + tv2->computeAt(tv3, -2); + + tv2->axis(-1)->parallelize(ParallelType::Vectorize); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn(shape1, options); + auto t1 = at::randn(shape2, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t1}); + auto cg_outputs = fe.runFusion({t0, t1}); + + auto ref = t0 + t1; + + testValidate(&fusion, cg_outputs, {t0, t1}, {ref}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionVectorizeContigIndexPointwiseSchedule_CUDA) { + std::vector shape0{100, 14, 2, 14}; + std::vector shape1{100, 2, 14}; + + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(shape0.size()); + fusion.addInput(tv0); + auto tv1 = makeContigTensor(shape1.size()); + fusion.addInput(tv1); + + auto tv2 = broadcast(tv1, {false, true, false, false}); + auto tv3 = add(tv0, tv2); + fusion.addOutput(tv3); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn(shape0, options); + auto t1 = at::randn(shape1, options); + + auto lparams = schedulePointwise(&fusion, {t0, t1}); + + GpuLower gpulw(&fusion); + auto kernel = gpulw.kernel(); + + // The innermost two dimensions are merged and contiguous, so + // vectorization can be done against 2*14=28 rather than 14, so + // vector word size should be 4. Broadcasting of tv1 should not + // matter. + for (const auto& vec_info : kernel->summary().vectorized_set_info) { + TORCH_CHECK( + vec_info.word_size == 4, + "Invalid vector word size: ", + vec_info.word_size); + } + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t1}, lparams); + auto cg_outputs = fe.runFusion({t0, t1}); + + auto ref = t0 + t1.unsqueeze(-3); + + testValidate(&fusion, cg_outputs, {t0, t1}, {ref}, __LINE__, __FILE__); +} + +// Repro of issue #1539. +TEST_F(NVFuserTest, FusionTrivialReductionForwarding1_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + auto tv1 = broadcast(tv0, {true, false}); + auto tv2 = sum(tv1, {0}); + auto tv3 = set(tv2); + fusion.addOutput(tv3); + + tv2->merge(0); + tv2->split(0, 4); + + TransformPropagatorWithCheck propagator(tv2); + MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator); + + // All tensors must be transformed to a 2D tensor with each axis + // mapped with each other in the LOOP map. + ComputeAtMap ca_map(&fusion); + for (auto tv : ir_utils::allTvs(&fusion)) { + TORCH_CHECK( + tv->nDims() == 2, "Expected to be a 2D tensor but: ", tv->toString()); + for (const auto i : c10::irange(2)) { + TORCH_CHECK(ca_map.areMapped( + tv->axis(i), tv3->axis(i), IdMappingMode::PERMISSIVE)); + } + } +} + +TEST_F(NVFuserTest, FusionTrivialReductionForwarding2_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + auto tv1 = broadcast(tv0, {true, false}); + auto tv2 = sum(tv1, {0}); + auto tv3 = add(tv2, IrBuilder::create(1)); + + fusion.addOutput(tv3); + + // Merging a trivial reduction with a non-reduction domain + tv2->merge(0, 1); + tv2->split(0, 4); + + tv3->split(0, 4); + + // tv2 and tv3 are different as tv3 lacks the trivial reduction, but + // they are mapped with each other by BestEffortReplay as the merge + // of trivial reduciton dim is forwarded. + + PairwiseRootDomainMap root_map(tv2, tv3); + + auto p2c = BestEffortReplay::replayCasP(tv3, tv2, 2, root_map).getReplay(); + for (const auto i : c10::irange(tv2->nDims())) { + auto tv2_id = tv2->axis(i); + auto it = p2c.find(tv2_id); + TORCH_CHECK( + it != p2c.end(), + "Expected mapped consumer ID but not found: ", + tv2_id->toString()); + auto tv3_mapped_id = it->second; + TORCH_CHECK( + tv3_mapped_id == tv3->axis(i), + "Unexpected mapped consumer ID: ", + tv3_mapped_id->toString()); + } + + auto c2p = BestEffortReplay::replayPasC(tv2, tv3, 2, root_map).getReplay(); + for (const auto i : c10::irange(tv3->nDims())) { + auto tv3_id = tv3->axis(i); + auto it = c2p.find(tv3_id); + TORCH_CHECK( + it != c2p.end(), + "Expected mapped producer ID but not found: ", + tv3_id->toString()); + auto tv2_mapped_id = it->second; + TORCH_CHECK( + tv2_mapped_id == tv2->axis(i), + "Unexpected mapped consumer ID: ", + tv2_mapped_id->toString()); + } +} + +TEST_F(NVFuserTest, FusionTrivialReductionForwarding3_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto tv1 = sum(tv0, {1}); + auto tv2 = add(tv1, IrBuilder::create(1)); + fusion.addOutput(tv2); + + // Similar pattern as FusionTrivialReductionForwarding2 but trivial + // reduciton at non-root domain + + // Create a trivial reduction by splitting with a factor of 1 + tv1->split(1, 1, false); + // Merging with a trivial reduction + tv1->merge(0, 1); + auto tv1_merge_out_id = tv1->axis(0); + tv1->split(0, 5); + + tv2->split(0, 5); + + // The merge of tv1 is done with a non-root trivial + // reduciton. BestEffortReplay should forward the merge. + + PairwiseRootDomainMap root_map(tv1, tv2); + auto p2c = BestEffortReplay::replayCasP(tv2, tv1, 2, root_map).getReplay(); + + // The two tensors should look like: + // tv1: [I1*1//5, 5, I2//1] + // tv2: [I1//5, 5] + // + // BestEffortRepaly should forward the merge of (I1 * 1) and create + // mappings of: + // I1*1//5 -> I1//5 + // 5 -> 5 + // I1*1 -> I1 + + TORCH_CHECK(p2c.size() == 3, "Unexpected number of mappings"); + TORCH_CHECK(p2c.count(tv1->axis(0)) && p2c[tv1->axis(0)] == tv2->axis(0)); + TORCH_CHECK(p2c.count(tv1->axis(1)) && p2c[tv1->axis(1)] == tv2->axis(1)); + TORCH_CHECK( + p2c.count(tv1_merge_out_id) && + p2c[tv1_merge_out_id] == tv2->getRootDomain()[0]); +} + +TEST_F(NVFuserTest, FusionTrivialReductionForwarding4_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + auto tv1 = makeSymbolicTensor(2); + fusion.addInput(tv1); + + auto tv2 = broadcast(tv0, {true, false}); + auto tv3 = add(tv1, tv2); + fusion.addOutput(tv3); + + // tv4 has a trivial reduction axis + auto tv4 = sum(tv2, {0}); + auto tv5 = add(tv4, IrBuilder::create(1)); + fusion.addOutput(tv5); + + tv3->merge(0, 1); + tv3->split(0, 32); + + // This causes the trivial reduction of tv4 to be merged with + // another axis of tv4, and then forward computeAt is done from tv4 + // to tv5. The split of the merged id of tv4 should be done on tv5 + // by forwarding the merge of the trivial reduction. + tv0->computeAt(tv3, -1); + + tv3->axis(0)->parallelize(ParallelType::BIDx); + tv3->axis(1)->parallelize(ParallelType::TIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({111}, options); + auto t1 = at::randn({123, 111}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t1}); + auto cg_outputs = fe.runFusion({t0, t1}); + + auto t2 = t0.unsqueeze(0); + auto t3 = t1 + t2; + auto t5 = sum(t2, {0}) + 1; + + testValidate(&fusion, cg_outputs, {t0, t1}, {t3, t5}, __LINE__, __FILE__); +} + +// See issue #1598 +TEST_F(NVFuserTest, FusionRAWSyncInsertionPlace1_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + auto tv1 = makeSymbolicTensor(2); + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = set(tv0); + auto tv3 = set(tv1); + auto tv4 = add(tv2, tv3); + fusion.addOutput(tv4); + + // Place tv2 on shared memory + tv2->split(0, 2); + tv2->split(-1, 4); + tv2->setMemoryType(MemoryType::Shared); + tv2->axis(-2)->parallelize(ParallelType::TIDy); + tv2->axis(-1)->parallelize(ParallelType::TIDx); + + tv3->split(0, 2); + tv3->split(-1, 4); + // swap tidx and tidy + tv3->axis(-2)->parallelize(ParallelType::TIDx); + tv3->axis(-1)->parallelize(ParallelType::TIDy); + + tv4->split(0, 2); + tv4->split(-1, 4); + tv4->axis(-2)->parallelize(ParallelType::TIDx); + tv4->axis(-1)->parallelize(ParallelType::TIDy); + + tv0->computeAt(tv4, 1); + tv3->computeAt(tv4, -1); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({10, 64}, options); + auto t1 = at::randn({10, 64}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t1}); + auto cg_outputs = fe.runFusion({t0, t1}); + + auto ref = t0 + t1; + + testValidate(&fusion, cg_outputs, {t0, t1}, {ref}, __LINE__, __FILE__); +} + +// See issue #1598 +TEST_F(NVFuserTest, FusionRAWSyncInsertionPlace2_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + auto tv1 = makeSymbolicTensor(2); + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = set(tv0); + auto tv3 = set(tv1); + auto tv4 = add(tv2, tv3); + fusion.addOutput(tv4); + + tv2->split(0, 2); + tv2->split(-1, 4); + tv2->setMemoryType(MemoryType::Shared); + + tv2->axis(-2)->parallelize(ParallelType::TIDy); + tv2->axis(-1)->parallelize(ParallelType::TIDx); + + tv4->split(0, 2); + tv4->split(-1, 4); + // Also do unroll for tv3 and tv4 + tv4->split(-2, 8, false); + tv4->axis(-3)->parallelize(ParallelType::Unroll); + // swap tidx and tidy + tv4->axis(-2)->parallelize(ParallelType::TIDx); + tv4->axis(-1)->parallelize(ParallelType::TIDy); + + tv0->computeAt(tv4, 1); + tv3->computeAt(tv4, -1); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({10, 64}, options); + auto t1 = at::randn({10, 64}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t1}); + auto cg_outputs = fe.runFusion({t0, t1}); + + auto ref = t0 + t1; + + testValidate(&fusion, cg_outputs, {t0, t1}, {ref}, __LINE__, __FILE__); +} + +// See issue #1599 +TEST_F(NVFuserTest, FusionRAWSyncInsertionPlace3_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + auto tv1 = makeSymbolicTensor(2); + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = set(tv0); + auto tv3 = set(tv1); + auto tv4 = add(tv2, tv3); + fusion.addOutput(tv4); + + // Use unroll where a RAW-sync tensor is stored + + tv4->split(0, 2); + tv4->split(0, 3); + tv4->split(-1, 4); + tv4->axis(1)->parallelize(ParallelType::Unroll); + tv4->axis(-2)->parallelize(ParallelType::TIDx); + tv4->axis(-1)->parallelize(ParallelType::TIDy); + + tv0->computeAt(tv4, 3); + tv3->computeAt(tv4, -1); + + tv2->split(-1, 4); + tv2->axis(-2)->parallelize(ParallelType::TIDy); + tv2->axis(-1)->parallelize(ParallelType::TIDx); + tv2->setMemoryType(MemoryType::Shared); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({50, 64}, options); + auto t1 = at::randn({50, 64}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t1}); + auto cg_outputs = fe.runFusion({t0, t1}); + + auto ref = t0 + t1; + + testValidate(&fusion, cg_outputs, {t0, t1}, {ref}, __LINE__, __FILE__); +} + +// See #1618 +TEST_F(NVFuserTest, FusionRAWSyncInsertionPlace4_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({16, 128}); + auto tv1 = makeConcreteTensor({16, 128}); + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = set(tv0); + auto tv3 = set(tv1); + auto tv4 = set(tv2); + auto tv5 = set(tv3); + auto tv6 = add(tv4, tv5); + fusion.addOutput(tv6); + + tv2->setMemoryType(MemoryType::Shared); + tv3->setMemoryType(MemoryType::Shared); + + tv2->computeAt(tv6, 0); + tv3->computeAt(tv6, 1); + tv4->computeAt(tv6, 1); + tv5->computeAt(tv6, -1); + tv2->split(1, 64); + tv3->split(1, 64); + tv2->axis(-1)->parallelize(ParallelType::TIDx); + tv3->axis(-1)->parallelize(ParallelType::TIDx); + tv6->axis(-1)->parallelize(ParallelType::TIDx); + + // Check the block sync is inserted at the correct location. + // There is exactly one block sync needed in this test case + // and the sync needs to be after the 2 expressions + // that modify shared memory. + class SyncInsertionPointChecker : public kir::IrVisitor { + public: + using kir::IrVisitor::handle; + + private: + void handle(UnaryOp* uop) final { + // Record number of unary ops that modifies shared memory. + if (uop->out()->isA() && + uop->out()->as()->view()->getMemoryType() == + MemoryType::Shared && + // Filter out initialization expressions + uop->in()->isA()) { + number_of_writes_++; + } + } + void handle(kir::BlockSync* bsync) final { + // Make sure both shared memory modifying expressions + // have been observed at the sync insertion point. + TORCH_INTERNAL_ASSERT( + number_of_writes_ == 2, + "FusionRAWSyncInsertionPlace4 test fail:", + "only 1 sync after the 2 shared mem writes is needed in this test," + "either a redundant sync has been inserted or the block sync is not inserted at the right place"); + } + + private: + int number_of_writes_ = 0; + } sync_insertion_checker; + GpuLower gpulw(&fusion); + sync_insertion_checker.handle(gpulw.kernel()->topLevelExprs()); +} + +// Test serial write and parallel read of shared mem: mapped case +TEST_F(NVFuserTest, FusionSerialSmemWriteParallelRead1_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeConcreteTensor({128, 6}); + TensorView* tv1 = makeConcreteTensor({128, 6}); + TensorView* tv2 = makeConcreteTensor({128, 6}); + fusion.addInput(tv0); + fusion.addInput(tv1); + fusion.addInput(tv2); + + TensorView* tv3 = add(tv0, tv1); + TensorView* tv4 = add(tv3, tv2); + + fusion.addOutput(tv4); + + // Use shared memory + tv3->setMemoryType(MemoryType::Shared); + + // Parallelize t4, in this case dim 0 on tv3 will + // not be parallelized but dim0 of t4 will be. + // We will need to make sure a sync is inserted + // even if these dimensions are mapped. + tv4->axis(0)->parallelize(ParallelType::TIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor t0 = at::randn({128, 6}, options); + at::Tensor t1 = at::randn({128, 6}, options); + at::Tensor t2 = at::randn({128, 6}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t1, t2}); + auto cg_outputs = fe.runFusion({t0, t1, t2}); + + auto ref = t0 + t1 + t2; + + testValidate(&fusion, cg_outputs, {t0, t1, t2}, {ref}, __LINE__, __FILE__); +} + +// Test serial write and parallel read of shared mem: un-mapped case +TEST_F(NVFuserTest, FusionSerialSmemWriteParallelRead2_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeConcreteTensor({128, 6}); + TensorView* tv1 = makeConcreteTensor({128, 6}); + TensorView* tv2 = makeConcreteTensor({128, 6}); + fusion.addInput(tv0); + fusion.addInput(tv1); + fusion.addInput(tv2); + + TensorView* tv3 = add(tv0, tv1); + TensorView* tv4 = add(tv3, tv2); + + fusion.addOutput(tv4); + + // Use shared memory + tv3->setMemoryType(MemoryType::Shared); + + // Split and parallelize t4, + // the parallelized dimension in t4 will not + // map across to the shared mem tensor, t3. So + // there will need to be a sync before use of t3. + tv4->split(0, 2); + tv4->axis(0)->parallelize(ParallelType::TIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor t0 = at::randn({128, 6}, options); + at::Tensor t1 = at::randn({128, 6}, options); + at::Tensor t2 = at::randn({128, 6}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t1, t2}); + auto cg_outputs = fe.runFusion({t0, t1, t2}); + + auto ref = t0 + t1 + t2; + + testValidate(&fusion, cg_outputs, {t0, t1, t2}, {ref}, __LINE__, __FILE__); +} + +// Simple test of async copy primitive +TEST_F(NVFuserTest, FusionSimpleCpAsync_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + int m = 33, n = 31; + + TensorView* tv0 = makeConcreteTensor({m, n}); + TensorView* tv1 = makeConcreteTensor({m, n}); + + fusion.addInput(tv0); + fusion.addInput(tv1); + + TensorView* tv2 = add(tv0, tv1); + + fusion.addOutput(tv2); + + auto tv0_shared = tv0->cacheAfter(LoadStoreOpType::CpAsync); + tv0_shared->setMemoryType(MemoryType::Shared); + + tv0->computeAt(tv2, 1); + tv0_shared->axis(1)->parallelize(ParallelType::TIDx); + tv2->axis(1)->parallelize(ParallelType::TIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({m, n}, options); + at::Tensor t1 = at::randn({m, n}, options); + + FusionExecutor fe; + + // requires ampere+ GPU + if (!deviceMajorMinorCheck(8)) { + ASSERT_ANY_THROW(fe.compileFusion(&fusion, {t0, t1})); + GTEST_SKIP() << "skipping tests on pre-AMPERE GPUs"; + } + fe.compileFusion(&fusion, {t0, t1}); + auto cg_outputs = fe.runFusion({t0, t1}); + + auto ref = t0 + t1; + + testValidate(&fusion, cg_outputs, {t0, t1}, {ref}, __LINE__, __FILE__); +} + +// Simple test of async copy primitive: double buffered +// Double buffer case 1, both block sync and async wait +// are needed. +TEST_F(NVFuserTest, FusionDoubleBufferCpAsync1_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Using vectorization so need to keep n multiple of 4. + int m = 33, n = 48; + + TensorView* tv0 = makeConcreteTensor({m, n}); + TensorView* tv1 = makeConcreteTensor({m, n}); + + fusion.addInput(tv0); + fusion.addInput(tv1); + + TensorView* tv2 = add(tv0, tv1); + + fusion.addOutput(tv2); + + auto tv0_shared = tv0->cacheAfter(LoadStoreOpType::CpAsync); + tv0_shared->setMemoryType(MemoryType::Shared); + tv0->computeAt(tv2, 1); + + // Asynchronously load a tile in one schedule + tv0_shared->split(1, 4); + tv0_shared->axis(-1)->parallelize(ParallelType::Vectorize); + tv0_shared->axis(-2)->parallelize(ParallelType::TIDx); + + // Consume the loaded tile in another schedule, + // triggering the need for a sync. + tv2->split(1, 12); + tv2->axis(-1)->parallelize(ParallelType::TIDx); + + // Double buffer the shared mem tensor. + tv0_shared->doubleBuffer(); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({m, n}, options); + at::Tensor t1 = at::randn({m, n}, options); + + FusionExecutor fe; + // requires ampere+ GPU + if (!deviceMajorMinorCheck(8)) { + ASSERT_ANY_THROW(fe.compileFusion(&fusion, {t0, t1})); + GTEST_SKIP() << "skipping tests on pre-AMPERE GPUs"; + } + fe.compileFusion(&fusion, {t0, t1}); + auto cg_outputs = fe.runFusion({t0, t1}); + + auto ref = t0 + t1; + + testValidate(&fusion, cg_outputs, {t0, t1}, {ref}, __LINE__, __FILE__); +} + +// Simple test of async copy primitive: double buffered +// Double buffer case 2, only async wait is needed +TEST_F(NVFuserTest, FusionDoubleBufferCpAsync2_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Using vectorization so need to keep n multiple of 4. + int m = 33, n = 48; + + TensorView* tv0 = makeConcreteTensor({m, n}); + TensorView* tv1 = makeConcreteTensor({m, n}); + + fusion.addInput(tv0); + fusion.addInput(tv1); + + TensorView* tv2 = add(tv0, tv1); + + fusion.addOutput(tv2); + + auto tv0_shared = tv0->cacheAfter(LoadStoreOpType::CpAsync); + tv0_shared->setMemoryType(MemoryType::Shared); + tv0->computeAt(tv2, 1); + + // Asynchronously load a tile in one schedule + tv0_shared->split(1, 4); + tv0_shared->axis(-2)->parallelize(ParallelType::TIDx); + + // Consume the loaded tile in another schedule, + // triggering the need for a sync. + tv2->split(1, 4); + tv2->axis(-2)->parallelize(ParallelType::TIDx); + + // Double buffer the shared mem tensor. + tv0_shared->doubleBuffer(); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({m, n}, options); + at::Tensor t1 = at::randn({m, n}, options); + + FusionExecutor fe; + // requires ampere+ GPU + if (!deviceMajorMinorCheck(8)) { + ASSERT_ANY_THROW(fe.compileFusion(&fusion, {t0, t1})); + GTEST_SKIP() << "skipping tests on pre-AMPERE GPUs"; + } + fe.compileFusion(&fusion, {t0, t1}); + auto cg_outputs = fe.runFusion({t0, t1}); + + auto ref = t0 + t1; + + testValidate(&fusion, cg_outputs, {t0, t1}, {ref}, __LINE__, __FILE__); +} + +// Simple test for double buffer in shared mem, +// where we should not insert redundant syncs when +// they are not needed. +TEST_F(NVFuserTest, FusionDoubleBufferNoSync_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Using vectorization so need to keep n multiple of 4. + int m = 33, n = 48; + + TensorView* tv0 = makeConcreteTensor({m, n}); + TensorView* tv1 = makeConcreteTensor({m, n}); + + fusion.addInput(tv0); + fusion.addInput(tv1); + + TensorView* tv2 = add(tv0, tv1); + + fusion.addOutput(tv2); + + auto tv0_shared = tv0->cacheAfter(); + tv0_shared->setMemoryType(MemoryType::Shared); + tv0->computeAt(tv2, 1); + + // Asynchronously load a tile in one schedule + tv0_shared->split(1, 4); + tv0_shared->axis(-2)->parallelize(ParallelType::TIDx); + + // Consume the loaded tile in another schedule, + // triggering the need for a sync. + tv2->split(1, 4); + tv2->axis(-2)->parallelize(ParallelType::TIDx); + + // Double buffer the shared mem tensor. + tv0_shared->doubleBuffer(); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({m, n}, options); + at::Tensor t1 = at::randn({m, n}, options); + + GpuLower gpulw(&fusion); + auto flattened_exprs = + ir_utils::flattenScopedExprs(gpulw.kernel()->topLevelExprs()); + bool sync_inserted = std::any_of( + flattened_exprs.begin(), flattened_exprs.end(), [](Expr* expr) { + return expr->isA(); + }); + TORCH_INTERNAL_ASSERT(!sync_inserted, "Un-expected block sync inserted"); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t1}); + auto cg_outputs = fe.runFusion({t0, t1}); + + auto ref = t0 + t1; + + testValidate(&fusion, cg_outputs, {t0, t1}, {ref}, __LINE__, __FILE__); +} + +// Test predicate inversion for cp.async +TEST_F(NVFuserTest, FusionCpAsyncPredicate_CUDA) { + // requires ampere+ GPU + + Fusion fusion; + FusionGuard fg(&fusion); + + // Using vectorization so need to keep n multiple of 4. + int m = 33, n = 48; + + TensorView* tv0 = makeConcreteTensor({m, n}); + + fusion.addInput(tv0); + auto tv1 = sum(tv0, {1}); + fusion.addOutput(tv1); + + auto tv0_shared = tv0->cacheAfter(LoadStoreOpType::CpAsync); + auto tv0_reg = tv0_shared->cacheAfter(); + tv0_shared->setMemoryType(MemoryType::Shared); + tv0->computeAt(tv1, 1); + + tv0_shared->split(-1, 32); + tv0_shared->split(-1, 4); + tv0_shared->axis(-1)->parallelize(ParallelType::Vectorize); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({m, n}, options); + + FusionExecutor fe; + if (!deviceMajorMinorCheck(8)) { + ASSERT_ANY_THROW(fe.compileFusion(&fusion, {t0})); + GTEST_SKIP() << "skipping tests on pre-AMPERE GPUs"; + } + + fe.compileFusion(&fusion, {t0}); + auto cg_outputs = fe.runFusion({t0}); + + auto ref = t0.sum({1}); + + testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__); +} + +// Test predicate removal on reg-to-reg expressions +TEST_F(NVFuserTest, FusionPredRemovalCheck_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeContigTensor(2); + fusion.addInput(tv0); + + TensorView* tv1 = set(tv0); + TensorView* tv2 = set(tv1); + TensorView* tv3 = set(tv2); + TensorView* tv4 = set(tv3); + + fusion.addOutput(tv4); + tv4->split(1, 4); + tv0->computeAt(tv4, -2); + tv3->axis(-1)->parallelize(ParallelType::Vectorize); + + class PredicateRemovalChecker : public kir::IrVisitor { + public: + using kir::IrVisitor::handle; + + private: + void handle(UnaryOp* uop) final { + assertOnLocalToLocal(uop); + } + + // Utility to assert any local-to-local expr is only trivially predicated. + void assertOnLocalToLocal(Expr* expr) { + bool is_local = true; + for (auto in : ir_utils::filterByType(expr->inputs())) { + if (in->view()->getMemoryType() != MemoryType::Local) { + is_local = false; + } + } + for (auto in : + ir_utils::filterByType(expr->outputs())) { + if (in->view()->getMemoryType() != MemoryType::Local) { + is_local = false; + } + } + + if (is_local) { + if (auto ite = dynamic_cast(scope_exprs_.back())) { + TORCH_INTERNAL_ASSERT( + ite->predicate()->value()->isConst(), + "redundant predicate on: ", + expr); + } + } + } + + private: + bool within_ite_ = false; + } pred_checker; + + GpuLower gpulw(&fusion); + pred_checker.handle(gpulw.kernel()->topLevelExprs()); +} + +TEST_F(NVFuserTest, FusionPropagateParallelTypesToSiblings_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + auto tvs = Welford(tv0, {0}); + auto tv_avg = tvs.avg; + fusion.addOutput(tv_avg); + + tv_avg->split(0, 128); + TransformPropagatorWithCheck propagator(tv_avg); + MaxRootDomainInfoSpanningTree(tv_avg).traverse(&propagator); + + tv_avg->axis(0)->parallelize(ParallelType::BIDx); + tv_avg->axis(1)->parallelize(ParallelType::TIDx); + + // Make sure the parallelization of tv_avg is propagated to the var + // and count tensors. + GpuLower gpulw(&fusion); + for (const auto expr : gpulw.kernel()->exprs()) { + auto wop = dynamic_cast(expr); + if (wop == nullptr) { + continue; + } + auto ref = wop->outAvg()->as(); + for (auto sibling : ir_utils::filterByType(wop->outputs())) { + if (ref == sibling) { + continue; + } + TORCH_CHECK( + ref->nDims() == sibling->nDims(), + "Invalid sibling: ", + sibling->toString()); + for (const auto i : c10::irange(ref->nDims())) { + TORCH_CHECK( + ref->axis(i)->getParallelType() == + sibling->axis(i)->getParallelType(), + "Mismatched parallel types between siblings. ", + ref->toString(), + ", ", + sibling->toString()); + } + } + } + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto options_int = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0); + at::manual_seed(0); + at::Tensor t0 = at::randn({9999}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto outputs = fe.runFusion({t0}); + + testValidate(fe.kernel(), outputs, {t0}, {t0.mean({0})}, __LINE__, __FILE__); +} + +// Test ExactRootDomainMap +TEST_F(NVFuserTest, FusionExactRootDomainMap_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + auto tv1 = makeSymbolicTensor(2); + fusion.addInput(tv1); + + auto tv2 = broadcast(tv0, {false, true}); + auto tv3 = transpose(tv2); + auto tv4 = add(tv2, tv1); + auto tv5 = add(tv2, tv3); + auto tv6 = add(tv3, tv1); + fusion.addOutput(tv4); + fusion.addOutput(tv5); + fusion.addOutput(tv6); + + const auto exact_map = ExactRootDomainMap(&fusion); + + // In the exact mapping, the broadcast domain introduced at tv2 is + // only mapped with the another one in tv3, which is just transposed + // from tv2. Any other domain, including the second domain of tv4, + // must not be mapped. + + auto tv2_bc = tv2->axis(1); + auto tv3_bc = tv3->axis(0); + + TORCH_CHECK( + exact_map.areMapped(tv2_bc, tv3_bc), + "Invalid exact root domain map: ", + exact_map.toString()); + + // They must not be mapped with anything else. + for (auto tv : ir_utils::allTvs(&fusion)) { + for (auto root_id : tv->getRootDomain()) { + if (root_id == tv2_bc || root_id == tv3_bc) { + continue; + } + TORCH_CHECK( + !exact_map.areMapped(root_id, tv2_bc), + "Invalid exact root domain map: ", + exact_map.toString()); + TORCH_CHECK( + !exact_map.areMapped(root_id, tv3_bc), + "Invalid exact root domain map: ", + exact_map.toString()); + } + } +} + +class NVFuserMultithreadedTest : public ::testing::Test { + protected: + bool was_enabled = false; + + void SetUp() override { + was_enabled = fuser::cuda::setEnabled(true); + } + + void TearDown() override { + fuser::cuda::setEnabled(was_enabled); + } +}; + +TEST_F(NVFuserMultithreadedTest, SingleFunction_CUDA) { + std::string ir = R"IR( +graph(%x.1 : Tensor, + %y.1 : Tensor): + %12 : NoneType = prim::Constant() + %11 : bool = prim::Constant[value=0]() + %9 : int = prim::Constant[value=1]() + %3 : Tensor = aten::exp(%x.1) + %5 : Tensor = aten::relu(%y.1) + %6 : Tensor = aten::sin(%5) + %8 : Tensor = aten::add(%3, %6, %9) + %10 : int[] = prim::ListConstruct(%9) + %13 : Tensor = aten::sum(%8, %10, %11, %12) + return (%13) +)IR"; + auto g = std::make_shared(); + torch::jit::parseIR(ir, g.get()); + GraphFunction fn("nvfuser_test", g, nullptr); + + auto run_kernel = [&fn]() { + auto x = torch::rand({32, 32}, at::TensorOptions(at::kCUDA)); + auto y = torch::rand({32, 32}, at::TensorOptions(at::kCUDA)); + std::vector results; + for (const auto& _ : c10::irange(10)) { + auto stack = createStack({x.clone(), y.clone()}); + fn.run(stack); + results.push_back(stack.back()); + } + for (const auto& i : c10::irange(1, 10)) { + auto t0 = results[0].toTensor(); + auto ti = results[i].toTensor(); + ASSERT_TRUE(at::allclose(t0, ti)); + } + }; + + constexpr size_t kNumThreads = 4; + std::vector threads; + for (size_t id = 0; id < kNumThreads; ++id) { + threads.emplace_back(run_kernel); + } + for (auto& t : threads) { + t.join(); + } +} + +TEST_F(NVFuserMultithreadedTest, MultipleFunctions_CUDA) { + auto run_kernel = []() { + const std::string ir = R"IR( + graph(%x.1 : Tensor, + %y.1 : Tensor): + %12 : NoneType = prim::Constant() + %11 : bool = prim::Constant[value=0]() + %9 : int = prim::Constant[value=1]() + %3 : Tensor = aten::exp(%x.1) + %5 : Tensor = aten::relu(%y.1) + %6 : Tensor = aten::sin(%5) + %8 : Tensor = aten::add(%3, %6, %9) + %10 : int[] = prim::ListConstruct(%9) + %13 : Tensor = aten::sum(%8, %10, %11, %12) + return (%13) + )IR"; + auto g = std::make_shared(); + torch::jit::parseIR(ir, g.get()); + GraphFunction fn("nvfuser_test", g, nullptr); + + auto x = torch::rand({32, 32}, at::TensorOptions(at::kCUDA)); + auto y = torch::rand({32, 32}, at::TensorOptions(at::kCUDA)); + std::vector results; + constexpr size_t numRuns = 10; + for (const auto& _ : c10::irange(numRuns)) { + auto stack = createStack({x.clone(), y.clone()}); + fn.run(stack); + results.push_back(stack.back()); + } + for (const auto& i : c10::irange(1, numRuns)) { + auto t0 = results[0].toTensor(); + auto ti = results[i].toTensor(); + ASSERT_TRUE(at::allclose(t0, ti)); + } + }; + + constexpr size_t kNumThreads = 4; + std::vector threads; + for (size_t id = 0; id < kNumThreads; ++id) { + threads.emplace_back(run_kernel); + } + for (auto& t : threads) { + t.join(); + } +} + +// Repro of issue #1655 +TEST_F(NVFuserTest, FusionIncompleteConcreteID_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + auto tv1 = makeSymbolicTensor(2); + fusion.addInput(tv1); + auto tv2 = makeSymbolicTensor(2); + fusion.addInput(tv2); + + auto tv3 = broadcast(tv0, {true, true, false}); + auto tv4 = broadcast(tv1, {false, true, false}); + auto tv5 = broadcast(tv2, {true, false, false}); + + auto tv6 = add(tv3, tv4); + auto tv7 = add(tv3, tv5); + + fusion.addOutput(tv6); + fusion.addOutput(tv7); + + tv6->merge(0); + tv6->merge(0); + + TransformPropagatorWithCheck propagator(tv6); + MaxRootDomainInfoSpanningTree(tv6).traverse(&propagator); + + tv0->computeAt(tv6, -1, ComputeAtMode::MostInlined); + tv1->computeAt(tv6, -1, ComputeAtMode::MostInlined); + tv2->computeAt(tv7, -1, ComputeAtMode::MostInlined); + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) + ASSERT_ANY_THROW(fusion.printKernel()); +} + +TEST_F(NVFuserTest, FusionTestReEntrantGridWelford_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + int X = 256, Y = 7, Z = 2048; + + // setup fusion + auto tv0 = makeContigTensor(4, DataType::Half); + fusion.addInput(tv0); + auto tv1 = castOp(DataType::Float, tv0); + + auto tvs = Welford(tv1, {0, 1, 2}); + auto tv_avg = tvs.avg; + auto tv_M2 = tvs.var_sum; + auto tv_N = tvs.n; + fusion.addOutput(tv_avg); + fusion.addOutput(tv_M2); + + auto cached_input = tv0->cacheAfter(); + auto cached_avg = tv_avg->cacheBefore(); + auto cached_M2 = tv_M2->cacheBefore(); + + auto reduction_tv = scheduler_utils::getReductionTvs(&fusion)[0]; + + reduction_tv->merge(0); + reduction_tv->merge(0); + + int TIDx = 16; + int vec = 4; + + int TIDy = 16; + int outer_tidy_fact = 16; + + reduction_tv->split(-1, TIDx * vec); + reduction_tv->split(-1, vec); + reduction_tv->axis(-2)->parallelize(ParallelType::TIDx); + reduction_tv->axis(-1)->parallelize(ParallelType::Vectorize); + reduction_tv->axis(-3)->parallelize(ParallelType::BIDx); + + reduction_tv->split(0, TIDy); + reduction_tv->axis(1)->parallelize(ParallelType::TIDy); + reduction_tv->split(0, outer_tidy_fact); + reduction_tv->axis(0)->parallelize(ParallelType::BIDy); + + // T2_g[ rblockIdx.y, rS{16}, rthreadIdx.y, iblockIdx.x, ithreadIdx.x24, + // iV25{4} ] + reduction_tv->reorder({{3, 0}, {4, 1}, {0, 2}, {2, 3}, {1, 4}, {5, 5}}); + // T2_g[iblockIdx.x, ithreadIdx.x24, rblockIdx.y, rthreadIdx.y, rS{16}, + // iV25{4}] + + TransformPropagatorWithCheck propagator(reduction_tv); + MaxRootDomainInfoSpanningTree(reduction_tv).traverse(&propagator); + auto rfactor_tv = ir_utils::rfactorHelper(reduction_tv, {4}); + scheduler_utils::parallelizeAllLike(rfactor_tv); + + tv0->computeAt(tv_avg, 2); + tv0->computeAt(cached_input, -2); + + cached_input->computeAt(rfactor_tv, 4, ComputeAtMode::BestEffort); + + for (auto tv : ir_utils::allTvs(&fusion)) { + if (tv == cached_input || tv == tv_avg || tv == tv_M2) { + continue; + } + tv->axis(-1)->parallelize(ParallelType::Serial); + } + + FusionExecutor fe; + fe.compileFusion(&fusion, {}, LaunchParams()); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({X, Y, Y, Z}, options); + + auto cg_outputs = fe.runFusion({t0}, LaunchParams(-1, -1, -1, -1, -1, -1)); + + // by default Welford outputs sum of square diff so need to divide to get var + cg_outputs[1] = cg_outputs[1].div((float)(X * Y * Y)); + + auto at_mu = at::mean(t0.to(at::kDouble), {0, 1, 2}); + auto at_var = at::var(t0.to(at::kDouble), {0, 1, 2}, false); + + testValidate( + &fusion, + cg_outputs, + {t0}, + {at_mu, at_var}, + __LINE__, + __FILE__, + "", + LaunchParams(-1, -1, -1, -1, -1, -1)); +} + +// Test sync insertion with redundant predicates +TEST_F(NVFuserTest, FusionRedundantPredSync_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeConcreteTensor({32}); + TensorView* tv1 = makeConcreteTensor({32, 32}); + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = broadcast(tv0, {true, false}); + auto tv3 = add(tv2, tv1); + + fusion.addOutput(tv3); + + auto tv0c = tv0->cacheAfter(); + + // Make a redundant write through smem + tv0c->setMemoryType(MemoryType::Shared); + + tv0->computeAt(tv3, 0); + tv1->computeAt(tv3, 0); + + tv0c->axis(0)->parallelize(ParallelType::TIDx); + tv2->axis(0)->parallelize(ParallelType::TIDy); + tv2->axis(1)->parallelize(ParallelType::TIDx); + + tv3->axis(0)->parallelize(ParallelType::TIDy); + tv3->axis(1)->parallelize(ParallelType::TIDx); + + GpuLower gpulw(&fusion); + auto flattened_exprs = + ir_utils::flattenScopedExprs(gpulw.kernel()->topLevelExprs()); + bool sync_inserted = std::any_of( + flattened_exprs.begin(), flattened_exprs.end(), [](Expr* expr) { + return expr->isA(); + }); + TORCH_INTERNAL_ASSERT(sync_inserted, "Expected block sync not inserted"); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor t0 = at::randn({32}, options); + at::Tensor t1 = at::randn({32, 32}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t1}); + auto cg_outputs = fe.runFusion({t0, t1}); + + auto ref = t0 + t1; + + testValidate(&fusion, cg_outputs, {t0, t1}, {ref}, __LINE__, __FILE__); +} + +// Test case for removing syncs on chain of redundant uses. +TEST_F(NVFuserTest, FusionRedundantPredSync2_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeConcreteTensor({32}); + TensorView* tv1 = makeConcreteTensor({32, 32}); + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = broadcast(tv0, {true, false}); + auto tv3 = add(tv2, tv1); + + fusion.addOutput(tv3); + + auto tv0c = tv0->cacheAfter(); + + // Make a redundant write through smem + tv0c->setMemoryType(MemoryType::Shared); + tv2->setMemoryType(MemoryType::Shared); + + tv0->computeAt(tv3, 0); + tv1->computeAt(tv3, 0); + + tv0c->axis(0)->parallelize(ParallelType::TIDx); + tv2->axis(0)->parallelize(ParallelType::TIDy); + tv2->axis(1)->parallelize(ParallelType::TIDx); + + tv3->axis(0)->parallelize(ParallelType::TIDy); + tv3->axis(1)->parallelize(ParallelType::TIDx); + + // Utility class to make sure one block sync + // is inserted by RAW pass. + class SyncChecker : public kir::IrVisitor { + public: + using kir::IrVisitor::handle; + int result() { + return sync_seen_; + } + + private: + void handle(kir::BlockSync*) final { + sync_seen_++; + } + + private: + int sync_seen_ = 0; + } checker; + + GpuLower gpulw(&fusion); + checker.handle(gpulw.kernel()->topLevelExprs()); + TORCH_INTERNAL_ASSERT( + checker.result() < 2, "More syncs were inserted than expected"); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor t0 = at::randn({32}, options); + at::Tensor t1 = at::randn({32, 32}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t1}); + auto cg_outputs = fe.runFusion({t0, t1}); + + auto ref = t0 + t1; + + testValidate(&fusion, cg_outputs, {t0, t1}, {ref}, __LINE__, __FILE__); +} + +// Test case for sync insertion after redundant predicated smem write +// Check that syncs are removed only when all paths are redundant. +TEST_F(NVFuserTest, FusionRedundantPredSync3_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeConcreteTensor({32}); + TensorView* tv1 = makeConcreteTensor({32, 32}); + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = broadcast(tv0, {true, false}); + auto tv3 = set(tv2); + auto tv4 = add(tv3, tv1); + auto tv5 = add(tv2, tv1); + + fusion.addOutput(tv4); + fusion.addOutput(tv5); + + auto tv0c = tv0->cacheAfter(); + + // In this scheduling config, + // tv0c -> tv2 -> tv3 is a redundant path for tidy + // tv0c -> tv2 -> tv5 is not. + // So we need a RAW sync in tv0c->tv2 to make sure + // tv2 has the correct value to produce tv5. + tv0c->setMemoryType(MemoryType::Shared); + tv3->setMemoryType(MemoryType::Shared); + + tv0c->axis(0)->parallelize(ParallelType::TIDx); + tv2->axis(0)->parallelize(ParallelType::TIDy); + tv2->axis(1)->parallelize(ParallelType::TIDx); + + tv3->axis(0)->parallelize(ParallelType::TIDy); + tv3->axis(1)->parallelize(ParallelType::TIDx); + + tv5->axis(0)->parallelize(ParallelType::TIDy); + tv5->axis(1)->parallelize(ParallelType::TIDx); + + // Utility class to make sure one block sync + // is inserted by RAW pass. + class SyncChecker : public kir::IrVisitor { + public: + using kir::IrVisitor::handle; + int result() { + return sync_seen_; + } + + private: + void handle(kir::BlockSync* sync) final { + if (!sync->isWarHazardSync()) { + sync_seen_++; + } + } + + private: + int sync_seen_ = 0; + } checker; + + GpuLower gpulw(&fusion); + checker.handle(gpulw.kernel()->topLevelExprs()); + + // This is implicit checking. There are exactly 2 places + // where RAW hazards happen: one producing tv2 and the other + // producing tv3. This test case expect syncs in both of + // these places so we check that 2 RAW syncs are inserted. + TORCH_INTERNAL_ASSERT( + checker.result() == 2, + "Exactly 2 RAW sync expected for the two shared memory transfers"); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor t0 = at::randn({32}, options); + at::Tensor t1 = at::randn({32, 32}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t1}); + auto cg_outputs = fe.runFusion({t0, t1}); + + auto ref = t0 + t1; + + testValidate(&fusion, cg_outputs, {t0, t1}, {ref, ref}, __LINE__, __FILE__); +} + +// Unit test case for detecting thread redundant usage of shared tensors. +TEST_F(NVFuserTest, FusionRedundantUseCheck_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeConcreteTensor({32, 32}); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + auto tv2 = set(tv1); + auto tv3 = set(tv2); + auto tv4 = set(tv3); + + auto tv5 = set(tv4); + + auto tv6 = set(tv4); + auto tv7 = set(tv6); + + fusion.addOutput(tv5); + fusion.addOutput(tv7); + + tv2->setMemoryType(MemoryType::Shared); + tv4->setMemoryType(MemoryType::Shared); + + tv7->axis(-1)->parallelize(ParallelType::TIDx); + + // Thread pred map cannot be built without an active lower + // object. So would need to lower the whole fusion for + // testing. However, lower also keeps an copy of the fusion + // so the original pointers cannot be used to querry the + // thread pred map. So have to traverse the new expr list + // to find the pointers; + GpuLower gpulw(&fusion); + + TensorView *lowered_tv2 = nullptr, *lowered_tv4 = nullptr; + auto used_vals = gpulw.kernel()->usedMathVals(); + + for (auto tv : ir_utils::filterByType(used_vals)) { + if (tv->name() == 2) { + lowered_tv2 = tv; + } + if (tv->name() == 4) { + lowered_tv4 = tv; + } + } + + TORCH_INTERNAL_ASSERT( + lowered_tv2 != nullptr && lowered_tv4 != nullptr, + "tv2 or tv4 not lowered or mangled"); + + auto tv2_info = gpulw.threadPredMap().getPredicateInfo(lowered_tv2); + auto tv4_info = gpulw.threadPredMap().getPredicateInfo(lowered_tv4); + + // tv2 -> tv3 -> tv4 (shared) is the only use chain for tv2, + // and tv4 is redundantly written in tidx so tv2 is redundantly + // consumed in tidx. + TORCH_INTERNAL_ASSERT( + tv2_info.redundant_use_types.get(ParallelType::TIDx), + "TV2 is redundantly used but not detected."); + + // tv4->tv5 (global) is a redundant use chain, but + // tv4->tv6->tv7 is not, so tv4 should not be detected as + // a redundant used tensor in tidx. + TORCH_INTERNAL_ASSERT( + !tv4_info.redundant_use_types.get(ParallelType::TIDx), + "TV4 is not redundantly used but not detected."); +} + +// Test a basic swizzle pattern +TEST_F(NVFuserTest, FusionSimpleSwizzle0_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({2, 32}); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); + + fusion.addOutput(tv2); + + // Make a 2x8 Zshape tile + tv1->split(-1, 16); + tv1->split(-1, 8); + // [O, 2, 8] + + tv2->split(-1, 16); + tv2->split(-1, 4); + //[O, 4, 4] + + tv1->computeAt(tv2, 1); + tv1->swizzle(Swizzle2DType::ZShape, -2, -1); + + FusionExecutor fe; + fe.compileFusion(&fusion); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({2, 32}, options); + auto t2 = t0 + 2.0; + auto cg_outputs = fe.runFusion({t0}); + + testValidate(&fusion, cg_outputs, {t0}, {t2}, __LINE__, __FILE__); +} + +// Test swizzle inlining +TEST_F(NVFuserTest, FusionSimpleSwizzle1_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({2, 32}); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); + auto tv3 = add(tv2, IrBuilder::create(1)); + + fusion.addOutput(tv3); + + // Make a 2x8 Zshape tile + tv2->split(-1, 16); + tv2->split(-1, 8); + // [O, 2, 8] + + tv3->split(-1, 16); + tv3->split(-1, 4); + //[O, 4, 4] + + tv2->computeAt(tv3, 1); + tv2->swizzle(Swizzle2DType::ZShape, -2, -1); + + // Inlining a producer into a swizzled consumer is ok + tv1->computeAt(tv2, -1); + + FusionExecutor fe; + fe.compileFusion(&fusion); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({2, 32}, options); + auto t3 = t0 + 3.0; + auto cg_outputs = fe.runFusion({t0}); + + testValidate(&fusion, cg_outputs, {t0}, {t3}, __LINE__, __FILE__); +} + +// Test sync insertion and memory check in parallelized swizzles. +// In this test, data is parallel written into smem in zcurve +// pattern and then read out and output to global mem unswizzled. +TEST_F(NVFuserTest, FusionSimpleSwizzle2_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({32, 32}); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); + + fusion.addOutput(tv2); + + tv1->swizzle(Swizzle2DType::ZShape, -2, -1); + + tv1->axis(0)->parallelize(ParallelType::TIDx); + tv1->axis(1)->parallelize(ParallelType::TIDy); + + tv2->axis(0)->parallelize(ParallelType::TIDx); + tv2->axis(1)->parallelize(ParallelType::TIDy); + + // Validation should fail since TV1 is not in shared + // memory as required by sync info pass. + ASSERT_ANY_THROW(GpuLower gpulw_throw(&fusion)); + + tv1->setMemoryType(MemoryType::Shared); + + // Make sure that a sync is inserted: + bool sync_found = false; + GpuLower gpu_lw(&fusion); + auto flattened_exps = + ir_utils::flattenScopedExprs(gpu_lw.kernel()->topLevelExprs()); + + for (auto expr : flattened_exps) { + if (expr->isA()) { + sync_found = true; + } + // Will require a sync thread before any shared memory read. + for (auto inp_tv : ir_utils::filterByType(expr->inputs())) { + if (inp_tv->getMemoryType() == MemoryType::Shared) { + TORCH_INTERNAL_ASSERT( + sync_found, "Block sync required but not inserted"); + } + } + } + + FusionExecutor fe; + fe.compileFusion(&fusion); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({32, 32}, options); + auto t2 = t0 + 2.0; + auto cg_outputs = fe.runFusion({t0}); + + testValidate(&fusion, cg_outputs, {t0}, {t2}, __LINE__, __FILE__); +} + +// Test BestEffortReplay behavior with swizzle op +TEST_F(NVFuserTest, FusionSwizzleMapping_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({2, 32}); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); + auto tv3 = add(tv2, IrBuilder::create(1)); + + fusion.addOutput(tv3); + + // Make a 2x8 Zshape tile + tv2->split(-1, 16); + tv2->split(-1, 8); + // [O, 2, 8] + + tv3->split(-1, 16); + tv3->split(-1, 4); + //[O, 4, 4] + + tv2->computeAt(tv3, 1); + tv2->swizzle(Swizzle2DType::ZShape, -2, -1); + + // Inlining a producer into a swizzled consumer is ok + tv1->computeAt(tv2, -1); + + // Check BestEffortReplay behavior with skip swizzles option on. + PairwiseRootDomainMap root_map(tv1, tv2); + + // Check producer to consumer map, + // i.e. unswizzled tensor to swizzled tensor map + //---------------------------------------------------------- + auto p2c = BestEffortReplay::replayCasP(tv2, tv1, -1, root_map).getReplay(); + auto swizzle_x_it0 = p2c.find(tv1->axis(-2)); + auto swizzle_y_it0 = p2c.find(tv1->axis(-1)); + // P2C map should exist and both the x and y map should + // map to the output of the swizzle op. + TORCH_INTERNAL_ASSERT( + swizzle_x_it0 != p2c.end() && swizzle_y_it0 != p2c.end()); + TORCH_INTERNAL_ASSERT( + swizzle_x_it0->second == tv2->axis(-2) && + swizzle_y_it0->second == tv2->axis(-1)); + + // Check consumer to producer map, + // i.e. swizzled tensor to unswizzled tensor map + //---------------------------------------------------------- + auto c2p = BestEffortReplay::replayPasC(tv1, tv2, -1, root_map).getReplay(); + + auto swizzle_op = tv2->axis(-1)->definition()->as(); + + // Find mapping for swizzle inputs + auto swizzle_x_it1 = c2p.find(swizzle_op->inX()); + auto swizzle_y_it1 = c2p.find(swizzle_op->inY()); + + // Find mapping for swizzle outputs + auto swizzle_x_it2 = c2p.find(swizzle_op->outX()); + auto swizzle_y_it2 = c2p.find(swizzle_op->outY()); + + // Input of swizzle ops will not be mapped to any + // by BestEffortReplay, as BestEffortReplay has to be + // one to one. IdGraph will further map them together. + TORCH_INTERNAL_ASSERT( + swizzle_x_it1 == c2p.end() && swizzle_y_it1 == c2p.end()); + + // Mapping for swizzle outputs should be mapped and should + // also map to the corresponding axes on the unswizzled tensor. + TORCH_INTERNAL_ASSERT( + swizzle_x_it2 != c2p.end() && swizzle_y_it2 != c2p.end()); + TORCH_INTERNAL_ASSERT( + swizzle_x_it2->second == tv1->axis(-2) && + swizzle_y_it2->second == tv1->axis(-1)); + + // Check id graph behavior + //---------------------------------------------------------- + ComputeAtMap ca_map(&fusion); + // Corresponding inputs and outputs of swizzle ops are + // map through by exact and permissive map. + TORCH_INTERNAL_ASSERT( + ca_map.areMapped(tv1->axis(-2), swizzle_op->inX(), IdMappingMode::EXACT)); + TORCH_INTERNAL_ASSERT( + ca_map.areMapped(tv1->axis(-1), swizzle_op->inY(), IdMappingMode::EXACT)); + TORCH_INTERNAL_ASSERT(ca_map.areMapped( + tv1->axis(-2), swizzle_op->outX(), IdMappingMode::EXACT)); + TORCH_INTERNAL_ASSERT(ca_map.areMapped( + tv1->axis(-1), swizzle_op->outY(), IdMappingMode::EXACT)); + + TORCH_INTERNAL_ASSERT(ca_map.areMapped( + tv1->axis(-2), swizzle_op->inX(), IdMappingMode::PERMISSIVE)); + TORCH_INTERNAL_ASSERT(ca_map.areMapped( + tv1->axis(-1), swizzle_op->inY(), IdMappingMode::PERMISSIVE)); + TORCH_INTERNAL_ASSERT(ca_map.areMapped( + tv1->axis(-2), swizzle_op->outX(), IdMappingMode::PERMISSIVE)); + TORCH_INTERNAL_ASSERT(ca_map.areMapped( + tv1->axis(-1), swizzle_op->outY(), IdMappingMode::PERMISSIVE)); +} + +// Test a basic loop swizzle pattern +TEST_F(NVFuserTest, FusionLoopSwizzle0_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({2, 32}); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); + + fusion.addOutput(tv2); + + tv2->split(-1, 16); + tv2->split(-1, 4); + //[O, 4, 4] + + tv2->swizzle(Swizzle2DType::ZShape, -2, -1, SwizzleMode::Loop); + + tv0->computeAt(tv2, -1); + + FusionExecutor fe; + fe.compileFusion(&fusion); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({2, 32}, options); + auto t2 = t0 + 2.0; + auto cg_outputs = fe.runFusion({t0}); + + testValidate(&fusion, cg_outputs, {t0}, {t2}, __LINE__, __FILE__); +} + +// Outer block zshape pattern +TEST_F(NVFuserTest, FusionLoopSwizzle1_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(2); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); + + fusion.addOutput(tv2); + + tv2->split(-2, 8); + tv2->split(-1, 4); + //[I0o, I0i, I1o, I1i] + tv2->reorder({{1, 2}, {2, 1}}); + //[I0o, I1o, I0i, I1i] + + tv2->swizzle(Swizzle2DType::ZShape, 0, 1, SwizzleMode::Loop); + tv0->computeAt(tv2, -1); + + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(1)->parallelize(ParallelType::BIDy); + + FusionExecutor fe; + fe.compileFusion(&fusion); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({45, 77}, options); + auto t2 = t0 + 2.0; + auto cg_outputs = fe.runFusion({t0}); + + testValidate(&fusion, cg_outputs, {t0}, {t2}, __LINE__, __FILE__); +} + +// Test assertion in unsupported pattern: non-leaf loop swizzle. +TEST_F(NVFuserTest, FusionLoopSwizzleCheck0_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({2, 32}); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); + + fusion.addOutput(tv2); + + tv2->split(-1, 16); + tv2->split(-1, 4); + //[O, 4, 4] + + // Swizzle the inner tile. + tv2->swizzle(Swizzle2DType::ZShape, -2, -1, SwizzleMode::Loop); + + // Make swizzle output not a leaf domain. + tv2->merge(-2); + + tv0->computeAt(tv2, -1); + + FusionExecutor fe; + ASSERT_ANY_THROW(fe.compileFusion(&fusion)); +} + +// Test assertion in unsupported pattern: half-inlined loop swizzle. +TEST_F(NVFuserTest, FusionLoopSwizzleCheck1_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({2, 32}); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = add(tv1, IrBuilder::create(1)); + auto tv3 = add(tv2, IrBuilder::create(1)); + + fusion.addOutput(tv3); + + //[O, 4, 4] + tv2->split(-1, 16); + tv2->split(-1, 4); + + //[O, 4, 4] + tv3->split(-1, 16); + tv3->split(-1, 4); + + // Swizzle inner tile of tv2 + tv2->swizzle(Swizzle2DType::ZShape, -2, -1, SwizzleMode::Loop); + + // Make tv2 swizzled and partially-inlined (unsupported). + tv0->computeAt(tv3, -2); + + FusionExecutor fe; + ASSERT_ANY_THROW(fe.compileFusion(&fusion)); +} + +TEST_F(NVFuserTest, FusionUnsqueeze1_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + std::vector shape({10, 11}); + + auto tv0 = makeConcreteTensor(shape); + fusion.addInput(tv0); + + // [I, R] + auto tv1 = sum(tv0, {1}); + // [I, B] + auto tv2 = unsqueeze(tv1, -1); + fusion.addOutput(tv2); + + TORCH_CHECK( + tv2->nDims() == 2, "Unexpected unsqueeze result: ", tv2->toString()); + TORCH_CHECK( + tv2->axis(1)->isBroadcast(), + "Unexpected unsqueeze result: ", + tv2->toString()); + + // tv1 has only one non-reduction axis. An exception should be + // thrown. + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) + ASSERT_ANY_THROW(unsqueeze(tv1, 2)); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({10, 11}, options); + std::vector aten_inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + auto ref = t0.sum(1).unsqueeze(-1); + + testValidate(&fusion, cg_outputs, aten_inputs, {ref}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionSqueeze1_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + std::vector shape({10, 11}); + + auto tv0 = makeConcreteTensor(shape); + fusion.addInput(tv0); + + // [I, B] + auto tv1 = sum(tv0, {1}, true); + // [I] + auto tv2 = squeeze(tv1, {shape[0], 1}); + fusion.addOutput(tv2); + + TORCH_CHECK( + tv2->nDims() == 2, "Unexpected squeeze result: ", tv2->toString()); + + // [I, R] + auto tv3 = sum(tv0, {1}); + // tv3 has only one non-reduction axis. The extent of the first axis + // is not one, so squeeze should fail. + // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) + ASSERT_ANY_THROW(squeeze(tv3, {shape[0], 1})); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({10, 11}, options); + std::vector aten_inputs = {t0}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + auto ref = t0.sum(1, true).squeeze(-1); + + testValidate(&fusion, cg_outputs, aten_inputs, {ref}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionContigPredicate_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto tv1 = set(tv0); + auto tv2 = broadcast(tv1, {false, true, false}); + fusion.addOutput(tv2); + + tv2->merge(-2, -1); + tv2->merge(-2, -1); + tv2->split(-1, 100); + tv0->computeAt(tv2, -1); + + GpuLower gpulw(&fusion); + TORCH_CHECK(PredicatedChecker::isPredicated(tv1, gpulw)); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({3, 4}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto cg_outputs = fe.runFusion({t0}); + + auto ref = t0.unsqueeze(1); + + testValidate(fe.kernel(), cg_outputs, {t0}, {ref}, __LINE__, __FILE__); +} + +// Repro of https://github.com/csarofeen/pytorch/issues/1777 +TEST_F(NVFuserTest, FusionDivScalarLhs_CUDA) { + // tv1 = 2.0 / tv0 + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + TensorView* tv1 = div(IrBuilder::create(2.0), tv0); + fusion.addOutput(tv1); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({3, 3}, options); + // There's no overload div(Scalar, Tensor) in ATen + auto aten_output = at::div( + at::native::wrapped_scalar_tensor(at::Scalar(2.0), options.device()), t0); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0}); + auto cg_outputs = fe.runFusion({t0}); + + testValidate(&fusion, cg_outputs, {t0}, {aten_output}, __LINE__, __FILE__); +} + +// Repro of an issue of the reduction scheduler with a broadcast +// domain concretized to multiple domains that are not proven to have +// the same extent +TEST_F(NVFuserTest, FusionRepro1713_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeSymbolicTensor(2); + auto tv1 = makeSymbolicTensor(2); + auto tv2 = makeSymbolicTensor(1); + fusion->addInput(tv0); + fusion->addInput(tv1); + fusion->addInput(tv2); + auto tv3 = broadcast(tv2, {false, true}); + + auto tv4 = add(tv3, tv0); + + auto tv5 = add(tv3, tv1); + auto tv6 = sum(tv5, {0}); + fusion->addOutput(tv4); + fusion->addOutput(tv6); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({1024, 204800}, options); + // Original repro had the same shape as t0, but this should work + // with a different extent at the second axis + at::Tensor t1 = at::randn({1024, 123}, options); + at::Tensor t2 = at::randn({1024}, options); + std::vector aten_inputs({t0, t1, t2}); + + FusionExecutorCache executor_cache(std::move(fusion)); + auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); + + auto t3 = t2.unsqueeze(-1); + auto t4 = t3 + t0; + auto t5 = t3 + t1; + auto t6 = sum(t5, {0}); + + testValidate( + executor_cache.fusion(), + cg_outputs, + {t0, t1, t2}, + {t4, t6}, + __LINE__, + __FILE__); +} + +TEST_F(NVFuserTest, FusionExpand_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto w = 2, x = 3, y = 4, z = 5; + + // Test + // a simple expand + // Expand that's propagated + // expand_as + // symbolic expand + + // x + auto tv0 = makeSymbolicTensor(1); + fusion->addInput(tv0); + + auto tv1 = broadcast(tv0, {false, true}); + auto tv2 = expand(tv1, {tv0->axis(0)->extent(), IrBuilder::create(y)}); + + // x + auto tv3 = makeSymbolicTensor(1); + fusion->addInput(tv3); + auto tv4 = broadcast(tv3, {false, true}); + auto tv5 = add(tv4, tv2); + // [x, e_y] + + // [x, y, z] + auto tv6 = makeSymbolicTensor(3); + fusion->addInput(tv6); + + // Disjoint set op will cause a segmentation for just this op. + auto tmp_7 = set(tv6); + fusion->addOutput(tmp_7); + + auto tv7 = broadcast(tv5, {false, false, true}); + + auto tv8 = expand_as(tv7, tv6); + // [x, e_y, e_z] + + auto w_symbolic = IrBuilder::create(); + fusion->addInput(w_symbolic); + + auto tv9 = broadcast(tv8, {true, false, false, false}); + //[1, x, e_y, e_z] + + auto tv10 = expand( + tv9, + {w_symbolic, + tv9->axis(1)->extent(), + tv9->axis(2)->expandedExtent(), + tv9->axis(3)->expandedExtent()}); + + fusion->addOutput(tv10); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({x}, options); + at::Tensor t3 = at::randn({x}, options); + at::Tensor t6 = at::randn({x, y, z}, options); + + FusionExecutorCache executor_cache(std::move(fusion)); + + auto cg_outputs = executor_cache.runFusionWithInputs({t0, t3, t6, w}); + auto cg_out = cg_outputs[1]; + + TORCH_INTERNAL_ASSERT(cg_out.size(0) == w); + TORCH_INTERNAL_ASSERT(cg_out.size(1) == x); + TORCH_INTERNAL_ASSERT(cg_out.size(2) == y); + TORCH_INTERNAL_ASSERT(cg_out.size(3) == z); + TORCH_INTERNAL_ASSERT(cg_out.stride(0) == 0); + TORCH_INTERNAL_ASSERT(cg_out.stride(1) == 1); + TORCH_INTERNAL_ASSERT(cg_out.stride(2) == 0); + TORCH_INTERNAL_ASSERT(cg_out.stride(3) == 0); + + auto t10 = t0.unsqueeze(-1) + .expand({x, y}) + .add(t3.unsqueeze(-1)) + .unsqueeze(-1) + .expand_as(t6) + .unsqueeze(0) + .expand({w, x, y, z}); + + testValidate( + executor_cache.fusion(), + cg_outputs, + {t0, t3, t6, w}, + {t6, t10}, + __LINE__, + __FILE__); +} + +TEST_F(NVFuserTest, FusionExpandIssue1751_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto x = 3, y = 4, z = 5; + + // y, z + auto tv0 = makeSymbolicTensor(2); + fusion->addInput(tv0); + + auto tv1 = broadcast(tv0, {true, false, false}); + + // Two ways to propagate extents as is: use -1 or explicitly pass + // the extent vals. + + auto tv2 = expand( + tv1, + {IrBuilder::create(x), + IrBuilder::create(-1), + IrBuilder::create(-1)}); + + auto tv3 = expand( + tv1, + {IrBuilder::create(x), + tv0->axis(0)->extent(), + tv0->axis(1)->extent()}); + + fusion->addOutput(tv2); + fusion->addOutput(tv3); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({y, z}, options); + + FusionExecutorCache executor_cache(std::move(fusion)); + + auto cg_outputs = executor_cache.runFusionWithInputs({t0}); + + for (const auto& cg_out : cg_outputs) { + TORCH_INTERNAL_ASSERT(cg_out.size(0) == x); + TORCH_INTERNAL_ASSERT(cg_out.size(1) == y); + TORCH_INTERNAL_ASSERT(cg_out.size(2) == z); + } + + auto t2 = t0.expand({x, y, z}); + + testValidate( + executor_cache.fusion(), cg_outputs, {t0}, {t2, t2}, __LINE__, __FILE__); +} + +// TODO: Make sure the kernel uses the expanded concrete size instead +// of the symbolic size +TEST_F(NVFuserTest, FusionExpandToConcrete_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto x = 3, y = 4; + + auto tv0 = makeSymbolicTensor(1); + fusion->addInput(tv0); + + auto tv1 = broadcast(tv0, {true, false}); + + auto tv2 = + expand(tv1, {IrBuilder::create(x), IrBuilder::create(y)}); + + fusion->addOutput(tv2); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({y}, options); + + FusionExecutorCache executor_cache(std::move(fusion)); + + auto cg_outputs = executor_cache.runFusionWithInputs({t0}); + + for (const auto& cg_out : cg_outputs) { + TORCH_INTERNAL_ASSERT(cg_out.size(0) == x); + TORCH_INTERNAL_ASSERT(cg_out.size(1) == y); + } + + auto t2 = t0.expand({x, y}); + + testValidate( + executor_cache.fusion(), cg_outputs, {t0}, {t2}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionReproNoncontigBroadcast_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({4, 32, 16, 112, 112}, options).transpose(-1, -2); + at::Tensor t1 = at::randn({32, 1, 112, 1}, options).transpose(-1, -2); + + auto tv0 = TensorViewBuilder() + .ndims(5) + .contiguity({true, true, false, false, false}) // ttfff + .shape({-1, -1, -1, -1, -1}) + .dtype(DataType::Half) + .build(); + auto tv1 = TensorViewBuilder() + .ndims(4) + .contiguity({true, false, false, true}) // tfft + .shape({-1, 1, 1, -1}) + .dtype(DataType::Half) + .build(); + + fusion->addInput(tv0); + fusion->addInput(tv1); + + auto tv2 = add(tv0, tv1); + + fusion->addOutput(tv2); + + std::vector aten_inputs({t0, t1}); + + FusionExecutorCache executor_cache(std::move(fusion)); + auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); + + auto t2 = t0 + t1; + + testValidate( + executor_cache.fusion(), cg_outputs, {t0, t1}, {t2}, __LINE__, __FILE__); +} + +namespace { + +// check that the resulting sibling are identical +void checkSiblingConsistency(TensorView* replay, TensorView* target) { + auto replay_root = replay->getRootDomain(); + auto replay_dom = replay->domain()->domain(); + auto target_root = target->getRootDomain(); + auto target_dom = target->domain()->domain(); + std::unordered_map target2replay_map; + TORCH_CHECK(replay_root.size() == target_root.size()); + target2replay_map.reserve(replay_root.size()); + std::transform( + target_root.begin(), + target_root.end(), + replay_root.begin(), + std::inserter(target2replay_map, target2replay_map.begin()), + [](auto a, auto b) { return std::make_pair(a, b); }); + BestEffortReplay replay_(replay_dom, target_dom, target2replay_map); + auto r = replay_.getReplay(); + for (int64_t i = 0; i < replay_dom.size(); i++) { + auto target_id = target_dom[i]; + auto replay_it = r.find(target_id); + TORCH_CHECK(replay_it != r.end()); + TORCH_CHECK( + replay_it->second == replay_dom[i], + "IterDomain mismatch when checking ", + replay, + " and ", + target, + " at ", + i, + ", got ", + replay_it->second, + " and ", + replay_dom[i]); + } +}; + +} // namespace + +TEST_F(NVFuserTest, FusionTransformPropagateSibling_CUDA) { + // https://github.com/csarofeen/pytorch/issues/1760 + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto tvs = Welford(tv0, {1}); + fusion.addOutput(tvs.var_sum); + + tvs.avg->split(1, 1); + tvs.avg->split(1, 2); + tvs.avg->split(1, 3); + tvs.var_sum->split(1, 1); + tvs.var_sum->split(1, 2); + tvs.var_sum->split(1, 3); + tvs.n->split(1, 1); + tvs.n->split(1, 2); + tvs.n->split(1, 3); + + auto var_sum_rf = ir_utils::rfactorHelper(tvs.var_sum, {1, 4}); + + TransformPropagatorWithCheck propagator(var_sum_rf); + MaxRootDomainInfoSpanningTree(var_sum_rf).traverse(&propagator); + + auto rf_tvs = ir_utils::producerTvsOf(tvs.var_sum); + + std::vector siblings[] = {{tvs.avg, tvs.var_sum, tvs.n}, rf_tvs}; + for (auto tensors : siblings) { + for (auto t1 : tensors) { + for (auto t2 : tensors) { + TORCH_CHECK(TransformReplay::fullSelfMatching(t1, t2)); + } + } + } +} + +TEST_F(NVFuserTest, FusionTransformPropagateSelectorSibling_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto tvs = Welford(tv0, {1}); + fusion.addOutput(tvs.var_sum); + + tvs.avg->split(1, 1); + tvs.avg->split(1, 2); + tvs.avg->split(1, 3); + tvs.var_sum->split(1, 1); + tvs.var_sum->split(1, 2); + tvs.var_sum->split(1, 3); + tvs.n->split(1, 1); + tvs.n->split(1, 2); + tvs.n->split(1, 3); + + auto var_sum_rf = ir_utils::rfactorHelper(tvs.var_sum, {1, 4}); + + struct DisableTv0 : public MaxInfoSpanningTree::Selector { + TensorView* tv0; + virtual bool allowC2P(TensorView* from, TensorView* to) override { + return from != tv0 && to != tv0; + }; + virtual bool allowP2C(TensorView* from, TensorView* to) override { + return from != tv0 && to != tv0; + }; + virtual bool allowSibling(TensorView* from, TensorView* to) override { + return true; + } + DisableTv0(TensorView* tv0) : tv0(tv0) {} + } selector1(tv0); + + struct DisableTv0AndSibling : public DisableTv0 { + virtual bool allowSibling(TensorView* from, TensorView* to) override { + return false; + } + using DisableTv0::DisableTv0; + } selector2(tv0); + + TransformPropagatorWithCheck propagator(var_sum_rf); + MaxRootDomainInfoSpanningTree good_path(var_sum_rf, &selector1); + MaxRootDomainInfoSpanningTree bad_path(var_sum_rf, &selector2); + + auto rf_tvs = ir_utils::producerTvsOf(tvs.var_sum); + + auto check = [&]() { + std::vector siblings[] = { + {tvs.avg, tvs.var_sum, tvs.n}, rf_tvs}; + for (auto tensors : siblings) { + for (auto t1 : tensors) { + for (auto t2 : tensors) { + TORCH_CHECK(TransformReplay::fullSelfMatching(t1, t2)); + } + } + } + }; + + bad_path.traverse(&propagator); + ASSERT_ANY_THROW(check()); + good_path.traverse(&propagator); + check(); +} + +TEST_F(NVFuserTest, FusionTransformPropagatePosition_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(4); + auto tv1 = makeSymbolicTensor(6); + fusion.addInput(tv0); + + auto tv2 = broadcast(tv0, {false, false, true, false, false, true}); + auto tv3 = add(tv1, tv2); + fusion.addOutput(tv3); + + tv0->merge(2); + tv0->merge(0); + TransformPropagatorWithCheck propagator(tv0); + MaxRootDomainInfoSpanningTree(tv0).traverse(&propagator); + + TORCH_CHECK(tv1->nDims() == 4); +} + +TEST_F(NVFuserTest, FusionIgnoreZeroDimReduction_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeSymbolicTensor(1); + fusion->addInput(tv0); + auto tv1 = sum(tv0, {0}); + // tv1 is effectively a zero-dim tensor as it only has a reduction + // axis. + // Reducing it further is converted to just a set op. + auto tv2 = sum(tv1, {0}); + fusion->addOutput(tv2); + + auto tv2_def = dynamic_cast(tv2->definition()); + TORCH_CHECK( + tv2_def != nullptr, + "Expected UnaryOp but found ", + tv2->definition()->toString()); + + TORCH_CHECK( + tv2_def->getUnaryOpType() == UnaryOpType::Set, + "Expected UnaryOpType::Set but found ", + tv2_def->getUnaryOpType()); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({12345}, options); + std::vector aten_inputs({t0}); + + FusionExecutorCache executor_cache(std::move(fusion)); + auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); + + auto ref = sum(t0, {0}); + + testValidate( + executor_cache.fusion(), + cg_outputs, + aten_inputs, + {ref}, + __LINE__, + __FILE__); +} + +// Repro of issue #1770 +TEST_F(NVFuserTest, FusionIssue1770Repro_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeSymbolicTensor(1); + fusion->addInput(tv0); + auto tv1 = makeSymbolicTensor(1); + fusion->addInput(tv1); + + auto tv2 = ge(tv0, tv1); + auto tv3 = + where(tv2, IrBuilder::create(1), IrBuilder::create(2)); + fusion->addOutput(tv3); + + std::vector shape({999}); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn(shape, options); + at::Tensor t1 = at::randn(shape, options); + std::vector aten_inputs({t0, t1}); + + FusionExecutorCache executor_cache(std::move(fusion)); + auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); + + auto ref = where(t0 >= t1, 1.0, 2.0); + + testValidate( + executor_cache.fusion(), + cg_outputs, + aten_inputs, + {ref}, + __LINE__, + __FILE__); +} + +TEST_F(NVFuserTest, FusionTransformPropagatorSelector_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeSymbolicTensor(1); + fusion->addInput(tv0); + auto tv1 = makeSymbolicTensor(1); + fusion->addInput(tv1); + + auto tv2 = add(tv0, tv1); + + auto tv3 = sin(tv2); + auto tv4 = cos(tv2); + + fusion->addOutput(tv3); + fusion->addOutput(tv4); + + tv2->split(0, 10); + + struct Selector : public MaxInfoSpanningTree::Selector { + TensorView* tv0; + TensorView* tv3; + virtual bool allowC2P(TensorView* from, TensorView* to) override { + return to == tv0; + } + virtual bool allowP2C(TensorView* from, TensorView* to) override { + return to == tv3; + } + virtual bool allowSibling(TensorView* from, TensorView* to) override { + return false; + } + Selector(TensorView* tv0, TensorView* tv3) : tv0(tv0), tv3(tv3) {} + } selector(tv0, tv3); + + TransformPropagatorWithCheck propagator(tv2); + MaxRootDomainInfoSpanningTree(tv2, &selector).traverse(&propagator); + + TORCH_CHECK(tv0->nDims() == 2); + TORCH_CHECK(tv1->nDims() == 1); + TORCH_CHECK(tv2->nDims() == 2); + TORCH_CHECK(tv3->nDims() == 2); + TORCH_CHECK(tv4->nDims() == 1); +} + +TEST_F(NVFuserTest, FusionTransformPropagatorPos_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeConcreteTensor({22, 105}); + fusion->addInput(tv0); + + auto tv1 = sin(tv0); + fusion->addOutput(tv1); + + tv1->split(0, 2); + tv1->split(-1, 3); + tv1->split(-1, 5); + + TransformPropagatorWithCheck propagator(tv1, 2); + MaxRootDomainInfoSpanningTree(tv1, 2).traverse(&propagator); + + auto expect = makeConcreteTensor({22, 105}); + expect->split(0, 2); + TORCH_CHECK(TransformReplay::fullSelfMatching(expect, tv0)); +} + +TEST_F(NVFuserTest, FusionMaxRootDomainInfoSpanningTreePrintTwice_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeSymbolicTensor(3); + fusion->addInput(tv0); + + auto tv1 = sum(tv0, {0}); + auto tv2 = neg(tv1); + + fusion->addOutput(tv2); + + tv1->split(0, 10); + + struct Printer : public MaxInfoSpanningTree::Propagator { + std::stringstream ss; + virtual void propagateC2P(TensorView* from, TensorView* to) override { + ss << "propagateC2P" << std::endl; + ss << "from: " << from->name() << std::endl; + ss << "to: " << to->name() << std::endl; + } + virtual void propagateP2C(TensorView* from, TensorView* to) override { + ss << "propagateP2C" << std::endl; + ss << "from: " << from->name() << std::endl; + ss << "to: " << to->name() << std::endl; + } + virtual void propagateSibling(TensorView* from, TensorView* to) override { + ss << "propagateSibling" << std::endl; + ss << "from: " << from->name() << std::endl; + ss << "to: " << to->name() << std::endl; + } + } printer1, printer2; + printer1.ss << std::endl; + printer2.ss << std::endl; + + MaxRootDomainInfoSpanningTree path(tv1); + path.traverse(&printer1); + path.traverse(&printer2); + + auto expect = R"ESCAPE( +propagateC2P +from: 1 +to: 0 +propagateP2C +from: 1 +to: 2 +)ESCAPE"; + TORCH_CHECK(printer1.ss.str() == expect); + TORCH_CHECK(printer2.ss.str() == expect); +} + +TEST_F(NVFuserTest, FusionTransformPropagatorNoOverwrite_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeSymbolicTensor(1); + fusion->addInput(tv0); + auto tv1 = broadcast(tv0, {true, false, true}); + auto tv2 = sin(tv1); + fusion->addOutput(tv2); + + tv0->split(0, 2); + tv2->split(1, 2); + tv2->split(0, 4); + + MaxRootDomainInfoSpanningTree path1(tv2); + TransformPropagatorWithCheck propagator1(tv2); + path1.traverse(&propagator1); + + MaxRootDomainInfoSpanningTree path2(tv0); + TransformPropagatorWithCheck propagator2(tv0); + path2.traverse(&propagator2); + + TORCH_CHECK(tv1->axis(0)->isBroadcast()); + TORCH_CHECK(tv1->axis(1)->isBroadcast()); + TORCH_CHECK(!tv1->axis(2)->isBroadcast()); + TORCH_CHECK(!tv1->axis(3)->isBroadcast()); + TORCH_CHECK(tv1->axis(4)->isBroadcast()); + + auto expect = makeSymbolicTensor(3); + expect->split(1, 2); + expect->split(0, 4); + TORCH_CHECK(TransformReplay::fullSelfMatching(expect, tv1)); +} + +TEST_F(NVFuserTest, FusionIssue1785Repro_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + TensorView* tv0 = makeContigTensor(1); + TensorView* tv1 = makeContigTensor(2); + + // Register your inputs + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = set(tv0); + // [B, I] + auto tv3 = broadcast(tv2, {true, false}); + auto tv4 = add(tv3, tv1); + auto tv5 = set(tv4); + + // Register your outputs + fusion.addOutput(tv5); + + tv5->split(0, 8); + tv5->split(-1, 8); + + // [Serial, TIDy, TIDX, Serial] + + tv4->computeAt(tv5, -2); + tv3->computeAt(tv4, -1); + tv2->computeAt(tv3, 0); + tv2->split(0, 8); + tv2->axis(0)->parallelize(ParallelType::TIDx); + tv1->computeAt(tv5, -2); + + tv5->axis(1)->parallelize(ParallelType::TIDy); + tv5->axis(2)->parallelize(ParallelType::TIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor in1 = at::randn({16}, options); + at::Tensor in2 = at::randn({12, 16}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {in1, in2}); + auto cg_outputs = fe.runFusion({in1, in2}); + + auto tv_ref = in1 + in2; + + testValidate(&fusion, cg_outputs, {in1, in2}, {tv_ref}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionSkipReplay_CUDA) { + { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeContigTensor(1); + TensorView* tv1 = makeContigTensor(2); + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = broadcast(tv0, {false, true}); + auto tv3 = add(tv2, tv1); + fusion.addOutput(tv3); + + tv3->split(1, 2, false); + + TransformPropagatorWithCheck propagator(tv3); + MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + } + + { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeContigTensor(3); + fusion.addInput(tv0); + + auto tv1 = sum(tv0, {0, 2}); + auto tv2 = sin(tv1); + fusion.addOutput(tv2); + + tv0->split(1, 2, false); + + TransformPropagatorWithCheck propagator(tv0); + MaxRootDomainInfoSpanningTree(tv0).traverse(&propagator); + } +} + +TEST_F(NVFuserTest, FusionInlineRepro1803_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeContigTensor(2); + + fusion.addInput(tv0); + auto tv1 = set(tv0); + auto tvs = Welford(tv1, {1}); + auto tvo = set(tvs.var_sum); + fusion.addOutput(tvo); + + tvo->split(0, 16); + tvo->axis(1)->parallelize(ParallelType::Unroll); + + tv0->computeAt(tvo, -1, ComputeAtMode::BestEffort); + + TORCH_CHECK( + tvs.var_sum->getComputeAtPosition() == tvs.avg->getComputeAtPosition()); + TORCH_CHECK( + tvs.var_sum->getComputeAtPosition() == tvs.n->getComputeAtPosition()); + TORCH_CHECK(tvs.var_sum->getComputeAtPosition() == 1); +} + +// Unit test for the transform selection logic +TEST_F(NVFuserTest, FusionBoundedDirectionSelection1_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeContigTensor(2); + + fusion.addInput(tv0); + auto tv1 = set(tv0); + auto tv2 = set(tv1); + auto tv3 = add(tv2, tv1); + fusion.addOutput(tv3); + + tv3->split(-1, 5); + tv3->split(-1, 8); + + scheduler_utils::BoundedDirectionalTransformPropagator::backward( + tv3, -1, {tv0, tv2}); + + // Check that the splits are replayed on tv2 + TORCH_INTERNAL_ASSERT( + tv2->nDims() == tv3->nDims(), + "Propagator didn't propagate to tv2: ", + tv2->toString()); + + // Check that the splits are replayed on tv1 as well. Even though + // one of its consumers, tv2, is part of the boundary, another + // consumer is not a boundary, so tv1 should be transformed as well. + TORCH_INTERNAL_ASSERT( + tv1->nDims() == tv3->nDims(), + "Propagator didn't propagate to tv1: ", + tv1->toString()); +} + +TEST_F(NVFuserTest, FusionIssueRepro1844_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + std::vector shape = {2, 1, 768}; + std::vector sum_to_shape = {768}; + std::vector sum_to_axes = {0, 1}; + double kProb = 0.5; + + std::vector sum_to_symb; + std::transform( + sum_to_shape.begin(), + sum_to_shape.end(), + std::back_inserter(sum_to_symb), + [](int s) -> Int* { return IrBuilder::create(s); }); + + TensorView* tv0 = makeContigConcreteTensor(shape); + TensorView* tv1 = makeContigConcreteTensor(shape); + TensorView* tv2 = makeContigConcreteTensor(shape, DataType::Bool); + + fusion->addInput(tv0); + fusion->addInput(tv1); + fusion->addInput(tv2); + + Double* prob = IrBuilder::create(kProb); + auto grad_input = dropout_backward(tv1, tv2, prob); + auto grad_gelu = gelu_backward(grad_input, tv0); + auto grad_bias = sum_to(grad_gelu, sum_to_symb); + + fusion->addOutput(grad_gelu); + fusion->addOutput(grad_bias); + + const auto options = + at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + const auto mask_options = + at::TensorOptions().dtype(at::kBool).device(at::kCUDA, 0); + at::manual_seed(0); + + at::Tensor a = at::randn(shape, options); + at::Tensor b = at::randn(shape, options); + at::Tensor c = at::randn(shape, options); + auto mask = at::gt(c, 0.0f); + std::vector aten_inputs = {a, b, mask}; + + FusionExecutorCache executor_cache(std::move(fusion)); + auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); + + auto dinput = at::native_dropout_backward(b, mask, kProb); + auto dgelu = at::gelu_backward(dinput, a, "none"); + auto dbias = dgelu.sum(sum_to_axes); + + testValidate( + executor_cache.fusion(), + cg_outputs, + aten_inputs, + {dgelu, dbias}, + __LINE__, + __FILE__); +} + +TEST_F(NVFuserTest, FusionInsertMagicZero1_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = set(tv1); + fusion.addOutput(tv2); + + tv2->split(0, 32); + tv2->split(-1, 2); + tv2->reorder({{1, 2}, {2, 1}}); + tv2->merge(0); + + TransformPropagatorWithCheck propagator(tv2); + MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator); + + tv0->computeAt(tv2, 1); + + // The predicate of tv2 should be protected with magic zero + GpuLower gpulw(&fusion); + TORCH_CHECK( + PredicateMagicZeroChecker::isProtected(tv2, gpulw), + "Failed to protect the predicates of ", + tv2->toString()); +} + +TEST_F(NVFuserTest, FusionRepro1860_CUDA) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr; + FusionGuard fg(&fusion); + std::vector contiguity{true, false, false}; + + std::vector shape{1, -1, -1}; + TensorView* tv0 = makeContigConcreteTensor(shape); + fusion.addInput(tv0); + TensorView* tv1 = makeContigConcreteTensor(shape); + fusion.addInput(tv1); + TensorView* tv2 = makeContigConcreteTensor(shape); + fusion.addInput(tv2); + + std::vector domain1(3, nullptr); + for (const auto i : c10::irange(3)) { + if (i == 0) { + domain1[i] = + IterDomainBuilder( + FusionGuard::getCurFusion()->zeroVal(), IrBuilder::create(1)) + .iter_type(IterType::Broadcast) + .build(); + } else { + domain1[i] = + IterDomainBuilder( + FusionGuard::getCurFusion()->zeroVal(), IrBuilder::create(1)) + .expanded_extent(IrBuilder::create(1 + i)) + .iter_type(IterType::Broadcast) + .build(); + } + } + + TensorView* tv22 = IrBuilder::create( + IrBuilder::create(domain1, contiguity), DataType::Float); + + fusion.addInput(tv22); + + auto tv3 = add(tv0, tv1); + auto tv4 = softmax(tv3, 0); + auto tv5 = add(tv4, tv22); + fusion.addOutput(tv5); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor input1 = at::randn({1, 2, 3}, options); + at::Tensor input2 = at::randn({1, 2, 3}, options); + at::Tensor input3 = at::randn({1, 2, 3}, options); + at::Tensor input4 = at::randn({1, 1, 1}, options).expand({1, 2, 3}); + std::vector aten_inputs = {input1, input2, input3, input4}; + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs(aten_inputs); +} + +TEST_F(NVFuserTest, FusionExpandReduce_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeConcreteTensor({1, 8}); + fusion->addInput(tv0); + + auto tv1 = + expand(tv0, {IrBuilder::create(12), IrBuilder::create(8)}); + + auto tv2 = sum(tv1, {0}); + fusion->addOutput(tv2); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + auto t0 = at::randn({1, 8}, options); + + FusionExecutorCache executor_cache(std::move(fusion)); + auto cg_outputs = executor_cache.runFusionWithInputs({t0}); + + auto ref = t0.expand({12, 8}).sum({0}); + + testValidate( + executor_cache.fusion(), cg_outputs, {t0}, {ref}, __LINE__, __FILE__); +} + +// Predicate elimination issue repro: +TEST_F(NVFuserTest, FusionExpandReduce2_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeConcreteTensor({1, 4}); + fusion->addInput(tv0); + + auto tv1 = + expand(tv0, {IrBuilder::create(3), IrBuilder::create(4)}); + + auto tv2 = sum(tv1, {0}); + fusion->addOutput(tv2); + + // tv2[r{3}, i{4}] + tv2->split(0, NamedScalar::getParallelDim(ParallelType::TIDy)); + tv2->axis(1)->parallelize(ParallelType::TIDy); + tv2->split(0, NamedScalar::getParallelDim(ParallelType::BIDy), false); + tv2->axis(0)->parallelize(ParallelType::BIDy); + tv2->split(-1, NamedScalar::getParallelDim(ParallelType::TIDx)); + tv2->axis(-1)->parallelize(ParallelType::TIDx); + tv2->axis(-2)->parallelize(ParallelType::BIDx); + // [rBIDy, rO, rTIDy, iBIDx, iTIDx] + tv2->reorder({{-2, 0}, {-1, 1}, {2, 2}}); + // [iBIDx, iTIDx, rTIDy, rBIDy, rO] + auto tv3 = tv2->rFactor({-1}); + + TransformPropagatorWithCheck propagator(tv3); + MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + scheduler_utils::parallelizeAllLike(tv3); + tv0->computeAt(tv3, -1, ComputeAtMode::MostInlined); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + auto t0 = at::randn({1, 4}, options); + + FusionExecutor fe; + fe.compileFusion(fusion.get(), {t0}, LaunchParams(-1, 2, -1, 4, 2, 1)); + auto cg_outputs = fe.runFusion({t0}, LaunchParams(-1, 2, -1, 4, 2, 1)); + + auto ref = t0.expand({3, 4}).sum({0}); + + testValidate( + fusion.get(), + cg_outputs, + {t0}, + {ref}, + __LINE__, + __FILE__, + "", + LaunchParams(-1, 2, -1, 4, 2, 1)); +} + +TEST_F(NVFuserTest, FusionExpandBadShapeTest_CUDA) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr; + FusionGuard fg(&fusion); + std::vector contiguity{false, false}; + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + std::vector domains = { + IterDomainBuilder( + FusionGuard::getCurFusion()->zeroVal(), IrBuilder::create()) + .build(), + IterDomainBuilder( + FusionGuard::getCurFusion()->zeroVal(), IrBuilder::create(1)) + .expanded_extent(IrBuilder::create(10)) + .iter_type(IterType::Broadcast) + .build()}; + + // expand to 10 + TensorView* tv22 = IrBuilder::create( + IrBuilder::create(domains, contiguity), DataType::Float); + + fusion.addInput(tv22); + + auto tv3 = add(tv0, tv22); + fusion.addOutput(tv3); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + // Incompatible shapes + at::Tensor input1 = at::randn({2, 3}, options); + // Passing expand size of 5, not 10. Should cause an error + at::Tensor input4 = at::randn({2, 1}, options).expand({2, 5}); + + std::vector aten_inputs = {input1, input4}; + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + ASSERT_ANY_THROW(executor_cache.runFusionWithInputs(aten_inputs)); +} + +TEST_F( + NVFuserTest, + FusionPointwiseScheduleWithBroadcastAndTrivialReduction_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(3); + auto tv1 = makeContigTensor(2); + fusion.addInput(tv0); + fusion.addInput(tv1); + auto tv2 = broadcast(tv0, {false, true, false, true, false, true}); + auto tv3 = sin(tv2); + auto tv4 = add(tv3, tv1); + auto tv5 = sum(tv4, {1}); + fusion.addOutput(tv5); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({100, 100, 10}, options); + at::Tensor t1 = at::randn({10, 20}, options); + + auto aten_output = (t0.view({100, 1, 100, 1, 10, 1}).sin() + t1).squeeze(1); + + std::vector aten_inputs = {t0, t1}; + + auto lparams = schedulePointwise(&fusion, aten_inputs); + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs, lparams); + auto cg_outputs = fe.runFusion(aten_inputs, lparams); + + testValidate( + &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionInliningMismatchedDims1_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({2, 3, 4}); + fusion.addInput(tv0); + auto tv1 = sin(tv0); + auto tv2 = cos(tv1); + auto tv3 = transpose(tv2, 1, 2); + auto tv4 = exp(tv3); + auto tv5 = tan(tv4); + fusion.addOutput(tv5); + + inlineMost(); + + TORCH_CHECK(tv5->getComputeAtPosition() == 3); + TORCH_CHECK(tv4->getComputeAtPosition() == 3); + TORCH_CHECK(tv3->getComputeAtPosition() == 3); + TORCH_CHECK(tv2->getComputeAtPosition() == 1); + TORCH_CHECK(tv1->getComputeAtPosition() == 3); + + const auto options = + at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input = at::randn({2, 3, 4}, options); + auto output = input.sin().cos().transpose(1, 2).exp().tan(); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input}); + auto cg_outputs = fe.runFusion({input}); + + testValidate(&fusion, cg_outputs, {input}, {output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionInliningMismatchedDims2_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({2, 3, 4}); + fusion.addInput(tv0); + auto tv1 = sin(tv0); + auto tv2 = cos(tv1); + auto tv3 = transpose(tv2, 1, 2); + auto tv4 = exp(tv3); + auto tv5 = tan(tv4); + fusion.addOutput(tv5); + + inlineAllAt(tv5, -1, true); + + TORCH_CHECK(tv5->getComputeAtPosition() == 3); + TORCH_CHECK(tv4->getComputeAtPosition() == 3); + TORCH_CHECK(tv3->getComputeAtPosition() == 3); + TORCH_CHECK(tv2->getComputeAtPosition() == 1); + TORCH_CHECK(tv1->getComputeAtPosition() == 1); + + const auto options = + at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input = at::randn({2, 3, 4}, options); + auto output = input.sin().cos().transpose(1, 2).exp().tan(); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input}); + auto cg_outputs = fe.runFusion({input}); + + testValidate(&fusion, cg_outputs, {input}, {output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionInliningMismatchedDims3_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({2, 3, 4}); + fusion.addInput(tv0); + auto tv1 = sin(tv0); + // broadcasting + auto tv2 = broadcast(tv1, {false, true, false, true, false, true}); + auto tv3 = relu(tv2); + // trivial reduction + auto tv4 = sum(tv3, {1, 3, 5}); + auto tv5 = cos(tv4); + auto tv6 = transpose(tv5, 1, 2); + auto tv7 = exp(tv6); + auto tv8 = tan(tv7); + fusion.addOutput(tv8); + + for (auto tv : {tv2, tv3, tv4}) { + tv->merge(0); + tv->merge(1); + tv->merge(2); + } + + inlineMost(); + + TORCH_CHECK(tv8->getComputeAtPosition() == 3); + TORCH_CHECK(tv7->getComputeAtPosition() == 3); + TORCH_CHECK(tv6->getComputeAtPosition() == 3); + TORCH_CHECK(tv5->getComputeAtPosition() == 1); + TORCH_CHECK(tv4->getComputeAtPosition() == 3); + TORCH_CHECK(tv3->getComputeAtPosition() == 3); + TORCH_CHECK(tv2->getComputeAtPosition() == 3); + TORCH_CHECK(tv1->getComputeAtPosition() == 3); + + const auto options = + at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input = at::randn({2, 3, 4}, options); + auto output = input.sin().relu().cos().transpose(1, 2).exp().tan(); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input}); + auto cg_outputs = fe.runFusion({input}); + + testValidate(&fusion, cg_outputs, {input}, {output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionInliningMismatchedDims4_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({2, 3, 4}); + fusion.addInput(tv0); + auto tv1 = sin(tv0); + auto tv2 = exp(tv1); + auto tv3 = relu(tv2); + auto tv4 = cos(tv3); + auto tv5 = tan(tv4); + fusion.addOutput(tv5); + + tv3->merge(1); + inlineMost(); + + TORCH_CHECK(tv5->getComputeAtPosition() == 3); + TORCH_CHECK(tv4->getComputeAtPosition() == 3); + TORCH_CHECK(tv3->getComputeAtPosition() == 1); + TORCH_CHECK(tv2->getComputeAtPosition() == 1); + TORCH_CHECK(tv1->getComputeAtPosition() == 3); + + const auto options = + at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input = at::randn({2, 3, 4}, options); + auto output = input.sin().exp().relu().cos().tan(); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input}); + auto cg_outputs = fe.runFusion({input}); + + testValidate(&fusion, cg_outputs, {input}, {output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionInliningBroadcast_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({2, 3, 4}); + fusion.addInput(tv0); + auto tv1 = sin(tv0); + // broadcasting + auto tv2 = broadcast(tv1, {false, true, false, true, false, true}); + auto tv3 = cos(tv2); + auto tv4 = tan(tv3); + fusion.addOutput(tv4); + + for (auto tv : {tv2, tv3, tv4}) { + tv->merge(0); + tv->merge(1); + tv->merge(2); + } + + inlineMost(); + + TORCH_CHECK(tv4->getComputeAtPosition() == 3); + TORCH_CHECK(tv3->getComputeAtPosition() == 3); + TORCH_CHECK(tv2->getComputeAtPosition() == 3); + TORCH_CHECK(tv1->getComputeAtPosition() == 3); + + const auto options = + at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input = at::randn({2, 3, 4}, options); + auto output = input.sin().view({2, 1, 3, 1, 4, 1}).cos().tan(); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input}); + auto cg_outputs = fe.runFusion({input}); + + testValidate(&fusion, cg_outputs, {input}, {output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionInliningBroadcastTrivialReduction_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({2, 3, 4}); + fusion.addInput(tv0); + auto tv1 = sin(tv0); + // broadcasting + auto tv2 = broadcast(tv1, {false, true, false, true, false, true}); + auto tv3 = tan(tv2); + // trivial reduction + auto tv4 = sum(tv3, {1, 3, 5}); + auto tv5 = cos(tv4); + auto tv6 = exp(tv5); + fusion.addOutput(tv6); + + for (auto tv : {tv2, tv3, tv4}) { + tv->merge(0); + tv->merge(1); + tv->merge(2); + } + + inlineMost(); + + TORCH_CHECK(tv6->getComputeAtPosition() == 3); + TORCH_CHECK(tv5->getComputeAtPosition() == 3); + TORCH_CHECK(tv4->getComputeAtPosition() == 3); + TORCH_CHECK(tv3->getComputeAtPosition() == 3); + TORCH_CHECK(tv2->getComputeAtPosition() == 3); + TORCH_CHECK(tv1->getComputeAtPosition() == 3); + + const auto options = + at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input = at::randn({2, 3, 4}, options); + auto output = input.sin().tan().cos().exp(); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input}); + auto cg_outputs = fe.runFusion({input}); + + testValidate(&fusion, cg_outputs, {input}, {output}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionMatchedLeafPosWithoutReplayTrivialReduction_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({2, 1, 3, 1, 4, 1}); + fusion.addInput(tv0); + auto tv1 = sum(tv0, {1, 3, 5}); + auto tv2 = sin(tv1); + fusion.addOutput(tv1); + + for (auto tv : {tv0, tv1}) { + tv->merge(0); + tv->merge(1); + tv->merge(2); + } + + TORCH_CHECK( + TransformReplay::getMatchedLeafPosWithoutReplayPasC(tv0, tv1, 3) == 3); + TORCH_CHECK( + TransformReplay::getMatchedLeafPosWithoutReplayCasP(tv1, tv0, 3) == 3); + TORCH_CHECK( + TransformReplay::getMatchedLeafPosWithoutReplayPasC(tv1, tv2, 3) == 3); + TORCH_CHECK( + TransformReplay::getMatchedLeafPosWithoutReplayCasP(tv2, tv1, 3) == 3); +} + +TEST_F(NVFuserTest, FusionMatchedLeafPosWithoutReplayBroadcast_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({2, 3, 4}); + fusion.addInput(tv0); + auto tv1 = broadcast(tv0, {false, true, false, true, false, true}); + auto tv2 = sin(tv1); + fusion.addOutput(tv2); + + for (auto tv : {tv1, tv2}) { + tv->merge(0); + tv->merge(1); + tv->merge(2); + } + + TORCH_CHECK( + TransformReplay::getMatchedLeafPosWithoutReplayPasC(tv0, tv1, 3) == 3); + TORCH_CHECK( + TransformReplay::getMatchedLeafPosWithoutReplayCasP(tv1, tv0, 3) == 3); + TORCH_CHECK( + TransformReplay::getMatchedLeafPosWithoutReplayPasC(tv1, tv2, 3) == 3); + TORCH_CHECK( + TransformReplay::getMatchedLeafPosWithoutReplayCasP(tv2, tv1, 3) == 3); +} + +TEST_F(NVFuserTest, FusionIdGraphTrivialReduction_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({2, 3, 4}); + fusion.addInput(tv0); + auto tv1 = broadcast(tv0, {false, true, false, true, false, true}); + auto tv2 = sum(tv1, {1, 3, 5}); + auto tv3 = sin(tv2); + fusion.addOutput(tv3); + + for (auto tv : {tv1, tv2}) { + tv->merge(0); + tv->merge(1); + tv->merge(2); + } + + inlineMost(); + + ComputeAtMap ca_map(&fusion); + + auto all_tvs = ir_utils::allTvs(&fusion); + for (auto tv1 : all_tvs) { + for (auto tv2 : all_tvs) { + if (tv1->isFusionInput() || tv2->isFusionInput()) { + continue; + } + for (int i : c10::irange(3)) { + auto id1 = tv1->axis(i); + auto id2 = tv2->axis(i); + TORCH_CHECK(ca_map.areMapped(id1, id2, IdMappingMode::LOOP)); + TORCH_CHECK(ca_map.areMapped(id1, id2, IdMappingMode::PERMISSIVE)); + } + } + } +} + +TEST_F(NVFuserTest, FusionPrint_CUDA) { + auto dtypes = { + at::kFloat, + at::kDouble, + at::kHalf, + at::kBFloat16, + at::kInt, + at::kLong, + at::kBool}; + for (auto dtype : dtypes) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeSymbolicTensor(1, aten_to_data_type(dtype)); + fusion->addInput(tv0); + auto tv1 = print(tv0); + auto tv2 = sin(tv1); + fusion->addOutput(tv2); + + // There is no way to check if anything is printed to the console, but we + // can validate that when print exist, compilation and computation are not + // broken. + auto options = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0); + at::Tensor t0 = at::arange(2, options).to(dtype); + + FusionExecutorCache executor_cache(std::move(fusion)); + auto cg_outputs = executor_cache.runFusionWithInputs({t0}); + + testValidate( + executor_cache.fusion(), + cg_outputs, + {t0}, + {t0.sin()}, + __LINE__, + __FILE__); + } +} + +TEST_F(NVFuserTest, FusionCheckedSymbolicShape_CUDA) { + const auto options = + at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor a = at::randn({123, 456}, options); + at::Tensor b = at::randn({123, 456}, options); + at::Tensor c = at::randn({321, 654}, options); + + using return_t = + std::pair, std::vector>; + auto matched_add = [](at::Tensor a, at::Tensor b) -> return_t { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + Val* s1 = IrBuilder::create(); + Val* s2 = IrBuilder::create(); + auto builder = TensorViewBuilder().shape(std::vector{s1, s2}); + TensorView* tv0 = builder.build(); + TensorView* tv1 = builder.build(); + + fusion->addInput(tv0); + fusion->addInput(tv1); + + auto tv2 = add(tv0, tv1); + + fusion->addOutput(tv2); + + auto executor_cache = + std::make_unique(std::move(fusion)); + auto cg_outputs = executor_cache->runFusionWithInputs({a, b}); + return {std::move(executor_cache), std::move(cg_outputs)}; + }; + + { + auto ret1 = matched_add(a, b); + testValidate( + ret1.first->fusion(), ret1.second, {a, b}, {a + b}, __LINE__, __FILE__); + } + + { + EXPECT_THAT( + [&]() { matched_add(a, c); }, + ::testing::ThrowsMessage( + ::testing::HasSubstr("Attempting to bind"))); + } +} + +TEST_F(NVFuserTest, FusionSizeDependentData_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + Val* s1 = IrBuilder::create(); + auto builder = TensorViewBuilder().shape(std::vector{s1}); + TensorView* tv0 = builder.build(); + + fusion->addInput(tv0); + + auto tv1 = add(tv0, s1); + + fusion->addOutput(tv1); + + const auto options = + at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor a = at::zeros({123}, options); + + FusionExecutorCache executor_cache(std::move(fusion)); + auto cg_outputs = executor_cache.runFusionWithInputs({a}); + + testValidate( + executor_cache.fusion(), cg_outputs, {a}, {a + 123}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionDependencyCheck_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeSymbolicTensor(1); + TensorView* tv1 = makeSymbolicTensor(1); + TensorView* tv2 = makeSymbolicTensor(1); + TensorView* tv3 = makeSymbolicTensor(1); + + auto tv4 = add(tv0, tv1); + auto tv5 = add(tv0, tv2); + auto tv6 = add(tv0, tv3); + + auto tv7 = add(tv1, tv2); + auto tv8 = add(tv1, tv3); + + auto tv9 = add(tv2, tv3); + + { + auto all_vals = DependencyCheck::getAllValsBetween( + {tv0, tv1}, {tv4, tv5, tv6, tv7, tv8, tv9}); + std::unordered_set all_vals_set(all_vals.begin(), all_vals.end()); + std::vector results({tv0, tv1, tv4, tv5, tv6, tv7, tv8}); + for (auto result : results) { + TORCH_CHECK(all_vals_set.count(result) > 0); + all_vals_set.erase(result); + } + TORCH_CHECK(all_vals_set.empty()); + } + + auto tv10 = add(tv6, tv7); + { + auto all_vals = DependencyCheck::getAllValsBetween({tv0, tv1}, {tv10}); + std::unordered_set all_vals_set(all_vals.begin(), all_vals.end()); + std::vector results({tv0, tv1, tv6, tv7, tv10}); + for (auto result : results) { + TORCH_CHECK(all_vals_set.count(result) > 0); + all_vals_set.erase(result); + } + TORCH_CHECK(all_vals_set.empty()); + } +} + +// Repro for issue #1925 +TEST_F(NVFuserTest, FusionScheduleTransposeRepro1_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(4); + auto tv1 = makeConcreteTensor({-1, -1, -1, 1}); + fusion.addInput(tv0); + fusion.addInput(tv1); + auto tv2 = add(tv0, tv1); + fusion.addOutput(tv2); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input0 = at::randn({1, 1, 333, 1}, options); + at::Tensor input1 = at::randn({1, 1, 333, 1}, options); + + auto lparams = scheduleTranspose(&fusion, {input0, input1}); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input0, input1}, lparams); + auto outputs = fe.runFusion({input0, input1}, lparams); + + auto tv_ref = input0 + input1; + + testValidate( + &fusion, outputs, {input0, input1}, {tv_ref}, __LINE__, __FILE__); +} + +// Repro for issue #1873 +TEST_F(NVFuserTest, FusionInlineBroadcastIndexing0_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(1); + auto tv1 = makeContigTensor(2); + fusion.addInput(tv0); + fusion.addInput(tv1); + auto tv2 = set(tv0); + auto tv3 = broadcast(tv2, {true, false}); + auto tv4 = add(tv3, tv1); + fusion.addOutput(tv4); + + tv4->merge(0); + tv4->split(0, 32); + + tv0->computeAt(tv4, 1); + + tv2->split(-1, 8); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({123}, options); + at::Tensor t1 = at::randn({3, 123}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t1}); + + auto outputs = fe.runFusion({t0, t1}); + + auto tv_ref = t0 + t1; + + testValidate(&fusion, outputs, {t0, t1}, {tv_ref}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionPredicateUnshare_CUDA) { + // https://github.com/csarofeen/pytorch/issues/1926 + std::unique_ptr fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + TensorView* tv0 = makeSymbolicTensor(2); + fusion->addInput(tv0); + auto tv1 = set(tv0); + auto tv2 = set(tv1); + fusion->addOutput(tv2); + + tv1->setMemoryType(MemoryType::Shared); + for (auto tv : {tv1, tv2}) { + tv->split(0, 4); + tv->reorder({{1, -1}}); + tv->split(1, 8); + tv->merge(0); + tv->split(0, 1); + tv->axis(0)->parallelize(ParallelType::BIDx); + tv->axis(1)->parallelize(ParallelType::Unswitch); + } + tv1->merge(2); + tv2->reorder({{2, 3}}); + tv2->merge(2); + for (auto tv : {tv1, tv2}) { + tv->axis(-1)->parallelize(ParallelType::TIDx); + } + + inlineMost(); + + auto options = at::TensorOptions().dtype(kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({5, 5}, options); + + FusionExecutor fe; + fe.compileFusion(fusion, {t0}); + auto cg_outputs = fe.runFusion({t0}); + auto out = cg_outputs[0]; + + testValidate(fusion, {out}, {t0}, {t0}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, AsyncCompilation_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + TensorView* tv0 = makeSymbolicTensor(2); + TensorView* tv1 = makeSymbolicTensor(1); + TensorView* tv2 = makeSymbolicTensor(2); + + fusion->addInput(tv0); + fusion->addInput(tv1); + fusion->addInput(tv2); + + TensorView* tv3 = add(tv0, IrBuilder::create(1)); // Group 0 + TensorView* tv4 = + max(tv3, {0}); // Group 0 (use max instead to avoid numerical issues) + TensorView* tv5 = add(tv4, tv1); // Group 0 (Non Broadcast after reduce, + // keeps normalization scheduler away) + TensorView* tv6 = add(tv5, tv2); // Group 1 (Broadcast after reduce) + + fusion->addOutput(tv6); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor t0 = at::randn({8, 5}, options); + at::Tensor t1 = at::randn({5}, options); + at::Tensor t2 = at::randn({8, 5}, options); + + auto t3 = t0.add(1.0); + auto t4 = std::get<0>(at::max(t3, 0)); + auto t5 = t4.add(t1); + auto t6 = t5.add(t2); + + FusionExecutorCache executor_cache(std::move(fusion)); + + std::vector aten_inputs = {t0, t1, t2}; + + executor_cache.compileFusionAsync(aten_inputs); + + while (!executor_cache.isCompiled(aten_inputs)) { + std::this_thread::sleep_for(std::chrono::milliseconds(20)); + printf("."); + } + + auto outputs = executor_cache.runFusionWithInputs(aten_inputs); + + TORCH_CHECK( + executor_cache.getMostRecentKernelRuntime()->isSegmented(), + "segmentation didn't happen"); + TORCH_CHECK( + executor_cache.getMostRecentKernelRuntime() + ->fusionSegments() + ->groups() + .size() == 2, + "segmentation didn't happen as expected"); + + testValidate( + executor_cache.fusion(), outputs, aten_inputs, {t6}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionMergeBroadcastingTrivialReduction1_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + TensorView* tv0 = makeConcreteTensor({1, 1}); + TensorView* tv1 = makeConcreteTensor({-1}); + fusion->addInput(tv0); + fusion->addInput(tv1); + auto tv2 = sum(tv0, {1}); + auto tv3 = add(tv2, tv1); + fusion->addOutput(tv3); + + tv0->merge(0); + + MaxRootDomainInfoSpanningTree tree(tv0); + TransformPropagatorWithCheck tp(tv0); + tree.traverse(&tp); + + inlineMost(); + + auto options = at::TensorOptions().dtype(kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({1, 1}, options); + at::Tensor t1 = at::randn({10}, options); + + FusionExecutor fe; + fe.compileFusion(fusion, {t0, t1}); + auto cg_outputs = fe.runFusion({t0, t1}); + auto out = cg_outputs[0]; + + testValidate( + fusion, {out}, {t0, t1}, {t1 + t0.flatten()}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionMergeBroadcastingTrivialReduction2_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + TensorView* tv0 = makeConcreteTensor({-1, 1, 1}); + TensorView* tv1 = makeConcreteTensor({-1, -1}); + fusion->addInput(tv0); + fusion->addInput(tv1); + auto tv2 = sum(tv0, {1}); + auto tv3 = add(tv2, tv1); + fusion->addOutput(tv3); + + tv2->merge(1); + tv2->merge(0); + + MaxRootDomainInfoSpanningTree tree(tv0); + TransformPropagatorWithCheck tp(tv0); + tree.traverse(&tp); + + inlineMost(); + + auto options = at::TensorOptions().dtype(kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({10, 1, 1}, options); + at::Tensor t1 = at::randn({10, 10}, options); + + FusionExecutor fe; + fe.compileFusion(fusion, {t0, t1}); + auto cg_outputs = fe.runFusion({t0, t1}); + auto out = cg_outputs[0]; + + testValidate( + fusion, {out}, {t0, t1}, {t1 + t0.squeeze(-1)}, __LINE__, __FILE__); +} + +// Simple test case exercising the null scheduler path. +TEST_F(NVFuserTest, FusionNullScheduler_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeConcreteTensor({1, 1, 1}); + fusion->addInput(tv0); + + auto tv1 = sum(tv0, {0, 1, 2}); + + fusion->addOutput(tv1); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({1, 1, 1}, options); + + std::vector aten_inputs({t0}); + + FusionExecutorCache executor_cache(std::move(fusion)); + auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); + + auto t1 = t0.sum({0, 1, 2}); + + testValidate( + executor_cache.fusion(), cg_outputs, {t0}, {t1}, __LINE__, __FILE__); + + auto groups = + executor_cache.getMostRecentKernelRuntime()->fusionSegments()->groups(); + + // Check that all groups on the resulting runtime are null. + for (auto group : groups) { + TORCH_INTERNAL_ASSERT(group->heuristic() == ScheduleHeuristic::NoOp); + } +} + +// Simple test case exercising the null scheduler path. +TEST_F(NVFuserTest, FusionNullScheduler2_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeConcreteTensor({0, 1, 9223372036854775807L}); + fusion->addInput(tv0); + + auto tv1 = sum(tv0, {0, 1, 2}); + + fusion->addOutput(tv1); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({0, 1, 9223372036854775807L}, options); + + std::vector aten_inputs({t0}); + + FusionExecutorCache executor_cache(std::move(fusion)); + auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); + + auto t1 = t0.sum({0, 1, 2}); + + testValidate( + executor_cache.fusion(), cg_outputs, {t0}, {t1}, __LINE__, __FILE__); + + auto groups = + executor_cache.getMostRecentKernelRuntime()->fusionSegments()->groups(); + + // Check that all groups on the resulting runtime are null. + for (auto group : groups) { + TORCH_INTERNAL_ASSERT(group->heuristic() == ScheduleHeuristic::NoOp); + } +} + +// Simple test case exercising the null scheduler path. +TEST_F(NVFuserTest, FusionNullScheduler3_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = TensorViewBuilder().ndims(0).build(); + auto tv1 = TensorViewBuilder().ndims(0).build(); + fusion->addInput(tv0); + fusion->addInput(tv1); + auto tv2 = add(tv0, tv1); + fusion->addOutput(tv2); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({}, options); + at::Tensor t1 = at::randn({}, options); + + std::vector aten_inputs({t0, t1}); + + FusionExecutorCache executor_cache(std::move(fusion)); + auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); + + testValidate( + executor_cache.fusion(), + cg_outputs, + {t0, t1}, + {t0 + t1}, + __LINE__, + __FILE__); + + auto groups = + executor_cache.getMostRecentKernelRuntime()->fusionSegments()->groups(); + + // Check that all groups on the resulting runtime are null. + for (auto group : groups) { + TORCH_INTERNAL_ASSERT(group->heuristic() == ScheduleHeuristic::NoOp); + } +} + +TEST_F(NVFuserTest, FusionEmpty_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeConcreteTensor({10, 10, 10}); + auto tv1 = makeConcreteTensor({10, 10, 10}); + fusion->addInput(tv0); + fusion->addInput(tv1); + fusion->addOutput(tv0); + fusion->addOutput(tv1); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({10, 10, 10}, options); + at::Tensor t1 = at::randn({10, 10, 10}, options); + + std::vector aten_inputs({t0, t1}); + + FusionExecutorCache executor_cache(std::move(fusion)); + auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); + + testValidate( + executor_cache.fusion(), + cg_outputs, + {t0, t1}, + {t0, t1}, + __LINE__, + __FILE__); + + auto groups = + executor_cache.getMostRecentKernelRuntime()->fusionSegments()->groups(); + + // Check that all groups on the resulting runtime are null. + for (auto group : groups) { + TORCH_INTERNAL_ASSERT(group->heuristic() == ScheduleHeuristic::NoOp); + } +} + +TEST_F(NVFuserTest, FusionMappingRelation_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + TensorView* tv0 = makeConcreteTensor({1, 1}); + TensorView* tv1 = makeConcreteTensor({-1, 1, 1}); + fusion->addInput(tv0); + fusion->addInput(tv1); + auto tv2 = set(tv0); + auto tv3 = broadcast(tv2, {true, false, false}); + auto tv4 = add(tv3, tv1); + + fusion->addOutput(tv4); + + tv4->merge(-2); + tv4->merge(-1); + + tv0->computeAt(tv4, -1); + tv1->computeAt(tv4, -1); + + ComputeAtMap ca_map(fusion); + + // FIXME: This is the concerning part that would motivate some + // more formalization on concrete/permissive mapping: + // exact mapping should ideally imply permissive mapping. + auto tv4_inner_node = tv4->axis(0)->definition()->input(1)->as(); + TORCH_CHECK( + ca_map.areMapped(tv2->axis(0), tv4_inner_node, IdMappingMode::EXACT)); + TORCH_CHECK(!ca_map.areMapped( + tv2->axis(0), tv4_inner_node, IdMappingMode::PERMISSIVE)); + + auto options = at::TensorOptions().dtype(kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({1, 1}, options); + at::Tensor t1 = at::randn({2, 1, 1}, options); + + FusionExecutor fe; + fe.compileFusion(fusion, {t0, t1}); + auto cg_outputs = fe.runFusion({t0, t1}); + auto out = cg_outputs[0]; + + testValidate( + fusion, {out}, {t0, t1}, {t1 + t0.squeeze(0)}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionInlineAt_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + TensorView* tv0 = makeSymbolicTensor(2); + fusion->addInput(tv0); + auto tv1 = sin(tv0); + auto tv2 = cos(tv1); + fusion->addOutput(tv2); + + tv1->inlineAt(-1); + + auto options = at::TensorOptions().dtype(kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({100, 2}, options); + + FusionExecutor fe; + fe.compileFusion(fusion, {t0}); + auto cg_outputs = fe.runFusion({t0}); + auto out = cg_outputs[0]; + + testValidate(fusion, {out}, {t0}, {t0.sin().cos()}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionTrivialInputForwarding_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + TensorView* tv0 = makeConcreteTensor({-1, -1}); + TensorView* tv1 = makeConcreteTensor({-1, -1}); + fusion->addInput(tv0); + fusion->addInput(tv1); + // Note: tv2 is not needed. Kept it here since previously there was an + // assertion from sorting in codegen. + auto tv2 = add(tv1, IrBuilder::create(3.141)); + fusion->addOutput(tv0); + + auto options = at::TensorOptions().dtype(kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({10, 4}, options); + at::Tensor t1 = at::randn({10, 4}, options); + + FusionExecutorCache fec(std::move(fusion_ptr)); + auto cg_outputs = fec.runFusionWithInputs({t0, t1}); + + testValidate(fusion, cg_outputs, {t0, t1}, {t0}, __LINE__, __FILE__); + + // Second run to ensure cache hit handles trivial forwarding properly + TORCH_CHECK(fec.isCompiled({t0, t1})); + auto cg_outputs2 = fec.runFusionWithInputs({t0, t1}); + testValidate(fusion, cg_outputs2, {t0, t1}, {t0}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionTrivialInputForwarding2_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + TensorView* tv0 = makeSymbolicTensor(0); + fusion->addInput(tv0); + fusion->addOutput(tv0); + + auto options = at::TensorOptions().dtype(kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({}, options); + + FusionExecutorCache fec(std::move(fusion_ptr)); + auto cg_outputs = fec.runFusionWithInputs({t0}); + + testValidate(fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__); + + // Second run to ensure cache hit handles trivial forwarding properly + TORCH_CHECK(fec.isCompiled({t0})); + auto cg_outputs2 = fec.runFusionWithInputs({t0}); + testValidate(fusion, cg_outputs2, {t0}, {t0}, __LINE__, __FILE__); +} + +// Simplified repro of issue #2008 +TEST_F(NVFuserTest, FusionReplayTrivialReductionAndBroadcast2_CUDA) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + std::vector shape({10, 1, 1}); + + auto tv0 = makeConcreteTensor(shape); + fusion.addInput(tv0); + + auto tv1 = add(tv0, IrBuilder::create(1)); + auto tv2 = sum(tv1, {1, 2}); + auto tv3 = broadcast(tv2, {false, true, true}); + fusion.addOutput(tv3); + + tv0->merge(-2, -1)->merge(-2, -1)->split(0, 4); + + MaxRootDomainInfoSpanningTree tree(tv0); + TransformPropagator tp(tv0); + tree.traverse(&tp); + + auto options = at::TensorOptions().dtype(kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn(shape, options); + std::vector aten_inputs({t0}); + + FusionExecutor fe; + fe.compileFusion(fusion_ptr.get(), aten_inputs); + auto outputs = fe.runFusion(aten_inputs); + + testValidate(&fusion, outputs, aten_inputs, {t0 + 1}, __LINE__, __FILE__); +} + +namespace { + +size_t getVecSizeForPointwise(FusionExecutorCache& fec) { + auto most_recent_params = + fec.getMostRecentKernelRuntime()->getMostRecentExecutorLog().params; + auto params = dynamic_cast(most_recent_params.get()); + if (params->vectorize) { + return params->unroll_factor; + } + return 1; +} + +} // namespace + +TEST_F(NVFuserTest, FusionVectorizeStrideContiguity2D_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + TensorView* tv0 = + TensorViewBuilder().ndims(2).contiguity({false, true}).build(); + fusion->addInput(tv0); + auto tv1 = set(tv0); + fusion->addOutput(tv1); + + FusionExecutorCache fec(std::move(fusion_ptr)); + fec.profile(true); + + std::vector> size_and_vec{{17, 1}, {18, 2}, {32, 4}}; + + for (auto pair : size_and_vec) { + auto size = pair.first; + auto vec = pair.second; + auto options = at::TensorOptions().dtype(kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({1000000, size}, options).narrow(1, 0, 16); + auto cg_outputs = fec.runFusionWithInputs({t0}); + + TORCH_CHECK(getVecSizeForPointwise(fec) == vec); + + testValidate(fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__); + } +} + +TEST_F(NVFuserTest, FusionVectorizeStrideContiguity3D_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + TensorView* tv0 = + TensorViewBuilder().ndims(3).contiguity({false, true, true}).build(); + fusion->addInput(tv0); + auto tv1 = set(tv0); + fusion->addOutput(tv1); + + FusionExecutorCache fec(std::move(fusion_ptr)); + fec.profile(true); + + std::vector> size_and_vec{{17, 1}, {10, 2}, {16, 4}}; + + for (auto pair : size_and_vec) { + auto size = pair.first; + auto vec = pair.second; + auto options = at::TensorOptions().dtype(kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({1000000, size, 3}, options).narrow(1, 0, 8); + auto cg_outputs = fec.runFusionWithInputs({t0}); + + TORCH_CHECK(getVecSizeForPointwise(fec) == vec); + + testValidate(fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__); + } +} + +TEST_F(NVFuserTest, FusionVectorizeStrideContiguity5D_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + TensorView* tv0 = TensorViewBuilder() + .ndims(5) + .contiguity({false, true, false, true, true}) + .build(); + fusion->addInput(tv0); + auto tv1 = set(tv0); + fusion->addOutput(tv1); + + FusionExecutorCache fec(std::move(fusion_ptr)); + fec.profile(true); + + auto options = at::TensorOptions().dtype(kFloat).device(at::kCUDA, 0); + + std::vector> sizes_and_vec{ + {9, 17, 1}, {9, 10, 2}, {9, 16, 4}}; + + for (auto tup : sizes_and_vec) { + auto size1 = std::get<0>(tup); + auto size2 = std::get<1>(tup); + auto vec = std::get<2>(tup); + at::Tensor t0 = at::randn({4, size1, 12345, size2, 3}, options) + .narrow(1, 0, 8) + .narrow(3, 0, 4); + auto cg_outputs = fec.runFusionWithInputs({t0}); + + TORCH_CHECK(getVecSizeForPointwise(fec) == vec); + + testValidate(fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__); + } +} + +TEST_F(NVFuserTest, FusionVectorizeStrideContiguitySelfOverlapping_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + TensorView* tv0 = TensorViewBuilder() + .ndims(5) + .contiguity({false, true, false, true, true}) + .build(); + fusion->addInput(tv0); + auto tv1 = set(tv0); + fusion->addOutput(tv1); + + FusionExecutorCache fec(std::move(fusion_ptr)); + fec.profile(true); + + auto options = at::TensorOptions().dtype(kFloat).device(at::kCUDA, 0); + + std::vector> sizes_strides_and_vec{ + {4, 4, 4, 4}, + {4, 4, 2, 2}, + {4, 2, 4, 2}, + {2, 4, 4, 2}, + {4, 4, 1, 1}, + {4, 1, 4, 1}, + {1, 4, 4, 1}, + {2, 2, 2, 2}, + {2, 2, 1, 1}, + {2, 1, 2, 1}, + {1, 2, 2, 1}}; + + for (auto tup : sizes_strides_and_vec) { + auto size = std::get<0>(tup); + auto stride1 = std::get<1>(tup); + auto stride2 = std::get<2>(tup); + auto vec = std::get<3>(tup); + std::vector shape = {4, 4, 12345, size, 3}; + std::vector stride = {stride1, stride2 * 12345, stride2, 3, 1}; + at::Tensor t0 = at::empty_strided(shape, stride, options); + t0.random_(); + auto cg_outputs = fec.runFusionWithInputs({t0}); + TORCH_CHECK(getVecSizeForPointwise(fec) == vec); + testValidate(fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__); + } +} + +TEST_F(NVFuserTest, FusionSimpleAmperePipeline_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // requires ampere+ GPU + if (!deviceMajorMinorCheck(8)) { + GTEST_SKIP() << "skipping tests on pre-AMPERE GPUs"; + return; + } + + auto tv0 = makeContigTensor(1); + + fusion.addInput(tv0); + + auto tv1 = set(tv0); + + fusion.addOutput(tv1); + + auto tv_cache = tv0->cacheAfter(LoadStoreOpType::CpAsync); + tv_cache->setMemoryType(MemoryType::Shared); + + tv1->split(0, 16); + tv0->computeAt(tv1, 1); + + tv_cache->circularBuffer(10); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input1 = at::randn({255}, options); + + // Add check that the cp async op has an inlined predicate. + class InlinedCpAsyncPredChecker : public kir::IrVisitor { + public: + using kir::IrVisitor::handle; + + private: + void handle(kir::IfThenElse* ite) final { + auto prev_within_ite = within_ite_; + within_ite_ = true; + kir::IrVisitor::handle(ite); + within_ite_ = prev_within_ite; + } + + void handle(LoadStoreOp* ldst) final { + if (ldst->opType() == LoadStoreOpType::CpAsync) { + TORCH_INTERNAL_ASSERT(!within_ite_, "CPASYNC predicate not inlined"); + TORCH_INTERNAL_ASSERT( + ldst->predicate()->hasValue() && + !ldst->predicate()->value()->isConst(), + "CPASYNC predicate is not generated"); + } + } + + private: + bool within_ite_ = false; + } pred_checker; + + // Check that cp async is inlined: + GpuLower gpulw(&fusion); + pred_checker.handle(gpulw.kernel()->topLevelExprs()); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input1}); + auto cg_outputs = fe.runFusion({input1}); + + testValidate(&fusion, cg_outputs, {input1}, {input1}, __LINE__, __FILE__); +} + +// Test file size should be up to 10K LoC. Create a new file for more tests. + +} // namespace jit +} // namespace torch +#endif // #if defined(USE_CUDA) diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_fused_reduction.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu_fused_reduction.cpp index 3b9e7cbd962c6..e827de56e56bd 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_fused_reduction.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_fused_reduction.cpp @@ -10,7 +10,7 @@ #include #include #include -#include +#include #include #include #include @@ -2391,10 +2391,7 @@ TEST_F(NVFuserTest, FusionCrossIterationGroupedGridAllreduceWelfordShmoo_CUDA) { transform_ref_rf->axis(unswitch_id)->parallelize(ParallelType::Unswitch); - InlinePropagator inline_propagator( - transform_ref_rf, -1, ComputeAtMode::MostInlined); - MaxRootDomainInfoSpanningTree(transform_ref_rf) - .traverse(&inline_propagator); + inlineMost(); // Make sure the reduction expr is converted to GroupedGridReduciton // and the non-reduction domains of the output TV are either diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_rng.cu b/torch/csrc/jit/codegen/cuda/test/test_gpu_rng.cu index e6acc4c5307a1..a1ff6562e6bda 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_rng.cu +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_rng.cu @@ -141,7 +141,7 @@ TEST_F(NVFuserTest, FusionRNGManualScheduleValidateWithCURand_CUDA) { TensorView* tv0 = makeSymbolicTensor(1, aten_to_data_type(dtype)); fusion->addInput(tv0); - auto tv1 = randlike(tv0); + auto tv1 = rand_like(tv0); auto tv2 = set(tv1); fusion->addOutput(tv2); @@ -166,6 +166,41 @@ TEST_F(NVFuserTest, FusionRNGManualScheduleValidateWithCURand_CUDA) { testValidate(fusion, {out}, {t0}, {ref}, __LINE__, __FILE__); } +TEST_F(NVFuserTest, FusionRNGManualScheduleValidateWithCURand2_CUDA) { +#ifdef FBCODE_CAFFE2 + GTEST_SKIP() << "Fails accuracy on V100 32gb"; +#endif + auto dtype = kFloat; + std::unique_ptr fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + Int* size1 = IrBuilder::create(); + Int* size2 = IrBuilder::create(); + Int* size3 = IrBuilder::create(); + Int* size4 = IrBuilder::create(); + fusion->addInput(size1); + fusion->addInput(size2); + fusion->addInput(size3); + fusion->addInput(size4); + TensorView* tv0 = rand({size1, size2, size3, size4}, DataType::Float); + fusion->addOutput(tv0); + + auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA, 0); + + FusionExecutor fe; + fe.compileFusion(fusion, {10, 10, 10, 10}); + + at::manual_seed(0); + auto cg_outputs = fe.runFusion({10, 10, 10, 10}); + auto out = cg_outputs[0]; + + at::manual_seed(0); + auto ref = generate_uniform(10000, dtype).view({10, 10, 10, 10}); + + testValidate(fusion, {out}, {10, 10, 10, 10}, {ref}, __LINE__, __FILE__); +} + TEST_F(NVFuserTest, FusionBroadcastingRNG_CUDA) { for (auto dtype : {kFloat, kDouble}) { std::unique_ptr fusion_ptr = std::make_unique(); @@ -176,7 +211,7 @@ TEST_F(NVFuserTest, FusionBroadcastingRNG_CUDA) { TensorView* tv1 = makeConcreteTensor({5, 5}, aten_to_data_type(dtype)); fusion->addInput(tv0); fusion->addInput(tv1); - auto tv2 = randlike(tv0); + auto tv2 = rand_like(tv0); auto tv3 = add(tv1, tv2); auto tv4 = add(tv0, tv3); fusion->addOutput(tv4); @@ -207,7 +242,7 @@ TEST_F(NVFuserTest, FusionBroadcastingRNG2_CUDA) { TensorView* tv1 = makeSymbolicTensor(1, aten_to_data_type(dtype)); fusion->addInput(tv0); fusion->addInput(tv1); - auto tv2 = randlike(tv0); + auto tv2 = rand_like(tv0); auto tv3 = add(tv1, tv2); fusion->addOutput(tv3); @@ -239,7 +274,7 @@ TEST_F(NVFuserTest, FusionBroadcastingRNGSmem_CUDA) { TensorView* tv1 = makeConcreteTensor({5, 5}, aten_to_data_type(dtype)); fusion->addInput(tv0); fusion->addInput(tv1); - auto tv2 = randlike(tv0); + auto tv2 = rand_like(tv0); auto tv3 = add(tv1, tv2); auto tv4 = add(tv0, tv3); fusion->addOutput(tv4); @@ -272,7 +307,7 @@ TEST_F(NVFuserTest, FusionBroadcastingRNGSmemNonSquareTile_CUDA) { TensorView* tv1 = makeConcreteTensor({5, 5}); fusion->addInput(tv0); fusion->addInput(tv1); - auto tv2 = randlike(tv0); + auto tv2 = rand_like(tv0); auto tv3 = add(tv1, tv2); auto tv4 = add(tv0, tv3); fusion->addOutput(tv4); @@ -297,5 +332,71 @@ TEST_F(NVFuserTest, FusionBroadcastingRNGSmemNonSquareTile_CUDA) { TORCH_CHECK((out.select(1, 0) == out.select(1, 4)).all().item()); } +TEST_F(NVFuserTest, FusionUniform_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + Int* size_val = IrBuilder::create(); + Double* low = IrBuilder::create(); + Double* high = IrBuilder::create(); + fusion->addInput(size_val); + fusion->addInput(low); + fusion->addInput(high); + TensorView* tv0 = uniform({size_val}, low, high, DataType::Float); + TensorView* tv1 = uniform({size_val}, low, high, DataType::Double); + fusion->addOutput(tv0); + fusion->addOutput(tv1); + + FusionExecutorCache fec(std::move(fusion_ptr)); + + for (int64_t size : {16, 1024, 10001, 10002, 10003, 100000, 10000001}) { + at::manual_seed(0); + auto cg_outputs = fec.runFusionWithInputs({size, -1.0, 1.0}); + + at::manual_seed(0); + auto ref0 = generate_uniform(size, kFloat) * 2 - 1; + auto ref1 = generate_uniform(size, kDouble) * 2 - 1; + + testValidate( + fec.fusion(), + cg_outputs, + {size, -1.0, 1.0}, + {ref0, ref1}, + __LINE__, + __FILE__); + } +} + +TEST_F(NVFuserTest, FusionRandLikeReduction_CUDA) { + auto dtype = kFloat; + std::unique_ptr fusion_ptr = std::make_unique(); + auto fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + TensorView* tv0 = makeSymbolicTensor(2, aten_to_data_type(dtype)); + fusion->addInput(tv0); + auto tv1 = sum(tv0, {0}); + auto tv2 = rand_like(tv1); + auto tv3 = add(tv1, tv2); + fusion->addOutput(tv3); + + FusionExecutorCache fec(std::move(fusion_ptr)); + + auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA, 0); + at::Tensor t0 = at::zeros({2, 3}, options); + + at::manual_seed(0); + auto cg_outputs = fec.runFusionWithInputs({t0}); + auto out = cg_outputs[0]; + + at::manual_seed(0); + auto t1 = t0.sum(0); + auto t2 = generate_uniform(3, dtype).expand_as(t1); + auto t3 = t1.add(t2); + + testValidate(fec.fusion(), {out}, {t0}, {t3}, __LINE__, __FILE__); +} + } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_scheduler_utils.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu_scheduler_utils.cpp deleted file mode 100644 index 8e611364bd521..0000000000000 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_scheduler_utils.cpp +++ /dev/null @@ -1,57 +0,0 @@ -#if defined(USE_CUDA) -#include -#include - -#include -#include -#include -#include -#include - -// Tests go in torch::jit -namespace torch { -namespace jit { - -using namespace torch::jit::fuser::cuda; - -TEST_F(NVFuserTest, FusionSplitDims_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - int64_t* p = prime_numbers; - auto tv = makeConcreteTensor( - {p[0] * p[1] * p[2], p[3], p[4], p[5] * p[6], p[7], p[8], p[9] * p[10]}); - std::vector dims{0, 1, 2, 3, 4, 5, 6}; - scheduler_utils::splitDims( - tv, {{0, p[2]}, {0, p[1]}, {3, p[6]}, {6, p[10]}}, dims); - TORCH_CHECK(tv->nDims() == 11); - for (auto i : c10::irange(11)) { - TORCH_CHECK(tv->axis(i)->extent()->evaluateInt() == p[i]); - } - std::vector expect{0, 3, 4, 5, 7, 8, 9}; - TORCH_CHECK(dims == expect); -} - -TEST_F(NVFuserTest, FusionMergeDims_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - int64_t* p = prime_numbers; - auto tv = makeConcreteTensor( - {p[0], p[1], p[2], p[3], p[4], p[5], p[6], p[7], p[8], p[9], p[10]}); - std::vector dims{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; - auto merged = scheduler_utils::mergeDims(tv, {2, 3, 7, 8, 9}, dims); - TORCH_CHECK(merged == 2); - std::vector expect_shape{ - p[0], p[1], p[2] * p[3] * p[7] * p[8] * p[9], p[4], p[5], p[6], p[10]}; - TORCH_CHECK(tv->nDims() == expect_shape.size()); - for (auto i : c10::irange(expect_shape.size())) { - TORCH_CHECK(tv->axis(i)->extent()->evaluateInt() == expect_shape[i]); - } - std::vector expect_dims{0, 1, 2, 2, 3, 4, 5, 2, 2, 2, 6}; - TORCH_CHECK(dims == expect_dims); -} - -} // namespace jit -} // namespace torch -#endif // #if defined(USE_CUDA) diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_shift.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu_shift.cpp index b2302013f5fd9..d1f185011826e 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_shift.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_shift.cpp @@ -2976,6 +2976,7 @@ TEST_F(NVFuserTest, FusionConv2D_CUDA) { TEST_F(NVFuserTest, FusionConv2DNoPadding_CUDA) { Fusion fusion; FusionGuard fg(&fusion); + ContextCudnnTF32Disabled disabling_tf32_cudnn; // Input: [C, H, W] auto inp = makeSymbolicTensor(3); @@ -5394,6 +5395,72 @@ TEST_F(NVFuserTest, FusionGatherIterTypePromotion_CUDA) { testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__); } +TEST_F(NVFuserTest, FusionContigPredicateShift_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + std::vector shape({2, 2}); + + auto tv0 = makeConcreteTensor(shape); + // [0:I] + fusion.addInput(tv0); + + // Below, tv2 and tv3 are mostly the same, except for tv2 is padded + // with 0, whereas tv3 is not, so the valid range of tv3 is [0:I-1] + + // [0:I] + auto tv1 = shift(tv0, {-1, 0}); + + // [0:I-1] + auto tv2 = shift(tv0, {-1, 0}, false); + + // tv3 is not an output of shift, but it gets a partial root + // domain from tv2, so it must be predicated at the root domain + auto tv3 = add(tv2, IrBuilder::create(1)); + + fusion.addOutput(tv1); + fusion.addOutput(tv3); + + // contig merge + tv1->merge(0); + tv1->split(0, 4); + TransformPropagator propagator(tv1); + MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + // Create 3x2 and trim to 2x2. This would cause the output tensor + // non-zero values if not properly predicated. + at::Tensor t0 = at::randn({3, 2}, options); + t0 = t0.index( + {at::indexing::Slice(0, 2), at::indexing::Slice(0, at::indexing::None)}); + + // Use random output to detect invalid writes + at::Tensor t1 = at::rand_like(t0, options); + // Use zero-cleared output to detect invalid writes + at::Tensor t3 = at::zeros_like(t0, options); + + std::vector inputs = {t0}; + std::vector outputs = {t1, t3}; + + std::vector indices{ + at::indexing::Slice(0, -1), at::indexing::Slice(0, at::indexing::None)}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); + fe.runFusion(inputs, outputs); + + // Make sure the padded region is zero filled + TORCH_CHECK(t1[1].equal(at::zeros(2, options))); + // Make sure not touched as the shift is not padded + TORCH_CHECK(t3[1].equal(at::zeros(2, options))); + + auto ref = shift(t0, {-1, 0}); + + TORCH_CHECK(t1.equal(ref)); + TORCH_CHECK(t3.index(indices).equal((ref + 1).index(indices))); +} + } // namespace jit } // namespace torch #endif // #if defined(USE_CUDA) diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_tensor_factories.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu_tensor_factories.cpp index 15bdda0c0ec1c..06e93fcd579e3 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_tensor_factories.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_tensor_factories.cpp @@ -18,12 +18,193 @@ namespace jit { using namespace torch::jit::fuser::cuda; +TEST_F(NVFuserTest, FusionStandaloneFull_CUDA) { + auto sizes = {0, 1, 10, 17, 1024}; + auto dtypes = { + kBool, + kFloat, + kLong, + kDouble, + kHalf, + kBFloat16, + kInt, + kComplexFloat, + kComplexDouble}; + + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + Val* size = IrBuilder::create(); + Val* fill_val1 = IrBuilder::create(); + Val* fill_val2 = IrBuilder::create(); + Val* fill_val3 = IrBuilder::create(); + fusion->addInput(size); + fusion->addInput(fill_val1); + fusion->addInput(fill_val2); + fusion->addInput(fill_val3); + for (auto dtype : dtypes) { + if (!isSupportedTypeByDevice(aten_to_data_type(dtype))) { + continue; + } + auto out_tv = full({size}, fill_val1, aten_to_data_type(dtype)); + fusion->addOutput(out_tv); + out_tv = full({size, size}, fill_val2, aten_to_data_type(dtype)); + fusion->addOutput(out_tv); + out_tv = full_like(out_tv, fill_val3); + fusion->addOutput(out_tv); + } + + FusionExecutorCache executor_cache(std::move(fusion)); + + for (auto size : sizes) { + std::vector expect; + expect.reserve(dtypes.size()); + for (auto dtype : dtypes) { + if (!isSupportedTypeByDevice(aten_to_data_type(dtype))) { + continue; + } + const auto options = + at::TensorOptions().dtype(dtype).device(at::kCUDA, 0); + expect.emplace_back(at::full({size}, 11, options)); + expect.emplace_back(at::full({size, size}, 12, options)); + expect.emplace_back(at::full({size, size}, 13, options)); + } + auto cg_outputs = executor_cache.runFusionWithInputs({size, 11, 12, 13}); + + testValidate( + executor_cache.fusion(), + cg_outputs, + {size, 11, 12, 13}, + expect, + __LINE__, + __FILE__); + } +} + +TEST_F(NVFuserTest, FusionStandaloneZeros_CUDA) { + auto sizes = {0, 1, 10, 17, 1024}; + auto dtypes = { + kBool, + kFloat, + kLong, + kDouble, + kHalf, + kBFloat16, + kInt, + kComplexFloat, + kComplexDouble}; + + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + Val* size = IrBuilder::create(); + fusion->addInput(size); + for (auto dtype : dtypes) { + if (!isSupportedTypeByDevice(aten_to_data_type(dtype))) { + continue; + } + auto out_tv = zeros({size}, aten_to_data_type(dtype)); + fusion->addOutput(out_tv); + out_tv = zeros({size, size}, aten_to_data_type(dtype)); + fusion->addOutput(out_tv); + out_tv = zeros_like(out_tv); + fusion->addOutput(out_tv); + } + + FusionExecutorCache executor_cache(std::move(fusion)); + + for (auto size : sizes) { + std::vector expect; + expect.reserve(dtypes.size()); + for (auto dtype : dtypes) { + if (!isSupportedTypeByDevice(aten_to_data_type(dtype))) { + continue; + } + const auto options = + at::TensorOptions().dtype(dtype).device(at::kCUDA, 0); + expect.emplace_back(at::zeros({size}, options)); + expect.emplace_back(at::zeros({size, size}, options)); + expect.emplace_back(at::zeros({size, size}, options)); + } + auto cg_outputs = executor_cache.runFusionWithInputs({size}); + + testValidate( + executor_cache.fusion(), + cg_outputs, + {size}, + expect, + __LINE__, + __FILE__); + } +} + +TEST_F(NVFuserTest, FusionStandaloneOnes_CUDA) { + auto sizes = {0, 1, 10, 17, 1024}; + auto dtypes = { + kBool, + kFloat, + kLong, + kDouble, + kHalf, + kBFloat16, + kInt, + kComplexFloat, + kComplexDouble}; + + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + Val* size = IrBuilder::create(); + fusion->addInput(size); + for (auto dtype : dtypes) { + if (!isSupportedTypeByDevice(aten_to_data_type(dtype))) { + continue; + } + auto out_tv = ones({size}, aten_to_data_type(dtype)); + fusion->addOutput(out_tv); + out_tv = ones({size, size}, aten_to_data_type(dtype)); + fusion->addOutput(out_tv); + out_tv = ones_like(out_tv); + fusion->addOutput(out_tv); + } + + FusionExecutorCache executor_cache(std::move(fusion)); + + for (auto size : sizes) { + std::vector expect; + expect.reserve(dtypes.size()); + for (auto dtype : dtypes) { + if (!isSupportedTypeByDevice(aten_to_data_type(dtype))) { + continue; + } + const auto options = + at::TensorOptions().dtype(dtype).device(at::kCUDA, 0); + expect.emplace_back(at::ones({size}, options)); + expect.emplace_back(at::ones({size, size}, options)); + expect.emplace_back(at::ones({size, size}, options)); + } + auto cg_outputs = executor_cache.runFusionWithInputs({size}); + + testValidate( + executor_cache.fusion(), + cg_outputs, + {size}, + expect, + __LINE__, + __FILE__); + } +} + TEST_F(NVFuserTest, FusionStandaloneARange_CUDA) { auto starts_ends = {-1., 0., 10.3, 1024. * 256}; auto steps = {-1.5, 1., 2.}; auto dtypes = {kFloat, kLong, kDouble}; for (auto dtype : dtypes) { + if (!isSupportedTypeByDevice(aten_to_data_type(dtype))) { + continue; + } + auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -97,6 +278,62 @@ TEST_F(NVFuserTest, FusionStandaloneARange_CUDA) { } } +TEST_F(NVFuserTest, FusionStandaloneEye_CUDA) { + auto sizes = {0, 1, 10, 17, 1024}; + auto dtypes = { + kBool, + kFloat, + kLong, + kDouble, + kHalf, + kBFloat16, + kInt, + kComplexFloat, + kComplexDouble}; + + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + Val* size = IrBuilder::create(); + Val* maybe_m = IrBuilder::create(); + fusion->addInput(size); + fusion->addInput(maybe_m); + for (auto dtype : dtypes) { + if (!isSupportedTypeByDevice(aten_to_data_type(dtype))) { + continue; + } + auto out_tv1 = eye(size, aten_to_data_type(dtype)); + fusion->addOutput(out_tv1); + auto out_tv2 = eye(size, maybe_m, aten_to_data_type(dtype)); + fusion->addOutput(out_tv2); + } + + FusionExecutorCache executor_cache(std::move(fusion)); + + for (auto size : sizes) { + std::vector expect; + expect.reserve(dtypes.size()); + for (auto dtype : dtypes) { + if (!isSupportedTypeByDevice(aten_to_data_type(dtype))) { + continue; + } + const auto options = + at::TensorOptions().dtype(dtype).device(at::kCUDA, 0); + expect.emplace_back(at::eye(size, options)); + expect.emplace_back(at::eye(size, 15, options)); + } + auto cg_outputs = executor_cache.runFusionWithInputs({size, 15}); + + testValidate( + executor_cache.fusion(), + cg_outputs, + {size, 15}, + expect, + __LINE__, + __FILE__); + } +} + } // namespace jit } // namespace torch #endif // #if defined(USE_CUDA) diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_transpose.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu_transpose.cpp index 5e8b6bc1bda69..b10360f00315e 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_transpose.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_transpose.cpp @@ -3,7 +3,8 @@ #include #include -#include +#include +#include #include #include #include @@ -261,9 +262,11 @@ TEST_F(NVFuserTest, FusionScheduleTransposeSinTransposeCos_CUDA) { testValidate(&fusion, outputs, {input}, {tv_ref}, __LINE__, __FILE__); } -// t0->transpose--. -// | -// t1->transpose---add-->sin->t5 +/* + * t0->transpose--. + * \ + * t1->transpose---add-->sin->t5 + */ TEST_F(NVFuserTest, FusionScheduleTransposeMultipleInput_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -325,10 +328,12 @@ TEST_F(NVFuserTest, FusionScheduleTransposeMultipleOutput_CUDA) { &fusion, outputs, {input}, {tv_ref1, tv_ref2}, __LINE__, __FILE__); } -// t0->transpose->sin->t3 -// \_.-->cos->t5 -// / -// t1 +/* + * t0->transpose->sin->t3 + * \_.-->cos->t5 + * / + * t1 + */ TEST_F(NVFuserTest, FusionScheduleTransposeMultipleInputOutput_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -366,9 +371,11 @@ TEST_F(NVFuserTest, FusionScheduleTransposeMultipleInputOutput_CUDA) { __FILE__); } -// .------>sin------>z -// x->transpose->transpose->add->y -// \_______________________/ +/* + * .------>sin------>z + * x->transpose->transpose->add->y + * \_______________________/ + */ TEST_F(NVFuserTest, FusionScheduleTransposeMatchingSkipConnection_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -743,9 +750,7 @@ TEST_F(NVFuserTest, FusionManualScheduleTransposeComplexDAG1_CUDA) { } // inline - MaxRootDomainInfoSpanningTree entire_dag(tv9); - InlinePropagator inline_propagator(tv9, -1, ComputeAtMode::MostInlined); - entire_dag.traverse(&inline_propagator); + inlineMost(); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input0 = at::randn({512, 1024, 256}, options); @@ -789,6 +794,61 @@ TEST_F(NVFuserTest, FusionViewNoTranspose_CUDA) { TORCH_CHECK(!hasAtLeastTwoValidGroups(&fusion)); } +TEST_F(NVFuserTest, FusionTransposeSelfMapping_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(2); + fusion.addInput(tv0); + auto tv1 = transpose(tv0, 0, 1); + auto tv2 = add(tv0, tv1); + fusion.addOutput(tv2); + + EXPECT_THAT( + [&]() { IterDomainGraph(fusion_ptr.get()); }, + testing::ThrowsMessage( + testing::HasSubstr("Unsupported domain mapping detected"))); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({5, 5}, options); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto cg_outputs = executor_cache.runFusionWithInputs({t0}); + + auto ref = t0.transpose(0, 1) + t0; + + testValidate( + executor_cache.fusion(), cg_outputs, {t0}, {ref}, __LINE__, __FILE__); +} + +#if 0 +// silent wrong result +TEST_F(NVFuserTest, FusionTransposeViewSelfMapping_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(2); + fusion.addInput(tv0); + auto tv1 = transpose(tv0, 0, 1); + auto tv2 = view(tv0, {2, 3}, {3, 2}); + auto tv3 = add(tv1, tv2); + fusion.addOutput(tv3); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({2, 3}, options); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto cg_outputs = executor_cache.runFusionWithInputs({t0}); + + auto ref = t0.transpose(0, 1) + t0.view({3, 2}); + + testValidate( + executor_cache.fusion(), cg_outputs, {t0}, {ref}, __LINE__, __FILE__); +} +#endif + // t0------------. // t2->broadcast->sub->mul->relu->t6 // t1------------------' @@ -932,6 +992,269 @@ TEST_F(NVFuserTest, FusionScheduleTransposeSmallInnerSize3_CUDA) { testValidate(&fusion, outputs, {input}, {tv_ref}, __LINE__, __FILE__); } +// x->sin->transpose->cos->y +TEST_F(NVFuserTest, FusionScheduleTranspose2DSmallInnerSize_CUDA) { + std::array, 2> shapes{ + std::vector{1024 * 1024 * 128, 2}, + std::vector{2, 1024 * 1024 * 128}}; + for (const auto& shape : shapes) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(2); + fusion.addInput(tv0); + auto tv1 = sin(tv0); + auto tv2 = transpose(tv1, 0, 1); + auto tv3 = cos(tv2); + fusion.addOutput(tv3); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input = at::randn(shape, options); + + auto lparams = scheduleTranspose(&fusion, {input}); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input}, lparams); + auto outputs = fe.runFusion({input}, lparams); + + auto tv_ref = input.sin().transpose(0, 1).cos(); + + testValidate(&fusion, outputs, {input}, {tv_ref}, __LINE__, __FILE__); + } +} + +TEST_F(NVFuserTest, FusionTransposeBankConflict1_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({32, 32}); + fusion.addInput(tv0); + auto tv1 = set(tv0); + auto tv2 = transpose(tv1, 0, 1); + auto tv3 = set(tv2); + fusion.addOutput(tv3); + + tv1->setMemoryType(MemoryType::Shared); + tv1->axis(1)->parallelize(ParallelType::TIDx); + tv2->axis(1)->parallelize(ParallelType::TIDx); + tv3->axis(1)->parallelize(ParallelType::TIDx); + + auto bank_conflict_info = fusion.bankConflictInfo(); + + TORCH_CHECK(!bank_conflict_info.empty()); + for (auto info : bank_conflict_info) { + std::pair expect{32, 0}; + TORCH_CHECK(info.second == expect); + } +} + +TEST_F(NVFuserTest, FusionTransposeBankConflict2_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({32, 32}); + fusion.addInput(tv0); + auto tv1 = set(tv0); + auto tv2 = transpose(tv1, 0, 1); + auto tv3 = set(tv2); + fusion.addOutput(tv3); + + tv1->setMemoryType(MemoryType::Shared); + tv1->axis(0)->parallelize(ParallelType::TIDx); + tv2->axis(0)->parallelize(ParallelType::TIDx); + tv3->axis(0)->parallelize(ParallelType::TIDx); + + auto bank_conflict_info = fusion.bankConflictInfo(); + + TORCH_CHECK(!bank_conflict_info.empty()); + for (auto info : bank_conflict_info) { + std::pair expect{0, 32}; + TORCH_CHECK(info.second == expect); + } +} + +TEST_F(NVFuserTest, FusionTransposeBankConflict3_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({32, 32}, DataType::Bool); + fusion.addInput(tv0); + auto tv1 = set(tv0); + auto tv2 = transpose(tv1, 0, 1); + auto tv3 = set(tv2); + fusion.addOutput(tv3); + + tv1->setMemoryType(MemoryType::Shared); + tv1->axis(1)->parallelize(ParallelType::TIDx); + tv2->axis(1)->parallelize(ParallelType::TIDx); + tv3->axis(1)->parallelize(ParallelType::TIDx); + + auto bank_conflict_info = fusion.bankConflictInfo(); + + TORCH_CHECK(!bank_conflict_info.empty()); + for (auto info : bank_conflict_info) { + std::pair expect{8, 0}; + TORCH_CHECK(info.second == expect); + } +} + +TEST_F(NVFuserTest, FusionTransposeBankConflict4_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({32, 32}); + fusion.addInput(tv0); + auto tv1 = set(tv0); + auto tv2 = transpose(tv1, 0, 1); + auto tv3 = set(tv2); + fusion.addOutput(tv3); + + tv1->setMemoryType(MemoryType::Shared); + tv1->merge(0); + tv1->split(0, 4); + tv1->split(0, 8); + tv1->axis(-1)->parallelize(ParallelType::Vectorize); + tv1->axis(0)->parallelize(ParallelType::TIDx); + // T1 [TIDx(32), 8, V(4)] + + tv2->setMemoryType(MemoryType::Shared); + tv2->merge(0); + tv2->split(0, 4); + tv2->split(0, 32); + tv2->axis(1)->parallelize(ParallelType::TIDx); + // T2 [8, TIDx(32), 4] + + tv3->merge(0); + tv3->split(0, 2); + tv3->split(0, 32); + tv3->axis(1)->parallelize(ParallelType::TIDx); + // T3 [16, TIDx(32), 2] + + auto bank_conflict_info = fusion.bankConflictInfo(); + + TORCH_CHECK(!bank_conflict_info.empty()); + for (auto info : bank_conflict_info) { + std::pair expect1{0, 8}; + std::pair expect2{8, 4}; + std::pair expect3{2, 0}; + TORCH_CHECK( + info.second == expect1 || info.second == expect2 || + info.second == expect3); + } +} + +TEST_F(NVFuserTest, FusionTransposeBankConflict5_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({1024, 32, 32}); + fusion.addInput(tv0); + auto tv1 = set(tv0); + auto tv2 = transpose(tv1, 1, 2); + auto tv3 = set(tv2); + fusion.addOutput(tv3); + + tv1->setMemoryType(MemoryType::Shared); + tv1->axis(2)->parallelize(ParallelType::TIDx); + tv2->axis(2)->parallelize(ParallelType::TIDx); + tv3->axis(2)->parallelize(ParallelType::TIDx); + tv1->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv3->axis(0)->parallelize(ParallelType::BIDx); + + auto bank_conflict_info = fusion.bankConflictInfo(); + + TORCH_CHECK(!bank_conflict_info.empty()); + for (auto info : bank_conflict_info) { + std::pair expect{32, 0}; + TORCH_CHECK(info.second == expect); + } +} + +TEST_F(NVFuserTest, FusionTransposeBankConflict6_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({1024, 32, 32}); + fusion.addInput(tv0); + auto tv1 = set(tv0); + auto tv2 = transpose(tv1, 1, 2); + auto tv3 = set(tv2); + fusion.addOutput(tv3); + + tv1->setMemoryType(MemoryType::Shared); + tv1->axis(2)->parallelize(ParallelType::TIDy); + tv2->axis(2)->parallelize(ParallelType::TIDy); + tv3->axis(2)->parallelize(ParallelType::TIDy); + tv1->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv3->axis(0)->parallelize(ParallelType::BIDx); + + auto bank_conflict_info = fusion.bankConflictInfo(); + + TORCH_CHECK(!bank_conflict_info.empty()); + for (auto info : bank_conflict_info) { + std::pair expect{32, 0}; + TORCH_CHECK(info.second == expect); + } +} + +TEST_F(NVFuserTest, FusionTransposeBankConflict7_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({1024, 8, 8}); + fusion.addInput(tv0); + auto tv1 = set(tv0); + auto tv2 = transpose(tv1, 1, 2); + auto tv3 = set(tv2); + fusion.addOutput(tv3); + + tv1->setMemoryType(MemoryType::Shared); + tv1->axis(1)->parallelize(ParallelType::TIDx); + tv2->axis(1)->parallelize(ParallelType::TIDx); + tv3->axis(1)->parallelize(ParallelType::TIDx); + tv1->axis(2)->parallelize(ParallelType::TIDy); + tv2->axis(2)->parallelize(ParallelType::TIDy); + tv3->axis(2)->parallelize(ParallelType::TIDy); + tv1->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv3->axis(0)->parallelize(ParallelType::BIDx); + + auto bank_conflict_info = fusion.bankConflictInfo(); + + TORCH_CHECK(!bank_conflict_info.empty()); + for (auto info : bank_conflict_info) { + std::pair expect{0, 2}; + TORCH_CHECK(info.second == expect); + } +} + +TEST_F(NVFuserTest, FusionTransposeBankConflict8_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({1024, 8, 8}); + fusion.addInput(tv0); + auto tv1 = set(tv0); + auto tv2 = transpose(tv1, 1, 2); + auto tv3 = set(tv2); + fusion.addOutput(tv3); + + tv1->setMemoryType(MemoryType::Shared); + tv1->axis(2)->parallelize(ParallelType::TIDx); + tv2->axis(2)->parallelize(ParallelType::TIDy); + tv3->axis(2)->parallelize(ParallelType::TIDy); + tv1->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv3->axis(0)->parallelize(ParallelType::BIDx); + + auto bank_conflict_info = fusion.bankConflictInfo(); + + // no bank confliction + TORCH_CHECK(bank_conflict_info.empty()); +} + } // namespace jit } // namespace torch #endif // #if defined(USE_CUDA) diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_utils.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu_utils.cpp new file mode 100644 index 0000000000000..19c3c6f9bf6db --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_utils.cpp @@ -0,0 +1,273 @@ +#if defined(USE_CUDA) +#include +#include + +#include +#include +#include +#include +#include +#include + +// Tests go in torch::jit +namespace torch { +namespace jit { + +using namespace torch::jit::fuser::cuda; + +TEST_F(NVFuserTest, FusionSplitDims_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + int64_t* p = prime_numbers; + auto tv = makeConcreteTensor( + {p[0] * p[1] * p[2], p[3], p[4], p[5] * p[6], p[7], p[8], p[9] * p[10]}); + std::vector dims{0, 1, 2, 3, 4, 5, 6}; + scheduler_utils::splitDims( + tv, {{0, p[2]}, {0, p[1]}, {3, p[6]}, {6, p[10]}}, dims); + TORCH_CHECK(tv->nDims() == 11); + for (auto i : c10::irange(11)) { + TORCH_CHECK(tv->axis(i)->extent()->evaluateInt() == p[i]); + } + std::vector expect{0, 3, 4, 5, 7, 8, 9}; + TORCH_CHECK(dims == expect); +} + +TEST_F(NVFuserTest, FusionMergeDims_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + int64_t* p = prime_numbers; + auto tv = makeConcreteTensor( + {p[0], p[1], p[2], p[3], p[4], p[5], p[6], p[7], p[8], p[9], p[10]}); + std::vector dims{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; + auto merged = scheduler_utils::mergeDims(tv, {2, 3, 7, 8, 9}, dims); + TORCH_CHECK(merged == 2); + std::vector expect_shape{ + p[0], p[1], p[2] * p[3] * p[7] * p[8] * p[9], p[4], p[5], p[6], p[10]}; + TORCH_CHECK(tv->nDims() == expect_shape.size()); + for (auto i : c10::irange(expect_shape.size())) { + TORCH_CHECK(tv->axis(i)->extent()->evaluateInt() == expect_shape[i]); + } + std::vector expect_dims{0, 1, 2, 2, 3, 4, 5, 2, 2, 2, 6}; + TORCH_CHECK(dims == expect_dims); +} + +TEST_F(NVFuserTest, FusionReorderAsRFactor_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + int a = 1, b = 2, c = 3, d = 4; + + TensorView* tv0 = makeConcreteTensor({a, b, c, d}); + fusion.addInput(tv0); + fusion.addOutput(tv0); + + // [a, b, c, d] + tv0->merge(0, 2); + // [a*c, b, d] + tv0->split(1, 2); + // [a*c, bo, bi, d] + tv0->split(3, 3); + // [a*c, bo, bi, do, di] + tv0->reorder({{1, 4}, {2, 1}, {3, 3}, {4, 2}}); + // [a*c, bi, di, do, bo] + tv0->merge(3); + tv0->merge(1); + // [a*c, bi*di, do*bo] + tv0->reorder({{0, 2}}); + // [bi*di, do*bo, a*c] + // Order we want is: + // [a*c, do*bo, bi*di] + auto old2new = scheduler_utils::domainReorderAsRfactorMap(tv0); + TORCH_CHECK(old2new[0] == 2); + TORCH_CHECK(old2new[1] == 1); + TORCH_CHECK(old2new[2] == 0); +} + +TEST_F(NVFuserTest, FusionDisjointViewSet_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeConcreteTensor({2, 3, 4}); + fusion->addInput(tv0); + + auto tv1 = view(tv0, {2, 3, 4}, {2, 12}); + + auto tv2 = makeConcreteTensor({2, 12}); + fusion->addInput(tv2); + + auto tv3 = add(tv2, tv1); + fusion->addOutput(tv3); + + auto disjoint_exact = scheduler_utils::disjointViewSets(fusion.get()); + + TORCH_INTERNAL_ASSERT( + disjoint_exact.strictAreMapped(tv0->axis(1), tv0->axis(2))); +} + +TEST_F(NVFuserTest, FusionMatchingViews_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + int x = 2, y = 3, z = 4; + + auto tv0 = makeConcreteTensor({x, y, z}); + fusion.addInput(tv0); + + auto tv1 = view(tv0, {x, y, z}, {x * y, z}); + + auto tv2 = sin(tv1); + + auto tv3 = view(tv2, {x * y, z}, {x, y * z}); + fusion.addOutput(tv3); + + auto tv4 = makeConcreteTensor({x, y, z}); + fusion.addInput(tv4); + + auto tv5 = view(tv4, {x, y, z}, {x, y * z}); + fusion.addOutput(tv5); + + // Link 0 and 3 together for view analysis done based on before the views + // actually happened. + auto tv6 = add(tv0, tv4); + fusion.addOutput(tv6); + + TORCH_INTERNAL_ASSERT(!scheduler_utils::allMatchingViews(&fusion)); +} + +TEST_F(NVFuserTest, FusionBroadcastViewMultiples_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + int a = 2, b = 3, c = 5, d = 7, e = 11, f = 13; + + auto tv0 = makeConcreteTensor({a, b, c, d, e, f}); + fusion.addInput(tv0); + + // tie e and f together (swapping values next to eachother enforces they'll be + // merged then split by view) + auto tv1 = view(tv0, {a, b, c, d, e, f}, {a, b, c, d, f, e}); + fusion.addOutput(tv1); + + // swap d and e + auto tv2 = transpose(tv1, 3, 4); + // tie c and e together + auto tv3 = view(tv2, {a, b, c, e, d, f}, {a, b, e, c, d, f}); + + fusion.addOutput(tv3); + + auto tv4 = set(tv0); + // Use tv4 as the reference + fusion.addOutput(tv4); + + // a, b, d aren't tied to anything so they are valid broadcasts from the + // perspective of broadcast multiples analysis. + auto tv5 = makeConcreteTensor({1, 1, c, 1, e, f}); + fusion.addInput(tv5); + + // c, e, and f are tied together so this shouldn't be counted as a broadcast + // dim in the reference since it's a partial bcast + auto tv6 = makeConcreteTensor({a, b, c, 1, 1, 1}); + fusion.addInput(tv6); + + // c, e, and f are tied together this should be counted as a broadcast dim in + // the reference since it's a partial bcast + auto tv7 = makeConcreteTensor({a, b, 1, 1, 1, 1}); + fusion.addInput(tv7); + + // plug the broadcasts into the fusion + auto tv8 = add(tv5, tv4); + auto tv9 = add(tv6, tv8); + auto tv10 = add(tv7, tv9); + fusion.addOutput(tv10); + + auto bcast_info = + scheduler_utils::getBroadcastMultiples(tv4, DataType::Int32); + + // linked c, e, and f together so they should have the same id. + TORCH_CHECK(bcast_info.view_disjoint_set_ids[5] == 0); + TORCH_CHECK(bcast_info.view_disjoint_set_ids[4] == 0); + TORCH_CHECK(bcast_info.view_disjoint_set_ids[3] == 1); + TORCH_CHECK(bcast_info.view_disjoint_set_ids[2] == 0); + TORCH_CHECK(bcast_info.view_disjoint_set_ids[1] == 2); + TORCH_CHECK(bcast_info.view_disjoint_set_ids[0] == 3); + + TORCH_CHECK( + scheduler_utils::breakIsDisjoint(bcast_info.view_disjoint_set_ids, 0)); + TORCH_CHECK( + scheduler_utils::breakIsDisjoint(bcast_info.view_disjoint_set_ids, 1)); + TORCH_CHECK( + scheduler_utils::breakIsDisjoint(bcast_info.view_disjoint_set_ids, 2)); + TORCH_CHECK( + !scheduler_utils::breakIsDisjoint(bcast_info.view_disjoint_set_ids, 3)); + TORCH_CHECK( + !scheduler_utils::breakIsDisjoint(bcast_info.view_disjoint_set_ids, 4)); + TORCH_CHECK( + !scheduler_utils::breakIsDisjoint(bcast_info.view_disjoint_set_ids, 5)); + + // tv0 [a, b, c, d, e, f] + // tv1 [a, b, c, d, e, f] + // tv3 [a, b, c, d, e, f] + // tv4 [a, b, c, d, e, f] + // tv5 [1, 1, c, 1, e, f] -> Left bcasts should show up in some multiples + // tv6 [a, b, c, 1, 1, 1] -> view interferes with bcasts, non of these should + // show up + // tv7 [a, b, 1, 1, 1, 1] -> These broadcasts could be recognized + // tv10 [a, b, c, d, e, f] + + TORCH_CHECK( + bcast_info.broadcast_multiples[0].lhs_multiple == 0 && + bcast_info.broadcast_multiples[0].rhs_multiple == 8 * 4); + + TORCH_CHECK( + bcast_info.broadcast_multiples[1].lhs_multiple == 7 * 4 && + bcast_info.broadcast_multiples[1].rhs_multiple == 8 * 4); + + TORCH_CHECK( + bcast_info.broadcast_multiples[2].lhs_multiple == 7 * 4 && + bcast_info.broadcast_multiples[2].rhs_multiple == 7 * 4); + + TORCH_CHECK( + bcast_info.broadcast_multiples[3].lhs_multiple == 8 * 4 && + bcast_info.broadcast_multiples[3].rhs_multiple == 7 * 4); + + TORCH_CHECK( + bcast_info.broadcast_multiples[4].lhs_multiple == 8 * 4 && + bcast_info.broadcast_multiples[4].rhs_multiple == 7 * 4); + + TORCH_CHECK( + bcast_info.broadcast_multiples[5].lhs_multiple == 8 * 4 && + bcast_info.broadcast_multiples[5].rhs_multiple == 7 * 4); +} + +TEST_F(NVFuserTest, FusionTVDomainGuard_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + std::vector all_true = {true, true}; + std::vector all_false = {false, false}; + std::vector false_true = {false, true}; + auto tv = TensorViewBuilder().ndims(2).contiguity(false_true).build(); + TORCH_CHECK(tv->domain()->contiguity() == false_true); + { + auto guard = ir_utils::overrideContiguityGuard(tv, true); + TORCH_CHECK(tv->domain()->contiguity() == all_true); + } + TORCH_CHECK(tv->domain()->contiguity() == false_true); + { + auto guard = ir_utils::overrideContiguityGuard(tv, false); + TORCH_CHECK(tv->domain()->contiguity() == all_false); + } + TORCH_CHECK(tv->domain()->contiguity() == false_true); + { + auto guard1 = ir_utils::overrideContiguityGuard(tv, true); + auto guard2 = std::move(guard1); + TORCH_CHECK(tv->domain()->contiguity() == all_true); + } + TORCH_CHECK(tv->domain()->contiguity() == false_true); +} + +} // namespace jit +} // namespace torch +#endif // #if defined(USE_CUDA) diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_validator.h b/torch/csrc/jit/codegen/cuda/test/test_gpu_validator.h index 0adaaa9786c30..f70c7a80f76fb 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_validator.h +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_validator.h @@ -1,3 +1,5 @@ +#pragma once + #include #include #include @@ -5,43 +7,16 @@ #include #include -#include -#include #include +// Tests go in torch::jit namespace torch { namespace jit { -namespace fuser { -namespace cuda { - -inline bool deviceMajorMinorCheck(int major, int minor = 0) { - auto dev_prop = at::cuda::getCurrentDeviceProperties(); - if (dev_prop->major < major || - (dev_prop->major == major && dev_prop->minor < minor)) { - return false; - } - return true; -} -inline int deviceSMCount() { - int sm_count = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; - return sm_count; -} +using namespace torch::jit::fuser::cuda; -class NVFuserTest : public ::testing::Test { - protected: - void SetUp() override { - // requires PASCAL or newer - if (!deviceMajorMinorCheck(6)) { - GTEST_SKIP() << "skipping tests on pre-PASCAL GPUs"; - } - } - - void TearDown() override { - c10::cuda::CUDACachingAllocator::emptyCache(); - } -}; +namespace { struct ValidationConstants { // Tolerances generated from randn + add + sum fusion @@ -72,8 +47,6 @@ struct ValidationConstants { double base_float_rel_tol = -1; }; -namespace { - // Returns abs and relative values to use for validation std::pair getTolerance( DataType dtype, @@ -336,15 +309,13 @@ ExpressionEvaluator bindInputsAndLaunchParams( return expr_eval; } -} // namespace - // Validation will look through the fusion and figure out how many elements were // reduced to create each output. It will then compute a tolernace to use for // allclose based on experimental results. The experimental results were based // on adding two tensors then summing them. This of course has an assumption // that we're always summing values between -2 and 2. If we start summing values // larger than that this approach might not hold. -inline void testValidate( +void testValidate( Fusion* fusion, const std::vector& fusion_outputs, const at::ArrayRef& aten_inputs, @@ -464,18 +435,6 @@ inline void testValidate( } } -inline void clearL2Cache() { - torch::NoGradGuard no_grad; - auto l2_cache_size = at::cuda::getCurrentDeviceProperties()->l2CacheSize; - auto options = - torch::TensorOptions().dtype(torch::kFloat32).device(at::kCUDA, 0); - - auto l2_elems = l2_cache_size / 4; - torch::Tensor t0 = torch::empty(l2_elems, options); - torch::Tensor t1 = torch::clone(t0); -}; - -} // namespace cuda -} // namespace fuser +} // namespace } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_view.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu_view.cpp index 60194ade674d1..9785e089052af 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_view.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_view.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -22,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -1270,6 +1272,9 @@ TEST_F(NVFuserTest, FusionViewVectorize_CUDA) { } TEST_F(NVFuserTest, FusionExpandFlatten_CUDA) { +#ifdef FBCODE_CAFFE2 + GTEST_SKIP() << "Fails accuracy on V100 32gb"; +#endif auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -1341,6 +1346,517 @@ TEST_F(NVFuserTest, FusionReductionFlatten1_CUDA) { executor_cache.fusion(), cg_outputs, {t0}, {ref}, __LINE__, __FILE__); } +TEST_F(NVFuserTest, FusionPwiseViewSchedule_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + int x = 31, y = 65, z = 103; + + auto tv0 = makeConcreteTensor({x, y, z}); + fusion.addInput(tv0); + + auto tv1 = sin(tv0); + + auto tv2 = view(tv1, {x, y, z}, {x, y * z}); + fusion.addOutput(tv2); + + auto tv3 = makeConcreteTensor({x, y, z}); + fusion.addInput(tv3); + + auto tv4 = view(tv3, {x, y, z}, {x, y * z}); + fusion.addOutput(tv4); + + // Link 0 and 3 together for view analysis done based on before the views + // actually happened. + auto tv5 = add(tv0, tv3); + fusion.addOutput(tv5); + + TORCH_INTERNAL_ASSERT(scheduler_utils::allMatchingViews(&fusion)); + { + TransformPropagator propagator(tv4); + MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); + } + + for (auto i : c10::irange(tv5->nDims() - 1)) { + tv5->merge(0); + } + tv5->split(0, 32); + tv5->split(0, 4); + tv5->axis(0)->parallelize(ParallelType::BIDx); + tv5->axis(1)->parallelize(ParallelType::Unroll); + tv5->axis(2)->parallelize(ParallelType::TIDx); + + { + TransformPropagator propagator(tv5); + MaxRootDomainInfoSpanningTree spanning_tree(tv5); + spanning_tree.traverse(&propagator); + scheduler_utils::parallelizeAllLike(tv5); + + // Inline the schedule + inlineMost(); + } + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor t0 = at::randn({x, y, z}, options); + at::Tensor t3 = at::randn({x, y, z}, options); + auto t1 = sin(t0); + auto t2 = at::native::view(t1, {x, y * z}); + auto t4 = at::native::view(t3, {x, y * z}); + auto t5 = t0 + t3; + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t3}); + auto cg_outputs = fe.runFusion({t0, t3}); + + testValidate(&fusion, cg_outputs, {t0, t3}, {t2, t4, t5}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionSumViewSchedule_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + int x = 31, y = 65, z = 103; + + auto tv0 = makeConcreteTensor({x, y, z}); + fusion.addInput(tv0); + + auto tv1 = sin(tv0); + + auto tv2 = view(tv1, {x, y, z}, {x, y * z}); + fusion.addOutput(tv2); + + auto tv3 = makeConcreteTensor({x, y, z}); + fusion.addInput(tv3); + + auto tv4 = view(tv3, {x, y, z}, {x, y * z}); + auto tv5 = sum(tv4, {1}); + fusion.addOutput(tv5); + + // Link 0 and 3 together for view analysis done based on before the views + // actually happened. + auto tv6 = add(tv0, tv3); + fusion.addOutput(tv6); + + TORCH_INTERNAL_ASSERT(scheduler_utils::allMatchingViews(&fusion)); + { + TransformPropagator propagator(tv4); + MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); + } + + tv5->split(1, 128); + tv5->split(1, 4); + + auto tv5_rf = tv5->rFactor({1, 2}); + tv5_rf->axis(0)->parallelize(ParallelType::BIDx); + tv5_rf->axis(2)->parallelize(ParallelType::Unroll); + tv5_rf->axis(3)->parallelize(ParallelType::TIDx); + + { + TransformPropagator propagator(tv5_rf); + MaxRootDomainInfoSpanningTree spanning_tree(tv5_rf); + spanning_tree.traverse(&propagator); + scheduler_utils::parallelizeAllLike(tv5_rf); + + // Inline the schedule + inlineMost(); + } + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor t0 = at::randn({x, y, z}, options); + at::Tensor t3 = at::randn({x, y, z}, options); + auto t1 = sin(t0); + auto t2 = at::native::view(t1, {x, y * z}); + auto t4 = at::native::view(t3, {x, y * z}); + auto t5 = t4.sum({1}); + auto t6 = t0 + t3; + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t3}); + auto cg_outputs = fe.runFusion({t0, t3}); + + testValidate(&fusion, cg_outputs, {t0, t3}, {t2, t5, t6}, __LINE__, __FILE__); +} + +// Make sure matching views are segmented into the same kernel +TEST_F(NVFuserTest, FusionViewMagicSchedule1_CUDA) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + int x = 31, y = 65, z = 103; + + auto tv0 = makeConcreteTensor({x, y, z}); + fusion.addInput(tv0); + + auto tv1 = sin(tv0); + + auto tv2 = view(tv1, {x, y, z}, {x, y * z}); + fusion.addOutput(tv2); + + auto tv3 = makeConcreteTensor({x, y, z}); + fusion.addInput(tv3); + + auto tv4 = view(tv3, {x, y, z}, {x, y * z}); + fusion.addOutput(tv4); + + // Link 0 and 3 together for view analysis done based on before the views + // actually happened. + auto tv5 = add(tv0, tv3); + fusion.addOutput(tv5); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor t0 = at::randn({x, y, z}, options); + at::Tensor t3 = at::randn({x, y, z}, options); + auto t1 = sin(t0); + auto t2 = at::native::view(t1, {x, y * z}); + auto t4 = at::native::view(t3, {x, y * z}); + auto t5 = t0 + t3; + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto cg_outputs = executor_cache.runFusionWithInputs({t0, t3}); + TORCH_CHECK(!executor_cache.getMostRecentKernelRuntime()->isSegmented()); + + testValidate(&fusion, cg_outputs, {t0, t3}, {t2, t4, t5}, __LINE__, __FILE__); +} + +// Make sure views of views are correct +TEST_F(NVFuserTest, FusionViewMagicSchedule2_CUDA) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + int x = 31, y = 65, z = 103; + + auto tv0 = makeConcreteTensor({x, y, z}); + fusion.addInput(tv0); + + auto tv1 = sin(tv0); + + auto tv2 = view(tv1, {x, y, z}, {x, y * z}); + auto tv3 = view(tv2, {x, y * z}, {x * y, z}); + auto tv4 = view(tv3, {x * y, z}, {y, x * z}); + auto tv5 = view(tv4, {y, x * z}, {x, y, z}); + fusion.addOutput(tv5); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor t0 = at::randn({x, y, z}, options); + auto aten_out = sin(t0); + + // For now pointwise scheduler only accepts a single view at a time, so this + // will be broken up into multiple kernels. This is due to the reference check + // looking for all mappings to all input IDs. + // TODO: Fix the reference check for this case + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto cg_outputs = executor_cache.runFusionWithInputs({t0}); + + testValidate(&fusion, cg_outputs, {t0}, {aten_out}, __LINE__, __FILE__); +} + +// Make sure broadcasts not on the view path that don't interfere with view are +// segmented in one kernel and correctly trigger 2D pointwise scheduling +TEST_F(NVFuserTest, FusionViewMagicSchedule3_CUDA) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + int w = 15, x = 31, y = 49, z = 65; + + auto tv0 = makeConcreteTensor({x, y, z}); + fusion.addInput(tv0); + + auto tv1 = sin(tv0); + + auto tv2 = view(tv1, {x, y, z}, {x, y * z}); + fusion.addOutput(tv2); + + auto tv3 = makeConcreteTensor({x, y, z}); + fusion.addInput(tv3); + + auto tv4 = view(tv3, {x, y, z}, {x, y * z}); + fusion.addOutput(tv4); + + // Link 0 and 3 together for view analysis done based on before the views + // actually happened. + auto tv5 = add(tv0, tv3); + fusion.addOutput(tv5); + + // Broadcast on another branch to drive the pointwise reference to not be on + // the view paths. + + auto tv6 = makeConcreteTensor({w, x, y, z}); + fusion.addInput(tv6); + auto tv7 = broadcast(tv0, {true, false, false, false}); + auto tv8 = add(tv6, tv7); + // tv8 should be the reference for the pointwise fusion. This broadcast + // pattern doesn't interfere with the views, so this should also be scheduled + // as 2D. + fusion.addOutput(tv8); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor t0 = at::randn({x, y, z}, options); + at::Tensor t3 = at::randn({x, y, z}, options); + auto t1 = sin(t0); + auto t2 = at::native::view(t1, {x, y * z}); + auto t4 = at::native::view(t3, {x, y * z}); + auto t5 = t0 + t3; + at::Tensor t6 = at::randn({w, x, y, z}, options); + auto t8 = t6.add(t0); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + // Collect the heuristic params + executor_cache.profile(true); + auto cg_outputs = executor_cache.runFusionWithInputs({t0, t3, t6}); + + TORCH_CHECK(!executor_cache.getMostRecentKernelRuntime()->isSegmented()); + TORCH_CHECK(executor_cache.getMostRecentExecutorInfo() + .params->isA()); + auto pparams = + executor_cache.getMostRecentExecutorInfo().params->as(); + TORCH_CHECK(pparams->break_point == 1); + + testValidate( + &fusion, cg_outputs, {t0, t3, t6}, {t2, t4, t5, t8}, __LINE__, __FILE__); +} + +// Make sure broadcasts through views when not conflicting with view are +// segmented into one kernel and trigger 2D pointwise scheduler. +TEST_F(NVFuserTest, FusionViewMagicSchedule4_CUDA) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + int w = 15, x = 31, y = 49, z = 65; + + auto tv0 = makeConcreteTensor({x, y, z}); + fusion.addInput(tv0); + + auto tv1 = sin(tv0); + + auto tv2 = view(tv1, {x, y, z}, {x, y * z}); + fusion.addOutput(tv2); + + auto tv3 = makeConcreteTensor({x, y, z}); + fusion.addInput(tv3); + + auto tv4 = makeConcreteTensor({x, 1, 1}); + fusion.addInput(tv4); + + auto tv5 = add(tv4, tv3); + + auto tv6 = view(tv5, {x, y, z}, {x, y * z}); + fusion.addOutput(tv6); + + // Link 0 and 3 together for view analysis done based on before the views + // actually happened. + auto tv7 = add(tv0, tv3); + fusion.addOutput(tv7); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor t0 = at::randn({x, y, z}, options); + at::Tensor t3 = at::randn({x, y, z}, options); + at::Tensor t4 = at::randn({x, 1, 1}, options); + auto t1 = sin(t0); + auto t2 = at::native::view(t1, {x, y * z}); + auto t5 = t4 + t3; + auto t6 = at::native::view(t5, {x, y * z}); + auto t7 = t0 + t3; + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + // Collect the heuristic params + executor_cache.profile(true); + auto cg_outputs = executor_cache.runFusionWithInputs({t0, t3, t4}); + + TORCH_CHECK(!executor_cache.getMostRecentKernelRuntime()->isSegmented()); + TORCH_CHECK(executor_cache.getMostRecentExecutorInfo() + .params->isA()); + auto pparams = + executor_cache.getMostRecentExecutorInfo().params->as(); + TORCH_CHECK(pparams->break_point == 1); + + testValidate( + &fusion, cg_outputs, {t0, t3, t4}, {t2, t6, t7}, __LINE__, __FILE__); +} + +// Make sure different views that are consumed by the reference are segmented +// into a single kernel. +TEST_F(NVFuserTest, FusionViewMagicSchedule5_CUDA) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + int w = 15, x = 31, y = 49, z = 65; + + auto tv0 = makeConcreteTensor({w, x, y * z}); + fusion.addInput(tv0); + auto tv1 = sin(tv0); + auto tv2 = view(tv1, {w, x, y * z}, {z, y, x, w}); + + auto tv3 = makeConcreteTensor({w, x * y, z}); + fusion.addInput(tv3); + auto tv4 = cos(tv3); + auto tv5 = view(tv4, {w, x * y, z}, {z, y, x, w}); + + auto tv6 = add(tv2, tv5); + fusion.addOutput(tv6); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor t0 = at::randn({w, x, y * z}, options); + auto t1 = sin(t0); + auto t2 = at::native::view(t1, {z, y, x, w}); + at::Tensor t3 = at::randn({w, x * y, z}, options); + auto t4 = cos(t3); + auto t5 = at::native::view(t4, {z, y, x, w}); + auto t6 = add(t2, t5); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + // Collect the heuristic params + executor_cache.profile(true); + auto cg_outputs = executor_cache.runFusionWithInputs({t0, t3}); + + TORCH_CHECK(!executor_cache.getMostRecentKernelRuntime()->isSegmented()); + TORCH_CHECK(executor_cache.getMostRecentExecutorInfo() + .params->isA()); + + testValidate(&fusion, cg_outputs, {t0, t3}, {t6}, __LINE__, __FILE__); +} + +// Make sure different views that are consumed by the reference are segmented +// into a single kernel. +TEST_F(NVFuserTest, FusionViewMapping_CUDA) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + int w = 15, x = 31, y = 49, z = 65; + + auto tv0 = makeConcreteTensor({w, x, y * z}); + fusion.addInput(tv0); + auto tv1 = sin(tv0); + auto tv2 = view(tv1, {w, x, y * z}, {z, y, x, w}); + + auto tv3 = makeConcreteTensor({w, x * y, z}); + fusion.addInput(tv3); + auto tv4 = cos(tv3); + auto tv5 = view(tv4, {w, x * y, z}, {z, y, x, w}); + + auto tv6 = add(tv2, tv5); + fusion.addOutput(tv6); + + tv6->merge(0); + tv6->merge(0); + tv6->merge(0); + tv6->split(0, 128); + tv6->split(0, 4); + tv6->axis(0)->parallelize(ParallelType::BIDx); + tv6->axis(1)->parallelize(ParallelType::Unroll); + tv6->axis(2)->parallelize(ParallelType::TIDx); + + TransformPropagator propagator(tv6); + MaxRootDomainInfoSpanningTree spanning_tree(tv6); + spanning_tree.traverse(&propagator); + scheduler_utils::parallelizeAllLike(tv6); + + // Inline the schedule + inlineMost(); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor t0 = at::randn({w, x, y * z}, options); + auto t1 = sin(t0); + auto t2 = at::native::view(t1, {z, y, x, w}); + at::Tensor t3 = at::randn({w, x * y, z}, options); + auto t4 = cos(t3); + auto t5 = at::native::view(t4, {z, y, x, w}); + auto t6 = add(t2, t5); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t3}); + auto cg_outputs = fe.runFusion({t0, t3}); + + testValidate(&fusion, cg_outputs, {t0, t3}, {t6}, __LINE__, __FILE__); +} + +TEST_F(NVFuserTest, FusionLowerDivisibleSplits_CUDA) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + int w = 15, x = 31, y = 49, z = 65; + + auto tv0 = makeContigTensor(4); + fusion.addInput(tv0); + auto tv1 = sin(tv0); + auto tv2 = view(tv1, {w, x, y, z}, {z, y, x, w}); + + fusion.addOutput(tv2); + + tv2->merge(0)->merge(0)->merge(0)->split(0, 4)->split(0, 8, false); + + TransformPropagator propagator(tv2); + MaxRootDomainInfoSpanningTree spanning_tree(tv2); + spanning_tree.traverse(&propagator); + scheduler_utils::parallelizeAllLike(tv2); + + // Inline the schedule + inlineMost(); + + auto divisible_splits = getAllDivisibleSplits(&fusion); + + // Operations on all tensors are basically: + // [10] merge(0) [9]->outer->definition + // [9] merge(0) [8]->outer->definition + // [8] merge(0) [7]->in->definition + // [7] split(0, z, false) [6]->in->definition + // [6] split(1, y, false) [5]->in->definition + // [5] split(2, x, false) [3]->inner->definition + // RFactor of tv2 + // [4] merge(0) [3]->outer->definition + // [3] merge(0) [2]->outer->definition + // [2] merge(0) [1]->in->definition + // [1] split(0, 4) [0]->in->definition + // [0] split(0, 8, false) tv->axis(0)->definition + + for (auto tv : std::vector({tv2, tv1, tv0})) { + auto transform_0 = tv->axis(0)->definition()->as(); + auto transform_1 = transform_0->in()->definition()->as(); + auto transform_2 = transform_1->in()->definition()->as(); + auto transform_3 = transform_2->outer()->definition()->as(); + + auto transform_5 = transform_3->inner()->definition()->as(); + auto transform_6 = transform_5->in()->definition()->as(); + auto transform_7 = transform_6->in()->definition()->as(); + + TORCH_CHECK( + divisible_splits.find(transform_5) != divisible_splits.end(), + "Expecting: ", + transform_5->toString(), + "\nFrom TV: ", + tv, + "\nTo be a divisible split."); + TORCH_CHECK( + divisible_splits.find(transform_6) != divisible_splits.end(), + "Expecting: ", + transform_6->toString(), + "\nFrom TV: ", + tv, + "\nTo be a divisible split."); + TORCH_CHECK( + divisible_splits.find(transform_7) != divisible_splits.end(), + "Expecting: ", + transform_7->toString(), + "\nFrom TV: ", + tv, + "\nTo be a divisible split."); + } +} + } // namespace jit } // namespace torch #endif // #if defined(USE_CUDA) diff --git a/torch/csrc/jit/codegen/cuda/test/test_utils.h b/torch/csrc/jit/codegen/cuda/test/test_utils.h index c8bf546daf4a0..8b199b930f247 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_utils.h +++ b/torch/csrc/jit/codegen/cuda/test/test_utils.h @@ -1,8 +1,21 @@ #pragma once -#include - +#include +#include #include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include + +#include // Tests go in torch::jit namespace torch { @@ -11,7 +24,7 @@ namespace jit { using namespace torch::jit::fuser::cuda; namespace { - +bool var; // Make a tensor that is known to be fully contiguous of dimensionality=ndims, // but unknown sizes TensorView* makeContigTensor(size_t ndims, DataType dtype = DataType::Float) { @@ -84,6 +97,277 @@ int64_t prime_numbers[] = { 1087, 1091, 1093, 1097, 1103, 1109, 1117, 1123, 1129, 1151, 1153, 1163, 1171, 1181, 1187, 1193, 1201, 1213, 1217, 1223}; +bool deviceMajorMinorCheck(int major, int minor = 0) { + auto dev_prop = at::cuda::getCurrentDeviceProperties(); + if (dev_prop->major < major || + (dev_prop->major == major && dev_prop->minor < minor)) { + return false; + } + return true; +} + +int deviceSMCount() { + int sm_count = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; + return sm_count; +} + +void clearL2Cache() { + torch::NoGradGuard no_grad; + auto l2_cache_size = at::cuda::getCurrentDeviceProperties()->l2CacheSize; + auto options = + torch::TensorOptions().dtype(torch::kFloat32).device(at::kCUDA, 0); + + auto l2_elems = l2_cache_size / 4; + torch::Tensor t0 = torch::empty(l2_elems, options); + torch::Tensor t1 = torch::clone(t0); +}; + +TensorView* loweredTv(TensorView* tv, GpuLower& gpulw) { + auto used_tvs = ir_utils::allTvs(gpulw.kernel()->as()); + TensorView* matching_tv = nullptr; + for (auto lowered_tv : used_tvs) { + if (lowered_tv->name() == tv->name()) { + matching_tv = lowered_tv; + } + } + TORCH_INTERNAL_ASSERT(matching_tv != nullptr); + return matching_tv; +} + +class PredicatedChecker : public kir::IrVisitor { + public: + // Checks if the provided tv is written to within a non-trivial conditional + static bool isPredicated(TensorView* tv, GpuLower& gpulw) { + PredicatedChecker checker( + loweredTv(tv, gpulw), gpulw.kernel()->topLevelExprs()); + return checker.is_predicated_; + } + + private: + PredicatedChecker() = delete; + + PredicatedChecker(TensorView* tv, std::vector exprs) : tv_(tv) { + kir::IrVisitor::handle(exprs); + } + + using kir::IrVisitor::handle; + bool is_predicated_ = false; + bool predicated_ite_ = false; + TensorView* tv_ = nullptr; + + void handle(kir::IfThenElse* ite) final { + auto prev_ite = predicated_ite_; + predicated_ite_ = !ite->predicate()->value()->isConstScalar(); + kir::IrVisitor::handle(ite); + predicated_ite_ = prev_ite; + } + + void handle(Expr* expr) final { + if (expr->outputs().size() && expr->outputs()[0]->isA()) { + auto ti = expr->outputs()[0]->as(); + if (ti->view() == tv_) { + is_predicated_ = is_predicated_ | predicated_ite_; + if (expr->predicate() != nullptr && + !expr->predicate()->value()->isConst()) { + is_predicated_ = true; + } + } + } + kir::IrVisitor::handle(expr); + } +}; + +class UnswitchInElseChecker : public kir::IrVisitor { + public: + // Checks if there are any unswitched for loops within an else clause + static bool check(GpuLower& gpulw) { + UnswitchInElseChecker checker(gpulw.kernel()->topLevelExprs()); + return checker.found_in_else_; + } + + private: + UnswitchInElseChecker() = delete; + UnswitchInElseChecker(std::vector exprs) { + kir::IrVisitor::handle(exprs); + } + + using kir::IrVisitor::handle; + bool within_else_ = false; + bool found_in_else_ = false; + + void handle(kir::IfThenElse* ite) final { + auto prev_within_else = within_else_; + within_else_ = true; + kir::IrVisitor::handle(ite->elseBody().exprs()); + within_else_ = prev_within_else; + } + + void handle(kir::ForLoop* for_loop) final { + if (for_loop->iter_domain()->getParallelType() == ParallelType::Unswitch) { + found_in_else_ = found_in_else_ || within_else_; + } + kir::IrVisitor::handle(for_loop); + } +}; + +class PredicateMagicZeroChecker : public kir::IrVisitor { + public: + // Checks if all predicated domains of the provided tv are protected with + // magic zero + static bool isProtected(TensorView* tv, GpuLower& gpulw) { + PredicateMagicZeroChecker checker( + loweredTv(tv, gpulw), gpulw.kernel()->topLevelExprs()); + return checker.is_protected_; + } + + private: + using kir::IrVisitor::handle; + + PredicateMagicZeroChecker(TensorView* tv, std::vector exprs) + : tv_(tv) { + handle(exprs); + } + + void handle(kir::IfThenElse* ite) final { + auto prev_predicate = predicate_; + predicate_ = ite->predicate()->value(); + kir::IrVisitor::handle(ite); + predicate_ = prev_predicate; + } + + void handle(Expr* expr) final { + if (expr->outputs().size() && expr->outputs()[0]->isA()) { + auto ti = expr->outputs()[0]->as(); + if (ti->view() == tv_) { + is_protected_ = checkPredicateOfTensor(predicate_); + return; + } + } + + if (expr->isA()) { + handle(expr->as()); + } else if (expr->isA()) { + handle(expr->as()); + } else { + for (auto input : expr->inputs()) { + handle(input); + } + } + } + + // Return true If all predicated domains are protected + bool checkPredicateOfTensor(Val* predicate) { + auto id_predicates = decomposeCompoundPredicate(predicate); + for (auto id_predicate : id_predicates) { + // Just check if nvfuser_zero is used. Not perfect but probably + // good enough. + is_magic_zero_found_ = false; + handle(id_predicate); + if (!is_magic_zero_found_) { + return false; + } + } + return true; + } + + // Decompose "X && Y" to a vector of {X, Y}. + std::vector decomposeCompoundPredicate(Val* predicate) { + if (auto binary_op = dynamic_cast(predicate->definition())) { + if (binary_op->getBinaryOpType() == BinaryOpType::And) { + auto pred = decomposeCompoundPredicate(binary_op->lhs()); + auto rhs_pred = decomposeCompoundPredicate(binary_op->rhs()); + pred.insert(pred.end(), rhs_pred.begin(), rhs_pred.end()); + return pred; + } + } + + return {predicate}; + } + + void handle(Val* val) final { + if (isMagicZero(val)) { + is_magic_zero_found_ = true; + return; + } + + auto def = val->definition(); + if (def != nullptr) { + handle(def); + } + } + + private: + bool is_protected_ = false; + Val* predicate_ = nullptr; + TensorView* tv_ = nullptr; + bool is_magic_zero_found_ = false; +}; + +// Basically just TransformPropagator, except that it checks the consistency +// replayPasC with getMatchedLeafPosWithoutReplayPasC, replayCasP with +// getMatchedLeafPosWithoutReplayCasP, and fullSelfReplay with fullSelfMatching: +// - After replayPasC, getMatchedLeafPosWithoutReplayPasC should return the same +// replayed position +// - After replayCasP, getMatchedLeafPosWithoutReplayCasP should return the same +// replayed position +// - After fullSelfReplay, fullSelfMatching should return true +struct TransformPropagatorWithCheck : public TransformPropagator { + public: + virtual void propagateC2P(TensorView* from, TensorView* to) override { + TransformPropagator::propagateC2P(from, to); + auto from_pos = replayed_pos_.at(from); + auto to_pos = replayed_pos_.at(to); + TORCH_CHECK( + TransformReplay::getMatchedLeafPosWithoutReplayPasC( + to, from, from_pos) == to_pos); + } + virtual void propagateP2C(TensorView* from, TensorView* to) override { + TransformPropagator::propagateP2C(from, to); + auto from_pos = replayed_pos_.at(from); + auto to_pos = replayed_pos_.at(to); + TORCH_CHECK( + TransformReplay::getMatchedLeafPosWithoutReplayCasP( + to, from, from_pos) == to_pos); + } + virtual void propagateSibling(TensorView* from, TensorView* to) override { + TransformPropagator::propagateSibling(from, to); + auto from_pos = replayed_pos_.at(from); + auto to_pos = replayed_pos_.at(to); + TORCH_CHECK(from_pos == to_pos); + TORCH_CHECK(TransformReplay::fullSelfMatching(from, to)); + } + using TransformPropagator::TransformPropagator; +}; + } // namespace + +class ContextCudnnTF32Disabled { + public: + ContextCudnnTF32Disabled() { + flag_ = at::globalContext().allowTF32CuDNN(); + at::globalContext().setAllowTF32CuDNN(false); + } + + ~ContextCudnnTF32Disabled() { + at::globalContext().setAllowTF32CuDNN(flag_); + } + + private: + bool flag_; +}; + +// Fixture class must be uniquely identified, i.e., can't be in an +// anonymous namespace +class NVFuserTest : public ::testing::Test { + protected: + void SetUp() override { + // requires PASCAL or newer + if (!deviceMajorMinorCheck(6)) { + GTEST_SKIP() << "skipping tests on pre-PASCAL GPUs"; + } + setFillAllocationWithNan(true); + } +}; + } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/transform_iter.cpp b/torch/csrc/jit/codegen/cuda/transform_iter.cpp index 3e0473665b478..ab683e79ce9aa 100644 --- a/torch/csrc/jit/codegen/cuda/transform_iter.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_iter.cpp @@ -137,7 +137,7 @@ void ReplayTransformations::handle(Swizzle2D* swizzle_2d) { auto id_in_y = swizzle_2d->inY(); // Make sure we have a corresponding entry in our map pointing to the ID we're - // going to replay the split on + // going to replay the swizzle on auto it_x = id_map_.find(id_in_x); auto it_y = id_map_.find(id_in_y); @@ -162,7 +162,7 @@ void ReplayTransformations::handle(Swizzle2D* swizzle_2d) { auto outs = std::make_pair(mapped_x, mapped_y); if (replay_swizzle_) { - // Replay the split onto mapped + // Replay the swizzle onto mapped outs = IterDomain::swizzle(swizzle_2d->swizzleType(), mapped_x, mapped_y); // Remove mapped from the leaf IDs @@ -224,7 +224,7 @@ void ReplayTransformations::runReplay() { // Switch outDomain to a vector to start the traversal std::vector traversal_vals( target_domain_.begin(), target_domain_.end()); - traverseFrom(traversal_vals[0]->fusion(), traversal_vals); + traverseTo(traversal_vals[0]->fusion(), traversal_vals); if (error_on_failure_) TORCH_INTERNAL_ASSERT( @@ -762,14 +762,6 @@ struct ProducerForwardingInfo { (outer->isTrivialReduction() && !inner->isReduction())) { auto compliment_id = inner->isTrivialReduction() ? inner : outer; auto forwarded_id = inner->isTrivialReduction() ? outer : inner; - // Only allow forwarding when the trivial reduction domain is - // an root domain - if (std::find( - producer->getMaybeRFactorDomain().begin(), - producer->getMaybeRFactorDomain().end(), - compliment_id) == producer->getMaybeRFactorDomain().end()) { - continue; - } forwarding_map.emplace(std::make_pair(forwarded_id, merge->out())); compliment_map.emplace(std::make_pair( forwarded_id, std::vector{compliment_id})); diff --git a/torch/csrc/jit/codegen/cuda/transform_rfactor.cpp b/torch/csrc/jit/codegen/cuda/transform_rfactor.cpp index dc5973c0ecd6a..8d5151074563e 100644 --- a/torch/csrc/jit/codegen/cuda/transform_rfactor.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_rfactor.cpp @@ -262,7 +262,7 @@ std::pair TransformRFactor::runReplay( std::transform(axes.begin(), axes.end(), axes.begin(), [ndims](int i) { TORCH_CHECK( i >= -ndims && i < ndims, - "Rfactor replay recieved an axis outside the number of dims in the tensor, acceptable inclusive range is ", + "Rfactor replay received an axis outside the number of dims in the tensor, acceptable inclusive range is ", -ndims, " to ", ndims - 1); diff --git a/torch/csrc/jit/codegen/cuda/transform_view.cpp b/torch/csrc/jit/codegen/cuda/transform_view.cpp index e5f9c068f16c1..c617f548649ec 100644 --- a/torch/csrc/jit/codegen/cuda/transform_view.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_view.cpp @@ -509,7 +509,7 @@ class AnalyzeViewTransformation { "View is complete, but there's still some elements to distribute."); } - if ((new_view_index == new_view_.size() || + if ((new_view_index + 1 >= new_view_.size() || (new_view_[new_view_index + 1] != 1)) && original_view_index + 1 < original_view_.size() && original_view_[original_view_index + 1] == 1 && @@ -732,7 +732,7 @@ AnalyzeViewResult analyzeView( FUSER_PERF_SCOPE("analyzeView"); TORCH_INTERNAL_ASSERT( original_sizes.size() > 0, - "Empty original size not supported for view operatioon."); + "Empty original size not supported for view operation."); TORCH_INTERNAL_ASSERT( TensorDomain::noReductions(original_view_tv->getMaybeRFactorDomain()) diff --git a/torch/csrc/jit/codegen/cuda/type.cpp b/torch/csrc/jit/codegen/cuda/type.cpp index e3d61efac9722..3b8f380683ed2 100644 --- a/torch/csrc/jit/codegen/cuda/type.cpp +++ b/torch/csrc/jit/codegen/cuda/type.cpp @@ -1,5 +1,7 @@ #include +#include + #include #include @@ -160,6 +162,17 @@ DataType getTypeFromComplexType(DataType dtype) { } } +bool isSupportedTypeByDevice(DataType dtype) { + auto prop = at::cuda::getCurrentDeviceProperties(); + auto major_ver = prop->major; + switch (dtype) { + case DataType::BFloat16: + return major_ver >= 8; + default: + return true; + } +} + bool isIntegerOp(const BinaryOpType bopt) { return bopt >= BinaryOpType::Mod && bopt <= BinaryOpType::Rshift; } @@ -290,8 +303,12 @@ static const char* predicate_type2string(PredicateType t) { static const char* expr_type2string(ExprType t) { switch (t) { + case ExprType::FullOp: + return "FullOp"; case ExprType::ARangeOp: return "ARangeOp"; + case ExprType::EyeOp: + return "EyeOp"; case ExprType::UnaryOp: return "UnaryOp"; case ExprType::BinaryOp: @@ -656,6 +673,8 @@ static const char* rng_op_type_inline_op2string(RNGOpType t) { switch (t) { case RNGOpType::Uniform: return "rng_uniform"; + case RNGOpType::UniformRange: + return "rng_uniform_range"; default: break; } @@ -694,6 +713,8 @@ static const char* rng_op_type2string(RNGOpType t) { switch (t) { case RNGOpType::Uniform: return "rng_uniform"; + case RNGOpType::UniformRange: + return "rng_uniform_range"; default: TORCH_INTERNAL_ASSERT(false, "Unexpected RNGOpType"); } diff --git a/torch/csrc/jit/codegen/cuda/type.h b/torch/csrc/jit/codegen/cuda/type.h index 224922febc3fa..4aa894113e993 100644 --- a/torch/csrc/jit/codegen/cuda/type.h +++ b/torch/csrc/jit/codegen/cuda/type.h @@ -101,10 +101,14 @@ int getVectorSizeFromType(DataType dtype); DataType getTypeFromVectorType(DataType dtype); // Return the corresponding scalar of a complex type DataType getTypeFromComplexType(DataType dtype); +// Return if the datatype is supported on the current device +TORCH_CUDA_CU_API bool isSupportedTypeByDevice(DataType dtype); enum class ExprType { Invalid, + FullOp, ARangeOp, + EyeOp, UnaryOp, BinaryOp, TernaryOp, @@ -244,7 +248,8 @@ enum class BinaryOpType { }; enum class RNGOpType { - Uniform, + Uniform, // Uniform in [0, 1) + UniformRange, // Uniform in [low, high] }; // Return if output of operator should be a boolean diff --git a/torch/csrc/jit/codegen/cuda/type_inference.cpp b/torch/csrc/jit/codegen/cuda/type_inference.cpp index 534fa91488cee..7422cf20d7c2b 100644 --- a/torch/csrc/jit/codegen/cuda/type_inference.cpp +++ b/torch/csrc/jit/codegen/cuda/type_inference.cpp @@ -445,13 +445,16 @@ class NaiveTypePropagator { copyScalarTypeAndDeviceToOutput(out_type->withDim(c10::nullopt), node); break; } - case prim::unsqueeze_copy: case prim::expand_copy: case prim::expand_as_copy: - case prim::squeeze_copy: + case prim::flatten_copy: + case prim::permute_copy: case prim::reshape_copy: - case prim::view_copy: - case prim::flatten_copy: { + case prim::squeeze_copy: + case prim::t_copy: + case prim::transpose_copy: + case prim::unsqueeze_copy: + case prim::view_copy: { auto out_type = node->input(0)->type()->cast(); copyScalarTypeAndDeviceToOutput(out_type, node); break; diff --git a/torch/csrc/jit/codegen/cuda/utils.cpp b/torch/csrc/jit/codegen/cuda/utils.cpp index d7409c98db658..33395692fb39e 100644 --- a/torch/csrc/jit/codegen/cuda/utils.cpp +++ b/torch/csrc/jit/codegen/cuda/utils.cpp @@ -41,9 +41,10 @@ auto parseDebugDumpOptions() { {DebugDumpOption::PythonDefinition, false}, {DebugDumpOption::PythonFrontendDebug, false}, {DebugDumpOption::TransformPropagator, false}, - {DebugDumpOption::InlinePropagator, false}, {DebugDumpOption::Cubin, false}, - {DebugDumpOption::Ptx, false}}; + {DebugDumpOption::Ptx, false}, + {DebugDumpOption::BankConflictInfo, false}, + {DebugDumpOption::SyncMap, false}}; if (const char* dump_options = std::getenv("PYTORCH_NVFUSER_DUMP")) { c10::string_view options_view(dump_options); @@ -100,12 +101,14 @@ auto parseDebugDumpOptions() { options_map[DebugDumpOption::PythonFrontendDebug] = true; } else if (token == "transform_propagator") { options_map[DebugDumpOption::TransformPropagator] = true; - } else if (token == "inline_propagator") { - options_map[DebugDumpOption::InlinePropagator] = true; } else if (token == "cubin") { options_map[DebugDumpOption::Cubin] = true; } else if (token == "ptx") { options_map[DebugDumpOption::Ptx] = true; + } else if (token == "bank_conflict") { + options_map[DebugDumpOption::BankConflictInfo] = true; + } else if (token == "sync_map") { + options_map[DebugDumpOption::SyncMap] = true; } else { TORCH_CHECK( false, @@ -118,7 +121,7 @@ auto parseDebugDumpOptions() { "\tdraw_segmented_fusion, scheduler_params, parallel_dimensions,\n", "\tbuffer_reuse_verbose, ptxas_verbose, halo, segmenter_logging,\n", "\tperf_debug_verbose, python_definition, python_frontend_debug,\n", - "\ttransform_propagator, inline_propagator, cubin, ptx\n"); + "\ttransform_propagator, cubin, ptx, bank_conflict, sync_map\n"); } options_view = (end_pos != c10::string_view::npos) ? options_view.substr(end_pos + 1) @@ -132,6 +135,7 @@ auto parseDebugDumpOptions() { auto parseDisableOptions() { std::unordered_map options_map = { {DisableOption::ArchCheck, false}, + {DisableOption::CompileToSass, false}, {DisableOption::Fallback, false}, {DisableOption::Fma, false}, {DisableOption::IndexHoist, false}, @@ -145,6 +149,8 @@ auto parseDisableOptions() { const auto token = options_view.substr(0, end_pos); if (token == "arch_check") { options_map[DisableOption::ArchCheck] = true; + } else if (token == "compile_to_sass") { + options_map[DisableOption::CompileToSass] = true; } else if (token == "fallback") { options_map[DisableOption::Fallback] = true; } else if (token == "fma") { @@ -179,8 +185,7 @@ auto parseEnableOptions() { {EnableOption::Complex, false}, {EnableOption::KernelProfile, false}, {EnableOption::LinearDecomposition, false}, - {EnableOption::ConvDecomposition, false}, - {EnableOption::TransposeScheduler, false}}; + {EnableOption::ConvDecomposition, false}}; if (const char* dump_options = std::getenv("PYTORCH_NVFUSER_ENABLE")) { c10::string_view options_view(dump_options); @@ -195,8 +200,6 @@ auto parseEnableOptions() { options_map[EnableOption::LinearDecomposition] = true; } else if (token == "conv_decomposition") { options_map[EnableOption::ConvDecomposition] = true; - } else if (token == "transpose_scheduler") { - options_map[EnableOption::TransposeScheduler] = true; } else { TORCH_CHECK( false, @@ -204,7 +207,7 @@ auto parseEnableOptions() { token, "'\nAvailable options:\n", "\tcomplex, kernel_profile, linear_decomposition,", - "conv_decomposition, transpose_scheduler"); + "conv_decomposition"); } options_view = (end_pos != c10::string_view::npos) ? options_view.substr(end_pos + 1) diff --git a/torch/csrc/jit/codegen/cuda/utils.h b/torch/csrc/jit/codegen/cuda/utils.h index 5b5c794f3810d..61f7fee7cd4cf 100644 --- a/torch/csrc/jit/codegen/cuda/utils.h +++ b/torch/csrc/jit/codegen/cuda/utils.h @@ -57,10 +57,10 @@ enum class DebugDumpOption { PythonFrontendDebug, //! Python Frontend debug information. TransformPropagator, //! When running TransformPropagator, print propagation //! path and replay result - InlinePropagator, //! When running InlinePropagator, print propagation - //! path and inlining result Cubin, //! Dump compiled CUBIN - Ptx //! Dump compiled PTX + Ptx, //! Dump compiled PTX + BankConflictInfo, //! Dump bank confliction info + SyncMap //! RAW dependency info }; TORCH_CUDA_CU_API bool isDebugDumpEnabled(DebugDumpOption option); @@ -71,6 +71,8 @@ TORCH_CUDA_CU_API bool isDebugDumpEnabled(DebugDumpOption option); //! enum class DisableOption { ArchCheck, //! Disable hardware-specific checks to enable cross arch debug + CompileToSass, //! Disable direct compilation to sass so the ptx can be + //! examined Fallback, //! Disable fallback Fma, //! Disable FMA instructions IndexHoist, //! Disable index hoisting @@ -89,7 +91,6 @@ enum class EnableOption { KernelProfile, //! Enable intra-kernel performance profiling LinearDecomposition, //! Enable linear-bias decomposition ConvDecomposition, //! Enable conv-bias decomposition - TransposeScheduler //! Enable the experimental transpose scheduler }; TORCH_CUDA_CU_API bool isOptionEnabled(EnableOption option); diff --git a/torch/csrc/jit/codegen/fuser/codegen.cpp b/torch/csrc/jit/codegen/fuser/codegen.cpp index 0665d21a7a4fc..c28ad2ba1ae09 100644 --- a/torch/csrc/jit/codegen/fuser/codegen.cpp +++ b/torch/csrc/jit/codegen/fuser/codegen.cpp @@ -490,7 +490,7 @@ std::string generateKernel( env.s("access", format("t${formal}.data[t${formal}_offset]", env)); env.s("access_vec4", format("t${formal}_buf[i]", env)); } - env.s("lhs_type", calcScalarTypeName(input.second.value().scalar_type)); + env.s("lhs_type", calcScalarTypeName(input.second->scalar_type)); // load input in vectorized code path auto ele_size = at::elementSize((*input.second).scalar_type); diff --git a/torch/csrc/jit/codegen/fuser/cpu/fused_kernel.cpp b/torch/csrc/jit/codegen/fuser/cpu/fused_kernel.cpp index 013a8e8b4adbf..8da7e63a69355 100644 --- a/torch/csrc/jit/codegen/fuser/cpu/fused_kernel.cpp +++ b/torch/csrc/jit/codegen/fuser/cpu/fused_kernel.cpp @@ -263,7 +263,7 @@ static const std::string compile_string = #ifndef __PPC64__ // "-march=native " #endif - "-std=c++14 -fPIC ${fopenmp} -shared \"${cpp_file}\" -o \"${so_file}\" -lm"; + "-std=c++17 -fPIC ${fopenmp} -shared \"${cpp_file}\" -o \"${so_file}\" -lm"; #endif static void runCompiler( const std::string& cpp_file, diff --git a/torch/csrc/jit/codegen/fuser/cpu/fused_kernel.h b/torch/csrc/jit/codegen/fuser/cpu/fused_kernel.h index ce5d6ee2c5546..2e6d59596323d 100644 --- a/torch/csrc/jit/codegen/fuser/cpu/fused_kernel.h +++ b/torch/csrc/jit/codegen/fuser/cpu/fused_kernel.h @@ -3,7 +3,6 @@ #include #include #include -#include #include #include diff --git a/torch/csrc/jit/codegen/fuser/cpu/temp_file.h b/torch/csrc/jit/codegen/fuser/cpu/temp_file.h index 080d76bde2225..9fb53bc962c5b 100644 --- a/torch/csrc/jit/codegen/fuser/cpu/temp_file.h +++ b/torch/csrc/jit/codegen/fuser/cpu/temp_file.h @@ -1,9 +1,9 @@ #pragma once #include +#include #include #include -#include #ifdef _WIN32 #include @@ -61,7 +61,7 @@ int wmkstemps(wchar_t* tmpl, int suffix_len) { #endif struct TempFile { - TH_DISALLOW_COPY_AND_ASSIGN(TempFile); + AT_DISALLOW_COPY_AND_ASSIGN(TempFile); TempFile(const std::string& t, int suffix) { #ifdef _MSC_VER diff --git a/torch/csrc/jit/codegen/fuser/cuda/fused_kernel.cpp b/torch/csrc/jit/codegen/fuser/cuda/fused_kernel.cpp index 85bd74bfdbae4..72a011febe762 100644 --- a/torch/csrc/jit/codegen/fuser/cuda/fused_kernel.cpp +++ b/torch/csrc/jit/codegen/fuser/cuda/fused_kernel.cpp @@ -64,6 +64,8 @@ void codegenOutputQuery( max_dev_version = CudaVersion(7, 5); } else if (nvrtc_version == CudaVersion(11, 0)) { // 11.0 supports 3-8.0 max_dev_version = CudaVersion(8, 0); + } else if (nvrtc_version.first == 11 && nvrtc_version.second < 8) { + max_dev_version = CudaVersion(8, 6); } else { // If the driver version is unknown (i.e. newer than this code) // assume the driver supports this device @@ -125,7 +127,7 @@ FusedKernelCUDA::FusedKernelCUDA( &program, code_.c_str(), nullptr, 0, nullptr, nullptr)); #if defined(USE_ROCM) - std::vector args = {"--std=c++14"}; + std::vector args = {"--std=c++17"}; #if ROCM_VERSION >= 40200 args.push_back("-hip-pch"); #endif @@ -146,7 +148,7 @@ FusedKernelCUDA::FusedKernelCUDA( std::to_string(major) + std::to_string(minor); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) const std::vector args = { - "--std=c++14", compute.c_str(), "-default-device"}; + "--std=c++17", compute.c_str(), "-default-device"}; #endif const auto result = nvrtc().nvrtcCompileProgram(program, args.size(), args.data()); diff --git a/torch/csrc/jit/codegen/fuser/fused_kernel.h b/torch/csrc/jit/codegen/fuser/fused_kernel.h index 3d34082ff771b..29ab3e7ed51c0 100644 --- a/torch/csrc/jit/codegen/fuser/fused_kernel.h +++ b/torch/csrc/jit/codegen/fuser/fused_kernel.h @@ -1,9 +1,9 @@ #pragma once #include +#include #include #include -#include #include #include @@ -14,7 +14,7 @@ namespace jit { namespace fuser { struct FusedKernel { - TH_DISALLOW_COPY_AND_ASSIGN(FusedKernel); + AT_DISALLOW_COPY_AND_ASSIGN(FusedKernel); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) FusedKernel( diff --git a/torch/csrc/jit/codegen/onednn/README.md b/torch/csrc/jit/codegen/onednn/README.md index ca2a644372dd2..e3f3ec66734b2 100644 --- a/torch/csrc/jit/codegen/onednn/README.md +++ b/torch/csrc/jit/codegen/onednn/README.md @@ -104,7 +104,7 @@ with torch.no_grad(): # run the model with torch.no_grad(): - # oneDNN graph fusion will be trigerred during runtime + # oneDNN graph fusion will be triggered during runtime output = model(images) ``` diff --git a/torch/csrc/jit/codegen/onednn/graph_helper.cpp b/torch/csrc/jit/codegen/onednn/graph_helper.cpp index a8a202acf0dac..a14dce108dd12 100644 --- a/torch/csrc/jit/codegen/onednn/graph_helper.cpp +++ b/torch/csrc/jit/codegen/onednn/graph_helper.cpp @@ -505,7 +505,6 @@ Node* LlgaGraphHelper::createSingletonSubgraph(Node* n, AliasDb& aliasDb) { auto group = SubgraphUtils::createSingletonSubgraphAndUpdateAliasing( n, prim::oneDNNFusionGroup, aliasDb); opToOwningPartition_.add(group, partitionId); - LlgaNodeWrapper(group).initOutputLayouts(); return group; } @@ -585,25 +584,29 @@ LlgaNodeWrapper::LlgaNodeWrapper(const Node* node) } void LlgaNodeWrapper::setOpaqueLayout(size_t offset) { - TORCH_CHECK(offset < n->outputs().size(), "Invalid output offset ", offset); + const auto num_output = n->is(attr::output_layouts).size(); + TORCH_CHECK( + offset < num_output, + "Out of range. (Invalid index ", + offset, + " for attr::output_layouts with size ", + num_output, + ")"); auto& layouts = const_cast&>(n->is(attr::output_layouts)); // NOLINT - layouts.at(offset) = 1; + layouts.at(offset) = OPAQUE_LAYOUT; } bool LlgaNodeWrapper::useOpaqueLayout(size_t offset) const { - TORCH_CHECK(offset < n->outputs().size(), "Invalid output offset ", offset); - return n->is(attr::output_layouts)[offset] == 1; -} - -void LlgaNodeWrapper::initOutputLayouts() { - if (n->hasAttribute(attr::output_layouts)) { - return; - } - - // Init all output layouts as undef - std::vector layouts(n->outputs().size(), 0); - n->is_(attr::output_layouts, layouts); + const auto num_output = n->is(attr::output_layouts).size(); + TORCH_CHECK( + offset < num_output, + "Out of range. (Invalid index ", + offset, + " for attr::output_layouts with size ", + num_output, + ")"); + return n->is(attr::output_layouts)[offset] == OPAQUE_LAYOUT; } } // namespace onednn diff --git a/torch/csrc/jit/codegen/onednn/graph_helper.h b/torch/csrc/jit/codegen/onednn/graph_helper.h index 5422a90d9e97b..fbb5eaa84aec7 100644 --- a/torch/csrc/jit/codegen/onednn/graph_helper.h +++ b/torch/csrc/jit/codegen/onednn/graph_helper.h @@ -10,6 +10,9 @@ namespace jit { namespace fuser { namespace onednn { +#define STRIDED_LAYOUT 0 +#define OPAQUE_LAYOUT 1 + struct OpPartitionMap { void add(uint64_t opId, uint64_t partitionId) { opmap_[opId] = partitionId; @@ -92,8 +95,6 @@ class LlgaNodeWrapper { friend class LlgaGraphHelper; private: - void initOutputLayouts(); - Node* n; }; diff --git a/torch/csrc/jit/codegen/onednn/layout_propagation.cpp b/torch/csrc/jit/codegen/onednn/layout_propagation.cpp index 448e1cf858849..4201282fb083b 100644 --- a/torch/csrc/jit/codegen/onednn/layout_propagation.cpp +++ b/torch/csrc/jit/codegen/onednn/layout_propagation.cpp @@ -1,5 +1,6 @@ #include #include +#include namespace torch { namespace jit { @@ -10,6 +11,14 @@ void LayoutPropagation(Node* n) { if (!LlgaGraphHelper::isLlgaSubgraph(n)) return; + // initial attr::output_layouts if undefined + if (!n->hasAttribute(attr::output_layouts)) { + const auto num_output = n->outputs().size(); + GRAPH_DEBUG("Initial output_layouts of size ", num_output); + std::vector layouts(num_output, STRIDED_LAYOUT); + n->is_(attr::output_layouts, layouts); + } + for (auto input : n->inputs()) { auto prev = input->node(); auto offset = input->offset(); diff --git a/torch/csrc/jit/docs/serialization.md b/torch/csrc/jit/docs/serialization.md index 8c3461a9abe83..a374f5bed40ba 100644 --- a/torch/csrc/jit/docs/serialization.md +++ b/torch/csrc/jit/docs/serialization.md @@ -127,7 +127,7 @@ its methods or attributes. **Uses of tensor constants**. Most constants are inlined as literals, like strings or ints. But since tensors are potentially very large, when -`PythonPrint` encouters a constant tensor it will emit a reference to a +`PythonPrint` encounters a constant tensor it will emit a reference to a global `CONSTANTS` table (like `foo = CONSTANTS.c0`). When importing, the importer will know how to resolve this reference into an diff --git a/torch/csrc/jit/frontend/ir_emitter.cpp b/torch/csrc/jit/frontend/ir_emitter.cpp index d60dd77bc8dad..7c53dbd0b3392 100644 --- a/torch/csrc/jit/frontend/ir_emitter.cpp +++ b/torch/csrc/jit/frontend/ir_emitter.cpp @@ -5640,7 +5640,7 @@ void CompilationUnit::define_interface( for (const Stmt& stmt : classDef.body()) { if (stmt.kind() != TK_DEF) { throw ErrorReport(stmt) - << "interface declartions can only contain method definitions"; + << "interface declarations can only contain method definitions"; } auto method_def = Def(stmt); if (!method_def.decl().return_type().present()) { diff --git a/torch/csrc/jit/frontend/script_type_parser.cpp b/torch/csrc/jit/frontend/script_type_parser.cpp index f5d6f640d413d..d05ec95fb9fa2 100644 --- a/torch/csrc/jit/frontend/script_type_parser.cpp +++ b/torch/csrc/jit/frontend/script_type_parser.cpp @@ -316,7 +316,7 @@ std::vector ScriptTypeParser::evaluateDefaults( // We then run constant prop on this graph and check the results are // constant. This approach avoids having to have separate handling of // default arguments from standard expressions by piecing together existing - // machinery for graph generation, constant propgation, and constant + // machinery for graph generation, constant propagation, and constant // extraction. auto tuple_type = Subscript::create( r, diff --git a/torch/csrc/jit/ir/ir.h b/torch/csrc/jit/ir/ir.h index fe9e340fbe02d..402dd58b00846 100644 --- a/torch/csrc/jit/ir/ir.h +++ b/torch/csrc/jit/ir/ir.h @@ -7,10 +7,10 @@ #include #include -#include #include #include +#include #include #include #include @@ -177,7 +177,7 @@ struct Wrap { }; struct Value { - TH_DISALLOW_COPY_AND_ASSIGN(Value); + AT_DISALLOW_COPY_AND_ASSIGN(Value); Value(Node* node_, size_t offset_); private: @@ -239,6 +239,11 @@ struct Value { const Node* node() const { return node_; } + + /** + * @warning NEVER pass raw pointer of smart pointer managed Graph to Python. + * Check #87343 for details. + */ Graph* owningGraph(); const Graph* owningGraph() const; // TODO: make this more const correct @@ -310,7 +315,7 @@ struct Value { }; struct TORCH_API Node { - TH_DISALLOW_COPY_AND_ASSIGN(Node); + AT_DISALLOW_COPY_AND_ASSIGN(Node); friend struct Graph; friend struct Block; friend struct Value; @@ -398,6 +403,10 @@ struct TORCH_API Node { } SourceRange sourceRange() const; + /** + * @warning NEVER pass raw pointer of smart pointer managed Graph to Python. + * Check #87343 for details. + */ Graph* owningGraph() { return graph_; } @@ -1015,7 +1024,7 @@ struct Block { friend struct Node; friend struct Graph; - TH_DISALLOW_COPY_AND_ASSIGN(Block); + AT_DISALLOW_COPY_AND_ASSIGN(Block); TORCH_API Block(Graph* graph_, Node* node_); at::ArrayRef inputs() { @@ -1049,6 +1058,10 @@ struct Block { const Node* param_node() const { return input_; } + /** + * @warning NEVER pass raw pointer of smart pointer managed Graph to Python. + * Check #87343 for details. + */ Graph* owningGraph() { return graph_; } @@ -1163,8 +1176,8 @@ struct Block { std::shared_ptr> wrap_; }; -struct Graph { - TH_DISALLOW_COPY_AND_ASSIGN(Graph); +struct Graph : std::enable_shared_from_this { + AT_DISALLOW_COPY_AND_ASSIGN(Graph); friend struct Node; friend struct Value; friend struct Block; diff --git a/torch/csrc/jit/ir/irparser.cpp b/torch/csrc/jit/ir/irparser.cpp index 1f790de92cb1a..0673645731da0 100644 --- a/torch/csrc/jit/ir/irparser.cpp +++ b/torch/csrc/jit/ir/irparser.cpp @@ -237,7 +237,7 @@ ParsedLiteral IRParser::parseScalarLiteral(Node* n) { auto text = L.expect(TK_NUMBER); if (!parse_tensor_constants_) { throw ErrorReport(token.range) - << "Single-element tensor constant encoutered but " + << "Single-element tensor constant encountered but " << "`parse_tensor_constants` is set to false " << token.text(); } L.expect('}'); diff --git a/torch/csrc/jit/mobile/compatibility/backport_manager.cpp b/torch/csrc/jit/mobile/compatibility/backport_manager.cpp index 489084912445f..2bad08c0765a2 100644 --- a/torch/csrc/jit/mobile/compatibility/backport_manager.cpp +++ b/torch/csrc/jit/mobile/compatibility/backport_manager.cpp @@ -7,7 +7,6 @@ #include #include #include -#include #include #include #include @@ -504,7 +503,6 @@ std::stringstream backport_v7_to_v6(std::stringstream& input_model_stream) { std::stringstream backport_v9_to_v8(std::stringstream& input_model_stream) { ExtraFilesMap extra_files; - register_flatbuffer_all(); Module torch_script = torch::jit::load(input_model_stream, c10::nullopt, extra_files); std::stringstream intermediate_model_stream; diff --git a/torch/csrc/jit/mobile/compatibility/model_compatibility.cpp b/torch/csrc/jit/mobile/compatibility/model_compatibility.cpp index 089c116179ef9..9ce71eba9ce75 100644 --- a/torch/csrc/jit/mobile/compatibility/model_compatibility.cpp +++ b/torch/csrc/jit/mobile/compatibility/model_compatibility.cpp @@ -4,6 +4,7 @@ #include // removed after using simple type_resolver/obj_loader #include #include +#include #include // removed after using simple type_resolver/obj_loader #include #include @@ -111,13 +112,7 @@ uint64_t _get_model_bytecode_version_from_bytes(char* data, size_t size) { auto format = getFileFormat(data); switch (format) { case FileFormat::FlatbufferFileFormat: { - if (get_flatbuffer_bytecode_version == nullptr) { - TORCH_CHECK( - false, - "Flatbuffer input file but the build hasn't enabled flatbuffer"); - } else { - return get_flatbuffer_bytecode_version(data); - } + return get_bytecode_version_from_bytes(data); } case FileFormat::ZipFileFormat: { auto rai = diff --git a/torch/csrc/jit/mobile/flatbuffer_loader.cpp b/torch/csrc/jit/mobile/flatbuffer_loader.cpp index 45e31fb5e1747..ec18e489b5cd7 100644 --- a/torch/csrc/jit/mobile/flatbuffer_loader.cpp +++ b/torch/csrc/jit/mobile/flatbuffer_loader.cpp @@ -1,5 +1,3 @@ -#include - #ifdef FLATBUFFERS_VERSION_MAJOR #error "flatbuffer_loader.h must not include any flatbuffers headers" #endif // FLATBUFFERS_VERSION_MAJOR @@ -24,8 +22,8 @@ #include #include #include -#include #include +#include #include #include #include @@ -718,7 +716,7 @@ void FlatbufferLoader::extractJitSourceAndConstants( std::vector* constants) { AT_ASSERT( module_parsed_, - "Need to first parse a flatbuffer file before extracing jit_sources"); + "Need to first parse a flatbuffer file before extracting jit_sources"); const auto* ivalues = module_->ivalues(); for (uint32_t i = mobile_ivalue_size_; i < ivalues->size(); i++) { @@ -882,7 +880,6 @@ mobile::Module load_mobile_module_from_stream_with_copy( std::move(data), size, device, extra_files); } -namespace { mobile::Module parse_flatbuffer_no_object( std::shared_ptr data, size_t size, @@ -912,16 +909,10 @@ mobile::Module parse_flatbuffer_no_object( m.set_delete_memory(std::move(data)); return m; } -} // namespace bool register_flatbuffer_loader() { - load_flatbuffer_bytes = parse_and_initialize_mobile_module; - load_flatbuffer_bytes_no_object = parse_flatbuffer_no_object; - get_flatbuffer_bytecode_version = get_bytecode_version_from_bytes; return true; } -const bool kRegisteredFlatbufferLoader = register_flatbuffer_loader(); - } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/mobile/flatbuffer_loader.h b/torch/csrc/jit/mobile/flatbuffer_loader.h index eee44d4b647ed..f29fe5b2e4942 100644 --- a/torch/csrc/jit/mobile/flatbuffer_loader.h +++ b/torch/csrc/jit/mobile/flatbuffer_loader.h @@ -117,10 +117,19 @@ TORCH_API mobile::Module load_mobile_module_from_stream_with_copy( c10::optional device = c10::nullopt, ExtraFilesMap* extra_files = nullptr); -// This function will make the capabilities to load -// Module as a flatbuffer file available for use by _load_for_mobile -// and friends. This is NOT needed if using the other functions -// in this file directly. +TORCH_API mobile::Module parse_flatbuffer_no_object( + std::shared_ptr data, + size_t size, + c10::optional device); + +TORCH_API mobile::Module parse_and_initialize_mobile_module( + void* data, + size_t, + c10::optional, + ExtraFilesMap* extra_files, + bool should_copy_tensor_memory); + +// no op, TODO(qihan) delete TORCH_API bool register_flatbuffer_loader(); } // namespace jit diff --git a/torch/csrc/jit/mobile/import.cpp b/torch/csrc/jit/mobile/import.cpp index 2270418dbbcff..5acd5cab39854 100644 --- a/torch/csrc/jit/mobile/import.cpp +++ b/torch/csrc/jit/mobile/import.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -13,6 +14,7 @@ #include #include #include +#include #include #include #include @@ -88,19 +90,6 @@ using caffe2::serialize::MemoryReadAdapter; using caffe2::serialize::PyTorchStreamReader; using caffe2::serialize::ReadAdapterInterface; -mobile::Module (*load_flatbuffer_bytes)( - std::shared_ptr, - size_t size, - c10::optional, - ExtraFilesMap*) = nullptr; - -mobile::Module (*load_flatbuffer_bytes_no_object)( - std::shared_ptr, - size_t size, - c10::optional) = nullptr; - -uint64_t (*get_flatbuffer_bytecode_version)(char* flatbuffer_content) = nullptr; - OpCode parseOpCode(const char* str); TypePtr resolveTypeNameMobile( @@ -630,13 +619,8 @@ mobile::Module _load_mobile_from_bytes( std::move(rai), device, extra_files, module_load_options); } case FileFormat::FlatbufferFileFormat: { - if (load_flatbuffer_bytes != nullptr) { - return load_flatbuffer_bytes(data, size, device, &extra_files); - } else { - TORCH_CHECK( - false, - "Flatbuffer input file but the build hasn't enabled flatbuffer"); - } + return parse_and_initialize_mobile_module( + data, size, device, &extra_files); } default: { TORCH_CHECK(false, "Format error"); @@ -726,16 +710,7 @@ void _load_extra_only_for_mobile( // TODO: the current flatbuffers implementation will always load the // whole module including the extra files. Ideally it should be // possible to just get the extra files given data - std::shared_ptr data; - size_t size = 0; - std::tie(data, size) = get_file_content(filename.c_str()); - if (load_flatbuffer_bytes != nullptr) { - load_flatbuffer_bytes(data, size, device, &extra_files); - } else { - TORCH_CHECK( - false, - "Flatbuffer input file but the build hasn't enabled flatbuffer"); - } + load_mobile_module_from_file(filename, c10::nullopt, &extra_files); break; } default: { diff --git a/torch/csrc/jit/mobile/import.h b/torch/csrc/jit/mobile/import.h index b17a4bb341ca1..643ca57858a36 100644 --- a/torch/csrc/jit/mobile/import.h +++ b/torch/csrc/jit/mobile/import.h @@ -107,18 +107,5 @@ TORCH_API std::set _export_operator_list( } // namespace mobile -extern mobile::Module (*load_flatbuffer_bytes)( - std::shared_ptr, - size_t size, - c10::optional, - ExtraFilesMap*); - -extern mobile::Module (*load_flatbuffer_bytes_no_object)( - std::shared_ptr, - size_t size, - c10::optional); - -extern uint64_t (*get_flatbuffer_bytecode_version)(char* flatbuffer_content); - } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/mobile/import_data.cpp b/torch/csrc/jit/mobile/import_data.cpp index 01c6ea7ac579c..309b238a8d41b 100644 --- a/torch/csrc/jit/mobile/import_data.cpp +++ b/torch/csrc/jit/mobile/import_data.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -247,14 +248,8 @@ std::map _load_parameters_bytes( std::map map; switch (format) { case FileFormat::FlatbufferFileFormat: { - if (load_flatbuffer_bytes_no_object != nullptr) { - auto m = load_flatbuffer_bytes_no_object(data, size, device); - map = mobile_module_to_parameter_map(m); - } else { - TORCH_CHECK( - false, - "Flatbuffer input file but the build hasn't enabled flatbuffer"); - } + auto m = parse_flatbuffer_no_object(data, size, device); + map = mobile_module_to_parameter_map(m); break; } diff --git a/torch/csrc/jit/mobile/module.cpp b/torch/csrc/jit/mobile/module.cpp index 5da8cb4a55da6..8f61cc2402e1b 100644 --- a/torch/csrc/jit/mobile/module.cpp +++ b/torch/csrc/jit/mobile/module.cpp @@ -97,6 +97,10 @@ c10::optional Module::find_method(const std::string& basename) const { } namespace { +// For JIT, there is a private function to get all modules by iteration in +// struct slot_iterator_impl (jit/api/module.h). The following function use +// recursion to mimic the logic without allocating extra memory to get module +// list and set training attribute directly. void set_train_recurse( const c10::intrusive_ptr& obj, bool on) { @@ -109,7 +113,9 @@ void set_train_recurse( "call .eval() before saving your model?"); } for (const auto& slot : obj->slots()) { - if (slot.isObject()) { + // slots is a list of IValue. Continue setting training attribute only + // if the slot is an object and a module. + if (slot.isObject() && slot.toObjectRef().type()->is_module()) { set_train_recurse(slot.toObject(), on); } } diff --git a/torch/csrc/jit/mobile/profiler_edge.cpp b/torch/csrc/jit/mobile/profiler_edge.cpp index d3dc596ca3dcc..8fdd1654082ae 100644 --- a/torch/csrc/jit/mobile/profiler_edge.cpp +++ b/torch/csrc/jit/mobile/profiler_edge.cpp @@ -18,15 +18,23 @@ KinetoEdgeCPUProfiler::KinetoEdgeCPUProfiler( const bool profile_memory, const bool with_stack, const bool with_flops, - const bool with_modules) + const bool with_modules, + std::vector events) : m_(m), trace_file_name_(fname) { + torch::profiler::impl::ExperimentalConfig experimental_config; + // Enable hardware counters + if (events.size()) { + experimental_config.performance_events = std::move(events); + } + torch::profiler::impl::ProfilerConfig config( torch::profiler::impl::ProfilerState::KINETO, report_input_shapes, profile_memory, with_stack, with_flops, - with_modules); + with_modules, + experimental_config); torch::autograd::profiler::prepareProfiler( config, {torch::autograd::profiler::ActivityType::CPU}); if (with_modules || with_stack) { diff --git a/torch/csrc/jit/mobile/profiler_edge.h b/torch/csrc/jit/mobile/profiler_edge.h index 52dc26d1221a7..2a89819e700cd 100644 --- a/torch/csrc/jit/mobile/profiler_edge.h +++ b/torch/csrc/jit/mobile/profiler_edge.h @@ -55,7 +55,8 @@ class TORCH_API KinetoEdgeCPUProfiler { const bool profile_memory = false, const bool with_stack = false, const bool with_flops = false, - const bool with_modules = false); + const bool with_modules = false, + std::vector events = {}); const std::unique_ptr& disableProfiler(); diff --git a/torch/csrc/jit/mobile/train/export_data.cpp b/torch/csrc/jit/mobile/train/export_data.cpp index da75ef991ba2c..731ffef15424a 100644 --- a/torch/csrc/jit/mobile/train/export_data.cpp +++ b/torch/csrc/jit/mobile/train/export_data.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -131,14 +132,7 @@ void _save_parameters( }; if (use_flatbuffer) { - if (_save_mobile_module_to != nullptr) { - _save_mobile_module_to(mobile::tensor_dict_to_mobile(dict), write_func); - } else { - TORCH_CHECK( - false, - "Trying to export as flatbuffer file but " - "the build hasn't enabled flatbuffer"); - } + save_mobile_module_to_func(mobile::tensor_dict_to_mobile(dict), write_func); } else { // For Pickle, we only serialize the dict itself. mobile::IValuePickler pickler(write_func); diff --git a/torch/csrc/jit/operator_upgraders/README.md b/torch/csrc/jit/operator_upgraders/README.md index 084e6688f148e..75639006e5034 100644 --- a/torch/csrc/jit/operator_upgraders/README.md +++ b/torch/csrc/jit/operator_upgraders/README.md @@ -1,6 +1,6 @@ # Guidance for Operator Developer -PyTorch’s operators sometimes require changes for different reasons (e.g. from improving their usability to fixing bugs). These changes can be backward compatibility (BC) breaking, where older programs will no longer run as expected (or at all) on the latest version of PyTorch (an old program / new runtime problem), or forward compatibility (FC) breaking, where new programs will not run on older versions of PyTorch (a new program / old runtime problem). This guidance focuses on the requirements for maintaining backwards comatibility when making changes to an operator. +PyTorch’s operators sometimes require changes for different reasons (e.g. from improving their usability to fixing bugs). These changes can be backward compatibility (BC) breaking, where older programs will no longer run as expected (or at all) on the latest version of PyTorch (an old program / new runtime problem), or forward compatibility (FC) breaking, where new programs will not run on older versions of PyTorch (a new program / old runtime problem). This guidance focuses on the requirements for maintaining backwards compatibility when making changes to an operator. In order to do this we introduce the concept of the *upgrader*: a method to adapt the new operator to mimic the old operator behavior. When a new runtime reads an old program containing the old operator definition, the upgrader will adapt the old operator definition to comply with the new operator implementation. As you would expect, an upgrader is only applied when an old operation definition is encountered (i.e. if there are no "old" operators in the program, no upgrader would be used). For more details on the reasoning behind this new requirement please refer to the [PyTorch Operator Versioning RFC](https://github.com/pytorch/rfcs/blob/master/RFC-0017-PyTorch-Operator-Versioning.md). @@ -177,7 +177,7 @@ When making changes to the operators, the first thing to identify is if it's BC/ except Exception as e: self.skipTest("Failed to load fixture!") - # Step4. Load the new model and it won't apply the ugprader + # Step4. Load the new model and it won't apply the upgrader current_mobile_module_float = self._save_load_mobile_module(MyModuleFloat) current_server_module_float = self._save_load_module(MyModuleFloat) @@ -226,7 +226,7 @@ def foo(x, y, z=100): return x, y, z ``` -2. To help understanding the BC/FC breakage changes, here are some FC breaking changes examples. The solution to resolve it is not there yet. If it's desired, please report it in either [PyTorch Forum](https://discuss.pytorch.org/) or [PyTorch Github](https://github.com/pytorch/pytorch). We will prioritize it accordingly. +2. To help understanding the BC/FC breakage changes, here are some FC breaking changes examples. The solution to resolve it is not there yet. If it's desired, please report it in either [PyTorch Forum](https://discuss.pytorch.org/) or [PyTorch GitHub](https://github.com/pytorch/pytorch). We will prioritize it accordingly. - Adding new default argument: - Adding a new default argument not RIGHT BEFORE the out arguments which can be 0 or more. diff --git a/torch/csrc/jit/passes/mobile_optimizer_type.h b/torch/csrc/jit/passes/mobile_optimizer_type.h new file mode 100644 index 0000000000000..d11f288dca343 --- /dev/null +++ b/torch/csrc/jit/passes/mobile_optimizer_type.h @@ -0,0 +1,13 @@ +#pragma once + +#include + +enum class MobileOptimizerType : int8_t { + CONV_BN_FUSION, + INSERT_FOLD_PREPACK_OPS, + REMOVE_DROPOUT, + FUSE_ADD_RELU, + HOIST_CONV_PACKED_PARAMS, + CONV_1D_TO_2D, + VULKAN_AUTOMATIC_GPU_TRANSFER, +}; diff --git a/torch/csrc/jit/passes/onnx.cpp b/torch/csrc/jit/passes/onnx.cpp index 98f3cb42aea0f..d4e7aa6c7f98f 100644 --- a/torch/csrc/jit/passes/onnx.cpp +++ b/torch/csrc/jit/passes/onnx.cpp @@ -8,13 +8,14 @@ #include #include #include +#include +#include #include #include #include #include #include #include - namespace torch { namespace jit { @@ -59,13 +60,13 @@ void checkONNXCompatibility(const c10::FunctionSchema& schema) { if (type->kind() == TypeKind::OptionalType) { type = reinterpret_cast(type.get())->getElementType(); // recursive optional type is not supported - AT_ASSERT(type->kind() != TypeKind::OptionalType); + TORCH_INTERNAL_ASSERT(type->kind() != TypeKind::OptionalType); } if (type->kind() == TypeKind::ListType) { const auto& elem_type = reinterpret_cast(type.get())->getElementType(); if (elem_type->isSubtypeOf(*TensorType::get())) { - AT_ASSERTM( + TORCH_INTERNAL_ASSERT( !has_tensor_list, "ONNX export supports at most one TensorList as input."); has_tensor_list = true; @@ -92,7 +93,7 @@ void preprocessCaffe2Ops(Block* block) { size_t origin_inputs_index = 0; for (const auto& arg : args) { auto type = arg.type(); - AT_ASSERT(origin_inputs_index < origin_inputs.size()); + TORCH_INTERNAL_ASSERT(origin_inputs_index < origin_inputs.size()); const auto& origin_input = origin_inputs[origin_inputs_index++]; if (type->kind() == TypeKind::OptionalType && origin_input->mustBeNone()) { @@ -104,24 +105,24 @@ void preprocessCaffe2Ops(Block* block) { type->kind() == TypeKind::BoolType || type->kind() == TypeKind::IntType) { const auto* constant_node = origin_input->node(); - AT_ASSERT(constant_node->kind() == prim::Constant); + TORCH_INTERNAL_ASSERT(constant_node->kind() == prim::Constant); it->i_(Symbol::attr(arg.name()), constant_node->i(attr::value)); } else if (type->kind() == TypeKind::FloatType) { const auto* constant_node = origin_input->node(); - AT_ASSERT(constant_node->kind() == prim::Constant); + TORCH_INTERNAL_ASSERT(constant_node->kind() == prim::Constant); it->f_(Symbol::attr(arg.name()), constant_node->f(attr::value)); } else if (type->kind() == TypeKind::StringType) { const auto* constant_node = origin_input->node(); - AT_ASSERT(constant_node->kind() == prim::Constant); + TORCH_INTERNAL_ASSERT(constant_node->kind() == prim::Constant); it->s_(Symbol::attr(arg.name()), constant_node->s(attr::value)); } else if (type->kind() == TypeKind::ListType) { const auto& list_node = origin_input->node(); const auto& elem_type = type->castRaw()->getElementType(); - AT_ASSERT( + TORCH_INTERNAL_ASSERT( list_node->kind() == prim::ListConstruct || list_node->kind() == prim::Constant); if (elem_type->isSubtypeOf(*TensorType::get())) { - AT_ASSERT(list_node->kind(), prim::ListConstruct); + TORCH_INTERNAL_ASSERT(list_node->kind(), prim::ListConstruct); const auto& tensor_list = origin_input->node()->inputs(); for (const auto& t : tensor_list) { it->addInput(t); @@ -131,7 +132,7 @@ void preprocessCaffe2Ops(Block* block) { if (list_node->kind() == prim::ListConstruct) { for (const auto* elem_input : list_node->inputs()) { const auto* constant_node = elem_input->node(); - AT_ASSERT(constant_node->kind() == prim::Constant); + TORCH_INTERNAL_ASSERT(constant_node->kind() == prim::Constant); values.push_back(constant_node->f(attr::value)); } } else { // is a constant list @@ -326,10 +327,20 @@ void NodeToONNX( ONNXShapeTypeInference(const_node, empty_params_dict, opset_version); env[old] = const_node->output(); } else { - // ConstantValueMap has been set in shape inference, - // set_constant_value_map = false here to avoid redundancy. + // An update in ConstantValueMap is also needed here, since + // the user setType can be only accessed in this step, and it + // should be reliable. MergeInferredTypeAndSetMap( - outputs[i], old->type(), outputs[i]->type(), false); + outputs[i], old->type(), outputs[i]->type()); + // non ONNX node with no type given will throw out the warnings here. + UpdateReliable( + outputs[i], + AreInputsReliableOrStatic(outputs[i]->node()), + /*no_type_warning=*/true); + // For the node type that does not have ComputeConstant logic, it may + // have reliable shape but its shape is not in ConstantValueMap. So we + // need to update ConstantValueMap. + UpdateShapeConstantIfReliable(outputs[i]); // Copy over source location and scope information to all nodes // created by the symbolic @@ -426,8 +437,16 @@ void NodeToONNX( WithInsertPoint insert_point_guard(new_block); WithCurrentScope scope_guard(*g, n->scope()); + + // IMPORTANT: NEVER pass raw pointer of smart pointer managed objects to + // Python. Check #87343 for details. py::object raw_output = onnx.attr("_run_symbolic_function")( - g, new_block, n, py_inputs, env, operator_export_type); + g->shared_from_this(), + new_block, + n, + py_inputs, + env, + operator_export_type); // Find new nodes that have been created by _run_symbolic_function and // propagate metadata @@ -530,8 +549,11 @@ void NodeToONNX( opset_version, pyobj.attr("symbolic"), /* custom */ true); + + // IMPORTANT: NEVER pass raw pointer of smart pointer managed objects to + // Python. Check #87343 for details. py::object raw_output = onnx.attr("_run_symbolic_method")( - new_block->owningGraph(), + new_block->owningGraph()->shared_from_this(), op->name(), pyobj.attr("symbolic"), py_symbolic_args); @@ -542,8 +564,10 @@ void NodeToONNX( Node* n = static_cast(op); n->s_(attr::name, op->name()); // Call symbolic function + // IMPORTANT: NEVER pass raw pointer of smart pointer managed objects to + // Python. Check #87343 for details. py::object raw_output = onnx.attr("_run_symbolic_function")( - new_block->owningGraph(), + new_block->owningGraph()->shared_from_this(), new_block, n, py_symbolic_args, diff --git a/torch/csrc/jit/passes/onnx.h b/torch/csrc/jit/passes/onnx.h index e3c6cd23ecc3e..11bee67916404 100644 --- a/torch/csrc/jit/passes/onnx.h +++ b/torch/csrc/jit/passes/onnx.h @@ -1,8 +1,6 @@ #pragma once #include -#include -#include #include #include diff --git a/torch/csrc/jit/passes/onnx/pattern_conversion/pattern_conversion.cpp b/torch/csrc/jit/passes/onnx/pattern_conversion/pattern_conversion.cpp index 2280ea6eb30bb..d93e34f87c6e9 100644 --- a/torch/csrc/jit/passes/onnx/pattern_conversion/pattern_conversion.cpp +++ b/torch/csrc/jit/passes/onnx/pattern_conversion/pattern_conversion.cpp @@ -146,8 +146,8 @@ std::unordered_map MergeSliceAndSelectToIndices( std::forward_as_tuple(index_tensor, aten::select)); dim_offset++; } else { - AT_ERROR( - "Unexpected node kind ", + TORCH_CHECK( + false, node->kind().toDisplayString(), " Expected aten::slice or aten::select."); } @@ -202,7 +202,8 @@ std::vector ReshapeToAdvancedIndexingFormat( if (((max_index_dim - min_index_dim + 1) != tensor_ind_count) && tensor_ind_count != 0) { - AT_ERROR( + TORCH_CHECK( + false, "Only consecutive 1-d tensor indices are supported in exporting aten::index_put to ONNX.", "Check https://pytorch.org/docs/stable/onnx.html#indexing for details"); } @@ -230,7 +231,8 @@ std::vector ReshapeToAdvancedIndexingFormat( break; } default: - AT_ERROR("Unexpected node kind ", index_i->second.orig_node_kind); + TORCH_CHECK( + false, "Unexpected node kind ", index_i->second.orig_node_kind); } if (ind_size != 1) { diff --git a/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp b/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp index db74dca360e3f..efb7686fae3fe 100644 --- a/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp +++ b/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp @@ -136,6 +136,21 @@ Node* addDummyClone( orig_data->type()->kind() == TypeKind::BoolType) { auto* noneNode = graph->create(prim::Constant); noneNode->output()->setType(NoneType::get()); + // For scripting mode, aten::clone requires input to be a TensorType + // Hence if we encounter an IntType, FloatType, or BoolType, + // we set the input to the appropriate TensorType + if (orig_data->type()->kind() == TypeKind::IntType && + insertBefore == false) { + orig_data->setType(TensorType::fromNumberType(*IntType::get())); + } else if ( + orig_data->type()->kind() == TypeKind::FloatType && + insertBefore == false) { + orig_data->setType(TensorType::fromNumberType(*FloatType::get())); + } else if ( + orig_data->type()->kind() == TypeKind::BoolType && + insertBefore == false) { + orig_data->setType(TensorType::fromBoolType()); + } newNode = graph->create(aten::clone, /*num_outputs =*/1); newNode->addInput(orig_data); newNode->addInput(noneNode->output()); diff --git a/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp b/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp index 657c27f70c7d9..3af0360b7e011 100644 --- a/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp +++ b/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp @@ -48,6 +48,7 @@ static const std::unordered_set standardOps = { onnx::Div, onnx::Gemm, onnx::Min, + onnx::Max, onnx::Mod, onnx::Mul, onnx::Pow, diff --git a/torch/csrc/jit/passes/onnx/shape_type_inference.cpp b/torch/csrc/jit/passes/onnx/shape_type_inference.cpp index 248733f746a63..a9087508e6ad2 100644 --- a/torch/csrc/jit/passes/onnx/shape_type_inference.cpp +++ b/torch/csrc/jit/passes/onnx/shape_type_inference.cpp @@ -76,16 +76,13 @@ std::pair MergeInferredType( void MergeInferredTypeAndSetMap( Value* dest_v, TypePtr existing_type, - TypePtr inferred_type, - bool set_constant_value_map) { + TypePtr inferred_type) { TypePtr mergedType; bool inferred; std::tie(mergedType, inferred) = MergeInferredType(existing_type, inferred_type); dest_v->setType(mergedType); - if (set_constant_value_map) { - ConstantValueMap::SetUseInferredType(dest_v->debugName(), inferred); - } + ConstantValueMap::SetUseInferredType(dest_v->debugName(), inferred); } namespace { @@ -232,6 +229,28 @@ bool IsValidONNXNode(const Node* n) { return true; } +bool CustomSettype(Node* node) { + // This is a helper function to decide if the non-ONNX node actually has + // custom setType from user + // Go through every symbolic_sizes and if any one of them is static, we say + // this is set by user. On the other hand, if all of them are * (dynamic), we + // take this node does not have given type, since unreliable nodes have * + // shape anyway. + auto all_output_has_type = [](Value* output) { + if (auto output_type = output->type()->cast()) { + if (auto sizes = output_type->symbolic_sizes().sizes()) { + return std::any_of(std::begin(*sizes), std::end(*sizes), [](auto size) { + return size.is_static(); + }); + } + } + return false; + }; + + return std::all_of( + node->outputs().begin(), node->outputs().end(), all_output_has_type); +} + Value* CloneValueFromListConstruct( Value* v, std::shared_ptr n_graph, @@ -700,18 +719,25 @@ std::vector<::c10::ShapeSymbol> Broadcast( const c10::ShapeSymbol& ss_shape_1 = input_shape_value_1[rank_1 - 1 - idx]; bool is_static_0 = ss_shape_0.is_static(); bool is_static_1 = ss_shape_1.is_static(); + size_t shape_idx = rank_max - 1 - idx; if (is_static_0 && is_static_1) { int64_t static_0_sz = ss_shape_0.static_size(); int64_t static_1_sz = ss_shape_1.static_size(); - final_shape[rank_max - 1 - idx] = ::c10::ShapeSymbol::fromStaticSize( - std::max(static_0_sz, static_1_sz)); + // condition for corner case of 0d tensor + // 0d tensor with 1d tensor would give us 0d tensor + if (std::min(static_0_sz, static_1_sz) == 0) { + final_shape[shape_idx] = ::c10::ShapeSymbol::fromStaticSize( + std::min(static_0_sz, static_1_sz)); + } else { + final_shape[shape_idx] = ::c10::ShapeSymbol::fromStaticSize( + std::max(static_0_sz, static_1_sz)); + } } else if (!is_static_0 && !is_static_1) { if (ss_shape_0.value() == ss_shape_1.value()) { - final_shape[rank_max - 1 - idx] = ss_shape_0; + final_shape[shape_idx] = ss_shape_0; } } } - if (rank_0 < rank_1) { for (size_t idx = rank_min; idx < rank_max; idx++) { size_t shape_idx = rank_max - 1 - idx; @@ -1872,7 +1898,8 @@ static std::unordered_set nodeTypeReliableForTracer = { void UpdateReliable( torch::jit::Value* output, - const std::pair& inferred_type_reliable) { + const std::pair& inferred_type_reliable, + bool no_type_warning) { auto inferred = ConstantValueMap::GetUseInferredType(output->debugName()).value_or(false); auto isTypeReliableForTracer = @@ -1880,7 +1907,9 @@ void UpdateReliable( output->node()->kind().toDisplayString()) != nodeTypeReliableForTracer.end(); if (!inferred && !isTypeReliableForTracer && - !output->node()->kind().is_onnx()) { + !output->node()->kind().is_onnx() && no_type_warning) { + // TODO(84661): This warning comes before setType in symbolic_fn. + // tracked in #84661 TORCH_WARN( "The shape inference of ", output->node()->kind().toDisplayString(), @@ -1890,7 +1919,7 @@ void UpdateReliable( diagnostics::Diagnose( diagnostics::Rule::kNodeMissingOnnxShapeInference, diagnostics::Level::kWarning, - {output->node()->kind().toDisplayString()}); + {{"op_name", output->node()->kind().toDisplayString()}}); } auto reliable = false; if (inferred) { @@ -1942,6 +1971,7 @@ void ONNXShapeTypeInference( SetGraphInputTypeReliable(n->owningGraph()); GRAPH_UPDATE( "Running ONNX shape inference for node: ", n->kind().toDisplayString()); + if (IsValidONNXNode(n)) { // Create a Graph containing only the single node n. // This graph is later converted to ONNX to run shape inference. @@ -2034,6 +2064,15 @@ void ONNXShapeTypeInference( GRAPH_DEBUG( "ONNX graph after shape inference: ", prettyPrint(*model_proto)); } + } else if (CustomSettype(n)) { + // If the node is not ONNX standard, go through every output to check if + // they all have shape. If they all do, this should be reliable even if the + // Op is not from ONNX. + for (auto node_output : n->outputs()) { + // Custom setType output should get in here if it's set correctly. They + // will be updated to inferred for later updatereliable function. + ConstantValueMap::SetUseInferredType(node_output->debugName(), true); + } } SpecialPostProcess(n); @@ -2075,20 +2114,7 @@ void ONNXShapeTypeInference( // reliable shape but its shape is not in ConstantValueMap. So we need this // logic to update ConstantValueMap. for (auto node_output : n->outputs()) { - if (ConstantValueMap::HasTypeReliable(node_output->debugName())) { - auto reliable = - ConstantValueMap::GetTypeReliable(node_output->debugName()) - .value_or(false); - if (reliable && !ConstantValueMap::HasShape(node_output->debugName())) { - // TODO: ListType case - if (auto output_tensor_type = node_output->type()->cast()) { - if (output_tensor_type->dim()) { - auto symbolic_sizes = output_tensor_type->symbolic_sizes(); - UpdateShapeConstantValueMap(node_output, symbolic_sizes); - } - } - } - } + UpdateShapeConstantIfReliable(node_output); } GRAPH_DEBUG( @@ -2221,7 +2247,7 @@ size_t ONNXAssignOutputShape( auto& new_var = THPVariable_Unpack(list_elem); TORCH_CHECK( var.scalar_type() == new_var.scalar_type(), - "Unsupported sequence with mixed elment types in model outputs. " + "Unsupported sequence with mixed element types in model outputs. " "ONNX supports only sequences of elements of the same data type."); } auto elem_type = graph->outputs() @@ -2273,10 +2299,10 @@ size_t ONNXAssignOutputShape( // Tracing: // Ignore None, since it is not captured in IR graph as output. // Scripting: - // Ignore None, if observing a fixed `None` node in IR graph. Because it - // is meaningless to include it as graph output as it carries no - // data/information. Plus that static `None` is not supported in ONNX IR. - // Otherwise, the output should have type `Optional`, and should be + // Ignore None, if observing a fixed `None` node in IR graph. Because + // it is meaningless to include it as graph output as it carries no + // data/information. Plus that static `None` is not supported in ONNX + // IR. Otherwise, the output should have type `Optional`, and should be // converted to ONNX `Optional`. // More context: @@ -2336,5 +2362,21 @@ void ONNXShapeTypeInference( ConstantValueMap::ClearMaps(); } +void UpdateShapeConstantIfReliable(torch::jit::Value* node_output) { + if (ConstantValueMap::HasTypeReliable(node_output->debugName())) { + auto reliable = ConstantValueMap::GetTypeReliable(node_output->debugName()) + .value_or(false); + if (reliable && !ConstantValueMap::HasShape(node_output->debugName())) { + // TODO: ListType case + if (auto output_tensor_type = node_output->type()->cast()) { + if (output_tensor_type->dim()) { + auto symbolic_sizes = output_tensor_type->symbolic_sizes(); + UpdateShapeConstantValueMap(node_output, symbolic_sizes); + } + } + } + } +} + } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/passes/onnx/shape_type_inference.h b/torch/csrc/jit/passes/onnx/shape_type_inference.h index afda5b1765377..39350ed273d48 100644 --- a/torch/csrc/jit/passes/onnx/shape_type_inference.h +++ b/torch/csrc/jit/passes/onnx/shape_type_inference.h @@ -34,8 +34,7 @@ std::pair MergeInferredType( void MergeInferredTypeAndSetMap( Value* dest_v, TypePtr existing_type, - TypePtr inferred_type, - bool set_constant_value_map = true); + TypePtr inferred_type); // Update graph input types with dynamic axes info. // Axes that are marked as dynamic will be assigned as dynamic ShapeSymbol. @@ -80,9 +79,11 @@ TORCH_API void ONNXShapeTypeInference( std::pair AreInputsReliableOrStatic(Node* n); void UpdateReliable( torch::jit::Value* output, - const std::pair& input_reliable); + const std::pair& input_reliable, + bool no_type_warning = false); void UpdateReliable(torch::jit::Node* n); +void UpdateShapeConstantIfReliable(torch::jit::Value* output); } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp b/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp index f5a50e76fcae4..300e3452a8d17 100644 --- a/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp +++ b/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp @@ -299,7 +299,10 @@ void ConvertQuantizedWeight( } } -enum class QuantizedParamsType { CONV, LINEAR }; +// CONV1D needs a different unpacking from CONV, since it's +// packed as CONV2D intentionally at the first place. +// See: https://github.com/pytorch/pytorch/pull/38248 +enum class QuantizedParamsType { CONV1D, CONV, LINEAR }; // This is called before the onnx pass. Using pattern matching we // find the relevant nodes and extract the packed_params. The packed_params are @@ -413,7 +416,8 @@ void unpackQuantizedWeightsHelper( groups = groups_int; transpose = transpose_int; } else if ( - params_type == QuantizedParamsType::CONV && + (params_type == QuantizedParamsType::CONV || + params_type == QuantizedParamsType::CONV1D) && ser_tup->elements()[0].isString()) { const auto& elements = ser_tup->elements(); auto version = elements[0].toStringRef(); @@ -426,25 +430,32 @@ void unpackQuantizedWeightsHelper( const int64_t kSpatialDim = conv_params_packed[0].item(); // skip kSpatialDim int64_t idx = 1; + // kSpatialDim = 2 even it's for Conv1D from torch.op to adopt Conv2D, + // so we need a special unpack for Conv1D which has Conv2D dim. + // See: https://github.com/pytorch/pytorch/pull/38248 for (const auto i : c10::irange(kSpatialDim)) { - (void)i; // Suppress unused variable warning - stride_int.emplace_back(conv_params_packed[idx].item()); + if (params_type != QuantizedParamsType::CONV1D || i != 0) { + stride_int.emplace_back(conv_params_packed[idx].item()); + } idx++; } for (const auto i : c10::irange(kSpatialDim)) { - (void)i; // Suppress unused variable warning - padding_int.emplace_back(conv_params_packed[idx].item()); + if (params_type != QuantizedParamsType::CONV1D || i != 0) { + padding_int.emplace_back(conv_params_packed[idx].item()); + } idx++; } for (const auto i : c10::irange(kSpatialDim)) { - (void)i; // Suppress unused variable warning - dilation_int.emplace_back(conv_params_packed[idx].item()); + if (params_type != QuantizedParamsType::CONV1D || i != 0) { + dilation_int.emplace_back(conv_params_packed[idx].item()); + } idx++; } for (const auto i : c10::irange(kSpatialDim)) { - (void)i; // Suppress unused variable warning - output_padding_int.emplace_back( - conv_params_packed[idx].item()); + if (params_type != QuantizedParamsType::CONV1D || i != 0) { + output_padding_int.emplace_back( + conv_params_packed[idx].item()); + } idx++; } groups_int = conv_params_packed[idx].item(); @@ -461,6 +472,9 @@ void unpackQuantizedWeightsHelper( torch::List optional = elements[2].toList(); bias = optional.get(0).toOptional(); + if (params_type == QuantizedParamsType::CONV1D) { + unpacked_weight = unpacked_weight.squeeze_(2); + } stride = stride_int; padding = padding_int; dilation = dilation_int; @@ -638,6 +652,10 @@ void UnpackQuantizedWeights( graph(%input, %packed_weight, %w_scale, %w_zero_point): %r = quantized::linear(%input, %packed_weight, %w_scale, %w_zero_point) return (%r) )"; + std::string qconv1d_relu = R"( + graph(%input, %packed_params, %scale, %zero_point): + %r = quantized::conv1d_relu(%input, %packed_params, %scale, %zero_point) + return (%r) )"; std::string qconv2d = R"( graph(%input, %packed_params, %scale, %zero_point): %r = quantized::conv2d(%input, %packed_params, %scale, %zero_point) @@ -668,6 +686,13 @@ void UnpackQuantizedWeights( "quantized::conv2d_unpack", QuantizedParamsType::CONV, caffe2); + unpackQuantizedWeightsHelper( + graph, + paramsDict, + qconv1d_relu, + "quantized::conv1d_unpack", + QuantizedParamsType::CONV1D, + caffe2); unpackQuantizedWeightsHelper( graph, paramsDict, diff --git a/torch/csrc/jit/passes/peephole_non_tensor.cpp b/torch/csrc/jit/passes/peephole_non_tensor.cpp index c114ea759e52c..10ff3db0586a0 100644 --- a/torch/csrc/jit/passes/peephole_non_tensor.cpp +++ b/torch/csrc/jit/passes/peephole_non_tensor.cpp @@ -15,7 +15,7 @@ namespace { * constant int value if there exists one. * * @pre node is integer arithmetic. - * @post if there's one constant in two oprands, then the second operand is + * @post if there's one constant in two operands, then the second operand is * constant. */ c10::optional checkArithNode(Node& node) { diff --git a/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp b/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp index 54bd6679980e6..c852696c62d78 100644 --- a/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp +++ b/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp @@ -257,19 +257,6 @@ at::ScalarType getObserverDtype(Module& module, Value* v) { return at::ScalarType::Undefined; } -at::ScalarType getObserverComputeDtype(Module& module, Value* v) { - auto observer_name = findObserverName(v); - if (observer_name.has_value()) { - auto observer_module = module.attr(observer_name.value()).toModule(); - if (observer_module.hasattr("compute_dtype")) { - at::ScalarType scalar_type = - observer_module.attr("compute_dtype").toScalarType(); - return scalar_type; - } - } - return at::ScalarType::Undefined; -} - c10::optional getEmbeddingBagObsName( script::Module& module, Node* n) { @@ -480,12 +467,8 @@ void insertQuantizationOps( dequant = insertFP16CastOps(g, observer_out); } else if (!isWeight(module, observer_out)) { auto observer_dtype = getObserverDtype(module, observer_out); - auto observer_compute_dtype = - getObserverComputeDtype(module, observer_out); if (observer_dtype == at::ScalarType::QUInt8 || - observer_dtype == at::ScalarType::QInt8 || - observer_compute_dtype == at::ScalarType::QUInt8 || - observer_compute_dtype == at::ScalarType::QInt8) { + observer_dtype == at::ScalarType::QInt8) { // For activation tensors we insert choose_qparams, quant, dequant ops. Value* dtype = g->insertGetAttr(self, qparams.back()); std::tie(choose_qparams, quant, dequant) = @@ -1092,9 +1075,10 @@ std::tuple InsertQuantDeQuantHelper:: auto scalar_type = observer_module.attr("dtype"); if (isPlaceholderObserver(n->input(0))) { // get compute_dtype for dynamic quantization - if (observer_module.hasattr("compute_dtype")) { + if (observer_module.hasattr("is_dynamic") && + observer_module.attr("is_dynamic").toBool()) { qparams.push_back( - std::make_pair(kScalarType, observer_module.attr("compute_dtype"))); + std::make_pair(kScalarType, observer_module.attr("dtype"))); } return std::make_tuple(qscheme, qparams); } else if (scalar_type == at::ScalarType::Half) { @@ -1554,7 +1538,7 @@ Node* insertQuantDequantNodes( void checkCalculateQParamsResultTypes(const Node* out) { TORCH_CHECK( out->outputs().size() == 2, - "cacluate_qparams should produce output of size 2 (scale, zero_point)."); + "calculate_qparams should produce output of size 2 (scale, zero_point)."); Value* scale = out->output(0); Value* zp = out->output(1); TORCH_CHECK( diff --git a/torch/csrc/jit/passes/vulkan_rewrite.cpp b/torch/csrc/jit/passes/vulkan_rewrite.cpp index 9a0f45ff84020..0c37d5b503477 100644 --- a/torch/csrc/jit/passes/vulkan_rewrite.cpp +++ b/torch/csrc/jit/passes/vulkan_rewrite.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include @@ -18,6 +19,24 @@ namespace jit { namespace { +void insertPrePackedBatchNormOp(std::shared_ptr& graph) { + std::string batchnorm_pattern = R"( + graph(%input, %weight, %bias, %mean, %var, %training, %momentum, %eps, %cudnn_enable): + %r = aten::batch_norm(%input, %weight, %bias, %mean, %var, %training, %momentum, %eps, %cudnn_enable) + return (%r))"; + std::string prepacked_ops_pattern = R"( + graph(%input, %weight, %bias, %mean, %var, %training, %momentum, %eps, %cudnn_enable): + %op_context : __torch__.torch.classes.vulkan.BatchNormPackedContext = vulkan_prepack::create_batchnorm_context( + %weight, %bias, %mean, %var, %training, %momentum, %eps, %cudnn_enable) + %res = vulkan_prepack::run_batchnorm_context(%input, %op_context) + return (%res))"; + + SubgraphRewriter batchnorm_rewriter; + batchnorm_rewriter.RegisterRewritePattern( + batchnorm_pattern, prepacked_ops_pattern); + batchnorm_rewriter.runOnGraph(graph); +} + void insertPrePackedLinearOp(std::shared_ptr& graph) { // fuse decomposed linear into aten::linear FuseLinear(graph); @@ -82,6 +101,51 @@ void insertPrePackedConv2dOp(std::shared_ptr& graph) { transpose_rewriter.runOnGraph(graph); } +void transferInputOutputBackends(std::shared_ptr& graph) { + // Move inputs to Vulkan backend + for (Value* input : graph->inputs()) { + NamedValue named_input = NamedValue("", input); + if (named_input.type()->kind() == TypeKind::TensorType) { + // find the insertion point + WithInsertPoint ip(input->uses()[0].user->prev()); + Value* replaced_input = graph->insert( + Symbol::fromQualString("aten::to"), {named_input, "vulkan"}); + // replace the input + input->replaceAllUsesAfterNodeWith( + replaced_input->node(), replaced_input); + } + } + + // Move outputs to CPU backend + at::ArrayRef&& outputs = graph->outputs(); + for (size_t i = 0; i < outputs.size(); i++) { + Value* output = outputs[i]; + NamedValue named_output = NamedValue("", output); + if (named_output.type()->kind() == TypeKind::TensorType) { + // find the insertion point + WithInsertPoint ip(output->node()->next()); + Value* replaced_output = graph->insert( + Symbol::fromQualString("aten::to"), {named_output, "cpu"}); + // replace the output + graph->block()->replaceOutput(i, replaced_output); + } + } + + SubgraphRewriter rewriter; + rewriter.runOnGraph(graph); +} + +void transferInputOutputBackends(script::Module& module) { + std::shared_ptr graph = module.get_methods()[0].graph(); + transferInputOutputBackends(graph); +} + +void eliminateDeadCode(script::Module& module) { + for (auto& method : module.get_methods()) { + EliminateDeadCode(method.graph()); + } +} + void insertPrePackedGruOp(std::shared_ptr& graph) { std::string gru_pattern = R"( graph(%input.1, %hx.1, %params_cpu:Tensor[], %has_biases:bool, %num_layers:int, %dropout:float, %train:bool, %bidirectional:bool, %batch_first:bool): @@ -219,6 +283,7 @@ void vulkanInsertPrePackedOps(std::shared_ptr& graph) { insertPrePackedConv2dOp(graph); insertPrePackedGruOp(graph); insertPrePackedLstmOp(graph); + insertPrePackedBatchNormOp(graph); } void vulkanInsertPrePackedOps(script::Module& module) { @@ -249,7 +314,9 @@ void vulkanFoldPrePackingOps(script::Module& m) { (n->kind() == Symbol::fromQualString("vulkan_prepack::create_gru_context")) || (n->kind() == - Symbol::fromQualString("vulkan_prepack::create_lstm_context"))); + Symbol::fromQualString("vulkan_prepack::create_lstm_context")) || + (n->kind() == + Symbol::fromQualString("vulkan_prepack::create_batchnorm_context"))); }; PrePackingOpsFolder(m, filter_fn, "prepack_folding"); } @@ -269,18 +336,28 @@ void vulkanRunCanonicalOptimizations(script::Module& module) { script::Module vulkanOptimizeForMobile( const script::Module& m, + const std::set& optimization_blocklist, const std::vector& preserved_methods) { auto cloned_module = m.clone(); cloned_module.eval(); cloned_module = FoldConvBatchNorm(cloned_module); - vulkanInsertPrePackedOps(cloned_module); cloned_module = freeze_module(cloned_module, preserved_methods); + vulkanInsertPrePackedOps(cloned_module); vulkanFusePrePackedConvWithClamp(cloned_module); vulkanFoldPrePackingOps(cloned_module); removeDropout(cloned_module); vulkanRemoveMutation(cloned_module); + + if (!optimization_blocklist.count( + MobileOptimizerType::VULKAN_AUTOMATIC_GPU_TRANSFER)) { + transferInputOutputBackends(cloned_module); + cloned_module.register_attribute( + "requires_backend_transfers", BoolType::get(), false); + } + // remove duplicated constants vulkanRunCanonicalOptimizations(cloned_module); + eliminateDeadCode(cloned_module); cloned_module.register_attribute( "optimized_for_vulkan", BoolType::get(), true); diff --git a/torch/csrc/jit/passes/vulkan_rewrite.h b/torch/csrc/jit/passes/vulkan_rewrite.h index 8e67dce70f542..395d885e8e2c3 100644 --- a/torch/csrc/jit/passes/vulkan_rewrite.h +++ b/torch/csrc/jit/passes/vulkan_rewrite.h @@ -2,6 +2,7 @@ #include #include +#include namespace torch { namespace jit { @@ -11,6 +12,7 @@ TORCH_API void vulkanFusePrePackedConvWithClamp(script::Module& module); TORCH_API void vulkanFoldPrePackingOps(script::Module& module); TORCH_API script::Module vulkanOptimizeForMobile( const script::Module& module, + const std::set& optimization_blocklist, const std::vector& preserved_methods); } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/passes/xnnpack_rewrite.cpp b/torch/csrc/jit/passes/xnnpack_rewrite.cpp index 2476d1be4df61..0e2163f7a19f8 100644 --- a/torch/csrc/jit/passes/xnnpack_rewrite.cpp +++ b/torch/csrc/jit/passes/xnnpack_rewrite.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include diff --git a/torch/csrc/jit/passes/xnnpack_rewrite.h b/torch/csrc/jit/passes/xnnpack_rewrite.h index 498dcd006fe3c..d1a64c52c9230 100644 --- a/torch/csrc/jit/passes/xnnpack_rewrite.h +++ b/torch/csrc/jit/passes/xnnpack_rewrite.h @@ -2,19 +2,11 @@ #include #include +#include namespace torch { namespace jit { -enum class MobileOptimizerType : int8_t { - CONV_BN_FUSION, - INSERT_FOLD_PREPACK_OPS, - REMOVE_DROPOUT, - FUSE_ADD_RELU, - HOIST_CONV_PACKED_PARAMS, - CONV_1D_TO_2D, -}; - TORCH_API void transformConv1dToConv2d(std::shared_ptr& graph); TORCH_API void transformConv1dToConv2d(script::Module& module); TORCH_API void insertPrePackedOps(std::shared_ptr& graph); diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index 0bb959a3c61e0..7ee48635cdffc 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -13,7 +13,7 @@ #if (!defined(FBCODE_CAFFE2) && defined(BUILD_ONEDNN_GRAPH)) #include #endif -#include +#include #include #include #include @@ -52,6 +52,7 @@ #include #include #include +#include #include #include #include @@ -99,7 +100,6 @@ #include #include -#include #include #include #include @@ -126,249 +126,11 @@ using c10::Argument; using c10::FunctionSchema; using c10::SchemaArgType; using c10::SchemaArgument; -using c10::SymFloat; -using c10::SymFloatNode; -using c10::SymIntNode; +using c10::SymNode; using caffe2::serialize::PyTorchStreamReader; using caffe2::serialize::PyTorchStreamWriter; using torch::utils::SchemaInfo; -static c10::SymIntNode toSymIntNode(c10::SymIntNode a, py::object b) { - return torch::is_symint_node(b) ? b.cast() - : a->wrap(b.cast()); -} - -static c10::SymFloatNode toSymFloatNode(c10::SymFloatNode a, py::object b) { - if (torch::is_symfloat_node(b)) { - return b.cast(); - } else if (torch::is_symint_node(b)) { - return b.cast()->sym_float(); - } else { - return a->wrap(b.cast()); - } -} - -class PythonSymIntNodeImpl : public c10::SymIntNodeImpl { - public: - PythonSymIntNodeImpl(py::object pyobj) : c10::SymIntNodeImpl() { - pyobj_ = std::make_shared( - pyobj.release().ptr(), getPyInterpreter()); - }; - - virtual SymIntNode clone() override { - py::gil_scoped_acquire acquire; - auto r = getPyObj().attr("clone")(); - return c10::make_intrusive(r); - } - - virtual SymIntNode wrap(int64_t num) override { - py::gil_scoped_acquire acquire; - auto r = getPyObj().attr("wrap")(num); - return c10::make_intrusive(r); - } - - virtual bool bool_() override { - py::gil_scoped_acquire acquire; - return getPyObj().attr("__bool__")().is(py::handle(Py_True)); - } - - virtual int64_t guard_int(const char* file, int64_t line) override { - py::gil_scoped_acquire acquire; - return getPyObj().attr("guard_int")(file, line).cast(); - } - - virtual int64_t int_() override { - py::gil_scoped_acquire acquire; - return getPyObj().attr("__int__")().cast(); - } - - SymFloatNode sym_float() override; - - virtual std::string str() override { - py::gil_scoped_acquire acquire; - return getPyObj().attr("__str__")().cast(); - } - - virtual SymIntNode dispatch_common_( - const char* fname, - const SymIntNode& other) { - auto pother = dynamic_cast(other.get()); - TORCH_CHECK(pother); - py::gil_scoped_acquire acquire; - auto r = getPyObj().attr(fname)(pother->getPyObj()); - return c10::make_intrusive(r); - } - - virtual SymIntNode dispatch_common_(const char* fname) { - py::gil_scoped_acquire acquire; - auto r = getPyObj().attr(fname)(); - return c10::make_intrusive(r); - } - - virtual SymIntNode add(const SymIntNode& other) override { - return dispatch_common_(__FUNCTION__, other); - } - - virtual SymIntNode sub(const SymIntNode& other) override { - return dispatch_common_(__FUNCTION__, other); - } - - virtual SymIntNode mul(const SymIntNode& other) override { - return dispatch_common_(__FUNCTION__, other); - } - - virtual SymFloatNode truediv(const SymIntNode& other) override; - - virtual SymIntNode floordiv(const SymIntNode& other) override { - return dispatch_common_(__FUNCTION__, other); - } - - virtual SymIntNode mod(const SymIntNode& other) override { - return dispatch_common_(__FUNCTION__, other); - } - - virtual SymIntNode eq(const SymIntNode& other) override { - return dispatch_common_(__FUNCTION__, other); - } - - virtual SymIntNode gt(const SymIntNode& other) override { - return dispatch_common_(__FUNCTION__, other); - } - - virtual SymIntNode lt(const SymIntNode& other) override { - return dispatch_common_(__FUNCTION__, other); - } - - virtual SymIntNode le(const SymIntNode& other) override { - return dispatch_common_(__FUNCTION__, other); - } - - virtual SymIntNode ge(const SymIntNode& other) override { - return dispatch_common_(__FUNCTION__, other); - } - - virtual SymIntNode min(const SymIntNode& other) override { - return dispatch_common_(__FUNCTION__, other); - } - virtual SymIntNode max(const SymIntNode& other) override { - return dispatch_common_(__FUNCTION__, other); - } - - virtual SymIntNode ceil() override { - return dispatch_common_(__FUNCTION__); - } - - virtual SymIntNode neg() override { - return dispatch_common_(__FUNCTION__); - } - - py::handle getPyObj() { - return py::handle(pyobj_.get()->ptr(getPyInterpreter())); - } - std::shared_ptr pyobj_ = nullptr; -}; - -class PythonSymFloatNodeImpl : public c10::SymFloatNodeImpl { - public: - PythonSymFloatNodeImpl(py::object pyobj) : c10::SymFloatNodeImpl() { - pyobj_ = std::make_shared( - pyobj.release().ptr(), getPyInterpreter()); - }; - - virtual SymFloatNode wrap(double num) override { - py::gil_scoped_acquire acquire; - auto r = getPyObj().attr("wrap")(num); - return c10::make_intrusive(r); - } - - virtual std::string str() override { - py::gil_scoped_acquire acquire; - return getPyObj().attr("__str__")().cast(); - } - - SymFloatNode dispatch_common_(const char* fname, const SymFloatNode& other) { - auto pother = dynamic_cast(other.get()); - TORCH_CHECK(pother); - py::gil_scoped_acquire acquire; - auto r = getPyObj().attr(fname)(pother->getPyObj()); - return c10::make_intrusive(r); - } - - SymFloatNode add(const SymFloatNode& other) override { - return dispatch_common_(__FUNCTION__, other); - } - - SymFloatNode sub(const SymFloatNode& other) override { - return dispatch_common_(__FUNCTION__, other); - } - - SymFloatNode mul(const SymFloatNode& other) override { - return dispatch_common_(__FUNCTION__, other); - } - - SymFloatNode truediv(const SymFloatNode& other) override { - return dispatch_common_(__FUNCTION__, other); - } - - SymFloatNode pow(const SymFloatNode& other) override { - return dispatch_common_(__FUNCTION__, other); - } - - SymFloatNode eq(const SymFloatNode& other) override { - return dispatch_common_(__FUNCTION__, other); - } - - SymFloatNode gt(const SymFloatNode& other) override { - return dispatch_common_(__FUNCTION__, other); - } - - SymFloatNode lt(const SymFloatNode& other) override { - return dispatch_common_(__FUNCTION__, other); - } - - SymFloatNode le(const SymFloatNode& other) override { - return dispatch_common_(__FUNCTION__, other); - } - - SymFloatNode ge(const SymFloatNode& other) override { - return dispatch_common_(__FUNCTION__, other); - } - - SymIntNode ceil() override; - SymIntNode floor() override; - - py::handle getPyObj() { - return py::handle(pyobj_.get()->ptr(getPyInterpreter())); - } - std::shared_ptr pyobj_ = nullptr; -}; - -SymFloatNode PythonSymIntNodeImpl::truediv(const SymIntNode& other) { - auto pother = dynamic_cast(other.get()); - TORCH_CHECK(pother); - py::gil_scoped_acquire acquire; - auto r = getPyObj().attr("truediv")(pother->getPyObj()); - return c10::make_intrusive(r); -} - -SymFloatNode PythonSymIntNodeImpl::sym_float() { - py::gil_scoped_acquire acquire; - return c10::make_intrusive( - getPyObj().attr("__sym_float__")()); -} - -SymIntNode PythonSymFloatNodeImpl::ceil() { - py::gil_scoped_acquire acquire; - auto r = getPyObj().attr("ceil")(); - return c10::make_intrusive(r); -} - -SymIntNode PythonSymFloatNodeImpl::floor() { - py::gil_scoped_acquire acquire; - auto r = getPyObj().attr("floor")(); - return c10::make_intrusive(r); -} - namespace { using autograd::variable_list; @@ -895,7 +657,7 @@ void initJITBindings(PyObject* module) { .def( "_jit_pass_create_autodiff_subgraphs", [](const std::shared_ptr& graph, py::object threshold) { - if (threshold.is(py::none())) { + if (threshold.is_none()) { CreateAutodiffSubgraphs(graph); } else { CreateAutodiffSubgraphs(graph, py::cast(threshold)); @@ -1320,8 +1082,10 @@ void initJITBindings(PyObject* module) { .def( "_jit_pass_vulkan_optimize_for_mobile", [](script::Module& module, + std::set& optimization_blocklist, std::vector& preserved_methods) { - return vulkanOptimizeForMobile(module, preserved_methods); + return vulkanOptimizeForMobile( + module, optimization_blocklist, preserved_methods); }) .def( "_jit_pass_metal_insert_prepacked_ops", @@ -1381,276 +1145,68 @@ void initJITBindings(PyObject* module) { } }); - auto symint_class = - py::class_(m, "SymIntNode") - .def_static( - "new_symint", - [](py::object obj) -> c10::SymIntNode { - return c10::make_intrusive(obj); - }) - .def( - "get_pyobj", - [](c10::SymIntNode a) -> py::object { - if (auto* psn = dynamic_cast(a.get())) { - return py::reinterpret_borrow(psn->getPyObj()); - } - return py::none(); - }) - .def( - "__add__", - [](c10::SymIntNode a, py::object b) -> c10::SymIntNode { - auto snb = toSymIntNode(a, b); - return a->add(snb); - }) - .def( - "__radd__", - [](c10::SymIntNode a, py::object b) -> c10::SymIntNode { - auto snb = toSymIntNode(a, b); - return snb->add(a); - }) - .def( - "__sub__", - [](c10::SymIntNode a, py::object b) -> c10::SymIntNode { - auto snb = toSymIntNode(a, b); - return a->sub(snb); - }) - .def( - "__rsub__", - [](c10::SymIntNode a, py::object b) -> c10::SymIntNode { - auto snb = toSymIntNode(a, b); - return snb->sub(a); - }) - .def( - "__mul__", - [](c10::SymIntNode a, py::object b) -> c10::SymIntNode { - auto snb = toSymIntNode(a, b); - return a->mul(snb); - }) - .def( - "__rmul__", - [](c10::SymIntNode a, py::object b) -> c10::SymIntNode { - auto snb = toSymIntNode(a, b); - return snb->mul(a); - }) - .def( - "__truediv__", - [](c10::SymIntNode a, py::object b) -> c10::SymFloatNode { - auto snb = toSymIntNode(a, b); - return a->truediv(snb); - }) - .def( - "__rtruediv__", - [](c10::SymIntNode a, py::object b) -> c10::SymFloatNode { - auto snb = toSymIntNode(a, b); - return snb->truediv(a); - }) - .def( - "__floordiv__", - [](c10::SymIntNode a, py::object b) -> c10::SymIntNode { - auto snb = toSymIntNode(a, b); - return a->floordiv(snb); - }) - .def( - "__rfloordiv__", - [](c10::SymIntNode a, py::object b) -> c10::SymIntNode { - auto snb = toSymIntNode(a, b); - return snb->floordiv(a); - }) - .def( - "__mod__", - [](c10::SymIntNode a, py::object b) -> c10::SymIntNode { - auto snb = toSymIntNode(a, b); - return a->mod(snb); - }) - .def( - "__rmod__", - [](c10::SymIntNode a, py::object b) -> c10::SymIntNode { - auto snb = toSymIntNode(a, b); - return snb->mod(a); - }) - .def( - "__pow__", - [](c10::SymIntNode a, py::object b) -> py::object { - if (PyFloat_Check(b.ptr())) { - auto float_a = a->sym_float(); - return py::cast( - float_a->pow(float_a->wrap(py::cast(b)))); - } - // TODO: integer pow - return py::reinterpret_borrow(Py_NotImplemented); - }) - // TODO: rpow - .def( - "__eq__", - [](c10::SymIntNode a, py::object b) -> c10::SymIntNode { - auto snb = toSymIntNode(a, b); - return a->eq(snb); - }) - .def( - "__gt__", - [](c10::SymIntNode a, py::object b) { - auto snb = toSymIntNode(a, b); - return a->gt(snb); - }) - .def( - "__lt__", - [](c10::SymIntNode a, py::object b) -> c10::SymIntNode { - auto snb = toSymIntNode(a, b); - return a->lt(snb); - }) - .def( - "__le__", - [](c10::SymIntNode a, py::object b) -> c10::SymIntNode { - auto snb = toSymIntNode(a, b); - return a->le(snb); - }) - .def( - "__ge__", - [](c10::SymIntNode a, py::object b) -> c10::SymIntNode { - auto snb = toSymIntNode(a, b); - return a->ge(snb); - }) - .def( - "__ceil__", - [](c10::SymIntNode a) -> c10::SymIntNode { return a->ceil(); }) - .def( - "__neg__", - [](c10::SymIntNode a) -> c10::SymIntNode { return a->neg(); }) - .def( - "__min__", - [](c10::SymIntNode a, py::object b) -> c10::SymIntNode { - auto snb = toSymIntNode(a, b); - return a->min(snb); - }) - .def( - "__max__", - [](c10::SymIntNode a, py::object b) -> c10::SymIntNode { - auto snb = toSymIntNode(a, b); - return a->max(snb); - }) - .def("__bool__", [](c10::SymIntNode a) { return a->bool_(); }) - .def("__int__", [](c10::SymIntNode a) { return a->int_(); }) - // Intentionally don't set file line, as the Python backtrace matters - // more here - .def( - "guard_int", - [](c10::SymIntNode a) { return a->guard_int(nullptr, 0); }) - .def( - "__sym_float__", - [](c10::SymIntNode a) { - // TODO: remove dynamic cast when sym_float is in base class - auto* psn = dynamic_cast(a.get()); - TORCH_INTERNAL_ASSERT(psn); - return psn->sym_float(); - }) - .def("__str__", [](c10::SymIntNode a) { return a->str(); }) - .def("__repr__", [](c10::SymIntNode a) { return a->str(); }); - - py::class_(m, "SymFloatNode") - .def_static( - "new_symfloat", - [](py::object obj) -> c10::SymFloatNode { - return c10::make_intrusive(obj); - }) - .def( - "__add__", - [](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode { - auto snb = toSymFloatNode(a, b); - return a->add(snb); - }) - .def( - "__radd__", - [](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode { - auto snb = toSymFloatNode(a, b); - return snb->add(a); - }) - .def( - "__sub__", - [](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode { - auto snb = toSymFloatNode(a, b); - return a->sub(snb); - }) - .def( - "__mul__", - [](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode { - auto snb = toSymFloatNode(a, b); - return a->mul(snb); - }) - .def( - "__rmul__", - [](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode { - auto snb = toSymFloatNode(a, b); - return snb->mul(a); - }) - .def( - "__truediv__", - [](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode { - auto snb = toSymFloatNode(a, b); - return a->truediv(snb); - }) - .def( - "__rtruediv__", - [](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode { - auto snb = toSymFloatNode(a, b); - return snb->truediv(a); - }) - .def( - "__eq__", - [](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode { - auto snb = toSymFloatNode(a, b); - return a->eq(snb); - }) - .def( - "__gt__", - [](c10::SymFloatNode a, py::object b) { - auto snb = toSymFloatNode(a, b); - return a->gt(snb); - }) - .def( - "__lt__", - [](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode { - auto snb = toSymFloatNode(a, b); - return a->lt(snb); + // NB: This isn't actually used for regular PyTorch symbolic tracing; + // XLA is what needs this +#define SYMNODE_UNARY(n) .def(#n, [](c10::SymNode a) { return a->n(); }) +#define SYMNODE_BINARY(n) \ + .def(#n, [](c10::SymNode a, c10::SymNode b) { return a->n(b); }) + auto symnode_class = + py::class_(m, "_SymNode") + // clang-format off + // These DO NOT install magic methods; the SymInt/SymFloat wrapper in + // Python is responsible for this + SYMNODE_UNARY(clone) + SYMNODE_UNARY(is_int) + SYMNODE_UNARY(is_float) + SYMNODE_UNARY(bool_) + SYMNODE_UNARY(int_) + SYMNODE_UNARY(sym_float) + SYMNODE_BINARY(add) + SYMNODE_BINARY(sub) + SYMNODE_BINARY(mul) + SYMNODE_BINARY(truediv) + SYMNODE_BINARY(pow) + SYMNODE_BINARY(floordiv) + SYMNODE_BINARY(mod) + SYMNODE_BINARY(eq) + SYMNODE_BINARY(gt) + SYMNODE_BINARY(lt) + SYMNODE_BINARY(le) + SYMNODE_BINARY(ge) + SYMNODE_BINARY(min) + SYMNODE_BINARY(max) + SYMNODE_UNARY(ceil) + SYMNODE_UNARY(floor) + SYMNODE_UNARY(neg) + // Intentionally don't set file line, as the + // Python backtrace matters more here + .def( + "guard_int", + [](c10::SymNode a) { + return a->guard_int(nullptr, 0); + }) + .def( + "guard_float", + [](c10::SymNode a) { + return a->guard_float(nullptr, 0); + }) + .def( + "wrap_int", + [](c10::SymNode a, int64_t b) { + return a->wrap_int(b); + }) + .def( + "wrap_float", + [](c10::SymNode a, double b) { + return a->wrap_float(b); }) .def( - "__le__", - [](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode { - auto snb = toSymFloatNode(a, b); - return a->le(snb); - }) - .def( - "__ge__", - [](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode { - auto snb = toSymFloatNode(a, b); - return a->ge(snb); - }) - .def( - "__pow__", - [](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode { - auto snb = toSymFloatNode(a, b); - return a->pow(snb); - }) - .def( - "__rpow__", - [](c10::SymFloatNode a, py::object b) -> c10::SymFloatNode { - auto snb = toSymFloatNode(a, b); - return snb->pow(a); - }) - .def( - "__ceil__", - [](c10::SymFloatNode a) -> c10::SymIntNode { return a->ceil(); }) - .def( - "__floor__", - [](c10::SymFloatNode a) -> c10::SymIntNode { return a->floor(); }) - .def( - "get_pyobj", - [](c10::SymFloatNode a) -> py::object { - if (auto* psn = dynamic_cast(a.get())) { - return py::reinterpret_borrow(psn->getPyObj()); - } - return py::none(); - }) - .def("__str__", [](c10::SymFloatNode a) { return a->str(); }); + "__str__", + [](c10::SymNode a) { return a->str(); }) + .def("__repr__", [](c10::SymNode a) { + return a->str(); + }); + // clang-format on // NOLINTNEXTLINE(bugprone-unused-raii) py::class_(m, "CompleteArgumentSpec") @@ -1724,7 +1280,7 @@ void initJITBindings(PyObject* module) { .def(py::init()) .def(py::init([](const py::object& buffer) { auto writer_func = [=](const void* data, size_t size) { - // Writting an empty file is a noop + // Writing an empty file is a noop if (size == 0) { return size; } @@ -1768,6 +1324,9 @@ void initJITBindings(PyObject* module) { .value( "HOIST_CONV_PACKED_PARAMS", MobileOptimizerType::HOIST_CONV_PACKED_PARAMS) + .value( + "VULKAN_AUTOMATIC_GPU_TRANSFER", + MobileOptimizerType::VULKAN_AUTOMATIC_GPU_TRANSFER) .export_values(); // This allows PyTorchStreamReader to read from a Python buffer. It requires diff --git a/torch/csrc/jit/python/pybind_utils.cpp b/torch/csrc/jit/python/pybind_utils.cpp index 60c7247ada62a..47089fcc89694 100644 --- a/torch/csrc/jit/python/pybind_utils.cpp +++ b/torch/csrc/jit/python/pybind_utils.cpp @@ -80,10 +80,10 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional N) { scalar = at::Scalar(THPUtils_unpackComplexDouble(obj.ptr())); } else if (THPUtils_checkDouble(obj.ptr())) { scalar = at::Scalar(THPUtils_unpackDouble(obj.ptr())); - } else if (torch::is_symint_node(py::handle(obj))) { + } else if (torch::is_symint(py::handle(obj))) { save_symint = true; scalar = at::Scalar(7777777); - } else if (torch::is_symfloat_node(py::handle(obj))) { + } else if (torch::is_symfloat(py::handle(obj))) { save_symint = true; scalar = at::Scalar(std::numeric_limits::quiet_NaN()); } else { @@ -161,12 +161,12 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional N) { return py::cast(obj); } case TypeKind::SymIntType: - if (torch::is_symint_node(obj.ptr())) { + if (torch::is_symint(obj.ptr())) { return py::cast(obj); } return py::cast(obj); case TypeKind::SymFloatType: - if (torch::is_symfloat_node(obj.ptr())) { + if (torch::is_symfloat(obj.ptr())) { return py::cast(obj); } return py::cast(obj); @@ -253,7 +253,7 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional N) { bool is_symbolic = false; for (auto it = obj.begin(); it != obj.end(); it++) { auto elm = *it; - if (torch::is_symint_node(elm)) { + if (torch::is_symint(elm)) { is_symbolic = true; break; } @@ -269,7 +269,7 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional N) { for (auto it = obj.begin(); it != obj.end(); it++) { auto elm = *it; // TODO: what about SymInt conversion to SymFloat? - if (torch::is_symfloat_node(elm)) { + if (torch::is_symfloat(elm)) { is_symbolic = true; break; } @@ -442,9 +442,9 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional N) { } else if (PyComplex_CheckExact(obj.ptr())) { auto c_obj = py::cast>(obj.ptr()); return static_cast>(c_obj); - } else if (torch::is_symint_node(obj)) { + } else if (torch::is_symint(obj)) { return py::cast(obj); - } else if (torch::is_symfloat_node(obj)) { + } else if (torch::is_symfloat(obj)) { return py::cast(obj); } else { throw py::cast_error( @@ -755,8 +755,7 @@ py::object _get_operation_for_overload_or_packet( total_arg_num, false /* throw_error */); } - if (overloaded_args.size() > 0 || - at::impl::PythonTorchFunctionTLS::get_mode()) { + if (overloaded_args.size() > 0 || at::impl::torch_function_mode_enabled()) { py::object ret; std::string ns = symbol.ns().toUnqualString(); std::string method_name = symbol.toUnqualString(); diff --git a/torch/csrc/jit/python/pybind_utils.h b/torch/csrc/jit/python/pybind_utils.h index 835c7d0dc709a..5dfe28e92fd72 100644 --- a/torch/csrc/jit/python/pybind_utils.h +++ b/torch/csrc/jit/python/pybind_utils.h @@ -299,7 +299,7 @@ inline InferredType tryToInferType(py::handle input) { return InferredType(TensorType::get()); } - if (input.is(py::none())) { + if (input.is_none()) { return InferredType(NoneType::get()); } diff --git a/torch/csrc/jit/python/python_ir.cpp b/torch/csrc/jit/python/python_ir.cpp index a19a8cd011db3..c1cae6eb300c6 100644 --- a/torch/csrc/jit/python/python_ir.cpp +++ b/torch/csrc/jit/python/python_ir.cpp @@ -382,7 +382,11 @@ void initPythonIRBindings(PyObject* module_) { "Find all nodes", py::arg("kind"), py::arg("recurse") = true) - .def("addInput", [](Graph& g) { return g.addInput(); }) + .def( + "addInput", + [](Graph& g, const std::string& name) { return g.addInput(name); }, + "Add input to graph with optional name seed", + py::arg("name") = "") .def("copy", [](Graph& g) { return g.copy(); }) .GS(eraseInput) .GS(eraseOutput) diff --git a/torch/csrc/jit/python/python_sugared_value.cpp b/torch/csrc/jit/python/python_sugared_value.cpp index 705731778dc35..12d565427ae48 100644 --- a/torch/csrc/jit/python/python_sugared_value.cpp +++ b/torch/csrc/jit/python/python_sugared_value.cpp @@ -1149,7 +1149,7 @@ std::shared_ptr toSugaredValue( g.insertConstant(static_cast>(c_obj), loc)); } else if (py::isinstance(obj)) { return toSimple(g.insertConstant(py::cast(obj), loc)); - } else if (obj.is(py::none())) { + } else if (obj.is_none()) { return toSimple(g.insertConstant(IValue(), loc)); } else if (THPDevice_Check(obj.ptr())) { auto device = reinterpret_cast(obj.ptr()); diff --git a/torch/csrc/jit/python/python_tracer.cpp b/torch/csrc/jit/python/python_tracer.cpp index 83570c85e9b4c..c89d54872a07b 100644 --- a/torch/csrc/jit/python/python_tracer.cpp +++ b/torch/csrc/jit/python/python_tracer.cpp @@ -27,7 +27,7 @@ namespace tracer { std::vector _pythonCallstack() { pybind11::gil_scoped_acquire gil; PyFrameObject* frame = PyEval_GetFrame(); - Py_INCREF(frame); + Py_XINCREF(frame); std::vector entries; while (nullptr != frame) { diff --git a/torch/csrc/jit/python/script_init.cpp b/torch/csrc/jit/python/script_init.cpp index ee9509588932c..2c6f8b1daca83 100644 --- a/torch/csrc/jit/python/script_init.cpp +++ b/torch/csrc/jit/python/script_init.cpp @@ -112,7 +112,7 @@ struct PythonResolver : public Resolver { const SourceRange& loc) override { pybind11::gil_scoped_acquire ag; py::object obj = rcb_(name); - if (obj.is(py::none())) { + if (obj.is_none()) { return nullptr; } return toSugaredValue(obj, m, loc); @@ -153,7 +153,7 @@ struct PythonResolver : public Resolver { } pybind11::gil_scoped_acquire ag; py::object obj = rcb_(name); - if (obj.is(py::none())) { + if (obj.is_none()) { return nullptr; } @@ -366,7 +366,7 @@ static StrongFunctionPtr script_compile_overloaded_function( const ResolutionCallback& rcb, const FunctionDefaults& implementation_defaults, const py::object& signature) { - if (signature.is(py::none())) { + if (signature.is_none()) { throw ErrorReport(overload_decl.range()) << "Must explicitly add type annotations to overloaded functions"; } @@ -1774,7 +1774,7 @@ void initJitScriptBindings(PyObject* module) { if (def.kind() != TK_DEF) { throw ErrorReport(def.range()) << "Currently class bodies can only contain method " - "definitions. File an issue on Github if you want " + "definitions. File an issue on GitHub if you want " "something else!"; } methodDefs.emplace_back(Def(def)); @@ -1869,7 +1869,7 @@ void initJitScriptBindings(PyObject* module) { py::object map_location, const py::dict& extra_files) { c10::optional optional_device; - if (!map_location.is(py::none())) { + if (!map_location.is_none()) { AT_ASSERT(THPDevice_Check(map_location.ptr())); optional_device = reinterpret_cast(map_location.ptr())->device; @@ -1889,7 +1889,7 @@ void initJitScriptBindings(PyObject* module) { py::object map_location, std::string ts_id) { c10::optional optional_device; - if (!map_location.is(py::none())) { + if (!map_location.is_none()) { AT_ASSERT(THPDevice_Check(map_location.ptr())); optional_device = reinterpret_cast(map_location.ptr())->device; @@ -1909,7 +1909,7 @@ void initJitScriptBindings(PyObject* module) { const py::dict& extra_files) { std::istringstream in(buffer); c10::optional optional_device; - if (!map_location.is(py::none())) { + if (!map_location.is_none()) { AT_ASSERT(THPDevice_Check(map_location.ptr())); optional_device = reinterpret_cast(map_location.ptr())->device; @@ -1924,7 +1924,7 @@ void initJitScriptBindings(PyObject* module) { "_load_for_lite_interpreter", [](const std::string& filename, py::object map_location) { c10::optional optional_device; - if (!map_location.is(py::none())) { + if (!map_location.is_none()) { AT_ASSERT(THPDevice_Check(map_location.ptr())); optional_device = reinterpret_cast(map_location.ptr())->device; @@ -1936,7 +1936,7 @@ void initJitScriptBindings(PyObject* module) { [](const std::string& buffer, py::object map_location) { std::istringstream in(buffer); c10::optional optional_device; - if (!map_location.is(py::none())) { + if (!map_location.is_none()) { AT_ASSERT(THPDevice_Check(map_location.ptr())); optional_device = reinterpret_cast(map_location.ptr())->device; diff --git a/torch/csrc/jit/runtime/argument_spec.h b/torch/csrc/jit/runtime/argument_spec.h index 66e53da24d1df..d09918522a812 100644 --- a/torch/csrc/jit/runtime/argument_spec.h +++ b/torch/csrc/jit/runtime/argument_spec.h @@ -66,7 +66,7 @@ struct ArgumentInfo { }; static_assert( - std::is_pod::value, + std::is_standard_layout::value, "ArgumentInfo is to be a POD struct"); static_assert( sizeof(ArgumentInfo) == sizeof(ArgumentInfo::plain_data_type), diff --git a/torch/csrc/jit/runtime/graph_executor.cpp b/torch/csrc/jit/runtime/graph_executor.cpp index c2c84eb9e4e47..88a092c39fe05 100644 --- a/torch/csrc/jit/runtime/graph_executor.cpp +++ b/torch/csrc/jit/runtime/graph_executor.cpp @@ -922,13 +922,12 @@ void runNondiffOptimization( std::shared_ptr& graph, bool strict_fuser_check) { GRAPH_DEBUG( - "Before customPrePassses (beginning of runNondiffOptimization)\n", - *graph); + "Before customPrePasses (beginning of runNondiffOptimization)\n", *graph); // Run custom passes that different backends can register. for (const auto& passPair : getCustomPrePasses()) { passPair.first(graph); } - GRAPH_DEBUG("After customPrePassses\n", *graph); + GRAPH_DEBUG("After customPrePasses\n", *graph); // decomposition pass, decompose certain ops that will be used in the // following passes (like batchmm and jit fusion) @@ -960,7 +959,7 @@ void runNondiffOptimization( passPair.first(graph); } GRAPH_DEBUG( - "After customPostPassses (end of runNondiffOptimization)\n", *graph); + "After customPostPasses (end of runNondiffOptimization)\n", *graph); } void runOptimization( diff --git a/torch/csrc/jit/runtime/static/README.md b/torch/csrc/jit/runtime/static/README.md index 82d42d4b9f4c7..9b72db912684a 100644 --- a/torch/csrc/jit/runtime/static/README.md +++ b/torch/csrc/jit/runtime/static/README.md @@ -141,10 +141,10 @@ is selected instead. When loading a model, ops are selected for each `torch::jit::Node` in the graph as follows: -1) If an out variant is registered, pass the node to the function that prodcues the `SROperator`. If -the result is not `nulltpr`, use that op. -2) If a native function is registered, pass the node to the function that prodcues the `SROperator`. If -the result is not `nulltpr`, use that op. +1) If an out variant is registered, pass the node to the function that produces the `SROperator`. If +the result is not `nullptr`, use that op. +2) If a native function is registered, pass the node to the function that produces the `SROperator`. If +the result is not `nullptr`, use that op. 3) Use the JIT implementation. Static runtime will throw an exception if it does not exist. ## Implementation Details diff --git a/torch/csrc/jit/runtime/static/generated_ops.cpp b/torch/csrc/jit/runtime/static/generated_ops.cpp index 69cc98bf14ec6..2ad1741ef56de 100644 --- a/torch/csrc/jit/runtime/static/generated_ops.cpp +++ b/torch/csrc/jit/runtime/static/generated_ops.cpp @@ -2431,25 +2431,6 @@ REGISTER_OPERATOR_FUNCTOR(aten::addbmm, aten_addbmm, [](Node* n) -> SROperator { return nullptr; }); -REGISTER_OPERATOR_FUNCTOR(aten::diag, aten_diag, [](Node* n) -> SROperator { - if (n->matches( - torch::schema("aten::diag(Tensor self, int diagonal=0) -> Tensor"))) { - return [](ProcessedNode* p_node) { - const auto& self = p_node->Input(0).toTensor(); - const auto diagonal = p_node->Input(1).toInt(); - if (p_node->Output(0).isNone()) { - p_node->Output(0) = at::native::diag(self, diagonal); - return; - } - auto& out = p_node->Output(0).toTensor(); - fastResizeToZero(out); - at::native::diag_cpu_out(self, diagonal, out); - }; - } - LogAndDumpSchema(n); - return nullptr; -}); - REGISTER_OPERATOR_FUNCTOR(aten::cross, aten_cross, [](Node* n) -> SROperator { if (n->matches(torch::schema( "aten::cross(Tensor self, Tensor other, int? dim=None) -> Tensor"))) { @@ -3427,96 +3408,6 @@ REGISTER_OPERATOR_FUNCTOR( return nullptr; }); -REGISTER_OPERATOR_FUNCTOR(aten::nll_loss, aten_nll_loss, [](Node* n) -> SROperator { - if (n->matches(torch::schema( - "aten::nll_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100) -> Tensor"))) { - return [](ProcessedNode* p_node) { - const auto& self = p_node->Input(0).toTensor(); - const auto& target = p_node->Input(1).toTensor(); - const auto weight = p_node->Input(2).toOptional(); - const auto reduction = p_node->Input(3).toInt(); - const auto ignore_index = p_node->Input(4).toInt(); - if (p_node->Output(0).isNone()) { - p_node->Output(0) = - at::native::nll_loss(self, target, weight, reduction, ignore_index); - return; - } - auto& out = p_node->Output(0).toTensor(); - fastResizeToZero(out); - at::native::nll_loss_out( - self, target, weight, reduction, ignore_index, out); - }; - } - LogAndDumpSchema(n); - return nullptr; -}); - -REGISTER_OPERATOR_FUNCTOR( - aten::nll_loss_backward, - aten_nll_loss_backward, - [](Node* n) -> SROperator { - if (n->matches(torch::schema( - "aten::nll_loss_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight) -> Tensor"))) { - return [](ProcessedNode* p_node) { - const auto& grad_output = p_node->Input(0).toTensor(); - const auto& self = p_node->Input(1).toTensor(); - const auto& target = p_node->Input(2).toTensor(); - const auto weight = p_node->Input(3).toOptional(); - const auto reduction = p_node->Input(4).toInt(); - const auto ignore_index = p_node->Input(5).toInt(); - const auto& total_weight = p_node->Input(6).toTensor(); - if (p_node->Output(0).isNone()) { - p_node->Output(0) = at::cpu::nll_loss_backward( - grad_output, - self, - target, - weight, - reduction, - ignore_index, - total_weight); - return; - } - auto& grad_input = p_node->Output(0).toTensor(); - fastResizeToZero(grad_input); - at::cpu::nll_loss_backward_out( - grad_input, - grad_output, - self, - target, - weight, - reduction, - ignore_index, - total_weight); - }; - } - LogAndDumpSchema(n); - return nullptr; - }); - -REGISTER_OPERATOR_FUNCTOR(aten::nll_loss2d, aten_nll_loss2d, [](Node* n) -> SROperator { - if (n->matches(torch::schema( - "aten::nll_loss2d(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100) -> Tensor"))) { - return [](ProcessedNode* p_node) { - const auto& self = p_node->Input(0).toTensor(); - const auto& target = p_node->Input(1).toTensor(); - const auto weight = p_node->Input(2).toOptional(); - const auto reduction = p_node->Input(3).toInt(); - const auto ignore_index = p_node->Input(4).toInt(); - if (p_node->Output(0).isNone()) { - p_node->Output(0) = at::native::nll_loss2d( - self, target, weight, reduction, ignore_index); - return; - } - auto& out = p_node->Output(0).toTensor(); - fastResizeToZero(out); - at::native::nll_loss2d_out( - self, target, weight, reduction, ignore_index, out); - }; - } - LogAndDumpSchema(n); - return nullptr; -}); - REGISTER_OPERATOR_FUNCTOR( aten::soft_margin_loss, aten_soft_margin_loss, diff --git a/torch/csrc/jit/runtime/static/impl.cpp b/torch/csrc/jit/runtime/static/impl.cpp index 897f3b5eee644..3f87df14f555e 100644 --- a/torch/csrc/jit/runtime/static/impl.cpp +++ b/torch/csrc/jit/runtime/static/impl.cpp @@ -56,9 +56,9 @@ namespace jit { namespace { -bool allArgsAreTensors(Node* node) { +bool allArgsAreTensors(const Node* node) { const auto& inputs = node->inputs(); - return std::all_of(inputs.begin(), inputs.end(), [](Value* value) { + return std::all_of(inputs.begin(), inputs.end(), [](const Value* value) { return value->type()->kind() == TypeKind::TensorType; }); } @@ -69,7 +69,7 @@ bool allArgsAreTensors(Node* node) { // These are rarely-used ops. Disallowing them typically eliminates // corner cases in graph optimizations, allowing for more aggressive // optimizations and better performance. -bool isUnsupportedOp(Node* node) { +bool isUnsupportedOp(const Node* node) { auto kind = node->kind(); if (kind != aten::__is__ && kind != aten::__isnot__) { return false; @@ -87,12 +87,21 @@ bool isUnsupportedOp(Node* node) { return allArgsAreTensors(node); } -// graph must be frozen or canEnableStaticRuntime would return false -// if there's any prim::CallMethod op left in the graph -bool canEnableStaticRuntime(const std::shared_ptr& graph) { - // check for sub-blocks +namespace { + +bool canEnableStaticRuntimeImpl(const Block* block) { + if (block == nullptr) { + return false; + } + bool can_support = true; - for (auto* node : graph->block()->nodes()) { + for (auto* node : block->nodes()) { + for (auto* subblock : node->blocks()) { + // The ordering prevents && from short circuiting, which we want - + // it's useful to see *all* the unsupported ops. + can_support = canEnableStaticRuntimeImpl(subblock) && can_support; + } + const auto kind = node->kind(); if (kind == prim::Constant) { continue; @@ -107,6 +116,14 @@ bool canEnableStaticRuntime(const std::shared_ptr& graph) { return can_support; } +} // namespace + +// Graph must be frozen. canEnableStaticRuntime will return false +// if there's any prim::CallMethod ops left in the graph. +bool canEnableStaticRuntime(const std::shared_ptr& graph) { + return canEnableStaticRuntimeImpl(graph->block()); +} + namespace { auto sr_metadata_registerer = torch::class_( @@ -155,7 +172,6 @@ void OptimizeGraph( UseVariadicStack(graph); EliminateTrivialEquallySplit(graph); EliminateExtraPermuteOps(graph); - PrepackWeights(graph); if (opts.enable_out_variant) { UseVariadicOp( @@ -182,6 +198,7 @@ void OptimizeGraph( } FuseListUnpack(graph); RemoveUnnecessaryOutputs(graph); + PrepackWeights(graph); #endif } @@ -2045,7 +2062,7 @@ bool ProcessedNode::verify_inputs_dont_overlap_outputs(bool force_check) const { bool skip_check = !schema || ((schema->is_mutable() || !fn_->checkMemoryOverlap()) && num_outputs() == 1); - if (!force_check && skip_check) { + if (!schema || (!force_check && skip_check)) { if (!schema) { VLOG(2) << "Detected that op schema is null"; return true; diff --git a/torch/csrc/jit/runtime/static/native_ops.cpp b/torch/csrc/jit/runtime/static/native_ops.cpp index 790d54b5c0023..1c8fb0791389c 100644 --- a/torch/csrc/jit/runtime/static/native_ops.cpp +++ b/torch/csrc/jit/runtime/static/native_ops.cpp @@ -870,9 +870,9 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::tensor_split, aten_tensor_split, [](Node* "aten::tensor_split.sections(Tensor(a -> *) self, int sections, int dim=0) -> Tensor(a)[]"))) { return [](ProcessedNode* pnode) { const auto& a = pnode->Input(0).toTensor(); - const auto b = pnode->Input(1).toInt(); + const auto b = pnode->Input(1).toSymInt(); const auto c = pnode->Input(2).toInt(); - pnode->Output(0) = at::native::tensor_split(a, b, c); + pnode->Output(0) = at::native::tensor_split_sections_symint(a, b, c); }; } diff --git a/torch/csrc/jit/runtime/static/ops.cpp b/torch/csrc/jit/runtime/static/ops.cpp index 92044ca565a9c..e2a154ad069e9 100644 --- a/torch/csrc/jit/runtime/static/ops.cpp +++ b/torch/csrc/jit/runtime/static/ops.cpp @@ -45,7 +45,7 @@ C10_DEFINE_bool( static_runtime_enable_fast_math, true, - "If on, static runtime may use use optimizations that cause accurary loss " + "If on, static runtime may use use optimizations that cause accuracy loss " "vs the jit interpreter"); namespace at { @@ -1675,36 +1675,6 @@ REGISTER_OPERATOR_FUNCTOR( }; }); -namespace { - -std::vector permute_output_sizes( - c10::IntArrayRef self_sizes, - c10::IntArrayRef dims) { - const auto nDim = dims.size(); - TORCH_CHECK( - self_sizes.size() == nDim, - "permute input and output tensors must have the same rank, got input rank=", - self_sizes.size(), - "; output rank=", - nDim); - std::vector dims_seen(nDim, false); - std::vector output_sizes; - output_sizes.reserve(nDim); - for (size_t i = 0; i < nDim; ++i) { - auto dim = c10::maybe_wrap_dim(dims[i], nDim); - TORCH_CHECK( - !dims_seen[dim], - "permute dims must be unique, found duplicate dim=", - dim); - - output_sizes.push_back(self_sizes[dim]); - dims_seen[dim] = true; - } - return output_sizes; -} - -} // namespace - // Out variants for view ops are registered to a separate registry because // their outputs (views) can't participate in memory reuse. REGISTER_OPERATOR_FUNCTOR( @@ -1729,29 +1699,6 @@ REGISTER_OPERATOR_FUNCTOR( }; }); -REGISTER_OPERATOR_FUNCTOR( - static_runtime::permute_copy, - sr_permute_copy, - [](Node* n) -> SROperator { - if (!n->matches(torch::schema( - "static_runtime::permute_copy(Tensor self, int[] dims) -> Tensor"))) { - LogAndDumpSchema(n); - return nullptr; - } - return [](ProcessedNode* p_node) { - const auto& self = p_node->Input(0).toTensor(); - const auto dims = p_node->Input(1).toDimVector(); - - if (p_node->Output(0).isNone()) { - p_node->Output(0) = create_empty_from(self); - } - auto& output = p_node->Output(0).toTensor(); - at::native::resize_( - output, permute_output_sizes(self.sizes(), dims), c10::nullopt); - at::native::permute_copy_out(self, dims, output); - }; - }); - REGISTER_OPERATOR_FUNCTOR( static_runtime::flatten_copy, aten_flatten, diff --git a/torch/csrc/jit/runtime/static/passes.h b/torch/csrc/jit/runtime/static/passes.h index 35c1678adca86..d61d7baa4947e 100644 --- a/torch/csrc/jit/runtime/static/passes.h +++ b/torch/csrc/jit/runtime/static/passes.h @@ -21,7 +21,7 @@ TORCH_API void ReplacePermuteWithCopy( std::shared_ptr& graph, bool outputs_are_immutable = true); -void ReplaceWithMaybeCopy( +TORCH_API void ReplaceWithMaybeCopy( std::shared_ptr& graph, bool outputs_are_immutable = true); diff --git a/torch/csrc/jit/serialization/export.cpp b/torch/csrc/jit/serialization/export.cpp index 2f178addda955..f5f5ab7c99088 100644 --- a/torch/csrc/jit/serialization/export.cpp +++ b/torch/csrc/jit/serialization/export.cpp @@ -729,7 +729,7 @@ void GraphEncoder::EncodeBlock( bool add_node_names, bool use_external_data_format, const std::string& onnx_file_path) { - AT_ASSERT(graph_proto != nullptr); + TORCH_INTERNAL_ASSERT(graph_proto != nullptr); std::string block_name = "torch_jit"; if (num_blocks_) { block_name += std::to_string(num_blocks_); @@ -806,7 +806,7 @@ void GraphEncoder::AddInitializersIntoGraphProto( const std::map& initializers, bool use_external_data_format, const std::string& onnx_file_path) { - AT_ASSERT(block->inputs().size() >= initializers.size()); + TORCH_INTERNAL_ASSERT(block->inputs().size() >= initializers.size()); for (auto input : block->inputs()) { auto name_tensor_pair = initializers.find(input->debugName()); if (name_tensor_pair == initializers.end()) { @@ -888,7 +888,7 @@ void GraphEncoder::EncodeNode( node_proto->set_domain(domain); } if (operator_export_type_ == onnx_torch::OperatorExportTypes::ONNX) { - AT_ASSERT( + TORCH_INTERNAL_ASSERT( !node->kind().is_aten() && !node->kind().is_prim() && !node->kind().is_attr()); } @@ -923,7 +923,7 @@ void GraphEncoder::EncodeNode( node_proto, node, attr_name, use_external_data_format, onnx_file_path); } if (node->kind() == ::c10::onnx::Loop) { - AT_ASSERT(node->blocks().size() == 1); + TORCH_INTERNAL_ASSERT(node->blocks().size() == 1); auto body = node_proto->add_attribute(); body->set_name("body"); @@ -940,7 +940,7 @@ void GraphEncoder::EncodeNode( onnx_file_path); } if (node->kind() == ::c10::onnx::If) { - AT_ASSERT(node->blocks().size() == 2); + TORCH_INTERNAL_ASSERT(node->blocks().size() == 2); auto then_branch = node_proto->add_attribute(); then_branch->set_name("then_branch"); @@ -978,7 +978,7 @@ void GraphEncoder::AddAttribute( const std::string& ref_attr_name, const AttributeKind attr_kind) { auto attr = node_proto->add_attribute(); - AT_ASSERT(name.is_attr()); + TORCH_INTERNAL_ASSERT(name.is_attr()); attr->set_name(name.toUnqualString()); attr->set_ref_attr_name(ref_attr_name); attr->set_type(ATenAttributeKindToOnnxAttributeType(attr_kind, name)); @@ -1009,7 +1009,7 @@ void GraphEncoder::AddAttribute( }; auto attr = node_proto->add_attribute(); - AT_ASSERT(name.is_attr()); + TORCH_INTERNAL_ASSERT(name.is_attr()); attr->set_name(name.toUnqualString()); attr->set_type( ATenAttributeKindToOnnxAttributeType(node->kindOf(name), name)); @@ -1236,7 +1236,7 @@ void GraphEncoder::EncodeTensor( // or use_external_data_format should be true, not both at the same time. They // can both be false at the same time (for ONNX export for regular model // size). - AT_ASSERT( + TORCH_INTERNAL_ASSERT( !((defer_weight_export_ && external_ref) && use_external_data_format)); // Add a buffer to the raw_data_export_map for the caller to dump into an // external data store. If external_ref is not specified, we instead dump @@ -1244,18 +1244,19 @@ void GraphEncoder::EncodeTensor( if (defer_weight_export_ && external_ref) { // For now, we use the name of the tensor as the external lookup name to // avoid ONNX protobuf changes. - AT_ASSERT(external_ref.value() == tensor_proto->name()); - AT_ASSERT(raw_data_export_map_.count(external_ref.value()) == 0); + TORCH_INTERNAL_ASSERT(external_ref.value() == tensor_proto->name()); + TORCH_INTERNAL_ASSERT( + raw_data_export_map_.count(external_ref.value()) == 0); raw_data_export_map_[external_ref.value()] = t; tensor_proto->set_raw_data("__EXTERNAL"); } else { - AT_ASSERT(t.is_contiguous()); + TORCH_INTERNAL_ASSERT(t.is_contiguous()); size_t tensorSize = static_cast(c10::multiply_integers( std::begin(tensor.sizes()), std::end(tensor.sizes()))); if (use_external_data_format && tensorSize > ParamSizeThresholdForExternalStorage) { - AT_ASSERT(!onnx_file_path.empty()); - AT_ASSERT(tensor_proto->has_name()); + TORCH_INTERNAL_ASSERT(!onnx_file_path.empty()); + TORCH_INTERNAL_ASSERT(tensor_proto->has_name()); auto tensorName = GetExternalFileName(tensor_proto->name()); CreateExternalFile(t, tensorName, onnx_file_path); onnx::StringStringEntryProto* location = diff --git a/torch/csrc/jit/serialization/export.h b/torch/csrc/jit/serialization/export.h index 06670a5716450..da5d5e6a70959 100644 --- a/torch/csrc/jit/serialization/export.h +++ b/torch/csrc/jit/serialization/export.h @@ -4,12 +4,12 @@ #include #include #include +#include #include #include #include #include #include - #include namespace ONNX_NAMESPACE { @@ -260,9 +260,18 @@ Table(const std::vector>& entries); TORCH_API void enableMobileInterfaceCallExport(); bool getMobileInterfaceCallExport(); -CompilationOptions getOptionsFromGlobal(); +TORCH_API CompilationOptions getOptionsFromGlobal(); + +TORCH_API void save_jit_module( + const Module& module, + const std::string& filename, + const ExtraFilesMap& extra_files = ExtraFilesMap()); + +TORCH_API DetachedBuffer::UniqueDetachedBuffer save_jit_module_to_bytes( + const Module& module, + const ExtraFilesMap& extra_files = ExtraFilesMap()); -extern void (*_save_jit_module_to)( +TORCH_API void save_jit_module_to_write_func( const Module& module, const ExtraFilesMap& extra_files, bool save_mobile_debug_info, diff --git a/torch/csrc/jit/serialization/export_bytecode.cpp b/torch/csrc/jit/serialization/export_bytecode.cpp index b56c4980211a8..6f30f82899ed4 100644 --- a/torch/csrc/jit/serialization/export_bytecode.cpp +++ b/torch/csrc/jit/serialization/export_bytecode.cpp @@ -212,7 +212,7 @@ mobile::Code compileGraphToMobileCode( for (const TypePtr& element_type : input_type->containedTypes()) { TORCH_CHECK( element_type->kind() != TypeKind::ClassType, - "Returining a list or dictionary with pytorch class type ", + "Returning a list or dictionary with pytorch class type ", "is not supported in mobile module " "(List[Foo] or Dict[int, Foo] for class Foo(torch.nn.Module)). " "Workaround: instead of using pytorch class as their element type, ", diff --git a/torch/csrc/jit/serialization/export_module.cpp b/torch/csrc/jit/serialization/export_module.cpp index b29f1e2914c0c..0ff9b78478462 100644 --- a/torch/csrc/jit/serialization/export_module.cpp +++ b/torch/csrc/jit/serialization/export_module.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #include #include @@ -95,7 +96,7 @@ ExportModuleExtraFilesHook& GetExtraFilesHook() { * ] * ]" * - * @param compilation_unit Jit compilcation unit to look up function schema. + * @param compilation_unit Jit compilation unit to look up function schema. * @param type_ptr A type pointer and it can be possibly any type. * @param default_type_str The default string representation. The string can * either from type_ptr->str(), type_ptr->annotation_str(), or @@ -874,11 +875,37 @@ void ExportModule( use_flatbuffer); } -void (*_save_jit_module_to)( +void save_jit_module( + const Module& module, + const std::string& filename, + const ExtraFilesMap& extra_files) { + auto buffer = save_jit_module_to_bytes(module, extra_files); + std::fstream ofile(filename, std::ios::binary | std::ios::out); + ofile.write( + reinterpret_cast(buffer->data()), buffer->size()); // NOLINT + ofile.close(); +} + +DetachedBuffer::UniqueDetachedBuffer save_jit_module_to_bytes( + const Module& module, + const ExtraFilesMap& extra_files) { + ExtraFilesMap jitfiles; + std::vector constants; + jitModuleToPythonCodeAndConstants(module, &jitfiles, &constants); + CompilationOptions options = getOptionsFromGlobal(); + mobile::Module mobilem = jitModuleToMobile(module, options); + return save_mobile_module_to_bytes(mobilem, extra_files, jitfiles, constants); +} + +void save_jit_module_to_write_func( const Module& module, const ExtraFilesMap& extra_files, bool save_mobile_debug_info, - const std::function& writer_func) = nullptr; + const std::function& writer_func) { + (void)save_mobile_debug_info; + auto buffer = save_jit_module_to_bytes(module, extra_files); + writer_func(reinterpret_cast(buffer->data()), buffer->size()); +} void ExportModule( const Module& module, @@ -888,14 +915,8 @@ void ExportModule( bool save_mobile_debug_info, bool use_flatbuffer) { if (use_flatbuffer) { - if (_save_jit_module_to != nullptr) { - _save_jit_module_to( - module, extra_files, save_mobile_debug_info, writer_func); - } else { - TORCH_CHECK( - false, - "Trying to export as flatbuffer file but the build hasn't enabled flatbuffer"); - } + save_jit_module_to_write_func( + module, extra_files, save_mobile_debug_info, writer_func); } else { caffe2::serialize::PyTorchStreamWriter writer(writer_func); ScriptModuleSerializer serializer(writer); diff --git a/torch/csrc/jit/serialization/flatbuffer_serializer.cpp b/torch/csrc/jit/serialization/flatbuffer_serializer.cpp index 54ec7c7b6ed3e..ccacf7beab846 100644 --- a/torch/csrc/jit/serialization/flatbuffer_serializer.cpp +++ b/torch/csrc/jit/serialization/flatbuffer_serializer.cpp @@ -819,7 +819,7 @@ DetachedBuffer::UniqueDetachedBuffer save_mobile_module_to_bytes( return DetachedBufferFriend::make_unique_detached_buffer(ret); } -static void save_mobile_module_to_func( +void save_mobile_module_to_func( const mobile::Module& module, const std::function& writer_func) { auto buffer = save_mobile_module_to_bytes(module); @@ -827,15 +827,8 @@ static void save_mobile_module_to_func( } bool register_flatbuffer_serializer() { - _save_mobile_module_to = save_mobile_module_to_func; return true; } -// iOS builds are often build with -Wglobal-constructor to minimize -// startup time. So let them call register manually if needed. -#if !defined(__APPLE__) -const bool kFlatbufferSerializerRegistered = register_flatbuffer_serializer(); -#endif - } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/serialization/flatbuffer_serializer.h b/torch/csrc/jit/serialization/flatbuffer_serializer.h index 24da6b5527922..43e8062ef2dce 100644 --- a/torch/csrc/jit/serialization/flatbuffer_serializer.h +++ b/torch/csrc/jit/serialization/flatbuffer_serializer.h @@ -83,10 +83,11 @@ TORCH_API DetachedBuffer::UniqueDetachedBuffer save_mobile_module_to_bytes( const ExtraFilesMap& jit_sources = ExtraFilesMap(), const std::vector& jit_constants = {}); -// This function will make the capabilities to load and safe -// Module as a flatbuffer file available for use by _load_for_mobile -// and friends. This is NOT needed if using the other functions -// in this file directly. +TORCH_API void save_mobile_module_to_func( + const mobile::Module& module, + const std::function& writer_func); + +// TODO(qihan): delete TORCH_API bool register_flatbuffer_serializer(); } // namespace jit diff --git a/torch/csrc/jit/serialization/flatbuffer_serializer_jit.cpp b/torch/csrc/jit/serialization/flatbuffer_serializer_jit.cpp index 321068311da25..9cbb0f1cd2f80 100644 --- a/torch/csrc/jit/serialization/flatbuffer_serializer_jit.cpp +++ b/torch/csrc/jit/serialization/flatbuffer_serializer_jit.cpp @@ -15,87 +15,9 @@ namespace torch { namespace jit { -Module parse_and_initialize_jit_module( - std::shared_ptr data, - size_t size, - ExtraFilesMap& extra_files, - c10::optional device) { - populate_upgraders_graph_map(); - ExtraFilesMap jit_files; - std::vector jit_constants; - mobile::Module mobilem = parse_and_initialize_mobile_module_for_jit( - data.get(), size, jit_files, jit_constants, device, &extra_files); - - Module m = jitModuleFromSourceAndConstants( - mobilem._ivalue(), - jit_files, - jit_constants, - static_cast(mobilem.bytecode_version())); - m.set_delete_memory(data); - return m; -} - -Module load_jit_module_from_file( - const std::string& filename, - ExtraFilesMap& extra_files, - c10::optional device) { - auto data = get_file_content(filename.c_str()); - return parse_and_initialize_jit_module( - std::move(std::get<0>(data)), std::get<1>(data), extra_files, device); -} - -Module load_jit_module_from_stream( - std::istream& in, - ExtraFilesMap& extra_files, - c10::optional device) { - auto data = get_stream_content(in); - return parse_and_initialize_jit_module( - std::move(std::get<0>(data)), std::get<1>(data), extra_files, device); -} - -void save_jit_module( - const Module& module, - const std::string& filename, - const ExtraFilesMap& extra_files) { - auto buffer = save_jit_module_to_bytes(module, extra_files); - std::fstream ofile(filename, std::ios::binary | std::ios::out); - ofile.write( - reinterpret_cast(buffer->data()), buffer->size()); // NOLINT - ofile.close(); -} - -DetachedBuffer::UniqueDetachedBuffer save_jit_module_to_bytes( - const Module& module, - const ExtraFilesMap& extra_files) { - ExtraFilesMap jitfiles; - std::vector constants; - jitModuleToPythonCodeAndConstants(module, &jitfiles, &constants); - CompilationOptions options = getOptionsFromGlobal(); - mobile::Module mobilem = jitModuleToMobile(module, options); - return save_mobile_module_to_bytes(mobilem, extra_files, jitfiles, constants); -} - -static void save_jit_module_to_write_func( - const Module& module, - const ExtraFilesMap& extra_files, - bool save_mobile_debug_info, - const std::function& writer_func) { - (void)save_mobile_debug_info; - auto buffer = save_jit_module_to_bytes(module, extra_files); - writer_func(reinterpret_cast(buffer->data()), buffer->size()); -} - bool register_flatbuffer_all() { - (void)register_flatbuffer_loader(); - (void)register_flatbuffer_serializer(); - _save_jit_module_to = save_jit_module_to_write_func; - _load_jit_module_from_flatbuffer_bytes = parse_and_initialize_jit_module; return true; } -#if !defined(__APPLE__) -const bool kFlatbufferSerializerJitInitialized = register_flatbuffer_all(); -#endif - } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/serialization/flatbuffer_serializer_jit.h b/torch/csrc/jit/serialization/flatbuffer_serializer_jit.h index 1f605f18ba1e5..b43ab831f1773 100644 --- a/torch/csrc/jit/serialization/flatbuffer_serializer_jit.h +++ b/torch/csrc/jit/serialization/flatbuffer_serializer_jit.h @@ -5,35 +5,6 @@ namespace torch { namespace jit { -TORCH_API void save_jit_module( - const Module& module, - const std::string& filename, - const ExtraFilesMap& extra_files = ExtraFilesMap()); - -TORCH_API DetachedBuffer::UniqueDetachedBuffer save_jit_module_to_bytes( - const Module& module, - const ExtraFilesMap& extra_files = ExtraFilesMap()); - -TORCH_API Module parse_and_initialize_jit_module( - std::shared_ptr data, - size_t size, - ExtraFilesMap& extra_files, - c10::optional device = c10::nullopt); - -TORCH_API Module load_jit_module_from_file( - const std::string& filename, - ExtraFilesMap& extra_files, - c10::optional device = c10::nullopt); - -TORCH_API Module load_jit_module_from_stream( - std::istream& in, - ExtraFilesMap& extra_files, - c10::optional device = c10::nullopt); - -// This function will make the capabilities to load and safe -// Module as a flatbuffer file available for use by _load_for_mobile -// and friends. This is NOT needed if using the other functions -// in this file directly. TORCH_API bool register_flatbuffer_all(); } // namespace jit diff --git a/torch/csrc/jit/serialization/import.cpp b/torch/csrc/jit/serialization/import.cpp index a72abeaede8e1..56087f1fe0d3b 100644 --- a/torch/csrc/jit/serialization/import.cpp +++ b/torch/csrc/jit/serialization/import.cpp @@ -24,6 +24,7 @@ #include #include #include +#include #include #include #include @@ -294,12 +295,6 @@ Module import_ir_module( return import_ir_module(std::move(cu), in, device, extra_files); } -Module (*_load_jit_module_from_flatbuffer_bytes)( - std::shared_ptr, - size_t, - ExtraFilesMap&, - c10::optional) = nullptr; - static Module _load_jit_module_from_bytes( std::shared_ptr data, size_t size, @@ -307,6 +302,44 @@ static Module _load_jit_module_from_bytes( c10::optional device, ExtraFilesMap& extra_files); +Module parse_and_initialize_jit_module( + std::shared_ptr data, + size_t size, + ExtraFilesMap& extra_files, + c10::optional device) { + populate_upgraders_graph_map(); + ExtraFilesMap jit_files; + std::vector jit_constants; + mobile::Module mobilem = parse_and_initialize_mobile_module_for_jit( + data.get(), size, jit_files, jit_constants, device, &extra_files); + + Module m = jitModuleFromSourceAndConstants( + mobilem._ivalue(), + jit_files, + jit_constants, + static_cast(mobilem.bytecode_version())); + m.set_delete_memory(data); + return m; +} + +Module load_jit_module_from_file( + const std::string& filename, + ExtraFilesMap& extra_files, + c10::optional device) { + auto data = get_file_content(filename.c_str()); + return parse_and_initialize_jit_module( + std::move(std::get<0>(data)), std::get<1>(data), extra_files, device); +} + +Module load_jit_module_from_stream( + std::istream& in, + ExtraFilesMap& extra_files, + c10::optional device) { + auto data = get_stream_content(in); + return parse_and_initialize_jit_module( + std::move(std::get<0>(data)), std::get<1>(data), extra_files, device); +} + Module import_ir_module( std::shared_ptr cu, std::istream& in, @@ -444,18 +477,11 @@ Module _load_jit_module_from_bytes( std::shared_ptr cu, c10::optional device, ExtraFilesMap& extra_files) { - TORCH_CHECK(size >= kFileFormatHeaderSize, "Unrecorgnized data format"); + TORCH_CHECK(size >= kFileFormatHeaderSize, "Unrecognized data format"); auto format = getFileFormat(data.get()); switch (format) { case FileFormat::FlatbufferFileFormat: { - if (_load_jit_module_from_flatbuffer_bytes != nullptr) { - return _load_jit_module_from_flatbuffer_bytes( - data, size, extra_files, device); - } else { - TORCH_CHECK( - false, - "Flatbuffer input file but the build hasn't enable flatbuffer") - } + return parse_and_initialize_jit_module(data, size, extra_files, device); } case FileFormat::ZipFileFormat: { auto rai = std::make_unique(data.get(), size); diff --git a/torch/csrc/jit/serialization/import.h b/torch/csrc/jit/serialization/import.h index 581ad681a3d25..2b56914472b6c 100644 --- a/torch/csrc/jit/serialization/import.h +++ b/torch/csrc/jit/serialization/import.h @@ -110,19 +110,27 @@ TORCH_API Module jitModuleFromSourceAndConstants( const std::vector& constants, int32_t version); -extern Module (*_load_jit_module_from_flatbuffer_bytes)( - // comp unit - std::shared_ptr, - size_t, - ExtraFilesMap&, - c10::optional); - -extern Module (*_load_jit_module_from_flatbuffer_bytes)( - // comp unit - std::shared_ptr, - size_t, - ExtraFilesMap&, - c10::optional); +TORCH_API Module parse_and_initialize_jit_module( + std::shared_ptr data, + size_t size, + ExtraFilesMap& extra_files, + c10::optional device = c10::nullopt); + +TORCH_API Module load_jit_module_from_file( + const std::string& filename, + ExtraFilesMap& extra_files, + c10::optional device = c10::nullopt); + +TORCH_API Module load_jit_module_from_stream( + std::istream& in, + ExtraFilesMap& extra_files, + c10::optional device = c10::nullopt); + +TORCH_API Module parse_and_initialize_jit_module( + std::shared_ptr data, + size_t size, + ExtraFilesMap& extra_files, + c10::optional device); } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/serialization/pickler.cpp b/torch/csrc/jit/serialization/pickler.cpp index 22efbf1b47607..364d603b4c43c 100644 --- a/torch/csrc/jit/serialization/pickler.cpp +++ b/torch/csrc/jit/serialization/pickler.cpp @@ -480,6 +480,20 @@ void Pickler::pushLiteralTensor(const IValue& ivalue) { // Construct the collections.OrderedDict for the backward_hooks push(PickleOpCode::REDUCE); + if (!quantized) { + // Only push it for regular tensor if the dictionary is not empty. + auto metadata = torch::jit::getTensorMetadata(tensor); + if (!metadata.empty()) { + // IValues based on std::unordered_map are slow and deprecated. + // Thus, pass a c10::Dict to pushDict. + c10::Dict math_bits_; + for (const auto& pair : metadata) { + math_bits_.insert(pair.first, pair.second); + } + pushDict(math_bits_); + } + } + push(PickleOpCode::TUPLE); // Call torch._utils._rebuild_tensor_v2 diff --git a/torch/csrc/jit/serialization/pickler.h b/torch/csrc/jit/serialization/pickler.h index 399d7c232de13..26f9fcf423965 100644 --- a/torch/csrc/jit/serialization/pickler.h +++ b/torch/csrc/jit/serialization/pickler.h @@ -5,11 +5,11 @@ #include #include +#include #include #include #include #include -#include namespace torch { namespace jit { @@ -118,7 +118,7 @@ void setTypeTags(bool state); bool getTypeTags(); class TORCH_API Pickler { - TH_DISALLOW_COPY_AND_ASSIGN(Pickler); + AT_DISALLOW_COPY_AND_ASSIGN(Pickler); public: Pickler(std::function writer) @@ -296,5 +296,60 @@ uint64_t getStorageKey(const at::Tensor& tensor); // otherwise return false bool checkHasValidSetGetState(const std::shared_ptr& cls); +// Return a map of Tensor Metadata for serialization. +// For now, it only takes care of `conj` and `neg` bit. +inline std::unordered_map getTensorMetadata( + const at::Tensor& t) { + // We don't support serializing `ZeroTensor` as it is not public + // facing yet. + TORCH_CHECK( + !t._is_zerotensor(), + "ZeroTensor is not serializable,", + " please file an issue if required."); + std::unordered_map metadata{}; + + // Only add meta-data if the value is not default. + if (t.is_conj()) { + metadata["conj"] = true; + } + if (t.is_neg()) { + metadata["neg"] = true; + } + return metadata; +} + +// set Tensor Metadata based on the map. +// Refer: getTensorMathdata +inline void setTensorMetadata( + const at::Tensor& t, + std::unordered_map metadata) { + for (auto& key_value_pair : metadata) { + if (key_value_pair.first == "conj") { + t._set_conj(true); + } else if (key_value_pair.first == "neg") { + t._set_neg(true); + } else { + TORCH_CHECK( + false, + "Unexpected key `", + key_value_pair.first, + "` passed to setTensorMetadata."); + } + } +} + +// set Tensor metadata based on the map. +// NOTE: This overload is required by unpickler.cpp +inline void setTensorMetadata( + const at::Tensor& t, + c10::Dict metadata_idict) { + std::unordered_map metadata; + for (auto& pair : metadata_idict) { + auto key = *pair.key().toString(); + metadata[key] = pair.value().toBool(); + } + setTensorMetadata(t, metadata); +} + } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/serialization/unpickler.cpp b/torch/csrc/jit/serialization/unpickler.cpp index 7b40f138c600f..4bbf7a783a232 100644 --- a/torch/csrc/jit/serialization/unpickler.cpp +++ b/torch/csrc/jit/serialization/unpickler.cpp @@ -532,6 +532,21 @@ PickleOpCode Unpickler::readInstruction() { } stack_.emplace_back(std::move(tensor)); } break; + case PickleOpCode::SETITEM: { + // At this OpCode, stack looks like + // | Stack Bottom | + // | ...... | + // | Dict | -> (stack_size - 3) + // | Key | -> (stack_size - 2) + // | Value | -> (stack_size - 1) + auto stack_size = stack_.size(); + auto dict_pos = stack_size - 3; + auto key_pos = stack_size - 2; + auto val_pos = stack_size - 1; + auto dict = stack_.at(dict_pos).toGenericDict(); + dict.insert_or_assign(stack_.at(key_pos), stack_.at(val_pos)); + stack_.erase(stack_.begin() + (key_pos), stack_.end()); + } break; default: { AT_ERROR( "Unknown opcode for unpickling at ", @@ -546,6 +561,23 @@ PickleOpCode Unpickler::readInstruction() { void Unpickler::readGlobal( const std::string& module_name, const std::string& class_name) { + if (this->skip_next_read_global) { + // See [NOTE] skip_next_read_global + this->skip_next_read_global--; + if (this->skip_next_read_global == 1) { + // Pass through to the correct handler + } else if (this->skip_next_read_global == 0) { + // Corresponds to the type of `Tensor` being unpickled + if (module_name != "torch" || class_name != "Tensor") { + TORCH_WARN( + "Trying to load a Subclassed Tensor, it will be converted to at::Tensor in C++"); + } + stack_.emplace_back(int64_t(globals_.size() - 1)); + return; + } else { + TORCH_CHECK(false, "INVALID VALUES") + } + } // TODO [unpickler refactor] __main__ isn't used by the pickler anymore, this // is only here for bc-compatibility reasons if (module_name == "__main__") { @@ -631,6 +663,12 @@ void Unpickler::readGlobal( // Unpickle a tensor bool quantized = class_name == "_rebuild_qtensor"; rebuildTensor(quantized); + } else if ( + module_name == "torch._tensor" && + (class_name == "_rebuild_from_type_v2")) { + // Unpickle a Tensor with Python attributes or + // a Subclassed Tensor. + rebuildTensorFromTypeV2(); } else if ( module_name == "torch._utils" && class_name == "_rebuild_sparse_tensor") { rebuildSparseTensor(); @@ -823,17 +861,65 @@ void Unpickler::rebuildTensor(bool quantized) { } else { result = at::empty({0}, storage_tensor.options()); } - bool requires_grad = elements.at(idx).toBool(); - // elements[idx++] is empty backwards hooks + bool requires_grad = elements.at(idx++).toBool(); + idx++; // backwards hooks is empty at::TensorImpl* impl = result.unsafeGetTensorImpl(); impl->set_storage_keep_dtype(storage_tensor.storage()); impl->set_storage_offset(storage_offset); impl->set_sizes_and_strides(size, stride); result = autograd::make_variable(result, requires_grad); + + // Handle if math_bits were pickled. + // See `args` of _reduce_ex_internal + // for a regular tensor (final else case). + // Tensors pickled before this patch didn't + // have this argument for storing MathBits, + // in that case, we do nothing. + // NOTE: `math_bits` is the 7th arg. + // NOTE: This is only meant for regular tensor and not quantized + // which also has 7 args serialized. + if (!quantized && elements.size() == 7) { + auto math_bits = elements.at(idx++).toGenericDict(); + torch::jit::setTensorMetadata(result, math_bits); + } + stack_.emplace_back(std::move(result)); }); } +void Unpickler::rebuildTensorFromTypeV2() { + // [NOTE] skip_next_read_global + // When rebuilding Tensor with Python Attr or Subclassed Tensor, + // we receive `(func, type(self), args, state)` on stack for + // `rebuildTensorFromTypeV2`. + // Thus next call to readGlobal corresponds to `func` which is + // the function to rebuild the base tensor. + // The call after `func` to readGlobal corresponds to `type` of the + // Tensor where we raise warning if the type is not `torch.Tensor`. + this->skip_next_read_global = 2; + auto curr_globals_idx = globals_.size(); + globals_.emplace_back([this, curr_globals_idx] { + // args is a tuple with following data + // (function to rebuild base tensor, type of tensor, + // arguments to construct base tensor, Python State (as dict)) + auto args = pop(stack_).toTuple(); + size_t tup_idx = 0; + const auto args_elems = args->elements(); + auto base_tensor_args = args_elems.at(tup_idx + 2).toTuple(); + auto py_state = args_elems.at(tup_idx + 3).toGenericDict(); + if (py_state.size() > 0) { + TORCH_WARN( + "Loading Tensor with Python attributes will return at::Tensor with Python attributes being discarded"); + } + // This calls the function to rebuild the + // base tensor. + // Eg. `rebuildTensor`, `rebuildSpareTensor`. + stack_.emplace_back(base_tensor_args); + globals_[curr_globals_idx + 1](); + stack_.emplace_back(pop(stack_)); + }); +} + #ifdef USE_RPC void Unpickler::rebuildRRef() { globals_.emplace_back([this] { diff --git a/torch/csrc/jit/serialization/unpickler.h b/torch/csrc/jit/serialization/unpickler.h index c57aa2556d73c..de00e7eacff21 100644 --- a/torch/csrc/jit/serialization/unpickler.h +++ b/torch/csrc/jit/serialization/unpickler.h @@ -23,7 +23,7 @@ class DeserializationStorageContext; // deleted at some point, the Pickler doesn't produce it and it's only around to // support models saved before 1.1 class TORCH_API Unpickler { - TH_DISALLOW_COPY_AND_ASSIGN(Unpickler); + AT_DISALLOW_COPY_AND_ASSIGN(Unpickler); using TypeParserT = c10::TypePtr (*)(const std::string&); @@ -120,6 +120,7 @@ class TORCH_API Unpickler { const std::string& module_name, const std::string& class_name); void rebuildTensor(bool quantized); + void rebuildTensorFromTypeV2(); void rebuildSparseTensor(); #ifdef USE_DISTRIBUTED void rebuildRRef(); @@ -176,6 +177,9 @@ class TORCH_API Unpickler { // See [type tag serialization] uint64_t version_; + + // See [NOTE] skip_next_read_global + uint8_t skip_next_read_global = 0; }; void restoreAccurateTypeTags(const IValue& root, const c10::TypePtr& type_tag); diff --git a/torch/csrc/jit/tensorexpr/cuda_codegen.cpp b/torch/csrc/jit/tensorexpr/cuda_codegen.cpp index ef17c85002904..cfbac9b398f95 100644 --- a/torch/csrc/jit/tensorexpr/cuda_codegen.cpp +++ b/torch/csrc/jit/tensorexpr/cuda_codegen.cpp @@ -1314,7 +1314,7 @@ void CudaCodeGen::CompileToNVRTC( &program, code.c_str(), nullptr, 0, nullptr, nullptr)); #if defined(USE_ROCM) - std::vector args = {"--std=c++14"}; + std::vector args = {"--std=c++17"}; #if ROCM_VERSION >= 40200 args.push_back("-hip-pch"); #endif @@ -1335,7 +1335,7 @@ void CudaCodeGen::CompileToNVRTC( std::to_string(major) + std::to_string(minor); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) const std::vector args = { - "--std=c++14", compute.c_str(), "-default-device"}; + "--std=c++17", compute.c_str(), "-default-device"}; #endif auto result = nvrtc().nvrtcCompileProgram(program, args.size(), args.data()); diff --git a/torch/csrc/jit/tensorexpr/half_support.h b/torch/csrc/jit/tensorexpr/half_support.h index af146a62baa09..f095c79fbb5a1 100644 --- a/torch/csrc/jit/tensorexpr/half_support.h +++ b/torch/csrc/jit/tensorexpr/half_support.h @@ -77,12 +77,20 @@ class HalfRewriter : public IRMutator { // get the dtype of the `value()` before that is mutated. auto newType = v->value()->dtype(); ExprPtr new_val = v->value()->accept_mutator(this); + auto bufType = v->buf()->dtype(); if (isHalf(newType.scalar_type())) { new_val = alloc(newType, new_val); inserted_half_casts_.insert(new_val); } + // The scalar_type of value is not Half while the buf is Half + if (!isHalf(newType.scalar_type()) && isHalf(bufType.scalar_type())) { + new_val = alloc( + newType.cloneWithScalarType(bufType.scalar_type()), new_val); + inserted_half_casts_.insert(new_val); + } + v->set_value(new_val); return v; } diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp index c30ed316e48b1..eb108abfb0296 100644 --- a/torch/csrc/jit/tensorexpr/kernel.cpp +++ b/torch/csrc/jit/tensorexpr/kernel.cpp @@ -29,7 +29,7 @@ namespace tensorexpr { std::string buildErrorMessage(const std::string& s) { static const std::string generic_error_message = - "This error occured in the fuser. You can turn off the fuser with " + "This error occurred in the fuser. You can turn off the fuser with " "torch.jit.enable_fusion(False)."; if (s.empty()) { return generic_error_message; diff --git a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp index 78521efc240ee..a889420f944ad 100644 --- a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp +++ b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp @@ -8,6 +8,7 @@ #include #include +// Note [llvm::SCEVPredicate non-virtual destructor] // llvm::SCEVPredicate has virtual function but non-virtual destructor // https://github.com/llvm/llvm-project/blob/c1a0a213378a458fbea1a5c77b315c7dce08fd05/llvm/include/llvm/Analysis/ScalarEvolution.h#L198 #pragma GCC diagnostic push @@ -15,15 +16,30 @@ #include #pragma GCC diagnostic pop +#include +#include #include #include #include #include #include +#include #include #include +#include + +// see Note [llvm::SCEVPredicate non-virtual destructor] +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wnon-virtual-dtor" +#include +#pragma GCC diagnostic pop + #include #include +#include +#include +#include +#include #if LLVM_VERSION_MAJOR >= 10 #include @@ -446,6 +462,9 @@ LLVMCodeGenImpl::LLVMCodeGenImpl( irb_(getContext()), kernel_func_name_(std::move(kernel_func_name)), bufsExtAlloc_(ExternalAllocBufFinder::find(stmt)) { +#if LLVM_VERSION_MAJOR >= 15 + context_->setOpaquePointers(false); +#endif if (!triple) { triple = LLVMTargetTriple(); } @@ -691,7 +710,7 @@ void LLVMCodeGenImpl::visit(AddPtr v) { } else if (!lfp && !rfp) { value_ = irb_.CreateAdd(lhs, rhs); } else { - throw malformed_input("llvm_codgen: bad type in Add", v); + throw malformed_input("llvm_codegen: bad type in Add", v); } } @@ -709,7 +728,7 @@ void LLVMCodeGenImpl::visit(SubPtr v) { } else if (!lfp && !rfp) { value_ = irb_.CreateSub(lhs, rhs); } else { - throw malformed_input("llvm_codgen: bad type in Sub", v); + throw malformed_input("llvm_codegen: bad type in Sub", v); } } @@ -727,7 +746,7 @@ void LLVMCodeGenImpl::visit(MulPtr v) { } else if (!lfp && !rfp) { value_ = irb_.CreateMul(lhs, rhs); } else { - throw malformed_input("llvm_codgen: bad type in Mul", v); + throw malformed_input("llvm_codegen: bad type in Mul", v); } } @@ -745,7 +764,7 @@ void LLVMCodeGenImpl::visit(DivPtr v) { } else if (!lfp && !rfp) { value_ = irb_.CreateSDiv(lhs, rhs); } else { - throw malformed_input("llvm_codgen: bad type in Div", v); + throw malformed_input("llvm_codegen: bad type in Div", v); } } @@ -760,7 +779,7 @@ void LLVMCodeGenImpl::visit(AndPtr v) { if (!lfp && !rfp) { value_ = irb_.CreateAnd(lhs, rhs); } else { - throw malformed_input("llvm_codgen: bad type in And", v); + throw malformed_input("llvm_codegen: bad type in And", v); } } @@ -775,7 +794,7 @@ void LLVMCodeGenImpl::visit(OrPtr v) { if (!lfp && !rfp) { value_ = irb_.CreateOr(lhs, rhs); } else { - throw malformed_input("llvm_codgen: bad type in Or", v); + throw malformed_input("llvm_codegen: bad type in Or", v); } } @@ -790,7 +809,7 @@ void LLVMCodeGenImpl::visit(XorPtr v) { if (!lfp && !rfp) { value_ = irb_.CreateXor(lhs, rhs); } else { - throw malformed_input("llvm_codgen: bad type in Xor", v); + throw malformed_input("llvm_codegen: bad type in Xor", v); } } @@ -805,7 +824,7 @@ void LLVMCodeGenImpl::visit(LshiftPtr v) { if (!lfp && !rfp) { value_ = irb_.CreateShl(lhs, rhs); } else { - throw malformed_input("llvm_codgen: bad type in Lshift", v); + throw malformed_input("llvm_codegen: bad type in Lshift", v); } } @@ -824,7 +843,7 @@ void LLVMCodeGenImpl::visit(RshiftPtr v) { value_ = irb_.CreateLShr(lhs, rhs); } } else { - throw malformed_input("llvm_codgen: bad type in Rshift", v); + throw malformed_input("llvm_codegen: bad type in Rshift", v); } } @@ -839,7 +858,7 @@ void LLVMCodeGenImpl::visit(ModPtr v) { if (!lfp && !rfp) { value_ = irb_.CreateSRem(lhs, rhs); } else { - throw malformed_input("llvm_codgen: bad type in Mod", v); + throw malformed_input("llvm_codegen: bad type in Mod", v); } } @@ -2443,6 +2462,54 @@ void LLVMCodeGenImpl::visit(CondPtr v) { irb_.SetInsertPoint(end_block); } +// "New" PassManager needed to replace TM.adjustPassManager +#if LLVM_VERSION_MAJOR >= 15 +void LLVMCodeGenImpl::optimize(llvm::Module& M) { + // Add internal analysis passes from the target machine. + auto& TM = jit_->getTargetMachine(); + + // Create the analysis managers. + llvm::LoopAnalysisManager LAM; + llvm::FunctionAnalysisManager FAM; + llvm::CGSCCAnalysisManager CGAM; + llvm::ModuleAnalysisManager MAM; + + // Create the new pass manager builder. + // Take a look at the PassBuilder constructor parameters for more + // customization, e.g. specifying a TargetMachine or various debugging + // options. + llvm::PassBuilder PB(&TM); + + TM.registerPassBuilderCallbacks(PB); + + // Register all the basic analyses with the managers. + PB.registerModuleAnalyses(MAM); + PB.registerCGSCCAnalyses(CGAM); + PB.registerFunctionAnalyses(FAM); + PB.registerLoopAnalyses(LAM); + PB.crossRegisterProxies(LAM, FAM, CGAM, MAM); + + llvm::ModulePassManager MPM = + PB.buildPerModuleDefaultPipeline(llvm::OptimizationLevel::O3); + llvm::FunctionPassManager FPM = PB.buildFunctionSimplificationPipeline( + llvm::OptimizationLevel::O3, llvm::ThinOrFullLTOPhase::None); + + FAM.registerPass([&] { return TM.getTargetIRAnalysis(); }); + + FPM.addPass(llvm::LoopVectorizePass()); + FPM.addPass(llvm::SLPVectorizerPass()); + + FPM.addPass(llvm::DCEPass()); + MPM.addPass(llvm::AlwaysInlinerPass()); + + MPM.run(M, MAM); + for (auto& FF : M) { + if (!FF.empty()) { + FPM.run(FF, FAM); + } + } +} +#else // "Old" PassManager void LLVMCodeGenImpl::optimize(llvm::Module& M) { llvm::legacy::FunctionPassManager FPM(&M); llvm::legacy::PassManager PM; @@ -2469,6 +2536,7 @@ void LLVMCodeGenImpl::optimize(llvm::Module& M) { } FPM.doFinalization(); } +#endif RegisterCodeGen llvm_codegen_reg("llvm_codegen"); diff --git a/torch/csrc/jit/tensorexpr/loopnest.cpp b/torch/csrc/jit/tensorexpr/loopnest.cpp index 6b66d48fe505e..a9cab316aa3e4 100644 --- a/torch/csrc/jit/tensorexpr/loopnest.cpp +++ b/torch/csrc/jit/tensorexpr/loopnest.cpp @@ -1506,12 +1506,12 @@ void LoopNest::sliceHead(ForPtr f, int factor, ForPtr* head, ForPtr* tail) { } if (!f) { - throw malformed_input("sliceHead attempted on null loop", f); + throw malformed_input("sliceHead attempted on null loop"); } BlockPtr p = to(f->get_parent()); if (!p) { - throw malformed_input("sliceHead attempted on loop with no parent", p); + throw malformed_input("sliceHead attempted on loop with no parent"); } ExprPtr head_end = alloc( @@ -1546,12 +1546,12 @@ void LoopNest::sliceTail(ForPtr f, int factor, ForPtr* head, ForPtr* tail) { } if (!f) { - throw malformed_input("sliceTail attempted on null loop", f); + throw malformed_input("sliceTail attempted on null loop"); } BlockPtr p = to(f->get_parent()); if (!p) { - throw malformed_input("sliceTail attempted on loop with no parent", p); + throw malformed_input("sliceTail attempted on loop with no parent"); } ExprPtr tail_start = alloc( @@ -1585,12 +1585,12 @@ void LoopNest::splitWithTail( ForPtr* inner, ForPtr* tail) { if (!f) { - throw malformed_input("splitWithTail attempted on null loop", f); + throw malformed_input("splitWithTail attempted on null loop"); } BlockPtr p = to(f->get_parent()); if (!p) { - throw malformed_input("splitWithTail attempted on loop with no parent", p); + throw malformed_input("splitWithTail attempted on loop with no parent"); } // Normalize the loop to simplify start and stop bound computation diff --git a/torch/csrc/jit/tensorexpr/stmt.h b/torch/csrc/jit/tensorexpr/stmt.h index d2894ea157e6e..37a204bfa3294 100644 --- a/torch/csrc/jit/tensorexpr/stmt.h +++ b/torch/csrc/jit/tensorexpr/stmt.h @@ -754,13 +754,13 @@ class TORCH_API For : public StmtNode { stop_(stop), loop_options_(std::move(loop_options)) { if (!var) { - throw malformed_input("invalid Var in For loop", var); + throw malformed_input("invalid Var in For loop"); } else if (!start) { - throw malformed_input("invalid Start in For loop", start); + throw malformed_input("invalid Start in For loop"); } else if (!stop) { - throw malformed_input("invalid Stop in For loop", stop); + throw malformed_input("invalid Stop in For loop"); } else if (!body || body->get_parent()) { - throw malformed_input("invalid Body in For loop", body); + throw malformed_input("invalid Body in For loop"); } BlockPtr b = to(body); diff --git a/torch/csrc/lazy/backend/backend_interface.cpp b/torch/csrc/lazy/backend/backend_interface.cpp index cbcd92b6a9924..0fb3257c90a91 100644 --- a/torch/csrc/lazy/backend/backend_interface.cpp +++ b/torch/csrc/lazy/backend/backend_interface.cpp @@ -18,11 +18,6 @@ const BackendImplInterface* getBackend() { return interface; } -// default implementation -bool BackendImplInterface::ShouldSyncTensor(const LazyTensorPtr tensor) const { - return tensor->GetIrValue()->op() != ltc_not_supported; -} - BackendRegistrar::BackendRegistrar( const BackendImplInterface* backend_impl_interface) { backend_impl_registry.store(backend_impl_interface); @@ -43,7 +38,7 @@ at::Tensor MakeTensorFromComputationData( std::unique_ptr LoweringContext::Create( const std::string& name, BackendDevice device, - c10::ArrayRef post_order, + c10::ArrayRef post_order, Util::EmissionMap emit_status) { return getBackend()->CreateLoweringContext( name, device, post_order, emit_status); diff --git a/torch/csrc/lazy/backend/backend_interface.h b/torch/csrc/lazy/backend/backend_interface.h index 2936105dc6a3d..f94d3b602e52c 100644 --- a/torch/csrc/lazy/backend/backend_interface.h +++ b/torch/csrc/lazy/backend/backend_interface.h @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -41,8 +42,6 @@ class TORCH_API BackendImplInterface { virtual const IrBuilder* GetIrBuilder() const = 0; - virtual bool ShouldSyncTensor(const LazyTensorPtr tensor) const; - /** * Data Transfer * */ @@ -60,7 +59,7 @@ class TORCH_API BackendImplInterface { // Gets backend data if the node is a device data node. Otherwise returns // nullptr - virtual BackendDataPtr GetComputationDataFromNode(Node*) const = 0; + virtual BackendDataPtr GetComputationDataFromNode(const Node*) const = 0; virtual at::Tensor MakeTensorFromComputationData( const BackendDataPtr data, @@ -73,7 +72,7 @@ class TORCH_API BackendImplInterface { virtual std::unique_ptr CreateLoweringContext( const std::string& name, BackendDevice device, - c10::ArrayRef post_order, + c10::ArrayRef post_order, Util::EmissionMap emit_status) const = 0; virtual std::unique_ptr CreateLoweringContext( diff --git a/torch/csrc/lazy/backend/lowering_context.cpp b/torch/csrc/lazy/backend/lowering_context.cpp index 64922a1b3e136..635ee4891cc7f 100644 --- a/torch/csrc/lazy/backend/lowering_context.cpp +++ b/torch/csrc/lazy/backend/lowering_context.cpp @@ -9,7 +9,7 @@ LoweringContext::LoweringContext(const std::string& name, BackendDevice device) LoweringContext::LoweringContext( const std::string& name, BackendDevice device, - c10::ArrayRef post_order, + c10::ArrayRef post_order, Util::EmissionMap emit_status) : device_(std::move(device)), emit_status_(std::move(emit_status)) {} diff --git a/torch/csrc/lazy/backend/lowering_context.h b/torch/csrc/lazy/backend/lowering_context.h index 6f487aef7f741..49e7b8be58cbf 100644 --- a/torch/csrc/lazy/backend/lowering_context.h +++ b/torch/csrc/lazy/backend/lowering_context.h @@ -42,7 +42,7 @@ class TORCH_API LoweringContext { LoweringContext( const std::string& name, BackendDevice device, - c10::ArrayRef post_order, + c10::ArrayRef post_order, Util::EmissionMap emit_status); virtual ~LoweringContext() = default; @@ -50,7 +50,7 @@ class TORCH_API LoweringContext { static std::unique_ptr Create( const std::string& name, BackendDevice device, - c10::ArrayRef post_order, + c10::ArrayRef post_order, Util::EmissionMap emit_status); static std::unique_ptr Create( diff --git a/torch/csrc/lazy/core/config.cpp b/torch/csrc/lazy/core/config.cpp index d87036767be59..c39fd8fef75a4 100644 --- a/torch/csrc/lazy/core/config.cpp +++ b/torch/csrc/lazy/core/config.cpp @@ -10,7 +10,7 @@ C10_DEFINE_bool( C10_DEFINE_bool( torch_lazy_handle_special_scalars, false, - "Handle special scalars 0 and 1 diffrently"); + "Handle special scalars 0 and 1 differently"); C10_DEFINE_bool( torch_lazy_all_numbers_special_scalars, diff --git a/torch/csrc/lazy/core/debug_util.cpp b/torch/csrc/lazy/core/debug_util.cpp index 50f42b718128e..50077d498a751 100644 --- a/torch/csrc/lazy/core/debug_util.cpp +++ b/torch/csrc/lazy/core/debug_util.cpp @@ -88,7 +88,7 @@ std::string DebugUtil::GetTensorsGraphInfo( c10::ArrayRef tensors, const std::vector* indices, GraphFormat format) { - std::vector root_nodes; + std::vector root_nodes; std::vector root_values; std::vector root_hashes; torch::lazy::Unique unique_device; diff --git a/torch/csrc/lazy/core/internal_ops/ltc_ops.h b/torch/csrc/lazy/core/internal_ops/ltc_ops.h index 3f195d8b445cf..ce62f2e51f539 100644 --- a/torch/csrc/lazy/core/internal_ops/ltc_ops.h +++ b/torch/csrc/lazy/core/internal_ops/ltc_ops.h @@ -48,13 +48,5 @@ const OpKindWrapper ltc_replication_pad_backward( "lazy_tensors::replication_pad_backward"); const OpKindWrapper ltc_tensor_data("lazy_tensors::tensor_data"); -// For view ops -const OpKindWrapper ltc_as_strided_view_update( - "lazy_tensors::as_strided_view_update"); -const OpKindWrapper ltc_diagonal_view_update( - "lazy_tensors::diagonal_view_update"); -const OpKindWrapper ltc_narrow_view_update("lazy_tensors::narrow_view_update"); -const OpKindWrapper ltc_select_view_update("lazy_tensors::select_view_update"); - } // namespace lazy } // namespace torch diff --git a/torch/csrc/lazy/core/ir_builder.h b/torch/csrc/lazy/core/ir_builder.h index 8e645c485158e..9cc974236cd8f 100644 --- a/torch/csrc/lazy/core/ir_builder.h +++ b/torch/csrc/lazy/core/ir_builder.h @@ -58,9 +58,6 @@ struct IrBuilder { const Value& input0, const std::vector& size, const bool& is_scalar_expand) const = 0; - virtual NodePtr MakeView( - const Value& input0, - const std::vector& output_size) const = 0; virtual NodePtr MakeCast( const Value& input0, const at::ScalarType& dtype, @@ -73,59 +70,6 @@ struct IrBuilder { const size_t& num_outputs = 1, const hash_t& hash_seed = static_cast(0x5a2d296e9)) const = 0; - // View op nodes - virtual NodePtr MakeAsStridedViewUpdate( - const Value& input0, - const Value& input1, - const std::vector& size, - const std::vector& stride, - const int64_t& storage_offset) const = 0; - virtual NodePtr MakeAsStrided( - const Value& input0, - const std::vector& size, - const std::vector& stride, - const int64_t& storage_offset) const = 0; - virtual NodePtr MakeDiagonalViewUpdate( - const Value& input0, - const Value& input1, - const int64_t& offset, - const int64_t& dim1, - const int64_t& dim2) const = 0; - virtual NodePtr MakeDiagonal( - const Value& input0, - const int64_t& offset, - const int64_t& dim1, - const int64_t& dim2) const = 0; - virtual NodePtr MakeNarrowViewUpdate( - const Value& input0, - const Value& input1, - const std::vector& base_indices) const = 0; - virtual NodePtr MakeNarrow( - const Value& input0, - const std::vector& base_indices, - const std::vector& sizes) const = 0; - virtual NodePtr MakePermute( - const Value& input0, - const std::vector& dims) const = 0; - virtual NodePtr MakeResize( - const Value& input0, - const std::vector& size) const = 0; - virtual NodePtr MakeSelectViewUpdate( - const Value& input0, - const Value& input1, - const int64_t& dim, - const int64_t& start, - const int64_t& end, - const int64_t& stride) const = 0; - virtual NodePtr MakeSelect( - const Value& input0, - const int64_t& dim, - const int64_t& start, - const int64_t& end, - const int64_t& stride) const = 0; - virtual NodePtr MakeSqueeze(const Value& input0, const int& dim) const = 0; - virtual NodePtr MakeUnsqueeze(const Value& input0, const int& dim) const = 0; - // dynamic ir nodes virtual NodePtr MakeSizeNode(const Value& input, size_t dim) const = 0; virtual NodePtr MakeSizeAdd(const Value& a, const Value& b) const = 0; @@ -149,11 +93,6 @@ static inline NodePtr MakeExpand( const bool& is_scalar_expand) { return getIrBuilder()->MakeExpand(input0, size, is_scalar_expand); } -static inline NodePtr MakeView( - const Value& input0, - const std::vector& output_size) { - return getIrBuilder()->MakeView(input0, output_size); -} static inline NodePtr MakeCast( const Value& input0, const at::ScalarType& dtype, @@ -173,86 +112,6 @@ static inline NodePtr MakeGeneric( op, operands, shape, num_outputs, hash_seed); } -// View op nodes -static inline NodePtr MakeAsStridedViewUpdate( - const Value& input0, - const Value& input1, - const std::vector& size, - const std::vector& stride, - const int64_t& storage_offset) { - return getIrBuilder()->MakeAsStridedViewUpdate( - input0, input1, size, stride, storage_offset); -} -static inline NodePtr MakeAsStrided( - const Value& input0, - const std::vector& size, - const std::vector& stride, - const int64_t& storage_offset) { - return getIrBuilder()->MakeAsStrided(input0, size, stride, storage_offset); -} -static inline NodePtr MakeDiagonalViewUpdate( - const Value& input0, - const Value& input1, - const int64_t& offset, - const int64_t& dim1, - const int64_t& dim2) { - return getIrBuilder()->MakeDiagonalViewUpdate( - input0, input1, offset, dim1, dim2); -} -static inline NodePtr MakeDiagonal( - const Value& input0, - const int64_t& offset, - const int64_t& dim1, - const int64_t& dim2) { - return getIrBuilder()->MakeDiagonal(input0, offset, dim1, dim2); -} -static inline NodePtr MakeNarrowViewUpdate( - const Value& input0, - const Value& input1, - const std::vector& base_indices) { - return getIrBuilder()->MakeNarrowViewUpdate(input0, input1, base_indices); -} -static inline NodePtr MakeNarrow( - const Value& input0, - const std::vector& base_indices, - const std::vector& sizes) { - return getIrBuilder()->MakeNarrow(input0, base_indices, sizes); -} -static inline NodePtr MakePermute( - const Value& input0, - const std::vector& dims) { - return getIrBuilder()->MakePermute(input0, dims); -} -static inline NodePtr MakeResize( - const Value& input0, - const std::vector& size) { - return getIrBuilder()->MakeResize(input0, size); -} -static inline NodePtr MakeSelectViewUpdate( - const Value& input0, - const Value& input1, - const int64_t& dim, - const int64_t& start, - const int64_t& end, - const int64_t& stride) { - return getIrBuilder()->MakeSelectViewUpdate( - input0, input1, dim, start, end, stride); -} -static inline NodePtr MakeSelect( - const Value& input0, - const int64_t& dim, - const int64_t& start, - const int64_t& end, - const int64_t& stride) { - return getIrBuilder()->MakeSelect(input0, dim, start, end, stride); -} -static inline NodePtr MakeSqueeze(const Value& input0, const int& dim) { - return getIrBuilder()->MakeSqueeze(input0, dim); -} -static inline NodePtr MakeUnsqueeze(const Value& input0, const int& dim) { - return getIrBuilder()->MakeUnsqueeze(input0, dim); -} - // dynamic ir nodes static inline NodePtr MakeSizeNode(const Value& input, size_t dim) { return getIrBuilder()->MakeSizeNode(input, dim); @@ -269,10 +128,10 @@ static inline NodePtr MakeSizeDiv(const Value& a, const Value& b) { inline Value GetSymIntValue(c10::SymInt a) { return Value( - a.is_symbolic() ? dynamic_cast( - a.toSymIntNodeImpl().get()) - ->node_ - : MakeScalar(a.as_int_unchecked(), at::kLong), + a.is_symbolic() + ? dynamic_cast(a.toSymNodeImpl().get()) + ->node_ + : MakeScalar(a.as_int_unchecked(), at::kLong), 0); } diff --git a/torch/csrc/lazy/core/ir_dump_util.cpp b/torch/csrc/lazy/core/ir_dump_util.cpp index eff2873d668d7..19cb2ae7b1624 100644 --- a/torch/csrc/lazy/core/ir_dump_util.cpp +++ b/torch/csrc/lazy/core/ir_dump_util.cpp @@ -80,7 +80,7 @@ c10::optional ParseAttrTag( return tag; } -NodeIdMap GenerateIdMap(c10::ArrayRef post_order) { +NodeIdMap GenerateIdMap(c10::ArrayRef post_order) { NodeIdMap id_map; for (auto node : post_order) { TORCH_CHECK(id_map.emplace(node, id_map.size()).second, node->ToString()); @@ -89,7 +89,7 @@ NodeIdMap GenerateIdMap(c10::ArrayRef post_order) { } std::unordered_map GetRootsIds( - c10::ArrayRef roots) { + c10::ArrayRef roots) { std::unordered_map roots_ids; for (const auto i : c10::irange(roots.size())) { roots_ids[roots[i]] = i; @@ -178,14 +178,14 @@ std::string GenerateTextNodeSpec(const Node* node, const NodeIdMap& id_map) { } // namespace -std::string DumpUtil::ToDot(c10::ArrayRef nodes) { +std::string DumpUtil::ToDot(c10::ArrayRef nodes) { auto post_order = Util::ComputePostOrder(nodes); return PostOrderToDot(post_order, nodes); } std::string DumpUtil::PostOrderToDot( - c10::ArrayRef post_order, - c10::ArrayRef roots) { + c10::ArrayRef post_order, + c10::ArrayRef roots) { std::unordered_map roots_ids = GetRootsIds(roots); NodeIdMap id_map = GenerateIdMap(post_order); std::stringstream ss; @@ -218,14 +218,14 @@ std::string DumpUtil::PostOrderToDot( return ss.str(); } -std::string DumpUtil::ToText(c10::ArrayRef nodes) { +std::string DumpUtil::ToText(c10::ArrayRef nodes) { auto post_order = Util::ComputePostOrder(nodes); return PostOrderToText(post_order, nodes); } std::string DumpUtil::PostOrderToText( - c10::ArrayRef post_order, - c10::ArrayRef roots) { + c10::ArrayRef post_order, + c10::ArrayRef roots) { std::unordered_map roots_ids = GetRootsIds(roots); NodeIdMap id_map = GenerateIdMap(post_order); std::stringstream ss; diff --git a/torch/csrc/lazy/core/ir_dump_util.h b/torch/csrc/lazy/core/ir_dump_util.h index 22cf139bfbd64..4b4e1e0749b24 100644 --- a/torch/csrc/lazy/core/ir_dump_util.h +++ b/torch/csrc/lazy/core/ir_dump_util.h @@ -11,17 +11,17 @@ class BackendDevice; class TORCH_API DumpUtil { public: - static std::string ToDot(c10::ArrayRef nodes); + static std::string ToDot(c10::ArrayRef nodes); static std::string PostOrderToDot( - c10::ArrayRef post_order, - c10::ArrayRef roots); + c10::ArrayRef post_order, + c10::ArrayRef roots); - static std::string ToText(c10::ArrayRef nodes); + static std::string ToText(c10::ArrayRef nodes); static std::string PostOrderToText( - c10::ArrayRef post_order, - c10::ArrayRef roots); + c10::ArrayRef post_order, + c10::ArrayRef roots); static std::string ToBackend( c10::ArrayRef values, diff --git a/torch/csrc/lazy/core/ir_util.cpp b/torch/csrc/lazy/core/ir_util.cpp index 2d463bb99d5f5..b2a2a8ecfa20a 100644 --- a/torch/csrc/lazy/core/ir_util.cpp +++ b/torch/csrc/lazy/core/ir_util.cpp @@ -5,13 +5,12 @@ namespace torch { namespace lazy { -std::vector Util::ComputePostOrder(const Node* node, EmissionMap* emap) { - std::vector post_order; - std::vector queue; - // std::vector to c10::ArrayRef conversion is not supported, - // so we need to drop const in the return vector and use const_cast here. - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - queue.push_back(const_cast(node)); +std::vector Util::ComputePostOrder( + const Node* node, + EmissionMap* emap) { + std::vector post_order; + std::vector queue; + queue.push_back(node); while (!queue.empty()) { node = queue.back(); auto it = emap->find(node); @@ -20,8 +19,7 @@ std::vector Util::ComputePostOrder(const Node* node, EmissionMap* emap) { for (auto& output : node->operands()) { auto oit = emap->find(output.node); if (oit == emap->end()) { - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - queue.push_back(const_cast(output.node)); + queue.push_back(output.node); } else { TORCH_CHECK( oit->second != kEmitting, @@ -38,8 +36,7 @@ std::vector Util::ComputePostOrder(const Node* node, EmissionMap* emap) { output.node->ToString()); } (*emap)[node] = kEmitted; - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - post_order.push_back(const_cast(node)); + post_order.push_back(node); queue.pop_back(); } else { TORCH_CHECK(it->second == kEmitted); @@ -49,10 +46,10 @@ std::vector Util::ComputePostOrder(const Node* node, EmissionMap* emap) { return post_order; } -std::vector Util::ComputePostOrder( - c10::ArrayRef nodes, +std::vector Util::ComputePostOrder( + c10::ArrayRef nodes, EmissionMap* emap) { - std::vector post_order; + std::vector post_order; for (auto node : nodes) { auto node_post_order = ComputePostOrder(node, emap); post_order.insert( @@ -61,12 +58,13 @@ std::vector Util::ComputePostOrder( return post_order; } -std::vector Util::ComputePostOrder(c10::ArrayRef nodes) { +std::vector Util::ComputePostOrder( + c10::ArrayRef nodes) { EmissionMap emap; return ComputePostOrder(nodes, &emap); } -size_t Util::GetGraphSize(c10::ArrayRef nodes) { +size_t Util::GetGraphSize(c10::ArrayRef nodes) { return ComputePostOrder(nodes).size(); } diff --git a/torch/csrc/lazy/core/ir_util.h b/torch/csrc/lazy/core/ir_util.h index a95b1a523bfa9..df3d0fd7ac406 100644 --- a/torch/csrc/lazy/core/ir_util.h +++ b/torch/csrc/lazy/core/ir_util.h @@ -25,21 +25,22 @@ class TORCH_API Util { // this API. The returned post-order can be empty if the node has already been // emitted inside the emission map. An error is generated if a loop is // detected. - static std::vector ComputePostOrder( + static std::vector ComputePostOrder( const Node* node, EmissionMap* emap); - static std::vector ComputePostOrder( - c10::ArrayRef nodes, + static std::vector ComputePostOrder( + c10::ArrayRef nodes, EmissionMap* emap); // Same as above, but computes the post order on the set of nodes specified as // argument. - static std::vector ComputePostOrder(c10::ArrayRef nodes); + static std::vector ComputePostOrder( + c10::ArrayRef nodes); // Retrieves the number of nodes within the graph whose sink are passed in the // nodes argument. - static size_t GetGraphSize(c10::ArrayRef nodes); + static size_t GetGraphSize(c10::ArrayRef nodes); }; } // namespace lazy diff --git a/torch/csrc/lazy/core/lazy_graph_executor.cpp b/torch/csrc/lazy/core/lazy_graph_executor.cpp index 96476e4a9663b..acab845c1f346 100644 --- a/torch/csrc/lazy/core/lazy_graph_executor.cpp +++ b/torch/csrc/lazy/core/lazy_graph_executor.cpp @@ -390,10 +390,15 @@ bool TensorsHaveIR(const std::vector& tensors) { return false; } +std::atomic lazy_graph_executor_registry; } // namespace +void LazyGraphExecutor::Register(LazyGraphExecutor* executor) { + lazy_graph_executor_registry.store(executor); +} LazyGraphExecutor* LazyGraphExecutor::Get() { - static LazyGraphExecutor* executor = new LazyGraphExecutor(); + auto* executor = lazy_graph_executor_registry.load(); + TORCH_CHECK(executor, "Lazy graph executor not registered."); return executor; } @@ -604,6 +609,10 @@ void LazyGraphExecutor::Async::Wait() { } } +bool LazyGraphExecutor::ShouldSyncTensor(const LazyTensorPtr tensor) const { + return tensor->GetIrValue()->op() != ltc_not_supported; +} + LazyGraphExecutor::SyncTensorCollection LazyGraphExecutor::CollectSyncTensors( const std::vector& tensors, const SyncTensorsConfig& config) { @@ -635,7 +644,7 @@ LazyGraphExecutor::SyncTensorCollection LazyGraphExecutor::CollectSyncTensors( tensors[i]->CurrentDataHandle() == nullptr) { Value ir_value = tensors[i]->CurrentIrValue(); if (ir_value) { - if (getBackend()->ShouldSyncTensor(tensors[i])) { + if (ShouldSyncTensor(tensors[i])) { // Add only tensors which need to be synced. coll.hash = HashCombine(coll.hash, ir_value.hash()); coll.indices.push_back(i); @@ -679,13 +688,37 @@ std::vector LazyGraphExecutor::CollectRoots( return roots; } -std::vector LazyGraphExecutor::FetchTensorData( +void LazyGraphExecutor::ExtractIRAndPrepareTensorData( std::vector* tensors, const SyncTensorsConfig& config, - c10::ArrayRef indices) { + c10::ArrayRef indices, + std::vector& ir_values, + std::vector& tensor_data_vec) { + ir_values.reserve(indices.size()); + tensor_data_vec.reserve(indices.size()); + for (auto index : indices) { + LazyTensorPtr& tensor = (*tensors)[index]; + Value ir_value = tensor->CurrentIrValue(); + ir_values.push_back(ir_value); + const BackendDevice& tensor_device = tensor->GetDevice(); + BackendDataPtr handle = getBackend()->CreateDataPlaceholder( + tensor_device, std::move(tensor->shape())); + tensor_data_vec.push_back(handle); + if (tensor->CurrentDataHandle() == nullptr && config.sync_ltc_data) { + tensor->AssignIrValue(Value()); + } + } +} + +std::vector LazyGraphExecutor::SetTensorData( + std::vector* tensors, + const SyncTensorsConfig& config, + c10::ArrayRef indices, + const std::vector& tensor_data_vec) { std::vector tensors_data; tensors_data.reserve(indices.size()); - for (auto index : indices) { + for (int i = 0; i < indices.size(); i++) { + auto index = indices[i]; LazyTensorPtr& tensor = (*tensors)[index]; // If the config.force_ltc_data flag is true, the purpose of this tensor // sync operation is to truncate the IR graph and materialize device data in @@ -698,11 +731,12 @@ std::vector LazyGraphExecutor::FetchTensorData( // completes. BackendDataPtr handle = tensor->CurrentDataHandle(); if (handle == nullptr && config.force_ltc_data) { - const BackendDevice& tensor_device = tensor->GetDevice(); - handle = getBackend()->CreateDataPlaceholder( - tensor_device, std::move(tensor->shape())); - - tensor->SetDataHandle(handle, config.sync_ltc_data); + handle = tensor_data_vec[i]; + // Note: We are not using SetHandleData method here since that method + // resets the ir_value. We have already done the resetting as part + // of ExtractIRAndPrepareTensorData to overlap with previous execution. + tensor->data()->handle = handle; + tensor->data()->tensor_data = c10::nullopt; } tensors_data.emplace_back(std::move(handle)); } @@ -710,12 +744,11 @@ std::vector LazyGraphExecutor::FetchTensorData( } LazyGraphExecutor::PostOrderData LazyGraphExecutor::RunPostOrder( - const std::vector& tensors, + const std::vector& ir_values, SyncTensorCollection* coll) { - std::vector roots; - roots.reserve(coll->indices.size()); - for (auto index : coll->indices) { - Value ir_value = tensors.at(index)->CurrentIrValue(); + std::vector roots; + roots.reserve(ir_values.size()); + for (auto ir_value : ir_values) { roots.push_back(ir_value.node.get()); } PostOrderData po_data; @@ -746,7 +779,8 @@ LazyGraphExecutor::PostOrderData LazyGraphExecutor::RunPostOrder( std::shared_ptr LazyGraphExecutor::TryRunCachedSync( std::vector* tensors, SyncTensorCollection* coll, - PostOrderData* po_data) { + PostOrderData* po_data, + const std::vector& tensor_data_vec) { ComputationCache::TypePtr cached_computation = LookupCachedCompile(coll->hash); if (cached_computation == nullptr) { @@ -763,50 +797,24 @@ std::shared_ptr LazyGraphExecutor::TryRunCachedSync( tensors, coll, std::move(po_data->parameters_data), - std::move(cached_computation)); + std::move(cached_computation), + tensor_data_vec); } LazyGraphExecutor::CompilationResult LazyGraphExecutor::Compile( const std::vector& tensors, c10::ArrayRef devices, const SyncTensorCollection& coll, - PostOrderData* po_data) { + PostOrderData* po_data, + const std::vector& ir_values) { auto lowering_ctx = LoweringContext::Create( "SyncTensorsGraph", coll.device, po_data->post_order, std::move(po_data->emission_map)); - for (auto index : coll.indices) { - Value ir_value = tensors[index]->CurrentIrValue(); + for (auto ir_value : ir_values) { lowering_ctx->AddResult(ir_value); } - if (FLAGS_torch_lazy_param_aliasing && coll.config.sync_ltc_data) { - // We can only alias at the step barrier, when force_ltc_data is true. - // Consider the case: - // 1. Tensor A(DEVICE_DATA) - // 2. Tensor B = A + 0.9 - // 3. A += 0.4 - // If we activate aliasing for A's graph, and we do: - // print(A) - // print(A) - // The first print will update DEVICE_DATA' with DEVICE_DATA+0.4, and the - // second print will again update DEVICE_DATA" with DEVICE_DATA'+0.4, which - // will lead to incorrect results. - // We cannot normally turn A's state into DEVICE_DATA, as if any of the - // sources is a view, this will not lead to correct results (as A's value - // taken at different times need to reflect view source changes): - // 1. Tensor A = some_graph_with_view_source(V) - // 2. print(A) - // 3. V += 1 - // 4. print(A) - // The second print should reflect the new value due to V's changes. - // Also in the first example, unless we are doing a step barrier and hence - // include all live tensors, if the B value is not part of the graph, it - // will later fetch the new value of A, which is incorrect. - // But, when we issue a step barrier (force_ltc_data == true) we have to - // turn everything into DEVICE_DATA, so we can activate aliasing. - BuildInputOutputAliases(tensors, coll.indices, lowering_ctx.get()); - } ComputationPtr computation = lowering_ctx->Build(); // If force_ltc_data is true it means that we did a proper sync and are @@ -857,40 +865,6 @@ LazyGraphExecutor::ComputationCache::TypePtr LazyGraphExecutor:: typedef SSIZE_T ssize_t; #endif -void LazyGraphExecutor::BuildInputOutputAliases( - const std::vector& tensors, - c10::ArrayRef indices, - LoweringContext* lowering_ctx) { - std::unordered_map output_tensor_id_map; - for (const auto i : c10::irange(indices.size())) { - size_t tensor_index = indices[i]; - int64_t tensor_id = tensors[tensor_index]->GetUniqueId(); - output_tensor_id_map[tensor_id] = i; - } - const std::vector& parameters_data = - lowering_ctx->GetParametersData(); - std::vector alias_map(indices.size(), -1); - for (const auto i : c10::irange(parameters_data.size())) { - DeviceDataInfo* data_info = - dynamic_cast(parameters_data[i]->info()); - if (data_info != nullptr && !data_info->read_only) { - auto it = output_tensor_id_map.find(data_info->tensor_id); - if (it != output_tensor_id_map.end()) { - size_t output_index = it->second; - if (lowering_ctx->CheckResultShape(parameters_data[i], output_index) && - alias_map[output_index] < 0) { - lowering_ctx->SetUpAlias({static_cast(output_index)}, i, {}); - alias_map[output_index] = i; - - VLOG(6) << "Aliased parameter " << i << " with output " - << output_index << ": " << Shape(parameters_data[i]->shape()); - } - } - } - } - TORCH_LAZY_VALUE_METRIC("InputOutputAliasCount", alias_map.size()); -} - std::shared_ptr LazyGraphExecutor:: SyncTensorsGraphInternal( std::vector* tensors, @@ -903,17 +877,23 @@ std::shared_ptr LazyGraphExecutor:: TensorCollectionBarrier(&coll); return nullptr; } - PostOrderData po_data = RunPostOrder(*tensors, &coll); DebugUtil::SaveTensorsGraphInfo( "ScheduleSyncTensorsGraph", *tensors, &coll.indices); + std::vector ir_values; + std::vector tensor_data_vec; + ExtractIRAndPrepareTensorData( + tensors, coll.config, coll.indices, ir_values, tensor_data_vec); + PostOrderData po_data = RunPostOrder(ir_values, &coll); coll.hash = HashCombine(coll.hash, Hash(po_data.parameter_sequence)); VLOG(4) << "Parameter sequence graph hash " << HashToString(coll.hash); - std::shared_ptr async = TryRunCachedSync(tensors, &coll, &po_data); + std::shared_ptr async = + TryRunCachedSync(tensors, &coll, &po_data, tensor_data_vec); if (async != nullptr) { return async; } - CompilationResult compile_result = Compile(*tensors, devices, coll, &po_data); + CompilationResult compile_result = + Compile(*tensors, devices, coll, &po_data, ir_values); if (GRAPH_DUMP_ENABLED) { auto* comp = compile_result.computation.get(); LOG(ERROR) << "Add a cached computation with hash " << coll.hash @@ -932,7 +912,8 @@ std::shared_ptr LazyGraphExecutor:: tensors, &coll, std::move(compile_result.parameters_data), - std::move(cached_computation)); + std::move(cached_computation), + tensor_data_vec); } std::shared_ptr LazyGraphExecutor:: @@ -948,12 +929,7 @@ std::shared_ptr LazyGraphExecutor:: std::move(tensors_data), std::move(cached_computation)); - auto syncfn = [this, async, hash = coll->hash]() { - // For profiling lazy trace overhead - if (noop_execution_mode_) { - return; - } - + auto syncfn = [async, hash = coll->hash]() { try { VLOG(3) << "Executing IR graph hash " << HashToString(hash) << " on device " << async->device << " ..."; @@ -988,10 +964,8 @@ std::shared_ptr LazyGraphExecutor:: // even in case the caller does not wait, and that is accomplished by // setting the unlockers status. In that case the exception will be // surfaced when the user tries to acquire the device locks the next time. - // std::exception_ptr exptr = std::current_exception(); for (auto& unlocker : async->unlocker) { - std::exception_ptr exptr = std::current_exception(); - unlocker.SetStatus(std::move(exptr)); + unlocker.SetStatus(std::current_exception()); } throw; } @@ -1010,8 +984,10 @@ std::shared_ptr LazyGraphExecutor:: std::vector* tensors, SyncTensorCollection* coll, std::vector parameters_data, - ComputationCache::TypePtr cached_computation) { - auto tensors_data = FetchTensorData(tensors, coll->config, coll->indices); + ComputationCache::TypePtr cached_computation, + const std::vector& tensor_data_vec) { + auto tensors_data = + SetTensorData(tensors, coll->config, coll->indices, tensor_data_vec); return ScheduleSyncTensorsGraph( coll, std::move(parameters_data), @@ -1131,7 +1107,12 @@ hash_t LazyGraphExecutor::GetGraphHash( config.sync_ltc_data = false; auto coll = CollectSyncTensors(tensors, config); - auto po_data = RunPostOrder(tensors, &coll); + std::vector ir_values; + for (auto index : coll.indices) { + Value ir_value = tensors[index]->CurrentIrValue(); + ir_values.push_back(ir_value); + } + auto po_data = RunPostOrder(ir_values, &coll); coll.hash = HashCombine(coll.hash, Hash(po_data.parameter_sequence)); return coll.hash; } diff --git a/torch/csrc/lazy/core/lazy_graph_executor.h b/torch/csrc/lazy/core/lazy_graph_executor.h index 8116ad23ff068..10b41b64a6174 100644 --- a/torch/csrc/lazy/core/lazy_graph_executor.h +++ b/torch/csrc/lazy/core/lazy_graph_executor.h @@ -21,10 +21,18 @@ class TORCH_API LazyGraphExecutor { bool read_only = false; }; + // Register a lazy graph executor instance that can be retrieved using Get() + static void Register(LazyGraphExecutor*); static LazyGraphExecutor* Get(); - void RegisterTensor(std::shared_ptr data); - void UnregisterTensor(LazyTensor::Data* data); + virtual ~LazyGraphExecutor() = default; + + // Override these methods to perform custom tensor registration and + // unregistration Note: It is vital that the parent implementations are also + // called + // in order for the tensors to show up in the live tensor list + virtual void RegisterTensor(std::shared_ptr data); + virtual void UnregisterTensor(LazyTensor::Data* data); // Seed for random generator Value GetRngSeed(const BackendDevice& device); @@ -110,12 +118,6 @@ class TORCH_API LazyGraphExecutor { const Shape& shape, const BackendDevice& device); - // Configure the executor treat compile/execute API calls as no-ops - // for use when profiling lazy trace overheads - void SetNoOpExecutionMode(bool enable_noop) { - noop_execution_mode_ = enable_noop; - } - struct CachedComputation { explicit CachedComputation(ComputationPtr computation) : computation(std::move(computation)) {} @@ -129,7 +131,10 @@ class TORCH_API LazyGraphExecutor { hash_t GetGraphHash(const std::vector& tensors); - private: + protected: + // TODO(alanwaketan): Revisit if all of them need to be accessible to + // derived classes. + struct SyncTensorsConfig { // Whether we want to force data on the target tensors (hence trimming // the IR graph above them). @@ -150,12 +155,13 @@ class TORCH_API LazyGraphExecutor { }; struct PostOrderData { - std::vector post_order; + std::vector post_order; Util::EmissionMap emission_map; std::vector parameters_data; std::vector parameter_sequence; }; + private: struct CompilationResult { BackendDevice device; size_t emitted_nodes = 0; @@ -181,6 +187,8 @@ class TORCH_API LazyGraphExecutor { std::vector tensors_data; }; + virtual bool ShouldSyncTensor(const LazyTensorPtr tensor) const; + SyncTensorCollection CollectSyncTensors( const std::vector& tensors, const SyncTensorsConfig& config); @@ -192,32 +200,38 @@ class TORCH_API LazyGraphExecutor { const std::vector& tensors, c10::ArrayRef indices); - std::vector FetchTensorData( + std::vector SetTensorData( std::vector* tensors, const SyncTensorsConfig& config, - c10::ArrayRef indices); + c10::ArrayRef indices, + const std::vector& tensor_data_vec); + + void ExtractIRAndPrepareTensorData( + std::vector* tensors, + const SyncTensorsConfig& config, + c10::ArrayRef indices, + std::vector& ir_values, + std::vector& tensor_data_vec); PostOrderData RunPostOrder( - const std::vector& tensors, + const std::vector& ir_values, SyncTensorCollection* coll); + std::shared_ptr TryRunCachedSync( std::vector* tensors, SyncTensorCollection* coll, - PostOrderData* po_data); + PostOrderData* po_data, + const std::vector& tensor_data_vec); CompilationResult Compile( const std::vector& tensors, c10::ArrayRef devices, const SyncTensorCollection& coll, - PostOrderData* po_data); + PostOrderData* po_data, + const std::vector& ir_values); ComputationCache::TypePtr LookupCachedCompile(const hash_t& hash); - void BuildInputOutputAliases( - const std::vector& tensors, - c10::ArrayRef indices, - LoweringContext* lowering_ctx); - std::shared_ptr SyncTensorsGraphInternal( std::vector* tensors, c10::ArrayRef devices, @@ -236,7 +250,8 @@ class TORCH_API LazyGraphExecutor { std::vector* tensors, SyncTensorCollection* coll, std::vector parameters_data, - ComputationCache::TypePtr cached_computation); + ComputationCache::TypePtr cached_computation, + const std::vector& tensor_data_vec); std::vector GetTensorsFused(std::vector* tensors); @@ -251,8 +266,6 @@ class TORCH_API LazyGraphExecutor { const std::vector& tensors, c10::ArrayRef indices, c10::ArrayRef tensors_data); - - bool noop_execution_mode_ = false; }; } // namespace lazy diff --git a/torch/csrc/lazy/core/lazy_view.cpp b/torch/csrc/lazy/core/lazy_view.cpp deleted file mode 100644 index d52c0f62fb77e..0000000000000 --- a/torch/csrc/lazy/core/lazy_view.cpp +++ /dev/null @@ -1,262 +0,0 @@ -#include - -#include -#include -#include -#include - -#include -#include -#include -#include - -namespace torch { -namespace lazy { -namespace { - -Value ApplyViewInfo(Value ir_value, const ViewInfo& view_info) { - switch (view_info.view_type) { - case ViewInfo::Type::kSelect: - return MakeSelect( - ir_value, - view_info.select->dim, - view_info.select->start, - view_info.select->end, - view_info.select->stride); - case ViewInfo::Type::kNarrow: - return MakeNarrow( - ir_value, view_info.indices, view_info.shape.sizes().vec()); - case ViewInfo::Type::kNoOp: - return ir_value; - case ViewInfo::Type::kPermute: - return MakePermute(ir_value, view_info.permutation); - case ViewInfo::Type::kReshape: - return MakeView(ir_value, view_info.shape.sizes().vec()); - case ViewInfo::Type::kResize: - return MakeResize(ir_value, view_info.shape.sizes().vec()); - case ViewInfo::Type::kSqueeze: - return MakeSqueeze(ir_value, view_info.squeeze_index); - case ViewInfo::Type::kUnsqueeze: - return MakeUnsqueeze(ir_value, view_info.squeeze_index); - case ViewInfo::Type::kAsStrided: - return MakeAsStrided( - ir_value, - view_info.shape.sizes().vec(), - view_info.as_strided->stride, - view_info.as_strided->offset); - case ViewInfo::Type::kDiagonal: - return MakeDiagonal( - ir_value, - view_info.diagonal->offset, - view_info.diagonal->dim1, - view_info.diagonal->dim2); - default: - TORCH_INTERNAL_ASSERT( - false, "Invalid view type: ", GetEnumValue(view_info.view_type)); - } -} - -// Here we are trying to populate inplace updated values from the latest view -// all the way back to the original tensor. -// For example: -// a = torch.diagonal(b) -// b.add_(1) # a should be updated as well. -// -// Ideally we should all have a *ViewUpdate IR which updates the original -// tensor/view withe current value. See DiagonalViewUpdate and corresponding -// LowerDiagonalViewUpdate in ts_node_lowering.cpp. There are some "edge cases" -// here simply because they can smartly reuse some other ops to undo themselves. -Value ApplyUpdate(Value ir_value, const Alias::UpdateData& update_data) { - // We first bring the source IR value forward, by reshaping and slicing. - std::vector tmp_values({ir_value}); - for (const ViewInfo& view_info : update_data.view_infos) { - tmp_values.push_back(ApplyViewInfo(tmp_values.back(), view_info)); - } - // We then move backward given the source update value, by reshaping and - // slice-updating. - Value result = update_data.ir_value; - for (size_t i = update_data.view_infos.size(); i > 0; --i) { - const ViewInfo& view_info = update_data.view_infos[i - 1]; - switch (view_info.view_type) { - case ViewInfo::Type::kSelect: - result = MakeSelectViewUpdate( - tmp_values[i - 1], - result, - view_info.select->dim, - view_info.select->start, - view_info.select->end, - view_info.select->stride); - break; - case ViewInfo::Type::kNarrow: - result = - MakeNarrowViewUpdate(tmp_values[i - 1], result, view_info.indices); - break; - case ViewInfo::Type::kNoOp: - break; - case ViewInfo::Type::kPermute: - result = MakePermute(result, InversePermutation(view_info.permutation)); - break; - case ViewInfo::Type::kReshape: - result = MakeView(result, view_info.source_shape.sizes().vec()); - break; - case ViewInfo::Type::kResize: - result = MakeResize(result, view_info.source_shape.sizes().vec()); - break; - case ViewInfo::Type::kSqueeze: - result = MakeUnsqueeze(ir_value, view_info.squeeze_index); - break; - case ViewInfo::Type::kUnsqueeze: - result = MakeSqueeze(ir_value, view_info.squeeze_index); - break; - case ViewInfo::Type::kAsStrided: - result = MakeAsStridedViewUpdate( - tmp_values[i - 1], - result, - view_info.source_shape.sizes().vec(), - view_info.as_strided->stride, - view_info.as_strided->offset); - break; - case ViewInfo::Type::kDiagonal: - result = MakeDiagonalViewUpdate( - tmp_values[i - 1], - result, - view_info.diagonal->offset, - view_info.diagonal->dim1, - view_info.diagonal->dim2); - break; - default: - TORCH_INTERNAL_ASSERT( - false, "Invalid view type: ", GetEnumValue(view_info.view_type)); - } - } - return result; -} - -} // namespace - -ViewInfo::ViewInfo(Type view_type, Shape shape, Shape source_shape) - : view_type(view_type), - shape(std::move(shape)), - indices(source_shape.dim(), 0), - source_shape(std::move(source_shape)) {} - -ViewInfo::ViewInfo(Type view_type, Shape shape, Shape source_shape, int64_t sqi) - : view_type(view_type), - shape(std::move(shape)), - source_shape(std::move(source_shape)), - squeeze_index(sqi) { - TORCH_CHECK(view_type == Type::kSqueeze); -} - -ViewInfo::ViewInfo( - Type view_type, - Shape source_shape, - std::vector permutation) - : view_type(view_type), - shape(MakePermuteShape(source_shape, permutation)), - source_shape(std::move(source_shape)), - permutation(std::move(permutation)) { - TORCH_CHECK(view_type == Type::kPermute); -} - -ViewInfo::ViewInfo(Type view_type, const Shape& source_shape, SelectInfo select) - : view_type(view_type), - shape(MakeSelectShape( - source_shape, - select.dim, - select.start, - select.end, - select.stride)), - source_shape(source_shape), - select(select) { - TORCH_CHECK(view_type == Type::kSelect); -} - -ViewInfo::ViewInfo( - Type view_type, - Shape shape, - Shape source_shape, - AsStridedInfo as_strided) - : view_type(view_type), - shape(std::move(shape)), - source_shape(std::move(source_shape)), - as_strided(std::move(as_strided)) { - TORCH_CHECK(view_type == Type::kAsStrided); -} - -ViewInfo::ViewInfo( - Type view_type, - const Shape& source_shape, - DiagonalInfo diagonal) - : view_type(view_type), - shape(MakeDiagonalShape( - source_shape, - diagonal.offset, - diagonal.dim1, - diagonal.dim2)), - source_shape(source_shape), - diagonal(diagonal) { - TORCH_CHECK(view_type == Type::kDiagonal); -} - -void Alias::Update(Value ir_value, std::vector view_infos) { - if (!updates_.empty() && updates_.back().view_infos == view_infos) { - updates_.back().ir_value = std::move(ir_value); - } else { - updates_.push_back({std::move(ir_value), std::move(view_infos)}); - } - ++generation_; -} - -Value Alias::SyncUpdateOperations() { - for (auto& update_data : updates_) { - root_ir_value_ = ApplyUpdate(root_ir_value_, update_data); - } - updates_.clear(); - return root_ir_value_; -} - -LazyView::LazyView( - Shape shape, - std::shared_ptr alias, - ViewInfo view_info) - : shape_(std::move(shape)), alias_(std::move(alias)) { - view_infos_.push_back(std::move(view_info)); -} - -LazyView::LazyView( - Shape shape, - std::shared_ptr alias, - std::vector view_infos) - : view_infos_(std::move(view_infos)), - shape_(std::move(shape)), - alias_(std::move(alias)) {} - -void LazyView::Update(Value ir_value) { - alias_->Update(std::move(ir_value), view_infos_); -} - -std::shared_ptr LazyView::CreateSubView( - Shape shape, - ViewInfo view_info) { - std::vector view_infos(view_infos_); - view_infos.push_back(std::move(view_info)); - return std::make_shared( - std::move(shape), alias_, std::move(view_infos)); -} - -std::tuple LazyView::GetViewIrNode() { - if (IsUpToDate()) { - return std::make_tuple(ir_value_, false); - } - Value update = alias_->SyncUpdateOperations(); - for (auto& view_info : view_infos_) { - update = ApplyViewInfo(update, view_info); - } - ir_value_ = update; - generation_ = alias_->generation(); - return std::make_tuple(ir_value_, true); -} - -} // namespace lazy -} // namespace torch diff --git a/torch/csrc/lazy/core/lazy_view.h b/torch/csrc/lazy/core/lazy_view.h deleted file mode 100644 index 5e1a106494cfb..0000000000000 --- a/torch/csrc/lazy/core/lazy_view.h +++ /dev/null @@ -1,173 +0,0 @@ -#pragma once - -#include -#include -#include - -#include -#include - -namespace torch { -namespace lazy { - -struct TORCH_API SelectInfo { - bool operator==(const SelectInfo& ref) const { - return dim == ref.dim && start == ref.start && end == ref.end && - stride == ref.stride; - } - - int64_t dim = 0; - int64_t start = 0; - int64_t end = 0; - int64_t stride = 0; -}; - -struct TORCH_API AsStridedInfo { - bool operator==(const AsStridedInfo& ref) const { - return offset == ref.offset && stride == ref.stride; - } - - std::vector stride; - int64_t offset = 0; -}; - -struct TORCH_API DiagonalInfo { - bool operator==(const DiagonalInfo& ref) const { - return offset == ref.offset && dim1 == ref.dim1 && dim2 == ref.dim2; - } - - int64_t offset = 0; - int64_t dim1 = 0; - int64_t dim2 = 1; -}; - -struct TORCH_API ViewInfo { - enum class Type { - kInvalid, - kNarrow, - kNoOp, - kPermute, - kReshape, - kResize, - kSelect, - kAsStrided, - kDiagonal, - kSqueeze, - kUnsqueeze, - }; - - ViewInfo() = default; - ViewInfo(Type view_type, Shape shape, Shape source_shape); - ViewInfo(Type view_type, Shape shape, Shape source_shape, int64_t sqi); - ViewInfo( - Type view_type, - Shape source_shape, - std::vector permutation); - ViewInfo(Type view_type, const Shape& source_shape, SelectInfo select); - ViewInfo( - Type view_type, - Shape shape, - Shape source_shape, - AsStridedInfo as_strided); - ViewInfo(Type view_type, const Shape& source_shape, DiagonalInfo diagonal); - - bool operator==(const ViewInfo& ref) const { - return view_type == ref.view_type && shape == ref.shape && - indices == ref.indices && source_shape == ref.source_shape && - permutation == ref.permutation && select == ref.select && - as_strided == ref.as_strided && diagonal == ref.diagonal; - } - - Type view_type = Type::kInvalid; - // The shape of the result of a view. In case of narrowing, this represents - // the size of the narrow slice. - Shape shape; - // In case of narrowing, the starting indices from where the narrow slice is - // cut. - std::vector indices; - // The shape of the source of this view. - Shape source_shape; - // The permutation to be used. If empty, this is not a permute operation. - std::vector permutation; - // Information used for sliced views. - c10::optional select; - // Information used for as_strided views. - c10::optional as_strided; - // Information used for diagonal views. - c10::optional diagonal; - // Squeeze/Unsqueeze Index - int64_t squeeze_index; -}; - -// When a "view" (capture by reference) is taken on a node, an Alias object is -// created on the captured node itself, with its current IR Node value. -class TORCH_API Alias { - public: - struct UpdateData { - Value ir_value; - std::vector view_infos; - }; - - explicit Alias(Value ir_value) : root_ir_value_(std::move(ir_value)) {} - - size_t generation() const { - return generation_; - } - - // Appends an update to the IR value stored within the alias. The ir_value is - // the value to be written, and view_infos represents the forward path from - // the alias's ir_value to the update ir_value. - void Update(Value ir_value, std::vector view_infos); - - Value SyncUpdateOperations(); - - private: - // The IR value which is the root at which the view was created. - Value root_ir_value_; - // The stacked updates on the view. Orders matter, as most recent updates - // might overwrite older ones. - std::vector updates_; - // Incremented every time an update happens. Used by view to track alias - // changes and regenerate the most current value. - size_t generation_ = 0; -}; - -class TORCH_API LazyView { - public: - LazyView(Shape shape, std::shared_ptr alias, ViewInfo view_info); - LazyView( - Shape shape, - std::shared_ptr alias, - std::vector view_infos); - - void Update(Value ir_value); - - const Shape& shape() const { - return shape_; - } - - const std::shared_ptr& alias() const { - return alias_; - } - - std::shared_ptr CreateSubView(Shape shape, ViewInfo view_info); - - // Extracts the current IrNode out of a view, into a IrNode structure - // where the updated fields tells whether a new IR value has been created, or - // the cached one returned. - std::tuple GetViewIrNode(); - - bool IsUpToDate() const { - return ir_value_ && generation_ == alias_->generation(); - } - - private: - std::vector view_infos_; - Shape shape_; - std::shared_ptr alias_; - Value ir_value_; - size_t generation_ = 0; -}; - -} // namespace lazy -} // namespace torch diff --git a/torch/csrc/lazy/core/metrics.cpp b/torch/csrc/lazy/core/metrics.cpp index cb8120c1d45c9..7f12793a66001 100644 --- a/torch/csrc/lazy/core/metrics.cpp +++ b/torch/csrc/lazy/core/metrics.cpp @@ -106,7 +106,7 @@ MetricsArena* MetricsArena::Get() { return arena; } -void MetricsArena::Reset() { +void MetricsArena::ResetCounters() { for (auto& pair : counters_) { if (pair.second) { pair.second->Reset(); @@ -114,6 +114,14 @@ void MetricsArena::Reset() { } } +void MetricsArena::ResetMetrics() { + for (auto& pair : metrics_) { + if (pair.second) { + pair.second->Reset(); + } + } +} + void MetricsArena::RegisterMetric( const std::string& name, MetricReprFn repr_fn, @@ -141,6 +149,9 @@ void MetricsArena::ForEachMetric( const std::function& metric_func) { std::lock_guard lock(lock_); for (auto& name_data : metrics_) { + if (!name_data.second->IsValid()) { + continue; + } metric_func(name_data.first, name_data.second.get()); } } @@ -149,38 +160,44 @@ void MetricsArena::ForEachCounter( const std::function& counter_func) { std::lock_guard lock(lock_); for (auto& name_data : counters_) { + if (!name_data.second->IsValid()) + continue; counter_func(name_data.first, name_data.second.get()); } } std::vector MetricsArena::GetMetricNames() { std::vector names; - std::lock_guard lock(lock_); - for (auto& name_data : metrics_) { - names.push_back(name_data.first); - } + ForEachMetric([&names](const std::string& name, MetricData* data) { + names.push_back(name); + }); return names; } MetricData* MetricsArena::GetMetric(const std::string& name) { std::lock_guard lock(lock_); auto it = metrics_.find(name); - return it != metrics_.end() ? it->second.get() : nullptr; + if (it == metrics_.end()) { + return nullptr; + } + return it->second->IsValid() ? it->second.get() : nullptr; } std::vector MetricsArena::GetCounterNames() { std::vector names; - std::lock_guard lock(lock_); - for (auto& name_data : counters_) { - names.push_back(name_data.first); - } + ForEachCounter([&names](const std::string& name, CounterData* data) { + names.push_back(name); + }); return names; } CounterData* MetricsArena::GetCounter(const std::string& name) { std::lock_guard lock(lock_); auto it = counters_.find(name); - return it != counters_.end() ? it->second.get() : nullptr; + if (it == counters_.end()) { + return nullptr; + } + return it->second->IsValid() ? it->second.get() : nullptr; } MetricData::MetricData(MetricReprFn repr_fn, size_t max_samples) @@ -226,6 +243,14 @@ std::vector MetricData::Samples( return samples; } +void MetricData::Reset() { + std::lock_guard lock(lock_); + count_ = 0; + // Don't clear. samples_ are init with placeholders. + samples_ = std::vector(samples_.size()); + accumulator_ = 0.0; +} + Metric::Metric(std::string name, MetricReprFn repr_fn, size_t max_samples) : name_(std::move(name)), repr_fn_(std::move(repr_fn)), @@ -353,6 +378,39 @@ std::string CreateMetricReport() { return ss.str(); } +std::string CreateMetricReport( + const std::vector& counter_names, + const std::vector& metric_names) { + MetricsArena* arena = MetricsArena::Get(); + std::stringstream ss; + std::set metric_name_set( + metric_names.begin(), metric_names.end()); + arena->ForEachMetric( + [&ss, &metric_name_set](const std::string& name, MetricData* data) { + if (metric_name_set.find(name) != metric_name_set.end()) { + EmitMetricInfo(name, data, &ss); + } + }); + std::set counter_name_set( + counter_names.begin(), counter_names.end()); + arena->ForEachCounter( + [&ss, &counter_name_set](const std::string& name, CounterData* data) { + if (counter_name_set.find(name) != counter_name_set.end()) { + EmitCounterInfo(name, data, &ss); + } + }); + + static std::string fall_back_counter_prefix = "aten::"; + arena->ForEachCounter([&ss](const std::string& name, CounterData* data) { + if (name.rfind(fall_back_counter_prefix, 0) == 0) { + // it might emit duplicated counter if user also specified exact aten + // counter in the `counter_names` but it should be very rare. + EmitCounterInfo(name, data, &ss); + } + }); + return ss.str(); +} + std::vector GetMetricNames() { return MetricsArena::Get()->GetMetricNames(); } diff --git a/torch/csrc/lazy/core/metrics.h b/torch/csrc/lazy/core/metrics.h index 43fb617c1ba16..40bc606326eea 100644 --- a/torch/csrc/lazy/core/metrics.h +++ b/torch/csrc/lazy/core/metrics.h @@ -55,6 +55,12 @@ class TORCH_API MetricData { return repr_fn_(value); } + void Reset(); + + bool IsValid() const { + return TotalSamples() > 0; + } + private: mutable std::mutex lock_; MetricReprFn repr_fn_; @@ -81,6 +87,10 @@ class TORCH_API CounterData { value_ = 0; } + bool IsValid() const { + return value_ > 0; + } + private: std::atomic value_; }; @@ -89,7 +99,8 @@ class TORCH_API MetricsArena { public: static MetricsArena* Get(); - void Reset(); + void ResetCounters(); + void ResetMetrics(); // Registers a new metric in the global arena. void RegisterMetric( @@ -216,6 +227,11 @@ class TORCH_API Counter { // Creates a report with the current metrics statistics. TORCH_API std::string CreateMetricReport(); +// Creates a report with the selected metrics statistics. +TORCH_API std::string CreateMetricReport( + const std::vector& counter_names, + const std::vector& metric_names); + // Returns the currently registered metric names. Note that the list can grow // since metrics are usually function intialized (they are static function // variables). diff --git a/torch/csrc/lazy/core/shape_inference.cpp b/torch/csrc/lazy/core/shape_inference.cpp index bcc73a3ed79fd..df82fd45fe29b 100644 --- a/torch/csrc/lazy/core/shape_inference.cpp +++ b/torch/csrc/lazy/core/shape_inference.cpp @@ -451,11 +451,11 @@ std::vector compute_shape_expand( std::vector target_size(_sizes.size()); for (const auto idx : c10::irange(_sizes.size())) { if (_sizes[idx].is_symbolic()) { - c10::SymIntNode symbolicIntNode = _sizes[idx].toSymIntNodeImpl(); - auto* lazySymIntNode = - dynamic_cast(symbolicIntNode.get()); - TORCH_INTERNAL_ASSERT(lazySymIntNode); - auto size_node = lazySymIntNode->node_; + c10::SymNode symbolicIntNode = _sizes[idx].toSymNodeImpl(); + auto* lazySymNode = + dynamic_cast(symbolicIntNode.get()); + TORCH_INTERNAL_ASSERT(lazySymNode); + auto size_node = lazySymNode->node_; auto static_value = std::dynamic_pointer_cast(size_node) ->getStaticValue(); diff --git a/torch/csrc/lazy/core/shape_inference.h b/torch/csrc/lazy/core/shape_inference.h index a1b51495fb3fd..9ceb45d6b23d9 100644 --- a/torch/csrc/lazy/core/shape_inference.h +++ b/torch/csrc/lazy/core/shape_inference.h @@ -4,7 +4,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/torch/csrc/lazy/core/tensor.cpp b/torch/csrc/lazy/core/tensor.cpp index bf673a72361d3..a7890fc3e0635 100644 --- a/torch/csrc/lazy/core/tensor.cpp +++ b/torch/csrc/lazy/core/tensor.cpp @@ -36,30 +36,21 @@ LazyTensorPtr LazyTensor::Create( TORCH_CHECK(tensor.device().type() != at::kLazy); LazyTensorPtr lazy_tensor = c10::make_intrusive(LazyTensor(tensor, device)); - LazyGraphExecutor::Get()->RegisterTensor(lazy_tensor->data_ptr()); + LazyGraphExecutor::Get()->RegisterTensor(lazy_tensor->data()); return lazy_tensor; } LazyTensorPtr LazyTensor::Create(Value ir_value, const BackendDevice& device) { LazyTensorPtr lazy_tensor = c10::make_intrusive(LazyTensor(std::move(ir_value), device)); - LazyGraphExecutor::Get()->RegisterTensor(lazy_tensor->data_ptr()); - return lazy_tensor; -} - -LazyTensorPtr LazyTensor::Create( - std::shared_ptr view, - const BackendDevice& device) { - LazyTensorPtr lazy_tensor = - c10::make_intrusive(LazyTensor(std::move(view), device)); - LazyGraphExecutor::Get()->RegisterTensor(lazy_tensor->data_ptr()); + LazyGraphExecutor::Get()->RegisterTensor(lazy_tensor->data()); return lazy_tensor; } LazyTensorPtr LazyTensor::Create(BackendDataPtr handle) { LazyTensorPtr lazy_tensor = c10::make_intrusive(LazyTensor(std::move(handle))); - LazyGraphExecutor::Get()->RegisterTensor(lazy_tensor->data_ptr()); + LazyGraphExecutor::Get()->RegisterTensor(lazy_tensor->data()); return lazy_tensor; } @@ -78,21 +69,11 @@ LazyTensor::LazyTensor(Value ir_value, const BackendDevice& device) TryLimitGraphSize(); } -LazyTensor::LazyTensor( - std::shared_ptr view, - const BackendDevice& device) - : LazyTensor(std::make_shared(std::move(view), device)) {} +LazyTensor::LazyTensor(std::shared_ptr data) : data_(std::move(data)) {} -LazyTensor::LazyTensor(std::shared_ptr data) - : data_(std::move(data)), - storage_(c10::Storage( - {}, - 0, - c10::DataPtr(nullptr, backendDeviceToAtenDevice(data_->device)))) {} - -LazyTensor::Data* LazyTensor::data() const { +auto LazyTensor::data() const -> const std::shared_ptr& { TORCH_CHECK(data_ != nullptr, "Trying to access a null cursor"); - return data_.get(); + return data_; } int64_t LazyTensor::size(int64_t dim) const { @@ -107,9 +88,6 @@ at::ScalarType LazyTensor::dtype() const { } MaybeRef LazyTensor::shape() const { - if (data()->view != nullptr) { - return data()->view->shape(); - } if (data()->handle != nullptr) { return Shape(data()->handle->shape()); } @@ -131,45 +109,23 @@ int64_t LazyTensor::GetUniqueId() const { return data()->unique_id; } -std::ptrdiff_t LazyTensor::GetViewAliasId() const { - return data()->view != nullptr - ? reinterpret_cast(data()->view->alias().get()) - : 0; -} - BackendDataPtr LazyTensor::GetDataHandle() { - // Data can coexist with a view, but we need to check that the view did - // not receive any updates before calling the current IR valid. - bool up_to_date = true; - Value ir_value; - if (data()->view != nullptr) { - bool updated = false; - std::tie(ir_value, updated) = GetViewUpdate(data()->view); - up_to_date = !updated; - } - if (up_to_date) { - BackendDataPtr handle = CurrentDataHandle(); - if (handle != nullptr) { - TORCH_CHECK( - handle->HasValue(), - "Trying to access data while an async operation is in flight: ", - handle->shape().to_string()); - return handle; - } - } - if (ir_value) { - // The view gave us an updated IR value. We usually do not have a valid IR - // value field together with a view, but to allow code reuse in - // ApplyPendingGraph() we temporarily set it here. The following call to - // ApplyPendingGraph() will clear it. - AssignIrValue(std::move(ir_value)); + BackendDataPtr handle = CurrentDataHandle(); + if (handle != nullptr) { + TORCH_CHECK( + handle->HasValue(), + "Trying to access data while an async operation is in flight: ", + handle->shape().to_string()); + return handle; } + if (data()->ir_value) { ApplyPendingGraph(); } else { TORCH_CHECK(data()->tensor_data); data()->handle = TensorToDataHandle(*data()->tensor_data, GetDevice()); } + return data()->handle; } @@ -184,10 +140,9 @@ void LazyTensor::SetDataHandle(BackendDataPtr handle) { void LazyTensor::SetDataHandle(BackendDataPtr handle, bool sync) { data()->handle = std::move(handle); // Assigning a device data should always clear the IR node, to allow graph - // trimming. A view cannot be reset though, unless we are at a step-end sync. + // trimming. AssignIrValue(Value()); if (sync) { - data()->view = nullptr; data()->tensor_data = c10::nullopt; } } @@ -195,16 +150,8 @@ void LazyTensor::SetDataHandle(BackendDataPtr handle, bool sync) { void LazyTensor::SetIrValue(Value ir_value) { data()->handle = nullptr; data()->tensor_data = c10::nullopt; - if (data()->view != nullptr) { - // If we have an active view, and a SetIrValue() happens, it means we are - // within an in-place execution context, and we need to update the view's - // alias as well. - data()->view = UpdateView(data()->view, std::move(ir_value)); - data()->generation += 1; - } else { - AssignIrValue(std::move(ir_value)); - TryLimitGraphSize(); - } + AssignIrValue(std::move(ir_value)); + TryLimitGraphSize(); } void LazyTensor::SetInPlaceIrValue(Value ir_value) { @@ -257,9 +204,6 @@ Value LazyTensor::GetIrValue() const { } Value LazyTensor::CurrentIrValue() const { - if (data()->view != nullptr) { - return std::get<0>(GetViewUpdate(data()->view)); - } return data()->ir_value; } @@ -268,9 +212,6 @@ void LazyTensor::SetTensorData(at::Tensor tensor_data) { } c10::optional LazyTensor::CurrentTensorData() const { - if (data()->view != nullptr && !data()->view->IsUpToDate()) { - return c10::nullopt; - } return data()->tensor_data; } @@ -293,71 +234,6 @@ Value LazyTensor::GetIrValueForTensor( return CreateTensorNode(std::move(data), read_only); } -std::tuple LazyTensor::GetViewUpdate( - const std::shared_ptr& view) const { - auto value_with_update = view->GetViewIrNode(); - if (std::get<1>(value_with_update)) { - data()->handle = nullptr; - data()->tensor_data = c10::nullopt; - } - return value_with_update; -} - -std::shared_ptr LazyTensor::UpdateView( - std::shared_ptr view, - Value ir_value) const { - if (ir_value.shape().sizes() != view->shape().sizes()) { - TORCH_CHECK(ir_value.shape().numel() == view->shape().numel()); - - ViewInfo view_info( - ViewInfo::Type::kReshape, ir_value.shape(), view->shape()); - view = view->CreateSubView(view_info.shape, view_info); - } - view->Update(std::move(ir_value)); - return view; -} - -void LazyTensor::SetSubView(ViewInfo view_info) const { - data()->view = data()->view->CreateSubView(view_info.shape, view_info); - data()->generation += 1; -} - -void LazyTensor::ModifyCurrentView(ViewInfo view_info) const { - if (data()->view != nullptr) { - SetSubView(view_info); - return; - } - // This node is not a view. Since this function is meant to modify a view - // in place, we need to turn this existing tensor into a view. - Value ir_value = GetIrValue(); - std::shared_ptr alias = std::make_shared(ir_value); - data()->view = std::make_shared(view_info.shape, alias, view_info); - AssignIrValue(Value()); -} - -std::shared_ptr LazyTensor::CreateView(ViewInfo view_info) const { - if (data()->view != nullptr) { - return data()->view->CreateSubView(view_info.shape, view_info); - } - // This node is not a view, and creating a view forks the current node into - // becoming one itself. This means creating an alias with the current IR - // Node, and using the same alias for the created IR Node. - Value ir_value = GetIrValue(); - std::shared_ptr alias = std::make_shared(ir_value); - ViewInfo this_view_info( - ViewInfo::Type::kNoOp, ir_value.shape(), ir_value.shape()); - data()->view = std::make_shared( - ir_value.shape(), alias, std::move(this_view_info)); - AssignIrValue(Value()); - return std::make_shared(view_info.shape, alias, view_info); -} - -LazyTensorPtr LazyTensor::CreateViewTensor(ViewInfo view_info) const { - auto new_tensor = Create(CreateView(std::move(view_info)), GetDevice()); - new_tensor->storage_ = Storage(); - return new_tensor; -} - at::Tensor LazyTensor::ToTensor(bool detached) { at::Tensor tensor; c10::optional tensor_data = CurrentTensorData(); @@ -374,8 +250,7 @@ at::Tensor LazyTensor::ToTensor(bool detached) { } else { tensor = *tensor_data; if (detached) { - if (data()->ir_value || data()->handle != nullptr || - data()->view != nullptr) { + if (data()->ir_value || data()->handle != nullptr) { // If we have other authoritive sources, just drop our reference and // transfer it to the caller. data()->tensor_data = c10::nullopt; @@ -395,7 +270,6 @@ void LazyTensor::ShallowCopyTo(LazyTensorPtr dest) const { void LazyTensor::SetTensor(at::Tensor tensor) { SetTensorData(tensor); - data()->view = nullptr; data()->handle = nullptr; AssignIrValue(Value()); } @@ -408,25 +282,14 @@ void LazyTensor::UpdateFromTensor(at::Tensor tensor, bool sync) { SetTensorData(tensor); data()->handle = nullptr; AssignIrValue(Value()); - if (data()->view != nullptr) { - Value ir_value = GetIrValueForTensor(tensor, GetDevice()); - data()->view = UpdateView(data()->view, std::move(ir_value)); - } } } void LazyTensor::UpdateFromTensorOut(at::Tensor tensor) { - if (data()->view != nullptr && shape().Get().numel() != tensor.numel()) { - data()->view = nullptr; - } UpdateFromTensor(std::move(tensor), /*sync=*/false); } void LazyTensor::UpdateFromTensorOut(const LazyTensorPtr& tensor) { - if (data()->view != nullptr && - shape().Get().numel() != tensor->shape().Get().numel()) { - data()->view = nullptr; - } SetIrValue(tensor->GetIrValue()); } diff --git a/torch/csrc/lazy/core/tensor.h b/torch/csrc/lazy/core/tensor.h index 12cfdd2827d74..5c1bee431c180 100644 --- a/torch/csrc/lazy/core/tensor.h +++ b/torch/csrc/lazy/core/tensor.h @@ -1,22 +1,18 @@ #pragma once -#include +#include #include #include #include #include -#include #include namespace torch { namespace lazy { -class TORCH_API SymIntNodeImpl : public c10::SymIntNodeImpl { +class TORCH_API SymNodeImpl : public c10::SymNodeImpl { public: - SymIntNodeImpl(NodePtr ptr) : node_(std::move(ptr)){}; - c10::SymIntNode add(const c10::SymIntNode& other) override { - TORCH_CHECK(false, "NYI"); - } + SymNodeImpl(NodePtr ptr) : node_(std::move(ptr)){}; NodePtr node_; }; @@ -37,20 +33,19 @@ class TORCH_API LazyTensor : public c10::intrusive_ptr_target { : ir_value(std::move(ir_value)), device(std::move(device)), unique_id(GetNextTensorId()) {} - Data(std::shared_ptr view, BackendDevice device) - : view(std::move(view)), - device(std::move(device)), - unique_id(GetNextTensorId()) {} Data(at::Tensor tensor_data, BackendDevice device) : tensor_data(std::move(tensor_data)), device(std::move(device)), unique_id(GetNextTensorId()) {} + // TODO(alanwaketan): Remove this ctor. This is a + // temporary ctor to ease XLA LTC migration. + Data(BackendDevice device) + : device(std::move(device)), unique_id(GetNextTensorId()) {} ~Data(); BackendDataPtr handle; Value ir_value; - std::shared_ptr view; c10::optional tensor_data; const BackendDevice device; const int64_t unique_id = 0; @@ -70,16 +65,14 @@ class TORCH_API LazyTensor : public c10::intrusive_ptr_target { // have to check both lazy_tensor_ptr && *lazy_tensor_ptr, so everywhere that // used to rely on a LazyTensor obj with a null Data can now rely on a null // LazyTensorPtr instead. - LazyTensor() = delete; + // TODO(alanwaketan): This is a temporarily change to make XLA LTC migration + // easier. Restore it back to delete. + LazyTensor() = default; size_t generation() const { return data()->generation; } - LazyTensorPtr alias() const { - return c10::make_intrusive(LazyTensor(data_ptr())); - } - int64_t size(int64_t dim) const; at::Tensor ToTensor(bool detached); @@ -93,7 +86,7 @@ class TORCH_API LazyTensor : public c10::intrusive_ptr_target { void UpdateFromTensorOut(at::Tensor tensor); void UpdateFromTensorOut(const LazyTensorPtr& tensor); - Data* data() const; + const std::shared_ptr& data() const; at::ScalarType dtype() const; @@ -102,10 +95,6 @@ class TORCH_API LazyTensor : public c10::intrusive_ptr_target { const BackendDevice& GetDevice() const; int64_t GetUniqueId() const; - // Retrieves an opaque ID of the alias object upon which the tensor's view is - // rooted, or 0 if this tensor is not a view. - std::ptrdiff_t GetViewAliasId() const; - // Fetches the data behind the tensor. If the tensor has a graph defining // its current value, executes the graph and fetches the data result. BackendDataPtr GetDataHandle(); @@ -129,59 +118,29 @@ class TORCH_API LazyTensor : public c10::intrusive_ptr_target { void SetIrValue(Value ir_value); void SetInPlaceIrValue(Value ir_value); - void SetSubView(ViewInfo view_info) const; - c10::optional CurrentTensorData() const; std::vector MakeOutputTensors(NodePtr node) const; - LazyTensorPtr CreateViewTensor(ViewInfo view_info) const; LazyTensorPtr CopyTensorToDevice(const BackendDevice& device); - void ModifyCurrentView(ViewInfo view_info) const; - // Applies the queue of operations in preparation for using the data. void ApplyPendingGraph(); - const c10::Storage& Storage() const { - return storage_; - } - // This is currently only used by outlier view ops such as expand that - // don't go through CreateViewTensor to support Tensor.is_alias_of. - void SetStorage(const c10::Storage& storage) { - storage_ = storage; - } + void AssignIrValue(Value ir_value) const; + + protected: + explicit LazyTensor(std::shared_ptr data); + + void SetTensorData(at::Tensor tensor_data); private: LazyTensor(const at::Tensor& tensor, const BackendDevice& device); LazyTensor(Value ir_value, const BackendDevice& device); - LazyTensor(std::shared_ptr view, const BackendDevice& device); explicit LazyTensor(BackendDataPtr handle); - explicit LazyTensor(std::shared_ptr data); - - static LazyTensorPtr Create( - std::shared_ptr view, - const BackendDevice& device); - - std::shared_ptr data_ptr() const { - return data_; - } - - void AssignIrValue(Value ir_value) const; - - void SetTensorData(at::Tensor tensor_data); Value CreateTensorNode(BackendDataPtr data, bool read_only) const; - std::tuple GetViewUpdate( - const std::shared_ptr& view) const; - - std::shared_ptr UpdateView( - std::shared_ptr view, - Value ir_value) const; - - std::shared_ptr CreateView(ViewInfo view_info) const; - // We build a graph accumulating operations, but at a given point we // need to force a rendering, otherwise the graph can grow without control. // Think: @@ -196,12 +155,6 @@ class TORCH_API LazyTensor : public c10::intrusive_ptr_target { static int64_t GetNextTensorId(); std::shared_ptr data_; - // Temporarily used to suport Tensor.is_alias_of(). - // This is a fake storage that doesn't store anything. - // Instead it serves as a marker to mark LazyTensors that - // points to the same storage, and thus alias of each other. - // FIXME(alanwaketan): Remove this once we have functionalization (bdhirsh). - c10::Storage storage_; }; // Utils to convert at::Tensor to LazyTensor, and vice versa. diff --git a/torch/csrc/lazy/core/tensor_impl.h b/torch/csrc/lazy/core/tensor_impl.h index de1191a3de3e2..710230605cc1f 100644 --- a/torch/csrc/lazy/core/tensor_impl.h +++ b/torch/csrc/lazy/core/tensor_impl.h @@ -49,15 +49,6 @@ class TORCH_API LTCTensorImpl final : public c10::TensorImpl { c10::SymIntArrayRef sym_sizes_custom() const override; c10::SymIntArrayRef sym_strides_custom() const override; -#ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY - const at::Storage& storage() const override { - return tensor_->Storage(); - } - bool has_storage() const override { - return tensor_->Storage(); - } -#endif // C10_DISABLE_TENSORIMPL_EXTENSIBILITY - private: void setup_size_properties(); diff --git a/torch/csrc/lazy/python/init.cpp b/torch/csrc/lazy/python/init.cpp index 2d421a3eb2ae7..fe74d29d87ac1 100644 --- a/torch/csrc/lazy/python/init.cpp +++ b/torch/csrc/lazy/python/init.cpp @@ -42,9 +42,9 @@ std::ptrdiff_t GetTensorId(const at::Tensor& tensor) { std::string GetTensorsDump( const std::vector& tensors, - const std::function)>& + const std::function)>& coverter) { - std::vector nodes; + std::vector nodes; std::vector values; for (auto& tensor : tensors) { auto inner = at::functionalization::impl::from_functional_tensor(tensor); @@ -126,8 +126,10 @@ void initLazyBindings(PyObject* module) { torch::lazy::LazyGraphExecutor::Get()->WaitDeviceOps({}); }, py::arg("devices")); - lazy.def( - "_reset_metrics", []() { torch::lazy::MetricsArena::Get()->Reset(); }); + lazy.def("_reset_metrics", []() { + torch::lazy::MetricsArena::Get()->ResetCounters(); + torch::lazy::MetricsArena::Get()->ResetMetrics(); + }); lazy.def("_counter_names", []() { return torch::lazy::GetCounterNames(); }); lazy.def( "_metrics_report", []() { return torch::lazy::CreateMetricReport(); }); @@ -142,7 +144,7 @@ void initLazyBindings(PyObject* module) { lazy.def( "_get_tensors_text", [](const std::vector& tensors) -> std::string { - auto coverter = [](c10::ArrayRef nodes) { + auto coverter = [](c10::ArrayRef nodes) { return torch::lazy::DumpUtil::ToText(nodes); }; return GetTensorsDump(tensors, coverter); @@ -150,7 +152,7 @@ void initLazyBindings(PyObject* module) { lazy.def( "_get_tensors_dot", [](const std::vector& tensors) -> std::string { - auto coverter = [](c10::ArrayRef nodes) { + auto coverter = [](c10::ArrayRef nodes) { return torch::lazy::DumpUtil::ToDot(nodes); }; return GetTensorsDump(tensors, coverter); @@ -222,7 +224,7 @@ void initLazyBindings(PyObject* module) { [](const std::vector& tensors) -> std::pair, std::vector> { #if !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE)) - std::vector roots; + std::vector roots; for (auto& tensor : tensors) { auto xtensor = TryGetLtcTensor(tensor); roots.push_back(xtensor->GetIrValue().node.get()); @@ -305,6 +307,19 @@ void initLazyBindings(PyObject* module) { #endif // !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE)) return result; }); + + // GetPythonFramesFunction() has not ever worked with torchdeploy/multipy + // possibly becuase GetPythonFrames resolves to external cpython rather + // than embedded cpython. So far this problem has only been observed + // internally, so we will just block it off there. + +#if !(defined(USE_DEPLOY)) + + // When libtorch_python is loaded, we register the python frame getter + // otherwise, debug util simply omits python frames + GetPythonFramesFunction() = GetPythonFrames; + +#endif // USE_DEPLOY } } // namespace lazy diff --git a/torch/csrc/lazy/ts_backend/ir_builder.h b/torch/csrc/lazy/ts_backend/ir_builder.h index 600243b67f622..1f32a3521ba8a 100644 --- a/torch/csrc/lazy/ts_backend/ir_builder.h +++ b/torch/csrc/lazy/ts_backend/ir_builder.h @@ -30,10 +30,6 @@ struct TorchScriptIrBuilder : IrBuilder { const bool& is_scalar_expand) const override { return ReuseOrMakeNode(input0, size, is_scalar_expand); } - NodePtr MakeView(const Value& input0, const std::vector& output_size) - const override { - return ReuseOrMakeNode(input0, output_size); - } NodePtr MakeCast( const Value& input0, const at::ScalarType& dtype, @@ -55,84 +51,6 @@ struct TorchScriptIrBuilder : IrBuilder { return MakeNode(op, operands, shape, num_outputs, hash_seed); } - // View op nodes - NodePtr MakeAsStridedViewUpdate( - const Value& input0, - const Value& input1, - const std::vector& size, - const std::vector& stride, - const int64_t& storage_offset) const override { - return ReuseOrMakeNode( - input0, input1, size, stride, storage_offset); - } - NodePtr MakeAsStrided( - const Value& input0, - const std::vector& size, - const std::vector& stride, - const int64_t& storage_offset) const override { - return ReuseOrMakeNode(input0, size, stride, storage_offset); - } - NodePtr MakeDiagonalViewUpdate( - const Value& input0, - const Value& input1, - const int64_t& offset, - const int64_t& dim1, - const int64_t& dim2) const override { - return ReuseOrMakeNode( - input0, input1, offset, dim1, dim2); - } - NodePtr MakeDiagonal( - const Value& input0, - const int64_t& offset, - const int64_t& dim1, - const int64_t& dim2) const override { - return ReuseOrMakeNode(input0, offset, dim1, dim2); - } - NodePtr MakeNarrowViewUpdate( - const Value& input0, - const Value& input1, - const std::vector& base_indices) const override { - return ReuseOrMakeNode(input0, input1, base_indices); - } - NodePtr MakeNarrow( - const Value& input0, - const std::vector& base_indices, - const std::vector& sizes) const override { - return ReuseOrMakeNode(input0, base_indices, sizes); - } - NodePtr MakePermute(const Value& input0, const std::vector& dims) - const override { - return ReuseOrMakeNode(input0, dims); - } - NodePtr MakeResize(const Value& input0, const std::vector& size) - const override { - return ReuseOrMakeNode(input0, size); - } - NodePtr MakeSelectViewUpdate( - const Value& input0, - const Value& input1, - const int64_t& dim, - const int64_t& start, - const int64_t& end, - const int64_t& stride) const override { - return ReuseOrMakeNode( - input0, input1, dim, start, end, stride); - } - NodePtr MakeSelect( - const Value& input0, - const int64_t& dim, - const int64_t& start, - const int64_t& end, - const int64_t& stride) const override { - return ReuseOrMakeNode